伪造指定图像:CGAN 原理、实战代码与生成式 AI 优化指南

·

关键词:条件生成对抗网络、CGAN、GAN、标签生成图片、深度学习、PyTorch、手写数字、生成式 AI

前言:为什么我们需要 CGAN?

传统 GAN 只能随机生成一批“似真似假”的图片——无法控制猫、狗还是字母 A。
条件生成对抗网络(Conditional GAN, CGAN) 在 GAN 的基础上把一个“标签”塞进生成器和判别器,让输入什么,就输出什么。今天这篇文章将带你一文吃透这篇发表于 2014 年的论文,提供可直接运行的 PyTorch 代码,并总结常用于生产环境的优化技巧。
👉 想立刻体验 28×28 手写字指定数字生成吗?点这里进入即用模块!


1. GAN 速刷回顾

1.1 一句话总结生成对抗网络

1.2 目标函数(简写版)

[
\min_G \max_D \;
\mathbb{E}_{x \sim P_{data}}\!\left[\log D(x)\right] \;+\;
\mathbb{E}_{z \sim P_z}\!\left[\log\!\bigl(1 - D\bigl(G(z)\bigr)\bigr)\right]
]

1.3 遗留问题

GAN 只能随机 “开盲盒”,无法控制类别或风格——这就轮到 CGAN 出马了。


2. CGAN 的核心思路:把标签喂给网络

2.1 训练时的三件套

  1. 真实图像 x
  2. 真实标签 y
  3. 先验噪声 z

判别器不仅要判断图像真不真,还得看标签对不对
因此损失函数变成 带条件的期望

[
\min_G \max_D \;
\mathbb{E}_{x,y \sim P_{data}}\!\left[\log D(x|y)\right] \;+\;
\mathbb{E}_{z,y}\!\left[\log\!\bigl(1 - D\bigl(G(z|y)\bigr)\bigr)\right]
]

只要让 y 作为附加通道(如拼接、拼接后再全连接),网络就能利用类别信息进行训练。

2.2 标签的 3 种玩法

2.2.1 独热编码独狼

传统 MNIST 分类,每张图仅一个数字,10 维 one-hot 足以。[0,0,1,…,0] 代表数字“2”。

2.2.2 多重标签组合跑龙套

在 MIRFlickr 这类多标注数据集中,一张“三明治”照片被标记:{chicken, bread, butter, homemade}
用 Skip-gram 把关键词转成多热或多维向量,再塞进 CGAN,让模型一次性思考多个属性,从而提前实现文本到图像的雏形。

👉 动手试试把“夏季+海边+夕阳”一次性输进去,观察 CGAN 怎样作画!


3. 代码实战:PyTorch 从零运行手写数字 CGAN

下面提供经过精简、可直接复制的 PyTorch 工程模板。
主要配置:

import torch, torch.nn as nn, torch.optim as optim
from torchvision import datasets, transforms

# 超参
batch_size = 128
lr         = 2e-4
epochs     = 50
device     = 'cuda' if torch.cuda.is_available() else 'cpu'

# 数据加载
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('mnist', train=True, download=True, transform=transform),
    batch_size=batch_size, shuffle=True)

# 生成器:输入 [z(100) + y(10)] => 784(28*28 图像)
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc_z = nn.Sequential(nn.Linear(100, 256), nn.BatchNorm1d(256), nn.ReLU(True))
        self.fc_y = nn.Sequential(nn.Linear(10, 256),   nn.BatchNorm1d(256), nn.ReLU(True))
        self.fc1  = nn.Sequential(nn.Linear(512, 512),  nn.BatchNorm1d(512), nn.ReLU(True))
        self.fc2  = nn.Sequential(nn.Linear(512, 1024), nn.BatchNorm1d(1024),nn.ReLU(True))
        self.fc_out = nn.Sequential(nn.Linear(1024, 784), nn.Tanh())

    def forward(self, z, y):
        h_z = self.fc_z(z)
        h_y = self.fc_y(y)
        h = torch.cat([h_z, h_y], 1)
        return self.fc_out(self.fc2(self.fc1(h))).view(-1,1,28,28)

# 判别器:输入 [x(784) + y(10)] => 概率
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc_x = nn.Sequential(nn.Linear(784, 1024), nn.LeakyReLU(0.2))
        self.fc_y = nn.Sequential(nn.Linear(10,  1024), nn.LeakyReLU(0.2))
        self.fc1  = nn.Sequential(nn.Linear(2048,512),  nn.LeakyReLU(0.2))
        self.fc2  = nn.Sequential(nn.Linear(512, 256),   nn.LeakyReLU(0.2))
        self.fc_out = nn.Sequential(nn.Linear(256,1), nn.Sigmoid())

    def forward(self, x, y):
        x = x.view(-1,784)
        h_x = self.fc_x(x)
        h_y = self.fc_y(y)
        h = torch.cat([h_x, h_y], 1)
        return self.fc_out(self.fc2(self.fc1(h))).squeeze()

# 模型、损失、优化器
G = Generator().to(device)
D = Discriminator().to(device)
bce = nn.BCELoss()
opt_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5,0.999))
opt_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5,0.999))

# 训练循环简略版
for epoch in range(epochs):
    for real_imgs, labels in train_loader:
        bs = real_imgs.size(0)
        # 标注
        y_real = torch.ones(bs).to(device)
        y_fake = torch.zeros(bs).to(device)
        y_onehot = torch.zeros(bs,10).scatter_(1, labels.view(-1,1), 1).to(device)
        real_imgs = real_imgs.to(device)
        # 训练 D
        opt_D.zero_grad()
        real_loss = bce(D(real_imgs, y_onehot), y_real)
        z = torch.randn(bs, 100).to(device)
        fake_imgs = G(z, y_onehot)
        fake_loss = bce(D(fake_imgs.detach(), y_onehot), y_fake)
        (real_loss + fake_loss).backward()
        opt_D.step()
        # 训练 G
        opt_G.zero_grad()
        loss_G = bce(D(fake_imgs, y_onehot), y_real)
        loss_G.backward()
        opt_G.step()

3.1 训练效果快照


4. 常见疑难 FAQ

Q1:给标签也加入噪声可不可以?

可以,适当做随机翻转或 MixUp 可避免过拟合并提升泛化。

Q2:生成图像仍然颜色失真?

换用 BatchNorm + MSELoss 可选,有时比交叉熵收敛更稳。

Q3:如何支持 256×256 大图的多个类别?

  1. 把全连接改成 转置卷积 (ConvTranspose2d)
  2. Projection 标签:用 Embedding 把类别映射为 32 维向量,再在空间维度复制到与特征图尺寸一致。

Q4:训练费时怎么“SSR 骚折加速”?

Q5:工业生产时常内存爆?

采用 累积判别、累积生成 的双缓存写法,保证两者不会同时跑满显存。

Q6:如何评价“指定类别”质量?

常用 Inception Score、FID + 类别准确率 的加权指标,FID 越低、类别准确率越高越好。


5. 进阶玩法 & 展望

  1. StyleGAN3 + 条件注入:将 CGAN 思想嫁接到 StyleGAN3,可用 class conditioning 得到极高清结果。
  2. 扩散模型时代的“条件版”同样沿用 CGAN 思想:把 UNet 内参外插类别 Embedding,控制性更上一层楼。
  3. 商业落地:游戏皮肤预览、电商虚拟服饰试穿、新闻智能配图均可作为落地场景。

结语

CGAN 用极简思路一举打破 GAN “随机生成” 魔咒,奠定了现代文生图的理论雏形。
记得多调超参、多攒高质量标签,方能复现出媲美论文的高保真结果。祝你玩得开心!

torch.save(G.state_dict(), 'generator_cgan_final.pth')