Debugging and Profiling

Back

Loading concept...

🔍 PyTorch Debugging & Profiling: Becoming a Code Detective

Imagine you’re a detective solving mysteries in a big factory. Sometimes machines break down, sometimes they run too slow, and sometimes they forget what they learned. Your job? Find out why and fix it!


🎭 The Detective’s Analogy

Think of training a neural network like running a chocolate factory:

  • Memory = Your factory’s warehouse space
  • Gradients = Instructions workers pass along the assembly line
  • Profiler = A security camera that watches everything
  • Hooks = Secret spies you place at specific machines
  • Tensors = The chocolate bars moving through the factory

When something goes wrong, you need detective tools to find the problem!


🎥 torch.profiler: Your Security Camera System

What Is It?

The profiler is like having security cameras everywhere in your factory. It records:

  • How long each machine takes
  • How much power each machine uses
  • Which machine is the slowest

Simple Example

import torch
from torch.profiler import profile

# Start recording everything
with profile() as prof:
    # Your model does its work
    output = model(input_data)
    loss = criterion(output, target)
    loss.backward()

# See what happened
print(prof.key_averages())

What You’ll See

Name              CPU Time    GPU Time
--------------    --------    --------
forward           2.5ms       1.2ms
backward          4.1ms       2.8ms
optimizer.step    0.8ms       0.3ms

Pro Detective Tip 🕵️

# Record to a file you can view later
with profile(
    on_trace_ready=torch.profiler.
        tensorboard_trace_handler('./logs')
) as prof:
    for step, data in enumerate(loader):
        model(data)
        prof.step()  # Mark each step

Now you can view beautiful charts in TensorBoard!


💾 Memory Profiling: Checking Your Warehouse

The Problem

Your factory warehouse (GPU memory) can overflow! When it does:

  • Training crashes
  • You see “CUDA out of memory”
  • Everything stops

How to Check Memory

import torch

# See current warehouse usage
print(torch.cuda.memory_allocated())
# Output: 1073741824 (bytes used)

# See maximum ever used
print(torch.cuda.max_memory_allocated())

# See full memory report
print(torch.cuda.memory_summary())

Visual Memory Map

graph TD A["Total GPU Memory: 8GB"] --> B["Model Weights: 2GB"] A --> C["Activations: 3GB"] A --> D["Gradients: 2GB"] A --> E["Free Space: 1GB"] style A fill:#4ECDC4 style B fill:#FF6B6B style C fill:#FFE66D style D fill:#95E1D3 style E fill:#DFE6E9

Finding Memory Leaks

# Before training loop
torch.cuda.reset_peak_memory_stats()

for epoch in range(10):
    train_one_epoch()

    # Check after each epoch
    current = torch.cuda.memory_allocated()
    peak = torch.cuda.max_memory_allocated()
    print(f"Epoch {epoch}: {current/1e9:.2f}GB")

If memory keeps growing each epoch, you have a leak!


🕵️ Debugging with Hooks: Your Secret Spies

What Are Hooks?

Hooks are like placing spies at specific machines in your factory. They report back what they see!

Types of Hooks

Hook Type When It Activates What It Sees
forward_hook During forward pass Input & Output
backward_hook During backward pass Gradients
forward_pre_hook Before forward Input only

Simple Forward Hook

def spy_function(module, input, output):
    print(f"Layer: {module.__class__.__name__}")
    print(f"Output shape: {output.shape}")
    print(f"Output mean: {output.mean():.4f}")

# Place your spy
hook = model.layer1.register_forward_hook(
    spy_function
)

# Run the model - spy reports!
output = model(data)

# Remove spy when done
hook.remove()

Gradient Hook (Backward Spy)

def gradient_spy(module, grad_input, grad_output):
    if grad_output[0] is not None:
        print(f"Gradient magnitude: "
              f"{grad_output[0].abs().mean():.6f}")

hook = model.fc.register_backward_hook(
    gradient_spy
)

Finding Dead Layers

def check_if_alive(name):
    def hook(module, input, output):
        if output.abs().max() < 1e-6:
            print(f"⚠️ {name} might be dead!")
    return hook

# Check all layers
for name, layer in model.named_modules():
    layer.register_forward_hook(check_if_alive(name))

🍫 Debugging Tensor Issues: Checking Your Chocolate Bars

Common Tensor Problems

graph LR A["Tensor Problems"] --> B["Wrong Shape"] A --> C["NaN Values"] A --> D["Inf Values"] A --> E["Wrong Device"] B --> B1["Use .shape to check"] C --> C1["Use torch.isnan"] D --> D1["Use torch.isinf"] E --> E1["Use .device to check"]

Shape Detective

# Always check shapes!
print(f"Input: {x.shape}")
print(f"Weight: {model.fc.weight.shape}")
print(f"Output: {output.shape}")

# Quick shape assertion
assert x.shape == (batch, channels, h, w), \
    f"Expected (B,C,H,W), got {x.shape}"

NaN and Inf Detective

def check_tensor_health(tensor, name="tensor"):
    has_nan = torch.isnan(tensor).any()
    has_inf = torch.isinf(tensor).any()

    if has_nan:
        print(f"🚨 {name} has NaN values!")
    if has_inf:
        print(f"🚨 {name} has Inf values!")

    return not (has_nan or has_inf)

# Use it everywhere
check_tensor_health(loss, "loss")
check_tensor_health(output, "output")

Device Mismatch Detective

# Check where tensors live
print(f"Data on: {data.device}")
print(f"Model on: {next(model.parameters()).device}")

# Quick fix
data = data.to(model.device)

🧹 Memory Management: Keeping Your Warehouse Clean

Why It Matters

A messy warehouse means:

  • No room for new chocolate bars
  • Workers trip over old boxes
  • Factory shuts down!

Cleaning Techniques

# Delete variables you don't need
del old_tensor
del intermediate_results

# Empty the trash
torch.cuda.empty_cache()

# Full cleanup
import gc
gc.collect()
torch.cuda.empty_cache()

Gradient Accumulation (Save Space!)

Instead of one big batch, use many small ones:

accumulation_steps = 4
optimizer.zero_grad()

for i, (data, target) in enumerate(loader):
    output = model(data)
    loss = criterion(output, target)
    loss = loss / accumulation_steps
    loss.backward()

    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

Checkpoint Magic (Trade Time for Space)

from torch.utils.checkpoint import checkpoint

class MemoryEfficientModel(nn.Module):
    def forward(self, x):
        # Don't save activations, recompute later
        x = checkpoint(self.layer1, x)
        x = checkpoint(self.layer2, x)
        return x

📉 Gradient Problems: When Instructions Get Lost

The Three Gradient Villains

graph LR A["Gradient Problems"] --> B["Vanishing 👻"] A --> C["Exploding 💥"] A --> D["None Gradients 🚫"] B --> B1["Gradients too tiny"] B --> B2["Network stops learning"] C --> C1["Gradients too huge"] C --> C2["NaN losses appear"] D --> D1["No path exists"] D --> D2["requires_grad=False"]

Detecting Vanishing Gradients

def check_gradients(model):
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_norm = param.grad.norm().item()
            if grad_norm < 1e-7:
                print(f"👻 {name}: gradient vanishing!")
            elif grad_norm > 1000:
                print(f"💥 {name}: gradient exploding!")

Gradient Clipping (Tame the Explosions)

# Clip gradients to safe range
torch.nn.utils.clip_grad_norm_(
    model.parameters(),
    max_norm=1.0
)
optimizer.step()

Fix None Gradients

# Check if tensor needs gradients
print(f"requires_grad: {tensor.requires_grad}")

# Enable gradient tracking
tensor.requires_grad_(True)

# Or when creating
tensor = torch.tensor([1.0], requires_grad=True)

Anomaly Detection Mode

# Turn on the detective mode!
torch.autograd.set_detect_anomaly(True)

# Now if gradients fail, you get details
loss.backward()  # Will show exactly where it broke

🎯 Quick Reference: Detective Toolkit

Problem Tool Command
Slow training Profiler with profile() as p:
Out of memory Memory check torch.cuda.memory_allocated()
Strange outputs Forward hook register_forward_hook()
Bad gradients Backward hook register_backward_hook()
NaN values Tensor check torch.isnan(x).any()
Memory leak Cleanup torch.cuda.empty_cache()
Exploding grads Clipping clip_grad_norm_(params, 1.0)
Debug mode Anomaly set_detect_anomaly(True)

🌟 You’re Now a PyTorch Detective!

Remember:

  1. Profile first - Know where the problem is
  2. Watch memory - Keep your warehouse clean
  3. Use hooks - Place spies where you need them
  4. Check tensors - Shape, device, NaN, Inf
  5. Manage gradients - Clip them, track them, love them

When your chocolate factory runs smoothly, you make the best AI chocolate in the world! 🍫🤖


💡 Final Tip: Start with torch.autograd.set_detect_anomaly(True) when debugging. It’s like turning on all the lights in the factory!

Loading story...

Story - Premium Content

Please sign in to view this story and start learning.

Upgrade to Premium to unlock full access to all stories.

Stay Tuned!

Story is coming soon.

Story Preview

Story - Premium Content

Please sign in to view this concept and start learning.

Upgrade to Premium to unlock full access to all content.