sympy2jax 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sympy2jax/sympy_module.py +5 -4
- {sympy2jax-0.0.5.dist-info → sympy2jax-0.0.7.dist-info}/METADATA +20 -28
- sympy2jax-0.0.7.dist-info/RECORD +6 -0
- {sympy2jax-0.0.5.dist-info → sympy2jax-0.0.7.dist-info}/WHEEL +1 -1
- sympy2jax-0.0.5.dist-info/RECORD +0 -6
- {sympy2jax-0.0.5.dist-info → sympy2jax-0.0.7.dist-info}/licenses/LICENSE +0 -0
sympy2jax/sympy_module.py
CHANGED
|
@@ -22,6 +22,7 @@ import equinox as eqx
|
|
|
22
22
|
import jax
|
|
23
23
|
import jax.numpy as jnp
|
|
24
24
|
import jax.scipy as jsp
|
|
25
|
+
import jax.tree_util as jtu
|
|
25
26
|
import sympy
|
|
26
27
|
|
|
27
28
|
|
|
@@ -129,7 +130,7 @@ class _Symbol(_AbstractNode):
|
|
|
129
130
|
_name: str
|
|
130
131
|
|
|
131
132
|
def __init__(self, expr: sympy.Expr):
|
|
132
|
-
self._name = expr.name # pyright: ignore
|
|
133
|
+
self._name = str(expr.name) # pyright: ignore
|
|
133
134
|
|
|
134
135
|
def __call__(self, memodict: dict):
|
|
135
136
|
try:
|
|
@@ -307,7 +308,7 @@ class SymbolicModule(eqx.Module):
|
|
|
307
308
|
func_lookup=lookup,
|
|
308
309
|
make_array=make_array,
|
|
309
310
|
)
|
|
310
|
-
self.nodes =
|
|
311
|
+
self.nodes = jtu.tree_map(_convert, expressions)
|
|
311
312
|
|
|
312
313
|
def sympy(self) -> sympy.Expr:
|
|
313
314
|
if self.has_extra_funcs:
|
|
@@ -316,10 +317,10 @@ class SymbolicModule(eqx.Module):
|
|
|
316
317
|
"is passed."
|
|
317
318
|
)
|
|
318
319
|
memodict = dict()
|
|
319
|
-
return
|
|
320
|
+
return jtu.tree_map(
|
|
320
321
|
lambda n: n.sympy(memodict, _reverse_lookup), self.nodes, is_leaf=_is_node
|
|
321
322
|
)
|
|
322
323
|
|
|
323
324
|
def __call__(self, **symbols):
|
|
324
325
|
memodict = symbols
|
|
325
|
-
return
|
|
326
|
+
return jtu.tree_map(lambda n: n(memodict), self.nodes, is_leaf=_is_node)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: sympy2jax
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.7
|
|
4
4
|
Summary: Turn SymPy expressions into trainable JAX expressions.
|
|
5
5
|
Project-URL: repository, https://github.com/google/sympy2jax
|
|
6
6
|
Author-email: Patrick Kidger <contact@kidger.site>
|
|
@@ -260,7 +260,7 @@ params = jax.tree_leaves(mod) # 1.0 and 2.0 are parameters.
|
|
|
260
260
|
## Documentation
|
|
261
261
|
|
|
262
262
|
```python
|
|
263
|
-
|
|
263
|
+
sympy2jax.SymbolicModule(expressions, extra_funcs=None, make_array=True)
|
|
264
264
|
```
|
|
265
265
|
|
|
266
266
|
Where:
|
|
@@ -274,32 +274,24 @@ Instances have a `.sympy()` method that translates the module back into a PyTree
|
|
|
274
274
|
|
|
275
275
|
(That's literally the entire documentation, it's super easy.)
|
|
276
276
|
|
|
277
|
-
##
|
|
277
|
+
## See also: other libraries in the JAX ecosystem
|
|
278
278
|
|
|
279
|
-
|
|
279
|
+
**Always useful**
|
|
280
|
+
[Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX!
|
|
281
|
+
[jaxtyping](https://github.com/patrick-kidger/jaxtyping): type annotations for shape/dtype of arrays.
|
|
280
282
|
|
|
281
|
-
|
|
283
|
+
**Deep learning**
|
|
284
|
+
[Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.
|
|
285
|
+
[Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).
|
|
286
|
+
[Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).
|
|
287
|
+
[paramax](https://github.com/danielward27/paramax): parameterizations and constraints for PyTrees.
|
|
282
288
|
|
|
283
|
-
|
|
289
|
+
**Scientific computing**
|
|
290
|
+
[Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers.
|
|
291
|
+
[Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.
|
|
292
|
+
[Lineax](https://github.com/patrick-kidger/lineax): linear solvers.
|
|
293
|
+
[BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling.
|
|
294
|
+
[PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!)
|
|
284
295
|
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
[Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers.
|
|
288
|
-
|
|
289
|
-
[Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.
|
|
290
|
-
|
|
291
|
-
[Lineax](https://github.com/google/lineax): linear solvers.
|
|
292
|
-
|
|
293
|
-
[BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling.
|
|
294
|
-
|
|
295
|
-
[Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).
|
|
296
|
-
|
|
297
|
-
[Eqxvision](https://github.com/paganpasta/eqxvision): computer vision models.
|
|
298
|
-
|
|
299
|
-
[Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).
|
|
300
|
-
|
|
301
|
-
[PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!)
|
|
302
|
-
|
|
303
|
-
### Disclaimer
|
|
304
|
-
|
|
305
|
-
This is not an official Google product.
|
|
296
|
+
**Awesome JAX**
|
|
297
|
+
[Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of other JAX projects.
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
sympy2jax/__init__.py,sha256=KqjLNlIiDATWsPhjPLpG-Hud4HMMuq_DP6sISDWVYQg,720
|
|
2
|
+
sympy2jax/sympy_module.py,sha256=R1RjJ57jvLcSqLJZNu20BE3XdYHMi2UuNGmZslifYyA,9411
|
|
3
|
+
sympy2jax-0.0.7.dist-info/METADATA,sha256=lkRfJ18NFu93bdodMkm7DrqLYYLayJh-qwce0rR5C8o,16659
|
|
4
|
+
sympy2jax-0.0.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
+
sympy2jax-0.0.7.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
6
|
+
sympy2jax-0.0.7.dist-info/RECORD,,
|
sympy2jax-0.0.5.dist-info/RECORD
DELETED
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
sympy2jax/__init__.py,sha256=KqjLNlIiDATWsPhjPLpG-Hud4HMMuq_DP6sISDWVYQg,720
|
|
2
|
-
sympy2jax/sympy_module.py,sha256=t1Z8yTRCDgDY_CcG8TF19aAJivKSdz__Ej_OdQ0272g,9378
|
|
3
|
-
sympy2jax-0.0.5.dist-info/METADATA,sha256=iJitCmILWFrGKzgHMNxSIUKeX20FrD4alddTdNNHrKc,16459
|
|
4
|
-
sympy2jax-0.0.5.dist-info/WHEEL,sha256=9QBuHhg6FNW7lppboF2vKVbCGTVzsFykgRQjjlajrhA,87
|
|
5
|
-
sympy2jax-0.0.5.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
6
|
-
sympy2jax-0.0.5.dist-info/RECORD,,
|
|
File without changes
|