brainstate 0.0.2.post20240825__py2.py3-none-any.whl → 0.0.2.post20240826__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_module.py +6 -6
- brainstate/random.py +5 -5
- brainstate/transform/__init__.py +16 -6
- brainstate/transform/_conditions.py +334 -0
- brainstate/transform/{_controls_test.py → _conditions_test.py} +35 -35
- brainstate/transform/_error_if.py +94 -0
- brainstate/transform/{_jit_error_test.py → _error_if_test.py} +4 -4
- brainstate/transform/_loop_collect_return.py +502 -0
- brainstate/transform/_loop_no_collection.py +170 -0
- brainstate/transform/_mapping.py +109 -0
- brainstate/transform/_unvmap.py +143 -0
- brainstate/typing.py +55 -1
- {brainstate-0.0.2.post20240825.dist-info → brainstate-0.0.2.post20240826.dist-info}/METADATA +2 -2
- {brainstate-0.0.2.post20240825.dist-info → brainstate-0.0.2.post20240826.dist-info}/RECORD +17 -13
- {brainstate-0.0.2.post20240825.dist-info → brainstate-0.0.2.post20240826.dist-info}/WHEEL +1 -1
- brainstate/transform/_control.py +0 -665
- brainstate/transform/_jit_error.py +0 -180
- {brainstate-0.0.2.post20240825.dist-info → brainstate-0.0.2.post20240826.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20240825.dist-info → brainstate-0.0.2.post20240826.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,170 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import math
|
19
|
+
from typing import Any, Callable, TypeVar
|
20
|
+
|
21
|
+
import jax
|
22
|
+
|
23
|
+
from brainstate._utils import set_module_as
|
24
|
+
from ._conditions import _wrapped_fun
|
25
|
+
from ._loop_collect_return import _bounded_while_loop
|
26
|
+
from ._make_jaxpr import StatefulFunction, _assign_state_values
|
27
|
+
|
28
|
+
X = TypeVar('X')
|
29
|
+
Y = TypeVar('Y')
|
30
|
+
T = TypeVar('T')
|
31
|
+
Carry = TypeVar('Carry')
|
32
|
+
BooleanNumeric = Any # A bool, or a Boolean array.
|
33
|
+
|
34
|
+
__all__ = [
|
35
|
+
# while loop
|
36
|
+
'while_loop', 'bounded_while_loop',
|
37
|
+
]
|
38
|
+
|
39
|
+
|
40
|
+
@set_module_as('brainstate.transform')
|
41
|
+
def while_loop(
|
42
|
+
cond_fun: Callable[[T], BooleanNumeric],
|
43
|
+
body_fun: Callable[[T], T],
|
44
|
+
init_val: T
|
45
|
+
) -> T:
|
46
|
+
"""
|
47
|
+
Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.
|
48
|
+
|
49
|
+
The `Haskell-like type signature`_ in brief is
|
50
|
+
|
51
|
+
.. code-block:: haskell
|
52
|
+
|
53
|
+
while_loop :: (a -> Bool) -> (a -> a) -> a -> a
|
54
|
+
|
55
|
+
The semantics of ``while_loop`` are given by this Python implementation::
|
56
|
+
|
57
|
+
def while_loop(cond_fun, body_fun, init_val):
|
58
|
+
val = init_val
|
59
|
+
while cond_fun(val):
|
60
|
+
val = body_fun(val)
|
61
|
+
return val
|
62
|
+
|
63
|
+
Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
|
64
|
+
to a single WhileOp. That makes it useful for reducing compilation times
|
65
|
+
for jit-compiled functions, since native Python loop constructs in an ``@jit``
|
66
|
+
function are unrolled, leading to large XLA computations.
|
67
|
+
|
68
|
+
Also unlike the Python analogue, the loop-carried value ``val`` must hold a
|
69
|
+
fixed shape and dtype across all iterations (and not just be consistent up to
|
70
|
+
NumPy rank/shape broadcasting and dtype promotion rules, for example). In
|
71
|
+
other words, the type ``a`` in the type signature above represents an array
|
72
|
+
with a fixed shape and dtype (or a nested tuple/list/dict container data
|
73
|
+
structure with a fixed structure and arrays with fixed shape and dtype at the
|
74
|
+
leaves).
|
75
|
+
|
76
|
+
Another difference from using Python-native loop constructs is that
|
77
|
+
``while_loop`` is not reverse-mode differentiable because XLA computations
|
78
|
+
require static bounds on memory requirements.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
cond_fun: function of type ``a -> Bool``.
|
82
|
+
body_fun: function of type ``a -> a``.
|
83
|
+
init_val: value of type ``a``, a type that can be a scalar, array, or any
|
84
|
+
pytree (nested Python tuple/list/dict) thereof, representing the initial
|
85
|
+
loop carry value.
|
86
|
+
|
87
|
+
Returns:
|
88
|
+
The output from the final iteration of body_fun, of type ``a``.
|
89
|
+
|
90
|
+
.. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
|
91
|
+
"""
|
92
|
+
if not (callable(body_fun) and callable(cond_fun)):
|
93
|
+
raise TypeError("while_loop: body_fun and cond_fun arguments should be callable.")
|
94
|
+
if jax.config.jax_disable_jit:
|
95
|
+
try:
|
96
|
+
val = init_val
|
97
|
+
while cond_fun(val):
|
98
|
+
val = body_fun(val)
|
99
|
+
return val
|
100
|
+
except jax.core.ConcretizationTypeError:
|
101
|
+
# Can't run this while_loop in Python (e.g. because there's a vmap
|
102
|
+
# transformation on it), so we fall back to the primitive version.
|
103
|
+
pass
|
104
|
+
|
105
|
+
# evaluate jaxpr
|
106
|
+
stateful_cond = StatefulFunction(cond_fun).make_jaxpr(init_val)
|
107
|
+
stateful_body = StatefulFunction(body_fun).make_jaxpr(init_val)
|
108
|
+
all_states = tuple(set(stateful_cond.get_states() + stateful_body.get_states()))
|
109
|
+
new_cond_fun = _wrapped_fun(stateful_cond, all_states, return_states=False)
|
110
|
+
new_body_fun = _wrapped_fun(stateful_body, all_states, return_states=True)
|
111
|
+
|
112
|
+
# while_loop
|
113
|
+
state_vals, final_val = jax.lax.while_loop(new_cond_fun,
|
114
|
+
new_body_fun,
|
115
|
+
(tuple(st.value for st in all_states), init_val))
|
116
|
+
_assign_state_values(all_states, state_vals)
|
117
|
+
return final_val
|
118
|
+
|
119
|
+
|
120
|
+
def bounded_while_loop(
|
121
|
+
cond_fun: Callable[[T], BooleanNumeric],
|
122
|
+
body_fun: Callable[[T], T],
|
123
|
+
init_val: T,
|
124
|
+
*,
|
125
|
+
max_steps: int,
|
126
|
+
base: int = 16,
|
127
|
+
):
|
128
|
+
"""
|
129
|
+
While loop with a bound on the maximum number of steps.
|
130
|
+
|
131
|
+
This function is useful when you want to ensure that a while loop terminates
|
132
|
+
even if the condition function is never false. The function is implemented
|
133
|
+
using a scan operation, so it is reverse-mode differentiable.
|
134
|
+
|
135
|
+
Args:
|
136
|
+
cond_fun: A function of type ``a -> Bool``.
|
137
|
+
body_fun: A function of type ``a -> a``.
|
138
|
+
init_val: The initial value of type ``a``.
|
139
|
+
max_steps: A bound on the maximum number of steps, after which the loop
|
140
|
+
terminates unconditionally.
|
141
|
+
base: Run time will increase slightly as `base` increases. Compilation time will
|
142
|
+
decrease substantially as `math.ceil(math.log(max_steps, base))` decreases.
|
143
|
+
(Which happens as `base` increases.)
|
144
|
+
|
145
|
+
Returns:
|
146
|
+
The final value, as if computed by a `lax.while_loop`.
|
147
|
+
"""
|
148
|
+
|
149
|
+
# checking
|
150
|
+
if not isinstance(max_steps, int) or max_steps < 0:
|
151
|
+
raise ValueError("max_steps must be a non-negative integer")
|
152
|
+
init_val = jax.tree.map(jax.numpy.asarray, init_val)
|
153
|
+
if max_steps == 0:
|
154
|
+
return init_val
|
155
|
+
|
156
|
+
# evaluate jaxpr
|
157
|
+
stateful_cond = StatefulFunction(cond_fun).make_jaxpr(init_val)
|
158
|
+
stateful_body = StatefulFunction(body_fun).make_jaxpr(init_val)
|
159
|
+
all_states = tuple(set(stateful_cond.get_states() + stateful_body.get_states()))
|
160
|
+
new_cond_fun = _wrapped_fun(stateful_cond, all_states, return_states=False)
|
161
|
+
new_body_fun = _wrapped_fun(stateful_body, all_states, return_states=True)
|
162
|
+
|
163
|
+
# initial value
|
164
|
+
init_val = (tuple(st.value for st in all_states), init_val)
|
165
|
+
|
166
|
+
# while_loop
|
167
|
+
rounded_max_steps = base ** int(math.ceil(math.log(max_steps, base)))
|
168
|
+
state_vals, val = _bounded_while_loop(new_cond_fun, new_body_fun, init_val, rounded_max_steps, base, None)
|
169
|
+
_assign_state_values(all_states, state_vals)
|
170
|
+
return val
|
@@ -0,0 +1,109 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import jax
|
19
|
+
|
20
|
+
from ._loop_collect_return import scan
|
21
|
+
|
22
|
+
__all__ = [
|
23
|
+
'map',
|
24
|
+
]
|
25
|
+
|
26
|
+
|
27
|
+
def _batch_and_remainder(x, batch_size: int):
|
28
|
+
leaves, treedef = jax.tree.flatten(x)
|
29
|
+
|
30
|
+
scan_leaves = []
|
31
|
+
remainder_leaves = []
|
32
|
+
|
33
|
+
for leaf in leaves:
|
34
|
+
num_batches, _ = divmod(leaf.shape[0], batch_size)
|
35
|
+
total_batch_elems = num_batches * batch_size
|
36
|
+
scan_leaves.append(leaf[:total_batch_elems].reshape(num_batches, batch_size, *leaf.shape[1:]))
|
37
|
+
remainder_leaves.append(leaf[total_batch_elems:])
|
38
|
+
|
39
|
+
scan_tree = treedef.unflatten(scan_leaves)
|
40
|
+
remainder_tree = treedef.unflatten(remainder_leaves)
|
41
|
+
return scan_tree, remainder_tree
|
42
|
+
|
43
|
+
|
44
|
+
def map(
|
45
|
+
f,
|
46
|
+
xs,
|
47
|
+
*,
|
48
|
+
batch_size: int | None = None,
|
49
|
+
):
|
50
|
+
"""Map a function over leading array axes.
|
51
|
+
|
52
|
+
Like Python's builtin map, except inputs and outputs are in the form of
|
53
|
+
stacked arrays. Consider using the :func:`~jax.vmap` transform instead, unless you
|
54
|
+
need to apply a function element by element for reduced memory usage or
|
55
|
+
heterogeneous computation with other control flow primitives.
|
56
|
+
|
57
|
+
When ``xs`` is an array type, the semantics of :func:`~map` are given by this
|
58
|
+
Python implementation::
|
59
|
+
|
60
|
+
def map(f, xs):
|
61
|
+
return np.stack([f(x) for x in xs])
|
62
|
+
|
63
|
+
Like :func:`~scan`, :func:`~map` is implemented in terms of JAX primitives so
|
64
|
+
many of the same advantages over a Python loop apply: ``xs`` may be an
|
65
|
+
arbitrary nested pytree type, and the mapped computation is compiled only
|
66
|
+
once.
|
67
|
+
|
68
|
+
If ``batch_size`` is provided, the computation is executed in batches of that size
|
69
|
+
and parallelized using :func:`~jax.vmap`. This can be used as either a more performant
|
70
|
+
version of ``map`` or as a memory-efficient version of ``vmap``. If the axis is not
|
71
|
+
divisible by the batch size, the remainder is processed in a separate ``vmap`` and
|
72
|
+
concatenated to the result.
|
73
|
+
|
74
|
+
>>> x = jax.numpy.ones((10, 3, 4))
|
75
|
+
>>> def f(x):
|
76
|
+
... print('inner shape:', x.shape)
|
77
|
+
... return x + 1
|
78
|
+
>>> y = map(f, x, batch_size=3)
|
79
|
+
inner shape: (3, 4)
|
80
|
+
inner shape: (3, 4)
|
81
|
+
>>> y.shape
|
82
|
+
(10, 3, 4)
|
83
|
+
|
84
|
+
In the example above, "inner shape" is printed twice, once while tracing the batched
|
85
|
+
computation and once while tracing the remainder computation.
|
86
|
+
|
87
|
+
Args:
|
88
|
+
f: a Python function to apply element-wise over the first axis or axes of
|
89
|
+
``xs``.
|
90
|
+
xs: values over which to map along the leading axis.
|
91
|
+
batch_size: (optional) integer specifying the size of the batch for each step to execute
|
92
|
+
in parallel.
|
93
|
+
|
94
|
+
Returns:
|
95
|
+
Mapped values.
|
96
|
+
"""
|
97
|
+
if batch_size is not None:
|
98
|
+
scan_xs, remainder_xs = _batch_and_remainder(xs, batch_size)
|
99
|
+
g = lambda _, x: ((), jax.vmap(f)(x))
|
100
|
+
_, scan_ys = scan(g, (), scan_xs)
|
101
|
+
remainder_ys = jax.vmap(f)(remainder_xs)
|
102
|
+
flatten = lambda x: x.reshape(-1, *x.shape[2:])
|
103
|
+
ys = jax.tree.map(
|
104
|
+
lambda x, y: jax.numpy.concatenate([flatten(x), y], axis=0), scan_ys, remainder_ys,
|
105
|
+
)
|
106
|
+
else:
|
107
|
+
g = lambda _, x: ((), f(x))
|
108
|
+
_, ys = scan(g, (), xs)
|
109
|
+
return ys
|
@@ -0,0 +1,143 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.core
|
3
|
+
import jax.interpreters.batching as batching
|
4
|
+
import jax.interpreters.mlir as mlir
|
5
|
+
import jax.numpy as jnp
|
6
|
+
|
7
|
+
from brainstate._utils import set_module_as
|
8
|
+
|
9
|
+
__all__ = [
|
10
|
+
"unvmap",
|
11
|
+
]
|
12
|
+
|
13
|
+
|
14
|
+
@set_module_as('brainstate.transform')
|
15
|
+
def unvmap(x, op: str = 'any'):
|
16
|
+
if op == 'all':
|
17
|
+
return unvmap_all(x)
|
18
|
+
elif op == 'any':
|
19
|
+
return unvmap_any(x)
|
20
|
+
elif op == 'none':
|
21
|
+
return _without_vmap(x)
|
22
|
+
elif op == 'max':
|
23
|
+
return unvmap_max(x)
|
24
|
+
else:
|
25
|
+
raise ValueError(f'Do not support type: {op}')
|
26
|
+
|
27
|
+
|
28
|
+
# unvmap_all
|
29
|
+
|
30
|
+
unvmap_all_p = jax.core.Primitive("unvmap_all")
|
31
|
+
|
32
|
+
|
33
|
+
def unvmap_all(x):
|
34
|
+
"""As `jnp.all`, but ignores batch dimensions."""
|
35
|
+
return unvmap_all_p.bind(x)
|
36
|
+
|
37
|
+
|
38
|
+
def _unvmap_all_impl(x):
|
39
|
+
return jnp.all(x)
|
40
|
+
|
41
|
+
|
42
|
+
def _unvmap_all_abstract_eval(x):
|
43
|
+
return jax.core.ShapedArray(shape=(), dtype=jax.numpy.bool_.dtype) # pyright: ignore
|
44
|
+
|
45
|
+
|
46
|
+
def _unvmap_all_batch(x, batch_axes):
|
47
|
+
(x,) = x
|
48
|
+
return unvmap_all(x), batching.not_mapped
|
49
|
+
|
50
|
+
|
51
|
+
unvmap_all_p.def_impl(_unvmap_all_impl)
|
52
|
+
unvmap_all_p.def_abstract_eval(_unvmap_all_abstract_eval)
|
53
|
+
batching.primitive_batchers[unvmap_all_p] = _unvmap_all_batch # pyright: ignore
|
54
|
+
mlir.register_lowering(
|
55
|
+
unvmap_all_p,
|
56
|
+
mlir.lower_fun(_unvmap_all_impl, multiple_results=False),
|
57
|
+
)
|
58
|
+
|
59
|
+
# unvmap_any
|
60
|
+
|
61
|
+
unvmap_any_p = jax.core.Primitive("unvmap_any")
|
62
|
+
|
63
|
+
|
64
|
+
def unvmap_any(x):
|
65
|
+
"""As `jnp.any`, but ignores batch dimensions."""
|
66
|
+
return unvmap_any_p.bind(x)
|
67
|
+
|
68
|
+
|
69
|
+
def _unvmap_any_impl(x):
|
70
|
+
return jnp.any(x)
|
71
|
+
|
72
|
+
|
73
|
+
def _unvmap_any_abstract_eval(x):
|
74
|
+
return jax.core.ShapedArray(shape=(), dtype=jax.numpy.bool_.dtype) # pyright: ignore
|
75
|
+
|
76
|
+
|
77
|
+
def _unvmap_any_batch(x, batch_axes):
|
78
|
+
(x,) = x
|
79
|
+
return unvmap_any(x), batching.not_mapped
|
80
|
+
|
81
|
+
|
82
|
+
unvmap_any_p.def_impl(_unvmap_any_impl)
|
83
|
+
unvmap_any_p.def_abstract_eval(_unvmap_any_abstract_eval)
|
84
|
+
batching.primitive_batchers[unvmap_any_p] = _unvmap_any_batch # pyright: ignore
|
85
|
+
mlir.register_lowering(
|
86
|
+
unvmap_any_p,
|
87
|
+
mlir.lower_fun(_unvmap_any_impl, multiple_results=False),
|
88
|
+
)
|
89
|
+
|
90
|
+
# unvmap_max
|
91
|
+
|
92
|
+
unvmap_max_p = jax.core.Primitive("unvmap_max")
|
93
|
+
|
94
|
+
|
95
|
+
def unvmap_max(x):
|
96
|
+
"""As `jnp.max`, but ignores batch dimensions."""
|
97
|
+
return unvmap_max_p.bind(x)
|
98
|
+
|
99
|
+
|
100
|
+
def _unvmap_max_impl(x):
|
101
|
+
return jnp.max(x)
|
102
|
+
|
103
|
+
|
104
|
+
def _unvmap_max_abstract_eval(x):
|
105
|
+
return jax.core.ShapedArray(shape=(), dtype=x.dtype)
|
106
|
+
|
107
|
+
|
108
|
+
def _unvmap_max_batch(x, batch_axes):
|
109
|
+
(x,) = x
|
110
|
+
return unvmap_max(x), batching.not_mapped
|
111
|
+
|
112
|
+
|
113
|
+
unvmap_max_p.def_impl(_unvmap_max_impl)
|
114
|
+
unvmap_max_p.def_abstract_eval(_unvmap_max_abstract_eval)
|
115
|
+
batching.primitive_batchers[unvmap_max_p] = _unvmap_max_batch # pyright: ignore
|
116
|
+
mlir.register_lowering(
|
117
|
+
unvmap_max_p,
|
118
|
+
mlir.lower_fun(_unvmap_max_impl, multiple_results=False),
|
119
|
+
)
|
120
|
+
|
121
|
+
|
122
|
+
def _without_vmap(x):
|
123
|
+
return _no_vmap_prim.bind(x)
|
124
|
+
|
125
|
+
|
126
|
+
def _without_vmap_imp(x):
|
127
|
+
return x
|
128
|
+
|
129
|
+
|
130
|
+
def _without_vmap_abs(x):
|
131
|
+
return x
|
132
|
+
|
133
|
+
|
134
|
+
def _without_vmap_batch(x, batch_axes):
|
135
|
+
(x,) = x
|
136
|
+
return _without_vmap(x), batching.not_mapped
|
137
|
+
|
138
|
+
|
139
|
+
_no_vmap_prim = jax.core.Primitive('no_vmap')
|
140
|
+
_no_vmap_prim.def_impl(_without_vmap_imp)
|
141
|
+
_no_vmap_prim.def_abstract_eval(_without_vmap_abs)
|
142
|
+
batching.primitive_batchers[_no_vmap_prim] = _without_vmap_batch
|
143
|
+
mlir.register_lowering(_no_vmap_prim, mlir.lower_fun(_without_vmap_imp, multiple_results=False))
|
brainstate/typing.py
CHANGED
@@ -15,8 +15,9 @@
|
|
15
15
|
|
16
16
|
|
17
17
|
import functools as ft
|
18
|
+
import inspect
|
18
19
|
import typing
|
19
|
-
from typing import Sequence, Protocol, Union, Any, Generic, TypeVar
|
20
|
+
from typing import Sequence, Protocol, Union, Any, Generic, TypeVar, Tuple
|
20
21
|
|
21
22
|
import brainunit as bu
|
22
23
|
import jax
|
@@ -34,6 +35,59 @@ __all__ = [
|
|
34
35
|
|
35
36
|
_T = TypeVar("_T")
|
36
37
|
|
38
|
+
_Annotation = TypeVar("_Annotation")
|
39
|
+
|
40
|
+
|
41
|
+
class _Array(Generic[_Annotation]):
|
42
|
+
pass
|
43
|
+
|
44
|
+
|
45
|
+
_Array.__module__ = "builtins"
|
46
|
+
|
47
|
+
|
48
|
+
def _item_to_str(item: Union[str, type, slice]) -> str:
|
49
|
+
if isinstance(item, slice):
|
50
|
+
if item.step is not None:
|
51
|
+
raise NotImplementedError
|
52
|
+
return _item_to_str(item.start) + ": " + _item_to_str(item.stop)
|
53
|
+
elif item is ...:
|
54
|
+
return "..."
|
55
|
+
elif inspect.isclass(item):
|
56
|
+
return item.__name__
|
57
|
+
else:
|
58
|
+
return repr(item)
|
59
|
+
|
60
|
+
|
61
|
+
def _maybe_tuple_to_str(
|
62
|
+
item: Union[str, type, slice, Tuple[Union[str, type, slice], ...]]
|
63
|
+
) -> str:
|
64
|
+
if isinstance(item, tuple):
|
65
|
+
if len(item) == 0:
|
66
|
+
# Explicit brackets
|
67
|
+
return "()"
|
68
|
+
else:
|
69
|
+
# No brackets
|
70
|
+
return ", ".join([_item_to_str(i) for i in item])
|
71
|
+
else:
|
72
|
+
return _item_to_str(item)
|
73
|
+
|
74
|
+
|
75
|
+
class Array:
|
76
|
+
def __class_getitem__(cls, item):
|
77
|
+
class X:
|
78
|
+
pass
|
79
|
+
|
80
|
+
X.__module__ = "builtins"
|
81
|
+
X.__qualname__ = _maybe_tuple_to_str(item)
|
82
|
+
return _Array[X]
|
83
|
+
|
84
|
+
|
85
|
+
# Same __module__ trick here again. (So that we get the correct display when
|
86
|
+
# doing `def f(x: Array)` as well as `def f(x: Array["dim"])`.
|
87
|
+
#
|
88
|
+
# Don't need to set __qualname__ as that's already correct.
|
89
|
+
Array.__module__ = "builtins"
|
90
|
+
|
37
91
|
|
38
92
|
class _FakePyTree(Generic[_T]):
|
39
93
|
pass
|
{brainstate-0.0.2.post20240825.dist-info → brainstate-0.0.2.post20240826.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: brainstate
|
3
|
-
Version: 0.0.2.
|
3
|
+
Version: 0.0.2.post20240826
|
4
4
|
Summary: A State-based Transformation System for Brain Dynamics Programming.
|
5
5
|
Home-page: https://github.com/brainpy/brainstate
|
6
6
|
Author: BDP
|
@@ -31,7 +31,7 @@ License-File: LICENSE
|
|
31
31
|
Requires-Dist: jax
|
32
32
|
Requires-Dist: jaxlib
|
33
33
|
Requires-Dist: numpy
|
34
|
-
Requires-Dist: brainunit >=0.0.2
|
34
|
+
Requires-Dist: brainunit (>=0.0.2)
|
35
35
|
Provides-Extra: cpu
|
36
36
|
Requires-Dist: jaxlib ; extra == 'cpu'
|
37
37
|
Provides-Extra: cuda12
|
@@ -1,5 +1,5 @@
|
|
1
1
|
brainstate/__init__.py,sha256=zipNSih9Tyvi4-5cXqNPGsDF7VeestkLp-lcjJ4-dA0,1408
|
2
|
-
brainstate/_module.py,sha256=
|
2
|
+
brainstate/_module.py,sha256=L0cwF_6p9cSZlWGi33Mb5s4vUnbfSeQi2TPUbvsGCzo,52461
|
3
3
|
brainstate/_module_test.py,sha256=oQaoaZBTo1o3wHrMEJTInQCc7RdcVs1gcfQGvdSb1SI,7843
|
4
4
|
brainstate/_random_for_unit.py,sha256=1rHr7gfH_bYrJfpxbDhQUk_j00Yosx-GzyZCXrLxsd0,2007
|
5
5
|
brainstate/_state.py,sha256=C0widCOj_ca6zfqh95jzFXf_G5vi0hJyuQ5GIqEqOUs,12102
|
@@ -8,10 +8,10 @@ brainstate/_utils.py,sha256=RLorgGJkt2BhbX4C-ygd-PPG0wfcGCghjSP93sRvzqM,833
|
|
8
8
|
brainstate/environ.py,sha256=k0p1oyi9jbsPfuvqrPL-_zgSd7VW3LRs0LboxlaaIfc,11806
|
9
9
|
brainstate/mixin.py,sha256=OumTTSVyYSbtudjfS_MRThsBaeVJ_0JggeMClY7xtBA,10758
|
10
10
|
brainstate/mixin_test.py,sha256=-Ej9oUOu8O1M4oy37SVMj7xNRYhHHyAHwrjS_aISayo,2923
|
11
|
-
brainstate/random.py,sha256=
|
11
|
+
brainstate/random.py,sha256=pTZvTH06hv08_TpwzAWCqAjy-8oNGmB6-Jp6MKfkLaY,188087
|
12
12
|
brainstate/random_test.py,sha256=cCeuYvlZkCS2_RgG0vipZFNSHG8b-uJ7SXM9SZDCYQM,17866
|
13
13
|
brainstate/surrogate.py,sha256=1kgbn82GSlpReIytIVl29yozk75gkdZv0gTBlixQ4C4,43798
|
14
|
-
brainstate/typing.py,sha256=
|
14
|
+
brainstate/typing.py,sha256=szCYee9R15YQfsEAQOx95_LqfrD9AYuE5dfTBTPd8sg,9165
|
15
15
|
brainstate/util.py,sha256=y-6eX1z3EMyg6pfZt4YdDalOnJ3HDAT1IPBCJDp-gQI,19876
|
16
16
|
brainstate/functional/__init__.py,sha256=j6-3Er4fgqWpvntzYCZVB3e5hoz-Z3aqvapITCuDri0,1107
|
17
17
|
brainstate/functional/_activations.py,sha256=gvZ9E1-TsEUlyO7Om0eYzlM9DF-14_A32-gta1mjGo4,17798
|
@@ -47,20 +47,24 @@ brainstate/optim/__init__.py,sha256=1L6x_qZprw3PJYddB1nX-uTFGUl6_Qt3PM0OdY6g968,
|
|
47
47
|
brainstate/optim/_lr_scheduler.py,sha256=emKnA52UVqOfUcX7LJqwP-FVDVlGGzTQi2djYmbCWUo,15627
|
48
48
|
brainstate/optim/_lr_scheduler_test.py,sha256=OwF8Iz-PorEbO0gO--A7IIgQEytqEfYWbPucAgzqL90,1598
|
49
49
|
brainstate/optim/_sgd_optimizer.py,sha256=JiK_AVGregL0wn8uHhRQvK9Qq7Qja7dEyLW6Aan7b70,45826
|
50
|
-
brainstate/transform/__init__.py,sha256=
|
50
|
+
brainstate/transform/__init__.py,sha256=hqef3a4sLQ_Oihuqs8E5IghSLr9o2bS7CWmwRL8jX6E,1887
|
51
51
|
brainstate/transform/_autograd.py,sha256=Pj_YxpU52guaxQs1NcB6qDtXgkvaPcoJbuvIF8T-Wmk,23964
|
52
52
|
brainstate/transform/_autograd_test.py,sha256=RWriMemIF9FVFUjQh4IHzLhT9LGyd1JXpjXfFZKHn10,38654
|
53
|
-
brainstate/transform/
|
54
|
-
brainstate/transform/
|
53
|
+
brainstate/transform/_conditions.py,sha256=bdqdHCPCJIpRJxNr0enO2u81924YoIuA8kS8GUGY98g,12970
|
54
|
+
brainstate/transform/_conditions_test.py,sha256=hg3gyOk4jn88F_ZYYqqwf6m87N3GlOUE9dC2V3BnMTA,7691
|
55
|
+
brainstate/transform/_error_if.py,sha256=0JThfFqt9B3K3H6mS84qecBS22yTi3-FPzviaYacaMY,2597
|
56
|
+
brainstate/transform/_error_if_test.py,sha256=kQZujlgr9bYnL-Vf7x4Zfc7jJk7rCLNVu-bsiry40dQ,1874
|
55
57
|
brainstate/transform/_jit.py,sha256=sjQHFV8Tt75fpdl12jjPRDPT92_IZxBBJAG4gapdbNQ,11471
|
56
|
-
brainstate/transform/_jit_error.py,sha256=8rGRx8dtvmPWmHVOsfz30EUMXSix-m2PKM3Ni_9-_7I,4829
|
57
|
-
brainstate/transform/_jit_error_test.py,sha256=GAVGL0eNJ5Fu0lHABCGc-nLfa_0x0tw_VPfURB-nhLc,1862
|
58
58
|
brainstate/transform/_jit_test.py,sha256=5ltT7izh_OS9dcHnRymmVhq01QomjwZGdA8XzwJRLb4,2868
|
59
|
+
brainstate/transform/_loop_collect_return.py,sha256=aUZSK5MX4stVP8Te4R8glm2SdP18rUUfHjcV4TXOPC8,20768
|
60
|
+
brainstate/transform/_loop_no_collection.py,sha256=p2vHoNNesDH2cM7b5LgLzSv90M8iNQPkRZEl0jhf7yA,6476
|
59
61
|
brainstate/transform/_make_jaxpr.py,sha256=ZkrOZu4_0xcILuPUA3RFEkorJ-xbDuDtXorJI_qVThE,30450
|
60
62
|
brainstate/transform/_make_jaxpr_test.py,sha256=K3vRUBroDTCCx0lnmhgHtgrlWvWglJO2f1K2phTvU70,3819
|
63
|
+
brainstate/transform/_mapping.py,sha256=G9XUsD1xKLCprwwE0wv0gSXS0NYZ-ZIsv-PKKRlOoTA,3821
|
61
64
|
brainstate/transform/_progress_bar.py,sha256=VGoRZPRBmB8ELNwLc6c7S8QhUUTvn0FY46IbBm9cuYM,3502
|
62
|
-
brainstate
|
63
|
-
brainstate-0.0.2.
|
64
|
-
brainstate-0.0.2.
|
65
|
-
brainstate-0.0.2.
|
66
|
-
brainstate-0.0.2.
|
65
|
+
brainstate/transform/_unvmap.py,sha256=8Se_23QrwDdcJpFcUnnMgD6EP-4XylbhP9K5TDhW358,3311
|
66
|
+
brainstate-0.0.2.post20240826.dist-info/LICENSE,sha256=VZe9u1jgUL2eCY6ZPOYgdb8KCblCHt8ECdbtJid6e1s,11550
|
67
|
+
brainstate-0.0.2.post20240826.dist-info/METADATA,sha256=unkbYHiPHAtNHGENoSt47mA0N3cXWsCafRC8Fo2NPyk,3851
|
68
|
+
brainstate-0.0.2.post20240826.dist-info/WHEEL,sha256=bb2Ot9scclHKMOLDEHY6B2sicWOgugjFKaJsT7vwMQo,110
|
69
|
+
brainstate-0.0.2.post20240826.dist-info/top_level.txt,sha256=eQbGgKn0ptx7FDWuua0V0wr4K1VHi2iOUCYo3fUQBRA,11
|
70
|
+
brainstate-0.0.2.post20240826.dist-info/RECORD,,
|