sympy2jax 0.0.6__tar.gz → 0.0.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sympy2jax-0.0.6 → sympy2jax-0.0.8}/PKG-INFO +3 -2
- {sympy2jax-0.0.6 → sympy2jax-0.0.8}/README.md +2 -1
- {sympy2jax-0.0.6 → sympy2jax-0.0.8}/pyproject.toml +3 -3
- {sympy2jax-0.0.6 → sympy2jax-0.0.8}/sympy2jax/sympy_module.py +95 -83
- {sympy2jax-0.0.6 → sympy2jax-0.0.8}/.gitignore +0 -0
- {sympy2jax-0.0.6 → sympy2jax-0.0.8}/LICENSE +0 -0
- {sympy2jax-0.0.6 → sympy2jax-0.0.8}/sympy2jax/__init__.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sympy2jax
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.8
|
|
4
4
|
Summary: Turn SymPy expressions into trainable JAX expressions.
|
|
5
5
|
Project-URL: repository, https://github.com/google/sympy2jax
|
|
6
6
|
Author-email: Patrick Kidger <contact@kidger.site>
|
|
@@ -260,7 +260,7 @@ params = jax.tree_leaves(mod) # 1.0 and 2.0 are parameters.
|
|
|
260
260
|
## Documentation
|
|
261
261
|
|
|
262
262
|
```python
|
|
263
|
-
|
|
263
|
+
sympy2jax.SymbolicModule(expressions, extra_funcs=None, make_array=True)
|
|
264
264
|
```
|
|
265
265
|
|
|
266
266
|
Where:
|
|
@@ -284,6 +284,7 @@ Instances have a `.sympy()` method that translates the module back into a PyTree
|
|
|
284
284
|
[Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.
|
|
285
285
|
[Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).
|
|
286
286
|
[Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).
|
|
287
|
+
[paramax](https://github.com/danielward27/paramax): parameterizations and constraints for PyTrees.
|
|
287
288
|
|
|
288
289
|
**Scientific computing**
|
|
289
290
|
[Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers.
|
|
@@ -37,7 +37,7 @@ params = jax.tree_leaves(mod) # 1.0 and 2.0 are parameters.
|
|
|
37
37
|
## Documentation
|
|
38
38
|
|
|
39
39
|
```python
|
|
40
|
-
|
|
40
|
+
sympy2jax.SymbolicModule(expressions, extra_funcs=None, make_array=True)
|
|
41
41
|
```
|
|
42
42
|
|
|
43
43
|
Where:
|
|
@@ -61,6 +61,7 @@ Instances have a `.sympy()` method that translates the module back into a PyTree
|
|
|
61
61
|
[Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.
|
|
62
62
|
[Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).
|
|
63
63
|
[Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).
|
|
64
|
+
[paramax](https://github.com/danielward27/paramax): parameterizations and constraints for PyTrees.
|
|
64
65
|
|
|
65
66
|
**Scientific computing**
|
|
66
67
|
[Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers.
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "sympy2jax"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.8"
|
|
4
4
|
description = "Turn SymPy expressions into trainable JAX expressions."
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
requires-python ="~=3.9"
|
|
@@ -31,13 +31,13 @@ include = ["sympy2jax/*"]
|
|
|
31
31
|
[tool.pytest.ini_options]
|
|
32
32
|
addopts = "--jaxtyping-packages=symyp2jax,beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On))"
|
|
33
33
|
|
|
34
|
-
[tool.ruff]
|
|
34
|
+
[tool.ruff.lint]
|
|
35
35
|
select = ["E", "F", "I001"]
|
|
36
36
|
ignore = ["E402", "E721", "E731", "E741", "F722"]
|
|
37
37
|
ignore-init-module-imports = true
|
|
38
38
|
fixable = ["I001", "F401"]
|
|
39
39
|
|
|
40
|
-
[tool.ruff.isort]
|
|
40
|
+
[tool.ruff.lint.isort]
|
|
41
41
|
combine-as-imports = true
|
|
42
42
|
lines-after-imports = 2
|
|
43
43
|
extra-standard-library = ["typing_extensions"]
|
|
@@ -22,13 +22,14 @@ import equinox as eqx
|
|
|
22
22
|
import jax
|
|
23
23
|
import jax.numpy as jnp
|
|
24
24
|
import jax.scipy as jsp
|
|
25
|
-
import
|
|
25
|
+
import jax.tree_util as jtu
|
|
26
|
+
import sympy as sympy_module
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
PyTree = Any
|
|
29
30
|
|
|
30
|
-
concatenate: Callable =
|
|
31
|
-
stack: Callable =
|
|
31
|
+
concatenate: Callable = sympy_module.Function("concatenate") # pyright: ignore
|
|
32
|
+
stack: Callable = sympy_module.Function("stack") # pyright: ignore
|
|
32
33
|
|
|
33
34
|
|
|
34
35
|
def _reduce(fn):
|
|
@@ -48,56 +49,56 @@ def _single_args(fn):
|
|
|
48
49
|
_lookup = {
|
|
49
50
|
concatenate: _single_args(jnp.concatenate),
|
|
50
51
|
stack: _single_args(jnp.stack),
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
52
|
+
sympy_module.Mul: _reduce(jnp.multiply),
|
|
53
|
+
sympy_module.Add: _reduce(jnp.add),
|
|
54
|
+
sympy_module.div: jnp.divide,
|
|
55
|
+
sympy_module.Abs: jnp.abs,
|
|
56
|
+
sympy_module.sign: jnp.sign,
|
|
57
|
+
sympy_module.ceiling: jnp.ceil,
|
|
58
|
+
sympy_module.floor: jnp.floor,
|
|
59
|
+
sympy_module.log: jnp.log,
|
|
60
|
+
sympy_module.exp: jnp.exp,
|
|
61
|
+
sympy_module.sqrt: jnp.sqrt,
|
|
62
|
+
sympy_module.cos: jnp.cos,
|
|
63
|
+
sympy_module.acos: jnp.arccos,
|
|
64
|
+
sympy_module.sin: jnp.sin,
|
|
65
|
+
sympy_module.asin: jnp.arcsin,
|
|
66
|
+
sympy_module.tan: jnp.tan,
|
|
67
|
+
sympy_module.atan: jnp.arctan,
|
|
68
|
+
sympy_module.atan2: jnp.arctan2,
|
|
69
|
+
sympy_module.cosh: jnp.cosh,
|
|
70
|
+
sympy_module.acosh: jnp.arccosh,
|
|
71
|
+
sympy_module.sinh: jnp.sinh,
|
|
72
|
+
sympy_module.asinh: jnp.arcsinh,
|
|
73
|
+
sympy_module.tanh: jnp.tanh,
|
|
74
|
+
sympy_module.atanh: jnp.arctanh,
|
|
75
|
+
sympy_module.Pow: jnp.power,
|
|
76
|
+
sympy_module.re: jnp.real,
|
|
77
|
+
sympy_module.im: jnp.imag,
|
|
78
|
+
sympy_module.arg: jnp.angle,
|
|
79
|
+
sympy_module.erf: jsp.special.erf,
|
|
80
|
+
sympy_module.Eq: jnp.equal,
|
|
81
|
+
sympy_module.Ne: jnp.not_equal,
|
|
82
|
+
sympy_module.StrictGreaterThan: jnp.greater,
|
|
83
|
+
sympy_module.StrictLessThan: jnp.less,
|
|
84
|
+
sympy_module.LessThan: jnp.less_equal,
|
|
85
|
+
sympy_module.GreaterThan: jnp.greater_equal,
|
|
86
|
+
sympy_module.And: jnp.logical_and,
|
|
87
|
+
sympy_module.Or: jnp.logical_or,
|
|
88
|
+
sympy_module.Not: jnp.logical_not,
|
|
89
|
+
sympy_module.Xor: jnp.logical_xor,
|
|
90
|
+
sympy_module.Max: _reduce(jnp.maximum),
|
|
91
|
+
sympy_module.Min: _reduce(jnp.minimum),
|
|
92
|
+
sympy_module.MatAdd: _reduce(jnp.add),
|
|
93
|
+
sympy_module.Trace: jnp.trace,
|
|
94
|
+
sympy_module.Determinant: jnp.linalg.det,
|
|
94
95
|
}
|
|
95
96
|
|
|
96
97
|
_constant_lookup = {
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
98
|
+
sympy_module.E: jnp.e,
|
|
99
|
+
sympy_module.pi: jnp.pi,
|
|
100
|
+
sympy_module.EulerGamma: jnp.euler_gamma,
|
|
101
|
+
sympy_module.I: 1j,
|
|
101
102
|
}
|
|
102
103
|
|
|
103
104
|
_reverse_lookup = {v: k for k, v in _lookup.items()}
|
|
@@ -117,7 +118,7 @@ class _AbstractNode(eqx.Module):
|
|
|
117
118
|
...
|
|
118
119
|
|
|
119
120
|
@abc.abstractmethod
|
|
120
|
-
def sympy(self, memodict: dict, func_lookup: dict) ->
|
|
121
|
+
def sympy(self, memodict: dict, func_lookup: dict) -> sympy_module.Expr:
|
|
121
122
|
...
|
|
122
123
|
|
|
123
124
|
# Comparisons based on identity
|
|
@@ -128,7 +129,7 @@ class _AbstractNode(eqx.Module):
|
|
|
128
129
|
class _Symbol(_AbstractNode):
|
|
129
130
|
_name: str
|
|
130
131
|
|
|
131
|
-
def __init__(self, expr:
|
|
132
|
+
def __init__(self, expr: sympy_module.Expr):
|
|
132
133
|
self._name = str(expr.name) # pyright: ignore
|
|
133
134
|
|
|
134
135
|
def __call__(self, memodict: dict):
|
|
@@ -137,9 +138,9 @@ class _Symbol(_AbstractNode):
|
|
|
137
138
|
except KeyError as e:
|
|
138
139
|
raise KeyError(f"Missing input for symbol {self._name}") from e
|
|
139
140
|
|
|
140
|
-
def sympy(self, memodict: dict, func_lookup: dict) ->
|
|
141
|
+
def sympy(self, memodict: dict, func_lookup: dict) -> sympy_module.Expr:
|
|
141
142
|
# memodict not needed as sympy deduplicates internally
|
|
142
|
-
return
|
|
143
|
+
return sympy_module.Symbol(self._name)
|
|
143
144
|
|
|
144
145
|
|
|
145
146
|
def _maybe_array(val, make_array):
|
|
@@ -152,39 +153,39 @@ def _maybe_array(val, make_array):
|
|
|
152
153
|
class _Integer(_AbstractNode):
|
|
153
154
|
_value: jax.typing.ArrayLike
|
|
154
155
|
|
|
155
|
-
def __init__(self, expr:
|
|
156
|
-
assert isinstance(expr,
|
|
156
|
+
def __init__(self, expr: sympy_module.Expr, make_array: bool):
|
|
157
|
+
assert isinstance(expr, sympy_module.Integer)
|
|
157
158
|
self._value = _maybe_array(int(expr), make_array)
|
|
158
159
|
|
|
159
160
|
def __call__(self, memodict: dict):
|
|
160
161
|
return self._value
|
|
161
162
|
|
|
162
|
-
def sympy(self, memodict: dict, func_lookup: dict) ->
|
|
163
|
+
def sympy(self, memodict: dict, func_lookup: dict) -> sympy_module.Expr:
|
|
163
164
|
# memodict not needed as sympy deduplicates internally
|
|
164
|
-
return
|
|
165
|
+
return sympy_module.Integer(_item(self._value))
|
|
165
166
|
|
|
166
167
|
|
|
167
168
|
class _Float(_AbstractNode):
|
|
168
169
|
_value: jax.typing.ArrayLike
|
|
169
170
|
|
|
170
|
-
def __init__(self, expr:
|
|
171
|
-
assert isinstance(expr,
|
|
171
|
+
def __init__(self, expr: sympy_module.Expr, make_array: bool):
|
|
172
|
+
assert isinstance(expr, sympy_module.Float)
|
|
172
173
|
self._value = _maybe_array(float(expr), make_array)
|
|
173
174
|
|
|
174
175
|
def __call__(self, memodict: dict):
|
|
175
176
|
return self._value
|
|
176
177
|
|
|
177
|
-
def sympy(self, memodict: dict, func_lookup: dict) ->
|
|
178
|
+
def sympy(self, memodict: dict, func_lookup: dict) -> sympy_module.Expr:
|
|
178
179
|
# memodict not needed as sympy deduplicates internally
|
|
179
|
-
return
|
|
180
|
+
return sympy_module.Float(_item(self._value))
|
|
180
181
|
|
|
181
182
|
|
|
182
183
|
class _Rational(_AbstractNode):
|
|
183
184
|
_numerator: jax.typing.ArrayLike
|
|
184
185
|
_denominator: jax.typing.ArrayLike
|
|
185
186
|
|
|
186
|
-
def __init__(self, expr:
|
|
187
|
-
assert isinstance(expr,
|
|
187
|
+
def __init__(self, expr: sympy_module.Expr, make_array: bool):
|
|
188
|
+
assert isinstance(expr, sympy_module.Rational)
|
|
188
189
|
numerator = expr.numerator
|
|
189
190
|
denominator = expr.denominator
|
|
190
191
|
if callable(numerator):
|
|
@@ -198,18 +199,18 @@ class _Rational(_AbstractNode):
|
|
|
198
199
|
def __call__(self, memodict: dict):
|
|
199
200
|
return self._numerator / self._denominator
|
|
200
201
|
|
|
201
|
-
def sympy(self, memodict: dict, func_lookup: dict) ->
|
|
202
|
+
def sympy(self, memodict: dict, func_lookup: dict) -> sympy_module.Expr:
|
|
202
203
|
# memodict not needed as sympy deduplicates internally
|
|
203
|
-
return
|
|
204
|
+
return sympy_module.Integer(_item(self._numerator)) / sympy_module.Integer(
|
|
204
205
|
_item(self._denominator)
|
|
205
206
|
)
|
|
206
207
|
|
|
207
208
|
|
|
208
209
|
class _Constant(_AbstractNode):
|
|
209
210
|
_value: jnp.ndarray
|
|
210
|
-
_expr:
|
|
211
|
+
_expr: sympy_module.Expr
|
|
211
212
|
|
|
212
|
-
def __init__(self, expr:
|
|
213
|
+
def __init__(self, expr: sympy_module.Expr, make_array: bool):
|
|
213
214
|
assert expr in _constant_lookup
|
|
214
215
|
self._value = _maybe_array(_constant_lookup[expr], make_array)
|
|
215
216
|
self._expr = expr
|
|
@@ -217,7 +218,7 @@ class _Constant(_AbstractNode):
|
|
|
217
218
|
def __call__(self, memodict: dict):
|
|
218
219
|
return self._value
|
|
219
220
|
|
|
220
|
-
def sympy(self, memodict: dict, func_lookup: dict) ->
|
|
221
|
+
def sympy(self, memodict: dict, func_lookup: dict) -> sympy_module.Expr:
|
|
221
222
|
return self._expr
|
|
222
223
|
|
|
223
224
|
|
|
@@ -226,14 +227,20 @@ class _Func(_AbstractNode):
|
|
|
226
227
|
_args: list
|
|
227
228
|
|
|
228
229
|
def __init__(
|
|
229
|
-
self,
|
|
230
|
+
self,
|
|
231
|
+
expr: sympy_module.Expr,
|
|
232
|
+
memodict: dict,
|
|
233
|
+
func_lookup: Mapping,
|
|
234
|
+
make_array: bool,
|
|
230
235
|
):
|
|
231
236
|
try:
|
|
232
237
|
self._func = func_lookup[expr.func]
|
|
233
238
|
except KeyError as e:
|
|
234
239
|
raise KeyError(f"Unsupported Sympy type {type(expr)}") from e
|
|
235
240
|
self._args = [
|
|
236
|
-
_sympy_to_node(
|
|
241
|
+
_sympy_to_node(
|
|
242
|
+
cast(sympy_module.Expr, arg), memodict, func_lookup, make_array
|
|
243
|
+
)
|
|
237
244
|
for arg in expr.args
|
|
238
245
|
]
|
|
239
246
|
|
|
@@ -248,7 +255,7 @@ class _Func(_AbstractNode):
|
|
|
248
255
|
args.append(arg_call)
|
|
249
256
|
return self._func(*args)
|
|
250
257
|
|
|
251
|
-
def sympy(self, memodict: dict, func_lookup: dict) ->
|
|
258
|
+
def sympy(self, memodict: dict, func_lookup: dict) -> sympy_module.Expr:
|
|
252
259
|
try:
|
|
253
260
|
return memodict[self]
|
|
254
261
|
except KeyError:
|
|
@@ -260,20 +267,25 @@ class _Func(_AbstractNode):
|
|
|
260
267
|
|
|
261
268
|
|
|
262
269
|
def _sympy_to_node(
|
|
263
|
-
expr:
|
|
270
|
+
expr: sympy_module.Expr, memodict: dict, func_lookup: Mapping, make_array: bool
|
|
264
271
|
) -> _AbstractNode:
|
|
265
272
|
try:
|
|
266
273
|
return memodict[expr]
|
|
267
274
|
except KeyError:
|
|
268
|
-
if isinstance(expr,
|
|
275
|
+
if isinstance(expr, sympy_module.Symbol):
|
|
269
276
|
out = _Symbol(expr)
|
|
270
|
-
elif isinstance(expr,
|
|
277
|
+
elif isinstance(expr, sympy_module.Integer):
|
|
271
278
|
out = _Integer(expr, make_array)
|
|
272
|
-
elif isinstance(expr,
|
|
279
|
+
elif isinstance(expr, sympy_module.Float):
|
|
273
280
|
out = _Float(expr, make_array)
|
|
274
|
-
elif isinstance(expr,
|
|
281
|
+
elif isinstance(expr, sympy_module.Rational):
|
|
275
282
|
out = _Rational(expr, make_array)
|
|
276
|
-
elif expr in (
|
|
283
|
+
elif expr in (
|
|
284
|
+
sympy_module.E,
|
|
285
|
+
sympy_module.pi,
|
|
286
|
+
sympy_module.EulerGamma,
|
|
287
|
+
sympy_module.I,
|
|
288
|
+
):
|
|
277
289
|
out = _Constant(expr, make_array)
|
|
278
290
|
else:
|
|
279
291
|
out = _Func(expr, memodict, func_lookup, make_array)
|
|
@@ -287,7 +299,7 @@ def _is_node(x):
|
|
|
287
299
|
|
|
288
300
|
class SymbolicModule(eqx.Module):
|
|
289
301
|
nodes: PyTree
|
|
290
|
-
has_extra_funcs: bool = eqx.
|
|
302
|
+
has_extra_funcs: bool = eqx.field(static=True)
|
|
291
303
|
|
|
292
304
|
def __init__(
|
|
293
305
|
self,
|
|
@@ -307,19 +319,19 @@ class SymbolicModule(eqx.Module):
|
|
|
307
319
|
func_lookup=lookup,
|
|
308
320
|
make_array=make_array,
|
|
309
321
|
)
|
|
310
|
-
self.nodes =
|
|
322
|
+
self.nodes = jtu.tree_map(_convert, expressions)
|
|
311
323
|
|
|
312
|
-
def sympy(self) ->
|
|
324
|
+
def sympy(self) -> sympy_module.Expr:
|
|
313
325
|
if self.has_extra_funcs:
|
|
314
326
|
raise NotImplementedError(
|
|
315
327
|
"SymbolicModule cannot be converted back to SymPy if `extra_funcs` "
|
|
316
328
|
"is passed."
|
|
317
329
|
)
|
|
318
330
|
memodict = dict()
|
|
319
|
-
return
|
|
331
|
+
return jtu.tree_map(
|
|
320
332
|
lambda n: n.sympy(memodict, _reverse_lookup), self.nodes, is_leaf=_is_node
|
|
321
333
|
)
|
|
322
334
|
|
|
323
335
|
def __call__(self, **symbols):
|
|
324
336
|
memodict = symbols
|
|
325
|
-
return
|
|
337
|
+
return jtu.tree_map(lambda n: n(memodict), self.nodes, is_leaf=_is_node)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|