zoomy-core 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of zoomy-core might be problematic. Click here for more details.

Files changed (57) hide show
  1. zoomy_core/decorators/decorators.py +25 -0
  2. zoomy_core/fvm/flux.py +97 -0
  3. zoomy_core/fvm/nonconservative_flux.py +97 -0
  4. zoomy_core/fvm/ode.py +55 -0
  5. zoomy_core/fvm/solver_numpy.py +305 -0
  6. zoomy_core/fvm/timestepping.py +13 -0
  7. zoomy_core/mesh/gmsh_loader.py +301 -0
  8. zoomy_core/mesh/mesh.py +1192 -0
  9. zoomy_core/mesh/mesh_extrude.py +168 -0
  10. zoomy_core/mesh/mesh_util.py +487 -0
  11. zoomy_core/misc/custom_types.py +6 -0
  12. zoomy_core/misc/gui.py +61 -0
  13. zoomy_core/misc/interpolation.py +140 -0
  14. zoomy_core/misc/io.py +401 -0
  15. zoomy_core/misc/logger_config.py +18 -0
  16. zoomy_core/misc/misc.py +216 -0
  17. zoomy_core/misc/static_class.py +94 -0
  18. zoomy_core/model/analysis.py +147 -0
  19. zoomy_core/model/basefunction.py +113 -0
  20. zoomy_core/model/basemodel.py +512 -0
  21. zoomy_core/model/boundary_conditions.py +193 -0
  22. zoomy_core/model/initial_conditions.py +171 -0
  23. zoomy_core/model/model.py +63 -0
  24. zoomy_core/model/models/GN.py +70 -0
  25. zoomy_core/model/models/advection.py +53 -0
  26. zoomy_core/model/models/basisfunctions.py +181 -0
  27. zoomy_core/model/models/basismatrices.py +377 -0
  28. zoomy_core/model/models/core.py +564 -0
  29. zoomy_core/model/models/coupled_constrained.py +60 -0
  30. zoomy_core/model/models/old_smm copy.py +867 -0
  31. zoomy_core/model/models/poisson.py +41 -0
  32. zoomy_core/model/models/shallow_moments.py +757 -0
  33. zoomy_core/model/models/shallow_moments_sediment.py +378 -0
  34. zoomy_core/model/models/shallow_moments_topo.py +423 -0
  35. zoomy_core/model/models/shallow_moments_variants.py +1509 -0
  36. zoomy_core/model/models/shallow_water.py +266 -0
  37. zoomy_core/model/models/shallow_water_topo.py +111 -0
  38. zoomy_core/model/models/shear_shallow_flow.py +594 -0
  39. zoomy_core/model/models/sme_turbulent.py +613 -0
  40. zoomy_core/model/models/swe_old.py +1018 -0
  41. zoomy_core/model/models/vam.py +455 -0
  42. zoomy_core/postprocessing/postprocessing.py +72 -0
  43. zoomy_core/preprocessing/openfoam_moments.py +452 -0
  44. zoomy_core/transformation/helpers.py +25 -0
  45. zoomy_core/transformation/to_amrex.py +238 -0
  46. zoomy_core/transformation/to_c.py +181 -0
  47. zoomy_core/transformation/to_jax.py +14 -0
  48. zoomy_core/transformation/to_numpy.py +115 -0
  49. zoomy_core/transformation/to_openfoam.py +254 -0
  50. zoomy_core/transformation/to_ufl.py +67 -0
  51. {zoomy_core-0.1.1.dist-info → zoomy_core-0.1.2.dist-info}/METADATA +1 -1
  52. zoomy_core-0.1.2.dist-info/RECORD +55 -0
  53. zoomy_core-0.1.2.dist-info/top_level.txt +1 -0
  54. zoomy_core-0.1.1.dist-info/RECORD +0 -5
  55. zoomy_core-0.1.1.dist-info/top_level.txt +0 -1
  56. {zoomy_core-0.1.1.dist-info → zoomy_core-0.1.2.dist-info}/WHEEL +0 -0
  57. {zoomy_core-0.1.1.dist-info → zoomy_core-0.1.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,216 @@
1
+ import os
2
+ import numpy as np
3
+
4
+ # import scipy.interpolate as interp
5
+ # from functools import wraps
6
+
7
+ from attr import define
8
+ from typing import Callable, Optional, Any
9
+ from types import SimpleNamespace
10
+
11
+ from sympy import MatrixSymbol
12
+ from sympy import MutableDenseNDimArray as ZArray
13
+
14
+ from library.zoomy_core.misc.custom_types import FArray
15
+ from library.zoomy_core.misc.static_class import register_static_pytree
16
+ from library.zoomy_core.misc.logger_config import logger
17
+
18
+
19
+
20
+
21
+
22
+
23
+ @register_static_pytree
24
+ @define(slots=True, frozen=False, kw_only=True)
25
+ class Zstruct(SimpleNamespace):
26
+ def __init__(self, **kwargs):
27
+ super().__init__(**kwargs)
28
+
29
+ def __getitem__(self, key):
30
+ return self.values()[key]
31
+
32
+ def length(self):
33
+ return len(self.values())
34
+
35
+ def get_list(self, recursive: bool = True):
36
+ if recursive:
37
+ output = []
38
+ for item in self.values():
39
+ if hasattr(item, 'get_list'):
40
+ # If item is a Zstruct or similar, call get_list recursively
41
+ output.append(item.get_list(recursive=True))
42
+ else:
43
+ output.append(item)
44
+ return output
45
+ else:
46
+ return self.values()
47
+
48
+ def as_dict(self, recursive: bool = True):
49
+ if recursive:
50
+ output = {}
51
+ for key, value in self.items():
52
+ if hasattr(value, 'as_dict'):
53
+ # If value is a Zstruct or similar, call as_dict recursively
54
+ output[key] = value.as_dict(recursive=True)
55
+ else:
56
+ output[key] = value
57
+ return output
58
+ else:
59
+ return self.__dict__
60
+
61
+
62
+ def items(self, resursive: bool = False):
63
+ return self.as_dict(recursive=resursive).items()
64
+
65
+ def keys(self):
66
+ return list(self.as_dict(recursive=False).keys())
67
+
68
+ def values(self):
69
+ return list(self.as_dict(recursive=False).values())
70
+
71
+ def contains(self, key):
72
+ if self.as_dict(recursive=False).get(key) is not None:
73
+ return True
74
+ return False
75
+
76
+ def update(self, zstruct, recursive: bool = True):
77
+ """
78
+ Update the current Zstruct with another Zstruct or dictionary.
79
+ """
80
+ if not isinstance(zstruct, Zstruct):
81
+ raise TypeError("zstruct must be a Zstruct or a dictionary.")
82
+
83
+ if recursive:
84
+ # Update each attribute recursively
85
+ for key, value in zstruct.as_dict(recursive=False).items():
86
+ if hasattr(self, key):
87
+ current_value = getattr(self, key)
88
+ if isinstance(current_value, Zstruct) and isinstance(value, Zstruct):
89
+ # If both are Zstructs, update recursively
90
+ current_value.update(value, recursive=True)
91
+ else:
92
+ setattr(self, key, value)
93
+ else:
94
+ setattr(self, key, value)
95
+ else:
96
+ # Update only the top-level attributes
97
+ for key, value in zstruct.as_dict(recursive=False).items():
98
+ setattr(self, key, value)
99
+
100
+ @classmethod
101
+ def from_dict(cls, d):
102
+ """
103
+ Create a Zstruct recursively from a dictionary.
104
+
105
+ Args:
106
+ d (dict): Dictionary.
107
+
108
+ Returns:
109
+ Zstruct: An instance of Zstruct.
110
+ """
111
+ if not isinstance(d, dict):
112
+ raise TypeError("Input must be a dictionary.")
113
+
114
+ # Convert the dictionary to a Zstruct
115
+ for k, v in d.items():
116
+ if isinstance(v, dict):
117
+ d[k] = Zstruct.from_dict(v)
118
+ return cls(**d)
119
+
120
+
121
+ @register_static_pytree
122
+ @define(slots=True, frozen=False, kw_only=True)
123
+ class Settings(Zstruct):
124
+ """
125
+ Settings class for the application.
126
+
127
+ Args:
128
+ **kwargs: Arbitrary keyword arguments to set as attributes.
129
+
130
+ Returns:
131
+ An `IterableNamespace` instance.
132
+ """
133
+
134
+ def __init__(self, **kwargs):
135
+ # assert that kwargs constains name
136
+ if 'output' not in kwargs or not isinstance(kwargs['output'], Zstruct):
137
+ logger.warning("No 'output' Zstruct found in Settings. Default: Zstruct(directory='output', filename='simulation', clean_directory=False)")
138
+ kwargs['output'] = Zstruct(directory='output', filename='simulation', clean_directory=True)
139
+ output = kwargs['output']
140
+ if not output.contains('directory'):
141
+ logger.warning("No 'directory' attribute found in output Zstruct. Default: 'output'")
142
+ kwargs['output'] = Zstruct(directory='output', **output.as_dict())
143
+ if not output.contains('filename'):
144
+ logger.warning("No 'filename' attribute found in output Zstruct. Default: 'simulation'")
145
+ kwargs['output'] = Zstruct(filename='simulation', **output.as_dict())
146
+ if not output.contains('clean_directory'):
147
+ logger.warning("No 'clean_directory' attribute found in output Zstruct. Default: False")
148
+ kwargs['output'] = Zstruct(clean_directory=False, **output.as_dict())
149
+ super().__init__(**kwargs)
150
+
151
+ @classmethod
152
+ def default(cls):
153
+ """
154
+ Returns a default Settings instance.
155
+ """
156
+ return cls(
157
+ output=Zstruct(directory='output', filename='simulation', clean_directory=False)
158
+ )
159
+
160
+
161
+ def compute_transverse_direction(normal):
162
+ dim = normal.shape[0]
163
+ if dim == 1:
164
+ return np.zeros_like(normal)
165
+ elif dim == 2:
166
+ transverse = np.zeros((2), dtype=float)
167
+ transverse[0] = -normal[1]
168
+ transverse[1] = normal[0]
169
+ return transverse
170
+ elif dim == 3:
171
+ cartesian_x = np.array([1, 0, 0], dtype=float)
172
+ transverse = np.cross(normal, cartesian_x)
173
+ return transverse
174
+ else:
175
+ assert False
176
+
177
+
178
+ def extract_momentum_fields_as_vectors(Q, momentum_fields, dim):
179
+ num_fields = len(momentum_fields)
180
+ num_momentum_eqns = int(num_fields / dim)
181
+ Qnew = np.empty((num_momentum_eqns, dim))
182
+ for i_eq in range(num_momentum_eqns):
183
+ for i_dim in range(dim):
184
+ Qnew[i_eq, i_dim] = Q[momentum_fields[i_dim * num_momentum_eqns + i_eq]]
185
+ return Qnew
186
+
187
+
188
+ def projection_in_normal_and_transverse_direction(Q, momentum_fields, normal):
189
+ dim = normal.shape[0]
190
+ transverse_directions = compute_transverse_direction(normal)
191
+ Q_momentum_eqns = extract_momentum_fields_as_vectors(Q, momentum_fields, dim)
192
+ Q_normal = np.zeros((Q_momentum_eqns.shape[0]), dtype=float)
193
+ Q_transverse = np.zeros((Q_momentum_eqns.shape[0]), dtype=float)
194
+ for d in range(dim):
195
+ Q_normal += Q_momentum_eqns[:, d] * normal[d]
196
+ Q_transverse += Q_momentum_eqns[:, d] * transverse_directions[d]
197
+ return Q_normal, Q_transverse
198
+
199
+
200
+ def projection_in_x_y_direction(Qn, Qt, normal):
201
+ dim = normal.shape[0]
202
+ num_momentum_fields = Qn.shape[0]
203
+ transverse_directions = compute_transverse_direction(normal)
204
+ Q = np.empty((num_momentum_fields * dim), dtype=float)
205
+ for i in range(num_momentum_fields):
206
+ for d in range(dim):
207
+ Q[i + d * num_momentum_fields] = (
208
+ Qn[i] * normal[d] + Qt[i] * transverse_directions[d]
209
+ )
210
+ return Q
211
+
212
+
213
+ def project_in_x_y_and_recreate_Q(Qn, Qt, Qorig, momentum_eqns, normal):
214
+ Qnew = np.array(Qorig)
215
+ Qnew[momentum_eqns] = projection_in_x_y_direction(Qn, Qt, normal)
216
+ return Qnew
@@ -0,0 +1,94 @@
1
+
2
+ import attrs
3
+ from functools import partial
4
+ from typing import Any, Tuple, Type
5
+ import attr
6
+
7
+ try:
8
+ import jax
9
+ import jax.numpy as jnp
10
+ _HAVE_JAX = True
11
+ except ImportError:
12
+ _HAVE_JAX = False
13
+
14
+
15
+ def register_static_pytree(cls: Type[Any]) -> Type[Any]:
16
+ """
17
+ Class decorator that registers the class as a JAX pytree node,
18
+ treating all member variables as static.
19
+
20
+ Parameters:
21
+ cls (Type[Any]): The class to register.
22
+
23
+ Returns:
24
+ Type[Any]: The registered class.
25
+ """
26
+ if not _HAVE_JAX:
27
+ # no-op decorator
28
+ return cls
29
+ if not attrs.has(cls):
30
+ raise TypeError(
31
+ "register_static_pytree can only be applied to classes decorated with @attrs.define or @attr.s"
32
+ )
33
+
34
+ # Extract field names from attrs
35
+ field_names = [field.name for field in attrs.fields(cls)]
36
+
37
+ # Define the flatten function
38
+ def flatten(instance: Any) -> Tuple[Tuple, Tuple]:
39
+ aux_data = tuple(getattr(instance, name) for name in field_names)
40
+ children = () # No dynamic children since all are static
41
+ return children, aux_data
42
+
43
+ # Define the unflatten function
44
+ def unflatten(aux_data: Tuple, children: Tuple) -> Any:
45
+ return cls(*aux_data)
46
+
47
+ # Register the class as a pytree node with JAX
48
+ jax.tree_util.register_pytree_node(cls, flatten, unflatten)
49
+
50
+ return cls
51
+
52
+
53
+ # 2. Define the Mesh class using attrs and the decorator
54
+ @register_static_pytree
55
+ @attrs.define(frozen=True)
56
+ class Mesh:
57
+ x: jnp.ndarray
58
+ y: jnp.ndarray
59
+ # Use factory for mutable default fields
60
+ z: jnp.ndarray = attr.field(factory=lambda: jnp.array([0.0]))
61
+ # Example additional fields
62
+ w: jnp.ndarray = attr.field(factory=lambda: jnp.array([[1.0, 2.0], [3.0, 4.0]]))
63
+ description: str = "Default Mesh Description"
64
+
65
+
66
+ # 3. Define the SpaceOperator class with JAX's jit
67
+ class SpaceOperator:
68
+ @partial(jax.jit, static_argnums=(0,)) # Marks 'self' as static
69
+ def solve(self, q: jnp.ndarray, mesh: Mesh) -> jnp.ndarray:
70
+ """
71
+ Example operation: Element-wise multiplication of mesh.x with q,
72
+ then add mesh.y for demonstration.
73
+ """
74
+ return mesh.x * q + mesh.y
75
+
76
+
77
+ if __name__ == "__main__":
78
+ # Create mesh data using JAX arrays
79
+ x = jnp.linspace(0, 1, 10)
80
+ y = jnp.linspace(1, 2, 10)
81
+ # Initialize Mesh with x and y; z and w use default_factory
82
+ mesh = Mesh(x, y)
83
+
84
+ # Initialize the operator
85
+ space_op = SpaceOperator()
86
+
87
+ # Define the input Q as a JAX array
88
+ Q = jnp.linspace(0, 1, 10)
89
+
90
+ # Use the solve method
91
+ Q_result = space_op.solve(Q, mesh)
92
+
93
+ # Print the result
94
+ print("Q_result:", Q_result)
@@ -0,0 +1,147 @@
1
+ from sympy import Matrix, diff, exp, I, linear_eq_to_matrix, solve , Eq, zeros, simplify, nsimplify, latex, symbols, Function, together, Symbol
2
+ from IPython.display import display, Latex
3
+
4
+
5
+ class ModelAnalyser():
6
+ def __init__(self, model):
7
+ self.model = model
8
+ self.t = model.time
9
+ x, y, z = model.position
10
+ self.x = x
11
+ self.y = y
12
+ self.z = z
13
+ self.equations = None
14
+ self.plane_wave_symbols = []
15
+
16
+
17
+ def get_equations(self):
18
+ return self.equations
19
+
20
+ def print_equations(self):
21
+ latex_lines = " \\\\\n".join([f"& {latex(eq)}" for eq in self.equations])
22
+ latex_block = r"$$\begin{align*}" + "\n" + latex_lines + r"\end{align*}$$"
23
+ display(Latex(latex_block))
24
+
25
+ def get_time_space(self):
26
+ x, y, z = self.model.position
27
+ t = self.model.time
28
+ return t, x, y, z
29
+
30
+ def _get_omega_k(self):
31
+ omega, kx, ky, kz = symbols('omega k_x k_y k_z')
32
+ return omega, kx, ky, kz
33
+
34
+ def _get_exponential(self):
35
+ omega, kx, ky, kz = self._get_omega_k()
36
+ t, x, y, z = self.get_time_space()
37
+ exponential = exp(I * (kx * x + ky * y + kz * z - omega * t))
38
+ return exponential
39
+
40
+ def get_eps(self):
41
+ eps = symbols('eps')
42
+ return eps
43
+
44
+ def create_functions_from_list(self, names):
45
+ t, x, y, z = self.get_time_space()
46
+ return [Function(name)(t, x, y, z) for name in names]
47
+
48
+ def delete_equations(self, indices):
49
+ self.equations = [self.equations[i] for i in range(len(self.equations)) if i not in indices]
50
+
51
+
52
+ def solve_for_constraints(self, list_of_selected_equations, list_of_variables):
53
+ equations = self.equations
54
+ sol = solve([equations[i] for i in list_of_selected_equations], list_of_variables)
55
+ equations = [eq.xreplace(sol).doit() for eq in equations]
56
+ # delete used equations from equation system
57
+ equations = [equations[i] for i in range(len(equations)) if i not in list_of_selected_equations]
58
+ self.equations = equations
59
+ return sol
60
+
61
+ def insert_plane_wave_ansatz(self, functions_to_replace):
62
+ exponential = self._get_exponential()
63
+ f_bar_dict = {}
64
+ for f in functions_to_replace:
65
+ # Create the base name (e.g., 'f0')
66
+ f_name = str(f.func) # Get the function name (e.g., 'f0')
67
+
68
+ # Create a new symbol representing \bar{f0}
69
+ f_bar = Symbol(r'\bar{' + f_name + '}')
70
+ f_bar_dict[f] = f_bar * exponential
71
+ self.plane_wave_symbols.append(f_bar)
72
+ self.equations = [eq.xreplace(f_bar_dict).doit() for eq in self.equations ]
73
+
74
+ def solve_for_dispersion_relation(self):
75
+ assert self.equations is not None, "No equations available to solve for dispersion relation."
76
+ assert self.plane_wave_symbols, "No plane wave symbols available to solve for dispersion relation. Use insert_plane_wave_ansatz first."
77
+ A, rhs = linear_eq_to_matrix(self.equations, self.plane_wave_symbols)
78
+ omega, kx, ky, kz = self._get_omega_k()
79
+ sol = solve(A.det(), omega)
80
+ return sol
81
+
82
+ def remove_exponential(self):
83
+ exponential = self._get_exponential()
84
+ equations = self.equations
85
+ equations = [simplify(Eq(eq.lhs / exponential, eq.rhs / exponential)) for eq in equations]
86
+ self.equations = equations
87
+
88
+ def linearize_system(self, q, qaux, constraints=None):
89
+ model = self.model
90
+ t, x, y, z = self.get_time_space()
91
+ dim = model.dimension
92
+ X = [x, y, z]
93
+
94
+ Q = Matrix(model.variables.get_list())
95
+ Qaux = Matrix(model.aux_variables.get_list())
96
+
97
+
98
+ substitutions = {Q[i]: q[i] for i in range(len(q))}
99
+ substitutions.update({Qaux[i]: qaux[i] for i in range(len(qaux))})
100
+
101
+
102
+ A = model.quasilinear_matrix()
103
+ S = model.residual()
104
+ if constraints is not None:
105
+ C = constraints
106
+ else:
107
+ C = zeros(0, 1)
108
+
109
+
110
+ Q = Q.xreplace(substitutions)
111
+ for d in range(dim):
112
+ A[d] = A[d].xreplace(substitutions)
113
+ S = S.xreplace(substitutions)
114
+ C = C.xreplace(substitutions)
115
+
116
+ C = C.doit()
117
+
118
+
119
+ gradQ = Matrix([diff(q[i], X[j]) for i in range(len(q)) for j in range(dim)]).reshape(len(q), dim)
120
+
121
+ AgradQ = A[0] * gradQ[:, 0]
122
+ for d in range(1, dim):
123
+ AgradQ += A[d] * gradQ[:, d]
124
+
125
+
126
+ expr = list(Matrix.vstack((diff(q, t) + AgradQ - S) , C))
127
+ for i in range(len(expr)):
128
+ expr[i] = nsimplify(expr[i], rational=True)
129
+ expr = Matrix(expr)
130
+ eps = self.get_eps()
131
+ res = expr.copy()
132
+ for i, e in enumerate(expr):
133
+ collected = e
134
+ collected = collected.series(eps, 0, 2).removeO()
135
+ order_1_term = collected.coeff(eps, 1)
136
+ res[i] = order_1_term
137
+
138
+ for r in range(res.shape[0]):
139
+ denom = together(res[r]).as_numer_denom()[1]
140
+ res[r] *= denom
141
+ res[r] = simplify(res[r])
142
+
143
+ linearized_system = [Eq((res[i]),0) for i in range(res.shape[0])]
144
+
145
+
146
+ self.equations = linearized_system
147
+
@@ -0,0 +1,113 @@
1
+
2
+ import sympy
3
+ import sympy as sp
4
+ from attrs import define, field
5
+ from sympy import lambdify
6
+
7
+ from library.zoomy_core.misc.misc import Zstruct
8
+
9
+ def listify(expr):
10
+ if type(expr) is sp.Piecewise:
11
+ return expr
12
+ else:
13
+ return list(expr)
14
+
15
+ def vectorize_constant_sympy_expressions(expr, Q, Qaux):
16
+ """
17
+ Replace entries in `expr` that are constant w.r.t. Q and Qaux
18
+ by entry * ones_like(Q[0]) so NumPy/JAX vectorization works.
19
+ Handles scalars, lists, sympy.Matrix, sympy.Array, and sympy.Piecewise.
20
+ """
21
+ symbol_list = set(Q.get_list() + Qaux.get_list())
22
+ q0 = Q[0]
23
+ ones_like = sp.Function("ones_like") # symbolic placeholder
24
+ zeros_like = sp.Function("zeros_like") # symbolic placeholder
25
+
26
+ # convert matrices to nested lists (Array handles lists better)
27
+ if isinstance(expr, (sp.MatrixBase, sp.ImmutableDenseMatrix, sp.MutableDenseMatrix)):
28
+ expr = expr.tolist()
29
+
30
+ def vectorize_entry(entry):
31
+ """Return entry multiplied by ones_like(q0) if it is constant."""
32
+ # numeric zero
33
+ if entry == 0:
34
+ return zeros_like(q0)
35
+
36
+ # numeric constant (int, float, Rational, pi, etc.)
37
+ if getattr(entry, "is_number", False):
38
+ return entry * ones_like(q0)
39
+
40
+ # symbolic constant independent of Q and Qaux
41
+ if hasattr(entry, "free_symbols") and entry.free_symbols.isdisjoint(symbol_list):
42
+ return entry * ones_like(q0)
43
+
44
+ # otherwise, depends on variables
45
+ return entry
46
+
47
+ def recurse(e):
48
+ """Recursively handle Array, Matrix, Piecewise, list, or scalar."""
49
+ # Handle lists (possibly nested)
50
+ if isinstance(e, list):
51
+ return [recurse(sub) for sub in e]
52
+
53
+ # Handle Matrices
54
+ if isinstance(e, sp.MatrixBase):
55
+ return sp.Matrix([[recurse(sub) for sub in row] for row in e.tolist()])
56
+
57
+ # Handle Arrays (any rank)
58
+ if isinstance(e, sp.Array):
59
+ return sp.Array([recurse(sub) for sub in e])
60
+
61
+ # Handle Piecewise
62
+ if isinstance(e, sp.Piecewise):
63
+ # Recursively process all (expr, cond) pairs
64
+ new_args = []
65
+ for expr_i, cond_i in e.args:
66
+ new_expr = recurse(expr_i)
67
+ new_args.append((new_expr, cond_i))
68
+ return sp.Piecewise(*new_args)
69
+
70
+ # Scalar or atomic expression
71
+ return vectorize_entry(e)
72
+
73
+ # Recurse and then normalize into an N-dimensional array (if possible)
74
+ result = recurse(expr)
75
+
76
+ # Convert to Array if possible (this ensures uniform output type)
77
+ if isinstance(result, list):
78
+ try:
79
+ return sp.Array(result)
80
+ except Exception:
81
+ # fall back if shapes inconsistent
82
+ return result
83
+ return result
84
+
85
+
86
+
87
+
88
+
89
+ @define(frozen=True, slots=True, kw_only=True)
90
+ class Function:
91
+ """
92
+ Generic (virtual) function implementation.
93
+ """
94
+
95
+ name: str = field(default="Function")
96
+ args: Zstruct = field(default=Zstruct())
97
+ definition = field(default=sympy.zeros(1, 1))
98
+
99
+ def __call__(self):
100
+ """Allow calling the instance to get its symbolic definition."""
101
+ return self.definition
102
+
103
+ def lambdify(self, modules=None):
104
+ """Return a lambdified version of the function."""
105
+
106
+ func = lambdify(
107
+ self.args.get_list(),
108
+ vectorize_constant_sympy_expressions(
109
+ listify(self.definition), self.args.variables, self.args.aux_variables
110
+ ),
111
+ modules=modules,
112
+ )
113
+ return func