brainstate 0.2.0__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +2 -4
- brainstate/_deprecation_test.py +2 -24
- brainstate/_state.py +540 -35
- brainstate/_state_test.py +1085 -8
- brainstate/graph/_operation.py +1 -5
- brainstate/mixin.py +14 -0
- brainstate/nn/__init__.py +42 -33
- brainstate/nn/_collective_ops.py +2 -0
- brainstate/nn/_common_test.py +0 -20
- brainstate/nn/_delay.py +1 -1
- brainstate/nn/_dropout_test.py +9 -6
- brainstate/nn/_dynamics.py +67 -464
- brainstate/nn/_dynamics_test.py +0 -14
- brainstate/nn/_embedding.py +7 -7
- brainstate/nn/_exp_euler.py +9 -9
- brainstate/nn/_linear.py +21 -21
- brainstate/nn/_module.py +25 -18
- brainstate/nn/_normalizations.py +27 -27
- brainstate/random/__init__.py +6 -6
- brainstate/random/{_rand_funs.py → _fun.py} +1 -1
- brainstate/random/{_rand_funs_test.py → _fun_test.py} +0 -2
- brainstate/random/_impl.py +672 -0
- brainstate/random/{_rand_seed.py → _seed.py} +1 -1
- brainstate/random/{_rand_state.py → _state.py} +121 -418
- brainstate/random/{_rand_state_test.py → _state_test.py} +7 -7
- brainstate/transform/__init__.py +6 -9
- brainstate/transform/_conditions.py +2 -2
- brainstate/transform/_find_state.py +200 -0
- brainstate/transform/_find_state_test.py +84 -0
- brainstate/transform/_make_jaxpr.py +221 -61
- brainstate/transform/_make_jaxpr_test.py +125 -1
- brainstate/transform/_mapping.py +287 -209
- brainstate/transform/_mapping_test.py +94 -184
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/METADATA +1 -1
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/RECORD +39 -39
- brainstate/transform/_eval_shape.py +0 -145
- brainstate/transform/_eval_shape_test.py +0 -38
- brainstate/transform/_random.py +0 -171
- /brainstate/random/{_rand_seed_test.py → _seed_test.py} +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
brainstate/nn/_dynamics.py
CHANGED
@@ -33,7 +33,7 @@ For handling the delays:
|
|
33
33
|
|
34
34
|
"""
|
35
35
|
|
36
|
-
from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar,
|
36
|
+
from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, Tuple
|
37
37
|
|
38
38
|
import jax
|
39
39
|
import numpy as np
|
@@ -41,23 +41,26 @@ import numpy as np
|
|
41
41
|
from brainstate import environ
|
42
42
|
from brainstate._state import State
|
43
43
|
from brainstate.graph import Node
|
44
|
-
from brainstate.mixin import ParamDescriber
|
45
44
|
from brainstate.typing import Size, ArrayLike
|
46
45
|
from ._delay import StateWithDelay, Delay
|
47
46
|
from ._module import Module
|
48
47
|
|
48
|
+
T = TypeVar('T')
|
49
|
+
|
49
50
|
__all__ = [
|
50
|
-
'DynamicsGroup',
|
51
51
|
'Dynamics',
|
52
|
+
|
53
|
+
'receive_update_output',
|
54
|
+
'not_receive_update_output',
|
55
|
+
'receive_update_input',
|
56
|
+
'not_receive_update_input',
|
57
|
+
|
52
58
|
'Prefetch',
|
53
59
|
'PrefetchDelay',
|
54
60
|
'PrefetchDelayAt',
|
55
61
|
'OutputDelayAt',
|
56
62
|
]
|
57
63
|
|
58
|
-
T = TypeVar('T')
|
59
|
-
_max_order = 10
|
60
|
-
|
61
64
|
|
62
65
|
class Dynamics(Module):
|
63
66
|
"""
|
@@ -122,7 +125,7 @@ class Dynamics(Module):
|
|
122
125
|
|
123
126
|
__module__ = 'brainstate.nn'
|
124
127
|
|
125
|
-
graph_invisible_attrs = (
|
128
|
+
graph_invisible_attrs = ()
|
126
129
|
|
127
130
|
# before updates
|
128
131
|
_before_updates: Optional[Dict[Hashable, Callable]]
|
@@ -136,11 +139,7 @@ class Dynamics(Module):
|
|
136
139
|
# delta inputs
|
137
140
|
_delta_inputs: Optional[Dict[str, ArrayLike | Callable]]
|
138
141
|
|
139
|
-
def __init__(
|
140
|
-
self,
|
141
|
-
in_size: Size,
|
142
|
-
name: Optional[str] = None,
|
143
|
-
):
|
142
|
+
def __init__(self, in_size: Size, name: Optional[str] = None):
|
144
143
|
# initialize
|
145
144
|
super().__init__(name=name)
|
146
145
|
|
@@ -157,12 +156,6 @@ class Dynamics(Module):
|
|
157
156
|
raise ValueError(f'"in_size" must be int, or a tuple/list of int. But we got {type(in_size)}')
|
158
157
|
self.in_size = in_size
|
159
158
|
|
160
|
-
# current inputs
|
161
|
-
self._current_inputs = None
|
162
|
-
|
163
|
-
# delta inputs
|
164
|
-
self._delta_inputs = None
|
165
|
-
|
166
159
|
# before updates
|
167
160
|
self._before_updates = None
|
168
161
|
|
@@ -172,14 +165,6 @@ class Dynamics(Module):
|
|
172
165
|
# in-/out- size of neuron population
|
173
166
|
self.out_size = self.in_size
|
174
167
|
|
175
|
-
# def __pretty_repr_item__(self, name, value):
|
176
|
-
# if name in [
|
177
|
-
# '_before_updates', '_after_updates', '_current_inputs', '_delta_inputs',
|
178
|
-
# '_in_size', '_out_size', '_name', '_mode',
|
179
|
-
# ]:
|
180
|
-
# return (name, value) if value is None else (name[1:], value) # skip the first `_`
|
181
|
-
# return super().__pretty_repr_item__(name, value)
|
182
|
-
|
183
168
|
@property
|
184
169
|
def varshape(self):
|
185
170
|
"""
|
@@ -200,327 +185,72 @@ class Dynamics(Module):
|
|
200
185
|
"""
|
201
186
|
return self.in_size
|
202
187
|
|
203
|
-
|
204
|
-
def current_inputs(self):
|
205
|
-
"""
|
206
|
-
Get the dictionary of current inputs registered with this dynamics model.
|
207
|
-
|
208
|
-
Current inputs represent direct input currents that flow into the model.
|
209
|
-
|
210
|
-
Returns
|
211
|
-
-------
|
212
|
-
dict or None
|
213
|
-
A dictionary mapping keys to current input functions or values,
|
214
|
-
or None if no current inputs have been registered.
|
215
|
-
|
216
|
-
See Also
|
217
|
-
--------
|
218
|
-
add_current_input : Register a new current input
|
219
|
-
sum_current_inputs : Apply and sum all current inputs
|
220
|
-
delta_inputs : Dictionary of instantaneous change inputs
|
221
|
-
"""
|
222
|
-
return self._current_inputs
|
223
|
-
|
224
|
-
@property
|
225
|
-
def delta_inputs(self):
|
226
|
-
"""
|
227
|
-
Get the dictionary of delta inputs registered with this dynamics model.
|
228
|
-
|
229
|
-
Delta inputs represent instantaneous changes to state variables (dX/dt).
|
230
|
-
|
231
|
-
Returns
|
232
|
-
-------
|
233
|
-
dict or None
|
234
|
-
A dictionary mapping keys to delta input functions or values,
|
235
|
-
or None if no delta inputs have been registered.
|
236
|
-
|
237
|
-
See Also
|
238
|
-
--------
|
239
|
-
add_delta_input : Register a new delta input
|
240
|
-
sum_delta_inputs : Apply and sum all delta inputs
|
241
|
-
current_inputs : Dictionary of direct current inputs
|
242
|
-
"""
|
243
|
-
return self._delta_inputs
|
244
|
-
|
245
|
-
def add_current_input(
|
246
|
-
self,
|
247
|
-
key: str,
|
248
|
-
inp: Union[Callable, ArrayLike],
|
249
|
-
label: Optional[str] = None
|
250
|
-
):
|
251
|
-
"""
|
252
|
-
Add a current input function or array to the dynamics model.
|
253
|
-
|
254
|
-
Current inputs represent direct input currents that can be accessed during
|
255
|
-
model updates through the `sum_current_inputs()` method.
|
256
|
-
|
257
|
-
Parameters
|
258
|
-
----------
|
259
|
-
key : str
|
260
|
-
Unique identifier for this current input. Used to retrieve or reference
|
261
|
-
the input later.
|
262
|
-
inp : Union[Callable, ArrayLike]
|
263
|
-
The input data or function that generates input data.
|
264
|
-
- If callable: Will be called during updates with arguments passed to `sum_current_inputs()`
|
265
|
-
- If array-like: Will be applied once and then automatically removed from available inputs
|
266
|
-
label : Optional[str], default=None
|
267
|
-
Optional grouping label for the input. When provided, allows selective
|
268
|
-
processing of inputs by label in `sum_current_inputs()`.
|
269
|
-
|
270
|
-
Raises
|
271
|
-
------
|
272
|
-
ValueError
|
273
|
-
If the key has already been used for a different current input.
|
274
|
-
|
275
|
-
Notes
|
276
|
-
-----
|
277
|
-
- Inputs with the same label can be processed together using the `label`
|
278
|
-
parameter in `sum_current_inputs()`.
|
279
|
-
- Non-callable inputs are consumed when used (removed after first use).
|
280
|
-
- Callable inputs persist and can be called repeatedly.
|
281
|
-
|
282
|
-
See Also
|
283
|
-
--------
|
284
|
-
sum_current_inputs : Sum all current inputs matching a given label
|
285
|
-
add_delta_input : Add a delta input function or array
|
286
|
-
"""
|
287
|
-
key = _input_label_repr(key, label)
|
288
|
-
if self._current_inputs is None:
|
289
|
-
self._current_inputs = dict()
|
290
|
-
if key in self._current_inputs:
|
291
|
-
if id(self._current_inputs[key]) != id(inp):
|
292
|
-
raise ValueError(f'Key "{key}" has been defined and used in the current inputs of {self}.')
|
293
|
-
self._current_inputs[key] = inp
|
294
|
-
|
295
|
-
def add_delta_input(
|
296
|
-
self,
|
297
|
-
key: str,
|
298
|
-
inp: Union[Callable, ArrayLike],
|
299
|
-
label: Optional[str] = None
|
300
|
-
):
|
301
|
-
"""
|
302
|
-
Add a delta input function or array to the dynamics model.
|
303
|
-
|
304
|
-
Delta inputs represent instantaneous changes to the model state (i.e., dX/dt contributions).
|
305
|
-
This method registers a function or array that provides delta inputs which will be
|
306
|
-
accessible during model updates through the `sum_delta_inputs()` method.
|
307
|
-
|
308
|
-
Parameters
|
309
|
-
----------
|
310
|
-
key : str
|
311
|
-
Unique identifier for this delta input. Used to retrieve or reference
|
312
|
-
the input later.
|
313
|
-
inp : Union[Callable, ArrayLike]
|
314
|
-
The input data or function that generates input data.
|
315
|
-
- If callable: Will be called during updates with arguments passed to `sum_delta_inputs()`
|
316
|
-
- If array-like: Will be applied once and then automatically removed from available inputs
|
317
|
-
label : Optional[str], default=None
|
318
|
-
Optional grouping label for the input. When provided, allows selective
|
319
|
-
processing of inputs by label in `sum_delta_inputs()`.
|
320
|
-
|
321
|
-
Raises
|
322
|
-
------
|
323
|
-
ValueError
|
324
|
-
If the key has already been used for a different delta input.
|
325
|
-
|
326
|
-
Notes
|
327
|
-
-----
|
328
|
-
- Inputs with the same label can be processed together using the `label`
|
329
|
-
parameter in `sum_delta_inputs()`.
|
330
|
-
- Non-callable inputs are consumed when used (removed after first use).
|
331
|
-
- Callable inputs persist and can be called repeatedly.
|
332
|
-
|
333
|
-
See Also
|
334
|
-
--------
|
335
|
-
sum_delta_inputs : Sum all delta inputs matching a given label
|
336
|
-
add_current_input : Add a current input function or array
|
337
|
-
"""
|
338
|
-
key = _input_label_repr(key, label)
|
339
|
-
if self._delta_inputs is None:
|
340
|
-
self._delta_inputs = dict()
|
341
|
-
if key in self._delta_inputs:
|
342
|
-
if id(self._delta_inputs[key]) != id(inp):
|
343
|
-
raise ValueError(f'Key "{key}" has been defined and used.')
|
344
|
-
self._delta_inputs[key] = inp
|
345
|
-
|
346
|
-
def get_input(self, key: str):
|
188
|
+
def prefetch(self, item: str) -> 'Prefetch':
|
347
189
|
"""
|
348
|
-
|
190
|
+
Create a reference to a state or variable that may not be initialized yet.
|
349
191
|
|
350
|
-
|
351
|
-
|
352
|
-
|
192
|
+
This method allows accessing module attributes or states before they are
|
193
|
+
fully defined, acting as a placeholder that will be resolved when called.
|
194
|
+
Particularly useful for creating references to variables that will be defined
|
195
|
+
during initialization or runtime.
|
353
196
|
|
354
197
|
Parameters
|
355
198
|
----------
|
356
|
-
|
357
|
-
The
|
199
|
+
item : str
|
200
|
+
The name of the attribute or state to reference.
|
358
201
|
|
359
202
|
Returns
|
360
203
|
-------
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
Raises
|
365
|
-
------
|
366
|
-
ValueError
|
367
|
-
If no input function is found with the specified key in either
|
368
|
-
current_inputs or delta_inputs.
|
369
|
-
|
370
|
-
See Also
|
371
|
-
--------
|
372
|
-
add_current_input : Register a current input function
|
373
|
-
add_delta_input : Register a delta input function
|
204
|
+
Prefetch
|
205
|
+
A Prefetch object that provides access to the referenced item.
|
374
206
|
|
375
207
|
Examples
|
376
208
|
--------
|
377
|
-
>>>
|
378
|
-
>>>
|
379
|
-
>>>
|
380
|
-
>>>
|
209
|
+
>>> import brainstate
|
210
|
+
>>> import brainunit as u
|
211
|
+
>>> neuron = brainstate.nn.LIF(...)
|
212
|
+
>>> v_ref = neuron.prefetch('V') # Reference to voltage
|
213
|
+
>>> v_value = v_ref() # Get current value
|
214
|
+
>>> delayed_v = v_ref.delay.at(5.0 * u.ms) # Get delayed value
|
381
215
|
"""
|
382
|
-
|
383
|
-
return self._current_inputs[key]
|
384
|
-
elif self._delta_inputs is not None and key in self._delta_inputs:
|
385
|
-
return self._delta_inputs[key]
|
386
|
-
else:
|
387
|
-
raise ValueError(f'Input key {key} is not in current/delta inputs of the module {self}.')
|
216
|
+
return Prefetch(self, item)
|
388
217
|
|
389
|
-
def
|
390
|
-
self,
|
391
|
-
init: Any,
|
392
|
-
*args,
|
393
|
-
label: Optional[str] = None,
|
394
|
-
**kwargs
|
395
|
-
):
|
218
|
+
def prefetch_delay(self, state: str, delay_time, init: Callable = None) -> 'PrefetchDelayAt':
|
396
219
|
"""
|
397
|
-
|
398
|
-
|
399
|
-
This method iterates through all registered current input functions (from `.current_inputs`)
|
400
|
-
and applies them to calculate the total input current for the dynamics model. It adds all results
|
401
|
-
to the initial value provided.
|
220
|
+
Create a reference to a delayed state or variable in the module.
|
402
221
|
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
The initial value to which all current inputs will be added.
|
407
|
-
*args
|
408
|
-
Variable length argument list passed to each current input function.
|
409
|
-
label : Optional[str], default=None
|
410
|
-
If provided, only process current inputs with this label prefix.
|
411
|
-
When None, process all current inputs regardless of label.
|
412
|
-
**kwargs
|
413
|
-
Arbitrary keyword arguments passed to each current input function.
|
222
|
+
This method simplifies the process of accessing a delayed version of a state or variable
|
223
|
+
within the module. It first creates a prefetch reference to the specified state,
|
224
|
+
then specifies the delay time for accessing this state.
|
414
225
|
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
226
|
+
Args:
|
227
|
+
state (str): The name of the state or variable to reference.
|
228
|
+
delay_time (ArrayLike): The amount of time to delay the variable access,
|
229
|
+
typically in time units (e.g., milliseconds).
|
230
|
+
init (Callable, optional): An optional initialization function to provide
|
231
|
+
a default value if the delayed state is not yet available.
|
419
232
|
|
420
|
-
|
421
|
-
|
422
|
-
- Non-callable current inputs are applied once and then automatically removed from
|
423
|
-
the current_inputs dictionary.
|
424
|
-
- Callable current inputs remain registered for subsequent calls.
|
425
|
-
- When a label is provided, only current inputs with keys starting with that label
|
426
|
-
are applied.
|
427
|
-
"""
|
428
|
-
if self._current_inputs is None:
|
429
|
-
return init
|
430
|
-
if label is None:
|
431
|
-
filter_fn = lambda k: True
|
432
|
-
else:
|
433
|
-
label_repr = _input_label_start(label)
|
434
|
-
filter_fn = lambda k: k.startswith(label_repr)
|
435
|
-
for key in tuple(self._current_inputs.keys()):
|
436
|
-
if filter_fn(key):
|
437
|
-
out = self._current_inputs[key]
|
438
|
-
if callable(out):
|
439
|
-
try:
|
440
|
-
init = init + out(*args, **kwargs)
|
441
|
-
except Exception as e:
|
442
|
-
raise ValueError(
|
443
|
-
f'Error in current input value {key}: {out}\n'
|
444
|
-
f'Error: {e}'
|
445
|
-
) from e
|
446
|
-
else:
|
447
|
-
try:
|
448
|
-
init = init + out
|
449
|
-
except Exception as e:
|
450
|
-
raise ValueError(
|
451
|
-
f'Error in current input value {key}: {out}\n'
|
452
|
-
f'Error: {e}'
|
453
|
-
) from e
|
454
|
-
self._current_inputs.pop(key)
|
455
|
-
return init
|
456
|
-
|
457
|
-
def sum_delta_inputs(
|
458
|
-
self,
|
459
|
-
init: Any,
|
460
|
-
*args,
|
461
|
-
label: Optional[str] = None,
|
462
|
-
**kwargs
|
463
|
-
):
|
233
|
+
Returns:
|
234
|
+
PrefetchDelayAt: An object that provides access to the variable at the specified delay time.
|
464
235
|
"""
|
465
|
-
|
236
|
+
return PrefetchDelayAt(self, state, delay_time, init=init)
|
466
237
|
|
467
|
-
|
468
|
-
|
469
|
-
to the
|
238
|
+
def output_delay(self, *delay_time) -> 'OutputDelayAt':
|
239
|
+
"""
|
240
|
+
Create a reference to the delayed output of the module.
|
470
241
|
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
The initial value to which all delta inputs will be added.
|
475
|
-
*args
|
476
|
-
Variable length argument list passed to each delta input function.
|
477
|
-
label : Optional[str], default=None
|
478
|
-
If provided, only process delta inputs with this label prefix.
|
479
|
-
When None, process all delta inputs regardless of label.
|
480
|
-
**kwargs
|
481
|
-
Arbitrary keyword arguments passed to each delta input function.
|
242
|
+
This method simplifies the process of accessing a delayed version of the module's output.
|
243
|
+
It instantiates an `OutputDelayAt` object, which can be used to retrieve the output value
|
244
|
+
at the specified delay time.
|
482
245
|
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
The initial value plus all applicable delta inputs summed together.
|
246
|
+
Args:
|
247
|
+
delay (Optional[ArrayLike]): The amount of time to delay the output access,
|
248
|
+
typically in time units (e.g., milliseconds). Defaults to None.
|
487
249
|
|
488
|
-
|
489
|
-
|
490
|
-
- Non-callable delta inputs are applied once and then automatically removed from
|
491
|
-
the delta_inputs dictionary.
|
492
|
-
- Callable delta inputs remain registered for subsequent calls.
|
493
|
-
- When a label is provided, only delta inputs with keys starting with that label
|
494
|
-
are applied.
|
250
|
+
Returns:
|
251
|
+
OutputDelayAt: An object that provides access to the module's output at the specified delay time.
|
495
252
|
"""
|
496
|
-
|
497
|
-
return init
|
498
|
-
if label is None:
|
499
|
-
filter_fn = lambda k: True
|
500
|
-
else:
|
501
|
-
label_repr = _input_label_start(label)
|
502
|
-
filter_fn = lambda k: k.startswith(label_repr)
|
503
|
-
for key in tuple(self._delta_inputs.keys()):
|
504
|
-
if filter_fn(key):
|
505
|
-
out = self._delta_inputs[key]
|
506
|
-
if callable(out):
|
507
|
-
try:
|
508
|
-
init = init + out(*args, **kwargs)
|
509
|
-
except Exception as e:
|
510
|
-
raise ValueError(
|
511
|
-
f'Error in delta input function {key}: {out}\n'
|
512
|
-
f'Error: {e}'
|
513
|
-
) from e
|
514
|
-
else:
|
515
|
-
try:
|
516
|
-
init = init + out
|
517
|
-
except Exception as e:
|
518
|
-
raise ValueError(
|
519
|
-
f'Error in delta input value {key}: {out}\n'
|
520
|
-
f'Error: {e}'
|
521
|
-
) from e
|
522
|
-
self._delta_inputs.pop(key)
|
523
|
-
return init
|
253
|
+
return OutputDelayAt(self, delay_time)
|
524
254
|
|
525
255
|
@property
|
526
256
|
def before_updates(self):
|
@@ -559,7 +289,7 @@ class Dynamics(Module):
|
|
559
289
|
"""
|
560
290
|
return self._after_updates
|
561
291
|
|
562
|
-
def
|
292
|
+
def add_before_update(self, key: Any, fun: Callable):
|
563
293
|
"""
|
564
294
|
Register a function to be executed before the module's update.
|
565
295
|
|
@@ -585,7 +315,7 @@ class Dynamics(Module):
|
|
585
315
|
raise KeyError(f'{key} has been registered in before_updates of {self}')
|
586
316
|
self.before_updates[key] = fun
|
587
317
|
|
588
|
-
def
|
318
|
+
def add_after_update(self, key: Any, fun: Callable):
|
589
319
|
"""
|
590
320
|
Register a function to be executed after the module's update.
|
591
321
|
|
@@ -611,7 +341,7 @@ class Dynamics(Module):
|
|
611
341
|
raise KeyError(f'{key} has been registered in after_updates of {self}')
|
612
342
|
self.after_updates[key] = fun
|
613
343
|
|
614
|
-
def
|
344
|
+
def get_before_update(self, key: Any):
|
615
345
|
"""
|
616
346
|
Retrieve a registered before-update function by its key.
|
617
347
|
|
@@ -636,7 +366,7 @@ class Dynamics(Module):
|
|
636
366
|
raise KeyError(f'{key} is not registered in before_updates of {self}')
|
637
367
|
return self.before_updates.get(key)
|
638
368
|
|
639
|
-
def
|
369
|
+
def get_after_update(self, key: Any):
|
640
370
|
"""
|
641
371
|
Retrieve a registered after-update function by its key.
|
642
372
|
|
@@ -661,7 +391,7 @@ class Dynamics(Module):
|
|
661
391
|
raise KeyError(f'{key} is not registered in after_updates of {self}')
|
662
392
|
return self.after_updates.get(key)
|
663
393
|
|
664
|
-
def
|
394
|
+
def has_before_update(self, key: Any):
|
665
395
|
"""
|
666
396
|
Check if a before-update function is registered with the given key.
|
667
397
|
|
@@ -679,7 +409,7 @@ class Dynamics(Module):
|
|
679
409
|
return False
|
680
410
|
return key in self.before_updates
|
681
411
|
|
682
|
-
def
|
412
|
+
def has_after_update(self, key: Any):
|
683
413
|
"""
|
684
414
|
Check if an after-update function is registered with the given key.
|
685
415
|
|
@@ -722,120 +452,6 @@ class Dynamics(Module):
|
|
722
452
|
model(ret)
|
723
453
|
return ret
|
724
454
|
|
725
|
-
def prefetch(self, item: str) -> 'Prefetch':
|
726
|
-
"""
|
727
|
-
Create a reference to a state or variable that may not be initialized yet.
|
728
|
-
|
729
|
-
This method allows accessing module attributes or states before they are
|
730
|
-
fully defined, acting as a placeholder that will be resolved when called.
|
731
|
-
Particularly useful for creating references to variables that will be defined
|
732
|
-
during initialization or runtime.
|
733
|
-
|
734
|
-
Parameters
|
735
|
-
----------
|
736
|
-
item : str
|
737
|
-
The name of the attribute or state to reference.
|
738
|
-
|
739
|
-
Returns
|
740
|
-
-------
|
741
|
-
Prefetch
|
742
|
-
A Prefetch object that provides access to the referenced item.
|
743
|
-
|
744
|
-
Examples
|
745
|
-
--------
|
746
|
-
>>> import brainstate
|
747
|
-
>>> import brainunit as u
|
748
|
-
>>> neuron = brainstate.nn.LIF(...)
|
749
|
-
>>> v_ref = neuron.prefetch('V') # Reference to voltage
|
750
|
-
>>> v_value = v_ref() # Get current value
|
751
|
-
>>> delayed_v = v_ref.delay.at(5.0 * u.ms) # Get delayed value
|
752
|
-
"""
|
753
|
-
return Prefetch(self, item)
|
754
|
-
|
755
|
-
def align_pre(self, dyn: Union[ParamDescriber[T], T]) -> T:
|
756
|
-
"""
|
757
|
-
Registers a dynamics module to execute after this module.
|
758
|
-
|
759
|
-
This method establishes a sequential execution relationship where the specified
|
760
|
-
dynamics module will be called after this module completes its update. This
|
761
|
-
creates a feed-forward connection in the computational graph.
|
762
|
-
|
763
|
-
Parameters
|
764
|
-
----------
|
765
|
-
dyn : Union[ParamDescriber[T], T]
|
766
|
-
The dynamics module to be executed after this module. Can be either:
|
767
|
-
- An instance of Dynamics
|
768
|
-
- A ParamDescriber that can instantiate a Dynamics object
|
769
|
-
|
770
|
-
Returns
|
771
|
-
-------
|
772
|
-
T
|
773
|
-
The dynamics module that was registered, allowing for method chaining.
|
774
|
-
|
775
|
-
Raises
|
776
|
-
------
|
777
|
-
TypeError
|
778
|
-
If the input is not a Dynamics instance or a ParamDescriber that creates
|
779
|
-
a Dynamics instance.
|
780
|
-
|
781
|
-
Examples
|
782
|
-
--------
|
783
|
-
>>> import brainstate
|
784
|
-
>>> n1 = brainstate.nn.LIF(10)
|
785
|
-
>>> n1.align_pre(brainstate.nn.Expon.desc(n1.varshape)) # n2 will run after n1
|
786
|
-
"""
|
787
|
-
if isinstance(dyn, Dynamics):
|
788
|
-
self._add_after_update(id(dyn), dyn)
|
789
|
-
return dyn
|
790
|
-
elif isinstance(dyn, ParamDescriber):
|
791
|
-
if not issubclass(dyn.cls, Dynamics):
|
792
|
-
raise TypeError(f'The input {dyn} should be an instance of {Dynamics}.')
|
793
|
-
if not self._has_after_update(dyn.identifier):
|
794
|
-
self._add_after_update(
|
795
|
-
dyn.identifier,
|
796
|
-
dyn() if ('in_size' in dyn.kwargs or len(dyn.args) > 0) else dyn(in_size=self.varshape)
|
797
|
-
)
|
798
|
-
return self._get_after_update(dyn.identifier)
|
799
|
-
else:
|
800
|
-
raise TypeError(f'The input {dyn} should be an instance of {Dynamics} or a delayed initializer.')
|
801
|
-
|
802
|
-
def prefetch_delay(self, state: str, delay_time, init: Callable = None) -> 'PrefetchDelayAt':
|
803
|
-
"""
|
804
|
-
Create a reference to a delayed state or variable in the module.
|
805
|
-
|
806
|
-
This method simplifies the process of accessing a delayed version of a state or variable
|
807
|
-
within the module. It first creates a prefetch reference to the specified state,
|
808
|
-
then specifies the delay time for accessing this state.
|
809
|
-
|
810
|
-
Args:
|
811
|
-
state (str): The name of the state or variable to reference.
|
812
|
-
delay_time (ArrayLike): The amount of time to delay the variable access,
|
813
|
-
typically in time units (e.g., milliseconds).
|
814
|
-
init (Callable, optional): An optional initialization function to provide
|
815
|
-
a default value if the delayed state is not yet available.
|
816
|
-
|
817
|
-
Returns:
|
818
|
-
PrefetchDelayAt: An object that provides access to the variable at the specified delay time.
|
819
|
-
"""
|
820
|
-
return PrefetchDelayAt(self, state, delay_time, init=init)
|
821
|
-
|
822
|
-
def output_delay(self, *delay_time) -> 'OutputDelayAt':
|
823
|
-
"""
|
824
|
-
Create a reference to the delayed output of the module.
|
825
|
-
|
826
|
-
This method simplifies the process of accessing a delayed version of the module's output.
|
827
|
-
It instantiates an `OutputDelayAt` object, which can be used to retrieve the output value
|
828
|
-
at the specified delay time.
|
829
|
-
|
830
|
-
Args:
|
831
|
-
delay (Optional[ArrayLike]): The amount of time to delay the output access,
|
832
|
-
typically in time units (e.g., milliseconds). Defaults to None.
|
833
|
-
|
834
|
-
Returns:
|
835
|
-
OutputDelayAt: An object that provides access to the module's output at the specified delay time.
|
836
|
-
"""
|
837
|
-
return OutputDelayAt(self, delay_time)
|
838
|
-
|
839
455
|
|
840
456
|
class Prefetch(Node):
|
841
457
|
"""
|
@@ -1047,14 +663,14 @@ class PrefetchDelayAt(Node):
|
|
1047
663
|
self.delay_time = delay_time
|
1048
664
|
if len(delay_time) > 0:
|
1049
665
|
key = _get_prefetch_delay_key(item)
|
1050
|
-
if not module.
|
1051
|
-
module.
|
666
|
+
if not module.has_after_update(key):
|
667
|
+
module.add_after_update(
|
1052
668
|
key,
|
1053
669
|
not_receive_update_output(
|
1054
670
|
StateWithDelay(module, item, init=init)
|
1055
671
|
)
|
1056
672
|
)
|
1057
|
-
self.state_delay: StateWithDelay = module.
|
673
|
+
self.state_delay: StateWithDelay = module.get_after_update(key)
|
1058
674
|
self.delay_info = self.state_delay.register_delay(*delay_time)
|
1059
675
|
|
1060
676
|
def __call__(self, *args, **kwargs):
|
@@ -1108,10 +724,10 @@ class OutputDelayAt(Node):
|
|
1108
724
|
assert isinstance(module, Dynamics), 'The module should be an instance of Dynamics.'
|
1109
725
|
self.module = module
|
1110
726
|
key = _get_output_delay_key()
|
1111
|
-
if not module.
|
727
|
+
if not module.has_after_update(key):
|
1112
728
|
delay = Delay(jax.ShapeDtypeStruct(module.out_size, dtype=environ.dftype()), take_aware_unit=True)
|
1113
|
-
module.
|
1114
|
-
self.out_delay: Delay = module.
|
729
|
+
module.add_after_update(key, receive_update_output(delay))
|
730
|
+
self.out_delay: Delay = module.get_after_update(key)
|
1115
731
|
self.delay_info = self.out_delay.register_delay(*delay_time)
|
1116
732
|
|
1117
733
|
def __call__(self, *args, **kwargs):
|
@@ -1138,7 +754,7 @@ def _get_prefetch_item_delay(target: Union[Prefetch, PrefetchDelay, PrefetchDela
|
|
1138
754
|
f'The target module should be an instance '
|
1139
755
|
f'of Dynamics. But got {target.module}.'
|
1140
756
|
)
|
1141
|
-
delay = target.module.
|
757
|
+
delay = target.module.get_after_update(_get_prefetch_delay_key(target.item))
|
1142
758
|
if not isinstance(delay, StateWithDelay):
|
1143
759
|
raise TypeError(f'The prefetch target should be a {StateWithDelay.__name__} when accessing '
|
1144
760
|
f'its delay. But got {delay}.')
|
@@ -1187,9 +803,6 @@ def maybe_init_prefetch(target, *args, **kwargs):
|
|
1187
803
|
# delay.register_delay(*target.delay_time)
|
1188
804
|
|
1189
805
|
|
1190
|
-
DynamicsGroup = Module
|
1191
|
-
|
1192
|
-
|
1193
806
|
def receive_update_output(cls: object):
|
1194
807
|
"""
|
1195
808
|
The decorator to mark the object (as the after updates) to receive the output of the update function.
|
@@ -1255,13 +868,3 @@ def not_receive_update_input(cls: object):
|
|
1255
868
|
if hasattr(cls, '_receive_update_input'):
|
1256
869
|
delattr(cls, '_receive_update_input')
|
1257
870
|
return cls
|
1258
|
-
|
1259
|
-
|
1260
|
-
def _input_label_start(label: str):
|
1261
|
-
# unify the input label repr.
|
1262
|
-
return f'{label} // '
|
1263
|
-
|
1264
|
-
|
1265
|
-
def _input_label_repr(name: str, label: Optional[str] = None):
|
1266
|
-
# unify the input label repr.
|
1267
|
-
return name if label is None else (_input_label_start(label) + str(name))
|