sympy2jax 0.0.7__tar.gz → 0.0.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sympy2jax
3
- Version: 0.0.7
3
+ Version: 0.0.8
4
4
  Summary: Turn SymPy expressions into trainable JAX expressions.
5
5
  Project-URL: repository, https://github.com/google/sympy2jax
6
6
  Author-email: Patrick Kidger <contact@kidger.site>
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "sympy2jax"
3
- version = "0.0.7"
3
+ version = "0.0.8"
4
4
  description = "Turn SymPy expressions into trainable JAX expressions."
5
5
  readme = "README.md"
6
6
  requires-python ="~=3.9"
@@ -31,13 +31,13 @@ include = ["sympy2jax/*"]
31
31
  [tool.pytest.ini_options]
32
32
  addopts = "--jaxtyping-packages=symyp2jax,beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On))"
33
33
 
34
- [tool.ruff]
34
+ [tool.ruff.lint]
35
35
  select = ["E", "F", "I001"]
36
36
  ignore = ["E402", "E721", "E731", "E741", "F722"]
37
37
  ignore-init-module-imports = true
38
38
  fixable = ["I001", "F401"]
39
39
 
40
- [tool.ruff.isort]
40
+ [tool.ruff.lint.isort]
41
41
  combine-as-imports = true
42
42
  lines-after-imports = 2
43
43
  extra-standard-library = ["typing_extensions"]
@@ -23,13 +23,13 @@ import jax
23
23
  import jax.numpy as jnp
24
24
  import jax.scipy as jsp
25
25
  import jax.tree_util as jtu
26
- import sympy
26
+ import sympy as sympy_module
27
27
 
28
28
 
29
29
  PyTree = Any
30
30
 
31
- concatenate: Callable = sympy.Function("concatenate") # pyright: ignore
32
- stack: Callable = sympy.Function("stack") # pyright: ignore
31
+ concatenate: Callable = sympy_module.Function("concatenate") # pyright: ignore
32
+ stack: Callable = sympy_module.Function("stack") # pyright: ignore
33
33
 
34
34
 
35
35
  def _reduce(fn):
@@ -49,56 +49,56 @@ def _single_args(fn):
49
49
  _lookup = {
50
50
  concatenate: _single_args(jnp.concatenate),
51
51
  stack: _single_args(jnp.stack),
52
- sympy.Mul: _reduce(jnp.multiply),
53
- sympy.Add: _reduce(jnp.add),
54
- sympy.div: jnp.divide,
55
- sympy.Abs: jnp.abs,
56
- sympy.sign: jnp.sign,
57
- sympy.ceiling: jnp.ceil,
58
- sympy.floor: jnp.floor,
59
- sympy.log: jnp.log,
60
- sympy.exp: jnp.exp,
61
- sympy.sqrt: jnp.sqrt,
62
- sympy.cos: jnp.cos,
63
- sympy.acos: jnp.arccos,
64
- sympy.sin: jnp.sin,
65
- sympy.asin: jnp.arcsin,
66
- sympy.tan: jnp.tan,
67
- sympy.atan: jnp.arctan,
68
- sympy.atan2: jnp.arctan2,
69
- sympy.cosh: jnp.cosh,
70
- sympy.acosh: jnp.arccosh,
71
- sympy.sinh: jnp.sinh,
72
- sympy.asinh: jnp.arcsinh,
73
- sympy.tanh: jnp.tanh,
74
- sympy.atanh: jnp.arctanh,
75
- sympy.Pow: jnp.power,
76
- sympy.re: jnp.real,
77
- sympy.im: jnp.imag,
78
- sympy.arg: jnp.angle,
79
- sympy.erf: jsp.special.erf,
80
- sympy.Eq: jnp.equal,
81
- sympy.Ne: jnp.not_equal,
82
- sympy.StrictGreaterThan: jnp.greater,
83
- sympy.StrictLessThan: jnp.less,
84
- sympy.LessThan: jnp.less_equal,
85
- sympy.GreaterThan: jnp.greater_equal,
86
- sympy.And: jnp.logical_and,
87
- sympy.Or: jnp.logical_or,
88
- sympy.Not: jnp.logical_not,
89
- sympy.Xor: jnp.logical_xor,
90
- sympy.Max: _reduce(jnp.maximum),
91
- sympy.Min: _reduce(jnp.minimum),
92
- sympy.MatAdd: _reduce(jnp.add),
93
- sympy.Trace: jnp.trace,
94
- sympy.Determinant: jnp.linalg.det,
52
+ sympy_module.Mul: _reduce(jnp.multiply),
53
+ sympy_module.Add: _reduce(jnp.add),
54
+ sympy_module.div: jnp.divide,
55
+ sympy_module.Abs: jnp.abs,
56
+ sympy_module.sign: jnp.sign,
57
+ sympy_module.ceiling: jnp.ceil,
58
+ sympy_module.floor: jnp.floor,
59
+ sympy_module.log: jnp.log,
60
+ sympy_module.exp: jnp.exp,
61
+ sympy_module.sqrt: jnp.sqrt,
62
+ sympy_module.cos: jnp.cos,
63
+ sympy_module.acos: jnp.arccos,
64
+ sympy_module.sin: jnp.sin,
65
+ sympy_module.asin: jnp.arcsin,
66
+ sympy_module.tan: jnp.tan,
67
+ sympy_module.atan: jnp.arctan,
68
+ sympy_module.atan2: jnp.arctan2,
69
+ sympy_module.cosh: jnp.cosh,
70
+ sympy_module.acosh: jnp.arccosh,
71
+ sympy_module.sinh: jnp.sinh,
72
+ sympy_module.asinh: jnp.arcsinh,
73
+ sympy_module.tanh: jnp.tanh,
74
+ sympy_module.atanh: jnp.arctanh,
75
+ sympy_module.Pow: jnp.power,
76
+ sympy_module.re: jnp.real,
77
+ sympy_module.im: jnp.imag,
78
+ sympy_module.arg: jnp.angle,
79
+ sympy_module.erf: jsp.special.erf,
80
+ sympy_module.Eq: jnp.equal,
81
+ sympy_module.Ne: jnp.not_equal,
82
+ sympy_module.StrictGreaterThan: jnp.greater,
83
+ sympy_module.StrictLessThan: jnp.less,
84
+ sympy_module.LessThan: jnp.less_equal,
85
+ sympy_module.GreaterThan: jnp.greater_equal,
86
+ sympy_module.And: jnp.logical_and,
87
+ sympy_module.Or: jnp.logical_or,
88
+ sympy_module.Not: jnp.logical_not,
89
+ sympy_module.Xor: jnp.logical_xor,
90
+ sympy_module.Max: _reduce(jnp.maximum),
91
+ sympy_module.Min: _reduce(jnp.minimum),
92
+ sympy_module.MatAdd: _reduce(jnp.add),
93
+ sympy_module.Trace: jnp.trace,
94
+ sympy_module.Determinant: jnp.linalg.det,
95
95
  }
96
96
 
97
97
  _constant_lookup = {
98
- sympy.E: jnp.e,
99
- sympy.pi: jnp.pi,
100
- sympy.EulerGamma: jnp.euler_gamma,
101
- sympy.I: 1j,
98
+ sympy_module.E: jnp.e,
99
+ sympy_module.pi: jnp.pi,
100
+ sympy_module.EulerGamma: jnp.euler_gamma,
101
+ sympy_module.I: 1j,
102
102
  }
103
103
 
104
104
  _reverse_lookup = {v: k for k, v in _lookup.items()}
@@ -118,7 +118,7 @@ class _AbstractNode(eqx.Module):
118
118
  ...
119
119
 
120
120
  @abc.abstractmethod
121
- def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
121
+ def sympy(self, memodict: dict, func_lookup: dict) -> sympy_module.Expr:
122
122
  ...
123
123
 
124
124
  # Comparisons based on identity
@@ -129,7 +129,7 @@ class _AbstractNode(eqx.Module):
129
129
  class _Symbol(_AbstractNode):
130
130
  _name: str
131
131
 
132
- def __init__(self, expr: sympy.Expr):
132
+ def __init__(self, expr: sympy_module.Expr):
133
133
  self._name = str(expr.name) # pyright: ignore
134
134
 
135
135
  def __call__(self, memodict: dict):
@@ -138,9 +138,9 @@ class _Symbol(_AbstractNode):
138
138
  except KeyError as e:
139
139
  raise KeyError(f"Missing input for symbol {self._name}") from e
140
140
 
141
- def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
141
+ def sympy(self, memodict: dict, func_lookup: dict) -> sympy_module.Expr:
142
142
  # memodict not needed as sympy deduplicates internally
143
- return sympy.Symbol(self._name)
143
+ return sympy_module.Symbol(self._name)
144
144
 
145
145
 
146
146
  def _maybe_array(val, make_array):
@@ -153,39 +153,39 @@ def _maybe_array(val, make_array):
153
153
  class _Integer(_AbstractNode):
154
154
  _value: jax.typing.ArrayLike
155
155
 
156
- def __init__(self, expr: sympy.Expr, make_array: bool):
157
- assert isinstance(expr, sympy.Integer)
156
+ def __init__(self, expr: sympy_module.Expr, make_array: bool):
157
+ assert isinstance(expr, sympy_module.Integer)
158
158
  self._value = _maybe_array(int(expr), make_array)
159
159
 
160
160
  def __call__(self, memodict: dict):
161
161
  return self._value
162
162
 
163
- def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
163
+ def sympy(self, memodict: dict, func_lookup: dict) -> sympy_module.Expr:
164
164
  # memodict not needed as sympy deduplicates internally
165
- return sympy.Integer(_item(self._value))
165
+ return sympy_module.Integer(_item(self._value))
166
166
 
167
167
 
168
168
  class _Float(_AbstractNode):
169
169
  _value: jax.typing.ArrayLike
170
170
 
171
- def __init__(self, expr: sympy.Expr, make_array: bool):
172
- assert isinstance(expr, sympy.Float)
171
+ def __init__(self, expr: sympy_module.Expr, make_array: bool):
172
+ assert isinstance(expr, sympy_module.Float)
173
173
  self._value = _maybe_array(float(expr), make_array)
174
174
 
175
175
  def __call__(self, memodict: dict):
176
176
  return self._value
177
177
 
178
- def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
178
+ def sympy(self, memodict: dict, func_lookup: dict) -> sympy_module.Expr:
179
179
  # memodict not needed as sympy deduplicates internally
180
- return sympy.Float(_item(self._value))
180
+ return sympy_module.Float(_item(self._value))
181
181
 
182
182
 
183
183
  class _Rational(_AbstractNode):
184
184
  _numerator: jax.typing.ArrayLike
185
185
  _denominator: jax.typing.ArrayLike
186
186
 
187
- def __init__(self, expr: sympy.Expr, make_array: bool):
188
- assert isinstance(expr, sympy.Rational)
187
+ def __init__(self, expr: sympy_module.Expr, make_array: bool):
188
+ assert isinstance(expr, sympy_module.Rational)
189
189
  numerator = expr.numerator
190
190
  denominator = expr.denominator
191
191
  if callable(numerator):
@@ -199,18 +199,18 @@ class _Rational(_AbstractNode):
199
199
  def __call__(self, memodict: dict):
200
200
  return self._numerator / self._denominator
201
201
 
202
- def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
202
+ def sympy(self, memodict: dict, func_lookup: dict) -> sympy_module.Expr:
203
203
  # memodict not needed as sympy deduplicates internally
204
- return sympy.Integer(_item(self._numerator)) / sympy.Integer(
204
+ return sympy_module.Integer(_item(self._numerator)) / sympy_module.Integer(
205
205
  _item(self._denominator)
206
206
  )
207
207
 
208
208
 
209
209
  class _Constant(_AbstractNode):
210
210
  _value: jnp.ndarray
211
- _expr: sympy.Expr
211
+ _expr: sympy_module.Expr
212
212
 
213
- def __init__(self, expr: sympy.Expr, make_array: bool):
213
+ def __init__(self, expr: sympy_module.Expr, make_array: bool):
214
214
  assert expr in _constant_lookup
215
215
  self._value = _maybe_array(_constant_lookup[expr], make_array)
216
216
  self._expr = expr
@@ -218,7 +218,7 @@ class _Constant(_AbstractNode):
218
218
  def __call__(self, memodict: dict):
219
219
  return self._value
220
220
 
221
- def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
221
+ def sympy(self, memodict: dict, func_lookup: dict) -> sympy_module.Expr:
222
222
  return self._expr
223
223
 
224
224
 
@@ -227,14 +227,20 @@ class _Func(_AbstractNode):
227
227
  _args: list
228
228
 
229
229
  def __init__(
230
- self, expr: sympy.Expr, memodict: dict, func_lookup: Mapping, make_array: bool
230
+ self,
231
+ expr: sympy_module.Expr,
232
+ memodict: dict,
233
+ func_lookup: Mapping,
234
+ make_array: bool,
231
235
  ):
232
236
  try:
233
237
  self._func = func_lookup[expr.func]
234
238
  except KeyError as e:
235
239
  raise KeyError(f"Unsupported Sympy type {type(expr)}") from e
236
240
  self._args = [
237
- _sympy_to_node(cast(sympy.Expr, arg), memodict, func_lookup, make_array)
241
+ _sympy_to_node(
242
+ cast(sympy_module.Expr, arg), memodict, func_lookup, make_array
243
+ )
238
244
  for arg in expr.args
239
245
  ]
240
246
 
@@ -249,7 +255,7 @@ class _Func(_AbstractNode):
249
255
  args.append(arg_call)
250
256
  return self._func(*args)
251
257
 
252
- def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
258
+ def sympy(self, memodict: dict, func_lookup: dict) -> sympy_module.Expr:
253
259
  try:
254
260
  return memodict[self]
255
261
  except KeyError:
@@ -261,20 +267,25 @@ class _Func(_AbstractNode):
261
267
 
262
268
 
263
269
  def _sympy_to_node(
264
- expr: sympy.Expr, memodict: dict, func_lookup: Mapping, make_array: bool
270
+ expr: sympy_module.Expr, memodict: dict, func_lookup: Mapping, make_array: bool
265
271
  ) -> _AbstractNode:
266
272
  try:
267
273
  return memodict[expr]
268
274
  except KeyError:
269
- if isinstance(expr, sympy.Symbol):
275
+ if isinstance(expr, sympy_module.Symbol):
270
276
  out = _Symbol(expr)
271
- elif isinstance(expr, sympy.Integer):
277
+ elif isinstance(expr, sympy_module.Integer):
272
278
  out = _Integer(expr, make_array)
273
- elif isinstance(expr, sympy.Float):
279
+ elif isinstance(expr, sympy_module.Float):
274
280
  out = _Float(expr, make_array)
275
- elif isinstance(expr, sympy.Rational):
281
+ elif isinstance(expr, sympy_module.Rational):
276
282
  out = _Rational(expr, make_array)
277
- elif expr in (sympy.E, sympy.pi, sympy.EulerGamma, sympy.I):
283
+ elif expr in (
284
+ sympy_module.E,
285
+ sympy_module.pi,
286
+ sympy_module.EulerGamma,
287
+ sympy_module.I,
288
+ ):
278
289
  out = _Constant(expr, make_array)
279
290
  else:
280
291
  out = _Func(expr, memodict, func_lookup, make_array)
@@ -288,7 +299,7 @@ def _is_node(x):
288
299
 
289
300
  class SymbolicModule(eqx.Module):
290
301
  nodes: PyTree
291
- has_extra_funcs: bool = eqx.static_field()
302
+ has_extra_funcs: bool = eqx.field(static=True)
292
303
 
293
304
  def __init__(
294
305
  self,
@@ -310,7 +321,7 @@ class SymbolicModule(eqx.Module):
310
321
  )
311
322
  self.nodes = jtu.tree_map(_convert, expressions)
312
323
 
313
- def sympy(self) -> sympy.Expr:
324
+ def sympy(self) -> sympy_module.Expr:
314
325
  if self.has_extra_funcs:
315
326
  raise NotImplementedError(
316
327
  "SymbolicModule cannot be converted back to SymPy if `extra_funcs` "
File without changes
File without changes
File without changes