GNN学习笔记

GNN从入门到精通课程笔记

2.1 DeepWalk (Code-Application)

DeepWalk: Online Learning of Social Representations (KDD’ 14)

用DeepWalk在Zachary’s Karate Network上构建Embedding, 用sklearn提供的LR模型分类。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import networkx as nx
from karateclub import DeepWalk

# Load Karate Club Graph
G = nx.karate_club_graph()


# Generate DeepWalk Embeddings
model = DeepWalk(walk_length=5, dimensions=128, window_size=5, epochs=20, workers=1)
model.fit(G)
embeddings = model.get_embedding()

# Generate X and Y for training
X = []
Y = []
for node in G.nodes():
X.append(embeddings[node])
Y.append(G.nodes[node]["club"])

# Split Train dataset and Test dataset
from sklearn.model_selection import train_test_split
X_train, X_test, Y_train, Y_test = train_test_split(X, Y)

# Train a Logistic Regression Classifier
from sklearn.linear_model import LogisticRegression
clf = LogisticRegression(random_state=0).fit(X_train, Y_train)

# Evaluate the Classifier
accuracy = clf.score(X_test, Y_test)
print("Accuracy:", accuracy)

运行结果

1
Accuracy: 1.0