Figure 1: An example of federated learning for the task of next-word prediction on mobile phones. Devices communicate with a central server periodically to learn a global model. Federated learning helps preserve user privacy and reduce strain on the network by keeping data localized.
What is federated learning? How does it differ from traditional large-scale machine learning, distributed optimization, and privacy-preserving data analysis? What do we understand currently about federated learning, and what problems are left to explore? In this post, we briefly answer these questions, and describe ongoing work in federated learning at CMU. For a more comprehensive and technical discussion, please see our recent white paper.
Federated Learning is privacy-preserving model training in heterogeneous, distributed networks.
Mobile phones, wearable devices, and autonomous vehicles are just a few of the modern distributed networks generating a wealth of data each day. Due to the growing computational power of these devices—coupled with concerns about transmitting private information—it is increasingly attractive to store data locally and push network computation to the edge devices. Federated learning has emerged as a training paradigm in such settings. As we discuss in this post, federated learning requires fundamental advances in areas such as privacy, large-scale machine learning, and distributed optimization, and raises new questions at the intersection of machine learning and systems.
Potential applications of federated learning may include tasks such as learning the activities of mobile phone users, adapting to pedestrian behavior in autonomous vehicles, or predicting health events like heart attack risk from wearable devices. We discuss two canonical applications in more detail below.
Federated learning has been deployed in practice by major companies, and plays a critical role in supporting privacy-sensitive applications where the training data are distributed at the edge. Next, we formalize the problem of federated learning and describe some of the fundamental challenges associated with this setting.
The canonical federated learning problem involves learning a single, global statistical model from data stored on tens to potentially millions of remote devices. In particular, the goal is typically to minimize the following objective function:
$$\min_w F(w), ~\text{where}~ F(w) := \sum_{k=1}^m p_k F_k(w)$$
Here \(m\) is the total number of devices, \(F_k\) is the local objective function for the \(k\)th device, and \(p_k\) specifies the relative impact of each device with \(p_k ≥ 0 \) and \( \sum_{k=1}^m p_k = 1 \).
The local objective function \(F_k\) is often defined as the empirical risk over local data. The relative impact of each device \(p_k\) is user-defined, with two natural settings being \(p_k = \frac{1}{m}\) or \(p_k = \frac{n_k}{n}\), where \(n\) is the total number of samples over all devices. Although this is a common federated learning objective, there exist other alternatives such as simultaneously learning distinct but related local models via multi-task learning (Smith et al, 2017) where each device corresponds to a task. Both the multi-task and meta-learning perspectives enable personalized or device-specific modeling, which can be a natural approach to handle the statistical heterogeneity of the data.
At first sight, both federated learning and classical distributed learning share a similar goal of minimizing the empirical risk over distributed entities. However, there are fundamental challenges associated with solving the above objective in the federated settings, as we describe below.
These challenges resemble classical problems in areas such as privacy, large-scale machine learning, and distributed optimization. For instance, numerous methods have been proposed to tackle expensive communication in the machine learning, optimization, and signal processing communities. However, prior methods are typically unable to fully handle the scale of federated networks, much less the challenges of systems and statistical heterogeneity. Similarly, while privacy is an important aspect for many machine learning applications, privacy-preserving methods for federated learning can be challenging to rigorously assert due to the statistical variation in the data, and may be difficult to implement due to systems constraints on each device and across the potentially massive network. Our white paper provides more details about each of these core challenges, including a review of recent work and connections with classical results.
Federated learning is an active area of research across CMU. Below, we highlight a sample of recent projects by our group and close collaborators that address some of the unique challenges in federated learning.
[S. Caldas, S. Duddu, P. Wu, T. Li, J. Konečný, B. McMahan, V. Smith, A. Talwalkar]
The field of federated learning is in its nascency, and we are at a pivotal time to shape the developments made in this area and ensure that they are grounded in real-world settings, assumptions, and datasets. LEAF is a modular benchmarking framework for learning in federated settings. It includes a suite of open-source federated datasets, a rigorous evaluation framework, and a set of reference implementations to facilitate both the reproducibility of empirical results and the dissemination of new solutions for federated learning.
[T. Li, A. K. Sahu, M. Sanjabi, M. Zaheer, A. Talwalkar, V. Smith]
While the current state-of-the-art method for federated learning, FedAvg (McMahan et al, 2017), has demonstrated empirical success, it does not fully address the underlying challenges associated with heterogeneity, and can diverge in practice. This work introduces an optimization framework, FedProx, to tackle systems and statistical heterogeneity. FedProx adds a proximal term to the local loss functions, which (1) makes the method more amenable to theoretical analysis in the presence of heterogeneity, and (2) demonstrates more robust empirical performance than FedAvg.
[T. Li, M. Sanjabi, A. Beirami, V. Smith]
Naively minimizing an aggregate loss function in a heterogeneous federated network may disproportionately advantage or disadvantage some of the devices. This work proposes q-Fair Federated Learning (q-FFL), a novel and flexible optimization objective inspired by fair resource allocation in wireless networks that encourages a more fair accuracy distribution by adaptively imposing higher weight to devices with higher loss. To solve q-FFL, the authors devise a communication-efficient method that can be implemented at scale across hundreds to millions of devices.
[M. Khodak, M. Balcan, A. Talwalkar]
Gradient-based meta learning (GBML) aims to learn good meta-initializations for gradient descent methods for new tasks, and can be applied to federated settings where each device can be viewed as a task. The authors propose a theoretical framework, Average Regret-Upper-Bound Analysis (ARUBA), that connects GBML to online convex optimization, and improves the existing transfer-risk bounds in meta-learning. In applications, ARUBA leads to better methods for federated learning and few-shot learning via learning a per-coordinate learning rate. It also serves as the foundation for the differentially private meta learning work (described below).
[J. Li, M. Khodak, S. Caldas, A. Talwalkar]
Parameter-transfer is a well-known approach for meta-learning, with applications including federated learning. However, parameter-transfer algorithms often require sharing models that have been trained on the samples from specific tasks, thus leaving the task-owners susceptible of privacy leakage. This work formalizes the notion of task-global differential privacy as a practical relaxation of more commonly studied threat models. The authors then propose a new differentially private algorithm for gradient-based parameter transfer that retains provable transfer learning guarantees in convex settings.
[T. Li, Z. Liu, V. Sekar, V. Smith]
Communication and privacy are two critical concerns in distributed learning (particularly federated learning), but are typically treated separately. This work argues that a natural connection exists between methods for communication reduction and privacy preservation. In particular, the authors prove that Count Sketch (a method for data stream summarization) has inherent differential privacy properties without additional mechanisms. Using these guarantees, they propose a sketch-based framework for distributed learning, where the transmitted messages are compressed via sketches to simultaneously achieve communication efficiency and provable privacy benefits.
Although recent work has begun to address the challenges discussed in this post, there are a number of critical open directions in federated learning that are yet to be explored. We briefly list some open problems below.
These challenging problems (and more) will require collaborative efforts from a wide range of research communities.
See our recent white paper: Federated Learning: Challenges, Methods, and Future Directions
DISCLAIMER: All opinions expressed in this post are those of the author and do not represent the views of Carnegie Mellon University.
]]>Figure 1: A two-dimensional convex function represented via contour lines. Each ellipse denotes a contour line. The function value is constant on the boundary of each such ellipse, and decreases as the ellipse becomes smaller and smaller. Let us assume we want to minimize this function starting from a point \(A\). The red line shows the path followed by a gradient descent optimizer converging to the minimum point \(B\), while the green dashed line represents the direct line joining \(A\) and \(B\).
In today’s post, we will discuss an interesting property concerning the trajectory of gradient descent iterates, namely the length of the Gradient Descent curve. Let us assume we want to minimize the function shown in Figure 1 starting from a point \(A\). We deploy gradient descent (GD) to this end, and converge to the minimum point \(B\). The green dashed line represents the direct line joining \(A\) and \(B\). GD however does not initially know the location of \(B\). Instead, it performs local updates that direct it towards \(B\), as represented by the red solid line. Consequently, the length of the red path that GD follows is longer than the direct green path. How long will the red path be? We studied this question and arrived at some interesting bounds on the path length (both upper and lower) in various settings (check out our preprint!). We discuss some of these bounds in this post. But first, we motivate why understanding path lengths could be useful for understanding nonconvex optimization.
Consider the following two-dimensional loss surface with multiple global minima, all equivalent in function value. Over-parameterized systems such as deep neural networks may have losses that exhibit a similar structure.
Suppose we were to run GD to minimize the loss function, starting from two different initialization points \(A\) and \(B\). The GD paths in both cases converge to different minima, which have the same function value. Intuitively, we know that because the slope is high, convergence is fast. This is because on the path of interest followed by the GD iterates, the function behaves like a convex function (in fact, a strongly convex function). Yet, our traditional convergence bounds would fail here since they require such properties to hold at a global level. Check out this blog by Prof. Sanjeev Arora for more insights on how our conventional view on optimization is insufficient to explain optimization in some modern machine learning settings.
By proving path length bounds for GD curves, we hope to establish that Gradient Descent explores only a small region of the parameter space. Thus desirable properties such as convexity need only hold in a small region. A number of recent advances toward proving convergence and generalization bounds for deep neural networks depend on path length bounds (for example, this paper). In this post we will discuss some interesting upper and lower bounds on the path length of GD curves under a variety of assumptions on the loss surface. We first begin with a meta-theorem that connects fast convergence with short paths.
We know that GD exhibits fast convergence in many situations. Does fast convergence imply that the path followed must be short? Perhaps unsurprisingly, the answer is yes for linear convergence. If the iterates are \(\mathbf{x}_0, \mathbf{x}_1, \mathbf{x}_2, \ldots \) and the optimal point is \(\mathbf{x}^*\), then linear convergence refers to the following recurrent inequality for every \(k\) (assume all norms, denoted by \(\|\cdot\|\), in this post refer to the Euclidean \(\ell_2\) norm):
$$\|\mathbf{x}_k – \mathbf{x}^*\| \leq (1 – c) \|\mathbf{x}_{k-1} – \mathbf{x}^*\|.$$
If linear convergence holds for any iterative optimization algorithm (not just GD), we can derive a path length bound by using the triangle inequality:
Consider the triangle \(\mathbf{x}_k \mathbf{x}_{k+1} \mathbf{x}^*\) above. Assuming \(\|\mathbf{x}_k – \mathbf{x}^*\| = \varphi\), we have a bound \(\|\mathbf{x}_{k+1} – \mathbf{x}^*\| \leq (1 – c)\varphi\) because of linear convergence. Then by the triangle inequality, \(\|\mathbf{x}_{k} – \mathbf{x}_{k+1}\| \leq (2-c) \varphi\). We use the same technique for the next triangle \(\mathbf{x}_{k+1} \mathbf{x}_{k+2} \mathbf{x}^*\) but the length \(\|\mathbf{x}_{k+1} – \mathbf{x}^*\|\) is now smaller than \((1-c)\varphi\) instead of just \(\varphi\). Thus, \(\|\mathbf{x}_{k+1} – \mathbf{x}_{k+2}\| \leq (2-c)(1-c) \varphi\). In this manner, we add up all the contributions starting from \(k = 0\) to \(k = \infty\). This leads to the following geometric series:
$$\begin{align*}\sum_{k=0}^\infty \|\mathbf{x}_{k} – \mathbf{x}_{k+1}\| &\leq \sum_{k=0}^\infty (2-c)(1-c)^k\|\mathbf{x}_{0} – \mathbf{x}^*\| \\ &= \left( \frac{2-c}{c}\right) \|\mathbf{x}_{0} – \mathbf{x}^*\|. \end{align*}$$
Thus the total path of GD, even if we were to run it forever, would still be bounded by the distance between the initial point and the convergent point, times a number that depends on the linear convergence constant \(c\).
The bound we introduced above is very general. It applies to any iterative algorithm as long as linear convergence holds. However, often with generality comes lack of precision. Let us understand the implications for an important class of functions: strongly convex and smooth (SCS) functions. SCS functions satisfy the following two properties, for some \(L \geq \mu > 0\):
The condition number of \(f\) is denoted as \(\kappa := L/\mu\). This number is important because it governs the convergence speed of first order methods such as GD on SCS functions. Linear convergence holds for SCS functions with \(c = 1/\kappa\). Thus the path length bound is \(2\kappa\|\mathbf{x}_0 – \mathbf{x}^*\|\). It turns out that this bound is not tight, and a better upper bound of \(2\sqrt\kappa\|\mathbf{x}_0 – \mathbf{x}^*\|\) can be achieved for SCS functions (as shown in this paper and this paper). In the rest of this blog, we absorb the \(\|\mathbf{x}_0 – \mathbf{x}^*\|\) term in the product in the \(\mathcal{O}\) notation since all bounds are multiplicative with that term.
Is the \(\mathcal{O}(\sqrt\kappa)\) bound tight or can it be improved further? Are there interesting special cases? We studied this problem and found answers to the following two questions:
PLS functions are a generalization of SCS functions. The smoothness condition for PLS functions is the same as the SCS case but we replace the strong convexity assumption with the following (weaker) condition:
PLS functions are a superset of SCS functions (PLS functions need not even be convex). If the strong convexity assumption is satisfied for some \(\mu \geq 0\), then so is the PL assumption for the same \(\mu\). In this case too, we define the condition number as \(\kappa = L/\mu\). GD on PLS functions continues to exhibit linear convergence. Furthermore, the \(\mathcal{O}(\sqrt\kappa)\) path length bound is true for PLS functions as well. (If you are interested in knowing more about GD on PLS functions, check out this paper!)
Although lower bounds for SCS functions remains an important open question, we were in fact able to show a \(\mathcal{\widetilde{\Omega}}(\min(\sqrt{d}, \kappa^{1/4}))\) worst case lower bound for PLS functions (the \(\mathcal{\widetilde{\Omega}}\) hides \(\log\) factors). To prove this lower bound, we construct a PLS function \(f\) with condition number \(\kappa\) and identify a particular initialization point \(\mathbf{x}_0\) such that the GD path length from this point is lower bounded as claimed.
Suppose our parameter space is \(\mathbb{R}^d\) and \(\mathbf{x}\) is a point in this space. The lower bound function \(f\) we construct is defined as the sum of \(d\) identical scalar functions over each component of the parameter space. That is, \(f(\mathbf{x}) = \sum_{i=1}^d g(\mathbf{x}_i)\), where \(g\) is a scalar PLS function with condition number \(\kappa\) (thus the condition number of \(f\) is also \(\kappa\)). This is a plot of \(g\) over the domain \([0, 1.2]\):
\(g\) is designed so that it is equal to \(x^2\) in the interval (\([0, 0.5]\)) and then gradually tapers off while maintaining the PL curvature condition. We stagger the \(d\) components of the initial point \(\mathbf{x}_0\) so that at every consecutive iterate, a new component enters the interval \([0, 0.5]\). This component exhibits a large decrease of approximately \(0.5\). This way, at every time interval, an additional component is ‘optimized’ from \(0.5\) to almost \(0\). The following figures illustrate this for iterates \(k = 0, 1, 2\). Every colored circle denotes \(\mathbf{x}_i\) for different indices \(i\).
Our construction thus ensures that the additional path travelled by each iterate is about \(0.5\), and so the total path length is at least about \(0.5d\). However the shortest path is \(\approx \sqrt{1^2 + 1^2 + \ldots + 1^2} = \sqrt{d}\). We conclude that the path followed has length about \(\sqrt{d}\) times the shortest path. Finally, we show that for \(f\), \(\kappa = \Theta(d^2)\). This relates the path length to \(\kappa\) and completes the proof. The broad idea is that instead of moving from the point \((1, 1, \ldots 1)\) to the point \((0, 0, \ldots 0)\) directly through the shortest path, we follow a path aligned with the edges of a \(d\)-dimensional hypercube which is a factor of \(\sqrt{d}\) times the diagonal path.
The linear regression objective is described as follows. Given an \(n \times d\) matrix \(\mathbf{A}\) and a \(n \times 1\) output vector \(\mathbf{y}\), we wish to solve for the vector \(\mathbf{x}\) that minimizes the least squares objective:
$$\text{min}_\mathbf{x} \|\mathbf{Ax} – \mathbf{y}\|^2.$$
The linear regression objective is strongly convex if the matrix \(\mathbf{A}\) has full column rank. Thus the \(\mathcal{O}(\sqrt\kappa)\) bound for strongly convex functions applies here too (with some modification a similar bound applies even if \(\mathbf{A}\) does not have full column rank). In fact, this was the best known bound prior to our work.
It turns out that this agnostic bound is wildly loose. The right bound is not \(\mathcal{O}(\sqrt{\kappa})\), but instead \(\mathcal{O}(\min(\sqrt{d}, \sqrt{\log{\kappa}})\)! We call it the right bound since we can also prove a \(\mathcal{\Omega}(\min(\sqrt{d}, \sqrt{\log{\kappa}})\) worst case lower bound for convex quadratic functions. We will give some intuition for the proofs here, but for a complete treatment, do look at our paper.
The lower bound can be obtained by considering the behavior of GD on a function whose Hessian \(\mathbf{A}^T\mathbf{A}\) has geometrically increasing singular values \(\omega^0, \omega^1, \omega^2, \ldots, \omega^d\) for a constant \(\omega\). We can then use a similar idea as the PL lower bound to show that in every consecutive iterate a single additional component exhibits large decrease leading to a \(\sqrt{d}\) lower bound. Since \(\kappa = \omega^{d-1}\), we get the \(\mathcal{\Omega}(\sqrt{\log{\kappa}})\) lower bound as well.
The matching upper bound proof is involved, but the key to analysis here is that we can explicitly write down the GD iterates for quadratic objectives. We will discuss this first key step here. For simplicity, assume that \(\mathbf{A}\) has full column rank (the bound holds even without this assumption). The GD update with a step-size \(\eta\) is given by
$$\mathbf{x}_{k+1} = \mathbf{x}_{k} – 2\eta\mathbf{A}^T(\mathbf{A}\mathbf{x}_k – \mathbf{y}).$$
If \(\eta\) is appropriately small, the GD updates will converge to the optimal solution \(\mathbf{x}^* = (\mathbf{A}^T\mathbf{A})^{-1}\mathbf{A}^T\mathbf{y}\). (The solution attained if \(\mathbf{A}\) does not have full column rank can also be written explicitly using the Moore-Penrose pseudoinverse). Using the fact that \(\mathbf{A}^T\mathbf{y} = \mathbf{A}^T\mathbf{A}\mathbf{x}^*\), the above recurrence unravels to give us an explicit form for \(\mathbf{x}_k\):
$$\mathbf{x}_{k} = \mathbf{x}^* + (\mathbf{I}_d – 2\eta\mathbf{A}^T\mathbf{A})^k(\mathbf{x}_0 – \mathbf{x}^*),$$
where \(\mathbf{I}_d\) is the \(d \times d\) identity matrix. The path length can explicitly be written as
$$\begin{align*}\sum_{k = 0}^\infty \|2\eta(\mathbf{A}^T\mathbf{A}) (\mathbf{I}_d – 2\eta\mathbf{A}^T\mathbf{A})^k(\mathbf{x}_0 – \mathbf{x}^*)\|.\end{align*}$$
A standard choice for \(\eta\) ensures that the singular values of \(2\eta\mathbf{A}^T\mathbf{A}\) are \(1 \geq \sigma_1 \geq \sigma_2 \ldots \geq \sigma_d\). Now let \(\alpha_i\) denote the component of \(\mathbf{x}_0 – \mathbf{x}^*\) in the singular vector corresponding to \(\sigma_i\). Then we can rewrite the path length as
$$\begin{align*}\sum_{k = 0}^\infty \sqrt{\sum_{i=1}^d \sigma_i^2 (1 – \sigma_i)^{2k}\alpha_i^{2}}.\end{align*}$$
This immediately leads to the \(\mathcal{O}(\sqrt{d})\) bound. For non-negative \(a_1, a_2, \ldots, a_d\), \(\sqrt{a_1^2 + a_2^2 + \ldots + a_d^2} \leq a_1 + a_2 + \ldots + a_d\), and so we obtain
$$\begin{align*}\sum_{k = 0}^\infty \sqrt{\sum_{i=1}^d \sigma_i^2 (1 – \sigma_i)^{2k}\alpha_i^{2k}} &\leq \sum_{k = 0}^\infty \sum_{i=1}^d \sigma_i (1 – \sigma_i)^{k}\alpha_i
\\ &= \sum_{i=1}^d \alpha_i
\\ &\leq \sqrt{d} \sum_{i=1}^d \sqrt{\alpha_i^2}
\\ &= \sqrt{d}\|\mathbf{x}_0 – \mathbf{x}^*\| .\end{align*}$$
The proof for the \(\mathcal{O}(\sqrt{\log \kappa})\) bound is also elementary but we shall skip the details here.
The story so far summarizes what we know about function classes that exhibit linear convergence (i.e., SCS, PLS, and quadratic objectives). Now, we will discuss convex functions, where GD exhibits sub-linear convergence rates. Here, even the finiteness of GD paths is a non-trivial matter. In particular, there exist convex functions for which GD paths spiral indefinitely around the minimum (as shown here). Yet it can indeed be shown that the total path length is finite. Before discussing the results in this domain, we introduce a cousin of GD that is actually much easier to work with: Gradient Flow (GF). This is the continuous limit of GD obtained by making the step-size infinitesimally small. Instead of an iterate \(k\), suppose we are at some real-valued time \(t \geq 0\) and the current parameter value is \(\mathbf{x}_t\). For some small \(\epsilon > 0\), the parameter update is given as \(\mathbf{x}_{t+\epsilon} = \mathbf{x}_t – \epsilon \nabla f(\mathbf{x}_t)\). By letting \(\epsilon \to 0\), we can represent the instantaneous change in \(\mathbf{x}_t\) as:
$$\begin{align*}\frac{d \mathbf{x}_t}{dt} &= \lim_{\epsilon \to 0} \frac{(\mathbf{x}_{t} – \epsilon \nabla f(\mathbf{x}_t)) – \mathbf{x}_{t}}{\epsilon} \\ &= -\nabla f(\mathbf{x}_t).\end{align*}$$
Evidently, unlike GD, GF has no step-size! This turns out to be massively useful. While analyzing GD we have to deal with the step-size carefully. However we can study GF in the absence of discretization noise induced by the step-size GD. (If you’re wondering why we don’t use GF instead of GD in practice, that’s because GF is a time-dependent differential equation and not an algorithm that can be implemented on a finite state machine). Yet conclusions made on GF continue to generalize to GD. In particular, all the path length bounds we discussed so far hold for both GD and GF.
To illustrate that GF proofs are clean, in this section we work with GF and prove an interesting fact about the GF curve (we don’t go over the GD proof but you can take our word for it—it’s more involved): the self-contractedness of GF curves of quasiconvex functions, a property that leads to path length bounds in this setting. Specifically, a GD or GF curve \(g\) is self-contracted if for any \(s_1 \leq s_2 \leq s_3\):
$$\|g(s_1) – g(s_3)\| \leq \|g(s_2) – g(s_3)\|. \tag{1}$$
Loosely speaking, the above says that for any point \(s \geq 0\), the curve \(g([0, s])\) converges to \(g(s)\). Below are two-dimensional self-contracted curves (left – GF curve, right – GD curve):
Let us prove the self-contractedness of GF curves of quasiconvex functions. Recall that the GF differential equation is \(\frac{d \mathbf{x}_t}{dt}= -\nabla f(\mathbf{x}_t) \). Thus:
$$\begin{align*} \frac{d f(\mathbf{x}_t)}{dt} &= \left( \frac{d f(\mathbf{x}_t)}{d\mathbf{x}_t} \right)^T \left(\frac{d \mathbf{x}_t}{dt}\right) \\ &= (\nabla f(\mathbf{x}_t))^T(-\nabla f(\mathbf{x}_t)) \leq 0. \end{align*}$$
This shows that the function value is non-increasing with \(t\). Now, fix \(t \geq 0\), and let \(s \leq t\) be a free variable. Then:
$$\begin{align*} \frac{d\|\mathbf{x}_s -\mathbf{x}_t\|^2}{ds} &= 2(\mathbf{x}_s -\mathbf{x}_t)^T\left( \frac{d (\mathbf{x}_s -\mathbf{x}_t)}{ds}\right) \\ &= (\mathbf{x}_s -\mathbf{x}_t)^T(-\nabla f(\mathbf{x}_s)) \\ &\leq 0. \end{align*}$$
The final inequality holds by observing that: (a) \(f(\mathbf{x}_t) \leq f(\mathbf{x}_s)\) as proved earlier, and then (b) using the equivalent definition of quasiconvexity for differentiable functions—if \(f(\mathbf{x}) \leq f(\mathbf{y})\), then \(\nabla f(\mathbf{y})^T(\mathbf{x} – \mathbf{y}) \leq 0.\) This implies that \(\|\mathbf{x}_s -\mathbf{x}_t\|\) is non-increasing in \(s\) which is equivalent to Equation (1).
It seems that self-contracted curves have a natural shrinkage property that should lead to a path length bound. This intuition was nailed down using a very interesting geometric proof in this paper and this paper. It was shown that certain smooth self-contracted curves (which includes GF curves) have a path length bound of \(2^{\mathcal{O}(d \log d)}\). In our paper we show that a similar bound also holds for GD curves of smooth convex functions.
One of the most important questions we have yet to answer is identifying lower and upper bounds in the strongly convex case. It would also be interesting to extend the ideas discussed for GD/GF to other iterative algorithms like Accelerated Gradient Descent, Polyak’s Heavy Ball method, Projected Gradient Descent, etc. If you liked this post, we encourage you to go over our paper and perhaps think about some of these open problems yourselves!
DISCLAIMER: All opinions expressed in this posts are those of the author and do not represent the views of Carnegie Mellon University.
]]>(Crossposted at Off the Convex Path)
Traditional wisdom in machine learning holds that there is a careful trade-off between training error and generalization gap. There is a “sweet spot” for the model complexity such that the model (i) is big enough to achieve reasonably good training error, and (ii) is small enough so that the generalization gap – the difference between test error and training error – can be controlled. A smaller model would give a larger training error while making the model bigger would result in a larger generalization gap, both leading to larger test errors. This is described by the classical U-shaped curve for the test error when the model complexity varies (see Figure 1(a)).
However, it is common nowadays to use highly complex over-parameterized models like deep neural networks. These models are usually trained to achieve near-zero error on the training data, and yet they still have remarkable performance on test data. Belkin et al. (2018) characterized this phenomenon by a “double descent” curve which extends the classical U-shaped curve. It was observed that, as one increases the model complexity past the point where the model can perfectly fit the training data (i.e., interpolation regime is reached), test error continues to drop! Interestingly, the best test error is often achieved by the largest model, which goes against the classical intuition about the “sweet spot.” Belkin et al. (2018) illustrates this phenomenon in Figure 1.
Consequently one suspects that the training algorithms used in deep learning – (stochastic) gradient descent and its variants – somehow implicitly constrain the complexity of trained networks (i.e., “true number” of parameters), thus leading to a small generalization gap.
Since larger models often give better performance in practice, one may naturally wonder:
How does an infinitely wide net perform?
The answer to this question corresponds to the right end of Figure 1 (b). This blog post is about a model that has attracted a lot of attention in the past year: deep learning in the regime where the width – namely, the number of channels in convolutional filters, or the number of neurons in fully-connected internal layers – goes to infinity. At first glance this approach may seem hopeless for both practitioners and theorists: all the computing power in the world is insufficient to train an infinite network, and theorists already have their hands full trying to figure out finite ones. But in math/physics, there is a tradition of deriving insights into questions by studying them in the infinite limit, and indeed here too, the infinite limit becomes easier for theory.
Experts may recall the connection between infinitely wide neural networks and kernel methods from 25 years ago by Neal (1994) as well as the recent extensions by Lee et al. (2018) and Matthews et al. (2018). These kernels correspond to infinitely wide deep networks whose parameters are chosen randomly, and only the top (classification) layer is trained by gradient descent. Specifically, if \(f(\theta,x)\) denotes the output of the network on input \(x\) where \(\theta\) denotes the parameters in the network, and \(\mathcal{W}\) is an initialization distribution over \(\theta\) (usually Gaussian with proper scaling), then the corresponding kernel is $$\mathrm{ker} \left(x,x’\right) = \mathbb{E}_{\theta \sim \mathcal{W}}[f\left(\theta,x\right)\cdot f\left(\theta,x’\right)]$$ where \(x,x’\) are two inputs.
What about the more usual scenario when all layers are trained? Recently, Jacot et al. (2018) first observed that this is also related to a kernel named neural tangent kernel (NTK), which has the form $$\mathrm{ker} \left(x,x’\right) = \mathbb{E}_{\theta \sim \mathcal{W}}\left[\left\langle \frac{\partial f\left(\theta,x\right)}{\partial \theta}, \frac{\partial f\left(\theta,x’\right)}{\partial \theta}\right\rangle\right].$$
The key difference between the NTK and previously proposed kernels is that the NTK is defined through the inner product between the gradients of the network outputs with respect to the network parameters. This gradient arises from the use of the gradient descent algorithm. Roughly speaking, the following conclusion can be made for a sufficiently wide deep neural network trained by gradient descent:
A properly randomly initialized sufficiently wide deep neural network trained by gradient descent with infinitesimal step size (a.k.a. gradient flow) is equivalent to a kernel regression predictor with a deterministic kernel called neural tangent kernel (NTK).
This was more or less established in the original paper of Jacot et al. (2018), but they required the width of every layer to go to infinity in sequential order. In our recent paper with Sanjeev Arora, Zhiyuan Li, Ruslan Salakhutdinov and Ruosong Wang, we improve this result to the non-asymptotic setting where the width of every layer only needs to be greater than a certain finite threshold.
In the rest of this post, we will first explain how NTK arises and the idea behind the proof of the equivalence between wide neural networks and NTKs. Then we will present experimental results showing how well infinitely wide neural networks perform in practice.
Now we describe how training an ultra-wide fully-connected neural network leads to kernel regression with respect to the NTK. A more detailed treatment is given in our paper. We first specify our setup. We consider the standard supervised learning setting, in which we are given \(n\) training data points drawn from some underlying distribution and wish to find a function that, given the input, predicts the label well on the data distribution. We consider a fully-connected neural network defined by \(f(\theta, x)\) , where \(\theta\) is the collection of all the parameters in the network and \(x\) is the input. For simplicity, we only consider neural networks with a single output, i.e., \(f(\theta, x) \in \mathbb{R}\) , but the generalization to multiple outputs is straightforward.
We consider training the neural network by minimizing the quadratic loss over training data: \(\ell(\theta) = \frac{1}{2}\sum_{i=1}^n (f(\theta,x_i)-y_i)^2.\) Gradient descent with infinitesimally small learning rate (a.k.a. gradient flow) is applied on this loss function \(\ell(\theta)\): $$\frac{d \theta(t)}{dt} = – \nabla \ell(\theta(t))$$ where \(\theta(t)\) denotes the parameters at time \(t\).
Let us define some useful notation. Denote \(u_i = f(\theta, x_i)\) , which is the network’s output on \(x_i\). We let \(u=(u_1, \ldots, u_n)^\top \in \mathbb{R}^n\) be the collection of the network outputs on all training inputs. We use the time index \(t\) for all variables that depend on time, e.g. \(u_i(t), u(t)\), etc. With this notation, the training objective can be conveniently written as \(\ell(\theta) = \frac12 \|u-y\|_2^2\).
Using simple differentiation, one can obtain the dynamics of \(u_t\) as follows: (see our paper for a proof) $$\frac{du(t)}{dt} = -H(t)\cdot(u(t)-y),$$ where \(H(t)\) is an \(n \times n\) positive semidefinite matrix whose \(i,j\)-th entry is \(\left\langle \frac{\partial f(\theta(t), x_i)}{\partial\theta}, \frac{\partial f(\theta(t), x_j)}{\partial\theta} \right\rangle\) .
Note that \(H(t)\) is the kernel matrix of the following (time-varying) kernel evaluated on the training data: $$ker_t(x,x’) = \left\langle \frac{\partial f(\theta(t), x)}{\partial\theta}, \frac{\partial f(\theta(t), x’)}{\partial\theta} \right\rangle, \quad \forall x, x’ \in \mathbb{R}^{d}.$$ In this kernel, an input \(x\) is mapped to a feature vector \(\phi_t(x) = \frac{\partial f(\theta(t), x)}{\partial\theta}\) defined through the gradient of the network output with respect to the parameters at time \(t\).
Up to this point we haven’t used the property that the neural network is very wide. The formula for the evolution of \(u(t)\) is valid in general. In the large width limit, it turns out that the time-varying kernel \(ker_t(\cdot,\cdot)\) is (with high probability) always close to a deterministic fixed kernel \(ker_{\mathsf{NTK}}(\cdot,\cdot)\), which is the neural tangent kernel (NTK). This property is proved in two steps, both requiring the large width assumption:
Combining the above two steps, we conclude that for any two inputs \(x,x’\) , with high probability we have $$ker_t(x,x’) \approx ker_0(x,x’) \approx ker_{\mathsf{NTK}}(x,x’), \quad \forall t.$$ As we have seen, the dynamics of gradient descent are closely related to the time-varying kernel \(ker_t(\cdot,\cdot)\). Now that we know that \(ker_t(\cdot,\cdot)\) is essentially the same as the NTK, with a few more steps, we can eventually establish the equivalence between trained neural network and NTK: the final learned neural network at time \(t=\infty\) , denoted by \(f_{\mathsf{NN}}(x) = f(\theta(\infty), x)\) , is equivalent to the kernel regression solution with respect to the NTK. Namely, for any input \(x\) we have $$f_{\mathsf{NN}}(x) \approx f_{\mathsf{NTK}}(x) = ker_{\mathsf{NTK}}(x, X)^\top \cdot ker_{\mathsf{NTK}}(X, X)^{-1} \cdot y,$$ where \(ker_{\mathsf{NTK}}(x, X) = (ker_{\mathsf{NTK}}(x, x_1), \ldots, ker_{\mathsf{NTK}}(x, x_n))^\top \in \mathbb{R}^n\), and \(ker_{\mathsf{NTK}}(X, X)\) is an \(n\times n\) matrix whose \((i, j)\)-th entry is \(ker_{\mathsf{NTK}}(x_i, x_j)\).
(In order to not have a bias term in the kernel regression solution we also assume that the network output at initialization is small: \(f(\theta(0), x)\approx0\); this can be ensured by e.g. scaling down the initialization magnitude by a large constant, or replicating a network with opposite signs on the top layer at initialization.)
Having established this equivalence, we can now address the question of how well infinitely wide neural networks perform in practice — we can just evaluate the kernel regression predictors using the NTKs! We test NTKs on a standard image classification dataset, CIFAR-10. Note that for image datasets, one needs to use convolutional neural networks (CNNs) to achieve good performance. Therefore, we derive an extension of NTK, convolutional neural tangent kernels (CNTKs) and test their performance on CIFAR-10. In the table below, we report the classification accuracies of different CNNs and CNTKs:
Depth | CNN-V | CNTK-V | CNN-GAP | CNTK-GAP |
---|---|---|---|---|
3 | 61.97% | 64.67% | 57.96% | 70.47% |
4 | 62.12% | 65.52% | 80.58% | 75.93% |
6 | 64.03% | 66.03% | 80.97% | 76.73% |
11 | 70.97% | 65.90% | 75.45% | 77.43% |
21 | 80.56% | 64.09% | 81.23% | 77.08% |
Here CNN-Vs are vanilla practically-wide CNNs (without pooling), and CNTK-Vs are their NTK counterparts. We also test CNNs with global average pooling (GAP), denotes above as CNN-GAPs, and their NTK counterparts, CNTK-GAPs. For all experiments, we turn off batch normalization, data augmentation, etc., and only use SGD to train CNNs (for CNTKs, we use the closed-form formula of kernel regression).
We find that CNTKs are actually very powerful kernels. The best kernel we find, 11-layer CNTK with GAP, achieves 77.43% classification accuracy on CIFAR-10. This results in a significant new benchmark for performance of a pure kernel-based method on CIFAR-10, being 10% higher than methods reported by Novak et al. (2019). The CNTKs also perform similarly to their CNN counterparts. This means that ultra-wide CNNs can achieve reasonable test performance on CIFAR-10.
It is also interesting to see that the global average pooling operation can significantly increase the classification accuracy for both CNNs and CNTKs. From this observation, we suspect that many techniques that improve the performance of neural networks are in some sense universal, i.e., these techniques might benefit kernel methods as well.
Understanding the surprisingly good performance of over-parameterized deep neural networks is definitely a challenging theoretical question. Now, at least we have a better understanding of a class of ultra-wide neural networks (the right end of Figure 1(b)): they are captured by neural tangent kernels! A hurdle that remains is that the classic generalization theory for kernels is still incapable of giving realistic bounds for generalization. But at least we now know that better understanding of kernels can lead to better understanding of deep nets.
Another fruitful direction is to “translate” different architectures/tricks of neural networks to kernels and to check their practical performance. We have found that global average pooling can significantly boost the performance of kernels, so we hope other tricks like batch normalization, dropout, max-pooling, etc. can also benefit kernels. Similarly, one can try to translate other architectures like recurrent neural networks, graph neural networks, and transformers, to kernels as well.
Our study also shows that there is a performance gap between infinitely wide networks and finite ones. How to explain this gap is an important theoretical question.
DISCLAIMER: All opinions expressed in this posts are those of the authors and do not represent the views of Carnegie Mellon University.
]]>
Figure 1: Overview of the setting of unsupervised domain adaptation and its difference with the standard setting of supervised learning. In domain adaptation the source (training) domain is related to but different from the target (testing) domain. During training, the algorithm can only have access to labeled samples from source domain and unlabeled samples from target domain. The goal is to generalize on the target domain.
One of the backbone assumptions underpinning the generalization theory of supervised learning algorithms is that the test distribution should be the same as the training distribution. However in many real-world applications it is usually time-consuming or even infeasible to collect labeled data from all the possible scenarios where our learning system is going to be deployed. For example, consider a typical application of vehicle counting, where we would like to count how many cars are there in a given image captured by the camera. There are over 200 cameras with different calibrations, perspectives, lighting conditions, etc. at different locations in Manhattan. In this case, it is very costly to collect labeled data of images from all the cameras. Ideally, we would collect labeled images for a subset of the cameras and still be able to train a counting system that would work well for all cameras.
Domain adaptation deals with the setting where we only have access to labeled data from the training distribution (a.k.a., source domain) and unlabeled data from the testing distribution (a.k.a., target domain). The setting is complicated by the fact that the source domain can be different from the target domain — just like the above example where different images taken from different cameras usually have different pixel distributions due to different perspectives, lighting, calibrations, etc. The goal of an adaptation algorithm is then to generalize to the target domain without seeing labeled samples from it.
In this blog post, we will first review a common technique to achieve this goal based on the idea of finding a domain-invariant representation. Then we will construct a simple example to show that this technique alone does not necessarily lead to good generalization on the target domain. To understand the failure mode, we give a generalization upper bound that decomposes into terms measuring the difference in input and label distributions between the source and target domains. Crucially, this bound allows us to provide a sufficient condition for good generalizations on the target domain. We also complement the generalization upper bound with an information-theoretic lower bound to characterize the trade-off in learning domain-invariant representations. Intuitively, this result says that when the marginal label distributions differ across domains, one cannot hope to simultaneously minimize both source and target errors by learning invariant representations; this provides a necessary condition for the success of methods based on learning invariant representations. All the material presented here is based on our recent work published at ICML 2019.
The central idea behind learning invariant representations is quite simple and intuitive: we want to find a representations that is insensitive to the domain shift while still capturing rich information for the target task. Such a representation would allow us to generalize to the target domain by only training with data from the source domain. The pipeline for learning domain invariant representations is illustrated in Figure 3.
Note that in the framework above we can use different transformation functions \(g_S/g_T\) on the source/target domain to align the distributions. This powerful framework is also very flexible: by using different measures to align the feature distributions, we recover several of the existing approaches, e.g., DANN (Ganin et al.’ 15), DAN (Long et al.’ 15) and WDGRL (Shen et al.’ 18).
A theoretical justification for the above framework is the following generalization bound by Ben-David et al.’ 10: Let \(\mathcal{H}\) be a hypothesis class and \(\mathcal{D}_S/\mathcal{D}_T\) be the marginal data distributions of source/target domains, respectively. For any \(h\in\mathcal{H}\), the following generalization bound holds: $$\varepsilon_T(h) \leq \varepsilon_S(h) + d(\mathcal{D}_S, \mathcal{D}_T) + \lambda^*,$$ where \(\lambda^* = \inf_{h\in\mathcal{H}} \varepsilon_S(h) + \varepsilon_T(h)\) is the optimal joint error achievable on both domains. At a colloquial level, the above generalization bound shows that the target risk could essentially be bounded by three terms:
The interpretation of the bound is as follows. If there exists a hypothesis that works well on both domains, then in order to minimize the target risk, one should choose a hypothesis that minimizes the source risk while at the same time aligning the source and target data distributions.
The above framework for domain adaptation has generated a surge of interest in recent years and we have seen many interesting variants and applications based on the general idea of learning domain-invariant representations. Yet it is not clear whether such methods are guaranteed to succeed when the following conditions are met:
Since we can only train with labeled data from the source domain, ideally we would hope that when the above two conditions are met, the composition function \(h\circ g\) also achieves a small risk on the target domain because these two domains are close to each other in the feature space. Perhaps somewhat surprisingly, this is not the case as we demonstrate from the following simple example illustrated in Figure 4.
Consider an adaptation problem where we have input space and feature space \(\mathcal{X} = \mathcal{Z} = \mathbb{R}\) with source domain \(\mathcal{D}_S = U(-1,0)\) and target domain \(\mathcal{D}_T = U(1,2)\), respectively, where we use \(U(a,b)\) to mean a uniform distribution in the interval \((a, b)\). In this example, the two domains are so far away from each other that their supports are disjoint! Now, let’s try to align them so that they are closer to each other. We can do this by shifting the source domain to the right by one unit and then shifting the target domain to the left by one unit.
As shown in Figure 4, after adaptation both domains have distribution \(U(0, 1)\), i.e., they are perfectly aligned by our simple translation transformation. However, due to our construction, now the labels are flipped between the two domains: for every \(x\in (0, 1)\), exactly one of the domains has label 1 and the other has label 0. This implies that if a hypothesis achieves perfect classification on the source domain, it will also incur the maximum risk of 1 on the target domain! In fact, in this case we have \(\varepsilon_S(h) + \varepsilon_T(h) = 1\) after adaptation for any classifier \(h\). As a comparison, before adaptation, a simple interval hypothesis \(h^*(x) = 1\) iff \(x\in (-1/2, 3/2)\) attains perfect classification on both domains.
So what insights can we gain from the previous counter-example? Why do we incur a large target error despite perfectly aligning the marginal distributions of the two domains and minimizing the source error? Does this contradict Ben-David et al.’s generalization bound?
The caveat here is that while the distance between the two domains becomes 0 after the adaptation, the optimal joint error on both domains becomes large. In the counter-example above, this means that after adaptation \(\lambda^{*} = 1\), which further implies \(\varepsilon_T(h) = 1\) if \(\varepsilon_S(h) = 0\). Intuitively, from Figure 4 we can see that the labeling functions of the two domains are “maximally different” from each other after adaptation, but during adaptation we are only aligning the marginal distributions in the feature space. Since the optimal joint error \(\lambda^*\) is often unknown and intractable to compute, could we construct a generalization upper bound that is free of the constant \(\lambda^*\) and takes into account the conditional shift?
Here is an informal description of what we show in our paper: Let \(f_S\) and \(f_T\) be the labeling functions of the source and target domains. Then for any hypothesis class \(\mathcal{H}\) and any \(h\in\mathcal{H}\), the following inequality holds: $$\varepsilon_T(h) \leq \varepsilon_S(h) + d(\mathcal{D}_S, \mathcal{D}_T) + \min\{\mathbb{E}_S[|f_S – f_T|], \mathbb{E}_T[|f_S – f_T|]\}.$$
Roughly speaking, the above bound gives a decomposition of the difference of errors between source and target domains. Again, the second term on the RHS measures the difference between the marginal data distributions. But, in place of the optimal joint error term, the third term now measures the discrepancy between the labeling functions of these two domains. Hence, this bound says that just aligning the marginal data distributions is not sufficient for adaptation, we also need to ensure that the label functions (conditional distributions) are close to each other after adaptation.
In the counter-example above, we demonstrated that aligning the marginal distributions and achieving a small source error is not sufficient to guarantee a small target error. But in this example, it is actually possible to find another feature transformation that jointly aligns both the marginal data distributions and the labeling functions. Specifically, let the feature transformation \(g(x) = \mathbb{I}_{x\leq 0}(x)(x+1) + \mathbb{I}_{x > 0}(x)(2-x)\). Then, it is straightforward to verify that the source and target domains perfectly align with each other after adaptation. Furthermore, we also have \(\varepsilon_T(h) = 0\) if \(\varepsilon_S(h) = 0\).
Consequently, it is natural to wonder whether it is always possible to find a feature transformation and a hypothesis to align the marginal data distributions and minimize the source error so that the composite function of these two also achieves a small target error? Quite surprisingly, we show that this is not always possible. In fact, finding a feature transformation to align the marginal distributions can provably increase the joint error on both domains. With this transformation, minimizing the source error will only lead to increasing the target error!
More formally, let \(\mathcal{D}_S^Y/\mathcal{D}_T^Y\) be the marginal label distribution of the source/target domain. For any feature transformation \(g: X\to Z\), let \(\mathcal{D}_S^Z/\mathcal{D}_T^Z\) be the resulting feature distribution by applying \(g(\cdot)\) to \(\mathcal{D}_S/\mathcal{D}_T\) respectively. Furthermore, define \(d_{\text{JS}}(\cdot, \cdot)\) to be the Jensen-Shannon distance between a pair of distributions. Then, for any hypothesis \(h: Z\to\{0, 1\}\), if \(d_{\text{JS}}(\mathcal{D}_S^Y, \mathcal{D}_T^Y) \geq d_{\text{JS}}(\mathcal{D}_S^Z, \mathcal{D}_T^Z)\), the following inequality holds: $$\varepsilon_S(h\circ g) + \varepsilon_T(h\circ g)\geq \frac{1}{2}\left(d_{\text{JS}}(\mathcal{D}_S^Y, \mathcal{D}_T^Y) – d_{\text{JS}}(\mathcal{D}_S^Z, \mathcal{D}_T^Z)\right)^2.$$
Let’s parse the above lower bound step by step. The LHS corresponds to the joint error achievable by the composite function \(h\circ g\) on both the source and the target domains. The RHS contains the distance between the marginal label distributions and the distance between the feature distributions. Hence, when the marginal label distributions \(\mathcal{D}_S^Y/\mathcal{D}_T^Y\) differ between two domains, i.e., \(d_{\text{JS}}(\mathcal{D}_S^Y, \mathcal{D}_T^Y) > 0\), aligning the marginal data distributions by learning \(g(\cdot)\) will only increase the lower bound. In particular, for domain-invariant representations where \(d_{\text{JS}}(\mathcal{D}_S^Z, \mathcal{D}_T^Z) = 0\), the lower bound attains its maximum value of \(\frac{1}{2}d^2_{\text{JS}}(\mathcal{D}_S^Y, \mathcal{D}_T^Y)\). Since in domain adaptation we only have access to labeled data from the source domain, minimizing the source error will only lead to an increase of the target error. In a nutshell, this lower bound can be understood as an uncertainty principle: when the marginal label distributions differ across domains, one has to incur large error in either the source domain or the target domain when using domain-invariant representations.
One implication made by our lower bound is that when two domains have different marginal label distributions, minimizing the source error while aligning the two domains can lead to increased target error. To verify this, let us consider the task of digit classification on the MNIST, SVHN and USPS datasets. The label distributions of these three datasets are shown in Figure 5.
From Figure 5, it is clear to see that these three datasets have quite different label distributions. Now let’s use DANN (Ganin et al., 2015) to classify on the target domain by learning a domain invariant representation while training to minimize error on the source domain.
We plot four adaptation trajectories for DANN in Figure 6. Across the four adaptation tasks, we can observe the following pattern: the test domain accuracy rapidly grows within the first 10 iterations before gradually decreasing from its peak, despite consistently increasing source training accuracy. These phase transitions can be verified from the negative slopes of the least squares fit of the adaptation curves (dashed lines in Figure 6). The above experimental results are consistent with our theoretical findings: over-training on the source task can indeed hurt generalization to the target domain when the label distributions differ.
Note that the failure mode in the above counter-example is due to the increase of the distance between the labeling functions during adaptation. One interesting direction for future work is then to characterize what properties the feature transformation function should have in order to decrease the shift between labeling functions. Of course domain adaptation would not be possible without proper assumptions on the underlying source/target domains. It would be nice to establish some realistic assumptions under which we can develop effective adaptation algorithms that align both the marginal distributions and the labeling functions. Feel free to get in touch if you’d like to talk more!
DISCLAIMER: All opinions expressed in this post are those of the author and do not represent the views of CMU.
]]>We are proud to present the following papers at the 33rd Conference on Neural Information Processing Systems (NeurIPS) in Vancouver, Canada. Check back for an update with poster numbers and links once the camera-ready papers become available.
If you are attending NeurIPS 2019, please stop by to say hello and hear more about what we are doing!
Joint-task Self-supervised Learning for Temporal Correspondence
Xueting Li (uc merced) · Sifei Liu (NVIDIA) · Shalini De Mello (NVIDIA) · Xiaolong Wang (CMU) · Jan Kautz (NVIDIA) · Ming-Hsuan Yang (UC Merced / Google)
Deep Equilibrium Models
Shaojie Bai (Carnegie Mellon University) · J. Zico Kolter (Carnegie Mellon University / Bosch Center for AI) · Vladlen Koltun (Intel Labs)
Volumetric Correspondence Networks for Optical Flow
Gengshan Yang (Carnegie Mellon University) · Deva Ramanan (Carnegie Mellon University)
Efficient Symmetric Norm Regression via Linear Sketching
Zhao Song (University of Washington) · Ruosong Wang (Carnegie Mellon University) · Lin Yang (Johns Hopkins University) · Hongyang Zhang (Carnegie Mellon University) · Peilin Zhong (Columbia University)
Envy-Free Classification
Maria-Florina Balcan (Carnegie Mellon University) · Travis Dick (Carnegie Mellon University) · Ritesh Noothigattu (Carnegie Mellon University) · Ariel D Procaccia (Carnegie Mellon University)
Twin Auxilary Classifiers GAN
Mingming Gong (University of Melbourne) · Yanwu Xu (University of Pittsburgh) · Chunyuan Li (Microsoft Research) · Kun Zhang (CMU) · Kayhan Batmanghelich (University of Pittsburgh)
Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift
Stephan Rabanser (Amazon) · Stephan Günnemann (Technical University of Munich) · Zachary Lipton (Carnegie Mellon University)
Backprop with Approximate Activations for Memory-efficient Network Training
Ayan Chakrabarti (Washington University in St. Louis) · Benjamin Moseley (Carnegie Mellon University)
Total Least Squares Regression in Input Sparsity Time
Huaian Diao (Northeast Normal University) · Zhao Song (Harvard University & University of Washington) · David Woodruff (Carnegie Mellon University) · Xin Yang (University of Washington)
Conformal Prediction Under Covariate Shift
Rina Foygel Barber (University of Chicago) · Emmanuel Candes (Stanford University) · Aaditya Ramdas (CMU) · Ryan Tibshirani (Carnegie Mellon University)
Optimal Analysis of Subset-Selection Based L_p Low-Rank Approximation
Chen Dan (Carnegie Mellon University) · Hong Wang (Massachusetts Institute of Technology) · Hongyang Zhang (Carnegie Mellon University) · Yuchen Zhou (University of Wisconsin, Madison) · Pradeep Ravikumar (Carnegie Mellon University)
Third-Person Visual Imitation Learning via Decoupled Hierarchical Control
Pratyusha Sharma (Carnegie Mellon University) · Deepak Pathak (UC Berkeley) · Abhinav Gupta (Facebook AI Research/CMU)
Visual Sequence Learning in Hierarchical Prediction Networks and Primate Visual Cortex
JIELIN QIU (Shanghai Jiao Tong University) · Ge Huang (Carnegie Mellon University) · Tai Sing Lee (Carnegie Mellon University)
Optimal Decision Tree with Noisy Outcomes
Su Jia (CMU) · viswanath nagarajan (Univ Michigan, Ann Arbor) · Fatemeh Navidi (University of Michigan) · R Ravi (CMU)
Learning Sample-Specific Models with Low-Rank Personalized Regression
Ben Lengerich (Carnegie Mellon University) · Bryon Aragam (University of Chicago) · Eric Xing (Petuum Inc. / Carnegie Mellon University)
A Normative Theory for Causal Inference and Bayes Factor Computation in Neural Circuits
Wenhao Zhang (Carnegie Mellon & U. of Pittsburgh) · Si Wu (Peking University) · Brent Doiron (University of Pittsburgh) · Tai Sing Lee (Carnegie Mellon University)
Regularized Weighted Low Rank Approximation
Frank Ban (UC Berkeley) · David Woodruff (Carnegie Mellon University) · Richard Zhang (UC Berkeley)
Partially Encrypted Deep Learning using Functional Encryption
Theo Ryffel (École Normale Supérieure) · David Pointcheval (École Normale Supérieure) · Francis Bach (INRIA – Ecole Normale Superieure) · Edouard Dufour-Sans (Carnegie Mellon University) · Romain Gay (UC Berkeley)
Learning low-dimensional state embeddings and metastable clusters from time series data
Yifan Sun (Carnegie Mellon University) · Yaqi Duan (Princeton University) · Hao Gong (Princeton University) · Mengdi Wang (Princeton University)
Offline Contextual Bayesian Optimization
Ian Char (Carnegie Mellon University) · Youngseog Chung (Carnegie Mellon University) · Willie Neiswanger (Carnegie Mellon University) · Kirthevasan Kandasamy (Carnegie Mellon University) · Oak Nelson (Princeton Plasma Physics Lab) · Mark Boyer (Princeton Plasma Physics Lab) · Egemen Kolemen (Princeton Plasma Physics Lab) · Jeff Schneider (Carnegie Mellon University)
Game Design for Eliciting Distinguishable Behavior
Fan Yang (Carnegie Mellon University) · Liu Leqi (Carnegie Mellon University) · Yifan Wu (Carnegie Mellon University) · Zachary Lipton (Carnegie Mellon University) · Pradeep Ravikumar (Carnegie Mellon University) · Tom M Mitchell (Carnegie Mellon University) · William Cohen (Google AI)
Optimal Sketching for Kronecker Product Regression and Low Rank Approximation
Huaian Diao (Northeast Normal University) · Rajesh Jayaram (Carnegie Mellon University) · Zhao Song (UT-Austin) · Wen Sun (Microsoft Research) · David Woodruff (Carnegie Mellon University)
Online Learning for Auxiliary Task Weighting for Reinforcement Learning
Xingyu Lin (Carnegie Mellon University) · Harjatin Baweja (CMU) · George Kantor (CMU) · David Held (CMU)
Cost Effective Active Search
Shali Jiang (Washington University in St. Louis) · Roman Garnett (Washington University in St. Louis) · Benjamin Moseley (Carnegie Mellon University)
Mutually Regressive Point Processes
Ifigeneia Apostolopoulou (Carnegie Mellon University) · Scott Linderman (Stanford University) · Kyle Miller (Carnegie Mellon University) · Artur Dubrawski (Carnegie Mellon University)
Efficient Regret Minimization Algorithm for Extensive-Form Correlated Equilibrium
Gabriele Farina (Carnegie Mellon University) · Chun Kai Ling (Carnegie Mellon University) · Fei Fang (Carnegie Mellon University) · Tuomas Sandholm (Carnegie Mellon University)
Optimistic Regret Minimization for Extensive-Form Games via Dilated Distance-Generating Functions
Gabriele Farina (Carnegie Mellon University) · Christian Kroer (Columbia University) · Tuomas Sandholm (Carnegie Mellon University)
Face Reconstruction from Voice using Generative Adversarial Networks
Yandong Wen (Carnegie Mellon University) · Bhiksha Raj (Carnegie Mellon University) · Rita Singh (Carnegie Mellon University)
On Testing for Biases in Peer Review
Ivan Stelmakh (Carnegie Mellon University) · Nihar Shah (CMU) · Aarti Singh (CMU)
Graph Neural Tangent Kernel: Fusing Graph Neural Networks with Graph Kernels
Simon Du (Carnegie Mellon University) · Kangcheng Hou (Zhejiang University) · Ruslan Salakhutdinov (Carnegie Mellon University) · Barnabas Poczos (Carnegie Mellon University) · Ruosong Wang (Carnegie Mellon University) · Keyulu Xu (MIT)
Acceleration via Symplectic Discretization of High-Resolution Differential Equations
Bin Shi (UC Berkeley) · Simon Du (Carnegie Mellon University) · Weijie Su (University of Pennsylvania) · Michael Jordan (UC Berkeley)
XLNet: Generalized Autoregressive Pretraining for Language Understanding
Zhilin Yang (Tsinghua University) · Zihang Dai (Carnegie Mellon University) · Yiming Yang (CMU) · Jaime Carbonell (CMU) · Ruslan Salakhutdinov (Carnegie Mellon University) · Quoc V Le (Google)
Mixtape: Breaking the Softmax Bottleneck Efficiently
Zhilin Yang (Tsinghua University) · Thang Luong (Google) · Ruslan Salakhutdinov (Carnegie Mellon University) · Quoc V Le (Google)
MaCow: Masked Convolutional Generative Flow
Xuezhe Ma (Carnegie Mellon University) · Xiang Kong (Carnegie Mellon University) · Shanghang Zhang (Carnegie Mellon University) · Eduard Hovy (Carnegie Mellon University)
Adaptive Gradient-Based Meta-Learning Methods
Mikhail Khodak (CMU) · Maria-Florina Balcan (Carnegie Mellon University) · Ameet Talwalkar (CMU)
Towards a Zero-One Law for Column Subset Selection
Zhao Song (University of Washington) · David Woodruff (Carnegie Mellon University) · Peilin Zhong (Columbia University)
Dual Adversarial Semantics-Consistent Network for Generalized Zero-Shot Learning
Jian Ni (University of Science and Technology of China) · Shanghang Zhang (Carnegie Mellon University) · Haiyong Xie (University of Science and Technology of China)
Likelihood-Free Overcomplete ICA and ApplicationsIn Causal Discovery
Chenwei DING (The University of Sydney) · Mingming Gong (University of Melbourne) · Kun Zhang (CMU) · Dacheng Tao (University of Sydney)
The bias of the sample mean in multi-armed bandits can be positive or negative
Jaehyeok Shin (Carnegie Mellon University) · Aaditya Ramdas (Carnegie Mellon University) · Alessandro Rinaldo (CMU)
Efficient and Thrifty Voting by Any Means Necessary
Debmalya Mandal (Columbia University) · Ariel D Procaccia (Carnegie Mellon University) · Nisarg Shah (University of Toronto) · David Woodruff (Carnegie Mellon University)
Re-examination of the Role of Latent Variables in Sequence Modeling
Guokun Lai (Carnegie Mellon University) · Zihang Dai (Carnegie Mellon University)
Towards Understanding the Importance of Shortcut Connections in Residual Networks
Tianyi Liu (Georgia Institute of Technolodgy) · Minshuo Chen (Georgia Tech) · Mo Zhou (Duke University) · Simon Du (Carnegie Mellon University) · Enlu Zhou (Georgia Institute of Technology) · Tuo Zhao (Gatech)
Learning Local Search Heuristics for Boolean Satisfiability
Emre Yolcu (Carnegie Mellon University) · Barnabas Poczos (Carnegie Mellon University)
Difference Maximization Q-learning: Provably Efficient Q-learning with Function Approximation
Simon Du (Carnegie Mellon University) · Yuping Luo (Princeton University) · Ruosong Wang (Carnegie Mellon University) · Hanrui Zhang (Duke University)
On Exact Computation with an Infinitely Wide Neural Net
Sanjeev Arora (Princeton University) · Simon Du (Carnegie Mellon University) · Wei Hu (Princeton University) · zhiyuan li (Princeton University) · Ruslan Salakhutdinov (Carnegie Mellon University) · Ruosong Wang (Carnegie Mellon University)
Paradoxes in Fair Machine Learning
Paul Goelz (Carnegie Mellon University) · Anson Kahng (Carnegie Mellon University) · Ariel D Procaccia (Carnegie Mellon University)
Graph Agreement Models for Semi-Supervised Learning
Otilia Stretcu (Carnegie Mellon University) · Krishnamurthy Viswanathan (Google Research) · Dana Movshovitz-Attias (Google) · Emmanouil Platanios (Carnegie Mellon University) · Sujith Ravi (Google Research) · Andrew Tomkins (Google)
Nonparametric Density Estimation & Convergence Rates for GANs under Besov IPM Losses
Ananya Uppal (Carnegie Mellon University) · Shashank Singh (Carnegie Mellon University) · Barnabas Poczos (Carnegie Mellon University)
Correlation in Extensive-Form Games: Saddle-Point Formulation and Benchmarks
Gabriele Farina (Carnegie Mellon University) · Chun Kai Ling (Carnegie Mellon University) · Fei Fang (Carnegie Mellon University) · Tuomas Sandholm (Carnegie Mellon University)
ADDIS: an adaptive discarding algorithm for online FDR control with conservative nulls
Jinjin Tian (Carnegie Mellon University) · Aaditya Ramdas (Carnegie Mellon University)
Tight Dimensionality Reduction for Sketching Low Degree Polynomial Kernels
Michela Meister (Google) · Tamas Sarlos (Google Research) · David Woodruff (Carnegie Mellon University)
Differentiable Convex Optimization Layers
Akshay Agrawal (Stanford University) · Brandon Amos (Facebook) · Shane Barratt (Stanford University) · Stephen Boyd (Stanford University) · Steven Diamond (Stanford University) · J. Zico Kolter (Carnegie Mellon University / Bosch Center for AI)
Average Case Column Subset Selection for Entrywise $\ell_1$-Norm Loss
Zhao Song (University of Washington) · David Woodruff (Carnegie Mellon University) · Peilin Zhong (Columbia University)
Efficient Forward Architecture Search
Hanzhang Hu (Carnegie Mellon University) · John Langford (Microsoft Research New York) · Rich Caruana (Microsoft) · Saurajit Mukherjee (microsoft) · Eric J Horvitz (Microsoft Research) · Debadeepta Dey (Microsoft Research AI)
Efficient Near-Optimal Testing of Community Changes in Balanced Stochastic Block Models
Aditya Gangrade (Boston University) · Praveen Venkatesh (Carnegie Mellon University) · Bobak Nazer (Boston University) · Venkatesh Saligrama (Boston University)
Learning Robust Global Representations by Penalizing Local Predictive Power
Haohan Wang (Carnegie Mellon University) · Songwei Ge (Carnegie Mellon University) · Zachary Lipton (Carnegie Mellon University) · Eric Xing (Petuum Inc. / Carnegie Mellon University)
Unsupervised Curricula for Visual Meta-Reinforcement Learning
Allan Jabri (UC Berkeley) · Kyle Hsu (University of Toronto) · Ben Eysenbach (Carnegie Mellon University) · Abhishek Gupta (University of California, Berkeley) · Alexei Efros (UC Berkeley) · Sergey Levine (UC Berkeley) · Chelsea Finn (Stanford University)
Deep Gamblers: Learning to Abstain with Portfolio Theory
Ziyin Liu (University of Tokyo) · Zhikang Wang (University of Tokyo) · Paul Pu Liang (Carnegie Mellon University) · Ruslan Salakhutdinov (Carnegie Mellon University) · Louis-Philippe Morency (Carnegie Mellon University) · Masahito Ueda (University of Tokyo)
Statistical Analysis of Nearest Neighbor Methods for Anomaly Detection
Xiaoyi Gu (Carnegie Mellon University) · Leman Akoglu (CMU) · Alessandro Rinaldo (CMU)
On the (in)fidelity and sensitivity of explanations
Chih-Kuan Yeh (Carnegie Mellon University) · Cheng-Yu Hsieh (National Taiwan University) · Arun Suggala (Carnegie Mellon University) · David Inouye (Carnegie Mellon University) · Pradeep Ravikumar (Carnegie Mellon University)
Learning Stable Deep Dynamics Models
J. Zico Kolter (Carnegie Mellon University / Bosch Center for AI) · Gaurav Manek (Carnegie Mellon University)
Learning Neural Networks with Adaptive Regularization
Han Zhao (Carnegie Mellon University) · Yao-Hung Tsai (Carnegie Mellon University) · Ruslan Salakhutdinov (Carnegie Mellon University) · Geoffrey Gordon (MSR Montréal & CMU)
Uniform convergence may be unable to explain generalization in deep learning
Vaishnavh Nagarajan (Carnegie Mellon University) · J. Zico Kolter (Carnegie Mellon University / Bosch Center for AI)
Adversarial Music: Real world Audio Adversary against Wake-word Detection System
Juncheng Li (Carnegie Mellon University) · Shuhui Qu (Stanford University) · Xinjian Li (Carnegie Mellon University) · Joseph Szurley (Bosch Center for Artificial Intelligence) · J. Zico Kolter (Carnegie Mellon University / Bosch Center for AI) · Florian Metze (Carnegie Mellon University)
Neuropathic Pain Diagnosis Simulator for Causal Discovery Algorithm Evaluation
Ruibo Tu (KTH Royal Institute of Technology) · Kun Zhang (CMU) · Bo Bertilson (KI Karolinska Institutet) · Hedvig Kjellstrom (KTH Royal Institute of Technology) · Cheng Zhang (Microsoft)
Triad Constraints for Learning Causal Structure of Latent Variables
Ruichu Cai (Guangdong University of Technology) · Feng Xie (Guangdong University of Technology) · Clark Glymour (Carnegie Mellon University) · Zhifeng Hao (Guangdong University of Technology) · Kun Zhang (CMU)
Kalman Filter, Sensor Fusion, and Constrained Regression: Equivalences and Insights
David Farrow (Carnegie Mellon University) · Maria Jahja (Carnegie Mellon University) · Roni Rosenfeld (Carnegie Mellon University) · Ryan Tibshirani (Carnegie Mellon University)
Specific and Shared Causal Relation Modeling and Mechanism-based Clustering
Biwei Huang (Carnegie Mellon University) · Kun Zhang (CMU) · Pengtao Xie (Petuum / CMU) · Mingming Gong (University of Melbourne) · Eric Xing (Petuum Inc.) · Clark Glymour (Carnegie Mellon University)
Towards modular and programmable architecture search
Renato Negrinho (Carnegie Mellon University) · Matthew Gormley (Carnegie Mellon University) · Geoffrey Gordon (MSR Montréal & CMU) · Darshan Patil (Carnegie Mellon University) · Nghia Le (Carnegie Mellon University) · Daniel Ferreira (TU Wien)
Are Sixteen Heads Really Better than One?
Paul Michel (Carnegie Mellon University, Language Technologies Institute) · Omer Levy (Facebook) · Graham Neubig (Carnegie Mellon University)
Inducing brain-relevant bias in natural language processing models
Dan Schwartz (Carnegie Mellon University) · Mariya Toneva (Carnegie Mellon University) · Leila Wehbe (Carnegie Mellon University)
Differentially Private Covariance Estimation
Kareem Amin (Google Research) · Travis Dick (Carnegie Mellon University) · Alex Kulesza (Google) · Andres Munoz (Google) · Sergei Vassilvitskii (Google)
Missing Not at Random in Matrix Completion: The Effectiveness of Estimating Missingness Probabilities Under a Low Nuclear Norm Assumption
Wei Ma (Carnegie Mellon University) · George Chen (Carnegie Mellon University)
Interpreting and improving natural-language processing (in machines) with natural language-processing (in the brain)
Mariya Toneva (Carnegie Mellon University) · Leila Wehbe (Carnegie Mellon University)
On Human-Aligned Risk Minimization
Liu Leqi (Carnegie Mellon University) · Adarsh Prasad (Carnegie Mellon University) · Pradeep Ravikumar (Carnegie Mellon University)
Search on the Replay Buffer: Bridging Planning and Reinforcement Learning
Ben Eysenbach (Carnegie Mellon University) · Ruslan Salakhutdinov (Carnegie Mellon University) · Sergey Levine (UC Berkeley)
Multiple Futures Prediction
Charlie Tang (Apple Inc.) · Ruslan Salakhutdinov (Carnegie Mellon University)
Neural Taskonomy: Inferring the Similarity of Task-Derived Representations from Brain Activity
Aria Y Wang (Carnegie Mellon University) · Leila Wehbe (Carnegie Mellon University) · Michael J Tarr (Carnegie Mellon University)
Inherent Tradeoffs in Learning Fair Representation
Han Zhao (Carnegie Mellon University) · Geoff Gordon (Microsoft)
Learning Data Manipulation for Augmentation and Weighting
Zhiting Hu (Carnegie Mellon University) · Bowen Tan (CMU) · Ruslan Salakhutdinov (Carnegie Mellon University) · Tom Mitchell (Carnegie Mellon University) · Eric Xing (Petuum Inc. / Carnegie Mellon University)
Figure 1: Comparison of existing algorithms without policy certificates (top) and with our proposed policy certificates (bottom). While in existing reinforcement learning the user has no information about how well the algorithm will perform in the next episode, we propose that algorithms output policy certificates before playing an episode to allow users to intervene if necessary.
Designing reinforcement learning methods which find a good policy with as few samples as possible is a key goal of both empirical and theoretical research. On the theoretical side there are two main ways, regret- or PAC (probably approximately correct) bounds, to measure and guarantee sample-efficiency of a method. Ideally, we would like to have algorithms that have good performance according to both criteria, as they measure different aspects of sample efficiency and we have shown previously that one cannot simply go from one to the other. In a specific setting called tabular episodic MDPs, a recent algorithm achieved close to optimal regret bounds but there was no methods known to be close to optimal according to the PAC criterion despite a long line of research. In our work presented at ICML 2019, we close this gap with a new method that achieves minimax-optimal PAC (and regret) bounds which match the statistical worst-case lower bounds in the dominating terms.
Interestingly, we achieve this by addressing a general issue of PAC and regret bounds which is that they do not reveal when an algorithm will potentially take bad actions (only e.g. how often). This issue leads to a lack of accountability that could be particularly problematic in high-stakes applications (see a motivational scenario in Figure 2).
Besides being sample-efficient, our algorithm also does not suffer from this lack of accountability because it outputs what we call policy certificates. Policy certificates are confidence intervals around the current expected return of the algorithm and optimal return given to us by the algorithm before each episode (see Figure 1). This information allows users of our algorithms to intervene if the certified performance is not deemed adequate. We accompany this algorithm with a new type of learning guarantee called IPOC that is stronger than PAC, regret and the recent Uniform-PAC as it ensures not only sample-efficiency but also the tightness of policy certificates. We primarily consider the simple tabular episodic setting where there is only a small number of possible states and actions. While this is often not the case in practical applications, we believe that the insights developed in this work can potentially be used to design more sample-efficient and accountable reinforcement learning methods for challenging real-world problems with rich observations like images or text.
We propose to make methods for episodic reinforcement learning more accountable by having them output a policy certificate before each episode. A policy certificate is a confidence interval \([l_k, u_k]\) where \(k\) is the episode index. This interval contains both the expected sum of rewards of the algorithm’s policy in the next episode and the optimal expected sum of rewards in the next episode (see Figure 1 for an illustration). As such, a policy certificate helps answer two questions which are of interest in many applications:
Policy certificates are only useful if these confidence intervals are not too loose. To ensure this, we introduce a type of guarantee for algorithms with policy certificates IPOC (Individual POlicy Certificates) bounds. These bounds guarantee that all certificates are valid confidence intervals and bound the number of times their length can exceed any given threshold. IPOC bounds guarantee both the sample-efficiency of policy learning and the accuracy of policy certificates. That means the algorithm has to play better and better policies but also needs to tell us more accurately how good these policies are. IPOC bounds are stronger than existing learning bounds such as PAC or regret (see Figure 3) and imply that the algorithm is anytime interruptible (see paper for details).
Policy certificates are not limited to specific types of algorithms but optimistic algorithms are particularly natural to extend to output policy certificates. These methods give us the upper end of certificates “for free” as they maintain an upper confidence bound \(\tilde Q(s,a)\) on the optimal value function Q*(s,a) and follow the greedy policy π with respect to this upper confidence bound. In similar fashion, we can compute a lower confidence bound \(\underset{\sim}{Q}(s,a)\) of the Q-function \(Q^\pi (s,a)\) of this greedy policy. The certificate for this policy is then just these confidence bounds evaluated at the initial state \(s_1\) of the episode \([l_k, u_k] = \left[ \underset{\sim}{Q}(s_1, \pi(s_1)), \tilde Q(s_1, \pi(s_1)\right]\)
We demonstrate this principle with a new algorithm called ORLC (Optimistic RL with Certificates) for tabular MDPs. Similar to existing optimistic algorithms like UCBVI and UBEV, it computes the confidence bounds \(\tilde Q\) by optimistic value iteration on an estimated model but also computes lower confidence bounds \(\underset{\sim}{Q}\) with a pessimistic version of value iteration. These procedures are similar to vanilla value iteration but add optimism bonuses or subtract pessimism bonuses in each time step respectively to ensure high confidence bounds.
Interestingly, we found that computing lower confidence bounds for policy certificates can also improve sample-efficiency of policy learning. More concretely, we could tighten the optimism bonuses in our tabular method ORLC using the lower bounds \(\underset{\sim}{Q}\). This makes the algorithm less conservative and able to adjust more quickly to observed data. As a result, we were able to prove the first PAC bounds for tabular MDPs that are minimax-optimal in the dominating term:
Theorem: Minimax IPOC Mistake, PAC and regret bound of ORLC
In any episodic MDP with S states, A actions and an episode length H, the algorithm ORLC satisfies the IPOC Mistake bound below. That is, with probability at least \(1-\delta\), all certificates are valid confidence intervals and for all \(\epsilon > 0\) ORLC outputs certificates larger than \(\epsilon\) in at most
$$\tilde O\left( \frac{S A H^2}{\epsilon^2}\ln \frac 1 \delta + \frac{S^2 A H^3}{\epsilon}\ln \frac 1 \delta \right)$$
episodes. This immediately implies that the bound above is a (Uniform-)PAC bound and that ORLC satisfies a high-probability regret bound for all number of episodes \(T\) of
$$\tilde O\left( \sqrt{SAH^2 T} \ln 1/\delta + S^2 A H^3 \ln(T / \delta) \right)$$.
Comparing the order of our PAC bounds against the statistical lower bounds and prior state of the art PAC and regret bounds in the table below, this is the first time the optimal polynomial dependency of \(SAH^2\) has been achieved in the dominating \(\epsilon^{-2}\) term. Our bounds also improve the prior regret bounds of UCBVI-BF by avoiding their \(\sqrt{H^3T}\) terms, making our bounds minimax-optimal even when the episode length \(H\) is large.
Algorithm | (mistake) PAC bound | Regret bound | IPOC Mistake bound |
Lower bounds | \( \frac{SAH^2}{\epsilon^2} \) | \( \sqrt{H^2 S A T}\) | \( \frac{SAH^2}{\epsilon^2} \) |
ORLC (our) | \( \frac{SAH^2}{\epsilon^2} + \frac{S^2 A H^3}{\epsilon} \) | \( \sqrt{H^2 S A T} + S^2 AH^3 \) | \( \frac{SAH^2}{\epsilon^2} + \frac{S^2 A H^3}{\epsilon} \) |
UCBVI | – | \( \sqrt{H^2 S A T} + \sqrt{H^3 T} + S^2 AH^2 \) | – |
UBEV | \( \frac{SAH^3}{\epsilon^2} + \frac{S^2 A H^3}{\epsilon} \) | \( \sqrt{H^3 S A T} + S^2 AH^3 \) | – |
UCFH | \( \frac{S^2AH^2}{\epsilon^2} \) | – | – |
As mentioned above, our algorithm achieves this new IPOC guarantee and improved PAC bounds by maintaining a lower confidence bound \(\underset{\sim}{Q}(s,a)\) of the Q-function \(Q^\pi(s,a)\) of its policy at all times in addition to the usual upper confidence bound \(\tilde Q(s,a)\) on the optimal value function \(Q^\star(s,a)\). Deriving tight lower confidence bounds \(\underset{\sim}{Q}(s,a)\) requires new techniques compared to those for upper confidence bounds . All recent optimistic algorithms for tabular MDPs leverage for their upper confidence bounds that \(\tilde Q\) is a confidence bound on \(Q^\star\) which does not depend on the samples. The optimal Q-function is always the same, no matter what samples the algorithm saw. We cannot leverage the same insight for our lower confidence bounds because the Q-function of the current policy \(Q^\pi\) does depend on the samples the algorithm saw. After all, the policy \(\pi\) is computed as a function of these samples. We develop a technique that allows us to deal with this challenge by explicitly incorporating both upper and lower confidence bounds in our bonus terms. It turns out that this technique not only helps achieving tighter lower confidence bounds but also tighter upper-confidence bounds. This is the key for our improved PAC and regret bounds.
Our work provided the final ingredient for PAC bounds for episodic tabular MDPs that are minimax-optimal up to lower-order terms and also established the foundation for policy certificates. In the full paper, we also considered more general MDPs and designed a policy certificate algorithm for so-called finite MDPs with linear side information. This is a generalization of the popular linear contextual bandit setting and requires function approximation. In the future, we plan to investigate policy certificates as a useful empirical tool for deep reinforcement learning techniques and examine whether the specific form of optimism bonuses derived in this work can inspire more sample-efficient exploration bonuses in deep RL methods.
This post is also featured on the Stanford AIforHI blog and is based on work in the following paper:
Christoph Dann, Lihong Li, Wei Wei, Emma Brunskill
Policy Certificates: Towards Accountable Reinforcement Learning
International Conference on Machine Learning (ICML) 2019
Other works mentioned in this post:
DISCLAIMER: All opinions expressed in this post are those of the author and do not represent the views of CMU.
Automated decision-making is one of the core objectives of artificial intelligence. Not surprisingly, over the past few years, entire new research fields have emerged to tackle that task. This blog post is concerned with regret minimization, one of the central tools in online learning. Regret minimization models the problem of repeated online decision making: an agent is called to make a sequence of decisions, under unknown (and potentially adversarial) loss functions. Regret minimization is a versatile mathematical abstraction, that has found a plethora of practical applications: portfolio optimization, computation of Nash equilibria, applications to markets and auctions, submodular function optimization, and more.
In this blog post, we will be interested in showing how one can compose regret-minimizing agents—or regret minimizers for short. In other words, suppose that you are given a regret minimizer that can output good decisions on a set \(\mathcal{X}\) and another regret minimizer that can output good decisions on a set \(\mathcal{Y}\). We show how you can combine them and build a good regret minimizer for a composite set obtained from \(\mathcal{X}\) and \(\mathcal{Y}\)—for example their Cartesian product, their convex hull, their intersection, and so on. Our approach will treat the two regret minimizers, one for \(\mathcal{X}\) and one for \(\mathcal{Y}\), as black boxes. This is tricky: we simply combine them without opening the box, so we must account for the possibility of having to combine very different regret minimizers. On the other hand, the benefit is that we are free to pick the best regret minimizers for each individual set. This is important. For example, consider an extensive-form game: we might know how to build specialized regret minimizers for different parts of the game. We figured out how to combine these regret minimizers to build a composite regret minimizer that can handle the whole game. All material is based off of a recent paper that appeared at ICML 2019.
By the end of the blog post, I will give several applications of this calculus. It enables one to do several things which were not possible before. It also gives a significantly simpler proof of counterfactual regret minimization(CFR), the state-of-the-art scalable method for computing Nash equilibrium in large extensive-form games. The whole exact CFR algorithm falls out naturally, almost trivially, from our calculus.
A regret minimizer is an abstraction of a repeated decision-maker. One way to think about a regret minimizer is as a device that supports two operations:
The decision making is online, in the sense that each decision is made by only taking into account the past decisions and their corresponding loss functions; no information about future losses is available to the regret minimizer at any time. For the rest of the post, we focus on linear losses, that is \(\mathcal{F} = \mathcal{L}\) where \(\mathcal{L}\) denotes the set of all linear functions with domain \(\mathcal{X}\).
The quality metric for a regret minimizer is its cumulative regret. Intuitively, it measures how well the regret minimizer did against the best, fixed decision in hindsight. We can formalize this idea mathematically as the difference between the loss that was cumulated, \(\sum_{t=1}^T \ell^t(\mathbf{x}^t)\), and the minimum possible cumulative loss, \(\min_{\hat{\mathbf{x}}\in\mathcal{X}} \sum_{t=1}^T \ell^t(\hat{\mathbf{x}})\). In formulas, the cumulative regret up to time \(T\) is defined as $$\displaystyle R^T := \sum_{t=1}^T \ell^t(\mathbf{x}^t) – \min_{\hat{\mathbf{x}} \in \mathcal{X}} \sum_{t=1}^T \ell^t(\hat{\mathbf{x}}).$$
“Good” regret minimizers, also called Hannan consistent regret minimizers, are such that their cumulative regret grows sublinearly as a function of \(T\). Several good and general-purpose regret minimizers are known in the literature. Some of them, like follow-the-regularized-leader, online mirror descent, and online (projected) gradient descent work for any convex domain \(\mathcal{X}\). Other are tailored for specific domains, such as regret matching and regret matching plus, both of which are specifically designed for the case in which \(\mathcal{X}\) is a (probability) simplex. However, these general-purpose regret minimizers typically come with two drawbacks:
Given the drawbacks of the traditional approaches, we started to wonder about different ways to construct regret minimizers, until we stumbled upon this intriguing thought: can we construct regret minimizers for composite sets by combining regret minimizers for the individual atoms? The answer is yes
Let’s start from a simple example. Suppose we have a regret minimizer that outputs decisions on a convex set \(\mathcal{X}\), and another regret minimizer that outputs decisions on a convex set \(\mathcal{Y}\). How can we combine them to obtain a regret minimizer for their Cartesian product \(\mathcal{X} \times \mathcal{Y}\)? The natural idea, in this case, is to let the two regret minimizers operate independently:
This process is represented pictorially in Figure 2. We coin this type of pictorial representation a “regret circuit”.
Some simple algebra shows that, at all time \(T\), our strategy guarantees that the cumulative regret \(R^T\) of the composite regret minimizer (as seen from outside of the gray dashed box), satisfies \(R^T = R_\mathcal{X}^T + R_\mathcal{Y}^T\), where \(R_\mathcal{X}^T\) and \(R_\mathcal{Y}^T\) are the cumulative regrets of the regret minimizers for domains \(\mathcal{X}\) and \(\mathcal{Y}\) respectively. Hence, if both of those regret minimizers are “good” (Hannan consistent), than so is the composite regret minimizer.
What about convex hulls? It turns out that this is much trickier! We can try to reuse the same approach as before: we ask the two regret minimizers, one for \(\mathcal{X}\) and one for \(\mathcal{Y}\), to independently output decisions. But now we run into this dilemma as to how we should form a convex combination between the two decisions.
In this case, the regret circuit is shown in Figure 3.
If the loss function \(\ell_{\lambda}^{t-1}\) that enters the extra regret minimizer is set up correctly, and if all three internal regret minimizers are good, one can prove that the composite regret minimizer, as seen from outside of the gray dashed box, is also a good regret minimizer. In particular, a natural way to define \(\ell_{\lambda}^{t}\) is as
\[
\ell^t_\lambda : \Delta^{2} \ni (\lambda_1,\lambda_2) \mapsto \lambda_1 \ell^t(\mathbf{x}^t) + \lambda_2\ell^t(\mathbf{y}^t),
\]
which can be seen as a form of counterfactual loss function.
It turns out that the two regret circuits we’ve seen so far—one for the Cartesian product and one for the convex hull of two sets—are already enough to give a very natural proof of the counterfactual regret minimization (CFR) framework, a family of regret minimizers, specifically tailored for extensive-form games. CFR has been the de facto state of the art for the past 10+ years for computing approximate Nash equilibria in large games and has been one of the key technologies that allowed to solve large Heads-Up Limit and No-Limit Texas Hold’Em. The basic intuition is as follows (all details are in our paper). Consider for example the sequential action space of the first player in the game of Kuhn poker (Figure 4, left):
In other words, we can represent the strategy of the player by composing convex hulls and Cartesian products, following the structure of the game (Figure 4, right).
Since we can express the set of strategies in the game by composing convex hulls and Cartesian products, it should now be clear how our framework assists us in constructing a regret minimizer for this domain.
After having seen Cartesian products and convex hulls, a natural question is: what about intersections and constraint satisfaction? In this case, we assume to have access to a good regret minimizer for a domain \(\mathcal{X}\), and we want to somehow construct a good regret minimizer for the curtailed set \(\mathcal{X} \cap \mathcal{Y}\).
It turns out that, in general, these constraining operations are more costly than “enlarging” operations such as convex hull, Minkowski sums, Cartesian products, etc. In the paper, we show two different circuits:
The main idea for both circuits is to use the regret minimizer for \(\mathcal{X}\) output decisions, and then penalize infeasible choices by injecting extra penalization terms in the loss functions that enter the regret minimizer for \(\mathcal{X}\). In the case of the circuit that guarantees feasibility, the decisions are also projected onto \(\mathcal{X} \cap \mathcal{Y}\) before they are output by the composite regret minimizer. Figure 5 shows the resulting regret circuits, where \(d\) is the distance-generating function used in the projection (for example, a good choice could be \(d(\mathbf{x}) = \|\mathbf{x}\|^2_2\)), and \(\alpha^t\) is a penalization coefficient.
While here we are not interested in all the details of this circuit, we remark an interesting observation: the regret circuit is a constructive proof of the fact that we can always turn an infeasible regret minimizer into a feasible one by projecting onto the feasible set, outside the loop!
Armed with these new intersection circuits, we can show that the recent Constrained CFR algorithm can be constructed as a special example via our framework. Our exact (feasible) intersection construction leads to a new algorithm for the same problem as well.
Another application is in the realm of optimistic/predictive regret minimization. This is a recent subfield of online learning, whose techniques can be used to break the learning-theoretic barrier \(O(T^{-1/2})\) on the convergence rate of regret-based approaches to saddle points (for example, Nash equilibria). In a different ICML 2019 paper, we used our calculus to prove that, under certain hypotheses, CFR can be modified to have a convergence rate of \(O(T^{-3/4})\) to Nash equilibrium, instead of \(O(T^{-1/2})\) as in the original (non-optimistic) version.
Regret circuits have already proved to be useful in several applications, mostly in game theory. The fact that we can combine potentially very different regret minimizers as black boxes is very appealing because it enables to choose the best algorithm for each set that is being composed, and conquer different parts of the design space with different techniques. In the paper, we show regret circuits for several convexity-preserving operations, including convex hull, Cartesian product, affine transformations, intersections, and Minkowski sums. However, several questions remain open:
Figure 1: Visualization of supervised neighborhoods for local explanation with MAPLE. When seeing the new point \(X = (1, 0, 1)\), this tree determines that \(X_2\) and \(X_6\) are its neighbors and gives them weight 1 and gives all other points weight 0. MAPLE averages these weights across all the trees in the ensemble.
Machine learning is increasingly used to make critical decisions such as a doctor’s diagnosis, a biologist’s experimental design, and a lender’s loan decision. In these areas, mistakes can be the difference between life and death, can lead to wasted time and money, and can have serious legal consequences.
Because of the serious potential ramifications of using machine learning in these domains, it falls onto machine learning practitioners to ensure that their models are robust and to foster trust with the people who interact with their models. Broadly speaking, meeting these two goals is the objective of interpretability and is achieved by iteratively: explaining both global and local behavior of a model (increasing understanding), checking that these explanations make sense (developing trust), and fixing any identified problems (preventing bad failures).
Meeting these two goals is a very difficult task and interpretability faces many challenges, but we will be focusing on two in particular:
Our proposed method, MAPLE, couples classical local linear modeling techniques with a dual interpretation of tree ensembles (which aggregate the predictions of multiple decision trees), both as a supervised neighborhood approach and as a feature selection method (see Fig. 1). By doing this, we are able to slightly improve accuracy while producing multiple types of explanations.
Before diving into the technical details of MAPLE and how it works as an interpretability system (both for explaining its own predictions and for explaining the predictions of another model), we provide an overview and comparison of the main types of explanations.
At a high level, there are three main types of explanations:
Example-based explanations are clearly distinct from the other two explanation types, as the former relies on sample data points and the latter two on features. Furthermore, local and global explanations themselves capture fundamentally different characteristics of the predictive model. To see this, consider the toy datasets in Fig. 2 generated from three univariate functions.
Figure 2: Toy datasets from left to right (a) Linear (b) Shifted Logistic (c) Step Function.
Generally, local explanations are better suited for modeling smooth continuous effects (Fig. 2a). For discontinuous effects (Fig. 2c) or effects that are very strong in a small region (Fig. 2b), local explanations either fail to detect the effect or make unusual predictions, depending on how the local neighborhood is defined (i.e., whether or not it is defined in a supervised manner, more on this in the ‘Supervised vs Unsupervised Neighborhood’ section). We will call such effects global patterns because they are difficult to detect or model with local explanations.
Conversely, global explanations are less effective at explaining continuous effects and more effective at explaining global patterns. This is because they tend to be rule-based models that use feature discretization or binning. This processing doesn’t lend itself easily to modeling continuous effects (you need many small steps to approximate a linear model well) but does lend itself towards modeling the abrupt changes around global patterns (because those effects create natural cut-offs for the feature discretization or binning).
Most real datasets have both continuous and discontinuous effects and, therefore, it is crucial to devise explanation systems that can capture, or are at least aware of, both types of effects.
Because local explanations are actionable for (they answer the question “what could I have done differently to get the desired outcome?”) and relevant to (it is not particularly helpful to a person to know how the model behaves for an entirely different person) the people impacted by machine learning systems, we focus on them for this work.
The goal of a local explanation, \(g\), is to approximate our learned model, \(f\), well across some neighborhood of the input space, \(N_x\). Naturally, this leads to the fidelity-metric: \(E_{x’ \sim N_x}[ (g(x’) – f(x’))^2]\). The choices of \(g\) and \(N_x\) are important and should often be problem specific. Similar to previous work, we assume that \(g\) is a linear function.
Figure 3: A simple way of generating a local explanation that is very similar to LIME. From left to right, 1) Start with a point that you want to explain, 2) Define a neighborhood around that point, 3) Sample points from that neighborhood, and 4) Fit a linear model to the model’s predictions at those sampled points
MAPLE (Plumb et al. 2018) modifies tree ensembles to produce local explanations that are able to detect global patterns and to produce example-based explanations; these modifications are built on work from (A. Bloniarz et al. 2016) and (S. Kazemitabar et al. 2017). Importantly, we find that doing this typically improves the predictive accuracy of the model and that the resulting local explanations have high fidelity.
At a high level, MAPLE uses the tree ensemble to identify which training points are most relevant to a new prediction and uses those points to fit a linear model that is used both to make a prediction and as a local explanation. We will now make this more precise.
Given training data \((x_i, y_i)\) for \(i= 1, \ldots, n\), we start by training an ensemble of trees on this data, \(T_i\) for \(i= 1, \ldots, K\). For a point \(x\), let \(T_k(x)\) be the index of the leaf node of \(T_k\) that contains \(x\). Suppose that we want to make a prediction at \(x\) and also give an explanation for that prediction.
To do this, we start by assigning a similarity weight to each training point, \(x_i\), based on how often the trees put \(x_i\) and \(x\) in the same leaf node. So we define \(w_i = \frac{1}{K} \sum_{j=1}^K \mathbb{I}[T_j(x_i) = T_j(x)]\). This is how MAPLE produces example-based explanations; training points with a larger \(w_i\) will be more relevant to the prediction/explanation at \(x\) than training points with smaller weights. An example of this process for a single tree is shown in Fig. 1.
To actually make a prediction/explanation, we solve the weighted linear regression problem \(\hat\beta_x = \text{argmin}_\beta \sum_{i=1}^n w_i (\beta^T x_i – y_i)^2\). Then MAPLE makes the prediction \(f_{MAPLE}(x) = \hat\beta_x^T x\) and gives the local explanation \(\hat\beta_x\). Because the \(w_i\) depend on the training data (i.e., the most relevant points depend on the labels \(y_i\)), we say that \(\hat\beta_x\) uses a supervised neighborhood.
When LIME defines its local explanations, it optimizes for the fidelity-metric with \(N_x\) set as a probability distribution centered on \(x\). So we say it uses an unsupervised neighborhood. As mentioned earlier, the behavior of local explanations around global patterns depends on whether or not they use a supervised or unsupervised neighborhood.
Why don’t unsupervised neighborhoods detect global patterns? Near a global pattern, an unsupervised neighborhood will sample points on either side of it. Consequently, if the explanation is linear, it will smooth the global pattern (i.e., fail to detect it). Importantly, the only indication that something might be awry is that the explanation will have lower fidelity.
Although sometimes this smoothing is a good enough approximation, it would be better if the explanation detected the global pattern. For example, if we interpret Fig. 2b as the probability of giving someone a loan as their income increases, we can see that smoothing the global effect causes the explanation to give overly optimistic advice.
How are supervised neighborhoods different? On the other hand, supervised neighborhoods will tend to sample points only on one side of the global pattern and consequently will not smooth it. For example, in Fig. 2c, MAPLE will predict a slope of zero at almost all points because the function is flat across each one of its three learned neighborhoods.
But this clearly is also not a desirable behavior since it would imply that this feature does not matter for the prediction. Consequently, we introduce a technique to determine if a coefficient is zero/small because it does not matter or if it is zero/small because it is near a global pattern.
We do this by examining the probability distribution over the features induced by the weights, \(w_i\), and training points, \(x_i\), and determining where the explanation can be applied. Note that this distribution is defined using the weights learned by MAPLE. When a point is near a global pattern, this distribution becomes skewed and we can detect it. A brief example is shown bellow in Fig. 4 (see the paper for complete details).
Figure 4: An example of the local neighborhoods learned by MAPLE as we perform a grid search across the active feature of each of the toy datasets from Fig. 2. Notice that we can detect the strong effect by the small neighborhood in the steep region of the logistic curve (middle) and the discontinuities in the step function (right).
In summary, by using the local training distribution that MAPLE learns around a point, we can determine whether or not that point is near a global pattern.
When evaluating the effectiveness of MAPLE, there are three main questions:
We evaluated these questions on several UCI datasets [Dheeru 2017] and will summarize our results here (for full details, see the paper).
1. Do we sacrifice accuracy to gain interpretability? No, in fact MAPLE is almost always more accurate than the tree ensemble it is built on.
2. How well do its local explanations explain its own predictions? When comparing MAPLE’s local explanation to an explanation fit by LIME to explain the predictions made by MAPLE, MAPLE produces substantially better explanations (as measured by the fidelity metric).
This is not surprising since this is asking MAPLE to explain itself, but it does indicate that MAPLE is an improvement on tree ensembles in terms of both accuracy and interpretability.
3. How well can it explain a black-box model? When we use MAPLE or LIME to explain a black-box model (in this case a Support Vector Regression model), MAPLE often produces better explanations (again, measured by the fidelity metric).
By using leaf node membership as a form of supervised neighborhood selection, MAPLE is able to modify tree ensembles to be substantially more interpretable without the typical accuracy-interpretability trade-off. Additionally, it is able to provide feedback for all three types of explanations: local explanations via training a linear model, example-based explanations via highly weighted neighbors, and finally, detection of global patterns by using the supervised neighborhoods.
Marco Tulio Ribeiro, Sameer Singh, and Carlos Guestrin. “Why should i trust you?: Explaining the predictions of any classifier.” Proceedings of the 22nd ACM SIGKDD international conference on knowledge discovery and data mining. ACM, 2016.
Ribeiro, Marco Tulio, Sameer Singh, and Carlos Guestrin. “Anchors: High-precision model-agnostic explanations.” Thirty-Second AAAI Conference on Artificial Intelligence. 2018.
Gregory Plumb, Denali Molitor, and Ameet S. Talwalkar. “Model Agnostic Supervised Local Explanations.” Advances in Neural Information Processing Systems. 2018.
A. Bloniarz, C. Wu, B. Yu, A. Talwalkar. Supervised Neighborhoods for Distributed Nonparametric Regression. AISTATS, 2016.
S. Kazemitabar, A. Amini, A. Bloniarz, A. Talwalkar. Variable Importance Using Decision Trees. NIPS, 2017.
Dua Dheeru and Efi Karra Taniskidou. UCI machine learning repository, 2017.
DISCLAIMER: All opinions expressed in this post are those of the author and do not represent the views of CMU.
]]>Figure 1. Fine-tuning a language model to predict EEG data. The encoder is pretrained on Wikipedia to predict the next word in a sequence (or previous word for the backward LSTM). We use the contextualized embeddings from the encoder as input to a decoder. The decoder uses a convolution to create embeddings for each pair of words, which, along with word-frequency and word-length become the basis for a linear layer to predict EEG responses. The model is fine-tuned to predict this EEG data. In this example the model is jointly trained to predict the N400 and P600 EEG responses.
Imagine for a moment that we can take snippets of text, give them to a computational model, and that the model can perfectly predict some of the brain activity recorded from a person who was reading the same text. Can we learn anything about how the brain works from this model? If we trust the model, then we can at least identify which parts of the brain activity are related to the text. Beyond this though, what we learn from the model depends on how much we know about the mechanisms it uses to make its predictions. Given that we want to understand these mechanisms, and that models produced by deep learning can be difficult to interpret, deep learning seems at first glance not to be a good candidate for analyzing language processing in the brain. However, deep learning has proven to be amazingly effective at capturing statistical regularities in language (and other domains). This effectiveness motivated us to see whether a deep learning model is able to predict brain activity from text well, and importantly, whether we can gain any understanding about the brain activity from the predictions. It turns out that the answer to both questions is yes.
One of the open questions in the study of how the brain processes language is how word meanings are integrated together to form the meanings of sentences, passages and dialogues. Electroencephalography (EEG) is a tool that is commonly used to study those integrative processes. In a recent paper, we propose to use fine-tuning of a language model and multitask learning to better understand how various language-elicited EEG responses are related to each other. If we can better understand these EEG responses and what drives them, then we can use that understanding to better study language processing in people.
In our analysis, we use EEG observations of brain activity recorded as people read sentences. Several different kinds of deviations from baseline measurements of activity occur as people read text. The most well studied of these is called the N400 response. It is a Negative deflection in the electrical activity (relative to a baseline) that occurs around 400 milliseconds after the onset of a word (thus N400), and it has been associated with semantic effort. If a word is expected in context — for example “I like peanut butter and jelly” versus “I like peanut butter and roller skates” — then the expected word jelly will elicit a reduced N400 response compared to the unexpected roller.
In the data we analyze (made available by Stefan Frank and colleagues) six different language associated responses are considered. Three of these — the N400, PNP, and EPNP responses — are generally considered markers for semantic processes in the brain while the other three — the P600, LAN, and ELAN — are generally considered markers for syntactic processes in the brain. The division of these EEG responses into indicators for syntactic and semantic processes is controversial, and there is considerable debate about what each of the responses signifies. The P600, for example, is thought by some researchers to be triggered by syntactic violations, as in “The plane took we to paradise and back” while others have noted that it can also be triggered by semantic role violations, as in “Every morning at breakfast the eggs would eat …”, and still others have questioned whether the P600 is language specific or rather a marker for any kind of rare event. One possibility is that it is associated with an attentive process invoked to reconcile conflicting information from lower level language processing. In any case, a clearer picture of the relationship between all of the EEG responses and between text and the EEG responses would make them better tools for investigating language processing in the brain.
Rather than having discrete labeling of whether each of the six EEG responses occurred as a participant read a given word, in this dataset the EEG responses are defined continuously as the average potential of a predefined set of EEG sensors during a predefined time-window (relative to when a word appears). This gives us six scalar values per word per experiment participant, and we average the values across the participants to give six final scalar values per word.
To predict these six scalar values for each word, we use a pretrained bidirectional LSTM as an encoder. We anticipate that the EEG responses occur in part as a function of a shift in the meaning or structure of the incoming language. For example, the N400 is associated with semantic effort and surprisal, so we might expect that the N400 would be some function of a difference between adjacent word embeddings. Because of this intuition, we pair up the embeddings output from the encoder by putting them through a convolutional layer that can learn functions on adjacent word embeddings. We use the pair embeddings output by the convolution, along with the word length and log-probability of the word as the basis for predicting the EEG responses. The EEG responses are predicted from this basis using a linear layer. The forward and backward LSTMs are pretrained independently on the WikiText-103 dataset to predict the next and previous words respectively from a snippet of text. We fine-tune the model by training the decoder first and keeping the encoder parameters fixed, and then after that we continue training by also modifying the final layer of the LSTM for a few epochs.
A natural question is whether these EEG measures of brain activity can be predicted from the text at all, and whether all of this deep learning machinery actually improves the prediction compared to a simpler model. As our measure of accuracy, we use the proportion of variance explained — i.e. we normalize the mean squared error on the validation set by the variance on the validation set and subtract that number from 1: \(\mathrm{POVE} = 1 – \frac{\mathrm{MSE}}{\mathrm{variance}}\). We compare the accuracy of using the decoder on top of three different encoders: an encoder which completely bypasses the LSTM (i.e. the output embeddings are the same as the input embeddings to the encoder), an encoder which is a forward-only LSTM, and an encoder which is the full bidirectional LSTM.
Surprisingly, we see that all six of the EEG measures can be predicted at above chance levels (\(0\) is chance here since guessing the mean would give us \(\mathrm{POVE}\) of \(0\)). Previous work (here and here) has found that only some of the EEG measures are predictable, but that work did not directly try to predict the brain activity from the text. Instead, it used an estimate of the surprisal (the negative log-probability of the word in context), and an estimate of the syntactic complexity to predict the EEG data. Those intermediate values have the benefit of being interpretable, but they lose a lot of the pertinent information.
We also see that the full bidirectional encoder is better able to predict the brain activity than the other encoders. The comparison between encoders is not completely fair because there are more parameters in the forward-only encoder than the embedding-only encoder, and more parameters than both of those in the bidirectional encoder, so part of the reason that the bidirectional encoder might be better is simply that it has more degrees of freedom to work with. Nonetheless, this result suggests that the context matters for the prediction of the EEG signals, which means that there is opportunity to learn about the features in the language stream that drive the EEG responses.
It’s good to see that the deep learning model can predict all of the EEG responses, but we also want to learn something about those responses. We use multitask learning to accomplish that here. We train our network using \(63 = \binom{6}{1} + \binom{6}{2} + … + \binom{6}{6}\) variations of our loss function. In each variation, we choose a subset of the six EEG signals and include a mean squared error term for the prediction of each signal in that subset. For example, one of the variations includes just the N400 and the P600 responses, so there are mean squared error terms for the prediction of the N400 and the prediction of the P600 in the loss function for that variation, but not for the LAN. We only make predictions for content words (adjectives, adverbs, auxiliary verbs, nouns, pronouns, proper nouns, and verbs), so if there are \(B\) examples in a mini-batch, and example \(b\) has \(W_b\) content words, and if we let the superscripts \(p,a\) denote the predicted and actual values for an EEG signal respectively, then the loss function for the N400 and P600 variation can be written as:
$$\frac{1}{\sum_{b=1}^B W_b} \sum_{b=1}^B \sum_{w=1}^{W_b} (\mathrm{P600}^{p}_{b,w} – \mathrm{P600}^{a}_{b,w})^2 + (\mathrm{N400}^{p}_{b,w} – \mathrm{N400}^{a}_{b,w})^2$$
The premise of this method is that if two or more EEG signals are related to each other, then including all of the related signals as prediction tasks should create a helpful inductive bias. With this bias, the function that the deep learning model learns between the text and an EEG signal of interest should be a better approximation of the true function, and therefore it should generalize better to unseen examples.
We filter the results to keep (i) the variations that include just a single EEG response in the loss function (the top bar in each group below), (ii) the variations that best explain each EEG response (the bottom bar in each group below), and (iii) the variations which are not significantly different from the best variations and which include no more EEG responses in the loss function than the best variation, i.e. all simpler combinations of EEG responses which perform as well as the best combination (all the other bars). For the N400, where the best variation does not include any other EEG signals, we also show how the proportion of variance explained changes when we include each of the other EEG signals.
For each target EEG signal other than the N400, it is possible to improve prediction by using multitask learning. As Rich Caruana points out in his work on multitask learning, a target task can be improved by auxiliary tasks even when the tasks are unrelated. However, our results are suggestive of relationships between the EEG signals. It’s not the case that training with more EEG signals is always better, and the pattern of improvements for different variations doesn’t look random. The improvements also don’t follow the pattern of raw correlations between the EEG signals (see our paper for the correlations).
Some of the relationships we see here are expected from current theories of how each EEG response relates to language processing. The LAN/P600 and ELAN/P600 relationship is expected based both on prior studies where they have been observed together and theory that the ELAN/LAN responses occur during syntactic violations and the P600 occurs during increased syntactic effort. Our results also suggest some relationships which are not as expected, but which have plausible explanations. For example, some researchers believe that the ELAN and LAN responses mark working memory demands, and if this is so, then those responses might be expected to be related to the other signals that track language processing demands of any kind. That could explain why they seem to widely benefit (and benefit from) the prediction of other signals. However, the apparent isolation of the N400 from this benefit would be surprising in that case.
We need to be a little careful about over-interpreting the results here; the way that the EEG responses are defined in this dataset means that several of them are spatially overlapping and close to each other temporally, so some signals may spill-over into others. Future studies will be required to tease apart the possibilities suggested by this analysis, but we believe that this methodology is a promising direction. Multitask learning can help us understand complex relationships between EEG signals. We can also partially address the concern about signal spill-over by including other prediction tasks.
Two additional tasks we can include are prediction of self-paced reading times (in which words are shown one-by-one and the experiment participant presses a button to advance to the next word) and eye-tracking data. Both are available from different experiment participants for the sentences that the EEG signals were collected on. Self-paced reading times and eye-tracking data can both be thought of as measures of reading comprehension difficulty, so we expect that they should be related to the EEG data. Indeed, we see that when these tasks are used in training, both benefit the prediction of the EEG data compared to training on the target EEG signal alone. This result is really interesting because it cannot be explained by any spill-over effect. It suggests that the model might really be learning about some of the latent factors that underlie both EEG responses and behavior (for the detailed results and further discussion of the behavioral data, please see our paper).
It’s really exciting to see how well the EEG signals can be predicted using one of the latest language models, and multitask learning gives us some insight into how the EEG signals relate to each other and to behavioral data. While this analysis method is for now largely exploratory and suggestive, we hope to extend it over time to gain more and more understanding of how the brain processes language. If you’re interested in more information about the method or further discussion of the results, please check out our paper here.
DISCLAIMER: All opinions expressed in this post are those of the author and do not represent the views of CMU.
]]>Over the past decade, artificial intelligence (AI) has achieved remarkable success in many fields such as healthcare, automotive, and marketing. The capabilities of sophisticated, autonomous decision systems driven by AI keep evolving and moving from lab to reality. Many of these systems are black-box which means we don’t really understand how they work and why they reach such decisions.
As black-box decision systems have come into greater use, they have also come under greater criticism. One of the main concerns is that it is dangerous to rely on black-box decisions without knowing the way they are made. Here is an example of why they can be dangerous.
Risk-assessment tools have been widely used in the federal and state courts to facilitate and improve judges’ decisions in the criminal justice processes. They provide defendants’ future criminal risk based on socio-economic status, family background, and other factors. In May 2016, ProPublica claimed that one of the most widely used risk-assessment tools, COMPAS, was biased against black defendants while being more generous to white defendants [link]. Northpointe, a for-profit company that provides the software, disputed the analysis but refused to disclose the software’s decision mechanism. So it is not possible for either stakeholders or the public to see what might be actually creating the disparity.
It is dangerous to rely on black-box decisions without knowing the way they are made. Here, we raise a question: How can we possibly go about resolving this concern? Explaining how a black-box decision system works or why it reaches such decisions helps to decide whether or not to follow its decisions. The need for interpretability is especially urgent in fields where black-box decisions can be life-changing and have significant consequences, such as disease diagnosis, criminal justice, and self-driving cars.
What makes a ‘good’ explanation for a black-box? Assume that you give a black-box predictive model an image of an apple. You open the black-box and explain why it believes this is indeed an apple on the image. Simply saying that “it is red, so this is an apple” is not sufficient to justify your thought, but you should also avoid redundant explanation. It is important to give enough information concisely in explaining a black-box decision system. In other words, explanations should be brief but comprehensive.
How can we take into account both briefness and comprehensiveness for explaining a black-box? Our work uses an information theoretic perspective to quantify the idea of briefness and comprehensiveness.
The information bottleneck principle (Tishby et al., 2000) provides an appealing information theoretic view for learning supervised models by defining what we mean by a ‘good’ representation. The principle says that the optimal model transmits as much information as possible from its input to its output through a compressed representation called the information bottleneck. And the information bottleneck is a good representation that is maximally informative about the output while compressive about a given input. Recently, Shwartz-Ziv et al. (2017) and Tishby et al. (2015) showed that the principle also applies to deep neural networks and each layer of the networks can work as an information bottleneck.
We adopt the information bottleneck principle as a criterion for finding a ‘good’ explanation. In the information theoretic view, we define a brief but comprehensive explanation as maximally informative about the black-box decision while compressive about a given input. In other words, the explanation should maximally compress the mutual information regarding an input while preserving as much as possible mutual information regarding its output.
We introduce the variational information bottleneck for interpretation (VIBI), a system-agnostic information bottleneck model that provides a brief but comprehensive explanation for every single decision made by a black-box.
VIBI is composed of two parts: explainer and approximator, each of which is modeled by a deep neural network. Using the information bottleneck principle, VIBI learns an explainer that favors brief explanations while enforcing that the explanations alone suffice for an accurate approximation to the black-box. See the following illustration for an illustration of VIBI.
For each instance, the explainer returns a probability whether a chunk of features, called a cognitive chunk, will be selected as an explanation or not. Cognitive chunk is defined as a group of raw features that work as a unit to be explained and whose identity is recognizable to a human, such as a word, phrase, sentence or a group of pixels. The selected chunks act as an information bottleneck that is maximally compressed about input and informative about the decision made by a black-box system on that input.
Now, we formulate the following optimization problem inspired by the information bottleneck principle to learn the explainer and approximator:
$$ p(\mathbf{z} | \mathbf{x}) = \mathrm{argmax}_{p(\mathbf{z} | \mathbf{x}), p(\mathbf{y} | \mathbf{t})} ~~\mathrm{I} ( \mathbf{t}, \mathbf{y} ) – \beta~\mathrm{I} ( \mathbf{x}, \mathbf{t} )$$ where \( \mathrm{I} ( \mathbf{t}, \mathbf{y} ) \) represents the sufficiency of information retained for explaining the black-box output \( \mathbf{y} \), \(-\mathrm{I} ( \mathbf{x}, \mathbf{t} ) \) represents the briefness of the explanation \( \mathbf{t} \), and \( \beta \) is a Lagrange multiplier representing a trade-off between the two.
The current form of information bottleneck objective is intractable due to the mutual informations and the non-differentiable sample \( \mathbf{z} \). We address these challenges as follows.
Variational Approximation to Information Bottleneck Objective
The mutual informations \( \mathrm{I} ( \mathbf{t}, \mathbf{y} ) \) and \( \mathrm{I} ( \mathbf{x}, \mathbf{t} ) \) are computationally expensive to quantify (Tishby et al., 2000; Chechik et al., 2005). In order to reduce the computational burden, we use a variational approximation to our information bottleneck objective: $$\mathrm{I} ( \mathbf{t}, \mathbf{y} )~-~\beta~\mathrm{I} ( \mathbf{x}, \mathbf{t} )
\geq \mathbb{E}_{\mathbf{y} \sim p(\mathbf{x})} \mathbb{E}_{\mathbf{y} | \mathbf{x} \sim p(\mathbf{y} | \mathbf{x})} \mathbb{E}_{\mathbf{t} | \mathbf{x} \sim p(\mathbf{t} | \mathbf{x})} \left[ \log q(\mathbf{y} | \mathbf{t}) \right] ~-~\beta~\mathbb{E}_{\mathbf{x}\sim p(\mathbf{x})} \mathrm{KL} (p(\mathbf{z}| \mathbf{x}), r(\mathbf{z})) $$
Now, we can integrate the Kullback-Leibler divergence \( \mathrm{KL} (p(\mathbf{z}| \mathbf{x}), r(\mathbf{z})) \) analytically with proper choices of \( r(\mathbf{z}) \) and \( p(\mathbf{z}|\mathbf{x}) \). We also use the empirical data distribution to approximate \( p(\mathbf{x}, \mathbf{y}) = p(\mathbf{x})p(\mathbf{y}|\mathbf{x}) \).
Continuous Relaxation and Re-parameterization
We use the generalized Gumbel-softmax trick (Jang et al., 2017; Chen et al., 2018), which approximates the non-differentiable categorical subset sampling with Gumbel-softmax samples that are differentiable. This trick allows using standard backpropagation to compute the gradients of the parameters via reparameterization.
VIBI provides instance-specific keywords to explain an LSTM sentiment prediction model using Large Movie Review Dataset, IMDB.
The keywords such as “waste,” and “horrible,” are selected for the negative-predicted movie review, while keywords such as “most fascinating,” explain the model’s positive-predicted movie review. Also, we could see that the LSTM sentiment prediction model makes a wrong prediction for a negative review because the review includes several positive words such as ‘enjoyable’ and ‘exciting’.
VIBI also provides instance-specific key patches containing \( 4 \times 4 \) pixels to explain a CNN digit recognition model using the MNIST image dataset.
The first two examples show that the CNN recognizes digits using both shapes and angles. In the first example, the CNN characterizes ‘1’s by straightly aligned patches along with the activated regions although ‘1’s in the left and right panels are written at different angles. Contrary to the first example, the second example shows that the CNN recognizes the difference between ‘9’ and ‘6’ by their differences in angles. The last two examples show that the CNN catches a difference of ‘7’s from ‘1’s by patches located on the activated horizontal line on ‘7’ (see the cyan circle) and recognizes ‘8’s by two patches on the top of the digits and another two patches at the bottom circle.
We assume that a better explanation allows humans to better infer the black-box output given the explanation. Therefore, we asked humans to infer the output of the black-box system (Positive/Negative/Neutral) given five keywords as an explanation generated by VIBI and other competing methods (Saliency, LIME, and L2X). Each method was evaluated by the human intelligences on Amazon Mechanical Turk who are awarded the Masters Qualification (i.e. high-performance workers who have demonstrated excellence across a wide range of tasks). We also evaluated the interpretability for the CNN digit recognition model using MNIST. We asked humans to directly score the explanation on a 0 to 5 scale (0 for no explanation, 1-4 for insufficient or redundant explanation and 5 for concise explanation). Each method was evaluated by 16 graduate students at the School of Computer Science, Carnegie Mellon University who have taken at least one graduate-level machine learning class.
We assessed fidelity of the approximator by prediction performance with respect to the black- box output. We introduce two types of formalized metrics to quantitatively evaluate the fidelity: approximator fidelity and rationale fidelity.
Approximator fidelity implies the ability of the approximator to imitate the behaviour of a black-box. As shown above, VIBI and L2X outperform the others in approximating the black-box models. However, it does not mean both approximators are same in fidelity. See below.
Rationale fidelity implies how much the selected chunks contribute to the approximator fidelity. As shown above, the selected chunks of VIBI account for more approximator fidelity than L2X. Note that L2X is a special case of VIBI having the information bottleneck trade-off parameter \( \beta = 0 \) (i.e. not using the compressiveness constraint \( −\mathrm{I} ( \mathbf{x}, \mathbf{t} ) \)). Therefore, compressing information through the explainer achieves not only conciseness of explanation but also better fidelity of explanation to a black-box.
Note that the number of cognitive chunks to be selected, \( k \), should be given in advance. It also impacts conciseness of the actual total explanation and should be chosen carefully. In our analysis, we choose \( k \) as the minimum number that exceeds a certain fidelity.
Further details can be found here. The code is publicly available here.
DISCLAIMER: All opinions expressed in this post are those of the author and do not represent the views of CMU.
]]>