Bayesian Uncertainty for Gradient Aggregation in Multi-Task Learning (2024)

Idan Achituve  Idit Diamant  Arnon Netzer  Gal Chechik  Ethan Fetaya

Abstract

As machine learning becomes more prominent there is a growing demand to perform several inference tasks in parallel. Running a dedicated model for each task is computationally expensive and therefore there is a great interest in multi-task learning (MTL). MTL aims at learning a single model that solves several tasks efficiently. Optimizing MTL models is often achieved by first computing a single gradient per task and then aggregating the gradients for obtaining a combined update direction. However, this approach do not consider an important aspect, the sensitivity in the gradient dimensions. Here, we introduce a novel gradient aggregation approach using Bayesian inference. We place a probability distribution over the task-specific parameters, which in turn induce a distribution over the gradients of the tasks. This additional valuable information allows us to quantify the uncertainty in each of the gradients dimensions, which can then be factored in when aggregating them. We empirically demonstrate the benefits of our approach in a variety of datasets, achieving state-of-the-art performance.

Machine Learning, ICML

1 Introduction

In many application domains, there is a need to perform several machine learning inference tasks simultaneously. For instance, an autonomous vehicle needs to identify and detect objects in its vicinity, perform lane detection, track the movements of other vehicles over time, and predict free space around it, all in parallel and in real-time. In deep Multi-Task Learning (MTL) the goal is to train a single neural network (NN) to solve several tasks simultaneously, thus avoiding the need to have one dedicated model for each task (Caruana, 1997). Besides reducing the computational demand at test time, MTL also has the potential to improve generalization (Baxter, 2000). It is therefore not surprising that applications of MTL are taking central roles in various domains, such as vision (Achituve etal., 2021a; Shamshad etal., 2023; Zheng etal., 2023), natural language processing (Liu etal., 2019b; Zhou etal., 2023), speech (Michelsanti etal., 2021), robotics (Devin etal., 2017; Shu etal., 2018), and general scientific problems (Wu etal., 2018) to name a few.

However, optimizing multiple tasks simultaneously is a challenging problem that may lead to degradation in performance compared to learning them individually (Standley etal., 2020; Yu etal., 2020). To address this issue, one basic formula that many MTL optimization algorithms follow is to first calculate the gradient of each task’s loss, and then aggregate these gradients according to some specified scheme. For example, several studies focus on reducing conflicts between the gradients before averaging them (Yu etal., 2020; Wang etal., 2020), others find a convex combination with minimal norm (Sener & Koltun, 2018; Désidéri, 2012), and some use a game theoretical approach (Navon etal., 2022). However, by relying only on the gradient these methods miss an important aspect, the sensitivity of the gradient in each dimension.

Our approach builds on the following observation - for each task, there may be many “good” parameter configurations. Standard MTL optimization methods take only a single value into account, and as such lose information in the aggregation step. Hence, tracking all of the parameter configurations will yield a richer description of the gradient space that can be advantageous when finding an update direction. Specifically, to account for all parameter values, we propose to place a probability distribution over the task-specific parameters, which in turn induces a probability distribution over the gradients. As a result, we obtain uncertainty estimates for the gradients that reflect the sensitivity in each of their dimensions. High-uncertainty dimensions are more lenient for changes while dimensions with a lower uncertainty are more strict (see illustration in Figure2).

To obtain a probability distribution over the task-specific parameters we take a Bayesian approach. According to the Bayesian view, a posterior distribution over parameters of interest can be derived through Bayes rule. In MTL, it is common to use a shared feature extractor network with linear task-specific layers (Ruder, 2017).Hence, if we assume a Bayesian model over the last task-specific layer weights during the back-propagation process, we obtain the posterior distributions over them. The posterior is then used to compute a Gaussian distribution over the gradients by means of moment matching.Then, to derive an update direction for the shared network, we design a novel aggregation scheme that considers the full distributions of the gradients. We name our method BayesAgg-MTL. An important implication of our approach is that BayesAgg-MTL assigns weights to the gradients at a higher resolution compared to existing methods, allocating a specific weight for each dimension and datum in the batch. We demonstrate our method effectiveness over baseline methods on the MTL benchmarks QM9 (Ramakrishnan etal., 2014), CIFAR-100 (Krizhevsky etal., 2009), ChestX-ray14 (Wang etal., 2017), and UTKFace (Zhang etal., 2017).

In summary, this paper makes the following novel contributions:(1) The first Bayesian formulation of gradient aggregation for Multi-Task Learning. (2) A novel posterior approximation based on a second-order Taylor expansion. (3) A new MTL optimization algorithm based on our posterior estimation. (4) New state-of-the-art results on several MTL benchmarks compared to leading methods. Our code is publicly available at https://github.com/ssi-research/BayesAgg˙MTL.

2 Background

Notations. Scalars, vectors, and matrices are denoted with lower-case letters (e.g., x𝑥xitalic_x), bold lower-case letters (e.g., 𝐱𝐱{\mathbf{x}}bold_x), and bold upper-case letters (e.g., 𝐗𝐗{\mathbf{X}}bold_X) respectively. All vectors are treated as column vectors. Training samples are tuples consisting of shared features across all tasks and labels of K𝐾Kitalic_K tasks, namely (𝐱,{𝐲k}k=1K)𝒟similar-to𝐱superscriptsubscriptsuperscript𝐲𝑘𝑘1𝐾𝒟({\mathbf{x}},\{{\mathbf{y}}^{k}\}_{k=1}^{K})\sim{\mathcal{D}}( bold_x , { bold_y start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ) ∼ caligraphic_D, where 𝒟𝒟{\mathcal{D}}caligraphic_D denotes the training set. We denote the dimensionality of the input and the output of task k𝑘kitalic_k by d𝐱subscript𝑑𝐱d_{\mathbf{x}}italic_d start_POSTSUBSCRIPT bold_x end_POSTSUBSCRIPT and oksubscript𝑜𝑘o_{k}italic_o start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT accordingly.

In this study, we focus on common NN architectures for MTL having a shared feature extractor and linear task-specific heads (Kendall etal., 2018; Sener & Koltun, 2018).The model parameters are denoted by {𝜽,{𝐰k}k=1K}𝜽superscriptsubscriptsuperscript𝐰𝑘𝑘1𝐾\{\bm{\mathbf{\theta}},\{{\mathbf{w}}^{k}\}_{k=1}^{K}\}{ bold_italic_θ , { bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT }, where 𝜽d𝜽𝜽superscriptsubscript𝑑𝜽\bm{\mathbf{\theta}}\in{\mathbb{R}}^{d_{\bm{\mathbf{\theta}}}}bold_italic_θ ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is the vector of shared parameters and {𝐰k}k=1Ksuperscriptsubscriptsuperscript𝐰𝑘𝑘1𝐾\{{\mathbf{w}}^{k}\}_{k=1}^{K}{ bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT are task-specific parameter vectors, each lies in dksuperscriptsubscript𝑑𝑘{\mathbb{R}}^{d_{k}}blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT. The last shared feature representation is denoted by the vector 𝐡(𝐱;𝜽)d𝐡𝐡𝐱𝜽superscriptsubscript𝑑𝐡{\mathbf{h}}({\mathbf{x}};\bm{\mathbf{\theta}})\in{\mathbb{R}}^{d_{\mathbf{h}}}bold_h ( bold_x ; bold_italic_θ ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT bold_h end_POSTSUBSCRIPT end_POSTSUPERSCRIPT.Hence, the output of the network for task k𝑘kitalic_k can be described as 𝐟k(𝐡(𝐱;𝜽);𝐰k)superscript𝐟𝑘𝐡𝐱𝜽superscript𝐰𝑘{\mathbf{f}}^{k}({\mathbf{h}}({\mathbf{x}};\bm{\mathbf{\theta}});{\mathbf{w}}^%{k})bold_f start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ( bold_h ( bold_x ; bold_italic_θ ) ; bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ).The loss of task k[1,,K]𝑘1𝐾k\in[1,...,K]italic_k ∈ [ 1 , … , italic_K ] is denoted by k(𝐱,𝐲;{𝜽,𝐰k})superscript𝑘𝐱𝐲𝜽superscript𝐰𝑘\ell^{k}({\mathbf{x}},{\mathbf{y}};\{\bm{\mathbf{\theta}},{\mathbf{w}}^{k}\})roman_ℓ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ( bold_x , bold_y ; { bold_italic_θ , bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT } ).The gradient of loss ksuperscript𝑘\ell^{k}roman_ℓ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT w.r.t 𝐡(𝐱;𝜽)𝐡𝐱𝜽{\mathbf{h}}({\mathbf{x}};\bm{\mathbf{\theta}})bold_h ( bold_x ; bold_italic_θ ) is 𝐠kk𝐡(𝐱;𝜽)(𝐱,𝐲;{𝜽,𝐰k})dhsuperscript𝐠𝑘superscript𝑘𝐡𝐱𝜽𝐱𝐲𝜽superscript𝐰𝑘superscriptsubscript𝑑{\mathbf{g}}^{k}\coloneqq\frac{\partial\ell^{k}}{\partial{\mathbf{h}}({\mathbf%{x}};\bm{\mathbf{\theta}})}({\mathbf{x}},{\mathbf{y}};\{\bm{\mathbf{\theta}},{%\mathbf{w}}^{k}\})\in{\mathbb{R}}^{d_{h}}bold_g start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ≔ divide start_ARG ∂ roman_ℓ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT end_ARG start_ARG ∂ bold_h ( bold_x ; bold_italic_θ ) end_ARG ( bold_x , bold_y ; { bold_italic_θ , bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT } ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT end_POSTSUPERSCRIPT. For clarity of exposition, function dependence on input variables will be omitted from now on.

Bayesian Uncertainty for Gradient Aggregation in Multi-Task Learning (1)

2.1 Multi-Task Learning

A prevailing approach to optimize MTL models goes as follows. First, the gradient of each task loss is computed. Second, an aggregation rule is imposed to combine the gradients according to some algorithm. And lastly, perform an update step using the outcome of the aggregation step. Commonly the aggregation rule operates on the gradients of the loss w.r.t parameters, or only the shared parameters (e.g., Yu etal., 2020; Navon etal., 2022; Shamsian etal., 2023)). Alternatively, to avoid a costly full back-propagation process for each task, some methods suggest applying it on the last shared representation (e.g., Sener & Koltun, 2018; Liu etal., 2020; Senushkin etal., 2023). Here, to make our method fast and scalable, we take the latter approach and note that it could be extended to full gradient aggregation.

2.2 Bayesian Inference

We wish to incorporate uncertainty estimates for the gradients into the aggregation procedure. Doing so will allow us to find an update direction that takes into account the importance of each gradient dimension for each task. A natural choice to model uncertainty is using Bayesian inference. Since we would like to get uncertainty estimates w.r.t the last shared hidden layer, we treat only the last task-specific layer as a Bayesian layer. This “Bayesian last layer” approach is a common way to scale Bayesian inference to deep neural networks (Snoek etal., 2015; Calandra etal., 2016; Wilson etal., 2016a; Achituve etal., 2021c). We will now present some of the main concepts of Bayesian modeling that will be used as part of our method.

For simplicity, assume a single output variable. We also dropped the task notation for clarity. According to the Bayesian paradigm, instead of treating the parameters 𝐰𝐰{\mathbf{w}}bold_w as deterministic values that need to be optimized, they are treated as random variables, i.e. there is a distribution over the parameters. The posterior distribution for 𝐰𝐰{\mathbf{w}}bold_w, after observing the data, is given using Bayes rule as

logp(𝐰|𝒟)𝑙𝑜𝑔𝑝conditional𝐰𝒟\displaystyle log~{}p({\mathbf{w}}|{\mathcal{D}})italic_l italic_o italic_g italic_p ( bold_w | caligraphic_D )logp(𝐲|𝐗,𝐰)+logp(𝐰).proportional-toabsent𝑙𝑜𝑔𝑝conditional𝐲𝐗𝐰𝑙𝑜𝑔𝑝𝐰\displaystyle\propto log~{}p({\mathbf{y}}|{\mathbf{X}},{\mathbf{w}})+log~{}p({%\mathbf{w}}).∝ italic_l italic_o italic_g italic_p ( bold_y | bold_X , bold_w ) + italic_l italic_o italic_g italic_p ( bold_w ) .(1)

Predictions in Bayesian inference are given by taking the expected prediction with respect to the posterior distribution.In general, the Bayesian inference procedure for 𝐰𝐰{\mathbf{w}}bold_w is intractable. However, for some specific scenarios, there exists an analytic solution. For example, in linear regression, if we assume a Gaussian likelihood with a fixed independent scalar noise between the observations τ𝜏\tauitalic_τ, p(𝐲|{𝐱i}i=1|𝒟|,𝐰)=i=1|𝒟|𝒩(yi|𝐱iT𝐰,τ2)𝑝conditional𝐲superscriptsubscriptsubscript𝐱𝑖𝑖1𝒟𝐰superscriptsubscriptproduct𝑖1𝒟𝒩conditionalsubscript𝑦𝑖superscriptsubscript𝐱𝑖𝑇𝐰superscript𝜏2p({\mathbf{y}}|\{{\mathbf{x}}_{i}\}_{i=1}^{|{\mathcal{D}}|},{\mathbf{w}})=%\prod_{i=1}^{|{\mathcal{D}}|}\mathcal{N}(y_{i}|{\mathbf{x}}_{i}^{T}{\mathbf{w}%},\tau^{2})italic_p ( bold_y | { bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | caligraphic_D | end_POSTSUPERSCRIPT , bold_w ) = ∏ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | caligraphic_D | end_POSTSUPERSCRIPT caligraphic_N ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_w , italic_τ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ), and a Gaussian prior p(𝐰)=𝒩(𝐰|𝐦p,𝐒p)𝑝𝐰𝒩conditional𝐰subscript𝐦𝑝subscript𝐒𝑝p({\mathbf{w}})=\mathcal{N}({\mathbf{w}}|{\mathbf{m}}_{p},{\mathbf{S}}_{p})italic_p ( bold_w ) = caligraphic_N ( bold_w | bold_m start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT , bold_S start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ) then,

p(𝐰|𝒟)𝑝conditional𝐰𝒟\displaystyle p({\mathbf{w}}|{\mathcal{D}})italic_p ( bold_w | caligraphic_D )=𝒩(𝐰|𝐦,𝐒)absent𝒩conditional𝐰𝐦𝐒\displaystyle=\mathcal{N}({\mathbf{w}}|{\mathbf{m}},{\mathbf{S}})= caligraphic_N ( bold_w | bold_m , bold_S )(2)
𝐦𝐦\displaystyle{\mathbf{m}}bold_m=𝐒((𝐒p)1𝐦p+τ2𝐗𝐲)absent𝐒superscriptsubscript𝐒𝑝1subscript𝐦𝑝superscript𝜏2𝐗𝐲\displaystyle={\mathbf{S}}(({\mathbf{S}}_{p})^{-1}{\mathbf{m}}_{p}+\tau^{-2}{%\mathbf{X}}{\mathbf{y}})= bold_S ( ( bold_S start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_m start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT + italic_τ start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT bold_Xy )
𝐒𝐒\displaystyle{\mathbf{S}}bold_S=((𝐒p)1+τ2𝐗𝐗T)1.absentsuperscriptsuperscriptsubscript𝐒𝑝1superscript𝜏2superscript𝐗𝐗𝑇1\displaystyle=(({\mathbf{S}}_{p})^{-1}+\tau^{-2}{\mathbf{X}}{\mathbf{X}}^{T})^%{-1}.= ( ( bold_S start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT + italic_τ start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT bold_XX start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT .

Here 𝐗d𝐱×|𝒟|𝐗superscriptsubscript𝑑𝐱𝒟{\mathbf{X}}\in{\mathbb{R}}^{d_{\mathbf{x}}\times|{\mathcal{D}}|}bold_X ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT bold_x end_POSTSUBSCRIPT × | caligraphic_D | end_POSTSUPERSCRIPT is the matrix that results from stacking the vectors {𝐱i}i=1|𝒟|superscriptsubscriptsubscript𝐱𝑖𝑖1𝒟\{{\mathbf{x}}_{i}\}_{i=1}^{|{\mathcal{D}}|}{ bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | caligraphic_D | end_POSTSUPERSCRIPT. Similarly, we denote by 𝐇d𝐡×|𝒟|𝐇superscriptsubscript𝑑𝐡𝒟{\mathbf{H}}\in{\mathbb{R}}^{d_{\mathbf{h}}\times|{\mathcal{D}}|}bold_H ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT bold_h end_POSTSUBSCRIPT × | caligraphic_D | end_POSTSUPERSCRIPT the matrix that results from stacking the vectors of hidden representation. In the specific case of deep NNs with Bayesian last layer we get the same inference result only with 𝐇𝐇{\mathbf{H}}bold_H replacing 𝐗𝐗{\mathbf{X}}bold_X. Going beyond a single output variable entails defining a covariance matrix for the noise model. However, in this study we assume independence between the output variables in these cases.

Unlike regression, in classification the likelihood is not a Gaussian, and the posterior can only be approximated. The common choice is to use variational inference (Wilson etal., 2016b; Achituve etal., 2021b, 2023), although there are other alternatives as well (Kristiadi etal., 2020).

3 Method

We start with an outline of the problem and our approach. Consider a deep network for multi-task learning that has a shared feature extractor part and task-specific linear layers. We propose to use Bayesian inference on the last layer as a means to train deterministic MTL models. For each task k𝑘kitalic_k, we define a Bayesian probabilistic model representing the uncertainty over the linear weights of the last, task-specific layer 𝐰ksuperscript𝐰𝑘{\mathbf{w}}^{k}bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT. The distribution over weights induces a distribution over gradients of the loss with respect to the last shared hidden layer.Given these per-task distributions on a joint space, we propose an aggregation rule for combining the gradients of the tasks to a shared update direction that takes into account the uncertainty in the gradients (see illustration in Figure1). Then, the back-propagation process can proceed as usual.

Since regression and classification setups yield different inference procedures according to our approach, albeit having the same general framework, we discuss the two setups separately, starting with regression.

3.1 BayesAgg-MTL for Regression Tasks

Consider a standard square loss for task k𝑘kitalic_k, k=(yky^k)2superscript𝑘superscriptsuperscript𝑦𝑘superscript^𝑦𝑘2\ell^{k}=(y^{k}-\hat{y}^{k})^{2}roman_ℓ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT = ( italic_y start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT - over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, between the label yksuperscript𝑦𝑘y^{k}italic_y start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT and the network output y^ksuperscript^𝑦𝑘\hat{y}^{k}over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT. Given a random batch of example 𝒟similar-to𝒟{\mathcal{B}}\sim{\mathcal{D}}caligraphic_B ∼ caligraphic_D, the gradient of the loss with respect to the hidden layer 𝐡𝐡{\mathbf{h}}bold_h for the ithsuperscript𝑖𝑡i^{th}italic_i start_POSTSUPERSCRIPT italic_t italic_h end_POSTSUPERSCRIPT example is,

𝐠ik=liky^iky^ik𝐡i=2𝐰k(𝐡iT𝐰kyik).subscriptsuperscript𝐠𝑘𝑖subscriptsuperscript𝑙𝑘𝑖subscriptsuperscript^𝑦𝑘𝑖subscriptsuperscript^𝑦𝑘𝑖subscript𝐡𝑖2superscript𝐰𝑘subscriptsuperscript𝐡𝑇𝑖superscript𝐰𝑘subscriptsuperscript𝑦𝑘𝑖{\mathbf{g}}^{k}_{i}=\frac{\partial l^{k}_{i}}{\partial\hat{y}^{k}_{i}}\frac{%\partial\hat{y}^{k}_{i}}{\partial{\mathbf{h}}_{i}}=2{\mathbf{w}}^{k}({\mathbf{%h}}^{T}_{i}{\mathbf{w}}^{k}-y^{k}_{i}).bold_g start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = divide start_ARG ∂ italic_l start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG ∂ over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG divide start_ARG ∂ over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG = 2 bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ( bold_h start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT - italic_y start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) .(3)

Our main observation is that 𝐠iksubscriptsuperscript𝐠𝑘𝑖{\mathbf{g}}^{k}_{i}bold_g start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is a function of 𝐰ksuperscript𝐰𝑘{\mathbf{w}}^{k}bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT. Hence, if we view 𝐰ksuperscript𝐰𝑘{\mathbf{w}}^{k}bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT in the back-propagation process as a random variable, then 𝐠iksubscriptsuperscript𝐠𝑘𝑖{\mathbf{g}}^{k}_{i}bold_g start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT will be a random variable as well. This view will allow us to capture the uncertainty in the task gradient. Since the dimension of the hidden layer is usually small compared to the dimension of all shared parameters, operations in this space, such as matrix inverse, should not be costly.

If we fix all the shared parameters, then the posterior over 𝐰ksuperscript𝐰𝑘{\mathbf{w}}^{k}bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT has a Gaussian distribution with known parameters via Eq.2. As 𝐠iksubscriptsuperscript𝐠𝑘𝑖{\mathbf{g}}^{k}_{i}bold_g start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is quadratic in 𝐰ksuperscript𝐰𝑘{\mathbf{w}}^{k}bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT, it has a generalized chi-squared distribution (Davies, 1973). However, since this distribution does not admit a closed-form density function, and since the gradient aggregation needs to be efficient as we run it at each iteration, we approximate 𝐠iksubscriptsuperscript𝐠𝑘𝑖{\mathbf{g}}^{k}_{i}bold_g start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT as a Gaussian distribution. The optimal choice for the parameters of this Gaussian is given by matching its first two moments to those of the true density, as these parameters minimize the Kullback–Leibler divergence between the two distributions (Minka, 2001). Luckily, in the regression case, we can derive the first two moments from the posterior over 𝐰ksuperscript𝐰𝑘{\mathbf{w}}^{k}bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT,

𝔼[𝐠ik]=2[𝐒k𝐡i+𝐦k(𝐡iT𝐦kyik)],𝔼delimited-[]subscriptsuperscript𝐠𝑘𝑖2delimited-[]superscript𝐒𝑘subscript𝐡𝑖superscript𝐦𝑘superscriptsubscript𝐡𝑖𝑇superscript𝐦𝑘subscriptsuperscript𝑦𝑘𝑖\displaystyle\mathbb{E}[{\mathbf{g}}^{k}_{i}]=2[{\mathbf{S}}^{k}{\mathbf{h}}_{%i}+{\mathbf{m}}^{k}({\mathbf{h}}_{i}^{T}{\mathbf{m}}^{k}-y^{k}_{i})],blackboard_E [ bold_g start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ] = 2 [ bold_S start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + bold_m start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ( bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_m start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT - italic_y start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ] ,(4)
𝔼[𝐠ik(𝐠ik)T]=4[(yik)2(𝐒k+𝐌k)2yik(𝐦k𝐡iT(𝐒k+𝐌k)\displaystyle\mathbb{E}[{\mathbf{g}}^{k}_{i}({\mathbf{g}}^{k}_{i})^{T}]=4[(y^{%k}_{i})^{2}({\mathbf{S}}^{k}+{\mathbf{M}}^{k})-2y^{k}_{i}({\mathbf{m}}^{k}{%\mathbf{h}}_{i}^{T}({\mathbf{S}}^{k}+{\mathbf{M}}^{k})blackboard_E [ bold_g start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_g start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ] = 4 [ ( italic_y start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( bold_S start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT + bold_M start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) - 2 italic_y start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_m start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( bold_S start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT + bold_M start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT )
+(𝐒k+𝐌k)𝐡i(𝐦k)T+𝐡iT𝐦k(𝐒k𝐌k))\displaystyle+({\mathbf{S}}^{k}+{\mathbf{M}}^{k}){\mathbf{h}}_{i}({\mathbf{m}}%^{k})^{T}+{\mathbf{h}}_{i}^{T}{\mathbf{m}}^{k}({\mathbf{S}}^{k}-{\mathbf{M}}^{%k}))+ ( bold_S start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT + bold_M start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_m start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT + bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_m start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ( bold_S start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT - bold_M start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) )
+(𝐒k+𝐌k)(𝐀i+𝐀iT)(𝐒k+𝐌k)superscript𝐒𝑘superscript𝐌𝑘subscript𝐀𝑖superscriptsubscript𝐀𝑖𝑇superscript𝐒𝑘superscript𝐌𝑘\displaystyle+({\mathbf{S}}^{k}+{\mathbf{M}}^{k})({\mathbf{A}}_{i}+{\mathbf{A}%}_{i}^{T})({\mathbf{S}}^{k}+{\mathbf{M}}^{k})+ ( bold_S start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT + bold_M start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) ( bold_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + bold_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) ( bold_S start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT + bold_M start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT )
+Tr(𝐀i𝐒k)(𝐒k+𝐌k)+(𝐦k)T𝐀i𝐦k(𝐒k𝐌k)],\displaystyle+Tr({\mathbf{A}}_{i}{\mathbf{S}}^{k})({\mathbf{S}}^{k}+{\mathbf{M%}}^{k})+({\mathbf{m}}^{k})^{T}{\mathbf{A}}_{i}{\mathbf{m}}^{k}({\mathbf{S}}^{k%}-{\mathbf{M}}^{k})],+ italic_T italic_r ( bold_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_S start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) ( bold_S start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT + bold_M start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) + ( bold_m start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_m start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ( bold_S start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT - bold_M start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) ] ,

where 𝐀i=𝐡i𝐡iTsubscript𝐀𝑖subscript𝐡𝑖superscriptsubscript𝐡𝑖𝑇{\mathbf{A}}_{i}={\mathbf{h}}_{i}{\mathbf{h}}_{i}^{T}bold_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT, 𝐌k=𝐦k(𝐦k)Tsuperscript𝐌𝑘superscript𝐦𝑘superscriptsuperscript𝐦𝑘𝑇{\mathbf{M}}^{k}={\mathbf{m}}^{k}({\mathbf{m}}^{k})^{T}bold_M start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT = bold_m start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ( bold_m start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT, we assumed τ=1𝜏1\tau=1italic_τ = 1, and Tr()𝑇𝑟Tr(\cdot)italic_T italic_r ( ⋅ ) is the matrix trace. We emphasize that the following approximation is for the gradient of a single data point and a single task, not for the gradient of the task with respect to the entire batch. The full derivation is presented in AppendixA.1.

Several points deserve attention here. First, note the similarity between the solution of the first moment and the gradient obtained via the standard back-propagation. The two differences are that the last layer parameters, 𝐰ksuperscript𝐰𝑘{\mathbf{w}}^{k}bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT, are replaced with the posterior mean, 𝐦ksuperscript𝐦𝑘{\mathbf{m}}^{k}bold_m start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT, and an uncertainty term was added. In the extreme case of 𝐒k0superscript𝐒𝑘0{\mathbf{S}}^{k}\rightarrow 0bold_S start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT → 0 and 𝐦k𝐰ksuperscript𝐦𝑘superscript𝐰𝑘{\mathbf{m}}^{k}\rightarrow{\mathbf{w}}^{k}bold_m start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT → bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT, the mean coincides with that of the standard back-propagation. Second, in the case of a multi-output task, following our independence assumption between output variables, we can obtain the moments for each output dimension separately using the same procedure, so de facto we treat each output as a different task. Finally, during training, the shared parameters are constantly being updated. Hence, to compute the posterior distribution for 𝐰ksuperscript𝐰𝑘{\mathbf{w}}^{k}bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT we need to iterate over the entire dataset at each update step. In practice, this can make our method computationally expensive. Therefore, we use the current batch data only to approximate the posterior over 𝐰ksuperscript𝐰𝑘{\mathbf{w}}^{k}bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT, and introduce information about the full dataset through the prior as described next.

Prior selection.A common choice in Bayesian deep learning is to choose uninformative priors, such as a standard Gaussian, to let the data be the main influence on the posterior (Wilson & Izmailov, 2020; Fortuin etal., 2021). However, in our case, we found this prior to be too weak. Since the posterior depends only on a single batch we opted to introduce information about the whole dataset through the prior. A natural choice is to use the posterior distribution of the previous batch as our prior (Särkkä, 2013, Chapter3). However, this method did not work well in our experiments and we developed an alternative. During each epoch, we collect the feature representations and labels of all examples in the dataset. At the end of the epoch, we compute the posterior based on the full data (with an isotropic Gaussian prior) and use this posterior as the prior at each step in the subsequent epoch. Updating the full data prior more frequently is likely to have a beneficial effect on our overall model; however it will also probably make the training time longer. Hence, doing the update once an epoch strikes a good balance between performance and training time.

Aggregation step. Having an approximation for the gradient distribution of each task we need to combine them to find an update direction for the shared parameters. Denote the mean of the gradient of the loss for task k𝑘kitalic_k w.r.t the hidden layer for the ithsuperscript𝑖𝑡i^{th}italic_i start_POSTSUPERSCRIPT italic_t italic_h end_POSTSUPERSCRIPT example by 𝝁ik𝔼[𝐠ik]subscriptsuperscript𝝁𝑘𝑖𝔼delimited-[]subscriptsuperscript𝐠𝑘𝑖\bm{\mathbf{\mu}}^{k}_{i}\coloneqq\mathbb{E}[{\mathbf{g}}^{k}_{i}]bold_italic_μ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≔ blackboard_E [ bold_g start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ], and similarly the covariance matrix 𝚺ik(𝚲ik)1𝔼[𝐠ik(𝐠ik)T]𝔼[𝐠ik]𝔼[𝐠ik]Tsubscriptsuperscript𝚺𝑘𝑖superscriptsubscriptsuperscript𝚲𝑘𝑖1𝔼delimited-[]subscriptsuperscript𝐠𝑘𝑖superscriptsubscriptsuperscript𝐠𝑘𝑖𝑇𝔼delimited-[]subscriptsuperscript𝐠𝑘𝑖𝔼superscriptdelimited-[]subscriptsuperscript𝐠𝑘𝑖𝑇\bm{\mathbf{\Sigma}}^{k}_{i}\coloneqq(\mathbf{\Lambda}^{k}_{i})^{-1}\coloneqq%\mathbb{E}[{\mathbf{g}}^{k}_{i}({\mathbf{g}}^{k}_{i})^{T}]-\mathbb{E}[{\mathbf%{g}}^{k}_{i}]\mathbb{E}[{\mathbf{g}}^{k}_{i}]^{T}bold_Σ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≔ ( bold_Λ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ≔ blackboard_E [ bold_g start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_g start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ] - blackboard_E [ bold_g start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ] blackboard_E [ bold_g start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT. We strive to find an update direction for the last shared layer, 𝐠isubscript𝐠𝑖{\mathbf{g}}_{i}bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, that lies in a high-density region for all tasks. Hence, we pick 𝐠isubscript𝐠𝑖{\mathbf{g}}_{i}bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT that maximizes the following likelihood:

argmax𝐠isubscriptargmaxsubscript𝐠𝑖\displaystyle\operatorname*{arg\,max}_{{\mathbf{g}}_{i}}start_OPERATOR roman_arg roman_max end_OPERATOR start_POSTSUBSCRIPT bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPTk=1K𝒩(𝐠i|𝝁ik,𝚺ik)=superscriptsubscriptproduct𝑘1𝐾𝒩conditionalsubscript𝐠𝑖subscriptsuperscript𝝁𝑘𝑖subscriptsuperscript𝚺𝑘𝑖absent\displaystyle\prod_{k=1}^{K}\mathcal{N}({\mathbf{g}}_{i}|\bm{\mathbf{\mu}}^{k}%_{i},\bm{\mathbf{\Sigma}}^{k}_{i})=∏ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT caligraphic_N ( bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | bold_italic_μ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_Σ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) =(5)
argmin𝐠isubscriptargminsubscript𝐠𝑖\displaystyle\operatorname*{arg\,min}_{{\mathbf{g}}_{i}}start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPTk=1Klog𝒩(𝐠i|𝝁ik,𝚺ik).superscriptsubscript𝑘1𝐾𝑙𝑜𝑔𝒩conditionalsubscript𝐠𝑖subscriptsuperscript𝝁𝑘𝑖subscriptsuperscript𝚺𝑘𝑖\displaystyle-\sum_{k=1}^{K}log~{}\mathcal{N}({\mathbf{g}}_{i}|\bm{\mathbf{\mu%}}^{k}_{i},\bm{\mathbf{\Sigma}}^{k}_{i}).- ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_l italic_o italic_g caligraphic_N ( bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | bold_italic_μ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_Σ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) .

Thankfully, the above optimization problem can be solved in closed-form, yielding the following solution:

𝐠i=(k=1K𝚲ik)1(k=1K𝚲ik𝝁ik).subscript𝐠𝑖superscriptsuperscriptsubscript𝑘1𝐾subscriptsuperscript𝚲𝑘𝑖1superscriptsubscript𝑘1𝐾subscriptsuperscript𝚲𝑘𝑖subscriptsuperscript𝝁𝑘𝑖\displaystyle{\mathbf{g}}_{i}=\left(\sum_{k=1}^{K}\mathbf{\Lambda}^{k}_{i}%\right)^{-1}\left(\sum_{k=1}^{K}\mathbf{\Lambda}^{k}_{i}\bm{\mathbf{\mu}}^{k}_%{i}\right).bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT bold_Λ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT bold_Λ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_italic_μ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) .(6)

However, we found that modeling the full covariance matrix can be numerically unstable and sensitive to noise in the gradient. Instead, we assume independence between the dimensions of 𝐠iksuperscriptsubscript𝐠𝑖𝑘{\mathbf{g}}_{i}^{k}bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT for all tasks which results in diagonal covariance matrices having variance (𝝈ik)21/𝝀iksuperscriptsubscriptsuperscript𝝈𝑘𝑖21subscriptsuperscript𝝀𝑘𝑖(\bm{\mathbf{\sigma}}^{k}_{i})^{2}\coloneqq 1/\bm{\mathbf{\lambda}}^{k}_{i}( bold_italic_σ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≔ 1 / bold_italic_λ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. The update direction now becomes:

𝐠i=k=1K1/(𝝈ik)2k=1K1/(𝝈ik)2𝝁ik=k=1K𝝀ikk=1K𝝀ik𝜶ik𝝁ik,subscript𝐠𝑖superscriptsubscript𝑘1𝐾1superscriptsubscriptsuperscript𝝈𝑘𝑖2superscriptsubscript𝑘1𝐾1superscriptsubscriptsuperscript𝝈𝑘𝑖2subscriptsuperscript𝝁𝑘𝑖superscriptsubscript𝑘1𝐾superscriptsubscriptsuperscript𝝀𝑘𝑖superscriptsubscript𝑘1𝐾subscriptsuperscript𝝀𝑘𝑖superscriptsubscript𝜶𝑖𝑘subscriptsuperscript𝝁𝑘𝑖\displaystyle{\mathbf{g}}_{i}=\sum_{k=1}^{K}\frac{1/(\bm{\mathbf{\sigma}}^{k}_%{i})^{2}}{\sum_{k=1}^{K}1/(\bm{\mathbf{\sigma}}^{k}_{i})^{2}}\bm{\mathbf{\mu}}%^{k}_{i}=\sum_{k=1}^{K}\overbrace{\frac{\bm{\mathbf{\lambda}}^{k}_{i}}{\sum_{k%=1}^{K}\bm{\mathbf{\lambda}}^{k}_{i}}}^{\bm{\mathbf{\alpha}}_{i}^{k}}\bm{%\mathbf{\mu}}^{k}_{i},bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT divide start_ARG 1 / ( bold_italic_σ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT 1 / ( bold_italic_σ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_μ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT over⏞ start_ARG divide start_ARG bold_italic_λ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT bold_italic_λ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG end_ARG start_POSTSUPERSCRIPT bold_italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT bold_italic_μ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ,(7)

where the division and multiplication are done element-wise. In Eq.7 we intentionally denote by 𝜶iksuperscriptsubscript𝜶𝑖𝑘\bm{\mathbf{\alpha}}_{i}^{k}bold_italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT the vector of uncertainty-based weights that our method assigns to the mean gradient to highlight that the weights are unique per task, dimension, and datum. The final modification for the method involves down-scaling the impact of the precision by a hyper-parameter s(0,1]𝑠01s\in(0,1]italic_s ∈ ( 0 , 1 ], namely, we take (𝝀ik)ssuperscriptsubscriptsuperscript𝝀𝑘𝑖𝑠(\bm{\mathbf{\lambda}}^{k}_{i})^{s}( bold_italic_λ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT. Empirically, the scaling parameter helped to achieve better performance, perhaps due to misspecifications in the model (such as the diagonal Gaussian assumption over 𝐠iksubscriptsuperscript𝐠𝑘𝑖{\mathbf{g}}^{k}_{i}bold_g start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT).

Input: {\mathcal{B}}caligraphic_B - a random batch of examples; p(𝐰k|𝒟)k[1,,K]𝑝conditionalsuperscript𝐰𝑘𝒟for-all𝑘1𝐾p({\mathbf{w}}^{k}|{\mathcal{D}})~{}~{}\forall k\in[1,...,K]italic_p ( bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT | caligraphic_D ) ∀ italic_k ∈ [ 1 , … , italic_K ] - posterior distributions over the task-specific parameters; s𝑠sitalic_s - scaling hyper-parameter
For i=1,,||𝑖1i=1,...,|{\mathcal{B}}|italic_i = 1 , … , | caligraphic_B |:
For k=1,,K𝑘1𝐾k=1,...,Kitalic_k = 1 , … , italic_K:
\bullet Compute 𝔼[𝐠ik]𝔼delimited-[]subscriptsuperscript𝐠𝑘𝑖\mathbb{E}[{\mathbf{g}}^{k}_{i}]blackboard_E [ bold_g start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ] and 𝔼[𝐠ik(𝐠ik)T]𝔼delimited-[]subscriptsuperscript𝐠𝑘𝑖superscriptsubscriptsuperscript𝐠𝑘𝑖𝑇\mathbb{E}[{\mathbf{g}}^{k}_{i}({\mathbf{g}}^{k}_{i})^{T}]blackboard_E [ bold_g start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_g start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ] as in Eq.4 for
regression or Eq.11 for classification.
\bullet Set (operations are done element-wise),
𝝁ik𝔼[𝐠ik]subscriptsuperscript𝝁𝑘𝑖𝔼delimited-[]subscriptsuperscript𝐠𝑘𝑖\bm{\mathbf{\mu}}^{k}_{i}\coloneqq\mathbb{E}[{\mathbf{g}}^{k}_{i}]bold_italic_μ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≔ blackboard_E [ bold_g start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ],
𝝀ik(𝔼[(𝐠ik)2]𝔼[𝐠ik]𝔼[𝐠ik]))1\bm{\mathbf{\lambda}}^{k}_{i}\coloneqq(\mathbb{E}[({\mathbf{g}}^{k}_{i})^{2}]-%\mathbb{E}[{\mathbf{g}}^{k}_{i}]\mathbb{E}[{\mathbf{g}}^{k}_{i}]))^{-1}bold_italic_λ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≔ ( blackboard_E [ ( bold_g start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] - blackboard_E [ bold_g start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ] blackboard_E [ bold_g start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ] ) ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT.
End for
Compute 𝐠i=k=1K(𝝀ik)sk=1K(𝝀ik)s𝝁iksubscript𝐠𝑖superscriptsubscript𝑘1𝐾superscriptsubscriptsuperscript𝝀𝑘𝑖𝑠superscriptsubscript𝑘1𝐾superscriptsubscriptsuperscript𝝀𝑘𝑖𝑠subscriptsuperscript𝝁𝑘𝑖{\mathbf{g}}_{i}=\sum_{k=1}^{K}\frac{(\bm{\mathbf{\lambda}}^{k}_{i})^{s}}{\sum%_{k=1}^{K}(\bm{\mathbf{\lambda}}^{k}_{i})^{s}}\bm{\mathbf{\mu}}^{k}_{i}bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT divide start_ARG ( bold_italic_λ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ( bold_italic_λ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT end_ARG bold_italic_μ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT.
End for
Compute gradient via matrix multiplication w.r.t the shared parameters: 1||i=1||𝐠i𝐡i𝜽1superscriptsubscript𝑖1subscript𝐠𝑖subscript𝐡𝑖𝜽\frac{1}{|{\mathcal{B}}|}\sum_{i=1}^{|{\mathcal{B}}|}{\mathbf{g}}_{i}\frac{%\partial{\mathbf{h}}_{i}}{\partial\bm{\mathbf{\theta}}}divide start_ARG 1 end_ARG start_ARG | caligraphic_B | end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | caligraphic_B | end_POSTSUPERSCRIPT bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT divide start_ARG ∂ bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_italic_θ end_ARG.

With the aggregated gradient for each example, the back-propagation procedure proceeds as usual by averaging over all examples in the batch and then back-propagating this over to the shared parameters. To gain a better intuition about the update rule of BayesAgg-MTL , consider the illustration in Figure2. In the figure, we plot the mean update direction of two tasks along with the uncertainty in them. The first task is more sensitive to shifts in the vertical dimension and less so to shifts in the second (horizontal) dimension, while for the second task, it is the opposite. By taking the variance information into account, BayesAgg-MTL can find an update direction that works well for both, compared to a simple average of the gradient means. We summarize our method in Algorithm 1.

Making predictions. Since we have a closed-form solution for the posterior of the task-specific parameters, BayesAgg-MTL does not learn this layer during training. Therefore, when making predictions we use the posterior mean, 𝐦ksuperscript𝐦𝑘{\mathbf{m}}^{k}bold_m start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT, computed on the full training set. We do so, instead of using a full Bayesian inference, for a fair comparison with alternative MTL approaches and to have an identical run-time and memory requirements when making predictions.

Connection to Nash-MTL.In (Navon etal., 2022) the authors proposed a cooperative bargaining game approach to the gradient aggregation step with the directional derivative as the utility of each player (task). They then proposed using the Nash bargaining solution, the direction that maximizes the product of all the utilities. One can consider Eq.5 as the Nash bargaining solution with the utility of each task being its likelihood. However, unlike (Navon etal., 2022) we get an analytical formula for the bargaining solution since the Gaussian exponent and the logarithm cancel out.

Bayesian Uncertainty for Gradient Aggregation in Multi-Task Learning (2)

3.2 BayesAgg-MTL for Classification Tasks

We now turn to present our approach for classification tasks. When dealing with classification there are two sources of intractability that we need to overcome. The first is the posterior of 𝐰ksuperscript𝐰𝑘{\mathbf{w}}^{k}bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT, and the second is estimating the moments of 𝐠iksubscriptsuperscript𝐠𝑘𝑖{\mathbf{g}}^{k}_{i}bold_g start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. We describe our solution to both challenges next.

Posterior approximation. In classification tasks the likelihood is not a Gaussian and in general, we cannot compute the posterior in closed-form. One common option is to approximate it using a Gaussian distribution and learn its parameters using a variational inference (VI) scheme (Saul etal., 1996; Neal & Hinton, 1998; Bishop, 2006). However, in our early experimentations, we didn’t find it to work well without using a computationally expensive VI optimization at each update step. Alternatively to VI, the Laplace approximation (MacKay, 1992) approximates the posterior as a Gaussian using a second-order Taylor expansion. Since the expansion is done at the optimal parameter values that are learned point-wise, the Jacobean term in the expansion vanishes. Here, we follow a similar path; however, we cannot assume that the Jacobean is zero as we are not near a stationary point during most of the training. Nevertheless, we can still find a Gaussian approximation. A similar derivation was proposed in (Immer etal., 2021), yet they ignored the first order term eventually.Denote by 𝐰^ksuperscript^𝐰𝑘\hat{{\mathbf{w}}}^{k}over^ start_ARG bold_w end_ARG start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT the learned point estimate for the task parameters, and Δ𝐰k𝐰k𝐰^kΔsuperscript𝐰𝑘superscript𝐰𝑘superscript^𝐰𝑘\Delta{\mathbf{w}}^{k}\coloneqq{\mathbf{w}}^{k}-\hat{{\mathbf{w}}}^{k}roman_Δ bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ≔ bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT - over^ start_ARG bold_w end_ARG start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT. Then, at each step of the training by using Bayes rule we can obtain a posterior approximation for 𝐰ksuperscript𝐰𝑘{\mathbf{w}}^{k}bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT using the following:

logp(𝐰k|)logp(𝐰^k|)+𝑙𝑜𝑔𝑝conditionalsuperscript𝐰𝑘limit-from𝑙𝑜𝑔𝑝conditionalsuperscript^𝐰𝑘\displaystyle log~{}p({\mathbf{w}}^{k}|{\mathcal{B}})\approx log~{}p(\hat{{%\mathbf{w}}}^{k}|{\mathcal{B}})+italic_l italic_o italic_g italic_p ( bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT | caligraphic_B ) ≈ italic_l italic_o italic_g italic_p ( over^ start_ARG bold_w end_ARG start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT | caligraphic_B ) +(8)
(logp(𝐲k|𝐗,𝐰k)𝐰klogp(𝐰k)𝐰k)TΔ𝐰k+limit-fromsuperscript𝑙𝑜𝑔𝑝conditionalsuperscript𝐲𝑘𝐗superscript𝐰𝑘superscript𝐰𝑘𝑙𝑜𝑔𝑝superscript𝐰𝑘superscript𝐰𝑘𝑇Δsuperscript𝐰𝑘\displaystyle\left(-\frac{\partial log~{}p({\mathbf{y}}^{k}|{\mathbf{X}},{%\mathbf{w}}^{k})}{\partial{\mathbf{w}}^{k}}-\frac{\partial log~{}p({\mathbf{w}%}^{k})}{\partial{\mathbf{w}}^{k}}\right)^{T}\Delta{\mathbf{w}}^{k}+( - divide start_ARG ∂ italic_l italic_o italic_g italic_p ( bold_y start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT | bold_X , bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) end_ARG start_ARG ∂ bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT end_ARG - divide start_ARG ∂ italic_l italic_o italic_g italic_p ( bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) end_ARG start_ARG ∂ bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT end_ARG ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT roman_Δ bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT +
12(Δ𝐰k)T(2logp(𝐲k|𝐗,𝐰k)(𝐰k)22logp(𝐰k)(𝐰k)2)Δ𝐰k.12superscriptΔsuperscript𝐰𝑘𝑇superscript2𝑙𝑜𝑔𝑝conditionalsuperscript𝐲𝑘𝐗superscript𝐰𝑘superscriptsuperscript𝐰𝑘2superscript2𝑙𝑜𝑔𝑝superscript𝐰𝑘superscriptsuperscript𝐰𝑘2Δsuperscript𝐰𝑘\displaystyle\frac{1}{2}(\Delta{\mathbf{w}}^{k})^{T}\left(-\frac{\partial^{2}%log~{}p({\mathbf{y}}^{k}|{\mathbf{X}},{\mathbf{w}}^{k})}{\partial({\mathbf{w}}%^{k})^{2}}-\frac{\partial^{2}log~{}p({\mathbf{w}}^{k})}{\partial({\mathbf{w}}^%{k})^{2}}\right)\Delta{\mathbf{w}}^{k}.divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( roman_Δ bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( - divide start_ARG ∂ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_l italic_o italic_g italic_p ( bold_y start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT | bold_X , bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) end_ARG start_ARG ∂ ( bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG - divide start_ARG ∂ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_l italic_o italic_g italic_p ( bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) end_ARG start_ARG ∂ ( bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) roman_Δ bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT .

The above takes the following form ck+(𝐚k)T(𝐰k𝐰^k)+12(𝐰k𝐰^k)T𝐁k(𝐰k𝐰^k)superscript𝑐𝑘superscriptsuperscript𝐚𝑘𝑇superscript𝐰𝑘superscript^𝐰𝑘12superscriptsuperscript𝐰𝑘superscript^𝐰𝑘𝑇superscript𝐁𝑘superscript𝐰𝑘superscript^𝐰𝑘c^{k}+({\mathbf{a}}^{k})^{T}({\mathbf{w}}^{k}-\hat{{\mathbf{w}}}^{k})+\frac{1}%{2}({\mathbf{w}}^{k}-\hat{{\mathbf{w}}}^{k})^{T}{\mathbf{B}}^{k}({\mathbf{w}}^%{k}-\hat{{\mathbf{w}}}^{k})italic_c start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT + ( bold_a start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT - over^ start_ARG bold_w end_ARG start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) + divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT - over^ start_ARG bold_w end_ARG start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_B start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ( bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT - over^ start_ARG bold_w end_ARG start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ), where 𝐚kdk,𝐁kdk×dk,ckformulae-sequencesuperscript𝐚𝑘superscriptsubscript𝑑𝑘formulae-sequencesuperscript𝐁𝑘superscriptsubscript𝑑𝑘subscript𝑑𝑘superscript𝑐𝑘{\mathbf{a}}^{k}\in{\mathbb{R}}^{d_{k}},{\mathbf{B}}^{k}\in{\mathbb{R}}^{d_{k}%\times d_{k}},c^{k}\in{\mathbb{R}}bold_a start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , bold_B start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , italic_c start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ∈ blackboard_R are known constants. We stress here again, that since we apply Bayesian inference to the last layer parameters only, computing and inverting 𝐁ksuperscript𝐁𝑘{\mathbf{B}}^{k}bold_B start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT, typically does not incur a large computational overhead.

After rearranging and completing the square we obtain a quadratic form corresponding to the following Gaussian distribution (see full derivation in AppendixA.2):

p(𝐰k|)𝑝conditionalsuperscript𝐰𝑘\displaystyle p({\mathbf{w}}^{k}|{\mathcal{B}})italic_p ( bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT | caligraphic_B )𝒩(𝐰k|𝐰^k(𝐁k)1𝐚k,(𝐁k)1).absent𝒩conditionalsuperscript𝐰𝑘superscript^𝐰𝑘superscriptsuperscript𝐁𝑘1superscript𝐚𝑘superscriptsuperscript𝐁𝑘1\displaystyle\approx\mathcal{N}({\mathbf{w}}^{k}|\hat{{\mathbf{w}}}^{k}-({%\mathbf{B}}^{k})^{-1}{\mathbf{a}}^{k},({\mathbf{B}}^{k})^{-1}).≈ caligraphic_N ( bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT | over^ start_ARG bold_w end_ARG start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT - ( bold_B start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_a start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT , ( bold_B start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ) .(9)

Examining the above posterior reveals several insights. First, the posterior mean corresponds to the Newton method update step. Second, the covariance of this posterior is the same as that of the Laplace approximation. Third, at a stationary point the Laplace approximation is recovered if the gradient of the loss w.r.t the parameters approaches zero.

One limitation of the approximation in Eq.9 is that the Hessian will not be positive-definite in most cases. Therefore, we replace it with the generalized Gauss-Newton (GGN) matrix (Schraudolph, 2002; Martens & Sutskever, 2011; Daxberger etal., 2021):

𝐁~k=i=1||(𝐉ik)T𝐇ik𝐉ik+(𝐒pk)1.superscript~𝐁𝑘superscriptsubscript𝑖1superscriptsubscriptsuperscript𝐉𝑘𝑖𝑇subscriptsuperscript𝐇𝑘𝑖subscriptsuperscript𝐉𝑘𝑖superscriptsuperscriptsubscript𝐒𝑝𝑘1\displaystyle\tilde{{\mathbf{B}}}^{k}=\sum_{i=1}^{|{\mathcal{B}}|}({\mathbf{J}%}^{k}_{i})^{T}{\mathbf{H}}^{k}_{i}{\mathbf{J}}^{k}_{i}+({\mathbf{S}}_{p}^{k})^%{-1}.over~ start_ARG bold_B end_ARG start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT | caligraphic_B | end_POSTSUPERSCRIPT ( bold_J start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_H start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_J start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + ( bold_S start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT .(10)

Where, 𝐉ik=𝐟k(𝐱i;𝐰k)/𝐰kok×dksubscriptsuperscript𝐉𝑘𝑖superscript𝐟𝑘subscript𝐱𝑖superscript𝐰𝑘superscript𝐰𝑘superscriptsubscript𝑜𝑘subscript𝑑𝑘{\mathbf{J}}^{k}_{i}=\partial{\mathbf{f}}^{k}({\mathbf{x}}_{i};{\mathbf{w}}^{k%})/\partial{\mathbf{w}}^{k}\in{\mathbb{R}}^{o_{k}\times d_{k}}bold_J start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = ∂ bold_f start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ( bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) / ∂ bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_o start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is the Jacobean of the model output for task k𝑘kitalic_k w.r.t the last layer parameters of that task, 𝐇ik=2logp(𝐲ik|𝐱i,𝐰k)/(𝐟k(𝐱i;𝐰k))2ok×oksubscriptsuperscript𝐇𝑘𝑖superscript2𝑙𝑜𝑔𝑝conditionalsubscriptsuperscript𝐲𝑘𝑖subscript𝐱𝑖superscript𝐰𝑘superscriptsuperscript𝐟𝑘subscript𝐱𝑖superscript𝐰𝑘2superscriptsubscript𝑜𝑘subscript𝑜𝑘{\mathbf{H}}^{k}_{i}=-\partial^{2}log~{}p({\mathbf{y}}^{k}_{i}|{\mathbf{x}}_{i%},{\mathbf{w}}^{k})/\partial({\mathbf{f}}^{k}({\mathbf{x}}_{i};{\mathbf{w}}^{k%}))^{2}\in{\mathbb{R}}^{o_{k}\times o_{k}}bold_H start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = - ∂ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_l italic_o italic_g italic_p ( bold_y start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) / ∂ ( bold_f start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ( bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_o start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT × italic_o start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is the Hessian of the negative log-likelihood w.r.t the model outputs of task k𝑘kitalic_k, and 𝐒pksuperscriptsubscript𝐒𝑝𝑘{\mathbf{S}}_{p}^{k}bold_S start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT is the covariance of the Gaussian prior for 𝐰ksuperscript𝐰𝑘{\mathbf{w}}^{k}bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT. As in the regression case we use here an informative prior based on the posterior from the full dataset at each training step.

Moments estimation. Unlike the regression case, in classification 𝐠iksubscriptsuperscript𝐠𝑘𝑖{\mathbf{g}}^{k}_{i}bold_g start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT will depend on 𝐰ksuperscript𝐰𝑘{\mathbf{w}}^{k}bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT through some non-linear function. Hence, obtaining the moments as in Eq.4 in closed-form is more challenging. However, since we are estimating the parameters of the last layer only, which in many cases are relatively low-dimensional, we can efficiently approximate these moments with Monte-Carlo sampling:

𝔼[𝐠ik]𝔼delimited-[]superscriptsubscript𝐠𝑖𝑘\displaystyle\mathbb{E}[{\mathbf{g}}_{i}^{k}]blackboard_E [ bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ]1Jj=1J𝐠ik(𝐰jk),absent1𝐽superscriptsubscript𝑗1𝐽superscriptsubscript𝐠𝑖𝑘subscriptsuperscript𝐰𝑘𝑗\displaystyle\approx\frac{1}{J}\sum_{j=1}^{J}{\mathbf{g}}_{i}^{k}({\mathbf{w}}%^{k}_{j}),≈ divide start_ARG 1 end_ARG start_ARG italic_J end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_J end_POSTSUPERSCRIPT bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ( bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ,(11)
𝔼[𝐠ik(𝐠ik)T]𝔼delimited-[]superscriptsubscript𝐠𝑖𝑘superscriptsuperscriptsubscript𝐠𝑖𝑘𝑇\displaystyle\mathbb{E}[{\mathbf{g}}_{i}^{k}({\mathbf{g}}_{i}^{k})^{T}]blackboard_E [ bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ( bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ]1Jj=1J𝐠ik(𝐰jk)𝐠ik(𝐰jk)T.absent1𝐽superscriptsubscript𝑗1𝐽superscriptsubscript𝐠𝑖𝑘subscriptsuperscript𝐰𝑘𝑗superscriptsubscript𝐠𝑖𝑘superscriptsubscriptsuperscript𝐰𝑘𝑗𝑇\displaystyle\approx\frac{1}{J}\sum_{j=1}^{J}{\mathbf{g}}_{i}^{k}({\mathbf{w}}%^{k}_{j}){\mathbf{g}}_{i}^{k}({\mathbf{w}}^{k}_{j})^{T}.≈ divide start_ARG 1 end_ARG start_ARG italic_J end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_J end_POSTSUPERSCRIPT bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ( bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ( bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT .

Here, 𝐰jksuperscriptsubscript𝐰𝑗𝑘{\mathbf{w}}_{j}^{k}bold_w start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT are samples from p(𝐰k|)𝑝conditionalsuperscript𝐰𝑘p({\mathbf{w}}^{k}|{\mathcal{B}})italic_p ( bold_w start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT | caligraphic_B ), and the total number of samples are J𝐽Jitalic_J. Effectively this means that we need to back-propagate gradients w.r.t the shared hidden layer J𝐽Jitalic_J times; however, since the task-specific layers are linear it can be done cheaply and in parallel. Having the moment estimation we proceed with the aggregation rule as described in Section3.1.

Making predictions. Unlike the regression case, here we learn the parameters of the last layer as part of the posterior approximation. Therefore, making predictions is done as usual with a forward-pass through the network.

4 Related Work

Multi-task learning is an active research area that attempts to learn jointly multiple tasks, commonly using a shared representation (Ruder, 2017; Navon etal., 2022; Liu etal., 2023; Elich etal., 2023; Shi etal., 2023; Yun & Cho, 2023).Learning a shared representation for multiple tasks imposes some challenges. One challenge is trying to learn an architecture that can express both task-shared and task-specific features. Another challenge is to find the optimal balancing of the tasks and enable learning the different tasks with equal importance.One line of research in MTL suggests methods to introduce novel MTL-friendly architectures, such as task-specific modules (Misra etal., 2016), attention-based networks (Liu etal., 2019a), and an ensemble of single-task models (Dimitriadis etal., 2023). Yet, a more common line of research focuses on the MTL optimization process, trying to explain the difficulties in the process by e.g. conflicting gradients (Wang etal., 2020) or plateaus in the loss landscape (Schaul etal., 2019). Our method focuses on the latter, MTL optimization process improvement.

Different strategies were proposed to address the MTL optimization challenge to successfully balance the training of the different tasks and resolve their conflicts. The methods can broadly be categorized into two groups, loss-based and gradient-based (Dai etal., 2023). Loss-based approaches attempt to allocate weights for the tasks based on some criteria related to the loss, such as the difficulty of the task (Guo etal., 2018), random weights (Lin etal., 2022), geometric mean of the task losses (Chennupati etal., 2019; Yun & Cho, 2023), and task uncertainty (Kendall etal., 2018). Regarding the last one, to weigh the tasks it uses the uncertainty in the observations only. This is very different from our approach that weighs each dimension of the task gradients based on full Bayesian information.

Bayesian Uncertainty for Gradient Aggregation in Multi-Task Learning (3)

Gradient-based methods attempt to balance the tasks by using the gradients information directly (Chen etal., 2018, 2020; Javaloy & Valera, 2022; Liu etal., 2020; Navon etal., 2022; Fernando etal., 2023; Senushkin etal., 2023). For example, GradNorm (Chen etal., 2018) dynamically tunes the gradient magnitudes to prevent imbalances between the tasks during training. PCGrad (Yu etal., 2020) identifies gradient conflicts as the main optimization issue in MTL, and attempts to reduce the conflicts by projecting each gradient to the other tasks’ normal plane.Nash-MTL (Navon etal., 2022) suggests treating MTL as a bargaining game to find Pareto optimal solutions.Several studies suggested adaptations for the multiple-gradient descent algorithm (MGDA) (Désidéri, 2012; Sener & Koltun, 2018), such as CAGrad, (Liu etal., 2021), and MoCo (Fernando etal., 2023).As opposed to previous methods, our approach considers both the mean and the variance of the gradients to derive an update direction.

Lastly, some studies recently suggested performing model merging based on the uncertainty of the parameters (Matena & Raffel, 2022; Daheim etal., 2023). The goal there is usually to combine models for various tasks, such as model ensembling, federated learning, and robust fine-tuning. Unlike these methods, we assume a Bayesian model on the last layer only and propagate the uncertainty to the gradients for gradient aggregation.

𝚫𝐦%percentsubscript𝚫𝐦\mathbf{\Delta_{m}\%}bold_Δ start_POSTSUBSCRIPT bold_m end_POSTSUBSCRIPT % (\downarrow)
LS177.6±3.4plus-or-minus177.63.4177.6\pm 3.4177.6 ± 3.4
SI77.8±9.2plus-or-minus77.89.2~{}~{}77.8\pm 9.277.8 ± 9.2
RLW203.8±3.4plus-or-minus203.83.4203.8\pm 3.4203.8 ± 3.4
DWA175.3±6.3plus-or-minus175.36.3175.3\pm 6.3175.3 ± 6.3
UW108.0±22.5plus-or-minus108.022.5~{}108.0\pm 22.5108.0 ± 22.5
MGDA120.5±2.0plus-or-minus120.52.0120.5\pm 2.0120.5 ± 2.0
PCGrad125.7±10.3plus-or-minus125.710.3~{}~{}125.7\pm 10.3125.7 ± 10.3
CAGrad112.8±4.0plus-or-minus112.84.0112.8\pm 4.0112.8 ± 4.0
IMTL-G77.2±9.3plus-or-minus77.29.3~{}~{}77.2\pm 9.377.2 ± 9.3
Nash-MTL62.0±1.4plus-or-minus62.01.4~{}~{}62.0\pm 1.462.0 ± 1.4
IGBv267.7±8.1plus-or-minus67.78.1~{}~{}67.7\pm 8.167.7 ± 8.1
Aligned-MTL-UB71.0±9.6plus-or-minus71.09.6~{}~{}71.0\pm 9.671.0 ± 9.6
BayesAgg-MTL (Ours)53.2±7.1plus-or-minus53.27.1\mathbf{~{}~{}53.2\pm 7.1}bold_53.2 ± bold_7.1

5 Experiments

We evaluated BayesAgg-MTL on several MTL benchmarks differing in the number of tasks and their types. Unless specified otherwise, we report the average and standard deviation (std) of relevant metrics over 3333 random seeds. In all datasets, we pre-allocated a validation set from the training set for hyper-parameter tuning and early stopping for all methods. Throughout our experiments, we used the ADAM optimizer (Kingma & Ba, 2015) which was found to be effective for MTL due to partial loss-scale invariance (Elich etal., 2023). Full experimental details are given in AppendixB.

Compared methods. We compare BayesAgg-MTL with the following baseline methods: (1) Single Task Learning (STL), which learns each task independently under the same experimental setup as that of the MTL methods; (2) Linear Scalarization (LS), which assigns a uniform weight to all tasks, namely k=1Kksuperscriptsubscript𝑘1𝐾superscript𝑘\sum_{k=1}^{K}\ell^{k}∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_ℓ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT; (3) Scale-Invariant (SI) (Navon etal., 2022), which assigns a uniform weight to the log of all tasks, namely k=1Klogksuperscriptsubscript𝑘1𝐾𝑙𝑜𝑔superscript𝑘\sum_{k=1}^{K}log~{}\ell^{k}∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_l italic_o italic_g roman_ℓ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT;(4) Random Loss Weighting (RLW) (Lin etal., 2022), which allocates random weights to the losses at each iteration;(5) Dynamic Weight Average (DWA) (Liu etal., 2019a), which allocates a weight based on the rate of change of the loss for each task; (6) Uncertainty weighting (UW) (Kendall etal., 2018), which minimize a scalar term corresponding to the aleatoric uncertainty for each task; (7) Multiple-Gradient Descent Algorithm (MGDA) (Désidéri, 2012; Sener & Koltun, 2018), which finds a minimum norm solution for a convex combination of the losses; (8) Projecting Conflicting Gradients (PCGrad) (Yu etal., 2020), which projects the gradient of each task onto the normal plane of tasks they are in conflict with; (9) Conflict-Averse Grad (CAGrad) (Liu etal., 2021), which searches an update direction centered at the LS solution while minimizing conflicts in gradients; (10) Impartial MTL-Grad (IMTL-G) (Liu etal., 2020), which finds an update vector such that the projection of it on each of the gradients of the tasks is equal; (11) Nash-MTL (Navon etal., 2022) that derives task weights based on the Nash bargaining solution; (12) Improvable GapBalancing (IGBv2) (Dai etal., 2023), which suggests a Reinforcement learning procedure to balance the task losses; (13) Aligned-MTL-UB (Senushkin etal., 2023), which aligns the principle components of a gradient matrix.

Evaluation metric. Unless specified otherwise, we report the Δm%percentsubscriptΔ𝑚\Delta_{m}\%roman_Δ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT % metric introduced in (Maninis etal., 2019). This metric measures the average relative difference between a method m𝑚mitalic_m compared to the STL baseline according to some criterion of interest Mksuperscript𝑀𝑘M^{k}italic_M start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT. Namely, Δm=1Kk=1K(1)δk(MmkMsk)/MsksubscriptΔ𝑚1𝐾superscriptsubscript𝑘1𝐾superscript1subscript𝛿𝑘superscriptsubscript𝑀𝑚𝑘superscriptsubscript𝑀𝑠𝑘superscriptsubscript𝑀𝑠𝑘\Delta_{m}=\frac{1}{K}\sum_{k=1}^{K}(-1)^{\delta_{k}}(M_{m}^{k}-M_{s}^{k})/M_{%s}^{k}roman_Δ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_K end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ( - 1 ) start_POSTSUPERSCRIPT italic_δ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( italic_M start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT - italic_M start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) / italic_M start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT. Where, Mmksuperscriptsubscript𝑀𝑚𝑘M_{m}^{k}italic_M start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT is the criterion value for task k𝑘kitalic_k under method m𝑚mitalic_m, Msksuperscriptsubscript𝑀𝑠𝑘M_{s}^{k}italic_M start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT is the criterion value for task k𝑘kitalic_k under the STL baseline, and δk{0,1}subscript𝛿𝑘01\delta_{k}\in\{0,1\}italic_δ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ { 0 , 1 }. If δk=0subscript𝛿𝑘0\delta_{k}=0italic_δ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = 0 then a lower value for Mksuperscript𝑀𝑘M^{k}italic_M start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT is better (e.g., task loss), and if δk=1subscript𝛿𝑘1\delta_{k}=1italic_δ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = 1 then a higher value for Mksuperscript𝑀𝑘M^{k}italic_M start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT is preferred (e.g., task accuracy). Lower Δm%percentsubscriptΔ𝑚\Delta_{m}\%roman_Δ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT % indicates a better performance.

Pre-training stage.To obtain meaningful features for the Bayesian layer, it is a common practice to apply a pre-training step using standard NN training for several epochs (Wilson etal., 2016a, b). We follow the same path here and apply an initial pre-training step using linear scalarization. We would like to stress here that in all the experiments, the overall number of training steps for BayesAgg-MTL (including the pre-training) is the same as all methods.

CIFAR (Acc.) [\uparrow]CX-ray (𝚫𝐦%percentsubscript𝚫𝐦\mathbf{\Delta_{m}\%}bold_Δ start_POSTSUBSCRIPT bold_m end_POSTSUBSCRIPT %) [\downarrow]
LS56.96±.06plus-or-minus56.96.0656.96\pm.0656.96 ± .0614.62±0.2plus-or-minus14.620.2-14.62\pm 0.2- 14.62 ± 0.2
SI55.75±0.3plus-or-minus55.750.355.75\pm 0.355.75 ± 0.310.94±0.4plus-or-minus10.940.4-10.94\pm 0.4- 10.94 ± 0.4
RLW59.30±.08plus-or-minus59.30.0859.30\pm.0859.30 ± .0811.69±0.1plus-or-minus11.690.1-11.69\pm 0.1- 11.69 ± 0.1
DWA58.44±0.5plus-or-minus58.440.558.44\pm 0.558.44 ± 0.514.79±.07plus-or-minus14.79.07\mathbf{-14.79\pm.07}- bold_14.79 ± bold_.07
UW56.63±0.5plus-or-minus56.630.556.63\pm 0.556.63 ± 0.513.95±0.2plus-or-minus13.950.2-13.95\pm 0.2- 13.95 ± 0.2
MGDA59.74±.07plus-or-minus59.74.07\mathbf{59.74\pm.07}bold_59.74 ± bold_.0714.44±0.4plus-or-minus14.440.4-14.44\pm 0.4- 14.44 ± 0.4
PCGrad56.32±0.2plus-or-minus56.320.256.32\pm 0.256.32 ± 0.213.43±0.5plus-or-minus13.430.5-13.43\pm 0.5- 13.43 ± 0.5
CAGrad56.59±0.2plus-or-minus56.590.256.59\pm 0.256.59 ± 0.214.49±0.1plus-or-minus14.490.1-14.49\pm 0.1- 14.49 ± 0.1
IMTL-G57.09±0.3plus-or-minus57.090.357.09\pm 0.357.09 ± 0.38.23±1.8plus-or-minus8.231.8~{}-8.23\pm 1.8- 8.23 ± 1.8
Nash-MTL56.59±0.2plus-or-minus56.590.256.59\pm 0.256.59 ± 0.213.23±0.5plus-or-minus13.230.5-13.23\pm 0.5- 13.23 ± 0.5
IGBv256.61±0.2plus-or-minus56.610.256.61\pm 0.256.61 ± 0.22.82±0.6plus-or-minus2.820.6~{}-2.82\pm 0.6- 2.82 ± 0.6
Aligned-MTL-UB56.57±0.7plus-or-minus56.570.756.57\pm 0.756.57 ± 0.714.14±0.2plus-or-minus14.140.2-14.14\pm 0.2- 14.14 ± 0.2
BayesAgg-MTL (Ours)59.97±0.4plus-or-minus59.970.4\mathbf{59.97\pm 0.4}bold_59.97 ± bold_0.414.96±0.1plus-or-minus14.960.1\mathbf{-14.96\pm 0.1}- bold_14.96 ± bold_0.1

Age (×𝟏𝟎𝟏absentsuperscript101\mathbf{\times 10^{1}}× bold_10 start_POSTSUPERSCRIPT bold_1 end_POSTSUPERSCRIPT) (\downarrow)Gender (\uparrow)Ethnicity (\uparrow)𝚫𝐦%percentsubscript𝚫𝐦\mathbf{\Delta_{m}\%}bold_Δ start_POSTSUBSCRIPT bold_m end_POSTSUBSCRIPT % (\downarrow)
STL1.40±0.03plus-or-minus1.400.031.40\pm 0.031.40 ± 0.0392.32±0.35plus-or-minus92.320.3592.32\pm 0.3592.32 ± 0.3582.42±0.42plus-or-minus82.420.4282.42\pm 0.4282.42 ± 0.42
LS1.46±0.02plus-or-minus1.460.021.46\pm 0.021.46 ± 0.0292.92±0.24plus-or-minus92.920.2492.92\pm 0.2492.92 ± 0.2483.98±0.43plus-or-minus83.980.4383.98\pm 0.4383.98 ± 0.430.69±0.59plus-or-minus0.690.59~{}~{}~{}0.69\pm 0.590.69 ± 0.59
SI1.42±0.03plus-or-minus1.420.031.42\pm 0.031.42 ± 0.0393.05±0.29plus-or-minus93.050.2993.05\pm 0.2993.05 ± 0.2983.40±0.27plus-or-minus83.400.2783.40\pm 0.2783.40 ± 0.270.11±0.89plus-or-minus0.110.89~{}~{}~{}0.11\pm 0.890.11 ± 0.89
RLW1.44±0.03plus-or-minus1.440.031.44\pm 0.031.44 ± 0.0392.89±0.25plus-or-minus92.890.2592.89\pm 0.2592.89 ± 0.2583.70±0.49plus-or-minus83.700.4983.70\pm 0.4983.70 ± 0.490.31±0.76plus-or-minus0.310.76-0.31\pm 0.76- 0.31 ± 0.76
DWA1.44±0.02plus-or-minus1.440.021.44\pm 0.021.44 ± 0.0292.90±0.16plus-or-minus92.900.1692.90\pm 0.1692.90 ± 0.1683.55±0.33plus-or-minus83.550.3383.55\pm 0.3383.55 ± 0.330.35±0.60plus-or-minus0.350.60~{}~{}~{}0.35\pm 0.600.35 ± 0.60
UW1.43±0.00plus-or-minus1.430.001.43\pm 0.001.43 ± 0.0092.99±0.24plus-or-minus92.990.2492.99\pm 0.2492.99 ± 0.2483.09±0.39plus-or-minus83.090.3983.09\pm 0.3983.09 ± 0.390.15±0.24plus-or-minus0.150.24~{}~{}~{}0.15\pm 0.240.15 ± 0.24
MGDA1.38±0.02plus-or-minus1.380.021.38\pm 0.021.38 ± 0.0293.29±0.31plus-or-minus93.290.31\mathbf{93.29\pm 0.31}bold_93.29 ± bold_0.3183.51±0.30plus-or-minus83.510.3083.51\pm 0.3083.51 ± 0.301.39±0.50plus-or-minus1.390.50-1.39\pm 0.50- 1.39 ± 0.50
PCGrad1.47±0.03plus-or-minus1.470.031.47\pm 0.031.47 ± 0.0392.92±0.28plus-or-minus92.920.2892.92\pm 0.2892.92 ± 0.2883.28±0.38plus-or-minus83.280.3883.28\pm 0.3883.28 ± 0.381.13±0.57plus-or-minus1.130.57~{}~{}~{}1.13\pm 0.571.13 ± 0.57
CAGrad1.40±0.02plus-or-minus1.400.021.40\pm 0.021.40 ± 0.0293.06±0.26plus-or-minus93.060.2693.06\pm 0.2693.06 ± 0.2683.28±0.46plus-or-minus83.280.4683.28\pm 0.4683.28 ± 0.460.58±0.59plus-or-minus0.580.59-0.58\pm 0.59- 0.58 ± 0.59
IMTL-G1.41±0.03plus-or-minus1.410.031.41\pm 0.031.41 ± 0.0393.10±0.16plus-or-minus93.100.1693.10\pm 0.1693.10 ± 0.1683.78±0.47plus-or-minus83.780.4783.78\pm 0.4783.78 ± 0.470.50±0.89plus-or-minus0.500.89-0.50\pm 0.89- 0.50 ± 0.89
Nash-MTL1.42±0.02plus-or-minus1.420.021.42\pm 0.021.42 ± 0.0292.89±0.10plus-or-minus92.890.1092.89\pm 0.1092.89 ± 0.1083.19±0.50plus-or-minus83.190.5083.19\pm 0.5083.19 ± 0.500.17±0.71plus-or-minus0.170.71-0.17\pm 0.71- 0.17 ± 0.71
IGBv21.42±0.02plus-or-minus1.420.021.42\pm 0.021.42 ± 0.0293.09±0.22plus-or-minus93.090.2293.09\pm 0.2293.09 ± 0.2283.34±0.33plus-or-minus83.340.3383.34\pm 0.3383.34 ± 0.330.21±0.50plus-or-minus0.210.50-0.21\pm 0.50- 0.21 ± 0.50
Aligned-MTL-UB1.45±0.02plus-or-minus1.450.021.45\pm 0.021.45 ± 0.0293.00±0.24plus-or-minus93.000.2493.00\pm 0.2493.00 ± 0.2483.36±0.43plus-or-minus83.360.4383.36\pm 0.4383.36 ± 0.430.66±0.50plus-or-minus0.660.50~{}~{}~{}0.66\pm 0.500.66 ± 0.50
BayesAgg-MTL (Ours)1.35±0.03plus-or-minus1.350.03\mathbf{1.35\pm 0.03}bold_1.35 ± bold_0.0393.01±0.17plus-or-minus93.010.1793.01\pm 0.1793.01 ± 0.1784.25±0.35plus-or-minus84.250.35\mathbf{84.25\pm 0.35}bold_84.25 ± bold_0.352.23±0.76plus-or-minus2.230.76\mathbf{-2.23\pm 0.76}- bold_2.23 ± bold_0.76

5.1 BayesAgg-MTL for Regression

We first evaluated BayesAgg-MTL on an MTL problem with regression tasks only. We used the QM9 dataset which contains 130,000similar-toabsent130000\sim 130,000∼ 130 , 000 stable small organic molecules represented as graphs having node and edge features (Ramakrishnan etal., 2014; Wu etal., 2018). The goal here is to predict 11111111 chemical properties, such as geometric and energetic ones, that may vary in scale and difficulty of the tasks. We follow the experimental protocol of Navon etal. (2022). Specifically, we allocate approximately 110,000110000110,000110 , 000 examples for training, with separate validation and testing sets with 10,0001000010,00010 , 000 examples each. Additionally, we employ the message-passing neural network architecture (Gilmer etal., 2017) in conjunction with the pooling operator described in (Vinyals etal., 2016).

The test results for this dataset are presented in Table1. Baseline method results were taken from (Dai etal., 2023), except for Aligned-MTL-UB, which is included here for the first time. The criterion used in ΔmsubscriptΔ𝑚\Delta_{m}roman_Δ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT here is the mean absolute error (MAE) of the losses. From the table, BayesAgg-MTL achieves the best test performance, with a significant improvement compared to most of the baseline methods.

To gain a better intuition into the weights that BayesAgg-MTL assigns, we define here again the vector of weights per example and task from Eq.7, 𝜶ik𝝀ik/(k=1K𝝀ik)subscriptsuperscript𝜶𝑘𝑖subscriptsuperscript𝝀𝑘𝑖superscriptsubscript𝑘1𝐾subscriptsuperscript𝝀𝑘𝑖\bm{\mathbf{\alpha}}^{k}_{i}\coloneqq{\bm{\mathbf{\lambda}}^{k}_{i}}/{(\sum_{k%=1}^{K}\bm{\mathbf{\lambda}}^{k}_{i}})bold_italic_α start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≔ bold_italic_λ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT / ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT bold_italic_λ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ). Figure3 depicts for all tasks the average over dimensions of 𝜶iksubscriptsuperscript𝜶𝑘𝑖\bm{\mathbf{\alpha}}^{k}_{i}bold_italic_α start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT for 20202020 random examples at the start, middle, and end of training. The plot reveals an interesting pattern. Early in training, the average weights are distributed among the tasks without any specific pattern. As training progresses, larger weights are assigned for tasks 4104104-104 - 10 in the middle of the training, while tasks 03030-30 - 3 receive smaller weights. At the end of the training, this pattern changes, and tasks 03030-30 - 3 are assigned with larger weights compared to tasks 4104104-104 - 10.

5.2 BayesAgg-MTL for Binary Classification

Next, we evaluated BayesAgg-MTL on the MTL benchmarks CIFAR-MTL (Krizhevsky etal., 2009; Rosenbaum etal., 2018), and ChestX-ray14 (Wang etal., 2017). To the best of our knowledge, we are the first to evaluate MTL methods on the latter dataset. These datasets contain a large number of tasks, 20202020 and 14141414 respectively, with a high class-imbalance distribution. This poses a significant challenge for current MTL methods.

CIFAR-MTL uses the coarse labels of the CIFAR-100 dataset to create an MTL benchmark having 20202020 binary tasks. Classes from this dataset are grouped into super-classes (fish, flowers, trees, etc.), such that each example is given a one-hot encoding vector of labels indicating the super-class it belongs to. We use the official train-test split having 50,0005000050,00050 , 000 examples and 10,0001000010,00010 , 000 examples respectively. We allocate 5,00050005,0005 , 000 examples from the training set for a validation set. Our experiments on this dataset were conducted using a simple NN having 3333 convolution layers.

ChestX-ray14 contains 112,000similar-toabsent112000\sim 112,000∼ 112 , 000 X-ray images of chests from 32,7173271732,71732 , 717 patients. Each image has labels from 14141414 binary classes corresponding to the occurrence or absence of thoracic diseases. Multiple diseases can appear together in a patient. In our experiments, we mostly follow the training protocol suggested in (Taslimi etal., 2022) that used ResNet-34 for the shared parameters. we use the official split of 70%10%20%percent70percent10percent2070\%-10\%-20\%70 % - 10 % - 20 % for training, validation, and test.

We present the test results for these datasets in Table2. On the CIFAR-MTL we report the accuracy in class assignment, and on the ChestX-ray14 we report the ΔmsubscriptΔ𝑚\Delta_{m}roman_Δ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT based on the AUC-ROC values per task. From the table, BayesAgg-MTL performs best on both datasets. Interestingly, on the ChestX-ray14 dataset almost all methods, except for ours and DWA, under-perform the naive LS baseline. In AppendixC.2 we compare the run-time of all methods on this dataset and on the QM9. We show that BayesAgg-MTL is substantially faster than other baseline methods that use gradients w.r.t the shared parameters to weigh the tasks.

5.3 BayesAgg-MTL for Mixed Tasks

In the last set of experiments, we evaluated BayesAgg-MTL and baseline methods on the UTKFace dataset (Zhang etal., 2017). This dataset contains over 20,0002000020,00020 , 000 face images with annotations of age, gender, and ethnicity. The age values range from 00 to 116116116116, treated as a regression task. Gender is classified into binary categories, either male or female, while ethnicity is classified into five distinct categories, making it a multi-class classification task. We split the dataset according to 70%10%20%percent70percent10percent2070\%-10\%-20\%70 % - 10 % - 20 % to train, validation, and test datasets. Here, we use ResNet-18 for the shared network.

Results for this dataset based on 8888 random seeds are presented in Table3. Here as well BayesAgg-MTL outperforms all methods, having the best results on 2222 out of 3333 tasks. Interestingly, our approach and MGDA, were the only methods to improve upon the STL baseline on the regression task.

6 Conclusions

In this study, we present BayesAgg-MTL , a novel method for aggregating the task gradients in MTL. Instead of treating the gradient of each task as a deterministic quantity we advocate here to assign a probability distribution over them. The randomness in them arises by noticing that there are many possible configurations for the task-specific parameters that work well. Hence, by tracking all of them using Bayesian tools we can obtain a richer description of the gradient space. This in turn allows us to model the uncertainty in the gradients and derive an update direction for the shared parameters that takes it into account. We demonstrate our method’s effectiveness on several benchmark datasets compared with leading baseline methods. For future work, we would like to extend BayesAgg-MTL beyond linear task heads. The challenge here would be to efficiently estimate the Bayesian posterior and the gradient moments. Another possible limitation of BayesAgg-MTL , having in common with other popular MTL methods, is that it may fail on rare or atypical examples (Sagawa etal., 2019).

Acknowledgements

This study was funded by a grant to GC from the Israel Science Foundation (ISF 737/2018), and by an equipment grant to GC and Bar-Ilan University from the Israel Science Foundation (ISF 2332/18). IA is supported by a PhD fellowship from Bar-Ilan data science institute (BIU DSI).

Impact Statement

This paper presents work whose goal is to advance the field of Machine Learning. There are many potential societal consequences of our work, none which we feel must be specifically highlighted here.

References

  • Achituve etal. (2021a)Achituve, I., Maron, H., and Chechik, G.Self-supervised learning for domain adaptation on point clouds.In Proceedings of the IEEE/CVF winter conference onapplications of computer vision, pp. 123–133, 2021a.
  • Achituve etal. (2021b)Achituve, I., Navon, A., Yemini, Y., Chechik, G., and Fetaya, E.GP-Tree: A Gaussian process classifier for few-shot incrementallearning.In International Conference on Machine Learning, pp. 54–65.PMLR, 2021b.
  • Achituve etal. (2021c)Achituve, I., Shamsian, A., Navon, A., Chechik, G., and Fetaya, E.Personalized federated learning with Gaussian processes.Advances in Neural Information Processing Systems,34:8392–8406, 2021c.
  • Achituve etal. (2023)Achituve, I., Chechik, G., and Fetaya, E.Guided deep kernel learning.In Uncertainty in Artificial Intelligence. PMLR, 2023.
  • Baxter (2000)Baxter, J.A model of inductive bias learning.Journal of artificial intelligence research, 12:149–198, 2000.
  • Bishop (2006)Bishop, C.Pattern recognition and machine learning.Springer google schola, 2:531–537, 2006.
  • Brier (1950)Brier, G.W.Verification of forecasts expressed in terms of probability.Monthly weather review, 78(1):1–3, 1950.
  • Brookes (2020)Brookes, M.The matrix reference manual.http://www.ee.imperial.ac.uk/hp/staff/dmb/matrix/intro.html,2020.
  • Calandra etal. (2016)Calandra, R., Peters, J., Rasmussen, C.E., and Deisenroth, M.P.Manifold Gaussian processes for regression.In 2016 International Joint Conference on Neural Networks(IJCNN), pp. 3338–3345. IEEE, 2016.
  • Caruana (1997)Caruana, R.Multitask learning.Machine learning, 28:41–75, 1997.
  • Chen etal. (2018)Chen, Z., Badrinarayanan, V., Lee, C.-Y., and Rabinovich, A.GradNorm: Gradient normalization for adaptive loss balancing indeep multitask networks.In Dy, J. and Krause, A. (eds.), Proceedings of the 35thInternational Conference on Machine Learning, volume80 of Proceedingsof Machine Learning Research, pp. 794–803. PMLR, 10–15 Jul 2018.
  • Chen etal. (2020)Chen, Z., Ngiam, J., Huang, Y., Luong, T., Kretzschmar, H., Chai, Y., andAnguelov, D.Just pick a sign: Optimizing deep multitask models with gradient signdropout.In Larochelle, H., Ranzato, M., Hadsell, R., Balcan, M., and Lin, H.(eds.), Advances in Neural Information Processing Systems, volume33,pp. 2039–2050. Curran Associates, Inc., 2020.
  • Chennupati etal. (2019)Chennupati, S., Sistu, G., Yogamani, S., and Rawashdeh, S.Multinet++: Multi-stream feature aggregation and geometric lossstrategy for multi-task learning.In Proceedings of the IEEE/CVF Conference on Computer Visionand Pattern Recognition (CVPR) Workshops, June 2019.
  • Daheim etal. (2023)Daheim, N., Möllenhoff, T., Ponti, E., Gurevych, I., and Khan, M.E.Model merging by uncertainty-based gradient matching.In The Twelfth International Conference on LearningRepresentations, 2023.
  • Dai etal. (2023)Dai, Y., Fei, N., and Lu, Z.Improvable gap balancing for multi-task learning.In Uncertainty in Artificial Intelligence, pp. 496–506.PMLR, 2023.
  • D’Angelo & Fortuin (2021)D’Angelo, F. and Fortuin, V.Repulsive deep ensembles are Bayesian.Advances in Neural Information Processing Systems,34:3451–3465, 2021.
  • Davies (1973)Davies, R.B.Numerical inversion of a characteristic function.Biometrika, 60(2):415–417, 1973.
  • Daxberger etal. (2021)Daxberger, E., Kristiadi, A., Immer, A., Eschenhagen, R., Bauer, M., andHennig, P.Laplace redux-effortless bayesian deep learning.Advances in Neural Information Processing Systems,34:20089–20103, 2021.
  • Désidéri (2012)Désidéri, J.-A.Multiple-gradient descent algorithm (MGDA) for multiobjectiveoptimization.Comptes Rendus Mathematique, 350(5-6):313–318, 2012.
  • Devin etal. (2017)Devin, C., Gupta, A., Darrell, T., Abbeel, P., and Levine, S.Learning modular neural network policies for multi-task andmulti-robot transfer.In 2017 IEEE international conference on robotics andautomation (ICRA), pp. 2169–2176. IEEE, 2017.
  • Dimitriadis etal. (2023)Dimitriadis, N., Frossard, P., and Fleuret, F.Pareto manifold learning: Tackling multiple tasks via ensembles ofsingle-task models.In International Conference on Machine Learning, pp.8015–8052. PMLR, 2023.
  • Elich etal. (2023)Elich, C., Kirchdorfer, L., Köhler, J.M., and Schott, L.Challenging common assumptions in multi-task learning.arXiv preprint arXiv:2311.04698, 2023.
  • Fernando etal. (2023)Fernando, H., Shen, H., Liu, M., Chaudhury, S., Murugesan, K., and Chen, T.Mitigating gradient bias in multi-objective learning: A provablyconvergent stochastic approach.In International Conference on Learning Representations, 2023.
  • Fey & Lenssen (2019)Fey, M. and Lenssen, J.E.Fast graph representation learning with pytorch geometric.In ICLR Workshop on Representation Learning on Graphs andManifolds, 2019.
  • Fortuin etal. (2021)Fortuin, V., Garriga-Alonso, A., Ober, S.W., Wenzel, F., Ratsch, G., Turner,R.E., vander Wilk, M., and Aitchison, L.Bayesian neural network priors revisited.In International Conference on Learning Representations, 2021.
  • Gilmer etal. (2017)Gilmer, J., Schoenholz, S.S., Riley, P.F., Vinyals, O., and Dahl, G.E.Neural message passing for quantum chemistry.In International conference on machine learning, pp.1263–1272. PMLR, 2017.
  • Guo etal. (2018)Guo, M., Haque, A., Huang, D.-A., Yeung, S., and Fei-Fei, L.Dynamic task prioritization for multitask learning.In Proceedings of the European Conference on Computer Vision(ECCV), September 2018.
  • Immer etal. (2021)Immer, A., Bauer, M., Fortuin, V., Rätsch, G., and Emtiyaz, K.M.Scalable marginal likelihood estimation for model selection in deeplearning.In International Conference on Machine Learning, pp.4563–4573. PMLR, 2021.
  • Javaloy & Valera (2022)Javaloy, A. and Valera, I.Rotograd: Gradient hom*ogenization in multitask learning.In International Conference on Learning Representations, 2022.
  • Kendall etal. (2018)Kendall, A., Gal, Y., and Cipolla, R.Multi-task learning using uncertainty to weigh losses for scenegeometry and semantics.In Proceedings of the IEEE conference on computer vision andpattern recognition, pp. 7482–7491, 2018.
  • Kingma & Ba (2014)Kingma, D.P. and Ba, J.ADAM: A method for stochastic optimization.In International Conference on Learning Representations, 2014.
  • Kingma & Ba (2015)Kingma, D.P. and Ba, J.Adam: A method for stochastic optimization.In Bengio, Y. and LeCun, Y. (eds.), 3rd InternationalConference on Learning Representations, 2015.
  • Kristiadi etal. (2020)Kristiadi, A., Hein, M., and Hennig, P.Being Bayesian, even just a bit, fixes overconfidence in Relunetworks.In International conference on machine learning, pp.5436–5446. PMLR, 2020.
  • Krizhevsky etal. (2009)Krizhevsky, A., Hinton, G., etal.Learning multiple layers of features from tiny images.Technical report, University of Toronto, 2009.
  • Kurin etal. (2022)Kurin, V., DePalma, A., Kostrikov, I., Whiteson, S., and Mudigonda, P.K.In defense of the unitary scalarization for deep multi-task learning.Advances in Neural Information Processing Systems,35:12169–12183, 2022.
  • Lakshminarayanan etal. (2017)Lakshminarayanan, B., Pritzel, A., and Blundell, C.Simple and scalable predictive uncertainty estimation using deepensembles.Advances in neural information processing systems, 30, 2017.
  • Lin etal. (2022)Lin, B., Ye, F., Zhang, Y., and Tsang, I.W.Reasonable effectiveness of random weighting: A litmus test formulti-task learning.Transactions on Machine Learning Research, 2022.
  • Liu etal. (2021)Liu, B., Liu, X., Jin, X., Stone, P., and Liu, Q.Conflict-averse gradient descent for multi-task learning.Advances in Neural Information Processing Systems,34:18878–18890, 2021.
  • Liu etal. (2023)Liu, B., Feng, Y., Stone, P., and Liu, Q.Famo: Fast adaptive multitask optimization, 2023.
  • Liu etal. (2020)Liu, L., Li, Y., Kuang, Z., Xue, J.-H., Chen, Y., Yang, W., Liao, Q., andZhang, W.Towards impartial multi-task learning.In International Conference on Learning Representations, 2020.
  • Liu etal. (2019a)Liu, S., Johns, E., and Davison, A.J.End-to-end multi-task learning with attention.In Proceedings of the IEEE/CVF conference on computer visionand pattern recognition, pp. 1871–1880, 2019a.
  • Liu etal. (2019b)Liu, X., He, P., Chen, W., and Gao, J.Multi-task deep neural networks for natural language understanding.In Proceedings of the 57th Annual Meeting of the Associationfor Computational Linguistics, pp. 4487–4496, 2019b.
  • MacKay (1992)MacKay, D.J.Bayesian interpolation.Neural computation, 4(3):415–447, 1992.
  • Maninis etal. (2019)Maninis, K.-K., Radosavovic, I., and Kokkinos, I.Attentive single-tasking of multiple tasks.In Proceedings of the IEEE/CVF conference on computer visionand pattern recognition, pp. 1851–1860, 2019.
  • Martens & Sutskever (2011)Martens, J. and Sutskever, I.Learning recurrent neural networks with hessian-free optimization.In Proceedings of the 28th international conference on machinelearning (ICML-11), pp. 1033–1040, 2011.
  • Matena & Raffel (2022)Matena, M.S. and Raffel, C.A.Merging models with fisher-weighted averaging.Advances in Neural Information Processing Systems,35:17703–17716, 2022.
  • Michelsanti etal. (2021)Michelsanti, D., Tan, Z.-H., Zhang, S.-X., Xu, Y., Yu, M., Yu, D., and Jensen,J.An overview of deep-learning-based audio-visual speech enhancementand separation.IEEE/ACM Transactions on Audio, Speech, and LanguageProcessing, 29:1368–1396, 2021.
  • Minka (2001)Minka, T.P.Expectation propagation for approximate bayesian inference.In Proceedings of the Seventeenth conference on Uncertainty inartificial intelligence, pp. 362–369, 2001.
  • Misra etal. (2016)Misra, I., Shrivastava, A., Gupta, A., and Hebert, M.Cross-stitch networks for multi-task learning.In Proceedings of the IEEE/CVF Conference on Computer Visionand Pattern Recognition (CVPR), pp. 3994–4003, 06 2016.doi: 10.1109/CVPR.2016.433.
  • Naeini etal. (2015)Naeini, M.P., Cooper, G.F., and Hauskrecht, M.Obtaining well calibrated probabilities using Bayesian binning.In Proceedings of the Twenty-Ninth AAAI Conference onArtificial Intelligence, January 25-30, 2015, Austin, Texas, USA, pp.2901–2907. AAAI Press, 2015.
  • Navon etal. (2022)Navon, A., Shamsian, A., Achituve, I., Maron, H., Kawaguchi, K., Chechik, G.,and Fetaya, E.Multi-task learning as a bargaining game.In International Conference on Machine Learning, pp.16428–16446. PMLR, 2022.
  • Neal & Hinton (1998)Neal, R.M. and Hinton, G.E.A view of the EM algorithm that justifies incremental, sparse, andother variants.In Learning in graphical models, pp. 355–368. Springer,1998.
  • Ramakrishnan etal. (2014)Ramakrishnan, R., Dral, P.O., Rupp, M., and VonLilienfeld, O.A.Quantum chemistry structures and properties of 134 kilo molecules.Scientific data, 1(1):1–7, 2014.
  • Rosenbaum etal. (2018)Rosenbaum, C., Klinger, T., and Riemer, M.Routing networks: Adaptive selection of non-linear functions formulti-task learning.In International Conference on Learning Representations, 2018.
  • Ruder (2017)Ruder, S.An overview of multi-task learning in deep neural networks.arXiv preprint arXiv:1706.05098, 2017.
  • Russakovsky etal. (2015)Russakovsky, O., Deng, J., Su, H., Krause, J., Satheesh, S., Ma, S., Huang, Z.,Karpathy, A., Khosla, A., Bernstein, M., etal.ImageNet large scale visual recognition challenge.International journal of computer vision, 115:211–252, 2015.
  • Sagawa etal. (2019)Sagawa, S., Koh, P.W., Hashimoto, T.B., and Liang, P.Distributionally robust neural networks.In International Conference on Learning Representations, 2019.
  • Särkkä (2013)Särkkä, S.Bayesian Filtering and Smoothing, volume3 of Instituteof Mathematical Statistics textbooks.Cambridge University Press, 2013.
  • Saul etal. (1996)Saul, L.K., Jaakkola, T., and Jordan, M.I.Mean field theory for Sigmoid Belief Networks.Journal of artificial intelligence research, 4:61–76, 1996.
  • Schaul etal. (2019)Schaul, T., Borsa, D., Modayil, J., and Pascanu, R.Ray interference: a source of plateaus in deep reinforcementlearning, 2019.
  • Schraudolph (2002)Schraudolph, N.N.Fast curvature matrix-vector products for second-order gradientdescent.Neural computation, 14(7):1723–1738,2002.
  • Sener & Koltun (2018)Sener, O. and Koltun, V.Multi-task learning as multi-objective optimization.Advances in neural information processing systems, 31, 2018.
  • Senushkin etal. (2023)Senushkin, D., Patakin, N., Kuznetsov, A., and Konushin, A.Independent component alignment for multi-task learning.In Proceedings of the IEEE/CVF Conference on Computer Visionand Pattern Recognition, pp. 20083–20093, 2023.
  • Shamshad etal. (2023)Shamshad, F., Khan, S., Zamir, S.W., Khan, M.H., Hayat, M., Khan, F.S., andFu, H.Transformers in medical imaging: A survey.Medical Image Analysis, pp. 102802, 2023.
  • Shamsian etal. (2023)Shamsian, A., Navon, A., Glazer, N., Kawaguchi, K., Chechik, G., and Fetaya, E.Auxiliary learning as an asymmetric bargaining game.arXiv preprint arXiv:2301.13501, 2023.
  • Shi etal. (2023)Shi, H., Ren, S., Zhang, T., and Pan, S.J.Deep multitask learning with progressive parameter sharing.In Proceedings of the IEEE/CVF International Conference onComputer Vision, pp. 19924–19935, 2023.
  • Shu etal. (2018)Shu, T., Xiong, C., and Socher, R.Hierarchical and interpretable skill acquisition in multi-taskreinforcement learning.In International Conference on Learning Representations, 2018.
  • Snoek etal. (2015)Snoek, J., Rippel, O., Swersky, K., Kiros, R., Satish, N., Sundaram, N.,Patwary, M., Prabhat, M., and Adams, R.Scalable Bayesian optimization using deep neural networks.In International conference on machine learning, pp.2171–2180. PMLR, 2015.
  • Standley etal. (2020)Standley, T., Zamir, A., Chen, D., Guibas, L., Malik, J., and Savarese, S.Which tasks should be learned together in multi-task learning?In International Conference on Machine Learning, pp.9120–9132. PMLR, 2020.
  • Taslimi etal. (2022)Taslimi, S., Taslimi, S., Fathi, N., Salehi, M., and Rohban, M.H.SwincheX: Multi-label classification on chest X-ray images withtransformers.arXiv preprint arXiv:2206.04246, 2022.
  • Vinyals etal. (2016)Vinyals, O., Bengio, S., and Kudlur, M.Order matters: Sequence to sequence for sets.In Bengio, Y. and LeCun, Y. (eds.), 4th InternationalConference on Learning Representations, ICLR, 2016.
  • Wang etal. (2017)Wang, X., Peng, Y., Lu, L., Lu, Z., Bagheri, M., and Summers, R.M.ChestX-ray8: Hospital-scale chest X-ray database and benchmarkson weakly-supervised classification and localization of common thoraxdiseases.In Proceedings of the IEEE conference on computer vision andpattern recognition, pp. 2097–2106, 2017.
  • Wang etal. (2020)Wang, Z., Tsvetkov, Y., Firat, O., and Cao, Y.Gradient vaccine: Investigating and improving multi-task optimizationin massively multilingual models.In International Conference on Learning Representations, 2020.
  • Wightman (2019)Wightman, R.Pytorch image models.https://github.com/rwightman/pytorch-image-models, 2019.
  • Wild etal. (2024)Wild, V.D., Ghalebikesabi, S., Sejdinovic, D., and Knoblauch, J.A rigorous link between deep ensembles and (variational) Bayesianmethods.Advances in Neural Information Processing Systems, 36, 2024.
  • Wilson & Izmailov (2020)Wilson, A.G. and Izmailov, P.Bayesian deep learning and a probabilistic perspective ofgeneralization.Advances in neural information processing systems,33:4697–4708, 2020.
  • Wilson etal. (2016a)Wilson, A.G., Hu, Z., Salakhutdinov, R., and Xing, E.P.Deep kernel learning.In Artificial intelligence and statistics, pp. 370–378.PMLR, 2016a.
  • Wilson etal. (2016b)Wilson, A.G., Hu, Z., Salakhutdinov, R.R., and Xing, E.P.Stochastic variational deep kernel learning.Advances in neural information processing systems, 29,2016b.
  • Wu etal. (2018)Wu, Z., Ramsundar, B., Feinberg, E.N., Gomes, J., Geniesse, C., Pappu, A.S.,Leswing, K., and Pande, V.MoleculeNet: a benchmark for molecular machine learning.Chemical science, 9(2):513–530, 2018.
  • Xin etal. (2022)Xin, D., Ghorbani, B., Gilmer, J., Garg, A., and Firat, O.Do current multi-task optimization methods in deep learning evenhelp?Advances in Neural Information Processing Systems,35:13597–13609, 2022.
  • Yu etal. (2020)Yu, T., Kumar, S., Gupta, A., Levine, S., Hausman, K., and Finn, C.Gradient surgery for multi-task learning.Advances in Neural Information Processing Systems,33:5824–5836, 2020.
  • Yun & Cho (2023)Yun, H. and Cho, H.Achievement-based training progress balancing for multi-tasklearning.In Proceedings of the IEEE/CVF International Conference onComputer Vision (ICCV), pp. 16935–16944, October 2023.
  • Zhang etal. (2017)Zhang, Z., Song, Y., and Qi, H.Age progression/regression by conditional adversarial autoencoder.In Proceedings of the IEEE conference on computer vision andpattern recognition, pp. 5810–5818, 2017.
  • Zheng etal. (2023)Zheng, C., Wu, W., Chen, C., Yang, T., Zhu, S., Shen, J., Kehtarnavaz, N., andShah, M.Deep learning-based human pose estimation: A survey.ACM Computing Surveys, 56(1):1–37, 2023.
  • Zhou etal. (2023)Zhou, C., Li, Q., Li, C., Yu, J., Liu, Y., Wang, G., Zhang, K., Ji, C., Yan,Q., He, L., etal.A comprehensive survey on pretrained foundation models: A historyfrom BERT to ChatGPT.arXiv preprint arXiv:2302.09419, 2023.

Appendix A Full Derivations

We now present the full derivation for Eq.4 & Eq.9 presented in the main text. For clarity, we drop here the superscript notation of the task.

A.1 Regression Moments

Starting with the first moment,

𝔼[𝐠i]𝔼delimited-[]subscript𝐠𝑖\displaystyle\mathbb{E}[{\mathbf{g}}_{i}]blackboard_E [ bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ]=𝐠ip(𝐠i)𝑑𝐠i=𝐠i(𝐰)p(𝐰|𝒟)𝑑𝐰absentsubscript𝐠𝑖𝑝subscript𝐠𝑖differential-dsubscript𝐠𝑖subscript𝐠𝑖𝐰𝑝conditional𝐰𝒟differential-d𝐰\displaystyle=\int{\mathbf{g}}_{i}p({\mathbf{g}}_{i})d{\mathbf{g}}_{i}=\int{%\mathbf{g}}_{i}({\mathbf{w}})p({\mathbf{w}}|{\mathcal{D}})d{\mathbf{w}}= ∫ bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_p ( bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) italic_d bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = ∫ bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_w ) italic_p ( bold_w | caligraphic_D ) italic_d bold_w(12)
=2𝐰(𝐡iT𝐰yi)p(𝐰|𝒟)𝑑𝐰absent2𝐰superscriptsubscript𝐡𝑖𝑇𝐰subscript𝑦𝑖𝑝conditional𝐰𝒟differential-d𝐰\displaystyle=2\int{\mathbf{w}}({\mathbf{h}}_{i}^{T}{\mathbf{w}}-y_{i})p({%\mathbf{w}}|{\mathcal{D}})d{\mathbf{w}}= 2 ∫ bold_w ( bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_w - italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) italic_p ( bold_w | caligraphic_D ) italic_d bold_w
=2𝐰𝐰T𝐡iyi𝐰p(𝐰|𝒟)d𝐰absent2superscript𝐰𝐰𝑇subscript𝐡𝑖subscript𝑦𝑖𝐰𝑝conditional𝐰𝒟𝑑𝐰\displaystyle=2\int{\mathbf{w}}{\mathbf{w}}^{T}{\mathbf{h}}_{i}-y_{i}{\mathbf{%w}}p({\mathbf{w}}|{\mathcal{D}})d{\mathbf{w}}= 2 ∫ bold_ww start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_w italic_p ( bold_w | caligraphic_D ) italic_d bold_w
=2([𝐒+𝐦𝐦T]𝐡iyi𝐦).absent2delimited-[]𝐒superscript𝐦𝐦𝑇subscript𝐡𝑖subscript𝑦𝑖𝐦\displaystyle=2([{\mathbf{S}}+{\mathbf{m}}{\mathbf{m}}^{T}]{\mathbf{h}}_{i}-y_%{i}{\mathbf{m}}).= 2 ( [ bold_S + bold_mm start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ] bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_m ) .

Where we made explicit the dependence in 𝐰𝐰{\mathbf{w}}bold_w on the first step. For computing the second moment we aided by the matrix reference manual (Brookes, 2020),

𝔼[𝐠i𝐠iT]𝔼delimited-[]subscript𝐠𝑖superscriptsubscript𝐠𝑖𝑇\displaystyle\mathbb{E}[{\mathbf{g}}_{i}{\mathbf{g}}_{i}^{T}]blackboard_E [ bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ]=𝐠i𝐠iTp(𝐠i)𝑑𝐠i=𝐠i(𝐰)𝐠iT(𝐰)p(𝐰|𝒟)𝑑𝐰absentsubscript𝐠𝑖superscriptsubscript𝐠𝑖𝑇𝑝subscript𝐠𝑖differential-dsubscript𝐠𝑖subscript𝐠𝑖𝐰superscriptsubscript𝐠𝑖𝑇𝐰𝑝conditional𝐰𝒟differential-d𝐰\displaystyle=\int{\mathbf{g}}_{i}{\mathbf{g}}_{i}^{T}p({\mathbf{g}}_{i})d{%\mathbf{g}}_{i}=\int{\mathbf{g}}_{i}({\mathbf{w}}){\mathbf{g}}_{i}^{T}({%\mathbf{w}})p({\mathbf{w}}|{\mathcal{D}})d{\mathbf{w}}= ∫ bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_p ( bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) italic_d bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = ∫ bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_w ) bold_g start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( bold_w ) italic_p ( bold_w | caligraphic_D ) italic_d bold_w(13)
=4𝐰(𝐡iT𝐰yi)(𝐡iT𝐰yi)𝐰Tp(𝐰|𝒟)𝑑𝐰absent4𝐰superscriptsubscript𝐡𝑖𝑇𝐰subscript𝑦𝑖superscriptsubscript𝐡𝑖𝑇𝐰subscript𝑦𝑖superscript𝐰𝑇𝑝conditional𝐰𝒟differential-d𝐰\displaystyle=4\int{\mathbf{w}}({\mathbf{h}}_{i}^{T}{\mathbf{w}}-y_{i})({%\mathbf{h}}_{i}^{T}{\mathbf{w}}-y_{i}){\mathbf{w}}^{T}p({\mathbf{w}}|{\mathcal%{D}})d{\mathbf{w}}= 4 ∫ bold_w ( bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_w - italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_w - italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) bold_w start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_p ( bold_w | caligraphic_D ) italic_d bold_w
=4(yi2𝐰𝐰T𝐡i2yi𝐰𝐡iT𝐰𝐰T+𝐰𝐰T𝐡i𝐡iT𝐰𝐰T)p(𝐰|𝒟)𝑑𝐰.absent4superscriptsubscript𝑦𝑖2superscript𝐰𝐰𝑇subscript𝐡𝑖2subscript𝑦𝑖superscriptsubscript𝐰𝐡𝑖𝑇superscript𝐰𝐰𝑇superscript𝐰𝐰𝑇subscript𝐡𝑖superscriptsubscript𝐡𝑖𝑇superscript𝐰𝐰𝑇𝑝conditional𝐰𝒟differential-d𝐰\displaystyle=4\int(y_{i}^{2}{\mathbf{w}}{\mathbf{w}}^{T}{\mathbf{h}}_{i}-2y_{%i}{\mathbf{w}}{\mathbf{h}}_{i}^{T}{\mathbf{w}}{\mathbf{w}}^{T}+{\mathbf{w}}{%\mathbf{w}}^{T}{\mathbf{h}}_{i}{\mathbf{h}}_{i}^{T}{\mathbf{w}}{\mathbf{w}}^{T%})p({\mathbf{w}}|{\mathcal{D}})d{\mathbf{w}}.= 4 ∫ ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_ww start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - 2 italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_wh start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_ww start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT + bold_ww start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_ww start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) italic_p ( bold_w | caligraphic_D ) italic_d bold_w .

We now solve each term separately and obtain the result,

𝐰𝐰Tp(𝐰|𝒟)𝑑𝐰superscript𝐰𝐰𝑇𝑝conditional𝐰𝒟differential-d𝐰\displaystyle\int{\mathbf{w}}{\mathbf{w}}^{T}p({\mathbf{w}}|{\mathcal{D}})d{%\mathbf{w}}∫ bold_ww start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_p ( bold_w | caligraphic_D ) italic_d bold_w=𝐒+𝐦𝐦T,absent𝐒superscript𝐦𝐦𝑇\displaystyle={\mathbf{S}}+{\mathbf{m}}{\mathbf{m}}^{T},= bold_S + bold_mm start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ,(14)
𝐰𝐡iT𝐰𝐰Tp(𝐰|𝒟)𝑑𝐰superscriptsubscript𝐰𝐡𝑖𝑇superscript𝐰𝐰𝑇𝑝conditional𝐰𝒟differential-d𝐰\displaystyle\int{\mathbf{w}}{\mathbf{h}}_{i}^{T}{\mathbf{w}}{\mathbf{w}}^{T}p%({\mathbf{w}}|{\mathcal{D}})d{\mathbf{w}}∫ bold_wh start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_ww start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_p ( bold_w | caligraphic_D ) italic_d bold_w=𝐦𝐡iT(𝐒+𝐦𝐦T)+(𝐒+𝐦𝐦T)𝐡i𝐦T+𝐡iT𝐦(𝐒𝐦𝐦T),absentsuperscriptsubscript𝐦𝐡𝑖𝑇𝐒superscript𝐦𝐦𝑇𝐒superscript𝐦𝐦𝑇subscript𝐡𝑖superscript𝐦𝑇superscriptsubscript𝐡𝑖𝑇𝐦𝐒superscript𝐦𝐦𝑇\displaystyle={\mathbf{m}}{\mathbf{h}}_{i}^{T}({\mathbf{S}}+{\mathbf{m}}{%\mathbf{m}}^{T})+({\mathbf{S}}+{\mathbf{m}}{\mathbf{m}}^{T}){\mathbf{h}}_{i}{%\mathbf{m}}^{T}+{\mathbf{h}}_{i}^{T}{\mathbf{m}}({\mathbf{S}}-{\mathbf{m}}{%\mathbf{m}}^{T}),= bold_mh start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( bold_S + bold_mm start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) + ( bold_S + bold_mm start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_m start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT + bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_m ( bold_S - bold_mm start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) ,
𝐰𝐰T𝐡i𝐡iT𝐰𝐰Tp(𝐰|𝒟)𝑑𝐰superscript𝐰𝐰𝑇subscript𝐡𝑖superscriptsubscript𝐡𝑖𝑇superscript𝐰𝐰𝑇𝑝conditional𝐰𝒟differential-d𝐰\displaystyle\int{\mathbf{w}}{\mathbf{w}}^{T}{\mathbf{h}}_{i}{\mathbf{h}}_{i}^%{T}{\mathbf{w}}{\mathbf{w}}^{T}p({\mathbf{w}}|{\mathcal{D}})d{\mathbf{w}}∫ bold_ww start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_ww start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_p ( bold_w | caligraphic_D ) italic_d bold_w=(𝐒+𝐦𝐦T)(𝐀i+𝐀iT)(𝐒+𝐦𝐦T)+𝐦T𝐀i𝐦(𝐒𝐦𝐦T)+Tr(𝐀i𝐒)(𝐒+𝐦𝐦T).absent𝐒superscript𝐦𝐦𝑇subscript𝐀𝑖superscriptsubscript𝐀𝑖𝑇𝐒superscript𝐦𝐦𝑇superscript𝐦𝑇subscript𝐀𝑖𝐦𝐒superscript𝐦𝐦𝑇𝑇𝑟subscript𝐀𝑖𝐒𝐒superscript𝐦𝐦𝑇\displaystyle=({\mathbf{S}}+{\mathbf{m}}{\mathbf{m}}^{T})({\mathbf{A}}_{i}+{%\mathbf{A}}_{i}^{T})({\mathbf{S}}+{\mathbf{m}}{\mathbf{m}}^{T})+{\mathbf{m}}^{%T}{\mathbf{A}}_{i}{\mathbf{m}}({\mathbf{S}}-{\mathbf{m}}{\mathbf{m}}^{T})+Tr({%\mathbf{A}}_{i}{\mathbf{S}})({\mathbf{S}}+{\mathbf{m}}{\mathbf{m}}^{T}).= ( bold_S + bold_mm start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) ( bold_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + bold_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) ( bold_S + bold_mm start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) + bold_m start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_m ( bold_S - bold_mm start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) + italic_T italic_r ( bold_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_S ) ( bold_S + bold_mm start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) .

Where, 𝐀i=𝐡i𝐡iTsubscript𝐀𝑖subscript𝐡𝑖superscriptsubscript𝐡𝑖𝑇{\mathbf{A}}_{i}={\mathbf{h}}_{i}{\mathbf{h}}_{i}^{T}bold_A start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT.

A.2 Second Order Posterior Approximation

We now present the quadratic form of the log-posterior in Eq.9. First we recap some of our notations here,

c=logp(𝐰^|);𝐚=logp(𝐲|𝐗,𝐰)𝐰logp(𝐰)𝐰|𝐰=𝐰^;𝐁=2logp(𝐲|𝐗,𝐰)𝐰22logp(𝐰)𝐰2|𝐰=𝐰^.formulae-sequence𝑐𝑙𝑜𝑔𝑝conditional^𝐰formulae-sequence𝐚𝑙𝑜𝑔𝑝conditional𝐲𝐗𝐰𝐰evaluated-at𝑙𝑜𝑔𝑝𝐰𝐰𝐰^𝐰𝐁superscript2𝑙𝑜𝑔𝑝conditional𝐲𝐗𝐰superscript𝐰2evaluated-atsuperscript2𝑙𝑜𝑔𝑝𝐰superscript𝐰2𝐰^𝐰\displaystyle c=log~{}p(\hat{{\mathbf{w}}}|{\mathcal{B}});\quad{\mathbf{a}}=%\left.-\frac{\partial log~{}p({\mathbf{y}}|{\mathbf{X}},{\mathbf{w}})}{%\partial{\mathbf{w}}}-\frac{\partial log~{}p({\mathbf{w}})}{\partial{\mathbf{w%}}}\right|_{{\mathbf{w}}=\hat{{\mathbf{w}}}};\quad{\mathbf{B}}=\left.-\frac{%\partial^{2}log~{}p({\mathbf{y}}|{\mathbf{X}},{\mathbf{w}})}{\partial{\mathbf{%w}}^{2}}-\frac{\partial^{2}log~{}p({\mathbf{w}})}{\partial{\mathbf{w}}^{2}}%\right|_{{\mathbf{w}}=\hat{{\mathbf{w}}}}.italic_c = italic_l italic_o italic_g italic_p ( over^ start_ARG bold_w end_ARG | caligraphic_B ) ; bold_a = - divide start_ARG ∂ italic_l italic_o italic_g italic_p ( bold_y | bold_X , bold_w ) end_ARG start_ARG ∂ bold_w end_ARG - divide start_ARG ∂ italic_l italic_o italic_g italic_p ( bold_w ) end_ARG start_ARG ∂ bold_w end_ARG | start_POSTSUBSCRIPT bold_w = over^ start_ARG bold_w end_ARG end_POSTSUBSCRIPT ; bold_B = - divide start_ARG ∂ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_l italic_o italic_g italic_p ( bold_y | bold_X , bold_w ) end_ARG start_ARG ∂ bold_w start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG - divide start_ARG ∂ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_l italic_o italic_g italic_p ( bold_w ) end_ARG start_ARG ∂ bold_w start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG | start_POSTSUBSCRIPT bold_w = over^ start_ARG bold_w end_ARG end_POSTSUBSCRIPT .(15)

Using these constants in Eq.8 yields the following form:

c+𝐚T(𝐰𝐰^)+12(𝐰𝐰^)T𝐁(𝐰𝐰^)𝑐superscript𝐚𝑇𝐰^𝐰12superscript𝐰^𝐰𝑇𝐁𝐰^𝐰\displaystyle c+{\mathbf{a}}^{T}({\mathbf{w}}-\hat{{\mathbf{w}}})+\frac{1}{2}(%{\mathbf{w}}-\hat{{\mathbf{w}}})^{T}{\mathbf{B}}({\mathbf{w}}-\hat{{\mathbf{w}%}})italic_c + bold_a start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( bold_w - over^ start_ARG bold_w end_ARG ) + divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( bold_w - over^ start_ARG bold_w end_ARG ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_B ( bold_w - over^ start_ARG bold_w end_ARG )(16)
=\displaystyle==c𝐚T𝐰^+12(𝐰^T𝐁𝐰^)(𝐁T𝐰^𝐚)T𝐁1(𝐁T𝐰^𝐚)const.+12(𝐰(𝐰^𝐁1𝐚))T𝐁(𝐰(𝐰^𝐁1𝐚)).subscript𝑐superscript𝐚𝑇^𝐰12superscript^𝐰𝑇𝐁^𝐰superscriptsuperscript𝐁𝑇^𝐰𝐚𝑇superscript𝐁1superscript𝐁𝑇^𝐰𝐚𝑐𝑜𝑛𝑠𝑡12superscript𝐰^𝐰superscript𝐁1𝐚𝑇𝐁𝐰^𝐰superscript𝐁1𝐚\displaystyle\underbrace{c-{\mathbf{a}}^{T}\hat{{\mathbf{w}}}+\frac{1}{2}(\hat%{{\mathbf{w}}}^{T}{\mathbf{B}}\hat{{\mathbf{w}}})-({\mathbf{B}}^{T}\hat{{%\mathbf{w}}}-{\mathbf{a}})^{T}{\mathbf{B}}^{-1}({\mathbf{B}}^{T}\hat{{\mathbf{%w}}}-{\mathbf{a}})}_{const.}+\frac{1}{2}({\mathbf{w}}-(\hat{{\mathbf{w}}}-{%\mathbf{B}}^{-1}{\mathbf{a}}))^{T}{\mathbf{B}}({\mathbf{w}}-(\hat{{\mathbf{w}}%}-{\mathbf{B}}^{-1}{\mathbf{a}})).under⏟ start_ARG italic_c - bold_a start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT over^ start_ARG bold_w end_ARG + divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( over^ start_ARG bold_w end_ARG start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_B over^ start_ARG bold_w end_ARG ) - ( bold_B start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT over^ start_ARG bold_w end_ARG - bold_a ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_B start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_B start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT over^ start_ARG bold_w end_ARG - bold_a ) end_ARG start_POSTSUBSCRIPT italic_c italic_o italic_n italic_s italic_t . end_POSTSUBSCRIPT + divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( bold_w - ( over^ start_ARG bold_w end_ARG - bold_B start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_a ) ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_B ( bold_w - ( over^ start_ARG bold_w end_ARG - bold_B start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_a ) ) .

The above takes the quadratic form of a Gaussian having mean (𝐰^𝐁1𝐚)^𝐰superscript𝐁1𝐚(\hat{{\mathbf{w}}}-{\mathbf{B}}^{-1}{\mathbf{a}})( over^ start_ARG bold_w end_ARG - bold_B start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_a ) and covariance 𝐁1superscript𝐁1{\mathbf{B}}^{-1}bold_B start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT.

Bayesian Uncertainty for Gradient Aggregation in Multi-Task Learning (4)
Bayesian Uncertainty for Gradient Aggregation in Multi-Task Learning (5)

Appendix B Full Experimental Details

All the experiments were done using PyTorch on NVIDIA V100 and A100 GPUs having 32GB of memory.

QM9. On this dataset we followed the training protocol presented by Navon etal. (2022). Specifically, We allocated 110,000110000110,000110 , 000 examples for training and 10,0001000010,00010 , 000 examples for validation and testing. The task labels are normalized to have zero mean and unit std. We use the implementation of (Fey & Lenssen, 2019) for the message-passing NN presented in (Gilmer etal., 2017) as the base NN. Here, we trained only our method and the baseline method Aligned-MTL-UB. All the other results were taken from (Navon etal., 2022; Dai etal., 2023). We used the same random seeds as in those studies. Each method was trained for 300300300300 epochs using the ADAM optimizer (Kingma & Ba, 2014) with an initial lr of 1e31𝑒31e-31 italic_e - 3. The batch size was set to 120120120120. We use the ReduceOnPlate scheduler with the ΔmsubscriptΔ𝑚\Delta_{m}roman_Δ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT metric, computed on the validation set. This metric was also used for early stopping and model selection. For BayesAgg-MTL we set the number of pre-training epochs using linear scalarization to 50505050. In initial experiments, we found that in regression tasks relatively higher values for the s𝑠sitalic_s hyper-parameter were preferred. Hence, we searched over s{0.75,0.85,0.95}𝑠0.750.850.95s\in\{0.75,0.85,0.95\}italic_s ∈ { 0.75 , 0.85 , 0.95 }. For the Aligned-MTL-UB we did a hyper-parameter search over the scale modes in {min, median, and rmse}, and whether to apply that scale to the task-specific parameters as well.

CIFAR-MTL. Similarly to (Rosenbaum etal., 2018), to form an MTL benchmark we used the coarse labels of CIFAR-100. Each example in the CIFAR-100 dataset belongs to one of 20202020 super-classes. We use these super-classes as separate binary MTL tasks, where the task value is 1111 if the example indeed belongs to the super-class and 00 otherwise. We use the official CIFAR train-test split of 50,0005000050,00050 , 000 and 10,0001000010,00010 , 000 respectively. We allocated 5,00050005,0005 , 000 examples from the training set to validation. To train the models we use a CNN having 3333 convolution layers with 160160160160 channels and a kernel of size 3333. Each convolution was followed by an Exponential Linear Unit (ELU) activation and max-pooling of 3×3333\times 33 × 3. The final layer is a batch normalization layer. All methods were trained for 50505050 epochs using the ADAM optimizer, with an initial learning rate of 1e31𝑒31e-31 italic_e - 3 and a scheduler that drops the learning rate by a factor of 0.10.10.10.1 at 60%percent6060\%60 % and 80%percent8080\%80 % of the training. We set the batch size to 128128128128 and used a weight decay of 1e41𝑒41e-41 italic_e - 4.For all baseline methods, we did a hyper-parameter grid search over the most important 23232-32 - 3 hyper-parameters. Specifically, we would like to highlight that we searched over additional weight decay values for the LS, SI, and RLW baselines as advocated by Kurin etal. (2022).As for BayesAgg-MTL , unlike the regression case, for classification smaller s𝑠sitalic_s values are preferred. We searched over s{5e2,5e3,5e4}𝑠5superscript𝑒25superscript𝑒35superscript𝑒4s\in\{5e^{-2},5e^{-3},5e^{-4}\}italic_s ∈ { 5 italic_e start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT , 5 italic_e start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT , 5 italic_e start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT }. Also, we search over the number of pre-train epochs in {1,3}13\{1,3\}{ 1 , 3 }. We set J𝐽Jitalic_J, the number of Monte-Carlo samples, to 1024102410241024, although we could have used much less without performance degradation. We used the validation accuracy for early stopping and model selection.

ChestX-ray14. This dataset reports the absence or appearance of 14141414 types of chest diseases, which we view as an MTL problem. It contains approximately 112,000112000112,000112 , 000 images from 32,7173271732,71732 , 717 patients. We use the official data split presented in (Wang etal., 2017), having 70%percent7070\%70 % training examples, 10%percent1010\%10 % validation examples, and 20%percent2020\%20 % test examples. We follow the experimental setup of (Taslimi etal., 2022) that uses PyTorch Image Models (Wightman, 2019) for data augmentations, a publicly available repository. We resize each image to size 224×224224224224\times 224224 × 224 and use data augmentation such as color jitter having 0.40.40.40.4 intensity and random erase of pixels with a probability of 0.250.250.250.25. Images are normalized according to ImageNet statistics (Russakovsky etal., 2015). We use here ResNet-34 pre-trained on ImageNet as the shared feature extractor. We replaced the final classification layer with a fully connected layer of dimension 256256256256 followed by an ELU activation. Experimental details and hyper-parameter searches are similar to those described for CIFAR-MTL, except for the following changes. Here we trained for 100100100100 epochs, the batch size was set to 256256256256, and we didn’t use a weight decay. We use the ΔmsubscriptΔ𝑚\Delta_{m}roman_Δ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT metric for early stopping and model selection.

UTKFace. This dataset contains approximately 23,7002370023,70023 , 700 images of faces, each associated with the age, gender, and ethnicity of the person. We remove 3333 examples from the dataset that have missing labels. We split the dataset to train/validation/test according to the 70102070102070-10-2070 - 10 - 20 scheme. The split was stratified by the age variable as it is the most diverse label. We treat the task of predicting the age as a regression task, and we normalize it to have zero mean and unit std. During training, images are resized to 140×140140140140\times 140140 × 140, randomly cropped to size 128128128128, and undergo random horizontal flip. Test images are resized and centered cropped. Here, we used ResNet-18 with the final classification layer replaced by a fully connected layer of size 256256256256 and an ELU activation. The experimental setup is similar to that described under the CIFAR-MTL, with the exception that here we trained for 100100100100 epochs. We perform a hyper-parameter grid search for all methods on this dataset as well. For our method, we set the number of pre-training epochs to 10101010 and searched over s{0.3,0.5,0.8}𝑠0.30.50.8s\in\{0.3,0.5,0.8\}italic_s ∈ { 0.3 , 0.5 , 0.8 } for the regression task and s{0.005,0.05,0.1}𝑠0.0050.050.1s\in\{0.005,0.05,0.1\}italic_s ∈ { 0.005 , 0.05 , 0.1 } for the classification tasks. We use the ΔmsubscriptΔ𝑚\Delta_{m}roman_Δ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT metric for early stopping and model selection. Optimizing and evaluating the regression task is done using the MSE loss and the classification tasks using the standard cross-entropy loss.

CIFAR-MTLQM9
LS1.5511.5511.5511.5513.9323.9323.9323.932
SI1.6001.6001.6001.6003.9573.9573.9573.957
RLW1.5551.5551.5551.5553.9103.9103.9103.910
DWA1.5711.5711.5711.5713.9373.9373.9373.937
UW1.7401.7401.7401.7404.0334.0334.0334.033
MGDA15.3415.3415.3415.3439.5739.5739.5739.57
PCGrad19.9919.9919.9919.9929.8929.8929.8929.89
CAGrad12.1912.1912.1912.1926.1626.1626.1626.16
IMTL-G10.1210.1210.1210.1227.1927.1927.1927.19
Nash-MTL22.4722.4722.4722.4745.3245.3245.3245.32
IGBv23.6513.6513.6513.6514.7234.7234.7234.723
Gradient w.r.t Representation
MGDA-UB5.5225.5225.5225.5229.3529.3529.3529.352
IMTL-G2.9692.9692.9692.9694.4264.4264.4264.426
Aligned-MTL-UB2.9882.9882.9882.9884.4284.4284.4284.428
BayesAgg-MTL (Ours)5.5585.5585.5585.5584.1774.1774.1774.177

Appendix C Additional Experiments

C.1 Calibration

A possible benefit of using a Bayesian layer as the last layer is enhanced uncertainty estimation capabilities. Here we compare BayesAgg-MTL to baseline methods on that aspect. To do so we log the expected calibration error (ECE) (Naeini etal., 2015) and Brier score (Brier, 1950) for all methods on the classification tasks of the UTKFace dataset. In ECE we first discretize the [0,1]01[0,1][ 0 , 1 ] line segment and then measure a weighted average difference between the classifier confidence and accuracy. We use 15151515 interval bins in our comparison. Brier score measures the mean square error between the one-hot label encoding and the prediction probability vector. In both metrics, lower values are better. Results are presented in Figure4. From the figure, BayesAgg-MTL is better calibrated than most methods on both datasets. On the gender task, it is best calibrated according to the two metrics. On the Ethnicity task, it has the best Brier score and second-best ECE score. We stress here that for a fair comparison with baseline methods, we did not use the Bayesian posterior of BayesAgg-MTL on the last layer to make test predictions, but rather the point estimate of it learned during training. Using the full posterior should yield even better results.

C.2 Training Time

Table4 compares the run time of all methods on the CIFAR-MTL and QM9 datasets. We report the average processing time of a batch based on 10101010 epochs.To do the comparison, we use the best hyper-parameter configuration (in terms of performance) according to the CIFAR-MTL experiments. For MGDA and IMTL-G we present the run time under two settings, (1) when using in the aggregation scheme the full gradients w.r.t the shared parameters (top block); (2) when using in the aggregation scheme the gradients w.r.t the hidden layer (bottom block) as BayesAgg-MTL does. For BayesAgg-MTL we do not include the pre-training steps in the time measurements. From the table, methods that do not rely on the gradients for weighing the tasks are faster as outlined before in previous studies (Xin etal., 2022; Kurin etal., 2022); however, this often comes at a significant performance reduction. BayesAgg-MTL training time is almost as fast as those methods on regression problems, in which everything is done in closed-form, and slower on classification problems, partly due to the sampling process. Nevertheless, it is substantially faster than other gradient balancing methods that use gradients w.r.t the shared parameters.

C.3 Comparison to Bayesian Training

QM9UTKFace
Ensemble (1024102410241024 heads)161.4±13.1plus-or-minus161.413.1~{}~{}161.4\pm 13.1161.4 ± 13.10.99±0.62plus-or-minus0.990.62~{}0.99\pm 0.620.99 ± 0.62
Ensemble (10101010 networks)144.5±0.3plus-or-minus144.50.3~{}144.5\pm 0.3144.5 ± 0.30.13±0.39plus-or-minus0.130.39-0.13\pm 0.39- 0.13 ± 0.39
BayesAgg-MTL (Ours)53.2±7.1plus-or-minus53.27.1\mathbf{~{}~{}53.2\pm 7.1}bold_53.2 ± bold_7.12.23±0.76plus-or-minus2.230.76\mathbf{-2.23\pm 0.76}- bold_2.23 ± bold_0.76

Given that we used a Bayesian inference procedure in our approach, a natural question one may ask is how does standard approximate Bayesian training perform in MTL?

Recall that the goal of this paper is to use Bayesian inference on the last layer as a means to train deterministic MTL models using the uncertainty estimates in the gradients of the tasks. We use these uncertainty estimates to come up with an aggregation rule for combining the gradients of the tasks to a shared update direction. More concisely, our aim is to better learn a deterministic MTL model while reducing as much as possible the computational overhead involved in training it. In standard approximate Bayesian training the gradient used in the backward process is considered as a deterministic quantity, similarly to non-Bayesian training. Hence, even when applying standard Bayesian inference to the task-specific parameters, the optimization issues regarding how to combine the gradients of the tasks effectively remain.

To showcase that we compare BayesAgg-MTL to deep ensembles (Lakshminarayanan etal., 2017) that have a strong link to approximate Bayesian methods (Wilson & Izmailov, 2020; D’Angelo & Fortuin, 2021; Wild etal., 2024). We chose deep ensembles because of their simplicity and predictive abilities. We show here results on QM9 and UTKFace for two baselines: (1) Using 1024102410241024 heads for each task and a shared backbone; (2) Using 10101010 networks, each with a different backbone and task heads. The latter is substantially computationally more demanding as it requires different copies of the backbone as well, which is usually large. We combine the tasks using linear scalarization (i.e., equal weighting of the tasks) and averaging over the ensemble members. We follow the same experimental protocol of the paper and report the ΔmsubscriptΔ𝑚\Delta_{m}roman_Δ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT values for each method in Table5. From the table, the ensemble model with the shared backbone performs roughly the same as standard linear scalarization, with a slight advantage on QM9. This result makes sense as the uncertainty information is not taken into account when aggregating the gradients (i.e., only the mean values are used). Full ensemble training improves upon the ensemble baseline having a shared feature extractor, but it comes with a substantial computational overhead. Finally, BayesAgg-MTL substantially outperforms both methods on both datasets.

C.4 Full Results

In Tables 6 and 7 we present the per-task results for all methods on the QM9 and ChestX-ray14 respectively. On QM9 we report the mean-absolute error of each task and on ChestX-ray14 the AUC-ROC of the tasks. Due to lack of space, we abbreviated several diseases names from the ChestX-ray14. We outline here the full names of all diseases: Atelectasis, Cardiomegaly, Consolidation, Edema, Effusion, Emphysema, Fibrosis, Hernia, Infiltration, Mass, Nodule, Pleural_Thickening, Pneumonia, Pneumothorax.

μ𝜇\muitalic_μα𝛼\alphaitalic_αϵhom*osubscriptitalic-ϵhom*o\epsilon_{\text{hom*o}}italic_ϵ start_POSTSUBSCRIPT hom*o end_POSTSUBSCRIPTϵLUMOsubscriptitalic-ϵLUMO\epsilon_{\text{LUMO}}italic_ϵ start_POSTSUBSCRIPT LUMO end_POSTSUBSCRIPTR2delimited-⟨⟩superscript𝑅2\langle R^{2}\rangle⟨ italic_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ⟩ZPVEU0subscript𝑈0U_{0}italic_U start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPTU𝑈Uitalic_UH𝐻Hitalic_HG𝐺Gitalic_Gcvsubscript𝑐𝑣c_{v}italic_c start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT
MAE \downarrow𝚫𝐦%percentsubscript𝚫𝐦\mathbf{\Delta_{m}\%}bold_Δ start_POSTSUBSCRIPT bold_m end_POSTSUBSCRIPT % (\downarrow)
STL0.0670.0670.0670.0670.1810.1810.1810.18160.5760.5760.5760.5753.9153.9153.9153.910.5030.5030.5030.5034.534.534.534.5358.8058.8058.8058.8064.2064.2064.2064.2063.8063.8063.8063.8066.2066.2066.2066.200.0720.0720.0720.072
LS0.1060.1060.1060.1060.3260.3260.3260.32673.5773.5773.5773.5789.6789.6789.6789.675.1975.1975.1975.19714.0614.0614.0614.06143.4143.4143.4143.4144.2144.2144.2144.2144.6144.6144.6144.6140.3140.3140.3140.30.1290.1290.1290.129177.6177.6177.6177.6
SI0.3090.3090.3090.3090.3460.3460.3460.346149.8149.8149.8149.8135.7135.7135.7135.71.0031.0031.0031.0034.514.514.514.5155.3255.3255.3255.3255.7555.7555.7555.7555.8255.8255.8255.8255.2755.2755.2755.270.1120.1120.1120.11277.877.877.877.8
RLW0.1130.1130.1130.1130.3400.3400.3400.34076.9576.9576.9576.9592.7692.7692.7692.765.8695.8695.8695.86915.4715.4715.4715.47156.3156.3156.3156.3157.1157.1157.1157.1157.6157.6157.6157.6153.0153.0153.0153.00.1370.1370.1370.137203.8203.8203.8203.8
DWA0.1070.1070.1070.1070.3250.3250.3250.32574.0674.0674.0674.0690.6190.6190.6190.615.0915.0915.0915.09113.9913.9913.9913.99142.3142.3142.3142.3143.0143.0143.0143.0143.4143.4143.4143.4139.3139.3139.3139.30.1250.1250.1250.125175.3175.3175.3175.3
UW0.3870.3870.3870.3870.4250.4250.4250.425166.2166.2166.2166.2155.8155.8155.8155.81.0651.0651.0651.0655.005.005.005.0066.4266.4266.4266.4266.7866.7866.7866.7866.8066.8066.8066.8066.2466.2466.2466.240.1230.1230.1230.123108.0108.0108.0108.0
MGDA0.2170.2170.2170.2170.3680.3680.3680.368126.8126.8126.8126.8104.6104.6104.6104.63.2273.2273.2273.2275.695.695.695.6988.3788.3788.3788.3789.4189.4189.4189.4189.3289.3289.3289.3288.0188.0188.0188.010.1200.1200.1200.120120.5120.5120.5120.5
PCGrad0.1060.1060.1060.1060.2930.2930.2930.29375.8575.8575.8575.8588.3388.3388.3388.333.9403.9403.9403.9409.159.159.159.15116.4116.4116.4116.4116.8116.8116.8116.8117.2117.2117.2117.2114.5114.5114.5114.50.1100.1100.1100.110125.7125.7125.7125.7
CAGrad0.1180.1180.1180.1180.3210.3210.3210.32183.5183.5183.5183.5194.8194.8194.8194.813.2193.2193.2193.2196.936.936.936.93114.0114.0114.0114.0114.3114.3114.3114.3114.5114.5114.5114.5112.3112.3112.3112.30.1160.1160.1160.116112.8112.8112.8112.8
IMTL-G0.1360.1360.1360.1360.2880.2880.2880.28898.3198.3198.3198.3193.9693.9693.9693.961.7531.7531.7531.7535.705.705.705.70101.4101.4101.4101.4102.4102.4102.4102.4102.0102.0102.0102.0100.1100.1100.1100.10.0970.0970.0970.09777.277.277.277.2
Nash-MTL0.1030.1030.1030.1030.2490.2490.2490.24982.9582.9582.9582.9581.8981.8981.8981.892.4262.4262.4262.4265.385.385.385.3874.5274.5274.5274.5275.0275.0275.0275.0275.1075.1075.1075.1074.1674.1674.1674.160.0930.0930.0930.09362.062.062.062.0
IGBv20.2510.2510.2510.2510.3330.3330.3330.333149.1149.1149.1149.1130.2130.2130.2130.20.9560.9560.9560.9564.394.394.394.3956.7556.7556.7556.7557.1957.1957.1957.1957.2557.2557.2557.2556.7356.7356.7356.730.1100.1100.1100.11067.767.767.767.7
Aligned-MTL-UB0.1720.1720.1720.1720.3500.3500.3500.350117.3117.3117.3117.3109.0109.0109.0109.01.5201.5201.5201.5205.235.235.235.2376.1376.1376.1376.1376.5876.5876.5876.5876.6276.6276.6276.6275.7175.7175.7175.710.9800.9800.9800.98071.071.071.071.0
BayesAgg-MTL (Ours)0.1220.1220.1220.1220.2800.2800.2800.28087.7887.7887.7887.7890.4490.4490.4490.441.7761.7761.7761.7765.315.315.315.3163.3363.3363.3363.3364.9164.9164.9164.9166.7166.7166.7166.7181.9181.9181.9181.910.0930.0930.0930.09353.253.2\mathbf{53.2}bold_53.2
Atel.Card.Cons.EdemaEffusionEmphysemaFibrosisHerniaInfi.MassNodulePleu.PneumoniaPneu.
AUC-ROC \uparrow𝚫𝐦%percentsubscript𝚫𝐦\mathbf{\Delta_{m}\%}bold_Δ start_POSTSUBSCRIPT bold_m end_POSTSUBSCRIPT % (\downarrow)
STL.7543.7543.7543.7543.8615.8615.8615.8615.7132.7132.7132.7132.8212.8212.8212.8212.8224.8224.8224.8224.6333.6333.6333.6333.7357.7357.7357.7357.7647.7647.7647.7647.6830.6830.6830.6830.6208.6208.6208.6208.5894.5894.5894.5894.6389.6389.6389.6389.5710.5710.5710.5710.7701.7701.7701.7701
LS.7744.7744.7744.7744.8804.8804.8804.8804.7477.7477.7477.7477.8457.8457.8457.8457.8273.8273.8273.8273.8798.8798.8798.8798.8250.8250.8250.8250.9129.9129.9129.9129.7013.7013.7013.7013.8209.8209.8209.8209.7593.7593.7593.7593.7660.7660.7660.7660.7235.7235.7235.7235.8525.8525.8525.852514.6214.62-14.62- 14.62
SI.7457.7457.7457.7457.8739.8739.8739.8739.7289.7289.7289.7289.8426.8426.8426.8426.8152.8152.8152.8152.8593.8593.8593.8593.7903.7903.7903.7903.8045.8045.8045.8045.6996.6996.6996.6996.7971.7971.7971.7971.7268.7268.7268.7268.7353.7353.7353.7353.6993.6993.6993.6993.8389.8389.8389.83891.941.94-1.94- 1.94
RLW.7596.7596.7596.7596.8704.8704.8704.8704.7389.7389.7389.7389.8385.8385.8385.8385.8218.8218.8218.8218.8390.8390.8390.8390.7956.7956.7956.7956.8646.8646.8646.8646.6991.6991.6991.6991.7933.7933.7933.7933.7340.7340.7340.7340.7362.7362.7362.7362.7101.7101.7101.7101.8345.8345.8345.834511.6911.69-11.69- 11.69
DWA.7734.7734.7734.7734.8847.8847.8847.8847.7503.7503.7503.7503.8482.8482.8482.8482.8267.8267.8267.8267.8768.8768.8768.8768.8185.8185.8185.8185.9410.9410.9410.9410.6977.6977.6977.6977.8175.8175.8175.8175.7590.7590.7590.7590.7739.7739.7739.7739.7240.7240.7240.7240.8440.8440.8440.844014.7914.79-14.79- 14.79
UW.7600.7600.7600.7600.8870.8870.8870.8870.7437.7437.7437.7437.8464.8464.8464.8464.8221.8221.8221.8221.8768.8768.8768.8768.8176.8176.8176.8176.9434.9434.9434.9434.7012.7012.7012.7012.8049.8049.8049.8049.7426.7426.7426.7426.7608.7608.7608.7608.7057.7057.7057.7057.8498.8498.8498.849813.9513.95-13.95- 13.95
MGDA.7720.7720.7720.7720.8857.8857.8857.8857.7473.7473.7473.7473.8454.8454.8454.8454.8260.8260.8260.8260.8762.8762.8762.8762.8181.8181.8181.8181.9290.9290.9290.9290.6961.6961.6961.6961.8141.8141.8141.8141.7570.7570.7570.7570.7661.7661.7661.7661.7213.7213.7213.7213.8479.8479.8479.847914.4414.44-14.44- 14.44
PCGrad.7678.7678.7678.7678.8793.8793.8793.8793.7461.7461.7461.7461.8432.8432.8432.8432.8266.8266.8266.8266.8721.8721.8721.8721.8165.8165.8165.8165.8565.8565.8565.8565.6991.6991.6991.6991.8123.8123.8123.8123.7499.7499.7499.7499.7629.7629.7629.7629.7203.7203.7203.7203.8451.8451.8451.845113.4313.43-13.43- 13.43
CAGrad.7744.7744.7744.7744.8823.8823.8823.8823.7489.7489.7489.7489.8464.8464.8464.8464.8269.8269.8269.8269.8756.8756.8756.8756.8199.8199.8199.8199.9201.9201.9201.9201.6998.6998.6998.6998.8158.8158.8158.8158.7567.7567.7567.7567.7702.7702.7702.7702.7207.7207.7207.7207.8482.8482.8482.848214.5014.50-14.50- 14.50
IMTL-G.7395.7395.7395.7395.8533.8533.8533.8533.7229.7229.7229.7229.8235.8235.8235.8235.8023.8023.8023.8023.7692.7692.7692.7692.7538.7538.7538.7538.8973.8973.8973.8973.6903.6903.6903.6903.7543.7543.7543.7543.7052.7052.7052.7052.7221.7221.7221.7221.6758.6758.6758.6758.8026.8026.8026.80268.248.24-8.24- 8.24
Nash-MTL.7623.7623.7623.7623.8774.8774.8774.8774.7445.7445.7445.7445.8420.8420.8420.8420.8206.8206.8206.8206.8627.8627.8627.8627.8214.8214.8214.8214.8997.8997.8997.8997.6999.6999.6999.6999.8035.8035.8035.8035.7412.7412.7412.7412.7553.7553.7553.7553.7117.7117.7117.7117.8447.8447.8447.844713.2413.24-13.24- 13.24
IGBv2.7189.7189.7189.7189.8354.8354.8354.8354.7049.7049.7049.7049.8097.8097.8097.8097.7865.7865.7865.7865.7360.7360.7360.7360.7160.7160.7160.7160.7053.7053.7053.7053.6858.6858.6858.6858.6828.6828.6828.6828.6647.6647.6647.6647.7038.7038.7038.7038.6512.6512.6512.6512.7783.7783.7783.77832.832.83-2.83- 2.83
Aligned-MTL-UB.7689.7689.7689.7689.8801.8801.8801.8801.7491.7491.7491.7491.8456.8456.8456.8456.8245.8245.8245.8245.8772.8772.8772.8772.8221.8221.8221.8221.8992.8992.8992.8992.6997.6997.6997.6997.8115.8115.8115.8115.7543.7543.7543.7543.7674.7674.7674.7674.7208.7208.7208.7208.8497.8497.8497.849714.1514.15-14.15- 14.15
BayesAgg-MTL (Ours).7761.7761.7761.7761.8836.8836.8836.8836.7511.7511.7511.7511.8487.8487.8487.8487.8293.8293.8293.8293.8863.8863.8863.8863.8289.8289.8289.8289.9121.9121.9121.9121.6967.6967.6967.6967.8220.8220.8220.8220.7622.7622.7622.7622.7762.7762.7762.7762.7214.7214.7214.7214.8545.8545.8545.854514.9614.96\mathbf{-14.96}- bold_14.96
Bayesian Uncertainty for Gradient Aggregation in Multi-Task Learning (2024)
Top Articles
Latest Posts
Article information

Author: Jamar Nader

Last Updated:

Views: 6529

Rating: 4.4 / 5 (75 voted)

Reviews: 90% of readers found this page helpful

Author information

Name: Jamar Nader

Birthday: 1995-02-28

Address: Apt. 536 6162 Reichel Greens, Port Zackaryside, CT 22682-9804

Phone: +9958384818317

Job: IT Representative

Hobby: Scrapbooking, Hiking, Hunting, Kite flying, Blacksmithing, Video gaming, Foraging

Introduction: My name is Jamar Nader, I am a fine, shiny, colorful, bright, nice, perfect, curious person who loves writing and wants to share my knowledge and understanding with you.