sympy2jax 0.0.5__tar.gz → 0.0.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: sympy2jax
3
- Version: 0.0.5
3
+ Version: 0.0.7
4
4
  Summary: Turn SymPy expressions into trainable JAX expressions.
5
5
  Project-URL: repository, https://github.com/google/sympy2jax
6
6
  Author-email: Patrick Kidger <contact@kidger.site>
@@ -260,7 +260,7 @@ params = jax.tree_leaves(mod) # 1.0 and 2.0 are parameters.
260
260
  ## Documentation
261
261
 
262
262
  ```python
263
- sympytorch.SymbolicModule(expressions, extra_funcs=None, make_array=True)
263
+ sympy2jax.SymbolicModule(expressions, extra_funcs=None, make_array=True)
264
264
  ```
265
265
 
266
266
  Where:
@@ -274,32 +274,24 @@ Instances have a `.sympy()` method that translates the module back into a PyTree
274
274
 
275
275
  (That's literally the entire documentation, it's super easy.)
276
276
 
277
- ## Finally
277
+ ## See also: other libraries in the JAX ecosystem
278
278
 
279
- ### See also: other libraries in the JAX ecosystem
279
+ **Always useful**
280
+ [Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX!
281
+ [jaxtyping](https://github.com/patrick-kidger/jaxtyping): type annotations for shape/dtype of arrays.
280
282
 
281
- [jaxtyping](https://github.com/google/jaxtyping): type annotations for shape/dtype of arrays.
283
+ **Deep learning**
284
+ [Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.
285
+ [Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).
286
+ [Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).
287
+ [paramax](https://github.com/danielward27/paramax): parameterizations and constraints for PyTrees.
282
288
 
283
- [Equinox](https://github.com/patrick-kidger/equinox): neural networks.
289
+ **Scientific computing**
290
+ [Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers.
291
+ [Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.
292
+ [Lineax](https://github.com/patrick-kidger/lineax): linear solvers.
293
+ [BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling.
294
+ [PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!)
284
295
 
285
- [Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.
286
-
287
- [Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers.
288
-
289
- [Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.
290
-
291
- [Lineax](https://github.com/google/lineax): linear solvers.
292
-
293
- [BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling.
294
-
295
- [Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).
296
-
297
- [Eqxvision](https://github.com/paganpasta/eqxvision): computer vision models.
298
-
299
- [Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).
300
-
301
- [PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!)
302
-
303
- ### Disclaimer
304
-
305
- This is not an official Google product.
296
+ **Awesome JAX**
297
+ [Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of other JAX projects.
@@ -37,7 +37,7 @@ params = jax.tree_leaves(mod) # 1.0 and 2.0 are parameters.
37
37
  ## Documentation
38
38
 
39
39
  ```python
40
- sympytorch.SymbolicModule(expressions, extra_funcs=None, make_array=True)
40
+ sympy2jax.SymbolicModule(expressions, extra_funcs=None, make_array=True)
41
41
  ```
42
42
 
43
43
  Where:
@@ -51,32 +51,24 @@ Instances have a `.sympy()` method that translates the module back into a PyTree
51
51
 
52
52
  (That's literally the entire documentation, it's super easy.)
53
53
 
54
- ## Finally
54
+ ## See also: other libraries in the JAX ecosystem
55
55
 
56
- ### See also: other libraries in the JAX ecosystem
56
+ **Always useful**
57
+ [Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX!
58
+ [jaxtyping](https://github.com/patrick-kidger/jaxtyping): type annotations for shape/dtype of arrays.
57
59
 
58
- [jaxtyping](https://github.com/google/jaxtyping): type annotations for shape/dtype of arrays.
60
+ **Deep learning**
61
+ [Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.
62
+ [Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).
63
+ [Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).
64
+ [paramax](https://github.com/danielward27/paramax): parameterizations and constraints for PyTrees.
59
65
 
60
- [Equinox](https://github.com/patrick-kidger/equinox): neural networks.
66
+ **Scientific computing**
67
+ [Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers.
68
+ [Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.
69
+ [Lineax](https://github.com/patrick-kidger/lineax): linear solvers.
70
+ [BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling.
71
+ [PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!)
61
72
 
62
- [Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.
63
-
64
- [Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers.
65
-
66
- [Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.
67
-
68
- [Lineax](https://github.com/google/lineax): linear solvers.
69
-
70
- [BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling.
71
-
72
- [Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).
73
-
74
- [Eqxvision](https://github.com/paganpasta/eqxvision): computer vision models.
75
-
76
- [Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).
77
-
78
- [PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!)
79
-
80
- ### Disclaimer
81
-
82
- This is not an official Google product.
73
+ **Awesome JAX**
74
+ [Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of other JAX projects.
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "sympy2jax"
3
- version = "0.0.5"
3
+ version = "0.0.7"
4
4
  description = "Turn SymPy expressions into trainable JAX expressions."
5
5
  readme = "README.md"
6
6
  requires-python ="~=3.9"
@@ -22,6 +22,7 @@ import equinox as eqx
22
22
  import jax
23
23
  import jax.numpy as jnp
24
24
  import jax.scipy as jsp
25
+ import jax.tree_util as jtu
25
26
  import sympy
26
27
 
27
28
 
@@ -129,7 +130,7 @@ class _Symbol(_AbstractNode):
129
130
  _name: str
130
131
 
131
132
  def __init__(self, expr: sympy.Expr):
132
- self._name = expr.name # pyright: ignore
133
+ self._name = str(expr.name) # pyright: ignore
133
134
 
134
135
  def __call__(self, memodict: dict):
135
136
  try:
@@ -307,7 +308,7 @@ class SymbolicModule(eqx.Module):
307
308
  func_lookup=lookup,
308
309
  make_array=make_array,
309
310
  )
310
- self.nodes = jax.tree_map(_convert, expressions)
311
+ self.nodes = jtu.tree_map(_convert, expressions)
311
312
 
312
313
  def sympy(self) -> sympy.Expr:
313
314
  if self.has_extra_funcs:
@@ -316,10 +317,10 @@ class SymbolicModule(eqx.Module):
316
317
  "is passed."
317
318
  )
318
319
  memodict = dict()
319
- return jax.tree_map(
320
+ return jtu.tree_map(
320
321
  lambda n: n.sympy(memodict, _reverse_lookup), self.nodes, is_leaf=_is_node
321
322
  )
322
323
 
323
324
  def __call__(self, **symbols):
324
325
  memodict = symbols
325
- return jax.tree_map(lambda n: n(memodict), self.nodes, is_leaf=_is_node)
326
+ return jtu.tree_map(lambda n: n(memodict), self.nodes, is_leaf=_is_node)
File without changes
File without changes