4.2.6. DCGAN判别器的构建#

DCGAN的判别器对生成的图像和真实的图像进行判别。对于一张输入到判别器的图像,只有真和假,这是一个二分类问题,因此判别器相当于一个二分类网络。对于生成的图像,应将其分类为假,对于真实的图像,应将其分类为真。判别器的输入为一张图像,经过判别器的内部模块逐步处理后,输出为1或0。与生成器相对的,判别器借助Conv2d的下采样功能,来实现对图像的逐级缩小,达到特征提取的目的。最后,我们使用Sigmoid函数将提取的特征映射到0-1之间,作为网络的输出。

chapter_4_2_6_01.py#
 1from torch import nn
 2from chapter_4_2_3_01 import nc, ndf
 3
 4
 5class Discriminator(nn.Module):
 6    """
 7    DCGAN判别器的构建
 8    """
 9
10    def __init__(self):
11        super(Discriminator, self).__init__()
12        self.main = nn.Sequential(
13            # (nc) x 64 x 64
14            nn.Conv2d(in_channels=nc, out_channels=ndf,
15                      kernel_size=4, stride=2, padding=1, bias=False),
16            nn.LeakyReLU(0.2, inplace=True),
17            # (ndf) x 32 x 32,输入—输出计算过程:
18            # O=(I-K+2P)/S+1=(64-4+2*1)/2+1=32
19            nn.Conv2d(in_channels=ndf, out_channels=ndf * 2,
20                      kernel_size=4, stride=2, padding=1, bias=False),
21            nn.BatchNorm2d(ndf * 2),
22            nn.LeakyReLU(0.2, inplace=True),
23            # (ndf*2) x 16 x 16,输入—输出计算过程:
24            # O=(I-K+2P)/S+1=(32-4+2*1)/2+1=16
25            nn.Conv2d(in_channels=ndf * 2, out_channels=ndf * 4,
26                      kernel_size=4, stride=2, padding=1, bias=False),
27            nn.BatchNorm2d(ndf * 4),
28            nn.LeakyReLU(0.2, inplace=True),
29            # (ndf*4) x 8 x 8,输入—输出计算过程:
30            # O=(I-K+2P)/S+1=(16-4+2*1)/2+1=8
31            nn.Conv2d(in_channels=ndf * 4, out_channels=ndf * 8,
32                      kernel_size=4, stride=2, padding=1, bias=False),
33            nn.BatchNorm2d(ndf * 8),
34            nn.LeakyReLU(0.2, inplace=True),
35            # (ndf*8) x 4 x 4,输入—输出计算过程:
36            # O=(I-K+2P)/S+1=(8-4+2*1)/2+1=4
37            nn.Conv2d(in_channels=ndf * 8, out_channels=1,
38                      kernel_size=4, stride=1, padding=0, bias=False),
39            nn.Sigmoid()
40            # 1 x 1 x 1,输入—输出计算过程:
41            # O=(I-K+2P)/S+1=(4-4+2*0)/2+1=1
42        )
43
44    def forward(self, inputs):
45        return self.main(inputs)