By 苏剑林 | September 19, 2024
Softmax, as the name suggests, is a "soft max"—a smooth approximation of the $\max$ operator (more precisely, the $\text{argmax}$). It transforms any vector $\boldsymbol{x} \in \mathbb{R}^n$ into a new vector with non-negative components that sum to 1 through exponential normalization, allowing us to adjust its degree of approximation to $\text{argmax}$ (in its one-hot form) via a temperature parameter. In addition to exponential normalization, we previously introduced other schemes achieving similar effects in "Path to Probability Distributions: A Survey of Softmax and Its Alternatives".
We know that the maximum value is often referred to as Top-1, and its smooth approximation schemes seem quite mature. However, has the reader ever wondered what a smooth approximation for a general Top-$k$ would look like? Let us explore this problem together below.
Let $\boldsymbol{x}=(x_1,x_2,\cdots,x_n)\in\mathbb{R}^n$. For simplicity, we assume they are pairwise distinct, i.e., $i\neq j \Leftrightarrow x_i\neq x_j$. Let $\Omega_k(\boldsymbol{x})$ denote the set of indices of the $k$ largest components of $\boldsymbol{x}$, such that $|\Omega_k(\boldsymbol{x})|=k$ and $\forall i\in \Omega_k(\boldsymbol{x}), j \not\in \Omega_k(\boldsymbol{x})\Rightarrow x_i > x_j$. We define the Top-$k$ operator $\mathcal{T}_k$ as a mapping from $\mathbb{R}^n \mapsto \{0,1\}^n$:
\begin{equation} [\mathcal{T}_k(\boldsymbol{x})]_i = \left\{\begin{aligned}1,&\,\, i\in \Omega_k(\boldsymbol{x}) \\ 0,&\,\, i \not\in \Omega_k(\boldsymbol{x})\end{aligned}\right. \end{equation}In simple terms, if $x_i$ is among the $k$ largest elements, the corresponding position becomes 1; otherwise, it becomes 0. The final result is a multi-hot vector, e.g., $\mathcal{T}_2([3,2,1,4]) = [1,0,0,1]$.
The operation from $\boldsymbol{x}$ to $\mathcal{T}_k(\boldsymbol{x})$ is essentially a "hard assignment," which is inherently discontinuous and retains no effective gradient information regarding $\boldsymbol{x}$, making it impossible to integrate into a model for end-to-end training. To solve this, we need to construct a smooth approximation of $\mathcal{T}_k(\boldsymbol{x})$—also called a "Differentiable Top-$k$ Operator" in some literature—that provides effective gradient information.
Specifically, we define the set
\begin{equation}\Delta_k^{n-1} = \left\{\boldsymbol{p}=(p_1,p_2,\cdots,p_n)\left|\, p_1,p_2,\cdots,p_n\in[0,1],\sum_{i=1}^n p_i = k\right.\right\}\end{equation}Our task is to construct a mapping $\mathcal{ST}_k(\boldsymbol{x})$ from $\mathbb{R}^n \mapsto \Delta_k^{n-1}$ that as much as possible satisfies the following properties:
\begin{align} &{\color{red}{\text{Monotonicity}}}:\quad [\mathcal{ST}_k(\boldsymbol{x})]_i \geq [\mathcal{ST}_k(\boldsymbol{x})]_j \,\,\Leftrightarrow\,\, x_i \geq x_j \\[8pt] &{\color{red}{\text{Invariance}}}:\quad \mathcal{ST}_k(\boldsymbol{x}) = \mathcal{ST}_k(\boldsymbol{x} + c),\,\,\forall c\in\mathbb{R} \\[8pt] &{\color{red}{\text{Convergence}}}:\quad \lim_{\tau\to 0^+}\mathcal{ST}_k(\boldsymbol{x}/\tau) = \mathcal{T}_k(\boldsymbol{x}) \\ \end{align}One can verify that Softmax, as $\mathcal{ST}_1(\boldsymbol{x})$, satisfies these properties. Thus, proposing these properties is essentially an attempt to construct $\mathcal{ST}_k(\boldsymbol{x})$ as a natural generalization of Softmax. Of course, constructing a smooth approximation for Top-$k$ is inherently harder than for Top-1, so if difficulties arise, we do not strictly need to follow every property, as long as the mapping exhibits the characteristics of a smooth approximation of $\mathcal{T}_k(\boldsymbol{x})$.
In fact, I have paid attention to this problem for a long time. It was first discussed in the 2019 article "Trifles on Smoothing Functions: Differentiable Approximations of Non-differentiable Functions", where I called it $\text{soft-}k\text{-max}$ and gave an iterative construction scheme:
Input is $\boldsymbol{x}$, initialize $\boldsymbol{p}^{(0)}$ as a zero vector;
Execute $\boldsymbol{x} = \boldsymbol{x} - \min(\boldsymbol{x})$ (ensure all elements are non-negative)
For $i=1,2,\dots,k$, execute:
$\boldsymbol{y} = (1 - \boldsymbol{p}^{(i-1)})\otimes\boldsymbol{x}$;
$\boldsymbol{p}^{(i)} = \boldsymbol{p}^{(i-1)} + \text{softmax}(\boldsymbol{y})$
Return $\boldsymbol{p}^{(k)}$.
Actually, the idea behind this iterative construction is simple. We can first understand it by replacing $\text{softmax}(\boldsymbol{y})$ with $\mathcal{T}_1(\boldsymbol{y})$. In that case, the algorithm first ensures non-negative components, identifies the Top-1, sets that Top-1 to zero (the max becomes the min), then identifies the remaining Top-1, and so on. The final $\boldsymbol{p}_k$ is exactly $\mathcal{T}_k(\boldsymbol{x})$. Using $\text{softmax}(\boldsymbol{y})$ as a smooth approximation of $\mathcal{T}_1(\boldsymbol{y})$ during the iteration naturally yields a smooth approximation of $\mathcal{T}_k(\boldsymbol{x})$.
Coincidentally, I found a reply on Stack Exchange to the question "Is there something like softmax but for top k values?" that proposed a similar iterative scheme. It first defines a weighted Softmax:
\begin{equation}[\text{softmax}(\boldsymbol{x};\boldsymbol{w})]_i = \frac{w_i e^{x_i}}{\sum\limits_{i=1}^n w_i e^{x_i}}\end{equation}Then the constructed iterative process is:
Input is $\boldsymbol{x}$, initialize $\boldsymbol{p}^{(0)}$ as a zero vector;
For $i=1,2,\dots,k$, execute:
$\boldsymbol{p}^{(i)} = \boldsymbol{p}^{(i-1)} + \text{softmax}(\boldsymbol{x}; 1 - \boldsymbol{p}^{(i-1)})$
Return $\boldsymbol{p}^{(k)}$.
This follows the exact same logic as my proposed iterative process, except I multiplied $1 - \boldsymbol{p}_{i-1}$ onto $\boldsymbol{x}$, whereas they multiplied it onto $e^{\boldsymbol{x}}$, simplifying the process by leveraging the non-negativity of $e^{\boldsymbol{x}}$. However, this iteration is actually incorrect as it does not satisfy "Convergence." For example, when $k=2$, taking the limit $\tau\to 0^+$ for $\boldsymbol{x}/\tau$ does not result in a Multi-Hot vector, but rather a vector where the maximum becomes 1.5, the second-largest becomes 0.5, and the rest become 0. This is because $1-p_{\max}$ is roughly on the same order as $e^{-x_{\max}}$; multiplying $1-p_{\max}$ by $e^{x_{\max}}$ cannot completely eliminate the maximum value.
Iterative construction relies entirely on experience and may hide difficult-to-detect issues—as seen with the weighted Softmax iteration that appears simpler but is logically flawed. Without guiding principles closer to the essence, these schemes are also hard to analyze theoretically. For instance, in my iterative construction, although it tests fine, it is difficult to prove that the components of $\boldsymbol{p}_k$ are within $[0,1]$ or whether it satisfies monotonicity.
Thus, we seek a guiding principle from a higher perspective to design this smooth approximation. Just a few days ago, I suddenly realized a crucial fact:
\begin{equation}\mathcal{T}_k(\boldsymbol{x}) = \nabla_{\boldsymbol{x}} \sum_{i\in\Omega_k(\boldsymbol{x})} x_i\end{equation}In other words, the gradient of the sum of the $k$ largest components is exactly $\mathcal{T}_k(\boldsymbol{x})$. Therefore, we can instead look for a smooth approximation of $\sum_{i\in\Omega_k(\boldsymbol{x})} x_i$ and take its gradient to obtain the smooth approximation of $\mathcal{T}_k(\boldsymbol{x})$. The former is a scalar, which is easier to approximate smoothly. For example, using the identity:
\begin{equation}\sum_{i\in\Omega_k(\boldsymbol{x})} x_i = \max_{i_1 < \cdots < i_k} (x_{i_1} + \cdots + x_{i_k})\end{equation}That is, take the maximum over the sum of all possible $k$-combinations of components. Thus, the problem becomes finding a smooth approximation of $\max$, which we have already solved (refer to "Seeking a Smooth Maximum Function"); the answer is $\text{logsumexp}$:
\begin{equation}\max_{i_1 < \cdots < i_k} (x_{i_1} + \cdots + x_{i_k})\approx \log \sum_{i_1 < \cdots < i_k} e^{x_{i_1} + \cdots + x_{i_k}} \triangleq \log Z_k\end{equation}Taking the gradient, we obtain a form for $\mathcal{ST}_k(\boldsymbol{x})$:
\begin{equation}[\mathcal{ST}_k(\boldsymbol{x})]_i = \frac{\sum\limits_{i_2 < \cdots < i_k} e^{x_i+x_{i_2} + \cdots + x_{i_k}}}{\sum\limits_{i_1 < \cdots < i_k} e^{x_{i_1} +x_{i_2}+ \cdots + x_{i_k}}} \triangleq \frac{Z_{k,i}}{Z_k} \label{eq:k-max-grad}\end{equation}The denominator is the sum of exponents of all $k$-component sums, while the numerator is the sum of exponents of all $k$-component sums that include $x_i$. Based on this form, we can easily prove:
\begin{equation}0 < [\mathcal{ST}_k(\boldsymbol{x})]_i < 1,\quad \sum_{i=1}^n [\mathcal{ST}_k(\boldsymbol{x})]_i = k\end{equation}So the $\mathcal{ST}_k(\boldsymbol{x})$ defined this way indeed belongs to $\Delta_k^{n-1}$. In fact, we can also prove it satisfies monotonicity, invariance, and convergence, and that $\mathcal{ST}_1(\boldsymbol{x})$ is simply Softmax. These characteristics show it is indeed a natural generalization of Softmax for the Top-$k$ operator. We shall call it "**GradTopK** (Gradient-guided Soft Top-k operator)."
However, it is not yet time to celebrate, as the numerical computation of Equation \eqref{eq:k-max-grad} is not resolved. If calculated directly, the denominator involves $C_n^k$ exponential terms, which is computationally expensive. We must find an efficient calculation method. Having denoted the numerator and denominator as $Z_{k,i}$ and $Z_k$, we can observe that the numerator $Z_{k,i}$ satisfies the recurrence:
\begin{equation}Z_{k,i} = e^{x_i}(Z_{k-1} - Z_{k-1,i})\end{equation}Combined with the fact that the sum of $Z_{k,i}$ over $i$ equals $k Z_k$, we can construct a recursive calculation process:
\begin{equation}\begin{aligned} \log Z_{k,i} =&\, x_i + \log(e^{\log Z_{k-1}} - e^{\log Z_{k-1,i}}) \\ \log Z_k =&\, \left(\log\sum_{i=1}^n e^{\log Z_{k,i}}\right) - \log k \\ \end{aligned}\end{equation}where $\log Z_{1,i} = x_i$. To reduce overflow risk, we have taken logarithms on both sides. Now, calculating $\mathcal{ST}_k(\boldsymbol{x})$ only requires $k$ iterations, which is efficient enough. However, even with logarithmic handling, the above recursion only works for $\boldsymbol{x}$ with small variance or small $k$. Otherwise, $\log Z_{k-1}$ and the largest $\log Z_{k-1,i}$ become extremely close; when they are numerically indistinguishable, a $\log 0$ bug occurs. Personally, I believe this is a fundamental difficulty of such recursive transformations.
A very crude reference implementation:
def grad_topk(x, k, tau=1.0):
x = x / tau
n = x.shape[-1]
log_Z_ki = x
log_Z_k = np.logaddexp.reduce(log_Z_ki, axis=-1, keepdims=True) - np.log(1)
for i in range(2, k + 1):
log_Z_ki = x + np.log(np.maximum(np.exp(log_Z_k) - np.exp(log_Z_ki), 1e-12))
log_Z_k = np.logaddexp.reduce(log_Z_ki, axis=-1, keepdims=True) - np.log(i)
return np.exp(log_Z_ki - log_Z_k)
The previous approach of building a smooth approximation of Top-$k$ via gradients indeed offers a sense of high-level beauty, but some readers might find it too abstract and lacking intuition. Moreover, the numerical instability for large variances in $\boldsymbol{x}$ or large $k$ leaves us somewhat unsatisfied. Therefore, we will next explore a bottom-up construction approach.
This idea comes from a reply to another Stack Exchange post, "Differentiable top-k function". Let $f(x)$ be any smooth, monotonically increasing function from $\mathbb{R} \mapsto [0,1]$ that satisfies $\lim_{x\to\infty}f(x) = 1$ and $\lim_{x\to-\infty}f(x) = 0$. This seems like many conditions, but such functions are not hard to construct—for example, the classic Sigmoid function $\sigma(x)=1/(1+e^{-x})$, as well as $\text{clip}(x,0,1)$, $\min(1, e^x)$, etc. We then consider:
\begin{equation}f(\boldsymbol{x}) = [f(x_1),f(x_2),\cdots,f(x_n)]\end{equation}How far is $f(\boldsymbol{x})$ from the $\mathcal{ST}_k(\boldsymbol{x})$ we want? Every component is already in $[0,1]$, but the components do not sum to $k$. Thus, we introduce an undetermined constant $\lambda(\boldsymbol{x})$ to ensure this:
\begin{equation}\mathcal{ST}_k(\boldsymbol{x}) \triangleq f(\boldsymbol{x} - \lambda(\boldsymbol{x})),\quad \sum_{i=1}^n f(x_i - \lambda(\boldsymbol{x})) = k\end{equation}That is, we solve for $\lambda(\boldsymbol{x})$ from the requirement that the sum of components equals $k$. We can call this "**ThreTopK** (Threshold-adjusted Soft Top-k operator)." Readers who have read "Path to Probability Distributions: A Survey of Softmax and Its Alternatives" will notice that this approach is identical to Sparsemax and Entmax-$\alpha$.
Is ThreTopK our ideal $\mathcal{ST}_k(\boldsymbol{x})$? Indeed! First, because we assumed $f$ is monotonic, monotonicity is satisfied. Second, $f(\boldsymbol{x} - \lambda(\boldsymbol{x}))=f(\boldsymbol{x}+c - (c+\lambda(\boldsymbol{x})))$, meaning constants can be absorbed into $\lambda(\boldsymbol{x})$, satisfying invariance. Finally, as $\tau\to 0^+$, we can find an appropriate threshold $\lambda(\boldsymbol{x}/\tau)$ such that the $k$ largest components of $\boldsymbol{x}/\tau-\lambda(\boldsymbol{x}/\tau)$ tend to $\infty$ and the rest tend to $-\infty$, so $f(\boldsymbol{x}/\tau-\lambda(\boldsymbol{x}/\tau))$ equals $\mathcal{T}_k(\boldsymbol{x})$, satisfying convergence.
Since the theoretical superiority of ThreTopK is proven, the next task is to calculate $\lambda(\boldsymbol{x})$, which in most cases requires numerical methods. However, for $f(x)=\min(1, e^x)$, we can obtain an analytical solution.
The solution logic is the same as for Sparsemax. Without loss of generality, assume the components of $\boldsymbol{x}$ are sorted in descending order: $x_1 > x_2 > \cdots > x_n$. Suppose we already know that $x_m \geq \lambda(\boldsymbol{x}) \geq x_{m+1}$; then:
\begin{equation}k = \sum_{i=1}^n \min(1, e^{x_i - \lambda(\boldsymbol{x})}) = m + \sum_{i=m+1}^n e^{x_i - \lambda(\boldsymbol{x})}\end{equation}Solving for $\lambda$:
\begin{equation}\lambda(\boldsymbol{x})=\log\left(\sum_{i=m+1}^n e^{x_i}\right) - \log(k-m)\end{equation}From this, we can see that when $k=1$, $m$ can only be $0$, and ThreTopK becomes Softmax. When $k > 1$, we cannot determine $m$ beforehand, so we must iterate $m=0,1,\cdots,k-1$ to calculate $\lambda(\boldsymbol{x})$ and find the one satisfying $x_m \geq \lambda(\boldsymbol{x}) \geq x_{m+1}$. Below is a crude reference implementation:
def thretopk_exp(x, k):
x_sort = np.sort(x)[::-1]
n = x.shape[-1]
for m in range(k):
log_e_x_sum = np.logaddexp.reduce(x_sort[m+1:])
lam = log_e_x_sum - np.log(k - m)
if x_sort[m] >= lam and (m == n - 1 or lam >= x_sort[m+1]):
return np.minimum(1.0, np.exp(x - lam))
Both the derivation and the code show that ThreTopK with $f(x)=\min(1, e^x)$ almost never encounters numerical stability issues and reduces to Softmax when $k=1$. These are its advantages. However, $\min(1, e^x)$ is not perfectly smooth (unless $k=1$ and the $\min$ is irrelevant); it is non-differentiable at $x=0$. If one cares about this, one needs a $C^\infty$ function for $f(x)$, such as $\sigma(x)$.
Taking $f(x)=\sigma(x)$ as an example, we cannot derive an analytical solution for $\lambda(\boldsymbol{x})$. However, due to the monotonic increase of $\sigma(x)$, the function
\begin{equation}F(\lambda)\triangleq \sum_{i=1}^n \sigma(x_i - \lambda)\end{equation}is monotonically decreasing with respect to $\lambda$. Thus, solving $F(\lambda(\boldsymbol{x}))=k$ numerically is not difficult using binary search or Newton's method. Using binary search, it is easy to see that $\lambda(\boldsymbol{x})\in[x_{\min} - \sigma^{-1}(k/n), x_{\max} - \sigma^{-1}(k/n)]$, where $\sigma^{-1}$ is the inverse of $\sigma$. Starting from this interval, we can bisect to the specified precision:
def thretopk_sigmoid(x, k, eps=1e-5):
inv_sigma = lambda p: -np.log(1 / p - 1)
low = np.min(x) - inv_sigma(k / len(x))
high = np.max(x) - inv_sigma(k / len(x))
for _ in range(20): # 20 iterations for precision
mid = (low + high) / 2
if np.sum(1 / (1 + np.exp(-(x - mid)))) > k:
low = mid
else:
high = mid
return 1 / (1 + np.exp(-(x - mid)))
Thus, the numerical calculation of $\lambda(\boldsymbol{x})$ is not a major obstacle. The real challenge is that when we use numerical methods to compute $\lambda(\boldsymbol{x})$, we often lose the gradient of $\lambda(\boldsymbol{x})$ with respect to $\boldsymbol{x}$, which affects end-to-end training. To address this, we can manually calculate $\nabla_{\boldsymbol{x}}\lambda(\boldsymbol{x})$ and customize the backpropagation process. Specifically, for
\begin{equation}\sum_{i=1}^n \sigma(x_i - \lambda(\boldsymbol{x})) = k\end{equation}taking the partial derivative with respect to some $x_j$, we get:
\begin{equation}\sigma'(x_j - \lambda(\boldsymbol{x}))-\sum_{i=1}^n \sigma'(x_i - \lambda(\boldsymbol{x}))\frac{\partial\lambda(\boldsymbol{x})}{\partial x_j} = 0\end{equation}Then:
\begin{equation}\frac{\partial\lambda(\boldsymbol{x})}{\partial x_j} = \frac{\sigma'(x_j - \lambda(\boldsymbol{x}))}{\sum\limits_{i=1}^n \sigma'(x_i - \lambda(\boldsymbol{x}))}\end{equation}where $\sigma'$ is the derivative of $\sigma$. We now have an expression for $\nabla_{\boldsymbol{x}}\lambda(\boldsymbol{x})$, where every term is computable (since $\lambda(\boldsymbol{x})$ was found numerically). We can directly specify this as the backpropagation result. A simple and universal implementation trick is to use $\text{stop\_gradient}$ (hereafter $\text{sg}$), substituting $\lambda(\boldsymbol{x})$ in the model with:
\begin{equation}\boldsymbol{x}\cdot\text{sg}[\nabla_{\boldsymbol{x}}\lambda(\boldsymbol{x})] + \text{sg}[\lambda(\boldsymbol{x}) - \boldsymbol{x}\cdot\nabla_{\boldsymbol{x}}\lambda(\boldsymbol{x})]\end{equation}where $\cdot$ is the vector dot product. In this way, during forward propagation, it is equivalent to $\lambda(\boldsymbol{x})$ because the $\text{sg}$ terms treat their contents as constants. During backpropagation, the gradient of the $\text{sg}$ parts is zero, leaving only the desired $\nabla_{\boldsymbol{x}}\lambda(\boldsymbol{x})$. This allows us to customize the gradient of $\lambda(\boldsymbol{x})$, regardless of how it was computed.
We see that $f(x)=\min(1,e^x)$ has an analytical solution but is not globally smooth, while $f(x)=\sigma(x)$ is sufficiently smooth but complex to solve. Is there a choice that combines both advantages? Indeed, I found that the following $f(x)$ is globally smooth, and $\lambda(\boldsymbol{x})$ can be solved analytically:
\begin{equation}f(x) = \left\{\begin{aligned}1 - e^{-x}/2,&\quad x\geq 0 \\ e^x / 2,&\quad x < 0\end{aligned}\right.\end{equation}It can also be written as $f(x) = (1 - e^{-|x|})\text{sign}(x)/2+1/2$. One can verify that $f(x)$ is an S-shaped function; while it is piecewise, both the function itself and its derivative are continuous at $x=0$, making it sufficiently smooth.
The solution logic remains the same. Assume $x_1 > x_2 > \cdots > x_n$, and suppose we know $x_m \geq \lambda(\boldsymbol{x}) \geq x_{m+1}$. Then:
\begin{equation}\begin{aligned} k =&\, \sum_{i=1}^m (1 - e^{-(x_i - \lambda(\boldsymbol{x}))}/2) + \sum_{i=m+1}^n e^{x_i - \lambda(\boldsymbol{x})}/2 \\ =&\, m - \frac{1}{2}e^{\lambda(\boldsymbol{x})}\sum_{i=1}^m e^{-x_i} + \frac{1}{2}e^{-\lambda(\boldsymbol{x})}\sum_{i=m+1}^n e^{x_i} \end{aligned}\end{equation}Solving for $\lambda$ yields:
\begin{equation}\lambda(\boldsymbol{x})=\log\sum_{i=m+1}^n e^{x_i} - \log\left(\sqrt{(k-m)^2 + \left(\sum_{i=1}^m e^{-x_i}\right)\left(\sum_{i=m+1}^n e^{x_i}\right)}+(k-m)\right)\end{equation}Then one iterates $m=0,1,\cdots,n-1$ and finds the $\lambda(\boldsymbol{x})$ satisfying $x_m \geq \lambda(\boldsymbol{x}) \geq x_{m+1}$. Readers can also try to prove that when $k=1$, ThreTopK with this $f(x)$ also reduces exactly to Softmax.
Reference implementation:
def thretopk_smooth(x, k):
x_sort = np.sort(x)[::-1]
n = x.shape[-1]
x_sort = np.concatenate([x_sort, [-np.inf]])
for m in range(n):
sum1 = np.sum(np.exp(-x_sort[:m])) if m > 0 else 0
sum2 = np.sum(np.exp(x_sort[m:]))
a = (k - m)
b = np.sqrt(a**2 + sum1 * sum2) + a
lam = np.log(sum2) - np.log(b)
if (m == 0 or x_sort[m-1] >= lam) and lam >= x_sort[m]:
return np.where(x >= lam, 1 - 0.5 * np.exp(-(x - lam)), 0.5 * np.exp(x - lam))
This article discussed the problem of smooth approximation for the Top-$k$ operator, which is a general generalization of smooth approximations for Top-1 like Softmax. We proposed three construction approaches—iterative construction, gradient guidance, and undetermined constants—and analyzed their respective pros and cons.