brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1360 -1318
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,631 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+
19
+ """
20
+ All the basic dynamics class for the ``brainstate``.
21
+
22
+ For handling dynamical systems:
23
+
24
+ - ``DynamicsGroup``: The class for a group of modules, which update ``Projection`` first,
25
+ then ``Dynamics``, finally others.
26
+ - ``Projection``: The class for the synaptic projection.
27
+ - ``Dynamics``: The class for the dynamical system.
28
+
29
+ For handling the delays:
30
+
31
+ - ``Delay``: The class for all delays.
32
+ - ``DelayAccess``: The class for the delay access.
33
+
34
+ """
35
+ from __future__ import annotations
36
+
37
+ from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, TYPE_CHECKING
38
+
39
+ import brainunit as u
40
+ import numpy as np
41
+
42
+ from brainstate import environ
43
+ from brainstate._state import State
44
+ from brainstate.graph import Node
45
+ from brainstate.mixin import ParamDescriber
46
+ from brainstate.nn._module import Module
47
+ from brainstate.typing import Size, ArrayLike
48
+ from ._state_delay import StateWithDelay, Delay
49
+
50
+ __all__ = [
51
+ 'DynamicsGroup', 'Projection', 'Dynamics', 'Prefetch',
52
+ ]
53
+
54
+ T = TypeVar('T')
55
+ _max_order = 10
56
+
57
+
58
+ class Projection(Module):
59
+ """
60
+ Base class to model synaptic projections.
61
+ """
62
+
63
+ __module__ = 'brainstate.nn'
64
+
65
+ def update(self, *args, **kwargs):
66
+ sub_nodes = tuple(self.nodes(allowed_hierarchy=(1, 1)).values())
67
+ if len(sub_nodes):
68
+ for node in sub_nodes:
69
+ node(*args, **kwargs)
70
+ else:
71
+ raise ValueError('Do not implement the update() function.')
72
+
73
+
74
+ class Dynamics(Module):
75
+ """
76
+ Base class to model dynamics.
77
+
78
+ .. note::
79
+ In general, every instance of :py:class:`~.Module` implemented in
80
+ BrainPy only defines the evolving function at each time step :math:`t`.
81
+
82
+ If users want to define the logic of running models across multiple steps,
83
+ we recommend users to use :py:func:`~.for_loop`, :py:class:`~.LoopOverTime`,
84
+ :py:class:`~.DSRunner`, or :py:class:`~.DSTrainer`.
85
+
86
+ To be compatible with previous APIs, :py:class:`~.Module` inherits
87
+ from the :py:class:`~.DelayRegister`. It's worthy to note that the methods of
88
+ :py:class:`~.DelayRegister` will be removed in the future, including:
89
+
90
+ - ``.register_delay()``
91
+ - ``.get_delay_data()``
92
+ - ``.update_local_delays()``
93
+ - ``.reset_local_delays()``
94
+
95
+ There are several essential attributes:
96
+
97
+ - ``size``: the geometry of the neuron group. For example, `(10, )` denotes a line of
98
+ neurons, `(10, 10)` denotes a neuron group aligned in a 2D space, `(10, 15, 4)` denotes
99
+ a 3-dimensional neuron group.
100
+ - ``num``: the flattened number of neurons in the group. For example, `size=(10, )` => \
101
+ `num=10`, `size=(10, 10)` => `num=100`, `size=(10, 15, 4)` => `num=600`.
102
+
103
+ Args:
104
+ in_size: The neuron group geometry.
105
+ name: The name of the dynamic system.
106
+ """
107
+
108
+ __module__ = 'brainstate.nn'
109
+
110
+ # before updates
111
+ _before_updates: Optional[Dict[Hashable, Callable]]
112
+
113
+ # after updates
114
+ _after_updates: Optional[Dict[Hashable, Callable]]
115
+
116
+ # current inputs
117
+ _current_inputs: Optional[Dict[str, ArrayLike | Callable]]
118
+
119
+ # delta inputs
120
+ _delta_inputs: Optional[Dict[str, ArrayLike | Callable]]
121
+
122
+ def __init__(
123
+ self,
124
+ in_size: Size,
125
+ name: Optional[str] = None,
126
+ ):
127
+ # initialize
128
+ super().__init__(name=name)
129
+
130
+ # geometry size of neuron population
131
+ if isinstance(in_size, (list, tuple)):
132
+ if len(in_size) <= 0:
133
+ raise ValueError(f'"in_size" must be int, or a tuple/list of int. But we got {type(in_size)}')
134
+ if not isinstance(in_size[0], (int, np.integer)):
135
+ raise ValueError(f'"in_size" must be int, or a tuple/list of int. But we got {type(in_size)}')
136
+ in_size = tuple(in_size)
137
+ elif isinstance(in_size, (int, np.integer)):
138
+ in_size = (in_size,)
139
+ else:
140
+ raise ValueError(f'"in_size" must be int, or a tuple/list of int. But we got {type(in_size)}')
141
+ self.in_size = in_size
142
+
143
+ # current inputs
144
+ self._current_inputs = None
145
+
146
+ # delta inputs
147
+ self._delta_inputs = None
148
+
149
+ # before updates
150
+ self._before_updates = None
151
+
152
+ # after updates
153
+ self._after_updates = None
154
+
155
+ # in-/out- size of neuron population
156
+ self.out_size = self.in_size
157
+
158
+ @property
159
+ def varshape(self):
160
+ """The shape of variables in the neuron group."""
161
+ return self.in_size
162
+
163
+ @property
164
+ def current_inputs(self):
165
+ """
166
+ The current inputs of the model. It should be a dictionary of the input data.
167
+ """
168
+ return self._current_inputs
169
+
170
+ @property
171
+ def delta_inputs(self):
172
+ """
173
+ The delta inputs of the model. It should be a dictionary of the input data.
174
+ """
175
+ return self._delta_inputs
176
+
177
+ def add_current_input(
178
+ self,
179
+ key: str,
180
+ inp: Union[Callable, ArrayLike],
181
+ label: Optional[str] = None
182
+ ):
183
+ """
184
+ Add a current input function.
185
+
186
+ Args:
187
+ key: str. The dict key.
188
+ inp: Callable, ArrayLike. The currents or the function to generate currents.
189
+ label: str. The input label.
190
+ """
191
+ key = _input_label_repr(key, label)
192
+ if self._current_inputs is None:
193
+ self._current_inputs = dict()
194
+ if key in self._current_inputs:
195
+ if id(self._current_inputs[key]) != id(inp):
196
+ raise ValueError(f'Key "{key}" has been defined and used in the current inputs of {self}.')
197
+ self._current_inputs[key] = inp
198
+
199
+ def add_delta_input(
200
+ self,
201
+ key: str,
202
+ inp: Union[Callable, ArrayLike],
203
+ label: Optional[str] = None
204
+ ):
205
+ """
206
+ Add a delta input function.
207
+
208
+ Args:
209
+ key: str. The dict key.
210
+ inp: Callable, ArrayLike. The currents or the function to generate currents.
211
+ label: str. The input label.
212
+ """
213
+ key = _input_label_repr(key, label)
214
+ if self._delta_inputs is None:
215
+ self._delta_inputs = dict()
216
+ if key in self._delta_inputs:
217
+ if id(self._delta_inputs[key]) != id(inp):
218
+ raise ValueError(f'Key "{key}" has been defined and used.')
219
+ self._delta_inputs[key] = inp
220
+
221
+ def get_input(self, key: str):
222
+ """Get the input function.
223
+
224
+ Args:
225
+ key: str. The key.
226
+
227
+ Returns:
228
+ The input function which generates currents.
229
+ """
230
+ if self._current_inputs is not None and key in self._current_inputs:
231
+ return self._current_inputs[key]
232
+ elif self._delta_inputs is not None and key in self._delta_inputs:
233
+ return self._delta_inputs[key]
234
+ else:
235
+ raise ValueError(f'Input key {key} is not in current/delta inputs of the module {self}.')
236
+
237
+ def sum_current_inputs(
238
+ self,
239
+ init: Any,
240
+ *args,
241
+ label: Optional[str] = None,
242
+ **kwargs
243
+ ):
244
+ """
245
+ Summarize all current inputs by the defined input functions ``.current_inputs``.
246
+
247
+ Args:
248
+ init: The initial input data.
249
+ *args: The arguments for input functions.
250
+ **kwargs: The arguments for input functions.
251
+ label: str. The input label.
252
+
253
+ Returns:
254
+ The total currents.
255
+ """
256
+ if self._current_inputs is None:
257
+ return init
258
+ if label is None:
259
+ # no label
260
+ for key in tuple(self._current_inputs.keys()):
261
+ out = self._current_inputs[key]
262
+ init = init + (out(*args, **kwargs) if callable(out) else out)
263
+ if not callable(out):
264
+ self._current_inputs.pop(key)
265
+ else:
266
+ # has label
267
+ label_repr = _input_label_start(label)
268
+ for key in tuple(self._current_inputs.keys()):
269
+ if key.startswith(label_repr):
270
+ out = self._current_inputs[key]
271
+ init = init + (out(*args, **kwargs) if callable(out) else out)
272
+ if not callable(out):
273
+ self._current_inputs.pop(key)
274
+ return init
275
+
276
+ def sum_delta_inputs(
277
+ self,
278
+ init: Any,
279
+ *args,
280
+ label: Optional[str] = None,
281
+ **kwargs
282
+ ):
283
+ """
284
+ Summarize all delta inputs by the defined input functions ``.delta_inputs``.
285
+
286
+ Args:
287
+ init: The initial input data.
288
+ *args: The arguments for input functions.
289
+ **kwargs: The arguments for input functions.
290
+ label: str. The input label.
291
+
292
+ Returns:
293
+ The total currents.
294
+ """
295
+ if self._delta_inputs is None:
296
+ return init
297
+ if label is None:
298
+ # no label
299
+ for key in tuple(self._delta_inputs.keys()):
300
+ out = self._delta_inputs[key]
301
+ init = init + (out(*args, **kwargs) if callable(out) else out)
302
+ if not callable(out):
303
+ self._delta_inputs.pop(key)
304
+ else:
305
+ # has label
306
+ label_repr = _input_label_start(label)
307
+ for key in tuple(self._delta_inputs.keys()):
308
+ if key.startswith(label_repr):
309
+ out = self._delta_inputs[key]
310
+ init = init + (out(*args, **kwargs) if callable(out) else out)
311
+ if not callable(out):
312
+ self._delta_inputs.pop(key)
313
+ return init
314
+
315
+ @property
316
+ def before_updates(self):
317
+ """
318
+ The before updates of the model. It should be a dictionary of the updating functions.
319
+ """
320
+ return self._before_updates
321
+
322
+ @property
323
+ def after_updates(self):
324
+ """
325
+ The after updates of the model. It should be a dictionary of the updating functions.
326
+ """
327
+ return self._after_updates
328
+
329
+ def _add_before_update(self, key: Any, fun: Callable):
330
+ """
331
+ Add the before update into this node.
332
+ """
333
+ if self._before_updates is None:
334
+ self._before_updates = dict()
335
+ if key in self.before_updates:
336
+ raise KeyError(f'{key} has been registered in before_updates of {self}')
337
+ self.before_updates[key] = fun
338
+
339
+ def _add_after_update(self, key: Any, fun: Callable):
340
+ """Add the after update into this node"""
341
+ if self._after_updates is None:
342
+ self._after_updates = dict()
343
+ if key in self.after_updates:
344
+ raise KeyError(f'{key} has been registered in after_updates of {self}')
345
+ self.after_updates[key] = fun
346
+
347
+ def _get_before_update(self, key: Any):
348
+ """Get the before update of this node by the given ``key``."""
349
+ if self._before_updates is None:
350
+ raise KeyError(f'{key} is not registered in before_updates of {self}')
351
+ if key not in self.before_updates:
352
+ raise KeyError(f'{key} is not registered in before_updates of {self}')
353
+ return self.before_updates.get(key)
354
+
355
+ def _get_after_update(self, key: Any):
356
+ """Get the after update of this node by the given ``key``."""
357
+ if self._after_updates is None:
358
+ raise KeyError(f'{key} is not registered in after_updates of {self}')
359
+ if key not in self.after_updates:
360
+ raise KeyError(f'{key} is not registered in after_updates of {self}')
361
+ return self.after_updates.get(key)
362
+
363
+ def _has_before_update(self, key: Any):
364
+ """Whether this node has the before update of the given ``key``."""
365
+ if self._before_updates is None:
366
+ return False
367
+ return key in self.before_updates
368
+
369
+ def _has_after_update(self, key: Any):
370
+ """Whether this node has the after update of the given ``key``."""
371
+ if self._after_updates is None:
372
+ return False
373
+ return key in self.after_updates
374
+
375
+ def __call__(self, *args, **kwargs):
376
+ """
377
+ The shortcut to call ``update`` methods.
378
+ """
379
+
380
+ # ``before_updates``
381
+ if self.before_updates is not None:
382
+ for model in self.before_updates.values():
383
+ if hasattr(model, '_receive_update_input'):
384
+ model(*args, **kwargs)
385
+ else:
386
+ model()
387
+
388
+ # update the model self
389
+ ret = self.update(*args, **kwargs)
390
+
391
+ # ``after_updates``
392
+ if self.after_updates is not None:
393
+ for model in self.after_updates.values():
394
+ if hasattr(model, '_not_receive_update_output'):
395
+ model()
396
+ else:
397
+ model(ret)
398
+ return ret
399
+
400
+ def prefetch(self, item: str) -> 'Prefetch':
401
+ return Prefetch(self, item)
402
+
403
+ def align_pre(
404
+ self, dyn: Union[ParamDescriber[T], T]
405
+ ) -> T:
406
+ """
407
+ Align the dynamics before the interaction.
408
+ """
409
+ if isinstance(dyn, Dynamics):
410
+ self._add_after_update(dyn.name, dyn)
411
+ return dyn
412
+ elif isinstance(dyn, ParamDescriber):
413
+ if not isinstance(dyn.cls, Dynamics):
414
+ raise TypeError(f'The input {dyn} should be an instance of {Dynamics}.')
415
+ if not self._has_after_update(dyn.identifier):
416
+ self._add_after_update(dyn.identifier, dyn())
417
+ return self._get_after_update(dyn.identifier)
418
+ else:
419
+ raise TypeError(f'The input {dyn} should be an instance of {Dynamics} or a delayed initializer.')
420
+
421
+ def __leaf_fn__(self, name, value):
422
+ if name in ['_in_size', '_out_size', '_name', '_mode',
423
+ '_before_updates', '_after_updates', '_current_inputs', '_delta_inputs']:
424
+ return (name, value) if value is None else (name[1:], value) # skip the first `_`
425
+ return name, value
426
+
427
+
428
+ class Prefetch(Node):
429
+ """
430
+ Prefetch a variable of the given module.
431
+ """
432
+
433
+ def __init__(self, module: Module, item: str):
434
+ super().__init__()
435
+ self.module = module
436
+ self.item = item
437
+
438
+ @property
439
+ def delay(self):
440
+ return PrefetchDelay(self.module, self.item)
441
+
442
+ def __call__(self, *args, **kwargs):
443
+ item = _get_prefetch_item(self)
444
+ return item.value if isinstance(item, State) else item
445
+
446
+
447
+ class PrefetchDelay(Node):
448
+ def __init__(self, module: Dynamics, item: str):
449
+ self.module = module
450
+ self.item = item
451
+
452
+ def at(self, time: ArrayLike):
453
+ return PrefetchDelayAt(self.module, self.item, time)
454
+
455
+
456
+ class PrefetchDelayAt(Node):
457
+ """
458
+ Prefetch the delay of a variable in the given module at a specific time.
459
+
460
+ Args:
461
+ module: The module that has the item with the name specified by ``item`` argument.
462
+ item: The item that has the delay.
463
+ time: The time to retrieve the delay.
464
+ """
465
+
466
+ def __init__(self, module: Dynamics, item: str, time: ArrayLike):
467
+ super().__init__()
468
+ assert isinstance(module, Dynamics), ''
469
+ self.module = module
470
+ self.item = item
471
+ self.time = time
472
+ self.step = u.math.asarray(time / environ.get_dt(), dtype=environ.ditype())
473
+
474
+ # register the delay
475
+ key = _get_delay_key(item)
476
+ if not module._has_after_update(key):
477
+ module._add_after_update(key, not_receive_update_output(StateWithDelay(module, item)))
478
+ self.state_delay: StateWithDelay = module._get_after_update(key)
479
+ self.state_delay.register_delay(time)
480
+
481
+ def __call__(self, *args, **kwargs):
482
+ # return self.state_delay.retrieve_at_time(self.time)
483
+ return self.state_delay.retrieve_at_step(self.step)
484
+
485
+
486
+ def _get_delay_key(item) -> str:
487
+ return f'{item}-delay'
488
+
489
+
490
+ def _get_prefetch_item(target: Union[Prefetch, PrefetchDelayAt]) -> Any:
491
+ item = getattr(target.module, target.item, None)
492
+ if item is None:
493
+ raise AttributeError(f'The target {target.module} should have an `{target.item}` attribute.')
494
+ return item
495
+
496
+
497
+ def _get_prefetch_item_delay(target: Union[Prefetch, PrefetchDelay, PrefetchDelayAt]) -> Delay:
498
+ assert isinstance(target.module, Dynamics), (f'The target module should be an instance '
499
+ f'of Dynamics. But got {target.module}.')
500
+ delay = target.module._get_after_update(_get_delay_key(target.item))
501
+ if not isinstance(delay, StateWithDelay):
502
+ raise TypeError(f'The prefetch target should be a {StateWithDelay.__name__} when accessing '
503
+ f'its delay. But got {delay}.')
504
+ return delay
505
+
506
+
507
+ def maybe_init_prefetch(target, *args, **kwargs):
508
+ if isinstance(target, Prefetch):
509
+ _get_prefetch_item(target)
510
+
511
+ elif isinstance(target, PrefetchDelay):
512
+ _get_prefetch_item_delay(target)
513
+
514
+ elif isinstance(target, PrefetchDelayAt):
515
+ delay = _get_prefetch_item_delay(target)
516
+ delay.register_delay(target.time)
517
+
518
+
519
+ class DynamicsGroup(Module):
520
+ """
521
+ A group of :py:class:`~.Module` in which the updating order does not matter.
522
+
523
+ Args:
524
+ children_as_tuple: The children objects.
525
+ children_as_dict: The children objects.
526
+ """
527
+
528
+ __module__ = 'brainstate.nn'
529
+
530
+ if not TYPE_CHECKING:
531
+ def __init__(self, *children_as_tuple, **children_as_dict):
532
+ super().__init__()
533
+ self.layers_tuple = tuple(children_as_tuple)
534
+ self.layers_dict = dict(children_as_dict)
535
+
536
+ def update(self, *args, **kwargs):
537
+ """
538
+ Update function of a network.
539
+
540
+ In this update function, the update functions in children systems are iteratively called.
541
+ """
542
+ projs, dyns, others = self.nodes(allowed_hierarchy=(1, 1)).split(Projection, Dynamics)
543
+
544
+ # update nodes of projections
545
+ for node in projs.values():
546
+ node()
547
+
548
+ # update nodes of dynamics
549
+ for node in dyns.values():
550
+ node()
551
+
552
+ # update nodes with other types, including delays, ...
553
+ for node in others.values():
554
+ node()
555
+
556
+
557
+ def receive_update_output(cls: object):
558
+ """
559
+ The decorator to mark the object (as the after updates) to receive the output of the update function.
560
+
561
+ That is, the `aft_update` will receive the return of the update function::
562
+
563
+ ret = model.update(*args, **kwargs)
564
+ for fun in model.aft_updates:
565
+ fun(ret)
566
+
567
+ """
568
+ # assert isinstance(cls, Module), 'The input class should be instance of Module.'
569
+ if hasattr(cls, '_not_receive_update_output'):
570
+ delattr(cls, '_not_receive_update_output')
571
+ return cls
572
+
573
+
574
+ def not_receive_update_output(cls: T) -> T:
575
+ """
576
+ The decorator to mark the object (as the after updates) to not receive the output of the update function.
577
+
578
+ That is, the `aft_update` will not receive the return of the update function::
579
+
580
+ ret = model.update(*args, **kwargs)
581
+ for fun in model.aft_updates:
582
+ fun()
583
+
584
+ """
585
+ # assert isinstance(cls, Module), 'The input class should be instance of Module.'
586
+ cls._not_receive_update_output = True
587
+ return cls
588
+
589
+
590
+ def receive_update_input(cls: object):
591
+ """
592
+ The decorator to mark the object (as the before updates) to receive the input of the update function.
593
+
594
+ That is, the `bef_update` will receive the input of the update function::
595
+
596
+
597
+ for fun in model.bef_updates:
598
+ fun(*args, **kwargs)
599
+ model.update(*args, **kwargs)
600
+
601
+ """
602
+ # assert isinstance(cls, Module), 'The input class should be instance of Module.'
603
+ cls._receive_update_input = True
604
+ return cls
605
+
606
+
607
+ def not_receive_update_input(cls: object):
608
+ """
609
+ The decorator to mark the object (as the before updates) to not receive the input of the update function.
610
+
611
+ That is, the `bef_update` will not receive the input of the update function::
612
+
613
+ for fun in model.bef_updates:
614
+ fun()
615
+ model.update()
616
+
617
+ """
618
+ # assert isinstance(cls, Module), 'The input class should be instance of Module.'
619
+ if hasattr(cls, '_receive_update_input'):
620
+ delattr(cls, '_receive_update_input')
621
+ return cls
622
+
623
+
624
+ def _input_label_start(label: str):
625
+ # unify the input label repr.
626
+ return f'{label} // '
627
+
628
+
629
+ def _input_label_repr(name: str, label: Optional[str] = None):
630
+ # unify the input label repr.
631
+ return name if label is None else (_input_label_start(label) + str(name))
@@ -0,0 +1,79 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from __future__ import annotations
19
+
20
+ import unittest
21
+
22
+ import numpy as np
23
+
24
+ import brainstate as bst
25
+
26
+
27
+ class TestModuleGroup(unittest.TestCase):
28
+ def test_initialization(self):
29
+ group = bst.nn.DynamicsGroup()
30
+ self.assertIsInstance(group, bst.nn.DynamicsGroup)
31
+
32
+
33
+ class TestProjection(unittest.TestCase):
34
+ def test_initialization(self):
35
+ proj = bst.nn.Projection()
36
+ self.assertIsInstance(proj, bst.nn.Projection)
37
+
38
+ def test_update_not_implemented(self):
39
+ proj = bst.nn.Projection()
40
+ with self.assertRaises(ValueError):
41
+ proj.update()
42
+
43
+
44
+ class TestDynamics(unittest.TestCase):
45
+ def test_initialization(self):
46
+ dyn = bst.nn.Dynamics(in_size=10)
47
+ self.assertIsInstance(dyn, bst.nn.Dynamics)
48
+ self.assertEqual(dyn.in_size, (10,))
49
+ self.assertEqual(dyn.out_size, (10,))
50
+
51
+ def test_size_validation(self):
52
+ with self.assertRaises(ValueError):
53
+ bst.nn.Dynamics(in_size=[])
54
+ with self.assertRaises(ValueError):
55
+ bst.nn.Dynamics(in_size="invalid")
56
+
57
+ def test_input_handling(self):
58
+ dyn = bst.nn.Dynamics(in_size=10)
59
+ dyn.add_current_input("test_current", lambda: np.random.rand(10))
60
+ dyn.add_delta_input("test_delta", lambda: np.random.rand(10))
61
+
62
+ self.assertIn("test_current", dyn.current_inputs)
63
+ self.assertIn("test_delta", dyn.delta_inputs)
64
+
65
+ def test_duplicate_input_key(self):
66
+ dyn = bst.nn.Dynamics(in_size=10)
67
+ dyn.add_current_input("test", lambda: np.random.rand(10))
68
+ with self.assertRaises(ValueError):
69
+ dyn.add_current_input("test", lambda: np.random.rand(10))
70
+
71
+ def test_varshape(self):
72
+ dyn = bst.nn.Dynamics(in_size=(2, 3))
73
+ self.assertEqual(dyn.varshape, (2, 3))
74
+ dyn = bst.nn.Dynamics(in_size=(2, 3))
75
+ self.assertEqual(dyn.varshape, (2, 3))
76
+
77
+
78
+ if __name__ == '__main__':
79
+ unittest.main()