1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
| class GraphAttentionLayer(nn.Module): """ Simple GAT layer, similar to https://arxiv.org/abs/1710.10903 """ def __init__(self, in_features, out_features, dropout, alpha, concat=True): super(GraphAttentionLayer, self).__init__() self.dropout = dropout self.in_features = in_features self.out_features = out_features self.alpha = alpha self.concat = concat
self.W = nn.Parameter(torch.empty(size=(in_features, out_features))) nn.init.xavier_uniform_(self.W.data, gain=1.414) self.a = nn.Parameter(torch.empty(size=(2*out_features, 1))) nn.init.xavier_uniform_(self.a.data, gain=1.414)
self.leakyrelu = nn.LeakyReLU(self.alpha)
def forward(self, h, adj): Wh = torch.mm(h, self.W) e = self._prepare_attentional_mechanism_input(Wh)
zero_vec = -9e15*torch.ones_like(e) attention = torch.where(adj > 0, e, zero_vec) attention = F.softmax(attention, dim=1) attention = F.dropout(attention, self.dropout, training=self.training) h_prime = torch.matmul(attention, Wh)
if self.concat: return F.elu(h_prime) else: return h_prime
def _prepare_attentional_mechanism_input(self, Wh): N = Wh.size()[0]
Wh_repeated_in_chunks = Wh.repeat_interleave(N, dim=0) Wh_repeated_alternating = Wh.repeat(N, 1)
all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=1)
return all_combinations_matrix.view(N, N, 2 * self.out_features) class GAT(nn.Module): def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads): """Dense version of GAT.""" super(GAT, self).__init__() self.dropout = dropout
self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)] for i, attention in enumerate(self.attentions): self.add_module('attention_{}'.format(i), attention)
self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)
def forward(self, x, adj): x = F.dropout(x, self.dropout, training=self.training) x = torch.cat([att(x, adj) for att in self.attentions], dim=1) x = F.dropout(x, self.dropout, training=self.training) x = F.elu(self.out_att(x, adj)) return F.log_softmax(x, dim=1)