4.3.3. 特征匹配损失#
训练DCGAN时,判别器输出一个0到1的值表示输入样本是真实样本的概率,生成器的目标是使该判别值最大。可见,我们只是依据判别值指导生成器更新参数,对于生成器来说可参考的信息太少。为了让生成器参考更多的信息,我们尝试使用特征匹配损失。
特征匹配损失可使生成样本与真实样本在判别器中间层输出的特征互相匹配,来改善对抗训练的不稳定性。从判别器的结构来看,整个网络由特征提取层和分类层组成。特征提取层提取输入样本的抽象特征,分类层则将抽象特征映射为一个判别值,完成对输入样本的分类。前面的训练方法我们只是参考了分类层的输出,而使用特征匹配损失可参考特征提取层的输出。分别截取生成样本与真实样本在判别器某一特征提取层的输出特征,并度量这两个抽象特征之间的差距,通过这个差距来指导生成器学习更多信息,使得生成的样本更加真实可靠。
特征匹配损失表示为生成样本在判别器中间层上特征的期望值与真实样本在判别器中间层上特征的期望值的均方误差。具体实现方面,因为特征匹配损失是为了更新生成器的参数,所以仅用于生成器,判别器仍使用原来的二分类交叉熵损失。
特征匹配损失的表达式如下:
其中f ( )表示判别器中间层的映射,表示求取期望。
基于PyTorch实现的代码如下:
1from torch import nn
2
3from chapter_4_2_3_01 import nc, ndf
4
5
6class Discriminator(nn.Module):
7 def __init__(self):
8 super(Discriminator, self).__init__()
9 self.layer1 = nn.Sequential(
10 nn.Conv2d(in_channels=nc, out_channels=ndf,
11 kernel_size=4, stride=2, padding=1, bias=False),
12 nn.LeakyReLU(0.2, inplace=True),
13 )
14 self.layer2 = nn.Sequential(
15 nn.Conv2d(in_channels=ndf, out_channels=ndf * 2,
16 kernel_size=4, stride=2, padding=1, bias=False),
17 nn.BatchNorm2d(ndf * 2),
18 nn.LeakyReLU(0.2, inplace=True),
19 )
20
21 self.layer3 = nn.Sequential(
22 nn.Conv2d(in_channels=ndf * 2, out_channels=ndf * 4,
23 kernel_size=4, stride=2, padding=1, bias=False),
24 nn.BatchNorm2d(ndf * 4),
25 nn.LeakyReLU(0.2, inplace=True),
26 )
27 self.layer4 = nn.Sequential(
28 nn.Conv2d(in_channels=ndf * 4, out_channels=ndf * 8,
29 kernel_size=4, stride=2, padding=1, bias=False),
30 nn.BatchNorm2d(ndf * 8),
31 nn.LeakyReLU(0.2, inplace=True),
32 )
33 self.layer5 = nn.Sequential(
34 nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
35 nn.Sigmoid()
36 )
37
38 def forward(self, inputs):
39 # 获取第一层输出的特征
40 out1 = self.layer1(inputs)
41 # 获取第二层输出的特征
42 out2 = self.layer2(out1)
43 # 获取第三层输出的特征
44 out3 = self.layer3(out2)
45 # 获取第四层输出的特征
46 out4 = self.layer4(out3)
47 # 获取模型输出
48 out5 = self.layer5(out4)
49 return out5, [out1, out2, out3, out4]
我们使用单独的层来表示判别器的网络结构,方便取出每一层的输出,进行特征匹配损失的计算。因为我们更改了判别器的输出,在对判别器进行训练时,只需要out5,所以我们使用NULLD表示判别器用不到的输出。对生成器进行训练时,先使用netD(real_cpu.detach())获得真实样本的中间特征,再使用netD(fake)获得生成样本的中间特征。这里我们分别取真实样本与生成样本在判别器self.layer4层的输出特征out4。然后对这两个特征求取对应的期望,并计算获得特征匹配损失。最后利用特征匹配损失更新模型。
基于PyTorch实现的代码如下,完整代码可扫描二维码下载。
待处理
待适配
1import torch
2from torch import nn, optim
3
4from chapter_4_2_3_01 import epochs, device, nz, beta
5from chapter_4_2_4_01 import dataloader
6from chapter_4_3_1_01 import D_real_label, D_fake_label
7from chapter_4_2_7_01 import netG
8from chapter_4_3_2_01 import optimizerG
9from chapter_4_3_3_01 import Discriminator
10
11# 使用新的Discriminator
12netD = Discriminator().to(device)
13optimizerD = optim.Adam(netD.parameters(), lr=0.0003, betas=(beta, 0.999))
14# 特征匹配损失
15criterion = nn.BCELoss()
16criterionG = nn.MSELoss()
17
18for epoch in range(epochs):
19 for i, data in enumerate(dataloader, 0):
20 # 用真实数据、real_label训练判别器
21 netD.zero_grad()
22 real_cpu = data[0].to(device)
23 b_size = real_cpu.size(0)
24 # 真标签
25 label = torch.full((b_size,), D_real_label, dtype=torch.float,
26 device=device)
27 # 使用NULLD变量表示不使用的输出
28 output, NULLD = netD(real_cpu)
29 output = output.view(-1)
30 # 计算判别器对真实数据的损失
31 errD_real = criterion(output, label)
32 errD_real.backward()
33 D_x = output.mean().item()
34
35 # 用噪声生成的数据、fake_label训练判别器
36 # 生成一个批次的噪声
37 noise = torch.randn(b_size, nz, 1, 1, device=device)
38 # 噪声输入生成器,得到生成图片
39 fake = netG(noise)
40 # 假标签
41 label.fill_(D_fake_label)
42 # 使用NULLD变量表示不使用的输出
43 output, NULLD = netD(fake.detach())
44 output = output.view(-1)
45 # 计算判别器对生成数据的损失
46 errD_fake = criterion(output, label)
47 errD_fake.backward()
48 D_G_z1 = output.mean().item()
49 # 计算判别器的损失
50 errD = errD_real + errD_fake
51 optimizerD.step()
52
53 # 训练生成器
54 netG.zero_grad()
55 # 基于真实数据,获取判别器特征
56 _, feature_real = netD(real_cpu.detach())
57 # 基于生成数据,获取判别器特征
58 output, feature_fake = netD(fake)
59 # 计算真实数据在判别器最后一层特征的期望
60 feature_real_last = torch.mean(feature_real[-1], 0)
61 # 计算生成数据在判别器最后一层特征的期望
62 feature_fake_last = torch.mean(feature_fake[-1], 0)
63 # 使用均方误差计算真实样本与生成样本的特征的损失
64 errG = criterionG(feature_fake_last, feature_real_last.detach())
65 errG.backward()
66 D_G_z2 = output.mean().item()
67 optimizerG.step()
采用特征匹配损失的训练时间约为189分钟,训练结果为图4-11右图,图4-11左图为未使用改进方法的训练结果。使用特征匹配损失后,部分生成样本有所改进,但一部分样本存在模式奔溃的现象。