一文轻松看懂生成对抗网络(GAN)—— 原理、实现与应用

宇宙微尘
2025-08-25 18:57:43
人工智能
算法解析
本帖最后由 宇宙微尘 于 2025-8-25 21:34 编辑


本文用通俗的语言描述了生成对抗网络(GAN)的核心原理与应用。通过详细介绍GAN的生成器与判别器的对抗训练过程,阐述了它们如何相互博弈,不断优化生成样本的质量。通过在PyTorch中实现一个简单的GAN模型,展示了其在生成手写数字图像方面的应用。文章还进一步探讨了GAN在多个领域的广泛应用,包括图像生成、数据增强、医学成像等,展示了其巨大的潜力与前景。



在金庸武侠小说《射雕英雄传》中,周伯通因为和东邪黄药师打赌而输,困于桃花岛十余年。


周伯通虽然心性顽劣,但在师兄的教导之下却极重信义,多年来一直没有越雷池半步。周伯通天性爱玩,在漫漫长夜中为了打发无聊时光,遂萌生“自己左手与右手”打架的想法,继而创造出金庸武学体系中的绝顶功夫-左右手互搏术。


公元2014年,在深度学习领域,Ian Goodfellow等人提出了生成对抗网络(GAN)模型。模型通过框架中两个模块:生成模型(Generative Model)和判别模型(Discriminative Model)的互相博弈学习产生高质量的输出。生成对抗网络被认为是当前最具前景、最具活跃度的模型之一,目前主要应用于样本数据生成、图像生成、图像修复、图像转换、文本生成等方向。标志着生成式人工智能的关键突破。GAN通过生成器和判别器的对抗训练,能够生成高度真实的图像和数据,是现代生成式AI的重要模型之一 。



生成对抗网络(GAN),它的核心思想是利用生成器和判别器进行博弈,逐步优化生成效果,使生成数据的分布逐渐逼近真实数据的分布。生成器负责从随机噪声中生成样本,尽量模仿真实数据的特征,以骗过判别器。而判别器则承担辨别真伪的角色,它接收生成器的样本和真实样本,不断提升判断的准确性。


在这一过程中,生成器与判别器处于不断对抗的状态。生成器每次生成样本后,判别器都通过判断结果来反馈生成器的生成质量。判别器不断更新参数以增强自身识别能力,迫使生成器在不断优化中生成更加逼真的样本。通过这样的对抗训练,生成器逐渐学会模仿真实数据的特征分布,最终生成的样本在判别器面前接近于真实数据,使得判别器难以分辨。


这种对抗性的架构使得GAN具备了强大的生成能力, GAN的出现推动了图像和视频生成的显著进展。


GAN基本原理


1. 核心构成


GAN由两个重要的部分构成:生成器(Generator,简写作G)和判别器(Discriminator,简写作D)


生成器:通过机器生成数据,目的是尽可能“骗过”判别器,生成的数据记做G(z);


判别器:判断数据是真实数据还是「生成器」生成的数据,目的是尽可能找出「生成器」造的“假数据”。它的输入参数是x,x代表数据,输出D(x)代表x为真实数据的概率,如果为1,就代表100%是真实的数据,而输出为0,就代表不可能是真实的数据。


这样,G和D构成了一个动态对抗(或博弈过程),随着训练(对抗)的进行,G生成的数据越来越接近真实数据,D鉴别数据的水平越来越高。在理想的状态下,G可以生成足以“以假乱真”的数据;而对于D来说,它难以判定生成器生成的数据究竟是不是真实的,因此D(G(z)) = 0.5。训练完成后,我们得到了一个生成模型G,它可以用来生成以假乱真的数据。



2. 训练过程


第一阶段固定「判别器D」,训练「生成器G」。使用一个性能不错的判别器,G不断生成“假数据”,然后给这个D去判断。开始时候,G还很弱,所以很容易被判别出来。但随着训练不断进行,G技能不断提升,最终骗过了D。这个时候,D基本属于“瞎猜”的状态,判断是否为假数据的概率为50%。


第二阶段固定「生成器G」,训练「判别器D」。当通过了第一阶段,继续训练G就没有意义了。这时候我们固定G,然后开始训练D。通过不断训练,D提高了自己的鉴别能力,最终他可以准确判断出假数据。


重复第一阶段、第二阶段。通过不断的循环,「生成器G」和「判别器D」的能力都越来越强。最终我们得到了一个效果非常好的「生成器G」,就可以用它来生成数据。



基于PyTorch实现一个简单的GAN生成对抗网络模型


为了更好地理解生成对抗网络(GAN)的工作原理,我们将通过一个简单的例子进行实践。接下来,我们将展示如何使用PyTorch实现一个基本的GAN模型,目标是生成手写数字图像。这个过程将帮助我们进一步理解GAN中的生成器和判别器如何在对抗训练中不断优化,直到生成器能够生成足够真实的数据,使得判别器无法区分它们与真实数据之间的差异。通过实现一个简单的GAN,我们可以更直观地看到GAN模型的训练过程,并为后续深入研究更复杂的GAN模型奠定基础。


1. MNIST手写数字数据集简介


MNIST数据集总共包含两个子数据集:一个训练数据集和一个测试数据集。它们分别包含了60K和10K的28×28的灰度图像。




附下载链接


链接:https://pan.baidu.com/s/1TaL3dCHxAj17LgvSSd_eTA?pwd=xl8n 


提取码:xl8n



以下是使用PyTorch实现一个简单的GAN模型,用于生成手写数字的示例。


2. 训练过程


训练过程可以分为两个主要阶段。首先,在第一阶段,我们固定判别器D,专注于训练生成器G。在此阶段,生成器G从随机噪声中生成假数据,并交给判别器D进行判断。最开始,生成器的能力较弱,生成的数据很容易被判别器识别为假数据。然而,随着训练的进行,生成器逐步提高生成数据的质量,最终使判别器难以辨别其生成的数据。此时,判别器基本处于猜测状态,无法准确判断数据的真伪。


接下来进入第二阶段,固定生成器G,开始训练判别器D。在这一阶段,判别器D通过学习如何识别生成器所生成的假数据来提升自己的鉴别能力,最终能够较为准确地区分真实数据与假数据。


这两个阶段交替进行,生成器和判别器在对抗过程中不断优化。通过反复训练,生成器和判别器的能力不断提升,最终生成器能够生成几乎无法与真实数据区分的高质量样本。


import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
# 生成器
class Generator ( nn.Module ) :
    def __init__ ( self, noise_dim ) :
        super ( Generator, self ) .__init__ (  )
        self.model = nn.Sequential (
            nn.Linear ( noise_dim, 256 ) ,
            nn.ReLU (  ) ,
            nn.Linear ( 256, 512 ) ,
            nn.ReLU (  ) ,
            nn.Linear ( 512, 1024 ) ,
            nn.ReLU (  ) ,
            nn.Linear ( 1024, 28 * 28 ) ,
            nn.Tanh (  )   # 输出范围 [ -1, 1 ]
         )
    def forward ( self, z ) :
        return self.model ( z ) .reshape ( -1, 1, 28, 28 )
        
# 判别器
class Discriminator ( nn.Module ) :
    def __init__ ( self ) :
        super ( Discriminator, self ) .__init__ (  )
        self.model = nn.Sequential (
            nn.Flatten (  ) ,
            nn.Linear ( 28 * 28, 512 ) ,
            nn.LeakyReLU ( 0.2 ) ,
            nn.Linear ( 512, 256 ) ,
            nn.LeakyReLU ( 0.2 ) ,
            nn.Linear ( 256, 1 ) ,
            nn.Sigmoid (  )   # 输出范围 [ 0, 1 ]
         )
    def forward ( self, img ) :
        return self.model ( img )

def draw_images ( generator, epoch
, examples=16, dim= ( 4,4 ) , figsize= ( 10,10 )  ) :
    noise= np.random.normal ( loc=0, scale=1, size= [ examples, 100 ]  )
    generated_images = generator ( noise )
    generated_images = generated_images.reshape ( 25,28,28 )
    plt.figure ( figsize=figsize )
    for i in range ( generated_images.shape [ 0 ]  ) :
        plt.subplot ( dim [ 0 ] , dim [ 1 ] , i+1 )
        plt.imshow ( generated_images [ i ] , interpolation='nearest', cmap='Greys' )
        plt.axis ( 'off' )
    plt.tight_layout (  )
    plt.savefig ( 'Generated_images %d.png' %epoch )
        
# 训练GAN
def train_gan ( epochs, noise_dim, batch_size ) :
    # 加载MNIST手写数字数据集
    dataset = datasets.MNIST ( root='.', train=True, download=True, transform=transforms.ToTensor (  )  )
    dataloader = DataLoader ( dataset, batch_size=batch_size, shuffle=True )
    generator = Generator ( noise_dim )
    discriminator = Discriminator (  )
   
    # 优化器
    optimizer_G = optim.Adam ( generator.parameters (  ) , lr=0.0002 )
    optimizer_D = optim.Adam ( discriminator.parameters (  ) , lr=0.0002 )
   
    # 迭代训练
    for epoch in range ( epochs ) :
        for real_imgs, _ in dataloader:
            ### 训练判别器
            optimizer_D.zero_grad (  )
              # 由随机噪声,生成“假”图像
            noise = torch.randn ( real_imgs.size ( 0 ) , noise_dim )
            fake_imgs = generator ( noise )
            
            d_loss_real = nn.BCELoss (  )  ( discriminator ( real_imgs ) , torch.ones ( real_imgs.size ( 0 ) , 1 )  )
            d_loss_fake = nn.BCELoss (  )  ( discriminator ( fake_imgs.detach (  )  ) , torch.zeros ( real_imgs.size ( 0 ) , 1 )  )
            d_loss = d_loss_real + d_loss_fake
            d_loss.backward (  )
            optimizer_D.step (  )
            
            ### 训练生成器
            optimizer_G.zero_grad (  )
            g_loss = nn.BCELoss (  )  ( discriminator ( fake_imgs ) , torch.ones ( real_imgs.size ( 0 ) , 1 )  )
            g_loss.backward (  )
            optimizer_G.step (  )
            
        # 显示迭代过程中生成器生成的图像   
        if epoch  == 1 or epoch  % 10 == 0:
            draw_images ( generator, epoch )     
            
        print ( f'Epoch {epoch}/{epochs}, D Loss: {d_loss.item (  ) }, G Loss: {g_loss.item (  ) }' )

# 调用训练函数
train_gan ( epochs=1000, noise_dim=100, batch_size=64 )


通过上述步骤,我们成功实现了一个简单的GAN网络。尽管这里的实现比较基础,但它为理解GAN的运作原理提供了良好的基础。随着对GAN的深入学习,大家可以尝试更多复杂的网络结构和优化策略,以提升生成样本的质量和多样性。


GAN的应用


随着技术的进步,GAN已被广泛应用于多个领域,展示了其强大的生成能力和灵活性。它们可以用于图像合成、语义图像编辑、风格迁移、图像超分辨率和分类等。它们也被用于医学领域的药物发现和医学成像。


1. 生成数据集


人工智能的训练是需要大量的数据集,可以通过GAN自动生成低成本的数据集。


2. 人脸生成



3.  物品生成



4. 图像转换




5. 图像修复




总结与展望


GAN不仅在图像生成和数据增强方面展现了巨大潜力,也在医学、艺术等多个领域产生了深远的影响。随着技术的不断进步,GAN的应用前景广阔,未来有望在更复杂的任务中发挥更大作用,推动各行各业的创新与发展。


 




文章改编转载自微信公众号:E等于mc平方


原文链接:


https://mp.weixin.qq.com/s/oy4VL7xBmNXAwhKFFM7tKA?scene=1


https://mp.weixin.qq.com/s/0DjW70jdoiRV1nGZ7lStiA?scene=1&click_id=20

80
0
0
0
关于作者
相关文章
  • CD-RBM+BM-ILM:破解人脸识别梯度消失难题的混合技术 ...
    《Face Recognition Based on CD-RBM and BM-ILM》发表于《Journal of Physics: Conference Seri ...
    了解详情 
  • 一文学会9种主流GAN损失函数及其PyTorch实现:从经典模型到现代 ...
    生成对抗网络(GAN)依赖于其损失函数来优化生成器和判别器的训练过程。本文首先介绍了经典GAN的 ...
    了解详情 
  • 相干伊辛机在生命科学基础研究和药物发现中的应用 ...
    本文发表在《信息通信技术与政策》2025 年第 7 期:http://ictp.caict.ac.cn/CN/10.12267/j.issn ...
    了解详情 
  • 量子计算变革金融衍生品定价:复杂期权与风险分析的新突破 ...
    本文提出量子并行蒙特卡洛(MCQP)算法,解决金融衍生品定价中经典方法在高维问题的局限。该算法 ...
    了解详情 
联系我们
二维码
在本版发帖返回顶部
快速回复 返回顶部 返回列表
玻色有奖小调研
填写问卷,将免费赠送您5个100bit真机配额
(单选) 您是从哪个渠道得知我们的?*
您是从哪个社交媒体得知我们的?*
您是通过哪个学校的校园宣讲得知我们的呢?
取消

提交成功

真机配额已发放到您的账户,可前往【云平台】查看