Recent Artificial Intelligence achievements rely on the ability to train models on large quantities of data. This requires scaling two processes: high-quality data creation and model training. The first can be achieved by using a data labeling platform like Kili Technology, where you can quickly annotate images, videos, DICOM, text, and documents with high-quality standards. Model training can scale through distributed training.
Indeed, the increase of the Deep Learning algorithms’ performance relies on the capability of training models on massive datasets, sometimes composed of items themselves getting bigger. To handle this inflation of training data size, leveraging several machines efficiently is key. In this article, we will introduce distributed learning and how to implement it. We will then focus on the data-parallel distributed training, which covers most of the cases in real-world ML.
What is distributed training?
Distributed training is the process of training machine learning algorithms using several machines. The goal is to make the training process scalable, which means that handling a bigger dataset can be solved by adding more machines to the training infrastructure.
Distributed training provides more CPUs and bandwidth to process big amounts of data. However, leveraging this additional capacity optimally is a real challenge.
What is the difference between distributed training and federated learning?
Federated learning, like distributed training, assumes that the data is distributed across several machines, and then can be seen as a variant thereof. However, it is often not done for scalability reasons, but rather for data privacy, availability, or security reasons. The data is kept on each infrastructure machine (or edge), but this is a constraint more than a commodity.
When to use distributed training?
You have no other choice than running distributed training when:
Making the training procedure go through all your dataset takes an enormous amount of time (for example, very large text corpora used in large language model training like Open AI GPT series or Big Science’s Bloom, big image datasets like Imagenet, video datasets for autonomous driving, CT scans datasets, satellite images datasets…);
Storing the data is impossible on a single machine;
The data is only accessible in specific locations because of security or size (large enterprises or astronomical data for example);
The data should entirely fit in the RAM to be pre-processed.
How to do distributed training?
Distributed training does not come for free. Let’s focus on the data parallelism techniques, which deal with data distributed on several machines. We won’t describe model parallelism here, which addresses the distribution of machine learning models on several machines.
Data parallelism relies on these steps:
1. The model is initialized on the main server.
2. The workers download the model from the server.
3. The workers compute the gradients of the loss function on a batch of the training data they host.
4. The workers send the gradients to the main server.
5. The main server applies these gradients to update the model weights.
6. Go to step 2.
Data-parallel distributed training flowchart
There are several ways to make the gradient computations handle data distributed on several machines.
Synchronous updates sequence diagram with 3 workers
The synchronous updates rely on the fact that the main server waits for each worker to finish their gradient computations to complete the model update and serve it again to all the workers for a new iteration. The consequence is that this process is rather slow because it waits for the slowest worker. This can be mitigated by approaches like:
backup workers, where the main server stops waiting as soon as N workers have completed their task. This comes at the cost of higher gradient variance because this makes the batches smaller, but the time to converge is effectively reduced because the additional iterations to perform because of the higher gradient variance are faster to run than waiting for the slow workers.
flexible rapid reassignment (a.k.a. FlexRR): a slowed worker can assign some part of its duty to faster workers, to catch up with the other workers.
Asynchronous updates sequence diagram with 3 workers
The asynchronous updates work as follows: the main server updates and serves the model to all the workers each time a worker delivers a gradient update, which can cause race conditions (updates are computed on a worker on a model state that is already obsolete). This can only work if the weights that are updated simultaneously are sufficiently distinct between two concurring workers.
Fortunately, it is the case for very large models where only a small portion of the weights are updated at each gradient update. In this case, the convergence is the same as in the synchronous case, but with a higher computing efficiency.
Synchronous updates can be slow, but some added engineering can make them work as fast as asynchronous updates. The latter is easier to implement but depends on the weight gradients' sparsity.
It’s all about tuning
Once you have decided about the model update policy, tuning the worker’s batch size and the learning rate is fundamental. Why is it different from a single machine’s gradient descent algorithm? If you consider synchronous updates, the model is not updated until all the batches are received from the workers, which makes a virtual batch size of B = N x b where N is the number of workers, and b the batch size of a single machine. b should be large enough to compute significant gradients, and N can also be large when large-scale problems are addressed.
Single machine update from a batch
Server update from worker batches
This makes the number of steps in an epoch lower and does not decrease the loss at high speed. This can be mitigated by putting a higher learning rate, at the potential cost of convergence stability. So tuning is important to seize all the benefits of parallelization.
Distributed training in practice
Once the gradient update strategy has been decided, it is time to pick a framework to integrate parallelization into your machine learning algorithm. In the rest of this document, we will use the Horovod framework. Horovod is a distributed Deep Learning framework compatible with Pytorch and Keras and can run with several infrastructure backends like Kubernetes, Ray, or Spark.
We will examine the example given in the repository for MNIST training:
Here are the main Horovod functions that should be used to make the model training distributed:
hvd.DistributedOptimizer: the distributed optimizer that wraps a torch optimizer. It exposes the synchronized method to gather all the gradients once finished and then reduces the result. It finally updates the model.
hvd.broadcast_parameters: broadcast the parameters (state_dict, named_parameters, parameters). This is the method used to provide the model update to the worker processes.
hvd.broadcast_optimizer_state: broadcast the optimizer to the nodes so they can compute the gradient updates with the correct learning rate.
hvd.allreduce: perform averaging or summation over all the processes running on the workers. It is used in the DistributedOptimizer, and should also be called as soon as a value should be computed on the workers and reduced on the main server.
hvd.rank: unique id of the process running on the worker.
hvd.size: total number of Horovod processes over all the workers.
The PyTorch DistributedSampler does the data distribution part:
torch.utils.data.distributed.DistributedSampler: samples the subset of the data associated with the process rank.
Now you have all the tools to create a distributed training process for a PyTorch script!
You can also have a look to other distributed training approaches that are specific to the ML frameworks, like:
Pytorch distributed data parallel training with nn.parallel.DistributedDataParallel, working with Pytorch only.
tf.distribute, working with Tensorflow/Keras
In this article, we introduced the distributed training concept and focused on data-parallel training. We highlighted the differences between asynchronous and synchronous training, and the care to be given to learning rate tuning. We finally discussed the toolbox to operate distributed training, leveraging Horovod and Torch features.
Distributed machine learning tools
Pytorch distributed: https://pytorch.org/tutorials/beginner/dist_overview.html
Distributed machine learning articles
A Survey on Distributed Machine Learning: https://arxiv.org/abs/1912.09789
Distributed Deep Learning Using Synchronous Stochastic Gradient Descent: https://arxiv.org/abs/1602.06709
Revisiting Distributed Synchronous SGD https://arxiv.org/abs/1604.00981
Solving the straggler problem for iterative convergent parallel ML https://www.cs.cmu.edu/~jinlianw/papers/harlap-socc16.pdf
Distributed machine learning book
Section Distributed Training in the book Designing Machine Learning Systems
Written by Pierre Leveau
Lead Machine Engineer @ Kili Technology