@@ -13,7 +13,7 @@ abstract type GNNLayer end
1313
1414
1515"""
16- WithGraph(model, g::GNNGraph; traingraph=false)
16+ WithGraph(model, g::GNNGraph; traingraph=false)
1717
1818A type wrapping the `model` and tying it to the graph `g`.
1919In the forward pass, can only take feature arrays as inputs,
@@ -38,17 +38,31 @@ x2 = rand(Float32, 2, 4)
3838@assert wg(g2, x2) == model(g2, x2)
3939```
4040"""
41- struct WithGraph{M}
42- model:: M
43- g:: GNNGraph
44- traingraph:: Bool
41+ struct WithGraph{M, G <: GNNGraph }
42+ model:: M
43+ g:: G
44+ traingraph:: Bool
4545end
4646
4747WithGraph (model, g:: GNNGraph ; traingraph= false ) = WithGraph (model, g, traingraph)
4848
4949@functor WithGraph
5050Flux. trainable (l:: WithGraph ) = l. traingraph ? (l. model, l. g) : (l. model,)
5151
52+ # Work around
53+ # https://github.com/FluxML/Flux.jl/issues/1733
54+ # Revisit after
55+ # https://github.com/FluxML/Flux.jl/pull/1742
56+ function Flux. destructure (m:: WithGraph )
57+ @assert m. traingraph == false # TODO
58+ p, re = Flux. destructure (m. model)
59+ function re_withgraph (x)
60+ WithGraph (re (x), m. g, m. traingraph)
61+ end
62+
63+ return p, re_withgraph
64+ end
65+
5266(l:: WithGraph )(g:: GNNGraph , x... ; kws... ) = l. model (g, x... ; kws... )
5367(l:: WithGraph )(x... ; kws... ) = l. model (l. g, x... ; kws... )
5468
@@ -86,15 +100,15 @@ julia> m(g, x)
86100```
87101"""
88102struct GNNChain{T} <: GNNLayer
89- layers:: T
90-
91- GNNChain (xs... ) = new {typeof(xs)} (xs)
92-
93- function GNNChain (; kw... )
94- :layers in Base. keys (kw) && throw (ArgumentError (" a GNNChain cannot have a named layer called `layers`" ))
95- isempty (kw) && return new {Tuple{}} (())
96- new {typeof(values(kw))} (values (kw))
97- end
103+ layers:: T
104+
105+ GNNChain (xs... ) = new {typeof(xs)} (xs)
106+
107+ function GNNChain (; kw... )
108+ :layers in Base. keys (kw) && throw (ArgumentError (" a GNNChain cannot have a named layer called `layers`" ))
109+ isempty (kw) && return new {Tuple{}} (())
110+ new {typeof(values(kw))} (values (kw))
111+ end
98112end
99113
100114@forward GNNChain. layers Base. getindex, Base. length, Base. first, Base. last,
0 commit comments