LCOV - code coverage report
Current view: top level - src/Solver/Nonlocal - Linear.H (source / functions) Hit Total Coverage
Test: coverage_merged.info Lines: 81 113 71.7 %
Date: 2025-01-16 18:33:59 Functions: 7 8 87.5 %

          Line data    Source code
       1             : #ifndef SOLVER_NONLOCAL_LINEAR
       2             : #define SOLVER_NONLOCAL_LINEAR
       3             : #include "Operator/Operator.H"
       4             : #include <AMReX_MLMG.H>
       5             : #include "IC/Trig.H"
       6             : 
       7             : namespace Solver
       8             : {
       9             : namespace Nonlocal
      10             : {
      11             : /// \brief Multigrid Linear solver for multicomponent, multi-level operators
      12             : /// 
      13             : /// This class is a thin wrapper for the `amrex::MLMG` solver.
      14             : /// It exists to set a range of default MLMG settings automatically, for instance,
      15             : /// `setCFStrategy`, which may not be obvious to the user.
      16             : ///
      17             : /// It also exists as a compatibility layer so that future fixes for compatibility
      18             : /// with AMReX can be implemented here.
      19             : class Linear  // : public amrex::MLMG
      20             : {
      21             : public:
      22             : 
      23          25 :     Linear() //Operator::Operator<Grid::Node>& a_lp) : MLMG(a_lp), linop(a_lp)
      24          25 :     {
      25          25 :     }
      26             : 
      27             :     Linear(Operator::Operator<Grid::Node>& a_lp)
      28             :     {
      29             :         this->Define(a_lp);
      30             :     }
      31             : 
      32          25 :     ~Linear()
      33          25 :     {
      34          25 :         if (m_defined) Clear();
      35          25 :     }
      36             : 
      37         426 :     void Define(Operator::Operator<Grid::Node>& a_lp)
      38             :     {
      39         426 :         if (m_defined) Util::Abort(INFO, "Solver cannot be re-defined");
      40         426 :         this->linop = &a_lp;
      41         426 :         this->mlmg = new amrex::MLMG(a_lp);
      42         426 :         m_defined = true;
      43         426 :         PrepareMLMG(*mlmg);
      44         426 :     }
      45         426 :     void Clear()
      46             :     {
      47         426 :         if (!m_defined) Util::Abort(INFO, "Solver cannot be cleared if not defined");
      48         426 :         this->linop = nullptr;
      49         426 :         if (this->mlmg) delete this->mlmg;
      50         426 :         this->mlmg = nullptr;
      51         426 :         m_defined = false;
      52         426 :     }
      53             : 
      54             :     Set::Scalar solveaffine(amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_sol,
      55             :         amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_rhs,
      56             :         Real a_tol_rel, Real a_tol_abs, bool copyrhs = false,
      57             :         const char* checkpoint_file = nullptr)
      58             :     {
      59             :         if (!m_defined) Util::Abort(INFO, "Solver not defined");
      60             :         amrex::Vector<amrex::MultiFab*> rhs_tmp(a_rhs.size());
      61             :         amrex::Vector<amrex::MultiFab*> zero_tmp(a_rhs.size());
      62             :         for (int i = 0; i < rhs_tmp.size(); i++)
      63             :         {
      64             :             rhs_tmp[i] = new amrex::MultiFab(a_rhs[i]->boxArray(), a_rhs[i]->DistributionMap(), a_rhs[i]->nComp(), a_rhs[i]->nGrow());
      65             :             zero_tmp[i] = new amrex::MultiFab(a_rhs[i]->boxArray(), a_rhs[i]->DistributionMap(), a_rhs[i]->nComp(), a_rhs[i]->nGrow());
      66             :             rhs_tmp[i]->setVal(0.0);
      67             :             zero_tmp[i]->setVal(0.0);
      68             :             Util::Message(INFO, rhs_tmp[i]->norm0());
      69             :         }
      70             : 
      71             :         linop->SetHomogeneous(false);
      72             :         mlmg->apply(rhs_tmp, zero_tmp);
      73             : 
      74             :         for (int lev = 0; lev < rhs_tmp.size(); lev++)
      75             :         {
      76             :             amrex::Box domain = linop->Geom(lev).Domain();
      77             :             domain.convert(amrex::IntVect::TheNodeVector());
      78             :             const Dim3 lo = amrex::lbound(domain), hi = amrex::ubound(domain);
      79             :             for (MFIter mfi(*rhs_tmp[lev], amrex::TilingIfNotGPU());mfi.isValid();++mfi)
      80             :             {
      81             :                 amrex::Box bx = mfi.growntilebox(rhs_tmp[lev]->nGrow());
      82             :                 bx = bx & domain;
      83             :                 amrex::Array4<amrex::Real> const& rhstmp = rhs_tmp[lev]->array(mfi);
      84             :                 for (int n = 0; n < rhs_tmp[lev]->nComp(); n++)
      85             :                 {
      86             :                     amrex::ParallelFor(bx, [=] AMREX_GPU_DEVICE(int i, int j, int k)
      87             :                     {
      88             :                         bool    AMREX_D_DECL(xmin = (i == lo.x), ymin = (j == lo.y), zmin = (k == lo.z)),
      89             :                             AMREX_D_DECL(xmax = (i == hi.x), ymax = (j == hi.y), zmax = (k == hi.z));
      90             :                         if (AMREX_D_TERM(xmax || xmin, || ymax || ymin, || zmax || zmin))
      91             :                             rhstmp(i, j, k, n) = 0.0;
      92             :                         else
      93             :                             rhstmp(i, j, k, n) *= -1.0;
      94             :                     });
      95             :                 }
      96             :             }
      97             :             Util::Message(INFO, rhs_tmp[lev]->norm0());
      98             :             linop->realFillBoundary(*rhs_tmp[lev], linop->Geom(lev));
      99             :             Util::Message(INFO, rhs_tmp[lev]->norm0());
     100             :             //rhs_tmp[lev]->FillBoundary();
     101             :         }
     102             : 
     103             :         for (int lev = 0; lev < rhs_tmp.size(); lev++)
     104             :         {
     105             :             Util::Message(INFO, rhs_tmp[lev]->norm0());
     106             :             amrex::Add(*rhs_tmp[lev], *a_rhs[lev], 0, 0, rhs_tmp[lev]->nComp(), rhs_tmp[lev]->nGrow());
     107             :             if (copyrhs)
     108             :                 amrex::Copy(*a_rhs[lev], *rhs_tmp[lev], 0, 0, rhs_tmp[lev]->nComp(), rhs_tmp[lev]->nGrow());
     109             :             Util::Message(INFO, rhs_tmp[lev]->norm0());
     110             :         }
     111             : 
     112             :         linop->SetHomogeneous(true);
     113             :         PrepareMLMG(*mlmg);
     114             :         Set::Scalar retval = NAN;
     115             :         try
     116             :         {
     117             :             retval = mlmg->solve(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(rhs_tmp), a_tol_rel, a_tol_abs, checkpoint_file);
     118             :         }
     119             :         catch (const std::exception& e)
     120             :         {
     121             :             if (m_dump_on_fail) dumpOnConvergenceFail(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(rhs_tmp));
     122             :             if (m_abort_on_fail) Util::Abort(INFO, e.what());
     123             :         }
     124             :         if (a_sol[0]->contains_nan()) 
     125             :         {
     126             :             dumpOnConvergenceFail(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(rhs_tmp));
     127             :             Util::Abort(INFO);
     128             :         }
     129             : 
     130             :         return retval;
     131             :     };
     132             : 
     133        1326 :     Set::Scalar solve(amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_sol,
     134             :         amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_rhs,
     135             :         Real a_tol_rel, Real a_tol_abs, const char* checkpoint_file = nullptr)
     136             :     {
     137        1326 :         PrepareMLMG(*mlmg);
     138        1326 :         Set::Scalar retval = NAN;
     139             :         try
     140             :         {
     141        1326 :             retval = mlmg->solve(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(a_rhs), a_tol_rel, a_tol_abs, checkpoint_file);
     142             :         }
     143           0 :         catch (const std::exception& e)
     144             :         {
     145           0 :             if (m_dump_on_fail) dumpOnConvergenceFail(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(a_rhs));
     146           0 :             if (m_abort_on_fail) Util::Abort(INFO, e.what());
     147             :         }
     148        1326 :         return retval;
     149             :     };
     150             :     Set::Scalar solve(amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_sol,
     151             :         amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_rhs)
     152             :     {
     153             :         PrepareMLMG(*mlmg);
     154             :         Set::Scalar retval = NAN;
     155             :         try
     156             :         {
     157             :             retval = mlmg->solve(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(a_rhs), tol_rel, tol_abs);
     158             :         }
     159             :         catch (const std::exception& e)
     160             :         {
     161             :             if (m_dump_on_fail) dumpOnConvergenceFail(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(a_rhs));
     162             :             if (m_abort_on_fail) Util::Abort(INFO, e.what());
     163             :         }
     164             :         if (a_sol[0]->contains_nan()) 
     165             :         {
     166             :             dumpOnConvergenceFail(GetVecOfPtrs(a_sol), GetVecOfConstPtrs(a_rhs));
     167             :             Util::Abort(INFO);
     168             :         }
     169             :         return retval;
     170             :     };
     171             :     void apply(amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_rhs,
     172             :         amrex::Vector<std::unique_ptr<amrex::MultiFab> >& a_sol)
     173             :     {
     174             :         PrepareMLMG(*mlmg);
     175             :         mlmg->apply(GetVecOfPtrs(a_rhs), GetVecOfPtrs(a_sol));
     176             :     };
     177             : 
     178             : 
     179             : 
     180             :     void setMaxIter(const int a_max_iter) { max_iter = a_max_iter; }
     181             :     void setBottomMaxIter(const int a_bottom_max_iter) { bottom_max_iter = a_bottom_max_iter; }
     182             :     void setMaxFmgIter(const int a_max_fmg_iter) { max_fmg_iter = a_max_fmg_iter; }
     183             :     void setFixedIter(const int a_fixed_iter) { fixed_iter = a_fixed_iter; }
     184             :     void setVerbose(const int a_verbose) { verbose = a_verbose; }
     185             :     void setPreSmooth(const int a_pre_smooth) { pre_smooth = a_pre_smooth; }
     186             :     void setPostSmooth(const int a_post_smooth) { post_smooth = a_post_smooth; }
     187             : 
     188           0 :     void dumpOnConvergenceFail(const amrex::Vector<amrex::MultiFab*>& a_sol_mf,
     189             :         const amrex::Vector<amrex::MultiFab const*>& a_rhs_mf)
     190             :     {
     191           0 :         int nlevs = a_sol_mf.size();
     192           0 :         int ncomps = a_sol_mf[0]->nComp();
     193             : 
     194           0 :         amrex::Vector<amrex::Geometry> geom;
     195           0 :         amrex::Vector<int> iter;
     196           0 :         amrex::Vector<amrex::IntVect> refratio;
     197           0 :         amrex::Vector<std::string> names;
     198           0 :         for (int i = 0; i < nlevs; i++)
     199             :         {
     200           0 :             geom.push_back(linop->Geom(i));
     201           0 :             iter.push_back(0);
     202           0 :             if (i > 0) refratio.push_back(amrex::IntVect(2));
     203             :         }
     204           0 :         for (int n = 0; n < ncomps; n++)
     205             :         {
     206           0 :             names.push_back("var" + std::to_string(n));
     207             :         }
     208             : 
     209           0 :         std::string outputdir = Util::GetFileName();
     210           0 :         WriteMultiLevelPlotfile(outputdir + "/mlmg_sol", nlevs,
     211           0 :             amrex::GetVecOfConstPtrs(a_sol_mf),
     212             :             names, geom, 0, iter, refratio);
     213           0 :         WriteMultiLevelPlotfile(outputdir + "/mlmg_rhs", nlevs,
     214           0 :             amrex::GetVecOfConstPtrs(a_rhs_mf),
     215             :             names, geom, 0, iter, refratio);
     216             : 
     217           0 :         Set::Field<Set::Scalar> res_mf(nlevs);
     218           0 :         for (int lev = 0; lev < nlevs; lev++)
     219             :         {
     220           0 :             res_mf.Define(lev, a_sol_mf[lev]->boxArray(), a_sol_mf[lev]->DistributionMap(),
     221           0 :                 ncomps, a_sol_mf[lev]->nGrow());
     222             :         }
     223             : 
     224           0 :         mlmg->compResidual(amrex::GetVecOfPtrs(res_mf), a_sol_mf, a_rhs_mf);
     225             : 
     226           0 :         WriteMultiLevelPlotfile(outputdir + "/mlmg_res", nlevs,
     227           0 :             amrex::GetVecOfConstPtrs(res_mf),
     228             :             names, geom, 0, iter, refratio);
     229             : 
     230           0 :     }
     231             : 
     232             :     //using MLMG::solve;
     233             : protected:
     234             :     int max_iter = -1;
     235             :     int bottom_max_iter = -1;
     236             :     int max_fmg_iter = -1;
     237             :     int fixed_iter = -1;
     238             :     int verbose = -1;
     239             :     int pre_smooth = -1;
     240             :     int post_smooth = -1;
     241             :     int final_smooth = -1;
     242             :     int bottom_smooth = -1;
     243             :     std::string bottom_solver;
     244             :     Set::Scalar cg_tol_rel = -1.0;
     245             :     Set::Scalar cg_tol_abs = -1.0;
     246             :     Set::Scalar bottom_tol_rel = -1.0;
     247             :     Set::Scalar bottom_tol_abs = -1.0;
     248             :     Set::Scalar tol_rel = -1.0;
     249             :     Set::Scalar tol_abs = -1.0;
     250             :     Set::Scalar omega = -1.0;
     251             :     bool average_down_coeffs = false;
     252             :     bool normalize_ddw = false;
     253             : 
     254             :     Operator::Operator<Grid::Node>* linop;
     255             :     amrex::MLMG* mlmg;
     256             : 
     257        1752 :     void PrepareMLMG(amrex::MLMG& mlmg)
     258             :     {
     259        1752 :         if (!m_defined) Util::Message(INFO, "Solver not defined");
     260        1752 :         mlmg.setBottomSolver(MLMG::BottomSolver::bicgstab);
     261        1752 :         mlmg.setCFStrategy(MLMG::CFStrategy::ghostnodes);
     262        1752 :         mlmg.setFinalFillBC(false);
     263        1752 :         mlmg.setMaxFmgIter(100000000);
     264             : 
     265             : 
     266        1752 :         if (max_iter >= 0)        mlmg.setMaxIter(max_iter);
     267        1752 :         if (bottom_max_iter >= 0) mlmg.setBottomMaxIter(bottom_max_iter);
     268        1752 :         if (max_fmg_iter >= 0)    mlmg.setMaxFmgIter(max_fmg_iter);
     269        1752 :         if (fixed_iter >= 0)      mlmg.setFixedIter(fixed_iter);
     270        1752 :         if (verbose >= 0)
     271             :         {
     272        1752 :             mlmg.setVerbose(verbose - 1);
     273        1752 :             if (verbose > 4)      mlmg.setBottomVerbose(verbose);
     274        1752 :             else                  mlmg.setBottomVerbose(0);
     275             :         }
     276             : 
     277        1752 :         if (pre_smooth >= 0)      mlmg.setPreSmooth(pre_smooth);
     278        1752 :         if (post_smooth >= 0)     mlmg.setPostSmooth(post_smooth);
     279        1752 :         if (final_smooth >= 0)    mlmg.setFinalSmooth(final_smooth);
     280        1752 :         if (bottom_smooth >= 0)   mlmg.setBottomSmooth(bottom_smooth);
     281             : 
     282        1752 :         if (bottom_solver == "cg")       mlmg.setBottomSolver(MLMG::BottomSolver::cg);
     283        1752 :         else if (bottom_solver == "bicgstab") mlmg.setBottomSolver(MLMG::BottomSolver::bicgstab);
     284        1752 :         else if (bottom_solver == "smoother") mlmg.setBottomSolver(MLMG::BottomSolver::smoother);
     285             : 
     286        1752 :         if (bottom_tol_rel >= 0) mlmg.setBottomTolerance(bottom_tol_rel);
     287        1752 :         if (bottom_tol_abs >= 0) mlmg.setBottomToleranceAbs(bottom_tol_abs);
     288             : 
     289        1752 :         if (omega >= 0) this->linop->SetOmega(omega);
     290        1752 :         if (average_down_coeffs) this->linop->SetAverageDownCoeffs(true);
     291        1752 :         if (normalize_ddw) this->linop->SetNormalizeDDW(true);
     292        1752 :     }
     293             : 
     294             : 
     295             : public:
     296             :     // These are the parameters that are read in for a standard 
     297             :     // multigrid linear solve.
     298          19 :     static void Parse(Linear& value, amrex::ParmParse& pp)
     299             :     {
     300             :         // Max number of iterations to perform before erroring out
     301          19 :         pp_query("max_iter", value.max_iter);
     302             : 
     303             :         // Max number of iterations on the bottom solver
     304          19 :         pp_query("bottom_max_iter", value.bottom_max_iter);
     305             : 
     306             :         // Max number of F-cycle iterations to perform
     307          19 :         pp_query("max_fmg_iter", value.max_fmg_iter);
     308             : 
     309             :         // DEPRICATED - do not use
     310          19 :         if (pp.contains("max_fixed_iter"))
     311           0 :             Util::Abort(INFO, "max_fixed_iter is depricated. Use fixed_iter instead.");
     312             : 
     313             :         // Number of fixed iterations to perform before exiting gracefully
     314          19 :         pp_query("fixed_iter", value.fixed_iter);
     315             : 
     316             :         // Verbosity of the solver (1-5)
     317          19 :         pp_query("verbose", value.verbose);
     318             : 
     319             :         // Number of smoothing operations before bottom solve (2)
     320          19 :         pp_query("pre_smooth", value.pre_smooth);
     321             : 
     322             :         // Number of smoothing operations after bottom solve (2)
     323          19 :         pp_query("post_smooth", value.post_smooth);
     324             : 
     325             :         // Number of final smoothing operations when smoother is used as bottom solver (8)
     326          19 :         pp_query("final_smooth", value.final_smooth);
     327             : 
     328             :         // Additional smoothing after bottom CG solver (0)
     329          19 :         pp_query("bottom_smooth", value.bottom_smooth);
     330             : 
     331             :         // The method that is used for the multigrid bottom solve (cg, bicgstab, smoother)
     332          19 :         pp_query("bottom_solver", value.bottom_solver);
     333             : 
     334          19 :         if (pp.contains("cg_tol_rel"))
     335           0 :             Util::Abort(INFO, "cg_tol_rel is depricated. Use bottom_tol_rel instead.");
     336          19 :         if (pp.contains("cg_tol_abs"))
     337           0 :             Util::Abort(INFO, "cg_tol_abs is depricated. Use bottom_tol_abs instead.");
     338             : 
     339             :         // Relative tolerance on bottom solver
     340          19 :         pp_query("bottom_tol_rel", value.bottom_tol_rel);
     341             : 
     342             :         // Absolute tolerance on bottom solver
     343          19 :         pp_query("bottom_tol_abs", value.bottom_tol_abs);
     344             : 
     345             :         // Relative tolerance
     346          19 :         pp_query("tol_rel", value.tol_rel);
     347             : 
     348             :         // Absolute tolerance
     349          19 :         pp_query("tol_abs", value.tol_abs);
     350             : 
     351             :         // Omega (used in gauss-seidel solver)
     352          19 :         pp_query("omega", value.omega);
     353             : 
     354             :         // Whether to average down coefficients or use the ones given.
     355             :         // (Setting this to true is important for fracture.)
     356          19 :         pp_query("average_down_coeffs", value.average_down_coeffs);
     357             : 
     358             :         // Whether to normalize DDW when calculating the diagonal.
     359             :         // This is primarily used when DDW is near-singular - like when there
     360             :         // is a "void" region or when doing phase field fracture.
     361          19 :         pp_query("normalize_ddw", value.normalize_ddw);
     362             : 
     363             :         // [false] 
     364             :         // If set to true, output diagnostic multifab information 
     365             :         // whenever the MLMG solver fails to converge.
     366             :         // (Note: you must also set :code:`amrex.signalhandling=0`
     367             :         // and :code:`amrex.throw_exception=1` for this to work.)
     368          19 :         pp_query("dump_on_fail", value.m_dump_on_fail);
     369             : 
     370             :         // [true]
     371             :         // If set to false, MLMG will not die if convergence criterion
     372             :         // is not reached.
     373             :         // (Note: you must also set :code:`amrex.signalhandling=0`
     374             :         // and :code:`amrex.throw_exception=1` for this to work.)
     375          19 :         pp_query("abort_on_fail", value.m_abort_on_fail);
     376             : 
     377          19 :         if (value.m_dump_on_fail || !value.m_abort_on_fail)
     378             :         {
     379           6 :             IO::ParmParse pp("amrex");
     380           2 :             pp.add("signal_handling", 0);
     381           2 :             pp.add("throw_exception", 1);
     382             :         }
     383             : 
     384          19 :     }
     385             : protected:
     386             :     bool m_defined = false;
     387             : 
     388             :     bool m_dump_on_fail = false;
     389             :     bool m_abort_on_fail = true;
     390             : 
     391             : };
     392             : }
     393             : }
     394             : #endif

Generated by: LCOV version 1.14