brainstate 0.1.8__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. brainstate/__init__.py +58 -51
  2. brainstate/_compatible_import.py +148 -148
  3. brainstate/_state.py +1605 -1663
  4. brainstate/_state_test.py +52 -52
  5. brainstate/_utils.py +47 -47
  6. brainstate/augment/__init__.py +30 -30
  7. brainstate/augment/_autograd.py +778 -778
  8. brainstate/augment/_autograd_test.py +1289 -1289
  9. brainstate/augment/_eval_shape.py +99 -99
  10. brainstate/augment/_eval_shape_test.py +38 -38
  11. brainstate/augment/_mapping.py +1060 -1060
  12. brainstate/augment/_mapping_test.py +597 -597
  13. brainstate/augment/_random.py +151 -151
  14. brainstate/compile/__init__.py +38 -38
  15. brainstate/compile/_ad_checkpoint.py +204 -204
  16. brainstate/compile/_ad_checkpoint_test.py +49 -49
  17. brainstate/compile/_conditions.py +256 -256
  18. brainstate/compile/_conditions_test.py +220 -220
  19. brainstate/compile/_error_if.py +92 -92
  20. brainstate/compile/_error_if_test.py +52 -52
  21. brainstate/compile/_jit.py +346 -346
  22. brainstate/compile/_jit_test.py +143 -143
  23. brainstate/compile/_loop_collect_return.py +536 -536
  24. brainstate/compile/_loop_collect_return_test.py +58 -58
  25. brainstate/compile/_loop_no_collection.py +184 -184
  26. brainstate/compile/_loop_no_collection_test.py +50 -50
  27. brainstate/compile/_make_jaxpr.py +888 -888
  28. brainstate/compile/_make_jaxpr_test.py +156 -156
  29. brainstate/compile/_progress_bar.py +202 -202
  30. brainstate/compile/_unvmap.py +159 -159
  31. brainstate/compile/_util.py +147 -147
  32. brainstate/environ.py +563 -563
  33. brainstate/environ_test.py +62 -62
  34. brainstate/functional/__init__.py +27 -26
  35. brainstate/graph/__init__.py +29 -29
  36. brainstate/graph/_graph_node.py +244 -244
  37. brainstate/graph/_graph_node_test.py +73 -73
  38. brainstate/graph/_graph_operation.py +1738 -1738
  39. brainstate/graph/_graph_operation_test.py +563 -563
  40. brainstate/init/__init__.py +26 -26
  41. brainstate/init/_base.py +52 -52
  42. brainstate/init/_generic.py +244 -244
  43. brainstate/init/_random_inits.py +553 -553
  44. brainstate/init/_random_inits_test.py +149 -149
  45. brainstate/init/_regular_inits.py +105 -105
  46. brainstate/init/_regular_inits_test.py +50 -50
  47. brainstate/mixin.py +365 -363
  48. brainstate/mixin_test.py +77 -73
  49. brainstate/nn/__init__.py +135 -131
  50. brainstate/{functional → nn}/_activations.py +808 -813
  51. brainstate/{functional → nn}/_activations_test.py +331 -331
  52. brainstate/nn/_collective_ops.py +514 -514
  53. brainstate/nn/_collective_ops_test.py +43 -43
  54. brainstate/nn/_common.py +178 -178
  55. brainstate/nn/_conv.py +501 -501
  56. brainstate/nn/_conv_test.py +238 -238
  57. brainstate/nn/_delay.py +509 -502
  58. brainstate/nn/_delay_test.py +238 -184
  59. brainstate/nn/_dropout.py +426 -426
  60. brainstate/nn/_dropout_test.py +100 -100
  61. brainstate/nn/_dynamics.py +1343 -1343
  62. brainstate/nn/_dynamics_test.py +78 -78
  63. brainstate/nn/_elementwise.py +1119 -1119
  64. brainstate/nn/_elementwise_test.py +169 -169
  65. brainstate/nn/_embedding.py +58 -58
  66. brainstate/nn/_exp_euler.py +92 -92
  67. brainstate/nn/_exp_euler_test.py +35 -35
  68. brainstate/nn/_fixedprob.py +239 -239
  69. brainstate/nn/_fixedprob_test.py +114 -114
  70. brainstate/nn/_inputs.py +608 -608
  71. brainstate/nn/_linear.py +424 -424
  72. brainstate/nn/_linear_mv.py +83 -83
  73. brainstate/nn/_linear_mv_test.py +120 -120
  74. brainstate/nn/_linear_test.py +107 -107
  75. brainstate/nn/_ltp.py +28 -28
  76. brainstate/nn/_module.py +377 -377
  77. brainstate/nn/_module_test.py +40 -40
  78. brainstate/nn/_neuron.py +705 -705
  79. brainstate/nn/_neuron_test.py +161 -161
  80. brainstate/nn/_normalizations.py +975 -918
  81. brainstate/nn/_normalizations_test.py +73 -73
  82. brainstate/{functional → nn}/_others.py +46 -46
  83. brainstate/nn/_poolings.py +1177 -1177
  84. brainstate/nn/_poolings_test.py +217 -217
  85. brainstate/nn/_projection.py +486 -486
  86. brainstate/nn/_rate_rnns.py +554 -554
  87. brainstate/nn/_rate_rnns_test.py +63 -63
  88. brainstate/nn/_readout.py +209 -209
  89. brainstate/nn/_readout_test.py +53 -53
  90. brainstate/nn/_stp.py +236 -236
  91. brainstate/nn/_synapse.py +505 -505
  92. brainstate/nn/_synapse_test.py +131 -131
  93. brainstate/nn/_synaptic_projection.py +423 -423
  94. brainstate/nn/_synouts.py +162 -162
  95. brainstate/nn/_synouts_test.py +57 -57
  96. brainstate/nn/_utils.py +89 -89
  97. brainstate/nn/metrics.py +388 -388
  98. brainstate/optim/__init__.py +38 -38
  99. brainstate/optim/_base.py +64 -64
  100. brainstate/optim/_lr_scheduler.py +448 -448
  101. brainstate/optim/_lr_scheduler_test.py +50 -50
  102. brainstate/optim/_optax_optimizer.py +152 -152
  103. brainstate/optim/_optax_optimizer_test.py +53 -53
  104. brainstate/optim/_sgd_optimizer.py +1104 -1104
  105. brainstate/random/__init__.py +24 -24
  106. brainstate/random/_rand_funs.py +3616 -3616
  107. brainstate/random/_rand_funs_test.py +567 -567
  108. brainstate/random/_rand_seed.py +210 -210
  109. brainstate/random/_rand_seed_test.py +48 -48
  110. brainstate/random/_rand_state.py +1409 -1409
  111. brainstate/random/_random_for_unit.py +52 -52
  112. brainstate/surrogate.py +1957 -1957
  113. brainstate/transform.py +23 -23
  114. brainstate/typing.py +304 -304
  115. brainstate/util/__init__.py +50 -50
  116. brainstate/util/caller.py +98 -98
  117. brainstate/util/error.py +55 -55
  118. brainstate/util/filter.py +469 -469
  119. brainstate/util/others.py +540 -540
  120. brainstate/util/pretty_pytree.py +945 -945
  121. brainstate/util/pretty_pytree_test.py +159 -159
  122. brainstate/util/pretty_repr.py +328 -328
  123. brainstate/util/pretty_table.py +2954 -2954
  124. brainstate/util/scaling.py +258 -258
  125. brainstate/util/struct.py +523 -523
  126. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
  127. brainstate-0.1.9.dist-info/RECORD +130 -0
  128. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
  129. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
  130. brainstate/functional/_normalization.py +0 -81
  131. brainstate/functional/_spikes.py +0 -204
  132. brainstate-0.1.8.dist-info/RECORD +0 -132
  133. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
@@ -1,1343 +1,1343 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
-
19
- """
20
- All the basic dynamics class for the ``brainstate``.
21
-
22
- For handling dynamical systems:
23
-
24
- - ``DynamicsGroup``: The class for a group of modules, which update ``Projection`` first,
25
- then ``Dynamics``, finally others.
26
- - ``Projection``: The class for the synaptic projection.
27
- - ``Dynamics``: The class for the dynamical system.
28
-
29
- For handling the delays:
30
-
31
- - ``Delay``: The class for all delays.
32
- - ``DelayAccess``: The class for the delay access.
33
-
34
- """
35
-
36
- from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, TYPE_CHECKING, Tuple
37
-
38
- import jax
39
- import numpy as np
40
-
41
- from brainstate import environ
42
- from brainstate._state import State
43
- from brainstate.graph import Node
44
- from brainstate.mixin import ParamDescriber
45
- from brainstate.typing import Size, ArrayLike
46
- from ._delay import StateWithDelay, Delay
47
- from ._module import Module
48
-
49
- __all__ = [
50
- 'DynamicsGroup', 'Projection', 'Dynamics',
51
- 'Prefetch', 'PrefetchDelay', 'PrefetchDelayAt', 'OutputDelayAt',
52
- ]
53
-
54
- T = TypeVar('T')
55
- _max_order = 10
56
-
57
-
58
- class Projection(Module):
59
- """
60
- Base class for synaptic projection modules in neural network modeling.
61
-
62
- This class defines the interface for modules that handle projections between
63
- neural populations. Projections process input signals and transform them
64
- before they reach the target neurons, implementing the connectivity patterns
65
- in neural networks.
66
-
67
- In the BrainState execution order, Projection modules are updated before
68
- Dynamics modules, following the natural information flow in neural systems:
69
- 1. Projections process inputs (synaptic transmission)
70
- 2. Dynamics update neuron states (neural integration)
71
-
72
- The Projection class does not implement the update logic directly but delegates
73
- to its child nodes. If no child nodes exist, it raises a ValueError.
74
-
75
- Parameters
76
- ----------
77
- *args : Any
78
- Arguments passed to the parent Module class.
79
- **kwargs : Any
80
- Keyword arguments passed to the parent Module class.
81
-
82
- Raises
83
- ------
84
- ValueError
85
- If the update() method is called but no child nodes are defined.
86
-
87
- Notes
88
- -----
89
- Derived classes should implement specific projection behaviors, such as
90
- dense connectivity, sparse connectivity, or specific weight update rules.
91
- """
92
- __module__ = 'brainstate.nn'
93
-
94
- def update(self, *args, **kwargs):
95
- sub_nodes = tuple(self.nodes(allowed_hierarchy=(1, 1)).values())
96
- if len(sub_nodes):
97
- for node in sub_nodes:
98
- node(*args, **kwargs)
99
- else:
100
- raise ValueError('Do not implement the update() function.')
101
-
102
-
103
- class Dynamics(Module):
104
- """
105
- Base class for implementing neural dynamics models in BrainState.
106
-
107
- Dynamics classes represent the core computational units in neural simulations,
108
- implementing the differential equations or update rules that govern neural activity.
109
- This class provides infrastructure for managing neural populations, handling inputs,
110
- and coordinating updates within the simulation framework.
111
-
112
- The Dynamics class serves several key purposes:
113
- 1. Managing neuron population geometry and size information
114
- 2. Handling current and delta (instantaneous change) inputs to neurons
115
- 3. Supporting before/after update hooks for computational dependencies
116
- 4. Providing access to delayed state variables through the prefetch mechanism
117
- 5. Establishing the execution order in neural network simulations
118
-
119
- Parameters
120
- ----------
121
- in_size : Size
122
- The geometry of the neuron population. Can be an integer (e.g., 10) for
123
- 1D neuron arrays, or a tuple (e.g., (10, 10)) for multi-dimensional populations.
124
- name : Optional[str], default=None
125
- Optional name identifier for this dynamics module.
126
-
127
- Attributes
128
- ----------
129
- in_size : tuple
130
- The shape/geometry of the neuron population.
131
- out_size : tuple
132
- The output shape, typically matches in_size.
133
- current_inputs : Optional[Dict[str, Union[Callable, ArrayLike]]]
134
- Dictionary of registered current input functions or arrays.
135
- delta_inputs : Optional[Dict[str, Union[Callable, ArrayLike]]]
136
- Dictionary of registered delta input functions or arrays.
137
- before_updates : Optional[Dict[Hashable, Callable]]
138
- Dictionary of functions to call before the main update.
139
- after_updates : Optional[Dict[Hashable, Callable]]
140
- Dictionary of functions to call after the main update.
141
-
142
- Notes
143
- -----
144
- In the BrainState execution sequence, Dynamics modules are updated after
145
- Projection modules and before other module types, reflecting the natural
146
- flow of information in neural systems.
147
-
148
- There are several essential attributes:
149
-
150
- - ``size``: the geometry of the neuron group. For example, `(10, )` denotes a line of
151
- neurons, `(10, 10)` denotes a neuron group aligned in a 2D space, `(10, 15, 4)` denotes
152
- a 3-dimensional neuron group.
153
- - ``num``: the flattened number of neurons in the group. For example, `size=(10, )` => \
154
- `num=10`, `size=(10, 10)` => `num=100`, `size=(10, 15, 4)` => `num=600`.
155
-
156
-
157
- See Also
158
- --------
159
- Module : Parent class providing base module functionality
160
- Projection : Class for handling synaptic projections between neural populations
161
- DynamicsGroup : Container for organizing multiple dynamics modules
162
- """
163
-
164
- __module__ = 'brainstate.nn'
165
-
166
- graph_invisible_attrs = ('_before_updates', '_after_updates', '_current_inputs', '_delta_inputs')
167
-
168
- # before updates
169
- _before_updates: Optional[Dict[Hashable, Callable]]
170
-
171
- # after updates
172
- _after_updates: Optional[Dict[Hashable, Callable]]
173
-
174
- # current inputs
175
- _current_inputs: Optional[Dict[str, ArrayLike | Callable]]
176
-
177
- # delta inputs
178
- _delta_inputs: Optional[Dict[str, ArrayLike | Callable]]
179
-
180
- def __init__(
181
- self,
182
- in_size: Size,
183
- name: Optional[str] = None,
184
- ):
185
- # initialize
186
- super().__init__(name=name)
187
-
188
- # geometry size of neuron population
189
- if isinstance(in_size, (list, tuple)):
190
- if len(in_size) <= 0:
191
- raise ValueError(f'"in_size" must be int, or a tuple/list of int. But we got {type(in_size)}')
192
- if not isinstance(in_size[0], (int, np.integer)):
193
- raise ValueError(f'"in_size" must be int, or a tuple/list of int. But we got {type(in_size)}')
194
- in_size = tuple(in_size)
195
- elif isinstance(in_size, (int, np.integer)):
196
- in_size = (in_size,)
197
- else:
198
- raise ValueError(f'"in_size" must be int, or a tuple/list of int. But we got {type(in_size)}')
199
- self.in_size = in_size
200
-
201
- # current inputs
202
- self._current_inputs = None
203
-
204
- # delta inputs
205
- self._delta_inputs = None
206
-
207
- # before updates
208
- self._before_updates = None
209
-
210
- # after updates
211
- self._after_updates = None
212
-
213
- # in-/out- size of neuron population
214
- self.out_size = self.in_size
215
-
216
- # def __pretty_repr_item__(self, name, value):
217
- # if name in [
218
- # '_before_updates', '_after_updates', '_current_inputs', '_delta_inputs',
219
- # '_in_size', '_out_size', '_name', '_mode',
220
- # ]:
221
- # return (name, value) if value is None else (name[1:], value) # skip the first `_`
222
- # return super().__pretty_repr_item__(name, value)
223
-
224
- @property
225
- def varshape(self):
226
- """
227
- Get the shape of variables in the neuron group.
228
-
229
- This property provides access to the geometry (shape) of the neuron population,
230
- which determines how variables and states are structured.
231
-
232
- Returns
233
- -------
234
- tuple
235
- A tuple representing the dimensional shape of the neuron group,
236
- matching the in_size parameter provided during initialization.
237
-
238
- See Also
239
- --------
240
- in_size : The input geometry specification for the neuron group
241
- """
242
- return self.in_size
243
-
244
- @property
245
- def current_inputs(self):
246
- """
247
- Get the dictionary of current inputs registered with this dynamics model.
248
-
249
- Current inputs represent direct input currents that flow into the model.
250
-
251
- Returns
252
- -------
253
- dict or None
254
- A dictionary mapping keys to current input functions or values,
255
- or None if no current inputs have been registered.
256
-
257
- See Also
258
- --------
259
- add_current_input : Register a new current input
260
- sum_current_inputs : Apply and sum all current inputs
261
- delta_inputs : Dictionary of instantaneous change inputs
262
- """
263
- return self._current_inputs
264
-
265
- @property
266
- def delta_inputs(self):
267
- """
268
- Get the dictionary of delta inputs registered with this dynamics model.
269
-
270
- Delta inputs represent instantaneous changes to state variables (dX/dt).
271
-
272
- Returns
273
- -------
274
- dict or None
275
- A dictionary mapping keys to delta input functions or values,
276
- or None if no delta inputs have been registered.
277
-
278
- See Also
279
- --------
280
- add_delta_input : Register a new delta input
281
- sum_delta_inputs : Apply and sum all delta inputs
282
- current_inputs : Dictionary of direct current inputs
283
- """
284
- return self._delta_inputs
285
-
286
- def add_current_input(
287
- self,
288
- key: str,
289
- inp: Union[Callable, ArrayLike],
290
- label: Optional[str] = None
291
- ):
292
- """
293
- Add a current input function or array to the dynamics model.
294
-
295
- Current inputs represent direct input currents that can be accessed during
296
- model updates through the `sum_current_inputs()` method.
297
-
298
- Parameters
299
- ----------
300
- key : str
301
- Unique identifier for this current input. Used to retrieve or reference
302
- the input later.
303
- inp : Union[Callable, ArrayLike]
304
- The input data or function that generates input data.
305
- - If callable: Will be called during updates with arguments passed to `sum_current_inputs()`
306
- - If array-like: Will be applied once and then automatically removed from available inputs
307
- label : Optional[str], default=None
308
- Optional grouping label for the input. When provided, allows selective
309
- processing of inputs by label in `sum_current_inputs()`.
310
-
311
- Raises
312
- ------
313
- ValueError
314
- If the key has already been used for a different current input.
315
-
316
- Notes
317
- -----
318
- - Inputs with the same label can be processed together using the `label`
319
- parameter in `sum_current_inputs()`.
320
- - Non-callable inputs are consumed when used (removed after first use).
321
- - Callable inputs persist and can be called repeatedly.
322
-
323
- See Also
324
- --------
325
- sum_current_inputs : Sum all current inputs matching a given label
326
- add_delta_input : Add a delta input function or array
327
- """
328
- key = _input_label_repr(key, label)
329
- if self._current_inputs is None:
330
- self._current_inputs = dict()
331
- if key in self._current_inputs:
332
- if id(self._current_inputs[key]) != id(inp):
333
- raise ValueError(f'Key "{key}" has been defined and used in the current inputs of {self}.')
334
- self._current_inputs[key] = inp
335
-
336
- def add_delta_input(
337
- self,
338
- key: str,
339
- inp: Union[Callable, ArrayLike],
340
- label: Optional[str] = None
341
- ):
342
- """
343
- Add a delta input function or array to the dynamics model.
344
-
345
- Delta inputs represent instantaneous changes to the model state (i.e., dX/dt contributions).
346
- This method registers a function or array that provides delta inputs which will be
347
- accessible during model updates through the `sum_delta_inputs()` method.
348
-
349
- Parameters
350
- ----------
351
- key : str
352
- Unique identifier for this delta input. Used to retrieve or reference
353
- the input later.
354
- inp : Union[Callable, ArrayLike]
355
- The input data or function that generates input data.
356
- - If callable: Will be called during updates with arguments passed to `sum_delta_inputs()`
357
- - If array-like: Will be applied once and then automatically removed from available inputs
358
- label : Optional[str], default=None
359
- Optional grouping label for the input. When provided, allows selective
360
- processing of inputs by label in `sum_delta_inputs()`.
361
-
362
- Raises
363
- ------
364
- ValueError
365
- If the key has already been used for a different delta input.
366
-
367
- Notes
368
- -----
369
- - Inputs with the same label can be processed together using the `label`
370
- parameter in `sum_delta_inputs()`.
371
- - Non-callable inputs are consumed when used (removed after first use).
372
- - Callable inputs persist and can be called repeatedly.
373
-
374
- See Also
375
- --------
376
- sum_delta_inputs : Sum all delta inputs matching a given label
377
- add_current_input : Add a current input function or array
378
- """
379
- key = _input_label_repr(key, label)
380
- if self._delta_inputs is None:
381
- self._delta_inputs = dict()
382
- if key in self._delta_inputs:
383
- if id(self._delta_inputs[key]) != id(inp):
384
- raise ValueError(f'Key "{key}" has been defined and used.')
385
- self._delta_inputs[key] = inp
386
-
387
- def get_input(self, key: str):
388
- """
389
- Get a registered input function by its key.
390
-
391
- Retrieves either a current input or a delta input function that was previously
392
- registered with the given key. This method checks both current_inputs and
393
- delta_inputs dictionaries for the specified key.
394
-
395
- Parameters
396
- ----------
397
- key : str
398
- The unique identifier used when the input function was registered.
399
-
400
- Returns
401
- -------
402
- Callable or ArrayLike
403
- The input function or array associated with the given key.
404
-
405
- Raises
406
- ------
407
- ValueError
408
- If no input function is found with the specified key in either
409
- current_inputs or delta_inputs.
410
-
411
- See Also
412
- --------
413
- add_current_input : Register a current input function
414
- add_delta_input : Register a delta input function
415
-
416
- Examples
417
- --------
418
- >>> model = Dynamics(10)
419
- >>> model.add_current_input('stimulus', lambda t: np.sin(t))
420
- >>> input_func = model.get_input('stimulus')
421
- >>> input_func(0.5) # Returns sin(0.5)
422
- """
423
- if self._current_inputs is not None and key in self._current_inputs:
424
- return self._current_inputs[key]
425
- elif self._delta_inputs is not None and key in self._delta_inputs:
426
- return self._delta_inputs[key]
427
- else:
428
- raise ValueError(f'Input key {key} is not in current/delta inputs of the module {self}.')
429
-
430
- def sum_current_inputs(
431
- self,
432
- init: Any,
433
- *args,
434
- label: Optional[str] = None,
435
- **kwargs
436
- ):
437
- """
438
- Summarize all current inputs by applying and summing all registered current input functions.
439
-
440
- This method iterates through all registered current input functions (from `.current_inputs`)
441
- and applies them to calculate the total input current for the dynamics model. It adds all results
442
- to the initial value provided.
443
-
444
- Parameters
445
- ----------
446
- init : Any
447
- The initial value to which all current inputs will be added.
448
- *args : tuple
449
- Variable length argument list passed to each current input function.
450
- label : Optional[str], default=None
451
- If provided, only process current inputs with this label prefix.
452
- When None, process all current inputs regardless of label.
453
- **kwargs : dict
454
- Arbitrary keyword arguments passed to each current input function.
455
-
456
- Returns
457
- -------
458
- Any
459
- The initial value plus all applicable current inputs summed together.
460
-
461
- Notes
462
- -----
463
- - Non-callable current inputs are applied once and then automatically removed from
464
- the current_inputs dictionary.
465
- - Callable current inputs remain registered for subsequent calls.
466
- - When a label is provided, only current inputs with keys starting with that label
467
- are applied.
468
- """
469
- if self._current_inputs is None:
470
- return init
471
- if label is None:
472
- filter_fn = lambda k: True
473
- else:
474
- label_repr = _input_label_start(label)
475
- filter_fn = lambda k: k.startswith(label_repr)
476
- for key in tuple(self._current_inputs.keys()):
477
- if filter_fn(key):
478
- out = self._current_inputs[key]
479
- if callable(out):
480
- try:
481
- init = init + out(*args, **kwargs)
482
- except Exception as e:
483
- raise ValueError(
484
- f'Error in delta input value {key}: {out}\n'
485
- f'Error: {e}'
486
- ) from e
487
- else:
488
- try:
489
- init = init + out
490
- except Exception as e:
491
- raise ValueError(
492
- f'Error in delta input value {key}: {out}\n'
493
- f'Error: {e}'
494
- ) from e
495
- self._current_inputs.pop(key)
496
- return init
497
-
498
- def sum_delta_inputs(
499
- self,
500
- init: Any,
501
- *args,
502
- label: Optional[str] = None,
503
- **kwargs
504
- ):
505
- """
506
- Summarize all delta inputs by applying and summing all registered delta input functions.
507
-
508
- This method iterates through all registered delta input functions (from `.delta_inputs`)
509
- and applies them to calculate instantaneous changes to model states. It adds all results
510
- to the initial value provided.
511
-
512
- Parameters
513
- ----------
514
- init : Any
515
- The initial value to which all delta inputs will be added.
516
- *args : tuple
517
- Variable length argument list passed to each delta input function.
518
- label : Optional[str], default=None
519
- If provided, only process delta inputs with this label prefix.
520
- When None, process all delta inputs regardless of label.
521
- **kwargs : dict
522
- Arbitrary keyword arguments passed to each delta input function.
523
-
524
- Returns
525
- -------
526
- Any
527
- The initial value plus all applicable delta inputs summed together.
528
-
529
- Notes
530
- -----
531
- - Non-callable delta inputs are applied once and then automatically removed from
532
- the delta_inputs dictionary.
533
- - Callable delta inputs remain registered for subsequent calls.
534
- - When a label is provided, only delta inputs with keys starting with that label
535
- are applied.
536
- """
537
- if self._delta_inputs is None:
538
- return init
539
- if label is None:
540
- filter_fn = lambda k: True
541
- else:
542
- label_repr = _input_label_start(label)
543
- filter_fn = lambda k: k.startswith(label_repr)
544
- for key in tuple(self._delta_inputs.keys()):
545
- if filter_fn(key):
546
- out = self._delta_inputs[key]
547
- if callable(out):
548
- try:
549
- init = init + out(*args, **kwargs)
550
- except Exception as e:
551
- raise ValueError(
552
- f'Error in delta input function {key}: {out}\n'
553
- f'Error: {e}'
554
- ) from e
555
- else:
556
- try:
557
- init = init + out
558
- except Exception as e:
559
- raise ValueError(
560
- f'Error in delta input value {key}: {out}\n'
561
- f'Error: {e}'
562
- ) from e
563
- self._delta_inputs.pop(key)
564
- return init
565
-
566
- @property
567
- def before_updates(self):
568
- """
569
- Get the dictionary of functions to execute before the module's update.
570
-
571
- Returns
572
- -------
573
- dict or None
574
- Dictionary mapping keys to callable functions that will be executed
575
- before the main update, or None if no before updates are registered.
576
-
577
- Notes
578
- -----
579
- Before updates are executed in the order they were registered whenever
580
- the module is called via __call__.
581
- """
582
- return self._before_updates
583
-
584
- @property
585
- def after_updates(self):
586
- """
587
- Get the dictionary of functions to execute after the module's update.
588
-
589
- Returns
590
- -------
591
- dict or None
592
- Dictionary mapping keys to callable functions that will be executed
593
- after the main update, or None if no after updates are registered.
594
-
595
- Notes
596
- -----
597
- After updates are executed in the order they were registered whenever
598
- the module is called via __call__, and may optionally receive the return
599
- value from the update method.
600
- """
601
- return self._after_updates
602
-
603
- def _add_before_update(self, key: Any, fun: Callable):
604
- """
605
- Register a function to be executed before the module's update.
606
-
607
- Parameters
608
- ----------
609
- key : Any
610
- A unique identifier for the update function.
611
- fun : Callable
612
- The function to execute before the module's update.
613
-
614
- Raises
615
- ------
616
- KeyError
617
- If the key is already registered in before_updates.
618
-
619
- Notes
620
- -----
621
- Internal method used by the module system to register dependencies.
622
- """
623
- if self._before_updates is None:
624
- self._before_updates = dict()
625
- if key in self.before_updates:
626
- raise KeyError(f'{key} has been registered in before_updates of {self}')
627
- self.before_updates[key] = fun
628
-
629
- def _add_after_update(self, key: Any, fun: Callable):
630
- """
631
- Register a function to be executed after the module's update.
632
-
633
- Parameters
634
- ----------
635
- key : Any
636
- A unique identifier for the update function.
637
- fun : Callable
638
- The function to execute after the module's update.
639
-
640
- Raises
641
- ------
642
- KeyError
643
- If the key is already registered in after_updates.
644
-
645
- Notes
646
- -----
647
- Internal method used by the module system to register dependencies.
648
- """
649
- if self._after_updates is None:
650
- self._after_updates = dict()
651
- if key in self.after_updates:
652
- raise KeyError(f'{key} has been registered in after_updates of {self}')
653
- self.after_updates[key] = fun
654
-
655
- def _get_before_update(self, key: Any):
656
- """
657
- Retrieve a registered before-update function by its key.
658
-
659
- Parameters
660
- ----------
661
- key : Any
662
- The identifier of the before-update function to retrieve.
663
-
664
- Returns
665
- -------
666
- Callable
667
- The registered before-update function.
668
-
669
- Raises
670
- ------
671
- KeyError
672
- If the key is not registered in before_updates or if before_updates is None.
673
- """
674
- if self._before_updates is None:
675
- raise KeyError(f'{key} is not registered in before_updates of {self}')
676
- if key not in self.before_updates:
677
- raise KeyError(f'{key} is not registered in before_updates of {self}')
678
- return self.before_updates.get(key)
679
-
680
- def _get_after_update(self, key: Any):
681
- """
682
- Retrieve a registered after-update function by its key.
683
-
684
- Parameters
685
- ----------
686
- key : Any
687
- The identifier of the after-update function to retrieve.
688
-
689
- Returns
690
- -------
691
- Callable
692
- The registered after-update function.
693
-
694
- Raises
695
- ------
696
- KeyError
697
- If the key is not registered in after_updates or if after_updates is None.
698
- """
699
- if self._after_updates is None:
700
- raise KeyError(f'{key} is not registered in after_updates of {self}')
701
- if key not in self.after_updates:
702
- raise KeyError(f'{key} is not registered in after_updates of {self}')
703
- return self.after_updates.get(key)
704
-
705
- def _has_before_update(self, key: Any):
706
- """
707
- Check if a before-update function is registered with the given key.
708
-
709
- Parameters
710
- ----------
711
- key : Any
712
- The identifier to check for in the before_updates dictionary.
713
-
714
- Returns
715
- -------
716
- bool
717
- True if the key is registered in before_updates, False otherwise.
718
- """
719
- if self._before_updates is None:
720
- return False
721
- return key in self.before_updates
722
-
723
- def _has_after_update(self, key: Any):
724
- """
725
- Check if an after-update function is registered with the given key.
726
-
727
- Parameters
728
- ----------
729
- key : Any
730
- The identifier to check for in the after_updates dictionary.
731
-
732
- Returns
733
- -------
734
- bool
735
- True if the key is registered in after_updates, False otherwise.
736
- """
737
- if self._after_updates is None:
738
- return False
739
- return key in self.after_updates
740
-
741
- def __call__(self, *args, **kwargs):
742
- """
743
- The shortcut to call ``update`` methods.
744
- """
745
-
746
- # ``before_updates``
747
- if self.before_updates is not None:
748
- for model in self.before_updates.values():
749
- if hasattr(model, '_receive_update_input'):
750
- model(*args, **kwargs)
751
- else:
752
- model()
753
-
754
- # update the model self
755
- ret = self.update(*args, **kwargs)
756
-
757
- # ``after_updates``
758
- if self.after_updates is not None:
759
- for model in self.after_updates.values():
760
- if hasattr(model, '_not_receive_update_output'):
761
- model()
762
- else:
763
- model(ret)
764
- return ret
765
-
766
- def prefetch(self, item: str) -> 'Prefetch':
767
- """
768
- Create a reference to a state or variable that may not be initialized yet.
769
-
770
- This method allows accessing module attributes or states before they are
771
- fully defined, acting as a placeholder that will be resolved when called.
772
- Particularly useful for creating references to variables that will be defined
773
- during initialization or runtime.
774
-
775
- Parameters
776
- ----------
777
- item : str
778
- The name of the attribute or state to reference.
779
-
780
- Returns
781
- -------
782
- Prefetch
783
- A Prefetch object that provides access to the referenced item.
784
-
785
- Examples
786
- --------
787
- >>> import brainstate
788
- >>> import brainunit as u
789
- >>> neuron = brainstate.nn.LIF(...)
790
- >>> v_ref = neuron.prefetch('V') # Reference to voltage
791
- >>> v_value = v_ref() # Get current value
792
- >>> delayed_v = v_ref.delay.at(5.0 * u.ms) # Get delayed value
793
- """
794
- return Prefetch(self, item)
795
-
796
- def align_pre(self, dyn: Union[ParamDescriber[T], T]) -> T:
797
- """
798
- Registers a dynamics module to execute after this module.
799
-
800
- This method establishes a sequential execution relationship where the specified
801
- dynamics module will be called after this module completes its update. This
802
- creates a feed-forward connection in the computational graph.
803
-
804
- Parameters
805
- ----------
806
- dyn : Union[ParamDescriber[T], T]
807
- The dynamics module to be executed after this module. Can be either:
808
- - An instance of Dynamics
809
- - A ParamDescriber that can instantiate a Dynamics object
810
-
811
- Returns
812
- -------
813
- T
814
- The dynamics module that was registered, allowing for method chaining.
815
-
816
- Raises
817
- ------
818
- TypeError
819
- If the input is not a Dynamics instance or a ParamDescriber that creates
820
- a Dynamics instance.
821
-
822
- Examples
823
- --------
824
- >>> import brainstate
825
- >>> n1 = brainstate.nn.LIF(10)
826
- >>> n1.align_pre(brainstate.nn.Expon.desc(n1.varshape)) # n2 will run after n1
827
- """
828
- if isinstance(dyn, Dynamics):
829
- self._add_after_update(id(dyn), dyn)
830
- return dyn
831
- elif isinstance(dyn, ParamDescriber):
832
- if not issubclass(dyn.cls, Dynamics):
833
- raise TypeError(f'The input {dyn} should be an instance of {Dynamics}.')
834
- if not self._has_after_update(dyn.identifier):
835
- self._add_after_update(
836
- dyn.identifier,
837
- dyn() if ('in_size' in dyn.kwargs or len(dyn.args) > 0) else dyn(in_size=self.varshape)
838
- )
839
- return self._get_after_update(dyn.identifier)
840
- else:
841
- raise TypeError(f'The input {dyn} should be an instance of {Dynamics} or a delayed initializer.')
842
-
843
- def prefetch_delay(self, state: str, delay_time, init: Callable = None) -> 'PrefetchDelayAt':
844
- """
845
- Create a reference to a delayed state or variable in the module.
846
-
847
- This method simplifies the process of accessing a delayed version of a state or variable
848
- within the module. It first creates a prefetch reference to the specified state,
849
- then specifies the delay time for accessing this state.
850
-
851
- Args:
852
- state (str): The name of the state or variable to reference.
853
- delay_time (ArrayLike): The amount of time to delay the variable access,
854
- typically in time units (e.g., milliseconds).
855
- init (Callable, optional): An optional initialization function to provide
856
- a default value if the delayed state is not yet available.
857
-
858
- Returns:
859
- PrefetchDelayAt: An object that provides access to the variable at the specified delay time.
860
- """
861
- return PrefetchDelayAt(self, state, delay_time, init=init)
862
-
863
- def output_delay(self, *delay_time) -> 'OutputDelayAt':
864
- """
865
- Create a reference to the delayed output of the module.
866
-
867
- This method simplifies the process of accessing a delayed version of the module's output.
868
- It instantiates an `OutputDelayAt` object, which can be used to retrieve the output value
869
- at the specified delay time.
870
-
871
- Args:
872
- delay (Optional[ArrayLike]): The amount of time to delay the output access,
873
- typically in time units (e.g., milliseconds). Defaults to None.
874
-
875
- Returns:
876
- OutputDelayAt: An object that provides access to the module's output at the specified delay time.
877
- """
878
- return OutputDelayAt(self, delay_time)
879
-
880
-
881
- class Prefetch(Node):
882
- """
883
- Prefetch a state or variable in a module before it is initialized.
884
-
885
-
886
- This class provides a mechanism to reference a module's state or attribute
887
- that may not have been initialized yet. It acts as a placeholder or reference
888
- that will be resolved when called.
889
-
890
- Use cases:
891
- - Access variables within dynamics modules that will be defined later
892
- - Create references to states across module boundaries
893
- - Enable access to delayed states through the `.delay` property
894
-
895
- Parameters
896
- ----------
897
- module : Module
898
- The module that contains or will contain the referenced item.
899
- item : str
900
- The attribute name of the state or variable to prefetch.
901
-
902
- Examples
903
- --------
904
- >>> import brainstate
905
- >>> import brainunit as u
906
- >>> neuron = brainstate.nn.LIF(...)
907
- >>> v_reference = neuron.prefetch('V') # Reference to voltage before initialization
908
- >>> v_value = v_reference() # Get the current value
909
- >>> delay_ref = v_reference.delay.at(5.0 * u.ms) # Reference voltage delayed by 5ms
910
-
911
- Notes
912
- -----
913
- When called, this class retrieves the current value of the referenced item.
914
- Use the `.delay` property to access delayed versions of the state.
915
-
916
- """
917
-
918
- def __init__(self, module: Dynamics, item: str):
919
- """
920
- Initialize a Prefetch object.
921
-
922
- Parameters
923
- ----------
924
- module : Module
925
- The module that contains or will contain the referenced item.
926
- item : str
927
- The attribute name of the state or variable to prefetch.
928
- """
929
- super().__init__()
930
- self.module = module
931
- self.item = item
932
-
933
- @property
934
- def delay(self):
935
- """
936
- Access delayed versions of the prefetched item.
937
-
938
- Returns
939
- -------
940
- PrefetchDelay
941
- An object that provides access to delayed versions of the prefetched item.
942
- """
943
- return PrefetchDelay(self.module, self.item)
944
- # return PrefetchDelayAt(self.module, self.item, time)
945
-
946
- def __call__(self, *args, **kwargs):
947
- """
948
- Get the current value of the prefetched item.
949
-
950
- Returns
951
- -------
952
- Any
953
- The current value of the referenced item. If the item is a State object,
954
- returns its value attribute, otherwise returns the item itself.
955
- """
956
- item = _get_prefetch_item(self)
957
- return item.value if isinstance(item, State) else item
958
-
959
- def get_item_value(self):
960
- """
961
- Get the current value of the prefetched item.
962
-
963
- Similar to __call__, but explicitly named for clarity.
964
-
965
- Returns
966
- -------
967
- Any
968
- The current value of the referenced item. If the item is a State object,
969
- returns its value attribute, otherwise returns the item itself.
970
- """
971
- item = _get_prefetch_item(self)
972
- return item.value if isinstance(item, State) else item
973
-
974
- def get_item(self):
975
- """
976
- Get the referenced item object itself, not its value.
977
-
978
- Returns
979
- -------
980
- Any
981
- The actual referenced item from the module, which could be a State
982
- object or any other attribute.
983
- """
984
- return _get_prefetch_item(self)
985
-
986
-
987
- class PrefetchDelay(Node):
988
- """
989
- Provides access to delayed versions of a prefetched state or variable.
990
-
991
- This class acts as an intermediary for accessing delayed values of module variables.
992
- It doesn't retrieve values directly but provides methods to specify the delay time
993
- via the `at()` method.
994
-
995
- Parameters
996
- ----------
997
- module : Dynamics
998
- The dynamics module that contains the referenced state or variable.
999
- item : str
1000
- The name of the state or variable to access with delay.
1001
-
1002
- Examples
1003
- --------
1004
- >>> import brainstate
1005
- >>> import brainunit as u
1006
- >>> neuron = brainstate.nn.LIF(10)
1007
- >>> # Access voltage delayed by 5ms
1008
- >>> delayed_v = neuron.prefetch('V').delay.at(5.0 * u.ms)
1009
- >>> delayed_value = delayed_v() # Get the delayed value
1010
- """
1011
-
1012
- def __init__(self, module: Dynamics, item: str):
1013
- self.module = module
1014
- self.item = item
1015
-
1016
- def at(self, *delay_time):
1017
- """
1018
- Specifies the delay time for accessing the variable.
1019
-
1020
- Parameters
1021
- ----------
1022
- time : ArrayLike
1023
- The amount of time to delay the variable access, typically in time units
1024
- (e.g., milliseconds).
1025
-
1026
- Returns
1027
- -------
1028
- PrefetchDelayAt
1029
- An object that provides access to the variable at the specified delay time.
1030
- """
1031
- return PrefetchDelayAt(self.module, self.item, delay_time)
1032
-
1033
-
1034
- class PrefetchDelayAt(Node):
1035
- """
1036
- Provides access to a specific delayed state or variable value at the specific time.
1037
-
1038
- This class represents the final step in the prefetch delay chain, providing
1039
- actual access to state values at a specific delay time. It converts the
1040
- specified time delay into steps and registers the delay with the appropriate
1041
- StateWithDelay handler.
1042
-
1043
- Parameters
1044
- ----------
1045
- module : Dynamics
1046
- The dynamics module that contains the referenced state or variable.
1047
- item : str
1048
- The name of the state or variable to access with delay.
1049
- time : ArrayLike
1050
- The amount of time to delay access by, typically in time units (e.g., milliseconds).
1051
-
1052
- Examples
1053
- --------
1054
- >>> import brainstate
1055
- >>> import brainunit as u
1056
- >>> neuron = brainstate.nn.LIF(10)
1057
- >>> # Create a reference to voltage delayed by 5ms
1058
- >>> delayed_v = PrefetchDelayAt(neuron, 'V', 5.0 * u.ms)
1059
- >>> # Get the delayed value
1060
- >>> v_value = delayed_v()
1061
- """
1062
-
1063
- def __init__(
1064
- self,
1065
- module: Dynamics,
1066
- item: str,
1067
- delay_time: Tuple,
1068
- init: Callable = None
1069
- ):
1070
- """
1071
- Initialize a PrefetchDelayAt object.
1072
-
1073
- Parameters
1074
- ----------
1075
- module : Dynamics
1076
- The dynamics module that contains the referenced state or variable.
1077
- item : str
1078
- The name of the state or variable to access with delay.
1079
- delay_time : Tuple
1080
- The amount of time to delay access by, typically in time units (e.g., milliseconds).
1081
- """
1082
- super().__init__()
1083
- assert isinstance(module, Dynamics), 'The module should be an instance of Dynamics.'
1084
- self.module = module
1085
- self.item = item
1086
- if not isinstance(delay_time, (tuple, list)):
1087
- delay_time = (delay_time,)
1088
- self.delay_time = delay_time
1089
- if len(delay_time) > 0:
1090
- key = _get_prefetch_delay_key(item)
1091
- if not module._has_after_update(key):
1092
- module._add_after_update(
1093
- key,
1094
- not_receive_update_output(
1095
- StateWithDelay(module, item, init=init)
1096
- )
1097
- )
1098
- self.state_delay: StateWithDelay = module._get_after_update(key)
1099
- self.delay_info = self.state_delay.register_delay(*delay_time)
1100
-
1101
- def __call__(self, *args, **kwargs):
1102
- """
1103
- Retrieve the value of the state at the specified delay time.
1104
-
1105
- Returns
1106
- -------
1107
- Any
1108
- The value of the state or variable at the specified delay time.
1109
- """
1110
- if len(self.delay_time) == 0:
1111
- return _get_prefetch_item(self).value
1112
- else:
1113
- return self.state_delay.retrieve_at_step(*self.delay_info)
1114
-
1115
-
1116
- class OutputDelayAt(Node):
1117
- """
1118
- Provides access to a specific delayed state or variable value at the specific time.
1119
-
1120
- This class represents the final step in the prefetch delay chain, providing
1121
- actual access to state values at a specific delay time. It converts the
1122
- specified time delay into steps and registers the delay with the appropriate
1123
- StateWithDelay handler.
1124
-
1125
- Parameters
1126
- ----------
1127
- module : Dynamics
1128
- The dynamics module that contains the referenced state or variable.
1129
- time : ArrayLike
1130
- The amount of time to delay access by, typically in time units (e.g., milliseconds).
1131
-
1132
- Examples
1133
- --------
1134
- >>> import brainstate
1135
- >>> import brainunit as u
1136
- >>> neuron = brainstate.nn.LIF(10)
1137
- >>> # Create a reference to voltage delayed by 5ms
1138
- >>> delayed_spike = OutputDelayAt(neuron, 5.0 * u.ms)
1139
- >>> # Get the delayed value
1140
- >>> v_value = delayed_spike()
1141
- """
1142
-
1143
- def __init__(
1144
- self,
1145
- module: Dynamics,
1146
- delay_time: Tuple,
1147
- ):
1148
- super().__init__()
1149
- assert isinstance(module, Dynamics), 'The module should be an instance of Dynamics.'
1150
- self.module = module
1151
- key = _get_output_delay_key()
1152
- if not module._has_after_update(key):
1153
- delay = Delay(jax.ShapeDtypeStruct(module.out_size, dtype=environ.dftype()), take_aware_unit=True)
1154
- module._add_after_update(key, receive_update_output(delay))
1155
- self.out_delay: Delay = module._get_after_update(key)
1156
- self.delay_info = self.out_delay.register_delay(*delay_time)
1157
-
1158
- def __call__(self, *args, **kwargs):
1159
- return self.out_delay.retrieve_at_step(*self.delay_info)
1160
-
1161
-
1162
- def _get_prefetch_delay_key(item) -> str:
1163
- return f'{item}-prefetch-delay'
1164
-
1165
-
1166
- def _get_output_delay_key() -> str:
1167
- return f'output-delay'
1168
-
1169
-
1170
- def _get_prefetch_item(target: Union[Prefetch, PrefetchDelayAt]) -> Any:
1171
- item = getattr(target.module, target.item, None)
1172
- if item is None:
1173
- raise AttributeError(f'The target {target.module} should have an `{target.item}` attribute.')
1174
- return item
1175
-
1176
-
1177
- def _get_prefetch_item_delay(target: Union[Prefetch, PrefetchDelay, PrefetchDelayAt]) -> Delay:
1178
- assert isinstance(target.module, Dynamics), (
1179
- f'The target module should be an instance '
1180
- f'of Dynamics. But got {target.module}.'
1181
- )
1182
- delay = target.module._get_after_update(_get_prefetch_delay_key(target.item))
1183
- if not isinstance(delay, StateWithDelay):
1184
- raise TypeError(f'The prefetch target should be a {StateWithDelay.__name__} when accessing '
1185
- f'its delay. But got {delay}.')
1186
- return delay
1187
-
1188
-
1189
- def maybe_init_prefetch(target, *args, **kwargs):
1190
- """
1191
- Initialize a prefetch target if needed, based on its type.
1192
-
1193
- This function ensures that prefetch references are properly initialized
1194
- and ready to use. It handles different types of prefetch objects by
1195
- performing the appropriate initialization action:
1196
- - For :py:class:`Prefetch` objects: retrieves the referenced item
1197
- - For :py:class:`PrefetchDelay` objects: retrieves the delay handler
1198
- - For :py:class:`PrefetchDelayAt` objects: registers the specified delay
1199
-
1200
- Parameters
1201
- ----------
1202
- target : Union[Prefetch, PrefetchDelay, PrefetchDelayAt]
1203
- The prefetch target to initialize.
1204
- *args : Any
1205
- Additional positional arguments (unused).
1206
- **kwargs : Any
1207
- Additional keyword arguments (unused).
1208
-
1209
- Returns
1210
- -------
1211
- None
1212
- This function performs initialization side effects only.
1213
-
1214
- Notes
1215
- -----
1216
- This function is typically called internally when prefetched references
1217
- are used to ensure they are properly set up before access.
1218
- """
1219
- if isinstance(target, Prefetch):
1220
- _get_prefetch_item(target)
1221
-
1222
- elif isinstance(target, PrefetchDelay):
1223
- _get_prefetch_item_delay(target)
1224
-
1225
- elif isinstance(target, PrefetchDelayAt):
1226
- pass
1227
- # delay = _get_prefetch_item_delay(target)
1228
- # delay.register_delay(*target.delay_time)
1229
-
1230
-
1231
- class DynamicsGroup(Module):
1232
- """
1233
- A group of :py:class:`~.Module` in which the updating order does not matter.
1234
-
1235
- Args:
1236
- children_as_tuple: The children objects.
1237
- children_as_dict: The children objects.
1238
- """
1239
-
1240
- __module__ = 'brainstate.nn'
1241
-
1242
- if not TYPE_CHECKING:
1243
- def __init__(self, *children_as_tuple, **children_as_dict):
1244
- super().__init__()
1245
- self.layers_tuple = tuple(children_as_tuple)
1246
- self.layers_dict = dict(children_as_dict)
1247
-
1248
- def update(self, *args, **kwargs):
1249
- """
1250
- Update function of a network.
1251
-
1252
- In this update function, the update functions in children systems are iteratively called.
1253
- """
1254
- projs, dyns, others = self.nodes(allowed_hierarchy=(1, 1)).split(Projection, Dynamics)
1255
-
1256
- # update nodes of projections
1257
- for node in projs.values():
1258
- node()
1259
-
1260
- # update nodes of dynamics
1261
- for node in dyns.values():
1262
- node()
1263
-
1264
- # update nodes with other types, including delays, ...
1265
- for node in others.values():
1266
- node()
1267
-
1268
-
1269
- def receive_update_output(cls: object):
1270
- """
1271
- The decorator to mark the object (as the after updates) to receive the output of the update function.
1272
-
1273
- That is, the `aft_update` will receive the return of the update function::
1274
-
1275
- ret = model.update(*args, **kwargs)
1276
- for fun in model.aft_updates:
1277
- fun(ret)
1278
-
1279
- """
1280
- # assert isinstance(cls, Module), 'The input class should be instance of Module.'
1281
- if hasattr(cls, '_not_receive_update_output'):
1282
- delattr(cls, '_not_receive_update_output')
1283
- return cls
1284
-
1285
-
1286
- def not_receive_update_output(cls: T) -> T:
1287
- """
1288
- The decorator to mark the object (as the after updates) to not receive the output of the update function.
1289
-
1290
- That is, the `aft_update` will not receive the return of the update function::
1291
-
1292
- ret = model.update(*args, **kwargs)
1293
- for fun in model.aft_updates:
1294
- fun()
1295
-
1296
- """
1297
- # assert isinstance(cls, Module), 'The input class should be instance of Module.'
1298
- cls._not_receive_update_output = True
1299
- return cls
1300
-
1301
-
1302
- def receive_update_input(cls: object):
1303
- """
1304
- The decorator to mark the object (as the before updates) to receive the input of the update function.
1305
-
1306
- That is, the `bef_update` will receive the input of the update function::
1307
-
1308
-
1309
- for fun in model.bef_updates:
1310
- fun(*args, **kwargs)
1311
- model.update(*args, **kwargs)
1312
-
1313
- """
1314
- # assert isinstance(cls, Module), 'The input class should be instance of Module.'
1315
- cls._receive_update_input = True
1316
- return cls
1317
-
1318
-
1319
- def not_receive_update_input(cls: object):
1320
- """
1321
- The decorator to mark the object (as the before updates) to not receive the input of the update function.
1322
-
1323
- That is, the `bef_update` will not receive the input of the update function::
1324
-
1325
- for fun in model.bef_updates:
1326
- fun()
1327
- model.update()
1328
-
1329
- """
1330
- # assert isinstance(cls, Module), 'The input class should be instance of Module.'
1331
- if hasattr(cls, '_receive_update_input'):
1332
- delattr(cls, '_receive_update_input')
1333
- return cls
1334
-
1335
-
1336
- def _input_label_start(label: str):
1337
- # unify the input label repr.
1338
- return f'{label} // '
1339
-
1340
-
1341
- def _input_label_repr(name: str, label: Optional[str] = None):
1342
- # unify the input label repr.
1343
- return name if label is None else (_input_label_start(label) + str(name))
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+
19
+ """
20
+ All the basic dynamics class for the ``brainstate``.
21
+
22
+ For handling dynamical systems:
23
+
24
+ - ``DynamicsGroup``: The class for a group of modules, which update ``Projection`` first,
25
+ then ``Dynamics``, finally others.
26
+ - ``Projection``: The class for the synaptic projection.
27
+ - ``Dynamics``: The class for the dynamical system.
28
+
29
+ For handling the delays:
30
+
31
+ - ``Delay``: The class for all delays.
32
+ - ``DelayAccess``: The class for the delay access.
33
+
34
+ """
35
+
36
+ from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, TYPE_CHECKING, Tuple
37
+
38
+ import jax
39
+ import numpy as np
40
+
41
+ from brainstate import environ
42
+ from brainstate._state import State
43
+ from brainstate.graph import Node
44
+ from brainstate.mixin import ParamDescriber
45
+ from brainstate.typing import Size, ArrayLike
46
+ from ._delay import StateWithDelay, Delay
47
+ from ._module import Module
48
+
49
+ __all__ = [
50
+ 'DynamicsGroup', 'Projection', 'Dynamics',
51
+ 'Prefetch', 'PrefetchDelay', 'PrefetchDelayAt', 'OutputDelayAt',
52
+ ]
53
+
54
+ T = TypeVar('T')
55
+ _max_order = 10
56
+
57
+
58
+ class Projection(Module):
59
+ """
60
+ Base class for synaptic projection modules in neural network modeling.
61
+
62
+ This class defines the interface for modules that handle projections between
63
+ neural populations. Projections process input signals and transform them
64
+ before they reach the target neurons, implementing the connectivity patterns
65
+ in neural networks.
66
+
67
+ In the BrainState execution order, Projection modules are updated before
68
+ Dynamics modules, following the natural information flow in neural systems:
69
+ 1. Projections process inputs (synaptic transmission)
70
+ 2. Dynamics update neuron states (neural integration)
71
+
72
+ The Projection class does not implement the update logic directly but delegates
73
+ to its child nodes. If no child nodes exist, it raises a ValueError.
74
+
75
+ Parameters
76
+ ----------
77
+ *args : Any
78
+ Arguments passed to the parent Module class.
79
+ **kwargs : Any
80
+ Keyword arguments passed to the parent Module class.
81
+
82
+ Raises
83
+ ------
84
+ ValueError
85
+ If the update() method is called but no child nodes are defined.
86
+
87
+ Notes
88
+ -----
89
+ Derived classes should implement specific projection behaviors, such as
90
+ dense connectivity, sparse connectivity, or specific weight update rules.
91
+ """
92
+ __module__ = 'brainstate.nn'
93
+
94
+ def update(self, *args, **kwargs):
95
+ sub_nodes = tuple(self.nodes(allowed_hierarchy=(1, 1)).values())
96
+ if len(sub_nodes):
97
+ for node in sub_nodes:
98
+ node(*args, **kwargs)
99
+ else:
100
+ raise ValueError('Do not implement the update() function.')
101
+
102
+
103
+ class Dynamics(Module):
104
+ """
105
+ Base class for implementing neural dynamics models in BrainState.
106
+
107
+ Dynamics classes represent the core computational units in neural simulations,
108
+ implementing the differential equations or update rules that govern neural activity.
109
+ This class provides infrastructure for managing neural populations, handling inputs,
110
+ and coordinating updates within the simulation framework.
111
+
112
+ The Dynamics class serves several key purposes:
113
+ 1. Managing neuron population geometry and size information
114
+ 2. Handling current and delta (instantaneous change) inputs to neurons
115
+ 3. Supporting before/after update hooks for computational dependencies
116
+ 4. Providing access to delayed state variables through the prefetch mechanism
117
+ 5. Establishing the execution order in neural network simulations
118
+
119
+ Parameters
120
+ ----------
121
+ in_size : Size
122
+ The geometry of the neuron population. Can be an integer (e.g., 10) for
123
+ 1D neuron arrays, or a tuple (e.g., (10, 10)) for multi-dimensional populations.
124
+ name : Optional[str], default=None
125
+ Optional name identifier for this dynamics module.
126
+
127
+ Attributes
128
+ ----------
129
+ in_size : tuple
130
+ The shape/geometry of the neuron population.
131
+ out_size : tuple
132
+ The output shape, typically matches in_size.
133
+ current_inputs : Optional[Dict[str, Union[Callable, ArrayLike]]]
134
+ Dictionary of registered current input functions or arrays.
135
+ delta_inputs : Optional[Dict[str, Union[Callable, ArrayLike]]]
136
+ Dictionary of registered delta input functions or arrays.
137
+ before_updates : Optional[Dict[Hashable, Callable]]
138
+ Dictionary of functions to call before the main update.
139
+ after_updates : Optional[Dict[Hashable, Callable]]
140
+ Dictionary of functions to call after the main update.
141
+
142
+ Notes
143
+ -----
144
+ In the BrainState execution sequence, Dynamics modules are updated after
145
+ Projection modules and before other module types, reflecting the natural
146
+ flow of information in neural systems.
147
+
148
+ There are several essential attributes:
149
+
150
+ - ``size``: the geometry of the neuron group. For example, `(10, )` denotes a line of
151
+ neurons, `(10, 10)` denotes a neuron group aligned in a 2D space, `(10, 15, 4)` denotes
152
+ a 3-dimensional neuron group.
153
+ - ``num``: the flattened number of neurons in the group. For example, `size=(10, )` => \
154
+ `num=10`, `size=(10, 10)` => `num=100`, `size=(10, 15, 4)` => `num=600`.
155
+
156
+
157
+ See Also
158
+ --------
159
+ Module : Parent class providing base module functionality
160
+ Projection : Class for handling synaptic projections between neural populations
161
+ DynamicsGroup : Container for organizing multiple dynamics modules
162
+ """
163
+
164
+ __module__ = 'brainstate.nn'
165
+
166
+ graph_invisible_attrs = ('_before_updates', '_after_updates', '_current_inputs', '_delta_inputs')
167
+
168
+ # before updates
169
+ _before_updates: Optional[Dict[Hashable, Callable]]
170
+
171
+ # after updates
172
+ _after_updates: Optional[Dict[Hashable, Callable]]
173
+
174
+ # current inputs
175
+ _current_inputs: Optional[Dict[str, ArrayLike | Callable]]
176
+
177
+ # delta inputs
178
+ _delta_inputs: Optional[Dict[str, ArrayLike | Callable]]
179
+
180
+ def __init__(
181
+ self,
182
+ in_size: Size,
183
+ name: Optional[str] = None,
184
+ ):
185
+ # initialize
186
+ super().__init__(name=name)
187
+
188
+ # geometry size of neuron population
189
+ if isinstance(in_size, (list, tuple)):
190
+ if len(in_size) <= 0:
191
+ raise ValueError(f'"in_size" must be int, or a tuple/list of int. But we got {type(in_size)}')
192
+ if not isinstance(in_size[0], (int, np.integer)):
193
+ raise ValueError(f'"in_size" must be int, or a tuple/list of int. But we got {type(in_size)}')
194
+ in_size = tuple(in_size)
195
+ elif isinstance(in_size, (int, np.integer)):
196
+ in_size = (in_size,)
197
+ else:
198
+ raise ValueError(f'"in_size" must be int, or a tuple/list of int. But we got {type(in_size)}')
199
+ self.in_size = in_size
200
+
201
+ # current inputs
202
+ self._current_inputs = None
203
+
204
+ # delta inputs
205
+ self._delta_inputs = None
206
+
207
+ # before updates
208
+ self._before_updates = None
209
+
210
+ # after updates
211
+ self._after_updates = None
212
+
213
+ # in-/out- size of neuron population
214
+ self.out_size = self.in_size
215
+
216
+ # def __pretty_repr_item__(self, name, value):
217
+ # if name in [
218
+ # '_before_updates', '_after_updates', '_current_inputs', '_delta_inputs',
219
+ # '_in_size', '_out_size', '_name', '_mode',
220
+ # ]:
221
+ # return (name, value) if value is None else (name[1:], value) # skip the first `_`
222
+ # return super().__pretty_repr_item__(name, value)
223
+
224
+ @property
225
+ def varshape(self):
226
+ """
227
+ Get the shape of variables in the neuron group.
228
+
229
+ This property provides access to the geometry (shape) of the neuron population,
230
+ which determines how variables and states are structured.
231
+
232
+ Returns
233
+ -------
234
+ tuple
235
+ A tuple representing the dimensional shape of the neuron group,
236
+ matching the in_size parameter provided during initialization.
237
+
238
+ See Also
239
+ --------
240
+ in_size : The input geometry specification for the neuron group
241
+ """
242
+ return self.in_size
243
+
244
+ @property
245
+ def current_inputs(self):
246
+ """
247
+ Get the dictionary of current inputs registered with this dynamics model.
248
+
249
+ Current inputs represent direct input currents that flow into the model.
250
+
251
+ Returns
252
+ -------
253
+ dict or None
254
+ A dictionary mapping keys to current input functions or values,
255
+ or None if no current inputs have been registered.
256
+
257
+ See Also
258
+ --------
259
+ add_current_input : Register a new current input
260
+ sum_current_inputs : Apply and sum all current inputs
261
+ delta_inputs : Dictionary of instantaneous change inputs
262
+ """
263
+ return self._current_inputs
264
+
265
+ @property
266
+ def delta_inputs(self):
267
+ """
268
+ Get the dictionary of delta inputs registered with this dynamics model.
269
+
270
+ Delta inputs represent instantaneous changes to state variables (dX/dt).
271
+
272
+ Returns
273
+ -------
274
+ dict or None
275
+ A dictionary mapping keys to delta input functions or values,
276
+ or None if no delta inputs have been registered.
277
+
278
+ See Also
279
+ --------
280
+ add_delta_input : Register a new delta input
281
+ sum_delta_inputs : Apply and sum all delta inputs
282
+ current_inputs : Dictionary of direct current inputs
283
+ """
284
+ return self._delta_inputs
285
+
286
+ def add_current_input(
287
+ self,
288
+ key: str,
289
+ inp: Union[Callable, ArrayLike],
290
+ label: Optional[str] = None
291
+ ):
292
+ """
293
+ Add a current input function or array to the dynamics model.
294
+
295
+ Current inputs represent direct input currents that can be accessed during
296
+ model updates through the `sum_current_inputs()` method.
297
+
298
+ Parameters
299
+ ----------
300
+ key : str
301
+ Unique identifier for this current input. Used to retrieve or reference
302
+ the input later.
303
+ inp : Union[Callable, ArrayLike]
304
+ The input data or function that generates input data.
305
+ - If callable: Will be called during updates with arguments passed to `sum_current_inputs()`
306
+ - If array-like: Will be applied once and then automatically removed from available inputs
307
+ label : Optional[str], default=None
308
+ Optional grouping label for the input. When provided, allows selective
309
+ processing of inputs by label in `sum_current_inputs()`.
310
+
311
+ Raises
312
+ ------
313
+ ValueError
314
+ If the key has already been used for a different current input.
315
+
316
+ Notes
317
+ -----
318
+ - Inputs with the same label can be processed together using the `label`
319
+ parameter in `sum_current_inputs()`.
320
+ - Non-callable inputs are consumed when used (removed after first use).
321
+ - Callable inputs persist and can be called repeatedly.
322
+
323
+ See Also
324
+ --------
325
+ sum_current_inputs : Sum all current inputs matching a given label
326
+ add_delta_input : Add a delta input function or array
327
+ """
328
+ key = _input_label_repr(key, label)
329
+ if self._current_inputs is None:
330
+ self._current_inputs = dict()
331
+ if key in self._current_inputs:
332
+ if id(self._current_inputs[key]) != id(inp):
333
+ raise ValueError(f'Key "{key}" has been defined and used in the current inputs of {self}.')
334
+ self._current_inputs[key] = inp
335
+
336
+ def add_delta_input(
337
+ self,
338
+ key: str,
339
+ inp: Union[Callable, ArrayLike],
340
+ label: Optional[str] = None
341
+ ):
342
+ """
343
+ Add a delta input function or array to the dynamics model.
344
+
345
+ Delta inputs represent instantaneous changes to the model state (i.e., dX/dt contributions).
346
+ This method registers a function or array that provides delta inputs which will be
347
+ accessible during model updates through the `sum_delta_inputs()` method.
348
+
349
+ Parameters
350
+ ----------
351
+ key : str
352
+ Unique identifier for this delta input. Used to retrieve or reference
353
+ the input later.
354
+ inp : Union[Callable, ArrayLike]
355
+ The input data or function that generates input data.
356
+ - If callable: Will be called during updates with arguments passed to `sum_delta_inputs()`
357
+ - If array-like: Will be applied once and then automatically removed from available inputs
358
+ label : Optional[str], default=None
359
+ Optional grouping label for the input. When provided, allows selective
360
+ processing of inputs by label in `sum_delta_inputs()`.
361
+
362
+ Raises
363
+ ------
364
+ ValueError
365
+ If the key has already been used for a different delta input.
366
+
367
+ Notes
368
+ -----
369
+ - Inputs with the same label can be processed together using the `label`
370
+ parameter in `sum_delta_inputs()`.
371
+ - Non-callable inputs are consumed when used (removed after first use).
372
+ - Callable inputs persist and can be called repeatedly.
373
+
374
+ See Also
375
+ --------
376
+ sum_delta_inputs : Sum all delta inputs matching a given label
377
+ add_current_input : Add a current input function or array
378
+ """
379
+ key = _input_label_repr(key, label)
380
+ if self._delta_inputs is None:
381
+ self._delta_inputs = dict()
382
+ if key in self._delta_inputs:
383
+ if id(self._delta_inputs[key]) != id(inp):
384
+ raise ValueError(f'Key "{key}" has been defined and used.')
385
+ self._delta_inputs[key] = inp
386
+
387
+ def get_input(self, key: str):
388
+ """
389
+ Get a registered input function by its key.
390
+
391
+ Retrieves either a current input or a delta input function that was previously
392
+ registered with the given key. This method checks both current_inputs and
393
+ delta_inputs dictionaries for the specified key.
394
+
395
+ Parameters
396
+ ----------
397
+ key : str
398
+ The unique identifier used when the input function was registered.
399
+
400
+ Returns
401
+ -------
402
+ Callable or ArrayLike
403
+ The input function or array associated with the given key.
404
+
405
+ Raises
406
+ ------
407
+ ValueError
408
+ If no input function is found with the specified key in either
409
+ current_inputs or delta_inputs.
410
+
411
+ See Also
412
+ --------
413
+ add_current_input : Register a current input function
414
+ add_delta_input : Register a delta input function
415
+
416
+ Examples
417
+ --------
418
+ >>> model = Dynamics(10)
419
+ >>> model.add_current_input('stimulus', lambda t: np.sin(t))
420
+ >>> input_func = model.get_input('stimulus')
421
+ >>> input_func(0.5) # Returns sin(0.5)
422
+ """
423
+ if self._current_inputs is not None and key in self._current_inputs:
424
+ return self._current_inputs[key]
425
+ elif self._delta_inputs is not None and key in self._delta_inputs:
426
+ return self._delta_inputs[key]
427
+ else:
428
+ raise ValueError(f'Input key {key} is not in current/delta inputs of the module {self}.')
429
+
430
+ def sum_current_inputs(
431
+ self,
432
+ init: Any,
433
+ *args,
434
+ label: Optional[str] = None,
435
+ **kwargs
436
+ ):
437
+ """
438
+ Summarize all current inputs by applying and summing all registered current input functions.
439
+
440
+ This method iterates through all registered current input functions (from `.current_inputs`)
441
+ and applies them to calculate the total input current for the dynamics model. It adds all results
442
+ to the initial value provided.
443
+
444
+ Parameters
445
+ ----------
446
+ init : Any
447
+ The initial value to which all current inputs will be added.
448
+ *args : tuple
449
+ Variable length argument list passed to each current input function.
450
+ label : Optional[str], default=None
451
+ If provided, only process current inputs with this label prefix.
452
+ When None, process all current inputs regardless of label.
453
+ **kwargs : dict
454
+ Arbitrary keyword arguments passed to each current input function.
455
+
456
+ Returns
457
+ -------
458
+ Any
459
+ The initial value plus all applicable current inputs summed together.
460
+
461
+ Notes
462
+ -----
463
+ - Non-callable current inputs are applied once and then automatically removed from
464
+ the current_inputs dictionary.
465
+ - Callable current inputs remain registered for subsequent calls.
466
+ - When a label is provided, only current inputs with keys starting with that label
467
+ are applied.
468
+ """
469
+ if self._current_inputs is None:
470
+ return init
471
+ if label is None:
472
+ filter_fn = lambda k: True
473
+ else:
474
+ label_repr = _input_label_start(label)
475
+ filter_fn = lambda k: k.startswith(label_repr)
476
+ for key in tuple(self._current_inputs.keys()):
477
+ if filter_fn(key):
478
+ out = self._current_inputs[key]
479
+ if callable(out):
480
+ try:
481
+ init = init + out(*args, **kwargs)
482
+ except Exception as e:
483
+ raise ValueError(
484
+ f'Error in delta input value {key}: {out}\n'
485
+ f'Error: {e}'
486
+ ) from e
487
+ else:
488
+ try:
489
+ init = init + out
490
+ except Exception as e:
491
+ raise ValueError(
492
+ f'Error in delta input value {key}: {out}\n'
493
+ f'Error: {e}'
494
+ ) from e
495
+ self._current_inputs.pop(key)
496
+ return init
497
+
498
+ def sum_delta_inputs(
499
+ self,
500
+ init: Any,
501
+ *args,
502
+ label: Optional[str] = None,
503
+ **kwargs
504
+ ):
505
+ """
506
+ Summarize all delta inputs by applying and summing all registered delta input functions.
507
+
508
+ This method iterates through all registered delta input functions (from `.delta_inputs`)
509
+ and applies them to calculate instantaneous changes to model states. It adds all results
510
+ to the initial value provided.
511
+
512
+ Parameters
513
+ ----------
514
+ init : Any
515
+ The initial value to which all delta inputs will be added.
516
+ *args : tuple
517
+ Variable length argument list passed to each delta input function.
518
+ label : Optional[str], default=None
519
+ If provided, only process delta inputs with this label prefix.
520
+ When None, process all delta inputs regardless of label.
521
+ **kwargs : dict
522
+ Arbitrary keyword arguments passed to each delta input function.
523
+
524
+ Returns
525
+ -------
526
+ Any
527
+ The initial value plus all applicable delta inputs summed together.
528
+
529
+ Notes
530
+ -----
531
+ - Non-callable delta inputs are applied once and then automatically removed from
532
+ the delta_inputs dictionary.
533
+ - Callable delta inputs remain registered for subsequent calls.
534
+ - When a label is provided, only delta inputs with keys starting with that label
535
+ are applied.
536
+ """
537
+ if self._delta_inputs is None:
538
+ return init
539
+ if label is None:
540
+ filter_fn = lambda k: True
541
+ else:
542
+ label_repr = _input_label_start(label)
543
+ filter_fn = lambda k: k.startswith(label_repr)
544
+ for key in tuple(self._delta_inputs.keys()):
545
+ if filter_fn(key):
546
+ out = self._delta_inputs[key]
547
+ if callable(out):
548
+ try:
549
+ init = init + out(*args, **kwargs)
550
+ except Exception as e:
551
+ raise ValueError(
552
+ f'Error in delta input function {key}: {out}\n'
553
+ f'Error: {e}'
554
+ ) from e
555
+ else:
556
+ try:
557
+ init = init + out
558
+ except Exception as e:
559
+ raise ValueError(
560
+ f'Error in delta input value {key}: {out}\n'
561
+ f'Error: {e}'
562
+ ) from e
563
+ self._delta_inputs.pop(key)
564
+ return init
565
+
566
+ @property
567
+ def before_updates(self):
568
+ """
569
+ Get the dictionary of functions to execute before the module's update.
570
+
571
+ Returns
572
+ -------
573
+ dict or None
574
+ Dictionary mapping keys to callable functions that will be executed
575
+ before the main update, or None if no before updates are registered.
576
+
577
+ Notes
578
+ -----
579
+ Before updates are executed in the order they were registered whenever
580
+ the module is called via __call__.
581
+ """
582
+ return self._before_updates
583
+
584
+ @property
585
+ def after_updates(self):
586
+ """
587
+ Get the dictionary of functions to execute after the module's update.
588
+
589
+ Returns
590
+ -------
591
+ dict or None
592
+ Dictionary mapping keys to callable functions that will be executed
593
+ after the main update, or None if no after updates are registered.
594
+
595
+ Notes
596
+ -----
597
+ After updates are executed in the order they were registered whenever
598
+ the module is called via __call__, and may optionally receive the return
599
+ value from the update method.
600
+ """
601
+ return self._after_updates
602
+
603
+ def _add_before_update(self, key: Any, fun: Callable):
604
+ """
605
+ Register a function to be executed before the module's update.
606
+
607
+ Parameters
608
+ ----------
609
+ key : Any
610
+ A unique identifier for the update function.
611
+ fun : Callable
612
+ The function to execute before the module's update.
613
+
614
+ Raises
615
+ ------
616
+ KeyError
617
+ If the key is already registered in before_updates.
618
+
619
+ Notes
620
+ -----
621
+ Internal method used by the module system to register dependencies.
622
+ """
623
+ if self._before_updates is None:
624
+ self._before_updates = dict()
625
+ if key in self.before_updates:
626
+ raise KeyError(f'{key} has been registered in before_updates of {self}')
627
+ self.before_updates[key] = fun
628
+
629
+ def _add_after_update(self, key: Any, fun: Callable):
630
+ """
631
+ Register a function to be executed after the module's update.
632
+
633
+ Parameters
634
+ ----------
635
+ key : Any
636
+ A unique identifier for the update function.
637
+ fun : Callable
638
+ The function to execute after the module's update.
639
+
640
+ Raises
641
+ ------
642
+ KeyError
643
+ If the key is already registered in after_updates.
644
+
645
+ Notes
646
+ -----
647
+ Internal method used by the module system to register dependencies.
648
+ """
649
+ if self._after_updates is None:
650
+ self._after_updates = dict()
651
+ if key in self.after_updates:
652
+ raise KeyError(f'{key} has been registered in after_updates of {self}')
653
+ self.after_updates[key] = fun
654
+
655
+ def _get_before_update(self, key: Any):
656
+ """
657
+ Retrieve a registered before-update function by its key.
658
+
659
+ Parameters
660
+ ----------
661
+ key : Any
662
+ The identifier of the before-update function to retrieve.
663
+
664
+ Returns
665
+ -------
666
+ Callable
667
+ The registered before-update function.
668
+
669
+ Raises
670
+ ------
671
+ KeyError
672
+ If the key is not registered in before_updates or if before_updates is None.
673
+ """
674
+ if self._before_updates is None:
675
+ raise KeyError(f'{key} is not registered in before_updates of {self}')
676
+ if key not in self.before_updates:
677
+ raise KeyError(f'{key} is not registered in before_updates of {self}')
678
+ return self.before_updates.get(key)
679
+
680
+ def _get_after_update(self, key: Any):
681
+ """
682
+ Retrieve a registered after-update function by its key.
683
+
684
+ Parameters
685
+ ----------
686
+ key : Any
687
+ The identifier of the after-update function to retrieve.
688
+
689
+ Returns
690
+ -------
691
+ Callable
692
+ The registered after-update function.
693
+
694
+ Raises
695
+ ------
696
+ KeyError
697
+ If the key is not registered in after_updates or if after_updates is None.
698
+ """
699
+ if self._after_updates is None:
700
+ raise KeyError(f'{key} is not registered in after_updates of {self}')
701
+ if key not in self.after_updates:
702
+ raise KeyError(f'{key} is not registered in after_updates of {self}')
703
+ return self.after_updates.get(key)
704
+
705
+ def _has_before_update(self, key: Any):
706
+ """
707
+ Check if a before-update function is registered with the given key.
708
+
709
+ Parameters
710
+ ----------
711
+ key : Any
712
+ The identifier to check for in the before_updates dictionary.
713
+
714
+ Returns
715
+ -------
716
+ bool
717
+ True if the key is registered in before_updates, False otherwise.
718
+ """
719
+ if self._before_updates is None:
720
+ return False
721
+ return key in self.before_updates
722
+
723
+ def _has_after_update(self, key: Any):
724
+ """
725
+ Check if an after-update function is registered with the given key.
726
+
727
+ Parameters
728
+ ----------
729
+ key : Any
730
+ The identifier to check for in the after_updates dictionary.
731
+
732
+ Returns
733
+ -------
734
+ bool
735
+ True if the key is registered in after_updates, False otherwise.
736
+ """
737
+ if self._after_updates is None:
738
+ return False
739
+ return key in self.after_updates
740
+
741
+ def __call__(self, *args, **kwargs):
742
+ """
743
+ The shortcut to call ``update`` methods.
744
+ """
745
+
746
+ # ``before_updates``
747
+ if self.before_updates is not None:
748
+ for model in self.before_updates.values():
749
+ if hasattr(model, '_receive_update_input'):
750
+ model(*args, **kwargs)
751
+ else:
752
+ model()
753
+
754
+ # update the model self
755
+ ret = self.update(*args, **kwargs)
756
+
757
+ # ``after_updates``
758
+ if self.after_updates is not None:
759
+ for model in self.after_updates.values():
760
+ if hasattr(model, '_not_receive_update_output'):
761
+ model()
762
+ else:
763
+ model(ret)
764
+ return ret
765
+
766
+ def prefetch(self, item: str) -> 'Prefetch':
767
+ """
768
+ Create a reference to a state or variable that may not be initialized yet.
769
+
770
+ This method allows accessing module attributes or states before they are
771
+ fully defined, acting as a placeholder that will be resolved when called.
772
+ Particularly useful for creating references to variables that will be defined
773
+ during initialization or runtime.
774
+
775
+ Parameters
776
+ ----------
777
+ item : str
778
+ The name of the attribute or state to reference.
779
+
780
+ Returns
781
+ -------
782
+ Prefetch
783
+ A Prefetch object that provides access to the referenced item.
784
+
785
+ Examples
786
+ --------
787
+ >>> import brainstate
788
+ >>> import brainunit as u
789
+ >>> neuron = brainstate.nn.LIF(...)
790
+ >>> v_ref = neuron.prefetch('V') # Reference to voltage
791
+ >>> v_value = v_ref() # Get current value
792
+ >>> delayed_v = v_ref.delay.at(5.0 * u.ms) # Get delayed value
793
+ """
794
+ return Prefetch(self, item)
795
+
796
+ def align_pre(self, dyn: Union[ParamDescriber[T], T]) -> T:
797
+ """
798
+ Registers a dynamics module to execute after this module.
799
+
800
+ This method establishes a sequential execution relationship where the specified
801
+ dynamics module will be called after this module completes its update. This
802
+ creates a feed-forward connection in the computational graph.
803
+
804
+ Parameters
805
+ ----------
806
+ dyn : Union[ParamDescriber[T], T]
807
+ The dynamics module to be executed after this module. Can be either:
808
+ - An instance of Dynamics
809
+ - A ParamDescriber that can instantiate a Dynamics object
810
+
811
+ Returns
812
+ -------
813
+ T
814
+ The dynamics module that was registered, allowing for method chaining.
815
+
816
+ Raises
817
+ ------
818
+ TypeError
819
+ If the input is not a Dynamics instance or a ParamDescriber that creates
820
+ a Dynamics instance.
821
+
822
+ Examples
823
+ --------
824
+ >>> import brainstate
825
+ >>> n1 = brainstate.nn.LIF(10)
826
+ >>> n1.align_pre(brainstate.nn.Expon.desc(n1.varshape)) # n2 will run after n1
827
+ """
828
+ if isinstance(dyn, Dynamics):
829
+ self._add_after_update(id(dyn), dyn)
830
+ return dyn
831
+ elif isinstance(dyn, ParamDescriber):
832
+ if not issubclass(dyn.cls, Dynamics):
833
+ raise TypeError(f'The input {dyn} should be an instance of {Dynamics}.')
834
+ if not self._has_after_update(dyn.identifier):
835
+ self._add_after_update(
836
+ dyn.identifier,
837
+ dyn() if ('in_size' in dyn.kwargs or len(dyn.args) > 0) else dyn(in_size=self.varshape)
838
+ )
839
+ return self._get_after_update(dyn.identifier)
840
+ else:
841
+ raise TypeError(f'The input {dyn} should be an instance of {Dynamics} or a delayed initializer.')
842
+
843
+ def prefetch_delay(self, state: str, delay_time, init: Callable = None) -> 'PrefetchDelayAt':
844
+ """
845
+ Create a reference to a delayed state or variable in the module.
846
+
847
+ This method simplifies the process of accessing a delayed version of a state or variable
848
+ within the module. It first creates a prefetch reference to the specified state,
849
+ then specifies the delay time for accessing this state.
850
+
851
+ Args:
852
+ state (str): The name of the state or variable to reference.
853
+ delay_time (ArrayLike): The amount of time to delay the variable access,
854
+ typically in time units (e.g., milliseconds).
855
+ init (Callable, optional): An optional initialization function to provide
856
+ a default value if the delayed state is not yet available.
857
+
858
+ Returns:
859
+ PrefetchDelayAt: An object that provides access to the variable at the specified delay time.
860
+ """
861
+ return PrefetchDelayAt(self, state, delay_time, init=init)
862
+
863
+ def output_delay(self, *delay_time) -> 'OutputDelayAt':
864
+ """
865
+ Create a reference to the delayed output of the module.
866
+
867
+ This method simplifies the process of accessing a delayed version of the module's output.
868
+ It instantiates an `OutputDelayAt` object, which can be used to retrieve the output value
869
+ at the specified delay time.
870
+
871
+ Args:
872
+ delay (Optional[ArrayLike]): The amount of time to delay the output access,
873
+ typically in time units (e.g., milliseconds). Defaults to None.
874
+
875
+ Returns:
876
+ OutputDelayAt: An object that provides access to the module's output at the specified delay time.
877
+ """
878
+ return OutputDelayAt(self, delay_time)
879
+
880
+
881
+ class Prefetch(Node):
882
+ """
883
+ Prefetch a state or variable in a module before it is initialized.
884
+
885
+
886
+ This class provides a mechanism to reference a module's state or attribute
887
+ that may not have been initialized yet. It acts as a placeholder or reference
888
+ that will be resolved when called.
889
+
890
+ Use cases:
891
+ - Access variables within dynamics modules that will be defined later
892
+ - Create references to states across module boundaries
893
+ - Enable access to delayed states through the `.delay` property
894
+
895
+ Parameters
896
+ ----------
897
+ module : Module
898
+ The module that contains or will contain the referenced item.
899
+ item : str
900
+ The attribute name of the state or variable to prefetch.
901
+
902
+ Examples
903
+ --------
904
+ >>> import brainstate
905
+ >>> import brainunit as u
906
+ >>> neuron = brainstate.nn.LIF(...)
907
+ >>> v_reference = neuron.prefetch('V') # Reference to voltage before initialization
908
+ >>> v_value = v_reference() # Get the current value
909
+ >>> delay_ref = v_reference.delay.at(5.0 * u.ms) # Reference voltage delayed by 5ms
910
+
911
+ Notes
912
+ -----
913
+ When called, this class retrieves the current value of the referenced item.
914
+ Use the `.delay` property to access delayed versions of the state.
915
+
916
+ """
917
+
918
+ def __init__(self, module: Dynamics, item: str):
919
+ """
920
+ Initialize a Prefetch object.
921
+
922
+ Parameters
923
+ ----------
924
+ module : Module
925
+ The module that contains or will contain the referenced item.
926
+ item : str
927
+ The attribute name of the state or variable to prefetch.
928
+ """
929
+ super().__init__()
930
+ self.module = module
931
+ self.item = item
932
+
933
+ @property
934
+ def delay(self):
935
+ """
936
+ Access delayed versions of the prefetched item.
937
+
938
+ Returns
939
+ -------
940
+ PrefetchDelay
941
+ An object that provides access to delayed versions of the prefetched item.
942
+ """
943
+ return PrefetchDelay(self.module, self.item)
944
+ # return PrefetchDelayAt(self.module, self.item, time)
945
+
946
+ def __call__(self, *args, **kwargs):
947
+ """
948
+ Get the current value of the prefetched item.
949
+
950
+ Returns
951
+ -------
952
+ Any
953
+ The current value of the referenced item. If the item is a State object,
954
+ returns its value attribute, otherwise returns the item itself.
955
+ """
956
+ item = _get_prefetch_item(self)
957
+ return item.value if isinstance(item, State) else item
958
+
959
+ def get_item_value(self):
960
+ """
961
+ Get the current value of the prefetched item.
962
+
963
+ Similar to __call__, but explicitly named for clarity.
964
+
965
+ Returns
966
+ -------
967
+ Any
968
+ The current value of the referenced item. If the item is a State object,
969
+ returns its value attribute, otherwise returns the item itself.
970
+ """
971
+ item = _get_prefetch_item(self)
972
+ return item.value if isinstance(item, State) else item
973
+
974
+ def get_item(self):
975
+ """
976
+ Get the referenced item object itself, not its value.
977
+
978
+ Returns
979
+ -------
980
+ Any
981
+ The actual referenced item from the module, which could be a State
982
+ object or any other attribute.
983
+ """
984
+ return _get_prefetch_item(self)
985
+
986
+
987
+ class PrefetchDelay(Node):
988
+ """
989
+ Provides access to delayed versions of a prefetched state or variable.
990
+
991
+ This class acts as an intermediary for accessing delayed values of module variables.
992
+ It doesn't retrieve values directly but provides methods to specify the delay time
993
+ via the `at()` method.
994
+
995
+ Parameters
996
+ ----------
997
+ module : Dynamics
998
+ The dynamics module that contains the referenced state or variable.
999
+ item : str
1000
+ The name of the state or variable to access with delay.
1001
+
1002
+ Examples
1003
+ --------
1004
+ >>> import brainstate
1005
+ >>> import brainunit as u
1006
+ >>> neuron = brainstate.nn.LIF(10)
1007
+ >>> # Access voltage delayed by 5ms
1008
+ >>> delayed_v = neuron.prefetch('V').delay.at(5.0 * u.ms)
1009
+ >>> delayed_value = delayed_v() # Get the delayed value
1010
+ """
1011
+
1012
+ def __init__(self, module: Dynamics, item: str):
1013
+ self.module = module
1014
+ self.item = item
1015
+
1016
+ def at(self, *delay_time):
1017
+ """
1018
+ Specifies the delay time for accessing the variable.
1019
+
1020
+ Parameters
1021
+ ----------
1022
+ time : ArrayLike
1023
+ The amount of time to delay the variable access, typically in time units
1024
+ (e.g., milliseconds).
1025
+
1026
+ Returns
1027
+ -------
1028
+ PrefetchDelayAt
1029
+ An object that provides access to the variable at the specified delay time.
1030
+ """
1031
+ return PrefetchDelayAt(self.module, self.item, delay_time)
1032
+
1033
+
1034
+ class PrefetchDelayAt(Node):
1035
+ """
1036
+ Provides access to a specific delayed state or variable value at the specific time.
1037
+
1038
+ This class represents the final step in the prefetch delay chain, providing
1039
+ actual access to state values at a specific delay time. It converts the
1040
+ specified time delay into steps and registers the delay with the appropriate
1041
+ StateWithDelay handler.
1042
+
1043
+ Parameters
1044
+ ----------
1045
+ module : Dynamics
1046
+ The dynamics module that contains the referenced state or variable.
1047
+ item : str
1048
+ The name of the state or variable to access with delay.
1049
+ time : ArrayLike
1050
+ The amount of time to delay access by, typically in time units (e.g., milliseconds).
1051
+
1052
+ Examples
1053
+ --------
1054
+ >>> import brainstate
1055
+ >>> import brainunit as u
1056
+ >>> neuron = brainstate.nn.LIF(10)
1057
+ >>> # Create a reference to voltage delayed by 5ms
1058
+ >>> delayed_v = PrefetchDelayAt(neuron, 'V', 5.0 * u.ms)
1059
+ >>> # Get the delayed value
1060
+ >>> v_value = delayed_v()
1061
+ """
1062
+
1063
+ def __init__(
1064
+ self,
1065
+ module: Dynamics,
1066
+ item: str,
1067
+ delay_time: Tuple,
1068
+ init: Callable = None
1069
+ ):
1070
+ """
1071
+ Initialize a PrefetchDelayAt object.
1072
+
1073
+ Parameters
1074
+ ----------
1075
+ module : Dynamics
1076
+ The dynamics module that contains the referenced state or variable.
1077
+ item : str
1078
+ The name of the state or variable to access with delay.
1079
+ delay_time : Tuple
1080
+ The amount of time to delay access by, typically in time units (e.g., milliseconds).
1081
+ """
1082
+ super().__init__()
1083
+ assert isinstance(module, Dynamics), 'The module should be an instance of Dynamics.'
1084
+ self.module = module
1085
+ self.item = item
1086
+ if not isinstance(delay_time, (tuple, list)):
1087
+ delay_time = (delay_time,)
1088
+ self.delay_time = delay_time
1089
+ if len(delay_time) > 0:
1090
+ key = _get_prefetch_delay_key(item)
1091
+ if not module._has_after_update(key):
1092
+ module._add_after_update(
1093
+ key,
1094
+ not_receive_update_output(
1095
+ StateWithDelay(module, item, init=init)
1096
+ )
1097
+ )
1098
+ self.state_delay: StateWithDelay = module._get_after_update(key)
1099
+ self.delay_info = self.state_delay.register_delay(*delay_time)
1100
+
1101
+ def __call__(self, *args, **kwargs):
1102
+ """
1103
+ Retrieve the value of the state at the specified delay time.
1104
+
1105
+ Returns
1106
+ -------
1107
+ Any
1108
+ The value of the state or variable at the specified delay time.
1109
+ """
1110
+ if len(self.delay_time) == 0:
1111
+ return _get_prefetch_item(self).value
1112
+ else:
1113
+ return self.state_delay.retrieve_at_step(*self.delay_info)
1114
+
1115
+
1116
+ class OutputDelayAt(Node):
1117
+ """
1118
+ Provides access to a specific delayed state or variable value at the specific time.
1119
+
1120
+ This class represents the final step in the prefetch delay chain, providing
1121
+ actual access to state values at a specific delay time. It converts the
1122
+ specified time delay into steps and registers the delay with the appropriate
1123
+ StateWithDelay handler.
1124
+
1125
+ Parameters
1126
+ ----------
1127
+ module : Dynamics
1128
+ The dynamics module that contains the referenced state or variable.
1129
+ time : ArrayLike
1130
+ The amount of time to delay access by, typically in time units (e.g., milliseconds).
1131
+
1132
+ Examples
1133
+ --------
1134
+ >>> import brainstate
1135
+ >>> import brainunit as u
1136
+ >>> neuron = brainstate.nn.LIF(10)
1137
+ >>> # Create a reference to voltage delayed by 5ms
1138
+ >>> delayed_spike = OutputDelayAt(neuron, 5.0 * u.ms)
1139
+ >>> # Get the delayed value
1140
+ >>> v_value = delayed_spike()
1141
+ """
1142
+
1143
+ def __init__(
1144
+ self,
1145
+ module: Dynamics,
1146
+ delay_time: Tuple,
1147
+ ):
1148
+ super().__init__()
1149
+ assert isinstance(module, Dynamics), 'The module should be an instance of Dynamics.'
1150
+ self.module = module
1151
+ key = _get_output_delay_key()
1152
+ if not module._has_after_update(key):
1153
+ delay = Delay(jax.ShapeDtypeStruct(module.out_size, dtype=environ.dftype()), take_aware_unit=True)
1154
+ module._add_after_update(key, receive_update_output(delay))
1155
+ self.out_delay: Delay = module._get_after_update(key)
1156
+ self.delay_info = self.out_delay.register_delay(*delay_time)
1157
+
1158
+ def __call__(self, *args, **kwargs):
1159
+ return self.out_delay.retrieve_at_step(*self.delay_info)
1160
+
1161
+
1162
+ def _get_prefetch_delay_key(item) -> str:
1163
+ return f'{item}-prefetch-delay'
1164
+
1165
+
1166
+ def _get_output_delay_key() -> str:
1167
+ return f'output-delay'
1168
+
1169
+
1170
+ def _get_prefetch_item(target: Union[Prefetch, PrefetchDelayAt]) -> Any:
1171
+ item = getattr(target.module, target.item, None)
1172
+ if item is None:
1173
+ raise AttributeError(f'The target {target.module} should have an `{target.item}` attribute.')
1174
+ return item
1175
+
1176
+
1177
+ def _get_prefetch_item_delay(target: Union[Prefetch, PrefetchDelay, PrefetchDelayAt]) -> Delay:
1178
+ assert isinstance(target.module, Dynamics), (
1179
+ f'The target module should be an instance '
1180
+ f'of Dynamics. But got {target.module}.'
1181
+ )
1182
+ delay = target.module._get_after_update(_get_prefetch_delay_key(target.item))
1183
+ if not isinstance(delay, StateWithDelay):
1184
+ raise TypeError(f'The prefetch target should be a {StateWithDelay.__name__} when accessing '
1185
+ f'its delay. But got {delay}.')
1186
+ return delay
1187
+
1188
+
1189
+ def maybe_init_prefetch(target, *args, **kwargs):
1190
+ """
1191
+ Initialize a prefetch target if needed, based on its type.
1192
+
1193
+ This function ensures that prefetch references are properly initialized
1194
+ and ready to use. It handles different types of prefetch objects by
1195
+ performing the appropriate initialization action:
1196
+ - For :py:class:`Prefetch` objects: retrieves the referenced item
1197
+ - For :py:class:`PrefetchDelay` objects: retrieves the delay handler
1198
+ - For :py:class:`PrefetchDelayAt` objects: registers the specified delay
1199
+
1200
+ Parameters
1201
+ ----------
1202
+ target : Union[Prefetch, PrefetchDelay, PrefetchDelayAt]
1203
+ The prefetch target to initialize.
1204
+ *args : Any
1205
+ Additional positional arguments (unused).
1206
+ **kwargs : Any
1207
+ Additional keyword arguments (unused).
1208
+
1209
+ Returns
1210
+ -------
1211
+ None
1212
+ This function performs initialization side effects only.
1213
+
1214
+ Notes
1215
+ -----
1216
+ This function is typically called internally when prefetched references
1217
+ are used to ensure they are properly set up before access.
1218
+ """
1219
+ if isinstance(target, Prefetch):
1220
+ _get_prefetch_item(target)
1221
+
1222
+ elif isinstance(target, PrefetchDelay):
1223
+ _get_prefetch_item_delay(target)
1224
+
1225
+ elif isinstance(target, PrefetchDelayAt):
1226
+ pass
1227
+ # delay = _get_prefetch_item_delay(target)
1228
+ # delay.register_delay(*target.delay_time)
1229
+
1230
+
1231
+ class DynamicsGroup(Module):
1232
+ """
1233
+ A group of :py:class:`~.Module` in which the updating order does not matter.
1234
+
1235
+ Args:
1236
+ children_as_tuple: The children objects.
1237
+ children_as_dict: The children objects.
1238
+ """
1239
+
1240
+ __module__ = 'brainstate.nn'
1241
+
1242
+ if not TYPE_CHECKING:
1243
+ def __init__(self, *children_as_tuple, **children_as_dict):
1244
+ super().__init__()
1245
+ self.layers_tuple = tuple(children_as_tuple)
1246
+ self.layers_dict = dict(children_as_dict)
1247
+
1248
+ def update(self, *args, **kwargs):
1249
+ """
1250
+ Update function of a network.
1251
+
1252
+ In this update function, the update functions in children systems are iteratively called.
1253
+ """
1254
+ projs, dyns, others = self.nodes(allowed_hierarchy=(1, 1)).split(Projection, Dynamics)
1255
+
1256
+ # update nodes of projections
1257
+ for node in projs.values():
1258
+ node()
1259
+
1260
+ # update nodes of dynamics
1261
+ for node in dyns.values():
1262
+ node()
1263
+
1264
+ # update nodes with other types, including delays, ...
1265
+ for node in others.values():
1266
+ node()
1267
+
1268
+
1269
+ def receive_update_output(cls: object):
1270
+ """
1271
+ The decorator to mark the object (as the after updates) to receive the output of the update function.
1272
+
1273
+ That is, the `aft_update` will receive the return of the update function::
1274
+
1275
+ ret = model.update(*args, **kwargs)
1276
+ for fun in model.aft_updates:
1277
+ fun(ret)
1278
+
1279
+ """
1280
+ # assert isinstance(cls, Module), 'The input class should be instance of Module.'
1281
+ if hasattr(cls, '_not_receive_update_output'):
1282
+ delattr(cls, '_not_receive_update_output')
1283
+ return cls
1284
+
1285
+
1286
+ def not_receive_update_output(cls: T) -> T:
1287
+ """
1288
+ The decorator to mark the object (as the after updates) to not receive the output of the update function.
1289
+
1290
+ That is, the `aft_update` will not receive the return of the update function::
1291
+
1292
+ ret = model.update(*args, **kwargs)
1293
+ for fun in model.aft_updates:
1294
+ fun()
1295
+
1296
+ """
1297
+ # assert isinstance(cls, Module), 'The input class should be instance of Module.'
1298
+ cls._not_receive_update_output = True
1299
+ return cls
1300
+
1301
+
1302
+ def receive_update_input(cls: object):
1303
+ """
1304
+ The decorator to mark the object (as the before updates) to receive the input of the update function.
1305
+
1306
+ That is, the `bef_update` will receive the input of the update function::
1307
+
1308
+
1309
+ for fun in model.bef_updates:
1310
+ fun(*args, **kwargs)
1311
+ model.update(*args, **kwargs)
1312
+
1313
+ """
1314
+ # assert isinstance(cls, Module), 'The input class should be instance of Module.'
1315
+ cls._receive_update_input = True
1316
+ return cls
1317
+
1318
+
1319
+ def not_receive_update_input(cls: object):
1320
+ """
1321
+ The decorator to mark the object (as the before updates) to not receive the input of the update function.
1322
+
1323
+ That is, the `bef_update` will not receive the input of the update function::
1324
+
1325
+ for fun in model.bef_updates:
1326
+ fun()
1327
+ model.update()
1328
+
1329
+ """
1330
+ # assert isinstance(cls, Module), 'The input class should be instance of Module.'
1331
+ if hasattr(cls, '_receive_update_input'):
1332
+ delattr(cls, '_receive_update_input')
1333
+ return cls
1334
+
1335
+
1336
+ def _input_label_start(label: str):
1337
+ # unify the input label repr.
1338
+ return f'{label} // '
1339
+
1340
+
1341
+ def _input_label_repr(name: str, label: Optional[str] = None):
1342
+ # unify the input label repr.
1343
+ return name if label is None else (_input_label_start(label) + str(name))