brainstate 0.2.0__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. brainstate/__init__.py +2 -4
  2. brainstate/_deprecation_test.py +2 -24
  3. brainstate/_state.py +540 -35
  4. brainstate/_state_test.py +1085 -8
  5. brainstate/graph/_operation.py +1 -5
  6. brainstate/mixin.py +14 -0
  7. brainstate/nn/__init__.py +42 -33
  8. brainstate/nn/_collective_ops.py +2 -0
  9. brainstate/nn/_common_test.py +0 -20
  10. brainstate/nn/_delay.py +1 -1
  11. brainstate/nn/_dropout_test.py +9 -6
  12. brainstate/nn/_dynamics.py +67 -464
  13. brainstate/nn/_dynamics_test.py +0 -14
  14. brainstate/nn/_embedding.py +7 -7
  15. brainstate/nn/_exp_euler.py +9 -9
  16. brainstate/nn/_linear.py +21 -21
  17. brainstate/nn/_module.py +25 -18
  18. brainstate/nn/_normalizations.py +27 -27
  19. brainstate/random/__init__.py +6 -6
  20. brainstate/random/{_rand_funs.py → _fun.py} +1 -1
  21. brainstate/random/{_rand_funs_test.py → _fun_test.py} +0 -2
  22. brainstate/random/_impl.py +672 -0
  23. brainstate/random/{_rand_seed.py → _seed.py} +1 -1
  24. brainstate/random/{_rand_state.py → _state.py} +121 -418
  25. brainstate/random/{_rand_state_test.py → _state_test.py} +7 -7
  26. brainstate/transform/__init__.py +6 -9
  27. brainstate/transform/_conditions.py +2 -2
  28. brainstate/transform/_find_state.py +200 -0
  29. brainstate/transform/_find_state_test.py +84 -0
  30. brainstate/transform/_make_jaxpr.py +221 -61
  31. brainstate/transform/_make_jaxpr_test.py +125 -1
  32. brainstate/transform/_mapping.py +287 -209
  33. brainstate/transform/_mapping_test.py +94 -184
  34. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/METADATA +1 -1
  35. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/RECORD +39 -39
  36. brainstate/transform/_eval_shape.py +0 -145
  37. brainstate/transform/_eval_shape_test.py +0 -38
  38. brainstate/transform/_random.py +0 -171
  39. /brainstate/random/{_rand_seed_test.py → _seed_test.py} +0 -0
  40. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  41. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +0 -0
  42. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -33,7 +33,7 @@ For handling the delays:
33
33
 
34
34
  """
35
35
 
36
- from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, TYPE_CHECKING, Tuple
36
+ from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, Tuple
37
37
 
38
38
  import jax
39
39
  import numpy as np
@@ -41,23 +41,26 @@ import numpy as np
41
41
  from brainstate import environ
42
42
  from brainstate._state import State
43
43
  from brainstate.graph import Node
44
- from brainstate.mixin import ParamDescriber
45
44
  from brainstate.typing import Size, ArrayLike
46
45
  from ._delay import StateWithDelay, Delay
47
46
  from ._module import Module
48
47
 
48
+ T = TypeVar('T')
49
+
49
50
  __all__ = [
50
- 'DynamicsGroup',
51
51
  'Dynamics',
52
+
53
+ 'receive_update_output',
54
+ 'not_receive_update_output',
55
+ 'receive_update_input',
56
+ 'not_receive_update_input',
57
+
52
58
  'Prefetch',
53
59
  'PrefetchDelay',
54
60
  'PrefetchDelayAt',
55
61
  'OutputDelayAt',
56
62
  ]
57
63
 
58
- T = TypeVar('T')
59
- _max_order = 10
60
-
61
64
 
62
65
  class Dynamics(Module):
63
66
  """
@@ -122,7 +125,7 @@ class Dynamics(Module):
122
125
 
123
126
  __module__ = 'brainstate.nn'
124
127
 
125
- graph_invisible_attrs = ('_before_updates', '_after_updates', '_current_inputs', '_delta_inputs')
128
+ graph_invisible_attrs = ()
126
129
 
127
130
  # before updates
128
131
  _before_updates: Optional[Dict[Hashable, Callable]]
@@ -136,11 +139,7 @@ class Dynamics(Module):
136
139
  # delta inputs
137
140
  _delta_inputs: Optional[Dict[str, ArrayLike | Callable]]
138
141
 
139
- def __init__(
140
- self,
141
- in_size: Size,
142
- name: Optional[str] = None,
143
- ):
142
+ def __init__(self, in_size: Size, name: Optional[str] = None):
144
143
  # initialize
145
144
  super().__init__(name=name)
146
145
 
@@ -157,12 +156,6 @@ class Dynamics(Module):
157
156
  raise ValueError(f'"in_size" must be int, or a tuple/list of int. But we got {type(in_size)}')
158
157
  self.in_size = in_size
159
158
 
160
- # current inputs
161
- self._current_inputs = None
162
-
163
- # delta inputs
164
- self._delta_inputs = None
165
-
166
159
  # before updates
167
160
  self._before_updates = None
168
161
 
@@ -172,14 +165,6 @@ class Dynamics(Module):
172
165
  # in-/out- size of neuron population
173
166
  self.out_size = self.in_size
174
167
 
175
- # def __pretty_repr_item__(self, name, value):
176
- # if name in [
177
- # '_before_updates', '_after_updates', '_current_inputs', '_delta_inputs',
178
- # '_in_size', '_out_size', '_name', '_mode',
179
- # ]:
180
- # return (name, value) if value is None else (name[1:], value) # skip the first `_`
181
- # return super().__pretty_repr_item__(name, value)
182
-
183
168
  @property
184
169
  def varshape(self):
185
170
  """
@@ -200,327 +185,72 @@ class Dynamics(Module):
200
185
  """
201
186
  return self.in_size
202
187
 
203
- @property
204
- def current_inputs(self):
205
- """
206
- Get the dictionary of current inputs registered with this dynamics model.
207
-
208
- Current inputs represent direct input currents that flow into the model.
209
-
210
- Returns
211
- -------
212
- dict or None
213
- A dictionary mapping keys to current input functions or values,
214
- or None if no current inputs have been registered.
215
-
216
- See Also
217
- --------
218
- add_current_input : Register a new current input
219
- sum_current_inputs : Apply and sum all current inputs
220
- delta_inputs : Dictionary of instantaneous change inputs
221
- """
222
- return self._current_inputs
223
-
224
- @property
225
- def delta_inputs(self):
226
- """
227
- Get the dictionary of delta inputs registered with this dynamics model.
228
-
229
- Delta inputs represent instantaneous changes to state variables (dX/dt).
230
-
231
- Returns
232
- -------
233
- dict or None
234
- A dictionary mapping keys to delta input functions or values,
235
- or None if no delta inputs have been registered.
236
-
237
- See Also
238
- --------
239
- add_delta_input : Register a new delta input
240
- sum_delta_inputs : Apply and sum all delta inputs
241
- current_inputs : Dictionary of direct current inputs
242
- """
243
- return self._delta_inputs
244
-
245
- def add_current_input(
246
- self,
247
- key: str,
248
- inp: Union[Callable, ArrayLike],
249
- label: Optional[str] = None
250
- ):
251
- """
252
- Add a current input function or array to the dynamics model.
253
-
254
- Current inputs represent direct input currents that can be accessed during
255
- model updates through the `sum_current_inputs()` method.
256
-
257
- Parameters
258
- ----------
259
- key : str
260
- Unique identifier for this current input. Used to retrieve or reference
261
- the input later.
262
- inp : Union[Callable, ArrayLike]
263
- The input data or function that generates input data.
264
- - If callable: Will be called during updates with arguments passed to `sum_current_inputs()`
265
- - If array-like: Will be applied once and then automatically removed from available inputs
266
- label : Optional[str], default=None
267
- Optional grouping label for the input. When provided, allows selective
268
- processing of inputs by label in `sum_current_inputs()`.
269
-
270
- Raises
271
- ------
272
- ValueError
273
- If the key has already been used for a different current input.
274
-
275
- Notes
276
- -----
277
- - Inputs with the same label can be processed together using the `label`
278
- parameter in `sum_current_inputs()`.
279
- - Non-callable inputs are consumed when used (removed after first use).
280
- - Callable inputs persist and can be called repeatedly.
281
-
282
- See Also
283
- --------
284
- sum_current_inputs : Sum all current inputs matching a given label
285
- add_delta_input : Add a delta input function or array
286
- """
287
- key = _input_label_repr(key, label)
288
- if self._current_inputs is None:
289
- self._current_inputs = dict()
290
- if key in self._current_inputs:
291
- if id(self._current_inputs[key]) != id(inp):
292
- raise ValueError(f'Key "{key}" has been defined and used in the current inputs of {self}.')
293
- self._current_inputs[key] = inp
294
-
295
- def add_delta_input(
296
- self,
297
- key: str,
298
- inp: Union[Callable, ArrayLike],
299
- label: Optional[str] = None
300
- ):
301
- """
302
- Add a delta input function or array to the dynamics model.
303
-
304
- Delta inputs represent instantaneous changes to the model state (i.e., dX/dt contributions).
305
- This method registers a function or array that provides delta inputs which will be
306
- accessible during model updates through the `sum_delta_inputs()` method.
307
-
308
- Parameters
309
- ----------
310
- key : str
311
- Unique identifier for this delta input. Used to retrieve or reference
312
- the input later.
313
- inp : Union[Callable, ArrayLike]
314
- The input data or function that generates input data.
315
- - If callable: Will be called during updates with arguments passed to `sum_delta_inputs()`
316
- - If array-like: Will be applied once and then automatically removed from available inputs
317
- label : Optional[str], default=None
318
- Optional grouping label for the input. When provided, allows selective
319
- processing of inputs by label in `sum_delta_inputs()`.
320
-
321
- Raises
322
- ------
323
- ValueError
324
- If the key has already been used for a different delta input.
325
-
326
- Notes
327
- -----
328
- - Inputs with the same label can be processed together using the `label`
329
- parameter in `sum_delta_inputs()`.
330
- - Non-callable inputs are consumed when used (removed after first use).
331
- - Callable inputs persist and can be called repeatedly.
332
-
333
- See Also
334
- --------
335
- sum_delta_inputs : Sum all delta inputs matching a given label
336
- add_current_input : Add a current input function or array
337
- """
338
- key = _input_label_repr(key, label)
339
- if self._delta_inputs is None:
340
- self._delta_inputs = dict()
341
- if key in self._delta_inputs:
342
- if id(self._delta_inputs[key]) != id(inp):
343
- raise ValueError(f'Key "{key}" has been defined and used.')
344
- self._delta_inputs[key] = inp
345
-
346
- def get_input(self, key: str):
188
+ def prefetch(self, item: str) -> 'Prefetch':
347
189
  """
348
- Get a registered input function by its key.
190
+ Create a reference to a state or variable that may not be initialized yet.
349
191
 
350
- Retrieves either a current input or a delta input function that was previously
351
- registered with the given key. This method checks both current_inputs and
352
- delta_inputs dictionaries for the specified key.
192
+ This method allows accessing module attributes or states before they are
193
+ fully defined, acting as a placeholder that will be resolved when called.
194
+ Particularly useful for creating references to variables that will be defined
195
+ during initialization or runtime.
353
196
 
354
197
  Parameters
355
198
  ----------
356
- key : str
357
- The unique identifier used when the input function was registered.
199
+ item : str
200
+ The name of the attribute or state to reference.
358
201
 
359
202
  Returns
360
203
  -------
361
- Callable or ArrayLike
362
- The input function or array associated with the given key.
363
-
364
- Raises
365
- ------
366
- ValueError
367
- If no input function is found with the specified key in either
368
- current_inputs or delta_inputs.
369
-
370
- See Also
371
- --------
372
- add_current_input : Register a current input function
373
- add_delta_input : Register a delta input function
204
+ Prefetch
205
+ A Prefetch object that provides access to the referenced item.
374
206
 
375
207
  Examples
376
208
  --------
377
- >>> model = Dynamics(10)
378
- >>> model.add_current_input('stimulus', lambda t: np.sin(t))
379
- >>> input_func = model.get_input('stimulus')
380
- >>> input_func(0.5) # Returns sin(0.5)
209
+ >>> import brainstate
210
+ >>> import brainunit as u
211
+ >>> neuron = brainstate.nn.LIF(...)
212
+ >>> v_ref = neuron.prefetch('V') # Reference to voltage
213
+ >>> v_value = v_ref() # Get current value
214
+ >>> delayed_v = v_ref.delay.at(5.0 * u.ms) # Get delayed value
381
215
  """
382
- if self._current_inputs is not None and key in self._current_inputs:
383
- return self._current_inputs[key]
384
- elif self._delta_inputs is not None and key in self._delta_inputs:
385
- return self._delta_inputs[key]
386
- else:
387
- raise ValueError(f'Input key {key} is not in current/delta inputs of the module {self}.')
216
+ return Prefetch(self, item)
388
217
 
389
- def sum_current_inputs(
390
- self,
391
- init: Any,
392
- *args,
393
- label: Optional[str] = None,
394
- **kwargs
395
- ):
218
+ def prefetch_delay(self, state: str, delay_time, init: Callable = None) -> 'PrefetchDelayAt':
396
219
  """
397
- Summarize all current inputs by applying and summing all registered current input functions.
398
-
399
- This method iterates through all registered current input functions (from `.current_inputs`)
400
- and applies them to calculate the total input current for the dynamics model. It adds all results
401
- to the initial value provided.
220
+ Create a reference to a delayed state or variable in the module.
402
221
 
403
- Parameters
404
- ----------
405
- init : Any
406
- The initial value to which all current inputs will be added.
407
- *args
408
- Variable length argument list passed to each current input function.
409
- label : Optional[str], default=None
410
- If provided, only process current inputs with this label prefix.
411
- When None, process all current inputs regardless of label.
412
- **kwargs
413
- Arbitrary keyword arguments passed to each current input function.
222
+ This method simplifies the process of accessing a delayed version of a state or variable
223
+ within the module. It first creates a prefetch reference to the specified state,
224
+ then specifies the delay time for accessing this state.
414
225
 
415
- Returns
416
- -------
417
- Any
418
- The initial value plus all applicable current inputs summed together.
226
+ Args:
227
+ state (str): The name of the state or variable to reference.
228
+ delay_time (ArrayLike): The amount of time to delay the variable access,
229
+ typically in time units (e.g., milliseconds).
230
+ init (Callable, optional): An optional initialization function to provide
231
+ a default value if the delayed state is not yet available.
419
232
 
420
- Notes
421
- -----
422
- - Non-callable current inputs are applied once and then automatically removed from
423
- the current_inputs dictionary.
424
- - Callable current inputs remain registered for subsequent calls.
425
- - When a label is provided, only current inputs with keys starting with that label
426
- are applied.
427
- """
428
- if self._current_inputs is None:
429
- return init
430
- if label is None:
431
- filter_fn = lambda k: True
432
- else:
433
- label_repr = _input_label_start(label)
434
- filter_fn = lambda k: k.startswith(label_repr)
435
- for key in tuple(self._current_inputs.keys()):
436
- if filter_fn(key):
437
- out = self._current_inputs[key]
438
- if callable(out):
439
- try:
440
- init = init + out(*args, **kwargs)
441
- except Exception as e:
442
- raise ValueError(
443
- f'Error in current input value {key}: {out}\n'
444
- f'Error: {e}'
445
- ) from e
446
- else:
447
- try:
448
- init = init + out
449
- except Exception as e:
450
- raise ValueError(
451
- f'Error in current input value {key}: {out}\n'
452
- f'Error: {e}'
453
- ) from e
454
- self._current_inputs.pop(key)
455
- return init
456
-
457
- def sum_delta_inputs(
458
- self,
459
- init: Any,
460
- *args,
461
- label: Optional[str] = None,
462
- **kwargs
463
- ):
233
+ Returns:
234
+ PrefetchDelayAt: An object that provides access to the variable at the specified delay time.
464
235
  """
465
- Summarize all delta inputs by applying and summing all registered delta input functions.
236
+ return PrefetchDelayAt(self, state, delay_time, init=init)
466
237
 
467
- This method iterates through all registered delta input functions (from `.delta_inputs`)
468
- and applies them to calculate instantaneous changes to model states. It adds all results
469
- to the initial value provided.
238
+ def output_delay(self, *delay_time) -> 'OutputDelayAt':
239
+ """
240
+ Create a reference to the delayed output of the module.
470
241
 
471
- Parameters
472
- ----------
473
- init : Any
474
- The initial value to which all delta inputs will be added.
475
- *args
476
- Variable length argument list passed to each delta input function.
477
- label : Optional[str], default=None
478
- If provided, only process delta inputs with this label prefix.
479
- When None, process all delta inputs regardless of label.
480
- **kwargs
481
- Arbitrary keyword arguments passed to each delta input function.
242
+ This method simplifies the process of accessing a delayed version of the module's output.
243
+ It instantiates an `OutputDelayAt` object, which can be used to retrieve the output value
244
+ at the specified delay time.
482
245
 
483
- Returns
484
- -------
485
- Any
486
- The initial value plus all applicable delta inputs summed together.
246
+ Args:
247
+ delay (Optional[ArrayLike]): The amount of time to delay the output access,
248
+ typically in time units (e.g., milliseconds). Defaults to None.
487
249
 
488
- Notes
489
- -----
490
- - Non-callable delta inputs are applied once and then automatically removed from
491
- the delta_inputs dictionary.
492
- - Callable delta inputs remain registered for subsequent calls.
493
- - When a label is provided, only delta inputs with keys starting with that label
494
- are applied.
250
+ Returns:
251
+ OutputDelayAt: An object that provides access to the module's output at the specified delay time.
495
252
  """
496
- if self._delta_inputs is None:
497
- return init
498
- if label is None:
499
- filter_fn = lambda k: True
500
- else:
501
- label_repr = _input_label_start(label)
502
- filter_fn = lambda k: k.startswith(label_repr)
503
- for key in tuple(self._delta_inputs.keys()):
504
- if filter_fn(key):
505
- out = self._delta_inputs[key]
506
- if callable(out):
507
- try:
508
- init = init + out(*args, **kwargs)
509
- except Exception as e:
510
- raise ValueError(
511
- f'Error in delta input function {key}: {out}\n'
512
- f'Error: {e}'
513
- ) from e
514
- else:
515
- try:
516
- init = init + out
517
- except Exception as e:
518
- raise ValueError(
519
- f'Error in delta input value {key}: {out}\n'
520
- f'Error: {e}'
521
- ) from e
522
- self._delta_inputs.pop(key)
523
- return init
253
+ return OutputDelayAt(self, delay_time)
524
254
 
525
255
  @property
526
256
  def before_updates(self):
@@ -559,7 +289,7 @@ class Dynamics(Module):
559
289
  """
560
290
  return self._after_updates
561
291
 
562
- def _add_before_update(self, key: Any, fun: Callable):
292
+ def add_before_update(self, key: Any, fun: Callable):
563
293
  """
564
294
  Register a function to be executed before the module's update.
565
295
 
@@ -585,7 +315,7 @@ class Dynamics(Module):
585
315
  raise KeyError(f'{key} has been registered in before_updates of {self}')
586
316
  self.before_updates[key] = fun
587
317
 
588
- def _add_after_update(self, key: Any, fun: Callable):
318
+ def add_after_update(self, key: Any, fun: Callable):
589
319
  """
590
320
  Register a function to be executed after the module's update.
591
321
 
@@ -611,7 +341,7 @@ class Dynamics(Module):
611
341
  raise KeyError(f'{key} has been registered in after_updates of {self}')
612
342
  self.after_updates[key] = fun
613
343
 
614
- def _get_before_update(self, key: Any):
344
+ def get_before_update(self, key: Any):
615
345
  """
616
346
  Retrieve a registered before-update function by its key.
617
347
 
@@ -636,7 +366,7 @@ class Dynamics(Module):
636
366
  raise KeyError(f'{key} is not registered in before_updates of {self}')
637
367
  return self.before_updates.get(key)
638
368
 
639
- def _get_after_update(self, key: Any):
369
+ def get_after_update(self, key: Any):
640
370
  """
641
371
  Retrieve a registered after-update function by its key.
642
372
 
@@ -661,7 +391,7 @@ class Dynamics(Module):
661
391
  raise KeyError(f'{key} is not registered in after_updates of {self}')
662
392
  return self.after_updates.get(key)
663
393
 
664
- def _has_before_update(self, key: Any):
394
+ def has_before_update(self, key: Any):
665
395
  """
666
396
  Check if a before-update function is registered with the given key.
667
397
 
@@ -679,7 +409,7 @@ class Dynamics(Module):
679
409
  return False
680
410
  return key in self.before_updates
681
411
 
682
- def _has_after_update(self, key: Any):
412
+ def has_after_update(self, key: Any):
683
413
  """
684
414
  Check if an after-update function is registered with the given key.
685
415
 
@@ -722,120 +452,6 @@ class Dynamics(Module):
722
452
  model(ret)
723
453
  return ret
724
454
 
725
- def prefetch(self, item: str) -> 'Prefetch':
726
- """
727
- Create a reference to a state or variable that may not be initialized yet.
728
-
729
- This method allows accessing module attributes or states before they are
730
- fully defined, acting as a placeholder that will be resolved when called.
731
- Particularly useful for creating references to variables that will be defined
732
- during initialization or runtime.
733
-
734
- Parameters
735
- ----------
736
- item : str
737
- The name of the attribute or state to reference.
738
-
739
- Returns
740
- -------
741
- Prefetch
742
- A Prefetch object that provides access to the referenced item.
743
-
744
- Examples
745
- --------
746
- >>> import brainstate
747
- >>> import brainunit as u
748
- >>> neuron = brainstate.nn.LIF(...)
749
- >>> v_ref = neuron.prefetch('V') # Reference to voltage
750
- >>> v_value = v_ref() # Get current value
751
- >>> delayed_v = v_ref.delay.at(5.0 * u.ms) # Get delayed value
752
- """
753
- return Prefetch(self, item)
754
-
755
- def align_pre(self, dyn: Union[ParamDescriber[T], T]) -> T:
756
- """
757
- Registers a dynamics module to execute after this module.
758
-
759
- This method establishes a sequential execution relationship where the specified
760
- dynamics module will be called after this module completes its update. This
761
- creates a feed-forward connection in the computational graph.
762
-
763
- Parameters
764
- ----------
765
- dyn : Union[ParamDescriber[T], T]
766
- The dynamics module to be executed after this module. Can be either:
767
- - An instance of Dynamics
768
- - A ParamDescriber that can instantiate a Dynamics object
769
-
770
- Returns
771
- -------
772
- T
773
- The dynamics module that was registered, allowing for method chaining.
774
-
775
- Raises
776
- ------
777
- TypeError
778
- If the input is not a Dynamics instance or a ParamDescriber that creates
779
- a Dynamics instance.
780
-
781
- Examples
782
- --------
783
- >>> import brainstate
784
- >>> n1 = brainstate.nn.LIF(10)
785
- >>> n1.align_pre(brainstate.nn.Expon.desc(n1.varshape)) # n2 will run after n1
786
- """
787
- if isinstance(dyn, Dynamics):
788
- self._add_after_update(id(dyn), dyn)
789
- return dyn
790
- elif isinstance(dyn, ParamDescriber):
791
- if not issubclass(dyn.cls, Dynamics):
792
- raise TypeError(f'The input {dyn} should be an instance of {Dynamics}.')
793
- if not self._has_after_update(dyn.identifier):
794
- self._add_after_update(
795
- dyn.identifier,
796
- dyn() if ('in_size' in dyn.kwargs or len(dyn.args) > 0) else dyn(in_size=self.varshape)
797
- )
798
- return self._get_after_update(dyn.identifier)
799
- else:
800
- raise TypeError(f'The input {dyn} should be an instance of {Dynamics} or a delayed initializer.')
801
-
802
- def prefetch_delay(self, state: str, delay_time, init: Callable = None) -> 'PrefetchDelayAt':
803
- """
804
- Create a reference to a delayed state or variable in the module.
805
-
806
- This method simplifies the process of accessing a delayed version of a state or variable
807
- within the module. It first creates a prefetch reference to the specified state,
808
- then specifies the delay time for accessing this state.
809
-
810
- Args:
811
- state (str): The name of the state or variable to reference.
812
- delay_time (ArrayLike): The amount of time to delay the variable access,
813
- typically in time units (e.g., milliseconds).
814
- init (Callable, optional): An optional initialization function to provide
815
- a default value if the delayed state is not yet available.
816
-
817
- Returns:
818
- PrefetchDelayAt: An object that provides access to the variable at the specified delay time.
819
- """
820
- return PrefetchDelayAt(self, state, delay_time, init=init)
821
-
822
- def output_delay(self, *delay_time) -> 'OutputDelayAt':
823
- """
824
- Create a reference to the delayed output of the module.
825
-
826
- This method simplifies the process of accessing a delayed version of the module's output.
827
- It instantiates an `OutputDelayAt` object, which can be used to retrieve the output value
828
- at the specified delay time.
829
-
830
- Args:
831
- delay (Optional[ArrayLike]): The amount of time to delay the output access,
832
- typically in time units (e.g., milliseconds). Defaults to None.
833
-
834
- Returns:
835
- OutputDelayAt: An object that provides access to the module's output at the specified delay time.
836
- """
837
- return OutputDelayAt(self, delay_time)
838
-
839
455
 
840
456
  class Prefetch(Node):
841
457
  """
@@ -1047,14 +663,14 @@ class PrefetchDelayAt(Node):
1047
663
  self.delay_time = delay_time
1048
664
  if len(delay_time) > 0:
1049
665
  key = _get_prefetch_delay_key(item)
1050
- if not module._has_after_update(key):
1051
- module._add_after_update(
666
+ if not module.has_after_update(key):
667
+ module.add_after_update(
1052
668
  key,
1053
669
  not_receive_update_output(
1054
670
  StateWithDelay(module, item, init=init)
1055
671
  )
1056
672
  )
1057
- self.state_delay: StateWithDelay = module._get_after_update(key)
673
+ self.state_delay: StateWithDelay = module.get_after_update(key)
1058
674
  self.delay_info = self.state_delay.register_delay(*delay_time)
1059
675
 
1060
676
  def __call__(self, *args, **kwargs):
@@ -1108,10 +724,10 @@ class OutputDelayAt(Node):
1108
724
  assert isinstance(module, Dynamics), 'The module should be an instance of Dynamics.'
1109
725
  self.module = module
1110
726
  key = _get_output_delay_key()
1111
- if not module._has_after_update(key):
727
+ if not module.has_after_update(key):
1112
728
  delay = Delay(jax.ShapeDtypeStruct(module.out_size, dtype=environ.dftype()), take_aware_unit=True)
1113
- module._add_after_update(key, receive_update_output(delay))
1114
- self.out_delay: Delay = module._get_after_update(key)
729
+ module.add_after_update(key, receive_update_output(delay))
730
+ self.out_delay: Delay = module.get_after_update(key)
1115
731
  self.delay_info = self.out_delay.register_delay(*delay_time)
1116
732
 
1117
733
  def __call__(self, *args, **kwargs):
@@ -1138,7 +754,7 @@ def _get_prefetch_item_delay(target: Union[Prefetch, PrefetchDelay, PrefetchDela
1138
754
  f'The target module should be an instance '
1139
755
  f'of Dynamics. But got {target.module}.'
1140
756
  )
1141
- delay = target.module._get_after_update(_get_prefetch_delay_key(target.item))
757
+ delay = target.module.get_after_update(_get_prefetch_delay_key(target.item))
1142
758
  if not isinstance(delay, StateWithDelay):
1143
759
  raise TypeError(f'The prefetch target should be a {StateWithDelay.__name__} when accessing '
1144
760
  f'its delay. But got {delay}.')
@@ -1187,9 +803,6 @@ def maybe_init_prefetch(target, *args, **kwargs):
1187
803
  # delay.register_delay(*target.delay_time)
1188
804
 
1189
805
 
1190
- DynamicsGroup = Module
1191
-
1192
-
1193
806
  def receive_update_output(cls: object):
1194
807
  """
1195
808
  The decorator to mark the object (as the after updates) to receive the output of the update function.
@@ -1255,13 +868,3 @@ def not_receive_update_input(cls: object):
1255
868
  if hasattr(cls, '_receive_update_input'):
1256
869
  delattr(cls, '_receive_update_input')
1257
870
  return cls
1258
-
1259
-
1260
- def _input_label_start(label: str):
1261
- # unify the input label repr.
1262
- return f'{label} // '
1263
-
1264
-
1265
- def _input_label_repr(name: str, label: Optional[str] = None):
1266
- # unify the input label repr.
1267
- return name if label is None else (_input_label_start(label) + str(name))