2024-09-07
Under the Hood of torch.compile
Let's look behind the scenes what happens when we use torch.compile API
torch.compile
is a way to move from eager to graph execution mode and essentially it has a compiler that does the final work
now to get the compiled model, torch.compile executes 3 steps
-
graph acquisition - this step is to capture model definition and transform it into a graph of ops. actually it reads python byte code before execution.
-
graph lowering - now its time to simplify and optimize the graph by fusing, combining and reducing ops
-
graph compilation - in this step device code is generated for different architectures, vendors and devices like GPU or even TPU
step 1 is executed by torchdynamo and its built in cpython and called as frame evaluation api
steps 2 and 3 are executed by the backend compiler torchinductor which uses openmp framework and triton compiler and this is configurable btw using the backend
param in the torch.compile. there are many available backends.
@Shashank Prasanna's super amazing post about torch.compile