What Dataset should I use to retrain my model?
A complete guide on what dataset you should use to retrain your model. Learn with cases.
Retraining a model is a necessary step in a model conception. What do I do if new data is available to me ? What do I do if my model shows lower prediction performances compared to the time I trained and tested it ?
Suppose that you want to improve your model performances after some period of time. You have two main possibilities :
You can readjust hyperparameters, change the pipeline, change the model architecture etc. In a nutshell, you can modify your model to have a new version, more or less close to the former one, which has better performance.
You can retrain your model on a new dataset.
Changing your model architecture or hyper-parameters will change its ability to learn information in your data. If the data distribution is complex and that your model is too simple to catch this complexity, it is necessary to focus on the model engineering. It is a good leverage to improve your model performance, and it is sometimes necessary; but it is time consuming.
Now let’s suppose that you have a classification problem and that you have the Bayes Classifier, ie, the theoretically best classifier. The model will learn from the data that you give to it. If the data carry very poor information or if your dataset doesn’t represent the true data distribution, it will have very poor prediction performances, even if it is the best classifier. Acting on the dataset to retrain your model is today underestimated and yet very important. Machine learning algorithms learn from data, improving your training dataset will improve your model performances. Bad data leads to bad performances. Kili Technology helps you to handle your dataset : assets and annotations. We provide a platform to fasten your annotations and enhance your dataset.
In this blog article, we will focus on the second leverage to improve your model performance : retraining it with a new dataset. But how to choose this new dataset ? The way to choose your new dataset will depend on the following question that you need to ask yourself : Why do I want to retrain my model?
In this article, we will focus on three reasons that cover a lot of cases :
First Case : The data I used to train my model does not represent well the current effect that I am trying to model anymore
Second case : The model has a high variance
Third case : The model achieve bad performances on some specific data
Depending on the different cases, we will see what are the actions to take to create a new training dataset to retrain the model.
First Case: The data I used to train my model does not represent well the current effect that I am trying to model anymore
When you develop a machine learning model, most of the time, you make the assumption that your probability distribution doesn’t change over time. This is generally a simplification assumption : the world evolves and most environments that you are trying to model are not stationary. If the data distribution changes too much, the model that you have trained on old data won’t be able to predict as accurately as before the new environment. This phenomenon of non-stationary distribution is called concept drift. In this article, we won’t deal with the question of detecting a concept drift, we rather encourage you to read the article “How To Monitor Machine Learning Models In Production? “ if you are not confident with such an important question.
Now let’s suppose that you consider that it is time to retrain your model because its performance decreased below your criteria level because of concept drift. What dataset should you use to retrain your model ? Should you use only new data ? Do you keep old data ?
Suppose that you index your data with their creation date. A first method is to use a fixed window size for your dataset, which means, for example, selecting the 6 last months for each retraining. The size of the window should be chosen carefully according to your environment. On one hand, If the environment is very versatile, you may choose a small window. On the other hand, if the environment is a slow evolving one, you can opt for a wider window. This method is an easy one but might be too simple for a complex environment.
An alternative to this method is dynamic window size : at each retraining you split your dataset in train/test, your test dataset is the last data that you have and the previous data is your training dataset. On your training data, you do a grid search to optimize what are the last N months of data that give the best performances. This is mainly relevant if you have new data coming periodically.
If you are in the case where you rarely get data but in a bigger amount, in a more discrete temporality you can also perform such a validation to choose your training dataset. Imagine that you trained your model on a dataset of 5000 assets that you have at the time and that now 1000 new data are available to you. You can try to take these 1000 new data (as they are new and thus more valuable) and a part alpha of randomly selected data from the old dataset. Then, you can do a cross validation to optimize this parameter alpha to have the best performance on your test data.
Second Case : The model has a high variance
Bias and variance is a tradeoff. If the variance of your model is too high, your model is not robust because it is overfitting. An interpretation is that you model learnt by heart the examples that you gave to him but came with prediction rules that allow him to make predictions on data that it has never seen. A model that is overfitting is a very bad model and you absolutely need to address the issue.
A first idea is to decrease the complexity of your model, or consider ensemble methods like bagging. Bagging is an algorithm that trains a model by aggregating different models trained on bootstrapped samples of your dataset. It has shown good results in making models more robust and avoiding overfitting.
Another option is to train your model with more data. As the goal of this article is to show that reworking your dataset can drastically improve your model’s performance, we will only speak about this last option.
If new data is available to you, it is a good idea to add it to your dataset. However, to improve the performance of your model when retraining it, your new data needs to carry new information. Remember, what is important is not the number of assets that you have, it is the diversity of information that you carry in it. The cost of retrieving new data is high. It is a cost benefit analysis : if the cost of gathering new data is higher than the benefit of having a most performant model, then it is not efficient to obtain this new data.
Techniques also exist to create new data with the data you already have. These techniques are not free because it requires at least time but they might be less expensive than seeking for new data. Below, we will see some of these techniques : Data Augmentation, other syntactic data generation or Transfer Learning
For images, you can apply transformation to your data to come up with new data : geometric transformation, color space transformation, random erasing, feature space augmentation and more. Such images won’t be totally new but they can add more complexity to your dataset. For example, If your data come from a fixed camera where objects are always taken from the same angle, geometric transformation techniques can break this biais by generating images of the same object from a different angle. There are two types of data augmentation :
Offline Augmentation: New images are generated before the model training and they are stored in your dataset. You keep them in memory.
Online Augmentation: New data is generated during the model training. They are never kept in memory. In a deep learning model, during the training algorithm, at each epoch, data of your dataset are augmented stochastically and thus your algorithm doesn’t see the same training dataset at each epoch.
Data Augmentation example with a dog image
Discover how training data can make or break your AI projects, and how to implement the Data Centric AI philosophy in your ML projects.
Synthetic data generation
In the domain of deep learning, some algorithms are able to generate new data from your dataset : generative adversarial networks, neural style transfer, variational auto-encodeur. These techniques produce new images that are more modified than data augmentation techniques explained in the previous paragraph.
Generative Adversarial Networks (GANs) and Variational Auto-encoders (VAEs) are very famous families of algorithms that can infer the probabilistic distribution of your dataset and generate new data according to this distribution. They are very documented models so we won’t go into details in this article.
The SMOTE algorithm is also very useful in case of an unbalanced dataset. Unbalanced dataset is a problem and leads to a very biased model if it is not addressed. The two main techniques to balance the dataset are under-sampling and over-sampling. Under-sampling reduces the number of data in the majority class to balance the dataset. This has to be done carefully because we are losing valuable information. Over-sampling rather increases the number of data in the minority class. SMOTE is a technique of over-sampling that generates new synthetic data for the minority class. The SMOTE algorithm selects points of the minority category that are close to each other, draw a line between these two points and generate new data for points that are on this line.
This is very useful in a deep learning algorithm, the idea is to transfer pretrained parameters from a specific task to your target task. The pretrained layers have already seen a large dataset and they carry the information that they have extracted from them in the parameters.
For example, if you work with images, deep learning algorithms recently showed very good results in image Machine Learning tasks. Images are a quite complex input and in most deep learning models, the first layers can be interpreted as feature extractor of images : the first layers extract information from images and the last layers achieve the task that you want. Then, it is a good idea to take a pretrained state of the art deep learning model like VGGNet or ResNet and to replace the last classification layer by layers that would allow to achieve your task. When training your model, you can keep the pretrained parameters constant and only train the one specific to your model. You transfer what the pretrained model has learned. Transfer Learning methods have shown good performances and allow to achieve a better accuracy and a better training start.
Transfer Learning diagram
Third Case : The model achieve bad performances on some specific data
Let's suppose that you have a model that does an image classification task. Your model has an accuracy quite high, it is satisfactory to you and you develop it on production. However a client comes to you and says that your model is not accurate at all. You look at the asset of this client and you notice that most of the images that he handles are taken from night or with low luminosity. You look at the dataset you used to train your model and you notice that it indeed lacks such images. Your dataset did not cover the data distribution well and your model predicts badly on these image gaps. So how to spot such a problem?
On a classification task, a good first idea is to draw the classification matrix, this will allow us to see if the model performs badly on a specific category of the data. If this is the case, you can find more data points of this category and add it to your dataset to retrain your model.
Dimension reduction techniques
A more general and complex technique is to use dimension reduction techniques that can be used for visualisation. Most famous algorithms are t-SNE or UMAP, they are more complex and more powerful than Principal Component Analysis (PCA) for visualization. Such algorithms reduce the dimension of the input space while keeping similar points close to each other in the output space. If the data space is reduced to a dimension of 2 or 3, we can visualize the data more easily. If a kind of data is present in low numbers in the dataset, such a visualization might allow us to spot it. Such data would be isolated in a small cluster of points. If such a case happens, it is wise to test the model on new data of the lacking type, and if the model performs badly on it, to append more data of this type in the dataset to retrain the model.
Fashion-MNIST dataset from Zalando Research with images of fashion items embbeded with supervised UMAP using 15 neighbors. Image taken from UMAP documentation
Model interpretability is a key practise when trying to improve its model. Understanding why a model has made a decision is important if this decision has a big impact and that we want to know why the model made such a prediction. But model interpretability is also a strong tool to deal with biais.
One famous example was the AI recruiting tool project of Amazon created in 2014 that could review resumes and determine which applicant could continue the process. This model was discriminative against women and was of course abandoned. But this bias is not surprising as the data that fed the model were mainly males resumes as Amazon software engineers is majority composed of men. Model interpretability allows you to understand which variables play a role in the decision and to assess if this makes sense.
Models are more or less interpretable and we can observe experimentally that the less interpretable models are often the more accurate.
Shap for (SHapley Additive exPlanations) is a method based on game theory to explain predictions and is very useful for model interpretability. It can give you the feature importance and the impact of each feature in the model decision (Other techniques exist for feature importances).
In the Image below, we can see the Shap values calculated on the famous Boston housing dataset. Red variables are the one contributing to predict a higher value and the blue one contributing to predict a lower value. It gives a very good idea of the impact of a feature in a model.
example of visualization of Shap coefficents
If such a tool had been used for the Amazon gender bias model, it would have been visible that gender plays a role in prediction. Once a bias is identified, the action to rebalance the dataset to decrease the bias can be taken.
Deep learning models are among the less interpretable models. Grad-cam is a tool developed by researchers in 2016 that allows more interpretability for deep learning models applied to images. Grad Cam allows visualizing the class activation of Convolutional Neural Network (CNN), i.e. to see where the CNN is looking in the image to make the prediction. Grad-Cam output is a heatmap visualisation for a given category as shown in the image below. It is a good idea to experiment this algorithm on images that our model does not predict well to debug the model. If the model doesn’t not activate around the good object in the image, it might be a good idea to add more data of this class in your dataset.
Grad-cam example taken from the original paper
There exists plenty more tools for model interpretability that is a very important subject but not the core of this article. If you want more information on the subject, Interpretable Machine Learning A Guide for Making Black Box Models Explainable is a very good online book by Christoph Molnar.
In this article, we wanted to show the importance of the dataset when retraining a model. Data Scientists spend a lot of time reworking on models to make sure that the models learn the most from data. But if the data is bad, the model won’t have a lot of information to learn from it. Kili Technology helps you to build the perfect dataset, in a collaborative way and with plenty of intelligent tools to accelerate the annotation process.
We didn’t speak about what to do after having retrained the model yet. You need to go through the process of model validation that is very general and we won’t go into detail. When you retrain a model, it is necessary to always save the former version. If the new model shows lower performances than the old model, you need to be able to come back to the original model.
We presented different actions to take to improve your dataset when retraining your model. We saw how to spot weaknesses of your dataset with tools dependent or independent from your model. The subject is vast and many more tools were not described in this article but we had the pretension to give you the necessary basics for dataset improvement when retraining your model.