LCOV - code coverage report
Current view: top level - src/IC - Expression.H (source / functions) Coverage Total Hit
Test: coverage_merged.info Lines: 57.9 % 76 44
Test Date: 2025-04-03 04:02:21 Functions: 50.0 % 8 4

            Line data    Source code
       1              : //
       2              : // Initialize a field using a mathematical expression.
       3              : // Expressions are imported as strings and are compiled real-time using the
       4              : // `AMReX Parser <https://amrex-codes.github.io/amrex/docs_html/Basics.html#parser>`_.
       5              : //
       6              : // Works for single or multiple-component fields.
       7              : // Use the :code:`regionN` (N=0,1,2, etc. up to number of components) to pass expression.
       8              : // For example:
       9              : //
      10              : // .. code-block:: bash
      11              : // 
      12              : //    ic.region0 = "sin(x*y*z)"
      13              : //    ic.region1 = "3.0*(x > 0.5 and y > 0.5)"
      14              : //
      15              : // for a two-component field. It is up to you to make sure your expressions are parsed
      16              : // correctly; otherwise you will get undefined behavior.
      17              : //
      18              : // :bdg-primary-line:`Constants`
      19              : // You can add constants to your expressions using the :code:`constant` directive.
      20              : // For instance, in the following code
      21              : // 
      22              : // .. code-block:: bash
      23              : //
      24              : //    psi.ic.type=expression
      25              : //    psi.ic.expression.constant.eps = 0.05
      26              : //    psi.ic.expression.constant.R   = 0.25
      27              : //    psi.ic.expression.region0 = "0.5 + 0.5*tanh((x^2 + y^2 - R)/eps)"
      28              : //    
      29              : // the constants :code:`eps` and :code:`R` are defined by the user and then used
      30              : // in the subsequent expression.
      31              : // The variables can have any name made up of characters that is not reserved.
      32              : // However, if multiple ICs are used, they must be defined each time for each IC.
      33              : //
      34              : 
      35              : #ifndef IC_EXPRESSION_H_
      36              : #define IC_EXPRESSION_H_
      37              : #include "IC/IC.H"
      38              : #include "Util/Util.H"
      39              : #include "IO/ParmParse.H"
      40              : #include "AMReX_Parser.H"
      41              : 
      42              : namespace IC
      43              : {
      44              : class Expression : public IC<Set::Scalar>, public IC<Set::Vector>
      45              : {
      46              : private:
      47              :     enum CoordSys { Cartesian, Polar, Spherical };
      48              :     std::vector<amrex::Parser> parser;
      49              :     std::vector<amrex::ParserExecutor<4>> f;
      50              :     Expression::CoordSys coord = Expression::CoordSys::Cartesian;
      51              : public:
      52              :     static constexpr const char* name = "expression";
      53              :     Expression(amrex::Vector<amrex::Geometry>& _geom) : 
      54              :         IC<Set::Scalar>(_geom), IC<Set::Vector>(_geom) {}
      55           16 :     Expression(amrex::Vector<amrex::Geometry>& _geom, IO::ParmParse& pp, std::string name) : 
      56           16 :         IC<Set::Scalar>(_geom), IC<Set::Vector>(_geom)
      57              :     {
      58           48 :         pp_queryclass(name, *this);
      59           16 :     }
      60          287 :     virtual void Add(const int& lev, Set::Field<Set::Scalar>& a_field, Set::Scalar a_time = 0.0) override
      61              :     {
      62         2009 :         Util::Assert(INFO, TEST(a_field[lev]->nComp() == (int)f.size()));
      63          639 :         for (amrex::MFIter mfi(*a_field[lev], amrex::TilingIfNotGPU()); mfi.isValid(); ++mfi)
      64              :         {
      65          352 :             amrex::Box bx;// = mfi.tilebox();
      66              :             //bx.grow(a_field[lev]->nGrow());
      67          352 :             amrex::IndexType type = a_field[lev]->ixType();
      68          704 :             if (type == amrex::IndexType::TheCellType())      bx = mfi.growntilebox();
      69            0 :             else if (type == amrex::IndexType::TheNodeType()) bx = mfi.grownnodaltilebox();
      70            0 :             else Util::Abort(INFO, "Unkonwn index type");
      71              : 
      72          352 :             amrex::Array4<Set::Scalar> const& field = a_field[lev]->array(mfi);
      73          709 :             for (unsigned int n = 0; n < f.size(); n++)
      74              :             {
      75          357 :                 amrex::ParallelFor(bx, [=] AMREX_GPU_DEVICE(int i, int j, int k)
      76              :                 {
      77       300144 :                     Set::Vector x = Set::Position(i, j, k, IC<Set::Scalar>::geom[lev], type);
      78       300144 :                     if (coord == Expression::CoordSys::Cartesian)
      79              :                     {
      80              : #if AMREX_SPACEDIM == 1
      81              :                         field(i, j, k, n) = f[n](x(0), 0.0, 0.0, a_time);
      82              : #elif AMREX_SPACEDIM == 2
      83       900432 :                         field(i, j, k, n) = f[n](x(0), x(1), 0.0, a_time);
      84              : #elif AMREX_SPACEDIM == 3
      85            0 :                         field(i, j, k, n) = f[n](x(0), x(1), x(2), a_time);
      86              : #endif
      87              :                     }
      88              : #if AMREX_SPACEDIM>1
      89            0 :                     else if (coord == Expression::CoordSys::Polar)
      90              :                     {
      91            0 :                         field(i, j, k, n) = f[n](sqrt(x(0)* x(0) + x(1) * x(1)), std::atan2(x(1), x(0)), x(2), a_time);
      92              :                     }
      93              : #endif
      94       300144 :                 });
      95              :             }
      96          287 :         }
      97          287 :         a_field[lev]->FillBoundary();
      98          287 :     };
      99              : 
     100            0 :     virtual void Add(const int& lev, Set::Field<Set::Vector>& a_field, Set::Scalar a_time = 0.0) override
     101              :     {
     102            0 :         Util::Assert(INFO, TEST(a_field[lev]->nComp() == 1));
     103            0 :         Util::Assert(INFO, TEST(f.size() >= AMREX_SPACEDIM));
     104            0 :         for (amrex::MFIter mfi(*a_field[lev], amrex::TilingIfNotGPU()); mfi.isValid(); ++mfi)
     105              :         {
     106            0 :             amrex::Box bx;
     107            0 :             amrex::IndexType type = a_field[lev]->ixType();
     108            0 :             if (type == amrex::IndexType::TheCellType())      bx = mfi.growntilebox();
     109            0 :             else if (type == amrex::IndexType::TheNodeType()) bx = mfi.grownnodaltilebox();
     110            0 :             else Util::Abort(INFO, "Unkonwn index type");
     111              : 
     112            0 :             Set::Patch<Set::Vector> field = a_field.Patch(lev,mfi);
     113            0 :             for (unsigned int n = 0; n < AMREX_SPACEDIM; n++)
     114              :             {
     115            0 :                 amrex::ParallelFor(bx, [=] AMREX_GPU_DEVICE(int i, int j, int k)
     116              :                 {
     117            0 :                     Set::Vector x = Set::Position(i, j, k, IC<Set::Vector>::geom[lev], type);
     118            0 :                     if (coord == Expression::CoordSys::Cartesian)
     119              :                     {
     120              : #if AMREX_SPACEDIM == 1
     121              :                         field(i, j, k)(n) = f[n](x(0), 0.0, 0.0, a_time);
     122              : #elif AMREX_SPACEDIM == 2
     123            0 :                         field(i, j, k)(n) = f[n](x(0), x(1), 0.0, a_time);
     124              : #elif AMREX_SPACEDIM == 3
     125            0 :                         field(i, j, k)(n) = f[n](x(0), x(1), x(2), a_time);
     126              : #endif
     127              :                     }
     128              : #if AMREX_SPACEDIM>1
     129            0 :                     else if (coord == Expression::CoordSys::Polar)
     130              :                     {
     131            0 :                         field(i, j, k)(n) = f[n](sqrt(x(0)* x(0) + x(1) * x(1)), std::atan2(x(1), x(0)), x(2), a_time);
     132              :                     }
     133              : #endif
     134            0 :                 });
     135              :             }
     136            0 :         }
     137            0 :         a_field[lev]->FillBoundary();
     138            0 :     };
     139              : 
     140           16 :     static void Parse(Expression& value, IO::ParmParse& pp)
     141              :     {
     142              : 
     143           32 :         std::string coordstr = "";
     144              :         // coordinate system to use
     145          112 :         pp_query_validate("coord", coordstr, {"cartesian","polar"}); 
     146           16 :         if (coordstr == "cartesian") value.coord = Expression::CoordSys::Cartesian;
     147            0 :         else if (coordstr == "polar") value.coord = Expression::CoordSys::Polar;
     148            0 :         else Util::Abort(INFO, "unsupported coordinates ", coordstr);
     149              : 
     150           16 :         std::vector<std::string> expression_strs;
     151              :         // Mathematical expression in terms of x,y,z,t (if coord=cartesian)
     152              :         // or r,theta,z,t (if coord=polar) and any defined constants.
     153           80 :         pp.query_enumerate("region", expression_strs);
     154              : 
     155           37 :         for (unsigned int i = 0; i < expression_strs.size(); i++)
     156              :         {
     157           21 :             value.parser.push_back(amrex::Parser(expression_strs[i]));
     158              : 
     159              :             //
     160              :             // Read in user-defined constants and add them to the parser
     161              :             //
     162           21 :             std::string prefix = pp.getPrefix();
     163           21 :             std::set<std::string> entries = pp.getEntries(prefix + ".constant");//"constant");
     164           21 :             std::set<std::string>::iterator entry;
     165           52 :             for (entry = entries.begin(); entry != entries.end(); entry++)
     166              :             {
     167           31 :                 IO::ParmParse pp;
     168           31 :                 std::string fullname = *entry;
     169           31 :                 Set::Scalar val  = NAN;
     170           31 :                 pp_query(fullname.data(),val);
     171           31 :                 std::string name = Util::String::Split(fullname,'.').back();
     172           31 :                 value.parser.back().setConstant(name,val);
     173           31 :             }
     174              : 
     175           21 :             if (value.coord == Expression::CoordSys::Cartesian)
     176              :             {
     177          147 :                 value.parser.back().registerVariables({ "x","y","z","t" });
     178           21 :                 value.f.push_back(value.parser.back().compile<4>());
     179              :             }
     180            0 :             else if (value.coord == Expression::CoordSys::Polar)
     181              :             {
     182            0 :                 value.parser.back().registerVariables({ "r","theta","z","t" });
     183            0 :                 value.f.push_back(value.parser.back().compile<4>());
     184              :             }
     185           21 :         }
     186              : 
     187           79 :     };
     188              : };
     189              : }
     190              : 
     191              : #endif
        

Generated by: LCOV version 2.0-1