brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/typing.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -13,8 +13,40 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ """
17
+ Comprehensive type annotations for BrainState.
18
+
19
+ This module provides a collection of type aliases, protocols, and generic types
20
+ specifically designed for scientific computing, neural network modeling, and
21
+ array operations within the BrainState ecosystem.
22
+
23
+ The type system is designed to be compatible with JAX, NumPy, and BrainUnit,
24
+ providing comprehensive type hints for arrays, shapes, seeds, and PyTree structures.
25
+
26
+ Examples
27
+ --------
28
+ Basic usage with array types:
29
+
30
+ .. code-block:: python
31
+
32
+ >>> import brainstate
33
+ >>> from brainstate.typing import ArrayLike, Shape, DTypeLike
34
+ >>>
35
+ >>> def process_array(data: ArrayLike, shape: Shape, dtype: DTypeLike) -> brainstate.Array:
36
+ ... return brainstate.asarray(data, dtype=dtype).reshape(shape)
37
+
38
+ Using PyTree annotations:
39
+
40
+ .. code-block:: python
41
+
42
+ >>> from brainstate.typing import PyTree
43
+ >>>
44
+ >>> def tree_function(tree: PyTree[float, "T"]) -> PyTree[float, "T"]:
45
+ ... return brainstate.tree_map(lambda x: x * 2, tree)
46
+ """
47
+
16
48
  import builtins
17
- import functools as ft
49
+ import functools
18
50
  import importlib
19
51
  import inspect
20
52
  from typing import (
@@ -29,42 +61,215 @@ import numpy as np
29
61
  tp = importlib.import_module("typing")
30
62
 
31
63
  __all__ = [
64
+ # Path and filter types
32
65
  'PathParts',
33
66
  'Predicate',
34
67
  'Filter',
35
- 'PyTree',
36
- 'Size',
68
+ 'FilterLiteral',
69
+
70
+ # Array and shape types
71
+ 'Array',
72
+ 'ArrayLike',
37
73
  'Shape',
74
+ 'Size',
38
75
  'Axes',
39
- 'SeedOrKey',
40
- 'ArrayLike',
41
76
  'DType',
42
77
  'DTypeLike',
78
+ 'SupportsDType',
79
+
80
+ # PyTree types
81
+ 'PyTree',
82
+
83
+ # Random number generation
84
+ 'SeedOrKey',
85
+
86
+ # Utility types
87
+ 'Key',
43
88
  'Missing',
89
+
90
+ # Type variables
91
+ 'K',
92
+ '_T',
93
+ '_Annotation',
44
94
  ]
45
95
 
46
- K = TypeVar('K')
96
+ # ============================================================================
97
+ # Type Variables
98
+ # ============================================================================
99
+
100
+ K = TypeVar('K', bound='Key')
101
+ """Type variable for keys that must be comparable and hashable."""
102
+
103
+ _T = TypeVar("_T")
104
+ """Generic type variable for any type."""
105
+
106
+ _Annotation = TypeVar("_Annotation")
107
+ """Type variable for array annotations."""
108
+
47
109
 
110
+ # ============================================================================
111
+ # Key and Path Types
112
+ # ============================================================================
48
113
 
49
114
  @runtime_checkable
50
115
  class Key(Hashable, Protocol):
116
+ """Protocol for keys that can be used in PyTree paths.
117
+
118
+ A Key must be both hashable and comparable, making it suitable
119
+ for use as dictionary keys and for ordering operations.
120
+
121
+ Examples
122
+ --------
123
+ Valid key types include:
124
+
125
+ .. code-block:: python
126
+
127
+ >>> # String keys
128
+ >>> key1: Key = "layer1"
129
+ >>>
130
+ >>> # Integer keys
131
+ >>> key2: Key = 42
132
+ >>>
133
+ >>> # Custom hashable objects
134
+ >>> class CustomKey:
135
+ ... def __init__(self, name: str):
136
+ ... self.name = name
137
+ ...
138
+ ... def __hash__(self) -> int:
139
+ ... return hash(self.name)
140
+ ...
141
+ ... def __eq__(self, other) -> bool:
142
+ ... return isinstance(other, CustomKey) and self.name == other.name
143
+ ...
144
+ ... def __lt__(self, other) -> bool:
145
+ ... return isinstance(other, CustomKey) and self.name < other.name
146
+ """
147
+
51
148
  def __lt__(self: K, value: K, /) -> bool:
149
+ """Less than comparison for ordering keys.
150
+
151
+ Parameters
152
+ ----------
153
+ value : Key
154
+ The key to compare against.
155
+
156
+ Returns
157
+ -------
158
+ bool
159
+ True if this key is less than the other key.
160
+ """
52
161
  ...
53
162
 
54
163
 
55
164
  Ellipsis = builtins.ellipsis if TYPE_CHECKING else Any
165
+ """Type alias for ellipsis, used in filter expressions."""
56
166
 
57
167
  PathParts = Tuple[Key, ...]
58
- Predicate = Callable[[PathParts, Any], bool]
59
- FilterLiteral = Union[type, str, Predicate, bool, Ellipsis, None]
60
- Filter = Union[FilterLiteral, Tuple['Filter', ...], List['Filter']]
168
+ """Tuple of keys representing a path through a PyTree structure.
61
169
 
62
- _T = TypeVar("_T")
170
+ Examples
171
+ --------
172
+ .. code-block:: python
63
173
 
64
- _Annotation = TypeVar("_Annotation")
174
+ >>> # Path to a nested value in a PyTree
175
+ >>> path: PathParts = ("model", "layers", 0, "weights")
176
+ >>>
177
+ >>> # Empty path representing the root
178
+ >>> root_path: PathParts = ()
179
+ """
65
180
 
181
+ Predicate = Callable[[PathParts, Any], bool]
182
+ """Function that takes a path and value, returning whether it matches some condition.
183
+
184
+ Parameters
185
+ ----------
186
+ path : PathParts
187
+ The path to the value in the PyTree.
188
+ value : Any
189
+ The value at that path.
190
+
191
+ Returns
192
+ -------
193
+ bool
194
+ True if the path/value combination matches the predicate.
195
+
196
+ Examples
197
+ --------
198
+ .. code-block:: python
199
+
200
+ >>> def is_weight_matrix(path: PathParts, value: Any) -> bool:
201
+ ... '''Check if a value is a weight matrix (2D array).'''
202
+ ... return len(path) > 0 and "weight" in str(path[-1]) and hasattr(value, 'ndim') and value.ndim == 2
203
+ >>>
204
+ >>> def is_bias_vector(path: PathParts, value: Any) -> bool:
205
+ ... '''Check if a value is a bias vector (1D array).'''
206
+ ... return len(path) > 0 and "bias" in str(path[-1]) and hasattr(value, 'ndim') and value.ndim == 1
207
+ """
208
+
209
+ FilterLiteral = Union[type, str, Predicate, bool, Ellipsis, None]
210
+ """Basic filter types that can be used to select parts of a PyTree.
211
+
212
+ Components
213
+ ----------
214
+ type
215
+ Filter by type, e.g., `float`, `jax.Array`.
216
+ str
217
+ Filter by string matching in path keys.
218
+ Predicate
219
+ Custom function for complex filtering logic.
220
+ bool
221
+ Simple True/False filter.
222
+ Ellipsis
223
+ Wildcard filter that matches anything.
224
+ None
225
+ Filter that matches None values.
226
+
227
+ Examples
228
+ --------
229
+ .. code-block:: python
230
+
231
+ >>> # Filter by type
232
+ >>> float_filter: FilterLiteral = float
233
+ >>>
234
+ >>> # Filter by string pattern
235
+ >>> weight_filter: FilterLiteral = "weight"
236
+ >>>
237
+ >>> # Custom predicate filter
238
+ >>> matrix_filter: FilterLiteral = lambda path, x: hasattr(x, 'ndim') and x.ndim == 2
239
+ """
240
+
241
+ Filter = Union[FilterLiteral, Tuple['Filter', ...], List['Filter']]
242
+ """Flexible filter type that can be a single filter or combination of filters.
243
+
244
+ This allows for complex filtering patterns by combining multiple filter criteria.
245
+
246
+ Examples
247
+ --------
248
+ .. code-block:: python
249
+
250
+ >>> # Single filter
251
+ >>> simple_filter: Filter = "weight"
252
+ >>>
253
+ >>> # Tuple of filters (all must match)
254
+ >>> combined_filter: Filter = (float, "weight")
255
+ >>>
256
+ >>> # List of filters (any can match)
257
+ >>> alternative_filter: Filter = [int, float, "bias"]
258
+ >>>
259
+ >>> # Nested combinations
260
+ >>> complex_filter: Filter = [
261
+ ... ("weight", lambda p, x: x.ndim == 2), # 2D weight matrices
262
+ ... ("bias", lambda p, x: x.ndim == 1), # 1D bias vectors
263
+ ... ]
264
+ """
265
+
266
+
267
+ # ============================================================================
268
+ # Array Annotation Types
269
+ # ============================================================================
66
270
 
67
271
  class _Array(Generic[_Annotation]):
272
+ """Internal generic array type for creating custom array annotations."""
68
273
  pass
69
274
 
70
275
 
@@ -72,9 +277,26 @@ _Array.__module__ = "builtins"
72
277
 
73
278
 
74
279
  def _item_to_str(item: Union[str, type, slice]) -> str:
280
+ """Convert an array annotation item to its string representation.
281
+
282
+ Parameters
283
+ ----------
284
+ item : Union[str, type, slice]
285
+ The item to convert to string.
286
+
287
+ Returns
288
+ -------
289
+ str
290
+ String representation of the item.
291
+
292
+ Raises
293
+ ------
294
+ NotImplementedError
295
+ If slice has a step component.
296
+ """
75
297
  if isinstance(item, slice):
76
298
  if item.step is not None:
77
- raise NotImplementedError
299
+ raise NotImplementedError("Slice steps are not supported in array annotations")
78
300
  return _item_to_str(item.start) + ": " + _item_to_str(item.stop)
79
301
  elif item is ...:
80
302
  return "..."
@@ -87,19 +309,87 @@ def _item_to_str(item: Union[str, type, slice]) -> str:
87
309
  def _maybe_tuple_to_str(
88
310
  item: Union[str, type, slice, Tuple[Union[str, type, slice], ...]]
89
311
  ) -> str:
312
+ """Convert array annotation items (potentially in tuple) to string representation.
313
+
314
+ Parameters
315
+ ----------
316
+ item : Union[str, type, slice, Tuple[...]]
317
+ Single item or tuple of items to convert.
318
+
319
+ Returns
320
+ -------
321
+ str
322
+ String representation of the item(s).
323
+ """
90
324
  if isinstance(item, tuple):
91
325
  if len(item) == 0:
92
- # Explicit brackets
326
+ # Explicit brackets for empty tuple
93
327
  return "()"
94
328
  else:
95
- # No brackets
329
+ # No brackets for non-empty tuple
96
330
  return ", ".join([_item_to_str(i) for i in item])
97
331
  else:
98
332
  return _item_to_str(item)
99
333
 
100
334
 
101
335
  class Array:
336
+ """Flexible array type annotation supporting shape and dtype specifications.
337
+
338
+ This class provides a convenient way to annotate arrays with shape information,
339
+ making code more self-documenting and enabling better static analysis.
340
+
341
+ Examples
342
+ --------
343
+ Basic array annotations:
344
+
345
+ .. code-block:: python
346
+
347
+ >>> from brainstate.typing import Array
348
+ >>>
349
+ >>> # Any array
350
+ >>> def process_array(x: Array) -> Array:
351
+ ... return x * 2
352
+ >>>
353
+ >>> # Array with specific shape annotation
354
+ >>> def matrix_multiply(a: Array["m, n"], b: Array["n, k"]) -> Array["m, k"]:
355
+ ... return a @ b
356
+ >>>
357
+ >>> # Array with dtype and shape
358
+ >>> def normalize_weights(weights: Array["batch, features"]) -> Array["batch, features"]:
359
+ ... return weights / weights.sum(axis=-1, keepdims=True)
360
+
361
+ Advanced shape annotations:
362
+
363
+ .. code-block:: python
364
+
365
+ >>> # Using ellipsis for flexible dimensions
366
+ >>> def flatten_batch(x: Array["batch, ..."]) -> Array["batch, -1"]:
367
+ ... return x.reshape(x.shape[0], -1)
368
+ >>>
369
+ >>> # Multiple shape constraints
370
+ >>> def attention(
371
+ ... query: Array["batch, seq_len, d_model"],
372
+ ... key: Array["batch, seq_len, d_model"],
373
+ ... value: Array["batch, seq_len, d_model"]
374
+ ... ) -> Array["batch, seq_len, d_model"]:
375
+ ... # Attention computation
376
+ ... pass
377
+ """
378
+
102
379
  def __class_getitem__(cls, item):
380
+ """Create a specialized Array type with shape/dtype annotations.
381
+
382
+ Parameters
383
+ ----------
384
+ item : str, type, slice, or tuple
385
+ Shape specification, dtype, or combination thereof.
386
+
387
+ Returns
388
+ -------
389
+ _Array
390
+ Specialized array type with the given annotation.
391
+ """
392
+
103
393
  class X:
104
394
  pass
105
395
 
@@ -108,14 +398,16 @@ class Array:
108
398
  return _Array[X]
109
399
 
110
400
 
111
- # Same __module__ trick here again. (So that we get the correct display when
112
- # doing `def f(x: Array)` as well as `def f(x: Array["dim"])`.
113
- #
114
- # Don't need to set __qualname__ as that's already correct.
401
+ # Set module for proper display in type hints
115
402
  Array.__module__ = "builtins"
116
403
 
117
404
 
405
+ # ============================================================================
406
+ # PyTree Types
407
+ # ============================================================================
408
+
118
409
  class _FakePyTree(Generic[_T]):
410
+ """Internal generic PyTree type for creating specialized PyTree annotations."""
119
411
  pass
120
412
 
121
413
 
@@ -125,7 +417,16 @@ _FakePyTree.__module__ = "builtins"
125
417
 
126
418
 
127
419
  class _MetaPyTree(type):
420
+ """Metaclass for PyTree type that prevents instantiation and handles subscripting."""
421
+
128
422
  def __call__(self, *args, **kwargs):
423
+ """Prevent direct instantiation of PyTree type.
424
+
425
+ Raises
426
+ ------
427
+ RuntimeError
428
+ Always raised since PyTree is a type annotation only.
429
+ """
129
430
  raise RuntimeError("PyTree cannot be instantiated")
130
431
 
131
432
  # Can't return a generic (e.g. _FakePyTree[item]) because generic aliases don't do
@@ -135,7 +436,7 @@ class _MetaPyTree(type):
135
436
  # isn't allowed.
136
437
  # Likewise we can't do types.new_class("PyTree", (Generic[_T],), {}) because that
137
438
  # has __module__ "types", e.g. we get types.PyTree[int].
138
- @ft.lru_cache(maxsize=None)
439
+ @functools.lru_cache(maxsize=None)
139
440
  def __getitem__(cls, item):
140
441
  if isinstance(item, tuple):
141
442
  if len(item) == 2:
@@ -206,14 +507,15 @@ else:
206
507
  PyTree.__doc__ = """Represents a PyTree.
207
508
 
208
509
  Annotations of the following sorts are supported:
209
- ```python
210
- a: PyTree
211
- b: PyTree[LeafType]
212
- c: PyTree[LeafType, "T"]
213
- d: PyTree[LeafType, "S T"]
214
- e: PyTree[LeafType, "... T"]
215
- f: PyTree[LeafType, "T ..."]
216
- ```
510
+
511
+ .. code-block:: python
512
+
513
+ >>> a: PyTree
514
+ >>> b: PyTree[LeafType]
515
+ >>> c: PyTree[LeafType, "T"]
516
+ >>> d: PyTree[LeafType, "S T"]
517
+ >>> e: PyTree[LeafType, "... T"]
518
+ >>> f: PyTree[LeafType, "T ..."]
217
519
 
218
520
  These correspond to:
219
521
 
@@ -227,23 +529,26 @@ b. `PyTree[LeafType]` denotes a PyTree all of whose leaves match `LeafType`. For
227
529
  c. A structure name can also be passed. In this case
228
530
  `jax.tree_util.tree_structure(...)` will be called, and bound to the structure name.
229
531
  This can be used to mark that multiple PyTrees all have the same structure:
230
- ```python
231
- def f(x: PyTree[int, "T"], y: PyTree[int, "T"]):
232
- ...
233
- ```
532
+
533
+ .. code-block:: python
534
+
535
+ >>> def f(x: PyTree[int, "T"], y: PyTree[int, "T"]):
536
+ ... ...
234
537
 
235
538
  d. A composite structure can be declared. In this case the variable must have a PyTree
236
539
  structure each to the composition of multiple previously-bound PyTree structures.
237
540
  For example:
238
- ```python
239
- def f(x: PyTree[int, "T"], y: PyTree[int, "S"], z: PyTree[int, "S T"]):
240
- ...
241
-
242
- x = (1, 2)
243
- y = {"key": 3}
244
- z = {"key": (4, 5)} # structure is the composition of the structures of `y` and `z`
245
- f(x, y, z)
246
- ```
541
+
542
+ .. code-block:: python
543
+
544
+ >>> def f(x: PyTree[int, "T"], y: PyTree[int, "S"], z: PyTree[int, "S T"]):
545
+ ... ...
546
+ >>>
547
+ >>> x = (1, 2)
548
+ >>> y = {"key": 3}
549
+ >>> z = {"key": (4, 5)} # structure is the composition of the structures of `y` and `z`
550
+ >>> f(x, y, z)
551
+
247
552
  When performing runtime type-checking, all the individual pieces must have already
248
553
  been bound to structures, otherwise the composite structure check will throw an error.
249
554
 
@@ -257,48 +562,276 @@ f. A structure can end with a `...`, to denote that the PyTree must be a prefix
257
562
  cases, all named pieces must already have been seen and their structures bound.
258
563
  """ # noqa: E501
259
564
 
565
+ # ============================================================================
566
+ # Shape and Size Types
567
+ # ============================================================================
568
+
260
569
  Size = Union[int, Sequence[int], np.integer, Sequence[np.integer]]
261
- Axes = Union[int, Sequence[int]]
262
- SeedOrKey = Union[int, jax.Array, np.ndarray]
570
+ """Type for specifying array sizes and dimensions.
571
+
572
+ Can be a single integer for 1D sizes, or a sequence of integers for multi-dimensional shapes.
573
+ Supports both Python integers and NumPy integer types for compatibility.
574
+
575
+ Examples
576
+ --------
577
+ .. code-block:: python
578
+
579
+ >>> # Single dimension
580
+ >>> size1: Size = 10
581
+ >>>
582
+ >>> # Multiple dimensions
583
+ >>> size2: Size = (3, 4, 5)
584
+ >>>
585
+ >>> # Using NumPy integers
586
+ >>> size3: Size = np.int32(8)
587
+ >>>
588
+ >>> # Mixed sequence
589
+ >>> size4: Size = [np.int64(2), 3, np.int32(4)]
590
+ """
591
+
263
592
  Shape = Sequence[int]
593
+ """Type for array shapes as sequences of integers.
264
594
 
265
- # --- Array --- #
595
+ Represents the shape of an array as a sequence of dimension sizes.
596
+ More restrictive than Size as it requires a sequence.
597
+
598
+ Examples
599
+ --------
600
+ .. code-block:: python
601
+
602
+ >>> # 2D array shape
603
+ >>> matrix_shape: Shape = (10, 20)
604
+ >>>
605
+ >>> # 3D array shape
606
+ >>> tensor_shape: Shape = (5, 10, 15)
607
+ >>>
608
+ >>> # 1D array shape (note: still needs to be a sequence)
609
+ >>> vector_shape: Shape = (100,)
610
+ """
611
+
612
+ Axes = Union[int, Sequence[int]]
613
+ """Type for specifying axes along which operations should be performed.
614
+
615
+ Can be a single axis (integer) or multiple axes (sequence of integers).
616
+ Used in reduction operations, reshaping, and other array manipulations.
617
+
618
+ Examples
619
+ --------
620
+ .. code-block:: python
621
+
622
+ >>> # Single axis
623
+ >>> axis1: Axes = 0
624
+ >>>
625
+ >>> # Multiple axes
626
+ >>> axis2: Axes = (0, 2)
627
+ >>>
628
+ >>> # All axes for global operations
629
+ >>> axis3: Axes = tuple(range(ndim))
630
+ >>>
631
+ >>> def sum_along_axes(array: ArrayLike, axes: Axes) -> ArrayLike:
632
+ ... return jnp.sum(array, axis=axes)
633
+ """
634
+
635
+ # ============================================================================
636
+ # Array Types
637
+ # ============================================================================
266
638
 
267
- # ArrayLike is a Union of all objects that can be implicitly converted to a
268
- # standard JAX array (i.e. not including future non-standard array types like
269
- # KeyArray and BInt). It's different than np.typing.ArrayLike in that it doesn't
270
- # accept arbitrary sequences, nor does it accept string data.
271
639
  ArrayLike = Union[
272
640
  jax.Array, # JAX array type
273
641
  np.ndarray, # NumPy array type
274
642
  np.bool_, np.number, # NumPy scalar types
275
643
  bool, int, float, complex, # Python scalar types
276
- u.Quantity, # Quantity
644
+ u.Quantity, # BrainUnit quantity type
277
645
  ]
646
+ """Union of all objects that can be implicitly converted to a JAX array.
647
+
648
+ This type is designed for JAX compatibility and excludes arbitrary sequences
649
+ and string data that numpy.typing.ArrayLike would include. It represents
650
+ data that can be safely converted to arrays without ambiguity.
651
+
652
+ Components
653
+ ----------
654
+ jax.Array
655
+ Native JAX arrays.
656
+ np.ndarray
657
+ NumPy arrays that can be converted to JAX arrays.
658
+ np.bool_, np.number
659
+ NumPy scalar types (bool, int8, float32, etc.).
660
+ bool, int, float, complex
661
+ Python built-in scalar types.
662
+ u.Quantity
663
+ BrainUnit quantities with physical units.
664
+
665
+ Examples
666
+ --------
667
+ .. code-block:: python
668
+
669
+ >>> def process_data(data: ArrayLike) -> jax.Array:
670
+ ... '''Convert input to JAX array and process it.'''
671
+ ... array = jnp.asarray(data)
672
+ ... return array * 2
673
+ >>>
674
+ >>> # Valid inputs
675
+ >>> process_data(jnp.array([1, 2, 3])) # JAX array
676
+ >>> process_data(np.array([1, 2, 3])) # NumPy array
677
+ >>> process_data([1, 2, 3]) # Python list (via numpy)
678
+ >>> process_data(42) # Python scalar
679
+ >>> process_data(np.float32(3.14)) # NumPy scalar
680
+ >>> process_data(1.5 * u.second) # Quantity with units
681
+ """
682
+
683
+ # ============================================================================
684
+ # Data Type Annotations
685
+ # ============================================================================
278
686
 
279
- # --- Dtype --- #
687
+ DType = np.dtype
688
+ """Alias for NumPy's dtype type.
280
689
 
690
+ Used to represent data types of arrays in a clear and consistent manner.
281
691
 
282
- DType = np.dtype
692
+ Examples
693
+ --------
694
+ .. code-block:: python
695
+
696
+ >>> def create_array(shape: Shape, dtype: DType) -> jax.Array:
697
+ ... return jnp.zeros(shape, dtype=dtype)
698
+ >>>
699
+ >>> # Usage
700
+ >>> arr = create_array((3, 4), np.float32)
701
+ """
283
702
 
284
703
 
285
704
  class SupportsDType(Protocol):
705
+ """Protocol for objects that have a dtype property.
706
+
707
+ This protocol defines the interface for any object that exposes
708
+ a dtype attribute, allowing for flexible type checking.
709
+
710
+ Examples
711
+ --------
712
+ .. code-block:: python
713
+
714
+ >>> def get_dtype(obj: SupportsDType) -> DType:
715
+ ... return obj.dtype
716
+ >>>
717
+ >>> # Works with arrays
718
+ >>> arr = jnp.array([1.0, 2.0])
719
+ >>> dtype = get_dtype(arr) # float32
720
+ """
721
+
286
722
  @property
287
- def dtype(self) -> DType: ...
723
+ def dtype(self) -> DType:
724
+ """Return the data type of the object.
725
+
726
+ Returns
727
+ -------
728
+ DType
729
+ The NumPy dtype of the object.
730
+ """
731
+ ...
288
732
 
289
733
 
290
- # DTypeLike is meant to annotate inputs to np.dtype that return
291
- # a valid JAX dtype. It's different than numpy.typing.DTypeLike
292
- # because JAX doesn't support objects or structured dtypes.
293
- # Unlike np.typing.DTypeLike, we exclude None, and instead require
294
- # explicit annotations when None is acceptable.
295
734
  DTypeLike = Union[
296
- str, # like 'float32', 'int32'
297
- type[Any], # like np.float32, np.int32, float, int
298
- np.dtype, # like np.dtype('float32'), np.dtype('int32')
299
- SupportsDType, # like jnp.float32, jnp.int32
735
+ str, # String representations like 'float32', 'int32'
736
+ type[Any], # Type objects like np.float32, np.int32, float, int
737
+ np.dtype, # NumPy dtype objects
738
+ SupportsDType, # Objects with a dtype property
300
739
  ]
740
+ """Union of types that can be converted to a valid JAX dtype.
741
+
742
+ This is more restrictive than numpy.typing.DTypeLike as JAX doesn't support
743
+ object arrays or structured dtypes. It excludes None to require explicit
744
+ handling of optional dtypes.
745
+
746
+ Components
747
+ ----------
748
+ str
749
+ String representations like 'float32', 'int32', 'bool'.
750
+ type[Any]
751
+ Type objects like np.float32, float, int, bool.
752
+ np.dtype
753
+ NumPy dtype objects created with np.dtype().
754
+ SupportsDType
755
+ Any object with a .dtype property.
756
+
757
+ Examples
758
+ --------
759
+ .. code-block:: python
760
+
761
+ >>> def cast_array(array: ArrayLike, dtype: DTypeLike) -> jax.Array:
762
+ ... '''Cast array to specified dtype.'''
763
+ ... return jnp.asarray(array, dtype=dtype)
764
+ >>>
765
+ >>> # Valid dtype specifications
766
+ >>> cast_array(data, 'float32') # String
767
+ >>> cast_array(data, np.float32) # NumPy type
768
+ >>> cast_array(data, float) # Python type
769
+ >>> cast_array(data, np.dtype('int32')) # NumPy dtype object
770
+ >>> cast_array(data, other_array) # Object with dtype property
771
+ """
772
+
773
+ # ============================================================================
774
+ # Random Number Generation
775
+ # ============================================================================
301
776
 
777
+ SeedOrKey = Union[int, jax.Array, np.ndarray]
778
+ """Type for random number generator seeds or keys.
779
+
780
+ Represents values that can be used to seed random number generators
781
+ or serve as PRNG keys in JAX's random number generation system.
782
+
783
+ Components
784
+ ----------
785
+ int
786
+ Integer seeds for random number generators.
787
+ jax.Array
788
+ JAX PRNG keys (typically created with jax.random.PRNGKey).
789
+ np.ndarray
790
+ NumPy arrays that can serve as random keys.
791
+
792
+ Examples
793
+ --------
794
+ .. code-block:: python
795
+
796
+ >>> def generate_random(key: SeedOrKey, shape: Shape) -> jax.Array:
797
+ ... '''Generate random numbers using the provided seed or key.'''
798
+ ... if isinstance(key, int):
799
+ ... key = jax.random.PRNGKey(key)
800
+ ... return jax.random.normal(key, shape)
801
+ >>>
802
+ >>> # Valid seeds/keys
803
+ >>> generate_random(42, (3, 4)) # Integer seed
804
+ >>> generate_random(jax.random.PRNGKey(123), (5,)) # JAX PRNG key
805
+ >>> generate_random(np.array([1, 2], dtype=np.uint32), (2, 2)) # NumPy array
806
+ """
807
+
808
+
809
+ # ============================================================================
810
+ # Utility Types
811
+ # ============================================================================
302
812
 
303
813
  class Missing:
814
+ """Sentinel class to represent missing or unspecified values.
815
+
816
+ This class is used as a default value when None has semantic meaning
817
+ and you need to distinguish between "None was passed" and "nothing was passed".
818
+
819
+ Examples
820
+ --------
821
+ .. code-block:: python
822
+
823
+ >>> _MISSING = Missing()
824
+ >>>
825
+ >>> def function_with_optional_param(value: Union[int, None, Missing] = _MISSING):
826
+ ... if value is _MISSING:
827
+ ... print("No value provided")
828
+ ... elif value is None:
829
+ ... print("None was explicitly provided")
830
+ ... else:
831
+ ... print(f"Value: {value}")
832
+ >>>
833
+ >>> function_with_optional_param() # "No value provided"
834
+ >>> function_with_optional_param(None) # "None was explicitly provided"
835
+ >>> function_with_optional_param(42) # "Value: 42"
836
+ """
304
837
  pass