What Should "KL Divergence" Look Like Under GlobalPointer?

By 苏剑林 | April 15, 2022

Recently, a reader mentioned wanting to test the combined effect of GlobalPointer and R-Drop, but was unsure how to calculate the KL divergence under GlobalPointer. Regularization techniques like R-Drop or Virtual Adversarial Training require calculating the KL divergence between probability distributions. However, since the output of GlobalPointer is not a probability distribution, it cannot be calculated directly.

After some exploration, I have identified a usable form and verified its feasibility through simple experiments. I will introduce my analysis process here.

Symmetric Divergence

KL divergence is a function of two probability distributions. It is asymmetric, meaning $KL(p\Vert q)$ is generally not equal to $KL(q\Vert p)$. In practical applications, we usually use the symmetrized KL divergence:

\begin{equation}D(p,q) = KL(p\Vert q) + KL(q\Vert p)\end{equation}

Substituting the definition of KL divergence $KL(p\Vert q)=\sum\limits_i p_i\log\frac{p_i}{q_i}$, we can simplify it to obtain:

\begin{equation}D(p,q) = \sum_i (p_i - q_i)(\log p_i - \log q_i)\end{equation}

Considering that $p,q$ are usually obtained via softmax, we define:

\begin{equation}p_i = \frac{e^{s_i}}{\sum\limits_j e^{s_j}},\quad q_i = \frac{e^{t_i}}{\sum\limits_j e^{t_j}}\end{equation}

Substituting these in, we get:

\begin{equation}\begin{aligned} D(p,q) =&\, \sum_i (p_i - q_i)(s_i - t_i) + \sum_i (p_i - q_i)\left(\log\sum_j e^{t_j} - \log\sum_j e^{s_j}\right) \\ =&\, \sum_i (p_i - q_i)(s_i - t_i) + \left(\sum_i p_i - \sum_i q_i\right)\left(\log\sum_j e^{t_j} - \log\sum_j e^{s_j}\right) \\ =&\, \sum_i (p_i - q_i)(s_i - t_i) \end{aligned}\label{eq:kl-0}\end{equation}

Analogous Result

As we can see, from the perspective of logits, symmetric KL divergence takes the following form:

\begin{equation}D(s, t) = \sum_i (f(s_i) - f(t_i))(s_i - t_i) = \langle f(s) - f(t), s - t \rangle\label{eq:kl}\end{equation}

where $f$ is the softmax operation, and $\langle\cdot,\cdot\rangle$ denotes the dot product of vectors. In terms of form, it is the dot product of two vectors: one is the difference in logits, and the second is the difference of logits transformed by $f$. What are the characteristics of transformation $f$? We know that softmax is actually a smooth approximation of $\text{onehot}(\text{argmax}(\cdot))$ (refer to "Talk on Function Smoothing: Differentiable Approximation of Non-differentiable Functions"). For classification, the maximum value is the target class to be output, so ultimately, it is a smooth approximation of "setting the target class to 1 and non-target classes to 0."

With this abstract perspective, we can analogously construct the "KL divergence" for GlobalPointer. The output of GlobalPointer can also be understood as logits, but the loss function it uses is the multi-label cross-entropy proposed in "Generalizing 'Softmax + Cross Entropy' to Multi-label Classification Problems." Therefore, this essentially becomes a question of how to calculate KL divergence within multi-label cross-entropy. In GlobalPointer, the target categories are not necessarily the classes with the largest logits, but rather all categories where the logits are greater than 0.

So, for GlobalPointer, its symmetric divergence can retain the form of Equation $\eqref{eq:kl}$, but $f$ should be replaced with a smooth approximation of "setting values greater than 0 to 1 and values less than 0 to 0." The sigmoid function $\sigma(x)=1/(1+e^{-x})$ happens to be a function that satisfies this property. Therefore, we can design the symmetric KL divergence for GlobalPointer as:

\begin{equation}D(s, t) = \sum_i (\sigma(s_i) - \sigma(t_i))(s_i - t_i) = \langle \sigma(s) - \sigma(t), s - t \rangle\label{eq:gp-kl}\end{equation}

A Breakthrough

Interestingly, I later discovered that Equation $\eqref{eq:gp-kl}$ is actually equivalent to applying $\sigma$ activation to each logit separately, calculating the KL divergence for each binary probability distribution individually, and then summing them up.

Proving this is simple. Note that the binary distribution $[\sigma(s), 1 - \sigma(s)]$ constructed by the $\sigma$ function is equivalent to the binary distribution constructed by using $[s, 0]$ as logits with softmax, i.e., $[\sigma(s), 1 - \sigma(s)] = \text{softmax}([s, 0])$. Therefore, according to formula $\eqref{eq:kl-0}$, we directly have:

\begin{equation}\begin{aligned} &\,D\big([\sigma(s_i),1 - \sigma(s_i)],[\sigma(t_i),1 - \sigma(t_i)]\big) \\ =&\,(\sigma(s_i)-\sigma(t_i))(s_i - t_i) + \big((1-\sigma(s_i))-(1-\sigma(t_i))\big)(0 - 0)\\ =&\,(\sigma(s_i)-\sigma(t_i))(s_i - t_i) \end{aligned}\end{equation}

Summing up each component gives us Equation $\eqref{eq:gp-kl}$.

This equivalence shows that while treating multi-label classification as multiple binary classification problems brings about class imbalance issues, when used merely to evaluate the consistency of results (continuity), the so-called class imbalance problem does not exist (because it is not classification at all). Therefore, it can still be viewed as multiple binary classification problems, and conventional KL divergence can be calculated for them.

Experimental Results

I and some netizens conducted several simple comparative experiments. The results showed that using Equation $\eqref{eq:gp-kl}$ as the KL divergence to apply R-Drop to GlobalPointer indeed yields a slight improvement in performance. Conversely, if one directly applies softmax to GlobalPointer's logits and then calculates conventional KL divergence, the results are actually worse. This demonstrates the rationality of Equation $\eqref{eq:gp-kl}$.

However, it should be pointed out that Equation $\eqref{eq:gp-kl}$ merely provides a scheme for using R-Drop or Virtual Adversarial Training within GlobalPointer. Whether there will be an improvement in specific cases is not guaranteed, much like how conventional classification problems paired with R-Drop do not always yield improvements. This requires experimentation, especially regarding the fine-tuning of the regularization weight coefficient.

Conclusion

This article primarily discussed the calculation of "KL divergence" under GlobalPointer, providing a usable KL divergence form for applying R-Drop or Virtual Adversarial Training to GlobalPointer.