brainstate 0.0.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +45 -0
- brainstate/_module.py +1466 -0
- brainstate/_module_test.py +133 -0
- brainstate/_state.py +378 -0
- brainstate/_state_test.py +41 -0
- brainstate/_utils.py +21 -0
- brainstate/environ.py +375 -0
- brainstate/functional/__init__.py +25 -0
- brainstate/functional/_activations.py +754 -0
- brainstate/functional/_normalization.py +69 -0
- brainstate/functional/_spikes.py +90 -0
- brainstate/init/__init__.py +26 -0
- brainstate/init/_base.py +36 -0
- brainstate/init/_generic.py +175 -0
- brainstate/init/_random_inits.py +489 -0
- brainstate/init/_regular_inits.py +109 -0
- brainstate/math/__init__.py +21 -0
- brainstate/math/_einops.py +787 -0
- brainstate/math/_einops_parsing.py +169 -0
- brainstate/math/_einops_parsing_test.py +126 -0
- brainstate/math/_einops_test.py +346 -0
- brainstate/math/_misc.py +298 -0
- brainstate/math/_misc_test.py +58 -0
- brainstate/mixin.py +373 -0
- brainstate/mixin_test.py +73 -0
- brainstate/nn/__init__.py +68 -0
- brainstate/nn/_base.py +248 -0
- brainstate/nn/_connections.py +686 -0
- brainstate/nn/_dynamics.py +406 -0
- brainstate/nn/_elementwise.py +1437 -0
- brainstate/nn/_misc.py +132 -0
- brainstate/nn/_normalizations.py +389 -0
- brainstate/nn/_others.py +100 -0
- brainstate/nn/_poolings.py +1228 -0
- brainstate/nn/_poolings_test.py +231 -0
- brainstate/nn/_projection/__init__.py +32 -0
- brainstate/nn/_projection/_align_post.py +528 -0
- brainstate/nn/_projection/_align_pre.py +599 -0
- brainstate/nn/_projection/_delta.py +241 -0
- brainstate/nn/_projection/_utils.py +17 -0
- brainstate/nn/_projection/_vanilla.py +101 -0
- brainstate/nn/_rate_rnns.py +393 -0
- brainstate/nn/_readout.py +130 -0
- brainstate/nn/_synouts.py +166 -0
- brainstate/nn/functional/__init__.py +25 -0
- brainstate/nn/functional/_activations.py +754 -0
- brainstate/nn/functional/_normalization.py +69 -0
- brainstate/nn/functional/_spikes.py +90 -0
- brainstate/nn/init/__init__.py +26 -0
- brainstate/nn/init/_base.py +36 -0
- brainstate/nn/init/_generic.py +175 -0
- brainstate/nn/init/_random_inits.py +489 -0
- brainstate/nn/init/_regular_inits.py +109 -0
- brainstate/nn/surrogate.py +1740 -0
- brainstate/optim/__init__.py +23 -0
- brainstate/optim/_lr_scheduler.py +486 -0
- brainstate/optim/_lr_scheduler_test.py +36 -0
- brainstate/optim/_sgd_optimizer.py +1148 -0
- brainstate/random.py +5148 -0
- brainstate/random_test.py +576 -0
- brainstate/surrogate.py +1740 -0
- brainstate/transform/__init__.py +36 -0
- brainstate/transform/_autograd.py +585 -0
- brainstate/transform/_autograd_test.py +1183 -0
- brainstate/transform/_control.py +665 -0
- brainstate/transform/_controls_test.py +220 -0
- brainstate/transform/_jit.py +239 -0
- brainstate/transform/_jit_error.py +158 -0
- brainstate/transform/_jit_test.py +102 -0
- brainstate/transform/_make_jaxpr.py +573 -0
- brainstate/transform/_make_jaxpr_test.py +133 -0
- brainstate/transform/_progress_bar.py +113 -0
- brainstate/typing.py +69 -0
- brainstate/util.py +747 -0
- brainstate-0.0.1.dist-info/LICENSE +202 -0
- brainstate-0.0.1.dist-info/METADATA +101 -0
- brainstate-0.0.1.dist-info/RECORD +79 -0
- brainstate-0.0.1.dist-info/WHEEL +6 -0
- brainstate-0.0.1.dist-info/top_level.txt +1 -0
brainstate/nn/_base.py
ADDED
@@ -0,0 +1,248 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
from __future__ import annotations
|
19
|
+
|
20
|
+
import inspect
|
21
|
+
from typing import Sequence, Optional, Tuple, Union
|
22
|
+
|
23
|
+
from .._module import Module, UpdateReturn, Container, visible_module_dict
|
24
|
+
from ..mixin import Mixin, DelayedInitializer, DelayedInit
|
25
|
+
|
26
|
+
__all__ = [
|
27
|
+
'ExplicitInOutSize',
|
28
|
+
'ElementWiseBlock',
|
29
|
+
'Sequential',
|
30
|
+
'DnnLayer',
|
31
|
+
]
|
32
|
+
|
33
|
+
|
34
|
+
# -------------------------------------------------------------------------------------- #
|
35
|
+
# Network Related Concepts
|
36
|
+
# -------------------------------------------------------------------------------------- #
|
37
|
+
|
38
|
+
|
39
|
+
class ExplicitInOutSize(Mixin):
|
40
|
+
"""
|
41
|
+
Mix-in class with the explicit input and output shape.
|
42
|
+
|
43
|
+
Attributes
|
44
|
+
----------
|
45
|
+
in_size: tuple[int]
|
46
|
+
The input shape, without the batch size. This argument is important, since it is
|
47
|
+
used to evaluate the shape of the output.
|
48
|
+
out_size: tuple[int]
|
49
|
+
The output shape, without the batch size.
|
50
|
+
"""
|
51
|
+
__module__ = 'brainstate.nn'
|
52
|
+
|
53
|
+
_in_size: Optional[Tuple[int, ...]] = None
|
54
|
+
_out_size: Optional[Tuple[int, ...]] = None
|
55
|
+
|
56
|
+
@property
|
57
|
+
def in_size(self) -> Tuple[int, ...]:
|
58
|
+
if self._in_size is None:
|
59
|
+
raise ValueError(f"The input shape is not set in this node: {self} ")
|
60
|
+
return self._in_size
|
61
|
+
|
62
|
+
@in_size.setter
|
63
|
+
def in_size(self, in_size: Sequence[int]):
|
64
|
+
self._in_size = tuple(in_size)
|
65
|
+
|
66
|
+
@property
|
67
|
+
def out_size(self) -> Tuple[int, ...]:
|
68
|
+
if self._out_size is None:
|
69
|
+
raise ValueError(f"The output shape is not set in this node: {self}")
|
70
|
+
return self._out_size
|
71
|
+
|
72
|
+
@out_size.setter
|
73
|
+
def out_size(self, out_size: Sequence[int]):
|
74
|
+
self._out_size = tuple(out_size)
|
75
|
+
|
76
|
+
|
77
|
+
class ElementWiseBlock(Mixin):
|
78
|
+
"""
|
79
|
+
Mix-in class for element-wise modules.
|
80
|
+
"""
|
81
|
+
__module__ = 'brainstate.nn'
|
82
|
+
|
83
|
+
|
84
|
+
class Sequential(Module, UpdateReturn, Container, ExplicitInOutSize):
|
85
|
+
"""
|
86
|
+
A sequential `input-output` module.
|
87
|
+
|
88
|
+
Modules will be added to it in the order they are passed in the
|
89
|
+
constructor. Alternatively, an ``dict`` of modules can be
|
90
|
+
passed in. The ``update()`` method of ``Sequential`` accepts any
|
91
|
+
input and forwards it to the first module it contains. It then
|
92
|
+
"chains" outputs to inputs sequentially for each subsequent module,
|
93
|
+
finally returning the output of the last module.
|
94
|
+
|
95
|
+
The value a ``Sequential`` provides over manually calling a sequence
|
96
|
+
of modules is that it allows treating the whole container as a
|
97
|
+
single module, such that performing a transformation on the
|
98
|
+
``Sequential`` applies to each of the modules it stores (which are
|
99
|
+
each a registered submodule of the ``Sequential``).
|
100
|
+
|
101
|
+
What's the difference between a ``Sequential`` and a
|
102
|
+
:py:class:`Container`? A ``Container`` is exactly what it
|
103
|
+
sounds like--a container to store :py:class:`DynamicalSystem` s!
|
104
|
+
On the other hand, the layers in a ``Sequential`` are connected
|
105
|
+
in a cascading way.
|
106
|
+
|
107
|
+
Examples
|
108
|
+
--------
|
109
|
+
|
110
|
+
>>> import jax
|
111
|
+
>>> import brainstate as bst
|
112
|
+
>>> import brainstate.nn as nn
|
113
|
+
>>>
|
114
|
+
>>> # composing ANN models
|
115
|
+
>>> l = nn.Sequential(nn.Linear(100, 10),
|
116
|
+
>>> jax.nn.relu,
|
117
|
+
>>> nn.Linear(10, 2))
|
118
|
+
>>> l(bst.random.random((256, 100)))
|
119
|
+
>>>
|
120
|
+
>>> # Using Sequential with Dict. This is functionally the
|
121
|
+
>>> # same as the above code
|
122
|
+
>>> l = nn.Sequential(l1=nn.Linear(100, 10),
|
123
|
+
>>> l2=jax.nn.relu,
|
124
|
+
>>> l3=nn.Linear(10, 2))
|
125
|
+
>>> l(bst.random.random((256, 100)))
|
126
|
+
|
127
|
+
Args:
|
128
|
+
modules_as_tuple: The children modules.
|
129
|
+
modules_as_dict: The children modules.
|
130
|
+
name: The object name.
|
131
|
+
mode: The object computing context/mode. Default is ``None``.
|
132
|
+
"""
|
133
|
+
__module__ = 'brainstate.nn'
|
134
|
+
|
135
|
+
def __init__(self, first: ExplicitInOutSize, *modules_as_tuple, **modules_as_dict):
|
136
|
+
super().__init__()
|
137
|
+
|
138
|
+
assert isinstance(first, ExplicitInOutSize)
|
139
|
+
in_size = first.out_size
|
140
|
+
|
141
|
+
tuple_modules = []
|
142
|
+
for module in modules_as_tuple:
|
143
|
+
module, in_size = self._format_module(module, in_size)
|
144
|
+
tuple_modules.append(module)
|
145
|
+
|
146
|
+
dict_modules = dict()
|
147
|
+
for key, module in modules_as_dict.items():
|
148
|
+
module, in_size = self._format_module(module, in_size)
|
149
|
+
dict_modules[key] = module
|
150
|
+
|
151
|
+
# Attribute of "Container"
|
152
|
+
self.children = visible_module_dict(self.format_elements(object, first, *tuple_modules, **dict_modules))
|
153
|
+
|
154
|
+
# the input and output shape
|
155
|
+
self.in_size = tuple(first.in_size)
|
156
|
+
self.out_size = tuple(in_size)
|
157
|
+
|
158
|
+
def _format_module(self, module, in_size):
|
159
|
+
if isinstance(module, DelayedInitializer):
|
160
|
+
module = module(in_size=in_size)
|
161
|
+
assert isinstance(module, ExplicitInOutSize)
|
162
|
+
out_size = module.out_size
|
163
|
+
elif isinstance(module, ElementWiseBlock):
|
164
|
+
out_size = in_size
|
165
|
+
elif isinstance(module, ExplicitInOutSize):
|
166
|
+
out_size = module.out_size
|
167
|
+
else:
|
168
|
+
raise TypeError(f"Unsupported type {type(module)}. ")
|
169
|
+
return module, out_size
|
170
|
+
|
171
|
+
def update(self, x):
|
172
|
+
"""Update function of a sequential model.
|
173
|
+
"""
|
174
|
+
for m in self.children.values():
|
175
|
+
x = m(x)
|
176
|
+
return x
|
177
|
+
|
178
|
+
def update_return(self):
|
179
|
+
"""
|
180
|
+
The return information of the sequence according to the final model.
|
181
|
+
"""
|
182
|
+
last = self[-1]
|
183
|
+
if not isinstance(last, UpdateReturn):
|
184
|
+
raise NotImplementedError(f'The last element in the sequence is not an instance of {UpdateReturn.__name__}')
|
185
|
+
return last.update_return()
|
186
|
+
|
187
|
+
def update_return_info(self):
|
188
|
+
"""
|
189
|
+
The return information of the sequence according to the final model.
|
190
|
+
"""
|
191
|
+
last = self[-1]
|
192
|
+
if not isinstance(last, UpdateReturn):
|
193
|
+
raise NotImplementedError(f'The last element in the sequence is not an instance of {UpdateReturn.__name__}')
|
194
|
+
return last.update_return_info()
|
195
|
+
|
196
|
+
def __getitem__(self, key: Union[int, slice, str]):
|
197
|
+
if isinstance(key, str):
|
198
|
+
if key in self.children:
|
199
|
+
return self.children[key]
|
200
|
+
else:
|
201
|
+
raise KeyError(f'Does not find a component named {key} in\n {str(self)}')
|
202
|
+
elif isinstance(key, slice):
|
203
|
+
return Sequential(**dict(tuple(self.children.items())[key]))
|
204
|
+
elif isinstance(key, int):
|
205
|
+
return tuple(self.children.values())[key]
|
206
|
+
elif isinstance(key, (tuple, list)):
|
207
|
+
_all_nodes = tuple(self.children.items())
|
208
|
+
return Sequential(**dict(_all_nodes[k] for k in key))
|
209
|
+
else:
|
210
|
+
raise KeyError(f'Unknown type of key: {type(key)}')
|
211
|
+
|
212
|
+
def __repr__(self):
|
213
|
+
nodes = self.children.values()
|
214
|
+
entries = '\n'.join(f' [{i}] {_repr_object(x)}' for i, x in enumerate(nodes))
|
215
|
+
return f'{self.__class__.__name__}(\n{entries}\n)'
|
216
|
+
|
217
|
+
|
218
|
+
def _repr_object(x):
|
219
|
+
if isinstance(x, Module):
|
220
|
+
return repr(x)
|
221
|
+
elif callable(x):
|
222
|
+
signature = inspect.signature(x)
|
223
|
+
args = [f'{k}={v.default}' for k, v in signature.parameters.items()
|
224
|
+
if v.default is not inspect.Parameter.empty]
|
225
|
+
args = ', '.join(args)
|
226
|
+
while not hasattr(x, '__name__'):
|
227
|
+
if not hasattr(x, 'func'):
|
228
|
+
break
|
229
|
+
x = x.func # Handle functools.partial
|
230
|
+
if not hasattr(x, '__name__') and hasattr(x, '__class__'):
|
231
|
+
return x.__class__.__name__
|
232
|
+
if args:
|
233
|
+
return f'{x.__name__}(*, {args})'
|
234
|
+
return x.__name__
|
235
|
+
else:
|
236
|
+
x = repr(x).split('\n')
|
237
|
+
x = [x[0]] + [' ' + y for y in x[1:]]
|
238
|
+
return '\n'.join(x)
|
239
|
+
|
240
|
+
|
241
|
+
class DnnLayer(Module, ExplicitInOutSize, DelayedInit):
|
242
|
+
"""
|
243
|
+
A DNN layer.
|
244
|
+
"""
|
245
|
+
__module__ = 'brainstate.nn'
|
246
|
+
|
247
|
+
def __repr__(self):
|
248
|
+
return f"{self.__class__.__name__}(in_size={self.in_size}, out_size={self.out_size})"
|