The case for unit-normalized loss (with back-propagation)

Imagine you're me - struggling to debug unstable training[1] - then while taking a detour[2] into optimizers you find out that not only does batch accumulation affect the step-size (effective learning-rate), the scalar loss value you end up with also scales the gradients.

I had assumed that gradients were automagically normalized by either the framework or the optimizer; neither is the case. This was obvious in hind-sight given the chain-rule but, to me coming from evolutionary strategies, this seemed like absolute insanity. How do you test optimizers if the step-size entirely depends on their progress in-training? How do you develop new optimizers, given that gradients can be anything from 1e-8 to 1e+8 depending on the loss function?

With this in mind I took a detour from my detour to test a new method; what if gradients but normalized?

Brief intermezzo

Since I was working with various optimizers trying to compare them for my previous detour, one optimizer came to mind which did have 'normalized' step-sizes (generously put): SignSGD.

SignSGD is a bit of an odd case in the optimizer world, it effectively does: params = params - sign(grad)*alpha which is about a simple as it gets, but due to sign always being in ${-1, 0, 1}$ the step-size is (mostly) constant.

SignSGD in action, provided by yours truly as there aren't any good existing plots

Maybe you can also see the issue with this; it's terrible and it's optimizer-specific. We only need the norm to be equal between steps and for the same parameters, i.e. a part of the model that doesn't need to move much shouldn't have to but the amount of aggregate movement should be approximately the same (the norm of the gradients).

Unit-normalization

To make sure I didn't fool myself I first set up our existing belief that scaling the loss does in fact change parameters through back-propagation:

model = nn.Linear(10, 10)
x = torch.randn(10, 10)
y = torch.randn(10, 10)

print("MSE with batchsize=10")

criterion = nn.MSELoss(reduction='mean')
out = model(x)
loss = criterion(out, y)
loss.backward()
print('mean batch\t\t', torch.norm(model.weight.grad))

model.zero_grad()
criterion = nn.MSELoss(reduction='sum')
out = model(x)
loss = criterion(out, y)
loss.backward()
print('sum batch\t\t', torch.norm(model.weight.grad))

model.zero_grad()
out = model(x)
loss = criterion(out, y)
(loss / 100).backward()
print('sum batch/(batchsize*10)\t', torch.norm(model.weight.grad))
MSE with batsize=10
mean batch		tensor(0.6698)
sum batch		tensor(66.9827)
sum batch/batchsize	tensor(0.6698)

Working as expected

You may notice that we are dividing by 10*batchsize (this does hold for other values) this is due to the amount of values we are comparing with MSE. This makes it very annoying to normalize since we need contextual information about what the loss function is doing, unless...

If we divide the loss by the loss (i.e. make loss=1); regardless of how the batch is accumulated (pre-backprop) the gradients end up the same:

model = nn.Linear(10, 10)

for i in range(10):
    if i % 2 == 0:
        x = torch.randn(10, 10)*(i+1)
        y = torch.randn(10, 10)*(i+1)

    model.zero_grad()
    criterion = nn.MSELoss(reduction='sum' if i % 2 == 0 else 'mean')
    out = model(x)
    loss = criterion(out, y)
    (loss / loss.item()).backward()
    print(f'unit-norm {"sum" if i%2==0 else "mean"}\t\t', torch.norm(model.weight.grad), f'\terror={loss.item()}')
unit-norm sum		tensor(0.6942) 	error=98.28943634033203
unit-norm mean		tensor(0.6942) 	error=0.9828943610191345
unit-norm sum		tensor(0.5292) 	error=1277.2415771484375
unit-norm mean		tensor(0.5292) 	error=12.772416114807129
unit-norm sum		tensor(0.5460) 	error=3134.31298828125
unit-norm mean		tensor(0.5460) 	error=31.343130111694336
unit-norm sum		tensor(0.7055) 	error=4419.689453125
unit-norm mean		tensor(0.7055) 	error=44.196895599365234
unit-norm sum		tensor(0.5809) 	error=12567.6796875
unit-norm mean		tensor(0.5809) 	error=125.67679595947266

For each batch the gradients are equivalent regardless of accumulation

We have effectively normalized the gradients. Even as we increase the underlying error, the gradient norm stays stable.

Note: if you are doing batch accumulation you still need to divide by the amount you accumulate: (loss/loss.item()/grad_accum).backward() ; like-wise if you're aggregating within a batch with different expected loss magnitudes it may help to do (loss/loss.detach()).mean().backward() with reduction='none'

But why?

You may ask, it is a fair question if you're not in the trenches of optimizers and ML research. And my response would be "We need to get away from the alchemy", currently training can very easily destabilize and as previously mentioned I identified loss scaling and batch scaling as key factors preventing us from doing sound training comparisons and exploration.

If you have ever read an optimizer paper you'll notice that they always out-perform SGD and Adam, while in an actual comparison they typically under-perform. Here's AdaBelief (watch until the end)

0:00
/0:16

AdaBelief Headline animation; notice Adam actually getting to the same point just slower.

There are arguments for certain optimizers being 'allowed' larger step-sizes - like being able to better handle stochastic noise - but from the video AdaBelief looks like it's Adam just sped up. So we can't tell if AdaBelief is just Adam with 10x learning-rate or if it's actually being intelligent about it's speed.

For real applications this problem becomes exponentially worse as there are many things to factor in including batchsize, batch accumulation, the outsized effect of 'bad' samples in the dataset due to the higher loss value. You really do not want to have to filter through those things to train models or tune/develop/test optimizers.

Wait, doesn't this mess up current training schemes?

If you've been paying attention and thinking along you may realize an issue with normalizing the gradients compared to how back-prop is current utilized. As the loss gets lower the step-size automatically decreases (for better or worse); but with unit-norm it doesn't necessarily decrease. Wouldn't this mess up long-tail convergence? And optimizers currently assume this decay, does this mean we need different optimizers?

To be honest I don't know, I haven't ran enough experiments with this method; but regardless it's not something that's hard to overcome. We've used ReduceOnPlateau and learning-rate decays forever; the difference is that the schedule now actually does what you'd expect rather than it being modulated by the scaling of your loss function.

References?

I didn't actually refrence many papers directly (only AdaBelief, SignSGD and Adam), but I do have some interesting things I found afterwards: