PyTorch 2.x and Backends (WIP, come back in a couple of days)

Disclaimer: Human “generated” text as a labor of love.

The Great ML Framework Debate

Centuries ago (2020), I contributed to PyTorch library, specifically, TorchServe which used to be the default model serving library for PyTorch Models.

Back then the ML Frameworks were split into two factions:

  • Eager Mode
  • Graph Mode

Graph Mode was a strict programming model which required the models to be expressible as compute DAGs. With each node of the graph expresed as a single operation whose inputs and outputs were statically sized. This made life easier for the frameworks to optimize with a compiler since graph optimizations are readily achievable offline in comparison to dynamic shapes. This can be seen as analogous to programming in statically typed languages such as C/C++. ML Frameworks in question are TensorFlow 1.x, MxNet etc

Eager Mode was easier to code and debug with. Preferred by users who just wanted to get on with it. If you are concerned more with the outcome of the model you are building then this is how you would be thinking. It was analogous to programming with Python. PyTorch was the flagbearer for this eager mode, funnily enough MxNet also supported this.

PyTorch - Graph Modes

PyTorch wanted in, so they made all these different attempts at also becoming a Graph Mode Execution Framework.

  • torch.jit.trace
  • torch.jit.script
  • Lazy Tensors

But the users were not interested.

PyTorch Compile

In PyTorch 2.x, we have the option to compile the model and make faster where we can. In the sense, not all parts of the model can be captured as a graph but whatever can be captured can be optimized.

# model is your model torch.nn.Module
 
compiled_model = torch.compile(model, backend="inductor", mode="max-autotune", dynamic=True)

This script tests the dynamic compilation feature of PyTorch with a large tensor. It uses the torch.compile decorator to compile a function that computes the sine and cosine of a tensor, and then calls this function with a large tensor on a CUDA device.

TorchDynamo:

Graph Capture for PyTorch inside torch.compile

TorchInductor:

Default compiler backend for PyTorch.

Defining Custom Backend For torch.compile()

In Tesla Autopilot, we use a custom ASIC / Inference Computer on-board. We always wanted to define a backend so the model’s compilation stack could be abstracted away from the Model Developers. But the constraints of the hardware required us to do operations (model partitioning, sometimes, by hand) which were done better when compiled statically offline.

Here, we will see how a simple user-defined backend would look. In a real world backend, these sample inputs will be used to implement optimizations. Again, notice that optimizations are possible only when input shapes and sizes are known.

from typing import List
import torch
 
def my_compiler(gm: torch.fx.GraphModule,
                sample_inputs: List[torch.Tensor]):
    print("my_compiler() called with FX graph:")
    gm.graph.print_tabular()
    return gm # returns a callable
 
@torch.compile(backend=my_compiler)
def toy_example(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b
 
for _ in range(1000):
    toy_example(torch.randn(10), torch.randn(10))
 
from typing import List
import torch
 
def my_compiler(gm: torch.fx.GraphModule,
                sample_inputs: List[torch.Tensor]):
    print("my_compiler() called with FX graph:")
    gm.graph.print_tabular()
    return gm # returns a callable
 
@torch.compile(backend=my_compiler)
def toy_example(a, b):
    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b
 
for _ in range(1000):
    toy_example(torch.randn(10), torch.randn(10))
my_compiler() called with FX graph:
opcode         name    target                                                  args         kwargs
-------------  ------  ------------------------------------------------------  -----------  --------
placeholder    l_a_    L_a_                                                    ()           {}
placeholder    l_b_    L_b_                                                    ()           {}
call_function  abs_1   <built-in method abs of type object at 0x7fdce4beff00>  (l_a_,)      {}
call_function  add     <built-in function add>                                 (abs_1, 1)   {}
call_function  x       <built-in function truediv>                             (l_a_, add)  {}
call_method    sum_1   sum                                                     (l_b_,)      {}
call_function  lt      <built-in function lt>                                  (sum_1, 0)   {}
output         output  output                                                  ((x, lt),)   {}


my_compiler() called with FX graph:
opcode         name    target                   args         kwargs
-------------  ------  -----------------------  -----------  --------
placeholder    l_b_    L_b_                     ()           {}
placeholder    l_x_    L_x_                     ()           {}
call_function  b       <built-in function mul>  (l_b_, -1)   {}
call_function  mul_1   <built-in function mul>  (l_x_, b)    {}
output         output  output                   ((mul_1,),)  {}


my_compiler() called with FX graph:
opcode         name    target                   args          kwargs
-------------  ------  -----------------------  ------------  --------
placeholder    l_x_    L_x_                     ()            {}
placeholder    l_b_    L_b_                     ()            {}
call_function  mul     <built-in function mul>  (l_x_, l_b_)  {}
output         output  output                   ((mul,),)     {}

Let’s break this down:

Our toy_example function is basically being run with of two tensors of 10 random values each. The function as simple as it is, can be broadly divided into three parts:

COMPUTE 1

IF COND:
    COMPUTE 2

COMPUTE 3

The three graphs printed in tabular form are actually just these COMPUTE1, COMPUTE2, COMPUTE3.

import torch
print(torch.__version__)
2.6.0+cu124
import os
os.environ['TORCH_COMPILE_DEBUG'] = '1'
import torch
 
@torch.compile(dynamic=True)
def foo(x):
    y = x.sin()
    z = y.cos()
    return y, z
 
foo(torch.randn([8192, 8192], device='cuda'))
/home/mlaidev/software/jsr_gdm/.venv/lib/python3.13/site-packages/torch/utils/_config_module.py:342: UserWarning: Skipping serialization of skipfiles_inline_module_allowlist value {}
  warnings.warn(
W0526 09:03:48.447000 189635 /home/mlaidev/software/jsr_gdm/.venv/lib/python3.13/site-packages/torch/_inductor/debug.py:435] [1/0] model__1_inference_1 debug trace: /home/mlaidev/software/jsr_gdm/pytorch_backend/torch_compile_debug/run_2025_05_26_09_03_06_493183-pid_189635/torchinductor/model__1_inference_1.1





(tensor([[-0.5081,  0.8742, -0.7563,  ...,  0.7116,  0.6417,  0.8982],
         [-0.1069, -0.9184, -0.6707,  ..., -0.3978, -0.6187,  0.8878],
         [ 0.6196, -0.6739, -0.4469,  ...,  0.7490, -0.9443,  0.7898],
         ...,
         [ 0.8925, -0.9631, -0.9991,  ...,  0.3309, -0.1943, -0.1613],
         [-0.9386, -0.8649, -0.9941,  ...,  0.4611, -0.2239,  0.1002],
         [ 0.6346, -0.2102,  0.8530,  ..., -0.9979, -0.7121, -0.8892]],
        device='cuda:0'),
 tensor([[0.8737, 0.6416, 0.7274,  ..., 0.7573, 0.8011, 0.6230],
         [0.9943, 0.6071, 0.7834,  ..., 0.9219, 0.8146, 0.6312],
         [0.8141, 0.7814, 0.9018,  ..., 0.7324, 0.5863, 0.7040],
         ...,
         [0.6274, 0.5709, 0.5410,  ..., 0.9458, 0.9812, 0.9870],
         [0.5909, 0.6487, 0.5453,  ..., 0.8956, 0.9750, 0.9950],
         [0.8053, 0.9780, 0.6577,  ..., 0.5421, 0.7570, 0.6300]],
        device='cuda:0'))
!ls  /home/mlaidev/software/jsr_gdm/pytorch_backend/torch_compile_debug/run_2025_05_26_09_03_06_493183-pid_189635/torchinductor/model__1_inference_1.1
fx_graph_readable.py  fx_graph_transformed.py  ir_pre_fusion.txt
fx_graph_runnable.py  ir_post_fusion.txt       output_code.py