2.3.4. 采用条件式GAN生成手写数字#

我们之前构建的MNIST GAN可以生成各种不同的手写数字输出图像,同时,也很好地避免了单一化和模式崩溃。

如果能通过某种方式引导GAN生成多样化的图像,同时又仅限于生成训练数据中的一类图像,例如,我们可以要求GAN生成不同的、但都代表数字8的图像,满足我们的特定需求,从而实现真正意义上的人机交互,那将是非常有价值的。

  1. 条件式GAN架构

为了让训练后的GAN生成器输出指定类型的图像,需要输入我们希望的输出类型。也就是说,我们需要将类型作为生成器输入的一部分,如同随机种子一样。

对于判别器,情况会更加复杂。我们现在希望判别器学习将类型标签与图像关联起来,而不仅仅是尝试将真实的图像和生成的图像分开。因此,我们需要将类型标签与图像一起输入判别器。

下图显示的架构是条件式(conditional)GAN。与GAN的主要区别在于,现在生成器和判别器的输入都在图像数据的基础上加入了类型标签。

图2-68 条件GAN架构图

图2-68 条件GAN架构图#

  1. 判别器

我们在之前的MNIST GAN基础上,实现这个架构。

首先,我们需要更新判别器,使它可以同时接收输入图像的像素数据和标签信息。一种简单的方法是扩展forward() 函数,使它可以同时接收图像张量和标签张量为输入变量,再直接将它们拼接起来。标签张量就是我们之前在Dataset类中创建的独热张量。

chapter_2_3_4_01.py#
 1import torch
 2from chapter_2_3_2_09 import Discriminator as Gan
 3
 4
 5class Discriminator(Gan):
 6    """
 7    此时的模型,只是修改了forward,网络主干部分未修改,需注意
 8    """
 9
10    def forward(self, image_tensor, label_tensor):
11        inputs = torch.cat((image_tensor, label_tensor))
12        return self.model(inputs)

通过torch.cat() 函数可以方便地将两个张量拼接起来。从Dataset类中返回的图像张量长度为784,标签张量的长度为10,所以拼接起来后的长度为794。

由于我们扩展了输入的大小,因此需要更改第一层神经网络的定义,将预期输入的大小改为784+10,即794。

chapter_2_3_4_02.py#
 1from torch import nn
 2from chapter_2_3_4_01 import Discriminator as Gan
 3
 4
 5class Discriminator(Gan):
 6    """
 7    此时的模型,只是修改了forward,网络主干部分未修改,需注意
 8    """
 9
10    def __init__(self):
11        nn.Module.__init__(self)
12        self.model = nn.Sequential(
13            # 考虑了标签张量的影响
14            nn.Linear(784 + 10, 200),
15            nn.LayerNorm(200),
16            nn.LeakyReLU(0.02),
17            nn.Linear(200, 1),
18            nn.Sigmoid())
19

我们还需要为随机生成的图像搭配一个随机类别标签,我们创建了一个函数generate_random_one_hot() ,来生成一个随机的独热标签向量。

chapter_2_3_4_03.py#
 1import torch
 2import numpy as np
 3
 4
 5def generate_random_one_hot(size):
 6    label_tensor = torch.zeros(size)
 7    random_idx = np.random.randint(0, size)
 8    # 随机令一位为1
 9    label_tensor[random_idx] = 1
10    return label_tensor
  1. 生成器

对于生成器,需要修改forward() 函数,把种子和标签张量输入生成器。因此。我们需要把输入参数拼接起来,再输入神经网络,仍需用到torch.cat() 函数。

chapter_2_3_4_04.py#
1import torch
2
3from chapter_2_3_2_09 import Generator as Gen
4
5
6class Generator(Gen):
7    def forward(self, seed_tensor, label_tensor):
8        inputs = torch.cat((seed_tensor, label_tensor))
9        return self.model(inputs)

网络的第一层需要修改,以便接收10个额外标签张量,变为100+10。

chapter_2_3_4_05.py#
 1from torch import nn
 2
 3from chapter_2_3_4_04 import Generator as Gen
 4
 5
 6class Generator(Gen):
 7
 8    def __init__(self):
 9        nn.Module.__init__(self)
10        self.model = nn.Sequential(
11            nn.Linear(100 + 10, 200),
12            nn.LayerNorm(200),
13            nn.LeakyReLU(0.02),
14            nn.Linear(200, 784),
15            nn.Sigmoid()
16        )
  1. 训练

训练循环同样需要修改,随机生成一个类别标签张量,在相应的位置输入给判别器和生成器。以下代码只显示了周期循环内的内容。我们在这里总共对条件式GAN训练10轮。

chapter_2_3_4_06.py#
 1import torch
 2
 3from chapter_2_3_4_05 import Generator
 4from chapter_2_3_4_02 import Discriminator
 5from chapter_2_3_2_02 import mnist_dataset
 6from chapter_2_3_2_09 import generate_random
 7from chapter_2_3_4_03 import generate_random_one_hot
 8
 9DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
10
11discriminator_net = Discriminator().to(DEVICE)
12generator_net = Generator().to(DEVICE)
13loss_function = torch.nn.BCELoss()
14
15optimizer_d = torch.optim.Adam(discriminator_net.parameters())
16optimizer_g = torch.optim.Adam(generator_net.parameters())
17progress_d_real = []
18progress_d_fake = []
19progress_g = []
20counter = 0
21# 真假标签
22real_label = torch.FloatTensor([1.0]).to(DEVICE)
23fake_label = torch.FloatTensor([0.0]).to(DEVICE)
24
25for i in range(10):
26    for label, real_data, target in mnist_dataset:
27        discriminator_net.zero_grad()
28        # 真实数据训练判别器
29        output = discriminator_net(real_data.to(DEVICE), target.to(DEVICE))
30        loss_d_real = loss_function(output, real_label)
31
32        # 生成数据训练判别器
33        random_label = generate_random_one_hot(10).to(DEVICE)
34        gen_img = generator_net(generate_random(100).to(DEVICE),
35                                random_label)
36        output = discriminator_net(gen_img.detach(), random_label)
37        loss_d_fake = loss_function(output, fake_label)
38        loss_d = loss_d_real + loss_d_fake
39        optimizer_d.zero_grad()
40        loss_d.backward()
41        optimizer_d.step()
42
43        # 训练生成器,使生成器生成的图像更真实
44        generator_net.zero_grad()
45        gen_img = generator_net(generate_random(100).to(DEVICE),
46                                random_label)
47        output = discriminator_net(gen_img,
48                                   random_label)
49        loss_g = loss_function(output, real_label)
50        optimizer_g.zero_grad()
51        loss_g.backward()
52        optimizer_g.step()
53
54        counter += 1
55        if counter % 500 == 0:
56            progress_d_real.append(loss_d_real.item())
57            progress_d_fake.append(loss_d_fake.item())
58            progress_g.append(loss_g.item())
59        if counter % 10000 == 0:
60            print(f'epoch = {i + 1}, counter = {counter}')
  1. 条件式GAN的结果

我们定义plot_conditional_image函数,它实现了生成并绘制指定标签的图像。

chapter_2_3_4_07.py#
 1import matplotlib.pyplot as plt
 2
 3from chapter_2_3_4_06 import *
 4
 5
 6def plot_conditional_images(label):
 7    label_tensor = torch.zeros(10)
 8    label_tensor[label] = 1.0
 9    f, ax_arr = plt.subplots(2, 3, figsize=(16, 8))
10    for i in range(2):
11        for j in range(3):
12            output = generator_net(generate_random(100).to(DEVICE),
13                                   label_tensor.to(DEVICE))
14            img = output.detach().cpu().numpy().reshape(28, 28)
15            ax_arr[i, j].imshow(img, interpolation='None', cmap='Blues')
16    plt.show()
17
18
19plot_conditional_images(9)

我们将9作为参数传给plot_conditional_image,查看条件式GAN生成的数字9的图像。

图2-69 条件GAN生成器输出

图2-69 条件GAN生成器输出#

通过仅仅10轮的训练,我们的条件式GAN不仅生成了几幅数字9的图像,而且这些图像都不一样。条件式GAN的完整代码可参考附录或扫描二维码下载。

待处理

待完善

生成指定类型的多样化图像具有很多应用场景,比如生成具有特定情绪表情的人像、具有指定风格的建筑等。而实现这一功能的关键在于,训练数据需要根据我们希望生成的类别进行标记。