brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +167 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2297 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +2157 -1652
- brainstate/_state_test.py +1129 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1620 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1447 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +146 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +635 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +134 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +480 -477
- brainstate/nn/_dynamics.py +870 -1267
- brainstate/nn/_dynamics_test.py +53 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +391 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
- brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
- brainstate/random/_impl.py +672 -0
- brainstate/random/{_rand_seed.py → _seed.py} +675 -675
- brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
- brainstate/random/{_rand_state.py → _state.py} +1320 -1617
- brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
- brainstate/transform/__init__.py +56 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_find_state.py +200 -0
- brainstate/transform/_find_state_test.py +84 -0
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2176 -2016
- brainstate/transform/_make_jaxpr_test.py +1634 -1510
- brainstate/transform/_mapping.py +607 -529
- brainstate/transform/_mapping_test.py +104 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
- brainstate-0.2.2.dist-info/RECORD +111 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
- brainstate/transform/_eval_shape.py +0 -145
- brainstate/transform/_eval_shape_test.py +0 -38
- brainstate/transform/_random.py +0 -171
- brainstate-0.2.1.dist-info/RECORD +0 -111
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
brainstate/nn/_delay.py
CHANGED
@@ -1,575 +1,575 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import numbers
|
17
|
-
from functools import partial
|
18
|
-
from typing import Optional, Dict, Callable, Union, Sequence
|
19
|
-
|
20
|
-
import brainunit as u
|
21
|
-
import jax
|
22
|
-
import jax.numpy as jnp
|
23
|
-
import numpy as np
|
24
|
-
|
25
|
-
from brainstate import environ
|
26
|
-
from brainstate._state import ShortTermState, State
|
27
|
-
from brainstate.graph import Node
|
28
|
-
from brainstate.transform import jit_error_if
|
29
|
-
from brainstate.typing import ArrayLike, PyTree
|
30
|
-
from ._collective_ops import call_order
|
31
|
-
from ._module import Module
|
32
|
-
|
33
|
-
__all__ = [
|
34
|
-
'Delay', 'DelayAccess', 'StateWithDelay',
|
35
|
-
]
|
36
|
-
|
37
|
-
_DELAY_ROTATE = 'rotation'
|
38
|
-
_DELAY_CONCAT = 'concat'
|
39
|
-
_INTERP_LINEAR = 'linear_interp'
|
40
|
-
_INTERP_ROUND = 'round'
|
41
|
-
|
42
|
-
|
43
|
-
def _get_delay(delay_time):
|
44
|
-
if delay_time is None:
|
45
|
-
return 0. * environ.get_dt(), 0
|
46
|
-
delay_step = delay_time / environ.get_dt()
|
47
|
-
assert u.get_dim(delay_step) == u.DIMENSIONLESS
|
48
|
-
delay_step = jnp.ceil(delay_step).astype(environ.ditype())
|
49
|
-
return delay_time, delay_step
|
50
|
-
|
51
|
-
|
52
|
-
class DelayAccess(Node):
|
53
|
-
"""
|
54
|
-
Accessor node for a registered entry in a Delay instance.
|
55
|
-
|
56
|
-
This node holds a reference to a Delay and a named entry that was
|
57
|
-
registered on that Delay. It is used by graphs to query delayed
|
58
|
-
values by delegating to the underlying Delay instance.
|
59
|
-
|
60
|
-
Args:
|
61
|
-
delay: The delay instance.
|
62
|
-
*time: The delay time.
|
63
|
-
entry: The delay entry.
|
64
|
-
"""
|
65
|
-
|
66
|
-
__module__ = 'brainstate.nn'
|
67
|
-
|
68
|
-
def __init__(
|
69
|
-
self,
|
70
|
-
delay: 'Delay',
|
71
|
-
*time,
|
72
|
-
entry: str,
|
73
|
-
):
|
74
|
-
super().__init__()
|
75
|
-
self.delay = delay
|
76
|
-
assert isinstance(delay, Delay), 'The input delay should be an instance of Delay.'
|
77
|
-
self._delay_entry = entry
|
78
|
-
self.delay_info = delay.register_entry(self._delay_entry, *time)
|
79
|
-
|
80
|
-
def update(self):
|
81
|
-
return self.delay.at(self._delay_entry)
|
82
|
-
|
83
|
-
|
84
|
-
class Delay(Module):
|
85
|
-
"""
|
86
|
-
Delay variable for storing short-term history data.
|
87
|
-
|
88
|
-
The data in this delay variable is arranged as::
|
89
|
-
|
90
|
-
delay = 0 [ data
|
91
|
-
delay = 1 data
|
92
|
-
delay = 2 data
|
93
|
-
... ....
|
94
|
-
... ....
|
95
|
-
delay = length-1 data
|
96
|
-
delay = length data ]
|
97
|
-
|
98
|
-
Args:
|
99
|
-
time: int, float. The delay time.
|
100
|
-
init: Any. The delay data. It can be a Python number, like float, int, boolean values.
|
101
|
-
It can also be arrays. Or a callable function or instance of ``Connector``.
|
102
|
-
Note that ``initial_delay_data`` should be arranged as the following way::
|
103
|
-
|
104
|
-
delay = 1 [ data
|
105
|
-
delay = 2 data
|
106
|
-
... ....
|
107
|
-
... ....
|
108
|
-
delay = length-1 data
|
109
|
-
delay = length data ]
|
110
|
-
entries: optional, dict. The delay access entries.
|
111
|
-
delay_method: str. The method used for updating delay. Default None.
|
112
|
-
"""
|
113
|
-
|
114
|
-
__module__ = 'brainstate.nn'
|
115
|
-
|
116
|
-
max_time: float #
|
117
|
-
max_length: int
|
118
|
-
history: Optional[ShortTermState]
|
119
|
-
|
120
|
-
def __init__(
|
121
|
-
self,
|
122
|
-
target_info: PyTree,
|
123
|
-
time: Optional[Union[int, float, u.Quantity]] = None, # delay time
|
124
|
-
init: Optional[Union[ArrayLike, Callable]] = None, # delay data before t0
|
125
|
-
entries: Optional[Dict] = None, # delay access entry
|
126
|
-
delay_method: Optional[str] = _DELAY_ROTATE, # delay method
|
127
|
-
interp_method: str = _INTERP_LINEAR, # interpolation method
|
128
|
-
take_aware_unit: bool = False
|
129
|
-
):
|
130
|
-
# target information
|
131
|
-
self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), target_info)
|
132
|
-
|
133
|
-
# delay method
|
134
|
-
assert delay_method in [_DELAY_ROTATE, _DELAY_CONCAT], (
|
135
|
-
f'Un-supported delay method {delay_method}. '
|
136
|
-
f'Only support {_DELAY_ROTATE} and {_DELAY_CONCAT}'
|
137
|
-
)
|
138
|
-
self.delay_method = delay_method
|
139
|
-
|
140
|
-
# interp method
|
141
|
-
assert interp_method in [_INTERP_LINEAR, _INTERP_ROUND], (
|
142
|
-
f'Un-supported interpolation method {interp_method}. '
|
143
|
-
f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}'
|
144
|
-
)
|
145
|
-
self.interp_method = interp_method
|
146
|
-
|
147
|
-
# delay length and time
|
148
|
-
with jax.ensure_compile_time_eval():
|
149
|
-
self.max_time, delay_length = _get_delay(time)
|
150
|
-
self.max_length = delay_length + 1
|
151
|
-
|
152
|
-
super().__init__()
|
153
|
-
|
154
|
-
# delay data
|
155
|
-
if init is not None:
|
156
|
-
if not isinstance(init, (numbers.Number, jax.Array, np.ndarray, Callable)):
|
157
|
-
raise TypeError(f'init should be Array, Callable, or None. But got {init}')
|
158
|
-
self._init = init
|
159
|
-
self._history = None
|
160
|
-
|
161
|
-
# other info
|
162
|
-
self._registered_entries = dict()
|
163
|
-
|
164
|
-
# other info
|
165
|
-
if entries is not None:
|
166
|
-
for entry, delay_time in entries.items():
|
167
|
-
if isinstance(delay_time, (tuple, list)):
|
168
|
-
self.register_entry(entry, *delay_time)
|
169
|
-
else:
|
170
|
-
self.register_entry(entry, delay_time)
|
171
|
-
|
172
|
-
self.take_aware_unit = take_aware_unit
|
173
|
-
self._unit = None
|
174
|
-
|
175
|
-
@property
|
176
|
-
def history(self):
|
177
|
-
return self._history
|
178
|
-
|
179
|
-
@history.setter
|
180
|
-
def history(self, value):
|
181
|
-
self._history = value
|
182
|
-
|
183
|
-
def _f_to_init(self, a, batch_size, length):
|
184
|
-
shape = list(a.shape)
|
185
|
-
if batch_size is not None:
|
186
|
-
shape.insert(0, batch_size)
|
187
|
-
shape.insert(0, length)
|
188
|
-
if isinstance(self._init, (jax.Array, np.ndarray, numbers.Number)):
|
189
|
-
data = jnp.broadcast_to(jnp.asarray(self._init, a.dtype), shape)
|
190
|
-
elif callable(self._init):
|
191
|
-
data = self._init(shape, dtype=a.dtype)
|
192
|
-
else:
|
193
|
-
assert self._init is None, f'init should be Array, Callable, or None. but got {self._init}'
|
194
|
-
data = jnp.zeros(shape, dtype=a.dtype)
|
195
|
-
return data
|
196
|
-
|
197
|
-
@call_order(3)
|
198
|
-
def init_state(self, batch_size: int = None, **kwargs):
|
199
|
-
fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
|
200
|
-
self.history = ShortTermState(jax.tree.map(fun, self.target_info))
|
201
|
-
|
202
|
-
def reset_state(self, batch_size: int = None, **kwargs):
|
203
|
-
fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
|
204
|
-
self.history.value = jax.tree.map(fun, self.target_info)
|
205
|
-
|
206
|
-
def register_delay(self, *delay_time):
|
207
|
-
"""
|
208
|
-
Register delay times and update the maximum delay configuration.
|
209
|
-
|
210
|
-
This method processes one or more delay times, validates their format and consistency,
|
211
|
-
and updates the delay buffer size if necessary. It handles both scalar and vector
|
212
|
-
delay times, ensuring all vector delays have the same size.
|
213
|
-
|
214
|
-
Args:
|
215
|
-
*delay_time: Variable number of delay time arguments. The first argument should be
|
216
|
-
the primary delay time (float, int, or array-like). Additional arguments are
|
217
|
-
treated as indices or secondary delay parameters. All delay times should be
|
218
|
-
non-negative numbers or arrays of the same size.
|
219
|
-
|
220
|
-
Returns:
|
221
|
-
tuple or None: If delay_time[0] is None, returns None. Otherwise, returns a tuple
|
222
|
-
containing (delay_step, *delay_time[1:]) where delay_step is the computed
|
223
|
-
delay step in integer time units, and the remaining elements are the
|
224
|
-
additional delay parameters passed in.
|
225
|
-
|
226
|
-
Raises:
|
227
|
-
AssertionError: If no delay time is provided (empty delay_time).
|
228
|
-
ValueError: If delay times have inconsistent sizes when using vector delays,
|
229
|
-
or if delay times are not scalar or 1D arrays.
|
230
|
-
|
231
|
-
Note:
|
232
|
-
- The method updates self.max_time and self.max_length if the new delay
|
233
|
-
requires a larger buffer size.
|
234
|
-
- Delay steps are computed using the current environment time step (dt).
|
235
|
-
- All delay indices (delay_time[1:]) must be integers.
|
236
|
-
- Vector delays must all have the same size as the first delay time.
|
237
|
-
|
238
|
-
Example:
|
239
|
-
>>> delay_obj.register_delay(5.0) # Register 5ms delay
|
240
|
-
>>> delay_obj.register_delay(jnp.array([2.0, 3.0]), 0, 1) # Vector delay with indices
|
241
|
-
"""
|
242
|
-
assert len(delay_time) >= 1, 'You should provide at least one delay time.'
|
243
|
-
for dt in delay_time[1:]:
|
244
|
-
assert jnp.issubdtype(u.math.get_dtype(dt), jnp.integer), f'The index should be integer. But got {dt}.'
|
245
|
-
if delay_time[0] is None:
|
246
|
-
return None
|
247
|
-
with jax.ensure_compile_time_eval():
|
248
|
-
time, delay_step = _get_delay(delay_time[0])
|
249
|
-
max_delay_step = jnp.max(delay_step)
|
250
|
-
self.max_time = u.math.max(time)
|
251
|
-
|
252
|
-
# delay variable
|
253
|
-
if self.max_length <= max_delay_step + 1:
|
254
|
-
self.max_length = max_delay_step + 1
|
255
|
-
return delay_step, *delay_time[1:]
|
256
|
-
|
257
|
-
def register_entry(self, entry: str, *delay_time) -> 'Delay':
|
258
|
-
"""
|
259
|
-
Register an entry to access the delay data.
|
260
|
-
|
261
|
-
Args:
|
262
|
-
entry: str. The entry to access the delay data.
|
263
|
-
delay_time: The delay time of the entry, the first element is the delay time,
|
264
|
-
the second and later element is the index.
|
265
|
-
"""
|
266
|
-
if entry in self._registered_entries:
|
267
|
-
raise KeyError(
|
268
|
-
f'Entry {entry} has been registered. '
|
269
|
-
f'The existing delay for the key {entry} is {self._registered_entries[entry]}. '
|
270
|
-
f'The new delay for the key {entry} is {delay_time}. '
|
271
|
-
f'You can use another key. '
|
272
|
-
)
|
273
|
-
delay_info = self.register_delay(*delay_time)
|
274
|
-
self._registered_entries[entry] = delay_info
|
275
|
-
return delay_info
|
276
|
-
|
277
|
-
def access(self, entry: str, *delay_time) -> DelayAccess:
|
278
|
-
"""
|
279
|
-
Create a DelayAccess object for a specific delay entry and delay time.
|
280
|
-
|
281
|
-
Args:
|
282
|
-
entry (str): The name of the delay entry to access.
|
283
|
-
delay_time (Sequence): The delay time or parameters associated with the entry.
|
284
|
-
|
285
|
-
Returns:
|
286
|
-
DelayAccess: An object that provides access to the delay data for the specified entry and time.
|
287
|
-
"""
|
288
|
-
return DelayAccess(self, delay_time, entry=entry)
|
289
|
-
|
290
|
-
def at(self, entry: str) -> ArrayLike:
|
291
|
-
"""
|
292
|
-
Get the data at the given entry.
|
293
|
-
|
294
|
-
Args:
|
295
|
-
entry: str. The entry to access the data.
|
296
|
-
|
297
|
-
Returns:
|
298
|
-
The data.
|
299
|
-
"""
|
300
|
-
assert isinstance(entry, str), (f'entry should be a string for describing the '
|
301
|
-
f'entry of the delay data. But we got {entry}.')
|
302
|
-
if entry not in self._registered_entries:
|
303
|
-
raise KeyError(f'Does not find delay entry "{entry}".')
|
304
|
-
delay_step = self._registered_entries[entry]
|
305
|
-
if delay_step is None:
|
306
|
-
delay_step = (0,)
|
307
|
-
return self.retrieve_at_step(*delay_step)
|
308
|
-
|
309
|
-
def retrieve_at_step(self, delay_step, *indices) -> PyTree:
|
310
|
-
"""
|
311
|
-
Retrieve the delay data at the given delay time step (the integer to indicate the time step).
|
312
|
-
|
313
|
-
Parameters
|
314
|
-
----------
|
315
|
-
delay_step: int_like
|
316
|
-
Retrieve the data at the given time step.
|
317
|
-
indices: tuple
|
318
|
-
The indices to slice the data.
|
319
|
-
|
320
|
-
Returns
|
321
|
-
-------
|
322
|
-
delay_data: The delay data at the given delay step.
|
323
|
-
|
324
|
-
"""
|
325
|
-
assert self.history is not None, 'The delay history is not initialized.'
|
326
|
-
assert delay_step is not None, 'The delay step should be given.'
|
327
|
-
|
328
|
-
if environ.get(environ.JIT_ERROR_CHECK, False):
|
329
|
-
def _check_delay(delay_len):
|
330
|
-
raise ValueError(
|
331
|
-
f'The request delay length should be less than the '
|
332
|
-
f'maximum delay {self.max_length - 1}. But we got {delay_len}'
|
333
|
-
)
|
334
|
-
|
335
|
-
jit_error_if(delay_step >= self.max_length, _check_delay, delay_step)
|
336
|
-
|
337
|
-
# rotation method
|
338
|
-
with jax.ensure_compile_time_eval():
|
339
|
-
if self.delay_method == _DELAY_ROTATE:
|
340
|
-
i = environ.get(environ.I, desc='The time step index.')
|
341
|
-
di = i - delay_step
|
342
|
-
delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32)
|
343
|
-
delay_idx = jax.lax.stop_gradient(delay_idx)
|
344
|
-
|
345
|
-
elif self.delay_method == _DELAY_CONCAT:
|
346
|
-
delay_idx = delay_step
|
347
|
-
|
348
|
-
else:
|
349
|
-
raise ValueError(f'Unknown delay updating method "{self.delay_method}"')
|
350
|
-
|
351
|
-
# the delay index
|
352
|
-
if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer):
|
353
|
-
raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}')
|
354
|
-
indices = (delay_idx,) + indices
|
355
|
-
|
356
|
-
# the delay data
|
357
|
-
if self._unit is None:
|
358
|
-
return jax.tree.map(lambda a: a[indices], self.history.value)
|
359
|
-
else:
|
360
|
-
return jax.tree.map(
|
361
|
-
lambda hist, unit: u.maybe_decimal(hist[indices] * unit),
|
362
|
-
self.history.value,
|
363
|
-
self._unit
|
364
|
-
)
|
365
|
-
|
366
|
-
def retrieve_at_time(self, delay_time, *indices) -> PyTree:
|
367
|
-
"""
|
368
|
-
Retrieve the delay data at the given delay time step (the integer to indicate the time step).
|
369
|
-
|
370
|
-
Parameters
|
371
|
-
----------
|
372
|
-
delay_time: float
|
373
|
-
Retrieve the data at the given time.
|
374
|
-
indices: tuple
|
375
|
-
The indices to slice the data.
|
376
|
-
|
377
|
-
Returns
|
378
|
-
-------
|
379
|
-
delay_data: The delay data at the given delay step.
|
380
|
-
|
381
|
-
"""
|
382
|
-
assert self.history is not None, 'The delay history is not initialized.'
|
383
|
-
assert delay_time is not None, 'The delay time should be given.'
|
384
|
-
|
385
|
-
current_time = environ.get(environ.T, desc='The current time.')
|
386
|
-
dt = environ.get_dt()
|
387
|
-
|
388
|
-
if environ.get(environ.JIT_ERROR_CHECK, False):
|
389
|
-
def _check_delay(t_now, t_delay):
|
390
|
-
raise ValueError(
|
391
|
-
f'The request delay time should be within '
|
392
|
-
f'[{t_now - self.max_time - dt}, {t_now}], '
|
393
|
-
f'but we got {t_delay}'
|
394
|
-
)
|
395
|
-
|
396
|
-
jit_error_if(
|
397
|
-
jnp.logical_or(
|
398
|
-
delay_time > current_time,
|
399
|
-
delay_time < current_time - self.max_time - dt
|
400
|
-
),
|
401
|
-
_check_delay,
|
402
|
-
current_time,
|
403
|
-
delay_time
|
404
|
-
)
|
405
|
-
|
406
|
-
with jax.ensure_compile_time_eval():
|
407
|
-
diff = current_time - delay_time
|
408
|
-
float_time_step = diff / dt
|
409
|
-
|
410
|
-
if self.interp_method == _INTERP_LINEAR: # "linear" interpolation
|
411
|
-
data_at_t0 = self.retrieve_at_step(jnp.asarray(jnp.floor(float_time_step), dtype=jnp.int32), *indices)
|
412
|
-
data_at_t1 = self.retrieve_at_step(jnp.asarray(jnp.ceil(float_time_step), dtype=jnp.int32), *indices)
|
413
|
-
t_diff = float_time_step - jnp.floor(float_time_step)
|
414
|
-
return jax.tree.map(lambda a, b: a * (1 - t_diff) + b * t_diff, data_at_t0, data_at_t1)
|
415
|
-
|
416
|
-
elif self.interp_method == _INTERP_ROUND: # "round" interpolation
|
417
|
-
return self.retrieve_at_step(jnp.asarray(jnp.round(float_time_step), dtype=jnp.int32), *indices)
|
418
|
-
|
419
|
-
else: # raise error
|
420
|
-
raise ValueError(f'Un-supported interpolation method {self.interp_method}, '
|
421
|
-
f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}')
|
422
|
-
|
423
|
-
def update(self, current: PyTree) -> None:
|
424
|
-
"""
|
425
|
-
Update delay variable with the new data.
|
426
|
-
"""
|
427
|
-
|
428
|
-
with jax.ensure_compile_time_eval():
|
429
|
-
assert self.history is not None, 'The delay history is not initialized.'
|
430
|
-
|
431
|
-
if self.take_aware_unit and self._unit is None:
|
432
|
-
self._unit = jax.tree.map(lambda x: u.get_unit(x), current, is_leaf=u.math.is_quantity)
|
433
|
-
|
434
|
-
# update the delay data at the rotation index
|
435
|
-
if self.delay_method == _DELAY_ROTATE:
|
436
|
-
i = environ.get(environ.I)
|
437
|
-
idx = jnp.asarray(i % self.max_length, dtype=environ.dutype())
|
438
|
-
idx = jax.lax.stop_gradient(idx)
|
439
|
-
self.history.value = jax.tree.map(
|
440
|
-
lambda hist, cur: hist.at[idx].set(cur),
|
441
|
-
self.history.value,
|
442
|
-
current
|
443
|
-
)
|
444
|
-
# update the delay data at the first position
|
445
|
-
elif self.delay_method == _DELAY_CONCAT:
|
446
|
-
current = jax.tree.map(lambda a: jnp.expand_dims(a, 0), current)
|
447
|
-
if self.max_length > 1:
|
448
|
-
self.history.value = jax.tree.map(
|
449
|
-
lambda hist, cur: jnp.concatenate([cur, hist[:-1]], axis=0),
|
450
|
-
self.history.value,
|
451
|
-
current
|
452
|
-
)
|
453
|
-
else:
|
454
|
-
self.history.value = current
|
455
|
-
|
456
|
-
else:
|
457
|
-
raise ValueError(f'Unknown updating method "{self.delay_method}"')
|
458
|
-
|
459
|
-
|
460
|
-
class StateWithDelay(Delay):
|
461
|
-
"""
|
462
|
-
Delayed history buffer bound to a module state.
|
463
|
-
|
464
|
-
StateWithDelay is a specialized :py:class:`~.Delay` that attaches to a
|
465
|
-
concrete :py:class:`~brainstate._state.State` living on a target module
|
466
|
-
(for example a membrane potential ``V`` on a neuron). It automatically
|
467
|
-
maintains a rolling history of that state and exposes convenient helpers to
|
468
|
-
retrieve the value at a given delay either by step or by time.
|
469
|
-
|
470
|
-
In normal usage you rarely instantiate this class directly. It is created
|
471
|
-
implicitly when using the prefetch-delay helpers on a Dynamics module, e.g.:
|
472
|
-
|
473
|
-
- ``module.prefetch('V').delay.at(5.0 * u.ms)``
|
474
|
-
- ``module.prefetch_delay('V', 5.0 * u.ms)``
|
475
|
-
|
476
|
-
Both will construct a StateWithDelay bound to ``module.V`` under the hood
|
477
|
-
and register the requested delay, so you can retrieve the delayed value
|
478
|
-
inside your update rules.
|
479
|
-
|
480
|
-
Parameters
|
481
|
-
----------
|
482
|
-
target : :py:class:`~brainstate.graph.Node`
|
483
|
-
The module object that owns the state to track.
|
484
|
-
item : str
|
485
|
-
The attribute name of the target state on ``target`` (must be a
|
486
|
-
:py:class:`~brainstate._state.State`).
|
487
|
-
init : Callable, optional
|
488
|
-
Optional initializer used to fill the history buffer before ``t0``
|
489
|
-
when delays request values from the past that hasn't been simulated yet.
|
490
|
-
The callable receives ``(shape, dtype)`` and must return an array.
|
491
|
-
If not provided, zeros are used. You may also pass a scalar/array
|
492
|
-
literal via the underlying Delay API when constructing manually.
|
493
|
-
delay_method : {"rotation", "concat"}, default "rotation"
|
494
|
-
Internal buffering strategy (inherits behavior from :py:class:`~.Delay`).
|
495
|
-
"rotation" keeps a ring buffer; "concat" shifts by concatenation.
|
496
|
-
|
497
|
-
Attributes
|
498
|
-
----------
|
499
|
-
state : :py:class:`~brainstate._state.State`
|
500
|
-
The concrete state object being tracked.
|
501
|
-
history : :py:class:`~brainstate._state.ShortTermState`
|
502
|
-
Rolling time axis buffer with shape ``[length, *state.shape]``.
|
503
|
-
max_time : float
|
504
|
-
Maximum time span currently supported by the buffer.
|
505
|
-
max_length : int
|
506
|
-
Buffer length in steps (``ceil(max_time/dt)+1``).
|
507
|
-
|
508
|
-
Notes
|
509
|
-
-----
|
510
|
-
- This class inherits all retrieval utilities from :py:class:`~.Delay`:
|
511
|
-
use :py:meth:`retrieve_at_step` when you know the integer delay steps,
|
512
|
-
or :py:meth:`retrieve_at_time` for continuous-time queries with optional
|
513
|
-
linear/round interpolation.
|
514
|
-
- It is registered as an "after-update" hook on the owning Dynamics so the
|
515
|
-
buffer is updated automatically after each simulation step.
|
516
|
-
|
517
|
-
Examples
|
518
|
-
--------
|
519
|
-
Access a neuron's membrane potential 5 ms in the past:
|
520
|
-
|
521
|
-
>>> import brainunit as u
|
522
|
-
>>> import brainstate as brainstate
|
523
|
-
>>> lif = brainstate.nn.LIF(100)
|
524
|
-
>>> # Create a delayed accessor to V(t-5ms)
|
525
|
-
>>> v_delay = lif.prefetch_delay('V', 5.0 * u.ms)
|
526
|
-
>>> # Inside another module's update you can read the delayed value
|
527
|
-
>>> v_t_minus_5ms = v_delay()
|
528
|
-
|
529
|
-
Register multiple delay taps and index-specific delays:
|
530
|
-
|
531
|
-
>>> # Under the hood, a StateWithDelay is created and you can register
|
532
|
-
>>> # additional taps (in steps or time) via its Delay interface
|
533
|
-
>>> _ = lif.prefetch('V').delay.at(2.0 * u.ms) # additional delay
|
534
|
-
>>> # Direct access to buffer by steps (advanced)
|
535
|
-
>>> # lif._get_after_update('V-prefetch-delay').retrieve_at_step(3)
|
536
|
-
"""
|
537
|
-
|
538
|
-
__module__ = 'brainstate.nn'
|
539
|
-
|
540
|
-
state: State # state
|
541
|
-
|
542
|
-
def __init__(
|
543
|
-
self,
|
544
|
-
target: Node,
|
545
|
-
item: str,
|
546
|
-
init: Callable = None,
|
547
|
-
delay_method: Optional[str] = _DELAY_ROTATE,
|
548
|
-
):
|
549
|
-
super().__init__(None, init=init, delay_method=delay_method)
|
550
|
-
|
551
|
-
self._target = target
|
552
|
-
self._target_term = item
|
553
|
-
|
554
|
-
@property
|
555
|
-
def state(self) -> State:
|
556
|
-
r = getattr(self._target, self._target_term)
|
557
|
-
if not isinstance(r, State):
|
558
|
-
raise TypeError(f'The term "{self._target_term}" in the module "{self._target}" is not a State.')
|
559
|
-
return r
|
560
|
-
|
561
|
-
@call_order(3)
|
562
|
-
def init_state(self, *args, **kwargs):
|
563
|
-
"""
|
564
|
-
State initialization function.
|
565
|
-
"""
|
566
|
-
state = self.state
|
567
|
-
self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), state.value)
|
568
|
-
super().init_state(*args, **kwargs)
|
569
|
-
|
570
|
-
def update(self, *args) -> None:
|
571
|
-
"""
|
572
|
-
Update the delay variable with the new data.
|
573
|
-
"""
|
574
|
-
value = self.state.value
|
575
|
-
return super().update(value)
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import numbers
|
17
|
+
from functools import partial
|
18
|
+
from typing import Optional, Dict, Callable, Union, Sequence
|
19
|
+
|
20
|
+
import brainunit as u
|
21
|
+
import jax
|
22
|
+
import jax.numpy as jnp
|
23
|
+
import numpy as np
|
24
|
+
|
25
|
+
from brainstate import environ
|
26
|
+
from brainstate._state import ShortTermState, State
|
27
|
+
from brainstate.graph import Node
|
28
|
+
from brainstate.transform import jit_error_if
|
29
|
+
from brainstate.typing import ArrayLike, PyTree
|
30
|
+
from ._collective_ops import call_order
|
31
|
+
from ._module import Module
|
32
|
+
|
33
|
+
__all__ = [
|
34
|
+
'Delay', 'DelayAccess', 'StateWithDelay',
|
35
|
+
]
|
36
|
+
|
37
|
+
_DELAY_ROTATE = 'rotation'
|
38
|
+
_DELAY_CONCAT = 'concat'
|
39
|
+
_INTERP_LINEAR = 'linear_interp'
|
40
|
+
_INTERP_ROUND = 'round'
|
41
|
+
|
42
|
+
|
43
|
+
def _get_delay(delay_time):
|
44
|
+
if delay_time is None:
|
45
|
+
return 0. * environ.get_dt(), 0
|
46
|
+
delay_step = delay_time / environ.get_dt()
|
47
|
+
assert u.get_dim(delay_step) == u.DIMENSIONLESS
|
48
|
+
delay_step = jnp.ceil(delay_step).astype(environ.ditype())
|
49
|
+
return delay_time, delay_step
|
50
|
+
|
51
|
+
|
52
|
+
class DelayAccess(Node):
|
53
|
+
"""
|
54
|
+
Accessor node for a registered entry in a Delay instance.
|
55
|
+
|
56
|
+
This node holds a reference to a Delay and a named entry that was
|
57
|
+
registered on that Delay. It is used by graphs to query delayed
|
58
|
+
values by delegating to the underlying Delay instance.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
delay: The delay instance.
|
62
|
+
*time: The delay time.
|
63
|
+
entry: The delay entry.
|
64
|
+
"""
|
65
|
+
|
66
|
+
__module__ = 'brainstate.nn'
|
67
|
+
|
68
|
+
def __init__(
|
69
|
+
self,
|
70
|
+
delay: 'Delay',
|
71
|
+
*time,
|
72
|
+
entry: str,
|
73
|
+
):
|
74
|
+
super().__init__()
|
75
|
+
self.delay = delay
|
76
|
+
assert isinstance(delay, Delay), 'The input delay should be an instance of Delay.'
|
77
|
+
self._delay_entry = entry
|
78
|
+
self.delay_info = delay.register_entry(self._delay_entry, *time)
|
79
|
+
|
80
|
+
def update(self):
|
81
|
+
return self.delay.at(self._delay_entry)
|
82
|
+
|
83
|
+
|
84
|
+
class Delay(Module):
|
85
|
+
"""
|
86
|
+
Delay variable for storing short-term history data.
|
87
|
+
|
88
|
+
The data in this delay variable is arranged as::
|
89
|
+
|
90
|
+
delay = 0 [ data
|
91
|
+
delay = 1 data
|
92
|
+
delay = 2 data
|
93
|
+
... ....
|
94
|
+
... ....
|
95
|
+
delay = length-1 data
|
96
|
+
delay = length data ]
|
97
|
+
|
98
|
+
Args:
|
99
|
+
time: int, float. The delay time.
|
100
|
+
init: Any. The delay data. It can be a Python number, like float, int, boolean values.
|
101
|
+
It can also be arrays. Or a callable function or instance of ``Connector``.
|
102
|
+
Note that ``initial_delay_data`` should be arranged as the following way::
|
103
|
+
|
104
|
+
delay = 1 [ data
|
105
|
+
delay = 2 data
|
106
|
+
... ....
|
107
|
+
... ....
|
108
|
+
delay = length-1 data
|
109
|
+
delay = length data ]
|
110
|
+
entries: optional, dict. The delay access entries.
|
111
|
+
delay_method: str. The method used for updating delay. Default None.
|
112
|
+
"""
|
113
|
+
|
114
|
+
__module__ = 'brainstate.nn'
|
115
|
+
|
116
|
+
max_time: float #
|
117
|
+
max_length: int
|
118
|
+
history: Optional[ShortTermState]
|
119
|
+
|
120
|
+
def __init__(
|
121
|
+
self,
|
122
|
+
target_info: PyTree,
|
123
|
+
time: Optional[Union[int, float, u.Quantity]] = None, # delay time
|
124
|
+
init: Optional[Union[ArrayLike, Callable]] = None, # delay data before t0
|
125
|
+
entries: Optional[Dict] = None, # delay access entry
|
126
|
+
delay_method: Optional[str] = _DELAY_ROTATE, # delay method
|
127
|
+
interp_method: str = _INTERP_LINEAR, # interpolation method
|
128
|
+
take_aware_unit: bool = False
|
129
|
+
):
|
130
|
+
# target information
|
131
|
+
self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), target_info)
|
132
|
+
|
133
|
+
# delay method
|
134
|
+
assert delay_method in [_DELAY_ROTATE, _DELAY_CONCAT], (
|
135
|
+
f'Un-supported delay method {delay_method}. '
|
136
|
+
f'Only support {_DELAY_ROTATE} and {_DELAY_CONCAT}'
|
137
|
+
)
|
138
|
+
self.delay_method = delay_method
|
139
|
+
|
140
|
+
# interp method
|
141
|
+
assert interp_method in [_INTERP_LINEAR, _INTERP_ROUND], (
|
142
|
+
f'Un-supported interpolation method {interp_method}. '
|
143
|
+
f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}'
|
144
|
+
)
|
145
|
+
self.interp_method = interp_method
|
146
|
+
|
147
|
+
# delay length and time
|
148
|
+
with jax.ensure_compile_time_eval():
|
149
|
+
self.max_time, delay_length = _get_delay(time)
|
150
|
+
self.max_length = delay_length + 1
|
151
|
+
|
152
|
+
super().__init__()
|
153
|
+
|
154
|
+
# delay data
|
155
|
+
if init is not None:
|
156
|
+
if not isinstance(init, (numbers.Number, jax.Array, np.ndarray, Callable)):
|
157
|
+
raise TypeError(f'init should be Array, Callable, or None. But got {init}')
|
158
|
+
self._init = init
|
159
|
+
self._history = None
|
160
|
+
|
161
|
+
# other info
|
162
|
+
self._registered_entries = dict()
|
163
|
+
|
164
|
+
# other info
|
165
|
+
if entries is not None:
|
166
|
+
for entry, delay_time in entries.items():
|
167
|
+
if isinstance(delay_time, (tuple, list)):
|
168
|
+
self.register_entry(entry, *delay_time)
|
169
|
+
else:
|
170
|
+
self.register_entry(entry, delay_time)
|
171
|
+
|
172
|
+
self.take_aware_unit = take_aware_unit
|
173
|
+
self._unit = None
|
174
|
+
|
175
|
+
@property
|
176
|
+
def history(self):
|
177
|
+
return self._history
|
178
|
+
|
179
|
+
@history.setter
|
180
|
+
def history(self, value):
|
181
|
+
self._history = value
|
182
|
+
|
183
|
+
def _f_to_init(self, a, batch_size, length):
|
184
|
+
shape = list(a.shape)
|
185
|
+
if batch_size is not None:
|
186
|
+
shape.insert(0, batch_size)
|
187
|
+
shape.insert(0, length)
|
188
|
+
if isinstance(self._init, (jax.Array, np.ndarray, numbers.Number)):
|
189
|
+
data = jnp.broadcast_to(jnp.asarray(self._init, a.dtype), shape)
|
190
|
+
elif callable(self._init):
|
191
|
+
data = self._init(shape, dtype=a.dtype)
|
192
|
+
else:
|
193
|
+
assert self._init is None, f'init should be Array, Callable, or None. but got {self._init}'
|
194
|
+
data = jnp.zeros(shape, dtype=a.dtype)
|
195
|
+
return data
|
196
|
+
|
197
|
+
@call_order(3)
|
198
|
+
def init_state(self, batch_size: int = None, **kwargs):
|
199
|
+
fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
|
200
|
+
self.history = ShortTermState(jax.tree.map(fun, self.target_info))
|
201
|
+
|
202
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
203
|
+
fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
|
204
|
+
self.history.value = jax.tree.map(fun, self.target_info)
|
205
|
+
|
206
|
+
def register_delay(self, *delay_time):
|
207
|
+
"""
|
208
|
+
Register delay times and update the maximum delay configuration.
|
209
|
+
|
210
|
+
This method processes one or more delay times, validates their format and consistency,
|
211
|
+
and updates the delay buffer size if necessary. It handles both scalar and vector
|
212
|
+
delay times, ensuring all vector delays have the same size.
|
213
|
+
|
214
|
+
Args:
|
215
|
+
*delay_time: Variable number of delay time arguments. The first argument should be
|
216
|
+
the primary delay time (float, int, or array-like). Additional arguments are
|
217
|
+
treated as indices or secondary delay parameters. All delay times should be
|
218
|
+
non-negative numbers or arrays of the same size.
|
219
|
+
|
220
|
+
Returns:
|
221
|
+
tuple or None: If delay_time[0] is None, returns None. Otherwise, returns a tuple
|
222
|
+
containing (delay_step, *delay_time[1:]) where delay_step is the computed
|
223
|
+
delay step in integer time units, and the remaining elements are the
|
224
|
+
additional delay parameters passed in.
|
225
|
+
|
226
|
+
Raises:
|
227
|
+
AssertionError: If no delay time is provided (empty delay_time).
|
228
|
+
ValueError: If delay times have inconsistent sizes when using vector delays,
|
229
|
+
or if delay times are not scalar or 1D arrays.
|
230
|
+
|
231
|
+
Note:
|
232
|
+
- The method updates self.max_time and self.max_length if the new delay
|
233
|
+
requires a larger buffer size.
|
234
|
+
- Delay steps are computed using the current environment time step (dt).
|
235
|
+
- All delay indices (delay_time[1:]) must be integers.
|
236
|
+
- Vector delays must all have the same size as the first delay time.
|
237
|
+
|
238
|
+
Example:
|
239
|
+
>>> delay_obj.register_delay(5.0) # Register 5ms delay
|
240
|
+
>>> delay_obj.register_delay(jnp.array([2.0, 3.0]), 0, 1) # Vector delay with indices
|
241
|
+
"""
|
242
|
+
assert len(delay_time) >= 1, 'You should provide at least one delay time.'
|
243
|
+
for dt in delay_time[1:]:
|
244
|
+
assert jnp.issubdtype(u.math.get_dtype(dt), jnp.integer), f'The index should be integer. But got {dt}.'
|
245
|
+
if delay_time[0] is None:
|
246
|
+
return None
|
247
|
+
with jax.ensure_compile_time_eval():
|
248
|
+
time, delay_step = _get_delay(delay_time[0])
|
249
|
+
max_delay_step = jnp.max(delay_step)
|
250
|
+
self.max_time = u.math.max(time)
|
251
|
+
|
252
|
+
# delay variable
|
253
|
+
if self.max_length <= max_delay_step + 1:
|
254
|
+
self.max_length = max_delay_step + 1
|
255
|
+
return delay_step, *delay_time[1:]
|
256
|
+
|
257
|
+
def register_entry(self, entry: str, *delay_time) -> 'Delay':
|
258
|
+
"""
|
259
|
+
Register an entry to access the delay data.
|
260
|
+
|
261
|
+
Args:
|
262
|
+
entry: str. The entry to access the delay data.
|
263
|
+
delay_time: The delay time of the entry, the first element is the delay time,
|
264
|
+
the second and later element is the index.
|
265
|
+
"""
|
266
|
+
if entry in self._registered_entries:
|
267
|
+
raise KeyError(
|
268
|
+
f'Entry {entry} has been registered. '
|
269
|
+
f'The existing delay for the key {entry} is {self._registered_entries[entry]}. '
|
270
|
+
f'The new delay for the key {entry} is {delay_time}. '
|
271
|
+
f'You can use another key. '
|
272
|
+
)
|
273
|
+
delay_info = self.register_delay(*delay_time)
|
274
|
+
self._registered_entries[entry] = delay_info
|
275
|
+
return delay_info
|
276
|
+
|
277
|
+
def access(self, entry: str, *delay_time) -> DelayAccess:
|
278
|
+
"""
|
279
|
+
Create a DelayAccess object for a specific delay entry and delay time.
|
280
|
+
|
281
|
+
Args:
|
282
|
+
entry (str): The name of the delay entry to access.
|
283
|
+
delay_time (Sequence): The delay time or parameters associated with the entry.
|
284
|
+
|
285
|
+
Returns:
|
286
|
+
DelayAccess: An object that provides access to the delay data for the specified entry and time.
|
287
|
+
"""
|
288
|
+
return DelayAccess(self, *delay_time, entry=entry)
|
289
|
+
|
290
|
+
def at(self, entry: str) -> ArrayLike:
|
291
|
+
"""
|
292
|
+
Get the data at the given entry.
|
293
|
+
|
294
|
+
Args:
|
295
|
+
entry: str. The entry to access the data.
|
296
|
+
|
297
|
+
Returns:
|
298
|
+
The data.
|
299
|
+
"""
|
300
|
+
assert isinstance(entry, str), (f'entry should be a string for describing the '
|
301
|
+
f'entry of the delay data. But we got {entry}.')
|
302
|
+
if entry not in self._registered_entries:
|
303
|
+
raise KeyError(f'Does not find delay entry "{entry}".')
|
304
|
+
delay_step = self._registered_entries[entry]
|
305
|
+
if delay_step is None:
|
306
|
+
delay_step = (0,)
|
307
|
+
return self.retrieve_at_step(*delay_step)
|
308
|
+
|
309
|
+
def retrieve_at_step(self, delay_step, *indices) -> PyTree:
|
310
|
+
"""
|
311
|
+
Retrieve the delay data at the given delay time step (the integer to indicate the time step).
|
312
|
+
|
313
|
+
Parameters
|
314
|
+
----------
|
315
|
+
delay_step: int_like
|
316
|
+
Retrieve the data at the given time step.
|
317
|
+
indices: tuple
|
318
|
+
The indices to slice the data.
|
319
|
+
|
320
|
+
Returns
|
321
|
+
-------
|
322
|
+
delay_data: The delay data at the given delay step.
|
323
|
+
|
324
|
+
"""
|
325
|
+
assert self.history is not None, 'The delay history is not initialized.'
|
326
|
+
assert delay_step is not None, 'The delay step should be given.'
|
327
|
+
|
328
|
+
if environ.get(environ.JIT_ERROR_CHECK, False):
|
329
|
+
def _check_delay(delay_len):
|
330
|
+
raise ValueError(
|
331
|
+
f'The request delay length should be less than the '
|
332
|
+
f'maximum delay {self.max_length - 1}. But we got {delay_len}'
|
333
|
+
)
|
334
|
+
|
335
|
+
jit_error_if(delay_step >= self.max_length, _check_delay, delay_step)
|
336
|
+
|
337
|
+
# rotation method
|
338
|
+
with jax.ensure_compile_time_eval():
|
339
|
+
if self.delay_method == _DELAY_ROTATE:
|
340
|
+
i = environ.get(environ.I, desc='The time step index.')
|
341
|
+
di = i - delay_step
|
342
|
+
delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32)
|
343
|
+
delay_idx = jax.lax.stop_gradient(delay_idx)
|
344
|
+
|
345
|
+
elif self.delay_method == _DELAY_CONCAT:
|
346
|
+
delay_idx = delay_step
|
347
|
+
|
348
|
+
else:
|
349
|
+
raise ValueError(f'Unknown delay updating method "{self.delay_method}"')
|
350
|
+
|
351
|
+
# the delay index
|
352
|
+
if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer):
|
353
|
+
raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}')
|
354
|
+
indices = (delay_idx,) + indices
|
355
|
+
|
356
|
+
# the delay data
|
357
|
+
if self._unit is None:
|
358
|
+
return jax.tree.map(lambda a: a[indices], self.history.value)
|
359
|
+
else:
|
360
|
+
return jax.tree.map(
|
361
|
+
lambda hist, unit: u.maybe_decimal(hist[indices] * unit),
|
362
|
+
self.history.value,
|
363
|
+
self._unit
|
364
|
+
)
|
365
|
+
|
366
|
+
def retrieve_at_time(self, delay_time, *indices) -> PyTree:
|
367
|
+
"""
|
368
|
+
Retrieve the delay data at the given delay time step (the integer to indicate the time step).
|
369
|
+
|
370
|
+
Parameters
|
371
|
+
----------
|
372
|
+
delay_time: float
|
373
|
+
Retrieve the data at the given time.
|
374
|
+
indices: tuple
|
375
|
+
The indices to slice the data.
|
376
|
+
|
377
|
+
Returns
|
378
|
+
-------
|
379
|
+
delay_data: The delay data at the given delay step.
|
380
|
+
|
381
|
+
"""
|
382
|
+
assert self.history is not None, 'The delay history is not initialized.'
|
383
|
+
assert delay_time is not None, 'The delay time should be given.'
|
384
|
+
|
385
|
+
current_time = environ.get(environ.T, desc='The current time.')
|
386
|
+
dt = environ.get_dt()
|
387
|
+
|
388
|
+
if environ.get(environ.JIT_ERROR_CHECK, False):
|
389
|
+
def _check_delay(t_now, t_delay):
|
390
|
+
raise ValueError(
|
391
|
+
f'The request delay time should be within '
|
392
|
+
f'[{t_now - self.max_time - dt}, {t_now}], '
|
393
|
+
f'but we got {t_delay}'
|
394
|
+
)
|
395
|
+
|
396
|
+
jit_error_if(
|
397
|
+
jnp.logical_or(
|
398
|
+
delay_time > current_time,
|
399
|
+
delay_time < current_time - self.max_time - dt
|
400
|
+
),
|
401
|
+
_check_delay,
|
402
|
+
current_time,
|
403
|
+
delay_time
|
404
|
+
)
|
405
|
+
|
406
|
+
with jax.ensure_compile_time_eval():
|
407
|
+
diff = current_time - delay_time
|
408
|
+
float_time_step = diff / dt
|
409
|
+
|
410
|
+
if self.interp_method == _INTERP_LINEAR: # "linear" interpolation
|
411
|
+
data_at_t0 = self.retrieve_at_step(jnp.asarray(jnp.floor(float_time_step), dtype=jnp.int32), *indices)
|
412
|
+
data_at_t1 = self.retrieve_at_step(jnp.asarray(jnp.ceil(float_time_step), dtype=jnp.int32), *indices)
|
413
|
+
t_diff = float_time_step - jnp.floor(float_time_step)
|
414
|
+
return jax.tree.map(lambda a, b: a * (1 - t_diff) + b * t_diff, data_at_t0, data_at_t1)
|
415
|
+
|
416
|
+
elif self.interp_method == _INTERP_ROUND: # "round" interpolation
|
417
|
+
return self.retrieve_at_step(jnp.asarray(jnp.round(float_time_step), dtype=jnp.int32), *indices)
|
418
|
+
|
419
|
+
else: # raise error
|
420
|
+
raise ValueError(f'Un-supported interpolation method {self.interp_method}, '
|
421
|
+
f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}')
|
422
|
+
|
423
|
+
def update(self, current: PyTree) -> None:
|
424
|
+
"""
|
425
|
+
Update delay variable with the new data.
|
426
|
+
"""
|
427
|
+
|
428
|
+
with jax.ensure_compile_time_eval():
|
429
|
+
assert self.history is not None, 'The delay history is not initialized.'
|
430
|
+
|
431
|
+
if self.take_aware_unit and self._unit is None:
|
432
|
+
self._unit = jax.tree.map(lambda x: u.get_unit(x), current, is_leaf=u.math.is_quantity)
|
433
|
+
|
434
|
+
# update the delay data at the rotation index
|
435
|
+
if self.delay_method == _DELAY_ROTATE:
|
436
|
+
i = environ.get(environ.I)
|
437
|
+
idx = jnp.asarray(i % self.max_length, dtype=environ.dutype())
|
438
|
+
idx = jax.lax.stop_gradient(idx)
|
439
|
+
self.history.value = jax.tree.map(
|
440
|
+
lambda hist, cur: hist.at[idx].set(cur),
|
441
|
+
self.history.value,
|
442
|
+
current
|
443
|
+
)
|
444
|
+
# update the delay data at the first position
|
445
|
+
elif self.delay_method == _DELAY_CONCAT:
|
446
|
+
current = jax.tree.map(lambda a: jnp.expand_dims(a, 0), current)
|
447
|
+
if self.max_length > 1:
|
448
|
+
self.history.value = jax.tree.map(
|
449
|
+
lambda hist, cur: jnp.concatenate([cur, hist[:-1]], axis=0),
|
450
|
+
self.history.value,
|
451
|
+
current
|
452
|
+
)
|
453
|
+
else:
|
454
|
+
self.history.value = current
|
455
|
+
|
456
|
+
else:
|
457
|
+
raise ValueError(f'Unknown updating method "{self.delay_method}"')
|
458
|
+
|
459
|
+
|
460
|
+
class StateWithDelay(Delay):
|
461
|
+
"""
|
462
|
+
Delayed history buffer bound to a module state.
|
463
|
+
|
464
|
+
StateWithDelay is a specialized :py:class:`~.Delay` that attaches to a
|
465
|
+
concrete :py:class:`~brainstate._state.State` living on a target module
|
466
|
+
(for example a membrane potential ``V`` on a neuron). It automatically
|
467
|
+
maintains a rolling history of that state and exposes convenient helpers to
|
468
|
+
retrieve the value at a given delay either by step or by time.
|
469
|
+
|
470
|
+
In normal usage you rarely instantiate this class directly. It is created
|
471
|
+
implicitly when using the prefetch-delay helpers on a Dynamics module, e.g.:
|
472
|
+
|
473
|
+
- ``module.prefetch('V').delay.at(5.0 * u.ms)``
|
474
|
+
- ``module.prefetch_delay('V', 5.0 * u.ms)``
|
475
|
+
|
476
|
+
Both will construct a StateWithDelay bound to ``module.V`` under the hood
|
477
|
+
and register the requested delay, so you can retrieve the delayed value
|
478
|
+
inside your update rules.
|
479
|
+
|
480
|
+
Parameters
|
481
|
+
----------
|
482
|
+
target : :py:class:`~brainstate.graph.Node`
|
483
|
+
The module object that owns the state to track.
|
484
|
+
item : str
|
485
|
+
The attribute name of the target state on ``target`` (must be a
|
486
|
+
:py:class:`~brainstate._state.State`).
|
487
|
+
init : Callable, optional
|
488
|
+
Optional initializer used to fill the history buffer before ``t0``
|
489
|
+
when delays request values from the past that hasn't been simulated yet.
|
490
|
+
The callable receives ``(shape, dtype)`` and must return an array.
|
491
|
+
If not provided, zeros are used. You may also pass a scalar/array
|
492
|
+
literal via the underlying Delay API when constructing manually.
|
493
|
+
delay_method : {"rotation", "concat"}, default "rotation"
|
494
|
+
Internal buffering strategy (inherits behavior from :py:class:`~.Delay`).
|
495
|
+
"rotation" keeps a ring buffer; "concat" shifts by concatenation.
|
496
|
+
|
497
|
+
Attributes
|
498
|
+
----------
|
499
|
+
state : :py:class:`~brainstate._state.State`
|
500
|
+
The concrete state object being tracked.
|
501
|
+
history : :py:class:`~brainstate._state.ShortTermState`
|
502
|
+
Rolling time axis buffer with shape ``[length, *state.shape]``.
|
503
|
+
max_time : float
|
504
|
+
Maximum time span currently supported by the buffer.
|
505
|
+
max_length : int
|
506
|
+
Buffer length in steps (``ceil(max_time/dt)+1``).
|
507
|
+
|
508
|
+
Notes
|
509
|
+
-----
|
510
|
+
- This class inherits all retrieval utilities from :py:class:`~.Delay`:
|
511
|
+
use :py:meth:`retrieve_at_step` when you know the integer delay steps,
|
512
|
+
or :py:meth:`retrieve_at_time` for continuous-time queries with optional
|
513
|
+
linear/round interpolation.
|
514
|
+
- It is registered as an "after-update" hook on the owning Dynamics so the
|
515
|
+
buffer is updated automatically after each simulation step.
|
516
|
+
|
517
|
+
Examples
|
518
|
+
--------
|
519
|
+
Access a neuron's membrane potential 5 ms in the past:
|
520
|
+
|
521
|
+
>>> import brainunit as u
|
522
|
+
>>> import brainstate as brainstate
|
523
|
+
>>> lif = brainstate.nn.LIF(100)
|
524
|
+
>>> # Create a delayed accessor to V(t-5ms)
|
525
|
+
>>> v_delay = lif.prefetch_delay('V', 5.0 * u.ms)
|
526
|
+
>>> # Inside another module's update you can read the delayed value
|
527
|
+
>>> v_t_minus_5ms = v_delay()
|
528
|
+
|
529
|
+
Register multiple delay taps and index-specific delays:
|
530
|
+
|
531
|
+
>>> # Under the hood, a StateWithDelay is created and you can register
|
532
|
+
>>> # additional taps (in steps or time) via its Delay interface
|
533
|
+
>>> _ = lif.prefetch('V').delay.at(2.0 * u.ms) # additional delay
|
534
|
+
>>> # Direct access to buffer by steps (advanced)
|
535
|
+
>>> # lif._get_after_update('V-prefetch-delay').retrieve_at_step(3)
|
536
|
+
"""
|
537
|
+
|
538
|
+
__module__ = 'brainstate.nn'
|
539
|
+
|
540
|
+
state: State # state
|
541
|
+
|
542
|
+
def __init__(
|
543
|
+
self,
|
544
|
+
target: Node,
|
545
|
+
item: str,
|
546
|
+
init: Callable = None,
|
547
|
+
delay_method: Optional[str] = _DELAY_ROTATE,
|
548
|
+
):
|
549
|
+
super().__init__(None, init=init, delay_method=delay_method)
|
550
|
+
|
551
|
+
self._target = target
|
552
|
+
self._target_term = item
|
553
|
+
|
554
|
+
@property
|
555
|
+
def state(self) -> State:
|
556
|
+
r = getattr(self._target, self._target_term)
|
557
|
+
if not isinstance(r, State):
|
558
|
+
raise TypeError(f'The term "{self._target_term}" in the module "{self._target}" is not a State.')
|
559
|
+
return r
|
560
|
+
|
561
|
+
@call_order(3)
|
562
|
+
def init_state(self, *args, **kwargs):
|
563
|
+
"""
|
564
|
+
State initialization function.
|
565
|
+
"""
|
566
|
+
state = self.state
|
567
|
+
self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), state.value)
|
568
|
+
super().init_state(*args, **kwargs)
|
569
|
+
|
570
|
+
def update(self, *args) -> None:
|
571
|
+
"""
|
572
|
+
Update the delay variable with the new data.
|
573
|
+
"""
|
574
|
+
value = self.state.value
|
575
|
+
return super().update(value)
|