Line data Source code
1 : #ifndef SOLVER_NONLOCAL_LINEAR
2 : #define SOLVER_NONLOCAL_LINEAR
3 : #include "Operator/Operator.H"
4 : #include <AMReX_MLMG.H>
5 :
6 : namespace Solver
7 : {
8 : namespace Nonlocal
9 : {
10 : /// \brief Multigrid Linear solver for multicomponent, multi-level operators
11 : ///
12 : /// This class is a thin wrapper for the `amrex::MLMG` solver.
13 : /// It exists to set a range of default MLMG settings automatically, for instance,
14 : /// `setCFStrategy`, which may not be obvious to the user.
15 : ///
16 : /// It also exists as a compatibility layer so that future fixes for compatibility
17 : /// with AMReX can be implemented here.
18 : class Linear // : public amrex::MLMG
19 : {
20 : public:
21 :
22 36 : Linear() //Operator::Operator<Grid::Node>& a_lp) : MLMG(a_lp), linop(a_lp)
23 36 : {
24 36 : }
25 :
26 : Linear(Operator::Operator<Grid::Node>& a_lp)
27 : {
28 : this->Define(a_lp);
29 : }
30 :
31 36 : ~Linear()
32 : {
33 36 : if (m_defined) Clear();
34 36 : }
35 :
36 734 : void Define(Operator::Operator<Grid::Node>& a_lp)
37 : {
38 734 : if (m_defined) Util::Abort(INFO, "Solver cannot be re-defined");
39 734 : this->linop = &a_lp;
40 734 : this->mlmg = new amrex::MLMG(a_lp);
41 734 : m_defined = true;
42 734 : PrepareMLMG(*mlmg);
43 734 : }
44 734 : void Clear()
45 : {
46 734 : if (!m_defined) Util::Abort(INFO, "Solver cannot be cleared if not defined");
47 734 : this->linop = nullptr;
48 734 : if (this->mlmg) delete this->mlmg;
49 734 : this->mlmg = nullptr;
50 734 : m_defined = false;
51 734 : }
52 :
53 : Set::Scalar solveaffine(amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_sol,
54 : amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_rhs,
55 : Real a_tol_rel, Real a_tol_abs, bool copyrhs = false,
56 : const char* checkpoint_file = nullptr)
57 : {
58 : if (!m_defined) Util::Abort(INFO, "Solver not defined");
59 : amrex::Vector<amrex::MultiFab*> rhs_tmp(a_rhs.size());
60 : amrex::Vector<amrex::MultiFab*> zero_tmp(a_rhs.size());
61 : for (int i = 0; i < rhs_tmp.size(); i++)
62 : {
63 : rhs_tmp[i] = new amrex::MultiFab(a_rhs[i]->boxArray(), a_rhs[i]->DistributionMap(), a_rhs[i]->nComp(), a_rhs[i]->nGrow());
64 : zero_tmp[i] = new amrex::MultiFab(a_rhs[i]->boxArray(), a_rhs[i]->DistributionMap(), a_rhs[i]->nComp(), a_rhs[i]->nGrow());
65 : rhs_tmp[i]->setVal(0.0);
66 : zero_tmp[i]->setVal(0.0);
67 : Util::Message(INFO, rhs_tmp[i]->norm0());
68 : }
69 :
70 : linop->SetHomogeneous(false);
71 : mlmg->apply(rhs_tmp, zero_tmp);
72 :
73 : for (int lev = 0; lev < rhs_tmp.size(); lev++)
74 : {
75 : amrex::Box domain = linop->Geom(lev).Domain();
76 : domain.convert(amrex::IntVect::TheNodeVector());
77 : const Dim3 lo = amrex::lbound(domain), hi = amrex::ubound(domain);
78 : for (MFIter mfi(*rhs_tmp[lev], amrex::TilingIfNotGPU());mfi.isValid();++mfi)
79 : {
80 : amrex::Box bx = mfi.growntilebox(rhs_tmp[lev]->nGrow());
81 : bx = bx & domain;
82 : amrex::Array4<amrex::Real> const& rhstmp = rhs_tmp[lev]->array(mfi);
83 : for (int n = 0; n < rhs_tmp[lev]->nComp(); n++)
84 : {
85 : amrex::ParallelFor(bx, [=] AMREX_GPU_DEVICE(int i, int j, int k)
86 : {
87 : bool AMREX_D_DECL(xmin = (i == lo.x), ymin = (j == lo.y), zmin = (k == lo.z)),
88 : AMREX_D_DECL(xmax = (i == hi.x), ymax = (j == hi.y), zmax = (k == hi.z));
89 : if (AMREX_D_TERM(xmax || xmin, || ymax || ymin, || zmax || zmin))
90 : rhstmp(i, j, k, n) = 0.0;
91 : else
92 : rhstmp(i, j, k, n) *= -1.0;
93 : });
94 : }
95 : }
96 : Util::Message(INFO, rhs_tmp[lev]->norm0());
97 : rhs_tmp[lev]->setMultiGhost(true);
98 : rhs_tmp[lev]->FillBoundaryAndSync(linop->Geom(lev).periodicity());
99 : }
100 :
101 : for (int lev = 0; lev < rhs_tmp.size(); lev++)
102 : {
103 : Util::Message(INFO, rhs_tmp[lev]->norm0());
104 : amrex::Add(*rhs_tmp[lev], *a_rhs[lev], 0, 0, rhs_tmp[lev]->nComp(), rhs_tmp[lev]->nGrow());
105 : if (copyrhs)
106 : amrex::Copy(*a_rhs[lev], *rhs_tmp[lev], 0, 0, rhs_tmp[lev]->nComp(), rhs_tmp[lev]->nGrow());
107 : Util::Message(INFO, rhs_tmp[lev]->norm0());
108 : }
109 :
110 : linop->SetHomogeneous(true);
111 : PrepareMLMG(*mlmg);
112 : Set::Scalar retval = NAN;
113 : try
114 : {
115 : retval = mlmg->solve(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(rhs_tmp), a_tol_rel, a_tol_abs, checkpoint_file);
116 : }
117 : catch (const std::exception& e)
118 : {
119 : if (m_dump_on_fail) dumpOnConvergenceFail(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(rhs_tmp));
120 : if (m_abort_on_fail) Util::Abort(INFO, e.what());
121 : }
122 : if (a_sol[0]->contains_nan())
123 : {
124 : dumpOnConvergenceFail(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(rhs_tmp));
125 : Util::Abort(INFO);
126 : }
127 :
128 : return retval;
129 : };
130 :
131 1634 : Set::Scalar solve(amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_sol,
132 : amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_rhs,
133 : Real a_tol_rel, Real a_tol_abs, const char* checkpoint_file = nullptr)
134 : {
135 1634 : PrepareMLMG(*mlmg);
136 1634 : Set::Scalar retval = NAN;
137 : try
138 : {
139 1634 : retval = mlmg->solve(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(a_rhs), a_tol_rel, a_tol_abs, checkpoint_file);
140 : }
141 0 : catch (const std::exception& e)
142 : {
143 0 : if (m_dump_on_fail) dumpOnConvergenceFail(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(a_rhs));
144 0 : if (m_abort_on_fail) Util::Abort(INFO, e.what());
145 0 : }
146 1634 : return retval;
147 : };
148 : Set::Scalar solve(amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_sol,
149 : amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_rhs)
150 : {
151 : PrepareMLMG(*mlmg);
152 : Set::Scalar retval = NAN;
153 : try
154 : {
155 : retval = mlmg->solve(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(a_rhs), tol_rel, tol_abs);
156 : }
157 : catch (const std::exception& e)
158 : {
159 : if (m_dump_on_fail) dumpOnConvergenceFail(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(a_rhs));
160 : if (m_abort_on_fail) Util::Abort(INFO, e.what());
161 : }
162 : if (a_sol[0]->contains_nan())
163 : {
164 : dumpOnConvergenceFail(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(a_rhs));
165 : Util::Abort(INFO);
166 : }
167 : return retval;
168 : };
169 : void apply(amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_rhs,
170 : amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_sol)
171 : {
172 : PrepareMLMG(*mlmg);
173 : mlmg->apply(GetVecOfPtrs(a_rhs), GetVecOfPtrs(a_sol));
174 : };
175 :
176 :
177 :
178 : void setMaxIter(const int a_max_iter) { max_iter = a_max_iter; }
179 : void setBottomMaxIter(const int a_bottom_max_iter) { bottom_max_iter = a_bottom_max_iter; }
180 : void setMaxFmgIter(const int a_max_fmg_iter) { max_fmg_iter = a_max_fmg_iter; }
181 : void setFixedIter(const int a_fixed_iter) { fixed_iter = a_fixed_iter; }
182 : void setVerbose(const int a_verbose) { verbose = a_verbose; }
183 : void setPreSmooth(const int a_pre_smooth) { pre_smooth = a_pre_smooth; }
184 : void setPostSmooth(const int a_post_smooth) { post_smooth = a_post_smooth; }
185 :
186 0 : void dumpOnConvergenceFail(const amrex::Vector<amrex::MultiFab*>& a_sol_mf,
187 : const amrex::Vector<amrex::MultiFab const*>& a_rhs_mf)
188 : {
189 0 : int nlevs = a_sol_mf.size();
190 0 : int ncomps = a_sol_mf[0]->nComp();
191 :
192 0 : amrex::Vector<amrex::Geometry> geom;
193 0 : amrex::Vector<int> iter;
194 0 : amrex::Vector<amrex::IntVect> refratio;
195 0 : amrex::Vector<std::string> names;
196 0 : for (int i = 0; i < nlevs; i++)
197 : {
198 0 : geom.push_back(linop->Geom(i));
199 0 : iter.push_back(0);
200 0 : if (i > 0) refratio.push_back(amrex::IntVect(2));
201 : }
202 0 : for (int n = 0; n < ncomps; n++)
203 : {
204 0 : names.push_back("var" + std::to_string(n));
205 : }
206 :
207 0 : std::string outputdir = Util::GetFileName();
208 0 : WriteMultiLevelPlotfile(outputdir + "/mlmg_sol", nlevs,
209 0 : amrex::GetVecOfConstPtrs(a_sol_mf),
210 : names, geom, 0, iter, refratio);
211 0 : WriteMultiLevelPlotfile(outputdir + "/mlmg_rhs", nlevs,
212 0 : amrex::GetVecOfConstPtrs(a_rhs_mf),
213 : names, geom, 0, iter, refratio);
214 :
215 0 : Set::Field<Set::Scalar> res_mf(nlevs);
216 0 : for (int lev = 0; lev < nlevs; lev++)
217 : {
218 0 : res_mf.Define(lev, a_sol_mf[lev]->boxArray(), a_sol_mf[lev]->DistributionMap(),
219 0 : ncomps, a_sol_mf[lev]->nGrow());
220 : }
221 :
222 0 : mlmg->compResidual(amrex::GetVecOfPtrs(res_mf), a_sol_mf, a_rhs_mf);
223 :
224 0 : WriteMultiLevelPlotfile(outputdir + "/mlmg_res", nlevs,
225 0 : amrex::GetVecOfConstPtrs(res_mf),
226 : names, geom, 0, iter, refratio);
227 :
228 0 : }
229 :
230 : //using MLMG::solve;
231 : protected:
232 : int max_iter = -1;
233 : int bottom_max_iter = -1;
234 : int max_fmg_iter = -1;
235 : int fixed_iter = -1;
236 : int verbose = -1;
237 : int pre_smooth = -1;
238 : int post_smooth = -1;
239 : int final_smooth = -1;
240 : int bottom_smooth = -1;
241 : std::string bottom_solver;
242 : Set::Scalar cg_tol_rel = -1.0;
243 : Set::Scalar cg_tol_abs = -1.0;
244 : Set::Scalar bottom_tol_rel = -1.0;
245 : Set::Scalar bottom_tol_abs = -1.0;
246 : Set::Scalar tol_rel = -1.0;
247 : Set::Scalar tol_abs = -1.0;
248 : Set::Scalar omega = -1.0;
249 : bool average_down_coeffs = false;
250 : bool normalize_ddw = false;
251 :
252 : Operator::Operator<Grid::Node>* linop;
253 : amrex::MLMG* mlmg;
254 :
255 2368 : void PrepareMLMG(amrex::MLMG& mlmg)
256 : {
257 2368 : if (!m_defined) Util::Message(INFO, "Solver not defined");
258 2368 : mlmg.setBottomSolver(MLMG::BottomSolver::bicgstab);
259 2368 : mlmg.setCFStrategy(MLMG::CFStrategy::ghostnodes);
260 2368 : mlmg.setFinalFillBC(false);
261 2368 : mlmg.setMaxFmgIter(100000000);
262 :
263 :
264 2368 : if (max_iter >= 0) mlmg.setMaxIter(max_iter);
265 2368 : if (bottom_max_iter >= 0) mlmg.setBottomMaxIter(bottom_max_iter);
266 2368 : if (max_fmg_iter >= 0) mlmg.setMaxFmgIter(max_fmg_iter);
267 2368 : if (fixed_iter >= 0) mlmg.setFixedIter(fixed_iter);
268 2368 : if (verbose >= 0)
269 : {
270 2368 : mlmg.setVerbose(verbose - 1);
271 2368 : if (verbose > 4) mlmg.setBottomVerbose(verbose);
272 2368 : else mlmg.setBottomVerbose(0);
273 : }
274 :
275 2368 : if (pre_smooth >= 0) mlmg.setPreSmooth(pre_smooth);
276 2368 : if (post_smooth >= 0) mlmg.setPostSmooth(post_smooth);
277 2368 : if (final_smooth >= 0) mlmg.setFinalSmooth(final_smooth);
278 2368 : if (bottom_smooth >= 0) mlmg.setBottomSmooth(bottom_smooth);
279 :
280 2368 : if (bottom_solver == "cg") mlmg.setBottomSolver(MLMG::BottomSolver::cg);
281 2368 : else if (bottom_solver == "bicgstab") mlmg.setBottomSolver(MLMG::BottomSolver::bicgstab);
282 2368 : else if (bottom_solver == "smoother") mlmg.setBottomSolver(MLMG::BottomSolver::smoother);
283 :
284 2368 : if (bottom_tol_rel >= 0) mlmg.setBottomTolerance(bottom_tol_rel);
285 2368 : if (bottom_tol_abs >= 0) mlmg.setBottomToleranceAbs(bottom_tol_abs);
286 :
287 2368 : if (omega >= 0) this->linop->SetOmega(omega);
288 2368 : if (average_down_coeffs) this->linop->SetAverageDownCoeffs(true);
289 2368 : if (normalize_ddw) this->linop->SetNormalizeDDW(true);
290 2368 : }
291 :
292 :
293 : public:
294 : // These are the parameters that are read in for a standard
295 : // multigrid linear solve.
296 30 : static void Parse(Linear& value, amrex::ParmParse& pp)
297 : {
298 : // Max number of iterations to perform before erroring out
299 30 : pp_query("max_iter", value.max_iter);
300 :
301 : // Max number of iterations on the bottom solver
302 30 : pp_query("bottom_max_iter", value.bottom_max_iter);
303 :
304 : // Max number of F-cycle iterations to perform
305 30 : pp_query("max_fmg_iter", value.max_fmg_iter);
306 :
307 : // DEPRICATED - do not use
308 30 : if (pp.contains("max_fixed_iter"))
309 0 : Util::Abort(INFO, "max_fixed_iter is depricated. Use fixed_iter instead.");
310 :
311 : // Number of fixed iterations to perform before exiting gracefully
312 30 : pp_query("fixed_iter", value.fixed_iter);
313 :
314 : // Verbosity of the solver (1-5)
315 30 : pp_query("verbose", value.verbose);
316 :
317 : // Number of smoothing operations before bottom solve (2)
318 30 : pp_query("pre_smooth", value.pre_smooth);
319 :
320 : // Number of smoothing operations after bottom solve (2)
321 30 : pp_query("post_smooth", value.post_smooth);
322 :
323 : // Number of final smoothing operations when smoother is used as bottom solver (8)
324 30 : pp_query("final_smooth", value.final_smooth);
325 :
326 : // Additional smoothing after bottom CG solver (0)
327 30 : pp_query("bottom_smooth", value.bottom_smooth);
328 :
329 : // The method that is used for the multigrid bottom solve (cg, bicgstab, smoother)
330 30 : pp_query("bottom_solver", value.bottom_solver);
331 :
332 30 : if (pp.contains("cg_tol_rel"))
333 0 : Util::Abort(INFO, "cg_tol_rel is depricated. Use bottom_tol_rel instead.");
334 30 : if (pp.contains("cg_tol_abs"))
335 0 : Util::Abort(INFO, "cg_tol_abs is depricated. Use bottom_tol_abs instead.");
336 :
337 : // Relative tolerance on bottom solver
338 30 : pp_query("bottom_tol_rel", value.bottom_tol_rel);
339 :
340 : // Absolute tolerance on bottom solver
341 30 : pp_query("bottom_tol_abs", value.bottom_tol_abs);
342 :
343 : // Relative tolerance
344 30 : pp_query("tol_rel", value.tol_rel);
345 :
346 : // Absolute tolerance
347 30 : pp_query("tol_abs", value.tol_abs);
348 :
349 : // Omega (used in gauss-seidel solver)
350 30 : pp_query("omega", value.omega);
351 :
352 : // Whether to average down coefficients or use the ones given.
353 : // (Setting this to true is important for fracture.)
354 30 : pp_query("average_down_coeffs", value.average_down_coeffs);
355 :
356 : // Whether to normalize DDW when calculating the diagonal.
357 : // This is primarily used when DDW is near-singular - like when there
358 : // is a "void" region or when doing phase field fracture.
359 30 : pp_query("normalize_ddw", value.normalize_ddw);
360 :
361 : // [false]
362 : // If set to true, output diagnostic multifab information
363 : // whenever the MLMG solver fails to converge.
364 : // (Note: you must also set :code:`amrex.signalhandling=0`
365 : // and :code:`amrex.throw_exception=1` for this to work.)
366 30 : pp_query("dump_on_fail", value.m_dump_on_fail);
367 :
368 : // [true]
369 : // If set to false, MLMG will not die if convergence criterion
370 : // is not reached.
371 : // (Note: you must also set :code:`amrex.signalhandling=0`
372 : // and :code:`amrex.throw_exception=1` for this to work.)
373 30 : pp_query("abort_on_fail", value.m_abort_on_fail);
374 :
375 30 : if (value.m_dump_on_fail || !value.m_abort_on_fail)
376 : {
377 2 : IO::ParmParse pp("amrex");
378 2 : pp.add("signal_handling", 0);
379 2 : pp.add("throw_exception", 1);
380 2 : }
381 :
382 30 : }
383 : protected:
384 : bool m_defined = false;
385 :
386 : bool m_dump_on_fail = false;
387 : bool m_abort_on_fail = true;
388 :
389 : };
390 : }
391 : }
392 : #endif
|