@@ -378,19 +378,24 @@ function negative_sample(g::GNNGraph;
378378end
379379
380380"""
381- rand_edge_split(g::GNNGraph, frac) -> g1, g2
381+ rand_edge_split(g::GNNGraph, frac; bidirected=is_bidirected(g) ) -> g1, g2
382382
383383Randomly partition the edges in `g` to from two graphs, `g1`
384384and `g2`. Both will have the same number of nodes as `g`.
385385`g1` will contain a fraction `frac` of the original edges,
386386while `g2` wil contain the rest.
387+
388+ If `bidirected = true` makes sure that an edge and its reverse go into the same split.
389+
387390Useful for train/test splits in link prediction tasks.
388391"""
389- function rand_edge_split (g:: GNNGraph , frac)
390- # TODO add bidirected version
392+ function rand_edge_split (g:: GNNGraph , frac; bidirected= is_bidirected (g))
391393 s, t = edge_index (g)
392- eids = randperm (g. num_edges)
393- size1 = round (Int, g. num_edges * frac)
394+ idx, idmax = edge_encoding (s, t, g. num_nodes, directed= ! bidirected)
395+ uidx = union (idx) # So that multi-edges (and reverse edges in the bidir case) go in the same split
396+ nu = length (uidx)
397+ eids = randperm (nu)
398+ size1 = round (Int, nu * frac)
394399
395400 s1, t1 = s[eids[1 : size1]], t[eids[1 : size1]]
396401 g1 = GNNGraph (s1, t1, num_nodes= g. num_nodes)
0 commit comments