K-Means under the hood with Python

Hello World!

This article is meant to explain how the K-Means Clustering algorithm works while simultaneously learning a little Python.

What is K-Means?

K-Means Clustering is an unsupervised learning algorithm that tells you how similar observations are by putting them into groups or “clusters”.  K-Means is often used as a discovery step on new data to discover what various categories might be and then apply something such as a k-nearest-neighbor as a classifier to it after understanding the centroid labels.  Where a centroid is the center of a “cluster” or group.

How does K-Means work?

K-Means takes in an unlabeled data set and a whole real number, k.  K is the number of centroids, or clusters you wish to find.  If you do not know how many clusters there should be, it is possible to do some pre-processing to find that more automatically, however that is out of the scope of this article.  Once you have a data set and defined the size of k, K-Means begins its iterative process.  It starts by selecting centroids by moving them to the average of the data associated with them.  It then reshuffles all of the data into new groups based on the proximity to each centroid.

Coding Pre-requisites

I use Visual Studio Community 2015 with the Python Tools Extensions.  I find it easy for interactive coding, coloring, syntax, deployments, and environment management.  Secondly you will need 2 packages installed with pip,  plotly and numpy.  In visual studio, this is done by simply by selecting the “Tools” menu at the very top, navigating to “Python Tools” and then “Python Environments”.  You can alternatively push ctrl + k followed by ctrl + `.  A new window will appear.

kmeans_1

Once, here select your environment.  Click the drop down and select pip.  type “install ” + package name.  Followed by clicking the suggestion that pops up, highlighted below.

kmeans_2

This will install the package numpy into the highlighted environment, excellent.  Repeat this for plotly.

What is Plotly and Numpy?

Plotly is wonderful interactive plotting library with wrappers in multiple languages.  The biggest item here is that it can produce graphics that are html5/javascript and css and therefor used in any modern environment.  They are then easily tweaked to allow for a data connected back end with a pure js front end or fit right into client applications.  Numpy is a numerical computing library with a variety of linear algebra functions.  Numpy will be very commonly used in machine learning with python, as will pandas.

Show me some code already!

In this section, I am going to show a section of code at a time with explanation, and then at the very end, I will paste the entire file so it is easily copy/pasted.  This article will also include several charts to better explain what is happening.  I find visuals to be more explanatory than words.

import math
import numpy as np
import plotly as plotly
from plotly.graph_objs import Scatter, Layout

Here we have some stock import statements, basically we are able to refer to numpy as np, plotly as plotly and whatever gets imported by plotly.graph_objs we can reference simply as Scatter and Layout without needing to bother with go.Scatter or otherwise.

Functions we Need

def CalculateDistance(p1,p2):
    xDiff = (p2[0,0] - p1[0,0])**2
    yDiff = (p2[0,1] - p2[0,1])**2
    return math.sqrt(xDiff + yDiff)

The distance between any two points is Sqrt((x1 – x2)^2 + (y1-y2)^2).  Here we are taking in two numpy arrays with the shape 1,2 (two columns and 1 row). 0,0 is our x location for each point and 0,1 is our y location for each point.  We square both x and y components, and return the sqrt.

Next is a helper function for defining a Trace, or a mark on a plotly plot.  In this case we are defaulting to a scatter point.

def CreateScatterTrace(data, color, marker='point', size = 10):
    trace = Scatter(
        x = data[:,0],
        y = data[:,1],
        mode = 'markers',
        marker = dict(
            size = size,
            color = color,
            symbol = marker,
            line = dict(
                width = 2)
            )
    )
    return trace

The parameter data is a numpy array of the shape nx2.  So n rows and 2 columns, where column 0 is x and column 1 is y. Remember, Python is 0 indexed.  The parameter=syntax, such as marker=’point’ is the syntax for optional arguments with default values.

def Plot2Categories(cat1Data, cat2Data):
    traceCat1 = CreateScatterTrace(cat1Data, 'rgba(152, 0, 0, .8)')
    traceCat2 = CreateScatterTrace(cat2Data, 'rgba(0, 152, 0, .8)')
    plotly.offline.plot({
        "data": [traceCat1, traceCat2],
        "layout": Layout(title="2 Fake Categories")
    })

This function plots 2 sets of data points as distinct categories visualized with red and green using plotly’s offline mode.  What the offline mode does is allows us to not be connected to a server and generate everything we need interactively and also display it.  Executing this function will bring up a browser window with an interactive chart displayed.

def PlotSingleCategory(data):
    trace = CreateScatterTrace(data, 'rgba(152, 0, 0, .8)')
    plotly.offline.plot({
        "data": [trace],
        "layout": Layout(title="Unlabeled Data")
    })

Plot single category, is a very simple scatter plot, plotting one set of points all as red.

def PlotKMeansSingle(data, cen1, cen2):
    traceCat1 = CreateScatterTrace(data, 'rgba(0, 0, 152, .8)')
    traceCen1 = CreateScatterTrace(cen1, 'rgba(152, 0, 0, .8)', 'star', 15)
    traceCen2 = CreateScatterTrace(cen2, 'rgba(0, 152, 0, .8)', 'star', 15)
    plotly.offline.plot({
        "data": [traceCat1, traceCen1, traceCen2],
        "layout": Layout(title="KMeans")
    })

The above function simply showcases what K-Means looks like after the initial selection of centroids, but before the data has been assigned to clusters.

def PlotKMeans(cat1Data, cat2Data, cen1, cen2):
    traceCat1 = CreateScatterTrace(cat1Data, 'rgba(152, 0, 0, .8)')
    traceCat2 = CreateScatterTrace(cat2Data, 'rgba(0, 152, 0, .8)')
    traceCen1 = CreateScatterTrace(cen1, 'rgba(152, 0, 0, .8)', 'star', 15)
    traceCen2 = CreateScatterTrace(cen2, 'rgba(0, 152, 0, .8)', 'star', 15)
    plotly.offline.plot({
        "data": [traceCat1, traceCat2, traceCen1, traceCen2],
        "layout": Layout(title="KMeans")
    })

Finally the above code plots a 2 category K-Means with centroids and associated categorical data as red and green dots for the data and a red and green star for the centroids.

Now the Interesting Bit

Now we can get on to showcasing K-Means.  This is intended to be a showcase with plotting and discovery and not a meaningful implementation, that has been done a million times over and exists in several popular libraries.

cat1 = 5.5
cat2 = 3.5
#Create 2 A-Symmetrical Gaussian Distributions
cat1Data = np.column_stack((np.random.normal(cat1, 1.0, 50), 
                            np.random.normal(cat1, 1.0, 50)))
cat2Data = np.column_stack((np.random.normal(cat2, 1.0, 50), 
                            np.random.normal(cat2, 1.0, 50)))

What we are doing here is creating 2 different A-Symmetrical Gaussian Distributions.  What in the world is an A-Symmetrical Gaussian Distribution?  Here is a great article on what a Gaussian Distribution is, basically it gives us a nice healthy set of random data with a pretty bell curve distribution for demonstration purposes.  The A-Symmetrical bit is that I generate 2 different distributions for x and y values for a bit of extra variance, while most demonstrations use the same generated data, which gives you symmetrical points, such as 1,1 or 2,2; Using A-Symmetry, we get 1,2 and 5,4 etc etc.

We have two sets of data as categories, because the point of k-means is to pick the correct K, and identify similar items in groups, by using a gaussian distribution at two different, means, we have those groups baked in, so lets go ahead and visualize them.

#Plot 2 categories known
Plot2Categories(cat1Data, cat2Data)

And here is an interactive graphic of this plot.

As you can see we get a nice distribution of 2 distinct categories with some intermingling of the data crossing over the decision boundary.  Good demonstration data.

Now, remember that K-Means is an unsupervised algorithm, we have generated two labeled data sets at this point.  Lets merge the two together such that labels are gone and plot the data as K-Means would see it.

#merge categories and display as k-means would see them.
mergedData = np.concatenate((cat1Data, cat2Data))
PlotSingleCategory(mergedData)

np.concatenate simply appends two properly shaped data sets together.  Below is the plot of what K-Means would anticipate to see.

plot_2

This presents a very different machine learning problem, and hence why we use K-Means.  We know there are two categories, so we will proceed with that knowledge.

#pick 2 initial centroids intelligently
xSplit = np.percentile(mergedData[:,0], 50.0)
ySplit = np.percentile(mergedData[:,1], 50.0)
xStd = np.std(mergedData[:,0])
yStd = np.std(mergedData[:,1])
cat1Cen = np.array([xSplit - xStd, ySplit - yStd]).reshape([1,2])
cat2Cen = np.array([xSplit + xStd, ySplit + yStd]).reshape([1,2])
#Do dumber centroid selection, for longer convergence
#This is for demonstration
cat1Cen = np.array([1,1]).reshape([1,2])
cat2Cen = np.array([7,7]).reshape([1,2])

The first thing to do is to pick our 2 centroids.  Here you can see two variations for how to do it.  With the demo plots we see, I used the second, as it took longer to converge.  Using the first choice with a gaussian distribution ends up almost precicely selecting the centroids, simply due to the nature how how the math works out.  I suggest reading this article a few times if you are interested while thinking about statistical percentile ranges and standard deviation.

#Plot before category assignments
PlotKMeansSingle(mergedData, cat1Cen, cat2Cen)

Ok, lets go ahead and plot where we end up, prior to assigning categories to the data sets.

plot_3

Excellent, we can see the centroids, represented by stars at 1,1 and 7,7.  Not quite the ideal locations, but in proper order.

#assign points to categories
cat1 = np.empty([2,])
cat2 = np.empty([2,])
for i in range(0, mergedData[:,0].size):
    distToCen1 = CalculateDistance(cat1Cen, mergedData[i,:].reshape([1,2]))
    distToCen2 = CalculateDistance(cat2Cen, mergedData[i,:].reshape([1,2]))
    if distToCen1 < distToCen2 :
        cat1 = np.concatenate((cat1, mergedData[i,:]))
    else:
        cat2 = np.concatenate((cat2, mergedData[i,:]))

This is the re-assignment portion of k-means.  We take all of the data, un-assign it (its already un-assigned here), and then calculate the distance to each centroid and assign it to the closest centroid.  Here we favor centroid 2 in the case of a tie as we use a < sign as opposed to a <= sign in our if statement.  Notice also we create the empty arrays outside of the loop such that they maintain within memory upon loop exit with the full data.

#endfor
#delete first row, as it contains zeros from initialization
#Also just reshape while we are at it.
cat1 = np.delete(cat1, (0,1))
cat1 = cat1.reshape([cat1.size / 2, 2])
cat2 = np.delete(cat2, (0,1))
cat2 = cat2.reshape([cat2.size / 2, 2])

#Plot first round
PlotKMeans(cat1, cat2, cat1Cen, cat2Cen)

Here we do a little cleaning up.  Creating an empty array and then concatenating to it leaves a bit of dirt in the first row, so we remove that first row, and reshape the arrays according to an nx2 matrix.  As we do not know how many points resulted in each centroid, we use the array’s size property and divide by the number of desired columns (2).  The plot we end with looks as below.

plot_4

Alright, that looks like we are starting on a decent path.  We have our initial centroids and categories.  They are starting semi-close to final, but not there, so lets see what happens if we iterate the loop again.

#Move centroids to average of data they have
cat1Cen = np.array([np.mean(cat1[:,0]), np.mean(cat1[:,1])]).reshape([1,2])
cat2Cen = np.array([np.mean(cat2[:,0]), np.mean(cat2[:,1])]).reshape([1,2])
#Plot before re-assignment
PlotKMeans(cat1, cat2, cat1Cen, cat2Cen)

Since the centroids already exist, we simply get the average x and y values of the data assigned to each cluster, and move the centroids to those co-ordinates.  Lets go ahead and plot this movement prior to re-assignment of the points.

plot_5

Excellent, it appears the centroids have moved significantly closer to the actual average of the data points.  Notice how the lower centroid is very close to our initialized mean for the gaussian distribution (3.5) and our upper centroid is also close to the initialized mean for the upper distribution (5.5).  This is no accident, this is how K-Means works.  It will find the means.

Ok, lets do a re-assignment step.

#Re-assign
mergedData = np.concatenate((cat1, cat2))
cat1 = np.empty([2,])
cat2 = np.empty([2,])
for i in range(0, mergedData[:,0].size):
    distToCen1 = CalculateDistance(cat1Cen, mergedData[i,:].reshape([1,2]))
    distToCen2 = CalculateDistance(cat2Cen, mergedData[i,:].reshape([1,2]))
    if distToCen1 < distToCen2 :
        cat1 = np.concatenate((cat1, mergedData[i,:]))
    else:
        cat2 = np.concatenate((cat2, mergedData[i,:]))
#endfor
cat1 = np.delete(cat1, (0,1))
cat1 = cat1.reshape([cat1.size / 2, 2])
cat2 = np.delete(cat2, (0,1))
cat2 = cat2.reshape([cat2.size / 2, 2])
#plot after
PlotKMeans(cat1, cat2, cat1Cen, cat2Cen)

Notice how we have merged the data back into a bunch of unknowns, we then go through the process of assigning to the closest centroid, do some clean up and plot.

Here is a static representation of the plot.  The static image makes it easier to see which data points each cluster won/lost based on the re-assignment.  After the static, I will put an interactive to showcase some more of the neatness of plotly.

plot_6

It appears our lower centroid has won 7 new points.  Here is the interactive chart.

Summary

So there you have it, that is how K-Means clustering works under the hood.  As long as you can represent your problem numerically, K-Means can do this for you.  In this example we can see it didn’t get things perfectly after 2 iterations, but it did a very fine job for having no labels.  Of course since distance is a polar problem as opposed to a linear problem, there are likely some optimizations we can make to this for accuracy, but again, beyond today’s scope.  I hope you learned a lot about Visual Studio, Python, Plotly and of course K-Means.

Entire Code Set

import math
import numpy as np
import plotly as plotly
import plotly.plotly as py
from plotly.graph_objs import *

def CalculateDistance(p1,p2):
    xDiff = (p2[0,0] - p1[0,0])**2
    yDiff = (p2[0,1] - p2[0,1])**2
    return math.sqrt(xDiff + yDiff)
def CreateScatterTrace(data, color, marker='point', size = 10):
    trace = Scatter(
        x = data[:,0],
        y = data[:,1],
        mode = 'markers',
        marker = dict(
            size = size,
            color = color,
            symbol = marker,
            line = dict(
                width = 2)
            )
    )
    return trace

def Plot2Categories(cat1Data, cat2Data):
    traceCat1 = CreateScatterTrace(cat1Data, 'rgba(152, 0, 0, .8)')
    traceCat2 = CreateScatterTrace(cat2Data, 'rgba(0, 152, 0, .8)')
    py.plot([traceCat1, traceCat2], filename = 'blog/twocategories')
    #plotly.offline.plot({
    #    "data": [traceCat1, traceCat2],
    #    "layout": Layout(title="2 Fake Categories")
    #})

def PlotSingleCategory(data):
    trace = CreateScatterTrace(data, 'rgba(152, 0, 0, .8)')
    plotly.offline.plot({
        "data": [trace],
        "layout": Layout(title="2 Fake Categories")
    })

def PlotKMeansSingle(data, cen1, cen2):
    traceCat1 = CreateScatterTrace(data, 'rgba(0, 0, 152, .8)')
    traceCen1 = CreateScatterTrace(cen1, 'rgba(152, 0, 0, .8)', 'star', 15)
    traceCen2 = CreateScatterTrace(cen2, 'rgba(0, 152, 0, .8)', 'star', 15)
    plotly.offline.plot({
        "data": [traceCat1, traceCen1, traceCen2],
        "layout": Layout(title="KMeans")
    })

def PlotKMeans(cat1Data, cat2Data, cen1, cen2):
    traceCat1 = CreateScatterTrace(cat1Data, 'rgba(152, 0, 0, .8)')
    traceCat2 = CreateScatterTrace(cat2Data, 'rgba(0, 152, 0, .8)')
    traceCen1 = CreateScatterTrace(cen1, 'rgba(152, 0, 0, .8)', 'star', 15)
    traceCen2 = CreateScatterTrace(cen2, 'rgba(0, 152, 0, .8)', 'star', 15)
    py.plot([traceCat1, traceCat2, traceCen1, traceCen2], filename = 'blog/KMeansPlot')
    #plotly.offline.plot({
    #    "data": [traceCat1, traceCat2, traceCen1, traceCen2],
    #    "layout": Layout(title="KMeans")
    #})

cat1 = 5.5
cat2 = 3.5
#Create 2 A-Symmetrical Gaussian Distributions
cat1Data = np.column_stack((np.random.normal(cat1, 1.0, 50), 
                            np.random.normal(cat1, 1.0, 50)))
cat2Data = np.column_stack((np.random.normal(cat2, 1.0, 50), 
                            np.random.normal(cat2, 1.0, 50)))
#Plot 2 categories known
Plot2Categories(cat1Data, cat2Data)

#merge categories and display as k-means would see them.
mergedData = np.concatenate((cat1Data, cat2Data))
PlotSingleCategory(mergedData)

###################
# K-Means Round 1 #
###################
#pick 2 initial centroids intelligently
xSplit = np.percentile(mergedData[:,0], 50.0)
ySplit = np.percentile(mergedData[:,1], 50.0)
xStd = np.std(mergedData[:,0])
yStd = np.std(mergedData[:,1])
cat1Cen = np.array([xSplit - xStd, ySplit - yStd]).reshape([1,2])
cat2Cen = np.array([xSplit + xStd, ySplit + yStd]).reshape([1,2])
#Do dumber centroid selection, for longer convergence
#This is for demonstration
cat1Cen = np.array([1,1]).reshape([1,2])
cat2Cen = np.array([7,7]).reshape([1,2])

#Plot before category assignments
PlotKMeansSingle(mergedData, cat1Cen, cat2Cen)

#assign points to categories
cat1 = np.empty([2,])
cat2 = np.empty([2,])
for i in range(0, mergedData[:,0].size):
    distToCen1 = CalculateDistance(cat1Cen, mergedData[i,:].reshape([1,2]))
    distToCen2 = CalculateDistance(cat2Cen, mergedData[i,:].reshape([1,2]))
    if distToCen1 < distToCen2 :
        cat1 = np.concatenate((cat1, mergedData[i,:]))
    else:
        cat2 = np.concatenate((cat2, mergedData[i,:]))
#endfor
#delete first row, as it contains zeros from initialization
#Also just reshape while we are at it.
cat1 = np.delete(cat1, (0,1))
cat1 = cat1.reshape([cat1.size / 2, 2])
cat2 = np.delete(cat2, (0,1))
cat2 = cat2.reshape([cat2.size / 2, 2])

#Plot first round
PlotKMeans(cat1, cat2, cat1Cen, cat2Cen)

###################
# K-Means Round 2 #
###################
#Move centroids to average of data they have
cat1Cen = np.array([np.mean(cat1[:,0]), np.mean(cat1[:,1])]).reshape([1,2])
cat2Cen = np.array([np.mean(cat2[:,0]), np.mean(cat2[:,1])]).reshape([1,2])
#Plot before re-assignment
PlotKMeans(cat1, cat2, cat1Cen, cat2Cen)
#Re-assign
mergedData = np.concatenate((cat1, cat2))
cat1 = np.empty([2,])
cat2 = np.empty([2,])
for i in range(0, mergedData[:,0].size):
    distToCen1 = CalculateDistance(cat1Cen, mergedData[i,:].reshape([1,2]))
    distToCen2 = CalculateDistance(cat2Cen, mergedData[i,:].reshape([1,2]))
    if distToCen1 < distToCen2 :
        cat1 = np.concatenate((cat1, mergedData[i,:]))
    else:
        cat2 = np.concatenate((cat2, mergedData[i,:]))
#endfor
cat1 = np.delete(cat1, (0,1))
cat1 = cat1.reshape([cat1.size / 2, 2])
cat2 = np.delete(cat2, (0,1))
cat2 = cat2.reshape([cat2.size / 2, 2])
#plot after
PlotKMeans(cat1, cat2, cat1Cen, cat2Cen)

3 thoughts on “K-Means under the hood with Python

  1. This is a great demonstration of how k-means is implemented! I would have been confused as to why CreateScatterTrace’s local x and y variables extracted parameter data’s entire columns but it was illustrated beautifully in the description.

    I’m also impressed at how it is used to categorize data. I had not a single conception of how clustering is applicable to practical situations. Examining the stabilization of the centroid made me realize centroids can serve as the average of a category. As an analogy, advising an athlete on what exercise is best for agility can be made easier by collecting a list exercises and plotting them based on their ability to increase balance or strength.

    Observing how randomized data points were categorized meaningfully has inspired me to seek data stored in the web, fetch it, and use k-means clustering to inform the public which choices will bring forth the most benefits.

  2. Pingback: K-Means Clustering With Python – Curated SQL

  3. Pingback: From Pirates to Snakes: Top-down processing – calderonj.com

Leave a Reply

Your email address will not be published. Required fields are marked *