brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +15 -28
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/mixin.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -15,35 +15,66 @@
|
|
15
15
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
|
18
|
-
|
19
|
-
|
20
|
-
)
|
18
|
+
"""
|
19
|
+
Mixin classes and utility types for brainstate.
|
21
20
|
|
22
|
-
|
21
|
+
This module provides various mixin classes and custom type definitions that
|
22
|
+
enhance the functionality of brainstate components. It includes parameter
|
23
|
+
description mixins, alignment interfaces, and custom type definitions for
|
24
|
+
expressing complex type requirements.
|
25
|
+
"""
|
23
26
|
|
24
|
-
|
25
|
-
|
27
|
+
from typing import Sequence, Optional, TypeVar, Union, _GenericAlias
|
28
|
+
|
29
|
+
import jax
|
26
30
|
|
27
31
|
__all__ = [
|
28
32
|
'Mixin',
|
29
33
|
'ParamDesc',
|
30
34
|
'ParamDescriber',
|
31
|
-
'AlignPost',
|
32
|
-
'BindCondData',
|
33
|
-
|
34
|
-
# types
|
35
35
|
'JointTypes',
|
36
36
|
'OneOfTypes',
|
37
|
-
|
38
|
-
|
37
|
+
'_JointGenericAlias',
|
38
|
+
'_OneOfGenericAlias',
|
39
39
|
'Mode',
|
40
40
|
'JointMode',
|
41
41
|
'Batching',
|
42
42
|
'Training',
|
43
43
|
]
|
44
44
|
|
45
|
+
T = TypeVar('T')
|
46
|
+
ArrayLike = jax.typing.ArrayLike
|
47
|
+
|
45
48
|
|
46
49
|
def hashable(x):
|
50
|
+
"""
|
51
|
+
Check if an object is hashable.
|
52
|
+
|
53
|
+
Parameters
|
54
|
+
----------
|
55
|
+
x : Any
|
56
|
+
The object to check for hashability.
|
57
|
+
|
58
|
+
Returns
|
59
|
+
-------
|
60
|
+
bool
|
61
|
+
True if the object is hashable, False otherwise.
|
62
|
+
|
63
|
+
Examples
|
64
|
+
--------
|
65
|
+
.. code-block:: python
|
66
|
+
|
67
|
+
>>> import brainstate
|
68
|
+
>>>
|
69
|
+
>>> # Hashable objects
|
70
|
+
>>> assert brainstate.mixin.hashable(42) == True
|
71
|
+
>>> assert brainstate.mixin.hashable("string") == True
|
72
|
+
>>> assert brainstate.mixin.hashable((1, 2, 3)) == True
|
73
|
+
>>>
|
74
|
+
>>> # Non-hashable objects
|
75
|
+
>>> assert brainstate.mixin.hashable([1, 2, 3]) == False
|
76
|
+
>>> assert brainstate.mixin.hashable({"key": "value"}) == False
|
77
|
+
"""
|
47
78
|
try:
|
48
79
|
hash(x)
|
49
80
|
return True
|
@@ -52,45 +83,194 @@ def hashable(x):
|
|
52
83
|
|
53
84
|
|
54
85
|
class Mixin(object):
|
55
|
-
"""
|
56
|
-
|
57
|
-
|
86
|
+
"""
|
87
|
+
Base Mixin object for behavioral extensions.
|
88
|
+
|
89
|
+
The key characteristic of a :py:class:`~.Mixin` is that it provides only
|
90
|
+
behavioral functions without requiring initialization. Mixins are used to
|
91
|
+
add specific functionality to classes through multiple inheritance without
|
92
|
+
the complexity of a full base class.
|
93
|
+
|
94
|
+
Notes
|
95
|
+
-----
|
96
|
+
Mixins should not define ``__init__`` methods. They should only provide
|
97
|
+
methods that add specific behaviors to the classes that inherit from them.
|
98
|
+
|
99
|
+
Examples
|
100
|
+
--------
|
101
|
+
Creating a custom mixin:
|
102
|
+
|
103
|
+
.. code-block:: python
|
104
|
+
|
105
|
+
>>> import brainstate
|
106
|
+
>>>
|
107
|
+
>>> class LoggingMixin(brainstate.mixin.Mixin):
|
108
|
+
... def log(self, message):
|
109
|
+
... print(f"[{self.__class__.__name__}] {message}")
|
110
|
+
|
111
|
+
>>> class MyComponent(brainstate.nn.Module, LoggingMixin):
|
112
|
+
... def __init__(self):
|
113
|
+
... super().__init__()
|
114
|
+
...
|
115
|
+
... def process(self):
|
116
|
+
... self.log("Processing data...")
|
117
|
+
... return "Done"
|
118
|
+
>>>
|
119
|
+
>>> component = MyComponent()
|
120
|
+
>>> component.process() # Prints: [MyComponent] Processing data...
|
58
121
|
"""
|
59
122
|
pass
|
60
123
|
|
61
124
|
|
62
125
|
class ParamDesc(Mixin):
|
63
126
|
"""
|
64
|
-
|
65
|
-
|
66
|
-
This mixin enables
|
67
|
-
|
68
|
-
|
69
|
-
|
127
|
+
Mixin for describing initialization parameters.
|
128
|
+
|
129
|
+
This mixin enables a class to have a ``desc`` classmethod, which produces
|
130
|
+
an instance of :py:class:`~.ParamDescriber`. This is useful for creating
|
131
|
+
parameter templates that can be reused to instantiate multiple objects
|
132
|
+
with the same configuration.
|
133
|
+
|
134
|
+
Attributes
|
135
|
+
----------
|
136
|
+
non_hashable_params : sequence of str, optional
|
137
|
+
Names of parameters that are not hashable and should be handled specially.
|
138
|
+
|
139
|
+
Notes
|
140
|
+
-----
|
141
|
+
This mixin can be applied to any Python class, not just brainstate-specific classes.
|
142
|
+
|
143
|
+
Examples
|
144
|
+
--------
|
145
|
+
Basic usage of ParamDesc:
|
146
|
+
|
147
|
+
.. code-block:: python
|
148
|
+
|
149
|
+
>>> import brainstate
|
150
|
+
>>>
|
151
|
+
>>> class NeuronModel(brainstate.mixin.ParamDesc):
|
152
|
+
... def __init__(self, size, tau=10.0, threshold=1.0):
|
153
|
+
... self.size = size
|
154
|
+
... self.tau = tau
|
155
|
+
... self.threshold = threshold
|
156
|
+
>>>
|
157
|
+
>>> # Create a parameter descriptor
|
158
|
+
>>> neuron_desc = NeuronModel.desc(size=100, tau=20.0)
|
159
|
+
>>>
|
160
|
+
>>> # Use the descriptor to create instances
|
161
|
+
>>> neuron1 = neuron_desc(threshold=0.8) # Creates with threshold=0.8
|
162
|
+
>>> neuron2 = neuron_desc(threshold=1.2) # Creates with threshold=1.2
|
163
|
+
>>>
|
164
|
+
>>> # Both neurons share size=100, tau=20.0 but have different thresholds
|
165
|
+
|
166
|
+
Creating reusable templates:
|
167
|
+
|
168
|
+
.. code-block:: python
|
169
|
+
|
170
|
+
>>> # Define a template for excitatory neurons
|
171
|
+
>>> exc_neuron_template = NeuronModel.desc(size=1000, tau=10.0, threshold=1.0)
|
172
|
+
>>>
|
173
|
+
>>> # Define a template for inhibitory neurons
|
174
|
+
>>> inh_neuron_template = NeuronModel.desc(size=250, tau=5.0, threshold=0.5)
|
175
|
+
>>>
|
176
|
+
>>> # Create multiple instances from templates
|
177
|
+
>>> exc_population = [exc_neuron_template() for _ in range(5)]
|
178
|
+
>>> inh_population = [inh_neuron_template() for _ in range(2)]
|
70
179
|
"""
|
71
180
|
|
181
|
+
# Optional list of parameter names that are not hashable
|
182
|
+
# These will be converted to strings for hashing purposes
|
72
183
|
non_hashable_params: Optional[Sequence[str]] = None
|
73
184
|
|
74
185
|
@classmethod
|
75
186
|
def desc(cls, *args, **kwargs) -> 'ParamDescriber':
|
187
|
+
"""
|
188
|
+
Create a parameter describer for this class.
|
189
|
+
|
190
|
+
Parameters
|
191
|
+
----------
|
192
|
+
*args
|
193
|
+
Positional arguments to be used in future instantiations.
|
194
|
+
**kwargs
|
195
|
+
Keyword arguments to be used in future instantiations.
|
196
|
+
|
197
|
+
Returns
|
198
|
+
-------
|
199
|
+
ParamDescriber
|
200
|
+
A descriptor that can be used to create instances with these parameters.
|
201
|
+
"""
|
76
202
|
return ParamDescriber(cls, *args, **kwargs)
|
77
203
|
|
78
204
|
|
79
205
|
class HashableDict(dict):
|
206
|
+
"""
|
207
|
+
A dictionary that can be hashed by converting non-hashable values to strings.
|
208
|
+
|
209
|
+
This is used internally to make parameter dictionaries hashable so they can
|
210
|
+
be used as part of cache keys or other contexts requiring hashability.
|
211
|
+
|
212
|
+
Parameters
|
213
|
+
----------
|
214
|
+
the_dict : dict
|
215
|
+
The dictionary to make hashable.
|
216
|
+
|
217
|
+
Notes
|
218
|
+
-----
|
219
|
+
Non-hashable values in the dictionary are automatically converted to their
|
220
|
+
string representation.
|
221
|
+
|
222
|
+
Examples
|
223
|
+
--------
|
224
|
+
.. code-block:: python
|
225
|
+
|
226
|
+
>>> import brainstate
|
227
|
+
>>> import jax.numpy as jnp
|
228
|
+
>>>
|
229
|
+
>>> # Regular dict with non-hashable values cannot be hashed
|
230
|
+
>>> regular_dict = {"array": jnp.array([1, 2, 3]), "value": 42}
|
231
|
+
>>> # hash(regular_dict) # This would raise TypeError
|
232
|
+
>>>
|
233
|
+
>>> # HashableDict can be hashed
|
234
|
+
>>> hashable = brainstate.mixin.HashableDict(regular_dict)
|
235
|
+
>>> key = hash(hashable) # This works!
|
236
|
+
>>>
|
237
|
+
>>> # Can be used in sets or as dict keys
|
238
|
+
>>> cache = {hashable: "result"}
|
239
|
+
"""
|
240
|
+
|
80
241
|
def __init__(self, the_dict: dict):
|
242
|
+
# Process the dictionary to ensure all values are hashable
|
81
243
|
out = dict()
|
82
244
|
for k, v in the_dict.items():
|
83
245
|
if not hashable(v):
|
84
|
-
|
246
|
+
# Convert non-hashable values to their string representation
|
247
|
+
v = str(v)
|
85
248
|
out[k] = v
|
86
249
|
super().__init__(out)
|
87
250
|
|
88
251
|
def __hash__(self):
|
252
|
+
"""
|
253
|
+
Compute hash from sorted items for consistent hashing regardless of insertion order.
|
254
|
+
"""
|
89
255
|
return hash(tuple(sorted(self.items())))
|
90
256
|
|
91
257
|
|
92
258
|
class NoSubclassMeta(type):
|
259
|
+
"""
|
260
|
+
Metaclass that prevents a class from being subclassed.
|
261
|
+
|
262
|
+
This is used to ensure that certain classes (like ParamDescriber) are used
|
263
|
+
as-is and not extended through inheritance, which could lead to unexpected
|
264
|
+
behavior.
|
265
|
+
|
266
|
+
Raises
|
267
|
+
------
|
268
|
+
TypeError
|
269
|
+
If an attempt is made to subclass a class using this metaclass.
|
270
|
+
"""
|
271
|
+
|
93
272
|
def __new__(cls, name, bases, classdict):
|
273
|
+
# Check if any base class uses NoSubclassMeta
|
94
274
|
for b in bases:
|
95
275
|
if isinstance(b, NoSubclassMeta):
|
96
276
|
raise TypeError("type '{0}' is not an acceptable base type".format(b.__name__))
|
@@ -99,209 +279,758 @@ class NoSubclassMeta(type):
|
|
99
279
|
|
100
280
|
class ParamDescriber(metaclass=NoSubclassMeta):
|
101
281
|
"""
|
102
|
-
|
282
|
+
Parameter descriptor for deferred object instantiation.
|
283
|
+
|
284
|
+
This class stores a class reference along with arguments and keyword arguments,
|
285
|
+
allowing for deferred instantiation. It's useful for creating templates that
|
286
|
+
can be reused to create multiple instances with similar configurations.
|
287
|
+
|
288
|
+
Parameters
|
289
|
+
----------
|
290
|
+
cls : type
|
291
|
+
The class to be instantiated.
|
292
|
+
*desc_tuple
|
293
|
+
Positional arguments to be stored and used during instantiation.
|
294
|
+
**desc_dict
|
295
|
+
Keyword arguments to be stored and used during instantiation.
|
296
|
+
|
297
|
+
Attributes
|
298
|
+
----------
|
299
|
+
cls : type
|
300
|
+
The class that will be instantiated.
|
301
|
+
args : tuple
|
302
|
+
Stored positional arguments.
|
303
|
+
kwargs : dict
|
304
|
+
Stored keyword arguments.
|
305
|
+
identifier : tuple
|
306
|
+
A hashable identifier for this descriptor.
|
307
|
+
|
308
|
+
Notes
|
309
|
+
-----
|
310
|
+
ParamDescriber cannot be subclassed due to the NoSubclassMeta metaclass.
|
311
|
+
This ensures consistent behavior across the codebase.
|
312
|
+
|
313
|
+
Examples
|
314
|
+
--------
|
315
|
+
Manual creation of a descriptor:
|
316
|
+
|
317
|
+
.. code-block:: python
|
318
|
+
|
319
|
+
>>> import brainstate
|
320
|
+
>>>
|
321
|
+
>>> class Network:
|
322
|
+
... def __init__(self, n_neurons, learning_rate=0.01):
|
323
|
+
... self.n_neurons = n_neurons
|
324
|
+
... self.learning_rate = learning_rate
|
325
|
+
>>>
|
326
|
+
>>> # Create a descriptor
|
327
|
+
>>> network_desc = brainstate.mixin.ParamDescriber(
|
328
|
+
... Network, n_neurons=1000, learning_rate=0.001
|
329
|
+
... )
|
330
|
+
>>>
|
331
|
+
>>> # Use the descriptor to create instances with additional args
|
332
|
+
>>> net1 = network_desc()
|
333
|
+
>>> net2 = network_desc() # Same configuration
|
334
|
+
|
335
|
+
Using with ParamDesc mixin:
|
336
|
+
|
337
|
+
.. code-block:: python
|
338
|
+
|
339
|
+
>>> class Network(brainstate.mixin.ParamDesc):
|
340
|
+
... def __init__(self, n_neurons, learning_rate=0.01):
|
341
|
+
... self.n_neurons = n_neurons
|
342
|
+
... self.learning_rate = learning_rate
|
343
|
+
>>>
|
344
|
+
>>> # More concise syntax using the desc() classmethod
|
345
|
+
>>> network_desc = Network.desc(n_neurons=1000)
|
346
|
+
>>> net = network_desc(learning_rate=0.005) # Override learning_rate
|
103
347
|
"""
|
104
348
|
|
105
349
|
def __init__(self, cls: T, *desc_tuple, **desc_dict):
|
350
|
+
# Store the class to be instantiated
|
106
351
|
self.cls: type = cls
|
107
352
|
|
108
|
-
# arguments
|
353
|
+
# Store the arguments for later instantiation
|
109
354
|
self.args = desc_tuple
|
110
355
|
self.kwargs = desc_dict
|
111
356
|
|
112
|
-
# identifier
|
357
|
+
# Create a hashable identifier for caching/comparison purposes
|
358
|
+
# This combines the class, args tuple, and hashable kwargs dict
|
113
359
|
self._identifier = (cls, tuple(desc_tuple), HashableDict(desc_dict))
|
114
360
|
|
115
361
|
def __call__(self, *args, **kwargs) -> T:
|
116
|
-
|
362
|
+
"""
|
363
|
+
Instantiate the class with stored and additional arguments.
|
364
|
+
|
365
|
+
Parameters
|
366
|
+
----------
|
367
|
+
*args
|
368
|
+
Additional positional arguments to append.
|
369
|
+
**kwargs
|
370
|
+
Additional keyword arguments to merge (will override stored kwargs).
|
371
|
+
|
372
|
+
Returns
|
373
|
+
-------
|
374
|
+
T
|
375
|
+
An instance of the described class.
|
376
|
+
"""
|
377
|
+
# Merge stored arguments with new arguments
|
378
|
+
# Stored args come first, then new args
|
379
|
+
# Merge kwargs with new kwargs overriding stored ones
|
380
|
+
merged_kwargs = {**self.kwargs, **kwargs}
|
381
|
+
return self.cls(*self.args, *args, **merged_kwargs)
|
117
382
|
|
118
383
|
def init(self, *args, **kwargs):
|
384
|
+
"""
|
385
|
+
Alias for __call__, explicitly named for clarity.
|
386
|
+
|
387
|
+
Parameters
|
388
|
+
----------
|
389
|
+
*args
|
390
|
+
Additional positional arguments.
|
391
|
+
**kwargs
|
392
|
+
Additional keyword arguments.
|
393
|
+
|
394
|
+
Returns
|
395
|
+
-------
|
396
|
+
T
|
397
|
+
An instance of the described class.
|
398
|
+
"""
|
119
399
|
return self.__call__(*args, **kwargs)
|
120
400
|
|
121
401
|
def __instancecheck__(self, instance):
|
402
|
+
"""
|
403
|
+
Check if an instance is compatible with this descriptor.
|
404
|
+
|
405
|
+
Parameters
|
406
|
+
----------
|
407
|
+
instance : Any
|
408
|
+
The instance to check.
|
409
|
+
|
410
|
+
Returns
|
411
|
+
-------
|
412
|
+
bool
|
413
|
+
True if the instance is a ParamDescriber for a compatible class.
|
414
|
+
"""
|
415
|
+
# Must be a ParamDescriber
|
122
416
|
if not isinstance(instance, ParamDescriber):
|
123
417
|
return False
|
418
|
+
# The described class must be a subclass of our class
|
124
419
|
if not issubclass(instance.cls, self.cls):
|
125
420
|
return False
|
126
421
|
return True
|
127
422
|
|
128
423
|
@classmethod
|
129
424
|
def __class_getitem__(cls, item: type):
|
425
|
+
"""
|
426
|
+
Support for subscript notation: ParamDescriber[MyClass].
|
427
|
+
|
428
|
+
Parameters
|
429
|
+
----------
|
430
|
+
item : type
|
431
|
+
The class to create a descriptor for.
|
432
|
+
|
433
|
+
Returns
|
434
|
+
-------
|
435
|
+
ParamDescriber
|
436
|
+
A descriptor for the given class.
|
437
|
+
"""
|
130
438
|
return ParamDescriber(item)
|
131
439
|
|
132
440
|
@property
|
133
441
|
def identifier(self):
|
442
|
+
"""
|
443
|
+
Get the unique identifier for this descriptor.
|
444
|
+
|
445
|
+
Returns
|
446
|
+
-------
|
447
|
+
tuple
|
448
|
+
A hashable identifier consisting of (class, args, kwargs).
|
449
|
+
"""
|
134
450
|
return self._identifier
|
135
451
|
|
136
452
|
@identifier.setter
|
137
453
|
def identifier(self, value: ArrayLike):
|
454
|
+
"""
|
455
|
+
Prevent modification of the identifier.
|
456
|
+
|
457
|
+
Raises
|
458
|
+
------
|
459
|
+
AttributeError
|
460
|
+
Always, as the identifier is read-only.
|
461
|
+
"""
|
138
462
|
raise AttributeError('Cannot set the identifier.')
|
139
463
|
|
140
464
|
|
141
|
-
|
465
|
+
def not_implemented(func):
|
142
466
|
"""
|
143
|
-
|
144
|
-
|
145
|
-
This
|
146
|
-
|
467
|
+
Decorator to mark a function as not implemented.
|
468
|
+
|
469
|
+
This decorator wraps a function to raise NotImplementedError when called,
|
470
|
+
and adds a ``not_implemented`` attribute for checking.
|
471
|
+
|
472
|
+
Parameters
|
473
|
+
----------
|
474
|
+
func : callable
|
475
|
+
The function to mark as not implemented.
|
476
|
+
|
477
|
+
Returns
|
478
|
+
-------
|
479
|
+
callable
|
480
|
+
A wrapper function that raises NotImplementedError.
|
481
|
+
|
482
|
+
Examples
|
483
|
+
--------
|
484
|
+
.. code-block:: python
|
485
|
+
|
486
|
+
>>> import brainstate
|
487
|
+
>>>
|
488
|
+
>>> class BaseModel:
|
489
|
+
... @brainstate.mixin.not_implemented
|
490
|
+
... def process(self, x):
|
491
|
+
... pass
|
492
|
+
>>>
|
493
|
+
>>> model = BaseModel()
|
494
|
+
>>> # model.process(10) # Raises: NotImplementedError: process is not implemented.
|
495
|
+
>>>
|
496
|
+
>>> # Check if a method is not implemented
|
497
|
+
>>> assert hasattr(BaseModel.process, 'not_implemented')
|
147
498
|
"""
|
148
499
|
|
149
|
-
def align_post_input_add(self, *args, **kwargs):
|
150
|
-
raise NotImplementedError
|
151
|
-
|
152
|
-
|
153
|
-
class BindCondData(Mixin):
|
154
|
-
"""Bind temporary conductance data.
|
155
|
-
|
156
|
-
|
157
|
-
"""
|
158
|
-
_conductance: Optional
|
159
|
-
|
160
|
-
def bind_cond(self, conductance):
|
161
|
-
self._conductance = conductance
|
162
|
-
|
163
|
-
def unbind_cond(self):
|
164
|
-
self._conductance = None
|
165
|
-
|
166
|
-
|
167
|
-
def not_implemented(func):
|
168
500
|
def wrapper(*args, **kwargs):
|
169
501
|
raise NotImplementedError(f'{func.__name__} is not implemented.')
|
170
502
|
|
503
|
+
# Mark the wrapper so we can detect not-implemented methods
|
171
504
|
wrapper.not_implemented = True
|
172
505
|
return wrapper
|
173
506
|
|
174
507
|
|
175
|
-
class
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
else:
|
184
|
-
raise TypeError(f'Must be type. But got {bases}')
|
185
|
-
return super().__new__(cls, name, bases, dct)
|
508
|
+
class _JointGenericAlias(_GenericAlias, _root=True):
|
509
|
+
"""
|
510
|
+
Generic alias for JointTypes (intersection types).
|
511
|
+
|
512
|
+
This class represents a type that requires all specified types to be satisfied.
|
513
|
+
Unlike _MetaUnionType which creates actual classes with metaclass conflicts,
|
514
|
+
this uses typing's generic alias system to avoid metaclass issues.
|
515
|
+
"""
|
186
516
|
|
187
|
-
def __instancecheck__(self,
|
188
|
-
|
189
|
-
|
517
|
+
def __instancecheck__(self, obj):
|
518
|
+
"""
|
519
|
+
Check if an instance is an instance of all component types.
|
520
|
+
"""
|
521
|
+
return all(isinstance(obj, cls) for cls in self.__args__)
|
190
522
|
|
191
523
|
def __subclasscheck__(self, subclass):
|
192
|
-
|
524
|
+
"""
|
525
|
+
Check if a class is a subclass of all component types.
|
526
|
+
"""
|
527
|
+
return all(issubclass(subclass, cls) for cls in self.__args__)
|
528
|
+
|
529
|
+
def __eq__(self, other):
|
530
|
+
"""
|
531
|
+
Check equality with another type.
|
193
532
|
|
533
|
+
Two JointTypes are equal if they have the same component types,
|
534
|
+
regardless of order.
|
535
|
+
"""
|
536
|
+
if not isinstance(other, _JointGenericAlias):
|
537
|
+
return NotImplemented
|
538
|
+
return set(self.__args__) == set(other.__args__)
|
194
539
|
|
195
|
-
|
196
|
-
|
197
|
-
|
540
|
+
def __hash__(self):
|
541
|
+
"""
|
542
|
+
Return hash of the JointType.
|
198
543
|
|
544
|
+
The hash is based on the frozenset of component types to ensure
|
545
|
+
that JointTypes with the same types (regardless of order) have
|
546
|
+
the same hash.
|
547
|
+
"""
|
548
|
+
return hash(frozenset(self.__args__))
|
199
549
|
|
200
|
-
|
201
|
-
|
202
|
-
|
550
|
+
def __repr__(self):
|
551
|
+
"""
|
552
|
+
Return string representation of the JointType.
|
203
553
|
|
204
|
-
|
554
|
+
Returns a readable representation showing all component types.
|
555
|
+
"""
|
556
|
+
args_str = ', '.join(
|
557
|
+
arg.__module__ + '.' + arg.__name__ if hasattr(arg, '__module__') and hasattr(arg, '__name__')
|
558
|
+
else str(arg)
|
559
|
+
for arg in self.__args__
|
560
|
+
)
|
561
|
+
return f'JointTypes[{args_str}]'
|
562
|
+
|
563
|
+
def __reduce__(self):
|
564
|
+
"""
|
565
|
+
Support for pickling.
|
205
566
|
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
567
|
+
Returns the necessary information to reconstruct the JointType
|
568
|
+
when unpickling.
|
569
|
+
"""
|
570
|
+
return (_JointGenericAlias, (self.__origin__, self.__args__))
|
210
571
|
|
211
|
-
JointTypes[JointTypes[int, str], float] == JointTypes[int, str, float]
|
212
572
|
|
213
|
-
|
573
|
+
class _OneOfGenericAlias(_GenericAlias, _root=True):
|
574
|
+
"""
|
575
|
+
Generic alias for OneOfTypes (union types).
|
214
576
|
|
215
|
-
|
577
|
+
This class represents a type that requires at least one of the specified
|
578
|
+
types to be satisfied. It's similar to typing.Union but provides a consistent
|
579
|
+
interface with JointTypes and avoids potential metaclass conflicts.
|
580
|
+
"""
|
216
581
|
|
217
|
-
|
582
|
+
def __instancecheck__(self, obj):
|
583
|
+
"""
|
584
|
+
Check if an instance is an instance of any component type.
|
585
|
+
"""
|
586
|
+
return any(isinstance(obj, cls) for cls in self.__args__)
|
218
587
|
|
219
|
-
|
588
|
+
def __subclasscheck__(self, subclass):
|
589
|
+
"""
|
590
|
+
Check if a class is a subclass of any component type.
|
591
|
+
"""
|
592
|
+
return any(issubclass(subclass, cls) for cls in self.__args__)
|
220
593
|
|
221
|
-
|
594
|
+
def __eq__(self, other):
|
595
|
+
"""
|
596
|
+
Check equality with another type.
|
222
597
|
|
223
|
-
|
598
|
+
Two OneOfTypes are equal if they have the same component types,
|
599
|
+
regardless of order.
|
600
|
+
"""
|
601
|
+
if not isinstance(other, _OneOfGenericAlias):
|
602
|
+
return NotImplemented
|
603
|
+
return set(self.__args__) == set(other.__args__)
|
224
604
|
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
if parameters == ():
|
229
|
-
raise TypeError("Cannot take a Joint of no types.")
|
230
|
-
if not isinstance(parameters, tuple):
|
231
|
-
parameters = (parameters,)
|
232
|
-
msg = "JointTypes[arg, ...]: each arg must be a type."
|
233
|
-
parameters = tuple(_type_check(p, msg) for p in parameters)
|
234
|
-
parameters = _remove_dups_flatten(parameters)
|
235
|
-
if len(parameters) == 1:
|
236
|
-
return parameters[0]
|
237
|
-
if len(parameters) == 2 and type(None) in parameters:
|
238
|
-
return _UnionGenericAlias(self, parameters, name="Optional")
|
239
|
-
return _JointGenericAlias(self, parameters)
|
605
|
+
def __hash__(self):
|
606
|
+
"""
|
607
|
+
Return hash of the OneOfType.
|
240
608
|
|
609
|
+
The hash is based on the frozenset of component types to ensure
|
610
|
+
that OneOfTypes with the same types (regardless of order) have
|
611
|
+
the same hash.
|
612
|
+
"""
|
613
|
+
return hash(frozenset(self.__args__))
|
241
614
|
|
242
|
-
|
243
|
-
|
244
|
-
|
615
|
+
def __repr__(self):
|
616
|
+
"""
|
617
|
+
Return string representation of the OneOfType.
|
245
618
|
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
619
|
+
Returns a readable representation showing all component types.
|
620
|
+
"""
|
621
|
+
args_str = ', '.join(
|
622
|
+
arg.__module__ + '.' + arg.__name__ if hasattr(arg, '__module__') and hasattr(arg, '__name__')
|
623
|
+
else str(arg)
|
624
|
+
for arg in self.__args__
|
625
|
+
)
|
626
|
+
return f'OneOfTypes[{args_str}]'
|
627
|
+
|
628
|
+
def __reduce__(self):
|
629
|
+
"""
|
630
|
+
Support for pickling.
|
251
631
|
|
252
|
-
|
632
|
+
Returns the necessary information to reconstruct the OneOfType
|
633
|
+
when unpickling.
|
634
|
+
"""
|
635
|
+
return (_OneOfGenericAlias, (self.__origin__, self.__args__))
|
253
636
|
|
254
|
-
- Unions of a single argument vanish, e.g.::
|
255
637
|
|
256
|
-
|
638
|
+
class _JointTypesClass:
|
639
|
+
"""Helper class to enable subscript syntax for JointTypes."""
|
257
640
|
|
258
|
-
|
641
|
+
def __call__(self, *types):
|
642
|
+
"""
|
643
|
+
Create a type that requires all specified types (intersection type).
|
644
|
+
|
645
|
+
This function creates a type hint that indicates a value must satisfy all
|
646
|
+
the specified types simultaneously. It's useful for expressing complex
|
647
|
+
type requirements where a single object must implement multiple interfaces.
|
648
|
+
|
649
|
+
Parameters
|
650
|
+
----------
|
651
|
+
*types : type
|
652
|
+
The types that must all be satisfied.
|
653
|
+
|
654
|
+
Returns
|
655
|
+
-------
|
656
|
+
type
|
657
|
+
A type that checks for all specified types.
|
658
|
+
|
659
|
+
Notes
|
660
|
+
-----
|
661
|
+
- If only one type is provided, that type is returned directly.
|
662
|
+
- Redundant types are automatically removed.
|
663
|
+
- The order of types doesn't matter for equality checks.
|
664
|
+
|
665
|
+
Examples
|
666
|
+
--------
|
667
|
+
Basic usage with interfaces:
|
668
|
+
|
669
|
+
.. code-block:: python
|
670
|
+
|
671
|
+
>>> import brainstate
|
672
|
+
>>> from typing import Protocol
|
673
|
+
>>>
|
674
|
+
>>> class Trainable(Protocol):
|
675
|
+
... def train(self): ...
|
676
|
+
>>>
|
677
|
+
>>> class Evaluable(Protocol):
|
678
|
+
... def evaluate(self): ...
|
679
|
+
>>>
|
680
|
+
>>> # A model that is both trainable and evaluable
|
681
|
+
>>> TrainableEvaluableModel = brainstate.mixin.JointTypes(Trainable, Evaluable)
|
682
|
+
>>> # Or using subscript syntax
|
683
|
+
>>> TrainableEvaluableModel = brainstate.mixin.JointTypes[Trainable, Evaluable]
|
684
|
+
>>>
|
685
|
+
>>> class NeuralNetwork(Trainable, Evaluable):
|
686
|
+
... def train(self):
|
687
|
+
... return "Training..."
|
688
|
+
...
|
689
|
+
... def evaluate(self):
|
690
|
+
... return "Evaluating..."
|
691
|
+
>>>
|
692
|
+
>>> model = NeuralNetwork()
|
693
|
+
>>> # model satisfies JointTypes(Trainable, Evaluable)
|
694
|
+
|
695
|
+
Using with mixin classes:
|
696
|
+
|
697
|
+
.. code-block:: python
|
698
|
+
|
699
|
+
>>> class Serializable:
|
700
|
+
... def save(self): pass
|
701
|
+
>>>
|
702
|
+
>>> class Visualizable:
|
703
|
+
... def plot(self): pass
|
704
|
+
>>>
|
705
|
+
>>> # Require both serialization and visualization
|
706
|
+
>>> FullFeaturedModel = brainstate.mixin.JointTypes[Serializable, Visualizable]
|
707
|
+
>>>
|
708
|
+
>>> class MyModel(Serializable, Visualizable):
|
709
|
+
... def save(self):
|
710
|
+
... return "Saved"
|
711
|
+
...
|
712
|
+
... def plot(self):
|
713
|
+
... return "Plotted"
|
714
|
+
"""
|
715
|
+
if len(types) == 0:
|
716
|
+
raise TypeError("Cannot create a JointTypes of no types.")
|
717
|
+
|
718
|
+
# Remove duplicates while preserving some order
|
719
|
+
seen = set()
|
720
|
+
unique_types = []
|
721
|
+
for t in types:
|
722
|
+
if t not in seen:
|
723
|
+
seen.add(t)
|
724
|
+
unique_types.append(t)
|
725
|
+
|
726
|
+
# If only one type, return it directly
|
727
|
+
if len(unique_types) == 1:
|
728
|
+
return unique_types[0]
|
729
|
+
|
730
|
+
# Create a generic alias for the joint type
|
731
|
+
# This avoids metaclass conflicts by using typing's generic alias system
|
732
|
+
return _JointGenericAlias(object, tuple(unique_types))
|
733
|
+
|
734
|
+
def __getitem__(self, item):
|
735
|
+
"""Enable subscript syntax: JointTypes[Type1, Type2]."""
|
736
|
+
if isinstance(item, tuple):
|
737
|
+
return self(*item)
|
738
|
+
else:
|
739
|
+
return self(item)
|
259
740
|
|
260
|
-
assert OneOfTypes[int, str, int] == OneOfTypes[int, str]
|
261
741
|
|
262
|
-
|
742
|
+
# Create singleton instance that acts as both a callable and supports subscript
|
743
|
+
JointTypes = _JointTypesClass()
|
263
744
|
|
264
|
-
assert OneOfTypes[int, str] == OneOfTypes[str, int]
|
265
745
|
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
746
|
+
class _OneOfTypesClass:
|
747
|
+
"""Helper class to enable subscript syntax for OneOfTypes."""
|
748
|
+
|
749
|
+
def __call__(self, *types):
|
750
|
+
"""
|
751
|
+
Create a type that requires one of the specified types (union type).
|
752
|
+
|
753
|
+
This is similar to typing.Union but provides a more intuitive name and
|
754
|
+
consistent behavior with JointTypes. It indicates that a value must satisfy
|
755
|
+
at least one of the specified types.
|
756
|
+
|
757
|
+
Parameters
|
758
|
+
----------
|
759
|
+
*types : type
|
760
|
+
The types, one of which must be satisfied.
|
761
|
+
|
762
|
+
Returns
|
763
|
+
-------
|
764
|
+
Union type
|
765
|
+
A union type of the specified types.
|
766
|
+
|
767
|
+
Notes
|
768
|
+
-----
|
769
|
+
- If only one type is provided, that type is returned directly.
|
770
|
+
- Redundant types are automatically removed.
|
771
|
+
- The order of types doesn't matter for equality checks.
|
772
|
+
- This is equivalent to typing.Union[...].
|
773
|
+
|
774
|
+
Examples
|
775
|
+
--------
|
776
|
+
Basic usage with different types:
|
777
|
+
|
778
|
+
.. code-block:: python
|
779
|
+
|
780
|
+
>>> import brainstate
|
781
|
+
>>>
|
782
|
+
>>> # A parameter that can be int or float
|
783
|
+
>>> NumericType = brainstate.mixin.OneOfTypes(int, float)
|
784
|
+
>>> # Or using subscript syntax
|
785
|
+
>>> NumericType = brainstate.mixin.OneOfTypes[int, float]
|
786
|
+
>>>
|
787
|
+
>>> def process_value(x: NumericType):
|
788
|
+
... return x * 2
|
789
|
+
>>>
|
790
|
+
>>> # Both work
|
791
|
+
>>> result1 = process_value(5) # int
|
792
|
+
>>> result2 = process_value(3.14) # float
|
793
|
+
|
794
|
+
Using with class types:
|
795
|
+
|
796
|
+
.. code-block:: python
|
797
|
+
|
798
|
+
>>> class NumpyArray:
|
799
|
+
... pass
|
800
|
+
>>>
|
801
|
+
>>> class JAXArray:
|
802
|
+
... pass
|
803
|
+
>>>
|
804
|
+
>>> # Accept either numpy or JAX arrays
|
805
|
+
>>> ArrayType = brainstate.mixin.OneOfTypes[NumpyArray, JAXArray]
|
806
|
+
>>>
|
807
|
+
>>> def compute(arr: ArrayType):
|
808
|
+
... if isinstance(arr, NumpyArray):
|
809
|
+
... return "Processing numpy array"
|
810
|
+
... elif isinstance(arr, JAXArray):
|
811
|
+
... return "Processing JAX array"
|
812
|
+
|
813
|
+
Combining with None for optional types:
|
814
|
+
|
815
|
+
.. code-block:: python
|
816
|
+
|
817
|
+
>>> # Optional string (equivalent to Optional[str])
|
818
|
+
>>> MaybeString = brainstate.mixin.OneOfTypes[str, type(None)]
|
819
|
+
>>>
|
820
|
+
>>> def format_name(name: MaybeString) -> str:
|
821
|
+
... if name is None:
|
822
|
+
... return "Anonymous"
|
823
|
+
... return name.title()
|
824
|
+
"""
|
825
|
+
if len(types) == 0:
|
826
|
+
raise TypeError("Cannot create a OneOfTypes of no types.")
|
827
|
+
|
828
|
+
# Remove duplicates
|
829
|
+
seen = set()
|
830
|
+
unique_types = []
|
831
|
+
for t in types:
|
832
|
+
if t not in seen:
|
833
|
+
seen.add(t)
|
834
|
+
unique_types.append(t)
|
835
|
+
|
836
|
+
# If only one type, return it directly
|
837
|
+
if len(unique_types) == 1:
|
838
|
+
return unique_types[0]
|
839
|
+
|
840
|
+
# Create a generic alias for the union type
|
841
|
+
# This provides consistency with JointTypes and avoids metaclass conflicts
|
842
|
+
return _OneOfGenericAlias(Union, tuple(unique_types))
|
843
|
+
|
844
|
+
def __getitem__(self, item):
|
845
|
+
"""Enable subscript syntax: OneOfTypes[Type1, Type2]."""
|
846
|
+
if isinstance(item, tuple):
|
847
|
+
return self(*item)
|
848
|
+
else:
|
849
|
+
return self(item)
|
850
|
+
|
851
|
+
|
852
|
+
# Create singleton instance that acts as both a callable and supports subscript
|
853
|
+
OneOfTypes = _OneOfTypesClass()
|
854
|
+
|
855
|
+
|
856
|
+
def __getattr__(name):
|
857
|
+
if name in [
|
858
|
+
'Mode',
|
859
|
+
'JointMode',
|
860
|
+
'Batching',
|
861
|
+
'Training',
|
862
|
+
'AlignPost',
|
863
|
+
'BindCondData',
|
864
|
+
]:
|
865
|
+
import warnings
|
866
|
+
warnings.warn(
|
867
|
+
f"brainstate.mixin.{name} is deprecated and will be removed in a future version. "
|
868
|
+
f"Please use brainpy.mixin.{name} instead.",
|
869
|
+
DeprecationWarning,
|
870
|
+
stacklevel=2
|
871
|
+
)
|
872
|
+
import brainpy
|
873
|
+
return getattr(brainpy.mixin, name)
|
874
|
+
raise AttributeError(
|
875
|
+
f'module {__name__!r} has no attribute {name!r}'
|
876
|
+
)
|
281
877
|
|
282
878
|
|
283
879
|
class Mode(Mixin):
|
284
880
|
"""
|
285
|
-
Base class for computation
|
881
|
+
Base class for computation behavior modes.
|
882
|
+
|
883
|
+
Modes are used to represent different computational contexts or behaviors,
|
884
|
+
such as training vs evaluation, batched vs single-sample processing, etc.
|
885
|
+
They provide a flexible way to configure how models and components behave
|
886
|
+
in different scenarios.
|
887
|
+
|
888
|
+
Examples
|
889
|
+
--------
|
890
|
+
Creating a custom mode:
|
891
|
+
|
892
|
+
.. code-block:: python
|
893
|
+
|
894
|
+
>>> import brainstate
|
895
|
+
>>>
|
896
|
+
>>> class InferenceMode(brainstate.mixin.Mode):
|
897
|
+
... def __init__(self, use_cache=True):
|
898
|
+
... self.use_cache = use_cache
|
899
|
+
>>>
|
900
|
+
>>> # Create mode instances
|
901
|
+
>>> inference = InferenceMode(use_cache=True)
|
902
|
+
>>> print(inference) # Output: InferenceMode
|
903
|
+
|
904
|
+
Checking mode types:
|
905
|
+
|
906
|
+
.. code-block:: python
|
907
|
+
|
908
|
+
>>> class FastMode(brainstate.mixin.Mode):
|
909
|
+
... pass
|
910
|
+
>>>
|
911
|
+
>>> class SlowMode(brainstate.mixin.Mode):
|
912
|
+
... pass
|
913
|
+
>>>
|
914
|
+
>>> fast = FastMode()
|
915
|
+
>>> slow = SlowMode()
|
916
|
+
>>>
|
917
|
+
>>> # Check exact mode type
|
918
|
+
>>> assert fast.is_a(FastMode)
|
919
|
+
>>> assert not fast.is_a(SlowMode)
|
920
|
+
>>>
|
921
|
+
>>> # Check if mode is an instance of a type
|
922
|
+
>>> assert fast.has(brainstate.mixin.Mode)
|
923
|
+
|
924
|
+
Using modes in a model:
|
925
|
+
|
926
|
+
.. code-block:: python
|
927
|
+
|
928
|
+
>>> class Model:
|
929
|
+
... def __init__(self):
|
930
|
+
... self.mode = brainstate.mixin.Training()
|
931
|
+
...
|
932
|
+
... def forward(self, x):
|
933
|
+
... if self.mode.has(brainstate.mixin.Training):
|
934
|
+
... # Training-specific logic
|
935
|
+
... return self.train_forward(x)
|
936
|
+
... else:
|
937
|
+
... # Inference logic
|
938
|
+
... return self.eval_forward(x)
|
939
|
+
...
|
940
|
+
... def train_forward(self, x):
|
941
|
+
... return x + 0.1 # Add noise during training
|
942
|
+
...
|
943
|
+
... def eval_forward(self, x):
|
944
|
+
... return x # No noise during evaluation
|
286
945
|
"""
|
287
946
|
|
288
947
|
def __repr__(self):
|
948
|
+
"""
|
949
|
+
String representation of the mode.
|
950
|
+
|
951
|
+
Returns
|
952
|
+
-------
|
953
|
+
str
|
954
|
+
The class name of the mode.
|
955
|
+
"""
|
289
956
|
return self.__class__.__name__
|
290
957
|
|
291
958
|
def __eq__(self, other: 'Mode'):
|
959
|
+
"""
|
960
|
+
Check equality of modes based on their type.
|
961
|
+
|
962
|
+
Parameters
|
963
|
+
----------
|
964
|
+
other : Mode
|
965
|
+
Another mode to compare with.
|
966
|
+
|
967
|
+
Returns
|
968
|
+
-------
|
969
|
+
bool
|
970
|
+
True if both modes are of the same class.
|
971
|
+
"""
|
292
972
|
assert isinstance(other, Mode)
|
293
973
|
return other.__class__ == self.__class__
|
294
974
|
|
295
975
|
def is_a(self, mode: type):
|
296
976
|
"""
|
297
|
-
Check whether the mode is exactly the desired mode.
|
977
|
+
Check whether the mode is exactly the desired mode type.
|
978
|
+
|
979
|
+
This performs an exact type match, not checking for subclasses.
|
980
|
+
|
981
|
+
Parameters
|
982
|
+
----------
|
983
|
+
mode : type
|
984
|
+
The mode type to check against.
|
985
|
+
|
986
|
+
Returns
|
987
|
+
-------
|
988
|
+
bool
|
989
|
+
True if this mode is exactly of the specified type.
|
990
|
+
|
991
|
+
Examples
|
992
|
+
--------
|
993
|
+
.. code-block:: python
|
994
|
+
|
995
|
+
>>> import brainstate
|
996
|
+
>>>
|
997
|
+
>>> training_mode = brainstate.mixin.Training()
|
998
|
+
>>> assert training_mode.is_a(brainstate.mixin.Training)
|
999
|
+
>>> assert not training_mode.is_a(brainstate.mixin.Batching)
|
298
1000
|
"""
|
299
1001
|
assert isinstance(mode, type), 'Must be a type.'
|
300
1002
|
return self.__class__ == mode
|
301
1003
|
|
302
1004
|
def has(self, mode: type):
|
303
1005
|
"""
|
304
|
-
Check whether the mode
|
1006
|
+
Check whether the mode includes the desired mode type.
|
1007
|
+
|
1008
|
+
This checks if the current mode is an instance of the specified type,
|
1009
|
+
including checking for subclasses.
|
1010
|
+
|
1011
|
+
Parameters
|
1012
|
+
----------
|
1013
|
+
mode : type
|
1014
|
+
The mode type to check for.
|
1015
|
+
|
1016
|
+
Returns
|
1017
|
+
-------
|
1018
|
+
bool
|
1019
|
+
True if this mode is an instance of the specified type.
|
1020
|
+
|
1021
|
+
Examples
|
1022
|
+
--------
|
1023
|
+
.. code-block:: python
|
1024
|
+
|
1025
|
+
>>> import brainstate
|
1026
|
+
>>>
|
1027
|
+
>>> # Create a custom mode that extends Training
|
1028
|
+
>>> class AdvancedTraining(brainstate.mixin.Training):
|
1029
|
+
... pass
|
1030
|
+
>>>
|
1031
|
+
>>> advanced = AdvancedTraining()
|
1032
|
+
>>> assert advanced.has(brainstate.mixin.Training) # True (subclass)
|
1033
|
+
>>> assert advanced.has(brainstate.mixin.Mode) # True (base class)
|
305
1034
|
"""
|
306
1035
|
assert isinstance(mode, type), 'Must be a type.'
|
307
1036
|
return isinstance(self, mode)
|
@@ -309,57 +1038,396 @@ class Mode(Mixin):
|
|
309
1038
|
|
310
1039
|
class JointMode(Mode):
|
311
1040
|
"""
|
312
|
-
|
1041
|
+
A mode that combines multiple modes simultaneously.
|
1042
|
+
|
1043
|
+
JointMode allows expressing that a computation is in multiple modes at once,
|
1044
|
+
such as being both in training mode and batching mode. This is useful for
|
1045
|
+
complex scenarios where multiple behavioral aspects need to be active.
|
1046
|
+
|
1047
|
+
Parameters
|
1048
|
+
----------
|
1049
|
+
*modes : Mode
|
1050
|
+
The modes to combine.
|
1051
|
+
|
1052
|
+
Attributes
|
1053
|
+
----------
|
1054
|
+
modes : tuple of Mode
|
1055
|
+
The individual modes that are combined.
|
1056
|
+
types : set of type
|
1057
|
+
The types of the combined modes.
|
1058
|
+
|
1059
|
+
Raises
|
1060
|
+
------
|
1061
|
+
TypeError
|
1062
|
+
If any of the provided arguments is not a Mode instance.
|
1063
|
+
|
1064
|
+
Examples
|
1065
|
+
--------
|
1066
|
+
Combining training and batching modes:
|
1067
|
+
|
1068
|
+
.. code-block:: python
|
1069
|
+
|
1070
|
+
>>> import brainstate
|
1071
|
+
>>>
|
1072
|
+
>>> # Create individual modes
|
1073
|
+
>>> training = brainstate.mixin.Training()
|
1074
|
+
>>> batching = brainstate.mixin.Batching(batch_size=32)
|
1075
|
+
>>>
|
1076
|
+
>>> # Combine them
|
1077
|
+
>>> joint = brainstate.mixin.JointMode(training, batching)
|
1078
|
+
>>> print(joint) # JointMode(Training, Batching(in_size=32, axis=0))
|
1079
|
+
>>>
|
1080
|
+
>>> # Check if specific modes are present
|
1081
|
+
>>> assert joint.has(brainstate.mixin.Training)
|
1082
|
+
>>> assert joint.has(brainstate.mixin.Batching)
|
1083
|
+
>>>
|
1084
|
+
>>> # Access attributes from combined modes
|
1085
|
+
>>> print(joint.batch_size) # 32 (from Batching mode)
|
1086
|
+
|
1087
|
+
Using in model configuration:
|
1088
|
+
|
1089
|
+
.. code-block:: python
|
1090
|
+
|
1091
|
+
>>> class NeuralNetwork:
|
1092
|
+
... def __init__(self):
|
1093
|
+
... self.mode = None
|
1094
|
+
...
|
1095
|
+
... def set_train_mode(self, batch_size=1):
|
1096
|
+
... # Set both training and batching modes
|
1097
|
+
... training = brainstate.mixin.Training()
|
1098
|
+
... batching = brainstate.mixin.Batching(batch_size=batch_size)
|
1099
|
+
... self.mode = brainstate.mixin.JointMode(training, batching)
|
1100
|
+
...
|
1101
|
+
... def forward(self, x):
|
1102
|
+
... if self.mode.has(brainstate.mixin.Training):
|
1103
|
+
... x = self.apply_dropout(x)
|
1104
|
+
...
|
1105
|
+
... if self.mode.has(brainstate.mixin.Batching):
|
1106
|
+
... # Process in batches
|
1107
|
+
... batch_size = self.mode.batch_size
|
1108
|
+
... return self.batch_process(x, batch_size)
|
1109
|
+
...
|
1110
|
+
... return self.process(x)
|
1111
|
+
>>>
|
1112
|
+
>>> model = NeuralNetwork()
|
1113
|
+
>>> model.set_train_mode(batch_size=64)
|
313
1114
|
"""
|
314
1115
|
|
315
1116
|
def __init__(self, *modes: Mode):
|
1117
|
+
# Validate that all arguments are Mode instances
|
316
1118
|
for m_ in modes:
|
317
1119
|
if not isinstance(m_, Mode):
|
318
1120
|
raise TypeError(f'The supported type must be a tuple/list of Mode. But we got {m_}')
|
1121
|
+
|
1122
|
+
# Store the modes as a tuple
|
319
1123
|
self.modes = tuple(modes)
|
1124
|
+
|
1125
|
+
# Store the types of the modes for quick lookup
|
320
1126
|
self.types = set([m.__class__ for m in modes])
|
321
1127
|
|
322
1128
|
def __repr__(self):
|
1129
|
+
"""
|
1130
|
+
String representation showing all combined modes.
|
1131
|
+
|
1132
|
+
Returns
|
1133
|
+
-------
|
1134
|
+
str
|
1135
|
+
A string showing the joint mode and its components.
|
1136
|
+
"""
|
323
1137
|
return f'{self.__class__.__name__}({", ".join([repr(m) for m in self.modes])})'
|
324
1138
|
|
325
1139
|
def has(self, mode: type):
|
326
1140
|
"""
|
327
|
-
Check whether the
|
1141
|
+
Check whether any of the combined modes includes the desired type.
|
1142
|
+
|
1143
|
+
Parameters
|
1144
|
+
----------
|
1145
|
+
mode : type
|
1146
|
+
The mode type to check for.
|
1147
|
+
|
1148
|
+
Returns
|
1149
|
+
-------
|
1150
|
+
bool
|
1151
|
+
True if any of the combined modes is or inherits from the specified type.
|
1152
|
+
|
1153
|
+
Examples
|
1154
|
+
--------
|
1155
|
+
.. code-block:: python
|
1156
|
+
|
1157
|
+
>>> import brainstate
|
1158
|
+
>>>
|
1159
|
+
>>> training = brainstate.mixin.Training()
|
1160
|
+
>>> batching = brainstate.mixin.Batching(batch_size=16)
|
1161
|
+
>>> joint = brainstate.mixin.JointMode(training, batching)
|
1162
|
+
>>>
|
1163
|
+
>>> assert joint.has(brainstate.mixin.Training)
|
1164
|
+
>>> assert joint.has(brainstate.mixin.Batching)
|
1165
|
+
>>> assert joint.has(brainstate.mixin.Mode) # Base class
|
328
1166
|
"""
|
329
1167
|
assert isinstance(mode, type), 'Must be a type.'
|
1168
|
+
# Check if any of the combined mode types is a subclass of the target mode
|
330
1169
|
return any([issubclass(cls, mode) for cls in self.types])
|
331
1170
|
|
332
1171
|
def is_a(self, cls: type):
|
333
1172
|
"""
|
334
|
-
Check whether the mode is exactly the desired
|
1173
|
+
Check whether the joint mode is exactly the desired combined type.
|
1174
|
+
|
1175
|
+
This is a complex check that verifies the joint mode matches a specific
|
1176
|
+
combination of types.
|
1177
|
+
|
1178
|
+
Parameters
|
1179
|
+
----------
|
1180
|
+
cls : type
|
1181
|
+
The combined type to check against.
|
1182
|
+
|
1183
|
+
Returns
|
1184
|
+
-------
|
1185
|
+
bool
|
1186
|
+
True if the joint mode exactly matches the specified type combination.
|
335
1187
|
"""
|
336
|
-
|
1188
|
+
# Use JointTypes to create the expected type from our mode types
|
1189
|
+
return JointTypes(*tuple(self.types)) == cls
|
337
1190
|
|
338
1191
|
def __getattr__(self, item):
|
339
1192
|
"""
|
340
|
-
Get
|
341
|
-
|
342
|
-
|
1193
|
+
Get attributes from the combined modes.
|
1194
|
+
|
1195
|
+
This method searches through all combined modes to find the requested
|
1196
|
+
attribute, allowing transparent access to properties of any of the
|
1197
|
+
combined modes.
|
1198
|
+
|
1199
|
+
Parameters
|
1200
|
+
----------
|
1201
|
+
item : str
|
1202
|
+
The attribute name to search for.
|
1203
|
+
|
1204
|
+
Returns
|
1205
|
+
-------
|
1206
|
+
Any
|
1207
|
+
The attribute value from the first mode that has it.
|
1208
|
+
|
1209
|
+
Raises
|
1210
|
+
------
|
1211
|
+
AttributeError
|
1212
|
+
If the attribute is not found in any of the combined modes.
|
1213
|
+
|
1214
|
+
Examples
|
1215
|
+
--------
|
1216
|
+
.. code-block:: python
|
1217
|
+
|
1218
|
+
>>> import brainstate
|
1219
|
+
>>>
|
1220
|
+
>>> batching = brainstate.mixin.Batching(batch_size=32, batch_axis=1)
|
1221
|
+
>>> training = brainstate.mixin.Training()
|
1222
|
+
>>> joint = brainstate.mixin.JointMode(batching, training)
|
1223
|
+
>>>
|
1224
|
+
>>> # Access batching attributes directly
|
1225
|
+
>>> print(joint.batch_size) # 32
|
1226
|
+
>>> print(joint.batch_axis) # 1
|
343
1227
|
"""
|
1228
|
+
# Don't interfere with accessing modes and types attributes
|
344
1229
|
if item in ['modes', 'types']:
|
345
1230
|
return super().__getattribute__(item)
|
1231
|
+
|
1232
|
+
# Search for the attribute in each combined mode
|
346
1233
|
for m in self.modes:
|
347
1234
|
if hasattr(m, item):
|
348
1235
|
return getattr(m, item)
|
1236
|
+
|
1237
|
+
# If not found, fall back to default behavior (will raise AttributeError)
|
349
1238
|
return super().__getattribute__(item)
|
350
1239
|
|
351
1240
|
|
352
1241
|
class Batching(Mode):
|
353
|
-
"""
|
1242
|
+
"""
|
1243
|
+
Mode indicating batched computation.
|
1244
|
+
|
1245
|
+
This mode specifies that computations should be performed on batches of data,
|
1246
|
+
including information about the batch size and which axis represents the batch
|
1247
|
+
dimension.
|
1248
|
+
|
1249
|
+
Parameters
|
1250
|
+
----------
|
1251
|
+
batch_size : int, default 1
|
1252
|
+
The size of each batch.
|
1253
|
+
batch_axis : int, default 0
|
1254
|
+
The axis along which batching occurs.
|
1255
|
+
|
1256
|
+
Attributes
|
1257
|
+
----------
|
1258
|
+
batch_size : int
|
1259
|
+
The number of samples in each batch.
|
1260
|
+
batch_axis : int
|
1261
|
+
The axis index representing the batch dimension.
|
1262
|
+
|
1263
|
+
Examples
|
1264
|
+
--------
|
1265
|
+
Basic batching configuration:
|
1266
|
+
|
1267
|
+
.. code-block:: python
|
1268
|
+
|
1269
|
+
>>> import brainstate
|
1270
|
+
>>>
|
1271
|
+
>>> # Create a batching mode
|
1272
|
+
>>> batching = brainstate.mixin.Batching(batch_size=32, batch_axis=0)
|
1273
|
+
>>> print(batching) # Batching(in_size=32, axis=0)
|
1274
|
+
>>>
|
1275
|
+
>>> # Access batch parameters
|
1276
|
+
>>> print(f"Processing {batching.batch_size} samples at once")
|
1277
|
+
>>> print(f"Batch dimension is axis {batching.batch_axis}")
|
1278
|
+
|
1279
|
+
Using in a model:
|
1280
|
+
|
1281
|
+
.. code-block:: python
|
1282
|
+
|
1283
|
+
>>> import jax.numpy as jnp
|
1284
|
+
>>>
|
1285
|
+
>>> class BatchedModel:
|
1286
|
+
... def __init__(self):
|
1287
|
+
... self.mode = None
|
1288
|
+
...
|
1289
|
+
... def set_batch_mode(self, batch_size, batch_axis=0):
|
1290
|
+
... self.mode = brainstate.mixin.Batching(batch_size, batch_axis)
|
1291
|
+
...
|
1292
|
+
... def process(self, x):
|
1293
|
+
... if self.mode is not None and self.mode.has(brainstate.mixin.Batching):
|
1294
|
+
... # Process in batches
|
1295
|
+
... batch_size = self.mode.batch_size
|
1296
|
+
... axis = self.mode.batch_axis
|
1297
|
+
... return jnp.mean(x, axis=axis, keepdims=True)
|
1298
|
+
... return x
|
1299
|
+
>>>
|
1300
|
+
>>> model = BatchedModel()
|
1301
|
+
>>> model.set_batch_mode(batch_size=64)
|
1302
|
+
>>>
|
1303
|
+
>>> # Process batched data
|
1304
|
+
>>> data = jnp.random.randn(64, 100) # 64 samples, 100 features
|
1305
|
+
>>> result = model.process(data)
|
1306
|
+
|
1307
|
+
Combining with other modes:
|
1308
|
+
|
1309
|
+
.. code-block:: python
|
1310
|
+
|
1311
|
+
>>> # Combine batching with training mode
|
1312
|
+
>>> training = brainstate.mixin.Training()
|
1313
|
+
>>> batching = brainstate.mixin.Batching(batch_size=128)
|
1314
|
+
>>> combined = brainstate.mixin.JointMode(training, batching)
|
1315
|
+
>>>
|
1316
|
+
>>> # Use in a training loop
|
1317
|
+
>>> def train_step(model, data, mode):
|
1318
|
+
... if mode.has(brainstate.mixin.Batching):
|
1319
|
+
... # Split data into batches
|
1320
|
+
... batch_size = mode.batch_size
|
1321
|
+
... # ... batched processing ...
|
1322
|
+
... if mode.has(brainstate.mixin.Training):
|
1323
|
+
... # Apply training-specific operations
|
1324
|
+
... # ... training logic ...
|
1325
|
+
... pass
|
1326
|
+
"""
|
354
1327
|
|
355
1328
|
def __init__(self, batch_size: int = 1, batch_axis: int = 0):
|
356
1329
|
self.batch_size = batch_size
|
357
1330
|
self.batch_axis = batch_axis
|
358
1331
|
|
359
1332
|
def __repr__(self):
|
1333
|
+
"""
|
1334
|
+
String representation showing batch configuration.
|
1335
|
+
|
1336
|
+
Returns
|
1337
|
+
-------
|
1338
|
+
str
|
1339
|
+
A string showing the batch size and axis.
|
1340
|
+
"""
|
360
1341
|
return f'{self.__class__.__name__}(in_size={self.batch_size}, axis={self.batch_axis})'
|
361
1342
|
|
362
1343
|
|
363
1344
|
class Training(Mode):
|
364
|
-
"""
|
1345
|
+
"""
|
1346
|
+
Mode indicating training computation.
|
1347
|
+
|
1348
|
+
This mode specifies that the model is in training mode, which typically
|
1349
|
+
enables behaviors like dropout, batch normalization in training mode,
|
1350
|
+
gradient computation, etc.
|
1351
|
+
|
1352
|
+
Examples
|
1353
|
+
--------
|
1354
|
+
Basic training mode:
|
1355
|
+
|
1356
|
+
.. code-block:: python
|
1357
|
+
|
1358
|
+
>>> import brainstate
|
1359
|
+
>>>
|
1360
|
+
>>> # Create training mode
|
1361
|
+
>>> training = brainstate.mixin.Training()
|
1362
|
+
>>> print(training) # Training
|
1363
|
+
>>>
|
1364
|
+
>>> # Check mode
|
1365
|
+
>>> assert training.is_a(brainstate.mixin.Training)
|
1366
|
+
>>> assert training.has(brainstate.mixin.Mode)
|
1367
|
+
|
1368
|
+
Using in a model with dropout:
|
1369
|
+
|
1370
|
+
.. code-block:: python
|
1371
|
+
|
1372
|
+
>>> import brainstate
|
1373
|
+
>>> import jax
|
1374
|
+
>>> import jax.numpy as jnp
|
1375
|
+
>>>
|
1376
|
+
>>> class ModelWithDropout:
|
1377
|
+
... def __init__(self, dropout_rate=0.5):
|
1378
|
+
... self.dropout_rate = dropout_rate
|
1379
|
+
... self.mode = None
|
1380
|
+
...
|
1381
|
+
... def set_training(self, is_training=True):
|
1382
|
+
... if is_training:
|
1383
|
+
... self.mode = brainstate.mixin.Training()
|
1384
|
+
... else:
|
1385
|
+
... self.mode = brainstate.mixin.Mode() # Evaluation mode
|
1386
|
+
...
|
1387
|
+
... def forward(self, x, rng_key):
|
1388
|
+
... # Apply dropout only during training
|
1389
|
+
... if self.mode is not None and self.mode.has(brainstate.mixin.Training):
|
1390
|
+
... keep_prob = 1.0 - self.dropout_rate
|
1391
|
+
... mask = jax.random.bernoulli(rng_key, keep_prob, x.shape)
|
1392
|
+
... x = jnp.where(mask, x / keep_prob, 0)
|
1393
|
+
... return x
|
1394
|
+
>>>
|
1395
|
+
>>> model = ModelWithDropout()
|
1396
|
+
>>>
|
1397
|
+
>>> # Training mode
|
1398
|
+
>>> model.set_training(True)
|
1399
|
+
>>> key = jax.random.PRNGKey(0)
|
1400
|
+
>>> x_train = jnp.ones((10, 20))
|
1401
|
+
>>> out_train = model.forward(x_train, key) # Dropout applied
|
1402
|
+
>>>
|
1403
|
+
>>> # Evaluation mode
|
1404
|
+
>>> model.set_training(False)
|
1405
|
+
>>> out_eval = model.forward(x_train, key) # No dropout
|
1406
|
+
|
1407
|
+
Combining with batching:
|
1408
|
+
|
1409
|
+
.. code-block:: python
|
1410
|
+
|
1411
|
+
>>> # Create combined training and batching mode
|
1412
|
+
>>> training = brainstate.mixin.Training()
|
1413
|
+
>>> batching = brainstate.mixin.Batching(batch_size=32)
|
1414
|
+
>>> mode = brainstate.mixin.JointMode(training, batching)
|
1415
|
+
>>>
|
1416
|
+
>>> # Use in training configuration
|
1417
|
+
>>> class Trainer:
|
1418
|
+
... def __init__(self, model, mode):
|
1419
|
+
... self.model = model
|
1420
|
+
... self.mode = mode
|
1421
|
+
...
|
1422
|
+
... def train_epoch(self, data):
|
1423
|
+
... if self.mode.has(brainstate.mixin.Training):
|
1424
|
+
... # Enable training-specific behaviors
|
1425
|
+
... self.model.set_training(True)
|
1426
|
+
...
|
1427
|
+
... if self.mode.has(brainstate.mixin.Batching):
|
1428
|
+
... # Process in batches
|
1429
|
+
... batch_size = self.mode.batch_size
|
1430
|
+
... # ... batched training loop ...
|
1431
|
+
... pass
|
1432
|
+
"""
|
365
1433
|
pass
|