brainstate 0.0.2.post20240825__py2.py3-none-any.whl → 0.0.2.post20240826__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,170 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import math
19
+ from typing import Any, Callable, TypeVar
20
+
21
+ import jax
22
+
23
+ from brainstate._utils import set_module_as
24
+ from ._conditions import _wrapped_fun
25
+ from ._loop_collect_return import _bounded_while_loop
26
+ from ._make_jaxpr import StatefulFunction, _assign_state_values
27
+
28
+ X = TypeVar('X')
29
+ Y = TypeVar('Y')
30
+ T = TypeVar('T')
31
+ Carry = TypeVar('Carry')
32
+ BooleanNumeric = Any # A bool, or a Boolean array.
33
+
34
+ __all__ = [
35
+ # while loop
36
+ 'while_loop', 'bounded_while_loop',
37
+ ]
38
+
39
+
40
+ @set_module_as('brainstate.transform')
41
+ def while_loop(
42
+ cond_fun: Callable[[T], BooleanNumeric],
43
+ body_fun: Callable[[T], T],
44
+ init_val: T
45
+ ) -> T:
46
+ """
47
+ Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.
48
+
49
+ The `Haskell-like type signature`_ in brief is
50
+
51
+ .. code-block:: haskell
52
+
53
+ while_loop :: (a -> Bool) -> (a -> a) -> a -> a
54
+
55
+ The semantics of ``while_loop`` are given by this Python implementation::
56
+
57
+ def while_loop(cond_fun, body_fun, init_val):
58
+ val = init_val
59
+ while cond_fun(val):
60
+ val = body_fun(val)
61
+ return val
62
+
63
+ Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
64
+ to a single WhileOp. That makes it useful for reducing compilation times
65
+ for jit-compiled functions, since native Python loop constructs in an ``@jit``
66
+ function are unrolled, leading to large XLA computations.
67
+
68
+ Also unlike the Python analogue, the loop-carried value ``val`` must hold a
69
+ fixed shape and dtype across all iterations (and not just be consistent up to
70
+ NumPy rank/shape broadcasting and dtype promotion rules, for example). In
71
+ other words, the type ``a`` in the type signature above represents an array
72
+ with a fixed shape and dtype (or a nested tuple/list/dict container data
73
+ structure with a fixed structure and arrays with fixed shape and dtype at the
74
+ leaves).
75
+
76
+ Another difference from using Python-native loop constructs is that
77
+ ``while_loop`` is not reverse-mode differentiable because XLA computations
78
+ require static bounds on memory requirements.
79
+
80
+ Args:
81
+ cond_fun: function of type ``a -> Bool``.
82
+ body_fun: function of type ``a -> a``.
83
+ init_val: value of type ``a``, a type that can be a scalar, array, or any
84
+ pytree (nested Python tuple/list/dict) thereof, representing the initial
85
+ loop carry value.
86
+
87
+ Returns:
88
+ The output from the final iteration of body_fun, of type ``a``.
89
+
90
+ .. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
91
+ """
92
+ if not (callable(body_fun) and callable(cond_fun)):
93
+ raise TypeError("while_loop: body_fun and cond_fun arguments should be callable.")
94
+ if jax.config.jax_disable_jit:
95
+ try:
96
+ val = init_val
97
+ while cond_fun(val):
98
+ val = body_fun(val)
99
+ return val
100
+ except jax.core.ConcretizationTypeError:
101
+ # Can't run this while_loop in Python (e.g. because there's a vmap
102
+ # transformation on it), so we fall back to the primitive version.
103
+ pass
104
+
105
+ # evaluate jaxpr
106
+ stateful_cond = StatefulFunction(cond_fun).make_jaxpr(init_val)
107
+ stateful_body = StatefulFunction(body_fun).make_jaxpr(init_val)
108
+ all_states = tuple(set(stateful_cond.get_states() + stateful_body.get_states()))
109
+ new_cond_fun = _wrapped_fun(stateful_cond, all_states, return_states=False)
110
+ new_body_fun = _wrapped_fun(stateful_body, all_states, return_states=True)
111
+
112
+ # while_loop
113
+ state_vals, final_val = jax.lax.while_loop(new_cond_fun,
114
+ new_body_fun,
115
+ (tuple(st.value for st in all_states), init_val))
116
+ _assign_state_values(all_states, state_vals)
117
+ return final_val
118
+
119
+
120
+ def bounded_while_loop(
121
+ cond_fun: Callable[[T], BooleanNumeric],
122
+ body_fun: Callable[[T], T],
123
+ init_val: T,
124
+ *,
125
+ max_steps: int,
126
+ base: int = 16,
127
+ ):
128
+ """
129
+ While loop with a bound on the maximum number of steps.
130
+
131
+ This function is useful when you want to ensure that a while loop terminates
132
+ even if the condition function is never false. The function is implemented
133
+ using a scan operation, so it is reverse-mode differentiable.
134
+
135
+ Args:
136
+ cond_fun: A function of type ``a -> Bool``.
137
+ body_fun: A function of type ``a -> a``.
138
+ init_val: The initial value of type ``a``.
139
+ max_steps: A bound on the maximum number of steps, after which the loop
140
+ terminates unconditionally.
141
+ base: Run time will increase slightly as `base` increases. Compilation time will
142
+ decrease substantially as `math.ceil(math.log(max_steps, base))` decreases.
143
+ (Which happens as `base` increases.)
144
+
145
+ Returns:
146
+ The final value, as if computed by a `lax.while_loop`.
147
+ """
148
+
149
+ # checking
150
+ if not isinstance(max_steps, int) or max_steps < 0:
151
+ raise ValueError("max_steps must be a non-negative integer")
152
+ init_val = jax.tree.map(jax.numpy.asarray, init_val)
153
+ if max_steps == 0:
154
+ return init_val
155
+
156
+ # evaluate jaxpr
157
+ stateful_cond = StatefulFunction(cond_fun).make_jaxpr(init_val)
158
+ stateful_body = StatefulFunction(body_fun).make_jaxpr(init_val)
159
+ all_states = tuple(set(stateful_cond.get_states() + stateful_body.get_states()))
160
+ new_cond_fun = _wrapped_fun(stateful_cond, all_states, return_states=False)
161
+ new_body_fun = _wrapped_fun(stateful_body, all_states, return_states=True)
162
+
163
+ # initial value
164
+ init_val = (tuple(st.value for st in all_states), init_val)
165
+
166
+ # while_loop
167
+ rounded_max_steps = base ** int(math.ceil(math.log(max_steps, base)))
168
+ state_vals, val = _bounded_while_loop(new_cond_fun, new_body_fun, init_val, rounded_max_steps, base, None)
169
+ _assign_state_values(all_states, state_vals)
170
+ return val
@@ -0,0 +1,109 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import jax
19
+
20
+ from ._loop_collect_return import scan
21
+
22
+ __all__ = [
23
+ 'map',
24
+ ]
25
+
26
+
27
+ def _batch_and_remainder(x, batch_size: int):
28
+ leaves, treedef = jax.tree.flatten(x)
29
+
30
+ scan_leaves = []
31
+ remainder_leaves = []
32
+
33
+ for leaf in leaves:
34
+ num_batches, _ = divmod(leaf.shape[0], batch_size)
35
+ total_batch_elems = num_batches * batch_size
36
+ scan_leaves.append(leaf[:total_batch_elems].reshape(num_batches, batch_size, *leaf.shape[1:]))
37
+ remainder_leaves.append(leaf[total_batch_elems:])
38
+
39
+ scan_tree = treedef.unflatten(scan_leaves)
40
+ remainder_tree = treedef.unflatten(remainder_leaves)
41
+ return scan_tree, remainder_tree
42
+
43
+
44
+ def map(
45
+ f,
46
+ xs,
47
+ *,
48
+ batch_size: int | None = None,
49
+ ):
50
+ """Map a function over leading array axes.
51
+
52
+ Like Python's builtin map, except inputs and outputs are in the form of
53
+ stacked arrays. Consider using the :func:`~jax.vmap` transform instead, unless you
54
+ need to apply a function element by element for reduced memory usage or
55
+ heterogeneous computation with other control flow primitives.
56
+
57
+ When ``xs`` is an array type, the semantics of :func:`~map` are given by this
58
+ Python implementation::
59
+
60
+ def map(f, xs):
61
+ return np.stack([f(x) for x in xs])
62
+
63
+ Like :func:`~scan`, :func:`~map` is implemented in terms of JAX primitives so
64
+ many of the same advantages over a Python loop apply: ``xs`` may be an
65
+ arbitrary nested pytree type, and the mapped computation is compiled only
66
+ once.
67
+
68
+ If ``batch_size`` is provided, the computation is executed in batches of that size
69
+ and parallelized using :func:`~jax.vmap`. This can be used as either a more performant
70
+ version of ``map`` or as a memory-efficient version of ``vmap``. If the axis is not
71
+ divisible by the batch size, the remainder is processed in a separate ``vmap`` and
72
+ concatenated to the result.
73
+
74
+ >>> x = jax.numpy.ones((10, 3, 4))
75
+ >>> def f(x):
76
+ ... print('inner shape:', x.shape)
77
+ ... return x + 1
78
+ >>> y = map(f, x, batch_size=3)
79
+ inner shape: (3, 4)
80
+ inner shape: (3, 4)
81
+ >>> y.shape
82
+ (10, 3, 4)
83
+
84
+ In the example above, "inner shape" is printed twice, once while tracing the batched
85
+ computation and once while tracing the remainder computation.
86
+
87
+ Args:
88
+ f: a Python function to apply element-wise over the first axis or axes of
89
+ ``xs``.
90
+ xs: values over which to map along the leading axis.
91
+ batch_size: (optional) integer specifying the size of the batch for each step to execute
92
+ in parallel.
93
+
94
+ Returns:
95
+ Mapped values.
96
+ """
97
+ if batch_size is not None:
98
+ scan_xs, remainder_xs = _batch_and_remainder(xs, batch_size)
99
+ g = lambda _, x: ((), jax.vmap(f)(x))
100
+ _, scan_ys = scan(g, (), scan_xs)
101
+ remainder_ys = jax.vmap(f)(remainder_xs)
102
+ flatten = lambda x: x.reshape(-1, *x.shape[2:])
103
+ ys = jax.tree.map(
104
+ lambda x, y: jax.numpy.concatenate([flatten(x), y], axis=0), scan_ys, remainder_ys,
105
+ )
106
+ else:
107
+ g = lambda _, x: ((), f(x))
108
+ _, ys = scan(g, (), xs)
109
+ return ys
@@ -0,0 +1,143 @@
1
+ import jax
2
+ import jax.core
3
+ import jax.interpreters.batching as batching
4
+ import jax.interpreters.mlir as mlir
5
+ import jax.numpy as jnp
6
+
7
+ from brainstate._utils import set_module_as
8
+
9
+ __all__ = [
10
+ "unvmap",
11
+ ]
12
+
13
+
14
+ @set_module_as('brainstate.transform')
15
+ def unvmap(x, op: str = 'any'):
16
+ if op == 'all':
17
+ return unvmap_all(x)
18
+ elif op == 'any':
19
+ return unvmap_any(x)
20
+ elif op == 'none':
21
+ return _without_vmap(x)
22
+ elif op == 'max':
23
+ return unvmap_max(x)
24
+ else:
25
+ raise ValueError(f'Do not support type: {op}')
26
+
27
+
28
+ # unvmap_all
29
+
30
+ unvmap_all_p = jax.core.Primitive("unvmap_all")
31
+
32
+
33
+ def unvmap_all(x):
34
+ """As `jnp.all`, but ignores batch dimensions."""
35
+ return unvmap_all_p.bind(x)
36
+
37
+
38
+ def _unvmap_all_impl(x):
39
+ return jnp.all(x)
40
+
41
+
42
+ def _unvmap_all_abstract_eval(x):
43
+ return jax.core.ShapedArray(shape=(), dtype=jax.numpy.bool_.dtype) # pyright: ignore
44
+
45
+
46
+ def _unvmap_all_batch(x, batch_axes):
47
+ (x,) = x
48
+ return unvmap_all(x), batching.not_mapped
49
+
50
+
51
+ unvmap_all_p.def_impl(_unvmap_all_impl)
52
+ unvmap_all_p.def_abstract_eval(_unvmap_all_abstract_eval)
53
+ batching.primitive_batchers[unvmap_all_p] = _unvmap_all_batch # pyright: ignore
54
+ mlir.register_lowering(
55
+ unvmap_all_p,
56
+ mlir.lower_fun(_unvmap_all_impl, multiple_results=False),
57
+ )
58
+
59
+ # unvmap_any
60
+
61
+ unvmap_any_p = jax.core.Primitive("unvmap_any")
62
+
63
+
64
+ def unvmap_any(x):
65
+ """As `jnp.any`, but ignores batch dimensions."""
66
+ return unvmap_any_p.bind(x)
67
+
68
+
69
+ def _unvmap_any_impl(x):
70
+ return jnp.any(x)
71
+
72
+
73
+ def _unvmap_any_abstract_eval(x):
74
+ return jax.core.ShapedArray(shape=(), dtype=jax.numpy.bool_.dtype) # pyright: ignore
75
+
76
+
77
+ def _unvmap_any_batch(x, batch_axes):
78
+ (x,) = x
79
+ return unvmap_any(x), batching.not_mapped
80
+
81
+
82
+ unvmap_any_p.def_impl(_unvmap_any_impl)
83
+ unvmap_any_p.def_abstract_eval(_unvmap_any_abstract_eval)
84
+ batching.primitive_batchers[unvmap_any_p] = _unvmap_any_batch # pyright: ignore
85
+ mlir.register_lowering(
86
+ unvmap_any_p,
87
+ mlir.lower_fun(_unvmap_any_impl, multiple_results=False),
88
+ )
89
+
90
+ # unvmap_max
91
+
92
+ unvmap_max_p = jax.core.Primitive("unvmap_max")
93
+
94
+
95
+ def unvmap_max(x):
96
+ """As `jnp.max`, but ignores batch dimensions."""
97
+ return unvmap_max_p.bind(x)
98
+
99
+
100
+ def _unvmap_max_impl(x):
101
+ return jnp.max(x)
102
+
103
+
104
+ def _unvmap_max_abstract_eval(x):
105
+ return jax.core.ShapedArray(shape=(), dtype=x.dtype)
106
+
107
+
108
+ def _unvmap_max_batch(x, batch_axes):
109
+ (x,) = x
110
+ return unvmap_max(x), batching.not_mapped
111
+
112
+
113
+ unvmap_max_p.def_impl(_unvmap_max_impl)
114
+ unvmap_max_p.def_abstract_eval(_unvmap_max_abstract_eval)
115
+ batching.primitive_batchers[unvmap_max_p] = _unvmap_max_batch # pyright: ignore
116
+ mlir.register_lowering(
117
+ unvmap_max_p,
118
+ mlir.lower_fun(_unvmap_max_impl, multiple_results=False),
119
+ )
120
+
121
+
122
+ def _without_vmap(x):
123
+ return _no_vmap_prim.bind(x)
124
+
125
+
126
+ def _without_vmap_imp(x):
127
+ return x
128
+
129
+
130
+ def _without_vmap_abs(x):
131
+ return x
132
+
133
+
134
+ def _without_vmap_batch(x, batch_axes):
135
+ (x,) = x
136
+ return _without_vmap(x), batching.not_mapped
137
+
138
+
139
+ _no_vmap_prim = jax.core.Primitive('no_vmap')
140
+ _no_vmap_prim.def_impl(_without_vmap_imp)
141
+ _no_vmap_prim.def_abstract_eval(_without_vmap_abs)
142
+ batching.primitive_batchers[_no_vmap_prim] = _without_vmap_batch
143
+ mlir.register_lowering(_no_vmap_prim, mlir.lower_fun(_without_vmap_imp, multiple_results=False))
brainstate/typing.py CHANGED
@@ -15,8 +15,9 @@
15
15
 
16
16
 
17
17
  import functools as ft
18
+ import inspect
18
19
  import typing
19
- from typing import Sequence, Protocol, Union, Any, Generic, TypeVar
20
+ from typing import Sequence, Protocol, Union, Any, Generic, TypeVar, Tuple
20
21
 
21
22
  import brainunit as bu
22
23
  import jax
@@ -34,6 +35,59 @@ __all__ = [
34
35
 
35
36
  _T = TypeVar("_T")
36
37
 
38
+ _Annotation = TypeVar("_Annotation")
39
+
40
+
41
+ class _Array(Generic[_Annotation]):
42
+ pass
43
+
44
+
45
+ _Array.__module__ = "builtins"
46
+
47
+
48
+ def _item_to_str(item: Union[str, type, slice]) -> str:
49
+ if isinstance(item, slice):
50
+ if item.step is not None:
51
+ raise NotImplementedError
52
+ return _item_to_str(item.start) + ": " + _item_to_str(item.stop)
53
+ elif item is ...:
54
+ return "..."
55
+ elif inspect.isclass(item):
56
+ return item.__name__
57
+ else:
58
+ return repr(item)
59
+
60
+
61
+ def _maybe_tuple_to_str(
62
+ item: Union[str, type, slice, Tuple[Union[str, type, slice], ...]]
63
+ ) -> str:
64
+ if isinstance(item, tuple):
65
+ if len(item) == 0:
66
+ # Explicit brackets
67
+ return "()"
68
+ else:
69
+ # No brackets
70
+ return ", ".join([_item_to_str(i) for i in item])
71
+ else:
72
+ return _item_to_str(item)
73
+
74
+
75
+ class Array:
76
+ def __class_getitem__(cls, item):
77
+ class X:
78
+ pass
79
+
80
+ X.__module__ = "builtins"
81
+ X.__qualname__ = _maybe_tuple_to_str(item)
82
+ return _Array[X]
83
+
84
+
85
+ # Same __module__ trick here again. (So that we get the correct display when
86
+ # doing `def f(x: Array)` as well as `def f(x: Array["dim"])`.
87
+ #
88
+ # Don't need to set __qualname__ as that's already correct.
89
+ Array.__module__ = "builtins"
90
+
37
91
 
38
92
  class _FakePyTree(Generic[_T]):
39
93
  pass
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.0.2.post20240825
3
+ Version: 0.0.2.post20240826
4
4
  Summary: A State-based Transformation System for Brain Dynamics Programming.
5
5
  Home-page: https://github.com/brainpy/brainstate
6
6
  Author: BDP
@@ -31,7 +31,7 @@ License-File: LICENSE
31
31
  Requires-Dist: jax
32
32
  Requires-Dist: jaxlib
33
33
  Requires-Dist: numpy
34
- Requires-Dist: brainunit >=0.0.2
34
+ Requires-Dist: brainunit (>=0.0.2)
35
35
  Provides-Extra: cpu
36
36
  Requires-Dist: jaxlib ; extra == 'cpu'
37
37
  Provides-Extra: cuda12
@@ -1,5 +1,5 @@
1
1
  brainstate/__init__.py,sha256=zipNSih9Tyvi4-5cXqNPGsDF7VeestkLp-lcjJ4-dA0,1408
2
- brainstate/_module.py,sha256=YJDp9aD38wBa_lY6BojWjWV9LJ2aFMAMYh-KZe5a4eM,52443
2
+ brainstate/_module.py,sha256=L0cwF_6p9cSZlWGi33Mb5s4vUnbfSeQi2TPUbvsGCzo,52461
3
3
  brainstate/_module_test.py,sha256=oQaoaZBTo1o3wHrMEJTInQCc7RdcVs1gcfQGvdSb1SI,7843
4
4
  brainstate/_random_for_unit.py,sha256=1rHr7gfH_bYrJfpxbDhQUk_j00Yosx-GzyZCXrLxsd0,2007
5
5
  brainstate/_state.py,sha256=C0widCOj_ca6zfqh95jzFXf_G5vi0hJyuQ5GIqEqOUs,12102
@@ -8,10 +8,10 @@ brainstate/_utils.py,sha256=RLorgGJkt2BhbX4C-ygd-PPG0wfcGCghjSP93sRvzqM,833
8
8
  brainstate/environ.py,sha256=k0p1oyi9jbsPfuvqrPL-_zgSd7VW3LRs0LboxlaaIfc,11806
9
9
  brainstate/mixin.py,sha256=OumTTSVyYSbtudjfS_MRThsBaeVJ_0JggeMClY7xtBA,10758
10
10
  brainstate/mixin_test.py,sha256=-Ej9oUOu8O1M4oy37SVMj7xNRYhHHyAHwrjS_aISayo,2923
11
- brainstate/random.py,sha256=BqEBYVD9TGe8dSzp8U0suK0O4r6Ox59GCq0mwfUndVQ,188073
11
+ brainstate/random.py,sha256=pTZvTH06hv08_TpwzAWCqAjy-8oNGmB6-Jp6MKfkLaY,188087
12
12
  brainstate/random_test.py,sha256=cCeuYvlZkCS2_RgG0vipZFNSHG8b-uJ7SXM9SZDCYQM,17866
13
13
  brainstate/surrogate.py,sha256=1kgbn82GSlpReIytIVl29yozk75gkdZv0gTBlixQ4C4,43798
14
- brainstate/typing.py,sha256=6BlkLSN5TiaNO49q8b0OYyzcuSxmdoG3noIJTbyhE3s,7895
14
+ brainstate/typing.py,sha256=szCYee9R15YQfsEAQOx95_LqfrD9AYuE5dfTBTPd8sg,9165
15
15
  brainstate/util.py,sha256=y-6eX1z3EMyg6pfZt4YdDalOnJ3HDAT1IPBCJDp-gQI,19876
16
16
  brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
17
17
  brainstate/functional/_activations.py,sha256=gvZ9E1-TsEUlyO7Om0eYzlM9DF-14_A32-gta1mjGo4,17798
@@ -47,20 +47,24 @@ brainstate/optim/__init__.py,sha256=1L6x_qZprw3PJYddB1nX-uTFGUl6_Qt3PM0OdY6g968,
47
47
  brainstate/optim/_lr_scheduler.py,sha256=emKnA52UVqOfUcX7LJqwP-FVDVlGGzTQi2djYmbCWUo,15627
48
48
  brainstate/optim/_lr_scheduler_test.py,sha256=OwF8Iz-PorEbO0gO--A7IIgQEytqEfYWbPucAgzqL90,1598
49
49
  brainstate/optim/_sgd_optimizer.py,sha256=JiK_AVGregL0wn8uHhRQvK9Qq7Qja7dEyLW6Aan7b70,45826
50
- brainstate/transform/__init__.py,sha256=my2X4ZW0uKZRfN82zyGEPizWNJ0fsSP2akvmkjn43ck,1458
50
+ brainstate/transform/__init__.py,sha256=hqef3a4sLQ_Oihuqs8E5IghSLr9o2bS7CWmwRL8jX6E,1887
51
51
  brainstate/transform/_autograd.py,sha256=Pj_YxpU52guaxQs1NcB6qDtXgkvaPcoJbuvIF8T-Wmk,23964
52
52
  brainstate/transform/_autograd_test.py,sha256=RWriMemIF9FVFUjQh4IHzLhT9LGyd1JXpjXfFZKHn10,38654
53
- brainstate/transform/_control.py,sha256=0NFUGLIenqKuBhBiTmY0YgCrl2GI1ZbuWMW0DSOolpE,26874
54
- brainstate/transform/_controls_test.py,sha256=mPUa_qmXXVxDziAJrPWRBwsGnc3cHR9co08eJB_fJwA,7648
53
+ brainstate/transform/_conditions.py,sha256=bdqdHCPCJIpRJxNr0enO2u81924YoIuA8kS8GUGY98g,12970
54
+ brainstate/transform/_conditions_test.py,sha256=hg3gyOk4jn88F_ZYYqqwf6m87N3GlOUE9dC2V3BnMTA,7691
55
+ brainstate/transform/_error_if.py,sha256=0JThfFqt9B3K3H6mS84qecBS22yTi3-FPzviaYacaMY,2597
56
+ brainstate/transform/_error_if_test.py,sha256=kQZujlgr9bYnL-Vf7x4Zfc7jJk7rCLNVu-bsiry40dQ,1874
55
57
  brainstate/transform/_jit.py,sha256=sjQHFV8Tt75fpdl12jjPRDPT92_IZxBBJAG4gapdbNQ,11471
56
- brainstate/transform/_jit_error.py,sha256=8rGRx8dtvmPWmHVOsfz30EUMXSix-m2PKM3Ni_9-_7I,4829
57
- brainstate/transform/_jit_error_test.py,sha256=GAVGL0eNJ5Fu0lHABCGc-nLfa_0x0tw_VPfURB-nhLc,1862
58
58
  brainstate/transform/_jit_test.py,sha256=5ltT7izh_OS9dcHnRymmVhq01QomjwZGdA8XzwJRLb4,2868
59
+ brainstate/transform/_loop_collect_return.py,sha256=aUZSK5MX4stVP8Te4R8glm2SdP18rUUfHjcV4TXOPC8,20768
60
+ brainstate/transform/_loop_no_collection.py,sha256=p2vHoNNesDH2cM7b5LgLzSv90M8iNQPkRZEl0jhf7yA,6476
59
61
  brainstate/transform/_make_jaxpr.py,sha256=ZkrOZu4_0xcILuPUA3RFEkorJ-xbDuDtXorJI_qVThE,30450
60
62
  brainstate/transform/_make_jaxpr_test.py,sha256=K3vRUBroDTCCx0lnmhgHtgrlWvWglJO2f1K2phTvU70,3819
63
+ brainstate/transform/_mapping.py,sha256=G9XUsD1xKLCprwwE0wv0gSXS0NYZ-ZIsv-PKKRlOoTA,3821
61
64
  brainstate/transform/_progress_bar.py,sha256=VGoRZPRBmB8ELNwLc6c7S8QhUUTvn0FY46IbBm9cuYM,3502
62
- brainstate-0.0.2.post20240825.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
63
- brainstate-0.0.2.post20240825.dist-info/METADATA,sha256=COfpxoCL7w1xGa1OFYFeANFLhAKmSioWVtmF_i2st34,3849
64
- brainstate-0.0.2.post20240825.dist-info/WHEEL,sha256=GUeE9LxUgRABPG7YM0jCNs9cBsAIx0YAkzCB88PMLgc,109
65
- brainstate-0.0.2.post20240825.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
66
- brainstate-0.0.2.post20240825.dist-info/RECORD,,
65
+ brainstate/transform/_unvmap.py,sha256=8Se_23QrwDdcJpFcUnnMgD6EP-4XylbhP9K5TDhW358,3311
66
+ brainstate-0.0.2.post20240826.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
67
+ brainstate-0.0.2.post20240826.dist-info/METADATA,sha256=unkbYHiPHAtNHGENoSt47mA0N3cXWsCafRC8Fo2NPyk,3851
68
+ brainstate-0.0.2.post20240826.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
69
+ brainstate-0.0.2.post20240826.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
70
+ brainstate-0.0.2.post20240826.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (73.0.1)
2
+ Generator: bdist_wheel (0.38.4)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py2-none-any
5
5
  Tag: py3-none-any