简介
PyTorch具有广泛的神经网络构建模块,具有简单、直观和稳定的API。PyTorch包含为模型准备和加载公共数据集的包。
PyTorch数据加载实用工具的核心是torch.utils.data.DataLoader类。它表示数据集上的Python迭代器。PyTorch在torch.utils.data.Dataset类中提供了内置的高质量数据集。它们允许使用预加载的数据集以及自己的数据。Dataset存储样本及其对应的标签,DataLoader在Dataset周围包装了一个可迭代对象,以方便访问样本。这些数据集目前包含:
使用torchaudio.datasets中的Yesno数据集演示如何高效地将数据从PyTorch Dataset 加载到PyTorch DataLoader 中。
安装
需要安装 torchaudio 以访问数据集,pip install torchaudio
要在谷歌Colab中运行,取消注释,!pip install torchaudio
步骤
- 导入加载数据所需的所有库
- 访问数据集中的数据
- 加载数据
- 遍历数据
- [可选]可视化数据
导入加载数据所需的所有库
在这个样例中,使用 torch 和 torchaudio。根据使用的内置数据集,还可以安装和导入torchvision 或 torchtext
import torch
import torchaudio
访问数据集中的数据
torchaudio中的Yesno数据集有60段记录,记录的是一个人用希伯来语说“yes”或“no”;每段录音有8个单词长。torchaudio.datasets.YESNO为YesNo创建数据集
yesno_data = torchaudio.datasets.YESNO(
root='./',
url='http://www.openslr.org/resources/1/waves_yesno.tar.gz',
folder_in_archive='waves_yesno',
download=True)
数据集中的每一项都是元组:(waveform, sample_rate, labels)。对Yesno数据集必须设置root参数,用于存储训练和测试数据集。其他参数是可选的,显示了它们的默认值。
加载数据
现在可以访问数据集了,将它传递给torch.utils.data.DataLoader。DataLoader将数据集和采样器结合起来,返回一个遍历数据集的迭代器。
data_loader = torch.utils.data.DataLoader(yesno_data,
batch_size=1,
shuffle=True)
遍历数据
数据现在可以使用data_loader进行迭代了。当开始训练模型时,这将是必要的!现在data_loader对象中的每个数据条目都被转换为一个张量,其中包含表示waveform、sample rate和labels的张量。
for data in data_loader:
print("Data: ", data)
print("Waveform: {}\nSample rate: {}\nLabels: {}".format(data[0], data[1], data[2]))
break
可视化数据
可以选择将数据可视化,以进一步理解DataLoader的输出
import matplotlib.pyplot as plt
print(data[0][0].numpy())
plt.figure()
plt.plot(waveform.t().numpy())