brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +95 -29
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.9.dist-info/RECORD +0 -130
  161. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/mixin.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -15,35 +15,66 @@
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
- from typing import (
19
- Sequence, Optional, TypeVar, _SpecialForm, _type_check, _remove_dups_flatten, _UnionGenericAlias
20
- )
18
+ """
19
+ Mixin classes and utility types for brainstate.
21
20
 
22
- import jax
21
+ This module provides various mixin classes and custom type definitions that
22
+ enhance the functionality of brainstate components. It includes parameter
23
+ description mixins, alignment interfaces, and custom type definitions for
24
+ expressing complex type requirements.
25
+ """
23
26
 
24
- T = TypeVar('T')
25
- ArrayLike = jax.typing.ArrayLike
27
+ from typing import Sequence, Optional, TypeVar, Union, _GenericAlias
28
+
29
+ import jax
26
30
 
27
31
  __all__ = [
28
32
  'Mixin',
29
33
  'ParamDesc',
30
34
  'ParamDescriber',
31
- 'AlignPost',
32
- 'BindCondData',
33
-
34
- # types
35
35
  'JointTypes',
36
36
  'OneOfTypes',
37
-
38
- # behavior modes
37
+ '_JointGenericAlias',
38
+ '_OneOfGenericAlias',
39
39
  'Mode',
40
40
  'JointMode',
41
41
  'Batching',
42
42
  'Training',
43
43
  ]
44
44
 
45
+ T = TypeVar('T')
46
+ ArrayLike = jax.typing.ArrayLike
47
+
45
48
 
46
49
  def hashable(x):
50
+ """
51
+ Check if an object is hashable.
52
+
53
+ Parameters
54
+ ----------
55
+ x : Any
56
+ The object to check for hashability.
57
+
58
+ Returns
59
+ -------
60
+ bool
61
+ True if the object is hashable, False otherwise.
62
+
63
+ Examples
64
+ --------
65
+ .. code-block:: python
66
+
67
+ >>> import brainstate
68
+ >>>
69
+ >>> # Hashable objects
70
+ >>> assert brainstate.mixin.hashable(42) == True
71
+ >>> assert brainstate.mixin.hashable("string") == True
72
+ >>> assert brainstate.mixin.hashable((1, 2, 3)) == True
73
+ >>>
74
+ >>> # Non-hashable objects
75
+ >>> assert brainstate.mixin.hashable([1, 2, 3]) == False
76
+ >>> assert brainstate.mixin.hashable({"key": "value"}) == False
77
+ """
47
78
  try:
48
79
  hash(x)
49
80
  return True
@@ -52,45 +83,194 @@ def hashable(x):
52
83
 
53
84
 
54
85
  class Mixin(object):
55
- """Base Mixin object.
56
-
57
- The key for a :py:class:`~.Mixin` is that: no initialization function, only behavioral functions.
86
+ """
87
+ Base Mixin object for behavioral extensions.
88
+
89
+ The key characteristic of a :py:class:`~.Mixin` is that it provides only
90
+ behavioral functions without requiring initialization. Mixins are used to
91
+ add specific functionality to classes through multiple inheritance without
92
+ the complexity of a full base class.
93
+
94
+ Notes
95
+ -----
96
+ Mixins should not define ``__init__`` methods. They should only provide
97
+ methods that add specific behaviors to the classes that inherit from them.
98
+
99
+ Examples
100
+ --------
101
+ Creating a custom mixin:
102
+
103
+ .. code-block:: python
104
+
105
+ >>> import brainstate
106
+ >>>
107
+ >>> class LoggingMixin(brainstate.mixin.Mixin):
108
+ ... def log(self, message):
109
+ ... print(f"[{self.__class__.__name__}] {message}")
110
+
111
+ >>> class MyComponent(brainstate.nn.Module, LoggingMixin):
112
+ ... def __init__(self):
113
+ ... super().__init__()
114
+ ...
115
+ ... def process(self):
116
+ ... self.log("Processing data...")
117
+ ... return "Done"
118
+ >>>
119
+ >>> component = MyComponent()
120
+ >>> component.process() # Prints: [MyComponent] Processing data...
58
121
  """
59
122
  pass
60
123
 
61
124
 
62
125
  class ParamDesc(Mixin):
63
126
  """
64
- :py:class:`~.Mixin` indicates the function for describing initialization parameters.
65
-
66
- This mixin enables the subclass has a classmethod ``desc``, which
67
- produces an instance of :py:class:`~.ParamDescriber`.
68
-
69
- Note this Mixin can be applied in any Python object.
127
+ Mixin for describing initialization parameters.
128
+
129
+ This mixin enables a class to have a ``desc`` classmethod, which produces
130
+ an instance of :py:class:`~.ParamDescriber`. This is useful for creating
131
+ parameter templates that can be reused to instantiate multiple objects
132
+ with the same configuration.
133
+
134
+ Attributes
135
+ ----------
136
+ non_hashable_params : sequence of str, optional
137
+ Names of parameters that are not hashable and should be handled specially.
138
+
139
+ Notes
140
+ -----
141
+ This mixin can be applied to any Python class, not just brainstate-specific classes.
142
+
143
+ Examples
144
+ --------
145
+ Basic usage of ParamDesc:
146
+
147
+ .. code-block:: python
148
+
149
+ >>> import brainstate
150
+ >>>
151
+ >>> class NeuronModel(brainstate.mixin.ParamDesc):
152
+ ... def __init__(self, size, tau=10.0, threshold=1.0):
153
+ ... self.size = size
154
+ ... self.tau = tau
155
+ ... self.threshold = threshold
156
+ >>>
157
+ >>> # Create a parameter descriptor
158
+ >>> neuron_desc = NeuronModel.desc(size=100, tau=20.0)
159
+ >>>
160
+ >>> # Use the descriptor to create instances
161
+ >>> neuron1 = neuron_desc(threshold=0.8) # Creates with threshold=0.8
162
+ >>> neuron2 = neuron_desc(threshold=1.2) # Creates with threshold=1.2
163
+ >>>
164
+ >>> # Both neurons share size=100, tau=20.0 but have different thresholds
165
+
166
+ Creating reusable templates:
167
+
168
+ .. code-block:: python
169
+
170
+ >>> # Define a template for excitatory neurons
171
+ >>> exc_neuron_template = NeuronModel.desc(size=1000, tau=10.0, threshold=1.0)
172
+ >>>
173
+ >>> # Define a template for inhibitory neurons
174
+ >>> inh_neuron_template = NeuronModel.desc(size=250, tau=5.0, threshold=0.5)
175
+ >>>
176
+ >>> # Create multiple instances from templates
177
+ >>> exc_population = [exc_neuron_template() for _ in range(5)]
178
+ >>> inh_population = [inh_neuron_template() for _ in range(2)]
70
179
  """
71
180
 
181
+ # Optional list of parameter names that are not hashable
182
+ # These will be converted to strings for hashing purposes
72
183
  non_hashable_params: Optional[Sequence[str]] = None
73
184
 
74
185
  @classmethod
75
186
  def desc(cls, *args, **kwargs) -> 'ParamDescriber':
187
+ """
188
+ Create a parameter describer for this class.
189
+
190
+ Parameters
191
+ ----------
192
+ *args
193
+ Positional arguments to be used in future instantiations.
194
+ **kwargs
195
+ Keyword arguments to be used in future instantiations.
196
+
197
+ Returns
198
+ -------
199
+ ParamDescriber
200
+ A descriptor that can be used to create instances with these parameters.
201
+ """
76
202
  return ParamDescriber(cls, *args, **kwargs)
77
203
 
78
204
 
79
205
  class HashableDict(dict):
206
+ """
207
+ A dictionary that can be hashed by converting non-hashable values to strings.
208
+
209
+ This is used internally to make parameter dictionaries hashable so they can
210
+ be used as part of cache keys or other contexts requiring hashability.
211
+
212
+ Parameters
213
+ ----------
214
+ the_dict : dict
215
+ The dictionary to make hashable.
216
+
217
+ Notes
218
+ -----
219
+ Non-hashable values in the dictionary are automatically converted to their
220
+ string representation.
221
+
222
+ Examples
223
+ --------
224
+ .. code-block:: python
225
+
226
+ >>> import brainstate
227
+ >>> import jax.numpy as jnp
228
+ >>>
229
+ >>> # Regular dict with non-hashable values cannot be hashed
230
+ >>> regular_dict = {"array": jnp.array([1, 2, 3]), "value": 42}
231
+ >>> # hash(regular_dict) # This would raise TypeError
232
+ >>>
233
+ >>> # HashableDict can be hashed
234
+ >>> hashable = brainstate.mixin.HashableDict(regular_dict)
235
+ >>> key = hash(hashable) # This works!
236
+ >>>
237
+ >>> # Can be used in sets or as dict keys
238
+ >>> cache = {hashable: "result"}
239
+ """
240
+
80
241
  def __init__(self, the_dict: dict):
242
+ # Process the dictionary to ensure all values are hashable
81
243
  out = dict()
82
244
  for k, v in the_dict.items():
83
245
  if not hashable(v):
84
- v = str(v) # convert to string if not hashable
246
+ # Convert non-hashable values to their string representation
247
+ v = str(v)
85
248
  out[k] = v
86
249
  super().__init__(out)
87
250
 
88
251
  def __hash__(self):
252
+ """
253
+ Compute hash from sorted items for consistent hashing regardless of insertion order.
254
+ """
89
255
  return hash(tuple(sorted(self.items())))
90
256
 
91
257
 
92
258
  class NoSubclassMeta(type):
259
+ """
260
+ Metaclass that prevents a class from being subclassed.
261
+
262
+ This is used to ensure that certain classes (like ParamDescriber) are used
263
+ as-is and not extended through inheritance, which could lead to unexpected
264
+ behavior.
265
+
266
+ Raises
267
+ ------
268
+ TypeError
269
+ If an attempt is made to subclass a class using this metaclass.
270
+ """
271
+
93
272
  def __new__(cls, name, bases, classdict):
273
+ # Check if any base class uses NoSubclassMeta
94
274
  for b in bases:
95
275
  if isinstance(b, NoSubclassMeta):
96
276
  raise TypeError("type '{0}' is not an acceptable base type".format(b.__name__))
@@ -99,209 +279,758 @@ class NoSubclassMeta(type):
99
279
 
100
280
  class ParamDescriber(metaclass=NoSubclassMeta):
101
281
  """
102
- ParamDesc initialization for parameter describers.
282
+ Parameter descriptor for deferred object instantiation.
283
+
284
+ This class stores a class reference along with arguments and keyword arguments,
285
+ allowing for deferred instantiation. It's useful for creating templates that
286
+ can be reused to create multiple instances with similar configurations.
287
+
288
+ Parameters
289
+ ----------
290
+ cls : type
291
+ The class to be instantiated.
292
+ *desc_tuple
293
+ Positional arguments to be stored and used during instantiation.
294
+ **desc_dict
295
+ Keyword arguments to be stored and used during instantiation.
296
+
297
+ Attributes
298
+ ----------
299
+ cls : type
300
+ The class that will be instantiated.
301
+ args : tuple
302
+ Stored positional arguments.
303
+ kwargs : dict
304
+ Stored keyword arguments.
305
+ identifier : tuple
306
+ A hashable identifier for this descriptor.
307
+
308
+ Notes
309
+ -----
310
+ ParamDescriber cannot be subclassed due to the NoSubclassMeta metaclass.
311
+ This ensures consistent behavior across the codebase.
312
+
313
+ Examples
314
+ --------
315
+ Manual creation of a descriptor:
316
+
317
+ .. code-block:: python
318
+
319
+ >>> import brainstate
320
+ >>>
321
+ >>> class Network:
322
+ ... def __init__(self, n_neurons, learning_rate=0.01):
323
+ ... self.n_neurons = n_neurons
324
+ ... self.learning_rate = learning_rate
325
+ >>>
326
+ >>> # Create a descriptor
327
+ >>> network_desc = brainstate.mixin.ParamDescriber(
328
+ ... Network, n_neurons=1000, learning_rate=0.001
329
+ ... )
330
+ >>>
331
+ >>> # Use the descriptor to create instances with additional args
332
+ >>> net1 = network_desc()
333
+ >>> net2 = network_desc() # Same configuration
334
+
335
+ Using with ParamDesc mixin:
336
+
337
+ .. code-block:: python
338
+
339
+ >>> class Network(brainstate.mixin.ParamDesc):
340
+ ... def __init__(self, n_neurons, learning_rate=0.01):
341
+ ... self.n_neurons = n_neurons
342
+ ... self.learning_rate = learning_rate
343
+ >>>
344
+ >>> # More concise syntax using the desc() classmethod
345
+ >>> network_desc = Network.desc(n_neurons=1000)
346
+ >>> net = network_desc(learning_rate=0.005) # Override learning_rate
103
347
  """
104
348
 
105
349
  def __init__(self, cls: T, *desc_tuple, **desc_dict):
350
+ # Store the class to be instantiated
106
351
  self.cls: type = cls
107
352
 
108
- # arguments
353
+ # Store the arguments for later instantiation
109
354
  self.args = desc_tuple
110
355
  self.kwargs = desc_dict
111
356
 
112
- # identifier
357
+ # Create a hashable identifier for caching/comparison purposes
358
+ # This combines the class, args tuple, and hashable kwargs dict
113
359
  self._identifier = (cls, tuple(desc_tuple), HashableDict(desc_dict))
114
360
 
115
361
  def __call__(self, *args, **kwargs) -> T:
116
- return self.cls(*self.args, *args, **self.kwargs, **kwargs)
362
+ """
363
+ Instantiate the class with stored and additional arguments.
364
+
365
+ Parameters
366
+ ----------
367
+ *args
368
+ Additional positional arguments to append.
369
+ **kwargs
370
+ Additional keyword arguments to merge (will override stored kwargs).
371
+
372
+ Returns
373
+ -------
374
+ T
375
+ An instance of the described class.
376
+ """
377
+ # Merge stored arguments with new arguments
378
+ # Stored args come first, then new args
379
+ # Merge kwargs with new kwargs overriding stored ones
380
+ merged_kwargs = {**self.kwargs, **kwargs}
381
+ return self.cls(*self.args, *args, **merged_kwargs)
117
382
 
118
383
  def init(self, *args, **kwargs):
384
+ """
385
+ Alias for __call__, explicitly named for clarity.
386
+
387
+ Parameters
388
+ ----------
389
+ *args
390
+ Additional positional arguments.
391
+ **kwargs
392
+ Additional keyword arguments.
393
+
394
+ Returns
395
+ -------
396
+ T
397
+ An instance of the described class.
398
+ """
119
399
  return self.__call__(*args, **kwargs)
120
400
 
121
401
  def __instancecheck__(self, instance):
402
+ """
403
+ Check if an instance is compatible with this descriptor.
404
+
405
+ Parameters
406
+ ----------
407
+ instance : Any
408
+ The instance to check.
409
+
410
+ Returns
411
+ -------
412
+ bool
413
+ True if the instance is a ParamDescriber for a compatible class.
414
+ """
415
+ # Must be a ParamDescriber
122
416
  if not isinstance(instance, ParamDescriber):
123
417
  return False
418
+ # The described class must be a subclass of our class
124
419
  if not issubclass(instance.cls, self.cls):
125
420
  return False
126
421
  return True
127
422
 
128
423
  @classmethod
129
424
  def __class_getitem__(cls, item: type):
425
+ """
426
+ Support for subscript notation: ParamDescriber[MyClass].
427
+
428
+ Parameters
429
+ ----------
430
+ item : type
431
+ The class to create a descriptor for.
432
+
433
+ Returns
434
+ -------
435
+ ParamDescriber
436
+ A descriptor for the given class.
437
+ """
130
438
  return ParamDescriber(item)
131
439
 
132
440
  @property
133
441
  def identifier(self):
442
+ """
443
+ Get the unique identifier for this descriptor.
444
+
445
+ Returns
446
+ -------
447
+ tuple
448
+ A hashable identifier consisting of (class, args, kwargs).
449
+ """
134
450
  return self._identifier
135
451
 
136
452
  @identifier.setter
137
453
  def identifier(self, value: ArrayLike):
454
+ """
455
+ Prevent modification of the identifier.
456
+
457
+ Raises
458
+ ------
459
+ AttributeError
460
+ Always, as the identifier is read-only.
461
+ """
138
462
  raise AttributeError('Cannot set the identifier.')
139
463
 
140
464
 
141
- class AlignPost(Mixin):
465
+ def not_implemented(func):
142
466
  """
143
- Align post MixIn.
144
-
145
- This class provides a ``align_post_input_add()`` function for
146
- add external currents.
467
+ Decorator to mark a function as not implemented.
468
+
469
+ This decorator wraps a function to raise NotImplementedError when called,
470
+ and adds a ``not_implemented`` attribute for checking.
471
+
472
+ Parameters
473
+ ----------
474
+ func : callable
475
+ The function to mark as not implemented.
476
+
477
+ Returns
478
+ -------
479
+ callable
480
+ A wrapper function that raises NotImplementedError.
481
+
482
+ Examples
483
+ --------
484
+ .. code-block:: python
485
+
486
+ >>> import brainstate
487
+ >>>
488
+ >>> class BaseModel:
489
+ ... @brainstate.mixin.not_implemented
490
+ ... def process(self, x):
491
+ ... pass
492
+ >>>
493
+ >>> model = BaseModel()
494
+ >>> # model.process(10) # Raises: NotImplementedError: process is not implemented.
495
+ >>>
496
+ >>> # Check if a method is not implemented
497
+ >>> assert hasattr(BaseModel.process, 'not_implemented')
147
498
  """
148
499
 
149
- def align_post_input_add(self, *args, **kwargs):
150
- raise NotImplementedError
151
-
152
-
153
- class BindCondData(Mixin):
154
- """Bind temporary conductance data.
155
-
156
-
157
- """
158
- _conductance: Optional
159
-
160
- def bind_cond(self, conductance):
161
- self._conductance = conductance
162
-
163
- def unbind_cond(self):
164
- self._conductance = None
165
-
166
-
167
- def not_implemented(func):
168
500
  def wrapper(*args, **kwargs):
169
501
  raise NotImplementedError(f'{func.__name__} is not implemented.')
170
502
 
503
+ # Mark the wrapper so we can detect not-implemented methods
171
504
  wrapper.not_implemented = True
172
505
  return wrapper
173
506
 
174
507
 
175
- class _MetaUnionType(type):
176
- def __new__(cls, name, bases, dct):
177
- if isinstance(bases, type):
178
- bases = (bases,)
179
- elif isinstance(bases, (list, tuple)):
180
- bases = tuple(bases)
181
- for base in bases:
182
- assert isinstance(base, type), f'Must be type. But got {base}'
183
- else:
184
- raise TypeError(f'Must be type. But got {bases}')
185
- return super().__new__(cls, name, bases, dct)
508
+ class _JointGenericAlias(_GenericAlias, _root=True):
509
+ """
510
+ Generic alias for JointTypes (intersection types).
511
+
512
+ This class represents a type that requires all specified types to be satisfied.
513
+ Unlike _MetaUnionType which creates actual classes with metaclass conflicts,
514
+ this uses typing's generic alias system to avoid metaclass issues.
515
+ """
186
516
 
187
- def __instancecheck__(self, other):
188
- cls_of_other = other.__class__
189
- return all([issubclass(cls_of_other, cls) for cls in self.__bases__])
517
+ def __instancecheck__(self, obj):
518
+ """
519
+ Check if an instance is an instance of all component types.
520
+ """
521
+ return all(isinstance(obj, cls) for cls in self.__args__)
190
522
 
191
523
  def __subclasscheck__(self, subclass):
192
- return all([issubclass(subclass, cls) for cls in self.__bases__])
524
+ """
525
+ Check if a class is a subclass of all component types.
526
+ """
527
+ return all(issubclass(subclass, cls) for cls in self.__args__)
528
+
529
+ def __eq__(self, other):
530
+ """
531
+ Check equality with another type.
193
532
 
533
+ Two JointTypes are equal if they have the same component types,
534
+ regardless of order.
535
+ """
536
+ if not isinstance(other, _JointGenericAlias):
537
+ return NotImplemented
538
+ return set(self.__args__) == set(other.__args__)
194
539
 
195
- class _JointGenericAlias(_UnionGenericAlias, _root=True):
196
- def __subclasscheck__(self, subclass):
197
- return all([issubclass(subclass, cls) for cls in set(self.__args__)])
540
+ def __hash__(self):
541
+ """
542
+ Return hash of the JointType.
198
543
 
544
+ The hash is based on the frozenset of component types to ensure
545
+ that JointTypes with the same types (regardless of order) have
546
+ the same hash.
547
+ """
548
+ return hash(frozenset(self.__args__))
199
549
 
200
- @_SpecialForm
201
- def JointTypes(self, parameters):
202
- """Joint types; JointTypes[X, Y] means both X and Y.
550
+ def __repr__(self):
551
+ """
552
+ Return string representation of the JointType.
203
553
 
204
- To define a union, use e.g. Union[int, str].
554
+ Returns a readable representation showing all component types.
555
+ """
556
+ args_str = ', '.join(
557
+ arg.__module__ + '.' + arg.__name__ if hasattr(arg, '__module__') and hasattr(arg, '__name__')
558
+ else str(arg)
559
+ for arg in self.__args__
560
+ )
561
+ return f'JointTypes[{args_str}]'
562
+
563
+ def __reduce__(self):
564
+ """
565
+ Support for pickling.
205
566
 
206
- Details:
207
- - The arguments must be types and there must be at least one.
208
- - None as an argument is a special case and is replaced by `type(None)`.
209
- - Unions of unions are flattened, e.g.::
567
+ Returns the necessary information to reconstruct the JointType
568
+ when unpickling.
569
+ """
570
+ return (_JointGenericAlias, (self.__origin__, self.__args__))
210
571
 
211
- JointTypes[JointTypes[int, str], float] == JointTypes[int, str, float]
212
572
 
213
- - Unions of a single argument vanish, e.g.::
573
+ class _OneOfGenericAlias(_GenericAlias, _root=True):
574
+ """
575
+ Generic alias for OneOfTypes (union types).
214
576
 
215
- JointTypes[int] == int # The constructor actually returns int
577
+ This class represents a type that requires at least one of the specified
578
+ types to be satisfied. It's similar to typing.Union but provides a consistent
579
+ interface with JointTypes and avoids potential metaclass conflicts.
580
+ """
216
581
 
217
- - Redundant arguments are skipped, e.g.::
582
+ def __instancecheck__(self, obj):
583
+ """
584
+ Check if an instance is an instance of any component type.
585
+ """
586
+ return any(isinstance(obj, cls) for cls in self.__args__)
218
587
 
219
- JointTypes[int, str, int] == JointTypes[int, str]
588
+ def __subclasscheck__(self, subclass):
589
+ """
590
+ Check if a class is a subclass of any component type.
591
+ """
592
+ return any(issubclass(subclass, cls) for cls in self.__args__)
220
593
 
221
- - When comparing unions, the argument order is ignored, e.g.::
594
+ def __eq__(self, other):
595
+ """
596
+ Check equality with another type.
222
597
 
223
- JointTypes[int, str] == JointTypes[str, int]
598
+ Two OneOfTypes are equal if they have the same component types,
599
+ regardless of order.
600
+ """
601
+ if not isinstance(other, _OneOfGenericAlias):
602
+ return NotImplemented
603
+ return set(self.__args__) == set(other.__args__)
224
604
 
225
- - You cannot subclass or instantiate a JointTypes.
226
- - You can use Optional[X] as a shorthand for JointTypes[X, None].
227
- """
228
- if parameters == ():
229
- raise TypeError("Cannot take a Joint of no types.")
230
- if not isinstance(parameters, tuple):
231
- parameters = (parameters,)
232
- msg = "JointTypes[arg, ...]: each arg must be a type."
233
- parameters = tuple(_type_check(p, msg) for p in parameters)
234
- parameters = _remove_dups_flatten(parameters)
235
- if len(parameters) == 1:
236
- return parameters[0]
237
- if len(parameters) == 2 and type(None) in parameters:
238
- return _UnionGenericAlias(self, parameters, name="Optional")
239
- return _JointGenericAlias(self, parameters)
605
+ def __hash__(self):
606
+ """
607
+ Return hash of the OneOfType.
240
608
 
609
+ The hash is based on the frozenset of component types to ensure
610
+ that OneOfTypes with the same types (regardless of order) have
611
+ the same hash.
612
+ """
613
+ return hash(frozenset(self.__args__))
241
614
 
242
- @_SpecialForm
243
- def OneOfTypes(self, parameters):
244
- """Sole type; OneOfTypes[X, Y] means either X or Y.
615
+ def __repr__(self):
616
+ """
617
+ Return string representation of the OneOfType.
245
618
 
246
- To define a union, use e.g. OneOfTypes[int, str]. Details:
247
- - The arguments must be types and there must be at least one.
248
- - None as an argument is a special case and is replaced by
249
- type(None).
250
- - Unions of unions are flattened, e.g.::
619
+ Returns a readable representation showing all component types.
620
+ """
621
+ args_str = ', '.join(
622
+ arg.__module__ + '.' + arg.__name__ if hasattr(arg, '__module__') and hasattr(arg, '__name__')
623
+ else str(arg)
624
+ for arg in self.__args__
625
+ )
626
+ return f'OneOfTypes[{args_str}]'
627
+
628
+ def __reduce__(self):
629
+ """
630
+ Support for pickling.
251
631
 
252
- assert OneOfTypes[OneOfTypes[int, str], float] == OneOfTypes[int, str, float]
632
+ Returns the necessary information to reconstruct the OneOfType
633
+ when unpickling.
634
+ """
635
+ return (_OneOfGenericAlias, (self.__origin__, self.__args__))
253
636
 
254
- - Unions of a single argument vanish, e.g.::
255
637
 
256
- assert OneOfTypes[int] == int # The constructor actually returns int
638
+ class _JointTypesClass:
639
+ """Helper class to enable subscript syntax for JointTypes."""
257
640
 
258
- - Redundant arguments are skipped, e.g.::
641
+ def __call__(self, *types):
642
+ """
643
+ Create a type that requires all specified types (intersection type).
644
+
645
+ This function creates a type hint that indicates a value must satisfy all
646
+ the specified types simultaneously. It's useful for expressing complex
647
+ type requirements where a single object must implement multiple interfaces.
648
+
649
+ Parameters
650
+ ----------
651
+ *types : type
652
+ The types that must all be satisfied.
653
+
654
+ Returns
655
+ -------
656
+ type
657
+ A type that checks for all specified types.
658
+
659
+ Notes
660
+ -----
661
+ - If only one type is provided, that type is returned directly.
662
+ - Redundant types are automatically removed.
663
+ - The order of types doesn't matter for equality checks.
664
+
665
+ Examples
666
+ --------
667
+ Basic usage with interfaces:
668
+
669
+ .. code-block:: python
670
+
671
+ >>> import brainstate
672
+ >>> from typing import Protocol
673
+ >>>
674
+ >>> class Trainable(Protocol):
675
+ ... def train(self): ...
676
+ >>>
677
+ >>> class Evaluable(Protocol):
678
+ ... def evaluate(self): ...
679
+ >>>
680
+ >>> # A model that is both trainable and evaluable
681
+ >>> TrainableEvaluableModel = brainstate.mixin.JointTypes(Trainable, Evaluable)
682
+ >>> # Or using subscript syntax
683
+ >>> TrainableEvaluableModel = brainstate.mixin.JointTypes[Trainable, Evaluable]
684
+ >>>
685
+ >>> class NeuralNetwork(Trainable, Evaluable):
686
+ ... def train(self):
687
+ ... return "Training..."
688
+ ...
689
+ ... def evaluate(self):
690
+ ... return "Evaluating..."
691
+ >>>
692
+ >>> model = NeuralNetwork()
693
+ >>> # model satisfies JointTypes(Trainable, Evaluable)
694
+
695
+ Using with mixin classes:
696
+
697
+ .. code-block:: python
698
+
699
+ >>> class Serializable:
700
+ ... def save(self): pass
701
+ >>>
702
+ >>> class Visualizable:
703
+ ... def plot(self): pass
704
+ >>>
705
+ >>> # Require both serialization and visualization
706
+ >>> FullFeaturedModel = brainstate.mixin.JointTypes[Serializable, Visualizable]
707
+ >>>
708
+ >>> class MyModel(Serializable, Visualizable):
709
+ ... def save(self):
710
+ ... return "Saved"
711
+ ...
712
+ ... def plot(self):
713
+ ... return "Plotted"
714
+ """
715
+ if len(types) == 0:
716
+ raise TypeError("Cannot create a JointTypes of no types.")
717
+
718
+ # Remove duplicates while preserving some order
719
+ seen = set()
720
+ unique_types = []
721
+ for t in types:
722
+ if t not in seen:
723
+ seen.add(t)
724
+ unique_types.append(t)
725
+
726
+ # If only one type, return it directly
727
+ if len(unique_types) == 1:
728
+ return unique_types[0]
729
+
730
+ # Create a generic alias for the joint type
731
+ # This avoids metaclass conflicts by using typing's generic alias system
732
+ return _JointGenericAlias(object, tuple(unique_types))
733
+
734
+ def __getitem__(self, item):
735
+ """Enable subscript syntax: JointTypes[Type1, Type2]."""
736
+ if isinstance(item, tuple):
737
+ return self(*item)
738
+ else:
739
+ return self(item)
259
740
 
260
- assert OneOfTypes[int, str, int] == OneOfTypes[int, str]
261
741
 
262
- - When comparing unions, the argument order is ignored, e.g.::
742
+ # Create singleton instance that acts as both a callable and supports subscript
743
+ JointTypes = _JointTypesClass()
263
744
 
264
- assert OneOfTypes[int, str] == OneOfTypes[str, int]
265
745
 
266
- - You cannot subclass or instantiate a union.
267
- - You can use Optional[X] as a shorthand for OneOfTypes[X, None].
268
- """
269
- if parameters == ():
270
- raise TypeError("Cannot take a Sole of no types.")
271
- if not isinstance(parameters, tuple):
272
- parameters = (parameters,)
273
- msg = "OneOfTypes[arg, ...]: each arg must be a type."
274
- parameters = tuple(_type_check(p, msg) for p in parameters)
275
- parameters = _remove_dups_flatten(parameters)
276
- if len(parameters) == 1:
277
- return parameters[0]
278
- if len(parameters) == 2 and type(None) in parameters:
279
- return _UnionGenericAlias(self, parameters, name="Optional")
280
- return _UnionGenericAlias(self, parameters)
746
+ class _OneOfTypesClass:
747
+ """Helper class to enable subscript syntax for OneOfTypes."""
748
+
749
+ def __call__(self, *types):
750
+ """
751
+ Create a type that requires one of the specified types (union type).
752
+
753
+ This is similar to typing.Union but provides a more intuitive name and
754
+ consistent behavior with JointTypes. It indicates that a value must satisfy
755
+ at least one of the specified types.
756
+
757
+ Parameters
758
+ ----------
759
+ *types : type
760
+ The types, one of which must be satisfied.
761
+
762
+ Returns
763
+ -------
764
+ Union type
765
+ A union type of the specified types.
766
+
767
+ Notes
768
+ -----
769
+ - If only one type is provided, that type is returned directly.
770
+ - Redundant types are automatically removed.
771
+ - The order of types doesn't matter for equality checks.
772
+ - This is equivalent to typing.Union[...].
773
+
774
+ Examples
775
+ --------
776
+ Basic usage with different types:
777
+
778
+ .. code-block:: python
779
+
780
+ >>> import brainstate
781
+ >>>
782
+ >>> # A parameter that can be int or float
783
+ >>> NumericType = brainstate.mixin.OneOfTypes(int, float)
784
+ >>> # Or using subscript syntax
785
+ >>> NumericType = brainstate.mixin.OneOfTypes[int, float]
786
+ >>>
787
+ >>> def process_value(x: NumericType):
788
+ ... return x * 2
789
+ >>>
790
+ >>> # Both work
791
+ >>> result1 = process_value(5) # int
792
+ >>> result2 = process_value(3.14) # float
793
+
794
+ Using with class types:
795
+
796
+ .. code-block:: python
797
+
798
+ >>> class NumpyArray:
799
+ ... pass
800
+ >>>
801
+ >>> class JAXArray:
802
+ ... pass
803
+ >>>
804
+ >>> # Accept either numpy or JAX arrays
805
+ >>> ArrayType = brainstate.mixin.OneOfTypes[NumpyArray, JAXArray]
806
+ >>>
807
+ >>> def compute(arr: ArrayType):
808
+ ... if isinstance(arr, NumpyArray):
809
+ ... return "Processing numpy array"
810
+ ... elif isinstance(arr, JAXArray):
811
+ ... return "Processing JAX array"
812
+
813
+ Combining with None for optional types:
814
+
815
+ .. code-block:: python
816
+
817
+ >>> # Optional string (equivalent to Optional[str])
818
+ >>> MaybeString = brainstate.mixin.OneOfTypes[str, type(None)]
819
+ >>>
820
+ >>> def format_name(name: MaybeString) -> str:
821
+ ... if name is None:
822
+ ... return "Anonymous"
823
+ ... return name.title()
824
+ """
825
+ if len(types) == 0:
826
+ raise TypeError("Cannot create a OneOfTypes of no types.")
827
+
828
+ # Remove duplicates
829
+ seen = set()
830
+ unique_types = []
831
+ for t in types:
832
+ if t not in seen:
833
+ seen.add(t)
834
+ unique_types.append(t)
835
+
836
+ # If only one type, return it directly
837
+ if len(unique_types) == 1:
838
+ return unique_types[0]
839
+
840
+ # Create a generic alias for the union type
841
+ # This provides consistency with JointTypes and avoids metaclass conflicts
842
+ return _OneOfGenericAlias(Union, tuple(unique_types))
843
+
844
+ def __getitem__(self, item):
845
+ """Enable subscript syntax: OneOfTypes[Type1, Type2]."""
846
+ if isinstance(item, tuple):
847
+ return self(*item)
848
+ else:
849
+ return self(item)
850
+
851
+
852
+ # Create singleton instance that acts as both a callable and supports subscript
853
+ OneOfTypes = _OneOfTypesClass()
854
+
855
+
856
+ def __getattr__(name):
857
+ if name in [
858
+ 'Mode',
859
+ 'JointMode',
860
+ 'Batching',
861
+ 'Training',
862
+ 'AlignPost',
863
+ 'BindCondData',
864
+ ]:
865
+ import warnings
866
+ warnings.warn(
867
+ f"brainstate.mixin.{name} is deprecated and will be removed in a future version. "
868
+ f"Please use brainpy.mixin.{name} instead.",
869
+ DeprecationWarning,
870
+ stacklevel=2
871
+ )
872
+ import brainpy
873
+ return getattr(brainpy.mixin, name)
874
+ raise AttributeError(
875
+ f'module {__name__!r} has no attribute {name!r}'
876
+ )
281
877
 
282
878
 
283
879
  class Mode(Mixin):
284
880
  """
285
- Base class for computation behaviors.
881
+ Base class for computation behavior modes.
882
+
883
+ Modes are used to represent different computational contexts or behaviors,
884
+ such as training vs evaluation, batched vs single-sample processing, etc.
885
+ They provide a flexible way to configure how models and components behave
886
+ in different scenarios.
887
+
888
+ Examples
889
+ --------
890
+ Creating a custom mode:
891
+
892
+ .. code-block:: python
893
+
894
+ >>> import brainstate
895
+ >>>
896
+ >>> class InferenceMode(brainstate.mixin.Mode):
897
+ ... def __init__(self, use_cache=True):
898
+ ... self.use_cache = use_cache
899
+ >>>
900
+ >>> # Create mode instances
901
+ >>> inference = InferenceMode(use_cache=True)
902
+ >>> print(inference) # Output: InferenceMode
903
+
904
+ Checking mode types:
905
+
906
+ .. code-block:: python
907
+
908
+ >>> class FastMode(brainstate.mixin.Mode):
909
+ ... pass
910
+ >>>
911
+ >>> class SlowMode(brainstate.mixin.Mode):
912
+ ... pass
913
+ >>>
914
+ >>> fast = FastMode()
915
+ >>> slow = SlowMode()
916
+ >>>
917
+ >>> # Check exact mode type
918
+ >>> assert fast.is_a(FastMode)
919
+ >>> assert not fast.is_a(SlowMode)
920
+ >>>
921
+ >>> # Check if mode is an instance of a type
922
+ >>> assert fast.has(brainstate.mixin.Mode)
923
+
924
+ Using modes in a model:
925
+
926
+ .. code-block:: python
927
+
928
+ >>> class Model:
929
+ ... def __init__(self):
930
+ ... self.mode = brainstate.mixin.Training()
931
+ ...
932
+ ... def forward(self, x):
933
+ ... if self.mode.has(brainstate.mixin.Training):
934
+ ... # Training-specific logic
935
+ ... return self.train_forward(x)
936
+ ... else:
937
+ ... # Inference logic
938
+ ... return self.eval_forward(x)
939
+ ...
940
+ ... def train_forward(self, x):
941
+ ... return x + 0.1 # Add noise during training
942
+ ...
943
+ ... def eval_forward(self, x):
944
+ ... return x # No noise during evaluation
286
945
  """
287
946
 
288
947
  def __repr__(self):
948
+ """
949
+ String representation of the mode.
950
+
951
+ Returns
952
+ -------
953
+ str
954
+ The class name of the mode.
955
+ """
289
956
  return self.__class__.__name__
290
957
 
291
958
  def __eq__(self, other: 'Mode'):
959
+ """
960
+ Check equality of modes based on their type.
961
+
962
+ Parameters
963
+ ----------
964
+ other : Mode
965
+ Another mode to compare with.
966
+
967
+ Returns
968
+ -------
969
+ bool
970
+ True if both modes are of the same class.
971
+ """
292
972
  assert isinstance(other, Mode)
293
973
  return other.__class__ == self.__class__
294
974
 
295
975
  def is_a(self, mode: type):
296
976
  """
297
- Check whether the mode is exactly the desired mode.
977
+ Check whether the mode is exactly the desired mode type.
978
+
979
+ This performs an exact type match, not checking for subclasses.
980
+
981
+ Parameters
982
+ ----------
983
+ mode : type
984
+ The mode type to check against.
985
+
986
+ Returns
987
+ -------
988
+ bool
989
+ True if this mode is exactly of the specified type.
990
+
991
+ Examples
992
+ --------
993
+ .. code-block:: python
994
+
995
+ >>> import brainstate
996
+ >>>
997
+ >>> training_mode = brainstate.mixin.Training()
998
+ >>> assert training_mode.is_a(brainstate.mixin.Training)
999
+ >>> assert not training_mode.is_a(brainstate.mixin.Batching)
298
1000
  """
299
1001
  assert isinstance(mode, type), 'Must be a type.'
300
1002
  return self.__class__ == mode
301
1003
 
302
1004
  def has(self, mode: type):
303
1005
  """
304
- Check whether the mode is included in the desired mode.
1006
+ Check whether the mode includes the desired mode type.
1007
+
1008
+ This checks if the current mode is an instance of the specified type,
1009
+ including checking for subclasses.
1010
+
1011
+ Parameters
1012
+ ----------
1013
+ mode : type
1014
+ The mode type to check for.
1015
+
1016
+ Returns
1017
+ -------
1018
+ bool
1019
+ True if this mode is an instance of the specified type.
1020
+
1021
+ Examples
1022
+ --------
1023
+ .. code-block:: python
1024
+
1025
+ >>> import brainstate
1026
+ >>>
1027
+ >>> # Create a custom mode that extends Training
1028
+ >>> class AdvancedTraining(brainstate.mixin.Training):
1029
+ ... pass
1030
+ >>>
1031
+ >>> advanced = AdvancedTraining()
1032
+ >>> assert advanced.has(brainstate.mixin.Training) # True (subclass)
1033
+ >>> assert advanced.has(brainstate.mixin.Mode) # True (base class)
305
1034
  """
306
1035
  assert isinstance(mode, type), 'Must be a type.'
307
1036
  return isinstance(self, mode)
@@ -309,57 +1038,396 @@ class Mode(Mixin):
309
1038
 
310
1039
  class JointMode(Mode):
311
1040
  """
312
- Joint mode.
1041
+ A mode that combines multiple modes simultaneously.
1042
+
1043
+ JointMode allows expressing that a computation is in multiple modes at once,
1044
+ such as being both in training mode and batching mode. This is useful for
1045
+ complex scenarios where multiple behavioral aspects need to be active.
1046
+
1047
+ Parameters
1048
+ ----------
1049
+ *modes : Mode
1050
+ The modes to combine.
1051
+
1052
+ Attributes
1053
+ ----------
1054
+ modes : tuple of Mode
1055
+ The individual modes that are combined.
1056
+ types : set of type
1057
+ The types of the combined modes.
1058
+
1059
+ Raises
1060
+ ------
1061
+ TypeError
1062
+ If any of the provided arguments is not a Mode instance.
1063
+
1064
+ Examples
1065
+ --------
1066
+ Combining training and batching modes:
1067
+
1068
+ .. code-block:: python
1069
+
1070
+ >>> import brainstate
1071
+ >>>
1072
+ >>> # Create individual modes
1073
+ >>> training = brainstate.mixin.Training()
1074
+ >>> batching = brainstate.mixin.Batching(batch_size=32)
1075
+ >>>
1076
+ >>> # Combine them
1077
+ >>> joint = brainstate.mixin.JointMode(training, batching)
1078
+ >>> print(joint) # JointMode(Training, Batching(in_size=32, axis=0))
1079
+ >>>
1080
+ >>> # Check if specific modes are present
1081
+ >>> assert joint.has(brainstate.mixin.Training)
1082
+ >>> assert joint.has(brainstate.mixin.Batching)
1083
+ >>>
1084
+ >>> # Access attributes from combined modes
1085
+ >>> print(joint.batch_size) # 32 (from Batching mode)
1086
+
1087
+ Using in model configuration:
1088
+
1089
+ .. code-block:: python
1090
+
1091
+ >>> class NeuralNetwork:
1092
+ ... def __init__(self):
1093
+ ... self.mode = None
1094
+ ...
1095
+ ... def set_train_mode(self, batch_size=1):
1096
+ ... # Set both training and batching modes
1097
+ ... training = brainstate.mixin.Training()
1098
+ ... batching = brainstate.mixin.Batching(batch_size=batch_size)
1099
+ ... self.mode = brainstate.mixin.JointMode(training, batching)
1100
+ ...
1101
+ ... def forward(self, x):
1102
+ ... if self.mode.has(brainstate.mixin.Training):
1103
+ ... x = self.apply_dropout(x)
1104
+ ...
1105
+ ... if self.mode.has(brainstate.mixin.Batching):
1106
+ ... # Process in batches
1107
+ ... batch_size = self.mode.batch_size
1108
+ ... return self.batch_process(x, batch_size)
1109
+ ...
1110
+ ... return self.process(x)
1111
+ >>>
1112
+ >>> model = NeuralNetwork()
1113
+ >>> model.set_train_mode(batch_size=64)
313
1114
  """
314
1115
 
315
1116
  def __init__(self, *modes: Mode):
1117
+ # Validate that all arguments are Mode instances
316
1118
  for m_ in modes:
317
1119
  if not isinstance(m_, Mode):
318
1120
  raise TypeError(f'The supported type must be a tuple/list of Mode. But we got {m_}')
1121
+
1122
+ # Store the modes as a tuple
319
1123
  self.modes = tuple(modes)
1124
+
1125
+ # Store the types of the modes for quick lookup
320
1126
  self.types = set([m.__class__ for m in modes])
321
1127
 
322
1128
  def __repr__(self):
1129
+ """
1130
+ String representation showing all combined modes.
1131
+
1132
+ Returns
1133
+ -------
1134
+ str
1135
+ A string showing the joint mode and its components.
1136
+ """
323
1137
  return f'{self.__class__.__name__}({", ".join([repr(m) for m in self.modes])})'
324
1138
 
325
1139
  def has(self, mode: type):
326
1140
  """
327
- Check whether the mode is included in the desired mode.
1141
+ Check whether any of the combined modes includes the desired type.
1142
+
1143
+ Parameters
1144
+ ----------
1145
+ mode : type
1146
+ The mode type to check for.
1147
+
1148
+ Returns
1149
+ -------
1150
+ bool
1151
+ True if any of the combined modes is or inherits from the specified type.
1152
+
1153
+ Examples
1154
+ --------
1155
+ .. code-block:: python
1156
+
1157
+ >>> import brainstate
1158
+ >>>
1159
+ >>> training = brainstate.mixin.Training()
1160
+ >>> batching = brainstate.mixin.Batching(batch_size=16)
1161
+ >>> joint = brainstate.mixin.JointMode(training, batching)
1162
+ >>>
1163
+ >>> assert joint.has(brainstate.mixin.Training)
1164
+ >>> assert joint.has(brainstate.mixin.Batching)
1165
+ >>> assert joint.has(brainstate.mixin.Mode) # Base class
328
1166
  """
329
1167
  assert isinstance(mode, type), 'Must be a type.'
1168
+ # Check if any of the combined mode types is a subclass of the target mode
330
1169
  return any([issubclass(cls, mode) for cls in self.types])
331
1170
 
332
1171
  def is_a(self, cls: type):
333
1172
  """
334
- Check whether the mode is exactly the desired mode.
1173
+ Check whether the joint mode is exactly the desired combined type.
1174
+
1175
+ This is a complex check that verifies the joint mode matches a specific
1176
+ combination of types.
1177
+
1178
+ Parameters
1179
+ ----------
1180
+ cls : type
1181
+ The combined type to check against.
1182
+
1183
+ Returns
1184
+ -------
1185
+ bool
1186
+ True if the joint mode exactly matches the specified type combination.
335
1187
  """
336
- return JointTypes[tuple(self.types)] == cls
1188
+ # Use JointTypes to create the expected type from our mode types
1189
+ return JointTypes(*tuple(self.types)) == cls
337
1190
 
338
1191
  def __getattr__(self, item):
339
1192
  """
340
- Get the attribute from the mode.
341
-
342
- If the attribute is not found in the mode, then it will be searched in the base class.
1193
+ Get attributes from the combined modes.
1194
+
1195
+ This method searches through all combined modes to find the requested
1196
+ attribute, allowing transparent access to properties of any of the
1197
+ combined modes.
1198
+
1199
+ Parameters
1200
+ ----------
1201
+ item : str
1202
+ The attribute name to search for.
1203
+
1204
+ Returns
1205
+ -------
1206
+ Any
1207
+ The attribute value from the first mode that has it.
1208
+
1209
+ Raises
1210
+ ------
1211
+ AttributeError
1212
+ If the attribute is not found in any of the combined modes.
1213
+
1214
+ Examples
1215
+ --------
1216
+ .. code-block:: python
1217
+
1218
+ >>> import brainstate
1219
+ >>>
1220
+ >>> batching = brainstate.mixin.Batching(batch_size=32, batch_axis=1)
1221
+ >>> training = brainstate.mixin.Training()
1222
+ >>> joint = brainstate.mixin.JointMode(batching, training)
1223
+ >>>
1224
+ >>> # Access batching attributes directly
1225
+ >>> print(joint.batch_size) # 32
1226
+ >>> print(joint.batch_axis) # 1
343
1227
  """
1228
+ # Don't interfere with accessing modes and types attributes
344
1229
  if item in ['modes', 'types']:
345
1230
  return super().__getattribute__(item)
1231
+
1232
+ # Search for the attribute in each combined mode
346
1233
  for m in self.modes:
347
1234
  if hasattr(m, item):
348
1235
  return getattr(m, item)
1236
+
1237
+ # If not found, fall back to default behavior (will raise AttributeError)
349
1238
  return super().__getattribute__(item)
350
1239
 
351
1240
 
352
1241
  class Batching(Mode):
353
- """Batching mode."""
1242
+ """
1243
+ Mode indicating batched computation.
1244
+
1245
+ This mode specifies that computations should be performed on batches of data,
1246
+ including information about the batch size and which axis represents the batch
1247
+ dimension.
1248
+
1249
+ Parameters
1250
+ ----------
1251
+ batch_size : int, default 1
1252
+ The size of each batch.
1253
+ batch_axis : int, default 0
1254
+ The axis along which batching occurs.
1255
+
1256
+ Attributes
1257
+ ----------
1258
+ batch_size : int
1259
+ The number of samples in each batch.
1260
+ batch_axis : int
1261
+ The axis index representing the batch dimension.
1262
+
1263
+ Examples
1264
+ --------
1265
+ Basic batching configuration:
1266
+
1267
+ .. code-block:: python
1268
+
1269
+ >>> import brainstate
1270
+ >>>
1271
+ >>> # Create a batching mode
1272
+ >>> batching = brainstate.mixin.Batching(batch_size=32, batch_axis=0)
1273
+ >>> print(batching) # Batching(in_size=32, axis=0)
1274
+ >>>
1275
+ >>> # Access batch parameters
1276
+ >>> print(f"Processing {batching.batch_size} samples at once")
1277
+ >>> print(f"Batch dimension is axis {batching.batch_axis}")
1278
+
1279
+ Using in a model:
1280
+
1281
+ .. code-block:: python
1282
+
1283
+ >>> import jax.numpy as jnp
1284
+ >>>
1285
+ >>> class BatchedModel:
1286
+ ... def __init__(self):
1287
+ ... self.mode = None
1288
+ ...
1289
+ ... def set_batch_mode(self, batch_size, batch_axis=0):
1290
+ ... self.mode = brainstate.mixin.Batching(batch_size, batch_axis)
1291
+ ...
1292
+ ... def process(self, x):
1293
+ ... if self.mode is not None and self.mode.has(brainstate.mixin.Batching):
1294
+ ... # Process in batches
1295
+ ... batch_size = self.mode.batch_size
1296
+ ... axis = self.mode.batch_axis
1297
+ ... return jnp.mean(x, axis=axis, keepdims=True)
1298
+ ... return x
1299
+ >>>
1300
+ >>> model = BatchedModel()
1301
+ >>> model.set_batch_mode(batch_size=64)
1302
+ >>>
1303
+ >>> # Process batched data
1304
+ >>> data = jnp.random.randn(64, 100) # 64 samples, 100 features
1305
+ >>> result = model.process(data)
1306
+
1307
+ Combining with other modes:
1308
+
1309
+ .. code-block:: python
1310
+
1311
+ >>> # Combine batching with training mode
1312
+ >>> training = brainstate.mixin.Training()
1313
+ >>> batching = brainstate.mixin.Batching(batch_size=128)
1314
+ >>> combined = brainstate.mixin.JointMode(training, batching)
1315
+ >>>
1316
+ >>> # Use in a training loop
1317
+ >>> def train_step(model, data, mode):
1318
+ ... if mode.has(brainstate.mixin.Batching):
1319
+ ... # Split data into batches
1320
+ ... batch_size = mode.batch_size
1321
+ ... # ... batched processing ...
1322
+ ... if mode.has(brainstate.mixin.Training):
1323
+ ... # Apply training-specific operations
1324
+ ... # ... training logic ...
1325
+ ... pass
1326
+ """
354
1327
 
355
1328
  def __init__(self, batch_size: int = 1, batch_axis: int = 0):
356
1329
  self.batch_size = batch_size
357
1330
  self.batch_axis = batch_axis
358
1331
 
359
1332
  def __repr__(self):
1333
+ """
1334
+ String representation showing batch configuration.
1335
+
1336
+ Returns
1337
+ -------
1338
+ str
1339
+ A string showing the batch size and axis.
1340
+ """
360
1341
  return f'{self.__class__.__name__}(in_size={self.batch_size}, axis={self.batch_axis})'
361
1342
 
362
1343
 
363
1344
  class Training(Mode):
364
- """Training mode."""
1345
+ """
1346
+ Mode indicating training computation.
1347
+
1348
+ This mode specifies that the model is in training mode, which typically
1349
+ enables behaviors like dropout, batch normalization in training mode,
1350
+ gradient computation, etc.
1351
+
1352
+ Examples
1353
+ --------
1354
+ Basic training mode:
1355
+
1356
+ .. code-block:: python
1357
+
1358
+ >>> import brainstate
1359
+ >>>
1360
+ >>> # Create training mode
1361
+ >>> training = brainstate.mixin.Training()
1362
+ >>> print(training) # Training
1363
+ >>>
1364
+ >>> # Check mode
1365
+ >>> assert training.is_a(brainstate.mixin.Training)
1366
+ >>> assert training.has(brainstate.mixin.Mode)
1367
+
1368
+ Using in a model with dropout:
1369
+
1370
+ .. code-block:: python
1371
+
1372
+ >>> import brainstate
1373
+ >>> import jax
1374
+ >>> import jax.numpy as jnp
1375
+ >>>
1376
+ >>> class ModelWithDropout:
1377
+ ... def __init__(self, dropout_rate=0.5):
1378
+ ... self.dropout_rate = dropout_rate
1379
+ ... self.mode = None
1380
+ ...
1381
+ ... def set_training(self, is_training=True):
1382
+ ... if is_training:
1383
+ ... self.mode = brainstate.mixin.Training()
1384
+ ... else:
1385
+ ... self.mode = brainstate.mixin.Mode() # Evaluation mode
1386
+ ...
1387
+ ... def forward(self, x, rng_key):
1388
+ ... # Apply dropout only during training
1389
+ ... if self.mode is not None and self.mode.has(brainstate.mixin.Training):
1390
+ ... keep_prob = 1.0 - self.dropout_rate
1391
+ ... mask = jax.random.bernoulli(rng_key, keep_prob, x.shape)
1392
+ ... x = jnp.where(mask, x / keep_prob, 0)
1393
+ ... return x
1394
+ >>>
1395
+ >>> model = ModelWithDropout()
1396
+ >>>
1397
+ >>> # Training mode
1398
+ >>> model.set_training(True)
1399
+ >>> key = jax.random.PRNGKey(0)
1400
+ >>> x_train = jnp.ones((10, 20))
1401
+ >>> out_train = model.forward(x_train, key) # Dropout applied
1402
+ >>>
1403
+ >>> # Evaluation mode
1404
+ >>> model.set_training(False)
1405
+ >>> out_eval = model.forward(x_train, key) # No dropout
1406
+
1407
+ Combining with batching:
1408
+
1409
+ .. code-block:: python
1410
+
1411
+ >>> # Create combined training and batching mode
1412
+ >>> training = brainstate.mixin.Training()
1413
+ >>> batching = brainstate.mixin.Batching(batch_size=32)
1414
+ >>> mode = brainstate.mixin.JointMode(training, batching)
1415
+ >>>
1416
+ >>> # Use in training configuration
1417
+ >>> class Trainer:
1418
+ ... def __init__(self, model, mode):
1419
+ ... self.model = model
1420
+ ... self.mode = mode
1421
+ ...
1422
+ ... def train_epoch(self, data):
1423
+ ... if self.mode.has(brainstate.mixin.Training):
1424
+ ... # Enable training-specific behaviors
1425
+ ... self.model.set_training(True)
1426
+ ...
1427
+ ... if self.mode.has(brainstate.mixin.Batching):
1428
+ ... # Process in batches
1429
+ ... batch_size = self.mode.batch_size
1430
+ ... # ... batched training loop ...
1431
+ ... pass
1432
+ """
365
1433
  pass