brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
brainstate/typing.py CHANGED
@@ -1,837 +1,837 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """
17
- Comprehensive type annotations for BrainState.
18
-
19
- This module provides a collection of type aliases, protocols, and generic types
20
- specifically designed for scientific computing, neural network modeling, and
21
- array operations within the BrainState ecosystem.
22
-
23
- The type system is designed to be compatible with JAX, NumPy, and BrainUnit,
24
- providing comprehensive type hints for arrays, shapes, seeds, and PyTree structures.
25
-
26
- Examples
27
- --------
28
- Basic usage with array types:
29
-
30
- .. code-block:: python
31
-
32
- >>> import brainstate
33
- >>> from brainstate.typing import ArrayLike, Shape, DTypeLike
34
- >>>
35
- >>> def process_array(data: ArrayLike, shape: Shape, dtype: DTypeLike) -> brainstate.Array:
36
- ... return brainstate.asarray(data, dtype=dtype).reshape(shape)
37
-
38
- Using PyTree annotations:
39
-
40
- .. code-block:: python
41
-
42
- >>> from brainstate.typing import PyTree
43
- >>>
44
- >>> def tree_function(tree: PyTree[float, "T"]) -> PyTree[float, "T"]:
45
- ... return brainstate.tree_map(lambda x: x * 2, tree)
46
- """
47
-
48
- import builtins
49
- import functools
50
- import importlib
51
- import inspect
52
- from typing import (
53
- Any, Callable, Hashable, List, Protocol, Tuple, TypeVar, Union,
54
- runtime_checkable, TYPE_CHECKING, Generic, Sequence
55
- )
56
-
57
- import brainunit as u
58
- import jax
59
- import numpy as np
60
-
61
- tp = importlib.import_module("typing")
62
-
63
- __all__ = [
64
- # Path and filter types
65
- 'PathParts',
66
- 'Predicate',
67
- 'Filter',
68
- 'FilterLiteral',
69
-
70
- # Array and shape types
71
- 'Array',
72
- 'ArrayLike',
73
- 'Shape',
74
- 'Size',
75
- 'Axes',
76
- 'DType',
77
- 'DTypeLike',
78
- 'SupportsDType',
79
-
80
- # PyTree types
81
- 'PyTree',
82
-
83
- # Random number generation
84
- 'SeedOrKey',
85
-
86
- # Utility types
87
- 'Key',
88
- 'Missing',
89
-
90
- # Type variables
91
- 'K',
92
- '_T',
93
- '_Annotation',
94
- ]
95
-
96
- # ============================================================================
97
- # Type Variables
98
- # ============================================================================
99
-
100
- K = TypeVar('K', bound='Key')
101
- """Type variable for keys that must be comparable and hashable."""
102
-
103
- _T = TypeVar("_T")
104
- """Generic type variable for any type."""
105
-
106
- _Annotation = TypeVar("_Annotation")
107
- """Type variable for array annotations."""
108
-
109
-
110
- # ============================================================================
111
- # Key and Path Types
112
- # ============================================================================
113
-
114
- @runtime_checkable
115
- class Key(Hashable, Protocol):
116
- """Protocol for keys that can be used in PyTree paths.
117
-
118
- A Key must be both hashable and comparable, making it suitable
119
- for use as dictionary keys and for ordering operations.
120
-
121
- Examples
122
- --------
123
- Valid key types include:
124
-
125
- .. code-block:: python
126
-
127
- >>> # String keys
128
- >>> key1: Key = "layer1"
129
- >>>
130
- >>> # Integer keys
131
- >>> key2: Key = 42
132
- >>>
133
- >>> # Custom hashable objects
134
- >>> class CustomKey:
135
- ... def __init__(self, name: str):
136
- ... self.name = name
137
- ...
138
- ... def __hash__(self) -> int:
139
- ... return hash(self.name)
140
- ...
141
- ... def __eq__(self, other) -> bool:
142
- ... return isinstance(other, CustomKey) and self.name == other.name
143
- ...
144
- ... def __lt__(self, other) -> bool:
145
- ... return isinstance(other, CustomKey) and self.name < other.name
146
- """
147
-
148
- def __lt__(self: K, value: K, /) -> bool:
149
- """Less than comparison for ordering keys.
150
-
151
- Parameters
152
- ----------
153
- value : Key
154
- The key to compare against.
155
-
156
- Returns
157
- -------
158
- bool
159
- True if this key is less than the other key.
160
- """
161
- ...
162
-
163
-
164
- Ellipsis = builtins.ellipsis if TYPE_CHECKING else Any
165
- """Type alias for ellipsis, used in filter expressions."""
166
-
167
- PathParts = Tuple[Key, ...]
168
- """Tuple of keys representing a path through a PyTree structure.
169
-
170
- Examples
171
- --------
172
- .. code-block:: python
173
-
174
- >>> # Path to a nested value in a PyTree
175
- >>> path: PathParts = ("model", "layers", 0, "weights")
176
- >>>
177
- >>> # Empty path representing the root
178
- >>> root_path: PathParts = ()
179
- """
180
-
181
- Predicate = Callable[[PathParts, Any], bool]
182
- """Function that takes a path and value, returning whether it matches some condition.
183
-
184
- Parameters
185
- ----------
186
- path : PathParts
187
- The path to the value in the PyTree.
188
- value : Any
189
- The value at that path.
190
-
191
- Returns
192
- -------
193
- bool
194
- True if the path/value combination matches the predicate.
195
-
196
- Examples
197
- --------
198
- .. code-block:: python
199
-
200
- >>> def is_weight_matrix(path: PathParts, value: Any) -> bool:
201
- ... '''Check if a value is a weight matrix (2D array).'''
202
- ... return len(path) > 0 and "weight" in str(path[-1]) and hasattr(value, 'ndim') and value.ndim == 2
203
- >>>
204
- >>> def is_bias_vector(path: PathParts, value: Any) -> bool:
205
- ... '''Check if a value is a bias vector (1D array).'''
206
- ... return len(path) > 0 and "bias" in str(path[-1]) and hasattr(value, 'ndim') and value.ndim == 1
207
- """
208
-
209
- FilterLiteral = Union[type, str, Predicate, bool, Ellipsis, None]
210
- """Basic filter types that can be used to select parts of a PyTree.
211
-
212
- Components
213
- ----------
214
- type
215
- Filter by type, e.g., `float`, `jax.Array`.
216
- str
217
- Filter by string matching in path keys.
218
- Predicate
219
- Custom function for complex filtering logic.
220
- bool
221
- Simple True/False filter.
222
- Ellipsis
223
- Wildcard filter that matches anything.
224
- None
225
- Filter that matches None values.
226
-
227
- Examples
228
- --------
229
- .. code-block:: python
230
-
231
- >>> # Filter by type
232
- >>> float_filter: FilterLiteral = float
233
- >>>
234
- >>> # Filter by string pattern
235
- >>> weight_filter: FilterLiteral = "weight"
236
- >>>
237
- >>> # Custom predicate filter
238
- >>> matrix_filter: FilterLiteral = lambda path, x: hasattr(x, 'ndim') and x.ndim == 2
239
- """
240
-
241
- Filter = Union[FilterLiteral, Tuple['Filter', ...], List['Filter']]
242
- """Flexible filter type that can be a single filter or combination of filters.
243
-
244
- This allows for complex filtering patterns by combining multiple filter criteria.
245
-
246
- Examples
247
- --------
248
- .. code-block:: python
249
-
250
- >>> # Single filter
251
- >>> simple_filter: Filter = "weight"
252
- >>>
253
- >>> # Tuple of filters (all must match)
254
- >>> combined_filter: Filter = (float, "weight")
255
- >>>
256
- >>> # List of filters (any can match)
257
- >>> alternative_filter: Filter = [int, float, "bias"]
258
- >>>
259
- >>> # Nested combinations
260
- >>> complex_filter: Filter = [
261
- ... ("weight", lambda p, x: x.ndim == 2), # 2D weight matrices
262
- ... ("bias", lambda p, x: x.ndim == 1), # 1D bias vectors
263
- ... ]
264
- """
265
-
266
-
267
- # ============================================================================
268
- # Array Annotation Types
269
- # ============================================================================
270
-
271
- class _Array(Generic[_Annotation]):
272
- """Internal generic array type for creating custom array annotations."""
273
- pass
274
-
275
-
276
- _Array.__module__ = "builtins"
277
-
278
-
279
- def _item_to_str(item: Union[str, type, slice]) -> str:
280
- """Convert an array annotation item to its string representation.
281
-
282
- Parameters
283
- ----------
284
- item : Union[str, type, slice]
285
- The item to convert to string.
286
-
287
- Returns
288
- -------
289
- str
290
- String representation of the item.
291
-
292
- Raises
293
- ------
294
- NotImplementedError
295
- If slice has a step component.
296
- """
297
- if isinstance(item, slice):
298
- if item.step is not None:
299
- raise NotImplementedError("Slice steps are not supported in array annotations")
300
- return _item_to_str(item.start) + ": " + _item_to_str(item.stop)
301
- elif item is ...:
302
- return "..."
303
- elif inspect.isclass(item):
304
- return item.__name__
305
- else:
306
- return repr(item)
307
-
308
-
309
- def _maybe_tuple_to_str(
310
- item: Union[str, type, slice, Tuple[Union[str, type, slice], ...]]
311
- ) -> str:
312
- """Convert array annotation items (potentially in tuple) to string representation.
313
-
314
- Parameters
315
- ----------
316
- item : Union[str, type, slice, Tuple[...]]
317
- Single item or tuple of items to convert.
318
-
319
- Returns
320
- -------
321
- str
322
- String representation of the item(s).
323
- """
324
- if isinstance(item, tuple):
325
- if len(item) == 0:
326
- # Explicit brackets for empty tuple
327
- return "()"
328
- else:
329
- # No brackets for non-empty tuple
330
- return ", ".join([_item_to_str(i) for i in item])
331
- else:
332
- return _item_to_str(item)
333
-
334
-
335
- class Array:
336
- """Flexible array type annotation supporting shape and dtype specifications.
337
-
338
- This class provides a convenient way to annotate arrays with shape information,
339
- making code more self-documenting and enabling better static analysis.
340
-
341
- Examples
342
- --------
343
- Basic array annotations:
344
-
345
- .. code-block:: python
346
-
347
- >>> from brainstate.typing import Array
348
- >>>
349
- >>> # Any array
350
- >>> def process_array(x: Array) -> Array:
351
- ... return x * 2
352
- >>>
353
- >>> # Array with specific shape annotation
354
- >>> def matrix_multiply(a: Array["m, n"], b: Array["n, k"]) -> Array["m, k"]:
355
- ... return a @ b
356
- >>>
357
- >>> # Array with dtype and shape
358
- >>> def normalize_weights(weights: Array["batch, features"]) -> Array["batch, features"]:
359
- ... return weights / weights.sum(axis=-1, keepdims=True)
360
-
361
- Advanced shape annotations:
362
-
363
- .. code-block:: python
364
-
365
- >>> # Using ellipsis for flexible dimensions
366
- >>> def flatten_batch(x: Array["batch, ..."]) -> Array["batch, -1"]:
367
- ... return x.reshape(x.shape[0], -1)
368
- >>>
369
- >>> # Multiple shape constraints
370
- >>> def attention(
371
- ... query: Array["batch, seq_len, d_model"],
372
- ... key: Array["batch, seq_len, d_model"],
373
- ... value: Array["batch, seq_len, d_model"]
374
- ... ) -> Array["batch, seq_len, d_model"]:
375
- ... # Attention computation
376
- ... pass
377
- """
378
-
379
- def __class_getitem__(cls, item):
380
- """Create a specialized Array type with shape/dtype annotations.
381
-
382
- Parameters
383
- ----------
384
- item : str, type, slice, or tuple
385
- Shape specification, dtype, or combination thereof.
386
-
387
- Returns
388
- -------
389
- _Array
390
- Specialized array type with the given annotation.
391
- """
392
-
393
- class X:
394
- pass
395
-
396
- X.__module__ = "builtins"
397
- X.__qualname__ = _maybe_tuple_to_str(item)
398
- return _Array[X]
399
-
400
-
401
- # Set module for proper display in type hints
402
- Array.__module__ = "builtins"
403
-
404
-
405
- # ============================================================================
406
- # PyTree Types
407
- # ============================================================================
408
-
409
- class _FakePyTree(Generic[_T]):
410
- """Internal generic PyTree type for creating specialized PyTree annotations."""
411
- pass
412
-
413
-
414
- _FakePyTree.__name__ = "PyTree"
415
- _FakePyTree.__qualname__ = "PyTree"
416
- _FakePyTree.__module__ = "builtins"
417
-
418
-
419
- class _MetaPyTree(type):
420
- """Metaclass for PyTree type that prevents instantiation and handles subscripting."""
421
-
422
- def __call__(self, *args, **kwargs):
423
- """Prevent direct instantiation of PyTree type.
424
-
425
- Raises
426
- ------
427
- RuntimeError
428
- Always raised since PyTree is a type annotation only.
429
- """
430
- raise RuntimeError("PyTree cannot be instantiated")
431
-
432
- # Can't return a generic (e.g. _FakePyTree[item]) because generic aliases don't do
433
- # the custom __instancecheck__ that we want.
434
- # We can't add that __instancecheck__ via subclassing, e.g.
435
- # type("PyTree", (Generic[_T],), {}), because dynamic subclassing of typeforms
436
- # isn't allowed.
437
- # Likewise we can't do types.new_class("PyTree", (Generic[_T],), {}) because that
438
- # has __module__ "types", e.g. we get types.PyTree[int].
439
- @functools.lru_cache(maxsize=None)
440
- def __getitem__(cls, item):
441
- if isinstance(item, tuple):
442
- if len(item) == 2:
443
-
444
- class X(PyTree):
445
- leaftype = item[0]
446
- structure = item[1].strip()
447
-
448
- if not isinstance(X.structure, str):
449
- raise ValueError(
450
- "The structure annotation `struct` in "
451
- "`brainstate.typing.PyTree[leaftype, struct]` must be be a string, "
452
- f"e.g. `brainstate.typing.PyTree[leaftype, 'T']`. Got '{X.structure}'."
453
- )
454
- pieces = X.structure.split()
455
- if len(pieces) == 0:
456
- raise ValueError(
457
- "The string `struct` in `brainstate.typing.PyTree[leaftype, struct]` "
458
- "cannot be the empty string."
459
- )
460
- for piece_index, piece in enumerate(pieces):
461
- if (piece_index == 0) or (piece_index == len(pieces) - 1):
462
- if piece == "...":
463
- continue
464
- if not piece.isidentifier():
465
- raise ValueError(
466
- "The string `struct` in "
467
- "`brainstate.typing.PyTree[leaftype, struct]` must be be a "
468
- "whitespace-separated sequence of identifiers, e.g. "
469
- "`brainstate.typing.PyTree[leaftype, 'T']` or "
470
- "`brainstate.typing.PyTree[leaftype, 'foo bar']`.\n"
471
- "(Here, 'identifier' is used in the same sense as in "
472
- "regular Python, i.e. a valid variable name.)\n"
473
- f"Got piece '{piece}' in overall structure '{X.structure}'."
474
- )
475
- name = str(_FakePyTree[item[0]])[:-1] + ', "' + item[1].strip() + '"]'
476
- else:
477
- raise ValueError(
478
- "The subscript `foo` in `brainstate.typing.PyTree[foo]` must either be a "
479
- "leaf type, e.g. `PyTree[int]`, or a 2-tuple of leaf and "
480
- "structure, e.g. `PyTree[int, 'T']`. Received a tuple of length "
481
- f"{len(item)}."
482
- )
483
- else:
484
- name = str(_FakePyTree[item])
485
-
486
- class X(PyTree):
487
- leaftype = item
488
- structure = None
489
-
490
- X.__name__ = name
491
- X.__qualname__ = name
492
- if getattr(tp, "GENERATING_DOCUMENTATION", False):
493
- X.__module__ = "builtins"
494
- else:
495
- X.__module__ = "brainstate.typing"
496
- return X
497
-
498
-
499
- # Can't do `class PyTree(Generic[_T]): ...` because we need to override the
500
- # instancecheck for PyTree[foo], but subclassing
501
- # `type(Generic[int])`, i.e. `typing._GenericAlias` is disallowed.
502
- PyTree = _MetaPyTree("PyTree", (), {})
503
- if getattr(tp, "GENERATING_DOCUMENTATION", False):
504
- PyTree.__module__ = "builtins"
505
- else:
506
- PyTree.__module__ = "brainstate.typing"
507
- PyTree.__doc__ = """Represents a PyTree.
508
-
509
- Annotations of the following sorts are supported:
510
-
511
- .. code-block:: python
512
-
513
- >>> a: PyTree
514
- >>> b: PyTree[LeafType]
515
- >>> c: PyTree[LeafType, "T"]
516
- >>> d: PyTree[LeafType, "S T"]
517
- >>> e: PyTree[LeafType, "... T"]
518
- >>> f: PyTree[LeafType, "T ..."]
519
-
520
- These correspond to:
521
-
522
- a. A plain `PyTree` can be used an annotation, in which case `PyTree` is simply a
523
- suggestively-named alternative to `Any`.
524
- ([By definition all types are PyTrees.](https://jax.readthedocs.io/en/latest/pytrees.html))
525
-
526
- b. `PyTree[LeafType]` denotes a PyTree all of whose leaves match `LeafType`. For
527
- example, `PyTree[int]` or `PyTree[Union[str, Float32[Array, "b c"]]]`.
528
-
529
- c. A structure name can also be passed. In this case
530
- `jax.tree_util.tree_structure(...)` will be called, and bound to the structure name.
531
- This can be used to mark that multiple PyTrees all have the same structure:
532
-
533
- .. code-block:: python
534
-
535
- >>> def f(x: PyTree[int, "T"], y: PyTree[int, "T"]):
536
- ... ...
537
-
538
- d. A composite structure can be declared. In this case the variable must have a PyTree
539
- structure each to the composition of multiple previously-bound PyTree structures.
540
- For example:
541
-
542
- .. code-block:: python
543
-
544
- >>> def f(x: PyTree[int, "T"], y: PyTree[int, "S"], z: PyTree[int, "S T"]):
545
- ... ...
546
- >>>
547
- >>> x = (1, 2)
548
- >>> y = {"key": 3}
549
- >>> z = {"key": (4, 5)} # structure is the composition of the structures of `y` and `z`
550
- >>> f(x, y, z)
551
-
552
- When performing runtime type-checking, all the individual pieces must have already
553
- been bound to structures, otherwise the composite structure check will throw an error.
554
-
555
- e. A structure can begin with a `...`, to denote that the lower levels of the PyTree
556
- must match the declared structure, but the upper levels can be arbitrary. As in the
557
- previous case, all named pieces must already have been seen and their structures
558
- bound.
559
-
560
- f. A structure can end with a `...`, to denote that the PyTree must be a prefix of the
561
- declared structure, but the lower levels can be arbitrary. As in the previous two
562
- cases, all named pieces must already have been seen and their structures bound.
563
- """ # noqa: E501
564
-
565
- # ============================================================================
566
- # Shape and Size Types
567
- # ============================================================================
568
-
569
- Size = Union[int, Sequence[int], np.integer, Sequence[np.integer]]
570
- """Type for specifying array sizes and dimensions.
571
-
572
- Can be a single integer for 1D sizes, or a sequence of integers for multi-dimensional shapes.
573
- Supports both Python integers and NumPy integer types for compatibility.
574
-
575
- Examples
576
- --------
577
- .. code-block:: python
578
-
579
- >>> # Single dimension
580
- >>> size1: Size = 10
581
- >>>
582
- >>> # Multiple dimensions
583
- >>> size2: Size = (3, 4, 5)
584
- >>>
585
- >>> # Using NumPy integers
586
- >>> size3: Size = np.int32(8)
587
- >>>
588
- >>> # Mixed sequence
589
- >>> size4: Size = [np.int64(2), 3, np.int32(4)]
590
- """
591
-
592
- Shape = Sequence[int]
593
- """Type for array shapes as sequences of integers.
594
-
595
- Represents the shape of an array as a sequence of dimension sizes.
596
- More restrictive than Size as it requires a sequence.
597
-
598
- Examples
599
- --------
600
- .. code-block:: python
601
-
602
- >>> # 2D array shape
603
- >>> matrix_shape: Shape = (10, 20)
604
- >>>
605
- >>> # 3D array shape
606
- >>> tensor_shape: Shape = (5, 10, 15)
607
- >>>
608
- >>> # 1D array shape (note: still needs to be a sequence)
609
- >>> vector_shape: Shape = (100,)
610
- """
611
-
612
- Axes = Union[int, Sequence[int]]
613
- """Type for specifying axes along which operations should be performed.
614
-
615
- Can be a single axis (integer) or multiple axes (sequence of integers).
616
- Used in reduction operations, reshaping, and other array manipulations.
617
-
618
- Examples
619
- --------
620
- .. code-block:: python
621
-
622
- >>> # Single axis
623
- >>> axis1: Axes = 0
624
- >>>
625
- >>> # Multiple axes
626
- >>> axis2: Axes = (0, 2)
627
- >>>
628
- >>> # All axes for global operations
629
- >>> axis3: Axes = tuple(range(ndim))
630
- >>>
631
- >>> def sum_along_axes(array: ArrayLike, axes: Axes) -> ArrayLike:
632
- ... return jnp.sum(array, axis=axes)
633
- """
634
-
635
- # ============================================================================
636
- # Array Types
637
- # ============================================================================
638
-
639
- ArrayLike = Union[
640
- jax.Array, # JAX array type
641
- np.ndarray, # NumPy array type
642
- np.bool_, np.number, # NumPy scalar types
643
- bool, int, float, complex, # Python scalar types
644
- u.Quantity, # BrainUnit quantity type
645
- ]
646
- """Union of all objects that can be implicitly converted to a JAX array.
647
-
648
- This type is designed for JAX compatibility and excludes arbitrary sequences
649
- and string data that numpy.typing.ArrayLike would include. It represents
650
- data that can be safely converted to arrays without ambiguity.
651
-
652
- Components
653
- ----------
654
- jax.Array
655
- Native JAX arrays.
656
- np.ndarray
657
- NumPy arrays that can be converted to JAX arrays.
658
- np.bool_, np.number
659
- NumPy scalar types (bool, int8, float32, etc.).
660
- bool, int, float, complex
661
- Python built-in scalar types.
662
- u.Quantity
663
- BrainUnit quantities with physical units.
664
-
665
- Examples
666
- --------
667
- .. code-block:: python
668
-
669
- >>> def process_data(data: ArrayLike) -> jax.Array:
670
- ... '''Convert input to JAX array and process it.'''
671
- ... array = jnp.asarray(data)
672
- ... return array * 2
673
- >>>
674
- >>> # Valid inputs
675
- >>> process_data(jnp.array([1, 2, 3])) # JAX array
676
- >>> process_data(np.array([1, 2, 3])) # NumPy array
677
- >>> process_data([1, 2, 3]) # Python list (via numpy)
678
- >>> process_data(42) # Python scalar
679
- >>> process_data(np.float32(3.14)) # NumPy scalar
680
- >>> process_data(1.5 * u.second) # Quantity with units
681
- """
682
-
683
- # ============================================================================
684
- # Data Type Annotations
685
- # ============================================================================
686
-
687
- DType = np.dtype
688
- """Alias for NumPy's dtype type.
689
-
690
- Used to represent data types of arrays in a clear and consistent manner.
691
-
692
- Examples
693
- --------
694
- .. code-block:: python
695
-
696
- >>> def create_array(shape: Shape, dtype: DType) -> jax.Array:
697
- ... return jnp.zeros(shape, dtype=dtype)
698
- >>>
699
- >>> # Usage
700
- >>> arr = create_array((3, 4), np.float32)
701
- """
702
-
703
-
704
- class SupportsDType(Protocol):
705
- """Protocol for objects that have a dtype property.
706
-
707
- This protocol defines the interface for any object that exposes
708
- a dtype attribute, allowing for flexible type checking.
709
-
710
- Examples
711
- --------
712
- .. code-block:: python
713
-
714
- >>> def get_dtype(obj: SupportsDType) -> DType:
715
- ... return obj.dtype
716
- >>>
717
- >>> # Works with arrays
718
- >>> arr = jnp.array([1.0, 2.0])
719
- >>> dtype = get_dtype(arr) # float32
720
- """
721
-
722
- @property
723
- def dtype(self) -> DType:
724
- """Return the data type of the object.
725
-
726
- Returns
727
- -------
728
- DType
729
- The NumPy dtype of the object.
730
- """
731
- ...
732
-
733
-
734
- DTypeLike = Union[
735
- str, # String representations like 'float32', 'int32'
736
- type[Any], # Type objects like np.float32, np.int32, float, int
737
- np.dtype, # NumPy dtype objects
738
- SupportsDType, # Objects with a dtype property
739
- ]
740
- """Union of types that can be converted to a valid JAX dtype.
741
-
742
- This is more restrictive than numpy.typing.DTypeLike as JAX doesn't support
743
- object arrays or structured dtypes. It excludes None to require explicit
744
- handling of optional dtypes.
745
-
746
- Components
747
- ----------
748
- str
749
- String representations like 'float32', 'int32', 'bool'.
750
- type[Any]
751
- Type objects like np.float32, float, int, bool.
752
- np.dtype
753
- NumPy dtype objects created with np.dtype().
754
- SupportsDType
755
- Any object with a .dtype property.
756
-
757
- Examples
758
- --------
759
- .. code-block:: python
760
-
761
- >>> def cast_array(array: ArrayLike, dtype: DTypeLike) -> jax.Array:
762
- ... '''Cast array to specified dtype.'''
763
- ... return jnp.asarray(array, dtype=dtype)
764
- >>>
765
- >>> # Valid dtype specifications
766
- >>> cast_array(data, 'float32') # String
767
- >>> cast_array(data, np.float32) # NumPy type
768
- >>> cast_array(data, float) # Python type
769
- >>> cast_array(data, np.dtype('int32')) # NumPy dtype object
770
- >>> cast_array(data, other_array) # Object with dtype property
771
- """
772
-
773
- # ============================================================================
774
- # Random Number Generation
775
- # ============================================================================
776
-
777
- SeedOrKey = Union[int, jax.Array, np.ndarray]
778
- """Type for random number generator seeds or keys.
779
-
780
- Represents values that can be used to seed random number generators
781
- or serve as PRNG keys in JAX's random number generation system.
782
-
783
- Components
784
- ----------
785
- int
786
- Integer seeds for random number generators.
787
- jax.Array
788
- JAX PRNG keys (typically created with jax.random.PRNGKey).
789
- np.ndarray
790
- NumPy arrays that can serve as random keys.
791
-
792
- Examples
793
- --------
794
- .. code-block:: python
795
-
796
- >>> def generate_random(key: SeedOrKey, shape: Shape) -> jax.Array:
797
- ... '''Generate random numbers using the provided seed or key.'''
798
- ... if isinstance(key, int):
799
- ... key = jax.random.PRNGKey(key)
800
- ... return jax.random.normal(key, shape)
801
- >>>
802
- >>> # Valid seeds/keys
803
- >>> generate_random(42, (3, 4)) # Integer seed
804
- >>> generate_random(jax.random.PRNGKey(123), (5,)) # JAX PRNG key
805
- >>> generate_random(np.array([1, 2], dtype=np.uint32), (2, 2)) # NumPy array
806
- """
807
-
808
-
809
- # ============================================================================
810
- # Utility Types
811
- # ============================================================================
812
-
813
- class Missing:
814
- """Sentinel class to represent missing or unspecified values.
815
-
816
- This class is used as a default value when None has semantic meaning
817
- and you need to distinguish between "None was passed" and "nothing was passed".
818
-
819
- Examples
820
- --------
821
- .. code-block:: python
822
-
823
- >>> _MISSING = Missing()
824
- >>>
825
- >>> def function_with_optional_param(value: Union[int, None, Missing] = _MISSING):
826
- ... if value is _MISSING:
827
- ... print("No value provided")
828
- ... elif value is None:
829
- ... print("None was explicitly provided")
830
- ... else:
831
- ... print(f"Value: {value}")
832
- >>>
833
- >>> function_with_optional_param() # "No value provided"
834
- >>> function_with_optional_param(None) # "None was explicitly provided"
835
- >>> function_with_optional_param(42) # "Value: 42"
836
- """
837
- pass
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ Comprehensive type annotations for BrainState.
18
+
19
+ This module provides a collection of type aliases, protocols, and generic types
20
+ specifically designed for scientific computing, neural network modeling, and
21
+ array operations within the BrainState ecosystem.
22
+
23
+ The type system is designed to be compatible with JAX, NumPy, and BrainUnit,
24
+ providing comprehensive type hints for arrays, shapes, seeds, and PyTree structures.
25
+
26
+ Examples
27
+ --------
28
+ Basic usage with array types:
29
+
30
+ .. code-block:: python
31
+
32
+ >>> import brainstate
33
+ >>> from brainstate.typing import ArrayLike, Shape, DTypeLike
34
+ >>>
35
+ >>> def process_array(data: ArrayLike, shape: Shape, dtype: DTypeLike) -> brainstate.Array:
36
+ ... return brainstate.asarray(data, dtype=dtype).reshape(shape)
37
+
38
+ Using PyTree annotations:
39
+
40
+ .. code-block:: python
41
+
42
+ >>> from brainstate.typing import PyTree
43
+ >>>
44
+ >>> def tree_function(tree: PyTree[float, "T"]) -> PyTree[float, "T"]:
45
+ ... return brainstate.tree_map(lambda x: x * 2, tree)
46
+ """
47
+
48
+ import builtins
49
+ import functools
50
+ import importlib
51
+ import inspect
52
+ from typing import (
53
+ Any, Callable, Hashable, List, Protocol, Tuple, TypeVar, Union,
54
+ runtime_checkable, TYPE_CHECKING, Generic, Sequence
55
+ )
56
+
57
+ import brainunit as u
58
+ import jax
59
+ import numpy as np
60
+
61
+ tp = importlib.import_module("typing")
62
+
63
+ __all__ = [
64
+ # Path and filter types
65
+ 'PathParts',
66
+ 'Predicate',
67
+ 'Filter',
68
+ 'FilterLiteral',
69
+
70
+ # Array and shape types
71
+ 'Array',
72
+ 'ArrayLike',
73
+ 'Shape',
74
+ 'Size',
75
+ 'Axes',
76
+ 'DType',
77
+ 'DTypeLike',
78
+ 'SupportsDType',
79
+
80
+ # PyTree types
81
+ 'PyTree',
82
+
83
+ # Random number generation
84
+ 'SeedOrKey',
85
+
86
+ # Utility types
87
+ 'Key',
88
+ 'Missing',
89
+
90
+ # Type variables
91
+ 'K',
92
+ '_T',
93
+ '_Annotation',
94
+ ]
95
+
96
+ # ============================================================================
97
+ # Type Variables
98
+ # ============================================================================
99
+
100
+ K = TypeVar('K', bound='Key')
101
+ """Type variable for keys that must be comparable and hashable."""
102
+
103
+ _T = TypeVar("_T")
104
+ """Generic type variable for any type."""
105
+
106
+ _Annotation = TypeVar("_Annotation")
107
+ """Type variable for array annotations."""
108
+
109
+
110
+ # ============================================================================
111
+ # Key and Path Types
112
+ # ============================================================================
113
+
114
+ @runtime_checkable
115
+ class Key(Hashable, Protocol):
116
+ """Protocol for keys that can be used in PyTree paths.
117
+
118
+ A Key must be both hashable and comparable, making it suitable
119
+ for use as dictionary keys and for ordering operations.
120
+
121
+ Examples
122
+ --------
123
+ Valid key types include:
124
+
125
+ .. code-block:: python
126
+
127
+ >>> # String keys
128
+ >>> key1: Key = "layer1"
129
+ >>>
130
+ >>> # Integer keys
131
+ >>> key2: Key = 42
132
+ >>>
133
+ >>> # Custom hashable objects
134
+ >>> class CustomKey:
135
+ ... def __init__(self, name: str):
136
+ ... self.name = name
137
+ ...
138
+ ... def __hash__(self) -> int:
139
+ ... return hash(self.name)
140
+ ...
141
+ ... def __eq__(self, other) -> bool:
142
+ ... return isinstance(other, CustomKey) and self.name == other.name
143
+ ...
144
+ ... def __lt__(self, other) -> bool:
145
+ ... return isinstance(other, CustomKey) and self.name < other.name
146
+ """
147
+
148
+ def __lt__(self: K, value: K, /) -> bool:
149
+ """Less than comparison for ordering keys.
150
+
151
+ Parameters
152
+ ----------
153
+ value : Key
154
+ The key to compare against.
155
+
156
+ Returns
157
+ -------
158
+ bool
159
+ True if this key is less than the other key.
160
+ """
161
+ ...
162
+
163
+
164
+ Ellipsis = builtins.ellipsis if TYPE_CHECKING else Any
165
+ """Type alias for ellipsis, used in filter expressions."""
166
+
167
+ PathParts = Tuple[Key, ...]
168
+ """Tuple of keys representing a path through a PyTree structure.
169
+
170
+ Examples
171
+ --------
172
+ .. code-block:: python
173
+
174
+ >>> # Path to a nested value in a PyTree
175
+ >>> path: PathParts = ("model", "layers", 0, "weights")
176
+ >>>
177
+ >>> # Empty path representing the root
178
+ >>> root_path: PathParts = ()
179
+ """
180
+
181
+ Predicate = Callable[[PathParts, Any], bool]
182
+ """Function that takes a path and value, returning whether it matches some condition.
183
+
184
+ Parameters
185
+ ----------
186
+ path : PathParts
187
+ The path to the value in the PyTree.
188
+ value : Any
189
+ The value at that path.
190
+
191
+ Returns
192
+ -------
193
+ bool
194
+ True if the path/value combination matches the predicate.
195
+
196
+ Examples
197
+ --------
198
+ .. code-block:: python
199
+
200
+ >>> def is_weight_matrix(path: PathParts, value: Any) -> bool:
201
+ ... '''Check if a value is a weight matrix (2D array).'''
202
+ ... return len(path) > 0 and "weight" in str(path[-1]) and hasattr(value, 'ndim') and value.ndim == 2
203
+ >>>
204
+ >>> def is_bias_vector(path: PathParts, value: Any) -> bool:
205
+ ... '''Check if a value is a bias vector (1D array).'''
206
+ ... return len(path) > 0 and "bias" in str(path[-1]) and hasattr(value, 'ndim') and value.ndim == 1
207
+ """
208
+
209
+ FilterLiteral = Union[type, str, Predicate, bool, Ellipsis, None]
210
+ """Basic filter types that can be used to select parts of a PyTree.
211
+
212
+ Components
213
+ ----------
214
+ type
215
+ Filter by type, e.g., `float`, `jax.Array`.
216
+ str
217
+ Filter by string matching in path keys.
218
+ Predicate
219
+ Custom function for complex filtering logic.
220
+ bool
221
+ Simple True/False filter.
222
+ Ellipsis
223
+ Wildcard filter that matches anything.
224
+ None
225
+ Filter that matches None values.
226
+
227
+ Examples
228
+ --------
229
+ .. code-block:: python
230
+
231
+ >>> # Filter by type
232
+ >>> float_filter: FilterLiteral = float
233
+ >>>
234
+ >>> # Filter by string pattern
235
+ >>> weight_filter: FilterLiteral = "weight"
236
+ >>>
237
+ >>> # Custom predicate filter
238
+ >>> matrix_filter: FilterLiteral = lambda path, x: hasattr(x, 'ndim') and x.ndim == 2
239
+ """
240
+
241
+ Filter = Union[FilterLiteral, Tuple['Filter', ...], List['Filter']]
242
+ """Flexible filter type that can be a single filter or combination of filters.
243
+
244
+ This allows for complex filtering patterns by combining multiple filter criteria.
245
+
246
+ Examples
247
+ --------
248
+ .. code-block:: python
249
+
250
+ >>> # Single filter
251
+ >>> simple_filter: Filter = "weight"
252
+ >>>
253
+ >>> # Tuple of filters (all must match)
254
+ >>> combined_filter: Filter = (float, "weight")
255
+ >>>
256
+ >>> # List of filters (any can match)
257
+ >>> alternative_filter: Filter = [int, float, "bias"]
258
+ >>>
259
+ >>> # Nested combinations
260
+ >>> complex_filter: Filter = [
261
+ ... ("weight", lambda p, x: x.ndim == 2), # 2D weight matrices
262
+ ... ("bias", lambda p, x: x.ndim == 1), # 1D bias vectors
263
+ ... ]
264
+ """
265
+
266
+
267
+ # ============================================================================
268
+ # Array Annotation Types
269
+ # ============================================================================
270
+
271
+ class _Array(Generic[_Annotation]):
272
+ """Internal generic array type for creating custom array annotations."""
273
+ pass
274
+
275
+
276
+ _Array.__module__ = "builtins"
277
+
278
+
279
+ def _item_to_str(item: Union[str, type, slice]) -> str:
280
+ """Convert an array annotation item to its string representation.
281
+
282
+ Parameters
283
+ ----------
284
+ item : Union[str, type, slice]
285
+ The item to convert to string.
286
+
287
+ Returns
288
+ -------
289
+ str
290
+ String representation of the item.
291
+
292
+ Raises
293
+ ------
294
+ NotImplementedError
295
+ If slice has a step component.
296
+ """
297
+ if isinstance(item, slice):
298
+ if item.step is not None:
299
+ raise NotImplementedError("Slice steps are not supported in array annotations")
300
+ return _item_to_str(item.start) + ": " + _item_to_str(item.stop)
301
+ elif item is ...:
302
+ return "..."
303
+ elif inspect.isclass(item):
304
+ return item.__name__
305
+ else:
306
+ return repr(item)
307
+
308
+
309
+ def _maybe_tuple_to_str(
310
+ item: Union[str, type, slice, Tuple[Union[str, type, slice], ...]]
311
+ ) -> str:
312
+ """Convert array annotation items (potentially in tuple) to string representation.
313
+
314
+ Parameters
315
+ ----------
316
+ item : Union[str, type, slice, Tuple[...]]
317
+ Single item or tuple of items to convert.
318
+
319
+ Returns
320
+ -------
321
+ str
322
+ String representation of the item(s).
323
+ """
324
+ if isinstance(item, tuple):
325
+ if len(item) == 0:
326
+ # Explicit brackets for empty tuple
327
+ return "()"
328
+ else:
329
+ # No brackets for non-empty tuple
330
+ return ", ".join([_item_to_str(i) for i in item])
331
+ else:
332
+ return _item_to_str(item)
333
+
334
+
335
+ class Array:
336
+ """Flexible array type annotation supporting shape and dtype specifications.
337
+
338
+ This class provides a convenient way to annotate arrays with shape information,
339
+ making code more self-documenting and enabling better static analysis.
340
+
341
+ Examples
342
+ --------
343
+ Basic array annotations:
344
+
345
+ .. code-block:: python
346
+
347
+ >>> from brainstate.typing import Array
348
+ >>>
349
+ >>> # Any array
350
+ >>> def process_array(x: Array) -> Array:
351
+ ... return x * 2
352
+ >>>
353
+ >>> # Array with specific shape annotation
354
+ >>> def matrix_multiply(a: Array["m, n"], b: Array["n, k"]) -> Array["m, k"]:
355
+ ... return a @ b
356
+ >>>
357
+ >>> # Array with dtype and shape
358
+ >>> def normalize_weights(weights: Array["batch, features"]) -> Array["batch, features"]:
359
+ ... return weights / weights.sum(axis=-1, keepdims=True)
360
+
361
+ Advanced shape annotations:
362
+
363
+ .. code-block:: python
364
+
365
+ >>> # Using ellipsis for flexible dimensions
366
+ >>> def flatten_batch(x: Array["batch, ..."]) -> Array["batch, -1"]:
367
+ ... return x.reshape(x.shape[0], -1)
368
+ >>>
369
+ >>> # Multiple shape constraints
370
+ >>> def attention(
371
+ ... query: Array["batch, seq_len, d_model"],
372
+ ... key: Array["batch, seq_len, d_model"],
373
+ ... value: Array["batch, seq_len, d_model"]
374
+ ... ) -> Array["batch, seq_len, d_model"]:
375
+ ... # Attention computation
376
+ ... pass
377
+ """
378
+
379
+ def __class_getitem__(cls, item):
380
+ """Create a specialized Array type with shape/dtype annotations.
381
+
382
+ Parameters
383
+ ----------
384
+ item : str, type, slice, or tuple
385
+ Shape specification, dtype, or combination thereof.
386
+
387
+ Returns
388
+ -------
389
+ _Array
390
+ Specialized array type with the given annotation.
391
+ """
392
+
393
+ class X:
394
+ pass
395
+
396
+ X.__module__ = "builtins"
397
+ X.__qualname__ = _maybe_tuple_to_str(item)
398
+ return _Array[X]
399
+
400
+
401
+ # Set module for proper display in type hints
402
+ Array.__module__ = "builtins"
403
+
404
+
405
+ # ============================================================================
406
+ # PyTree Types
407
+ # ============================================================================
408
+
409
+ class _FakePyTree(Generic[_T]):
410
+ """Internal generic PyTree type for creating specialized PyTree annotations."""
411
+ pass
412
+
413
+
414
+ _FakePyTree.__name__ = "PyTree"
415
+ _FakePyTree.__qualname__ = "PyTree"
416
+ _FakePyTree.__module__ = "builtins"
417
+
418
+
419
+ class _MetaPyTree(type):
420
+ """Metaclass for PyTree type that prevents instantiation and handles subscripting."""
421
+
422
+ def __call__(self, *args, **kwargs):
423
+ """Prevent direct instantiation of PyTree type.
424
+
425
+ Raises
426
+ ------
427
+ RuntimeError
428
+ Always raised since PyTree is a type annotation only.
429
+ """
430
+ raise RuntimeError("PyTree cannot be instantiated")
431
+
432
+ # Can't return a generic (e.g. _FakePyTree[item]) because generic aliases don't do
433
+ # the custom __instancecheck__ that we want.
434
+ # We can't add that __instancecheck__ via subclassing, e.g.
435
+ # type("PyTree", (Generic[_T],), {}), because dynamic subclassing of typeforms
436
+ # isn't allowed.
437
+ # Likewise we can't do types.new_class("PyTree", (Generic[_T],), {}) because that
438
+ # has __module__ "types", e.g. we get types.PyTree[int].
439
+ @functools.lru_cache(maxsize=None)
440
+ def __getitem__(cls, item):
441
+ if isinstance(item, tuple):
442
+ if len(item) == 2:
443
+
444
+ class X(PyTree):
445
+ leaftype = item[0]
446
+ structure = item[1].strip()
447
+
448
+ if not isinstance(X.structure, str):
449
+ raise ValueError(
450
+ "The structure annotation `struct` in "
451
+ "`brainstate.typing.PyTree[leaftype, struct]` must be be a string, "
452
+ f"e.g. `brainstate.typing.PyTree[leaftype, 'T']`. Got '{X.structure}'."
453
+ )
454
+ pieces = X.structure.split()
455
+ if len(pieces) == 0:
456
+ raise ValueError(
457
+ "The string `struct` in `brainstate.typing.PyTree[leaftype, struct]` "
458
+ "cannot be the empty string."
459
+ )
460
+ for piece_index, piece in enumerate(pieces):
461
+ if (piece_index == 0) or (piece_index == len(pieces) - 1):
462
+ if piece == "...":
463
+ continue
464
+ if not piece.isidentifier():
465
+ raise ValueError(
466
+ "The string `struct` in "
467
+ "`brainstate.typing.PyTree[leaftype, struct]` must be be a "
468
+ "whitespace-separated sequence of identifiers, e.g. "
469
+ "`brainstate.typing.PyTree[leaftype, 'T']` or "
470
+ "`brainstate.typing.PyTree[leaftype, 'foo bar']`.\n"
471
+ "(Here, 'identifier' is used in the same sense as in "
472
+ "regular Python, i.e. a valid variable name.)\n"
473
+ f"Got piece '{piece}' in overall structure '{X.structure}'."
474
+ )
475
+ name = str(_FakePyTree[item[0]])[:-1] + ', "' + item[1].strip() + '"]'
476
+ else:
477
+ raise ValueError(
478
+ "The subscript `foo` in `brainstate.typing.PyTree[foo]` must either be a "
479
+ "leaf type, e.g. `PyTree[int]`, or a 2-tuple of leaf and "
480
+ "structure, e.g. `PyTree[int, 'T']`. Received a tuple of length "
481
+ f"{len(item)}."
482
+ )
483
+ else:
484
+ name = str(_FakePyTree[item])
485
+
486
+ class X(PyTree):
487
+ leaftype = item
488
+ structure = None
489
+
490
+ X.__name__ = name
491
+ X.__qualname__ = name
492
+ if getattr(tp, "GENERATING_DOCUMENTATION", False):
493
+ X.__module__ = "builtins"
494
+ else:
495
+ X.__module__ = "brainstate.typing"
496
+ return X
497
+
498
+
499
+ # Can't do `class PyTree(Generic[_T]): ...` because we need to override the
500
+ # instancecheck for PyTree[foo], but subclassing
501
+ # `type(Generic[int])`, i.e. `typing._GenericAlias` is disallowed.
502
+ PyTree = _MetaPyTree("PyTree", (), {})
503
+ if getattr(tp, "GENERATING_DOCUMENTATION", False):
504
+ PyTree.__module__ = "builtins"
505
+ else:
506
+ PyTree.__module__ = "brainstate.typing"
507
+ PyTree.__doc__ = """Represents a PyTree.
508
+
509
+ Annotations of the following sorts are supported:
510
+
511
+ .. code-block:: python
512
+
513
+ >>> a: PyTree
514
+ >>> b: PyTree[LeafType]
515
+ >>> c: PyTree[LeafType, "T"]
516
+ >>> d: PyTree[LeafType, "S T"]
517
+ >>> e: PyTree[LeafType, "... T"]
518
+ >>> f: PyTree[LeafType, "T ..."]
519
+
520
+ These correspond to:
521
+
522
+ a. A plain `PyTree` can be used an annotation, in which case `PyTree` is simply a
523
+ suggestively-named alternative to `Any`.
524
+ ([By definition all types are PyTrees.](https://jax.readthedocs.io/en/latest/pytrees.html))
525
+
526
+ b. `PyTree[LeafType]` denotes a PyTree all of whose leaves match `LeafType`. For
527
+ example, `PyTree[int]` or `PyTree[Union[str, Float32[Array, "b c"]]]`.
528
+
529
+ c. A structure name can also be passed. In this case
530
+ `jax.tree_util.tree_structure(...)` will be called, and bound to the structure name.
531
+ This can be used to mark that multiple PyTrees all have the same structure:
532
+
533
+ .. code-block:: python
534
+
535
+ >>> def f(x: PyTree[int, "T"], y: PyTree[int, "T"]):
536
+ ... ...
537
+
538
+ d. A composite structure can be declared. In this case the variable must have a PyTree
539
+ structure each to the composition of multiple previously-bound PyTree structures.
540
+ For example:
541
+
542
+ .. code-block:: python
543
+
544
+ >>> def f(x: PyTree[int, "T"], y: PyTree[int, "S"], z: PyTree[int, "S T"]):
545
+ ... ...
546
+ >>>
547
+ >>> x = (1, 2)
548
+ >>> y = {"key": 3}
549
+ >>> z = {"key": (4, 5)} # structure is the composition of the structures of `y` and `z`
550
+ >>> f(x, y, z)
551
+
552
+ When performing runtime type-checking, all the individual pieces must have already
553
+ been bound to structures, otherwise the composite structure check will throw an error.
554
+
555
+ e. A structure can begin with a `...`, to denote that the lower levels of the PyTree
556
+ must match the declared structure, but the upper levels can be arbitrary. As in the
557
+ previous case, all named pieces must already have been seen and their structures
558
+ bound.
559
+
560
+ f. A structure can end with a `...`, to denote that the PyTree must be a prefix of the
561
+ declared structure, but the lower levels can be arbitrary. As in the previous two
562
+ cases, all named pieces must already have been seen and their structures bound.
563
+ """ # noqa: E501
564
+
565
+ # ============================================================================
566
+ # Shape and Size Types
567
+ # ============================================================================
568
+
569
+ Size = Union[int, Sequence[int], np.integer, Sequence[np.integer]]
570
+ """Type for specifying array sizes and dimensions.
571
+
572
+ Can be a single integer for 1D sizes, or a sequence of integers for multi-dimensional shapes.
573
+ Supports both Python integers and NumPy integer types for compatibility.
574
+
575
+ Examples
576
+ --------
577
+ .. code-block:: python
578
+
579
+ >>> # Single dimension
580
+ >>> size1: Size = 10
581
+ >>>
582
+ >>> # Multiple dimensions
583
+ >>> size2: Size = (3, 4, 5)
584
+ >>>
585
+ >>> # Using NumPy integers
586
+ >>> size3: Size = np.int32(8)
587
+ >>>
588
+ >>> # Mixed sequence
589
+ >>> size4: Size = [np.int64(2), 3, np.int32(4)]
590
+ """
591
+
592
+ Shape = Sequence[int]
593
+ """Type for array shapes as sequences of integers.
594
+
595
+ Represents the shape of an array as a sequence of dimension sizes.
596
+ More restrictive than Size as it requires a sequence.
597
+
598
+ Examples
599
+ --------
600
+ .. code-block:: python
601
+
602
+ >>> # 2D array shape
603
+ >>> matrix_shape: Shape = (10, 20)
604
+ >>>
605
+ >>> # 3D array shape
606
+ >>> tensor_shape: Shape = (5, 10, 15)
607
+ >>>
608
+ >>> # 1D array shape (note: still needs to be a sequence)
609
+ >>> vector_shape: Shape = (100,)
610
+ """
611
+
612
+ Axes = Union[int, Sequence[int]]
613
+ """Type for specifying axes along which operations should be performed.
614
+
615
+ Can be a single axis (integer) or multiple axes (sequence of integers).
616
+ Used in reduction operations, reshaping, and other array manipulations.
617
+
618
+ Examples
619
+ --------
620
+ .. code-block:: python
621
+
622
+ >>> # Single axis
623
+ >>> axis1: Axes = 0
624
+ >>>
625
+ >>> # Multiple axes
626
+ >>> axis2: Axes = (0, 2)
627
+ >>>
628
+ >>> # All axes for global operations
629
+ >>> axis3: Axes = tuple(range(ndim))
630
+ >>>
631
+ >>> def sum_along_axes(array: ArrayLike, axes: Axes) -> ArrayLike:
632
+ ... return jnp.sum(array, axis=axes)
633
+ """
634
+
635
+ # ============================================================================
636
+ # Array Types
637
+ # ============================================================================
638
+
639
+ ArrayLike = Union[
640
+ jax.Array, # JAX array type
641
+ np.ndarray, # NumPy array type
642
+ np.bool_, np.number, # NumPy scalar types
643
+ bool, int, float, complex, # Python scalar types
644
+ u.Quantity, # BrainUnit quantity type
645
+ ]
646
+ """Union of all objects that can be implicitly converted to a JAX array.
647
+
648
+ This type is designed for JAX compatibility and excludes arbitrary sequences
649
+ and string data that numpy.typing.ArrayLike would include. It represents
650
+ data that can be safely converted to arrays without ambiguity.
651
+
652
+ Components
653
+ ----------
654
+ jax.Array
655
+ Native JAX arrays.
656
+ np.ndarray
657
+ NumPy arrays that can be converted to JAX arrays.
658
+ np.bool_, np.number
659
+ NumPy scalar types (bool, int8, float32, etc.).
660
+ bool, int, float, complex
661
+ Python built-in scalar types.
662
+ u.Quantity
663
+ BrainUnit quantities with physical units.
664
+
665
+ Examples
666
+ --------
667
+ .. code-block:: python
668
+
669
+ >>> def process_data(data: ArrayLike) -> jax.Array:
670
+ ... '''Convert input to JAX array and process it.'''
671
+ ... array = jnp.asarray(data)
672
+ ... return array * 2
673
+ >>>
674
+ >>> # Valid inputs
675
+ >>> process_data(jnp.array([1, 2, 3])) # JAX array
676
+ >>> process_data(np.array([1, 2, 3])) # NumPy array
677
+ >>> process_data([1, 2, 3]) # Python list (via numpy)
678
+ >>> process_data(42) # Python scalar
679
+ >>> process_data(np.float32(3.14)) # NumPy scalar
680
+ >>> process_data(1.5 * u.second) # Quantity with units
681
+ """
682
+
683
+ # ============================================================================
684
+ # Data Type Annotations
685
+ # ============================================================================
686
+
687
+ DType = np.dtype
688
+ """Alias for NumPy's dtype type.
689
+
690
+ Used to represent data types of arrays in a clear and consistent manner.
691
+
692
+ Examples
693
+ --------
694
+ .. code-block:: python
695
+
696
+ >>> def create_array(shape: Shape, dtype: DType) -> jax.Array:
697
+ ... return jnp.zeros(shape, dtype=dtype)
698
+ >>>
699
+ >>> # Usage
700
+ >>> arr = create_array((3, 4), np.float32)
701
+ """
702
+
703
+
704
+ class SupportsDType(Protocol):
705
+ """Protocol for objects that have a dtype property.
706
+
707
+ This protocol defines the interface for any object that exposes
708
+ a dtype attribute, allowing for flexible type checking.
709
+
710
+ Examples
711
+ --------
712
+ .. code-block:: python
713
+
714
+ >>> def get_dtype(obj: SupportsDType) -> DType:
715
+ ... return obj.dtype
716
+ >>>
717
+ >>> # Works with arrays
718
+ >>> arr = jnp.array([1.0, 2.0])
719
+ >>> dtype = get_dtype(arr) # float32
720
+ """
721
+
722
+ @property
723
+ def dtype(self) -> DType:
724
+ """Return the data type of the object.
725
+
726
+ Returns
727
+ -------
728
+ DType
729
+ The NumPy dtype of the object.
730
+ """
731
+ ...
732
+
733
+
734
+ DTypeLike = Union[
735
+ str, # String representations like 'float32', 'int32'
736
+ type[Any], # Type objects like np.float32, np.int32, float, int
737
+ np.dtype, # NumPy dtype objects
738
+ SupportsDType, # Objects with a dtype property
739
+ ]
740
+ """Union of types that can be converted to a valid JAX dtype.
741
+
742
+ This is more restrictive than numpy.typing.DTypeLike as JAX doesn't support
743
+ object arrays or structured dtypes. It excludes None to require explicit
744
+ handling of optional dtypes.
745
+
746
+ Components
747
+ ----------
748
+ str
749
+ String representations like 'float32', 'int32', 'bool'.
750
+ type[Any]
751
+ Type objects like np.float32, float, int, bool.
752
+ np.dtype
753
+ NumPy dtype objects created with np.dtype().
754
+ SupportsDType
755
+ Any object with a .dtype property.
756
+
757
+ Examples
758
+ --------
759
+ .. code-block:: python
760
+
761
+ >>> def cast_array(array: ArrayLike, dtype: DTypeLike) -> jax.Array:
762
+ ... '''Cast array to specified dtype.'''
763
+ ... return jnp.asarray(array, dtype=dtype)
764
+ >>>
765
+ >>> # Valid dtype specifications
766
+ >>> cast_array(data, 'float32') # String
767
+ >>> cast_array(data, np.float32) # NumPy type
768
+ >>> cast_array(data, float) # Python type
769
+ >>> cast_array(data, np.dtype('int32')) # NumPy dtype object
770
+ >>> cast_array(data, other_array) # Object with dtype property
771
+ """
772
+
773
+ # ============================================================================
774
+ # Random Number Generation
775
+ # ============================================================================
776
+
777
+ SeedOrKey = Union[int, jax.Array, np.ndarray]
778
+ """Type for random number generator seeds or keys.
779
+
780
+ Represents values that can be used to seed random number generators
781
+ or serve as PRNG keys in JAX's random number generation system.
782
+
783
+ Components
784
+ ----------
785
+ int
786
+ Integer seeds for random number generators.
787
+ jax.Array
788
+ JAX PRNG keys (typically created with jax.random.PRNGKey).
789
+ np.ndarray
790
+ NumPy arrays that can serve as random keys.
791
+
792
+ Examples
793
+ --------
794
+ .. code-block:: python
795
+
796
+ >>> def generate_random(key: SeedOrKey, shape: Shape) -> jax.Array:
797
+ ... '''Generate random numbers using the provided seed or key.'''
798
+ ... if isinstance(key, int):
799
+ ... key = jax.random.PRNGKey(key)
800
+ ... return jax.random.normal(key, shape)
801
+ >>>
802
+ >>> # Valid seeds/keys
803
+ >>> generate_random(42, (3, 4)) # Integer seed
804
+ >>> generate_random(jax.random.PRNGKey(123), (5,)) # JAX PRNG key
805
+ >>> generate_random(np.array([1, 2], dtype=np.uint32), (2, 2)) # NumPy array
806
+ """
807
+
808
+
809
+ # ============================================================================
810
+ # Utility Types
811
+ # ============================================================================
812
+
813
+ class Missing:
814
+ """Sentinel class to represent missing or unspecified values.
815
+
816
+ This class is used as a default value when None has semantic meaning
817
+ and you need to distinguish between "None was passed" and "nothing was passed".
818
+
819
+ Examples
820
+ --------
821
+ .. code-block:: python
822
+
823
+ >>> _MISSING = Missing()
824
+ >>>
825
+ >>> def function_with_optional_param(value: Union[int, None, Missing] = _MISSING):
826
+ ... if value is _MISSING:
827
+ ... print("No value provided")
828
+ ... elif value is None:
829
+ ... print("None was explicitly provided")
830
+ ... else:
831
+ ... print(f"Value: {value}")
832
+ >>>
833
+ >>> function_with_optional_param() # "No value provided"
834
+ >>> function_with_optional_param(None) # "None was explicitly provided"
835
+ >>> function_with_optional_param(42) # "Value: 42"
836
+ """
837
+ pass