How the Two Elementary Function Approximations of GELU Came to Be

By 苏剑林 | March 26, 2020

GELU, which stands for Gaussian Error Linear Unit, is a variant of the ReLU activation function and is expressed in a non-elementary form. It was introduced in the paper "Gaussian Error Linear Units (GELUs)", used later in GPT, then in BERT, and subsequently adopted by many later pre-trained language models. With the rise of BERT and other pre-trained models, GELU has surged in popularity, becoming a trendy activation function almost overnight.

gelu function image

In the original GELU paper, the authors proposed not only the exact form of GELU but also provided two elementary function approximations. This article discusses how those approximations were derived.

The GELU Function

The form of the GELU function is:

\begin{equation}\text{GELU}(x)=x \Phi(x)\end{equation}

where $\Phi(x)$ is the cumulative distribution function of the standard normal distribution, i.e.,

\begin{equation}\Phi(x)=\int_{-\infty}^x \frac{e^{-t^2/2}}{\sqrt{2\pi}}dt=\frac{1}{2}\left[1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right]\end{equation}

Here $\text{erf}(x)=\frac{2}{\sqrt{\pi}}\int_0^x e^{-t^2}dt$. The original paper then mentions two approximations:

\begin{equation}x\Phi(x)\approx x\sigma(1.702 x)\label{eq:x-sigma}\end{equation}

and

\begin{equation}x\Phi(x)\approx \frac{1}{2} x \left[1 + \tanh\left(\sqrt{\frac{2}{\pi}}\left(x + 0.044715 x^3\right)\right)\right]\label{eq:x-phi}\end{equation}

Currently, many implementations of Transformer-based models still use the approximation \eqref{eq:x-phi} for the GELU function. However, since many frameworks now include precise $\text{erf}$ calculation functions, the value of these elementary function approximations might be diminishing. Thus, consider this as a mathematical analysis exercise.

What to Use for Approximation

Finding an approximation for GELU is equivalent to finding an approximation for $\Phi(x)$, which in turn is equivalent to finding an approximation for $\text{erf}\left(\frac{x}{\sqrt{2}}\right)$.

erf function image

First, we need to address the question of what function to use for the approximation. From the graph of $\text{erf}(x)$, we can observe its characteristics:

1. It is an odd function, i.e., $\text{erf}(x)=-\text{erf}(-x)$;

2. It is monotonically increasing, with $\lim\limits_{x\to -\infty}\text{erf}(x)=-1$ and $\lim\limits_{x\to +\infty}\text{erf}(x)=1$.

We have many examples of odd functions, such as $x^{2n+1}, \sin x, \tan x, \tanh x$, etc. Furthermore, the superposition or composition of odd functions remains an odd function, such as $\sin\left(x + x^3\right)$. Among functions that are odd, monotonically increasing, and bounded, $\tanh x$ is perhaps the most obvious choice. In fact, $\tanh x$ is very similar to $\text{erf}(x)$.

Therefore, we can start from $\tanh x$ to construct possible fitting forms, such as:

\begin{equation}\left\{\begin{aligned} &\tanh\left(a x + b x^3 + c x^5\right)\\ &a\tanh x + b \tanh^3 x + c \tanh^5 x\\ &a\tanh bx + c \tanh dx + e \tanh fx\\ &\vdots \end{aligned}\right.\end{equation}

How to Approximate

Once we have the form to be fitted, the next concern is how to perform the fitting and what criteria to use. Generally speaking, there are two approaches: local fitting and global fitting.

Local Fitting

Local fitting is based on Taylor expansion. For example, considering the approximation form $\tanh\left(a x + b x^3\right)$, we expand it around $x=0$ to get:

\begin{equation}\text{erf}\left(\frac{x}{\sqrt{2}}\right) - \tanh\left(a x + b x^3\right)=\left(\sqrt{\frac{2}{\pi }}-a\right) x + \left(\frac{a^3}{3}-b-\frac{1}{3 \sqrt{2 \pi }}\right)x^3 + \dots\end{equation}

By setting the first two terms to zero, we obtain two equations. Solving them yields:

\begin{equation}a=\sqrt{\frac{2}{\pi}},\quad b=\frac{4-\pi }{3 \sqrt{2} \pi ^{3/2}}\end{equation}

Substituting these back into $x\Phi(x)$ and converting them to numerical form gives:

\begin{equation}x\Phi(x)\approx \frac{1}{2} x\left[1 + \tanh\left(\sqrt{\frac{2}{\pi}}\left(x + 0.0455399 x^3\right)\right)\right]\label{eq:x-phi-local}\end{equation}

Global Fitting

Equation \eqref{eq:x-phi-local} is already quite close to equation \eqref{eq:x-phi}, but the second coefficient is still slightly off. This is because \eqref{eq:x-phi-local} is purely a result of local approximation. As the name suggests, local approximation is very accurate locally; for instance, the derivation above is based on the Taylor expansion at $x=0$, so it is very accurate near $x=0$, but the error increases as we move further away from $0$. Thus, we also need to consider global error.

A common global error measure is the integral form. For example, when approximating $f(x)$ with $g(x,\theta)$, we minimize:

\begin{equation}\min_{\theta} \int [f(x)-g(x,\theta)]^2 dx \quad\text{or}\quad \min_{\theta} \int |f(x)-g(x,\theta)| dx \end{equation}

However, the importance of error at each $x$ might vary. Therefore, to ensure generality, one might multiply by a weight $\lambda(x)$:

\begin{equation}\min_{\theta} \int \lambda(x)[f(x)-g(x,\theta)]^2 dx \quad\text{or}\quad \min_{\theta} \int \lambda(x)|f(x)-g(x,\theta)| dx \end{equation}

Different choices of $\lambda(x)$ lead to different solutions, and choosing the most suitable $\lambda(x)$ is not straightforward.

Instead of optimizing this integral error, we optimize a more intuitive min-max error:

\begin{equation}\min_{\theta} \max_x |f(x)-g(x,\theta)|\end{equation}

This expression is easy to understand: "Find an appropriate $\theta$ such that the maximum $|f(x)-g(x,\theta)|$ is as small as possible." Such a goal aligns with our intuitive understanding and avoids the selection of weights.

Hybrid Fitting

Based on this idea, we fix $a=\sqrt{\frac{2}{\pi}}$ and then re-solve for $\tanh\left(a x + b x^3\right)$. We fix this $a$ because it represents the first-order local approximation; we want to preserve some local accuracy while letting $b$ help us minimize the global error as much as possible, thus achieving a hybrid of local and global approximation. So, we now solve:

\begin{equation}\min_{b} \max_x \left|\text{erf}\left(\frac{x}{\sqrt{2}}\right)-\tanh\left(a x + b x^3\right)\right|\end{equation}

Using scipy, this can be easily solved:

import numpy as np
from scipy.special import erf
from scipy.optimize import minimize

def f(x, b):
 a = np.sqrt(2 / np.pi)
 return np.abs(erf(x / np.sqrt(2)) - np.tanh(a * x + b * x**3))

def g(b):
 return np.max([f(x, b) for x in np.arange(0, 4, 0.001)])

options = {'xtol': 1e-10, 'ftol': 1e-10, 'maxiter': 100000}
result = minimize(g, 0, method='Powell', options=options)
print(result.x)

Finally, we obtain $b=0.035677337314877385$, which corresponds to the form:

\begin{equation}x\Phi(x)\approx \frac{1}{2} x\left[1 + \tanh\left(\sqrt{\frac{2}{\pi}}\left(x + 0.04471491123850965 x^3\right)\right)\right]\label{eq:x-phi-global}\end{equation}

The last few significant digits might have some error, but the preceding part matches equation \eqref{eq:x-phi} perfectly. As a supplementary note, equation \eqref{eq:x-phi} was proposed in the paper "Approximations to the Cumulative Normal Function and its Inverse for Use on a Pocket Calculator", which is a result from over 40 years ago.

As for the first approximation, it comes from the paper "A logistic approximation to the cumulative normal distribution". It is the result of using $\sigma(\lambda x)$ to globally approximate $\Phi(x)$ directly, i.e.,

\begin{equation}\min_{\lambda}\max_{x}\left|\Phi(x) - \sigma(\lambda x)\right|\end{equation}

Solving this yields $\lambda=1.7017449256323682$, which means:

\begin{equation}\Phi(x)\approx \sigma(1.7017449256323682 x)\end{equation}

This is also very consistent with equation \eqref{eq:x-sigma}.

Article Summary

In this article, we solved a mathematical analysis problem together—introducing the GELU activation function and attempting to explore the origins of its two approximate elementary forms, successfully "watering" this blog post into existence.