22# Ported from https://docs.dgl.ai/tutorials/blitz/4_link_predict.html#sphx-glr-tutorials-blitz-4-link-predict-py
33
44using Flux
5+ # Link prediction task
6+ # https://arxiv.org/pdf/2102.12557.pdf
7+
58using Flux: onecold, onehotbatch
69using Flux. Losses: logitbinarycrossentropy
710using GraphNeuralNetworks
8- using GraphNeuralNetworks: ones_like, zeros_like
9- using MLDatasets: Cora
11+ using MLDatasets: PubMed, Cora
1012using Statistics, Random, LinearAlgebra
1113using CUDA
12- using MLJBase: AreaUnderCurve
14+ # using MLJBase: AreaUnderCurve
1315CUDA. allowscalar (false )
1416
15- """
16- Transform vector of cartesian indexes into a tuple of vectors containing integers.
17- """
18- ci2t (ci:: AbstractVector{<:CartesianIndex} , dims) = ntuple (i -> map (x -> x[i], ci), dims)
19-
2017# arguments for the `train` function
2118Base. @kwdef mutable struct Args
2219 η = 1f-3 # learning rate
@@ -34,6 +31,8 @@ function (::DotPredictor)(g, x)
3431 return vec (z)
3532end
3633
34+ using ChainRulesCore
35+
3736function train (; kws... )
3837 # args = Args(; kws...)
3938 args = Args ()
@@ -54,75 +53,67 @@ function train(; kws...)
5453 g = GNNGraph (data. adjacency_list) |> device
5554 X = data. node_features |> device
5655
56+
5757 # ### SPLIT INTO NEGATIVE AND POSITIVE SAMPLES
58- # Split edge set for training and testing
5958 s, t = edge_index (g)
6059 eids = randperm (g. num_edges)
6160 test_size = round (Int, g. num_edges * 0.1 )
62- train_size = g . num_edges - test_size
61+
6362 test_pos_s, test_pos_t = s[eids[1 : test_size]], t[eids[1 : test_size]]
64- train_pos_s, train_pos_t = s[eids[test_size+ 1 : end ]], t[eids[test_size+ 1 : end ]]
65-
66- # Find all negative edges and split them for training and testing
67- adj = adjacency_matrix (g)
68- adj_neg = 1 .- adj - I
69- neg_s, neg_t = ci2t (findall (adj_neg .> 0 ), 2 )
70-
71- neg_eids = randperm (length (neg_s))[1 : g. num_edges]
72- test_neg_s, test_neg_t = neg_s[neg_eids[1 : test_size]], neg_t[neg_eids[1 : test_size]]
73- train_neg_s, train_neg_t = neg_s[neg_eids[test_size+ 1 : end ]], neg_t[neg_eids[test_size+ 1 : end ]]
74- # train_neg_s, train_neg_t = neg_s[neg_eids[train_size+1:end]], neg_t[neg_eids[train_size+1:end]]
63+ test_pos_g = GNNGraph (test_pos_s, test_pos_t, num_nodes= g. num_nodes)
7564
76- train_pos_g = GNNGraph (( train_pos_s, train_pos_t), num_nodes = g . num_nodes)
77- train_neg_g = GNNGraph ((train_neg_s, train_neg_t) , num_nodes= g. num_nodes)
65+ train_pos_s, train_pos_t = s[eids[test_size + 1 : end ]], t[eids[test_size + 1 : end ]]
66+ train_pos_g = GNNGraph (train_pos_s, train_pos_t , num_nodes= g. num_nodes)
7867
79- test_pos_g = GNNGraph ((test_pos_s, test_pos_t), num_nodes= g. num_nodes)
80- test_neg_g = GNNGraph ((test_neg_s, test_neg_t), num_nodes= g. num_nodes)
68+ test_neg_g = negative_sample (g, num_neg_edges= test_size)
8169
82- @show train_pos_g test_pos_g train_neg_g test_neg_g
83-
84- # ## DEFINE MODEL
70+ # ## DEFINE MODEL #########
8571 nin, nhidden = size (X,1 ), args. nhidden
8672
87- model = GNNChain (GCNConv (nin => nhidden, relu),
88- GCNConv (nhidden => nhidden)) |> device
73+ model = WithGraph (GNNChain (GCNConv (nin => nhidden, relu),
74+ GCNConv (nhidden => nhidden)),
75+ train_pos_g) |> device
8976
9077 pred = DotPredictor ()
9178
9279 ps = Flux. params (model)
9380 opt = ADAM (args. η)
9481
95- # ## LOSS FUNCTION
82+ # ## LOSS FUNCTION ############
9683
97- function loss (pos_g, neg_g)
98- h = model (train_pos_g, X)
84+ function loss (pos_g, neg_g = nothing )
85+ h = model (X)
86+ if neg_g === nothing
87+ # we sample a negative graph at each training step
88+ neg_g = negative_sample (pos_g)
89+ end
9990 pos_score = pred (pos_g, h)
10091 neg_score = pred (neg_g, h)
10192 scores = [pos_score; neg_score]
102- labels = [ones_like ( pos_score); zeros_like ( neg_score)]
93+ labels = [fill! ( similar ( pos_score), 1 ); fill! ( similar ( neg_score), 0 )]
10394 return logitbinarycrossentropy (scores, labels)
10495 end
10596
106- function accuracy (pos_g, neg_g)
107- h = model (train_pos_g, X)
108- pos_score = pred (pos_g, h)
109- neg_score = pred (neg_g, h)
110- scores = [pos_score; neg_score]
111- labels = [ones_like ( pos_score); zeros_like ( neg_score)]
112- return logitbinarycrossentropy (scores, labels)
113- end
97+ # function accuracy(pos_g, neg_g)
98+ # h = model(train_pos_g, X)
99+ # pos_score = pred(pos_g, h)
100+ # neg_score = pred(neg_g, h)
101+ # scores = [pos_score; neg_score]
102+ # labels = [fill!(similar( pos_score), 1); fill!(similar( neg_score), 0 )]
103+ # return logitbinarycrossentropy(scores, labels)
104+ # end
114105
115106 # ## LOGGING FUNCTION
116107 function report (epoch)
117- train_loss = loss (train_pos_g, train_neg_g )
108+ train_loss = loss (train_pos_g)
118109 test_loss = loss (test_pos_g, test_neg_g)
119110 println (" Epoch: $epoch Train: $(train_loss) Test: $(test_loss) " )
120111 end
121112
122113 # ## TRAINING
123114 report (0 )
124115 for epoch in 1 : args. epochs
125- gs = Flux. gradient (() -> loss (train_pos_g, train_neg_g ), ps)
116+ gs = Flux. gradient (() -> loss (train_pos_g), ps)
126117 Flux. Optimise. update! (opt, ps, gs)
127118 epoch % args. infotime == 0 && report (epoch)
128119 end
0 commit comments