🍽️ PyTorch Data Loading: The Restaurant Kitchen Story
Imagine you’re running a busy restaurant kitchen. You have a giant warehouse full of ingredients (your data), and you need to get them to the chefs (your neural network) in the right way. Too slow? Customers wait. Wrong portions? Disaster!
PyTorch’s data loading system is like having a super-organized kitchen manager who knows exactly how to fetch, prepare, and serve ingredients perfectly.
🏪 The Dataset Class: Your Recipe Book
Think of the Dataset class as your master recipe book. It tells PyTorch two things:
- How many recipes do you have? (
__len__) - What’s in recipe #5? (
__getitem__)
from torch.utils.data import Dataset
class MyRecipeBook(Dataset):
def __init__(self, recipes):
self.recipes = recipes
def __len__(self):
return len(self.recipes)
def __getitem__(self, idx):
return self.recipes[idx]
That’s it! Just tell PyTorch how many items you have and how to get one item. Simple as counting books on a shelf and picking one out!
graph TD A["Dataset Class"] --> B["__len__"] A --> C["__getitem__"] B --> D["How many items?"] C --> E["Get item at index"]
📦 Built-in Datasets: Ready-Made Ingredients
PyTorch comes with pre-packaged datasets like a supermarket with ready-to-cook meals!
Popular ones include:
| Dataset | What it is |
|---|---|
| MNIST | Handwritten digits (0-9) |
| CIFAR-10 | Tiny colorful pictures |
| ImageNet | Millions of real photos |
| FashionMNIST | Clothes images |
from torchvision import datasets
# Download MNIST - like ordering
# groceries online!
mnist = datasets.MNIST(
root='./data',
train=True,
download=True
)
print(f"We have {len(mnist)} images!")
Why use them? You skip the boring prep work and jump straight to the fun part - training your model!
🔧 Custom Datasets: Your Special Recipe
Sometimes you need your own special recipe. Maybe you have photos of your cat, or sales data from your shop.
Here’s how to make your own Dataset:
class CatPhotoDataset(Dataset):
def __init__(self, folder_path):
self.images = os.listdir(folder_path)
self.folder = folder_path
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = os.path.join(
self.folder,
self.images[idx]
)
image = load_image(img_path)
return image
The secret? Just implement those two magic methods:
__len__= “How many items do I have?”__getitem__= “Give me item number X”
🌊 IterableDataset: The Streaming Buffet
What if your data is SO BIG it won’t fit in memory? Like an endless buffet that keeps bringing new dishes!
IterableDataset is perfect for:
- Live data streams (Twitter feeds, sensors)
- Files too large to load at once
- Data that generates on-the-fly
from torch.utils.data import IterableDataset
class EndlessBuffet(IterableDataset):
def __iter__(self):
while True:
# Keep generating data
dish = make_new_dish()
yield dish
Key difference:
- Regular Dataset: “Give me item #42” ✓
- IterableDataset: “Just keep giving me the next one” ✓
graph TD A["Data Source"] --> B["IterableDataset"] B --> C["Item 1"] B --> D["Item 2"] B --> E["Item 3..."] B --> F["Endless stream"]
🚚 DataLoader: The Delivery Truck
Now for the star of the show!
The DataLoader is your delivery truck. It takes items from your Dataset and delivers them to your model in organized batches.
from torch.utils.data import DataLoader
loader = DataLoader(
dataset=my_dataset,
batch_size=32, # 32 items per trip
shuffle=True, # Mix it up!
num_workers=4 # 4 helper trucks
)
for batch in loader:
# Each batch has 32 items!
train_on_batch(batch)
Why batches?
- Training on one item = too slow
- Training on everything = crashes memory
- Batches = just right! 🎯
Key settings:
| Setting | What it does |
|---|---|
batch_size |
Items per batch |
shuffle |
Randomize order |
num_workers |
Parallel loading |
drop_last |
Drop incomplete batch |
🍳 Collate Functions: The Prep Chef
Sometimes your data pieces don’t fit together neatly. Collate functions are like a prep chef who arranges everything perfectly on the plate.
Default behavior: Stack tensors into batches.
Custom collate: Handle weird data shapes!
def my_collate(batch):
# batch = list of items
images = [item['image'] for item in batch]
labels = [item['label'] for item in batch]
# Pad images to same size
images = pad_to_same_size(images)
return {
'images': torch.stack(images),
'labels': torch.tensor(labels)
}
loader = DataLoader(
dataset,
collate_fn=my_collate
)
When you need custom collate:
- Variable-length sequences (sentences)
- Different image sizes
- Special preprocessing
🎲 Samplers: The Order Decider
Samplers decide which items to pick and in what order. Like a lottery machine choosing numbers!
Types of Samplers
SequentialSampler - Picks 0, 1, 2, 3… in order
from torch.utils.data import SequentialSampler
sampler = SequentialSampler(dataset)
RandomSampler - Picks randomly
from torch.utils.data import RandomSampler
sampler = RandomSampler(dataset)
WeightedRandomSampler - Some items more likely
# Make rare items appear more often
weights = [1.0, 1.0, 10.0, 1.0] # Item 2 = 10x more likely
sampler = WeightedRandomSampler(
weights,
num_samples=100
)
graph TD A["Sampler Types"] --> B["Sequential"] A --> C["Random"] A --> D["Weighted"] B --> E["0,1,2,3..."] C --> F["3,7,1,9..."] D --> G["Rare items more often"]
Why use weighted sampling? If you have 1000 cats and 10 dogs, normal training sees mostly cats. Weighted sampling ensures your model sees enough dogs too!
✂️ Dataset Splitting: Dividing the Pie
You never test on training data. That’s like grading a test with the answer key visible!
Three pieces of the pie:
- Training set (70-80%) - Learn from this
- Validation set (10-15%) - Tune settings
- Test set (10-15%) - Final grade
from torch.utils.data import random_split
dataset = MyDataset() # 1000 items
train_set, val_set, test_set = random_split(
dataset,
[700, 150, 150] # 70%, 15%, 15%
)
print(f"Train: {len(train_set)}") # 700
print(f"Val: {len(val_set)}") # 150
print(f"Test: {len(test_set)}") # 150
Another way - by percentage:
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_set, test_set = random_split(
dataset,
[train_size, test_size]
)
graph TD A["Full Dataset"] --> B["random_split"] B --> C["Training 70%"] B --> D["Validation 15%"] B --> E["Test 15%"] C --> F["Model learns"] D --> G["Tune hyperparams"] E --> H["Final evaluation"]
🎯 Putting It All Together
Here’s the complete kitchen workflow:
from torch.utils.data import (
Dataset, DataLoader, random_split
)
# 1. Create your Dataset (recipe book)
class MyData(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# 2. Split it up
full = MyData(all_your_data)
train, val, test = random_split(
full, [800, 100, 100]
)
# 3. Create DataLoaders (delivery trucks)
train_loader = DataLoader(
train,
batch_size=32,
shuffle=True
)
val_loader = DataLoader(
val,
batch_size=32
)
# 4. Train!
for epoch in range(10):
for batch in train_loader:
# Your training code here
pass
🌟 Quick Reference
| Component | Purpose | Analogy |
|---|---|---|
| Dataset | Store & access data | Recipe book |
| Built-in Datasets | Pre-made data | Supermarket meals |
| Custom Dataset | Your own data | Your recipe |
| IterableDataset | Streaming data | Endless buffet |
| DataLoader | Batch & deliver | Delivery truck |
| Collate Function | Format batches | Prep chef |
| Sampler | Choose order | Lottery machine |
| random_split | Divide data | Cutting pie |
🚀 You Did It!
You now understand PyTorch’s entire data loading pipeline! From storing data in Datasets, to streaming with IterableDataset, to batching with DataLoaders, to smart sampling and proper splitting.
Remember: Good data loading = fast training = happy models!
Your neural network is hungry. Now you know exactly how to feed it! 🍽️
