brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +169 -58
- brainstate/_compatible_import.py +340 -148
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +45 -55
- brainstate/_state.py +1652 -1605
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -563
- brainstate/environ_test.py +1223 -62
- brainstate/graph/__init__.py +22 -29
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1433 -365
- brainstate/mixin_test.py +1017 -77
- brainstate/nn/__init__.py +137 -135
- brainstate/nn/_activations.py +1100 -808
- brainstate/nn/_activations_test.py +354 -331
- brainstate/nn/_collective_ops.py +633 -514
- brainstate/nn/_collective_ops_test.py +774 -43
- brainstate/nn/_common.py +226 -178
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +2010 -501
- brainstate/nn/_conv_test.py +849 -238
- brainstate/nn/_delay.py +575 -588
- brainstate/nn/_delay_test.py +243 -238
- brainstate/nn/_dropout.py +618 -426
- brainstate/nn/_dropout_test.py +477 -100
- brainstate/nn/_dynamics.py +1267 -1343
- brainstate/nn/_dynamics_test.py +67 -78
- brainstate/nn/_elementwise.py +1298 -1119
- brainstate/nn/_elementwise_test.py +830 -169
- brainstate/nn/_embedding.py +408 -58
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
- brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
- brainstate/nn/_exp_euler.py +254 -92
- brainstate/nn/_exp_euler_test.py +377 -35
- brainstate/nn/_linear.py +744 -424
- brainstate/nn/_linear_test.py +475 -107
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +384 -377
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -975
- brainstate/nn/_normalizations_test.py +699 -73
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +2239 -1177
- brainstate/nn/_poolings_test.py +953 -217
- brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +216 -89
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +809 -553
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
- brainstate/random/__init__.py +270 -24
- brainstate/random/_rand_funs.py +3938 -3616
- brainstate/random/_rand_funs_test.py +640 -567
- brainstate/random/_rand_seed.py +675 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1617 -1409
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
- brainstate/{augment → transform}/_autograd.py +1025 -778
- brainstate/{augment → transform}/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +220 -220
- brainstate/{compile → transform}/_error_if.py +94 -92
- brainstate/{compile → transform}/_error_if_test.py +52 -52
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +38 -38
- brainstate/{compile → transform}/_jit.py +399 -346
- brainstate/{compile → transform}/_jit_test.py +143 -143
- brainstate/{compile → transform}/_loop_collect_return.py +675 -536
- brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
- brainstate/{compile → transform}/_loop_no_collection.py +283 -184
- brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +255 -202
- brainstate/{augment → transform}/_random.py +171 -151
- brainstate/{compile → transform}/_unvmap.py +256 -159
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +837 -304
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +27 -50
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +945 -469
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +910 -523
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
- brainstate-0.2.1.dist-info/RECORD +111 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/mixin.py
CHANGED
@@ -1,365 +1,1433 @@
|
|
1
|
-
# Copyright 2024
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
'
|
33
|
-
|
34
|
-
|
35
|
-
'JointTypes',
|
36
|
-
'OneOfTypes',
|
37
|
-
|
38
|
-
|
39
|
-
'Mode',
|
40
|
-
'JointMode',
|
41
|
-
'Batching',
|
42
|
-
'Training',
|
43
|
-
]
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
class
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
self.
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
class
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
return
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
class
|
364
|
-
|
365
|
-
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
"""
|
19
|
+
Mixin classes and utility types for brainstate.
|
20
|
+
|
21
|
+
This module provides various mixin classes and custom type definitions that
|
22
|
+
enhance the functionality of brainstate components. It includes parameter
|
23
|
+
description mixins, alignment interfaces, and custom type definitions for
|
24
|
+
expressing complex type requirements.
|
25
|
+
"""
|
26
|
+
|
27
|
+
from typing import Sequence, Optional, TypeVar, Union, _GenericAlias
|
28
|
+
|
29
|
+
import jax
|
30
|
+
|
31
|
+
__all__ = [
|
32
|
+
'Mixin',
|
33
|
+
'ParamDesc',
|
34
|
+
'ParamDescriber',
|
35
|
+
'JointTypes',
|
36
|
+
'OneOfTypes',
|
37
|
+
'_JointGenericAlias',
|
38
|
+
'_OneOfGenericAlias',
|
39
|
+
'Mode',
|
40
|
+
'JointMode',
|
41
|
+
'Batching',
|
42
|
+
'Training',
|
43
|
+
]
|
44
|
+
|
45
|
+
T = TypeVar('T')
|
46
|
+
ArrayLike = jax.typing.ArrayLike
|
47
|
+
|
48
|
+
|
49
|
+
def hashable(x):
|
50
|
+
"""
|
51
|
+
Check if an object is hashable.
|
52
|
+
|
53
|
+
Parameters
|
54
|
+
----------
|
55
|
+
x : Any
|
56
|
+
The object to check for hashability.
|
57
|
+
|
58
|
+
Returns
|
59
|
+
-------
|
60
|
+
bool
|
61
|
+
True if the object is hashable, False otherwise.
|
62
|
+
|
63
|
+
Examples
|
64
|
+
--------
|
65
|
+
.. code-block:: python
|
66
|
+
|
67
|
+
>>> import brainstate
|
68
|
+
>>>
|
69
|
+
>>> # Hashable objects
|
70
|
+
>>> assert brainstate.mixin.hashable(42) == True
|
71
|
+
>>> assert brainstate.mixin.hashable("string") == True
|
72
|
+
>>> assert brainstate.mixin.hashable((1, 2, 3)) == True
|
73
|
+
>>>
|
74
|
+
>>> # Non-hashable objects
|
75
|
+
>>> assert brainstate.mixin.hashable([1, 2, 3]) == False
|
76
|
+
>>> assert brainstate.mixin.hashable({"key": "value"}) == False
|
77
|
+
"""
|
78
|
+
try:
|
79
|
+
hash(x)
|
80
|
+
return True
|
81
|
+
except TypeError:
|
82
|
+
return False
|
83
|
+
|
84
|
+
|
85
|
+
class Mixin(object):
|
86
|
+
"""
|
87
|
+
Base Mixin object for behavioral extensions.
|
88
|
+
|
89
|
+
The key characteristic of a :py:class:`~.Mixin` is that it provides only
|
90
|
+
behavioral functions without requiring initialization. Mixins are used to
|
91
|
+
add specific functionality to classes through multiple inheritance without
|
92
|
+
the complexity of a full base class.
|
93
|
+
|
94
|
+
Notes
|
95
|
+
-----
|
96
|
+
Mixins should not define ``__init__`` methods. They should only provide
|
97
|
+
methods that add specific behaviors to the classes that inherit from them.
|
98
|
+
|
99
|
+
Examples
|
100
|
+
--------
|
101
|
+
Creating a custom mixin:
|
102
|
+
|
103
|
+
.. code-block:: python
|
104
|
+
|
105
|
+
>>> import brainstate
|
106
|
+
>>>
|
107
|
+
>>> class LoggingMixin(brainstate.mixin.Mixin):
|
108
|
+
... def log(self, message):
|
109
|
+
... print(f"[{self.__class__.__name__}] {message}")
|
110
|
+
|
111
|
+
>>> class MyComponent(brainstate.nn.Module, LoggingMixin):
|
112
|
+
... def __init__(self):
|
113
|
+
... super().__init__()
|
114
|
+
...
|
115
|
+
... def process(self):
|
116
|
+
... self.log("Processing data...")
|
117
|
+
... return "Done"
|
118
|
+
>>>
|
119
|
+
>>> component = MyComponent()
|
120
|
+
>>> component.process() # Prints: [MyComponent] Processing data...
|
121
|
+
"""
|
122
|
+
pass
|
123
|
+
|
124
|
+
|
125
|
+
class ParamDesc(Mixin):
|
126
|
+
"""
|
127
|
+
Mixin for describing initialization parameters.
|
128
|
+
|
129
|
+
This mixin enables a class to have a ``desc`` classmethod, which produces
|
130
|
+
an instance of :py:class:`~.ParamDescriber`. This is useful for creating
|
131
|
+
parameter templates that can be reused to instantiate multiple objects
|
132
|
+
with the same configuration.
|
133
|
+
|
134
|
+
Attributes
|
135
|
+
----------
|
136
|
+
non_hashable_params : sequence of str, optional
|
137
|
+
Names of parameters that are not hashable and should be handled specially.
|
138
|
+
|
139
|
+
Notes
|
140
|
+
-----
|
141
|
+
This mixin can be applied to any Python class, not just brainstate-specific classes.
|
142
|
+
|
143
|
+
Examples
|
144
|
+
--------
|
145
|
+
Basic usage of ParamDesc:
|
146
|
+
|
147
|
+
.. code-block:: python
|
148
|
+
|
149
|
+
>>> import brainstate
|
150
|
+
>>>
|
151
|
+
>>> class NeuronModel(brainstate.mixin.ParamDesc):
|
152
|
+
... def __init__(self, size, tau=10.0, threshold=1.0):
|
153
|
+
... self.size = size
|
154
|
+
... self.tau = tau
|
155
|
+
... self.threshold = threshold
|
156
|
+
>>>
|
157
|
+
>>> # Create a parameter descriptor
|
158
|
+
>>> neuron_desc = NeuronModel.desc(size=100, tau=20.0)
|
159
|
+
>>>
|
160
|
+
>>> # Use the descriptor to create instances
|
161
|
+
>>> neuron1 = neuron_desc(threshold=0.8) # Creates with threshold=0.8
|
162
|
+
>>> neuron2 = neuron_desc(threshold=1.2) # Creates with threshold=1.2
|
163
|
+
>>>
|
164
|
+
>>> # Both neurons share size=100, tau=20.0 but have different thresholds
|
165
|
+
|
166
|
+
Creating reusable templates:
|
167
|
+
|
168
|
+
.. code-block:: python
|
169
|
+
|
170
|
+
>>> # Define a template for excitatory neurons
|
171
|
+
>>> exc_neuron_template = NeuronModel.desc(size=1000, tau=10.0, threshold=1.0)
|
172
|
+
>>>
|
173
|
+
>>> # Define a template for inhibitory neurons
|
174
|
+
>>> inh_neuron_template = NeuronModel.desc(size=250, tau=5.0, threshold=0.5)
|
175
|
+
>>>
|
176
|
+
>>> # Create multiple instances from templates
|
177
|
+
>>> exc_population = [exc_neuron_template() for _ in range(5)]
|
178
|
+
>>> inh_population = [inh_neuron_template() for _ in range(2)]
|
179
|
+
"""
|
180
|
+
|
181
|
+
# Optional list of parameter names that are not hashable
|
182
|
+
# These will be converted to strings for hashing purposes
|
183
|
+
non_hashable_params: Optional[Sequence[str]] = None
|
184
|
+
|
185
|
+
@classmethod
|
186
|
+
def desc(cls, *args, **kwargs) -> 'ParamDescriber':
|
187
|
+
"""
|
188
|
+
Create a parameter describer for this class.
|
189
|
+
|
190
|
+
Parameters
|
191
|
+
----------
|
192
|
+
*args
|
193
|
+
Positional arguments to be used in future instantiations.
|
194
|
+
**kwargs
|
195
|
+
Keyword arguments to be used in future instantiations.
|
196
|
+
|
197
|
+
Returns
|
198
|
+
-------
|
199
|
+
ParamDescriber
|
200
|
+
A descriptor that can be used to create instances with these parameters.
|
201
|
+
"""
|
202
|
+
return ParamDescriber(cls, *args, **kwargs)
|
203
|
+
|
204
|
+
|
205
|
+
class HashableDict(dict):
|
206
|
+
"""
|
207
|
+
A dictionary that can be hashed by converting non-hashable values to strings.
|
208
|
+
|
209
|
+
This is used internally to make parameter dictionaries hashable so they can
|
210
|
+
be used as part of cache keys or other contexts requiring hashability.
|
211
|
+
|
212
|
+
Parameters
|
213
|
+
----------
|
214
|
+
the_dict : dict
|
215
|
+
The dictionary to make hashable.
|
216
|
+
|
217
|
+
Notes
|
218
|
+
-----
|
219
|
+
Non-hashable values in the dictionary are automatically converted to their
|
220
|
+
string representation.
|
221
|
+
|
222
|
+
Examples
|
223
|
+
--------
|
224
|
+
.. code-block:: python
|
225
|
+
|
226
|
+
>>> import brainstate
|
227
|
+
>>> import jax.numpy as jnp
|
228
|
+
>>>
|
229
|
+
>>> # Regular dict with non-hashable values cannot be hashed
|
230
|
+
>>> regular_dict = {"array": jnp.array([1, 2, 3]), "value": 42}
|
231
|
+
>>> # hash(regular_dict) # This would raise TypeError
|
232
|
+
>>>
|
233
|
+
>>> # HashableDict can be hashed
|
234
|
+
>>> hashable = brainstate.mixin.HashableDict(regular_dict)
|
235
|
+
>>> key = hash(hashable) # This works!
|
236
|
+
>>>
|
237
|
+
>>> # Can be used in sets or as dict keys
|
238
|
+
>>> cache = {hashable: "result"}
|
239
|
+
"""
|
240
|
+
|
241
|
+
def __init__(self, the_dict: dict):
|
242
|
+
# Process the dictionary to ensure all values are hashable
|
243
|
+
out = dict()
|
244
|
+
for k, v in the_dict.items():
|
245
|
+
if not hashable(v):
|
246
|
+
# Convert non-hashable values to their string representation
|
247
|
+
v = str(v)
|
248
|
+
out[k] = v
|
249
|
+
super().__init__(out)
|
250
|
+
|
251
|
+
def __hash__(self):
|
252
|
+
"""
|
253
|
+
Compute hash from sorted items for consistent hashing regardless of insertion order.
|
254
|
+
"""
|
255
|
+
return hash(tuple(sorted(self.items())))
|
256
|
+
|
257
|
+
|
258
|
+
class NoSubclassMeta(type):
|
259
|
+
"""
|
260
|
+
Metaclass that prevents a class from being subclassed.
|
261
|
+
|
262
|
+
This is used to ensure that certain classes (like ParamDescriber) are used
|
263
|
+
as-is and not extended through inheritance, which could lead to unexpected
|
264
|
+
behavior.
|
265
|
+
|
266
|
+
Raises
|
267
|
+
------
|
268
|
+
TypeError
|
269
|
+
If an attempt is made to subclass a class using this metaclass.
|
270
|
+
"""
|
271
|
+
|
272
|
+
def __new__(cls, name, bases, classdict):
|
273
|
+
# Check if any base class uses NoSubclassMeta
|
274
|
+
for b in bases:
|
275
|
+
if isinstance(b, NoSubclassMeta):
|
276
|
+
raise TypeError("type '{0}' is not an acceptable base type".format(b.__name__))
|
277
|
+
return type.__new__(cls, name, bases, dict(classdict))
|
278
|
+
|
279
|
+
|
280
|
+
class ParamDescriber(metaclass=NoSubclassMeta):
|
281
|
+
"""
|
282
|
+
Parameter descriptor for deferred object instantiation.
|
283
|
+
|
284
|
+
This class stores a class reference along with arguments and keyword arguments,
|
285
|
+
allowing for deferred instantiation. It's useful for creating templates that
|
286
|
+
can be reused to create multiple instances with similar configurations.
|
287
|
+
|
288
|
+
Parameters
|
289
|
+
----------
|
290
|
+
cls : type
|
291
|
+
The class to be instantiated.
|
292
|
+
*desc_tuple
|
293
|
+
Positional arguments to be stored and used during instantiation.
|
294
|
+
**desc_dict
|
295
|
+
Keyword arguments to be stored and used during instantiation.
|
296
|
+
|
297
|
+
Attributes
|
298
|
+
----------
|
299
|
+
cls : type
|
300
|
+
The class that will be instantiated.
|
301
|
+
args : tuple
|
302
|
+
Stored positional arguments.
|
303
|
+
kwargs : dict
|
304
|
+
Stored keyword arguments.
|
305
|
+
identifier : tuple
|
306
|
+
A hashable identifier for this descriptor.
|
307
|
+
|
308
|
+
Notes
|
309
|
+
-----
|
310
|
+
ParamDescriber cannot be subclassed due to the NoSubclassMeta metaclass.
|
311
|
+
This ensures consistent behavior across the codebase.
|
312
|
+
|
313
|
+
Examples
|
314
|
+
--------
|
315
|
+
Manual creation of a descriptor:
|
316
|
+
|
317
|
+
.. code-block:: python
|
318
|
+
|
319
|
+
>>> import brainstate
|
320
|
+
>>>
|
321
|
+
>>> class Network:
|
322
|
+
... def __init__(self, n_neurons, learning_rate=0.01):
|
323
|
+
... self.n_neurons = n_neurons
|
324
|
+
... self.learning_rate = learning_rate
|
325
|
+
>>>
|
326
|
+
>>> # Create a descriptor
|
327
|
+
>>> network_desc = brainstate.mixin.ParamDescriber(
|
328
|
+
... Network, n_neurons=1000, learning_rate=0.001
|
329
|
+
... )
|
330
|
+
>>>
|
331
|
+
>>> # Use the descriptor to create instances with additional args
|
332
|
+
>>> net1 = network_desc()
|
333
|
+
>>> net2 = network_desc() # Same configuration
|
334
|
+
|
335
|
+
Using with ParamDesc mixin:
|
336
|
+
|
337
|
+
.. code-block:: python
|
338
|
+
|
339
|
+
>>> class Network(brainstate.mixin.ParamDesc):
|
340
|
+
... def __init__(self, n_neurons, learning_rate=0.01):
|
341
|
+
... self.n_neurons = n_neurons
|
342
|
+
... self.learning_rate = learning_rate
|
343
|
+
>>>
|
344
|
+
>>> # More concise syntax using the desc() classmethod
|
345
|
+
>>> network_desc = Network.desc(n_neurons=1000)
|
346
|
+
>>> net = network_desc(learning_rate=0.005) # Override learning_rate
|
347
|
+
"""
|
348
|
+
|
349
|
+
def __init__(self, cls: T, *desc_tuple, **desc_dict):
|
350
|
+
# Store the class to be instantiated
|
351
|
+
self.cls: type = cls
|
352
|
+
|
353
|
+
# Store the arguments for later instantiation
|
354
|
+
self.args = desc_tuple
|
355
|
+
self.kwargs = desc_dict
|
356
|
+
|
357
|
+
# Create a hashable identifier for caching/comparison purposes
|
358
|
+
# This combines the class, args tuple, and hashable kwargs dict
|
359
|
+
self._identifier = (cls, tuple(desc_tuple), HashableDict(desc_dict))
|
360
|
+
|
361
|
+
def __call__(self, *args, **kwargs) -> T:
|
362
|
+
"""
|
363
|
+
Instantiate the class with stored and additional arguments.
|
364
|
+
|
365
|
+
Parameters
|
366
|
+
----------
|
367
|
+
*args
|
368
|
+
Additional positional arguments to append.
|
369
|
+
**kwargs
|
370
|
+
Additional keyword arguments to merge (will override stored kwargs).
|
371
|
+
|
372
|
+
Returns
|
373
|
+
-------
|
374
|
+
T
|
375
|
+
An instance of the described class.
|
376
|
+
"""
|
377
|
+
# Merge stored arguments with new arguments
|
378
|
+
# Stored args come first, then new args
|
379
|
+
# Merge kwargs with new kwargs overriding stored ones
|
380
|
+
merged_kwargs = {**self.kwargs, **kwargs}
|
381
|
+
return self.cls(*self.args, *args, **merged_kwargs)
|
382
|
+
|
383
|
+
def init(self, *args, **kwargs):
|
384
|
+
"""
|
385
|
+
Alias for __call__, explicitly named for clarity.
|
386
|
+
|
387
|
+
Parameters
|
388
|
+
----------
|
389
|
+
*args
|
390
|
+
Additional positional arguments.
|
391
|
+
**kwargs
|
392
|
+
Additional keyword arguments.
|
393
|
+
|
394
|
+
Returns
|
395
|
+
-------
|
396
|
+
T
|
397
|
+
An instance of the described class.
|
398
|
+
"""
|
399
|
+
return self.__call__(*args, **kwargs)
|
400
|
+
|
401
|
+
def __instancecheck__(self, instance):
|
402
|
+
"""
|
403
|
+
Check if an instance is compatible with this descriptor.
|
404
|
+
|
405
|
+
Parameters
|
406
|
+
----------
|
407
|
+
instance : Any
|
408
|
+
The instance to check.
|
409
|
+
|
410
|
+
Returns
|
411
|
+
-------
|
412
|
+
bool
|
413
|
+
True if the instance is a ParamDescriber for a compatible class.
|
414
|
+
"""
|
415
|
+
# Must be a ParamDescriber
|
416
|
+
if not isinstance(instance, ParamDescriber):
|
417
|
+
return False
|
418
|
+
# The described class must be a subclass of our class
|
419
|
+
if not issubclass(instance.cls, self.cls):
|
420
|
+
return False
|
421
|
+
return True
|
422
|
+
|
423
|
+
@classmethod
|
424
|
+
def __class_getitem__(cls, item: type):
|
425
|
+
"""
|
426
|
+
Support for subscript notation: ParamDescriber[MyClass].
|
427
|
+
|
428
|
+
Parameters
|
429
|
+
----------
|
430
|
+
item : type
|
431
|
+
The class to create a descriptor for.
|
432
|
+
|
433
|
+
Returns
|
434
|
+
-------
|
435
|
+
ParamDescriber
|
436
|
+
A descriptor for the given class.
|
437
|
+
"""
|
438
|
+
return ParamDescriber(item)
|
439
|
+
|
440
|
+
@property
|
441
|
+
def identifier(self):
|
442
|
+
"""
|
443
|
+
Get the unique identifier for this descriptor.
|
444
|
+
|
445
|
+
Returns
|
446
|
+
-------
|
447
|
+
tuple
|
448
|
+
A hashable identifier consisting of (class, args, kwargs).
|
449
|
+
"""
|
450
|
+
return self._identifier
|
451
|
+
|
452
|
+
@identifier.setter
|
453
|
+
def identifier(self, value: ArrayLike):
|
454
|
+
"""
|
455
|
+
Prevent modification of the identifier.
|
456
|
+
|
457
|
+
Raises
|
458
|
+
------
|
459
|
+
AttributeError
|
460
|
+
Always, as the identifier is read-only.
|
461
|
+
"""
|
462
|
+
raise AttributeError('Cannot set the identifier.')
|
463
|
+
|
464
|
+
|
465
|
+
def not_implemented(func):
|
466
|
+
"""
|
467
|
+
Decorator to mark a function as not implemented.
|
468
|
+
|
469
|
+
This decorator wraps a function to raise NotImplementedError when called,
|
470
|
+
and adds a ``not_implemented`` attribute for checking.
|
471
|
+
|
472
|
+
Parameters
|
473
|
+
----------
|
474
|
+
func : callable
|
475
|
+
The function to mark as not implemented.
|
476
|
+
|
477
|
+
Returns
|
478
|
+
-------
|
479
|
+
callable
|
480
|
+
A wrapper function that raises NotImplementedError.
|
481
|
+
|
482
|
+
Examples
|
483
|
+
--------
|
484
|
+
.. code-block:: python
|
485
|
+
|
486
|
+
>>> import brainstate
|
487
|
+
>>>
|
488
|
+
>>> class BaseModel:
|
489
|
+
... @brainstate.mixin.not_implemented
|
490
|
+
... def process(self, x):
|
491
|
+
... pass
|
492
|
+
>>>
|
493
|
+
>>> model = BaseModel()
|
494
|
+
>>> # model.process(10) # Raises: NotImplementedError: process is not implemented.
|
495
|
+
>>>
|
496
|
+
>>> # Check if a method is not implemented
|
497
|
+
>>> assert hasattr(BaseModel.process, 'not_implemented')
|
498
|
+
"""
|
499
|
+
|
500
|
+
def wrapper(*args, **kwargs):
|
501
|
+
raise NotImplementedError(f'{func.__name__} is not implemented.')
|
502
|
+
|
503
|
+
# Mark the wrapper so we can detect not-implemented methods
|
504
|
+
wrapper.not_implemented = True
|
505
|
+
return wrapper
|
506
|
+
|
507
|
+
|
508
|
+
class _JointGenericAlias(_GenericAlias, _root=True):
|
509
|
+
"""
|
510
|
+
Generic alias for JointTypes (intersection types).
|
511
|
+
|
512
|
+
This class represents a type that requires all specified types to be satisfied.
|
513
|
+
Unlike _MetaUnionType which creates actual classes with metaclass conflicts,
|
514
|
+
this uses typing's generic alias system to avoid metaclass issues.
|
515
|
+
"""
|
516
|
+
|
517
|
+
def __instancecheck__(self, obj):
|
518
|
+
"""
|
519
|
+
Check if an instance is an instance of all component types.
|
520
|
+
"""
|
521
|
+
return all(isinstance(obj, cls) for cls in self.__args__)
|
522
|
+
|
523
|
+
def __subclasscheck__(self, subclass):
|
524
|
+
"""
|
525
|
+
Check if a class is a subclass of all component types.
|
526
|
+
"""
|
527
|
+
return all(issubclass(subclass, cls) for cls in self.__args__)
|
528
|
+
|
529
|
+
def __eq__(self, other):
|
530
|
+
"""
|
531
|
+
Check equality with another type.
|
532
|
+
|
533
|
+
Two JointTypes are equal if they have the same component types,
|
534
|
+
regardless of order.
|
535
|
+
"""
|
536
|
+
if not isinstance(other, _JointGenericAlias):
|
537
|
+
return NotImplemented
|
538
|
+
return set(self.__args__) == set(other.__args__)
|
539
|
+
|
540
|
+
def __hash__(self):
|
541
|
+
"""
|
542
|
+
Return hash of the JointType.
|
543
|
+
|
544
|
+
The hash is based on the frozenset of component types to ensure
|
545
|
+
that JointTypes with the same types (regardless of order) have
|
546
|
+
the same hash.
|
547
|
+
"""
|
548
|
+
return hash(frozenset(self.__args__))
|
549
|
+
|
550
|
+
def __repr__(self):
|
551
|
+
"""
|
552
|
+
Return string representation of the JointType.
|
553
|
+
|
554
|
+
Returns a readable representation showing all component types.
|
555
|
+
"""
|
556
|
+
args_str = ', '.join(
|
557
|
+
arg.__module__ + '.' + arg.__name__ if hasattr(arg, '__module__') and hasattr(arg, '__name__')
|
558
|
+
else str(arg)
|
559
|
+
for arg in self.__args__
|
560
|
+
)
|
561
|
+
return f'JointTypes[{args_str}]'
|
562
|
+
|
563
|
+
def __reduce__(self):
|
564
|
+
"""
|
565
|
+
Support for pickling.
|
566
|
+
|
567
|
+
Returns the necessary information to reconstruct the JointType
|
568
|
+
when unpickling.
|
569
|
+
"""
|
570
|
+
return (_JointGenericAlias, (self.__origin__, self.__args__))
|
571
|
+
|
572
|
+
|
573
|
+
class _OneOfGenericAlias(_GenericAlias, _root=True):
|
574
|
+
"""
|
575
|
+
Generic alias for OneOfTypes (union types).
|
576
|
+
|
577
|
+
This class represents a type that requires at least one of the specified
|
578
|
+
types to be satisfied. It's similar to typing.Union but provides a consistent
|
579
|
+
interface with JointTypes and avoids potential metaclass conflicts.
|
580
|
+
"""
|
581
|
+
|
582
|
+
def __instancecheck__(self, obj):
|
583
|
+
"""
|
584
|
+
Check if an instance is an instance of any component type.
|
585
|
+
"""
|
586
|
+
return any(isinstance(obj, cls) for cls in self.__args__)
|
587
|
+
|
588
|
+
def __subclasscheck__(self, subclass):
|
589
|
+
"""
|
590
|
+
Check if a class is a subclass of any component type.
|
591
|
+
"""
|
592
|
+
return any(issubclass(subclass, cls) for cls in self.__args__)
|
593
|
+
|
594
|
+
def __eq__(self, other):
|
595
|
+
"""
|
596
|
+
Check equality with another type.
|
597
|
+
|
598
|
+
Two OneOfTypes are equal if they have the same component types,
|
599
|
+
regardless of order.
|
600
|
+
"""
|
601
|
+
if not isinstance(other, _OneOfGenericAlias):
|
602
|
+
return NotImplemented
|
603
|
+
return set(self.__args__) == set(other.__args__)
|
604
|
+
|
605
|
+
def __hash__(self):
|
606
|
+
"""
|
607
|
+
Return hash of the OneOfType.
|
608
|
+
|
609
|
+
The hash is based on the frozenset of component types to ensure
|
610
|
+
that OneOfTypes with the same types (regardless of order) have
|
611
|
+
the same hash.
|
612
|
+
"""
|
613
|
+
return hash(frozenset(self.__args__))
|
614
|
+
|
615
|
+
def __repr__(self):
|
616
|
+
"""
|
617
|
+
Return string representation of the OneOfType.
|
618
|
+
|
619
|
+
Returns a readable representation showing all component types.
|
620
|
+
"""
|
621
|
+
args_str = ', '.join(
|
622
|
+
arg.__module__ + '.' + arg.__name__ if hasattr(arg, '__module__') and hasattr(arg, '__name__')
|
623
|
+
else str(arg)
|
624
|
+
for arg in self.__args__
|
625
|
+
)
|
626
|
+
return f'OneOfTypes[{args_str}]'
|
627
|
+
|
628
|
+
def __reduce__(self):
|
629
|
+
"""
|
630
|
+
Support for pickling.
|
631
|
+
|
632
|
+
Returns the necessary information to reconstruct the OneOfType
|
633
|
+
when unpickling.
|
634
|
+
"""
|
635
|
+
return (_OneOfGenericAlias, (self.__origin__, self.__args__))
|
636
|
+
|
637
|
+
|
638
|
+
class _JointTypesClass:
|
639
|
+
"""Helper class to enable subscript syntax for JointTypes."""
|
640
|
+
|
641
|
+
def __call__(self, *types):
|
642
|
+
"""
|
643
|
+
Create a type that requires all specified types (intersection type).
|
644
|
+
|
645
|
+
This function creates a type hint that indicates a value must satisfy all
|
646
|
+
the specified types simultaneously. It's useful for expressing complex
|
647
|
+
type requirements where a single object must implement multiple interfaces.
|
648
|
+
|
649
|
+
Parameters
|
650
|
+
----------
|
651
|
+
*types : type
|
652
|
+
The types that must all be satisfied.
|
653
|
+
|
654
|
+
Returns
|
655
|
+
-------
|
656
|
+
type
|
657
|
+
A type that checks for all specified types.
|
658
|
+
|
659
|
+
Notes
|
660
|
+
-----
|
661
|
+
- If only one type is provided, that type is returned directly.
|
662
|
+
- Redundant types are automatically removed.
|
663
|
+
- The order of types doesn't matter for equality checks.
|
664
|
+
|
665
|
+
Examples
|
666
|
+
--------
|
667
|
+
Basic usage with interfaces:
|
668
|
+
|
669
|
+
.. code-block:: python
|
670
|
+
|
671
|
+
>>> import brainstate
|
672
|
+
>>> from typing import Protocol
|
673
|
+
>>>
|
674
|
+
>>> class Trainable(Protocol):
|
675
|
+
... def train(self): ...
|
676
|
+
>>>
|
677
|
+
>>> class Evaluable(Protocol):
|
678
|
+
... def evaluate(self): ...
|
679
|
+
>>>
|
680
|
+
>>> # A model that is both trainable and evaluable
|
681
|
+
>>> TrainableEvaluableModel = brainstate.mixin.JointTypes(Trainable, Evaluable)
|
682
|
+
>>> # Or using subscript syntax
|
683
|
+
>>> TrainableEvaluableModel = brainstate.mixin.JointTypes[Trainable, Evaluable]
|
684
|
+
>>>
|
685
|
+
>>> class NeuralNetwork(Trainable, Evaluable):
|
686
|
+
... def train(self):
|
687
|
+
... return "Training..."
|
688
|
+
...
|
689
|
+
... def evaluate(self):
|
690
|
+
... return "Evaluating..."
|
691
|
+
>>>
|
692
|
+
>>> model = NeuralNetwork()
|
693
|
+
>>> # model satisfies JointTypes(Trainable, Evaluable)
|
694
|
+
|
695
|
+
Using with mixin classes:
|
696
|
+
|
697
|
+
.. code-block:: python
|
698
|
+
|
699
|
+
>>> class Serializable:
|
700
|
+
... def save(self): pass
|
701
|
+
>>>
|
702
|
+
>>> class Visualizable:
|
703
|
+
... def plot(self): pass
|
704
|
+
>>>
|
705
|
+
>>> # Require both serialization and visualization
|
706
|
+
>>> FullFeaturedModel = brainstate.mixin.JointTypes[Serializable, Visualizable]
|
707
|
+
>>>
|
708
|
+
>>> class MyModel(Serializable, Visualizable):
|
709
|
+
... def save(self):
|
710
|
+
... return "Saved"
|
711
|
+
...
|
712
|
+
... def plot(self):
|
713
|
+
... return "Plotted"
|
714
|
+
"""
|
715
|
+
if len(types) == 0:
|
716
|
+
raise TypeError("Cannot create a JointTypes of no types.")
|
717
|
+
|
718
|
+
# Remove duplicates while preserving some order
|
719
|
+
seen = set()
|
720
|
+
unique_types = []
|
721
|
+
for t in types:
|
722
|
+
if t not in seen:
|
723
|
+
seen.add(t)
|
724
|
+
unique_types.append(t)
|
725
|
+
|
726
|
+
# If only one type, return it directly
|
727
|
+
if len(unique_types) == 1:
|
728
|
+
return unique_types[0]
|
729
|
+
|
730
|
+
# Create a generic alias for the joint type
|
731
|
+
# This avoids metaclass conflicts by using typing's generic alias system
|
732
|
+
return _JointGenericAlias(object, tuple(unique_types))
|
733
|
+
|
734
|
+
def __getitem__(self, item):
|
735
|
+
"""Enable subscript syntax: JointTypes[Type1, Type2]."""
|
736
|
+
if isinstance(item, tuple):
|
737
|
+
return self(*item)
|
738
|
+
else:
|
739
|
+
return self(item)
|
740
|
+
|
741
|
+
|
742
|
+
# Create singleton instance that acts as both a callable and supports subscript
|
743
|
+
JointTypes = _JointTypesClass()
|
744
|
+
|
745
|
+
|
746
|
+
class _OneOfTypesClass:
|
747
|
+
"""Helper class to enable subscript syntax for OneOfTypes."""
|
748
|
+
|
749
|
+
def __call__(self, *types):
|
750
|
+
"""
|
751
|
+
Create a type that requires one of the specified types (union type).
|
752
|
+
|
753
|
+
This is similar to typing.Union but provides a more intuitive name and
|
754
|
+
consistent behavior with JointTypes. It indicates that a value must satisfy
|
755
|
+
at least one of the specified types.
|
756
|
+
|
757
|
+
Parameters
|
758
|
+
----------
|
759
|
+
*types : type
|
760
|
+
The types, one of which must be satisfied.
|
761
|
+
|
762
|
+
Returns
|
763
|
+
-------
|
764
|
+
Union type
|
765
|
+
A union type of the specified types.
|
766
|
+
|
767
|
+
Notes
|
768
|
+
-----
|
769
|
+
- If only one type is provided, that type is returned directly.
|
770
|
+
- Redundant types are automatically removed.
|
771
|
+
- The order of types doesn't matter for equality checks.
|
772
|
+
- This is equivalent to typing.Union[...].
|
773
|
+
|
774
|
+
Examples
|
775
|
+
--------
|
776
|
+
Basic usage with different types:
|
777
|
+
|
778
|
+
.. code-block:: python
|
779
|
+
|
780
|
+
>>> import brainstate
|
781
|
+
>>>
|
782
|
+
>>> # A parameter that can be int or float
|
783
|
+
>>> NumericType = brainstate.mixin.OneOfTypes(int, float)
|
784
|
+
>>> # Or using subscript syntax
|
785
|
+
>>> NumericType = brainstate.mixin.OneOfTypes[int, float]
|
786
|
+
>>>
|
787
|
+
>>> def process_value(x: NumericType):
|
788
|
+
... return x * 2
|
789
|
+
>>>
|
790
|
+
>>> # Both work
|
791
|
+
>>> result1 = process_value(5) # int
|
792
|
+
>>> result2 = process_value(3.14) # float
|
793
|
+
|
794
|
+
Using with class types:
|
795
|
+
|
796
|
+
.. code-block:: python
|
797
|
+
|
798
|
+
>>> class NumpyArray:
|
799
|
+
... pass
|
800
|
+
>>>
|
801
|
+
>>> class JAXArray:
|
802
|
+
... pass
|
803
|
+
>>>
|
804
|
+
>>> # Accept either numpy or JAX arrays
|
805
|
+
>>> ArrayType = brainstate.mixin.OneOfTypes[NumpyArray, JAXArray]
|
806
|
+
>>>
|
807
|
+
>>> def compute(arr: ArrayType):
|
808
|
+
... if isinstance(arr, NumpyArray):
|
809
|
+
... return "Processing numpy array"
|
810
|
+
... elif isinstance(arr, JAXArray):
|
811
|
+
... return "Processing JAX array"
|
812
|
+
|
813
|
+
Combining with None for optional types:
|
814
|
+
|
815
|
+
.. code-block:: python
|
816
|
+
|
817
|
+
>>> # Optional string (equivalent to Optional[str])
|
818
|
+
>>> MaybeString = brainstate.mixin.OneOfTypes[str, type(None)]
|
819
|
+
>>>
|
820
|
+
>>> def format_name(name: MaybeString) -> str:
|
821
|
+
... if name is None:
|
822
|
+
... return "Anonymous"
|
823
|
+
... return name.title()
|
824
|
+
"""
|
825
|
+
if len(types) == 0:
|
826
|
+
raise TypeError("Cannot create a OneOfTypes of no types.")
|
827
|
+
|
828
|
+
# Remove duplicates
|
829
|
+
seen = set()
|
830
|
+
unique_types = []
|
831
|
+
for t in types:
|
832
|
+
if t not in seen:
|
833
|
+
seen.add(t)
|
834
|
+
unique_types.append(t)
|
835
|
+
|
836
|
+
# If only one type, return it directly
|
837
|
+
if len(unique_types) == 1:
|
838
|
+
return unique_types[0]
|
839
|
+
|
840
|
+
# Create a generic alias for the union type
|
841
|
+
# This provides consistency with JointTypes and avoids metaclass conflicts
|
842
|
+
return _OneOfGenericAlias(Union, tuple(unique_types))
|
843
|
+
|
844
|
+
def __getitem__(self, item):
|
845
|
+
"""Enable subscript syntax: OneOfTypes[Type1, Type2]."""
|
846
|
+
if isinstance(item, tuple):
|
847
|
+
return self(*item)
|
848
|
+
else:
|
849
|
+
return self(item)
|
850
|
+
|
851
|
+
|
852
|
+
# Create singleton instance that acts as both a callable and supports subscript
|
853
|
+
OneOfTypes = _OneOfTypesClass()
|
854
|
+
|
855
|
+
|
856
|
+
def __getattr__(name):
|
857
|
+
if name in [
|
858
|
+
'Mode',
|
859
|
+
'JointMode',
|
860
|
+
'Batching',
|
861
|
+
'Training',
|
862
|
+
'AlignPost',
|
863
|
+
'BindCondData',
|
864
|
+
]:
|
865
|
+
import warnings
|
866
|
+
warnings.warn(
|
867
|
+
f"brainstate.mixin.{name} is deprecated and will be removed in a future version. "
|
868
|
+
f"Please use brainpy.mixin.{name} instead.",
|
869
|
+
DeprecationWarning,
|
870
|
+
stacklevel=2
|
871
|
+
)
|
872
|
+
import brainpy
|
873
|
+
return getattr(brainpy.mixin, name)
|
874
|
+
raise AttributeError(
|
875
|
+
f'module {__name__!r} has no attribute {name!r}'
|
876
|
+
)
|
877
|
+
|
878
|
+
|
879
|
+
class Mode(Mixin):
|
880
|
+
"""
|
881
|
+
Base class for computation behavior modes.
|
882
|
+
|
883
|
+
Modes are used to represent different computational contexts or behaviors,
|
884
|
+
such as training vs evaluation, batched vs single-sample processing, etc.
|
885
|
+
They provide a flexible way to configure how models and components behave
|
886
|
+
in different scenarios.
|
887
|
+
|
888
|
+
Examples
|
889
|
+
--------
|
890
|
+
Creating a custom mode:
|
891
|
+
|
892
|
+
.. code-block:: python
|
893
|
+
|
894
|
+
>>> import brainstate
|
895
|
+
>>>
|
896
|
+
>>> class InferenceMode(brainstate.mixin.Mode):
|
897
|
+
... def __init__(self, use_cache=True):
|
898
|
+
... self.use_cache = use_cache
|
899
|
+
>>>
|
900
|
+
>>> # Create mode instances
|
901
|
+
>>> inference = InferenceMode(use_cache=True)
|
902
|
+
>>> print(inference) # Output: InferenceMode
|
903
|
+
|
904
|
+
Checking mode types:
|
905
|
+
|
906
|
+
.. code-block:: python
|
907
|
+
|
908
|
+
>>> class FastMode(brainstate.mixin.Mode):
|
909
|
+
... pass
|
910
|
+
>>>
|
911
|
+
>>> class SlowMode(brainstate.mixin.Mode):
|
912
|
+
... pass
|
913
|
+
>>>
|
914
|
+
>>> fast = FastMode()
|
915
|
+
>>> slow = SlowMode()
|
916
|
+
>>>
|
917
|
+
>>> # Check exact mode type
|
918
|
+
>>> assert fast.is_a(FastMode)
|
919
|
+
>>> assert not fast.is_a(SlowMode)
|
920
|
+
>>>
|
921
|
+
>>> # Check if mode is an instance of a type
|
922
|
+
>>> assert fast.has(brainstate.mixin.Mode)
|
923
|
+
|
924
|
+
Using modes in a model:
|
925
|
+
|
926
|
+
.. code-block:: python
|
927
|
+
|
928
|
+
>>> class Model:
|
929
|
+
... def __init__(self):
|
930
|
+
... self.mode = brainstate.mixin.Training()
|
931
|
+
...
|
932
|
+
... def forward(self, x):
|
933
|
+
... if self.mode.has(brainstate.mixin.Training):
|
934
|
+
... # Training-specific logic
|
935
|
+
... return self.train_forward(x)
|
936
|
+
... else:
|
937
|
+
... # Inference logic
|
938
|
+
... return self.eval_forward(x)
|
939
|
+
...
|
940
|
+
... def train_forward(self, x):
|
941
|
+
... return x + 0.1 # Add noise during training
|
942
|
+
...
|
943
|
+
... def eval_forward(self, x):
|
944
|
+
... return x # No noise during evaluation
|
945
|
+
"""
|
946
|
+
|
947
|
+
def __repr__(self):
|
948
|
+
"""
|
949
|
+
String representation of the mode.
|
950
|
+
|
951
|
+
Returns
|
952
|
+
-------
|
953
|
+
str
|
954
|
+
The class name of the mode.
|
955
|
+
"""
|
956
|
+
return self.__class__.__name__
|
957
|
+
|
958
|
+
def __eq__(self, other: 'Mode'):
|
959
|
+
"""
|
960
|
+
Check equality of modes based on their type.
|
961
|
+
|
962
|
+
Parameters
|
963
|
+
----------
|
964
|
+
other : Mode
|
965
|
+
Another mode to compare with.
|
966
|
+
|
967
|
+
Returns
|
968
|
+
-------
|
969
|
+
bool
|
970
|
+
True if both modes are of the same class.
|
971
|
+
"""
|
972
|
+
assert isinstance(other, Mode)
|
973
|
+
return other.__class__ == self.__class__
|
974
|
+
|
975
|
+
def is_a(self, mode: type):
|
976
|
+
"""
|
977
|
+
Check whether the mode is exactly the desired mode type.
|
978
|
+
|
979
|
+
This performs an exact type match, not checking for subclasses.
|
980
|
+
|
981
|
+
Parameters
|
982
|
+
----------
|
983
|
+
mode : type
|
984
|
+
The mode type to check against.
|
985
|
+
|
986
|
+
Returns
|
987
|
+
-------
|
988
|
+
bool
|
989
|
+
True if this mode is exactly of the specified type.
|
990
|
+
|
991
|
+
Examples
|
992
|
+
--------
|
993
|
+
.. code-block:: python
|
994
|
+
|
995
|
+
>>> import brainstate
|
996
|
+
>>>
|
997
|
+
>>> training_mode = brainstate.mixin.Training()
|
998
|
+
>>> assert training_mode.is_a(brainstate.mixin.Training)
|
999
|
+
>>> assert not training_mode.is_a(brainstate.mixin.Batching)
|
1000
|
+
"""
|
1001
|
+
assert isinstance(mode, type), 'Must be a type.'
|
1002
|
+
return self.__class__ == mode
|
1003
|
+
|
1004
|
+
def has(self, mode: type):
|
1005
|
+
"""
|
1006
|
+
Check whether the mode includes the desired mode type.
|
1007
|
+
|
1008
|
+
This checks if the current mode is an instance of the specified type,
|
1009
|
+
including checking for subclasses.
|
1010
|
+
|
1011
|
+
Parameters
|
1012
|
+
----------
|
1013
|
+
mode : type
|
1014
|
+
The mode type to check for.
|
1015
|
+
|
1016
|
+
Returns
|
1017
|
+
-------
|
1018
|
+
bool
|
1019
|
+
True if this mode is an instance of the specified type.
|
1020
|
+
|
1021
|
+
Examples
|
1022
|
+
--------
|
1023
|
+
.. code-block:: python
|
1024
|
+
|
1025
|
+
>>> import brainstate
|
1026
|
+
>>>
|
1027
|
+
>>> # Create a custom mode that extends Training
|
1028
|
+
>>> class AdvancedTraining(brainstate.mixin.Training):
|
1029
|
+
... pass
|
1030
|
+
>>>
|
1031
|
+
>>> advanced = AdvancedTraining()
|
1032
|
+
>>> assert advanced.has(brainstate.mixin.Training) # True (subclass)
|
1033
|
+
>>> assert advanced.has(brainstate.mixin.Mode) # True (base class)
|
1034
|
+
"""
|
1035
|
+
assert isinstance(mode, type), 'Must be a type.'
|
1036
|
+
return isinstance(self, mode)
|
1037
|
+
|
1038
|
+
|
1039
|
+
class JointMode(Mode):
|
1040
|
+
"""
|
1041
|
+
A mode that combines multiple modes simultaneously.
|
1042
|
+
|
1043
|
+
JointMode allows expressing that a computation is in multiple modes at once,
|
1044
|
+
such as being both in training mode and batching mode. This is useful for
|
1045
|
+
complex scenarios where multiple behavioral aspects need to be active.
|
1046
|
+
|
1047
|
+
Parameters
|
1048
|
+
----------
|
1049
|
+
*modes : Mode
|
1050
|
+
The modes to combine.
|
1051
|
+
|
1052
|
+
Attributes
|
1053
|
+
----------
|
1054
|
+
modes : tuple of Mode
|
1055
|
+
The individual modes that are combined.
|
1056
|
+
types : set of type
|
1057
|
+
The types of the combined modes.
|
1058
|
+
|
1059
|
+
Raises
|
1060
|
+
------
|
1061
|
+
TypeError
|
1062
|
+
If any of the provided arguments is not a Mode instance.
|
1063
|
+
|
1064
|
+
Examples
|
1065
|
+
--------
|
1066
|
+
Combining training and batching modes:
|
1067
|
+
|
1068
|
+
.. code-block:: python
|
1069
|
+
|
1070
|
+
>>> import brainstate
|
1071
|
+
>>>
|
1072
|
+
>>> # Create individual modes
|
1073
|
+
>>> training = brainstate.mixin.Training()
|
1074
|
+
>>> batching = brainstate.mixin.Batching(batch_size=32)
|
1075
|
+
>>>
|
1076
|
+
>>> # Combine them
|
1077
|
+
>>> joint = brainstate.mixin.JointMode(training, batching)
|
1078
|
+
>>> print(joint) # JointMode(Training, Batching(in_size=32, axis=0))
|
1079
|
+
>>>
|
1080
|
+
>>> # Check if specific modes are present
|
1081
|
+
>>> assert joint.has(brainstate.mixin.Training)
|
1082
|
+
>>> assert joint.has(brainstate.mixin.Batching)
|
1083
|
+
>>>
|
1084
|
+
>>> # Access attributes from combined modes
|
1085
|
+
>>> print(joint.batch_size) # 32 (from Batching mode)
|
1086
|
+
|
1087
|
+
Using in model configuration:
|
1088
|
+
|
1089
|
+
.. code-block:: python
|
1090
|
+
|
1091
|
+
>>> class NeuralNetwork:
|
1092
|
+
... def __init__(self):
|
1093
|
+
... self.mode = None
|
1094
|
+
...
|
1095
|
+
... def set_train_mode(self, batch_size=1):
|
1096
|
+
... # Set both training and batching modes
|
1097
|
+
... training = brainstate.mixin.Training()
|
1098
|
+
... batching = brainstate.mixin.Batching(batch_size=batch_size)
|
1099
|
+
... self.mode = brainstate.mixin.JointMode(training, batching)
|
1100
|
+
...
|
1101
|
+
... def forward(self, x):
|
1102
|
+
... if self.mode.has(brainstate.mixin.Training):
|
1103
|
+
... x = self.apply_dropout(x)
|
1104
|
+
...
|
1105
|
+
... if self.mode.has(brainstate.mixin.Batching):
|
1106
|
+
... # Process in batches
|
1107
|
+
... batch_size = self.mode.batch_size
|
1108
|
+
... return self.batch_process(x, batch_size)
|
1109
|
+
...
|
1110
|
+
... return self.process(x)
|
1111
|
+
>>>
|
1112
|
+
>>> model = NeuralNetwork()
|
1113
|
+
>>> model.set_train_mode(batch_size=64)
|
1114
|
+
"""
|
1115
|
+
|
1116
|
+
def __init__(self, *modes: Mode):
|
1117
|
+
# Validate that all arguments are Mode instances
|
1118
|
+
for m_ in modes:
|
1119
|
+
if not isinstance(m_, Mode):
|
1120
|
+
raise TypeError(f'The supported type must be a tuple/list of Mode. But we got {m_}')
|
1121
|
+
|
1122
|
+
# Store the modes as a tuple
|
1123
|
+
self.modes = tuple(modes)
|
1124
|
+
|
1125
|
+
# Store the types of the modes for quick lookup
|
1126
|
+
self.types = set([m.__class__ for m in modes])
|
1127
|
+
|
1128
|
+
def __repr__(self):
|
1129
|
+
"""
|
1130
|
+
String representation showing all combined modes.
|
1131
|
+
|
1132
|
+
Returns
|
1133
|
+
-------
|
1134
|
+
str
|
1135
|
+
A string showing the joint mode and its components.
|
1136
|
+
"""
|
1137
|
+
return f'{self.__class__.__name__}({", ".join([repr(m) for m in self.modes])})'
|
1138
|
+
|
1139
|
+
def has(self, mode: type):
|
1140
|
+
"""
|
1141
|
+
Check whether any of the combined modes includes the desired type.
|
1142
|
+
|
1143
|
+
Parameters
|
1144
|
+
----------
|
1145
|
+
mode : type
|
1146
|
+
The mode type to check for.
|
1147
|
+
|
1148
|
+
Returns
|
1149
|
+
-------
|
1150
|
+
bool
|
1151
|
+
True if any of the combined modes is or inherits from the specified type.
|
1152
|
+
|
1153
|
+
Examples
|
1154
|
+
--------
|
1155
|
+
.. code-block:: python
|
1156
|
+
|
1157
|
+
>>> import brainstate
|
1158
|
+
>>>
|
1159
|
+
>>> training = brainstate.mixin.Training()
|
1160
|
+
>>> batching = brainstate.mixin.Batching(batch_size=16)
|
1161
|
+
>>> joint = brainstate.mixin.JointMode(training, batching)
|
1162
|
+
>>>
|
1163
|
+
>>> assert joint.has(brainstate.mixin.Training)
|
1164
|
+
>>> assert joint.has(brainstate.mixin.Batching)
|
1165
|
+
>>> assert joint.has(brainstate.mixin.Mode) # Base class
|
1166
|
+
"""
|
1167
|
+
assert isinstance(mode, type), 'Must be a type.'
|
1168
|
+
# Check if any of the combined mode types is a subclass of the target mode
|
1169
|
+
return any([issubclass(cls, mode) for cls in self.types])
|
1170
|
+
|
1171
|
+
def is_a(self, cls: type):
|
1172
|
+
"""
|
1173
|
+
Check whether the joint mode is exactly the desired combined type.
|
1174
|
+
|
1175
|
+
This is a complex check that verifies the joint mode matches a specific
|
1176
|
+
combination of types.
|
1177
|
+
|
1178
|
+
Parameters
|
1179
|
+
----------
|
1180
|
+
cls : type
|
1181
|
+
The combined type to check against.
|
1182
|
+
|
1183
|
+
Returns
|
1184
|
+
-------
|
1185
|
+
bool
|
1186
|
+
True if the joint mode exactly matches the specified type combination.
|
1187
|
+
"""
|
1188
|
+
# Use JointTypes to create the expected type from our mode types
|
1189
|
+
return JointTypes(*tuple(self.types)) == cls
|
1190
|
+
|
1191
|
+
def __getattr__(self, item):
|
1192
|
+
"""
|
1193
|
+
Get attributes from the combined modes.
|
1194
|
+
|
1195
|
+
This method searches through all combined modes to find the requested
|
1196
|
+
attribute, allowing transparent access to properties of any of the
|
1197
|
+
combined modes.
|
1198
|
+
|
1199
|
+
Parameters
|
1200
|
+
----------
|
1201
|
+
item : str
|
1202
|
+
The attribute name to search for.
|
1203
|
+
|
1204
|
+
Returns
|
1205
|
+
-------
|
1206
|
+
Any
|
1207
|
+
The attribute value from the first mode that has it.
|
1208
|
+
|
1209
|
+
Raises
|
1210
|
+
------
|
1211
|
+
AttributeError
|
1212
|
+
If the attribute is not found in any of the combined modes.
|
1213
|
+
|
1214
|
+
Examples
|
1215
|
+
--------
|
1216
|
+
.. code-block:: python
|
1217
|
+
|
1218
|
+
>>> import brainstate
|
1219
|
+
>>>
|
1220
|
+
>>> batching = brainstate.mixin.Batching(batch_size=32, batch_axis=1)
|
1221
|
+
>>> training = brainstate.mixin.Training()
|
1222
|
+
>>> joint = brainstate.mixin.JointMode(batching, training)
|
1223
|
+
>>>
|
1224
|
+
>>> # Access batching attributes directly
|
1225
|
+
>>> print(joint.batch_size) # 32
|
1226
|
+
>>> print(joint.batch_axis) # 1
|
1227
|
+
"""
|
1228
|
+
# Don't interfere with accessing modes and types attributes
|
1229
|
+
if item in ['modes', 'types']:
|
1230
|
+
return super().__getattribute__(item)
|
1231
|
+
|
1232
|
+
# Search for the attribute in each combined mode
|
1233
|
+
for m in self.modes:
|
1234
|
+
if hasattr(m, item):
|
1235
|
+
return getattr(m, item)
|
1236
|
+
|
1237
|
+
# If not found, fall back to default behavior (will raise AttributeError)
|
1238
|
+
return super().__getattribute__(item)
|
1239
|
+
|
1240
|
+
|
1241
|
+
class Batching(Mode):
|
1242
|
+
"""
|
1243
|
+
Mode indicating batched computation.
|
1244
|
+
|
1245
|
+
This mode specifies that computations should be performed on batches of data,
|
1246
|
+
including information about the batch size and which axis represents the batch
|
1247
|
+
dimension.
|
1248
|
+
|
1249
|
+
Parameters
|
1250
|
+
----------
|
1251
|
+
batch_size : int, default 1
|
1252
|
+
The size of each batch.
|
1253
|
+
batch_axis : int, default 0
|
1254
|
+
The axis along which batching occurs.
|
1255
|
+
|
1256
|
+
Attributes
|
1257
|
+
----------
|
1258
|
+
batch_size : int
|
1259
|
+
The number of samples in each batch.
|
1260
|
+
batch_axis : int
|
1261
|
+
The axis index representing the batch dimension.
|
1262
|
+
|
1263
|
+
Examples
|
1264
|
+
--------
|
1265
|
+
Basic batching configuration:
|
1266
|
+
|
1267
|
+
.. code-block:: python
|
1268
|
+
|
1269
|
+
>>> import brainstate
|
1270
|
+
>>>
|
1271
|
+
>>> # Create a batching mode
|
1272
|
+
>>> batching = brainstate.mixin.Batching(batch_size=32, batch_axis=0)
|
1273
|
+
>>> print(batching) # Batching(in_size=32, axis=0)
|
1274
|
+
>>>
|
1275
|
+
>>> # Access batch parameters
|
1276
|
+
>>> print(f"Processing {batching.batch_size} samples at once")
|
1277
|
+
>>> print(f"Batch dimension is axis {batching.batch_axis}")
|
1278
|
+
|
1279
|
+
Using in a model:
|
1280
|
+
|
1281
|
+
.. code-block:: python
|
1282
|
+
|
1283
|
+
>>> import jax.numpy as jnp
|
1284
|
+
>>>
|
1285
|
+
>>> class BatchedModel:
|
1286
|
+
... def __init__(self):
|
1287
|
+
... self.mode = None
|
1288
|
+
...
|
1289
|
+
... def set_batch_mode(self, batch_size, batch_axis=0):
|
1290
|
+
... self.mode = brainstate.mixin.Batching(batch_size, batch_axis)
|
1291
|
+
...
|
1292
|
+
... def process(self, x):
|
1293
|
+
... if self.mode is not None and self.mode.has(brainstate.mixin.Batching):
|
1294
|
+
... # Process in batches
|
1295
|
+
... batch_size = self.mode.batch_size
|
1296
|
+
... axis = self.mode.batch_axis
|
1297
|
+
... return jnp.mean(x, axis=axis, keepdims=True)
|
1298
|
+
... return x
|
1299
|
+
>>>
|
1300
|
+
>>> model = BatchedModel()
|
1301
|
+
>>> model.set_batch_mode(batch_size=64)
|
1302
|
+
>>>
|
1303
|
+
>>> # Process batched data
|
1304
|
+
>>> data = jnp.random.randn(64, 100) # 64 samples, 100 features
|
1305
|
+
>>> result = model.process(data)
|
1306
|
+
|
1307
|
+
Combining with other modes:
|
1308
|
+
|
1309
|
+
.. code-block:: python
|
1310
|
+
|
1311
|
+
>>> # Combine batching with training mode
|
1312
|
+
>>> training = brainstate.mixin.Training()
|
1313
|
+
>>> batching = brainstate.mixin.Batching(batch_size=128)
|
1314
|
+
>>> combined = brainstate.mixin.JointMode(training, batching)
|
1315
|
+
>>>
|
1316
|
+
>>> # Use in a training loop
|
1317
|
+
>>> def train_step(model, data, mode):
|
1318
|
+
... if mode.has(brainstate.mixin.Batching):
|
1319
|
+
... # Split data into batches
|
1320
|
+
... batch_size = mode.batch_size
|
1321
|
+
... # ... batched processing ...
|
1322
|
+
... if mode.has(brainstate.mixin.Training):
|
1323
|
+
... # Apply training-specific operations
|
1324
|
+
... # ... training logic ...
|
1325
|
+
... pass
|
1326
|
+
"""
|
1327
|
+
|
1328
|
+
def __init__(self, batch_size: int = 1, batch_axis: int = 0):
|
1329
|
+
self.batch_size = batch_size
|
1330
|
+
self.batch_axis = batch_axis
|
1331
|
+
|
1332
|
+
def __repr__(self):
|
1333
|
+
"""
|
1334
|
+
String representation showing batch configuration.
|
1335
|
+
|
1336
|
+
Returns
|
1337
|
+
-------
|
1338
|
+
str
|
1339
|
+
A string showing the batch size and axis.
|
1340
|
+
"""
|
1341
|
+
return f'{self.__class__.__name__}(in_size={self.batch_size}, axis={self.batch_axis})'
|
1342
|
+
|
1343
|
+
|
1344
|
+
class Training(Mode):
|
1345
|
+
"""
|
1346
|
+
Mode indicating training computation.
|
1347
|
+
|
1348
|
+
This mode specifies that the model is in training mode, which typically
|
1349
|
+
enables behaviors like dropout, batch normalization in training mode,
|
1350
|
+
gradient computation, etc.
|
1351
|
+
|
1352
|
+
Examples
|
1353
|
+
--------
|
1354
|
+
Basic training mode:
|
1355
|
+
|
1356
|
+
.. code-block:: python
|
1357
|
+
|
1358
|
+
>>> import brainstate
|
1359
|
+
>>>
|
1360
|
+
>>> # Create training mode
|
1361
|
+
>>> training = brainstate.mixin.Training()
|
1362
|
+
>>> print(training) # Training
|
1363
|
+
>>>
|
1364
|
+
>>> # Check mode
|
1365
|
+
>>> assert training.is_a(brainstate.mixin.Training)
|
1366
|
+
>>> assert training.has(brainstate.mixin.Mode)
|
1367
|
+
|
1368
|
+
Using in a model with dropout:
|
1369
|
+
|
1370
|
+
.. code-block:: python
|
1371
|
+
|
1372
|
+
>>> import brainstate
|
1373
|
+
>>> import jax
|
1374
|
+
>>> import jax.numpy as jnp
|
1375
|
+
>>>
|
1376
|
+
>>> class ModelWithDropout:
|
1377
|
+
... def __init__(self, dropout_rate=0.5):
|
1378
|
+
... self.dropout_rate = dropout_rate
|
1379
|
+
... self.mode = None
|
1380
|
+
...
|
1381
|
+
... def set_training(self, is_training=True):
|
1382
|
+
... if is_training:
|
1383
|
+
... self.mode = brainstate.mixin.Training()
|
1384
|
+
... else:
|
1385
|
+
... self.mode = brainstate.mixin.Mode() # Evaluation mode
|
1386
|
+
...
|
1387
|
+
... def forward(self, x, rng_key):
|
1388
|
+
... # Apply dropout only during training
|
1389
|
+
... if self.mode is not None and self.mode.has(brainstate.mixin.Training):
|
1390
|
+
... keep_prob = 1.0 - self.dropout_rate
|
1391
|
+
... mask = jax.random.bernoulli(rng_key, keep_prob, x.shape)
|
1392
|
+
... x = jnp.where(mask, x / keep_prob, 0)
|
1393
|
+
... return x
|
1394
|
+
>>>
|
1395
|
+
>>> model = ModelWithDropout()
|
1396
|
+
>>>
|
1397
|
+
>>> # Training mode
|
1398
|
+
>>> model.set_training(True)
|
1399
|
+
>>> key = jax.random.PRNGKey(0)
|
1400
|
+
>>> x_train = jnp.ones((10, 20))
|
1401
|
+
>>> out_train = model.forward(x_train, key) # Dropout applied
|
1402
|
+
>>>
|
1403
|
+
>>> # Evaluation mode
|
1404
|
+
>>> model.set_training(False)
|
1405
|
+
>>> out_eval = model.forward(x_train, key) # No dropout
|
1406
|
+
|
1407
|
+
Combining with batching:
|
1408
|
+
|
1409
|
+
.. code-block:: python
|
1410
|
+
|
1411
|
+
>>> # Create combined training and batching mode
|
1412
|
+
>>> training = brainstate.mixin.Training()
|
1413
|
+
>>> batching = brainstate.mixin.Batching(batch_size=32)
|
1414
|
+
>>> mode = brainstate.mixin.JointMode(training, batching)
|
1415
|
+
>>>
|
1416
|
+
>>> # Use in training configuration
|
1417
|
+
>>> class Trainer:
|
1418
|
+
... def __init__(self, model, mode):
|
1419
|
+
... self.model = model
|
1420
|
+
... self.mode = mode
|
1421
|
+
...
|
1422
|
+
... def train_epoch(self, data):
|
1423
|
+
... if self.mode.has(brainstate.mixin.Training):
|
1424
|
+
... # Enable training-specific behaviors
|
1425
|
+
... self.model.set_training(True)
|
1426
|
+
...
|
1427
|
+
... if self.mode.has(brainstate.mixin.Batching):
|
1428
|
+
... # Process in batches
|
1429
|
+
... batch_size = self.mode.batch_size
|
1430
|
+
... # ... batched training loop ...
|
1431
|
+
... pass
|
1432
|
+
"""
|
1433
|
+
pass
|