Multi-task Learning
Introduction
Multi-task Learning is a sub-field of Machine Learning and by extension, of deep learning. While it is somewhat similar in concept to Multi-modal Learning , there are a few differences.
The general idea of multi-task learning, is to share the input, a model or layers of a model in order to build a shared representation of some input data that can be used to solve multiple tasks or problems using the same input data.
This can be useful when data is limited or when we explicitly want to build a shared representation. It works only if the data is relevant to all the tasks, and there is the same amount of label data for each task.
There are multiple benefits to this approach, notably :
- We train a single large model for multiple problems, so inference is faster and the architecture is more compact.
- The label data can be sparse, meaning we can train on Task 1 with some Task 2 labels missing and vice-versa.
- There is implicit regularization by the fact we optimize the loss of the ensemble of tasks and not a single one, limiting overfitting to a task.
One example is multi-lingual translation, where the input text is the same, but is translated in different languages. This in turn builds a shared representation of multiple languages, if we manage to train the model successfully.
Multi-task variations
There are two main ways to share parameters, either by Hard parameter sharing, where the layer shares all its parameters for all tasks, or Soft parameter sharing, where the layers are constrained by eachother so the weights are similar but not identical.
There are two main ways to share parameters in multi-task learning:
Hard parameter sharing and Soft parameter sharing.
In Hard parameter sharing, all layers of a model share their weights for all tasks. This means that if we want to add or remove a task from the training set, we need to retrain the entire network with new weights.
In Soft parameter sharing, layers of different models share only their weights within each other, but not across tasks. This means that if a task is added or removed from the training set, we can retrain individual components without affecting all other tasks.
Loss computation
The loss is usually simply computed by weighted summation, so that the total loss is, optionally with a regularization term :
where:
is the weight vector represents the parameters of each task (e.g. weights and biases) is the loss function for each individual task is the number of tasks is a regularization term that encourages the model to generalize well across all tasks
The weight vector

