sympy2jax 0.0.6__tar.gz → 0.0.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sympy2jax-0.0.6 → sympy2jax-0.0.7}/PKG-INFO +3 -2
- {sympy2jax-0.0.6 → sympy2jax-0.0.7}/README.md +2 -1
- {sympy2jax-0.0.6 → sympy2jax-0.0.7}/pyproject.toml +1 -1
- {sympy2jax-0.0.6 → sympy2jax-0.0.7}/sympy2jax/sympy_module.py +4 -3
- {sympy2jax-0.0.6 → sympy2jax-0.0.7}/.gitignore +0 -0
- {sympy2jax-0.0.6 → sympy2jax-0.0.7}/LICENSE +0 -0
- {sympy2jax-0.0.6 → sympy2jax-0.0.7}/sympy2jax/__init__.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sympy2jax
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.7
|
|
4
4
|
Summary: Turn SymPy expressions into trainable JAX expressions.
|
|
5
5
|
Project-URL: repository, https://github.com/google/sympy2jax
|
|
6
6
|
Author-email: Patrick Kidger <contact@kidger.site>
|
|
@@ -260,7 +260,7 @@ params = jax.tree_leaves(mod) # 1.0 and 2.0 are parameters.
|
|
|
260
260
|
## Documentation
|
|
261
261
|
|
|
262
262
|
```python
|
|
263
|
-
|
|
263
|
+
sympy2jax.SymbolicModule(expressions, extra_funcs=None, make_array=True)
|
|
264
264
|
```
|
|
265
265
|
|
|
266
266
|
Where:
|
|
@@ -284,6 +284,7 @@ Instances have a `.sympy()` method that translates the module back into a PyTree
|
|
|
284
284
|
[Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.
|
|
285
285
|
[Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).
|
|
286
286
|
[Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).
|
|
287
|
+
[paramax](https://github.com/danielward27/paramax): parameterizations and constraints for PyTrees.
|
|
287
288
|
|
|
288
289
|
**Scientific computing**
|
|
289
290
|
[Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers.
|
|
@@ -37,7 +37,7 @@ params = jax.tree_leaves(mod) # 1.0 and 2.0 are parameters.
|
|
|
37
37
|
## Documentation
|
|
38
38
|
|
|
39
39
|
```python
|
|
40
|
-
|
|
40
|
+
sympy2jax.SymbolicModule(expressions, extra_funcs=None, make_array=True)
|
|
41
41
|
```
|
|
42
42
|
|
|
43
43
|
Where:
|
|
@@ -61,6 +61,7 @@ Instances have a `.sympy()` method that translates the module back into a PyTree
|
|
|
61
61
|
[Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.
|
|
62
62
|
[Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).
|
|
63
63
|
[Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).
|
|
64
|
+
[paramax](https://github.com/danielward27/paramax): parameterizations and constraints for PyTrees.
|
|
64
65
|
|
|
65
66
|
**Scientific computing**
|
|
66
67
|
[Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers.
|
|
@@ -22,6 +22,7 @@ import equinox as eqx
|
|
|
22
22
|
import jax
|
|
23
23
|
import jax.numpy as jnp
|
|
24
24
|
import jax.scipy as jsp
|
|
25
|
+
import jax.tree_util as jtu
|
|
25
26
|
import sympy
|
|
26
27
|
|
|
27
28
|
|
|
@@ -307,7 +308,7 @@ class SymbolicModule(eqx.Module):
|
|
|
307
308
|
func_lookup=lookup,
|
|
308
309
|
make_array=make_array,
|
|
309
310
|
)
|
|
310
|
-
self.nodes =
|
|
311
|
+
self.nodes = jtu.tree_map(_convert, expressions)
|
|
311
312
|
|
|
312
313
|
def sympy(self) -> sympy.Expr:
|
|
313
314
|
if self.has_extra_funcs:
|
|
@@ -316,10 +317,10 @@ class SymbolicModule(eqx.Module):
|
|
|
316
317
|
"is passed."
|
|
317
318
|
)
|
|
318
319
|
memodict = dict()
|
|
319
|
-
return
|
|
320
|
+
return jtu.tree_map(
|
|
320
321
|
lambda n: n.sympy(memodict, _reverse_lookup), self.nodes, is_leaf=_is_node
|
|
321
322
|
)
|
|
322
323
|
|
|
323
324
|
def __call__(self, **symbols):
|
|
324
325
|
memodict = symbols
|
|
325
|
-
return
|
|
326
|
+
return jtu.tree_map(lambda n: n(memodict), self.nodes, is_leaf=_is_node)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|