GAN生成对抗网络原理推导与Python实现

一、引言

生成对抗网络(GAN,Generative Adversarial Network)是一种重要的深度学习模型,首次由Ian Goodfellow等人在2014年提出。其核心思想是通过两个神经网络的对抗训练,即生成器(Generator)和判别器(Discriminator),使得生成器能够生成与真实数据相似的伪造数据。GAN在图像生成、图像修复、图像转换等任务中得到了广泛应用。

二、原理推导

1. 基本概念

GAN的目标是通过一个对抗过程,使生成器生成的样本分布逼近真实样本的分布。生成器试图学习真实数据的分布 ( p_{data}(x) ),而判别器则判断输入样本 ( x ) 是真实样本还是生成样本。

2. 损失函数

GAN的损失函数可以用以下公式表示:

[ \text{min}G \text{max}_D V(D,G) = E{x \sim p_{data}(x)}[\log D(x)] + E_{z \sim p_z(z)}[\log(1 - D(G(z)))] ]

其中: - ( D(x) ) 是判别器对输入样本 ( x ) 预测为真实的概率。 - ( G(z) ) 是生成器生成的样本,( z ) 是从噪声分布 ( p_z(z) ) 中采样得到的噪声。

3. 对抗训练过程

在训练过程中,首先固定生成器 ( G ) 来训练判别器 ( D ) :

  1. 从真实样本中取样本 ( x ),求 ( D(x) )。
  2. 从随机噪声 ( z ) 中生成样本 ( G(z) ),求 ( D(G(z)) )。
  3. 计算损失并更新判别器的参数,使得最大化 ( V(D, G) )。

接着,固定判别器 ( D ) 来训练生成器 ( G ) :

  1. 生成样本 ( G(z) )。
  2. 计算损失 ( -\log(D(G(z))) ),更新生成器的参数,使得最小化损失。

4. 收敛条件

在理想情况下,GAN的收敛条件是生成器生成的样本分布与真实样本分布一致,即 ( p_g(x) = p_{data}(x) )。但在实际训练中,由于网络结构、超参数等因素的影响,GAN的训练可能会不稳定。

三、Python代码实现

接下来,我们用Python实现一个简单的GAN。我们将使用Keras库,生成手写数字(MNIST数据集)。

import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Reshape, Flatten, Dropout
from keras.layers import LeakyReLU
from keras.optimizers import Adam

# 加载数据
(X_train, _), (_, _) = mnist.load_data()
X_train = X_train / 255.0  # 归一化到[0, 1]区间
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)

# 定义生成器
def build_generator():
    model = Sequential()
    model.add(Dense(256, input_dim=100))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(28 * 28 * 1, activation='tanh'))
    model.add(Reshape((28, 28, 1)))
    return model

# 定义判别器
def build_discriminator():
    model = Sequential()
    model.add(Flatten(input_shape=(28, 28, 1)))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.3))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.3))
    model.add(Dense(1, activation='sigmoid'))
    return model

# 编译模型
generator = build_generator()
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(), metrics=['accuracy'])

# 构建GAN模型
z = Sequential()
z.add(generator)
discriminator.trainable = False
z.add(discriminator)
z.compile(loss='binary_crossentropy', optimizer=Adam())

# 训练
def train_gan(epochs, batch_size):
    for epoch in range(epochs):
        # 训练判别器
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        real_imgs = X_train[idx]
        noise = np.random.normal(0, 1, (batch_size, 100))
        fake_imgs = generator.predict(noise)
        d_loss_real = discriminator.train_on_batch(real_imgs, np.ones((batch_size, 1)))
        d_loss_fake = discriminator.train_on_batch(fake_imgs, np.zeros((batch_size, 1)))
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # 训练生成器
        noise = np.random.normal(0, 1, (batch_size, 100))
        g_loss = z.train_on_batch(noise, np.ones((batch_size, 1)))

        if epoch % 1000 == 0:
            print(f"{epoch} [D loss: {d_loss[0]:.4f}, acc.: {100 * d_loss[1]:.2f}%] [G loss: {g_loss:.4f}]")
            save_generated_images(epoch)

def save_generated_images(epoch):
    noise = np.random.normal(0, 1, (25, 100))
    fake_imgs = generator.predict(noise)
    fake_imgs = 0.5 * fake_imgs + 0.5  # 反归一化到[0, 1]区间
    fig, axs = plt.subplots(5, 5)
    cnt = 0
    for i in range(5):
        for j in range(5):
            axs[i, j].imshow(fake_imgs[cnt, :, :, 0], cmap='gray')
            axs[i, j].axis('off')
            cnt += 1
    plt.show()

# 开始训练
train_gan(epochs=10000, batch_size=64)

四、总结

GAN通过生成器和判别器的对抗训练,可以生成高质量的伪造数据。尽管训练过程可能不稳定,但GAN在图像生成等领域展现了强大的潜力。以上代码实现了一个简单的GAN,用于生成手写数字,希望能引发你对生成对抗网络更深入的研究和应用。

点赞(0) 打赏

微信小程序

微信扫一扫体验

微信公众账号

微信扫一扫加关注

发表
评论
返回
顶部