brainstate 0.1.0.post20250420__py2.py3-none-any.whl → 0.1.0.post20250422__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. brainstate/_compatible_import.py +15 -0
  2. brainstate/_state.py +5 -4
  3. brainstate/_state_test.py +2 -1
  4. brainstate/augment/_autograd_test.py +3 -2
  5. brainstate/augment/_eval_shape.py +2 -1
  6. brainstate/augment/_mapping.py +0 -1
  7. brainstate/augment/_mapping_test.py +1 -0
  8. brainstate/compile/_ad_checkpoint.py +2 -1
  9. brainstate/compile/_conditions.py +3 -3
  10. brainstate/compile/_conditions_test.py +2 -1
  11. brainstate/compile/_error_if.py +2 -1
  12. brainstate/compile/_error_if_test.py +2 -1
  13. brainstate/compile/_jit.py +3 -2
  14. brainstate/compile/_jit_test.py +2 -1
  15. brainstate/compile/_loop_collect_return.py +2 -2
  16. brainstate/compile/_loop_collect_return_test.py +2 -1
  17. brainstate/compile/_loop_no_collection.py +1 -1
  18. brainstate/compile/_make_jaxpr.py +2 -2
  19. brainstate/compile/_make_jaxpr_test.py +2 -1
  20. brainstate/compile/_progress_bar.py +2 -1
  21. brainstate/compile/_unvmap.py +1 -2
  22. brainstate/environ.py +4 -4
  23. brainstate/environ_test.py +2 -1
  24. brainstate/functional/_activations.py +2 -1
  25. brainstate/functional/_activations_test.py +1 -1
  26. brainstate/functional/_normalization.py +2 -1
  27. brainstate/functional/_others.py +2 -1
  28. brainstate/graph/_graph_operation.py +3 -2
  29. brainstate/graph/_graph_operation_test.py +4 -3
  30. brainstate/init/_base.py +2 -1
  31. brainstate/init/_generic.py +2 -1
  32. brainstate/nn/__init__.py +4 -0
  33. brainstate/nn/_collective_ops.py +1 -0
  34. brainstate/nn/_collective_ops_test.py +0 -4
  35. brainstate/nn/_common.py +0 -1
  36. brainstate/nn/_dyn_impl/__init__.py +0 -4
  37. brainstate/nn/_dyn_impl/_dynamics_neuron.py +431 -13
  38. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +2 -1
  39. brainstate/nn/_dyn_impl/_dynamics_synapse.py +405 -103
  40. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +2 -1
  41. brainstate/nn/_dyn_impl/_inputs.py +236 -29
  42. brainstate/nn/_dyn_impl/_rate_rnns.py +238 -82
  43. brainstate/nn/_dyn_impl/_rate_rnns_test.py +2 -1
  44. brainstate/nn/_dyn_impl/_readout.py +91 -8
  45. brainstate/nn/_dyn_impl/_readout_test.py +2 -1
  46. brainstate/nn/_dynamics/_dynamics_base.py +676 -96
  47. brainstate/nn/_dynamics/_dynamics_base_test.py +2 -1
  48. brainstate/nn/_dynamics/_projection_base.py +29 -30
  49. brainstate/nn/_dynamics/_state_delay.py +3 -3
  50. brainstate/nn/_dynamics/_synouts_test.py +2 -1
  51. brainstate/nn/_elementwise/_dropout.py +3 -2
  52. brainstate/nn/_elementwise/_dropout_test.py +2 -1
  53. brainstate/nn/_elementwise/_elementwise.py +2 -1
  54. brainstate/nn/{_dyn_impl/_projection_alignpost.py → _event/__init__.py} +8 -7
  55. brainstate/nn/_event/_fixedprob_mv.py +169 -0
  56. brainstate/nn/_event/_fixedprob_mv_test.py +115 -0
  57. brainstate/nn/_event/_linear_mv.py +85 -0
  58. brainstate/nn/_event/_linear_mv_test.py +121 -0
  59. brainstate/nn/_exp_euler.py +2 -1
  60. brainstate/nn/_exp_euler_test.py +2 -1
  61. brainstate/nn/_interaction/_conv.py +2 -1
  62. brainstate/nn/_interaction/_linear.py +2 -1
  63. brainstate/nn/_interaction/_linear_test.py +2 -1
  64. brainstate/nn/_interaction/_normalizations.py +3 -2
  65. brainstate/nn/_interaction/_poolings.py +4 -3
  66. brainstate/nn/_module_test.py +2 -1
  67. brainstate/nn/metrics.py +4 -3
  68. brainstate/optim/_lr_scheduler.py +2 -1
  69. brainstate/optim/_lr_scheduler_test.py +2 -1
  70. brainstate/optim/_optax_optimizer_test.py +2 -1
  71. brainstate/optim/_sgd_optimizer.py +3 -2
  72. brainstate/random/_rand_funs.py +2 -1
  73. brainstate/random/_rand_funs_test.py +3 -2
  74. brainstate/random/_rand_seed.py +3 -2
  75. brainstate/random/_rand_seed_test.py +2 -1
  76. brainstate/random/_rand_state.py +4 -3
  77. brainstate/surrogate.py +1 -2
  78. brainstate/typing.py +4 -4
  79. brainstate/util/_caller.py +2 -1
  80. brainstate/util/_others.py +4 -4
  81. brainstate/util/_pretty_pytree.py +1 -1
  82. brainstate/util/_pretty_pytree_test.py +2 -1
  83. brainstate/util/_pretty_table.py +43 -43
  84. brainstate/util/_struct.py +2 -1
  85. brainstate/util/filter.py +0 -1
  86. {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250422.dist-info}/METADATA +3 -3
  87. brainstate-0.1.0.post20250422.dist-info/RECORD +133 -0
  88. brainstate-0.1.0.post20250420.dist-info/RECORD +0 -129
  89. {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250422.dist-info}/LICENSE +0 -0
  90. {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250422.dist-info}/WHEEL +0 -0
  91. {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250422.dist-info}/top_level.txt +0 -0
@@ -57,9 +57,38 @@ _max_order = 10
57
57
 
58
58
  class Projection(Module):
59
59
  """
60
- Base class to model synaptic projections.
60
+ Base class for synaptic projection modules in neural network modeling.
61
+
62
+ This class defines the interface for modules that handle projections between
63
+ neural populations. Projections process input signals and transform them
64
+ before they reach the target neurons, implementing the connectivity patterns
65
+ in neural networks.
66
+
67
+ In the BrainState execution order, Projection modules are updated before
68
+ Dynamics modules, following the natural information flow in neural systems:
69
+ 1. Projections process inputs (synaptic transmission)
70
+ 2. Dynamics update neuron states (neural integration)
71
+
72
+ The Projection class does not implement the update logic directly but delegates
73
+ to its child nodes. If no child nodes exist, it raises a ValueError.
74
+
75
+ Parameters
76
+ ----------
77
+ *args : Any
78
+ Arguments passed to the parent Module class.
79
+ **kwargs : Any
80
+ Keyword arguments passed to the parent Module class.
81
+
82
+ Raises
83
+ ------
84
+ ValueError
85
+ If the update() method is called but no child nodes are defined.
86
+
87
+ Notes
88
+ -----
89
+ Derived classes should implement specific projection behaviors, such as
90
+ dense connectivity, sparse connectivity, or specific weight update rules.
61
91
  """
62
-
63
92
  __module__ = 'brainstate.nn'
64
93
 
65
94
  def update(self, *args, **kwargs):
@@ -73,24 +102,48 @@ class Projection(Module):
73
102
 
74
103
  class Dynamics(Module):
75
104
  """
76
- Base class to model dynamics.
77
-
78
- .. note::
79
- In general, every instance of :py:class:`~.Module` implemented in
80
- BrainPy only defines the evolving function at each time step :math:`t`.
81
-
82
- If users want to define the logic of running models across multiple steps,
83
- we recommend users to use :py:func:`~.for_loop`, :py:class:`~.LoopOverTime`,
84
- :py:class:`~.DSRunner`, or :py:class:`~.DSTrainer`.
85
-
86
- To be compatible with previous APIs, :py:class:`~.Module` inherits
87
- from the :py:class:`~.DelayRegister`. It's worthy to note that the methods of
88
- :py:class:`~.DelayRegister` will be removed in the future, including:
89
-
90
- - ``.register_delay()``
91
- - ``.get_delay_data()``
92
- - ``.update_local_delays()``
93
- - ``.reset_local_delays()``
105
+ Base class for implementing neural dynamics models in BrainState.
106
+
107
+ Dynamics classes represent the core computational units in neural simulations,
108
+ implementing the differential equations or update rules that govern neural activity.
109
+ This class provides infrastructure for managing neural populations, handling inputs,
110
+ and coordinating updates within the simulation framework.
111
+
112
+ The Dynamics class serves several key purposes:
113
+ 1. Managing neuron population geometry and size information
114
+ 2. Handling current and delta (instantaneous change) inputs to neurons
115
+ 3. Supporting before/after update hooks for computational dependencies
116
+ 4. Providing access to delayed state variables through the prefetch mechanism
117
+ 5. Establishing the execution order in neural network simulations
118
+
119
+ Parameters
120
+ ----------
121
+ in_size : Size
122
+ The geometry of the neuron population. Can be an integer (e.g., 10) for
123
+ 1D neuron arrays, or a tuple (e.g., (10, 10)) for multi-dimensional populations.
124
+ name : Optional[str], default=None
125
+ Optional name identifier for this dynamics module.
126
+
127
+ Attributes
128
+ ----------
129
+ in_size : tuple
130
+ The shape/geometry of the neuron population.
131
+ out_size : tuple
132
+ The output shape, typically matches in_size.
133
+ current_inputs : Optional[Dict[str, Union[Callable, ArrayLike]]]
134
+ Dictionary of registered current input functions or arrays.
135
+ delta_inputs : Optional[Dict[str, Union[Callable, ArrayLike]]]
136
+ Dictionary of registered delta input functions or arrays.
137
+ before_updates : Optional[Dict[Hashable, Callable]]
138
+ Dictionary of functions to call before the main update.
139
+ after_updates : Optional[Dict[Hashable, Callable]]
140
+ Dictionary of functions to call after the main update.
141
+
142
+ Notes
143
+ -----
144
+ In the BrainState execution sequence, Dynamics modules are updated after
145
+ Projection modules and before other module types, reflecting the natural
146
+ flow of information in neural systems.
94
147
 
95
148
  There are several essential attributes:
96
149
 
@@ -100,9 +153,12 @@ class Dynamics(Module):
100
153
  - ``num``: the flattened number of neurons in the group. For example, `size=(10, )` => \
101
154
  `num=10`, `size=(10, 10)` => `num=100`, `size=(10, 15, 4)` => `num=600`.
102
155
 
103
- Args:
104
- in_size: The neuron group geometry.
105
- name: The name of the dynamic system.
156
+
157
+ See Also
158
+ --------
159
+ Module : Parent class providing base module functionality
160
+ Projection : Class for handling synaptic projections between neural populations
161
+ DynamicsGroup : Container for organizing multiple dynamics modules
106
162
  """
107
163
 
108
164
  __module__ = 'brainstate.nn'
@@ -158,26 +214,72 @@ class Dynamics(Module):
158
214
  self.out_size = self.in_size
159
215
 
160
216
  def __pretty_repr_item__(self, name, value):
161
- if name in ['_before_updates', '_after_updates', '_current_inputs', '_delta_inputs']:
162
- return None if value is None else (name[1:], value) # skip the first `_`
217
+ if name in [
218
+ '_before_updates', '_after_updates', '_current_inputs', '_delta_inputs',
219
+ '_in_size', '_out_size', '_name', '_mode',
220
+ ]:
221
+ return (name, value) if value is None else (name[1:], value) # skip the first `_`
163
222
  return super().__pretty_repr_item__(name, value)
164
223
 
165
224
  @property
166
225
  def varshape(self):
167
- """The shape of variables in the neuron group."""
226
+ """
227
+ Get the shape of variables in the neuron group.
228
+
229
+ This property provides access to the geometry (shape) of the neuron population,
230
+ which determines how variables and states are structured.
231
+
232
+ Returns
233
+ -------
234
+ tuple
235
+ A tuple representing the dimensional shape of the neuron group,
236
+ matching the in_size parameter provided during initialization.
237
+
238
+ See Also
239
+ --------
240
+ in_size : The input geometry specification for the neuron group
241
+ """
168
242
  return self.in_size
169
243
 
170
244
  @property
171
245
  def current_inputs(self):
172
246
  """
173
- The current inputs of the model. It should be a dictionary of the input data.
247
+ Get the dictionary of current inputs registered with this dynamics model.
248
+
249
+ Current inputs represent direct input currents that flow into the model.
250
+
251
+ Returns
252
+ -------
253
+ dict or None
254
+ A dictionary mapping keys to current input functions or values,
255
+ or None if no current inputs have been registered.
256
+
257
+ See Also
258
+ --------
259
+ add_current_input : Register a new current input
260
+ sum_current_inputs : Apply and sum all current inputs
261
+ delta_inputs : Dictionary of instantaneous change inputs
174
262
  """
175
263
  return self._current_inputs
176
264
 
177
265
  @property
178
266
  def delta_inputs(self):
179
267
  """
180
- The delta inputs of the model. It should be a dictionary of the input data.
268
+ Get the dictionary of delta inputs registered with this dynamics model.
269
+
270
+ Delta inputs represent instantaneous changes to state variables (dX/dt).
271
+
272
+ Returns
273
+ -------
274
+ dict or None
275
+ A dictionary mapping keys to delta input functions or values,
276
+ or None if no delta inputs have been registered.
277
+
278
+ See Also
279
+ --------
280
+ add_delta_input : Register a new delta input
281
+ sum_delta_inputs : Apply and sum all delta inputs
282
+ current_inputs : Dictionary of direct current inputs
181
283
  """
182
284
  return self._delta_inputs
183
285
 
@@ -188,12 +290,40 @@ class Dynamics(Module):
188
290
  label: Optional[str] = None
189
291
  ):
190
292
  """
191
- Add a current input function.
192
-
193
- Args:
194
- key: str. The dict key.
195
- inp: Callable, ArrayLike. The currents or the function to generate currents.
196
- label: str. The input label.
293
+ Add a current input function or array to the dynamics model.
294
+
295
+ Current inputs represent direct input currents that can be accessed during
296
+ model updates through the `sum_current_inputs()` method.
297
+
298
+ Parameters
299
+ ----------
300
+ key : str
301
+ Unique identifier for this current input. Used to retrieve or reference
302
+ the input later.
303
+ inp : Union[Callable, ArrayLike]
304
+ The input data or function that generates input data.
305
+ - If callable: Will be called during updates with arguments passed to `sum_current_inputs()`
306
+ - If array-like: Will be applied once and then automatically removed from available inputs
307
+ label : Optional[str], default=None
308
+ Optional grouping label for the input. When provided, allows selective
309
+ processing of inputs by label in `sum_current_inputs()`.
310
+
311
+ Raises
312
+ ------
313
+ ValueError
314
+ If the key has already been used for a different current input.
315
+
316
+ Notes
317
+ -----
318
+ - Inputs with the same label can be processed together using the `label`
319
+ parameter in `sum_current_inputs()`.
320
+ - Non-callable inputs are consumed when used (removed after first use).
321
+ - Callable inputs persist and can be called repeatedly.
322
+
323
+ See Also
324
+ --------
325
+ sum_current_inputs : Sum all current inputs matching a given label
326
+ add_delta_input : Add a delta input function or array
197
327
  """
198
328
  key = _input_label_repr(key, label)
199
329
  if self._current_inputs is None:
@@ -210,12 +340,41 @@ class Dynamics(Module):
210
340
  label: Optional[str] = None
211
341
  ):
212
342
  """
213
- Add a delta input function.
214
-
215
- Args:
216
- key: str. The dict key.
217
- inp: Callable, ArrayLike. The currents or the function to generate currents.
218
- label: str. The input label.
343
+ Add a delta input function or array to the dynamics model.
344
+
345
+ Delta inputs represent instantaneous changes to the model state (i.e., dX/dt contributions).
346
+ This method registers a function or array that provides delta inputs which will be
347
+ accessible during model updates through the `sum_delta_inputs()` method.
348
+
349
+ Parameters
350
+ ----------
351
+ key : str
352
+ Unique identifier for this delta input. Used to retrieve or reference
353
+ the input later.
354
+ inp : Union[Callable, ArrayLike]
355
+ The input data or function that generates input data.
356
+ - If callable: Will be called during updates with arguments passed to `sum_delta_inputs()`
357
+ - If array-like: Will be applied once and then automatically removed from available inputs
358
+ label : Optional[str], default=None
359
+ Optional grouping label for the input. When provided, allows selective
360
+ processing of inputs by label in `sum_delta_inputs()`.
361
+
362
+ Raises
363
+ ------
364
+ ValueError
365
+ If the key has already been used for a different delta input.
366
+
367
+ Notes
368
+ -----
369
+ - Inputs with the same label can be processed together using the `label`
370
+ parameter in `sum_delta_inputs()`.
371
+ - Non-callable inputs are consumed when used (removed after first use).
372
+ - Callable inputs persist and can be called repeatedly.
373
+
374
+ See Also
375
+ --------
376
+ sum_delta_inputs : Sum all delta inputs matching a given label
377
+ add_current_input : Add a current input function or array
219
378
  """
220
379
  key = _input_label_repr(key, label)
221
380
  if self._delta_inputs is None:
@@ -226,13 +385,40 @@ class Dynamics(Module):
226
385
  self._delta_inputs[key] = inp
227
386
 
228
387
  def get_input(self, key: str):
229
- """Get the input function.
230
-
231
- Args:
232
- key: str. The key.
233
-
234
- Returns:
235
- The input function which generates currents.
388
+ """
389
+ Get a registered input function by its key.
390
+
391
+ Retrieves either a current input or a delta input function that was previously
392
+ registered with the given key. This method checks both current_inputs and
393
+ delta_inputs dictionaries for the specified key.
394
+
395
+ Parameters
396
+ ----------
397
+ key : str
398
+ The unique identifier used when the input function was registered.
399
+
400
+ Returns
401
+ -------
402
+ Callable or ArrayLike
403
+ The input function or array associated with the given key.
404
+
405
+ Raises
406
+ ------
407
+ ValueError
408
+ If no input function is found with the specified key in either
409
+ current_inputs or delta_inputs.
410
+
411
+ See Also
412
+ --------
413
+ add_current_input : Register a current input function
414
+ add_delta_input : Register a delta input function
415
+
416
+ Examples
417
+ --------
418
+ >>> model = Dynamics(10)
419
+ >>> model.add_current_input('stimulus', lambda t: np.sin(t))
420
+ >>> input_func = model.get_input('stimulus')
421
+ >>> input_func(0.5) # Returns sin(0.5)
236
422
  """
237
423
  if self._current_inputs is not None and key in self._current_inputs:
238
424
  return self._current_inputs[key]
@@ -249,16 +435,36 @@ class Dynamics(Module):
249
435
  **kwargs
250
436
  ):
251
437
  """
252
- Summarize all current inputs by the defined input functions ``.current_inputs``.
253
-
254
- Args:
255
- init: The initial input data.
256
- *args: The arguments for input functions.
257
- **kwargs: The arguments for input functions.
258
- label: str. The input label.
259
-
260
- Returns:
261
- The total currents.
438
+ Summarize all current inputs by applying and summing all registered current input functions.
439
+
440
+ This method iterates through all registered current input functions (from `.current_inputs`)
441
+ and applies them to calculate the total input current for the dynamics model. It adds all results
442
+ to the initial value provided.
443
+
444
+ Parameters
445
+ ----------
446
+ init : Any
447
+ The initial value to which all current inputs will be added.
448
+ *args : tuple
449
+ Variable length argument list passed to each current input function.
450
+ label : Optional[str], default=None
451
+ If provided, only process current inputs with this label prefix.
452
+ When None, process all current inputs regardless of label.
453
+ **kwargs : dict
454
+ Arbitrary keyword arguments passed to each current input function.
455
+
456
+ Returns
457
+ -------
458
+ Any
459
+ The initial value plus all applicable current inputs summed together.
460
+
461
+ Notes
462
+ -----
463
+ - Non-callable current inputs are applied once and then automatically removed from
464
+ the current_inputs dictionary.
465
+ - Callable current inputs remain registered for subsequent calls.
466
+ - When a label is provided, only current inputs with keys starting with that label
467
+ are applied.
262
468
  """
263
469
  if self._current_inputs is None:
264
470
  return init
@@ -288,16 +494,36 @@ class Dynamics(Module):
288
494
  **kwargs
289
495
  ):
290
496
  """
291
- Summarize all delta inputs by the defined input functions ``.delta_inputs``.
292
-
293
- Args:
294
- init: The initial input data.
295
- *args: The arguments for input functions.
296
- **kwargs: The arguments for input functions.
297
- label: str. The input label.
298
-
299
- Returns:
300
- The total currents.
497
+ Summarize all delta inputs by applying and summing all registered delta input functions.
498
+
499
+ This method iterates through all registered delta input functions (from `.delta_inputs`)
500
+ and applies them to calculate instantaneous changes to model states. It adds all results
501
+ to the initial value provided.
502
+
503
+ Parameters
504
+ ----------
505
+ init : Any
506
+ The initial value to which all delta inputs will be added.
507
+ *args : tuple
508
+ Variable length argument list passed to each delta input function.
509
+ label : Optional[str], default=None
510
+ If provided, only process delta inputs with this label prefix.
511
+ When None, process all delta inputs regardless of label.
512
+ **kwargs : dict
513
+ Arbitrary keyword arguments passed to each delta input function.
514
+
515
+ Returns
516
+ -------
517
+ Any
518
+ The initial value plus all applicable delta inputs summed together.
519
+
520
+ Notes
521
+ -----
522
+ - Non-callable delta inputs are applied once and then automatically removed from
523
+ the delta_inputs dictionary.
524
+ - Callable delta inputs remain registered for subsequent calls.
525
+ - When a label is provided, only delta inputs with keys starting with that label
526
+ are applied.
301
527
  """
302
528
  if self._delta_inputs is None:
303
529
  return init
@@ -322,20 +548,59 @@ class Dynamics(Module):
322
548
  @property
323
549
  def before_updates(self):
324
550
  """
325
- The before updates of the model. It should be a dictionary of the updating functions.
551
+ Get the dictionary of functions to execute before the module's update.
552
+
553
+ Returns
554
+ -------
555
+ dict or None
556
+ Dictionary mapping keys to callable functions that will be executed
557
+ before the main update, or None if no before updates are registered.
558
+
559
+ Notes
560
+ -----
561
+ Before updates are executed in the order they were registered whenever
562
+ the module is called via __call__.
326
563
  """
327
564
  return self._before_updates
328
565
 
329
566
  @property
330
567
  def after_updates(self):
331
568
  """
332
- The after updates of the model. It should be a dictionary of the updating functions.
569
+ Get the dictionary of functions to execute after the module's update.
570
+
571
+ Returns
572
+ -------
573
+ dict or None
574
+ Dictionary mapping keys to callable functions that will be executed
575
+ after the main update, or None if no after updates are registered.
576
+
577
+ Notes
578
+ -----
579
+ After updates are executed in the order they were registered whenever
580
+ the module is called via __call__, and may optionally receive the return
581
+ value from the update method.
333
582
  """
334
583
  return self._after_updates
335
584
 
336
585
  def _add_before_update(self, key: Any, fun: Callable):
337
586
  """
338
- Add the before update into this node.
587
+ Register a function to be executed before the module's update.
588
+
589
+ Parameters
590
+ ----------
591
+ key : Any
592
+ A unique identifier for the update function.
593
+ fun : Callable
594
+ The function to execute before the module's update.
595
+
596
+ Raises
597
+ ------
598
+ KeyError
599
+ If the key is already registered in before_updates.
600
+
601
+ Notes
602
+ -----
603
+ Internal method used by the module system to register dependencies.
339
604
  """
340
605
  if self._before_updates is None:
341
606
  self._before_updates = dict()
@@ -344,7 +609,25 @@ class Dynamics(Module):
344
609
  self.before_updates[key] = fun
345
610
 
346
611
  def _add_after_update(self, key: Any, fun: Callable):
347
- """Add the after update into this node"""
612
+ """
613
+ Register a function to be executed after the module's update.
614
+
615
+ Parameters
616
+ ----------
617
+ key : Any
618
+ A unique identifier for the update function.
619
+ fun : Callable
620
+ The function to execute after the module's update.
621
+
622
+ Raises
623
+ ------
624
+ KeyError
625
+ If the key is already registered in after_updates.
626
+
627
+ Notes
628
+ -----
629
+ Internal method used by the module system to register dependencies.
630
+ """
348
631
  if self._after_updates is None:
349
632
  self._after_updates = dict()
350
633
  if key in self.after_updates:
@@ -352,7 +635,24 @@ class Dynamics(Module):
352
635
  self.after_updates[key] = fun
353
636
 
354
637
  def _get_before_update(self, key: Any):
355
- """Get the before update of this node by the given ``key``."""
638
+ """
639
+ Retrieve a registered before-update function by its key.
640
+
641
+ Parameters
642
+ ----------
643
+ key : Any
644
+ The identifier of the before-update function to retrieve.
645
+
646
+ Returns
647
+ -------
648
+ Callable
649
+ The registered before-update function.
650
+
651
+ Raises
652
+ ------
653
+ KeyError
654
+ If the key is not registered in before_updates or if before_updates is None.
655
+ """
356
656
  if self._before_updates is None:
357
657
  raise KeyError(f'{key} is not registered in before_updates of {self}')
358
658
  if key not in self.before_updates:
@@ -360,7 +660,24 @@ class Dynamics(Module):
360
660
  return self.before_updates.get(key)
361
661
 
362
662
  def _get_after_update(self, key: Any):
363
- """Get the after update of this node by the given ``key``."""
663
+ """
664
+ Retrieve a registered after-update function by its key.
665
+
666
+ Parameters
667
+ ----------
668
+ key : Any
669
+ The identifier of the after-update function to retrieve.
670
+
671
+ Returns
672
+ -------
673
+ Callable
674
+ The registered after-update function.
675
+
676
+ Raises
677
+ ------
678
+ KeyError
679
+ If the key is not registered in after_updates or if after_updates is None.
680
+ """
364
681
  if self._after_updates is None:
365
682
  raise KeyError(f'{key} is not registered in after_updates of {self}')
366
683
  if key not in self.after_updates:
@@ -368,13 +685,37 @@ class Dynamics(Module):
368
685
  return self.after_updates.get(key)
369
686
 
370
687
  def _has_before_update(self, key: Any):
371
- """Whether this node has the before update of the given ``key``."""
688
+ """
689
+ Check if a before-update function is registered with the given key.
690
+
691
+ Parameters
692
+ ----------
693
+ key : Any
694
+ The identifier to check for in the before_updates dictionary.
695
+
696
+ Returns
697
+ -------
698
+ bool
699
+ True if the key is registered in before_updates, False otherwise.
700
+ """
372
701
  if self._before_updates is None:
373
702
  return False
374
703
  return key in self.before_updates
375
704
 
376
705
  def _has_after_update(self, key: Any):
377
- """Whether this node has the after update of the given ``key``."""
706
+ """
707
+ Check if an after-update function is registered with the given key.
708
+
709
+ Parameters
710
+ ----------
711
+ key : Any
712
+ The identifier to check for in the after_updates dictionary.
713
+
714
+ Returns
715
+ -------
716
+ bool
717
+ True if the key is registered in after_updates, False otherwise.
718
+ """
378
719
  if self._after_updates is None:
379
720
  return False
380
721
  return key in self.after_updates
@@ -405,19 +746,75 @@ class Dynamics(Module):
405
746
  return ret
406
747
 
407
748
  def prefetch(self, item: str) -> 'Prefetch':
749
+ """
750
+ Create a reference to a state or variable that may not be initialized yet.
751
+
752
+ This method allows accessing module attributes or states before they are
753
+ fully defined, acting as a placeholder that will be resolved when called.
754
+ Particularly useful for creating references to variables that will be defined
755
+ during initialization or runtime.
756
+
757
+ Parameters
758
+ ----------
759
+ item : str
760
+ The name of the attribute or state to reference.
761
+
762
+ Returns
763
+ -------
764
+ Prefetch
765
+ A Prefetch object that provides access to the referenced item.
766
+
767
+ Examples
768
+ --------
769
+ >>> import brainstate
770
+ >>> import brainunit as u
771
+ >>> neuron = brainstate.nn.LIF(...)
772
+ >>> v_ref = neuron.prefetch('V') # Reference to voltage
773
+ >>> v_value = v_ref() # Get current value
774
+ >>> delayed_v = v_ref.delay.at(5.0 * u.ms) # Get delayed value
775
+ """
408
776
  return Prefetch(self, item)
409
777
 
410
778
  def align_pre(
411
- self, dyn: Union[ParamDescriber[T], T]
779
+ self,
780
+ dyn: Union[ParamDescriber[T], T]
412
781
  ) -> T:
413
782
  """
414
- Align the dynamics before the interaction.
783
+ Registers a dynamics module to execute after this module.
784
+
785
+ This method establishes a sequential execution relationship where the specified
786
+ dynamics module will be called after this module completes its update. This
787
+ creates a feed-forward connection in the computational graph.
788
+
789
+ Parameters
790
+ ----------
791
+ dyn : Union[ParamDescriber[T], T]
792
+ The dynamics module to be executed after this module. Can be either:
793
+ - An instance of Dynamics
794
+ - A ParamDescriber that can instantiate a Dynamics object
795
+
796
+ Returns
797
+ -------
798
+ T
799
+ The dynamics module that was registered, allowing for method chaining.
800
+
801
+ Raises
802
+ ------
803
+ TypeError
804
+ If the input is not a Dynamics instance or a ParamDescriber that creates
805
+ a Dynamics instance.
806
+
807
+ Examples
808
+ --------
809
+ >>> import brainstate
810
+ >>> n1 = brainstate.nn.LIF(10)
811
+ >>> n1.align_pre(brainstate.nn.Expon.desc(n1.varshape)) # n2 will run after n1
415
812
  """
416
813
  if isinstance(dyn, Dynamics):
417
814
  self._add_after_update(dyn.name, dyn)
418
815
  return dyn
419
816
  elif isinstance(dyn, ParamDescriber):
420
- if not isinstance(dyn.cls, Dynamics):
817
+ if not issubclass(dyn.cls, Dynamics):
421
818
  raise TypeError(f'The input {dyn} should be an instance of {Dynamics}.')
422
819
  if not self._has_after_update(dyn.identifier):
423
820
  self._add_after_update(dyn.identifier, dyn())
@@ -425,62 +822,206 @@ class Dynamics(Module):
425
822
  else:
426
823
  raise TypeError(f'The input {dyn} should be an instance of {Dynamics} or a delayed initializer.')
427
824
 
428
- def __pretty_repr_item__(self, name, value):
429
- if name in ['_in_size', '_out_size', '_name', '_mode',
430
- '_before_updates', '_after_updates', '_current_inputs', '_delta_inputs']:
431
- return (name, value) if value is None else (name[1:], value) # skip the first `_`
432
- return name, value
433
-
434
825
 
435
826
  class Prefetch(Node):
436
827
  """
437
- Prefetch a variable of the given module.
828
+ Prefetch a state or variable in a module before it is initialized.
829
+
830
+
831
+ This class provides a mechanism to reference a module's state or attribute
832
+ that may not have been initialized yet. It acts as a placeholder or reference
833
+ that will be resolved when called.
834
+
835
+ Use cases:
836
+ - Access variables within dynamics modules that will be defined later
837
+ - Create references to states across module boundaries
838
+ - Enable access to delayed states through the `.delay` property
839
+
840
+ Parameters
841
+ ----------
842
+ module : Module
843
+ The module that contains or will contain the referenced item.
844
+ item : str
845
+ The attribute name of the state or variable to prefetch.
846
+
847
+ Examples
848
+ --------
849
+ >>> import brainstate
850
+ >>> import brainunit as u
851
+ >>> neuron = brainstate.nn.LIF(...)
852
+ >>> v_reference = neuron.prefetch('V') # Reference to voltage before initialization
853
+ >>> v_value = v_reference() # Get the current value
854
+ >>> delay_ref = v_reference.delay.at(5.0 * u.ms) # Reference voltage delayed by 5ms
855
+
856
+ Notes
857
+ -----
858
+ When called, this class retrieves the current value of the referenced item.
859
+ Use the `.delay` property to access delayed versions of the state.
860
+
438
861
  """
439
862
 
440
- def __init__(self, module: Module, item: str):
863
+ def __init__(self, module: Dynamics, item: str):
864
+ """
865
+ Initialize a Prefetch object.
866
+
867
+ Parameters
868
+ ----------
869
+ module : Module
870
+ The module that contains or will contain the referenced item.
871
+ item : str
872
+ The attribute name of the state or variable to prefetch.
873
+ """
441
874
  super().__init__()
442
875
  self.module = module
443
876
  self.item = item
444
877
 
445
878
  @property
446
879
  def delay(self):
880
+ """
881
+ Access delayed versions of the prefetched item.
882
+
883
+ Returns
884
+ -------
885
+ PrefetchDelay
886
+ An object that provides access to delayed versions of the prefetched item.
887
+ """
447
888
  return PrefetchDelay(self.module, self.item)
448
889
 
449
890
  def __call__(self, *args, **kwargs):
891
+ """
892
+ Get the current value of the prefetched item.
893
+
894
+ Returns
895
+ -------
896
+ Any
897
+ The current value of the referenced item. If the item is a State object,
898
+ returns its value attribute, otherwise returns the item itself.
899
+ """
450
900
  item = _get_prefetch_item(self)
451
901
  return item.value if isinstance(item, State) else item
452
902
 
453
903
  def get_item_value(self):
904
+ """
905
+ Get the current value of the prefetched item.
906
+
907
+ Similar to __call__, but explicitly named for clarity.
908
+
909
+ Returns
910
+ -------
911
+ Any
912
+ The current value of the referenced item. If the item is a State object,
913
+ returns its value attribute, otherwise returns the item itself.
914
+ """
454
915
  item = _get_prefetch_item(self)
455
916
  return item.value if isinstance(item, State) else item
456
917
 
457
918
  def get_item(self):
458
919
  """
459
- Get
920
+ Get the referenced item object itself, not its value.
921
+
922
+ Returns
923
+ -------
924
+ Any
925
+ The actual referenced item from the module, which could be a State
926
+ object or any other attribute.
460
927
  """
461
928
  return _get_prefetch_item(self)
462
929
 
463
930
 
464
931
  class PrefetchDelay(Node):
932
+ """
933
+ Provides access to delayed versions of a prefetched state or variable.
934
+
935
+ This class acts as an intermediary for accessing delayed values of module variables.
936
+ It doesn't retrieve values directly but provides methods to specify the delay time
937
+ via the `at()` method.
938
+
939
+ Parameters
940
+ ----------
941
+ module : Dynamics
942
+ The dynamics module that contains the referenced state or variable.
943
+ item : str
944
+ The name of the state or variable to access with delay.
945
+
946
+ Examples
947
+ --------
948
+ >>> import brainstate
949
+ >>> import brainunit as u
950
+ >>> neuron = brainstate.nn.LIF(10)
951
+ >>> # Access voltage delayed by 5ms
952
+ >>> delayed_v = neuron.prefetch('V').delay.at(5.0 * u.ms)
953
+ >>> delayed_value = delayed_v() # Get the delayed value
954
+ """
955
+
465
956
  def __init__(self, module: Dynamics, item: str):
466
957
  self.module = module
467
958
  self.item = item
468
959
 
469
960
  def at(self, time: ArrayLike):
961
+ """
962
+ Specifies the delay time for accessing the variable.
963
+
964
+ Parameters
965
+ ----------
966
+ time : ArrayLike
967
+ The amount of time to delay the variable access, typically in time units
968
+ (e.g., milliseconds).
969
+
970
+ Returns
971
+ -------
972
+ PrefetchDelayAt
973
+ An object that provides access to the variable at the specified delay time.
974
+ """
470
975
  return PrefetchDelayAt(self.module, self.item, time)
471
976
 
472
977
 
473
978
  class PrefetchDelayAt(Node):
474
979
  """
475
- Prefetch the delay of a variable in the given module at a specific time.
476
-
477
- Args:
478
- module: The module that has the item with the name specified by ``item`` argument.
479
- item: The item that has the delay.
480
- time: The time to retrieve the delay.
980
+ Provides access to a specific delayed state or variable value at the specific time.
981
+
982
+ This class represents the final step in the prefetch delay chain, providing
983
+ actual access to state values at a specific delay time. It converts the
984
+ specified time delay into steps and registers the delay with the appropriate
985
+ StateWithDelay handler.
986
+
987
+ Parameters
988
+ ----------
989
+ module : Dynamics
990
+ The dynamics module that contains the referenced state or variable.
991
+ item : str
992
+ The name of the state or variable to access with delay.
993
+ time : ArrayLike
994
+ The amount of time to delay access by, typically in time units (e.g., milliseconds).
995
+
996
+ Examples
997
+ --------
998
+ >>> import brainstate
999
+ >>> import brainunit as u
1000
+ >>> neuron = brainstate.nn.LIF(10)
1001
+ >>> # Create a reference to voltage delayed by 5ms
1002
+ >>> delayed_v = PrefetchDelayAt(neuron, 'V', 5.0 * u.ms)
1003
+ >>> # Get the delayed value
1004
+ >>> v_value = delayed_v()
481
1005
  """
482
1006
 
483
- def __init__(self, module: Dynamics, item: str, time: ArrayLike):
1007
+ def __init__(
1008
+ self,
1009
+ module: Dynamics,
1010
+ item: str,
1011
+ time: ArrayLike
1012
+ ):
1013
+ """
1014
+ Initialize a PrefetchDelayAt object.
1015
+
1016
+ Parameters
1017
+ ----------
1018
+ module : Dynamics
1019
+ The dynamics module that contains the referenced state or variable.
1020
+ item : str
1021
+ The name of the state or variable to access with delay.
1022
+ time : ArrayLike
1023
+ The amount of time to delay access by, typically in time units.
1024
+ """
484
1025
  super().__init__()
485
1026
  assert isinstance(module, Dynamics), ''
486
1027
  self.module = module
@@ -496,6 +1037,14 @@ class PrefetchDelayAt(Node):
496
1037
  self.state_delay.register_delay(time)
497
1038
 
498
1039
  def __call__(self, *args, **kwargs):
1040
+ """
1041
+ Retrieve the value of the state at the specified delay time.
1042
+
1043
+ Returns
1044
+ -------
1045
+ Any
1046
+ The value of the state or variable at the specified delay time.
1047
+ """
499
1048
  # return self.state_delay.retrieve_at_time(self.time)
500
1049
  return self.state_delay.retrieve_at_step(self.step)
501
1050
 
@@ -512,8 +1061,10 @@ def _get_prefetch_item(target: Union[Prefetch, PrefetchDelayAt]) -> Any:
512
1061
 
513
1062
 
514
1063
  def _get_prefetch_item_delay(target: Union[Prefetch, PrefetchDelay, PrefetchDelayAt]) -> Delay:
515
- assert isinstance(target.module, Dynamics), (f'The target module should be an instance '
516
- f'of Dynamics. But got {target.module}.')
1064
+ assert isinstance(target.module, Dynamics), (
1065
+ f'The target module should be an instance '
1066
+ f'of Dynamics. But got {target.module}.'
1067
+ )
517
1068
  delay = target.module._get_after_update(_get_delay_key(target.item))
518
1069
  if not isinstance(delay, StateWithDelay):
519
1070
  raise TypeError(f'The prefetch target should be a {StateWithDelay.__name__} when accessing '
@@ -522,6 +1073,35 @@ def _get_prefetch_item_delay(target: Union[Prefetch, PrefetchDelay, PrefetchDela
522
1073
 
523
1074
 
524
1075
  def maybe_init_prefetch(target, *args, **kwargs):
1076
+ """
1077
+ Initialize a prefetch target if needed, based on its type.
1078
+
1079
+ This function ensures that prefetch references are properly initialized
1080
+ and ready to use. It handles different types of prefetch objects by
1081
+ performing the appropriate initialization action:
1082
+ - For :py:class:`Prefetch` objects: retrieves the referenced item
1083
+ - For :py:class:`PrefetchDelay` objects: retrieves the delay handler
1084
+ - For :py:class:`PrefetchDelayAt` objects: registers the specified delay
1085
+
1086
+ Parameters
1087
+ ----------
1088
+ target : Union[Prefetch, PrefetchDelay, PrefetchDelayAt]
1089
+ The prefetch target to initialize.
1090
+ *args : Any
1091
+ Additional positional arguments (unused).
1092
+ **kwargs : Any
1093
+ Additional keyword arguments (unused).
1094
+
1095
+ Returns
1096
+ -------
1097
+ None
1098
+ This function performs initialization side effects only.
1099
+
1100
+ Notes
1101
+ -----
1102
+ This function is typically called internally when prefetched references
1103
+ are used to ensure they are properly set up before access.
1104
+ """
525
1105
  if isinstance(target, Prefetch):
526
1106
  _get_prefetch_item(target)
527
1107