How To Monitor Machine Learning Models In Production?
One key challenge in modern AI application is to maintain high levels of performance with production models. Engineer spend an increasing amount of time working on the production life-cycle of their models. In the paper on the system “Overton”, Apple engineers states that “In the life cycle of many production machine-learning applications, maintaining and improving deployed models is the dominant factor in their total cost and effectiveness–much greater than the cost of de novo model construction“. To ensure that AI systems are able to meet business needs the effort put in monitoring has increased across all organizations. Monitoring and improving machine learning models requires to be able to:
1 – monitor your data
2 – supervise data labeling
In this article, we will first discuss why monitoring is important in machine learning and what challenges are related to it. Then we will talk about the technical aspects of monitoring data.
Why is monitoring essential in Machine Learning?
After training and testing a model in an experimental setup it is time to put it to work into production. The change from an experimental to a real life environment will affect the performance of your model. It is important to remember that a model is developed to serve a business purpose and to generate value for an organization. It must therefore answer critical needs of business:
- We want models to be stable and available at all time
- We want models stay relevant in production. Therefore it must maintain a high level of performance (in some application a 98% performance might not be enough!)
If you cannot ensure business that your machine learning model pipeline will amways perform above X% of performance at all time the entire data science project might not make it to production.
What are the challenges of monitoring in Machine Learning?
You cannot use performance metrics
The first idea that comes to mind to monitor a machine learning model is to track performance metrics in “real-time”. Unfortunatly, most of the time it won’t be possible for the following reasons:
- Model intervention
- the model you are using might impact the observations you can make because there is an action taken after a prediction
- Long latency between prediction and ground truth label being available
Since it is most likely not possible to track performance of models in real time, it is necessry to find heuristics to identify potential problems during the production lifecycle of your models. Otherwise, the only way to percive the decrease in the performance of your model is through a direct impact on the business such as a decrease in income.
Monitoring models requires the collaboration of multiple populations
Once you are able to set alerts on critical part of your machine learning pipeline, you need to add human in the loop to perform relabeling to further train your model. At Kili we experienced that incorporating labeling during the production lifecycle of models allowed them to maintain a high level of performance.
Since in real life models performance decay with time it is necessary to keep humans in the loop throughout production. Humans are essential to:
- Validate prediction with low confidence
- Relabel data in order to retrain models
Prediction validation and relabeling needs to be performed by people that have an expertise in the task that you wish to automate. For example in the case of tumor detection specialized doctors should perform the annotation of new data. In order to keep cost low it is critical to incorporate humans in the inference pipeline using the right tools. At Kili we created the plateform to ease as much as possible the interaction between data scientists and labelers to make this process efficient cuting by 5x the cost of annotation.
The next part of this article presents how to monitor your data and detect possible changes that can affect the performance of your models.
There are two main reasons why the performance of an AI application can decrease:
- There is an issue in the data gathering/processing pipeline
This means that the data provided to the models is not as expected. This can occur because an update of an api you are using to collect data includes breaking changes. Another example is the failure of a sensor that was gathering images and just starts sending dark images that are instead of pictures.
2. There has been a change in the nature/distribution of the data
This second point is what makes every machine learning model’s performance decay with time. Since change is inevitable, it is most likely that the data you used to train a machine learning model ends up being less and less relevant with time.
To illustrate we can look at how email spam and phishing detection are most subject to change in behavior in the data. Scammers usually adapt to the current context to trap their victims. Recently with the pandemic, a surge in scam emails using covid 19 has been recorded. This change in behavior from scammers can go unseen from machine learning models. In this situation the earlier we can detect this shift in phishing emails the better we can protect the users.
Unlike classic software development, it is not possible to deploy tests that cover all use cases of a machine learning model. For this reason, it is mandatory to track machine learnings behavior in real-time for monitoring.
How to prevent breaking data changes?
The first thing to do is to be prepared for changes in the data gathering and processing pipeline. Listing all libraries and third party applications on which you rely to gather data is a good practice. This will ease the task of looking up changelogs and allow you to stay aware of any change in the data format or deprecated function.
Simple statistics performed on the input of your models can alert you on breaking data changes. To prevent breaking data you can record and compute the following metrics on the input data:
- % of missing values
- univariate statistics: mean, standard deviation, quantiles
- average number of words in texts
- average brightness and contrasts levels in images
Any abrupt changes in those statistics should alert you on problems affecting the serving data.
How to prevent behavioral changes?
Behavioral changes can occur because either the concept you are trying to predict or the data itself has changed over time. This type of change is harder to predict because of its stochastic nature. To detect changes in data distribution we can proceed by monitoring both the output and the input of the machine learning model.
One key metric to monitor the outputs is the % of predicted labels. Choose a time window (minute, hour, day) to perform the aggregation of your model’s predictions. Then you should compare this metrics with you prior belief (obtained by analysis of training data). A sudden change in the distribution of labels is a strong indicator of an issue. Let’s take the simple example of spam detection. If the normal rate of email identified as spam is 3% and goes up to 20% in a week, there might be reasons to investigate the change in the behavior of the model.
The other key measurement to take into account is the change in the distribution of the data usually referred to as data drift. The goal is to measure how different today’s input are from the one used to train the model. It is not an easy task to quantify this drift but we will see what tools can be used to perform those measurements. We can divide the techniques used to measure the drift into two categories:
- Distance between distribution
Using distance between distributions
To try to quantify the change in distribution between data from different time intervals we can proceed by using some of the distance used in algorithms such as variational autoencoders or GANs (generative adversarial networks). Mainly to distance measures:
- KL divergence
- Warssestein distance
- Sliced Wasserstein distance
Both distances quantify how “far” different is a distribution A from a distribution B. The later one the Wasserstein distance (i.e EM distance) has two benefits in comparison to the KL divergence. First, it is a distance and not a divergence which means that dist(A,B) = dist(B,A). Also, the Wasserstein distance has the nice property of always being correlated with how far a distribution is from another one whereas the KL divergence might have a value of infinity if both distributions don’t overlap.
The idea here is that we can use those measures to quantify how different is the serving data from the training data. We can apply those distances to raw data or on derived features. For example, we can process texts into multiple features (word counts, tf-idf, …) we can compute those distances on the individual features and compute average to obtain an overall distance between two sets of data.
Using a classifier
To quantify data drift we can also use a classifiers. The concept is similar to the discriminants in GANs. The idea is to create a dataset composed of both training and serving data. Both types of data are labeled with the class they belong to (example: 0 for training, 1 for serving). This data set is then split into a training and testing set. After training the model to classify the samples into training or serving data, we can compute the following error on the testing set: 2(1-2E)
If both sets of data are similar this metric should be close to 0 (error = 0.5 random). Contrarily, if the model finds an underlying pattern to differentiate the label data from the serving data the error would increase.
While this approach can be applied to diverse types of data, it comes with the cost of choosing the right type of classifier which represents additional work. One principle would be to keep the model simple to avoid monitoring the model that monitors models.
Managing data througout the inference process
Logs in software engineering allow developers to track events and errors occurring in their application. Records in logging are usually identified by their unique id and timestamp. Logs are a key element to debug and monitor an application. With them developers can compute metrics to monitor the health of their application.
In machine learning, we want to be able to perform aggregation at the data level. This requires to organize your data just as logs. We want to be able to identify data throughout the inference process: from raw data to process features to predictions. This is a list of what elements are required to monitor a machine learning model:
- Raw data
- Processed data (model inputs)
- Predictions (model outputs)
- Confidence in predictions (either using simple softmax or more advanced uncertainty quantification techniques)
What to do once concept or data drift is detected?
We have talked about how to detect drift in both the data and the behavior of a model but we haven’t yet discussed the answers we can take to solve problems and keep the model running.
The only real remediation to concept/data drift is retraining. Even if you work using ensemble learning or online learning it is critical to obtain ground truth in order to retrain your model on newly label data.
Retraining regularly your models requires to incorporate the labeling process in the inference pipeline of your application. Using Kili you can easily relabel batches of data using trained labelers. The powerful python API that we have developed also allows data scientists to easily integrate active learning in the loop to further improve the efficiency the labeling.
Humans in the loop are also mandatory if you want to achieve extremly high levels of performance. In that case you must have humans to correct or validate prediction in which you have a low confidence.
Prevent breaking changes in the data gathering/processing pipeline:
- List all of the dependencies your model rely on to get its input data
- libraries, third party application, another service in the company
- Monitor the model’s input using simple statistics. Abrupt changes and extreme values are good signs of an issue
Prevent model decay due to a change in the nature of the data and concept:
- Monitor the output of your model (number of class predicted, the number of objects detected)
- Compare distributions of the training and serving data
- Use a distance measure (EM distance, KL divergence)
- Use a classifier
Always take into account business metrics as well. For example, a drop in income our sales can indicate a miss functioning of the model.
If you can, compare models between them in real-time. In an MLops environment, multiple models could be deployed at once and their agreement level could be compared as an additional indicator of model behavior.
At Kili we are developing a platform to help organizations manage and supervise their data. We make it easy to manage training and serving data by associating ids to both data and labels/predictions. If your predictions consist of multiple object (in image segmentation for example) you can have ids to the instance level! This is proved useful to perform future aggregations. You can also compare multiple models using consensus our evaluate models against ground truth. And last but not least it is easy to plug to your inference pipeline to perform batch labeling during production thus allowing for retraining or applying online learning. Learn more Kili reading our website.
In this article, we have learned how we can monitor our model to detect issues that can arise from the data. For those that are interested in this topic and wish to go further, I recommend a list of very interesting documents to read.