demathpy 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
demathpy/__init__.py CHANGED
@@ -8,16 +8,16 @@ from .pde import (
8
8
  parse_pde,
9
9
  step_pdes,
10
10
  )
11
- from .ode import robust_parse, parse_odes_to_function
11
+ from .ode import ODE, parse_ode
12
12
 
13
13
  __all__ = [
14
14
  "PDE",
15
+ "ODE",
15
16
  "normalize_symbols",
16
17
  "normalize_lhs",
17
18
  "normalize_pde",
18
19
  "init_grid",
19
20
  "parse_pde",
20
21
  "step_pdes",
21
- "robust_parse",
22
- "parse_odes_to_function",
22
+ "parse_ode",
23
23
  ]
demathpy/ode.py CHANGED
@@ -1,131 +1,427 @@
1
- import json
1
+ """
2
+ Ordinary Differential Equation (ODE) utilities.
3
+
4
+ Provides a class-based ODE solver similar to the PDE module,
5
+ but for systems dependent only on time (t).
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Dict, List, Tuple, Any, Optional
11
+
2
12
  import re
3
- import sympy
13
+ import json
4
14
  import numpy as np
5
- from sympy.parsing.sympy_parser import parse_expr, standard_transformations, implicit_multiplication_application, convert_xor
6
15
 
16
+ from .symbols import normalize_symbols, normalize_lhs
7
17
 
8
- def _convert_ternary(expr: str) -> str:
9
- """
10
- Convert a single C-style ternary (cond ? a : b) into SymPy Piecewise.
11
- Supports one level (no nesting).
12
- """
13
- if "?" not in expr or ":" not in expr:
14
- return expr
15
-
16
- # naive split for single ternary
17
- # pattern: <cond> ? <a> : <b>
18
- parts = expr.split("?")
19
- if len(parts) != 2:
20
- return expr
21
- cond = parts[0].strip()
22
- rest = parts[1]
23
- if ":" not in rest:
24
- return expr
25
- a, b = rest.split(":", 1)
26
- a = a.strip()
27
- b = b.strip()
28
- return f"Piecewise(({a}, {cond}), ({b}, True))"
29
-
30
-
31
- def robust_parse(expr_str):
18
+
19
+ def _preprocess_expr(expr: str) -> str:
20
+ expr = (expr or "").strip()
21
+ expr = re.sub(r"\(\s*approx\s*\)", "", expr, flags=re.IGNORECASE)
22
+ expr = re.sub(r"\bapprox\b", "", expr, flags=re.IGNORECASE)
23
+ expr = normalize_symbols(expr)
24
+ return expr.strip()
25
+
26
+
27
+ def normalize_ode(ode: str) -> str:
28
+ return (ode or "").strip()
29
+
30
+
31
+ def parse_ode(ode: str) -> Tuple[str, int, str, str]:
32
32
  """
33
- Parses a string into a SymPy expression with relaxed syntax rules:
34
- - Implicit multiplication (5x -> 5*x)
35
- - Caret for power (x^2 -> x**2)
36
- - Aliases 'y' to 'z' for 2D convenience
33
+ Returns (var, order, lhs_coeff_expr, rhs_expr).
34
+ Parses "dy/dt = -y" or "d^2y/dt^2 = -y".
37
35
  """
38
- if not isinstance(expr_str, str):
39
- return sympy.sympify(expr_str)
36
+ ode = normalize_ode(ode)
37
+ if "=" not in ode:
38
+ # Assume implicit "dy/dt =" ?? No, safer to default to u, 1st order
39
+ return "u", 1, "1", ode
40
+
41
+ lhs, rhs = ode.split("=", 1)
42
+ lhs = normalize_lhs(lhs.strip())
43
+ rhs = rhs.strip()
44
+
45
+ def _extract_coeff(lhs_expr: str, deriv_expr: str) -> str:
46
+ coeff = lhs_expr.replace(deriv_expr, "").strip()
47
+ if not coeff:
48
+ return "1"
49
+ coeff = coeff.strip("*")
50
+ return _preprocess_expr(coeff) or "1"
51
+
52
+ # Updated regex for ODE derivatives: dy/dt, d^2x/dt^2
53
+ pattern = r"(?:∂|d)\s*(?:\^?(\d+))?\s*([a-zA-Z_]\w*)\s*/\s*(?:∂|d)t\s*(?:\^?(\d+))?"
54
+ m = re.search(pattern, lhs)
55
+
56
+ if m:
57
+ ord1 = m.group(1)
58
+ var = m.group(2)
59
+ ord2 = m.group(3)
40
60
 
41
- transformations = (standard_transformations + (implicit_multiplication_application, convert_xor))
61
+ order = 1
62
+ if ord1:
63
+ order = int(ord1)
64
+ elif ord2:
65
+ order = int(ord2)
66
+
67
+ coeff = _extract_coeff(lhs, m.group(0))
68
+ return var, order, coeff, _preprocess_expr(rhs)
69
+
70
+ # Fallback if no dy/dt found on lhs?
71
+ # Maybe user wrote "y' = ..." (not supported by regex yet)
72
+ return "u", 1, "1", _preprocess_expr(rhs)
73
+
74
+
75
+ class ODE:
76
+ equation: str
77
+ desc: str
78
+
79
+ u: np.ndarray
80
+ u_shape: List[str] # ["x"] or ["x", "y"] for vector systems
42
81
 
43
- # Define symbols and alias y -> z
44
- x, z, vx, vz, t, pid = sympy.symbols('x z vx vz t id')
45
- local_dict = {
46
- 'x': x, 'z': z, 'y': z, 'vx': vx, 'vz': vz, 't': t, 'id': pid,
47
- 'pi': sympy.pi, 'e': sympy.E
48
- }
49
-
50
- # Ensure common functions are recognized (Abs not abs)
51
- local_dict.update({
52
- 'sin': sympy.sin,
53
- 'cos': sympy.cos,
54
- 'tan': sympy.tan,
55
- 'exp': sympy.exp,
56
- 'sqrt': sympy.sqrt,
57
- 'log': sympy.log,
58
- 'abs': sympy.Abs,
59
- 'Abs': sympy.Abs,
60
- 'Piecewise': sympy.Piecewise,
61
- })
62
-
63
- try:
64
- pre = _convert_ternary(expr_str)
65
- return parse_expr(pre, transformations=transformations, local_dict=local_dict)
66
- except Exception:
67
- # Fallback
68
- return sympy.sympify(expr_str, locals=local_dict)
69
-
70
-
71
- def parse_odes_to_function(ode_json_str):
72
- """
73
- Parses a JSON string of ODEs and returns a dynamic update function.
74
- """
75
- try:
76
- if isinstance(ode_json_str, str):
77
- odes = json.loads(ode_json_str)
82
+ initial: List[str] # ["x=1", "y=0"]
83
+
84
+ external_variables: Dict[str, float]
85
+ time: float
86
+
87
+ _u_t: np.ndarray | None = None # For 2nd order or momentum
88
+
89
+ def __init__(self, equation: str = "", desc: str = "", u_shape: List[str] = None):
90
+ self.equation = equation
91
+ self.desc = desc
92
+ self.u = np.array([])
93
+
94
+ if u_shape:
95
+ self.u_shape = u_shape
96
+ elif equation:
97
+ # Infer from equation
98
+ var, _, _, _ = parse_ode(equation)
99
+ if var and var != "u":
100
+ self.u_shape = [var]
101
+ else:
102
+ self.u_shape = ["u"]
78
103
  else:
79
- odes = ode_json_str
80
- except json.JSONDecodeError as e:
81
- print(f"Failed to decode JSON from LLM: {e}")
82
- return None
104
+ self.u_shape = ["u"]
83
105
 
84
- # Define standard symbols
85
- x, z, vx, vz, t = sympy.symbols('x z vx vz t')
86
-
87
- deriv_map = {}
88
- keys = ['dx', 'dz', 'dvx', 'dvz']
106
+ self.initial = []
107
+ self.external_variables = {}
108
+ self.time = 0.0
109
+ self._u_t = None
110
+
111
+ def init_state(self, shape: tuple = (1,)):
112
+ """
113
+ Initialize the state array y (self.u).
114
+ For a scalar ODE ("du/dt = ..."), shape corresponds to batch size (N independent systems).
115
+ For a vector ODE (u_shape=["x", "v"]), self.u will be (2, N) if shape=(N,).
116
+ """
117
+ self.time = 0.0
118
+
119
+ if not self.u_shape:
120
+ self.u_shape = ["u"]
121
+
122
+ num_components = len(self.u_shape)
123
+
124
+ # If user passed a single int, wrap it
125
+ if isinstance(shape, int):
126
+ shape = (shape,)
127
+
128
+ if num_components > 1:
129
+ self.u = np.zeros((num_components, *shape), dtype=float)
130
+ else:
131
+ self.u = np.zeros(shape, dtype=float)
132
+
133
+ self._u_t = np.zeros_like(self.u)
134
+
135
+ def set_initial_state(self):
136
+ """
137
+ Parses `self.initial` and applies to `self.u`.
138
+ """
139
+ env = self._build_eval_env()
140
+ env["t"] = 0.0
141
+
142
+ if not self.initial:
143
+ return
144
+
145
+ for ic_eqn in self.initial:
146
+ if "=" not in ic_eqn: continue
147
+ lhs, rhs = ic_eqn.split("=", 1)
148
+ lhs = lhs.strip()
149
+ rhs_expr = rhs.strip()
150
+
151
+ # Determine target component
152
+ target_idx = None
153
+ target_name = "u"
154
+
155
+ m = re.match(r"^([a-zA-Z_]\w*)", lhs)
156
+ if m:
157
+ target_name = m.group(1)
158
+
159
+ if target_name == "u" and len(self.u_shape) == 1:
160
+ target_idx = None
161
+ elif target_name in self.u_shape:
162
+ target_idx = self.u_shape.index(target_name)
163
+
164
+ try:
165
+ rhs_expr = _preprocess_expr(rhs_expr)
166
+ val = eval(rhs_expr, {}, env)
167
+
168
+ if target_idx is None:
169
+ if np.shape(val) == np.shape(self.u):
170
+ self.u[:] = val
171
+ else:
172
+ self.u[:] = val # Broadcast
173
+ elif len(self.u_shape) == 1 and target_idx == 0:
174
+ self.u[:] = val
175
+ else:
176
+ if target_idx < len(self.u):
177
+ self.u[target_idx][:] = val
178
+ except Exception as e:
179
+ print(f"Failed to set IC '{ic_eqn}': {e}")
180
+
181
+ def _build_eval_env(self) -> Dict[str, object]:
182
+
183
+ def pos(u): return np.maximum(u, 0.0)
184
+ def sign(u): return np.sign(u)
185
+ def step_fn(u): return np.heaviside(u, 1.0)
186
+ def heaviside_fn(u, h0=1.0): return np.heaviside(u, h0)
187
+ def clamp(u, lower, upper): return np.clip(u, lower, upper)
188
+
189
+ def elementwise_min(*values):
190
+ if not values:
191
+ raise ValueError("min requires at least one value")
192
+ out = values[0]
193
+ for value in values[1:]:
194
+ out = np.minimum(out, value)
195
+ return out
196
+
197
+ def elementwise_max(*values):
198
+ if not values:
199
+ raise ValueError("max requires at least one value")
200
+ out = values[0]
201
+ for value in values[1:]:
202
+ out = np.maximum(out, value)
203
+ return out
204
+
205
+ def piecewise(*cases):
206
+ if not cases:
207
+ raise ValueError("Piecewise requires at least one case")
208
+
209
+ # Also support numpy-style piecewise(cond, if_true, if_false).
210
+ if len(cases) == 3 and not isinstance(cases[0], (tuple, list)):
211
+ cond, if_true, if_false = cases
212
+ return np.where(cond, if_true, if_false)
213
+
214
+ condlist = []
215
+ choicelist = []
216
+ default = 0.0
217
+
218
+ for case in cases:
219
+ if not isinstance(case, (tuple, list)) or len(case) != 2:
220
+ raise ValueError("Piecewise cases must be (expr, condition) pairs")
221
+ expr, cond = case
222
+ if isinstance(cond, (bool, np.bool_)):
223
+ if bool(cond):
224
+ default = expr
225
+ continue
226
+ condlist.append(cond)
227
+ choicelist.append(expr)
228
+
229
+ if not condlist:
230
+ return np.asarray(default)
231
+ return np.select(condlist, choicelist, default=default)
232
+
233
+ env = {
234
+ "np": np,
235
+ "sin": np.sin, "cos": np.cos, "tan": np.tan,
236
+ "sinh": np.sinh, "cosh": np.cosh, "tanh": np.tanh,
237
+ "arcsin": np.arcsin, "arccos": np.arccos, "arctan": np.arctan,
238
+ "log": np.log, "log10": np.log10, "log2": np.log2,
239
+ "exp": np.exp, "sqrt": np.sqrt, "abs": np.abs,
240
+ "pi": np.pi, "inf": np.inf,
241
+ "pos": pos, "sign": sign, "step": step_fn,
242
+ "heaviside": heaviside_fn, "Heaviside": heaviside_fn,
243
+ "clamp": clamp, "clip": np.clip,
244
+ "where": np.where,
245
+ "min": elementwise_min, "max": elementwise_max,
246
+ "minimum": np.minimum, "maximum": np.maximum,
247
+ "piecewise": piecewise, "Piecewise": piecewise,
248
+ }
249
+
250
+ env.update(self.external_variables)
251
+
252
+ if "t" not in env:
253
+ env["t"] = self.time
254
+
255
+ return env
256
+
257
+ def evaluate_rhs(self, rhs_expr: str, env: Dict[str, Any]) -> np.ndarray:
258
+ rhs_expr = _preprocess_expr(rhs_expr)
259
+ return eval(rhs_expr, {}, env)
89
260
 
90
- for key in keys:
91
- expr_str = odes.get(key, "0")
261
+ def evaluate_scalar(self, expr: str, env: Dict[str, Any]) -> float:
262
+ if expr in ("", "1"): return 1.0
263
+ val = self.evaluate_rhs(expr, env)
264
+ if np.isscalar(val): return float(val)
265
+ return float(np.mean(val))
266
+
267
+ def step(self, dt: float):
268
+ var_name, order, coeff_expr, rhs_expr = parse_ode(self.equation)
269
+
270
+ self.time += dt
271
+ env = self._build_eval_env()
272
+ env["t"] = self.time
273
+
274
+ # Inject state
275
+ if len(self.u_shape) == 1:
276
+ name = self.u_shape[0]
277
+ env[name] = self.u
278
+ v_t = self._u_t if self._u_t is not None else np.zeros_like(self.u)
279
+ env[f"{name}_t"] = v_t
280
+ else:
281
+ for i, name in enumerate(self.u_shape):
282
+ env[name] = self.u[i]
283
+ v_t = self._u_t[i] if self._u_t is not None else np.zeros_like(self.u[i])
284
+ env[f"{name}_t"] = v_t
285
+ # Also inject 'u' as general access if needed, or maybe not to avoid confusion?
286
+ # PDE does env["u"] = self.u.
287
+ env["u"] = self.u
288
+ env["u_t"] = self._u_t if self._u_t is not None else np.zeros_like(self.u)
289
+
92
290
  try:
93
- # Parse the expression safely using robust parser
94
- expr = robust_parse(str(expr_str))
291
+ rhs = self.evaluate_rhs(rhs_expr, env)
292
+ if isinstance(rhs, list):
293
+ rhs = np.array(rhs)
294
+
295
+ coeff = self.evaluate_scalar(coeff_expr, env)
296
+ forcing = rhs / (coeff if coeff else 1.0)
297
+
298
+ # Ensure shape compatibility if forcing is lower dim (e.g. constant vector [1,2] vs shape (2, N))
299
+ if isinstance(forcing, np.ndarray) and forcing.ndim < self.u.ndim:
300
+ diff = self.u.ndim - forcing.ndim
301
+ # Assume alignment on first dimension (components)
302
+ # Expand trailing dims
303
+ new_shape = forcing.shape + (1,) * diff
304
+ forcing = forcing.reshape(new_shape)
305
+
306
+ target_indices = []
307
+ if var_name == "u" and len(self.u_shape) == 1:
308
+ target_indices = [None]
309
+ elif var_name in self.u_shape:
310
+ idx = self.u_shape.index(var_name)
311
+ target_indices = [idx]
312
+ elif var_name == "u" and len(self.u_shape) > 1:
313
+ target_indices = range(len(self.u_shape))
95
314
 
96
- # Create a localized function
97
- # Arguments match the order we will call them
98
- func = sympy.lambdify((x, z, vx, vz, t), expr, modules=['numpy', 'math'])
99
- deriv_map[key] = func
315
+ # 2nd Order Euler / Semi-Implicit
316
+ if order == 2:
317
+ if self._u_t is None: self._u_t = np.zeros_like(self.u)
318
+
319
+ if var_name == "u" and len(self.u_shape) > 1:
320
+ self._u_t += dt * forcing
321
+ self.u += dt * self._u_t
322
+ elif len(target_indices) == 1:
323
+ idx = target_indices[0]
324
+ # Handle flat u for single component
325
+ is_flat = (len(self.u_shape) == 1)
326
+
327
+ target_u = self.u if (idx is None or is_flat) else self.u[idx]
328
+ target_ut = self._u_t if (idx is None or is_flat) else self._u_t[idx]
329
+
330
+ target_ut[:] += dt * forcing
331
+ target_u[:] += dt * target_ut
332
+
333
+ else:
334
+ # 1st Order Euler
335
+ if var_name == "u" and len(self.u_shape) > 1:
336
+ self.u += dt * forcing
337
+ elif len(target_indices) == 1:
338
+ idx = target_indices[0]
339
+ if idx is None or (len(self.u_shape) == 1):
340
+ self.u[:] += dt * forcing
341
+ else:
342
+ self.u[idx] += dt * forcing
343
+
100
344
  except Exception as e:
101
- print(f"Error parsing expression for {key}: {e}")
102
- return None
103
-
104
- def dynamics(particle, dt):
105
- # Current state
106
- cx, cz, cvx, cvz = particle.x, particle.z, particle.vx, particle.vz
107
- # We assume particle might track time, or we just pass 0 if autonomous
108
- ct = getattr(particle, 'time', 0.0)
345
+ print(f"Error stepping ODE: {e}")
346
+ raise e
347
+
348
+ def get_grid(self, u_state: np.ndarray = None, dt: float = 0.0) -> np.ndarray:
349
+ """
350
+ Calculates the change (du) or rate of change (forcing/dydt) for the current state
351
+ without modifying the internal state.
109
352
 
353
+ Args:
354
+ u_state: Optional state to substitute self.u
355
+ dt: if > 0, returns delta. If 0, returns rate.
356
+ """
357
+ original_u = self.u
358
+ original_u_t = self._u_t
359
+
360
+ if u_state is not None:
361
+ self.u = u_state
362
+ if self._u_t is not None and self._u_t.shape != u_state.shape:
363
+ self._u_t = np.zeros_like(u_state)
364
+ elif self._u_t is None:
365
+ self._u_t = np.zeros_like(u_state)
366
+
367
+ var_name, order, coeff_expr, rhs_expr = parse_ode(self.equation)
368
+
369
+ env = self._build_eval_env()
370
+ env["t"] = self.time
371
+
372
+ if len(self.u_shape) == 1:
373
+ name = self.u_shape[0]
374
+ env[name] = self.u
375
+ v_t = self._u_t if self._u_t is not None else np.zeros_like(self.u)
376
+ env[f"{name}_t"] = v_t
377
+ else:
378
+ for i, name in enumerate(self.u_shape):
379
+ env[name] = self.u[i]
380
+ v_t = self._u_t[i] if self._u_t is not None else np.zeros_like(self.u[i])
381
+ env[f"{name}_t"] = v_t
382
+ env["u"] = self.u
383
+ env["u_t"] = self._u_t if self._u_t is not None else np.zeros_like(self.u)
384
+
110
385
  try:
111
- # Calculate derivatives
112
- val_dx = deriv_map['dx'](cx, cz, cvx, cvz, ct)
113
- val_dz = deriv_map['dz'](cx, cz, cvx, cvz, ct)
114
- val_dvx = deriv_map['dvx'](cx, cz, cvx, cvz, ct)
115
- val_dvz = deriv_map['dvz'](cx, cz, cvx, cvz, ct)
116
-
117
- # Simple Euler Integration
118
- particle.x += float(val_dx) * dt
119
- particle.z += float(val_dz) * dt
120
- particle.vx += float(val_dvx) * dt
121
- particle.vz += float(val_dvz) * dt
122
-
123
- # Update time if tracked
124
- if hasattr(particle, 'time'):
125
- particle.time += dt
386
+ rhs = self.evaluate_rhs(rhs_expr, env)
387
+ if isinstance(rhs, list):
388
+ rhs = np.array(rhs)
126
389
 
390
+ coeff = self.evaluate_scalar(coeff_expr, env)
391
+ forcing = rhs / (coeff if coeff else 1.0)
392
+
393
+ if order == 2:
394
+ # 2nd Order
395
+ if dt > 0:
396
+ v = self._u_t if self._u_t is not None else np.zeros_like(self.u)
397
+ du = dt * v + 0.5 * (dt**2) * forcing
398
+ else:
399
+ du = forcing
400
+ else:
401
+ # 1st Order
402
+ if dt > 0:
403
+ du = forcing * dt
404
+ else:
405
+ du = forcing
406
+
127
407
  except Exception as e:
128
- # Prevent crashing the renderer on math errors (e.g. div by zero)
129
- print(f"Runtime error in dynamics: {e}")
408
+ self.u = original_u
409
+ self._u_t = original_u_t
410
+ raise e
411
+
412
+ self.u = original_u
413
+ self._u_t = original_u_t
414
+ return du
415
+
416
+ def to_json(self) -> str:
417
+ return json.dumps({
418
+ "equation": self.equation,
419
+ "desc": self.desc,
420
+ "u_shape": self.u_shape,
421
+ "initial": self.initial,
422
+ "variables": self.external_variables,
423
+ "time": self.time
424
+ }, indent=2)
130
425
 
131
- return dynamics
426
+ def __str__(self):
427
+ return self.to_json()
demathpy/pde.py CHANGED
@@ -20,8 +20,6 @@ def _preprocess_expr(expr: str) -> str:
20
20
  expr = (expr or "").strip()
21
21
  expr = re.sub(r"\(\s*approx\s*\)", "", expr, flags=re.IGNORECASE)
22
22
  expr = re.sub(r"\bapprox\b", "", expr, flags=re.IGNORECASE)
23
- if "=" in expr:
24
- expr = expr.split("=")[-1]
25
23
  expr = normalize_symbols(expr)
26
24
  return expr.strip()
27
25
 
@@ -651,6 +649,53 @@ class PDE:
651
649
  def pos(u): return np.maximum(u, 0.0)
652
650
  def sech(u): return 1.0 / np.cosh(u)
653
651
  def sign(u): return np.sign(u)
652
+ def step_fn(u): return np.heaviside(u, 1.0)
653
+ def heaviside_fn(u, h0=1.0): return np.heaviside(u, h0)
654
+ def clamp(u, lower, upper): return np.clip(u, lower, upper)
655
+
656
+ def elementwise_min(*values):
657
+ if not values:
658
+ raise ValueError("min requires at least one value")
659
+ out = values[0]
660
+ for value in values[1:]:
661
+ out = np.minimum(out, value)
662
+ return out
663
+
664
+ def elementwise_max(*values):
665
+ if not values:
666
+ raise ValueError("max requires at least one value")
667
+ out = values[0]
668
+ for value in values[1:]:
669
+ out = np.maximum(out, value)
670
+ return out
671
+
672
+ def piecewise(*cases):
673
+ if not cases:
674
+ raise ValueError("Piecewise requires at least one case")
675
+
676
+ # Also support numpy-style piecewise(cond, if_true, if_false).
677
+ if len(cases) == 3 and not isinstance(cases[0], (tuple, list)):
678
+ cond, if_true, if_false = cases
679
+ return np.where(cond, if_true, if_false)
680
+
681
+ condlist = []
682
+ choicelist = []
683
+ default = 0.0
684
+
685
+ for case in cases:
686
+ if not isinstance(case, (tuple, list)) or len(case) != 2:
687
+ raise ValueError("Piecewise cases must be (expr, condition) pairs")
688
+ expr, cond = case
689
+ if isinstance(cond, (bool, np.bool_)):
690
+ if bool(cond):
691
+ default = expr
692
+ continue
693
+ condlist.append(cond)
694
+ choicelist.append(expr)
695
+
696
+ if not condlist:
697
+ return np.asarray(default)
698
+ return np.select(condlist, choicelist, default=default)
654
699
 
655
700
  env = {
656
701
  "np": np,
@@ -666,6 +711,12 @@ class PDE:
666
711
  "gradmag": gradmag, "gradl1": gradl1,
667
712
  "grad": grad, "div": div, "advect": advect,
668
713
  "pos": pos, "sech": sech, "sign": sign,
714
+ "step": step_fn, "heaviside": heaviside_fn, "Heaviside": heaviside_fn,
715
+ "clamp": clamp, "clip": np.clip,
716
+ "where": np.where,
717
+ "min": elementwise_min, "max": elementwise_max,
718
+ "minimum": np.minimum, "maximum": np.maximum,
719
+ "piecewise": piecewise, "Piecewise": piecewise,
669
720
  }
670
721
 
671
722
  env.update(self.external_variables)
demathpy/symbols.py CHANGED
@@ -377,13 +377,17 @@ def normalize_symbols(expr: str) -> str:
377
377
  expr = re.sub(r"(\))\s*([a-zA-Z_])", r"\1*\2", expr)
378
378
  # Symbol followed by operator function: eta lap(u) -> eta*lap(u)
379
379
  expr = re.sub(
380
- r"([a-zA-Z_][a-zA-Z0-9_]*)\s*(lap|grad|div|dx|dz|dxx|dzz|pos|gradl1|gradmag)\s*\(",
380
+ r"([a-zA-Z_][a-zA-Z0-9_]*)\s*(lap|grad|div|dx|dz|dxx|dzz|pos|gradl1|gradmag|clamp|clip|piecewise|Piecewise|min|max|minimum|maximum|where)\s*\(",
381
381
  r"\1*\2(",
382
382
  expr,
383
383
  )
384
384
 
385
385
  # Fix accidental insertions like gradl1*(u) -> gradl1(u)
386
- expr = re.sub(r"\b(gradl1|gradmag|lap|dx|dz|dxx|dzz|grad|div|pos)\s*\*\s*\(", r"\1(", expr)
386
+ expr = re.sub(
387
+ r"\b(gradl1|gradmag|lap|dx|dz|dxx|dzz|grad|div|pos|clamp|clip|piecewise|Piecewise|min|max|minimum|maximum|where)\s*\*\s*\(",
388
+ r"\1(",
389
+ expr,
390
+ )
387
391
  # Avoid inserting * for known functions
388
392
  def _fn_mul(match):
389
393
  name = match.group(1)
@@ -393,8 +397,12 @@ def normalize_symbols(expr: str) -> str:
393
397
  "arcsin", "arccos", "arctan",
394
398
  "exp", "sqrt",
395
399
  "log", "log10", "log2",
396
- "abs", "sech", "sign",
400
+ "abs", "sech", "sign", "step", "heaviside",
397
401
  "lap", "dx", "dz", "dxx", "dzz", "grad", "div", "advect", "gradmag", "gradl1", "pos",
402
+ "clamp", "clip", "where",
403
+ "piecewise", "Piecewise",
404
+ "min", "max", "minimum", "maximum",
405
+ "Heaviside",
398
406
  }:
399
407
  return f"{name}("
400
408
  return f"{name}*("
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: demathpy
3
- Version: 0.1.0
3
+ Version: 0.1.2
4
4
  Summary: PDE/ODE math backend
5
5
  Author: Misekai
6
6
  Author-email: Misekai <mcore-us@misekai.net>
@@ -155,6 +155,53 @@ To integrate `Demathpy` into visualization software or interactive notebooks, yo
155
155
  # plt.show()
156
156
  ```
157
157
 
158
+ ### The ODE Class
159
+
160
+ Demathpy also efficiently solves Ordinary Differential Equations (ODEs) where the state depends only on time $t$. The API is identical to the PDE class.
161
+
162
+ ```python
163
+ from demathpy import ODE
164
+
165
+ # 1. EXPONENTIAL DECAY
166
+ # dy/dt = -y
167
+ o = ODE("dy/dt = -y", u_shape=["y"])
168
+ o.initial = ["y = 1.0"]
169
+ o.init_state(shape=(1,)) # scalar system
170
+ o.set_initial_state()
171
+
172
+ for _ in range(100):
173
+ o.step(dt=0.01)
174
+ print(o.u) # Should be close to exp(-1)
175
+
176
+ # 2. VECTOR SYSTEMS (Predator-Prey)
177
+ # du/dt = u - u*v
178
+ # dv/dt = u*v - v
179
+ pp = ODE("du/dt = [u[0] - u[0]*u[1], u[0]*u[1] - u[1]]", u_shape=["u", "v"])
180
+ pp.initial = ["u = 1.1", "v = 1.0"]
181
+ pp.init_state(shape=(1,)) # Single ecosystem
182
+ pp.set_initial_state()
183
+
184
+ pp.step(dt=0.1)
185
+
186
+ # 3. BATCHED EXECUTION
187
+ # Simulating 1000 identical particles with different initial conditions
188
+ particles = ODE("dx/dt = -x + noise") # noise not impl by default but external vars work
189
+ # Or just decay
190
+ batch = ODE("dy/dt = -y")
191
+ batch.init_state(shape=(1000,)) # 1000 systems
192
+ # Set random initial states directly (or use equation if supported)
193
+ import numpy as np
194
+ batch.u[:] = np.random.rand(1000)
195
+
196
+ batch.step(0.1)
197
+ ```
198
+
199
+ **Key ODE Features:**
200
+ - **Equation Parsing:** Supports `dy/dt`, `d^2y/dt^2`, vector syntax `[a, b]`.
201
+ - **Initialization:** Use `init_state(shape=...)` where shape defines the batch size (independent systems).
202
+ - **Probing:** `get_grid(u_state=..., dt=0)` works exactly like PDE for generating phase portraits (return vector field at state).
203
+ - **Functions:** Includes `sin, cos, exp, step, heaviside, sign, abs` ...
204
+
158
205
  ### License
159
206
 
160
207
  MIT
@@ -0,0 +1,8 @@
1
+ demathpy/__init__.py,sha256=gtvKvQKWns_OkAb9uyPmX8jIys0o62ccE_lMbi1_YQE,404
2
+ demathpy/ode.py,sha256=AKjNMWCQxKVCuT3f2Zl3x7ZfD0Aem-Rvw9TaNvwHq4Y,14875
3
+ demathpy/pde.py,sha256=pEVdtLXIMa14Vs89bCCg26k7CsJgXrWCJAM9L_i-Y-Q,33862
4
+ demathpy/symbols.py,sha256=XDCsyDg_fs1kgt5dRnme0UdWXH8e0_vL1t0Y0gY-rLc,14078
5
+ demathpy-0.1.2.dist-info/licenses/LICENSE,sha256=Vofo4OGPZaOEnCixsvF2MaZSGdpDI4uVMm9GNEjSGBM,1064
6
+ demathpy-0.1.2.dist-info/WHEEL,sha256=iHtWm8nRfs0VRdCYVXocAWFW8ppjHL-uTJkAdZJKOBM,80
7
+ demathpy-0.1.2.dist-info/METADATA,sha256=0gxgIbQjsd3mTT80W--tGyz9m5bbqt0jvw34KEdDdGY,6574
8
+ demathpy-0.1.2.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: uv 0.9.28
2
+ Generator: uv 0.9.30
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,8 +0,0 @@
1
- demathpy/__init__.py,sha256=O2m_YqVOiPaPVC7b9jJdObU9mFmPVZuvIvHf5JY2bKQ,448
2
- demathpy/ode.py,sha256=3SgeiYbHUov9NtrKCOhCo2I4DFLLdh5iIqkhPxwt254,4188
3
- demathpy/pde.py,sha256=v8tMQGKyelmYaQI4mMJmge4WW8KK6NGTCmZlmZ8XI-I,31788
4
- demathpy/symbols.py,sha256=kdONgSB_-9jL6DEGx9GyJMorYY1wTOZ6C-GV18_rMd0,13755
5
- demathpy-0.1.0.dist-info/licenses/LICENSE,sha256=Vofo4OGPZaOEnCixsvF2MaZSGdpDI4uVMm9GNEjSGBM,1064
6
- demathpy-0.1.0.dist-info/WHEEL,sha256=fAguSjoiATBe7TNBkJwOjyL1Tt4wwiaQGtNtjRPNMQA,80
7
- demathpy-0.1.0.dist-info/METADATA,sha256=AUpR7HK1cHuDJMWQNYubUs6VN-bnMCHcB60HprMWdLc,5018
8
- demathpy-0.1.0.dist-info/RECORD,,