sympy2jax 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sympy2jax/sympy_module.py CHANGED
@@ -22,13 +22,14 @@ import equinox as eqx
22
22
  import jax
23
23
  import jax.numpy as jnp
24
24
  import jax.scipy as jsp
25
- import sympy
25
+ import jax.tree_util as jtu
26
+ import sympy as sympy_module
26
27
 
27
28
 
28
29
  PyTree = Any
29
30
 
30
- concatenate: Callable = sympy.Function("concatenate") # pyright: ignore
31
- stack: Callable = sympy.Function("stack") # pyright: ignore
31
+ concatenate: Callable = sympy_module.Function("concatenate") # pyright: ignore
32
+ stack: Callable = sympy_module.Function("stack") # pyright: ignore
32
33
 
33
34
 
34
35
  def _reduce(fn):
@@ -48,56 +49,56 @@ def _single_args(fn):
48
49
  _lookup = {
49
50
  concatenate: _single_args(jnp.concatenate),
50
51
  stack: _single_args(jnp.stack),
51
- sympy.Mul: _reduce(jnp.multiply),
52
- sympy.Add: _reduce(jnp.add),
53
- sympy.div: jnp.divide,
54
- sympy.Abs: jnp.abs,
55
- sympy.sign: jnp.sign,
56
- sympy.ceiling: jnp.ceil,
57
- sympy.floor: jnp.floor,
58
- sympy.log: jnp.log,
59
- sympy.exp: jnp.exp,
60
- sympy.sqrt: jnp.sqrt,
61
- sympy.cos: jnp.cos,
62
- sympy.acos: jnp.arccos,
63
- sympy.sin: jnp.sin,
64
- sympy.asin: jnp.arcsin,
65
- sympy.tan: jnp.tan,
66
- sympy.atan: jnp.arctan,
67
- sympy.atan2: jnp.arctan2,
68
- sympy.cosh: jnp.cosh,
69
- sympy.acosh: jnp.arccosh,
70
- sympy.sinh: jnp.sinh,
71
- sympy.asinh: jnp.arcsinh,
72
- sympy.tanh: jnp.tanh,
73
- sympy.atanh: jnp.arctanh,
74
- sympy.Pow: jnp.power,
75
- sympy.re: jnp.real,
76
- sympy.im: jnp.imag,
77
- sympy.arg: jnp.angle,
78
- sympy.erf: jsp.special.erf,
79
- sympy.Eq: jnp.equal,
80
- sympy.Ne: jnp.not_equal,
81
- sympy.StrictGreaterThan: jnp.greater,
82
- sympy.StrictLessThan: jnp.less,
83
- sympy.LessThan: jnp.less_equal,
84
- sympy.GreaterThan: jnp.greater_equal,
85
- sympy.And: jnp.logical_and,
86
- sympy.Or: jnp.logical_or,
87
- sympy.Not: jnp.logical_not,
88
- sympy.Xor: jnp.logical_xor,
89
- sympy.Max: _reduce(jnp.maximum),
90
- sympy.Min: _reduce(jnp.minimum),
91
- sympy.MatAdd: _reduce(jnp.add),
92
- sympy.Trace: jnp.trace,
93
- sympy.Determinant: jnp.linalg.det,
52
+ sympy_module.Mul: _reduce(jnp.multiply),
53
+ sympy_module.Add: _reduce(jnp.add),
54
+ sympy_module.div: jnp.divide,
55
+ sympy_module.Abs: jnp.abs,
56
+ sympy_module.sign: jnp.sign,
57
+ sympy_module.ceiling: jnp.ceil,
58
+ sympy_module.floor: jnp.floor,
59
+ sympy_module.log: jnp.log,
60
+ sympy_module.exp: jnp.exp,
61
+ sympy_module.sqrt: jnp.sqrt,
62
+ sympy_module.cos: jnp.cos,
63
+ sympy_module.acos: jnp.arccos,
64
+ sympy_module.sin: jnp.sin,
65
+ sympy_module.asin: jnp.arcsin,
66
+ sympy_module.tan: jnp.tan,
67
+ sympy_module.atan: jnp.arctan,
68
+ sympy_module.atan2: jnp.arctan2,
69
+ sympy_module.cosh: jnp.cosh,
70
+ sympy_module.acosh: jnp.arccosh,
71
+ sympy_module.sinh: jnp.sinh,
72
+ sympy_module.asinh: jnp.arcsinh,
73
+ sympy_module.tanh: jnp.tanh,
74
+ sympy_module.atanh: jnp.arctanh,
75
+ sympy_module.Pow: jnp.power,
76
+ sympy_module.re: jnp.real,
77
+ sympy_module.im: jnp.imag,
78
+ sympy_module.arg: jnp.angle,
79
+ sympy_module.erf: jsp.special.erf,
80
+ sympy_module.Eq: jnp.equal,
81
+ sympy_module.Ne: jnp.not_equal,
82
+ sympy_module.StrictGreaterThan: jnp.greater,
83
+ sympy_module.StrictLessThan: jnp.less,
84
+ sympy_module.LessThan: jnp.less_equal,
85
+ sympy_module.GreaterThan: jnp.greater_equal,
86
+ sympy_module.And: jnp.logical_and,
87
+ sympy_module.Or: jnp.logical_or,
88
+ sympy_module.Not: jnp.logical_not,
89
+ sympy_module.Xor: jnp.logical_xor,
90
+ sympy_module.Max: _reduce(jnp.maximum),
91
+ sympy_module.Min: _reduce(jnp.minimum),
92
+ sympy_module.MatAdd: _reduce(jnp.add),
93
+ sympy_module.Trace: jnp.trace,
94
+ sympy_module.Determinant: jnp.linalg.det,
94
95
  }
95
96
 
96
97
  _constant_lookup = {
97
- sympy.E: jnp.e,
98
- sympy.pi: jnp.pi,
99
- sympy.EulerGamma: jnp.euler_gamma,
100
- sympy.I: 1j,
98
+ sympy_module.E: jnp.e,
99
+ sympy_module.pi: jnp.pi,
100
+ sympy_module.EulerGamma: jnp.euler_gamma,
101
+ sympy_module.I: 1j,
101
102
  }
102
103
 
103
104
  _reverse_lookup = {v: k for k, v in _lookup.items()}
@@ -117,7 +118,7 @@ class _AbstractNode(eqx.Module):
117
118
  ...
118
119
 
119
120
  @abc.abstractmethod
120
- def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
121
+ def sympy(self, memodict: dict, func_lookup: dict) -> sympy_module.Expr:
121
122
  ...
122
123
 
123
124
  # Comparisons based on identity
@@ -128,7 +129,7 @@ class _AbstractNode(eqx.Module):
128
129
  class _Symbol(_AbstractNode):
129
130
  _name: str
130
131
 
131
- def __init__(self, expr: sympy.Expr):
132
+ def __init__(self, expr: sympy_module.Expr):
132
133
  self._name = str(expr.name) # pyright: ignore
133
134
 
134
135
  def __call__(self, memodict: dict):
@@ -137,9 +138,9 @@ class _Symbol(_AbstractNode):
137
138
  except KeyError as e:
138
139
  raise KeyError(f"Missing input for symbol {self._name}") from e
139
140
 
140
- def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
141
+ def sympy(self, memodict: dict, func_lookup: dict) -> sympy_module.Expr:
141
142
  # memodict not needed as sympy deduplicates internally
142
- return sympy.Symbol(self._name)
143
+ return sympy_module.Symbol(self._name)
143
144
 
144
145
 
145
146
  def _maybe_array(val, make_array):
@@ -152,39 +153,39 @@ def _maybe_array(val, make_array):
152
153
  class _Integer(_AbstractNode):
153
154
  _value: jax.typing.ArrayLike
154
155
 
155
- def __init__(self, expr: sympy.Expr, make_array: bool):
156
- assert isinstance(expr, sympy.Integer)
156
+ def __init__(self, expr: sympy_module.Expr, make_array: bool):
157
+ assert isinstance(expr, sympy_module.Integer)
157
158
  self._value = _maybe_array(int(expr), make_array)
158
159
 
159
160
  def __call__(self, memodict: dict):
160
161
  return self._value
161
162
 
162
- def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
163
+ def sympy(self, memodict: dict, func_lookup: dict) -> sympy_module.Expr:
163
164
  # memodict not needed as sympy deduplicates internally
164
- return sympy.Integer(_item(self._value))
165
+ return sympy_module.Integer(_item(self._value))
165
166
 
166
167
 
167
168
  class _Float(_AbstractNode):
168
169
  _value: jax.typing.ArrayLike
169
170
 
170
- def __init__(self, expr: sympy.Expr, make_array: bool):
171
- assert isinstance(expr, sympy.Float)
171
+ def __init__(self, expr: sympy_module.Expr, make_array: bool):
172
+ assert isinstance(expr, sympy_module.Float)
172
173
  self._value = _maybe_array(float(expr), make_array)
173
174
 
174
175
  def __call__(self, memodict: dict):
175
176
  return self._value
176
177
 
177
- def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
178
+ def sympy(self, memodict: dict, func_lookup: dict) -> sympy_module.Expr:
178
179
  # memodict not needed as sympy deduplicates internally
179
- return sympy.Float(_item(self._value))
180
+ return sympy_module.Float(_item(self._value))
180
181
 
181
182
 
182
183
  class _Rational(_AbstractNode):
183
184
  _numerator: jax.typing.ArrayLike
184
185
  _denominator: jax.typing.ArrayLike
185
186
 
186
- def __init__(self, expr: sympy.Expr, make_array: bool):
187
- assert isinstance(expr, sympy.Rational)
187
+ def __init__(self, expr: sympy_module.Expr, make_array: bool):
188
+ assert isinstance(expr, sympy_module.Rational)
188
189
  numerator = expr.numerator
189
190
  denominator = expr.denominator
190
191
  if callable(numerator):
@@ -198,18 +199,18 @@ class _Rational(_AbstractNode):
198
199
  def __call__(self, memodict: dict):
199
200
  return self._numerator / self._denominator
200
201
 
201
- def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
202
+ def sympy(self, memodict: dict, func_lookup: dict) -> sympy_module.Expr:
202
203
  # memodict not needed as sympy deduplicates internally
203
- return sympy.Integer(_item(self._numerator)) / sympy.Integer(
204
+ return sympy_module.Integer(_item(self._numerator)) / sympy_module.Integer(
204
205
  _item(self._denominator)
205
206
  )
206
207
 
207
208
 
208
209
  class _Constant(_AbstractNode):
209
210
  _value: jnp.ndarray
210
- _expr: sympy.Expr
211
+ _expr: sympy_module.Expr
211
212
 
212
- def __init__(self, expr: sympy.Expr, make_array: bool):
213
+ def __init__(self, expr: sympy_module.Expr, make_array: bool):
213
214
  assert expr in _constant_lookup
214
215
  self._value = _maybe_array(_constant_lookup[expr], make_array)
215
216
  self._expr = expr
@@ -217,7 +218,7 @@ class _Constant(_AbstractNode):
217
218
  def __call__(self, memodict: dict):
218
219
  return self._value
219
220
 
220
- def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
221
+ def sympy(self, memodict: dict, func_lookup: dict) -> sympy_module.Expr:
221
222
  return self._expr
222
223
 
223
224
 
@@ -226,14 +227,20 @@ class _Func(_AbstractNode):
226
227
  _args: list
227
228
 
228
229
  def __init__(
229
- self, expr: sympy.Expr, memodict: dict, func_lookup: Mapping, make_array: bool
230
+ self,
231
+ expr: sympy_module.Expr,
232
+ memodict: dict,
233
+ func_lookup: Mapping,
234
+ make_array: bool,
230
235
  ):
231
236
  try:
232
237
  self._func = func_lookup[expr.func]
233
238
  except KeyError as e:
234
239
  raise KeyError(f"Unsupported Sympy type {type(expr)}") from e
235
240
  self._args = [
236
- _sympy_to_node(cast(sympy.Expr, arg), memodict, func_lookup, make_array)
241
+ _sympy_to_node(
242
+ cast(sympy_module.Expr, arg), memodict, func_lookup, make_array
243
+ )
237
244
  for arg in expr.args
238
245
  ]
239
246
 
@@ -248,7 +255,7 @@ class _Func(_AbstractNode):
248
255
  args.append(arg_call)
249
256
  return self._func(*args)
250
257
 
251
- def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
258
+ def sympy(self, memodict: dict, func_lookup: dict) -> sympy_module.Expr:
252
259
  try:
253
260
  return memodict[self]
254
261
  except KeyError:
@@ -260,20 +267,25 @@ class _Func(_AbstractNode):
260
267
 
261
268
 
262
269
  def _sympy_to_node(
263
- expr: sympy.Expr, memodict: dict, func_lookup: Mapping, make_array: bool
270
+ expr: sympy_module.Expr, memodict: dict, func_lookup: Mapping, make_array: bool
264
271
  ) -> _AbstractNode:
265
272
  try:
266
273
  return memodict[expr]
267
274
  except KeyError:
268
- if isinstance(expr, sympy.Symbol):
275
+ if isinstance(expr, sympy_module.Symbol):
269
276
  out = _Symbol(expr)
270
- elif isinstance(expr, sympy.Integer):
277
+ elif isinstance(expr, sympy_module.Integer):
271
278
  out = _Integer(expr, make_array)
272
- elif isinstance(expr, sympy.Float):
279
+ elif isinstance(expr, sympy_module.Float):
273
280
  out = _Float(expr, make_array)
274
- elif isinstance(expr, sympy.Rational):
281
+ elif isinstance(expr, sympy_module.Rational):
275
282
  out = _Rational(expr, make_array)
276
- elif expr in (sympy.E, sympy.pi, sympy.EulerGamma, sympy.I):
283
+ elif expr in (
284
+ sympy_module.E,
285
+ sympy_module.pi,
286
+ sympy_module.EulerGamma,
287
+ sympy_module.I,
288
+ ):
277
289
  out = _Constant(expr, make_array)
278
290
  else:
279
291
  out = _Func(expr, memodict, func_lookup, make_array)
@@ -287,7 +299,7 @@ def _is_node(x):
287
299
 
288
300
  class SymbolicModule(eqx.Module):
289
301
  nodes: PyTree
290
- has_extra_funcs: bool = eqx.static_field()
302
+ has_extra_funcs: bool = eqx.field(static=True)
291
303
 
292
304
  def __init__(
293
305
  self,
@@ -307,19 +319,19 @@ class SymbolicModule(eqx.Module):
307
319
  func_lookup=lookup,
308
320
  make_array=make_array,
309
321
  )
310
- self.nodes = jax.tree_map(_convert, expressions)
322
+ self.nodes = jtu.tree_map(_convert, expressions)
311
323
 
312
- def sympy(self) -> sympy.Expr:
324
+ def sympy(self) -> sympy_module.Expr:
313
325
  if self.has_extra_funcs:
314
326
  raise NotImplementedError(
315
327
  "SymbolicModule cannot be converted back to SymPy if `extra_funcs` "
316
328
  "is passed."
317
329
  )
318
330
  memodict = dict()
319
- return jax.tree_map(
331
+ return jtu.tree_map(
320
332
  lambda n: n.sympy(memodict, _reverse_lookup), self.nodes, is_leaf=_is_node
321
333
  )
322
334
 
323
335
  def __call__(self, **symbols):
324
336
  memodict = symbols
325
- return jax.tree_map(lambda n: n(memodict), self.nodes, is_leaf=_is_node)
337
+ return jtu.tree_map(lambda n: n(memodict), self.nodes, is_leaf=_is_node)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sympy2jax
3
- Version: 0.0.6
3
+ Version: 0.0.8
4
4
  Summary: Turn SymPy expressions into trainable JAX expressions.
5
5
  Project-URL: repository, https://github.com/google/sympy2jax
6
6
  Author-email: Patrick Kidger <contact@kidger.site>
@@ -260,7 +260,7 @@ params = jax.tree_leaves(mod) # 1.0 and 2.0 are parameters.
260
260
  ## Documentation
261
261
 
262
262
  ```python
263
- sympytorch.SymbolicModule(expressions, extra_funcs=None, make_array=True)
263
+ sympy2jax.SymbolicModule(expressions, extra_funcs=None, make_array=True)
264
264
  ```
265
265
 
266
266
  Where:
@@ -284,6 +284,7 @@ Instances have a `.sympy()` method that translates the module back into a PyTree
284
284
  [Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.
285
285
  [Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).
286
286
  [Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).
287
+ [paramax](https://github.com/danielward27/paramax): parameterizations and constraints for PyTrees.
287
288
 
288
289
  **Scientific computing**
289
290
  [Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers.
@@ -0,0 +1,6 @@
1
+ sympy2jax/__init__.py,sha256=KqjLNlIiDATWsPhjPLpG-Hud4HMMuq_DP6sISDWVYQg,720
2
+ sympy2jax/sympy_module.py,sha256=lQV08kzpo-odJ-6stXAcuDAAxeJ3eznj-IrtAfYC73E,10127
3
+ sympy2jax-0.0.8.dist-info/METADATA,sha256=T6-6aOck1OKOrnwFrDwbZQDooqxNVF6ySI7ahH_2xnU,16659
4
+ sympy2jax-0.0.8.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ sympy2jax-0.0.8.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
6
+ sympy2jax-0.0.8.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.27.0
2
+ Generator: hatchling 1.28.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,6 +0,0 @@
1
- sympy2jax/__init__.py,sha256=KqjLNlIiDATWsPhjPLpG-Hud4HMMuq_DP6sISDWVYQg,720
2
- sympy2jax/sympy_module.py,sha256=zQwxSeW2GbNWQbvqwYaDq_IFx3gVw0Jwe5d6h_YDJO4,9383
3
- sympy2jax-0.0.6.dist-info/METADATA,sha256=RG9im4iI-5vriRimeDW8wH3aAiy7V1Gcx2FLOpHfsYk,16559
4
- sympy2jax-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- sympy2jax-0.0.6.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
6
- sympy2jax-0.0.6.dist-info/RECORD,,