Multi-Task Learning Explained in 5 Minutes
March 20, 2021Video Transcript
Hi! In this video, we'll talk about multi-task learning, which is about teaching machine learning models to do multiple things at a time.
We can do a lot of things using neural networks, or machine learning models in general, such as image classification, object detection, super-resolution, text generation, and so on.
Typically a model is trained to do a single task, which is convenient, but we may want to use a single model to solve multiple problems for various reasons, such as efficiency and better generalization.
To improve efficiency, one thing we can do is to share some of the layers between different but related tasks. For example, if we want to classify a scene, detect objects in it, and also output a segmentation mask, do we really need to train three separate models? All these three tasks involve some level of scene understanding. We could potentially save some memory, computation, and power by building a unified, multi-task model.
Efficiency is especially important in embedded applications. If you are running your model on low-power devices, any amount of power you save can make a difference.
So, how can we do that from a model architecture perspective? Perhaps the simplest thing we could do would be to use a shared backbone followed by multiple heads, each dedicated to one task.
There are many architectural considerations that we can make here. How much capacity to allocate to the shared backbone, how big the heads should be, where the heads should branch off, and so on.
When we allocate more capacity to the shared part and less capacity to the non-shared part, the tasks are forced to share more features. Therefore, they become more tightly coupled.
Intuitively, the closer the tasks are the more features they can share, but how close the tasks are may not always be straightforward. Seemingly-related tasks can hurt each other's accuracy when trained together. So, in practice, it takes some trial and error to design a multi-task model.
We don't have to have a hard-branching in our architecture, like in the previous example. We can also define soft parameter sharing mechanisms like the cross-stitch networks from CVPR 2016, or SemifreddoNets that I proposed in this ECCV 2020 paper, where how much of the features should be shared is determined by trainable parameters. I'll put the links to the papers in the description.
So far we focused only on the efficiency aspect of multi-task learning. Efficiency is not the only motivation for multi-task learning though. When done right, learning multiple tasks together can help a model learn a better representation, and generalize better too. But a better representation is not guaranteed when multiple tasks are learned together. Some tasks help each other but some others don't get along well. This ICML 2020 paper has some good insights about what tasks should and should be learned together and which ones should not.
Alright, let's say we already have a multi-task model architecture. To train it, we need to define a loss function and an optimizer. The overall loss function is typically a combination of multiple loss functions, corresponding to multiple tasks. How we combine them depends on a lot of factors. For example, some losses may be on a different scale than others, such as cross-entropy and mean squared error losses. If one loss is much larger than the others, then it may dominate the training. Some losses may be easier to optimize than others. Some losses may be more important to the overall system. Some losses may converge faster than the others, so the relative magnitudes of the losses may be constantly changing.
As you can see, assigning weights to each one of the tasks manually may be cumbersome, and sometimes outright impossible. Luckily there are some approaches that aim to do that automatically.
In this paper, titled, Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics, the authors propose to weigh multiple loss functions by considering the uncertainty of each task. I won't go into the math much, but let's take a look at how this combined loss function works intuitively. They basically define these sigmas as trainable parameters for each task. If a task has a higher uncertainty, a larger sigma will decrease its contribution to the overall loss function. The log sigmas here will penalize setting the sigmas unnecessarily large. Otherwise, the model can always say "I'm not sure" to minimize the overall loss function. As the uncertainty goes the infinity, the loss would go to zero without these additional terms.
Another approach to loss balancing is proposed in this paper, named GradNorm. GradNorm adjusts the weights of the losses in a way that balances the gradient magnitudes. The training is considered balanced when the loss ratio is similar for all tasks. The loss ratio is basically the loss for a given task as compared to its initial value. This balanced training setup aims to prevent some tasks from getting too far ahead as compared to others.
The idea of multi-task learning sounds very intuitive and promising, but in practice, it doesn't always work well. Sometimes, training a separate model for each task performs better because of the gradient updates for different tasks interfering or conflicting with each other, or some tasks dominating the others. This is called negative transfer. Even when a multi-task model has a larger capacity than all of the separately trained models combined, it may still perform worse.
Besides, it's easier to deal with separately trained models. If we need improvement on one task, we can train that model separately. Try all kinds of tricks on that particular model only, without breaking the other tasks. So, with everything else being equal —same overall computational cost and power consumption— it makes more sense to train a separate model for each task.
Alright, that was pretty much it. I hope you liked it. Subscribe for more videos. And I'll see you next time.
References
- SemifreddoNets: Partially Frozen Neural Networks for Efficient Computer Vision Systems
- SemifreddoNets Video Summary
- Cross-stitch Networks for Multi-task Learning
- Which Tasks Should Be Learned Together in Multi-task Learning?
- Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics
- GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks