There's been a recent string of blog posts featuring implementations of the k-Nearest Neighbour algorithm in several languages (for k = 1), and it's been cool to see how the solutions are expressed differently in different languages, as well as the vast differences in performance. The k-Nearest Neighbour algorithm is cool in itself, it's a dead simple technique for classifying (or labelling) a collection of observed data by searching the previously observed collections of data whose classifications/labels are already known, and finding the one which most nearly resembles our not-yet-classified collection (for some definition of "nearly resembles"). It's an example of a machine learning classification algorithm, one of those things that lives in the fun intersection of math and programming.
We're gonna take a look at a concrete application of the k-NN algorithm, compare the performance of the implementations from those aforementioned blog posts with new implementations in Golang and Haskell, and take a look at an optimized version which takes a logical shortcut and also leverages Golang's built-in support for concurrency.
All the code and datasets can be found on Github. The Golang and Haskell code is also at the bottom of this post.
TL;DR: Golang wins, or, in honor of the World Cup: GOOOOOOOOOOLLLLLLLang!!!
The problem
In this particular example, we've got 5000 pixelated (28x28) greyscale (0-255) "drawings" of the digits 0 through 9. Some of them might look like this:
Source: https://onlinecourses.science.psu.edu/stat857/node/186
These 5000 digit drawings constitute our training set. We're then given a bunch of new drawings where (let's pretend for a moment) we don't know what digits they're supposed to represent, but we know the greyscale values at each pixel. Given any such unclassified drawing, our goal is to make a reasonable guess as to what digit it's supposed to represent. The way this algorithm works is to find the drawing in the training set which most nearly resembles our unclassified drawing, then our reasonable guess is that the unclassified drawing in question represents the same digit as the nearest drawing in the training set. At this point, we can say that we've classified our previously unclassified drawing.
But what does "nearly resemble" mean in this case? Roughly, we want to look at how different a pair of drawings is, pixel by pixel, and aggregate those differences for all the pixels. The smaller the aggregate pixel difference, the nearer the resemblance. The standard measure of distance here is the Euclidean metric: Given two vectors x⃗, y⃗ of length 28 × 28 = 784 consisting of 8-bit unsigned integers 0…255, we define their distance to be:
$d(\vec{x}, \vec{y}) = \sqrt{\sum_{i=0}^{783} (x_i - y_i)^2}$
In this problem we're given 500 drawings to classify, and they form our validation set. After running the algorithm against all 500, we can see what percentage of them we classified correctly (because we actually are given their labels, we just pretend not to know them when doing the classification), and how long it took to do them all.
The data is given to us as a couple of CSV files, one for the training set, one for the validation set. Each row corresponds to a drawing. The first column is the label (i.e. what digit the drawing represent), and the next 784 columns are the greyscale values of each pixel in the drawing.
Note that the above describes the k-Nearest Neighbour classification in the case k = 1. If we wanted to do it for k > 0, we would take an unclassified drawing and find the k nearest drawings in the training set, and then classify the drawing according to whichever digit is represented most amongst those k nearest drawings.
Blog Chain
This post inspired by a chain of blog posts, each of which contains implementations of the algorithm in a different language (or two). All the implementations are naive, in that they pretty much do the simplest thing possible, and take hardly any shortcuts to speed up or skip calculations:
- The most recent one implemented it in Factor
- The one before that did it in Rust.
- That one was inspired by a blog post which had it in F# and OCaml, and a follow-up which improves the first OCaml implementation.
I work for Pivotal on the Cloud Foundry project and recently joined the Diego team where I was introduced to Golang. I thought it'd be fun to add naive and optimized implementations in Golang to the comparison. Then I came across an awesome primer on Haskell (http://learnyouahaskell.com/) so the incomparable @alexsuraci and I paired on adding Haskell to the mix.
Comparison
Performance comparisons between the naive implementations in each language were performed on a freshly spun up c3.xlarge EC2 instance as follows:
- Install Golang, Haskell, Rust, F#, and OCaml. Download Factor.
- Write the (naive) code for Golang and Haskell. Copy-paste the code for Rust, F#, OCaml, and Factor.
- Compile executables for Haskell, Rust, F#, and OCaml.
- Run and time the executables with
time ./<executable-name>
. Run the Golang code withtime go run golang-k-nn.go
. Run the Factor code in thescratchpad
REPL with[k-nn] time
.
Results
- Golang: 4.701s
- Factor: 6.358s
- OCaml: 12.757s
- F#: 23.507s
- Rust: 78.138s
- Haskell: 91.581s
Golang
1 2 3 4 5 6 7 |
|
Haskell
1 2 3 4 5 6 7 |
|
Rust
1 2 3 4 5 6 7 |
|
F#
1 2 3 4 5 6 7 8 |
|
OCaml
1 2 3 4 5 6 7 |
|
Factor
1 2 3 4 5 6 7 8 9 10 11 12 13 |
|
Optimized implementation in Golang
The Golang implementation gets a major performance boost involves two optimizations:
- Short-circuit distance calculations between a test case and a training case that are necessarily suboptimal. In other words, if you know the distance to one potential nearest neighbour is 100, and half-way through calculating the distance to another potential nearest neighbour you already have a distance-so-far of 105, stop calculating and move on to the next candidate for nearest neighbour.
- Use goroutines to parallelize the computations. The way this was done was not ideal, because the parallelism isn't in the classification algorithm itself, instead it parellelizes the classification of the members of the validation sample. However, it's is easy enough to "do it right", and what's currently there is good enough to see how significant the gains are when firing on all your cores.
1 2 3 4 5 6 7 |
|
Code
Golang (naive)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
|
Golang (optimized)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
|
Haskell (naive)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
|