brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/nn/_delay.py CHANGED
@@ -1,588 +1,575 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import numbers
17
- from functools import partial
18
- from typing import Optional, Dict, Callable, Union, Sequence
19
-
20
- import brainunit as u
21
- import jax
22
- import jax.numpy as jnp
23
- import numpy as np
24
-
25
- from brainstate import environ
26
- from brainstate._state import ShortTermState, State
27
- from brainstate.compile import jit_error_if
28
- from brainstate.graph import Node
29
- from brainstate.typing import ArrayLike, PyTree
30
- from ._collective_ops import call_order
31
- from ._module import Module
32
-
33
- __all__ = [
34
- 'Delay', 'DelayAccess', 'StateWithDelay',
35
- ]
36
-
37
- _DELAY_ROTATE = 'rotation'
38
- _DELAY_CONCAT = 'concat'
39
- _INTERP_LINEAR = 'linear_interp'
40
- _INTERP_ROUND = 'round'
41
-
42
-
43
- def _get_delay(delay_time):
44
- if delay_time is None:
45
- return 0. * environ.get_dt(), 0
46
- delay_step = delay_time / environ.get_dt()
47
- assert u.get_dim(delay_step) == u.DIMENSIONLESS
48
- delay_step = jnp.ceil(delay_step).astype(environ.ditype())
49
- return delay_time, delay_step
50
-
51
-
52
- class DelayAccess(Node):
53
- """
54
- Accessor node for a registered entry in a Delay instance.
55
-
56
- This node holds a reference to a Delay and a named entry that was
57
- registered on that Delay. It is used by graphs to query delayed
58
- values by delegating to the underlying Delay instance.
59
-
60
- Args:
61
- delay: The delay instance.
62
- time: The delay time.
63
- delay_entry: The delay entry.
64
- """
65
-
66
- __module__ = 'brainstate.nn'
67
-
68
- def __init__(
69
- self,
70
- delay: 'Delay',
71
- time: Union[None, int, float],
72
- delay_entry: str,
73
- ):
74
- super().__init__()
75
- self.refs = {'delay': delay}
76
- assert isinstance(delay, Delay), 'The input delay should be an instance of Delay.'
77
- self._delay_entry = delay_entry
78
- self.delay_info = delay.register_entry(self._delay_entry, time)
79
-
80
- def update(self):
81
- return self.refs['delay'].at(self._delay_entry)
82
-
83
-
84
- class Delay(Module):
85
- """
86
- Generate Delays for the given :py:class:`~.State` instance.
87
-
88
- The data in this delay variable is arranged as::
89
-
90
- delay = 0 [ data
91
- delay = 1 data
92
- delay = 2 data
93
- ... ....
94
- ... ....
95
- delay = length-1 data
96
- delay = length data ]
97
-
98
- Args:
99
- time: int, float. The delay time.
100
- init: Any. The delay data. It can be a Python number, like float, int, boolean values.
101
- It can also be arrays. Or a callable function or instance of ``Connector``.
102
- Note that ``initial_delay_data`` should be arranged as the following way::
103
-
104
- delay = 1 [ data
105
- delay = 2 data
106
- ... ....
107
- ... ....
108
- delay = length-1 data
109
- delay = length data ]
110
- entries: optional, dict. The delay access entries.
111
- delay_method: str. The method used for updating delay. Default None.
112
- """
113
-
114
- __module__ = 'brainstate.nn'
115
-
116
- max_time: float #
117
- max_length: int
118
- history: Optional[ShortTermState]
119
-
120
- def __init__(
121
- self,
122
- target_info: PyTree,
123
- time: Optional[Union[int, float, u.Quantity]] = None, # delay time
124
- init: Optional[Union[ArrayLike, Callable]] = None, # delay data before t0
125
- entries: Optional[Dict] = None, # delay access entry
126
- delay_method: Optional[str] = _DELAY_ROTATE, # delay method
127
- interp_method: str = _INTERP_LINEAR, # interpolation method
128
- take_aware_unit: bool = False
129
- ):
130
- # target information
131
- self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), target_info)
132
-
133
- # delay method
134
- assert delay_method in [_DELAY_ROTATE, _DELAY_CONCAT], (
135
- f'Un-supported delay method {delay_method}. '
136
- f'Only support {_DELAY_ROTATE} and {_DELAY_CONCAT}'
137
- )
138
- self.delay_method = delay_method
139
-
140
- # interp method
141
- assert interp_method in [_INTERP_LINEAR, _INTERP_ROUND], (
142
- f'Un-supported interpolation method {interp_method}. '
143
- f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}'
144
- )
145
- self.interp_method = interp_method
146
-
147
- # delay length and time
148
- with jax.ensure_compile_time_eval():
149
- self.max_time, delay_length = _get_delay(time)
150
- self.max_length = delay_length + 1
151
-
152
- super().__init__()
153
-
154
- # delay data
155
- if init is not None:
156
- if not isinstance(init, (numbers.Number, jax.Array, np.ndarray, Callable)):
157
- raise TypeError(f'init should be Array, Callable, or None. But got {init}')
158
- self._init = init
159
- self._history = None
160
-
161
- # other info
162
- self._registered_entries = dict()
163
-
164
- # other info
165
- if entries is not None:
166
- for entry, delay_time in entries.items():
167
- if isinstance(delay_time, (tuple, list)):
168
- self.register_entry(entry, *delay_time)
169
- else:
170
- self.register_entry(entry, delay_time)
171
-
172
- self.take_aware_unit = take_aware_unit
173
- self._unit = None
174
-
175
- @property
176
- def history(self):
177
- return self._history
178
-
179
- @history.setter
180
- def history(self, value):
181
- self._history = value
182
-
183
- def _f_to_init(self, a, batch_size, length):
184
- shape = list(a.shape)
185
- if batch_size is not None:
186
- shape.insert(0, batch_size)
187
- shape.insert(0, length)
188
- if isinstance(self._init, (jax.Array, np.ndarray, numbers.Number)):
189
- data = jnp.broadcast_to(jnp.asarray(self._init, a.dtype), shape)
190
- elif callable(self._init):
191
- data = self._init(shape, dtype=a.dtype)
192
- else:
193
- assert self._init is None, f'init should be Array, Callable, or None. but got {self._init}'
194
- data = jnp.zeros(shape, dtype=a.dtype)
195
- return data
196
-
197
- @call_order(3)
198
- def init_state(self, batch_size: int = None, **kwargs):
199
- fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
200
- self.history = ShortTermState(jax.tree.map(fun, self.target_info))
201
-
202
- def reset_state(self, batch_size: int = None, **kwargs):
203
- fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
204
- self.history.value = jax.tree.map(fun, self.target_info)
205
-
206
- def register_delay(self, *delay_time):
207
- """
208
- Register delay times and update the maximum delay configuration.
209
-
210
- This method processes one or more delay times, validates their format and consistency,
211
- and updates the delay buffer size if necessary. It handles both scalar and vector
212
- delay times, ensuring all vector delays have the same size.
213
-
214
- Args:
215
- *delay_time: Variable number of delay time arguments. The first argument should be
216
- the primary delay time (float, int, or array-like). Additional arguments are
217
- treated as indices or secondary delay parameters. All delay times should be
218
- non-negative numbers or arrays of the same size.
219
-
220
- Returns:
221
- tuple or None: If delay_time[0] is None, returns None. Otherwise, returns a tuple
222
- containing (delay_step, *delay_time[1:]) where delay_step is the computed
223
- delay step in integer time units, and the remaining elements are the
224
- additional delay parameters passed in.
225
-
226
- Raises:
227
- AssertionError: If no delay time is provided (empty delay_time).
228
- ValueError: If delay times have inconsistent sizes when using vector delays,
229
- or if delay times are not scalar or 1D arrays.
230
-
231
- Note:
232
- - The method updates self.max_time and self.max_length if the new delay
233
- requires a larger buffer size.
234
- - Delay steps are computed using the current environment time step (dt).
235
- - All delay indices (delay_time[1:]) must be integers.
236
- - Vector delays must all have the same size as the first delay time.
237
-
238
- Example:
239
- >>> delay_obj.register_delay(5.0) # Register 5ms delay
240
- >>> delay_obj.register_delay(jnp.array([2.0, 3.0]), 0, 1) # Vector delay with indices
241
- """
242
- assert len(delay_time) >= 1, 'You should provide at least one delay time.'
243
- delay_size = u.math.size(delay_time[0])
244
- for dt in delay_time[1:]:
245
- assert jnp.issubdtype(u.math.get_dtype(dt), jnp.integer), f'The index should be integer. But got {dt}.'
246
- # delay_size = u.math.size(delay_time[0])
247
- # for dt in delay_time:
248
- # if u.math.ndim(dt) == 0:
249
- # pass
250
- # elif u.math.ndim(dt) == 1:
251
- # if u.math.size(dt) != delay_size:
252
- # raise ValueError(
253
- # f'The delay time should be a scalar or a vector with the same size. '
254
- # f'But got {delay_time}. The delay time {dt} has size {u.math.size(dt)}'
255
- # )
256
- # else:
257
- # raise ValueError(f'The delay time should be a scalar/vector. But got {dt}.')
258
- if delay_time[0] is None:
259
- return None
260
- with jax.ensure_compile_time_eval():
261
- time, delay_step = _get_delay(delay_time[0])
262
- max_delay_step = jnp.max(delay_step)
263
- self.max_time = u.math.max(time)
264
-
265
- # delay variable
266
- if self.max_length <= max_delay_step + 1:
267
- self.max_length = max_delay_step + 1
268
- return delay_step, *delay_time[1:]
269
-
270
- def register_entry(self, entry: str, *delay_time) -> 'Delay':
271
- """
272
- Register an entry to access the delay data.
273
-
274
- Args:
275
- entry: str. The entry to access the delay data.
276
- delay_time: The delay time of the entry, the first element is the delay time,
277
- the second and later element is the index.
278
- """
279
- if entry in self._registered_entries:
280
- raise KeyError(
281
- f'Entry {entry} has been registered. '
282
- f'The existing delay for the key {entry} is {self._registered_entries[entry]}. '
283
- f'The new delay for the key {entry} is {delay_time}. '
284
- f'You can use another key. '
285
- )
286
- delay_info = self.register_delay(*delay_time)
287
- self._registered_entries[entry] = delay_info
288
- return delay_info
289
-
290
- def access(self, entry: str, delay_time: Sequence) -> DelayAccess:
291
- """
292
- Create a DelayAccess object for a specific delay entry and delay time.
293
-
294
- Args:
295
- entry (str): The name of the delay entry to access.
296
- delay_time (Sequence): The delay time or parameters associated with the entry.
297
-
298
- Returns:
299
- DelayAccess: An object that provides access to the delay data for the specified entry and time.
300
- """
301
- return DelayAccess(self, delay_time, delay_entry=entry)
302
-
303
- def at(self, entry: str) -> ArrayLike:
304
- """
305
- Get the data at the given entry.
306
-
307
- Args:
308
- entry: str. The entry to access the data.
309
-
310
- Returns:
311
- The data.
312
- """
313
- assert isinstance(entry, str), (f'entry should be a string for describing the '
314
- f'entry of the delay data. But we got {entry}.')
315
- if entry not in self._registered_entries:
316
- raise KeyError(f'Does not find delay entry "{entry}".')
317
- delay_step = self._registered_entries[entry]
318
- if delay_step is None:
319
- delay_step = (0,)
320
- return self.retrieve_at_step(*delay_step)
321
-
322
- def retrieve_at_step(self, delay_step, *indices) -> PyTree:
323
- """
324
- Retrieve the delay data at the given delay time step (the integer to indicate the time step).
325
-
326
- Parameters
327
- ----------
328
- delay_step: int_like
329
- Retrieve the data at the given time step.
330
- indices: tuple
331
- The indices to slice the data.
332
-
333
- Returns
334
- -------
335
- delay_data: The delay data at the given delay step.
336
-
337
- """
338
- assert self.history is not None, 'The delay history is not initialized.'
339
- assert delay_step is not None, 'The delay step should be given.'
340
-
341
- if environ.get(environ.JIT_ERROR_CHECK, False):
342
- def _check_delay(delay_len):
343
- raise ValueError(
344
- f'The request delay length should be less than the '
345
- f'maximum delay {self.max_length - 1}. But we got {delay_len}'
346
- )
347
-
348
- jit_error_if(delay_step >= self.max_length, _check_delay, delay_step)
349
-
350
- # rotation method
351
- with jax.ensure_compile_time_eval():
352
- if self.delay_method == _DELAY_ROTATE:
353
- i = environ.get(environ.I, desc='The time step index.')
354
- di = i - delay_step
355
- delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32)
356
- delay_idx = jax.lax.stop_gradient(delay_idx)
357
-
358
- elif self.delay_method == _DELAY_CONCAT:
359
- delay_idx = delay_step
360
-
361
- else:
362
- raise ValueError(f'Unknown delay updating method "{self.delay_method}"')
363
-
364
- # the delay index
365
- if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer):
366
- raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}')
367
- indices = (delay_idx,) + indices
368
-
369
- # the delay data
370
- if self._unit is None:
371
- return jax.tree.map(lambda a: a[indices], self.history.value)
372
- else:
373
- return jax.tree.map(
374
- lambda hist, unit: u.maybe_decimal(hist[indices] * unit),
375
- self.history.value,
376
- self._unit
377
- )
378
-
379
- def retrieve_at_time(self, delay_time, *indices) -> PyTree:
380
- """
381
- Retrieve the delay data at the given delay time step (the integer to indicate the time step).
382
-
383
- Parameters
384
- ----------
385
- delay_time: float
386
- Retrieve the data at the given time.
387
- indices: tuple
388
- The indices to slice the data.
389
-
390
- Returns
391
- -------
392
- delay_data: The delay data at the given delay step.
393
-
394
- """
395
- assert self.history is not None, 'The delay history is not initialized.'
396
- assert delay_time is not None, 'The delay time should be given.'
397
-
398
- current_time = environ.get(environ.T, desc='The current time.')
399
- dt = environ.get_dt()
400
-
401
- if environ.get(environ.JIT_ERROR_CHECK, False):
402
- def _check_delay(t_now, t_delay):
403
- raise ValueError(
404
- f'The request delay time should be within '
405
- f'[{t_now - self.max_time - dt}, {t_now}], '
406
- f'but we got {t_delay}'
407
- )
408
-
409
- jit_error_if(
410
- jnp.logical_or(
411
- delay_time > current_time,
412
- delay_time < current_time - self.max_time - dt
413
- ),
414
- _check_delay,
415
- current_time,
416
- delay_time
417
- )
418
-
419
- with jax.ensure_compile_time_eval():
420
- diff = current_time - delay_time
421
- float_time_step = diff / dt
422
-
423
- if self.interp_method == _INTERP_LINEAR: # "linear" interpolation
424
- data_at_t0 = self.retrieve_at_step(jnp.asarray(jnp.floor(float_time_step), dtype=jnp.int32), *indices)
425
- data_at_t1 = self.retrieve_at_step(jnp.asarray(jnp.ceil(float_time_step), dtype=jnp.int32), *indices)
426
- t_diff = float_time_step - jnp.floor(float_time_step)
427
- return jax.tree.map(lambda a, b: a * (1 - t_diff) + b * t_diff, data_at_t0, data_at_t1)
428
-
429
- elif self.interp_method == _INTERP_ROUND: # "round" interpolation
430
- return self.retrieve_at_step(jnp.asarray(jnp.round(float_time_step), dtype=jnp.int32), *indices)
431
-
432
- else: # raise error
433
- raise ValueError(f'Un-supported interpolation method {self.interp_method}, '
434
- f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}')
435
-
436
- def update(self, current: PyTree) -> None:
437
- """
438
- Update delay variable with the new data.
439
- """
440
-
441
- with jax.ensure_compile_time_eval():
442
- assert self.history is not None, 'The delay history is not initialized.'
443
-
444
- if self.take_aware_unit and self._unit is None:
445
- self._unit = jax.tree.map(lambda x: u.get_unit(x), current, is_leaf=u.math.is_quantity)
446
-
447
- # update the delay data at the rotation index
448
- if self.delay_method == _DELAY_ROTATE:
449
- i = environ.get(environ.I)
450
- idx = jnp.asarray(i % self.max_length, dtype=environ.dutype())
451
- idx = jax.lax.stop_gradient(idx)
452
- self.history.value = jax.tree.map(
453
- lambda hist, cur: hist.at[idx].set(cur),
454
- self.history.value,
455
- current
456
- )
457
- # update the delay data at the first position
458
- elif self.delay_method == _DELAY_CONCAT:
459
- current = jax.tree.map(lambda a: jnp.expand_dims(a, 0), current)
460
- if self.max_length > 1:
461
- self.history.value = jax.tree.map(
462
- lambda hist, cur: jnp.concatenate([cur, hist[:-1]], axis=0),
463
- self.history.value,
464
- current
465
- )
466
- else:
467
- self.history.value = current
468
-
469
- else:
470
- raise ValueError(f'Unknown updating method "{self.delay_method}"')
471
-
472
-
473
- class StateWithDelay(Delay):
474
- """
475
- Delayed history buffer bound to a module state.
476
-
477
- StateWithDelay is a specialized :py:class:`~.Delay` that attaches to a
478
- concrete :py:class:`~brainstate._state.State` living on a target module
479
- (for example a membrane potential ``V`` on a neuron). It automatically
480
- maintains a rolling history of that state and exposes convenient helpers to
481
- retrieve the value at a given delay either by step or by time.
482
-
483
- In normal usage you rarely instantiate this class directly. It is created
484
- implicitly when using the prefetch-delay helpers on a Dynamics module, e.g.:
485
-
486
- - ``module.prefetch('V').delay.at(5.0 * u.ms)``
487
- - ``module.prefetch_delay('V', 5.0 * u.ms)``
488
-
489
- Both will construct a StateWithDelay bound to ``module.V`` under the hood
490
- and register the requested delay, so you can retrieve the delayed value
491
- inside your update rules.
492
-
493
- Parameters
494
- ----------
495
- target : :py:class:`~brainstate.graph.Node`
496
- The module object that owns the state to track.
497
- item : str
498
- The attribute name of the target state on ``target`` (must be a
499
- :py:class:`~brainstate._state.State`).
500
- init : Callable, optional
501
- Optional initializer used to fill the history buffer before ``t0``
502
- when delays request values from the past that hasn't been simulated yet.
503
- The callable receives ``(shape, dtype)`` and must return an array.
504
- If not provided, zeros are used. You may also pass a scalar/array
505
- literal via the underlying Delay API when constructing manually.
506
- delay_method : {"rotation", "concat"}, default "rotation"
507
- Internal buffering strategy (inherits behavior from :py:class:`~.Delay`).
508
- "rotation" keeps a ring buffer; "concat" shifts by concatenation.
509
-
510
- Attributes
511
- ----------
512
- state : :py:class:`~brainstate._state.State`
513
- The concrete state object being tracked.
514
- history : :py:class:`~brainstate._state.ShortTermState`
515
- Rolling time axis buffer with shape ``[length, *state.shape]``.
516
- max_time : float
517
- Maximum time span currently supported by the buffer.
518
- max_length : int
519
- Buffer length in steps (``ceil(max_time/dt)+1``).
520
-
521
- Notes
522
- -----
523
- - This class inherits all retrieval utilities from :py:class:`~.Delay`:
524
- use :py:meth:`retrieve_at_step` when you know the integer delay steps,
525
- or :py:meth:`retrieve_at_time` for continuous-time queries with optional
526
- linear/round interpolation.
527
- - It is registered as an "after-update" hook on the owning Dynamics so the
528
- buffer is updated automatically after each simulation step.
529
-
530
- Examples
531
- --------
532
- Access a neuron's membrane potential 5 ms in the past:
533
-
534
- >>> import brainunit as u
535
- >>> import brainstate as bst
536
- >>> lif = bst.nn.LIF(100)
537
- >>> # Create a delayed accessor to V(t-5ms)
538
- >>> v_delay = lif.prefetch_delay('V', 5.0 * u.ms)
539
- >>> # Inside another module's update you can read the delayed value
540
- >>> v_t_minus_5ms = v_delay()
541
-
542
- Register multiple delay taps and index-specific delays:
543
-
544
- >>> # Under the hood, a StateWithDelay is created and you can register
545
- >>> # additional taps (in steps or time) via its Delay interface
546
- >>> _ = lif.prefetch('V').delay.at(2.0 * u.ms) # additional delay
547
- >>> # Direct access to buffer by steps (advanced)
548
- >>> # lif._get_after_update('V-prefetch-delay').retrieve_at_step(3)
549
- """
550
-
551
- __module__ = 'brainstate.nn'
552
-
553
- state: State # state
554
-
555
- def __init__(
556
- self,
557
- target: Node,
558
- item: str,
559
- init: Callable = None,
560
- delay_method: Optional[str] = _DELAY_ROTATE,
561
- ):
562
- super().__init__(None, init=init, delay_method=delay_method)
563
-
564
- self._target = target
565
- self._target_term = item
566
-
567
- @property
568
- def state(self) -> State:
569
- r = getattr(self._target, self._target_term)
570
- if not isinstance(r, State):
571
- raise TypeError(f'The term "{self._target_term}" in the module "{self._target}" is not a State.')
572
- return r
573
-
574
- @call_order(3)
575
- def init_state(self, *args, **kwargs):
576
- """
577
- State initialization function.
578
- """
579
- state = self.state
580
- self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), state.value)
581
- super().init_state(*args, **kwargs)
582
-
583
- def update(self, *args) -> None:
584
- """
585
- Update the delay variable with the new data.
586
- """
587
- value = self.state.value
588
- return super().update(value)
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import numbers
17
+ from functools import partial
18
+ from typing import Optional, Dict, Callable, Union, Sequence
19
+
20
+ import brainunit as u
21
+ import jax
22
+ import jax.numpy as jnp
23
+ import numpy as np
24
+
25
+ from brainstate import environ
26
+ from brainstate._state import ShortTermState, State
27
+ from brainstate.graph import Node
28
+ from brainstate.transform import jit_error_if
29
+ from brainstate.typing import ArrayLike, PyTree
30
+ from ._collective_ops import call_order
31
+ from ._module import Module
32
+
33
+ __all__ = [
34
+ 'Delay', 'DelayAccess', 'StateWithDelay',
35
+ ]
36
+
37
+ _DELAY_ROTATE = 'rotation'
38
+ _DELAY_CONCAT = 'concat'
39
+ _INTERP_LINEAR = 'linear_interp'
40
+ _INTERP_ROUND = 'round'
41
+
42
+
43
+ def _get_delay(delay_time):
44
+ if delay_time is None:
45
+ return 0. * environ.get_dt(), 0
46
+ delay_step = delay_time / environ.get_dt()
47
+ assert u.get_dim(delay_step) == u.DIMENSIONLESS
48
+ delay_step = jnp.ceil(delay_step).astype(environ.ditype())
49
+ return delay_time, delay_step
50
+
51
+
52
+ class DelayAccess(Node):
53
+ """
54
+ Accessor node for a registered entry in a Delay instance.
55
+
56
+ This node holds a reference to a Delay and a named entry that was
57
+ registered on that Delay. It is used by graphs to query delayed
58
+ values by delegating to the underlying Delay instance.
59
+
60
+ Args:
61
+ delay: The delay instance.
62
+ *time: The delay time.
63
+ entry: The delay entry.
64
+ """
65
+
66
+ __module__ = 'brainstate.nn'
67
+
68
+ def __init__(
69
+ self,
70
+ delay: 'Delay',
71
+ *time,
72
+ entry: str,
73
+ ):
74
+ super().__init__()
75
+ self.delay = delay
76
+ assert isinstance(delay, Delay), 'The input delay should be an instance of Delay.'
77
+ self._delay_entry = entry
78
+ self.delay_info = delay.register_entry(self._delay_entry, *time)
79
+
80
+ def update(self):
81
+ return self.delay.at(self._delay_entry)
82
+
83
+
84
+ class Delay(Module):
85
+ """
86
+ Delay variable for storing short-term history data.
87
+
88
+ The data in this delay variable is arranged as::
89
+
90
+ delay = 0 [ data
91
+ delay = 1 data
92
+ delay = 2 data
93
+ ... ....
94
+ ... ....
95
+ delay = length-1 data
96
+ delay = length data ]
97
+
98
+ Args:
99
+ time: int, float. The delay time.
100
+ init: Any. The delay data. It can be a Python number, like float, int, boolean values.
101
+ It can also be arrays. Or a callable function or instance of ``Connector``.
102
+ Note that ``initial_delay_data`` should be arranged as the following way::
103
+
104
+ delay = 1 [ data
105
+ delay = 2 data
106
+ ... ....
107
+ ... ....
108
+ delay = length-1 data
109
+ delay = length data ]
110
+ entries: optional, dict. The delay access entries.
111
+ delay_method: str. The method used for updating delay. Default None.
112
+ """
113
+
114
+ __module__ = 'brainstate.nn'
115
+
116
+ max_time: float #
117
+ max_length: int
118
+ history: Optional[ShortTermState]
119
+
120
+ def __init__(
121
+ self,
122
+ target_info: PyTree,
123
+ time: Optional[Union[int, float, u.Quantity]] = None, # delay time
124
+ init: Optional[Union[ArrayLike, Callable]] = None, # delay data before t0
125
+ entries: Optional[Dict] = None, # delay access entry
126
+ delay_method: Optional[str] = _DELAY_ROTATE, # delay method
127
+ interp_method: str = _INTERP_LINEAR, # interpolation method
128
+ take_aware_unit: bool = False
129
+ ):
130
+ # target information
131
+ self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), target_info)
132
+
133
+ # delay method
134
+ assert delay_method in [_DELAY_ROTATE, _DELAY_CONCAT], (
135
+ f'Un-supported delay method {delay_method}. '
136
+ f'Only support {_DELAY_ROTATE} and {_DELAY_CONCAT}'
137
+ )
138
+ self.delay_method = delay_method
139
+
140
+ # interp method
141
+ assert interp_method in [_INTERP_LINEAR, _INTERP_ROUND], (
142
+ f'Un-supported interpolation method {interp_method}. '
143
+ f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}'
144
+ )
145
+ self.interp_method = interp_method
146
+
147
+ # delay length and time
148
+ with jax.ensure_compile_time_eval():
149
+ self.max_time, delay_length = _get_delay(time)
150
+ self.max_length = delay_length + 1
151
+
152
+ super().__init__()
153
+
154
+ # delay data
155
+ if init is not None:
156
+ if not isinstance(init, (numbers.Number, jax.Array, np.ndarray, Callable)):
157
+ raise TypeError(f'init should be Array, Callable, or None. But got {init}')
158
+ self._init = init
159
+ self._history = None
160
+
161
+ # other info
162
+ self._registered_entries = dict()
163
+
164
+ # other info
165
+ if entries is not None:
166
+ for entry, delay_time in entries.items():
167
+ if isinstance(delay_time, (tuple, list)):
168
+ self.register_entry(entry, *delay_time)
169
+ else:
170
+ self.register_entry(entry, delay_time)
171
+
172
+ self.take_aware_unit = take_aware_unit
173
+ self._unit = None
174
+
175
+ @property
176
+ def history(self):
177
+ return self._history
178
+
179
+ @history.setter
180
+ def history(self, value):
181
+ self._history = value
182
+
183
+ def _f_to_init(self, a, batch_size, length):
184
+ shape = list(a.shape)
185
+ if batch_size is not None:
186
+ shape.insert(0, batch_size)
187
+ shape.insert(0, length)
188
+ if isinstance(self._init, (jax.Array, np.ndarray, numbers.Number)):
189
+ data = jnp.broadcast_to(jnp.asarray(self._init, a.dtype), shape)
190
+ elif callable(self._init):
191
+ data = self._init(shape, dtype=a.dtype)
192
+ else:
193
+ assert self._init is None, f'init should be Array, Callable, or None. but got {self._init}'
194
+ data = jnp.zeros(shape, dtype=a.dtype)
195
+ return data
196
+
197
+ @call_order(3)
198
+ def init_state(self, batch_size: int = None, **kwargs):
199
+ fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
200
+ self.history = ShortTermState(jax.tree.map(fun, self.target_info))
201
+
202
+ def reset_state(self, batch_size: int = None, **kwargs):
203
+ fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
204
+ self.history.value = jax.tree.map(fun, self.target_info)
205
+
206
+ def register_delay(self, *delay_time):
207
+ """
208
+ Register delay times and update the maximum delay configuration.
209
+
210
+ This method processes one or more delay times, validates their format and consistency,
211
+ and updates the delay buffer size if necessary. It handles both scalar and vector
212
+ delay times, ensuring all vector delays have the same size.
213
+
214
+ Args:
215
+ *delay_time: Variable number of delay time arguments. The first argument should be
216
+ the primary delay time (float, int, or array-like). Additional arguments are
217
+ treated as indices or secondary delay parameters. All delay times should be
218
+ non-negative numbers or arrays of the same size.
219
+
220
+ Returns:
221
+ tuple or None: If delay_time[0] is None, returns None. Otherwise, returns a tuple
222
+ containing (delay_step, *delay_time[1:]) where delay_step is the computed
223
+ delay step in integer time units, and the remaining elements are the
224
+ additional delay parameters passed in.
225
+
226
+ Raises:
227
+ AssertionError: If no delay time is provided (empty delay_time).
228
+ ValueError: If delay times have inconsistent sizes when using vector delays,
229
+ or if delay times are not scalar or 1D arrays.
230
+
231
+ Note:
232
+ - The method updates self.max_time and self.max_length if the new delay
233
+ requires a larger buffer size.
234
+ - Delay steps are computed using the current environment time step (dt).
235
+ - All delay indices (delay_time[1:]) must be integers.
236
+ - Vector delays must all have the same size as the first delay time.
237
+
238
+ Example:
239
+ >>> delay_obj.register_delay(5.0) # Register 5ms delay
240
+ >>> delay_obj.register_delay(jnp.array([2.0, 3.0]), 0, 1) # Vector delay with indices
241
+ """
242
+ assert len(delay_time) >= 1, 'You should provide at least one delay time.'
243
+ for dt in delay_time[1:]:
244
+ assert jnp.issubdtype(u.math.get_dtype(dt), jnp.integer), f'The index should be integer. But got {dt}.'
245
+ if delay_time[0] is None:
246
+ return None
247
+ with jax.ensure_compile_time_eval():
248
+ time, delay_step = _get_delay(delay_time[0])
249
+ max_delay_step = jnp.max(delay_step)
250
+ self.max_time = u.math.max(time)
251
+
252
+ # delay variable
253
+ if self.max_length <= max_delay_step + 1:
254
+ self.max_length = max_delay_step + 1
255
+ return delay_step, *delay_time[1:]
256
+
257
+ def register_entry(self, entry: str, *delay_time) -> 'Delay':
258
+ """
259
+ Register an entry to access the delay data.
260
+
261
+ Args:
262
+ entry: str. The entry to access the delay data.
263
+ delay_time: The delay time of the entry, the first element is the delay time,
264
+ the second and later element is the index.
265
+ """
266
+ if entry in self._registered_entries:
267
+ raise KeyError(
268
+ f'Entry {entry} has been registered. '
269
+ f'The existing delay for the key {entry} is {self._registered_entries[entry]}. '
270
+ f'The new delay for the key {entry} is {delay_time}. '
271
+ f'You can use another key. '
272
+ )
273
+ delay_info = self.register_delay(*delay_time)
274
+ self._registered_entries[entry] = delay_info
275
+ return delay_info
276
+
277
+ def access(self, entry: str, *delay_time) -> DelayAccess:
278
+ """
279
+ Create a DelayAccess object for a specific delay entry and delay time.
280
+
281
+ Args:
282
+ entry (str): The name of the delay entry to access.
283
+ delay_time (Sequence): The delay time or parameters associated with the entry.
284
+
285
+ Returns:
286
+ DelayAccess: An object that provides access to the delay data for the specified entry and time.
287
+ """
288
+ return DelayAccess(self, delay_time, entry=entry)
289
+
290
+ def at(self, entry: str) -> ArrayLike:
291
+ """
292
+ Get the data at the given entry.
293
+
294
+ Args:
295
+ entry: str. The entry to access the data.
296
+
297
+ Returns:
298
+ The data.
299
+ """
300
+ assert isinstance(entry, str), (f'entry should be a string for describing the '
301
+ f'entry of the delay data. But we got {entry}.')
302
+ if entry not in self._registered_entries:
303
+ raise KeyError(f'Does not find delay entry "{entry}".')
304
+ delay_step = self._registered_entries[entry]
305
+ if delay_step is None:
306
+ delay_step = (0,)
307
+ return self.retrieve_at_step(*delay_step)
308
+
309
+ def retrieve_at_step(self, delay_step, *indices) -> PyTree:
310
+ """
311
+ Retrieve the delay data at the given delay time step (the integer to indicate the time step).
312
+
313
+ Parameters
314
+ ----------
315
+ delay_step: int_like
316
+ Retrieve the data at the given time step.
317
+ indices: tuple
318
+ The indices to slice the data.
319
+
320
+ Returns
321
+ -------
322
+ delay_data: The delay data at the given delay step.
323
+
324
+ """
325
+ assert self.history is not None, 'The delay history is not initialized.'
326
+ assert delay_step is not None, 'The delay step should be given.'
327
+
328
+ if environ.get(environ.JIT_ERROR_CHECK, False):
329
+ def _check_delay(delay_len):
330
+ raise ValueError(
331
+ f'The request delay length should be less than the '
332
+ f'maximum delay {self.max_length - 1}. But we got {delay_len}'
333
+ )
334
+
335
+ jit_error_if(delay_step >= self.max_length, _check_delay, delay_step)
336
+
337
+ # rotation method
338
+ with jax.ensure_compile_time_eval():
339
+ if self.delay_method == _DELAY_ROTATE:
340
+ i = environ.get(environ.I, desc='The time step index.')
341
+ di = i - delay_step
342
+ delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32)
343
+ delay_idx = jax.lax.stop_gradient(delay_idx)
344
+
345
+ elif self.delay_method == _DELAY_CONCAT:
346
+ delay_idx = delay_step
347
+
348
+ else:
349
+ raise ValueError(f'Unknown delay updating method "{self.delay_method}"')
350
+
351
+ # the delay index
352
+ if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer):
353
+ raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}')
354
+ indices = (delay_idx,) + indices
355
+
356
+ # the delay data
357
+ if self._unit is None:
358
+ return jax.tree.map(lambda a: a[indices], self.history.value)
359
+ else:
360
+ return jax.tree.map(
361
+ lambda hist, unit: u.maybe_decimal(hist[indices] * unit),
362
+ self.history.value,
363
+ self._unit
364
+ )
365
+
366
+ def retrieve_at_time(self, delay_time, *indices) -> PyTree:
367
+ """
368
+ Retrieve the delay data at the given delay time step (the integer to indicate the time step).
369
+
370
+ Parameters
371
+ ----------
372
+ delay_time: float
373
+ Retrieve the data at the given time.
374
+ indices: tuple
375
+ The indices to slice the data.
376
+
377
+ Returns
378
+ -------
379
+ delay_data: The delay data at the given delay step.
380
+
381
+ """
382
+ assert self.history is not None, 'The delay history is not initialized.'
383
+ assert delay_time is not None, 'The delay time should be given.'
384
+
385
+ current_time = environ.get(environ.T, desc='The current time.')
386
+ dt = environ.get_dt()
387
+
388
+ if environ.get(environ.JIT_ERROR_CHECK, False):
389
+ def _check_delay(t_now, t_delay):
390
+ raise ValueError(
391
+ f'The request delay time should be within '
392
+ f'[{t_now - self.max_time - dt}, {t_now}], '
393
+ f'but we got {t_delay}'
394
+ )
395
+
396
+ jit_error_if(
397
+ jnp.logical_or(
398
+ delay_time > current_time,
399
+ delay_time < current_time - self.max_time - dt
400
+ ),
401
+ _check_delay,
402
+ current_time,
403
+ delay_time
404
+ )
405
+
406
+ with jax.ensure_compile_time_eval():
407
+ diff = current_time - delay_time
408
+ float_time_step = diff / dt
409
+
410
+ if self.interp_method == _INTERP_LINEAR: # "linear" interpolation
411
+ data_at_t0 = self.retrieve_at_step(jnp.asarray(jnp.floor(float_time_step), dtype=jnp.int32), *indices)
412
+ data_at_t1 = self.retrieve_at_step(jnp.asarray(jnp.ceil(float_time_step), dtype=jnp.int32), *indices)
413
+ t_diff = float_time_step - jnp.floor(float_time_step)
414
+ return jax.tree.map(lambda a, b: a * (1 - t_diff) + b * t_diff, data_at_t0, data_at_t1)
415
+
416
+ elif self.interp_method == _INTERP_ROUND: # "round" interpolation
417
+ return self.retrieve_at_step(jnp.asarray(jnp.round(float_time_step), dtype=jnp.int32), *indices)
418
+
419
+ else: # raise error
420
+ raise ValueError(f'Un-supported interpolation method {self.interp_method}, '
421
+ f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}')
422
+
423
+ def update(self, current: PyTree) -> None:
424
+ """
425
+ Update delay variable with the new data.
426
+ """
427
+
428
+ with jax.ensure_compile_time_eval():
429
+ assert self.history is not None, 'The delay history is not initialized.'
430
+
431
+ if self.take_aware_unit and self._unit is None:
432
+ self._unit = jax.tree.map(lambda x: u.get_unit(x), current, is_leaf=u.math.is_quantity)
433
+
434
+ # update the delay data at the rotation index
435
+ if self.delay_method == _DELAY_ROTATE:
436
+ i = environ.get(environ.I)
437
+ idx = jnp.asarray(i % self.max_length, dtype=environ.dutype())
438
+ idx = jax.lax.stop_gradient(idx)
439
+ self.history.value = jax.tree.map(
440
+ lambda hist, cur: hist.at[idx].set(cur),
441
+ self.history.value,
442
+ current
443
+ )
444
+ # update the delay data at the first position
445
+ elif self.delay_method == _DELAY_CONCAT:
446
+ current = jax.tree.map(lambda a: jnp.expand_dims(a, 0), current)
447
+ if self.max_length > 1:
448
+ self.history.value = jax.tree.map(
449
+ lambda hist, cur: jnp.concatenate([cur, hist[:-1]], axis=0),
450
+ self.history.value,
451
+ current
452
+ )
453
+ else:
454
+ self.history.value = current
455
+
456
+ else:
457
+ raise ValueError(f'Unknown updating method "{self.delay_method}"')
458
+
459
+
460
+ class StateWithDelay(Delay):
461
+ """
462
+ Delayed history buffer bound to a module state.
463
+
464
+ StateWithDelay is a specialized :py:class:`~.Delay` that attaches to a
465
+ concrete :py:class:`~brainstate._state.State` living on a target module
466
+ (for example a membrane potential ``V`` on a neuron). It automatically
467
+ maintains a rolling history of that state and exposes convenient helpers to
468
+ retrieve the value at a given delay either by step or by time.
469
+
470
+ In normal usage you rarely instantiate this class directly. It is created
471
+ implicitly when using the prefetch-delay helpers on a Dynamics module, e.g.:
472
+
473
+ - ``module.prefetch('V').delay.at(5.0 * u.ms)``
474
+ - ``module.prefetch_delay('V', 5.0 * u.ms)``
475
+
476
+ Both will construct a StateWithDelay bound to ``module.V`` under the hood
477
+ and register the requested delay, so you can retrieve the delayed value
478
+ inside your update rules.
479
+
480
+ Parameters
481
+ ----------
482
+ target : :py:class:`~brainstate.graph.Node`
483
+ The module object that owns the state to track.
484
+ item : str
485
+ The attribute name of the target state on ``target`` (must be a
486
+ :py:class:`~brainstate._state.State`).
487
+ init : Callable, optional
488
+ Optional initializer used to fill the history buffer before ``t0``
489
+ when delays request values from the past that hasn't been simulated yet.
490
+ The callable receives ``(shape, dtype)`` and must return an array.
491
+ If not provided, zeros are used. You may also pass a scalar/array
492
+ literal via the underlying Delay API when constructing manually.
493
+ delay_method : {"rotation", "concat"}, default "rotation"
494
+ Internal buffering strategy (inherits behavior from :py:class:`~.Delay`).
495
+ "rotation" keeps a ring buffer; "concat" shifts by concatenation.
496
+
497
+ Attributes
498
+ ----------
499
+ state : :py:class:`~brainstate._state.State`
500
+ The concrete state object being tracked.
501
+ history : :py:class:`~brainstate._state.ShortTermState`
502
+ Rolling time axis buffer with shape ``[length, *state.shape]``.
503
+ max_time : float
504
+ Maximum time span currently supported by the buffer.
505
+ max_length : int
506
+ Buffer length in steps (``ceil(max_time/dt)+1``).
507
+
508
+ Notes
509
+ -----
510
+ - This class inherits all retrieval utilities from :py:class:`~.Delay`:
511
+ use :py:meth:`retrieve_at_step` when you know the integer delay steps,
512
+ or :py:meth:`retrieve_at_time` for continuous-time queries with optional
513
+ linear/round interpolation.
514
+ - It is registered as an "after-update" hook on the owning Dynamics so the
515
+ buffer is updated automatically after each simulation step.
516
+
517
+ Examples
518
+ --------
519
+ Access a neuron's membrane potential 5 ms in the past:
520
+
521
+ >>> import brainunit as u
522
+ >>> import brainstate as brainstate
523
+ >>> lif = brainstate.nn.LIF(100)
524
+ >>> # Create a delayed accessor to V(t-5ms)
525
+ >>> v_delay = lif.prefetch_delay('V', 5.0 * u.ms)
526
+ >>> # Inside another module's update you can read the delayed value
527
+ >>> v_t_minus_5ms = v_delay()
528
+
529
+ Register multiple delay taps and index-specific delays:
530
+
531
+ >>> # Under the hood, a StateWithDelay is created and you can register
532
+ >>> # additional taps (in steps or time) via its Delay interface
533
+ >>> _ = lif.prefetch('V').delay.at(2.0 * u.ms) # additional delay
534
+ >>> # Direct access to buffer by steps (advanced)
535
+ >>> # lif._get_after_update('V-prefetch-delay').retrieve_at_step(3)
536
+ """
537
+
538
+ __module__ = 'brainstate.nn'
539
+
540
+ state: State # state
541
+
542
+ def __init__(
543
+ self,
544
+ target: Node,
545
+ item: str,
546
+ init: Callable = None,
547
+ delay_method: Optional[str] = _DELAY_ROTATE,
548
+ ):
549
+ super().__init__(None, init=init, delay_method=delay_method)
550
+
551
+ self._target = target
552
+ self._target_term = item
553
+
554
+ @property
555
+ def state(self) -> State:
556
+ r = getattr(self._target, self._target_term)
557
+ if not isinstance(r, State):
558
+ raise TypeError(f'The term "{self._target_term}" in the module "{self._target}" is not a State.')
559
+ return r
560
+
561
+ @call_order(3)
562
+ def init_state(self, *args, **kwargs):
563
+ """
564
+ State initialization function.
565
+ """
566
+ state = self.state
567
+ self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), state.value)
568
+ super().init_state(*args, **kwargs)
569
+
570
+ def update(self, *args) -> None:
571
+ """
572
+ Update the delay variable with the new data.
573
+ """
574
+ value = self.state.value
575
+ return super().update(value)