brainstate 0.0.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +45 -0
- brainstate/_module.py +1466 -0
- brainstate/_module_test.py +133 -0
- brainstate/_state.py +378 -0
- brainstate/_state_test.py +41 -0
- brainstate/_utils.py +21 -0
- brainstate/environ.py +375 -0
- brainstate/functional/__init__.py +25 -0
- brainstate/functional/_activations.py +754 -0
- brainstate/functional/_normalization.py +69 -0
- brainstate/functional/_spikes.py +90 -0
- brainstate/init/__init__.py +26 -0
- brainstate/init/_base.py +36 -0
- brainstate/init/_generic.py +175 -0
- brainstate/init/_random_inits.py +489 -0
- brainstate/init/_regular_inits.py +109 -0
- brainstate/math/__init__.py +21 -0
- brainstate/math/_einops.py +787 -0
- brainstate/math/_einops_parsing.py +169 -0
- brainstate/math/_einops_parsing_test.py +126 -0
- brainstate/math/_einops_test.py +346 -0
- brainstate/math/_misc.py +298 -0
- brainstate/math/_misc_test.py +58 -0
- brainstate/mixin.py +373 -0
- brainstate/mixin_test.py +73 -0
- brainstate/nn/__init__.py +68 -0
- brainstate/nn/_base.py +248 -0
- brainstate/nn/_connections.py +686 -0
- brainstate/nn/_dynamics.py +406 -0
- brainstate/nn/_elementwise.py +1437 -0
- brainstate/nn/_misc.py +132 -0
- brainstate/nn/_normalizations.py +389 -0
- brainstate/nn/_others.py +100 -0
- brainstate/nn/_poolings.py +1228 -0
- brainstate/nn/_poolings_test.py +231 -0
- brainstate/nn/_projection/__init__.py +32 -0
- brainstate/nn/_projection/_align_post.py +528 -0
- brainstate/nn/_projection/_align_pre.py +599 -0
- brainstate/nn/_projection/_delta.py +241 -0
- brainstate/nn/_projection/_utils.py +17 -0
- brainstate/nn/_projection/_vanilla.py +101 -0
- brainstate/nn/_rate_rnns.py +393 -0
- brainstate/nn/_readout.py +130 -0
- brainstate/nn/_synouts.py +166 -0
- brainstate/nn/functional/__init__.py +25 -0
- brainstate/nn/functional/_activations.py +754 -0
- brainstate/nn/functional/_normalization.py +69 -0
- brainstate/nn/functional/_spikes.py +90 -0
- brainstate/nn/init/__init__.py +26 -0
- brainstate/nn/init/_base.py +36 -0
- brainstate/nn/init/_generic.py +175 -0
- brainstate/nn/init/_random_inits.py +489 -0
- brainstate/nn/init/_regular_inits.py +109 -0
- brainstate/nn/surrogate.py +1740 -0
- brainstate/optim/__init__.py +23 -0
- brainstate/optim/_lr_scheduler.py +486 -0
- brainstate/optim/_lr_scheduler_test.py +36 -0
- brainstate/optim/_sgd_optimizer.py +1148 -0
- brainstate/random.py +5148 -0
- brainstate/random_test.py +576 -0
- brainstate/surrogate.py +1740 -0
- brainstate/transform/__init__.py +36 -0
- brainstate/transform/_autograd.py +585 -0
- brainstate/transform/_autograd_test.py +1183 -0
- brainstate/transform/_control.py +665 -0
- brainstate/transform/_controls_test.py +220 -0
- brainstate/transform/_jit.py +239 -0
- brainstate/transform/_jit_error.py +158 -0
- brainstate/transform/_jit_test.py +102 -0
- brainstate/transform/_make_jaxpr.py +573 -0
- brainstate/transform/_make_jaxpr_test.py +133 -0
- brainstate/transform/_progress_bar.py +113 -0
- brainstate/typing.py +69 -0
- brainstate/util.py +747 -0
- brainstate-0.0.1.dist-info/LICENSE +202 -0
- brainstate-0.0.1.dist-info/METADATA +101 -0
- brainstate-0.0.1.dist-info/RECORD +79 -0
- brainstate-0.0.1.dist-info/WHEEL +6 -0
- brainstate-0.0.1.dist-info/top_level.txt +1 -0
brainstate/mixin.py
ADDED
@@ -0,0 +1,373 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
from typing import Sequence, Optional, TypeVar, Any
|
19
|
+
from typing import (_SpecialForm, _type_check, _remove_dups_flatten, _UnionGenericAlias)
|
20
|
+
|
21
|
+
T = TypeVar('T')
|
22
|
+
PyTree = Any
|
23
|
+
|
24
|
+
State = None
|
25
|
+
|
26
|
+
__all__ = [
|
27
|
+
'Mixin',
|
28
|
+
'DelayedInit',
|
29
|
+
'DelayedInitializer',
|
30
|
+
'AlignPost',
|
31
|
+
'BindCondData',
|
32
|
+
'UpdateReturn',
|
33
|
+
|
34
|
+
# types
|
35
|
+
'AllOfTypes',
|
36
|
+
'OneOfTypes',
|
37
|
+
|
38
|
+
# behavior modes
|
39
|
+
'Mode',
|
40
|
+
'JointMode',
|
41
|
+
'Batching',
|
42
|
+
'Training',
|
43
|
+
]
|
44
|
+
|
45
|
+
|
46
|
+
def _get_state():
|
47
|
+
global State
|
48
|
+
if State is None:
|
49
|
+
from brainstate._state import State
|
50
|
+
return State
|
51
|
+
|
52
|
+
|
53
|
+
class Mixin(object):
|
54
|
+
"""Base Mixin object.
|
55
|
+
|
56
|
+
The key for a :py:class:`~.Mixin` is that: no initialization function, only behavioral functions.
|
57
|
+
"""
|
58
|
+
pass
|
59
|
+
|
60
|
+
|
61
|
+
class DelayedInit(Mixin):
|
62
|
+
"""
|
63
|
+
:py:class:`~.Mixin` indicates the function for describing initialization parameters.
|
64
|
+
|
65
|
+
This mixin enables the subclass has a classmethod ``delayed``, which
|
66
|
+
produces an instance of :py:class:`~.DelayedInitializer`.
|
67
|
+
|
68
|
+
Note this Mixin can be applied in any Python object.
|
69
|
+
"""
|
70
|
+
|
71
|
+
non_hash_params: Optional[Sequence[str]] = None
|
72
|
+
|
73
|
+
@classmethod
|
74
|
+
def delayed(cls, *args, **kwargs) -> 'DelayedInitializer':
|
75
|
+
return DelayedInitializer(cls, *args, **kwargs)
|
76
|
+
|
77
|
+
|
78
|
+
class HashableDict(dict):
|
79
|
+
def __hash__(self):
|
80
|
+
return hash(tuple(sorted(self.items())))
|
81
|
+
|
82
|
+
|
83
|
+
class NoSubclassMeta(type):
|
84
|
+
def __new__(cls, name, bases, classdict):
|
85
|
+
for b in bases:
|
86
|
+
if isinstance(b, NoSubclassMeta):
|
87
|
+
raise TypeError("type '{0}' is not an acceptable base type".format(b.__name__))
|
88
|
+
return type.__new__(cls, name, bases, dict(classdict))
|
89
|
+
|
90
|
+
|
91
|
+
class DelayedInitializer(metaclass=NoSubclassMeta):
|
92
|
+
"""
|
93
|
+
DelayedInit initialization for parameter describers.
|
94
|
+
"""
|
95
|
+
|
96
|
+
def __init__(self, cls: T, *desc_tuple, **desc_dict):
|
97
|
+
self.cls = cls
|
98
|
+
|
99
|
+
# arguments
|
100
|
+
self.args = desc_tuple
|
101
|
+
self.kwargs = desc_dict
|
102
|
+
|
103
|
+
# identifier
|
104
|
+
self._identifier = (cls, tuple(desc_tuple), HashableDict(desc_dict))
|
105
|
+
|
106
|
+
def __call__(self, *args, **kwargs) -> T:
|
107
|
+
return self.cls(*self.args, *args, **self.kwargs, **kwargs)
|
108
|
+
|
109
|
+
def init(self, *args, **kwargs):
|
110
|
+
return self.__call__(*args, **kwargs)
|
111
|
+
|
112
|
+
def __instancecheck__(self, instance):
|
113
|
+
if not isinstance(instance, DelayedInitializer):
|
114
|
+
return False
|
115
|
+
if not issubclass(instance.cls, self.cls):
|
116
|
+
return False
|
117
|
+
return True
|
118
|
+
|
119
|
+
@classmethod
|
120
|
+
def __class_getitem__(cls, item: type):
|
121
|
+
return DelayedInitializer(item)
|
122
|
+
|
123
|
+
@property
|
124
|
+
def identifier(self):
|
125
|
+
return self._identifier
|
126
|
+
|
127
|
+
@identifier.setter
|
128
|
+
def identifier(self, value):
|
129
|
+
raise AttributeError('Cannot set the identifier.')
|
130
|
+
|
131
|
+
|
132
|
+
class AlignPost(Mixin):
|
133
|
+
"""
|
134
|
+
Align post MixIn.
|
135
|
+
|
136
|
+
This class provides a ``align_post_input_add()`` function for
|
137
|
+
add external currents.
|
138
|
+
"""
|
139
|
+
|
140
|
+
def align_post_input_add(self, *args, **kwargs):
|
141
|
+
raise NotImplementedError
|
142
|
+
|
143
|
+
|
144
|
+
class BindCondData(Mixin):
|
145
|
+
"""Bind temporary conductance data.
|
146
|
+
|
147
|
+
|
148
|
+
"""
|
149
|
+
_conductance: Optional
|
150
|
+
|
151
|
+
def bind_cond(self, conductance):
|
152
|
+
self._conductance = conductance
|
153
|
+
|
154
|
+
def unbind_cond(self):
|
155
|
+
self._conductance = None
|
156
|
+
|
157
|
+
|
158
|
+
class UpdateReturn(Mixin):
|
159
|
+
|
160
|
+
def update_return(self) -> PyTree:
|
161
|
+
"""
|
162
|
+
The update function return of the model.
|
163
|
+
|
164
|
+
It should be a pytree, with each element as a ``jax.ShapeDtypeStruct`` or ``jax.core.ShapedArray``.
|
165
|
+
|
166
|
+
"""
|
167
|
+
raise NotImplementedError(f'Must implement the "{self.update_return.__name__}()" function.')
|
168
|
+
|
169
|
+
def update_return_info(self) -> PyTree:
|
170
|
+
"""
|
171
|
+
The update return information of the model.
|
172
|
+
|
173
|
+
It should be a pytree, with each element as a ``jax.Array``.
|
174
|
+
|
175
|
+
.. note::
|
176
|
+
Should not include the batch axis and batch size.
|
177
|
+
These information will be inferred from the ``mode`` attribute.
|
178
|
+
|
179
|
+
"""
|
180
|
+
raise NotImplementedError(f'Must implement the "{self.update_return_info.__name__}()" function.')
|
181
|
+
|
182
|
+
|
183
|
+
class _MetaUnionType(type):
|
184
|
+
def __new__(cls, name, bases, dct):
|
185
|
+
if isinstance(bases, type):
|
186
|
+
bases = (bases,)
|
187
|
+
elif isinstance(bases, (list, tuple)):
|
188
|
+
bases = tuple(bases)
|
189
|
+
for base in bases:
|
190
|
+
assert isinstance(base, type), f'Must be type. But got {base}'
|
191
|
+
else:
|
192
|
+
raise TypeError(f'Must be type. But got {bases}')
|
193
|
+
return super().__new__(cls, name, bases, dct)
|
194
|
+
|
195
|
+
def __instancecheck__(self, other):
|
196
|
+
cls_of_other = other.__class__
|
197
|
+
return all([issubclass(cls_of_other, cls) for cls in self.__bases__])
|
198
|
+
|
199
|
+
def __subclasscheck__(self, subclass):
|
200
|
+
return all([issubclass(subclass, cls) for cls in self.__bases__])
|
201
|
+
|
202
|
+
|
203
|
+
class _JointGenericAlias(_UnionGenericAlias, _root=True):
|
204
|
+
def __subclasscheck__(self, subclass):
|
205
|
+
return all([issubclass(subclass, cls) for cls in set(self.__args__)])
|
206
|
+
|
207
|
+
|
208
|
+
@_SpecialForm
|
209
|
+
def AllOfTypes(self, parameters):
|
210
|
+
"""All of types; AllOfTypes[X, Y] means both X and Y.
|
211
|
+
|
212
|
+
To define a union, use e.g. Union[int, str].
|
213
|
+
|
214
|
+
Details:
|
215
|
+
- The arguments must be types and there must be at least one.
|
216
|
+
- None as an argument is a special case and is replaced by `type(None)`.
|
217
|
+
- Unions of unions are flattened, e.g.::
|
218
|
+
|
219
|
+
AllOfTypes[AllOfTypes[int, str], float] == AllOfTypes[int, str, float]
|
220
|
+
|
221
|
+
- Unions of a single argument vanish, e.g.::
|
222
|
+
|
223
|
+
AllOfTypes[int] == int # The constructor actually returns int
|
224
|
+
|
225
|
+
- Redundant arguments are skipped, e.g.::
|
226
|
+
|
227
|
+
AllOfTypes[int, str, int] == AllOfTypes[int, str]
|
228
|
+
|
229
|
+
- When comparing unions, the argument order is ignored, e.g.::
|
230
|
+
|
231
|
+
AllOfTypes[int, str] == AllOfTypes[str, int]
|
232
|
+
|
233
|
+
- You cannot subclass or instantiate a AllOfTypes.
|
234
|
+
- You can use Optional[X] as a shorthand for AllOfTypes[X, None].
|
235
|
+
"""
|
236
|
+
if parameters == ():
|
237
|
+
raise TypeError("Cannot take a Joint of no types.")
|
238
|
+
if not isinstance(parameters, tuple):
|
239
|
+
parameters = (parameters,)
|
240
|
+
msg = "AllOfTypes[arg, ...]: each arg must be a type."
|
241
|
+
parameters = tuple(_type_check(p, msg) for p in parameters)
|
242
|
+
parameters = _remove_dups_flatten(parameters)
|
243
|
+
if len(parameters) == 1:
|
244
|
+
return parameters[0]
|
245
|
+
if len(parameters) == 2 and type(None) in parameters:
|
246
|
+
return _UnionGenericAlias(self, parameters, name="Optional")
|
247
|
+
return _JointGenericAlias(self, parameters)
|
248
|
+
|
249
|
+
|
250
|
+
@_SpecialForm
|
251
|
+
def OneOfTypes(self, parameters):
|
252
|
+
"""Sole type; OneOfTypes[X, Y] means either X or Y.
|
253
|
+
|
254
|
+
To define a union, use e.g. OneOfTypes[int, str]. Details:
|
255
|
+
- The arguments must be types and there must be at least one.
|
256
|
+
- None as an argument is a special case and is replaced by
|
257
|
+
type(None).
|
258
|
+
- Unions of unions are flattened, e.g.::
|
259
|
+
|
260
|
+
assert OneOfTypes[OneOfTypes[int, str], float] == OneOfTypes[int, str, float]
|
261
|
+
|
262
|
+
- Unions of a single argument vanish, e.g.::
|
263
|
+
|
264
|
+
assert OneOfTypes[int] == int # The constructor actually returns int
|
265
|
+
|
266
|
+
- Redundant arguments are skipped, e.g.::
|
267
|
+
|
268
|
+
assert OneOfTypes[int, str, int] == OneOfTypes[int, str]
|
269
|
+
|
270
|
+
- When comparing unions, the argument order is ignored, e.g.::
|
271
|
+
|
272
|
+
assert OneOfTypes[int, str] == OneOfTypes[str, int]
|
273
|
+
|
274
|
+
- You cannot subclass or instantiate a union.
|
275
|
+
- You can use Optional[X] as a shorthand for OneOfTypes[X, None].
|
276
|
+
"""
|
277
|
+
if parameters == ():
|
278
|
+
raise TypeError("Cannot take a Sole of no types.")
|
279
|
+
if not isinstance(parameters, tuple):
|
280
|
+
parameters = (parameters,)
|
281
|
+
msg = "OneOfTypes[arg, ...]: each arg must be a type."
|
282
|
+
parameters = tuple(_type_check(p, msg) for p in parameters)
|
283
|
+
parameters = _remove_dups_flatten(parameters)
|
284
|
+
if len(parameters) == 1:
|
285
|
+
return parameters[0]
|
286
|
+
if len(parameters) == 2 and type(None) in parameters:
|
287
|
+
return _UnionGenericAlias(self, parameters, name="Optional")
|
288
|
+
return _UnionGenericAlias(self, parameters)
|
289
|
+
|
290
|
+
|
291
|
+
class Mode(Mixin):
|
292
|
+
"""
|
293
|
+
Base class for computation behaviors.
|
294
|
+
"""
|
295
|
+
|
296
|
+
def __repr__(self):
|
297
|
+
return self.__class__.__name__
|
298
|
+
|
299
|
+
def __eq__(self, other: 'Mode'):
|
300
|
+
assert isinstance(other, Mode)
|
301
|
+
return other.__class__ == self.__class__
|
302
|
+
|
303
|
+
def is_a(self, mode: type):
|
304
|
+
"""
|
305
|
+
Check whether the mode is exactly the desired mode.
|
306
|
+
"""
|
307
|
+
assert isinstance(mode, type), 'Must be a type.'
|
308
|
+
return self.__class__ == mode
|
309
|
+
|
310
|
+
def has(self, mode: type):
|
311
|
+
"""
|
312
|
+
Check whether the mode is included in the desired mode.
|
313
|
+
"""
|
314
|
+
assert isinstance(mode, type), 'Must be a type.'
|
315
|
+
return isinstance(self, mode)
|
316
|
+
|
317
|
+
|
318
|
+
class JointMode(Mode):
|
319
|
+
"""
|
320
|
+
Joint mode.
|
321
|
+
"""
|
322
|
+
|
323
|
+
def __init__(self, *modes: Mode):
|
324
|
+
for m_ in modes:
|
325
|
+
if not isinstance(m_, Mode):
|
326
|
+
raise TypeError(f'The supported type must be a tuple/list of Mode. But we got {m_}')
|
327
|
+
self.modes = tuple(modes)
|
328
|
+
self.types = set([m.__class__ for m in modes])
|
329
|
+
|
330
|
+
def __repr__(self):
|
331
|
+
return f'{self.__class__.__name__}({", ".join([repr(m) for m in self.modes])})'
|
332
|
+
|
333
|
+
def has(self, mode: type):
|
334
|
+
"""
|
335
|
+
Check whether the mode is included in the desired mode.
|
336
|
+
"""
|
337
|
+
assert isinstance(mode, type), 'Must be a type.'
|
338
|
+
return any([issubclass(cls, mode) for cls in self.types])
|
339
|
+
|
340
|
+
def is_a(self, cls: type):
|
341
|
+
"""
|
342
|
+
Check whether the mode is exactly the desired mode.
|
343
|
+
"""
|
344
|
+
return AllOfTypes[tuple(self.types)] == cls
|
345
|
+
|
346
|
+
def __getattr__(self, item):
|
347
|
+
"""
|
348
|
+
Get the attribute from the mode.
|
349
|
+
|
350
|
+
If the attribute is not found in the mode, then it will be searched in the base class.
|
351
|
+
"""
|
352
|
+
if item in ['modes', 'types']:
|
353
|
+
return super().__getattribute__(item)
|
354
|
+
for m in self.modes:
|
355
|
+
if hasattr(m, item):
|
356
|
+
return getattr(m, item)
|
357
|
+
return super().__getattribute__(item)
|
358
|
+
|
359
|
+
|
360
|
+
class Batching(Mode):
|
361
|
+
"""Batching mode."""
|
362
|
+
|
363
|
+
def __init__(self, batch_size: int = 1, batch_axis: int = 0):
|
364
|
+
self.batch_size = batch_size
|
365
|
+
self.batch_axis = batch_axis
|
366
|
+
|
367
|
+
def __repr__(self):
|
368
|
+
return f'{self.__class__.__name__}(size={self.batch_size}, axis={self.batch_axis})'
|
369
|
+
|
370
|
+
|
371
|
+
class Training(Mode):
|
372
|
+
"""Training mode."""
|
373
|
+
pass
|
brainstate/mixin_test.py
ADDED
@@ -0,0 +1,73 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import unittest
|
17
|
+
|
18
|
+
import brainstate as bc
|
19
|
+
|
20
|
+
|
21
|
+
class TestMixin(unittest.TestCase):
|
22
|
+
def test_mixin(self):
|
23
|
+
self.assertTrue(bc.mixin.Mixin)
|
24
|
+
self.assertTrue(bc.mixin.DelayedInit)
|
25
|
+
self.assertTrue(bc.mixin.DelayedInitializer)
|
26
|
+
self.assertTrue(bc.mixin.AllOfTypes)
|
27
|
+
self.assertTrue(bc.mixin.OneOfTypes)
|
28
|
+
self.assertTrue(bc.mixin.Mode)
|
29
|
+
self.assertTrue(bc.mixin.Batching)
|
30
|
+
self.assertTrue(bc.mixin.Training)
|
31
|
+
|
32
|
+
|
33
|
+
class TestMode(unittest.TestCase):
|
34
|
+
def test_JointMode(self):
|
35
|
+
a = bc.mixin.JointMode(bc.mixin.Batching(), bc.mixin.Training())
|
36
|
+
self.assertTrue(a.is_a(bc.mixin.AllOfTypes[bc.mixin.Batching, bc.mixin.Training]))
|
37
|
+
self.assertTrue(a.has(bc.mixin.Batching))
|
38
|
+
self.assertTrue(a.has(bc.mixin.Training))
|
39
|
+
b = bc.mixin.JointMode(bc.mixin.Batching())
|
40
|
+
self.assertTrue(b.is_a(bc.mixin.AllOfTypes[bc.mixin.Batching]))
|
41
|
+
self.assertTrue(b.is_a(bc.mixin.Batching))
|
42
|
+
self.assertTrue(b.has(bc.mixin.Batching))
|
43
|
+
|
44
|
+
def test_Training(self):
|
45
|
+
a = bc.mixin.Training()
|
46
|
+
self.assertTrue(a.is_a(bc.mixin.Training))
|
47
|
+
self.assertTrue(a.is_a(bc.mixin.AllOfTypes[bc.mixin.Training]))
|
48
|
+
self.assertTrue(a.has(bc.mixin.Training))
|
49
|
+
self.assertTrue(a.has(bc.mixin.AllOfTypes[bc.mixin.Training]))
|
50
|
+
self.assertFalse(a.is_a(bc.mixin.Batching))
|
51
|
+
self.assertFalse(a.has(bc.mixin.Batching))
|
52
|
+
|
53
|
+
def test_Batching(self):
|
54
|
+
a = bc.mixin.Batching()
|
55
|
+
self.assertTrue(a.is_a(bc.mixin.Batching))
|
56
|
+
self.assertTrue(a.is_a(bc.mixin.AllOfTypes[bc.mixin.Batching]))
|
57
|
+
self.assertTrue(a.has(bc.mixin.Batching))
|
58
|
+
self.assertTrue(a.has(bc.mixin.AllOfTypes[bc.mixin.Batching]))
|
59
|
+
|
60
|
+
self.assertFalse(a.is_a(bc.mixin.Training))
|
61
|
+
self.assertFalse(a.has(bc.mixin.Training))
|
62
|
+
|
63
|
+
def test_Mode(self):
|
64
|
+
a = bc.mixin.Mode()
|
65
|
+
self.assertTrue(a.is_a(bc.mixin.Mode))
|
66
|
+
self.assertTrue(a.is_a(bc.mixin.AllOfTypes[bc.mixin.Mode]))
|
67
|
+
self.assertTrue(a.has(bc.mixin.Mode))
|
68
|
+
self.assertTrue(a.has(bc.mixin.AllOfTypes[bc.mixin.Mode]))
|
69
|
+
|
70
|
+
self.assertFalse(a.is_a(bc.mixin.Training))
|
71
|
+
self.assertFalse(a.has(bc.mixin.Training))
|
72
|
+
self.assertFalse(a.is_a(bc.mixin.Batching))
|
73
|
+
self.assertFalse(a.has(bc.mixin.Batching))
|
@@ -0,0 +1,68 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from ._base import *
|
17
|
+
from ._base import __all__ as base_all
|
18
|
+
from ._connections import *
|
19
|
+
from ._connections import __all__ as connections_all
|
20
|
+
from ._dynamics import *
|
21
|
+
from ._dynamics import __all__ as dynamics_all
|
22
|
+
from ._elementwise import *
|
23
|
+
from ._elementwise import __all__ as elementwise_all
|
24
|
+
from ._misc import *
|
25
|
+
from ._misc import __all__ as _misc_all
|
26
|
+
from ._normalizations import *
|
27
|
+
from ._normalizations import __all__ as normalizations_all
|
28
|
+
from ._others import *
|
29
|
+
from ._others import __all__ as others_all
|
30
|
+
from ._poolings import *
|
31
|
+
from ._poolings import __all__ as poolings_all
|
32
|
+
from ._projection import *
|
33
|
+
from ._projection import __all__ as _projection_all
|
34
|
+
from ._rate_rnns import *
|
35
|
+
from ._rate_rnns import __all__ as rate_rnns
|
36
|
+
from ._readout import *
|
37
|
+
from ._readout import __all__ as readout_all
|
38
|
+
from ._synouts import *
|
39
|
+
from ._synouts import __all__ as synouts_all
|
40
|
+
|
41
|
+
__all__ = (
|
42
|
+
base_all +
|
43
|
+
connections_all +
|
44
|
+
dynamics_all +
|
45
|
+
elementwise_all +
|
46
|
+
normalizations_all +
|
47
|
+
others_all +
|
48
|
+
poolings_all +
|
49
|
+
rate_rnns +
|
50
|
+
readout_all +
|
51
|
+
synouts_all +
|
52
|
+
_projection_all +
|
53
|
+
_misc_all
|
54
|
+
)
|
55
|
+
|
56
|
+
del (
|
57
|
+
base_all,
|
58
|
+
connections_all,
|
59
|
+
dynamics_all,
|
60
|
+
elementwise_all,
|
61
|
+
normalizations_all,
|
62
|
+
others_all,
|
63
|
+
poolings_all,
|
64
|
+
readout_all,
|
65
|
+
synouts_all,
|
66
|
+
_projection_all,
|
67
|
+
_misc_all
|
68
|
+
)
|