Изменения

Перейти к: навигация, поиск

PixelRNN и PixelCNN

4265 байт добавлено, 14:05, 31 марта 2020
Тире
[[File:pixel-1.png|450px|thumb|Рисунок 1. Пример использования PixelRNN/PixelCNN сетей]]
'''''PixelRNN/''''' и '''''PixelCNN''' '' {{-- -}} алгоритмы машинного обучения, входящие в семейство авторегрессивных моделей. Используются для генерации и дополнения изображений. Алгоритмы были представлены в 2016 году компанией ''DeepMind ''<ref name=PixelNet>[https://arxiv.org/abs/1601.06759 Pixel Recurrent Neural Networks]</ref> и являются предшественниками алгоритма ''WaveNet''<ref name=WaveNet>[https://deepmind.com/blog/article/wavenet-generative-model-raw-audio WaveNet: A generative model for raw audio]</ref>, который используется в голосовом помощнике ''Google''.
Основным преимуществом ''PixelRNN/'' и ''PixelCNN '' является уменьшение времени обучения, по сравнению с наивными способами попиксельной генерации изображений.
== Постановка задачи ==
Пусть дано черно-белое изображение <tex>X</tex> размером <tex>N\times N</tex>. Построчно преобразуем картинку в вектор <tex>V_X = \{x_1, x_2, \dots, x_{N^2} \}</tex>, соединяя конец текущей строки с началом следующей. В таком представлении изображения можно предположить, что значение любого пикселя <tex>x_i\in V_X</tex> может зависеть от значений предыдущих пикселей <tex>x_j, j = 1,2,\dots i-1</tex>.
Тогда значение пикселя <tex>x_i\in V_X</tex> можно выразить через условную вероятность <tex>p(x_i|x_1, x_2, \dots x_{i-1})</tex>, и, используя цепное правило для вероятностей<ref name=ChainRule>[https://en.wikipedia.org/wiki/Chain_rule_(probability) Chain rule (probability)]</ref>, оценка совместного распределения всех пикселей будет записываться в следующем виде: <tex>p(X)=\prod_{i=1}^{N^2}p(x_i|x_1, x_2, \dots x_{i-1})</tex>.
Задача алгоритма - восстановить данное распределение. Учитывая тот факт, что любой пиксель принимает значение <tex>0<=x_i<=255</tex>, необходимо восстановить лишь дискретное распределение.
== Идея ==
Т.к. утверждается, что значение текущего пикселя зависит от значений предыдущего, то уместно использовать [[:Рекуррентные_нейронные_сети|''RNN'']], а точнее [[Долгая краткосрочная память|''LSTM'']]. В ранних работах <ref name=SpatialLSTM>[https://arxiv.org/abs/1506.03478 Generative Image Modeling Using Spatial LSTMs]</ref> уже использовался данный подход, и вычисление скрытого состояния происходило следующим образом: <tex>h_{i,j}=f(h_{i-1,j}, h_{i,j-1}, x_{i,j})</tex>, т.е. для того, чтобы вычислить текущее скрытое состояние, нужно было подсчитать все предыдущие, что занимает достаточно много времени.
Авторы алгоритма модернизировали [[Долгая краткосрочная память|''LSTM'']] в '''''RowLSTM''''' и '''''Diagonal BiLSTM''''' таким образом, чтобы стало возможным распараллеливание вычислений, что в итоге положительно сказывается на времени обучения модели.
=== RowLSTM ===
[[File:pixel-2.png|350px|thumb|Рисунок 2. Визуализация работы модификаций ''LSTM''. Снизу кружками обозначены пиксели, сверху - состояния на каждом пикселе. Синим обозначено то, что влияет на текущее скрытое состояние. Пустые кружки не принимают участие в вычислениях для данного скрытого состояния]]В данной модификации [[Долгая краткосрочная память|''LSTM'']] предлагается рассчитывать скрытое состояние следующим образом: <tex>h_{i,j}=f(h_{i-1,j-1}, h_{i-1,j}, h_{i-1,j+1}, x_{i,j})</tex>.
Как видно из формулы и Рисунка 2, значение текущего скрытого состояния не зависит от предыдущего слева, а зависит от предыдущих сверху, которые можно параллельно рассчитать.
Из плюсов данного алгоритма можно отметить его быстродействие {{- --}} модель обучается быстрее, нежели наивный [[Долгая краткосрочная память|''LSTM'']]. Из минусов - относительно плохое качество получаемых изображений. Это связанно как минимум с тем, что мы используем контекст пикселей с предыдущей строки, но никак не используем контекст соседнего слева пикселя, которые является достаточно важным, т.к. является ближайшим с точки зрения построчной генерации изображения.
Отсюда напрашивается идея каким-то образом найти скрытое состояние пикселя слева, но при этом не потерять в производительности.
=== Diagonal BiLSTM ===
[[File:pixel-3.png|350px|thumb|Рисунок 3. Операция сдвига в ''Diagonal BiLSTM''. Параллелизация происходит по диагоналям.]]В данной версии скрытое состояние считается таким же образом, как и в наивном подходе: <tex>h_{i,j}=f(h_{i-1,j}, h_{i,j-1}, x_{i,j})</tex>, но при этом есть хитрость в самом вычислении. Построчно сдвинем строки вправо на один пиксель относительно предыдущей, а затем вычислим скрытые состояния в каждом столбце, как показано на рисунке Рисунке 3.
Данная версия позволяет учитывать контекст более качественно, но при этом занимает больше времени, чем ''RowLSTM''.
=== PixelCNN ===
Идея в том, что обычно соседние пиксели (в рамках ядра 9x9) хранят самый важный контекст для пикселя. Поэтому предлагается просто использовать известные пиксели для вычисления нового, как показано на рисунке 2.
 
== Архитектура ==
В алгоритмах ''PixelRNN/'' и ''PixelCNN '' используются много несколько архитектурных трюков, позволяющих сделать производить вычисления быстрыми и надежными.
=== Маскированные сверточные слои (Mask) ===В описаниях алгоритмов фигурируют два типа маскированных сверточных слоя {{- --}} '''''MaskA''''', '''''MaskB'''''. Они необходимы для сокрытия от алгоритма лишней информации и учета контекста - чтобы не обрабатывать изображение после каждого подсчета, удаляя значения пикселей, можно применить маску к изображению, что является более быстрой операцией.
Для каждого пикселя в цветном изображении в порядке очереди существуют 3 три контекста: красный канал, зеленый и синий. В данном алгоритме очередь важна, т.е. если сейчас обрабатывается красный канал, то контекст только от предыдущих значений красного канала, если зеленый {{- --}} то от всех значений на красном канале и предыдущих значениях на зеленом и т.д.
'''''MaskA''''' используется для того, чтобы учитывать контекст предыдущих каналов, но при этом не учитывать контекст от предыдущих значений текущего канала и следующих каналов. '''''MaskB''''' выполняет ту же функцию, что и '''''MaskA''''', но при этом учитывает контекст от предыдущих значений текущего канала.
=== Уменьшение размерности ===
[[File:pixel-4.png|350px|thumb|Рисунок 4. Блоки уменьшения размерности. Слева - блок для ''PixelCNN'', справа - ''PixelRNN''. ]]На вход в любой их указанных выше алгоритмов (''PixelCNN'', Row LSTM''RowLSTM'', ''Diagonal BiLSTM'') подается большое количество объектов. Поэтому внутри каждого из них сначала происходит уменьшение их количества в 2 два раза, а затем обратное увеличение в 2 разадо исходного размера. Структура алгоритма с учетом уменьшения размерности показана на рисунке 4.
=== Внутреннее устройство LSTM === Внутреннее устройство ''RowLSTM'' и ''Diagonal BiLSTM'' блоков одинаково, за исключением того, что во втором случае добавляется операция сдвига в начале и возврат к исходной структуре изображения в конце.  Структура ''LSTM'' блока: # ''MaskB'' слой ''input-to-state'' <tex>K_{is}</tex> учитывает контекст из входа. # Сверточный слой ''state-to-state'' <tex>K_{ss}</tex> учитывает контекст из предыдущих скрытых слоев.  Используя эти два сверточных слоя формально вычисление ''LSTM'' блока можно записать следующим образом:  <tex>[o_i, f_i, i_i, g_i] = \sigma (K_{ss}\circledast h_{i-1} + K_{is}\circledast x_{i}) \\c_i=f_i\odot c_{i-1} + i_i\odot g_i\\h_i = o_i\odot tanh(c_i)</tex> где <tex>\sigma</tex> {{---}} функция активации, <tex>\circledast</tex> {{---}} операция свертки, <tex>\odot</tex> {{---}} поэлементное умножение. === Архитектура PixelRNN ===# ''MaskA'' размером <tex>7\times 7</tex> # Блоки уменьшения размеренности с ''RowLSTM'' блоком, в котором <tex>K_{is}</tex> имеет размер <tex>3\times 1</tex>, <tex>K_{ss}</tex> {{---}} <tex>3\times 2</tex>. Для ''Diagonal BiLSTM'' <tex>K_{is}</tex> имеет размер <tex>1\times 1</tex>, <tex>K_{ss}</tex> {{---}} <tex>1\times 2</tex>. Количество блоков варьируется. # ''ReLU'' активация # Сверточный слой размером <tex>1\times 1</tex> # ''Softmax'' слой === Архитектура PixelCNN ===# ''MaskA'' размером <tex>7\times 7</tex> # Блоки уменьшения размеренности для ''PixelCNN''. # ''ReLU'' активация # Сверточный слой размером <tex>1\times 1</tex> # ''Softmax'' слой == Сравнение с GAN подходов =={| class="wikitable"! style="font-weight:bold;" | Критерий\название! style="font-weight:bold;" | PixelCNN! style="font-weight:bold;" | PixelRNN(Row LSTM)! style="font-weight:bold;" | PixelRNN(Diagonal BiLSTM)|-| Время обучения| Быстрый| Средний| Медленный|-| Качество генерируемых изображений| Наихудшее| Средне-низкое| Средне-высокое|}
== Примеры реализации ==
* [https://github.com/singh-hrituraj/PixelCNN-Pytorch PixelCNN на Pytorch]
* [https://github.com/ardapekis/pixel-rnn PixelRNN на Pytorch]
* [https://github.com/shirgur/PixelRNN PixelRNN на Keras]
 
==Примечания==
<references/>
== Источники информации ==
* [https://towardsdatascience.com/auto-regressive-generative-models-pixelrnn-pixelcnn-32d192911173 Auto-Regressive Generative Models]
* [http://slazebni.cs.illinois.edu/spring17/lec13_advanced.pdf Advanced Generation Methods]
* [https://github.com/tensorflow/magenta/blob/master/magenta/reviews/pixelrnn.md Pixel Recurrent Neural Networks]
* [http://bjlkeng.github.io/posts/pixelcnn/ PixelCNN]
[[Категория: Машинное обучение]]
39
правок

Навигация