sindy-exp 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,452 @@
1
+ """Utilities to convert a `dysts` dynamical system object's rhs to SymPy.
2
+
3
+ This module inspects the source of an object's RHS method (by default
4
+ named ``rhs``), parses the function using ``ast``, and converts the
5
+ returned expression(s) into SymPy expressions.
6
+
7
+ The conversion is intentionally conservative and aims to handle common
8
+ patterns used in simple rhs implementations, e.g. returning a tuple/list
9
+ of arithmetic expressions, using indexing into a state vector (``x[0]``),
10
+ and calls to common ``numpy``/``math`` functions (``np.sin``, ``math.exp``, ...).
11
+
12
+ Limitations:
13
+ - It does not execute arbitrary code from the inspected function.
14
+ - Complex control flow, loops, or non-trivial Python constructs may not
15
+ be fully supported.
16
+
17
+ Example
18
+ -------
19
+ from dysts.flows import Lorenz
20
+ from inspect_to_sympy import object_to_sympy_rhs
21
+
22
+ lor = Lorenz()
23
+ symbols, exprs, lambda_rhs = object_to_sympy_rhs(lor)
24
+ # `symbols` is a list of SymPy symbols for the state vector
25
+ # `exprs` is a list of SymPy expressions for the RHS
26
+ # `lambda_rhs` is a SymPy Lambda mapping state symbols -> rhs expressions
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ import ast
32
+ import inspect
33
+ import textwrap
34
+ from typing import Any, Callable, Dict, List, Tuple
35
+
36
+ import numpy as np
37
+ import sympy as sp
38
+ from dysts.base import BaseDyn
39
+
40
+
41
+ def _is_name(node: ast.AST, name: str) -> bool:
42
+ return isinstance(node, ast.Name) and node.id == name
43
+
44
+
45
+ class _ASTToSympy(ast.NodeVisitor):
46
+ def __init__(
47
+ self,
48
+ state_name: str,
49
+ state_symbols: List[sp.Symbol],
50
+ locals_map: Dict[str, Any],
51
+ ):
52
+ self.state_name = state_name
53
+ self.state_symbols = state_symbols
54
+ self.locals = dict(locals_map)
55
+
56
+ def generic_visit(self, node):
57
+ raise NotImplementedError(f"AST node not supported: {node!r}")
58
+
59
+ def visit_Constant(self, node: ast.Constant):
60
+ return sp.sympify(node.value)
61
+
62
+ def visit_Num(self, node: ast.Num):
63
+ return sp.sympify(node.n)
64
+
65
+ def visit_Name(self, node: ast.Name):
66
+ if node.id in self.locals:
67
+ return self.locals[node.id]
68
+ return sp.Symbol(node.id)
69
+
70
+ def visit_Tuple(self, node: ast.Tuple):
71
+ elems = []
72
+ for elt in node.elts:
73
+ val = self.visit(elt)
74
+ if isinstance(val, (list, tuple)):
75
+ elems.extend(list(val))
76
+ else:
77
+ elems.append(val)
78
+ return tuple(elems)
79
+
80
+ def visit_List(self, node: ast.List):
81
+ elems = []
82
+ for elt in node.elts:
83
+ val = self.visit(elt)
84
+ if isinstance(val, (list, tuple)):
85
+ elems.extend(list(val))
86
+ else:
87
+ elems.append(val)
88
+ return elems
89
+
90
+ def visit_Starred(self, node: ast.Starred):
91
+ # Handle starred expressions like `*x` in list/tuple literals.
92
+ # If the starred value is the state vector name, expand to state symbols.
93
+ if isinstance(node.value, ast.Name) and node.value.id == self.state_name:
94
+ return tuple(self.state_symbols)
95
+ # Otherwise, evaluate the value and if it is a sequence, return its items
96
+ val = self.visit(node.value)
97
+ if isinstance(val, (list, tuple)):
98
+ return tuple(val)
99
+ raise NotImplementedError(
100
+ "Unsupported starred expression; cannot expand non-iterable"
101
+ )
102
+
103
+ def visit_BinOp(self, node: ast.BinOp):
104
+ left = self.visit(node.left)
105
+ right = self.visit(node.right)
106
+ if isinstance(node.op, ast.Add):
107
+ return left + right
108
+ if isinstance(node.op, ast.Sub):
109
+ return left - right
110
+ if isinstance(node.op, ast.Mult):
111
+ return left * right
112
+ if isinstance(node.op, ast.Div):
113
+ return left / right
114
+ if isinstance(node.op, ast.Pow):
115
+ return left**right
116
+ if isinstance(node.op, ast.Mod):
117
+ return left % right
118
+ raise NotImplementedError(f"Binary op not supported: {node.op!r}")
119
+
120
+ def visit_UnaryOp(self, node: ast.UnaryOp):
121
+ operand = self.visit(node.operand)
122
+ if isinstance(node.op, ast.USub):
123
+ return -operand
124
+ if isinstance(node.op, ast.UAdd):
125
+ return +operand
126
+ raise NotImplementedError(f"Unary op not supported: {node.op!r}")
127
+
128
+ def visit_Call(self, node: ast.Call):
129
+ # Determine function name
130
+ func = node.func
131
+ func_name = None
132
+ mod_name = None
133
+
134
+ if isinstance(func, ast.Name):
135
+ func_name = func.id
136
+ elif isinstance(func, ast.Attribute):
137
+ # e.g. np.sin or math.exp
138
+ if isinstance(func.value, ast.Name):
139
+ mod_name = func.value.id
140
+ func_name = func.attr
141
+ else:
142
+ raise NotImplementedError(
143
+ f"Call to unsupported func node: {ast.dump(func)}"
144
+ )
145
+
146
+ # Map common numpy/math functions to sympy
147
+ func_map = {
148
+ "sin": sp.sin,
149
+ "cos": sp.cos,
150
+ "tan": sp.tan,
151
+ "exp": sp.exp,
152
+ "log": sp.log,
153
+ "sqrt": sp.sqrt,
154
+ "abs": sp.Abs,
155
+ "atan": sp.atan,
156
+ "asin": sp.asin,
157
+ "acos": sp.acos,
158
+ }
159
+
160
+ args = [self.visit(a) for a in node.args]
161
+
162
+ # Special-case array constructors: return underlying list/tuple
163
+ if func_name in ("array", "asarray") and mod_name in ("np", "numpy"):
164
+ # expect a single positional arg that's a list/tuple
165
+ if len(args) == 1:
166
+ return args[0]
167
+
168
+ if func_name in func_map:
169
+ return func_map[func_name](*args)
170
+
171
+ # Unknown function: create a Sympy Function
172
+ symf = sp.Function(func_name)
173
+ return symf(*args)
174
+
175
+ def visit_Subscript(self, node: ast.Subscript):
176
+ # Support patterns like x[0] where x is the state vector name
177
+ value = node.value
178
+ # handle simple constant index
179
+ if _is_name(value, self.state_name):
180
+ # Python >=3.9: slice is directly the node.slice
181
+ idx_node = node.slice
182
+ if isinstance(idx_node, ast.Constant):
183
+ idx = idx_node.value
184
+ else:
185
+ raise NotImplementedError(
186
+ "Only constant indices into state vector supported"
187
+ )
188
+ return self.state_symbols[idx]
189
+
190
+ # If it's something else, try to evaluate generically
191
+ base = self.visit(value)
192
+ # slice may be constant
193
+ if isinstance(node.slice, ast.Constant):
194
+ key = node.slice.value
195
+ return base[key]
196
+ raise NotImplementedError("Unsupported subscript pattern")
197
+
198
+
199
+ def _numeric_consistency_check(
200
+ dysts_flow: BaseDyn,
201
+ rhsfunc: Callable,
202
+ arg_names: List[str],
203
+ state_names: List[str],
204
+ vector_mode: bool,
205
+ sys_dim: int,
206
+ lambda_rhs: sp.Lambda,
207
+ ) -> None:
208
+ """Compare the original dysts rhs function to the SymPy-derived lambda.
209
+
210
+ Raises a RuntimeError if they disagree.
211
+ """
212
+ # default to nonnegative support (e.g. Lotka volterra)
213
+ random_state = np.random.standard_exponential(size=sys_dim)
214
+
215
+ # Construct call arguments for the original function (bound method).
216
+ call_args = []
217
+ for name in arg_names:
218
+ if name == "self":
219
+ continue
220
+ if name in state_names and not vector_mode:
221
+ idx = state_names.index(name)
222
+ call_args.append(random_state[idx])
223
+ elif name in state_names and vector_mode:
224
+ call_args.append(np.asarray(random_state, dtype=float))
225
+ elif name == "t":
226
+ call_args.append(float(np.random.standard_normal(size=())))
227
+ else:
228
+ call_args.append(dysts_flow.params[name])
229
+
230
+ dysts_val = rhsfunc(*call_args)
231
+ orig_arr = np.asarray(dysts_val, dtype=float).ravel()
232
+
233
+ sym_val = lambda_rhs(*tuple(random_state))
234
+ sym_arr = np.asarray(sym_val, dtype=float).ravel()
235
+
236
+ if orig_arr.shape != sym_arr.shape:
237
+ raise RuntimeError(
238
+ f"_rhs shape {orig_arr.shape} != sympy shape {sym_arr.shape}"
239
+ )
240
+
241
+ if not np.allclose(orig_arr, sym_arr, rtol=1e-6, atol=1e-9):
242
+ raise RuntimeError("Numeric mismatch between original and sympy conversion.")
243
+
244
+
245
+ def dynsys_to_sympy(
246
+ obj: Any, func_name: str = "_rhs"
247
+ ) -> Tuple[List[sp.Symbol], List[sp.Expr], sp.Lambda]:
248
+ """Inspect ``obj`` for a method named ``func_name`` and return a SymPy
249
+ representation of its RHS.
250
+
251
+ Returns:
252
+ a tuple ``(state_symbols, exprs, lambda_rhs)`` where ``state_symbols``
253
+ is a list of SymPy symbols for the state vector, ``exprs`` is a list of
254
+ SymPy expressions for the RHS components, and ``lambda_rhs`` is a SymPy
255
+ Lambda mapping the state symbols to the RHS vector.
256
+
257
+ Example:
258
+
259
+ >>> from dysts.flows import Lorenz
260
+ >>> from inspect_to_sympy import dynsys_to_sympy
261
+ >>> lor = Lorenz()
262
+ >>> symbols, exprs, lambda_rhs = dynsys_to_sympy(lor)
263
+ >>> print(lor._rhs(1, 2, 3, t=0.0, **lor.params))
264
+ (10, 23, -6.0009999999999994)
265
+
266
+ >>> print(tuple(lambda_rhs(1, 2, 3)))
267
+ (10, 23, -6.00100000000000)
268
+
269
+ """
270
+
271
+ if not hasattr(obj, func_name):
272
+ raise AttributeError(f"Object has no attribute {func_name!r}")
273
+
274
+ func = getattr(obj, func_name)
275
+ src = inspect.getsource(func)
276
+ src = textwrap.dedent(src)
277
+
278
+ parsed = ast.parse(src)
279
+
280
+ # Find first FunctionDef
281
+ fndef = None
282
+ for node in parsed.body:
283
+ if isinstance(node, ast.FunctionDef):
284
+ fndef = node
285
+ break
286
+ if fndef is None:
287
+ raise RuntimeError("No function definition found in source")
288
+
289
+ # Determine state argument names. Common dysts signature:
290
+ # (self, *states, t, *parameters). Prefer obj.dimension when available.
291
+ arg_names = [a.arg for a in fndef.args.args]
292
+ if len(arg_names) == 0:
293
+ raise RuntimeError("Function has no arguments")
294
+
295
+ start_idx = 0
296
+ if arg_names[0] == "self":
297
+ start_idx = 1
298
+
299
+ vector_mode = False
300
+ state_args: List[str]
301
+ t_idx = None
302
+ if "t" in arg_names:
303
+ t_idx = arg_names.index("t")
304
+
305
+ if hasattr(obj, "dimension") and isinstance(getattr(obj, "dimension"), int):
306
+ n_state = int(getattr(obj, "dimension"))
307
+ if t_idx is not None:
308
+ potential = arg_names[start_idx:t_idx]
309
+ if len(potential) >= n_state:
310
+ state_args = potential[:n_state]
311
+ else:
312
+ state_args = [arg_names[start_idx]]
313
+ vector_mode = True
314
+ else:
315
+ potential = arg_names[start_idx:]
316
+ if len(potential) >= n_state:
317
+ state_args = potential[:n_state]
318
+ else:
319
+ state_args = [arg_names[start_idx]]
320
+ vector_mode = True
321
+ else:
322
+ if t_idx is not None:
323
+ state_args = arg_names[start_idx:t_idx]
324
+ if len(state_args) == 0:
325
+ state_args = [arg_names[start_idx]]
326
+ vector_mode = True
327
+ elif len(state_args) == 1:
328
+ # single name could be vector or scalar; assume vector-mode
329
+ vector_mode = True
330
+ else:
331
+ state_args = [arg_names[start_idx]]
332
+ vector_mode = True
333
+
334
+ # If vector_mode, inspect AST for subscript/index usage or tuple unpacking
335
+ if vector_mode:
336
+ state_name = state_args[0]
337
+ max_index = -1
338
+ unpack_size = None
339
+ for node in ast.walk(fndef):
340
+ if (
341
+ isinstance(node, ast.Subscript)
342
+ and isinstance(node.value, ast.Name)
343
+ and node.value.id == state_name
344
+ ):
345
+ sl = node.slice
346
+ if isinstance(sl, ast.Constant) and isinstance(sl.value, int):
347
+ if sl.value > max_index:
348
+ max_index = sl.value
349
+ if isinstance(node, ast.Assign):
350
+ if isinstance(node.value, ast.Name) and node.value.id == state_name:
351
+ targets = node.targets
352
+ if len(targets) == 1 and isinstance(
353
+ targets[0], (ast.Tuple, ast.List)
354
+ ):
355
+ unpack_size = len(targets[0].elts)
356
+
357
+ if unpack_size is not None:
358
+ n_state = unpack_size
359
+ elif max_index >= 0:
360
+ n_state = max_index + 1
361
+ else:
362
+ n_state = int(getattr(obj, "dimension", 3))
363
+
364
+ state_symbols = [sp.Symbol(f"x{i}") for i in range(n_state)]
365
+ primary_state_name = state_name
366
+ else:
367
+ # individual state args -> use their arg names as symbol names
368
+ state_symbols = [sp.Symbol(n) for n in state_args]
369
+ primary_state_name = state_args[0] if len(state_args) > 0 else "x"
370
+
371
+ # Build locals mapping from known state arg names and parameters
372
+ locals_map: Dict[str, Any] = {}
373
+ for i, name in enumerate(state_args):
374
+ if i < len(state_symbols):
375
+ locals_map[name] = state_symbols[i]
376
+
377
+ # map parameters (if present) to numeric values or symbols
378
+ if hasattr(obj, "parameters") and isinstance(getattr(obj, "parameters"), dict):
379
+ params = getattr(obj, "parameters")
380
+ if t_idx is not None:
381
+ param_arg_names = arg_names[t_idx + 1 :]
382
+ else:
383
+ param_arg_names = []
384
+ for pname in param_arg_names:
385
+ if pname in params:
386
+ locals_map[pname] = sp.sympify(params[pname])
387
+ else:
388
+ locals_map[pname] = sp.Symbol(pname)
389
+
390
+ converter = _ASTToSympy(primary_state_name, state_symbols, locals_map)
391
+
392
+ return_expr = None
393
+ # Walk through function body statements, handle Assign and Return
394
+ for stmt in fndef.body:
395
+ if isinstance(stmt, ast.Assign):
396
+ # only simple single-target assignments supported
397
+ if len(stmt.targets) != 1:
398
+ raise ValueError("Only single-target assignments supported")
399
+ target = stmt.targets[0]
400
+ if isinstance(target, ast.Name):
401
+ value_expr = converter.visit(stmt.value)
402
+ locals_map[target.id] = value_expr
403
+ elif (
404
+ isinstance(target, (ast.Tuple, ast.List))
405
+ and isinstance(stmt.value, ast.Name)
406
+ and stmt.value.id == state_name
407
+ ):
408
+ # unpacking like a,b,c = x -> map names to state symbols
409
+ for i, elt in enumerate(target.elts):
410
+ if isinstance(elt, ast.Name):
411
+ locals_map[elt.id] = state_symbols[i]
412
+ elif isinstance(stmt, ast.Return):
413
+ return_expr = stmt.value
414
+
415
+ if return_expr is None:
416
+ # maybe last statement is an Expr with list construction;
417
+ # try to find a Return node deep
418
+ for node in ast.walk(fndef):
419
+ if isinstance(node, ast.Return):
420
+ return_expr = node.value
421
+ break
422
+
423
+ if return_expr is None:
424
+ raise RuntimeError("No return expression found in function body")
425
+
426
+ # Refresh converter with updated locals
427
+ converter = _ASTToSympy(primary_state_name, state_symbols, locals_map)
428
+ rhs_val = converter.visit(return_expr)
429
+
430
+ if isinstance(rhs_val, (list, tuple)):
431
+ exprs = list(rhs_val)
432
+ else:
433
+ # single expression: treat as 1-dim RHS
434
+ exprs = [rhs_val]
435
+
436
+ lambda_rhs = sp.Lambda(tuple(state_symbols), sp.Matrix(exprs))
437
+
438
+ # Run numeric consistency guard (raises on mismatch)
439
+ _numeric_consistency_check(
440
+ obj,
441
+ func,
442
+ arg_names,
443
+ state_args,
444
+ vector_mode,
445
+ len(state_symbols),
446
+ lambda_rhs,
447
+ )
448
+
449
+ return state_symbols, exprs, lambda_rhs
450
+
451
+
452
+ __all__ = ["dynsys_to_sympy"]