brainstate 0.0.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. brainstate/__init__.py +45 -0
  2. brainstate/_module.py +1466 -0
  3. brainstate/_module_test.py +133 -0
  4. brainstate/_state.py +378 -0
  5. brainstate/_state_test.py +41 -0
  6. brainstate/_utils.py +21 -0
  7. brainstate/environ.py +375 -0
  8. brainstate/functional/__init__.py +25 -0
  9. brainstate/functional/_activations.py +754 -0
  10. brainstate/functional/_normalization.py +69 -0
  11. brainstate/functional/_spikes.py +90 -0
  12. brainstate/init/__init__.py +26 -0
  13. brainstate/init/_base.py +36 -0
  14. brainstate/init/_generic.py +175 -0
  15. brainstate/init/_random_inits.py +489 -0
  16. brainstate/init/_regular_inits.py +109 -0
  17. brainstate/math/__init__.py +21 -0
  18. brainstate/math/_einops.py +787 -0
  19. brainstate/math/_einops_parsing.py +169 -0
  20. brainstate/math/_einops_parsing_test.py +126 -0
  21. brainstate/math/_einops_test.py +346 -0
  22. brainstate/math/_misc.py +298 -0
  23. brainstate/math/_misc_test.py +58 -0
  24. brainstate/mixin.py +373 -0
  25. brainstate/mixin_test.py +73 -0
  26. brainstate/nn/__init__.py +68 -0
  27. brainstate/nn/_base.py +248 -0
  28. brainstate/nn/_connections.py +686 -0
  29. brainstate/nn/_dynamics.py +406 -0
  30. brainstate/nn/_elementwise.py +1437 -0
  31. brainstate/nn/_misc.py +132 -0
  32. brainstate/nn/_normalizations.py +389 -0
  33. brainstate/nn/_others.py +100 -0
  34. brainstate/nn/_poolings.py +1228 -0
  35. brainstate/nn/_poolings_test.py +231 -0
  36. brainstate/nn/_projection/__init__.py +32 -0
  37. brainstate/nn/_projection/_align_post.py +528 -0
  38. brainstate/nn/_projection/_align_pre.py +599 -0
  39. brainstate/nn/_projection/_delta.py +241 -0
  40. brainstate/nn/_projection/_utils.py +17 -0
  41. brainstate/nn/_projection/_vanilla.py +101 -0
  42. brainstate/nn/_rate_rnns.py +393 -0
  43. brainstate/nn/_readout.py +130 -0
  44. brainstate/nn/_synouts.py +166 -0
  45. brainstate/nn/functional/__init__.py +25 -0
  46. brainstate/nn/functional/_activations.py +754 -0
  47. brainstate/nn/functional/_normalization.py +69 -0
  48. brainstate/nn/functional/_spikes.py +90 -0
  49. brainstate/nn/init/__init__.py +26 -0
  50. brainstate/nn/init/_base.py +36 -0
  51. brainstate/nn/init/_generic.py +175 -0
  52. brainstate/nn/init/_random_inits.py +489 -0
  53. brainstate/nn/init/_regular_inits.py +109 -0
  54. brainstate/nn/surrogate.py +1740 -0
  55. brainstate/optim/__init__.py +23 -0
  56. brainstate/optim/_lr_scheduler.py +486 -0
  57. brainstate/optim/_lr_scheduler_test.py +36 -0
  58. brainstate/optim/_sgd_optimizer.py +1148 -0
  59. brainstate/random.py +5148 -0
  60. brainstate/random_test.py +576 -0
  61. brainstate/surrogate.py +1740 -0
  62. brainstate/transform/__init__.py +36 -0
  63. brainstate/transform/_autograd.py +585 -0
  64. brainstate/transform/_autograd_test.py +1183 -0
  65. brainstate/transform/_control.py +665 -0
  66. brainstate/transform/_controls_test.py +220 -0
  67. brainstate/transform/_jit.py +239 -0
  68. brainstate/transform/_jit_error.py +158 -0
  69. brainstate/transform/_jit_test.py +102 -0
  70. brainstate/transform/_make_jaxpr.py +573 -0
  71. brainstate/transform/_make_jaxpr_test.py +133 -0
  72. brainstate/transform/_progress_bar.py +113 -0
  73. brainstate/typing.py +69 -0
  74. brainstate/util.py +747 -0
  75. brainstate-0.0.1.dist-info/LICENSE +202 -0
  76. brainstate-0.0.1.dist-info/METADATA +101 -0
  77. brainstate-0.0.1.dist-info/RECORD +79 -0
  78. brainstate-0.0.1.dist-info/WHEEL +6 -0
  79. brainstate-0.0.1.dist-info/top_level.txt +1 -0
brainstate/_module.py ADDED
@@ -0,0 +1,1466 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+
19
+ """
20
+
21
+ All the basic classes for the ``brainstate``.
22
+
23
+ The basic classes include:
24
+
25
+ - ``Module``: The base class for all the objects in the ecosystem.
26
+ - ``Sequential``: The class for a sequential of modules, which update the modules sequentially.
27
+ - ``ModuleGroup``: The class for a group of modules, which update ``Projection`` first,
28
+ then ``Dynamics``, finally others.
29
+
30
+ and:
31
+
32
+ - ``visible_module_list``: A list to represent a sequence of :py:class:`~.Module`
33
+ that can be visible by the ``.nodes()`` extractor.
34
+ - ``visible_module_dict``: A dict to represent a dictionary of :py:class:`~.Module`
35
+ that can be visible by the ``.nodes()`` extractor.
36
+
37
+ For handling dynamical systems:
38
+
39
+ - ``Projection``: The class for the synaptic projection.
40
+ - ``Dynamics``: The class for the dynamical system.
41
+
42
+ For handling the delays:
43
+
44
+ - ``Delay``: The class for all delays.
45
+ - ``DelayAccess``: The class for the delay access.
46
+
47
+ """
48
+
49
+ import inspect
50
+ import math
51
+ import numbers
52
+ from collections import namedtuple
53
+ from functools import partial
54
+ from typing import Sequence, Any, Tuple, Union, Dict, Callable, Optional
55
+
56
+ import jax
57
+ import jax.numpy as jnp
58
+ import numpy as np
59
+
60
+ from . import environ
61
+ from ._utils import set_module_as
62
+ from ._state import State, StateDictManager, visible_state_dict
63
+ from .util import unique_name, DictManager, get_unique_name, DotDict
64
+ from .math import get_dtype
65
+ from .mixin import Mixin, Mode, DelayedInit, AllOfTypes, Batching, UpdateReturn
66
+ from .transform._jit_error import jit_error
67
+
68
+ Shape = Union[int, Sequence[int]]
69
+ PyTree = Any
70
+ ArrayLike = jax.typing.ArrayLike
71
+
72
+ delay_identifier = '_*_delay_of_'
73
+ ROTATE_UPDATE = 'rotation'
74
+ CONCAT_UPDATE = 'concat'
75
+
76
+ StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys'])
77
+
78
+ # the maximum order
79
+ _max_order = 10
80
+
81
+ __all__ = [
82
+ # basic classes
83
+ 'Module', 'visible_module_list', 'visible_module_dict', 'ModuleGroup',
84
+
85
+ # dynamical systems
86
+ 'Projection', 'Dynamics',
87
+
88
+ # delay handling
89
+ 'Delay', 'DelayAccess',
90
+
91
+ # helper functions
92
+ 'call_order',
93
+
94
+ # state processing
95
+ 'init_states', 'load_states', 'save_states', 'assign_state_values',
96
+ ]
97
+
98
+
99
+ class Module(object):
100
+ """
101
+ The Module class for the whole ecosystem.
102
+
103
+ The ``Module`` is the base class for all the objects in the ecosystem. It
104
+ provides the basic functionalities for the objects, including:
105
+
106
+ - ``states()``: Collect all states in this node and the children nodes.
107
+ - ``nodes()``: Collect all children nodes.
108
+ - ``update()``: The function to specify the updating rule.
109
+ - ``init_state()``: State initialization function.
110
+ - ``save_state()``: Save states as a dictionary.
111
+ - ``load_state()``: Load states from the external objects.
112
+
113
+ """
114
+
115
+ __module__ = 'brainstate'
116
+
117
+ # the excluded states
118
+ _invisible_states: Tuple[str, ...] = ()
119
+
120
+ # the excluded nodes
121
+ _invisible_nodes: Tuple[str, ...] = ()
122
+
123
+ # # the supported computing modes
124
+ # supported_modes: Optional[Sequence[Mode]] = None
125
+
126
+ def __init__(self, name: str = None, mode: Mode = None):
127
+ super().__init__()
128
+
129
+ # check whether the object has a unique name.
130
+ self._name = unique_name(self=self, name=name)
131
+
132
+ # mode setting
133
+ self._mode = None
134
+ self.mode = mode if mode is not None else environ.get('mode')
135
+
136
+ def __repr__(self):
137
+ return f'{self.__class__.__name__}'
138
+
139
+ @property
140
+ def name(self):
141
+ """Name of the model."""
142
+ return self._name
143
+
144
+ @name.setter
145
+ def name(self, name: str = None):
146
+ raise AttributeError('The name of the model is read-only.')
147
+
148
+ @property
149
+ def mode(self):
150
+ """Mode of the model, which is useful to control the multiple behaviors of the model."""
151
+ return self._mode
152
+
153
+ @mode.setter
154
+ def mode(self, value):
155
+ if not isinstance(value, Mode):
156
+ raise ValueError(f'Must be instance of {Mode.__name__}, '
157
+ f'but we got {type(value)}: {value}')
158
+ self._mode = value
159
+
160
+ def states(
161
+ self,
162
+ method: str = 'absolute',
163
+ level: int = -1,
164
+ include_self: bool = True,
165
+ unique: bool = True,
166
+ ) -> StateDictManager:
167
+ """
168
+ Collect all states in this node and the children nodes.
169
+
170
+ Parameters
171
+ ----------
172
+ method : str
173
+ The method to access the variables.
174
+ level: int
175
+ The hierarchy level to find variables.
176
+ include_self: bool
177
+ Whether include the variables in the self.
178
+ unique: bool
179
+ Whether return the unique variables.
180
+
181
+ Returns
182
+ -------
183
+ states : StateDictManager
184
+ The collection contained (the path, the variable).
185
+ """
186
+
187
+ # find the nodes
188
+ nodes = self.nodes(method=method, level=level, include_self=include_self)
189
+
190
+ # get the state stack
191
+ states = StateDictManager()
192
+ _state_id = set()
193
+ for node_path, node in nodes.items():
194
+ for k in node.__dict__.keys():
195
+ if k in node._invisible_states:
196
+ continue
197
+ v = getattr(node, k)
198
+ if isinstance(v, State):
199
+ if unique and id(v) in _state_id:
200
+ continue
201
+ _state_id.add(id(v))
202
+ states[f'{node_path}.{k}' if node_path else k] = v
203
+ elif isinstance(v, visible_state_dict):
204
+ for k2, v2 in v.items():
205
+ if unique and id(v2) in _state_id:
206
+ continue
207
+ _state_id.add(id(v2))
208
+ states[f'{node_path}.{k}.{k2}'] = v2
209
+
210
+ return states
211
+
212
+ def nodes(
213
+ self,
214
+ method: str = 'absolute',
215
+ level: int = -1,
216
+ include_self: bool = True,
217
+ unique: bool = True,
218
+ ) -> DictManager:
219
+ """
220
+ Collect all children nodes.
221
+
222
+ Parameters
223
+ ----------
224
+ method : str
225
+ The method to access the nodes.
226
+ level: int
227
+ The hierarchy level to find nodes.
228
+ include_self: bool
229
+ Whether include the self.
230
+ unique: bool
231
+ Whether return the unique variables.
232
+
233
+ Returns
234
+ -------
235
+ gather : DictManager
236
+ The collection contained (the path, the node).
237
+ """
238
+ nodes = _find_nodes(self, method=method, level=level, include_self=include_self)
239
+ if unique:
240
+ nodes = nodes.unique()
241
+ return nodes
242
+
243
+ def update(self, *args, **kwargs):
244
+ """
245
+ The function to specify the updating rule.
246
+ """
247
+ raise NotImplementedError(f'Subclass of {self.__class__.__name__} must '
248
+ f'implement "update" function.')
249
+
250
+ def __call__(self, *args, **kwargs):
251
+ return self.update(*args, **kwargs)
252
+
253
+ def __rrshift__(self, other):
254
+ """
255
+ Support using right shift operator to call modules.
256
+
257
+ Examples
258
+ --------
259
+
260
+ >>> import brainstate as bst
261
+ >>> import brainscale as nn # noqa
262
+ >>> x = bst.random.rand((10, 10))
263
+ >>> l = nn.Activation(jax.numpy.tanh)
264
+ >>> y = x >> l
265
+ """
266
+ return self.__call__(other)
267
+
268
+ def init_state(self, *args, **kwargs):
269
+ """
270
+ State initialization function.
271
+ """
272
+ pass
273
+
274
+ def save_state(self, **kwargs) -> Dict:
275
+ """Save states as a dictionary. """
276
+ return self.states(include_self=True, level=0, method='absolute')
277
+
278
+ def load_state(self, state_dict: Dict, **kwargs) -> Optional[Tuple[Sequence[str], Sequence[str]]]:
279
+ """Load states from the external objects."""
280
+ variables = self.states(include_self=True, level=0, method='absolute')
281
+ keys1 = set(state_dict.keys())
282
+ keys2 = set(variables.keys())
283
+ for key in keys2.intersection(keys1):
284
+ variables[key].value = jax.numpy.asarray(state_dict[key])
285
+ unexpected_keys = list(keys1 - keys2)
286
+ missing_keys = list(keys2 - keys1)
287
+ return unexpected_keys, missing_keys
288
+
289
+
290
+ def _find_nodes(self, method: str = 'absolute', level=-1, include_self=True, _lid=0, _edges=None) -> DictManager:
291
+ if _edges is None:
292
+ _edges = set()
293
+ gather = DictManager()
294
+ if include_self:
295
+ if method == 'absolute':
296
+ gather[self.name] = self
297
+ elif method == 'relative':
298
+ gather[''] = self
299
+ else:
300
+ raise ValueError(f'No support for the method of "{method}".')
301
+ if (level > -1) and (_lid >= level):
302
+ return gather
303
+ if method == 'absolute':
304
+ nodes = []
305
+ for k, v in self.__dict__.items():
306
+ if k in self._invisible_nodes:
307
+ continue
308
+ if isinstance(v, Module):
309
+ _add_node_absolute(self, v, _edges, gather, nodes)
310
+ elif isinstance(v, visible_module_list):
311
+ for v2 in v:
312
+ _add_node_absolute(self, v2, _edges, gather, nodes)
313
+ elif isinstance(v, visible_module_dict):
314
+ for v2 in v.values():
315
+ if isinstance(v2, Module):
316
+ _add_node_absolute(self, v2, _edges, gather, nodes)
317
+
318
+ # finding nodes recursively
319
+ for v in nodes:
320
+ gather.update(_find_nodes(v,
321
+ method=method,
322
+ level=level,
323
+ _lid=_lid + 1,
324
+ _edges=_edges,
325
+ include_self=include_self))
326
+
327
+ elif method == 'relative':
328
+ nodes = []
329
+ for k, v in self.__dict__.items():
330
+ if v in self._invisible_nodes:
331
+ continue
332
+ if isinstance(v, Module):
333
+ _add_node_relative(self, k, v, _edges, gather, nodes)
334
+ elif isinstance(v, visible_module_list):
335
+ for i, v2 in enumerate(v):
336
+ _add_node_relative(self, f'{k}-list:{i}', v2, _edges, gather, nodes)
337
+ elif isinstance(v, visible_module_dict):
338
+ for k2, v2 in v.items():
339
+ if isinstance(v2, Module):
340
+ _add_node_relative(self, f'{k}-dict:{k2}', v2, _edges, gather, nodes)
341
+
342
+ # finding nodes recursively
343
+ for k1, v1 in nodes:
344
+ for k2, v2 in _find_nodes(v1,
345
+ method=method,
346
+ _edges=_edges,
347
+ _lid=_lid + 1,
348
+ level=level,
349
+ include_self=include_self).items():
350
+ if k2:
351
+ gather[f'{k1}.{k2}'] = v2
352
+
353
+ else:
354
+ raise ValueError(f'No support for the method of "{method}".')
355
+ return gather
356
+
357
+
358
+ def _add_node_absolute(self, v, _paths, gather, nodes):
359
+ path = (id(self), id(v))
360
+ if path not in _paths:
361
+ _paths.add(path)
362
+ gather[v.name] = v
363
+ nodes.append(v)
364
+
365
+
366
+ def _add_node_relative(self, k, v, _paths, gather, nodes):
367
+ path = (id(self), id(v))
368
+ if path not in _paths:
369
+ _paths.add(path)
370
+ gather[k] = v
371
+ nodes.append((k, v))
372
+
373
+
374
+ class Projection(Module):
375
+ """
376
+ Base class to model synaptic projections.
377
+ """
378
+
379
+ __module__ = 'brainstate'
380
+
381
+ def update(self, *args, **kwargs):
382
+ nodes = tuple(self.nodes(level=1, include_self=False).values())
383
+ if len(nodes):
384
+ for node in nodes:
385
+ node(*args, **kwargs)
386
+ else:
387
+ raise ValueError('Do not implement the update() function.')
388
+
389
+
390
+ class visible_module_list(list):
391
+ """
392
+ A sequence of :py:class:`~.Module`, which is compatible with
393
+ :py:func:`~.vars()` and :py:func:`~.nodes()` operations in a :py:class:`~.Module`.
394
+
395
+ That is to say, any nodes that are wrapped into :py:class:`~.NodeList` will be automatically
396
+ retieved when using :py:func:`~.nodes()` function.
397
+
398
+ >>> import brainstate as bst
399
+ >>> l = bst.visible_module_list([bp.dnn.Dense(1, 2),
400
+ >>> bp.dnn.LSTMCell(2, 3)])
401
+ """
402
+
403
+ __module__ = 'brainstate'
404
+
405
+ def __init__(self, seq=()):
406
+ super().__init__()
407
+ self.extend(seq)
408
+
409
+ def append(self, element) -> 'visible_module_list':
410
+ if isinstance(element, State):
411
+ raise TypeError(f'Cannot append a state into a node list. ')
412
+ super().append(element)
413
+ return self
414
+
415
+ def extend(self, iterable) -> 'visible_module_list':
416
+ for element in iterable:
417
+ self.append(element)
418
+ return self
419
+
420
+
421
+ class visible_module_dict(dict):
422
+ """
423
+ A dictionary of :py:class:`~.Module`, which is compatible with
424
+ :py:func:`.vars()` operation in a :py:class:`~.Module`.
425
+
426
+ """
427
+
428
+ __module__ = 'brainstate'
429
+
430
+ def __init__(self, *args, check_unique: bool = False, **kwargs):
431
+ super().__init__()
432
+ self.check_unique = check_unique
433
+ self.update(*args, **kwargs)
434
+
435
+ def update(self, *args, **kwargs) -> 'visible_module_dict':
436
+ for arg in args:
437
+ if isinstance(arg, dict):
438
+ for k, v in arg.items():
439
+ self[k] = v
440
+ elif isinstance(arg, tuple):
441
+ assert len(arg) == 2
442
+ self[arg[0]] = args[1]
443
+ for k, v in kwargs.items():
444
+ self[k] = v
445
+ return self
446
+
447
+ def __setitem__(self, key, value) -> 'visible_module_dict':
448
+ if self.check_unique:
449
+ exist = self.get(key, None)
450
+ if id(exist) != id(value):
451
+ raise KeyError(f'Duplicate usage of key "{key}". "{key}" has been used for {value}.')
452
+ super().__setitem__(key, value)
453
+ return self
454
+
455
+
456
+ class ReceiveInputProj(Mixin):
457
+ """
458
+ The :py:class:`~.Mixin` that receives the input projections.
459
+
460
+ Note that the subclass should define a ``cur_inputs`` attribute. Otherwise,
461
+ the input function utilities cannot be used.
462
+
463
+ """
464
+ _current_inputs: Optional[visible_module_dict]
465
+ _delta_inputs: Optional[visible_module_dict]
466
+
467
+ @property
468
+ def current_inputs(self):
469
+ """
470
+ The current inputs of the model. It should be a dictionary of the input data.
471
+ """
472
+ return self._current_inputs
473
+
474
+ @property
475
+ def delta_inputs(self):
476
+ """
477
+ The delta inputs of the model. It should be a dictionary of the input data.
478
+ """
479
+
480
+ return self._delta_inputs
481
+
482
+ def add_input_fun(self, key: str, fun: Callable, label: Optional[str] = None, category: str = 'current'):
483
+ """Add an input function.
484
+
485
+ Args:
486
+ key: str. The dict key.
487
+ fun: Callable. The function to generate inputs.
488
+ label: str. The input label.
489
+ category: str. The input category, should be ``current`` (the current) or
490
+ ``delta`` (the delta synapse, indicating the delta function).
491
+ """
492
+ if not callable(fun):
493
+ raise TypeError('Must be a function.')
494
+
495
+ key = _input_label_repr(key, label)
496
+ if category == 'current':
497
+ if self._current_inputs is None:
498
+ self._current_inputs = visible_module_dict()
499
+ if key in self._current_inputs:
500
+ raise ValueError(f'Key "{key}" has been defined and used.')
501
+ self._current_inputs[key] = fun
502
+
503
+ elif category == 'delta':
504
+ if self._delta_inputs is None:
505
+ self._delta_inputs = visible_module_dict()
506
+ if key in self._delta_inputs:
507
+ raise ValueError(f'Key "{key}" has been defined and used.')
508
+ self._delta_inputs[key] = fun
509
+
510
+ else:
511
+ raise NotImplementedError(f'Unknown category: {category}. Only support "current" and "delta".')
512
+
513
+ def get_input_fun(self, key: str):
514
+ """Get the input function.
515
+
516
+ Args:
517
+ key: str. The key.
518
+
519
+ Returns:
520
+ The input function which generates currents.
521
+ """
522
+ if self._current_inputs is not None and key in self._current_inputs:
523
+ return self._current_inputs[key]
524
+
525
+ elif self._delta_inputs is not None and key in self._delta_inputs:
526
+ return self._delta_inputs[key]
527
+
528
+ else:
529
+ raise ValueError(f'Unknown key: {key}')
530
+
531
+ def sum_current_inputs(self, *args, init: Any = 0., label: Optional[str] = None, **kwargs):
532
+ """
533
+ Summarize all current inputs by the defined input functions ``.current_inputs``.
534
+
535
+ Args:
536
+ *args: The arguments for input functions.
537
+ init: The initial input data.
538
+ label: str. The input label.
539
+ **kwargs: The arguments for input functions.
540
+
541
+ Returns:
542
+ The total currents.
543
+ """
544
+ if self._current_inputs is None:
545
+ return init
546
+ if label is None:
547
+ for key, out in self._current_inputs.items():
548
+ init = init + out(*args, **kwargs)
549
+ else:
550
+ label_repr = _input_label_start(label)
551
+ for key, out in self._current_inputs.items():
552
+ if key.startswith(label_repr):
553
+ init = init + out(*args, **kwargs)
554
+ return init
555
+
556
+ def sum_delta_inputs(self, *args, init: Any = 0., label: Optional[str] = None, **kwargs):
557
+ """
558
+ Summarize all delta inputs by the defined input functions ``.delta_inputs``.
559
+
560
+ Args:
561
+ *args: The arguments for input functions.
562
+ init: The initial input data.
563
+ label: str. The input label.
564
+ **kwargs: The arguments for input functions.
565
+
566
+ Returns:
567
+ The total currents.
568
+ """
569
+ if self._delta_inputs is None:
570
+ return init
571
+ if label is None:
572
+ for key, out in self._delta_inputs.items():
573
+ init = init + out(*args, **kwargs)
574
+ else:
575
+ label_repr = _input_label_start(label)
576
+ for key, out in self._delta_inputs.items():
577
+ if key.startswith(label_repr):
578
+ init = init + out(*args, **kwargs)
579
+ return init
580
+
581
+
582
+ class Container(Mixin):
583
+ """Container :py:class:`~.MixIn` which wrap a group of objects.
584
+ """
585
+ children: visible_module_dict
586
+
587
+ def __getitem__(self, item):
588
+ """Overwrite the slice access (`self['']`). """
589
+ if item in self.children:
590
+ return self.children[item]
591
+ else:
592
+ raise ValueError(f'Unknown item {item}, we only found {list(self.children.keys())}')
593
+
594
+ def __getattr__(self, item):
595
+ """Overwrite the dot access (`self.`). """
596
+ children = super().__getattribute__('children')
597
+ if item == 'children':
598
+ return children
599
+ else:
600
+ if item in children:
601
+ return children[item]
602
+ else:
603
+ return super().__getattribute__(item)
604
+
605
+ def __repr__(self):
606
+ cls_name = self.__class__.__name__
607
+ indent = ' ' * len(cls_name)
608
+ child_str = [_repr_context(repr(val), indent) for val in self.children.values()]
609
+ string = ", \n".join(child_str)
610
+ return f'{cls_name}({string})'
611
+
612
+ def __get_elem_name(self, elem):
613
+ if isinstance(elem, Module):
614
+ return elem.name
615
+ else:
616
+ return get_unique_name('ContainerElem')
617
+
618
+ def format_elements(self, child_type: type, *children_as_tuple, **children_as_dict):
619
+ res = dict()
620
+
621
+ # add tuple-typed components
622
+ for module in children_as_tuple:
623
+ if isinstance(module, child_type):
624
+ res[self.__get_elem_name(module)] = module
625
+ elif isinstance(module, (list, tuple)):
626
+ for m in module:
627
+ if not isinstance(m, child_type):
628
+ raise TypeError(f'Should be instance of {child_type.__name__}. '
629
+ f'But we got {type(m)}')
630
+ res[self.__get_elem_name(m)] = m
631
+ elif isinstance(module, dict):
632
+ for k, v in module.items():
633
+ if not isinstance(v, child_type):
634
+ raise TypeError(f'Should be instance of {child_type.__name__}. '
635
+ f'But we got {type(v)}')
636
+ res[k] = v
637
+ else:
638
+ raise TypeError(f'Cannot parse sub-systems. They should be {child_type.__name__} '
639
+ f'or a list/tuple/dict of {child_type.__name__}.')
640
+ # add dict-typed components
641
+ for k, v in children_as_dict.items():
642
+ if not isinstance(v, child_type):
643
+ raise TypeError(f'Should be instance of {child_type.__name__}. '
644
+ f'But we got {type(v)}')
645
+ res[k] = v
646
+ return res
647
+
648
+ def add_elem(self, *elems, **elements):
649
+ """
650
+ Add new elements.
651
+
652
+ >>> obj = Container()
653
+ >>> obj.add_elem(a=1.)
654
+
655
+ Args:
656
+ elements: children objects.
657
+ """
658
+ self.children.update(self.format_elements(object, *elems, **elements))
659
+
660
+
661
+ class ExtendedUpdateWithBA(Module):
662
+ """
663
+ The extended update with before and after updates.
664
+ """
665
+
666
+ _before_updates: Optional[visible_module_dict]
667
+ _after_updates: Optional[visible_module_dict]
668
+
669
+ def __init__(self, *args, **kwargs):
670
+
671
+ # -- Attribute for "BeforeAfterMixIn" -- #
672
+ # the before- / after-updates used for computing
673
+ self._before_updates: Optional[Dict[str, Callable]] = None
674
+ self._after_updates: Optional[Dict[str, Callable]] = None
675
+
676
+ super().__init__(*args, **kwargs)
677
+
678
+ @property
679
+ def before_updates(self):
680
+ """
681
+ The before updates of the model. It should be a dictionary of the updating functions.
682
+ """
683
+ return self._before_updates
684
+
685
+ @property
686
+ def after_updates(self):
687
+ """
688
+ The after updates of the model. It should be a dictionary of the updating functions.
689
+ """
690
+ return self._after_updates
691
+
692
+ def add_before_update(self, key: Any, fun: Callable):
693
+ """
694
+ Add the before update into this node.
695
+ """
696
+ if self._before_updates is None:
697
+ self._before_updates = visible_module_dict()
698
+ if key in self.before_updates:
699
+ raise KeyError(f'{key} has been registered in before_updates of {self}')
700
+ self.before_updates[key] = fun
701
+
702
+ def add_after_update(self, key: Any, fun: Callable):
703
+ """Add the after update into this node"""
704
+ if self._after_updates is None:
705
+ self._after_updates = visible_module_dict()
706
+ if key in self.after_updates:
707
+ raise KeyError(f'{key} has been registered in after_updates of {self}')
708
+ self.after_updates[key] = fun
709
+
710
+ def get_before_update(self, key: Any):
711
+ """Get the before update of this node by the given ``key``."""
712
+ if self._before_updates is None:
713
+ raise KeyError(f'{key} is not registered in before_updates of {self}')
714
+ if key not in self.before_updates:
715
+ raise KeyError(f'{key} is not registered in before_updates of {self}')
716
+ return self.before_updates.get(key)
717
+
718
+ def get_after_update(self, key: Any):
719
+ """Get the after update of this node by the given ``key``."""
720
+ if self._after_updates is None:
721
+ raise KeyError(f'{key} is not registered in after_updates of {self}')
722
+ if key not in self.after_updates:
723
+ raise KeyError(f'{key} is not registered in after_updates of {self}')
724
+ return self.after_updates.get(key)
725
+
726
+ def has_before_update(self, key: Any):
727
+ """Whether this node has the before update of the given ``key``."""
728
+ if self._before_updates is None:
729
+ return False
730
+ return key in self.before_updates
731
+
732
+ def has_after_update(self, key: Any):
733
+ """Whether this node has the after update of the given ``key``."""
734
+ if self._after_updates is None:
735
+ return False
736
+ return key in self.after_updates
737
+
738
+ def __call__(self, *args, **kwargs):
739
+ """The shortcut to call ``update`` methods."""
740
+
741
+ # ``before_updates``
742
+ if self.before_updates is not None:
743
+ for model in self.before_updates.values():
744
+ if hasattr(model, '_receive_update_input'):
745
+ model(*args, **kwargs)
746
+ else:
747
+ model()
748
+
749
+ # update the model self
750
+ ret = self.update(*args, **kwargs)
751
+
752
+ # ``after_updates``
753
+ if self.after_updates is not None:
754
+ for model in self.after_updates.values():
755
+ if hasattr(model, '_not_receive_update_output'):
756
+ model()
757
+ else:
758
+ model(ret)
759
+ return ret
760
+
761
+
762
+ class Dynamics(ExtendedUpdateWithBA, ReceiveInputProj, UpdateReturn):
763
+ """
764
+ Dynamical System class.
765
+
766
+ .. note::
767
+ In general, every instance of :py:class:`~.Module` implemented in
768
+ BrainPy only defines the evolving function at each time step :math:`t`.
769
+
770
+ If users want to define the logic of running models across multiple steps,
771
+ we recommend users to use :py:func:`~.for_loop`, :py:class:`~.LoopOverTime`,
772
+ :py:class:`~.DSRunner`, or :py:class:`~.DSTrainer`.
773
+
774
+ To be compatible with previous APIs, :py:class:`~.Module` inherits
775
+ from the :py:class:`~.DelayRegister`. It's worthy to note that the methods of
776
+ :py:class:`~.DelayRegister` will be removed in the future, including:
777
+
778
+ - ``.register_delay()``
779
+ - ``.get_delay_data()``
780
+ - ``.update_local_delays()``
781
+ - ``.reset_local_delays()``
782
+
783
+
784
+ There are several essential attributes:
785
+
786
+ - ``size``: the geometry of the neuron group. For example, `(10, )` denotes a line of
787
+ neurons, `(10, 10)` denotes a neuron group aligned in a 2D space, `(10, 15, 4)` denotes
788
+ a 3-dimensional neuron group.
789
+ - ``num``: the flattened number of neurons in the group. For example, `size=(10, )` => \
790
+ `num=10`, `size=(10, 10)` => `num=100`, `size=(10, 15, 4)` => `num=600`.
791
+
792
+ Args:
793
+ size: The neuron group geometry.
794
+ name: The name of the dynamic system.
795
+ keep_size: Whether keep the geometry information.
796
+ mode: The computing mode.
797
+ """
798
+
799
+ __module__ = 'brainstate'
800
+
801
+ def __init__(
802
+ self,
803
+ size: Shape,
804
+ keep_size: bool = False,
805
+ name: Optional[str] = None,
806
+ mode: Optional[Mode] = None,
807
+ method: str = 'exp_auto'
808
+ ):
809
+ # size
810
+ if isinstance(size, (list, tuple)):
811
+ if len(size) <= 0:
812
+ raise ValueError(f'size must be int, or a tuple/list of int. '
813
+ f'But we got {type(size)}')
814
+ if not isinstance(size[0], (int, np.integer)):
815
+ raise ValueError('size must be int, or a tuple/list of int.'
816
+ f'But we got {type(size)}')
817
+ size = tuple(size)
818
+ elif isinstance(size, (int, np.integer)):
819
+ size = (size,)
820
+ else:
821
+ raise ValueError('size must be int, or a tuple/list of int.'
822
+ f'But we got {type(size)}')
823
+ self.size = size
824
+ self.keep_size = keep_size
825
+
826
+ # number of neurons
827
+ self.num = np.prod(size)
828
+
829
+ # integration method
830
+ self.method = method
831
+
832
+ # -- Attribute for "InputProjMixIn" -- #
833
+ # each instance of "SupportInputProj" should have
834
+ # "_current_inputs" and "_delta_inputs" attributes
835
+ self._current_inputs: Optional[Dict[str, Callable]] = None
836
+ self._delta_inputs: Optional[Dict[str, Callable]] = None
837
+
838
+ # initialize
839
+ super().__init__(name=name, mode=mode)
840
+
841
+ @property
842
+ def varshape(self):
843
+ """The shape of variables in the neuron group."""
844
+ return self.size if self.keep_size else (self.num,)
845
+
846
+ def __repr__(self):
847
+ return f'{self.name}(mode={self.mode}, size={self.size})'
848
+
849
+ def update_return_info(self) -> PyTree:
850
+ raise NotImplementedError(f'Subclass of {self.__class__.__name__}'
851
+ 'must implement "update_return_info" function.')
852
+
853
+ def update_return(self) -> PyTree:
854
+ raise NotImplementedError(f'Subclass of {self.__class__.__name__}'
855
+ 'must implement "update_return" function.')
856
+
857
+ def register_return_delay(
858
+ self,
859
+ delay_name: str,
860
+ delay_time: ArrayLike = None,
861
+ delay_step: ArrayLike = None,
862
+ ):
863
+ """Register local relay at the given delay time.
864
+
865
+ Args:
866
+ delay_name: str. The name of the current delay data.
867
+ delay_time: The delay time. Float.
868
+ delay_step: The delay step. Int. ``delay_step`` and ``delay_time`` are exclusive. ``delay_step = delay_time / dt``.
869
+ """
870
+ if not self.has_after_update(delay_identifier):
871
+ # add a model to receive the return of the target model
872
+ model = Delay(self.update_return_info())
873
+ # register the model
874
+ self.add_after_update(delay_identifier, model)
875
+ delay_cls: Delay = self.get_after_update(delay_identifier)
876
+ delay_cls.register_entry(delay_name, delay_time=delay_time, delay_step=delay_step)
877
+ return delay_cls
878
+
879
+ def get_return_delay_at(self, delay_name):
880
+ """Get the state delay at the given identifier (`name`).
881
+
882
+ See also :py:meth:`~.Module.register_state_delay`.
883
+
884
+ Args:
885
+ delay_name: The identifier of the delay.
886
+
887
+ Returns:
888
+ The delayed data at the given delay position.
889
+ """
890
+ return self.get_after_update(delay_identifier).at(delay_name)
891
+
892
+
893
+ class ModuleGroup(Module, Container):
894
+ """A group of :py:class:`~.Module` in which the updating order does not matter.
895
+
896
+ Args:
897
+ children_as_tuple: The children objects.
898
+ children_as_dict: The children objects.
899
+ name: The object name.
900
+ mode: The mode which controls the model computation.
901
+ child_type: The type of the children object. Default is :py:class:`Module`.
902
+ """
903
+
904
+ __module__ = 'brainstate'
905
+
906
+ def __init__(
907
+ self,
908
+ *children_as_tuple,
909
+ name: Optional[str] = None,
910
+ mode: Optional[Mode] = None,
911
+ child_type: type = Module,
912
+ **children_as_dict
913
+ ):
914
+ super().__init__(name=name, mode=mode)
915
+
916
+ # Attribute of "Container"
917
+ self.children = visible_module_dict(self.format_elements(child_type, *children_as_tuple, **children_as_dict))
918
+
919
+ def update(self, *args, **kwargs):
920
+ """
921
+ Step function of a network.
922
+
923
+ In this update function, the update functions in children systems are
924
+ iteratively called.
925
+ """
926
+ projs, dyns, others = self.nodes(level=1, include_self=False).split(Projection, Dynamics)
927
+
928
+ # update nodes of projections
929
+ for node in projs.values():
930
+ node()
931
+
932
+ # update nodes of dynamics
933
+ for node in dyns.values():
934
+ node()
935
+
936
+ # update nodes with other types, including delays, ...
937
+ for node in others.values():
938
+ node()
939
+
940
+
941
+ def receive_update_output(cls: object):
942
+ """
943
+ The decorator to mark the object (as the after updates) to receive the output of the update function.
944
+
945
+ That is, the `aft_update` will receive the return of the update function::
946
+
947
+ ret = model.update(*args, **kwargs)
948
+ for fun in model.aft_updates:
949
+ fun(ret)
950
+
951
+ """
952
+ # assert isinstance(cls, Module), 'The input class should be instance of Module.'
953
+ if hasattr(cls, '_not_receive_update_output'):
954
+ delattr(cls, '_not_receive_update_output')
955
+ return cls
956
+
957
+
958
+ def not_receive_update_output(cls: object):
959
+ """
960
+ The decorator to mark the object (as the after updates) to not receive the output of the update function.
961
+
962
+ That is, the `aft_update` will not receive the return of the update function::
963
+
964
+ ret = model.update(*args, **kwargs)
965
+ for fun in model.aft_updates:
966
+ fun()
967
+
968
+ """
969
+ # assert isinstance(cls, Module), 'The input class should be instance of Module.'
970
+ cls._not_receive_update_output = True
971
+ return cls
972
+
973
+
974
+ def receive_update_input(cls: object):
975
+ """
976
+ The decorator to mark the object (as the before updates) to receive the input of the update function.
977
+
978
+ That is, the `bef_update` will receive the input of the update function::
979
+
980
+
981
+ for fun in model.bef_updates:
982
+ fun(*args, **kwargs)
983
+ model.update(*args, **kwargs)
984
+
985
+ """
986
+ # assert isinstance(cls, Module), 'The input class should be instance of Module.'
987
+ cls._receive_update_input = True
988
+ return cls
989
+
990
+
991
+ def not_receive_update_input(cls: object):
992
+ """
993
+ The decorator to mark the object (as the before updates) to not receive the input of the update function.
994
+
995
+ That is, the `bef_update` will not receive the input of the update function::
996
+
997
+ for fun in model.bef_updates:
998
+ fun()
999
+ model.update()
1000
+
1001
+ """
1002
+ # assert isinstance(cls, Module), 'The input class should be instance of Module.'
1003
+ if hasattr(cls, '_receive_update_input'):
1004
+ delattr(cls, '_receive_update_input')
1005
+ return cls
1006
+
1007
+
1008
+ class Delay(ExtendedUpdateWithBA, DelayedInit):
1009
+ """
1010
+ Generate Delays for the given :py:class:`~.State` instance.
1011
+
1012
+ The data in this delay variable is arranged as::
1013
+
1014
+ delay = 0 [ data
1015
+ delay = 1 data
1016
+ delay = 2 data
1017
+ ... ....
1018
+ ... ....
1019
+ delay = length-1 data
1020
+ delay = length data ]
1021
+
1022
+ Args:
1023
+ time: int, float. The delay time.
1024
+ init: Any. The delay data. It can be a Python number, like float, int, boolean values.
1025
+ It can also be arrays. Or a callable function or instance of ``Connector``.
1026
+ Note that ``initial_delay_data`` should be arranged as the following way::
1027
+
1028
+ delay = 1 [ data
1029
+ delay = 2 data
1030
+ ... ....
1031
+ ... ....
1032
+ delay = length-1 data
1033
+ delay = length data ]
1034
+ entries: optional, dict. The delay access entries.
1035
+ name: str. The delay name.
1036
+ method: str. The method used for updating delay. Default None.
1037
+ mode: Mode. The computing mode. Default None.
1038
+ """
1039
+
1040
+ __module__ = 'brainstate'
1041
+
1042
+ non_hash_params = ('time', 'entries', 'name')
1043
+ max_time: float
1044
+ max_length: int
1045
+ history: Optional[State]
1046
+
1047
+ def __init__(
1048
+ self,
1049
+ target_info: PyTree,
1050
+ time: Optional[Union[int, float]] = None, # delay time
1051
+ init: Optional[Union[ArrayLike, Callable]] = None, # delay data init
1052
+ entries: Optional[Dict] = None, # delay access entry
1053
+ method: Optional[str] = ROTATE_UPDATE, # delay method
1054
+ # others
1055
+ name: Optional[str] = None,
1056
+ mode: Optional[Mode] = None,
1057
+ ):
1058
+
1059
+ # target information
1060
+ self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, get_dtype(a)), target_info)
1061
+
1062
+ # delay method
1063
+ assert method in [ROTATE_UPDATE, CONCAT_UPDATE]
1064
+ self.method = method
1065
+
1066
+ # delay length and time
1067
+ self.max_time, delay_length = _get_delay(time, None)
1068
+ self.max_length = delay_length + 1
1069
+
1070
+ super().__init__(name=name, mode=mode)
1071
+
1072
+ # delay data
1073
+ if init is not None:
1074
+ assert isinstance(init, (numbers.Number, jax.Array, Callable))
1075
+ self._init = init
1076
+ self._history = None
1077
+
1078
+ # other info
1079
+ self._registered_entries = dict()
1080
+
1081
+ # other info
1082
+ if entries is not None:
1083
+ for entry, delay_time in entries.items():
1084
+ self.register_entry(entry, delay_time)
1085
+
1086
+ def __repr__(self):
1087
+ name = self.__class__.__name__
1088
+ return f'{name}(delay_length={self.max_length}, target_info={self.target_info}, method="{self.method}")'
1089
+
1090
+ @property
1091
+ def history(self):
1092
+ return self._history
1093
+
1094
+ @history.setter
1095
+ def history(self, value):
1096
+ self._history = value
1097
+
1098
+ def _f_to_init(self, a, batch_size, length):
1099
+ shape = list(a.shape)
1100
+ if batch_size is not None:
1101
+ shape.insert(self.mode.batch_axis, batch_size)
1102
+ shape.insert(0, length)
1103
+ if isinstance(self._init, (jax.Array, numbers.Number)):
1104
+ data = jnp.broadcast_to(jnp.asarray(self._init, a.dtype), shape)
1105
+ elif callable(self._init):
1106
+ data = self._init(shape, dtype=a.dtype)
1107
+ else:
1108
+ assert self._init is None, f'init should be Array, Callable, or None. but got {self._init}'
1109
+ data = jnp.zeros(shape, dtype=a.dtype)
1110
+ return data
1111
+
1112
+ def init_state(self, batch_size: int = None, **kwargs):
1113
+ if batch_size is not None:
1114
+ assert self.mode.has(Batching), 'The mode should have Batching behavior when batch_size is not None.'
1115
+ fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
1116
+ self.history = State(jax.tree.map(fun, self.target_info))
1117
+
1118
+ def register_entry(
1119
+ self,
1120
+ entry: str,
1121
+ delay_time: Optional[Union[int, float]] = None,
1122
+ delay_step: Optional[int] = None,
1123
+ ) -> 'Delay':
1124
+ """Register an entry to access the data.
1125
+
1126
+ Args:
1127
+ entry: str. The entry to access the delay data.
1128
+ delay_time: The delay time of the entry (can be a float).
1129
+ delay_step: The delay step of the entry (must be an int). ``delat_step = delay_time / dt``.
1130
+
1131
+ Returns:
1132
+ Return the self.
1133
+ """
1134
+ if entry in self._registered_entries:
1135
+ raise KeyError(f'Entry {entry} has been registered. '
1136
+ f'The existing delay for the key {entry} is {self._registered_entries[entry]}. '
1137
+ f'The new delay for the key {entry} is {delay_time}. '
1138
+ f'You can use another key. ')
1139
+
1140
+ if isinstance(delay_time, (np.ndarray, jax.Array)):
1141
+ assert delay_time.size == 1 and delay_time.ndim == 0
1142
+ delay_time = delay_time.item()
1143
+
1144
+ _, delay_step = _get_delay(delay_time, delay_step)
1145
+
1146
+ # delay variable
1147
+ if self.max_length <= delay_step + 1:
1148
+ self.max_length = delay_step + 1
1149
+ self.max_time = delay_time
1150
+ self._registered_entries[entry] = delay_step
1151
+ return self
1152
+
1153
+ def at(self, entry: str, *indices) -> ArrayLike:
1154
+ """Get the data at the given entry.
1155
+
1156
+ Args:
1157
+ entry: str. The entry to access the data.
1158
+ *indices: The slicing indices. Not include the slice at the batch dimension.
1159
+
1160
+ Returns:
1161
+ The data.
1162
+ """
1163
+ assert isinstance(entry, str), (f'entry should be a string for describing the '
1164
+ f'entry of the delay data. But we got {entry}.')
1165
+ if entry not in self._registered_entries:
1166
+ raise KeyError(f'Does not find delay entry "{entry}".')
1167
+ delay_step = self._registered_entries[entry]
1168
+ if delay_step is None:
1169
+ delay_step = 0
1170
+ return self.retrieve(delay_step, *indices)
1171
+
1172
+ def retrieve(self, delay_step, *indices):
1173
+ """Retrieve the delay data according to the delay length.
1174
+
1175
+ Parameters
1176
+ ----------
1177
+ delay_step: int
1178
+ The delay length used to retrieve the data.
1179
+ """
1180
+ assert self.history is not None, 'The delay history is not initialized.'
1181
+ assert delay_step is not None, 'The delay step should be given.'
1182
+
1183
+ if environ.get(environ.JIT_ERROR_CHECK, False):
1184
+ def _check_delay(delay_len):
1185
+ raise ValueError(f'The request delay length should be less than the '
1186
+ f'maximum delay {self.max_length}. But we got {delay_len}')
1187
+
1188
+ jit_error(delay_step >= self.max_length, _check_delay, delay_step)
1189
+
1190
+ # rotation method
1191
+ if self.method == ROTATE_UPDATE:
1192
+ i = environ.get(environ.I)
1193
+ di = i - delay_step
1194
+ delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32)
1195
+ delay_idx = jax.lax.stop_gradient(delay_idx)
1196
+
1197
+ elif self.method == CONCAT_UPDATE:
1198
+ delay_idx = delay_step
1199
+
1200
+ else:
1201
+ raise ValueError(f'Unknown updating method "{self.method}"')
1202
+
1203
+ # the delay index
1204
+ if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer):
1205
+ raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}')
1206
+ indices = (delay_idx,) + indices
1207
+
1208
+ # the delay data
1209
+ return jax.tree.map(lambda a: a[indices], self.history.value)
1210
+
1211
+ def update(self, current: PyTree) -> None:
1212
+ """
1213
+ Update delay variable with the new data.
1214
+ """
1215
+ assert self.history is not None, 'The delay history is not initialized.'
1216
+
1217
+ # update the delay data at the rotation index
1218
+ if self.method == ROTATE_UPDATE:
1219
+ i = environ.get(environ.I)
1220
+ idx = jnp.asarray(i % self.max_length, dtype=environ.dutype())
1221
+ idx = jax.lax.stop_gradient(idx)
1222
+ self.history.value = jax.tree.map(lambda hist, cur: hist.at[idx].set(cur),
1223
+ self.history.value,
1224
+ current)
1225
+ # update the delay data at the first position
1226
+ elif self.method == CONCAT_UPDATE:
1227
+ current = jax.tree.map(lambda a: jnp.expand_dims(a, 0), current)
1228
+ if self.max_length > 1:
1229
+ self.history.value = jax.tree.map(lambda hist, cur: jnp.concatenate([cur, hist[:-1]], axis=0),
1230
+ self.history.value,
1231
+ current)
1232
+ else:
1233
+ self.history.value = current
1234
+
1235
+ else:
1236
+ raise ValueError(f'Unknown updating method "{self.method}"')
1237
+
1238
+
1239
+ class _StateDelay(Delay):
1240
+ """
1241
+ The state delay class.
1242
+
1243
+ Args:
1244
+ target: The target state instance.
1245
+ init: The initial delay data.
1246
+ """
1247
+
1248
+ __module__ = 'brainstate'
1249
+ _invisible_states = ('target',)
1250
+
1251
+ def __init__(
1252
+ self,
1253
+ target: State,
1254
+ time: Optional[Union[int, float]] = None, # delay time
1255
+ init: Optional[Union[ArrayLike, Callable]] = None, # delay data init
1256
+ entries: Optional[Dict] = None, # delay access entry
1257
+ method: Optional[str] = ROTATE_UPDATE, # delay method
1258
+ # others
1259
+ name: Optional[str] = None,
1260
+ mode: Optional[Mode] = None,
1261
+ ):
1262
+ super().__init__(target_info=target.value,
1263
+ time=time, init=init, entries=entries,
1264
+ method=method, name=name, mode=mode)
1265
+ self.target = target
1266
+
1267
+ def update(self, *args, **kwargs):
1268
+ super().update(self.target.value)
1269
+
1270
+
1271
+ class DelayAccess(Module):
1272
+ """
1273
+ The delay access class.
1274
+
1275
+ Args:
1276
+ delay: The delay instance.
1277
+ time: The delay time.
1278
+ indices: The indices of the delay data.
1279
+ delay_entry: The delay entry.
1280
+ """
1281
+
1282
+ __module__ = 'brainstate'
1283
+
1284
+ def __init__(
1285
+ self,
1286
+ delay: Delay,
1287
+ time: Union[None, int, float],
1288
+ *indices,
1289
+ delay_entry: str = None
1290
+ ):
1291
+ super().__init__(mode=delay.mode)
1292
+ self.refs = {'delay': delay}
1293
+ assert isinstance(delay, Delay), 'The input delay should be an instance of Delay.'
1294
+ self._delay_entry = delay_entry or self.name
1295
+ delay.register_entry(self._delay_entry, time)
1296
+ self.indices = indices
1297
+
1298
+ def update(self):
1299
+ return self.refs['delay'].at(self._delay_entry, *self.indices)
1300
+
1301
+
1302
+ def register_delay_of_target(target: AllOfTypes[ExtendedUpdateWithBA, UpdateReturn]):
1303
+ """Register delay class for the given target.
1304
+
1305
+ Args:
1306
+ target: The target class to register delay.
1307
+
1308
+ Returns:
1309
+ The delay registered for the given target.
1310
+ """
1311
+ if not target.has_after_update(delay_identifier):
1312
+ assert isinstance(target, AllOfTypes[ExtendedUpdateWithBA, UpdateReturn])
1313
+ target.add_after_update(delay_identifier, Delay(target.update_return_info()))
1314
+ delay_cls = target.get_after_update(delay_identifier)
1315
+ return delay_cls
1316
+
1317
+
1318
+ @set_module_as('brainstate')
1319
+ def call_order(level: int = 0):
1320
+ """The decorator for indicating the resetting level.
1321
+
1322
+ The function takes an optional integer argument level with a default value of 0.
1323
+
1324
+ The lower the level, the earlier the function is called.
1325
+
1326
+ >>> import brainstate as bst
1327
+ >>> bst.call_order(0)
1328
+ >>> bst.call_order(-1)
1329
+ >>> bst.call_order(-2)
1330
+
1331
+ """
1332
+ if level < 0:
1333
+ level = _max_order + level
1334
+ if level < 0 or level >= _max_order:
1335
+ raise ValueError(f'"call_order" must be an integer in [0, {_max_order}). but we got {level}.')
1336
+
1337
+ def wrap(fun: Callable):
1338
+ fun.call_order = level
1339
+ return fun
1340
+
1341
+ return wrap
1342
+
1343
+
1344
+ @set_module_as('brainstate')
1345
+ def init_states(target: Module, *args, **kwargs) -> Module:
1346
+ """
1347
+ Reset states of all children nodes in the given target.
1348
+
1349
+ Args:
1350
+ target: The target Module.
1351
+
1352
+ Returns:
1353
+ The target Module.
1354
+ """
1355
+ nodes_with_order = []
1356
+
1357
+ # reset node whose `init_state` has no `call_order`
1358
+ for node in list(target.nodes().values()):
1359
+ if not hasattr(node.init_state, 'call_order'):
1360
+ node.init_state(*args, **kwargs)
1361
+ else:
1362
+ nodes_with_order.append(node)
1363
+
1364
+ # reset the node's states
1365
+ for node in sorted(nodes_with_order, key=lambda x: x.init_state.call_order):
1366
+ node.init_state(*args, **kwargs)
1367
+
1368
+ return target
1369
+
1370
+
1371
+ @set_module_as('brainstate')
1372
+ def load_states(target: Module, state_dict: Dict, **kwargs):
1373
+ """Copy parameters and buffers from :attr:`state_dict` into
1374
+ this module and its descendants.
1375
+
1376
+ Args:
1377
+ target: Module. The dynamical system to load its states.
1378
+ state_dict: dict. A dict containing parameters and persistent buffers.
1379
+
1380
+ Returns:
1381
+ -------
1382
+ ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
1383
+
1384
+ * **missing_keys** is a list of str containing the missing keys
1385
+ * **unexpected_keys** is a list of str containing the unexpected keys
1386
+ """
1387
+ missing_keys = []
1388
+ unexpected_keys = []
1389
+ for name, node in target.nodes().items():
1390
+ r = node.load_state(state_dict[name], **kwargs)
1391
+ if r is not None:
1392
+ missing, unexpected = r
1393
+ missing_keys.extend([f'{name}.{key}' for key in missing])
1394
+ unexpected_keys.extend([f'{name}.{key}' for key in unexpected])
1395
+ return StateLoadResult(missing_keys, unexpected_keys)
1396
+
1397
+
1398
+ @set_module_as('brainstate')
1399
+ def save_states(target: Module, **kwargs) -> Dict:
1400
+ """Save all states in the ``target`` as a dictionary for later disk serialization.
1401
+
1402
+ Args:
1403
+ target: Module. The node to save its states.
1404
+
1405
+ Returns:
1406
+ Dict. The state dict for serialization.
1407
+ """
1408
+ return {key: node.save_state(**kwargs) for key, node in target.nodes().items()}
1409
+
1410
+
1411
+ @set_module_as('brainstate')
1412
+ def assign_state_values(target: Module, *state_by_abs_path: Dict):
1413
+ """
1414
+ Assign state values according to the given state dictionary.
1415
+
1416
+ Parameters
1417
+ ----------
1418
+ target: Module
1419
+ The target module.
1420
+ state_by_abs_path: dict
1421
+ The state dictionary which is accessed by the "absolute" accessing method.
1422
+
1423
+ """
1424
+ all_states = dict()
1425
+ for state in state_by_abs_path:
1426
+ all_states.update(state)
1427
+ variables = target.states(include_self=True, method='absolute')
1428
+ keys1 = set(all_states.keys())
1429
+ keys2 = set(variables.keys())
1430
+ for key in keys2.intersection(keys1):
1431
+ variables[key].value = jax.numpy.asarray(all_states[key])
1432
+ unexpected_keys = list(keys1 - keys2)
1433
+ missing_keys = list(keys2 - keys1)
1434
+ return unexpected_keys, missing_keys
1435
+
1436
+
1437
+ def _input_label_start(label: str):
1438
+ # unify the input label repr.
1439
+ return f'{label} // '
1440
+
1441
+
1442
+ def _input_label_repr(name: str, label: Optional[str] = None):
1443
+ # unify the input label repr.
1444
+ return name if label is None else (_input_label_start(label) + str(name))
1445
+
1446
+
1447
+ def _repr_context(repr_str, indent):
1448
+ splits = repr_str.split('\n')
1449
+ splits = [(s if i == 0 else (indent + s)) for i, s in enumerate(splits)]
1450
+ return '\n'.join(splits)
1451
+
1452
+
1453
+ def _get_delay(delay_time, delay_step):
1454
+ if delay_time is None:
1455
+ if delay_step is None:
1456
+ return 0., 0
1457
+ else:
1458
+ assert isinstance(delay_step, int), '"delay_step" should be an integer.'
1459
+ if delay_step == 0:
1460
+ return 0., 0
1461
+ delay_time = delay_step * environ.get_dt()
1462
+ else:
1463
+ assert delay_step is None, '"delay_step" should be None if "delay_time" is given.'
1464
+ assert isinstance(delay_time, (int, float))
1465
+ delay_step = math.ceil(delay_time / environ.get_dt())
1466
+ return delay_time, delay_step