@@ -321,13 +321,21 @@ function getgraph(g::GNNGraph, i::AbstractVector{Int}; nmap=false)
321321end
322322
323323"""
324- negative_sample(g::GNNGraph; num_neg_edges=g.num_edges)
324+ negative_sample(g::GNNGraph;
325+ num_neg_edges = g.num_edges,
326+ bidirected = is_bidirected(g))
325327
326328Return a graph containing random negative edges (i.e. non-edges) from graph `g` as edges.
329+
330+ Is `bidirected=true`, the output graph will be bidirected and there will be no
331+ leakage from the origin graph.
332+
333+ See also [`is_bidirected`](@ref).
327334"""
328335function negative_sample (g:: GNNGraph ;
329336 max_trials= 3 ,
330- num_neg_edges= g. num_edges)
337+ num_neg_edges= g. num_edges,
338+ bidirected = is_bidirected (g))
331339
332340 @assert g. num_graphs == 1
333341 # Consider self-loops as positive edges
@@ -344,8 +352,12 @@ function negative_sample(g::GNNGraph;
344352 device = Flux. cpu
345353 end
346354 idx_pos, maxid = edge_encoding (s, t, n)
347-
348- pneg = 1 - g. num_edges / maxid # prob of selecting negative edge
355+ if bidirected
356+ num_neg_edges = num_neg_edges ÷ 2
357+ pneg = 1 - g. num_edges / 2 maxid # prob of selecting negative edge
358+ else
359+ pneg = 1 - g. num_edges / 2 maxid # prob of selecting negative edge
360+ end
349361 # pneg * sample_prob * maxid == num_neg_edges
350362 sample_prob = min (1 , num_neg_edges / (pneg * maxid) * 1.1 )
351363 idx_neg = Int[]
@@ -359,6 +371,9 @@ function negative_sample(g::GNNGraph;
359371 end
360372 end
361373 s_neg, t_neg = edge_decoding (idx_neg, n)
374+ if bidirected
375+ s_neg, t_neg = [s_neg; t_neg], [t_neg; s_neg]
376+ end
362377 return GNNGraph (s_neg, t_neg, num_nodes= n) |> device
363378end
364379
@@ -372,6 +387,7 @@ while `g2` wil contain the rest.
372387Useful for train/test splits in link prediction tasks.
373388"""
374389function rand_edge_split (g:: GNNGraph , frac)
390+ # TODO add bidirected version
375391 s, t = edge_index (g)
376392 eids = randperm (g. num_edges)
377393 size1 = round (Int, g. num_edges * frac)
0 commit comments