@@ -144,7 +144,8 @@ If `weighted=true`, the `A` will contain the edge weights if any, otherwise the
144144function Graphs. adjacency_matrix (g:: GNNGraph{<:COO_T} , T:: DataType = eltype (g); dir = :out ,
145145 weighted = true )
146146 if g. graph[1 ] isa CuVector
147- # TODO revisit after https://github.com/JuliaGPU/CUDA.jl/pull/1152
147+ # Revisit after
148+ # https://github.com/JuliaGPU/CUDA.jl/issues/1113
148149 A, n, m = to_dense (g. graph, T; num_nodes = g. num_nodes, weighted)
149150 else
150151 A, n, m = to_sparse (g. graph, T; num_nodes = g. num_nodes, weighted)
@@ -164,63 +165,101 @@ function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType = eltype(g
164165 return dir == :out ? A : A'
165166end
166167
167- function _get_edge_weight (g, edge_weight)
168- if edge_weight === true || edge_weight === nothing
169- ew = get_edge_weight (g)
170- elseif edge_weight === false
171- ew = nothing
172- elseif edge_weight isa AbstractVector
173- ew = edge_weight
168+ function ChainRulesCore. rrule (:: typeof (adjacency_matrix), g:: G , T:: DataType ;
169+ dir = :out , weighted = true ) where {G <: GNNGraph{<:ADJMAT_T} }
170+ A = adjacency_matrix (g, T; dir, weighted)
171+ if ! weighted
172+ function adjacency_matrix_pullback_noweight (Δ)
173+ return (NoTangent (), ZeroTangent (), NoTangent ())
174+ end
175+ return A, adjacency_matrix_pullback_noweight
174176 else
175- error (" Invalid edge_weight argument." )
177+ function adjacency_matrix_pullback_weighted (Δ)
178+ dg = Tangent {G} (; graph = Δ .* binarize (A))
179+ return (NoTangent (), dg, NoTangent ())
180+ end
181+ return A, adjacency_matrix_pullback_weighted
182+ end
183+ end
184+
185+ function ChainRulesCore. rrule (:: typeof (adjacency_matrix), g:: G , T:: DataType ;
186+ dir = :out , weighted = true ) where {G <: GNNGraph{<:COO_T} }
187+ A = adjacency_matrix (g, T; dir, weighted)
188+ w = get_edge_weight (g)
189+ if ! weighted || w === nothing
190+ function adjacency_matrix_pullback_noweight (Δ)
191+ return (NoTangent (), ZeroTangent (), NoTangent ())
192+ end
193+ return A, adjacency_matrix_pullback_noweight
194+ else
195+ function adjacency_matrix_pullback_weighted (Δ)
196+ s, t = edge_index (g)
197+ dg = Tangent {G} (; graph = (NoTangent (), NoTangent (), NNlib. gather (Δ, s, t)))
198+ return (NoTangent (), dg, NoTangent ())
199+ end
200+ return A, adjacency_matrix_pullback_weighted
201+ end
202+ end
203+
204+ function _get_edge_weight (g, edge_weight:: Bool )
205+ if edge_weight === true
206+ return get_edge_weight (g)
207+ elseif edge_weight === false
208+ return nothing
176209 end
177- return ew
178210end
179211
212+ _get_edge_weight (g, edge_weight:: AbstractVector ) = edge_weight
213+
180214"""
181215 degree(g::GNNGraph, T=nothing; dir=:out, edge_weight=true)
182216
183217Return a vector containing the degrees of the nodes in `g`.
184218
219+ The gradient is propagated through this function only if `edge_weight` is `true`
220+ or a vector.
221+
185222# Arguments
223+
186224- `g`: A graph.
187225- `T`: Element type of the returned vector. If `nothing`, is
188226 chosen based on the graph type and will be an integer
189- if `edge_weight=false`.
227+ if `edge_weight=false`. Default `nothing`.
190228- `dir`: For `dir=:out` the degree of a node is counted based on the outgoing edges.
191229 For `dir=:in`, the ingoing edges are used. If `dir=:both` we have the sum of the two.
192230- `edge_weight`: If `true` and the graph contains weighted edges, the degree will
193231 be weighted. Set to `false` instead to just count the number of
194- outgoing/ingoing edges.
195- In alternative , you can also pass a vector of weights to be used
232+ outgoing/ingoing edges.
233+ Finally , you can also pass a vector of weights to be used
196234 instead of the graph's own weights.
235+ Default `true`.
236+
197237"""
198238function Graphs. degree (g:: GNNGraph{<:COO_T} , T:: TT = nothing ; dir = :out ,
199239 edge_weight = true ) where {
200240 TT <: Union{Nothing, Type{<:Number}} }
201241 s, t = edge_index (g)
202242
203- edge_weight = _get_edge_weight (g, edge_weight)
204- edge_weight = edge_weight === nothing ? ones_like (s) : edge_weight
205-
206- T = isnothing (T) ? eltype (edge_weight) : T
207- degs = fill! (similar (s, T, g. num_nodes), 0 )
208-
209- if dir ∈ [:out , :both ]
210- degs = degs .+ NNlib. scatter (+ , edge_weight, s, dstsize = (g. num_nodes,))
211- end
212- if dir ∈ [:in , :both ]
213- degs = degs .+ NNlib. scatter (+ , edge_weight, t, dstsize = (g. num_nodes,))
214- end
215- return degs
243+ ew = _get_edge_weight (g, edge_weight)
244+
245+ T = if isnothing (T)
246+ if ! isnothing (ew)
247+ eltype (ew)
248+ else
249+ eltype (s)
250+ end
251+ else
252+ T
253+ end
254+ return _degree ((s, t), T, dir, ew, g. num_nodes)
216255end
217256
218257# TODO :: Make efficient
219258Graphs. degree (g:: GNNGraph , i:: Union{Int, AbstractVector} ; dir = :out ) = degree (g; dir)[i]
220259
221260function Graphs. degree (g:: GNNGraph{<:ADJMAT_T} , T:: TT = nothing ; dir = :out ,
222- edge_weight = true ) where {TT}
223- TT <: Union{Nothing, Type{<:Number}}
261+ edge_weight = true ) where {TT<: Union{Nothing, Type{<:Number}} }
262+
224263 # edge_weight=true or edge_weight=nothing act the same here
225264 @assert ! (edge_weight isa AbstractArray) " passing the edge weights is not support by adjacency matrix representations"
226265 @assert dir ∈ (:in , :out , :both )
@@ -234,6 +273,26 @@ function Graphs.degree(g::GNNGraph{<:ADJMAT_T}, T::TT = nothing; dir = :out,
234273 end
235274 end
236275 A = adjacency_matrix (g)
276+ return _degree (A, T, dir, edge_weight, g. num_nodes)
277+ end
278+
279+ function _degree ((s, t):: Tuple , T:: Type , dir:: Symbol , edge_weight:: Nothing , num_nodes:: Int )
280+ _degree ((s, t), T, dir, ones_like (s, T), num_nodes)
281+ end
282+
283+ function _degree ((s, t):: Tuple , T:: Type , dir:: Symbol , edge_weight:: AbstractVector , num_nodes:: Int )
284+ degs = fill! (similar (s, T, num_nodes), 0 )
285+
286+ if dir ∈ [:out , :both ]
287+ degs = degs .+ NNlib. scatter (+ , edge_weight, s, dstsize = (num_nodes,))
288+ end
289+ if dir ∈ [:in , :both ]
290+ degs = degs .+ NNlib. scatter (+ , edge_weight, t, dstsize = (num_nodes,))
291+ end
292+ return degs
293+ end
294+
295+ function _degree (A:: AbstractMatrix , T:: Type , dir:: Symbol , edge_weight:: Bool , num_nodes:: Int )
237296 if edge_weight === false
238297 A = binarize (A)
239298 end
@@ -243,6 +302,40 @@ function Graphs.degree(g::GNNGraph{<:ADJMAT_T}, T::TT = nothing; dir = :out,
243302 vec (sum (A, dims = 1 )) .+ vec (sum (A, dims = 2 ))
244303end
245304
305+ function ChainRulesCore. rrule (:: typeof (_degree), graph, T, dir, edge_weight:: Nothing , num_nodes)
306+ degs = _degree (graph, T, dir, edge_weight, num_nodes)
307+ function _degree_pullback (Δ)
308+ return (NoTangent (), NoTangent (), NoTangent (), NoTangent (), NoTangent (), NoTangent ())
309+ end
310+ return degs, _degree_pullback
311+ end
312+
313+ function ChainRulesCore. rrule (:: typeof (_degree), A:: ADJMAT_T , T, dir, edge_weight:: Bool , num_nodes)
314+ degs = _degree (A, T, dir, edge_weight, num_nodes)
315+ if edge_weight === false
316+ function _degree_pullback_noweights (Δ)
317+ return (NoTangent (), NoTangent (), NoTangent (), NoTangent (), NoTangent (), NoTangent ())
318+ end
319+ return degs, _degree_pullback_noweights
320+ else
321+ function _degree_pullback_weights (Δ)
322+ # We propagate the gradient only to the non-zero elements
323+ # of the adjacency matrix.
324+ bA = binarize (A)
325+ if dir == :in
326+ dA = bA .* Δ'
327+ elseif dir == :out
328+ dA = Δ .* bA
329+ else # dir == :both
330+ dA = Δ .* bA + Δ' .* bA
331+ end
332+ return (NoTangent (), dA, NoTangent (), NoTangent (), NoTangent (), NoTangent ())
333+ end
334+ return degs, _degree_pullback_weights
335+ end
336+ end
337+
338+
246339"""
247340 has_isolated_nodes(g::GNNGraph; dir=:out)
248341
0 commit comments