4.2.3. 模型参数设置#

在这一小节中,我们将对一些需要设置的核心参数进行详细说明。在加载数据时,需要设置workers参数,DataLoader会创建workers参数数量的子进程用于数据加载。值大于1的workers可以加快数据读取速度,但会加重CPU负担。当内存有限时,设置过大的workers则容易导致内存溢出。因此workers参数与实验时的计算机硬件有较强相关性。简单起见,我们将workers参数设为0,读者可按需设置。batch_size参数用来设置每个批量中数据的多少,例如本例中我们共有9636张图片,如果batch_size设为128,则每个批量中有128张图片,一共有76个批量。image_size 参数将图片调整为统一的长宽尺寸,例如我们输入的图像分辨率为100×100,如果将image_size设为64,则调整后的图像分辨率为64×64。nc参数是指输入图片的通道数,对于普通的彩色图片,输入的通道数为3。nz参数用来设置输入的噪声信号即隐变量z的通道数,隐变量z通过torch.randn(B, nz, 1, 1)创建,它会输入到生成器,经过生成器内部模块处理,生成伪造图像。ngf设置生成器中特征图的基本维度。ndf设置判别器中特征图的基本维度。epochs为训练的次数。lr为优化器的学习率。

chapter_4_2_3_01.py#
 1import torch
 2
 3workers = 0
 4batch_size = 128
 5image_size = 64
 6nc = 3
 7nz = 100
 8ngf = 64
 9ndf = 64
10epochs = 100
11lr = 0.0002
12beta = 0.5
13device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")