1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
| from d2l import torch as d2l import torch import torchvision from matplotlib_inline import backend_inline from torch.utils import data from torchvision import transforms import matplotlib.pyplot as plt
trans = transforms.ToTensor() mnist_train = torchvision.datasets.FashionMNIST( root="../../data", train=True, transform=trans, download=True) mnist_test = torchvision.datasets.FashionMNIST( root="../../data", train=False, transform=trans, download=True) """ Fashion-MNIST由10个类别的图像组成, 每个类别由训练数据集(train dataset) 中的6000张图像 和测试数据集(test dataset)中的1000张图像组成。 因此,训练集和测试集分别包含60000和10000张图像。 测试数据集不会用于训练,只用于评估模型性能。 """ print('mnist_train_len', len(mnist_train)) print('mnist_test_len', len(mnist_test))
""" 每个输入图像的高度和宽度均为28像素。 数据集由灰度图像组成,其通道数为1。 """ print('数据的shape', mnist_train[0][0].shape) """ Fashion-MNIST中包含的10个类别,分别为t-shirt(T恤)、trouser(裤子)、 pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、 shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。 """
def get_fashion_mnist_labels(labels): """返回Fashion-MNIST数据集的文本标签""" text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] return [text_labels[int(i)] for i in labels]
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): """绘制图像列表""" figsize = (num_cols * scale, num_rows * scale) _, axes = plt.subplots(num_rows, num_cols, figsize = figsize) axes = axes.flatten() for i, (ax, img) in enumerate(zip(axes, imgs)): if torch.is_tensor(img): ax.imshow(img.numpy()) else: ax.imshow(img) ax.axes.get_xaxis().set_visible(False) ax.axes.get_yaxis().set_visible(False) if titles: ax.set_title(titles[i]) return axes
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
batch_size = 256
def get_dataloader_workers(): """使用4个进程来读取数据""" return 4
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers())
if __name__ == "__main__": ''' 李沐老师的课程展示中使用的环境是linux没有报错,但win10下在dataloader使用前需要加上: if __name__ == "__main__": ''' timer = d2l.Timer() for X, y in train_iter: continue print(f'{timer.stop():.2f} sec')
|