1. GAN基础概念与核心思想
生成对抗网络(GAN)由Ian Goodfellow在2014年提出,其核心思想是通过两个神经网络相互对抗来学习数据分布。想象一下艺术品鉴定师和赝品制造者的博弈过程:鉴定师不断学习识别真伪的技巧,而赝品制造者则持续改进伪造技术。这种动态平衡最终会使赝品达到以假乱真的程度。
GAN由两个关键组件构成:
- 生成器(Generator):接收随机噪声作为输入,输出合成数据。它的目标是生成足以欺骗判别器的样本。
- 判别器(Discriminator):接收真实数据和生成数据,输出其为真实样本的概率。它的目标是准确区分输入来源。
在实际训练中,这两个网络会进行交替优化。我用PyTorch实现时发现,这种对抗过程需要精细平衡两者的训练强度。如果判别器太强,生成器梯度会消失;反之则会导致模式崩溃(即生成器只产出有限的几种样本)。
2. 极小极大博弈的数学本质
2.1 目标函数推导
GAN的原始论文给出了经典的极小极大目标函数:
min_G max_D V(D,G) = E_{x∼p_data(x)}[log D(x)] + E_{z∼p_z(z)}[log(1-D(G(z)))]
这个公式包含两个关键期望:
- 第一项鼓励判别器对真实样本给出高概率
- 第二项使判别器对生成样本给出低概率
我在复现论文时注意到,生成器实际上只影响第二项。通过展开推导可以发现,对于固定生成器G,最优判别器D*满足:
D*(x) = p_data(x) / [p_data(x) + p_g(x)]
这个结论非常直观:判别器的最优表现取决于真实数据分布与生成分布的比值。
2.2 JS散度的出现
将最优判别器D*代回目标函数后,可以得到:
C(G) = -log(4) + 2*JSD(p_data||p_g)
其中JSD表示Jensen-Shannon散度。这说明GAN的训练本质上是在最小化真实分布与生成分布之间的JS散度。不过实际应用中,我们更常用的是改进的损失函数设计。
3. 生成器损失函数详解
3.1 原始形式的问题
原始GAN的生成器损失为:
L_G = E_{z∼p_z(z)}[log(1-D(G(z)))]
但在早期训练阶段,由于生成样本质量差,D(G(z))接近0,导致梯度非常小。我在MNIST实验中就遇到过生成器训练停滞的情况。
3.2 改进的损失函数
更实用的替代方案是最大化log(D(G(z))):
L_G = -E_{z∼p_z(z)}[log(D(G(z)))]
这相当于最小化生成分布与真实分布之间的KL散度。PyTorch实现时可以用BCELoss:
adversarial_loss = nn.BCELoss() g_loss = adversarial_loss(discriminator(fake_imgs), valid_labels)其中valid_labels是全1张量,因为生成器希望判别器将生成样本判定为真实。
4. 判别器损失函数解析
4.1 二分类交叉熵损失
判别器的目标函数包含两部分:
L_D = -E_{x∼p_data}[log D(x)] - E_{z∼p_z}[log(1-D(G(z)))]
对应PyTorch实现:
real_loss = adversarial_loss(discriminator(real_imgs), valid_labels) fake_loss = adversarial_loss(discriminator(fake_imgs.detach()), fake_labels) d_loss = (real_loss + fake_loss) / 2这里需要注意:
- fake_imgs要用detach()切断梯度反传
- 使用mean而非sum防止batch size影响损失尺度
4.2 梯度惩罚技巧
原始GAN容易导致梯度消失或爆炸。我的解决方案是加入梯度惩罚:
alpha = torch.rand(real_imgs.size(0), 1, 1, 1) interpolates = (alpha * real_imgs + (1-alpha) * fake_imgs).requires_grad_(True) d_interpolates = discriminator(interpolates) gradients = torch.autograd.grad( outputs=d_interpolates, inputs=interpolates, grad_outputs=torch.ones_like(d_interpolates), create_graph=True )[0] gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() d_loss += lambda_gp * gradient_penalty5. PyTorch实战MNIST生成
5.1 网络架构设计
对于MNIST生成任务,我推荐以下配置:
class Generator(nn.Module): def __init__(self, latent_dim): super().__init__() self.model = nn.Sequential( nn.Linear(latent_dim, 256), nn.LeakyReLU(0.2), nn.Linear(256, 512), nn.LeakyReLU(0.2), nn.Linear(512, 1024), nn.LeakyReLU(0.2), nn.Linear(1024, 784), nn.Tanh() ) def forward(self, z): return self.model(z).view(-1, 1, 28, 28)5.2 训练技巧分享
经过多次实验,我总结了几个关键点:
- 学习率设为2e-4通常效果不错
- 使用Adam优化器,β1=0.5
- 每训练判别器k次(k=1~5),训练生成器1次
- 定期保存生成样本可视化训练进度
完整的训练循环如下:
for epoch in range(epochs): for i, (real_imgs, _) in enumerate(dataloader): # 训练判别器 optimizer_D.zero_grad() z = torch.randn(batch_size, latent_dim) fake_imgs = generator(z) d_loss = compute_d_loss(real_imgs, fake_imgs) d_loss.backward() optimizer_D.step() # 训练生成器 if i % k == 0: optimizer_G.zero_grad() z = torch.randn(batch_size, latent_dim) fake_imgs = generator(z) g_loss = compute_g_loss(fake_imgs) g_loss.backward() optimizer_G.step()6. 常见问题与解决方案
在实现GAN的过程中,我遇到过几个典型问题:
模式崩溃:生成器只产生几种固定模式。解决方法包括:
- 使用minibatch discrimination
- 尝试不同的损失函数如Wasserstein loss
- 增加噪声多样性
训练不稳定:可以尝试:
- 梯度裁剪
- 学习率衰减
- 两时间尺度更新规则(TTUR)
评估困难:建议结合:
- 人工检查生成样本
- 计算IS(Inception Score)或FID
- 监控损失函数曲线
记得保存多个时间点的模型快照,这样当训练出现问题时可以回退到之前的稳定状态。我在实际项目中会每50个epoch保存一次模型和生成样本。