One key challenge in modern AI application is to maintain high levels of performance with production models. Engineer spend an increasing amount of time working on the production life-cycle of their models. In the paper on the system "Overton", Apple engineers states that "In the life cycle of many production machine-learning applications, maintaining and improving deployed models is the dominant factor in their total cost and effectiveness–much greater than the cost of de novo model construction". To ensure that AI systems are able to meet business needs the effort put in monitoring has increased across all organizations. Monitoring and improving machine learning models requires to be able to:
1 - monitor your data
2 - supervise data labeling
In this article, we will first discuss why monitoring is important in machine learning and what challenges are related to it. Then we will talk about the technical aspects of monitoring data.
Why is monitoring essential in Machine Learning?
After training and testing a model in an experimental setup it is time to put it to work into production. The change from an experimental to a real life environment will affect the performance of your model. It is important to remember that a model is developed to serve a business purpose and to generate value for an organization. It must therefore answer critical needs of business:
We want models to be stable and available at all time
We want models stay relevant in production. Therefore it must maintain a high level of performance (in some application a 98% performance might not be enough!)
If you cannot ensure business that your machine learning model pipeline will always perform above X% of performance at all time the entire data science project might not make it to production.
What are the challenges of monitoring in Machine Learning?
You cannot use performance metrics
The first idea that comes to mind to monitor a machine learning model is to track performance metrics in "real-time''. Unfortunately, most of the time it won't be possible for the following reasons:
the model you are using might impact the observations you can make because there is an action taken after a prediction
Long latency between prediction and ground truth label being available
Since it is most likely not possible to track performance of models in real time, it is necessary to find heuristics to identify potential problems during the production lifecycle of your models. Otherwise, the only way to perceive the decrease in the performance of your model is through a direct impact on the business such as a decrease in income.
Monitoring models requires the collaboration of multiple populations
Once you are able to set alerts on critical part of your machine learning pipeline, you need to add human in the loop to perform relabelling to further train your model. At Kili we experienced that incorporating labeling during the production lifecycle of models allowed them to maintain a high level of performance.
Since in real life models performance decay with time it is necessary to keep humans in the loop throughout production. Humans are essential to:
Validate prediction with low confidence
Relabel data in order to retrain models
Prediction validation and relabelling needs to be performed by people that have an expertise in the task that you wish to automate. For example in the case of tumor detection specialized doctors should perform the annotation of new data. In order to keep cost low it is critical to incorporate humans in the inference pipeline using the right tools. At Kili we created the platform to ease as much as possible the interaction between data scientists and labelers to make this process efficient cutting by 5x the cost of annotation.
The next part of this article presents how to monitor your data and detect possible changes that can affect the performance of your models.
There are two main reasons why the performance of an AI application can decrease:
There is an issue in the data gathering/processing pipeline
This means that the data provided to the models is not as expected. This can occur because an update of an api you are using to collect data includes breaking changes. Another example is the failure of a sensor that was gathering images and just starts sending dark images that are instead of pictures.
2. There has been a change in the nature/distribution of the data
This second point is what makes every machine learning model's performance decay with time. Since change is inevitable, it is most likely that the data you used to train a machine learning model ends up being less and less relevant with time.
To illustrate we can look at how email spam and phishing detection are most subject to change in behaviour in the data. Scammers usually adapt to the current context to trap their victims. Recently with the pandemic, a surge in scam emails using covid 19 has been recorded. This change in behaviour from scammers can go unseen from machine learning models. In this situation the earlier we can detect this shift in phishing emails the better we can protect the users.
Unlike classic software development, it is not possible to deploy tests that cover all use cases of a machine learning model. For this reason, it is mandatory to track machine learnings behaviour in real-time for monitoring.
How to prevent breaking data changes?
The first thing to do is to be prepared for changes in the data gathering and processing pipeline. Listing all libraries and third party applications on which you rely to gather data is a good practice. This will ease the task of looking up change-logs and allow you to stay aware of any change in the data format or deprecated function.
Simple statistics performed on the input of your models can alert you on breaking data changes. To prevent breaking data you can record and compute the following metrics on the input data:
% of missing values
univariate statistics: mean, standard deviation, quantiles
the average number of words in texts
average brightness and contrasts levels in images
Any abrupt changes in those statistics should alert you to problems affecting the serving data.
How to prevent behavioural changes?
Behavioural changes can occur because either the concept you are trying to predict or the data itself has changed over time. This type of change is harder to predict because of its stochastic nature. To detect changes in data distribution we can proceed by monitoring both the output and the input of the machine learning model.
One key metric to monitor the outputs is the % of predicted labels. Choose a time window (minute, hour, day) to perform the aggregation of your model's predictions. Then you should compare these metrics with your prior belief (obtained by analysis of training data). A sudden change in the distribution of labels is a strong indicator of an issue. Let's take the simple example of spam detection. If the normal rate of email identified as spam is 3% and goes up to 20% in a week, there might be reasons to investigate the change in the behaviour of the model.
The other key measurement to take into account is the change in the distribution of the data usually referred to as data drift. The goal is to measure how different today's input are from the one used to train the model. It is not an easy task to quantify this drift but we will see what tools can be used to perform those measurements. We can divide the techniques used to measure the drift into two categories:
Distance between distribution
Using distance between distributions
To try to quantify the change in distribution between data from different time intervals we can proceed by using some of the distance used in algorithms such as variational autoencoders or GANs (generative adversarial networks). Mainly to distance measures:
Sliced Wasserstein distance
Both distances quantify how "far" different is a distribution A from a distribution B. The latter one the Wasserstein distance (i.e EM distance) has two benefits in comparison to the KL divergence. First, it is a distance and not a divergence which means that dist(A,B) = dist(B,A). Also, the Wasserstein distance has the nice property of always being correlated with how far a distribution is from another one whereas the KL divergence might have a value of infinity if both distributions don’t overlap.
The idea here is that we can use those measures to quantify how different is the serving data from the training data. We can apply those distances to raw data or on derived features. For example, we can process texts into multiple features (word counts, tf-idf, ...) we can compute those distances on the individual features and compute the average to obtain an overall distance between two sets of data.
Using a classifier
To quantify data drift we can also use classifiers. The concept is similar to the discriminants in GANs. The idea is to create a dataset composed of both training and serving data. Both types of data are labelled with the class they belong to (example: 0 for training, 1 for serving). This data set is then split into a training and testing set. After training the model to classify the samples into training or serving data, we can compute the following error on the testing set: 2(1-2E)
If both sets of data are similar this metric should be close to 0 (error = 0.5 random). Contrarily, if the model finds an underlying pattern to differentiate the label data from the serving data the error would increase.
While this approach can be applied to diverse types of data, it comes with the cost of choosing the right type of classifier which represents additional work. One principle would be to keep the model simple to avoid monitoring the model that monitors models.
Managing data throughout the inference process
Logs in software engineering allow developers to track events and errors occurring in their applications. Records in logging are usually identified by their unique id and timestamp. Logs are a key element to debug and monitor an application. With them, developers can compute metrics to monitor the health of their applications.
In machine learning, we want to be able to perform aggregation at the data level. This requires organizing your data just as logs. We want to be able to identify data throughout the inference process: from raw data to process features to predictions. This is a list of what elements are required to monitor a machine learning model:
Processed data (model inputs)
Predictions (model outputs)
Confidence in predictions (either using simple softmax or more advanced uncertainty quantification techniques)
What to do once concept or data drift is detected?
We have talked about how to detect drift in both the data and the behaviour of a model but we haven’t yet discussed the answers we can take to solve problems and keep the model running.
The only real remediation to concept/data drift is retraining. Even if you work using ensemble learning or online learning it is critical to obtain ground truth in order to retrain your model on newly labelled data.
Retraining regularly your models requires incorporating the labelling process in the inference pipeline of your application. Using Kili you can easily relabel batches of data using trained labellers. The powerful python API that we have developed also allows data scientists to easily integrate active learning in the loop to further improve the efficiency of the labelling.
Humans in the loop are also mandatory if you want to achieve extremely high levels of performance. In that case, you must have humans to correct or validate predictions in which you have low confidence.
Prevent breaking changes in the data gathering/processing pipeline:
List all of the dependencies your model rely on to get its input data
libraries, third party application, another service in the company
Monitor the model's input using simple statistics. Abrupt changes and extreme values are good signs of an issue
Prevent model decay due to a change in the nature of the data and concept:
Monitor the output of your model (number of classes predicted, the number of objects detected)
Compare distributions of the training and serving data
Use a distance measure (EM distance, KL divergence)
Use a classifier
Always take into account business metrics as well. For example, a drop in income in our sales can indicate a miss functioning of the model.
If you can, compare models between them in real-time. In an MLops environment, multiple models could be deployed at once and their agreement level could be compared as an additional indicator of model behaviour.
At Kili we are developing a platform to help organizations manage and supervise their data. We make it easy to manage training and serving data by associating ids to both data and labels/predictions. If your predictions consist of multiple objects (in image segmentation for example) you can have ids to the instance level! This is proved useful to perform future aggregations. You can also compare multiple models using consensus our evaluate models against ground truth. And last but not least it is easy to plug into your inference pipeline to perform batch labelling during production thus allowing for retraining or applying online learning. Learn more Kili by reading on our website.
In this article, we have learned how we can monitor our model to detect issues that can arise from the data. For those that are interested in this topic and wish to go further, I recommend a list of very interesting documents to read.