В этом туториале я расскажу о генеративно-состязательных нейронных сетях (GAN) не прибегая к математическим деталям модели. Далее будет показано, как написать собственную простую GAN на Python с Keras, которая сможет генерировать знаки.
Перед вам перевод статьи Demystifying Generative Adversarial Nets (GANs), опубликованной на Datacamp, автор — Stefan Hosien. Ссылка на оригинал — в подвале статьи.
Аналогия
Проще всего понять, что такое GAN, обратившись к следующей аналогии. Представьте, что есть магазин, который покупает определенные сорта вина у своих поставщиков, которые он затем будет перепродавать.
Есть нечестные поставщики, которые продают поддельное вино, чтобы получить деньги. В таком случае руководство магазина должно уметь различать поддельные и подлинные вина.
Можно предположить, что изначально мошенники могли сделать много ошибок при попытке продать поддельное вино, а руководство магазина с легкостью определяло фальшивые экземпляры. Методом проб и ошибок мошенники пробовали разные техники, чтобы имитировать подлинное вино, и в конечном счете им это удалось. Теперь когда мошенники знают, как сделать так, чтобы вино прошло контроль в магазине, они начинают дальше улучшать свой продукт.
В то же время руководство магазина может получать фидбэк от других магазинов или экспертов о том, что некоторые их вина неоригинальные. Поэтому магазину приходится улучшать свои методи определения поддельных вин. Цель мошенников — создание неотличимых от оригинала вин, руководство магазина же стремится точно определить подлинность вина.
Такое взаимное состязание является идеей, лежащей в основе GAN.
Архитектура генеративно-состязательной сети
Используя пример, о котором было сказано выше, можно прийти к архитектуре GAN.
Очевидно, что в GAN должны быть две основные части — генератор и, так называемый, дискриминатор. Руководство магазина в примере сверху — дискриминаторная сеть, которая обычно представляет из себя сверточную нейросеть, CNN, (так как сети GAN в основном используются для задач, связанных с изображениями), которая приписывает изображению процент соответствия подлинности.
Мошенником в GAN выступает генеративная сеть, которая также является сверточной сетью со слоем развертки (deconvolution layer). Эта сеть накладывает шум на изображение (использую вектор шума) и выводит его. Во время тренировки генеративная сеть изучает, какие области изображения необходимо изменить или улучшить, чтобы дискриминатору понадобилось больше времени для определения подлинности сгенерированного изображения.
Генеративная сеть с каждым разом производит изображение, которое все больше походит на реальное, в то время как дискриминативная сеть пытается найти различия между реальным и искусственным изображением. Главная цель — сделать такую генеративную сеть, которая сможет воспроизводить неотличимые от реальных изображения.
Простая генеративно-состязательная сеть в Keras
Теперь когда вы поняли, что такое GAN, какие компоненты у нее есть, начнем писать код. Будем использовать Keras, а если вы не знакомы с этим фреймворком Python, перед началом работы посмотрите этот туториал. В основе этого туториала лежит простая и понятная GAN, разработанная здесь.
Для начала вам необходимо с помощью pip установить следующие пакеты:
- keras - matplotlib - tensorflow - tqdm
Мы будем использовать matplotlib для отрисовки графиков, tensorflow в качестве необходимого для Keras бэкграунда, tqdm для красивой визуализации прогресса с каждой эпохой, итерацией.
Следующий шаг — создание скрипта на Python. В этом скрипте сначала необходимо импортировать все модули и функции для работы. Объяснение работы каждого модуля будет дано позже.
import os import numpy as np import matplotlib.pyplot as plt from tqdm import tqdm from keras.layers import Input from keras.models import Model, Sequential from keras.layers.core import Dense, Dropout from keras.layers.advanced_activations import LeakyReLU from keras.datasets import mnist from keras.optimizers import Adam from keras import initializers
Теперь определим некоторые переменные:
# Let Keras know that we are using tensorflow as our backend engine os.environ["KERAS_BACKEND"] = "tensorflow" # To make sure that we can reproduce the experiment and get the same results np.random.seed(10) # The dimension of our random noise vector. random_dim = 100
Перед тем как начать строить дискриминатор и генератор, нужно собрать данные и сделать их предварительную обработку. Будем использовать известный датасет MNIST, который представляет из себя набор изображений цифр от 0 до 9.
def load_minst_data(): # load the data (x_train, y_train), (x_test, y_test) = mnist.load_data() # normalize our inputs to be in the range[-1, 1] x_train = (x_train.astype(np.float32) - 127.5)/127.5 # convert x_train with a shape of (60000, 28, 28) to (60000, 784) so we have # 784 columns per row x_train = x_train.reshape(60000, 784) return (x_train, y_train, x_test, y_test)
Заметим, что команда mnist.load_data() является частью Keras и позволяет легко импортировать датасет в рабочее пространство.
Теперь мы можем создать сети генератора и дискриминатора. Для обеих сетей используем оптимизатор Adam. В обоих случаях сеть будет состоять из трех скрытых слоев с активационной функцией Leaky Relu. Также следует добавить в дискриминатор dropout слои, чтобы улучшить его надежность, качество (robustness) на изображениях, которые не были показаны.
def get_optimizer(): return Adam(lr=0.0002, beta_1=0.5) def get_generator(optimizer): generator = Sequential() generator.add(Dense(256, input_dim=random_dim, kernel_initializer=initializers.RandomNormal(stddev=0.02))) generator.add(LeakyReLU(0.2)) generator.add(Dense(512)) generator.add(LeakyReLU(0.2)) generator.add(Dense(1024)) generator.add(LeakyReLU(0.2)) generator.add(Dense(784, activation='tanh')) generator.compile(loss='binary_crossentropy', optimizer=optimizer) return generator def get_discriminator(optimizer): discriminator = Sequential() discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02))) discriminator.add(LeakyReLU(0.2)) discriminator.add(Dropout(0.3)) discriminator.add(Dense(512)) discriminator.add(LeakyReLU(0.2)) discriminator.add(Dropout(0.3)) discriminator.add(Dense(256)) discriminator.add(LeakyReLU(0.2)) discriminator.add(Dropout(0.3)) discriminator.add(Dense(1, activation='sigmoid')) discriminator.compile(loss='binary_crossentropy', optimizer=optimizer)
Осталась только соединить генератор с дискриминатором.
def get_gan_network(discriminator, random_dim, generator, optimizer): # We initially set trainable to False since we only want to train either the # generator or discriminator at a time discriminator.trainable = False # gan input (noise) will be 100-dimensional vectors gan_input = Input(shape=(random_dim,)) # the output of the generator (an image) x = generator(gan_input) # get the output of the discriminator (probability if the image is real or not) gan_output = discriminator(x) gan = Model(inputs=gan_input, outputs=gan_output) gan.compile(loss='binary_crossentropy', optimizer=optimizer) return gan
Дополнительно можно создать функцию, сохраняющую сгенерированные изображения через каждые 20 эпох. Так как этот шаг не является основным в туториале, вам не обязательно полностью понимать выводящую изображение функцию.
def plot_generated_images(epoch, generator, examples=100, dim=(10, 10), figsize=(10, 10)): noise = np.random.normal(0, 1, size=[examples, random_dim]) generated_images = generator.predict(noise) generated_images = generated_images.reshape(examples, 28, 28) plt.figure(figsize=figsize) for i in range(generated_images.shape[0]): plt.subplot(dim[0], dim[1], i+1) plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r') plt.axis('off') plt.tight_layout() plt.savefig('gan_generated_image_epoch_%d.png' % epoch)
Мы написали большую часть нашей сети. Осталось только обучить нейросеть и посмотреть на результаты — изображения.
def train(epochs=1, batch_size=128): # Get the training and testing data x_train, y_train, x_test, y_test = load_minst_data() # Split the training data into batches of size 128 batch_count = x_train.shape[0] / batch_size # Build our GAN netowrk adam = get_optimizer() generator = get_generator(adam) discriminator = get_discriminator(adam) gan = get_gan_network(discriminator, random_dim, generator, adam) for e in xrange(1, epochs+1): print '-'*15, 'Epoch %d' % e, '-'*15 for _ in tqdm(xrange(batch_count)): # Get a random set of input noise and images noise = np.random.normal(0, 1, size=[batch_size, random_dim]) image_batch = x_train[np.random.randint(0, x_train.shape[0], size=batch_size)] # Generate fake MNIST images generated_images = generator.predict(noise) X = np.concatenate([image_batch, generated_images]) # Labels for generated and real data y_dis = np.zeros(2*batch_size) # One-sided label smoothing y_dis[:batch_size] = 0.9 # Train discriminator discriminator.trainable = True discriminator.train_on_batch(X, y_dis) # Train generator noise = np.random.normal(0, 1, size=[batch_size, random_dim]) y_gen = np.ones(batch_size) discriminator.trainable = False gan.train_on_batch(noise, y_gen) if e == 1 or e % 20 == 0: plot_generated_images(e, generator) if __name__ == '__main__': train(400, 128)
После обучения на 400 эпохах, можем посмотреть сгенерированные изображения. Глядя на произведенные после первой эпохи изображения, вы можете заметить, что они не имеют реальную структуру. После 40 эпох изображения приобретают нужную форму, а после 400 эпох изображения четкие и почти неотличимые от настоящих, за исключением пары штук.
Главная причина, по которой был выбран этот код, это скорость выполнения. Во время тренировки на CPU для каждой эпохи требуется примерно 2 минуты. Вы можете сами поэкспериментировать с кодом, добавляя эпохи или слои (не обязательно такие же) в генератор и дискриминатор. Однако, если вы работаете с CPU, использование более сложных и глубоких архитектур потребует большего времени на тренировку. Но этот факт не должен останавливать, экспериментируйте!
Заключение
Поздравляю, вы дошли до конца туториала и получили интуитивное понимание генеративно-состязательных сетей GAN. Помимо понимания вы реализовали свою собственную сеть с помощью библиотеки Keras.
+
У меня не получается повторить. ValueError: Input 0 of layer sequential_3 is incompatible with the layer: expected axis -1 of input shape to have value 784 but received input with… Подробнее »
Из-за вас я потратил пол месяца на это, мне пришлось переустанавливать малину, а все из-за того, что у вас ког нерабочий, AttributeError: ‘NoneType’ object has no attribute ‘trainable’ 10 горящих… Подробнее »
Там всё ещё хуже. Используется функция xrange, хотя такой функции вообще нету
Она есть. Ищи в другом месте))0
Если конкретнее — ищи в python 2.X)
А мозгов не хватило в функции get_discriminator добавить в конце «return discriminator» ?)))))