What is the best way for me to find out whether you are rich or poor, when the only thing I know is your address? Looking at your neighbourhood! That is the big idea behind the k-nearest neighbour (or KNN) algorithm, where k stands for the number of neighbours to look at. The idea couldn’t be any simpler yet the results are often very impressive indeed – so read on…
Let us take a task that is very hard to code, like identifying handwritten numbers. We will be using the Semeion Handwritten Digit Data Set from the UCI Machine Learning Repository and are separating training and test set for the upcoming task in the first step:
# helper function for plotting images of digits in a nice way + returning the respective number plot_digit
## [,1] [,2] [,3] [,4] [,5] [,6] ## [1,] 3 1 2 5 7 3 ## [2,] 1 5 1 6 7 6 ## [3,] 6 2 8 5 9 3 ## [4,] 5 7 5 7 5 9 par(old_par)
As you can see teaching a computer to read those digits is a task which would take considerable effort and easily hundreds of lines of code. You would have to intelligently identify different regions in the images and find some boundaries to try to identify which number is being shown. You could expect to do a lot of tweaking before you would get acceptable results.
The real magic behind machine learning and artificial intelligence is that when something is too complicated to code let the machine program itself by just showing it lots of examples (see also my post So, what is AI really?). We will do just that with the nearest neighbour algorithm.
When talking about neighbours it is implied already that we need some kind of distance metric to define what constitutes a neighbour. As in real life the simplest one is the so called Euclidean distance which is just how far different points are apart from each other as the crow flies. The simple formula that is used for this is just the good old Pythagorean theorem (in this case in a vectorized way) – you can see what maths at school was good for after all:dist_eucl
The k-nearest neighbours algorithm is pretty straight forward: it just compares the digit which is to be identified with all other digits and choses the k nearest ones. In case that the k nearest ones don’t come up with the same answer the majority vote (or mathematically the mode) is taken:mode
So, the algorithm itself comprises barely 4 lines of code! Now, let us see how it performs on this complicated task with k = 9 out of sample (first a few examples are shown and after that we have a look at the overall performance):# show a few examples set.seed(123) # for reproducibility no_examples
## [,1] [,2] [,3] [,4] [,5] [,6] ## [1,] 4 1 1 5 7 3 ## [2,] 0 5 1 4 7 3 ## [3,] 6 2 7 4 0 2 ## [4,] 5 5 3 6 3 7 par(old_par) prediction
Wow, it achieves an accuracy of nearly 95% out of the box while some of the digits are really hard to read even for humans! And we haven’t even given it the information that those images are two-dimensional because we coded all the images simply as (one-dimensional) binary numbers.
To get the idea where it failed have a look at the digits that were misclassified:# show misclassified digits err
## [,1] [,2] [,3] [,4] [,5] [,6] ## [1,] 2 6 9 8 9 5 ## [2,] 6 6 7 4 8 8 ## [3,] 9 9 1 4 8 9 par(old_par) # show what was predicted print(matrix(prediction[err], 3, 6, byrow = TRUE), quote = FALSE, right = TRUE) ## [,1] [,2] [,3] [,4] [,5] [,6] ## [1,] 1 5 1 9 3 6 ## [2,] 5 0 1 1 9 5 ## [3,] 3 7 3 1 2 2
Most of us would have difficulties reading at least some of those digits too, e.g. the third digit in the first row is supposed to be a 9, yet it could also be a distorted 1 – same with the first digit in the last row: some people would read a 3 (like our little program) or nothing at all really, but it is supposed to be a 9. So even the mistakes the system makes are understandable.
Sometimes the simplest methods are – perhaps not the best but – very effective indeed, you should keep that in mind!
To leave a comment for the author, please follow the link and comment on their blog: R-Bloggers – Learning Machines.
R-bloggers.com offers daily e-mail updates about R news and tutorials on topics such as: Data science, Big Data, R jobs, visualization (ggplot2, Boxplots, maps, animation), programming (RStudio, Sweave, LaTeX, SQL, Eclipse, git, hadoop, Web Scraping) statistics (regression, PCA, time series, trading) and more...