zoomy-core 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of zoomy-core might be problematic. Click here for more details.
- zoomy_core/__init__.py +7 -0
- zoomy_core/decorators/decorators.py +25 -0
- zoomy_core/fvm/flux.py +52 -0
- zoomy_core/fvm/nonconservative_flux.py +97 -0
- zoomy_core/fvm/ode.py +55 -0
- zoomy_core/fvm/solver_numpy.py +297 -0
- zoomy_core/fvm/timestepping.py +13 -0
- zoomy_core/mesh/mesh.py +1236 -0
- zoomy_core/mesh/mesh_extrude.py +168 -0
- zoomy_core/mesh/mesh_util.py +487 -0
- zoomy_core/misc/custom_types.py +6 -0
- zoomy_core/misc/interpolation.py +140 -0
- zoomy_core/misc/io.py +439 -0
- zoomy_core/misc/logger_config.py +18 -0
- zoomy_core/misc/misc.py +213 -0
- zoomy_core/model/analysis.py +147 -0
- zoomy_core/model/basefunction.py +113 -0
- zoomy_core/model/basemodel.py +512 -0
- zoomy_core/model/boundary_conditions.py +193 -0
- zoomy_core/model/initial_conditions.py +171 -0
- zoomy_core/model/model.py +63 -0
- zoomy_core/model/models/GN.py +70 -0
- zoomy_core/model/models/advection.py +53 -0
- zoomy_core/model/models/basisfunctions.py +181 -0
- zoomy_core/model/models/basismatrices.py +377 -0
- zoomy_core/model/models/core.py +564 -0
- zoomy_core/model/models/coupled_constrained.py +60 -0
- zoomy_core/model/models/poisson.py +41 -0
- zoomy_core/model/models/shallow_moments.py +757 -0
- zoomy_core/model/models/shallow_moments_sediment.py +378 -0
- zoomy_core/model/models/shallow_moments_topo.py +423 -0
- zoomy_core/model/models/shallow_moments_variants.py +1509 -0
- zoomy_core/model/models/shallow_water.py +266 -0
- zoomy_core/model/models/shallow_water_topo.py +111 -0
- zoomy_core/model/models/shear_shallow_flow.py +594 -0
- zoomy_core/model/models/sme_turbulent.py +613 -0
- zoomy_core/model/models/vam.py +455 -0
- zoomy_core/postprocessing/postprocessing.py +72 -0
- zoomy_core/preprocessing/openfoam_moments.py +452 -0
- zoomy_core/transformation/helpers.py +25 -0
- zoomy_core/transformation/to_amrex.py +238 -0
- zoomy_core/transformation/to_c.py +181 -0
- zoomy_core/transformation/to_jax.py +14 -0
- zoomy_core/transformation/to_numpy.py +115 -0
- zoomy_core/transformation/to_openfoam.py +254 -0
- zoomy_core/transformation/to_ufl.py +67 -0
- zoomy_core-0.1.11.dist-info/METADATA +225 -0
- zoomy_core-0.1.11.dist-info/RECORD +51 -0
- zoomy_core-0.1.11.dist-info/WHEEL +5 -0
- zoomy_core-0.1.11.dist-info/licenses/LICENSE +674 -0
- zoomy_core-0.1.11.dist-info/top_level.txt +1 -0
zoomy_core/misc/misc.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
# import scipy.interpolate as interp
|
|
5
|
+
# from functools import wraps
|
|
6
|
+
|
|
7
|
+
from attr import define
|
|
8
|
+
from typing import Callable, Optional, Any
|
|
9
|
+
from types import SimpleNamespace
|
|
10
|
+
|
|
11
|
+
from sympy import MatrixSymbol
|
|
12
|
+
from sympy import MutableDenseNDimArray as ZArray
|
|
13
|
+
|
|
14
|
+
from zoomy_core.misc.custom_types import FArray
|
|
15
|
+
from zoomy_core.misc.logger_config import logger
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@define(slots=True, frozen=False, kw_only=True)
|
|
23
|
+
class Zstruct(SimpleNamespace):
|
|
24
|
+
def __init__(self, **kwargs):
|
|
25
|
+
super().__init__(**kwargs)
|
|
26
|
+
|
|
27
|
+
def __getitem__(self, key):
|
|
28
|
+
return self.values()[key]
|
|
29
|
+
|
|
30
|
+
def length(self):
|
|
31
|
+
return len(self.values())
|
|
32
|
+
|
|
33
|
+
def get_list(self, recursive: bool = True):
|
|
34
|
+
if recursive:
|
|
35
|
+
output = []
|
|
36
|
+
for item in self.values():
|
|
37
|
+
if hasattr(item, 'get_list'):
|
|
38
|
+
# If item is a Zstruct or similar, call get_list recursively
|
|
39
|
+
output.append(item.get_list(recursive=True))
|
|
40
|
+
else:
|
|
41
|
+
output.append(item)
|
|
42
|
+
return output
|
|
43
|
+
else:
|
|
44
|
+
return self.values()
|
|
45
|
+
|
|
46
|
+
def as_dict(self, recursive: bool = True):
|
|
47
|
+
if recursive:
|
|
48
|
+
output = {}
|
|
49
|
+
for key, value in self.items():
|
|
50
|
+
if hasattr(value, 'as_dict'):
|
|
51
|
+
# If value is a Zstruct or similar, call as_dict recursively
|
|
52
|
+
output[key] = value.as_dict(recursive=True)
|
|
53
|
+
else:
|
|
54
|
+
output[key] = value
|
|
55
|
+
return output
|
|
56
|
+
else:
|
|
57
|
+
return self.__dict__
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def items(self, resursive: bool = False):
|
|
61
|
+
return self.as_dict(recursive=resursive).items()
|
|
62
|
+
|
|
63
|
+
def keys(self):
|
|
64
|
+
return list(self.as_dict(recursive=False).keys())
|
|
65
|
+
|
|
66
|
+
def values(self):
|
|
67
|
+
return list(self.as_dict(recursive=False).values())
|
|
68
|
+
|
|
69
|
+
def contains(self, key):
|
|
70
|
+
if self.as_dict(recursive=False).get(key) is not None:
|
|
71
|
+
return True
|
|
72
|
+
return False
|
|
73
|
+
|
|
74
|
+
def update(self, zstruct, recursive: bool = True):
|
|
75
|
+
"""
|
|
76
|
+
Update the current Zstruct with another Zstruct or dictionary.
|
|
77
|
+
"""
|
|
78
|
+
if not isinstance(zstruct, Zstruct):
|
|
79
|
+
raise TypeError("zstruct must be a Zstruct or a dictionary.")
|
|
80
|
+
|
|
81
|
+
if recursive:
|
|
82
|
+
# Update each attribute recursively
|
|
83
|
+
for key, value in zstruct.as_dict(recursive=False).items():
|
|
84
|
+
if hasattr(self, key):
|
|
85
|
+
current_value = getattr(self, key)
|
|
86
|
+
if isinstance(current_value, Zstruct) and isinstance(value, Zstruct):
|
|
87
|
+
# If both are Zstructs, update recursively
|
|
88
|
+
current_value.update(value, recursive=True)
|
|
89
|
+
else:
|
|
90
|
+
setattr(self, key, value)
|
|
91
|
+
else:
|
|
92
|
+
setattr(self, key, value)
|
|
93
|
+
else:
|
|
94
|
+
# Update only the top-level attributes
|
|
95
|
+
for key, value in zstruct.as_dict(recursive=False).items():
|
|
96
|
+
setattr(self, key, value)
|
|
97
|
+
|
|
98
|
+
@classmethod
|
|
99
|
+
def from_dict(cls, d):
|
|
100
|
+
"""
|
|
101
|
+
Create a Zstruct recursively from a dictionary.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
d (dict): Dictionary.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
Zstruct: An instance of Zstruct.
|
|
108
|
+
"""
|
|
109
|
+
if not isinstance(d, dict):
|
|
110
|
+
raise TypeError("Input must be a dictionary.")
|
|
111
|
+
|
|
112
|
+
# Convert the dictionary to a Zstruct
|
|
113
|
+
for k, v in d.items():
|
|
114
|
+
if isinstance(v, dict):
|
|
115
|
+
d[k] = Zstruct.from_dict(v)
|
|
116
|
+
return cls(**d)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
@define(slots=True, frozen=False, kw_only=True)
|
|
120
|
+
class Settings(Zstruct):
|
|
121
|
+
"""
|
|
122
|
+
Settings class for the application.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
**kwargs: Arbitrary keyword arguments to set as attributes.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
An `IterableNamespace` instance.
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
def __init__(self, **kwargs):
|
|
132
|
+
# assert that kwargs constains name
|
|
133
|
+
if 'output' not in kwargs or not isinstance(kwargs['output'], Zstruct):
|
|
134
|
+
logger.warning("No 'output' Zstruct found in Settings. Default: Zstruct(directory='output', filename='simulation', clean_directory=False)")
|
|
135
|
+
kwargs['output'] = Zstruct(directory='output', filename='simulation', clean_directory=True)
|
|
136
|
+
output = kwargs['output']
|
|
137
|
+
if not output.contains('directory'):
|
|
138
|
+
logger.warning("No 'directory' attribute found in output Zstruct. Default: 'output'")
|
|
139
|
+
kwargs['output'] = Zstruct(directory='output', **output.as_dict())
|
|
140
|
+
if not output.contains('filename'):
|
|
141
|
+
logger.warning("No 'filename' attribute found in output Zstruct. Default: 'simulation'")
|
|
142
|
+
kwargs['output'] = Zstruct(filename='simulation', **output.as_dict())
|
|
143
|
+
if not output.contains('clean_directory'):
|
|
144
|
+
logger.warning("No 'clean_directory' attribute found in output Zstruct. Default: False")
|
|
145
|
+
kwargs['output'] = Zstruct(clean_directory=False, **output.as_dict())
|
|
146
|
+
super().__init__(**kwargs)
|
|
147
|
+
|
|
148
|
+
@classmethod
|
|
149
|
+
def default(cls):
|
|
150
|
+
"""
|
|
151
|
+
Returns a default Settings instance.
|
|
152
|
+
"""
|
|
153
|
+
return cls(
|
|
154
|
+
output=Zstruct(directory='output', filename='simulation', clean_directory=False)
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def compute_transverse_direction(normal):
|
|
159
|
+
dim = normal.shape[0]
|
|
160
|
+
if dim == 1:
|
|
161
|
+
return np.zeros_like(normal)
|
|
162
|
+
elif dim == 2:
|
|
163
|
+
transverse = np.zeros((2), dtype=float)
|
|
164
|
+
transverse[0] = -normal[1]
|
|
165
|
+
transverse[1] = normal[0]
|
|
166
|
+
return transverse
|
|
167
|
+
elif dim == 3:
|
|
168
|
+
cartesian_x = np.array([1, 0, 0], dtype=float)
|
|
169
|
+
transverse = np.cross(normal, cartesian_x)
|
|
170
|
+
return transverse
|
|
171
|
+
else:
|
|
172
|
+
assert False
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def extract_momentum_fields_as_vectors(Q, momentum_fields, dim):
|
|
176
|
+
num_fields = len(momentum_fields)
|
|
177
|
+
num_momentum_eqns = int(num_fields / dim)
|
|
178
|
+
Qnew = np.empty((num_momentum_eqns, dim))
|
|
179
|
+
for i_eq in range(num_momentum_eqns):
|
|
180
|
+
for i_dim in range(dim):
|
|
181
|
+
Qnew[i_eq, i_dim] = Q[momentum_fields[i_dim * num_momentum_eqns + i_eq]]
|
|
182
|
+
return Qnew
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def projection_in_normal_and_transverse_direction(Q, momentum_fields, normal):
|
|
186
|
+
dim = normal.shape[0]
|
|
187
|
+
transverse_directions = compute_transverse_direction(normal)
|
|
188
|
+
Q_momentum_eqns = extract_momentum_fields_as_vectors(Q, momentum_fields, dim)
|
|
189
|
+
Q_normal = np.zeros((Q_momentum_eqns.shape[0]), dtype=float)
|
|
190
|
+
Q_transverse = np.zeros((Q_momentum_eqns.shape[0]), dtype=float)
|
|
191
|
+
for d in range(dim):
|
|
192
|
+
Q_normal += Q_momentum_eqns[:, d] * normal[d]
|
|
193
|
+
Q_transverse += Q_momentum_eqns[:, d] * transverse_directions[d]
|
|
194
|
+
return Q_normal, Q_transverse
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def projection_in_x_y_direction(Qn, Qt, normal):
|
|
198
|
+
dim = normal.shape[0]
|
|
199
|
+
num_momentum_fields = Qn.shape[0]
|
|
200
|
+
transverse_directions = compute_transverse_direction(normal)
|
|
201
|
+
Q = np.empty((num_momentum_fields * dim), dtype=float)
|
|
202
|
+
for i in range(num_momentum_fields):
|
|
203
|
+
for d in range(dim):
|
|
204
|
+
Q[i + d * num_momentum_fields] = (
|
|
205
|
+
Qn[i] * normal[d] + Qt[i] * transverse_directions[d]
|
|
206
|
+
)
|
|
207
|
+
return Q
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def project_in_x_y_and_recreate_Q(Qn, Qt, Qorig, momentum_eqns, normal):
|
|
211
|
+
Qnew = np.array(Qorig)
|
|
212
|
+
Qnew[momentum_eqns] = projection_in_x_y_direction(Qn, Qt, normal)
|
|
213
|
+
return Qnew
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
from sympy import Matrix, diff, exp, I, linear_eq_to_matrix, solve , Eq, zeros, simplify, nsimplify, latex, symbols, Function, together, Symbol
|
|
2
|
+
from IPython.display import display, Latex
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class ModelAnalyser():
|
|
6
|
+
def __init__(self, model):
|
|
7
|
+
self.model = model
|
|
8
|
+
self.t = model.time
|
|
9
|
+
x, y, z = model.position
|
|
10
|
+
self.x = x
|
|
11
|
+
self.y = y
|
|
12
|
+
self.z = z
|
|
13
|
+
self.equations = None
|
|
14
|
+
self.plane_wave_symbols = []
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_equations(self):
|
|
18
|
+
return self.equations
|
|
19
|
+
|
|
20
|
+
def print_equations(self):
|
|
21
|
+
latex_lines = " \\\\\n".join([f"& {latex(eq)}" for eq in self.equations])
|
|
22
|
+
latex_block = r"$$\begin{align*}" + "\n" + latex_lines + r"\end{align*}$$"
|
|
23
|
+
display(Latex(latex_block))
|
|
24
|
+
|
|
25
|
+
def get_time_space(self):
|
|
26
|
+
x, y, z = self.model.position
|
|
27
|
+
t = self.model.time
|
|
28
|
+
return t, x, y, z
|
|
29
|
+
|
|
30
|
+
def _get_omega_k(self):
|
|
31
|
+
omega, kx, ky, kz = symbols('omega k_x k_y k_z')
|
|
32
|
+
return omega, kx, ky, kz
|
|
33
|
+
|
|
34
|
+
def _get_exponential(self):
|
|
35
|
+
omega, kx, ky, kz = self._get_omega_k()
|
|
36
|
+
t, x, y, z = self.get_time_space()
|
|
37
|
+
exponential = exp(I * (kx * x + ky * y + kz * z - omega * t))
|
|
38
|
+
return exponential
|
|
39
|
+
|
|
40
|
+
def get_eps(self):
|
|
41
|
+
eps = symbols('eps')
|
|
42
|
+
return eps
|
|
43
|
+
|
|
44
|
+
def create_functions_from_list(self, names):
|
|
45
|
+
t, x, y, z = self.get_time_space()
|
|
46
|
+
return [Function(name)(t, x, y, z) for name in names]
|
|
47
|
+
|
|
48
|
+
def delete_equations(self, indices):
|
|
49
|
+
self.equations = [self.equations[i] for i in range(len(self.equations)) if i not in indices]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def solve_for_constraints(self, list_of_selected_equations, list_of_variables):
|
|
53
|
+
equations = self.equations
|
|
54
|
+
sol = solve([equations[i] for i in list_of_selected_equations], list_of_variables)
|
|
55
|
+
equations = [eq.xreplace(sol).doit() for eq in equations]
|
|
56
|
+
# delete used equations from equation system
|
|
57
|
+
equations = [equations[i] for i in range(len(equations)) if i not in list_of_selected_equations]
|
|
58
|
+
self.equations = equations
|
|
59
|
+
return sol
|
|
60
|
+
|
|
61
|
+
def insert_plane_wave_ansatz(self, functions_to_replace):
|
|
62
|
+
exponential = self._get_exponential()
|
|
63
|
+
f_bar_dict = {}
|
|
64
|
+
for f in functions_to_replace:
|
|
65
|
+
# Create the base name (e.g., 'f0')
|
|
66
|
+
f_name = str(f.func) # Get the function name (e.g., 'f0')
|
|
67
|
+
|
|
68
|
+
# Create a new symbol representing \bar{f0}
|
|
69
|
+
f_bar = Symbol(r'\bar{' + f_name + '}')
|
|
70
|
+
f_bar_dict[f] = f_bar * exponential
|
|
71
|
+
self.plane_wave_symbols.append(f_bar)
|
|
72
|
+
self.equations = [eq.xreplace(f_bar_dict).doit() for eq in self.equations ]
|
|
73
|
+
|
|
74
|
+
def solve_for_dispersion_relation(self):
|
|
75
|
+
assert self.equations is not None, "No equations available to solve for dispersion relation."
|
|
76
|
+
assert self.plane_wave_symbols, "No plane wave symbols available to solve for dispersion relation. Use insert_plane_wave_ansatz first."
|
|
77
|
+
A, rhs = linear_eq_to_matrix(self.equations, self.plane_wave_symbols)
|
|
78
|
+
omega, kx, ky, kz = self._get_omega_k()
|
|
79
|
+
sol = solve(A.det(), omega)
|
|
80
|
+
return sol
|
|
81
|
+
|
|
82
|
+
def remove_exponential(self):
|
|
83
|
+
exponential = self._get_exponential()
|
|
84
|
+
equations = self.equations
|
|
85
|
+
equations = [simplify(Eq(eq.lhs / exponential, eq.rhs / exponential)) for eq in equations]
|
|
86
|
+
self.equations = equations
|
|
87
|
+
|
|
88
|
+
def linearize_system(self, q, qaux, constraints=None):
|
|
89
|
+
model = self.model
|
|
90
|
+
t, x, y, z = self.get_time_space()
|
|
91
|
+
dim = model.dimension
|
|
92
|
+
X = [x, y, z]
|
|
93
|
+
|
|
94
|
+
Q = Matrix(model.variables.get_list())
|
|
95
|
+
Qaux = Matrix(model.aux_variables.get_list())
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
substitutions = {Q[i]: q[i] for i in range(len(q))}
|
|
99
|
+
substitutions.update({Qaux[i]: qaux[i] for i in range(len(qaux))})
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
A = model.quasilinear_matrix()
|
|
103
|
+
S = model.residual()
|
|
104
|
+
if constraints is not None:
|
|
105
|
+
C = constraints
|
|
106
|
+
else:
|
|
107
|
+
C = zeros(0, 1)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
Q = Q.xreplace(substitutions)
|
|
111
|
+
for d in range(dim):
|
|
112
|
+
A[d] = A[d].xreplace(substitutions)
|
|
113
|
+
S = S.xreplace(substitutions)
|
|
114
|
+
C = C.xreplace(substitutions)
|
|
115
|
+
|
|
116
|
+
C = C.doit()
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
gradQ = Matrix([diff(q[i], X[j]) for i in range(len(q)) for j in range(dim)]).reshape(len(q), dim)
|
|
120
|
+
|
|
121
|
+
AgradQ = A[0] * gradQ[:, 0]
|
|
122
|
+
for d in range(1, dim):
|
|
123
|
+
AgradQ += A[d] * gradQ[:, d]
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
expr = list(Matrix.vstack((diff(q, t) + AgradQ - S) , C))
|
|
127
|
+
for i in range(len(expr)):
|
|
128
|
+
expr[i] = nsimplify(expr[i], rational=True)
|
|
129
|
+
expr = Matrix(expr)
|
|
130
|
+
eps = self.get_eps()
|
|
131
|
+
res = expr.copy()
|
|
132
|
+
for i, e in enumerate(expr):
|
|
133
|
+
collected = e
|
|
134
|
+
collected = collected.series(eps, 0, 2).removeO()
|
|
135
|
+
order_1_term = collected.coeff(eps, 1)
|
|
136
|
+
res[i] = order_1_term
|
|
137
|
+
|
|
138
|
+
for r in range(res.shape[0]):
|
|
139
|
+
denom = together(res[r]).as_numer_denom()[1]
|
|
140
|
+
res[r] *= denom
|
|
141
|
+
res[r] = simplify(res[r])
|
|
142
|
+
|
|
143
|
+
linearized_system = [Eq((res[i]),0) for i in range(res.shape[0])]
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
self.equations = linearized_system
|
|
147
|
+
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
|
|
2
|
+
import sympy
|
|
3
|
+
import sympy as sp
|
|
4
|
+
from attrs import define, field
|
|
5
|
+
from sympy import lambdify
|
|
6
|
+
|
|
7
|
+
from zoomy_core.misc.misc import Zstruct
|
|
8
|
+
|
|
9
|
+
def listify(expr):
|
|
10
|
+
if type(expr) is sp.Piecewise:
|
|
11
|
+
return expr
|
|
12
|
+
else:
|
|
13
|
+
return list(expr)
|
|
14
|
+
|
|
15
|
+
def vectorize_constant_sympy_expressions(expr, Q, Qaux):
|
|
16
|
+
"""
|
|
17
|
+
Replace entries in `expr` that are constant w.r.t. Q and Qaux
|
|
18
|
+
by entry * ones_like(Q[0]) so NumPy/JAX vectorization works.
|
|
19
|
+
Handles scalars, lists, sympy.Matrix, sympy.Array, and sympy.Piecewise.
|
|
20
|
+
"""
|
|
21
|
+
symbol_list = set(Q.get_list() + Qaux.get_list())
|
|
22
|
+
q0 = Q[0]
|
|
23
|
+
ones_like = sp.Function("ones_like") # symbolic placeholder
|
|
24
|
+
zeros_like = sp.Function("zeros_like") # symbolic placeholder
|
|
25
|
+
|
|
26
|
+
# convert matrices to nested lists (Array handles lists better)
|
|
27
|
+
if isinstance(expr, (sp.MatrixBase, sp.ImmutableDenseMatrix, sp.MutableDenseMatrix)):
|
|
28
|
+
expr = expr.tolist()
|
|
29
|
+
|
|
30
|
+
def vectorize_entry(entry):
|
|
31
|
+
"""Return entry multiplied by ones_like(q0) if it is constant."""
|
|
32
|
+
# numeric zero
|
|
33
|
+
if entry == 0:
|
|
34
|
+
return zeros_like(q0)
|
|
35
|
+
|
|
36
|
+
# numeric constant (int, float, Rational, pi, etc.)
|
|
37
|
+
if getattr(entry, "is_number", False):
|
|
38
|
+
return entry * ones_like(q0)
|
|
39
|
+
|
|
40
|
+
# symbolic constant independent of Q and Qaux
|
|
41
|
+
if hasattr(entry, "free_symbols") and entry.free_symbols.isdisjoint(symbol_list):
|
|
42
|
+
return entry * ones_like(q0)
|
|
43
|
+
|
|
44
|
+
# otherwise, depends on variables
|
|
45
|
+
return entry
|
|
46
|
+
|
|
47
|
+
def recurse(e):
|
|
48
|
+
"""Recursively handle Array, Matrix, Piecewise, list, or scalar."""
|
|
49
|
+
# Handle lists (possibly nested)
|
|
50
|
+
if isinstance(e, list):
|
|
51
|
+
return [recurse(sub) for sub in e]
|
|
52
|
+
|
|
53
|
+
# Handle Matrices
|
|
54
|
+
if isinstance(e, sp.MatrixBase):
|
|
55
|
+
return sp.Matrix([[recurse(sub) for sub in row] for row in e.tolist()])
|
|
56
|
+
|
|
57
|
+
# Handle Arrays (any rank)
|
|
58
|
+
if isinstance(e, sp.Array):
|
|
59
|
+
return sp.Array([recurse(sub) for sub in e])
|
|
60
|
+
|
|
61
|
+
# Handle Piecewise
|
|
62
|
+
if isinstance(e, sp.Piecewise):
|
|
63
|
+
# Recursively process all (expr, cond) pairs
|
|
64
|
+
new_args = []
|
|
65
|
+
for expr_i, cond_i in e.args:
|
|
66
|
+
new_expr = recurse(expr_i)
|
|
67
|
+
new_args.append((new_expr, cond_i))
|
|
68
|
+
return sp.Piecewise(*new_args)
|
|
69
|
+
|
|
70
|
+
# Scalar or atomic expression
|
|
71
|
+
return vectorize_entry(e)
|
|
72
|
+
|
|
73
|
+
# Recurse and then normalize into an N-dimensional array (if possible)
|
|
74
|
+
result = recurse(expr)
|
|
75
|
+
|
|
76
|
+
# Convert to Array if possible (this ensures uniform output type)
|
|
77
|
+
if isinstance(result, list):
|
|
78
|
+
try:
|
|
79
|
+
return sp.Array(result)
|
|
80
|
+
except Exception:
|
|
81
|
+
# fall back if shapes inconsistent
|
|
82
|
+
return result
|
|
83
|
+
return result
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@define(frozen=True, slots=True, kw_only=True)
|
|
90
|
+
class Function:
|
|
91
|
+
"""
|
|
92
|
+
Generic (virtual) function implementation.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
name: str = field(default="Function")
|
|
96
|
+
args: Zstruct = field(default=Zstruct())
|
|
97
|
+
definition = field(default=sympy.zeros(1, 1))
|
|
98
|
+
|
|
99
|
+
def __call__(self):
|
|
100
|
+
"""Allow calling the instance to get its symbolic definition."""
|
|
101
|
+
return self.definition
|
|
102
|
+
|
|
103
|
+
def lambdify(self, modules=None):
|
|
104
|
+
"""Return a lambdified version of the function."""
|
|
105
|
+
|
|
106
|
+
func = lambdify(
|
|
107
|
+
self.args.get_list(),
|
|
108
|
+
vectorize_constant_sympy_expressions(
|
|
109
|
+
listify(self.definition), self.args.variables, self.args.aux_variables
|
|
110
|
+
),
|
|
111
|
+
modules=modules,
|
|
112
|
+
)
|
|
113
|
+
return func
|