brainstate 0.1.0.post20250406__py2.py3-none-any.whl → 0.1.0.post20250420__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_compatible_import.py +58 -0
- brainstate/compile/_conditions.py +3 -1
- brainstate/compile/_make_jaxpr.py +11 -14
- brainstate/compile/_make_jaxpr_test.py +1 -5
- brainstate/compile/_unvmap.py +1 -4
- brainstate/surrogate.py +1 -4
- {brainstate-0.1.0.post20250406.dist-info → brainstate-0.1.0.post20250420.dist-info}/METADATA +1 -1
- {brainstate-0.1.0.post20250406.dist-info → brainstate-0.1.0.post20250420.dist-info}/RECORD +11 -10
- {brainstate-0.1.0.post20250406.dist-info → brainstate-0.1.0.post20250420.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250406.dist-info → brainstate-0.1.0.post20250420.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250406.dist-info → brainstate-0.1.0.post20250420.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,58 @@
|
|
1
|
+
# Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
|
19
|
+
from contextlib import contextmanager
|
20
|
+
from typing import Iterable, Hashable
|
21
|
+
|
22
|
+
import jax
|
23
|
+
|
24
|
+
__all__ = [
|
25
|
+
'ClosedJaxpr',
|
26
|
+
'Primitive',
|
27
|
+
'extend_axis_env_nd',
|
28
|
+
'jaxpr_as_fun',
|
29
|
+
'get_aval',
|
30
|
+
'Tracer',
|
31
|
+
'to_concrete_aval',
|
32
|
+
]
|
33
|
+
|
34
|
+
from jax.core import get_aval, Tracer
|
35
|
+
|
36
|
+
if jax.__version_info__ < (0, 4, 38):
|
37
|
+
from jax.core import ClosedJaxpr, extend_axis_env_nd, Primitive, jaxpr_as_fun
|
38
|
+
else:
|
39
|
+
from jax.extend.core import ClosedJaxpr, Primitive, jaxpr_as_fun
|
40
|
+
from jax.core import trace_ctx
|
41
|
+
|
42
|
+
|
43
|
+
@contextmanager
|
44
|
+
def extend_axis_env_nd(name_size_pairs: Iterable[tuple[Hashable, int]]):
|
45
|
+
prev = trace_ctx.axis_env
|
46
|
+
try:
|
47
|
+
trace_ctx.set_axis_env(prev.extend_pure(name_size_pairs))
|
48
|
+
yield
|
49
|
+
finally:
|
50
|
+
trace_ctx.set_axis_env(prev)
|
51
|
+
|
52
|
+
|
53
|
+
def to_concrete_aval(aval):
|
54
|
+
aval = get_aval(aval)
|
55
|
+
if isinstance(aval, Tracer):
|
56
|
+
return aval.to_concrete_value()
|
57
|
+
return aval
|
58
|
+
|
@@ -24,6 +24,8 @@ from brainstate._utils import set_module_as
|
|
24
24
|
from ._error_if import jit_error_if
|
25
25
|
from ._make_jaxpr import StatefulFunction
|
26
26
|
from ._util import wrap_single_fun_in_multi_branches, write_back_state_values
|
27
|
+
from brainstate._compatible_import import to_concrete_aval, Tracer
|
28
|
+
|
27
29
|
|
28
30
|
__all__ = [
|
29
31
|
'cond', 'switch', 'ifelse',
|
@@ -86,7 +88,7 @@ def cond(pred, true_fun: Callable, false_fun: Callable, *operands):
|
|
86
88
|
raise TypeError("Pred type must be either boolean or number, got {}.".format(pred_dtype))
|
87
89
|
|
88
90
|
# not jit
|
89
|
-
if jax.config.jax_disable_jit and isinstance(
|
91
|
+
if jax.config.jax_disable_jit and not isinstance(to_concrete_aval(pred), Tracer):
|
90
92
|
if pred:
|
91
93
|
return true_fun(*operands)
|
92
94
|
else:
|
@@ -55,10 +55,12 @@ from __future__ import annotations
|
|
55
55
|
|
56
56
|
import functools
|
57
57
|
import inspect
|
58
|
-
import jax
|
59
58
|
import operator
|
60
59
|
from collections.abc import Hashable, Iterable, Sequence
|
61
|
-
from contextlib import ExitStack
|
60
|
+
from contextlib import ExitStack, contextmanager
|
61
|
+
from typing import Any, Callable, Tuple, Union, Dict, Optional
|
62
|
+
|
63
|
+
import jax
|
62
64
|
from jax._src import source_info_util
|
63
65
|
from jax._src.linear_util import annotate
|
64
66
|
from jax._src.traceback_util import api_boundary
|
@@ -66,17 +68,12 @@ from jax.api_util import shaped_abstractify
|
|
66
68
|
from jax.extend.linear_util import transformation_with_aux, wrap_init
|
67
69
|
from jax.interpreters import partial_eval as pe
|
68
70
|
from jax.util import wraps
|
69
|
-
from typing import Any, Callable, Tuple, Union, Dict, Optional
|
70
71
|
|
71
72
|
from brainstate._state import State, StateTraceStack
|
72
73
|
from brainstate._utils import set_module_as
|
73
74
|
from brainstate.typing import PyTree
|
74
75
|
from brainstate.util import PrettyObject
|
75
|
-
|
76
|
-
if jax.__version_info__ < (0, 4, 38):
|
77
|
-
from jax.core import ClosedJaxpr
|
78
|
-
else:
|
79
|
-
from jax.extend.core import ClosedJaxpr
|
76
|
+
from brainstate._compatible_import import ClosedJaxpr, extend_axis_env_nd
|
80
77
|
|
81
78
|
AxisName = Hashable
|
82
79
|
|
@@ -200,7 +197,7 @@ class StatefulFunction(PrettyObject):
|
|
200
197
|
|
201
198
|
# implicit parameters
|
202
199
|
self.cache_type = cache_type
|
203
|
-
self._cached_jaxpr: Dict[Any,
|
200
|
+
self._cached_jaxpr: Dict[Any, ClosedJaxpr] = dict()
|
204
201
|
self._cached_out_shapes: Dict[Any, PyTree] = dict()
|
205
202
|
self._cached_jaxpr_out_tree: Dict[Any, PyTree] = dict()
|
206
203
|
self._cached_state_trace: Dict[Any, StateTraceStack] = dict()
|
@@ -210,7 +207,7 @@ class StatefulFunction(PrettyObject):
|
|
210
207
|
return None
|
211
208
|
return k, v
|
212
209
|
|
213
|
-
def get_jaxpr(self, cache_key: Hashable = ()) ->
|
210
|
+
def get_jaxpr(self, cache_key: Hashable = ()) -> ClosedJaxpr:
|
214
211
|
"""
|
215
212
|
Read the JAX Jaxpr representation of the function.
|
216
213
|
|
@@ -507,8 +504,8 @@ def make_jaxpr(
|
|
507
504
|
return_shape: bool = False,
|
508
505
|
abstracted_axes: Optional[Any] = None,
|
509
506
|
state_returns: Union[str, Tuple[str, ...]] = ('read', 'write')
|
510
|
-
) -> Callable[..., (Tuple[
|
511
|
-
Tuple[
|
507
|
+
) -> Callable[..., (Tuple[ClosedJaxpr, Tuple[State, ...]] |
|
508
|
+
Tuple[ClosedJaxpr, Tuple[State, ...], PyTree])]:
|
512
509
|
"""
|
513
510
|
Creates a function that produces its jaxpr given example args.
|
514
511
|
|
@@ -754,12 +751,12 @@ def _make_jaxpr(
|
|
754
751
|
debug_info = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
|
755
752
|
with ExitStack() as stack:
|
756
753
|
if axis_env is not None:
|
757
|
-
stack.enter_context(
|
754
|
+
stack.enter_context(extend_axis_env_nd(axis_env))
|
758
755
|
if jax.__version_info__ < (0, 5, 0):
|
759
756
|
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info)
|
760
757
|
else:
|
761
758
|
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f)
|
762
|
-
closed_jaxpr =
|
759
|
+
closed_jaxpr = ClosedJaxpr(jaxpr, consts)
|
763
760
|
if return_shape:
|
764
761
|
out_avals, _ = jax.util.unzip2(out_type)
|
765
762
|
out_shapes_flat = [jax.ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
|
@@ -21,11 +21,7 @@ import pytest
|
|
21
21
|
import unittest
|
22
22
|
|
23
23
|
import brainstate as bst
|
24
|
-
|
25
|
-
if jax.__version_info__ < (0, 4, 38):
|
26
|
-
from jax.core import jaxpr_as_fun
|
27
|
-
else:
|
28
|
-
from jax.extend.core import jaxpr_as_fun
|
24
|
+
from brainstate._compatible_import import jaxpr_as_fun
|
29
25
|
|
30
26
|
|
31
27
|
class TestMakeJaxpr(unittest.TestCase):
|
brainstate/compile/_unvmap.py
CHANGED
@@ -21,11 +21,8 @@ import jax.interpreters.mlir as mlir
|
|
21
21
|
import jax.numpy as jnp
|
22
22
|
|
23
23
|
from brainstate._utils import set_module_as
|
24
|
+
from brainstate._compatible_import import Primitive
|
24
25
|
|
25
|
-
if jax.__version_info__ < (0, 4, 38):
|
26
|
-
from jax.core import Primitive
|
27
|
-
else:
|
28
|
-
from jax.extend.core import Primitive
|
29
26
|
|
30
27
|
__all__ = [
|
31
28
|
"unvmap",
|
brainstate/surrogate.py
CHANGED
@@ -22,11 +22,8 @@ import jax.scipy as sci
|
|
22
22
|
from jax.interpreters import batching, ad, mlir
|
23
23
|
|
24
24
|
from brainstate.util._pretty_pytree import PrettyObject
|
25
|
+
from brainstate._compatible_import import Primitive
|
25
26
|
|
26
|
-
if jax.__version_info__ < (0, 4, 38):
|
27
|
-
from jax.core import Primitive
|
28
|
-
else:
|
29
|
-
from jax.extend.core import Primitive
|
30
27
|
|
31
28
|
__all__ = [
|
32
29
|
'Surrogate',
|
{brainstate-0.1.0.post20250406.dist-info → brainstate-0.1.0.post20250420.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.1.0.
|
3
|
+
Version: 0.1.0.post20250420
|
4
4
|
Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
|
5
5
|
Home-page: https://github.com/chaobrain/brainstate
|
6
6
|
Author: BrainState Developers
|
@@ -1,4 +1,5 @@
|
|
1
1
|
brainstate/__init__.py,sha256=AkZyyFkn4fB8g2aT6Rc2MO1xICPpUZuDtdze-eUQNc0,1496
|
2
|
+
brainstate/_compatible_import.py,sha256=rFpe8602qIzFTSkGJ5aDm-TaWAsmXGb1p2pphU1jzE8,1720
|
2
3
|
brainstate/_state.py,sha256=KJclcHKGrIt8K_rDW3E2dO8g_f_UMcZwGID4UUb9MBE,60751
|
3
4
|
brainstate/_state_test.py,sha256=UBbbGJ8cb9dJ3NeySf-TNs_nNP47Ax8CP7QL_b32MAA,1636
|
4
5
|
brainstate/_utils.py,sha256=uJ6WWKq3yb05ZdktCQGLWOXsOJveL1H9pR7eev70Jes,1693
|
@@ -6,7 +7,7 @@ brainstate/environ.py,sha256=PllYYZKqany3G7NzIwoUPplLAePbyza6kJGXTPgJK-c,17698
|
|
6
7
|
brainstate/environ_test.py,sha256=khZ_-SUJL6rQCgndeYV98ruUIHGTwFDtITrOs_olmuo,2043
|
7
8
|
brainstate/mixin.py,sha256=g7uVUwZphZWsNs9pb48ozG2cDGaj0hs0g3lq8tDk-Sg,11310
|
8
9
|
brainstate/mixin_test.py,sha256=Oq_0fwC9vpXDN4t4dTBhWzLdFDNlcYsrcip14F1yECI,3079
|
9
|
-
brainstate/surrogate.py,sha256=
|
10
|
+
brainstate/surrogate.py,sha256=nSv2oaO0HBfbsKECr8ctpQSKbZveQAzidGdUrMcZQYU,53508
|
10
11
|
brainstate/transform.py,sha256=OvshYpPnp3YXPG6riy15Ve7jcX8t0aaATl1xZnsFeic,858
|
11
12
|
brainstate/typing.py,sha256=988gX1tvwtyYnYjmej90OaRxoMoBIPO0-DSrXXGxojM,10523
|
12
13
|
brainstate/augment/__init__.py,sha256=Q9-JIwQ1FNn8VLS1MA9MrSylbvUjWSw98whrI3NIuKo,1229
|
@@ -20,7 +21,7 @@ brainstate/augment/_random.py,sha256=ikRzNoDDE2BkARajDsBhNlngCUrghzGSZUDmEGvVors
|
|
20
21
|
brainstate/compile/__init__.py,sha256=fQtG316MLkeeu1Ssp54Kghw1PwbGK5gNq9yRVJu0wjA,1474
|
21
22
|
brainstate/compile/_ad_checkpoint.py,sha256=3wv-f89oo94XeWwRV5LcRot0Nz7xTk5_PdjEDyUMsoo,9394
|
22
23
|
brainstate/compile/_ad_checkpoint_test.py,sha256=R1I76nG4zIqb6g3M_VxWts7rUC1OHJCjtQhPkcbXodk,1746
|
23
|
-
brainstate/compile/_conditions.py,sha256=
|
24
|
+
brainstate/compile/_conditions.py,sha256=G-MioUGu2gjYV7wqulxTx8TfdVPx4AWbBi9V3sFw0uY,10221
|
24
25
|
brainstate/compile/_conditions_test.py,sha256=3Y1drxdNvTAm5CZu_ui00v9RTIaLDWoiRbzbmJjZrjQ,8387
|
25
26
|
brainstate/compile/_error_if.py,sha256=JIzvnAAW2Ze8aniw_CLZYktb1_WVtKLhP3eGCUhQNgg,2687
|
26
27
|
brainstate/compile/_error_if_test.py,sha256=j4x2bzWIWstwLzzt3R9hmencinSSveSrLt5AB7PJC1A,2042
|
@@ -30,10 +31,10 @@ brainstate/compile/_loop_collect_return.py,sha256=-LsP7fkHmAyGnDOKa3BxxYOEWe8M2J
|
|
30
31
|
brainstate/compile/_loop_collect_return_test.py,sha256=D9RQ5RyQHkqBr4nmSK-yM_uge3EC6uVm_Dzy42g3vtg,1802
|
31
32
|
brainstate/compile/_loop_no_collection.py,sha256=2OEVtv5XztOx-e0focZ1UnWkXmFzmDskjHJXuVXmuhA,7587
|
32
33
|
brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
|
33
|
-
brainstate/compile/_make_jaxpr.py,sha256=
|
34
|
-
brainstate/compile/_make_jaxpr_test.py,sha256=
|
34
|
+
brainstate/compile/_make_jaxpr.py,sha256=kh2bX1kQL9VLBic8E6MNa-gBO_G-4tJuFKTP517_PYE,33097
|
35
|
+
brainstate/compile/_make_jaxpr_test.py,sha256=EYUDoVKm2lDQeVCC3rBElRzH61LO76C61XgmEFtKcbU,4288
|
35
36
|
brainstate/compile/_progress_bar.py,sha256=3Z3OVcc5sl9FK9Fkt813l20MNzEfa6UZ9lJrvSgXTCU,7522
|
36
|
-
brainstate/compile/_unvmap.py,sha256=
|
37
|
+
brainstate/compile/_unvmap.py,sha256=dHLA6jkZdZ9_hPZ-4ovlPmZpi7Fl6Z4P6bta_Zgo1e0,4158
|
37
38
|
brainstate/compile/_util.py,sha256=a_tunKZ1OzVowCI2JmcniQz5P6bqZ4BJkDKmA_h7s6Y,6313
|
38
39
|
brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
|
39
40
|
brainstate/functional/_activations.py,sha256=VmCU9HOKWbysxuJFBN-JsShS4loNMG_E6IXfky1tX-s,21724
|
@@ -121,8 +122,8 @@ brainstate/util/_pretty_table.py,sha256=NM_6VAW6oL9jojsK0-RkQGHnDzLy_fn_hgzl5R8o
|
|
121
122
|
brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7562
|
122
123
|
brainstate/util/_struct.py,sha256=F5GfFURITAIYTwf17_xypkZU1wvoL4dUCviPnr_eCtw,17515
|
123
124
|
brainstate/util/filter.py,sha256=Zw0H42NwAi2P7dBr3ISv2VpkB5jqoWnV4Kpd61gq66o,14126
|
124
|
-
brainstate-0.1.0.
|
125
|
-
brainstate-0.1.0.
|
126
|
-
brainstate-0.1.0.
|
127
|
-
brainstate-0.1.0.
|
128
|
-
brainstate-0.1.0.
|
125
|
+
brainstate-0.1.0.post20250420.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
126
|
+
brainstate-0.1.0.post20250420.dist-info/METADATA,sha256=gHk2WgxJm1CirBO0QaiO85mcexK3eziLD9ahYST_RZ0,3689
|
127
|
+
brainstate-0.1.0.post20250420.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
128
|
+
brainstate-0.1.0.post20250420.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
129
|
+
brainstate-0.1.0.post20250420.dist-info/RECORD,,
|
File without changes
|
File without changes
|
{brainstate-0.1.0.post20250406.dist-info → brainstate-0.1.0.post20250420.dist-info}/top_level.txt
RENAMED
File without changes
|