Pytorch图像去噪实战(三十三):梯度累积训练大模型,小显存也能稳定训练大Batch
一、问题场景:显存太小,batch size只能设成1
图像去噪模型越做越大后,显存问题会越来越明显。
特别是训练:
- RGB UNet
- Restormer
- SwinIR
- Diffusion UNet
- 大 patch 图像
- 多尺度模型
经常会遇到:
CUDA out of memory最直接的做法是把 batch size 改小。
但 batch size 太小会带来问题:
- loss 抖动明显
- 梯度噪声大
- 训练不稳定
- BatchNorm 统计不准
- 指标提升慢
如果显存不够,但又想获得更大的等效 batch,就可以使用:
梯度累积 Gradient Accumulation。
二、梯度累积是什么?
普通训练:
一个 batch ->