🔍 PyTorch Debugging & Profiling: Becoming a Code Detective
Imagine you’re a detective solving mysteries in a big factory. Sometimes machines break down, sometimes they run too slow, and sometimes they forget what they learned. Your job? Find out why and fix it!
🎭 The Detective’s Analogy
Think of training a neural network like running a chocolate factory:
- Memory = Your factory’s warehouse space
- Gradients = Instructions workers pass along the assembly line
- Profiler = A security camera that watches everything
- Hooks = Secret spies you place at specific machines
- Tensors = The chocolate bars moving through the factory
When something goes wrong, you need detective tools to find the problem!
🎥 torch.profiler: Your Security Camera System
What Is It?
The profiler is like having security cameras everywhere in your factory. It records:
- How long each machine takes
- How much power each machine uses
- Which machine is the slowest
Simple Example
import torch
from torch.profiler import profile
# Start recording everything
with profile() as prof:
# Your model does its work
output = model(input_data)
loss = criterion(output, target)
loss.backward()
# See what happened
print(prof.key_averages())
What You’ll See
Name CPU Time GPU Time
-------------- -------- --------
forward 2.5ms 1.2ms
backward 4.1ms 2.8ms
optimizer.step 0.8ms 0.3ms
Pro Detective Tip 🕵️
# Record to a file you can view later
with profile(
on_trace_ready=torch.profiler.
tensorboard_trace_handler('./logs')
) as prof:
for step, data in enumerate(loader):
model(data)
prof.step() # Mark each step
Now you can view beautiful charts in TensorBoard!
💾 Memory Profiling: Checking Your Warehouse
The Problem
Your factory warehouse (GPU memory) can overflow! When it does:
- Training crashes
- You see “CUDA out of memory”
- Everything stops
How to Check Memory
import torch
# See current warehouse usage
print(torch.cuda.memory_allocated())
# Output: 1073741824 (bytes used)
# See maximum ever used
print(torch.cuda.max_memory_allocated())
# See full memory report
print(torch.cuda.memory_summary())
Visual Memory Map
graph TD A["Total GPU Memory: 8GB"] --> B["Model Weights: 2GB"] A --> C["Activations: 3GB"] A --> D["Gradients: 2GB"] A --> E["Free Space: 1GB"] style A fill:#4ECDC4 style B fill:#FF6B6B style C fill:#FFE66D style D fill:#95E1D3 style E fill:#DFE6E9
Finding Memory Leaks
# Before training loop
torch.cuda.reset_peak_memory_stats()
for epoch in range(10):
train_one_epoch()
# Check after each epoch
current = torch.cuda.memory_allocated()
peak = torch.cuda.max_memory_allocated()
print(f"Epoch {epoch}: {current/1e9:.2f}GB")
If memory keeps growing each epoch, you have a leak!
🕵️ Debugging with Hooks: Your Secret Spies
What Are Hooks?
Hooks are like placing spies at specific machines in your factory. They report back what they see!
Types of Hooks
| Hook Type | When It Activates | What It Sees |
|---|---|---|
forward_hook |
During forward pass | Input & Output |
backward_hook |
During backward pass | Gradients |
forward_pre_hook |
Before forward | Input only |
Simple Forward Hook
def spy_function(module, input, output):
print(f"Layer: {module.__class__.__name__}")
print(f"Output shape: {output.shape}")
print(f"Output mean: {output.mean():.4f}")
# Place your spy
hook = model.layer1.register_forward_hook(
spy_function
)
# Run the model - spy reports!
output = model(data)
# Remove spy when done
hook.remove()
Gradient Hook (Backward Spy)
def gradient_spy(module, grad_input, grad_output):
if grad_output[0] is not None:
print(f"Gradient magnitude: "
f"{grad_output[0].abs().mean():.6f}")
hook = model.fc.register_backward_hook(
gradient_spy
)
Finding Dead Layers
def check_if_alive(name):
def hook(module, input, output):
if output.abs().max() < 1e-6:
print(f"⚠️ {name} might be dead!")
return hook
# Check all layers
for name, layer in model.named_modules():
layer.register_forward_hook(check_if_alive(name))
🍫 Debugging Tensor Issues: Checking Your Chocolate Bars
Common Tensor Problems
graph LR A["Tensor Problems"] --> B["Wrong Shape"] A --> C["NaN Values"] A --> D["Inf Values"] A --> E["Wrong Device"] B --> B1["Use .shape to check"] C --> C1["Use torch.isnan"] D --> D1["Use torch.isinf"] E --> E1["Use .device to check"]
Shape Detective
# Always check shapes!
print(f"Input: {x.shape}")
print(f"Weight: {model.fc.weight.shape}")
print(f"Output: {output.shape}")
# Quick shape assertion
assert x.shape == (batch, channels, h, w), \
f"Expected (B,C,H,W), got {x.shape}"
NaN and Inf Detective
def check_tensor_health(tensor, name="tensor"):
has_nan = torch.isnan(tensor).any()
has_inf = torch.isinf(tensor).any()
if has_nan:
print(f"🚨 {name} has NaN values!")
if has_inf:
print(f"🚨 {name} has Inf values!")
return not (has_nan or has_inf)
# Use it everywhere
check_tensor_health(loss, "loss")
check_tensor_health(output, "output")
Device Mismatch Detective
# Check where tensors live
print(f"Data on: {data.device}")
print(f"Model on: {next(model.parameters()).device}")
# Quick fix
data = data.to(model.device)
🧹 Memory Management: Keeping Your Warehouse Clean
Why It Matters
A messy warehouse means:
- No room for new chocolate bars
- Workers trip over old boxes
- Factory shuts down!
Cleaning Techniques
# Delete variables you don't need
del old_tensor
del intermediate_results
# Empty the trash
torch.cuda.empty_cache()
# Full cleanup
import gc
gc.collect()
torch.cuda.empty_cache()
Gradient Accumulation (Save Space!)
Instead of one big batch, use many small ones:
accumulation_steps = 4
optimizer.zero_grad()
for i, (data, target) in enumerate(loader):
output = model(data)
loss = criterion(output, target)
loss = loss / accumulation_steps
loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
Checkpoint Magic (Trade Time for Space)
from torch.utils.checkpoint import checkpoint
class MemoryEfficientModel(nn.Module):
def forward(self, x):
# Don't save activations, recompute later
x = checkpoint(self.layer1, x)
x = checkpoint(self.layer2, x)
return x
📉 Gradient Problems: When Instructions Get Lost
The Three Gradient Villains
graph LR A["Gradient Problems"] --> B["Vanishing 👻"] A --> C["Exploding 💥"] A --> D["None Gradients 🚫"] B --> B1["Gradients too tiny"] B --> B2["Network stops learning"] C --> C1["Gradients too huge"] C --> C2["NaN losses appear"] D --> D1["No path exists"] D --> D2["requires_grad=False"]
Detecting Vanishing Gradients
def check_gradients(model):
for name, param in model.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm().item()
if grad_norm < 1e-7:
print(f"👻 {name}: gradient vanishing!")
elif grad_norm > 1000:
print(f"💥 {name}: gradient exploding!")
Gradient Clipping (Tame the Explosions)
# Clip gradients to safe range
torch.nn.utils.clip_grad_norm_(
model.parameters(),
max_norm=1.0
)
optimizer.step()
Fix None Gradients
# Check if tensor needs gradients
print(f"requires_grad: {tensor.requires_grad}")
# Enable gradient tracking
tensor.requires_grad_(True)
# Or when creating
tensor = torch.tensor([1.0], requires_grad=True)
Anomaly Detection Mode
# Turn on the detective mode!
torch.autograd.set_detect_anomaly(True)
# Now if gradients fail, you get details
loss.backward() # Will show exactly where it broke
🎯 Quick Reference: Detective Toolkit
| Problem | Tool | Command |
|---|---|---|
| Slow training | Profiler | with profile() as p: |
| Out of memory | Memory check | torch.cuda.memory_allocated() |
| Strange outputs | Forward hook | register_forward_hook() |
| Bad gradients | Backward hook | register_backward_hook() |
| NaN values | Tensor check | torch.isnan(x).any() |
| Memory leak | Cleanup | torch.cuda.empty_cache() |
| Exploding grads | Clipping | clip_grad_norm_(params, 1.0) |
| Debug mode | Anomaly | set_detect_anomaly(True) |
🌟 You’re Now a PyTorch Detective!
Remember:
- Profile first - Know where the problem is
- Watch memory - Keep your warehouse clean
- Use hooks - Place spies where you need them
- Check tensors - Shape, device, NaN, Inf
- Manage gradients - Clip them, track them, love them
When your chocolate factory runs smoothly, you make the best AI chocolate in the world! 🍫🤖
💡 Final Tip: Start with
torch.autograd.set_detect_anomaly(True)when debugging. It’s like turning on all the lights in the factory!
