brainstate 0.1.0.post20250423__py2.py3-none-any.whl → 0.1.0.post20250503__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_compatible_import.py +84 -1
- brainstate/compile/_make_jaxpr.py +12 -6
- {brainstate-0.1.0.post20250423.dist-info → brainstate-0.1.0.post20250503.dist-info}/METADATA +6 -2
- {brainstate-0.1.0.post20250423.dist-info → brainstate-0.1.0.post20250503.dist-info}/RECORD +7 -7
- {brainstate-0.1.0.post20250423.dist-info → brainstate-0.1.0.post20250503.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250423.dist-info → brainstate-0.1.0.post20250503.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250423.dist-info → brainstate-0.1.0.post20250503.dist-info}/top_level.txt +0 -0
brainstate/_compatible_import.py
CHANGED
@@ -18,7 +18,8 @@
|
|
18
18
|
|
19
19
|
import importlib.util
|
20
20
|
from contextlib import contextmanager
|
21
|
-
from
|
21
|
+
from functools import partial
|
22
|
+
from typing import Iterable, Hashable, TypeVar, Callable
|
22
23
|
|
23
24
|
import jax
|
24
25
|
|
@@ -31,8 +32,18 @@ __all__ = [
|
|
31
32
|
'Tracer',
|
32
33
|
'to_concrete_aval',
|
33
34
|
'brainevent',
|
35
|
+
'safe_map',
|
36
|
+
'safe_zip',
|
37
|
+
'unzip2',
|
38
|
+
'unzip3',
|
39
|
+
'wraps',
|
34
40
|
]
|
35
41
|
|
42
|
+
T = TypeVar("T")
|
43
|
+
T1 = TypeVar("T1")
|
44
|
+
T2 = TypeVar("T2")
|
45
|
+
T3 = TypeVar("T3")
|
46
|
+
|
36
47
|
brainevent_installed = importlib.util.find_spec('brainevent') is not None
|
37
48
|
|
38
49
|
from jax.core import get_aval, Tracer
|
@@ -53,6 +64,78 @@ else:
|
|
53
64
|
finally:
|
54
65
|
trace_ctx.set_axis_env(prev)
|
55
66
|
|
67
|
+
if jax.__version_info__ < (0, 6, 0):
|
68
|
+
from jax.util import safe_map, safe_zip, unzip2, wraps
|
69
|
+
|
70
|
+
else:
|
71
|
+
def safe_map(f, *args):
|
72
|
+
args = list(map(list, args))
|
73
|
+
n = len(args[0])
|
74
|
+
for arg in args[1:]:
|
75
|
+
assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
|
76
|
+
return list(map(f, *args))
|
77
|
+
|
78
|
+
|
79
|
+
def safe_zip(*args):
|
80
|
+
args = list(map(list, args))
|
81
|
+
n = len(args[0])
|
82
|
+
for arg in args[1:]:
|
83
|
+
assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
|
84
|
+
return list(zip(*args))
|
85
|
+
|
86
|
+
|
87
|
+
def unzip2(xys: Iterable[tuple[T1, T2]]
|
88
|
+
) -> tuple[tuple[T1, ...], tuple[T2, ...]]:
|
89
|
+
"""Unzip sequence of length-2 tuples into two tuples."""
|
90
|
+
# Note: we deliberately don't use zip(*xys) because it is lazily evaluated,
|
91
|
+
# is too permissive about inputs, and does not guarantee a length-2 output.
|
92
|
+
xs: list[T1] = []
|
93
|
+
ys: list[T2] = []
|
94
|
+
for x, y in xys:
|
95
|
+
xs.append(x)
|
96
|
+
ys.append(y)
|
97
|
+
return tuple(xs), tuple(ys)
|
98
|
+
|
99
|
+
|
100
|
+
def fun_name(fun: Callable):
|
101
|
+
name = getattr(fun, "__name__", None)
|
102
|
+
if name is not None:
|
103
|
+
return name
|
104
|
+
if isinstance(fun, partial):
|
105
|
+
return fun_name(fun.func)
|
106
|
+
else:
|
107
|
+
return "<unnamed function>"
|
108
|
+
|
109
|
+
|
110
|
+
def wraps(
|
111
|
+
wrapped: Callable,
|
112
|
+
namestr: str | None = None,
|
113
|
+
docstr: str | None = None,
|
114
|
+
**kwargs,
|
115
|
+
) -> Callable[[T], T]:
|
116
|
+
"""
|
117
|
+
Like functools.wraps, but with finer-grained control over the name and docstring
|
118
|
+
of the resulting function.
|
119
|
+
"""
|
120
|
+
|
121
|
+
def wrapper(fun: T) -> T:
|
122
|
+
try:
|
123
|
+
name = fun_name(wrapped)
|
124
|
+
doc = getattr(wrapped, "__doc__", "") or ""
|
125
|
+
fun.__dict__.update(getattr(wrapped, "__dict__", {}))
|
126
|
+
fun.__annotations__ = getattr(wrapped, "__annotations__", {})
|
127
|
+
fun.__name__ = name if namestr is None else namestr.format(fun=name)
|
128
|
+
fun.__module__ = getattr(wrapped, "__module__", "<unknown module>")
|
129
|
+
fun.__doc__ = (doc if docstr is None
|
130
|
+
else docstr.format(fun=name, doc=doc, **kwargs))
|
131
|
+
fun.__qualname__ = getattr(wrapped, "__qualname__", fun.__name__)
|
132
|
+
fun.__wrapped__ = wrapped
|
133
|
+
except Exception:
|
134
|
+
pass
|
135
|
+
return fun
|
136
|
+
|
137
|
+
return wrapper
|
138
|
+
|
56
139
|
|
57
140
|
def to_concrete_aval(aval):
|
58
141
|
aval = get_aval(aval)
|
@@ -67,9 +67,15 @@ from jax._src.traceback_util import api_boundary
|
|
67
67
|
from jax.api_util import shaped_abstractify
|
68
68
|
from jax.extend.linear_util import transformation_with_aux, wrap_init
|
69
69
|
from jax.interpreters import partial_eval as pe
|
70
|
-
from jax.util import wraps
|
71
70
|
|
72
|
-
from brainstate._compatible_import import
|
71
|
+
from brainstate._compatible_import import (
|
72
|
+
ClosedJaxpr,
|
73
|
+
extend_axis_env_nd,
|
74
|
+
safe_map,
|
75
|
+
safe_zip,
|
76
|
+
unzip2,
|
77
|
+
wraps,
|
78
|
+
)
|
73
79
|
from brainstate._state import State, StateTraceStack
|
74
80
|
from brainstate._utils import set_module_as
|
75
81
|
from brainstate.typing import PyTree
|
@@ -89,7 +95,7 @@ def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
|
|
89
95
|
try:
|
90
96
|
return (operator.index(x),)
|
91
97
|
except TypeError:
|
92
|
-
return tuple(
|
98
|
+
return tuple(safe_map(operator.index, x))
|
93
99
|
|
94
100
|
|
95
101
|
def _new_arg_fn(frame, trace, aval):
|
@@ -733,7 +739,7 @@ def _make_jaxpr(
|
|
733
739
|
else:
|
734
740
|
axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
|
735
741
|
in_type = pe.infer_lambda_input_type(axes_specs, flat_args)
|
736
|
-
in_avals, keep_inputs =
|
742
|
+
in_avals, keep_inputs = unzip2(in_type)
|
737
743
|
return in_avals, in_tree, keep_inputs
|
738
744
|
|
739
745
|
@wraps(fun)
|
@@ -744,7 +750,7 @@ def _make_jaxpr(
|
|
744
750
|
dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
|
745
751
|
f, args = jax.api_util.argnums_partial(f, dyn_argnums, args)
|
746
752
|
in_avals, in_tree, keep_inputs = _abstractify(args, kwargs)
|
747
|
-
in_type = tuple(
|
753
|
+
in_type = tuple(safe_zip(in_avals, keep_inputs))
|
748
754
|
f, out_tree = _flatten_fun(f, in_tree)
|
749
755
|
f = annotate(f, in_type)
|
750
756
|
if jax.__version_info__ < (0, 5, 0):
|
@@ -758,7 +764,7 @@ def _make_jaxpr(
|
|
758
764
|
jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f)
|
759
765
|
closed_jaxpr = ClosedJaxpr(jaxpr, consts)
|
760
766
|
if return_shape:
|
761
|
-
out_avals, _ =
|
767
|
+
out_avals, _ = unzip2(out_type)
|
762
768
|
out_shapes_flat = [jax.ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
|
763
769
|
return closed_jaxpr, jax.tree.unflatten(out_tree(), out_shapes_flat)
|
764
770
|
return closed_jaxpr
|
{brainstate-0.1.0.post20250423.dist-info → brainstate-0.1.0.post20250503.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.1.0.
|
3
|
+
Version: 0.1.0.post20250503
|
4
4
|
Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
|
5
5
|
Home-page: https://github.com/chaobrain/brainstate
|
6
6
|
Author: BrainState Developers
|
@@ -48,7 +48,7 @@ Requires-Dist: jaxlib[tpu] ; extra == 'tpu'
|
|
48
48
|
|
49
49
|
|
50
50
|
<p align="center">
|
51
|
-
<img alt="Header image of brainstate." src="https://
|
51
|
+
<img alt="Header image of brainstate." src="https://raw.githubusercontent.com/chaobrain/brainstate/main/docs/_static/brainstate.png" width=40%>
|
52
52
|
</p>
|
53
53
|
|
54
54
|
|
@@ -74,7 +74,11 @@ You can install ``brainstate`` via pip:
|
|
74
74
|
pip install brainstate --upgrade
|
75
75
|
```
|
76
76
|
|
77
|
+
Alternatively, you can install `BrainX`, which bundles `brainstate` with other compatible packages for a comprehensive brain modeling ecosystem:
|
77
78
|
|
79
|
+
```bash
|
80
|
+
pip install BrainX -U
|
81
|
+
```
|
78
82
|
|
79
83
|
## Documentation
|
80
84
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
brainstate/__init__.py,sha256=AkZyyFkn4fB8g2aT6Rc2MO1xICPpUZuDtdze-eUQNc0,1496
|
2
|
-
brainstate/_compatible_import.py,sha256=
|
2
|
+
brainstate/_compatible_import.py,sha256=7YGGnkoHMSV4hrTUTN0pTH8bmBCnu0L6E-eB8EdUokQ,4830
|
3
3
|
brainstate/_state.py,sha256=Qvb6O3LFUcq_V8wG6GLAgkckImEwfzDi79JFrZ-lVWc,60753
|
4
4
|
brainstate/_state_test.py,sha256=b6uvZdVRyC4n6-fYzmHNry1b-gJ6zE_kRSxGinqiHaw,1638
|
5
5
|
brainstate/_utils.py,sha256=uJ6WWKq3yb05ZdktCQGLWOXsOJveL1H9pR7eev70Jes,1693
|
@@ -31,7 +31,7 @@ brainstate/compile/_loop_collect_return.py,sha256=lMw66pARbKmWQnNISkNjuCAzSrKF92
|
|
31
31
|
brainstate/compile/_loop_collect_return_test.py,sha256=r-q-D2fYuO9oIGKCu7ZUZDlZj-RlPthdn7fmP04TxbQ,1804
|
32
32
|
brainstate/compile/_loop_no_collection.py,sha256=ij1On2Rj1_Wf7ZixbePNCWJXOi2XqjVHyH31NejV5D8,7587
|
33
33
|
brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
|
34
|
-
brainstate/compile/_make_jaxpr.py,sha256=
|
34
|
+
brainstate/compile/_make_jaxpr.py,sha256=cYK6ReeKJANdoMv4TCuCKYdC6m9J_r-h3AGCmvbUNW0,33088
|
35
35
|
brainstate/compile/_make_jaxpr_test.py,sha256=3rfb1oz0u1wHPmnppnmo2BlUBbEE02FlbvD6Rehub9I,4290
|
36
36
|
brainstate/compile/_progress_bar.py,sha256=5pCMCEmbTO5XmKtzRUJGA178tuBznWKuh9Kw00wAL1I,7524
|
37
37
|
brainstate/compile/_unvmap.py,sha256=oS_Pd2JgnMIUWfdsx_gkTLWnFAcbj8r5QM0D1JjgBT8,4156
|
@@ -126,8 +126,8 @@ brainstate/util/_pretty_table.py,sha256=c3c2UH8hIZ-lCas_KL462kLgsnBc6pjJiMuapPi6
|
|
126
126
|
brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7562
|
127
127
|
brainstate/util/_struct.py,sha256=RNQewdpyNcfntOJYefsDVWQjZW3RLsISKmTHpBhzyc8,17517
|
128
128
|
brainstate/util/filter.py,sha256=blTktYNaNgCsuwv7xABjvbWsoi4Fozov6C2cXX-ta2g,14124
|
129
|
-
brainstate-0.1.0.
|
130
|
-
brainstate-0.1.0.
|
131
|
-
brainstate-0.1.0.
|
132
|
-
brainstate-0.1.0.
|
133
|
-
brainstate-0.1.0.
|
129
|
+
brainstate-0.1.0.post20250503.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
130
|
+
brainstate-0.1.0.post20250503.dist-info/METADATA,sha256=OSUkTgf1YBIYUojDxojQcil4Gwj5pC2pEa8URkbxm7Y,3848
|
131
|
+
brainstate-0.1.0.post20250503.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
132
|
+
brainstate-0.1.0.post20250503.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
133
|
+
brainstate-0.1.0.post20250503.dist-info/RECORD,,
|
File without changes
|
File without changes
|
{brainstate-0.1.0.post20250423.dist-info → brainstate-0.1.0.post20250503.dist-info}/top_level.txt
RENAMED
File without changes
|