FlatNCE: Is the Reason for Poor Small-Batch Contrastive Learning Actually Floating-Point Error?

By 苏剑林 | July 26, 2021

Since SimCLR brought unsupervised learning to the forefront of computer vision, contrastive learning has gradually become popular in CV and even NLP, with more and more related research and work appearing. A well-known drawback of standard contrastive learning is the requirement for a relatively large batch_size (SimCLR performs best when batch_size=4096); with small batch_sizes, performance drops significantly. Consequently, one of the improvement directions for subsequent work has been to reduce the reliance on large batch_sizes. So, a natural question arises: What exactly is the reason standard contrastive learning performs poorly at small batch_sizes?

Recently, a paper titled "Simpler, Faster, Stronger: Breaking The log-K Curse On Contrastive Learners With FlatNCE" answered this question: floating-point error. It sounds almost unbelievable, but the paper's analysis is quite reasonable, and the proposed improvement, FlatNCE, indeed works better, making it hard to ignore.

Subtleties

Next, I will introduce the main content of the original paper according to my own understanding and notation. I won't go into a detailed review of Contrastive Learning here; broadly speaking, for a given sample $x$, we need to construct $K$ paired samples $y_1, y_2, \cdots, y_K$, where $y_t$ is the positive sample and the others are negative samples. Then we score each sample pair $(x, y_i)$, denoted as $s_1, s_2, \cdots, s_K$. Contrastive learning aims to enlarge the score gap between positive and negative pairs, usually using cross-entropy as the loss:

\begin{equation}-\log \frac{e^{s_t}}{\sum\limits_i e^{s_i}} = \log \left(\sum_i e^{s_i}\right) - s_t = \log \left(1 + \sum_{i\neq t} e^{s_i - s_t}\right)\end{equation}

For simplicity, let $\xi = \sum\limits_{i\neq t} e^{s_i - s_t}$. In practice, positive samples are usually high-similarity samples generated through data augmentation, while negative samples include all other samples in the batch; thus, negative samples can roughly be considered $K-1$ randomly selected samples. This means the gap between positive and negative pairs is already quite distinct, so it is easy for the model to achieve $s_t \gg s_i (i\neq t)$, i.e., $e^{s_i - s_t} \approx 0$. Thus, when the batch_size is relatively small (equivalent to $K$ being small), $\xi$ will be quite close to 0, which means the above loss function will also be very close to 0.

A loss function close to 0 usually means the gradient is also close to 0. However, this does not mean the model's update amount becomes very small. Because current contrastive learning uses adaptive optimizers like Adam, their update amount roughly takes the form of $\frac{\text{gradient}}{\sqrt{\text{gradient} \otimes \text{gradient}}} \times \text{learning rate}$. This means no matter how small the gradient is, as long as it is stable, the update amount will maintain the scale of the $\text{learning rate}$. Contrastive learning fits this scenario; to make $e^{s_i - s_t} \to 0$, one would need $s_i - s_t \to -\infty$. However, scores in contrastive learning are usually cosine similarities divided by a temperature parameter, so they are bounded; $s_i - s_t \to -\infty$ is unachievable. Therefore, after a certain number of training steps, the loss function will stay very close to 0 but remain greater than 0 for a long time.

However, the calculation of $\xi$ itself involves floating-point errors. When $\xi$ is very close to 0, the floating-point error might be larger than the exact value. Then the calculation of $\log(1+\xi)$ also has floating-point errors, and subsequently, the calculation of the gradient also has floating-point errors. This series of accumulated errors might result in the final calculated gradient being close to random noise, failing to provide effective update guidance. This is what the original paper considers to be the reason why contrastive learning performance degrades significantly at small batch_sizes.

Turning the Subtle into the Significant

Once this reason is understood, it is not difficult to propose a targeted solution. By performing a first-order expansion of the loss function, we have:

\begin{equation}\log \left(1 + \sum_{i\neq t} e^{s_i - s_t}\right) \approx \sum_{i\neq t} e^{s_i - s_t}\end{equation}

In other words, after a certain number of training steps, the model effectively uses $\xi$ as its loss function. Of course, since $\log(1+\xi) \leq \xi$ (i.e., $\xi$ is an upper bound of $\log(1+\xi)$), even if we used $\xi$ as the loss function from the start, the result wouldn't be much different. The main problem now is to solve the floating-point error caused by $\xi$ being too small. As mentioned, the update amount of adaptive optimizers is roughly in the form of $\frac{\text{gradient}}{\sqrt{\text{gradient} \otimes \text{gradient}}} \times \text{learning rate}$, which implies that if we directly multiply the loss function by a constant, the update amount theoretically will not change. So, since $\xi$ is too small, we can simply amplify it by multiplying it by a constant.

What should it be multiplied by? A direct idea is that the loss function should neither be too small nor too large; keeping it at the $\mathcal{O}(1)$ level is best. So we might as well multiply it by the reciprocal of $\xi$, essentially using:

\begin{equation}\frac{\xi}{\text{sg}(\xi)} = \frac{\sum\limits_{i\neq t} e^{s_i - s_t}}{\text{sg}\left(\sum\limits_{i\neq t} e^{s_i - s_t}\right)} \label{eq:flatnce-1}\end{equation}

as the loss function. Here $\text{sg}$ stands for stop_gradient (referred to as detach in the original paper), which means treating the denominator purely as a constant; when calculating the gradient, we only need to differentiate the numerator. This is the alternative scheme proposed by the original paper, called FlatNCE.

However, the form of the loss function with the $\text{sg}$ operator is not a form we are accustomed to. We can transform it. Observing that:

\begin{equation}\nabla_{\theta}\left(\frac{\xi}{\text{sg}(\xi)}\right) = \frac{\nabla_{\theta}\xi}{\xi} = \nabla_{\theta}\log \xi \end{equation}

In other words, the gradient provided by using $\frac{\xi}{\text{sg}(\xi)}$ as the loss function is exactly the same as the gradient provided by $\log \xi$. Therefore, we can replace the loss function with $\log \xi$ without the $\text{sg}$ operator:

\begin{equation}\log\left(\sum\limits_{i\neq t} e^{s_i - s_t}\right) = \log\left(\sum\limits_{i\neq t} e^{s_i}\right) - s_t \label{eq:flatnce-2}\end{equation}

Compared to cross-entropy, this loss simply removes the positive pair score $s_t$ from the $\text{logsumexp}$ operation. Note that $\text{logsumexp}$ can usually be calculated efficiently, and floating-point errors do not dominate. Therefore, using the above loss function to replace cross-entropy is theoretically equivalent, and in practice, it works better than cross-entropy at small batch_sizes. Additionally, it should be pointed out that the result of the above equation is not necessarily non-negative, so it is normal to see negative loss values during training after switching to this loss function.

Practical Knowledge

The analysis seems to make some sense, but is it actually effective? This naturally relies on experimental evidence. Unsurprisingly, FlatNCE works remarkably well.

The experiments in the original paper are mainly in CV, specifically replacing the loss of SimCLR with FlatNCE. The corresponding results are called FlatCLR. We are most concerned about whether FlatNCE really solves the dependency on large batch_sizes, and the following image provides an affirmative answer:

Comparison of SimCLR and FlatCLR under different batch_sizes

Comparison of SimCLR and FlatCLR under different batch_sizes

Below is a comparison of SimCLR and FlatCLR across various tasks, showing the superior performance of FlatCLR:

Comparison of SimCLR and FlatCLR across various tasks

Comparison of SimCLR and FlatCLR across various tasks

Finding Faults

Overall, the original paper's results are very creative. The perspective of "floating-point error" is very "dynamic" yet quite precise—one can't help but praise it.

Intuitively, the original goal of cross-entropy was to make the "difference between positive and negative sample scores as large as possible." This is fine for standard classification problems, but it is not enough for contrastive learning. Since the goal of contrastive learning is to learn features, in addition to the "coarse" feature that positive samples should score higher than negative samples, negative samples must also be compared with each other to learn finer features. The goal of FlatNCE is to make "positive sample scores as large as possible and negative sample scores as small as possible," shifting from relative value learning to absolute value learning. This allows optimization to continue even after positive and negative samples have been pulled apart by a certain distance, rather than stopping prematurely (for non-adaptive optimizers) or letting noise from floating-point errors dominate (for adaptive optimizers).

However, certain parts of the original paper's content are worth criticizing. For instance, the paper spends a significant amount of space discussing mutual information estimation, which has no substantial connection to the main body of the paper and increases the difficulty of understanding for the reader. Of course, a paper is not the same as popular science; adding extra theoretical derivation to make the article more substantial is understandable, though it would have been better to highlight the analysis of the floating-point error part. Furthermore, the part I find hardest to understand is that the paper presents $\eqref{eq:flatnce-1}$ as the final result. While this representation with "stop_gradient" is not difficult, it is not user-friendly. Usually, this method is only used as a "last resort" when the original function cannot be easily found, but that is clearly not the case for FlatNCE.

Conclusion

This article introduced a new work in contrastive learning. The work analyzes the floating-point error problem of cross-entropy during small-batch contrastive learning, suggesting this might be the primary reason for poor performance at small batch_sizes. It specifically proposes the improved loss function FlatNCE. Experiments show that contrastive learning based on FlatNCE indeed alleviates the dependency on large batch_sizes and achieves better results.