brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +95 -29
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.9.dist-info/RECORD +0 -130
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/typing.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -13,8 +13,40 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
"""
|
17
|
+
Comprehensive type annotations for BrainState.
|
18
|
+
|
19
|
+
This module provides a collection of type aliases, protocols, and generic types
|
20
|
+
specifically designed for scientific computing, neural network modeling, and
|
21
|
+
array operations within the BrainState ecosystem.
|
22
|
+
|
23
|
+
The type system is designed to be compatible with JAX, NumPy, and BrainUnit,
|
24
|
+
providing comprehensive type hints for arrays, shapes, seeds, and PyTree structures.
|
25
|
+
|
26
|
+
Examples
|
27
|
+
--------
|
28
|
+
Basic usage with array types:
|
29
|
+
|
30
|
+
.. code-block:: python
|
31
|
+
|
32
|
+
>>> import brainstate
|
33
|
+
>>> from brainstate.typing import ArrayLike, Shape, DTypeLike
|
34
|
+
>>>
|
35
|
+
>>> def process_array(data: ArrayLike, shape: Shape, dtype: DTypeLike) -> brainstate.Array:
|
36
|
+
... return brainstate.asarray(data, dtype=dtype).reshape(shape)
|
37
|
+
|
38
|
+
Using PyTree annotations:
|
39
|
+
|
40
|
+
.. code-block:: python
|
41
|
+
|
42
|
+
>>> from brainstate.typing import PyTree
|
43
|
+
>>>
|
44
|
+
>>> def tree_function(tree: PyTree[float, "T"]) -> PyTree[float, "T"]:
|
45
|
+
... return brainstate.tree_map(lambda x: x * 2, tree)
|
46
|
+
"""
|
47
|
+
|
16
48
|
import builtins
|
17
|
-
import functools
|
49
|
+
import functools
|
18
50
|
import importlib
|
19
51
|
import inspect
|
20
52
|
from typing import (
|
@@ -29,42 +61,215 @@ import numpy as np
|
|
29
61
|
tp = importlib.import_module("typing")
|
30
62
|
|
31
63
|
__all__ = [
|
64
|
+
# Path and filter types
|
32
65
|
'PathParts',
|
33
66
|
'Predicate',
|
34
67
|
'Filter',
|
35
|
-
'
|
36
|
-
|
68
|
+
'FilterLiteral',
|
69
|
+
|
70
|
+
# Array and shape types
|
71
|
+
'Array',
|
72
|
+
'ArrayLike',
|
37
73
|
'Shape',
|
74
|
+
'Size',
|
38
75
|
'Axes',
|
39
|
-
'SeedOrKey',
|
40
|
-
'ArrayLike',
|
41
76
|
'DType',
|
42
77
|
'DTypeLike',
|
78
|
+
'SupportsDType',
|
79
|
+
|
80
|
+
# PyTree types
|
81
|
+
'PyTree',
|
82
|
+
|
83
|
+
# Random number generation
|
84
|
+
'SeedOrKey',
|
85
|
+
|
86
|
+
# Utility types
|
87
|
+
'Key',
|
43
88
|
'Missing',
|
89
|
+
|
90
|
+
# Type variables
|
91
|
+
'K',
|
92
|
+
'_T',
|
93
|
+
'_Annotation',
|
44
94
|
]
|
45
95
|
|
46
|
-
|
96
|
+
# ============================================================================
|
97
|
+
# Type Variables
|
98
|
+
# ============================================================================
|
99
|
+
|
100
|
+
K = TypeVar('K', bound='Key')
|
101
|
+
"""Type variable for keys that must be comparable and hashable."""
|
102
|
+
|
103
|
+
_T = TypeVar("_T")
|
104
|
+
"""Generic type variable for any type."""
|
105
|
+
|
106
|
+
_Annotation = TypeVar("_Annotation")
|
107
|
+
"""Type variable for array annotations."""
|
108
|
+
|
47
109
|
|
110
|
+
# ============================================================================
|
111
|
+
# Key and Path Types
|
112
|
+
# ============================================================================
|
48
113
|
|
49
114
|
@runtime_checkable
|
50
115
|
class Key(Hashable, Protocol):
|
116
|
+
"""Protocol for keys that can be used in PyTree paths.
|
117
|
+
|
118
|
+
A Key must be both hashable and comparable, making it suitable
|
119
|
+
for use as dictionary keys and for ordering operations.
|
120
|
+
|
121
|
+
Examples
|
122
|
+
--------
|
123
|
+
Valid key types include:
|
124
|
+
|
125
|
+
.. code-block:: python
|
126
|
+
|
127
|
+
>>> # String keys
|
128
|
+
>>> key1: Key = "layer1"
|
129
|
+
>>>
|
130
|
+
>>> # Integer keys
|
131
|
+
>>> key2: Key = 42
|
132
|
+
>>>
|
133
|
+
>>> # Custom hashable objects
|
134
|
+
>>> class CustomKey:
|
135
|
+
... def __init__(self, name: str):
|
136
|
+
... self.name = name
|
137
|
+
...
|
138
|
+
... def __hash__(self) -> int:
|
139
|
+
... return hash(self.name)
|
140
|
+
...
|
141
|
+
... def __eq__(self, other) -> bool:
|
142
|
+
... return isinstance(other, CustomKey) and self.name == other.name
|
143
|
+
...
|
144
|
+
... def __lt__(self, other) -> bool:
|
145
|
+
... return isinstance(other, CustomKey) and self.name < other.name
|
146
|
+
"""
|
147
|
+
|
51
148
|
def __lt__(self: K, value: K, /) -> bool:
|
149
|
+
"""Less than comparison for ordering keys.
|
150
|
+
|
151
|
+
Parameters
|
152
|
+
----------
|
153
|
+
value : Key
|
154
|
+
The key to compare against.
|
155
|
+
|
156
|
+
Returns
|
157
|
+
-------
|
158
|
+
bool
|
159
|
+
True if this key is less than the other key.
|
160
|
+
"""
|
52
161
|
...
|
53
162
|
|
54
163
|
|
55
164
|
Ellipsis = builtins.ellipsis if TYPE_CHECKING else Any
|
165
|
+
"""Type alias for ellipsis, used in filter expressions."""
|
56
166
|
|
57
167
|
PathParts = Tuple[Key, ...]
|
58
|
-
|
59
|
-
FilterLiteral = Union[type, str, Predicate, bool, Ellipsis, None]
|
60
|
-
Filter = Union[FilterLiteral, Tuple['Filter', ...], List['Filter']]
|
168
|
+
"""Tuple of keys representing a path through a PyTree structure.
|
61
169
|
|
62
|
-
|
170
|
+
Examples
|
171
|
+
--------
|
172
|
+
.. code-block:: python
|
63
173
|
|
64
|
-
|
174
|
+
>>> # Path to a nested value in a PyTree
|
175
|
+
>>> path: PathParts = ("model", "layers", 0, "weights")
|
176
|
+
>>>
|
177
|
+
>>> # Empty path representing the root
|
178
|
+
>>> root_path: PathParts = ()
|
179
|
+
"""
|
65
180
|
|
181
|
+
Predicate = Callable[[PathParts, Any], bool]
|
182
|
+
"""Function that takes a path and value, returning whether it matches some condition.
|
183
|
+
|
184
|
+
Parameters
|
185
|
+
----------
|
186
|
+
path : PathParts
|
187
|
+
The path to the value in the PyTree.
|
188
|
+
value : Any
|
189
|
+
The value at that path.
|
190
|
+
|
191
|
+
Returns
|
192
|
+
-------
|
193
|
+
bool
|
194
|
+
True if the path/value combination matches the predicate.
|
195
|
+
|
196
|
+
Examples
|
197
|
+
--------
|
198
|
+
.. code-block:: python
|
199
|
+
|
200
|
+
>>> def is_weight_matrix(path: PathParts, value: Any) -> bool:
|
201
|
+
... '''Check if a value is a weight matrix (2D array).'''
|
202
|
+
... return len(path) > 0 and "weight" in str(path[-1]) and hasattr(value, 'ndim') and value.ndim == 2
|
203
|
+
>>>
|
204
|
+
>>> def is_bias_vector(path: PathParts, value: Any) -> bool:
|
205
|
+
... '''Check if a value is a bias vector (1D array).'''
|
206
|
+
... return len(path) > 0 and "bias" in str(path[-1]) and hasattr(value, 'ndim') and value.ndim == 1
|
207
|
+
"""
|
208
|
+
|
209
|
+
FilterLiteral = Union[type, str, Predicate, bool, Ellipsis, None]
|
210
|
+
"""Basic filter types that can be used to select parts of a PyTree.
|
211
|
+
|
212
|
+
Components
|
213
|
+
----------
|
214
|
+
type
|
215
|
+
Filter by type, e.g., `float`, `jax.Array`.
|
216
|
+
str
|
217
|
+
Filter by string matching in path keys.
|
218
|
+
Predicate
|
219
|
+
Custom function for complex filtering logic.
|
220
|
+
bool
|
221
|
+
Simple True/False filter.
|
222
|
+
Ellipsis
|
223
|
+
Wildcard filter that matches anything.
|
224
|
+
None
|
225
|
+
Filter that matches None values.
|
226
|
+
|
227
|
+
Examples
|
228
|
+
--------
|
229
|
+
.. code-block:: python
|
230
|
+
|
231
|
+
>>> # Filter by type
|
232
|
+
>>> float_filter: FilterLiteral = float
|
233
|
+
>>>
|
234
|
+
>>> # Filter by string pattern
|
235
|
+
>>> weight_filter: FilterLiteral = "weight"
|
236
|
+
>>>
|
237
|
+
>>> # Custom predicate filter
|
238
|
+
>>> matrix_filter: FilterLiteral = lambda path, x: hasattr(x, 'ndim') and x.ndim == 2
|
239
|
+
"""
|
240
|
+
|
241
|
+
Filter = Union[FilterLiteral, Tuple['Filter', ...], List['Filter']]
|
242
|
+
"""Flexible filter type that can be a single filter or combination of filters.
|
243
|
+
|
244
|
+
This allows for complex filtering patterns by combining multiple filter criteria.
|
245
|
+
|
246
|
+
Examples
|
247
|
+
--------
|
248
|
+
.. code-block:: python
|
249
|
+
|
250
|
+
>>> # Single filter
|
251
|
+
>>> simple_filter: Filter = "weight"
|
252
|
+
>>>
|
253
|
+
>>> # Tuple of filters (all must match)
|
254
|
+
>>> combined_filter: Filter = (float, "weight")
|
255
|
+
>>>
|
256
|
+
>>> # List of filters (any can match)
|
257
|
+
>>> alternative_filter: Filter = [int, float, "bias"]
|
258
|
+
>>>
|
259
|
+
>>> # Nested combinations
|
260
|
+
>>> complex_filter: Filter = [
|
261
|
+
... ("weight", lambda p, x: x.ndim == 2), # 2D weight matrices
|
262
|
+
... ("bias", lambda p, x: x.ndim == 1), # 1D bias vectors
|
263
|
+
... ]
|
264
|
+
"""
|
265
|
+
|
266
|
+
|
267
|
+
# ============================================================================
|
268
|
+
# Array Annotation Types
|
269
|
+
# ============================================================================
|
66
270
|
|
67
271
|
class _Array(Generic[_Annotation]):
|
272
|
+
"""Internal generic array type for creating custom array annotations."""
|
68
273
|
pass
|
69
274
|
|
70
275
|
|
@@ -72,9 +277,26 @@ _Array.__module__ = "builtins"
|
|
72
277
|
|
73
278
|
|
74
279
|
def _item_to_str(item: Union[str, type, slice]) -> str:
|
280
|
+
"""Convert an array annotation item to its string representation.
|
281
|
+
|
282
|
+
Parameters
|
283
|
+
----------
|
284
|
+
item : Union[str, type, slice]
|
285
|
+
The item to convert to string.
|
286
|
+
|
287
|
+
Returns
|
288
|
+
-------
|
289
|
+
str
|
290
|
+
String representation of the item.
|
291
|
+
|
292
|
+
Raises
|
293
|
+
------
|
294
|
+
NotImplementedError
|
295
|
+
If slice has a step component.
|
296
|
+
"""
|
75
297
|
if isinstance(item, slice):
|
76
298
|
if item.step is not None:
|
77
|
-
raise NotImplementedError
|
299
|
+
raise NotImplementedError("Slice steps are not supported in array annotations")
|
78
300
|
return _item_to_str(item.start) + ": " + _item_to_str(item.stop)
|
79
301
|
elif item is ...:
|
80
302
|
return "..."
|
@@ -87,19 +309,87 @@ def _item_to_str(item: Union[str, type, slice]) -> str:
|
|
87
309
|
def _maybe_tuple_to_str(
|
88
310
|
item: Union[str, type, slice, Tuple[Union[str, type, slice], ...]]
|
89
311
|
) -> str:
|
312
|
+
"""Convert array annotation items (potentially in tuple) to string representation.
|
313
|
+
|
314
|
+
Parameters
|
315
|
+
----------
|
316
|
+
item : Union[str, type, slice, Tuple[...]]
|
317
|
+
Single item or tuple of items to convert.
|
318
|
+
|
319
|
+
Returns
|
320
|
+
-------
|
321
|
+
str
|
322
|
+
String representation of the item(s).
|
323
|
+
"""
|
90
324
|
if isinstance(item, tuple):
|
91
325
|
if len(item) == 0:
|
92
|
-
# Explicit brackets
|
326
|
+
# Explicit brackets for empty tuple
|
93
327
|
return "()"
|
94
328
|
else:
|
95
|
-
# No brackets
|
329
|
+
# No brackets for non-empty tuple
|
96
330
|
return ", ".join([_item_to_str(i) for i in item])
|
97
331
|
else:
|
98
332
|
return _item_to_str(item)
|
99
333
|
|
100
334
|
|
101
335
|
class Array:
|
336
|
+
"""Flexible array type annotation supporting shape and dtype specifications.
|
337
|
+
|
338
|
+
This class provides a convenient way to annotate arrays with shape information,
|
339
|
+
making code more self-documenting and enabling better static analysis.
|
340
|
+
|
341
|
+
Examples
|
342
|
+
--------
|
343
|
+
Basic array annotations:
|
344
|
+
|
345
|
+
.. code-block:: python
|
346
|
+
|
347
|
+
>>> from brainstate.typing import Array
|
348
|
+
>>>
|
349
|
+
>>> # Any array
|
350
|
+
>>> def process_array(x: Array) -> Array:
|
351
|
+
... return x * 2
|
352
|
+
>>>
|
353
|
+
>>> # Array with specific shape annotation
|
354
|
+
>>> def matrix_multiply(a: Array["m, n"], b: Array["n, k"]) -> Array["m, k"]:
|
355
|
+
... return a @ b
|
356
|
+
>>>
|
357
|
+
>>> # Array with dtype and shape
|
358
|
+
>>> def normalize_weights(weights: Array["batch, features"]) -> Array["batch, features"]:
|
359
|
+
... return weights / weights.sum(axis=-1, keepdims=True)
|
360
|
+
|
361
|
+
Advanced shape annotations:
|
362
|
+
|
363
|
+
.. code-block:: python
|
364
|
+
|
365
|
+
>>> # Using ellipsis for flexible dimensions
|
366
|
+
>>> def flatten_batch(x: Array["batch, ..."]) -> Array["batch, -1"]:
|
367
|
+
... return x.reshape(x.shape[0], -1)
|
368
|
+
>>>
|
369
|
+
>>> # Multiple shape constraints
|
370
|
+
>>> def attention(
|
371
|
+
... query: Array["batch, seq_len, d_model"],
|
372
|
+
... key: Array["batch, seq_len, d_model"],
|
373
|
+
... value: Array["batch, seq_len, d_model"]
|
374
|
+
... ) -> Array["batch, seq_len, d_model"]:
|
375
|
+
... # Attention computation
|
376
|
+
... pass
|
377
|
+
"""
|
378
|
+
|
102
379
|
def __class_getitem__(cls, item):
|
380
|
+
"""Create a specialized Array type with shape/dtype annotations.
|
381
|
+
|
382
|
+
Parameters
|
383
|
+
----------
|
384
|
+
item : str, type, slice, or tuple
|
385
|
+
Shape specification, dtype, or combination thereof.
|
386
|
+
|
387
|
+
Returns
|
388
|
+
-------
|
389
|
+
_Array
|
390
|
+
Specialized array type with the given annotation.
|
391
|
+
"""
|
392
|
+
|
103
393
|
class X:
|
104
394
|
pass
|
105
395
|
|
@@ -108,14 +398,16 @@ class Array:
|
|
108
398
|
return _Array[X]
|
109
399
|
|
110
400
|
|
111
|
-
#
|
112
|
-
# doing `def f(x: Array)` as well as `def f(x: Array["dim"])`.
|
113
|
-
#
|
114
|
-
# Don't need to set __qualname__ as that's already correct.
|
401
|
+
# Set module for proper display in type hints
|
115
402
|
Array.__module__ = "builtins"
|
116
403
|
|
117
404
|
|
405
|
+
# ============================================================================
|
406
|
+
# PyTree Types
|
407
|
+
# ============================================================================
|
408
|
+
|
118
409
|
class _FakePyTree(Generic[_T]):
|
410
|
+
"""Internal generic PyTree type for creating specialized PyTree annotations."""
|
119
411
|
pass
|
120
412
|
|
121
413
|
|
@@ -125,7 +417,16 @@ _FakePyTree.__module__ = "builtins"
|
|
125
417
|
|
126
418
|
|
127
419
|
class _MetaPyTree(type):
|
420
|
+
"""Metaclass for PyTree type that prevents instantiation and handles subscripting."""
|
421
|
+
|
128
422
|
def __call__(self, *args, **kwargs):
|
423
|
+
"""Prevent direct instantiation of PyTree type.
|
424
|
+
|
425
|
+
Raises
|
426
|
+
------
|
427
|
+
RuntimeError
|
428
|
+
Always raised since PyTree is a type annotation only.
|
429
|
+
"""
|
129
430
|
raise RuntimeError("PyTree cannot be instantiated")
|
130
431
|
|
131
432
|
# Can't return a generic (e.g. _FakePyTree[item]) because generic aliases don't do
|
@@ -135,7 +436,7 @@ class _MetaPyTree(type):
|
|
135
436
|
# isn't allowed.
|
136
437
|
# Likewise we can't do types.new_class("PyTree", (Generic[_T],), {}) because that
|
137
438
|
# has __module__ "types", e.g. we get types.PyTree[int].
|
138
|
-
@
|
439
|
+
@functools.lru_cache(maxsize=None)
|
139
440
|
def __getitem__(cls, item):
|
140
441
|
if isinstance(item, tuple):
|
141
442
|
if len(item) == 2:
|
@@ -206,14 +507,15 @@ else:
|
|
206
507
|
PyTree.__doc__ = """Represents a PyTree.
|
207
508
|
|
208
509
|
Annotations of the following sorts are supported:
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
510
|
+
|
511
|
+
.. code-block:: python
|
512
|
+
|
513
|
+
>>> a: PyTree
|
514
|
+
>>> b: PyTree[LeafType]
|
515
|
+
>>> c: PyTree[LeafType, "T"]
|
516
|
+
>>> d: PyTree[LeafType, "S T"]
|
517
|
+
>>> e: PyTree[LeafType, "... T"]
|
518
|
+
>>> f: PyTree[LeafType, "T ..."]
|
217
519
|
|
218
520
|
These correspond to:
|
219
521
|
|
@@ -227,23 +529,26 @@ b. `PyTree[LeafType]` denotes a PyTree all of whose leaves match `LeafType`. For
|
|
227
529
|
c. A structure name can also be passed. In this case
|
228
530
|
`jax.tree_util.tree_structure(...)` will be called, and bound to the structure name.
|
229
531
|
This can be used to mark that multiple PyTrees all have the same structure:
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
532
|
+
|
533
|
+
.. code-block:: python
|
534
|
+
|
535
|
+
>>> def f(x: PyTree[int, "T"], y: PyTree[int, "T"]):
|
536
|
+
... ...
|
234
537
|
|
235
538
|
d. A composite structure can be declared. In this case the variable must have a PyTree
|
236
539
|
structure each to the composition of multiple previously-bound PyTree structures.
|
237
540
|
For example:
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
541
|
+
|
542
|
+
.. code-block:: python
|
543
|
+
|
544
|
+
>>> def f(x: PyTree[int, "T"], y: PyTree[int, "S"], z: PyTree[int, "S T"]):
|
545
|
+
... ...
|
546
|
+
>>>
|
547
|
+
>>> x = (1, 2)
|
548
|
+
>>> y = {"key": 3}
|
549
|
+
>>> z = {"key": (4, 5)} # structure is the composition of the structures of `y` and `z`
|
550
|
+
>>> f(x, y, z)
|
551
|
+
|
247
552
|
When performing runtime type-checking, all the individual pieces must have already
|
248
553
|
been bound to structures, otherwise the composite structure check will throw an error.
|
249
554
|
|
@@ -257,48 +562,276 @@ f. A structure can end with a `...`, to denote that the PyTree must be a prefix
|
|
257
562
|
cases, all named pieces must already have been seen and their structures bound.
|
258
563
|
""" # noqa: E501
|
259
564
|
|
565
|
+
# ============================================================================
|
566
|
+
# Shape and Size Types
|
567
|
+
# ============================================================================
|
568
|
+
|
260
569
|
Size = Union[int, Sequence[int], np.integer, Sequence[np.integer]]
|
261
|
-
|
262
|
-
|
570
|
+
"""Type for specifying array sizes and dimensions.
|
571
|
+
|
572
|
+
Can be a single integer for 1D sizes, or a sequence of integers for multi-dimensional shapes.
|
573
|
+
Supports both Python integers and NumPy integer types for compatibility.
|
574
|
+
|
575
|
+
Examples
|
576
|
+
--------
|
577
|
+
.. code-block:: python
|
578
|
+
|
579
|
+
>>> # Single dimension
|
580
|
+
>>> size1: Size = 10
|
581
|
+
>>>
|
582
|
+
>>> # Multiple dimensions
|
583
|
+
>>> size2: Size = (3, 4, 5)
|
584
|
+
>>>
|
585
|
+
>>> # Using NumPy integers
|
586
|
+
>>> size3: Size = np.int32(8)
|
587
|
+
>>>
|
588
|
+
>>> # Mixed sequence
|
589
|
+
>>> size4: Size = [np.int64(2), 3, np.int32(4)]
|
590
|
+
"""
|
591
|
+
|
263
592
|
Shape = Sequence[int]
|
593
|
+
"""Type for array shapes as sequences of integers.
|
264
594
|
|
265
|
-
|
595
|
+
Represents the shape of an array as a sequence of dimension sizes.
|
596
|
+
More restrictive than Size as it requires a sequence.
|
597
|
+
|
598
|
+
Examples
|
599
|
+
--------
|
600
|
+
.. code-block:: python
|
601
|
+
|
602
|
+
>>> # 2D array shape
|
603
|
+
>>> matrix_shape: Shape = (10, 20)
|
604
|
+
>>>
|
605
|
+
>>> # 3D array shape
|
606
|
+
>>> tensor_shape: Shape = (5, 10, 15)
|
607
|
+
>>>
|
608
|
+
>>> # 1D array shape (note: still needs to be a sequence)
|
609
|
+
>>> vector_shape: Shape = (100,)
|
610
|
+
"""
|
611
|
+
|
612
|
+
Axes = Union[int, Sequence[int]]
|
613
|
+
"""Type for specifying axes along which operations should be performed.
|
614
|
+
|
615
|
+
Can be a single axis (integer) or multiple axes (sequence of integers).
|
616
|
+
Used in reduction operations, reshaping, and other array manipulations.
|
617
|
+
|
618
|
+
Examples
|
619
|
+
--------
|
620
|
+
.. code-block:: python
|
621
|
+
|
622
|
+
>>> # Single axis
|
623
|
+
>>> axis1: Axes = 0
|
624
|
+
>>>
|
625
|
+
>>> # Multiple axes
|
626
|
+
>>> axis2: Axes = (0, 2)
|
627
|
+
>>>
|
628
|
+
>>> # All axes for global operations
|
629
|
+
>>> axis3: Axes = tuple(range(ndim))
|
630
|
+
>>>
|
631
|
+
>>> def sum_along_axes(array: ArrayLike, axes: Axes) -> ArrayLike:
|
632
|
+
... return jnp.sum(array, axis=axes)
|
633
|
+
"""
|
634
|
+
|
635
|
+
# ============================================================================
|
636
|
+
# Array Types
|
637
|
+
# ============================================================================
|
266
638
|
|
267
|
-
# ArrayLike is a Union of all objects that can be implicitly converted to a
|
268
|
-
# standard JAX array (i.e. not including future non-standard array types like
|
269
|
-
# KeyArray and BInt). It's different than np.typing.ArrayLike in that it doesn't
|
270
|
-
# accept arbitrary sequences, nor does it accept string data.
|
271
639
|
ArrayLike = Union[
|
272
640
|
jax.Array, # JAX array type
|
273
641
|
np.ndarray, # NumPy array type
|
274
642
|
np.bool_, np.number, # NumPy scalar types
|
275
643
|
bool, int, float, complex, # Python scalar types
|
276
|
-
u.Quantity, #
|
644
|
+
u.Quantity, # BrainUnit quantity type
|
277
645
|
]
|
646
|
+
"""Union of all objects that can be implicitly converted to a JAX array.
|
647
|
+
|
648
|
+
This type is designed for JAX compatibility and excludes arbitrary sequences
|
649
|
+
and string data that numpy.typing.ArrayLike would include. It represents
|
650
|
+
data that can be safely converted to arrays without ambiguity.
|
651
|
+
|
652
|
+
Components
|
653
|
+
----------
|
654
|
+
jax.Array
|
655
|
+
Native JAX arrays.
|
656
|
+
np.ndarray
|
657
|
+
NumPy arrays that can be converted to JAX arrays.
|
658
|
+
np.bool_, np.number
|
659
|
+
NumPy scalar types (bool, int8, float32, etc.).
|
660
|
+
bool, int, float, complex
|
661
|
+
Python built-in scalar types.
|
662
|
+
u.Quantity
|
663
|
+
BrainUnit quantities with physical units.
|
664
|
+
|
665
|
+
Examples
|
666
|
+
--------
|
667
|
+
.. code-block:: python
|
668
|
+
|
669
|
+
>>> def process_data(data: ArrayLike) -> jax.Array:
|
670
|
+
... '''Convert input to JAX array and process it.'''
|
671
|
+
... array = jnp.asarray(data)
|
672
|
+
... return array * 2
|
673
|
+
>>>
|
674
|
+
>>> # Valid inputs
|
675
|
+
>>> process_data(jnp.array([1, 2, 3])) # JAX array
|
676
|
+
>>> process_data(np.array([1, 2, 3])) # NumPy array
|
677
|
+
>>> process_data([1, 2, 3]) # Python list (via numpy)
|
678
|
+
>>> process_data(42) # Python scalar
|
679
|
+
>>> process_data(np.float32(3.14)) # NumPy scalar
|
680
|
+
>>> process_data(1.5 * u.second) # Quantity with units
|
681
|
+
"""
|
682
|
+
|
683
|
+
# ============================================================================
|
684
|
+
# Data Type Annotations
|
685
|
+
# ============================================================================
|
278
686
|
|
279
|
-
|
687
|
+
DType = np.dtype
|
688
|
+
"""Alias for NumPy's dtype type.
|
280
689
|
|
690
|
+
Used to represent data types of arrays in a clear and consistent manner.
|
281
691
|
|
282
|
-
|
692
|
+
Examples
|
693
|
+
--------
|
694
|
+
.. code-block:: python
|
695
|
+
|
696
|
+
>>> def create_array(shape: Shape, dtype: DType) -> jax.Array:
|
697
|
+
... return jnp.zeros(shape, dtype=dtype)
|
698
|
+
>>>
|
699
|
+
>>> # Usage
|
700
|
+
>>> arr = create_array((3, 4), np.float32)
|
701
|
+
"""
|
283
702
|
|
284
703
|
|
285
704
|
class SupportsDType(Protocol):
|
705
|
+
"""Protocol for objects that have a dtype property.
|
706
|
+
|
707
|
+
This protocol defines the interface for any object that exposes
|
708
|
+
a dtype attribute, allowing for flexible type checking.
|
709
|
+
|
710
|
+
Examples
|
711
|
+
--------
|
712
|
+
.. code-block:: python
|
713
|
+
|
714
|
+
>>> def get_dtype(obj: SupportsDType) -> DType:
|
715
|
+
... return obj.dtype
|
716
|
+
>>>
|
717
|
+
>>> # Works with arrays
|
718
|
+
>>> arr = jnp.array([1.0, 2.0])
|
719
|
+
>>> dtype = get_dtype(arr) # float32
|
720
|
+
"""
|
721
|
+
|
286
722
|
@property
|
287
|
-
def dtype(self) -> DType:
|
723
|
+
def dtype(self) -> DType:
|
724
|
+
"""Return the data type of the object.
|
725
|
+
|
726
|
+
Returns
|
727
|
+
-------
|
728
|
+
DType
|
729
|
+
The NumPy dtype of the object.
|
730
|
+
"""
|
731
|
+
...
|
288
732
|
|
289
733
|
|
290
|
-
# DTypeLike is meant to annotate inputs to np.dtype that return
|
291
|
-
# a valid JAX dtype. It's different than numpy.typing.DTypeLike
|
292
|
-
# because JAX doesn't support objects or structured dtypes.
|
293
|
-
# Unlike np.typing.DTypeLike, we exclude None, and instead require
|
294
|
-
# explicit annotations when None is acceptable.
|
295
734
|
DTypeLike = Union[
|
296
|
-
str, # like 'float32', 'int32'
|
297
|
-
type[Any], # like np.float32, np.int32, float, int
|
298
|
-
np.dtype, #
|
299
|
-
SupportsDType, #
|
735
|
+
str, # String representations like 'float32', 'int32'
|
736
|
+
type[Any], # Type objects like np.float32, np.int32, float, int
|
737
|
+
np.dtype, # NumPy dtype objects
|
738
|
+
SupportsDType, # Objects with a dtype property
|
300
739
|
]
|
740
|
+
"""Union of types that can be converted to a valid JAX dtype.
|
741
|
+
|
742
|
+
This is more restrictive than numpy.typing.DTypeLike as JAX doesn't support
|
743
|
+
object arrays or structured dtypes. It excludes None to require explicit
|
744
|
+
handling of optional dtypes.
|
745
|
+
|
746
|
+
Components
|
747
|
+
----------
|
748
|
+
str
|
749
|
+
String representations like 'float32', 'int32', 'bool'.
|
750
|
+
type[Any]
|
751
|
+
Type objects like np.float32, float, int, bool.
|
752
|
+
np.dtype
|
753
|
+
NumPy dtype objects created with np.dtype().
|
754
|
+
SupportsDType
|
755
|
+
Any object with a .dtype property.
|
756
|
+
|
757
|
+
Examples
|
758
|
+
--------
|
759
|
+
.. code-block:: python
|
760
|
+
|
761
|
+
>>> def cast_array(array: ArrayLike, dtype: DTypeLike) -> jax.Array:
|
762
|
+
... '''Cast array to specified dtype.'''
|
763
|
+
... return jnp.asarray(array, dtype=dtype)
|
764
|
+
>>>
|
765
|
+
>>> # Valid dtype specifications
|
766
|
+
>>> cast_array(data, 'float32') # String
|
767
|
+
>>> cast_array(data, np.float32) # NumPy type
|
768
|
+
>>> cast_array(data, float) # Python type
|
769
|
+
>>> cast_array(data, np.dtype('int32')) # NumPy dtype object
|
770
|
+
>>> cast_array(data, other_array) # Object with dtype property
|
771
|
+
"""
|
772
|
+
|
773
|
+
# ============================================================================
|
774
|
+
# Random Number Generation
|
775
|
+
# ============================================================================
|
301
776
|
|
777
|
+
SeedOrKey = Union[int, jax.Array, np.ndarray]
|
778
|
+
"""Type for random number generator seeds or keys.
|
779
|
+
|
780
|
+
Represents values that can be used to seed random number generators
|
781
|
+
or serve as PRNG keys in JAX's random number generation system.
|
782
|
+
|
783
|
+
Components
|
784
|
+
----------
|
785
|
+
int
|
786
|
+
Integer seeds for random number generators.
|
787
|
+
jax.Array
|
788
|
+
JAX PRNG keys (typically created with jax.random.PRNGKey).
|
789
|
+
np.ndarray
|
790
|
+
NumPy arrays that can serve as random keys.
|
791
|
+
|
792
|
+
Examples
|
793
|
+
--------
|
794
|
+
.. code-block:: python
|
795
|
+
|
796
|
+
>>> def generate_random(key: SeedOrKey, shape: Shape) -> jax.Array:
|
797
|
+
... '''Generate random numbers using the provided seed or key.'''
|
798
|
+
... if isinstance(key, int):
|
799
|
+
... key = jax.random.PRNGKey(key)
|
800
|
+
... return jax.random.normal(key, shape)
|
801
|
+
>>>
|
802
|
+
>>> # Valid seeds/keys
|
803
|
+
>>> generate_random(42, (3, 4)) # Integer seed
|
804
|
+
>>> generate_random(jax.random.PRNGKey(123), (5,)) # JAX PRNG key
|
805
|
+
>>> generate_random(np.array([1, 2], dtype=np.uint32), (2, 2)) # NumPy array
|
806
|
+
"""
|
807
|
+
|
808
|
+
|
809
|
+
# ============================================================================
|
810
|
+
# Utility Types
|
811
|
+
# ============================================================================
|
302
812
|
|
303
813
|
class Missing:
|
814
|
+
"""Sentinel class to represent missing or unspecified values.
|
815
|
+
|
816
|
+
This class is used as a default value when None has semantic meaning
|
817
|
+
and you need to distinguish between "None was passed" and "nothing was passed".
|
818
|
+
|
819
|
+
Examples
|
820
|
+
--------
|
821
|
+
.. code-block:: python
|
822
|
+
|
823
|
+
>>> _MISSING = Missing()
|
824
|
+
>>>
|
825
|
+
>>> def function_with_optional_param(value: Union[int, None, Missing] = _MISSING):
|
826
|
+
... if value is _MISSING:
|
827
|
+
... print("No value provided")
|
828
|
+
... elif value is None:
|
829
|
+
... print("None was explicitly provided")
|
830
|
+
... else:
|
831
|
+
... print(f"Value: {value}")
|
832
|
+
>>>
|
833
|
+
>>> function_with_optional_param() # "No value provided"
|
834
|
+
>>> function_with_optional_param(None) # "None was explicitly provided"
|
835
|
+
>>> function_with_optional_param(42) # "Value: 42"
|
836
|
+
"""
|
304
837
|
pass
|