5454"""
5555 add_edges(g::GNNGraph, s::AbstractVector, t::AbstractVector; [edata])
5656
57- Add to graph `g` the edges with source nodes `s` and target nodes `t`.
58-
57+ Add to graph `g` the edges with source nodes `s` and target nodes `t`.
5958"""
6059function add_edges (g:: GNNGraph{<:COO_T} ,
6160 snew:: AbstractVector{<:Integer} ,
@@ -79,6 +78,25 @@ function add_edges(g::GNNGraph{<:COO_T},
7978 g. ndata, edata, g. gdata)
8079end
8180
81+
82+ """
83+ add_nodes(g::GNNGraph, n; [ndata])
84+
85+ Add `n` new nodes to graph `g`. In the
86+ new graph, these nodes will have indexes from `g.num_nodes + 1`
87+ to `g.num_nodes + n`.
88+ """
89+ function add_nodes (g:: GNNGraph{<:COO_T} , n:: Integer ; ndata= (;))
90+ ndata = normalize_graphdata (ndata, default_name= :x , n= n)
91+ ndata = cat_features (g. ndata, ndata)
92+
93+ GNNGraph (g. graph,
94+ g. num_nodes + n, g. num_edges, g. num_graphs,
95+ g. graph_indicator,
96+ ndata, g. edata, g. gdata)
97+ end
98+
99+
82100function SparseArrays. blockdiag (g1:: GNNGraph , g2:: GNNGraph )
83101 nv1, nv2 = g1. num_nodes, g2. num_nodes
84102 if g1. graph isa COO_T
@@ -117,8 +135,6 @@ function SparseArrays.blockdiag(A1::AbstractMatrix, A2::AbstractMatrix)
117135 O2 A2]
118136end
119137
120- # ## Cat public interfaces #############
121-
122138"""
123139 blockdiag(xs::GNNGraph...)
124140
@@ -133,14 +149,115 @@ function SparseArrays.blockdiag(g1::GNNGraph, gothers::GNNGraph...)
133149end
134150
135151"""
136- batch(xs ::Vector{<:GNNGraph})
152+ batch(gs ::Vector{<:GNNGraph})
137153
138154Batch together multiple `GNNGraph`s into a single one
139155containing the total number of original nodes and edges.
140156
141157Equivalent to [`SparseArrays.blockdiag`](@ref).
158+ See also [`Flux.unbatch`](@ref).
159+
160+ # Usage
161+
162+ ```juliarepl
163+ julia> g1 = rand_graph(4, 6, ndata=ones(8, 4))
164+ GNNGraph:
165+ num_nodes = 4
166+ num_edges = 6
167+ num_graphs = 1
168+ ndata:
169+ x => (8, 4)
170+ edata:
171+ gdata:
172+
173+
174+ julia> g2 = rand_graph(7, 4, ndata=zeros(8, 7))
175+ GNNGraph:
176+ num_nodes = 7
177+ num_edges = 4
178+ num_graphs = 1
179+ ndata:
180+ x => (8, 7)
181+ edata:
182+ gdata:
183+
184+
185+ julia> g12 = Flux.batch([g1, g2])
186+ GNNGraph:
187+ num_nodes = 11
188+ num_edges = 10
189+ num_graphs = 2
190+ ndata:
191+ x => (8, 11)
192+ edata:
193+ gdata:
194+
195+
196+ julia> g12.ndata.x
197+ 8×11 Matrix{Float64}:
198+ 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
199+ 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
200+ 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
201+ 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
202+ 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
203+ 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
204+ 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
205+ 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
206+ ```
207+ """
208+ Flux. batch (gs:: Vector{<:GNNGraph} ) = blockdiag (gs... )
209+
210+
142211"""
143- Flux. batch (xs:: Vector{<:GNNGraph} ) = blockdiag (xs... )
212+ unbatch(g::GNNGraph)
213+
214+ Opposite of the [`Flux.batch`](@ref) operation, returns
215+ an array of the individual graphs batched together in `g`.
216+
217+ See also [`Flux.batch`](@ref) and [`getgraph`](@ref).
218+
219+ # Usage
220+
221+ ```juliarepl
222+ julia> gbatched = Flux.batch([rand_graph(5, 6), rand_graph(10, 8), rand_graph(4,2)])
223+ GNNGraph:
224+ num_nodes = 19
225+ num_edges = 16
226+ num_graphs = 3
227+ ndata:
228+ edata:
229+ gdata:
230+
231+ julia> Flux.unbatch(gbatched)
232+ 3-element Vector{GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}:
233+ GNNGraph:
234+ num_nodes = 5
235+ num_edges = 6
236+ num_graphs = 1
237+ ndata:
238+ edata:
239+ gdata:
240+
241+ GNNGraph:
242+ num_nodes = 10
243+ num_edges = 8
244+ num_graphs = 1
245+ ndata:
246+ edata:
247+ gdata:
248+
249+ GNNGraph:
250+ num_nodes = 4
251+ num_edges = 2
252+ num_graphs = 1
253+ ndata:
254+ edata:
255+ gdata:
256+ ```
257+ """
258+ function Flux. unbatch (g:: GNNGraph )
259+ [getgraph (g, i) for i in 1 : g. num_graphs]
260+ end
144261
145262
146263"""
0 commit comments