Extending Keras ImageDataGenerator to handle multilable classification tasks

I stumbled up on this problem recently, working on one of the kaggle competitions which featured a multi label and very unbalanced satellite image dataset.

Let’s talk a moment about a neat Keras feature which is keras.preprocessing.image.ImageDataGenerator as you can see from the documentation its main purpose is to augment and generate new images from your dataset. This is a common tactic to fight small datasets and overfitting.
By default ImageDataGenerator expects our data to be structured in a very specific way, this is each class should have its own directory and every image inside this directory belongs to the class specified by the name of this directory.
We can realize that this is very limiting and usage of this API directly will not work for Multi-label problems.

image_name tags
0 train_0 haze primary
1 train_1 agriculture clear primary water
2 train_2 clear primary
3 train_3 clear primary
4 train_4 agriculture clear habitation primary road

Since this is not a typical (single-label) target dataset, we need to develop a proper way to work with them (especially the encoding and decoding part)

This code will let us encode (and decode) a list of list of labels to a form of binary One-Hot encoded vectors.

We can now start developing the core functionality to extend ImageDataGenerator with multi-label support.

Lets first inspect the directory tree of our dataset

As you can see the directory contains 3 subfolders (valid, train, submission ), this is important to actually select what subset of the data we want to use down the line. For instance, your validation set should not be augmented.

Let’s jump into the code:

That’s quite a lot of code so let’s break it up into smaller chunks and see what they do.

L7:10
We fit our label encoder used to translate from file_names to labels and feature vectors.

L12:22
We initialize ImageDataGenerators, keep in mind that the validation generator is only doing one transform (rescaling)

L25:28
We define a grouper function, this is fairly important to yield batches in the same fashion as the IDG. This means yielding batches of constant and predefined size. If the total number of samples is not divisible by the batch-size the last yielded batch should be smaller (not cycling)

L31:37
This is the core part where we put all the code together. One of the key parts of this piece of code is on line 33, we pass gen.filenames to our grouper. As described above it’s critical to actually yield our labels in the same manner as the idg yields the data.
On line 36 we look up the filenames, and we hot-encode the labels.

Generators like that can be directly passed to a fit_generator method from Keras models.

Posted by jakub.cieslik

Leave a Reply