Understanding Model Parameter Initialization Strategies from a Geometric Perspective

By 苏剑林 | January 16, 2020

For complex models, parameter initialization is particularly important. A poor initialization often isn't just a matter of the model's performance suffering; it is more likely that the model will simply fail to train or converge entirely. A common adaptive initialization strategy in deep learning is Xavier initialization, which consists of weights randomly sampled from a normal distribution $\mathcal{N}\left(0, \frac{2}{fan_{in} + fan_{out}}\right)$, where $fan_{in}$ is the input dimension and $fan_{out}$ is the output dimension. Other initialization strategies are generally similar, albeit with different assumptions, leading to slight variations in their final forms.

The derivation of standard initialization strategies is based on probability and statistics. The general idea is to assume that the input data has a mean of 0 and a variance of 1, and then expect the output data to also maintain a mean of 0 and a variance of 1, thereby deriving the conditions for the mean and variance that the initial transformation should satisfy. Theoretically, there is nothing wrong with this process, but in the author's view, it is still not intuitive enough, and the derivation process involves quite a few assumptions. This article hopes to understand model initialization methods from a geometric perspective, offering a more direct derivation process.

Orthogonality at Your Fingertips

Some time ago, the author wrote "Distribution of the Angle Between Two Random Vectors in n-Dimensional Space". One of the corollaries is:

Corollary 1: Any two random vectors in high-dimensional space are almost always perpendicular.

In fact, Corollary 1 is the starting point for the entire geometric perspective of this article! A further corollary of it is:

Corollary 2: If we randomly select $n^2$ numbers from $\mathcal{N}(0, 1/n)$ to form an $n \times n$ matrix, this matrix is approximately an orthogonal matrix, and the larger $n$ is, the better the approximation.

Readers who find this hard to believe can verify it numerically:

import numpy as np

n = 100
W = np.random.randn(n, n) / np.sqrt(n)
X = np.dot(W.T, W) # Multiply the matrix by its own transpose
print(X) # Check if it is close to the identity matrix
print(np.square(X - np.eye(n)).mean()) # Calculate the MSE with the identity matrix

I believe that for most readers, seeing Corollary 2 for the first time will be somewhat surprising. An orthogonal matrix is a matrix that satisfies $\boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{I}$, meaning its inverse is equal to its transpose. Solving for the inverse of a general matrix is significantly more difficult than finding its transpose, so we tend to feel that the condition "inverse = transpose" should be very stringent. However, Corollary 2 tells us that a matrix sampled randomly is already close to being an orthogonal matrix, which is admittedly a bit counter-intuitive. When the author first realized this, it was also quite a surprise.

Actually, It's Not That Hard to Understand

However, once we get used to the fact in Corollary 1 that "any two random vectors in high-dimensional space are almost always perpendicular," we can quickly understand and derive this result. During a quick derivation, we can first consider the standard normal distribution $\mathcal{N}(0,1)$. Note that Corollary 1 requires the sampling direction to be uniform, and the standard normal distribution satisfies this requirements exactly. If we sample an $n \times n$ matrix from $\mathcal{N}(0,1)$, we can view it as $n$ vectors of dimension $n$. Since these $n$ vectors are all random, they are naturally close to being pairwise orthogonal.

Of course, being pairwise orthogonal is not yet an orthogonal matrix, because an orthogonal matrix also requires each vector to have a length (norm) of 1. Since we have $\mathbb{E}_{x\sim \mathcal{N}(0,1)}\left[x^2\right]=1$, this means an $n$-dimensional vector sampled from $\mathcal{N}(0,1)$ has a length approximately equal to $\sqrt{n}$. To approach orthogonality, we need to divide each element by $\sqrt{n}$, which is equivalent to changing the sampling variance from 1 to $1/n$.

Furthermore, the sampling distribution does not necessarily have to be a normal distribution; for instance, a uniform distribution $U\left[-\sqrt{3/n}, \sqrt{3/n}\right]$ also works. In fact, we have:

Corollary 3: An $n \times n$ matrix independently and repeatedly sampled from any distribution $p(x)$ with a mean of 0 and a variance of $1/n$ will approach an orthogonal matrix.

We can understand Corollary 3 from a more mathematical perspective: suppose $\boldsymbol{x}=(x_1,x_2,\dots,x_n)$ and $\boldsymbol{y}=(y_1,y_2,\dots,y_n)$ are both sampled from $p(x)$. Thus, we have:

\begin{equation} \begin{aligned} \langle \boldsymbol{x}, \boldsymbol{y}\rangle =&\, n\times \frac{1}{n}\sum_{k=1}^n x_k y_k\\ \approx&\, n\times \mathbb{E}_{x\sim p(x),y\sim p(x)}[xy]\\ =&\, n\times \mathbb{E}_{x\sim p(x)}[x]\times \mathbb{E}_{y\sim p(x)}[y]\\ =&\,0 \end{aligned} \end{equation}

and

\begin{equation} \begin{aligned} \Vert\boldsymbol{x}\Vert^2 =&\, n\times \frac{1}{n}\sum_{k=1}^n x_k^2\\ \approx&\, n\times \mathbb{E}_{x\sim p(x)}\left[x^2\right]\\ =&\, n\times \left(\mu^2 + \sigma^2\right)\\ =&\,1 \end{aligned} \end{equation}

Therefore, any two vectors are approximately orthonormal, and consequently, the sampled matrix is also close to an orthogonal matrix.

Now We Can Discuss Initialization

After discussing orthogonal matrices at length, the essence is to lay the groundwork for understanding the geometric significance of initialization methods. If readers still remember linear algebra, they should recall that the critical significance of an orthogonal matrix is that it preserves the norm of a vector during the transformation process. Expressed mathematically, if $\boldsymbol{W}\in \mathbb{R}^{n\times n}$ is an orthogonal matrix and $\boldsymbol{x}\in\mathbb{R}^n$ is any vector, then the norm of $\boldsymbol{x}$ is equal to the norm of $\boldsymbol{W}\boldsymbol{x}$:

\begin{equation}\Vert\boldsymbol{W}\boldsymbol{x}\Vert^2 = \boldsymbol{x}^{\top}\boldsymbol{W}^{\top}\boldsymbol{W}\boldsymbol{x}=\boldsymbol{x}^{\top}\boldsymbol{x}=\Vert\boldsymbol{x}\Vert^2\end{equation}

Consider a fully connected layer:

\begin{equation}\boldsymbol{y}=\boldsymbol{W}\boldsymbol{x} + \boldsymbol{b}\end{equation}

A deep learning model is essentially a nesting of fully connected layers. Therefore, to prevent the final output of the model from "exploding" or "vanishing" during the initialization phase, one idea is to let the model maintain the norm during initialization.

The natural initialization strategy derived from this idea is: "initialize $\boldsymbol{b}$ as all zeros, and initialize $\boldsymbol{W}$ with a random orthogonal matrix." Corollary 2 has already told us that an $n \times n$ matrix sampled from $\mathcal{N}(0, 1/n)$ is already close to being an orthogonal matrix, so we can sample from $\mathcal{N}(0, 1/n)$ to initialize $\boldsymbol{W}$. This is the Xavier initialization strategy, also called Glorot initialization in some frameworks because the author is Xavier Glorot. Additionally, the sampling distribution does not strictly have to be $\mathcal{N}(0, 1/n)$; as mentioned in Corollary 3, you can sample from any distribution with a mean of 0 and a variance of $1/n$.

The above discussion pertains to cases where both the input and output dimensions are $n$. What if the input is $n$-dimensional and the output is $m$-dimensional? In this case, $\boldsymbol{W}\in\mathbb{R}^{m\times n}$, and the condition for maintaining the norm of $\boldsymbol{W}\boldsymbol{x}$ is still $\boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{I}$. However, when $m < n$, this is impossible; when $m \geq n$, it is possible, and following a similar derivation as before, we can obtain:

Corollary 4: When $m \geq n$, an $m \times n$ matrix sampled independently and repeatedly from any distribution $p(x)$ with a mean of 0 and a variance of $1/m$ approximately satisfies $\boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{I}$.

Thus, if $m > n$, one simply needs to change the sampling distribution variance to $1/m$. As for the case when $m < n$, although there is no direct derivation, this practice can still be followed, as a reasonable strategy should be universal. Note that this modification is slightly different from the original design of Xavier initialization; it is the dual version of "LeCun initialization" (where the LeCun initialization variance is $1/n$), while the variance for Xavier initialization is $2/(m+n)$, which is an intuitive approach that averages forward and backward propagation. Here, we mainly consider forward propagation.

Some readers might still wonder: you have only considered scenarios without activation functions; even if the norm of $\boldsymbol{y}$ is the same as $\boldsymbol{x}$, it will change after passing through an activation function. This is indeed the case, and such situations must be analyzed on a case-by-case basis. For example, since $\tanh(x) \approx x$ when $x$ is small, Xavier initialization can be considered directly applicable to $\tanh$ activation. For $\text{relu}$, we can assume that approximately half of the elements in $\text{relu}(\boldsymbol{y})$ will be set to zero, so the norm will be approximately $1/\sqrt{2}$ of the original. To keep the norm unchanged, $\boldsymbol{W}$ can be multiplied by $\sqrt{2}$, meaning the initialization variance changes from $1/m$ to $2/m$. This is the initialization strategy proposed by the great Kaiming He specifically for $\text{relu}$.

Of course, it is practically difficult to adjust the variance perfectly for every single activation function. A more general approach is to add an operation similar to Layer Normalization directly after the activation function to explicitly restore the norm. This is where various Normalization techniques come into play (readers are welcome to continue with the previous work "What Exactly Does BN Do? A Behind-the-Scenes Analysis").

A Final Summary

This article primarily derives the conclusion that "any $n \times n$ matrix with a mean of 0 and variance of $1/n$ is close to being an orthogonal matrix" from the premise that "any two random vectors in high-dimensional space are almost always perpendicular," thereby providing a geometric perspective for related initialization strategies. I believe that this geometric perspective is more intuitive and easier to understand than a purely statistical one.