2024-09-17
Automatic Mixed Precision Training in PyTorch
Understand how we can use pytorch amp
⚠️ amp is different from the concept of quantization
we generally use 32 bits to represent numbers (INT32 format) with 1 bit for sign and the remaining 31 bits for the number
INT32 also requires 4 bytes for 1 int value thus expecting more memory and compute
now to represent floating point numbers we use FP32 (single precision) and FP64 (double precision) and these have signs (same as above), exponent (range) and a fraction (decimal places)
model training usually has FP32
there are other formats like:
-
FP16 - has 16 bits and is also called as half-precision
-
bfloat16 - has 16 bits but has the same range as FP32 but low-precision and pytorch supports bfloat16 only on CPUs. it was created at Google
-
TF32 - has 19 bit and the tensorfloat was created at Nvidia
refer to this 👉 https://lnkd.in/g97uGsgx post to understand these formats in detail
to do model training, save memory and compute (with some compromise on the precision) we can use cheaper formats like FP8
training with mixed precision means using low-precision formats wherever possible while keeping high-precision as the default
this approach not only saves memory but prevents loss of information (impacting accuracy)
pytorch also maintains a list of operations to run at lower precision
automatic mixed precision by pytorch automatically replaces an operation to run in lower precision
we need to use torch.autocast
but to use it on GPUs we need to do a bit more
- enable backend flags for CUDA and CuDNN
-
torch.backend.cudnn.benchmark
-
cuda.matmul.allow_fp16_reduced_precision_reduction
-
cuda.matmul.allow_bf16_reduced_precision_reduction
-
torch.backend.cudnn.allow_tf32
- wrap the training loop in the
torch.autocast
use it as a context manager or a decorator and include the forward pass and loss calculation. It accepts device_type, and dtype as necessary arguments.
- use a gradient scaler
to prevent loss of information on the gradient because of the low precision we need to use a gradient scaler
gradient scaling improves convergence for networks
we need to wrap optimizations (optimizer.step) and backward pass (loss.backward) under the torch.cuda.amp.GradScaler
refer to this 👉 https://lnkd.in/g_EyxFdY post to understand automatic mixed precision in detail