2.3.4. 采用条件式GAN生成手写数字#
我们之前构建的MNIST GAN可以生成各种不同的手写数字输出图像,同时,也很好地避免了单一化和模式崩溃。
如果能通过某种方式引导GAN生成多样化的图像,同时又仅限于生成训练数据中的一类图像,例如,我们可以要求GAN生成不同的、但都代表数字8的图像,满足我们的特定需求,从而实现真正意义上的人机交互,那将是非常有价值的。
条件式GAN架构
为了让训练后的GAN生成器输出指定类型的图像,需要输入我们希望的输出类型。也就是说,我们需要将类型作为生成器输入的一部分,如同随机种子一样。
对于判别器,情况会更加复杂。我们现在希望判别器学习将类型标签与图像关联起来,而不仅仅是尝试将真实的图像和生成的图像分开。因此,我们需要将类型标签与图像一起输入判别器。
下图显示的架构是条件式(conditional)GAN。与GAN的主要区别在于,现在生成器和判别器的输入都在图像数据的基础上加入了类型标签。
图2-68 条件GAN架构图#
判别器
我们在之前的MNIST GAN基础上,实现这个架构。
首先,我们需要更新判别器,使它可以同时接收输入图像的像素数据和标签信息。一种简单的方法是扩展forward() 函数,使它可以同时接收图像张量和标签张量为输入变量,再直接将它们拼接起来。标签张量就是我们之前在Dataset类中创建的独热张量。
1import torch
2from chapter_2_3_2_09 import Discriminator as Gan
3
4
5class Discriminator(Gan):
6 """
7 此时的模型,只是修改了forward,网络主干部分未修改,需注意
8 """
9
10 def forward(self, image_tensor, label_tensor):
11 inputs = torch.cat((image_tensor, label_tensor))
12 return self.model(inputs)
通过torch.cat() 函数可以方便地将两个张量拼接起来。从Dataset类中返回的图像张量长度为784,标签张量的长度为10,所以拼接起来后的长度为794。
由于我们扩展了输入的大小,因此需要更改第一层神经网络的定义,将预期输入的大小改为784+10,即794。
1from torch import nn
2from chapter_2_3_4_01 import Discriminator as Gan
3
4
5class Discriminator(Gan):
6 """
7 此时的模型,只是修改了forward,网络主干部分未修改,需注意
8 """
9
10 def __init__(self):
11 nn.Module.__init__(self)
12 self.model = nn.Sequential(
13 # 考虑了标签张量的影响
14 nn.Linear(784 + 10, 200),
15 nn.LayerNorm(200),
16 nn.LeakyReLU(0.02),
17 nn.Linear(200, 1),
18 nn.Sigmoid())
19
我们还需要为随机生成的图像搭配一个随机类别标签,我们创建了一个函数generate_random_one_hot() ,来生成一个随机的独热标签向量。
1import torch
2import numpy as np
3
4
5def generate_random_one_hot(size):
6 label_tensor = torch.zeros(size)
7 random_idx = np.random.randint(0, size)
8 # 随机令一位为1
9 label_tensor[random_idx] = 1
10 return label_tensor
生成器
对于生成器,需要修改forward() 函数,把种子和标签张量输入生成器。因此。我们需要把输入参数拼接起来,再输入神经网络,仍需用到torch.cat() 函数。
1import torch
2
3from chapter_2_3_2_09 import Generator as Gen
4
5
6class Generator(Gen):
7 def forward(self, seed_tensor, label_tensor):
8 inputs = torch.cat((seed_tensor, label_tensor))
9 return self.model(inputs)
网络的第一层需要修改,以便接收10个额外标签张量,变为100+10。
1from torch import nn
2
3from chapter_2_3_4_04 import Generator as Gen
4
5
6class Generator(Gen):
7
8 def __init__(self):
9 nn.Module.__init__(self)
10 self.model = nn.Sequential(
11 nn.Linear(100 + 10, 200),
12 nn.LayerNorm(200),
13 nn.LeakyReLU(0.02),
14 nn.Linear(200, 784),
15 nn.Sigmoid()
16 )
训练
训练循环同样需要修改,随机生成一个类别标签张量,在相应的位置输入给判别器和生成器。以下代码只显示了周期循环内的内容。我们在这里总共对条件式GAN训练10轮。
1import torch
2
3from chapter_2_3_4_05 import Generator
4from chapter_2_3_4_02 import Discriminator
5from chapter_2_3_2_02 import mnist_dataset
6from chapter_2_3_2_09 import generate_random
7from chapter_2_3_4_03 import generate_random_one_hot
8
9DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
10
11discriminator_net = Discriminator().to(DEVICE)
12generator_net = Generator().to(DEVICE)
13loss_function = torch.nn.BCELoss()
14
15optimizer_d = torch.optim.Adam(discriminator_net.parameters())
16optimizer_g = torch.optim.Adam(generator_net.parameters())
17progress_d_real = []
18progress_d_fake = []
19progress_g = []
20counter = 0
21# 真假标签
22real_label = torch.FloatTensor([1.0]).to(DEVICE)
23fake_label = torch.FloatTensor([0.0]).to(DEVICE)
24
25for i in range(10):
26 for label, real_data, target in mnist_dataset:
27 discriminator_net.zero_grad()
28 # 真实数据训练判别器
29 output = discriminator_net(real_data.to(DEVICE), target.to(DEVICE))
30 loss_d_real = loss_function(output, real_label)
31
32 # 生成数据训练判别器
33 random_label = generate_random_one_hot(10).to(DEVICE)
34 gen_img = generator_net(generate_random(100).to(DEVICE),
35 random_label)
36 output = discriminator_net(gen_img.detach(), random_label)
37 loss_d_fake = loss_function(output, fake_label)
38 loss_d = loss_d_real + loss_d_fake
39 optimizer_d.zero_grad()
40 loss_d.backward()
41 optimizer_d.step()
42
43 # 训练生成器,使生成器生成的图像更真实
44 generator_net.zero_grad()
45 gen_img = generator_net(generate_random(100).to(DEVICE),
46 random_label)
47 output = discriminator_net(gen_img,
48 random_label)
49 loss_g = loss_function(output, real_label)
50 optimizer_g.zero_grad()
51 loss_g.backward()
52 optimizer_g.step()
53
54 counter += 1
55 if counter % 500 == 0:
56 progress_d_real.append(loss_d_real.item())
57 progress_d_fake.append(loss_d_fake.item())
58 progress_g.append(loss_g.item())
59 if counter % 10000 == 0:
60 print(f'epoch = {i + 1}, counter = {counter}')
条件式GAN的结果
我们定义plot_conditional_image函数,它实现了生成并绘制指定标签的图像。
1import matplotlib.pyplot as plt
2
3from chapter_2_3_4_06 import *
4
5
6def plot_conditional_images(label):
7 label_tensor = torch.zeros(10)
8 label_tensor[label] = 1.0
9 f, ax_arr = plt.subplots(2, 3, figsize=(16, 8))
10 for i in range(2):
11 for j in range(3):
12 output = generator_net(generate_random(100).to(DEVICE),
13 label_tensor.to(DEVICE))
14 img = output.detach().cpu().numpy().reshape(28, 28)
15 ax_arr[i, j].imshow(img, interpolation='None', cmap='Blues')
16 plt.show()
17
18
19plot_conditional_images(9)
我们将9作为参数传给plot_conditional_image,查看条件式GAN生成的数字9的图像。
图2-69 条件GAN生成器输出#
通过仅仅10轮的训练,我们的条件式GAN不仅生成了几幅数字9的图像,而且这些图像都不一样。条件式GAN的完整代码可参考附录或扫描二维码下载。
待处理
待完善
生成指定类型的多样化图像具有很多应用场景,比如生成具有特定情绪表情的人像、具有指定风格的建筑等。而实现这一功能的关键在于,训练数据需要根据我们希望生成的类别进行标记。