LCOV - code coverage report
Current view: top level - src/Solver/Nonlocal - Linear.H (source / functions) Coverage Total Hit
Test: coverage_merged.info Lines: 71.1 % 114 81
Test Date: 2026-02-24 04:46:08 Functions: 87.5 % 8 7

            Line data    Source code
       1              : #ifndef SOLVER_NONLOCAL_LINEAR
       2              : #define SOLVER_NONLOCAL_LINEAR
       3              : #include "Operator/Operator.H"
       4              : #include <AMReX_MLMG.H>
       5              : 
       6              : namespace Solver
       7              : {
       8              : namespace Nonlocal
       9              : {
      10              : /// \brief Multigrid Linear solver for multicomponent, multi-level operators
      11              : /// 
      12              : /// This class is a thin wrapper for the `amrex::MLMG` solver.
      13              : /// It exists to set a range of default MLMG settings automatically, for instance,
      14              : /// `setCFStrategy`, which may not be obvious to the user.
      15              : ///
      16              : /// It also exists as a compatibility layer so that future fixes for compatibility
      17              : /// with AMReX can be implemented here.
      18              : class Linear  // : public amrex::MLMG
      19              : {
      20              : public:
      21              : 
      22           36 :     Linear() //Operator::Operator<Grid::Node>& a_lp) : MLMG(a_lp), linop(a_lp)
      23           36 :     {
      24           36 :     }
      25              : 
      26              :     Linear(Operator::Operator<Grid::Node>& a_lp)
      27              :     {
      28              :         this->Define(a_lp);
      29              :     }
      30              : 
      31           36 :     ~Linear()
      32              :     {
      33           36 :         if (m_defined) Clear();
      34           36 :     }
      35              : 
      36          734 :     void Define(Operator::Operator<Grid::Node>& a_lp)
      37              :     {
      38          734 :         if (m_defined) Util::Abort(INFO, "Solver cannot be re-defined");
      39          734 :         this->linop = &a_lp;
      40          734 :         this->mlmg = new amrex::MLMG(a_lp);
      41          734 :         m_defined = true;
      42          734 :         PrepareMLMG(*mlmg);
      43          734 :     }
      44          734 :     void Clear()
      45              :     {
      46          734 :         if (!m_defined) Util::Abort(INFO, "Solver cannot be cleared if not defined");
      47          734 :         this->linop = nullptr;
      48          734 :         if (this->mlmg) delete this->mlmg;
      49          734 :         this->mlmg = nullptr;
      50          734 :         m_defined = false;
      51          734 :     }
      52              : 
      53              :     Set::Scalar solveaffine(amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_sol,
      54              :         amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_rhs,
      55              :         Real a_tol_rel, Real a_tol_abs, bool copyrhs = false,
      56              :         const char* checkpoint_file = nullptr)
      57              :     {
      58              :         if (!m_defined) Util::Abort(INFO, "Solver not defined");
      59              :         amrex::Vector<amrex::MultiFab*> rhs_tmp(a_rhs.size());
      60              :         amrex::Vector<amrex::MultiFab*> zero_tmp(a_rhs.size());
      61              :         for (int i = 0; i < rhs_tmp.size(); i++)
      62              :         {
      63              :             rhs_tmp[i] = new amrex::MultiFab(a_rhs[i]->boxArray(), a_rhs[i]->DistributionMap(), a_rhs[i]->nComp(), a_rhs[i]->nGrow());
      64              :             zero_tmp[i] = new amrex::MultiFab(a_rhs[i]->boxArray(), a_rhs[i]->DistributionMap(), a_rhs[i]->nComp(), a_rhs[i]->nGrow());
      65              :             rhs_tmp[i]->setVal(0.0);
      66              :             zero_tmp[i]->setVal(0.0);
      67              :             Util::Message(INFO, rhs_tmp[i]->norm0());
      68              :         }
      69              : 
      70              :         linop->SetHomogeneous(false);
      71              :         mlmg->apply(rhs_tmp, zero_tmp);
      72              : 
      73              :         for (int lev = 0; lev < rhs_tmp.size(); lev++)
      74              :         {
      75              :             amrex::Box domain = linop->Geom(lev).Domain();
      76              :             domain.convert(amrex::IntVect::TheNodeVector());
      77              :             const Dim3 lo = amrex::lbound(domain), hi = amrex::ubound(domain);
      78              :             for (MFIter mfi(*rhs_tmp[lev], amrex::TilingIfNotGPU());mfi.isValid();++mfi)
      79              :             {
      80              :                 amrex::Box bx = mfi.growntilebox(rhs_tmp[lev]->nGrow());
      81              :                 bx = bx & domain;
      82              :                 amrex::Array4<amrex::Real> const& rhstmp = rhs_tmp[lev]->array(mfi);
      83              :                 for (int n = 0; n < rhs_tmp[lev]->nComp(); n++)
      84              :                 {
      85              :                     amrex::ParallelFor(bx, [=] AMREX_GPU_DEVICE(int i, int j, int k)
      86              :                     {
      87              :                         bool    AMREX_D_DECL(xmin = (i == lo.x), ymin = (j == lo.y), zmin = (k == lo.z)),
      88              :                             AMREX_D_DECL(xmax = (i == hi.x), ymax = (j == hi.y), zmax = (k == hi.z));
      89              :                         if (AMREX_D_TERM(xmax || xmin, || ymax || ymin, || zmax || zmin))
      90              :                             rhstmp(i, j, k, n) = 0.0;
      91              :                         else
      92              :                             rhstmp(i, j, k, n) *= -1.0;
      93              :                     });
      94              :                 }
      95              :             }
      96              :             Util::Message(INFO, rhs_tmp[lev]->norm0());
      97              :             rhs_tmp[lev]->setMultiGhost(true);
      98              :             rhs_tmp[lev]->FillBoundaryAndSync(linop->Geom(lev).periodicity());
      99              :         }
     100              : 
     101              :         for (int lev = 0; lev < rhs_tmp.size(); lev++)
     102              :         {
     103              :             Util::Message(INFO, rhs_tmp[lev]->norm0());
     104              :             amrex::Add(*rhs_tmp[lev], *a_rhs[lev], 0, 0, rhs_tmp[lev]->nComp(), rhs_tmp[lev]->nGrow());
     105              :             if (copyrhs)
     106              :                 amrex::Copy(*a_rhs[lev], *rhs_tmp[lev], 0, 0, rhs_tmp[lev]->nComp(), rhs_tmp[lev]->nGrow());
     107              :             Util::Message(INFO, rhs_tmp[lev]->norm0());
     108              :         }
     109              : 
     110              :         linop->SetHomogeneous(true);
     111              :         PrepareMLMG(*mlmg);
     112              :         Set::Scalar retval = NAN;
     113              :         try
     114              :         {
     115              :             retval = mlmg->solve(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(rhs_tmp), a_tol_rel, a_tol_abs, checkpoint_file);
     116              :         }
     117              :         catch (const std::exception& e)
     118              :         {
     119              :             if (m_dump_on_fail) dumpOnConvergenceFail(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(rhs_tmp));
     120              :             if (m_abort_on_fail) Util::Abort(INFO, e.what());
     121              :         }
     122              :         if (a_sol[0]->contains_nan()) 
     123              :         {
     124              :             dumpOnConvergenceFail(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(rhs_tmp));
     125              :             Util::Abort(INFO);
     126              :         }
     127              : 
     128              :         return retval;
     129              :     };
     130              : 
     131         1634 :     Set::Scalar solve(amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_sol,
     132              :         amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_rhs,
     133              :         Real a_tol_rel, Real a_tol_abs, const char* checkpoint_file = nullptr)
     134              :     {
     135         1634 :         PrepareMLMG(*mlmg);
     136         1634 :         Set::Scalar retval = NAN;
     137              :         try
     138              :         {
     139         1634 :             retval = mlmg->solve(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(a_rhs), a_tol_rel, a_tol_abs, checkpoint_file);
     140              :         }
     141            0 :         catch (const std::exception& e)
     142              :         {
     143            0 :             if (m_dump_on_fail) dumpOnConvergenceFail(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(a_rhs));
     144            0 :             if (m_abort_on_fail) Util::Abort(INFO, e.what());
     145            0 :         }
     146         1634 :         return retval;
     147              :     };
     148              :     Set::Scalar solve(amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_sol,
     149              :         amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_rhs)
     150              :     {
     151              :         PrepareMLMG(*mlmg);
     152              :         Set::Scalar retval = NAN;
     153              :         try
     154              :         {
     155              :             retval = mlmg->solve(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(a_rhs), tol_rel, tol_abs);
     156              :         }
     157              :         catch (const std::exception& e)
     158              :         {
     159              :             if (m_dump_on_fail) dumpOnConvergenceFail(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(a_rhs));
     160              :             if (m_abort_on_fail) Util::Abort(INFO, e.what());
     161              :         }
     162              :         if (a_sol[0]->contains_nan()) 
     163              :         {
     164              :             dumpOnConvergenceFail(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(a_rhs));
     165              :             Util::Abort(INFO);
     166              :         }
     167              :         return retval;
     168              :     };
     169              :     void apply(amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_rhs,
     170              :         amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_sol)
     171              :     {
     172              :         PrepareMLMG(*mlmg);
     173              :         mlmg->apply(GetVecOfPtrs(a_rhs), GetVecOfPtrs(a_sol));
     174              :     };
     175              : 
     176              : 
     177              : 
     178              :     void setMaxIter(const int a_max_iter) { max_iter = a_max_iter; }
     179              :     void setBottomMaxIter(const int a_bottom_max_iter) { bottom_max_iter = a_bottom_max_iter; }
     180              :     void setMaxFmgIter(const int a_max_fmg_iter) { max_fmg_iter = a_max_fmg_iter; }
     181              :     void setFixedIter(const int a_fixed_iter) { fixed_iter = a_fixed_iter; }
     182              :     void setVerbose(const int a_verbose) { verbose = a_verbose; }
     183              :     void setPreSmooth(const int a_pre_smooth) { pre_smooth = a_pre_smooth; }
     184              :     void setPostSmooth(const int a_post_smooth) { post_smooth = a_post_smooth; }
     185              : 
     186            0 :     void dumpOnConvergenceFail(const amrex::Vector<amrex::MultiFab*>& a_sol_mf,
     187              :         const amrex::Vector<amrex::MultiFab const*>& a_rhs_mf)
     188              :     {
     189            0 :         int nlevs = a_sol_mf.size();
     190            0 :         int ncomps = a_sol_mf[0]->nComp();
     191              : 
     192            0 :         amrex::Vector<amrex::Geometry> geom;
     193            0 :         amrex::Vector<int> iter;
     194            0 :         amrex::Vector<amrex::IntVect> refratio;
     195            0 :         amrex::Vector<std::string> names;
     196            0 :         for (int i = 0; i < nlevs; i++)
     197              :         {
     198            0 :             geom.push_back(linop->Geom(i));
     199            0 :             iter.push_back(0);
     200            0 :             if (i > 0) refratio.push_back(amrex::IntVect(2));
     201              :         }
     202            0 :         for (int n = 0; n < ncomps; n++)
     203              :         {
     204            0 :             names.push_back("var" + std::to_string(n));
     205              :         }
     206              : 
     207            0 :         std::string outputdir = Util::GetFileName();
     208            0 :         WriteMultiLevelPlotfile(outputdir + "/mlmg_sol", nlevs,
     209            0 :             amrex::GetVecOfConstPtrs(a_sol_mf),
     210              :             names, geom, 0, iter, refratio);
     211            0 :         WriteMultiLevelPlotfile(outputdir + "/mlmg_rhs", nlevs,
     212            0 :             amrex::GetVecOfConstPtrs(a_rhs_mf),
     213              :             names, geom, 0, iter, refratio);
     214              : 
     215            0 :         Set::Field<Set::Scalar> res_mf(nlevs);
     216            0 :         for (int lev = 0; lev < nlevs; lev++)
     217              :         {
     218            0 :             res_mf.Define(lev, a_sol_mf[lev]->boxArray(), a_sol_mf[lev]->DistributionMap(),
     219            0 :                 ncomps, a_sol_mf[lev]->nGrow());
     220              :         }
     221              : 
     222            0 :         mlmg->compResidual(amrex::GetVecOfPtrs(res_mf), a_sol_mf, a_rhs_mf);
     223              : 
     224            0 :         WriteMultiLevelPlotfile(outputdir + "/mlmg_res", nlevs,
     225            0 :             amrex::GetVecOfConstPtrs(res_mf),
     226              :             names, geom, 0, iter, refratio);
     227              : 
     228            0 :     }
     229              : 
     230              :     //using MLMG::solve;
     231              : protected:
     232              :     int max_iter = -1;
     233              :     int bottom_max_iter = -1;
     234              :     int max_fmg_iter = -1;
     235              :     int fixed_iter = -1;
     236              :     int verbose = -1;
     237              :     int pre_smooth = -1;
     238              :     int post_smooth = -1;
     239              :     int final_smooth = -1;
     240              :     int bottom_smooth = -1;
     241              :     std::string bottom_solver;
     242              :     Set::Scalar cg_tol_rel = -1.0;
     243              :     Set::Scalar cg_tol_abs = -1.0;
     244              :     Set::Scalar bottom_tol_rel = -1.0;
     245              :     Set::Scalar bottom_tol_abs = -1.0;
     246              :     Set::Scalar tol_rel = -1.0;
     247              :     Set::Scalar tol_abs = -1.0;
     248              :     Set::Scalar omega = -1.0;
     249              :     bool average_down_coeffs = false;
     250              :     bool normalize_ddw = false;
     251              : 
     252              :     Operator::Operator<Grid::Node>* linop;
     253              :     amrex::MLMG* mlmg;
     254              : 
     255         2368 :     void PrepareMLMG(amrex::MLMG& mlmg)
     256              :     {
     257         2368 :         if (!m_defined) Util::Message(INFO, "Solver not defined");
     258         2368 :         mlmg.setBottomSolver(MLMG::BottomSolver::bicgstab);
     259         2368 :         mlmg.setCFStrategy(MLMG::CFStrategy::ghostnodes);
     260         2368 :         mlmg.setFinalFillBC(false);
     261         2368 :         mlmg.setMaxFmgIter(100000000);
     262              : 
     263              : 
     264         2368 :         if (max_iter >= 0)        mlmg.setMaxIter(max_iter);
     265         2368 :         if (bottom_max_iter >= 0) mlmg.setBottomMaxIter(bottom_max_iter);
     266         2368 :         if (max_fmg_iter >= 0)    mlmg.setMaxFmgIter(max_fmg_iter);
     267         2368 :         if (fixed_iter >= 0)      mlmg.setFixedIter(fixed_iter);
     268         2368 :         if (verbose >= 0)
     269              :         {
     270         2368 :             mlmg.setVerbose(verbose - 1);
     271         2368 :             if (verbose > 4)      mlmg.setBottomVerbose(verbose);
     272         2368 :             else                  mlmg.setBottomVerbose(0);
     273              :         }
     274              : 
     275         2368 :         if (pre_smooth >= 0)      mlmg.setPreSmooth(pre_smooth);
     276         2368 :         if (post_smooth >= 0)     mlmg.setPostSmooth(post_smooth);
     277         2368 :         if (final_smooth >= 0)    mlmg.setFinalSmooth(final_smooth);
     278         2368 :         if (bottom_smooth >= 0)   mlmg.setBottomSmooth(bottom_smooth);
     279              : 
     280         2368 :         if (bottom_solver == "cg")       mlmg.setBottomSolver(MLMG::BottomSolver::cg);
     281         2368 :         else if (bottom_solver == "bicgstab") mlmg.setBottomSolver(MLMG::BottomSolver::bicgstab);
     282         2368 :         else if (bottom_solver == "smoother") mlmg.setBottomSolver(MLMG::BottomSolver::smoother);
     283              : 
     284         2368 :         if (bottom_tol_rel >= 0) mlmg.setBottomTolerance(bottom_tol_rel);
     285         2368 :         if (bottom_tol_abs >= 0) mlmg.setBottomToleranceAbs(bottom_tol_abs);
     286              : 
     287         2368 :         if (omega >= 0) this->linop->SetOmega(omega);
     288         2368 :         if (average_down_coeffs) this->linop->SetAverageDownCoeffs(true);
     289         2368 :         if (normalize_ddw) this->linop->SetNormalizeDDW(true);
     290         2368 :     }
     291              : 
     292              : 
     293              : public:
     294              :     // These are the parameters that are read in for a standard 
     295              :     // multigrid linear solve.
     296           30 :     static void Parse(Linear& value, amrex::ParmParse& pp)
     297              :     {
     298              :         // Max number of iterations to perform before erroring out
     299           30 :         pp_query("max_iter", value.max_iter);
     300              : 
     301              :         // Max number of iterations on the bottom solver
     302           30 :         pp_query("bottom_max_iter", value.bottom_max_iter);
     303              : 
     304              :         // Max number of F-cycle iterations to perform
     305           30 :         pp_query("max_fmg_iter", value.max_fmg_iter);
     306              : 
     307              :         // DEPRICATED - do not use
     308           30 :         if (pp.contains("max_fixed_iter"))
     309            0 :             Util::Abort(INFO, "max_fixed_iter is depricated. Use fixed_iter instead.");
     310              : 
     311              :         // Number of fixed iterations to perform before exiting gracefully
     312           30 :         pp_query("fixed_iter", value.fixed_iter);
     313              : 
     314              :         // Verbosity of the solver (1-5)
     315           30 :         pp_query("verbose", value.verbose);
     316              : 
     317              :         // Number of smoothing operations before bottom solve (2)
     318           30 :         pp_query("pre_smooth", value.pre_smooth);
     319              : 
     320              :         // Number of smoothing operations after bottom solve (2)
     321           30 :         pp_query("post_smooth", value.post_smooth);
     322              : 
     323              :         // Number of final smoothing operations when smoother is used as bottom solver (8)
     324           30 :         pp_query("final_smooth", value.final_smooth);
     325              : 
     326              :         // Additional smoothing after bottom CG solver (0)
     327           30 :         pp_query("bottom_smooth", value.bottom_smooth);
     328              : 
     329              :         // The method that is used for the multigrid bottom solve (cg, bicgstab, smoother)
     330           30 :         pp_query("bottom_solver", value.bottom_solver);
     331              : 
     332           30 :         if (pp.contains("cg_tol_rel"))
     333            0 :             Util::Abort(INFO, "cg_tol_rel is depricated. Use bottom_tol_rel instead.");
     334           30 :         if (pp.contains("cg_tol_abs"))
     335            0 :             Util::Abort(INFO, "cg_tol_abs is depricated. Use bottom_tol_abs instead.");
     336              : 
     337              :         // Relative tolerance on bottom solver
     338           30 :         pp_query("bottom_tol_rel", value.bottom_tol_rel);
     339              : 
     340              :         // Absolute tolerance on bottom solver
     341           30 :         pp_query("bottom_tol_abs", value.bottom_tol_abs);
     342              : 
     343              :         // Relative tolerance
     344           30 :         pp_query("tol_rel", value.tol_rel);
     345              : 
     346              :         // Absolute tolerance
     347           30 :         pp_query("tol_abs", value.tol_abs);
     348              : 
     349              :         // Omega (used in gauss-seidel solver)
     350           30 :         pp_query("omega", value.omega);
     351              : 
     352              :         // Whether to average down coefficients or use the ones given.
     353              :         // (Setting this to true is important for fracture.)
     354           30 :         pp_query("average_down_coeffs", value.average_down_coeffs);
     355              : 
     356              :         // Whether to normalize DDW when calculating the diagonal.
     357              :         // This is primarily used when DDW is near-singular - like when there
     358              :         // is a "void" region or when doing phase field fracture.
     359           30 :         pp_query("normalize_ddw", value.normalize_ddw);
     360              : 
     361              :         // [false] 
     362              :         // If set to true, output diagnostic multifab information 
     363              :         // whenever the MLMG solver fails to converge.
     364              :         // (Note: you must also set :code:`amrex.signalhandling=0`
     365              :         // and :code:`amrex.throw_exception=1` for this to work.)
     366           30 :         pp_query("dump_on_fail", value.m_dump_on_fail);
     367              : 
     368              :         // [true]
     369              :         // If set to false, MLMG will not die if convergence criterion
     370              :         // is not reached.
     371              :         // (Note: you must also set :code:`amrex.signalhandling=0`
     372              :         // and :code:`amrex.throw_exception=1` for this to work.)
     373           30 :         pp_query("abort_on_fail", value.m_abort_on_fail);
     374              : 
     375           30 :         if (value.m_dump_on_fail || !value.m_abort_on_fail)
     376              :         {
     377            2 :             IO::ParmParse pp("amrex");
     378            2 :             pp.add("signal_handling", 0);
     379            2 :             pp.add("throw_exception", 1);
     380            2 :         }
     381              : 
     382           30 :     }
     383              : protected:
     384              :     bool m_defined = false;
     385              : 
     386              :     bool m_dump_on_fail = false;
     387              :     bool m_abort_on_fail = true;
     388              : 
     389              : };
     390              : }
     391              : }
     392              : #endif
        

Generated by: LCOV version 2.0-1