Thoughts on Dimension Averaging Strategies for Non-Square Matrices in Initialization Methods

By 苏剑林 | Oct 18, 2021

In articles such as "Understanding Model Parameter Initialization Strategies from a Geometric Perspective" and "A Brief Discussion on Transformer Initialization, Parameterization, and Normalization", we discussed initialization strategies for models. The general idea is: if an $n \times n$ square matrix is initialized with independent and identically distributed (i.i.d.) values with a mean of 0 and a variance of $1/n$, it approximates an orthogonal matrix, allowing the second moment (or variance) of the data to remain roughly constant during propagation.

But what if it is an $m \times n$ non-square matrix? The common approach (Xavier initialization) considers both forward and backward propagation together, thus using an i.i.d. initialization with mean 0 and variance $2/(m+n)$. However, this "averaging" is somewhat arbitrary. This article explores whether there might be better averaging schemes.

Basic Review

Xavier initialization considers a fully connected layer as follows (assuming $m$ input nodes and $n$ output nodes):

\begin{equation} y_j = b_j + \sum_i x_i w_{i,j}\end{equation}

where $b_j$ is typically initialized to 0, and the initial mean of $w_{i,j}$ is also generally 0. In "A Brief Discussion on Transformer Initialization, Parameterization, and Normalization", we calculated:

\begin{equation} \mathbb{E}[y_j^2] = \sum_{i} \mathbb{E}[x_i^2] \mathbb{E}[w_{i,j}^2]= m\mathbb{E}[x_i^2]\mathbb{E}[w_{i,j}^2]\end{equation}

To keep the second moment constant, we set the initialization variance of $w_{i,j}$ to $1/m$ (when the mean is 0, the variance equals the second moment).

However, this derivation only considers forward propagation. We also need to ensure that the model has reasonable gradients, which means the model must also remain stable during backward propagation. Let the model's loss function be $l$; according to the chain rule, we have:

\begin{equation}\frac{\partial l}{\partial x_i} = \sum_j \frac{\partial l}{\partial y_j} \frac{\partial y_j}{\partial x_i}=\sum_j \frac{\partial l}{\partial y_j} w_{i,j}\end{equation}

Note that this is now summing over $j$, where the dimension is $n$. Under the same assumptions, we have:

\begin{equation} \mathbb{E}\left[\left(\frac{\partial l}{\partial x_i}\right)^2\right] = \sum_{j} \mathbb{E}\left[\left(\frac{\partial l}{\partial y_j}\right)^2\right] \mathbb{E}[w_{i,j}^2]= n \mathbb{E}\left[\left(\frac{\partial l}{\partial y_j}\right)^2\right]\mathbb{E}[w_{i,j}^2]\end{equation}

To keep the second moment constant during backward propagation, we set the initialization variance of $w_{i,j}$ to $1/n$.

One is $1/m$ and the other is $1/n$. When $m \neq n$, there is a conflict. Since both are equally important, Xavier initialization simply averages the two dimensions, performing initialization with a variance of $2/(m+n)$.

Geometric Mean

Now let us consider two composite fully connected layers (temporarily ignoring bias terms):

\begin{equation} y = xW_1 W_2 \end{equation}

where $x \in \mathbb{R}^m, W_1 \in \mathbb{R}^{m \times n}, W_2 \in \mathbb{R}^{n \times m}$. That is to say, the input is $m$-dimensional, transformed to $n$ dimensions, and then transformed back to $m$ dimensions. Similar operations can be found in, for example, the FFN layer of BERT (though the FFN layer has an activation function in the middle).

According to the stability of forward propagation, we should initialize $W_1$ with a variance of $1/m$ and $W_2$ with a variance of $1/n$. However, if we require that $W_1$ and $W_2$ must be initialized with the same variance, then obviously, to keep the variance of $x$ and $y$ unchanged, both $W_1$ and $W_2$ need to be initialized with a distribution having a variance of $1/\sqrt{mn}$. When considering backward propagation, the result is the same.

In this way, we derive a new dimension averaging strategy: the geometric mean $\sqrt{mn}$. Through this strategy, in the composition of multi-layer networks, if the input and output dimensions remain the same, the variance will remain constant (for both forward and backward propagation). If we used the arithmetic mean $(m+n)/2$ and assumed $m < n$, then since $(m+n)^2/4 \geq mn$, the variance would shrink during forward/backward propagation.

Quadratic Mean

Another perspective is to treat this as a dual minimization problem: suppose the chosen variance is $t$. In forward propagation, we want $(mt-1)^2$ to be as small as possible, and in backward propagation, we want $(nt-1)^2$ to be as small as possible. Thus, we consider the sum:

\begin{equation}(mt-1)^2 + (nt-1)^2 \end{equation}

When $t = (m+n)/(m^2+n^2)$, the above expression reaches its minimum value. This yields a quadratic fractional averaging scheme: $(m^2+n^2)/(m+n)$.

It is easy to prove:

\begin{equation}\frac{m^2+n^2}{m+n} \geq \frac{m+n}{2}\geq \sqrt{mn}\end{equation}

From the derivation process, the quadratic mean on the left aims for the variance of each individual forward and backward step to remain as constant as possible; therefore, the left side can be considered a local optimal solution. The geometric mean on the right aims for the variance of the "initial input" and "final output" to remain as constant as possible; thus, in some sense, the right side can be considered a global optimal solution. The arithmetic mean in the middle is a solution that lies between the global and local optima.

Looking at it this way, it seems that the "arbitrary" arithmetic mean of Xavier initialization might actually be a practical choice following the "Middle Way"?

Summary

This article briefly considers dimension averaging schemes for non-square matrices in initialization methods. For a long time, it seems that no one questioned the default arithmetic mean, but I have derived different possibilities for averaging strategies from two different perspectives. As for which strategy is better, I have not conducted detailed experiments; interested readers can explore this on their own. Of course, it is also possible that under current optimization strategies, the default initialization scheme already works well enough, making fine-tuning unnecessary.