首页 科学教育文章正文

【零基础学AI】生成对抗网络(GAN)实战 - 手写数字生成(【新手入门AI】实战GAN:制作手写数字生成器)

科学教育 2025年07月10日 15:10 2 aaron
  【零基础学AI】生成对抗网络(GAN)实战 - 手写数字生成   一、前言   生成对抗网络(GAN)是一种强大的深度学习模型,可以生成与真实数据分布相似的样本。本文将带领初学者通过一个简单的手写数字生成任务,了解GAN的基本原理和实战操作。   二、准备工作   环境配置 安装Python 3.6及以上版本 安装TensorFlow库:pip install tensorflow   数据集 下载MNIST手写数字数据集:https://www.tensorflow.org/datasets/catalog/mnist   三、实战步骤 导入库和加载数据 import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers import tensorflow_datasets as tfds # 加载MNIST数据集 mnist = tfds.load('mnist', split='train', shuffle_files=True) 定义生成器和判别器 def make_generator_model(): model = keras.Sequential() model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,))) model.add(layers.LeakyReLU()) model.add(layers.Reshape((7, 7, 256))) model.add(layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same', use_bias=False)) model.add(layers.LeakyReLU()) model.add(layers.Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same', use_bias=False)) model.add(layers.LeakyReLU()) model.add(layers.Conv2DTranspose(1, (4, 4), strides=(2, 2), padding='same', use_bias=False, activation='tanh')) return model def make_discriminator_model(): model = keras.Sequential() model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1])) model.add(layers.LeakyReLU()) model.add(layers.Dropout(0.3)) model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')) model.add(layers.LeakyReLU()) model.add(layers.Dropout(0.3)) model.add(layers.Flatten()) model.add(layers.Dense(1)) return model 编译和训练模型 generator = make_generator_model() discriminator = make_discriminator_model() # 编译模型 discriminator.compile(loss='binary_crossentropy', optimizer=keras.optimizers.Adam(0.0001), metrics=['accuracy']) generator.compile(loss='binary_crossentropy', optimizer=keras.optimizers.Adam(0.0001)) # 训练模型 for epoch in range(50): for real_images, _ in mnist: real_images = real_images.reshape(len(real_images), 28, 28, 1) real_labels = np.ones((len(real_images), 1)) fake_images = generator.predict(np.random.normal(0, 1, (len(real_images), 100))) fake_labels = np.zeros((len(fake_images), 1)) discriminator.train_on_batch(real_images, real_labels) discriminator.train_on_batch(fake_images, fake_labels) # 生成并保存生成的手写数字图片 generated_images = generator.predict(np.random.normal(0, 1, (25, 100))) generated_images = generated_images * 255 generated_images = np.clip(generated_images, 0, 255).astype('uint8') for i in range(25): plt.imshow(generated_images[i, :, :, 0], cmap='gray') plt.show()   四、总结   通过以上步骤,我们成功实现了使用GAN生成手写数字图片。希望本文能帮助初学者快速入门GAN,为后续学习打下基础。

标签: in 生成

智杖百科 备案号:皖ICP备2023023635号 智杖百科 xml | txt