Вариационный автокодировщик
Вариационный автокодировщик (англ. Variational Autoencoder, VAE) — это автокодировщик (a.k.a. генеративная модель, которая учится отображать объекты в заданное скрытое пространство (и обратно)) основанный на вариационном выводе.
Содержание
Предпосылки
При попытке использования обыкновенного автокодировщика для генерации новых объектов (желательно из того же априорного распределения, что и датасет) возникает следующая проблема. Случайной величиной с каким распределением проинициализировать скрытые векторы, для того, чтобы картинка, после применения декодера, стала похожа на картинки из датасета, но при этом не совпадала ни с одной из них? Ответ на этот вопрос не ясен, в связи с тем, что обыкновенный автокодировщик не может ничего утверждать про распределение скрытого вектора и даже про его область определения. В частности, область определения может быть даже дискретной.
Вариационный автокодировщик в свою очередь предлагает пользователю самому определить распределение скрытого вектора.
Описание
Генеративное моделирование (англ. Generative modelling) — область машинного обучения, имеющая дело с распределением P(X), определенном на датасете X из пространства (возможно многомерного)
. Так, например, популярные задачи генерации картинок имеют дело с огромным количеством измерений (пикселей).Также как и в обыкновенных кодировщиках у нас имеется скрытое вероятностное пространство Z соответствующее случайной величине (z, P(z))(распределенной как-нибудь фиксированно, здесь ~N(0, 1)). И мы хотим иметь декодер
. При этом мы хотим найти такие , чтобы после разыгрывания z по P(z) мы получили <<что-то похожее>> на элементы X. Вообще, мы хотим, чтобы для любого мы хотим считать здесь мы заменили на , чтобы явно сделать зависимость между x и z и после этого применить формулу полной вероятности. Обычно около нуля почти для всех пар (x, z). Основная идея в том, что мы хотим теперь генерировать z, который бы давали что-то около x и только их суммировать в P(x). Для этого нам требуется ввести еще одно распределение Q(z|X), которое будет получать x и говорить распределение на z которое наиболее вероятно будет генерировать нам такой x. Теперь нам нужно как-то сделать похожими распределения и P(X). Рассмотрим следующую дивергенцию Кульбака-Лейблера
Распишем P(z|X) как P(X|z) * P(z) / P(X):
Что эквивалентно:
Рассмотрим эту штуку для Q(z|X), тогда:
Посмотрим, на это равенство. Правую часть мы можем оптимизировать градиентным спуском (пусть пока и не совсем понятно как). В левой же части первое слагаемое -- то, что мы хотим максимизировать. В то же время
мы хотим минимизировать. Если у нас Q(z|X) -- достаточно сильная модель, то в какой-то модель она будет хорошо матчить P(z|X), а значит их дивергенция Кульбака Лейблера будет почти 0. А значит на это слагаемое можно забить. И стараться максимизировать правую часть. В качестве бонуса мы еще получили более "поддатливую" P(z|X), вместо нее можно смотреть на Q(z|X).Теперь разберемся как оптимизировать правую часть. Сначала нужно определиться с
обычно ее берут равной . Где и какие-то детерминированные функции на X с обучаемыми параметрами , которые мы впредь будем опускать (ага, нейронки)).
Нетрудно проверить, что для двух нормальных распределений с параметрами , KLD есть .
Это значит, что
. Теперь здесь можно считать градиенты, для BackProp-a. С первым слагаемым в правой части, все немного сложнее. E_{z∼Q}[log P(X|z)] мы можем считать методом монте-карло(МК), но тогда такая штука (из-за того, что переменные спрятаны в распределении, из которого мы генерируем себе выборку, для МК) не является гладкой, относительно их, а значит непонятно как проталкивать через это градиент. Для того, чтобы все-таки можно было протолкнуть градиент, применяется так называемый reparametrization trick, который базируется на простой формуле . . В такой форме мы уже можем использовать BackPropagation для переменных из функций и .Следующая картинка лучше поможет осознать структуру VAE и, в частности, зачем нужен (и как работает) reparametrization trick.
Пример реализации
Ниже приведена реализация частного случая VAE на языке Python с использованием библиотеки Pytorch. Эта реализация работает с датасетом MNIST. Размерность скрытого слоя — 2. Координаты в нем считаются независимыми (из-за этого, например, матрица
диагональная, и формула для расчета KLD немного другая).class VariationalAutoencoder(nn.Module): def __init__(self): super().__init__() self.mu = nn.Linear(32, 2) self.gamma = nn.Linear(32, 2) self.encoder = nn.Sequential(nn.Linear(784, 32), nn.ReLU(True)) self.decoder = nn.Sequential(nn.Linear(2, 32), nn.ReLU(True), nn.Linear(32, 784), nn.Sigmoid()) def forward(self, x): mu, gamma = self.encode(x) encoding = self.reparameterize(mu, gamma) x = self.decoder(encoding) return x, mu, gamma def reparameterize(self, mu, gamma): if self.training: sigma = torch.exp(0.5*gamma) std_z = Variable(torch.from_numpy(np.random.normal(0, 1, size=sigma.size())).float()) encoding = std_z.mul(sigma).add(mu) return encoding else: return mu def encode(self, x): x = self.encoder(x) mu = self.mu(x) gamma = self.gamma(x) return mu, gamma def decode(self, x): return self.decoder(x) def latent(self, x): mu, gamma = self.encode(x) encoding = self.reparameterize(mu, gamma) return encoding def loss_function(input, output, mu, gamma, batch_size=batch_size): BCE = F.binary_cross_entropy(output, input) KLD = -0.5*torch.sum(1 + gamma - mu.pow(2) - gamma.exp()) KLD /= batch_size*784 return BCE + KLD
Применение
Область применения вариационных автокодировщиков совпадает с областью применения обыкновенных автокодировщиков. А именно:
- Каскадное обучение глубоких сетей (хотя сейчас применяется все реже, в связи с появлением новых методов инициализации весов)
- Уменьшение шума в данных
- Уменьшение размерности данных (иногда работает лучше, чем метод главных компонент)
Благодаря тому, что пользователь сам устанавливает нужное распределение скрытого вектора, вариационный кодировщик хорошо подходит для генерации новых объектов (например, картинок). Для этого достаточно разыграть скрытый вектор согласно его распределению и скормить ее в декодер. Получится объект из того же распределения, что и датасет.