brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. brainstate/__init__.py +169 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2319 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +1652 -1652
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1624 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1433 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +137 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +633 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +154 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +477 -477
  32. brainstate/nn/_dynamics.py +1267 -1267
  33. brainstate/nn/_dynamics_test.py +67 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +384 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/_rand_funs.py +3938 -3938
  64. brainstate/random/_rand_funs_test.py +640 -640
  65. brainstate/random/_rand_seed.py +675 -675
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1617
  68. brainstate/random/_rand_state_test.py +551 -551
  69. brainstate/transform/__init__.py +59 -59
  70. brainstate/transform/_ad_checkpoint.py +176 -176
  71. brainstate/transform/_ad_checkpoint_test.py +49 -49
  72. brainstate/transform/_autograd.py +1025 -1025
  73. brainstate/transform/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -316
  75. brainstate/transform/_conditions_test.py +220 -220
  76. brainstate/transform/_error_if.py +94 -94
  77. brainstate/transform/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -145
  79. brainstate/transform/_eval_shape_test.py +38 -38
  80. brainstate/transform/_jit.py +399 -399
  81. brainstate/transform/_jit_test.py +143 -143
  82. brainstate/transform/_loop_collect_return.py +675 -675
  83. brainstate/transform/_loop_collect_return_test.py +58 -58
  84. brainstate/transform/_loop_no_collection.py +283 -283
  85. brainstate/transform/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -2016
  87. brainstate/transform/_make_jaxpr_test.py +1510 -1510
  88. brainstate/transform/_mapping.py +529 -529
  89. brainstate/transform/_mapping_test.py +194 -194
  90. brainstate/transform/_progress_bar.py +255 -255
  91. brainstate/transform/_random.py +171 -171
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate-0.2.0.dist-info/RECORD +0 -111
  111. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  112. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/nn/_delay.py CHANGED
@@ -1,575 +1,575 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import numbers
17
- from functools import partial
18
- from typing import Optional, Dict, Callable, Union, Sequence
19
-
20
- import brainunit as u
21
- import jax
22
- import jax.numpy as jnp
23
- import numpy as np
24
-
25
- from brainstate import environ
26
- from brainstate._state import ShortTermState, State
27
- from brainstate.graph import Node
28
- from brainstate.transform import jit_error_if
29
- from brainstate.typing import ArrayLike, PyTree
30
- from ._collective_ops import call_order
31
- from ._module import Module
32
-
33
- __all__ = [
34
- 'Delay', 'DelayAccess', 'StateWithDelay',
35
- ]
36
-
37
- _DELAY_ROTATE = 'rotation'
38
- _DELAY_CONCAT = 'concat'
39
- _INTERP_LINEAR = 'linear_interp'
40
- _INTERP_ROUND = 'round'
41
-
42
-
43
- def _get_delay(delay_time):
44
- if delay_time is None:
45
- return 0. * environ.get_dt(), 0
46
- delay_step = delay_time / environ.get_dt()
47
- assert u.get_dim(delay_step) == u.DIMENSIONLESS
48
- delay_step = jnp.ceil(delay_step).astype(environ.ditype())
49
- return delay_time, delay_step
50
-
51
-
52
- class DelayAccess(Node):
53
- """
54
- Accessor node for a registered entry in a Delay instance.
55
-
56
- This node holds a reference to a Delay and a named entry that was
57
- registered on that Delay. It is used by graphs to query delayed
58
- values by delegating to the underlying Delay instance.
59
-
60
- Args:
61
- delay: The delay instance.
62
- *time: The delay time.
63
- entry: The delay entry.
64
- """
65
-
66
- __module__ = 'brainstate.nn'
67
-
68
- def __init__(
69
- self,
70
- delay: 'Delay',
71
- *time,
72
- entry: str,
73
- ):
74
- super().__init__()
75
- self.delay = delay
76
- assert isinstance(delay, Delay), 'The input delay should be an instance of Delay.'
77
- self._delay_entry = entry
78
- self.delay_info = delay.register_entry(self._delay_entry, *time)
79
-
80
- def update(self):
81
- return self.delay.at(self._delay_entry)
82
-
83
-
84
- class Delay(Module):
85
- """
86
- Delay variable for storing short-term history data.
87
-
88
- The data in this delay variable is arranged as::
89
-
90
- delay = 0 [ data
91
- delay = 1 data
92
- delay = 2 data
93
- ... ....
94
- ... ....
95
- delay = length-1 data
96
- delay = length data ]
97
-
98
- Args:
99
- time: int, float. The delay time.
100
- init: Any. The delay data. It can be a Python number, like float, int, boolean values.
101
- It can also be arrays. Or a callable function or instance of ``Connector``.
102
- Note that ``initial_delay_data`` should be arranged as the following way::
103
-
104
- delay = 1 [ data
105
- delay = 2 data
106
- ... ....
107
- ... ....
108
- delay = length-1 data
109
- delay = length data ]
110
- entries: optional, dict. The delay access entries.
111
- delay_method: str. The method used for updating delay. Default None.
112
- """
113
-
114
- __module__ = 'brainstate.nn'
115
-
116
- max_time: float #
117
- max_length: int
118
- history: Optional[ShortTermState]
119
-
120
- def __init__(
121
- self,
122
- target_info: PyTree,
123
- time: Optional[Union[int, float, u.Quantity]] = None, # delay time
124
- init: Optional[Union[ArrayLike, Callable]] = None, # delay data before t0
125
- entries: Optional[Dict] = None, # delay access entry
126
- delay_method: Optional[str] = _DELAY_ROTATE, # delay method
127
- interp_method: str = _INTERP_LINEAR, # interpolation method
128
- take_aware_unit: bool = False
129
- ):
130
- # target information
131
- self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), target_info)
132
-
133
- # delay method
134
- assert delay_method in [_DELAY_ROTATE, _DELAY_CONCAT], (
135
- f'Un-supported delay method {delay_method}. '
136
- f'Only support {_DELAY_ROTATE} and {_DELAY_CONCAT}'
137
- )
138
- self.delay_method = delay_method
139
-
140
- # interp method
141
- assert interp_method in [_INTERP_LINEAR, _INTERP_ROUND], (
142
- f'Un-supported interpolation method {interp_method}. '
143
- f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}'
144
- )
145
- self.interp_method = interp_method
146
-
147
- # delay length and time
148
- with jax.ensure_compile_time_eval():
149
- self.max_time, delay_length = _get_delay(time)
150
- self.max_length = delay_length + 1
151
-
152
- super().__init__()
153
-
154
- # delay data
155
- if init is not None:
156
- if not isinstance(init, (numbers.Number, jax.Array, np.ndarray, Callable)):
157
- raise TypeError(f'init should be Array, Callable, or None. But got {init}')
158
- self._init = init
159
- self._history = None
160
-
161
- # other info
162
- self._registered_entries = dict()
163
-
164
- # other info
165
- if entries is not None:
166
- for entry, delay_time in entries.items():
167
- if isinstance(delay_time, (tuple, list)):
168
- self.register_entry(entry, *delay_time)
169
- else:
170
- self.register_entry(entry, delay_time)
171
-
172
- self.take_aware_unit = take_aware_unit
173
- self._unit = None
174
-
175
- @property
176
- def history(self):
177
- return self._history
178
-
179
- @history.setter
180
- def history(self, value):
181
- self._history = value
182
-
183
- def _f_to_init(self, a, batch_size, length):
184
- shape = list(a.shape)
185
- if batch_size is not None:
186
- shape.insert(0, batch_size)
187
- shape.insert(0, length)
188
- if isinstance(self._init, (jax.Array, np.ndarray, numbers.Number)):
189
- data = jnp.broadcast_to(jnp.asarray(self._init, a.dtype), shape)
190
- elif callable(self._init):
191
- data = self._init(shape, dtype=a.dtype)
192
- else:
193
- assert self._init is None, f'init should be Array, Callable, or None. but got {self._init}'
194
- data = jnp.zeros(shape, dtype=a.dtype)
195
- return data
196
-
197
- @call_order(3)
198
- def init_state(self, batch_size: int = None, **kwargs):
199
- fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
200
- self.history = ShortTermState(jax.tree.map(fun, self.target_info))
201
-
202
- def reset_state(self, batch_size: int = None, **kwargs):
203
- fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
204
- self.history.value = jax.tree.map(fun, self.target_info)
205
-
206
- def register_delay(self, *delay_time):
207
- """
208
- Register delay times and update the maximum delay configuration.
209
-
210
- This method processes one or more delay times, validates their format and consistency,
211
- and updates the delay buffer size if necessary. It handles both scalar and vector
212
- delay times, ensuring all vector delays have the same size.
213
-
214
- Args:
215
- *delay_time: Variable number of delay time arguments. The first argument should be
216
- the primary delay time (float, int, or array-like). Additional arguments are
217
- treated as indices or secondary delay parameters. All delay times should be
218
- non-negative numbers or arrays of the same size.
219
-
220
- Returns:
221
- tuple or None: If delay_time[0] is None, returns None. Otherwise, returns a tuple
222
- containing (delay_step, *delay_time[1:]) where delay_step is the computed
223
- delay step in integer time units, and the remaining elements are the
224
- additional delay parameters passed in.
225
-
226
- Raises:
227
- AssertionError: If no delay time is provided (empty delay_time).
228
- ValueError: If delay times have inconsistent sizes when using vector delays,
229
- or if delay times are not scalar or 1D arrays.
230
-
231
- Note:
232
- - The method updates self.max_time and self.max_length if the new delay
233
- requires a larger buffer size.
234
- - Delay steps are computed using the current environment time step (dt).
235
- - All delay indices (delay_time[1:]) must be integers.
236
- - Vector delays must all have the same size as the first delay time.
237
-
238
- Example:
239
- >>> delay_obj.register_delay(5.0) # Register 5ms delay
240
- >>> delay_obj.register_delay(jnp.array([2.0, 3.0]), 0, 1) # Vector delay with indices
241
- """
242
- assert len(delay_time) >= 1, 'You should provide at least one delay time.'
243
- for dt in delay_time[1:]:
244
- assert jnp.issubdtype(u.math.get_dtype(dt), jnp.integer), f'The index should be integer. But got {dt}.'
245
- if delay_time[0] is None:
246
- return None
247
- with jax.ensure_compile_time_eval():
248
- time, delay_step = _get_delay(delay_time[0])
249
- max_delay_step = jnp.max(delay_step)
250
- self.max_time = u.math.max(time)
251
-
252
- # delay variable
253
- if self.max_length <= max_delay_step + 1:
254
- self.max_length = max_delay_step + 1
255
- return delay_step, *delay_time[1:]
256
-
257
- def register_entry(self, entry: str, *delay_time) -> 'Delay':
258
- """
259
- Register an entry to access the delay data.
260
-
261
- Args:
262
- entry: str. The entry to access the delay data.
263
- delay_time: The delay time of the entry, the first element is the delay time,
264
- the second and later element is the index.
265
- """
266
- if entry in self._registered_entries:
267
- raise KeyError(
268
- f'Entry {entry} has been registered. '
269
- f'The existing delay for the key {entry} is {self._registered_entries[entry]}. '
270
- f'The new delay for the key {entry} is {delay_time}. '
271
- f'You can use another key. '
272
- )
273
- delay_info = self.register_delay(*delay_time)
274
- self._registered_entries[entry] = delay_info
275
- return delay_info
276
-
277
- def access(self, entry: str, *delay_time) -> DelayAccess:
278
- """
279
- Create a DelayAccess object for a specific delay entry and delay time.
280
-
281
- Args:
282
- entry (str): The name of the delay entry to access.
283
- delay_time (Sequence): The delay time or parameters associated with the entry.
284
-
285
- Returns:
286
- DelayAccess: An object that provides access to the delay data for the specified entry and time.
287
- """
288
- return DelayAccess(self, delay_time, entry=entry)
289
-
290
- def at(self, entry: str) -> ArrayLike:
291
- """
292
- Get the data at the given entry.
293
-
294
- Args:
295
- entry: str. The entry to access the data.
296
-
297
- Returns:
298
- The data.
299
- """
300
- assert isinstance(entry, str), (f'entry should be a string for describing the '
301
- f'entry of the delay data. But we got {entry}.')
302
- if entry not in self._registered_entries:
303
- raise KeyError(f'Does not find delay entry "{entry}".')
304
- delay_step = self._registered_entries[entry]
305
- if delay_step is None:
306
- delay_step = (0,)
307
- return self.retrieve_at_step(*delay_step)
308
-
309
- def retrieve_at_step(self, delay_step, *indices) -> PyTree:
310
- """
311
- Retrieve the delay data at the given delay time step (the integer to indicate the time step).
312
-
313
- Parameters
314
- ----------
315
- delay_step: int_like
316
- Retrieve the data at the given time step.
317
- indices: tuple
318
- The indices to slice the data.
319
-
320
- Returns
321
- -------
322
- delay_data: The delay data at the given delay step.
323
-
324
- """
325
- assert self.history is not None, 'The delay history is not initialized.'
326
- assert delay_step is not None, 'The delay step should be given.'
327
-
328
- if environ.get(environ.JIT_ERROR_CHECK, False):
329
- def _check_delay(delay_len):
330
- raise ValueError(
331
- f'The request delay length should be less than the '
332
- f'maximum delay {self.max_length - 1}. But we got {delay_len}'
333
- )
334
-
335
- jit_error_if(delay_step >= self.max_length, _check_delay, delay_step)
336
-
337
- # rotation method
338
- with jax.ensure_compile_time_eval():
339
- if self.delay_method == _DELAY_ROTATE:
340
- i = environ.get(environ.I, desc='The time step index.')
341
- di = i - delay_step
342
- delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32)
343
- delay_idx = jax.lax.stop_gradient(delay_idx)
344
-
345
- elif self.delay_method == _DELAY_CONCAT:
346
- delay_idx = delay_step
347
-
348
- else:
349
- raise ValueError(f'Unknown delay updating method "{self.delay_method}"')
350
-
351
- # the delay index
352
- if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer):
353
- raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}')
354
- indices = (delay_idx,) + indices
355
-
356
- # the delay data
357
- if self._unit is None:
358
- return jax.tree.map(lambda a: a[indices], self.history.value)
359
- else:
360
- return jax.tree.map(
361
- lambda hist, unit: u.maybe_decimal(hist[indices] * unit),
362
- self.history.value,
363
- self._unit
364
- )
365
-
366
- def retrieve_at_time(self, delay_time, *indices) -> PyTree:
367
- """
368
- Retrieve the delay data at the given delay time step (the integer to indicate the time step).
369
-
370
- Parameters
371
- ----------
372
- delay_time: float
373
- Retrieve the data at the given time.
374
- indices: tuple
375
- The indices to slice the data.
376
-
377
- Returns
378
- -------
379
- delay_data: The delay data at the given delay step.
380
-
381
- """
382
- assert self.history is not None, 'The delay history is not initialized.'
383
- assert delay_time is not None, 'The delay time should be given.'
384
-
385
- current_time = environ.get(environ.T, desc='The current time.')
386
- dt = environ.get_dt()
387
-
388
- if environ.get(environ.JIT_ERROR_CHECK, False):
389
- def _check_delay(t_now, t_delay):
390
- raise ValueError(
391
- f'The request delay time should be within '
392
- f'[{t_now - self.max_time - dt}, {t_now}], '
393
- f'but we got {t_delay}'
394
- )
395
-
396
- jit_error_if(
397
- jnp.logical_or(
398
- delay_time > current_time,
399
- delay_time < current_time - self.max_time - dt
400
- ),
401
- _check_delay,
402
- current_time,
403
- delay_time
404
- )
405
-
406
- with jax.ensure_compile_time_eval():
407
- diff = current_time - delay_time
408
- float_time_step = diff / dt
409
-
410
- if self.interp_method == _INTERP_LINEAR: # "linear" interpolation
411
- data_at_t0 = self.retrieve_at_step(jnp.asarray(jnp.floor(float_time_step), dtype=jnp.int32), *indices)
412
- data_at_t1 = self.retrieve_at_step(jnp.asarray(jnp.ceil(float_time_step), dtype=jnp.int32), *indices)
413
- t_diff = float_time_step - jnp.floor(float_time_step)
414
- return jax.tree.map(lambda a, b: a * (1 - t_diff) + b * t_diff, data_at_t0, data_at_t1)
415
-
416
- elif self.interp_method == _INTERP_ROUND: # "round" interpolation
417
- return self.retrieve_at_step(jnp.asarray(jnp.round(float_time_step), dtype=jnp.int32), *indices)
418
-
419
- else: # raise error
420
- raise ValueError(f'Un-supported interpolation method {self.interp_method}, '
421
- f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}')
422
-
423
- def update(self, current: PyTree) -> None:
424
- """
425
- Update delay variable with the new data.
426
- """
427
-
428
- with jax.ensure_compile_time_eval():
429
- assert self.history is not None, 'The delay history is not initialized.'
430
-
431
- if self.take_aware_unit and self._unit is None:
432
- self._unit = jax.tree.map(lambda x: u.get_unit(x), current, is_leaf=u.math.is_quantity)
433
-
434
- # update the delay data at the rotation index
435
- if self.delay_method == _DELAY_ROTATE:
436
- i = environ.get(environ.I)
437
- idx = jnp.asarray(i % self.max_length, dtype=environ.dutype())
438
- idx = jax.lax.stop_gradient(idx)
439
- self.history.value = jax.tree.map(
440
- lambda hist, cur: hist.at[idx].set(cur),
441
- self.history.value,
442
- current
443
- )
444
- # update the delay data at the first position
445
- elif self.delay_method == _DELAY_CONCAT:
446
- current = jax.tree.map(lambda a: jnp.expand_dims(a, 0), current)
447
- if self.max_length > 1:
448
- self.history.value = jax.tree.map(
449
- lambda hist, cur: jnp.concatenate([cur, hist[:-1]], axis=0),
450
- self.history.value,
451
- current
452
- )
453
- else:
454
- self.history.value = current
455
-
456
- else:
457
- raise ValueError(f'Unknown updating method "{self.delay_method}"')
458
-
459
-
460
- class StateWithDelay(Delay):
461
- """
462
- Delayed history buffer bound to a module state.
463
-
464
- StateWithDelay is a specialized :py:class:`~.Delay` that attaches to a
465
- concrete :py:class:`~brainstate._state.State` living on a target module
466
- (for example a membrane potential ``V`` on a neuron). It automatically
467
- maintains a rolling history of that state and exposes convenient helpers to
468
- retrieve the value at a given delay either by step or by time.
469
-
470
- In normal usage you rarely instantiate this class directly. It is created
471
- implicitly when using the prefetch-delay helpers on a Dynamics module, e.g.:
472
-
473
- - ``module.prefetch('V').delay.at(5.0 * u.ms)``
474
- - ``module.prefetch_delay('V', 5.0 * u.ms)``
475
-
476
- Both will construct a StateWithDelay bound to ``module.V`` under the hood
477
- and register the requested delay, so you can retrieve the delayed value
478
- inside your update rules.
479
-
480
- Parameters
481
- ----------
482
- target : :py:class:`~brainstate.graph.Node`
483
- The module object that owns the state to track.
484
- item : str
485
- The attribute name of the target state on ``target`` (must be a
486
- :py:class:`~brainstate._state.State`).
487
- init : Callable, optional
488
- Optional initializer used to fill the history buffer before ``t0``
489
- when delays request values from the past that hasn't been simulated yet.
490
- The callable receives ``(shape, dtype)`` and must return an array.
491
- If not provided, zeros are used. You may also pass a scalar/array
492
- literal via the underlying Delay API when constructing manually.
493
- delay_method : {"rotation", "concat"}, default "rotation"
494
- Internal buffering strategy (inherits behavior from :py:class:`~.Delay`).
495
- "rotation" keeps a ring buffer; "concat" shifts by concatenation.
496
-
497
- Attributes
498
- ----------
499
- state : :py:class:`~brainstate._state.State`
500
- The concrete state object being tracked.
501
- history : :py:class:`~brainstate._state.ShortTermState`
502
- Rolling time axis buffer with shape ``[length, *state.shape]``.
503
- max_time : float
504
- Maximum time span currently supported by the buffer.
505
- max_length : int
506
- Buffer length in steps (``ceil(max_time/dt)+1``).
507
-
508
- Notes
509
- -----
510
- - This class inherits all retrieval utilities from :py:class:`~.Delay`:
511
- use :py:meth:`retrieve_at_step` when you know the integer delay steps,
512
- or :py:meth:`retrieve_at_time` for continuous-time queries with optional
513
- linear/round interpolation.
514
- - It is registered as an "after-update" hook on the owning Dynamics so the
515
- buffer is updated automatically after each simulation step.
516
-
517
- Examples
518
- --------
519
- Access a neuron's membrane potential 5 ms in the past:
520
-
521
- >>> import brainunit as u
522
- >>> import brainstate as brainstate
523
- >>> lif = brainstate.nn.LIF(100)
524
- >>> # Create a delayed accessor to V(t-5ms)
525
- >>> v_delay = lif.prefetch_delay('V', 5.0 * u.ms)
526
- >>> # Inside another module's update you can read the delayed value
527
- >>> v_t_minus_5ms = v_delay()
528
-
529
- Register multiple delay taps and index-specific delays:
530
-
531
- >>> # Under the hood, a StateWithDelay is created and you can register
532
- >>> # additional taps (in steps or time) via its Delay interface
533
- >>> _ = lif.prefetch('V').delay.at(2.0 * u.ms) # additional delay
534
- >>> # Direct access to buffer by steps (advanced)
535
- >>> # lif._get_after_update('V-prefetch-delay').retrieve_at_step(3)
536
- """
537
-
538
- __module__ = 'brainstate.nn'
539
-
540
- state: State # state
541
-
542
- def __init__(
543
- self,
544
- target: Node,
545
- item: str,
546
- init: Callable = None,
547
- delay_method: Optional[str] = _DELAY_ROTATE,
548
- ):
549
- super().__init__(None, init=init, delay_method=delay_method)
550
-
551
- self._target = target
552
- self._target_term = item
553
-
554
- @property
555
- def state(self) -> State:
556
- r = getattr(self._target, self._target_term)
557
- if not isinstance(r, State):
558
- raise TypeError(f'The term "{self._target_term}" in the module "{self._target}" is not a State.')
559
- return r
560
-
561
- @call_order(3)
562
- def init_state(self, *args, **kwargs):
563
- """
564
- State initialization function.
565
- """
566
- state = self.state
567
- self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), state.value)
568
- super().init_state(*args, **kwargs)
569
-
570
- def update(self, *args) -> None:
571
- """
572
- Update the delay variable with the new data.
573
- """
574
- value = self.state.value
575
- return super().update(value)
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import numbers
17
+ from functools import partial
18
+ from typing import Optional, Dict, Callable, Union, Sequence
19
+
20
+ import brainunit as u
21
+ import jax
22
+ import jax.numpy as jnp
23
+ import numpy as np
24
+
25
+ from brainstate import environ
26
+ from brainstate._state import ShortTermState, State
27
+ from brainstate.graph import Node
28
+ from brainstate.transform import jit_error_if
29
+ from brainstate.typing import ArrayLike, PyTree
30
+ from ._collective_ops import call_order
31
+ from ._module import Module
32
+
33
+ __all__ = [
34
+ 'Delay', 'DelayAccess', 'StateWithDelay',
35
+ ]
36
+
37
+ _DELAY_ROTATE = 'rotation'
38
+ _DELAY_CONCAT = 'concat'
39
+ _INTERP_LINEAR = 'linear_interp'
40
+ _INTERP_ROUND = 'round'
41
+
42
+
43
+ def _get_delay(delay_time):
44
+ if delay_time is None:
45
+ return 0. * environ.get_dt(), 0
46
+ delay_step = delay_time / environ.get_dt()
47
+ assert u.get_dim(delay_step) == u.DIMENSIONLESS
48
+ delay_step = jnp.ceil(delay_step).astype(environ.ditype())
49
+ return delay_time, delay_step
50
+
51
+
52
+ class DelayAccess(Node):
53
+ """
54
+ Accessor node for a registered entry in a Delay instance.
55
+
56
+ This node holds a reference to a Delay and a named entry that was
57
+ registered on that Delay. It is used by graphs to query delayed
58
+ values by delegating to the underlying Delay instance.
59
+
60
+ Args:
61
+ delay: The delay instance.
62
+ *time: The delay time.
63
+ entry: The delay entry.
64
+ """
65
+
66
+ __module__ = 'brainstate.nn'
67
+
68
+ def __init__(
69
+ self,
70
+ delay: 'Delay',
71
+ *time,
72
+ entry: str,
73
+ ):
74
+ super().__init__()
75
+ self.delay = delay
76
+ assert isinstance(delay, Delay), 'The input delay should be an instance of Delay.'
77
+ self._delay_entry = entry
78
+ self.delay_info = delay.register_entry(self._delay_entry, *time)
79
+
80
+ def update(self):
81
+ return self.delay.at(self._delay_entry)
82
+
83
+
84
+ class Delay(Module):
85
+ """
86
+ Delay variable for storing short-term history data.
87
+
88
+ The data in this delay variable is arranged as::
89
+
90
+ delay = 0 [ data
91
+ delay = 1 data
92
+ delay = 2 data
93
+ ... ....
94
+ ... ....
95
+ delay = length-1 data
96
+ delay = length data ]
97
+
98
+ Args:
99
+ time: int, float. The delay time.
100
+ init: Any. The delay data. It can be a Python number, like float, int, boolean values.
101
+ It can also be arrays. Or a callable function or instance of ``Connector``.
102
+ Note that ``initial_delay_data`` should be arranged as the following way::
103
+
104
+ delay = 1 [ data
105
+ delay = 2 data
106
+ ... ....
107
+ ... ....
108
+ delay = length-1 data
109
+ delay = length data ]
110
+ entries: optional, dict. The delay access entries.
111
+ delay_method: str. The method used for updating delay. Default None.
112
+ """
113
+
114
+ __module__ = 'brainstate.nn'
115
+
116
+ max_time: float #
117
+ max_length: int
118
+ history: Optional[ShortTermState]
119
+
120
+ def __init__(
121
+ self,
122
+ target_info: PyTree,
123
+ time: Optional[Union[int, float, u.Quantity]] = None, # delay time
124
+ init: Optional[Union[ArrayLike, Callable]] = None, # delay data before t0
125
+ entries: Optional[Dict] = None, # delay access entry
126
+ delay_method: Optional[str] = _DELAY_ROTATE, # delay method
127
+ interp_method: str = _INTERP_LINEAR, # interpolation method
128
+ take_aware_unit: bool = False
129
+ ):
130
+ # target information
131
+ self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), target_info)
132
+
133
+ # delay method
134
+ assert delay_method in [_DELAY_ROTATE, _DELAY_CONCAT], (
135
+ f'Un-supported delay method {delay_method}. '
136
+ f'Only support {_DELAY_ROTATE} and {_DELAY_CONCAT}'
137
+ )
138
+ self.delay_method = delay_method
139
+
140
+ # interp method
141
+ assert interp_method in [_INTERP_LINEAR, _INTERP_ROUND], (
142
+ f'Un-supported interpolation method {interp_method}. '
143
+ f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}'
144
+ )
145
+ self.interp_method = interp_method
146
+
147
+ # delay length and time
148
+ with jax.ensure_compile_time_eval():
149
+ self.max_time, delay_length = _get_delay(time)
150
+ self.max_length = delay_length + 1
151
+
152
+ super().__init__()
153
+
154
+ # delay data
155
+ if init is not None:
156
+ if not isinstance(init, (numbers.Number, jax.Array, np.ndarray, Callable)):
157
+ raise TypeError(f'init should be Array, Callable, or None. But got {init}')
158
+ self._init = init
159
+ self._history = None
160
+
161
+ # other info
162
+ self._registered_entries = dict()
163
+
164
+ # other info
165
+ if entries is not None:
166
+ for entry, delay_time in entries.items():
167
+ if isinstance(delay_time, (tuple, list)):
168
+ self.register_entry(entry, *delay_time)
169
+ else:
170
+ self.register_entry(entry, delay_time)
171
+
172
+ self.take_aware_unit = take_aware_unit
173
+ self._unit = None
174
+
175
+ @property
176
+ def history(self):
177
+ return self._history
178
+
179
+ @history.setter
180
+ def history(self, value):
181
+ self._history = value
182
+
183
+ def _f_to_init(self, a, batch_size, length):
184
+ shape = list(a.shape)
185
+ if batch_size is not None:
186
+ shape.insert(0, batch_size)
187
+ shape.insert(0, length)
188
+ if isinstance(self._init, (jax.Array, np.ndarray, numbers.Number)):
189
+ data = jnp.broadcast_to(jnp.asarray(self._init, a.dtype), shape)
190
+ elif callable(self._init):
191
+ data = self._init(shape, dtype=a.dtype)
192
+ else:
193
+ assert self._init is None, f'init should be Array, Callable, or None. but got {self._init}'
194
+ data = jnp.zeros(shape, dtype=a.dtype)
195
+ return data
196
+
197
+ @call_order(3)
198
+ def init_state(self, batch_size: int = None, **kwargs):
199
+ fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
200
+ self.history = ShortTermState(jax.tree.map(fun, self.target_info))
201
+
202
+ def reset_state(self, batch_size: int = None, **kwargs):
203
+ fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
204
+ self.history.value = jax.tree.map(fun, self.target_info)
205
+
206
+ def register_delay(self, *delay_time):
207
+ """
208
+ Register delay times and update the maximum delay configuration.
209
+
210
+ This method processes one or more delay times, validates their format and consistency,
211
+ and updates the delay buffer size if necessary. It handles both scalar and vector
212
+ delay times, ensuring all vector delays have the same size.
213
+
214
+ Args:
215
+ *delay_time: Variable number of delay time arguments. The first argument should be
216
+ the primary delay time (float, int, or array-like). Additional arguments are
217
+ treated as indices or secondary delay parameters. All delay times should be
218
+ non-negative numbers or arrays of the same size.
219
+
220
+ Returns:
221
+ tuple or None: If delay_time[0] is None, returns None. Otherwise, returns a tuple
222
+ containing (delay_step, *delay_time[1:]) where delay_step is the computed
223
+ delay step in integer time units, and the remaining elements are the
224
+ additional delay parameters passed in.
225
+
226
+ Raises:
227
+ AssertionError: If no delay time is provided (empty delay_time).
228
+ ValueError: If delay times have inconsistent sizes when using vector delays,
229
+ or if delay times are not scalar or 1D arrays.
230
+
231
+ Note:
232
+ - The method updates self.max_time and self.max_length if the new delay
233
+ requires a larger buffer size.
234
+ - Delay steps are computed using the current environment time step (dt).
235
+ - All delay indices (delay_time[1:]) must be integers.
236
+ - Vector delays must all have the same size as the first delay time.
237
+
238
+ Example:
239
+ >>> delay_obj.register_delay(5.0) # Register 5ms delay
240
+ >>> delay_obj.register_delay(jnp.array([2.0, 3.0]), 0, 1) # Vector delay with indices
241
+ """
242
+ assert len(delay_time) >= 1, 'You should provide at least one delay time.'
243
+ for dt in delay_time[1:]:
244
+ assert jnp.issubdtype(u.math.get_dtype(dt), jnp.integer), f'The index should be integer. But got {dt}.'
245
+ if delay_time[0] is None:
246
+ return None
247
+ with jax.ensure_compile_time_eval():
248
+ time, delay_step = _get_delay(delay_time[0])
249
+ max_delay_step = jnp.max(delay_step)
250
+ self.max_time = u.math.max(time)
251
+
252
+ # delay variable
253
+ if self.max_length <= max_delay_step + 1:
254
+ self.max_length = max_delay_step + 1
255
+ return delay_step, *delay_time[1:]
256
+
257
+ def register_entry(self, entry: str, *delay_time) -> 'Delay':
258
+ """
259
+ Register an entry to access the delay data.
260
+
261
+ Args:
262
+ entry: str. The entry to access the delay data.
263
+ delay_time: The delay time of the entry, the first element is the delay time,
264
+ the second and later element is the index.
265
+ """
266
+ if entry in self._registered_entries:
267
+ raise KeyError(
268
+ f'Entry {entry} has been registered. '
269
+ f'The existing delay for the key {entry} is {self._registered_entries[entry]}. '
270
+ f'The new delay for the key {entry} is {delay_time}. '
271
+ f'You can use another key. '
272
+ )
273
+ delay_info = self.register_delay(*delay_time)
274
+ self._registered_entries[entry] = delay_info
275
+ return delay_info
276
+
277
+ def access(self, entry: str, *delay_time) -> DelayAccess:
278
+ """
279
+ Create a DelayAccess object for a specific delay entry and delay time.
280
+
281
+ Args:
282
+ entry (str): The name of the delay entry to access.
283
+ delay_time (Sequence): The delay time or parameters associated with the entry.
284
+
285
+ Returns:
286
+ DelayAccess: An object that provides access to the delay data for the specified entry and time.
287
+ """
288
+ return DelayAccess(self, delay_time, entry=entry)
289
+
290
+ def at(self, entry: str) -> ArrayLike:
291
+ """
292
+ Get the data at the given entry.
293
+
294
+ Args:
295
+ entry: str. The entry to access the data.
296
+
297
+ Returns:
298
+ The data.
299
+ """
300
+ assert isinstance(entry, str), (f'entry should be a string for describing the '
301
+ f'entry of the delay data. But we got {entry}.')
302
+ if entry not in self._registered_entries:
303
+ raise KeyError(f'Does not find delay entry "{entry}".')
304
+ delay_step = self._registered_entries[entry]
305
+ if delay_step is None:
306
+ delay_step = (0,)
307
+ return self.retrieve_at_step(*delay_step)
308
+
309
+ def retrieve_at_step(self, delay_step, *indices) -> PyTree:
310
+ """
311
+ Retrieve the delay data at the given delay time step (the integer to indicate the time step).
312
+
313
+ Parameters
314
+ ----------
315
+ delay_step: int_like
316
+ Retrieve the data at the given time step.
317
+ indices: tuple
318
+ The indices to slice the data.
319
+
320
+ Returns
321
+ -------
322
+ delay_data: The delay data at the given delay step.
323
+
324
+ """
325
+ assert self.history is not None, 'The delay history is not initialized.'
326
+ assert delay_step is not None, 'The delay step should be given.'
327
+
328
+ if environ.get(environ.JIT_ERROR_CHECK, False):
329
+ def _check_delay(delay_len):
330
+ raise ValueError(
331
+ f'The request delay length should be less than the '
332
+ f'maximum delay {self.max_length - 1}. But we got {delay_len}'
333
+ )
334
+
335
+ jit_error_if(delay_step >= self.max_length, _check_delay, delay_step)
336
+
337
+ # rotation method
338
+ with jax.ensure_compile_time_eval():
339
+ if self.delay_method == _DELAY_ROTATE:
340
+ i = environ.get(environ.I, desc='The time step index.')
341
+ di = i - delay_step
342
+ delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32)
343
+ delay_idx = jax.lax.stop_gradient(delay_idx)
344
+
345
+ elif self.delay_method == _DELAY_CONCAT:
346
+ delay_idx = delay_step
347
+
348
+ else:
349
+ raise ValueError(f'Unknown delay updating method "{self.delay_method}"')
350
+
351
+ # the delay index
352
+ if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer):
353
+ raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}')
354
+ indices = (delay_idx,) + indices
355
+
356
+ # the delay data
357
+ if self._unit is None:
358
+ return jax.tree.map(lambda a: a[indices], self.history.value)
359
+ else:
360
+ return jax.tree.map(
361
+ lambda hist, unit: u.maybe_decimal(hist[indices] * unit),
362
+ self.history.value,
363
+ self._unit
364
+ )
365
+
366
+ def retrieve_at_time(self, delay_time, *indices) -> PyTree:
367
+ """
368
+ Retrieve the delay data at the given delay time step (the integer to indicate the time step).
369
+
370
+ Parameters
371
+ ----------
372
+ delay_time: float
373
+ Retrieve the data at the given time.
374
+ indices: tuple
375
+ The indices to slice the data.
376
+
377
+ Returns
378
+ -------
379
+ delay_data: The delay data at the given delay step.
380
+
381
+ """
382
+ assert self.history is not None, 'The delay history is not initialized.'
383
+ assert delay_time is not None, 'The delay time should be given.'
384
+
385
+ current_time = environ.get(environ.T, desc='The current time.')
386
+ dt = environ.get_dt()
387
+
388
+ if environ.get(environ.JIT_ERROR_CHECK, False):
389
+ def _check_delay(t_now, t_delay):
390
+ raise ValueError(
391
+ f'The request delay time should be within '
392
+ f'[{t_now - self.max_time - dt}, {t_now}], '
393
+ f'but we got {t_delay}'
394
+ )
395
+
396
+ jit_error_if(
397
+ jnp.logical_or(
398
+ delay_time > current_time,
399
+ delay_time < current_time - self.max_time - dt
400
+ ),
401
+ _check_delay,
402
+ current_time,
403
+ delay_time
404
+ )
405
+
406
+ with jax.ensure_compile_time_eval():
407
+ diff = current_time - delay_time
408
+ float_time_step = diff / dt
409
+
410
+ if self.interp_method == _INTERP_LINEAR: # "linear" interpolation
411
+ data_at_t0 = self.retrieve_at_step(jnp.asarray(jnp.floor(float_time_step), dtype=jnp.int32), *indices)
412
+ data_at_t1 = self.retrieve_at_step(jnp.asarray(jnp.ceil(float_time_step), dtype=jnp.int32), *indices)
413
+ t_diff = float_time_step - jnp.floor(float_time_step)
414
+ return jax.tree.map(lambda a, b: a * (1 - t_diff) + b * t_diff, data_at_t0, data_at_t1)
415
+
416
+ elif self.interp_method == _INTERP_ROUND: # "round" interpolation
417
+ return self.retrieve_at_step(jnp.asarray(jnp.round(float_time_step), dtype=jnp.int32), *indices)
418
+
419
+ else: # raise error
420
+ raise ValueError(f'Un-supported interpolation method {self.interp_method}, '
421
+ f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}')
422
+
423
+ def update(self, current: PyTree) -> None:
424
+ """
425
+ Update delay variable with the new data.
426
+ """
427
+
428
+ with jax.ensure_compile_time_eval():
429
+ assert self.history is not None, 'The delay history is not initialized.'
430
+
431
+ if self.take_aware_unit and self._unit is None:
432
+ self._unit = jax.tree.map(lambda x: u.get_unit(x), current, is_leaf=u.math.is_quantity)
433
+
434
+ # update the delay data at the rotation index
435
+ if self.delay_method == _DELAY_ROTATE:
436
+ i = environ.get(environ.I)
437
+ idx = jnp.asarray(i % self.max_length, dtype=environ.dutype())
438
+ idx = jax.lax.stop_gradient(idx)
439
+ self.history.value = jax.tree.map(
440
+ lambda hist, cur: hist.at[idx].set(cur),
441
+ self.history.value,
442
+ current
443
+ )
444
+ # update the delay data at the first position
445
+ elif self.delay_method == _DELAY_CONCAT:
446
+ current = jax.tree.map(lambda a: jnp.expand_dims(a, 0), current)
447
+ if self.max_length > 1:
448
+ self.history.value = jax.tree.map(
449
+ lambda hist, cur: jnp.concatenate([cur, hist[:-1]], axis=0),
450
+ self.history.value,
451
+ current
452
+ )
453
+ else:
454
+ self.history.value = current
455
+
456
+ else:
457
+ raise ValueError(f'Unknown updating method "{self.delay_method}"')
458
+
459
+
460
+ class StateWithDelay(Delay):
461
+ """
462
+ Delayed history buffer bound to a module state.
463
+
464
+ StateWithDelay is a specialized :py:class:`~.Delay` that attaches to a
465
+ concrete :py:class:`~brainstate._state.State` living on a target module
466
+ (for example a membrane potential ``V`` on a neuron). It automatically
467
+ maintains a rolling history of that state and exposes convenient helpers to
468
+ retrieve the value at a given delay either by step or by time.
469
+
470
+ In normal usage you rarely instantiate this class directly. It is created
471
+ implicitly when using the prefetch-delay helpers on a Dynamics module, e.g.:
472
+
473
+ - ``module.prefetch('V').delay.at(5.0 * u.ms)``
474
+ - ``module.prefetch_delay('V', 5.0 * u.ms)``
475
+
476
+ Both will construct a StateWithDelay bound to ``module.V`` under the hood
477
+ and register the requested delay, so you can retrieve the delayed value
478
+ inside your update rules.
479
+
480
+ Parameters
481
+ ----------
482
+ target : :py:class:`~brainstate.graph.Node`
483
+ The module object that owns the state to track.
484
+ item : str
485
+ The attribute name of the target state on ``target`` (must be a
486
+ :py:class:`~brainstate._state.State`).
487
+ init : Callable, optional
488
+ Optional initializer used to fill the history buffer before ``t0``
489
+ when delays request values from the past that hasn't been simulated yet.
490
+ The callable receives ``(shape, dtype)`` and must return an array.
491
+ If not provided, zeros are used. You may also pass a scalar/array
492
+ literal via the underlying Delay API when constructing manually.
493
+ delay_method : {"rotation", "concat"}, default "rotation"
494
+ Internal buffering strategy (inherits behavior from :py:class:`~.Delay`).
495
+ "rotation" keeps a ring buffer; "concat" shifts by concatenation.
496
+
497
+ Attributes
498
+ ----------
499
+ state : :py:class:`~brainstate._state.State`
500
+ The concrete state object being tracked.
501
+ history : :py:class:`~brainstate._state.ShortTermState`
502
+ Rolling time axis buffer with shape ``[length, *state.shape]``.
503
+ max_time : float
504
+ Maximum time span currently supported by the buffer.
505
+ max_length : int
506
+ Buffer length in steps (``ceil(max_time/dt)+1``).
507
+
508
+ Notes
509
+ -----
510
+ - This class inherits all retrieval utilities from :py:class:`~.Delay`:
511
+ use :py:meth:`retrieve_at_step` when you know the integer delay steps,
512
+ or :py:meth:`retrieve_at_time` for continuous-time queries with optional
513
+ linear/round interpolation.
514
+ - It is registered as an "after-update" hook on the owning Dynamics so the
515
+ buffer is updated automatically after each simulation step.
516
+
517
+ Examples
518
+ --------
519
+ Access a neuron's membrane potential 5 ms in the past:
520
+
521
+ >>> import brainunit as u
522
+ >>> import brainstate as brainstate
523
+ >>> lif = brainstate.nn.LIF(100)
524
+ >>> # Create a delayed accessor to V(t-5ms)
525
+ >>> v_delay = lif.prefetch_delay('V', 5.0 * u.ms)
526
+ >>> # Inside another module's update you can read the delayed value
527
+ >>> v_t_minus_5ms = v_delay()
528
+
529
+ Register multiple delay taps and index-specific delays:
530
+
531
+ >>> # Under the hood, a StateWithDelay is created and you can register
532
+ >>> # additional taps (in steps or time) via its Delay interface
533
+ >>> _ = lif.prefetch('V').delay.at(2.0 * u.ms) # additional delay
534
+ >>> # Direct access to buffer by steps (advanced)
535
+ >>> # lif._get_after_update('V-prefetch-delay').retrieve_at_step(3)
536
+ """
537
+
538
+ __module__ = 'brainstate.nn'
539
+
540
+ state: State # state
541
+
542
+ def __init__(
543
+ self,
544
+ target: Node,
545
+ item: str,
546
+ init: Callable = None,
547
+ delay_method: Optional[str] = _DELAY_ROTATE,
548
+ ):
549
+ super().__init__(None, init=init, delay_method=delay_method)
550
+
551
+ self._target = target
552
+ self._target_term = item
553
+
554
+ @property
555
+ def state(self) -> State:
556
+ r = getattr(self._target, self._target_term)
557
+ if not isinstance(r, State):
558
+ raise TypeError(f'The term "{self._target_term}" in the module "{self._target}" is not a State.')
559
+ return r
560
+
561
+ @call_order(3)
562
+ def init_state(self, *args, **kwargs):
563
+ """
564
+ State initialization function.
565
+ """
566
+ state = self.state
567
+ self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), state.value)
568
+ super().init_state(*args, **kwargs)
569
+
570
+ def update(self, *args) -> None:
571
+ """
572
+ Update the delay variable with the new data.
573
+ """
574
+ value = self.state.value
575
+ return super().update(value)