By 苏剑林 | December 11, 2021
In this blog, we have discussed gradient penalty-related content multiple times. In terms of form, gradient penalty terms are divided into two types. One is the gradient penalty with respect to the input $\Vert\nabla_{\boldsymbol{x}} f(\boldsymbol{x};\boldsymbol{\theta})\Vert^2$, which we discussed in articles such as "A Brief Talk on Adversarial Training: Meaning, Methods, and Reflections (with Keras Implementation)" and "Random Thoughts on Generalization: From Random Noise and Gradient Penalty to Virtual Adversarial Training". The other type is the gradient penalty with respect to parameters $\Vert\nabla_{\boldsymbol{\theta}} f(\boldsymbol{x};\boldsymbol{\theta})\Vert^2$, which we discussed in articles like "Optimization Algorithms from a Dynamical Perspective (V): Why the Learning Rate Should Not Be Too Small?" and "Do We Really Need to Reduce Training Set Loss to Zero?".
In these related articles, both types of gradient penalties are claimed to have the ability to improve the generalization performance of the model. So, is there any connection between the two? I learned about an inequality between them from a recent paper by Google, "The Geometric Occam's Razor Implicit in Deep Learning". This partially answers the question, and I feel it might be useful in the future, so I am taking some notes here.
Final Result
Assume there is an $l$-layer MLP model, denoted as
\begin{equation}\boldsymbol{h}^{(t+1)} = g^{(t)}(\boldsymbol{W}^{(t)}\boldsymbol{h}^{(t)}+\boldsymbol{b}^{(t)})\end{equation}
where $g^{(t)}$ is the activation function of the current layer, $t\in\{1,2,\cdots,l\}$, and let $\boldsymbol{h}^{(1)}$ be $\boldsymbol{x}$, i.e., the original input of the model. For the convenience of the derivation below, we denote $\boldsymbol{z}^{(t+1)}=\boldsymbol{W}^{(t)}\boldsymbol{h}^{(t)}+\boldsymbol{b}^{(t)}$. The set of all parameters is $\boldsymbol{\theta}=\{\boldsymbol{W}^{(1)},\boldsymbol{b}^{(1)},\boldsymbol{W}^{(2)},\boldsymbol{b}^{(2)},\cdots,\boldsymbol{W}^{(l)},\boldsymbol{b}^{(l)}\}$. If $f$ is any scalar function of $\boldsymbol{h}^{(l+1)}$, then the following inequality holds:
\begin{equation}\Vert\nabla_{\boldsymbol{x}} f\Vert^2\left(\frac{1 + \Vert \boldsymbol{h}^{(1)}\Vert^2}{\Vert\boldsymbol{W}^{(1)}\Vert^2 \Vert\nabla_{\boldsymbol{x}}\boldsymbol{h}^{(1)}\Vert^2}+\cdots+\frac{1 + \Vert \boldsymbol{h}^{(l)}\Vert^2}{\Vert\boldsymbol{W}^{(l)}\Vert^2 \Vert\nabla_{\boldsymbol{x}}\boldsymbol{h}^{(l)}\Vert^2}\right)\leq \Vert\nabla_{\boldsymbol{\theta}} f\Vert^2\label{eq:f}\end{equation}
In the above equation, $\Vert\nabla_{\boldsymbol{x}} f\Vert$, $\Vert\nabla_{\boldsymbol{\theta}} f\Vert^2$, and $\Vert \boldsymbol{h}^{(i)}\Vert$ use the ordinary $l_2$ norm, which is the square root of the sum of the squares of each element. However, $\Vert\boldsymbol{W}^{(1)}\Vert$ and $\Vert\nabla_{\boldsymbol{x}}\boldsymbol{h}^{(1)}\Vert$ use the "spectral norm" of the matrix (refer to "Lipschitz Continuity in Deep Learning: Generalization and Generative Models"). This inequality shows that the parameter gradient penalty, to some extent, contains the input gradient penalty.
Derivation Process
Obviously, to prove inequality $\eqref{eq:f}$, we only need to prove for each parameter:
\begin{align}\Vert\nabla_{\boldsymbol{x}} f\Vert^2\left(\frac{\Vert \boldsymbol{h}^{(t)}\Vert^2}{\Vert\boldsymbol{W}^{(t)}\Vert^2 \Vert\nabla_{\boldsymbol{x}}\boldsymbol{h}^{(t)}\Vert^2}\right)\leq&\, \Vert\nabla_{\boldsymbol{W}^{(t)}} f\Vert^2 \label{eq:w}\\
\Vert\nabla_{\boldsymbol{x}} f\Vert^2\left(\frac{1}{\Vert\boldsymbol{W}^{(t)}\Vert^2 \Vert\nabla_{\boldsymbol{x}}\boldsymbol{h}^{(t)}\Vert^2}\right)\leq&\, \Vert\nabla_{\boldsymbol{b}^{(t)}} f\Vert^2 \label{eq:b}
\end{align}
Then, by iterating through all $t$ and adding each equation side-by-side, the result follows. The proof of these two inequalities is essentially a matrix calculus problem. However, most readers, like myself, may not be familiar with matrix calculus. In such cases, the best approach is to write out the component forms, converting it into a scalar calculus problem.
Specifically, $\boldsymbol{z}^{(t+1)}=\boldsymbol{W}^{(t)}\boldsymbol{h}^{(t)}+\boldsymbol{b}^{(t)}$ written in component form is:
\begin{equation}z^{(t+1)}_i = \sum_j w^{(t)}_{i,j} h_j^{(t)} + b^{(t)}_i\end{equation}
Then, by the chain rule:
\begin{equation}\frac{\partial f}{\partial x_i} = \sum_{j,k} \frac{\partial f}{\partial z^{(t+1)}_j} \frac{\partial z^{(t+1)}_j}{\partial h^{(t)}_k} \frac{\partial h^{(t)}_k}{\partial x_i} = \sum_{j,k} \frac{\partial f}{\partial z^{(t+1)}_j} w^{(t)}_{j,k} \frac{\partial h^{(t)}_k}{\partial x_i}\label{eq:l}\end{equation}
And
\begin{equation}\frac{\partial z^{(t+1)}_j}{\partial w^{(t)}_{m,n}} = \delta_{j,m}h^{(t)}_n\end{equation}
Here $\delta_{j,m}$ is the Kronecker delta. Now we can write
\begin{equation}w^{(t)}_{j,k} = \sum_m \delta_{j,m}w^{(t)}_{m,k} = \sum_m \frac{\partial z^{(t+1)}_j}{\partial w^{(t)}_{m,n}} (h^{(t)}_n)^{-1} w^{(t)}_{m,k}\end{equation}
Substituting into $\eqref{eq:l}$, we get
\begin{equation}\frac{\partial f}{\partial x_i} = \sum_{j,k,m} \frac{\partial f}{\partial z^{(t+1)}_j} \frac{\partial z^{(t+1)}_j}{\partial w^{(t)}_{m,n}} (h^{(t)}_n)^{-1} w^{(t)}_{m,k} \frac{\partial h^{(t)}_k}{\partial x_i}=\sum_{k,m} \frac{\partial f}{\partial w^{(t)}_{m,n}} (h^{(t)}_n)^{-1} w^{(t)}_{m,k} \frac{\partial h^{(t)}_k}{\partial x_i}\end{equation}
Multiplying both sides by $h^{(t)}_n$ gives
\begin{equation}h^{(t)}_n\frac{\partial f}{\partial x_i} = \sum_{k,m} \frac{\partial f}{\partial w^{(t)}_{m,n}} w^{(t)}_{m,k} \frac{\partial h^{(t)}_k}{\partial x_i}\end{equation}
Assuming original vectors are column vectors and the shape of the matrix after calculating the gradient is transposed, then the above can be written in matrix form:
\begin{equation}\boldsymbol{h}^{(t)}(\nabla_{\boldsymbol{x}} f)^\top = (\nabla_{\boldsymbol{W}^{(t)}} f ) \boldsymbol{W}^{(t)}(\nabla_{\boldsymbol{x}}\boldsymbol{h}^{(t)})\end{equation}
Multiplying both sides on the left by $(\boldsymbol{h}^{(t)})^{\top}$ gives
\begin{equation}\Vert\boldsymbol{h}^{(t)}\Vert^2(\nabla_{\boldsymbol{x}} f)^\top = (\boldsymbol{h}^{(t)})^{\top}(\nabla_{\boldsymbol{W}^{(t)}} f ) \boldsymbol{W}^{(t)}(\nabla_{\boldsymbol{x}}\boldsymbol{h}^{(t)})\end{equation}
Taking the norm of both sides yields
\begin{equation}\Vert\boldsymbol{h}^{(t)}\Vert^2 \Vert\nabla_{\boldsymbol{x}} f\Vert = \Vert (\boldsymbol{h}^{(t)})^{\top}(\nabla_{\boldsymbol{W}^{(t)}} f ) \boldsymbol{W}^{(t)}(\nabla_{\boldsymbol{x}}\boldsymbol{h}^{(t)})\Vert \leq \Vert\boldsymbol{h}^{(t)}\Vert \Vert\nabla_{\boldsymbol{W}^{(t)}} f \Vert \Vert \boldsymbol{W}^{(t)}\Vert \Vert\nabla_{\boldsymbol{x}}\boldsymbol{h}^{(t)}\Vert\end{equation}
Regarding the second inequality sign, it holds whether the matrix norm used is the Frobenius norm ($l_2$ norm) or the spectral norm. Thus, after choosing the required norms and rearranging, we can obtain formula $\eqref{eq:w}$. The proof of formula $\eqref{eq:b}$ is similar and will not be repeated here.
Simple Analysis
Some readers might wonder how specifically to interpret formula $\eqref{eq:f}$? In fact, I mainly feel that formula $\eqref{eq:f}$ itself is somewhat interesting and might be used in some scenario in the future, so this article is primarily a "note" on it, without providing a definitive interpretation.
As for the logical sequence of the original paper, it goes like this: In "Optimization Algorithms from a Dynamical Perspective (V): Why the Learning Rate Should Not Be Too Small?", we introduced "Implicit Gradient Regularization" (by the same author as this paper), which pointed out that SGD implicitly contains a gradient penalty term for parameters. Formula $\eqref{eq:f}$ shows that the gradient penalty for parameters implicitly contains a gradient penalty for inputs. Furthermore, the gradient penalty for inputs is related to Dirichlet energy, which in turn can serve as a representation of model complexity. So, following this chain of reasoning, the conclusion is: SGD itself tends to choose models with relatively smaller complexity.
However, the original paper made a small mistake when interpreting formula $\eqref{eq:f}$. It stated that in the initial stage, $\Vert \boldsymbol{W}^{(t)}\Vert$ would be very close to 0, so the terms in the brackets of formula $\eqref{eq:f}$ would be very large. Therefore, to minimize the parameter gradient penalty on the right side of formula $\eqref{eq:f}$, one must make the input gradient penalty on the left side of formula $\eqref{eq:f}$ sufficiently small. However, as we know from "Understanding Model Parameter Initialization Strategies from a Geometric Perspective", commonly used initialization methods are actually close to orthogonal initialization, and the spectral norm of an orthogonal matrix is actually 1. If activation functions are considered, the spectral norm at initialization is actually greater than 1, so the assumption that $\Vert \boldsymbol{W}^{(t)}\Vert$ is very close to 0 during the initialization phase does not hold.
In fact, for a network that has not collapsed during training, the model parameters and the input/output of each layer generally maintain a stable state. Therefore, throughout the training process, $\Vert \boldsymbol{h}^{(t)}\Vert$, $\Vert\boldsymbol{W}^{(t)}\Vert$, and $\Vert\nabla_{\boldsymbol{x}}\boldsymbol{h}^{(t)}\Vert$ do not actually fluctuate much. Thus, the parameter gradient penalty on the right side is approximately equivalent to the input gradient penalty on the left side. This is my interpretation, which does not require the assumption that "$\Vert \boldsymbol{W}^{(t)}\Vert$ is very close to 0."
Summary
This article primarily introduced an inequality between two types of gradient penalty terms and provided its own proof along with a brief analysis.