Generative Adversarial Nets (GAN) — различия между версиями

Материал из Викиконспекты
Перейти к: навигация, поиск
(Улучшение обучения GAN)
(Оригинальный алгоритм обучения GAN)
(не показано 27 промежуточных версий 8 участников)
Строка 1: Строка 1:
 
[[File:Арх_ган.png|450px|thumb|Оригинальная архитектура GAN]]
 
[[File:Арх_ган.png|450px|thumb|Оригинальная архитектура GAN]]
'''Порождающие состязательные сети''' (англ. ''Generative Adversarial Nets, GAN'') {{---}} это алгоритм машинного обучения, входящий в семейство [[:Порождающие модели|порождающих моделей]] и построенный на комбинации из двух нейронных сетей, одна из которых генерирует образцы, а другая же пытается отличить настоящие образцы от сгенерированных. Впервые такие сети были представлены Иэном Гудфеллоу в 2014 году.  
+
 
 +
'''Порождающие состязательные сети''' (англ. ''Generative Adversarial Nets, GAN'') {{---}} алгоритм машинного обучения, входящий в семейство [[:Порождающие модели|порождающих моделей]]<sup>[на 28.01.19 не создан]</sup> и построенный на комбинации из двух нейронных сетей, одна из которых генерирует образцы, другая пытается отличить настоящие образцы от сгенерированных. Впервые такие сети были представлены Иэном Гудфеллоу в 2014 году.  
  
 
==Постановка задачи и метод==
 
==Постановка задачи и метод==
Имеется множество образцов <tex>X</tex> из распределения <tex>p_{data}</tex>, заданного на <tex> \mathbb R^n </tex>, а также некоторое пространство латентных факторов <tex>Z</tex> из распределения <tex>p_{z}</tex>, например, случайные вектора из равномерного распределения <tex> \mathbb U^p(0,1) </tex>.
+
Имеется множество образцов <tex>X</tex> из распределения <tex>p_{data}</tex>, заданного на <tex> \mathbb R^n </tex>, а также некоторое пространство латентных факторов <tex>Z</tex> из распределения <tex>p_{z}</tex>, например, случайные вектора из равномерного распределения <tex> \mathbb U^t(0,1) </tex>.
  
Рассмотрим две нейронные сети: первая {{---}} ''генератор'' <tex> G: Z \rightarrow \mathbb R^n </tex> с параметрами <tex>\theta</tex>, цель которой сгенерировать похожий образец из <tex>p_{data}</tex>, и вторая {{---}} ''дискриминатор'' <tex>D: \mathbb R^n \rightarrow \mathbb [0,1] </tex> с параметрами <tex>\gamma</tex>, цель которой выдавать максимальную оценку на образцах из <tex>X</tex> и минимальную на сгенерированных образцах из <tex>G</tex>. Распределение, порождаемое генератором будем обозначать <tex>p_{gen}</tex>. Так же заметим, что в текущем изложении не принципиальны архитектуры нейронных сетей, поэтому можно считать, что параметры <tex>\theta</tex> и <tex>\gamma</tex> являются просто параметрами многослойных персептронов.
+
Рассмотрим две нейронные сети: первая $-$ ''генератор'' <tex> G: Z \rightarrow \mathbb R^n </tex> с параметрами <tex>\theta</tex>, цель которой сгенерировать похожий образец из <tex>p_{data}</tex>, и вторая $-$ ''дискриминатор'' <tex>D: \mathbb R^n \rightarrow \mathbb [0,1] </tex> с параметрами <tex>\gamma</tex>, цель которой выдавать максимальную оценку на образцах из <tex>X</tex> и минимальную на сгенерированных образцах из <tex>G</tex>. Распределение, порождаемое генератором будем обозначать <tex>p_{gen}</tex>. Так же заметим, что в текущем изложении не принципиальны архитектуры нейронных сетей, поэтому можно считать, что параметры <tex>\theta</tex> и <tex>\gamma</tex> являются просто параметрами многослойных персептронов.
  
 
В качестве примера можно рассматривать генерацию реалистичных фотографий: в этом случае, входом для генератора может быть случайный многомерный шум, а выходом генератора (и входом для дискриминатора) RGB-изображение; выходом же для дискриминатора будет вероятность, что фотография настоящая, т.е число от 0 до 1.  
 
В качестве примера можно рассматривать генерацию реалистичных фотографий: в этом случае, входом для генератора может быть случайный многомерный шум, а выходом генератора (и входом для дискриминатора) RGB-изображение; выходом же для дискриминатора будет вероятность, что фотография настоящая, т.е число от 0 до 1.  
Строка 11: Строка 12:
 
Наша задача выучить распределение <tex>p_{gen}</tex> так, чтобы оно как можно лучше описывало <tex>p_{data}</tex>. Зададим функцию ошибки для получившейся модели. Со стороны дискриминатора мы хотим распознавать образцы из <tex>X</tex> как правильные, т.е в сторону единицы, и образцы из <tex>G</tex> как неправильные, т.е в сторону нуля, таким образом нужно максимизировать следующую величину:
 
Наша задача выучить распределение <tex>p_{gen}</tex> так, чтобы оно как можно лучше описывало <tex>p_{data}</tex>. Зададим функцию ошибки для получившейся модели. Со стороны дискриминатора мы хотим распознавать образцы из <tex>X</tex> как правильные, т.е в сторону единицы, и образцы из <tex>G</tex> как неправильные, т.е в сторону нуля, таким образом нужно максимизировать следующую величину:
  
<center> <tex>\mathop{E}\limits_{x \sim p_{data}}[logD(x)] + \mathop{E}\limits_{x \sim p_{gen}}[log(1-D(x))]</tex>, где <tex>\mathop{E}\limits_{x \sim p_{gen}}[log(1-D(x))] = \mathop{E}\limits_{z \sim p_{z}}[log(1-D(G(z))]</tex> </center>
+
<center> <tex>\mathop{E}\limits_{x \sim p_{data}}[logD(x)] + \mathop{E}\limits_{x \sim p_{gen}}[log(1-D(x))]</tex>, где <tex>\mathop{E}\limits_{x \sim p_{gen}}[log(1-D(x))] = \mathop{E}\limits_{z \sim p_{z}}[log(1-D(G(z))]</tex> </center>,
  
 
Со стороны же генератора требуется научиться "обманывать" дискриминатор, т.е минимизировать по <tex>p_{gen}</tex> второе слагаемое предыдущего выражения. Другими словами, <tex>G</tex> и <tex>D</tex> играют в так называемую ''минимаксную игру'', решая следующую задачу оптимизации:
 
Со стороны же генератора требуется научиться "обманывать" дискриминатор, т.е минимизировать по <tex>p_{gen}</tex> второе слагаемое предыдущего выражения. Другими словами, <tex>G</tex> и <tex>D</tex> играют в так называемую ''минимаксную игру'', решая следующую задачу оптимизации:
  
<center> <tex> \min\limits_{G}\max\limits_{D} \mathop{E}\limits_{x \sim p_{data}}[logD(x)]  + \mathop{E}\limits_{z \sim p_{z}}[log(1-D(G(z))]  </tex> </center>  
+
<center> <tex> \min\limits_{G}\max\limits_{D} \mathop{E}\limits_{x \sim p_{data}}[logD(x)]  + \mathop{E}\limits_{z \sim p_{z}}[log(1-D(G(z))]  </tex> </center>,
  
 
Теоретическое обоснование того, что такой метод заставляет <tex>p_{gen}</tex> сходится к <tex>p_{data}</tex> описано в исходной статье. <ref> [https://arxiv.org/pdf/1406.2661.pdf  Ian J. Goodfellow {{---}} Generative Adversarial Nets]</ref>
 
Теоретическое обоснование того, что такой метод заставляет <tex>p_{gen}</tex> сходится к <tex>p_{data}</tex> описано в исходной статье. <ref> [https://arxiv.org/pdf/1406.2661.pdf  Ian J. Goodfellow {{---}} Generative Adversarial Nets]</ref>
  
 
==Оригинальный алгоритм обучения GAN==
 
==Оригинальный алгоритм обучения GAN==
 +
[[File:Обучение_ган.png|450px|thumb|right|Визуализация генерирования фотографии с помощью DCGAN по одному и тому же шуму в зависимости от итерации обучения. Источник: https://arxiv.org/pdf/1701.07875.pdf]]
 +
 
В процессе обучения требуется делать два шага оптимизации поочередно: сначала обновлять веса генератора <tex>\theta</tex> при фиксированном <tex>\gamma</tex>, а затем веса дискриминатора <tex>\gamma</tex> при фиксированном <tex>\theta</tex>. На практике дискриминатор обновляется <tex>k</tex> раз вместо одного; <tex>k</tex> является гиперпараметром.
 
В процессе обучения требуется делать два шага оптимизации поочередно: сначала обновлять веса генератора <tex>\theta</tex> при фиксированном <tex>\gamma</tex>, а затем веса дискриминатора <tex>\gamma</tex> при фиксированном <tex>\theta</tex>. На практике дискриминатор обновляется <tex>k</tex> раз вместо одного; <tex>k</tex> является гиперпараметром.
  
 
  <font color=green>// num_iteration {{---}} число итераций обучения </font>
 
  <font color=green>// num_iteration {{---}} число итераций обучения </font>
  '''for''' i = 1..num_iteration '''do'''
+
  '''function''' GAN:
  '''for''' j = 1..k '''do'''
+
  '''for''' i = 1..num_iteration '''do'''
    Сэмплируем мини-батч $\{z_1, . . . , z_m\}$ из распределения $p_z$.
+
    '''for''' j = 1..k '''do'''
    Сэмплируем мини-батч $\{x_1, . . . , x_m\}$ из распределения $p_{data}$.
+
      <font color=green>//Получаем мини-батч $\{z_1, . . . , z_m\}$ из распределения $p_z$</font>
    Обновляем дискриминатор в сторону возрастания его градиента:
+
      $z$ = getBatchFromNoisePrior($p_z$) 
    <tex>\mathop{\nabla}_{\gamma} { \frac{1}{m} \sum_{t = 1}^m \limits} [logD(x_t)]  + [log(1-D(G(z_t))] </tex>
+
      <font color=green>//Получаем мини-батч $\{x_1, . . . , x_m\}$ из распределения $p_{data}$ </font>
  '''end''' '''for'''  
+
      $x$ = getBatchFromDataGeneratingDistribution($p_{data}$)
  Сэмплируем мини-батч $\{z_1, . . . , z_m\}$ из распределения $p_z$
+
      <font color=green>//Обновляем дискриминатор в сторону возрастания его градиента</font>
  Обновляем генератор в сторону убывания его градиента:
+
      <tex>d_w \leftarrow \mathop{\nabla}_{\gamma} { \frac{1}{m} \sum_{t = 1}^m \limits} [logD(x_t)]  + [log(1-D(G(z_t))] </tex>
  <tex>\mathop{\nabla}_{\theta}  { \frac{1}{m} \sum_{t = 1}^m \limits} [log(1-D(G(z_t))] </tex>
+
    '''end''' '''for'''
'''end''' '''for'''
+
    <font color=green>//Получаем мини-батч $\{z_1, . . . , z_m\}$ из распределения $p_z$ </font>
 
+
    $z$ = getBatchFromNoisePrior($p_z$)
Обновления на основе градиента могут быть сделаны любым стандартным способом, например, [[:Cтохастический градиентный спуск|стохастическим градиентным спуском]] (SGD). В оригинальной статье использовался SGD с импульсом.
+
    <font color=green>//Обновляем генератор в сторону убывания его градиента </font>
 +
    <tex>g_w \leftarrow \mathop{\nabla}_{\theta}  { \frac{1}{m} \sum_{t = 1}^m \limits} [log(1-D(G(z_t))] </tex>
 +
  '''end''' '''for'''
 +
Обновления на основе градиента могут быть сделаны любым стандартным способом, например, в оригинальной статье использовался [[:Cтохастический градиентный спуск|стохастический градиентный спуск]]<sup>[на 28.01.19 не создан]</sup> с импульсом.
  
 
==Улучшение обучения GAN==
 
==Улучшение обучения GAN==
  
 
Большинство GAN'ов подвержено следующим проблемам:
 
Большинство GAN'ов подвержено следующим проблемам:
* Несходимость (non-convergence): параметры модели дестабилизируются и не сходятся,
+
* Несходимость (non-convergence): параметры модели дестабилизируются и не сходятся;
* Схлопывание мод распределения (mode collapse): генератор коллапсирует, т.е выдает ограниченное количество разных образцов,
+
* Схлопывание мод распределения (mode collapse): генератор коллапсирует, т.е выдает ограниченное количество разных образцов;
* Исчезающий градиент (diminished gradient): дискриминатор становится слишком "сильным", а градиент генератора исчезает и обучение не происходит,
+
* Исчезающий градиент (diminished gradient): дискриминатор становится слишком "сильным", а градиент генератора исчезает и обучение не происходит;
 
* Высокая чувствительность к гиперпараметрам.
 
* Высокая чувствительность к гиперпараметрам.
  
 
Универсального подхода к их решению нет, но существуют практические советы<ref> [https://github.com/soumith/ganhacks  How to Train a GAN? Tips and tricks to make GANs work]</ref>, которые могут помочь. Основными из них являются:
 
Универсального подхода к их решению нет, но существуют практические советы<ref> [https://github.com/soumith/ganhacks  How to Train a GAN? Tips and tricks to make GANs work]</ref>, которые могут помочь. Основными из них являются:
# Нормализация данных. Все признаки в диапазоне $[-1; 1]$.
+
# Нормализация данных. Все признаки в диапазоне $[-1; 1]$;
# Замена функции ошибки для $G$ с $\min log (1-D)$ на $\max log D$, потому что исходный вариант имеет маленький градиент на раннем этапе обучения и большой градиент при сходимости, а предложенный наоборот.
+
# Замена функции ошибки для $G$ с $\min log (1-D)$ на $\max log D$, потому что исходный вариант имеет маленький градиент на раннем этапе обучения и большой градиент при сходимости, а предложенный наоборот;
# Сэмплирование из многомерного нормального распределения вместо равномерного.  
+
# Сэмплирование из многомерного нормального распределения вместо равномерного;  
# Использовать нормализационные слои (например, batch normalization или layer normalization) в $G$ и $D$.
+
# Использовать нормализационные слои (например, batch normalization или layer normalization) в $G$ и $D$;
# Использовать метки для данных, если они имеются, т.е обучать дискриминатор еще и классифицировать образцы.<ref> [https://arxiv.org/pdf/1610.09585.pdf Augustus Odena {{---}} Conditional Image Synthesis with Auxiliary Classifier GANs]</ref>
+
# Использовать метки для данных, если они имеются, т.е обучать дискриминатор еще и классифицировать образцы.
  
 
==Применение==
 
==Применение==
Чаще всего GAN'ы используются для генерации реалистичных изображений, однако существуют достаточно необычные применения, дающие впечатляющие результаты. Рассмотрим несколько из них:
 
  
* CycleGAN<ref> [https://junyanz.github.io/CycleGAN/ Jun-Yan Zhu & Taesung Park {{---}} Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks]</ref>: меняет изображения из одного домена в другой. Например, меняет на фотографии лошадей на зебр,
+
[[File:прогресс_ганов.jpg|450px|thumb|right|Прогресс в генерации фотографий с помощью GAN. Источник: https://twitter.com/goodfellow_ian]]
* SRGAN<ref> [https://arxiv.org/abs/1609.04802 Christian Ledig {{---}} Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network]</ref>: создает изображения с высоким разрешением из более низкого разрешения,
+
 
* Pix2Pix<ref> [https://phillipi.github.io/pix2pix/ Phillip Isola {{---}} Image-to-Image Translation with Conditional Adversarial Nets]</ref>: создает изображения по семантической окраске,
+
Чаще всего GAN'ы используются для генерации реалистичных фотографий. Серьезные улучшения в этом направлении были сделаны следующими работами:
* StackGAN<ref> [https://arxiv.org/abs/1612.03242 Han Zhang {{---}} StackGAN: Text to Photo-realistic Image Synthesis with Stacked Generative Adversarial Networks]</ref>: создает изображения по заданному тексту,
+
 
 +
* Auxiliary GAN<ref> [https://arxiv.org/pdf/1610.09585.pdf Augustus Odena {{---}} Conditional Image Synthesis with Auxiliary Classifier GANs]</ref>: вариант GAN-архитектуры, использующий метки данных;
 +
* SN-GAN<ref> [https://arxiv.org/pdf/1802.05957.pdf Takeru Miyato {{---}} SPECTRAL NORMALIZATION FOR GENERATIVE ADVERSARIAL NETWORKS]</ref>: GAN с новым подходом решения проблемы нестабильного обучения через спектральную нормализацию;
 +
* SAGAN<ref> [https://arxiv.org/pdf/1805.08318.pdf Han Zhang {{---}} Self-Attention Generative Adversarial Networks]</ref>: GAN, основанный на механизме внимания;
 +
* BigGAN<ref> [https://arxiv.org/pdf/1809.11096.pdf Andrew Brock {{---}} LARGE SCALE GAN TRAINING FOR HIGH FIDELITY NATURAL IMAGE SYNTHESIS]</ref>: GAN с ортогональной регуляризацией, позволившей разрешить проблему коллапсирования при долгом обучении;
 +
 
 +
Кроме простой генерации изображений, существуют достаточно необычные применения, дающие впечатляющие результаты не только на картинках, но и на звуке:
 +
 
 +
* CycleGAN<ref> [https://junyanz.github.io/CycleGAN/ Jun-Yan Zhu & Taesung Park {{---}} Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks]</ref>: меняет изображения c одного домена на другой, например, лошадей на зебр;
 +
* SRGAN<ref> [https://arxiv.org/abs/1609.04802 Christian Ledig {{---}} Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network]</ref>: создает изображения с высоким разрешением из более низкого разрешения;
 +
* Pix2Pix<ref> [https://phillipi.github.io/pix2pix/ Phillip Isola {{---}} Image-to-Image Translation with Conditional Adversarial Nets]</ref>: создает изображения по семантической окраске;
 +
* StackGAN<ref> [https://arxiv.org/abs/1612.03242 Han Zhang {{---}} StackGAN: Text to Photo-realistic Image Synthesis with Stacked Generative Adversarial Networks]</ref>: создает изображения по заданному тексту;
 
* MidiNet<ref> [https://arxiv.org/abs/1703.10847 Li-Chia Yang {{---}} MIDINET: A CONVOLUTIONAL GENERATIVE ADVERSARIAL NETWORK FOR SYMBOLIC-DOMAIN MUSIC GENERATION]</ref>: генерирует последовательность нот, таким образом, создает мелодию.
 
* MidiNet<ref> [https://arxiv.org/abs/1703.10847 Li-Chia Yang {{---}} MIDINET: A CONVOLUTIONAL GENERATIVE ADVERSARIAL NETWORK FOR SYMBOLIC-DOMAIN MUSIC GENERATION]</ref>: генерирует последовательность нот, таким образом, создает мелодию.
 +
 +
==CGAN (Conditional Generative Adversarial Nets)==
 +
 +
[[File:CGAN_architecture.png|450px|thumb|Архитектура CGAN. Источник: https://arxiv.org/pdf/1411.1784.pdf]]
 +
 +
'''Условные порождающие состязательные сети''' (англ. ''Conditional Generative Adversarial Nets, CGAN'') $-$ это модифицированная версия алгоритма GAN, которая позволяет
 +
генерировать объекты с дополнительными условиями '''y'''. '''y''' может быть любой дополнительной информацией, например, меткой класса или данными из других моделей. Добавление данных условий в существующую архитектуру осуществляется с помощью расширения вектором '''y''' входных данных генератора и дискриминатора.
 +
 +
В таком случае задача оптимизации будет выглядеть следующим образом:
 +
 +
<center> <tex> \min\limits_{G}\max\limits_{D} \mathop{E}\limits_{x \sim p_{data}}[logD(x|y)]  + \mathop{E}\limits_{z \sim p_{z}}[log(1-D(G(z|y))]  </tex> </center>
 +
 +
В качестве примера использования данного алгоритма можно рассмотреть задачу генерации рукописных цифр. ''CGAN'' был натренирован на датасете ''MNIST'' с метками классов представленных в виде ''one-hot'' векторов.
 +
 +
[[File:CGAN_generated.png|450px|thumb|Цифры, сгенерированные с помощью CGAN. Источник: https://arxiv.org/pdf/1411.1784.pdf]]
 +
 +
==DCGAN (Deep Convolutional Generative Adversarial Nets)==
 +
 +
[[File:DCGAN_generator.png|450px|thumb|Архитектура генератора в DCGAN. Источник: https://arxiv.org/pdf/1511.06434.pdf]]
 +
 +
'''DCGAN''' $-$ модификация алгоритма ''GAN'', основными архитектурными изменениями которой являются:
 +
* Замена всех пулинговых слоев на страйдинговые свертки (''strided convolutions'') в дискриминаторе и частично-страйдинговые свертки (''fractional-strided''
 +
''convolutions'') в генераторе;
 +
* Использование батчинговой нормализации для генератора и дискриминатора;
 +
* Удаление всех полносвязных скрытых уровней для более глубоких архитектур;
 +
* Использование ''ReLU'' в качестве функции активации в генераторе для всех слоев, кроме последнего, где используется ''tanh'';
 +
* Использование ''LeakyReLU'' в качестве функции активации в дискриминаторе для всех слоев.
 +
 +
Помимо задачи генерации объектов, данный алгоритм хорошо показывает себя в качестве ''feature extractor'''а.
 +
Данный алгоритм был натренирован на датасете ''Imagenet-1k'', после чего были использованы значения со сверточных слоев дискриминатора, подвергнутые ''max-pooling'''у, чтобы образовать матрицы
 +
<tex> 4 \times 4 </tex> и получить общий вектор признаков на их основе. ''L2-SVM'' с данным ''feature extractor'''ом на датасете ''CIFAR-10'' превосходит по точности решения, основанные на алгоритме
 +
''K-Means''. Более подробно об этом вы можете прочитать в статье. <ref> [https://arxiv.org/pdf/1511.06434.pdf  Alec Radford, Luke Metz, Soumith Chintala {{---}} Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks]</ref>
 +
  
 
==См. также==
 
==См. также==
*[[:Порождающие модели|Порождающие модели]]
+
*[[:Порождающие модели|Порождающие модели]]<sup>[на 28.01.19 не создан]</sup>
*[[:Variational autoencoder (VAE)|Variational autoencoder (VAE)]]
+
*[[:Variational autoencoder (VAE)|Variational autoencoder (VAE)]]<sup>[на 28.01.19 не создан]</sup>
 +
 
 
==Примечания==
 
==Примечания==
 
<references/>
 
<references/>
Строка 69: Строка 119:
 
* Сергей Николенко, Артур Кадурин, Екатерина Архангельская. Глубокое обучение. Погружение в мир нейронных сетей. — «Питер», 2018. — С. 348-360.
 
* Сергей Николенко, Артур Кадурин, Екатерина Архангельская. Глубокое обучение. Погружение в мир нейронных сетей. — «Питер», 2018. — С. 348-360.
 
* [https://medium.com/@jonathan_hui/gan-why-it-is-so-hard-to-train-generative-advisory-networks-819a86b3750b Medium | GAN — Why it is so hard to train Generative Adversarial Networks! ]
 
* [https://medium.com/@jonathan_hui/gan-why-it-is-so-hard-to-train-generative-advisory-networks-819a86b3750b Medium | GAN — Why it is so hard to train Generative Adversarial Networks! ]
 +
* [https://arxiv.org/pdf/1411.1784.pdf CGAN Paper]
 +
* [https://arxiv.org/pdf/1511.06434.pdf DCGAN Paper]
 
[[Категория: Машинное обучение]]
 
[[Категория: Машинное обучение]]
 
[[Категория: Порождающие модели]]
 
[[Категория: Порождающие модели]]

Версия 17:16, 25 ноября 2019

Оригинальная архитектура GAN

Порождающие состязательные сети (англ. Generative Adversarial Nets, GAN) — алгоритм машинного обучения, входящий в семейство порождающих моделей[на 28.01.19 не создан] и построенный на комбинации из двух нейронных сетей, одна из которых генерирует образцы, другая пытается отличить настоящие образцы от сгенерированных. Впервые такие сети были представлены Иэном Гудфеллоу в 2014 году.

Постановка задачи и метод

Имеется множество образцов [math]X[/math] из распределения [math]p_{data}[/math], заданного на [math] \mathbb R^n [/math], а также некоторое пространство латентных факторов [math]Z[/math] из распределения [math]p_{z}[/math], например, случайные вектора из равномерного распределения [math] \mathbb U^t(0,1) [/math].

Рассмотрим две нейронные сети: первая $-$ генератор [math] G: Z \rightarrow \mathbb R^n [/math] с параметрами [math]\theta[/math], цель которой сгенерировать похожий образец из [math]p_{data}[/math], и вторая $-$ дискриминатор [math]D: \mathbb R^n \rightarrow \mathbb [0,1] [/math] с параметрами [math]\gamma[/math], цель которой выдавать максимальную оценку на образцах из [math]X[/math] и минимальную на сгенерированных образцах из [math]G[/math]. Распределение, порождаемое генератором будем обозначать [math]p_{gen}[/math]. Так же заметим, что в текущем изложении не принципиальны архитектуры нейронных сетей, поэтому можно считать, что параметры [math]\theta[/math] и [math]\gamma[/math] являются просто параметрами многослойных персептронов.

В качестве примера можно рассматривать генерацию реалистичных фотографий: в этом случае, входом для генератора может быть случайный многомерный шум, а выходом генератора (и входом для дискриминатора) RGB-изображение; выходом же для дискриминатора будет вероятность, что фотография настоящая, т.е число от 0 до 1.

Наша задача выучить распределение [math]p_{gen}[/math] так, чтобы оно как можно лучше описывало [math]p_{data}[/math]. Зададим функцию ошибки для получившейся модели. Со стороны дискриминатора мы хотим распознавать образцы из [math]X[/math] как правильные, т.е в сторону единицы, и образцы из [math]G[/math] как неправильные, т.е в сторону нуля, таким образом нужно максимизировать следующую величину:

[math]\mathop{E}\limits_{x \sim p_{data}}[logD(x)] + \mathop{E}\limits_{x \sim p_{gen}}[log(1-D(x))][/math], где [math]\mathop{E}\limits_{x \sim p_{gen}}[log(1-D(x))] = \mathop{E}\limits_{z \sim p_{z}}[log(1-D(G(z))][/math]
,

Со стороны же генератора требуется научиться "обманывать" дискриминатор, т.е минимизировать по [math]p_{gen}[/math] второе слагаемое предыдущего выражения. Другими словами, [math]G[/math] и [math]D[/math] играют в так называемую минимаксную игру, решая следующую задачу оптимизации:

[math] \min\limits_{G}\max\limits_{D} \mathop{E}\limits_{x \sim p_{data}}[logD(x)] + \mathop{E}\limits_{z \sim p_{z}}[log(1-D(G(z))] [/math]
,

Теоретическое обоснование того, что такой метод заставляет [math]p_{gen}[/math] сходится к [math]p_{data}[/math] описано в исходной статье. [1]

Оригинальный алгоритм обучения GAN

Визуализация генерирования фотографии с помощью DCGAN по одному и тому же шуму в зависимости от итерации обучения. Источник: https://arxiv.org/pdf/1701.07875.pdf

В процессе обучения требуется делать два шага оптимизации поочередно: сначала обновлять веса генератора [math]\theta[/math] при фиксированном [math]\gamma[/math], а затем веса дискриминатора [math]\gamma[/math] при фиксированном [math]\theta[/math]. На практике дискриминатор обновляется [math]k[/math] раз вместо одного; [math]k[/math] является гиперпараметром.

// num_iteration — число итераций обучения 
function GAN:
  for i = 1..num_iteration do
    for j = 1..k do
      //Получаем мини-батч $\{z_1, . . . , z_m\}$ из распределения $p_z$
      $z$ = getBatchFromNoisePrior($p_z$)  
      //Получаем мини-батч $\{x_1, . . . , x_m\}$ из распределения $p_{data}$ 
      $x$ = getBatchFromDataGeneratingDistribution($p_{data}$)
      //Обновляем дискриминатор в сторону возрастания его градиента
      [math]d_w \leftarrow \mathop{\nabla}_{\gamma} { \frac{1}{m} \sum_{t = 1}^m \limits} [logD(x_t)]  + [log(1-D(G(z_t))] [/math]
    end for
    //Получаем мини-батч $\{z_1, . . . , z_m\}$ из распределения $p_z$ 
    $z$ = getBatchFromNoisePrior($p_z$)
    //Обновляем генератор в сторону убывания его градиента 
    [math]g_w \leftarrow \mathop{\nabla}_{\theta}  { \frac{1}{m} \sum_{t = 1}^m \limits} [log(1-D(G(z_t))] [/math]
  end for

Обновления на основе градиента могут быть сделаны любым стандартным способом, например, в оригинальной статье использовался стохастический градиентный спуск[на 28.01.19 не создан] с импульсом.

Улучшение обучения GAN

Большинство GAN'ов подвержено следующим проблемам:

  • Несходимость (non-convergence): параметры модели дестабилизируются и не сходятся;
  • Схлопывание мод распределения (mode collapse): генератор коллапсирует, т.е выдает ограниченное количество разных образцов;
  • Исчезающий градиент (diminished gradient): дискриминатор становится слишком "сильным", а градиент генератора исчезает и обучение не происходит;
  • Высокая чувствительность к гиперпараметрам.

Универсального подхода к их решению нет, но существуют практические советы[2], которые могут помочь. Основными из них являются:

  1. Нормализация данных. Все признаки в диапазоне $[-1; 1]$;
  2. Замена функции ошибки для $G$ с $\min log (1-D)$ на $\max log D$, потому что исходный вариант имеет маленький градиент на раннем этапе обучения и большой градиент при сходимости, а предложенный наоборот;
  3. Сэмплирование из многомерного нормального распределения вместо равномерного;
  4. Использовать нормализационные слои (например, batch normalization или layer normalization) в $G$ и $D$;
  5. Использовать метки для данных, если они имеются, т.е обучать дискриминатор еще и классифицировать образцы.

Применение

Прогресс в генерации фотографий с помощью GAN. Источник: https://twitter.com/goodfellow_ian

Чаще всего GAN'ы используются для генерации реалистичных фотографий. Серьезные улучшения в этом направлении были сделаны следующими работами:

  • Auxiliary GAN[3]: вариант GAN-архитектуры, использующий метки данных;
  • SN-GAN[4]: GAN с новым подходом решения проблемы нестабильного обучения через спектральную нормализацию;
  • SAGAN[5]: GAN, основанный на механизме внимания;
  • BigGAN[6]: GAN с ортогональной регуляризацией, позволившей разрешить проблему коллапсирования при долгом обучении;

Кроме простой генерации изображений, существуют достаточно необычные применения, дающие впечатляющие результаты не только на картинках, но и на звуке:

  • CycleGAN[7]: меняет изображения c одного домена на другой, например, лошадей на зебр;
  • SRGAN[8]: создает изображения с высоким разрешением из более низкого разрешения;
  • Pix2Pix[9]: создает изображения по семантической окраске;
  • StackGAN[10]: создает изображения по заданному тексту;
  • MidiNet[11]: генерирует последовательность нот, таким образом, создает мелодию.

CGAN (Conditional Generative Adversarial Nets)

Архитектура CGAN. Источник: https://arxiv.org/pdf/1411.1784.pdf

Условные порождающие состязательные сети (англ. Conditional Generative Adversarial Nets, CGAN) $-$ это модифицированная версия алгоритма GAN, которая позволяет генерировать объекты с дополнительными условиями y. y может быть любой дополнительной информацией, например, меткой класса или данными из других моделей. Добавление данных условий в существующую архитектуру осуществляется с помощью расширения вектором y входных данных генератора и дискриминатора.

В таком случае задача оптимизации будет выглядеть следующим образом:

[math] \min\limits_{G}\max\limits_{D} \mathop{E}\limits_{x \sim p_{data}}[logD(x|y)] + \mathop{E}\limits_{z \sim p_{z}}[log(1-D(G(z|y))] [/math]

В качестве примера использования данного алгоритма можно рассмотреть задачу генерации рукописных цифр. CGAN был натренирован на датасете MNIST с метками классов представленных в виде one-hot векторов.

Цифры, сгенерированные с помощью CGAN. Источник: https://arxiv.org/pdf/1411.1784.pdf

DCGAN (Deep Convolutional Generative Adversarial Nets)

Архитектура генератора в DCGAN. Источник: https://arxiv.org/pdf/1511.06434.pdf

DCGAN $-$ модификация алгоритма GAN, основными архитектурными изменениями которой являются:

  • Замена всех пулинговых слоев на страйдинговые свертки (strided convolutions) в дискриминаторе и частично-страйдинговые свертки (fractional-strided

convolutions) в генераторе;

  • Использование батчинговой нормализации для генератора и дискриминатора;
  • Удаление всех полносвязных скрытых уровней для более глубоких архитектур;
  • Использование ReLU в качестве функции активации в генераторе для всех слоев, кроме последнего, где используется tanh;
  • Использование LeakyReLU в качестве функции активации в дискриминаторе для всех слоев.

Помимо задачи генерации объектов, данный алгоритм хорошо показывает себя в качестве feature extractor'а. Данный алгоритм был натренирован на датасете Imagenet-1k, после чего были использованы значения со сверточных слоев дискриминатора, подвергнутые max-pooling'у, чтобы образовать матрицы [math] 4 \times 4 [/math] и получить общий вектор признаков на их основе. L2-SVM с данным feature extractor'ом на датасете CIFAR-10 превосходит по точности решения, основанные на алгоритме K-Means. Более подробно об этом вы можете прочитать в статье. [12]


См. также

Примечания

Источники информации