brainstate 0.0.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. brainstate/__init__.py +45 -0
  2. brainstate/_module.py +1466 -0
  3. brainstate/_module_test.py +133 -0
  4. brainstate/_state.py +378 -0
  5. brainstate/_state_test.py +41 -0
  6. brainstate/_utils.py +21 -0
  7. brainstate/environ.py +375 -0
  8. brainstate/functional/__init__.py +25 -0
  9. brainstate/functional/_activations.py +754 -0
  10. brainstate/functional/_normalization.py +69 -0
  11. brainstate/functional/_spikes.py +90 -0
  12. brainstate/init/__init__.py +26 -0
  13. brainstate/init/_base.py +36 -0
  14. brainstate/init/_generic.py +175 -0
  15. brainstate/init/_random_inits.py +489 -0
  16. brainstate/init/_regular_inits.py +109 -0
  17. brainstate/math/__init__.py +21 -0
  18. brainstate/math/_einops.py +787 -0
  19. brainstate/math/_einops_parsing.py +169 -0
  20. brainstate/math/_einops_parsing_test.py +126 -0
  21. brainstate/math/_einops_test.py +346 -0
  22. brainstate/math/_misc.py +298 -0
  23. brainstate/math/_misc_test.py +58 -0
  24. brainstate/mixin.py +373 -0
  25. brainstate/mixin_test.py +73 -0
  26. brainstate/nn/__init__.py +68 -0
  27. brainstate/nn/_base.py +248 -0
  28. brainstate/nn/_connections.py +686 -0
  29. brainstate/nn/_dynamics.py +406 -0
  30. brainstate/nn/_elementwise.py +1437 -0
  31. brainstate/nn/_misc.py +132 -0
  32. brainstate/nn/_normalizations.py +389 -0
  33. brainstate/nn/_others.py +100 -0
  34. brainstate/nn/_poolings.py +1228 -0
  35. brainstate/nn/_poolings_test.py +231 -0
  36. brainstate/nn/_projection/__init__.py +32 -0
  37. brainstate/nn/_projection/_align_post.py +528 -0
  38. brainstate/nn/_projection/_align_pre.py +599 -0
  39. brainstate/nn/_projection/_delta.py +241 -0
  40. brainstate/nn/_projection/_utils.py +17 -0
  41. brainstate/nn/_projection/_vanilla.py +101 -0
  42. brainstate/nn/_rate_rnns.py +393 -0
  43. brainstate/nn/_readout.py +130 -0
  44. brainstate/nn/_synouts.py +166 -0
  45. brainstate/nn/functional/__init__.py +25 -0
  46. brainstate/nn/functional/_activations.py +754 -0
  47. brainstate/nn/functional/_normalization.py +69 -0
  48. brainstate/nn/functional/_spikes.py +90 -0
  49. brainstate/nn/init/__init__.py +26 -0
  50. brainstate/nn/init/_base.py +36 -0
  51. brainstate/nn/init/_generic.py +175 -0
  52. brainstate/nn/init/_random_inits.py +489 -0
  53. brainstate/nn/init/_regular_inits.py +109 -0
  54. brainstate/nn/surrogate.py +1740 -0
  55. brainstate/optim/__init__.py +23 -0
  56. brainstate/optim/_lr_scheduler.py +486 -0
  57. brainstate/optim/_lr_scheduler_test.py +36 -0
  58. brainstate/optim/_sgd_optimizer.py +1148 -0
  59. brainstate/random.py +5148 -0
  60. brainstate/random_test.py +576 -0
  61. brainstate/surrogate.py +1740 -0
  62. brainstate/transform/__init__.py +36 -0
  63. brainstate/transform/_autograd.py +585 -0
  64. brainstate/transform/_autograd_test.py +1183 -0
  65. brainstate/transform/_control.py +665 -0
  66. brainstate/transform/_controls_test.py +220 -0
  67. brainstate/transform/_jit.py +239 -0
  68. brainstate/transform/_jit_error.py +158 -0
  69. brainstate/transform/_jit_test.py +102 -0
  70. brainstate/transform/_make_jaxpr.py +573 -0
  71. brainstate/transform/_make_jaxpr_test.py +133 -0
  72. brainstate/transform/_progress_bar.py +113 -0
  73. brainstate/typing.py +69 -0
  74. brainstate/util.py +747 -0
  75. brainstate-0.0.1.dist-info/LICENSE +202 -0
  76. brainstate-0.0.1.dist-info/METADATA +101 -0
  77. brainstate-0.0.1.dist-info/RECORD +79 -0
  78. brainstate-0.0.1.dist-info/WHEEL +6 -0
  79. brainstate-0.0.1.dist-info/top_level.txt +1 -0
brainstate/mixin.py ADDED
@@ -0,0 +1,373 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from typing import Sequence, Optional, TypeVar, Any
19
+ from typing import (_SpecialForm, _type_check, _remove_dups_flatten, _UnionGenericAlias)
20
+
21
+ T = TypeVar('T')
22
+ PyTree = Any
23
+
24
+ State = None
25
+
26
+ __all__ = [
27
+ 'Mixin',
28
+ 'DelayedInit',
29
+ 'DelayedInitializer',
30
+ 'AlignPost',
31
+ 'BindCondData',
32
+ 'UpdateReturn',
33
+
34
+ # types
35
+ 'AllOfTypes',
36
+ 'OneOfTypes',
37
+
38
+ # behavior modes
39
+ 'Mode',
40
+ 'JointMode',
41
+ 'Batching',
42
+ 'Training',
43
+ ]
44
+
45
+
46
+ def _get_state():
47
+ global State
48
+ if State is None:
49
+ from brainstate._state import State
50
+ return State
51
+
52
+
53
+ class Mixin(object):
54
+ """Base Mixin object.
55
+
56
+ The key for a :py:class:`~.Mixin` is that: no initialization function, only behavioral functions.
57
+ """
58
+ pass
59
+
60
+
61
+ class DelayedInit(Mixin):
62
+ """
63
+ :py:class:`~.Mixin` indicates the function for describing initialization parameters.
64
+
65
+ This mixin enables the subclass has a classmethod ``delayed``, which
66
+ produces an instance of :py:class:`~.DelayedInitializer`.
67
+
68
+ Note this Mixin can be applied in any Python object.
69
+ """
70
+
71
+ non_hash_params: Optional[Sequence[str]] = None
72
+
73
+ @classmethod
74
+ def delayed(cls, *args, **kwargs) -> 'DelayedInitializer':
75
+ return DelayedInitializer(cls, *args, **kwargs)
76
+
77
+
78
+ class HashableDict(dict):
79
+ def __hash__(self):
80
+ return hash(tuple(sorted(self.items())))
81
+
82
+
83
+ class NoSubclassMeta(type):
84
+ def __new__(cls, name, bases, classdict):
85
+ for b in bases:
86
+ if isinstance(b, NoSubclassMeta):
87
+ raise TypeError("type '{0}' is not an acceptable base type".format(b.__name__))
88
+ return type.__new__(cls, name, bases, dict(classdict))
89
+
90
+
91
+ class DelayedInitializer(metaclass=NoSubclassMeta):
92
+ """
93
+ DelayedInit initialization for parameter describers.
94
+ """
95
+
96
+ def __init__(self, cls: T, *desc_tuple, **desc_dict):
97
+ self.cls = cls
98
+
99
+ # arguments
100
+ self.args = desc_tuple
101
+ self.kwargs = desc_dict
102
+
103
+ # identifier
104
+ self._identifier = (cls, tuple(desc_tuple), HashableDict(desc_dict))
105
+
106
+ def __call__(self, *args, **kwargs) -> T:
107
+ return self.cls(*self.args, *args, **self.kwargs, **kwargs)
108
+
109
+ def init(self, *args, **kwargs):
110
+ return self.__call__(*args, **kwargs)
111
+
112
+ def __instancecheck__(self, instance):
113
+ if not isinstance(instance, DelayedInitializer):
114
+ return False
115
+ if not issubclass(instance.cls, self.cls):
116
+ return False
117
+ return True
118
+
119
+ @classmethod
120
+ def __class_getitem__(cls, item: type):
121
+ return DelayedInitializer(item)
122
+
123
+ @property
124
+ def identifier(self):
125
+ return self._identifier
126
+
127
+ @identifier.setter
128
+ def identifier(self, value):
129
+ raise AttributeError('Cannot set the identifier.')
130
+
131
+
132
+ class AlignPost(Mixin):
133
+ """
134
+ Align post MixIn.
135
+
136
+ This class provides a ``align_post_input_add()`` function for
137
+ add external currents.
138
+ """
139
+
140
+ def align_post_input_add(self, *args, **kwargs):
141
+ raise NotImplementedError
142
+
143
+
144
+ class BindCondData(Mixin):
145
+ """Bind temporary conductance data.
146
+
147
+
148
+ """
149
+ _conductance: Optional
150
+
151
+ def bind_cond(self, conductance):
152
+ self._conductance = conductance
153
+
154
+ def unbind_cond(self):
155
+ self._conductance = None
156
+
157
+
158
+ class UpdateReturn(Mixin):
159
+
160
+ def update_return(self) -> PyTree:
161
+ """
162
+ The update function return of the model.
163
+
164
+ It should be a pytree, with each element as a ``jax.ShapeDtypeStruct`` or ``jax.core.ShapedArray``.
165
+
166
+ """
167
+ raise NotImplementedError(f'Must implement the "{self.update_return.__name__}()" function.')
168
+
169
+ def update_return_info(self) -> PyTree:
170
+ """
171
+ The update return information of the model.
172
+
173
+ It should be a pytree, with each element as a ``jax.Array``.
174
+
175
+ .. note::
176
+ Should not include the batch axis and batch size.
177
+ These information will be inferred from the ``mode`` attribute.
178
+
179
+ """
180
+ raise NotImplementedError(f'Must implement the "{self.update_return_info.__name__}()" function.')
181
+
182
+
183
+ class _MetaUnionType(type):
184
+ def __new__(cls, name, bases, dct):
185
+ if isinstance(bases, type):
186
+ bases = (bases,)
187
+ elif isinstance(bases, (list, tuple)):
188
+ bases = tuple(bases)
189
+ for base in bases:
190
+ assert isinstance(base, type), f'Must be type. But got {base}'
191
+ else:
192
+ raise TypeError(f'Must be type. But got {bases}')
193
+ return super().__new__(cls, name, bases, dct)
194
+
195
+ def __instancecheck__(self, other):
196
+ cls_of_other = other.__class__
197
+ return all([issubclass(cls_of_other, cls) for cls in self.__bases__])
198
+
199
+ def __subclasscheck__(self, subclass):
200
+ return all([issubclass(subclass, cls) for cls in self.__bases__])
201
+
202
+
203
+ class _JointGenericAlias(_UnionGenericAlias, _root=True):
204
+ def __subclasscheck__(self, subclass):
205
+ return all([issubclass(subclass, cls) for cls in set(self.__args__)])
206
+
207
+
208
+ @_SpecialForm
209
+ def AllOfTypes(self, parameters):
210
+ """All of types; AllOfTypes[X, Y] means both X and Y.
211
+
212
+ To define a union, use e.g. Union[int, str].
213
+
214
+ Details:
215
+ - The arguments must be types and there must be at least one.
216
+ - None as an argument is a special case and is replaced by `type(None)`.
217
+ - Unions of unions are flattened, e.g.::
218
+
219
+ AllOfTypes[AllOfTypes[int, str], float] == AllOfTypes[int, str, float]
220
+
221
+ - Unions of a single argument vanish, e.g.::
222
+
223
+ AllOfTypes[int] == int # The constructor actually returns int
224
+
225
+ - Redundant arguments are skipped, e.g.::
226
+
227
+ AllOfTypes[int, str, int] == AllOfTypes[int, str]
228
+
229
+ - When comparing unions, the argument order is ignored, e.g.::
230
+
231
+ AllOfTypes[int, str] == AllOfTypes[str, int]
232
+
233
+ - You cannot subclass or instantiate a AllOfTypes.
234
+ - You can use Optional[X] as a shorthand for AllOfTypes[X, None].
235
+ """
236
+ if parameters == ():
237
+ raise TypeError("Cannot take a Joint of no types.")
238
+ if not isinstance(parameters, tuple):
239
+ parameters = (parameters,)
240
+ msg = "AllOfTypes[arg, ...]: each arg must be a type."
241
+ parameters = tuple(_type_check(p, msg) for p in parameters)
242
+ parameters = _remove_dups_flatten(parameters)
243
+ if len(parameters) == 1:
244
+ return parameters[0]
245
+ if len(parameters) == 2 and type(None) in parameters:
246
+ return _UnionGenericAlias(self, parameters, name="Optional")
247
+ return _JointGenericAlias(self, parameters)
248
+
249
+
250
+ @_SpecialForm
251
+ def OneOfTypes(self, parameters):
252
+ """Sole type; OneOfTypes[X, Y] means either X or Y.
253
+
254
+ To define a union, use e.g. OneOfTypes[int, str]. Details:
255
+ - The arguments must be types and there must be at least one.
256
+ - None as an argument is a special case and is replaced by
257
+ type(None).
258
+ - Unions of unions are flattened, e.g.::
259
+
260
+ assert OneOfTypes[OneOfTypes[int, str], float] == OneOfTypes[int, str, float]
261
+
262
+ - Unions of a single argument vanish, e.g.::
263
+
264
+ assert OneOfTypes[int] == int # The constructor actually returns int
265
+
266
+ - Redundant arguments are skipped, e.g.::
267
+
268
+ assert OneOfTypes[int, str, int] == OneOfTypes[int, str]
269
+
270
+ - When comparing unions, the argument order is ignored, e.g.::
271
+
272
+ assert OneOfTypes[int, str] == OneOfTypes[str, int]
273
+
274
+ - You cannot subclass or instantiate a union.
275
+ - You can use Optional[X] as a shorthand for OneOfTypes[X, None].
276
+ """
277
+ if parameters == ():
278
+ raise TypeError("Cannot take a Sole of no types.")
279
+ if not isinstance(parameters, tuple):
280
+ parameters = (parameters,)
281
+ msg = "OneOfTypes[arg, ...]: each arg must be a type."
282
+ parameters = tuple(_type_check(p, msg) for p in parameters)
283
+ parameters = _remove_dups_flatten(parameters)
284
+ if len(parameters) == 1:
285
+ return parameters[0]
286
+ if len(parameters) == 2 and type(None) in parameters:
287
+ return _UnionGenericAlias(self, parameters, name="Optional")
288
+ return _UnionGenericAlias(self, parameters)
289
+
290
+
291
+ class Mode(Mixin):
292
+ """
293
+ Base class for computation behaviors.
294
+ """
295
+
296
+ def __repr__(self):
297
+ return self.__class__.__name__
298
+
299
+ def __eq__(self, other: 'Mode'):
300
+ assert isinstance(other, Mode)
301
+ return other.__class__ == self.__class__
302
+
303
+ def is_a(self, mode: type):
304
+ """
305
+ Check whether the mode is exactly the desired mode.
306
+ """
307
+ assert isinstance(mode, type), 'Must be a type.'
308
+ return self.__class__ == mode
309
+
310
+ def has(self, mode: type):
311
+ """
312
+ Check whether the mode is included in the desired mode.
313
+ """
314
+ assert isinstance(mode, type), 'Must be a type.'
315
+ return isinstance(self, mode)
316
+
317
+
318
+ class JointMode(Mode):
319
+ """
320
+ Joint mode.
321
+ """
322
+
323
+ def __init__(self, *modes: Mode):
324
+ for m_ in modes:
325
+ if not isinstance(m_, Mode):
326
+ raise TypeError(f'The supported type must be a tuple/list of Mode. But we got {m_}')
327
+ self.modes = tuple(modes)
328
+ self.types = set([m.__class__ for m in modes])
329
+
330
+ def __repr__(self):
331
+ return f'{self.__class__.__name__}({", ".join([repr(m) for m in self.modes])})'
332
+
333
+ def has(self, mode: type):
334
+ """
335
+ Check whether the mode is included in the desired mode.
336
+ """
337
+ assert isinstance(mode, type), 'Must be a type.'
338
+ return any([issubclass(cls, mode) for cls in self.types])
339
+
340
+ def is_a(self, cls: type):
341
+ """
342
+ Check whether the mode is exactly the desired mode.
343
+ """
344
+ return AllOfTypes[tuple(self.types)] == cls
345
+
346
+ def __getattr__(self, item):
347
+ """
348
+ Get the attribute from the mode.
349
+
350
+ If the attribute is not found in the mode, then it will be searched in the base class.
351
+ """
352
+ if item in ['modes', 'types']:
353
+ return super().__getattribute__(item)
354
+ for m in self.modes:
355
+ if hasattr(m, item):
356
+ return getattr(m, item)
357
+ return super().__getattribute__(item)
358
+
359
+
360
+ class Batching(Mode):
361
+ """Batching mode."""
362
+
363
+ def __init__(self, batch_size: int = 1, batch_axis: int = 0):
364
+ self.batch_size = batch_size
365
+ self.batch_axis = batch_axis
366
+
367
+ def __repr__(self):
368
+ return f'{self.__class__.__name__}(size={self.batch_size}, axis={self.batch_axis})'
369
+
370
+
371
+ class Training(Mode):
372
+ """Training mode."""
373
+ pass
@@ -0,0 +1,73 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import unittest
17
+
18
+ import brainstate as bc
19
+
20
+
21
+ class TestMixin(unittest.TestCase):
22
+ def test_mixin(self):
23
+ self.assertTrue(bc.mixin.Mixin)
24
+ self.assertTrue(bc.mixin.DelayedInit)
25
+ self.assertTrue(bc.mixin.DelayedInitializer)
26
+ self.assertTrue(bc.mixin.AllOfTypes)
27
+ self.assertTrue(bc.mixin.OneOfTypes)
28
+ self.assertTrue(bc.mixin.Mode)
29
+ self.assertTrue(bc.mixin.Batching)
30
+ self.assertTrue(bc.mixin.Training)
31
+
32
+
33
+ class TestMode(unittest.TestCase):
34
+ def test_JointMode(self):
35
+ a = bc.mixin.JointMode(bc.mixin.Batching(), bc.mixin.Training())
36
+ self.assertTrue(a.is_a(bc.mixin.AllOfTypes[bc.mixin.Batching, bc.mixin.Training]))
37
+ self.assertTrue(a.has(bc.mixin.Batching))
38
+ self.assertTrue(a.has(bc.mixin.Training))
39
+ b = bc.mixin.JointMode(bc.mixin.Batching())
40
+ self.assertTrue(b.is_a(bc.mixin.AllOfTypes[bc.mixin.Batching]))
41
+ self.assertTrue(b.is_a(bc.mixin.Batching))
42
+ self.assertTrue(b.has(bc.mixin.Batching))
43
+
44
+ def test_Training(self):
45
+ a = bc.mixin.Training()
46
+ self.assertTrue(a.is_a(bc.mixin.Training))
47
+ self.assertTrue(a.is_a(bc.mixin.AllOfTypes[bc.mixin.Training]))
48
+ self.assertTrue(a.has(bc.mixin.Training))
49
+ self.assertTrue(a.has(bc.mixin.AllOfTypes[bc.mixin.Training]))
50
+ self.assertFalse(a.is_a(bc.mixin.Batching))
51
+ self.assertFalse(a.has(bc.mixin.Batching))
52
+
53
+ def test_Batching(self):
54
+ a = bc.mixin.Batching()
55
+ self.assertTrue(a.is_a(bc.mixin.Batching))
56
+ self.assertTrue(a.is_a(bc.mixin.AllOfTypes[bc.mixin.Batching]))
57
+ self.assertTrue(a.has(bc.mixin.Batching))
58
+ self.assertTrue(a.has(bc.mixin.AllOfTypes[bc.mixin.Batching]))
59
+
60
+ self.assertFalse(a.is_a(bc.mixin.Training))
61
+ self.assertFalse(a.has(bc.mixin.Training))
62
+
63
+ def test_Mode(self):
64
+ a = bc.mixin.Mode()
65
+ self.assertTrue(a.is_a(bc.mixin.Mode))
66
+ self.assertTrue(a.is_a(bc.mixin.AllOfTypes[bc.mixin.Mode]))
67
+ self.assertTrue(a.has(bc.mixin.Mode))
68
+ self.assertTrue(a.has(bc.mixin.AllOfTypes[bc.mixin.Mode]))
69
+
70
+ self.assertFalse(a.is_a(bc.mixin.Training))
71
+ self.assertFalse(a.has(bc.mixin.Training))
72
+ self.assertFalse(a.is_a(bc.mixin.Batching))
73
+ self.assertFalse(a.has(bc.mixin.Batching))
@@ -0,0 +1,68 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from ._base import *
17
+ from ._base import __all__ as base_all
18
+ from ._connections import *
19
+ from ._connections import __all__ as connections_all
20
+ from ._dynamics import *
21
+ from ._dynamics import __all__ as dynamics_all
22
+ from ._elementwise import *
23
+ from ._elementwise import __all__ as elementwise_all
24
+ from ._misc import *
25
+ from ._misc import __all__ as _misc_all
26
+ from ._normalizations import *
27
+ from ._normalizations import __all__ as normalizations_all
28
+ from ._others import *
29
+ from ._others import __all__ as others_all
30
+ from ._poolings import *
31
+ from ._poolings import __all__ as poolings_all
32
+ from ._projection import *
33
+ from ._projection import __all__ as _projection_all
34
+ from ._rate_rnns import *
35
+ from ._rate_rnns import __all__ as rate_rnns
36
+ from ._readout import *
37
+ from ._readout import __all__ as readout_all
38
+ from ._synouts import *
39
+ from ._synouts import __all__ as synouts_all
40
+
41
+ __all__ = (
42
+ base_all +
43
+ connections_all +
44
+ dynamics_all +
45
+ elementwise_all +
46
+ normalizations_all +
47
+ others_all +
48
+ poolings_all +
49
+ rate_rnns +
50
+ readout_all +
51
+ synouts_all +
52
+ _projection_all +
53
+ _misc_all
54
+ )
55
+
56
+ del (
57
+ base_all,
58
+ connections_all,
59
+ dynamics_all,
60
+ elementwise_all,
61
+ normalizations_all,
62
+ others_all,
63
+ poolings_all,
64
+ readout_all,
65
+ synouts_all,
66
+ _projection_all,
67
+ _misc_all
68
+ )