sympy2jax 0.0.7__tar.gz → 0.0.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sympy2jax-0.0.7 → sympy2jax-0.0.8}/PKG-INFO +1 -1
- {sympy2jax-0.0.7 → sympy2jax-0.0.8}/pyproject.toml +3 -3
- {sympy2jax-0.0.7 → sympy2jax-0.0.8}/sympy2jax/sympy_module.py +91 -80
- {sympy2jax-0.0.7 → sympy2jax-0.0.8}/.gitignore +0 -0
- {sympy2jax-0.0.7 → sympy2jax-0.0.8}/LICENSE +0 -0
- {sympy2jax-0.0.7 → sympy2jax-0.0.8}/README.md +0 -0
- {sympy2jax-0.0.7 → sympy2jax-0.0.8}/sympy2jax/__init__.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "sympy2jax"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.8"
|
|
4
4
|
description = "Turn SymPy expressions into trainable JAX expressions."
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
requires-python ="~=3.9"
|
|
@@ -31,13 +31,13 @@ include = ["sympy2jax/*"]
|
|
|
31
31
|
[tool.pytest.ini_options]
|
|
32
32
|
addopts = "--jaxtyping-packages=symyp2jax,beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On))"
|
|
33
33
|
|
|
34
|
-
[tool.ruff]
|
|
34
|
+
[tool.ruff.lint]
|
|
35
35
|
select = ["E", "F", "I001"]
|
|
36
36
|
ignore = ["E402", "E721", "E731", "E741", "F722"]
|
|
37
37
|
ignore-init-module-imports = true
|
|
38
38
|
fixable = ["I001", "F401"]
|
|
39
39
|
|
|
40
|
-
[tool.ruff.isort]
|
|
40
|
+
[tool.ruff.lint.isort]
|
|
41
41
|
combine-as-imports = true
|
|
42
42
|
lines-after-imports = 2
|
|
43
43
|
extra-standard-library = ["typing_extensions"]
|
|
@@ -23,13 +23,13 @@ import jax
|
|
|
23
23
|
import jax.numpy as jnp
|
|
24
24
|
import jax.scipy as jsp
|
|
25
25
|
import jax.tree_util as jtu
|
|
26
|
-
import sympy
|
|
26
|
+
import sympy as sympy_module
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
PyTree = Any
|
|
30
30
|
|
|
31
|
-
concatenate: Callable =
|
|
32
|
-
stack: Callable =
|
|
31
|
+
concatenate: Callable = sympy_module.Function("concatenate") # pyright: ignore
|
|
32
|
+
stack: Callable = sympy_module.Function("stack") # pyright: ignore
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
def _reduce(fn):
|
|
@@ -49,56 +49,56 @@ def _single_args(fn):
|
|
|
49
49
|
_lookup = {
|
|
50
50
|
concatenate: _single_args(jnp.concatenate),
|
|
51
51
|
stack: _single_args(jnp.stack),
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
52
|
+
sympy_module.Mul: _reduce(jnp.multiply),
|
|
53
|
+
sympy_module.Add: _reduce(jnp.add),
|
|
54
|
+
sympy_module.div: jnp.divide,
|
|
55
|
+
sympy_module.Abs: jnp.abs,
|
|
56
|
+
sympy_module.sign: jnp.sign,
|
|
57
|
+
sympy_module.ceiling: jnp.ceil,
|
|
58
|
+
sympy_module.floor: jnp.floor,
|
|
59
|
+
sympy_module.log: jnp.log,
|
|
60
|
+
sympy_module.exp: jnp.exp,
|
|
61
|
+
sympy_module.sqrt: jnp.sqrt,
|
|
62
|
+
sympy_module.cos: jnp.cos,
|
|
63
|
+
sympy_module.acos: jnp.arccos,
|
|
64
|
+
sympy_module.sin: jnp.sin,
|
|
65
|
+
sympy_module.asin: jnp.arcsin,
|
|
66
|
+
sympy_module.tan: jnp.tan,
|
|
67
|
+
sympy_module.atan: jnp.arctan,
|
|
68
|
+
sympy_module.atan2: jnp.arctan2,
|
|
69
|
+
sympy_module.cosh: jnp.cosh,
|
|
70
|
+
sympy_module.acosh: jnp.arccosh,
|
|
71
|
+
sympy_module.sinh: jnp.sinh,
|
|
72
|
+
sympy_module.asinh: jnp.arcsinh,
|
|
73
|
+
sympy_module.tanh: jnp.tanh,
|
|
74
|
+
sympy_module.atanh: jnp.arctanh,
|
|
75
|
+
sympy_module.Pow: jnp.power,
|
|
76
|
+
sympy_module.re: jnp.real,
|
|
77
|
+
sympy_module.im: jnp.imag,
|
|
78
|
+
sympy_module.arg: jnp.angle,
|
|
79
|
+
sympy_module.erf: jsp.special.erf,
|
|
80
|
+
sympy_module.Eq: jnp.equal,
|
|
81
|
+
sympy_module.Ne: jnp.not_equal,
|
|
82
|
+
sympy_module.StrictGreaterThan: jnp.greater,
|
|
83
|
+
sympy_module.StrictLessThan: jnp.less,
|
|
84
|
+
sympy_module.LessThan: jnp.less_equal,
|
|
85
|
+
sympy_module.GreaterThan: jnp.greater_equal,
|
|
86
|
+
sympy_module.And: jnp.logical_and,
|
|
87
|
+
sympy_module.Or: jnp.logical_or,
|
|
88
|
+
sympy_module.Not: jnp.logical_not,
|
|
89
|
+
sympy_module.Xor: jnp.logical_xor,
|
|
90
|
+
sympy_module.Max: _reduce(jnp.maximum),
|
|
91
|
+
sympy_module.Min: _reduce(jnp.minimum),
|
|
92
|
+
sympy_module.MatAdd: _reduce(jnp.add),
|
|
93
|
+
sympy_module.Trace: jnp.trace,
|
|
94
|
+
sympy_module.Determinant: jnp.linalg.det,
|
|
95
95
|
}
|
|
96
96
|
|
|
97
97
|
_constant_lookup = {
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
98
|
+
sympy_module.E: jnp.e,
|
|
99
|
+
sympy_module.pi: jnp.pi,
|
|
100
|
+
sympy_module.EulerGamma: jnp.euler_gamma,
|
|
101
|
+
sympy_module.I: 1j,
|
|
102
102
|
}
|
|
103
103
|
|
|
104
104
|
_reverse_lookup = {v: k for k, v in _lookup.items()}
|
|
@@ -118,7 +118,7 @@ class _AbstractNode(eqx.Module):
|
|
|
118
118
|
...
|
|
119
119
|
|
|
120
120
|
@abc.abstractmethod
|
|
121
|
-
def sympy(self, memodict: dict, func_lookup: dict) ->
|
|
121
|
+
def sympy(self, memodict: dict, func_lookup: dict) -> sympy_module.Expr:
|
|
122
122
|
...
|
|
123
123
|
|
|
124
124
|
# Comparisons based on identity
|
|
@@ -129,7 +129,7 @@ class _AbstractNode(eqx.Module):
|
|
|
129
129
|
class _Symbol(_AbstractNode):
|
|
130
130
|
_name: str
|
|
131
131
|
|
|
132
|
-
def __init__(self, expr:
|
|
132
|
+
def __init__(self, expr: sympy_module.Expr):
|
|
133
133
|
self._name = str(expr.name) # pyright: ignore
|
|
134
134
|
|
|
135
135
|
def __call__(self, memodict: dict):
|
|
@@ -138,9 +138,9 @@ class _Symbol(_AbstractNode):
|
|
|
138
138
|
except KeyError as e:
|
|
139
139
|
raise KeyError(f"Missing input for symbol {self._name}") from e
|
|
140
140
|
|
|
141
|
-
def sympy(self, memodict: dict, func_lookup: dict) ->
|
|
141
|
+
def sympy(self, memodict: dict, func_lookup: dict) -> sympy_module.Expr:
|
|
142
142
|
# memodict not needed as sympy deduplicates internally
|
|
143
|
-
return
|
|
143
|
+
return sympy_module.Symbol(self._name)
|
|
144
144
|
|
|
145
145
|
|
|
146
146
|
def _maybe_array(val, make_array):
|
|
@@ -153,39 +153,39 @@ def _maybe_array(val, make_array):
|
|
|
153
153
|
class _Integer(_AbstractNode):
|
|
154
154
|
_value: jax.typing.ArrayLike
|
|
155
155
|
|
|
156
|
-
def __init__(self, expr:
|
|
157
|
-
assert isinstance(expr,
|
|
156
|
+
def __init__(self, expr: sympy_module.Expr, make_array: bool):
|
|
157
|
+
assert isinstance(expr, sympy_module.Integer)
|
|
158
158
|
self._value = _maybe_array(int(expr), make_array)
|
|
159
159
|
|
|
160
160
|
def __call__(self, memodict: dict):
|
|
161
161
|
return self._value
|
|
162
162
|
|
|
163
|
-
def sympy(self, memodict: dict, func_lookup: dict) ->
|
|
163
|
+
def sympy(self, memodict: dict, func_lookup: dict) -> sympy_module.Expr:
|
|
164
164
|
# memodict not needed as sympy deduplicates internally
|
|
165
|
-
return
|
|
165
|
+
return sympy_module.Integer(_item(self._value))
|
|
166
166
|
|
|
167
167
|
|
|
168
168
|
class _Float(_AbstractNode):
|
|
169
169
|
_value: jax.typing.ArrayLike
|
|
170
170
|
|
|
171
|
-
def __init__(self, expr:
|
|
172
|
-
assert isinstance(expr,
|
|
171
|
+
def __init__(self, expr: sympy_module.Expr, make_array: bool):
|
|
172
|
+
assert isinstance(expr, sympy_module.Float)
|
|
173
173
|
self._value = _maybe_array(float(expr), make_array)
|
|
174
174
|
|
|
175
175
|
def __call__(self, memodict: dict):
|
|
176
176
|
return self._value
|
|
177
177
|
|
|
178
|
-
def sympy(self, memodict: dict, func_lookup: dict) ->
|
|
178
|
+
def sympy(self, memodict: dict, func_lookup: dict) -> sympy_module.Expr:
|
|
179
179
|
# memodict not needed as sympy deduplicates internally
|
|
180
|
-
return
|
|
180
|
+
return sympy_module.Float(_item(self._value))
|
|
181
181
|
|
|
182
182
|
|
|
183
183
|
class _Rational(_AbstractNode):
|
|
184
184
|
_numerator: jax.typing.ArrayLike
|
|
185
185
|
_denominator: jax.typing.ArrayLike
|
|
186
186
|
|
|
187
|
-
def __init__(self, expr:
|
|
188
|
-
assert isinstance(expr,
|
|
187
|
+
def __init__(self, expr: sympy_module.Expr, make_array: bool):
|
|
188
|
+
assert isinstance(expr, sympy_module.Rational)
|
|
189
189
|
numerator = expr.numerator
|
|
190
190
|
denominator = expr.denominator
|
|
191
191
|
if callable(numerator):
|
|
@@ -199,18 +199,18 @@ class _Rational(_AbstractNode):
|
|
|
199
199
|
def __call__(self, memodict: dict):
|
|
200
200
|
return self._numerator / self._denominator
|
|
201
201
|
|
|
202
|
-
def sympy(self, memodict: dict, func_lookup: dict) ->
|
|
202
|
+
def sympy(self, memodict: dict, func_lookup: dict) -> sympy_module.Expr:
|
|
203
203
|
# memodict not needed as sympy deduplicates internally
|
|
204
|
-
return
|
|
204
|
+
return sympy_module.Integer(_item(self._numerator)) / sympy_module.Integer(
|
|
205
205
|
_item(self._denominator)
|
|
206
206
|
)
|
|
207
207
|
|
|
208
208
|
|
|
209
209
|
class _Constant(_AbstractNode):
|
|
210
210
|
_value: jnp.ndarray
|
|
211
|
-
_expr:
|
|
211
|
+
_expr: sympy_module.Expr
|
|
212
212
|
|
|
213
|
-
def __init__(self, expr:
|
|
213
|
+
def __init__(self, expr: sympy_module.Expr, make_array: bool):
|
|
214
214
|
assert expr in _constant_lookup
|
|
215
215
|
self._value = _maybe_array(_constant_lookup[expr], make_array)
|
|
216
216
|
self._expr = expr
|
|
@@ -218,7 +218,7 @@ class _Constant(_AbstractNode):
|
|
|
218
218
|
def __call__(self, memodict: dict):
|
|
219
219
|
return self._value
|
|
220
220
|
|
|
221
|
-
def sympy(self, memodict: dict, func_lookup: dict) ->
|
|
221
|
+
def sympy(self, memodict: dict, func_lookup: dict) -> sympy_module.Expr:
|
|
222
222
|
return self._expr
|
|
223
223
|
|
|
224
224
|
|
|
@@ -227,14 +227,20 @@ class _Func(_AbstractNode):
|
|
|
227
227
|
_args: list
|
|
228
228
|
|
|
229
229
|
def __init__(
|
|
230
|
-
self,
|
|
230
|
+
self,
|
|
231
|
+
expr: sympy_module.Expr,
|
|
232
|
+
memodict: dict,
|
|
233
|
+
func_lookup: Mapping,
|
|
234
|
+
make_array: bool,
|
|
231
235
|
):
|
|
232
236
|
try:
|
|
233
237
|
self._func = func_lookup[expr.func]
|
|
234
238
|
except KeyError as e:
|
|
235
239
|
raise KeyError(f"Unsupported Sympy type {type(expr)}") from e
|
|
236
240
|
self._args = [
|
|
237
|
-
_sympy_to_node(
|
|
241
|
+
_sympy_to_node(
|
|
242
|
+
cast(sympy_module.Expr, arg), memodict, func_lookup, make_array
|
|
243
|
+
)
|
|
238
244
|
for arg in expr.args
|
|
239
245
|
]
|
|
240
246
|
|
|
@@ -249,7 +255,7 @@ class _Func(_AbstractNode):
|
|
|
249
255
|
args.append(arg_call)
|
|
250
256
|
return self._func(*args)
|
|
251
257
|
|
|
252
|
-
def sympy(self, memodict: dict, func_lookup: dict) ->
|
|
258
|
+
def sympy(self, memodict: dict, func_lookup: dict) -> sympy_module.Expr:
|
|
253
259
|
try:
|
|
254
260
|
return memodict[self]
|
|
255
261
|
except KeyError:
|
|
@@ -261,20 +267,25 @@ class _Func(_AbstractNode):
|
|
|
261
267
|
|
|
262
268
|
|
|
263
269
|
def _sympy_to_node(
|
|
264
|
-
expr:
|
|
270
|
+
expr: sympy_module.Expr, memodict: dict, func_lookup: Mapping, make_array: bool
|
|
265
271
|
) -> _AbstractNode:
|
|
266
272
|
try:
|
|
267
273
|
return memodict[expr]
|
|
268
274
|
except KeyError:
|
|
269
|
-
if isinstance(expr,
|
|
275
|
+
if isinstance(expr, sympy_module.Symbol):
|
|
270
276
|
out = _Symbol(expr)
|
|
271
|
-
elif isinstance(expr,
|
|
277
|
+
elif isinstance(expr, sympy_module.Integer):
|
|
272
278
|
out = _Integer(expr, make_array)
|
|
273
|
-
elif isinstance(expr,
|
|
279
|
+
elif isinstance(expr, sympy_module.Float):
|
|
274
280
|
out = _Float(expr, make_array)
|
|
275
|
-
elif isinstance(expr,
|
|
281
|
+
elif isinstance(expr, sympy_module.Rational):
|
|
276
282
|
out = _Rational(expr, make_array)
|
|
277
|
-
elif expr in (
|
|
283
|
+
elif expr in (
|
|
284
|
+
sympy_module.E,
|
|
285
|
+
sympy_module.pi,
|
|
286
|
+
sympy_module.EulerGamma,
|
|
287
|
+
sympy_module.I,
|
|
288
|
+
):
|
|
278
289
|
out = _Constant(expr, make_array)
|
|
279
290
|
else:
|
|
280
291
|
out = _Func(expr, memodict, func_lookup, make_array)
|
|
@@ -288,7 +299,7 @@ def _is_node(x):
|
|
|
288
299
|
|
|
289
300
|
class SymbolicModule(eqx.Module):
|
|
290
301
|
nodes: PyTree
|
|
291
|
-
has_extra_funcs: bool = eqx.
|
|
302
|
+
has_extra_funcs: bool = eqx.field(static=True)
|
|
292
303
|
|
|
293
304
|
def __init__(
|
|
294
305
|
self,
|
|
@@ -310,7 +321,7 @@ class SymbolicModule(eqx.Module):
|
|
|
310
321
|
)
|
|
311
322
|
self.nodes = jtu.tree_map(_convert, expressions)
|
|
312
323
|
|
|
313
|
-
def sympy(self) ->
|
|
324
|
+
def sympy(self) -> sympy_module.Expr:
|
|
314
325
|
if self.has_extra_funcs:
|
|
315
326
|
raise NotImplementedError(
|
|
316
327
|
"SymbolicModule cannot be converted back to SymPy if `extra_funcs` "
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|