4.2.4. 数据集加载与可视化#

我们使用的数据集是Kaggle中的Architecture Styles数据集[11]。此数据集除了能用于建筑风格分类外,也可用于其他深度学习任务,本章我们将其用于不同风格的建筑立面生成。我们对原数据集进行了简单处理,删除了一些质量不高的图片,得到的数据集包含24种风格的建筑立面,共9636张图像。处理后的数据集可扫描二维码下载。(二维码)

待处理

提供数据下载链接

现在我们已经有了所需的数据集,如何把这些数据进行加载,并处理成网络需要的格式呢?PyTorch中的ImageFolder类可以完成我们的需求。ImageFolder类默认数据集已经按照类型分成了不同的文件夹,一种类型的文件夹下只存放一种类型的图片。ImageFolder有两个主要参数:

  1. root:数据集的保存路径。

  2. transform:对数据进行转换操作,由一个transforms.Compose对象的实例表示。transforms.Compose对象可以看作一个容器,它能够装入多种数据变换操作。在这里,我们对数据进行了四种变换:transforms.Resize将输入图片按照我们设定的尺寸进行缩放;transforms.CenterCrop以输入图片的中心点为参考点,剪裁出一张长宽为设定值的图片;transforms.ToTensor将PIL图片格式转换为Tensor格式,以便PyTorch处理;transforms.Normalize将图像的每个通道按照设定的均值和标准差进行标准化。

通过ImageFolder,我们已将图片加载并处理为需要的数据格式,同时获得了所有数据的集合dataset。在训练模型时,需要将数据样本载入内存,并在每次迭代中对内存中的所有样本进行计算。因此,出于内存空间的限制以及对训练时间的考虑,我们不能将所有数据一次性载入模型,需要将dataset划分为多个批量(batch)输入到模型进行训练。DataLoader类可以完成这个任务。DataLoader起到一个采样器的作用,按照设定的batch_size将dataset划分为多个批量,其中shuffle参数决定对数据进行采样时是否打乱顺序。实际代码如下所示。

chapter_4_2_4_01.py#
 1import os
 2
 3import torchvision
 4from torchvision import transforms
 5
 6from chapter_4_2_3_01 import *
 7
 8# 数据存放路径
 9data_path = os.path.join("..", "data", "dataset")
10# 创建数据集dataset
11dataset = torchvision.datasets.ImageFolder(root=data_path,
12                                           transform=transforms.Compose([
13                                               transforms.Resize(image_size),
14                                               transforms.CenterCrop(
15                                                   image_size),
16                                               transforms.ToTensor(),
17                                               transforms.Normalize(
18                                                   (0.5, 0.5, 0.5),
19                                                   (0.5, 0.5, 0.5)),
20                                           ]))
21# 创建dataloader
22dataloader = torch.utils.data.DataLoader(dataset,
23                                         batch_size=batch_size,
24                                         shuffle=True,
25                                         num_workers=workers)

接下来,我们运行以下代码对一个批次的前64张图片进行可视化。首先,我们通过iter函数将dataloader转换为一个迭代器,以便用next函数取出其中的一个批次。同时,由于dataset和dataloader中的图片数据均是由PIL图像变换而成的,各个通道是按照R、G、B的顺序而存储的,而matplotlib中显示的图片通道需要按照G、B、R的顺序存储,因此我们采用numpy.transpose函数将三个通道的位置重新排布。可视化的结果如图4-7所示。

chapter_4_2_4_02.py#
 1import numpy as np
 2from matplotlib import pyplot as plt
 3from torchvision import utils as v_utils
 4
 5from chapter_4_2_3_01 import device
 6from chapter_4_2_4_01 import dataloader
 7
 8real_batch = next(iter(dataloader))
 9plt.figure(figsize=(8, 8))
10plt.axis("off")
11plt.title("Training Images")
12img_show = v_utils.make_grid(real_batch[0].to(device)[:64],
13                             padding=2,
14                             normalize=True)
15
16plt.imshow(np.transpose(img_show.cpu(), (1, 2, 0)))
17plt.show()
../../_images/4-7.png

图4-7 训练图片#