brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
brainstate/_module.py DELETED
@@ -1,1637 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
-
19
- """
20
-
21
- All the basic classes for the ``brainstate``.
22
-
23
- The basic classes include:
24
-
25
- - ``Module``: The base class for all the objects in the ecosystem.
26
- - ``Sequential``: The class for a sequential of modules, which update the modules sequentially.
27
- - ``ModuleGroup``: The class for a group of modules, which update ``Projection`` first,
28
- then ``Dynamics``, finally others.
29
-
30
- and:
31
-
32
- - ``visible_module_list``: A list to represent a sequence of :py:class:`~.Module`
33
- that can be visible by the ``.nodes()`` extractor.
34
- - ``visible_module_dict``: A dict to represent a dictionary of :py:class:`~.Module`
35
- that can be visible by the ``.nodes()`` extractor.
36
-
37
- For handling dynamical systems:
38
-
39
- - ``Projection``: The class for the synaptic projection.
40
- - ``Dynamics``: The class for the dynamical system.
41
-
42
- For handling the delays:
43
-
44
- - ``Delay``: The class for all delays.
45
- - ``DelayAccess``: The class for the delay access.
46
-
47
- """
48
-
49
- import math
50
- import numbers
51
- from collections import namedtuple
52
- from functools import partial
53
- from typing import Sequence, Any, Tuple, Union, Dict, Callable, Optional
54
-
55
- import jax
56
- import jax.numpy as jnp
57
- import numpy as np
58
-
59
- from brainstate import environ
60
- from brainstate._state import State, StateDictManager, visible_state_dict
61
- from brainstate._utils import set_module_as
62
- from brainstate.mixin import Mixin, Mode, DelayedInit, JointTypes, Batching, UpdateReturn
63
- from brainstate.transform import jit_error_if
64
- from brainstate.typing import Size, ArrayLike, PyTree
65
- from brainstate.util import unique_name, DictManager, get_unique_name
66
-
67
- delay_identifier = '_*_delay_of_'
68
- _DELAY_ROTATE = 'rotation'
69
- _DELAY_CONCAT = 'concat'
70
- _INTERP_LINEAR = 'linear_interp'
71
- _INTERP_ROUND = 'round'
72
-
73
- StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys'])
74
-
75
- # the maximum order
76
- _max_order = 10
77
-
78
- __all__ = [
79
- # basic classes
80
- 'Module', 'visible_module_list', 'visible_module_dict', 'ModuleGroup',
81
-
82
- # dynamical systems
83
- 'Projection', 'Dynamics',
84
-
85
- # delay handling
86
- 'Delay', 'DelayAccess',
87
-
88
- # helper functions
89
- 'call_order',
90
-
91
- # state processing
92
- 'init_states', 'reset_states', 'load_states', 'save_states', 'assign_state_values',
93
- ]
94
-
95
-
96
- class Object:
97
- """
98
- The Module class for the whole ecosystem.
99
-
100
- The ``Module`` is the base class for all the objects in the ecosystem. It
101
- provides the basic functionalities for the objects, including:
102
-
103
- - ``states()``: Collect all states in this node and the children nodes.
104
- - ``nodes()``: Collect all children nodes.
105
- - ``update()``: The function to specify the updating rule.
106
- - ``init_state()``: State initialization function.
107
- - ``save_state()``: Save states as a dictionary.
108
- - ``load_state()``: Load states from the external objects.
109
-
110
- """
111
-
112
- __module__ = 'brainstate'
113
-
114
- # the excluded states
115
- _invisible_states: Tuple[str, ...] = ()
116
-
117
- # the excluded nodes
118
- _invisible_nodes: Tuple[str, ...] = ()
119
-
120
- def __repr__(self):
121
- return f'{self.__class__.__name__}'
122
-
123
- def states(
124
- self,
125
- method: str = 'absolute',
126
- level: int = -1,
127
- include_self: bool = True,
128
- unique: bool = True,
129
- ) -> StateDictManager:
130
- """
131
- Collect all states in this node and the children nodes.
132
-
133
- Parameters
134
- ----------
135
- method : str
136
- The method to access the variables.
137
- level: int
138
- The hierarchy level to find variables.
139
- include_self: bool
140
- Whether include the variables in the self.
141
- unique: bool
142
- Whether return the unique variables.
143
-
144
- Returns
145
- -------
146
- states : StateDictManager
147
- The collection contained (the path, the variable).
148
- """
149
-
150
- # find the nodes
151
- nodes = self.nodes(method=method, level=level, include_self=include_self)
152
-
153
- # get the state stack
154
- states = StateDictManager()
155
- _state_id = set()
156
- for node_path, node in nodes.items():
157
- for k in node.__dict__.keys():
158
- if k in node._invisible_states:
159
- continue
160
- v = getattr(node, k)
161
- if isinstance(v, State):
162
- if unique and id(v) in _state_id:
163
- continue
164
- _state_id.add(id(v))
165
- states[f'{node_path}.{k}' if node_path else k] = v
166
- elif isinstance(v, visible_state_dict):
167
- for k2, v2 in v.items():
168
- if unique and id(v2) in _state_id:
169
- continue
170
- _state_id.add(id(v2))
171
- states[f'{node_path}.{k}.{k2}'] = v2
172
-
173
- return states
174
-
175
- def nodes(
176
- self,
177
- method: str = 'absolute',
178
- level: int = -1,
179
- include_self: bool = True,
180
- unique: bool = True,
181
- ) -> DictManager:
182
- """
183
- Collect all children nodes.
184
-
185
- Parameters
186
- ----------
187
- method : str
188
- The method to access the nodes.
189
- level: int
190
- The hierarchy level to find nodes.
191
- include_self: bool
192
- Whether include the self.
193
- unique: bool
194
- Whether return the unique variables.
195
-
196
- Returns
197
- -------
198
- gather : DictManager
199
- The collection contained (the path, the node).
200
- """
201
- nodes = _find_nodes(self, method=method, level=level, include_self=include_self)
202
- if unique:
203
- nodes = nodes.unique()
204
- return nodes
205
-
206
- def init_state(self, *args, **kwargs):
207
- """
208
- State initialization function.
209
- """
210
- pass
211
-
212
- def reset_state(self, *args, **kwargs):
213
- """
214
- State resetting function.
215
- """
216
- pass
217
-
218
- def save_state(self, **kwargs) -> Dict:
219
- """Save states as a dictionary. """
220
- return self.states(include_self=True, level=0, method='absolute')
221
-
222
- def load_state(self, state_dict: Dict, **kwargs) -> Optional[Tuple[Sequence[str], Sequence[str]]]:
223
- """Load states from the external objects."""
224
- variables = self.states(include_self=True, level=0, method='absolute')
225
- keys1 = set(state_dict.keys())
226
- keys2 = set(variables.keys())
227
- for key in keys2.intersection(keys1):
228
- variables[key].value = jax.numpy.asarray(state_dict[key])
229
- unexpected_keys = list(keys1 - keys2)
230
- missing_keys = list(keys2 - keys1)
231
- return unexpected_keys, missing_keys
232
-
233
- def __treescope_repr__(self, path, subtree_renderer):
234
- import treescope # type: ignore[import-not-found,import-untyped]
235
- children = {}
236
- for name, value in vars(self).items():
237
- if name.startswith('_'):
238
- continue
239
- children[name] = value
240
- return treescope.repr_lib.render_object_constructor(
241
- object_type=type(self),
242
- attributes=children,
243
- path=path,
244
- subtree_renderer=subtree_renderer,
245
- color=treescope.formatting_util.color_from_string(
246
- type(self).__qualname__
247
- )
248
- )
249
-
250
-
251
- def _find_nodes(self, method: str = 'absolute', level=-1, include_self=True, _lid=0, _edges=None) -> DictManager:
252
- if _edges is None:
253
- _edges = set()
254
- gather = DictManager()
255
- if include_self:
256
- if method == 'absolute':
257
- gather[self.name] = self
258
- elif method == 'relative':
259
- gather[''] = self
260
- else:
261
- raise ValueError(f'No support for the method of "{method}".')
262
- if (level > -1) and (_lid >= level):
263
- return gather
264
- if method == 'absolute':
265
- nodes = []
266
- for k, v in self.__dict__.items():
267
- if k in self._invisible_nodes:
268
- continue
269
- if isinstance(v, Module):
270
- _add_node_absolute(self, v, _edges, gather, nodes)
271
- elif isinstance(v, visible_module_list):
272
- for v2 in v:
273
- _add_node_absolute(self, v2, _edges, gather, nodes)
274
- elif isinstance(v, visible_module_dict):
275
- for v2 in v.values():
276
- if isinstance(v2, Module):
277
- _add_node_absolute(self, v2, _edges, gather, nodes)
278
-
279
- # finding nodes recursively
280
- for v in nodes:
281
- gather.update(_find_nodes(v,
282
- method=method,
283
- level=level,
284
- _lid=_lid + 1,
285
- _edges=_edges,
286
- include_self=include_self))
287
-
288
- elif method == 'relative':
289
- nodes = []
290
- for k, v in self.__dict__.items():
291
- if v in self._invisible_nodes:
292
- continue
293
- if isinstance(v, Module):
294
- _add_node_relative(self, k, v, _edges, gather, nodes)
295
- elif isinstance(v, visible_module_list):
296
- for i, v2 in enumerate(v):
297
- _add_node_relative(self, f'{k}-list:{i}', v2, _edges, gather, nodes)
298
- elif isinstance(v, visible_module_dict):
299
- for k2, v2 in v.items():
300
- if isinstance(v2, Module):
301
- _add_node_relative(self, f'{k}-dict:{k2}', v2, _edges, gather, nodes)
302
-
303
- # finding nodes recursively
304
- for k1, v1 in nodes:
305
- for k2, v2 in _find_nodes(v1,
306
- method=method,
307
- _edges=_edges,
308
- _lid=_lid + 1,
309
- level=level,
310
- include_self=include_self).items():
311
- if k2:
312
- gather[f'{k1}.{k2}'] = v2
313
-
314
- else:
315
- raise ValueError(f'No support for the method of "{method}".')
316
- return gather
317
-
318
-
319
- def _add_node_absolute(self, v, _paths, gather, nodes):
320
- path = (id(self), id(v))
321
- if path not in _paths:
322
- _paths.add(path)
323
- gather[v.name] = v
324
- nodes.append(v)
325
-
326
-
327
- def _add_node_relative(self, k, v, _paths, gather, nodes):
328
- path = (id(self), id(v))
329
- if path not in _paths:
330
- _paths.add(path)
331
- gather[k] = v
332
- nodes.append((k, v))
333
-
334
-
335
- class Module(Object):
336
- """
337
- The Module class for the whole ecosystem.
338
-
339
- The ``Module`` is the base class for all the objects in the ecosystem. It
340
- provides the basic functionalities for the objects, including:
341
-
342
- - ``states()``: Collect all states in this node and the children nodes.
343
- - ``nodes()``: Collect all children nodes.
344
- - ``update()``: The function to specify the updating rule.
345
- - ``init_state()``: State initialization function.
346
- - ``save_state()``: Save states as a dictionary.
347
- - ``load_state()``: Load states from the external objects.
348
-
349
- """
350
-
351
- __module__ = 'brainstate'
352
-
353
- def __init__(self, name: str = None, mode: Mode = None):
354
- super().__init__()
355
-
356
- # check whether the object has a unique name.
357
- self._name = unique_name(self=self, name=name)
358
-
359
- # mode setting
360
- self._mode = None
361
- self.mode = mode if mode is not None else environ.get('mode')
362
-
363
- def __repr__(self):
364
- return f'{self.__class__.__name__}'
365
-
366
- @property
367
- def name(self):
368
- """Name of the model."""
369
- return self._name
370
-
371
- @name.setter
372
- def name(self, name: str = None):
373
- raise AttributeError('The name of the model is read-only.')
374
-
375
- @property
376
- def mode(self):
377
- """Mode of the model, which is useful to control the multiple behaviors of the model."""
378
- return self._mode
379
-
380
- @mode.setter
381
- def mode(self, value):
382
- if not isinstance(value, Mode):
383
- raise ValueError(f'Must be instance of {Mode.__name__}, '
384
- f'but we got {type(value)}: {value}')
385
- self._mode = value
386
-
387
- def update(self, *args, **kwargs):
388
- """
389
- The function to specify the updating rule.
390
- """
391
- raise NotImplementedError(f'Subclass of {self.__class__.__name__} must '
392
- f'implement "update" function.')
393
-
394
- def __call__(self, *args, **kwargs):
395
- return self.update(*args, **kwargs)
396
-
397
- def __rrshift__(self, other):
398
- """
399
- Support using right shift operator to call modules.
400
-
401
- Examples
402
- --------
403
-
404
- >>> import brainstate as bst
405
- >>> x = bst.random.rand((10, 10))
406
- >>> l = bst.nn.Dropout(0.5)
407
- >>> y = x >> l
408
- """
409
- return self.__call__(other)
410
-
411
-
412
- class Projection(Module):
413
- """
414
- Base class to model synaptic projections.
415
- """
416
-
417
- __module__ = 'brainstate'
418
-
419
- def update(self, *args, **kwargs):
420
- nodes = tuple(self.nodes(level=1, include_self=False).values())
421
- if len(nodes):
422
- for node in nodes:
423
- node(*args, **kwargs)
424
- else:
425
- raise ValueError('Do not implement the update() function.')
426
-
427
-
428
- class visible_module_list(list):
429
- """
430
- A sequence of :py:class:`~.Module`, which is compatible with
431
- :py:func:`~.vars()` and :py:func:`~.nodes()` operations in a :py:class:`~.Module`.
432
-
433
- That is to say, any nodes that are wrapped into :py:class:`~.NodeList` will be automatically
434
- retieved when using :py:func:`~.nodes()` function.
435
-
436
- >>> import brainstate as bst
437
- >>> l = bst.visible_module_list([bst.nn.Linear(1, 2),
438
- >>> bst.nn.LSTMCell(2, 3)])
439
- """
440
-
441
- __module__ = 'brainstate'
442
-
443
- def __init__(self, seq=()):
444
- super().__init__()
445
- self.extend(seq)
446
-
447
- def append(self, element) -> 'visible_module_list':
448
- if isinstance(element, State):
449
- raise TypeError(f'Cannot append a state into a node list. ')
450
- super().append(element)
451
- return self
452
-
453
- def extend(self, iterable) -> 'visible_module_list':
454
- for element in iterable:
455
- self.append(element)
456
- return self
457
-
458
-
459
- class visible_module_dict(dict):
460
- """
461
- A dictionary of :py:class:`~.Module`, which is compatible with
462
- :py:func:`.vars()` operation in a :py:class:`~.Module`.
463
-
464
- """
465
-
466
- __module__ = 'brainstate'
467
-
468
- def __init__(self, *args, check_unique: bool = False, **kwargs):
469
- super().__init__()
470
- self.check_unique = check_unique
471
- self.update(*args, **kwargs)
472
-
473
- def update(self, *args, **kwargs) -> 'visible_module_dict':
474
- for arg in args:
475
- if isinstance(arg, dict):
476
- for k, v in arg.items():
477
- self[k] = v
478
- elif isinstance(arg, tuple):
479
- assert len(arg) == 2
480
- self[arg[0]] = args[1]
481
- for k, v in kwargs.items():
482
- self[k] = v
483
- return self
484
-
485
- def __setitem__(self, key, value) -> 'visible_module_dict':
486
- if self.check_unique:
487
- exist = self.get(key, None)
488
- if id(exist) != id(value):
489
- raise KeyError(f'Duplicate usage of key "{key}". "{key}" has been used for {value}.')
490
- super().__setitem__(key, value)
491
- return self
492
-
493
-
494
- class ReceiveInputProj(Mixin):
495
- """
496
- The :py:class:`~.Mixin` that receives the input projections.
497
-
498
- Note that the subclass should define a ``cur_inputs`` attribute. Otherwise,
499
- the input function utilities cannot be used.
500
-
501
- """
502
- _current_inputs: Optional[visible_module_dict]
503
- _delta_inputs: Optional[visible_module_dict]
504
-
505
- @property
506
- def current_inputs(self):
507
- """
508
- The current inputs of the model. It should be a dictionary of the input data.
509
- """
510
- return self._current_inputs
511
-
512
- @property
513
- def delta_inputs(self):
514
- """
515
- The delta inputs of the model. It should be a dictionary of the input data.
516
- """
517
-
518
- return self._delta_inputs
519
-
520
- def add_input_fun(self, key: str, fun: Callable, label: Optional[str] = None, category: str = 'current'):
521
- """Add an input function.
522
-
523
- Args:
524
- key: str. The dict key.
525
- fun: Callable. The function to generate inputs.
526
- label: str. The input label.
527
- category: str. The input category, should be ``current`` (the current) or
528
- ``delta`` (the delta synapse, indicating the delta function).
529
- """
530
- if not callable(fun):
531
- raise TypeError('Must be a function.')
532
-
533
- key = _input_label_repr(key, label)
534
- if category == 'current':
535
- if self._current_inputs is None:
536
- self._current_inputs = visible_module_dict()
537
- if key in self._current_inputs:
538
- raise ValueError(f'Key "{key}" has been defined and used.')
539
- self._current_inputs[key] = fun
540
-
541
- elif category == 'delta':
542
- if self._delta_inputs is None:
543
- self._delta_inputs = visible_module_dict()
544
- if key in self._delta_inputs:
545
- raise ValueError(f'Key "{key}" has been defined and used.')
546
- self._delta_inputs[key] = fun
547
-
548
- else:
549
- raise NotImplementedError(f'Unknown category: {category}. Only support "current" and "delta".')
550
-
551
- def get_input_fun(self, key: str):
552
- """Get the input function.
553
-
554
- Args:
555
- key: str. The key.
556
-
557
- Returns:
558
- The input function which generates currents.
559
- """
560
- if self._current_inputs is not None and key in self._current_inputs:
561
- return self._current_inputs[key]
562
-
563
- elif self._delta_inputs is not None and key in self._delta_inputs:
564
- return self._delta_inputs[key]
565
-
566
- else:
567
- raise ValueError(f'Unknown key: {key}')
568
-
569
- def sum_current_inputs(self, *args, init: Any = 0., label: Optional[str] = None, **kwargs):
570
- """
571
- Summarize all current inputs by the defined input functions ``.current_inputs``.
572
-
573
- Args:
574
- *args: The arguments for input functions.
575
- init: The initial input data.
576
- label: str. The input label.
577
- **kwargs: The arguments for input functions.
578
-
579
- Returns:
580
- The total currents.
581
- """
582
- if self._current_inputs is None:
583
- return init
584
- if label is None:
585
- for key, out in self._current_inputs.items():
586
- init = init + out(*args, **kwargs)
587
- else:
588
- label_repr = _input_label_start(label)
589
- for key, out in self._current_inputs.items():
590
- if key.startswith(label_repr):
591
- init = init + out(*args, **kwargs)
592
- return init
593
-
594
- def sum_delta_inputs(self, *args, init: Any = 0., label: Optional[str] = None, **kwargs):
595
- """
596
- Summarize all delta inputs by the defined input functions ``.delta_inputs``.
597
-
598
- Args:
599
- *args: The arguments for input functions.
600
- init: The initial input data.
601
- label: str. The input label.
602
- **kwargs: The arguments for input functions.
603
-
604
- Returns:
605
- The total currents.
606
- """
607
- if self._delta_inputs is None:
608
- return init
609
- if label is None:
610
- for key, out in self._delta_inputs.items():
611
- init = init + out(*args, **kwargs)
612
- else:
613
- label_repr = _input_label_start(label)
614
- for key, out in self._delta_inputs.items():
615
- if key.startswith(label_repr):
616
- init = init + out(*args, **kwargs)
617
- return init
618
-
619
-
620
- class Container(Mixin):
621
- """Container :py:class:`~.MixIn` which wrap a group of objects.
622
- """
623
- children: visible_module_dict
624
-
625
- def __getitem__(self, item):
626
- """Overwrite the slice access (`self['']`). """
627
- if item in self.children:
628
- return self.children[item]
629
- else:
630
- raise ValueError(f'Unknown item {item}, we only found {list(self.children.keys())}')
631
-
632
- def __getattr__(self, item):
633
- """Overwrite the dot access (`self.`). """
634
- children = super().__getattribute__('children')
635
- if item == 'children':
636
- return children
637
- else:
638
- if item in children:
639
- return children[item]
640
- else:
641
- return super().__getattribute__(item)
642
-
643
- def __repr__(self):
644
- cls_name = self.__class__.__name__
645
- indent = ' ' * len(cls_name)
646
- child_str = [_repr_context(repr(val), indent) for val in self.children.values()]
647
- string = ", \n".join(child_str)
648
- return f'{cls_name}({string})'
649
-
650
- def __get_elem_name(self, elem):
651
- if isinstance(elem, Module):
652
- return elem.name
653
- else:
654
- return get_unique_name('ContainerElem')
655
-
656
- def format_elements(self, child_type: type, *children_as_tuple, **children_as_dict):
657
- res = dict()
658
-
659
- # add tuple-typed components
660
- for module in children_as_tuple:
661
- if isinstance(module, child_type):
662
- res[self.__get_elem_name(module)] = module
663
- elif isinstance(module, (list, tuple)):
664
- for m in module:
665
- if not isinstance(m, child_type):
666
- raise TypeError(f'Should be instance of {child_type.__name__}. '
667
- f'But we got {type(m)}')
668
- res[self.__get_elem_name(m)] = m
669
- elif isinstance(module, dict):
670
- for k, v in module.items():
671
- if not isinstance(v, child_type):
672
- raise TypeError(f'Should be instance of {child_type.__name__}. '
673
- f'But we got {type(v)}')
674
- res[k] = v
675
- else:
676
- raise TypeError(f'Cannot parse sub-systems. They should be {child_type.__name__} '
677
- f'or a list/tuple/dict of {child_type.__name__}.')
678
- # add dict-typed components
679
- for k, v in children_as_dict.items():
680
- if not isinstance(v, child_type):
681
- raise TypeError(f'Should be instance of {child_type.__name__}. '
682
- f'But we got {type(v)}')
683
- res[k] = v
684
- return res
685
-
686
- def add_elem(self, *elems, **elements):
687
- """
688
- Add new elements.
689
-
690
- >>> obj = Container()
691
- >>> obj.add_elem(a=1.)
692
-
693
- Args:
694
- elements: children objects.
695
- """
696
- self.children.update(self.format_elements(object, *elems, **elements))
697
-
698
-
699
- class ExtendedUpdateWithBA(Module):
700
- """
701
- The extended update with before and after updates.
702
- """
703
-
704
- _before_updates: Optional[visible_module_dict]
705
- _after_updates: Optional[visible_module_dict]
706
-
707
- def __init__(self, *args, **kwargs):
708
-
709
- # -- Attribute for "BeforeAfterMixIn" -- #
710
- # the before- / after-updates used for computing
711
- self._before_updates: Optional[Dict[str, Callable]] = None
712
- self._after_updates: Optional[Dict[str, Callable]] = None
713
-
714
- super().__init__(*args, **kwargs)
715
-
716
- @property
717
- def before_updates(self):
718
- """
719
- The before updates of the model. It should be a dictionary of the updating functions.
720
- """
721
- return self._before_updates
722
-
723
- @property
724
- def after_updates(self):
725
- """
726
- The after updates of the model. It should be a dictionary of the updating functions.
727
- """
728
- return self._after_updates
729
-
730
- def add_before_update(self, key: Any, fun: Callable):
731
- """
732
- Add the before update into this node.
733
- """
734
- if self._before_updates is None:
735
- self._before_updates = visible_module_dict()
736
- if key in self.before_updates:
737
- raise KeyError(f'{key} has been registered in before_updates of {self}')
738
- self.before_updates[key] = fun
739
-
740
- def add_after_update(self, key: Any, fun: Callable):
741
- """Add the after update into this node"""
742
- if self._after_updates is None:
743
- self._after_updates = visible_module_dict()
744
- if key in self.after_updates:
745
- raise KeyError(f'{key} has been registered in after_updates of {self}')
746
- self.after_updates[key] = fun
747
-
748
- def get_before_update(self, key: Any):
749
- """Get the before update of this node by the given ``key``."""
750
- if self._before_updates is None:
751
- raise KeyError(f'{key} is not registered in before_updates of {self}')
752
- if key not in self.before_updates:
753
- raise KeyError(f'{key} is not registered in before_updates of {self}')
754
- return self.before_updates.get(key)
755
-
756
- def get_after_update(self, key: Any):
757
- """Get the after update of this node by the given ``key``."""
758
- if self._after_updates is None:
759
- raise KeyError(f'{key} is not registered in after_updates of {self}')
760
- if key not in self.after_updates:
761
- raise KeyError(f'{key} is not registered in after_updates of {self}')
762
- return self.after_updates.get(key)
763
-
764
- def has_before_update(self, key: Any):
765
- """Whether this node has the before update of the given ``key``."""
766
- if self._before_updates is None:
767
- return False
768
- return key in self.before_updates
769
-
770
- def has_after_update(self, key: Any):
771
- """Whether this node has the after update of the given ``key``."""
772
- if self._after_updates is None:
773
- return False
774
- return key in self.after_updates
775
-
776
- def __call__(self, *args, **kwargs):
777
- """The shortcut to call ``update`` methods."""
778
-
779
- # ``before_updates``
780
- if self.before_updates is not None:
781
- for model in self.before_updates.values():
782
- if hasattr(model, '_receive_update_input'):
783
- model(*args, **kwargs)
784
- else:
785
- model()
786
-
787
- # update the model self
788
- ret = self.update(*args, **kwargs)
789
-
790
- # ``after_updates``
791
- if self.after_updates is not None:
792
- for model in self.after_updates.values():
793
- if hasattr(model, '_not_receive_update_output'):
794
- model()
795
- else:
796
- model(ret)
797
- return ret
798
-
799
-
800
- class Dynamics(ExtendedUpdateWithBA, ReceiveInputProj, UpdateReturn):
801
- """
802
- Dynamical System class.
803
-
804
- .. note::
805
- In general, every instance of :py:class:`~.Module` implemented in
806
- BrainPy only defines the evolving function at each time step :math:`t`.
807
-
808
- If users want to define the logic of running models across multiple steps,
809
- we recommend users to use :py:func:`~.for_loop`, :py:class:`~.LoopOverTime`,
810
- :py:class:`~.DSRunner`, or :py:class:`~.DSTrainer`.
811
-
812
- To be compatible with previous APIs, :py:class:`~.Module` inherits
813
- from the :py:class:`~.DelayRegister`. It's worthy to note that the methods of
814
- :py:class:`~.DelayRegister` will be removed in the future, including:
815
-
816
- - ``.register_delay()``
817
- - ``.get_delay_data()``
818
- - ``.update_local_delays()``
819
- - ``.reset_local_delays()``
820
-
821
-
822
- There are several essential attributes:
823
-
824
- - ``size``: the geometry of the neuron group. For example, `(10, )` denotes a line of
825
- neurons, `(10, 10)` denotes a neuron group aligned in a 2D space, `(10, 15, 4)` denotes
826
- a 3-dimensional neuron group.
827
- - ``num``: the flattened number of neurons in the group. For example, `size=(10, )` => \
828
- `num=10`, `size=(10, 10)` => `num=100`, `size=(10, 15, 4)` => `num=600`.
829
-
830
- Args:
831
- size: The neuron group geometry.
832
- name: The name of the dynamic system.
833
- keep_size: Whether keep the geometry information.
834
- mode: The computing mode.
835
- """
836
-
837
- __module__ = 'brainstate'
838
-
839
- def __init__(
840
- self,
841
- size: Size,
842
- keep_size: bool = False,
843
- name: Optional[str] = None,
844
- mode: Optional[Mode] = None,
845
- ):
846
- # size
847
- if isinstance(size, (list, tuple)):
848
- if len(size) <= 0:
849
- raise ValueError(f'size must be int, or a tuple/list of int. '
850
- f'But we got {type(size)}')
851
- if not isinstance(size[0], (int, np.integer)):
852
- raise ValueError('size must be int, or a tuple/list of int.'
853
- f'But we got {type(size)}')
854
- size = tuple(size)
855
- elif isinstance(size, (int, np.integer)):
856
- size = (size,)
857
- else:
858
- raise ValueError('size must be int, or a tuple/list of int.'
859
- f'But we got {type(size)}')
860
- self.size = size
861
- self.keep_size = keep_size
862
-
863
- # number of neurons
864
- self.num = np.prod(size)
865
-
866
- # -- Attribute for "InputProjMixIn" -- #
867
- # each instance of "SupportInputProj" should have
868
- # "_current_inputs" and "_delta_inputs" attributes
869
- self._current_inputs: Optional[Dict[str, Callable]] = None
870
- self._delta_inputs: Optional[Dict[str, Callable]] = None
871
-
872
- # initialize
873
- super().__init__(name=name, mode=mode)
874
-
875
- @property
876
- def varshape(self):
877
- """The shape of variables in the neuron group."""
878
- return self.size if self.keep_size else (self.num,)
879
-
880
- def __repr__(self):
881
- return f'{self.name}(mode={self.mode}, size={self.size})'
882
-
883
- def update_return_info(self) -> PyTree:
884
- raise NotImplementedError(f'Subclass of {self.__class__.__name__}'
885
- 'must implement "update_return_info" function.')
886
-
887
- def update_return(self) -> PyTree:
888
- raise NotImplementedError(f'Subclass of {self.__class__.__name__}'
889
- 'must implement "update_return" function.')
890
-
891
- def register_return_delay(
892
- self,
893
- delay_name: str,
894
- delay_time: ArrayLike = None,
895
- delay_step: ArrayLike = None,
896
- ):
897
- """Register local relay at the given delay time.
898
-
899
- Args:
900
- delay_name: str. The name of the current delay data.
901
- delay_time: The delay time. Float.
902
- delay_step: The delay step. Int. ``delay_step`` and ``delay_time`` are exclusive. ``delay_step = delay_time / dt``.
903
- """
904
- if not self.has_after_update(delay_identifier):
905
- # add a model to receive the return of the target model
906
- model = Delay(self.update_return_info())
907
- # register the model
908
- self.add_after_update(delay_identifier, model)
909
- delay_cls: Delay = self.get_after_update(delay_identifier)
910
- delay_cls.register_entry(delay_name, delay_time=delay_time, delay_step=delay_step)
911
- return delay_cls
912
-
913
- def get_return_delay_at(self, delay_name):
914
- """Get the state delay at the given identifier (`name`).
915
-
916
- See also :py:meth:`~.Module.register_state_delay`.
917
-
918
- Args:
919
- delay_name: The identifier of the delay.
920
-
921
- Returns:
922
- The delayed data at the given delay position.
923
- """
924
- return self.get_after_update(delay_identifier).at(delay_name)
925
-
926
-
927
- class ModuleGroup(Module, Container):
928
- """A group of :py:class:`~.Module` in which the updating order does not matter.
929
-
930
- Args:
931
- children_as_tuple: The children objects.
932
- children_as_dict: The children objects.
933
- name: The object name.
934
- mode: The mode which controls the model computation.
935
- child_type: The type of the children object. Default is :py:class:`Module`.
936
- """
937
-
938
- __module__ = 'brainstate'
939
-
940
- def __init__(
941
- self,
942
- *children_as_tuple,
943
- name: Optional[str] = None,
944
- mode: Optional[Mode] = None,
945
- child_type: type = Module,
946
- **children_as_dict
947
- ):
948
- super().__init__(name=name, mode=mode)
949
-
950
- # Attribute of "Container"
951
- self.children = visible_module_dict(self.format_elements(child_type, *children_as_tuple, **children_as_dict))
952
-
953
- def update(self, *args, **kwargs):
954
- """
955
- Step function of a network.
956
-
957
- In this update function, the update functions in children systems are
958
- iteratively called.
959
- """
960
- projs, dyns, others = self.nodes(level=1, include_self=False).split(Projection, Dynamics)
961
-
962
- # update nodes of projections
963
- for node in projs.values():
964
- node()
965
-
966
- # update nodes of dynamics
967
- for node in dyns.values():
968
- node()
969
-
970
- # update nodes with other types, including delays, ...
971
- for node in others.values():
972
- node()
973
-
974
-
975
- def receive_update_output(cls: object):
976
- """
977
- The decorator to mark the object (as the after updates) to receive the output of the update function.
978
-
979
- That is, the `aft_update` will receive the return of the update function::
980
-
981
- ret = model.update(*args, **kwargs)
982
- for fun in model.aft_updates:
983
- fun(ret)
984
-
985
- """
986
- # assert isinstance(cls, Module), 'The input class should be instance of Module.'
987
- if hasattr(cls, '_not_receive_update_output'):
988
- delattr(cls, '_not_receive_update_output')
989
- return cls
990
-
991
-
992
- def not_receive_update_output(cls: object):
993
- """
994
- The decorator to mark the object (as the after updates) to not receive the output of the update function.
995
-
996
- That is, the `aft_update` will not receive the return of the update function::
997
-
998
- ret = model.update(*args, **kwargs)
999
- for fun in model.aft_updates:
1000
- fun()
1001
-
1002
- """
1003
- # assert isinstance(cls, Module), 'The input class should be instance of Module.'
1004
- cls._not_receive_update_output = True
1005
- return cls
1006
-
1007
-
1008
- def receive_update_input(cls: object):
1009
- """
1010
- The decorator to mark the object (as the before updates) to receive the input of the update function.
1011
-
1012
- That is, the `bef_update` will receive the input of the update function::
1013
-
1014
-
1015
- for fun in model.bef_updates:
1016
- fun(*args, **kwargs)
1017
- model.update(*args, **kwargs)
1018
-
1019
- """
1020
- # assert isinstance(cls, Module), 'The input class should be instance of Module.'
1021
- cls._receive_update_input = True
1022
- return cls
1023
-
1024
-
1025
- def not_receive_update_input(cls: object):
1026
- """
1027
- The decorator to mark the object (as the before updates) to not receive the input of the update function.
1028
-
1029
- That is, the `bef_update` will not receive the input of the update function::
1030
-
1031
- for fun in model.bef_updates:
1032
- fun()
1033
- model.update()
1034
-
1035
- """
1036
- # assert isinstance(cls, Module), 'The input class should be instance of Module.'
1037
- if hasattr(cls, '_receive_update_input'):
1038
- delattr(cls, '_receive_update_input')
1039
- return cls
1040
-
1041
-
1042
- class Delay(ExtendedUpdateWithBA, DelayedInit):
1043
- """
1044
- Generate Delays for the given :py:class:`~.State` instance.
1045
-
1046
- The data in this delay variable is arranged as::
1047
-
1048
- delay = 0 [ data
1049
- delay = 1 data
1050
- delay = 2 data
1051
- ... ....
1052
- ... ....
1053
- delay = length-1 data
1054
- delay = length data ]
1055
-
1056
- Args:
1057
- time: int, float. The delay time.
1058
- init: Any. The delay data. It can be a Python number, like float, int, boolean values.
1059
- It can also be arrays. Or a callable function or instance of ``Connector``.
1060
- Note that ``initial_delay_data`` should be arranged as the following way::
1061
-
1062
- delay = 1 [ data
1063
- delay = 2 data
1064
- ... ....
1065
- ... ....
1066
- delay = length-1 data
1067
- delay = length data ]
1068
- entries: optional, dict. The delay access entries.
1069
- name: str. The delay name.
1070
- delay_method: str. The method used for updating delay. Default None.
1071
- mode: Mode. The computing mode. Default None.
1072
- """
1073
-
1074
- __module__ = 'brainstate'
1075
-
1076
- non_hashable_params = ('time', 'entries', 'name')
1077
- max_time: float #
1078
- max_length: int
1079
- history: Optional[State]
1080
-
1081
- def __init__(
1082
- self,
1083
- target_info: PyTree,
1084
- time: Optional[Union[int, float]] = None, # delay time
1085
- init: Optional[Union[ArrayLike, Callable]] = None, # delay data before t0
1086
- entries: Optional[Dict] = None, # delay access entry
1087
- delay_method: Optional[str] = _DELAY_ROTATE, # delay method
1088
- interp_method: str = _INTERP_LINEAR, # interpolation method
1089
- # others
1090
- name: Optional[str] = None,
1091
- mode: Optional[Mode] = None,
1092
- ):
1093
-
1094
- # target information
1095
- self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), target_info)
1096
-
1097
- # delay method
1098
- assert delay_method in [_DELAY_ROTATE, _DELAY_CONCAT], (f'Un-supported delay method {delay_method}. '
1099
- f'Only support {_DELAY_ROTATE} and {_DELAY_CONCAT}')
1100
- self.delay_method = delay_method
1101
-
1102
- # interp method
1103
- assert interp_method in [_INTERP_LINEAR, _INTERP_ROUND], (f'Un-supported interpolation method {interp_method}. '
1104
- f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}')
1105
- self.interp_method = interp_method
1106
-
1107
- # delay length and time
1108
- self.max_time, delay_length = _get_delay(time, None)
1109
- self.max_length = delay_length + 1
1110
-
1111
- super().__init__(name=name, mode=mode)
1112
-
1113
- # delay data
1114
- if init is not None:
1115
- if not isinstance(init, (numbers.Number, jax.Array, np.ndarray, Callable)):
1116
- raise TypeError(f'init should be Array, Callable, or None. But got {init}')
1117
- self._init = init
1118
- self._history = None
1119
-
1120
- # other info
1121
- self._registered_entries = dict()
1122
-
1123
- # other info
1124
- if entries is not None:
1125
- for entry, delay_time in entries.items():
1126
- self.register_entry(entry, delay_time)
1127
-
1128
- def __repr__(self):
1129
- name = self.__class__.__name__
1130
- return (f'{name}('
1131
- f'delay_length={self.max_length}, '
1132
- f'target_info={self.target_info}, '
1133
- f'delay_method="{self.delay_method}", '
1134
- f'interp_method="{self.interp_method}")')
1135
-
1136
- @property
1137
- def history(self):
1138
- return self._history
1139
-
1140
- @history.setter
1141
- def history(self, value):
1142
- self._history = value
1143
-
1144
- def _f_to_init(self, a, batch_size, length):
1145
- shape = list(a.shape)
1146
- if batch_size is not None:
1147
- shape.insert(self.mode.batch_axis, batch_size)
1148
- shape.insert(0, length)
1149
- if isinstance(self._init, (jax.Array, np.ndarray, numbers.Number)):
1150
- data = jnp.broadcast_to(jnp.asarray(self._init, a.dtype), shape)
1151
- elif callable(self._init):
1152
- data = self._init(shape, dtype=a.dtype)
1153
- else:
1154
- assert self._init is None, f'init should be Array, Callable, or None. but got {self._init}'
1155
- data = jnp.zeros(shape, dtype=a.dtype)
1156
- return data
1157
-
1158
- def init_state(self, batch_size: int = None, **kwargs):
1159
- if batch_size is not None:
1160
- assert self.mode.has(Batching), 'The mode should have Batching behavior when batch_size is not None.'
1161
- fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
1162
- self.history = State(jax.tree.map(fun, self.target_info))
1163
-
1164
- def reset_state(self, batch_size: int = None, **kwargs):
1165
- if batch_size is not None:
1166
- assert self.mode.has(Batching), 'The mode should have Batching behavior when batch_size is not None.'
1167
- fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
1168
- self.history.value = jax.tree.map(fun, self.target_info)
1169
-
1170
- def register_entry(
1171
- self,
1172
- entry: str,
1173
- delay_time: Optional[Union[int, float]] = None,
1174
- delay_step: Optional[int] = None,
1175
- ) -> 'Delay':
1176
- """
1177
- Register an entry to access the delay data.
1178
-
1179
- Args:
1180
- entry: str. The entry to access the delay data.
1181
- delay_time: The delay time of the entry (can be a float).
1182
- delay_step: The delay step of the entry (must be an int). ``delat_step = delay_time / dt``.
1183
-
1184
- Returns:
1185
- Return the self.
1186
- """
1187
- if entry in self._registered_entries:
1188
- raise KeyError(f'Entry {entry} has been registered. '
1189
- f'The existing delay for the key {entry} is {self._registered_entries[entry]}. '
1190
- f'The new delay for the key {entry} is {delay_time}. '
1191
- f'You can use another key. ')
1192
-
1193
- if isinstance(delay_time, (np.ndarray, jax.Array)):
1194
- assert delay_time.size == 1 and delay_time.ndim == 0
1195
- delay_time = delay_time.item()
1196
-
1197
- _, delay_step = _get_delay(delay_time, delay_step)
1198
-
1199
- # delay variable
1200
- if self.max_length <= delay_step + 1:
1201
- self.max_length = delay_step + 1
1202
- self.max_time = delay_time
1203
- self._registered_entries[entry] = delay_step
1204
- return self
1205
-
1206
- def at(self, entry: str, *indices) -> ArrayLike:
1207
- """
1208
- Get the data at the given entry.
1209
-
1210
- Args:
1211
- entry: str. The entry to access the data.
1212
- *indices: The slicing indices. Not include the slice at the batch dimension.
1213
-
1214
- Returns:
1215
- The data.
1216
- """
1217
- assert isinstance(entry, str), (f'entry should be a string for describing the '
1218
- f'entry of the delay data. But we got {entry}.')
1219
- if entry not in self._registered_entries:
1220
- raise KeyError(f'Does not find delay entry "{entry}".')
1221
- delay_step = self._registered_entries[entry]
1222
- if delay_step is None:
1223
- delay_step = 0
1224
- return self.retrieve_at_step(delay_step, *indices)
1225
-
1226
- def retrieve_at_step(self, delay_step, *indices) -> PyTree:
1227
- """
1228
- Retrieve the delay data at the given delay time step (the integer to indicate the time step).
1229
-
1230
- Parameters
1231
- ----------
1232
- delay_step: int_like
1233
- Retrieve the data at the given time step.
1234
- indices: tuple
1235
- The indices to slice the data.
1236
-
1237
- Returns
1238
- -------
1239
- delay_data: The delay data at the given delay step.
1240
-
1241
- """
1242
- assert self.history is not None, 'The delay history is not initialized.'
1243
- assert delay_step is not None, 'The delay step should be given.'
1244
-
1245
- if environ.get(environ.JIT_ERROR_CHECK, False):
1246
- def _check_delay(delay_len):
1247
- raise ValueError(f'The request delay length should be less than the '
1248
- f'maximum delay {self.max_length - 1}. But we got {delay_len}')
1249
-
1250
- jit_error_if(delay_step >= self.max_length, _check_delay, delay_step)
1251
-
1252
- # rotation method
1253
- if self.delay_method == _DELAY_ROTATE:
1254
- i = environ.get(environ.I, desc='The time step index.')
1255
- di = i - delay_step
1256
- delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32)
1257
- delay_idx = jax.lax.stop_gradient(delay_idx)
1258
-
1259
- elif self.delay_method == _DELAY_CONCAT:
1260
- delay_idx = delay_step
1261
-
1262
- else:
1263
- raise ValueError(f'Unknown delay updating method "{self.delay_method}"')
1264
-
1265
- # the delay index
1266
- if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer):
1267
- raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}')
1268
- indices = (delay_idx,) + indices
1269
-
1270
- # the delay data
1271
- return jax.tree.map(lambda a: a[indices], self.history.value)
1272
-
1273
- def retrieve_at_time(self, delay_time, *indices) -> PyTree:
1274
- """
1275
- Retrieve the delay data at the given delay time step (the integer to indicate the time step).
1276
-
1277
- Parameters
1278
- ----------
1279
- delay_time: float
1280
- Retrieve the data at the given time.
1281
- indices: tuple
1282
- The indices to slice the data.
1283
-
1284
- Returns
1285
- -------
1286
- delay_data: The delay data at the given delay step.
1287
-
1288
- """
1289
- assert self.history is not None, 'The delay history is not initialized.'
1290
- assert delay_time is not None, 'The delay time should be given.'
1291
-
1292
- current_time = environ.get(environ.T, desc='The current time.')
1293
- dt = environ.get_dt()
1294
-
1295
- if environ.get(environ.JIT_ERROR_CHECK, False):
1296
- def _check_delay(t_now, t_delay):
1297
- raise ValueError(f'The request delay time should be within '
1298
- f'[{t_now - self.max_time - dt}, {t_now}], '
1299
- f'but we got {t_delay}')
1300
-
1301
- jit_error_if(jnp.logical_or(delay_time > current_time,
1302
- delay_time < current_time - self.max_time - dt),
1303
- _check_delay,
1304
- current_time, delay_time)
1305
-
1306
- diff = current_time - delay_time
1307
- float_time_step = diff / dt
1308
-
1309
- if self.interp_method == _INTERP_LINEAR: # "linear" interpolation
1310
- # def _interp(target):
1311
- # if len(indices) > 0:
1312
- # raise NotImplementedError('The slicing indices are not supported in the linear interpolation.')
1313
- # if self.delay_method == _DELAY_ROTATE:
1314
- # i = environ.get(environ.I, desc='The time step index.')
1315
- # _interp_fun = partial(jnp.interp, period=self.max_length)
1316
- # for dim in range(1, target.ndim, 1):
1317
- # _interp_fun = jax.vmap(_interp_fun, in_axes=(None, None, dim), out_axes=dim - 1)
1318
- # di = i - jnp.arange(self.max_length)
1319
- # delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32)
1320
- # return _interp_fun(float_time_step, delay_idx, target)
1321
- #
1322
- # elif self.delay_method == _DELAY_CONCAT:
1323
- # _interp_fun = partial(jnp.interp, period=self.max_length)
1324
- # for dim in range(1, target.ndim, 1):
1325
- # _interp_fun = jax.vmap(_interp_fun, in_axes=(None, None, dim), out_axes=dim - 1)
1326
- # return _interp_fun(float_time_step, jnp.arange(self.max_length), target)
1327
- #
1328
- # else:
1329
- # raise ValueError(f'Unknown delay updating method "{self.delay_method}"')
1330
- # return jax.tree.map(_interp, self.history.value)
1331
-
1332
- data_at_t0 = self.retrieve_at_step(jnp.asarray(jnp.floor(float_time_step), dtype=jnp.int32), *indices)
1333
- data_at_t1 = self.retrieve_at_step(jnp.asarray(jnp.ceil(float_time_step), dtype=jnp.int32), *indices)
1334
- t_diff = float_time_step - jnp.floor(float_time_step)
1335
- return jax.tree.map(lambda a, b: a * (1 - t_diff) + b * t_diff, data_at_t0, data_at_t1)
1336
-
1337
- elif self.interp_method == _INTERP_ROUND: # "round" interpolation
1338
- return self.retrieve_at_step(
1339
- jnp.asarray(jnp.round(float_time_step), dtype=jnp.int32),
1340
- *indices
1341
- )
1342
-
1343
- else: # raise error
1344
- raise ValueError(f'Un-supported interpolation method {self.interp_method}, '
1345
- f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}')
1346
-
1347
- def update(self, current: PyTree) -> None:
1348
- """
1349
- Update delay variable with the new data.
1350
- """
1351
- assert self.history is not None, 'The delay history is not initialized.'
1352
-
1353
- # update the delay data at the rotation index
1354
- if self.delay_method == _DELAY_ROTATE:
1355
- i = environ.get(environ.I)
1356
- idx = jnp.asarray(i % self.max_length, dtype=environ.dutype())
1357
- idx = jax.lax.stop_gradient(idx)
1358
- self.history.value = jax.tree.map(
1359
- lambda hist, cur: hist.at[idx].set(cur),
1360
- self.history.value,
1361
- current
1362
- )
1363
- # update the delay data at the first position
1364
- elif self.delay_method == _DELAY_CONCAT:
1365
- current = jax.tree.map(lambda a: jnp.expand_dims(a, 0), current)
1366
- if self.max_length > 1:
1367
- self.history.value = jax.tree.map(
1368
- lambda hist, cur: jnp.concatenate([cur, hist[:-1]], axis=0),
1369
- self.history.value,
1370
- current
1371
- )
1372
- else:
1373
- self.history.value = current
1374
-
1375
- else:
1376
- raise ValueError(f'Unknown updating method "{self.delay_method}"')
1377
-
1378
-
1379
- class _StateDelay(Delay):
1380
- """
1381
- The state delay class.
1382
-
1383
- Args:
1384
- target: The target state instance.
1385
- init: The initial delay data.
1386
- """
1387
-
1388
- __module__ = 'brainstate'
1389
- _invisible_states = ('target',)
1390
-
1391
- def __init__(
1392
- self,
1393
- target: State,
1394
- time: Optional[Union[int, float]] = None, # delay time
1395
- init: Optional[Union[ArrayLike, Callable]] = None, # delay data init
1396
- entries: Optional[Dict] = None, # delay access entry
1397
- delay_method: Optional[str] = _DELAY_ROTATE, # delay method
1398
- # others
1399
- name: Optional[str] = None,
1400
- mode: Optional[Mode] = None,
1401
- ):
1402
- super().__init__(target_info=target.value,
1403
- time=time,
1404
- init=init,
1405
- entries=entries,
1406
- delay_method=delay_method,
1407
- name=name,
1408
- mode=mode)
1409
- self.target = target
1410
-
1411
- def update(self, *args, **kwargs):
1412
- super().update(self.target.value)
1413
-
1414
-
1415
- class DelayAccess(Module):
1416
- """
1417
- The delay access class.
1418
-
1419
- Args:
1420
- delay: The delay instance.
1421
- time: The delay time.
1422
- indices: The indices of the delay data.
1423
- delay_entry: The delay entry.
1424
- """
1425
-
1426
- __module__ = 'brainstate'
1427
-
1428
- def __init__(
1429
- self,
1430
- delay: Delay,
1431
- time: Union[None, int, float],
1432
- *indices,
1433
- delay_entry: str = None
1434
- ):
1435
- super().__init__(mode=delay.mode)
1436
- self.refs = {'delay': delay}
1437
- assert isinstance(delay, Delay), 'The input delay should be an instance of Delay.'
1438
- self._delay_entry = delay_entry or self.name
1439
- delay.register_entry(self._delay_entry, time)
1440
- self.indices = indices
1441
-
1442
- def update(self):
1443
- return self.refs['delay'].at(self._delay_entry, *self.indices)
1444
-
1445
-
1446
- def register_delay_of_target(target: JointTypes[ExtendedUpdateWithBA, UpdateReturn]):
1447
- """Register delay class for the given target.
1448
-
1449
- Args:
1450
- target: The target class to register delay.
1451
-
1452
- Returns:
1453
- The delay registered for the given target.
1454
- """
1455
- if not target.has_after_update(delay_identifier):
1456
- assert isinstance(target, JointTypes[ExtendedUpdateWithBA, UpdateReturn])
1457
- target.add_after_update(delay_identifier, Delay(target.update_return_info()))
1458
- delay_cls = target.get_after_update(delay_identifier)
1459
- return delay_cls
1460
-
1461
-
1462
- @set_module_as('brainstate')
1463
- def call_order(level: int = 0):
1464
- """The decorator for indicating the resetting level.
1465
-
1466
- The function takes an optional integer argument level with a default value of 0.
1467
-
1468
- The lower the level, the earlier the function is called.
1469
-
1470
- >>> import brainstate as bst
1471
- >>> bst.call_order(0)
1472
- >>> bst.call_order(-1)
1473
- >>> bst.call_order(-2)
1474
-
1475
- """
1476
- if level < 0:
1477
- level = _max_order + level
1478
- if level < 0 or level >= _max_order:
1479
- raise ValueError(f'"call_order" must be an integer in [0, {_max_order}). but we got {level}.')
1480
-
1481
- def wrap(fun: Callable):
1482
- fun.call_order = level
1483
- return fun
1484
-
1485
- return wrap
1486
-
1487
-
1488
- @set_module_as('brainstate')
1489
- def init_states(target: Module, *args, **kwargs) -> Module:
1490
- """
1491
- Initialize states of all children nodes in the given target.
1492
-
1493
- Args:
1494
- target: The target Module.
1495
-
1496
- Returns:
1497
- The target Module.
1498
- """
1499
- nodes_with_order = []
1500
-
1501
- # reset node whose `init_state` has no `call_order`
1502
- for node in list(target.nodes().values()):
1503
- if not hasattr(node.init_state, 'call_order'):
1504
- node.init_state(*args, **kwargs)
1505
- else:
1506
- nodes_with_order.append(node)
1507
-
1508
- # reset the node's states
1509
- for node in sorted(nodes_with_order, key=lambda x: x.init_state.call_order):
1510
- node.init_state(*args, **kwargs)
1511
-
1512
- return target
1513
-
1514
-
1515
- @set_module_as('brainstate')
1516
- def reset_states(target: Module, *args, **kwargs) -> Module:
1517
- """
1518
- Reset states of all children nodes in the given target.
1519
-
1520
- Args:
1521
- target: The target Module.
1522
-
1523
- Returns:
1524
- The target Module.
1525
- """
1526
- nodes_with_order = []
1527
-
1528
- # reset node whose `init_state` has no `call_order`
1529
- for node in list(target.nodes().values()):
1530
- if not hasattr(node.reset_state, 'call_order'):
1531
- node.reset_state(*args, **kwargs)
1532
- else:
1533
- nodes_with_order.append(node)
1534
-
1535
- # reset the node's states
1536
- for node in sorted(nodes_with_order, key=lambda x: x.reset_state.call_order):
1537
- node.reset_state(*args, **kwargs)
1538
-
1539
- return target
1540
-
1541
-
1542
- @set_module_as('brainstate')
1543
- def load_states(target: Module, state_dict: Dict, **kwargs):
1544
- """Copy parameters and buffers from :attr:`state_dict` into
1545
- this module and its descendants.
1546
-
1547
- Args:
1548
- target: Module. The dynamical system to load its states.
1549
- state_dict: dict. A dict containing parameters and persistent buffers.
1550
-
1551
- Returns:
1552
- -------
1553
- ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
1554
-
1555
- * **missing_keys** is a list of str containing the missing keys
1556
- * **unexpected_keys** is a list of str containing the unexpected keys
1557
- """
1558
- missing_keys = []
1559
- unexpected_keys = []
1560
- for name, node in target.nodes().items():
1561
- r = node.load_state(state_dict[name], **kwargs)
1562
- if r is not None:
1563
- missing, unexpected = r
1564
- missing_keys.extend([f'{name}.{key}' for key in missing])
1565
- unexpected_keys.extend([f'{name}.{key}' for key in unexpected])
1566
- return StateLoadResult(missing_keys, unexpected_keys)
1567
-
1568
-
1569
- @set_module_as('brainstate')
1570
- def save_states(target: Module, **kwargs) -> Dict:
1571
- """Save all states in the ``target`` as a dictionary for later disk serialization.
1572
-
1573
- Args:
1574
- target: Module. The node to save its states.
1575
-
1576
- Returns:
1577
- Dict. The state dict for serialization.
1578
- """
1579
- return {key: node.save_state(**kwargs) for key, node in target.nodes().items()}
1580
-
1581
-
1582
- @set_module_as('brainstate')
1583
- def assign_state_values(target: Module, *state_by_abs_path: Dict):
1584
- """
1585
- Assign state values according to the given state dictionary.
1586
-
1587
- Parameters
1588
- ----------
1589
- target: Module
1590
- The target module.
1591
- state_by_abs_path: dict
1592
- The state dictionary which is accessed by the "absolute" accessing method.
1593
-
1594
- """
1595
- all_states = dict()
1596
- for state in state_by_abs_path:
1597
- all_states.update(state)
1598
- variables = target.states(include_self=True, method='absolute')
1599
- keys1 = set(all_states.keys())
1600
- keys2 = set(variables.keys())
1601
- for key in keys2.intersection(keys1):
1602
- variables[key].value = jax.numpy.asarray(all_states[key])
1603
- unexpected_keys = list(keys1 - keys2)
1604
- missing_keys = list(keys2 - keys1)
1605
- return unexpected_keys, missing_keys
1606
-
1607
-
1608
- def _input_label_start(label: str):
1609
- # unify the input label repr.
1610
- return f'{label} // '
1611
-
1612
-
1613
- def _input_label_repr(name: str, label: Optional[str] = None):
1614
- # unify the input label repr.
1615
- return name if label is None else (_input_label_start(label) + str(name))
1616
-
1617
-
1618
- def _repr_context(repr_str, indent):
1619
- splits = repr_str.split('\n')
1620
- splits = [(s if i == 0 else (indent + s)) for i, s in enumerate(splits)]
1621
- return '\n'.join(splits)
1622
-
1623
-
1624
- def _get_delay(delay_time, delay_step):
1625
- if delay_time is None:
1626
- if delay_step is None:
1627
- return 0., 0
1628
- else:
1629
- assert isinstance(delay_step, int), '"delay_step" should be an integer.'
1630
- if delay_step == 0:
1631
- return 0., 0
1632
- delay_time = delay_step * environ.get_dt()
1633
- else:
1634
- assert delay_step is None, '"delay_step" should be None if "delay_time" is given.'
1635
- # assert isinstance(delay_time, (int, float))
1636
- delay_step = math.ceil(delay_time / environ.get_dt())
1637
- return delay_time, delay_step