当前位置:首页 > python > 正文内容

基于PyTorch的图像分类任务实战:从数据预处理到模型优化

zhangsir4个月前 (07-07)python96

一、数据准备与预处理

1. 数据集选择与加载

以CIFAR-10数据集为例,该数据集包含60,000张32×32彩色图像,分为10个类别(飞机、汽车、鸟等),其中50,000张用于训练,10,000张用于测试。PyTorch通过torchvision.datasets.CIFAR10实现一键加载:

import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为Tensor并归一化至[0,1]
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))  # 均值方差标准化
])

train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

2. 数据增强技术

为提升模型泛化能力,需对训练数据进行随机变换:

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),  # 水平翻转
    transforms.RandomRotation(15),          # 随机旋转±15度
    transforms.ColorJitter(brightness=0.2, contrast=0.2),  # 亮度/对比度调整
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

效果验证:在Kaggle细胞分类竞赛中,采用数据增强后模型准确率从89%提升至94%。

二、模型构建与优化

1. 基础CNN模型实现

以3层卷积网络为例,展示从输入到输出的完整流程:

import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # 输出尺寸: [batch,32,16,16]
        x = self.pool(F.relu(self.conv2(x)))  # 输出尺寸: [batch,64,8,8]
        x = x.view(-1, 64 * 8 * 8)            # 展平操作
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

关键点

  • 卷积层负责特征提取,全连接层完成分类

  • ReLU激活函数引入非线性,避免梯度消失

  • 池化层降低特征维度,减少计算量

2. 迁移学习实战

利用预训练的ResNet50模型进行微调(Fine-tuning):

import torchvision.models as models

model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)  # 替换最后的全连接层

# 冻结前4个ResNet块参数(可选)
for param in model.parameters():
    param.requires_grad = False
for param in model.fc.parameters():
    param.requires_grad = True

优势

  • 在ImageNet上预训练的模型已学习到通用特征(如边缘、纹理)

  • 仅需少量数据即可达到高精度,尤其适合医学影像等标注成本高的领域

三、训练策略与调优

1. 损失函数与优化器选择

  • 交叉熵损失:适用于多分类任务,自动处理Softmax概率分布

  • Adam优化器:结合动量与自适应学习率,收敛速度快于SGD

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)  # L2正则化

2. 学习率调度

采用余弦退火策略动态调整学习率:

scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-6)

for epoch in range(100):
    # 训练代码...
    scheduler.step()

效果:在CIFAR-10实验中,该策略使模型在后期训练中跳出局部最优,最终准确率提升2.3%。

四、评估与部署

1. 评估指标

  • Top-1准确率:预测概率最高的类别是否正确

  • 混淆矩阵:分析各类别的误分类情况

from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt

def plot_confusion_matrix(y_true, y_pred, classes):
    cm = confusion_matrix(y_true, y_pred)
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.xticks(range(len(classes)), classes, rotation=45)
    plt.yticks(range(len(classes)), classes)
    plt.show()

2. 模型导出与部署

将训练好的模型转换为TorchScript格式,支持C++/Java等语言调用:

traced_script_module = torch.jit.trace(model, torch.rand(1, 3, 32, 32))
traced_script_module.save("model.pt")


zhangsir版权k3防采集https://mianka.xyz

扫描二维码推送至手机访问。

版权声明:本文由zhangsir or zhangmaam发布,如需转载请注明出处。

本文链接:https://mianka.xyz/post/190.html

分享给朋友:

“基于PyTorch的图像分类任务实战:从数据预处理到模型优化” 的相关文章

Python爬虫xpath详解

一、xpath介绍xpath是一门在 XML 文档中查找信息的语言。最初是用来搜寻 XML 文档的,但同样适用于 HTML 文档的搜索。所以在做爬虫时完全可以使用 XPath 做相应的信息抽取。二、安装lxmllxml是Python的一个第三方解析库,支持HTML和XML解析,而且效率非常高,弥补了...

如何用python获取一个网页的所有连接

如何用python获取一个网页的所有连接很简单直接上代码:# -*- coding: utf-8 -*- ''' 如何用python获取一个网页的所有连接 author:zhangsir ''' imp...

Python三方库ddddocr实现验证码识别

Python三方库ddddocr实现验证码识别环境要求python >= 3.8安装三方库pip install ddddocr -i https://pypi.tuna.tsinghua.edu.cn/simple参数说明:参数名参数类型默认值说明us...

python 爬虫 报错:UnicodeDecodeError: ‘utf-8‘ codec can‘t decode byte 0x8b in position”解决方案

发现报错“UnicodeDecodeError: 'utf-8' codec can't decode byte 0x8b in position 1:invalid start byte”,方法一:根据报错提示,错误原因有一条是这样的:“'Accept-Encodi...

python 实现彩色图转素描图

python可以把彩色图片转化为铅笔素描草图,对人像、景色都有很好的效果。而且只需几行代码就可以一键生成,适合批量操作,非常的快捷。需要的第三方库:Opencv - 计算机视觉工具,可以实现多元化的图像视频处理,有Python接口""" Photo ...

python selenium 使用代理ip

代码如下:from selenium import webdriver chromeOptions = webdriver.ChromeOptions() chromeOptions.add_argument("--proxy-serv...