Изменения

Перейти к: навигация, поиск

Neural Style Transfer

2194 байта добавлено, 04:33, 18 апреля 2019
Пример кода на PyTorch
== Пример кода на PyTorch ==
Content Loss class ContentLoss(nn.Module): def __init__(self, target,): super(ContentLoss, self).__init__() <font color="green"># we 'detach' the target content from the tree used</font> <font color="green"># to dynamically compute the gradient: this is a stated value,</font> <font color="green"># not a variable. Otherwise the forward method of the criterion</font> <font color="green"># will throw an error.</font> self.target = target.detach() def forward(self, input): self.loss = F.mse_loss(input, self.target) return input
Style Loss def __init__gram_matrix(selfinput): a, b, targetc,d = input.size(): # a=batch size(=1) <font color="green"># b=number of feature maps</font> super <font color="green"># (ContentLossc, selfd)=dimensions of a f.__init__map (N=c*d)</font> features = input.view(a * b, c * d) # we 'detach' the target content from the tree usedresise F_XL into \hat F_XL G = torch.mm(features, features.t()) # to dynamically compute the gradient: this is a stated value,gram product <font color="green"># not a variable. Otherwise we 'normalize' the forward method values of the criteriongram matrix</font> <font color="green"># will throw an errorby dividing by the number of element in each feature maps.</font> self.target = target return G.detachdiv(a * b * c * d)
class StyleLoss(nn.Module): def __init__(self, target_feature): super(StyleLoss, self).__init__() self.target = gram_matrix(target_feature).detach() def forward(self, input): G = gram_matrix(input) self.loss = F.mse_loss(inputG, self.target) return input Importing the Model cnn = models.vgg19(pretrained=True).features.to(device).eval()   Разбиение датасета на тренировочный и тестовый: <font color="green"># Split the data into training/testing sets</font> x_train = diabetes_X[:<font color="blue">-20</font>] x_test = diabetes_X[<font color="blue">-20</font>:] <font color="green"># Split the targets into training/testing sets</font> y_train = diabetes.target[:<font color="blue">-20</font>] y_test = diabetes.target[<font color="blue">-20</font>:]  '''import''' numpy '''as''' np '''import''' matplotlib.pyplot '''as''' plt plt.figure(figsize=(<font color="blue">20</font>,<font color="blue">4</font>)) '''for''' index, (image, label) '''in''' enumerate(zip(digits.data[<font color="blue">0</font>:<font color="blue">3</font>], digits.target[<font color="blue">0</font>:<font color="blue">3</font>])): plt.subplot(<font color="blue">1</font>, <font color="blue">3</font>, index + <font color="blue">1</font>) plt.imshow(np.reshape(image, (<font color="blue">8</font>,<font color="blue">8</font>)), cmap=plt.cm.gray) plt.title(<font color="red">'Training: %i\n'</font> % label, fontsize = <font color="blue">20</font>)
==См. также==
Анонимный участник

Навигация