brainstate 0.1.0.post20250422__py2.py3-none-any.whl → 0.1.0.post20250501__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,7 +18,8 @@
18
18
 
19
19
  import importlib.util
20
20
  from contextlib import contextmanager
21
- from typing import Iterable, Hashable
21
+ from functools import partial
22
+ from typing import Iterable, Hashable, TypeVar, Callable
22
23
 
23
24
  import jax
24
25
 
@@ -31,8 +32,18 @@ __all__ = [
31
32
  'Tracer',
32
33
  'to_concrete_aval',
33
34
  'brainevent',
35
+ 'safe_map',
36
+ 'safe_zip',
37
+ 'unzip2',
38
+ 'unzip3',
39
+ 'wraps',
34
40
  ]
35
41
 
42
+ T = TypeVar("T")
43
+ T1 = TypeVar("T1")
44
+ T2 = TypeVar("T2")
45
+ T3 = TypeVar("T3")
46
+
36
47
  brainevent_installed = importlib.util.find_spec('brainevent') is not None
37
48
 
38
49
  from jax.core import get_aval, Tracer
@@ -53,6 +64,93 @@ else:
53
64
  finally:
54
65
  trace_ctx.set_axis_env(prev)
55
66
 
67
+ if jax.__version_info__ < (0, 6, 0):
68
+ from jax.util import safe_map, safe_zip, unzip2, unzip3, wraps
69
+
70
+ else:
71
+ def safe_map(f, *args):
72
+ args = list(map(list, args))
73
+ n = len(args[0])
74
+ for arg in args[1:]:
75
+ assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
76
+ return list(map(f, *args))
77
+
78
+
79
+ def safe_zip(*args):
80
+ args = list(map(list, args))
81
+ n = len(args[0])
82
+ for arg in args[1:]:
83
+ assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
84
+ return list(zip(*args))
85
+
86
+
87
+ def unzip2(xys: Iterable[tuple[T1, T2]]
88
+ ) -> tuple[tuple[T1, ...], tuple[T2, ...]]:
89
+ """Unzip sequence of length-2 tuples into two tuples."""
90
+ # Note: we deliberately don't use zip(*xys) because it is lazily evaluated,
91
+ # is too permissive about inputs, and does not guarantee a length-2 output.
92
+ xs: list[T1] = []
93
+ ys: list[T2] = []
94
+ for x, y in xys:
95
+ xs.append(x)
96
+ ys.append(y)
97
+ return tuple(xs), tuple(ys)
98
+
99
+
100
+ def unzip3(xyzs: Iterable[tuple[T1, T2, T3]]
101
+ ) -> tuple[tuple[T1, ...], tuple[T2, ...], tuple[T3, ...]]:
102
+ """Unzip sequence of length-3 tuples into three tuples."""
103
+ # Note: we deliberately don't use zip(*xyzs) because it is lazily evaluated,
104
+ # is too permissive about inputs, and does not guarantee a length-3 output.
105
+ xs: list[T1] = []
106
+ ys: list[T2] = []
107
+ zs: list[T3] = []
108
+ for x, y, z in xyzs:
109
+ xs.append(x)
110
+ ys.append(y)
111
+ zs.append(z)
112
+ return tuple(xs), tuple(ys), tuple(zs)
113
+
114
+
115
+ def fun_name(fun: Callable):
116
+ name = getattr(fun, "__name__", None)
117
+ if name is not None:
118
+ return name
119
+ if isinstance(fun, partial):
120
+ return fun_name(fun.func)
121
+ else:
122
+ return "<unnamed function>"
123
+
124
+
125
+ def wraps(
126
+ wrapped: Callable,
127
+ namestr: str | None = None,
128
+ docstr: str | None = None,
129
+ **kwargs,
130
+ ) -> Callable[[T], T]:
131
+ """
132
+ Like functools.wraps, but with finer-grained control over the name and docstring
133
+ of the resulting function.
134
+ """
135
+
136
+ def wrapper(fun: T) -> T:
137
+ try:
138
+ name = fun_name(wrapped)
139
+ doc = getattr(wrapped, "__doc__", "") or ""
140
+ fun.__dict__.update(getattr(wrapped, "__dict__", {}))
141
+ fun.__annotations__ = getattr(wrapped, "__annotations__", {})
142
+ fun.__name__ = name if namestr is None else namestr.format(fun=name)
143
+ fun.__module__ = getattr(wrapped, "__module__", "<unknown module>")
144
+ fun.__doc__ = (doc if docstr is None
145
+ else docstr.format(fun=name, doc=doc, **kwargs))
146
+ fun.__qualname__ = getattr(wrapped, "__qualname__", fun.__name__)
147
+ fun.__wrapped__ = wrapped
148
+ except Exception:
149
+ pass
150
+ return fun
151
+
152
+ return wrapper
153
+
56
154
 
57
155
  def to_concrete_aval(aval):
58
156
  aval = get_aval(aval)
@@ -67,9 +67,15 @@ from jax._src.traceback_util import api_boundary
67
67
  from jax.api_util import shaped_abstractify
68
68
  from jax.extend.linear_util import transformation_with_aux, wrap_init
69
69
  from jax.interpreters import partial_eval as pe
70
- from jax.util import wraps
71
70
 
72
- from brainstate._compatible_import import ClosedJaxpr, extend_axis_env_nd
71
+ from brainstate._compatible_import import (
72
+ ClosedJaxpr,
73
+ extend_axis_env_nd,
74
+ safe_map,
75
+ safe_zip,
76
+ unzip2,
77
+ wraps,
78
+ )
73
79
  from brainstate._state import State, StateTraceStack
74
80
  from brainstate._utils import set_module_as
75
81
  from brainstate.typing import PyTree
@@ -89,7 +95,7 @@ def _ensure_index_tuple(x: Any) -> tuple[int, ...]:
89
95
  try:
90
96
  return (operator.index(x),)
91
97
  except TypeError:
92
- return tuple(jax.util.safe_map(operator.index, x))
98
+ return tuple(safe_map(operator.index, x))
93
99
 
94
100
 
95
101
  def _new_arg_fn(frame, trace, aval):
@@ -733,7 +739,7 @@ def _make_jaxpr(
733
739
  else:
734
740
  axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
735
741
  in_type = pe.infer_lambda_input_type(axes_specs, flat_args)
736
- in_avals, keep_inputs = jax.util.unzip2(in_type)
742
+ in_avals, keep_inputs = unzip2(in_type)
737
743
  return in_avals, in_tree, keep_inputs
738
744
 
739
745
  @wraps(fun)
@@ -744,7 +750,7 @@ def _make_jaxpr(
744
750
  dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
745
751
  f, args = jax.api_util.argnums_partial(f, dyn_argnums, args)
746
752
  in_avals, in_tree, keep_inputs = _abstractify(args, kwargs)
747
- in_type = tuple(jax.util.safe_zip(in_avals, keep_inputs))
753
+ in_type = tuple(safe_zip(in_avals, keep_inputs))
748
754
  f, out_tree = _flatten_fun(f, in_tree)
749
755
  f = annotate(f, in_type)
750
756
  if jax.__version_info__ < (0, 5, 0):
@@ -758,7 +764,7 @@ def _make_jaxpr(
758
764
  jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f)
759
765
  closed_jaxpr = ClosedJaxpr(jaxpr, consts)
760
766
  if return_shape:
761
- out_avals, _ = jax.util.unzip2(out_type)
767
+ out_avals, _ = unzip2(out_type)
762
768
  out_shapes_flat = [jax.ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals]
763
769
  return closed_jaxpr, jax.tree.unflatten(out_tree(), out_shapes_flat)
764
770
  return closed_jaxpr
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.1.0.post20250422
3
+ Version: 0.1.0.post20250501
4
4
  Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
5
5
  Home-page: https://github.com/chaobrain/brainstate
6
6
  Author: BrainState Developers
@@ -1,5 +1,5 @@
1
1
  brainstate/__init__.py,sha256=AkZyyFkn4fB8g2aT6Rc2MO1xICPpUZuDtdze-eUQNc0,1496
2
- brainstate/_compatible_import.py,sha256=ibxO3Zklu3fxdVsgP4RUe7FgQqs3cZkKbdRGo9GaZKA,2091
2
+ brainstate/_compatible_import.py,sha256=nEmjwX7aCH_kiXsVJXMjOd_SMj2etMiHxCF3z3H8GF0,5444
3
3
  brainstate/_state.py,sha256=Qvb6O3LFUcq_V8wG6GLAgkckImEwfzDi79JFrZ-lVWc,60753
4
4
  brainstate/_state_test.py,sha256=b6uvZdVRyC4n6-fYzmHNry1b-gJ6zE_kRSxGinqiHaw,1638
5
5
  brainstate/_utils.py,sha256=uJ6WWKq3yb05ZdktCQGLWOXsOJveL1H9pR7eev70Jes,1693
@@ -31,7 +31,7 @@ brainstate/compile/_loop_collect_return.py,sha256=lMw66pARbKmWQnNISkNjuCAzSrKF92
31
31
  brainstate/compile/_loop_collect_return_test.py,sha256=r-q-D2fYuO9oIGKCu7ZUZDlZj-RlPthdn7fmP04TxbQ,1804
32
32
  brainstate/compile/_loop_no_collection.py,sha256=ij1On2Rj1_Wf7ZixbePNCWJXOi2XqjVHyH31NejV5D8,7587
33
33
  brainstate/compile/_loop_no_collection_test.py,sha256=oStB1CSG_iLp9sHdXd1hJNFvlxbzjck9Iy4sABoJDj4,1419
34
- brainstate/compile/_make_jaxpr.py,sha256=gZdjXq9kd0G_R0Wz_ttiCXGGaKG3MpdJKMZ48xuIJzg,33081
34
+ brainstate/compile/_make_jaxpr.py,sha256=cYK6ReeKJANdoMv4TCuCKYdC6m9J_r-h3AGCmvbUNW0,33088
35
35
  brainstate/compile/_make_jaxpr_test.py,sha256=3rfb1oz0u1wHPmnppnmo2BlUBbEE02FlbvD6Rehub9I,4290
36
36
  brainstate/compile/_progress_bar.py,sha256=5pCMCEmbTO5XmKtzRUJGA178tuBznWKuh9Kw00wAL1I,7524
37
37
  brainstate/compile/_unvmap.py,sha256=oS_Pd2JgnMIUWfdsx_gkTLWnFAcbj8r5QM0D1JjgBT8,4156
@@ -126,8 +126,8 @@ brainstate/util/_pretty_table.py,sha256=c3c2UH8hIZ-lCas_KL462kLgsnBc6pjJiMuapPi6
126
126
  brainstate/util/_scaling.py,sha256=pc_eM_SZVwkY65I4tJh1ODiHNCoEhsfFXl2zBK0PLAg,7562
127
127
  brainstate/util/_struct.py,sha256=RNQewdpyNcfntOJYefsDVWQjZW3RLsISKmTHpBhzyc8,17517
128
128
  brainstate/util/filter.py,sha256=blTktYNaNgCsuwv7xABjvbWsoi4Fozov6C2cXX-ta2g,14124
129
- brainstate-0.1.0.post20250422.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
130
- brainstate-0.1.0.post20250422.dist-info/METADATA,sha256=oKYcNUrBWjw-Hp4-77lv6T2G18f883loEao6NEgHOyU,3655
131
- brainstate-0.1.0.post20250422.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
132
- brainstate-0.1.0.post20250422.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
133
- brainstate-0.1.0.post20250422.dist-info/RECORD,,
129
+ brainstate-0.1.0.post20250501.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
130
+ brainstate-0.1.0.post20250501.dist-info/METADATA,sha256=l6z_DHPZegmbLm2JoQT101SmHbULrVJ5yLn-Dd_NvhM,3655
131
+ brainstate-0.1.0.post20250501.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
132
+ brainstate-0.1.0.post20250501.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
133
+ brainstate-0.1.0.post20250501.dist-info/RECORD,,