@@ -50,7 +50,7 @@ function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T}) where T
5050 # @assert all(>(0), degree(g, T, dir=:in))
5151 c = 1 ./ sqrt .(degree (g, T, dir= :in ))
5252 x = x .* c'
53- x = propagate (copyxj , g, + , xj= x)
53+ x = propagate (copy_xj , g, + , xj= x)
5454 x = x .* c'
5555 if Dout >= Din
5656 x = l. weight * x
179179
180180function (l:: GraphConv )(g:: GNNGraph , x:: AbstractMatrix )
181181 check_num_nodes (g, x)
182- m = propagate (copyxj , g, l. aggr, xj= x)
182+ m = propagate (copy_xj , g, l. aggr, xj= x)
183183 x = l. σ .(l. weight1 * x .+ l. weight2 * m .+ l. bias)
184184 return x
185185end
@@ -206,7 +206,7 @@ Graph attentional layer from the paper [Graph Attention Networks](https://arxiv.
206206
207207Implements the operation
208208```math
209- \m athbf{x}_i' = \s um_{j \i n N(i)} \a lpha_{ij} W \m athbf{x}_j
209+ \m athbf{x}_i' = \s um_{j \i n N(i) \c up \{ i \} } \a lpha_{ij} W \m athbf{x}_j
210210```
211211where the attention coefficients ``\a lpha_{ij}`` are given by
212212```math
@@ -338,7 +338,7 @@ function (l::GatedGraphConv)(g::GNNGraph, H::AbstractMatrix{S}) where {S<:Real}
338338 end
339339 for i = 1 : l. num_layers
340340 M = view (l. weight, :, :, i) * H
341- M = propagate (copyxj , g, l. aggr; xj= M)
341+ M = propagate (copy_xj , g, l. aggr; xj= M)
342342 H, _ = l. gru (H, M)
343343 end
344344 H
@@ -420,7 +420,7 @@ GINConv(nn, ϵ; aggr=+) = GINConv(nn, ϵ, aggr)
420420
421421function (l:: GINConv )(g:: GNNGraph , x:: AbstractMatrix )
422422 check_num_nodes (g, x)
423- m = propagate (copyxj , g, l. aggr, xj= x)
423+ m = propagate (copy_xj , g, l. aggr, xj= x)
424424 l. nn ((1 + ofeltype (x, l. ϵ)) * x + m)
425425end
426426
542542
543543function (l:: SAGEConv )(g:: GNNGraph , x:: AbstractMatrix )
544544 check_num_nodes (g, x)
545- m = propagate (copyxj , g, l. aggr, xj= x)
545+ m = propagate (copy_xj , g, l. aggr, xj= x)
546546 x = l. σ .(l. weight * vcat (x, m) .+ l. bias)
547547 return x
548548end
@@ -711,3 +711,56 @@ function Base.show(io::IO, l::CGConv)
711711 print (io, " , residual=$(l. residual) " )
712712 print (io, " )" )
713713end
714+
715+
716+ @doc raw """
717+ AGNNConv(init_beta=1f0)
718+
719+ Attention-based Graph Neural Network layer from paper [Attention-based
720+ Graph Neural Network for Semi-Supervised Learning](https://arxiv.org/abs/1803.03735).
721+
722+ THe forward pass is given by
723+ ```math
724+ \m athbf{x}_i' = \s um_{j \i n {N(i) \c up \{ i\} } \a lpha_{ij} W \m athbf{x}_j
725+ ```
726+ where the attention coefficients ``\a lpha_{ij}`` are given by
727+ ```math
728+ \a lpha_{ij} =\f rac{e^{\b eta \c os(\m athbf{x}_i, \m athbf{x}_j)}}
729+ {\s um_{j'}e^{\b eta \c os(\m athbf{x}_i, \m athbf{x}_j'}}
730+ ```
731+ with the cosine distance defined by
732+ ```math
733+ \c os(\m athbf{x}_i, \m athbf{x}_j) =
734+ \m athbf{x}_i \c dot \m athbf{x}_j / \l Vert\m athbf{x}_i\r Vert \l Vert\m athbf{x}_j\r Vert``
735+ ```
736+ and ``\b eta`` a trainable parameter.
737+
738+ # Arguments
739+
740+ - `init_beta`: The initial value of ``\b eta``.
741+ """
742+ struct AGNNConv{A<: AbstractVector } <: GNNLayer
743+ β:: A
744+ end
745+
746+ @functor AGNNConv
747+
748+ function AGNNConv (init_beta = 1f0 )
749+ AGNNConv ([init_beta])
750+ end
751+
752+ function (l:: AGNNConv )(g:: GNNGraph , x:: AbstractMatrix )
753+ check_num_nodes (g, x)
754+ g = add_self_loops (g)
755+
756+ xn = x ./ sqrt .(sum (x.^ 2 , dims= 1 ))
757+ cos_dist = apply_edges (xi_dot_xj, g, xi= xn, xj= xn)
758+ α = softmax_edge_neighbors (g, l. β .* cos_dist)
759+
760+ x = propagate (g, + ; xj= x, e= α) do xi, xj, α
761+ α .* xj
762+ end
763+
764+ return x
765+ end
766+
0 commit comments