brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +169 -58
- brainstate/_compatible_import.py +340 -148
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +45 -55
- brainstate/_state.py +1652 -1605
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -563
- brainstate/environ_test.py +1223 -62
- brainstate/graph/__init__.py +22 -29
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1433 -365
- brainstate/mixin_test.py +1017 -77
- brainstate/nn/__init__.py +137 -135
- brainstate/nn/_activations.py +1100 -808
- brainstate/nn/_activations_test.py +354 -331
- brainstate/nn/_collective_ops.py +633 -514
- brainstate/nn/_collective_ops_test.py +774 -43
- brainstate/nn/_common.py +226 -178
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +2010 -501
- brainstate/nn/_conv_test.py +849 -238
- brainstate/nn/_delay.py +575 -588
- brainstate/nn/_delay_test.py +243 -238
- brainstate/nn/_dropout.py +618 -426
- brainstate/nn/_dropout_test.py +477 -100
- brainstate/nn/_dynamics.py +1267 -1343
- brainstate/nn/_dynamics_test.py +67 -78
- brainstate/nn/_elementwise.py +1298 -1119
- brainstate/nn/_elementwise_test.py +830 -169
- brainstate/nn/_embedding.py +408 -58
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
- brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
- brainstate/nn/_exp_euler.py +254 -92
- brainstate/nn/_exp_euler_test.py +377 -35
- brainstate/nn/_linear.py +744 -424
- brainstate/nn/_linear_test.py +475 -107
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +384 -377
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -975
- brainstate/nn/_normalizations_test.py +699 -73
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +2239 -1177
- brainstate/nn/_poolings_test.py +953 -217
- brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +216 -89
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +809 -553
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
- brainstate/random/__init__.py +270 -24
- brainstate/random/_rand_funs.py +3938 -3616
- brainstate/random/_rand_funs_test.py +640 -567
- brainstate/random/_rand_seed.py +675 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1617 -1409
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
- brainstate/{augment → transform}/_autograd.py +1025 -778
- brainstate/{augment → transform}/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +220 -220
- brainstate/{compile → transform}/_error_if.py +94 -92
- brainstate/{compile → transform}/_error_if_test.py +52 -52
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +38 -38
- brainstate/{compile → transform}/_jit.py +399 -346
- brainstate/{compile → transform}/_jit_test.py +143 -143
- brainstate/{compile → transform}/_loop_collect_return.py +675 -536
- brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
- brainstate/{compile → transform}/_loop_no_collection.py +283 -184
- brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +255 -202
- brainstate/{augment → transform}/_random.py +171 -151
- brainstate/{compile → transform}/_unvmap.py +256 -159
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +837 -304
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +27 -50
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +945 -469
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +910 -523
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
- brainstate-0.2.1.dist-info/RECORD +111 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/typing.py
CHANGED
@@ -1,304 +1,837 @@
|
|
1
|
-
# Copyright 2024
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
]
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
#
|
112
|
-
#
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
a
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
#
|
268
|
-
#
|
269
|
-
#
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""
|
17
|
+
Comprehensive type annotations for BrainState.
|
18
|
+
|
19
|
+
This module provides a collection of type aliases, protocols, and generic types
|
20
|
+
specifically designed for scientific computing, neural network modeling, and
|
21
|
+
array operations within the BrainState ecosystem.
|
22
|
+
|
23
|
+
The type system is designed to be compatible with JAX, NumPy, and BrainUnit,
|
24
|
+
providing comprehensive type hints for arrays, shapes, seeds, and PyTree structures.
|
25
|
+
|
26
|
+
Examples
|
27
|
+
--------
|
28
|
+
Basic usage with array types:
|
29
|
+
|
30
|
+
.. code-block:: python
|
31
|
+
|
32
|
+
>>> import brainstate
|
33
|
+
>>> from brainstate.typing import ArrayLike, Shape, DTypeLike
|
34
|
+
>>>
|
35
|
+
>>> def process_array(data: ArrayLike, shape: Shape, dtype: DTypeLike) -> brainstate.Array:
|
36
|
+
... return brainstate.asarray(data, dtype=dtype).reshape(shape)
|
37
|
+
|
38
|
+
Using PyTree annotations:
|
39
|
+
|
40
|
+
.. code-block:: python
|
41
|
+
|
42
|
+
>>> from brainstate.typing import PyTree
|
43
|
+
>>>
|
44
|
+
>>> def tree_function(tree: PyTree[float, "T"]) -> PyTree[float, "T"]:
|
45
|
+
... return brainstate.tree_map(lambda x: x * 2, tree)
|
46
|
+
"""
|
47
|
+
|
48
|
+
import builtins
|
49
|
+
import functools
|
50
|
+
import importlib
|
51
|
+
import inspect
|
52
|
+
from typing import (
|
53
|
+
Any, Callable, Hashable, List, Protocol, Tuple, TypeVar, Union,
|
54
|
+
runtime_checkable, TYPE_CHECKING, Generic, Sequence
|
55
|
+
)
|
56
|
+
|
57
|
+
import brainunit as u
|
58
|
+
import jax
|
59
|
+
import numpy as np
|
60
|
+
|
61
|
+
tp = importlib.import_module("typing")
|
62
|
+
|
63
|
+
__all__ = [
|
64
|
+
# Path and filter types
|
65
|
+
'PathParts',
|
66
|
+
'Predicate',
|
67
|
+
'Filter',
|
68
|
+
'FilterLiteral',
|
69
|
+
|
70
|
+
# Array and shape types
|
71
|
+
'Array',
|
72
|
+
'ArrayLike',
|
73
|
+
'Shape',
|
74
|
+
'Size',
|
75
|
+
'Axes',
|
76
|
+
'DType',
|
77
|
+
'DTypeLike',
|
78
|
+
'SupportsDType',
|
79
|
+
|
80
|
+
# PyTree types
|
81
|
+
'PyTree',
|
82
|
+
|
83
|
+
# Random number generation
|
84
|
+
'SeedOrKey',
|
85
|
+
|
86
|
+
# Utility types
|
87
|
+
'Key',
|
88
|
+
'Missing',
|
89
|
+
|
90
|
+
# Type variables
|
91
|
+
'K',
|
92
|
+
'_T',
|
93
|
+
'_Annotation',
|
94
|
+
]
|
95
|
+
|
96
|
+
# ============================================================================
|
97
|
+
# Type Variables
|
98
|
+
# ============================================================================
|
99
|
+
|
100
|
+
K = TypeVar('K', bound='Key')
|
101
|
+
"""Type variable for keys that must be comparable and hashable."""
|
102
|
+
|
103
|
+
_T = TypeVar("_T")
|
104
|
+
"""Generic type variable for any type."""
|
105
|
+
|
106
|
+
_Annotation = TypeVar("_Annotation")
|
107
|
+
"""Type variable for array annotations."""
|
108
|
+
|
109
|
+
|
110
|
+
# ============================================================================
|
111
|
+
# Key and Path Types
|
112
|
+
# ============================================================================
|
113
|
+
|
114
|
+
@runtime_checkable
|
115
|
+
class Key(Hashable, Protocol):
|
116
|
+
"""Protocol for keys that can be used in PyTree paths.
|
117
|
+
|
118
|
+
A Key must be both hashable and comparable, making it suitable
|
119
|
+
for use as dictionary keys and for ordering operations.
|
120
|
+
|
121
|
+
Examples
|
122
|
+
--------
|
123
|
+
Valid key types include:
|
124
|
+
|
125
|
+
.. code-block:: python
|
126
|
+
|
127
|
+
>>> # String keys
|
128
|
+
>>> key1: Key = "layer1"
|
129
|
+
>>>
|
130
|
+
>>> # Integer keys
|
131
|
+
>>> key2: Key = 42
|
132
|
+
>>>
|
133
|
+
>>> # Custom hashable objects
|
134
|
+
>>> class CustomKey:
|
135
|
+
... def __init__(self, name: str):
|
136
|
+
... self.name = name
|
137
|
+
...
|
138
|
+
... def __hash__(self) -> int:
|
139
|
+
... return hash(self.name)
|
140
|
+
...
|
141
|
+
... def __eq__(self, other) -> bool:
|
142
|
+
... return isinstance(other, CustomKey) and self.name == other.name
|
143
|
+
...
|
144
|
+
... def __lt__(self, other) -> bool:
|
145
|
+
... return isinstance(other, CustomKey) and self.name < other.name
|
146
|
+
"""
|
147
|
+
|
148
|
+
def __lt__(self: K, value: K, /) -> bool:
|
149
|
+
"""Less than comparison for ordering keys.
|
150
|
+
|
151
|
+
Parameters
|
152
|
+
----------
|
153
|
+
value : Key
|
154
|
+
The key to compare against.
|
155
|
+
|
156
|
+
Returns
|
157
|
+
-------
|
158
|
+
bool
|
159
|
+
True if this key is less than the other key.
|
160
|
+
"""
|
161
|
+
...
|
162
|
+
|
163
|
+
|
164
|
+
Ellipsis = builtins.ellipsis if TYPE_CHECKING else Any
|
165
|
+
"""Type alias for ellipsis, used in filter expressions."""
|
166
|
+
|
167
|
+
PathParts = Tuple[Key, ...]
|
168
|
+
"""Tuple of keys representing a path through a PyTree structure.
|
169
|
+
|
170
|
+
Examples
|
171
|
+
--------
|
172
|
+
.. code-block:: python
|
173
|
+
|
174
|
+
>>> # Path to a nested value in a PyTree
|
175
|
+
>>> path: PathParts = ("model", "layers", 0, "weights")
|
176
|
+
>>>
|
177
|
+
>>> # Empty path representing the root
|
178
|
+
>>> root_path: PathParts = ()
|
179
|
+
"""
|
180
|
+
|
181
|
+
Predicate = Callable[[PathParts, Any], bool]
|
182
|
+
"""Function that takes a path and value, returning whether it matches some condition.
|
183
|
+
|
184
|
+
Parameters
|
185
|
+
----------
|
186
|
+
path : PathParts
|
187
|
+
The path to the value in the PyTree.
|
188
|
+
value : Any
|
189
|
+
The value at that path.
|
190
|
+
|
191
|
+
Returns
|
192
|
+
-------
|
193
|
+
bool
|
194
|
+
True if the path/value combination matches the predicate.
|
195
|
+
|
196
|
+
Examples
|
197
|
+
--------
|
198
|
+
.. code-block:: python
|
199
|
+
|
200
|
+
>>> def is_weight_matrix(path: PathParts, value: Any) -> bool:
|
201
|
+
... '''Check if a value is a weight matrix (2D array).'''
|
202
|
+
... return len(path) > 0 and "weight" in str(path[-1]) and hasattr(value, 'ndim') and value.ndim == 2
|
203
|
+
>>>
|
204
|
+
>>> def is_bias_vector(path: PathParts, value: Any) -> bool:
|
205
|
+
... '''Check if a value is a bias vector (1D array).'''
|
206
|
+
... return len(path) > 0 and "bias" in str(path[-1]) and hasattr(value, 'ndim') and value.ndim == 1
|
207
|
+
"""
|
208
|
+
|
209
|
+
FilterLiteral = Union[type, str, Predicate, bool, Ellipsis, None]
|
210
|
+
"""Basic filter types that can be used to select parts of a PyTree.
|
211
|
+
|
212
|
+
Components
|
213
|
+
----------
|
214
|
+
type
|
215
|
+
Filter by type, e.g., `float`, `jax.Array`.
|
216
|
+
str
|
217
|
+
Filter by string matching in path keys.
|
218
|
+
Predicate
|
219
|
+
Custom function for complex filtering logic.
|
220
|
+
bool
|
221
|
+
Simple True/False filter.
|
222
|
+
Ellipsis
|
223
|
+
Wildcard filter that matches anything.
|
224
|
+
None
|
225
|
+
Filter that matches None values.
|
226
|
+
|
227
|
+
Examples
|
228
|
+
--------
|
229
|
+
.. code-block:: python
|
230
|
+
|
231
|
+
>>> # Filter by type
|
232
|
+
>>> float_filter: FilterLiteral = float
|
233
|
+
>>>
|
234
|
+
>>> # Filter by string pattern
|
235
|
+
>>> weight_filter: FilterLiteral = "weight"
|
236
|
+
>>>
|
237
|
+
>>> # Custom predicate filter
|
238
|
+
>>> matrix_filter: FilterLiteral = lambda path, x: hasattr(x, 'ndim') and x.ndim == 2
|
239
|
+
"""
|
240
|
+
|
241
|
+
Filter = Union[FilterLiteral, Tuple['Filter', ...], List['Filter']]
|
242
|
+
"""Flexible filter type that can be a single filter or combination of filters.
|
243
|
+
|
244
|
+
This allows for complex filtering patterns by combining multiple filter criteria.
|
245
|
+
|
246
|
+
Examples
|
247
|
+
--------
|
248
|
+
.. code-block:: python
|
249
|
+
|
250
|
+
>>> # Single filter
|
251
|
+
>>> simple_filter: Filter = "weight"
|
252
|
+
>>>
|
253
|
+
>>> # Tuple of filters (all must match)
|
254
|
+
>>> combined_filter: Filter = (float, "weight")
|
255
|
+
>>>
|
256
|
+
>>> # List of filters (any can match)
|
257
|
+
>>> alternative_filter: Filter = [int, float, "bias"]
|
258
|
+
>>>
|
259
|
+
>>> # Nested combinations
|
260
|
+
>>> complex_filter: Filter = [
|
261
|
+
... ("weight", lambda p, x: x.ndim == 2), # 2D weight matrices
|
262
|
+
... ("bias", lambda p, x: x.ndim == 1), # 1D bias vectors
|
263
|
+
... ]
|
264
|
+
"""
|
265
|
+
|
266
|
+
|
267
|
+
# ============================================================================
|
268
|
+
# Array Annotation Types
|
269
|
+
# ============================================================================
|
270
|
+
|
271
|
+
class _Array(Generic[_Annotation]):
|
272
|
+
"""Internal generic array type for creating custom array annotations."""
|
273
|
+
pass
|
274
|
+
|
275
|
+
|
276
|
+
_Array.__module__ = "builtins"
|
277
|
+
|
278
|
+
|
279
|
+
def _item_to_str(item: Union[str, type, slice]) -> str:
|
280
|
+
"""Convert an array annotation item to its string representation.
|
281
|
+
|
282
|
+
Parameters
|
283
|
+
----------
|
284
|
+
item : Union[str, type, slice]
|
285
|
+
The item to convert to string.
|
286
|
+
|
287
|
+
Returns
|
288
|
+
-------
|
289
|
+
str
|
290
|
+
String representation of the item.
|
291
|
+
|
292
|
+
Raises
|
293
|
+
------
|
294
|
+
NotImplementedError
|
295
|
+
If slice has a step component.
|
296
|
+
"""
|
297
|
+
if isinstance(item, slice):
|
298
|
+
if item.step is not None:
|
299
|
+
raise NotImplementedError("Slice steps are not supported in array annotations")
|
300
|
+
return _item_to_str(item.start) + ": " + _item_to_str(item.stop)
|
301
|
+
elif item is ...:
|
302
|
+
return "..."
|
303
|
+
elif inspect.isclass(item):
|
304
|
+
return item.__name__
|
305
|
+
else:
|
306
|
+
return repr(item)
|
307
|
+
|
308
|
+
|
309
|
+
def _maybe_tuple_to_str(
|
310
|
+
item: Union[str, type, slice, Tuple[Union[str, type, slice], ...]]
|
311
|
+
) -> str:
|
312
|
+
"""Convert array annotation items (potentially in tuple) to string representation.
|
313
|
+
|
314
|
+
Parameters
|
315
|
+
----------
|
316
|
+
item : Union[str, type, slice, Tuple[...]]
|
317
|
+
Single item or tuple of items to convert.
|
318
|
+
|
319
|
+
Returns
|
320
|
+
-------
|
321
|
+
str
|
322
|
+
String representation of the item(s).
|
323
|
+
"""
|
324
|
+
if isinstance(item, tuple):
|
325
|
+
if len(item) == 0:
|
326
|
+
# Explicit brackets for empty tuple
|
327
|
+
return "()"
|
328
|
+
else:
|
329
|
+
# No brackets for non-empty tuple
|
330
|
+
return ", ".join([_item_to_str(i) for i in item])
|
331
|
+
else:
|
332
|
+
return _item_to_str(item)
|
333
|
+
|
334
|
+
|
335
|
+
class Array:
|
336
|
+
"""Flexible array type annotation supporting shape and dtype specifications.
|
337
|
+
|
338
|
+
This class provides a convenient way to annotate arrays with shape information,
|
339
|
+
making code more self-documenting and enabling better static analysis.
|
340
|
+
|
341
|
+
Examples
|
342
|
+
--------
|
343
|
+
Basic array annotations:
|
344
|
+
|
345
|
+
.. code-block:: python
|
346
|
+
|
347
|
+
>>> from brainstate.typing import Array
|
348
|
+
>>>
|
349
|
+
>>> # Any array
|
350
|
+
>>> def process_array(x: Array) -> Array:
|
351
|
+
... return x * 2
|
352
|
+
>>>
|
353
|
+
>>> # Array with specific shape annotation
|
354
|
+
>>> def matrix_multiply(a: Array["m, n"], b: Array["n, k"]) -> Array["m, k"]:
|
355
|
+
... return a @ b
|
356
|
+
>>>
|
357
|
+
>>> # Array with dtype and shape
|
358
|
+
>>> def normalize_weights(weights: Array["batch, features"]) -> Array["batch, features"]:
|
359
|
+
... return weights / weights.sum(axis=-1, keepdims=True)
|
360
|
+
|
361
|
+
Advanced shape annotations:
|
362
|
+
|
363
|
+
.. code-block:: python
|
364
|
+
|
365
|
+
>>> # Using ellipsis for flexible dimensions
|
366
|
+
>>> def flatten_batch(x: Array["batch, ..."]) -> Array["batch, -1"]:
|
367
|
+
... return x.reshape(x.shape[0], -1)
|
368
|
+
>>>
|
369
|
+
>>> # Multiple shape constraints
|
370
|
+
>>> def attention(
|
371
|
+
... query: Array["batch, seq_len, d_model"],
|
372
|
+
... key: Array["batch, seq_len, d_model"],
|
373
|
+
... value: Array["batch, seq_len, d_model"]
|
374
|
+
... ) -> Array["batch, seq_len, d_model"]:
|
375
|
+
... # Attention computation
|
376
|
+
... pass
|
377
|
+
"""
|
378
|
+
|
379
|
+
def __class_getitem__(cls, item):
|
380
|
+
"""Create a specialized Array type with shape/dtype annotations.
|
381
|
+
|
382
|
+
Parameters
|
383
|
+
----------
|
384
|
+
item : str, type, slice, or tuple
|
385
|
+
Shape specification, dtype, or combination thereof.
|
386
|
+
|
387
|
+
Returns
|
388
|
+
-------
|
389
|
+
_Array
|
390
|
+
Specialized array type with the given annotation.
|
391
|
+
"""
|
392
|
+
|
393
|
+
class X:
|
394
|
+
pass
|
395
|
+
|
396
|
+
X.__module__ = "builtins"
|
397
|
+
X.__qualname__ = _maybe_tuple_to_str(item)
|
398
|
+
return _Array[X]
|
399
|
+
|
400
|
+
|
401
|
+
# Set module for proper display in type hints
|
402
|
+
Array.__module__ = "builtins"
|
403
|
+
|
404
|
+
|
405
|
+
# ============================================================================
|
406
|
+
# PyTree Types
|
407
|
+
# ============================================================================
|
408
|
+
|
409
|
+
class _FakePyTree(Generic[_T]):
|
410
|
+
"""Internal generic PyTree type for creating specialized PyTree annotations."""
|
411
|
+
pass
|
412
|
+
|
413
|
+
|
414
|
+
_FakePyTree.__name__ = "PyTree"
|
415
|
+
_FakePyTree.__qualname__ = "PyTree"
|
416
|
+
_FakePyTree.__module__ = "builtins"
|
417
|
+
|
418
|
+
|
419
|
+
class _MetaPyTree(type):
|
420
|
+
"""Metaclass for PyTree type that prevents instantiation and handles subscripting."""
|
421
|
+
|
422
|
+
def __call__(self, *args, **kwargs):
|
423
|
+
"""Prevent direct instantiation of PyTree type.
|
424
|
+
|
425
|
+
Raises
|
426
|
+
------
|
427
|
+
RuntimeError
|
428
|
+
Always raised since PyTree is a type annotation only.
|
429
|
+
"""
|
430
|
+
raise RuntimeError("PyTree cannot be instantiated")
|
431
|
+
|
432
|
+
# Can't return a generic (e.g. _FakePyTree[item]) because generic aliases don't do
|
433
|
+
# the custom __instancecheck__ that we want.
|
434
|
+
# We can't add that __instancecheck__ via subclassing, e.g.
|
435
|
+
# type("PyTree", (Generic[_T],), {}), because dynamic subclassing of typeforms
|
436
|
+
# isn't allowed.
|
437
|
+
# Likewise we can't do types.new_class("PyTree", (Generic[_T],), {}) because that
|
438
|
+
# has __module__ "types", e.g. we get types.PyTree[int].
|
439
|
+
@functools.lru_cache(maxsize=None)
|
440
|
+
def __getitem__(cls, item):
|
441
|
+
if isinstance(item, tuple):
|
442
|
+
if len(item) == 2:
|
443
|
+
|
444
|
+
class X(PyTree):
|
445
|
+
leaftype = item[0]
|
446
|
+
structure = item[1].strip()
|
447
|
+
|
448
|
+
if not isinstance(X.structure, str):
|
449
|
+
raise ValueError(
|
450
|
+
"The structure annotation `struct` in "
|
451
|
+
"`brainstate.typing.PyTree[leaftype, struct]` must be be a string, "
|
452
|
+
f"e.g. `brainstate.typing.PyTree[leaftype, 'T']`. Got '{X.structure}'."
|
453
|
+
)
|
454
|
+
pieces = X.structure.split()
|
455
|
+
if len(pieces) == 0:
|
456
|
+
raise ValueError(
|
457
|
+
"The string `struct` in `brainstate.typing.PyTree[leaftype, struct]` "
|
458
|
+
"cannot be the empty string."
|
459
|
+
)
|
460
|
+
for piece_index, piece in enumerate(pieces):
|
461
|
+
if (piece_index == 0) or (piece_index == len(pieces) - 1):
|
462
|
+
if piece == "...":
|
463
|
+
continue
|
464
|
+
if not piece.isidentifier():
|
465
|
+
raise ValueError(
|
466
|
+
"The string `struct` in "
|
467
|
+
"`brainstate.typing.PyTree[leaftype, struct]` must be be a "
|
468
|
+
"whitespace-separated sequence of identifiers, e.g. "
|
469
|
+
"`brainstate.typing.PyTree[leaftype, 'T']` or "
|
470
|
+
"`brainstate.typing.PyTree[leaftype, 'foo bar']`.\n"
|
471
|
+
"(Here, 'identifier' is used in the same sense as in "
|
472
|
+
"regular Python, i.e. a valid variable name.)\n"
|
473
|
+
f"Got piece '{piece}' in overall structure '{X.structure}'."
|
474
|
+
)
|
475
|
+
name = str(_FakePyTree[item[0]])[:-1] + ', "' + item[1].strip() + '"]'
|
476
|
+
else:
|
477
|
+
raise ValueError(
|
478
|
+
"The subscript `foo` in `brainstate.typing.PyTree[foo]` must either be a "
|
479
|
+
"leaf type, e.g. `PyTree[int]`, or a 2-tuple of leaf and "
|
480
|
+
"structure, e.g. `PyTree[int, 'T']`. Received a tuple of length "
|
481
|
+
f"{len(item)}."
|
482
|
+
)
|
483
|
+
else:
|
484
|
+
name = str(_FakePyTree[item])
|
485
|
+
|
486
|
+
class X(PyTree):
|
487
|
+
leaftype = item
|
488
|
+
structure = None
|
489
|
+
|
490
|
+
X.__name__ = name
|
491
|
+
X.__qualname__ = name
|
492
|
+
if getattr(tp, "GENERATING_DOCUMENTATION", False):
|
493
|
+
X.__module__ = "builtins"
|
494
|
+
else:
|
495
|
+
X.__module__ = "brainstate.typing"
|
496
|
+
return X
|
497
|
+
|
498
|
+
|
499
|
+
# Can't do `class PyTree(Generic[_T]): ...` because we need to override the
|
500
|
+
# instancecheck for PyTree[foo], but subclassing
|
501
|
+
# `type(Generic[int])`, i.e. `typing._GenericAlias` is disallowed.
|
502
|
+
PyTree = _MetaPyTree("PyTree", (), {})
|
503
|
+
if getattr(tp, "GENERATING_DOCUMENTATION", False):
|
504
|
+
PyTree.__module__ = "builtins"
|
505
|
+
else:
|
506
|
+
PyTree.__module__ = "brainstate.typing"
|
507
|
+
PyTree.__doc__ = """Represents a PyTree.
|
508
|
+
|
509
|
+
Annotations of the following sorts are supported:
|
510
|
+
|
511
|
+
.. code-block:: python
|
512
|
+
|
513
|
+
>>> a: PyTree
|
514
|
+
>>> b: PyTree[LeafType]
|
515
|
+
>>> c: PyTree[LeafType, "T"]
|
516
|
+
>>> d: PyTree[LeafType, "S T"]
|
517
|
+
>>> e: PyTree[LeafType, "... T"]
|
518
|
+
>>> f: PyTree[LeafType, "T ..."]
|
519
|
+
|
520
|
+
These correspond to:
|
521
|
+
|
522
|
+
a. A plain `PyTree` can be used an annotation, in which case `PyTree` is simply a
|
523
|
+
suggestively-named alternative to `Any`.
|
524
|
+
([By definition all types are PyTrees.](https://jax.readthedocs.io/en/latest/pytrees.html))
|
525
|
+
|
526
|
+
b. `PyTree[LeafType]` denotes a PyTree all of whose leaves match `LeafType`. For
|
527
|
+
example, `PyTree[int]` or `PyTree[Union[str, Float32[Array, "b c"]]]`.
|
528
|
+
|
529
|
+
c. A structure name can also be passed. In this case
|
530
|
+
`jax.tree_util.tree_structure(...)` will be called, and bound to the structure name.
|
531
|
+
This can be used to mark that multiple PyTrees all have the same structure:
|
532
|
+
|
533
|
+
.. code-block:: python
|
534
|
+
|
535
|
+
>>> def f(x: PyTree[int, "T"], y: PyTree[int, "T"]):
|
536
|
+
... ...
|
537
|
+
|
538
|
+
d. A composite structure can be declared. In this case the variable must have a PyTree
|
539
|
+
structure each to the composition of multiple previously-bound PyTree structures.
|
540
|
+
For example:
|
541
|
+
|
542
|
+
.. code-block:: python
|
543
|
+
|
544
|
+
>>> def f(x: PyTree[int, "T"], y: PyTree[int, "S"], z: PyTree[int, "S T"]):
|
545
|
+
... ...
|
546
|
+
>>>
|
547
|
+
>>> x = (1, 2)
|
548
|
+
>>> y = {"key": 3}
|
549
|
+
>>> z = {"key": (4, 5)} # structure is the composition of the structures of `y` and `z`
|
550
|
+
>>> f(x, y, z)
|
551
|
+
|
552
|
+
When performing runtime type-checking, all the individual pieces must have already
|
553
|
+
been bound to structures, otherwise the composite structure check will throw an error.
|
554
|
+
|
555
|
+
e. A structure can begin with a `...`, to denote that the lower levels of the PyTree
|
556
|
+
must match the declared structure, but the upper levels can be arbitrary. As in the
|
557
|
+
previous case, all named pieces must already have been seen and their structures
|
558
|
+
bound.
|
559
|
+
|
560
|
+
f. A structure can end with a `...`, to denote that the PyTree must be a prefix of the
|
561
|
+
declared structure, but the lower levels can be arbitrary. As in the previous two
|
562
|
+
cases, all named pieces must already have been seen and their structures bound.
|
563
|
+
""" # noqa: E501
|
564
|
+
|
565
|
+
# ============================================================================
|
566
|
+
# Shape and Size Types
|
567
|
+
# ============================================================================
|
568
|
+
|
569
|
+
Size = Union[int, Sequence[int], np.integer, Sequence[np.integer]]
|
570
|
+
"""Type for specifying array sizes and dimensions.
|
571
|
+
|
572
|
+
Can be a single integer for 1D sizes, or a sequence of integers for multi-dimensional shapes.
|
573
|
+
Supports both Python integers and NumPy integer types for compatibility.
|
574
|
+
|
575
|
+
Examples
|
576
|
+
--------
|
577
|
+
.. code-block:: python
|
578
|
+
|
579
|
+
>>> # Single dimension
|
580
|
+
>>> size1: Size = 10
|
581
|
+
>>>
|
582
|
+
>>> # Multiple dimensions
|
583
|
+
>>> size2: Size = (3, 4, 5)
|
584
|
+
>>>
|
585
|
+
>>> # Using NumPy integers
|
586
|
+
>>> size3: Size = np.int32(8)
|
587
|
+
>>>
|
588
|
+
>>> # Mixed sequence
|
589
|
+
>>> size4: Size = [np.int64(2), 3, np.int32(4)]
|
590
|
+
"""
|
591
|
+
|
592
|
+
Shape = Sequence[int]
|
593
|
+
"""Type for array shapes as sequences of integers.
|
594
|
+
|
595
|
+
Represents the shape of an array as a sequence of dimension sizes.
|
596
|
+
More restrictive than Size as it requires a sequence.
|
597
|
+
|
598
|
+
Examples
|
599
|
+
--------
|
600
|
+
.. code-block:: python
|
601
|
+
|
602
|
+
>>> # 2D array shape
|
603
|
+
>>> matrix_shape: Shape = (10, 20)
|
604
|
+
>>>
|
605
|
+
>>> # 3D array shape
|
606
|
+
>>> tensor_shape: Shape = (5, 10, 15)
|
607
|
+
>>>
|
608
|
+
>>> # 1D array shape (note: still needs to be a sequence)
|
609
|
+
>>> vector_shape: Shape = (100,)
|
610
|
+
"""
|
611
|
+
|
612
|
+
Axes = Union[int, Sequence[int]]
|
613
|
+
"""Type for specifying axes along which operations should be performed.
|
614
|
+
|
615
|
+
Can be a single axis (integer) or multiple axes (sequence of integers).
|
616
|
+
Used in reduction operations, reshaping, and other array manipulations.
|
617
|
+
|
618
|
+
Examples
|
619
|
+
--------
|
620
|
+
.. code-block:: python
|
621
|
+
|
622
|
+
>>> # Single axis
|
623
|
+
>>> axis1: Axes = 0
|
624
|
+
>>>
|
625
|
+
>>> # Multiple axes
|
626
|
+
>>> axis2: Axes = (0, 2)
|
627
|
+
>>>
|
628
|
+
>>> # All axes for global operations
|
629
|
+
>>> axis3: Axes = tuple(range(ndim))
|
630
|
+
>>>
|
631
|
+
>>> def sum_along_axes(array: ArrayLike, axes: Axes) -> ArrayLike:
|
632
|
+
... return jnp.sum(array, axis=axes)
|
633
|
+
"""
|
634
|
+
|
635
|
+
# ============================================================================
|
636
|
+
# Array Types
|
637
|
+
# ============================================================================
|
638
|
+
|
639
|
+
ArrayLike = Union[
|
640
|
+
jax.Array, # JAX array type
|
641
|
+
np.ndarray, # NumPy array type
|
642
|
+
np.bool_, np.number, # NumPy scalar types
|
643
|
+
bool, int, float, complex, # Python scalar types
|
644
|
+
u.Quantity, # BrainUnit quantity type
|
645
|
+
]
|
646
|
+
"""Union of all objects that can be implicitly converted to a JAX array.
|
647
|
+
|
648
|
+
This type is designed for JAX compatibility and excludes arbitrary sequences
|
649
|
+
and string data that numpy.typing.ArrayLike would include. It represents
|
650
|
+
data that can be safely converted to arrays without ambiguity.
|
651
|
+
|
652
|
+
Components
|
653
|
+
----------
|
654
|
+
jax.Array
|
655
|
+
Native JAX arrays.
|
656
|
+
np.ndarray
|
657
|
+
NumPy arrays that can be converted to JAX arrays.
|
658
|
+
np.bool_, np.number
|
659
|
+
NumPy scalar types (bool, int8, float32, etc.).
|
660
|
+
bool, int, float, complex
|
661
|
+
Python built-in scalar types.
|
662
|
+
u.Quantity
|
663
|
+
BrainUnit quantities with physical units.
|
664
|
+
|
665
|
+
Examples
|
666
|
+
--------
|
667
|
+
.. code-block:: python
|
668
|
+
|
669
|
+
>>> def process_data(data: ArrayLike) -> jax.Array:
|
670
|
+
... '''Convert input to JAX array and process it.'''
|
671
|
+
... array = jnp.asarray(data)
|
672
|
+
... return array * 2
|
673
|
+
>>>
|
674
|
+
>>> # Valid inputs
|
675
|
+
>>> process_data(jnp.array([1, 2, 3])) # JAX array
|
676
|
+
>>> process_data(np.array([1, 2, 3])) # NumPy array
|
677
|
+
>>> process_data([1, 2, 3]) # Python list (via numpy)
|
678
|
+
>>> process_data(42) # Python scalar
|
679
|
+
>>> process_data(np.float32(3.14)) # NumPy scalar
|
680
|
+
>>> process_data(1.5 * u.second) # Quantity with units
|
681
|
+
"""
|
682
|
+
|
683
|
+
# ============================================================================
|
684
|
+
# Data Type Annotations
|
685
|
+
# ============================================================================
|
686
|
+
|
687
|
+
DType = np.dtype
|
688
|
+
"""Alias for NumPy's dtype type.
|
689
|
+
|
690
|
+
Used to represent data types of arrays in a clear and consistent manner.
|
691
|
+
|
692
|
+
Examples
|
693
|
+
--------
|
694
|
+
.. code-block:: python
|
695
|
+
|
696
|
+
>>> def create_array(shape: Shape, dtype: DType) -> jax.Array:
|
697
|
+
... return jnp.zeros(shape, dtype=dtype)
|
698
|
+
>>>
|
699
|
+
>>> # Usage
|
700
|
+
>>> arr = create_array((3, 4), np.float32)
|
701
|
+
"""
|
702
|
+
|
703
|
+
|
704
|
+
class SupportsDType(Protocol):
|
705
|
+
"""Protocol for objects that have a dtype property.
|
706
|
+
|
707
|
+
This protocol defines the interface for any object that exposes
|
708
|
+
a dtype attribute, allowing for flexible type checking.
|
709
|
+
|
710
|
+
Examples
|
711
|
+
--------
|
712
|
+
.. code-block:: python
|
713
|
+
|
714
|
+
>>> def get_dtype(obj: SupportsDType) -> DType:
|
715
|
+
... return obj.dtype
|
716
|
+
>>>
|
717
|
+
>>> # Works with arrays
|
718
|
+
>>> arr = jnp.array([1.0, 2.0])
|
719
|
+
>>> dtype = get_dtype(arr) # float32
|
720
|
+
"""
|
721
|
+
|
722
|
+
@property
|
723
|
+
def dtype(self) -> DType:
|
724
|
+
"""Return the data type of the object.
|
725
|
+
|
726
|
+
Returns
|
727
|
+
-------
|
728
|
+
DType
|
729
|
+
The NumPy dtype of the object.
|
730
|
+
"""
|
731
|
+
...
|
732
|
+
|
733
|
+
|
734
|
+
DTypeLike = Union[
|
735
|
+
str, # String representations like 'float32', 'int32'
|
736
|
+
type[Any], # Type objects like np.float32, np.int32, float, int
|
737
|
+
np.dtype, # NumPy dtype objects
|
738
|
+
SupportsDType, # Objects with a dtype property
|
739
|
+
]
|
740
|
+
"""Union of types that can be converted to a valid JAX dtype.
|
741
|
+
|
742
|
+
This is more restrictive than numpy.typing.DTypeLike as JAX doesn't support
|
743
|
+
object arrays or structured dtypes. It excludes None to require explicit
|
744
|
+
handling of optional dtypes.
|
745
|
+
|
746
|
+
Components
|
747
|
+
----------
|
748
|
+
str
|
749
|
+
String representations like 'float32', 'int32', 'bool'.
|
750
|
+
type[Any]
|
751
|
+
Type objects like np.float32, float, int, bool.
|
752
|
+
np.dtype
|
753
|
+
NumPy dtype objects created with np.dtype().
|
754
|
+
SupportsDType
|
755
|
+
Any object with a .dtype property.
|
756
|
+
|
757
|
+
Examples
|
758
|
+
--------
|
759
|
+
.. code-block:: python
|
760
|
+
|
761
|
+
>>> def cast_array(array: ArrayLike, dtype: DTypeLike) -> jax.Array:
|
762
|
+
... '''Cast array to specified dtype.'''
|
763
|
+
... return jnp.asarray(array, dtype=dtype)
|
764
|
+
>>>
|
765
|
+
>>> # Valid dtype specifications
|
766
|
+
>>> cast_array(data, 'float32') # String
|
767
|
+
>>> cast_array(data, np.float32) # NumPy type
|
768
|
+
>>> cast_array(data, float) # Python type
|
769
|
+
>>> cast_array(data, np.dtype('int32')) # NumPy dtype object
|
770
|
+
>>> cast_array(data, other_array) # Object with dtype property
|
771
|
+
"""
|
772
|
+
|
773
|
+
# ============================================================================
|
774
|
+
# Random Number Generation
|
775
|
+
# ============================================================================
|
776
|
+
|
777
|
+
SeedOrKey = Union[int, jax.Array, np.ndarray]
|
778
|
+
"""Type for random number generator seeds or keys.
|
779
|
+
|
780
|
+
Represents values that can be used to seed random number generators
|
781
|
+
or serve as PRNG keys in JAX's random number generation system.
|
782
|
+
|
783
|
+
Components
|
784
|
+
----------
|
785
|
+
int
|
786
|
+
Integer seeds for random number generators.
|
787
|
+
jax.Array
|
788
|
+
JAX PRNG keys (typically created with jax.random.PRNGKey).
|
789
|
+
np.ndarray
|
790
|
+
NumPy arrays that can serve as random keys.
|
791
|
+
|
792
|
+
Examples
|
793
|
+
--------
|
794
|
+
.. code-block:: python
|
795
|
+
|
796
|
+
>>> def generate_random(key: SeedOrKey, shape: Shape) -> jax.Array:
|
797
|
+
... '''Generate random numbers using the provided seed or key.'''
|
798
|
+
... if isinstance(key, int):
|
799
|
+
... key = jax.random.PRNGKey(key)
|
800
|
+
... return jax.random.normal(key, shape)
|
801
|
+
>>>
|
802
|
+
>>> # Valid seeds/keys
|
803
|
+
>>> generate_random(42, (3, 4)) # Integer seed
|
804
|
+
>>> generate_random(jax.random.PRNGKey(123), (5,)) # JAX PRNG key
|
805
|
+
>>> generate_random(np.array([1, 2], dtype=np.uint32), (2, 2)) # NumPy array
|
806
|
+
"""
|
807
|
+
|
808
|
+
|
809
|
+
# ============================================================================
|
810
|
+
# Utility Types
|
811
|
+
# ============================================================================
|
812
|
+
|
813
|
+
class Missing:
|
814
|
+
"""Sentinel class to represent missing or unspecified values.
|
815
|
+
|
816
|
+
This class is used as a default value when None has semantic meaning
|
817
|
+
and you need to distinguish between "None was passed" and "nothing was passed".
|
818
|
+
|
819
|
+
Examples
|
820
|
+
--------
|
821
|
+
.. code-block:: python
|
822
|
+
|
823
|
+
>>> _MISSING = Missing()
|
824
|
+
>>>
|
825
|
+
>>> def function_with_optional_param(value: Union[int, None, Missing] = _MISSING):
|
826
|
+
... if value is _MISSING:
|
827
|
+
... print("No value provided")
|
828
|
+
... elif value is None:
|
829
|
+
... print("None was explicitly provided")
|
830
|
+
... else:
|
831
|
+
... print(f"Value: {value}")
|
832
|
+
>>>
|
833
|
+
>>> function_with_optional_param() # "No value provided"
|
834
|
+
>>> function_with_optional_param(None) # "None was explicitly provided"
|
835
|
+
>>> function_with_optional_param(42) # "Value: 42"
|
836
|
+
"""
|
837
|
+
pass
|