As machine learning becomes more prominent there is a growing demand to perform several inference tasks in parallel. Running a dedicated model for each task is computationally expensive and therefore there is a great interest in multi-task learning (MTL). MTL aims at learning a single model that solves several tasks efficiently. Optimizing MTL models is often achieved by first computing a single gradient per task and then aggregating the gradients for obtaining a combined update direction. However, this approach do not consider an important aspect, the sensitivity in the gradient dimensions. Here, we introduce a novel gradient aggregation approach using Bayesian inference. We place a probability distribution over the task-specific parameters, which in turn induce a distribution over the gradients of the tasks. This additional valuable information allows us to quantify the uncertainty in each of the gradients dimensions, which can then be factored in when aggregating them. We empirically demonstrate the benefits of our approach in a variety of datasets, achieving state-of-the-art performance.
Machine Learning, ICML
1 Introduction
In many application domains, there is a need to perform several machine learning inference tasks simultaneously. For instance, an autonomous vehicle needs to identify and detect objects in its vicinity, perform lane detection, track the movements of other vehicles over time, and predict free space around it, all in parallel and in real-time. In deep Multi-Task Learning (MTL) the goal is to train a single neural network (NN) to solve several tasks simultaneously, thus avoiding the need to have one dedicated model for each task (Caruana, 1997). Besides reducing the computational demand at test time, MTL also has the potential to improve generalization (Baxter, 2000). It is therefore not surprising that applications of MTL are taking central roles in various domains, such as vision (Achituve etal., 2021a; Shamshad etal., 2023; Zheng etal., 2023), natural language processing (Liu etal., 2019b; Zhou etal., 2023), speech (Michelsanti etal., 2021), robotics (Devin etal., 2017; Shu etal., 2018), and general scientific problems (Wu etal., 2018) to name a few.
However, optimizing multiple tasks simultaneously is a challenging problem that may lead to degradation in performance compared to learning them individually (Standley etal., 2020; Yu etal., 2020). To address this issue, one basic formula that many MTL optimization algorithms follow is to first calculate the gradient of each task’s loss, and then aggregate these gradients according to some specified scheme. For example, several studies focus on reducing conflicts between the gradients before averaging them (Yu etal., 2020; Wang etal., 2020), others find a convex combination with minimal norm (Sener & Koltun, 2018; Désidéri, 2012), and some use a game theoretical approach (Navon etal., 2022). However, by relying only on the gradient these methods miss an important aspect, the sensitivity of the gradient in each dimension.
Our approach builds on the following observation - for each task, there may be many “good” parameter configurations. Standard MTL optimization methods take only a single value into account, and as such lose information in the aggregation step. Hence, tracking all of the parameter configurations will yield a richer description of the gradient space that can be advantageous when finding an update direction. Specifically, to account for all parameter values, we propose to place a probability distribution over the task-specific parameters, which in turn induces a probability distribution over the gradients. As a result, we obtain uncertainty estimates for the gradients that reflect the sensitivity in each of their dimensions. High-uncertainty dimensions are more lenient for changes while dimensions with a lower uncertainty are more strict (see illustration in Figure2).
To obtain a probability distribution over the task-specific parameters we take a Bayesian approach. According to the Bayesian view, a posterior distribution over parameters of interest can be derived through Bayes rule. In MTL, it is common to use a shared feature extractor network with linear task-specific layers (Ruder, 2017).Hence, if we assume a Bayesian model over the last task-specific layer weights during the back-propagation process, we obtain the posterior distributions over them. The posterior is then used to compute a Gaussian distribution over the gradients by means of moment matching.Then, to derive an update direction for the shared network, we design a novel aggregation scheme that considers the full distributions of the gradients. We name our method BayesAgg-MTL. An important implication of our approach is that BayesAgg-MTL assigns weights to the gradients at a higher resolution compared to existing methods, allocating a specific weight for each dimension and datum in the batch. We demonstrate our method effectiveness over baseline methods on the MTL benchmarks QM9 (Ramakrishnan etal., 2014), CIFAR-100 (Krizhevsky etal., 2009), ChestX-ray14 (Wang etal., 2017), and UTKFace (Zhang etal., 2017).
In summary, this paper makes the following novel contributions:(1) The first Bayesian formulation of gradient aggregation for Multi-Task Learning. (2) A novel posterior approximation based on a second-order Taylor expansion. (3) A new MTL optimization algorithm based on our posterior estimation. (4) New state-of-the-art results on several MTL benchmarks compared to leading methods. Our code is publicly available at https://github.com/ssi-research/BayesAgg˙MTL.
2 Background
Notations. Scalars, vectors, and matrices are denoted with lower-case letters (e.g., ), bold lower-case letters (e.g., ), and bold upper-case letters (e.g., ) respectively. All vectors are treated as column vectors. Training samples are tuples consisting of shared features across all tasks and labels of tasks, namely , where denotes the training set. We denote the dimensionality of the input and the output of task by and accordingly.
In this study, we focus on common NN architectures for MTL having a shared feature extractor and linear task-specific heads (Kendall etal., 2018; Sener & Koltun, 2018).The model parameters are denoted by , where is the vector of shared parameters and are task-specific parameter vectors, each lies in . The last shared feature representation is denoted by the vector .Hence, the output of the network for task can be described as .The loss of task is denoted by .The gradient of loss w.r.t is . For clarity of exposition, function dependence on input variables will be omitted from now on.
2.1 Multi-Task Learning
A prevailing approach to optimize MTL models goes as follows. First, the gradient of each task loss is computed. Second, an aggregation rule is imposed to combine the gradients according to some algorithm. And lastly, perform an update step using the outcome of the aggregation step. Commonly the aggregation rule operates on the gradients of the loss w.r.t parameters, or only the shared parameters (e.g., Yu etal., 2020; Navon etal., 2022; Shamsian etal., 2023)). Alternatively, to avoid a costly full back-propagation process for each task, some methods suggest applying it on the last shared representation (e.g., Sener & Koltun, 2018; Liu etal., 2020; Senushkin etal., 2023). Here, to make our method fast and scalable, we take the latter approach and note that it could be extended to full gradient aggregation.
2.2 Bayesian Inference
We wish to incorporate uncertainty estimates for the gradients into the aggregation procedure. Doing so will allow us to find an update direction that takes into account the importance of each gradient dimension for each task. A natural choice to model uncertainty is using Bayesian inference. Since we would like to get uncertainty estimates w.r.t the last shared hidden layer, we treat only the last task-specific layer as a Bayesian layer. This “Bayesian last layer” approach is a common way to scale Bayesian inference to deep neural networks (Snoek etal., 2015; Calandra etal., 2016; Wilson etal., 2016a; Achituve etal., 2021c). We will now present some of the main concepts of Bayesian modeling that will be used as part of our method.
For simplicity, assume a single output variable. We also dropped the task notation for clarity. According to the Bayesian paradigm, instead of treating the parameters as deterministic values that need to be optimized, they are treated as random variables, i.e. there is a distribution over the parameters. The posterior distribution for , after observing the data, is given using Bayes rule as
(1)
Predictions in Bayesian inference are given by taking the expected prediction with respect to the posterior distribution.In general, the Bayesian inference procedure for is intractable. However, for some specific scenarios, there exists an analytic solution. For example, in linear regression, if we assume a Gaussian likelihood with a fixed independent scalar noise between the observations , , and a Gaussian prior then,
(2)
Here is the matrix that results from stacking the vectors . Similarly, we denote by the matrix that results from stacking the vectors of hidden representation. In the specific case of deep NNs with Bayesian last layer we get the same inference result only with replacing . Going beyond a single output variable entails defining a covariance matrix for the noise model. However, in this study we assume independence between the output variables in these cases.
Unlike regression, in classification the likelihood is not a Gaussian, and the posterior can only be approximated. The common choice is to use variational inference (Wilson etal., 2016b; Achituve etal., 2021b, 2023), although there are other alternatives as well (Kristiadi etal., 2020).
3 Method
We start with an outline of the problem and our approach. Consider a deep network for multi-task learning that has a shared feature extractor part and task-specific linear layers. We propose to use Bayesian inference on the last layer as a means to train deterministic MTL models. For each task , we define a Bayesian probabilistic model representing the uncertainty over the linear weights of the last, task-specific layer . The distribution over weights induces a distribution over gradients of the loss with respect to the last shared hidden layer.Given these per-task distributions on a joint space, we propose an aggregation rule for combining the gradients of the tasks to a shared update direction that takes into account the uncertainty in the gradients (see illustration in Figure1). Then, the back-propagation process can proceed as usual.
Since regression and classification setups yield different inference procedures according to our approach, albeit having the same general framework, we discuss the two setups separately, starting with regression.
3.1 BayesAgg-MTL for Regression Tasks
Consider a standard square loss for task , , between the label and the network output . Given a random batch of example , the gradient of the loss with respect to the hidden layer for the example is,
(3)
Our main observation is that is a function of . Hence, if we view in the back-propagation process as a random variable, then will be a random variable as well. This view will allow us to capture the uncertainty in the task gradient. Since the dimension of the hidden layer is usually small compared to the dimension of all shared parameters, operations in this space, such as matrix inverse, should not be costly.
If we fix all the shared parameters, then the posterior over has a Gaussian distribution with known parameters via Eq.2. As is quadratic in , it has a generalized chi-squared distribution (Davies, 1973). However, since this distribution does not admit a closed-form density function, and since the gradient aggregation needs to be efficient as we run it at each iteration, we approximate as a Gaussian distribution. The optimal choice for the parameters of this Gaussian is given by matching its first two moments to those of the true density, as these parameters minimize the Kullback–Leibler divergence between the two distributions (Minka, 2001). Luckily, in the regression case, we can derive the first two moments from the posterior over ,
(4)
where , , we assumed , and is the matrix trace. We emphasize that the following approximation is for the gradient of a single data point and a single task, not for the gradient of the task with respect to the entire batch. The full derivation is presented in AppendixA.1.
Several points deserve attention here. First, note the similarity between the solution of the first moment and the gradient obtained via the standard back-propagation. The two differences are that the last layer parameters, , are replaced with the posterior mean, , and an uncertainty term was added. In the extreme case of and , the mean coincides with that of the standard back-propagation. Second, in the case of a multi-output task, following our independence assumption between output variables, we can obtain the moments for each output dimension separately using the same procedure, so de facto we treat each output as a different task. Finally, during training, the shared parameters are constantly being updated. Hence, to compute the posterior distribution for we need to iterate over the entire dataset at each update step. In practice, this can make our method computationally expensive. Therefore, we use the current batch data only to approximate the posterior over , and introduce information about the full dataset through the prior as described next.
Prior selection.A common choice in Bayesian deep learning is to choose uninformative priors, such as a standard Gaussian, to let the data be the main influence on the posterior (Wilson & Izmailov, 2020; Fortuin etal., 2021). However, in our case, we found this prior to be too weak. Since the posterior depends only on a single batch we opted to introduce information about the whole dataset through the prior. A natural choice is to use the posterior distribution of the previous batch as our prior (Särkkä, 2013, Chapter3). However, this method did not work well in our experiments and we developed an alternative. During each epoch, we collect the feature representations and labels of all examples in the dataset. At the end of the epoch, we compute the posterior based on the full data (with an isotropic Gaussian prior) and use this posterior as the prior at each step in the subsequent epoch. Updating the full data prior more frequently is likely to have a beneficial effect on our overall model; however it will also probably make the training time longer. Hence, doing the update once an epoch strikes a good balance between performance and training time.
Aggregation step. Having an approximation for the gradient distribution of each task we need to combine them to find an update direction for the shared parameters. Denote the mean of the gradient of the loss for task w.r.t the hidden layer for the example by , and similarly the covariance matrix . We strive to find an update direction for the last shared layer, , that lies in a high-density region for all tasks. Hence, we pick that maximizes the following likelihood:
(5)
Thankfully, the above optimization problem can be solved in closed-form, yielding the following solution:
(6)
However, we found that modeling the full covariance matrix can be numerically unstable and sensitive to noise in the gradient. Instead, we assume independence between the dimensions of for all tasks which results in diagonal covariance matrices having variance . The update direction now becomes:
(7)
where the division and multiplication are done element-wise. In Eq.7 we intentionally denote by the vector of uncertainty-based weights that our method assigns to the mean gradient to highlight that the weights are unique per task, dimension, and datum. The final modification for the method involves down-scaling the impact of the precision by a hyper-parameter , namely, we take . Empirically, the scaling parameter helped to achieve better performance, perhaps due to misspecifications in the model (such as the diagonal Gaussian assumption over ).
Input: - a random batch of examples; - posterior distributions over the task-specific parameters; - scaling hyper-parameter For : For : Compute and as in Eq.4 for regression or Eq.11 for classification. Set (operations are done element-wise), , . End for Compute . End for Compute gradient via matrix multiplication w.r.t the shared parameters: .
With the aggregated gradient for each example, the back-propagation procedure proceeds as usual by averaging over all examples in the batch and then back-propagating this over to the shared parameters. To gain a better intuition about the update rule of BayesAgg-MTL , consider the illustration in Figure2. In the figure, we plot the mean update direction of two tasks along with the uncertainty in them. The first task is more sensitive to shifts in the vertical dimension and less so to shifts in the second (horizontal) dimension, while for the second task, it is the opposite. By taking the variance information into account, BayesAgg-MTL can find an update direction that works well for both, compared to a simple average of the gradient means. We summarize our method in Algorithm 1.
Making predictions. Since we have a closed-form solution for the posterior of the task-specific parameters, BayesAgg-MTL does not learn this layer during training. Therefore, when making predictions we use the posterior mean, , computed on the full training set. We do so, instead of using a full Bayesian inference, for a fair comparison with alternative MTL approaches and to have an identical run-time and memory requirements when making predictions.
Connection to Nash-MTL.In (Navon etal., 2022) the authors proposed a cooperative bargaining game approach to the gradient aggregation step with the directional derivative as the utility of each player (task). They then proposed using the Nash bargaining solution, the direction that maximizes the product of all the utilities. One can consider Eq.5 as the Nash bargaining solution with the utility of each task being its likelihood. However, unlike (Navon etal., 2022) we get an analytical formula for the bargaining solution since the Gaussian exponent and the logarithm cancel out.
3.2 BayesAgg-MTL for Classification Tasks
We now turn to present our approach for classification tasks. When dealing with classification there are two sources of intractability that we need to overcome. The first is the posterior of , and the second is estimating the moments of . We describe our solution to both challenges next.
Posterior approximation. In classification tasks the likelihood is not a Gaussian and in general, we cannot compute the posterior in closed-form. One common option is to approximate it using a Gaussian distribution and learn its parameters using a variational inference (VI) scheme (Saul etal., 1996; Neal & Hinton, 1998; Bishop, 2006). However, in our early experimentations, we didn’t find it to work well without using a computationally expensive VI optimization at each update step. Alternatively to VI, the Laplace approximation (MacKay, 1992) approximates the posterior as a Gaussian using a second-order Taylor expansion. Since the expansion is done at the optimal parameter values that are learned point-wise, the Jacobean term in the expansion vanishes. Here, we follow a similar path; however, we cannot assume that the Jacobean is zero as we are not near a stationary point during most of the training. Nevertheless, we can still find a Gaussian approximation. A similar derivation was proposed in (Immer etal., 2021), yet they ignored the first order term eventually.Denote by the learned point estimate for the task parameters, and . Then, at each step of the training by using Bayes rule we can obtain a posterior approximation for using the following:
(8)
The above takes the following form , where are known constants. We stress here again, that since we apply Bayesian inference to the last layer parameters only, computing and inverting , typically does not incur a large computational overhead.
After rearranging and completing the square we obtain a quadratic form corresponding to the following Gaussian distribution (see full derivation in AppendixA.2):
(9)
Examining the above posterior reveals several insights. First, the posterior mean corresponds to the Newton method update step. Second, the covariance of this posterior is the same as that of the Laplace approximation. Third, at a stationary point the Laplace approximation is recovered if the gradient of the loss w.r.t the parameters approaches zero.
One limitation of the approximation in Eq.9 is that the Hessian will not be positive-definite in most cases. Therefore, we replace it with the generalized Gauss-Newton (GGN) matrix (Schraudolph, 2002; Martens & Sutskever, 2011; Daxberger etal., 2021):
(10)
Where, is the Jacobean of the model output for task w.r.t the last layer parameters of that task, is the Hessian of the negative log-likelihood w.r.t the model outputs of task , and is the covariance of the Gaussian prior for . As in the regression case we use here an informative prior based on the posterior from the full dataset at each training step.
Moments estimation. Unlike the regression case, in classification will depend on through some non-linear function. Hence, obtaining the moments as in Eq.4 in closed-form is more challenging. However, since we are estimating the parameters of the last layer only, which in many cases are relatively low-dimensional, we can efficiently approximate these moments with Monte-Carlo sampling:
(11)
Here, are samples from , and the total number of samples are . Effectively this means that we need to back-propagate gradients w.r.t the shared hidden layer times; however, since the task-specific layers are linear it can be done cheaply and in parallel. Having the moment estimation we proceed with the aggregation rule as described in Section3.1.
Making predictions. Unlike the regression case, here we learn the parameters of the last layer as part of the posterior approximation. Therefore, making predictions is done as usual with a forward-pass through the network.
4 Related Work
Multi-task learning is an active research area that attempts to learn jointly multiple tasks, commonly using a shared representation (Ruder, 2017; Navon etal., 2022; Liu etal., 2023; Elich etal., 2023; Shi etal., 2023; Yun & Cho, 2023).Learning a shared representation for multiple tasks imposes some challenges. One challenge is trying to learn an architecture that can express both task-shared and task-specific features. Another challenge is to find the optimal balancing of the tasks and enable learning the different tasks with equal importance.One line of research in MTL suggests methods to introduce novel MTL-friendly architectures, such as task-specific modules (Misra etal., 2016), attention-based networks (Liu etal., 2019a), and an ensemble of single-task models (Dimitriadis etal., 2023). Yet, a more common line of research focuses on the MTL optimization process, trying to explain the difficulties in the process by e.g. conflicting gradients (Wang etal., 2020) or plateaus in the loss landscape (Schaul etal., 2019). Our method focuses on the latter, MTL optimization process improvement.
Different strategies were proposed to address the MTL optimization challenge to successfully balance the training of the different tasks and resolve their conflicts. The methods can broadly be categorized into two groups, loss-based and gradient-based (Dai etal., 2023). Loss-based approaches attempt to allocate weights for the tasks based on some criteria related to the loss, such as the difficulty of the task (Guo etal., 2018), random weights (Lin etal., 2022), geometric mean of the task losses (Chennupati etal., 2019; Yun & Cho, 2023), and task uncertainty (Kendall etal., 2018). Regarding the last one, to weigh the tasks it uses the uncertainty in the observations only. This is very different from our approach that weighs each dimension of the task gradients based on full Bayesian information.
Gradient-based methods attempt to balance the tasks by using the gradients information directly (Chen etal., 2018, 2020; Javaloy & Valera, 2022; Liu etal., 2020; Navon etal., 2022; Fernando etal., 2023; Senushkin etal., 2023). For example, GradNorm (Chen etal., 2018) dynamically tunes the gradient magnitudes to prevent imbalances between the tasks during training. PCGrad (Yu etal., 2020) identifies gradient conflicts as the main optimization issue in MTL, and attempts to reduce the conflicts by projecting each gradient to the other tasks’ normal plane.Nash-MTL (Navon etal., 2022) suggests treating MTL as a bargaining game to find Pareto optimal solutions.Several studies suggested adaptations for the multiple-gradient descent algorithm (MGDA) (Désidéri, 2012; Sener & Koltun, 2018), such as CAGrad, (Liu etal., 2021), and MoCo (Fernando etal., 2023).As opposed to previous methods, our approach considers both the mean and the variance of the gradients to derive an update direction.
Lastly, some studies recently suggested performing model merging based on the uncertainty of the parameters (Matena & Raffel, 2022; Daheim etal., 2023). The goal there is usually to combine models for various tasks, such as model ensembling, federated learning, and robust fine-tuning. Unlike these methods, we assume a Bayesian model on the last layer only and propagate the uncertainty to the gradients for gradient aggregation.
()
LS
SI
RLW
DWA
UW
MGDA
PCGrad
CAGrad
IMTL-G
Nash-MTL
IGBv2
Aligned-MTL-UB
BayesAgg-MTL (Ours)
5 Experiments
We evaluated BayesAgg-MTL on several MTL benchmarks differing in the number of tasks and their types. Unless specified otherwise, we report the average and standard deviation (std) of relevant metrics over random seeds. In all datasets, we pre-allocated a validation set from the training set for hyper-parameter tuning and early stopping for all methods. Throughout our experiments, we used the ADAM optimizer (Kingma & Ba, 2015) which was found to be effective for MTL due to partial loss-scale invariance (Elich etal., 2023). Full experimental details are given in AppendixB.
Compared methods. We compare BayesAgg-MTL with the following baseline methods: (1) Single Task Learning (STL), which learns each task independently under the same experimental setup as that of the MTL methods; (2) Linear Scalarization (LS), which assigns a uniform weight to all tasks, namely ; (3) Scale-Invariant (SI) (Navon etal., 2022), which assigns a uniform weight to the log of all tasks, namely ;(4) Random Loss Weighting (RLW) (Lin etal., 2022), which allocates random weights to the losses at each iteration;(5) Dynamic Weight Average (DWA) (Liu etal., 2019a), which allocates a weight based on the rate of change of the loss for each task; (6) Uncertainty weighting (UW) (Kendall etal., 2018), which minimize a scalar term corresponding to the aleatoric uncertainty for each task; (7) Multiple-Gradient Descent Algorithm (MGDA) (Désidéri, 2012; Sener & Koltun, 2018), which finds a minimum norm solution for a convex combination of the losses; (8) Projecting Conflicting Gradients (PCGrad) (Yu etal., 2020), which projects the gradient of each task onto the normal plane of tasks they are in conflict with; (9) Conflict-Averse Grad (CAGrad) (Liu etal., 2021), which searches an update direction centered at the LS solution while minimizing conflicts in gradients; (10) Impartial MTL-Grad (IMTL-G) (Liu etal., 2020), which finds an update vector such that the projection of it on each of the gradients of the tasks is equal; (11) Nash-MTL (Navon etal., 2022) that derives task weights based on the Nash bargaining solution; (12) Improvable GapBalancing (IGBv2) (Dai etal., 2023), which suggests a Reinforcement learning procedure to balance the task losses; (13) Aligned-MTL-UB (Senushkin etal., 2023), which aligns the principle components of a gradient matrix.
Evaluation metric. Unless specified otherwise, we report the metric introduced in (Maninis etal., 2019). This metric measures the average relative difference between a method compared to the STL baseline according to some criterion of interest . Namely, . Where, is the criterion value for task under method , is the criterion value for task under the STL baseline, and . If then a lower value for is better (e.g., task loss), and if then a higher value for is preferred (e.g., task accuracy). Lower indicates a better performance.
Pre-training stage.To obtain meaningful features for the Bayesian layer, it is a common practice to apply a pre-training step using standard NN training for several epochs (Wilson etal., 2016a, b). We follow the same path here and apply an initial pre-training step using linear scalarization. We would like to stress here that in all the experiments, the overall number of training steps for BayesAgg-MTL (including the pre-training) is the same as all methods.
CIFAR (Acc.) []
CX-ray () []
LS
SI
RLW
DWA
UW
MGDA
PCGrad
CAGrad
IMTL-G
Nash-MTL
IGBv2
Aligned-MTL-UB
BayesAgg-MTL (Ours)
Age() ()
Gender ()
Ethnicity ()
()
STL
–
LS
SI
RLW
DWA
UW
MGDA
PCGrad
CAGrad
IMTL-G
Nash-MTL
IGBv2
Aligned-MTL-UB
BayesAgg-MTL (Ours)
5.1 BayesAgg-MTL for Regression
We first evaluated BayesAgg-MTL on an MTL problem with regression tasks only. We used the QM9 dataset which contains stable small organic molecules represented as graphs having node and edge features (Ramakrishnan etal., 2014; Wu etal., 2018). The goal here is to predict chemical properties, such as geometric and energetic ones, that may vary in scale and difficulty of the tasks. We follow the experimental protocol of Navon etal. (2022). Specifically, we allocate approximately examples for training, with separate validation and testing sets with examples each. Additionally, we employ the message-passing neural network architecture (Gilmer etal., 2017) in conjunction with the pooling operator described in (Vinyals etal., 2016).
The test results for this dataset are presented in Table1. Baseline method results were taken from (Dai etal., 2023), except for Aligned-MTL-UB, which is included here for the first time. The criterion used in here is the mean absolute error (MAE) of the losses. From the table, BayesAgg-MTL achieves the best test performance, with a significant improvement compared to most of the baseline methods.
To gain a better intuition into the weights that BayesAgg-MTL assigns, we define here again the vector of weights per example and task from Eq.7, . Figure3 depicts for all tasks the average over dimensions of for random examples at the start, middle, and end of training. The plot reveals an interesting pattern. Early in training, the average weights are distributed among the tasks without any specific pattern. As training progresses, larger weights are assigned for tasks in the middle of the training, while tasks receive smaller weights. At the end of the training, this pattern changes, and tasks are assigned with larger weights compared to tasks .
5.2 BayesAgg-MTL for Binary Classification
Next, we evaluated BayesAgg-MTL on the MTL benchmarks CIFAR-MTL (Krizhevsky etal., 2009; Rosenbaum etal., 2018), and ChestX-ray14 (Wang etal., 2017). To the best of our knowledge, we are the first to evaluate MTL methods on the latter dataset. These datasets contain a large number of tasks, and respectively, with a high class-imbalance distribution. This poses a significant challenge for current MTL methods.
CIFAR-MTL uses the coarse labels of the CIFAR-100 dataset to create an MTL benchmark having binary tasks. Classes from this dataset are grouped into super-classes (fish, flowers, trees, etc.), such that each example is given a one-hot encoding vector of labels indicating the super-class it belongs to. We use the official train-test split having examples and examples respectively. We allocate examples from the training set for a validation set. Our experiments on this dataset were conducted using a simple NN having convolution layers.
ChestX-ray14 contains X-ray images of chests from patients. Each image has labels from binary classes corresponding to the occurrence or absence of thoracic diseases. Multiple diseases can appear together in a patient. In our experiments, we mostly follow the training protocol suggested in (Taslimi etal., 2022) that used ResNet-34 for the shared parameters. we use the official split of for training, validation, and test.
We present the test results for these datasets in Table2. On the CIFAR-MTL we report the accuracy in class assignment, and on the ChestX-ray14 we report the based on the AUC-ROC values per task. From the table, BayesAgg-MTL performs best on both datasets. Interestingly, on the ChestX-ray14 dataset almost all methods, except for ours and DWA, under-perform the naive LS baseline. In AppendixC.2 we compare the run-time of all methods on this dataset and on the QM9. We show that BayesAgg-MTL is substantially faster than other baseline methods that use gradients w.r.t the shared parameters to weigh the tasks.
5.3 BayesAgg-MTL for Mixed Tasks
In the last set of experiments, we evaluated BayesAgg-MTL and baseline methods on the UTKFace dataset (Zhang etal., 2017). This dataset contains over face images with annotations of age, gender, and ethnicity. The age values range from to , treated as a regression task. Gender is classified into binary categories, either male or female, while ethnicity is classified into five distinct categories, making it a multi-class classification task. We split the dataset according to to train, validation, and test datasets. Here, we use ResNet-18 for the shared network.
Results for this dataset based on random seeds are presented in Table3. Here as well BayesAgg-MTL outperforms all methods, having the best results on out of tasks. Interestingly, our approach and MGDA, were the only methods to improve upon the STL baseline on the regression task.
6 Conclusions
In this study, we present BayesAgg-MTL , a novel method for aggregating the task gradients in MTL. Instead of treating the gradient of each task as a deterministic quantity we advocate here to assign a probability distribution over them. The randomness in them arises by noticing that there are many possible configurations for the task-specific parameters that work well. Hence, by tracking all of them using Bayesian tools we can obtain a richer description of the gradient space. This in turn allows us to model the uncertainty in the gradients and derive an update direction for the shared parameters that takes it into account. We demonstrate our method’s effectiveness on several benchmark datasets compared with leading baseline methods. For future work, we would like to extend BayesAgg-MTL beyond linear task heads. The challenge here would be to efficiently estimate the Bayesian posterior and the gradient moments. Another possible limitation of BayesAgg-MTL , having in common with other popular MTL methods, is that it may fail on rare or atypical examples (Sagawa etal., 2019).
Acknowledgements
This study was funded by a grant to GC from the Israel Science Foundation (ISF 737/2018), and by an equipment grant to GC and Bar-Ilan University from the Israel Science Foundation (ISF 2332/18). IA is supported by a PhD fellowship from Bar-Ilan data science institute (BIU DSI).
Impact Statement
This paper presents work whose goal is to advance the field of Machine Learning. There are many potential societal consequences of our work, none which we feel must be specifically highlighted here.
References
Achituve etal. (2021a)Achituve, I., Maron, H., and Chechik, G.Self-supervised learning for domain adaptation on point clouds.In Proceedings of the IEEE/CVF winter conference onapplications of computer vision, pp. 123–133, 2021a.
Achituve etal. (2021b)Achituve, I., Navon, A., Yemini, Y., Chechik, G., and Fetaya, E.GP-Tree: A Gaussian process classifier for few-shot incrementallearning.In International Conference on Machine Learning, pp. 54–65.PMLR, 2021b.
Achituve etal. (2021c)Achituve, I., Shamsian, A., Navon, A., Chechik, G., and Fetaya, E.Personalized federated learning with Gaussian processes.Advances in Neural Information Processing Systems,34:8392–8406, 2021c.
Achituve etal. (2023)Achituve, I., Chechik, G., and Fetaya, E.Guided deep kernel learning.In Uncertainty in Artificial Intelligence. PMLR, 2023.
Baxter (2000)Baxter, J.A model of inductive bias learning.Journal of artificial intelligence research, 12:149–198, 2000.
Bishop (2006)Bishop, C.Pattern recognition and machine learning.Springer google schola, 2:531–537, 2006.
Brier (1950)Brier, G.W.Verification of forecasts expressed in terms of probability.Monthly weather review, 78(1):1–3, 1950.
Calandra etal. (2016)Calandra, R., Peters, J., Rasmussen, C.E., and Deisenroth, M.P.Manifold Gaussian processes for regression.In 2016 International Joint Conference on Neural Networks(IJCNN), pp. 3338–3345. IEEE, 2016.
Chen etal. (2018)Chen, Z., Badrinarayanan, V., Lee, C.-Y., and Rabinovich, A.GradNorm: Gradient normalization for adaptive loss balancing indeep multitask networks.In Dy, J. and Krause, A. (eds.), Proceedings of the 35thInternational Conference on Machine Learning, volume80 of Proceedingsof Machine Learning Research, pp. 794–803. PMLR, 10–15 Jul 2018.
Chen etal. (2020)Chen, Z., Ngiam, J., Huang, Y., Luong, T., Kretzschmar, H., Chai, Y., andAnguelov, D.Just pick a sign: Optimizing deep multitask models with gradient signdropout.In Larochelle, H., Ranzato, M., Hadsell, R., Balcan, M., and Lin, H.(eds.), Advances in Neural Information Processing Systems, volume33,pp. 2039–2050. Curran Associates, Inc., 2020.
Chennupati etal. (2019)Chennupati, S., Sistu, G., Yogamani, S., and Rawashdeh, S.Multinet++: Multi-stream feature aggregation and geometric lossstrategy for multi-task learning.In Proceedings of the IEEE/CVF Conference on Computer Visionand Pattern Recognition (CVPR) Workshops, June 2019.
Daheim etal. (2023)Daheim, N., Möllenhoff, T., Ponti, E., Gurevych, I., and Khan, M.E.Model merging by uncertainty-based gradient matching.In The Twelfth International Conference on LearningRepresentations, 2023.
Dai etal. (2023)Dai, Y., Fei, N., and Lu, Z.Improvable gap balancing for multi-task learning.In Uncertainty in Artificial Intelligence, pp. 496–506.PMLR, 2023.
D’Angelo & Fortuin (2021)D’Angelo, F. and Fortuin, V.Repulsive deep ensembles are Bayesian.Advances in Neural Information Processing Systems,34:3451–3465, 2021.
Davies (1973)Davies, R.B.Numerical inversion of a characteristic function.Biometrika, 60(2):415–417, 1973.
Daxberger etal. (2021)Daxberger, E., Kristiadi, A., Immer, A., Eschenhagen, R., Bauer, M., andHennig, P.Laplace redux-effortless bayesian deep learning.Advances in Neural Information Processing Systems,34:20089–20103, 2021.
Devin etal. (2017)Devin, C., Gupta, A., Darrell, T., Abbeel, P., and Levine, S.Learning modular neural network policies for multi-task andmulti-robot transfer.In 2017 IEEE international conference on robotics andautomation (ICRA), pp. 2169–2176. IEEE, 2017.
Dimitriadis etal. (2023)Dimitriadis, N., Frossard, P., and Fleuret, F.Pareto manifold learning: Tackling multiple tasks via ensembles ofsingle-task models.In International Conference on Machine Learning, pp.8015–8052. PMLR, 2023.
Elich etal. (2023)Elich, C., Kirchdorfer, L., Köhler, J.M., and Schott, L.Challenging common assumptions in multi-task learning.arXiv preprint arXiv:2311.04698, 2023.
Fernando etal. (2023)Fernando, H., Shen, H., Liu, M., Chaudhury, S., Murugesan, K., and Chen, T.Mitigating gradient bias in multi-objective learning: A provablyconvergent stochastic approach.In International Conference on Learning Representations, 2023.
Fey & Lenssen (2019)Fey, M. and Lenssen, J.E.Fast graph representation learning with pytorch geometric.In ICLR Workshop on Representation Learning on Graphs andManifolds, 2019.
Fortuin etal. (2021)Fortuin, V., Garriga-Alonso, A., Ober, S.W., Wenzel, F., Ratsch, G., Turner,R.E., vander Wilk, M., and Aitchison, L.Bayesian neural network priors revisited.In International Conference on Learning Representations, 2021.
Gilmer etal. (2017)Gilmer, J., Schoenholz, S.S., Riley, P.F., Vinyals, O., and Dahl, G.E.Neural message passing for quantum chemistry.In International conference on machine learning, pp.1263–1272. PMLR, 2017.
Guo etal. (2018)Guo, M., Haque, A., Huang, D.-A., Yeung, S., and Fei-Fei, L.Dynamic task prioritization for multitask learning.In Proceedings of the European Conference on Computer Vision(ECCV), September 2018.
Immer etal. (2021)Immer, A., Bauer, M., Fortuin, V., Rätsch, G., and Emtiyaz, K.M.Scalable marginal likelihood estimation for model selection in deeplearning.In International Conference on Machine Learning, pp.4563–4573. PMLR, 2021.
Javaloy & Valera (2022)Javaloy, A. and Valera, I.Rotograd: Gradient hom*ogenization in multitask learning.In International Conference on Learning Representations, 2022.
Kendall etal. (2018)Kendall, A., Gal, Y., and Cipolla, R.Multi-task learning using uncertainty to weigh losses for scenegeometry and semantics.In Proceedings of the IEEE conference on computer vision andpattern recognition, pp. 7482–7491, 2018.
Kingma & Ba (2014)Kingma, D.P. and Ba, J.ADAM: A method for stochastic optimization.In International Conference on Learning Representations, 2014.
Kingma & Ba (2015)Kingma, D.P. and Ba, J.Adam: A method for stochastic optimization.In Bengio, Y. and LeCun, Y. (eds.), 3rd InternationalConference on Learning Representations, 2015.
Kristiadi etal. (2020)Kristiadi, A., Hein, M., and Hennig, P.Being Bayesian, even just a bit, fixes overconfidence in Relunetworks.In International conference on machine learning, pp.5436–5446. PMLR, 2020.
Krizhevsky etal. (2009)Krizhevsky, A., Hinton, G., etal.Learning multiple layers of features from tiny images.Technical report, University of Toronto, 2009.
Kurin etal. (2022)Kurin, V., DePalma, A., Kostrikov, I., Whiteson, S., and Mudigonda, P.K.In defense of the unitary scalarization for deep multi-task learning.Advances in Neural Information Processing Systems,35:12169–12183, 2022.
Lakshminarayanan etal. (2017)Lakshminarayanan, B., Pritzel, A., and Blundell, C.Simple and scalable predictive uncertainty estimation using deepensembles.Advances in neural information processing systems, 30, 2017.
Lin etal. (2022)Lin, B., Ye, F., Zhang, Y., and Tsang, I.W.Reasonable effectiveness of random weighting: A litmus test formulti-task learning.Transactions on Machine Learning Research, 2022.
Liu etal. (2021)Liu, B., Liu, X., Jin, X., Stone, P., and Liu, Q.Conflict-averse gradient descent for multi-task learning.Advances in Neural Information Processing Systems,34:18878–18890, 2021.
Liu etal. (2023)Liu, B., Feng, Y., Stone, P., and Liu, Q.Famo: Fast adaptive multitask optimization, 2023.
Liu etal. (2020)Liu, L., Li, Y., Kuang, Z., Xue, J.-H., Chen, Y., Yang, W., Liao, Q., andZhang, W.Towards impartial multi-task learning.In International Conference on Learning Representations, 2020.
Liu etal. (2019a)Liu, S., Johns, E., and Davison, A.J.End-to-end multi-task learning with attention.In Proceedings of the IEEE/CVF conference on computer visionand pattern recognition, pp. 1871–1880, 2019a.
Liu etal. (2019b)Liu, X., He, P., Chen, W., and Gao, J.Multi-task deep neural networks for natural language understanding.In Proceedings of the 57th Annual Meeting of the Associationfor Computational Linguistics, pp. 4487–4496, 2019b.
Maninis etal. (2019)Maninis, K.-K., Radosavovic, I., and Kokkinos, I.Attentive single-tasking of multiple tasks.In Proceedings of the IEEE/CVF conference on computer visionand pattern recognition, pp. 1851–1860, 2019.
Martens & Sutskever (2011)Martens, J. and Sutskever, I.Learning recurrent neural networks with hessian-free optimization.In Proceedings of the 28th international conference on machinelearning (ICML-11), pp. 1033–1040, 2011.
Matena & Raffel (2022)Matena, M.S. and Raffel, C.A.Merging models with fisher-weighted averaging.Advances in Neural Information Processing Systems,35:17703–17716, 2022.
Michelsanti etal. (2021)Michelsanti, D., Tan, Z.-H., Zhang, S.-X., Xu, Y., Yu, M., Yu, D., and Jensen,J.An overview of deep-learning-based audio-visual speech enhancementand separation.IEEE/ACM Transactions on Audio, Speech, and LanguageProcessing, 29:1368–1396, 2021.
Minka (2001)Minka, T.P.Expectation propagation for approximate bayesian inference.In Proceedings of the Seventeenth conference on Uncertainty inartificial intelligence, pp. 362–369, 2001.
Misra etal. (2016)Misra, I., Shrivastava, A., Gupta, A., and Hebert, M.Cross-stitch networks for multi-task learning.In Proceedings of the IEEE/CVF Conference on Computer Visionand Pattern Recognition (CVPR), pp. 3994–4003, 06 2016.doi: 10.1109/CVPR.2016.433.
Naeini etal. (2015)Naeini, M.P., Cooper, G.F., and Hauskrecht, M.Obtaining well calibrated probabilities using Bayesian binning.In Proceedings of the Twenty-Ninth AAAI Conference onArtificial Intelligence, January 25-30, 2015, Austin, Texas, USA, pp.2901–2907. AAAI Press, 2015.
Navon etal. (2022)Navon, A., Shamsian, A., Achituve, I., Maron, H., Kawaguchi, K., Chechik, G.,and Fetaya, E.Multi-task learning as a bargaining game.In International Conference on Machine Learning, pp.16428–16446. PMLR, 2022.
Neal & Hinton (1998)Neal, R.M. and Hinton, G.E.A view of the EM algorithm that justifies incremental, sparse, andother variants.In Learning in graphical models, pp. 355–368. Springer,1998.
Ramakrishnan etal. (2014)Ramakrishnan, R., Dral, P.O., Rupp, M., and VonLilienfeld, O.A.Quantum chemistry structures and properties of 134 kilo molecules.Scientific data, 1(1):1–7, 2014.
Rosenbaum etal. (2018)Rosenbaum, C., Klinger, T., and Riemer, M.Routing networks: Adaptive selection of non-linear functions formulti-task learning.In International Conference on Learning Representations, 2018.
Ruder (2017)Ruder, S.An overview of multi-task learning in deep neural networks.arXiv preprint arXiv:1706.05098, 2017.
Russakovsky etal. (2015)Russakovsky, O., Deng, J., Su, H., Krause, J., Satheesh, S., Ma, S., Huang, Z.,Karpathy, A., Khosla, A., Bernstein, M., etal.ImageNet large scale visual recognition challenge.International journal of computer vision, 115:211–252, 2015.
Sagawa etal. (2019)Sagawa, S., Koh, P.W., Hashimoto, T.B., and Liang, P.Distributionally robust neural networks.In International Conference on Learning Representations, 2019.
Särkkä (2013)Särkkä, S.Bayesian Filtering and Smoothing, volume3 of Instituteof Mathematical Statistics textbooks.Cambridge University Press, 2013.
Saul etal. (1996)Saul, L.K., Jaakkola, T., and Jordan, M.I.Mean field theory for Sigmoid Belief Networks.Journal of artificial intelligence research, 4:61–76, 1996.
Schaul etal. (2019)Schaul, T., Borsa, D., Modayil, J., and Pascanu, R.Ray interference: a source of plateaus in deep reinforcementlearning, 2019.
Sener & Koltun (2018)Sener, O. and Koltun, V.Multi-task learning as multi-objective optimization.Advances in neural information processing systems, 31, 2018.
Senushkin etal. (2023)Senushkin, D., Patakin, N., Kuznetsov, A., and Konushin, A.Independent component alignment for multi-task learning.In Proceedings of the IEEE/CVF Conference on Computer Visionand Pattern Recognition, pp. 20083–20093, 2023.
Shamshad etal. (2023)Shamshad, F., Khan, S., Zamir, S.W., Khan, M.H., Hayat, M., Khan, F.S., andFu, H.Transformers in medical imaging: A survey.Medical Image Analysis, pp. 102802, 2023.
Shamsian etal. (2023)Shamsian, A., Navon, A., Glazer, N., Kawaguchi, K., Chechik, G., and Fetaya, E.Auxiliary learning as an asymmetric bargaining game.arXiv preprint arXiv:2301.13501, 2023.
Shi etal. (2023)Shi, H., Ren, S., Zhang, T., and Pan, S.J.Deep multitask learning with progressive parameter sharing.In Proceedings of the IEEE/CVF International Conference onComputer Vision, pp. 19924–19935, 2023.
Shu etal. (2018)Shu, T., Xiong, C., and Socher, R.Hierarchical and interpretable skill acquisition in multi-taskreinforcement learning.In International Conference on Learning Representations, 2018.
Snoek etal. (2015)Snoek, J., Rippel, O., Swersky, K., Kiros, R., Satish, N., Sundaram, N.,Patwary, M., Prabhat, M., and Adams, R.Scalable Bayesian optimization using deep neural networks.In International conference on machine learning, pp.2171–2180. PMLR, 2015.
Standley etal. (2020)Standley, T., Zamir, A., Chen, D., Guibas, L., Malik, J., and Savarese, S.Which tasks should be learned together in multi-task learning?In International Conference on Machine Learning, pp.9120–9132. PMLR, 2020.
Taslimi etal. (2022)Taslimi, S., Taslimi, S., Fathi, N., Salehi, M., and Rohban, M.H.SwincheX: Multi-label classification on chest X-ray images withtransformers.arXiv preprint arXiv:2206.04246, 2022.
Vinyals etal. (2016)Vinyals, O., Bengio, S., and Kudlur, M.Order matters: Sequence to sequence for sets.In Bengio, Y. and LeCun, Y. (eds.), 4th InternationalConference on Learning Representations, ICLR, 2016.
Wang etal. (2017)Wang, X., Peng, Y., Lu, L., Lu, Z., Bagheri, M., and Summers, R.M.ChestX-ray8: Hospital-scale chest X-ray database and benchmarkson weakly-supervised classification and localization of common thoraxdiseases.In Proceedings of the IEEE conference on computer vision andpattern recognition, pp. 2097–2106, 2017.
Wang etal. (2020)Wang, Z., Tsvetkov, Y., Firat, O., and Cao, Y.Gradient vaccine: Investigating and improving multi-task optimizationin massively multilingual models.In International Conference on Learning Representations, 2020.
Wild etal. (2024)Wild, V.D., Ghalebikesabi, S., Sejdinovic, D., and Knoblauch, J.A rigorous link between deep ensembles and (variational) Bayesianmethods.Advances in Neural Information Processing Systems, 36, 2024.
Wilson & Izmailov (2020)Wilson, A.G. and Izmailov, P.Bayesian deep learning and a probabilistic perspective ofgeneralization.Advances in neural information processing systems,33:4697–4708, 2020.
Wilson etal. (2016a)Wilson, A.G., Hu, Z., Salakhutdinov, R., and Xing, E.P.Deep kernel learning.In Artificial intelligence and statistics, pp. 370–378.PMLR, 2016a.
Wilson etal. (2016b)Wilson, A.G., Hu, Z., Salakhutdinov, R.R., and Xing, E.P.Stochastic variational deep kernel learning.Advances in neural information processing systems, 29,2016b.
Wu etal. (2018)Wu, Z., Ramsundar, B., Feinberg, E.N., Gomes, J., Geniesse, C., Pappu, A.S.,Leswing, K., and Pande, V.MoleculeNet: a benchmark for molecular machine learning.Chemical science, 9(2):513–530, 2018.
Xin etal. (2022)Xin, D., Ghorbani, B., Gilmer, J., Garg, A., and Firat, O.Do current multi-task optimization methods in deep learning evenhelp?Advances in Neural Information Processing Systems,35:13597–13609, 2022.
Yu etal. (2020)Yu, T., Kumar, S., Gupta, A., Levine, S., Hausman, K., and Finn, C.Gradient surgery for multi-task learning.Advances in Neural Information Processing Systems,33:5824–5836, 2020.
Yun & Cho (2023)Yun, H. and Cho, H.Achievement-based training progress balancing for multi-tasklearning.In Proceedings of the IEEE/CVF International Conference onComputer Vision (ICCV), pp. 16935–16944, October 2023.
Zhang etal. (2017)Zhang, Z., Song, Y., and Qi, H.Age progression/regression by conditional adversarial autoencoder.In Proceedings of the IEEE conference on computer vision andpattern recognition, pp. 5810–5818, 2017.
Zheng etal. (2023)Zheng, C., Wu, W., Chen, C., Yang, T., Zhu, S., Shen, J., Kehtarnavaz, N., andShah, M.Deep learning-based human pose estimation: A survey.ACM Computing Surveys, 56(1):1–37, 2023.
Zhou etal. (2023)Zhou, C., Li, Q., Li, C., Yu, J., Liu, Y., Wang, G., Zhang, K., Ji, C., Yan,Q., He, L., etal.A comprehensive survey on pretrained foundation models: A historyfrom BERT to ChatGPT.arXiv preprint arXiv:2302.09419, 2023.
Appendix A Full Derivations
We now present the full derivation for Eq.4 & Eq.9 presented in the main text. For clarity, we drop here the superscript notation of the task.
A.1 Regression Moments
Starting with the first moment,
(12)
Where we made explicit the dependence in on the first step. For computing the second moment we aided by the matrix reference manual (Brookes, 2020),
(13)
We now solve each term separately and obtain the result,
(14)
Where, .
A.2 Second Order Posterior Approximation
We now present the quadratic form of the log-posterior in Eq.9. First we recap some of our notations here,
(15)
Using these constants in Eq.8 yields the following form:
(16)
The above takes the quadratic form of a Gaussian having mean and covariance .
Appendix B Full Experimental Details
All the experiments were done using PyTorch on NVIDIA V100 and A100 GPUs having 32GB of memory.
QM9. On this dataset we followed the training protocol presented by Navon etal. (2022). Specifically, We allocated examples for training and examples for validation and testing. The task labels are normalized to have zero mean and unit std. We use the implementation of (Fey & Lenssen, 2019) for the message-passing NN presented in (Gilmer etal., 2017) as the base NN. Here, we trained only our method and the baseline method Aligned-MTL-UB. All the other results were taken from (Navon etal., 2022; Dai etal., 2023). We used the same random seeds as in those studies. Each method was trained for epochs using the ADAM optimizer (Kingma & Ba, 2014) with an initial lr of . The batch size was set to . We use the ReduceOnPlate scheduler with the metric, computed on the validation set. This metric was also used for early stopping and model selection. For BayesAgg-MTL we set the number of pre-training epochs using linear scalarization to . In initial experiments, we found that in regression tasks relatively higher values for the hyper-parameter were preferred. Hence, we searched over . For the Aligned-MTL-UB we did a hyper-parameter search over the scale modes in {min, median, and rmse}, and whether to apply that scale to the task-specific parameters as well.
CIFAR-MTL. Similarly to (Rosenbaum etal., 2018), to form an MTL benchmark we used the coarse labels of CIFAR-100. Each example in the CIFAR-100 dataset belongs to one of super-classes. We use these super-classes as separate binary MTL tasks, where the task value is if the example indeed belongs to the super-class and otherwise. We use the official CIFAR train-test split of and respectively. We allocated examples from the training set to validation. To train the models we use a CNN having convolution layers with channels and a kernel of size . Each convolution was followed by an Exponential Linear Unit (ELU) activation and max-pooling of . The final layer is a batch normalization layer. All methods were trained for epochs using the ADAM optimizer, with an initial learning rate of and a scheduler that drops the learning rate by a factor of at and of the training. We set the batch size to and used a weight decay of .For all baseline methods, we did a hyper-parameter grid search over the most important hyper-parameters. Specifically, we would like to highlight that we searched over additional weight decay values for the LS, SI, and RLW baselines as advocated by Kurin etal. (2022).As for BayesAgg-MTL , unlike the regression case, for classification smaller values are preferred. We searched over . Also, we search over the number of pre-train epochs in . We set , the number of Monte-Carlo samples, to , although we could have used much less without performance degradation. We used the validation accuracy for early stopping and model selection.
ChestX-ray14. This dataset reports the absence or appearance of types of chest diseases, which we view as an MTL problem. It contains approximately images from patients. We use the official data split presented in (Wang etal., 2017), having training examples, validation examples, and test examples. We follow the experimental setup of (Taslimi etal., 2022) that uses PyTorch Image Models (Wightman, 2019) for data augmentations, a publicly available repository. We resize each image to size and use data augmentation such as color jitter having intensity and random erase of pixels with a probability of . Images are normalized according to ImageNet statistics (Russakovsky etal., 2015). We use here ResNet-34 pre-trained on ImageNet as the shared feature extractor. We replaced the final classification layer with a fully connected layer of dimension followed by an ELU activation. Experimental details and hyper-parameter searches are similar to those described for CIFAR-MTL, except for the following changes. Here we trained for epochs, the batch size was set to , and we didn’t use a weight decay. We use the metric for early stopping and model selection.
UTKFace. This dataset contains approximately images of faces, each associated with the age, gender, and ethnicity of the person. We remove examples from the dataset that have missing labels. We split the dataset to train/validation/test according to the scheme. The split was stratified by the age variable as it is the most diverse label. We treat the task of predicting the age as a regression task, and we normalize it to have zero mean and unit std. During training, images are resized to , randomly cropped to size , and undergo random horizontal flip. Test images are resized and centered cropped. Here, we used ResNet-18 with the final classification layer replaced by a fully connected layer of size and an ELU activation. The experimental setup is similar to that described under the CIFAR-MTL, with the exception that here we trained for epochs. We perform a hyper-parameter grid search for all methods on this dataset as well. For our method, we set the number of pre-training epochs to and searched over for the regression task and for the classification tasks. We use the metric for early stopping and model selection. Optimizing and evaluating the regression task is done using the MSE loss and the classification tasks using the standard cross-entropy loss.
CIFAR-MTL
QM9
LS
SI
RLW
DWA
UW
MGDA
PCGrad
CAGrad
IMTL-G
Nash-MTL
IGBv2
Gradient w.r.t Representation
MGDA-UB
IMTL-G
Aligned-MTL-UB
BayesAgg-MTL (Ours)
Appendix C Additional Experiments
C.1 Calibration
A possible benefit of using a Bayesian layer as the last layer is enhanced uncertainty estimation capabilities. Here we compare BayesAgg-MTL to baseline methods on that aspect. To do so we log the expected calibration error (ECE) (Naeini etal., 2015) and Brier score (Brier, 1950) for all methods on the classification tasks of the UTKFace dataset. In ECE we first discretize the line segment and then measure a weighted average difference between the classifier confidence and accuracy. We use interval bins in our comparison. Brier score measures the mean square error between the one-hot label encoding and the prediction probability vector. In both metrics, lower values are better. Results are presented in Figure4. From the figure, BayesAgg-MTL is better calibrated than most methods on both datasets. On the gender task, it is best calibrated according to the two metrics. On the Ethnicity task, it has the best Brier score and second-best ECE score. We stress here that for a fair comparison with baseline methods, we did not use the Bayesian posterior of BayesAgg-MTL on the last layer to make test predictions, but rather the point estimate of it learned during training. Using the full posterior should yield even better results.
C.2 Training Time
Table4 compares the run time of all methods on the CIFAR-MTL and QM9 datasets. We report the average processing time of a batch based on epochs.To do the comparison, we use the best hyper-parameter configuration (in terms of performance) according to the CIFAR-MTL experiments. For MGDA and IMTL-G we present the run time under two settings, (1) when using in the aggregation scheme the full gradients w.r.t the shared parameters (top block); (2) when using in the aggregation scheme the gradients w.r.t the hidden layer (bottom block) as BayesAgg-MTL does. For BayesAgg-MTL we do not include the pre-training steps in the time measurements. From the table, methods that do not rely on the gradients for weighing the tasks are faster as outlined before in previous studies (Xin etal., 2022; Kurin etal., 2022); however, this often comes at a significant performance reduction. BayesAgg-MTL training time is almost as fast as those methods on regression problems, in which everything is done in closed-form, and slower on classification problems, partly due to the sampling process. Nevertheless, it is substantially faster than other gradient balancing methods that use gradients w.r.t the shared parameters.
C.3 Comparison to Bayesian Training
QM9
UTKFace
Ensemble ( heads)
Ensemble ( networks)
BayesAgg-MTL (Ours)
Given that we used a Bayesian inference procedure in our approach, a natural question one may ask is how does standard approximate Bayesian training perform in MTL?
Recall that the goal of this paper is to use Bayesian inference on the last layer as a means to train deterministic MTL models using the uncertainty estimates in the gradients of the tasks. We use these uncertainty estimates to come up with an aggregation rule for combining the gradients of the tasks to a shared update direction. More concisely, our aim is to better learn a deterministic MTL model while reducing as much as possible the computational overhead involved in training it. In standard approximate Bayesian training the gradient used in the backward process is considered as a deterministic quantity, similarly to non-Bayesian training. Hence, even when applying standard Bayesian inference to the task-specific parameters, the optimization issues regarding how to combine the gradients of the tasks effectively remain.
To showcase that we compare BayesAgg-MTL to deep ensembles (Lakshminarayanan etal., 2017) that have a strong link to approximate Bayesian methods (Wilson & Izmailov, 2020; D’Angelo & Fortuin, 2021; Wild etal., 2024). We chose deep ensembles because of their simplicity and predictive abilities. We show here results on QM9 and UTKFace for two baselines: (1) Using heads for each task and a shared backbone; (2) Using networks, each with a different backbone and task heads. The latter is substantially computationally more demanding as it requires different copies of the backbone as well, which is usually large. We combine the tasks using linear scalarization (i.e., equal weighting of the tasks) and averaging over the ensemble members. We follow the same experimental protocol of the paper and report the values for each method in Table5. From the table, the ensemble model with the shared backbone performs roughly the same as standard linear scalarization, with a slight advantage on QM9. This result makes sense as the uncertainty information is not taken into account when aggregating the gradients (i.e., only the mean values are used). Full ensemble training improves upon the ensemble baseline having a shared feature extractor, but it comes with a substantial computational overhead. Finally, BayesAgg-MTL substantially outperforms both methods on both datasets.
C.4 Full Results
In Tables 6 and 7 we present the per-task results for all methods on the QM9 and ChestX-ray14 respectively. On QM9 we report the mean-absolute error of each task and on ChestX-ray14 the AUC-ROC of the tasks. Due to lack of space, we abbreviated several diseases names from the ChestX-ray14. We outline here the full names of all diseases: Atelectasis, Cardiomegaly, Consolidation, Edema, Effusion, Emphysema, Fibrosis, Hernia, Infiltration, Mass, Nodule, Pleural_Thickening, Pneumonia, Pneumothorax.
Address: Apt. 536 6162 Reichel Greens, Port Zackaryside, CT 22682-9804
Phone: +9958384818317
Job: IT Representative
Hobby: Scrapbooking, Hiking, Hunting, Kite flying, Blacksmithing, Video gaming, Foraging
Introduction: My name is Jamar Nader, I am a fine, shiny, colorful, bright, nice, perfect, curious person who loves writing and wants to share my knowledge and understanding with you.
We notice you're using an ad blocker
Without advertising income, we can't keep making this site awesome for you.