当前位置:首页 > python > 正文内容

PyTorch自定义Dataset全解析:从理论到实战的完整指南

zhangsir6个月前 (06-30)python133

一、Dataset的核心机制

PyTorch的数据加载体系基于两大核心组件:

  1. Dataset:定义数据集的抽象接口,负责索引到样本的映射。

  2. DataLoader:封装Dataset,提供批量加载、多线程加速等功能。

1.1 数据传递流程

# 典型流程示例
dataset = CustomDataset(...)  # 创建自定义数据集
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)  # 包装为DataLoader
for batch in dataloader:  # 迭代获取批量数据
    inputs, labels = batch

1.2 Map式 vs Iterable式数据集

  • Map式数据集(常用):通过__getitem__实现索引访问,支持随机打乱(shuffle)。

  • Iterable式数据集:适用于流式数据(如实时传感器数据),按顺序迭代。

二、自定义Dataset的实现范式

继承torch.utils.data.Dataset需实现三个核心方法:

2.1 基础模板

from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform
        self.samples = self._load_data()  # 加载数据列表

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        # 1. 读取原始数据
        sample = self._read_sample(index)
        
        # 2. 应用预处理
        if self.transform:
            sample = self.transform(sample)
            
        # 3. 返回样本(如图像+标签)
        return sample

2.2 关键方法详解

  • __init__:初始化数据路径、预处理变换,并加载数据元信息。

  • __len__:返回数据集大小,用于DataLoader的进度控制。

  • __getitem__:核心方法,需完成数据读取、预处理和返回。

三、实战案例:图像分类数据集

以Kaggle的Dogs vs Cats数据集为例,实现自定义Dataset:

3.1 数据准备

dataset/
├── train/
│   ├── cat.0.jpg
│   ├── dog.0.jpg
│   └── ...
└── annotations.txt  # 格式:filename label

3.2 完整实现

import os
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

class DogCatDataset(Dataset):
    def __init__(self, root_dir, ann_file, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = self._load_annotations(ann_file)

    def _load_annotations(self, ann_file):
        samples = []
        with open(ann_file) as f:
            for line in f:
                filename, label = line.strip().split()
                samples.append((filename, int(label)))
        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        filename, label = self.samples[index]
        img_path = os.path.join(self.root_dir, filename)
        image = Image.open(img_path)
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

# 使用示例
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = DogCatDataset(
    root_dir='dataset/train',
    ann_file='dataset/annotations.txt',
    transform=transform
)

dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)


四、高级技巧与优化

4.1 多线程加速

通过num_workers参数启用多进程加载:

pythondataloader = DataLoader(dataset, batch_size=64, num_workers=8)

4.2 自定义Collate函数

处理变长数据(如NLP序列):

def collate_fn(batch):
    # batch: List[Tuple(image, label)]
    images = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    # 堆叠图像并转换为Tensor
    images = torch.stack(images, dim=0)
    labels = torch.LongTensor(labels)
    return images, labels

dataloader = DataLoader(dataset, batch_size=32, collate_fn=collate_fn)


4.3 内存映射优化

对于大型数据集,可使用内存映射文件(如HDF5)减少I/O开销:

import h5py

class HDF5Dataset(Dataset):
    def __init__(self, h5_path):
        self.h5_file = h5py.File(h5_path, 'r')
        self.length = len(self.h5_file['images'])

    def __getitem__(self, index):
        image = self.h5_file['images'][index]
        label = self.h5_file['labels'][index]
        return image, label

五、常见问题与解决方案

5.1 数据路径错误

  • 问题FileNotFoundError

  • 解决:使用os.path.join构建跨平台路径,检查文件权限。

5.2 内存不足

  • 问题:加载大型数据集时OOM

  • 解决

    • 使用生成器(IterableDataset)

    • 分批加载数据

    • 降低batch_size

5.3 数据类型不匹配

  • 问题RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor)

  • 解决:显式调用.to(device)或使用pin_memory=True加速GPU传输。

六、总结

自定义Dataset是PyTorch数据流水线的核心技能,通过继承Dataset类并实现__len____getitem__方法,开发者可以灵活处理各类数据格式。结合DataLoader的多线程加速和自定义collate_fn,可构建高效的数据加载管道。实际应用中需注意路径处理、内存优化和设备兼容性,以确保训练过程的稳定性。


zhangsir版权k3防采集https://mianka.xyz

扫描二维码推送至手机访问。

版权声明:本文由zhangsir or zhangmaam发布,如需转载请注明出处。

本文链接:https://mianka.xyz/post/186.html

分享给朋友:

“PyTorch自定义Dataset全解析:从理论到实战的完整指南” 的相关文章

如何向python 列表中添加元素

Python添加元素有三种方法:append、extend、insertappend:向列表添加元素,添加到尾部实例:list=[“my”,“name”,“is”,“mark”,“age”,18] print(“添加前:”,list) list.append(“test”) print(“添加...

如何用python获取一个网页的所有连接

如何用python获取一个网页的所有连接很简单直接上代码:# -*- coding: utf-8 -*- ''' 如何用python获取一个网页的所有连接 author:zhangsir ''' imp...

pip安装三方库 国内的一些镜像站点推荐

pip 国内的一些镜像站点推荐镜像套路:使用cmd;输入命令pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple 包名 即可开始安装。清华:https://pypi.tuna.tsinghua.edu.cn/simple 阿里云:http...

解决Django的request.POST获取不到请求参数的问题

这个是Django自身的问题:只要在请求头的添加"content-type":'application/x-www-form-urlencoded'就行。...

python 爬虫 报错:UnicodeDecodeError: ‘utf-8‘ codec can‘t decode byte 0x8b in position”解决方案

发现报错“UnicodeDecodeError: 'utf-8' codec can't decode byte 0x8b in position 1:invalid start byte”,方法一:根据报错提示,错误原因有一条是这样的:“'Accept-Encodi...

python 实现彩色图转素描图

python可以把彩色图片转化为铅笔素描草图,对人像、景色都有很好的效果。而且只需几行代码就可以一键生成,适合批量操作,非常的快捷。需要的第三方库:Opencv - 计算机视觉工具,可以实现多元化的图像视频处理,有Python接口""" Photo ...