brainstate 0.1.0.post20250420__py2.py3-none-any.whl → 0.1.0.post20250422__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_compatible_import.py +15 -0
- brainstate/_state.py +5 -4
- brainstate/_state_test.py +2 -1
- brainstate/augment/_autograd_test.py +3 -2
- brainstate/augment/_eval_shape.py +2 -1
- brainstate/augment/_mapping.py +0 -1
- brainstate/augment/_mapping_test.py +1 -0
- brainstate/compile/_ad_checkpoint.py +2 -1
- brainstate/compile/_conditions.py +3 -3
- brainstate/compile/_conditions_test.py +2 -1
- brainstate/compile/_error_if.py +2 -1
- brainstate/compile/_error_if_test.py +2 -1
- brainstate/compile/_jit.py +3 -2
- brainstate/compile/_jit_test.py +2 -1
- brainstate/compile/_loop_collect_return.py +2 -2
- brainstate/compile/_loop_collect_return_test.py +2 -1
- brainstate/compile/_loop_no_collection.py +1 -1
- brainstate/compile/_make_jaxpr.py +2 -2
- brainstate/compile/_make_jaxpr_test.py +2 -1
- brainstate/compile/_progress_bar.py +2 -1
- brainstate/compile/_unvmap.py +1 -2
- brainstate/environ.py +4 -4
- brainstate/environ_test.py +2 -1
- brainstate/functional/_activations.py +2 -1
- brainstate/functional/_activations_test.py +1 -1
- brainstate/functional/_normalization.py +2 -1
- brainstate/functional/_others.py +2 -1
- brainstate/graph/_graph_operation.py +3 -2
- brainstate/graph/_graph_operation_test.py +4 -3
- brainstate/init/_base.py +2 -1
- brainstate/init/_generic.py +2 -1
- brainstate/nn/__init__.py +4 -0
- brainstate/nn/_collective_ops.py +1 -0
- brainstate/nn/_collective_ops_test.py +0 -4
- brainstate/nn/_common.py +0 -1
- brainstate/nn/_dyn_impl/__init__.py +0 -4
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +431 -13
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +2 -1
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +405 -103
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +2 -1
- brainstate/nn/_dyn_impl/_inputs.py +236 -29
- brainstate/nn/_dyn_impl/_rate_rnns.py +238 -82
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +2 -1
- brainstate/nn/_dyn_impl/_readout.py +91 -8
- brainstate/nn/_dyn_impl/_readout_test.py +2 -1
- brainstate/nn/_dynamics/_dynamics_base.py +676 -96
- brainstate/nn/_dynamics/_dynamics_base_test.py +2 -1
- brainstate/nn/_dynamics/_projection_base.py +29 -30
- brainstate/nn/_dynamics/_state_delay.py +3 -3
- brainstate/nn/_dynamics/_synouts_test.py +2 -1
- brainstate/nn/_elementwise/_dropout.py +3 -2
- brainstate/nn/_elementwise/_dropout_test.py +2 -1
- brainstate/nn/_elementwise/_elementwise.py +2 -1
- brainstate/nn/{_dyn_impl/_projection_alignpost.py → _event/__init__.py} +8 -7
- brainstate/nn/_event/_fixedprob_mv.py +169 -0
- brainstate/nn/_event/_fixedprob_mv_test.py +115 -0
- brainstate/nn/_event/_linear_mv.py +85 -0
- brainstate/nn/_event/_linear_mv_test.py +121 -0
- brainstate/nn/_exp_euler.py +2 -1
- brainstate/nn/_exp_euler_test.py +2 -1
- brainstate/nn/_interaction/_conv.py +2 -1
- brainstate/nn/_interaction/_linear.py +2 -1
- brainstate/nn/_interaction/_linear_test.py +2 -1
- brainstate/nn/_interaction/_normalizations.py +3 -2
- brainstate/nn/_interaction/_poolings.py +4 -3
- brainstate/nn/_module_test.py +2 -1
- brainstate/nn/metrics.py +4 -3
- brainstate/optim/_lr_scheduler.py +2 -1
- brainstate/optim/_lr_scheduler_test.py +2 -1
- brainstate/optim/_optax_optimizer_test.py +2 -1
- brainstate/optim/_sgd_optimizer.py +3 -2
- brainstate/random/_rand_funs.py +2 -1
- brainstate/random/_rand_funs_test.py +3 -2
- brainstate/random/_rand_seed.py +3 -2
- brainstate/random/_rand_seed_test.py +2 -1
- brainstate/random/_rand_state.py +4 -3
- brainstate/surrogate.py +1 -2
- brainstate/typing.py +4 -4
- brainstate/util/_caller.py +2 -1
- brainstate/util/_others.py +4 -4
- brainstate/util/_pretty_pytree.py +1 -1
- brainstate/util/_pretty_pytree_test.py +2 -1
- brainstate/util/_pretty_table.py +43 -43
- brainstate/util/_struct.py +2 -1
- brainstate/util/filter.py +0 -1
- {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250422.dist-info}/METADATA +3 -3
- brainstate-0.1.0.post20250422.dist-info/RECORD +133 -0
- brainstate-0.1.0.post20250420.dist-info/RECORD +0 -129
- {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250422.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250422.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250422.dist-info}/top_level.txt +0 -0
@@ -57,9 +57,38 @@ _max_order = 10
|
|
57
57
|
|
58
58
|
class Projection(Module):
|
59
59
|
"""
|
60
|
-
Base class
|
60
|
+
Base class for synaptic projection modules in neural network modeling.
|
61
|
+
|
62
|
+
This class defines the interface for modules that handle projections between
|
63
|
+
neural populations. Projections process input signals and transform them
|
64
|
+
before they reach the target neurons, implementing the connectivity patterns
|
65
|
+
in neural networks.
|
66
|
+
|
67
|
+
In the BrainState execution order, Projection modules are updated before
|
68
|
+
Dynamics modules, following the natural information flow in neural systems:
|
69
|
+
1. Projections process inputs (synaptic transmission)
|
70
|
+
2. Dynamics update neuron states (neural integration)
|
71
|
+
|
72
|
+
The Projection class does not implement the update logic directly but delegates
|
73
|
+
to its child nodes. If no child nodes exist, it raises a ValueError.
|
74
|
+
|
75
|
+
Parameters
|
76
|
+
----------
|
77
|
+
*args : Any
|
78
|
+
Arguments passed to the parent Module class.
|
79
|
+
**kwargs : Any
|
80
|
+
Keyword arguments passed to the parent Module class.
|
81
|
+
|
82
|
+
Raises
|
83
|
+
------
|
84
|
+
ValueError
|
85
|
+
If the update() method is called but no child nodes are defined.
|
86
|
+
|
87
|
+
Notes
|
88
|
+
-----
|
89
|
+
Derived classes should implement specific projection behaviors, such as
|
90
|
+
dense connectivity, sparse connectivity, or specific weight update rules.
|
61
91
|
"""
|
62
|
-
|
63
92
|
__module__ = 'brainstate.nn'
|
64
93
|
|
65
94
|
def update(self, *args, **kwargs):
|
@@ -73,24 +102,48 @@ class Projection(Module):
|
|
73
102
|
|
74
103
|
class Dynamics(Module):
|
75
104
|
"""
|
76
|
-
Base class
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
105
|
+
Base class for implementing neural dynamics models in BrainState.
|
106
|
+
|
107
|
+
Dynamics classes represent the core computational units in neural simulations,
|
108
|
+
implementing the differential equations or update rules that govern neural activity.
|
109
|
+
This class provides infrastructure for managing neural populations, handling inputs,
|
110
|
+
and coordinating updates within the simulation framework.
|
111
|
+
|
112
|
+
The Dynamics class serves several key purposes:
|
113
|
+
1. Managing neuron population geometry and size information
|
114
|
+
2. Handling current and delta (instantaneous change) inputs to neurons
|
115
|
+
3. Supporting before/after update hooks for computational dependencies
|
116
|
+
4. Providing access to delayed state variables through the prefetch mechanism
|
117
|
+
5. Establishing the execution order in neural network simulations
|
118
|
+
|
119
|
+
Parameters
|
120
|
+
----------
|
121
|
+
in_size : Size
|
122
|
+
The geometry of the neuron population. Can be an integer (e.g., 10) for
|
123
|
+
1D neuron arrays, or a tuple (e.g., (10, 10)) for multi-dimensional populations.
|
124
|
+
name : Optional[str], default=None
|
125
|
+
Optional name identifier for this dynamics module.
|
126
|
+
|
127
|
+
Attributes
|
128
|
+
----------
|
129
|
+
in_size : tuple
|
130
|
+
The shape/geometry of the neuron population.
|
131
|
+
out_size : tuple
|
132
|
+
The output shape, typically matches in_size.
|
133
|
+
current_inputs : Optional[Dict[str, Union[Callable, ArrayLike]]]
|
134
|
+
Dictionary of registered current input functions or arrays.
|
135
|
+
delta_inputs : Optional[Dict[str, Union[Callable, ArrayLike]]]
|
136
|
+
Dictionary of registered delta input functions or arrays.
|
137
|
+
before_updates : Optional[Dict[Hashable, Callable]]
|
138
|
+
Dictionary of functions to call before the main update.
|
139
|
+
after_updates : Optional[Dict[Hashable, Callable]]
|
140
|
+
Dictionary of functions to call after the main update.
|
141
|
+
|
142
|
+
Notes
|
143
|
+
-----
|
144
|
+
In the BrainState execution sequence, Dynamics modules are updated after
|
145
|
+
Projection modules and before other module types, reflecting the natural
|
146
|
+
flow of information in neural systems.
|
94
147
|
|
95
148
|
There are several essential attributes:
|
96
149
|
|
@@ -100,9 +153,12 @@ class Dynamics(Module):
|
|
100
153
|
- ``num``: the flattened number of neurons in the group. For example, `size=(10, )` => \
|
101
154
|
`num=10`, `size=(10, 10)` => `num=100`, `size=(10, 15, 4)` => `num=600`.
|
102
155
|
|
103
|
-
|
104
|
-
|
105
|
-
|
156
|
+
|
157
|
+
See Also
|
158
|
+
--------
|
159
|
+
Module : Parent class providing base module functionality
|
160
|
+
Projection : Class for handling synaptic projections between neural populations
|
161
|
+
DynamicsGroup : Container for organizing multiple dynamics modules
|
106
162
|
"""
|
107
163
|
|
108
164
|
__module__ = 'brainstate.nn'
|
@@ -158,26 +214,72 @@ class Dynamics(Module):
|
|
158
214
|
self.out_size = self.in_size
|
159
215
|
|
160
216
|
def __pretty_repr_item__(self, name, value):
|
161
|
-
if name in [
|
162
|
-
|
217
|
+
if name in [
|
218
|
+
'_before_updates', '_after_updates', '_current_inputs', '_delta_inputs',
|
219
|
+
'_in_size', '_out_size', '_name', '_mode',
|
220
|
+
]:
|
221
|
+
return (name, value) if value is None else (name[1:], value) # skip the first `_`
|
163
222
|
return super().__pretty_repr_item__(name, value)
|
164
223
|
|
165
224
|
@property
|
166
225
|
def varshape(self):
|
167
|
-
"""
|
226
|
+
"""
|
227
|
+
Get the shape of variables in the neuron group.
|
228
|
+
|
229
|
+
This property provides access to the geometry (shape) of the neuron population,
|
230
|
+
which determines how variables and states are structured.
|
231
|
+
|
232
|
+
Returns
|
233
|
+
-------
|
234
|
+
tuple
|
235
|
+
A tuple representing the dimensional shape of the neuron group,
|
236
|
+
matching the in_size parameter provided during initialization.
|
237
|
+
|
238
|
+
See Also
|
239
|
+
--------
|
240
|
+
in_size : The input geometry specification for the neuron group
|
241
|
+
"""
|
168
242
|
return self.in_size
|
169
243
|
|
170
244
|
@property
|
171
245
|
def current_inputs(self):
|
172
246
|
"""
|
173
|
-
|
247
|
+
Get the dictionary of current inputs registered with this dynamics model.
|
248
|
+
|
249
|
+
Current inputs represent direct input currents that flow into the model.
|
250
|
+
|
251
|
+
Returns
|
252
|
+
-------
|
253
|
+
dict or None
|
254
|
+
A dictionary mapping keys to current input functions or values,
|
255
|
+
or None if no current inputs have been registered.
|
256
|
+
|
257
|
+
See Also
|
258
|
+
--------
|
259
|
+
add_current_input : Register a new current input
|
260
|
+
sum_current_inputs : Apply and sum all current inputs
|
261
|
+
delta_inputs : Dictionary of instantaneous change inputs
|
174
262
|
"""
|
175
263
|
return self._current_inputs
|
176
264
|
|
177
265
|
@property
|
178
266
|
def delta_inputs(self):
|
179
267
|
"""
|
180
|
-
|
268
|
+
Get the dictionary of delta inputs registered with this dynamics model.
|
269
|
+
|
270
|
+
Delta inputs represent instantaneous changes to state variables (dX/dt).
|
271
|
+
|
272
|
+
Returns
|
273
|
+
-------
|
274
|
+
dict or None
|
275
|
+
A dictionary mapping keys to delta input functions or values,
|
276
|
+
or None if no delta inputs have been registered.
|
277
|
+
|
278
|
+
See Also
|
279
|
+
--------
|
280
|
+
add_delta_input : Register a new delta input
|
281
|
+
sum_delta_inputs : Apply and sum all delta inputs
|
282
|
+
current_inputs : Dictionary of direct current inputs
|
181
283
|
"""
|
182
284
|
return self._delta_inputs
|
183
285
|
|
@@ -188,12 +290,40 @@ class Dynamics(Module):
|
|
188
290
|
label: Optional[str] = None
|
189
291
|
):
|
190
292
|
"""
|
191
|
-
Add a current input function.
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
293
|
+
Add a current input function or array to the dynamics model.
|
294
|
+
|
295
|
+
Current inputs represent direct input currents that can be accessed during
|
296
|
+
model updates through the `sum_current_inputs()` method.
|
297
|
+
|
298
|
+
Parameters
|
299
|
+
----------
|
300
|
+
key : str
|
301
|
+
Unique identifier for this current input. Used to retrieve or reference
|
302
|
+
the input later.
|
303
|
+
inp : Union[Callable, ArrayLike]
|
304
|
+
The input data or function that generates input data.
|
305
|
+
- If callable: Will be called during updates with arguments passed to `sum_current_inputs()`
|
306
|
+
- If array-like: Will be applied once and then automatically removed from available inputs
|
307
|
+
label : Optional[str], default=None
|
308
|
+
Optional grouping label for the input. When provided, allows selective
|
309
|
+
processing of inputs by label in `sum_current_inputs()`.
|
310
|
+
|
311
|
+
Raises
|
312
|
+
------
|
313
|
+
ValueError
|
314
|
+
If the key has already been used for a different current input.
|
315
|
+
|
316
|
+
Notes
|
317
|
+
-----
|
318
|
+
- Inputs with the same label can be processed together using the `label`
|
319
|
+
parameter in `sum_current_inputs()`.
|
320
|
+
- Non-callable inputs are consumed when used (removed after first use).
|
321
|
+
- Callable inputs persist and can be called repeatedly.
|
322
|
+
|
323
|
+
See Also
|
324
|
+
--------
|
325
|
+
sum_current_inputs : Sum all current inputs matching a given label
|
326
|
+
add_delta_input : Add a delta input function or array
|
197
327
|
"""
|
198
328
|
key = _input_label_repr(key, label)
|
199
329
|
if self._current_inputs is None:
|
@@ -210,12 +340,41 @@ class Dynamics(Module):
|
|
210
340
|
label: Optional[str] = None
|
211
341
|
):
|
212
342
|
"""
|
213
|
-
Add a delta input function.
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
343
|
+
Add a delta input function or array to the dynamics model.
|
344
|
+
|
345
|
+
Delta inputs represent instantaneous changes to the model state (i.e., dX/dt contributions).
|
346
|
+
This method registers a function or array that provides delta inputs which will be
|
347
|
+
accessible during model updates through the `sum_delta_inputs()` method.
|
348
|
+
|
349
|
+
Parameters
|
350
|
+
----------
|
351
|
+
key : str
|
352
|
+
Unique identifier for this delta input. Used to retrieve or reference
|
353
|
+
the input later.
|
354
|
+
inp : Union[Callable, ArrayLike]
|
355
|
+
The input data or function that generates input data.
|
356
|
+
- If callable: Will be called during updates with arguments passed to `sum_delta_inputs()`
|
357
|
+
- If array-like: Will be applied once and then automatically removed from available inputs
|
358
|
+
label : Optional[str], default=None
|
359
|
+
Optional grouping label for the input. When provided, allows selective
|
360
|
+
processing of inputs by label in `sum_delta_inputs()`.
|
361
|
+
|
362
|
+
Raises
|
363
|
+
------
|
364
|
+
ValueError
|
365
|
+
If the key has already been used for a different delta input.
|
366
|
+
|
367
|
+
Notes
|
368
|
+
-----
|
369
|
+
- Inputs with the same label can be processed together using the `label`
|
370
|
+
parameter in `sum_delta_inputs()`.
|
371
|
+
- Non-callable inputs are consumed when used (removed after first use).
|
372
|
+
- Callable inputs persist and can be called repeatedly.
|
373
|
+
|
374
|
+
See Also
|
375
|
+
--------
|
376
|
+
sum_delta_inputs : Sum all delta inputs matching a given label
|
377
|
+
add_current_input : Add a current input function or array
|
219
378
|
"""
|
220
379
|
key = _input_label_repr(key, label)
|
221
380
|
if self._delta_inputs is None:
|
@@ -226,13 +385,40 @@ class Dynamics(Module):
|
|
226
385
|
self._delta_inputs[key] = inp
|
227
386
|
|
228
387
|
def get_input(self, key: str):
|
229
|
-
"""
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
388
|
+
"""
|
389
|
+
Get a registered input function by its key.
|
390
|
+
|
391
|
+
Retrieves either a current input or a delta input function that was previously
|
392
|
+
registered with the given key. This method checks both current_inputs and
|
393
|
+
delta_inputs dictionaries for the specified key.
|
394
|
+
|
395
|
+
Parameters
|
396
|
+
----------
|
397
|
+
key : str
|
398
|
+
The unique identifier used when the input function was registered.
|
399
|
+
|
400
|
+
Returns
|
401
|
+
-------
|
402
|
+
Callable or ArrayLike
|
403
|
+
The input function or array associated with the given key.
|
404
|
+
|
405
|
+
Raises
|
406
|
+
------
|
407
|
+
ValueError
|
408
|
+
If no input function is found with the specified key in either
|
409
|
+
current_inputs or delta_inputs.
|
410
|
+
|
411
|
+
See Also
|
412
|
+
--------
|
413
|
+
add_current_input : Register a current input function
|
414
|
+
add_delta_input : Register a delta input function
|
415
|
+
|
416
|
+
Examples
|
417
|
+
--------
|
418
|
+
>>> model = Dynamics(10)
|
419
|
+
>>> model.add_current_input('stimulus', lambda t: np.sin(t))
|
420
|
+
>>> input_func = model.get_input('stimulus')
|
421
|
+
>>> input_func(0.5) # Returns sin(0.5)
|
236
422
|
"""
|
237
423
|
if self._current_inputs is not None and key in self._current_inputs:
|
238
424
|
return self._current_inputs[key]
|
@@ -249,16 +435,36 @@ class Dynamics(Module):
|
|
249
435
|
**kwargs
|
250
436
|
):
|
251
437
|
"""
|
252
|
-
Summarize all current inputs by
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
438
|
+
Summarize all current inputs by applying and summing all registered current input functions.
|
439
|
+
|
440
|
+
This method iterates through all registered current input functions (from `.current_inputs`)
|
441
|
+
and applies them to calculate the total input current for the dynamics model. It adds all results
|
442
|
+
to the initial value provided.
|
443
|
+
|
444
|
+
Parameters
|
445
|
+
----------
|
446
|
+
init : Any
|
447
|
+
The initial value to which all current inputs will be added.
|
448
|
+
*args : tuple
|
449
|
+
Variable length argument list passed to each current input function.
|
450
|
+
label : Optional[str], default=None
|
451
|
+
If provided, only process current inputs with this label prefix.
|
452
|
+
When None, process all current inputs regardless of label.
|
453
|
+
**kwargs : dict
|
454
|
+
Arbitrary keyword arguments passed to each current input function.
|
455
|
+
|
456
|
+
Returns
|
457
|
+
-------
|
458
|
+
Any
|
459
|
+
The initial value plus all applicable current inputs summed together.
|
460
|
+
|
461
|
+
Notes
|
462
|
+
-----
|
463
|
+
- Non-callable current inputs are applied once and then automatically removed from
|
464
|
+
the current_inputs dictionary.
|
465
|
+
- Callable current inputs remain registered for subsequent calls.
|
466
|
+
- When a label is provided, only current inputs with keys starting with that label
|
467
|
+
are applied.
|
262
468
|
"""
|
263
469
|
if self._current_inputs is None:
|
264
470
|
return init
|
@@ -288,16 +494,36 @@ class Dynamics(Module):
|
|
288
494
|
**kwargs
|
289
495
|
):
|
290
496
|
"""
|
291
|
-
Summarize all delta inputs by
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
497
|
+
Summarize all delta inputs by applying and summing all registered delta input functions.
|
498
|
+
|
499
|
+
This method iterates through all registered delta input functions (from `.delta_inputs`)
|
500
|
+
and applies them to calculate instantaneous changes to model states. It adds all results
|
501
|
+
to the initial value provided.
|
502
|
+
|
503
|
+
Parameters
|
504
|
+
----------
|
505
|
+
init : Any
|
506
|
+
The initial value to which all delta inputs will be added.
|
507
|
+
*args : tuple
|
508
|
+
Variable length argument list passed to each delta input function.
|
509
|
+
label : Optional[str], default=None
|
510
|
+
If provided, only process delta inputs with this label prefix.
|
511
|
+
When None, process all delta inputs regardless of label.
|
512
|
+
**kwargs : dict
|
513
|
+
Arbitrary keyword arguments passed to each delta input function.
|
514
|
+
|
515
|
+
Returns
|
516
|
+
-------
|
517
|
+
Any
|
518
|
+
The initial value plus all applicable delta inputs summed together.
|
519
|
+
|
520
|
+
Notes
|
521
|
+
-----
|
522
|
+
- Non-callable delta inputs are applied once and then automatically removed from
|
523
|
+
the delta_inputs dictionary.
|
524
|
+
- Callable delta inputs remain registered for subsequent calls.
|
525
|
+
- When a label is provided, only delta inputs with keys starting with that label
|
526
|
+
are applied.
|
301
527
|
"""
|
302
528
|
if self._delta_inputs is None:
|
303
529
|
return init
|
@@ -322,20 +548,59 @@ class Dynamics(Module):
|
|
322
548
|
@property
|
323
549
|
def before_updates(self):
|
324
550
|
"""
|
325
|
-
|
551
|
+
Get the dictionary of functions to execute before the module's update.
|
552
|
+
|
553
|
+
Returns
|
554
|
+
-------
|
555
|
+
dict or None
|
556
|
+
Dictionary mapping keys to callable functions that will be executed
|
557
|
+
before the main update, or None if no before updates are registered.
|
558
|
+
|
559
|
+
Notes
|
560
|
+
-----
|
561
|
+
Before updates are executed in the order they were registered whenever
|
562
|
+
the module is called via __call__.
|
326
563
|
"""
|
327
564
|
return self._before_updates
|
328
565
|
|
329
566
|
@property
|
330
567
|
def after_updates(self):
|
331
568
|
"""
|
332
|
-
|
569
|
+
Get the dictionary of functions to execute after the module's update.
|
570
|
+
|
571
|
+
Returns
|
572
|
+
-------
|
573
|
+
dict or None
|
574
|
+
Dictionary mapping keys to callable functions that will be executed
|
575
|
+
after the main update, or None if no after updates are registered.
|
576
|
+
|
577
|
+
Notes
|
578
|
+
-----
|
579
|
+
After updates are executed in the order they were registered whenever
|
580
|
+
the module is called via __call__, and may optionally receive the return
|
581
|
+
value from the update method.
|
333
582
|
"""
|
334
583
|
return self._after_updates
|
335
584
|
|
336
585
|
def _add_before_update(self, key: Any, fun: Callable):
|
337
586
|
"""
|
338
|
-
|
587
|
+
Register a function to be executed before the module's update.
|
588
|
+
|
589
|
+
Parameters
|
590
|
+
----------
|
591
|
+
key : Any
|
592
|
+
A unique identifier for the update function.
|
593
|
+
fun : Callable
|
594
|
+
The function to execute before the module's update.
|
595
|
+
|
596
|
+
Raises
|
597
|
+
------
|
598
|
+
KeyError
|
599
|
+
If the key is already registered in before_updates.
|
600
|
+
|
601
|
+
Notes
|
602
|
+
-----
|
603
|
+
Internal method used by the module system to register dependencies.
|
339
604
|
"""
|
340
605
|
if self._before_updates is None:
|
341
606
|
self._before_updates = dict()
|
@@ -344,7 +609,25 @@ class Dynamics(Module):
|
|
344
609
|
self.before_updates[key] = fun
|
345
610
|
|
346
611
|
def _add_after_update(self, key: Any, fun: Callable):
|
347
|
-
"""
|
612
|
+
"""
|
613
|
+
Register a function to be executed after the module's update.
|
614
|
+
|
615
|
+
Parameters
|
616
|
+
----------
|
617
|
+
key : Any
|
618
|
+
A unique identifier for the update function.
|
619
|
+
fun : Callable
|
620
|
+
The function to execute after the module's update.
|
621
|
+
|
622
|
+
Raises
|
623
|
+
------
|
624
|
+
KeyError
|
625
|
+
If the key is already registered in after_updates.
|
626
|
+
|
627
|
+
Notes
|
628
|
+
-----
|
629
|
+
Internal method used by the module system to register dependencies.
|
630
|
+
"""
|
348
631
|
if self._after_updates is None:
|
349
632
|
self._after_updates = dict()
|
350
633
|
if key in self.after_updates:
|
@@ -352,7 +635,24 @@ class Dynamics(Module):
|
|
352
635
|
self.after_updates[key] = fun
|
353
636
|
|
354
637
|
def _get_before_update(self, key: Any):
|
355
|
-
"""
|
638
|
+
"""
|
639
|
+
Retrieve a registered before-update function by its key.
|
640
|
+
|
641
|
+
Parameters
|
642
|
+
----------
|
643
|
+
key : Any
|
644
|
+
The identifier of the before-update function to retrieve.
|
645
|
+
|
646
|
+
Returns
|
647
|
+
-------
|
648
|
+
Callable
|
649
|
+
The registered before-update function.
|
650
|
+
|
651
|
+
Raises
|
652
|
+
------
|
653
|
+
KeyError
|
654
|
+
If the key is not registered in before_updates or if before_updates is None.
|
655
|
+
"""
|
356
656
|
if self._before_updates is None:
|
357
657
|
raise KeyError(f'{key} is not registered in before_updates of {self}')
|
358
658
|
if key not in self.before_updates:
|
@@ -360,7 +660,24 @@ class Dynamics(Module):
|
|
360
660
|
return self.before_updates.get(key)
|
361
661
|
|
362
662
|
def _get_after_update(self, key: Any):
|
363
|
-
"""
|
663
|
+
"""
|
664
|
+
Retrieve a registered after-update function by its key.
|
665
|
+
|
666
|
+
Parameters
|
667
|
+
----------
|
668
|
+
key : Any
|
669
|
+
The identifier of the after-update function to retrieve.
|
670
|
+
|
671
|
+
Returns
|
672
|
+
-------
|
673
|
+
Callable
|
674
|
+
The registered after-update function.
|
675
|
+
|
676
|
+
Raises
|
677
|
+
------
|
678
|
+
KeyError
|
679
|
+
If the key is not registered in after_updates or if after_updates is None.
|
680
|
+
"""
|
364
681
|
if self._after_updates is None:
|
365
682
|
raise KeyError(f'{key} is not registered in after_updates of {self}')
|
366
683
|
if key not in self.after_updates:
|
@@ -368,13 +685,37 @@ class Dynamics(Module):
|
|
368
685
|
return self.after_updates.get(key)
|
369
686
|
|
370
687
|
def _has_before_update(self, key: Any):
|
371
|
-
"""
|
688
|
+
"""
|
689
|
+
Check if a before-update function is registered with the given key.
|
690
|
+
|
691
|
+
Parameters
|
692
|
+
----------
|
693
|
+
key : Any
|
694
|
+
The identifier to check for in the before_updates dictionary.
|
695
|
+
|
696
|
+
Returns
|
697
|
+
-------
|
698
|
+
bool
|
699
|
+
True if the key is registered in before_updates, False otherwise.
|
700
|
+
"""
|
372
701
|
if self._before_updates is None:
|
373
702
|
return False
|
374
703
|
return key in self.before_updates
|
375
704
|
|
376
705
|
def _has_after_update(self, key: Any):
|
377
|
-
"""
|
706
|
+
"""
|
707
|
+
Check if an after-update function is registered with the given key.
|
708
|
+
|
709
|
+
Parameters
|
710
|
+
----------
|
711
|
+
key : Any
|
712
|
+
The identifier to check for in the after_updates dictionary.
|
713
|
+
|
714
|
+
Returns
|
715
|
+
-------
|
716
|
+
bool
|
717
|
+
True if the key is registered in after_updates, False otherwise.
|
718
|
+
"""
|
378
719
|
if self._after_updates is None:
|
379
720
|
return False
|
380
721
|
return key in self.after_updates
|
@@ -405,19 +746,75 @@ class Dynamics(Module):
|
|
405
746
|
return ret
|
406
747
|
|
407
748
|
def prefetch(self, item: str) -> 'Prefetch':
|
749
|
+
"""
|
750
|
+
Create a reference to a state or variable that may not be initialized yet.
|
751
|
+
|
752
|
+
This method allows accessing module attributes or states before they are
|
753
|
+
fully defined, acting as a placeholder that will be resolved when called.
|
754
|
+
Particularly useful for creating references to variables that will be defined
|
755
|
+
during initialization or runtime.
|
756
|
+
|
757
|
+
Parameters
|
758
|
+
----------
|
759
|
+
item : str
|
760
|
+
The name of the attribute or state to reference.
|
761
|
+
|
762
|
+
Returns
|
763
|
+
-------
|
764
|
+
Prefetch
|
765
|
+
A Prefetch object that provides access to the referenced item.
|
766
|
+
|
767
|
+
Examples
|
768
|
+
--------
|
769
|
+
>>> import brainstate
|
770
|
+
>>> import brainunit as u
|
771
|
+
>>> neuron = brainstate.nn.LIF(...)
|
772
|
+
>>> v_ref = neuron.prefetch('V') # Reference to voltage
|
773
|
+
>>> v_value = v_ref() # Get current value
|
774
|
+
>>> delayed_v = v_ref.delay.at(5.0 * u.ms) # Get delayed value
|
775
|
+
"""
|
408
776
|
return Prefetch(self, item)
|
409
777
|
|
410
778
|
def align_pre(
|
411
|
-
self,
|
779
|
+
self,
|
780
|
+
dyn: Union[ParamDescriber[T], T]
|
412
781
|
) -> T:
|
413
782
|
"""
|
414
|
-
|
783
|
+
Registers a dynamics module to execute after this module.
|
784
|
+
|
785
|
+
This method establishes a sequential execution relationship where the specified
|
786
|
+
dynamics module will be called after this module completes its update. This
|
787
|
+
creates a feed-forward connection in the computational graph.
|
788
|
+
|
789
|
+
Parameters
|
790
|
+
----------
|
791
|
+
dyn : Union[ParamDescriber[T], T]
|
792
|
+
The dynamics module to be executed after this module. Can be either:
|
793
|
+
- An instance of Dynamics
|
794
|
+
- A ParamDescriber that can instantiate a Dynamics object
|
795
|
+
|
796
|
+
Returns
|
797
|
+
-------
|
798
|
+
T
|
799
|
+
The dynamics module that was registered, allowing for method chaining.
|
800
|
+
|
801
|
+
Raises
|
802
|
+
------
|
803
|
+
TypeError
|
804
|
+
If the input is not a Dynamics instance or a ParamDescriber that creates
|
805
|
+
a Dynamics instance.
|
806
|
+
|
807
|
+
Examples
|
808
|
+
--------
|
809
|
+
>>> import brainstate
|
810
|
+
>>> n1 = brainstate.nn.LIF(10)
|
811
|
+
>>> n1.align_pre(brainstate.nn.Expon.desc(n1.varshape)) # n2 will run after n1
|
415
812
|
"""
|
416
813
|
if isinstance(dyn, Dynamics):
|
417
814
|
self._add_after_update(dyn.name, dyn)
|
418
815
|
return dyn
|
419
816
|
elif isinstance(dyn, ParamDescriber):
|
420
|
-
if not
|
817
|
+
if not issubclass(dyn.cls, Dynamics):
|
421
818
|
raise TypeError(f'The input {dyn} should be an instance of {Dynamics}.')
|
422
819
|
if not self._has_after_update(dyn.identifier):
|
423
820
|
self._add_after_update(dyn.identifier, dyn())
|
@@ -425,62 +822,206 @@ class Dynamics(Module):
|
|
425
822
|
else:
|
426
823
|
raise TypeError(f'The input {dyn} should be an instance of {Dynamics} or a delayed initializer.')
|
427
824
|
|
428
|
-
def __pretty_repr_item__(self, name, value):
|
429
|
-
if name in ['_in_size', '_out_size', '_name', '_mode',
|
430
|
-
'_before_updates', '_after_updates', '_current_inputs', '_delta_inputs']:
|
431
|
-
return (name, value) if value is None else (name[1:], value) # skip the first `_`
|
432
|
-
return name, value
|
433
|
-
|
434
825
|
|
435
826
|
class Prefetch(Node):
|
436
827
|
"""
|
437
|
-
Prefetch a variable
|
828
|
+
Prefetch a state or variable in a module before it is initialized.
|
829
|
+
|
830
|
+
|
831
|
+
This class provides a mechanism to reference a module's state or attribute
|
832
|
+
that may not have been initialized yet. It acts as a placeholder or reference
|
833
|
+
that will be resolved when called.
|
834
|
+
|
835
|
+
Use cases:
|
836
|
+
- Access variables within dynamics modules that will be defined later
|
837
|
+
- Create references to states across module boundaries
|
838
|
+
- Enable access to delayed states through the `.delay` property
|
839
|
+
|
840
|
+
Parameters
|
841
|
+
----------
|
842
|
+
module : Module
|
843
|
+
The module that contains or will contain the referenced item.
|
844
|
+
item : str
|
845
|
+
The attribute name of the state or variable to prefetch.
|
846
|
+
|
847
|
+
Examples
|
848
|
+
--------
|
849
|
+
>>> import brainstate
|
850
|
+
>>> import brainunit as u
|
851
|
+
>>> neuron = brainstate.nn.LIF(...)
|
852
|
+
>>> v_reference = neuron.prefetch('V') # Reference to voltage before initialization
|
853
|
+
>>> v_value = v_reference() # Get the current value
|
854
|
+
>>> delay_ref = v_reference.delay.at(5.0 * u.ms) # Reference voltage delayed by 5ms
|
855
|
+
|
856
|
+
Notes
|
857
|
+
-----
|
858
|
+
When called, this class retrieves the current value of the referenced item.
|
859
|
+
Use the `.delay` property to access delayed versions of the state.
|
860
|
+
|
438
861
|
"""
|
439
862
|
|
440
|
-
def __init__(self, module:
|
863
|
+
def __init__(self, module: Dynamics, item: str):
|
864
|
+
"""
|
865
|
+
Initialize a Prefetch object.
|
866
|
+
|
867
|
+
Parameters
|
868
|
+
----------
|
869
|
+
module : Module
|
870
|
+
The module that contains or will contain the referenced item.
|
871
|
+
item : str
|
872
|
+
The attribute name of the state or variable to prefetch.
|
873
|
+
"""
|
441
874
|
super().__init__()
|
442
875
|
self.module = module
|
443
876
|
self.item = item
|
444
877
|
|
445
878
|
@property
|
446
879
|
def delay(self):
|
880
|
+
"""
|
881
|
+
Access delayed versions of the prefetched item.
|
882
|
+
|
883
|
+
Returns
|
884
|
+
-------
|
885
|
+
PrefetchDelay
|
886
|
+
An object that provides access to delayed versions of the prefetched item.
|
887
|
+
"""
|
447
888
|
return PrefetchDelay(self.module, self.item)
|
448
889
|
|
449
890
|
def __call__(self, *args, **kwargs):
|
891
|
+
"""
|
892
|
+
Get the current value of the prefetched item.
|
893
|
+
|
894
|
+
Returns
|
895
|
+
-------
|
896
|
+
Any
|
897
|
+
The current value of the referenced item. If the item is a State object,
|
898
|
+
returns its value attribute, otherwise returns the item itself.
|
899
|
+
"""
|
450
900
|
item = _get_prefetch_item(self)
|
451
901
|
return item.value if isinstance(item, State) else item
|
452
902
|
|
453
903
|
def get_item_value(self):
|
904
|
+
"""
|
905
|
+
Get the current value of the prefetched item.
|
906
|
+
|
907
|
+
Similar to __call__, but explicitly named for clarity.
|
908
|
+
|
909
|
+
Returns
|
910
|
+
-------
|
911
|
+
Any
|
912
|
+
The current value of the referenced item. If the item is a State object,
|
913
|
+
returns its value attribute, otherwise returns the item itself.
|
914
|
+
"""
|
454
915
|
item = _get_prefetch_item(self)
|
455
916
|
return item.value if isinstance(item, State) else item
|
456
917
|
|
457
918
|
def get_item(self):
|
458
919
|
"""
|
459
|
-
Get
|
920
|
+
Get the referenced item object itself, not its value.
|
921
|
+
|
922
|
+
Returns
|
923
|
+
-------
|
924
|
+
Any
|
925
|
+
The actual referenced item from the module, which could be a State
|
926
|
+
object or any other attribute.
|
460
927
|
"""
|
461
928
|
return _get_prefetch_item(self)
|
462
929
|
|
463
930
|
|
464
931
|
class PrefetchDelay(Node):
|
932
|
+
"""
|
933
|
+
Provides access to delayed versions of a prefetched state or variable.
|
934
|
+
|
935
|
+
This class acts as an intermediary for accessing delayed values of module variables.
|
936
|
+
It doesn't retrieve values directly but provides methods to specify the delay time
|
937
|
+
via the `at()` method.
|
938
|
+
|
939
|
+
Parameters
|
940
|
+
----------
|
941
|
+
module : Dynamics
|
942
|
+
The dynamics module that contains the referenced state or variable.
|
943
|
+
item : str
|
944
|
+
The name of the state or variable to access with delay.
|
945
|
+
|
946
|
+
Examples
|
947
|
+
--------
|
948
|
+
>>> import brainstate
|
949
|
+
>>> import brainunit as u
|
950
|
+
>>> neuron = brainstate.nn.LIF(10)
|
951
|
+
>>> # Access voltage delayed by 5ms
|
952
|
+
>>> delayed_v = neuron.prefetch('V').delay.at(5.0 * u.ms)
|
953
|
+
>>> delayed_value = delayed_v() # Get the delayed value
|
954
|
+
"""
|
955
|
+
|
465
956
|
def __init__(self, module: Dynamics, item: str):
|
466
957
|
self.module = module
|
467
958
|
self.item = item
|
468
959
|
|
469
960
|
def at(self, time: ArrayLike):
|
961
|
+
"""
|
962
|
+
Specifies the delay time for accessing the variable.
|
963
|
+
|
964
|
+
Parameters
|
965
|
+
----------
|
966
|
+
time : ArrayLike
|
967
|
+
The amount of time to delay the variable access, typically in time units
|
968
|
+
(e.g., milliseconds).
|
969
|
+
|
970
|
+
Returns
|
971
|
+
-------
|
972
|
+
PrefetchDelayAt
|
973
|
+
An object that provides access to the variable at the specified delay time.
|
974
|
+
"""
|
470
975
|
return PrefetchDelayAt(self.module, self.item, time)
|
471
976
|
|
472
977
|
|
473
978
|
class PrefetchDelayAt(Node):
|
474
979
|
"""
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
980
|
+
Provides access to a specific delayed state or variable value at the specific time.
|
981
|
+
|
982
|
+
This class represents the final step in the prefetch delay chain, providing
|
983
|
+
actual access to state values at a specific delay time. It converts the
|
984
|
+
specified time delay into steps and registers the delay with the appropriate
|
985
|
+
StateWithDelay handler.
|
986
|
+
|
987
|
+
Parameters
|
988
|
+
----------
|
989
|
+
module : Dynamics
|
990
|
+
The dynamics module that contains the referenced state or variable.
|
991
|
+
item : str
|
992
|
+
The name of the state or variable to access with delay.
|
993
|
+
time : ArrayLike
|
994
|
+
The amount of time to delay access by, typically in time units (e.g., milliseconds).
|
995
|
+
|
996
|
+
Examples
|
997
|
+
--------
|
998
|
+
>>> import brainstate
|
999
|
+
>>> import brainunit as u
|
1000
|
+
>>> neuron = brainstate.nn.LIF(10)
|
1001
|
+
>>> # Create a reference to voltage delayed by 5ms
|
1002
|
+
>>> delayed_v = PrefetchDelayAt(neuron, 'V', 5.0 * u.ms)
|
1003
|
+
>>> # Get the delayed value
|
1004
|
+
>>> v_value = delayed_v()
|
481
1005
|
"""
|
482
1006
|
|
483
|
-
def __init__(
|
1007
|
+
def __init__(
|
1008
|
+
self,
|
1009
|
+
module: Dynamics,
|
1010
|
+
item: str,
|
1011
|
+
time: ArrayLike
|
1012
|
+
):
|
1013
|
+
"""
|
1014
|
+
Initialize a PrefetchDelayAt object.
|
1015
|
+
|
1016
|
+
Parameters
|
1017
|
+
----------
|
1018
|
+
module : Dynamics
|
1019
|
+
The dynamics module that contains the referenced state or variable.
|
1020
|
+
item : str
|
1021
|
+
The name of the state or variable to access with delay.
|
1022
|
+
time : ArrayLike
|
1023
|
+
The amount of time to delay access by, typically in time units.
|
1024
|
+
"""
|
484
1025
|
super().__init__()
|
485
1026
|
assert isinstance(module, Dynamics), ''
|
486
1027
|
self.module = module
|
@@ -496,6 +1037,14 @@ class PrefetchDelayAt(Node):
|
|
496
1037
|
self.state_delay.register_delay(time)
|
497
1038
|
|
498
1039
|
def __call__(self, *args, **kwargs):
|
1040
|
+
"""
|
1041
|
+
Retrieve the value of the state at the specified delay time.
|
1042
|
+
|
1043
|
+
Returns
|
1044
|
+
-------
|
1045
|
+
Any
|
1046
|
+
The value of the state or variable at the specified delay time.
|
1047
|
+
"""
|
499
1048
|
# return self.state_delay.retrieve_at_time(self.time)
|
500
1049
|
return self.state_delay.retrieve_at_step(self.step)
|
501
1050
|
|
@@ -512,8 +1061,10 @@ def _get_prefetch_item(target: Union[Prefetch, PrefetchDelayAt]) -> Any:
|
|
512
1061
|
|
513
1062
|
|
514
1063
|
def _get_prefetch_item_delay(target: Union[Prefetch, PrefetchDelay, PrefetchDelayAt]) -> Delay:
|
515
|
-
assert isinstance(target.module, Dynamics), (
|
516
|
-
|
1064
|
+
assert isinstance(target.module, Dynamics), (
|
1065
|
+
f'The target module should be an instance '
|
1066
|
+
f'of Dynamics. But got {target.module}.'
|
1067
|
+
)
|
517
1068
|
delay = target.module._get_after_update(_get_delay_key(target.item))
|
518
1069
|
if not isinstance(delay, StateWithDelay):
|
519
1070
|
raise TypeError(f'The prefetch target should be a {StateWithDelay.__name__} when accessing '
|
@@ -522,6 +1073,35 @@ def _get_prefetch_item_delay(target: Union[Prefetch, PrefetchDelay, PrefetchDela
|
|
522
1073
|
|
523
1074
|
|
524
1075
|
def maybe_init_prefetch(target, *args, **kwargs):
|
1076
|
+
"""
|
1077
|
+
Initialize a prefetch target if needed, based on its type.
|
1078
|
+
|
1079
|
+
This function ensures that prefetch references are properly initialized
|
1080
|
+
and ready to use. It handles different types of prefetch objects by
|
1081
|
+
performing the appropriate initialization action:
|
1082
|
+
- For :py:class:`Prefetch` objects: retrieves the referenced item
|
1083
|
+
- For :py:class:`PrefetchDelay` objects: retrieves the delay handler
|
1084
|
+
- For :py:class:`PrefetchDelayAt` objects: registers the specified delay
|
1085
|
+
|
1086
|
+
Parameters
|
1087
|
+
----------
|
1088
|
+
target : Union[Prefetch, PrefetchDelay, PrefetchDelayAt]
|
1089
|
+
The prefetch target to initialize.
|
1090
|
+
*args : Any
|
1091
|
+
Additional positional arguments (unused).
|
1092
|
+
**kwargs : Any
|
1093
|
+
Additional keyword arguments (unused).
|
1094
|
+
|
1095
|
+
Returns
|
1096
|
+
-------
|
1097
|
+
None
|
1098
|
+
This function performs initialization side effects only.
|
1099
|
+
|
1100
|
+
Notes
|
1101
|
+
-----
|
1102
|
+
This function is typically called internally when prefetched references
|
1103
|
+
are used to ensure they are properly set up before access.
|
1104
|
+
"""
|
525
1105
|
if isinstance(target, Prefetch):
|
526
1106
|
_get_prefetch_item(target)
|
527
1107
|
|