brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,1343 +1,1267 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
-
19
- """
20
- All the basic dynamics class for the ``brainstate``.
21
-
22
- For handling dynamical systems:
23
-
24
- - ``DynamicsGroup``: The class for a group of modules, which update ``Projection`` first,
25
- then ``Dynamics``, finally others.
26
- - ``Projection``: The class for the synaptic projection.
27
- - ``Dynamics``: The class for the dynamical system.
28
-
29
- For handling the delays:
30
-
31
- - ``Delay``: The class for all delays.
32
- - ``DelayAccess``: The class for the delay access.
33
-
34
- """
35
-
36
- from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, TYPE_CHECKING, Tuple
37
-
38
- import jax
39
- import numpy as np
40
-
41
- from brainstate import environ
42
- from brainstate._state import State
43
- from brainstate.graph import Node
44
- from brainstate.mixin import ParamDescriber
45
- from brainstate.typing import Size, ArrayLike
46
- from ._delay import StateWithDelay, Delay
47
- from ._module import Module
48
-
49
- __all__ = [
50
- 'DynamicsGroup', 'Projection', 'Dynamics',
51
- 'Prefetch', 'PrefetchDelay', 'PrefetchDelayAt', 'OutputDelayAt',
52
- ]
53
-
54
- T = TypeVar('T')
55
- _max_order = 10
56
-
57
-
58
- class Projection(Module):
59
- """
60
- Base class for synaptic projection modules in neural network modeling.
61
-
62
- This class defines the interface for modules that handle projections between
63
- neural populations. Projections process input signals and transform them
64
- before they reach the target neurons, implementing the connectivity patterns
65
- in neural networks.
66
-
67
- In the BrainState execution order, Projection modules are updated before
68
- Dynamics modules, following the natural information flow in neural systems:
69
- 1. Projections process inputs (synaptic transmission)
70
- 2. Dynamics update neuron states (neural integration)
71
-
72
- The Projection class does not implement the update logic directly but delegates
73
- to its child nodes. If no child nodes exist, it raises a ValueError.
74
-
75
- Parameters
76
- ----------
77
- *args : Any
78
- Arguments passed to the parent Module class.
79
- **kwargs : Any
80
- Keyword arguments passed to the parent Module class.
81
-
82
- Raises
83
- ------
84
- ValueError
85
- If the update() method is called but no child nodes are defined.
86
-
87
- Notes
88
- -----
89
- Derived classes should implement specific projection behaviors, such as
90
- dense connectivity, sparse connectivity, or specific weight update rules.
91
- """
92
- __module__ = 'brainstate.nn'
93
-
94
- def update(self, *args, **kwargs):
95
- sub_nodes = tuple(self.nodes(allowed_hierarchy=(1, 1)).values())
96
- if len(sub_nodes):
97
- for node in sub_nodes:
98
- node(*args, **kwargs)
99
- else:
100
- raise ValueError('Do not implement the update() function.')
101
-
102
-
103
- class Dynamics(Module):
104
- """
105
- Base class for implementing neural dynamics models in BrainState.
106
-
107
- Dynamics classes represent the core computational units in neural simulations,
108
- implementing the differential equations or update rules that govern neural activity.
109
- This class provides infrastructure for managing neural populations, handling inputs,
110
- and coordinating updates within the simulation framework.
111
-
112
- The Dynamics class serves several key purposes:
113
- 1. Managing neuron population geometry and size information
114
- 2. Handling current and delta (instantaneous change) inputs to neurons
115
- 3. Supporting before/after update hooks for computational dependencies
116
- 4. Providing access to delayed state variables through the prefetch mechanism
117
- 5. Establishing the execution order in neural network simulations
118
-
119
- Parameters
120
- ----------
121
- in_size : Size
122
- The geometry of the neuron population. Can be an integer (e.g., 10) for
123
- 1D neuron arrays, or a tuple (e.g., (10, 10)) for multi-dimensional populations.
124
- name : Optional[str], default=None
125
- Optional name identifier for this dynamics module.
126
-
127
- Attributes
128
- ----------
129
- in_size : tuple
130
- The shape/geometry of the neuron population.
131
- out_size : tuple
132
- The output shape, typically matches in_size.
133
- current_inputs : Optional[Dict[str, Union[Callable, ArrayLike]]]
134
- Dictionary of registered current input functions or arrays.
135
- delta_inputs : Optional[Dict[str, Union[Callable, ArrayLike]]]
136
- Dictionary of registered delta input functions or arrays.
137
- before_updates : Optional[Dict[Hashable, Callable]]
138
- Dictionary of functions to call before the main update.
139
- after_updates : Optional[Dict[Hashable, Callable]]
140
- Dictionary of functions to call after the main update.
141
-
142
- Notes
143
- -----
144
- In the BrainState execution sequence, Dynamics modules are updated after
145
- Projection modules and before other module types, reflecting the natural
146
- flow of information in neural systems.
147
-
148
- There are several essential attributes:
149
-
150
- - ``size``: the geometry of the neuron group. For example, `(10, )` denotes a line of
151
- neurons, `(10, 10)` denotes a neuron group aligned in a 2D space, `(10, 15, 4)` denotes
152
- a 3-dimensional neuron group.
153
- - ``num``: the flattened number of neurons in the group. For example, `size=(10, )` => \
154
- `num=10`, `size=(10, 10)` => `num=100`, `size=(10, 15, 4)` => `num=600`.
155
-
156
-
157
- See Also
158
- --------
159
- Module : Parent class providing base module functionality
160
- Projection : Class for handling synaptic projections between neural populations
161
- DynamicsGroup : Container for organizing multiple dynamics modules
162
- """
163
-
164
- __module__ = 'brainstate.nn'
165
-
166
- graph_invisible_attrs = ('_before_updates', '_after_updates', '_current_inputs', '_delta_inputs')
167
-
168
- # before updates
169
- _before_updates: Optional[Dict[Hashable, Callable]]
170
-
171
- # after updates
172
- _after_updates: Optional[Dict[Hashable, Callable]]
173
-
174
- # current inputs
175
- _current_inputs: Optional[Dict[str, ArrayLike | Callable]]
176
-
177
- # delta inputs
178
- _delta_inputs: Optional[Dict[str, ArrayLike | Callable]]
179
-
180
- def __init__(
181
- self,
182
- in_size: Size,
183
- name: Optional[str] = None,
184
- ):
185
- # initialize
186
- super().__init__(name=name)
187
-
188
- # geometry size of neuron population
189
- if isinstance(in_size, (list, tuple)):
190
- if len(in_size) <= 0:
191
- raise ValueError(f'"in_size" must be int, or a tuple/list of int. But we got {type(in_size)}')
192
- if not isinstance(in_size[0], (int, np.integer)):
193
- raise ValueError(f'"in_size" must be int, or a tuple/list of int. But we got {type(in_size)}')
194
- in_size = tuple(in_size)
195
- elif isinstance(in_size, (int, np.integer)):
196
- in_size = (in_size,)
197
- else:
198
- raise ValueError(f'"in_size" must be int, or a tuple/list of int. But we got {type(in_size)}')
199
- self.in_size = in_size
200
-
201
- # current inputs
202
- self._current_inputs = None
203
-
204
- # delta inputs
205
- self._delta_inputs = None
206
-
207
- # before updates
208
- self._before_updates = None
209
-
210
- # after updates
211
- self._after_updates = None
212
-
213
- # in-/out- size of neuron population
214
- self.out_size = self.in_size
215
-
216
- # def __pretty_repr_item__(self, name, value):
217
- # if name in [
218
- # '_before_updates', '_after_updates', '_current_inputs', '_delta_inputs',
219
- # '_in_size', '_out_size', '_name', '_mode',
220
- # ]:
221
- # return (name, value) if value is None else (name[1:], value) # skip the first `_`
222
- # return super().__pretty_repr_item__(name, value)
223
-
224
- @property
225
- def varshape(self):
226
- """
227
- Get the shape of variables in the neuron group.
228
-
229
- This property provides access to the geometry (shape) of the neuron population,
230
- which determines how variables and states are structured.
231
-
232
- Returns
233
- -------
234
- tuple
235
- A tuple representing the dimensional shape of the neuron group,
236
- matching the in_size parameter provided during initialization.
237
-
238
- See Also
239
- --------
240
- in_size : The input geometry specification for the neuron group
241
- """
242
- return self.in_size
243
-
244
- @property
245
- def current_inputs(self):
246
- """
247
- Get the dictionary of current inputs registered with this dynamics model.
248
-
249
- Current inputs represent direct input currents that flow into the model.
250
-
251
- Returns
252
- -------
253
- dict or None
254
- A dictionary mapping keys to current input functions or values,
255
- or None if no current inputs have been registered.
256
-
257
- See Also
258
- --------
259
- add_current_input : Register a new current input
260
- sum_current_inputs : Apply and sum all current inputs
261
- delta_inputs : Dictionary of instantaneous change inputs
262
- """
263
- return self._current_inputs
264
-
265
- @property
266
- def delta_inputs(self):
267
- """
268
- Get the dictionary of delta inputs registered with this dynamics model.
269
-
270
- Delta inputs represent instantaneous changes to state variables (dX/dt).
271
-
272
- Returns
273
- -------
274
- dict or None
275
- A dictionary mapping keys to delta input functions or values,
276
- or None if no delta inputs have been registered.
277
-
278
- See Also
279
- --------
280
- add_delta_input : Register a new delta input
281
- sum_delta_inputs : Apply and sum all delta inputs
282
- current_inputs : Dictionary of direct current inputs
283
- """
284
- return self._delta_inputs
285
-
286
- def add_current_input(
287
- self,
288
- key: str,
289
- inp: Union[Callable, ArrayLike],
290
- label: Optional[str] = None
291
- ):
292
- """
293
- Add a current input function or array to the dynamics model.
294
-
295
- Current inputs represent direct input currents that can be accessed during
296
- model updates through the `sum_current_inputs()` method.
297
-
298
- Parameters
299
- ----------
300
- key : str
301
- Unique identifier for this current input. Used to retrieve or reference
302
- the input later.
303
- inp : Union[Callable, ArrayLike]
304
- The input data or function that generates input data.
305
- - If callable: Will be called during updates with arguments passed to `sum_current_inputs()`
306
- - If array-like: Will be applied once and then automatically removed from available inputs
307
- label : Optional[str], default=None
308
- Optional grouping label for the input. When provided, allows selective
309
- processing of inputs by label in `sum_current_inputs()`.
310
-
311
- Raises
312
- ------
313
- ValueError
314
- If the key has already been used for a different current input.
315
-
316
- Notes
317
- -----
318
- - Inputs with the same label can be processed together using the `label`
319
- parameter in `sum_current_inputs()`.
320
- - Non-callable inputs are consumed when used (removed after first use).
321
- - Callable inputs persist and can be called repeatedly.
322
-
323
- See Also
324
- --------
325
- sum_current_inputs : Sum all current inputs matching a given label
326
- add_delta_input : Add a delta input function or array
327
- """
328
- key = _input_label_repr(key, label)
329
- if self._current_inputs is None:
330
- self._current_inputs = dict()
331
- if key in self._current_inputs:
332
- if id(self._current_inputs[key]) != id(inp):
333
- raise ValueError(f'Key "{key}" has been defined and used in the current inputs of {self}.')
334
- self._current_inputs[key] = inp
335
-
336
- def add_delta_input(
337
- self,
338
- key: str,
339
- inp: Union[Callable, ArrayLike],
340
- label: Optional[str] = None
341
- ):
342
- """
343
- Add a delta input function or array to the dynamics model.
344
-
345
- Delta inputs represent instantaneous changes to the model state (i.e., dX/dt contributions).
346
- This method registers a function or array that provides delta inputs which will be
347
- accessible during model updates through the `sum_delta_inputs()` method.
348
-
349
- Parameters
350
- ----------
351
- key : str
352
- Unique identifier for this delta input. Used to retrieve or reference
353
- the input later.
354
- inp : Union[Callable, ArrayLike]
355
- The input data or function that generates input data.
356
- - If callable: Will be called during updates with arguments passed to `sum_delta_inputs()`
357
- - If array-like: Will be applied once and then automatically removed from available inputs
358
- label : Optional[str], default=None
359
- Optional grouping label for the input. When provided, allows selective
360
- processing of inputs by label in `sum_delta_inputs()`.
361
-
362
- Raises
363
- ------
364
- ValueError
365
- If the key has already been used for a different delta input.
366
-
367
- Notes
368
- -----
369
- - Inputs with the same label can be processed together using the `label`
370
- parameter in `sum_delta_inputs()`.
371
- - Non-callable inputs are consumed when used (removed after first use).
372
- - Callable inputs persist and can be called repeatedly.
373
-
374
- See Also
375
- --------
376
- sum_delta_inputs : Sum all delta inputs matching a given label
377
- add_current_input : Add a current input function or array
378
- """
379
- key = _input_label_repr(key, label)
380
- if self._delta_inputs is None:
381
- self._delta_inputs = dict()
382
- if key in self._delta_inputs:
383
- if id(self._delta_inputs[key]) != id(inp):
384
- raise ValueError(f'Key "{key}" has been defined and used.')
385
- self._delta_inputs[key] = inp
386
-
387
- def get_input(self, key: str):
388
- """
389
- Get a registered input function by its key.
390
-
391
- Retrieves either a current input or a delta input function that was previously
392
- registered with the given key. This method checks both current_inputs and
393
- delta_inputs dictionaries for the specified key.
394
-
395
- Parameters
396
- ----------
397
- key : str
398
- The unique identifier used when the input function was registered.
399
-
400
- Returns
401
- -------
402
- Callable or ArrayLike
403
- The input function or array associated with the given key.
404
-
405
- Raises
406
- ------
407
- ValueError
408
- If no input function is found with the specified key in either
409
- current_inputs or delta_inputs.
410
-
411
- See Also
412
- --------
413
- add_current_input : Register a current input function
414
- add_delta_input : Register a delta input function
415
-
416
- Examples
417
- --------
418
- >>> model = Dynamics(10)
419
- >>> model.add_current_input('stimulus', lambda t: np.sin(t))
420
- >>> input_func = model.get_input('stimulus')
421
- >>> input_func(0.5) # Returns sin(0.5)
422
- """
423
- if self._current_inputs is not None and key in self._current_inputs:
424
- return self._current_inputs[key]
425
- elif self._delta_inputs is not None and key in self._delta_inputs:
426
- return self._delta_inputs[key]
427
- else:
428
- raise ValueError(f'Input key {key} is not in current/delta inputs of the module {self}.')
429
-
430
- def sum_current_inputs(
431
- self,
432
- init: Any,
433
- *args,
434
- label: Optional[str] = None,
435
- **kwargs
436
- ):
437
- """
438
- Summarize all current inputs by applying and summing all registered current input functions.
439
-
440
- This method iterates through all registered current input functions (from `.current_inputs`)
441
- and applies them to calculate the total input current for the dynamics model. It adds all results
442
- to the initial value provided.
443
-
444
- Parameters
445
- ----------
446
- init : Any
447
- The initial value to which all current inputs will be added.
448
- *args : tuple
449
- Variable length argument list passed to each current input function.
450
- label : Optional[str], default=None
451
- If provided, only process current inputs with this label prefix.
452
- When None, process all current inputs regardless of label.
453
- **kwargs : dict
454
- Arbitrary keyword arguments passed to each current input function.
455
-
456
- Returns
457
- -------
458
- Any
459
- The initial value plus all applicable current inputs summed together.
460
-
461
- Notes
462
- -----
463
- - Non-callable current inputs are applied once and then automatically removed from
464
- the current_inputs dictionary.
465
- - Callable current inputs remain registered for subsequent calls.
466
- - When a label is provided, only current inputs with keys starting with that label
467
- are applied.
468
- """
469
- if self._current_inputs is None:
470
- return init
471
- if label is None:
472
- filter_fn = lambda k: True
473
- else:
474
- label_repr = _input_label_start(label)
475
- filter_fn = lambda k: k.startswith(label_repr)
476
- for key in tuple(self._current_inputs.keys()):
477
- if filter_fn(key):
478
- out = self._current_inputs[key]
479
- if callable(out):
480
- try:
481
- init = init + out(*args, **kwargs)
482
- except Exception as e:
483
- raise ValueError(
484
- f'Error in delta input value {key}: {out}\n'
485
- f'Error: {e}'
486
- ) from e
487
- else:
488
- try:
489
- init = init + out
490
- except Exception as e:
491
- raise ValueError(
492
- f'Error in delta input value {key}: {out}\n'
493
- f'Error: {e}'
494
- ) from e
495
- self._current_inputs.pop(key)
496
- return init
497
-
498
- def sum_delta_inputs(
499
- self,
500
- init: Any,
501
- *args,
502
- label: Optional[str] = None,
503
- **kwargs
504
- ):
505
- """
506
- Summarize all delta inputs by applying and summing all registered delta input functions.
507
-
508
- This method iterates through all registered delta input functions (from `.delta_inputs`)
509
- and applies them to calculate instantaneous changes to model states. It adds all results
510
- to the initial value provided.
511
-
512
- Parameters
513
- ----------
514
- init : Any
515
- The initial value to which all delta inputs will be added.
516
- *args : tuple
517
- Variable length argument list passed to each delta input function.
518
- label : Optional[str], default=None
519
- If provided, only process delta inputs with this label prefix.
520
- When None, process all delta inputs regardless of label.
521
- **kwargs : dict
522
- Arbitrary keyword arguments passed to each delta input function.
523
-
524
- Returns
525
- -------
526
- Any
527
- The initial value plus all applicable delta inputs summed together.
528
-
529
- Notes
530
- -----
531
- - Non-callable delta inputs are applied once and then automatically removed from
532
- the delta_inputs dictionary.
533
- - Callable delta inputs remain registered for subsequent calls.
534
- - When a label is provided, only delta inputs with keys starting with that label
535
- are applied.
536
- """
537
- if self._delta_inputs is None:
538
- return init
539
- if label is None:
540
- filter_fn = lambda k: True
541
- else:
542
- label_repr = _input_label_start(label)
543
- filter_fn = lambda k: k.startswith(label_repr)
544
- for key in tuple(self._delta_inputs.keys()):
545
- if filter_fn(key):
546
- out = self._delta_inputs[key]
547
- if callable(out):
548
- try:
549
- init = init + out(*args, **kwargs)
550
- except Exception as e:
551
- raise ValueError(
552
- f'Error in delta input function {key}: {out}\n'
553
- f'Error: {e}'
554
- ) from e
555
- else:
556
- try:
557
- init = init + out
558
- except Exception as e:
559
- raise ValueError(
560
- f'Error in delta input value {key}: {out}\n'
561
- f'Error: {e}'
562
- ) from e
563
- self._delta_inputs.pop(key)
564
- return init
565
-
566
- @property
567
- def before_updates(self):
568
- """
569
- Get the dictionary of functions to execute before the module's update.
570
-
571
- Returns
572
- -------
573
- dict or None
574
- Dictionary mapping keys to callable functions that will be executed
575
- before the main update, or None if no before updates are registered.
576
-
577
- Notes
578
- -----
579
- Before updates are executed in the order they were registered whenever
580
- the module is called via __call__.
581
- """
582
- return self._before_updates
583
-
584
- @property
585
- def after_updates(self):
586
- """
587
- Get the dictionary of functions to execute after the module's update.
588
-
589
- Returns
590
- -------
591
- dict or None
592
- Dictionary mapping keys to callable functions that will be executed
593
- after the main update, or None if no after updates are registered.
594
-
595
- Notes
596
- -----
597
- After updates are executed in the order they were registered whenever
598
- the module is called via __call__, and may optionally receive the return
599
- value from the update method.
600
- """
601
- return self._after_updates
602
-
603
- def _add_before_update(self, key: Any, fun: Callable):
604
- """
605
- Register a function to be executed before the module's update.
606
-
607
- Parameters
608
- ----------
609
- key : Any
610
- A unique identifier for the update function.
611
- fun : Callable
612
- The function to execute before the module's update.
613
-
614
- Raises
615
- ------
616
- KeyError
617
- If the key is already registered in before_updates.
618
-
619
- Notes
620
- -----
621
- Internal method used by the module system to register dependencies.
622
- """
623
- if self._before_updates is None:
624
- self._before_updates = dict()
625
- if key in self.before_updates:
626
- raise KeyError(f'{key} has been registered in before_updates of {self}')
627
- self.before_updates[key] = fun
628
-
629
- def _add_after_update(self, key: Any, fun: Callable):
630
- """
631
- Register a function to be executed after the module's update.
632
-
633
- Parameters
634
- ----------
635
- key : Any
636
- A unique identifier for the update function.
637
- fun : Callable
638
- The function to execute after the module's update.
639
-
640
- Raises
641
- ------
642
- KeyError
643
- If the key is already registered in after_updates.
644
-
645
- Notes
646
- -----
647
- Internal method used by the module system to register dependencies.
648
- """
649
- if self._after_updates is None:
650
- self._after_updates = dict()
651
- if key in self.after_updates:
652
- raise KeyError(f'{key} has been registered in after_updates of {self}')
653
- self.after_updates[key] = fun
654
-
655
- def _get_before_update(self, key: Any):
656
- """
657
- Retrieve a registered before-update function by its key.
658
-
659
- Parameters
660
- ----------
661
- key : Any
662
- The identifier of the before-update function to retrieve.
663
-
664
- Returns
665
- -------
666
- Callable
667
- The registered before-update function.
668
-
669
- Raises
670
- ------
671
- KeyError
672
- If the key is not registered in before_updates or if before_updates is None.
673
- """
674
- if self._before_updates is None:
675
- raise KeyError(f'{key} is not registered in before_updates of {self}')
676
- if key not in self.before_updates:
677
- raise KeyError(f'{key} is not registered in before_updates of {self}')
678
- return self.before_updates.get(key)
679
-
680
- def _get_after_update(self, key: Any):
681
- """
682
- Retrieve a registered after-update function by its key.
683
-
684
- Parameters
685
- ----------
686
- key : Any
687
- The identifier of the after-update function to retrieve.
688
-
689
- Returns
690
- -------
691
- Callable
692
- The registered after-update function.
693
-
694
- Raises
695
- ------
696
- KeyError
697
- If the key is not registered in after_updates or if after_updates is None.
698
- """
699
- if self._after_updates is None:
700
- raise KeyError(f'{key} is not registered in after_updates of {self}')
701
- if key not in self.after_updates:
702
- raise KeyError(f'{key} is not registered in after_updates of {self}')
703
- return self.after_updates.get(key)
704
-
705
- def _has_before_update(self, key: Any):
706
- """
707
- Check if a before-update function is registered with the given key.
708
-
709
- Parameters
710
- ----------
711
- key : Any
712
- The identifier to check for in the before_updates dictionary.
713
-
714
- Returns
715
- -------
716
- bool
717
- True if the key is registered in before_updates, False otherwise.
718
- """
719
- if self._before_updates is None:
720
- return False
721
- return key in self.before_updates
722
-
723
- def _has_after_update(self, key: Any):
724
- """
725
- Check if an after-update function is registered with the given key.
726
-
727
- Parameters
728
- ----------
729
- key : Any
730
- The identifier to check for in the after_updates dictionary.
731
-
732
- Returns
733
- -------
734
- bool
735
- True if the key is registered in after_updates, False otherwise.
736
- """
737
- if self._after_updates is None:
738
- return False
739
- return key in self.after_updates
740
-
741
- def __call__(self, *args, **kwargs):
742
- """
743
- The shortcut to call ``update`` methods.
744
- """
745
-
746
- # ``before_updates``
747
- if self.before_updates is not None:
748
- for model in self.before_updates.values():
749
- if hasattr(model, '_receive_update_input'):
750
- model(*args, **kwargs)
751
- else:
752
- model()
753
-
754
- # update the model self
755
- ret = self.update(*args, **kwargs)
756
-
757
- # ``after_updates``
758
- if self.after_updates is not None:
759
- for model in self.after_updates.values():
760
- if hasattr(model, '_not_receive_update_output'):
761
- model()
762
- else:
763
- model(ret)
764
- return ret
765
-
766
- def prefetch(self, item: str) -> 'Prefetch':
767
- """
768
- Create a reference to a state or variable that may not be initialized yet.
769
-
770
- This method allows accessing module attributes or states before they are
771
- fully defined, acting as a placeholder that will be resolved when called.
772
- Particularly useful for creating references to variables that will be defined
773
- during initialization or runtime.
774
-
775
- Parameters
776
- ----------
777
- item : str
778
- The name of the attribute or state to reference.
779
-
780
- Returns
781
- -------
782
- Prefetch
783
- A Prefetch object that provides access to the referenced item.
784
-
785
- Examples
786
- --------
787
- >>> import brainstate
788
- >>> import brainunit as u
789
- >>> neuron = brainstate.nn.LIF(...)
790
- >>> v_ref = neuron.prefetch('V') # Reference to voltage
791
- >>> v_value = v_ref() # Get current value
792
- >>> delayed_v = v_ref.delay.at(5.0 * u.ms) # Get delayed value
793
- """
794
- return Prefetch(self, item)
795
-
796
- def align_pre(self, dyn: Union[ParamDescriber[T], T]) -> T:
797
- """
798
- Registers a dynamics module to execute after this module.
799
-
800
- This method establishes a sequential execution relationship where the specified
801
- dynamics module will be called after this module completes its update. This
802
- creates a feed-forward connection in the computational graph.
803
-
804
- Parameters
805
- ----------
806
- dyn : Union[ParamDescriber[T], T]
807
- The dynamics module to be executed after this module. Can be either:
808
- - An instance of Dynamics
809
- - A ParamDescriber that can instantiate a Dynamics object
810
-
811
- Returns
812
- -------
813
- T
814
- The dynamics module that was registered, allowing for method chaining.
815
-
816
- Raises
817
- ------
818
- TypeError
819
- If the input is not a Dynamics instance or a ParamDescriber that creates
820
- a Dynamics instance.
821
-
822
- Examples
823
- --------
824
- >>> import brainstate
825
- >>> n1 = brainstate.nn.LIF(10)
826
- >>> n1.align_pre(brainstate.nn.Expon.desc(n1.varshape)) # n2 will run after n1
827
- """
828
- if isinstance(dyn, Dynamics):
829
- self._add_after_update(id(dyn), dyn)
830
- return dyn
831
- elif isinstance(dyn, ParamDescriber):
832
- if not issubclass(dyn.cls, Dynamics):
833
- raise TypeError(f'The input {dyn} should be an instance of {Dynamics}.')
834
- if not self._has_after_update(dyn.identifier):
835
- self._add_after_update(
836
- dyn.identifier,
837
- dyn() if ('in_size' in dyn.kwargs or len(dyn.args) > 0) else dyn(in_size=self.varshape)
838
- )
839
- return self._get_after_update(dyn.identifier)
840
- else:
841
- raise TypeError(f'The input {dyn} should be an instance of {Dynamics} or a delayed initializer.')
842
-
843
- def prefetch_delay(self, state: str, delay_time, init: Callable = None) -> 'PrefetchDelayAt':
844
- """
845
- Create a reference to a delayed state or variable in the module.
846
-
847
- This method simplifies the process of accessing a delayed version of a state or variable
848
- within the module. It first creates a prefetch reference to the specified state,
849
- then specifies the delay time for accessing this state.
850
-
851
- Args:
852
- state (str): The name of the state or variable to reference.
853
- delay_time (ArrayLike): The amount of time to delay the variable access,
854
- typically in time units (e.g., milliseconds).
855
- init (Callable, optional): An optional initialization function to provide
856
- a default value if the delayed state is not yet available.
857
-
858
- Returns:
859
- PrefetchDelayAt: An object that provides access to the variable at the specified delay time.
860
- """
861
- return PrefetchDelayAt(self, state, delay_time, init=init)
862
-
863
- def output_delay(self, *delay_time) -> 'OutputDelayAt':
864
- """
865
- Create a reference to the delayed output of the module.
866
-
867
- This method simplifies the process of accessing a delayed version of the module's output.
868
- It instantiates an `OutputDelayAt` object, which can be used to retrieve the output value
869
- at the specified delay time.
870
-
871
- Args:
872
- delay (Optional[ArrayLike]): The amount of time to delay the output access,
873
- typically in time units (e.g., milliseconds). Defaults to None.
874
-
875
- Returns:
876
- OutputDelayAt: An object that provides access to the module's output at the specified delay time.
877
- """
878
- return OutputDelayAt(self, delay_time)
879
-
880
-
881
- class Prefetch(Node):
882
- """
883
- Prefetch a state or variable in a module before it is initialized.
884
-
885
-
886
- This class provides a mechanism to reference a module's state or attribute
887
- that may not have been initialized yet. It acts as a placeholder or reference
888
- that will be resolved when called.
889
-
890
- Use cases:
891
- - Access variables within dynamics modules that will be defined later
892
- - Create references to states across module boundaries
893
- - Enable access to delayed states through the `.delay` property
894
-
895
- Parameters
896
- ----------
897
- module : Module
898
- The module that contains or will contain the referenced item.
899
- item : str
900
- The attribute name of the state or variable to prefetch.
901
-
902
- Examples
903
- --------
904
- >>> import brainstate
905
- >>> import brainunit as u
906
- >>> neuron = brainstate.nn.LIF(...)
907
- >>> v_reference = neuron.prefetch('V') # Reference to voltage before initialization
908
- >>> v_value = v_reference() # Get the current value
909
- >>> delay_ref = v_reference.delay.at(5.0 * u.ms) # Reference voltage delayed by 5ms
910
-
911
- Notes
912
- -----
913
- When called, this class retrieves the current value of the referenced item.
914
- Use the `.delay` property to access delayed versions of the state.
915
-
916
- """
917
-
918
- def __init__(self, module: Dynamics, item: str):
919
- """
920
- Initialize a Prefetch object.
921
-
922
- Parameters
923
- ----------
924
- module : Module
925
- The module that contains or will contain the referenced item.
926
- item : str
927
- The attribute name of the state or variable to prefetch.
928
- """
929
- super().__init__()
930
- self.module = module
931
- self.item = item
932
-
933
- @property
934
- def delay(self):
935
- """
936
- Access delayed versions of the prefetched item.
937
-
938
- Returns
939
- -------
940
- PrefetchDelay
941
- An object that provides access to delayed versions of the prefetched item.
942
- """
943
- return PrefetchDelay(self.module, self.item)
944
- # return PrefetchDelayAt(self.module, self.item, time)
945
-
946
- def __call__(self, *args, **kwargs):
947
- """
948
- Get the current value of the prefetched item.
949
-
950
- Returns
951
- -------
952
- Any
953
- The current value of the referenced item. If the item is a State object,
954
- returns its value attribute, otherwise returns the item itself.
955
- """
956
- item = _get_prefetch_item(self)
957
- return item.value if isinstance(item, State) else item
958
-
959
- def get_item_value(self):
960
- """
961
- Get the current value of the prefetched item.
962
-
963
- Similar to __call__, but explicitly named for clarity.
964
-
965
- Returns
966
- -------
967
- Any
968
- The current value of the referenced item. If the item is a State object,
969
- returns its value attribute, otherwise returns the item itself.
970
- """
971
- item = _get_prefetch_item(self)
972
- return item.value if isinstance(item, State) else item
973
-
974
- def get_item(self):
975
- """
976
- Get the referenced item object itself, not its value.
977
-
978
- Returns
979
- -------
980
- Any
981
- The actual referenced item from the module, which could be a State
982
- object or any other attribute.
983
- """
984
- return _get_prefetch_item(self)
985
-
986
-
987
- class PrefetchDelay(Node):
988
- """
989
- Provides access to delayed versions of a prefetched state or variable.
990
-
991
- This class acts as an intermediary for accessing delayed values of module variables.
992
- It doesn't retrieve values directly but provides methods to specify the delay time
993
- via the `at()` method.
994
-
995
- Parameters
996
- ----------
997
- module : Dynamics
998
- The dynamics module that contains the referenced state or variable.
999
- item : str
1000
- The name of the state or variable to access with delay.
1001
-
1002
- Examples
1003
- --------
1004
- >>> import brainstate
1005
- >>> import brainunit as u
1006
- >>> neuron = brainstate.nn.LIF(10)
1007
- >>> # Access voltage delayed by 5ms
1008
- >>> delayed_v = neuron.prefetch('V').delay.at(5.0 * u.ms)
1009
- >>> delayed_value = delayed_v() # Get the delayed value
1010
- """
1011
-
1012
- def __init__(self, module: Dynamics, item: str):
1013
- self.module = module
1014
- self.item = item
1015
-
1016
- def at(self, *delay_time):
1017
- """
1018
- Specifies the delay time for accessing the variable.
1019
-
1020
- Parameters
1021
- ----------
1022
- time : ArrayLike
1023
- The amount of time to delay the variable access, typically in time units
1024
- (e.g., milliseconds).
1025
-
1026
- Returns
1027
- -------
1028
- PrefetchDelayAt
1029
- An object that provides access to the variable at the specified delay time.
1030
- """
1031
- return PrefetchDelayAt(self.module, self.item, delay_time)
1032
-
1033
-
1034
- class PrefetchDelayAt(Node):
1035
- """
1036
- Provides access to a specific delayed state or variable value at the specific time.
1037
-
1038
- This class represents the final step in the prefetch delay chain, providing
1039
- actual access to state values at a specific delay time. It converts the
1040
- specified time delay into steps and registers the delay with the appropriate
1041
- StateWithDelay handler.
1042
-
1043
- Parameters
1044
- ----------
1045
- module : Dynamics
1046
- The dynamics module that contains the referenced state or variable.
1047
- item : str
1048
- The name of the state or variable to access with delay.
1049
- time : ArrayLike
1050
- The amount of time to delay access by, typically in time units (e.g., milliseconds).
1051
-
1052
- Examples
1053
- --------
1054
- >>> import brainstate
1055
- >>> import brainunit as u
1056
- >>> neuron = brainstate.nn.LIF(10)
1057
- >>> # Create a reference to voltage delayed by 5ms
1058
- >>> delayed_v = PrefetchDelayAt(neuron, 'V', 5.0 * u.ms)
1059
- >>> # Get the delayed value
1060
- >>> v_value = delayed_v()
1061
- """
1062
-
1063
- def __init__(
1064
- self,
1065
- module: Dynamics,
1066
- item: str,
1067
- delay_time: Tuple,
1068
- init: Callable = None
1069
- ):
1070
- """
1071
- Initialize a PrefetchDelayAt object.
1072
-
1073
- Parameters
1074
- ----------
1075
- module : Dynamics
1076
- The dynamics module that contains the referenced state or variable.
1077
- item : str
1078
- The name of the state or variable to access with delay.
1079
- delay_time : Tuple
1080
- The amount of time to delay access by, typically in time units (e.g., milliseconds).
1081
- """
1082
- super().__init__()
1083
- assert isinstance(module, Dynamics), 'The module should be an instance of Dynamics.'
1084
- self.module = module
1085
- self.item = item
1086
- if not isinstance(delay_time, (tuple, list)):
1087
- delay_time = (delay_time,)
1088
- self.delay_time = delay_time
1089
- if len(delay_time) > 0:
1090
- key = _get_prefetch_delay_key(item)
1091
- if not module._has_after_update(key):
1092
- module._add_after_update(
1093
- key,
1094
- not_receive_update_output(
1095
- StateWithDelay(module, item, init=init)
1096
- )
1097
- )
1098
- self.state_delay: StateWithDelay = module._get_after_update(key)
1099
- self.delay_info = self.state_delay.register_delay(*delay_time)
1100
-
1101
- def __call__(self, *args, **kwargs):
1102
- """
1103
- Retrieve the value of the state at the specified delay time.
1104
-
1105
- Returns
1106
- -------
1107
- Any
1108
- The value of the state or variable at the specified delay time.
1109
- """
1110
- if len(self.delay_time) == 0:
1111
- return _get_prefetch_item(self).value
1112
- else:
1113
- return self.state_delay.retrieve_at_step(*self.delay_info)
1114
-
1115
-
1116
- class OutputDelayAt(Node):
1117
- """
1118
- Provides access to a specific delayed state or variable value at the specific time.
1119
-
1120
- This class represents the final step in the prefetch delay chain, providing
1121
- actual access to state values at a specific delay time. It converts the
1122
- specified time delay into steps and registers the delay with the appropriate
1123
- StateWithDelay handler.
1124
-
1125
- Parameters
1126
- ----------
1127
- module : Dynamics
1128
- The dynamics module that contains the referenced state or variable.
1129
- time : ArrayLike
1130
- The amount of time to delay access by, typically in time units (e.g., milliseconds).
1131
-
1132
- Examples
1133
- --------
1134
- >>> import brainstate
1135
- >>> import brainunit as u
1136
- >>> neuron = brainstate.nn.LIF(10)
1137
- >>> # Create a reference to voltage delayed by 5ms
1138
- >>> delayed_spike = OutputDelayAt(neuron, 5.0 * u.ms)
1139
- >>> # Get the delayed value
1140
- >>> v_value = delayed_spike()
1141
- """
1142
-
1143
- def __init__(
1144
- self,
1145
- module: Dynamics,
1146
- delay_time: Tuple,
1147
- ):
1148
- super().__init__()
1149
- assert isinstance(module, Dynamics), 'The module should be an instance of Dynamics.'
1150
- self.module = module
1151
- key = _get_output_delay_key()
1152
- if not module._has_after_update(key):
1153
- delay = Delay(jax.ShapeDtypeStruct(module.out_size, dtype=environ.dftype()), take_aware_unit=True)
1154
- module._add_after_update(key, receive_update_output(delay))
1155
- self.out_delay: Delay = module._get_after_update(key)
1156
- self.delay_info = self.out_delay.register_delay(*delay_time)
1157
-
1158
- def __call__(self, *args, **kwargs):
1159
- return self.out_delay.retrieve_at_step(*self.delay_info)
1160
-
1161
-
1162
- def _get_prefetch_delay_key(item) -> str:
1163
- return f'{item}-prefetch-delay'
1164
-
1165
-
1166
- def _get_output_delay_key() -> str:
1167
- return f'output-delay'
1168
-
1169
-
1170
- def _get_prefetch_item(target: Union[Prefetch, PrefetchDelayAt]) -> Any:
1171
- item = getattr(target.module, target.item, None)
1172
- if item is None:
1173
- raise AttributeError(f'The target {target.module} should have an `{target.item}` attribute.')
1174
- return item
1175
-
1176
-
1177
- def _get_prefetch_item_delay(target: Union[Prefetch, PrefetchDelay, PrefetchDelayAt]) -> Delay:
1178
- assert isinstance(target.module, Dynamics), (
1179
- f'The target module should be an instance '
1180
- f'of Dynamics. But got {target.module}.'
1181
- )
1182
- delay = target.module._get_after_update(_get_prefetch_delay_key(target.item))
1183
- if not isinstance(delay, StateWithDelay):
1184
- raise TypeError(f'The prefetch target should be a {StateWithDelay.__name__} when accessing '
1185
- f'its delay. But got {delay}.')
1186
- return delay
1187
-
1188
-
1189
- def maybe_init_prefetch(target, *args, **kwargs):
1190
- """
1191
- Initialize a prefetch target if needed, based on its type.
1192
-
1193
- This function ensures that prefetch references are properly initialized
1194
- and ready to use. It handles different types of prefetch objects by
1195
- performing the appropriate initialization action:
1196
- - For :py:class:`Prefetch` objects: retrieves the referenced item
1197
- - For :py:class:`PrefetchDelay` objects: retrieves the delay handler
1198
- - For :py:class:`PrefetchDelayAt` objects: registers the specified delay
1199
-
1200
- Parameters
1201
- ----------
1202
- target : Union[Prefetch, PrefetchDelay, PrefetchDelayAt]
1203
- The prefetch target to initialize.
1204
- *args : Any
1205
- Additional positional arguments (unused).
1206
- **kwargs : Any
1207
- Additional keyword arguments (unused).
1208
-
1209
- Returns
1210
- -------
1211
- None
1212
- This function performs initialization side effects only.
1213
-
1214
- Notes
1215
- -----
1216
- This function is typically called internally when prefetched references
1217
- are used to ensure they are properly set up before access.
1218
- """
1219
- if isinstance(target, Prefetch):
1220
- _get_prefetch_item(target)
1221
-
1222
- elif isinstance(target, PrefetchDelay):
1223
- _get_prefetch_item_delay(target)
1224
-
1225
- elif isinstance(target, PrefetchDelayAt):
1226
- pass
1227
- # delay = _get_prefetch_item_delay(target)
1228
- # delay.register_delay(*target.delay_time)
1229
-
1230
-
1231
- class DynamicsGroup(Module):
1232
- """
1233
- A group of :py:class:`~.Module` in which the updating order does not matter.
1234
-
1235
- Args:
1236
- children_as_tuple: The children objects.
1237
- children_as_dict: The children objects.
1238
- """
1239
-
1240
- __module__ = 'brainstate.nn'
1241
-
1242
- if not TYPE_CHECKING:
1243
- def __init__(self, *children_as_tuple, **children_as_dict):
1244
- super().__init__()
1245
- self.layers_tuple = tuple(children_as_tuple)
1246
- self.layers_dict = dict(children_as_dict)
1247
-
1248
- def update(self, *args, **kwargs):
1249
- """
1250
- Update function of a network.
1251
-
1252
- In this update function, the update functions in children systems are iteratively called.
1253
- """
1254
- projs, dyns, others = self.nodes(allowed_hierarchy=(1, 1)).split(Projection, Dynamics)
1255
-
1256
- # update nodes of projections
1257
- for node in projs.values():
1258
- node()
1259
-
1260
- # update nodes of dynamics
1261
- for node in dyns.values():
1262
- node()
1263
-
1264
- # update nodes with other types, including delays, ...
1265
- for node in others.values():
1266
- node()
1267
-
1268
-
1269
- def receive_update_output(cls: object):
1270
- """
1271
- The decorator to mark the object (as the after updates) to receive the output of the update function.
1272
-
1273
- That is, the `aft_update` will receive the return of the update function::
1274
-
1275
- ret = model.update(*args, **kwargs)
1276
- for fun in model.aft_updates:
1277
- fun(ret)
1278
-
1279
- """
1280
- # assert isinstance(cls, Module), 'The input class should be instance of Module.'
1281
- if hasattr(cls, '_not_receive_update_output'):
1282
- delattr(cls, '_not_receive_update_output')
1283
- return cls
1284
-
1285
-
1286
- def not_receive_update_output(cls: T) -> T:
1287
- """
1288
- The decorator to mark the object (as the after updates) to not receive the output of the update function.
1289
-
1290
- That is, the `aft_update` will not receive the return of the update function::
1291
-
1292
- ret = model.update(*args, **kwargs)
1293
- for fun in model.aft_updates:
1294
- fun()
1295
-
1296
- """
1297
- # assert isinstance(cls, Module), 'The input class should be instance of Module.'
1298
- cls._not_receive_update_output = True
1299
- return cls
1300
-
1301
-
1302
- def receive_update_input(cls: object):
1303
- """
1304
- The decorator to mark the object (as the before updates) to receive the input of the update function.
1305
-
1306
- That is, the `bef_update` will receive the input of the update function::
1307
-
1308
-
1309
- for fun in model.bef_updates:
1310
- fun(*args, **kwargs)
1311
- model.update(*args, **kwargs)
1312
-
1313
- """
1314
- # assert isinstance(cls, Module), 'The input class should be instance of Module.'
1315
- cls._receive_update_input = True
1316
- return cls
1317
-
1318
-
1319
- def not_receive_update_input(cls: object):
1320
- """
1321
- The decorator to mark the object (as the before updates) to not receive the input of the update function.
1322
-
1323
- That is, the `bef_update` will not receive the input of the update function::
1324
-
1325
- for fun in model.bef_updates:
1326
- fun()
1327
- model.update()
1328
-
1329
- """
1330
- # assert isinstance(cls, Module), 'The input class should be instance of Module.'
1331
- if hasattr(cls, '_receive_update_input'):
1332
- delattr(cls, '_receive_update_input')
1333
- return cls
1334
-
1335
-
1336
- def _input_label_start(label: str):
1337
- # unify the input label repr.
1338
- return f'{label} // '
1339
-
1340
-
1341
- def _input_label_repr(name: str, label: Optional[str] = None):
1342
- # unify the input label repr.
1343
- return name if label is None else (_input_label_start(label) + str(name))
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+
19
+ """
20
+ All the basic dynamics class for the ``brainstate``.
21
+
22
+ For handling dynamical systems:
23
+
24
+ - ``DynamicsGroup``: The class for a group of modules, which update ``Projection`` first,
25
+ then ``Dynamics``, finally others.
26
+ - ``Projection``: The class for the synaptic projection.
27
+ - ``Dynamics``: The class for the dynamical system.
28
+
29
+ For handling the delays:
30
+
31
+ - ``Delay``: The class for all delays.
32
+ - ``DelayAccess``: The class for the delay access.
33
+
34
+ """
35
+
36
+ from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, TYPE_CHECKING, Tuple
37
+
38
+ import jax
39
+ import numpy as np
40
+
41
+ from brainstate import environ
42
+ from brainstate._state import State
43
+ from brainstate.graph import Node
44
+ from brainstate.mixin import ParamDescriber
45
+ from brainstate.typing import Size, ArrayLike
46
+ from ._delay import StateWithDelay, Delay
47
+ from ._module import Module
48
+
49
+ __all__ = [
50
+ 'DynamicsGroup',
51
+ 'Dynamics',
52
+ 'Prefetch',
53
+ 'PrefetchDelay',
54
+ 'PrefetchDelayAt',
55
+ 'OutputDelayAt',
56
+ ]
57
+
58
+ T = TypeVar('T')
59
+ _max_order = 10
60
+
61
+
62
+ class Dynamics(Module):
63
+ """
64
+ Base class for implementing neural dynamics models in BrainState.
65
+
66
+ Dynamics classes represent the core computational units in neural simulations,
67
+ implementing the differential equations or update rules that govern neural activity.
68
+ This class provides infrastructure for managing neural populations, handling inputs,
69
+ and coordinating updates within the simulation framework.
70
+
71
+ The Dynamics class serves several key purposes:
72
+ 1. Managing neuron population geometry and size information
73
+ 2. Handling current and delta (instantaneous change) inputs to neurons
74
+ 3. Supporting before/after update hooks for computational dependencies
75
+ 4. Providing access to delayed state variables through the prefetch mechanism
76
+ 5. Establishing the execution order in neural network simulations
77
+
78
+ Parameters
79
+ ----------
80
+ in_size : Size
81
+ The geometry of the neuron population. Can be an integer (e.g., 10) for
82
+ 1D neuron arrays, or a tuple (e.g., (10, 10)) for multi-dimensional populations.
83
+ name : Optional[str], default=None
84
+ Optional name identifier for this dynamics module.
85
+
86
+ Attributes
87
+ ----------
88
+ in_size : tuple
89
+ The shape/geometry of the neuron population.
90
+ out_size : tuple
91
+ The output shape, typically matches in_size.
92
+ current_inputs : Optional[Dict[str, Union[Callable, ArrayLike]]]
93
+ Dictionary of registered current input functions or arrays.
94
+ delta_inputs : Optional[Dict[str, Union[Callable, ArrayLike]]]
95
+ Dictionary of registered delta input functions or arrays.
96
+ before_updates : Optional[Dict[Hashable, Callable]]
97
+ Dictionary of functions to call before the main update.
98
+ after_updates : Optional[Dict[Hashable, Callable]]
99
+ Dictionary of functions to call after the main update.
100
+
101
+ Notes
102
+ -----
103
+ In the BrainState execution sequence, Dynamics modules are updated after
104
+ Projection modules and before other module types, reflecting the natural
105
+ flow of information in neural systems.
106
+
107
+ There are several essential attributes:
108
+
109
+ - ``size``: the geometry of the neuron group. For example, `(10, )` denotes a line of
110
+ neurons, `(10, 10)` denotes a neuron group aligned in a 2D space, `(10, 15, 4)` denotes
111
+ a 3-dimensional neuron group.
112
+ - ``num``: the flattened number of neurons in the group. For example, `size=(10, )` => \
113
+ `num=10`, `size=(10, 10)` => `num=100`, `size=(10, 15, 4)` => `num=600`.
114
+
115
+
116
+ See Also
117
+ --------
118
+ Module : Parent class providing base module functionality
119
+ Projection : Class for handling synaptic projections between neural populations
120
+ DynamicsGroup : Container for organizing multiple dynamics modules
121
+ """
122
+
123
+ __module__ = 'brainstate.nn'
124
+
125
+ graph_invisible_attrs = ()
126
+
127
+ # before updates
128
+ _before_updates: Optional[Dict[Hashable, Callable]]
129
+
130
+ # after updates
131
+ _after_updates: Optional[Dict[Hashable, Callable]]
132
+
133
+ # current inputs
134
+ _current_inputs: Optional[Dict[str, ArrayLike | Callable]]
135
+
136
+ # delta inputs
137
+ _delta_inputs: Optional[Dict[str, ArrayLike | Callable]]
138
+
139
+ def __init__(
140
+ self,
141
+ in_size: Size,
142
+ name: Optional[str] = None,
143
+ ):
144
+ # initialize
145
+ super().__init__(name=name)
146
+
147
+ # geometry size of neuron population
148
+ if isinstance(in_size, (list, tuple)):
149
+ if len(in_size) <= 0:
150
+ raise ValueError(f'"in_size" must be int, or a tuple/list of int. But we got {type(in_size)}')
151
+ if not isinstance(in_size[0], (int, np.integer)):
152
+ raise ValueError(f'"in_size" must be int, or a tuple/list of int. But we got {type(in_size)}')
153
+ in_size = tuple(in_size)
154
+ elif isinstance(in_size, (int, np.integer)):
155
+ in_size = (in_size,)
156
+ else:
157
+ raise ValueError(f'"in_size" must be int, or a tuple/list of int. But we got {type(in_size)}')
158
+ self.in_size = in_size
159
+
160
+ # current inputs
161
+ self._current_inputs = None
162
+
163
+ # delta inputs
164
+ self._delta_inputs = None
165
+
166
+ # before updates
167
+ self._before_updates = None
168
+
169
+ # after updates
170
+ self._after_updates = None
171
+
172
+ # in-/out- size of neuron population
173
+ self.out_size = self.in_size
174
+
175
+ # def __pretty_repr_item__(self, name, value):
176
+ # if name in [
177
+ # '_before_updates', '_after_updates', '_current_inputs', '_delta_inputs',
178
+ # '_in_size', '_out_size', '_name', '_mode',
179
+ # ]:
180
+ # return (name, value) if value is None else (name[1:], value) # skip the first `_`
181
+ # return super().__pretty_repr_item__(name, value)
182
+
183
+ @property
184
+ def varshape(self):
185
+ """
186
+ Get the shape of variables in the neuron group.
187
+
188
+ This property provides access to the geometry (shape) of the neuron population,
189
+ which determines how variables and states are structured.
190
+
191
+ Returns
192
+ -------
193
+ tuple
194
+ A tuple representing the dimensional shape of the neuron group,
195
+ matching the in_size parameter provided during initialization.
196
+
197
+ See Also
198
+ --------
199
+ in_size : The input geometry specification for the neuron group
200
+ """
201
+ return self.in_size
202
+
203
+ @property
204
+ def current_inputs(self):
205
+ """
206
+ Get the dictionary of current inputs registered with this dynamics model.
207
+
208
+ Current inputs represent direct input currents that flow into the model.
209
+
210
+ Returns
211
+ -------
212
+ dict or None
213
+ A dictionary mapping keys to current input functions or values,
214
+ or None if no current inputs have been registered.
215
+
216
+ See Also
217
+ --------
218
+ add_current_input : Register a new current input
219
+ sum_current_inputs : Apply and sum all current inputs
220
+ delta_inputs : Dictionary of instantaneous change inputs
221
+ """
222
+ return self._current_inputs
223
+
224
+ @property
225
+ def delta_inputs(self):
226
+ """
227
+ Get the dictionary of delta inputs registered with this dynamics model.
228
+
229
+ Delta inputs represent instantaneous changes to state variables (dX/dt).
230
+
231
+ Returns
232
+ -------
233
+ dict or None
234
+ A dictionary mapping keys to delta input functions or values,
235
+ or None if no delta inputs have been registered.
236
+
237
+ See Also
238
+ --------
239
+ add_delta_input : Register a new delta input
240
+ sum_delta_inputs : Apply and sum all delta inputs
241
+ current_inputs : Dictionary of direct current inputs
242
+ """
243
+ return self._delta_inputs
244
+
245
+ def add_current_input(
246
+ self,
247
+ key: str,
248
+ inp: Union[Callable, ArrayLike],
249
+ label: Optional[str] = None
250
+ ):
251
+ """
252
+ Add a current input function or array to the dynamics model.
253
+
254
+ Current inputs represent direct input currents that can be accessed during
255
+ model updates through the `sum_current_inputs()` method.
256
+
257
+ Parameters
258
+ ----------
259
+ key : str
260
+ Unique identifier for this current input. Used to retrieve or reference
261
+ the input later.
262
+ inp : Union[Callable, ArrayLike]
263
+ The input data or function that generates input data.
264
+ - If callable: Will be called during updates with arguments passed to `sum_current_inputs()`
265
+ - If array-like: Will be applied once and then automatically removed from available inputs
266
+ label : Optional[str], default=None
267
+ Optional grouping label for the input. When provided, allows selective
268
+ processing of inputs by label in `sum_current_inputs()`.
269
+
270
+ Raises
271
+ ------
272
+ ValueError
273
+ If the key has already been used for a different current input.
274
+
275
+ Notes
276
+ -----
277
+ - Inputs with the same label can be processed together using the `label`
278
+ parameter in `sum_current_inputs()`.
279
+ - Non-callable inputs are consumed when used (removed after first use).
280
+ - Callable inputs persist and can be called repeatedly.
281
+
282
+ See Also
283
+ --------
284
+ sum_current_inputs : Sum all current inputs matching a given label
285
+ add_delta_input : Add a delta input function or array
286
+ """
287
+ key = _input_label_repr(key, label)
288
+ if self._current_inputs is None:
289
+ self._current_inputs = dict()
290
+ if key in self._current_inputs:
291
+ if id(self._current_inputs[key]) != id(inp):
292
+ raise ValueError(f'Key "{key}" has been defined and used in the current inputs of {self}.')
293
+ self._current_inputs[key] = inp
294
+
295
+ def add_delta_input(
296
+ self,
297
+ key: str,
298
+ inp: Union[Callable, ArrayLike],
299
+ label: Optional[str] = None
300
+ ):
301
+ """
302
+ Add a delta input function or array to the dynamics model.
303
+
304
+ Delta inputs represent instantaneous changes to the model state (i.e., dX/dt contributions).
305
+ This method registers a function or array that provides delta inputs which will be
306
+ accessible during model updates through the `sum_delta_inputs()` method.
307
+
308
+ Parameters
309
+ ----------
310
+ key : str
311
+ Unique identifier for this delta input. Used to retrieve or reference
312
+ the input later.
313
+ inp : Union[Callable, ArrayLike]
314
+ The input data or function that generates input data.
315
+ - If callable: Will be called during updates with arguments passed to `sum_delta_inputs()`
316
+ - If array-like: Will be applied once and then automatically removed from available inputs
317
+ label : Optional[str], default=None
318
+ Optional grouping label for the input. When provided, allows selective
319
+ processing of inputs by label in `sum_delta_inputs()`.
320
+
321
+ Raises
322
+ ------
323
+ ValueError
324
+ If the key has already been used for a different delta input.
325
+
326
+ Notes
327
+ -----
328
+ - Inputs with the same label can be processed together using the `label`
329
+ parameter in `sum_delta_inputs()`.
330
+ - Non-callable inputs are consumed when used (removed after first use).
331
+ - Callable inputs persist and can be called repeatedly.
332
+
333
+ See Also
334
+ --------
335
+ sum_delta_inputs : Sum all delta inputs matching a given label
336
+ add_current_input : Add a current input function or array
337
+ """
338
+ key = _input_label_repr(key, label)
339
+ if self._delta_inputs is None:
340
+ self._delta_inputs = dict()
341
+ if key in self._delta_inputs:
342
+ if id(self._delta_inputs[key]) != id(inp):
343
+ raise ValueError(f'Key "{key}" has been defined and used.')
344
+ self._delta_inputs[key] = inp
345
+
346
+ def get_input(self, key: str):
347
+ """
348
+ Get a registered input function by its key.
349
+
350
+ Retrieves either a current input or a delta input function that was previously
351
+ registered with the given key. This method checks both current_inputs and
352
+ delta_inputs dictionaries for the specified key.
353
+
354
+ Parameters
355
+ ----------
356
+ key : str
357
+ The unique identifier used when the input function was registered.
358
+
359
+ Returns
360
+ -------
361
+ Callable or ArrayLike
362
+ The input function or array associated with the given key.
363
+
364
+ Raises
365
+ ------
366
+ ValueError
367
+ If no input function is found with the specified key in either
368
+ current_inputs or delta_inputs.
369
+
370
+ See Also
371
+ --------
372
+ add_current_input : Register a current input function
373
+ add_delta_input : Register a delta input function
374
+
375
+ Examples
376
+ --------
377
+ >>> model = Dynamics(10)
378
+ >>> model.add_current_input('stimulus', lambda t: np.sin(t))
379
+ >>> input_func = model.get_input('stimulus')
380
+ >>> input_func(0.5) # Returns sin(0.5)
381
+ """
382
+ if self._current_inputs is not None and key in self._current_inputs:
383
+ return self._current_inputs[key]
384
+ elif self._delta_inputs is not None and key in self._delta_inputs:
385
+ return self._delta_inputs[key]
386
+ else:
387
+ raise ValueError(f'Input key {key} is not in current/delta inputs of the module {self}.')
388
+
389
+ def sum_current_inputs(
390
+ self,
391
+ init: Any,
392
+ *args,
393
+ label: Optional[str] = None,
394
+ **kwargs
395
+ ):
396
+ """
397
+ Summarize all current inputs by applying and summing all registered current input functions.
398
+
399
+ This method iterates through all registered current input functions (from `.current_inputs`)
400
+ and applies them to calculate the total input current for the dynamics model. It adds all results
401
+ to the initial value provided.
402
+
403
+ Parameters
404
+ ----------
405
+ init : Any
406
+ The initial value to which all current inputs will be added.
407
+ *args
408
+ Variable length argument list passed to each current input function.
409
+ label : Optional[str], default=None
410
+ If provided, only process current inputs with this label prefix.
411
+ When None, process all current inputs regardless of label.
412
+ **kwargs
413
+ Arbitrary keyword arguments passed to each current input function.
414
+
415
+ Returns
416
+ -------
417
+ Any
418
+ The initial value plus all applicable current inputs summed together.
419
+
420
+ Notes
421
+ -----
422
+ - Non-callable current inputs are applied once and then automatically removed from
423
+ the current_inputs dictionary.
424
+ - Callable current inputs remain registered for subsequent calls.
425
+ - When a label is provided, only current inputs with keys starting with that label
426
+ are applied.
427
+ """
428
+ if self._current_inputs is None:
429
+ return init
430
+ if label is None:
431
+ filter_fn = lambda k: True
432
+ else:
433
+ label_repr = _input_label_start(label)
434
+ filter_fn = lambda k: k.startswith(label_repr)
435
+ for key in tuple(self._current_inputs.keys()):
436
+ if filter_fn(key):
437
+ out = self._current_inputs[key]
438
+ if callable(out):
439
+ try:
440
+ init = init + out(*args, **kwargs)
441
+ except Exception as e:
442
+ raise ValueError(
443
+ f'Error in current input value {key}: {out}\n'
444
+ f'Error: {e}'
445
+ ) from e
446
+ else:
447
+ try:
448
+ init = init + out
449
+ except Exception as e:
450
+ raise ValueError(
451
+ f'Error in current input value {key}: {out}\n'
452
+ f'Error: {e}'
453
+ ) from e
454
+ self._current_inputs.pop(key)
455
+ return init
456
+
457
+ def sum_delta_inputs(
458
+ self,
459
+ init: Any,
460
+ *args,
461
+ label: Optional[str] = None,
462
+ **kwargs
463
+ ):
464
+ """
465
+ Summarize all delta inputs by applying and summing all registered delta input functions.
466
+
467
+ This method iterates through all registered delta input functions (from `.delta_inputs`)
468
+ and applies them to calculate instantaneous changes to model states. It adds all results
469
+ to the initial value provided.
470
+
471
+ Parameters
472
+ ----------
473
+ init : Any
474
+ The initial value to which all delta inputs will be added.
475
+ *args
476
+ Variable length argument list passed to each delta input function.
477
+ label : Optional[str], default=None
478
+ If provided, only process delta inputs with this label prefix.
479
+ When None, process all delta inputs regardless of label.
480
+ **kwargs
481
+ Arbitrary keyword arguments passed to each delta input function.
482
+
483
+ Returns
484
+ -------
485
+ Any
486
+ The initial value plus all applicable delta inputs summed together.
487
+
488
+ Notes
489
+ -----
490
+ - Non-callable delta inputs are applied once and then automatically removed from
491
+ the delta_inputs dictionary.
492
+ - Callable delta inputs remain registered for subsequent calls.
493
+ - When a label is provided, only delta inputs with keys starting with that label
494
+ are applied.
495
+ """
496
+ if self._delta_inputs is None:
497
+ return init
498
+ if label is None:
499
+ filter_fn = lambda k: True
500
+ else:
501
+ label_repr = _input_label_start(label)
502
+ filter_fn = lambda k: k.startswith(label_repr)
503
+ for key in tuple(self._delta_inputs.keys()):
504
+ if filter_fn(key):
505
+ out = self._delta_inputs[key]
506
+ if callable(out):
507
+ try:
508
+ init = init + out(*args, **kwargs)
509
+ except Exception as e:
510
+ raise ValueError(
511
+ f'Error in delta input function {key}: {out}\n'
512
+ f'Error: {e}'
513
+ ) from e
514
+ else:
515
+ try:
516
+ init = init + out
517
+ except Exception as e:
518
+ raise ValueError(
519
+ f'Error in delta input value {key}: {out}\n'
520
+ f'Error: {e}'
521
+ ) from e
522
+ self._delta_inputs.pop(key)
523
+ return init
524
+
525
+ @property
526
+ def before_updates(self):
527
+ """
528
+ Get the dictionary of functions to execute before the module's update.
529
+
530
+ Returns
531
+ -------
532
+ dict or None
533
+ Dictionary mapping keys to callable functions that will be executed
534
+ before the main update, or None if no before updates are registered.
535
+
536
+ Notes
537
+ -----
538
+ Before updates are executed in the order they were registered whenever
539
+ the module is called via __call__.
540
+ """
541
+ return self._before_updates
542
+
543
+ @property
544
+ def after_updates(self):
545
+ """
546
+ Get the dictionary of functions to execute after the module's update.
547
+
548
+ Returns
549
+ -------
550
+ dict or None
551
+ Dictionary mapping keys to callable functions that will be executed
552
+ after the main update, or None if no after updates are registered.
553
+
554
+ Notes
555
+ -----
556
+ After updates are executed in the order they were registered whenever
557
+ the module is called via __call__, and may optionally receive the return
558
+ value from the update method.
559
+ """
560
+ return self._after_updates
561
+
562
+ def _add_before_update(self, key: Any, fun: Callable):
563
+ """
564
+ Register a function to be executed before the module's update.
565
+
566
+ Parameters
567
+ ----------
568
+ key : Any
569
+ A unique identifier for the update function.
570
+ fun : Callable
571
+ The function to execute before the module's update.
572
+
573
+ Raises
574
+ ------
575
+ KeyError
576
+ If the key is already registered in before_updates.
577
+
578
+ Notes
579
+ -----
580
+ Internal method used by the module system to register dependencies.
581
+ """
582
+ if self._before_updates is None:
583
+ self._before_updates = dict()
584
+ if key in self.before_updates:
585
+ raise KeyError(f'{key} has been registered in before_updates of {self}')
586
+ self.before_updates[key] = fun
587
+
588
+ def _add_after_update(self, key: Any, fun: Callable):
589
+ """
590
+ Register a function to be executed after the module's update.
591
+
592
+ Parameters
593
+ ----------
594
+ key : Any
595
+ A unique identifier for the update function.
596
+ fun : Callable
597
+ The function to execute after the module's update.
598
+
599
+ Raises
600
+ ------
601
+ KeyError
602
+ If the key is already registered in after_updates.
603
+
604
+ Notes
605
+ -----
606
+ Internal method used by the module system to register dependencies.
607
+ """
608
+ if self._after_updates is None:
609
+ self._after_updates = dict()
610
+ if key in self.after_updates:
611
+ raise KeyError(f'{key} has been registered in after_updates of {self}')
612
+ self.after_updates[key] = fun
613
+
614
+ def _get_before_update(self, key: Any):
615
+ """
616
+ Retrieve a registered before-update function by its key.
617
+
618
+ Parameters
619
+ ----------
620
+ key : Any
621
+ The identifier of the before-update function to retrieve.
622
+
623
+ Returns
624
+ -------
625
+ Callable
626
+ The registered before-update function.
627
+
628
+ Raises
629
+ ------
630
+ KeyError
631
+ If the key is not registered in before_updates or if before_updates is None.
632
+ """
633
+ if self._before_updates is None:
634
+ raise KeyError(f'{key} is not registered in before_updates of {self}')
635
+ if key not in self.before_updates:
636
+ raise KeyError(f'{key} is not registered in before_updates of {self}')
637
+ return self.before_updates.get(key)
638
+
639
+ def _get_after_update(self, key: Any):
640
+ """
641
+ Retrieve a registered after-update function by its key.
642
+
643
+ Parameters
644
+ ----------
645
+ key : Any
646
+ The identifier of the after-update function to retrieve.
647
+
648
+ Returns
649
+ -------
650
+ Callable
651
+ The registered after-update function.
652
+
653
+ Raises
654
+ ------
655
+ KeyError
656
+ If the key is not registered in after_updates or if after_updates is None.
657
+ """
658
+ if self._after_updates is None:
659
+ raise KeyError(f'{key} is not registered in after_updates of {self}')
660
+ if key not in self.after_updates:
661
+ raise KeyError(f'{key} is not registered in after_updates of {self}')
662
+ return self.after_updates.get(key)
663
+
664
+ def _has_before_update(self, key: Any):
665
+ """
666
+ Check if a before-update function is registered with the given key.
667
+
668
+ Parameters
669
+ ----------
670
+ key : Any
671
+ The identifier to check for in the before_updates dictionary.
672
+
673
+ Returns
674
+ -------
675
+ bool
676
+ True if the key is registered in before_updates, False otherwise.
677
+ """
678
+ if self._before_updates is None:
679
+ return False
680
+ return key in self.before_updates
681
+
682
+ def _has_after_update(self, key: Any):
683
+ """
684
+ Check if an after-update function is registered with the given key.
685
+
686
+ Parameters
687
+ ----------
688
+ key : Any
689
+ The identifier to check for in the after_updates dictionary.
690
+
691
+ Returns
692
+ -------
693
+ bool
694
+ True if the key is registered in after_updates, False otherwise.
695
+ """
696
+ if self._after_updates is None:
697
+ return False
698
+ return key in self.after_updates
699
+
700
+ def __call__(self, *args, **kwargs):
701
+ """
702
+ The shortcut to call ``update`` methods.
703
+ """
704
+
705
+ # ``before_updates``
706
+ if self.before_updates is not None:
707
+ for model in self.before_updates.values():
708
+ if hasattr(model, '_receive_update_input'):
709
+ model(*args, **kwargs)
710
+ else:
711
+ model()
712
+
713
+ # update the model self
714
+ ret = self.update(*args, **kwargs)
715
+
716
+ # ``after_updates``
717
+ if self.after_updates is not None:
718
+ for model in self.after_updates.values():
719
+ if hasattr(model, '_not_receive_update_output'):
720
+ model()
721
+ else:
722
+ model(ret)
723
+ return ret
724
+
725
+ def prefetch(self, item: str) -> 'Prefetch':
726
+ """
727
+ Create a reference to a state or variable that may not be initialized yet.
728
+
729
+ This method allows accessing module attributes or states before they are
730
+ fully defined, acting as a placeholder that will be resolved when called.
731
+ Particularly useful for creating references to variables that will be defined
732
+ during initialization or runtime.
733
+
734
+ Parameters
735
+ ----------
736
+ item : str
737
+ The name of the attribute or state to reference.
738
+
739
+ Returns
740
+ -------
741
+ Prefetch
742
+ A Prefetch object that provides access to the referenced item.
743
+
744
+ Examples
745
+ --------
746
+ >>> import brainstate
747
+ >>> import brainunit as u
748
+ >>> neuron = brainstate.nn.LIF(...)
749
+ >>> v_ref = neuron.prefetch('V') # Reference to voltage
750
+ >>> v_value = v_ref() # Get current value
751
+ >>> delayed_v = v_ref.delay.at(5.0 * u.ms) # Get delayed value
752
+ """
753
+ return Prefetch(self, item)
754
+
755
+ def align_pre(self, dyn: Union[ParamDescriber[T], T]) -> T:
756
+ """
757
+ Registers a dynamics module to execute after this module.
758
+
759
+ This method establishes a sequential execution relationship where the specified
760
+ dynamics module will be called after this module completes its update. This
761
+ creates a feed-forward connection in the computational graph.
762
+
763
+ Parameters
764
+ ----------
765
+ dyn : Union[ParamDescriber[T], T]
766
+ The dynamics module to be executed after this module. Can be either:
767
+ - An instance of Dynamics
768
+ - A ParamDescriber that can instantiate a Dynamics object
769
+
770
+ Returns
771
+ -------
772
+ T
773
+ The dynamics module that was registered, allowing for method chaining.
774
+
775
+ Raises
776
+ ------
777
+ TypeError
778
+ If the input is not a Dynamics instance or a ParamDescriber that creates
779
+ a Dynamics instance.
780
+
781
+ Examples
782
+ --------
783
+ >>> import brainstate
784
+ >>> n1 = brainstate.nn.LIF(10)
785
+ >>> n1.align_pre(brainstate.nn.Expon.desc(n1.varshape)) # n2 will run after n1
786
+ """
787
+ if isinstance(dyn, Dynamics):
788
+ self._add_after_update(id(dyn), dyn)
789
+ return dyn
790
+ elif isinstance(dyn, ParamDescriber):
791
+ if not issubclass(dyn.cls, Dynamics):
792
+ raise TypeError(f'The input {dyn} should be an instance of {Dynamics}.')
793
+ if not self._has_after_update(dyn.identifier):
794
+ self._add_after_update(
795
+ dyn.identifier,
796
+ dyn() if ('in_size' in dyn.kwargs or len(dyn.args) > 0) else dyn(in_size=self.varshape)
797
+ )
798
+ return self._get_after_update(dyn.identifier)
799
+ else:
800
+ raise TypeError(f'The input {dyn} should be an instance of {Dynamics} or a delayed initializer.')
801
+
802
+ def prefetch_delay(self, state: str, delay_time, init: Callable = None) -> 'PrefetchDelayAt':
803
+ """
804
+ Create a reference to a delayed state or variable in the module.
805
+
806
+ This method simplifies the process of accessing a delayed version of a state or variable
807
+ within the module. It first creates a prefetch reference to the specified state,
808
+ then specifies the delay time for accessing this state.
809
+
810
+ Args:
811
+ state (str): The name of the state or variable to reference.
812
+ delay_time (ArrayLike): The amount of time to delay the variable access,
813
+ typically in time units (e.g., milliseconds).
814
+ init (Callable, optional): An optional initialization function to provide
815
+ a default value if the delayed state is not yet available.
816
+
817
+ Returns:
818
+ PrefetchDelayAt: An object that provides access to the variable at the specified delay time.
819
+ """
820
+ return PrefetchDelayAt(self, state, delay_time, init=init)
821
+
822
+ def output_delay(self, *delay_time) -> 'OutputDelayAt':
823
+ """
824
+ Create a reference to the delayed output of the module.
825
+
826
+ This method simplifies the process of accessing a delayed version of the module's output.
827
+ It instantiates an `OutputDelayAt` object, which can be used to retrieve the output value
828
+ at the specified delay time.
829
+
830
+ Args:
831
+ delay (Optional[ArrayLike]): The amount of time to delay the output access,
832
+ typically in time units (e.g., milliseconds). Defaults to None.
833
+
834
+ Returns:
835
+ OutputDelayAt: An object that provides access to the module's output at the specified delay time.
836
+ """
837
+ return OutputDelayAt(self, delay_time)
838
+
839
+
840
+ class Prefetch(Node):
841
+ """
842
+ Prefetch a state or variable in a module before it is initialized.
843
+
844
+
845
+ This class provides a mechanism to reference a module's state or attribute
846
+ that may not have been initialized yet. It acts as a placeholder or reference
847
+ that will be resolved when called.
848
+
849
+ Use cases:
850
+ - Access variables within dynamics modules that will be defined later
851
+ - Create references to states across module boundaries
852
+ - Enable access to delayed states through the `.delay` property
853
+
854
+ Parameters
855
+ ----------
856
+ module : Module
857
+ The module that contains or will contain the referenced item.
858
+ item : str
859
+ The attribute name of the state or variable to prefetch.
860
+
861
+ Examples
862
+ --------
863
+ >>> import brainstate
864
+ >>> import brainunit as u
865
+ >>> neuron = brainstate.nn.LIF(...)
866
+ >>> v_reference = neuron.prefetch('V') # Reference to voltage before initialization
867
+ >>> v_value = v_reference() # Get the current value
868
+ >>> delay_ref = v_reference.delay.at(5.0 * u.ms) # Reference voltage delayed by 5ms
869
+
870
+ Notes
871
+ -----
872
+ When called, this class retrieves the current value of the referenced item.
873
+ Use the `.delay` property to access delayed versions of the state.
874
+
875
+ """
876
+
877
+ def __init__(self, module: Dynamics, item: str):
878
+ """
879
+ Initialize a Prefetch object.
880
+
881
+ Parameters
882
+ ----------
883
+ module : Module
884
+ The module that contains or will contain the referenced item.
885
+ item : str
886
+ The attribute name of the state or variable to prefetch.
887
+ """
888
+ super().__init__()
889
+ self.module = module
890
+ self.item = item
891
+
892
+ @property
893
+ def delay(self):
894
+ """
895
+ Access delayed versions of the prefetched item.
896
+
897
+ Returns
898
+ -------
899
+ PrefetchDelay
900
+ An object that provides access to delayed versions of the prefetched item.
901
+ """
902
+ return PrefetchDelay(self.module, self.item)
903
+ # return PrefetchDelayAt(self.module, self.item, time)
904
+
905
+ def __call__(self, *args, **kwargs):
906
+ """
907
+ Get the current value of the prefetched item.
908
+
909
+ Returns
910
+ -------
911
+ Any
912
+ The current value of the referenced item. If the item is a State object,
913
+ returns its value attribute, otherwise returns the item itself.
914
+ """
915
+ item = _get_prefetch_item(self)
916
+ return item.value if isinstance(item, State) else item
917
+
918
+ def get_item_value(self):
919
+ """
920
+ Get the current value of the prefetched item.
921
+
922
+ Similar to __call__, but explicitly named for clarity.
923
+
924
+ Returns
925
+ -------
926
+ Any
927
+ The current value of the referenced item. If the item is a State object,
928
+ returns its value attribute, otherwise returns the item itself.
929
+ """
930
+ item = _get_prefetch_item(self)
931
+ return item.value if isinstance(item, State) else item
932
+
933
+ def get_item(self):
934
+ """
935
+ Get the referenced item object itself, not its value.
936
+
937
+ Returns
938
+ -------
939
+ Any
940
+ The actual referenced item from the module, which could be a State
941
+ object or any other attribute.
942
+ """
943
+ return _get_prefetch_item(self)
944
+
945
+
946
+ class PrefetchDelay(Node):
947
+ """
948
+ Provides access to delayed versions of a prefetched state or variable.
949
+
950
+ This class acts as an intermediary for accessing delayed values of module variables.
951
+ It doesn't retrieve values directly but provides methods to specify the delay time
952
+ via the `at()` method.
953
+
954
+ Parameters
955
+ ----------
956
+ module : Dynamics
957
+ The dynamics module that contains the referenced state or variable.
958
+ item : str
959
+ The name of the state or variable to access with delay.
960
+
961
+ Examples
962
+ --------
963
+ >>> import brainstate
964
+ >>> import brainunit as u
965
+ >>> neuron = brainstate.nn.LIF(10)
966
+ >>> # Access voltage delayed by 5ms
967
+ >>> delayed_v = neuron.prefetch('V').delay.at(5.0 * u.ms)
968
+ >>> delayed_value = delayed_v() # Get the delayed value
969
+ """
970
+
971
+ def __init__(self, module: Dynamics, item: str):
972
+ self.module = module
973
+ self.item = item
974
+
975
+ def at(self, *delay_time):
976
+ """
977
+ Specifies the delay time for accessing the variable.
978
+
979
+ Parameters
980
+ ----------
981
+ time : ArrayLike
982
+ The amount of time to delay the variable access, typically in time units
983
+ (e.g., milliseconds).
984
+
985
+ Returns
986
+ -------
987
+ PrefetchDelayAt
988
+ An object that provides access to the variable at the specified delay time.
989
+ """
990
+ return PrefetchDelayAt(self.module, self.item, delay_time)
991
+
992
+
993
+ class PrefetchDelayAt(Node):
994
+ """
995
+ Provides access to a specific delayed state or variable value at the specific time.
996
+
997
+ This class represents the final step in the prefetch delay chain, providing
998
+ actual access to state values at a specific delay time. It converts the
999
+ specified time delay into steps and registers the delay with the appropriate
1000
+ StateWithDelay handler.
1001
+
1002
+ Parameters
1003
+ ----------
1004
+ module : Dynamics
1005
+ The dynamics module that contains the referenced state or variable.
1006
+ item : str
1007
+ The name of the state or variable to access with delay.
1008
+ time : ArrayLike
1009
+ The amount of time to delay access by, typically in time units (e.g., milliseconds).
1010
+
1011
+ Examples
1012
+ --------
1013
+ >>> import brainstate
1014
+ >>> import brainunit as u
1015
+ >>> neuron = brainstate.nn.LIF(10)
1016
+ >>> # Create a reference to voltage delayed by 5ms
1017
+ >>> delayed_v = PrefetchDelayAt(neuron, 'V', 5.0 * u.ms)
1018
+ >>> # Get the delayed value
1019
+ >>> v_value = delayed_v()
1020
+ """
1021
+
1022
+ def __init__(
1023
+ self,
1024
+ module: Dynamics,
1025
+ item: str,
1026
+ delay_time: Tuple,
1027
+ init: Callable = None
1028
+ ):
1029
+ """
1030
+ Initialize a PrefetchDelayAt object.
1031
+
1032
+ Parameters
1033
+ ----------
1034
+ module : Dynamics
1035
+ The dynamics module that contains the referenced state or variable.
1036
+ item : str
1037
+ The name of the state or variable to access with delay.
1038
+ delay_time : Tuple
1039
+ The amount of time to delay access by, typically in time units (e.g., milliseconds).
1040
+ """
1041
+ super().__init__()
1042
+ assert isinstance(module, Dynamics), 'The module should be an instance of Dynamics.'
1043
+ self.module = module
1044
+ self.item = item
1045
+ if not isinstance(delay_time, (tuple, list)):
1046
+ delay_time = (delay_time,)
1047
+ self.delay_time = delay_time
1048
+ if len(delay_time) > 0:
1049
+ key = _get_prefetch_delay_key(item)
1050
+ if not module._has_after_update(key):
1051
+ module._add_after_update(
1052
+ key,
1053
+ not_receive_update_output(
1054
+ StateWithDelay(module, item, init=init)
1055
+ )
1056
+ )
1057
+ self.state_delay: StateWithDelay = module._get_after_update(key)
1058
+ self.delay_info = self.state_delay.register_delay(*delay_time)
1059
+
1060
+ def __call__(self, *args, **kwargs):
1061
+ """
1062
+ Retrieve the value of the state at the specified delay time.
1063
+
1064
+ Returns
1065
+ -------
1066
+ Any
1067
+ The value of the state or variable at the specified delay time.
1068
+ """
1069
+ if len(self.delay_time) == 0:
1070
+ return _get_prefetch_item(self).value
1071
+ else:
1072
+ return self.state_delay.retrieve_at_step(*self.delay_info)
1073
+
1074
+
1075
+ class OutputDelayAt(Node):
1076
+ """
1077
+ Provides access to a specific delayed state or variable value at the specific time.
1078
+
1079
+ This class represents the final step in the prefetch delay chain, providing
1080
+ actual access to state values at a specific delay time. It converts the
1081
+ specified time delay into steps and registers the delay with the appropriate
1082
+ StateWithDelay handler.
1083
+
1084
+ Parameters
1085
+ ----------
1086
+ module : Dynamics
1087
+ The dynamics module that contains the referenced state or variable.
1088
+ time : ArrayLike
1089
+ The amount of time to delay access by, typically in time units (e.g., milliseconds).
1090
+
1091
+ Examples
1092
+ --------
1093
+ >>> import brainstate
1094
+ >>> import brainunit as u
1095
+ >>> neuron = brainstate.nn.LIF(10)
1096
+ >>> # Create a reference to voltage delayed by 5ms
1097
+ >>> delayed_spike = OutputDelayAt(neuron, 5.0 * u.ms)
1098
+ >>> # Get the delayed value
1099
+ >>> v_value = delayed_spike()
1100
+ """
1101
+
1102
+ def __init__(
1103
+ self,
1104
+ module: Dynamics,
1105
+ delay_time: Tuple,
1106
+ ):
1107
+ super().__init__()
1108
+ assert isinstance(module, Dynamics), 'The module should be an instance of Dynamics.'
1109
+ self.module = module
1110
+ key = _get_output_delay_key()
1111
+ if not module._has_after_update(key):
1112
+ delay = Delay(jax.ShapeDtypeStruct(module.out_size, dtype=environ.dftype()), take_aware_unit=True)
1113
+ module._add_after_update(key, receive_update_output(delay))
1114
+ self.out_delay: Delay = module._get_after_update(key)
1115
+ self.delay_info = self.out_delay.register_delay(*delay_time)
1116
+
1117
+ def __call__(self, *args, **kwargs):
1118
+ return self.out_delay.retrieve_at_step(*self.delay_info)
1119
+
1120
+
1121
+ def _get_prefetch_delay_key(item) -> str:
1122
+ return f'{item}-prefetch-delay'
1123
+
1124
+
1125
+ def _get_output_delay_key() -> str:
1126
+ return f'output-delay'
1127
+
1128
+
1129
+ def _get_prefetch_item(target: Union[Prefetch, PrefetchDelayAt]) -> Any:
1130
+ item = getattr(target.module, target.item, None)
1131
+ if item is None:
1132
+ raise AttributeError(f'The target {target.module} should have an `{target.item}` attribute.')
1133
+ return item
1134
+
1135
+
1136
+ def _get_prefetch_item_delay(target: Union[Prefetch, PrefetchDelay, PrefetchDelayAt]) -> Delay:
1137
+ assert isinstance(target.module, Dynamics), (
1138
+ f'The target module should be an instance '
1139
+ f'of Dynamics. But got {target.module}.'
1140
+ )
1141
+ delay = target.module._get_after_update(_get_prefetch_delay_key(target.item))
1142
+ if not isinstance(delay, StateWithDelay):
1143
+ raise TypeError(f'The prefetch target should be a {StateWithDelay.__name__} when accessing '
1144
+ f'its delay. But got {delay}.')
1145
+ return delay
1146
+
1147
+
1148
+ def maybe_init_prefetch(target, *args, **kwargs):
1149
+ """
1150
+ Initialize a prefetch target if needed, based on its type.
1151
+
1152
+ This function ensures that prefetch references are properly initialized
1153
+ and ready to use. It handles different types of prefetch objects by
1154
+ performing the appropriate initialization action:
1155
+ - For :py:class:`Prefetch` objects: retrieves the referenced item
1156
+ - For :py:class:`PrefetchDelay` objects: retrieves the delay handler
1157
+ - For :py:class:`PrefetchDelayAt` objects: registers the specified delay
1158
+
1159
+ Parameters
1160
+ ----------
1161
+ target : Union[Prefetch, PrefetchDelay, PrefetchDelayAt]
1162
+ The prefetch target to initialize.
1163
+ *args : Any
1164
+ Additional positional arguments (unused).
1165
+ **kwargs : Any
1166
+ Additional keyword arguments (unused).
1167
+
1168
+ Returns
1169
+ -------
1170
+ None
1171
+ This function performs initialization side effects only.
1172
+
1173
+ Notes
1174
+ -----
1175
+ This function is typically called internally when prefetched references
1176
+ are used to ensure they are properly set up before access.
1177
+ """
1178
+ if isinstance(target, Prefetch):
1179
+ _get_prefetch_item(target)
1180
+
1181
+ elif isinstance(target, PrefetchDelay):
1182
+ _get_prefetch_item_delay(target)
1183
+
1184
+ elif isinstance(target, PrefetchDelayAt):
1185
+ pass
1186
+ # delay = _get_prefetch_item_delay(target)
1187
+ # delay.register_delay(*target.delay_time)
1188
+
1189
+
1190
+ DynamicsGroup = Module
1191
+
1192
+
1193
+ def receive_update_output(cls: object):
1194
+ """
1195
+ The decorator to mark the object (as the after updates) to receive the output of the update function.
1196
+
1197
+ That is, the `aft_update` will receive the return of the update function::
1198
+
1199
+ ret = model.update(*args, **kwargs)
1200
+ for fun in model.aft_updates:
1201
+ fun(ret)
1202
+
1203
+ """
1204
+ # assert isinstance(cls, Module), 'The input class should be instance of Module.'
1205
+ if hasattr(cls, '_not_receive_update_output'):
1206
+ delattr(cls, '_not_receive_update_output')
1207
+ return cls
1208
+
1209
+
1210
+ def not_receive_update_output(cls: T) -> T:
1211
+ """
1212
+ The decorator to mark the object (as the after updates) to not receive the output of the update function.
1213
+
1214
+ That is, the `aft_update` will not receive the return of the update function::
1215
+
1216
+ ret = model.update(*args, **kwargs)
1217
+ for fun in model.aft_updates:
1218
+ fun()
1219
+
1220
+ """
1221
+ # assert isinstance(cls, Module), 'The input class should be instance of Module.'
1222
+ cls._not_receive_update_output = True
1223
+ return cls
1224
+
1225
+
1226
+ def receive_update_input(cls: object):
1227
+ """
1228
+ The decorator to mark the object (as the before updates) to receive the input of the update function.
1229
+
1230
+ That is, the `bef_update` will receive the input of the update function::
1231
+
1232
+
1233
+ for fun in model.bef_updates:
1234
+ fun(*args, **kwargs)
1235
+ model.update(*args, **kwargs)
1236
+
1237
+ """
1238
+ # assert isinstance(cls, Module), 'The input class should be instance of Module.'
1239
+ cls._receive_update_input = True
1240
+ return cls
1241
+
1242
+
1243
+ def not_receive_update_input(cls: object):
1244
+ """
1245
+ The decorator to mark the object (as the before updates) to not receive the input of the update function.
1246
+
1247
+ That is, the `bef_update` will not receive the input of the update function::
1248
+
1249
+ for fun in model.bef_updates:
1250
+ fun()
1251
+ model.update()
1252
+
1253
+ """
1254
+ # assert isinstance(cls, Module), 'The input class should be instance of Module.'
1255
+ if hasattr(cls, '_receive_update_input'):
1256
+ delattr(cls, '_receive_update_input')
1257
+ return cls
1258
+
1259
+
1260
+ def _input_label_start(label: str):
1261
+ # unify the input label repr.
1262
+ return f'{label} // '
1263
+
1264
+
1265
+ def _input_label_repr(name: str, label: Optional[str] = None):
1266
+ # unify the input label repr.
1267
+ return name if label is None else (_input_label_start(label) + str(name))