diff --git a/asQ/diag_preconditioner.py b/asQ/diag_preconditioner.py index 22b20236..004bde66 100644 --- a/asQ/diag_preconditioner.py +++ b/asQ/diag_preconditioner.py @@ -14,6 +14,64 @@ import asQ.complex_proxy.vector as cpx +def construct_riesz_map(W, prefix, fieldsplit=False, riesz_options=None): + """ + Construct projection into W assuming W is a complex-proxy + FunctionSpace for the real FunctionSpace V. + + :arg V: a real-valued FunctionSpace. + :arg W: a complex-proxy FunctionSpace for V. + :arg prefix: the prefix for the PETSc options for the projection solve. + :arg riesz_options: PETSc options for the projection solve. Defaults to direct solve. + """ + # default is to solve directly + if riesz_options is None: + riesz_options = { + 'ksp_type': 'preonly', + 'pc_type': 'lu', + 'pc_factor_mat_solver_type': 'mumps', + 'mat_type': 'aij' + } + + # mixed mass matrices are decoupled so solve seperately + if fieldsplit: + full_riesz_options = { + 'ksp_type': 'preonly', + 'mat_type': 'nest', + 'pc_type': 'fieldsplit', + 'pc_field_split_type': 'additive', + 'fieldsplit': riesz_options + } + else: + full_riesz_options = riesz_options + + # mat types + mat_type = PETSc.Options().getString( + f"{prefix}mat_type", + default=riesz_options['mat_type']) + + sub_mat_type = PETSc.Options().getString( + f"{prefix}fieldsplit_mat_type", + default=riesz_options['mat_type']) + + # input for riesz map + rhs = fd.Function(W) + + # construct forms + v = fd.TestFunction(W) + u = fd.TrialFunction(W) + + a = fd.assemble(fd.inner(u, v)*fd.dx, + mat_type=mat_type, + sub_mat_type=sub_mat_type) + + # create LinearSolver + rmap = fd.LinearSolver(a, solver_parameters=full_riesz_options, + options_prefix=f"{prefix}") + + return rmap, rhs + + class DiagFFTPC(object): """ PETSc options: @@ -109,23 +167,19 @@ def initialize(self, pc): self.w_all = self.aaos.w_all # basic model function space - self.blockV = self.aaos.function_space + self.function_space = self.aaos.function_space W_all = self.aaos.function_space_all # sanity check - assert (self.blockV.dim()*paradiag.nlocal_timesteps == W_all.dim()) - - # Input/Output wrapper Functions for all-at-once residual being acted on - self.xf = fd.Function(W_all) # input - self.yf = fd.Function(W_all) # output + assert (self.function_space.dim()*paradiag.nlocal_timesteps == W_all.dim()) # Gamma coefficients exponents = np.arange(self.ntimesteps)/self.ntimesteps - self.Gam = paradiag.alpha**exponents + self.gamma = paradiag.alpha**exponents slice_begin = self.aaos.transform_index(0, from_range='slice', to_range='window') slice_end = slice_begin + self.nlocal_timesteps - self.Gam_slice = self.Gam[slice_begin:slice_end] + self.gamma_slice = self.gamma[slice_begin:slice_end] # circulant eigenvalues C1col = np.zeros(self.ntimesteps) @@ -137,37 +191,37 @@ def initialize(self, pc): C1col[:2] = np.array([1, -1])/dt C2col[:2] = np.array([theta, 1-theta]) - self.D1 = np.sqrt(self.ntimesteps)*fft(self.Gam*C1col) - self.D2 = np.sqrt(self.ntimesteps)*fft(self.Gam*C2col) + self.D1 = fft(self.gamma*C1col, norm='backward') + self.D2 = fft(self.gamma*C2col, norm='backward') # Block system setup - # First need to build the vector function space version of blockV - self.CblockV = cpx.FunctionSpace(self.blockV) + # First need to build the vector function space version of function_space + self.cpx_function_space = cpx.FunctionSpace(self.function_space) # set the boundary conditions to zero for the residual - self.CblockV_bcs = tuple((cb - for bc in self.aaos.boundary_conditions - for cb in cpx.DirichletBC(self.CblockV, self.blockV, - bc, 0*bc.function_arg))) + self.block_bcs = tuple((cb + for bc in self.aaos.boundary_conditions + for cb in cpx.DirichletBC(self.cpx_function_space, self.function_space, + bc, 0*bc.function_arg))) # function to do global reduction into for average block jacobian if jac_state in ('window', 'slice'): - self.ureduce = fd.Function(self.blockV) - self.uwrk = fd.Function(self.blockV) + self.ureduce = fd.Function(self.function_space) + self.uwrk = fd.Function(self.function_space) # input and output functions to the block solve - self.Jprob_in = fd.Function(self.CblockV) - self.Jprob_out = fd.Function(self.CblockV) + self.block_rhs = fd.Function(self.cpx_function_space) + self.block_sol = fd.Function(self.cpx_function_space) - # A place to store the real/imag components of the all-at-once residual after fft - self.xfi = fd.Function(W_all) - self.xfr = fd.Function(W_all) + # A place to store the real/imag components of the all-at-once residual + self.xreal = fd.Function(W_all) + self.ximag = fd.Function(W_all) # setting up the FFT stuff # construct simply dist array and 1d fftn: subcomm = Subcomm(self.ensemble.ensemble_comm, [0, 1]) # dimensions of space-time data in this ensemble_comm - nlocal = self.blockV.node_set.size + nlocal = self.function_space.node_set.size NN = np.array([self.ntimesteps, nlocal], dtype=int) # transfer pencil is aligned along axis 1 self.p0 = Pencil(subcomm, NN, axis=1) @@ -179,49 +233,27 @@ def initialize(self, pc): self.a1 = np.zeros(self.p1.subshape, complex) self.transfer = self.p0.transfer(self.p1, complex) - # setting up the Riesz map - default_riesz_method = { - 'ksp_type': 'preonly', - 'pc_type': 'lu', - 'pc_factor_mat_solver_type': 'mumps', - 'mat_type': 'aij' - } + # setting up the Riesz map to project residual into complex space + is_mixed = isinstance(self.function_space.ufl_element(), fd.MixedElement) + rmap_rhs = construct_riesz_map(self.cpx_function_space, + prefix=f"{prefix}{self.prefix}mass_", + fieldsplit=is_mixed) + self.riesz_proj, self.riesz_rhs = rmap_rhs + + # Now need to build the block solvers - # mixed mass matrices are decoupled so solve seperately - if isinstance(self.blockV.ufl_element(), fd.MixedElement): - default_riesz_parameters = { - 'ksp_type': 'preonly', - 'mat_type': 'nest', - 'pc_type': 'fieldsplit', - 'pc_field_split_type': 'additive', - 'fieldsplit': default_riesz_method - } - else: - default_riesz_parameters = default_riesz_method - - # we need to pass the mat_types to assemble directly because - # it won't pick them up from Options - - riesz_mat_type = PETSc.Options().getString( - f"{prefix}{self.prefix}mass_mat_type", - default=default_riesz_parameters['mat_type']) - - riesz_sub_mat_type = PETSc.Options().getString( - f"{prefix}{self.prefix}mass_fieldsplit_mat_type", - default=default_riesz_method['mat_type']) - - # input for the Riesz map - self.xtemp = fd.Function(self.CblockV) - v = fd.TestFunction(self.CblockV) - u = fd.TrialFunction(self.CblockV) - - a = fd.assemble(fd.inner(u, v)*fd.dx, - mat_type=riesz_mat_type, - sub_mat_type=riesz_sub_mat_type) - - self.Proj = fd.LinearSolver(a, solver_parameters=default_riesz_parameters, - options_prefix=f"{prefix}{self.prefix}mass_") + # time-average function to linearise around + self.u0 = fd.Function(self.cpx_function_space) + self.block_solvers = tuple((self._make_block(i, f"{prefix}{self.prefix}") + for i in range(self.nlocal_timesteps))) + + self.initialized = True + + def _make_block(self, i, prefix): + """ + Construct the LinearVariationalSolver for block index i. + """ # building the Jacobian of the nonlinear term # what we want is a block diagonal matrix in the 2x2 system # coupling the real and imaginary parts. @@ -231,7 +263,6 @@ def initialize(self, pc): # This is constructed by cpx.derivative # Building the nonlinear operator - self.Jsolvers = [] # which form to linearise around valid_linearisations = ['consistent', 'user'] @@ -248,65 +279,48 @@ def initialize(self, pc): form_function = partial(form_function, t=self.t_average) - # Now need to build the block solver - self.u0 = fd.Function(self.CblockV) # time average to linearise around - - # building the block problem solvers - for i in range(self.nlocal_timesteps): - ii = self.aaos.transform_index(i, from_range='slice', to_range='window') - d1 = self.D1[ii] - d2 = self.D2[ii] - - M, D1r, D1i = cpx.BilinearForm(self.CblockV, d1, form_mass, return_z=True) - K, D2r, D2i = cpx.derivative(d2, form_function, self.u0, return_z=True) - - A = M + K - - # The rhs - v = fd.TestFunction(self.CblockV) - L = fd.inner(v, self.Jprob_in)*fd.dx - - # pass sigma into PC: - sigma = self.D1[ii]**2/self.D2[ii] - sigma_inv = self.D2[ii]**2/self.D1[ii] - appctx_h = {} - appctx_h["sr"] = fd.Constant(np.real(sigma)) - appctx_h["si"] = fd.Constant(np.imag(sigma)) - appctx_h["sinvr"] = fd.Constant(np.real(sigma_inv)) - appctx_h["sinvi"] = fd.Constant(np.imag(sigma_inv)) - appctx_h["D2r"] = D2r - appctx_h["D2i"] = D2i - appctx_h["D1r"] = D1r - appctx_h["D1i"] = D1i - - # Options with prefix 'diagfft_block_' apply to all blocks by default - # If any options with prefix 'diagfft_block_{i}' exist, where i is the - # block number, then this prefix is used instead (like pc fieldsplit) - - block_prefix = f"{prefix}{self.prefix}block_" - for k, v in PETSc.Options().getAll().items(): - if k.startswith(f"{block_prefix}{str(ii)}_"): - block_prefix = f"{block_prefix}{str(ii)}_" - break - - jprob = fd.LinearVariationalProblem(A, L, self.Jprob_out, - bcs=self.CblockV_bcs) - Jsolver = fd.LinearVariationalSolver(jprob, - appctx=appctx_h, - options_prefix=block_prefix) - # multigrid transfer manager - if f'{prefix}transfer_managers' in paradiag.block_ctx: - # Jsolver.set_transfer_manager(paradiag.block_ctx['diag_transfer_managers'][ii]) - tm = paradiag.block_ctx[f'{prefix}transfer_managers'][i] - Jsolver.set_transfer_manager(tm) - tm_set = (Jsolver._ctx.transfer_manager is tm) - - if tm_set is False: - print(f"transfer manager not set on Jsolvers[{ii}]") - - self.Jsolvers.append(Jsolver) + ii = self.aaos.transform_index(i, from_range='slice', to_range='window') + d1 = self.D1[ii] + d2 = self.D2[ii] - self.initialized = True + M, D1r, D1i = cpx.BilinearForm(self.cpx_function_space, d1, form_mass, return_z=True) + K, D2r, D2i = cpx.derivative(d2, form_function, self.u0, return_z=True) + + A = M + K + + # The rhs + v = fd.TestFunction(self.cpx_function_space) + L = fd.inner(v, self.block_rhs)*fd.dx + + # pass sigma into PC: + appctx_h = {} + + # Options with prefix 'diagfft_block_' apply to all blocks by default + # If any options with prefix 'diagfft_block_{i}' exist, where i is the + # block number, then this prefix is used instead (like pc fieldsplit) + + block_prefix = f"{prefix}block_" + for k, v in PETSc.Options().getAll().items(): + if k.startswith(f"{block_prefix}{str(ii)}_"): + block_prefix = f"{block_prefix}{str(ii)}_" + break + + block_prob = fd.LinearVariationalProblem(A, L, self.block_sol, + bcs=self.block_bcs) + block_solver = fd.LinearVariationalSolver(block_prob, + appctx=appctx_h, + options_prefix=block_prefix) + # multigrid transfer manager + if 'diag_transfer_managers' in self.paradiag.block_ctx: + # block_solver.set_transfer_manager(self.paradiag.block_ctx['diag_transfer_managers'][ii]) + tm = self.paradiag.block_ctx['diag_transfer_managers'][i] + block_solver.set_transfer_manager(tm) + tm_set = (block_solver._ctx.transfer_manager is tm) + + if tm_set is False: + print(f"transfer manager not set on block_solvers[{ii}]") + + return block_solver def _record_diagnostics(self): """ @@ -315,7 +329,7 @@ def _record_diagnostics(self): Must be called exactly once at the end of each apply() """ for i in range(self.aaos.nlocal_timesteps): - its = self.Jsolvers[i].snes.getLinearSolveIterations() + its = self.block_solvers[i].snes.getLinearSolveIterations() self.paradiag.block_iterations.dlocal[i] += its @PETSc.Log.EventDecorator() @@ -351,87 +365,129 @@ def update(self, pc): @PETSc.Log.EventDecorator() @memprofile - def apply(self, pc, x, y): - - # copy petsc vec into Function - # hopefully this works - with self.xf.dat.vec_wo as v: - x.copy(v) - - # get array of basis coefficients - with self.xf.dat.vec_ro as v: - parray = v.array_r.reshape((self.aaos.nlocal_timesteps, - self.blockV.node_set.size)) - # This produces an array whose rows are time slices - # and columns are finite element basis coefficients - - ###################### - # Diagonalise - scale, transfer, FFT, transfer, Copy - # Scale - # is there a better way to do this with broadcasting? - parray = (1.0+0.j)*(self.Gam_slice*parray.T).T*np.sqrt(self.ntimesteps) - # transfer forward - self.a0[:] = parray[:] + def to_eigenbasis(self, xreal, ximag, output='real,imag'): + """ + In-place transform of the complex vector (xreal, ximag) to the preconditioner (block-)eigenbasis. + :arg xreal: real part of input and output + :arg ximag: real part of input and output + :arg output: which parts of the result to copy the back into xreal and/or ximag. + """ + # copy data into working array + with xreal.dat.vec_ro as v: + self.a0.real[:] = v.array_r.reshape((self.aaos.nlocal_timesteps, + self.function_space.node_set.size)) + with ximag.dat.vec_ro as v: + self.a0.imag[:] = v.array_r.reshape((self.aaos.nlocal_timesteps, + self.function_space.node_set.size)) + + # alpha-weighting + self.a0.real[:] = (self.gamma_slice*self.a0.real.T).T + + # transpose forward self.transfer.forward(self.a0, self.a1) + # FFT self.a1[:] = fft(self.a1, axis=0) - # transfer backward + # transpose backward self.transfer.backward(self.a1, self.a0) - # Copy into xfi, xfr - parray[:] = self.a0[:] - with self.xfr.dat.vec_wo as v: - v.array[:] = parray.real.reshape(-1) - with self.xfi.dat.vec_wo as v: - v.array[:] = parray.imag.reshape(-1) - ##################### - # Do the block solves + # copy back into output + if 'real' in output: + with xreal.dat.vec_wo as v: + v.array[:] = self.a0.real.reshape(-1) + if 'imag' in output: + with ximag.dat.vec_wo as v: + v.array[:] = self.a0.imag.reshape(-1) + + @PETSc.Log.EventDecorator() + @memprofile + def from_eigenbasis(self, xreal, ximag, output='real,imag'): + """ + In-place transform of the complex vector (xreal, ximag) from the preconditioner (block-)eigenbasis. + :arg xreal: real part of input and output + :arg ximag: real part of input and output + :arg output: which parts of the result to copy the back into xreal and/or ximag. + """ + # copy data into working array + with xreal.dat.vec_ro as v: + self.a0.real[:] = v.array_r.reshape((self.aaos.nlocal_timesteps, + self.function_space.node_set.size)) + with ximag.dat.vec_ro as v: + self.a0.imag[:] = v.array_r.reshape((self.aaos.nlocal_timesteps, + self.function_space.node_set.size)) + + # transpose forward + self.transfer.forward(self.a0, self.a1) + + # IFFT + self.a1[:] = ifft(self.a1, axis=0) + + # transpose backward + self.transfer.backward(self.a1, self.a0) + + # alpha-weighting + self.a0[:] = ((1.0/self.gamma_slice)*self.a0.T).T + + # copy back into output + if 'real' in output: + with xreal.dat.vec_wo as v: + v.array[:] = self.a0.real.reshape(-1) + if 'imag' in output: + with ximag.dat.vec_wo as v: + v.array[:] = self.a0.imag.reshape(-1) + + @PETSc.Log.EventDecorator() + @memprofile + def solve_blocks(self, xreal, ximag): + """ + Solve each of the blocks in the diagonalised preconditioner with + complex vector (xreal,ximag) as the right-hand-sides. + :arg xreal: real part of input and output + :arg ximag: real part of input and output + """ + def get_field(i, x): + return self.aaos.get_field_components(i, f_alls=x.subfunctions) for i in range(self.aaos.nlocal_timesteps): - # copy the data into solver input - self.xtemp.assign(0.) + self.block_rhs.assign(0.) + self.block_sol.assign(0.) - cpx.set_real(self.xtemp, self.aaos.get_field_components(i, f_alls=self.xfr.subfunctions)) - cpx.set_imag(self.xtemp, self.aaos.get_field_components(i, f_alls=self.xfi.subfunctions)) + # copy the data into solver input + cpx.set_real(self.riesz_rhs, get_field(i, xreal)) + cpx.set_imag(self.riesz_rhs, get_field(i, ximag)) - # Do a project for Riesz map, to be superceded - # when we get Cofunction - self.Proj.solve(self.Jprob_in, self.xtemp) + # Do a project for Riesz map, to be superceded when we get Cofunction + self.riesz_proj.solve(self.block_rhs, self.riesz_rhs) # solve the block system - self.Jprob_out.assign(0.) - self.Jsolvers[i].solve() + self.block_solvers[i].solve() # copy the data from solver output - cpx.get_real(self.Jprob_out, self.aaos.get_field_components(i, f_alls=self.xfr.subfunctions)) - cpx.get_imag(self.Jprob_out, self.aaos.get_field_components(i, f_alls=self.xfi.subfunctions)) - - ###################### - # Undiagonalise - Copy, transfer, IFFT, transfer, scale, copy - # get array of basis coefficients - with self.xfi.dat.vec_ro as v: - parray = 1j*v.array_r.reshape((self.aaos.nlocal_timesteps, - self.blockV.node_set.size)) - with self.xfr.dat.vec_ro as v: - parray += v.array_r.reshape((self.aaos.nlocal_timesteps, - self.blockV.node_set.size)) - # transfer forward - self.a0[:] = parray[:] - self.transfer.forward(self.a0, self.a1) - # IFFT - self.a1[:] = ifft(self.a1, axis=0) - # transfer backward - self.transfer.backward(self.a1, self.a0) - parray[:] = self.a0[:] - # scale - parray = ((1.0/self.Gam_slice)*parray.T).T - # Copy into xfi, xfr - with self.yf.dat.vec_wo as v: - v.array[:] = parray.reshape(-1).real - with self.yf.dat.vec_ro as v: + cpx.get_real(self.block_sol, get_field(i, xreal)) + cpx.get_imag(self.block_sol, get_field(i, ximag)) + + @PETSc.Log.EventDecorator() + @memprofile + def apply(self, pc, x, y): + + # copy input Vec into Function + with self.xreal.dat.vec_wo as v: + x.copy(v) + self.ximag.assign(0.) + + # forward FFT + self.to_eigenbasis(self.xreal, self.ximag) + + # Do the block solves + self.solve_blocks(self.xreal, self.ximag) + + # backward IFFT + self.from_eigenbasis(self.xreal, self.ximag, output='real') + + # copy solution into output Vec + with self.xreal.dat.vec_ro as v: v.copy(y) - ################ self._record_diagnostics() diff --git a/examples/advection/periodic_dg_advection.py b/examples/advection/periodic_dg_advection.py index b1bcae69..8326698c 100644 --- a/examples/advection/periodic_dg_advection.py +++ b/examples/advection/periodic_dg_advection.py @@ -111,8 +111,8 @@ def form_function(q, phi, t): # The PETSc solver parameters used to solve the # blocks in step (b) of inverting the ParaDiag matrix. block_parameters = { - 'ksp_type': 'gmres', - 'pc_type': 'bjacobi', + 'ksp_type': 'preonly', + 'pc_type': 'lu', } # The PETSc solver parameters for solving the all-at-once system. @@ -130,22 +130,18 @@ def form_function(q, phi, t): # 'ksp_type': 'preonly' paradiag_parameters = { + 'snes_type': 'ksponly', 'snes': { - 'linesearch_type': 'basic', 'monitor': None, 'converged_reason': None, 'rtol': 1e-10, - 'atol': 1e-12, - 'stol': 1e-12, }, 'mat_type': 'matfree', - 'ksp_type': 'preonly', + 'ksp_type': 'richardson', 'ksp': { 'monitor': None, 'converged_reason': None, 'rtol': 1e-10, - 'atol': 1e-12, - 'stol': 1e-12, }, 'pc_type': 'python', 'pc_python_type': 'asQ.DiagFFTPC'