sympy2jax 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sympy2jax/__init__.py CHANGED
@@ -12,7 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .sympy_module import SymbolicModule
15
+ from .sympy_module import (
16
+ concatenate as concatenate,
17
+ stack as stack,
18
+ SymbolicModule as SymbolicModule,
19
+ )
16
20
 
17
21
 
18
22
  __version__ = "0.0.4"
sympy2jax/sympy_module.py CHANGED
@@ -15,7 +15,8 @@
15
15
  import abc
16
16
  import collections as co
17
17
  import functools as ft
18
- from typing import Any, Callable, Optional
18
+ from collections.abc import Callable, Mapping
19
+ from typing import Any, cast, Optional
19
20
 
20
21
  import equinox as eqx
21
22
  import jax
@@ -26,6 +27,9 @@ import sympy
26
27
 
27
28
  PyTree = Any
28
29
 
30
+ concatenate: Callable = sympy.Function("concatenate") # pyright: ignore
31
+ stack: Callable = sympy.Function("stack") # pyright: ignore
32
+
29
33
 
30
34
  def _reduce(fn):
31
35
  def fn_(*args):
@@ -34,7 +38,16 @@ def _reduce(fn):
34
38
  return fn_
35
39
 
36
40
 
41
+ def _single_args(fn):
42
+ def fn_(*args):
43
+ return fn(args)
44
+
45
+ return fn_
46
+
47
+
37
48
  _lookup = {
49
+ concatenate: _single_args(jnp.concatenate),
50
+ stack: _single_args(jnp.stack),
38
51
  sympy.Mul: _reduce(jnp.multiply),
39
52
  sympy.Add: _reduce(jnp.add),
40
53
  sympy.div: jnp.divide,
@@ -80,13 +93,27 @@ _lookup = {
80
93
  sympy.Determinant: jnp.linalg.det,
81
94
  }
82
95
 
96
+ _constant_lookup = {
97
+ sympy.E: jnp.e,
98
+ sympy.pi: jnp.pi,
99
+ sympy.EulerGamma: jnp.euler_gamma,
100
+ sympy.I: 1j,
101
+ }
102
+
83
103
  _reverse_lookup = {v: k for k, v in _lookup.items()}
84
104
  assert len(_reverse_lookup) == len(_lookup)
85
105
 
86
106
 
107
+ def _item(x):
108
+ if eqx.is_array(x):
109
+ return x.item()
110
+ else:
111
+ return x
112
+
113
+
87
114
  class _AbstractNode(eqx.Module):
88
115
  @abc.abstractmethod
89
- def __call__(self, memodict: dict):
116
+ def __call__(self, memodict: dict) -> jax.typing.ArrayLike:
90
117
  ...
91
118
 
92
119
  @abc.abstractmethod
@@ -95,14 +122,14 @@ class _AbstractNode(eqx.Module):
95
122
 
96
123
  # Comparisons based on identity
97
124
  __hash__ = object.__hash__
98
- __eq__ = object.__eq__
125
+ __eq__ = object.__eq__ # pyright: ignore
99
126
 
100
127
 
101
128
  class _Symbol(_AbstractNode):
102
129
  _name: str
103
130
 
104
131
  def __init__(self, expr: sympy.Expr):
105
- self._name = expr.name
132
+ self._name = str(expr.name) # pyright: ignore
106
133
 
107
134
  def __call__(self, memodict: dict):
108
135
  try:
@@ -123,7 +150,7 @@ def _maybe_array(val, make_array):
123
150
 
124
151
 
125
152
  class _Integer(_AbstractNode):
126
- _value: jnp.ndarray
153
+ _value: jax.typing.ArrayLike
127
154
 
128
155
  def __init__(self, expr: sympy.Expr, make_array: bool):
129
156
  assert isinstance(expr, sympy.Integer)
@@ -134,11 +161,11 @@ class _Integer(_AbstractNode):
134
161
 
135
162
  def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
136
163
  # memodict not needed as sympy deduplicates internally
137
- return sympy.Integer(self._value.item())
164
+ return sympy.Integer(_item(self._value))
138
165
 
139
166
 
140
167
  class _Float(_AbstractNode):
141
- _value: jnp.ndarray
168
+ _value: jax.typing.ArrayLike
142
169
 
143
170
  def __init__(self, expr: sympy.Expr, make_array: bool):
144
171
  assert isinstance(expr, sympy.Float)
@@ -149,24 +176,49 @@ class _Float(_AbstractNode):
149
176
 
150
177
  def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
151
178
  # memodict not needed as sympy deduplicates internally
152
- return sympy.Float(self._value.item())
179
+ return sympy.Float(_item(self._value))
153
180
 
154
181
 
155
182
  class _Rational(_AbstractNode):
156
- _numerator: jnp.ndarray
157
- _denominator: jnp.ndarray
183
+ _numerator: jax.typing.ArrayLike
184
+ _denominator: jax.typing.ArrayLike
158
185
 
159
186
  def __init__(self, expr: sympy.Expr, make_array: bool):
160
187
  assert isinstance(expr, sympy.Rational)
161
- self._numerator = _maybe_array(int(expr.numerator), make_array)
162
- self._denominator = _maybe_array(int(expr.denominator), make_array)
188
+ numerator = expr.numerator
189
+ denominator = expr.denominator
190
+ if callable(numerator):
191
+ # Support SymPy < 1.10
192
+ numerator = numerator()
193
+ if callable(denominator):
194
+ denominator = denominator()
195
+ self._numerator = _maybe_array(int(numerator), make_array)
196
+ self._denominator = _maybe_array(int(denominator), make_array)
163
197
 
164
198
  def __call__(self, memodict: dict):
165
199
  return self._numerator / self._denominator
166
200
 
167
201
  def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
168
202
  # memodict not needed as sympy deduplicates internally
169
- return sympy.Integer(self._numerator) / sympy.Integer(self._denominator)
203
+ return sympy.Integer(_item(self._numerator)) / sympy.Integer(
204
+ _item(self._denominator)
205
+ )
206
+
207
+
208
+ class _Constant(_AbstractNode):
209
+ _value: jnp.ndarray
210
+ _expr: sympy.Expr
211
+
212
+ def __init__(self, expr: sympy.Expr, make_array: bool):
213
+ assert expr in _constant_lookup
214
+ self._value = _maybe_array(_constant_lookup[expr], make_array)
215
+ self._expr = expr
216
+
217
+ def __call__(self, memodict: dict):
218
+ return self._value
219
+
220
+ def sympy(self, memodict: dict, func_lookup: dict) -> sympy.Expr:
221
+ return self._expr
170
222
 
171
223
 
172
224
  class _Func(_AbstractNode):
@@ -174,14 +226,15 @@ class _Func(_AbstractNode):
174
226
  _args: list
175
227
 
176
228
  def __init__(
177
- self, expr: sympy.Expr, memodict: dict, func_lookup: dict, make_array: bool
229
+ self, expr: sympy.Expr, memodict: dict, func_lookup: Mapping, make_array: bool
178
230
  ):
179
231
  try:
180
232
  self._func = func_lookup[expr.func]
181
233
  except KeyError as e:
182
234
  raise KeyError(f"Unsupported Sympy type {type(expr)}") from e
183
235
  self._args = [
184
- _sympy_to_node(arg, memodict, func_lookup, make_array) for arg in expr.args
236
+ _sympy_to_node(cast(sympy.Expr, arg), memodict, func_lookup, make_array)
237
+ for arg in expr.args
185
238
  ]
186
239
 
187
240
  def __call__(self, memodict: dict):
@@ -207,7 +260,7 @@ class _Func(_AbstractNode):
207
260
 
208
261
 
209
262
  def _sympy_to_node(
210
- expr: sympy.Expr, memodict: dict, func_lookup: dict, make_array: bool
263
+ expr: sympy.Expr, memodict: dict, func_lookup: Mapping, make_array: bool
211
264
  ) -> _AbstractNode:
212
265
  try:
213
266
  return memodict[expr]
@@ -220,6 +273,8 @@ def _sympy_to_node(
220
273
  out = _Float(expr, make_array)
221
274
  elif isinstance(expr, sympy.Rational):
222
275
  out = _Rational(expr, make_array)
276
+ elif expr in (sympy.E, sympy.pi, sympy.EulerGamma, sympy.I):
277
+ out = _Constant(expr, make_array)
223
278
  else:
224
279
  out = _Func(expr, memodict, func_lookup, make_array)
225
280
  memodict[expr] = out
@@ -239,9 +294,7 @@ class SymbolicModule(eqx.Module):
239
294
  expressions: PyTree,
240
295
  extra_funcs: Optional[dict] = None,
241
296
  make_array: bool = True,
242
- **kwargs,
243
297
  ):
244
- super().__init__(**kwargs)
245
298
  if extra_funcs is None:
246
299
  lookup = _lookup
247
300
  self.has_extra_funcs = False
@@ -259,7 +312,8 @@ class SymbolicModule(eqx.Module):
259
312
  def sympy(self) -> sympy.Expr:
260
313
  if self.has_extra_funcs:
261
314
  raise NotImplementedError(
262
- "SymbolicModule cannot be converted back to SymPy if `extra_funcs` is passed"
315
+ "SymbolicModule cannot be converted back to SymPy if `extra_funcs` "
316
+ "is passed."
263
317
  )
264
318
  memodict = dict()
265
319
  return jax.tree_map(
@@ -0,0 +1,296 @@
1
+ Metadata-Version: 2.4
2
+ Name: sympy2jax
3
+ Version: 0.0.6
4
+ Summary: Turn SymPy expressions into trainable JAX expressions.
5
+ Project-URL: repository, https://github.com/google/sympy2jax
6
+ Author-email: Patrick Kidger <contact@kidger.site>
7
+ License:
8
+ Apache License
9
+ Version 2.0, January 2004
10
+ http://www.apache.org/licenses/
11
+
12
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
13
+
14
+ 1. Definitions.
15
+
16
+ "License" shall mean the terms and conditions for use, reproduction,
17
+ and distribution as defined by Sections 1 through 9 of this document.
18
+
19
+ "Licensor" shall mean the copyright owner or entity authorized by
20
+ the copyright owner that is granting the License.
21
+
22
+ "Legal Entity" shall mean the union of the acting entity and all
23
+ other entities that control, are controlled by, or are under common
24
+ control with that entity. For the purposes of this definition,
25
+ "control" means (i) the power, direct or indirect, to cause the
26
+ direction or management of such entity, whether by contract or
27
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
28
+ outstanding shares, or (iii) beneficial ownership of such entity.
29
+
30
+ "You" (or "Your") shall mean an individual or Legal Entity
31
+ exercising permissions granted by this License.
32
+
33
+ "Source" form shall mean the preferred form for making modifications,
34
+ including but not limited to software source code, documentation
35
+ source, and configuration files.
36
+
37
+ "Object" form shall mean any form resulting from mechanical
38
+ transformation or translation of a Source form, including but
39
+ not limited to compiled object code, generated documentation,
40
+ and conversions to other media types.
41
+
42
+ "Work" shall mean the work of authorship, whether in Source or
43
+ Object form, made available under the License, as indicated by a
44
+ copyright notice that is included in or attached to the work
45
+ (an example is provided in the Appendix below).
46
+
47
+ "Derivative Works" shall mean any work, whether in Source or Object
48
+ form, that is based on (or derived from) the Work and for which the
49
+ editorial revisions, annotations, elaborations, or other modifications
50
+ represent, as a whole, an original work of authorship. For the purposes
51
+ of this License, Derivative Works shall not include works that remain
52
+ separable from, or merely link (or bind by name) to the interfaces of,
53
+ the Work and Derivative Works thereof.
54
+
55
+ "Contribution" shall mean any work of authorship, including
56
+ the original version of the Work and any modifications or additions
57
+ to that Work or Derivative Works thereof, that is intentionally
58
+ submitted to Licensor for inclusion in the Work by the copyright owner
59
+ or by an individual or Legal Entity authorized to submit on behalf of
60
+ the copyright owner. For the purposes of this definition, "submitted"
61
+ means any form of electronic, verbal, or written communication sent
62
+ to the Licensor or its representatives, including but not limited to
63
+ communication on electronic mailing lists, source code control systems,
64
+ and issue tracking systems that are managed by, or on behalf of, the
65
+ Licensor for the purpose of discussing and improving the Work, but
66
+ excluding communication that is conspicuously marked or otherwise
67
+ designated in writing by the copyright owner as "Not a Contribution."
68
+
69
+ "Contributor" shall mean Licensor and any individual or Legal Entity
70
+ on behalf of whom a Contribution has been received by Licensor and
71
+ subsequently incorporated within the Work.
72
+
73
+ 2. Grant of Copyright License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ copyright license to reproduce, prepare Derivative Works of,
77
+ publicly display, publicly perform, sublicense, and distribute the
78
+ Work and such Derivative Works in Source or Object form.
79
+
80
+ 3. Grant of Patent License. Subject to the terms and conditions of
81
+ this License, each Contributor hereby grants to You a perpetual,
82
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
83
+ (except as stated in this section) patent license to make, have made,
84
+ use, offer to sell, sell, import, and otherwise transfer the Work,
85
+ where such license applies only to those patent claims licensable
86
+ by such Contributor that are necessarily infringed by their
87
+ Contribution(s) alone or by combination of their Contribution(s)
88
+ with the Work to which such Contribution(s) was submitted. If You
89
+ institute patent litigation against any entity (including a
90
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
91
+ or a Contribution incorporated within the Work constitutes direct
92
+ or contributory patent infringement, then any patent licenses
93
+ granted to You under this License for that Work shall terminate
94
+ as of the date such litigation is filed.
95
+
96
+ 4. Redistribution. You may reproduce and distribute copies of the
97
+ Work or Derivative Works thereof in any medium, with or without
98
+ modifications, and in Source or Object form, provided that You
99
+ meet the following conditions:
100
+
101
+ (a) You must give any other recipients of the Work or
102
+ Derivative Works a copy of this License; and
103
+
104
+ (b) You must cause any modified files to carry prominent notices
105
+ stating that You changed the files; and
106
+
107
+ (c) You must retain, in the Source form of any Derivative Works
108
+ that You distribute, all copyright, patent, trademark, and
109
+ attribution notices from the Source form of the Work,
110
+ excluding those notices that do not pertain to any part of
111
+ the Derivative Works; and
112
+
113
+ (d) If the Work includes a "NOTICE" text file as part of its
114
+ distribution, then any Derivative Works that You distribute must
115
+ include a readable copy of the attribution notices contained
116
+ within such NOTICE file, excluding those notices that do not
117
+ pertain to any part of the Derivative Works, in at least one
118
+ of the following places: within a NOTICE text file distributed
119
+ as part of the Derivative Works; within the Source form or
120
+ documentation, if provided along with the Derivative Works; or,
121
+ within a display generated by the Derivative Works, if and
122
+ wherever such third-party notices normally appear. The contents
123
+ of the NOTICE file are for informational purposes only and
124
+ do not modify the License. You may add Your own attribution
125
+ notices within Derivative Works that You distribute, alongside
126
+ or as an addendum to the NOTICE text from the Work, provided
127
+ that such additional attribution notices cannot be construed
128
+ as modifying the License.
129
+
130
+ You may add Your own copyright statement to Your modifications and
131
+ may provide additional or different license terms and conditions
132
+ for use, reproduction, or distribution of Your modifications, or
133
+ for any such Derivative Works as a whole, provided Your use,
134
+ reproduction, and distribution of the Work otherwise complies with
135
+ the conditions stated in this License.
136
+
137
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
138
+ any Contribution intentionally submitted for inclusion in the Work
139
+ by You to the Licensor shall be under the terms and conditions of
140
+ this License, without any additional terms or conditions.
141
+ Notwithstanding the above, nothing herein shall supersede or modify
142
+ the terms of any separate license agreement you may have executed
143
+ with Licensor regarding such Contributions.
144
+
145
+ 6. Trademarks. This License does not grant permission to use the trade
146
+ names, trademarks, service marks, or product names of the Licensor,
147
+ except as required for reasonable and customary use in describing the
148
+ origin of the Work and reproducing the content of the NOTICE file.
149
+
150
+ 7. Disclaimer of Warranty. Unless required by applicable law or
151
+ agreed to in writing, Licensor provides the Work (and each
152
+ Contributor provides its Contributions) on an "AS IS" BASIS,
153
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
154
+ implied, including, without limitation, any warranties or conditions
155
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
156
+ PARTICULAR PURPOSE. You are solely responsible for determining the
157
+ appropriateness of using or redistributing the Work and assume any
158
+ risks associated with Your exercise of permissions under this License.
159
+
160
+ 8. Limitation of Liability. In no event and under no legal theory,
161
+ whether in tort (including negligence), contract, or otherwise,
162
+ unless required by applicable law (such as deliberate and grossly
163
+ negligent acts) or agreed to in writing, shall any Contributor be
164
+ liable to You for damages, including any direct, indirect, special,
165
+ incidental, or consequential damages of any character arising as a
166
+ result of this License or out of the use or inability to use the
167
+ Work (including but not limited to damages for loss of goodwill,
168
+ work stoppage, computer failure or malfunction, or any and all
169
+ other commercial damages or losses), even if such Contributor
170
+ has been advised of the possibility of such damages.
171
+
172
+ 9. Accepting Warranty or Additional Liability. While redistributing
173
+ the Work or Derivative Works thereof, You may choose to offer,
174
+ and charge a fee for, acceptance of support, warranty, indemnity,
175
+ or other liability obligations and/or rights consistent with this
176
+ License. However, in accepting such obligations, You may act only
177
+ on Your own behalf and on Your sole responsibility, not on behalf
178
+ of any other Contributor, and only if You agree to indemnify,
179
+ defend, and hold each Contributor harmless for any liability
180
+ incurred by, or claims asserted against, such Contributor by reason
181
+ of your accepting any such warranty or additional liability.
182
+
183
+ END OF TERMS AND CONDITIONS
184
+
185
+ APPENDIX: How to apply the Apache License to your work.
186
+
187
+ To apply the Apache License to your work, attach the following
188
+ boilerplate notice, with the fields enclosed by brackets "[]"
189
+ replaced with your own identifying information. (Don't include
190
+ the brackets!) The text should be enclosed in the appropriate
191
+ comment syntax for the file format. We also recommend that a
192
+ file or class name and description of purpose be included on the
193
+ same "printed page" as the copyright notice for easier
194
+ identification within third-party archives.
195
+
196
+ Copyright [yyyy] [name of copyright owner]
197
+
198
+ Licensed under the Apache License, Version 2.0 (the "License");
199
+ you may not use this file except in compliance with the License.
200
+ You may obtain a copy of the License at
201
+
202
+ http://www.apache.org/licenses/LICENSE-2.0
203
+
204
+ Unless required by applicable law or agreed to in writing, software
205
+ distributed under the License is distributed on an "AS IS" BASIS,
206
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
207
+ See the License for the specific language governing permissions and
208
+ limitations under the License.
209
+ License-File: LICENSE
210
+ Keywords: equinox,jax,sympy
211
+ Classifier: Development Status :: 3 - Alpha
212
+ Classifier: Intended Audience :: Science/Research
213
+ Classifier: License :: OSI Approved :: Apache Software License
214
+ Classifier: Natural Language :: English
215
+ Classifier: Programming Language :: Python :: 3
216
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
217
+ Classifier: Topic :: Scientific/Engineering :: Mathematics
218
+ Requires-Python: ~=3.9
219
+ Requires-Dist: equinox>=0.5.3
220
+ Requires-Dist: jax>=0.3.4
221
+ Requires-Dist: sympy>=1.7.1
222
+ Description-Content-Type: text/markdown
223
+
224
+ <h1 align="center">sympy2jax</h1>
225
+
226
+ Turn SymPy expressions into trainable JAX expressions. The output will be an [Equinox](https://github.com/patrick-kidger/equinox) module with all SymPy floats (integers, rationals, ...) as leaves. SymPy symbols will be inputs.
227
+
228
+ Optimise your symbolic expressions via gradient descent!
229
+
230
+ ## Installation
231
+
232
+ ```bash
233
+ pip install sympy2jax
234
+ ```
235
+
236
+ Requires:
237
+ Python 3.7+
238
+ JAX 0.3.4+
239
+ Equinox 0.5.3+
240
+ SymPy 1.7.1+.
241
+
242
+ ## Example
243
+
244
+ ```python
245
+ import jax
246
+ import sympy
247
+ import sympy2jax
248
+
249
+ x_sym = sympy.symbols("x_sym")
250
+ cosx = 1.0 * sympy.cos(x_sym)
251
+ sinx = 2.0 * sympy.sin(x_sym)
252
+ mod = sympy2jax.SymbolicModule([cosx, sinx]) # PyTree of input expressions
253
+
254
+ x = jax.numpy.zeros(3)
255
+ out = mod(x_sym=x) # PyTree of results.
256
+ params = jax.tree_leaves(mod) # 1.0 and 2.0 are parameters.
257
+ # (Which may be trained in the usual way for Equinox.)
258
+ ```
259
+
260
+ ## Documentation
261
+
262
+ ```python
263
+ sympytorch.SymbolicModule(expressions, extra_funcs=None, make_array=True)
264
+ ```
265
+
266
+ Where:
267
+ - `expressions` is a PyTree of SymPy expressions.
268
+ - `extra_funcs` is an optional dictionary from SymPy functions to JAX operations, to extend the built-in translation rules.
269
+ - `make_array` is whether integers/floats/rationals should be stored as Python integers/etc., or as JAX arrays.
270
+
271
+ Instances can be called with key-value pairs of symbol-value, as in the above example.
272
+
273
+ Instances have a `.sympy()` method that translates the module back into a PyTree of SymPy expressions.
274
+
275
+ (That's literally the entire documentation, it's super easy.)
276
+
277
+ ## See also: other libraries in the JAX ecosystem
278
+
279
+ **Always useful**
280
+ [Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX!
281
+ [jaxtyping](https://github.com/patrick-kidger/jaxtyping): type annotations for shape/dtype of arrays.
282
+
283
+ **Deep learning**
284
+ [Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.
285
+ [Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).
286
+ [Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).
287
+
288
+ **Scientific computing**
289
+ [Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers.
290
+ [Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.
291
+ [Lineax](https://github.com/patrick-kidger/lineax): linear solvers.
292
+ [BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling.
293
+ [PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!)
294
+
295
+ **Awesome JAX**
296
+ [Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of other JAX projects.
@@ -0,0 +1,6 @@
1
+ sympy2jax/__init__.py,sha256=KqjLNlIiDATWsPhjPLpG-Hud4HMMuq_DP6sISDWVYQg,720
2
+ sympy2jax/sympy_module.py,sha256=zQwxSeW2GbNWQbvqwYaDq_IFx3gVw0Jwe5d6h_YDJO4,9383
3
+ sympy2jax-0.0.6.dist-info/METADATA,sha256=RG9im4iI-5vriRimeDW8wH3aAiy7V1Gcx2FLOpHfsYk,16559
4
+ sympy2jax-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ sympy2jax-0.0.6.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
6
+ sympy2jax-0.0.6.dist-info/RECORD,,
@@ -1,5 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.37.1)
2
+ Generator: hatchling 1.27.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
-
@@ -1,92 +0,0 @@
1
- Metadata-Version: 2.1
2
- Name: sympy2jax
3
- Version: 0.0.4
4
- Summary: Turn SymPy expressions into trainable JAX expressions.
5
- Home-page: https://github.com/google/sympy2jax
6
- Author: Patrick Kidger
7
- Author-email: contact@kidger.site
8
- Maintainer: Patrick Kidger
9
- Maintainer-email: contact@kidger.site
10
- License: Apache-2.0
11
- Platform: UNKNOWN
12
- Classifier: Development Status :: 3 - Alpha
13
- Classifier: Intended Audience :: Science/Research
14
- Classifier: License :: OSI Approved :: Apache Software License
15
- Classifier: Natural Language :: English
16
- Classifier: Programming Language :: Python :: 3
17
- Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
18
- Classifier: Topic :: Scientific/Engineering :: Mathematics
19
- Requires-Python: ~=3.7
20
- Description-Content-Type: text/markdown
21
- Requires-Dist: equinox (>=0.5.3)
22
- Requires-Dist: jax (>=0.3.4)
23
- Requires-Dist: sympy (>=1.7.1)
24
-
25
- <h1 align="center">sympy2jax</h1>
26
-
27
- Turn SymPy expressions into trainable JAX expressions. The output will be an [Equinox](https://github.com/patrick-kidger/equinox) module with all SymPy floats (integers, rationals, ...) as leaves. SymPy symbols will be inputs.
28
-
29
- Optimise your symbolic expressions via gradient descent!
30
-
31
- ## Installation
32
-
33
- ```bash
34
- pip install sympy2jax
35
- ```
36
-
37
- Requires:
38
- Python 3.7+
39
- JAX 0.3.4+
40
- Equinox 0.5.3+
41
- SymPy 1.7.1+.
42
-
43
- ## Example
44
-
45
- ```python
46
- import jax
47
- import sympy
48
- import sympy2jax
49
-
50
- x_sym = sympy.symbols("x_sym")
51
- cosx = 1.0 * sympy.cos(x_sym)
52
- sinx = 2.0 * sympy.sin(x_sym)
53
- mod = sympy2jax.SymbolicModule([cosx, sinx]) # PyTree of input expressions
54
-
55
- x = jax.numpy.zeros(3)
56
- out = mod(x_sym=x) # PyTree of results.
57
- params = jax.tree_leaves(mod) # 1.0 and 2.0 are parameters.
58
- # (Which may be trained in the usual way for Equinox.)
59
- ```
60
-
61
- ## Documentation
62
-
63
- ```python
64
- sympytorch.SymbolicModule(expressions, extra_funcs=None, make_array=True)
65
- ```
66
-
67
- Where:
68
- - `expressions` is a PyTree of SymPy expressions.
69
- - `extra_funcs` is an optional dictionary from SymPy functions to JAX operations, to extend the built-in translation rules.
70
- - `make_array` is whether integers/floats/rationals should be stored as Python integers/etc., or as JAX arrays.
71
-
72
- Instances can be called with key-value pairs of symbol-value, as in the above example.
73
-
74
- Instances have a `.sympy()` method that translates the module back into a PyTree of SymPy expressions.
75
-
76
- (That's literally the entire documentation, it's super easy.)
77
-
78
- ## Finally
79
-
80
- ### See also: other tools in the JAX ecosystem
81
-
82
- Neural networks: [Equinox](https://github.com/patrick-kidger/equinox).
83
-
84
- Numerical differential equation solvers: [Diffrax](https://github.com/patrick-kidger/diffrax).
85
-
86
- Type annotations and runtime checking for PyTrees and shape/dtype of JAX arrays: [jaxtyping](https://github.com/google/jaxtyping).
87
-
88
- ### Disclaimer
89
-
90
- This is not an official Google product.
91
-
92
-
@@ -1,7 +0,0 @@
1
- sympy2jax/__init__.py,sha256=s7exLvdhnQjBFcOz5Qs-K5DpaYjeunNwSK0fm-6dHew,641
2
- sympy2jax/sympy_module.py,sha256=9rJxrUmkVbj3Rhz4dacOzzhxmc4rQVScIvrtypu1Xok,7884
3
- sympy2jax-0.0.4.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
4
- sympy2jax-0.0.4.dist-info/METADATA,sha256=mfHkzIfF9rVMuscQv4aaeQGA5RO76b36K-9tb4TqZnU,2823
5
- sympy2jax-0.0.4.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
6
- sympy2jax-0.0.4.dist-info/top_level.txt,sha256=90sePp5fhiDuG5-Q1ocMRtZDImhpqSRr_KiV8Kp8TJA,10
7
- sympy2jax-0.0.4.dist-info/RECORD,,
@@ -1 +0,0 @@
1
- sympy2jax