Line data Source code
1 : #ifndef SOLVER_NONLOCAL_LINEAR
2 : #define SOLVER_NONLOCAL_LINEAR
3 : #include "Operator/Operator.H"
4 : #include <AMReX_MLMG.H>
5 : #include "IC/Trig.H"
6 :
7 : namespace Solver
8 : {
9 : namespace Nonlocal
10 : {
11 : /// \brief Multigrid Linear solver for multicomponent, multi-level operators
12 : ///
13 : /// This class is a thin wrapper for the `amrex::MLMG` solver.
14 : /// It exists to set a range of default MLMG settings automatically, for instance,
15 : /// `setCFStrategy`, which may not be obvious to the user.
16 : ///
17 : /// It also exists as a compatibility layer so that future fixes for compatibility
18 : /// with AMReX can be implemented here.
19 : class Linear // : public amrex::MLMG
20 : {
21 : public:
22 :
23 25 : Linear() //Operator::Operator<Grid::Node>& a_lp) : MLMG(a_lp), linop(a_lp)
24 25 : {
25 25 : }
26 :
27 : Linear(Operator::Operator<Grid::Node>& a_lp)
28 : {
29 : this->Define(a_lp);
30 : }
31 :
32 25 : ~Linear()
33 25 : {
34 25 : if (m_defined) Clear();
35 25 : }
36 :
37 426 : void Define(Operator::Operator<Grid::Node>& a_lp)
38 : {
39 426 : if (m_defined) Util::Abort(INFO, "Solver cannot be re-defined");
40 426 : this->linop = &a_lp;
41 426 : this->mlmg = new amrex::MLMG(a_lp);
42 426 : m_defined = true;
43 426 : PrepareMLMG(*mlmg);
44 426 : }
45 426 : void Clear()
46 : {
47 426 : if (!m_defined) Util::Abort(INFO, "Solver cannot be cleared if not defined");
48 426 : this->linop = nullptr;
49 426 : if (this->mlmg) delete this->mlmg;
50 426 : this->mlmg = nullptr;
51 426 : m_defined = false;
52 426 : }
53 :
54 : Set::Scalar solveaffine(amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_sol,
55 : amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_rhs,
56 : Real a_tol_rel, Real a_tol_abs, bool copyrhs = false,
57 : const char* checkpoint_file = nullptr)
58 : {
59 : if (!m_defined) Util::Abort(INFO, "Solver not defined");
60 : amrex::Vector<amrex::MultiFab*> rhs_tmp(a_rhs.size());
61 : amrex::Vector<amrex::MultiFab*> zero_tmp(a_rhs.size());
62 : for (int i = 0; i < rhs_tmp.size(); i++)
63 : {
64 : rhs_tmp[i] = new amrex::MultiFab(a_rhs[i]->boxArray(), a_rhs[i]->DistributionMap(), a_rhs[i]->nComp(), a_rhs[i]->nGrow());
65 : zero_tmp[i] = new amrex::MultiFab(a_rhs[i]->boxArray(), a_rhs[i]->DistributionMap(), a_rhs[i]->nComp(), a_rhs[i]->nGrow());
66 : rhs_tmp[i]->setVal(0.0);
67 : zero_tmp[i]->setVal(0.0);
68 : Util::Message(INFO, rhs_tmp[i]->norm0());
69 : }
70 :
71 : linop->SetHomogeneous(false);
72 : mlmg->apply(rhs_tmp, zero_tmp);
73 :
74 : for (int lev = 0; lev < rhs_tmp.size(); lev++)
75 : {
76 : amrex::Box domain = linop->Geom(lev).Domain();
77 : domain.convert(amrex::IntVect::TheNodeVector());
78 : const Dim3 lo = amrex::lbound(domain), hi = amrex::ubound(domain);
79 : for (MFIter mfi(*rhs_tmp[lev], amrex::TilingIfNotGPU());mfi.isValid();++mfi)
80 : {
81 : amrex::Box bx = mfi.growntilebox(rhs_tmp[lev]->nGrow());
82 : bx = bx & domain;
83 : amrex::Array4<amrex::Real> const& rhstmp = rhs_tmp[lev]->array(mfi);
84 : for (int n = 0; n < rhs_tmp[lev]->nComp(); n++)
85 : {
86 : amrex::ParallelFor(bx, [=] AMREX_GPU_DEVICE(int i, int j, int k)
87 : {
88 : bool AMREX_D_DECL(xmin = (i == lo.x), ymin = (j == lo.y), zmin = (k == lo.z)),
89 : AMREX_D_DECL(xmax = (i == hi.x), ymax = (j == hi.y), zmax = (k == hi.z));
90 : if (AMREX_D_TERM(xmax || xmin, || ymax || ymin, || zmax || zmin))
91 : rhstmp(i, j, k, n) = 0.0;
92 : else
93 : rhstmp(i, j, k, n) *= -1.0;
94 : });
95 : }
96 : }
97 : Util::Message(INFO, rhs_tmp[lev]->norm0());
98 : linop->realFillBoundary(*rhs_tmp[lev], linop->Geom(lev));
99 : Util::Message(INFO, rhs_tmp[lev]->norm0());
100 : //rhs_tmp[lev]->FillBoundary();
101 : }
102 :
103 : for (int lev = 0; lev < rhs_tmp.size(); lev++)
104 : {
105 : Util::Message(INFO, rhs_tmp[lev]->norm0());
106 : amrex::Add(*rhs_tmp[lev], *a_rhs[lev], 0, 0, rhs_tmp[lev]->nComp(), rhs_tmp[lev]->nGrow());
107 : if (copyrhs)
108 : amrex::Copy(*a_rhs[lev], *rhs_tmp[lev], 0, 0, rhs_tmp[lev]->nComp(), rhs_tmp[lev]->nGrow());
109 : Util::Message(INFO, rhs_tmp[lev]->norm0());
110 : }
111 :
112 : linop->SetHomogeneous(true);
113 : PrepareMLMG(*mlmg);
114 : Set::Scalar retval = NAN;
115 : try
116 : {
117 : retval = mlmg->solve(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(rhs_tmp), a_tol_rel, a_tol_abs, checkpoint_file);
118 : }
119 : catch (const std::exception& e)
120 : {
121 : if (m_dump_on_fail) dumpOnConvergenceFail(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(rhs_tmp));
122 : if (m_abort_on_fail) Util::Abort(INFO, e.what());
123 : }
124 : if (a_sol[0]->contains_nan())
125 : {
126 : dumpOnConvergenceFail(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(rhs_tmp));
127 : Util::Abort(INFO);
128 : }
129 :
130 : return retval;
131 : };
132 :
133 1326 : Set::Scalar solve(amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_sol,
134 : amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_rhs,
135 : Real a_tol_rel, Real a_tol_abs, const char* checkpoint_file = nullptr)
136 : {
137 1326 : PrepareMLMG(*mlmg);
138 1326 : Set::Scalar retval = NAN;
139 : try
140 : {
141 1326 : retval = mlmg->solve(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(a_rhs), a_tol_rel, a_tol_abs, checkpoint_file);
142 : }
143 0 : catch (const std::exception& e)
144 : {
145 0 : if (m_dump_on_fail) dumpOnConvergenceFail(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(a_rhs));
146 0 : if (m_abort_on_fail) Util::Abort(INFO, e.what());
147 : }
148 1326 : return retval;
149 : };
150 : Set::Scalar solve(amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_sol,
151 : amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_rhs)
152 : {
153 : PrepareMLMG(*mlmg);
154 : Set::Scalar retval = NAN;
155 : try
156 : {
157 : retval = mlmg->solve(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(a_rhs), tol_rel, tol_abs);
158 : }
159 : catch (const std::exception& e)
160 : {
161 : if (m_dump_on_fail) dumpOnConvergenceFail(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(a_rhs));
162 : if (m_abort_on_fail) Util::Abort(INFO, e.what());
163 : }
164 : if (a_sol[0]->contains_nan())
165 : {
166 : dumpOnConvergenceFail(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(a_rhs));
167 : Util::Abort(INFO);
168 : }
169 : return retval;
170 : };
171 : void apply(amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_rhs,
172 : amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_sol)
173 : {
174 : PrepareMLMG(*mlmg);
175 : mlmg->apply(GetVecOfPtrs(a_rhs), GetVecOfPtrs(a_sol));
176 : };
177 :
178 :
179 :
180 : void setMaxIter(const int a_max_iter) { max_iter = a_max_iter; }
181 : void setBottomMaxIter(const int a_bottom_max_iter) { bottom_max_iter = a_bottom_max_iter; }
182 : void setMaxFmgIter(const int a_max_fmg_iter) { max_fmg_iter = a_max_fmg_iter; }
183 : void setFixedIter(const int a_fixed_iter) { fixed_iter = a_fixed_iter; }
184 : void setVerbose(const int a_verbose) { verbose = a_verbose; }
185 : void setPreSmooth(const int a_pre_smooth) { pre_smooth = a_pre_smooth; }
186 : void setPostSmooth(const int a_post_smooth) { post_smooth = a_post_smooth; }
187 :
188 0 : void dumpOnConvergenceFail(const amrex::Vector<amrex::MultiFab*>& a_sol_mf,
189 : const amrex::Vector<amrex::MultiFab const*>& a_rhs_mf)
190 : {
191 0 : int nlevs = a_sol_mf.size();
192 0 : int ncomps = a_sol_mf[0]->nComp();
193 :
194 0 : amrex::Vector<amrex::Geometry> geom;
195 0 : amrex::Vector<int> iter;
196 0 : amrex::Vector<amrex::IntVect> refratio;
197 0 : amrex::Vector<std::string> names;
198 0 : for (int i = 0; i < nlevs; i++)
199 : {
200 0 : geom.push_back(linop->Geom(i));
201 0 : iter.push_back(0);
202 0 : if (i > 0) refratio.push_back(amrex::IntVect(2));
203 : }
204 0 : for (int n = 0; n < ncomps; n++)
205 : {
206 0 : names.push_back("var" + std::to_string(n));
207 : }
208 :
209 0 : std::string outputdir = Util::GetFileName();
210 0 : WriteMultiLevelPlotfile(outputdir + "/mlmg_sol", nlevs,
211 0 : amrex::GetVecOfConstPtrs(a_sol_mf),
212 : names, geom, 0, iter, refratio);
213 0 : WriteMultiLevelPlotfile(outputdir + "/mlmg_rhs", nlevs,
214 0 : amrex::GetVecOfConstPtrs(a_rhs_mf),
215 : names, geom, 0, iter, refratio);
216 :
217 0 : Set::Field<Set::Scalar> res_mf(nlevs);
218 0 : for (int lev = 0; lev < nlevs; lev++)
219 : {
220 0 : res_mf.Define(lev, a_sol_mf[lev]->boxArray(), a_sol_mf[lev]->DistributionMap(),
221 0 : ncomps, a_sol_mf[lev]->nGrow());
222 : }
223 :
224 0 : mlmg->compResidual(amrex::GetVecOfPtrs(res_mf), a_sol_mf, a_rhs_mf);
225 :
226 0 : WriteMultiLevelPlotfile(outputdir + "/mlmg_res", nlevs,
227 0 : amrex::GetVecOfConstPtrs(res_mf),
228 : names, geom, 0, iter, refratio);
229 :
230 0 : }
231 :
232 : //using MLMG::solve;
233 : protected:
234 : int max_iter = -1;
235 : int bottom_max_iter = -1;
236 : int max_fmg_iter = -1;
237 : int fixed_iter = -1;
238 : int verbose = -1;
239 : int pre_smooth = -1;
240 : int post_smooth = -1;
241 : int final_smooth = -1;
242 : int bottom_smooth = -1;
243 : std::string bottom_solver;
244 : Set::Scalar cg_tol_rel = -1.0;
245 : Set::Scalar cg_tol_abs = -1.0;
246 : Set::Scalar bottom_tol_rel = -1.0;
247 : Set::Scalar bottom_tol_abs = -1.0;
248 : Set::Scalar tol_rel = -1.0;
249 : Set::Scalar tol_abs = -1.0;
250 : Set::Scalar omega = -1.0;
251 : bool average_down_coeffs = false;
252 : bool normalize_ddw = false;
253 :
254 : Operator::Operator<Grid::Node>* linop;
255 : amrex::MLMG* mlmg;
256 :
257 1752 : void PrepareMLMG(amrex::MLMG& mlmg)
258 : {
259 1752 : if (!m_defined) Util::Message(INFO, "Solver not defined");
260 1752 : mlmg.setBottomSolver(MLMG::BottomSolver::bicgstab);
261 1752 : mlmg.setCFStrategy(MLMG::CFStrategy::ghostnodes);
262 1752 : mlmg.setFinalFillBC(false);
263 1752 : mlmg.setMaxFmgIter(100000000);
264 :
265 :
266 1752 : if (max_iter >= 0) mlmg.setMaxIter(max_iter);
267 1752 : if (bottom_max_iter >= 0) mlmg.setBottomMaxIter(bottom_max_iter);
268 1752 : if (max_fmg_iter >= 0) mlmg.setMaxFmgIter(max_fmg_iter);
269 1752 : if (fixed_iter >= 0) mlmg.setFixedIter(fixed_iter);
270 1752 : if (verbose >= 0)
271 : {
272 1752 : mlmg.setVerbose(verbose - 1);
273 1752 : if (verbose > 4) mlmg.setBottomVerbose(verbose);
274 1752 : else mlmg.setBottomVerbose(0);
275 : }
276 :
277 1752 : if (pre_smooth >= 0) mlmg.setPreSmooth(pre_smooth);
278 1752 : if (post_smooth >= 0) mlmg.setPostSmooth(post_smooth);
279 1752 : if (final_smooth >= 0) mlmg.setFinalSmooth(final_smooth);
280 1752 : if (bottom_smooth >= 0) mlmg.setBottomSmooth(bottom_smooth);
281 :
282 1752 : if (bottom_solver == "cg") mlmg.setBottomSolver(MLMG::BottomSolver::cg);
283 1752 : else if (bottom_solver == "bicgstab") mlmg.setBottomSolver(MLMG::BottomSolver::bicgstab);
284 1752 : else if (bottom_solver == "smoother") mlmg.setBottomSolver(MLMG::BottomSolver::smoother);
285 :
286 1752 : if (bottom_tol_rel >= 0) mlmg.setBottomTolerance(bottom_tol_rel);
287 1752 : if (bottom_tol_abs >= 0) mlmg.setBottomToleranceAbs(bottom_tol_abs);
288 :
289 1752 : if (omega >= 0) this->linop->SetOmega(omega);
290 1752 : if (average_down_coeffs) this->linop->SetAverageDownCoeffs(true);
291 1752 : if (normalize_ddw) this->linop->SetNormalizeDDW(true);
292 1752 : }
293 :
294 :
295 : public:
296 : // These are the parameters that are read in for a standard
297 : // multigrid linear solve.
298 19 : static void Parse(Linear& value, amrex::ParmParse& pp)
299 : {
300 : // Max number of iterations to perform before erroring out
301 19 : pp_query("max_iter", value.max_iter);
302 :
303 : // Max number of iterations on the bottom solver
304 19 : pp_query("bottom_max_iter", value.bottom_max_iter);
305 :
306 : // Max number of F-cycle iterations to perform
307 19 : pp_query("max_fmg_iter", value.max_fmg_iter);
308 :
309 : // DEPRICATED - do not use
310 19 : if (pp.contains("max_fixed_iter"))
311 0 : Util::Abort(INFO, "max_fixed_iter is depricated. Use fixed_iter instead.");
312 :
313 : // Number of fixed iterations to perform before exiting gracefully
314 19 : pp_query("fixed_iter", value.fixed_iter);
315 :
316 : // Verbosity of the solver (1-5)
317 19 : pp_query("verbose", value.verbose);
318 :
319 : // Number of smoothing operations before bottom solve (2)
320 19 : pp_query("pre_smooth", value.pre_smooth);
321 :
322 : // Number of smoothing operations after bottom solve (2)
323 19 : pp_query("post_smooth", value.post_smooth);
324 :
325 : // Number of final smoothing operations when smoother is used as bottom solver (8)
326 19 : pp_query("final_smooth", value.final_smooth);
327 :
328 : // Additional smoothing after bottom CG solver (0)
329 19 : pp_query("bottom_smooth", value.bottom_smooth);
330 :
331 : // The method that is used for the multigrid bottom solve (cg, bicgstab, smoother)
332 19 : pp_query("bottom_solver", value.bottom_solver);
333 :
334 19 : if (pp.contains("cg_tol_rel"))
335 0 : Util::Abort(INFO, "cg_tol_rel is depricated. Use bottom_tol_rel instead.");
336 19 : if (pp.contains("cg_tol_abs"))
337 0 : Util::Abort(INFO, "cg_tol_abs is depricated. Use bottom_tol_abs instead.");
338 :
339 : // Relative tolerance on bottom solver
340 19 : pp_query("bottom_tol_rel", value.bottom_tol_rel);
341 :
342 : // Absolute tolerance on bottom solver
343 19 : pp_query("bottom_tol_abs", value.bottom_tol_abs);
344 :
345 : // Relative tolerance
346 19 : pp_query("tol_rel", value.tol_rel);
347 :
348 : // Absolute tolerance
349 19 : pp_query("tol_abs", value.tol_abs);
350 :
351 : // Omega (used in gauss-seidel solver)
352 19 : pp_query("omega", value.omega);
353 :
354 : // Whether to average down coefficients or use the ones given.
355 : // (Setting this to true is important for fracture.)
356 19 : pp_query("average_down_coeffs", value.average_down_coeffs);
357 :
358 : // Whether to normalize DDW when calculating the diagonal.
359 : // This is primarily used when DDW is near-singular - like when there
360 : // is a "void" region or when doing phase field fracture.
361 19 : pp_query("normalize_ddw", value.normalize_ddw);
362 :
363 : // [false]
364 : // If set to true, output diagnostic multifab information
365 : // whenever the MLMG solver fails to converge.
366 : // (Note: you must also set :code:`amrex.signalhandling=0`
367 : // and :code:`amrex.throw_exception=1` for this to work.)
368 19 : pp_query("dump_on_fail", value.m_dump_on_fail);
369 :
370 : // [true]
371 : // If set to false, MLMG will not die if convergence criterion
372 : // is not reached.
373 : // (Note: you must also set :code:`amrex.signalhandling=0`
374 : // and :code:`amrex.throw_exception=1` for this to work.)
375 19 : pp_query("abort_on_fail", value.m_abort_on_fail);
376 :
377 19 : if (value.m_dump_on_fail || !value.m_abort_on_fail)
378 : {
379 6 : IO::ParmParse pp("amrex");
380 2 : pp.add("signal_handling", 0);
381 2 : pp.add("throw_exception", 1);
382 : }
383 :
384 19 : }
385 : protected:
386 : bool m_defined = false;
387 :
388 : bool m_dump_on_fail = false;
389 : bool m_abort_on_fail = true;
390 :
391 : };
392 : }
393 : }
394 : #endif
|