Training Pipeline

Back

Loading concept...

🚂 Training Your Neural Network: The Chef’s Kitchen

Imagine you’re teaching a young chef to cook the perfect meal. They need to learn ingredients, practice recipes, taste their creations, and know when to stop cooking before burning everything!


🎯 What We’ll Learn

Training a neural network is like running a cooking school. We’ll explore:

  1. Weight Initialization – Setting up the kitchen
  2. Training Loop – The cooking practice routine
  3. Validation & Testing – Taste testing!
  4. Overfitting & Underfitting – Too much vs too little seasoning
  5. Early Stopping – Knowing when the dish is perfect
  6. Gradient Clipping – Preventing kitchen fires
  7. Weight & Spectral Norm – Keeping portions balanced
  8. Micro-batch Accumulation – Cooking in small batches

🍳 1. Weight Initialization

The Kitchen Setup Story

Before cooking begins, you need to organize your kitchen. Imagine if every pot started empty vs every pot started overflowing with random ingredients. Neither works well!

Weight initialization is about giving your network a fair starting point – not too random, not all the same.

Why It Matters

Bad Start Good Start
All weights = 0 Weights vary slightly
Network learns nothing Network learns smoothly
Like frozen chef Like prepared chef

Common Methods

Xavier/Glorot Initialization:

“Give each neuron just enough starting power based on its connections”

import torch.nn as nn

layer = nn.Linear(100, 50)
nn.init.xavier_uniform_(layer.weight)

Kaiming/He Initialization:

“Perfect for ReLU – accounts for dying neurons”

nn.init.kaiming_normal_(
    layer.weight,
    mode='fan_in',
    nonlinearity='relu'
)

Simple Rule

  • Using ReLU? → Use Kaiming
  • Using Sigmoid/Tanh? → Use Xavier
graph TD A["Choose Activation"] --> B{ReLU?} B -->|Yes| C["Kaiming Init"] B -->|No| D["Xavier Init"] C --> E["Ready to Train!"] D --> E

🔄 2. Training Loop Implementation

The Practice Routine Story

Learning to cook requires practice. Every day, the chef:

  1. Looks at a recipe (forward pass)
  2. Tastes the result (compute loss)
  3. Figures out what went wrong (backward pass)
  4. Adjusts technique (update weights)
  5. Repeat!

The Basic Loop

model.train()  # Chef enters kitchen

for epoch in range(num_epochs):
    for batch in train_loader:
        # Step 1: Get ingredients
        inputs, targets = batch

        # Step 2: Cook the dish
        outputs = model(inputs)

        # Step 3: Taste it
        loss = criterion(outputs, targets)

        # Step 4: Clear old notes
        optimizer.zero_grad()

        # Step 5: Figure out fixes
        loss.backward()

        # Step 6: Apply fixes
        optimizer.step()

Each Step Explained

Step Code What Happens
Forward model(inputs) Make prediction
Loss criterion(...) Measure mistake
Zero Grad optimizer.zero_grad() Clear old gradients
Backward loss.backward() Calculate fixes
Update optimizer.step() Apply fixes
graph TD A["Get Batch"] --> B["Forward Pass"] B --> C["Compute Loss"] C --> D["Zero Gradients"] D --> E["Backward Pass"] E --> F["Update Weights"] F --> G{More Batches?} G -->|Yes| A G -->|No| H["Epoch Done!"]

🧪 3. Validation and Testing

The Taste Test Story

A chef can’t only eat their own food! They need:

  • Training – Daily practice cooking
  • Validation – Friend tastes during practice
  • Testing – Food critic’s final review

The Key Difference

Set Purpose When Used
Training Learn patterns Every epoch
Validation Check progress Every epoch
Testing Final score Once at end

Validation Code

model.eval()  # No learning mode

with torch.no_grad():  # No gradients!
    for batch in val_loader:
        inputs, targets = batch
        outputs = model(inputs)
        val_loss = criterion(
            outputs, targets
        )

Important: torch.no_grad()

This tells PyTorch: “We’re just looking, not learning!” It:

  • Saves memory
  • Speeds up computation
  • Prevents accidental learning
graph TD A["Training Set"] --> B["Model Learns"] B --> C["Validation Set"] C --> D{Good Enough?} D -->|No| A D -->|Yes| E["Test Set"] E --> F["Final Score!"]

⚖️ 4. Overfitting and Underfitting

The Seasoning Story

Imagine two chefs:

  • Chef Overfit memorizes exact salt amounts. New dish? Disaster!
  • Chef Underfit uses the same salt for everything. Boring!

Visual Guide

Underfitting          Just Right           Overfitting
     😴                   😊                   🤯

   Simple               Balanced            Complex
   model                 model               model

Can't learn          Learns well         Memorizes
 patterns             patterns            examples

How to Spot Them

Problem Training Loss Validation Loss
Underfitting High High
Just Right Low Low
Overfitting Very Low High

Fixes

For Underfitting:

  • Make model bigger
  • Train longer
  • Add more features

For Overfitting:

  • Add dropout
  • Use regularization
  • Get more data
  • Simplify model
# Dropout example
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(100, 50)
        self.dropout = nn.Dropout(0.5)
        self.layer2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = self.dropout(x)  # Random off
        return self.layer2(x)

🛑 5. Early Stopping

The Perfect Toast Story

When making toast, you watch carefully. The moment it’s golden – you stop! Wait too long? Burnt. Stop too early? Soggy bread.

Early stopping watches your validation loss. When it stops improving, training stops!

How It Works

Epoch 1: Val Loss = 0.9  ✓ Improving
Epoch 2: Val Loss = 0.7  ✓ Improving
Epoch 3: Val Loss = 0.5  ✓ Improving (Best!)
Epoch 4: Val Loss = 0.6  ⚠️ Worse
Epoch 5: Val Loss = 0.7  ⚠️ Worse (Patience = 2)
→ STOP! Go back to Epoch 3

Simple Implementation

best_loss = float('inf')
patience = 5
counter = 0

for epoch in range(1000):
    val_loss = validate(model)

    if val_loss < best_loss:
        best_loss = val_loss
        counter = 0
        # Save best model
        torch.save(
            model.state_dict(),
            'best_model.pt'
        )
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping!")
            break

# Load best model
model.load_state_dict(
    torch.load('best_model.pt')
)
graph TD A["Train Epoch"] --> B["Check Val Loss"] B --> C{Better?} C -->|Yes| D["Save Model"] D --> E["Reset Counter"] E --> A C -->|No| F["Increase Counter"] F --> G{Counter >= Patience?} G -->|No| A G -->|Yes| H["Stop &amp; Load Best"]

✂️ 6. Gradient Clipping

The Fire Prevention Story

Imagine a chef who gets so excited about seasoning, they dump the entire salt container! Gradient explosion is like that – updates become so huge, the model goes crazy.

Gradient clipping says: “No matter how excited you are, only add THIS much salt.”

Why Gradients Explode

In deep networks, gradients multiply through layers. Like compound interest, small numbers become HUGE:

  • Layer 1: gradient × 2
  • Layer 2: gradient × 2
  • Layer 10: gradient × 1024!

The Fix

# Before optimizer.step()
torch.nn.utils.clip_grad_norm_(
    model.parameters(),
    max_norm=1.0  # Clip threshold
)
optimizer.step()

Two Clipping Methods

Method What It Does
clip_grad_norm_ Scales all gradients together
clip_grad_value_ Clips each gradient separately
# Method 1: Norm clipping (recommended)
torch.nn.utils.clip_grad_norm_(
    model.parameters(),
    max_norm=1.0
)

# Method 2: Value clipping
torch.nn.utils.clip_grad_value_(
    model.parameters(),
    clip_value=0.5
)

When to Use

  • Training RNNs/LSTMs → Always clip!
  • Very deep networks → Probably clip
  • Stable training → Usually okay without

⚖️ 7. Weight Norm and Spectral Norm

The Balanced Portions Story

In a good restaurant, portions are controlled. Too much pasta on one plate, not enough on another? Bad experience!

Weight Norm and Spectral Norm keep your network’s weights balanced.

Weight Normalization

Separates weight magnitude from direction:

from torch.nn.utils import weight_norm

# Wrap a layer with weight norm
layer = weight_norm(
    nn.Linear(100, 50),
    name='weight'
)

Why? Faster, more stable training. The optimizer can separately adjust “how much” vs “which direction.”

Spectral Normalization

Controls the “stretchiness” of each layer:

from torch.nn.utils import spectral_norm

# Wrap a layer with spectral norm
layer = spectral_norm(
    nn.Linear(100, 50)
)

Why? Essential for training GANs! Prevents the discriminator from becoming too powerful too fast.

Quick Comparison

Technique Best For Effect
Weight Norm General training Faster convergence
Spectral Norm GANs Stable adversarial training
graph TD A["Raw Weights"] --> B{Need Control?} B -->|Faster Training| C["Weight Norm"] B -->|GAN Stability| D["Spectral Norm"] C --> E["Separated Magnitude"] D --> F["Bounded Lipschitz"]

📦 8. Micro-batch Accumulation

The Small Kitchen Story

Your kitchen is tiny. You can only cook 4 meals at a time, but customers expect batches of 32!

Solution: Cook 4 meals, set aside. Cook 4 more, set aside. After 8 rounds, serve all 32!

This is gradient accumulation – simulate large batches with small memory.

The Problem

Want: batch_size = 64
Have: GPU memory for batch_size = 8

The Solution

accumulation_steps = 8  # 8 × 8 = 64
optimizer.zero_grad()

for i, batch in enumerate(train_loader):
    inputs, targets = batch
    outputs = model(inputs)

    # Scale loss by accumulation
    loss = criterion(outputs, targets)
    loss = loss / accumulation_steps
    loss.backward()  # Accumulates!

    # Update only every N steps
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

Key Points

Aspect Small Batch With Accumulation
Memory per step Low ✓ Low ✓
Effective batch 8 64 ✓
Training stability Noisy Smooth ✓

Why Divide Loss?

When you call loss.backward() multiple times, gradients add up. Dividing by accumulation_steps keeps the total gradient equivalent to a single large batch.

graph TD A["Mini-batch 1"] --> B["Compute Gradients"] B --> C["Accumulate"] D["Mini-batch 2"] --> E["Compute Gradients"] E --> C F["..."] --> C G["Mini-batch N"] --> H["Compute Gradients"] H --> C C --> I{N steps done?} I -->|Yes| J["Update Weights"] J --> K["Zero Gradients"] K --> A I -->|No| A

🎓 Putting It All Together

Here’s a complete training script with ALL concepts:

import torch
import torch.nn as nn

# Model with weight norm
model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(256, 10)
)

# Initialize weights (Kaiming for ReLU)
for m in model.modules():
    if isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(
            m.weight,
            nonlinearity='relu'
        )

# Setup
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    model.parameters()
)

# Early stopping setup
best_loss = float('inf')
patience = 5
counter = 0

# Gradient accumulation setup
accum_steps = 4

# Training loop
for epoch in range(100):
    model.train()
    optimizer.zero_grad()

    for i, (x, y) in enumerate(train_loader):
        # Forward
        out = model(x)
        loss = criterion(out, y)
        loss = loss / accum_steps

        # Backward
        loss.backward()

        # Accumulate then update
        if (i + 1) % accum_steps == 0:
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                max_norm=1.0
            )
            optimizer.step()
            optimizer.zero_grad()

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for x, y in val_loader:
            out = model(x)
            val_loss += criterion(out, y)
    val_loss /= len(val_loader)

    # Early stopping check
    if val_loss < best_loss:
        best_loss = val_loss
        counter = 0
        torch.save(
            model.state_dict(),
            'best.pt'
        )
    else:
        counter += 1
        if counter >= patience:
            print("Stopping early!")
            break

# Load best and test
model.load_state_dict(
    torch.load('best.pt')
)

🏆 Quick Reference

Concept One-Line Summary
Weight Init Start weights smart, not random or zero
Training Loop Forward → Loss → Backward → Update
Validation Check progress without learning
Overfitting Memorized training, fails on new data
Underfitting Too simple to learn patterns
Early Stopping Stop when validation stops improving
Gradient Clipping Prevent exploding updates
Weight Norm Separate magnitude from direction
Spectral Norm Control layer “stretchiness”
Micro-batch Fake big batches with small memory

🚀 You Did It!

You now understand the complete PyTorch training pipeline! Like a master chef who knows:

  • How to set up their kitchen (initialization)
  • The perfect cooking routine (training loop)
  • When to taste test (validation)
  • Signs of over/under seasoning (fitting problems)
  • When the dish is done (early stopping)
  • How to prevent fires (gradient clipping)
  • Keeping portions balanced (normalization)
  • Cooking in a small kitchen (accumulation)

Go forth and train amazing models! 🎉

Loading story...

Story - Premium Content

Please sign in to view this story and start learning.

Upgrade to Premium to unlock full access to all stories.

Stay Tuned!

Story is coming soon.

Story Preview

Story - Premium Content

Please sign in to view this concept and start learning.

Upgrade to Premium to unlock full access to all content.