Run this notebook online:\ |Binder| or Colab: |Colab| .. |Binder| image:: https://mybinder.org/badge_logo.svg :target: https://mybinder.org/v2/gh/deepjavalibrary/d2l-java/master?filepath=chapter_optimization/convexity.ipynb .. |Colab| image:: https://colab.research.google.com/assets/colab-badge.svg :target: https://colab.research.google.com/github/deepjavalibrary/d2l-java/blob/colab/chapter_optimization/convexity.ipynb .. _sec_convexity: Convexity ========= Convexity plays a vital role in the design of optimization algorithms. This is largely due to the fact that it is much easier to analyze and test algorithms in this context. In other words, if the algorithm performs poorly even in the convex setting we should not hope to see great results otherwise. Furthermore, even though the optimization problems in deep learning are generally nonconvex, they often exhibit some properties of convex ones near local minima. This can lead to exciting new optimization variants such as :cite:`Izmailov.Podoprikhin.Garipov.ea.2018`. Basics ------ Let us begin with the basics. .. _fig_pacman: .. _fig_convex_intersect: .. _fig_nonconvex: Sets ~~~~ Sets are the basis of convexity. Simply put, a set :math:`X` in a vector space is convex if for any :math:`a, b \in X` the line segment connecting :math:`a` and :math:`b` is also in :math:`X`. In mathematical terms this means that for all :math:`\lambda \in [0, 1]` we have .. math:: \lambda \cdot a + (1-\lambda) \cdot b \in X \text{ whenever } a, b \in X. This sounds a bit abstract. Consider the picture :numref:`fig_pacman`. The first set is not convex since there are line segments that are not contained in it. The other two sets suffer no such problem. |Three shapes, the left one is nonconvex, the others are convex| Definitions on their own are not particularly useful unless you can do something with them. In this case we can look at unions and intersections as shown in :numref:`fig_convex_intersect`. Assume that :math:`X` and :math:`Y` are convex sets. Then :math:`X \cap Y` is also convex. To see this, consider any :math:`a, b \in X \cap Y`. Since :math:`X` and :math:`Y` are convex, the line segments connecting :math:`a` and :math:`b` are contained in both :math:`X` and :math:`Y`. Given that, they also need to be contained in :math:`X \cap Y`, thus proving our first theorem. |The intersection between two convex sets is convex| We can strengthen this result with little effort: given convex sets :math:`X_i`, their intersection :math:`\cap_{i} X_i` is convex. To see that the converse is not true, consider two disjoint sets :math:`X \cap Y = \emptyset`. Now pick :math:`a \in X` and :math:`b \in Y`. The line segment in :numref:`fig_nonconvex` connecting :math:`a` and :math:`b` needs to contain some part that is neither in :math:`X` nor :math:`Y`, since we assumed that :math:`X \cap Y = \emptyset`. Hence the line segment is not in :math:`X \cup Y` either, thus proving that in general unions of convex sets need not be convex. |The union of two convex sets need not be convex| Typically the problems in deep learning are defined on convex domains. For instance :math:`\mathbb{R}^d` is a convex set (after all, the line between any two points in :math:`\mathbb{R}^d` remains in :math:`\mathbb{R}^d`). In some cases we work with variables of bounded length, such as balls of radius :math:`r` as defined by :math:`\{\mathbf{x} | \mathbf{x} \in \mathbb{R}^d \text{ and } \|\mathbf{x}\|_2 \leq r\}`. Functions ~~~~~~~~~ Now that we have convex sets we can introduce convex functions :math:`f`. Given a convex set :math:`X` a function defined on it :math:`f: X \to \mathbb{R}` is convex if for all :math:`x, x' \in X` and for all :math:`\lambda \in [0, 1]` we have .. math:: \lambda f(x) + (1-\lambda) f(x') \geq f(\lambda x + (1-\lambda) x'). To illustrate this let us plot a few functions and check which ones satisfy the requirement. We need to import a few libraries. .. |Three shapes, the left one is nonconvex, the others are convex| image:: https://raw.githubusercontent.com/d2l-ai/d2l-en/master/img/pacman.svg .. |The intersection between two convex sets is convex| image:: https://raw.githubusercontent.com/d2l-ai/d2l-en/master/img/convex-intersect.svg .. |The union of two convex sets need not be convex| image:: https://raw.githubusercontent.com/d2l-ai/d2l-en/master/img/nonconvex.svg .. code:: java %load ../utils/djl-imports %load ../utils/plot-utils %load ../utils/Functions.java .. code:: java import ai.djl.ndarray.*; Let us define a few functions, both convex and nonconvex. .. code:: java import tech.tablesaw.plotly.traces.ScatterTrace; import tech.tablesaw.plotly.components.Axis.Spikes; // ScatterTrace.builder() does not support float[], // so we must convert to a double array first public double[] floatToDoubleArray(float[] x) { double[] ret = new double[x.length]; for (int i = 0; i < x.length; i++) { ret[i] = x[i]; } return ret; } public Figure plotLineAndSegment(float[] x, float[] y, float[] segment, Function func, int width, int height) { ScatterTrace trace = ScatterTrace.builder(floatToDoubleArray(x), floatToDoubleArray(y)) .mode(ScatterTrace.Mode.LINE) .build(); ScatterTrace trace2 = ScatterTrace.builder(floatToDoubleArray(segment), new double[]{func.apply(segment[0]), func.apply(segment[1])}) .mode(ScatterTrace.Mode.LINE) .build(); Layout layout = Layout.builder() .height(height) .width(width) .showLegend(false) .build(); return new Figure(layout, trace, trace2); } .. code:: java Function f = x -> 0.5f * x * x; // Convex Function g = x -> (float)Math.cos(Math.PI * x); // Nonconvex Function h = x -> (float)Math.exp(0.5f * x); // Convex NDManager manager = NDManager.newBaseManager(); NDArray X = manager.arange(-2f, 2f, 0.01f); float[] x = X.toFloatArray(); float[] segment = new float[]{-1.5f, 1f}; float[] fx = Functions.callFunc(x, f); float[] gx = Functions.callFunc(x, g); float[] hx = Functions.callFunc(x, h); display(plotLineAndSegment(x, fx, segment, f, 350, 300)); display(plotLineAndSegment(x, gx, segment, g, 350, 300)); display(plotLineAndSegment(x, hx, segment, h, 350, 300)); .. raw:: html
.. raw:: html
.. raw:: html
.. parsed-literal:: :class: output 7926d2f5-f3f0-4e23-b29a-06484748d53b As expected, the cosine function is nonconvex, whereas the parabola and the exponential function are. Note that the requirement that :math:`X` is a convex set is necessary for the condition to make sense. Otherwise the outcome of :math:`f(\lambda x + (1-\lambda) x')` might not be well defined. Convex functions have a number of desirable properties. Jensen's Inequality ~~~~~~~~~~~~~~~~~~~ One of the most useful tools is Jensen's inequality. It amounts to a generalization of the definition of convexity: .. math:: \begin{aligned} \sum_i \alpha_i f(x_i) & \geq f\left(\sum_i \alpha_i x_i\right) \text{ and } E_x[f(x)] & \geq f\left(E_x[x]\right), \end{aligned} where :math:`\alpha_i` are nonnegative real numbers such that :math:`\sum_i \alpha_i = 1`. In other words, the expectation of a convex function is larger than the convex function of an expectation. To prove the first inequality we repeatedly apply the definition of convexity to one term in the sum at a time. The expectation can be proven by taking the limit over finite segments. One of the common applications of Jensen's inequality is with regard to the log-likelihood of partially observed random variables. That is, we use .. math:: E_{y \sim P(y)}[-\log P(x \mid y)] \geq -\log P(x). This follows since :math:`\int P(y) P(x \mid y) dy = P(x)`. This is used in variational methods. Here :math:`y` is typically the unobserved random variable, :math:`P(y)` is the best guess of how it might be distributed and :math:`P(x)` is the distribution with :math:`y` integrated out. For instance, in clustering :math:`y` might be the cluster labels and :math:`P(x \mid y)` is the generative model when applying cluster labels. Properties ---------- Convex functions have a few useful properties. We describe them as follows. No Local Minima ~~~~~~~~~~~~~~~ In particular, convex functions do not have local minima. Let us assume the contrary and prove it wrong. If :math:`x \in X` is a local minimum there exists some neighborhood of :math:`x` for which :math:`f(x)` is the smallest value. Since :math:`x` is only a local minimum there has to be another :math:`x' \in X` for which :math:`f(x') < f(x)`. However, by convexity the function values on the entire *line* :math:`\lambda x + (1-\lambda) x'` have to be less than :math:`f(x')` since for :math:`\lambda \in [0, 1)` .. math:: f(x) > \lambda f(x) + (1-\lambda) f(x') \geq f(\lambda x + (1-\lambda) x'). This contradicts the assumption that :math:`f(x)` is a local minimum. For instance, the function :math:`f(x) = (x+1) (x-1)^2` has a local minimum for :math:`x=1`. However, it is not a global minimum. .. code:: java Function f = x -> (x - 1) * (x - 1) * (x + 1); float[] fx = Functions.callFunc(x, f); plotLineAndSegment(x, fx, segment, f, 400, 350); .. raw:: html
The fact that convex functions have no local minima is very convenient. It means that if we minimize functions we cannot "get stuck". Note, though, that this does not mean that there cannot be more than one global minimum or that there might even exist one. For instance, the function :math:`f(x) = \mathrm{max}(|x|-1, 0)` attains its minimum value over the interval :math:`[-1, 1]`. Conversely, the function :math:`f(x) = \exp(x)` does not attain a minimum value on :math:`\mathbb{R}`. For :math:`x \to -\infty` it asymptotes to :math:`0`, however there is no :math:`x` for which :math:`f(x) = 0`. Convex Functions and Sets ~~~~~~~~~~~~~~~~~~~~~~~~~ Convex functions define convex sets as *below-sets*. They are defined as .. math:: S_b := \{x | x \in X \text{ and } f(x) \leq b\}. Such sets are convex. Let us prove this quickly. Remember that for any :math:`x, x' \in S_b` we need to show that :math:`\lambda x + (1-\lambda) x' \in S_b` as long as :math:`\lambda \in [0, 1]`. But this follows directly from the definition of convexity since :math:`f(\lambda x + (1-\lambda) x') \leq \lambda f(x) + (1-\lambda) f(x') \leq b`. Have a look at the function :math:`f(x, y) = 0.5 x^2 + \cos(2 \pi y)` below. It is clearly nonconvex. The level sets are correspondingly nonconvex. In fact, they are typically composed of disjoint sets. TODO: Same issue as 11.1 tablesaw doesnt support mesh grid .. figure:: https://d2l-java-resources.s3.amazonaws.com/img/high_dim_nonconvex.svg Nonconvex Function. Derivatives and Convexity ~~~~~~~~~~~~~~~~~~~~~~~~~ Whenever the second derivative of a function exists it is very easy to check for convexity. All we need to do is check whether :math:`\partial_x^2 f(x) \succeq 0`, i.e., whether all of its eigenvalues are nonnegative. For instance, the function :math:`f(\mathbf{x}) = \frac{1}{2} \|\mathbf{x}\|^2_2` is convex since :math:`\partial_{\mathbf{x}}^2 f = \mathbf{1}`, i.e., its derivative is the identity matrix. The first thing to realize is that we only need to prove this property for one-dimensional functions. After all, in general we can always define some function :math:`g(z) = f(\mathbf{x} + z \cdot \mathbf{v})`. This function has the first and second derivatives :math:`g' = (\partial_{\mathbf{x}} f)^\top \mathbf{v}` and :math:`g'' = \mathbf{v}^\top (\partial^2_{\mathbf{x}} f) \mathbf{v}` respectively. In particular, :math:`g'' \geq 0` for all :math:`\mathbf{v}` whenever the Hessian of :math:`f` is positive semidefinite, i.e., whenever all of its eigenvalues are greater equal than zero. Hence back to the scalar case. To see that :math:`f''(x) \geq 0` for convex functions we use the fact that .. math:: \frac{1}{2} f(x + \epsilon) + \frac{1}{2} f(x - \epsilon) \geq f\left(\frac{x + \epsilon}{2} + \frac{x - \epsilon}{2}\right) = f(x). Since the second derivative is given by the limit over finite differences it follows that .. math:: f''(x) = \lim_{\epsilon \to 0} \frac{f(x+\epsilon) + f(x - \epsilon) - 2f(x)}{\epsilon^2} \geq 0. To see that the converse is true we use the fact that :math:`f'' \geq 0` implies that :math:`f'` is a monotonically increasing function. Let :math:`a < x < b` be three points in :math:`\mathbb{R}`. We use the mean value theorem to express .. math:: \begin{aligned} f(x) - f(a) & = (x-a) f'(\alpha) \text{ for some } \alpha \in [a, x] \text{ and } \\ f(b) - f(x) & = (b-x) f'(\beta) \text{ for some } \beta \in [x, b]. \end{aligned} By monotonicity :math:`f'(\beta) \geq f'(\alpha)`, hence .. math:: \begin{aligned} f(b) - f(a) & = f(b) - f(x) + f(x) - f(a) \\ & = (b-x) f'(\beta) + (x-a) f'(\alpha) \\ & \geq (b-a) f'(\alpha). \end{aligned} By geometry it follows that :math:`f(x)` is below the line connecting :math:`f(a)` and :math:`f(b)`, thus proving convexity. We omit a more formal derivation in favor of a graph below. Note: Currently no good way of annotating in tablesaw. This graph however requires annotating to make sense. .. figure:: https://d2l-java-resources.s3.amazonaws.com/img/convexity_check.svg Check Convexity. Constraints ----------- One of the nice properties of convex optimization is that it allows us to handle constraints efficiently. That is, it allows us to solve problems of the form: .. math:: \begin{aligned} \mathop{\mathrm{minimize~}}_{\mathbf{x}} & f(\mathbf{x}) \\ \text{ subject to } & c_i(\mathbf{x}) \leq 0 \text{ for all } i \in \{1, \ldots, N\}. \end{aligned} Here :math:`f` is the objective and the functions :math:`c_i` are constraint functions. To see what this does consider the case where :math:`c_1(\mathbf{x}) = \|\mathbf{x}\|_2 - 1`. In this case the parameters :math:`\mathbf{x}` are constrained to the unit ball. If a second constraint is :math:`c_2(\mathbf{x}) = \mathbf{v}^\top \mathbf{x} + b`, then this corresponds to all :math:`\mathbf{x}` lying on a halfspace. Satisfying both constraints simultaneously amounts to selecting a slice of a ball as the constraint set. Lagrange Function ~~~~~~~~~~~~~~~~~ In general, solving a constrained optimization problem is difficult. One way of addressing it stems from physics with a rather simple intuition. Imagine a ball inside a box. The ball will roll to the place that is lowest and the forces of gravity will be balanced out with the forces that the sides of the box can impose on the ball. In short, the gradient of the objective function (i.e., gravity) will be offset by the gradient of the constraint function (need to remain inside the box by virtue of the walls "pushing back"). Note that any constraint that is not active (i.e., the ball does not touch the wall) will not be able to exert any force on the ball. Skipping over the derivation of the Lagrange function :math:`L` (see e.g., the book by Boyd and Vandenberghe for details :cite:`Boyd.Vandenberghe.2004`) the above reasoning can be expressed via the following saddlepoint optimization problem: .. math:: L(\mathbf{x},\alpha) = f(\mathbf{x}) + \sum_i \alpha_i c_i(\mathbf{x}) \text{ where } \alpha_i \geq 0. Here the variables :math:`\alpha_i` are the so-called *Lagrange Multipliers* that ensure that a constraint is properly enforced. They are chosen just large enough to ensure that :math:`c_i(\mathbf{x}) \leq 0` for all :math:`i`. For instance, for any :math:`\mathbf{x}` for which :math:`c_i(\mathbf{x}) < 0` naturally, we'd end up picking :math:`\alpha_i = 0`. Moreover, this is a *saddlepoint* optimization problem where one wants to *maximize* :math:`L` with respect to :math:`\alpha` and simultaneously *minimize* it with respect to :math:`\mathbf{x}`. There is a rich body of literature explaining how to arrive at the function :math:`L(\mathbf{x}, \alpha)`. For our purposes it is sufficient to know that the saddlepoint of :math:`L` is where the original constrained optimization problem is solved optimally. Penalties ~~~~~~~~~ One way of satisfying constrained optimization problems at least approximately is to adapt the Lagrange function :math:`L`. Rather than satisfying :math:`c_i(\mathbf{x}) \leq 0` we simply add :math:`\alpha_i c_i(\mathbf{x})` to the objective function :math:`f(x)`. This ensures that the constraints will not be violated too badly. In fact, we have been using this trick all along. Consider weight decay in :numref:`sec_weight_decay`. In it we add :math:`\frac{\lambda}{2} \|\mathbf{w}\|^2` to the objective function to ensure that :math:`\mathbf{w}` does not grow too large. Using the constrained optimization point of view we can see that this will ensure that :math:`\|\mathbf{w}\|^2 - r^2 \leq 0` for some radius :math:`r`. Adjusting the value of :math:`\lambda` allows us to vary the size of :math:`\mathbf{w}`. In general, adding penalties is a good way of ensuring approximate constraint satisfaction. In practice this turns out to be much more robust than exact satisfaction. Furthermore, for nonconvex problems many of the properties that make the exact approach so appealing in the convex case (e.g., optimality) no longer hold. Projections ~~~~~~~~~~~ An alternative strategy for satisfying constraints are projections. Again, we encountered them before, e.g., when dealing with gradient clipping in :numref:`sec_rnn_scratch`. There we ensured that a gradient has length bounded by :math:`c` via .. math:: \mathbf{g} \leftarrow \mathbf{g} \cdot \mathrm{min}(1, c/\|\mathbf{g}\|). This turns out to be a *projection* of :math:`g` onto the ball of radius :math:`c`. More generally, a projection on a (convex) set :math:`X` is defined as .. math:: \mathrm{Proj}_X(\mathbf{x}) = \mathop{\mathrm{argmin}}_{\mathbf{x}' \in X} \|\mathbf{x} - \mathbf{x}'\|_2. It is thus the closest point in :math:`X` to :math:`\mathbf{x}`. This sounds a bit abstract. :numref:`fig_projections` explains it somewhat more clearly. In it we have two convex sets, a circle and a diamond. Points inside the set (yellow) remain unchanged. Points outside the set (black) are mapped to the closest point inside the set (red). While for :math:`\ell_2` balls this leaves the direction unchanged, this need not be the case in general, as can be seen in the case of the diamond. |Convex Projections| .. _fig_projections: One of the uses for convex projections is to compute sparse weight vectors. In this case we project :math:`\mathbf{w}` onto an :math:`\ell_1` ball (the latter is a generalized version of the diamond in the picture above). Summary ------- In the context of deep learning the main purpose of convex functions is to motivate optimization algorithms and help us understand them in detail. In the following we will see how gradient descent and stochastic gradient descent can be derived accordingly. - Intersections of convex sets are convex. Unions are not. - The expectation of a convex function is larger than the convex function of an expectation (Jensen's inequality). - A twice-differentiable function is convex if and only if its second derivative has only nonnegative eigenvalues throughout. - Convex constraints can be added via the Lagrange function. In practice simply add them with a penalty to the objective function. - Projections map to points in the (convex) set closest to the original point. Exercises --------- 1. Assume that we want to verify convexity of a set by drawing all lines between points within the set and checking whether the lines are contained. - Prove that it is sufficient to check only the points on the boundary. - Prove that it is sufficient to check only the vertices of the set. 2. Denote by :math:`B_p[r] := \{\mathbf{x} | \mathbf{x} \in \mathbb{R}^d \text{ and } \|\mathbf{x}\|_p \leq r\}` the ball of radius :math:`r` using the :math:`p`-norm. Prove that :math:`B_p[r]` is convex for all :math:`p \geq 1`. 3. Given convex functions :math:`f` and :math:`g` show that :math:`\mathrm{max}(f, g)` is convex, too. Prove that :math:`\mathrm{min}(f, g)` is not convex. 4. Prove that the normalization of the softmax function is convex. More specifically prove the convexity of :math:`f(x) = \log \sum_i \exp(x_i)`. 5. Prove that linear subspaces are convex sets, i.e., :math:`X = \{\mathbf{x} | \mathbf{W} \mathbf{x} = \mathbf{b}\}`. 6. Prove that in the case of linear subspaces with :math:`\mathbf{b} = 0` the projection :math:`\mathrm{Proj}_X` can be written as :math:`\mathbf{M} \mathbf{x}` for some matrix :math:`\mathbf{M}`. 7. Show that for convex twice differentiable functions :math:`f` we can write :math:`f(x + \epsilon) = f(x) + \epsilon f'(x) + \frac{1}{2} \epsilon^2 f''(x + \xi)` for some :math:`\xi \in [0, \epsilon]`. 8. Given a vector :math:`\mathbf{w} \in \mathbb{R}^d` with :math:`\|\mathbf{w}\|_1 > 1` compute the projection on the :math:`\ell_1` unit ball. - As intermediate step write out the penalized objective :math:`\|\mathbf{w} - \mathbf{w}'\|_2^2 + \lambda \|\mathbf{w}'\|_1` and compute the solution for a given :math:`\lambda > 0`. - Can you find the 'right' value of :math:`\lambda` without a lot of trial and error? 9. Given a convex set :math:`X` and two vectors :math:`\mathbf{x}` and :math:`\mathbf{y}` prove that projections never increase distances, i.e., :math:`\|\mathbf{x} - \mathbf{y}\| \geq \|\mathrm{Proj}_X(\mathbf{x}) - \mathrm{Proj}_X(\mathbf{y})\|`. .. |Convex Projections| image:: https://raw.githubusercontent.com/d2l-ai/d2l-en/master/img/projections.svg