Scalable Training

Back

Loading concept...

Production Deep Learning: Scalable Training

The Factory Analogy 🏭

Imagine you’re running a toy factory. You want to make millions of toys, but your factory is small. How do you scale up? You use clever tricks: batch your work, use faster machines, save your progress, and get help from friends. That’s exactly what we do with deep learning!


1. Gradient Accumulation

The Problem: Your Backpack is Too Small

Think about carrying groceries home. Your backpack can only hold 4 items at a time. But you need to bring home 32 items!

Old way: Make 8 trips (slow and tiring!)

Smart way: Write down what you picked each trip, then unpack everything at once at home!

What is Gradient Accumulation?

When training neural networks, we process data in batches. But sometimes our GPU memory is too small for big batches.

Gradient Accumulation = Process small mini-batches, but add up (accumulate) the gradients before updating the model.

# Without accumulation (needs lots of memory)
batch_size = 32  # GPU might crash!

# With accumulation (memory-friendly)
mini_batch = 4
accumulation_steps = 8
# 4 × 8 = 32 effective batch size!

Simple Code Example

optimizer.zero_grad()

for i, (data, target) in enumerate(loader):
    output = model(data)
    loss = criterion(output, target)
    loss = loss / accumulation_steps
    loss.backward()  # Accumulate gradients

    if (i + 1) % accumulation_steps == 0:
        optimizer.step()  # Update once!
        optimizer.zero_grad()

Why It Works

  • Same learning as big batches
  • Less memory needed
  • Any GPU can train large models!

2. Mixed Precision Training

The Art Store Analogy 🎨

You’re an artist with two types of paint:

  • Expensive paint (32-bit): Super precise colors, costs a lot
  • Budget paint (16-bit): Good enough for most things, half the price!

Smart artist: Use expensive paint only for tiny details. Use budget paint for everything else!

What is Mixed Precision?

Computers store numbers in different sizes:

  • FP32 (32 bits): Very precise, uses more memory
  • FP16 (16 bits): Less precise, uses half the memory!

Mixed Precision = Use FP16 for most calculations, FP32 only where needed.

graph TD A["Input Data"] --> B["FP16: Forward Pass"] B --> C["FP16: Compute Loss"] C --> D["FP32: Scale Loss"] D --> E["FP16: Backward Pass"] E --> F["FP32: Update Weights"] F --> G["Trained Model"]

Simple Code Example

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for data, target in loader:
    optimizer.zero_grad()

    with autocast():  # Magic happens here!
        output = model(data)
        loss = criterion(output, target)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

Benefits

Benefit Improvement
Memory 2× less
Speed 2-3× faster
Accuracy Same!

3. Gradient Checkpointing

The Video Game Save Point 🎮

Playing a long video game? You don’t save every second. You save at checkpoints. If you fail, you restart from the last checkpoint—not the beginning!

What is Gradient Checkpointing?

During training, the computer remembers every calculation to compute gradients. This uses lots of memory!

Checkpointing = Only save some calculations. Recompute the rest when needed.

graph TD A["Layer 1"] -->|Save| B["Checkpoint"] B --> C["Layer 2-3"] C -->|Save| D["Checkpoint"] D --> E["Layer 4-5"] E --> F["Output"] F -->|Backward| G["Recompute if needed"]

Trade-off

  • Less memory: Save fewer activations
  • More time: Recompute when needed
  • Net win: Train models that wouldn’t fit otherwise!

Simple Code Example

from torch.utils.checkpoint import checkpoint

class BigModel(nn.Module):
    def forward(self, x):
        # Checkpoint expensive layers
        x = checkpoint(self.layer1, x)
        x = checkpoint(self.layer2, x)
        x = self.final_layer(x)
        return x

4. Model Saving and Loading

The Recipe Book 📖

A chef writes down recipes so they can:

  • Remember them later
  • Share with other chefs
  • Continue cooking tomorrow

Your model’s learned knowledge = Your recipe!

What to Save?

# The essentials
checkpoint = {
    'model': model.state_dict(),      # Weights
    'optimizer': optimizer.state_dict(),  # Training state
    'epoch': current_epoch,           # Progress
    'loss': best_loss,                # Best score
}

# Save it!
torch.save(checkpoint, 'my_model.pt')

Loading Your Model

# Load checkpoint
checkpoint = torch.load('my_model.pt')

# Restore everything
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']

Best Practices

Practice Why
Save regularly Don’t lose hours of work!
Save best model Keep your champion
Include metadata Know what you saved
Use .pt or .pth Standard PyTorch format

5. Distributed Training

The Pizza Party Analogy 🍕

One person making 100 pizzas = Forever!

10 people, each making 10 pizzas = 10× faster!

What is Distributed Training?

Use multiple GPUs (or computers) to train faster!

graph TD A["Training Data"] --> B["Split Data"] B --> C["GPU 1: Batch 1"] B --> D["GPU 2: Batch 2"] B --> E["GPU 3: Batch 3"] B --> F["GPU 4: Batch 4"] C --> G["Combine Gradients"] D --> G E --> G F --> G G --> H["Update Model"]

Types of Distributed Training

Data Parallel: Same model, different data

  • Each GPU processes different batches
  • Most common approach

Model Parallel: Different parts of model on different GPUs

  • For HUGE models that don’t fit on one GPU

Simple Code Example

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel

# Initialize
dist.init_process_group(backend='nccl')

# Wrap your model
model = DistributedDataParallel(
    model,
    device_ids=[local_rank]
)

# Train normally - PyTorch handles the rest!

6. Gradient Synchronization

The Dance Team 💃

Imagine a dance team where everyone needs to do the same move at the same time. They need to synchronize!

What is Gradient Synchronization?

When multiple GPUs train together, they each compute different gradients. Before updating, they must share and combine their gradients.

How It Works

graph TD A["GPU 1 Gradient"] --> E["All-Reduce"] B["GPU 2 Gradient"] --> E C["GPU 3 Gradient"] --> E D["GPU 4 Gradient"] --> E E --> F["Average Gradient"] F --> G["Same Update on All GPUs"]

The All-Reduce Operation

All-Reduce = Everyone sends their gradient, everyone gets the average!

# This happens automatically with DDP
# But you can do it manually:

import torch.distributed as dist

# Each GPU has its own gradient
gradient = model.parameters().grad

# Synchronize across all GPUs
dist.all_reduce(gradient, op=dist.ReduceOp.SUM)
gradient /= world_size  # Average

Sync Strategies

Strategy When Use Case
Every step Always Standard training
Periodic Large batches Save communication
Async Many nodes Faster but noisier

Putting It All Together 🚀

Here’s how a production training script combines everything:

# Setup distributed training
dist.init_process_group(backend='nccl')
model = DistributedDataParallel(model)

# Mixed precision
scaler = GradScaler()

# Training loop
for epoch in range(num_epochs):
    for i, (data, target) in enumerate(loader):

        # Mixed precision forward
        with autocast():
            output = model(data)
            loss = criterion(output, target)
            loss = loss / accumulation_steps

        # Accumulate gradients
        scaler.scale(loss).backward()

        if (i + 1) % accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

    # Save checkpoint
    if rank == 0:  # Only main process saves
        torch.save({
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch,
        }, f'checkpoint_{epoch}.pt')

Quick Reference

Technique Problem Solved Key Benefit
Gradient Accumulation Small GPU memory Train with big batches
Mixed Precision Slow training 2-3× speedup
Checkpointing Model too big Fit huge models
Save/Load Lose progress Resume anytime
Distributed One GPU too slow Scale to many GPUs
Gradient Sync GPUs out of sync Consistent updates

You Did It! 🎉

You now understand the 6 pillars of scalable deep learning:

  1. Accumulate gradients to simulate big batches
  2. Mix precision for speed and memory
  3. Checkpoint to fit bigger models
  4. Save/Load to never lose progress
  5. Distribute across multiple GPUs
  6. Synchronize to keep everyone aligned

These techniques power models like GPT, BERT, and Stable Diffusion. Now you know their secrets!

“The best way to scale is to work smarter, not just harder.” 🧠

Loading story...

Story - Premium Content

Please sign in to view this story and start learning.

Upgrade to Premium to unlock full access to all stories.

Stay Tuned!

Story is coming soon.

Story Preview

Story - Premium Content

Please sign in to view this concept and start learning.

Upgrade to Premium to unlock full access to all content.