@@ -49,23 +49,19 @@ function train(; kws...)
4949 # ## LOAD DATA
5050 data = Cora. dataset ()
5151 # data = PubMed.dataset()
52- g = GNNGraph (data. adjacency_list) |> device
52+ g = GNNGraph (data. adjacency_list)
53+ @info g
5354 @show is_bidirected (g)
55+ @show has_self_loops (g)
56+ @show has_multi_edges (g)
57+ @show mean (degree (g))
58+
59+ g = g |> device
5460 X = data. node_features |> device
5561
5662 # ### SPLIT INTO NEGATIVE AND POSITIVE SAMPLES
57- s, t = edge_index (g)
58- eids = randperm (g. num_edges)
59- test_size = round (Int, g. num_edges * 0.1 )
60-
61- test_pos_s, test_pos_t = s[eids[1 : test_size]], t[eids[1 : test_size]]
62- test_pos_g = GNNGraph (test_pos_s, test_pos_t, num_nodes= g. num_nodes)
63-
64- train_pos_s, train_pos_t = s[eids[test_size+ 1 : end ]], t[eids[test_size+ 1 : end ]]
65- train_pos_g = GNNGraph (train_pos_s, train_pos_t, num_nodes= g. num_nodes)
66-
67- test_neg_g = negative_sample (g, num_neg_edges= test_size)
68-
63+ train_pos_g, test_pos_g = rand_edge_split (g, 0.9 )
64+ test_neg_g = negative_sample (g, num_neg_edges= test_pos_g. num_edges)
6965
7066 # ## DEFINE MODEL #########
7167 nin, nhidden = size (X,1 ), args. nhidden
@@ -82,7 +78,7 @@ function train(; kws...)
8278
8379 # ## LOSS FUNCTION ############
8480
85- function loss (pos_g, neg_g = nothing )
81+ function loss (pos_g, neg_g = nothing ; with_accuracy = false )
8682 h = model (X)
8783 if neg_g === nothing
8884 # We sample a negative graph at each training step
@@ -92,14 +88,20 @@ function train(; kws...)
9288 neg_score = pred (neg_g, h)
9389 scores = [pos_score; neg_score]
9490 labels = [fill! (similar (pos_score), 1 ); fill! (similar (neg_score), 0 )]
95- return logitbinarycrossentropy (scores, labels)
91+ l = logitbinarycrossentropy (scores, labels)
92+ if with_accuracy
93+ acc = 0.5 * mean (pos_score .>= 0 ) + 0.5 * mean (neg_score .< 0 )
94+ return l, acc
95+ else
96+ return l
97+ end
9698 end
9799
98100 # ## LOGGING FUNCTION
99101 function report (epoch)
100- train_loss = loss (train_pos_g)
101- test_loss = loss (test_pos_g, test_neg_g)
102- println (" Epoch: $epoch Train: $( train_loss) Test: $( test_loss)" )
102+ train_loss, train_acc = loss (train_pos_g, with_accuracy = true )
103+ test_loss, test_acc = loss (test_pos_g, test_neg_g, with_accuracy = true )
104+ println (" Epoch: $epoch $((; train_loss, train_acc)) $((; test_loss, test_acc) )" )
103105 end
104106
105107 # ## TRAINING
0 commit comments