The curious case of the first bad batch

Recently I've been looking into training latent diffusion models, but while experimenting I came across a strange artifact. The first batch always shoots up the loss, "could be a bad batch" I thought. It can happen that the random sampling from the dataset is very high variance compared to the rest of the dataset.

It kept happening, "Hmm so maybe not bad sampling, but could be bad model initialization", diffusion models typically have a specific initialization that I didn't bother to add; training was going well, so I ignored it.

Finally I wanted to speed up experimentation of different settings, which is a lot easier if you start from a checkpoint that is already partly trained, since you can compare against the same "seed" (only for model parameters). It still happened, and this time I couldn't ignore it; it causes the model to regress to effectively initial settings, losing hours or days of progress. When loading it without any optimization it does have expected loss from the checkpoint. Something was very wrong in the first optimization step.

Green training from scratch, Blue loading from checkpoint

I checked everything which could cause a re-initialization, or lack there of. zero_grad before training, checking gradient accumulation; checking the dataset sampling randomness, using a stateless optimizer, disabling autocast, loss rescaling, gradient clipping, ran each for a few batches all failed. I even tried gradient clipping to 0.01, but still it self-destructs with the same magnitude.

I lowered the learning rate and it was still there but lower and this time the recovery was way faster:

Green lr=4e-4; Blue lr=1e-4

Finally a lead, but the plot grows thicker; it can't be the optimizer parameters because one of stateless optimizer I tried before, and it can't be large gradients because of the gradient clipping we did before.

Now, something like OneCycle would likely resolve this issue, however I still didn't know why this jump existed, it did not make sense to me. So I went further going back to regular ol' SGD to check that it wasn't optimizer dependent, and it happened but less so.

As it turn out part of the inconsistency is due to smoothing options, which is typically the Exponential Moving Average. Since the first few samples don't have a large history they are not smoothed at all, causing the spikes to show up much worse than when continuing from the original values.

It's possible that the weight decay is the cause since it's dependent on the learning rate but usually assumes that the optimizer step will compensate, however in this case we have a high learning rate but moving averages inside AdamW are still 0, meaning the effective update is just the weight_decay. However when testing with weight_decay=0, and weight_decay=0.03 the difference is negligible so it also cannot be weight decay.