Model Saving and Export

Back

Loading concept...

🎒 Training Workflow: Model Saving and Export

The Backpack of Your AI Journey

Imagine you’re a world explorer who has learned amazing skills on your adventure. What if you could pack ALL your skills into a magical backpack and share them with friends anywhere in the world? That’s exactly what model saving and export does for your PyTorch models!

Your model has learned patterns from thousands of examples. Now you need to save it so you don’t lose that learning, and export it so it can work anywhere—on phones, websites, or other computers.


🏠 Saving and Loading Models

The Magic Diary Analogy

Think of your trained model like a diary full of secrets. You want to:

  1. Save the diary so you can read it later
  2. Load the diary when you need those secrets again

Two Ways to Save

Method 1: Save Just the Secrets (Recommended)

# Save only the learned weights
torch.save(model.state_dict(),
           'my_model_weights.pth')

# Load them back
model = MyModel()  # Create empty model
model.load_state_dict(
    torch.load('my_model_weights.pth')
)

Why this is better:

  • Smaller file size
  • Works even if you change your code slightly
  • Faster to save and load

Method 2: Save the Whole Thing

# Save entire model
torch.save(model, 'full_model.pth')

# Load it back
model = torch.load('full_model.pth')

When to use this:

  • Quick experiments
  • Sharing with someone using the exact same code

Real Life Example

# After training your cat detector
torch.save(
    cat_detector.state_dict(),
    'cat_detector_v1.pth'
)
print("Model saved! Your learning is safe.")

💾 Checkpointing

The Video Game Save Point

Remember playing video games? You save your progress so if something goes wrong, you don’t start from the beginning. Checkpointing is exactly that for training!

What to Save in a Checkpoint

checkpoint = {
    'epoch': 10,
    'model_state': model.state_dict(),
    'optimizer_state': optimizer.state_dict(),
    'loss': 0.05,
    'best_accuracy': 0.95
}
torch.save(checkpoint, 'checkpoint.pth')

Why Save All This?

Item Why It Matters
epoch Know where you stopped
model_state The brain’s learning
optimizer_state Training momentum
loss Track progress
best_accuracy Your best score

Loading a Checkpoint

checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(
    checkpoint['model_state']
)
optimizer.load_state_dict(
    checkpoint['optimizer_state']
)
start_epoch = checkpoint['epoch']

Pro Tip: Auto-Save the Best

if current_accuracy > best_accuracy:
    best_accuracy = current_accuracy
    torch.save(
        model.state_dict(),
        'best_model.pth'
    )
    print("New best model saved!")

🚀 torch.export

The Universal Translator

torch.export is like translating your diary into a universal language that computers everywhere can read. It creates a clean, optimized version of your model.

How It Works

import torch

# Your trained model
model = MyModel()
model.eval()

# Example input shape
example_input = torch.randn(1, 3, 224, 224)

# Export it!
exported = torch.export.export(
    model,
    (example_input,)
)

Why Use torch.export?

  • Faster: Optimized for production
  • Portable: Works across different systems
  • Safe: Catches errors before deployment
graph TD A["Trained Model"] --> B["torch.export"] B --> C["ExportedProgram"] C --> D["Deploy Anywhere"]

The Export is Smart

It captures your model’s logic and checks:

  • All operations are valid
  • Input shapes are correct
  • No Python tricks that can’t be translated

🔍 TorchScript Tracing

The Detective Method

Tracing is like having a detective follow your model and write down everything it does. You give it an example input, and it records every step.

How to Trace

model.eval()

# Example input
example = torch.rand(1, 3, 224, 224)

# The detective watches and records
traced_model = torch.jit.trace(
    model,
    example
)

# Save the recording
traced_model.save('traced_model.pt')

When Tracing Works Best

✅ Simple models with fixed paths ✅ No if/else based on input values ✅ Same operations every time

When Tracing Gets Confused

# This confuses tracing!
def forward(self, x):
    if x.sum() > 0:  # Changes per input
        return x * 2
    else:
        return x * 3

The detective only sees ONE path (the one your example takes), so it misses the other!

Verify Your Traced Model

# Original output
original = model(test_input)

# Traced output
traced = traced_model(test_input)

# Should be the same!
print(torch.allclose(original, traced))

✍️ TorchScript Scripting

The Perfect Copy Method

Scripting is like making a perfect copy of your recipe book—it understands ALL the instructions, including “if this, do that.”

How to Script

model.eval()

# Script the entire model
scripted_model = torch.jit.script(model)

# Save it
scripted_model.save('scripted_model.pt')

Scripting Handles Logic

class SmartModel(nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x * 2
        else:
            return x * 3

# Scripting captures BOTH paths!
scripted = torch.jit.script(SmartModel())

Tracing vs Scripting

Feature Tracing Scripting
Speed Faster Slower
If/Else ❌ Misses ✅ Captures
Loops ❌ Fixed ✅ Dynamic
Best For Simple CNN Complex logic

Mix and Match

# Script just one function
@torch.jit.script
def special_function(x):
    if x.mean() > 0.5:
        return x ** 2
    return x

# Use it in traced model
class HybridModel(nn.Module):
    def forward(self, x):
        return special_function(x)

🌐 ONNX Export

The Universal Passport

ONNX (Open Neural Network Exchange) is like giving your model a passport that works in ANY country. TensorFlow, CoreML, TensorRT—they all speak ONNX!

graph TD A["PyTorch Model"] --> B["ONNX Format"] B --> C["TensorFlow"] B --> D["CoreML/iOS"] B --> E["TensorRT/NVIDIA"] B --> F["ONNX Runtime"]

Export to ONNX

import torch.onnx

model.eval()
dummy_input = torch.randn(1, 3, 224, 224)

torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    input_names=['image'],
    output_names=['prediction'],
    dynamic_axes={
        'image': {0: 'batch_size'},
        'prediction': {0: 'batch_size'}
    }
)

Key Parameters Explained

Parameter What It Does
input_names Name your inputs
output_names Name your outputs
dynamic_axes Allow flexible batch sizes
opset_version ONNX version (use 11+)

Verify Your ONNX Model

import onnx

# Load and check
onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)
print("ONNX model is valid!")

Run with ONNX Runtime

import onnxruntime as ort
import numpy as np

session = ort.InferenceSession("model.onnx")
input_name = session.get_inputs()[0].name

# Run inference
result = session.run(
    None,
    {input_name: image_array}
)

🎯 Quick Decision Guide

graph TD A["Need to Save Model?"] --> B{Same PyTorch version?} B -->|Yes| C["torch.save state_dict"] B -->|No| D{Complex logic?} D -->|No| E["TorchScript Trace"] D -->|Yes| F["TorchScript Script"] A --> G{Deploy to other frameworks?} G -->|Yes| H["ONNX Export"] A --> I{Production optimization?} I -->|Yes| J["torch.export"]

🎒 Pack Your Model Checklist

Training complete? Save a checkpoint! ✅ Best model found? Save state_dict! ✅ Deploy to production? Use torch.export! ✅ Simple model, no if/else? Try tracing! ✅ Complex logic? Use scripting! ✅ Cross-platform needed? Export to ONNX!


🌟 Remember

Your model’s learning is precious. Saving protects it from crashes. Exporting lets it travel anywhere. Choose the right method for your journey, and your AI adventures will never be lost!

“A well-saved model is a journey that never ends.”

Loading story...

Story - Premium Content

Please sign in to view this story and start learning.

Upgrade to Premium to unlock full access to all stories.

Stay Tuned!

Story is coming soon.

Story Preview

Story - Premium Content

Please sign in to view this concept and start learning.

Upgrade to Premium to unlock full access to all content.