PyTorch数据加载(一)

简介

PyTorch具有广泛的神经网络构建模块,具有简单、直观和稳定的API。PyTorch包含为模型准备和加载公共数据集的包。

PyTorch数据加载实用工具的核心是torch.utils.data.DataLoader类。它表示数据集上的Python迭代器。PyTorch在torch.utils.data.Dataset类中提供了内置的高质量数据集。它们允许使用预加载的数据集以及自己的数据。Dataset存储样本及其对应的标签,DataLoader在Dataset周围包装了一个可迭代对象,以方便访问样本。这些数据集目前包含:

使用torchaudio.datasets中的Yesno数据集演示如何高效地将数据从PyTorch Dataset 加载到PyTorch DataLoader 中。

安装

需要安装 torchaudio 以访问数据集,pip install torchaudio

要在谷歌Colab中运行,取消注释,!pip install torchaudio

步骤

  1. 导入加载数据所需的所有库
  2. 访问数据集中的数据
  3. 加载数据
  4. 遍历数据
  5. [可选]可视化数据

导入加载数据所需的所有库

在这个样例中,使用 torchtorchaudio。根据使用的内置数据集,还可以安装和导入torchvisiontorchtext

import torch
import torchaudio

访问数据集中的数据

torchaudio中的Yesno数据集有60段记录,记录的是一个人用希伯来语说“yes”或“no”;每段录音有8个单词长。torchaudio.datasets.YESNO为YesNo创建数据集

yesno_data = torchaudio.datasets.YESNO(
     root='./',
     url='http://www.openslr.org/resources/1/waves_yesno.tar.gz',
     folder_in_archive='waves_yesno',
     download=True)

数据集中的每一项都是元组:(waveform, sample_rate, labels)。对Yesno数据集必须设置root参数,用于存储训练和测试数据集。其他参数是可选的,显示了它们的默认值。

加载数据

现在可以访问数据集了,将它传递给torch.utils.data.DataLoader。DataLoader将数据集和采样器结合起来,返回一个遍历数据集的迭代器。

data_loader = torch.utils.data.DataLoader(yesno_data,
                                          batch_size=1,
                                          shuffle=True)

遍历数据

数据现在可以使用data_loader进行迭代了。当开始训练模型时,这将是必要的!现在data_loader对象中的每个数据条目都被转换为一个张量,其中包含表示waveform、sample rate和labels的张量。

for data in data_loader:
  print("Data: ", data)
  print("Waveform: {}\nSample rate: {}\nLabels: {}".format(data[0], data[1], data[2]))
  break

可视化数据

可以选择将数据可视化,以进一步理解DataLoader的输出

import matplotlib.pyplot as plt

print(data[0][0].numpy())

plt.figure()
plt.plot(waveform.t().numpy())

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注