Decoupling weight decay with AdamW
In this blog we are going to explore how AdamW improves upon the standard Adam optimizer by decoupling the weight decay from the optimization step.
If readers are familiar with the background of Adam, they can skip to the section on AdamW.
Adam
Gradient Descent is a popular optimization algorithm used when the optimum cannot be calculated in closed form. Typically, the learning alorithm can be summarized as follows:
- Initialize the weights \(w^0\) and pick a learning rate \(\alpha\).
- For all data points \(i \in \{ 1, \ldots, N\} \): forward propagate \(x_i\) through the network to get the prediction \(\hat{y}_i\), and compute the loss \(\mathcal{L}_i(\hat{y}_i, y_i, w^t)\).
- Update the weights: \(w^{t+1} = w^t - \alpha \frac{1}{N}\sum_{i=1}^N \nabla_w \mathcal{L}_i \left(w^t\right) \).
- If the validation error continues to decrease, go to step 2, otherwise stop.
Since pure Gradient Descent requires calculating the gradient for each data point in the dataset before taking even a single step, it can be slow to converge.
To speed up the learning process, Stochastic Gradient Descent (SGD) was introduced. SGD instead computes the gradient for a small random batch of data points, and updates the weights based on this mini-batch. The goal is to approximate the true gradient by using a mini-batch of data points:
\[ \frac{1}{N}\sum_{i=1}^N \nabla_w \mathcal{L}_i \left(w^t\right) \approx \frac{1}{B}\sum_{b=1}^B \nabla_w \mathcal{L}_b \left(w^t\right) \]such that \(B \ll N \).
But even SGD is not perfect, since it scales the gradient equally in all directions. This can lead to slow convergence as highly oscilating gradient directions will slow down the learning process. A potential solution would focus on dampening the directions which oscillate the most, and amplify the directions which are stable across different batches.
In order to tackle this issue, two solutions have been proposed:
- SGD with Momentum (Sutskerver et al., 2013) computes a moving average of the gradients, which helps to dampen the oscillations: \[ \begin{align*} m^{t+1} &= \beta_1 m^t + (1-\beta_1) \nabla_w \mathcal{L}_i \left(w^t\right) \\ w^{t+1} &= w^t - \alpha \cdot m^{t+1} \end{align*} \]
- RMSProp (Tieleman and Hinton, 2012) scales the learning rate for each parameter by dividing it with the variance of the gradients: \[ \begin{align*} v^{t+1} &= \beta_2 v^t + (1-\beta_2) \left(\nabla_w \mathcal{L}_i \left(w^t\right) \nabla_w \odot \mathcal{L}_i \left(w^t\right) \right) \\ w^{t+1} &= w^t - \alpha \cdot \frac{\nabla_w \mathcal{L}_i \left(w^t\right)}{\sqrt{v^{t+1} + \epsilon}} \end{align*} \] where \(\odot\) denotes the element-wise (Hadamard) product, and \(\epsilon\) is a small constant to avoid division by zero.
With that said, Adam combines the benefits of both SGD with Momentum and RMSProp. It computes the moving average of the gradients and scales the learning rate for each parameter by dividing it with the variance of the gradients. The update rule for Adam is as follows:
\[ \begin{align*} m^{t+1} &= \beta_1 m^t + (1-\beta_1) \nabla_w \mathcal{L}_i \left(w^t\right) &\text{(SGD with Momentum)} \\ v^{t+1} &= \beta_2 v^t + (1-\beta_2) \left(\nabla_w \mathcal{L}_i \left(w^t\right) \nabla_w \odot \mathcal{L}_i \left(w^t\right) \right) &\text{(RMSProp)} \\ w^{t+1} &= w^t - \alpha \cdot \frac{m^{t+1}}{\sqrt{v^{t+1} + \epsilon}} \end{align*} \]One caveat is that we have to perform bias correction for the moving averages, as they are initialized to zero.
\[ \begin{align*} g_t &= \nabla_w \mathcal{L} \left(w^t\right) &\text{(rewrite for better readibility)}\\ m^0 &= 0 &\text{(initialize momentum to zero)} \end{align*} \]\[ \begin{align*} m^1 &= \beta_1 m^0 + (1-\beta_1) g_0 \\ &= (1-\beta_1) g_0 \\ m^2 &= \beta_1 m^1 + (1-\beta_1) g_1 \\ &= \beta_1 (1-\beta_1) g_0 + (1-\beta_1) g_1 \\ m^3 &= \beta_1 m^2 + (1-\beta_1) g_2 \\ &= \beta_1^2 (1-\beta_1) g_0 + \beta_1 (1-\beta_1) g_1 + (1-\beta_1) g_2 \\ \implies m^t &= (1-\beta_1) \sum_{i=0}^{t-1} \beta_1^{t-1-i} g_i \end{align*} \]\[ \begin{align*} \mathbb{E} \left[m^t \right] &= \mathbb{E} \left[ (1-\beta_1) \sum_{i=0}^{t-1} \beta_1^{t-1-i} g_i \right] \\ &= (1-\beta_1) \sum_{i=0}^{t-1} \beta_1^{t-1-i} \mathbb{E} \left[ g_i \right] \\ &\approx \mathbb{E}\left[g_t \right] (1-\beta_1) \sum_{i=0}^{t-1} \beta_1^{t-1-i} \\ &= \mathbb{E}\left[g_t \right] (1-\beta_1) \sum_{i=0}^{t-1} \beta_1^{i} \\ &= \mathbb{E}\left[g_t \right] (1-\beta_1) \frac{1-\beta_1^t}{1-\beta_1} &\text{(geometric series sum)} \\ &= \mathbb{E}\left[g_t \right] (1-\beta_1^t) \end{align*} \]\[ \begin{align*} s_n &= ar^0 + ar^1 + ar^2 + \ldots + ar^{n-1} \\ &= \sum_{i=0}^{n-1} ar^i \\ &= \cases{a \frac{1-r^n}{1-r} & $r \neq 1$ \\ na & $r = 1$} \end{align*} \]such that in our case \( a=1\) and \(r=\beta_1 \neq 1 \).
\[ \begin{align*} \hat{m}^t &= \frac{m^t}{1-\beta_1^t} \end{align*} \]\[ \begin{align*} \hat{v}^t &= \frac{v^t}{1-\beta_2^t} \end{align*} \]\[ \begin{align*} m^{t+1} &= \beta_1 m^t + (1-\beta_1) \nabla_w \mathcal{L}_i \left(w^t\right) &\text{(SGD with Momentum)} \\ v^{t+1} &= \beta_2 v^t + (1-\beta_2) \left(\nabla_w \mathcal{L}_i \left(w^t\right) \nabla_w \odot \mathcal{L}_i \left(w^t\right) \right) &\text{(RMSProp)} \\ \hat{m}^{t+1} &= \frac{m^{t+1}}{1-\beta_1^{t+1}} &\text{(bias correction for momentum)} \\ \hat{v}^{t+1} &= \frac{v^{t+1}}{1-\beta_2^{t+1}} &\text{(bias correction for variance)} \\ w^{t+1} &= w^t - \alpha \cdot \frac{\hat{m}^{t+1}}{\sqrt{\hat{v}^{t+1} + \epsilon}} \end{align*} \]AdamW
AdamW is a variant of Adam that decouples weight decay from the optimization step.
\[ w^{t+1} = (1-\lambda) w^t - \alpha \nabla_w \mathcal{L} \left( w^t \right) \]where \(\lambda \) defines the rate of weight decay per step.
Interestingly, in standard SGD, weight decay is equivalent to adding the L2 regularization term \(\frac{\lambda'}{2} \Vert w \Vert_2^2 \) to the optimization objective, where \(\lambda' = \frac{\lambda}{\alpha}\).
Consider the following (total) loss function:
\[ \mathcal{L}_{\text{total}} = \mathcal{L}(w) + \frac{\lambda'}{2} \Vert w \Vert_2^2 \]\[ g_t = \nabla_w \mathcal{L}_{\text{total}} = \nabla_w \mathcal{L}(w) + \lambda' w \]\[ \begin{align*} w^{t+1} &= w^t - \alpha \nabla_w \mathcal{L}_{\text{total}} \left( w^t \right) \\ &= w^t - \alpha \left( \nabla_w \mathcal{L}(w^t) + \lambda' w^t \right) \\ &= w^t - \alpha \nabla_w \mathcal{L}(w^t) - \alpha \frac{\lambda}{\alpha} w^t \\ &= (1-\lambda) w^t - \alpha \nabla_w \mathcal{L}(w^t) \end{align*} \]\[ g_t = \nabla_w \mathcal{L}(w) + \lambda' w \]where \(\lambda' = \frac{\lambda}{\alpha}\) is coupling the learning rate \(\alpha\) with the weight decay rate \(\lambda\).
However, the equivalence between \(L_2\) regularization and weight decay does not hold for adaptive gradient methods!
Let us consider the update step for SGD with Momentum:
\[ \begin{align*} m^{t+1} &= \beta_1 m^t + (1-\beta_1) \nabla_w \mathcal{L}_{\text{total}} \left( w^t \right) \\ &= \beta_1 m^t + (1-\beta_1) \left( \nabla_w \mathcal{L}(w^t) + \lambda' w^t \right) \\ &= \beta_1 m^t + (1-\beta_1) \nabla_w \mathcal{L}(w^t) + (1-\beta_1) \frac{\lambda}{\alpha} w^t \\ \\ w^{t+1} &= w^t - \alpha m^{t+1} \\ &= w^t - \alpha \left( \beta_1 m^t + (1-\beta_1) \nabla_w \mathcal{L}(w^t) + (1-\beta_1) \frac{\lambda}{\alpha} w^t \right) \\ &= w^t -\alpha \left( \beta_1 m^t + (1-\beta_1) \nabla_w \mathcal{L}(w^t) \right) - (1-\beta_1) \lambda w^t \\ &= (1 - \underbrace{(1-\beta_1)}_{\text{extra factor}} \lambda) w^t - \alpha \left( \beta_1 m^t + (1-\beta_1) \nabla_w \mathcal{L}(w^t) \right) \end{align*} \]As we see, we get an extra (unexpected) factor \((1-\beta_1)\), which makes it inequivalent to the pure weight decay form. The same holds for Adam, which also experiences this for the variance update term.
\[ \begin{align*} m^{t+1} &= \beta_1 m^t + (1-\beta_1) \nabla_w \mathcal{L}_i \left(w^t\right) &\text{(SGD with Momentum)} \\ v^{t+1} &= \beta_2 v^t + (1-\beta_2) \left(\nabla_w \mathcal{L}_i \left(w^t\right) \nabla_w \odot \mathcal{L}_i \left(w^t\right) \right) &\text{(RMSProp)} \\ \hat{m}^{t+1} &= \frac{m^{t+1}}{1-\beta_1^{t+1}} &\text{(bias correction for momentum)} \\ \hat{v}^{t+1} &= \frac{v^{t+1}}{1-\beta_2^{t+1}} &\text{(bias correction for variance)} \\ w^{t+1} &= w^t - \alpha \cdot \frac{\hat{m}^{t+1}}{\sqrt{\hat{v}^{t+1} + \epsilon}} - \lambda w^t &\text{(add the $-\lambda w^t$ term)}\\ &= (1 - \lambda) w^t - \alpha \cdot \frac{\hat{m}^{t+1}}{\sqrt{\hat{v}^{t+1} + \epsilon}} \end{align*} \]where \(\lambda\) is not coupled with the learning rate \(\alpha\).
Above is the full AdamW algorithm as layed out in the paper. Notice that the method also has a schedule multiplier factor \(\eta_t\) in order to allow for possible scheduling of both \(\alpha\) and \(\lambda\).