Figure 1: Illustration of the meta-learning process as applied to the task of personalized next-word prediction. Here each mobile device corresponds to a different next-word prediction task, with the test-task not seen during meta-training (Step 1).
In the classical machine learning setup, we aim to learn a single model for a single task given many training samples from the same distribution. However, in many practical applications, we are in fact exposed to several distinct yet related tasks that have only a few examples each. Because the data now come from different training distributions, simply learning a single global model, e.g., via stochastic gradient descent (SGD), may result in poor performance on each task. As a result, designing algorithms for learning-to-learn from multiple tasks has become a major area of study in machine learning, with the promise of improving performance on variety of tasks ranging from personalized next-character prediction on your smartphone to fast robotic adaptation in diverse environments.
While learning-to-learn is a well-studied area, it has received renewed attention in recent years. Often referred to as meta-learning, modern methods have the similar goal of improving sample complexity using multi-task data while also scaling to many more tasks and much larger models. Starting with MAML, several methods have been proposed in which samples from the related tasks are used to “meta-learn” — learn using multi-task data — an initialization \(\phi\) for SGD, with the hope that for single-task learning a few within-task gradient-updates starting from \(\phi\) will suffice to learn a good task-specific model \(\hat\theta\). The procedures for meta-learning this initialization are quite simple; for example, the popular Reptile algorithm returns an initialization \(\phi\) by iterating through the tasks \(t=1,\dots,T\), running SGD starting from \(\phi\) on each one to get task-parameters \(\hat\theta_t\), and moving the initialization slightly closer to the obtained model: \(\phi\gets\phi+\alpha_t(\hat\theta_t-\phi)\) for some learning rate \(\alpha_t>0\). By combining this simplicity with their inherent flexibility (i.e. they can be used with any differentiable parametric model), these gradient-based meta-learning methods achieve strong performance in many areas including few-shot image classification, meta-RL, and federated learning.
The success of such simple, initialization-based methods raises a natural question: why do they work so well? An intuitive answer is that in many settings, most tasks can be solved by model parameters that lie fairly close together. In my recent work with Nina Balcan and Ameet Talwalkar we formalize this intuition and how these methods exploit it as the natural outcome of solving a sequence of learning tasks using SGD. This post takes a look at this result and discusses how it leads to a theoretical framework, ARUBA, for developing and analyzing new meta-learning methods by studying upper bounds on the performance of repeated within-task algorithms. At a high level, ARUBA uses the fact that algorithms such as MAML, Reptile, and FedAvg are running gradient descent both within-task and for the meta-update; we can thus apply a vast array of existing low-regret and stochastic approximation results to prove meta-learning bounds for these methods and derive new algorithmic variants. We will highlight an especially interesting application of the ARUBA framework – a new adaptive learning rate for meta-learning that can be informally thought of as multi-task AdaGrad.
The limitations of single-task learning
To keep things both simple and general, we will work in the setting of distribution-free online learning in this post (though our results do imply bounds on the excess risk in the distributional setting): on each round \(i=1,\dots,m\) of task \(t\) the meta-learner picks \(\theta_{t,i}\in\Theta\) and suffers loss \(\ell_{t,i}(\theta_{t,i})\), where we make no stochastic assumptions on the differentiable loss function \(\ell_{t,i}:\Theta\mapsto\mathbb R\). In the case of supervised prediction, \(\ell_{t,i}(\theta)=L(f_\theta(x_{t,i}),y_{t,i})\) subsumes the performance of a parameterized model \(f_\theta\) on a data-point \((x_{t,i},y_{t,i})\) under some loss function \(L\). Within-task, the goal of the learner is to minimize regret, defined as the total loss compared to that of the best fixed parameter in hindsight:
$$
R_t
=\sum_{i=1}^m\ell_{t,i}(\theta_{t,i})-\min_{\theta\in\Theta}\sum_{i=1}^m\ell_{t,i}(\theta)
$$
For simplicity let \(\theta_t^\ast=\arg\min_\theta\sum_{i=1}^m\ell_{t,i}(\theta)\) be unique for each task.
What can we do in single-task online learning? A simple thing to try is online gradient descent: for initialization \(\phi\in\Theta\) and learning rate \(\eta>0\), take action \(\theta_{t,1}=\phi\) on round 1 and on subsequent rounds determine actions via the iteration \(\theta_{t,i+1}=\theta_{t,i}-\eta\nabla\ell_{t,i}(\theta_{t,i})\). If we assume that the loss functions are Lipschitz and the domain \(\Theta\) has bounded radius \(D\) (this covers problems such as constrained linear and logistic regression), then a classical result in online convex optimization (see the excellent survey by Shai Shalev-Shwartz) gives the following upper bound on the regret:
$$
R_t=
\sum_{i=1}^m\ell_{t,i}(\theta_{t,i})-\ell_{t,i}(\theta_t^\ast)\le\frac{\|\phi-\theta_t^\ast\|_2^2}{2\eta}+\eta m
$$
Setting the learning rate to be \(\eta=D/\sqrt m\) gives sublinear regret \(R_t=\mathcal O(D\sqrt m)\), for which there exists a tight lower bound. This is great when we have many rounds — then the average loss is proportional to \(D/\sqrt m\) — but when \(m\) is small this does not decrease fast enough to ensure low error.
Running gradient descent on top of gradient descent
The goal of meta-learning is to circumvent the single-task lower bound by exploiting multi-task information to do well on-average across tasks, i.e. by ensuring low task-averaged regret \(\bar R\):
$$
\bar R
=\frac1T\sum_{t=1}^TR_t
=\frac1T\sum_{t=1}^T\sum_{i=1}^m\ell_{t,i}(\theta_{t,i})-\ell_{t,i}(\theta_t^\ast)
$$
To do so we need to make some assumption on the task-similarity — otherwise the lower bound ensures \(\bar R=\Omega(D\sqrt m)\). A simple yet natural notion of task-similarity is closeness of the optimal parameters \(\theta_t^\ast\), which we can measure using their empirical variance \(V^2=\min_\phi\frac1T\sum_{t=1}^T\|\theta_t^\ast-\phi\|_2^2\). It turns out that for a Reptile-like online algorithm we can show that \(\bar R\to\mathcal O(V\sqrt m)\) as \(T\to\infty\), which as illustrated below can be a significant improvement over the best possible single-task bound \(\mathcal O(D\sqrt m)\) when the number of samples \(m\) is small. The analysis that follows thus formalizes an intuitive understanding about meta-learning methods for learning the initialization — their success depends directly on the closeness of high-accuracy parameters across the tasks.
We will consider an algorithm in which on each task \(t\) we will run online gradient descent with learning rate \(\eta>0\) initialized at \(\phi_t\). All that is left is to specify an update rule for this initialization. One thing we can do is try to set \(\phi_t\) so as to minimize the within-task regret-upper-bound \(U_t(\phi)=\frac{\|\phi-\theta_t^\ast\|_2^2}{2\eta}+\eta m\ge R_t\) that we saw earlier. This can also be done using gradient descent, now at the meta-level. Taking derivatives, we have the following update rule for learning rate \(\tilde\alpha_t>0\):
$$
\phi_{t+1}
=\phi_t-\tilde\alpha_t\nabla U_t(\phi_t)
=\phi_t+\frac{\tilde\alpha_t}\eta(\theta_t^\ast-\phi_t)
$$
For \(\alpha_t=\tilde\alpha_t/\eta\) this is almost exactly the same as the Reptile update rule given in the introduction, except there we had the last iterate \(\hat\theta_t\) of within-task SGD instead of the optimum in hindsight \(\theta_t^\ast\) (by making a reasonable non-degeneracy assumption, we can also avoid having to find \(\theta_t^\ast\)). Furthermore, since the regret-upper-bound \(U_t\) is \(\frac1\eta\)-strongly-convex, by setting \(\alpha_t=1/t\) we have regret guarantee
$$\sum_{t=1}^TU_t(\phi_t)-\min_{\phi\in\Theta}\sum_{t=1}^TU_t(\phi)=O\left(\frac{\log T}\eta\right)$$
This can then be substituted into the following short proof:
$$
\begin{aligned}
\bar R
=\frac1T\sum_{t=1}^TR_t
&\le\frac1T\sum_{t=1}^TU_t(\phi_t)\\
&=\frac1T\left(\sum_{t=1}^TU_t(\phi_t)-\min_{\phi\in\Theta}\sum_{t=1}^TU_t(\phi)\right)\qquad+\qquad\quad\min_{\phi\in\Theta}\frac1T\sum_{t=1}^TU_t(\phi)\\
&=\qquad\qquad\mathcal O\left(\frac{\log T}{\eta T}\right)\qquad\qquad+\qquad\qquad\min_{\phi\in\Theta}\frac1T\sum_{t=1}^T\frac{\|\theta_t^\ast-\phi\|_2^2}{2\eta}+\eta m\\
&=\qquad\qquad\mathcal O\left(\frac{\log T}{\eta T}\right)\qquad\qquad+\qquad\qquad\qquad\mathcal O\left(\frac{V^2}\eta+\eta m\right)
\end{aligned}
$$
Here the last line follows by the definition of \(V^2\). By setting \(\eta=V/\sqrt m\) we get \(\bar R=\mathcal O\left(\frac{\sqrt m}{VT}\log T+V\sqrt m\right)\to\mathcal O(V\sqrt m)\) as \(T\to\infty\), as desired.
Average Regret-Upper-Bound Analysis
The above analysis depends crucially on two properties of the regret-upper-bounds \(U_t(\phi)\):
- Data-dependence: the fact that \(U_t\) encodes the distance of the optimal parameter \(\theta_t^\ast\) from the initialization \(\phi_t\) allowed us to bound the average regret using a natural notion of task-similarity.
- Niceness: the strong-convexity of \(U_t\) allowed us to directly apply logarithmic regret guarantees from the online convex optimization literature.
Our main insight with the ARUBA framework is that these properties are not unique to the learning of a fixed initialization for simple gradient descent. Instead, they make available a vast arsenal of tools from online learning and sequential prediction that can be applied on either the meta-level or the within-task level or both. For example, on the meta-level we can apply dynamic regret guarantees to handle the case when the task-environment is non-stationary, or apply specialized online-to-batch conversion results for strongly-convex functions to obtain sharper statistical rates on the task-averaged regret.
On the within-task level, we can reason about algorithms other than online gradient descent that come with their own bounds \(U_t\). This has the potential to greatly expand the scope of meta-learning theory and practice; for example, if we have a nicely-parameterized guarantee on the reward of some reinforcement learning procedure we could provably optimize it using gradient descent to obtain both a theorem and an algorithm for meta-RL. In the paper, we generalize the results above to algorithms in the online mirror descent family such as exponentiated gradient descent, where \(U_t\) depends on the KL-divergence from the initialization (in this case a prior). One can even apply the analysis when \(U_t\) is not a regret-upper-bound but a different performance guarantee such as a bound on the excess population risk (expected loss) of task \(t\). In the following section we highlight an application to the case of gradient descent whose step-size varies across dimensions (per-coordinate), which leads to a practical way of meta-learning an adaptive learning rate for gradient descent that yields improvements on standard few-shot image classification benchmarks.
Using ARUBA to adapt to the inter-task geometry
In many settings, it is beneficial to precondition the gradient \(\nabla_{t,i}=\nabla\ell_{t,i}(\theta_{t,i})\) via element-wise multiplication by a per-coordinate learning rate \(\eta\in\mathbb R_+^d\) before performing the update:
$$\theta_{t,i+1}=\theta_{t,i}-\eta\odot\nabla_{t,i}$$
In the convex case, a per-coordinate learning rate improves performance when the distribution of optimal task parameters \(\theta_t^\ast\) is highly non-isotropic (specifically, when it deviates strongly away from the mean along certain coordinates) or when the magnitude of the gradient varies significantly across dimensions. For meta-learning and transfer learning with deep networks, we expect certain model parameters, such as lower-level feature extractors, to remain the same across tasks and other parameters, such as final layer classification weights, to change significantly.
Below, we highlight how we can use ARUBA to derive a modification of Reptile that simultaneously meta-learns a per-coordinate learning rate. While the analysis is for the convex case, when applying this new meta-training procedure for few-shot image classification using a convolutional network, we see a clear pattern emerge in the learning rate that follows the intuition described above: the shared lower-level weights do not vary strongly from task-to-task and thus are not updated as much as the higher-level filters and some of the classification parameters:
More empirical results, including several settings where we use our new learning rate to improve upon Reptile and FedAvg, are given in the paper. For now I will sketch the derivation and intuition behind our multi-task learning rate. The main item we need is the regret-upper-bound of preconditioned online gradient descent:
$$
\begin{aligned}
U_t(\phi,\eta)
\qquad
&=\qquad~~\frac12\|\theta_t^\ast-\phi\|_{1/\eta}^2\qquad~+\qquad~~\sum_{i=1}^m\|\nabla_{t,i}\|_\eta^2\\
&=\qquad\sum_{j=1}^d\frac{(\theta_{t,j}^\ast-\phi_j)^2}{2\eta_j}\qquad+\qquad\sum_{j=1}^d\eta_j\sum_{i=1}^m\nabla_{t,i,j}^2
\end{aligned}
$$
This regret-upper-bound exemplifies the two properties — niceness and data-dependence — that make ARUBA useful and widely applicable:
Data-dependence: \(U_t(\phi,\eta)\) is similar in form to the regret-upper-bound of online gradient descent, except now we take a summation over coordinates \(j=1,\dots,d\). If the optimal parameter is far away from the initialization along coordinate \(j\), a high learning rate \(\eta_j\) will reduce the negative effect of this due to the first term; however, if the gradients along coordinate \(j\) are large, we want a small \(\eta_j\) due to the second term. The optimal choice of \(\eta_j\) will find the best tradeoff between these two effects.
Niceness: Since the regret-upper-bound \(U_t(\phi,\eta)\) is convex in both variables we can apply tools from online convex optimization. The analysis is more involved as \(U_t\) is not strongly-convex and not Lipschitz, but we can still asymptotically achieve the average regret bound of the optimal initialization and learning rate by setting the task-\(t\) initialization in the same way as before — \(\phi_{t+1}=\phi_{t+1}+\alpha_t(\theta_t^\ast-\phi_t)\) — and setting the \(j\)th-coordinate of the task-\(t\) learning rate as follows:
$$
\begin{aligned}
\eta_j
=\sqrt{\frac{B_{t,j}+\varepsilon_t}{G_{t,j}+\zeta_t}}
\qquad\textrm{for}&\qquad B_{t,j}=\frac12\sum_{s<t}(\theta_{t,j}^\ast-\phi_{t,j})^2\qquad\textrm{(sum of squared distances)}\\
&\qquad G_{t,j}=\sum_{s<t}\sum_{i=1}^m\nabla_{t,i,j}^2\qquad\qquad\textrm{ (sum of squared gradients)}\\\\
&\qquad o(t)~~\textrm{smoothing terms}~~\varepsilon_t~~\textrm{and}~~\zeta_t
\end{aligned}
$$
As before, to obtain a practical algorithm we can replace \(\theta_t^\ast\) by \(\hat\theta_t\), the last iterate on task \(t\). Note that the denominator of \(\eta_j\) is very similar to that of the popular AdaGrad learning rate, which sets
$$\eta_j=\frac{\eta_0}{\sqrt{G_{t,j}+\delta}}$$
for some base rate \(\eta_0>0\) and smoothing term \(\delta=O(1)\). This correspondence makes clear why our learning rate is helpful, as in meta-learning we do not want to down-weight a coordinate’s learning rate just because it has a history of large gradients — it could be the case that the optimal task-parameters vary strongly along that coordinate and so we want to use the large gradient to move further away from the initialization. Instead, our learning rate corrects the numerator of AdaGrad so that it is higher if within-task we tend to travel far away from the initialization along that coordinate. Thus we will only down-weight a coordinate if its gradients tend to be large but noisy and thus do not cumulatively lead to progress in that direction.
Conclusion
As machine learning moves further beyond single-task supervised learning, understanding scalable methods for transfer learning will become more important in both research and applications. Thanks to its simplicity and interoperability with many types of within-task learning algorithms and meta-level training procedures, our framework can be broadly useful in the effort to address these questions. Indeed, we have already run experiments showing that the above adaptive multi-task learning rate yields an effective modification of FedAvg that requires no tuning for personalization. Ameet and I, together with Jeff Li and Sebastian Caldas, have also applied ARUBA to give provable guarantees for a practical algorithm in the setting of meta-learning under privacy constraints. Nevertheless, I believe there remain many promising theoretical directions, in both these areas and others such as meta-RL and continual learning, that are yet to be explored.
DISCLAIMER: All opinions expressed in this posts are those of the author and do not represent the views of Carnegie Mellon University.