brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/mixin.py CHANGED
@@ -1,365 +1,1433 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- from typing import (
19
- Sequence, Optional, TypeVar, _SpecialForm, _type_check, _remove_dups_flatten, _UnionGenericAlias
20
- )
21
-
22
- import jax
23
-
24
- T = TypeVar('T')
25
- ArrayLike = jax.typing.ArrayLike
26
-
27
- __all__ = [
28
- 'Mixin',
29
- 'ParamDesc',
30
- 'ParamDescriber',
31
- 'AlignPost',
32
- 'BindCondData',
33
-
34
- # types
35
- 'JointTypes',
36
- 'OneOfTypes',
37
-
38
- # behavior modes
39
- 'Mode',
40
- 'JointMode',
41
- 'Batching',
42
- 'Training',
43
- ]
44
-
45
-
46
- def hashable(x):
47
- try:
48
- hash(x)
49
- return True
50
- except TypeError:
51
- return False
52
-
53
-
54
- class Mixin(object):
55
- """Base Mixin object.
56
-
57
- The key for a :py:class:`~.Mixin` is that: no initialization function, only behavioral functions.
58
- """
59
- pass
60
-
61
-
62
- class ParamDesc(Mixin):
63
- """
64
- :py:class:`~.Mixin` indicates the function for describing initialization parameters.
65
-
66
- This mixin enables the subclass has a classmethod ``desc``, which
67
- produces an instance of :py:class:`~.ParamDescriber`.
68
-
69
- Note this Mixin can be applied in any Python object.
70
- """
71
-
72
- non_hashable_params: Optional[Sequence[str]] = None
73
-
74
- @classmethod
75
- def desc(cls, *args, **kwargs) -> 'ParamDescriber':
76
- return ParamDescriber(cls, *args, **kwargs)
77
-
78
-
79
- class HashableDict(dict):
80
- def __init__(self, the_dict: dict):
81
- out = dict()
82
- for k, v in the_dict.items():
83
- if not hashable(v):
84
- v = str(v) # convert to string if not hashable
85
- out[k] = v
86
- super().__init__(out)
87
-
88
- def __hash__(self):
89
- return hash(tuple(sorted(self.items())))
90
-
91
-
92
- class NoSubclassMeta(type):
93
- def __new__(cls, name, bases, classdict):
94
- for b in bases:
95
- if isinstance(b, NoSubclassMeta):
96
- raise TypeError("type '{0}' is not an acceptable base type".format(b.__name__))
97
- return type.__new__(cls, name, bases, dict(classdict))
98
-
99
-
100
- class ParamDescriber(metaclass=NoSubclassMeta):
101
- """
102
- ParamDesc initialization for parameter describers.
103
- """
104
-
105
- def __init__(self, cls: T, *desc_tuple, **desc_dict):
106
- self.cls: type = cls
107
-
108
- # arguments
109
- self.args = desc_tuple
110
- self.kwargs = desc_dict
111
-
112
- # identifier
113
- self._identifier = (cls, tuple(desc_tuple), HashableDict(desc_dict))
114
-
115
- def __call__(self, *args, **kwargs) -> T:
116
- return self.cls(*self.args, *args, **self.kwargs, **kwargs)
117
-
118
- def init(self, *args, **kwargs):
119
- return self.__call__(*args, **kwargs)
120
-
121
- def __instancecheck__(self, instance):
122
- if not isinstance(instance, ParamDescriber):
123
- return False
124
- if not issubclass(instance.cls, self.cls):
125
- return False
126
- return True
127
-
128
- @classmethod
129
- def __class_getitem__(cls, item: type):
130
- return ParamDescriber(item)
131
-
132
- @property
133
- def identifier(self):
134
- return self._identifier
135
-
136
- @identifier.setter
137
- def identifier(self, value: ArrayLike):
138
- raise AttributeError('Cannot set the identifier.')
139
-
140
-
141
- class AlignPost(Mixin):
142
- """
143
- Align post MixIn.
144
-
145
- This class provides a ``align_post_input_add()`` function for
146
- add external currents.
147
- """
148
-
149
- def align_post_input_add(self, *args, **kwargs):
150
- raise NotImplementedError
151
-
152
-
153
- class BindCondData(Mixin):
154
- """Bind temporary conductance data.
155
-
156
-
157
- """
158
- _conductance: Optional
159
-
160
- def bind_cond(self, conductance):
161
- self._conductance = conductance
162
-
163
- def unbind_cond(self):
164
- self._conductance = None
165
-
166
-
167
- def not_implemented(func):
168
- def wrapper(*args, **kwargs):
169
- raise NotImplementedError(f'{func.__name__} is not implemented.')
170
-
171
- wrapper.not_implemented = True
172
- return wrapper
173
-
174
-
175
- class _MetaUnionType(type):
176
- def __new__(cls, name, bases, dct):
177
- if isinstance(bases, type):
178
- bases = (bases,)
179
- elif isinstance(bases, (list, tuple)):
180
- bases = tuple(bases)
181
- for base in bases:
182
- assert isinstance(base, type), f'Must be type. But got {base}'
183
- else:
184
- raise TypeError(f'Must be type. But got {bases}')
185
- return super().__new__(cls, name, bases, dct)
186
-
187
- def __instancecheck__(self, other):
188
- cls_of_other = other.__class__
189
- return all([issubclass(cls_of_other, cls) for cls in self.__bases__])
190
-
191
- def __subclasscheck__(self, subclass):
192
- return all([issubclass(subclass, cls) for cls in self.__bases__])
193
-
194
-
195
- class _JointGenericAlias(_UnionGenericAlias, _root=True):
196
- def __subclasscheck__(self, subclass):
197
- return all([issubclass(subclass, cls) for cls in set(self.__args__)])
198
-
199
-
200
- @_SpecialForm
201
- def JointTypes(self, parameters):
202
- """Joint types; JointTypes[X, Y] means both X and Y.
203
-
204
- To define a union, use e.g. Union[int, str].
205
-
206
- Details:
207
- - The arguments must be types and there must be at least one.
208
- - None as an argument is a special case and is replaced by `type(None)`.
209
- - Unions of unions are flattened, e.g.::
210
-
211
- JointTypes[JointTypes[int, str], float] == JointTypes[int, str, float]
212
-
213
- - Unions of a single argument vanish, e.g.::
214
-
215
- JointTypes[int] == int # The constructor actually returns int
216
-
217
- - Redundant arguments are skipped, e.g.::
218
-
219
- JointTypes[int, str, int] == JointTypes[int, str]
220
-
221
- - When comparing unions, the argument order is ignored, e.g.::
222
-
223
- JointTypes[int, str] == JointTypes[str, int]
224
-
225
- - You cannot subclass or instantiate a JointTypes.
226
- - You can use Optional[X] as a shorthand for JointTypes[X, None].
227
- """
228
- if parameters == ():
229
- raise TypeError("Cannot take a Joint of no types.")
230
- if not isinstance(parameters, tuple):
231
- parameters = (parameters,)
232
- msg = "JointTypes[arg, ...]: each arg must be a type."
233
- parameters = tuple(_type_check(p, msg) for p in parameters)
234
- parameters = _remove_dups_flatten(parameters)
235
- if len(parameters) == 1:
236
- return parameters[0]
237
- if len(parameters) == 2 and type(None) in parameters:
238
- return _UnionGenericAlias(self, parameters, name="Optional")
239
- return _JointGenericAlias(self, parameters)
240
-
241
-
242
- @_SpecialForm
243
- def OneOfTypes(self, parameters):
244
- """Sole type; OneOfTypes[X, Y] means either X or Y.
245
-
246
- To define a union, use e.g. OneOfTypes[int, str]. Details:
247
- - The arguments must be types and there must be at least one.
248
- - None as an argument is a special case and is replaced by
249
- type(None).
250
- - Unions of unions are flattened, e.g.::
251
-
252
- assert OneOfTypes[OneOfTypes[int, str], float] == OneOfTypes[int, str, float]
253
-
254
- - Unions of a single argument vanish, e.g.::
255
-
256
- assert OneOfTypes[int] == int # The constructor actually returns int
257
-
258
- - Redundant arguments are skipped, e.g.::
259
-
260
- assert OneOfTypes[int, str, int] == OneOfTypes[int, str]
261
-
262
- - When comparing unions, the argument order is ignored, e.g.::
263
-
264
- assert OneOfTypes[int, str] == OneOfTypes[str, int]
265
-
266
- - You cannot subclass or instantiate a union.
267
- - You can use Optional[X] as a shorthand for OneOfTypes[X, None].
268
- """
269
- if parameters == ():
270
- raise TypeError("Cannot take a Sole of no types.")
271
- if not isinstance(parameters, tuple):
272
- parameters = (parameters,)
273
- msg = "OneOfTypes[arg, ...]: each arg must be a type."
274
- parameters = tuple(_type_check(p, msg) for p in parameters)
275
- parameters = _remove_dups_flatten(parameters)
276
- if len(parameters) == 1:
277
- return parameters[0]
278
- if len(parameters) == 2 and type(None) in parameters:
279
- return _UnionGenericAlias(self, parameters, name="Optional")
280
- return _UnionGenericAlias(self, parameters)
281
-
282
-
283
- class Mode(Mixin):
284
- """
285
- Base class for computation behaviors.
286
- """
287
-
288
- def __repr__(self):
289
- return self.__class__.__name__
290
-
291
- def __eq__(self, other: 'Mode'):
292
- assert isinstance(other, Mode)
293
- return other.__class__ == self.__class__
294
-
295
- def is_a(self, mode: type):
296
- """
297
- Check whether the mode is exactly the desired mode.
298
- """
299
- assert isinstance(mode, type), 'Must be a type.'
300
- return self.__class__ == mode
301
-
302
- def has(self, mode: type):
303
- """
304
- Check whether the mode is included in the desired mode.
305
- """
306
- assert isinstance(mode, type), 'Must be a type.'
307
- return isinstance(self, mode)
308
-
309
-
310
- class JointMode(Mode):
311
- """
312
- Joint mode.
313
- """
314
-
315
- def __init__(self, *modes: Mode):
316
- for m_ in modes:
317
- if not isinstance(m_, Mode):
318
- raise TypeError(f'The supported type must be a tuple/list of Mode. But we got {m_}')
319
- self.modes = tuple(modes)
320
- self.types = set([m.__class__ for m in modes])
321
-
322
- def __repr__(self):
323
- return f'{self.__class__.__name__}({", ".join([repr(m) for m in self.modes])})'
324
-
325
- def has(self, mode: type):
326
- """
327
- Check whether the mode is included in the desired mode.
328
- """
329
- assert isinstance(mode, type), 'Must be a type.'
330
- return any([issubclass(cls, mode) for cls in self.types])
331
-
332
- def is_a(self, cls: type):
333
- """
334
- Check whether the mode is exactly the desired mode.
335
- """
336
- return JointTypes[tuple(self.types)] == cls
337
-
338
- def __getattr__(self, item):
339
- """
340
- Get the attribute from the mode.
341
-
342
- If the attribute is not found in the mode, then it will be searched in the base class.
343
- """
344
- if item in ['modes', 'types']:
345
- return super().__getattribute__(item)
346
- for m in self.modes:
347
- if hasattr(m, item):
348
- return getattr(m, item)
349
- return super().__getattribute__(item)
350
-
351
-
352
- class Batching(Mode):
353
- """Batching mode."""
354
-
355
- def __init__(self, batch_size: int = 1, batch_axis: int = 0):
356
- self.batch_size = batch_size
357
- self.batch_axis = batch_axis
358
-
359
- def __repr__(self):
360
- return f'{self.__class__.__name__}(in_size={self.batch_size}, axis={self.batch_axis})'
361
-
362
-
363
- class Training(Mode):
364
- """Training mode."""
365
- pass
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ """
19
+ Mixin classes and utility types for brainstate.
20
+
21
+ This module provides various mixin classes and custom type definitions that
22
+ enhance the functionality of brainstate components. It includes parameter
23
+ description mixins, alignment interfaces, and custom type definitions for
24
+ expressing complex type requirements.
25
+ """
26
+
27
+ from typing import Sequence, Optional, TypeVar, Union, _GenericAlias
28
+
29
+ import jax
30
+
31
+ __all__ = [
32
+ 'Mixin',
33
+ 'ParamDesc',
34
+ 'ParamDescriber',
35
+ 'JointTypes',
36
+ 'OneOfTypes',
37
+ '_JointGenericAlias',
38
+ '_OneOfGenericAlias',
39
+ 'Mode',
40
+ 'JointMode',
41
+ 'Batching',
42
+ 'Training',
43
+ ]
44
+
45
+ T = TypeVar('T')
46
+ ArrayLike = jax.typing.ArrayLike
47
+
48
+
49
+ def hashable(x):
50
+ """
51
+ Check if an object is hashable.
52
+
53
+ Parameters
54
+ ----------
55
+ x : Any
56
+ The object to check for hashability.
57
+
58
+ Returns
59
+ -------
60
+ bool
61
+ True if the object is hashable, False otherwise.
62
+
63
+ Examples
64
+ --------
65
+ .. code-block:: python
66
+
67
+ >>> import brainstate
68
+ >>>
69
+ >>> # Hashable objects
70
+ >>> assert brainstate.mixin.hashable(42) == True
71
+ >>> assert brainstate.mixin.hashable("string") == True
72
+ >>> assert brainstate.mixin.hashable((1, 2, 3)) == True
73
+ >>>
74
+ >>> # Non-hashable objects
75
+ >>> assert brainstate.mixin.hashable([1, 2, 3]) == False
76
+ >>> assert brainstate.mixin.hashable({"key": "value"}) == False
77
+ """
78
+ try:
79
+ hash(x)
80
+ return True
81
+ except TypeError:
82
+ return False
83
+
84
+
85
+ class Mixin(object):
86
+ """
87
+ Base Mixin object for behavioral extensions.
88
+
89
+ The key characteristic of a :py:class:`~.Mixin` is that it provides only
90
+ behavioral functions without requiring initialization. Mixins are used to
91
+ add specific functionality to classes through multiple inheritance without
92
+ the complexity of a full base class.
93
+
94
+ Notes
95
+ -----
96
+ Mixins should not define ``__init__`` methods. They should only provide
97
+ methods that add specific behaviors to the classes that inherit from them.
98
+
99
+ Examples
100
+ --------
101
+ Creating a custom mixin:
102
+
103
+ .. code-block:: python
104
+
105
+ >>> import brainstate
106
+ >>>
107
+ >>> class LoggingMixin(brainstate.mixin.Mixin):
108
+ ... def log(self, message):
109
+ ... print(f"[{self.__class__.__name__}] {message}")
110
+
111
+ >>> class MyComponent(brainstate.nn.Module, LoggingMixin):
112
+ ... def __init__(self):
113
+ ... super().__init__()
114
+ ...
115
+ ... def process(self):
116
+ ... self.log("Processing data...")
117
+ ... return "Done"
118
+ >>>
119
+ >>> component = MyComponent()
120
+ >>> component.process() # Prints: [MyComponent] Processing data...
121
+ """
122
+ pass
123
+
124
+
125
+ class ParamDesc(Mixin):
126
+ """
127
+ Mixin for describing initialization parameters.
128
+
129
+ This mixin enables a class to have a ``desc`` classmethod, which produces
130
+ an instance of :py:class:`~.ParamDescriber`. This is useful for creating
131
+ parameter templates that can be reused to instantiate multiple objects
132
+ with the same configuration.
133
+
134
+ Attributes
135
+ ----------
136
+ non_hashable_params : sequence of str, optional
137
+ Names of parameters that are not hashable and should be handled specially.
138
+
139
+ Notes
140
+ -----
141
+ This mixin can be applied to any Python class, not just brainstate-specific classes.
142
+
143
+ Examples
144
+ --------
145
+ Basic usage of ParamDesc:
146
+
147
+ .. code-block:: python
148
+
149
+ >>> import brainstate
150
+ >>>
151
+ >>> class NeuronModel(brainstate.mixin.ParamDesc):
152
+ ... def __init__(self, size, tau=10.0, threshold=1.0):
153
+ ... self.size = size
154
+ ... self.tau = tau
155
+ ... self.threshold = threshold
156
+ >>>
157
+ >>> # Create a parameter descriptor
158
+ >>> neuron_desc = NeuronModel.desc(size=100, tau=20.0)
159
+ >>>
160
+ >>> # Use the descriptor to create instances
161
+ >>> neuron1 = neuron_desc(threshold=0.8) # Creates with threshold=0.8
162
+ >>> neuron2 = neuron_desc(threshold=1.2) # Creates with threshold=1.2
163
+ >>>
164
+ >>> # Both neurons share size=100, tau=20.0 but have different thresholds
165
+
166
+ Creating reusable templates:
167
+
168
+ .. code-block:: python
169
+
170
+ >>> # Define a template for excitatory neurons
171
+ >>> exc_neuron_template = NeuronModel.desc(size=1000, tau=10.0, threshold=1.0)
172
+ >>>
173
+ >>> # Define a template for inhibitory neurons
174
+ >>> inh_neuron_template = NeuronModel.desc(size=250, tau=5.0, threshold=0.5)
175
+ >>>
176
+ >>> # Create multiple instances from templates
177
+ >>> exc_population = [exc_neuron_template() for _ in range(5)]
178
+ >>> inh_population = [inh_neuron_template() for _ in range(2)]
179
+ """
180
+
181
+ # Optional list of parameter names that are not hashable
182
+ # These will be converted to strings for hashing purposes
183
+ non_hashable_params: Optional[Sequence[str]] = None
184
+
185
+ @classmethod
186
+ def desc(cls, *args, **kwargs) -> 'ParamDescriber':
187
+ """
188
+ Create a parameter describer for this class.
189
+
190
+ Parameters
191
+ ----------
192
+ *args
193
+ Positional arguments to be used in future instantiations.
194
+ **kwargs
195
+ Keyword arguments to be used in future instantiations.
196
+
197
+ Returns
198
+ -------
199
+ ParamDescriber
200
+ A descriptor that can be used to create instances with these parameters.
201
+ """
202
+ return ParamDescriber(cls, *args, **kwargs)
203
+
204
+
205
+ class HashableDict(dict):
206
+ """
207
+ A dictionary that can be hashed by converting non-hashable values to strings.
208
+
209
+ This is used internally to make parameter dictionaries hashable so they can
210
+ be used as part of cache keys or other contexts requiring hashability.
211
+
212
+ Parameters
213
+ ----------
214
+ the_dict : dict
215
+ The dictionary to make hashable.
216
+
217
+ Notes
218
+ -----
219
+ Non-hashable values in the dictionary are automatically converted to their
220
+ string representation.
221
+
222
+ Examples
223
+ --------
224
+ .. code-block:: python
225
+
226
+ >>> import brainstate
227
+ >>> import jax.numpy as jnp
228
+ >>>
229
+ >>> # Regular dict with non-hashable values cannot be hashed
230
+ >>> regular_dict = {"array": jnp.array([1, 2, 3]), "value": 42}
231
+ >>> # hash(regular_dict) # This would raise TypeError
232
+ >>>
233
+ >>> # HashableDict can be hashed
234
+ >>> hashable = brainstate.mixin.HashableDict(regular_dict)
235
+ >>> key = hash(hashable) # This works!
236
+ >>>
237
+ >>> # Can be used in sets or as dict keys
238
+ >>> cache = {hashable: "result"}
239
+ """
240
+
241
+ def __init__(self, the_dict: dict):
242
+ # Process the dictionary to ensure all values are hashable
243
+ out = dict()
244
+ for k, v in the_dict.items():
245
+ if not hashable(v):
246
+ # Convert non-hashable values to their string representation
247
+ v = str(v)
248
+ out[k] = v
249
+ super().__init__(out)
250
+
251
+ def __hash__(self):
252
+ """
253
+ Compute hash from sorted items for consistent hashing regardless of insertion order.
254
+ """
255
+ return hash(tuple(sorted(self.items())))
256
+
257
+
258
+ class NoSubclassMeta(type):
259
+ """
260
+ Metaclass that prevents a class from being subclassed.
261
+
262
+ This is used to ensure that certain classes (like ParamDescriber) are used
263
+ as-is and not extended through inheritance, which could lead to unexpected
264
+ behavior.
265
+
266
+ Raises
267
+ ------
268
+ TypeError
269
+ If an attempt is made to subclass a class using this metaclass.
270
+ """
271
+
272
+ def __new__(cls, name, bases, classdict):
273
+ # Check if any base class uses NoSubclassMeta
274
+ for b in bases:
275
+ if isinstance(b, NoSubclassMeta):
276
+ raise TypeError("type '{0}' is not an acceptable base type".format(b.__name__))
277
+ return type.__new__(cls, name, bases, dict(classdict))
278
+
279
+
280
+ class ParamDescriber(metaclass=NoSubclassMeta):
281
+ """
282
+ Parameter descriptor for deferred object instantiation.
283
+
284
+ This class stores a class reference along with arguments and keyword arguments,
285
+ allowing for deferred instantiation. It's useful for creating templates that
286
+ can be reused to create multiple instances with similar configurations.
287
+
288
+ Parameters
289
+ ----------
290
+ cls : type
291
+ The class to be instantiated.
292
+ *desc_tuple
293
+ Positional arguments to be stored and used during instantiation.
294
+ **desc_dict
295
+ Keyword arguments to be stored and used during instantiation.
296
+
297
+ Attributes
298
+ ----------
299
+ cls : type
300
+ The class that will be instantiated.
301
+ args : tuple
302
+ Stored positional arguments.
303
+ kwargs : dict
304
+ Stored keyword arguments.
305
+ identifier : tuple
306
+ A hashable identifier for this descriptor.
307
+
308
+ Notes
309
+ -----
310
+ ParamDescriber cannot be subclassed due to the NoSubclassMeta metaclass.
311
+ This ensures consistent behavior across the codebase.
312
+
313
+ Examples
314
+ --------
315
+ Manual creation of a descriptor:
316
+
317
+ .. code-block:: python
318
+
319
+ >>> import brainstate
320
+ >>>
321
+ >>> class Network:
322
+ ... def __init__(self, n_neurons, learning_rate=0.01):
323
+ ... self.n_neurons = n_neurons
324
+ ... self.learning_rate = learning_rate
325
+ >>>
326
+ >>> # Create a descriptor
327
+ >>> network_desc = brainstate.mixin.ParamDescriber(
328
+ ... Network, n_neurons=1000, learning_rate=0.001
329
+ ... )
330
+ >>>
331
+ >>> # Use the descriptor to create instances with additional args
332
+ >>> net1 = network_desc()
333
+ >>> net2 = network_desc() # Same configuration
334
+
335
+ Using with ParamDesc mixin:
336
+
337
+ .. code-block:: python
338
+
339
+ >>> class Network(brainstate.mixin.ParamDesc):
340
+ ... def __init__(self, n_neurons, learning_rate=0.01):
341
+ ... self.n_neurons = n_neurons
342
+ ... self.learning_rate = learning_rate
343
+ >>>
344
+ >>> # More concise syntax using the desc() classmethod
345
+ >>> network_desc = Network.desc(n_neurons=1000)
346
+ >>> net = network_desc(learning_rate=0.005) # Override learning_rate
347
+ """
348
+
349
+ def __init__(self, cls: T, *desc_tuple, **desc_dict):
350
+ # Store the class to be instantiated
351
+ self.cls: type = cls
352
+
353
+ # Store the arguments for later instantiation
354
+ self.args = desc_tuple
355
+ self.kwargs = desc_dict
356
+
357
+ # Create a hashable identifier for caching/comparison purposes
358
+ # This combines the class, args tuple, and hashable kwargs dict
359
+ self._identifier = (cls, tuple(desc_tuple), HashableDict(desc_dict))
360
+
361
+ def __call__(self, *args, **kwargs) -> T:
362
+ """
363
+ Instantiate the class with stored and additional arguments.
364
+
365
+ Parameters
366
+ ----------
367
+ *args
368
+ Additional positional arguments to append.
369
+ **kwargs
370
+ Additional keyword arguments to merge (will override stored kwargs).
371
+
372
+ Returns
373
+ -------
374
+ T
375
+ An instance of the described class.
376
+ """
377
+ # Merge stored arguments with new arguments
378
+ # Stored args come first, then new args
379
+ # Merge kwargs with new kwargs overriding stored ones
380
+ merged_kwargs = {**self.kwargs, **kwargs}
381
+ return self.cls(*self.args, *args, **merged_kwargs)
382
+
383
+ def init(self, *args, **kwargs):
384
+ """
385
+ Alias for __call__, explicitly named for clarity.
386
+
387
+ Parameters
388
+ ----------
389
+ *args
390
+ Additional positional arguments.
391
+ **kwargs
392
+ Additional keyword arguments.
393
+
394
+ Returns
395
+ -------
396
+ T
397
+ An instance of the described class.
398
+ """
399
+ return self.__call__(*args, **kwargs)
400
+
401
+ def __instancecheck__(self, instance):
402
+ """
403
+ Check if an instance is compatible with this descriptor.
404
+
405
+ Parameters
406
+ ----------
407
+ instance : Any
408
+ The instance to check.
409
+
410
+ Returns
411
+ -------
412
+ bool
413
+ True if the instance is a ParamDescriber for a compatible class.
414
+ """
415
+ # Must be a ParamDescriber
416
+ if not isinstance(instance, ParamDescriber):
417
+ return False
418
+ # The described class must be a subclass of our class
419
+ if not issubclass(instance.cls, self.cls):
420
+ return False
421
+ return True
422
+
423
+ @classmethod
424
+ def __class_getitem__(cls, item: type):
425
+ """
426
+ Support for subscript notation: ParamDescriber[MyClass].
427
+
428
+ Parameters
429
+ ----------
430
+ item : type
431
+ The class to create a descriptor for.
432
+
433
+ Returns
434
+ -------
435
+ ParamDescriber
436
+ A descriptor for the given class.
437
+ """
438
+ return ParamDescriber(item)
439
+
440
+ @property
441
+ def identifier(self):
442
+ """
443
+ Get the unique identifier for this descriptor.
444
+
445
+ Returns
446
+ -------
447
+ tuple
448
+ A hashable identifier consisting of (class, args, kwargs).
449
+ """
450
+ return self._identifier
451
+
452
+ @identifier.setter
453
+ def identifier(self, value: ArrayLike):
454
+ """
455
+ Prevent modification of the identifier.
456
+
457
+ Raises
458
+ ------
459
+ AttributeError
460
+ Always, as the identifier is read-only.
461
+ """
462
+ raise AttributeError('Cannot set the identifier.')
463
+
464
+
465
+ def not_implemented(func):
466
+ """
467
+ Decorator to mark a function as not implemented.
468
+
469
+ This decorator wraps a function to raise NotImplementedError when called,
470
+ and adds a ``not_implemented`` attribute for checking.
471
+
472
+ Parameters
473
+ ----------
474
+ func : callable
475
+ The function to mark as not implemented.
476
+
477
+ Returns
478
+ -------
479
+ callable
480
+ A wrapper function that raises NotImplementedError.
481
+
482
+ Examples
483
+ --------
484
+ .. code-block:: python
485
+
486
+ >>> import brainstate
487
+ >>>
488
+ >>> class BaseModel:
489
+ ... @brainstate.mixin.not_implemented
490
+ ... def process(self, x):
491
+ ... pass
492
+ >>>
493
+ >>> model = BaseModel()
494
+ >>> # model.process(10) # Raises: NotImplementedError: process is not implemented.
495
+ >>>
496
+ >>> # Check if a method is not implemented
497
+ >>> assert hasattr(BaseModel.process, 'not_implemented')
498
+ """
499
+
500
+ def wrapper(*args, **kwargs):
501
+ raise NotImplementedError(f'{func.__name__} is not implemented.')
502
+
503
+ # Mark the wrapper so we can detect not-implemented methods
504
+ wrapper.not_implemented = True
505
+ return wrapper
506
+
507
+
508
+ class _JointGenericAlias(_GenericAlias, _root=True):
509
+ """
510
+ Generic alias for JointTypes (intersection types).
511
+
512
+ This class represents a type that requires all specified types to be satisfied.
513
+ Unlike _MetaUnionType which creates actual classes with metaclass conflicts,
514
+ this uses typing's generic alias system to avoid metaclass issues.
515
+ """
516
+
517
+ def __instancecheck__(self, obj):
518
+ """
519
+ Check if an instance is an instance of all component types.
520
+ """
521
+ return all(isinstance(obj, cls) for cls in self.__args__)
522
+
523
+ def __subclasscheck__(self, subclass):
524
+ """
525
+ Check if a class is a subclass of all component types.
526
+ """
527
+ return all(issubclass(subclass, cls) for cls in self.__args__)
528
+
529
+ def __eq__(self, other):
530
+ """
531
+ Check equality with another type.
532
+
533
+ Two JointTypes are equal if they have the same component types,
534
+ regardless of order.
535
+ """
536
+ if not isinstance(other, _JointGenericAlias):
537
+ return NotImplemented
538
+ return set(self.__args__) == set(other.__args__)
539
+
540
+ def __hash__(self):
541
+ """
542
+ Return hash of the JointType.
543
+
544
+ The hash is based on the frozenset of component types to ensure
545
+ that JointTypes with the same types (regardless of order) have
546
+ the same hash.
547
+ """
548
+ return hash(frozenset(self.__args__))
549
+
550
+ def __repr__(self):
551
+ """
552
+ Return string representation of the JointType.
553
+
554
+ Returns a readable representation showing all component types.
555
+ """
556
+ args_str = ', '.join(
557
+ arg.__module__ + '.' + arg.__name__ if hasattr(arg, '__module__') and hasattr(arg, '__name__')
558
+ else str(arg)
559
+ for arg in self.__args__
560
+ )
561
+ return f'JointTypes[{args_str}]'
562
+
563
+ def __reduce__(self):
564
+ """
565
+ Support for pickling.
566
+
567
+ Returns the necessary information to reconstruct the JointType
568
+ when unpickling.
569
+ """
570
+ return (_JointGenericAlias, (self.__origin__, self.__args__))
571
+
572
+
573
+ class _OneOfGenericAlias(_GenericAlias, _root=True):
574
+ """
575
+ Generic alias for OneOfTypes (union types).
576
+
577
+ This class represents a type that requires at least one of the specified
578
+ types to be satisfied. It's similar to typing.Union but provides a consistent
579
+ interface with JointTypes and avoids potential metaclass conflicts.
580
+ """
581
+
582
+ def __instancecheck__(self, obj):
583
+ """
584
+ Check if an instance is an instance of any component type.
585
+ """
586
+ return any(isinstance(obj, cls) for cls in self.__args__)
587
+
588
+ def __subclasscheck__(self, subclass):
589
+ """
590
+ Check if a class is a subclass of any component type.
591
+ """
592
+ return any(issubclass(subclass, cls) for cls in self.__args__)
593
+
594
+ def __eq__(self, other):
595
+ """
596
+ Check equality with another type.
597
+
598
+ Two OneOfTypes are equal if they have the same component types,
599
+ regardless of order.
600
+ """
601
+ if not isinstance(other, _OneOfGenericAlias):
602
+ return NotImplemented
603
+ return set(self.__args__) == set(other.__args__)
604
+
605
+ def __hash__(self):
606
+ """
607
+ Return hash of the OneOfType.
608
+
609
+ The hash is based on the frozenset of component types to ensure
610
+ that OneOfTypes with the same types (regardless of order) have
611
+ the same hash.
612
+ """
613
+ return hash(frozenset(self.__args__))
614
+
615
+ def __repr__(self):
616
+ """
617
+ Return string representation of the OneOfType.
618
+
619
+ Returns a readable representation showing all component types.
620
+ """
621
+ args_str = ', '.join(
622
+ arg.__module__ + '.' + arg.__name__ if hasattr(arg, '__module__') and hasattr(arg, '__name__')
623
+ else str(arg)
624
+ for arg in self.__args__
625
+ )
626
+ return f'OneOfTypes[{args_str}]'
627
+
628
+ def __reduce__(self):
629
+ """
630
+ Support for pickling.
631
+
632
+ Returns the necessary information to reconstruct the OneOfType
633
+ when unpickling.
634
+ """
635
+ return (_OneOfGenericAlias, (self.__origin__, self.__args__))
636
+
637
+
638
+ class _JointTypesClass:
639
+ """Helper class to enable subscript syntax for JointTypes."""
640
+
641
+ def __call__(self, *types):
642
+ """
643
+ Create a type that requires all specified types (intersection type).
644
+
645
+ This function creates a type hint that indicates a value must satisfy all
646
+ the specified types simultaneously. It's useful for expressing complex
647
+ type requirements where a single object must implement multiple interfaces.
648
+
649
+ Parameters
650
+ ----------
651
+ *types : type
652
+ The types that must all be satisfied.
653
+
654
+ Returns
655
+ -------
656
+ type
657
+ A type that checks for all specified types.
658
+
659
+ Notes
660
+ -----
661
+ - If only one type is provided, that type is returned directly.
662
+ - Redundant types are automatically removed.
663
+ - The order of types doesn't matter for equality checks.
664
+
665
+ Examples
666
+ --------
667
+ Basic usage with interfaces:
668
+
669
+ .. code-block:: python
670
+
671
+ >>> import brainstate
672
+ >>> from typing import Protocol
673
+ >>>
674
+ >>> class Trainable(Protocol):
675
+ ... def train(self): ...
676
+ >>>
677
+ >>> class Evaluable(Protocol):
678
+ ... def evaluate(self): ...
679
+ >>>
680
+ >>> # A model that is both trainable and evaluable
681
+ >>> TrainableEvaluableModel = brainstate.mixin.JointTypes(Trainable, Evaluable)
682
+ >>> # Or using subscript syntax
683
+ >>> TrainableEvaluableModel = brainstate.mixin.JointTypes[Trainable, Evaluable]
684
+ >>>
685
+ >>> class NeuralNetwork(Trainable, Evaluable):
686
+ ... def train(self):
687
+ ... return "Training..."
688
+ ...
689
+ ... def evaluate(self):
690
+ ... return "Evaluating..."
691
+ >>>
692
+ >>> model = NeuralNetwork()
693
+ >>> # model satisfies JointTypes(Trainable, Evaluable)
694
+
695
+ Using with mixin classes:
696
+
697
+ .. code-block:: python
698
+
699
+ >>> class Serializable:
700
+ ... def save(self): pass
701
+ >>>
702
+ >>> class Visualizable:
703
+ ... def plot(self): pass
704
+ >>>
705
+ >>> # Require both serialization and visualization
706
+ >>> FullFeaturedModel = brainstate.mixin.JointTypes[Serializable, Visualizable]
707
+ >>>
708
+ >>> class MyModel(Serializable, Visualizable):
709
+ ... def save(self):
710
+ ... return "Saved"
711
+ ...
712
+ ... def plot(self):
713
+ ... return "Plotted"
714
+ """
715
+ if len(types) == 0:
716
+ raise TypeError("Cannot create a JointTypes of no types.")
717
+
718
+ # Remove duplicates while preserving some order
719
+ seen = set()
720
+ unique_types = []
721
+ for t in types:
722
+ if t not in seen:
723
+ seen.add(t)
724
+ unique_types.append(t)
725
+
726
+ # If only one type, return it directly
727
+ if len(unique_types) == 1:
728
+ return unique_types[0]
729
+
730
+ # Create a generic alias for the joint type
731
+ # This avoids metaclass conflicts by using typing's generic alias system
732
+ return _JointGenericAlias(object, tuple(unique_types))
733
+
734
+ def __getitem__(self, item):
735
+ """Enable subscript syntax: JointTypes[Type1, Type2]."""
736
+ if isinstance(item, tuple):
737
+ return self(*item)
738
+ else:
739
+ return self(item)
740
+
741
+
742
+ # Create singleton instance that acts as both a callable and supports subscript
743
+ JointTypes = _JointTypesClass()
744
+
745
+
746
+ class _OneOfTypesClass:
747
+ """Helper class to enable subscript syntax for OneOfTypes."""
748
+
749
+ def __call__(self, *types):
750
+ """
751
+ Create a type that requires one of the specified types (union type).
752
+
753
+ This is similar to typing.Union but provides a more intuitive name and
754
+ consistent behavior with JointTypes. It indicates that a value must satisfy
755
+ at least one of the specified types.
756
+
757
+ Parameters
758
+ ----------
759
+ *types : type
760
+ The types, one of which must be satisfied.
761
+
762
+ Returns
763
+ -------
764
+ Union type
765
+ A union type of the specified types.
766
+
767
+ Notes
768
+ -----
769
+ - If only one type is provided, that type is returned directly.
770
+ - Redundant types are automatically removed.
771
+ - The order of types doesn't matter for equality checks.
772
+ - This is equivalent to typing.Union[...].
773
+
774
+ Examples
775
+ --------
776
+ Basic usage with different types:
777
+
778
+ .. code-block:: python
779
+
780
+ >>> import brainstate
781
+ >>>
782
+ >>> # A parameter that can be int or float
783
+ >>> NumericType = brainstate.mixin.OneOfTypes(int, float)
784
+ >>> # Or using subscript syntax
785
+ >>> NumericType = brainstate.mixin.OneOfTypes[int, float]
786
+ >>>
787
+ >>> def process_value(x: NumericType):
788
+ ... return x * 2
789
+ >>>
790
+ >>> # Both work
791
+ >>> result1 = process_value(5) # int
792
+ >>> result2 = process_value(3.14) # float
793
+
794
+ Using with class types:
795
+
796
+ .. code-block:: python
797
+
798
+ >>> class NumpyArray:
799
+ ... pass
800
+ >>>
801
+ >>> class JAXArray:
802
+ ... pass
803
+ >>>
804
+ >>> # Accept either numpy or JAX arrays
805
+ >>> ArrayType = brainstate.mixin.OneOfTypes[NumpyArray, JAXArray]
806
+ >>>
807
+ >>> def compute(arr: ArrayType):
808
+ ... if isinstance(arr, NumpyArray):
809
+ ... return "Processing numpy array"
810
+ ... elif isinstance(arr, JAXArray):
811
+ ... return "Processing JAX array"
812
+
813
+ Combining with None for optional types:
814
+
815
+ .. code-block:: python
816
+
817
+ >>> # Optional string (equivalent to Optional[str])
818
+ >>> MaybeString = brainstate.mixin.OneOfTypes[str, type(None)]
819
+ >>>
820
+ >>> def format_name(name: MaybeString) -> str:
821
+ ... if name is None:
822
+ ... return "Anonymous"
823
+ ... return name.title()
824
+ """
825
+ if len(types) == 0:
826
+ raise TypeError("Cannot create a OneOfTypes of no types.")
827
+
828
+ # Remove duplicates
829
+ seen = set()
830
+ unique_types = []
831
+ for t in types:
832
+ if t not in seen:
833
+ seen.add(t)
834
+ unique_types.append(t)
835
+
836
+ # If only one type, return it directly
837
+ if len(unique_types) == 1:
838
+ return unique_types[0]
839
+
840
+ # Create a generic alias for the union type
841
+ # This provides consistency with JointTypes and avoids metaclass conflicts
842
+ return _OneOfGenericAlias(Union, tuple(unique_types))
843
+
844
+ def __getitem__(self, item):
845
+ """Enable subscript syntax: OneOfTypes[Type1, Type2]."""
846
+ if isinstance(item, tuple):
847
+ return self(*item)
848
+ else:
849
+ return self(item)
850
+
851
+
852
+ # Create singleton instance that acts as both a callable and supports subscript
853
+ OneOfTypes = _OneOfTypesClass()
854
+
855
+
856
+ def __getattr__(name):
857
+ if name in [
858
+ 'Mode',
859
+ 'JointMode',
860
+ 'Batching',
861
+ 'Training',
862
+ 'AlignPost',
863
+ 'BindCondData',
864
+ ]:
865
+ import warnings
866
+ warnings.warn(
867
+ f"brainstate.mixin.{name} is deprecated and will be removed in a future version. "
868
+ f"Please use brainpy.mixin.{name} instead.",
869
+ DeprecationWarning,
870
+ stacklevel=2
871
+ )
872
+ import brainpy
873
+ return getattr(brainpy.mixin, name)
874
+ raise AttributeError(
875
+ f'module {__name__!r} has no attribute {name!r}'
876
+ )
877
+
878
+
879
+ class Mode(Mixin):
880
+ """
881
+ Base class for computation behavior modes.
882
+
883
+ Modes are used to represent different computational contexts or behaviors,
884
+ such as training vs evaluation, batched vs single-sample processing, etc.
885
+ They provide a flexible way to configure how models and components behave
886
+ in different scenarios.
887
+
888
+ Examples
889
+ --------
890
+ Creating a custom mode:
891
+
892
+ .. code-block:: python
893
+
894
+ >>> import brainstate
895
+ >>>
896
+ >>> class InferenceMode(brainstate.mixin.Mode):
897
+ ... def __init__(self, use_cache=True):
898
+ ... self.use_cache = use_cache
899
+ >>>
900
+ >>> # Create mode instances
901
+ >>> inference = InferenceMode(use_cache=True)
902
+ >>> print(inference) # Output: InferenceMode
903
+
904
+ Checking mode types:
905
+
906
+ .. code-block:: python
907
+
908
+ >>> class FastMode(brainstate.mixin.Mode):
909
+ ... pass
910
+ >>>
911
+ >>> class SlowMode(brainstate.mixin.Mode):
912
+ ... pass
913
+ >>>
914
+ >>> fast = FastMode()
915
+ >>> slow = SlowMode()
916
+ >>>
917
+ >>> # Check exact mode type
918
+ >>> assert fast.is_a(FastMode)
919
+ >>> assert not fast.is_a(SlowMode)
920
+ >>>
921
+ >>> # Check if mode is an instance of a type
922
+ >>> assert fast.has(brainstate.mixin.Mode)
923
+
924
+ Using modes in a model:
925
+
926
+ .. code-block:: python
927
+
928
+ >>> class Model:
929
+ ... def __init__(self):
930
+ ... self.mode = brainstate.mixin.Training()
931
+ ...
932
+ ... def forward(self, x):
933
+ ... if self.mode.has(brainstate.mixin.Training):
934
+ ... # Training-specific logic
935
+ ... return self.train_forward(x)
936
+ ... else:
937
+ ... # Inference logic
938
+ ... return self.eval_forward(x)
939
+ ...
940
+ ... def train_forward(self, x):
941
+ ... return x + 0.1 # Add noise during training
942
+ ...
943
+ ... def eval_forward(self, x):
944
+ ... return x # No noise during evaluation
945
+ """
946
+
947
+ def __repr__(self):
948
+ """
949
+ String representation of the mode.
950
+
951
+ Returns
952
+ -------
953
+ str
954
+ The class name of the mode.
955
+ """
956
+ return self.__class__.__name__
957
+
958
+ def __eq__(self, other: 'Mode'):
959
+ """
960
+ Check equality of modes based on their type.
961
+
962
+ Parameters
963
+ ----------
964
+ other : Mode
965
+ Another mode to compare with.
966
+
967
+ Returns
968
+ -------
969
+ bool
970
+ True if both modes are of the same class.
971
+ """
972
+ assert isinstance(other, Mode)
973
+ return other.__class__ == self.__class__
974
+
975
+ def is_a(self, mode: type):
976
+ """
977
+ Check whether the mode is exactly the desired mode type.
978
+
979
+ This performs an exact type match, not checking for subclasses.
980
+
981
+ Parameters
982
+ ----------
983
+ mode : type
984
+ The mode type to check against.
985
+
986
+ Returns
987
+ -------
988
+ bool
989
+ True if this mode is exactly of the specified type.
990
+
991
+ Examples
992
+ --------
993
+ .. code-block:: python
994
+
995
+ >>> import brainstate
996
+ >>>
997
+ >>> training_mode = brainstate.mixin.Training()
998
+ >>> assert training_mode.is_a(brainstate.mixin.Training)
999
+ >>> assert not training_mode.is_a(brainstate.mixin.Batching)
1000
+ """
1001
+ assert isinstance(mode, type), 'Must be a type.'
1002
+ return self.__class__ == mode
1003
+
1004
+ def has(self, mode: type):
1005
+ """
1006
+ Check whether the mode includes the desired mode type.
1007
+
1008
+ This checks if the current mode is an instance of the specified type,
1009
+ including checking for subclasses.
1010
+
1011
+ Parameters
1012
+ ----------
1013
+ mode : type
1014
+ The mode type to check for.
1015
+
1016
+ Returns
1017
+ -------
1018
+ bool
1019
+ True if this mode is an instance of the specified type.
1020
+
1021
+ Examples
1022
+ --------
1023
+ .. code-block:: python
1024
+
1025
+ >>> import brainstate
1026
+ >>>
1027
+ >>> # Create a custom mode that extends Training
1028
+ >>> class AdvancedTraining(brainstate.mixin.Training):
1029
+ ... pass
1030
+ >>>
1031
+ >>> advanced = AdvancedTraining()
1032
+ >>> assert advanced.has(brainstate.mixin.Training) # True (subclass)
1033
+ >>> assert advanced.has(brainstate.mixin.Mode) # True (base class)
1034
+ """
1035
+ assert isinstance(mode, type), 'Must be a type.'
1036
+ return isinstance(self, mode)
1037
+
1038
+
1039
+ class JointMode(Mode):
1040
+ """
1041
+ A mode that combines multiple modes simultaneously.
1042
+
1043
+ JointMode allows expressing that a computation is in multiple modes at once,
1044
+ such as being both in training mode and batching mode. This is useful for
1045
+ complex scenarios where multiple behavioral aspects need to be active.
1046
+
1047
+ Parameters
1048
+ ----------
1049
+ *modes : Mode
1050
+ The modes to combine.
1051
+
1052
+ Attributes
1053
+ ----------
1054
+ modes : tuple of Mode
1055
+ The individual modes that are combined.
1056
+ types : set of type
1057
+ The types of the combined modes.
1058
+
1059
+ Raises
1060
+ ------
1061
+ TypeError
1062
+ If any of the provided arguments is not a Mode instance.
1063
+
1064
+ Examples
1065
+ --------
1066
+ Combining training and batching modes:
1067
+
1068
+ .. code-block:: python
1069
+
1070
+ >>> import brainstate
1071
+ >>>
1072
+ >>> # Create individual modes
1073
+ >>> training = brainstate.mixin.Training()
1074
+ >>> batching = brainstate.mixin.Batching(batch_size=32)
1075
+ >>>
1076
+ >>> # Combine them
1077
+ >>> joint = brainstate.mixin.JointMode(training, batching)
1078
+ >>> print(joint) # JointMode(Training, Batching(in_size=32, axis=0))
1079
+ >>>
1080
+ >>> # Check if specific modes are present
1081
+ >>> assert joint.has(brainstate.mixin.Training)
1082
+ >>> assert joint.has(brainstate.mixin.Batching)
1083
+ >>>
1084
+ >>> # Access attributes from combined modes
1085
+ >>> print(joint.batch_size) # 32 (from Batching mode)
1086
+
1087
+ Using in model configuration:
1088
+
1089
+ .. code-block:: python
1090
+
1091
+ >>> class NeuralNetwork:
1092
+ ... def __init__(self):
1093
+ ... self.mode = None
1094
+ ...
1095
+ ... def set_train_mode(self, batch_size=1):
1096
+ ... # Set both training and batching modes
1097
+ ... training = brainstate.mixin.Training()
1098
+ ... batching = brainstate.mixin.Batching(batch_size=batch_size)
1099
+ ... self.mode = brainstate.mixin.JointMode(training, batching)
1100
+ ...
1101
+ ... def forward(self, x):
1102
+ ... if self.mode.has(brainstate.mixin.Training):
1103
+ ... x = self.apply_dropout(x)
1104
+ ...
1105
+ ... if self.mode.has(brainstate.mixin.Batching):
1106
+ ... # Process in batches
1107
+ ... batch_size = self.mode.batch_size
1108
+ ... return self.batch_process(x, batch_size)
1109
+ ...
1110
+ ... return self.process(x)
1111
+ >>>
1112
+ >>> model = NeuralNetwork()
1113
+ >>> model.set_train_mode(batch_size=64)
1114
+ """
1115
+
1116
+ def __init__(self, *modes: Mode):
1117
+ # Validate that all arguments are Mode instances
1118
+ for m_ in modes:
1119
+ if not isinstance(m_, Mode):
1120
+ raise TypeError(f'The supported type must be a tuple/list of Mode. But we got {m_}')
1121
+
1122
+ # Store the modes as a tuple
1123
+ self.modes = tuple(modes)
1124
+
1125
+ # Store the types of the modes for quick lookup
1126
+ self.types = set([m.__class__ for m in modes])
1127
+
1128
+ def __repr__(self):
1129
+ """
1130
+ String representation showing all combined modes.
1131
+
1132
+ Returns
1133
+ -------
1134
+ str
1135
+ A string showing the joint mode and its components.
1136
+ """
1137
+ return f'{self.__class__.__name__}({", ".join([repr(m) for m in self.modes])})'
1138
+
1139
+ def has(self, mode: type):
1140
+ """
1141
+ Check whether any of the combined modes includes the desired type.
1142
+
1143
+ Parameters
1144
+ ----------
1145
+ mode : type
1146
+ The mode type to check for.
1147
+
1148
+ Returns
1149
+ -------
1150
+ bool
1151
+ True if any of the combined modes is or inherits from the specified type.
1152
+
1153
+ Examples
1154
+ --------
1155
+ .. code-block:: python
1156
+
1157
+ >>> import brainstate
1158
+ >>>
1159
+ >>> training = brainstate.mixin.Training()
1160
+ >>> batching = brainstate.mixin.Batching(batch_size=16)
1161
+ >>> joint = brainstate.mixin.JointMode(training, batching)
1162
+ >>>
1163
+ >>> assert joint.has(brainstate.mixin.Training)
1164
+ >>> assert joint.has(brainstate.mixin.Batching)
1165
+ >>> assert joint.has(brainstate.mixin.Mode) # Base class
1166
+ """
1167
+ assert isinstance(mode, type), 'Must be a type.'
1168
+ # Check if any of the combined mode types is a subclass of the target mode
1169
+ return any([issubclass(cls, mode) for cls in self.types])
1170
+
1171
+ def is_a(self, cls: type):
1172
+ """
1173
+ Check whether the joint mode is exactly the desired combined type.
1174
+
1175
+ This is a complex check that verifies the joint mode matches a specific
1176
+ combination of types.
1177
+
1178
+ Parameters
1179
+ ----------
1180
+ cls : type
1181
+ The combined type to check against.
1182
+
1183
+ Returns
1184
+ -------
1185
+ bool
1186
+ True if the joint mode exactly matches the specified type combination.
1187
+ """
1188
+ # Use JointTypes to create the expected type from our mode types
1189
+ return JointTypes(*tuple(self.types)) == cls
1190
+
1191
+ def __getattr__(self, item):
1192
+ """
1193
+ Get attributes from the combined modes.
1194
+
1195
+ This method searches through all combined modes to find the requested
1196
+ attribute, allowing transparent access to properties of any of the
1197
+ combined modes.
1198
+
1199
+ Parameters
1200
+ ----------
1201
+ item : str
1202
+ The attribute name to search for.
1203
+
1204
+ Returns
1205
+ -------
1206
+ Any
1207
+ The attribute value from the first mode that has it.
1208
+
1209
+ Raises
1210
+ ------
1211
+ AttributeError
1212
+ If the attribute is not found in any of the combined modes.
1213
+
1214
+ Examples
1215
+ --------
1216
+ .. code-block:: python
1217
+
1218
+ >>> import brainstate
1219
+ >>>
1220
+ >>> batching = brainstate.mixin.Batching(batch_size=32, batch_axis=1)
1221
+ >>> training = brainstate.mixin.Training()
1222
+ >>> joint = brainstate.mixin.JointMode(batching, training)
1223
+ >>>
1224
+ >>> # Access batching attributes directly
1225
+ >>> print(joint.batch_size) # 32
1226
+ >>> print(joint.batch_axis) # 1
1227
+ """
1228
+ # Don't interfere with accessing modes and types attributes
1229
+ if item in ['modes', 'types']:
1230
+ return super().__getattribute__(item)
1231
+
1232
+ # Search for the attribute in each combined mode
1233
+ for m in self.modes:
1234
+ if hasattr(m, item):
1235
+ return getattr(m, item)
1236
+
1237
+ # If not found, fall back to default behavior (will raise AttributeError)
1238
+ return super().__getattribute__(item)
1239
+
1240
+
1241
+ class Batching(Mode):
1242
+ """
1243
+ Mode indicating batched computation.
1244
+
1245
+ This mode specifies that computations should be performed on batches of data,
1246
+ including information about the batch size and which axis represents the batch
1247
+ dimension.
1248
+
1249
+ Parameters
1250
+ ----------
1251
+ batch_size : int, default 1
1252
+ The size of each batch.
1253
+ batch_axis : int, default 0
1254
+ The axis along which batching occurs.
1255
+
1256
+ Attributes
1257
+ ----------
1258
+ batch_size : int
1259
+ The number of samples in each batch.
1260
+ batch_axis : int
1261
+ The axis index representing the batch dimension.
1262
+
1263
+ Examples
1264
+ --------
1265
+ Basic batching configuration:
1266
+
1267
+ .. code-block:: python
1268
+
1269
+ >>> import brainstate
1270
+ >>>
1271
+ >>> # Create a batching mode
1272
+ >>> batching = brainstate.mixin.Batching(batch_size=32, batch_axis=0)
1273
+ >>> print(batching) # Batching(in_size=32, axis=0)
1274
+ >>>
1275
+ >>> # Access batch parameters
1276
+ >>> print(f"Processing {batching.batch_size} samples at once")
1277
+ >>> print(f"Batch dimension is axis {batching.batch_axis}")
1278
+
1279
+ Using in a model:
1280
+
1281
+ .. code-block:: python
1282
+
1283
+ >>> import jax.numpy as jnp
1284
+ >>>
1285
+ >>> class BatchedModel:
1286
+ ... def __init__(self):
1287
+ ... self.mode = None
1288
+ ...
1289
+ ... def set_batch_mode(self, batch_size, batch_axis=0):
1290
+ ... self.mode = brainstate.mixin.Batching(batch_size, batch_axis)
1291
+ ...
1292
+ ... def process(self, x):
1293
+ ... if self.mode is not None and self.mode.has(brainstate.mixin.Batching):
1294
+ ... # Process in batches
1295
+ ... batch_size = self.mode.batch_size
1296
+ ... axis = self.mode.batch_axis
1297
+ ... return jnp.mean(x, axis=axis, keepdims=True)
1298
+ ... return x
1299
+ >>>
1300
+ >>> model = BatchedModel()
1301
+ >>> model.set_batch_mode(batch_size=64)
1302
+ >>>
1303
+ >>> # Process batched data
1304
+ >>> data = jnp.random.randn(64, 100) # 64 samples, 100 features
1305
+ >>> result = model.process(data)
1306
+
1307
+ Combining with other modes:
1308
+
1309
+ .. code-block:: python
1310
+
1311
+ >>> # Combine batching with training mode
1312
+ >>> training = brainstate.mixin.Training()
1313
+ >>> batching = brainstate.mixin.Batching(batch_size=128)
1314
+ >>> combined = brainstate.mixin.JointMode(training, batching)
1315
+ >>>
1316
+ >>> # Use in a training loop
1317
+ >>> def train_step(model, data, mode):
1318
+ ... if mode.has(brainstate.mixin.Batching):
1319
+ ... # Split data into batches
1320
+ ... batch_size = mode.batch_size
1321
+ ... # ... batched processing ...
1322
+ ... if mode.has(brainstate.mixin.Training):
1323
+ ... # Apply training-specific operations
1324
+ ... # ... training logic ...
1325
+ ... pass
1326
+ """
1327
+
1328
+ def __init__(self, batch_size: int = 1, batch_axis: int = 0):
1329
+ self.batch_size = batch_size
1330
+ self.batch_axis = batch_axis
1331
+
1332
+ def __repr__(self):
1333
+ """
1334
+ String representation showing batch configuration.
1335
+
1336
+ Returns
1337
+ -------
1338
+ str
1339
+ A string showing the batch size and axis.
1340
+ """
1341
+ return f'{self.__class__.__name__}(in_size={self.batch_size}, axis={self.batch_axis})'
1342
+
1343
+
1344
+ class Training(Mode):
1345
+ """
1346
+ Mode indicating training computation.
1347
+
1348
+ This mode specifies that the model is in training mode, which typically
1349
+ enables behaviors like dropout, batch normalization in training mode,
1350
+ gradient computation, etc.
1351
+
1352
+ Examples
1353
+ --------
1354
+ Basic training mode:
1355
+
1356
+ .. code-block:: python
1357
+
1358
+ >>> import brainstate
1359
+ >>>
1360
+ >>> # Create training mode
1361
+ >>> training = brainstate.mixin.Training()
1362
+ >>> print(training) # Training
1363
+ >>>
1364
+ >>> # Check mode
1365
+ >>> assert training.is_a(brainstate.mixin.Training)
1366
+ >>> assert training.has(brainstate.mixin.Mode)
1367
+
1368
+ Using in a model with dropout:
1369
+
1370
+ .. code-block:: python
1371
+
1372
+ >>> import brainstate
1373
+ >>> import jax
1374
+ >>> import jax.numpy as jnp
1375
+ >>>
1376
+ >>> class ModelWithDropout:
1377
+ ... def __init__(self, dropout_rate=0.5):
1378
+ ... self.dropout_rate = dropout_rate
1379
+ ... self.mode = None
1380
+ ...
1381
+ ... def set_training(self, is_training=True):
1382
+ ... if is_training:
1383
+ ... self.mode = brainstate.mixin.Training()
1384
+ ... else:
1385
+ ... self.mode = brainstate.mixin.Mode() # Evaluation mode
1386
+ ...
1387
+ ... def forward(self, x, rng_key):
1388
+ ... # Apply dropout only during training
1389
+ ... if self.mode is not None and self.mode.has(brainstate.mixin.Training):
1390
+ ... keep_prob = 1.0 - self.dropout_rate
1391
+ ... mask = jax.random.bernoulli(rng_key, keep_prob, x.shape)
1392
+ ... x = jnp.where(mask, x / keep_prob, 0)
1393
+ ... return x
1394
+ >>>
1395
+ >>> model = ModelWithDropout()
1396
+ >>>
1397
+ >>> # Training mode
1398
+ >>> model.set_training(True)
1399
+ >>> key = jax.random.PRNGKey(0)
1400
+ >>> x_train = jnp.ones((10, 20))
1401
+ >>> out_train = model.forward(x_train, key) # Dropout applied
1402
+ >>>
1403
+ >>> # Evaluation mode
1404
+ >>> model.set_training(False)
1405
+ >>> out_eval = model.forward(x_train, key) # No dropout
1406
+
1407
+ Combining with batching:
1408
+
1409
+ .. code-block:: python
1410
+
1411
+ >>> # Create combined training and batching mode
1412
+ >>> training = brainstate.mixin.Training()
1413
+ >>> batching = brainstate.mixin.Batching(batch_size=32)
1414
+ >>> mode = brainstate.mixin.JointMode(training, batching)
1415
+ >>>
1416
+ >>> # Use in training configuration
1417
+ >>> class Trainer:
1418
+ ... def __init__(self, model, mode):
1419
+ ... self.model = model
1420
+ ... self.mode = mode
1421
+ ...
1422
+ ... def train_epoch(self, data):
1423
+ ... if self.mode.has(brainstate.mixin.Training):
1424
+ ... # Enable training-specific behaviors
1425
+ ... self.model.set_training(True)
1426
+ ...
1427
+ ... if self.mode.has(brainstate.mixin.Batching):
1428
+ ... # Process in batches
1429
+ ... batch_size = self.mode.batch_size
1430
+ ... # ... batched training loop ...
1431
+ ... pass
1432
+ """
1433
+ pass