4.2.5. DCGAN生成器的构建#

DCGAN的生成器主要通过转置卷积的上采样功能将输入的噪声数据逐步转换为图像数据。DCGAN的生成器由ConvTranspose2、BatchNorm2d、ReLU和Tanh组成。生成器借助ConvTranspose2d的上采样功能,实现对输入的噪声信号,即隐变量z的逐级放大。隐变量z的尺寸为(nz,1,1),经过生成器内部模块逐步处理后,最终生成的图像格式为(nc, image_size, image_size)。用我们设定的参数来举例,即从(100, 1, 1)的尺寸到(3, 64, 64)的尺寸。

生成器模型的实际代码如下所示。隐变量z的维度为100,此参数可按需设置,无其他特殊意义。与transforms.Compose类似,nn.Sequential也是一个序列容器,可按照一定顺序将各种神经网络模块装入其中。nn.Sequential把多个模块封装为一个模块,比逐层定义网络层更加方便。在forward方法接收到输入后,nn.Sequential便按照装入模块的顺序,依次计算并输出结果。

chapter_4_2_5_01.py#
 1from torch import nn
 2from chapter_4_2_3_01 import ngf, nz, nc
 3
 4
 5class Generator(nn.Module):
 6    def __init__(self):
 7        super(Generator, self).__init__()
 8        self.main = nn.Sequential(
 9            # 100 x 1 x 1
10            nn.ConvTranspose2d(in_channels=nz, out_channels=ngf * 8,
11                               kernel_size=4, stride=1, padding=0, bias=False),
12            nn.BatchNorm2d(ngf * 8),
13            nn.ReLU(True),
14            # (ngf*8) x 4 x 4,输入—输出计算过程:O=(I-1)*S+K-2P=(1-1)*1+4-2*0=4
15            nn.ConvTranspose2d(in_channels=ngf * 8, out_channels=ngf * 4,
16                               kernel_size=4, stride=2, padding=1, bias=False),
17            nn.BatchNorm2d(ngf * 4),
18            nn.ReLU(True),
19            # (ngf*4) x 8 x 8,输入—输出计算过程:O=(I-1)*S+K-2P=(4-1)*2+4-2*1=8
20            nn.ConvTranspose2d(in_channels=ngf * 4, out_channels=ngf * 2,
21                               kernel_size=4, stride=2, padding=1, bias=False),
22            nn.BatchNorm2d(ngf * 2),
23            nn.ReLU(True),
24            # (ngf*2) x 16 x 16,输入—输出计算过程:O=(I-1)*S+K-2P=(8-1)*2+4-2*1=16
25            nn.ConvTranspose2d(in_channels=ngf * 2, out_channels=ngf,
26                               kernel_size=4, stride=2, padding=1, bias=False),
27            nn.BatchNorm2d(ngf),
28            nn.ReLU(True),
29            # (ngf) x 32 x 32,输入—输出计算过程:O=(I-1)*S+K-2P=(16-1)*2+4-2*1=32
30            nn.ConvTranspose2d(in_channels=ngf, out_channels=nc,
31                               kernel_size=4, stride=2, padding=1, bias=False),
32            nn.Tanh()
33            # (nc) x 64 x 64,输入—输出计算过程:O=(I-1)*S+K-2P=(32-1)*2+4-2*1=64
34        )
35
36    def forward(self, inputs):
37        return self.main(inputs)