Consider the training problem
\[\nonumber \min_{w\in \mathbb{R}^d} \ell(w),\]where $\ell: \mathbb{R}^d\to \mathbb{R}$ is a loss function, and $w$ are learnable parameters. Assume that the loss is given as $\ell(w) = \mathbb{E}_{x\sim \mathcal{P}} [\ell(w,x)]$, where $x$ is a batch of data, sampled from the training data distribution $\mathcal{P}$.
Suppose we want to solve this problem with stochastic gradient methods. Let us introduce some notation: we denote
The learning rate in iteration $t$ will be given by $\alpha_t := \alpha \eta_t$. We will often refer to $\alpha$ as learning rate parameter, which is slightly inprecise, but for most of the contents the schedule $\eta_t$ will be constant anyway.
The arguably most widely used method for training large-scale machine learning models is AdamW. It has been proposed by Loshchilov and Hutter and its main feature is that it handles weight decay separate from the loss $\ell$ (as opposed to the original Adam [3]). For readers not familiar with AdamW, we refer to [1] and briefly explain the AdamW update formula below.
A short outline of this post:
Disclaimer: I have previously written a blog post for the ICLR 2023 blog post track [2], that discusses the weight decay mechanism of AdamW, and how it can be seen as a proximal version of Adam (the blog post is based on the paper by Zhuang et al [4]). This post will re-use some of the figures and contents. In fact, I stumbled upon the central question of this blog post during writing back then.
The quantities involed in AdamW are mostly the same as in the original version of Adam: let $g_t=\nabla \ell(w_t,x_t)$ be the stochastic gradient in iteration $t$, then for $\beta_1,\beta_2\in[0,1)$ we compute
\[m_t = \beta_1 m_{t-1} + (1-\beta_1) g_t, \nonumber \\ v_t = \beta_2 v_{t-1} + (1-\beta_2) g_t \odot g_t. \nonumber \\\]The bias-corrected quantities are then given by
\[\hat m_t = m_{t}/(1-\beta_1^{t}), \nonumber \\ \hat v_t = v_{t}/(1-\beta_2^{t}) . \nonumber \\\]Let us denote the Adam preconditioner by $D_t = \mathrm{diag}(\epsilon + \sqrt{\hat v_t})$. The way that AdamW was proposed originally in the paper by Loshchilov and Hutter [1] is
\[w_{t+1} = (1-\lambda\eta_t)w_t - \alpha_t D_t^{-1}\hat m_t. \tag{AdamW-LH}\]In Pytorch, the method is implemented slightly differently:
\[w_{t+1} = (1-\lambda\alpha_t)w_t - \alpha_t D_t^{-1}\hat m_t. \tag{AdamW-PT}\]Note that the only difference consists in the coefficient $1-\lambda\alpha_t= 1-\lambda\alpha\eta_t$ instead of $1-\lambda\eta_t$. While this seems like trivia at first sight (one could easily reparametrize $\lambda$ by $\lambda \alpha$), we will show that it has an important practical implication on tuning.
Remark: The implementation of AdamW is the same in Optax as in Pytorch. Hence, what follows applies similarly to tuning AdamW in Optax.
So what do we mean when we say that learning rate $\alpha$ and weight decay $\lambda$ are decoupled? We will work with the following (approximate) definition: we say that $\alpha$ and $\lambda$ are decoupled, if the optimal choice for $\lambda$ does not depend on the choice of $\alpha$. Here, we mean optimal with respect to some metric of interest - for the rest of the blog post, this metric will be the loss $\ell$ computed over a validation dataset.
The graphic below illustrates this phenomenon: imagine, we draw a heatmap of the validation loss over a $(\alpha,\lambda)$ grid. Bright values indicate a better model performance. Then, in a coupled scenario (left) the bright valley could have a diagonal shape, while for the decoupled scenario (right) the valley is more rectangular.
Fig. 1: Model performance (bright = good) as a function of learning rate and weight decay parameters. Illustration taken from [2].
Note that in practice this can make a huge difference: in general, we need to tune over the 2D-space of $(\alpha,\lambda)$, assuming that all other hyperparameters are already set. The naive way to do this is a grid search. However, if we know that $\alpha$ and $\lambda$ are decoupled, then it would be sufficient to do two separate line searches for $\alpha$ and $\lambda$, followed by combining the best values from each line search. For example, if for each parameter we have $N$ candidate values, this reduces the tuning effort from $N^2$ (naive grid search) to $2N$ (two line searches).
This motivates why the decoupling property is important for practical use.
One of the main contributions of the AdamW paper [1] was that it showed how to treat weight decay separately from the loss. This is declared in the paper as follows:
The main contribution of this paper is to improve regularization in Adam by decoupling the weight decay from the gradient-based update.
The authors also claim that their method decouples the weight decay parameter $\lambda$ and the learning rate $\alpha$ (which goes beyond decoupling weight decay and loss).
We provide empirical evidence that our proposed modification decouples the optimal choice of weight decay factor from the setting of the learning rate for […] Adam.
While this claim is supported by experiments in the paper, we will show next an example where there is no decoupling when using AdamW from Pytorch. The reason for this is, as we will show, the implementation subtlety we described in the previous section.
The experiment is as follows: we solve a ridge regression problem for some synthetic data $A \in \mathbb{R}^{n \times d},~b\in \mathbb{R}^{n}$ with $n=200,~d=1000$. Hence $\ell$ is the squared loss, given by $\ell(w) = \Vert Aw-b \Vert^2$.
We run both AdamW-LH and AdamW-PT, for a grid of learning-rate values $\alpha$ and weight-decay values $\lambda$. For now, we set the scheduler to be constant, that is $\eta_t=1$. We run everything for 50 epochs, with batch size 20, and average all results over five seeds.
Below is the final validation-set loss, plotted as heatmap over $\alpha$ and $\lambda$. Again, brighter values indicate lower loss values.
Fig. 2: Final validation loss (bright = low) as a function of learning rate $\alpha$ and weight decay parameter $\lambda$.
This matches the previous illustrative picture in Figure 1 pretty well (it’s not a perfect rectangle for AdamW-LH, but I guess it proves the point)!
Conclusion 1: Using the Pytorch implementation AdamW-PT, the parameters choices for $\alpha$ and $\lambda$ are not decoupled in general. However, the originally proposed method AdamW-LH indeed shows decoupling for the above example.
Based on this insight, the obvious question is: what is the best (joint) tuning strategy when using the Pytorch version AdamW-PT? We answer this next.
Assume that we have already found a good candidate value $\bar \lambda$ for the weight-decay parameter; for example, we obtained $\bar \lambda$ by tuning for a fixed (initial) learning rate $\bar \alpha$. Now we also want to tune the (initial) learning-rate value $\alpha$.
Assume that our tuning budget only allows for one line search/sweep. We will present two options for tuning:
(S1) Keep $\lambda = \bar \lambda$ fixed, and simply sweep over a range of values for $\alpha$.
If $\alpha$ and $\lambda$ are decoupled, then (S1) should work fine. However, as we saw before, the Pytorch version of AdamW, called AdamW-PT, seems not to be decoupled. Instead, the decay coefficient in each iteration is given by $1 - \alpha \lambda \eta_t$. Thus, it seems intuitive to keep the quantity $\alpha \lambda$ fixed, which is implemented by the following strategy:
(S2) When sweeping over $\alpha$, adapt $\lambda$ accordingly such that the product $\alpha \lambda$ stays fixed. For example, if $\alpha = 2\bar \alpha$, then set $\lambda=\frac12 \bar \lambda$.
Strategy (S1) is slightly easier to code; my conjecture is that (S1) is also employed more often than (S2) in practice.
However, and this is the main argument, from the way AdamW-PT is implemented, (S2) seems to be more reasonable. We verify this next.
We plot again the heatmaps as before, but now highlighting the points that we would actually observe by the tuning strategy (S1) or (S2). We show below
Here, we set $(\bar \lambda, \bar \alpha) =$ (1e-2,3.2e-1)
for AdamW-PT, and $(\bar \lambda, \bar \alpha) =$ (1e-2,3.2e-3)
for AdamW-LH.
In the below plots, the circle-shaped markers highlight the sweep that corresponds to the tuning strategy (S1) or (S2). The bottom plot shows the validation loss as a curve over the highlighted markers.
Fig. 3: Three different tuning strategies: (S1) for AdamW-PT (left), (S2) for AdamW-PT (middle) and (S1) for AdamW-LH (right). Top: Heatmap of final validation loss where the highlighted points show the results of the respective sweep. Bottom: A curve of the final validation loss at each of the highlighted points (learning rate increases from left to right on x-axis).
Note that the bottom plot displays the final validation-loss values that a practitioner would observe for the sweep of each respective tuning strategy. What is important is the width of the valley of this curve, as it reflects how dense the sweep would need to be to obtain a low final loss. The main insight here is: for the middle and right ones, it would be much easier to obtain a low final loss, as for the left one. This is important when the sweep has only few trials due to high computational costs for a single run, or other practical constraints.
Conclusion 2: when using the Pytorch version of AdamW (i.e. AdamW-PT), tuning strategy (S2) should be used. That is, when doubling the learning rate, the weight decay should be halved.
In fact, Figure 3 also shows that tuning strategy (S2) for AdamW-PT is essentially the same as strategy (S1) for AdamW-LH.
Summary and final remarks:
Implementation details can have an effect on hyperparameter tuning strategies. We showed this phenomenon for AdamW, where the tuning strategy should be a diagonal line search if the Pytorch implementation is used.
In the appendix, we show that the results are similar when using a square-root decaying scheduler for $\eta_t$ instead.
This blog post only covers a ridge regression problem, and one might argue that the results could be different for other tasks. However, the exercise certainly shows there is no decoupling for AdamW-PT for one of the simplest possible problems, ridge regression. I also observed good performance of the (S2) strategy for AdamW-PT when training a vision transfomer on Imagenet (with the timm
library).
[1] Loshchilov, I. and Hutter, F., Decoupled Weight Decay Regularization, ICLR 2019.
[2] Schaipp F., Decay No More, ICLR Blog Post Track 2023.
[3] Kingma, D. and Ba, J., Adam: A Method for Stochastic Optimization, ICLR 2015.
[4] Zhuang Z., Liu M., Cutkosky A., Orabona F., Understanding AdamW through Proximal Methods and Scale-Freeness, TMLR 2022.
To validate that the effects are similar for non-constant learning rates, we run the same experiment but now with a square-root decaying learning rate schedule. That is $\eta_t = 1/\sqrt{\text{epoch of iteration } t}$. We sweep again over the initial learning rate $\alpha$ and weight decay parameter $\lambda$. The results are plotted below:
Fig. 4: Same as Figure 3, but with a square-root decaying learning-rate schedule.
For completeness, this is the code we used for AdamW-LH. It is adapted from here.
class AdamLH(Optimizer):
""" AdamW with fully decoupled weight decay.
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
self._init_lr = lr
super(AdamLH, self).__init__(params, defaults)
def __setstate__(self, state):
super(AdamLH, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
@torch.no_grad()
def step(self, closure=None):
"""
Performs a single optimization step.
Parameters
----------
closure : LossClosure, optional
A callable that evaluates the model (possibly with backprop) and returns the loss,
by default None.
loss : torch.tensor, optional
The loss tensor. Use this when the backward step has already been performed.
By default None.
Returns
-------
(Stochastic) Loss function value.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
lr = group['lr']
lmbda = group['weight_decay']
eps = group['eps']
beta1, beta2 = group['betas']
for p in group['params']:
if p.grad is None:
continue
# decay
p.mul_(1 - lmbda*lr/self._init_lr)
grad = p.grad
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
state['step'] += 1
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1-beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2)
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
step_size = lr / bias_correction1
update = -step_size * exp_avg / denom
p.add_(update)
return loss
The post is part of his (quite awesome) blog post series Optimization Nuggets and can be found here.
]]>It is very hard to reach all goals at the same time. Typically, compiled languages such as C++ offer speed but lack the other two (at least for my taste). Some concrete examples of this tradeoff:
scikit-learn
many classical solvers are written in C++ or cython
, for example the logistic regression solvers.ProximalOperators.jl
and ProximalAlgorithms.jl
packages.However, the goal of this article is to present one approach of reaching all three goals in Python.
Consider problems of the form
\[\min_x f(x) + r(x)\]where we assume $f$ to be continously differentiable and $r$ is a (closed, convex) regularizer. For a step size $\alpha>0$, the proximal gradient algorithm for such problems is given by the iterates
\[x^{k+1} = \mathrm{prox}_{\alpha r}(x^k- \alpha \nabla f(x^k)),\]where $\mathrm{prox}$ is the proximal operator of a closed, convex function.
If we implement an algorithm for problems of the above type, it would be favourable to have code that works for any functions f
and r
fulfilling the respective assumptions. Moreover, as we have a composite objective, we would prefer to have a solver which we can call for any combination of f
and r
we would like - without adapting the code of the solver.
An obvious approach to achieve this, is handling both f
and r
as instances of classes, having the following methods:
f
needs the method grad
which computes a gradient at a specific point,r
needs the method prox
which computes the proximal operator of $\alpha\cdot r$ at a specific point.Let us show the implementation for f
being a quadratic function and r
being the 1-norm.
class Quadratic:
def __init__(self, A, b):
self.A = A
self.b = b
def grad(self, x):
g = self.A @ x + self.b
return g
The below formula for the proximal operator is well-known but for the understanding it is not so important here.
class L1Norm:
def __init__(self, l):
self.l = l
def prox(self, x, alpha):
return np.sign(x) * np.maximum(np.abs(x) - alpha*self.l, 0.)
Now, proximal gradient descent can be implemented generally with the following simple function:
def prox_gd(f, r, x0, alpha=0.1, max_iter=50):
x = x0.copy()
for i in range(max_iter):
y = x-alpha*f.grad(x)
x = r.prox(y,alpha)
return x
This is general and very simple to read. If one would like to apply the algorithm to a different objective, he would only need to write the respective f
and/or r
. With this, a library of functions can be built and used modularly.
However, all of the above is pure Python code and will therefore be pretty slow. Our goal is to use Numba in order to accelerate the implementation while keeping generality and readability.
Numba is a package for just-in-time (JIT) compilation. It is designed to speed up pure Python code using the decorator @njit
.
Numba support many functions built in numpy
. A detailled list is here.
The speedup comes typically from for
-loops - which naturally appear in optimization algorithms. Thus, we want to write our solver as a JIT-compiled numba
function. The problem: everything that happens inside a JIT-compiled function must itself be compiled. Thus, if we want to make use of class methods inside the solver, the class must be such that every method is JIT-compiled. Luckily, numba
offers this possibility using @jitclass
.
When using @jitclass
, it is important to specify the type of every attribute of the class. See the example below or the docs for all details. Our quadratic function class can be implemented as follows:
from numba.experimental import jitclass
from numba import float64, njit
spec = [
('b', float64[:]),
('A', float64[:,:])
]
@jitclass(spec)
class Quadratic:
def __init__(self, A, b):
self.A = A
self.b = b
def grad(self, x):
g = self.A @ x + self.b
return g
Same with the 1-norm:
spec_l1 = [('l', float64)]
@jitclass(spec_l1)
class L1Norm:
def __init__(self, l):
self.l = l
def prox(self, x, alpha):
return np.sign(x) * np.maximum(np.abs(x) - alpha*self.l, 0.)
Remark: @jitclass
alone does not necessarily speed up the code. The main speedup will come from for
-loops, typically appearing in the solver.
After implementing Quadratic
and L1Norm
as specific examples for f
and r
, we can now implement a numba
-version of proximal gradient descent. We can pretty muchy copy the code and simply add the @njit
decorator.
@njit()
def fast_prox_gd(f, r, x0, alpha=0.1, max_iter=50):
x = x0.copy()
for i in range(max_iter):
y = x-alpha*f.grad(x)
x = r.prox(y,alpha)
return x
Some remarks on the @njit
decorator (mainly a reminder to myself):
float
instead of int
), numba
will recompile the function (which takes longer).I implemented the pure Python and the Numba version of proximal gradient descent in this notebook.
For a simple 50-dimensional example with f
being quadratic function and r
the 1-norm, we get the following result:
# Python version
%timeit prox_gd(f, r, x0, alpha=0.001, max_iter=20000)
10 loops, best of 5: 164 ms per loop
# Numba version
%timeit fast_prox_gd(f, r, x0, alpha=0.001, max_iter=20000)
10 loops, best of 5: 54.2 ms per loop
Even for this simple example, we already get a speedup factor over 3. Of course, how much speedup is possible depends on how much of the computation is due to the loop or rather due to numerical heaviness (e.g. matrix vector multiplitication in high dimensions).
If the gradient or prox computation involves for
-loops (e.g. Condat’s algorithm for total variation regularization), using numba
will result in significant speedups in my experience.
The outlined approach can also be applied to stochastic algorithms where the number of iterations and thus the speedup is typically large. You can find some standard algorithms such as SGD, SAGA or SVRG in this repository.
Thanks for reading!
numba
and cython
: http://gouthamanbalaraman.com/blog/optimizing-python-numba-vs-cython.htmlThis article serves as a short checklist - mainly as a reminder to myself - for converting your research code into an open-source, distributable and well documented package. Some, but not all, steps might only apply to Python projects. Most of the individual steps are very well documented, so you can see this as a collection of websites/tutorials that helped me for my own projects.
Many of the following steps are much simplified if your code is already a Github repository.
If you aim to make your package available to others, it should have a license. While there are many standard open-source licenses around, be aware that your choice can make a difference in how others can use or redistribute your package. You can add a license directly over the Github page of your repository (link to docu).
A great introductory article on the legal background of open-source licenses is here.
When your project grows, at some point you might need to use some of your functions across multiple other scripts. In order to import from your module, you only need one additional file, a setup.py
file, and install the module locally as a package in your (virtual) environment. Fortunately, a setup file is basically all you need in order to make your package distributable - for example with pip
or conda
.
A useful guide on how to create a setup file and make your package distributable with pip
is here.
Other great resources with many details are this packaging guide and this introduction from the Python packaging authority.
If you ever had to become familiar with a code repository you did not write yourself, you will understand the importance of a proper documentation. Apart from the standard advice of using docstrings and comments where needed, you can also create and publish a documentation for your package as a whole. Typically, this could be included in the README
of your repository. However, if your package becomes more complex and needs more explaining, you might consider creating a documentation on Read the docs. I will list the steps on how to achieve this (obviously other tools could be used, but I will describe the ones I used myself).
Create a documentation using Sphinx. This mainly involves writing .md
or .rst
files where you explain everything which is needed. Here is a guide on how to get started.
One of the great features of Sphinx is, that it can parse the docstrings of your functions into nicely-looking and readable websites (as you might know it from the docs of numpy
). Moreover, you can include math formulas, cross-references or links into the docstrings. Like this, if you change the source code you will only need to update the docstrings and the documentation will be up-to-date automatically (see section Autodocs below).
Build the documentation locally (see below how to do that).
If you created the documentation files within the subfolder docs
, the commands for this are as simple as
cd docs/
make html
.readthedocs.yaml
on the top level of your repository. A minimal file example is here.main
or master
). Readthedocs will install all dependencies from requirements.txt
(whereas Pypi uses the ones from setup.py
).The Readthedocs documentation provides you with all the details about importing and building your documentation.
If your package gets more involved, having some text with pictures and automatically generated class and function documentation might not satisfy you. Fortunately, there are numerous ways to bring your documentation to the next level:
Explaining by example is often much more effective. Thus, show what your package can do by simply setting up small example scripts. Sphinx Gallery
offers an easy way to include showcase examples in your documentation and beautifully embed plots, visualization and code snippets.
The main idea is very simple: in a subfolder of your repository (e.g. called examples
) every Python script will be parsed. Every script with a filename starting with plot_
will be executed and all plots are shown.
As the gallery is an extension of Sphinx, it can be easily integrated into your configuration from step 3. All essential infos for getting started can be found in the documentation. Advanced configuration options, e.g. ignoring some of the files in the examples
directory, are here.
As an alternative to an example gallery, you can also use Jupyter notebooks to create tutorials on how to use your package, or showcase some of its features.
Like on Github, where Jupyter notebooks are rendered directly in your browser when you open them, the nbsphinx package allows you to do the same on your readthedocs page. Simply create your tutorial notebooks, save them to your docs
subfolder, and add the notebooks to your documentation index as described here.
From there, the possibilities are almost endless. You can even do things like linking to an interactive version of the notebook on Google colab or similar.
For projects with many submodules, it can be tedious to manually write a file that has links to all the documentation pages of different classes, functions etc. Fortunately, the autodocs extension for Sphinx can automatically generate them for you.
It can automatically generate documentation pages from the docstrings in an entire source file, class or function in your package.
Note: As you might have guessed, getting all docstrings rendered properly can be a little tricky at times, so try this out locally first before including it in your readthedocs page.
Even if you do not aim to open-source your code, you should include (unit) tests. This means writing test functions which ensure that your code is a) running without errors and b) giving the correct result.
Remark: even though it is the last point in this checklist, writing tests should be at best done while developing the package.
For example, if you wrote a function my_sqrt
which should always return a non-negative result, you could add a test like this:
def test_my_sqrt():
a = my_sqrt(b)
assert a >= 0
return
Often, you want to assert that two numbers or arrays are equal up to some numerical inaccuracy. For this, numpy
provides useful functionalities.
Using pytest
, all files with filename starting or ending with test
will be scanned for functions which start with the prefix test
(as in the example above).
If you want to compute a coverage report (i.e. how many lines of code are included in one of the tests), use
pytest --cov=my_package my_package/
Make it easy for others to cite your software. A citation snippet can be added directly over Github.
Automate testing and building using Github actions. For example, you can create automatic coverage reports with Codecov.
Thank you for reading! Many thanks to Johannes Ostner for giving feedback and adding some ressources.
]]>