fbpx
  • Туториал: создание простой GAN на Python с библиотекой Keras

    gan python keras tutorial

    В этом туториале я расскажу о генеративно-состязательных нейронных сетях (GAN) не прибегая к математическим деталям модели. Далее будет показано, как написать собственную простую GAN на Python с Keras, которая сможет генерировать знаки.

    Перед вам перевод статьи Demystifying Generative Adversarial Nets (GANs), опубликованной на Datacamp, автор — Stefan Hosien. Ссылка на оригинал — в подвале статьи.

    Аналогия

    Проще всего понять, что такое GAN, обратившись к следующей аналогии. Представьте, что есть магазин, который покупает определенные сорта вина у своих поставщиков, которые он затем будет перепродавать.

    gan туториал

    Есть нечестные поставщики, которые продают поддельное вино, чтобы получить деньги. В таком случае руководство магазина должно уметь различать поддельные и подлинные вина.

    gan на python туториал

    Можно предположить, что изначально мошенники могли сделать много ошибок при попытке продать поддельное вино, а руководство магазина с легкостью определяло фальшивые экземпляры. Методом проб и ошибок мошенники пробовали разные техники, чтобы имитировать подлинное вино, и в конечном счете им это удалось. Теперь когда мошенники знают, как сделать так, чтобы вино прошло контроль в магазине, они начинают дальше улучшать свой продукт.

    В то же время руководство магазина может получать фидбэк от других магазинов или экспертов о том, что некоторые их вина неоригинальные. Поэтому магазину приходится улучшать свои методи определения поддельных вин. Цель мошенников — создание неотличимых от оригинала вин, руководство магазина же стремится точно определить подлинность вина.

    Такое взаимное состязание является идеей, лежащей в основе GAN.

    Архитектура генеративно-состязательной сети

    Используя пример, о котором было сказано выше, можно прийти к архитектуре GAN.

    как работает gan

    Очевидно, что в GAN должны быть две основные части — генератор и, так называемый, дискриминатор. Руководство магазина в примере сверху — дискриминаторная сеть, которая обычно представляет из себя сверточную нейросеть, CNN, (так как сети GAN в основном используются для задач, связанных с изображениями), которая приписывает изображению процент соответствия подлинности.

    Мошенником в GAN выступает генеративная сеть, которая также является сверточной сетью со слоем развертки (deconvolution layer). Эта сеть накладывает шум на изображение (использую вектор шума) и выводит его. Во время тренировки генеративная сеть изучает, какие области изображения необходимо изменить или улучшить, чтобы дискриминатору понадобилось больше времени для определения подлинности сгенерированного изображения.

    Генеративная сеть с каждым разом производит изображение, которое все больше походит на реальное, в то время как дискриминативная сеть пытается найти различия между реальным и искусственным изображением. Главная цель — сделать такую генеративную сеть, которая сможет воспроизводить неотличимые от реальных изображения.

    Простая генеративно-состязательная сеть в Keras

    Теперь когда вы поняли, что такое GAN,  какие компоненты у нее есть, начнем писать код. Будем использовать Keras, а если вы не знакомы с этим фреймворком Python, перед началом работы посмотрите этот туториал. В основе этого туториала лежит простая и понятная GAN, разработанная здесь.

    Для начала вам необходимо с помощью pip установить следующие пакеты:

    - keras
    - matplotlib
    - tensorflow
    - tqdm

    Мы будем использовать matplotlib для отрисовки графиков, tensorflow в качестве необходимого для Keras бэкграунда, tqdm для красивой визуализации прогресса с каждой эпохой, итерацией.

    Следующий шаг — создание скрипта на Python. В этом скрипте сначала необходимо импортировать все модули и функции для работы. Объяснение работы каждого модуля будет дано позже.

    import os
    import numpy as np
    import matplotlib.pyplot as plt
    from tqdm import tqdm
    
    from keras.layers import Input
    from keras.models import Model, Sequential
    from keras.layers.core import Dense, Dropout
    from keras.layers.advanced_activations import LeakyReLU
    from keras.datasets import mnist
    from keras.optimizers import Adam
    from keras import initializers

    Теперь определим некоторые переменные:

    # Let Keras know that we are using tensorflow as our backend engine
    os.environ["KERAS_BACKEND"] = "tensorflow"
    # To make sure that we can reproduce the experiment and get the same results
    np.random.seed(10)
    # The dimension of our random noise vector.
    random_dim = 100

    Перед тем как начать строить дискриминатор и генератор, нужно собрать данные и сделать их предварительную обработку. Будем использовать известный датасет MNIST, который представляет из себя набор изображений цифр от 0 до 9.

    генерация символов с gan
    Пример символов из датасета MNIST
    def load_minst_data():
        # load the data
        (x_train, y_train), (x_test, y_test) = mnist.load_data()
        # normalize our inputs to be in the range[-1, 1] 
        x_train = (x_train.astype(np.float32) - 127.5)/127.5
        # convert x_train with a shape of (60000, 28, 28) to (60000, 784) so we have
        # 784 columns per row
        x_train = x_train.reshape(60000, 784)
        return (x_train, y_train, x_test, y_test)

    Заметим, что команда mnist.load_data() является частью Keras и позволяет легко импортировать датасет в рабочее пространство.

    Теперь мы можем создать сети генератора и дискриминатора. Для обеих сетей используем оптимизатор Adam. В обоих случаях сеть будет состоять из трех скрытых слоев с активационной функцией Leaky Relu. Также следует добавить в дискриминатор dropout слои, чтобы улучшить его надежность, качество (robustness) на изображениях, которые не были показаны.

    def get_optimizer():
        return Adam(lr=0.0002, beta_1=0.5)
    
    def get_generator(optimizer):
        generator = Sequential()
        generator.add(Dense(256, input_dim=random_dim, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
        generator.add(LeakyReLU(0.2))
    
        generator.add(Dense(512))
        generator.add(LeakyReLU(0.2))
    
        generator.add(Dense(1024))
        generator.add(LeakyReLU(0.2))
    
        generator.add(Dense(784, activation='tanh'))
        generator.compile(loss='binary_crossentropy', optimizer=optimizer)
        return generator
    
    def get_discriminator(optimizer):
        discriminator = Sequential()
        discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
        discriminator.add(LeakyReLU(0.2))
        discriminator.add(Dropout(0.3))
    
        discriminator.add(Dense(512))
        discriminator.add(LeakyReLU(0.2))
        discriminator.add(Dropout(0.3))
    
        discriminator.add(Dense(256))
        discriminator.add(LeakyReLU(0.2))
        discriminator.add(Dropout(0.3))
    
        discriminator.add(Dense(1, activation='sigmoid'))
        discriminator.compile(loss='binary_crossentropy', optimizer=optimizer)

    Осталась только соединить генератор с дискриминатором.

    def get_gan_network(discriminator, random_dim, generator, optimizer):
        # We initially set trainable to False since we only want to train either the 
        # generator or discriminator at a time
        discriminator.trainable = False
        # gan input (noise) will be 100-dimensional vectors
        gan_input = Input(shape=(random_dim,))
        # the output of the generator (an image)
        x = generator(gan_input)
        # get the output of the discriminator (probability if the image is real or not)
        gan_output = discriminator(x)
        gan = Model(inputs=gan_input, outputs=gan_output)
        gan.compile(loss='binary_crossentropy', optimizer=optimizer)
        return gan

    Дополнительно можно создать функцию, сохраняющую сгенерированные изображения через каждые 20 эпох. Так как этот шаг не является основным в туториале, вам не обязательно полностью понимать выводящую изображение функцию.

    def plot_generated_images(epoch, generator, examples=100, dim=(10, 10), figsize=(10, 10)):
        noise = np.random.normal(0, 1, size=[examples, random_dim])
        generated_images = generator.predict(noise)
        generated_images = generated_images.reshape(examples, 28, 28)
    
        plt.figure(figsize=figsize)
        for i in range(generated_images.shape[0]):
            plt.subplot(dim[0], dim[1], i+1)
            plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r')
            plt.axis('off')
        plt.tight_layout()
        plt.savefig('gan_generated_image_epoch_%d.png' % epoch)

    Мы написали большую часть нашей сети. Осталось только обучить нейросеть и посмотреть на результаты — изображения.

    def train(epochs=1, batch_size=128):
        # Get the training and testing data
        x_train, y_train, x_test, y_test = load_minst_data()
        # Split the training data into batches of size 128
        batch_count = x_train.shape[0] / batch_size
    
        # Build our GAN netowrk
        adam = get_optimizer()
        generator = get_generator(adam)
        discriminator = get_discriminator(adam)
        gan = get_gan_network(discriminator, random_dim, generator, adam)
    
        for e in xrange(1, epochs+1):
            print '-'*15, 'Epoch %d' % e, '-'*15
            for _ in tqdm(xrange(batch_count)):
                # Get a random set of input noise and images
                noise = np.random.normal(0, 1, size=[batch_size, random_dim])
                image_batch = x_train[np.random.randint(0, x_train.shape[0], size=batch_size)]
    
                # Generate fake MNIST images
                generated_images = generator.predict(noise)
                X = np.concatenate([image_batch, generated_images])
    
                # Labels for generated and real data
                y_dis = np.zeros(2*batch_size)
                # One-sided label smoothing
                y_dis[:batch_size] = 0.9
    
                # Train discriminator
                discriminator.trainable = True
                discriminator.train_on_batch(X, y_dis)
    
                # Train generator
                noise = np.random.normal(0, 1, size=[batch_size, random_dim])
                y_gen = np.ones(batch_size)
                discriminator.trainable = False
                gan.train_on_batch(noise, y_gen)
    
            if e == 1 or e % 20 == 0:
                plot_generated_images(e, generator)
    
    if __name__ == '__main__':
        train(400, 128)

    После обучения на 400 эпохах, можем посмотреть сгенерированные изображения. Глядя на произведенные после первой эпохи изображения, вы можете заметить, что они не имеют реальную структуру. После 40 эпох изображения приобретают нужную форму, а после 400 эпох изображения четкие и почти неотличимые от настоящих, за исключением пары штук.

    gan на python туториал генерация цифр
    Результат после 1 эпохи
    результат работы генеративной сети
    Результат после 40 эпох
    результат gan
    Результат после 400 эпох

    Главная причина, по которой был выбран этот код, это скорость выполнения. Во время тренировки на CPU для каждой эпохи требуется примерно 2 минуты. Вы можете сами поэкспериментировать с кодом, добавляя эпохи или слои (не обязательно такие же) в генератор и дискриминатор. Однако, если вы работаете с CPU, использование более сложных и глубоких архитектур потребует большего времени на тренировку. Но этот факт не должен останавливать, экспериментируйте!

    Заключение

    Поздравляю, вы дошли до конца туториала и получили интуитивное понимание генеративно-состязательных сетей GAN. Помимо понимания вы реализовали свою собственную сеть с помощью библиотеки Keras.