brainstate 0.1.8__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +58 -51
- brainstate/_compatible_import.py +148 -148
- brainstate/_state.py +1605 -1663
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/augment/__init__.py +30 -30
- brainstate/augment/_autograd.py +778 -778
- brainstate/augment/_autograd_test.py +1289 -1289
- brainstate/augment/_eval_shape.py +99 -99
- brainstate/augment/_eval_shape_test.py +38 -38
- brainstate/augment/_mapping.py +1060 -1060
- brainstate/augment/_mapping_test.py +597 -597
- brainstate/augment/_random.py +151 -151
- brainstate/compile/__init__.py +38 -38
- brainstate/compile/_ad_checkpoint.py +204 -204
- brainstate/compile/_ad_checkpoint_test.py +49 -49
- brainstate/compile/_conditions.py +256 -256
- brainstate/compile/_conditions_test.py +220 -220
- brainstate/compile/_error_if.py +92 -92
- brainstate/compile/_error_if_test.py +52 -52
- brainstate/compile/_jit.py +346 -346
- brainstate/compile/_jit_test.py +143 -143
- brainstate/compile/_loop_collect_return.py +536 -536
- brainstate/compile/_loop_collect_return_test.py +58 -58
- brainstate/compile/_loop_no_collection.py +184 -184
- brainstate/compile/_loop_no_collection_test.py +50 -50
- brainstate/compile/_make_jaxpr.py +888 -888
- brainstate/compile/_make_jaxpr_test.py +156 -156
- brainstate/compile/_progress_bar.py +202 -202
- brainstate/compile/_unvmap.py +159 -159
- brainstate/compile/_util.py +147 -147
- brainstate/environ.py +563 -563
- brainstate/environ_test.py +62 -62
- brainstate/functional/__init__.py +27 -26
- brainstate/graph/__init__.py +29 -29
- brainstate/graph/_graph_node.py +244 -244
- brainstate/graph/_graph_node_test.py +73 -73
- brainstate/graph/_graph_operation.py +1738 -1738
- brainstate/graph/_graph_operation_test.py +563 -563
- brainstate/init/__init__.py +26 -26
- brainstate/init/_base.py +52 -52
- brainstate/init/_generic.py +244 -244
- brainstate/init/_random_inits.py +553 -553
- brainstate/init/_random_inits_test.py +149 -149
- brainstate/init/_regular_inits.py +105 -105
- brainstate/init/_regular_inits_test.py +50 -50
- brainstate/mixin.py +365 -363
- brainstate/mixin_test.py +77 -73
- brainstate/nn/__init__.py +135 -131
- brainstate/{functional → nn}/_activations.py +808 -813
- brainstate/{functional → nn}/_activations_test.py +331 -331
- brainstate/nn/_collective_ops.py +514 -514
- brainstate/nn/_collective_ops_test.py +43 -43
- brainstate/nn/_common.py +178 -178
- brainstate/nn/_conv.py +501 -501
- brainstate/nn/_conv_test.py +238 -238
- brainstate/nn/_delay.py +509 -502
- brainstate/nn/_delay_test.py +238 -184
- brainstate/nn/_dropout.py +426 -426
- brainstate/nn/_dropout_test.py +100 -100
- brainstate/nn/_dynamics.py +1343 -1343
- brainstate/nn/_dynamics_test.py +78 -78
- brainstate/nn/_elementwise.py +1119 -1119
- brainstate/nn/_elementwise_test.py +169 -169
- brainstate/nn/_embedding.py +58 -58
- brainstate/nn/_exp_euler.py +92 -92
- brainstate/nn/_exp_euler_test.py +35 -35
- brainstate/nn/_fixedprob.py +239 -239
- brainstate/nn/_fixedprob_test.py +114 -114
- brainstate/nn/_inputs.py +608 -608
- brainstate/nn/_linear.py +424 -424
- brainstate/nn/_linear_mv.py +83 -83
- brainstate/nn/_linear_mv_test.py +120 -120
- brainstate/nn/_linear_test.py +107 -107
- brainstate/nn/_ltp.py +28 -28
- brainstate/nn/_module.py +377 -377
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_neuron.py +705 -705
- brainstate/nn/_neuron_test.py +161 -161
- brainstate/nn/_normalizations.py +975 -918
- brainstate/nn/_normalizations_test.py +73 -73
- brainstate/{functional → nn}/_others.py +46 -46
- brainstate/nn/_poolings.py +1177 -1177
- brainstate/nn/_poolings_test.py +217 -217
- brainstate/nn/_projection.py +486 -486
- brainstate/nn/_rate_rnns.py +554 -554
- brainstate/nn/_rate_rnns_test.py +63 -63
- brainstate/nn/_readout.py +209 -209
- brainstate/nn/_readout_test.py +53 -53
- brainstate/nn/_stp.py +236 -236
- brainstate/nn/_synapse.py +505 -505
- brainstate/nn/_synapse_test.py +131 -131
- brainstate/nn/_synaptic_projection.py +423 -423
- brainstate/nn/_synouts.py +162 -162
- brainstate/nn/_synouts_test.py +57 -57
- brainstate/nn/_utils.py +89 -89
- brainstate/nn/metrics.py +388 -388
- brainstate/optim/__init__.py +38 -38
- brainstate/optim/_base.py +64 -64
- brainstate/optim/_lr_scheduler.py +448 -448
- brainstate/optim/_lr_scheduler_test.py +50 -50
- brainstate/optim/_optax_optimizer.py +152 -152
- brainstate/optim/_optax_optimizer_test.py +53 -53
- brainstate/optim/_sgd_optimizer.py +1104 -1104
- brainstate/random/__init__.py +24 -24
- brainstate/random/_rand_funs.py +3616 -3616
- brainstate/random/_rand_funs_test.py +567 -567
- brainstate/random/_rand_seed.py +210 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1409 -1409
- brainstate/random/_random_for_unit.py +52 -52
- brainstate/surrogate.py +1957 -1957
- brainstate/transform.py +23 -23
- brainstate/typing.py +304 -304
- brainstate/util/__init__.py +50 -50
- brainstate/util/caller.py +98 -98
- brainstate/util/error.py +55 -55
- brainstate/util/filter.py +469 -469
- brainstate/util/others.py +540 -540
- brainstate/util/pretty_pytree.py +945 -945
- brainstate/util/pretty_pytree_test.py +159 -159
- brainstate/util/pretty_repr.py +328 -328
- brainstate/util/pretty_table.py +2954 -2954
- brainstate/util/scaling.py +258 -258
- brainstate/util/struct.py +523 -523
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
- brainstate-0.1.9.dist-info/RECORD +130 -0
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
- brainstate/functional/_normalization.py +0 -81
- brainstate/functional/_spikes.py +0 -204
- brainstate-0.1.8.dist-info/RECORD +0 -132
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
brainstate/transform.py
CHANGED
@@ -1,23 +1,23 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# alias for compilation and augmentation functions
|
17
|
-
|
18
|
-
from .augment import *
|
19
|
-
from .compile import *
|
20
|
-
|
21
|
-
if __name__ == '__main__':
|
22
|
-
ifelse
|
23
|
-
grad
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# alias for compilation and augmentation functions
|
17
|
+
|
18
|
+
from .augment import *
|
19
|
+
from .compile import *
|
20
|
+
|
21
|
+
if __name__ == '__main__':
|
22
|
+
ifelse
|
23
|
+
grad
|
brainstate/typing.py
CHANGED
@@ -1,304 +1,304 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import builtins
|
17
|
-
import functools as ft
|
18
|
-
import importlib
|
19
|
-
import inspect
|
20
|
-
from typing import (
|
21
|
-
Any, Callable, Hashable, List, Protocol, Tuple, TypeVar, Union,
|
22
|
-
runtime_checkable, TYPE_CHECKING, Generic, Sequence
|
23
|
-
)
|
24
|
-
|
25
|
-
import brainunit as u
|
26
|
-
import jax
|
27
|
-
import numpy as np
|
28
|
-
|
29
|
-
tp = importlib.import_module("typing")
|
30
|
-
|
31
|
-
__all__ = [
|
32
|
-
'PathParts',
|
33
|
-
'Predicate',
|
34
|
-
'Filter',
|
35
|
-
'PyTree',
|
36
|
-
'Size',
|
37
|
-
'Shape',
|
38
|
-
'Axes',
|
39
|
-
'SeedOrKey',
|
40
|
-
'ArrayLike',
|
41
|
-
'DType',
|
42
|
-
'DTypeLike',
|
43
|
-
'Missing',
|
44
|
-
]
|
45
|
-
|
46
|
-
K = TypeVar('K')
|
47
|
-
|
48
|
-
|
49
|
-
@runtime_checkable
|
50
|
-
class Key(Hashable, Protocol):
|
51
|
-
def __lt__(self: K, value: K, /) -> bool:
|
52
|
-
...
|
53
|
-
|
54
|
-
|
55
|
-
Ellipsis = builtins.ellipsis if TYPE_CHECKING else Any
|
56
|
-
|
57
|
-
PathParts = Tuple[Key, ...]
|
58
|
-
Predicate = Callable[[PathParts, Any], bool]
|
59
|
-
FilterLiteral = Union[type, str, Predicate, bool, Ellipsis, None]
|
60
|
-
Filter = Union[FilterLiteral, Tuple['Filter', ...], List['Filter']]
|
61
|
-
|
62
|
-
_T = TypeVar("_T")
|
63
|
-
|
64
|
-
_Annotation = TypeVar("_Annotation")
|
65
|
-
|
66
|
-
|
67
|
-
class _Array(Generic[_Annotation]):
|
68
|
-
pass
|
69
|
-
|
70
|
-
|
71
|
-
_Array.__module__ = "builtins"
|
72
|
-
|
73
|
-
|
74
|
-
def _item_to_str(item: Union[str, type, slice]) -> str:
|
75
|
-
if isinstance(item, slice):
|
76
|
-
if item.step is not None:
|
77
|
-
raise NotImplementedError
|
78
|
-
return _item_to_str(item.start) + ": " + _item_to_str(item.stop)
|
79
|
-
elif item is ...:
|
80
|
-
return "..."
|
81
|
-
elif inspect.isclass(item):
|
82
|
-
return item.__name__
|
83
|
-
else:
|
84
|
-
return repr(item)
|
85
|
-
|
86
|
-
|
87
|
-
def _maybe_tuple_to_str(
|
88
|
-
item: Union[str, type, slice, Tuple[Union[str, type, slice], ...]]
|
89
|
-
) -> str:
|
90
|
-
if isinstance(item, tuple):
|
91
|
-
if len(item) == 0:
|
92
|
-
# Explicit brackets
|
93
|
-
return "()"
|
94
|
-
else:
|
95
|
-
# No brackets
|
96
|
-
return ", ".join([_item_to_str(i) for i in item])
|
97
|
-
else:
|
98
|
-
return _item_to_str(item)
|
99
|
-
|
100
|
-
|
101
|
-
class Array:
|
102
|
-
def __class_getitem__(cls, item):
|
103
|
-
class X:
|
104
|
-
pass
|
105
|
-
|
106
|
-
X.__module__ = "builtins"
|
107
|
-
X.__qualname__ = _maybe_tuple_to_str(item)
|
108
|
-
return _Array[X]
|
109
|
-
|
110
|
-
|
111
|
-
# Same __module__ trick here again. (So that we get the correct display when
|
112
|
-
# doing `def f(x: Array)` as well as `def f(x: Array["dim"])`.
|
113
|
-
#
|
114
|
-
# Don't need to set __qualname__ as that's already correct.
|
115
|
-
Array.__module__ = "builtins"
|
116
|
-
|
117
|
-
|
118
|
-
class _FakePyTree(Generic[_T]):
|
119
|
-
pass
|
120
|
-
|
121
|
-
|
122
|
-
_FakePyTree.__name__ = "PyTree"
|
123
|
-
_FakePyTree.__qualname__ = "PyTree"
|
124
|
-
_FakePyTree.__module__ = "builtins"
|
125
|
-
|
126
|
-
|
127
|
-
class _MetaPyTree(type):
|
128
|
-
def __call__(self, *args, **kwargs):
|
129
|
-
raise RuntimeError("PyTree cannot be instantiated")
|
130
|
-
|
131
|
-
# Can't return a generic (e.g. _FakePyTree[item]) because generic aliases don't do
|
132
|
-
# the custom __instancecheck__ that we want.
|
133
|
-
# We can't add that __instancecheck__ via subclassing, e.g.
|
134
|
-
# type("PyTree", (Generic[_T],), {}), because dynamic subclassing of typeforms
|
135
|
-
# isn't allowed.
|
136
|
-
# Likewise we can't do types.new_class("PyTree", (Generic[_T],), {}) because that
|
137
|
-
# has __module__ "types", e.g. we get types.PyTree[int].
|
138
|
-
@ft.lru_cache(maxsize=None)
|
139
|
-
def __getitem__(cls, item):
|
140
|
-
if isinstance(item, tuple):
|
141
|
-
if len(item) == 2:
|
142
|
-
|
143
|
-
class X(PyTree):
|
144
|
-
leaftype = item[0]
|
145
|
-
structure = item[1].strip()
|
146
|
-
|
147
|
-
if not isinstance(X.structure, str):
|
148
|
-
raise ValueError(
|
149
|
-
"The structure annotation `struct` in "
|
150
|
-
"`brainstate.typing.PyTree[leaftype, struct]` must be be a string, "
|
151
|
-
f"e.g. `brainstate.typing.PyTree[leaftype, 'T']`. Got '{X.structure}'."
|
152
|
-
)
|
153
|
-
pieces = X.structure.split()
|
154
|
-
if len(pieces) == 0:
|
155
|
-
raise ValueError(
|
156
|
-
"The string `struct` in `brainstate.typing.PyTree[leaftype, struct]` "
|
157
|
-
"cannot be the empty string."
|
158
|
-
)
|
159
|
-
for piece_index, piece in enumerate(pieces):
|
160
|
-
if (piece_index == 0) or (piece_index == len(pieces) - 1):
|
161
|
-
if piece == "...":
|
162
|
-
continue
|
163
|
-
if not piece.isidentifier():
|
164
|
-
raise ValueError(
|
165
|
-
"The string `struct` in "
|
166
|
-
"`brainstate.typing.PyTree[leaftype, struct]` must be be a "
|
167
|
-
"whitespace-separated sequence of identifiers, e.g. "
|
168
|
-
"`brainstate.typing.PyTree[leaftype, 'T']` or "
|
169
|
-
"`brainstate.typing.PyTree[leaftype, 'foo bar']`.\n"
|
170
|
-
"(Here, 'identifier' is used in the same sense as in "
|
171
|
-
"regular Python, i.e. a valid variable name.)\n"
|
172
|
-
f"Got piece '{piece}' in overall structure '{X.structure}'."
|
173
|
-
)
|
174
|
-
name = str(_FakePyTree[item[0]])[:-1] + ', "' + item[1].strip() + '"]'
|
175
|
-
else:
|
176
|
-
raise ValueError(
|
177
|
-
"The subscript `foo` in `brainstate.typing.PyTree[foo]` must either be a "
|
178
|
-
"leaf type, e.g. `PyTree[int]`, or a 2-tuple of leaf and "
|
179
|
-
"structure, e.g. `PyTree[int, 'T']`. Received a tuple of length "
|
180
|
-
f"{len(item)}."
|
181
|
-
)
|
182
|
-
else:
|
183
|
-
name = str(_FakePyTree[item])
|
184
|
-
|
185
|
-
class X(PyTree):
|
186
|
-
leaftype = item
|
187
|
-
structure = None
|
188
|
-
|
189
|
-
X.__name__ = name
|
190
|
-
X.__qualname__ = name
|
191
|
-
if getattr(tp, "GENERATING_DOCUMENTATION", False):
|
192
|
-
X.__module__ = "builtins"
|
193
|
-
else:
|
194
|
-
X.__module__ = "brainstate.typing"
|
195
|
-
return X
|
196
|
-
|
197
|
-
|
198
|
-
# Can't do `class PyTree(Generic[_T]): ...` because we need to override the
|
199
|
-
# instancecheck for PyTree[foo], but subclassing
|
200
|
-
# `type(Generic[int])`, i.e. `typing._GenericAlias` is disallowed.
|
201
|
-
PyTree = _MetaPyTree("PyTree", (), {})
|
202
|
-
if getattr(tp, "GENERATING_DOCUMENTATION", False):
|
203
|
-
PyTree.__module__ = "builtins"
|
204
|
-
else:
|
205
|
-
PyTree.__module__ = "brainstate.typing"
|
206
|
-
PyTree.__doc__ = """Represents a PyTree.
|
207
|
-
|
208
|
-
Annotations of the following sorts are supported:
|
209
|
-
```python
|
210
|
-
a: PyTree
|
211
|
-
b: PyTree[LeafType]
|
212
|
-
c: PyTree[LeafType, "T"]
|
213
|
-
d: PyTree[LeafType, "S T"]
|
214
|
-
e: PyTree[LeafType, "... T"]
|
215
|
-
f: PyTree[LeafType, "T ..."]
|
216
|
-
```
|
217
|
-
|
218
|
-
These correspond to:
|
219
|
-
|
220
|
-
a. A plain `PyTree` can be used an annotation, in which case `PyTree` is simply a
|
221
|
-
suggestively-named alternative to `Any`.
|
222
|
-
([By definition all types are PyTrees.](https://jax.readthedocs.io/en/latest/pytrees.html))
|
223
|
-
|
224
|
-
b. `PyTree[LeafType]` denotes a PyTree all of whose leaves match `LeafType`. For
|
225
|
-
example, `PyTree[int]` or `PyTree[Union[str, Float32[Array, "b c"]]]`.
|
226
|
-
|
227
|
-
c. A structure name can also be passed. In this case
|
228
|
-
`jax.tree_util.tree_structure(...)` will be called, and bound to the structure name.
|
229
|
-
This can be used to mark that multiple PyTrees all have the same structure:
|
230
|
-
```python
|
231
|
-
def f(x: PyTree[int, "T"], y: PyTree[int, "T"]):
|
232
|
-
...
|
233
|
-
```
|
234
|
-
|
235
|
-
d. A composite structure can be declared. In this case the variable must have a PyTree
|
236
|
-
structure each to the composition of multiple previously-bound PyTree structures.
|
237
|
-
For example:
|
238
|
-
```python
|
239
|
-
def f(x: PyTree[int, "T"], y: PyTree[int, "S"], z: PyTree[int, "S T"]):
|
240
|
-
...
|
241
|
-
|
242
|
-
x = (1, 2)
|
243
|
-
y = {"key": 3}
|
244
|
-
z = {"key": (4, 5)} # structure is the composition of the structures of `y` and `z`
|
245
|
-
f(x, y, z)
|
246
|
-
```
|
247
|
-
When performing runtime type-checking, all the individual pieces must have already
|
248
|
-
been bound to structures, otherwise the composite structure check will throw an error.
|
249
|
-
|
250
|
-
e. A structure can begin with a `...`, to denote that the lower levels of the PyTree
|
251
|
-
must match the declared structure, but the upper levels can be arbitrary. As in the
|
252
|
-
previous case, all named pieces must already have been seen and their structures
|
253
|
-
bound.
|
254
|
-
|
255
|
-
f. A structure can end with a `...`, to denote that the PyTree must be a prefix of the
|
256
|
-
declared structure, but the lower levels can be arbitrary. As in the previous two
|
257
|
-
cases, all named pieces must already have been seen and their structures bound.
|
258
|
-
""" # noqa: E501
|
259
|
-
|
260
|
-
Size = Union[int, Sequence[int], np.integer, Sequence[np.integer]]
|
261
|
-
Axes = Union[int, Sequence[int]]
|
262
|
-
SeedOrKey = Union[int, jax.Array, np.ndarray]
|
263
|
-
Shape = Sequence[int]
|
264
|
-
|
265
|
-
# --- Array --- #
|
266
|
-
|
267
|
-
# ArrayLike is a Union of all objects that can be implicitly converted to a
|
268
|
-
# standard JAX array (i.e. not including future non-standard array types like
|
269
|
-
# KeyArray and BInt). It's different than np.typing.ArrayLike in that it doesn't
|
270
|
-
# accept arbitrary sequences, nor does it accept string data.
|
271
|
-
ArrayLike = Union[
|
272
|
-
jax.Array, # JAX array type
|
273
|
-
np.ndarray, # NumPy array type
|
274
|
-
np.bool_, np.number, # NumPy scalar types
|
275
|
-
bool, int, float, complex, # Python scalar types
|
276
|
-
u.Quantity, # Quantity
|
277
|
-
]
|
278
|
-
|
279
|
-
# --- Dtype --- #
|
280
|
-
|
281
|
-
|
282
|
-
DType = np.dtype
|
283
|
-
|
284
|
-
|
285
|
-
class SupportsDType(Protocol):
|
286
|
-
@property
|
287
|
-
def dtype(self) -> DType: ...
|
288
|
-
|
289
|
-
|
290
|
-
# DTypeLike is meant to annotate inputs to np.dtype that return
|
291
|
-
# a valid JAX dtype. It's different than numpy.typing.DTypeLike
|
292
|
-
# because JAX doesn't support objects or structured dtypes.
|
293
|
-
# Unlike np.typing.DTypeLike, we exclude None, and instead require
|
294
|
-
# explicit annotations when None is acceptable.
|
295
|
-
DTypeLike = Union[
|
296
|
-
str, # like 'float32', 'int32'
|
297
|
-
type[Any], # like np.float32, np.int32, float, int
|
298
|
-
np.dtype, # like np.dtype('float32'), np.dtype('int32')
|
299
|
-
SupportsDType, # like jnp.float32, jnp.int32
|
300
|
-
]
|
301
|
-
|
302
|
-
|
303
|
-
class Missing:
|
304
|
-
pass
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import builtins
|
17
|
+
import functools as ft
|
18
|
+
import importlib
|
19
|
+
import inspect
|
20
|
+
from typing import (
|
21
|
+
Any, Callable, Hashable, List, Protocol, Tuple, TypeVar, Union,
|
22
|
+
runtime_checkable, TYPE_CHECKING, Generic, Sequence
|
23
|
+
)
|
24
|
+
|
25
|
+
import brainunit as u
|
26
|
+
import jax
|
27
|
+
import numpy as np
|
28
|
+
|
29
|
+
tp = importlib.import_module("typing")
|
30
|
+
|
31
|
+
__all__ = [
|
32
|
+
'PathParts',
|
33
|
+
'Predicate',
|
34
|
+
'Filter',
|
35
|
+
'PyTree',
|
36
|
+
'Size',
|
37
|
+
'Shape',
|
38
|
+
'Axes',
|
39
|
+
'SeedOrKey',
|
40
|
+
'ArrayLike',
|
41
|
+
'DType',
|
42
|
+
'DTypeLike',
|
43
|
+
'Missing',
|
44
|
+
]
|
45
|
+
|
46
|
+
K = TypeVar('K')
|
47
|
+
|
48
|
+
|
49
|
+
@runtime_checkable
|
50
|
+
class Key(Hashable, Protocol):
|
51
|
+
def __lt__(self: K, value: K, /) -> bool:
|
52
|
+
...
|
53
|
+
|
54
|
+
|
55
|
+
Ellipsis = builtins.ellipsis if TYPE_CHECKING else Any
|
56
|
+
|
57
|
+
PathParts = Tuple[Key, ...]
|
58
|
+
Predicate = Callable[[PathParts, Any], bool]
|
59
|
+
FilterLiteral = Union[type, str, Predicate, bool, Ellipsis, None]
|
60
|
+
Filter = Union[FilterLiteral, Tuple['Filter', ...], List['Filter']]
|
61
|
+
|
62
|
+
_T = TypeVar("_T")
|
63
|
+
|
64
|
+
_Annotation = TypeVar("_Annotation")
|
65
|
+
|
66
|
+
|
67
|
+
class _Array(Generic[_Annotation]):
|
68
|
+
pass
|
69
|
+
|
70
|
+
|
71
|
+
_Array.__module__ = "builtins"
|
72
|
+
|
73
|
+
|
74
|
+
def _item_to_str(item: Union[str, type, slice]) -> str:
|
75
|
+
if isinstance(item, slice):
|
76
|
+
if item.step is not None:
|
77
|
+
raise NotImplementedError
|
78
|
+
return _item_to_str(item.start) + ": " + _item_to_str(item.stop)
|
79
|
+
elif item is ...:
|
80
|
+
return "..."
|
81
|
+
elif inspect.isclass(item):
|
82
|
+
return item.__name__
|
83
|
+
else:
|
84
|
+
return repr(item)
|
85
|
+
|
86
|
+
|
87
|
+
def _maybe_tuple_to_str(
|
88
|
+
item: Union[str, type, slice, Tuple[Union[str, type, slice], ...]]
|
89
|
+
) -> str:
|
90
|
+
if isinstance(item, tuple):
|
91
|
+
if len(item) == 0:
|
92
|
+
# Explicit brackets
|
93
|
+
return "()"
|
94
|
+
else:
|
95
|
+
# No brackets
|
96
|
+
return ", ".join([_item_to_str(i) for i in item])
|
97
|
+
else:
|
98
|
+
return _item_to_str(item)
|
99
|
+
|
100
|
+
|
101
|
+
class Array:
|
102
|
+
def __class_getitem__(cls, item):
|
103
|
+
class X:
|
104
|
+
pass
|
105
|
+
|
106
|
+
X.__module__ = "builtins"
|
107
|
+
X.__qualname__ = _maybe_tuple_to_str(item)
|
108
|
+
return _Array[X]
|
109
|
+
|
110
|
+
|
111
|
+
# Same __module__ trick here again. (So that we get the correct display when
|
112
|
+
# doing `def f(x: Array)` as well as `def f(x: Array["dim"])`.
|
113
|
+
#
|
114
|
+
# Don't need to set __qualname__ as that's already correct.
|
115
|
+
Array.__module__ = "builtins"
|
116
|
+
|
117
|
+
|
118
|
+
class _FakePyTree(Generic[_T]):
|
119
|
+
pass
|
120
|
+
|
121
|
+
|
122
|
+
_FakePyTree.__name__ = "PyTree"
|
123
|
+
_FakePyTree.__qualname__ = "PyTree"
|
124
|
+
_FakePyTree.__module__ = "builtins"
|
125
|
+
|
126
|
+
|
127
|
+
class _MetaPyTree(type):
|
128
|
+
def __call__(self, *args, **kwargs):
|
129
|
+
raise RuntimeError("PyTree cannot be instantiated")
|
130
|
+
|
131
|
+
# Can't return a generic (e.g. _FakePyTree[item]) because generic aliases don't do
|
132
|
+
# the custom __instancecheck__ that we want.
|
133
|
+
# We can't add that __instancecheck__ via subclassing, e.g.
|
134
|
+
# type("PyTree", (Generic[_T],), {}), because dynamic subclassing of typeforms
|
135
|
+
# isn't allowed.
|
136
|
+
# Likewise we can't do types.new_class("PyTree", (Generic[_T],), {}) because that
|
137
|
+
# has __module__ "types", e.g. we get types.PyTree[int].
|
138
|
+
@ft.lru_cache(maxsize=None)
|
139
|
+
def __getitem__(cls, item):
|
140
|
+
if isinstance(item, tuple):
|
141
|
+
if len(item) == 2:
|
142
|
+
|
143
|
+
class X(PyTree):
|
144
|
+
leaftype = item[0]
|
145
|
+
structure = item[1].strip()
|
146
|
+
|
147
|
+
if not isinstance(X.structure, str):
|
148
|
+
raise ValueError(
|
149
|
+
"The structure annotation `struct` in "
|
150
|
+
"`brainstate.typing.PyTree[leaftype, struct]` must be be a string, "
|
151
|
+
f"e.g. `brainstate.typing.PyTree[leaftype, 'T']`. Got '{X.structure}'."
|
152
|
+
)
|
153
|
+
pieces = X.structure.split()
|
154
|
+
if len(pieces) == 0:
|
155
|
+
raise ValueError(
|
156
|
+
"The string `struct` in `brainstate.typing.PyTree[leaftype, struct]` "
|
157
|
+
"cannot be the empty string."
|
158
|
+
)
|
159
|
+
for piece_index, piece in enumerate(pieces):
|
160
|
+
if (piece_index == 0) or (piece_index == len(pieces) - 1):
|
161
|
+
if piece == "...":
|
162
|
+
continue
|
163
|
+
if not piece.isidentifier():
|
164
|
+
raise ValueError(
|
165
|
+
"The string `struct` in "
|
166
|
+
"`brainstate.typing.PyTree[leaftype, struct]` must be be a "
|
167
|
+
"whitespace-separated sequence of identifiers, e.g. "
|
168
|
+
"`brainstate.typing.PyTree[leaftype, 'T']` or "
|
169
|
+
"`brainstate.typing.PyTree[leaftype, 'foo bar']`.\n"
|
170
|
+
"(Here, 'identifier' is used in the same sense as in "
|
171
|
+
"regular Python, i.e. a valid variable name.)\n"
|
172
|
+
f"Got piece '{piece}' in overall structure '{X.structure}'."
|
173
|
+
)
|
174
|
+
name = str(_FakePyTree[item[0]])[:-1] + ', "' + item[1].strip() + '"]'
|
175
|
+
else:
|
176
|
+
raise ValueError(
|
177
|
+
"The subscript `foo` in `brainstate.typing.PyTree[foo]` must either be a "
|
178
|
+
"leaf type, e.g. `PyTree[int]`, or a 2-tuple of leaf and "
|
179
|
+
"structure, e.g. `PyTree[int, 'T']`. Received a tuple of length "
|
180
|
+
f"{len(item)}."
|
181
|
+
)
|
182
|
+
else:
|
183
|
+
name = str(_FakePyTree[item])
|
184
|
+
|
185
|
+
class X(PyTree):
|
186
|
+
leaftype = item
|
187
|
+
structure = None
|
188
|
+
|
189
|
+
X.__name__ = name
|
190
|
+
X.__qualname__ = name
|
191
|
+
if getattr(tp, "GENERATING_DOCUMENTATION", False):
|
192
|
+
X.__module__ = "builtins"
|
193
|
+
else:
|
194
|
+
X.__module__ = "brainstate.typing"
|
195
|
+
return X
|
196
|
+
|
197
|
+
|
198
|
+
# Can't do `class PyTree(Generic[_T]): ...` because we need to override the
|
199
|
+
# instancecheck for PyTree[foo], but subclassing
|
200
|
+
# `type(Generic[int])`, i.e. `typing._GenericAlias` is disallowed.
|
201
|
+
PyTree = _MetaPyTree("PyTree", (), {})
|
202
|
+
if getattr(tp, "GENERATING_DOCUMENTATION", False):
|
203
|
+
PyTree.__module__ = "builtins"
|
204
|
+
else:
|
205
|
+
PyTree.__module__ = "brainstate.typing"
|
206
|
+
PyTree.__doc__ = """Represents a PyTree.
|
207
|
+
|
208
|
+
Annotations of the following sorts are supported:
|
209
|
+
```python
|
210
|
+
a: PyTree
|
211
|
+
b: PyTree[LeafType]
|
212
|
+
c: PyTree[LeafType, "T"]
|
213
|
+
d: PyTree[LeafType, "S T"]
|
214
|
+
e: PyTree[LeafType, "... T"]
|
215
|
+
f: PyTree[LeafType, "T ..."]
|
216
|
+
```
|
217
|
+
|
218
|
+
These correspond to:
|
219
|
+
|
220
|
+
a. A plain `PyTree` can be used an annotation, in which case `PyTree` is simply a
|
221
|
+
suggestively-named alternative to `Any`.
|
222
|
+
([By definition all types are PyTrees.](https://jax.readthedocs.io/en/latest/pytrees.html))
|
223
|
+
|
224
|
+
b. `PyTree[LeafType]` denotes a PyTree all of whose leaves match `LeafType`. For
|
225
|
+
example, `PyTree[int]` or `PyTree[Union[str, Float32[Array, "b c"]]]`.
|
226
|
+
|
227
|
+
c. A structure name can also be passed. In this case
|
228
|
+
`jax.tree_util.tree_structure(...)` will be called, and bound to the structure name.
|
229
|
+
This can be used to mark that multiple PyTrees all have the same structure:
|
230
|
+
```python
|
231
|
+
def f(x: PyTree[int, "T"], y: PyTree[int, "T"]):
|
232
|
+
...
|
233
|
+
```
|
234
|
+
|
235
|
+
d. A composite structure can be declared. In this case the variable must have a PyTree
|
236
|
+
structure each to the composition of multiple previously-bound PyTree structures.
|
237
|
+
For example:
|
238
|
+
```python
|
239
|
+
def f(x: PyTree[int, "T"], y: PyTree[int, "S"], z: PyTree[int, "S T"]):
|
240
|
+
...
|
241
|
+
|
242
|
+
x = (1, 2)
|
243
|
+
y = {"key": 3}
|
244
|
+
z = {"key": (4, 5)} # structure is the composition of the structures of `y` and `z`
|
245
|
+
f(x, y, z)
|
246
|
+
```
|
247
|
+
When performing runtime type-checking, all the individual pieces must have already
|
248
|
+
been bound to structures, otherwise the composite structure check will throw an error.
|
249
|
+
|
250
|
+
e. A structure can begin with a `...`, to denote that the lower levels of the PyTree
|
251
|
+
must match the declared structure, but the upper levels can be arbitrary. As in the
|
252
|
+
previous case, all named pieces must already have been seen and their structures
|
253
|
+
bound.
|
254
|
+
|
255
|
+
f. A structure can end with a `...`, to denote that the PyTree must be a prefix of the
|
256
|
+
declared structure, but the lower levels can be arbitrary. As in the previous two
|
257
|
+
cases, all named pieces must already have been seen and their structures bound.
|
258
|
+
""" # noqa: E501
|
259
|
+
|
260
|
+
Size = Union[int, Sequence[int], np.integer, Sequence[np.integer]]
|
261
|
+
Axes = Union[int, Sequence[int]]
|
262
|
+
SeedOrKey = Union[int, jax.Array, np.ndarray]
|
263
|
+
Shape = Sequence[int]
|
264
|
+
|
265
|
+
# --- Array --- #
|
266
|
+
|
267
|
+
# ArrayLike is a Union of all objects that can be implicitly converted to a
|
268
|
+
# standard JAX array (i.e. not including future non-standard array types like
|
269
|
+
# KeyArray and BInt). It's different than np.typing.ArrayLike in that it doesn't
|
270
|
+
# accept arbitrary sequences, nor does it accept string data.
|
271
|
+
ArrayLike = Union[
|
272
|
+
jax.Array, # JAX array type
|
273
|
+
np.ndarray, # NumPy array type
|
274
|
+
np.bool_, np.number, # NumPy scalar types
|
275
|
+
bool, int, float, complex, # Python scalar types
|
276
|
+
u.Quantity, # Quantity
|
277
|
+
]
|
278
|
+
|
279
|
+
# --- Dtype --- #
|
280
|
+
|
281
|
+
|
282
|
+
DType = np.dtype
|
283
|
+
|
284
|
+
|
285
|
+
class SupportsDType(Protocol):
|
286
|
+
@property
|
287
|
+
def dtype(self) -> DType: ...
|
288
|
+
|
289
|
+
|
290
|
+
# DTypeLike is meant to annotate inputs to np.dtype that return
|
291
|
+
# a valid JAX dtype. It's different than numpy.typing.DTypeLike
|
292
|
+
# because JAX doesn't support objects or structured dtypes.
|
293
|
+
# Unlike np.typing.DTypeLike, we exclude None, and instead require
|
294
|
+
# explicit annotations when None is acceptable.
|
295
|
+
DTypeLike = Union[
|
296
|
+
str, # like 'float32', 'int32'
|
297
|
+
type[Any], # like np.float32, np.int32, float, int
|
298
|
+
np.dtype, # like np.dtype('float32'), np.dtype('int32')
|
299
|
+
SupportsDType, # like jnp.float32, jnp.int32
|
300
|
+
]
|
301
|
+
|
302
|
+
|
303
|
+
class Missing:
|
304
|
+
pass
|