本帖最后由 宇宙微尘 于 2025-8-25 21:34 编辑
本文用通俗的语言描述了生成对抗网络(GAN)的核心原理与应用。通过详细介绍GAN的生成器与判别器的对抗训练过程,阐述了它们如何相互博弈,不断优化生成样本的质量。通过在PyTorch中实现一个简单的GAN模型,展示了其在生成手写数字图像方面的应用。文章还进一步探讨了GAN在多个领域的广泛应用,包括图像生成、数据增强、医学成像等,展示了其巨大的潜力与前景。
在金庸武侠小说《射雕英雄传》中,周伯通因为和东邪黄药师打赌而输,困于桃花岛十余年。
周伯通虽然心性顽劣,但在师兄的教导之下却极重信义,多年来一直没有越雷池半步。周伯通天性爱玩,在漫漫长夜中为了打发无聊时光,遂萌生“自己左手与右手”打架的想法,继而创造出金庸武学体系中的绝顶功夫-左右手互搏术。
公元2014年,在深度学习领域,Ian Goodfellow等人提出了生成对抗网络(GAN)模型。模型通过框架中两个模块:生成模型(Generative Model)和判别模型(Discriminative Model)的互相博弈学习产生高质量的输出。生成对抗网络被认为是当前最具前景、最具活跃度的模型之一,目前主要应用于样本数据生成、图像生成、图像修复、图像转换、文本生成等方向。标志着生成式人工智能的关键突破。GAN通过生成器和判别器的对抗训练,能够生成高度真实的图像和数据,是现代生成式AI的重要模型之一 。

生成对抗网络(GAN),它的核心思想是利用生成器和判别器进行博弈,逐步优化生成效果,使生成数据的分布逐渐逼近真实数据的分布。生成器负责从随机噪声中生成样本,尽量模仿真实数据的特征,以骗过判别器。而判别器则承担辨别真伪的角色,它接收生成器的样本和真实样本,不断提升判断的准确性。
在这一过程中,生成器与判别器处于不断对抗的状态。生成器每次生成样本后,判别器都通过判断结果来反馈生成器的生成质量。判别器不断更新参数以增强自身识别能力,迫使生成器在不断优化中生成更加逼真的样本。通过这样的对抗训练,生成器逐渐学会模仿真实数据的特征分布,最终生成的样本在判别器面前接近于真实数据,使得判别器难以分辨。
这种对抗性的架构使得GAN具备了强大的生成能力, GAN的出现推动了图像和视频生成的显著进展。
GAN基本原理
1. 核心构成
GAN由两个重要的部分构成:生成器(Generator,简写作G)和判别器(Discriminator,简写作D)。
生成器:通过机器生成数据,目的是尽可能“骗过”判别器,生成的数据记做G(z);
判别器:判断数据是真实数据还是「生成器」生成的数据,目的是尽可能找出「生成器」造的“假数据”。它的输入参数是x,x代表数据,输出D(x)代表x为真实数据的概率,如果为1,就代表100%是真实的数据,而输出为0,就代表不可能是真实的数据。
这样,G和D构成了一个动态对抗(或博弈过程),随着训练(对抗)的进行,G生成的数据越来越接近真实数据,D鉴别数据的水平越来越高。在理想的状态下,G可以生成足以“以假乱真”的数据;而对于D来说,它难以判定生成器生成的数据究竟是不是真实的,因此D(G(z)) = 0.5。训练完成后,我们得到了一个生成模型G,它可以用来生成以假乱真的数据。

2. 训练过程
第一阶段:固定「判别器D」,训练「生成器G」。使用一个性能不错的判别器,G不断生成“假数据”,然后给这个D去判断。开始时候,G还很弱,所以很容易被判别出来。但随着训练不断进行,G技能不断提升,最终骗过了D。这个时候,D基本属于“瞎猜”的状态,判断是否为假数据的概率为50%。
第二阶段:固定「生成器G」,训练「判别器D」。当通过了第一阶段,继续训练G就没有意义了。这时候我们固定G,然后开始训练D。通过不断训练,D提高了自己的鉴别能力,最终他可以准确判断出假数据。
重复第一阶段、第二阶段。通过不断的循环,「生成器G」和「判别器D」的能力都越来越强。最终我们得到了一个效果非常好的「生成器G」,就可以用它来生成数据。

基于PyTorch实现一个简单的GAN生成对抗网络模型
为了更好地理解生成对抗网络(GAN)的工作原理,我们将通过一个简单的例子进行实践。接下来,我们将展示如何使用PyTorch实现一个基本的GAN模型,目标是生成手写数字图像。这个过程将帮助我们进一步理解GAN中的生成器和判别器如何在对抗训练中不断优化,直到生成器能够生成足够真实的数据,使得判别器无法区分它们与真实数据之间的差异。通过实现一个简单的GAN,我们可以更直观地看到GAN模型的训练过程,并为后续深入研究更复杂的GAN模型奠定基础。
1. MNIST手写数字数据集简介
MNIST数据集总共包含两个子数据集:一个训练数据集和一个测试数据集。它们分别包含了60K和10K的28×28的灰度图像。

附下载链接
链接:https://pan.baidu.com/s/1TaL3dCHxAj17LgvSSd_eTA?pwd=xl8n
提取码:xl8n
以下是使用PyTorch实现一个简单的GAN模型,用于生成手写数字的示例。
2. 训练过程
训练过程可以分为两个主要阶段。首先,在第一阶段,我们固定判别器D,专注于训练生成器G。在此阶段,生成器G从随机噪声中生成假数据,并交给判别器D进行判断。最开始,生成器的能力较弱,生成的数据很容易被判别器识别为假数据。然而,随着训练的进行,生成器逐步提高生成数据的质量,最终使判别器难以辨别其生成的数据。此时,判别器基本处于猜测状态,无法准确判断数据的真伪。
接下来进入第二阶段,固定生成器G,开始训练判别器D。在这一阶段,判别器D通过学习如何识别生成器所生成的假数据来提升自己的鉴别能力,最终能够较为准确地区分真实数据与假数据。
这两个阶段交替进行,生成器和判别器在对抗过程中不断优化。通过反复训练,生成器和判别器的能力不断提升,最终生成器能够生成几乎无法与真实数据区分的高质量样本。
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
# 生成器
class Generator ( nn.Module ) :
def __init__ ( self, noise_dim ) :
super ( Generator, self ) .__init__ ( )
self.model = nn.Sequential (
nn.Linear ( noise_dim, 256 ) ,
nn.ReLU ( ) ,
nn.Linear ( 256, 512 ) ,
nn.ReLU ( ) ,
nn.Linear ( 512, 1024 ) ,
nn.ReLU ( ) ,
nn.Linear ( 1024, 28 * 28 ) ,
nn.Tanh ( ) # 输出范围 [ -1, 1 ]
)
def forward ( self, z ) :
return self.model ( z ) .reshape ( -1, 1, 28, 28 )
# 判别器
class Discriminator ( nn.Module ) :
def __init__ ( self ) :
super ( Discriminator, self ) .__init__ ( )
self.model = nn.Sequential (
nn.Flatten ( ) ,
nn.Linear ( 28 * 28, 512 ) ,
nn.LeakyReLU ( 0.2 ) ,
nn.Linear ( 512, 256 ) ,
nn.LeakyReLU ( 0.2 ) ,
nn.Linear ( 256, 1 ) ,
nn.Sigmoid ( ) # 输出范围 [ 0, 1 ]
)
def forward ( self, img ) :
return self.model ( img )
def draw_images ( generator, epoch
, examples=16, dim= ( 4,4 ) , figsize= ( 10,10 ) ) :
noise= np.random.normal ( loc=0, scale=1, size= [ examples, 100 ] )
generated_images = generator ( noise )
generated_images = generated_images.reshape ( 25,28,28 )
plt.figure ( figsize=figsize )
for i in range ( generated_images.shape [ 0 ] ) :
plt.subplot ( dim [ 0 ] , dim [ 1 ] , i+1 )
plt.imshow ( generated_images [ i ] , interpolation='nearest', cmap='Greys' )
plt.axis ( 'off' )
plt.tight_layout ( )
plt.savefig ( 'Generated_images %d.png' %epoch )
# 训练GAN
def train_gan ( epochs, noise_dim, batch_size ) :
# 加载MNIST手写数字数据集
dataset = datasets.MNIST ( root='.', train=True, download=True, transform=transforms.ToTensor ( ) )
dataloader = DataLoader ( dataset, batch_size=batch_size, shuffle=True )
generator = Generator ( noise_dim )
discriminator = Discriminator ( )
# 优化器
optimizer_G = optim.Adam ( generator.parameters ( ) , lr=0.0002 )
optimizer_D = optim.Adam ( discriminator.parameters ( ) , lr=0.0002 )
# 迭代训练
for epoch in range ( epochs ) :
for real_imgs, _ in dataloader:
### 训练判别器
optimizer_D.zero_grad ( )
# 由随机噪声,生成“假”图像
noise = torch.randn ( real_imgs.size ( 0 ) , noise_dim )
fake_imgs = generator ( noise )
d_loss_real = nn.BCELoss ( ) ( discriminator ( real_imgs ) , torch.ones ( real_imgs.size ( 0 ) , 1 ) )
d_loss_fake = nn.BCELoss ( ) ( discriminator ( fake_imgs.detach ( ) ) , torch.zeros ( real_imgs.size ( 0 ) , 1 ) )
d_loss = d_loss_real + d_loss_fake
d_loss.backward ( )
optimizer_D.step ( )
### 训练生成器
optimizer_G.zero_grad ( )
g_loss = nn.BCELoss ( ) ( discriminator ( fake_imgs ) , torch.ones ( real_imgs.size ( 0 ) , 1 ) )
g_loss.backward ( )
optimizer_G.step ( )
# 显示迭代过程中生成器生成的图像
if epoch == 1 or epoch % 10 == 0:
draw_images ( generator, epoch )
print ( f'Epoch {epoch}/{epochs}, D Loss: {d_loss.item ( ) }, G Loss: {g_loss.item ( ) }' )
# 调用训练函数
train_gan ( epochs=1000, noise_dim=100, batch_size=64 )

通过上述步骤,我们成功实现了一个简单的GAN网络。尽管这里的实现比较基础,但它为理解GAN的运作原理提供了良好的基础。随着对GAN的深入学习,大家可以尝试更多复杂的网络结构和优化策略,以提升生成样本的质量和多样性。
GAN的应用
随着技术的进步,GAN已被广泛应用于多个领域,展示了其强大的生成能力和灵活性。它们可以用于图像合成、语义图像编辑、风格迁移、图像超分辨率和分类等。它们也被用于医学领域的药物发现和医学成像。
1. 生成数据集
人工智能的训练是需要大量的数据集,可以通过GAN自动生成低成本的数据集。
2. 人脸生成

3. 物品生成

4. 图像转换


5. 图像修复


总结与展望
GAN不仅在图像生成和数据增强方面展现了巨大潜力,也在医学、艺术等多个领域产生了深远的影响。随着技术的不断进步,GAN的应用前景广阔,未来有望在更复杂的任务中发挥更大作用,推动各行各业的创新与发展。
文章改编转载自微信公众号:E等于mc平方
原文链接:
https://mp.weixin.qq.com/s/oy4VL7xBmNXAwhKFFM7tKA?scene=1
https://mp.weixin.qq.com/s/0DjW70jdoiRV1nGZ7lStiA?scene=1&click_id=20 |