Part 12: XLA and Compilation
Turning Python Into Optimized GPU Code
Python code runs one operation at a time, with overhead between each step. Compilation analyzes the entire computation graph, fuses operations, and generates optimized GPU code. The result? 2-3× speedups for the same hardware.
This part explains XLA, torch.compile, and the tradeoffs of compilation (speed vs. debugging difficulty).
Why Should a Leader Care?
When engineers mention:
“We’re enabling XLA for faster training”
“torch.compile gave us a 2× speedup”
“JIT compilation is causing issues”
...they’re talking about a technique that can significantly speed up training. Understanding compilation helps you:
Know when it helps and when it doesn’t
Understand debugging tradeoffs
Appreciate why some code runs faster than similar-looking code
The One Concept: Compile Instead of Interpret
Normal Python runs operation by operation: do step 1, return to Python, do step 2, return to Python, etc.
Compilation analyzes the entire computation, optimizes it, and runs everything as one fused unit.
Example without compilation:
1. Launch matmul kernel on GPU → wait → return to Python
2. Launch add kernel on GPU → wait → return to Python
3. Launch relu kernel on GPU → wait → return to PythonOverhead: 3 kernel launches, 3 Python round-trips
Example with compilation:
1. Compile: fuse matmul + add + relu into one kernel
2. Launch fused kernel on GPU → wait → return to PythonOverhead: 1 kernel launch, 1 Python round-trip
Result: Less overhead, better memory usage, faster execution.
What Is a Kernel?
A kernel is a program that runs on the GPU. Every operation (matmul, add, relu) is implemented as a kernel.
Without fusion: Each operation = separate kernel
Kernel 1: Read inputs from HBM, compute matmul, write output to HBM
Kernel 2: Read output from HBM, compute add, write to HBM
Kernel 3: Read output from HBM, compute relu, write to HBM
With fusion: One combined kernel
Kernel: Read inputs from HBM, compute matmul + add + relu in registers, write final output to HBM
Why fusion is faster: Intermediate results stay in fast registers, avoiding slow HBM reads/writes.
XLA (Accelerated Linear Algebra)
XLA is Google’s compiler for ML workloads. Used by:
JAX (by default, always on)
TensorFlow (optional, via
@tf.function(jit_compile=True))PyTorch/XLA (for TPU support)
What XLA Does
1. Operation fusion: Combine multiple operations into one kernel.
Example: y = relu(matmul(x, W) + b)
Without XLA:
matmul: read x (1000×10000), read W (10000×4096), compute, write output (1000×4096)
add: read output (1000×4096), read b (4096), compute, write output
relu: read output (1000×4096), compute, write output
Data moved: 3× read output from HBM (~50 MB × 3 = 150 MB)
With XLA fusion:
Fused kernel: read x, read W, read b, compute all operations, write final output
Data moved: 1× write output (~50 MB)
Speedup: 2-3× for memory-bound workloads (which most ML is).
2. Memory planning: Allocate exactly the memory needed, reuse buffers when safe.
3. Layout optimization: Rearrange tensor storage for faster access patterns.
4. Target-specific optimization: Generate code optimized for specific hardware (GPU vs TPU).
PyTorch Compilation: torch.compile
For years, PyTorch prioritized eager execution (easy debugging) over compilation.
In PyTorch 2.0 (2023), torch.compile was introduced:
model = MyModel()
model = torch.compile(model) # That's it
# Training loop runs with compiled model
for batch in dataloader:
loss = model(batch)
loss.backward()
optimizer.step()Under the Hood
TorchDynamo: Captures the computation graph from Python execution (bytecode analysis)
TorchInductor: Compiler backend that generates optimized CUDA/CPU code
Speedups
Typical: 10-40% faster training
Best case: 2-3× faster (for fusion-friendly models)
Worst case: No speedup or slower (if graph breaks, dynamic shapes)
First Run Overhead
Step 1: 30 seconds (compiling)
Step 2: 1 second
Step 3: 1 second
...
Step 1000: 1 secondCompile time: 10-60 seconds (depends on model size)
Amortized: Over thousands of steps, negligible
Compilation Tradeoffs
Compile Time Cost
First execution is slow: the compiler analyzes and optimizes the graph.
For training loops: Pay once (30s), benefit for millions of steps (weeks). Worth it.
For one-off inference: Might not be worth it if only running once.
Typical compile times:
Small model (GPT-2): 5-15 seconds
Medium model (7B): 30-60 seconds
Large model (70B): 1-3 minutes
Dynamic Shapes
If your input sizes change, the compiler might need to recompile.
Static shapes (batch=32, seq=2048 every time):
Compile once, run forever. Ideal.
Dynamic shapes (batch varies, seq varies):
Multiple compilations (one per shape combination)
Or fall back to eager execution
PyTorch 2.x has improved dynamic shape support, but still slower
Real-world: Most training uses static shapes. Inference often has dynamic shapes (variable-length inputs).
Debugging Difficulty
Compiled code is a black box. Can’t easily:
Print intermediate values
Set breakpoints inside fused kernels
See which operation caused NaN
Best practice:
Development: Disable compilation for debugging
Production: Enable for performance
Compatibility
Not all operations compile cleanly.
Compile well:
Standard ops: matmul, conv, relu, softmax
Transformers: attention, MLP
Problematic:
Custom Python code (arbitrary Python in forward pass)
Dynamic control flow: if/else based on tensor values
Unsupported ops: some third-party libraries
Graph breaks: PyTorch compiler gives up partway, creates multiple graphs. Hurts performance.
JIT vs AOT Compilation
JIT (Just-In-Time): Compile when you first run the code.
Used by:
torch.compile, JAX’sjit()Pro: Adapts to actual input shapes
Con: Compilation happens during first training step (30s delay)
AOT (Ahead-Of-Time): Compile before running.
Used in: Deployment (ONNX, TensorRT)
Pro: No runtime compilation overhead
Con: Must know shapes and structure ahead of time
ML training: Almost always JIT (input shapes known but compilation happens at runtime)
ML inference: Often AOT (compile once, deploy to millions of devices)
Leader Implications
“We’re enabling torch.compile”
Good for performance (10-40% speedup expected). Expect some initial debugging for compatibility issues.
“XLA is giving us 2× speedup on TPUs”
XLA + TPU is highly optimized by Google. 2-3× speedup is typical. JAX uses XLA by default.
“Compilation is failing for this model”
Some operation isn’t supported by the compiler. They’ll need to work around it or fall back to eager execution.
“We’re seeing graph breaks”
PyTorch compiler couldn’t capture the entire graph. Part runs compiled, part runs eagerly. Not ideal—investigate why.
“The first step took 90 seconds”
Normal. JIT compilation overhead. Subsequent steps are fast (1-2 seconds).
“Dynamic shapes are causing recompilation”
Input dimensions vary, forcing multiple compilations. Solutions: pad to fixed size, use dynamic shape mode (slower), or accept recompilation cost.
“Compiled models are harder to debug”
True. Disable compilation during debugging (model = torch.compile(model, disable=True) or just don’t compile).
Vocabulary Checkpoint
XLA: Accelerated Linear Algebra — Google’s ML compiler
torch.compile: PyTorch 2.0’s compilation system
Kernel: GPU program that executes an operation
Fusion: Combining multiple operations into one kernel
JIT compilation: Compiling at runtime (first execution)
AOT compilation: Compiling before runtime (deployment)
Graph break: When compiler can’t capture complete graph (PyTorch)
TorchDynamo: PyTorch’s graph capture mechanism
TorchInductor: PyTorch’s code generation backend
What’s Next?
You now have all the pieces: training fundamentals, hardware, parallelism, networks, frameworks, and compilation. Let’s put it all together. What does a real large-scale training job actually look like end-to-end?
Next time: We’ll put it all together — a complete system design for training a 70B parameter model from scratch.