Guide to Human in the Loop Machine Learning
Machine Learning models are not perfect. Their performance depends on the quality of data they are trained on. Publically available datasets are more often than not insufficient for training a model for production use cases. The dataset might not cover all the edge cases that our model might encounter in real-world applications. In order to tackle this issue, we need to have a human in the loop who can find these edge cases and label the data required for the next round of model training.
Key Points to Remember
In this article we will try to understand why training a machine learning model needs to be an iterative process and what type of data should be added to the training dataset in each iteration. We will see how to use Human in the Loop Machine Learning in a production use case, and how it helps us in building a really robust Machine Learning model with limited resources and data. In a nutshell, Human in the Loop Machine Learning follows the following steps:
Collect some data
Fine tune a model on that data
Spot “Edge Cases”
Find more data to eliminate edge cases
Retrain the model
Spot “Edge Cases”
Always Start with the Data
Let’s assume that we are training a segmentation model for fighter jets. Our primary challenge is to find enough labeled data to train a decent model. In this article, we will look at a technique that we can use to collect data while our model is in production. Which is called Human in the Loop Machine Learning. It can also be used in cases where the model is not currently in production but we have a really large unlabeled dataset to train our model. We can always take a brute force approach and label all the images in the dataset but annotation is really expensive and Human in the Loop Machine Learning can help us avoid spending money on labeling data which might not help us get better results.
As we already know, datasets need to be processed before being fed into a ML model. We need to ensure that all the classes are balanced before we train a model on the dataset. Also, we need to make sure that the model is able to segment instances of all the classes irrespective of the relative size of the object in the image. To do this we need to slice our dataset or break our dataset into parts based on some rules. These parts are known as dataset slices.
While evaluating our model we need to make sure that we test our model on images from each slice and compare the model's metric score on each slice in every iteration to the metric score from the previous iteration. A good model performance metric is also important and we should spend some time to make sure that the metric that we use for our training and evaluation satisfies our business needs.
Why do ML models need Fine-tuning?
As mentioned earlier, Machine Learning models are as good as the dataset they are trained on. In production use cases we might not have enough labeled data to train a model from scratch. In that case we use transfer learning to fine tune a pre-trained model on our relatively small dataset. In most practical use cases, we start with a pre-trained model which might be trained on a big open source dataset like ImageNet for image classification or COCO for object detection. In our use case we can use a model which is trained on PASCAL VOC 2012 which is a pretty standard segmentation dataset and fine tune our model on our Fighter Jet Dataset.
Go to Papers With Code and look for the latest State of the Art Machine Learning Model for Semantic Segmentation task. Pick the top 3 model architectures and fine tune the models on your dataset and compare the results and pick the best one. Once we have selected our model architecture and fine tuned our model we can go ahead and deploy our model to production.
Fantastic Edge Cases and Where to Find Them
Edge cases are those very rare conditions that occur when a very specific set of conditions are met. For example, an ice cream truck driving towards a self-driving car while sunlight is shining directly into the car’s camera can be considered an edge case. We should try to collect as much data as possible to eliminate these edge cases before deploying our model in production. Companies like Tesla use computer simulations to artificially generate data for such use cases and use that to augment their dataset. But there is no sure-shot way of finding all the edge cases and collecting data to eliminate them. This is where Human in the Loop Machine Learning shines.
We by default train a model on the dataset that we can manage to label in a short period of time. Then we use that model in production and wait for any anomalies which we can see in the production results. These anomalies can be spotted by doing the following exercises periodically.
A subset of all the predictions from the model is stored for human annotator review. The predictions can be sampled every hour or every half an hour. Or we can pick every 100th prediction and send it for human review. When sampling based on time we can evaluate our model’s performance based on the time of the day while when sampling based on the number of predictions we can spot any edge cases which might be missed if we process a large number of requests during a very short time period.
Model Confidence-based Checking
We can also send those images for human annotator review on which our model is not very confident. Such images are more often than not edge cases that need to be handled appropriately. We will talk about filtering data based on edge cases in the next section in detail.
Finally, we need to pay special attention to any feedback that we get from the customers that are using our model’s prediction in order to find edge cases that we might not have been able to detect using the techniques mentioned above.
Now, a human annotator goes in and processes all the edge cases and flags them as bugs. The data engineer can now go ahead and find images similar to the images in question and prepare a dataset using them.
Collect all the prediction results from our model as well as query images in a data lake so that whenever we encounter an edge case, the data engineer can query the data lake and find images that are similar to the ones they are looking at. This approach will help us in rolling out newer versions of the model as soon as we spot an edge case. We will just have to wait for the annotating team to label the new dataset.
Filtering data efficiently to find the required data
Now let's assume that we have deployed our model in production. Our model generates millions of predictions every day and we collect all of that data to retrain our model to improve its performance. Suppose that we spot an edge case and wish to query the data lake for images that are similar to our use case. But finding images similar to our edge case in the data lake comes with its own set of challenges.
Primarily, we might have to check millions of images before we can find a decent number of images that are visually similar to the image in question. To solve this problem, we will use a clustering algorithm. We will take a very simple backbone like ResNet18 or ResNet34 to extract features from the images then we will use K-Nearest Neighbors on the pre-computed features to cluster the dataset in order to generate clusters of images. Now the annotator can go through a couple of images in each cluster to determine which cluster might have the images which are similar to the image in question.
We can then rerun the clustering algorithm on the specific cluster and keep repeating this step until we are able to isolate images that are visually similar to our image in question. This will speed up our filtering process manyfold.
Efficiently labeling the data
To label the dataset we will have to use a Data Labeling tool. While there are free alternatives that you can consider exploring, they have a really steep learning curve and need a lot of technical knowhow to even get started. I would recommend using a paid tool that provides better support which comes in handy while labeling relatively large datasets. These paid services almost always have a free tier for developers which you can use in case you are a developer who is just testing new annotation tools. Kili Technology also offers annotation services so consider giving it a shot.
Having said that, let's now focus on some specific techniques that can help us with speeding up our annotation process. Human annotators are good at detecting and labeling all kinds of images but they are quite expensive to use. In most cases, the data that we want to be labeled can be auto labelled by using Machine Learning techniques.
Using Unsupervised Techniques
When we are working on a totally new problem like 3d segmentation of plants in a video stream we might not have any pretrained models to generate pseudo labels for our dataset. In this case we will have to label each and every frame of the video manually. Some unsupervised techniques can help us in labeling this dataset. For example, if we are trying to train a model which can detect and segment different types of fighter aircraft, it's really tough for any pre-trained model to be able to generate good segmentation masks for this dataset. In this case, we can use a technique like SuperPixels to generate segmentations for images really quickly.
Using Supervised Techniques
When we have a dataset that is relatively trivial to label, for example, a dataset for training a model to segment different make and models of cars, it is really easy to generate segmentations for the cars using a model that segments out cars as many open source datasets have cars in it we can quite easily find the state of the art open source model which is trained on that dataset and use it to generate segmentations for our dataset. Now, the human annotators can go in and quickly assign makes and models to each segmented car and fix some of the segmentation maps if the model misses some part of the car like rearview mirrors and antennae.
Using our own model from the last Iteration
Once we have trained our model on one iteration of our dataset and we need to add some new data to the training set, we can very easily generate pseudo labels by simply using our latest model to generate pseudo labels. This process becomes easier as our model improves with each iteration.
Retraining the model
Retraining the model is the next logical step in the pipeline, I would recommend iterating over different hyperparameters before settling down on a specific set of hyperparameters for training. In general, I would recommend using early stopping while training relatively small models to avoid overfitting. In the case of large models, I would recommend using Learning Rate Scheduling where we exponentially decay the learning rate after each epoch.
I’d like to bring your attention to the fact that Catastrophic Forgetting is very common and should not be overlooked. In our use case when training with a human in the loop, it’s very important to test the model properly and make sure that the model performs well on all the slices of datasets. Every new addition to our dataset should be considered as a separate slice and evaluation metrics must be calculated on each slice as usual to make sure that our model is not overfitting to a specific slice of data.
Re-evaluating the model
Once we have retrained our model we will once again evaluate the model so that we are sure about the performance of the model. We must save as much metadata as possible with the model weights so that we can very easily explain the model’s behaviour in any case. I would recommend spending some time building a Model Card for every iteration of the model. We can then deploy the model in production and start looking for more edge cases.
To conclude, we can say that Human in the Loop Machine Learning is a very important tool. It helps us build Machine Learning models interactively, which not only maintain their performance but also improve with time. It is a very startup friendly approach that can be used by bigger companies as well to train their models with limited resources. Human in the Loop Machine Learning is especially useful in cases where the data is scarce and hard to find.