2024-09-07

Under the Hood of torch.compile

Let's look behind the scenes what happens when we use torch.compile API


Image

torch.compile is a way to move from eager to graph execution mode and essentially it has a compiler that does the final work

now to get the compiled model, torch.compile executes 3 steps

  1. graph acquisition - this step is to capture model definition and transform it into a graph of ops. actually it reads python byte code before execution.

  2. graph lowering - now its time to simplify and optimize the graph by fusing, combining and reducing ops

  3. graph compilation - in this step device code is generated for different architectures, vendors and devices like GPU or even TPU

step 1 is executed by torchdynamo and its built in cpython and called as frame evaluation api

steps 2 and 3 are executed by the backend compiler torchinductor which uses openmp framework and triton compiler and this is configurable btw using the backend param in the torch.compile. there are many available backends.

@Shashank Prasanna's super amazing post about torch.compile