brainstate 0.0.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +45 -0
- brainstate/_module.py +1466 -0
- brainstate/_module_test.py +133 -0
- brainstate/_state.py +378 -0
- brainstate/_state_test.py +41 -0
- brainstate/_utils.py +21 -0
- brainstate/environ.py +375 -0
- brainstate/functional/__init__.py +25 -0
- brainstate/functional/_activations.py +754 -0
- brainstate/functional/_normalization.py +69 -0
- brainstate/functional/_spikes.py +90 -0
- brainstate/init/__init__.py +26 -0
- brainstate/init/_base.py +36 -0
- brainstate/init/_generic.py +175 -0
- brainstate/init/_random_inits.py +489 -0
- brainstate/init/_regular_inits.py +109 -0
- brainstate/math/__init__.py +21 -0
- brainstate/math/_einops.py +787 -0
- brainstate/math/_einops_parsing.py +169 -0
- brainstate/math/_einops_parsing_test.py +126 -0
- brainstate/math/_einops_test.py +346 -0
- brainstate/math/_misc.py +298 -0
- brainstate/math/_misc_test.py +58 -0
- brainstate/mixin.py +373 -0
- brainstate/mixin_test.py +73 -0
- brainstate/nn/__init__.py +68 -0
- brainstate/nn/_base.py +248 -0
- brainstate/nn/_connections.py +686 -0
- brainstate/nn/_dynamics.py +406 -0
- brainstate/nn/_elementwise.py +1437 -0
- brainstate/nn/_misc.py +132 -0
- brainstate/nn/_normalizations.py +389 -0
- brainstate/nn/_others.py +100 -0
- brainstate/nn/_poolings.py +1228 -0
- brainstate/nn/_poolings_test.py +231 -0
- brainstate/nn/_projection/__init__.py +32 -0
- brainstate/nn/_projection/_align_post.py +528 -0
- brainstate/nn/_projection/_align_pre.py +599 -0
- brainstate/nn/_projection/_delta.py +241 -0
- brainstate/nn/_projection/_utils.py +17 -0
- brainstate/nn/_projection/_vanilla.py +101 -0
- brainstate/nn/_rate_rnns.py +393 -0
- brainstate/nn/_readout.py +130 -0
- brainstate/nn/_synouts.py +166 -0
- brainstate/nn/functional/__init__.py +25 -0
- brainstate/nn/functional/_activations.py +754 -0
- brainstate/nn/functional/_normalization.py +69 -0
- brainstate/nn/functional/_spikes.py +90 -0
- brainstate/nn/init/__init__.py +26 -0
- brainstate/nn/init/_base.py +36 -0
- brainstate/nn/init/_generic.py +175 -0
- brainstate/nn/init/_random_inits.py +489 -0
- brainstate/nn/init/_regular_inits.py +109 -0
- brainstate/nn/surrogate.py +1740 -0
- brainstate/optim/__init__.py +23 -0
- brainstate/optim/_lr_scheduler.py +486 -0
- brainstate/optim/_lr_scheduler_test.py +36 -0
- brainstate/optim/_sgd_optimizer.py +1148 -0
- brainstate/random.py +5148 -0
- brainstate/random_test.py +576 -0
- brainstate/surrogate.py +1740 -0
- brainstate/transform/__init__.py +36 -0
- brainstate/transform/_autograd.py +585 -0
- brainstate/transform/_autograd_test.py +1183 -0
- brainstate/transform/_control.py +665 -0
- brainstate/transform/_controls_test.py +220 -0
- brainstate/transform/_jit.py +239 -0
- brainstate/transform/_jit_error.py +158 -0
- brainstate/transform/_jit_test.py +102 -0
- brainstate/transform/_make_jaxpr.py +573 -0
- brainstate/transform/_make_jaxpr_test.py +133 -0
- brainstate/transform/_progress_bar.py +113 -0
- brainstate/typing.py +69 -0
- brainstate/util.py +747 -0
- brainstate-0.0.1.dist-info/LICENSE +202 -0
- brainstate-0.0.1.dist-info/METADATA +101 -0
- brainstate-0.0.1.dist-info/RECORD +79 -0
- brainstate-0.0.1.dist-info/WHEEL +6 -0
- brainstate-0.0.1.dist-info/top_level.txt +1 -0
brainstate/_module.py
ADDED
@@ -0,0 +1,1466 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
|
19
|
+
"""
|
20
|
+
|
21
|
+
All the basic classes for the ``brainstate``.
|
22
|
+
|
23
|
+
The basic classes include:
|
24
|
+
|
25
|
+
- ``Module``: The base class for all the objects in the ecosystem.
|
26
|
+
- ``Sequential``: The class for a sequential of modules, which update the modules sequentially.
|
27
|
+
- ``ModuleGroup``: The class for a group of modules, which update ``Projection`` first,
|
28
|
+
then ``Dynamics``, finally others.
|
29
|
+
|
30
|
+
and:
|
31
|
+
|
32
|
+
- ``visible_module_list``: A list to represent a sequence of :py:class:`~.Module`
|
33
|
+
that can be visible by the ``.nodes()`` extractor.
|
34
|
+
- ``visible_module_dict``: A dict to represent a dictionary of :py:class:`~.Module`
|
35
|
+
that can be visible by the ``.nodes()`` extractor.
|
36
|
+
|
37
|
+
For handling dynamical systems:
|
38
|
+
|
39
|
+
- ``Projection``: The class for the synaptic projection.
|
40
|
+
- ``Dynamics``: The class for the dynamical system.
|
41
|
+
|
42
|
+
For handling the delays:
|
43
|
+
|
44
|
+
- ``Delay``: The class for all delays.
|
45
|
+
- ``DelayAccess``: The class for the delay access.
|
46
|
+
|
47
|
+
"""
|
48
|
+
|
49
|
+
import inspect
|
50
|
+
import math
|
51
|
+
import numbers
|
52
|
+
from collections import namedtuple
|
53
|
+
from functools import partial
|
54
|
+
from typing import Sequence, Any, Tuple, Union, Dict, Callable, Optional
|
55
|
+
|
56
|
+
import jax
|
57
|
+
import jax.numpy as jnp
|
58
|
+
import numpy as np
|
59
|
+
|
60
|
+
from . import environ
|
61
|
+
from ._utils import set_module_as
|
62
|
+
from ._state import State, StateDictManager, visible_state_dict
|
63
|
+
from .util import unique_name, DictManager, get_unique_name, DotDict
|
64
|
+
from .math import get_dtype
|
65
|
+
from .mixin import Mixin, Mode, DelayedInit, AllOfTypes, Batching, UpdateReturn
|
66
|
+
from .transform._jit_error import jit_error
|
67
|
+
|
68
|
+
Shape = Union[int, Sequence[int]]
|
69
|
+
PyTree = Any
|
70
|
+
ArrayLike = jax.typing.ArrayLike
|
71
|
+
|
72
|
+
delay_identifier = '_*_delay_of_'
|
73
|
+
ROTATE_UPDATE = 'rotation'
|
74
|
+
CONCAT_UPDATE = 'concat'
|
75
|
+
|
76
|
+
StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys'])
|
77
|
+
|
78
|
+
# the maximum order
|
79
|
+
_max_order = 10
|
80
|
+
|
81
|
+
__all__ = [
|
82
|
+
# basic classes
|
83
|
+
'Module', 'visible_module_list', 'visible_module_dict', 'ModuleGroup',
|
84
|
+
|
85
|
+
# dynamical systems
|
86
|
+
'Projection', 'Dynamics',
|
87
|
+
|
88
|
+
# delay handling
|
89
|
+
'Delay', 'DelayAccess',
|
90
|
+
|
91
|
+
# helper functions
|
92
|
+
'call_order',
|
93
|
+
|
94
|
+
# state processing
|
95
|
+
'init_states', 'load_states', 'save_states', 'assign_state_values',
|
96
|
+
]
|
97
|
+
|
98
|
+
|
99
|
+
class Module(object):
|
100
|
+
"""
|
101
|
+
The Module class for the whole ecosystem.
|
102
|
+
|
103
|
+
The ``Module`` is the base class for all the objects in the ecosystem. It
|
104
|
+
provides the basic functionalities for the objects, including:
|
105
|
+
|
106
|
+
- ``states()``: Collect all states in this node and the children nodes.
|
107
|
+
- ``nodes()``: Collect all children nodes.
|
108
|
+
- ``update()``: The function to specify the updating rule.
|
109
|
+
- ``init_state()``: State initialization function.
|
110
|
+
- ``save_state()``: Save states as a dictionary.
|
111
|
+
- ``load_state()``: Load states from the external objects.
|
112
|
+
|
113
|
+
"""
|
114
|
+
|
115
|
+
__module__ = 'brainstate'
|
116
|
+
|
117
|
+
# the excluded states
|
118
|
+
_invisible_states: Tuple[str, ...] = ()
|
119
|
+
|
120
|
+
# the excluded nodes
|
121
|
+
_invisible_nodes: Tuple[str, ...] = ()
|
122
|
+
|
123
|
+
# # the supported computing modes
|
124
|
+
# supported_modes: Optional[Sequence[Mode]] = None
|
125
|
+
|
126
|
+
def __init__(self, name: str = None, mode: Mode = None):
|
127
|
+
super().__init__()
|
128
|
+
|
129
|
+
# check whether the object has a unique name.
|
130
|
+
self._name = unique_name(self=self, name=name)
|
131
|
+
|
132
|
+
# mode setting
|
133
|
+
self._mode = None
|
134
|
+
self.mode = mode if mode is not None else environ.get('mode')
|
135
|
+
|
136
|
+
def __repr__(self):
|
137
|
+
return f'{self.__class__.__name__}'
|
138
|
+
|
139
|
+
@property
|
140
|
+
def name(self):
|
141
|
+
"""Name of the model."""
|
142
|
+
return self._name
|
143
|
+
|
144
|
+
@name.setter
|
145
|
+
def name(self, name: str = None):
|
146
|
+
raise AttributeError('The name of the model is read-only.')
|
147
|
+
|
148
|
+
@property
|
149
|
+
def mode(self):
|
150
|
+
"""Mode of the model, which is useful to control the multiple behaviors of the model."""
|
151
|
+
return self._mode
|
152
|
+
|
153
|
+
@mode.setter
|
154
|
+
def mode(self, value):
|
155
|
+
if not isinstance(value, Mode):
|
156
|
+
raise ValueError(f'Must be instance of {Mode.__name__}, '
|
157
|
+
f'but we got {type(value)}: {value}')
|
158
|
+
self._mode = value
|
159
|
+
|
160
|
+
def states(
|
161
|
+
self,
|
162
|
+
method: str = 'absolute',
|
163
|
+
level: int = -1,
|
164
|
+
include_self: bool = True,
|
165
|
+
unique: bool = True,
|
166
|
+
) -> StateDictManager:
|
167
|
+
"""
|
168
|
+
Collect all states in this node and the children nodes.
|
169
|
+
|
170
|
+
Parameters
|
171
|
+
----------
|
172
|
+
method : str
|
173
|
+
The method to access the variables.
|
174
|
+
level: int
|
175
|
+
The hierarchy level to find variables.
|
176
|
+
include_self: bool
|
177
|
+
Whether include the variables in the self.
|
178
|
+
unique: bool
|
179
|
+
Whether return the unique variables.
|
180
|
+
|
181
|
+
Returns
|
182
|
+
-------
|
183
|
+
states : StateDictManager
|
184
|
+
The collection contained (the path, the variable).
|
185
|
+
"""
|
186
|
+
|
187
|
+
# find the nodes
|
188
|
+
nodes = self.nodes(method=method, level=level, include_self=include_self)
|
189
|
+
|
190
|
+
# get the state stack
|
191
|
+
states = StateDictManager()
|
192
|
+
_state_id = set()
|
193
|
+
for node_path, node in nodes.items():
|
194
|
+
for k in node.__dict__.keys():
|
195
|
+
if k in node._invisible_states:
|
196
|
+
continue
|
197
|
+
v = getattr(node, k)
|
198
|
+
if isinstance(v, State):
|
199
|
+
if unique and id(v) in _state_id:
|
200
|
+
continue
|
201
|
+
_state_id.add(id(v))
|
202
|
+
states[f'{node_path}.{k}' if node_path else k] = v
|
203
|
+
elif isinstance(v, visible_state_dict):
|
204
|
+
for k2, v2 in v.items():
|
205
|
+
if unique and id(v2) in _state_id:
|
206
|
+
continue
|
207
|
+
_state_id.add(id(v2))
|
208
|
+
states[f'{node_path}.{k}.{k2}'] = v2
|
209
|
+
|
210
|
+
return states
|
211
|
+
|
212
|
+
def nodes(
|
213
|
+
self,
|
214
|
+
method: str = 'absolute',
|
215
|
+
level: int = -1,
|
216
|
+
include_self: bool = True,
|
217
|
+
unique: bool = True,
|
218
|
+
) -> DictManager:
|
219
|
+
"""
|
220
|
+
Collect all children nodes.
|
221
|
+
|
222
|
+
Parameters
|
223
|
+
----------
|
224
|
+
method : str
|
225
|
+
The method to access the nodes.
|
226
|
+
level: int
|
227
|
+
The hierarchy level to find nodes.
|
228
|
+
include_self: bool
|
229
|
+
Whether include the self.
|
230
|
+
unique: bool
|
231
|
+
Whether return the unique variables.
|
232
|
+
|
233
|
+
Returns
|
234
|
+
-------
|
235
|
+
gather : DictManager
|
236
|
+
The collection contained (the path, the node).
|
237
|
+
"""
|
238
|
+
nodes = _find_nodes(self, method=method, level=level, include_self=include_self)
|
239
|
+
if unique:
|
240
|
+
nodes = nodes.unique()
|
241
|
+
return nodes
|
242
|
+
|
243
|
+
def update(self, *args, **kwargs):
|
244
|
+
"""
|
245
|
+
The function to specify the updating rule.
|
246
|
+
"""
|
247
|
+
raise NotImplementedError(f'Subclass of {self.__class__.__name__} must '
|
248
|
+
f'implement "update" function.')
|
249
|
+
|
250
|
+
def __call__(self, *args, **kwargs):
|
251
|
+
return self.update(*args, **kwargs)
|
252
|
+
|
253
|
+
def __rrshift__(self, other):
|
254
|
+
"""
|
255
|
+
Support using right shift operator to call modules.
|
256
|
+
|
257
|
+
Examples
|
258
|
+
--------
|
259
|
+
|
260
|
+
>>> import brainstate as bst
|
261
|
+
>>> import brainscale as nn # noqa
|
262
|
+
>>> x = bst.random.rand((10, 10))
|
263
|
+
>>> l = nn.Activation(jax.numpy.tanh)
|
264
|
+
>>> y = x >> l
|
265
|
+
"""
|
266
|
+
return self.__call__(other)
|
267
|
+
|
268
|
+
def init_state(self, *args, **kwargs):
|
269
|
+
"""
|
270
|
+
State initialization function.
|
271
|
+
"""
|
272
|
+
pass
|
273
|
+
|
274
|
+
def save_state(self, **kwargs) -> Dict:
|
275
|
+
"""Save states as a dictionary. """
|
276
|
+
return self.states(include_self=True, level=0, method='absolute')
|
277
|
+
|
278
|
+
def load_state(self, state_dict: Dict, **kwargs) -> Optional[Tuple[Sequence[str], Sequence[str]]]:
|
279
|
+
"""Load states from the external objects."""
|
280
|
+
variables = self.states(include_self=True, level=0, method='absolute')
|
281
|
+
keys1 = set(state_dict.keys())
|
282
|
+
keys2 = set(variables.keys())
|
283
|
+
for key in keys2.intersection(keys1):
|
284
|
+
variables[key].value = jax.numpy.asarray(state_dict[key])
|
285
|
+
unexpected_keys = list(keys1 - keys2)
|
286
|
+
missing_keys = list(keys2 - keys1)
|
287
|
+
return unexpected_keys, missing_keys
|
288
|
+
|
289
|
+
|
290
|
+
def _find_nodes(self, method: str = 'absolute', level=-1, include_self=True, _lid=0, _edges=None) -> DictManager:
|
291
|
+
if _edges is None:
|
292
|
+
_edges = set()
|
293
|
+
gather = DictManager()
|
294
|
+
if include_self:
|
295
|
+
if method == 'absolute':
|
296
|
+
gather[self.name] = self
|
297
|
+
elif method == 'relative':
|
298
|
+
gather[''] = self
|
299
|
+
else:
|
300
|
+
raise ValueError(f'No support for the method of "{method}".')
|
301
|
+
if (level > -1) and (_lid >= level):
|
302
|
+
return gather
|
303
|
+
if method == 'absolute':
|
304
|
+
nodes = []
|
305
|
+
for k, v in self.__dict__.items():
|
306
|
+
if k in self._invisible_nodes:
|
307
|
+
continue
|
308
|
+
if isinstance(v, Module):
|
309
|
+
_add_node_absolute(self, v, _edges, gather, nodes)
|
310
|
+
elif isinstance(v, visible_module_list):
|
311
|
+
for v2 in v:
|
312
|
+
_add_node_absolute(self, v2, _edges, gather, nodes)
|
313
|
+
elif isinstance(v, visible_module_dict):
|
314
|
+
for v2 in v.values():
|
315
|
+
if isinstance(v2, Module):
|
316
|
+
_add_node_absolute(self, v2, _edges, gather, nodes)
|
317
|
+
|
318
|
+
# finding nodes recursively
|
319
|
+
for v in nodes:
|
320
|
+
gather.update(_find_nodes(v,
|
321
|
+
method=method,
|
322
|
+
level=level,
|
323
|
+
_lid=_lid + 1,
|
324
|
+
_edges=_edges,
|
325
|
+
include_self=include_self))
|
326
|
+
|
327
|
+
elif method == 'relative':
|
328
|
+
nodes = []
|
329
|
+
for k, v in self.__dict__.items():
|
330
|
+
if v in self._invisible_nodes:
|
331
|
+
continue
|
332
|
+
if isinstance(v, Module):
|
333
|
+
_add_node_relative(self, k, v, _edges, gather, nodes)
|
334
|
+
elif isinstance(v, visible_module_list):
|
335
|
+
for i, v2 in enumerate(v):
|
336
|
+
_add_node_relative(self, f'{k}-list:{i}', v2, _edges, gather, nodes)
|
337
|
+
elif isinstance(v, visible_module_dict):
|
338
|
+
for k2, v2 in v.items():
|
339
|
+
if isinstance(v2, Module):
|
340
|
+
_add_node_relative(self, f'{k}-dict:{k2}', v2, _edges, gather, nodes)
|
341
|
+
|
342
|
+
# finding nodes recursively
|
343
|
+
for k1, v1 in nodes:
|
344
|
+
for k2, v2 in _find_nodes(v1,
|
345
|
+
method=method,
|
346
|
+
_edges=_edges,
|
347
|
+
_lid=_lid + 1,
|
348
|
+
level=level,
|
349
|
+
include_self=include_self).items():
|
350
|
+
if k2:
|
351
|
+
gather[f'{k1}.{k2}'] = v2
|
352
|
+
|
353
|
+
else:
|
354
|
+
raise ValueError(f'No support for the method of "{method}".')
|
355
|
+
return gather
|
356
|
+
|
357
|
+
|
358
|
+
def _add_node_absolute(self, v, _paths, gather, nodes):
|
359
|
+
path = (id(self), id(v))
|
360
|
+
if path not in _paths:
|
361
|
+
_paths.add(path)
|
362
|
+
gather[v.name] = v
|
363
|
+
nodes.append(v)
|
364
|
+
|
365
|
+
|
366
|
+
def _add_node_relative(self, k, v, _paths, gather, nodes):
|
367
|
+
path = (id(self), id(v))
|
368
|
+
if path not in _paths:
|
369
|
+
_paths.add(path)
|
370
|
+
gather[k] = v
|
371
|
+
nodes.append((k, v))
|
372
|
+
|
373
|
+
|
374
|
+
class Projection(Module):
|
375
|
+
"""
|
376
|
+
Base class to model synaptic projections.
|
377
|
+
"""
|
378
|
+
|
379
|
+
__module__ = 'brainstate'
|
380
|
+
|
381
|
+
def update(self, *args, **kwargs):
|
382
|
+
nodes = tuple(self.nodes(level=1, include_self=False).values())
|
383
|
+
if len(nodes):
|
384
|
+
for node in nodes:
|
385
|
+
node(*args, **kwargs)
|
386
|
+
else:
|
387
|
+
raise ValueError('Do not implement the update() function.')
|
388
|
+
|
389
|
+
|
390
|
+
class visible_module_list(list):
|
391
|
+
"""
|
392
|
+
A sequence of :py:class:`~.Module`, which is compatible with
|
393
|
+
:py:func:`~.vars()` and :py:func:`~.nodes()` operations in a :py:class:`~.Module`.
|
394
|
+
|
395
|
+
That is to say, any nodes that are wrapped into :py:class:`~.NodeList` will be automatically
|
396
|
+
retieved when using :py:func:`~.nodes()` function.
|
397
|
+
|
398
|
+
>>> import brainstate as bst
|
399
|
+
>>> l = bst.visible_module_list([bp.dnn.Dense(1, 2),
|
400
|
+
>>> bp.dnn.LSTMCell(2, 3)])
|
401
|
+
"""
|
402
|
+
|
403
|
+
__module__ = 'brainstate'
|
404
|
+
|
405
|
+
def __init__(self, seq=()):
|
406
|
+
super().__init__()
|
407
|
+
self.extend(seq)
|
408
|
+
|
409
|
+
def append(self, element) -> 'visible_module_list':
|
410
|
+
if isinstance(element, State):
|
411
|
+
raise TypeError(f'Cannot append a state into a node list. ')
|
412
|
+
super().append(element)
|
413
|
+
return self
|
414
|
+
|
415
|
+
def extend(self, iterable) -> 'visible_module_list':
|
416
|
+
for element in iterable:
|
417
|
+
self.append(element)
|
418
|
+
return self
|
419
|
+
|
420
|
+
|
421
|
+
class visible_module_dict(dict):
|
422
|
+
"""
|
423
|
+
A dictionary of :py:class:`~.Module`, which is compatible with
|
424
|
+
:py:func:`.vars()` operation in a :py:class:`~.Module`.
|
425
|
+
|
426
|
+
"""
|
427
|
+
|
428
|
+
__module__ = 'brainstate'
|
429
|
+
|
430
|
+
def __init__(self, *args, check_unique: bool = False, **kwargs):
|
431
|
+
super().__init__()
|
432
|
+
self.check_unique = check_unique
|
433
|
+
self.update(*args, **kwargs)
|
434
|
+
|
435
|
+
def update(self, *args, **kwargs) -> 'visible_module_dict':
|
436
|
+
for arg in args:
|
437
|
+
if isinstance(arg, dict):
|
438
|
+
for k, v in arg.items():
|
439
|
+
self[k] = v
|
440
|
+
elif isinstance(arg, tuple):
|
441
|
+
assert len(arg) == 2
|
442
|
+
self[arg[0]] = args[1]
|
443
|
+
for k, v in kwargs.items():
|
444
|
+
self[k] = v
|
445
|
+
return self
|
446
|
+
|
447
|
+
def __setitem__(self, key, value) -> 'visible_module_dict':
|
448
|
+
if self.check_unique:
|
449
|
+
exist = self.get(key, None)
|
450
|
+
if id(exist) != id(value):
|
451
|
+
raise KeyError(f'Duplicate usage of key "{key}". "{key}" has been used for {value}.')
|
452
|
+
super().__setitem__(key, value)
|
453
|
+
return self
|
454
|
+
|
455
|
+
|
456
|
+
class ReceiveInputProj(Mixin):
|
457
|
+
"""
|
458
|
+
The :py:class:`~.Mixin` that receives the input projections.
|
459
|
+
|
460
|
+
Note that the subclass should define a ``cur_inputs`` attribute. Otherwise,
|
461
|
+
the input function utilities cannot be used.
|
462
|
+
|
463
|
+
"""
|
464
|
+
_current_inputs: Optional[visible_module_dict]
|
465
|
+
_delta_inputs: Optional[visible_module_dict]
|
466
|
+
|
467
|
+
@property
|
468
|
+
def current_inputs(self):
|
469
|
+
"""
|
470
|
+
The current inputs of the model. It should be a dictionary of the input data.
|
471
|
+
"""
|
472
|
+
return self._current_inputs
|
473
|
+
|
474
|
+
@property
|
475
|
+
def delta_inputs(self):
|
476
|
+
"""
|
477
|
+
The delta inputs of the model. It should be a dictionary of the input data.
|
478
|
+
"""
|
479
|
+
|
480
|
+
return self._delta_inputs
|
481
|
+
|
482
|
+
def add_input_fun(self, key: str, fun: Callable, label: Optional[str] = None, category: str = 'current'):
|
483
|
+
"""Add an input function.
|
484
|
+
|
485
|
+
Args:
|
486
|
+
key: str. The dict key.
|
487
|
+
fun: Callable. The function to generate inputs.
|
488
|
+
label: str. The input label.
|
489
|
+
category: str. The input category, should be ``current`` (the current) or
|
490
|
+
``delta`` (the delta synapse, indicating the delta function).
|
491
|
+
"""
|
492
|
+
if not callable(fun):
|
493
|
+
raise TypeError('Must be a function.')
|
494
|
+
|
495
|
+
key = _input_label_repr(key, label)
|
496
|
+
if category == 'current':
|
497
|
+
if self._current_inputs is None:
|
498
|
+
self._current_inputs = visible_module_dict()
|
499
|
+
if key in self._current_inputs:
|
500
|
+
raise ValueError(f'Key "{key}" has been defined and used.')
|
501
|
+
self._current_inputs[key] = fun
|
502
|
+
|
503
|
+
elif category == 'delta':
|
504
|
+
if self._delta_inputs is None:
|
505
|
+
self._delta_inputs = visible_module_dict()
|
506
|
+
if key in self._delta_inputs:
|
507
|
+
raise ValueError(f'Key "{key}" has been defined and used.')
|
508
|
+
self._delta_inputs[key] = fun
|
509
|
+
|
510
|
+
else:
|
511
|
+
raise NotImplementedError(f'Unknown category: {category}. Only support "current" and "delta".')
|
512
|
+
|
513
|
+
def get_input_fun(self, key: str):
|
514
|
+
"""Get the input function.
|
515
|
+
|
516
|
+
Args:
|
517
|
+
key: str. The key.
|
518
|
+
|
519
|
+
Returns:
|
520
|
+
The input function which generates currents.
|
521
|
+
"""
|
522
|
+
if self._current_inputs is not None and key in self._current_inputs:
|
523
|
+
return self._current_inputs[key]
|
524
|
+
|
525
|
+
elif self._delta_inputs is not None and key in self._delta_inputs:
|
526
|
+
return self._delta_inputs[key]
|
527
|
+
|
528
|
+
else:
|
529
|
+
raise ValueError(f'Unknown key: {key}')
|
530
|
+
|
531
|
+
def sum_current_inputs(self, *args, init: Any = 0., label: Optional[str] = None, **kwargs):
|
532
|
+
"""
|
533
|
+
Summarize all current inputs by the defined input functions ``.current_inputs``.
|
534
|
+
|
535
|
+
Args:
|
536
|
+
*args: The arguments for input functions.
|
537
|
+
init: The initial input data.
|
538
|
+
label: str. The input label.
|
539
|
+
**kwargs: The arguments for input functions.
|
540
|
+
|
541
|
+
Returns:
|
542
|
+
The total currents.
|
543
|
+
"""
|
544
|
+
if self._current_inputs is None:
|
545
|
+
return init
|
546
|
+
if label is None:
|
547
|
+
for key, out in self._current_inputs.items():
|
548
|
+
init = init + out(*args, **kwargs)
|
549
|
+
else:
|
550
|
+
label_repr = _input_label_start(label)
|
551
|
+
for key, out in self._current_inputs.items():
|
552
|
+
if key.startswith(label_repr):
|
553
|
+
init = init + out(*args, **kwargs)
|
554
|
+
return init
|
555
|
+
|
556
|
+
def sum_delta_inputs(self, *args, init: Any = 0., label: Optional[str] = None, **kwargs):
|
557
|
+
"""
|
558
|
+
Summarize all delta inputs by the defined input functions ``.delta_inputs``.
|
559
|
+
|
560
|
+
Args:
|
561
|
+
*args: The arguments for input functions.
|
562
|
+
init: The initial input data.
|
563
|
+
label: str. The input label.
|
564
|
+
**kwargs: The arguments for input functions.
|
565
|
+
|
566
|
+
Returns:
|
567
|
+
The total currents.
|
568
|
+
"""
|
569
|
+
if self._delta_inputs is None:
|
570
|
+
return init
|
571
|
+
if label is None:
|
572
|
+
for key, out in self._delta_inputs.items():
|
573
|
+
init = init + out(*args, **kwargs)
|
574
|
+
else:
|
575
|
+
label_repr = _input_label_start(label)
|
576
|
+
for key, out in self._delta_inputs.items():
|
577
|
+
if key.startswith(label_repr):
|
578
|
+
init = init + out(*args, **kwargs)
|
579
|
+
return init
|
580
|
+
|
581
|
+
|
582
|
+
class Container(Mixin):
|
583
|
+
"""Container :py:class:`~.MixIn` which wrap a group of objects.
|
584
|
+
"""
|
585
|
+
children: visible_module_dict
|
586
|
+
|
587
|
+
def __getitem__(self, item):
|
588
|
+
"""Overwrite the slice access (`self['']`). """
|
589
|
+
if item in self.children:
|
590
|
+
return self.children[item]
|
591
|
+
else:
|
592
|
+
raise ValueError(f'Unknown item {item}, we only found {list(self.children.keys())}')
|
593
|
+
|
594
|
+
def __getattr__(self, item):
|
595
|
+
"""Overwrite the dot access (`self.`). """
|
596
|
+
children = super().__getattribute__('children')
|
597
|
+
if item == 'children':
|
598
|
+
return children
|
599
|
+
else:
|
600
|
+
if item in children:
|
601
|
+
return children[item]
|
602
|
+
else:
|
603
|
+
return super().__getattribute__(item)
|
604
|
+
|
605
|
+
def __repr__(self):
|
606
|
+
cls_name = self.__class__.__name__
|
607
|
+
indent = ' ' * len(cls_name)
|
608
|
+
child_str = [_repr_context(repr(val), indent) for val in self.children.values()]
|
609
|
+
string = ", \n".join(child_str)
|
610
|
+
return f'{cls_name}({string})'
|
611
|
+
|
612
|
+
def __get_elem_name(self, elem):
|
613
|
+
if isinstance(elem, Module):
|
614
|
+
return elem.name
|
615
|
+
else:
|
616
|
+
return get_unique_name('ContainerElem')
|
617
|
+
|
618
|
+
def format_elements(self, child_type: type, *children_as_tuple, **children_as_dict):
|
619
|
+
res = dict()
|
620
|
+
|
621
|
+
# add tuple-typed components
|
622
|
+
for module in children_as_tuple:
|
623
|
+
if isinstance(module, child_type):
|
624
|
+
res[self.__get_elem_name(module)] = module
|
625
|
+
elif isinstance(module, (list, tuple)):
|
626
|
+
for m in module:
|
627
|
+
if not isinstance(m, child_type):
|
628
|
+
raise TypeError(f'Should be instance of {child_type.__name__}. '
|
629
|
+
f'But we got {type(m)}')
|
630
|
+
res[self.__get_elem_name(m)] = m
|
631
|
+
elif isinstance(module, dict):
|
632
|
+
for k, v in module.items():
|
633
|
+
if not isinstance(v, child_type):
|
634
|
+
raise TypeError(f'Should be instance of {child_type.__name__}. '
|
635
|
+
f'But we got {type(v)}')
|
636
|
+
res[k] = v
|
637
|
+
else:
|
638
|
+
raise TypeError(f'Cannot parse sub-systems. They should be {child_type.__name__} '
|
639
|
+
f'or a list/tuple/dict of {child_type.__name__}.')
|
640
|
+
# add dict-typed components
|
641
|
+
for k, v in children_as_dict.items():
|
642
|
+
if not isinstance(v, child_type):
|
643
|
+
raise TypeError(f'Should be instance of {child_type.__name__}. '
|
644
|
+
f'But we got {type(v)}')
|
645
|
+
res[k] = v
|
646
|
+
return res
|
647
|
+
|
648
|
+
def add_elem(self, *elems, **elements):
|
649
|
+
"""
|
650
|
+
Add new elements.
|
651
|
+
|
652
|
+
>>> obj = Container()
|
653
|
+
>>> obj.add_elem(a=1.)
|
654
|
+
|
655
|
+
Args:
|
656
|
+
elements: children objects.
|
657
|
+
"""
|
658
|
+
self.children.update(self.format_elements(object, *elems, **elements))
|
659
|
+
|
660
|
+
|
661
|
+
class ExtendedUpdateWithBA(Module):
|
662
|
+
"""
|
663
|
+
The extended update with before and after updates.
|
664
|
+
"""
|
665
|
+
|
666
|
+
_before_updates: Optional[visible_module_dict]
|
667
|
+
_after_updates: Optional[visible_module_dict]
|
668
|
+
|
669
|
+
def __init__(self, *args, **kwargs):
|
670
|
+
|
671
|
+
# -- Attribute for "BeforeAfterMixIn" -- #
|
672
|
+
# the before- / after-updates used for computing
|
673
|
+
self._before_updates: Optional[Dict[str, Callable]] = None
|
674
|
+
self._after_updates: Optional[Dict[str, Callable]] = None
|
675
|
+
|
676
|
+
super().__init__(*args, **kwargs)
|
677
|
+
|
678
|
+
@property
|
679
|
+
def before_updates(self):
|
680
|
+
"""
|
681
|
+
The before updates of the model. It should be a dictionary of the updating functions.
|
682
|
+
"""
|
683
|
+
return self._before_updates
|
684
|
+
|
685
|
+
@property
|
686
|
+
def after_updates(self):
|
687
|
+
"""
|
688
|
+
The after updates of the model. It should be a dictionary of the updating functions.
|
689
|
+
"""
|
690
|
+
return self._after_updates
|
691
|
+
|
692
|
+
def add_before_update(self, key: Any, fun: Callable):
|
693
|
+
"""
|
694
|
+
Add the before update into this node.
|
695
|
+
"""
|
696
|
+
if self._before_updates is None:
|
697
|
+
self._before_updates = visible_module_dict()
|
698
|
+
if key in self.before_updates:
|
699
|
+
raise KeyError(f'{key} has been registered in before_updates of {self}')
|
700
|
+
self.before_updates[key] = fun
|
701
|
+
|
702
|
+
def add_after_update(self, key: Any, fun: Callable):
|
703
|
+
"""Add the after update into this node"""
|
704
|
+
if self._after_updates is None:
|
705
|
+
self._after_updates = visible_module_dict()
|
706
|
+
if key in self.after_updates:
|
707
|
+
raise KeyError(f'{key} has been registered in after_updates of {self}')
|
708
|
+
self.after_updates[key] = fun
|
709
|
+
|
710
|
+
def get_before_update(self, key: Any):
|
711
|
+
"""Get the before update of this node by the given ``key``."""
|
712
|
+
if self._before_updates is None:
|
713
|
+
raise KeyError(f'{key} is not registered in before_updates of {self}')
|
714
|
+
if key not in self.before_updates:
|
715
|
+
raise KeyError(f'{key} is not registered in before_updates of {self}')
|
716
|
+
return self.before_updates.get(key)
|
717
|
+
|
718
|
+
def get_after_update(self, key: Any):
|
719
|
+
"""Get the after update of this node by the given ``key``."""
|
720
|
+
if self._after_updates is None:
|
721
|
+
raise KeyError(f'{key} is not registered in after_updates of {self}')
|
722
|
+
if key not in self.after_updates:
|
723
|
+
raise KeyError(f'{key} is not registered in after_updates of {self}')
|
724
|
+
return self.after_updates.get(key)
|
725
|
+
|
726
|
+
def has_before_update(self, key: Any):
|
727
|
+
"""Whether this node has the before update of the given ``key``."""
|
728
|
+
if self._before_updates is None:
|
729
|
+
return False
|
730
|
+
return key in self.before_updates
|
731
|
+
|
732
|
+
def has_after_update(self, key: Any):
|
733
|
+
"""Whether this node has the after update of the given ``key``."""
|
734
|
+
if self._after_updates is None:
|
735
|
+
return False
|
736
|
+
return key in self.after_updates
|
737
|
+
|
738
|
+
def __call__(self, *args, **kwargs):
|
739
|
+
"""The shortcut to call ``update`` methods."""
|
740
|
+
|
741
|
+
# ``before_updates``
|
742
|
+
if self.before_updates is not None:
|
743
|
+
for model in self.before_updates.values():
|
744
|
+
if hasattr(model, '_receive_update_input'):
|
745
|
+
model(*args, **kwargs)
|
746
|
+
else:
|
747
|
+
model()
|
748
|
+
|
749
|
+
# update the model self
|
750
|
+
ret = self.update(*args, **kwargs)
|
751
|
+
|
752
|
+
# ``after_updates``
|
753
|
+
if self.after_updates is not None:
|
754
|
+
for model in self.after_updates.values():
|
755
|
+
if hasattr(model, '_not_receive_update_output'):
|
756
|
+
model()
|
757
|
+
else:
|
758
|
+
model(ret)
|
759
|
+
return ret
|
760
|
+
|
761
|
+
|
762
|
+
class Dynamics(ExtendedUpdateWithBA, ReceiveInputProj, UpdateReturn):
|
763
|
+
"""
|
764
|
+
Dynamical System class.
|
765
|
+
|
766
|
+
.. note::
|
767
|
+
In general, every instance of :py:class:`~.Module` implemented in
|
768
|
+
BrainPy only defines the evolving function at each time step :math:`t`.
|
769
|
+
|
770
|
+
If users want to define the logic of running models across multiple steps,
|
771
|
+
we recommend users to use :py:func:`~.for_loop`, :py:class:`~.LoopOverTime`,
|
772
|
+
:py:class:`~.DSRunner`, or :py:class:`~.DSTrainer`.
|
773
|
+
|
774
|
+
To be compatible with previous APIs, :py:class:`~.Module` inherits
|
775
|
+
from the :py:class:`~.DelayRegister`. It's worthy to note that the methods of
|
776
|
+
:py:class:`~.DelayRegister` will be removed in the future, including:
|
777
|
+
|
778
|
+
- ``.register_delay()``
|
779
|
+
- ``.get_delay_data()``
|
780
|
+
- ``.update_local_delays()``
|
781
|
+
- ``.reset_local_delays()``
|
782
|
+
|
783
|
+
|
784
|
+
There are several essential attributes:
|
785
|
+
|
786
|
+
- ``size``: the geometry of the neuron group. For example, `(10, )` denotes a line of
|
787
|
+
neurons, `(10, 10)` denotes a neuron group aligned in a 2D space, `(10, 15, 4)` denotes
|
788
|
+
a 3-dimensional neuron group.
|
789
|
+
- ``num``: the flattened number of neurons in the group. For example, `size=(10, )` => \
|
790
|
+
`num=10`, `size=(10, 10)` => `num=100`, `size=(10, 15, 4)` => `num=600`.
|
791
|
+
|
792
|
+
Args:
|
793
|
+
size: The neuron group geometry.
|
794
|
+
name: The name of the dynamic system.
|
795
|
+
keep_size: Whether keep the geometry information.
|
796
|
+
mode: The computing mode.
|
797
|
+
"""
|
798
|
+
|
799
|
+
__module__ = 'brainstate'
|
800
|
+
|
801
|
+
def __init__(
|
802
|
+
self,
|
803
|
+
size: Shape,
|
804
|
+
keep_size: bool = False,
|
805
|
+
name: Optional[str] = None,
|
806
|
+
mode: Optional[Mode] = None,
|
807
|
+
method: str = 'exp_auto'
|
808
|
+
):
|
809
|
+
# size
|
810
|
+
if isinstance(size, (list, tuple)):
|
811
|
+
if len(size) <= 0:
|
812
|
+
raise ValueError(f'size must be int, or a tuple/list of int. '
|
813
|
+
f'But we got {type(size)}')
|
814
|
+
if not isinstance(size[0], (int, np.integer)):
|
815
|
+
raise ValueError('size must be int, or a tuple/list of int.'
|
816
|
+
f'But we got {type(size)}')
|
817
|
+
size = tuple(size)
|
818
|
+
elif isinstance(size, (int, np.integer)):
|
819
|
+
size = (size,)
|
820
|
+
else:
|
821
|
+
raise ValueError('size must be int, or a tuple/list of int.'
|
822
|
+
f'But we got {type(size)}')
|
823
|
+
self.size = size
|
824
|
+
self.keep_size = keep_size
|
825
|
+
|
826
|
+
# number of neurons
|
827
|
+
self.num = np.prod(size)
|
828
|
+
|
829
|
+
# integration method
|
830
|
+
self.method = method
|
831
|
+
|
832
|
+
# -- Attribute for "InputProjMixIn" -- #
|
833
|
+
# each instance of "SupportInputProj" should have
|
834
|
+
# "_current_inputs" and "_delta_inputs" attributes
|
835
|
+
self._current_inputs: Optional[Dict[str, Callable]] = None
|
836
|
+
self._delta_inputs: Optional[Dict[str, Callable]] = None
|
837
|
+
|
838
|
+
# initialize
|
839
|
+
super().__init__(name=name, mode=mode)
|
840
|
+
|
841
|
+
@property
|
842
|
+
def varshape(self):
|
843
|
+
"""The shape of variables in the neuron group."""
|
844
|
+
return self.size if self.keep_size else (self.num,)
|
845
|
+
|
846
|
+
def __repr__(self):
|
847
|
+
return f'{self.name}(mode={self.mode}, size={self.size})'
|
848
|
+
|
849
|
+
def update_return_info(self) -> PyTree:
|
850
|
+
raise NotImplementedError(f'Subclass of {self.__class__.__name__}'
|
851
|
+
'must implement "update_return_info" function.')
|
852
|
+
|
853
|
+
def update_return(self) -> PyTree:
|
854
|
+
raise NotImplementedError(f'Subclass of {self.__class__.__name__}'
|
855
|
+
'must implement "update_return" function.')
|
856
|
+
|
857
|
+
def register_return_delay(
|
858
|
+
self,
|
859
|
+
delay_name: str,
|
860
|
+
delay_time: ArrayLike = None,
|
861
|
+
delay_step: ArrayLike = None,
|
862
|
+
):
|
863
|
+
"""Register local relay at the given delay time.
|
864
|
+
|
865
|
+
Args:
|
866
|
+
delay_name: str. The name of the current delay data.
|
867
|
+
delay_time: The delay time. Float.
|
868
|
+
delay_step: The delay step. Int. ``delay_step`` and ``delay_time`` are exclusive. ``delay_step = delay_time / dt``.
|
869
|
+
"""
|
870
|
+
if not self.has_after_update(delay_identifier):
|
871
|
+
# add a model to receive the return of the target model
|
872
|
+
model = Delay(self.update_return_info())
|
873
|
+
# register the model
|
874
|
+
self.add_after_update(delay_identifier, model)
|
875
|
+
delay_cls: Delay = self.get_after_update(delay_identifier)
|
876
|
+
delay_cls.register_entry(delay_name, delay_time=delay_time, delay_step=delay_step)
|
877
|
+
return delay_cls
|
878
|
+
|
879
|
+
def get_return_delay_at(self, delay_name):
|
880
|
+
"""Get the state delay at the given identifier (`name`).
|
881
|
+
|
882
|
+
See also :py:meth:`~.Module.register_state_delay`.
|
883
|
+
|
884
|
+
Args:
|
885
|
+
delay_name: The identifier of the delay.
|
886
|
+
|
887
|
+
Returns:
|
888
|
+
The delayed data at the given delay position.
|
889
|
+
"""
|
890
|
+
return self.get_after_update(delay_identifier).at(delay_name)
|
891
|
+
|
892
|
+
|
893
|
+
class ModuleGroup(Module, Container):
|
894
|
+
"""A group of :py:class:`~.Module` in which the updating order does not matter.
|
895
|
+
|
896
|
+
Args:
|
897
|
+
children_as_tuple: The children objects.
|
898
|
+
children_as_dict: The children objects.
|
899
|
+
name: The object name.
|
900
|
+
mode: The mode which controls the model computation.
|
901
|
+
child_type: The type of the children object. Default is :py:class:`Module`.
|
902
|
+
"""
|
903
|
+
|
904
|
+
__module__ = 'brainstate'
|
905
|
+
|
906
|
+
def __init__(
|
907
|
+
self,
|
908
|
+
*children_as_tuple,
|
909
|
+
name: Optional[str] = None,
|
910
|
+
mode: Optional[Mode] = None,
|
911
|
+
child_type: type = Module,
|
912
|
+
**children_as_dict
|
913
|
+
):
|
914
|
+
super().__init__(name=name, mode=mode)
|
915
|
+
|
916
|
+
# Attribute of "Container"
|
917
|
+
self.children = visible_module_dict(self.format_elements(child_type, *children_as_tuple, **children_as_dict))
|
918
|
+
|
919
|
+
def update(self, *args, **kwargs):
|
920
|
+
"""
|
921
|
+
Step function of a network.
|
922
|
+
|
923
|
+
In this update function, the update functions in children systems are
|
924
|
+
iteratively called.
|
925
|
+
"""
|
926
|
+
projs, dyns, others = self.nodes(level=1, include_self=False).split(Projection, Dynamics)
|
927
|
+
|
928
|
+
# update nodes of projections
|
929
|
+
for node in projs.values():
|
930
|
+
node()
|
931
|
+
|
932
|
+
# update nodes of dynamics
|
933
|
+
for node in dyns.values():
|
934
|
+
node()
|
935
|
+
|
936
|
+
# update nodes with other types, including delays, ...
|
937
|
+
for node in others.values():
|
938
|
+
node()
|
939
|
+
|
940
|
+
|
941
|
+
def receive_update_output(cls: object):
|
942
|
+
"""
|
943
|
+
The decorator to mark the object (as the after updates) to receive the output of the update function.
|
944
|
+
|
945
|
+
That is, the `aft_update` will receive the return of the update function::
|
946
|
+
|
947
|
+
ret = model.update(*args, **kwargs)
|
948
|
+
for fun in model.aft_updates:
|
949
|
+
fun(ret)
|
950
|
+
|
951
|
+
"""
|
952
|
+
# assert isinstance(cls, Module), 'The input class should be instance of Module.'
|
953
|
+
if hasattr(cls, '_not_receive_update_output'):
|
954
|
+
delattr(cls, '_not_receive_update_output')
|
955
|
+
return cls
|
956
|
+
|
957
|
+
|
958
|
+
def not_receive_update_output(cls: object):
|
959
|
+
"""
|
960
|
+
The decorator to mark the object (as the after updates) to not receive the output of the update function.
|
961
|
+
|
962
|
+
That is, the `aft_update` will not receive the return of the update function::
|
963
|
+
|
964
|
+
ret = model.update(*args, **kwargs)
|
965
|
+
for fun in model.aft_updates:
|
966
|
+
fun()
|
967
|
+
|
968
|
+
"""
|
969
|
+
# assert isinstance(cls, Module), 'The input class should be instance of Module.'
|
970
|
+
cls._not_receive_update_output = True
|
971
|
+
return cls
|
972
|
+
|
973
|
+
|
974
|
+
def receive_update_input(cls: object):
|
975
|
+
"""
|
976
|
+
The decorator to mark the object (as the before updates) to receive the input of the update function.
|
977
|
+
|
978
|
+
That is, the `bef_update` will receive the input of the update function::
|
979
|
+
|
980
|
+
|
981
|
+
for fun in model.bef_updates:
|
982
|
+
fun(*args, **kwargs)
|
983
|
+
model.update(*args, **kwargs)
|
984
|
+
|
985
|
+
"""
|
986
|
+
# assert isinstance(cls, Module), 'The input class should be instance of Module.'
|
987
|
+
cls._receive_update_input = True
|
988
|
+
return cls
|
989
|
+
|
990
|
+
|
991
|
+
def not_receive_update_input(cls: object):
|
992
|
+
"""
|
993
|
+
The decorator to mark the object (as the before updates) to not receive the input of the update function.
|
994
|
+
|
995
|
+
That is, the `bef_update` will not receive the input of the update function::
|
996
|
+
|
997
|
+
for fun in model.bef_updates:
|
998
|
+
fun()
|
999
|
+
model.update()
|
1000
|
+
|
1001
|
+
"""
|
1002
|
+
# assert isinstance(cls, Module), 'The input class should be instance of Module.'
|
1003
|
+
if hasattr(cls, '_receive_update_input'):
|
1004
|
+
delattr(cls, '_receive_update_input')
|
1005
|
+
return cls
|
1006
|
+
|
1007
|
+
|
1008
|
+
class Delay(ExtendedUpdateWithBA, DelayedInit):
|
1009
|
+
"""
|
1010
|
+
Generate Delays for the given :py:class:`~.State` instance.
|
1011
|
+
|
1012
|
+
The data in this delay variable is arranged as::
|
1013
|
+
|
1014
|
+
delay = 0 [ data
|
1015
|
+
delay = 1 data
|
1016
|
+
delay = 2 data
|
1017
|
+
... ....
|
1018
|
+
... ....
|
1019
|
+
delay = length-1 data
|
1020
|
+
delay = length data ]
|
1021
|
+
|
1022
|
+
Args:
|
1023
|
+
time: int, float. The delay time.
|
1024
|
+
init: Any. The delay data. It can be a Python number, like float, int, boolean values.
|
1025
|
+
It can also be arrays. Or a callable function or instance of ``Connector``.
|
1026
|
+
Note that ``initial_delay_data`` should be arranged as the following way::
|
1027
|
+
|
1028
|
+
delay = 1 [ data
|
1029
|
+
delay = 2 data
|
1030
|
+
... ....
|
1031
|
+
... ....
|
1032
|
+
delay = length-1 data
|
1033
|
+
delay = length data ]
|
1034
|
+
entries: optional, dict. The delay access entries.
|
1035
|
+
name: str. The delay name.
|
1036
|
+
method: str. The method used for updating delay. Default None.
|
1037
|
+
mode: Mode. The computing mode. Default None.
|
1038
|
+
"""
|
1039
|
+
|
1040
|
+
__module__ = 'brainstate'
|
1041
|
+
|
1042
|
+
non_hash_params = ('time', 'entries', 'name')
|
1043
|
+
max_time: float
|
1044
|
+
max_length: int
|
1045
|
+
history: Optional[State]
|
1046
|
+
|
1047
|
+
def __init__(
|
1048
|
+
self,
|
1049
|
+
target_info: PyTree,
|
1050
|
+
time: Optional[Union[int, float]] = None, # delay time
|
1051
|
+
init: Optional[Union[ArrayLike, Callable]] = None, # delay data init
|
1052
|
+
entries: Optional[Dict] = None, # delay access entry
|
1053
|
+
method: Optional[str] = ROTATE_UPDATE, # delay method
|
1054
|
+
# others
|
1055
|
+
name: Optional[str] = None,
|
1056
|
+
mode: Optional[Mode] = None,
|
1057
|
+
):
|
1058
|
+
|
1059
|
+
# target information
|
1060
|
+
self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, get_dtype(a)), target_info)
|
1061
|
+
|
1062
|
+
# delay method
|
1063
|
+
assert method in [ROTATE_UPDATE, CONCAT_UPDATE]
|
1064
|
+
self.method = method
|
1065
|
+
|
1066
|
+
# delay length and time
|
1067
|
+
self.max_time, delay_length = _get_delay(time, None)
|
1068
|
+
self.max_length = delay_length + 1
|
1069
|
+
|
1070
|
+
super().__init__(name=name, mode=mode)
|
1071
|
+
|
1072
|
+
# delay data
|
1073
|
+
if init is not None:
|
1074
|
+
assert isinstance(init, (numbers.Number, jax.Array, Callable))
|
1075
|
+
self._init = init
|
1076
|
+
self._history = None
|
1077
|
+
|
1078
|
+
# other info
|
1079
|
+
self._registered_entries = dict()
|
1080
|
+
|
1081
|
+
# other info
|
1082
|
+
if entries is not None:
|
1083
|
+
for entry, delay_time in entries.items():
|
1084
|
+
self.register_entry(entry, delay_time)
|
1085
|
+
|
1086
|
+
def __repr__(self):
|
1087
|
+
name = self.__class__.__name__
|
1088
|
+
return f'{name}(delay_length={self.max_length}, target_info={self.target_info}, method="{self.method}")'
|
1089
|
+
|
1090
|
+
@property
|
1091
|
+
def history(self):
|
1092
|
+
return self._history
|
1093
|
+
|
1094
|
+
@history.setter
|
1095
|
+
def history(self, value):
|
1096
|
+
self._history = value
|
1097
|
+
|
1098
|
+
def _f_to_init(self, a, batch_size, length):
|
1099
|
+
shape = list(a.shape)
|
1100
|
+
if batch_size is not None:
|
1101
|
+
shape.insert(self.mode.batch_axis, batch_size)
|
1102
|
+
shape.insert(0, length)
|
1103
|
+
if isinstance(self._init, (jax.Array, numbers.Number)):
|
1104
|
+
data = jnp.broadcast_to(jnp.asarray(self._init, a.dtype), shape)
|
1105
|
+
elif callable(self._init):
|
1106
|
+
data = self._init(shape, dtype=a.dtype)
|
1107
|
+
else:
|
1108
|
+
assert self._init is None, f'init should be Array, Callable, or None. but got {self._init}'
|
1109
|
+
data = jnp.zeros(shape, dtype=a.dtype)
|
1110
|
+
return data
|
1111
|
+
|
1112
|
+
def init_state(self, batch_size: int = None, **kwargs):
|
1113
|
+
if batch_size is not None:
|
1114
|
+
assert self.mode.has(Batching), 'The mode should have Batching behavior when batch_size is not None.'
|
1115
|
+
fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
|
1116
|
+
self.history = State(jax.tree.map(fun, self.target_info))
|
1117
|
+
|
1118
|
+
def register_entry(
|
1119
|
+
self,
|
1120
|
+
entry: str,
|
1121
|
+
delay_time: Optional[Union[int, float]] = None,
|
1122
|
+
delay_step: Optional[int] = None,
|
1123
|
+
) -> 'Delay':
|
1124
|
+
"""Register an entry to access the data.
|
1125
|
+
|
1126
|
+
Args:
|
1127
|
+
entry: str. The entry to access the delay data.
|
1128
|
+
delay_time: The delay time of the entry (can be a float).
|
1129
|
+
delay_step: The delay step of the entry (must be an int). ``delat_step = delay_time / dt``.
|
1130
|
+
|
1131
|
+
Returns:
|
1132
|
+
Return the self.
|
1133
|
+
"""
|
1134
|
+
if entry in self._registered_entries:
|
1135
|
+
raise KeyError(f'Entry {entry} has been registered. '
|
1136
|
+
f'The existing delay for the key {entry} is {self._registered_entries[entry]}. '
|
1137
|
+
f'The new delay for the key {entry} is {delay_time}. '
|
1138
|
+
f'You can use another key. ')
|
1139
|
+
|
1140
|
+
if isinstance(delay_time, (np.ndarray, jax.Array)):
|
1141
|
+
assert delay_time.size == 1 and delay_time.ndim == 0
|
1142
|
+
delay_time = delay_time.item()
|
1143
|
+
|
1144
|
+
_, delay_step = _get_delay(delay_time, delay_step)
|
1145
|
+
|
1146
|
+
# delay variable
|
1147
|
+
if self.max_length <= delay_step + 1:
|
1148
|
+
self.max_length = delay_step + 1
|
1149
|
+
self.max_time = delay_time
|
1150
|
+
self._registered_entries[entry] = delay_step
|
1151
|
+
return self
|
1152
|
+
|
1153
|
+
def at(self, entry: str, *indices) -> ArrayLike:
|
1154
|
+
"""Get the data at the given entry.
|
1155
|
+
|
1156
|
+
Args:
|
1157
|
+
entry: str. The entry to access the data.
|
1158
|
+
*indices: The slicing indices. Not include the slice at the batch dimension.
|
1159
|
+
|
1160
|
+
Returns:
|
1161
|
+
The data.
|
1162
|
+
"""
|
1163
|
+
assert isinstance(entry, str), (f'entry should be a string for describing the '
|
1164
|
+
f'entry of the delay data. But we got {entry}.')
|
1165
|
+
if entry not in self._registered_entries:
|
1166
|
+
raise KeyError(f'Does not find delay entry "{entry}".')
|
1167
|
+
delay_step = self._registered_entries[entry]
|
1168
|
+
if delay_step is None:
|
1169
|
+
delay_step = 0
|
1170
|
+
return self.retrieve(delay_step, *indices)
|
1171
|
+
|
1172
|
+
def retrieve(self, delay_step, *indices):
|
1173
|
+
"""Retrieve the delay data according to the delay length.
|
1174
|
+
|
1175
|
+
Parameters
|
1176
|
+
----------
|
1177
|
+
delay_step: int
|
1178
|
+
The delay length used to retrieve the data.
|
1179
|
+
"""
|
1180
|
+
assert self.history is not None, 'The delay history is not initialized.'
|
1181
|
+
assert delay_step is not None, 'The delay step should be given.'
|
1182
|
+
|
1183
|
+
if environ.get(environ.JIT_ERROR_CHECK, False):
|
1184
|
+
def _check_delay(delay_len):
|
1185
|
+
raise ValueError(f'The request delay length should be less than the '
|
1186
|
+
f'maximum delay {self.max_length}. But we got {delay_len}')
|
1187
|
+
|
1188
|
+
jit_error(delay_step >= self.max_length, _check_delay, delay_step)
|
1189
|
+
|
1190
|
+
# rotation method
|
1191
|
+
if self.method == ROTATE_UPDATE:
|
1192
|
+
i = environ.get(environ.I)
|
1193
|
+
di = i - delay_step
|
1194
|
+
delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32)
|
1195
|
+
delay_idx = jax.lax.stop_gradient(delay_idx)
|
1196
|
+
|
1197
|
+
elif self.method == CONCAT_UPDATE:
|
1198
|
+
delay_idx = delay_step
|
1199
|
+
|
1200
|
+
else:
|
1201
|
+
raise ValueError(f'Unknown updating method "{self.method}"')
|
1202
|
+
|
1203
|
+
# the delay index
|
1204
|
+
if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer):
|
1205
|
+
raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}')
|
1206
|
+
indices = (delay_idx,) + indices
|
1207
|
+
|
1208
|
+
# the delay data
|
1209
|
+
return jax.tree.map(lambda a: a[indices], self.history.value)
|
1210
|
+
|
1211
|
+
def update(self, current: PyTree) -> None:
|
1212
|
+
"""
|
1213
|
+
Update delay variable with the new data.
|
1214
|
+
"""
|
1215
|
+
assert self.history is not None, 'The delay history is not initialized.'
|
1216
|
+
|
1217
|
+
# update the delay data at the rotation index
|
1218
|
+
if self.method == ROTATE_UPDATE:
|
1219
|
+
i = environ.get(environ.I)
|
1220
|
+
idx = jnp.asarray(i % self.max_length, dtype=environ.dutype())
|
1221
|
+
idx = jax.lax.stop_gradient(idx)
|
1222
|
+
self.history.value = jax.tree.map(lambda hist, cur: hist.at[idx].set(cur),
|
1223
|
+
self.history.value,
|
1224
|
+
current)
|
1225
|
+
# update the delay data at the first position
|
1226
|
+
elif self.method == CONCAT_UPDATE:
|
1227
|
+
current = jax.tree.map(lambda a: jnp.expand_dims(a, 0), current)
|
1228
|
+
if self.max_length > 1:
|
1229
|
+
self.history.value = jax.tree.map(lambda hist, cur: jnp.concatenate([cur, hist[:-1]], axis=0),
|
1230
|
+
self.history.value,
|
1231
|
+
current)
|
1232
|
+
else:
|
1233
|
+
self.history.value = current
|
1234
|
+
|
1235
|
+
else:
|
1236
|
+
raise ValueError(f'Unknown updating method "{self.method}"')
|
1237
|
+
|
1238
|
+
|
1239
|
+
class _StateDelay(Delay):
|
1240
|
+
"""
|
1241
|
+
The state delay class.
|
1242
|
+
|
1243
|
+
Args:
|
1244
|
+
target: The target state instance.
|
1245
|
+
init: The initial delay data.
|
1246
|
+
"""
|
1247
|
+
|
1248
|
+
__module__ = 'brainstate'
|
1249
|
+
_invisible_states = ('target',)
|
1250
|
+
|
1251
|
+
def __init__(
|
1252
|
+
self,
|
1253
|
+
target: State,
|
1254
|
+
time: Optional[Union[int, float]] = None, # delay time
|
1255
|
+
init: Optional[Union[ArrayLike, Callable]] = None, # delay data init
|
1256
|
+
entries: Optional[Dict] = None, # delay access entry
|
1257
|
+
method: Optional[str] = ROTATE_UPDATE, # delay method
|
1258
|
+
# others
|
1259
|
+
name: Optional[str] = None,
|
1260
|
+
mode: Optional[Mode] = None,
|
1261
|
+
):
|
1262
|
+
super().__init__(target_info=target.value,
|
1263
|
+
time=time, init=init, entries=entries,
|
1264
|
+
method=method, name=name, mode=mode)
|
1265
|
+
self.target = target
|
1266
|
+
|
1267
|
+
def update(self, *args, **kwargs):
|
1268
|
+
super().update(self.target.value)
|
1269
|
+
|
1270
|
+
|
1271
|
+
class DelayAccess(Module):
|
1272
|
+
"""
|
1273
|
+
The delay access class.
|
1274
|
+
|
1275
|
+
Args:
|
1276
|
+
delay: The delay instance.
|
1277
|
+
time: The delay time.
|
1278
|
+
indices: The indices of the delay data.
|
1279
|
+
delay_entry: The delay entry.
|
1280
|
+
"""
|
1281
|
+
|
1282
|
+
__module__ = 'brainstate'
|
1283
|
+
|
1284
|
+
def __init__(
|
1285
|
+
self,
|
1286
|
+
delay: Delay,
|
1287
|
+
time: Union[None, int, float],
|
1288
|
+
*indices,
|
1289
|
+
delay_entry: str = None
|
1290
|
+
):
|
1291
|
+
super().__init__(mode=delay.mode)
|
1292
|
+
self.refs = {'delay': delay}
|
1293
|
+
assert isinstance(delay, Delay), 'The input delay should be an instance of Delay.'
|
1294
|
+
self._delay_entry = delay_entry or self.name
|
1295
|
+
delay.register_entry(self._delay_entry, time)
|
1296
|
+
self.indices = indices
|
1297
|
+
|
1298
|
+
def update(self):
|
1299
|
+
return self.refs['delay'].at(self._delay_entry, *self.indices)
|
1300
|
+
|
1301
|
+
|
1302
|
+
def register_delay_of_target(target: AllOfTypes[ExtendedUpdateWithBA, UpdateReturn]):
|
1303
|
+
"""Register delay class for the given target.
|
1304
|
+
|
1305
|
+
Args:
|
1306
|
+
target: The target class to register delay.
|
1307
|
+
|
1308
|
+
Returns:
|
1309
|
+
The delay registered for the given target.
|
1310
|
+
"""
|
1311
|
+
if not target.has_after_update(delay_identifier):
|
1312
|
+
assert isinstance(target, AllOfTypes[ExtendedUpdateWithBA, UpdateReturn])
|
1313
|
+
target.add_after_update(delay_identifier, Delay(target.update_return_info()))
|
1314
|
+
delay_cls = target.get_after_update(delay_identifier)
|
1315
|
+
return delay_cls
|
1316
|
+
|
1317
|
+
|
1318
|
+
@set_module_as('brainstate')
|
1319
|
+
def call_order(level: int = 0):
|
1320
|
+
"""The decorator for indicating the resetting level.
|
1321
|
+
|
1322
|
+
The function takes an optional integer argument level with a default value of 0.
|
1323
|
+
|
1324
|
+
The lower the level, the earlier the function is called.
|
1325
|
+
|
1326
|
+
>>> import brainstate as bst
|
1327
|
+
>>> bst.call_order(0)
|
1328
|
+
>>> bst.call_order(-1)
|
1329
|
+
>>> bst.call_order(-2)
|
1330
|
+
|
1331
|
+
"""
|
1332
|
+
if level < 0:
|
1333
|
+
level = _max_order + level
|
1334
|
+
if level < 0 or level >= _max_order:
|
1335
|
+
raise ValueError(f'"call_order" must be an integer in [0, {_max_order}). but we got {level}.')
|
1336
|
+
|
1337
|
+
def wrap(fun: Callable):
|
1338
|
+
fun.call_order = level
|
1339
|
+
return fun
|
1340
|
+
|
1341
|
+
return wrap
|
1342
|
+
|
1343
|
+
|
1344
|
+
@set_module_as('brainstate')
|
1345
|
+
def init_states(target: Module, *args, **kwargs) -> Module:
|
1346
|
+
"""
|
1347
|
+
Reset states of all children nodes in the given target.
|
1348
|
+
|
1349
|
+
Args:
|
1350
|
+
target: The target Module.
|
1351
|
+
|
1352
|
+
Returns:
|
1353
|
+
The target Module.
|
1354
|
+
"""
|
1355
|
+
nodes_with_order = []
|
1356
|
+
|
1357
|
+
# reset node whose `init_state` has no `call_order`
|
1358
|
+
for node in list(target.nodes().values()):
|
1359
|
+
if not hasattr(node.init_state, 'call_order'):
|
1360
|
+
node.init_state(*args, **kwargs)
|
1361
|
+
else:
|
1362
|
+
nodes_with_order.append(node)
|
1363
|
+
|
1364
|
+
# reset the node's states
|
1365
|
+
for node in sorted(nodes_with_order, key=lambda x: x.init_state.call_order):
|
1366
|
+
node.init_state(*args, **kwargs)
|
1367
|
+
|
1368
|
+
return target
|
1369
|
+
|
1370
|
+
|
1371
|
+
@set_module_as('brainstate')
|
1372
|
+
def load_states(target: Module, state_dict: Dict, **kwargs):
|
1373
|
+
"""Copy parameters and buffers from :attr:`state_dict` into
|
1374
|
+
this module and its descendants.
|
1375
|
+
|
1376
|
+
Args:
|
1377
|
+
target: Module. The dynamical system to load its states.
|
1378
|
+
state_dict: dict. A dict containing parameters and persistent buffers.
|
1379
|
+
|
1380
|
+
Returns:
|
1381
|
+
-------
|
1382
|
+
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
|
1383
|
+
|
1384
|
+
* **missing_keys** is a list of str containing the missing keys
|
1385
|
+
* **unexpected_keys** is a list of str containing the unexpected keys
|
1386
|
+
"""
|
1387
|
+
missing_keys = []
|
1388
|
+
unexpected_keys = []
|
1389
|
+
for name, node in target.nodes().items():
|
1390
|
+
r = node.load_state(state_dict[name], **kwargs)
|
1391
|
+
if r is not None:
|
1392
|
+
missing, unexpected = r
|
1393
|
+
missing_keys.extend([f'{name}.{key}' for key in missing])
|
1394
|
+
unexpected_keys.extend([f'{name}.{key}' for key in unexpected])
|
1395
|
+
return StateLoadResult(missing_keys, unexpected_keys)
|
1396
|
+
|
1397
|
+
|
1398
|
+
@set_module_as('brainstate')
|
1399
|
+
def save_states(target: Module, **kwargs) -> Dict:
|
1400
|
+
"""Save all states in the ``target`` as a dictionary for later disk serialization.
|
1401
|
+
|
1402
|
+
Args:
|
1403
|
+
target: Module. The node to save its states.
|
1404
|
+
|
1405
|
+
Returns:
|
1406
|
+
Dict. The state dict for serialization.
|
1407
|
+
"""
|
1408
|
+
return {key: node.save_state(**kwargs) for key, node in target.nodes().items()}
|
1409
|
+
|
1410
|
+
|
1411
|
+
@set_module_as('brainstate')
|
1412
|
+
def assign_state_values(target: Module, *state_by_abs_path: Dict):
|
1413
|
+
"""
|
1414
|
+
Assign state values according to the given state dictionary.
|
1415
|
+
|
1416
|
+
Parameters
|
1417
|
+
----------
|
1418
|
+
target: Module
|
1419
|
+
The target module.
|
1420
|
+
state_by_abs_path: dict
|
1421
|
+
The state dictionary which is accessed by the "absolute" accessing method.
|
1422
|
+
|
1423
|
+
"""
|
1424
|
+
all_states = dict()
|
1425
|
+
for state in state_by_abs_path:
|
1426
|
+
all_states.update(state)
|
1427
|
+
variables = target.states(include_self=True, method='absolute')
|
1428
|
+
keys1 = set(all_states.keys())
|
1429
|
+
keys2 = set(variables.keys())
|
1430
|
+
for key in keys2.intersection(keys1):
|
1431
|
+
variables[key].value = jax.numpy.asarray(all_states[key])
|
1432
|
+
unexpected_keys = list(keys1 - keys2)
|
1433
|
+
missing_keys = list(keys2 - keys1)
|
1434
|
+
return unexpected_keys, missing_keys
|
1435
|
+
|
1436
|
+
|
1437
|
+
def _input_label_start(label: str):
|
1438
|
+
# unify the input label repr.
|
1439
|
+
return f'{label} // '
|
1440
|
+
|
1441
|
+
|
1442
|
+
def _input_label_repr(name: str, label: Optional[str] = None):
|
1443
|
+
# unify the input label repr.
|
1444
|
+
return name if label is None else (_input_label_start(label) + str(name))
|
1445
|
+
|
1446
|
+
|
1447
|
+
def _repr_context(repr_str, indent):
|
1448
|
+
splits = repr_str.split('\n')
|
1449
|
+
splits = [(s if i == 0 else (indent + s)) for i, s in enumerate(splits)]
|
1450
|
+
return '\n'.join(splits)
|
1451
|
+
|
1452
|
+
|
1453
|
+
def _get_delay(delay_time, delay_step):
|
1454
|
+
if delay_time is None:
|
1455
|
+
if delay_step is None:
|
1456
|
+
return 0., 0
|
1457
|
+
else:
|
1458
|
+
assert isinstance(delay_step, int), '"delay_step" should be an integer.'
|
1459
|
+
if delay_step == 0:
|
1460
|
+
return 0., 0
|
1461
|
+
delay_time = delay_step * environ.get_dt()
|
1462
|
+
else:
|
1463
|
+
assert delay_step is None, '"delay_step" should be None if "delay_time" is given.'
|
1464
|
+
assert isinstance(delay_time, (int, float))
|
1465
|
+
delay_step = math.ceil(delay_time / environ.get_dt())
|
1466
|
+
return delay_time, delay_step
|