36
правок
Изменения
→Примеры кода
==Примеры кода==
Опишем здесь пример построения сети, опустив построение дерева.
[https://github.com/yc930401/RecNN-pytorch Полный листинг кода для]анализа тональности текста на PyTorch (из статьи [https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf Socher et al.(2013c)])
class RNTN(nn.Module):
def __init__(self, word2index, hidden_size, output_size):
super(RNTN,self).__init__()
# Для рекурсивной нейронной сети обязательно нужно для векторное представление слов
self.word2index = word2index
self.embed = nn.Embedding(len(word2index), hidden_size)
self.V = nn.ParameterList([nn.Parameter(torch.randn(hidden_size * 2, hidden_size * 2)) for _ in range(hidden_size)]) # Тензор
self.W = nn.Parameter(torch.randn(hidden_size * 2, hidden_size))
self.b = nn.Parameter(torch.randn(1, hidden_size)) # bias
self.W_out = nn.Linear(hidden_size, output_size)
<font color="green"># инициализация весов</font>
def init_weight(self):
nn.init.xavier_uniform(self.embed.state_dict()['weight'])
nn.init.xavier_uniform(self.W_out.state_dict()['weight'])
for param in self.V.parameters():
nn.init.xavier_uniform(param)
nn.init.xavier_uniform(self.W)
self.b.data.fill_(0)
def tree_propagation(self, node):
recursive_tensor = OrderedDict()
current = None
if node.isLeaf:
tensor = Variable(LongTensor([self.word2index[node.word]])) if node.word in self.word2index.keys() \
else Variable(LongTensor([self.word2index['<UNK>']]))
current = self.embed(tensor) # 1xD
else:
recursive_tensor.update(self.tree_propagation(node.left))
recursive_tensor.update(self.tree_propagation(node.right))
concated = torch.cat([recursive_tensor[node.left], recursive_tensor[node.right]], 1) # 1x2D
xVx = []
for i, v in enumerate(self.V):
xVx.append(torch.matmul(torch.matmul(concated, v), concated.transpose(0, 1)))
xVx = torch.cat(xVx, 1) # 1xD
Wx = torch.matmul(concated, self.W) # 1xD
current = F.tanh(xVx + Wx + self.b) # 1xD
recursive_tensor[node] = current
return recursive_tensor
def forward(self, Trees, root_only=False):
propagated = []
if not isinstance(Trees, list):
Trees = [Trees]
for Tree in Trees:
recursive_tensor = self.tree_propagation(Tree.root)
if root_only:
recursive_tensor = recursive_tensor[Tree.root]
propagated.append(recursive_tensor)
else:
recursive_tensor = [tensor for node,tensor in recursive_tensor.items()]
propagated.extend(recursive_tensor)
propagated = torch.cat(propagated) # (num_of_node in batch, D)
return F.log_softmax(self.W_out(propagated),1)
*https://github.com/bogatyy/cs224d/tree/master/assignment3
*https://gist.github.com/anj1/504768e05fda49a6e3338e798ae1cddd