How to jointly tune learning rate and weight decay for AdamW
Published:
TL;DR: AdamW is often considered a method that decouples weight decay and learning rate. In this blog post, we show that this is not true for the specific way AdamW is implemented in Pytorch. We also show how to adapt the tuning strategy in order to fix this: when doubling the learning rate, the weight decay should be halved.
Introduction
Consider the training problem
where
Suppose we want to solve this problem with stochastic gradient methods. Let us introduce some notation: we denote
- the initial learning rate by
, - the multiplicative learning rate scheduler by
with , - the weight decay parameter by
.
The learning rate in iteration
The arguably most widely used method for training large-scale machine learning models is AdamW. It has been proposed by Loshchilov and Hutter and its main feature is that it handles weight decay separate from the loss
A short outline of this post:
- We explain the AdamW algorithm, and show that it is implemented in Pytorch slightly differently from the original paper.
- We explain the notion of decoupling of weight decay and learning rate, and why it is important in practice.
- We show why AdamW in Pytorch does not actually decouple weight decay and learning rate, because of the implementation difference mentioned in the first bullet.
- We then show how the tuning strategy needs to be adapted in order to fix this.
Disclaimer: I have previously written a blog post for the ICLR 2023 blog post track [2], that discusses the weight decay mechanism of AdamW, and how it can be seen as a proximal version of Adam (the blog post is based on the paper by Zhuang et al [4]). This post will re-use some of the figures and contents. In fact, I stumbled upon the central question of this blog post during writing back then.
On the subtleties of implementing AdamW
The quantities involed in AdamW are mostly the same as in the original version of Adam: let
The bias-corrected quantities are then given by
Let us denote the Adam preconditioner by
In Pytorch, the method is implemented slightly differently:
Note that the only difference consists in the coefficient
Remark: The implementation of AdamW is the same in Optax as in Pytorch. Hence, what follows applies similarly to tuning AdamW in Optax.
The meaning of decoupling
So what do we mean when we say that learning rate
The graphic below illustrates this phenomenon: imagine, we draw a heatmap of the validation loss over a
Fig. 1: Model performance (bright = good) as a function of learning rate and weight decay parameters. Illustration taken from [2].
Note that in practice this can make a huge difference: in general, we need to tune over the 2D-space of
This motivates why the decoupling property is important for practical use.
AdamW and its promise of decoupling
One of the main contributions of the AdamW paper [1] was that it showed how to treat weight decay separately from the loss. This is declared in the paper as follows:
The main contribution of this paper is to improve regularization in Adam by decoupling the weight decay from the gradient-based update.
The authors also claim that their method decouples the weight decay parameter
We provide empirical evidence that our proposed modification decouples the optimal choice of weight decay factor from the setting of the learning rate for […] Adam.
While this claim is supported by experiments in the paper, we will show next an example where there is no decoupling when using AdamW from Pytorch. The reason for this is, as we will show, the implementation subtlety we described in the previous section.
A simple experiment
The experiment is as follows: we solve a ridge regression problem for some synthetic data
We run both AdamW-LH and AdamW-PT, for a grid of learning-rate values
Below is the final validation-set loss, plotted as heatmap over
Fig. 2: Final validation loss (bright = low) as a function of learning rate
This matches the previous illustrative picture in Figure 1 pretty well (it’s not a perfect rectangle for AdamW-LH, but I guess it proves the point)!
Conclusion 1: Using the Pytorch implementation AdamW-PT, the parameters choices for
Based on this insight, the obvious question is: what is the best (joint) tuning strategy when using the Pytorch version AdamW-PT? We answer this next.
The right tuning strategy
Assume that we have already found a good candidate value
Assume that our tuning budget only allows for one line search/sweep. We will present two options for tuning:
(S1) Keep
If
(S2) When sweeping over
Strategy (S1) is slightly easier to code; my conjecture is that (S1) is also employed more often than (S2) in practice.
However, and this is the main argument, from the way AdamW-PT is implemented, (S2) seems to be more reasonable. We verify this next.
We plot again the heatmaps as before, but now highlighting the points that we would actually observe by the tuning strategy (S1) or (S2). We show below
- left: (S1) for AdamW-PT ,
- middle: (S2) for AdamW-PT,
- right: (S1) for AdamW-LH (as we had seen that AdamW-LH is indeed decoupled).
Here, we set (1e-2,3.2e-1)
for AdamW-PT, and (1e-2,3.2e-3)
for AdamW-LH. In the below plots, the circle-shaped markers highlight the sweep that corresponds to the tuning strategy (S1) or (S2). The bottom plot shows the validation loss as a curve over the highlighted markers.
Fig. 3: Three different tuning strategies: (S1) for AdamW-PT (left), (S2) for AdamW-PT (middle) and (S1) for AdamW-LH (right). Top: Heatmap of final validation loss where the highlighted points show the results of the respective sweep. Bottom: A curve of the final validation loss at each of the highlighted points (learning rate increases from left to right on x-axis).
Note that the bottom plot displays the final validation-loss values that a practitioner would observe for the sweep of each respective tuning strategy. What is important is the width of the valley of this curve, as it reflects how dense the sweep would need to be to obtain a low final loss. The main insight here is: for the middle and right ones, it would be much easier to obtain a low final loss, as for the left one. This is important when the sweep has only few trials due to high computational costs for a single run, or other practical constraints.
Conclusion 2: when using the Pytorch version of AdamW (i.e. AdamW-PT), tuning strategy (S2) should be used. That is, when doubling the learning rate, the weight decay should be halved.
In fact, Figure 3 also shows that tuning strategy (S2) for AdamW-PT is essentially the same as strategy (S1) for AdamW-LH.
Summary and final remarks:
Implementation details can have an effect on hyperparameter tuning strategies. We showed this phenomenon for AdamW, where the tuning strategy should be a diagonal line search if the Pytorch implementation is used.
In the appendix, we show that the results are similar when using a square-root decaying scheduler for
instead.This blog post only covers a ridge regression problem, and one might argue that the results could be different for other tasks. However, the exercise certainly shows there is no decoupling for AdamW-PT for one of the simplest possible problems, ridge regression. I also observed good performance of the (S2) strategy for AdamW-PT when training a vision transfomer on Imagenet (with the
timm
library).
If you want to cite this post, please use
@misc{adamw-decoupling-blog,
title = {How to jointly tune learning rate and weight decay for {A}dam{W}},
author = {Schaipp, Fabian},
howpublished = {\url{https://fabian-sp.github.io/posts/2024/02/decoupling/}},
year = {2024}
}
References
[1] Loshchilov, I. and Hutter, F., Decoupled Weight Decay Regularization, ICLR 2019.
[2] Schaipp F., Decay No More, ICLR Blog Post Track 2023.
[3] Kingma, D. and Ba, J., Adam: A Method for Stochastic Optimization, ICLR 2015.
[4] Zhuang Z., Liu M., Cutkosky A., Orabona F., Understanding AdamW through Proximal Methods and Scale-Freeness, TMLR 2022.
Appendix
Results with square-root schedule
To validate that the effects are similar for non-constant learning rates, we run the same experiment but now with a square-root decaying learning rate schedule. That is
Fig. 4: Same as Figure 3, but with a square-root decaying learning-rate schedule.
Pytorch code for AdamW-LH
For completeness, this is the code we used for AdamW-LH. It is adapted from here.
class AdamLH(Optimizer):
""" AdamW with fully decoupled weight decay.
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
self._init_lr = lr
super(AdamLH, self).__init__(params, defaults)
def __setstate__(self, state):
super(AdamLH, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
@torch.no_grad()
def step(self, closure=None):
"""
Performs a single optimization step.
Parameters
----------
closure : LossClosure, optional
A callable that evaluates the model (possibly with backprop) and returns the loss,
by default None.
loss : torch.tensor, optional
The loss tensor. Use this when the backward step has already been performed.
By default None.
Returns
-------
(Stochastic) Loss function value.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
lr = group['lr']
lmbda = group['weight_decay']
eps = group['eps']
beta1, beta2 = group['betas']
for p in group['params']:
if p.grad is None:
continue
# decay
p.mul_(1 - lmbda*lr/self._init_lr)
grad = p.grad
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
state['step'] += 1
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1-beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2)
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
step_size = lr / bias_correction1
update = -step_size * exp_avg / denom
p.add_(update)
return loss