brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,1267 +1,870 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
-
19
- """
20
- All the basic dynamics class for the ``brainstate``.
21
-
22
- For handling dynamical systems:
23
-
24
- - ``DynamicsGroup``: The class for a group of modules, which update ``Projection`` first,
25
- then ``Dynamics``, finally others.
26
- - ``Projection``: The class for the synaptic projection.
27
- - ``Dynamics``: The class for the dynamical system.
28
-
29
- For handling the delays:
30
-
31
- - ``Delay``: The class for all delays.
32
- - ``DelayAccess``: The class for the delay access.
33
-
34
- """
35
-
36
- from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, TYPE_CHECKING, Tuple
37
-
38
- import jax
39
- import numpy as np
40
-
41
- from brainstate import environ
42
- from brainstate._state import State
43
- from brainstate.graph import Node
44
- from brainstate.mixin import ParamDescriber
45
- from brainstate.typing import Size, ArrayLike
46
- from ._delay import StateWithDelay, Delay
47
- from ._module import Module
48
-
49
- __all__ = [
50
- 'DynamicsGroup',
51
- 'Dynamics',
52
- 'Prefetch',
53
- 'PrefetchDelay',
54
- 'PrefetchDelayAt',
55
- 'OutputDelayAt',
56
- ]
57
-
58
- T = TypeVar('T')
59
- _max_order = 10
60
-
61
-
62
- class Dynamics(Module):
63
- """
64
- Base class for implementing neural dynamics models in BrainState.
65
-
66
- Dynamics classes represent the core computational units in neural simulations,
67
- implementing the differential equations or update rules that govern neural activity.
68
- This class provides infrastructure for managing neural populations, handling inputs,
69
- and coordinating updates within the simulation framework.
70
-
71
- The Dynamics class serves several key purposes:
72
- 1. Managing neuron population geometry and size information
73
- 2. Handling current and delta (instantaneous change) inputs to neurons
74
- 3. Supporting before/after update hooks for computational dependencies
75
- 4. Providing access to delayed state variables through the prefetch mechanism
76
- 5. Establishing the execution order in neural network simulations
77
-
78
- Parameters
79
- ----------
80
- in_size : Size
81
- The geometry of the neuron population. Can be an integer (e.g., 10) for
82
- 1D neuron arrays, or a tuple (e.g., (10, 10)) for multi-dimensional populations.
83
- name : Optional[str], default=None
84
- Optional name identifier for this dynamics module.
85
-
86
- Attributes
87
- ----------
88
- in_size : tuple
89
- The shape/geometry of the neuron population.
90
- out_size : tuple
91
- The output shape, typically matches in_size.
92
- current_inputs : Optional[Dict[str, Union[Callable, ArrayLike]]]
93
- Dictionary of registered current input functions or arrays.
94
- delta_inputs : Optional[Dict[str, Union[Callable, ArrayLike]]]
95
- Dictionary of registered delta input functions or arrays.
96
- before_updates : Optional[Dict[Hashable, Callable]]
97
- Dictionary of functions to call before the main update.
98
- after_updates : Optional[Dict[Hashable, Callable]]
99
- Dictionary of functions to call after the main update.
100
-
101
- Notes
102
- -----
103
- In the BrainState execution sequence, Dynamics modules are updated after
104
- Projection modules and before other module types, reflecting the natural
105
- flow of information in neural systems.
106
-
107
- There are several essential attributes:
108
-
109
- - ``size``: the geometry of the neuron group. For example, `(10, )` denotes a line of
110
- neurons, `(10, 10)` denotes a neuron group aligned in a 2D space, `(10, 15, 4)` denotes
111
- a 3-dimensional neuron group.
112
- - ``num``: the flattened number of neurons in the group. For example, `size=(10, )` => \
113
- `num=10`, `size=(10, 10)` => `num=100`, `size=(10, 15, 4)` => `num=600`.
114
-
115
-
116
- See Also
117
- --------
118
- Module : Parent class providing base module functionality
119
- Projection : Class for handling synaptic projections between neural populations
120
- DynamicsGroup : Container for organizing multiple dynamics modules
121
- """
122
-
123
- __module__ = 'brainstate.nn'
124
-
125
- graph_invisible_attrs = ()
126
-
127
- # before updates
128
- _before_updates: Optional[Dict[Hashable, Callable]]
129
-
130
- # after updates
131
- _after_updates: Optional[Dict[Hashable, Callable]]
132
-
133
- # current inputs
134
- _current_inputs: Optional[Dict[str, ArrayLike | Callable]]
135
-
136
- # delta inputs
137
- _delta_inputs: Optional[Dict[str, ArrayLike | Callable]]
138
-
139
- def __init__(
140
- self,
141
- in_size: Size,
142
- name: Optional[str] = None,
143
- ):
144
- # initialize
145
- super().__init__(name=name)
146
-
147
- # geometry size of neuron population
148
- if isinstance(in_size, (list, tuple)):
149
- if len(in_size) <= 0:
150
- raise ValueError(f'"in_size" must be int, or a tuple/list of int. But we got {type(in_size)}')
151
- if not isinstance(in_size[0], (int, np.integer)):
152
- raise ValueError(f'"in_size" must be int, or a tuple/list of int. But we got {type(in_size)}')
153
- in_size = tuple(in_size)
154
- elif isinstance(in_size, (int, np.integer)):
155
- in_size = (in_size,)
156
- else:
157
- raise ValueError(f'"in_size" must be int, or a tuple/list of int. But we got {type(in_size)}')
158
- self.in_size = in_size
159
-
160
- # current inputs
161
- self._current_inputs = None
162
-
163
- # delta inputs
164
- self._delta_inputs = None
165
-
166
- # before updates
167
- self._before_updates = None
168
-
169
- # after updates
170
- self._after_updates = None
171
-
172
- # in-/out- size of neuron population
173
- self.out_size = self.in_size
174
-
175
- # def __pretty_repr_item__(self, name, value):
176
- # if name in [
177
- # '_before_updates', '_after_updates', '_current_inputs', '_delta_inputs',
178
- # '_in_size', '_out_size', '_name', '_mode',
179
- # ]:
180
- # return (name, value) if value is None else (name[1:], value) # skip the first `_`
181
- # return super().__pretty_repr_item__(name, value)
182
-
183
- @property
184
- def varshape(self):
185
- """
186
- Get the shape of variables in the neuron group.
187
-
188
- This property provides access to the geometry (shape) of the neuron population,
189
- which determines how variables and states are structured.
190
-
191
- Returns
192
- -------
193
- tuple
194
- A tuple representing the dimensional shape of the neuron group,
195
- matching the in_size parameter provided during initialization.
196
-
197
- See Also
198
- --------
199
- in_size : The input geometry specification for the neuron group
200
- """
201
- return self.in_size
202
-
203
- @property
204
- def current_inputs(self):
205
- """
206
- Get the dictionary of current inputs registered with this dynamics model.
207
-
208
- Current inputs represent direct input currents that flow into the model.
209
-
210
- Returns
211
- -------
212
- dict or None
213
- A dictionary mapping keys to current input functions or values,
214
- or None if no current inputs have been registered.
215
-
216
- See Also
217
- --------
218
- add_current_input : Register a new current input
219
- sum_current_inputs : Apply and sum all current inputs
220
- delta_inputs : Dictionary of instantaneous change inputs
221
- """
222
- return self._current_inputs
223
-
224
- @property
225
- def delta_inputs(self):
226
- """
227
- Get the dictionary of delta inputs registered with this dynamics model.
228
-
229
- Delta inputs represent instantaneous changes to state variables (dX/dt).
230
-
231
- Returns
232
- -------
233
- dict or None
234
- A dictionary mapping keys to delta input functions or values,
235
- or None if no delta inputs have been registered.
236
-
237
- See Also
238
- --------
239
- add_delta_input : Register a new delta input
240
- sum_delta_inputs : Apply and sum all delta inputs
241
- current_inputs : Dictionary of direct current inputs
242
- """
243
- return self._delta_inputs
244
-
245
- def add_current_input(
246
- self,
247
- key: str,
248
- inp: Union[Callable, ArrayLike],
249
- label: Optional[str] = None
250
- ):
251
- """
252
- Add a current input function or array to the dynamics model.
253
-
254
- Current inputs represent direct input currents that can be accessed during
255
- model updates through the `sum_current_inputs()` method.
256
-
257
- Parameters
258
- ----------
259
- key : str
260
- Unique identifier for this current input. Used to retrieve or reference
261
- the input later.
262
- inp : Union[Callable, ArrayLike]
263
- The input data or function that generates input data.
264
- - If callable: Will be called during updates with arguments passed to `sum_current_inputs()`
265
- - If array-like: Will be applied once and then automatically removed from available inputs
266
- label : Optional[str], default=None
267
- Optional grouping label for the input. When provided, allows selective
268
- processing of inputs by label in `sum_current_inputs()`.
269
-
270
- Raises
271
- ------
272
- ValueError
273
- If the key has already been used for a different current input.
274
-
275
- Notes
276
- -----
277
- - Inputs with the same label can be processed together using the `label`
278
- parameter in `sum_current_inputs()`.
279
- - Non-callable inputs are consumed when used (removed after first use).
280
- - Callable inputs persist and can be called repeatedly.
281
-
282
- See Also
283
- --------
284
- sum_current_inputs : Sum all current inputs matching a given label
285
- add_delta_input : Add a delta input function or array
286
- """
287
- key = _input_label_repr(key, label)
288
- if self._current_inputs is None:
289
- self._current_inputs = dict()
290
- if key in self._current_inputs:
291
- if id(self._current_inputs[key]) != id(inp):
292
- raise ValueError(f'Key "{key}" has been defined and used in the current inputs of {self}.')
293
- self._current_inputs[key] = inp
294
-
295
- def add_delta_input(
296
- self,
297
- key: str,
298
- inp: Union[Callable, ArrayLike],
299
- label: Optional[str] = None
300
- ):
301
- """
302
- Add a delta input function or array to the dynamics model.
303
-
304
- Delta inputs represent instantaneous changes to the model state (i.e., dX/dt contributions).
305
- This method registers a function or array that provides delta inputs which will be
306
- accessible during model updates through the `sum_delta_inputs()` method.
307
-
308
- Parameters
309
- ----------
310
- key : str
311
- Unique identifier for this delta input. Used to retrieve or reference
312
- the input later.
313
- inp : Union[Callable, ArrayLike]
314
- The input data or function that generates input data.
315
- - If callable: Will be called during updates with arguments passed to `sum_delta_inputs()`
316
- - If array-like: Will be applied once and then automatically removed from available inputs
317
- label : Optional[str], default=None
318
- Optional grouping label for the input. When provided, allows selective
319
- processing of inputs by label in `sum_delta_inputs()`.
320
-
321
- Raises
322
- ------
323
- ValueError
324
- If the key has already been used for a different delta input.
325
-
326
- Notes
327
- -----
328
- - Inputs with the same label can be processed together using the `label`
329
- parameter in `sum_delta_inputs()`.
330
- - Non-callable inputs are consumed when used (removed after first use).
331
- - Callable inputs persist and can be called repeatedly.
332
-
333
- See Also
334
- --------
335
- sum_delta_inputs : Sum all delta inputs matching a given label
336
- add_current_input : Add a current input function or array
337
- """
338
- key = _input_label_repr(key, label)
339
- if self._delta_inputs is None:
340
- self._delta_inputs = dict()
341
- if key in self._delta_inputs:
342
- if id(self._delta_inputs[key]) != id(inp):
343
- raise ValueError(f'Key "{key}" has been defined and used.')
344
- self._delta_inputs[key] = inp
345
-
346
- def get_input(self, key: str):
347
- """
348
- Get a registered input function by its key.
349
-
350
- Retrieves either a current input or a delta input function that was previously
351
- registered with the given key. This method checks both current_inputs and
352
- delta_inputs dictionaries for the specified key.
353
-
354
- Parameters
355
- ----------
356
- key : str
357
- The unique identifier used when the input function was registered.
358
-
359
- Returns
360
- -------
361
- Callable or ArrayLike
362
- The input function or array associated with the given key.
363
-
364
- Raises
365
- ------
366
- ValueError
367
- If no input function is found with the specified key in either
368
- current_inputs or delta_inputs.
369
-
370
- See Also
371
- --------
372
- add_current_input : Register a current input function
373
- add_delta_input : Register a delta input function
374
-
375
- Examples
376
- --------
377
- >>> model = Dynamics(10)
378
- >>> model.add_current_input('stimulus', lambda t: np.sin(t))
379
- >>> input_func = model.get_input('stimulus')
380
- >>> input_func(0.5) # Returns sin(0.5)
381
- """
382
- if self._current_inputs is not None and key in self._current_inputs:
383
- return self._current_inputs[key]
384
- elif self._delta_inputs is not None and key in self._delta_inputs:
385
- return self._delta_inputs[key]
386
- else:
387
- raise ValueError(f'Input key {key} is not in current/delta inputs of the module {self}.')
388
-
389
- def sum_current_inputs(
390
- self,
391
- init: Any,
392
- *args,
393
- label: Optional[str] = None,
394
- **kwargs
395
- ):
396
- """
397
- Summarize all current inputs by applying and summing all registered current input functions.
398
-
399
- This method iterates through all registered current input functions (from `.current_inputs`)
400
- and applies them to calculate the total input current for the dynamics model. It adds all results
401
- to the initial value provided.
402
-
403
- Parameters
404
- ----------
405
- init : Any
406
- The initial value to which all current inputs will be added.
407
- *args
408
- Variable length argument list passed to each current input function.
409
- label : Optional[str], default=None
410
- If provided, only process current inputs with this label prefix.
411
- When None, process all current inputs regardless of label.
412
- **kwargs
413
- Arbitrary keyword arguments passed to each current input function.
414
-
415
- Returns
416
- -------
417
- Any
418
- The initial value plus all applicable current inputs summed together.
419
-
420
- Notes
421
- -----
422
- - Non-callable current inputs are applied once and then automatically removed from
423
- the current_inputs dictionary.
424
- - Callable current inputs remain registered for subsequent calls.
425
- - When a label is provided, only current inputs with keys starting with that label
426
- are applied.
427
- """
428
- if self._current_inputs is None:
429
- return init
430
- if label is None:
431
- filter_fn = lambda k: True
432
- else:
433
- label_repr = _input_label_start(label)
434
- filter_fn = lambda k: k.startswith(label_repr)
435
- for key in tuple(self._current_inputs.keys()):
436
- if filter_fn(key):
437
- out = self._current_inputs[key]
438
- if callable(out):
439
- try:
440
- init = init + out(*args, **kwargs)
441
- except Exception as e:
442
- raise ValueError(
443
- f'Error in current input value {key}: {out}\n'
444
- f'Error: {e}'
445
- ) from e
446
- else:
447
- try:
448
- init = init + out
449
- except Exception as e:
450
- raise ValueError(
451
- f'Error in current input value {key}: {out}\n'
452
- f'Error: {e}'
453
- ) from e
454
- self._current_inputs.pop(key)
455
- return init
456
-
457
- def sum_delta_inputs(
458
- self,
459
- init: Any,
460
- *args,
461
- label: Optional[str] = None,
462
- **kwargs
463
- ):
464
- """
465
- Summarize all delta inputs by applying and summing all registered delta input functions.
466
-
467
- This method iterates through all registered delta input functions (from `.delta_inputs`)
468
- and applies them to calculate instantaneous changes to model states. It adds all results
469
- to the initial value provided.
470
-
471
- Parameters
472
- ----------
473
- init : Any
474
- The initial value to which all delta inputs will be added.
475
- *args
476
- Variable length argument list passed to each delta input function.
477
- label : Optional[str], default=None
478
- If provided, only process delta inputs with this label prefix.
479
- When None, process all delta inputs regardless of label.
480
- **kwargs
481
- Arbitrary keyword arguments passed to each delta input function.
482
-
483
- Returns
484
- -------
485
- Any
486
- The initial value plus all applicable delta inputs summed together.
487
-
488
- Notes
489
- -----
490
- - Non-callable delta inputs are applied once and then automatically removed from
491
- the delta_inputs dictionary.
492
- - Callable delta inputs remain registered for subsequent calls.
493
- - When a label is provided, only delta inputs with keys starting with that label
494
- are applied.
495
- """
496
- if self._delta_inputs is None:
497
- return init
498
- if label is None:
499
- filter_fn = lambda k: True
500
- else:
501
- label_repr = _input_label_start(label)
502
- filter_fn = lambda k: k.startswith(label_repr)
503
- for key in tuple(self._delta_inputs.keys()):
504
- if filter_fn(key):
505
- out = self._delta_inputs[key]
506
- if callable(out):
507
- try:
508
- init = init + out(*args, **kwargs)
509
- except Exception as e:
510
- raise ValueError(
511
- f'Error in delta input function {key}: {out}\n'
512
- f'Error: {e}'
513
- ) from e
514
- else:
515
- try:
516
- init = init + out
517
- except Exception as e:
518
- raise ValueError(
519
- f'Error in delta input value {key}: {out}\n'
520
- f'Error: {e}'
521
- ) from e
522
- self._delta_inputs.pop(key)
523
- return init
524
-
525
- @property
526
- def before_updates(self):
527
- """
528
- Get the dictionary of functions to execute before the module's update.
529
-
530
- Returns
531
- -------
532
- dict or None
533
- Dictionary mapping keys to callable functions that will be executed
534
- before the main update, or None if no before updates are registered.
535
-
536
- Notes
537
- -----
538
- Before updates are executed in the order they were registered whenever
539
- the module is called via __call__.
540
- """
541
- return self._before_updates
542
-
543
- @property
544
- def after_updates(self):
545
- """
546
- Get the dictionary of functions to execute after the module's update.
547
-
548
- Returns
549
- -------
550
- dict or None
551
- Dictionary mapping keys to callable functions that will be executed
552
- after the main update, or None if no after updates are registered.
553
-
554
- Notes
555
- -----
556
- After updates are executed in the order they were registered whenever
557
- the module is called via __call__, and may optionally receive the return
558
- value from the update method.
559
- """
560
- return self._after_updates
561
-
562
- def _add_before_update(self, key: Any, fun: Callable):
563
- """
564
- Register a function to be executed before the module's update.
565
-
566
- Parameters
567
- ----------
568
- key : Any
569
- A unique identifier for the update function.
570
- fun : Callable
571
- The function to execute before the module's update.
572
-
573
- Raises
574
- ------
575
- KeyError
576
- If the key is already registered in before_updates.
577
-
578
- Notes
579
- -----
580
- Internal method used by the module system to register dependencies.
581
- """
582
- if self._before_updates is None:
583
- self._before_updates = dict()
584
- if key in self.before_updates:
585
- raise KeyError(f'{key} has been registered in before_updates of {self}')
586
- self.before_updates[key] = fun
587
-
588
- def _add_after_update(self, key: Any, fun: Callable):
589
- """
590
- Register a function to be executed after the module's update.
591
-
592
- Parameters
593
- ----------
594
- key : Any
595
- A unique identifier for the update function.
596
- fun : Callable
597
- The function to execute after the module's update.
598
-
599
- Raises
600
- ------
601
- KeyError
602
- If the key is already registered in after_updates.
603
-
604
- Notes
605
- -----
606
- Internal method used by the module system to register dependencies.
607
- """
608
- if self._after_updates is None:
609
- self._after_updates = dict()
610
- if key in self.after_updates:
611
- raise KeyError(f'{key} has been registered in after_updates of {self}')
612
- self.after_updates[key] = fun
613
-
614
- def _get_before_update(self, key: Any):
615
- """
616
- Retrieve a registered before-update function by its key.
617
-
618
- Parameters
619
- ----------
620
- key : Any
621
- The identifier of the before-update function to retrieve.
622
-
623
- Returns
624
- -------
625
- Callable
626
- The registered before-update function.
627
-
628
- Raises
629
- ------
630
- KeyError
631
- If the key is not registered in before_updates or if before_updates is None.
632
- """
633
- if self._before_updates is None:
634
- raise KeyError(f'{key} is not registered in before_updates of {self}')
635
- if key not in self.before_updates:
636
- raise KeyError(f'{key} is not registered in before_updates of {self}')
637
- return self.before_updates.get(key)
638
-
639
- def _get_after_update(self, key: Any):
640
- """
641
- Retrieve a registered after-update function by its key.
642
-
643
- Parameters
644
- ----------
645
- key : Any
646
- The identifier of the after-update function to retrieve.
647
-
648
- Returns
649
- -------
650
- Callable
651
- The registered after-update function.
652
-
653
- Raises
654
- ------
655
- KeyError
656
- If the key is not registered in after_updates or if after_updates is None.
657
- """
658
- if self._after_updates is None:
659
- raise KeyError(f'{key} is not registered in after_updates of {self}')
660
- if key not in self.after_updates:
661
- raise KeyError(f'{key} is not registered in after_updates of {self}')
662
- return self.after_updates.get(key)
663
-
664
- def _has_before_update(self, key: Any):
665
- """
666
- Check if a before-update function is registered with the given key.
667
-
668
- Parameters
669
- ----------
670
- key : Any
671
- The identifier to check for in the before_updates dictionary.
672
-
673
- Returns
674
- -------
675
- bool
676
- True if the key is registered in before_updates, False otherwise.
677
- """
678
- if self._before_updates is None:
679
- return False
680
- return key in self.before_updates
681
-
682
- def _has_after_update(self, key: Any):
683
- """
684
- Check if an after-update function is registered with the given key.
685
-
686
- Parameters
687
- ----------
688
- key : Any
689
- The identifier to check for in the after_updates dictionary.
690
-
691
- Returns
692
- -------
693
- bool
694
- True if the key is registered in after_updates, False otherwise.
695
- """
696
- if self._after_updates is None:
697
- return False
698
- return key in self.after_updates
699
-
700
- def __call__(self, *args, **kwargs):
701
- """
702
- The shortcut to call ``update`` methods.
703
- """
704
-
705
- # ``before_updates``
706
- if self.before_updates is not None:
707
- for model in self.before_updates.values():
708
- if hasattr(model, '_receive_update_input'):
709
- model(*args, **kwargs)
710
- else:
711
- model()
712
-
713
- # update the model self
714
- ret = self.update(*args, **kwargs)
715
-
716
- # ``after_updates``
717
- if self.after_updates is not None:
718
- for model in self.after_updates.values():
719
- if hasattr(model, '_not_receive_update_output'):
720
- model()
721
- else:
722
- model(ret)
723
- return ret
724
-
725
- def prefetch(self, item: str) -> 'Prefetch':
726
- """
727
- Create a reference to a state or variable that may not be initialized yet.
728
-
729
- This method allows accessing module attributes or states before they are
730
- fully defined, acting as a placeholder that will be resolved when called.
731
- Particularly useful for creating references to variables that will be defined
732
- during initialization or runtime.
733
-
734
- Parameters
735
- ----------
736
- item : str
737
- The name of the attribute or state to reference.
738
-
739
- Returns
740
- -------
741
- Prefetch
742
- A Prefetch object that provides access to the referenced item.
743
-
744
- Examples
745
- --------
746
- >>> import brainstate
747
- >>> import brainunit as u
748
- >>> neuron = brainstate.nn.LIF(...)
749
- >>> v_ref = neuron.prefetch('V') # Reference to voltage
750
- >>> v_value = v_ref() # Get current value
751
- >>> delayed_v = v_ref.delay.at(5.0 * u.ms) # Get delayed value
752
- """
753
- return Prefetch(self, item)
754
-
755
- def align_pre(self, dyn: Union[ParamDescriber[T], T]) -> T:
756
- """
757
- Registers a dynamics module to execute after this module.
758
-
759
- This method establishes a sequential execution relationship where the specified
760
- dynamics module will be called after this module completes its update. This
761
- creates a feed-forward connection in the computational graph.
762
-
763
- Parameters
764
- ----------
765
- dyn : Union[ParamDescriber[T], T]
766
- The dynamics module to be executed after this module. Can be either:
767
- - An instance of Dynamics
768
- - A ParamDescriber that can instantiate a Dynamics object
769
-
770
- Returns
771
- -------
772
- T
773
- The dynamics module that was registered, allowing for method chaining.
774
-
775
- Raises
776
- ------
777
- TypeError
778
- If the input is not a Dynamics instance or a ParamDescriber that creates
779
- a Dynamics instance.
780
-
781
- Examples
782
- --------
783
- >>> import brainstate
784
- >>> n1 = brainstate.nn.LIF(10)
785
- >>> n1.align_pre(brainstate.nn.Expon.desc(n1.varshape)) # n2 will run after n1
786
- """
787
- if isinstance(dyn, Dynamics):
788
- self._add_after_update(id(dyn), dyn)
789
- return dyn
790
- elif isinstance(dyn, ParamDescriber):
791
- if not issubclass(dyn.cls, Dynamics):
792
- raise TypeError(f'The input {dyn} should be an instance of {Dynamics}.')
793
- if not self._has_after_update(dyn.identifier):
794
- self._add_after_update(
795
- dyn.identifier,
796
- dyn() if ('in_size' in dyn.kwargs or len(dyn.args) > 0) else dyn(in_size=self.varshape)
797
- )
798
- return self._get_after_update(dyn.identifier)
799
- else:
800
- raise TypeError(f'The input {dyn} should be an instance of {Dynamics} or a delayed initializer.')
801
-
802
- def prefetch_delay(self, state: str, delay_time, init: Callable = None) -> 'PrefetchDelayAt':
803
- """
804
- Create a reference to a delayed state or variable in the module.
805
-
806
- This method simplifies the process of accessing a delayed version of a state or variable
807
- within the module. It first creates a prefetch reference to the specified state,
808
- then specifies the delay time for accessing this state.
809
-
810
- Args:
811
- state (str): The name of the state or variable to reference.
812
- delay_time (ArrayLike): The amount of time to delay the variable access,
813
- typically in time units (e.g., milliseconds).
814
- init (Callable, optional): An optional initialization function to provide
815
- a default value if the delayed state is not yet available.
816
-
817
- Returns:
818
- PrefetchDelayAt: An object that provides access to the variable at the specified delay time.
819
- """
820
- return PrefetchDelayAt(self, state, delay_time, init=init)
821
-
822
- def output_delay(self, *delay_time) -> 'OutputDelayAt':
823
- """
824
- Create a reference to the delayed output of the module.
825
-
826
- This method simplifies the process of accessing a delayed version of the module's output.
827
- It instantiates an `OutputDelayAt` object, which can be used to retrieve the output value
828
- at the specified delay time.
829
-
830
- Args:
831
- delay (Optional[ArrayLike]): The amount of time to delay the output access,
832
- typically in time units (e.g., milliseconds). Defaults to None.
833
-
834
- Returns:
835
- OutputDelayAt: An object that provides access to the module's output at the specified delay time.
836
- """
837
- return OutputDelayAt(self, delay_time)
838
-
839
-
840
- class Prefetch(Node):
841
- """
842
- Prefetch a state or variable in a module before it is initialized.
843
-
844
-
845
- This class provides a mechanism to reference a module's state or attribute
846
- that may not have been initialized yet. It acts as a placeholder or reference
847
- that will be resolved when called.
848
-
849
- Use cases:
850
- - Access variables within dynamics modules that will be defined later
851
- - Create references to states across module boundaries
852
- - Enable access to delayed states through the `.delay` property
853
-
854
- Parameters
855
- ----------
856
- module : Module
857
- The module that contains or will contain the referenced item.
858
- item : str
859
- The attribute name of the state or variable to prefetch.
860
-
861
- Examples
862
- --------
863
- >>> import brainstate
864
- >>> import brainunit as u
865
- >>> neuron = brainstate.nn.LIF(...)
866
- >>> v_reference = neuron.prefetch('V') # Reference to voltage before initialization
867
- >>> v_value = v_reference() # Get the current value
868
- >>> delay_ref = v_reference.delay.at(5.0 * u.ms) # Reference voltage delayed by 5ms
869
-
870
- Notes
871
- -----
872
- When called, this class retrieves the current value of the referenced item.
873
- Use the `.delay` property to access delayed versions of the state.
874
-
875
- """
876
-
877
- def __init__(self, module: Dynamics, item: str):
878
- """
879
- Initialize a Prefetch object.
880
-
881
- Parameters
882
- ----------
883
- module : Module
884
- The module that contains or will contain the referenced item.
885
- item : str
886
- The attribute name of the state or variable to prefetch.
887
- """
888
- super().__init__()
889
- self.module = module
890
- self.item = item
891
-
892
- @property
893
- def delay(self):
894
- """
895
- Access delayed versions of the prefetched item.
896
-
897
- Returns
898
- -------
899
- PrefetchDelay
900
- An object that provides access to delayed versions of the prefetched item.
901
- """
902
- return PrefetchDelay(self.module, self.item)
903
- # return PrefetchDelayAt(self.module, self.item, time)
904
-
905
- def __call__(self, *args, **kwargs):
906
- """
907
- Get the current value of the prefetched item.
908
-
909
- Returns
910
- -------
911
- Any
912
- The current value of the referenced item. If the item is a State object,
913
- returns its value attribute, otherwise returns the item itself.
914
- """
915
- item = _get_prefetch_item(self)
916
- return item.value if isinstance(item, State) else item
917
-
918
- def get_item_value(self):
919
- """
920
- Get the current value of the prefetched item.
921
-
922
- Similar to __call__, but explicitly named for clarity.
923
-
924
- Returns
925
- -------
926
- Any
927
- The current value of the referenced item. If the item is a State object,
928
- returns its value attribute, otherwise returns the item itself.
929
- """
930
- item = _get_prefetch_item(self)
931
- return item.value if isinstance(item, State) else item
932
-
933
- def get_item(self):
934
- """
935
- Get the referenced item object itself, not its value.
936
-
937
- Returns
938
- -------
939
- Any
940
- The actual referenced item from the module, which could be a State
941
- object or any other attribute.
942
- """
943
- return _get_prefetch_item(self)
944
-
945
-
946
- class PrefetchDelay(Node):
947
- """
948
- Provides access to delayed versions of a prefetched state or variable.
949
-
950
- This class acts as an intermediary for accessing delayed values of module variables.
951
- It doesn't retrieve values directly but provides methods to specify the delay time
952
- via the `at()` method.
953
-
954
- Parameters
955
- ----------
956
- module : Dynamics
957
- The dynamics module that contains the referenced state or variable.
958
- item : str
959
- The name of the state or variable to access with delay.
960
-
961
- Examples
962
- --------
963
- >>> import brainstate
964
- >>> import brainunit as u
965
- >>> neuron = brainstate.nn.LIF(10)
966
- >>> # Access voltage delayed by 5ms
967
- >>> delayed_v = neuron.prefetch('V').delay.at(5.0 * u.ms)
968
- >>> delayed_value = delayed_v() # Get the delayed value
969
- """
970
-
971
- def __init__(self, module: Dynamics, item: str):
972
- self.module = module
973
- self.item = item
974
-
975
- def at(self, *delay_time):
976
- """
977
- Specifies the delay time for accessing the variable.
978
-
979
- Parameters
980
- ----------
981
- time : ArrayLike
982
- The amount of time to delay the variable access, typically in time units
983
- (e.g., milliseconds).
984
-
985
- Returns
986
- -------
987
- PrefetchDelayAt
988
- An object that provides access to the variable at the specified delay time.
989
- """
990
- return PrefetchDelayAt(self.module, self.item, delay_time)
991
-
992
-
993
- class PrefetchDelayAt(Node):
994
- """
995
- Provides access to a specific delayed state or variable value at the specific time.
996
-
997
- This class represents the final step in the prefetch delay chain, providing
998
- actual access to state values at a specific delay time. It converts the
999
- specified time delay into steps and registers the delay with the appropriate
1000
- StateWithDelay handler.
1001
-
1002
- Parameters
1003
- ----------
1004
- module : Dynamics
1005
- The dynamics module that contains the referenced state or variable.
1006
- item : str
1007
- The name of the state or variable to access with delay.
1008
- time : ArrayLike
1009
- The amount of time to delay access by, typically in time units (e.g., milliseconds).
1010
-
1011
- Examples
1012
- --------
1013
- >>> import brainstate
1014
- >>> import brainunit as u
1015
- >>> neuron = brainstate.nn.LIF(10)
1016
- >>> # Create a reference to voltage delayed by 5ms
1017
- >>> delayed_v = PrefetchDelayAt(neuron, 'V', 5.0 * u.ms)
1018
- >>> # Get the delayed value
1019
- >>> v_value = delayed_v()
1020
- """
1021
-
1022
- def __init__(
1023
- self,
1024
- module: Dynamics,
1025
- item: str,
1026
- delay_time: Tuple,
1027
- init: Callable = None
1028
- ):
1029
- """
1030
- Initialize a PrefetchDelayAt object.
1031
-
1032
- Parameters
1033
- ----------
1034
- module : Dynamics
1035
- The dynamics module that contains the referenced state or variable.
1036
- item : str
1037
- The name of the state or variable to access with delay.
1038
- delay_time : Tuple
1039
- The amount of time to delay access by, typically in time units (e.g., milliseconds).
1040
- """
1041
- super().__init__()
1042
- assert isinstance(module, Dynamics), 'The module should be an instance of Dynamics.'
1043
- self.module = module
1044
- self.item = item
1045
- if not isinstance(delay_time, (tuple, list)):
1046
- delay_time = (delay_time,)
1047
- self.delay_time = delay_time
1048
- if len(delay_time) > 0:
1049
- key = _get_prefetch_delay_key(item)
1050
- if not module._has_after_update(key):
1051
- module._add_after_update(
1052
- key,
1053
- not_receive_update_output(
1054
- StateWithDelay(module, item, init=init)
1055
- )
1056
- )
1057
- self.state_delay: StateWithDelay = module._get_after_update(key)
1058
- self.delay_info = self.state_delay.register_delay(*delay_time)
1059
-
1060
- def __call__(self, *args, **kwargs):
1061
- """
1062
- Retrieve the value of the state at the specified delay time.
1063
-
1064
- Returns
1065
- -------
1066
- Any
1067
- The value of the state or variable at the specified delay time.
1068
- """
1069
- if len(self.delay_time) == 0:
1070
- return _get_prefetch_item(self).value
1071
- else:
1072
- return self.state_delay.retrieve_at_step(*self.delay_info)
1073
-
1074
-
1075
- class OutputDelayAt(Node):
1076
- """
1077
- Provides access to a specific delayed state or variable value at the specific time.
1078
-
1079
- This class represents the final step in the prefetch delay chain, providing
1080
- actual access to state values at a specific delay time. It converts the
1081
- specified time delay into steps and registers the delay with the appropriate
1082
- StateWithDelay handler.
1083
-
1084
- Parameters
1085
- ----------
1086
- module : Dynamics
1087
- The dynamics module that contains the referenced state or variable.
1088
- time : ArrayLike
1089
- The amount of time to delay access by, typically in time units (e.g., milliseconds).
1090
-
1091
- Examples
1092
- --------
1093
- >>> import brainstate
1094
- >>> import brainunit as u
1095
- >>> neuron = brainstate.nn.LIF(10)
1096
- >>> # Create a reference to voltage delayed by 5ms
1097
- >>> delayed_spike = OutputDelayAt(neuron, 5.0 * u.ms)
1098
- >>> # Get the delayed value
1099
- >>> v_value = delayed_spike()
1100
- """
1101
-
1102
- def __init__(
1103
- self,
1104
- module: Dynamics,
1105
- delay_time: Tuple,
1106
- ):
1107
- super().__init__()
1108
- assert isinstance(module, Dynamics), 'The module should be an instance of Dynamics.'
1109
- self.module = module
1110
- key = _get_output_delay_key()
1111
- if not module._has_after_update(key):
1112
- delay = Delay(jax.ShapeDtypeStruct(module.out_size, dtype=environ.dftype()), take_aware_unit=True)
1113
- module._add_after_update(key, receive_update_output(delay))
1114
- self.out_delay: Delay = module._get_after_update(key)
1115
- self.delay_info = self.out_delay.register_delay(*delay_time)
1116
-
1117
- def __call__(self, *args, **kwargs):
1118
- return self.out_delay.retrieve_at_step(*self.delay_info)
1119
-
1120
-
1121
- def _get_prefetch_delay_key(item) -> str:
1122
- return f'{item}-prefetch-delay'
1123
-
1124
-
1125
- def _get_output_delay_key() -> str:
1126
- return f'output-delay'
1127
-
1128
-
1129
- def _get_prefetch_item(target: Union[Prefetch, PrefetchDelayAt]) -> Any:
1130
- item = getattr(target.module, target.item, None)
1131
- if item is None:
1132
- raise AttributeError(f'The target {target.module} should have an `{target.item}` attribute.')
1133
- return item
1134
-
1135
-
1136
- def _get_prefetch_item_delay(target: Union[Prefetch, PrefetchDelay, PrefetchDelayAt]) -> Delay:
1137
- assert isinstance(target.module, Dynamics), (
1138
- f'The target module should be an instance '
1139
- f'of Dynamics. But got {target.module}.'
1140
- )
1141
- delay = target.module._get_after_update(_get_prefetch_delay_key(target.item))
1142
- if not isinstance(delay, StateWithDelay):
1143
- raise TypeError(f'The prefetch target should be a {StateWithDelay.__name__} when accessing '
1144
- f'its delay. But got {delay}.')
1145
- return delay
1146
-
1147
-
1148
- def maybe_init_prefetch(target, *args, **kwargs):
1149
- """
1150
- Initialize a prefetch target if needed, based on its type.
1151
-
1152
- This function ensures that prefetch references are properly initialized
1153
- and ready to use. It handles different types of prefetch objects by
1154
- performing the appropriate initialization action:
1155
- - For :py:class:`Prefetch` objects: retrieves the referenced item
1156
- - For :py:class:`PrefetchDelay` objects: retrieves the delay handler
1157
- - For :py:class:`PrefetchDelayAt` objects: registers the specified delay
1158
-
1159
- Parameters
1160
- ----------
1161
- target : Union[Prefetch, PrefetchDelay, PrefetchDelayAt]
1162
- The prefetch target to initialize.
1163
- *args : Any
1164
- Additional positional arguments (unused).
1165
- **kwargs : Any
1166
- Additional keyword arguments (unused).
1167
-
1168
- Returns
1169
- -------
1170
- None
1171
- This function performs initialization side effects only.
1172
-
1173
- Notes
1174
- -----
1175
- This function is typically called internally when prefetched references
1176
- are used to ensure they are properly set up before access.
1177
- """
1178
- if isinstance(target, Prefetch):
1179
- _get_prefetch_item(target)
1180
-
1181
- elif isinstance(target, PrefetchDelay):
1182
- _get_prefetch_item_delay(target)
1183
-
1184
- elif isinstance(target, PrefetchDelayAt):
1185
- pass
1186
- # delay = _get_prefetch_item_delay(target)
1187
- # delay.register_delay(*target.delay_time)
1188
-
1189
-
1190
- DynamicsGroup = Module
1191
-
1192
-
1193
- def receive_update_output(cls: object):
1194
- """
1195
- The decorator to mark the object (as the after updates) to receive the output of the update function.
1196
-
1197
- That is, the `aft_update` will receive the return of the update function::
1198
-
1199
- ret = model.update(*args, **kwargs)
1200
- for fun in model.aft_updates:
1201
- fun(ret)
1202
-
1203
- """
1204
- # assert isinstance(cls, Module), 'The input class should be instance of Module.'
1205
- if hasattr(cls, '_not_receive_update_output'):
1206
- delattr(cls, '_not_receive_update_output')
1207
- return cls
1208
-
1209
-
1210
- def not_receive_update_output(cls: T) -> T:
1211
- """
1212
- The decorator to mark the object (as the after updates) to not receive the output of the update function.
1213
-
1214
- That is, the `aft_update` will not receive the return of the update function::
1215
-
1216
- ret = model.update(*args, **kwargs)
1217
- for fun in model.aft_updates:
1218
- fun()
1219
-
1220
- """
1221
- # assert isinstance(cls, Module), 'The input class should be instance of Module.'
1222
- cls._not_receive_update_output = True
1223
- return cls
1224
-
1225
-
1226
- def receive_update_input(cls: object):
1227
- """
1228
- The decorator to mark the object (as the before updates) to receive the input of the update function.
1229
-
1230
- That is, the `bef_update` will receive the input of the update function::
1231
-
1232
-
1233
- for fun in model.bef_updates:
1234
- fun(*args, **kwargs)
1235
- model.update(*args, **kwargs)
1236
-
1237
- """
1238
- # assert isinstance(cls, Module), 'The input class should be instance of Module.'
1239
- cls._receive_update_input = True
1240
- return cls
1241
-
1242
-
1243
- def not_receive_update_input(cls: object):
1244
- """
1245
- The decorator to mark the object (as the before updates) to not receive the input of the update function.
1246
-
1247
- That is, the `bef_update` will not receive the input of the update function::
1248
-
1249
- for fun in model.bef_updates:
1250
- fun()
1251
- model.update()
1252
-
1253
- """
1254
- # assert isinstance(cls, Module), 'The input class should be instance of Module.'
1255
- if hasattr(cls, '_receive_update_input'):
1256
- delattr(cls, '_receive_update_input')
1257
- return cls
1258
-
1259
-
1260
- def _input_label_start(label: str):
1261
- # unify the input label repr.
1262
- return f'{label} // '
1263
-
1264
-
1265
- def _input_label_repr(name: str, label: Optional[str] = None):
1266
- # unify the input label repr.
1267
- return name if label is None else (_input_label_start(label) + str(name))
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+
19
+ """
20
+ All the basic dynamics class for the ``brainstate``.
21
+
22
+ For handling dynamical systems:
23
+
24
+ - ``DynamicsGroup``: The class for a group of modules, which update ``Projection`` first,
25
+ then ``Dynamics``, finally others.
26
+ - ``Projection``: The class for the synaptic projection.
27
+ - ``Dynamics``: The class for the dynamical system.
28
+
29
+ For handling the delays:
30
+
31
+ - ``Delay``: The class for all delays.
32
+ - ``DelayAccess``: The class for the delay access.
33
+
34
+ """
35
+
36
+ from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, Tuple
37
+
38
+ import jax
39
+ import numpy as np
40
+
41
+ from brainstate import environ
42
+ from brainstate._state import State
43
+ from brainstate.graph import Node
44
+ from brainstate.typing import Size, ArrayLike
45
+ from ._delay import StateWithDelay, Delay
46
+ from ._module import Module
47
+
48
+ T = TypeVar('T')
49
+
50
+ __all__ = [
51
+ 'Dynamics',
52
+
53
+ 'receive_update_output',
54
+ 'not_receive_update_output',
55
+ 'receive_update_input',
56
+ 'not_receive_update_input',
57
+
58
+ 'Prefetch',
59
+ 'PrefetchDelay',
60
+ 'PrefetchDelayAt',
61
+ 'OutputDelayAt',
62
+ ]
63
+
64
+
65
+ class Dynamics(Module):
66
+ """
67
+ Base class for implementing neural dynamics models in BrainState.
68
+
69
+ Dynamics classes represent the core computational units in neural simulations,
70
+ implementing the differential equations or update rules that govern neural activity.
71
+ This class provides infrastructure for managing neural populations, handling inputs,
72
+ and coordinating updates within the simulation framework.
73
+
74
+ The Dynamics class serves several key purposes:
75
+ 1. Managing neuron population geometry and size information
76
+ 2. Handling current and delta (instantaneous change) inputs to neurons
77
+ 3. Supporting before/after update hooks for computational dependencies
78
+ 4. Providing access to delayed state variables through the prefetch mechanism
79
+ 5. Establishing the execution order in neural network simulations
80
+
81
+ Parameters
82
+ ----------
83
+ in_size : Size
84
+ The geometry of the neuron population. Can be an integer (e.g., 10) for
85
+ 1D neuron arrays, or a tuple (e.g., (10, 10)) for multi-dimensional populations.
86
+ name : Optional[str], default=None
87
+ Optional name identifier for this dynamics module.
88
+
89
+ Attributes
90
+ ----------
91
+ in_size : tuple
92
+ The shape/geometry of the neuron population.
93
+ out_size : tuple
94
+ The output shape, typically matches in_size.
95
+ current_inputs : Optional[Dict[str, Union[Callable, ArrayLike]]]
96
+ Dictionary of registered current input functions or arrays.
97
+ delta_inputs : Optional[Dict[str, Union[Callable, ArrayLike]]]
98
+ Dictionary of registered delta input functions or arrays.
99
+ before_updates : Optional[Dict[Hashable, Callable]]
100
+ Dictionary of functions to call before the main update.
101
+ after_updates : Optional[Dict[Hashable, Callable]]
102
+ Dictionary of functions to call after the main update.
103
+
104
+ Notes
105
+ -----
106
+ In the BrainState execution sequence, Dynamics modules are updated after
107
+ Projection modules and before other module types, reflecting the natural
108
+ flow of information in neural systems.
109
+
110
+ There are several essential attributes:
111
+
112
+ - ``size``: the geometry of the neuron group. For example, `(10, )` denotes a line of
113
+ neurons, `(10, 10)` denotes a neuron group aligned in a 2D space, `(10, 15, 4)` denotes
114
+ a 3-dimensional neuron group.
115
+ - ``num``: the flattened number of neurons in the group. For example, `size=(10, )` => \
116
+ `num=10`, `size=(10, 10)` => `num=100`, `size=(10, 15, 4)` => `num=600`.
117
+
118
+
119
+ See Also
120
+ --------
121
+ Module : Parent class providing base module functionality
122
+ Projection : Class for handling synaptic projections between neural populations
123
+ DynamicsGroup : Container for organizing multiple dynamics modules
124
+ """
125
+
126
+ __module__ = 'brainstate.nn'
127
+
128
+ graph_invisible_attrs = ()
129
+
130
+ # before updates
131
+ _before_updates: Optional[Dict[Hashable, Callable]]
132
+
133
+ # after updates
134
+ _after_updates: Optional[Dict[Hashable, Callable]]
135
+
136
+ # current inputs
137
+ _current_inputs: Optional[Dict[str, ArrayLike | Callable]]
138
+
139
+ # delta inputs
140
+ _delta_inputs: Optional[Dict[str, ArrayLike | Callable]]
141
+
142
+ def __init__(self, in_size: Size, name: Optional[str] = None):
143
+ # initialize
144
+ super().__init__(name=name)
145
+
146
+ # geometry size of neuron population
147
+ if isinstance(in_size, (list, tuple)):
148
+ if len(in_size) <= 0:
149
+ raise ValueError(f'"in_size" must be int, or a tuple/list of int. But we got {type(in_size)}')
150
+ if not isinstance(in_size[0], (int, np.integer)):
151
+ raise ValueError(f'"in_size" must be int, or a tuple/list of int. But we got {type(in_size)}')
152
+ in_size = tuple(in_size)
153
+ elif isinstance(in_size, (int, np.integer)):
154
+ in_size = (in_size,)
155
+ else:
156
+ raise ValueError(f'"in_size" must be int, or a tuple/list of int. But we got {type(in_size)}')
157
+ self.in_size = in_size
158
+
159
+ # before updates
160
+ self._before_updates = None
161
+
162
+ # after updates
163
+ self._after_updates = None
164
+
165
+ # in-/out- size of neuron population
166
+ self.out_size = self.in_size
167
+
168
+ @property
169
+ def varshape(self):
170
+ """
171
+ Get the shape of variables in the neuron group.
172
+
173
+ This property provides access to the geometry (shape) of the neuron population,
174
+ which determines how variables and states are structured.
175
+
176
+ Returns
177
+ -------
178
+ tuple
179
+ A tuple representing the dimensional shape of the neuron group,
180
+ matching the in_size parameter provided during initialization.
181
+
182
+ See Also
183
+ --------
184
+ in_size : The input geometry specification for the neuron group
185
+ """
186
+ return self.in_size
187
+
188
+ def prefetch(self, item: str) -> 'Prefetch':
189
+ """
190
+ Create a reference to a state or variable that may not be initialized yet.
191
+
192
+ This method allows accessing module attributes or states before they are
193
+ fully defined, acting as a placeholder that will be resolved when called.
194
+ Particularly useful for creating references to variables that will be defined
195
+ during initialization or runtime.
196
+
197
+ Parameters
198
+ ----------
199
+ item : str
200
+ The name of the attribute or state to reference.
201
+
202
+ Returns
203
+ -------
204
+ Prefetch
205
+ A Prefetch object that provides access to the referenced item.
206
+
207
+ Examples
208
+ --------
209
+ >>> import brainstate
210
+ >>> import brainunit as u
211
+ >>> neuron = brainstate.nn.LIF(...)
212
+ >>> v_ref = neuron.prefetch('V') # Reference to voltage
213
+ >>> v_value = v_ref() # Get current value
214
+ >>> delayed_v = v_ref.delay.at(5.0 * u.ms) # Get delayed value
215
+ """
216
+ return Prefetch(self, item)
217
+
218
+ def prefetch_delay(self, state: str, delay_time, init: Callable = None) -> 'PrefetchDelayAt':
219
+ """
220
+ Create a reference to a delayed state or variable in the module.
221
+
222
+ This method simplifies the process of accessing a delayed version of a state or variable
223
+ within the module. It first creates a prefetch reference to the specified state,
224
+ then specifies the delay time for accessing this state.
225
+
226
+ Args:
227
+ state (str): The name of the state or variable to reference.
228
+ delay_time (ArrayLike): The amount of time to delay the variable access,
229
+ typically in time units (e.g., milliseconds).
230
+ init (Callable, optional): An optional initialization function to provide
231
+ a default value if the delayed state is not yet available.
232
+
233
+ Returns:
234
+ PrefetchDelayAt: An object that provides access to the variable at the specified delay time.
235
+ """
236
+ return PrefetchDelayAt(self, state, delay_time, init=init)
237
+
238
+ def output_delay(self, *delay_time) -> 'OutputDelayAt':
239
+ """
240
+ Create a reference to the delayed output of the module.
241
+
242
+ This method simplifies the process of accessing a delayed version of the module's output.
243
+ It instantiates an `OutputDelayAt` object, which can be used to retrieve the output value
244
+ at the specified delay time.
245
+
246
+ Args:
247
+ delay (Optional[ArrayLike]): The amount of time to delay the output access,
248
+ typically in time units (e.g., milliseconds). Defaults to None.
249
+
250
+ Returns:
251
+ OutputDelayAt: An object that provides access to the module's output at the specified delay time.
252
+ """
253
+ return OutputDelayAt(self, delay_time)
254
+
255
+ @property
256
+ def before_updates(self):
257
+ """
258
+ Get the dictionary of functions to execute before the module's update.
259
+
260
+ Returns
261
+ -------
262
+ dict or None
263
+ Dictionary mapping keys to callable functions that will be executed
264
+ before the main update, or None if no before updates are registered.
265
+
266
+ Notes
267
+ -----
268
+ Before updates are executed in the order they were registered whenever
269
+ the module is called via __call__.
270
+ """
271
+ return self._before_updates
272
+
273
+ @property
274
+ def after_updates(self):
275
+ """
276
+ Get the dictionary of functions to execute after the module's update.
277
+
278
+ Returns
279
+ -------
280
+ dict or None
281
+ Dictionary mapping keys to callable functions that will be executed
282
+ after the main update, or None if no after updates are registered.
283
+
284
+ Notes
285
+ -----
286
+ After updates are executed in the order they were registered whenever
287
+ the module is called via __call__, and may optionally receive the return
288
+ value from the update method.
289
+ """
290
+ return self._after_updates
291
+
292
+ def add_before_update(self, key: Any, fun: Callable):
293
+ """
294
+ Register a function to be executed before the module's update.
295
+
296
+ Parameters
297
+ ----------
298
+ key : Any
299
+ A unique identifier for the update function.
300
+ fun : Callable
301
+ The function to execute before the module's update.
302
+
303
+ Raises
304
+ ------
305
+ KeyError
306
+ If the key is already registered in before_updates.
307
+
308
+ Notes
309
+ -----
310
+ Internal method used by the module system to register dependencies.
311
+ """
312
+ if self._before_updates is None:
313
+ self._before_updates = dict()
314
+ if key in self.before_updates:
315
+ raise KeyError(f'{key} has been registered in before_updates of {self}')
316
+ self.before_updates[key] = fun
317
+
318
+ def add_after_update(self, key: Any, fun: Callable):
319
+ """
320
+ Register a function to be executed after the module's update.
321
+
322
+ Parameters
323
+ ----------
324
+ key : Any
325
+ A unique identifier for the update function.
326
+ fun : Callable
327
+ The function to execute after the module's update.
328
+
329
+ Raises
330
+ ------
331
+ KeyError
332
+ If the key is already registered in after_updates.
333
+
334
+ Notes
335
+ -----
336
+ Internal method used by the module system to register dependencies.
337
+ """
338
+ if self._after_updates is None:
339
+ self._after_updates = dict()
340
+ if key in self.after_updates:
341
+ raise KeyError(f'{key} has been registered in after_updates of {self}')
342
+ self.after_updates[key] = fun
343
+
344
+ def get_before_update(self, key: Any):
345
+ """
346
+ Retrieve a registered before-update function by its key.
347
+
348
+ Parameters
349
+ ----------
350
+ key : Any
351
+ The identifier of the before-update function to retrieve.
352
+
353
+ Returns
354
+ -------
355
+ Callable
356
+ The registered before-update function.
357
+
358
+ Raises
359
+ ------
360
+ KeyError
361
+ If the key is not registered in before_updates or if before_updates is None.
362
+ """
363
+ if self._before_updates is None:
364
+ raise KeyError(f'{key} is not registered in before_updates of {self}')
365
+ if key not in self.before_updates:
366
+ raise KeyError(f'{key} is not registered in before_updates of {self}')
367
+ return self.before_updates.get(key)
368
+
369
+ def get_after_update(self, key: Any):
370
+ """
371
+ Retrieve a registered after-update function by its key.
372
+
373
+ Parameters
374
+ ----------
375
+ key : Any
376
+ The identifier of the after-update function to retrieve.
377
+
378
+ Returns
379
+ -------
380
+ Callable
381
+ The registered after-update function.
382
+
383
+ Raises
384
+ ------
385
+ KeyError
386
+ If the key is not registered in after_updates or if after_updates is None.
387
+ """
388
+ if self._after_updates is None:
389
+ raise KeyError(f'{key} is not registered in after_updates of {self}')
390
+ if key not in self.after_updates:
391
+ raise KeyError(f'{key} is not registered in after_updates of {self}')
392
+ return self.after_updates.get(key)
393
+
394
+ def has_before_update(self, key: Any):
395
+ """
396
+ Check if a before-update function is registered with the given key.
397
+
398
+ Parameters
399
+ ----------
400
+ key : Any
401
+ The identifier to check for in the before_updates dictionary.
402
+
403
+ Returns
404
+ -------
405
+ bool
406
+ True if the key is registered in before_updates, False otherwise.
407
+ """
408
+ if self._before_updates is None:
409
+ return False
410
+ return key in self.before_updates
411
+
412
+ def has_after_update(self, key: Any):
413
+ """
414
+ Check if an after-update function is registered with the given key.
415
+
416
+ Parameters
417
+ ----------
418
+ key : Any
419
+ The identifier to check for in the after_updates dictionary.
420
+
421
+ Returns
422
+ -------
423
+ bool
424
+ True if the key is registered in after_updates, False otherwise.
425
+ """
426
+ if self._after_updates is None:
427
+ return False
428
+ return key in self.after_updates
429
+
430
+ def __call__(self, *args, **kwargs):
431
+ """
432
+ The shortcut to call ``update`` methods.
433
+ """
434
+
435
+ # ``before_updates``
436
+ if self.before_updates is not None:
437
+ for model in self.before_updates.values():
438
+ if hasattr(model, '_receive_update_input'):
439
+ model(*args, **kwargs)
440
+ else:
441
+ model()
442
+
443
+ # update the model self
444
+ ret = self.update(*args, **kwargs)
445
+
446
+ # ``after_updates``
447
+ if self.after_updates is not None:
448
+ for model in self.after_updates.values():
449
+ if hasattr(model, '_not_receive_update_output'):
450
+ model()
451
+ else:
452
+ model(ret)
453
+ return ret
454
+
455
+
456
+ class Prefetch(Node):
457
+ """
458
+ Prefetch a state or variable in a module before it is initialized.
459
+
460
+
461
+ This class provides a mechanism to reference a module's state or attribute
462
+ that may not have been initialized yet. It acts as a placeholder or reference
463
+ that will be resolved when called.
464
+
465
+ Use cases:
466
+ - Access variables within dynamics modules that will be defined later
467
+ - Create references to states across module boundaries
468
+ - Enable access to delayed states through the `.delay` property
469
+
470
+ Parameters
471
+ ----------
472
+ module : Module
473
+ The module that contains or will contain the referenced item.
474
+ item : str
475
+ The attribute name of the state or variable to prefetch.
476
+
477
+ Examples
478
+ --------
479
+ >>> import brainstate
480
+ >>> import brainunit as u
481
+ >>> neuron = brainstate.nn.LIF(...)
482
+ >>> v_reference = neuron.prefetch('V') # Reference to voltage before initialization
483
+ >>> v_value = v_reference() # Get the current value
484
+ >>> delay_ref = v_reference.delay.at(5.0 * u.ms) # Reference voltage delayed by 5ms
485
+
486
+ Notes
487
+ -----
488
+ When called, this class retrieves the current value of the referenced item.
489
+ Use the `.delay` property to access delayed versions of the state.
490
+
491
+ """
492
+
493
+ def __init__(self, module: Dynamics, item: str):
494
+ """
495
+ Initialize a Prefetch object.
496
+
497
+ Parameters
498
+ ----------
499
+ module : Module
500
+ The module that contains or will contain the referenced item.
501
+ item : str
502
+ The attribute name of the state or variable to prefetch.
503
+ """
504
+ super().__init__()
505
+ self.module = module
506
+ self.item = item
507
+
508
+ @property
509
+ def delay(self):
510
+ """
511
+ Access delayed versions of the prefetched item.
512
+
513
+ Returns
514
+ -------
515
+ PrefetchDelay
516
+ An object that provides access to delayed versions of the prefetched item.
517
+ """
518
+ return PrefetchDelay(self.module, self.item)
519
+ # return PrefetchDelayAt(self.module, self.item, time)
520
+
521
+ def __call__(self, *args, **kwargs):
522
+ """
523
+ Get the current value of the prefetched item.
524
+
525
+ Returns
526
+ -------
527
+ Any
528
+ The current value of the referenced item. If the item is a State object,
529
+ returns its value attribute, otherwise returns the item itself.
530
+ """
531
+ item = _get_prefetch_item(self)
532
+ return item.value if isinstance(item, State) else item
533
+
534
+ def get_item_value(self):
535
+ """
536
+ Get the current value of the prefetched item.
537
+
538
+ Similar to __call__, but explicitly named for clarity.
539
+
540
+ Returns
541
+ -------
542
+ Any
543
+ The current value of the referenced item. If the item is a State object,
544
+ returns its value attribute, otherwise returns the item itself.
545
+ """
546
+ item = _get_prefetch_item(self)
547
+ return item.value if isinstance(item, State) else item
548
+
549
+ def get_item(self):
550
+ """
551
+ Get the referenced item object itself, not its value.
552
+
553
+ Returns
554
+ -------
555
+ Any
556
+ The actual referenced item from the module, which could be a State
557
+ object or any other attribute.
558
+ """
559
+ return _get_prefetch_item(self)
560
+
561
+
562
+ class PrefetchDelay(Node):
563
+ """
564
+ Provides access to delayed versions of a prefetched state or variable.
565
+
566
+ This class acts as an intermediary for accessing delayed values of module variables.
567
+ It doesn't retrieve values directly but provides methods to specify the delay time
568
+ via the `at()` method.
569
+
570
+ Parameters
571
+ ----------
572
+ module : Dynamics
573
+ The dynamics module that contains the referenced state or variable.
574
+ item : str
575
+ The name of the state or variable to access with delay.
576
+
577
+ Examples
578
+ --------
579
+ >>> import brainstate
580
+ >>> import brainunit as u
581
+ >>> neuron = brainstate.nn.LIF(10)
582
+ >>> # Access voltage delayed by 5ms
583
+ >>> delayed_v = neuron.prefetch('V').delay.at(5.0 * u.ms)
584
+ >>> delayed_value = delayed_v() # Get the delayed value
585
+ """
586
+
587
+ def __init__(self, module: Dynamics, item: str):
588
+ self.module = module
589
+ self.item = item
590
+
591
+ def at(self, *delay_time):
592
+ """
593
+ Specifies the delay time for accessing the variable.
594
+
595
+ Parameters
596
+ ----------
597
+ time : ArrayLike
598
+ The amount of time to delay the variable access, typically in time units
599
+ (e.g., milliseconds).
600
+
601
+ Returns
602
+ -------
603
+ PrefetchDelayAt
604
+ An object that provides access to the variable at the specified delay time.
605
+ """
606
+ return PrefetchDelayAt(self.module, self.item, delay_time)
607
+
608
+
609
+ class PrefetchDelayAt(Node):
610
+ """
611
+ Provides access to a specific delayed state or variable value at the specific time.
612
+
613
+ This class represents the final step in the prefetch delay chain, providing
614
+ actual access to state values at a specific delay time. It converts the
615
+ specified time delay into steps and registers the delay with the appropriate
616
+ StateWithDelay handler.
617
+
618
+ Parameters
619
+ ----------
620
+ module : Dynamics
621
+ The dynamics module that contains the referenced state or variable.
622
+ item : str
623
+ The name of the state or variable to access with delay.
624
+ time : ArrayLike
625
+ The amount of time to delay access by, typically in time units (e.g., milliseconds).
626
+
627
+ Examples
628
+ --------
629
+ >>> import brainstate
630
+ >>> import brainunit as u
631
+ >>> neuron = brainstate.nn.LIF(10)
632
+ >>> # Create a reference to voltage delayed by 5ms
633
+ >>> delayed_v = PrefetchDelayAt(neuron, 'V', 5.0 * u.ms)
634
+ >>> # Get the delayed value
635
+ >>> v_value = delayed_v()
636
+ """
637
+
638
+ def __init__(
639
+ self,
640
+ module: Dynamics,
641
+ item: str,
642
+ delay_time: Tuple,
643
+ init: Callable = None
644
+ ):
645
+ """
646
+ Initialize a PrefetchDelayAt object.
647
+
648
+ Parameters
649
+ ----------
650
+ module : Dynamics
651
+ The dynamics module that contains the referenced state or variable.
652
+ item : str
653
+ The name of the state or variable to access with delay.
654
+ delay_time : Tuple
655
+ The amount of time to delay access by, typically in time units (e.g., milliseconds).
656
+ """
657
+ super().__init__()
658
+ assert isinstance(module, Dynamics), 'The module should be an instance of Dynamics.'
659
+ self.module = module
660
+ self.item = item
661
+ if not isinstance(delay_time, (tuple, list)):
662
+ delay_time = (delay_time,)
663
+ self.delay_time = delay_time
664
+ if len(delay_time) > 0:
665
+ key = _get_prefetch_delay_key(item)
666
+ if not module.has_after_update(key):
667
+ module.add_after_update(
668
+ key,
669
+ not_receive_update_output(
670
+ StateWithDelay(module, item, init=init)
671
+ )
672
+ )
673
+ self.state_delay: StateWithDelay = module.get_after_update(key)
674
+ self.delay_info = self.state_delay.register_delay(*delay_time)
675
+
676
+ def __call__(self, *args, **kwargs):
677
+ """
678
+ Retrieve the value of the state at the specified delay time.
679
+
680
+ Returns
681
+ -------
682
+ Any
683
+ The value of the state or variable at the specified delay time.
684
+ """
685
+ if len(self.delay_time) == 0:
686
+ return _get_prefetch_item(self).value
687
+ else:
688
+ return self.state_delay.retrieve_at_step(*self.delay_info)
689
+
690
+
691
+ class OutputDelayAt(Node):
692
+ """
693
+ Provides access to a specific delayed state or variable value at the specific time.
694
+
695
+ This class represents the final step in the prefetch delay chain, providing
696
+ actual access to state values at a specific delay time. It converts the
697
+ specified time delay into steps and registers the delay with the appropriate
698
+ StateWithDelay handler.
699
+
700
+ Parameters
701
+ ----------
702
+ module : Dynamics
703
+ The dynamics module that contains the referenced state or variable.
704
+ time : ArrayLike
705
+ The amount of time to delay access by, typically in time units (e.g., milliseconds).
706
+
707
+ Examples
708
+ --------
709
+ >>> import brainstate
710
+ >>> import brainunit as u
711
+ >>> neuron = brainstate.nn.LIF(10)
712
+ >>> # Create a reference to voltage delayed by 5ms
713
+ >>> delayed_spike = OutputDelayAt(neuron, 5.0 * u.ms)
714
+ >>> # Get the delayed value
715
+ >>> v_value = delayed_spike()
716
+ """
717
+
718
+ def __init__(
719
+ self,
720
+ module: Dynamics,
721
+ delay_time: Tuple,
722
+ ):
723
+ super().__init__()
724
+ assert isinstance(module, Dynamics), 'The module should be an instance of Dynamics.'
725
+ self.module = module
726
+ key = _get_output_delay_key()
727
+ if not module.has_after_update(key):
728
+ delay = Delay(jax.ShapeDtypeStruct(module.out_size, dtype=environ.dftype()), take_aware_unit=True)
729
+ module.add_after_update(key, receive_update_output(delay))
730
+ self.out_delay: Delay = module.get_after_update(key)
731
+ self.delay_info = self.out_delay.register_delay(*delay_time)
732
+
733
+ def __call__(self, *args, **kwargs):
734
+ return self.out_delay.retrieve_at_step(*self.delay_info)
735
+
736
+
737
+ def _get_prefetch_delay_key(item) -> str:
738
+ return f'{item}-prefetch-delay'
739
+
740
+
741
+ def _get_output_delay_key() -> str:
742
+ return f'output-delay'
743
+
744
+
745
+ def _get_prefetch_item(target: Union[Prefetch, PrefetchDelayAt]) -> Any:
746
+ item = getattr(target.module, target.item, None)
747
+ if item is None:
748
+ raise AttributeError(f'The target {target.module} should have an `{target.item}` attribute.')
749
+ return item
750
+
751
+
752
+ def _get_prefetch_item_delay(target: Union[Prefetch, PrefetchDelay, PrefetchDelayAt]) -> Delay:
753
+ assert isinstance(target.module, Dynamics), (
754
+ f'The target module should be an instance '
755
+ f'of Dynamics. But got {target.module}.'
756
+ )
757
+ delay = target.module.get_after_update(_get_prefetch_delay_key(target.item))
758
+ if not isinstance(delay, StateWithDelay):
759
+ raise TypeError(f'The prefetch target should be a {StateWithDelay.__name__} when accessing '
760
+ f'its delay. But got {delay}.')
761
+ return delay
762
+
763
+
764
+ def maybe_init_prefetch(target, *args, **kwargs):
765
+ """
766
+ Initialize a prefetch target if needed, based on its type.
767
+
768
+ This function ensures that prefetch references are properly initialized
769
+ and ready to use. It handles different types of prefetch objects by
770
+ performing the appropriate initialization action:
771
+ - For :py:class:`Prefetch` objects: retrieves the referenced item
772
+ - For :py:class:`PrefetchDelay` objects: retrieves the delay handler
773
+ - For :py:class:`PrefetchDelayAt` objects: registers the specified delay
774
+
775
+ Parameters
776
+ ----------
777
+ target : Union[Prefetch, PrefetchDelay, PrefetchDelayAt]
778
+ The prefetch target to initialize.
779
+ *args : Any
780
+ Additional positional arguments (unused).
781
+ **kwargs : Any
782
+ Additional keyword arguments (unused).
783
+
784
+ Returns
785
+ -------
786
+ None
787
+ This function performs initialization side effects only.
788
+
789
+ Notes
790
+ -----
791
+ This function is typically called internally when prefetched references
792
+ are used to ensure they are properly set up before access.
793
+ """
794
+ if isinstance(target, Prefetch):
795
+ _get_prefetch_item(target)
796
+
797
+ elif isinstance(target, PrefetchDelay):
798
+ _get_prefetch_item_delay(target)
799
+
800
+ elif isinstance(target, PrefetchDelayAt):
801
+ pass
802
+ # delay = _get_prefetch_item_delay(target)
803
+ # delay.register_delay(*target.delay_time)
804
+
805
+
806
+ def receive_update_output(cls: object):
807
+ """
808
+ The decorator to mark the object (as the after updates) to receive the output of the update function.
809
+
810
+ That is, the `aft_update` will receive the return of the update function::
811
+
812
+ ret = model.update(*args, **kwargs)
813
+ for fun in model.aft_updates:
814
+ fun(ret)
815
+
816
+ """
817
+ # assert isinstance(cls, Module), 'The input class should be instance of Module.'
818
+ if hasattr(cls, '_not_receive_update_output'):
819
+ delattr(cls, '_not_receive_update_output')
820
+ return cls
821
+
822
+
823
+ def not_receive_update_output(cls: T) -> T:
824
+ """
825
+ The decorator to mark the object (as the after updates) to not receive the output of the update function.
826
+
827
+ That is, the `aft_update` will not receive the return of the update function::
828
+
829
+ ret = model.update(*args, **kwargs)
830
+ for fun in model.aft_updates:
831
+ fun()
832
+
833
+ """
834
+ # assert isinstance(cls, Module), 'The input class should be instance of Module.'
835
+ cls._not_receive_update_output = True
836
+ return cls
837
+
838
+
839
+ def receive_update_input(cls: object):
840
+ """
841
+ The decorator to mark the object (as the before updates) to receive the input of the update function.
842
+
843
+ That is, the `bef_update` will receive the input of the update function::
844
+
845
+
846
+ for fun in model.bef_updates:
847
+ fun(*args, **kwargs)
848
+ model.update(*args, **kwargs)
849
+
850
+ """
851
+ # assert isinstance(cls, Module), 'The input class should be instance of Module.'
852
+ cls._receive_update_input = True
853
+ return cls
854
+
855
+
856
+ def not_receive_update_input(cls: object):
857
+ """
858
+ The decorator to mark the object (as the before updates) to not receive the input of the update function.
859
+
860
+ That is, the `bef_update` will not receive the input of the update function::
861
+
862
+ for fun in model.bef_updates:
863
+ fun()
864
+ model.update()
865
+
866
+ """
867
+ # assert isinstance(cls, Module), 'The input class should be instance of Module.'
868
+ if hasattr(cls, '_receive_update_input'):
869
+ delattr(cls, '_receive_update_input')
870
+ return cls