Pytorch深度学习框架-data

前言
处理数据样本的代码可能会变得混乱且难以维护;我们理想情况下希望我们的数据集代码与我们的模型训练代码解耦,以获得更好的可读性和模块化。
以上出自Pytorch官方文档的原语,更好的可读性和模块化更有利于程序员进行开发,因此今天我们尝试了解Pytorch框架下数据集和数据处理的一些方式。
加载数据集
数据集加载过程中需要指定下列参数:
root
-存储数据的路径train
是否为训练集,True表示是训练集,False表示是测试集。download
是否从互联网下载,若数据在root路径下不可用,设置downald为True可以从互联网上下载所需的数据。transform
或者target_transform
指定了特征和标签转换。transform 和 target_transform,分别用于转换输入和目标。
Warn
当使用 download=True 创建数据集对象时,文件首先在根目录中下载和解压缩。此下载逻辑不是多进程安全的,因此如果在分布式设置中运行,可能会导致冲突/竞争条件。在分布式模式下,我们建议在设置分布式模式之前创建一个虚拟数据集对象来触发下载逻辑。
这里作者也不是很理解,暂时先留在这里,留待以后继续解决。
仅从上面的解释还是不够清楚,下面是一个从TorchVision
加载Fashion-MNIST
数据集的一个示例:
1 | import torch |
Tip
TorchVision 在 torchvision.datasets 模块中提供了许多内置数据集,以及用于构建自定义数据集的实用工具类。TorchVision专门用来处理图像,通常用于计算机视觉领域。
而Fashion-MNIST 是 Zalando 的文章图像数据集,包含 60,000 个训练示例和 10,000 个测试示例。每个示例包含一个 28×28 灰度图像以及来自 10 个类别之一的相关标签。
可视化数据集
使用matplotlib库可以实现数据集的可视化。
1 | import matplotlib.pyplot as plt |
自定义数据集
自定义数据集需要从torch.utils.data.Dataset
也就是我们在前文导入的Dataset
继承得到一个新的子类,这个子类必须要实现以下三个函数:
__init__()
实例化 Dataset 对象时运行一次,初始化包含数据的目录、注释文件和两个转换__len__()
函数返回数据集中样本的数量。__getitem__()
函数加载并返回给定索引 idx 处数据集的样本。根据索引,它识别磁盘上数据的位置,将数据转换为张量,并在元组中返回张量数据和相应的标签。
下面以FashionMNIST 数据集为例,这些数据图像存储在目录 img_dir 中,它们的标签单独存储在 CSV 文件 annotations_file 中。
其中label.csv文件为如下格式:
1 | tshirt1.jpg, 0 |
此时实现的自定义数据集为:
1 | import os |
小批量传递数据
Dataset 一次检索我们数据集的特征和标签一个样本。在训练模型时,我们通常希望以“小批量”传递样本,在每个 epoch 重新打乱数据以减少模型过拟合,并使用 Python 的 multiprocessing 来加速数据检索。
为了实现上述的目的,需要使用torch.utils.data.DataLoader
也就是DataLoader
。
使用的方式如下:
1 | from torch.utils.data import DataLoader |
迭代DataLoader
DataLoader是一个可迭代的对象,我们可以迭代进行模型训练等操作,如下列示例:
1 | # Display image and label. |
后记
以上就是数据集的一些信息了,如果想要了解更多的内容,可以直接在官方文档寻找。
同系列
Pytorch深度学习框架-insatll
Pytorch深度学习框架-tensor
Pytorch深度学习框架-data
Pytorch深度学习框架-transform
Pytorch深度学习框架-build
Pytorch深度学习框架-train
Pytorch深度学习框架-save-load
- Title: Pytorch深度学习框架-data
- Author: 呆呆的猪胖胖
- Created at : 2025-03-20 15:55:00
- Updated at : 2025-05-13 15:34:43
- Link: https://blog.cflmy.cn/2025/03/20/Technology/AI/Pytorch-data/
- License: This work is licensed under CC BY-NC-SA 4.0.