2024-09-17

Automatic Mixed Precision Training in PyTorch

Understand how we can use pytorch amp


⚠️ amp is different from the concept of quantization

we generally use 32 bits to represent numbers (INT32 format) with 1 bit for sign and the remaining 31 bits for the number

INT32 also requires 4 bytes for 1 int value thus expecting more memory and compute

now to represent floating point numbers we use FP32 (single precision) and FP64 (double precision) and these have signs (same as above), exponent (range) and a fraction (decimal places)

model training usually has FP32

there are other formats like:

  1. FP16 - has 16 bits and is also called as half-precision

  2. bfloat16 - has 16 bits but has the same range as FP32 but low-precision and pytorch supports bfloat16 only on CPUs. it was created at Google

  3. TF32 - has 19 bit and the tensorfloat was created at Nvidia

refer to this 👉 https://lnkd.in/g97uGsgx post to understand these formats in detail

to do model training, save memory and compute (with some compromise on the precision) we can use cheaper formats like FP8

training with mixed precision means using low-precision formats wherever possible while keeping high-precision as the default

this approach not only saves memory but prevents loss of information (impacting accuracy)

pytorch also maintains a list of operations to run at lower precision

automatic mixed precision by pytorch automatically replaces an operation to run in lower precision

we need to use torch.autocast but to use it on GPUs we need to do a bit more

  1. enable backend flags for CUDA and CuDNN
  • torch.backend.cudnn.benchmark

  • cuda.matmul.allow_fp16_reduced_precision_reduction

  • cuda.matmul.allow_bf16_reduced_precision_reduction

  • torch.backend.cudnn.allow_tf32

  1. wrap the training loop in the torch.autocast

use it as a context manager or a decorator and include the forward pass and loss calculation. It accepts device_type, and dtype as necessary arguments.

  1. use a gradient scaler

to prevent loss of information on the gradient because of the low precision we need to use a gradient scaler

gradient scaling improves convergence for networks

we need to wrap optimizations (optimizer.step) and backward pass (loss.backward) under the torch.cuda.amp.GradScaler

refer to this 👉 https://lnkd.in/g_EyxFdY post to understand automatic mixed precision in detail