brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/COBA_2005.py +125 -0
- benchmark/CUBA_2005.py +149 -0
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +611 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/event/__init__.py +27 -0
- brainstate/event/_csr.py +316 -0
- brainstate/event/_csr_benchmark.py +14 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +708 -0
- brainstate/event/_fixed_probability_benchmark.py +128 -0
- brainstate/event/_fixed_probability_test.py +131 -0
- brainstate/event/_linear.py +359 -0
- brainstate/event/_linear_benckmark.py +82 -0
- brainstate/event/_linear_test.py +117 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/event/_xla_custom_op.py +312 -0
- brainstate/event/_xla_custom_op_test.py +55 -0
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +41 -0
- brainstate/nn/_interaction/_conv.py +499 -0
- brainstate/nn/_interaction/_conv_test.py +239 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_linear.py +582 -0
- brainstate/nn/_interaction/_linear_test.py +42 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +121 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
- brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -13,5 +13,4 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
|
17
|
-
assert isinstance(x, cls), 'The input should be an instance of {}!'.format(cls)
|
16
|
+
# This module is going to be deleted in the future (near 2025-06).
|
brainstate/typing.py
CHANGED
@@ -13,73 +13,96 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
from __future__ import annotations
|
16
17
|
|
18
|
+
import builtins
|
17
19
|
import functools as ft
|
20
|
+
import importlib
|
18
21
|
import inspect
|
19
|
-
import typing
|
20
|
-
from typing import Sequence, Protocol, Union, Any, Generic, TypeVar, Tuple
|
21
22
|
|
22
|
-
import brainunit as
|
23
|
+
import brainunit as u
|
23
24
|
import jax
|
24
25
|
import numpy as np
|
25
26
|
|
27
|
+
tp = importlib.import_module("typing")
|
28
|
+
|
26
29
|
__all__ = [
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
30
|
+
'PathParts',
|
31
|
+
'Predicate',
|
32
|
+
'Filter',
|
33
|
+
'PyTree',
|
34
|
+
'Size',
|
35
|
+
'Axes',
|
36
|
+
'SeedOrKey',
|
37
|
+
'ArrayLike',
|
38
|
+
'DType',
|
39
|
+
'DTypeLike',
|
40
|
+
'Missing',
|
34
41
|
]
|
35
42
|
|
36
|
-
|
43
|
+
K = tp.TypeVar('K')
|
44
|
+
|
45
|
+
|
46
|
+
@tp.runtime_checkable
|
47
|
+
class Key(tp.Hashable, tp.Protocol):
|
48
|
+
def __lt__(self: K, value: K, /) -> bool:
|
49
|
+
...
|
50
|
+
|
37
51
|
|
38
|
-
|
52
|
+
Ellipsis = builtins.ellipsis if tp.TYPE_CHECKING else tp.Any
|
39
53
|
|
54
|
+
PathParts = tp.Tuple[Key, ...]
|
55
|
+
Predicate = tp.Callable[[PathParts, tp.Any], bool]
|
56
|
+
FilterLiteral = tp.Union[type, str, Predicate, bool, Ellipsis, None]
|
57
|
+
Filter = tp.Union[FilterLiteral, tp.Tuple['Filter', ...], tp.List['Filter']]
|
40
58
|
|
41
|
-
|
42
|
-
|
59
|
+
_T = tp.TypeVar("_T")
|
60
|
+
|
61
|
+
_Annotation = tp.TypeVar("_Annotation")
|
62
|
+
|
63
|
+
|
64
|
+
class _Array(tp.Generic[_Annotation]):
|
65
|
+
pass
|
43
66
|
|
44
67
|
|
45
68
|
_Array.__module__ = "builtins"
|
46
69
|
|
47
70
|
|
48
|
-
def _item_to_str(item: Union[str, type, slice]) -> str:
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
71
|
+
def _item_to_str(item: tp.Union[str, type, slice]) -> str:
|
72
|
+
if isinstance(item, slice):
|
73
|
+
if item.step is not None:
|
74
|
+
raise NotImplementedError
|
75
|
+
return _item_to_str(item.start) + ": " + _item_to_str(item.stop)
|
76
|
+
elif item is ...:
|
77
|
+
return "..."
|
78
|
+
elif inspect.isclass(item):
|
79
|
+
return item.__name__
|
80
|
+
else:
|
81
|
+
return repr(item)
|
59
82
|
|
60
83
|
|
61
84
|
def _maybe_tuple_to_str(
|
62
|
-
item: Union[str, type, slice, Tuple[Union[str, type, slice], ...]]
|
85
|
+
item: tp.Union[str, type, slice, tp.Tuple[tp.Union[str, type, slice], ...]]
|
63
86
|
) -> str:
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
87
|
+
if isinstance(item, tuple):
|
88
|
+
if len(item) == 0:
|
89
|
+
# Explicit brackets
|
90
|
+
return "()"
|
91
|
+
else:
|
92
|
+
# No brackets
|
93
|
+
return ", ".join([_item_to_str(i) for i in item])
|
68
94
|
else:
|
69
|
-
|
70
|
-
return ", ".join([_item_to_str(i) for i in item])
|
71
|
-
else:
|
72
|
-
return _item_to_str(item)
|
95
|
+
return _item_to_str(item)
|
73
96
|
|
74
97
|
|
75
98
|
class Array:
|
76
|
-
|
77
|
-
|
78
|
-
|
99
|
+
def __class_getitem__(cls, item):
|
100
|
+
class X:
|
101
|
+
pass
|
79
102
|
|
80
|
-
|
81
|
-
|
82
|
-
|
103
|
+
X.__module__ = "builtins"
|
104
|
+
X.__qualname__ = _maybe_tuple_to_str(item)
|
105
|
+
return _Array[X]
|
83
106
|
|
84
107
|
|
85
108
|
# Same __module__ trick here again. (So that we get the correct display when
|
@@ -89,8 +112,8 @@ class Array:
|
|
89
112
|
Array.__module__ = "builtins"
|
90
113
|
|
91
114
|
|
92
|
-
class _FakePyTree(Generic[_T]):
|
93
|
-
|
115
|
+
class _FakePyTree(tp.Generic[_T]):
|
116
|
+
pass
|
94
117
|
|
95
118
|
|
96
119
|
_FakePyTree.__name__ = "PyTree"
|
@@ -99,84 +122,84 @@ _FakePyTree.__module__ = "builtins"
|
|
99
122
|
|
100
123
|
|
101
124
|
class _MetaPyTree(type):
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
125
|
+
def __call__(self, *args, **kwargs):
|
126
|
+
raise RuntimeError("PyTree cannot be instantiated")
|
127
|
+
|
128
|
+
# Can't return a generic (e.g. _FakePyTree[item]) because generic aliases don't do
|
129
|
+
# the custom __instancecheck__ that we want.
|
130
|
+
# We can't add that __instancecheck__ via subclassing, e.g.
|
131
|
+
# type("PyTree", (Generic[_T],), {}), because dynamic subclassing of typeforms
|
132
|
+
# isn't allowed.
|
133
|
+
# Likewise we can't do types.new_class("PyTree", (Generic[_T],), {}) because that
|
134
|
+
# has __module__ "types", e.g. we get types.PyTree[int].
|
135
|
+
@ft.lru_cache(maxsize=None)
|
136
|
+
def __getitem__(cls, item):
|
137
|
+
if isinstance(item, tuple):
|
138
|
+
if len(item) == 2:
|
139
|
+
|
140
|
+
class X(PyTree):
|
141
|
+
leaftype = item[0]
|
142
|
+
structure = item[1].strip()
|
143
|
+
|
144
|
+
if not isinstance(X.structure, str):
|
145
|
+
raise ValueError(
|
146
|
+
"The structure annotation `struct` in "
|
147
|
+
"`brainstate.typing.PyTree[leaftype, struct]` must be be a string, "
|
148
|
+
f"e.g. `brainstate.typing.PyTree[leaftype, 'T']`. Got '{X.structure}'."
|
149
|
+
)
|
150
|
+
pieces = X.structure.split()
|
151
|
+
if len(pieces) == 0:
|
152
|
+
raise ValueError(
|
153
|
+
"The string `struct` in `brainstate.typing.PyTree[leaftype, struct]` "
|
154
|
+
"cannot be the empty string."
|
155
|
+
)
|
156
|
+
for piece_index, piece in enumerate(pieces):
|
157
|
+
if (piece_index == 0) or (piece_index == len(pieces) - 1):
|
158
|
+
if piece == "...":
|
159
|
+
continue
|
160
|
+
if not piece.isidentifier():
|
161
|
+
raise ValueError(
|
162
|
+
"The string `struct` in "
|
163
|
+
"`brainstate.typing.PyTree[leaftype, struct]` must be be a "
|
164
|
+
"whitespace-separated sequence of identifiers, e.g. "
|
165
|
+
"`brainstate.typing.PyTree[leaftype, 'T']` or "
|
166
|
+
"`brainstate.typing.PyTree[leaftype, 'foo bar']`.\n"
|
167
|
+
"(Here, 'identifier' is used in the same sense as in "
|
168
|
+
"regular Python, i.e. a valid variable name.)\n"
|
169
|
+
f"Got piece '{piece}' in overall structure '{X.structure}'."
|
170
|
+
)
|
171
|
+
name = str(_FakePyTree[item[0]])[:-1] + ', "' + item[1].strip() + '"]'
|
172
|
+
else:
|
173
|
+
raise ValueError(
|
174
|
+
"The subscript `foo` in `brainstate.typing.PyTree[foo]` must either be a "
|
175
|
+
"leaf type, e.g. `PyTree[int]`, or a 2-tuple of leaf and "
|
176
|
+
"structure, e.g. `PyTree[int, 'T']`. Received a tuple of length "
|
177
|
+
f"{len(item)}."
|
178
|
+
)
|
179
|
+
else:
|
180
|
+
name = str(_FakePyTree[item])
|
181
|
+
|
182
|
+
class X(PyTree):
|
183
|
+
leaftype = item
|
184
|
+
structure = None
|
185
|
+
|
186
|
+
X.__name__ = name
|
187
|
+
X.__qualname__ = name
|
188
|
+
if getattr(tp, "GENERATING_DOCUMENTATION", False):
|
189
|
+
X.__module__ = "builtins"
|
190
|
+
else:
|
191
|
+
X.__module__ = "brainstate.typing"
|
192
|
+
return X
|
170
193
|
|
171
194
|
|
172
195
|
# Can't do `class PyTree(Generic[_T]): ...` because we need to override the
|
173
196
|
# instancecheck for PyTree[foo], but subclassing
|
174
197
|
# `type(Generic[int])`, i.e. `typing._GenericAlias` is disallowed.
|
175
198
|
PyTree = _MetaPyTree("PyTree", (), {})
|
176
|
-
if getattr(
|
177
|
-
|
199
|
+
if getattr(tp, "GENERATING_DOCUMENTATION", False):
|
200
|
+
PyTree.__module__ = "builtins"
|
178
201
|
else:
|
179
|
-
|
202
|
+
PyTree.__module__ = "brainstate.typing"
|
180
203
|
PyTree.__doc__ = """Represents a PyTree.
|
181
204
|
|
182
205
|
Annotations of the following sorts are supported:
|
@@ -231,9 +254,9 @@ f. A structure can end with a `...`, to denote that the PyTree must be a prefix
|
|
231
254
|
cases, all named pieces must already have been seen and their structures bound.
|
232
255
|
""" # noqa: E501
|
233
256
|
|
234
|
-
Size = Union[int, Sequence[int]]
|
235
|
-
Axes = Union[int, Sequence[int]]
|
236
|
-
SeedOrKey = Union[int, jax.Array, np.ndarray]
|
257
|
+
Size = tp.Union[int, tp.Sequence[int]]
|
258
|
+
Axes = tp.Union[int, tp.Sequence[int]]
|
259
|
+
SeedOrKey = tp.Union[int, jax.Array, np.ndarray]
|
237
260
|
|
238
261
|
# --- Array --- #
|
239
262
|
|
@@ -241,12 +264,12 @@ SeedOrKey = Union[int, jax.Array, np.ndarray]
|
|
241
264
|
# standard JAX array (i.e. not including future non-standard array types like
|
242
265
|
# KeyArray and BInt). It's different than np.typing.ArrayLike in that it doesn't
|
243
266
|
# accept arbitrary sequences, nor does it accept string data.
|
244
|
-
ArrayLike = Union[
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
267
|
+
ArrayLike = tp.Union[
|
268
|
+
jax.Array, # JAX array type
|
269
|
+
np.ndarray, # NumPy array type
|
270
|
+
np.bool_, np.number, # NumPy scalar types
|
271
|
+
bool, int, float, complex, # Python scalar types
|
272
|
+
u.Quantity, # Quantity
|
250
273
|
]
|
251
274
|
|
252
275
|
# --- Dtype --- #
|
@@ -255,9 +278,9 @@ ArrayLike = Union[
|
|
255
278
|
DType = np.dtype
|
256
279
|
|
257
280
|
|
258
|
-
class SupportsDType(Protocol):
|
259
|
-
|
260
|
-
|
281
|
+
class SupportsDType(tp.Protocol):
|
282
|
+
@property
|
283
|
+
def dtype(self) -> DType: ...
|
261
284
|
|
262
285
|
|
263
286
|
# DTypeLike is meant to annotate inputs to np.dtype that return
|
@@ -265,9 +288,13 @@ class SupportsDType(Protocol):
|
|
265
288
|
# because JAX doesn't support objects or structured dtypes.
|
266
289
|
# Unlike np.typing.DTypeLike, we exclude None, and instead require
|
267
290
|
# explicit annotations when None is acceptable.
|
268
|
-
DTypeLike = Union[
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
291
|
+
DTypeLike = tp.Union[
|
292
|
+
str, # like 'float32', 'int32'
|
293
|
+
type[tp.Any], # like np.float32, np.int32, float, int
|
294
|
+
np.dtype, # like np.dtype('float32'), np.dtype('int32')
|
295
|
+
SupportsDType, # like jnp.float32, jnp.int32
|
273
296
|
]
|
297
|
+
|
298
|
+
|
299
|
+
class Missing:
|
300
|
+
pass
|
@@ -0,0 +1,52 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from ._dict import *
|
17
|
+
from ._dict import __all__ as _mapping_all
|
18
|
+
from ._error import *
|
19
|
+
from ._error import __all__ as _error_all
|
20
|
+
from ._filter import *
|
21
|
+
from ._filter import __all__ as _filter_all
|
22
|
+
from ._others import *
|
23
|
+
from ._others import __all__ as _others_all
|
24
|
+
from ._pretty_repr import *
|
25
|
+
from ._pretty_repr import __all__ as _pretty_repr_all
|
26
|
+
from ._scaling import *
|
27
|
+
from ._scaling import __all__ as _mem_scale_all
|
28
|
+
from ._struct import *
|
29
|
+
from ._struct import __all__ as _struct_all
|
30
|
+
from ._visualization import *
|
31
|
+
from ._visualization import __all__ as _visualization_all
|
32
|
+
|
33
|
+
__all__ = (
|
34
|
+
_others_all
|
35
|
+
+ _mem_scale_all
|
36
|
+
+ _filter_all
|
37
|
+
+ _pretty_repr_all
|
38
|
+
+ _struct_all
|
39
|
+
+ _error_all
|
40
|
+
+ _mapping_all
|
41
|
+
+ _visualization_all
|
42
|
+
)
|
43
|
+
del (
|
44
|
+
_others_all,
|
45
|
+
_mem_scale_all,
|
46
|
+
_filter_all,
|
47
|
+
_pretty_repr_all,
|
48
|
+
_struct_all,
|
49
|
+
_error_all,
|
50
|
+
_mapping_all,
|
51
|
+
_visualization_all,
|
52
|
+
)
|
@@ -0,0 +1,100 @@
|
|
1
|
+
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
|
+
# The credit should go to the Flax authors.
|
3
|
+
#
|
4
|
+
# Copyright 2024 The Flax Authors & 2024 BDP Ecosystem.
|
5
|
+
#
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
|
+
# you may not use this file except in compliance with the License.
|
8
|
+
# You may obtain a copy of the License at
|
9
|
+
#
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11
|
+
#
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15
|
+
# See the License for the specific language governing permissions and
|
16
|
+
# limitations under the License.
|
17
|
+
|
18
|
+
from __future__ import annotations
|
19
|
+
|
20
|
+
import dataclasses
|
21
|
+
from typing import Any, TypeVar, Protocol, Generic
|
22
|
+
|
23
|
+
import jax
|
24
|
+
|
25
|
+
__all__ = [
|
26
|
+
'DelayedAccessor',
|
27
|
+
'CallableProxy',
|
28
|
+
'ApplyCaller',
|
29
|
+
]
|
30
|
+
|
31
|
+
A = TypeVar('A', covariant=True) # type: ignore[not-supported-yet]
|
32
|
+
|
33
|
+
|
34
|
+
def _identity(x):
|
35
|
+
return x
|
36
|
+
|
37
|
+
|
38
|
+
@dataclasses.dataclass(frozen=True)
|
39
|
+
class GetItem:
|
40
|
+
key: Any
|
41
|
+
|
42
|
+
|
43
|
+
@dataclasses.dataclass(frozen=True)
|
44
|
+
class GetAttr:
|
45
|
+
name: str
|
46
|
+
|
47
|
+
|
48
|
+
@dataclasses.dataclass(frozen=True)
|
49
|
+
class DelayedAccessor:
|
50
|
+
actions: tuple[GetItem | GetAttr, ...] = ()
|
51
|
+
|
52
|
+
def __call__(self, x):
|
53
|
+
for action in self.actions:
|
54
|
+
if isinstance(action, GetItem):
|
55
|
+
x = x[action.key]
|
56
|
+
elif isinstance(action, GetAttr):
|
57
|
+
x = getattr(x, action.name)
|
58
|
+
return x
|
59
|
+
|
60
|
+
def __getattr__(self, name):
|
61
|
+
return DelayedAccessor(self.actions + (GetAttr(name),))
|
62
|
+
|
63
|
+
def __getitem__(self, key):
|
64
|
+
return DelayedAccessor(self.actions + (GetItem(key),))
|
65
|
+
|
66
|
+
|
67
|
+
jax.tree_util.register_static(DelayedAccessor)
|
68
|
+
|
69
|
+
|
70
|
+
class _AccessorCall(Protocol):
|
71
|
+
def __call__(self, accessor: DelayedAccessor, /, *args, **kwargs) -> Any:
|
72
|
+
...
|
73
|
+
|
74
|
+
|
75
|
+
class CallableProxy:
|
76
|
+
def __init__(
|
77
|
+
self, fun: _AccessorCall, accessor: DelayedAccessor | None = None
|
78
|
+
):
|
79
|
+
self._callable = fun
|
80
|
+
self._accessor = DelayedAccessor() if accessor is None else accessor
|
81
|
+
|
82
|
+
def __call__(self, *args, **kwargs):
|
83
|
+
return self._callable(self._accessor, *args, **kwargs)
|
84
|
+
|
85
|
+
def __getattr__(self, name) -> CallableProxy:
|
86
|
+
return CallableProxy(self._callable, getattr(self._accessor, name))
|
87
|
+
|
88
|
+
def __getitem__(self, key) -> CallableProxy:
|
89
|
+
return CallableProxy(self._callable, self._accessor[key])
|
90
|
+
|
91
|
+
|
92
|
+
class ApplyCaller(Protocol, Generic[A]):
|
93
|
+
def __getattr__(self, __name) -> ApplyCaller[A]:
|
94
|
+
...
|
95
|
+
|
96
|
+
def __getitem__(self, __name) -> ApplyCaller[A]:
|
97
|
+
...
|
98
|
+
|
99
|
+
def __call__(self, *args, **kwargs) -> tuple[Any, A]:
|
100
|
+
...
|