brainstate 0.1.0.post20250406__py2.py3-none-any.whl → 0.1.0.post20250420__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,58 @@
1
+ # Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+
19
+ from contextlib import contextmanager
20
+ from typing import Iterable, Hashable
21
+
22
+ import jax
23
+
24
+ __all__ = [
25
+ 'ClosedJaxpr',
26
+ 'Primitive',
27
+ 'extend_axis_env_nd',
28
+ 'jaxpr_as_fun',
29
+ 'get_aval',
30
+ 'Tracer',
31
+ 'to_concrete_aval',
32
+ ]
33
+
34
+ from jax.core import get_aval, Tracer
35
+
36
+ if jax.__version_info__ < (0, 4, 38):
37
+ from jax.core import ClosedJaxpr, extend_axis_env_nd, Primitive, jaxpr_as_fun
38
+ else:
39
+ from jax.extend.core import ClosedJaxpr, Primitive, jaxpr_as_fun
40
+ from jax.core import trace_ctx
41
+
42
+
43
+ @contextmanager
44
+ def extend_axis_env_nd(name_size_pairs: Iterable[tuple[Hashable, int]]):
45
+ prev = trace_ctx.axis_env
46
+ try:
47
+ trace_ctx.set_axis_env(prev.extend_pure(name_size_pairs))
48
+ yield
49
+ finally:
50
+ trace_ctx.set_axis_env(prev)
51
+
52
+
53
+ def to_concrete_aval(aval):
54
+ aval = get_aval(aval)
55
+ if isinstance(aval, Tracer):
56
+ return aval.to_concrete_value()
57
+ return aval
58
+
@@ -24,6 +24,8 @@ from brainstate._utils import set_module_as
24
24
  from ._error_if import jit_error_if
25
25
  from ._make_jaxpr import StatefulFunction
26
26
  from ._util import wrap_single_fun_in_multi_branches, write_back_state_values
27
+ from brainstate._compatible_import import to_concrete_aval, Tracer
28
+
27
29
 
28
30
  __all__ = [
29
31
  'cond', 'switch', 'ifelse',
@@ -86,7 +88,7 @@ def cond(pred, true_fun: Callable, false_fun: Callable, *operands):
86
88
  raise TypeError("Pred type must be either boolean or number, got {}.".format(pred_dtype))
87
89
 
88
90
  # not jit
89
- if jax.config.jax_disable_jit and isinstance(jax.core.get_aval(pred), jax.core.ConcreteArray):
91
+ if jax.config.jax_disable_jit and not isinstance(to_concrete_aval(pred), Tracer):
90
92
  if pred:
91
93
  return true_fun(*operands)
92
94
  else:
@@ -55,10 +55,12 @@ from __future__ import annotations
55
55
 
56
56
  import functools
57
57
  import inspect
58
- import jax
59
58
  import operator
60
59
  from collections.abc import Hashable, Iterable, Sequence
61
- from contextlib import ExitStack
60
+ from contextlib import ExitStack, contextmanager
61
+ from typing import Any, Callable, Tuple, Union, Dict, Optional
62
+
63
+ import jax
62
64
  from jax._src import source_info_util
63
65
  from jax._src.linear_util import annotate
64
66
  from jax._src.traceback_util import api_boundary
@@ -66,17 +68,12 @@ from jax.api_util import shaped_abstractify
66
68
  from jax.extend.linear_util import transformation_with_aux, wrap_init
67
69
  from jax.interpreters import partial_eval as pe
68
70
  from jax.util import wraps
69
- from typing import Any, Callable, Tuple, Union, Dict, Optional
70
71
 
71
72
  from brainstate._state import State, StateTraceStack
72
73
  from brainstate._utils import set_module_as
73
74
  from brainstate.typing import PyTree
74
75
  from brainstate.util import PrettyObject
75
-
76
- if jax.__version_info__ < (0, 4, 38):
77
- from jax.core import ClosedJaxpr
78
- else:
79
- from jax.extend.core import ClosedJaxpr
76
+ from brainstate._compatible_import import ClosedJaxpr, extend_axis_env_nd
80
77
 
81
78
  AxisName = Hashable
82
79
 
@@ -200,7 +197,7 @@ class StatefulFunction(PrettyObject):
200
197
 
201
198
  # implicit parameters
202
199
  self.cache_type = cache_type
203
- self._cached_jaxpr: Dict[Any, jax.core.ClosedJaxpr] = dict()
200
+ self._cached_jaxpr: Dict[Any, ClosedJaxpr] = dict()
204
201
  self._cached_out_shapes: Dict[Any, PyTree] = dict()
205
202
  self._cached_jaxpr_out_tree: Dict[Any, PyTree] = dict()
206
203
  self._cached_state_trace: Dict[Any, StateTraceStack] = dict()
@@ -210,7 +207,7 @@ class StatefulFunction(PrettyObject):
210
207
  return None
211
208
  return k, v
212
209
 
213
- def get_jaxpr(self, cache_key: Hashable = ()) -> jax.core.ClosedJaxpr:
210
+ def get_jaxpr(self, cache_key: Hashable = ()) -> ClosedJaxpr:
214
211
  """
215
212
  Read the JAX Jaxpr representation of the function.
216
213
 
@@ -507,8 +504,8 @@ def make_jaxpr(
507
504
  return_shape: bool = False,
508
505
  abstracted_axes: Optional[Any] = None,
509
506
  state_returns: Union[str, Tuple[str, ...]] = ('read', 'write')
510
- ) -> Callable[..., (Tuple[jax.core.ClosedJaxpr, Tuple[State, ...]] |
511
- Tuple[jax.core.ClosedJaxpr, Tuple[State, ...], PyTree])]:
507
+ ) -> Callable[..., (Tuple[ClosedJaxpr, Tuple[State, ...]] |
508
+ Tuple[ClosedJaxpr, Tuple[State, ...], PyTree])]:
512
509
  """
513
510
  Creates a function that produces its jaxpr given example args.
514
511
 
@@ -754,12 +751,12 @@ def _make_jaxpr(
754
751
  debug_info = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
755
752
  with ExitStack() as stack:
756
753
  if axis_env is not None:
757
- stack.enter_context(jax.core.extend_axis_env_nd(axis_env))
754
+ stack.enter_context(extend_axis_env_nd(axis_env))
758
755
  if jax.__version_info__ < (0, 5, 0):
759
756
  jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info)
760
757
  else:
761
758
  jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f)
762
- closed_jaxpr = jax.core.ClosedJaxpr(jaxpr, consts)
759
+ closed_jaxpr = ClosedJaxpr(jaxpr, consts)
763
760
  if return_shape:
764
761
  out_avals, _ = jax.util.unzip2(out_type)
765
762
  out_shapes_flat = [jax.ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
@@ -21,11 +21,7 @@ import pytest
21
21
  import unittest
22
22
 
23
23
  import brainstate as bst
24
-
25
- if jax.__version_info__ < (0, 4, 38):
26
- from jax.core import jaxpr_as_fun
27
- else:
28
- from jax.extend.core import jaxpr_as_fun
24
+ from brainstate._compatible_import import jaxpr_as_fun
29
25
 
30
26
 
31
27
  class TestMakeJaxpr(unittest.TestCase):
@@ -21,11 +21,8 @@ import jax.interpreters.mlir as mlir
21
21
  import jax.numpy as jnp
22
22
 
23
23
  from brainstate._utils import set_module_as
24
+ from brainstate._compatible_import import Primitive
24
25
 
25
- if jax.__version_info__ < (0, 4, 38):
26
- from jax.core import Primitive
27
- else:
28
- from jax.extend.core import Primitive
29
26
 
30
27
  __all__ = [
31
28
  "unvmap",
brainstate/surrogate.py CHANGED
@@ -22,11 +22,8 @@ import jax.scipy as sci
22
22
  from jax.interpreters import batching, ad, mlir
23
23
 
24
24
  from brainstate.util._pretty_pytree import PrettyObject
25
+ from brainstate._compatible_import import Primitive
25
26
 
26
- if jax.__version_info__ < (0, 4, 38):
27
- from jax.core import Primitive
28
- else:
29
- from jax.extend.core import Primitive
30
27
 
31
28
  __all__ = [
32
29
  'Surrogate',
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.1.0.post20250406
3
+ Version: 0.1.0.post20250420
4
4
  Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
5
5
  Home-page: https://github.com/chaobrain/brainstate
6
6
  Author: BrainState Developers
@@ -1,4 +1,5 @@
1
1
  brainstate/__init__.py,sha256=AkZyyFkn4fB8g2aT6Rc2MO1xICPpUZuDtdze-eUQNc0,1496
2
+ brainstate/_compatible_import.py,sha256=rFpe8602qIzFTSkGJ5aDm-TaWAsmXGb1p2pphU1jzE8,1720
2
3
  brainstate/_state.py,sha256=KJclcHKGrIt8K_rDW3E2dO8g_f_UMcZwGID4UUb9MBE,60751
3
4
  brainstate/_state_test.py,sha256=UBbbGJ8cb9dJ3NeySf-TNs_nNP47Ax8CP7QL_b32MAA,1636
4
5
  brainstate/_utils.py,sha256=uJ6WWKq3yb05ZdktCQGLWOXsOJveL1H9pR7eev70Jes,1693
@@ -6,7 +7,7 @@ brainstate/environ.py,sha256=PllYYZKqany3G7NzIwoUPplLAePbyza6kJGXTPgJK-c,17698
6
7
  brainstate/environ_test.py,sha256=khZ_-SUJL6rQCgndeYV98ruUIHGTwFDtITrOs_olmuo,2043
7
8
  brainstate/mixin.py,sha256=g7uVUwZphZWsNs9pb48ozG2cDGaj0hs0g3lq8tDk-Sg,11310
8
9
  brainstate/mixin_test.py,sha256=Oq_0fwC9vpXDN4t4dTBhWzLdFDNlcYsrcip14F1yECI,3079
9
- brainstate/surrogate.py,sha256=wWYw-TxaFxHVneXuHjWD1UtTcOTk3XRSnhRtUkt_Hb8,53580
10
+ brainstate/surrogate.py,sha256=nSv2oaO0HBfbsKECr8ctpQSKbZveQAzidGdUrMcZQYU,53508
10
11
  brainstate/transform.py,sha256=OvshYpPnp3YXPG6riy15Ve7jcX8t0aaATl1xZnsFeic,858
11
12
  brainstate/typing.py,sha256=988gX1tvwtyYnYjmej90OaRxoMoBIPO0-DSrXXGxojM,10523
12
13
  brainstate/augment/__init__.py,sha256=Q9-JIwQ1FNn8VLS1MA9MrSylbvUjWSw98whrI3NIuKo,1229
@@ -20,7 +21,7 @@ brainstate/augment/_random.py,sha256=ikRzNoDDE2BkARajDsBhNlngCUrghzGSZUDmEGvVors
20
21
  brainstate/compile/__init__.py,sha256=fQtG316MLkeeu1Ssp54Kghw1PwbGK5gNq9yRVJu0wjA,1474
21
22
  brainstate/compile/_ad_checkpoint.py,sha256=3wv-f89oo94XeWwRV5LcRot0Nz7xTk5_PdjEDyUMsoo,9394
22
23
  brainstate/compile/_ad_checkpoint_test.py,sha256=R1I76nG4zIqb6g3M_VxWts7rUC1OHJCjtQhPkcbXodk,1746
23
- brainstate/compile/_conditions.py,sha256=POPdqRvFbd0dETkShPbsXUvhRQI8MhgU71ijnvtcqP0,10164
24
+ brainstate/compile/_conditions.py,sha256=G-MioUGu2gjYV7wqulxTx8TfdVPx4AWbBi9V3sFw0uY,10221
24
25
  brainstate/compile/_conditions_test.py,sha256=3Y1drxdNvTAm5CZu_ui00v9RTIaLDWoiRbzbmJjZrjQ,8387
25
26
  brainstate/compile/_error_if.py,sha256=JIzvnAAW2Ze8aniw_CLZYktb1_WVtKLhP3eGCUhQNgg,2687
26
27
  brainstate/compile/_error_if_test.py,sha256=j4x2bzWIWstwLzzt3R9hmencinSSveSrLt5AB7PJC1A,2042
@@ -30,10 +31,10 @@ brainstate/compile/_loop_collect_return.py,sha256=-LsP7fkHmAyGnDOKa3BxxYOEWe8M2J
30
31
  brainstate/compile/_loop_collect_return_test.py,sha256=D9RQ5RyQHkqBr4nmSK-yM_uge3EC6uVm_Dzy42g3vtg,1802
31
32
  brainstate/compile/_loop_no_collection.py,sha256=2OEVtv5XztOx-e0focZ1UnWkXmFzmDskjHJXuVXmuhA,7587
32
33
  brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
33
- brainstate/compile/_make_jaxpr.py,sha256=8iV8XyvkMH3n3wbEWZAgZtbrUxryljwQJD6o5DMW9Lc,33189
34
- brainstate/compile/_make_jaxpr_test.py,sha256=fZe3K4RHFLmMAeXZoFZ5RyxgXvncTcuMQdjmOROJtKU,4365
34
+ brainstate/compile/_make_jaxpr.py,sha256=kh2bX1kQL9VLBic8E6MNa-gBO_G-4tJuFKTP517_PYE,33097
35
+ brainstate/compile/_make_jaxpr_test.py,sha256=EYUDoVKm2lDQeVCC3rBElRzH61LO76C61XgmEFtKcbU,4288
35
36
  brainstate/compile/_progress_bar.py,sha256=3Z3OVcc5sl9FK9Fkt813l20MNzEfa6UZ9lJrvSgXTCU,7522
36
- brainstate/compile/_unvmap.py,sha256=uCvQjvb8J7kT0kalX576mrAPvQuCh_W76EPdgZ53kTM,4230
37
+ brainstate/compile/_unvmap.py,sha256=dHLA6jkZdZ9_hPZ-4ovlPmZpi7Fl6Z4P6bta_Zgo1e0,4158
37
38
  brainstate/compile/_util.py,sha256=a_tunKZ1OzVowCI2JmcniQz5P6bqZ4BJkDKmA_h7s6Y,6313
38
39
  brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
39
40
  brainstate/functional/_activations.py,sha256=VmCU9HOKWbysxuJFBN-JsShS4loNMG_E6IXfky1tX-s,21724
@@ -121,8 +122,8 @@ brainstate/util/_pretty_table.py,sha256=NM_6VAW6oL9jojsK0-RkQGHnDzLy_fn_hgzl5R8o
121
122
  brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7562
122
123
  brainstate/util/_struct.py,sha256=F5GfFURITAIYTwf17_xypkZU1wvoL4dUCviPnr_eCtw,17515
123
124
  brainstate/util/filter.py,sha256=Zw0H42NwAi2P7dBr3ISv2VpkB5jqoWnV4Kpd61gq66o,14126
124
- brainstate-0.1.0.post20250406.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
125
- brainstate-0.1.0.post20250406.dist-info/METADATA,sha256=VK2aJl8p_7ET5VL8_UcAQTTOg2ynbbs5msPCZiFEjpE,3689
126
- brainstate-0.1.0.post20250406.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
127
- brainstate-0.1.0.post20250406.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
128
- brainstate-0.1.0.post20250406.dist-info/RECORD,,
125
+ brainstate-0.1.0.post20250420.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
126
+ brainstate-0.1.0.post20250420.dist-info/METADATA,sha256=gHk2WgxJm1CirBO0QaiO85mcexK3eziLD9ahYST_RZ0,3689
127
+ brainstate-0.1.0.post20250420.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
128
+ brainstate-0.1.0.post20250420.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
129
+ brainstate-0.1.0.post20250420.dist-info/RECORD,,