brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
brainstate/mixin.py CHANGED
@@ -1,1433 +1,1447 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- """
19
- Mixin classes and utility types for brainstate.
20
-
21
- This module provides various mixin classes and custom type definitions that
22
- enhance the functionality of brainstate components. It includes parameter
23
- description mixins, alignment interfaces, and custom type definitions for
24
- expressing complex type requirements.
25
- """
26
-
27
- from typing import Sequence, Optional, TypeVar, Union, _GenericAlias
28
-
29
- import jax
30
-
31
- __all__ = [
32
- 'Mixin',
33
- 'ParamDesc',
34
- 'ParamDescriber',
35
- 'JointTypes',
36
- 'OneOfTypes',
37
- '_JointGenericAlias',
38
- '_OneOfGenericAlias',
39
- 'Mode',
40
- 'JointMode',
41
- 'Batching',
42
- 'Training',
43
- ]
44
-
45
- T = TypeVar('T')
46
- ArrayLike = jax.typing.ArrayLike
47
-
48
-
49
- def hashable(x):
50
- """
51
- Check if an object is hashable.
52
-
53
- Parameters
54
- ----------
55
- x : Any
56
- The object to check for hashability.
57
-
58
- Returns
59
- -------
60
- bool
61
- True if the object is hashable, False otherwise.
62
-
63
- Examples
64
- --------
65
- .. code-block:: python
66
-
67
- >>> import brainstate
68
- >>>
69
- >>> # Hashable objects
70
- >>> assert brainstate.mixin.hashable(42) == True
71
- >>> assert brainstate.mixin.hashable("string") == True
72
- >>> assert brainstate.mixin.hashable((1, 2, 3)) == True
73
- >>>
74
- >>> # Non-hashable objects
75
- >>> assert brainstate.mixin.hashable([1, 2, 3]) == False
76
- >>> assert brainstate.mixin.hashable({"key": "value"}) == False
77
- """
78
- try:
79
- hash(x)
80
- return True
81
- except TypeError:
82
- return False
83
-
84
-
85
- class Mixin(object):
86
- """
87
- Base Mixin object for behavioral extensions.
88
-
89
- The key characteristic of a :py:class:`~.Mixin` is that it provides only
90
- behavioral functions without requiring initialization. Mixins are used to
91
- add specific functionality to classes through multiple inheritance without
92
- the complexity of a full base class.
93
-
94
- Notes
95
- -----
96
- Mixins should not define ``__init__`` methods. They should only provide
97
- methods that add specific behaviors to the classes that inherit from them.
98
-
99
- Examples
100
- --------
101
- Creating a custom mixin:
102
-
103
- .. code-block:: python
104
-
105
- >>> import brainstate
106
- >>>
107
- >>> class LoggingMixin(brainstate.mixin.Mixin):
108
- ... def log(self, message):
109
- ... print(f"[{self.__class__.__name__}] {message}")
110
-
111
- >>> class MyComponent(brainstate.nn.Module, LoggingMixin):
112
- ... def __init__(self):
113
- ... super().__init__()
114
- ...
115
- ... def process(self):
116
- ... self.log("Processing data...")
117
- ... return "Done"
118
- >>>
119
- >>> component = MyComponent()
120
- >>> component.process() # Prints: [MyComponent] Processing data...
121
- """
122
- pass
123
-
124
-
125
- class ParamDesc(Mixin):
126
- """
127
- Mixin for describing initialization parameters.
128
-
129
- This mixin enables a class to have a ``desc`` classmethod, which produces
130
- an instance of :py:class:`~.ParamDescriber`. This is useful for creating
131
- parameter templates that can be reused to instantiate multiple objects
132
- with the same configuration.
133
-
134
- Attributes
135
- ----------
136
- non_hashable_params : sequence of str, optional
137
- Names of parameters that are not hashable and should be handled specially.
138
-
139
- Notes
140
- -----
141
- This mixin can be applied to any Python class, not just brainstate-specific classes.
142
-
143
- Examples
144
- --------
145
- Basic usage of ParamDesc:
146
-
147
- .. code-block:: python
148
-
149
- >>> import brainstate
150
- >>>
151
- >>> class NeuronModel(brainstate.mixin.ParamDesc):
152
- ... def __init__(self, size, tau=10.0, threshold=1.0):
153
- ... self.size = size
154
- ... self.tau = tau
155
- ... self.threshold = threshold
156
- >>>
157
- >>> # Create a parameter descriptor
158
- >>> neuron_desc = NeuronModel.desc(size=100, tau=20.0)
159
- >>>
160
- >>> # Use the descriptor to create instances
161
- >>> neuron1 = neuron_desc(threshold=0.8) # Creates with threshold=0.8
162
- >>> neuron2 = neuron_desc(threshold=1.2) # Creates with threshold=1.2
163
- >>>
164
- >>> # Both neurons share size=100, tau=20.0 but have different thresholds
165
-
166
- Creating reusable templates:
167
-
168
- .. code-block:: python
169
-
170
- >>> # Define a template for excitatory neurons
171
- >>> exc_neuron_template = NeuronModel.desc(size=1000, tau=10.0, threshold=1.0)
172
- >>>
173
- >>> # Define a template for inhibitory neurons
174
- >>> inh_neuron_template = NeuronModel.desc(size=250, tau=5.0, threshold=0.5)
175
- >>>
176
- >>> # Create multiple instances from templates
177
- >>> exc_population = [exc_neuron_template() for _ in range(5)]
178
- >>> inh_population = [inh_neuron_template() for _ in range(2)]
179
- """
180
-
181
- # Optional list of parameter names that are not hashable
182
- # These will be converted to strings for hashing purposes
183
- non_hashable_params: Optional[Sequence[str]] = None
184
-
185
- @classmethod
186
- def desc(cls, *args, **kwargs) -> 'ParamDescriber':
187
- """
188
- Create a parameter describer for this class.
189
-
190
- Parameters
191
- ----------
192
- *args
193
- Positional arguments to be used in future instantiations.
194
- **kwargs
195
- Keyword arguments to be used in future instantiations.
196
-
197
- Returns
198
- -------
199
- ParamDescriber
200
- A descriptor that can be used to create instances with these parameters.
201
- """
202
- return ParamDescriber(cls, *args, **kwargs)
203
-
204
-
205
- class HashableDict(dict):
206
- """
207
- A dictionary that can be hashed by converting non-hashable values to strings.
208
-
209
- This is used internally to make parameter dictionaries hashable so they can
210
- be used as part of cache keys or other contexts requiring hashability.
211
-
212
- Parameters
213
- ----------
214
- the_dict : dict
215
- The dictionary to make hashable.
216
-
217
- Notes
218
- -----
219
- Non-hashable values in the dictionary are automatically converted to their
220
- string representation.
221
-
222
- Examples
223
- --------
224
- .. code-block:: python
225
-
226
- >>> import brainstate
227
- >>> import jax.numpy as jnp
228
- >>>
229
- >>> # Regular dict with non-hashable values cannot be hashed
230
- >>> regular_dict = {"array": jnp.array([1, 2, 3]), "value": 42}
231
- >>> # hash(regular_dict) # This would raise TypeError
232
- >>>
233
- >>> # HashableDict can be hashed
234
- >>> hashable = brainstate.mixin.HashableDict(regular_dict)
235
- >>> key = hash(hashable) # This works!
236
- >>>
237
- >>> # Can be used in sets or as dict keys
238
- >>> cache = {hashable: "result"}
239
- """
240
-
241
- def __init__(self, the_dict: dict):
242
- # Process the dictionary to ensure all values are hashable
243
- out = dict()
244
- for k, v in the_dict.items():
245
- if not hashable(v):
246
- # Convert non-hashable values to their string representation
247
- v = str(v)
248
- out[k] = v
249
- super().__init__(out)
250
-
251
- def __hash__(self):
252
- """
253
- Compute hash from sorted items for consistent hashing regardless of insertion order.
254
- """
255
- return hash(tuple(sorted(self.items())))
256
-
257
-
258
- class NoSubclassMeta(type):
259
- """
260
- Metaclass that prevents a class from being subclassed.
261
-
262
- This is used to ensure that certain classes (like ParamDescriber) are used
263
- as-is and not extended through inheritance, which could lead to unexpected
264
- behavior.
265
-
266
- Raises
267
- ------
268
- TypeError
269
- If an attempt is made to subclass a class using this metaclass.
270
- """
271
-
272
- def __new__(cls, name, bases, classdict):
273
- # Check if any base class uses NoSubclassMeta
274
- for b in bases:
275
- if isinstance(b, NoSubclassMeta):
276
- raise TypeError("type '{0}' is not an acceptable base type".format(b.__name__))
277
- return type.__new__(cls, name, bases, dict(classdict))
278
-
279
-
280
- class ParamDescriber(metaclass=NoSubclassMeta):
281
- """
282
- Parameter descriptor for deferred object instantiation.
283
-
284
- This class stores a class reference along with arguments and keyword arguments,
285
- allowing for deferred instantiation. It's useful for creating templates that
286
- can be reused to create multiple instances with similar configurations.
287
-
288
- Parameters
289
- ----------
290
- cls : type
291
- The class to be instantiated.
292
- *desc_tuple
293
- Positional arguments to be stored and used during instantiation.
294
- **desc_dict
295
- Keyword arguments to be stored and used during instantiation.
296
-
297
- Attributes
298
- ----------
299
- cls : type
300
- The class that will be instantiated.
301
- args : tuple
302
- Stored positional arguments.
303
- kwargs : dict
304
- Stored keyword arguments.
305
- identifier : tuple
306
- A hashable identifier for this descriptor.
307
-
308
- Notes
309
- -----
310
- ParamDescriber cannot be subclassed due to the NoSubclassMeta metaclass.
311
- This ensures consistent behavior across the codebase.
312
-
313
- Examples
314
- --------
315
- Manual creation of a descriptor:
316
-
317
- .. code-block:: python
318
-
319
- >>> import brainstate
320
- >>>
321
- >>> class Network:
322
- ... def __init__(self, n_neurons, learning_rate=0.01):
323
- ... self.n_neurons = n_neurons
324
- ... self.learning_rate = learning_rate
325
- >>>
326
- >>> # Create a descriptor
327
- >>> network_desc = brainstate.mixin.ParamDescriber(
328
- ... Network, n_neurons=1000, learning_rate=0.001
329
- ... )
330
- >>>
331
- >>> # Use the descriptor to create instances with additional args
332
- >>> net1 = network_desc()
333
- >>> net2 = network_desc() # Same configuration
334
-
335
- Using with ParamDesc mixin:
336
-
337
- .. code-block:: python
338
-
339
- >>> class Network(brainstate.mixin.ParamDesc):
340
- ... def __init__(self, n_neurons, learning_rate=0.01):
341
- ... self.n_neurons = n_neurons
342
- ... self.learning_rate = learning_rate
343
- >>>
344
- >>> # More concise syntax using the desc() classmethod
345
- >>> network_desc = Network.desc(n_neurons=1000)
346
- >>> net = network_desc(learning_rate=0.005) # Override learning_rate
347
- """
348
-
349
- def __init__(self, cls: T, *desc_tuple, **desc_dict):
350
- # Store the class to be instantiated
351
- self.cls: type = cls
352
-
353
- # Store the arguments for later instantiation
354
- self.args = desc_tuple
355
- self.kwargs = desc_dict
356
-
357
- # Create a hashable identifier for caching/comparison purposes
358
- # This combines the class, args tuple, and hashable kwargs dict
359
- self._identifier = (cls, tuple(desc_tuple), HashableDict(desc_dict))
360
-
361
- def __call__(self, *args, **kwargs) -> T:
362
- """
363
- Instantiate the class with stored and additional arguments.
364
-
365
- Parameters
366
- ----------
367
- *args
368
- Additional positional arguments to append.
369
- **kwargs
370
- Additional keyword arguments to merge (will override stored kwargs).
371
-
372
- Returns
373
- -------
374
- T
375
- An instance of the described class.
376
- """
377
- # Merge stored arguments with new arguments
378
- # Stored args come first, then new args
379
- # Merge kwargs with new kwargs overriding stored ones
380
- merged_kwargs = {**self.kwargs, **kwargs}
381
- return self.cls(*self.args, *args, **merged_kwargs)
382
-
383
- def init(self, *args, **kwargs):
384
- """
385
- Alias for __call__, explicitly named for clarity.
386
-
387
- Parameters
388
- ----------
389
- *args
390
- Additional positional arguments.
391
- **kwargs
392
- Additional keyword arguments.
393
-
394
- Returns
395
- -------
396
- T
397
- An instance of the described class.
398
- """
399
- return self.__call__(*args, **kwargs)
400
-
401
- def __instancecheck__(self, instance):
402
- """
403
- Check if an instance is compatible with this descriptor.
404
-
405
- Parameters
406
- ----------
407
- instance : Any
408
- The instance to check.
409
-
410
- Returns
411
- -------
412
- bool
413
- True if the instance is a ParamDescriber for a compatible class.
414
- """
415
- # Must be a ParamDescriber
416
- if not isinstance(instance, ParamDescriber):
417
- return False
418
- # The described class must be a subclass of our class
419
- if not issubclass(instance.cls, self.cls):
420
- return False
421
- return True
422
-
423
- @classmethod
424
- def __class_getitem__(cls, item: type):
425
- """
426
- Support for subscript notation: ParamDescriber[MyClass].
427
-
428
- Parameters
429
- ----------
430
- item : type
431
- The class to create a descriptor for.
432
-
433
- Returns
434
- -------
435
- ParamDescriber
436
- A descriptor for the given class.
437
- """
438
- return ParamDescriber(item)
439
-
440
- @property
441
- def identifier(self):
442
- """
443
- Get the unique identifier for this descriptor.
444
-
445
- Returns
446
- -------
447
- tuple
448
- A hashable identifier consisting of (class, args, kwargs).
449
- """
450
- return self._identifier
451
-
452
- @identifier.setter
453
- def identifier(self, value: ArrayLike):
454
- """
455
- Prevent modification of the identifier.
456
-
457
- Raises
458
- ------
459
- AttributeError
460
- Always, as the identifier is read-only.
461
- """
462
- raise AttributeError('Cannot set the identifier.')
463
-
464
-
465
- def not_implemented(func):
466
- """
467
- Decorator to mark a function as not implemented.
468
-
469
- This decorator wraps a function to raise NotImplementedError when called,
470
- and adds a ``not_implemented`` attribute for checking.
471
-
472
- Parameters
473
- ----------
474
- func : callable
475
- The function to mark as not implemented.
476
-
477
- Returns
478
- -------
479
- callable
480
- A wrapper function that raises NotImplementedError.
481
-
482
- Examples
483
- --------
484
- .. code-block:: python
485
-
486
- >>> import brainstate
487
- >>>
488
- >>> class BaseModel:
489
- ... @brainstate.mixin.not_implemented
490
- ... def process(self, x):
491
- ... pass
492
- >>>
493
- >>> model = BaseModel()
494
- >>> # model.process(10) # Raises: NotImplementedError: process is not implemented.
495
- >>>
496
- >>> # Check if a method is not implemented
497
- >>> assert hasattr(BaseModel.process, 'not_implemented')
498
- """
499
-
500
- def wrapper(*args, **kwargs):
501
- raise NotImplementedError(f'{func.__name__} is not implemented.')
502
-
503
- # Mark the wrapper so we can detect not-implemented methods
504
- wrapper.not_implemented = True
505
- return wrapper
506
-
507
-
508
- class _JointGenericAlias(_GenericAlias, _root=True):
509
- """
510
- Generic alias for JointTypes (intersection types).
511
-
512
- This class represents a type that requires all specified types to be satisfied.
513
- Unlike _MetaUnionType which creates actual classes with metaclass conflicts,
514
- this uses typing's generic alias system to avoid metaclass issues.
515
- """
516
-
517
- def __instancecheck__(self, obj):
518
- """
519
- Check if an instance is an instance of all component types.
520
- """
521
- return all(isinstance(obj, cls) for cls in self.__args__)
522
-
523
- def __subclasscheck__(self, subclass):
524
- """
525
- Check if a class is a subclass of all component types.
526
- """
527
- return all(issubclass(subclass, cls) for cls in self.__args__)
528
-
529
- def __eq__(self, other):
530
- """
531
- Check equality with another type.
532
-
533
- Two JointTypes are equal if they have the same component types,
534
- regardless of order.
535
- """
536
- if not isinstance(other, _JointGenericAlias):
537
- return NotImplemented
538
- return set(self.__args__) == set(other.__args__)
539
-
540
- def __hash__(self):
541
- """
542
- Return hash of the JointType.
543
-
544
- The hash is based on the frozenset of component types to ensure
545
- that JointTypes with the same types (regardless of order) have
546
- the same hash.
547
- """
548
- return hash(frozenset(self.__args__))
549
-
550
- def __repr__(self):
551
- """
552
- Return string representation of the JointType.
553
-
554
- Returns a readable representation showing all component types.
555
- """
556
- args_str = ', '.join(
557
- arg.__module__ + '.' + arg.__name__ if hasattr(arg, '__module__') and hasattr(arg, '__name__')
558
- else str(arg)
559
- for arg in self.__args__
560
- )
561
- return f'JointTypes[{args_str}]'
562
-
563
- def __reduce__(self):
564
- """
565
- Support for pickling.
566
-
567
- Returns the necessary information to reconstruct the JointType
568
- when unpickling.
569
- """
570
- return (_JointGenericAlias, (self.__origin__, self.__args__))
571
-
572
-
573
- class _OneOfGenericAlias(_GenericAlias, _root=True):
574
- """
575
- Generic alias for OneOfTypes (union types).
576
-
577
- This class represents a type that requires at least one of the specified
578
- types to be satisfied. It's similar to typing.Union but provides a consistent
579
- interface with JointTypes and avoids potential metaclass conflicts.
580
- """
581
-
582
- def __instancecheck__(self, obj):
583
- """
584
- Check if an instance is an instance of any component type.
585
- """
586
- return any(isinstance(obj, cls) for cls in self.__args__)
587
-
588
- def __subclasscheck__(self, subclass):
589
- """
590
- Check if a class is a subclass of any component type.
591
- """
592
- return any(issubclass(subclass, cls) for cls in self.__args__)
593
-
594
- def __eq__(self, other):
595
- """
596
- Check equality with another type.
597
-
598
- Two OneOfTypes are equal if they have the same component types,
599
- regardless of order.
600
- """
601
- if not isinstance(other, _OneOfGenericAlias):
602
- return NotImplemented
603
- return set(self.__args__) == set(other.__args__)
604
-
605
- def __hash__(self):
606
- """
607
- Return hash of the OneOfType.
608
-
609
- The hash is based on the frozenset of component types to ensure
610
- that OneOfTypes with the same types (regardless of order) have
611
- the same hash.
612
- """
613
- return hash(frozenset(self.__args__))
614
-
615
- def __repr__(self):
616
- """
617
- Return string representation of the OneOfType.
618
-
619
- Returns a readable representation showing all component types.
620
- """
621
- args_str = ', '.join(
622
- arg.__module__ + '.' + arg.__name__ if hasattr(arg, '__module__') and hasattr(arg, '__name__')
623
- else str(arg)
624
- for arg in self.__args__
625
- )
626
- return f'OneOfTypes[{args_str}]'
627
-
628
- def __reduce__(self):
629
- """
630
- Support for pickling.
631
-
632
- Returns the necessary information to reconstruct the OneOfType
633
- when unpickling.
634
- """
635
- return (_OneOfGenericAlias, (self.__origin__, self.__args__))
636
-
637
-
638
- class _JointTypesClass:
639
- """Helper class to enable subscript syntax for JointTypes."""
640
-
641
- def __call__(self, *types):
642
- """
643
- Create a type that requires all specified types (intersection type).
644
-
645
- This function creates a type hint that indicates a value must satisfy all
646
- the specified types simultaneously. It's useful for expressing complex
647
- type requirements where a single object must implement multiple interfaces.
648
-
649
- Parameters
650
- ----------
651
- *types : type
652
- The types that must all be satisfied.
653
-
654
- Returns
655
- -------
656
- type
657
- A type that checks for all specified types.
658
-
659
- Notes
660
- -----
661
- - If only one type is provided, that type is returned directly.
662
- - Redundant types are automatically removed.
663
- - The order of types doesn't matter for equality checks.
664
-
665
- Examples
666
- --------
667
- Basic usage with interfaces:
668
-
669
- .. code-block:: python
670
-
671
- >>> import brainstate
672
- >>> from typing import Protocol
673
- >>>
674
- >>> class Trainable(Protocol):
675
- ... def train(self): ...
676
- >>>
677
- >>> class Evaluable(Protocol):
678
- ... def evaluate(self): ...
679
- >>>
680
- >>> # A model that is both trainable and evaluable
681
- >>> TrainableEvaluableModel = brainstate.mixin.JointTypes(Trainable, Evaluable)
682
- >>> # Or using subscript syntax
683
- >>> TrainableEvaluableModel = brainstate.mixin.JointTypes[Trainable, Evaluable]
684
- >>>
685
- >>> class NeuralNetwork(Trainable, Evaluable):
686
- ... def train(self):
687
- ... return "Training..."
688
- ...
689
- ... def evaluate(self):
690
- ... return "Evaluating..."
691
- >>>
692
- >>> model = NeuralNetwork()
693
- >>> # model satisfies JointTypes(Trainable, Evaluable)
694
-
695
- Using with mixin classes:
696
-
697
- .. code-block:: python
698
-
699
- >>> class Serializable:
700
- ... def save(self): pass
701
- >>>
702
- >>> class Visualizable:
703
- ... def plot(self): pass
704
- >>>
705
- >>> # Require both serialization and visualization
706
- >>> FullFeaturedModel = brainstate.mixin.JointTypes[Serializable, Visualizable]
707
- >>>
708
- >>> class MyModel(Serializable, Visualizable):
709
- ... def save(self):
710
- ... return "Saved"
711
- ...
712
- ... def plot(self):
713
- ... return "Plotted"
714
- """
715
- if len(types) == 0:
716
- raise TypeError("Cannot create a JointTypes of no types.")
717
-
718
- # Remove duplicates while preserving some order
719
- seen = set()
720
- unique_types = []
721
- for t in types:
722
- if t not in seen:
723
- seen.add(t)
724
- unique_types.append(t)
725
-
726
- # If only one type, return it directly
727
- if len(unique_types) == 1:
728
- return unique_types[0]
729
-
730
- # Create a generic alias for the joint type
731
- # This avoids metaclass conflicts by using typing's generic alias system
732
- return _JointGenericAlias(object, tuple(unique_types))
733
-
734
- def __getitem__(self, item):
735
- """Enable subscript syntax: JointTypes[Type1, Type2]."""
736
- if isinstance(item, tuple):
737
- return self(*item)
738
- else:
739
- return self(item)
740
-
741
-
742
- # Create singleton instance that acts as both a callable and supports subscript
743
- JointTypes = _JointTypesClass()
744
-
745
-
746
- class _OneOfTypesClass:
747
- """Helper class to enable subscript syntax for OneOfTypes."""
748
-
749
- def __call__(self, *types):
750
- """
751
- Create a type that requires one of the specified types (union type).
752
-
753
- This is similar to typing.Union but provides a more intuitive name and
754
- consistent behavior with JointTypes. It indicates that a value must satisfy
755
- at least one of the specified types.
756
-
757
- Parameters
758
- ----------
759
- *types : type
760
- The types, one of which must be satisfied.
761
-
762
- Returns
763
- -------
764
- Union type
765
- A union type of the specified types.
766
-
767
- Notes
768
- -----
769
- - If only one type is provided, that type is returned directly.
770
- - Redundant types are automatically removed.
771
- - The order of types doesn't matter for equality checks.
772
- - This is equivalent to typing.Union[...].
773
-
774
- Examples
775
- --------
776
- Basic usage with different types:
777
-
778
- .. code-block:: python
779
-
780
- >>> import brainstate
781
- >>>
782
- >>> # A parameter that can be int or float
783
- >>> NumericType = brainstate.mixin.OneOfTypes(int, float)
784
- >>> # Or using subscript syntax
785
- >>> NumericType = brainstate.mixin.OneOfTypes[int, float]
786
- >>>
787
- >>> def process_value(x: NumericType):
788
- ... return x * 2
789
- >>>
790
- >>> # Both work
791
- >>> result1 = process_value(5) # int
792
- >>> result2 = process_value(3.14) # float
793
-
794
- Using with class types:
795
-
796
- .. code-block:: python
797
-
798
- >>> class NumpyArray:
799
- ... pass
800
- >>>
801
- >>> class JAXArray:
802
- ... pass
803
- >>>
804
- >>> # Accept either numpy or JAX arrays
805
- >>> ArrayType = brainstate.mixin.OneOfTypes[NumpyArray, JAXArray]
806
- >>>
807
- >>> def compute(arr: ArrayType):
808
- ... if isinstance(arr, NumpyArray):
809
- ... return "Processing numpy array"
810
- ... elif isinstance(arr, JAXArray):
811
- ... return "Processing JAX array"
812
-
813
- Combining with None for optional types:
814
-
815
- .. code-block:: python
816
-
817
- >>> # Optional string (equivalent to Optional[str])
818
- >>> MaybeString = brainstate.mixin.OneOfTypes[str, type(None)]
819
- >>>
820
- >>> def format_name(name: MaybeString) -> str:
821
- ... if name is None:
822
- ... return "Anonymous"
823
- ... return name.title()
824
- """
825
- if len(types) == 0:
826
- raise TypeError("Cannot create a OneOfTypes of no types.")
827
-
828
- # Remove duplicates
829
- seen = set()
830
- unique_types = []
831
- for t in types:
832
- if t not in seen:
833
- seen.add(t)
834
- unique_types.append(t)
835
-
836
- # If only one type, return it directly
837
- if len(unique_types) == 1:
838
- return unique_types[0]
839
-
840
- # Create a generic alias for the union type
841
- # This provides consistency with JointTypes and avoids metaclass conflicts
842
- return _OneOfGenericAlias(Union, tuple(unique_types))
843
-
844
- def __getitem__(self, item):
845
- """Enable subscript syntax: OneOfTypes[Type1, Type2]."""
846
- if isinstance(item, tuple):
847
- return self(*item)
848
- else:
849
- return self(item)
850
-
851
-
852
- # Create singleton instance that acts as both a callable and supports subscript
853
- OneOfTypes = _OneOfTypesClass()
854
-
855
-
856
- def __getattr__(name):
857
- if name in [
858
- 'Mode',
859
- 'JointMode',
860
- 'Batching',
861
- 'Training',
862
- 'AlignPost',
863
- 'BindCondData',
864
- ]:
865
- import warnings
866
- warnings.warn(
867
- f"brainstate.mixin.{name} is deprecated and will be removed in a future version. "
868
- f"Please use brainpy.mixin.{name} instead.",
869
- DeprecationWarning,
870
- stacklevel=2
871
- )
872
- import brainpy
873
- return getattr(brainpy.mixin, name)
874
- raise AttributeError(
875
- f'module {__name__!r} has no attribute {name!r}'
876
- )
877
-
878
-
879
- class Mode(Mixin):
880
- """
881
- Base class for computation behavior modes.
882
-
883
- Modes are used to represent different computational contexts or behaviors,
884
- such as training vs evaluation, batched vs single-sample processing, etc.
885
- They provide a flexible way to configure how models and components behave
886
- in different scenarios.
887
-
888
- Examples
889
- --------
890
- Creating a custom mode:
891
-
892
- .. code-block:: python
893
-
894
- >>> import brainstate
895
- >>>
896
- >>> class InferenceMode(brainstate.mixin.Mode):
897
- ... def __init__(self, use_cache=True):
898
- ... self.use_cache = use_cache
899
- >>>
900
- >>> # Create mode instances
901
- >>> inference = InferenceMode(use_cache=True)
902
- >>> print(inference) # Output: InferenceMode
903
-
904
- Checking mode types:
905
-
906
- .. code-block:: python
907
-
908
- >>> class FastMode(brainstate.mixin.Mode):
909
- ... pass
910
- >>>
911
- >>> class SlowMode(brainstate.mixin.Mode):
912
- ... pass
913
- >>>
914
- >>> fast = FastMode()
915
- >>> slow = SlowMode()
916
- >>>
917
- >>> # Check exact mode type
918
- >>> assert fast.is_a(FastMode)
919
- >>> assert not fast.is_a(SlowMode)
920
- >>>
921
- >>> # Check if mode is an instance of a type
922
- >>> assert fast.has(brainstate.mixin.Mode)
923
-
924
- Using modes in a model:
925
-
926
- .. code-block:: python
927
-
928
- >>> class Model:
929
- ... def __init__(self):
930
- ... self.mode = brainstate.mixin.Training()
931
- ...
932
- ... def forward(self, x):
933
- ... if self.mode.has(brainstate.mixin.Training):
934
- ... # Training-specific logic
935
- ... return self.train_forward(x)
936
- ... else:
937
- ... # Inference logic
938
- ... return self.eval_forward(x)
939
- ...
940
- ... def train_forward(self, x):
941
- ... return x + 0.1 # Add noise during training
942
- ...
943
- ... def eval_forward(self, x):
944
- ... return x # No noise during evaluation
945
- """
946
-
947
- def __repr__(self):
948
- """
949
- String representation of the mode.
950
-
951
- Returns
952
- -------
953
- str
954
- The class name of the mode.
955
- """
956
- return self.__class__.__name__
957
-
958
- def __eq__(self, other: 'Mode'):
959
- """
960
- Check equality of modes based on their type.
961
-
962
- Parameters
963
- ----------
964
- other : Mode
965
- Another mode to compare with.
966
-
967
- Returns
968
- -------
969
- bool
970
- True if both modes are of the same class.
971
- """
972
- assert isinstance(other, Mode)
973
- return other.__class__ == self.__class__
974
-
975
- def is_a(self, mode: type):
976
- """
977
- Check whether the mode is exactly the desired mode type.
978
-
979
- This performs an exact type match, not checking for subclasses.
980
-
981
- Parameters
982
- ----------
983
- mode : type
984
- The mode type to check against.
985
-
986
- Returns
987
- -------
988
- bool
989
- True if this mode is exactly of the specified type.
990
-
991
- Examples
992
- --------
993
- .. code-block:: python
994
-
995
- >>> import brainstate
996
- >>>
997
- >>> training_mode = brainstate.mixin.Training()
998
- >>> assert training_mode.is_a(brainstate.mixin.Training)
999
- >>> assert not training_mode.is_a(brainstate.mixin.Batching)
1000
- """
1001
- assert isinstance(mode, type), 'Must be a type.'
1002
- return self.__class__ == mode
1003
-
1004
- def has(self, mode: type):
1005
- """
1006
- Check whether the mode includes the desired mode type.
1007
-
1008
- This checks if the current mode is an instance of the specified type,
1009
- including checking for subclasses.
1010
-
1011
- Parameters
1012
- ----------
1013
- mode : type
1014
- The mode type to check for.
1015
-
1016
- Returns
1017
- -------
1018
- bool
1019
- True if this mode is an instance of the specified type.
1020
-
1021
- Examples
1022
- --------
1023
- .. code-block:: python
1024
-
1025
- >>> import brainstate
1026
- >>>
1027
- >>> # Create a custom mode that extends Training
1028
- >>> class AdvancedTraining(brainstate.mixin.Training):
1029
- ... pass
1030
- >>>
1031
- >>> advanced = AdvancedTraining()
1032
- >>> assert advanced.has(brainstate.mixin.Training) # True (subclass)
1033
- >>> assert advanced.has(brainstate.mixin.Mode) # True (base class)
1034
- """
1035
- assert isinstance(mode, type), 'Must be a type.'
1036
- return isinstance(self, mode)
1037
-
1038
-
1039
- class JointMode(Mode):
1040
- """
1041
- A mode that combines multiple modes simultaneously.
1042
-
1043
- JointMode allows expressing that a computation is in multiple modes at once,
1044
- such as being both in training mode and batching mode. This is useful for
1045
- complex scenarios where multiple behavioral aspects need to be active.
1046
-
1047
- Parameters
1048
- ----------
1049
- *modes : Mode
1050
- The modes to combine.
1051
-
1052
- Attributes
1053
- ----------
1054
- modes : tuple of Mode
1055
- The individual modes that are combined.
1056
- types : set of type
1057
- The types of the combined modes.
1058
-
1059
- Raises
1060
- ------
1061
- TypeError
1062
- If any of the provided arguments is not a Mode instance.
1063
-
1064
- Examples
1065
- --------
1066
- Combining training and batching modes:
1067
-
1068
- .. code-block:: python
1069
-
1070
- >>> import brainstate
1071
- >>>
1072
- >>> # Create individual modes
1073
- >>> training = brainstate.mixin.Training()
1074
- >>> batching = brainstate.mixin.Batching(batch_size=32)
1075
- >>>
1076
- >>> # Combine them
1077
- >>> joint = brainstate.mixin.JointMode(training, batching)
1078
- >>> print(joint) # JointMode(Training, Batching(in_size=32, axis=0))
1079
- >>>
1080
- >>> # Check if specific modes are present
1081
- >>> assert joint.has(brainstate.mixin.Training)
1082
- >>> assert joint.has(brainstate.mixin.Batching)
1083
- >>>
1084
- >>> # Access attributes from combined modes
1085
- >>> print(joint.batch_size) # 32 (from Batching mode)
1086
-
1087
- Using in model configuration:
1088
-
1089
- .. code-block:: python
1090
-
1091
- >>> class NeuralNetwork:
1092
- ... def __init__(self):
1093
- ... self.mode = None
1094
- ...
1095
- ... def set_train_mode(self, batch_size=1):
1096
- ... # Set both training and batching modes
1097
- ... training = brainstate.mixin.Training()
1098
- ... batching = brainstate.mixin.Batching(batch_size=batch_size)
1099
- ... self.mode = brainstate.mixin.JointMode(training, batching)
1100
- ...
1101
- ... def forward(self, x):
1102
- ... if self.mode.has(brainstate.mixin.Training):
1103
- ... x = self.apply_dropout(x)
1104
- ...
1105
- ... if self.mode.has(brainstate.mixin.Batching):
1106
- ... # Process in batches
1107
- ... batch_size = self.mode.batch_size
1108
- ... return self.batch_process(x, batch_size)
1109
- ...
1110
- ... return self.process(x)
1111
- >>>
1112
- >>> model = NeuralNetwork()
1113
- >>> model.set_train_mode(batch_size=64)
1114
- """
1115
-
1116
- def __init__(self, *modes: Mode):
1117
- # Validate that all arguments are Mode instances
1118
- for m_ in modes:
1119
- if not isinstance(m_, Mode):
1120
- raise TypeError(f'The supported type must be a tuple/list of Mode. But we got {m_}')
1121
-
1122
- # Store the modes as a tuple
1123
- self.modes = tuple(modes)
1124
-
1125
- # Store the types of the modes for quick lookup
1126
- self.types = set([m.__class__ for m in modes])
1127
-
1128
- def __repr__(self):
1129
- """
1130
- String representation showing all combined modes.
1131
-
1132
- Returns
1133
- -------
1134
- str
1135
- A string showing the joint mode and its components.
1136
- """
1137
- return f'{self.__class__.__name__}({", ".join([repr(m) for m in self.modes])})'
1138
-
1139
- def has(self, mode: type):
1140
- """
1141
- Check whether any of the combined modes includes the desired type.
1142
-
1143
- Parameters
1144
- ----------
1145
- mode : type
1146
- The mode type to check for.
1147
-
1148
- Returns
1149
- -------
1150
- bool
1151
- True if any of the combined modes is or inherits from the specified type.
1152
-
1153
- Examples
1154
- --------
1155
- .. code-block:: python
1156
-
1157
- >>> import brainstate
1158
- >>>
1159
- >>> training = brainstate.mixin.Training()
1160
- >>> batching = brainstate.mixin.Batching(batch_size=16)
1161
- >>> joint = brainstate.mixin.JointMode(training, batching)
1162
- >>>
1163
- >>> assert joint.has(brainstate.mixin.Training)
1164
- >>> assert joint.has(brainstate.mixin.Batching)
1165
- >>> assert joint.has(brainstate.mixin.Mode) # Base class
1166
- """
1167
- assert isinstance(mode, type), 'Must be a type.'
1168
- # Check if any of the combined mode types is a subclass of the target mode
1169
- return any([issubclass(cls, mode) for cls in self.types])
1170
-
1171
- def is_a(self, cls: type):
1172
- """
1173
- Check whether the joint mode is exactly the desired combined type.
1174
-
1175
- This is a complex check that verifies the joint mode matches a specific
1176
- combination of types.
1177
-
1178
- Parameters
1179
- ----------
1180
- cls : type
1181
- The combined type to check against.
1182
-
1183
- Returns
1184
- -------
1185
- bool
1186
- True if the joint mode exactly matches the specified type combination.
1187
- """
1188
- # Use JointTypes to create the expected type from our mode types
1189
- return JointTypes(*tuple(self.types)) == cls
1190
-
1191
- def __getattr__(self, item):
1192
- """
1193
- Get attributes from the combined modes.
1194
-
1195
- This method searches through all combined modes to find the requested
1196
- attribute, allowing transparent access to properties of any of the
1197
- combined modes.
1198
-
1199
- Parameters
1200
- ----------
1201
- item : str
1202
- The attribute name to search for.
1203
-
1204
- Returns
1205
- -------
1206
- Any
1207
- The attribute value from the first mode that has it.
1208
-
1209
- Raises
1210
- ------
1211
- AttributeError
1212
- If the attribute is not found in any of the combined modes.
1213
-
1214
- Examples
1215
- --------
1216
- .. code-block:: python
1217
-
1218
- >>> import brainstate
1219
- >>>
1220
- >>> batching = brainstate.mixin.Batching(batch_size=32, batch_axis=1)
1221
- >>> training = brainstate.mixin.Training()
1222
- >>> joint = brainstate.mixin.JointMode(batching, training)
1223
- >>>
1224
- >>> # Access batching attributes directly
1225
- >>> print(joint.batch_size) # 32
1226
- >>> print(joint.batch_axis) # 1
1227
- """
1228
- # Don't interfere with accessing modes and types attributes
1229
- if item in ['modes', 'types']:
1230
- return super().__getattribute__(item)
1231
-
1232
- # Search for the attribute in each combined mode
1233
- for m in self.modes:
1234
- if hasattr(m, item):
1235
- return getattr(m, item)
1236
-
1237
- # If not found, fall back to default behavior (will raise AttributeError)
1238
- return super().__getattribute__(item)
1239
-
1240
-
1241
- class Batching(Mode):
1242
- """
1243
- Mode indicating batched computation.
1244
-
1245
- This mode specifies that computations should be performed on batches of data,
1246
- including information about the batch size and which axis represents the batch
1247
- dimension.
1248
-
1249
- Parameters
1250
- ----------
1251
- batch_size : int, default 1
1252
- The size of each batch.
1253
- batch_axis : int, default 0
1254
- The axis along which batching occurs.
1255
-
1256
- Attributes
1257
- ----------
1258
- batch_size : int
1259
- The number of samples in each batch.
1260
- batch_axis : int
1261
- The axis index representing the batch dimension.
1262
-
1263
- Examples
1264
- --------
1265
- Basic batching configuration:
1266
-
1267
- .. code-block:: python
1268
-
1269
- >>> import brainstate
1270
- >>>
1271
- >>> # Create a batching mode
1272
- >>> batching = brainstate.mixin.Batching(batch_size=32, batch_axis=0)
1273
- >>> print(batching) # Batching(in_size=32, axis=0)
1274
- >>>
1275
- >>> # Access batch parameters
1276
- >>> print(f"Processing {batching.batch_size} samples at once")
1277
- >>> print(f"Batch dimension is axis {batching.batch_axis}")
1278
-
1279
- Using in a model:
1280
-
1281
- .. code-block:: python
1282
-
1283
- >>> import jax.numpy as jnp
1284
- >>>
1285
- >>> class BatchedModel:
1286
- ... def __init__(self):
1287
- ... self.mode = None
1288
- ...
1289
- ... def set_batch_mode(self, batch_size, batch_axis=0):
1290
- ... self.mode = brainstate.mixin.Batching(batch_size, batch_axis)
1291
- ...
1292
- ... def process(self, x):
1293
- ... if self.mode is not None and self.mode.has(brainstate.mixin.Batching):
1294
- ... # Process in batches
1295
- ... batch_size = self.mode.batch_size
1296
- ... axis = self.mode.batch_axis
1297
- ... return jnp.mean(x, axis=axis, keepdims=True)
1298
- ... return x
1299
- >>>
1300
- >>> model = BatchedModel()
1301
- >>> model.set_batch_mode(batch_size=64)
1302
- >>>
1303
- >>> # Process batched data
1304
- >>> data = jnp.random.randn(64, 100) # 64 samples, 100 features
1305
- >>> result = model.process(data)
1306
-
1307
- Combining with other modes:
1308
-
1309
- .. code-block:: python
1310
-
1311
- >>> # Combine batching with training mode
1312
- >>> training = brainstate.mixin.Training()
1313
- >>> batching = brainstate.mixin.Batching(batch_size=128)
1314
- >>> combined = brainstate.mixin.JointMode(training, batching)
1315
- >>>
1316
- >>> # Use in a training loop
1317
- >>> def train_step(model, data, mode):
1318
- ... if mode.has(brainstate.mixin.Batching):
1319
- ... # Split data into batches
1320
- ... batch_size = mode.batch_size
1321
- ... # ... batched processing ...
1322
- ... if mode.has(brainstate.mixin.Training):
1323
- ... # Apply training-specific operations
1324
- ... # ... training logic ...
1325
- ... pass
1326
- """
1327
-
1328
- def __init__(self, batch_size: int = 1, batch_axis: int = 0):
1329
- self.batch_size = batch_size
1330
- self.batch_axis = batch_axis
1331
-
1332
- def __repr__(self):
1333
- """
1334
- String representation showing batch configuration.
1335
-
1336
- Returns
1337
- -------
1338
- str
1339
- A string showing the batch size and axis.
1340
- """
1341
- return f'{self.__class__.__name__}(in_size={self.batch_size}, axis={self.batch_axis})'
1342
-
1343
-
1344
- class Training(Mode):
1345
- """
1346
- Mode indicating training computation.
1347
-
1348
- This mode specifies that the model is in training mode, which typically
1349
- enables behaviors like dropout, batch normalization in training mode,
1350
- gradient computation, etc.
1351
-
1352
- Examples
1353
- --------
1354
- Basic training mode:
1355
-
1356
- .. code-block:: python
1357
-
1358
- >>> import brainstate
1359
- >>>
1360
- >>> # Create training mode
1361
- >>> training = brainstate.mixin.Training()
1362
- >>> print(training) # Training
1363
- >>>
1364
- >>> # Check mode
1365
- >>> assert training.is_a(brainstate.mixin.Training)
1366
- >>> assert training.has(brainstate.mixin.Mode)
1367
-
1368
- Using in a model with dropout:
1369
-
1370
- .. code-block:: python
1371
-
1372
- >>> import brainstate
1373
- >>> import jax
1374
- >>> import jax.numpy as jnp
1375
- >>>
1376
- >>> class ModelWithDropout:
1377
- ... def __init__(self, dropout_rate=0.5):
1378
- ... self.dropout_rate = dropout_rate
1379
- ... self.mode = None
1380
- ...
1381
- ... def set_training(self, is_training=True):
1382
- ... if is_training:
1383
- ... self.mode = brainstate.mixin.Training()
1384
- ... else:
1385
- ... self.mode = brainstate.mixin.Mode() # Evaluation mode
1386
- ...
1387
- ... def forward(self, x, rng_key):
1388
- ... # Apply dropout only during training
1389
- ... if self.mode is not None and self.mode.has(brainstate.mixin.Training):
1390
- ... keep_prob = 1.0 - self.dropout_rate
1391
- ... mask = jax.random.bernoulli(rng_key, keep_prob, x.shape)
1392
- ... x = jnp.where(mask, x / keep_prob, 0)
1393
- ... return x
1394
- >>>
1395
- >>> model = ModelWithDropout()
1396
- >>>
1397
- >>> # Training mode
1398
- >>> model.set_training(True)
1399
- >>> key = jax.random.PRNGKey(0)
1400
- >>> x_train = jnp.ones((10, 20))
1401
- >>> out_train = model.forward(x_train, key) # Dropout applied
1402
- >>>
1403
- >>> # Evaluation mode
1404
- >>> model.set_training(False)
1405
- >>> out_eval = model.forward(x_train, key) # No dropout
1406
-
1407
- Combining with batching:
1408
-
1409
- .. code-block:: python
1410
-
1411
- >>> # Create combined training and batching mode
1412
- >>> training = brainstate.mixin.Training()
1413
- >>> batching = brainstate.mixin.Batching(batch_size=32)
1414
- >>> mode = brainstate.mixin.JointMode(training, batching)
1415
- >>>
1416
- >>> # Use in training configuration
1417
- >>> class Trainer:
1418
- ... def __init__(self, model, mode):
1419
- ... self.model = model
1420
- ... self.mode = mode
1421
- ...
1422
- ... def train_epoch(self, data):
1423
- ... if self.mode.has(brainstate.mixin.Training):
1424
- ... # Enable training-specific behaviors
1425
- ... self.model.set_training(True)
1426
- ...
1427
- ... if self.mode.has(brainstate.mixin.Batching):
1428
- ... # Process in batches
1429
- ... batch_size = self.mode.batch_size
1430
- ... # ... batched training loop ...
1431
- ... pass
1432
- """
1433
- pass
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ """
19
+ Mixin classes and utility types for brainstate.
20
+
21
+ This module provides various mixin classes and custom type definitions that
22
+ enhance the functionality of brainstate components. It includes parameter
23
+ description mixins, alignment interfaces, and custom type definitions for
24
+ expressing complex type requirements.
25
+ """
26
+
27
+ from typing import Sequence, Optional, TypeVar, Union, _GenericAlias
28
+
29
+ import jax
30
+
31
+ __all__ = [
32
+ 'Mixin',
33
+ 'ParamDesc',
34
+ 'ParamDescriber',
35
+ 'JointTypes',
36
+ 'OneOfTypes',
37
+ '_JointGenericAlias',
38
+ '_OneOfGenericAlias',
39
+ 'Mode',
40
+ 'JointMode',
41
+ 'Batching',
42
+ 'Training',
43
+ ]
44
+
45
+ T = TypeVar('T')
46
+ ArrayLike = jax.typing.ArrayLike
47
+
48
+
49
+ def hashable(x):
50
+ """
51
+ Check if an object is hashable.
52
+
53
+ Parameters
54
+ ----------
55
+ x : Any
56
+ The object to check for hashability.
57
+
58
+ Returns
59
+ -------
60
+ bool
61
+ True if the object is hashable, False otherwise.
62
+
63
+ Examples
64
+ --------
65
+ .. code-block:: python
66
+
67
+ >>> import brainstate
68
+ >>>
69
+ >>> # Hashable objects
70
+ >>> assert brainstate.mixin.hashable(42) == True
71
+ >>> assert brainstate.mixin.hashable("string") == True
72
+ >>> assert brainstate.mixin.hashable((1, 2, 3)) == True
73
+ >>>
74
+ >>> # Non-hashable objects
75
+ >>> assert brainstate.mixin.hashable([1, 2, 3]) == False
76
+ >>> assert brainstate.mixin.hashable({"key": "value"}) == False
77
+ """
78
+ try:
79
+ hash(x)
80
+ return True
81
+ except TypeError:
82
+ return False
83
+
84
+
85
+ class Mixin(object):
86
+ """
87
+ Base Mixin object for behavioral extensions.
88
+
89
+ The key characteristic of a :py:class:`~.Mixin` is that it provides only
90
+ behavioral functions without requiring initialization. Mixins are used to
91
+ add specific functionality to classes through multiple inheritance without
92
+ the complexity of a full base class.
93
+
94
+ Notes
95
+ -----
96
+ Mixins should not define ``__init__`` methods. They should only provide
97
+ methods that add specific behaviors to the classes that inherit from them.
98
+
99
+ Examples
100
+ --------
101
+ Creating a custom mixin:
102
+
103
+ .. code-block:: python
104
+
105
+ >>> import brainstate
106
+ >>>
107
+ >>> class LoggingMixin(brainstate.mixin.Mixin):
108
+ ... def log(self, message):
109
+ ... print(f"[{self.__class__.__name__}] {message}")
110
+
111
+ >>> class MyComponent(brainstate.nn.Module, LoggingMixin):
112
+ ... def __init__(self):
113
+ ... super().__init__()
114
+ ...
115
+ ... def process(self):
116
+ ... self.log("Processing data...")
117
+ ... return "Done"
118
+ >>>
119
+ >>> component = MyComponent()
120
+ >>> component.process() # Prints: [MyComponent] Processing data...
121
+ """
122
+ pass
123
+
124
+
125
+ class ParamDesc(Mixin):
126
+ """
127
+ Mixin for describing initialization parameters.
128
+
129
+ This mixin enables a class to have a ``desc`` classmethod, which produces
130
+ an instance of :py:class:`~.ParamDescriber`. This is useful for creating
131
+ parameter templates that can be reused to instantiate multiple objects
132
+ with the same configuration.
133
+
134
+ Attributes
135
+ ----------
136
+ non_hashable_params : sequence of str, optional
137
+ Names of parameters that are not hashable and should be handled specially.
138
+
139
+ Notes
140
+ -----
141
+ This mixin can be applied to any Python class, not just brainstate-specific classes.
142
+
143
+ Examples
144
+ --------
145
+ Basic usage of ParamDesc:
146
+
147
+ .. code-block:: python
148
+
149
+ >>> import brainstate
150
+ >>>
151
+ >>> class NeuronModel(brainstate.mixin.ParamDesc):
152
+ ... def __init__(self, size, tau=10.0, threshold=1.0):
153
+ ... self.size = size
154
+ ... self.tau = tau
155
+ ... self.threshold = threshold
156
+ >>>
157
+ >>> # Create a parameter descriptor
158
+ >>> neuron_desc = NeuronModel.desc(size=100, tau=20.0)
159
+ >>>
160
+ >>> # Use the descriptor to create instances
161
+ >>> neuron1 = neuron_desc(threshold=0.8) # Creates with threshold=0.8
162
+ >>> neuron2 = neuron_desc(threshold=1.2) # Creates with threshold=1.2
163
+ >>>
164
+ >>> # Both neurons share size=100, tau=20.0 but have different thresholds
165
+
166
+ Creating reusable templates:
167
+
168
+ .. code-block:: python
169
+
170
+ >>> # Define a template for excitatory neurons
171
+ >>> exc_neuron_template = NeuronModel.desc(size=1000, tau=10.0, threshold=1.0)
172
+ >>>
173
+ >>> # Define a template for inhibitory neurons
174
+ >>> inh_neuron_template = NeuronModel.desc(size=250, tau=5.0, threshold=0.5)
175
+ >>>
176
+ >>> # Create multiple instances from templates
177
+ >>> exc_population = [exc_neuron_template() for _ in range(5)]
178
+ >>> inh_population = [inh_neuron_template() for _ in range(2)]
179
+ """
180
+
181
+ # Optional list of parameter names that are not hashable
182
+ # These will be converted to strings for hashing purposes
183
+ non_hashable_params: Optional[Sequence[str]] = None
184
+
185
+ @classmethod
186
+ def desc(cls, *args, **kwargs) -> 'ParamDescriber':
187
+ """
188
+ Create a parameter describer for this class.
189
+
190
+ Parameters
191
+ ----------
192
+ *args
193
+ Positional arguments to be used in future instantiations.
194
+ **kwargs
195
+ Keyword arguments to be used in future instantiations.
196
+
197
+ Returns
198
+ -------
199
+ ParamDescriber
200
+ A descriptor that can be used to create instances with these parameters.
201
+ """
202
+ return ParamDescriber(cls, *args, **kwargs)
203
+
204
+
205
+ class HashableDict(dict):
206
+ """
207
+ A dictionary that can be hashed by converting non-hashable values to strings.
208
+
209
+ This is used internally to make parameter dictionaries hashable so they can
210
+ be used as part of cache keys or other contexts requiring hashability.
211
+
212
+ Parameters
213
+ ----------
214
+ the_dict : dict
215
+ The dictionary to make hashable.
216
+
217
+ Notes
218
+ -----
219
+ Non-hashable values in the dictionary are automatically converted to their
220
+ string representation.
221
+
222
+ Examples
223
+ --------
224
+ .. code-block:: python
225
+
226
+ >>> import brainstate
227
+ >>> import jax.numpy as jnp
228
+ >>>
229
+ >>> # Regular dict with non-hashable values cannot be hashed
230
+ >>> regular_dict = {"array": jnp.array([1, 2, 3]), "value": 42}
231
+ >>> # hash(regular_dict) # This would raise TypeError
232
+ >>>
233
+ >>> # HashableDict can be hashed
234
+ >>> hashable = brainstate.mixin.HashableDict(regular_dict)
235
+ >>> key = hash(hashable) # This works!
236
+ >>>
237
+ >>> # Can be used in sets or as dict keys
238
+ >>> cache = {hashable: "result"}
239
+ """
240
+
241
+ def __init__(self, the_dict: dict):
242
+ # Process the dictionary to ensure all values are hashable
243
+ out = dict()
244
+ for k, v in the_dict.items():
245
+ if not hashable(v):
246
+ # Convert non-hashable values to their string representation
247
+ v = str(v)
248
+ out[k] = v
249
+ super().__init__(out)
250
+
251
+ def __hash__(self):
252
+ """
253
+ Compute hash from sorted items for consistent hashing regardless of insertion order.
254
+ """
255
+ return hash(tuple(sorted(self.items())))
256
+
257
+
258
+ class NoSubclassMeta(type):
259
+ """
260
+ Metaclass that prevents a class from being subclassed.
261
+
262
+ This is used to ensure that certain classes (like ParamDescriber) are used
263
+ as-is and not extended through inheritance, which could lead to unexpected
264
+ behavior.
265
+
266
+ Raises
267
+ ------
268
+ TypeError
269
+ If an attempt is made to subclass a class using this metaclass.
270
+ """
271
+
272
+ def __new__(cls, name, bases, classdict):
273
+ # Check if any base class uses NoSubclassMeta
274
+ for b in bases:
275
+ if isinstance(b, NoSubclassMeta):
276
+ raise TypeError("type '{0}' is not an acceptable base type".format(b.__name__))
277
+ return type.__new__(cls, name, bases, dict(classdict))
278
+
279
+
280
+ class ParamDescriber(metaclass=NoSubclassMeta):
281
+ """
282
+ Parameter descriptor for deferred object instantiation.
283
+
284
+ This class stores a class reference along with arguments and keyword arguments,
285
+ allowing for deferred instantiation. It's useful for creating templates that
286
+ can be reused to create multiple instances with similar configurations.
287
+
288
+ Parameters
289
+ ----------
290
+ cls : type
291
+ The class to be instantiated.
292
+ *desc_tuple
293
+ Positional arguments to be stored and used during instantiation.
294
+ **desc_dict
295
+ Keyword arguments to be stored and used during instantiation.
296
+
297
+ Attributes
298
+ ----------
299
+ cls : type
300
+ The class that will be instantiated.
301
+ args : tuple
302
+ Stored positional arguments.
303
+ kwargs : dict
304
+ Stored keyword arguments.
305
+ identifier : tuple
306
+ A hashable identifier for this descriptor.
307
+
308
+ Notes
309
+ -----
310
+ ParamDescriber cannot be subclassed due to the NoSubclassMeta metaclass.
311
+ This ensures consistent behavior across the codebase.
312
+
313
+ Examples
314
+ --------
315
+ Manual creation of a descriptor:
316
+
317
+ .. code-block:: python
318
+
319
+ >>> import brainstate
320
+ >>>
321
+ >>> class Network:
322
+ ... def __init__(self, n_neurons, learning_rate=0.01):
323
+ ... self.n_neurons = n_neurons
324
+ ... self.learning_rate = learning_rate
325
+ >>>
326
+ >>> # Create a descriptor
327
+ >>> network_desc = brainstate.mixin.ParamDescriber(
328
+ ... Network, n_neurons=1000, learning_rate=0.001
329
+ ... )
330
+ >>>
331
+ >>> # Use the descriptor to create instances with additional args
332
+ >>> net1 = network_desc()
333
+ >>> net2 = network_desc() # Same configuration
334
+
335
+ Using with ParamDesc mixin:
336
+
337
+ .. code-block:: python
338
+
339
+ >>> class Network(brainstate.mixin.ParamDesc):
340
+ ... def __init__(self, n_neurons, learning_rate=0.01):
341
+ ... self.n_neurons = n_neurons
342
+ ... self.learning_rate = learning_rate
343
+ >>>
344
+ >>> # More concise syntax using the desc() classmethod
345
+ >>> network_desc = Network.desc(n_neurons=1000)
346
+ >>> net = network_desc(learning_rate=0.005) # Override learning_rate
347
+ """
348
+
349
+ def __init__(self, cls: T, *desc_tuple, **desc_dict):
350
+ # Store the class to be instantiated
351
+ self.cls: type = cls
352
+
353
+ # Store the arguments for later instantiation
354
+ self.args = desc_tuple
355
+ self.kwargs = desc_dict
356
+
357
+ # Create a hashable identifier for caching/comparison purposes
358
+ # This combines the class, args tuple, and hashable kwargs dict
359
+ self._identifier = (cls, tuple(desc_tuple), HashableDict(desc_dict))
360
+
361
+ def __call__(self, *args, **kwargs) -> T:
362
+ """
363
+ Instantiate the class with stored and additional arguments.
364
+
365
+ Parameters
366
+ ----------
367
+ *args
368
+ Additional positional arguments to append.
369
+ **kwargs
370
+ Additional keyword arguments to merge (will override stored kwargs).
371
+
372
+ Returns
373
+ -------
374
+ T
375
+ An instance of the described class.
376
+ """
377
+ # Merge stored arguments with new arguments
378
+ # Stored args come first, then new args
379
+ # Merge kwargs with new kwargs overriding stored ones
380
+ merged_kwargs = {**self.kwargs, **kwargs}
381
+ return self.cls(*self.args, *args, **merged_kwargs)
382
+
383
+ def __repr__(self):
384
+ """
385
+ Return a string representation of the ParamDescriber.
386
+
387
+ Returns
388
+ -------
389
+ str
390
+ A string showing the class and stored parameters.
391
+ """
392
+ args_str = ', '.join(repr(a) for a in self.args)
393
+ kwargs_str = ', '.join(f'{k}={v!r}' for k, v in self.kwargs.items())
394
+ all_params = ', '.join(filter(None, [args_str, kwargs_str]))
395
+ return f'ParamDescriber({self.cls.__name__}({all_params}))'
396
+
397
+ def init(self, *args, **kwargs):
398
+ """
399
+ Alias for __call__, explicitly named for clarity.
400
+
401
+ Parameters
402
+ ----------
403
+ *args
404
+ Additional positional arguments.
405
+ **kwargs
406
+ Additional keyword arguments.
407
+
408
+ Returns
409
+ -------
410
+ T
411
+ An instance of the described class.
412
+ """
413
+ return self.__call__(*args, **kwargs)
414
+
415
+ def __instancecheck__(self, instance):
416
+ """
417
+ Check if an instance is compatible with this descriptor.
418
+
419
+ Parameters
420
+ ----------
421
+ instance : Any
422
+ The instance to check.
423
+
424
+ Returns
425
+ -------
426
+ bool
427
+ True if the instance is a ParamDescriber for a compatible class.
428
+ """
429
+ # Must be a ParamDescriber
430
+ if not isinstance(instance, ParamDescriber):
431
+ return False
432
+ # The described class must be a subclass of our class
433
+ if not issubclass(instance.cls, self.cls):
434
+ return False
435
+ return True
436
+
437
+ @classmethod
438
+ def __class_getitem__(cls, item: type):
439
+ """
440
+ Support for subscript notation: ParamDescriber[MyClass].
441
+
442
+ Parameters
443
+ ----------
444
+ item : type
445
+ The class to create a descriptor for.
446
+
447
+ Returns
448
+ -------
449
+ ParamDescriber
450
+ A descriptor for the given class.
451
+ """
452
+ return ParamDescriber(item)
453
+
454
+ @property
455
+ def identifier(self):
456
+ """
457
+ Get the unique identifier for this descriptor.
458
+
459
+ Returns
460
+ -------
461
+ tuple
462
+ A hashable identifier consisting of (class, args, kwargs).
463
+ """
464
+ return self._identifier
465
+
466
+ @identifier.setter
467
+ def identifier(self, value: ArrayLike):
468
+ """
469
+ Prevent modification of the identifier.
470
+
471
+ Raises
472
+ ------
473
+ AttributeError
474
+ Always, as the identifier is read-only.
475
+ """
476
+ raise AttributeError('Cannot set the identifier.')
477
+
478
+
479
+ def not_implemented(func):
480
+ """
481
+ Decorator to mark a function as not implemented.
482
+
483
+ This decorator wraps a function to raise NotImplementedError when called,
484
+ and adds a ``not_implemented`` attribute for checking.
485
+
486
+ Parameters
487
+ ----------
488
+ func : callable
489
+ The function to mark as not implemented.
490
+
491
+ Returns
492
+ -------
493
+ callable
494
+ A wrapper function that raises NotImplementedError.
495
+
496
+ Examples
497
+ --------
498
+ .. code-block:: python
499
+
500
+ >>> import brainstate
501
+ >>>
502
+ >>> class BaseModel:
503
+ ... @brainstate.mixin.not_implemented
504
+ ... def process(self, x):
505
+ ... pass
506
+ >>>
507
+ >>> model = BaseModel()
508
+ >>> # model.process(10) # Raises: NotImplementedError: process is not implemented.
509
+ >>>
510
+ >>> # Check if a method is not implemented
511
+ >>> assert hasattr(BaseModel.process, 'not_implemented')
512
+ """
513
+
514
+ def wrapper(*args, **kwargs):
515
+ raise NotImplementedError(f'{func.__name__} is not implemented.')
516
+
517
+ # Mark the wrapper so we can detect not-implemented methods
518
+ wrapper.not_implemented = True
519
+ return wrapper
520
+
521
+
522
+ class _JointGenericAlias(_GenericAlias, _root=True):
523
+ """
524
+ Generic alias for JointTypes (intersection types).
525
+
526
+ This class represents a type that requires all specified types to be satisfied.
527
+ Unlike _MetaUnionType which creates actual classes with metaclass conflicts,
528
+ this uses typing's generic alias system to avoid metaclass issues.
529
+ """
530
+
531
+ def __instancecheck__(self, obj):
532
+ """
533
+ Check if an instance is an instance of all component types.
534
+ """
535
+ return all(isinstance(obj, cls) for cls in self.__args__)
536
+
537
+ def __subclasscheck__(self, subclass):
538
+ """
539
+ Check if a class is a subclass of all component types.
540
+ """
541
+ return all(issubclass(subclass, cls) for cls in self.__args__)
542
+
543
+ def __eq__(self, other):
544
+ """
545
+ Check equality with another type.
546
+
547
+ Two JointTypes are equal if they have the same component types,
548
+ regardless of order.
549
+ """
550
+ if not isinstance(other, _JointGenericAlias):
551
+ return NotImplemented
552
+ return set(self.__args__) == set(other.__args__)
553
+
554
+ def __hash__(self):
555
+ """
556
+ Return hash of the JointType.
557
+
558
+ The hash is based on the frozenset of component types to ensure
559
+ that JointTypes with the same types (regardless of order) have
560
+ the same hash.
561
+ """
562
+ return hash(frozenset(self.__args__))
563
+
564
+ def __repr__(self):
565
+ """
566
+ Return string representation of the JointType.
567
+
568
+ Returns a readable representation showing all component types.
569
+ """
570
+ args_str = ', '.join(
571
+ arg.__module__ + '.' + arg.__name__ if hasattr(arg, '__module__') and hasattr(arg, '__name__')
572
+ else str(arg)
573
+ for arg in self.__args__
574
+ )
575
+ return f'JointTypes[{args_str}]'
576
+
577
+ def __reduce__(self):
578
+ """
579
+ Support for pickling.
580
+
581
+ Returns the necessary information to reconstruct the JointType
582
+ when unpickling.
583
+ """
584
+ return (_JointGenericAlias, (self.__origin__, self.__args__))
585
+
586
+
587
+ class _OneOfGenericAlias(_GenericAlias, _root=True):
588
+ """
589
+ Generic alias for OneOfTypes (union types).
590
+
591
+ This class represents a type that requires at least one of the specified
592
+ types to be satisfied. It's similar to typing.Union but provides a consistent
593
+ interface with JointTypes and avoids potential metaclass conflicts.
594
+ """
595
+
596
+ def __instancecheck__(self, obj):
597
+ """
598
+ Check if an instance is an instance of any component type.
599
+ """
600
+ return any(isinstance(obj, cls) for cls in self.__args__)
601
+
602
+ def __subclasscheck__(self, subclass):
603
+ """
604
+ Check if a class is a subclass of any component type.
605
+ """
606
+ return any(issubclass(subclass, cls) for cls in self.__args__)
607
+
608
+ def __eq__(self, other):
609
+ """
610
+ Check equality with another type.
611
+
612
+ Two OneOfTypes are equal if they have the same component types,
613
+ regardless of order.
614
+ """
615
+ if not isinstance(other, _OneOfGenericAlias):
616
+ return NotImplemented
617
+ return set(self.__args__) == set(other.__args__)
618
+
619
+ def __hash__(self):
620
+ """
621
+ Return hash of the OneOfType.
622
+
623
+ The hash is based on the frozenset of component types to ensure
624
+ that OneOfTypes with the same types (regardless of order) have
625
+ the same hash.
626
+ """
627
+ return hash(frozenset(self.__args__))
628
+
629
+ def __repr__(self):
630
+ """
631
+ Return string representation of the OneOfType.
632
+
633
+ Returns a readable representation showing all component types.
634
+ """
635
+ args_str = ', '.join(
636
+ arg.__module__ + '.' + arg.__name__ if hasattr(arg, '__module__') and hasattr(arg, '__name__')
637
+ else str(arg)
638
+ for arg in self.__args__
639
+ )
640
+ return f'OneOfTypes[{args_str}]'
641
+
642
+ def __reduce__(self):
643
+ """
644
+ Support for pickling.
645
+
646
+ Returns the necessary information to reconstruct the OneOfType
647
+ when unpickling.
648
+ """
649
+ return (_OneOfGenericAlias, (self.__origin__, self.__args__))
650
+
651
+
652
+ class _JointTypesClass:
653
+ """Helper class to enable subscript syntax for JointTypes."""
654
+
655
+ def __call__(self, *types):
656
+ """
657
+ Create a type that requires all specified types (intersection type).
658
+
659
+ This function creates a type hint that indicates a value must satisfy all
660
+ the specified types simultaneously. It's useful for expressing complex
661
+ type requirements where a single object must implement multiple interfaces.
662
+
663
+ Parameters
664
+ ----------
665
+ *types : type
666
+ The types that must all be satisfied.
667
+
668
+ Returns
669
+ -------
670
+ type
671
+ A type that checks for all specified types.
672
+
673
+ Notes
674
+ -----
675
+ - If only one type is provided, that type is returned directly.
676
+ - Redundant types are automatically removed.
677
+ - The order of types doesn't matter for equality checks.
678
+
679
+ Examples
680
+ --------
681
+ Basic usage with interfaces:
682
+
683
+ .. code-block:: python
684
+
685
+ >>> import brainstate
686
+ >>> from typing import Protocol
687
+ >>>
688
+ >>> class Trainable(Protocol):
689
+ ... def train(self): ...
690
+ >>>
691
+ >>> class Evaluable(Protocol):
692
+ ... def evaluate(self): ...
693
+ >>>
694
+ >>> # A model that is both trainable and evaluable
695
+ >>> TrainableEvaluableModel = brainstate.mixin.JointTypes(Trainable, Evaluable)
696
+ >>> # Or using subscript syntax
697
+ >>> TrainableEvaluableModel = brainstate.mixin.JointTypes[Trainable, Evaluable]
698
+ >>>
699
+ >>> class NeuralNetwork(Trainable, Evaluable):
700
+ ... def train(self):
701
+ ... return "Training..."
702
+ ...
703
+ ... def evaluate(self):
704
+ ... return "Evaluating..."
705
+ >>>
706
+ >>> model = NeuralNetwork()
707
+ >>> # model satisfies JointTypes(Trainable, Evaluable)
708
+
709
+ Using with mixin classes:
710
+
711
+ .. code-block:: python
712
+
713
+ >>> class Serializable:
714
+ ... def save(self): pass
715
+ >>>
716
+ >>> class Visualizable:
717
+ ... def plot(self): pass
718
+ >>>
719
+ >>> # Require both serialization and visualization
720
+ >>> FullFeaturedModel = brainstate.mixin.JointTypes[Serializable, Visualizable]
721
+ >>>
722
+ >>> class MyModel(Serializable, Visualizable):
723
+ ... def save(self):
724
+ ... return "Saved"
725
+ ...
726
+ ... def plot(self):
727
+ ... return "Plotted"
728
+ """
729
+ if len(types) == 0:
730
+ raise TypeError("Cannot create a JointTypes of no types.")
731
+
732
+ # Remove duplicates while preserving some order
733
+ seen = set()
734
+ unique_types = []
735
+ for t in types:
736
+ if t not in seen:
737
+ seen.add(t)
738
+ unique_types.append(t)
739
+
740
+ # If only one type, return it directly
741
+ if len(unique_types) == 1:
742
+ return unique_types[0]
743
+
744
+ # Create a generic alias for the joint type
745
+ # This avoids metaclass conflicts by using typing's generic alias system
746
+ return _JointGenericAlias(object, tuple(unique_types))
747
+
748
+ def __getitem__(self, item):
749
+ """Enable subscript syntax: JointTypes[Type1, Type2]."""
750
+ if isinstance(item, tuple):
751
+ return self(*item)
752
+ else:
753
+ return self(item)
754
+
755
+
756
+ # Create singleton instance that acts as both a callable and supports subscript
757
+ JointTypes = _JointTypesClass()
758
+
759
+
760
+ class _OneOfTypesClass:
761
+ """Helper class to enable subscript syntax for OneOfTypes."""
762
+
763
+ def __call__(self, *types):
764
+ """
765
+ Create a type that requires one of the specified types (union type).
766
+
767
+ This is similar to typing.Union but provides a more intuitive name and
768
+ consistent behavior with JointTypes. It indicates that a value must satisfy
769
+ at least one of the specified types.
770
+
771
+ Parameters
772
+ ----------
773
+ *types : type
774
+ The types, one of which must be satisfied.
775
+
776
+ Returns
777
+ -------
778
+ Union type
779
+ A union type of the specified types.
780
+
781
+ Notes
782
+ -----
783
+ - If only one type is provided, that type is returned directly.
784
+ - Redundant types are automatically removed.
785
+ - The order of types doesn't matter for equality checks.
786
+ - This is equivalent to typing.Union[...].
787
+
788
+ Examples
789
+ --------
790
+ Basic usage with different types:
791
+
792
+ .. code-block:: python
793
+
794
+ >>> import brainstate
795
+ >>>
796
+ >>> # A parameter that can be int or float
797
+ >>> NumericType = brainstate.mixin.OneOfTypes(int, float)
798
+ >>> # Or using subscript syntax
799
+ >>> NumericType = brainstate.mixin.OneOfTypes[int, float]
800
+ >>>
801
+ >>> def process_value(x: NumericType):
802
+ ... return x * 2
803
+ >>>
804
+ >>> # Both work
805
+ >>> result1 = process_value(5) # int
806
+ >>> result2 = process_value(3.14) # float
807
+
808
+ Using with class types:
809
+
810
+ .. code-block:: python
811
+
812
+ >>> class NumpyArray:
813
+ ... pass
814
+ >>>
815
+ >>> class JAXArray:
816
+ ... pass
817
+ >>>
818
+ >>> # Accept either numpy or JAX arrays
819
+ >>> ArrayType = brainstate.mixin.OneOfTypes[NumpyArray, JAXArray]
820
+ >>>
821
+ >>> def compute(arr: ArrayType):
822
+ ... if isinstance(arr, NumpyArray):
823
+ ... return "Processing numpy array"
824
+ ... elif isinstance(arr, JAXArray):
825
+ ... return "Processing JAX array"
826
+
827
+ Combining with None for optional types:
828
+
829
+ .. code-block:: python
830
+
831
+ >>> # Optional string (equivalent to Optional[str])
832
+ >>> MaybeString = brainstate.mixin.OneOfTypes[str, type(None)]
833
+ >>>
834
+ >>> def format_name(name: MaybeString) -> str:
835
+ ... if name is None:
836
+ ... return "Anonymous"
837
+ ... return name.title()
838
+ """
839
+ if len(types) == 0:
840
+ raise TypeError("Cannot create a OneOfTypes of no types.")
841
+
842
+ # Remove duplicates
843
+ seen = set()
844
+ unique_types = []
845
+ for t in types:
846
+ if t not in seen:
847
+ seen.add(t)
848
+ unique_types.append(t)
849
+
850
+ # If only one type, return it directly
851
+ if len(unique_types) == 1:
852
+ return unique_types[0]
853
+
854
+ # Create a generic alias for the union type
855
+ # This provides consistency with JointTypes and avoids metaclass conflicts
856
+ return _OneOfGenericAlias(Union, tuple(unique_types))
857
+
858
+ def __getitem__(self, item):
859
+ """Enable subscript syntax: OneOfTypes[Type1, Type2]."""
860
+ if isinstance(item, tuple):
861
+ return self(*item)
862
+ else:
863
+ return self(item)
864
+
865
+
866
+ # Create singleton instance that acts as both a callable and supports subscript
867
+ OneOfTypes = _OneOfTypesClass()
868
+
869
+
870
+ def __getattr__(name):
871
+ if name in [
872
+ 'Mode',
873
+ 'JointMode',
874
+ 'Batching',
875
+ 'Training',
876
+ 'AlignPost',
877
+ 'BindCondData',
878
+ ]:
879
+ import warnings
880
+ warnings.warn(
881
+ f"brainstate.mixin.{name} is deprecated and will be removed in a future version. "
882
+ f"Please use brainpy.mixin.{name} instead.",
883
+ DeprecationWarning,
884
+ stacklevel=2
885
+ )
886
+ import brainpy
887
+ return getattr(brainpy.mixin, name)
888
+ raise AttributeError(
889
+ f'module {__name__!r} has no attribute {name!r}'
890
+ )
891
+
892
+
893
+ class Mode(Mixin):
894
+ """
895
+ Base class for computation behavior modes.
896
+
897
+ Modes are used to represent different computational contexts or behaviors,
898
+ such as training vs evaluation, batched vs single-sample processing, etc.
899
+ They provide a flexible way to configure how models and components behave
900
+ in different scenarios.
901
+
902
+ Examples
903
+ --------
904
+ Creating a custom mode:
905
+
906
+ .. code-block:: python
907
+
908
+ >>> import brainstate
909
+ >>>
910
+ >>> class InferenceMode(brainstate.mixin.Mode):
911
+ ... def __init__(self, use_cache=True):
912
+ ... self.use_cache = use_cache
913
+ >>>
914
+ >>> # Create mode instances
915
+ >>> inference = InferenceMode(use_cache=True)
916
+ >>> print(inference) # Output: InferenceMode
917
+
918
+ Checking mode types:
919
+
920
+ .. code-block:: python
921
+
922
+ >>> class FastMode(brainstate.mixin.Mode):
923
+ ... pass
924
+ >>>
925
+ >>> class SlowMode(brainstate.mixin.Mode):
926
+ ... pass
927
+ >>>
928
+ >>> fast = FastMode()
929
+ >>> slow = SlowMode()
930
+ >>>
931
+ >>> # Check exact mode type
932
+ >>> assert fast.is_a(FastMode)
933
+ >>> assert not fast.is_a(SlowMode)
934
+ >>>
935
+ >>> # Check if mode is an instance of a type
936
+ >>> assert fast.has(brainstate.mixin.Mode)
937
+
938
+ Using modes in a model:
939
+
940
+ .. code-block:: python
941
+
942
+ >>> class Model:
943
+ ... def __init__(self):
944
+ ... self.mode = brainstate.mixin.Training()
945
+ ...
946
+ ... def forward(self, x):
947
+ ... if self.mode.has(brainstate.mixin.Training):
948
+ ... # Training-specific logic
949
+ ... return self.train_forward(x)
950
+ ... else:
951
+ ... # Inference logic
952
+ ... return self.eval_forward(x)
953
+ ...
954
+ ... def train_forward(self, x):
955
+ ... return x + 0.1 # Add noise during training
956
+ ...
957
+ ... def eval_forward(self, x):
958
+ ... return x # No noise during evaluation
959
+ """
960
+
961
+ def __repr__(self):
962
+ """
963
+ String representation of the mode.
964
+
965
+ Returns
966
+ -------
967
+ str
968
+ The class name of the mode.
969
+ """
970
+ return self.__class__.__name__
971
+
972
+ def __eq__(self, other: 'Mode'):
973
+ """
974
+ Check equality of modes based on their type.
975
+
976
+ Parameters
977
+ ----------
978
+ other : Mode
979
+ Another mode to compare with.
980
+
981
+ Returns
982
+ -------
983
+ bool
984
+ True if both modes are of the same class.
985
+ """
986
+ assert isinstance(other, Mode)
987
+ return other.__class__ == self.__class__
988
+
989
+ def is_a(self, mode: type):
990
+ """
991
+ Check whether the mode is exactly the desired mode type.
992
+
993
+ This performs an exact type match, not checking for subclasses.
994
+
995
+ Parameters
996
+ ----------
997
+ mode : type
998
+ The mode type to check against.
999
+
1000
+ Returns
1001
+ -------
1002
+ bool
1003
+ True if this mode is exactly of the specified type.
1004
+
1005
+ Examples
1006
+ --------
1007
+ .. code-block:: python
1008
+
1009
+ >>> import brainstate
1010
+ >>>
1011
+ >>> training_mode = brainstate.mixin.Training()
1012
+ >>> assert training_mode.is_a(brainstate.mixin.Training)
1013
+ >>> assert not training_mode.is_a(brainstate.mixin.Batching)
1014
+ """
1015
+ assert isinstance(mode, type), 'Must be a type.'
1016
+ return self.__class__ == mode
1017
+
1018
+ def has(self, mode: type):
1019
+ """
1020
+ Check whether the mode includes the desired mode type.
1021
+
1022
+ This checks if the current mode is an instance of the specified type,
1023
+ including checking for subclasses.
1024
+
1025
+ Parameters
1026
+ ----------
1027
+ mode : type
1028
+ The mode type to check for.
1029
+
1030
+ Returns
1031
+ -------
1032
+ bool
1033
+ True if this mode is an instance of the specified type.
1034
+
1035
+ Examples
1036
+ --------
1037
+ .. code-block:: python
1038
+
1039
+ >>> import brainstate
1040
+ >>>
1041
+ >>> # Create a custom mode that extends Training
1042
+ >>> class AdvancedTraining(brainstate.mixin.Training):
1043
+ ... pass
1044
+ >>>
1045
+ >>> advanced = AdvancedTraining()
1046
+ >>> assert advanced.has(brainstate.mixin.Training) # True (subclass)
1047
+ >>> assert advanced.has(brainstate.mixin.Mode) # True (base class)
1048
+ """
1049
+ assert isinstance(mode, type), 'Must be a type.'
1050
+ return isinstance(self, mode)
1051
+
1052
+
1053
+ class JointMode(Mode):
1054
+ """
1055
+ A mode that combines multiple modes simultaneously.
1056
+
1057
+ JointMode allows expressing that a computation is in multiple modes at once,
1058
+ such as being both in training mode and batching mode. This is useful for
1059
+ complex scenarios where multiple behavioral aspects need to be active.
1060
+
1061
+ Parameters
1062
+ ----------
1063
+ *modes : Mode
1064
+ The modes to combine.
1065
+
1066
+ Attributes
1067
+ ----------
1068
+ modes : tuple of Mode
1069
+ The individual modes that are combined.
1070
+ types : set of type
1071
+ The types of the combined modes.
1072
+
1073
+ Raises
1074
+ ------
1075
+ TypeError
1076
+ If any of the provided arguments is not a Mode instance.
1077
+
1078
+ Examples
1079
+ --------
1080
+ Combining training and batching modes:
1081
+
1082
+ .. code-block:: python
1083
+
1084
+ >>> import brainstate
1085
+ >>>
1086
+ >>> # Create individual modes
1087
+ >>> training = brainstate.mixin.Training()
1088
+ >>> batching = brainstate.mixin.Batching(batch_size=32)
1089
+ >>>
1090
+ >>> # Combine them
1091
+ >>> joint = brainstate.mixin.JointMode(training, batching)
1092
+ >>> print(joint) # JointMode(Training, Batching(in_size=32, axis=0))
1093
+ >>>
1094
+ >>> # Check if specific modes are present
1095
+ >>> assert joint.has(brainstate.mixin.Training)
1096
+ >>> assert joint.has(brainstate.mixin.Batching)
1097
+ >>>
1098
+ >>> # Access attributes from combined modes
1099
+ >>> print(joint.batch_size) # 32 (from Batching mode)
1100
+
1101
+ Using in model configuration:
1102
+
1103
+ .. code-block:: python
1104
+
1105
+ >>> class NeuralNetwork:
1106
+ ... def __init__(self):
1107
+ ... self.mode = None
1108
+ ...
1109
+ ... def set_train_mode(self, batch_size=1):
1110
+ ... # Set both training and batching modes
1111
+ ... training = brainstate.mixin.Training()
1112
+ ... batching = brainstate.mixin.Batching(batch_size=batch_size)
1113
+ ... self.mode = brainstate.mixin.JointMode(training, batching)
1114
+ ...
1115
+ ... def forward(self, x):
1116
+ ... if self.mode.has(brainstate.mixin.Training):
1117
+ ... x = self.apply_dropout(x)
1118
+ ...
1119
+ ... if self.mode.has(brainstate.mixin.Batching):
1120
+ ... # Process in batches
1121
+ ... batch_size = self.mode.batch_size
1122
+ ... return self.batch_process(x, batch_size)
1123
+ ...
1124
+ ... return self.process(x)
1125
+ >>>
1126
+ >>> model = NeuralNetwork()
1127
+ >>> model.set_train_mode(batch_size=64)
1128
+ """
1129
+
1130
+ def __init__(self, *modes: Mode):
1131
+ # Validate that all arguments are Mode instances
1132
+ for m_ in modes:
1133
+ if not isinstance(m_, Mode):
1134
+ raise TypeError(f'The supported type must be a tuple/list of Mode. But we got {m_}')
1135
+
1136
+ # Store the modes as a tuple
1137
+ self.modes = tuple(modes)
1138
+
1139
+ # Store the types of the modes for quick lookup
1140
+ self.types = set([m.__class__ for m in modes])
1141
+
1142
+ def __repr__(self):
1143
+ """
1144
+ String representation showing all combined modes.
1145
+
1146
+ Returns
1147
+ -------
1148
+ str
1149
+ A string showing the joint mode and its components.
1150
+ """
1151
+ return f'{self.__class__.__name__}({", ".join([repr(m) for m in self.modes])})'
1152
+
1153
+ def has(self, mode: type):
1154
+ """
1155
+ Check whether any of the combined modes includes the desired type.
1156
+
1157
+ Parameters
1158
+ ----------
1159
+ mode : type
1160
+ The mode type to check for.
1161
+
1162
+ Returns
1163
+ -------
1164
+ bool
1165
+ True if any of the combined modes is or inherits from the specified type.
1166
+
1167
+ Examples
1168
+ --------
1169
+ .. code-block:: python
1170
+
1171
+ >>> import brainstate
1172
+ >>>
1173
+ >>> training = brainstate.mixin.Training()
1174
+ >>> batching = brainstate.mixin.Batching(batch_size=16)
1175
+ >>> joint = brainstate.mixin.JointMode(training, batching)
1176
+ >>>
1177
+ >>> assert joint.has(brainstate.mixin.Training)
1178
+ >>> assert joint.has(brainstate.mixin.Batching)
1179
+ >>> assert joint.has(brainstate.mixin.Mode) # Base class
1180
+ """
1181
+ assert isinstance(mode, type), 'Must be a type.'
1182
+ # Check if any of the combined mode types is a subclass of the target mode
1183
+ return any([issubclass(cls, mode) for cls in self.types])
1184
+
1185
+ def is_a(self, cls: type):
1186
+ """
1187
+ Check whether the joint mode is exactly the desired combined type.
1188
+
1189
+ This is a complex check that verifies the joint mode matches a specific
1190
+ combination of types.
1191
+
1192
+ Parameters
1193
+ ----------
1194
+ cls : type
1195
+ The combined type to check against.
1196
+
1197
+ Returns
1198
+ -------
1199
+ bool
1200
+ True if the joint mode exactly matches the specified type combination.
1201
+ """
1202
+ # Use JointTypes to create the expected type from our mode types
1203
+ return JointTypes(*tuple(self.types)) == cls
1204
+
1205
+ def __getattr__(self, item):
1206
+ """
1207
+ Get attributes from the combined modes.
1208
+
1209
+ This method searches through all combined modes to find the requested
1210
+ attribute, allowing transparent access to properties of any of the
1211
+ combined modes.
1212
+
1213
+ Parameters
1214
+ ----------
1215
+ item : str
1216
+ The attribute name to search for.
1217
+
1218
+ Returns
1219
+ -------
1220
+ Any
1221
+ The attribute value from the first mode that has it.
1222
+
1223
+ Raises
1224
+ ------
1225
+ AttributeError
1226
+ If the attribute is not found in any of the combined modes.
1227
+
1228
+ Examples
1229
+ --------
1230
+ .. code-block:: python
1231
+
1232
+ >>> import brainstate
1233
+ >>>
1234
+ >>> batching = brainstate.mixin.Batching(batch_size=32, batch_axis=1)
1235
+ >>> training = brainstate.mixin.Training()
1236
+ >>> joint = brainstate.mixin.JointMode(batching, training)
1237
+ >>>
1238
+ >>> # Access batching attributes directly
1239
+ >>> print(joint.batch_size) # 32
1240
+ >>> print(joint.batch_axis) # 1
1241
+ """
1242
+ # Don't interfere with accessing modes and types attributes
1243
+ if item in ['modes', 'types']:
1244
+ return super().__getattribute__(item)
1245
+
1246
+ # Search for the attribute in each combined mode
1247
+ for m in self.modes:
1248
+ if hasattr(m, item):
1249
+ return getattr(m, item)
1250
+
1251
+ # If not found, fall back to default behavior (will raise AttributeError)
1252
+ return super().__getattribute__(item)
1253
+
1254
+
1255
+ class Batching(Mode):
1256
+ """
1257
+ Mode indicating batched computation.
1258
+
1259
+ This mode specifies that computations should be performed on batches of data,
1260
+ including information about the batch size and which axis represents the batch
1261
+ dimension.
1262
+
1263
+ Parameters
1264
+ ----------
1265
+ batch_size : int, default 1
1266
+ The size of each batch.
1267
+ batch_axis : int, default 0
1268
+ The axis along which batching occurs.
1269
+
1270
+ Attributes
1271
+ ----------
1272
+ batch_size : int
1273
+ The number of samples in each batch.
1274
+ batch_axis : int
1275
+ The axis index representing the batch dimension.
1276
+
1277
+ Examples
1278
+ --------
1279
+ Basic batching configuration:
1280
+
1281
+ .. code-block:: python
1282
+
1283
+ >>> import brainstate
1284
+ >>>
1285
+ >>> # Create a batching mode
1286
+ >>> batching = brainstate.mixin.Batching(batch_size=32, batch_axis=0)
1287
+ >>> print(batching) # Batching(in_size=32, axis=0)
1288
+ >>>
1289
+ >>> # Access batch parameters
1290
+ >>> print(f"Processing {batching.batch_size} samples at once")
1291
+ >>> print(f"Batch dimension is axis {batching.batch_axis}")
1292
+
1293
+ Using in a model:
1294
+
1295
+ .. code-block:: python
1296
+
1297
+ >>> import jax.numpy as jnp
1298
+ >>>
1299
+ >>> class BatchedModel:
1300
+ ... def __init__(self):
1301
+ ... self.mode = None
1302
+ ...
1303
+ ... def set_batch_mode(self, batch_size, batch_axis=0):
1304
+ ... self.mode = brainstate.mixin.Batching(batch_size, batch_axis)
1305
+ ...
1306
+ ... def process(self, x):
1307
+ ... if self.mode is not None and self.mode.has(brainstate.mixin.Batching):
1308
+ ... # Process in batches
1309
+ ... batch_size = self.mode.batch_size
1310
+ ... axis = self.mode.batch_axis
1311
+ ... return jnp.mean(x, axis=axis, keepdims=True)
1312
+ ... return x
1313
+ >>>
1314
+ >>> model = BatchedModel()
1315
+ >>> model.set_batch_mode(batch_size=64)
1316
+ >>>
1317
+ >>> # Process batched data
1318
+ >>> data = jnp.random.randn(64, 100) # 64 samples, 100 features
1319
+ >>> result = model.process(data)
1320
+
1321
+ Combining with other modes:
1322
+
1323
+ .. code-block:: python
1324
+
1325
+ >>> # Combine batching with training mode
1326
+ >>> training = brainstate.mixin.Training()
1327
+ >>> batching = brainstate.mixin.Batching(batch_size=128)
1328
+ >>> combined = brainstate.mixin.JointMode(training, batching)
1329
+ >>>
1330
+ >>> # Use in a training loop
1331
+ >>> def train_step(model, data, mode):
1332
+ ... if mode.has(brainstate.mixin.Batching):
1333
+ ... # Split data into batches
1334
+ ... batch_size = mode.batch_size
1335
+ ... # ... batched processing ...
1336
+ ... if mode.has(brainstate.mixin.Training):
1337
+ ... # Apply training-specific operations
1338
+ ... # ... training logic ...
1339
+ ... pass
1340
+ """
1341
+
1342
+ def __init__(self, batch_size: int = 1, batch_axis: int = 0):
1343
+ self.batch_size = batch_size
1344
+ self.batch_axis = batch_axis
1345
+
1346
+ def __repr__(self):
1347
+ """
1348
+ String representation showing batch configuration.
1349
+
1350
+ Returns
1351
+ -------
1352
+ str
1353
+ A string showing the batch size and axis.
1354
+ """
1355
+ return f'{self.__class__.__name__}(in_size={self.batch_size}, axis={self.batch_axis})'
1356
+
1357
+
1358
+ class Training(Mode):
1359
+ """
1360
+ Mode indicating training computation.
1361
+
1362
+ This mode specifies that the model is in training mode, which typically
1363
+ enables behaviors like dropout, batch normalization in training mode,
1364
+ gradient computation, etc.
1365
+
1366
+ Examples
1367
+ --------
1368
+ Basic training mode:
1369
+
1370
+ .. code-block:: python
1371
+
1372
+ >>> import brainstate
1373
+ >>>
1374
+ >>> # Create training mode
1375
+ >>> training = brainstate.mixin.Training()
1376
+ >>> print(training) # Training
1377
+ >>>
1378
+ >>> # Check mode
1379
+ >>> assert training.is_a(brainstate.mixin.Training)
1380
+ >>> assert training.has(brainstate.mixin.Mode)
1381
+
1382
+ Using in a model with dropout:
1383
+
1384
+ .. code-block:: python
1385
+
1386
+ >>> import brainstate
1387
+ >>> import jax
1388
+ >>> import jax.numpy as jnp
1389
+ >>>
1390
+ >>> class ModelWithDropout:
1391
+ ... def __init__(self, dropout_rate=0.5):
1392
+ ... self.dropout_rate = dropout_rate
1393
+ ... self.mode = None
1394
+ ...
1395
+ ... def set_training(self, is_training=True):
1396
+ ... if is_training:
1397
+ ... self.mode = brainstate.mixin.Training()
1398
+ ... else:
1399
+ ... self.mode = brainstate.mixin.Mode() # Evaluation mode
1400
+ ...
1401
+ ... def forward(self, x, rng_key):
1402
+ ... # Apply dropout only during training
1403
+ ... if self.mode is not None and self.mode.has(brainstate.mixin.Training):
1404
+ ... keep_prob = 1.0 - self.dropout_rate
1405
+ ... mask = jax.random.bernoulli(rng_key, keep_prob, x.shape)
1406
+ ... x = jnp.where(mask, x / keep_prob, 0)
1407
+ ... return x
1408
+ >>>
1409
+ >>> model = ModelWithDropout()
1410
+ >>>
1411
+ >>> # Training mode
1412
+ >>> model.set_training(True)
1413
+ >>> key = jax.random.PRNGKey(0)
1414
+ >>> x_train = jnp.ones((10, 20))
1415
+ >>> out_train = model.forward(x_train, key) # Dropout applied
1416
+ >>>
1417
+ >>> # Evaluation mode
1418
+ >>> model.set_training(False)
1419
+ >>> out_eval = model.forward(x_train, key) # No dropout
1420
+
1421
+ Combining with batching:
1422
+
1423
+ .. code-block:: python
1424
+
1425
+ >>> # Create combined training and batching mode
1426
+ >>> training = brainstate.mixin.Training()
1427
+ >>> batching = brainstate.mixin.Batching(batch_size=32)
1428
+ >>> mode = brainstate.mixin.JointMode(training, batching)
1429
+ >>>
1430
+ >>> # Use in training configuration
1431
+ >>> class Trainer:
1432
+ ... def __init__(self, model, mode):
1433
+ ... self.model = model
1434
+ ... self.mode = mode
1435
+ ...
1436
+ ... def train_epoch(self, data):
1437
+ ... if self.mode.has(brainstate.mixin.Training):
1438
+ ... # Enable training-specific behaviors
1439
+ ... self.model.set_training(True)
1440
+ ...
1441
+ ... if self.mode.has(brainstate.mixin.Batching):
1442
+ ... # Process in batches
1443
+ ... batch_size = self.mode.batch_size
1444
+ ... # ... batched training loop ...
1445
+ ... pass
1446
+ """
1447
+ pass