1+ # An example of link prediction using negative and positive samples.
2+ # Ported from https://docs.dgl.ai/tutorials/blitz/4_link_predict.html#sphx-glr-tutorials-blitz-4-link-predict-py
3+
4+ using Flux
5+ # Link prediction task
6+ # https://arxiv.org/pdf/2102.12557.pdf
7+
8+ using Flux: onecold, onehotbatch
9+ using Flux. Losses: logitbinarycrossentropy
10+ using GraphNeuralNetworks
11+ using MLDatasets: PubMed, Cora
12+ using Statistics, Random, LinearAlgebra
13+ using CUDA
14+ # using MLJBase: AreaUnderCurve
15+ CUDA. allowscalar (false )
16+
17+ # arguments for the `train` function
18+ Base. @kwdef mutable struct Args
19+ η = 1f-3 # learning rate
20+ epochs = 200 # number of epochs
21+ seed = 17 # set seed > 0 for reproducibility
22+ usecuda = false # if true use cuda (if available)
23+ nhidden = 64 # dimension of hidden features
24+ infotime = 10 # report every `infotime` epochs
25+ end
26+
27+ struct DotPredictor end
28+
29+ function (:: DotPredictor )(g, x)
30+ z = apply_edges ((xi, xj, e) -> sum (xi .* xj, dims= 1 ), g, xi= x, xj= x)
31+ return vec (z)
32+ end
33+
34+ using ChainRulesCore
35+
36+ function train (; kws... )
37+ # args = Args(; kws...)
38+ args = Args ()
39+
40+ args. seed > 0 && Random. seed! (args. seed)
41+
42+ if args. usecuda && CUDA. functional ()
43+ device = gpu
44+ args. seed > 0 && CUDA. seed! (args. seed)
45+ @info " Training on GPU"
46+ else
47+ device = cpu
48+ @info " Training on CPU"
49+ end
50+
51+ # ## LOAD DATA
52+ data = Cora. dataset ()
53+ g = GNNGraph (data. adjacency_list) |> device
54+ X = data. node_features |> device
55+
56+
57+ # ### SPLIT INTO NEGATIVE AND POSITIVE SAMPLES
58+ s, t = edge_index (g)
59+ eids = randperm (g. num_edges)
60+ test_size = round (Int, g. num_edges * 0.1 )
61+
62+ test_pos_s, test_pos_t = s[eids[1 : test_size]], t[eids[1 : test_size]]
63+ test_pos_g = GNNGraph (test_pos_s, test_pos_t, num_nodes= g. num_nodes)
64+
65+ train_pos_s, train_pos_t = s[eids[test_size+ 1 : end ]], t[eids[test_size+ 1 : end ]]
66+ train_pos_g = GNNGraph (train_pos_s, train_pos_t, num_nodes= g. num_nodes)
67+
68+ test_neg_g = negative_sample (g, num_neg_edges= test_size)
69+
70+ # ## DEFINE MODEL #########
71+ nin, nhidden = size (X,1 ), args. nhidden
72+
73+ model = WithGraph (GNNChain (GCNConv (nin => nhidden, relu),
74+ GCNConv (nhidden => nhidden)),
75+ train_pos_g) |> device
76+
77+ pred = DotPredictor ()
78+
79+ ps = Flux. params (model)
80+ opt = ADAM (args. η)
81+
82+ # ## LOSS FUNCTION ############
83+
84+ function loss (pos_g, neg_g = nothing )
85+ h = model (X)
86+ if neg_g === nothing
87+ # we sample a negative graph at each training step
88+ neg_g = negative_sample (pos_g)
89+ end
90+ pos_score = pred (pos_g, h)
91+ neg_score = pred (neg_g, h)
92+ scores = [pos_score; neg_score]
93+ labels = [fill! (similar (pos_score), 1 ); fill! (similar (neg_score), 0 )]
94+ return logitbinarycrossentropy (scores, labels)
95+ end
96+
97+ # function accuracy(pos_g, neg_g)
98+ # h = model(train_pos_g, X)
99+ # pos_score = pred(pos_g, h)
100+ # neg_score = pred(neg_g, h)
101+ # scores = [pos_score; neg_score]
102+ # labels = [fill!(similar(pos_score), 1); fill!(similar(neg_score), 0)]
103+ # return logitbinarycrossentropy(scores, labels)
104+ # end
105+
106+ # ## LOGGING FUNCTION
107+ function report (epoch)
108+ train_loss = loss (train_pos_g)
109+ test_loss = loss (test_pos_g, test_neg_g)
110+ println (" Epoch: $epoch Train: $(train_loss) Test: $(test_loss) " )
111+ end
112+
113+ # ## TRAINING
114+ report (0 )
115+ for epoch in 1 : args. epochs
116+ gs = Flux. gradient (() -> loss (train_pos_g), ps)
117+ Flux. Optimise. update! (opt, ps, gs)
118+ epoch % args. infotime == 0 && report (epoch)
119+ end
120+ end
121+
122+ # train()
0 commit comments