reaxion 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- reaxion/__init__.py +5 -0
- reaxion/data/__init__.py +1 -0
- reaxion/data/atomic_weights.py +123 -0
- reaxion/data/solar_abundances.py +49 -0
- reaxion/eos/__init__.py +1 -0
- reaxion/eos/eos.py +3 -0
- reaxion/equation.py +41 -0
- reaxion/equation_system.py +380 -0
- reaxion/localstate.py +21 -0
- reaxion/misc.py +66 -0
- reaxion/networks/__init__.py +0 -0
- reaxion/numerics/__init__.py +1 -0
- reaxion/numerics/solvers.py +98 -0
- reaxion/numerics/tests/__init__.py +0 -0
- reaxion/numerics/tests/test_linear.py +33 -0
- reaxion/numerics/tests/test_newton_rootsolve.py +34 -0
- reaxion/process.py +126 -0
- reaxion/processes/__init__.py +7 -0
- reaxion/processes/freefree_emission.py +32 -0
- reaxion/processes/ionization.py +95 -0
- reaxion/processes/line_cooling.py +56 -0
- reaxion/processes/nbody_process.py +71 -0
- reaxion/processes/recombination.py +112 -0
- reaxion/processes/thermal_process.py +22 -0
- reaxion/symbols.py +57 -0
- reaxion-0.1.1.dist-info/METADATA +411 -0
- reaxion-0.1.1.dist-info/RECORD +30 -0
- reaxion-0.1.1.dist-info/WHEEL +5 -0
- reaxion-0.1.1.dist-info/licenses/LICENSE +21 -0
- reaxion-0.1.1.dist-info/top_level.txt +1 -0
reaxion/__init__.py
ADDED
reaxion/data/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .solar_abundances import SolarAbundances
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
atomic_weights = {
|
|
2
|
+
"Z": float("nan"),
|
|
3
|
+
"H": 1.008,
|
|
4
|
+
"He": 4.002602,
|
|
5
|
+
"Li": 6.94,
|
|
6
|
+
"Be": 9.0121831,
|
|
7
|
+
"B": 10.81,
|
|
8
|
+
"C": 12.011,
|
|
9
|
+
"N": 14.007,
|
|
10
|
+
"O": 15.999,
|
|
11
|
+
"F": 18.998403163,
|
|
12
|
+
"Ne": 20.1797,
|
|
13
|
+
"Na": 22.98976928,
|
|
14
|
+
"Mg": 24.305,
|
|
15
|
+
"Al": 26.9815384,
|
|
16
|
+
"Si": 28.085,
|
|
17
|
+
"P": 30.973761998,
|
|
18
|
+
"S": 32.06,
|
|
19
|
+
"Cl": 35.45,
|
|
20
|
+
"Ar": 39.95,
|
|
21
|
+
"K": 39.0983,
|
|
22
|
+
"Ca": 40.078,
|
|
23
|
+
"Sc": 44.955907,
|
|
24
|
+
"Ti": 47.867,
|
|
25
|
+
"V": 50.9415,
|
|
26
|
+
"Cr": 51.9961,
|
|
27
|
+
"Mn": 54.938043,
|
|
28
|
+
"Fe": 55.845,
|
|
29
|
+
"Co": 58.933194,
|
|
30
|
+
"Ni": 58.6934,
|
|
31
|
+
"Cu": 63.546,
|
|
32
|
+
"Zn": 65.38,
|
|
33
|
+
"Ga": 69.723,
|
|
34
|
+
"Ge": 72.630,
|
|
35
|
+
"As": 74.921595,
|
|
36
|
+
"Se": 78.971,
|
|
37
|
+
"Br": 79.904,
|
|
38
|
+
"Kr": 83.798,
|
|
39
|
+
"Rb": 85.4678,
|
|
40
|
+
"Sr": 87.62,
|
|
41
|
+
"Y": 88.905838,
|
|
42
|
+
"Zr": 91.224,
|
|
43
|
+
"Nb": 92.90637,
|
|
44
|
+
"Mo": 95.95,
|
|
45
|
+
"Tc": 97,
|
|
46
|
+
"Ru": 101.07,
|
|
47
|
+
"Rh": 102.90549,
|
|
48
|
+
"Pd": 106.42,
|
|
49
|
+
"Ag": 107.8682,
|
|
50
|
+
"Cd": 112.414,
|
|
51
|
+
"In": 114.818,
|
|
52
|
+
"Sn": 118.710,
|
|
53
|
+
"Sb": 121.760,
|
|
54
|
+
"Te": 127.60,
|
|
55
|
+
"I": 126.90447,
|
|
56
|
+
"Xe": 131.293,
|
|
57
|
+
"Cs": 132.90545196,
|
|
58
|
+
"Ba": 137.327,
|
|
59
|
+
"La": 138.90547,
|
|
60
|
+
"Ce": 140.116,
|
|
61
|
+
"Pr": 140.90766,
|
|
62
|
+
"Nd": 144.242,
|
|
63
|
+
"Pm": 145,
|
|
64
|
+
"Sm": 150.36,
|
|
65
|
+
"Eu": 151.964,
|
|
66
|
+
"Gd": 157.25,
|
|
67
|
+
"Tb": 158.925354,
|
|
68
|
+
"Dy": 162.500,
|
|
69
|
+
"Ho": 164.930329,
|
|
70
|
+
"Er": 167.259,
|
|
71
|
+
"Tm": 168.934219,
|
|
72
|
+
"Yb": 173.045,
|
|
73
|
+
"Lu": 174.9668,
|
|
74
|
+
"Hf": 178.486,
|
|
75
|
+
"Ta": 180.94788,
|
|
76
|
+
"W": 183.84,
|
|
77
|
+
"Re": 186.207,
|
|
78
|
+
"Os": 190.23,
|
|
79
|
+
"Ir": 192.217,
|
|
80
|
+
"Pt": 195.084,
|
|
81
|
+
"Au": 196.966570,
|
|
82
|
+
"Hg": 200.592,
|
|
83
|
+
"Tl": 204.38,
|
|
84
|
+
"Pb": 207.2,
|
|
85
|
+
"Bi": 208.98040,
|
|
86
|
+
"Po": 209,
|
|
87
|
+
"At": 210,
|
|
88
|
+
"Rn": 222,
|
|
89
|
+
"Fr": 223,
|
|
90
|
+
"Ra": 226,
|
|
91
|
+
"Ac": 227,
|
|
92
|
+
"Th": 232.0377,
|
|
93
|
+
"Pa": 231.03588,
|
|
94
|
+
"U": 238.02891,
|
|
95
|
+
"Np": 237,
|
|
96
|
+
"Pu": 244,
|
|
97
|
+
"Am": 243,
|
|
98
|
+
"Cm": 247,
|
|
99
|
+
"Bk": 247,
|
|
100
|
+
"Cf": 251,
|
|
101
|
+
"Es": 252,
|
|
102
|
+
"Fm": 257,
|
|
103
|
+
"Md": 258,
|
|
104
|
+
"No": 259,
|
|
105
|
+
"Lr": 262,
|
|
106
|
+
"Rf": 267,
|
|
107
|
+
"Db": 270,
|
|
108
|
+
"Sg": 269,
|
|
109
|
+
"Bh": 270,
|
|
110
|
+
"Hs": 270,
|
|
111
|
+
"Mt": 278,
|
|
112
|
+
"Ds": 281,
|
|
113
|
+
"Rg": 281,
|
|
114
|
+
"Cn": 285,
|
|
115
|
+
"Nh": 286,
|
|
116
|
+
"Fl": 289,
|
|
117
|
+
"Mc": 289,
|
|
118
|
+
"Lv": 293,
|
|
119
|
+
"Ts": 293,
|
|
120
|
+
"Og": 294,
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
bibliography = "https://iupac.qmul.ac.uk/AtWt/"
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from .atomic_weights import atomic_weights
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class SolarAbundancesClass:
|
|
5
|
+
"""Container for solar abundances with methods to convert between mass fraction and abundance per H"""
|
|
6
|
+
|
|
7
|
+
bibliography = ["2009ARA&A..47..481A"]
|
|
8
|
+
|
|
9
|
+
@property
|
|
10
|
+
def mass_fraction(self):
|
|
11
|
+
"""Returns a hard-coded dict of Solar abundance mass fractions"""
|
|
12
|
+
return {
|
|
13
|
+
"Z": 0.0142,
|
|
14
|
+
"He": 0.27030,
|
|
15
|
+
"C": 2.53e-3,
|
|
16
|
+
"N": 7.41e-4,
|
|
17
|
+
"O": 6.13e-3,
|
|
18
|
+
"Ne": 1.34e-3,
|
|
19
|
+
"Mg": 7.57e-4,
|
|
20
|
+
"Si": 7.12e-4,
|
|
21
|
+
"S": 3.31e-4,
|
|
22
|
+
"Ca": 6.87e-5,
|
|
23
|
+
"Fe": 1.38e-3,
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def abundance_per_H(self):
|
|
28
|
+
"""Returns dictionary of abundances per H nucleon"""
|
|
29
|
+
return {species: f / (1 - f) / atomic_weights[species] for species, f in self.mass_fraction.items()}
|
|
30
|
+
|
|
31
|
+
def x(self, species):
|
|
32
|
+
return self.get_abundance(species)
|
|
33
|
+
|
|
34
|
+
def get_mass_fraction(self, species: str) -> float:
|
|
35
|
+
"""Returns the mass fraction of a given species"""
|
|
36
|
+
if species in self.mass_fraction:
|
|
37
|
+
return self.mass_fraction[species]
|
|
38
|
+
else:
|
|
39
|
+
raise NotImplementedError(f"Solar abundance of {species} not available.")
|
|
40
|
+
|
|
41
|
+
def get_abundance(self, species: str) -> float:
|
|
42
|
+
"""Returns the abundance per H nuclear of an input species"""
|
|
43
|
+
if species in self.mass_fraction:
|
|
44
|
+
return self.abundance_per_H[species]
|
|
45
|
+
else:
|
|
46
|
+
raise NotImplementedError(f"Solar abundance of {species} not available.")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
SolarAbundances = SolarAbundancesClass()
|
reaxion/eos/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# from .eos import u, rho
|
reaxion/eos/eos.py
ADDED
reaxion/equation.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""Implementation of Equation class for representing conservation laws"""
|
|
2
|
+
|
|
3
|
+
import sympy as sp
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Equation(sp.core.relational.Equality):
|
|
7
|
+
"""Sympy equation where we overload addition/subtraction to apply those operations to the RHS, for summing rate
|
|
8
|
+
equations"""
|
|
9
|
+
|
|
10
|
+
def get_summand(self, other):
|
|
11
|
+
"""Value-check the operand and return the quantity to be summed in the operation: the expression itself if an expression, or the RHS"""
|
|
12
|
+
if isinstance(other, sp.core.relational.Equality):
|
|
13
|
+
if self.lhs != other.lhs:
|
|
14
|
+
raise ValueError(
|
|
15
|
+
"Tried to sum incompatible equations. Equation summation only defined for differential equations with the same LHS."
|
|
16
|
+
)
|
|
17
|
+
else:
|
|
18
|
+
return other.rhs
|
|
19
|
+
elif isinstance(other, sp.logic.boolalg.BooleanAtom):
|
|
20
|
+
return 0
|
|
21
|
+
else:
|
|
22
|
+
return other
|
|
23
|
+
|
|
24
|
+
def __add__(self, other):
|
|
25
|
+
summand = self.get_summand(other)
|
|
26
|
+
return Equation(self.lhs, self.rhs + summand)
|
|
27
|
+
|
|
28
|
+
def __sub__(self, other):
|
|
29
|
+
summand = self.get_summand(other)
|
|
30
|
+
return Equation(self.lhs, self.rhs - summand)
|
|
31
|
+
|
|
32
|
+
def __radd__(self, other):
|
|
33
|
+
return self.__add__(other)
|
|
34
|
+
|
|
35
|
+
def __iadd__(self, other):
|
|
36
|
+
self = self + other
|
|
37
|
+
return self
|
|
38
|
+
|
|
39
|
+
def __isub__(self, other):
|
|
40
|
+
self = self - other
|
|
41
|
+
return self
|
|
@@ -0,0 +1,380 @@
|
|
|
1
|
+
"""Implementation of EquationSystem for representing, manipulating, and constructing systems of conservation laws"""
|
|
2
|
+
|
|
3
|
+
import sympy as sp
|
|
4
|
+
from .symbols import d_dt, n_, x_, t, BDF, n_Htot, internal_energy
|
|
5
|
+
from .data import SolarAbundances
|
|
6
|
+
from jax import numpy as jnp
|
|
7
|
+
import numpy as np
|
|
8
|
+
from .numerics import newton_rootsolve
|
|
9
|
+
from astropy import units
|
|
10
|
+
from .equation import Equation
|
|
11
|
+
from sympy.codegen.ast import Assignment
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class EquationSystem(dict):
|
|
15
|
+
"""Dict of symbolic expressions with certain superpowers for manipulating sets of conservation equations."""
|
|
16
|
+
|
|
17
|
+
def copy(self):
|
|
18
|
+
new = EquationSystem()
|
|
19
|
+
for k in self:
|
|
20
|
+
new[k] = self[k]
|
|
21
|
+
return new
|
|
22
|
+
|
|
23
|
+
def __getitem__(self, __key: str):
|
|
24
|
+
"""Dict getitem method where we initialize a differential equation for the conservation of a species if the key
|
|
25
|
+
does not exist"""
|
|
26
|
+
if __key not in self:
|
|
27
|
+
self.__setitem__(__key, Equation(d_dt(n_(__key)), 0)) # technically should only be n_ if this is a species
|
|
28
|
+
# need to make sure that d/dt's don't add up when composing equations
|
|
29
|
+
return super().__getitem__(__key)
|
|
30
|
+
|
|
31
|
+
def __add__(self, other):
|
|
32
|
+
"""Return a dict whose values are the sum of the values of the operands"""
|
|
33
|
+
keys = self.keys() | other.keys()
|
|
34
|
+
new = EquationSystem()
|
|
35
|
+
for k in keys:
|
|
36
|
+
new[k] = self[k] + other[k]
|
|
37
|
+
return new
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
def symbols(self):
|
|
41
|
+
"""Returns the set of all symbols in the equations"""
|
|
42
|
+
all = set()
|
|
43
|
+
for e in self.values():
|
|
44
|
+
all.update(e.free_symbols)
|
|
45
|
+
if t in all: # leave time out
|
|
46
|
+
all.remove(t)
|
|
47
|
+
return all
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def jacobian(self):
|
|
51
|
+
"""Returns a dict of dicts representing the Jacobian of the RHS of the system. Keys are the names of the
|
|
52
|
+
conserved quantities and subkeys are the variable of differentiation.
|
|
53
|
+
"""
|
|
54
|
+
return {k: {s: sp.diff(e.rhs, s) for s in self.symbols} for k, e in self.items()}
|
|
55
|
+
|
|
56
|
+
def subs(self, expr, replacement):
|
|
57
|
+
"""Substitute symbolic expressions throughout the whole network."""
|
|
58
|
+
for k, e in self.items():
|
|
59
|
+
self[k] = e.subs(expr, replacement)
|
|
60
|
+
|
|
61
|
+
def reduced(self, knowns, time_dependent=[]):
|
|
62
|
+
subsystem = self.copy()
|
|
63
|
+
subsystem.set_time_dependence(time_dependent)
|
|
64
|
+
subsystem.do_conservation_reductions(time_dependent)
|
|
65
|
+
if "T" in (str(k) for k in knowns) and "T" not in time_dependent:
|
|
66
|
+
del subsystem["heat"]
|
|
67
|
+
return subsystem
|
|
68
|
+
|
|
69
|
+
def set_time_dependence(self, time_dependent_vars):
|
|
70
|
+
"""Insert backward-difference formulae or set to steady state"""
|
|
71
|
+
# put in backward differences
|
|
72
|
+
for q in self:
|
|
73
|
+
if q in time_dependent_vars: # insert backward-difference formula
|
|
74
|
+
self[q] = Equation(BDF(q), self[q].rhs)
|
|
75
|
+
else:
|
|
76
|
+
self[q] = Equation(0, self[q].rhs)
|
|
77
|
+
if "T" in time_dependent_vars: # special behaviour
|
|
78
|
+
self["heat"] = Equation(BDF("T"), self["heat"].rhs)
|
|
79
|
+
if "u" not in self:
|
|
80
|
+
self["u"] = Equation(0, sp.Symbol("u") - internal_energy)
|
|
81
|
+
|
|
82
|
+
def do_conservation_reductions(self, time_dependent_vars):
|
|
83
|
+
"""Eliminate equations from the system using known conservation laws."""
|
|
84
|
+
self.substitutions = []
|
|
85
|
+
|
|
86
|
+
# since we have n_Htot let's convert all other n's to x's
|
|
87
|
+
for s in self.symbols:
|
|
88
|
+
if "n_" in str(s) and "Htot" not in str(s):
|
|
89
|
+
species = str(s).split("_")[1]
|
|
90
|
+
self.substitutions.append((s, n_Htot * x_(species)))
|
|
91
|
+
|
|
92
|
+
# charge neutrality
|
|
93
|
+
if "e-" not in time_dependent_vars:
|
|
94
|
+
self.substitutions.append((x_("e-"), x_("H+") + x_("He+") + 2 * x_("He++")))
|
|
95
|
+
del self["e-"]
|
|
96
|
+
|
|
97
|
+
# general: sum(n_(species containing H) / (number of H in species)) - n_("H_2") / 2 #
|
|
98
|
+
if "H+" not in time_dependent_vars:
|
|
99
|
+
self.substitutions.append((x_("H+"), 1 - x_("H")))
|
|
100
|
+
if "H+" in self:
|
|
101
|
+
del self["H+"]
|
|
102
|
+
|
|
103
|
+
if "He++" not in time_dependent_vars:
|
|
104
|
+
y = sp.Symbol("y")
|
|
105
|
+
self.substitutions.append((x_("He++"), y - x_("He") - x_("He+")))
|
|
106
|
+
if "He++" in self:
|
|
107
|
+
del self["He++"]
|
|
108
|
+
|
|
109
|
+
for expr, sub in self.substitutions:
|
|
110
|
+
self.subs(expr, sub)
|
|
111
|
+
|
|
112
|
+
# general: substitute highest ionization state with n_Htot * x_element - sum of lower ionization states
|
|
113
|
+
|
|
114
|
+
@property
|
|
115
|
+
def rhs(self):
|
|
116
|
+
"""Return as dict of rhs-lhs instead of equations"""
|
|
117
|
+
return {k: e.rhs - e.lhs for k, e in self.items()}
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def rhs_scaled(self):
|
|
121
|
+
"""Returns a scaled version of the the RHS pulling out the usual factors affecting collision rates"""
|
|
122
|
+
return [r for r in self.rhs.values()] # / (T**0.5 * n_Htot * n_Htot * 1e-12)
|
|
123
|
+
|
|
124
|
+
def solve(
|
|
125
|
+
self,
|
|
126
|
+
knowns,
|
|
127
|
+
guesses,
|
|
128
|
+
time_dependent=[],
|
|
129
|
+
dt=None,
|
|
130
|
+
verbose=False,
|
|
131
|
+
tol=1e-3,
|
|
132
|
+
careful_steps=10,
|
|
133
|
+
symbolic_keys=False,
|
|
134
|
+
):
|
|
135
|
+
"""
|
|
136
|
+
Solves for equilibrium after substituting a set of known quantities, e.g. temperature, metallicity,
|
|
137
|
+
etc.
|
|
138
|
+
|
|
139
|
+
Parameters
|
|
140
|
+
----------
|
|
141
|
+
known_quantities: dict
|
|
142
|
+
Dict of symbolic quantities and their values that will be plugged into the network solve as known quantities.
|
|
143
|
+
Can be arrays if you want to substitute multiple values. If T is included here, we solve for chemical
|
|
144
|
+
equilibrium. If T is not included, solve for thermochemical equilibrium.
|
|
145
|
+
guesses: dict
|
|
146
|
+
Dict of symbolic quantities and their values that will be plugged into the network solve as guesses for the
|
|
147
|
+
unknown quantities. Can be arrays if you want to substitute multiple values. Will default to trying sensible
|
|
148
|
+
guesses for recognized quantities (NOT IMPLEMENTED YET)
|
|
149
|
+
tol: float, optional
|
|
150
|
+
Desired relative error in chemical abundances (default: 1e-3)
|
|
151
|
+
careful_steps: int, optional
|
|
152
|
+
Number of careful initial steps in the Newton solve before full step size is used - try increasing this if
|
|
153
|
+
your solve has trouble converging.
|
|
154
|
+
|
|
155
|
+
Returns
|
|
156
|
+
-------
|
|
157
|
+
soldict: dict
|
|
158
|
+
Dict of species and their equilibrium abundances relative to H or raw number densities (depending on
|
|
159
|
+
value of normalize_to_H)
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
def printv(*a, **k):
|
|
163
|
+
"""Print only if locally verbose=True"""
|
|
164
|
+
if verbose:
|
|
165
|
+
print(*a, **k)
|
|
166
|
+
|
|
167
|
+
# first: check knowns and guesses are all same size
|
|
168
|
+
num_params = np.array([len(np.array(guesses[g])) for g in guesses] + [len(np.array(knowns[g])) for g in knowns])
|
|
169
|
+
if not np.all(num_params == num_params[0]):
|
|
170
|
+
raise ValueError("Input parameters and initial guesses must all have the same shape.")
|
|
171
|
+
num_params = num_params[0]
|
|
172
|
+
|
|
173
|
+
if dt is not None:
|
|
174
|
+
knowns["Δt"] = np.repeat(dt.to(units.s), num_params)
|
|
175
|
+
|
|
176
|
+
if "u" in guesses or "T" in time_dependent:
|
|
177
|
+
self["u"] = Equation(0, internal_energy - sp.Symbol("u"))
|
|
178
|
+
subsystem = self.reduced(knowns, time_dependent)
|
|
179
|
+
symbols = subsystem.symbols
|
|
180
|
+
num_equations = len(subsystem)
|
|
181
|
+
|
|
182
|
+
# are there any symbols for which we can make a reasonable assumption or directly solve the steady-state approximation?
|
|
183
|
+
prescriptions = {"y": SolarAbundances.x("He"), "Y": SolarAbundances.mass_fraction["He"], "Z": 1.0}
|
|
184
|
+
assumed_values = {}
|
|
185
|
+
if len(symbols) > num_equations + len(knowns):
|
|
186
|
+
undetermined_symbols = symbols.difference(set(sp.Symbol(g) for g in guesses))
|
|
187
|
+
printv(f"Undetermined symbols: {undetermined_symbols}")
|
|
188
|
+
for s in undetermined_symbols:
|
|
189
|
+
# if we have a prescription for this quantity, plug it in here. This should eventually be specified at the model level.
|
|
190
|
+
if str(s) in prescriptions:
|
|
191
|
+
# case 1: we have given a value, which we should add to the list of knowns
|
|
192
|
+
assumed_values[str(s)] = np.repeat(prescriptions[str(s)], num_params)
|
|
193
|
+
printv(f"{s} not specified; assuming {s}={prescriptions[str(s)]}.")
|
|
194
|
+
symbols = subsystem.symbols
|
|
195
|
+
# case 2: we have given an expression in terms of the other available quantities: we need to subs it
|
|
196
|
+
|
|
197
|
+
# ok now we should have number of symbols unknowns + knowns
|
|
198
|
+
printv(
|
|
199
|
+
f"Free symbols: {symbols}\nKnown values: {list(knowns)}\nAssumed values: {list(assumed_values)}\nEquations solved: {list(subsystem.rhs)}"
|
|
200
|
+
)
|
|
201
|
+
if len(symbols) != len(knowns | assumed_values) + len(subsystem):
|
|
202
|
+
raise ValueError(
|
|
203
|
+
f"Number of free symbols is {len(symbols)} != number of knowns {len(knowns)} + number of assumptions {len(assumed_values)} + number of equations {len(subsystem)}\n"
|
|
204
|
+
)
|
|
205
|
+
else:
|
|
206
|
+
printv(
|
|
207
|
+
f"It's solvin time. Solving for {set(guesses)} based on input {set(knowns)} and assumptions about {set(assumed_values)}"
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
guessvals = {}
|
|
211
|
+
paramvals = {}
|
|
212
|
+
for s in subsystem.symbols:
|
|
213
|
+
for g in guesses:
|
|
214
|
+
if g == str(s) or f"x_{g}" == str(s):
|
|
215
|
+
guessvals[s] = guesses[g]
|
|
216
|
+
for k in knowns | assumed_values:
|
|
217
|
+
if k == str(s) or f"x_{k}" == str(s):
|
|
218
|
+
paramvals[s] = (knowns | assumed_values)[k]
|
|
219
|
+
|
|
220
|
+
lambda_args = [list(guessvals.keys()), list(paramvals.keys())]
|
|
221
|
+
func = sp.lambdify(lambda_args, subsystem.rhs_scaled, modules="jax", cse=True)
|
|
222
|
+
|
|
223
|
+
tolerance_vars = [x_("H"), x_("He+") + x_("He"), 1 - x_("H")]
|
|
224
|
+
if "T" in guesses:
|
|
225
|
+
tolerance_vars += [sp.Symbol("T")]
|
|
226
|
+
if "u" in guesses:
|
|
227
|
+
tolerance_vars += [sp.Symbol("u"), subsystem["heat"].rhs]
|
|
228
|
+
# , subsystem["heat"]] # converge on the internal energy and cooling rate
|
|
229
|
+
tolfunc = sp.lambdify(lambda_args, tolerance_vars, modules="jax", cse=True)
|
|
230
|
+
|
|
231
|
+
def f_numerical(X, *params):
|
|
232
|
+
"""JAX function to rootfind"""
|
|
233
|
+
return jnp.array(func(X, params))
|
|
234
|
+
|
|
235
|
+
def tolerance_func(X, *params):
|
|
236
|
+
"""Solution will terminate if the relative change in this quantity is < tol"""
|
|
237
|
+
return jnp.array(tolfunc(X, params))
|
|
238
|
+
|
|
239
|
+
# option to bail here and just provide the RHS
|
|
240
|
+
|
|
241
|
+
# jacfunc = sp.lambdify(
|
|
242
|
+
# lambda_args, [[sp.diff(a, g) for g in guessvals] for a in subsystem.rhs_scaled]
|
|
243
|
+
# ) # , modules="jax", cse=True
|
|
244
|
+
|
|
245
|
+
sol, num_iter = newton_rootsolve(
|
|
246
|
+
f_numerical,
|
|
247
|
+
jnp.array([g for g in guessvals.values()]).T,
|
|
248
|
+
jnp.array([p for p in paramvals.values()]).T,
|
|
249
|
+
tolfunc=tolerance_func,
|
|
250
|
+
rtol=tol,
|
|
251
|
+
careful_steps=careful_steps,
|
|
252
|
+
nonnegative=True,
|
|
253
|
+
return_num_iter=True,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
soldict = self.package_solution(sol, guessvals, guesses, paramvals, subsystem, symbolic_keys)
|
|
257
|
+
|
|
258
|
+
return soldict
|
|
259
|
+
|
|
260
|
+
def package_solution(self, sol, guessvals, guesses, paramvals, subsystem, symbolic_keys):
|
|
261
|
+
# now repack the solution
|
|
262
|
+
soldict = {}
|
|
263
|
+
for i, g in enumerate(guessvals):
|
|
264
|
+
soldict[g] = sol[:, i]
|
|
265
|
+
# do a reverse-pass on the substitutions we made to get all quantities
|
|
266
|
+
values_to_subs = soldict | paramvals
|
|
267
|
+
for expr, sub in reversed(subsystem.substitutions):
|
|
268
|
+
if expr in soldict:
|
|
269
|
+
continue
|
|
270
|
+
if "n_" in str(expr):
|
|
271
|
+
continue
|
|
272
|
+
soldict[expr] = sp.lambdify(list(sub.free_symbols), sub)(
|
|
273
|
+
*[values_to_subs[s] for s in list(sub.free_symbols)]
|
|
274
|
+
) # should probably make a function of this
|
|
275
|
+
values_to_subs |= soldict
|
|
276
|
+
if not symbolic_keys:
|
|
277
|
+
soldict = {str(k): v for k, v in soldict.items()}
|
|
278
|
+
# if we have a bunch of x_'s, should also link up keys in the original input format, e.g. H->x_H
|
|
279
|
+
if np.any(["x_" in k for k in guesses]): # if we specified abundances with x_ notation, return same
|
|
280
|
+
return soldict
|
|
281
|
+
soldict2 = {} # otherwise return with input format where keys are simple species strings
|
|
282
|
+
for k in soldict:
|
|
283
|
+
if "x_" in str(k):
|
|
284
|
+
soldict2[str(k).replace("x_", "")] = soldict[k]
|
|
285
|
+
else:
|
|
286
|
+
soldict2[k] = soldict[k]
|
|
287
|
+
soldict = soldict2
|
|
288
|
+
|
|
289
|
+
return soldict
|
|
290
|
+
|
|
291
|
+
def solver_functions(self, solve_vars, time_dependent=[], return_jac=False, return_dict=False):
|
|
292
|
+
"""Returns the RHS of the system to solve and its Jacobian, applying simplifications"""
|
|
293
|
+
|
|
294
|
+
solve_vars = list(solve_vars)
|
|
295
|
+
if "u" in solve_vars or "T" in time_dependent:
|
|
296
|
+
self["u"] = Equation(0, internal_energy - sp.Symbol("u"))
|
|
297
|
+
solve_vars.append("u")
|
|
298
|
+
|
|
299
|
+
knowns = self.symbols.difference(solve_vars)
|
|
300
|
+
subsystem = self.reduced(knowns, time_dependent)
|
|
301
|
+
|
|
302
|
+
rhs = {}
|
|
303
|
+
for s in subsystem.symbols:
|
|
304
|
+
for g in solve_vars:
|
|
305
|
+
if str(s) == "T" and "T" in solve_vars:
|
|
306
|
+
rhs[s] = subsystem.rhs["heat"]
|
|
307
|
+
elif str(g) == str(s) or f"x_{g}" == str(s):
|
|
308
|
+
rhs[s] = subsystem.rhs[g]
|
|
309
|
+
|
|
310
|
+
if return_jac:
|
|
311
|
+
jac = {}
|
|
312
|
+
for s, expr in rhs.items():
|
|
313
|
+
jac[s] = {s2: sp.diff(expr, s2) for s2 in rhs}
|
|
314
|
+
|
|
315
|
+
if return_dict:
|
|
316
|
+
return rhs, jac
|
|
317
|
+
else:
|
|
318
|
+
return (
|
|
319
|
+
list(rhs.values()),
|
|
320
|
+
[[jac[s1][s2] for s2 in rhs] for s1 in jac],
|
|
321
|
+
{s: i for i, s in enumerate(rhs)},
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
if return_dict:
|
|
325
|
+
return rhs
|
|
326
|
+
else:
|
|
327
|
+
return list(rhs.values()), {s: i for i, s in enumerate(rhs)}
|
|
328
|
+
|
|
329
|
+
def generate_code(self, solve_vars, time_dependent=[], language="Fortran", jac=True, cse=True, sanitize=True):
|
|
330
|
+
"""Generates numerical code that implements the system RHS and/or Jacobian in the specified language."""
|
|
331
|
+
func, jac, indices = self.solver_functions(solve_vars, time_dependent, return_jac=jac)
|
|
332
|
+
|
|
333
|
+
def printer(x, language="c"):
|
|
334
|
+
match language.lower():
|
|
335
|
+
case "fortran":
|
|
336
|
+
return sp.fcode(x, standard=2008)
|
|
337
|
+
case "c":
|
|
338
|
+
return sp.ccode(x, standard="c99")
|
|
339
|
+
case "python":
|
|
340
|
+
return sp.pycode(x)
|
|
341
|
+
case "c++":
|
|
342
|
+
return sp.cxxcode(x, standard="c++11")
|
|
343
|
+
|
|
344
|
+
codeblocks = []
|
|
345
|
+
|
|
346
|
+
header = "# Computes the RHS function "
|
|
347
|
+
if jac:
|
|
348
|
+
header += "and Jacobian "
|
|
349
|
+
header += f"to solve for {list(indices.keys())}\n\n"
|
|
350
|
+
|
|
351
|
+
header += "# INDEX CONVENTION: " + " ".join(f"({i}: {s})" for s, i in indices.items())
|
|
352
|
+
|
|
353
|
+
codeblocks.append(header)
|
|
354
|
+
|
|
355
|
+
if cse:
|
|
356
|
+
cse, (func, jac) = sp.cse((sp.Matrix(func), sp.Matrix(jac)))
|
|
357
|
+
block = []
|
|
358
|
+
for expr in cse:
|
|
359
|
+
block.append(printer(Assignment(*expr), language))
|
|
360
|
+
codeblocks.append(" \n".join(block))
|
|
361
|
+
|
|
362
|
+
rhs_result = sp.MatrixSymbol("rhs_result", len(func), 1)
|
|
363
|
+
codeblocks.append(printer(Assignment(rhs_result, func), language))
|
|
364
|
+
|
|
365
|
+
if jac:
|
|
366
|
+
jac_result = sp.MatrixSymbol("jac_result", len(func), len(func))
|
|
367
|
+
codeblocks.append(printer(Assignment(jac_result, jac), language))
|
|
368
|
+
|
|
369
|
+
code = "\n\n".join(codeblocks)
|
|
370
|
+
if sanitize:
|
|
371
|
+
sanitized_code = ""
|
|
372
|
+
replacements = {"+": "plus"}
|
|
373
|
+
for i, char in enumerate(code):
|
|
374
|
+
if char in replacements:
|
|
375
|
+
if code[i - 1].split(): # if preceding character is not whitespace
|
|
376
|
+
sanitized_code += replacements[char]
|
|
377
|
+
continue
|
|
378
|
+
sanitized_code += char
|
|
379
|
+
code = sanitized_code
|
|
380
|
+
return code
|
reaxion/localstate.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Specifies class for the local thermal and chemical state"""
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class State:
|
|
5
|
+
"""Internal energy, abundances, dust mass fraction, radiation field, cosmic ray ionization rate,
|
|
6
|
+
ortho-para ratio,"""
|
|
7
|
+
|
|
8
|
+
f_ortho: float = 0.75
|
|
9
|
+
hydrogen_massfrac: float = 0.7381
|
|
10
|
+
metallicity: float = 0.0134
|
|
11
|
+
|
|
12
|
+
def __init__(self):
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def density(self, species):
|
|
17
|
+
return 0
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# @property
|
|
21
|
+
# def
|