Sep 17

Repairing an Image in Python using Machine Learning (using KNN filters)

There are many great articles that can get you started on Machine Learning but so many of them focus on a handful examples that have little relevance to the real world. After all, there are only so many cases where one would need to determine the kind of iris based on it’s petals…

As such, one of the biggest challenges is not necessarily understanding Machine Learning – there are plenty good resources out there, but rather finding where to apply the knowledge.

While reading about KNN filtering it occurred to me that it would be a great candidate for repairing images with broken or missing pixels.

But first, what is KNN (or K Nearest Neighbors)? In short, KNN assumes that items with similar characteristics tend to bunch up, and tries to predict the value of an item based on the values of its closest neighbors. I will not go into too many details, because there are many great explanations on the net (one such article can be found here  https://www.analyticsvidhya.com/blog/2014/10/introduction-k-neighbours-algorithm-clustering/ ) I will focus, rather, on explaining the rationale behind image repairing.

So why did I think of images? Well, if we look at an image, we see that a lot of the pixels have similar colors to the pixels that are next to them. This is actually an observation that was used when creating  some older compression algorithms, anyone remember PCX files? So this means that even if we have a missing pixel, we can make a pretty good guess of its original color just by looking at the pixels that are around it.

I ran this algorithm on several images and got impressive results. For this example, I am using a random image that I got from here: https://pixabay.com/en/tree-sea-grass-nature-ocean-2435269/

Original image

What’s more, the code is not even too long, about half of it is just methods to handle the loading and saving of the image files.

Prerequisites (besides Python):

We’ll be using the following libraries:
  • Pillow for image processing (loading, converting and saving images)
  • numpy and pandas to work with data frames
  • pylab from Matplotlib for plotting graphs
  • scikit-learn for the actual KNN classification

The easiest way to install these is to use pip

pip install pandas numpy matplotlib sklearn Pillow
Ok, now that we have let’s get cracking.
The first step will be to load the image. We’ll be using the Pillow library to handle the image loading. If the image is not in the RGBA format, we convert it to RGBA.
def load_image(filename):
    im = Image.open(filename)
    if im.mode is not 'RGBA':
        return im.convert(mode='RGBA')
    return im

Once we have the image in the format that we need, we’ll map it to a dataframe, which, in our case represents a matrix that defines the image. Each pixel is defined by its X and Y coordinates and the color.

def convert_to_dataframe(image):
    pixels = image.load()
    data = []
    all_colors = []
    for x in range(0, image.width):
        for y in range (0, image.height):
            pixel_color = rgba_to_hex(pixels[image.width - x - 1, image.height - y - 1])
            data.append([x, y, pixel_color])
    return data, set(all_colors)

Because the Pillow image data defines each pixel as a tuple of 4 values for red, green, blue and alpha each, we will convert it to a single string representation of the format #RRGGBB. In this case I am ignoring the alpha value as I am not working with transparent images, but updating the code to include alpha would be trivial. As a note, the colors are being stored as strings, we would probably have a small performance boost if we stored them as 32bit integers instead. For the sake of simplicity, though, I chose not to do it.

def rgba_to_hex(rgba_tuple):
    assert type(rgba_tuple) == tuple and len(rgba_tuple) == 4
    return "#{:02x}{:02x}{:02x}".format(rgba_tuple[0], rgba_tuple[1], rgba_tuple[2])
If we’re here, let’s also write the functions to save the dataframe as image files. The save function receives one or more dataframes as parameter then combines the dataframes and writes them to a png file. It will become clearer why I chose to combine dataframes a little bit later on.
def save_to_file(filename, dataframes, size):
    ni = Image.new('RGBA', size, '#ff00ff')
    pixels = ni.load()
    for df in dataframes:
        for row in df.itertuples():
            pixels[size[0] - 1 - row.x, size[1] - 1 - row.y] = hex_to_rgba(row.color)
def hex_to_rgba(hex_string):
    return int(hex_string[1:3], 16), int(hex_string[3:5], 16), int(hex_string[5:7], 16), 255

Ok, so far we have the image, we can load it and extract the data in the format that we want, let’s do that:

def run(image_name):
    im = load_image(image_name)
    filename, file_extension = os.path.splitext(image_name)
    data, all_cols = convert_to_dataframe(im)
    df = pd.DataFrame(data, columns=['x', 'y', 'color'])

Now, let’s simulate that our image is damaged. We’ll do this by randomly removing 30% of the pixels in the image. The simplest way to do this is to take advantage of the methods numpy already gives us. np.random.uniform(min, max, length) creates a list of values between min and max. If the value is greater than 0.7, we make it True, otherwise False. Then we only keep the pixels from the original dataframe for which this value is False. This will keep roughly 70% of the original pixels.

is_missing = np.random.uniform(0, 1, len(df)) > 0.7
train = df[is_missing==False]

In order to actually see the result, I am filling in the missing pixels with magenta. This is an optional step, just so that we can see what the “broken” image looks like. Let’s save the image to a file so we see what the “corrupt” image looks like.

save_to_file('{}_missing{}'.format(filename, file_extension), [train], (im.width, im.height))

Image missing pixels

How do we choose the value of K?

So now to the big question. We know that KNN looks at the K closest neighbours to determine the value of the current pixel. But how do we pick this value? Well, that’s simple… Let’s take the data we have, and split it into two sets. The first one will be the training set, the second one will be the test set. We’ll use 80% of the data as a training set and 20% as a test set. Then, we’ll try to predict the values for the 20% and compare them against the original values for a range of values for K ( I chose between 1 and 20). We’ll then compare the accuracy and take the value of K that gives us the best accuracy.

def plot_k(data):
    accuracy = get_accuracy(data)
    results = pd.DataFrame(accuracy, columns=["n", "accuracy"])
    pl.plot(results.n, results.accuracy)
    pl.title("Accuracy for variable K")
def get_accuracy(data):
    accuracy = []
    print "Plotting K..."
    is_missing = np.random.uniform(0, 1, len(data)) > 0.8
    train = data[is_missing == False]
    test = data[is_missing == True]
    for n in range(1, 20, 1):
        clf = KNeighborsClassifier(n_neighbors=n)
        clf.fit(train[['x', 'y']], train['color'])
        preds = clf.predict(test[['x', 'y']])
        k_accuracy = np.where(preds==test['color'], 1, 0).sum() / float(len(test))
        print "Neighbors: %d, Accuracy: %3f" % (n, k_accuracy)
        accuracy.append([n, k_accuracy])
    return accuracy

Plotting the K variable

We can see from here that we get the best results when K is 3 (actually, 3,5,7 are very close, we could pick anyone of those). However, after K=7 we see a decrease in accuracy. That’s most likely because the algorithm is overfitting (i.e. looking at noise in the image and thinking it’s actually a part of the image).

Note 1: I was using an image with a limited number of colors (256). When using a larger palette, the accuracy can get much lower (in the 30s) but the results will still be looking good.

Note 2: Computing the accuracy of K can take a long time for images with lots of colors. For instance, computing the accuracy for a 1280×915 and 32bpp image took 5 hours on my Macbook Pro.

Alright, now that we have the K value, we can actually create the classifier and train it.

    clf = KNeighborsClassifier(n_neighbors=3)
    clf.fit(train[['x', 'y']], train['color'])

Classifier up and running… let’s predict!

    preds = clf.predict(test[['x', 'y']])
    test.color = preds

And now that we have the prediction, let’s combine the prediction with the original data and voila! We have a complete image… We lost a little bit of the details, when comparing it to the original, but if you didn’t put them side by side it would be hard to spot.

Image with missing pixels restored

I am sure there are ways to improve this, but the results are more than satisfactory.

The full code can be found here: https://gist.github.com/radulucaciu/df0b30453338946e639449710862822d