@@ -100,13 +100,11 @@ function broadcast_edges(g::GNNGraph, x)
100100 return gather (x, gi)
101101end
102102
103- # return a permuted matrix according to the sorting of the sortby column
104103function _sort_col (matrix:: AbstractArray ; rev:: Bool = true , sortby:: Int = 1 )
105104 index = sortperm (view (matrix, sortby, :); rev)
106105 return matrix[:, index]
107106end
108107
109- # sort and reshape matrix
110108function _sort_matrix (matrix:: AbstractArray , k:: Int ; rev:: Bool = true , sortby = nothing )
111109 if sortby === nothing
112110 return sort (matrix, dims = 2 ; rev)[:, 1 : k]
@@ -115,12 +113,10 @@ function _sort_matrix(matrix::AbstractArray, k::Int; rev::Bool = true, sortby =
115113 end
116114end
117115
118- # sort the iterator of batch matrices
119116function _sort_batch (matrices, k:: Int ; rev:: Bool = true , sortby = nothing )
120117 return map (x -> _sort_matrix (x, k; rev, sortby), matrices)
121118end
122119
123- # sort and reshape batch matrix
124120function _topk_batch (matrix:: AbstractArray , number_graphs:: Int , k:: Int ; rev:: Bool = true ,
125121 sortby = nothing )
126122 tensor_matrix = reshape (matrix, size (matrix, 1 ), size (matrix, 2 ) ÷ number_graphs,
@@ -129,7 +125,6 @@ function _topk_batch(matrix::AbstractArray, number_graphs::Int, k::Int; rev::Boo
129125 return reduce (hcat, sorted_matrix)
130126end
131127
132- # topk for a feature matrix
133128function _topk (matrix:: AbstractArray , number_graphs:: Int , k:: Int ; rev:: Bool = true ,
134129 sortby = nothing )
135130 if number_graphs == 1
0 commit comments