LCOV - code coverage report
Current view: top level - src/Solver/Nonlocal - Linear.H (source / functions) Coverage Total Hit
Test: coverage_merged.info Lines: 71.1 % 114 81
Test Date: 2025-04-03 04:02:21 Functions: 87.5 % 8 7

            Line data    Source code
       1              : #ifndef SOLVER_NONLOCAL_LINEAR
       2              : #define SOLVER_NONLOCAL_LINEAR
       3              : #include "Operator/Operator.H"
       4              : #include <AMReX_MLMG.H>
       5              : 
       6              : namespace Solver
       7              : {
       8              : namespace Nonlocal
       9              : {
      10              : /// \brief Multigrid Linear solver for multicomponent, multi-level operators
      11              : /// 
      12              : /// This class is a thin wrapper for the `amrex::MLMG` solver.
      13              : /// It exists to set a range of default MLMG settings automatically, for instance,
      14              : /// `setCFStrategy`, which may not be obvious to the user.
      15              : ///
      16              : /// It also exists as a compatibility layer so that future fixes for compatibility
      17              : /// with AMReX can be implemented here.
      18              : class Linear  // : public amrex::MLMG
      19              : {
      20              : public:
      21              : 
      22           25 :     Linear() //Operator::Operator<Grid::Node>& a_lp) : MLMG(a_lp), linop(a_lp)
      23           25 :     {
      24           25 :     }
      25              : 
      26              :     Linear(Operator::Operator<Grid::Node>& a_lp)
      27              :     {
      28              :         this->Define(a_lp);
      29              :     }
      30              : 
      31           25 :     ~Linear()
      32              :     {
      33           25 :         if (m_defined) Clear();
      34           25 :     }
      35              : 
      36          426 :     void Define(Operator::Operator<Grid::Node>& a_lp)
      37              :     {
      38          426 :         if (m_defined) Util::Abort(INFO, "Solver cannot be re-defined");
      39          426 :         this->linop = &a_lp;
      40          426 :         this->mlmg = new amrex::MLMG(a_lp);
      41          426 :         m_defined = true;
      42          426 :         PrepareMLMG(*mlmg);
      43          426 :     }
      44          426 :     void Clear()
      45              :     {
      46          426 :         if (!m_defined) Util::Abort(INFO, "Solver cannot be cleared if not defined");
      47          426 :         this->linop = nullptr;
      48          426 :         if (this->mlmg) delete this->mlmg;
      49          426 :         this->mlmg = nullptr;
      50          426 :         m_defined = false;
      51          426 :     }
      52              : 
      53              :     Set::Scalar solveaffine(amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_sol,
      54              :         amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_rhs,
      55              :         Real a_tol_rel, Real a_tol_abs, bool copyrhs = false,
      56              :         const char* checkpoint_file = nullptr)
      57              :     {
      58              :         if (!m_defined) Util::Abort(INFO, "Solver not defined");
      59              :         amrex::Vector<amrex::MultiFab*> rhs_tmp(a_rhs.size());
      60              :         amrex::Vector<amrex::MultiFab*> zero_tmp(a_rhs.size());
      61              :         for (int i = 0; i < rhs_tmp.size(); i++)
      62              :         {
      63              :             rhs_tmp[i] = new amrex::MultiFab(a_rhs[i]->boxArray(), a_rhs[i]->DistributionMap(), a_rhs[i]->nComp(), a_rhs[i]->nGrow());
      64              :             zero_tmp[i] = new amrex::MultiFab(a_rhs[i]->boxArray(), a_rhs[i]->DistributionMap(), a_rhs[i]->nComp(), a_rhs[i]->nGrow());
      65              :             rhs_tmp[i]->setVal(0.0);
      66              :             zero_tmp[i]->setVal(0.0);
      67              :             Util::Message(INFO, rhs_tmp[i]->norm0());
      68              :         }
      69              : 
      70              :         linop->SetHomogeneous(false);
      71              :         mlmg->apply(rhs_tmp, zero_tmp);
      72              : 
      73              :         for (int lev = 0; lev < rhs_tmp.size(); lev++)
      74              :         {
      75              :             amrex::Box domain = linop->Geom(lev).Domain();
      76              :             domain.convert(amrex::IntVect::TheNodeVector());
      77              :             const Dim3 lo = amrex::lbound(domain), hi = amrex::ubound(domain);
      78              :             for (MFIter mfi(*rhs_tmp[lev], amrex::TilingIfNotGPU());mfi.isValid();++mfi)
      79              :             {
      80              :                 amrex::Box bx = mfi.growntilebox(rhs_tmp[lev]->nGrow());
      81              :                 bx = bx & domain;
      82              :                 amrex::Array4<amrex::Real> const& rhstmp = rhs_tmp[lev]->array(mfi);
      83              :                 for (int n = 0; n < rhs_tmp[lev]->nComp(); n++)
      84              :                 {
      85              :                     amrex::ParallelFor(bx, [=] AMREX_GPU_DEVICE(int i, int j, int k)
      86              :                     {
      87              :                         bool    AMREX_D_DECL(xmin = (i == lo.x), ymin = (j == lo.y), zmin = (k == lo.z)),
      88              :                             AMREX_D_DECL(xmax = (i == hi.x), ymax = (j == hi.y), zmax = (k == hi.z));
      89              :                         if (AMREX_D_TERM(xmax || xmin, || ymax || ymin, || zmax || zmin))
      90              :                             rhstmp(i, j, k, n) = 0.0;
      91              :                         else
      92              :                             rhstmp(i, j, k, n) *= -1.0;
      93              :                     });
      94              :                 }
      95              :             }
      96              :             Util::Message(INFO, rhs_tmp[lev]->norm0());
      97              :             linop->realFillBoundary(*rhs_tmp[lev], linop->Geom(lev));
      98              :             Util::Message(INFO, rhs_tmp[lev]->norm0());
      99              :             //rhs_tmp[lev]->FillBoundary();
     100              :         }
     101              : 
     102              :         for (int lev = 0; lev < rhs_tmp.size(); lev++)
     103              :         {
     104              :             Util::Message(INFO, rhs_tmp[lev]->norm0());
     105              :             amrex::Add(*rhs_tmp[lev], *a_rhs[lev], 0, 0, rhs_tmp[lev]->nComp(), rhs_tmp[lev]->nGrow());
     106              :             if (copyrhs)
     107              :                 amrex::Copy(*a_rhs[lev], *rhs_tmp[lev], 0, 0, rhs_tmp[lev]->nComp(), rhs_tmp[lev]->nGrow());
     108              :             Util::Message(INFO, rhs_tmp[lev]->norm0());
     109              :         }
     110              : 
     111              :         linop->SetHomogeneous(true);
     112              :         PrepareMLMG(*mlmg);
     113              :         Set::Scalar retval = NAN;
     114              :         try
     115              :         {
     116              :             retval = mlmg->solve(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(rhs_tmp), a_tol_rel, a_tol_abs, checkpoint_file);
     117              :         }
     118              :         catch (const std::exception& e)
     119              :         {
     120              :             if (m_dump_on_fail) dumpOnConvergenceFail(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(rhs_tmp));
     121              :             if (m_abort_on_fail) Util::Abort(INFO, e.what());
     122              :         }
     123              :         if (a_sol[0]->contains_nan()) 
     124              :         {
     125              :             dumpOnConvergenceFail(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(rhs_tmp));
     126              :             Util::Abort(INFO);
     127              :         }
     128              : 
     129              :         return retval;
     130              :     };
     131              : 
     132         1326 :     Set::Scalar solve(amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_sol,
     133              :         amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_rhs,
     134              :         Real a_tol_rel, Real a_tol_abs, const char* checkpoint_file = nullptr)
     135              :     {
     136         1326 :         PrepareMLMG(*mlmg);
     137         1326 :         Set::Scalar retval = NAN;
     138              :         try
     139              :         {
     140         1326 :             retval = mlmg->solve(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(a_rhs), a_tol_rel, a_tol_abs, checkpoint_file);
     141              :         }
     142            0 :         catch (const std::exception& e)
     143              :         {
     144            0 :             if (m_dump_on_fail) dumpOnConvergenceFail(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(a_rhs));
     145            0 :             if (m_abort_on_fail) Util::Abort(INFO, e.what());
     146            0 :         }
     147         1326 :         return retval;
     148              :     };
     149              :     Set::Scalar solve(amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_sol,
     150              :         amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_rhs)
     151              :     {
     152              :         PrepareMLMG(*mlmg);
     153              :         Set::Scalar retval = NAN;
     154              :         try
     155              :         {
     156              :             retval = mlmg->solve(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(a_rhs), tol_rel, tol_abs);
     157              :         }
     158              :         catch (const std::exception& e)
     159              :         {
     160              :             if (m_dump_on_fail) dumpOnConvergenceFail(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(a_rhs));
     161              :             if (m_abort_on_fail) Util::Abort(INFO, e.what());
     162              :         }
     163              :         if (a_sol[0]->contains_nan()) 
     164              :         {
     165              :             dumpOnConvergenceFail(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(a_rhs));
     166              :             Util::Abort(INFO);
     167              :         }
     168              :         return retval;
     169              :     };
     170              :     void apply(amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_rhs,
     171              :         amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_sol)
     172              :     {
     173              :         PrepareMLMG(*mlmg);
     174              :         mlmg->apply(GetVecOfPtrs(a_rhs), GetVecOfPtrs(a_sol));
     175              :     };
     176              : 
     177              : 
     178              : 
     179              :     void setMaxIter(const int a_max_iter) { max_iter = a_max_iter; }
     180              :     void setBottomMaxIter(const int a_bottom_max_iter) { bottom_max_iter = a_bottom_max_iter; }
     181              :     void setMaxFmgIter(const int a_max_fmg_iter) { max_fmg_iter = a_max_fmg_iter; }
     182              :     void setFixedIter(const int a_fixed_iter) { fixed_iter = a_fixed_iter; }
     183              :     void setVerbose(const int a_verbose) { verbose = a_verbose; }
     184              :     void setPreSmooth(const int a_pre_smooth) { pre_smooth = a_pre_smooth; }
     185              :     void setPostSmooth(const int a_post_smooth) { post_smooth = a_post_smooth; }
     186              : 
     187            0 :     void dumpOnConvergenceFail(const amrex::Vector<amrex::MultiFab*>& a_sol_mf,
     188              :         const amrex::Vector<amrex::MultiFab const*>& a_rhs_mf)
     189              :     {
     190            0 :         int nlevs = a_sol_mf.size();
     191            0 :         int ncomps = a_sol_mf[0]->nComp();
     192              : 
     193            0 :         amrex::Vector<amrex::Geometry> geom;
     194            0 :         amrex::Vector<int> iter;
     195            0 :         amrex::Vector<amrex::IntVect> refratio;
     196            0 :         amrex::Vector<std::string> names;
     197            0 :         for (int i = 0; i < nlevs; i++)
     198              :         {
     199            0 :             geom.push_back(linop->Geom(i));
     200            0 :             iter.push_back(0);
     201            0 :             if (i > 0) refratio.push_back(amrex::IntVect(2));
     202              :         }
     203            0 :         for (int n = 0; n < ncomps; n++)
     204              :         {
     205            0 :             names.push_back("var" + std::to_string(n));
     206              :         }
     207              : 
     208            0 :         std::string outputdir = Util::GetFileName();
     209            0 :         WriteMultiLevelPlotfile(outputdir + "/mlmg_sol", nlevs,
     210            0 :             amrex::GetVecOfConstPtrs(a_sol_mf),
     211              :             names, geom, 0, iter, refratio);
     212            0 :         WriteMultiLevelPlotfile(outputdir + "/mlmg_rhs", nlevs,
     213            0 :             amrex::GetVecOfConstPtrs(a_rhs_mf),
     214              :             names, geom, 0, iter, refratio);
     215              : 
     216            0 :         Set::Field<Set::Scalar> res_mf(nlevs);
     217            0 :         for (int lev = 0; lev < nlevs; lev++)
     218              :         {
     219            0 :             res_mf.Define(lev, a_sol_mf[lev]->boxArray(), a_sol_mf[lev]->DistributionMap(),
     220            0 :                 ncomps, a_sol_mf[lev]->nGrow());
     221              :         }
     222              : 
     223            0 :         mlmg->compResidual(amrex::GetVecOfPtrs(res_mf), a_sol_mf, a_rhs_mf);
     224              : 
     225            0 :         WriteMultiLevelPlotfile(outputdir + "/mlmg_res", nlevs,
     226            0 :             amrex::GetVecOfConstPtrs(res_mf),
     227              :             names, geom, 0, iter, refratio);
     228              : 
     229            0 :     }
     230              : 
     231              :     //using MLMG::solve;
     232              : protected:
     233              :     int max_iter = -1;
     234              :     int bottom_max_iter = -1;
     235              :     int max_fmg_iter = -1;
     236              :     int fixed_iter = -1;
     237              :     int verbose = -1;
     238              :     int pre_smooth = -1;
     239              :     int post_smooth = -1;
     240              :     int final_smooth = -1;
     241              :     int bottom_smooth = -1;
     242              :     std::string bottom_solver;
     243              :     Set::Scalar cg_tol_rel = -1.0;
     244              :     Set::Scalar cg_tol_abs = -1.0;
     245              :     Set::Scalar bottom_tol_rel = -1.0;
     246              :     Set::Scalar bottom_tol_abs = -1.0;
     247              :     Set::Scalar tol_rel = -1.0;
     248              :     Set::Scalar tol_abs = -1.0;
     249              :     Set::Scalar omega = -1.0;
     250              :     bool average_down_coeffs = false;
     251              :     bool normalize_ddw = false;
     252              : 
     253              :     Operator::Operator<Grid::Node>* linop;
     254              :     amrex::MLMG* mlmg;
     255              : 
     256         1752 :     void PrepareMLMG(amrex::MLMG& mlmg)
     257              :     {
     258         1752 :         if (!m_defined) Util::Message(INFO, "Solver not defined");
     259         1752 :         mlmg.setBottomSolver(MLMG::BottomSolver::bicgstab);
     260         1752 :         mlmg.setCFStrategy(MLMG::CFStrategy::ghostnodes);
     261         1752 :         mlmg.setFinalFillBC(false);
     262         1752 :         mlmg.setMaxFmgIter(100000000);
     263              : 
     264              : 
     265         1752 :         if (max_iter >= 0)        mlmg.setMaxIter(max_iter);
     266         1752 :         if (bottom_max_iter >= 0) mlmg.setBottomMaxIter(bottom_max_iter);
     267         1752 :         if (max_fmg_iter >= 0)    mlmg.setMaxFmgIter(max_fmg_iter);
     268         1752 :         if (fixed_iter >= 0)      mlmg.setFixedIter(fixed_iter);
     269         1752 :         if (verbose >= 0)
     270              :         {
     271         1752 :             mlmg.setVerbose(verbose - 1);
     272         1752 :             if (verbose > 4)      mlmg.setBottomVerbose(verbose);
     273         1752 :             else                  mlmg.setBottomVerbose(0);
     274              :         }
     275              : 
     276         1752 :         if (pre_smooth >= 0)      mlmg.setPreSmooth(pre_smooth);
     277         1752 :         if (post_smooth >= 0)     mlmg.setPostSmooth(post_smooth);
     278         1752 :         if (final_smooth >= 0)    mlmg.setFinalSmooth(final_smooth);
     279         1752 :         if (bottom_smooth >= 0)   mlmg.setBottomSmooth(bottom_smooth);
     280              : 
     281         1752 :         if (bottom_solver == "cg")       mlmg.setBottomSolver(MLMG::BottomSolver::cg);
     282         1752 :         else if (bottom_solver == "bicgstab") mlmg.setBottomSolver(MLMG::BottomSolver::bicgstab);
     283         1752 :         else if (bottom_solver == "smoother") mlmg.setBottomSolver(MLMG::BottomSolver::smoother);
     284              : 
     285         1752 :         if (bottom_tol_rel >= 0) mlmg.setBottomTolerance(bottom_tol_rel);
     286         1752 :         if (bottom_tol_abs >= 0) mlmg.setBottomToleranceAbs(bottom_tol_abs);
     287              : 
     288         1752 :         if (omega >= 0) this->linop->SetOmega(omega);
     289         1752 :         if (average_down_coeffs) this->linop->SetAverageDownCoeffs(true);
     290         1752 :         if (normalize_ddw) this->linop->SetNormalizeDDW(true);
     291         1752 :     }
     292              : 
     293              : 
     294              : public:
     295              :     // These are the parameters that are read in for a standard 
     296              :     // multigrid linear solve.
     297           19 :     static void Parse(Linear& value, amrex::ParmParse& pp)
     298              :     {
     299              :         // Max number of iterations to perform before erroring out
     300           19 :         pp_query("max_iter", value.max_iter);
     301              : 
     302              :         // Max number of iterations on the bottom solver
     303           19 :         pp_query("bottom_max_iter", value.bottom_max_iter);
     304              : 
     305              :         // Max number of F-cycle iterations to perform
     306           19 :         pp_query("max_fmg_iter", value.max_fmg_iter);
     307              : 
     308              :         // DEPRICATED - do not use
     309           19 :         if (pp.contains("max_fixed_iter"))
     310            0 :             Util::Abort(INFO, "max_fixed_iter is depricated. Use fixed_iter instead.");
     311              : 
     312              :         // Number of fixed iterations to perform before exiting gracefully
     313           19 :         pp_query("fixed_iter", value.fixed_iter);
     314              : 
     315              :         // Verbosity of the solver (1-5)
     316           19 :         pp_query("verbose", value.verbose);
     317              : 
     318              :         // Number of smoothing operations before bottom solve (2)
     319           19 :         pp_query("pre_smooth", value.pre_smooth);
     320              : 
     321              :         // Number of smoothing operations after bottom solve (2)
     322           19 :         pp_query("post_smooth", value.post_smooth);
     323              : 
     324              :         // Number of final smoothing operations when smoother is used as bottom solver (8)
     325           19 :         pp_query("final_smooth", value.final_smooth);
     326              : 
     327              :         // Additional smoothing after bottom CG solver (0)
     328           19 :         pp_query("bottom_smooth", value.bottom_smooth);
     329              : 
     330              :         // The method that is used for the multigrid bottom solve (cg, bicgstab, smoother)
     331           19 :         pp_query("bottom_solver", value.bottom_solver);
     332              : 
     333           19 :         if (pp.contains("cg_tol_rel"))
     334            0 :             Util::Abort(INFO, "cg_tol_rel is depricated. Use bottom_tol_rel instead.");
     335           19 :         if (pp.contains("cg_tol_abs"))
     336            0 :             Util::Abort(INFO, "cg_tol_abs is depricated. Use bottom_tol_abs instead.");
     337              : 
     338              :         // Relative tolerance on bottom solver
     339           19 :         pp_query("bottom_tol_rel", value.bottom_tol_rel);
     340              : 
     341              :         // Absolute tolerance on bottom solver
     342           19 :         pp_query("bottom_tol_abs", value.bottom_tol_abs);
     343              : 
     344              :         // Relative tolerance
     345           19 :         pp_query("tol_rel", value.tol_rel);
     346              : 
     347              :         // Absolute tolerance
     348           19 :         pp_query("tol_abs", value.tol_abs);
     349              : 
     350              :         // Omega (used in gauss-seidel solver)
     351           19 :         pp_query("omega", value.omega);
     352              : 
     353              :         // Whether to average down coefficients or use the ones given.
     354              :         // (Setting this to true is important for fracture.)
     355           19 :         pp_query("average_down_coeffs", value.average_down_coeffs);
     356              : 
     357              :         // Whether to normalize DDW when calculating the diagonal.
     358              :         // This is primarily used when DDW is near-singular - like when there
     359              :         // is a "void" region or when doing phase field fracture.
     360           19 :         pp_query("normalize_ddw", value.normalize_ddw);
     361              : 
     362              :         // [false] 
     363              :         // If set to true, output diagnostic multifab information 
     364              :         // whenever the MLMG solver fails to converge.
     365              :         // (Note: you must also set :code:`amrex.signalhandling=0`
     366              :         // and :code:`amrex.throw_exception=1` for this to work.)
     367           19 :         pp_query("dump_on_fail", value.m_dump_on_fail);
     368              : 
     369              :         // [true]
     370              :         // If set to false, MLMG will not die if convergence criterion
     371              :         // is not reached.
     372              :         // (Note: you must also set :code:`amrex.signalhandling=0`
     373              :         // and :code:`amrex.throw_exception=1` for this to work.)
     374           19 :         pp_query("abort_on_fail", value.m_abort_on_fail);
     375              : 
     376           19 :         if (value.m_dump_on_fail || !value.m_abort_on_fail)
     377              :         {
     378            2 :             IO::ParmParse pp("amrex");
     379            2 :             pp.add("signal_handling", 0);
     380            2 :             pp.add("throw_exception", 1);
     381            2 :         }
     382              : 
     383           19 :     }
     384              : protected:
     385              :     bool m_defined = false;
     386              : 
     387              :     bool m_dump_on_fail = false;
     388              :     bool m_abort_on_fail = true;
     389              : 
     390              : };
     391              : }
     392              : }
     393              : #endif
        

Generated by: LCOV version 2.0-1