LCOV - code coverage report
Current view: top level - src/IC - Expression.H (source / functions) Coverage Total Hit
Test: coverage_merged.info Lines: 55.1 % 89 49
Test Date: 2025-08-12 17:45:17 Functions: 44.4 % 9 4

            Line data    Source code
       1              : //
       2              : // Initialize a field using a mathematical expression.
       3              : // Expressions are imported as strings and are compiled real-time using the
       4              : // `AMReX Parser <https://amrex-codes.github.io/amrex/docs_html/Basics.html#parser>`_.
       5              : //
       6              : // Works for single or multiple-component fields.
       7              : // Use the :code:`regionN` (N=0,1,2, etc. up to number of components) to pass expression.
       8              : // For example:
       9              : //
      10              : // .. code-block:: bash
      11              : // 
      12              : //    ic.region0 = "sin(x*y*z)"
      13              : //    ic.region1 = "3.0*(x > 0.5 and y > 0.5)"
      14              : //
      15              : // for a two-component field. It is up to you to make sure your expressions are parsed
      16              : // correctly; otherwise you will get undefined behavior.
      17              : //
      18              : // :bdg-primary-line:`Constants`
      19              : // You can add constants to your expressions using the :code:`constant` directive.
      20              : // For instance, in the following code
      21              : // 
      22              : // .. code-block:: bash
      23              : //
      24              : //    psi.ic.type=expression
      25              : //    psi.ic.expression.constant.eps = 0.05
      26              : //    psi.ic.expression.constant.R   = 0.25
      27              : //    psi.ic.expression.region0 = "0.5 + 0.5*tanh((x^2 + y^2 - R)/eps)"
      28              : //    
      29              : // the constants :code:`eps` and :code:`R` are defined by the user and then used
      30              : // in the subsequent expression.
      31              : // The variables can have any name made up of characters that is not reserved.
      32              : // However, if multiple ICs are used, they must be defined each time for each IC.
      33              : //
      34              : 
      35              : #ifndef IC_EXPRESSION_H_
      36              : #define IC_EXPRESSION_H_
      37              : #include "IC/IC.H"
      38              : #include "Util/Util.H"
      39              : #include "IO/ParmParse.H"
      40              : #include "AMReX_Parser.H"
      41              : #include <stdexcept>
      42              : 
      43              : namespace IC
      44              : {
      45              : class Expression : public IC<Set::Scalar>, public IC<Set::Vector>
      46              : {
      47              : private:
      48              :     enum CoordSys { Cartesian, Polar, Spherical };
      49              :     std::vector<amrex::Parser> parser;
      50              :     std::vector<amrex::ParserExecutor<4>> f;
      51              :     Expression::CoordSys coord = Expression::CoordSys::Cartesian;
      52              : public:
      53              :     static constexpr const char* name = "expression";
      54              :     Expression(amrex::Vector<amrex::Geometry>& _geom) : 
      55              :         IC<Set::Scalar>(_geom), IC<Set::Vector>(_geom) {}
      56           23 :     Expression(amrex::Vector<amrex::Geometry>& _geom, IO::ParmParse& pp, std::string name) : 
      57           23 :         IC<Set::Scalar>(_geom), IC<Set::Vector>(_geom)
      58              :     {
      59           23 :         pp_queryclass(name, *this);
      60           23 :     }
      61            0 :     Expression(amrex::Vector<amrex::Geometry>& _geom, Unit a_unit, IO::ParmParse& pp, std::string name) : 
      62            0 :         IC<Set::Scalar>(_geom), IC<Set::Vector>(_geom), unit(a_unit)
      63              :     {
      64            0 :         pp_queryclass(name, *this);
      65            0 :     }
      66          303 :     virtual void Add(const int& lev, Set::Field<Set::Scalar>& a_field, Set::Scalar a_time = 0.0) override
      67              :     {
      68         2121 :         Util::Assert(INFO, TEST(a_field[lev]->nComp() == (int)f.size()));
      69          671 :         for (amrex::MFIter mfi(*a_field[lev], amrex::TilingIfNotGPU()); mfi.isValid(); ++mfi)
      70              :         {
      71          368 :             amrex::Box bx;// = mfi.tilebox();
      72              :             //bx.grow(a_field[lev]->nGrow());
      73          368 :             amrex::IndexType type = a_field[lev]->ixType();
      74          736 :             if (type == amrex::IndexType::TheCellType())      bx = mfi.growntilebox();
      75            0 :             else if (type == amrex::IndexType::TheNodeType()) bx = mfi.grownnodaltilebox();
      76            0 :             else Util::Abort(INFO, "Unkonwn index type");
      77              : 
      78          368 :             amrex::Array4<Set::Scalar> const& field = a_field[lev]->array(mfi);
      79          743 :             for (unsigned int n = 0; n < f.size(); n++)
      80              :             {
      81          375 :                 amrex::ParallelFor(bx, [=] AMREX_GPU_DEVICE(int i, int j, int k)
      82              :                 {
      83       278264 :                     Set::Vector x = Set::Position(i, j, k, IC<Set::Scalar>::geom[lev], type);
      84       278264 :                     if (coord == Expression::CoordSys::Cartesian)
      85              :                     {
      86              : #if AMREX_SPACEDIM == 1
      87              :                         field(i, j, k, n) = f[n](x(0), 0.0, 0.0, a_time) * unitfactor;
      88              : #elif AMREX_SPACEDIM == 2
      89       834792 :                         field(i, j, k, n) = f[n](x(0), x(1), 0.0, a_time) * unitfactor;
      90              : #elif AMREX_SPACEDIM == 3
      91            0 :                         field(i, j, k, n) = f[n](x(0), x(1), x(2), a_time) * unitfactor;
      92              : #endif
      93              :                     }
      94              : #if AMREX_SPACEDIM>1
      95            0 :                     else if (coord == Expression::CoordSys::Polar)
      96              :                     {
      97            0 :                         field(i, j, k, n) = f[n](sqrt(x(0)* x(0) + x(1) * x(1)), std::atan2(x(1), x(0)), x(2), a_time)  * unitfactor;
      98              :                     }
      99              : #endif
     100       278264 :                 });
     101              :             }
     102          303 :         }
     103          303 :         a_field[lev]->FillBoundary();
     104          303 :     };
     105              : 
     106            0 :     virtual void Add(const int& lev, Set::Field<Set::Vector>& a_field, Set::Scalar a_time = 0.0) override
     107              :     {
     108            0 :         Util::Assert(INFO, TEST(a_field[lev]->nComp() == 1));
     109            0 :         Util::Assert(INFO, TEST(f.size() >= AMREX_SPACEDIM));
     110            0 :         for (amrex::MFIter mfi(*a_field[lev], amrex::TilingIfNotGPU()); mfi.isValid(); ++mfi)
     111              :         {
     112            0 :             amrex::Box bx;
     113            0 :             amrex::IndexType type = a_field[lev]->ixType();
     114            0 :             if (type == amrex::IndexType::TheCellType())      bx = mfi.growntilebox();
     115            0 :             else if (type == amrex::IndexType::TheNodeType()) bx = mfi.grownnodaltilebox();
     116            0 :             else Util::Abort(INFO, "Unkonwn index type");
     117              : 
     118            0 :             Set::Patch<Set::Vector> field = a_field.Patch(lev,mfi);
     119            0 :             for (unsigned int n = 0; n < AMREX_SPACEDIM; n++)
     120              :             {
     121            0 :                 amrex::ParallelFor(bx, [=] AMREX_GPU_DEVICE(int i, int j, int k)
     122              :                 {
     123            0 :                     Set::Vector x = Set::Position(i, j, k, IC<Set::Vector>::geom[lev], type);
     124            0 :                     if (coord == Expression::CoordSys::Cartesian)
     125              :                     {
     126              : #if AMREX_SPACEDIM == 1
     127              :                         field(i, j, k)(n) = f[n](x(0), 0.0, 0.0, a_time) * unitfactor;
     128              : #elif AMREX_SPACEDIM == 2
     129            0 :                         field(i, j, k)(n) = f[n](x(0), x(1), 0.0, a_time) * unitfactor;
     130              : #elif AMREX_SPACEDIM == 3
     131            0 :                         field(i, j, k)(n) = f[n](x(0), x(1), x(2), a_time) * unitfactor;
     132              : #endif
     133              :                     }
     134              : #if AMREX_SPACEDIM>1
     135            0 :                     else if (coord == Expression::CoordSys::Polar)
     136              :                     {
     137            0 :                         field(i, j, k)(n) = f[n](sqrt(x(0)* x(0) + x(1) * x(1)), std::atan2(x(1), x(0)), x(2), a_time)  * unitfactor;
     138              :                     }
     139              : #endif
     140            0 :                 });
     141              :             }
     142            0 :         }
     143            0 :         a_field[lev]->FillBoundary();
     144            0 :     };
     145              : 
     146           23 :     static void Parse(Expression& value, IO::ParmParse& pp)
     147              :     {
     148           23 :         std::string coordstr = "";
     149              :         // coordinate system to use
     150           92 :         pp_query_validate("coord", coordstr, {"cartesian","polar"}); 
     151           23 :         if (coordstr == "cartesian") value.coord = Expression::CoordSys::Cartesian;
     152            0 :         else if (coordstr == "polar") value.coord = Expression::CoordSys::Polar;
     153            0 :         else Util::Exception(INFO, "unsupported coordinates ", coordstr);
     154              : 
     155           23 :         std::string unitstr;
     156              :         // Units of the value that is returned by the expression
     157           23 :         pp.query_default("unit",unitstr,"");
     158              :         try
     159              :         {
     160           23 :             Unit unit_spec = Unit::Parse(unitstr);
     161           23 :             if (!unit_spec.isType(value.unit) && !unit_spec.isType(Unit::Less()))
     162            0 :                 Util::Exception(INFO, "Incompatible unit specified: ", unitstr);
     163           23 :             value.unitfactor = unit_spec.normalized_value();
     164              :         }
     165            0 :         catch (std::runtime_error &e)
     166              :         {
     167            0 :             Util::Exception(INFO,e.what());
     168            0 :         }
     169              : 
     170           23 :         std::vector<std::string> expression_strs;
     171              :         // Mathematical expression in terms of x,y,z,t (if coord=cartesian)
     172              :         // or r,theta,z,t (if coord=polar) and any defined constants.
     173           23 :         pp.query_enumerate("region", expression_strs);
     174              : 
     175           53 :         for (unsigned int i = 0; i < expression_strs.size(); i++)
     176              :         {
     177           30 :             value.parser.push_back(amrex::Parser(expression_strs[i]));
     178              : 
     179              :             //
     180              :             // Read in user-defined constants and add them to the parser
     181              :             //
     182           30 :             std::string prefix = pp.getPrefix();
     183           30 :             std::set<std::string> entries = pp.getEntries(prefix + ".constant");//"constant");
     184           30 :             std::set<std::string>::iterator entry;
     185           77 :             for (entry = entries.begin(); entry != entries.end(); entry++)
     186              :             {
     187           47 :                 IO::ParmParse pp;
     188           47 :                 std::string fullname = *entry;
     189           47 :                 Unit val;
     190           94 :                 pp.queryunit(fullname.data(),val);
     191           47 :                 std::string name = Util::String::Split(fullname,'.').back();
     192           47 :                 value.parser.back().setConstant(name,val.normalized_value());
     193           47 :             }
     194              : 
     195           30 :             if (value.coord == Expression::CoordSys::Cartesian)
     196              :             {
     197          210 :                 value.parser.back().registerVariables({ "x","y","z","t" });
     198           30 :                 value.f.push_back(value.parser.back().compile<4>());
     199              :             }
     200            0 :             else if (value.coord == Expression::CoordSys::Polar)
     201              :             {
     202            0 :                 value.parser.back().registerVariables({ "r","theta","z","t" });
     203            0 :                 value.f.push_back(value.parser.back().compile<4>());
     204              :             }
     205           30 :         }
     206          113 :     };
     207              : private:
     208              :     Unit unit = Unit::Less();
     209              :     Set::Scalar unitfactor = NAN;
     210              : };
     211              : }
     212              : 
     213              : #endif
        

Generated by: LCOV version 2.0-1