2.3.2. 采用GAN 生成手写数字#
我们从架构图入手,构建一个GAN。真实图像由我们在第2.2节中使用过的MNIST数据集提供。生成器的任务是生成相同大小的手写数字图像。随着训练的进展,我们希望生成的图像越来越真实,并可以骗过判别器。首先,让我们创建一个新的Note并导入所需的库。
1import random
2
3import pandas
4import numpy as np
5import torch
6import torch.nn as nn
7import matplotlib.pyplot as plt
8from torch.utils.data import Dataset
数据类
我们将使用之前创建的MnistDataset类加载数据集,对于数据集中的每个样本,我们将获得一个代表实际数字的标签、一个归一化的图像像素值张量,以及一个独热目标张量。另外,我们为MnistDataset类添加一个plot_image方法,它将对数据集中的图像进行可视化。我们可以通过绘制样本图像,测试Dataset类是否可以正常工作。
1import matplotlib.pyplot as plt
2
3from chapter_2_2_4_03 import MnistDataset as Dataset
4
5
6class MnistDataset(Dataset):
7 """
8 通过继承,复用2.2 章节中重复的代码
9 """
10
11 def plot_image(self, index):
12 img = self.data_df.iloc[index, 1:].values.reshape(28, 28)
13 plt.title("label = " + str(self.data_df.iloc[index, 0]))
14 plt.imshow(img, interpolation='none', cmap='Blues')
15 plt.show()
16
17
18mnist_dataset = MnistDataset('../data/mnist_train.csv')
19if __name__ == '__main__':
20 mnist_dataset.plot_image(0)
图2-56 可视化数据集中的第一个样本结果#
如图2-56所示,我们成功绘制了数据集中第一个样本的图像,它的标签是5。下面让我们开始用PyTorch搭建生成对抗网络模型吧!
MNIST判别器
我们先编辑判别器,GAN里面的判别器其实也是一个分类器。跟之前一样,它是一个继承自nn.Module的神经网络。我们按照PyTorch所需要的方式初始化网络,并创建一个forward() 函数。以下是判别器的构造函数。
1from torch import nn
2
3
4class Discriminator(nn.Module):
5 """
6 判别器
7 """
8
9 def __init__(self):
10 super().__init__()
11 self.model = nn.Sequential(
12 nn.Linear(784, 200),
13 nn.Sigmoid(),
14 nn.Linear(200, 1),
15 nn.Sigmoid()
16 )
17
18 def forward(self, inputs):
19 return self.model(inputs)
网络本身很简单。它在输入层有784个节点,因为输入是由28×28=784个像素组成的。在最后一层,其输出是单个值。当该值为1表示为真,该值为0则表示为伪。隐藏的中间层有200个节点,我们采用nn.Sequential将这些网络层按顺序堆叠起来。
测试判别器
在任何机器学习架构中,对重要组件的测试都很有必要。在构建生成器之前,我们先测试判别器,确保它至少能将真实图像与随机噪声区分开。我们定义一个生成随机噪声的函数generate_random() ,它将生成指定尺寸的0-1之间的随机数张量。
1import torch
2
3
4def generate_random(size):
5 random_data = torch.rand(size)
6 return random_data
和之前一样,我们需要定义损失函数和优化器,在这里我们采用MSELoss损失函数和SGD优化器,同时也创建counter和progress用于记录和输出训练进程。以下为测试判别器的代码。对于训练集中的真实图像,奖励判别器将训练数据判别为真,也就是目标输出1.0。对于每个生成数据样本,我们使用generate_random(
生成一幅由随机像素值组成的反例图像。我们训练判别器识别这些伪造数据,目标输出为0.0。
1import torch
2from torch import nn
3
4from chapter_2_2_4_03 import train_dataset
5from chapter_2_3_2_03 import Discriminator
6from chapter_2_3_2_04 import generate_random
7
8# 判别器
9discriminator = Discriminator()
10# 损失函数
11loss_function = nn.MSELoss()
12# 优化器
13optimizer = torch.optim.SGD(discriminator.parameters(), lr=0.01)
14counter = 0
15progress = []
16
17for label, image_data_tensor, target in train_dataset:
18 # 用真实数据、标签1.0训练判别器
19 output = discriminator(image_data_tensor)
20 real_loss = loss_function(output, torch.FloatTensor([1.0]))
21
22 # 用随机数据、标签0.0训练判别器
23 output = discriminator(generate_random(784))
24 fake_loss = loss_function(output, torch.FloatTensor([0.0]))
25
26 # 反向传播,更新参数
27 loss = real_loss + fake_loss
28 optimizer.zero_grad()
29 loss.backward()
30 optimizer.step()
31
32 counter += 1
33 if counter % 500 == 0:
34 progress.append(real_loss.item() + fake_loss.item())
35 if counter % 10000 == 0:
36 print('counter = ', counter)
训练过程中的损失值变化图如下所示,损失值下降并一直保持接近0的值,符合我们的预期。
图2-57 判别器损失值变化图#
MNIST生成器
下面我们开始搭建生成器网络。我们希望它的输出能骗过判别器,生成跟MNIST数据集中图像格式相同的、包含28×28=784像素的图像,这意味着输出层需要有784个节点。
生成器的隐藏层不需要局限于一个特定的大小,不过这个大小应该满足学习的需要,同时,需要配合判别器的学习速度。基于这些考量,许多人从反转判别器的构造入手来设计生成器。
反转后的网络的输出层有784个节点,隐含层有200个节点,输入层有1个节点,如图2-58所示,生成器所输出的784个像素值正是判别器所期待的输入。
图2-58 通过反转判别器得到生成器架构#
我们知道,对于给定的输入,一个神经网络的输出是不变的。然而,我们希望神经网络每次输出不同的、代表训练数据中所有数字的图像。例如,我们希望它生成的图像看起来像1、5、4、9等。为了实现这一设想,需要我们改变生成器的输入,在每个训练循环中,将一个随机值输入生成器。我们更新架构图,加入这个随机种子(random seed)。
图2-59 GAN架构图#
以下是生成器的代码。
1from torch import nn
2
3
4class Generator(nn.Module):
5 """
6 生成器
7 """
8
9 def __init__(self):
10 super().__init__()
11 self.model = nn.Sequential(
12 nn.Linear(1, 200),
13 nn.Sigmoid(),
14 nn.Linear(200, 784),
15 nn.Sigmoid()
16 )
17
18 def forward(self, inputs):
19 return self.model(inputs)
检查生成器输出
在正式训练GAN之前,我们需要检查生成器的输出格式是否正确。我们创建一个新的生成器对象,并输入一个随机种子,得到一个输出张量。我们可以通过output.shape来确认该张量有784个值。作为一幅图像,我们可以看到它是相当无规律的。这也符合我们的预期,因为这时生成器还没有经过训练。
1import matplotlib.pyplot as plt
2
3from chapter_2_3_2_06 import Generator
4from chapter_2_3_2_04 import generate_random
5
6# 测试生成器
7gen = Generator()
8# 绘制6张结果图 —— 未经过训练的生成器
9f, ax_arr = plt.subplots(2, 3, figsize=(16, 8))
10for i in range(2):
11 for j in range(3):
12 outputs = gen(generate_random(1))
13 img = outputs.detach().numpy().reshape(28, 28)
14 ax_arr[i, j].imshow(img, interpolation='None', cmap='Blues')
15plt.show()
图2-60 生成器的输出(0 epoch)#
训练GAN
让我们先看一下训练GAN的代码。从代码中可以看出,生成器类和判别器类的定义最明显的区别在于神经网络层的定义。
1import torch
2
3from chapter_2_3_2_06 import Generator
4from chapter_2_3_2_03 import Discriminator
5from chapter_2_2_4_03 import train_dataset
6from chapter_2_3_2_04 import generate_random
7
8discriminator = Discriminator()
9gen = Generator()
10loss_function = torch.nn.BCELoss()
11optimizer_d = torch.optim.Adam(discriminator.parameters())
12optimizer_g = torch.optim.Adam(gen.parameters())
13progress_d = []
14progress_g = []
15epoch_s = 10
16
17for i in range(epoch_s):
18 counter = 0
19 for label, real_data, target in train_dataset:
20 # (1) 用真实数据、1.0训练判别器
21 output = discriminator(real_data)
22 loss_d_real = loss_function(output, torch.FloatTensor([1.0]))
23 optimizer_d.zero_grad()
24 loss_d_real.backward()
25 optimizer_d.step()
26 # (2) 用生成数据、0.0训练判别器
27 output = discriminator(gen(generate_random(1)).detach())
28 loss_d_fake = loss_function(output, torch.FloatTensor([0.0]))
29 optimizer_d.zero_grad()
30 loss_d_fake.backward()
31 optimizer_d.step()
32 # (3) 训练生成器
33 output = discriminator(gen(generate_random(1)))
34 loss_g = loss_function(output, torch.FloatTensor([0.5]))
35 optimizer_g.zero_grad()
36 loss_g.backward()
37 optimizer_g.step()
38 counter += 1
39 # 保存loss,输出训练进度
40 if counter % 500 == 0:
41 progress_d.append(loss_d_fake.item() + loss_d_real.item())
42 progress_g.append(loss_g.item())
43 if counter % 10000 == 0:
44 print('epoch = {}, counter = {}'.format(i + 1, counter))
首先,我们创建了新的判别器和生成器对象。接着,我们运行训练循环1次。每次循环都重复训练GAN的3个步骤。
第1步,我们用真实的数据训练判别器。
第2步,我们使用一组生成数据来训练判别器。对于生成器输出,detach() 的作用是将其从计算图中分离出来。通常,对判别器损失直接调用backwards() 函数会计算整个计算图路径的所有误差梯度。这个路径从判别器损失开始,经过判别器,最后返回生成器。但我们只希望训练判别器,这么做可以明显地节省大网络的计算成本,因此不需要计算生成器的梯度。生成器的detach() 可以在该点切断计算图。图2-61更直观地解释了这一点。
图2-61 使用detach()函数切断生成器梯度传播#
第3步,我们输入判别器对象和单数值0.5训练生成器。生成器的训练与判别器的训练稍有不同。对于判别器,我们知道目标输出是什么。而对于生成器,我们不知道目标输出,但我们训练生成器的目标很明确:生成能够骗过判别器的图片。这意味着生成器所生成的图片在经过判别器后的输出,需要最大限度地接近真实标签。因此,我们将根据判别器的损失值计算的误差梯度来更新生成器。这里没有使用detach() ,是因为我们希望误差梯度从判别器损失传回生成器。生成器的train() 函数只更新生成器的链接权重,因此我们不需要防止判别器被更新。
完成训练需要几分钟的时间。让我们查看一个轮次的训练后所得到的判别器和生成器的损失图。
图2-62 判别器、生成器损失图(1 epoch)#
首先观察判别器的损失图,我们分别绘制了判别器对于真实数据和生成数据的损失值。可以看到,对于两种数据输入,损失值都先下降到0,并在一段时间内保持在较低水平,表明判别器领先于生成器。接着,损失值上升到0.25左右的位置,这表明判别器和生成器旗鼓相当。不过,判别器随后再次发力,损失值下降并趋近于0,表明该生成器没能学会骗过判别器。
对比观察生成器损失图。起初,判别器能够正确识别生成的图像,这是损失值偏高的原因。接着,生成器和判别器达到一些平衡,损失值下降到0.25上方并保持一段时间。随着训练的继续,判别器再次超过生成器,生成器损失值再度升高。
图2-63 生成器输出(1 epoch)#
此时,我们查看生成器生成的图像,可以发现生成的图像不是随机噪声,而是有某种形状。生成器最终能否学会生成手写数字呢?让我们继续运行代码,再训练9个轮次。
完成总共10个轮次的训练后,我们再次查看生成器输出的多幅图像,它们与真实的手写数字图像很像!然而,不难发现,这些图像显示的内容几乎都是相同的,像是在显示着同一个数字9。
即使图中显示的数字并不完美,但生成器却已经学会了创建类似的图像,我们用相对简单的代码实现了一个重要的工作!此部分完整代码可参考附录或扫描二维码下载。
待处理
待开源后提供下载github链接
图2-64 生成器输出(10 epochs)#
改良GAN
刚刚看到的现象,在GAN训练中非常常见,称为模式崩溃(mode collapse)。在MNIST的案例中,我们希望生成器能够创建代表所有10个数字的图像。当模式崩溃发生时,生成器只能生成10个数字中的一个或部分,无法达到要求。
图2-65 GAN的模式崩溃#
发生模式崩溃的原因还在分析研究中,许多相关的研究正在进行。下面的一些方法可以用来提高判别器对生成器反馈的质量。
(1)使用二元交叉熵BCELoss()代替损失函数中的均方误差MSELoss()。在神经网络执行分类任务时,二元交叉熵更适用。相比于均方误差,它更大程度地奖励正确的分类结果,同时惩罚错误的结果。
(2)将神经网络中的信号采用LayerNorm()进行归一化,以确保它们的均值为0。同时,归一化也可以有效地限制信号的方差,避免较大值引起的网络饱和。
(3)使用Adam优化器代替SGD优化器,并同时用于判别器和生成器。
(4)在生成过程的起始点提供更多的输入种子,且都是随机值。
(5)根据判别器和生成器的特点,输入不同的随机种子。对MNIST数据集来说,目前的测试是将判别器的性能与随机判断进行对比,输入判别器的随机值需要在0~1的范围内均匀抽取,对应真实数据集中图像像素的范围;输入生成器的随机值不需要符合0~1的范围,从一个平均值为0、方差为1的正态分布中抽取种子更加合理。
我们根据上述的方法对程序进行改进,选用BCELoss替代MSELoss,Adam代替SGD优化器。然后我们对神经网络进行修改,换用LeakyReLU作为激活函数,并在激活函数前使用LayerNorm对数据进行归一化,让数据尽可能集中在激活函数的敏感区域。同时,我们将生成器输入的随机噪声改为100个,并修改生成随机噪声的函数,换用torch.randn()函数,它将生成符合均值为0的标准正态分布的随机数。经过10轮的训练,生成器的生成结果如下图所示,它成功地生成了多种数字!
图2-66 改良GAN的生成器输出(10 epochs)#
改良GAN训练的完整代码可参考附录或扫描二维码下载。
待处理
后续附链接, chapter_2_3_2_09.py
除了以上提到的方法外,还有更多改良方法有待我们继续探索,请读者大胆尝试。
备注
如果达到了平衡:生成器的BCELoss和MSELoss应该分别是什么?
理论上,一个经过完美训练的GAN的最优MSELoss为0.25,最优BCELoss为ln 2。
回顾2.1.2小节中对损失函数的介绍,MSELoss的计算公式为:
对于一个经过完美训练的GAN,它的生成器能够实现以假乱真,鉴别器也能将真实模式与生成的假数据区别开来。在接收到以假乱真的生成数据时,鉴别器无法辨别输入的数据是真实的还是生成的,因此将输出0.5。代入公式(1)即可得到结果为0.25。
BCELoss是二分类专用的交叉熵计算函数。它的计算公式跟CELoss相同,但由于二分类的真实标签只有0、1两种可能,因此对于每个标签有:
其中yi是真实标签,p(yi)是模型输出的yi的概率。对于生成器生成的以假乱真的数据,标签y = 0,而鉴别器难以作出判断,因此会输出0.5,代入即可得到BCELoss = log 2。