主题
字号
CHAPTER 04 ≈ 25 MIN READ

数据增强与数据集加载

数据增强是提升模型泛化能力的关键技术。torchvision.transforms提供了丰富的图像变换函数。

图像变换与数据增强

transforms.Compose:组合多个变换

Compose用于将多个变换组合成管道。

from torchvision import transforms

# 创建变换管道
transform = transforms.Compose([
    transforms.ToTensor(),           # 1. 转为张量
    transforms.Normalize([0.5], [0.5]),  # 2. 标准化
    transforms.RandomHorizontalFlip(),   # 3. 随机翻转
])

# 应用变换
from PIL import Image
img = Image.open('cat.jpg')
img_tensor = transform(img)
print(img_tensor.shape)  # torch.Size([3, H, W])

transforms.ToTensor:图像转张量

ToTensor将PIL Image或NumPy数组转换为PyTorch张量,并自动归一化到[0, 1]。

from torchvision import transforms
from PIL import Image
import numpy as np

# 从PIL Image转换
img_pil = Image.open('cat.jpg')
print("PIL Image:", img_pil.mode, img_pil.size)

to_tensor = transforms.ToTensor()
tensor = to_tensor(img_pil)
print("Tensor:", tensor.shape, tensor.min(), tensor.max())
# 自动转换为 (C, H, W),值在 [0, 1]

# 从NumPy数组转换
img_np = np.array(img_pil)
print("NumPy:", img_np.shape)
tensor_np = to_tensor(img_np)
print("Tensor from NumPy:", tensor_np.shape)

# 转换前后对比
print("\n=== 转换说明 ===")
# PIL Image (H, W, C) [0-255]
# NumPy   (H, W, C) [0-255]
# Tensor  (C, H, W) [0, 1]

transforms.Normalize:标准化

Normalize使用均值和标准差对图像进行标准化。

数学公式

normalized = (input - mean) / std
# ImageNet常用均值和标准差
normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],  # RGB通道的均值
    std=[0.229, 0.224, 0.225]     # RGB通道的标准差
)

# 标准化流程
x = torch.rand(3, 224, 224)  # [0, 1]
x_normalized = normalize(x)

print("标准化前: min={:.3f}, max={:.3f}".format(x.min(), x.max()))
print("标准化后: min={:.3f}, max={:.3f}".format(x_normalized.min(), x_normalized.max()))
# 标准化后值通常在 [-2, 2] 范围内

训练和测试使用相同的标准化

# 定义标准化参数(ImageNet)
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]

# 训练变换
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std),
])

# 测试变换(验证集、测试集)
test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std),
])

# 使用变换
train_dataset = datasets.FashionMNIST(
    root='./data',
    train=True,
    transform=train_transform,
    download=True
)

test_dataset = datasets.FashionMNIST(
    root='./data',
    train=False,
    transform=test_transform,
    download=True
)

transforms.Resize与transforms.CenterCrop

调整图像尺寸并裁剪。

# Resize:调整图像大小
resize = transforms.Resize(size=(224, 224))  # 调整为指定尺寸
# 或者
resize = transforms.Resize(size=224)  # 短边调整为224,保持比例

# CenterCrop:中心裁剪
center_crop = transforms.CenterCrop(size=(224, 224))

# FiveCrop:四个角和中心裁剪(返回5个图像)
five_crop = transforms.FiveCrop(size=(224, 224))

# TenCrop:水平翻转后裁剪(返回10个图像)
ten_crop = transforms.TenCrop(size=(224, 224), vertical_flip=False)

# 组合使用
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])

transforms.RandomHorizontalFlip:随机水平翻转

随机水平翻转是数据增强中最常用、最有效的技术之一。

# 随机水平翻转
hflip = transforms.RandomHorizontalFlip(p=0.5)  # p: 翻转概率

# 随机垂直翻转
vflip = transforms.RandomVerticalFlip(p=0.5)

# 实际应用
img = Image.open('cat.jpg')
flipped = hflip(img)

transforms.RandomRotation:随机旋转

# 随机旋转
rotation = transforms.RandomRotation(degrees=15)  # -15到15度之间随机旋转
rotation10 = transforms.RandomRotation(degrees=(10, 30))  # 10到30度之间随机旋转

# 随机旋转(包含填充)
rotation_fill = transforms.RandomRotation(
    degrees=30,
    fill=(255, 255, 255)  # 填充颜色
)

transforms.ColorJitter:颜色抖动

# 颜色抖动
color_jitter = transforms.ColorJitter(
    brightness=0.2,    # 亮度调整范围
    contrast=0.2,      # 对比度调整范围
    saturation=0.2,    # 饱和度调整范围
    hue=0.1           # 色调调整范围
)

# 单独使用
brightness = transforms.ColorJitter(brightness=0.3)
contrast = transforms.ColorJitter(contrast=0.3)
saturation = transforms.ColorJitter(saturation=0.3)
hue = transforms.ColorJitter(hue=0.1)

transforms.RandomAffine:随机仿射变换

# 随机仿射变换
affine = transforms.RandomAffine(
    degrees=15,              # 旋转角度
    translate=(0.1, 0.1),   # 平移范围(相对于尺寸的比例)
    scale=(0.9, 1.1),        # 缩放范围
    shear=15                # 剪切角度
)

transforms.RandomErasing:随机擦除

RandomErasing(随机擦除)是一种有效的正则化技术,模拟遮挡。

# 随机擦除
random_erase = transforms.RandomErasing(
    p=0.5,                  # 擦除概率
    scale=(0.02, 0.33),     # 擦除区域相对于图像的比例范围
    ratio=(0.3, 3.3),       # 擦除区域宽高比范围
    value=0                 # 擦除区域的值
)

# 在张量上应用
to_tensor = transforms.ToTensor()
img_tensor = to_tensor(img)
erased = random_erase(img_tensor)

综合数据增强示例

# 完整的数据增强策略
train_transform = transforms.Compose([
    # 1. 随机大小裁剪并调整到224x224
    transforms.RandomResizedCrop(
        224,
        scale=(0.8, 1.0),    # 裁剪区域为原图的80%-100%
        ratio=(0.9, 1.1)     # 宽高比范围
    ),

    # 2. 随机水平翻转
    transforms.RandomHorizontalFlip(p=0.5),

    # 3. 随机旋转(-15到15度)
    transforms.RandomRotation(15),

    # 4. 颜色抖动
    transforms.ColorJitter(
        brightness=0.2,
        contrast=0.2,
        saturation=0.2,
        hue=0.1
    ),

    # 5. 转换为张量
    transforms.ToTensor(),

    # 6. 标准化
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),

    # 7. 随机擦除
    transforms.RandomErasing(p=0.3),
])

# 测试变换
test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

torchvision.datasets:内置数据集

PyTorch提供了常用的图像数据集,可以直接下载使用。

from torchvision import datasets, transforms

# Fashion-MNIST
fashion_train = datasets.FashionMNIST(
    root='./data',
    train=True,
    transform=transforms.ToTensor(),
    download=True
)

fashion_test = datasets.FashionMNIST(
    root='./data',
    train=False,
    transform=transforms.ToTensor(),
    download=True
)

# CIFAR-10
cifar10_train = datasets.CIFAR10(
    root='./data',
    train=True,
    transform=transforms.ToTensor(),
    download=True
)

cifar10_test = datasets.CIFAR10(
    root='./data',
    train=False,
    transform=transforms.ToTensor(),
    download=True
)

# ImageNet(需要手动下载)
# imagenet_train = datasets.ImageNet(
#     root='./data/imagenet',
#     split='train',
#     transform=transforms.ToTensor()
# )

# 查看数据集信息
print("Fashion-MNIST训练集:", len(fashion_train))
print("Fashion-MNIST测试集:", len(fashion_test))
print("CIFAR-10类别:", datasets.CIFAR10.classes)

常用数据集速查

数据集 类别数 训练集大小 图像大小 说明
MNIST 10 60,000 28x28 手写数字
Fashion-MNIST 10 60,000 28x28 服装分类
CIFAR-10 10 50,000 32x32 通用物体
CIFAR-100 100 50,000 32x32 100类物体
ImageNet 1000 1.2M 可变 大规模图像

Dataset类:自定义数据集

当需要加载自己的数据时,需要自定义Dataset类。

from torch.utils.data import Dataset
from PIL import Image
import os

class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []

        # 读取数据目录
        for label_name in os.listdir(root_dir):
            label_path = os.path.join(root_dir, label_name)
            if os.path.isdir(label_path):
                for img_name in os.listdir(label_path):
                    if img_name.endswith(('.jpg', '.png', '.jpeg')):
                        self.image_paths.append(os.path.join(label_path, img_name))
                        self.labels.append(label_name)

        # 创建标签到索引的映射
        self.classes = sorted(list(set(self.labels)))
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # 加载图像
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')

        # 应用变换
        if self.transform:
            image = self.transform(image)

        # 获取标签
        label = self.labels[idx]
        label_idx = self.class_to_idx[label]

        return image, label_idx

# 使用自定义数据集
custom_dataset = CustomDataset(
    root_dir='./my_data/train',
    transform=train_transform
)

print("数据集大小:", len(custom_dataset))
print("类别:", custom_dataset.classes)

处理文件夹结构

my_data/
├── train/
│   ├── cat/
│   │   ├── cat001.jpg
│   │   ├── cat002.jpg
│   │   └── ...
│   ├── dog/
│   │   ├── dog001.jpg
│   │   └── ...
│   └── ...
└── val/
    ├── cat/
    └── dog/

DataLoader:批量数据加载

DataLoader是PyTorch中最重要的数据加载工具。

from torch.utils.data import DataLoader

# 基础用法
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True
)

# 迭代数据
for batch_idx, (images, labels) in enumerate(train_loader):
    print("批次数:", batch_idx)
    print("图像形状:", images.shape)    # [32, 3, 224, 224]
    print("标签形状:", labels.shape)    # [32]
    break

DataLoader参数详解

参数 说明 常用值
dataset 数据集 -
batch_size 批大小 32, 64, 128
shuffle 是否打乱 True(训练), False(测试)
num_workers 数据加载进程数 4, 8
pin_memory 锁页内存,加快GPU传输 True
drop_last 丢弃最后一个不完整batch True(训练), False(测试)
collate_fn 自定义批处理函数 自定义

多GPU数据加载

# 多GPU时使用DistributedSampler
from torch.utils.data.distributed import DistributedSampler

train_sampler = DistributedSampler(
    dataset,
    num_replicas=num_gpus,
    rank=rank,
    shuffle=True
)

train_loader = DataLoader(
    dataset,
    batch_size=batch_size,
    sampler=train_sampler,
    num_workers=4,
    pin_memory=True
)