关键词:条件生成对抗网络、CGAN、GAN、标签生成图片、深度学习、PyTorch、手写数字、生成式 AI
前言:为什么我们需要 CGAN?
传统 GAN 只能随机生成一批“似真似假”的图片——无法控制猫、狗还是字母 A。
条件生成对抗网络(Conditional GAN, CGAN) 在 GAN 的基础上把一个“标签”塞进生成器和判别器,让输入什么,就输出什么。今天这篇文章将带你一文吃透这篇发表于 2014 年的论文,提供可直接运行的 PyTorch 代码,并总结常用于生产环境的优化技巧。
👉 想立刻体验 28×28 手写字指定数字生成吗?点这里进入即用模块!
1. GAN 速刷回顾
1.1 一句话总结生成对抗网络
- 生成器 G:把噪声
z变“假图” - 判别器 D:给这幅图打 0(假)或 1(真)
- 二者互搏,纳什均衡时 G 骗过 D,即得到以假乱真的模型。
1.2 目标函数(简写版)
[
\min_G \max_D \;
\mathbb{E}_{x \sim P_{data}}\!\left[\log D(x)\right] \;+\;
\mathbb{E}_{z \sim P_z}\!\left[\log\!\bigl(1 - D\bigl(G(z)\bigr)\bigr)\right]
]
1.3 遗留问题
GAN 只能随机 “开盲盒”,无法控制类别或风格——这就轮到 CGAN 出马了。
2. CGAN 的核心思路:把标签喂给网络
2.1 训练时的三件套
- 真实图像
x - 真实标签
y - 先验噪声
z
判别器不仅要判断图像真不真,还得看标签对不对。
因此损失函数变成 带条件的期望:
[
\min_G \max_D \;
\mathbb{E}_{x,y \sim P_{data}}\!\left[\log D(x|y)\right] \;+\;
\mathbb{E}_{z,y}\!\left[\log\!\bigl(1 - D\bigl(G(z|y)\bigr)\bigr)\right]
]
只要让 y 作为附加通道(如拼接、拼接后再全连接),网络就能利用类别信息进行训练。
2.2 标签的 3 种玩法
2.2.1 独热编码独狼
传统 MNIST 分类,每张图仅一个数字,10 维 one-hot 足以。[0,0,1,…,0] 代表数字“2”。
2.2.2 多重标签组合跑龙套
在 MIRFlickr 这类多标注数据集中,一张“三明治”照片被标记:{chicken, bread, butter, homemade}。
用 Skip-gram 把关键词转成多热或多维向量,再塞进 CGAN,让模型一次性思考多个属性,从而提前实现文本到图像的雏形。
👉 动手试试把“夏季+海边+夕阳”一次性输进去,观察 CGAN 怎样作画!
3. 代码实战:PyTorch 从零运行手写数字 CGAN
下面提供经过精简、可直接复制的 PyTorch 工程模板。
主要配置:
- 数据集:MNIST(==> 0-9 手写体)
- 噪声维度:100
- 条件维度:10(对应 one-hot)
import torch, torch.nn as nn, torch.optim as optim
from torchvision import datasets, transforms
# 超参
batch_size = 128
lr = 2e-4
epochs = 50
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 数据加载
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('mnist', train=True, download=True, transform=transform),
batch_size=batch_size, shuffle=True)
# 生成器:输入 [z(100) + y(10)] => 784(28*28 图像)
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.fc_z = nn.Sequential(nn.Linear(100, 256), nn.BatchNorm1d(256), nn.ReLU(True))
self.fc_y = nn.Sequential(nn.Linear(10, 256), nn.BatchNorm1d(256), nn.ReLU(True))
self.fc1 = nn.Sequential(nn.Linear(512, 512), nn.BatchNorm1d(512), nn.ReLU(True))
self.fc2 = nn.Sequential(nn.Linear(512, 1024), nn.BatchNorm1d(1024),nn.ReLU(True))
self.fc_out = nn.Sequential(nn.Linear(1024, 784), nn.Tanh())
def forward(self, z, y):
h_z = self.fc_z(z)
h_y = self.fc_y(y)
h = torch.cat([h_z, h_y], 1)
return self.fc_out(self.fc2(self.fc1(h))).view(-1,1,28,28)
# 判别器:输入 [x(784) + y(10)] => 概率
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.fc_x = nn.Sequential(nn.Linear(784, 1024), nn.LeakyReLU(0.2))
self.fc_y = nn.Sequential(nn.Linear(10, 1024), nn.LeakyReLU(0.2))
self.fc1 = nn.Sequential(nn.Linear(2048,512), nn.LeakyReLU(0.2))
self.fc2 = nn.Sequential(nn.Linear(512, 256), nn.LeakyReLU(0.2))
self.fc_out = nn.Sequential(nn.Linear(256,1), nn.Sigmoid())
def forward(self, x, y):
x = x.view(-1,784)
h_x = self.fc_x(x)
h_y = self.fc_y(y)
h = torch.cat([h_x, h_y], 1)
return self.fc_out(self.fc2(self.fc1(h))).squeeze()
# 模型、损失、优化器
G = Generator().to(device)
D = Discriminator().to(device)
bce = nn.BCELoss()
opt_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5,0.999))
opt_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5,0.999))
# 训练循环简略版
for epoch in range(epochs):
for real_imgs, labels in train_loader:
bs = real_imgs.size(0)
# 标注
y_real = torch.ones(bs).to(device)
y_fake = torch.zeros(bs).to(device)
y_onehot = torch.zeros(bs,10).scatter_(1, labels.view(-1,1), 1).to(device)
real_imgs = real_imgs.to(device)
# 训练 D
opt_D.zero_grad()
real_loss = bce(D(real_imgs, y_onehot), y_real)
z = torch.randn(bs, 100).to(device)
fake_imgs = G(z, y_onehot)
fake_loss = bce(D(fake_imgs.detach(), y_onehot), y_fake)
(real_loss + fake_loss).backward()
opt_D.step()
# 训练 G
opt_G.zero_grad()
loss_G = bce(D(fake_imgs, y_onehot), y_real)
loss_G.backward()
opt_G.step()3.1 训练效果快照
- 第 5 epoch:模糊数字轮廓
- 第 20 epoch:边缘锋利,形状变清晰
- 第 50 epoch:人类已难区分是否来自 MNIST 训练集
4. 常见疑难 FAQ
Q1:给标签也加入噪声可不可以?
可以,适当做随机翻转或 MixUp 可避免过拟合并提升泛化。
Q2:生成图像仍然颜色失真?
换用 BatchNorm + MSELoss 可选,有时比交叉熵收敛更稳。
Q3:如何支持 256×256 大图的多个类别?
- 把全连接改成 转置卷积 (ConvTranspose2d)。
- Projection 标签:用 Embedding 把类别映射为 32 维向量,再在空间维度复制到与特征图尺寸一致。
Q4:训练费时怎么“SSR 骚折加速”?
- 小批量高梯度累加:mini-batch=64,梯度累积 4 步,效果近似 batch=256。
- 利用 混合精度(apex/amp),在保持数值精度的情况下节省显存。
Q5:工业生产时常内存爆?
采用 累积判别、累积生成 的双缓存写法,保证两者不会同时跑满显存。
Q6:如何评价“指定类别”质量?
常用 Inception Score、FID + 类别准确率 的加权指标,FID 越低、类别准确率越高越好。
5. 进阶玩法 & 展望
- StyleGAN3 + 条件注入:将 CGAN 思想嫁接到 StyleGAN3,可用
class conditioning得到极高清结果。 - 扩散模型时代的“条件版”同样沿用 CGAN 思想:把 UNet 内参外插类别 Embedding,控制性更上一层楼。
- 商业落地:游戏皮肤预览、电商虚拟服饰试穿、新闻智能配图均可作为落地场景。
结语
CGAN 用极简思路一举打破 GAN “随机生成” 魔咒,奠定了现代文生图的理论雏形。
记得多调超参、多攒高质量标签,方能复现出媲美论文的高保真结果。祝你玩得开心!
torch.save(G.state_dict(), 'generator_cgan_final.pth')