sympy2jax 0.0.6__tar.gz → 0.0.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sympy2jax
3
- Version: 0.0.6
3
+ Version: 0.0.7
4
4
  Summary: Turn SymPy expressions into trainable JAX expressions.
5
5
  Project-URL: repository, https://github.com/google/sympy2jax
6
6
  Author-email: Patrick Kidger <contact@kidger.site>
@@ -260,7 +260,7 @@ params = jax.tree_leaves(mod) # 1.0 and 2.0 are parameters.
260
260
  ## Documentation
261
261
 
262
262
  ```python
263
- sympytorch.SymbolicModule(expressions, extra_funcs=None, make_array=True)
263
+ sympy2jax.SymbolicModule(expressions, extra_funcs=None, make_array=True)
264
264
  ```
265
265
 
266
266
  Where:
@@ -284,6 +284,7 @@ Instances have a `.sympy()` method that translates the module back into a PyTree
284
284
  [Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.
285
285
  [Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).
286
286
  [Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).
287
+ [paramax](https://github.com/danielward27/paramax): parameterizations and constraints for PyTrees.
287
288
 
288
289
  **Scientific computing**
289
290
  [Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers.
@@ -37,7 +37,7 @@ params = jax.tree_leaves(mod) # 1.0 and 2.0 are parameters.
37
37
  ## Documentation
38
38
 
39
39
  ```python
40
- sympytorch.SymbolicModule(expressions, extra_funcs=None, make_array=True)
40
+ sympy2jax.SymbolicModule(expressions, extra_funcs=None, make_array=True)
41
41
  ```
42
42
 
43
43
  Where:
@@ -61,6 +61,7 @@ Instances have a `.sympy()` method that translates the module back into a PyTree
61
61
  [Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.
62
62
  [Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).
63
63
  [Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).
64
+ [paramax](https://github.com/danielward27/paramax): parameterizations and constraints for PyTrees.
64
65
 
65
66
  **Scientific computing**
66
67
  [Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers.
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "sympy2jax"
3
- version = "0.0.6"
3
+ version = "0.0.7"
4
4
  description = "Turn SymPy expressions into trainable JAX expressions."
5
5
  readme = "README.md"
6
6
  requires-python ="~=3.9"
@@ -22,6 +22,7 @@ import equinox as eqx
22
22
  import jax
23
23
  import jax.numpy as jnp
24
24
  import jax.scipy as jsp
25
+ import jax.tree_util as jtu
25
26
  import sympy
26
27
 
27
28
 
@@ -307,7 +308,7 @@ class SymbolicModule(eqx.Module):
307
308
  func_lookup=lookup,
308
309
  make_array=make_array,
309
310
  )
310
- self.nodes = jax.tree_map(_convert, expressions)
311
+ self.nodes = jtu.tree_map(_convert, expressions)
311
312
 
312
313
  def sympy(self) -> sympy.Expr:
313
314
  if self.has_extra_funcs:
@@ -316,10 +317,10 @@ class SymbolicModule(eqx.Module):
316
317
  "is passed."
317
318
  )
318
319
  memodict = dict()
319
- return jax.tree_map(
320
+ return jtu.tree_map(
320
321
  lambda n: n.sympy(memodict, _reverse_lookup), self.nodes, is_leaf=_is_node
321
322
  )
322
323
 
323
324
  def __call__(self, **symbols):
324
325
  memodict = symbols
325
- return jax.tree_map(lambda n: n(memodict), self.nodes, is_leaf=_is_node)
326
+ return jtu.tree_map(lambda n: n(memodict), self.nodes, is_leaf=_is_node)
File without changes
File without changes