brainstate 0.1.2__py2.py3-none-any.whl → 0.1.4__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +0 -15
  3. brainstate/compile/_jit.py +14 -5
  4. brainstate/compile/_make_jaxpr.py +78 -22
  5. brainstate/compile/_make_jaxpr_test.py +13 -2
  6. brainstate/graph/_graph_node.py +1 -1
  7. brainstate/graph/_graph_operation.py +4 -4
  8. brainstate/mixin.py +30 -14
  9. brainstate/nn/__init__.py +84 -17
  10. brainstate/nn/{_interaction/_conv.py → _conv.py} +1 -1
  11. brainstate/nn/{_dynamics/_state_delay.py → _delay.py} +19 -3
  12. brainstate/nn/{_elementwise/_dropout.py → _dropout.py} +6 -5
  13. brainstate/nn/{_dynamics/_dynamics_base.py → _dynamics.py} +137 -21
  14. brainstate/nn/{_elementwise/_elementwise.py → _elementwise.py} +1 -1
  15. brainstate/nn/{_interaction/_embedding.py → _embedding.py} +1 -1
  16. brainstate/nn/{_event/_fixedprob_mv.py → _fixedprob.py} +96 -25
  17. brainstate/nn/{_dyn_impl/_inputs.py → _inputs.py} +4 -5
  18. brainstate/nn/{_interaction/_linear.py → _linear.py} +2 -5
  19. brainstate/nn/{_event/_linear_mv.py → _linear_mv.py} +2 -2
  20. brainstate/nn/{_event/__init__.py → _ltp.py} +7 -5
  21. brainstate/nn/_module.py +5 -5
  22. brainstate/nn/{_dyn_impl/_dynamics_neuron.py → _neuron.py} +2 -2
  23. brainstate/nn/{_interaction/_normalizations.py → _normalizations.py} +1 -1
  24. brainstate/nn/{_interaction/_poolings.py → _poolings.py} +1 -1
  25. brainstate/nn/{_interaction/_poolings_test.py → _poolings_test.py} +1 -1
  26. brainstate/nn/_projection.py +486 -0
  27. brainstate/nn/{_dyn_impl/_rate_rnns.py → _rate_rnns.py} +2 -2
  28. brainstate/nn/{_dyn_impl/_readout.py → _readout.py} +3 -3
  29. brainstate/nn/_stp.py +236 -0
  30. brainstate/nn/{_dyn_impl/_dynamics_synapse.py → _synapse.py} +19 -212
  31. brainstate/nn/_synaptic_projection.py +423 -0
  32. brainstate/nn/{_dynamics/_synouts.py → _synouts.py} +4 -1
  33. brainstate/surrogate.py +1 -1
  34. brainstate/typing.py +1 -1
  35. brainstate/util/__init__.py +14 -14
  36. brainstate/util/{_pretty_pytree.py → pretty_pytree.py} +2 -2
  37. {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/METADATA +1 -1
  38. {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/RECORD +61 -63
  39. brainstate/nn/_dyn_impl/__init__.py +0 -42
  40. brainstate/nn/_dynamics/__init__.py +0 -37
  41. brainstate/nn/_dynamics/_projection_base.py +0 -362
  42. brainstate/nn/_elementwise/__init__.py +0 -22
  43. brainstate/nn/_interaction/__init__.py +0 -41
  44. /brainstate/nn/{_interaction/_conv_test.py → _conv_test.py} +0 -0
  45. /brainstate/nn/{_elementwise/_dropout_test.py → _dropout_test.py} +0 -0
  46. /brainstate/nn/{_dynamics/_dynamics_base_test.py → _dynamics_test.py} +0 -0
  47. /brainstate/nn/{_elementwise/_elementwise_test.py → _elementwise_test.py} +0 -0
  48. /brainstate/nn/{_event/_fixedprob_mv_test.py → _fixedprob_test.py} +0 -0
  49. /brainstate/nn/{_event/_linear_mv_test.py → _linear_mv_test.py} +0 -0
  50. /brainstate/nn/{_interaction/_linear_test.py → _linear_test.py} +0 -0
  51. /brainstate/nn/{_dyn_impl/_dynamics_neuron_test.py → _neuron_test.py} +0 -0
  52. /brainstate/nn/{_interaction/_normalizations_test.py → _normalizations_test.py} +0 -0
  53. /brainstate/nn/{_dyn_impl/_rate_rnns_test.py → _rate_rnns_test.py} +0 -0
  54. /brainstate/nn/{_dyn_impl/_readout_test.py → _readout_test.py} +0 -0
  55. /brainstate/nn/{_dyn_impl/_dynamics_synapse_test.py → _synapse_test.py} +0 -0
  56. /brainstate/nn/{_dynamics/_synouts_test.py → _synouts_test.py} +0 -0
  57. /brainstate/util/{_caller.py → caller.py} +0 -0
  58. /brainstate/util/{_error.py → error.py} +0 -0
  59. /brainstate/util/{_others.py → others.py} +0 -0
  60. /brainstate/util/{_pretty_repr.py → pretty_repr.py} +0 -0
  61. /brainstate/util/{_pretty_table.py → pretty_table.py} +0 -0
  62. /brainstate/util/{_scaling.py → scaling.py} +0 -0
  63. /brainstate/util/{_struct.py → struct.py} +0 -0
  64. {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/LICENSE +0 -0
  65. {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/WHEEL +0 -0
  66. {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/top_level.txt +0 -0
brainstate/nn/__init__.py CHANGED
@@ -19,46 +19,113 @@ from ._collective_ops import *
19
19
  from ._collective_ops import __all__ as collective_ops_all
20
20
  from ._common import *
21
21
  from ._common import __all__ as common_all
22
- from ._dyn_impl import *
23
- from ._dyn_impl import __all__ as dyn_impl_all
22
+ from ._conv import *
23
+ from ._conv import __all__ as conv_all
24
+ from ._delay import *
25
+ from ._delay import __all__ as state_delay_all
26
+ from ._dropout import *
27
+ from ._dropout import __all__ as dropout_all
24
28
  from ._dynamics import *
25
- from ._dynamics import __all__ as dynamics_all
29
+ from ._dynamics import __all__ as dyn_all
26
30
  from ._elementwise import *
27
31
  from ._elementwise import __all__ as elementwise_all
28
- from ._event import *
29
- from ._event import __all__ as _event_all
32
+ from ._embedding import *
33
+ from ._embedding import __all__ as embed_all
30
34
  from ._exp_euler import *
31
35
  from ._exp_euler import __all__ as exp_euler_all
32
- from ._interaction import *
33
- from ._interaction import __all__ as interaction_all
36
+ from ._fixedprob import *
37
+ from._fixedprob import __all__ as fixedprob_all
38
+ from ._inputs import *
39
+ from ._inputs import __all__ as inputs_all
40
+ from ._linear import *
41
+ from ._linear import __all__ as linear_all
42
+ from ._linear_mv import *
43
+ from ._linear_mv import __all__ as linear_mv_all
44
+ from ._ltp import *
45
+ from ._ltp import __all__ as ltp_all
34
46
  from ._module import *
35
47
  from ._module import __all__ as module_all
48
+ from ._neuron import *
49
+ from ._neuron import __all__ as dyn_neuron_all
50
+ from ._normalizations import *
51
+ from ._normalizations import __all__ as normalizations_all
52
+ from ._poolings import *
53
+ from ._poolings import __all__ as poolings_all
54
+ from ._projection import *
55
+ from ._projection import __all__ as projection_all
56
+ from ._rate_rnns import *
57
+ from ._rate_rnns import __all__ as rate_rnns
58
+ from ._readout import *
59
+ from ._readout import __all__ as readout_all
60
+ from ._stp import *
61
+ from ._stp import __all__ as stp_all
62
+ from ._synapse import *
63
+ from ._synapse import __all__ as dyn_synapse_all
64
+ from ._synaptic_projection import *
65
+ from ._synaptic_projection import __all__ as _syn_proj_all
66
+ from ._synouts import *
67
+ from ._synouts import __all__ as synouts_all
36
68
  from ._utils import *
37
69
  from ._utils import __all__ as utils_all
38
70
 
39
71
  __all__ = (
40
- ['metrics']
72
+ [
73
+ 'metrics',
74
+ ]
41
75
  + collective_ops_all
42
76
  + common_all
43
- + dyn_impl_all
44
- + dynamics_all
45
77
  + elementwise_all
46
78
  + module_all
47
79
  + exp_euler_all
48
- + interaction_all
49
80
  + utils_all
50
- + _event_all
81
+ + dyn_all
82
+ + projection_all
83
+ + state_delay_all
84
+ + synouts_all
85
+ + conv_all
86
+ + linear_all
87
+ + normalizations_all
88
+ + poolings_all
89
+ + fixedprob_all
90
+ + linear_mv_all
91
+ + embed_all
92
+ + dropout_all
93
+ + elementwise_all
94
+ + dyn_neuron_all
95
+ + dyn_synapse_all
96
+ + inputs_all
97
+ + rate_rnns
98
+ + readout_all
99
+ + stp_all
100
+ + ltp_all
101
+ + _syn_proj_all
51
102
  )
52
103
 
53
104
  del (
54
105
  collective_ops_all,
55
106
  common_all,
56
- dyn_impl_all,
57
- dynamics_all,
58
- elementwise_all,
59
107
  module_all,
60
108
  exp_euler_all,
61
- interaction_all,
62
109
  utils_all,
63
- _event_all,
110
+ dyn_all,
111
+ projection_all,
112
+ state_delay_all,
113
+ synouts_all,
114
+ conv_all,
115
+ linear_all,
116
+ normalizations_all,
117
+ poolings_all,
118
+ embed_all,
119
+ fixedprob_all,
120
+ linear_mv_all,
121
+ dropout_all,
122
+ elementwise_all,
123
+ dyn_neuron_all,
124
+ dyn_synapse_all,
125
+ inputs_all,
126
+ readout_all,
127
+ rate_rnns,
128
+ stp_all,
129
+ ltp_all,
130
+ _syn_proj_all,
64
131
  )
@@ -23,8 +23,8 @@ import jax.numpy as jnp
23
23
 
24
24
  from brainstate import init, functional
25
25
  from brainstate._state import ParamState
26
- from brainstate.nn._module import Module
27
26
  from brainstate.typing import ArrayLike
27
+ from ._module import Module
28
28
 
29
29
  T = TypeVar('T')
30
30
 
@@ -27,9 +27,9 @@ from brainstate import environ
27
27
  from brainstate._state import ShortTermState, State
28
28
  from brainstate.compile import jit_error_if
29
29
  from brainstate.graph import Node
30
- from brainstate.nn._collective_ops import call_order
31
- from brainstate.nn._module import Module
32
30
  from brainstate.typing import ArrayLike, PyTree
31
+ from ._collective_ops import call_order
32
+ from ._module import Module
33
33
 
34
34
  __all__ = [
35
35
  'Delay', 'DelayAccess', 'StateWithDelay',
@@ -135,6 +135,7 @@ class Delay(Module):
135
135
  entries: Optional[Dict] = None, # delay access entry
136
136
  delay_method: Optional[str] = _DELAY_ROTATE, # delay method
137
137
  interp_method: str = _INTERP_LINEAR, # interpolation method
138
+ take_aware_unit: bool = False
138
139
  ):
139
140
  # target information
140
141
  self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), target_info)
@@ -170,6 +171,9 @@ class Delay(Module):
170
171
  for entry, delay_time in entries.items():
171
172
  self.register_entry(entry, delay_time)
172
173
 
174
+ self.take_aware_unit = take_aware_unit
175
+ self._unit = None
176
+
173
177
  @property
174
178
  def history(self):
175
179
  return self._history
@@ -326,7 +330,14 @@ class Delay(Module):
326
330
  indices = (delay_idx,) + indices
327
331
 
328
332
  # the delay data
329
- return jax.tree.map(lambda a: a[indices], self.history.value)
333
+ if self._unit is None:
334
+ return jax.tree.map(lambda a: a[indices], self.history.value)
335
+ else:
336
+ return jax.tree.map(
337
+ lambda hist, unit: u.maybe_decimal(hist[indices] * unit),
338
+ self.history.value,
339
+ self._unit
340
+ )
330
341
 
331
342
  def retrieve_at_time(self, delay_time, *indices) -> PyTree:
332
343
  """
@@ -389,6 +400,9 @@ class Delay(Module):
389
400
  """
390
401
  assert self.history is not None, 'The delay history is not initialized.'
391
402
 
403
+ if self.take_aware_unit and self._unit is None:
404
+ self._unit = jax.tree.map(lambda x: u.get_unit(x), current, is_leaf=u.math.is_quantity)
405
+
392
406
  # update the delay data at the rotation index
393
407
  if self.delay_method == _DELAY_ROTATE:
394
408
  i = environ.get(environ.I)
@@ -415,6 +429,8 @@ class Delay(Module):
415
429
  raise ValueError(f'Unknown updating method "{self.delay_method}"')
416
430
 
417
431
 
432
+
433
+
418
434
  class StateWithDelay(Delay):
419
435
  """
420
436
  A ``State`` type that defines the state in a differential equation.
@@ -22,8 +22,8 @@ import jax.numpy as jnp
22
22
 
23
23
  from brainstate import random, environ, init
24
24
  from brainstate._state import ShortTermState
25
- from brainstate.nn._module import ElementWiseBlock
26
25
  from brainstate.typing import Size
26
+ from ._module import ElementWiseBlock
27
27
 
28
28
  __all__ = [
29
29
  'DropoutFixed', 'Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d',
@@ -409,7 +409,8 @@ class DropoutFixed(ElementWiseBlock):
409
409
  self.out_size = in_size
410
410
 
411
411
  def init_state(self, batch_size=None, **kwargs):
412
- self.mask = ShortTermState(init.param(partial(random.bernoulli, self.prob), self.in_size, batch_size))
412
+ if self.prob < 1.:
413
+ self.mask = ShortTermState(init.param(partial(random.bernoulli, self.prob), self.in_size, batch_size))
413
414
 
414
415
  def update(self, x):
415
416
  dtype = u.math.get_dtype(x)
@@ -418,8 +419,8 @@ class DropoutFixed(ElementWiseBlock):
418
419
  if self.mask.value.shape != x.shape:
419
420
  raise ValueError(f"Input shape {x.shape} does not match the mask shape {self.mask.value.shape}. "
420
421
  f"Please call `init_state()` method first.")
421
- return jnp.where(self.mask.value,
422
- jnp.asarray(x / self.prob, dtype=dtype),
423
- jnp.asarray(0., dtype=dtype))
422
+ return u.math.where(self.mask.value,
423
+ u.math.asarray(x / self.prob, dtype=dtype),
424
+ u.math.asarray(0., dtype=dtype) * u.get_unit(x))
424
425
  else:
425
426
  return x
@@ -36,18 +36,20 @@ For handling the delays:
36
36
  from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, TYPE_CHECKING
37
37
 
38
38
  import brainunit as u
39
+ import jax
39
40
  import numpy as np
40
41
 
41
42
  from brainstate import environ
42
43
  from brainstate._state import State
43
44
  from brainstate.graph import Node
44
- from brainstate.mixin import ParamDescriber
45
- from brainstate.nn._module import Module
46
- from brainstate.typing import Size, ArrayLike
47
- from ._state_delay import StateWithDelay, Delay
45
+ from brainstate.mixin import ParamDescriber, UpdateReturn
46
+ from brainstate.typing import Size, ArrayLike, PyTree
47
+ from ._delay import StateWithDelay, Delay
48
+ from ._module import Module
48
49
 
49
50
  __all__ = [
50
- 'DynamicsGroup', 'Projection', 'Dynamics', 'Prefetch',
51
+ 'DynamicsGroup', 'Projection', 'Dynamics',
52
+ 'Prefetch', 'PrefetchDelay', 'PrefetchDelayAt', 'OutputDelayAt',
51
53
  ]
52
54
 
53
55
  T = TypeVar('T')
@@ -99,7 +101,7 @@ class Projection(Module):
99
101
  raise ValueError('Do not implement the update() function.')
100
102
 
101
103
 
102
- class Dynamics(Module):
104
+ class Dynamics(Module, UpdateReturn):
103
105
  """
104
106
  Base class for implementing neural dynamics models in BrainState.
105
107
 
@@ -810,17 +812,64 @@ class Dynamics(Module):
810
812
  >>> n1.align_pre(brainstate.nn.Expon.desc(n1.varshape)) # n2 will run after n1
811
813
  """
812
814
  if isinstance(dyn, Dynamics):
813
- self._add_after_update(dyn.name, dyn)
815
+ self._add_after_update(id(dyn), dyn)
814
816
  return dyn
815
817
  elif isinstance(dyn, ParamDescriber):
816
818
  if not issubclass(dyn.cls, Dynamics):
817
819
  raise TypeError(f'The input {dyn} should be an instance of {Dynamics}.')
818
820
  if not self._has_after_update(dyn.identifier):
819
- self._add_after_update(dyn.identifier, dyn())
821
+ self._add_after_update(
822
+ dyn.identifier,
823
+ dyn() if ('in_size' in dyn.kwargs or len(dyn.args) > 0) else dyn(in_size=self.varshape)
824
+ )
820
825
  return self._get_after_update(dyn.identifier)
821
826
  else:
822
827
  raise TypeError(f'The input {dyn} should be an instance of {Dynamics} or a delayed initializer.')
823
828
 
829
+ def prefetch_delay(
830
+ self,
831
+ state: str,
832
+ delay: Optional[ArrayLike] = None
833
+ ) -> 'PrefetchDelayAt':
834
+ """
835
+ Create a reference to a delayed state or variable in the module.
836
+
837
+ This method simplifies the process of accessing a delayed version of a state or variable
838
+ within the module. It first creates a prefetch reference to the specified state,
839
+ then specifies the delay time for accessing this state.
840
+
841
+ Args:
842
+ state (str): The name of the state or variable to reference.
843
+ delay (Optional[ArrayLike]): The amount of time to delay the variable access,
844
+ typically in time units (e.g., milliseconds). Defaults to None.
845
+
846
+ Returns:
847
+ PrefetchDelayAt: An object that provides access to the variable at the specified delay time.
848
+ """
849
+ return self.prefetch(state).delay.at(delay)
850
+
851
+ def output_delay(
852
+ self,
853
+ delay: Optional[ArrayLike] = None,
854
+ variable_like: PyTree = None
855
+ ) -> 'OutputDelayAt':
856
+ """
857
+ Create a reference to the delayed output of the module.
858
+
859
+ This method simplifies the process of accessing a delayed version of the module's output.
860
+ It instantiates an `OutputDelayAt` object, which can be used to retrieve the output value
861
+ at the specified delay time.
862
+
863
+ Args:
864
+ delay (Optional[ArrayLike]): The amount of time to delay the output access,
865
+ typically in time units (e.g., milliseconds). Defaults to None.
866
+ variable_like:
867
+
868
+ Returns:
869
+ OutputDelayAt: An object that provides access to the module's output at the specified delay time.
870
+ """
871
+ return OutputDelayAt(self, delay)
872
+
824
873
 
825
874
  class Prefetch(Node):
826
875
  """
@@ -885,6 +934,7 @@ class Prefetch(Node):
885
934
  An object that provides access to delayed versions of the prefetched item.
886
935
  """
887
936
  return PrefetchDelay(self.module, self.item)
937
+ # return PrefetchDelayAt(self.module, self.item, time)
888
938
 
889
939
  def __call__(self, *args, **kwargs):
890
940
  """
@@ -1007,7 +1057,7 @@ class PrefetchDelayAt(Node):
1007
1057
  self,
1008
1058
  module: Dynamics,
1009
1059
  item: str,
1010
- time: ArrayLike
1060
+ time: ArrayLike = None,
1011
1061
  ):
1012
1062
  """
1013
1063
  Initialize a PrefetchDelayAt object.
@@ -1026,14 +1076,16 @@ class PrefetchDelayAt(Node):
1026
1076
  self.module = module
1027
1077
  self.item = item
1028
1078
  self.time = time
1029
- self.step = u.math.asarray(time / environ.get_dt(), dtype=environ.ditype())
1030
1079
 
1031
- # register the delay
1032
- key = _get_delay_key(item)
1033
- if not module._has_after_update(key):
1034
- module._add_after_update(key, not_receive_update_output(StateWithDelay(module, item)))
1035
- self.state_delay: StateWithDelay = module._get_after_update(key)
1036
- self.state_delay.register_delay(time)
1080
+ if time is not None:
1081
+ self.step = u.math.asarray(time / environ.get_dt(), dtype=environ.ditype())
1082
+
1083
+ # register the delay
1084
+ key = _get_prefetch_delay_key(item)
1085
+ if not module._has_after_update(key):
1086
+ module._add_after_update(key, not_receive_update_output(StateWithDelay(module, item)))
1087
+ self.state_delay: StateWithDelay = module._get_after_update(key)
1088
+ self.state_delay.register_delay(time)
1037
1089
 
1038
1090
  def __call__(self, *args, **kwargs):
1039
1091
  """
@@ -1044,12 +1096,76 @@ class PrefetchDelayAt(Node):
1044
1096
  Any
1045
1097
  The value of the state or variable at the specified delay time.
1046
1098
  """
1047
- # return self.state_delay.retrieve_at_time(self.time)
1048
- return self.state_delay.retrieve_at_step(self.step)
1099
+ if self.time is None:
1100
+ return _get_prefetch_item(self).value
1101
+ else:
1102
+ return self.state_delay.retrieve_at_step(self.step)
1103
+
1104
+
1105
+ class OutputDelayAt(Node):
1106
+ """
1107
+ Provides access to a specific delayed state or variable value at the specific time.
1108
+
1109
+ This class represents the final step in the prefetch delay chain, providing
1110
+ actual access to state values at a specific delay time. It converts the
1111
+ specified time delay into steps and registers the delay with the appropriate
1112
+ StateWithDelay handler.
1113
+
1114
+ Parameters
1115
+ ----------
1116
+ module : Dynamics
1117
+ The dynamics module that contains the referenced state or variable.
1118
+ time : ArrayLike
1119
+ The amount of time to delay access by, typically in time units (e.g., milliseconds).
1120
+
1121
+ Examples
1122
+ --------
1123
+ >>> import brainstate
1124
+ >>> import brainunit as u
1125
+ >>> neuron = brainstate.nn.LIF(10)
1126
+ >>> # Create a reference to voltage delayed by 5ms
1127
+ >>> delayed_spike = OutputDelayAt(neuron, 5.0 * u.ms)
1128
+ >>> # Get the delayed value
1129
+ >>> v_value = delayed_spike()
1130
+ """
1131
+
1132
+ def __init__(
1133
+ self,
1134
+ module: Dynamics,
1135
+ time: Optional[ArrayLike] = None,
1136
+ variable_like: Optional[PyTree] = None,
1137
+ ):
1138
+ super().__init__()
1139
+ assert isinstance(module, Dynamics), 'The module should be an instance of Dynamics.'
1140
+ self.module = module
1141
+ dt = environ.get_dt()
1142
+ if time is None:
1143
+ time = u.math.zeros_like(dt)
1144
+ self.time = time
1145
+ self.step = u.math.asarray(time / dt, dtype=environ.ditype())
1146
+
1147
+ # register the delay
1148
+ key = _get_output_delay_key()
1149
+ if not module._has_after_update(key):
1150
+ delay = Delay(
1151
+ jax.ShapeDtypeStruct(module.out_size, dtype=environ.dftype()),
1152
+ time,
1153
+ take_aware_unit=True
1154
+ )
1155
+ module._add_after_update(key, receive_update_output(delay))
1156
+ self.out_delay: Delay = module._get_after_update(key)
1157
+ self.out_delay.register_delay(time)
1158
+
1159
+ def __call__(self, *args, **kwargs):
1160
+ return self.out_delay.retrieve_at_step(self.step)
1161
+
1162
+
1163
+ def _get_prefetch_delay_key(item) -> str:
1164
+ return f'{item}-prefetch-delay'
1049
1165
 
1050
1166
 
1051
- def _get_delay_key(item) -> str:
1052
- return f'{item}-delay'
1167
+ def _get_output_delay_key() -> str:
1168
+ return f'output-delay'
1053
1169
 
1054
1170
 
1055
1171
  def _get_prefetch_item(target: Union[Prefetch, PrefetchDelayAt]) -> Any:
@@ -1064,7 +1180,7 @@ def _get_prefetch_item_delay(target: Union[Prefetch, PrefetchDelay, PrefetchDela
1064
1180
  f'The target module should be an instance '
1065
1181
  f'of Dynamics. But got {target.module}.'
1066
1182
  )
1067
- delay = target.module._get_after_update(_get_delay_key(target.item))
1183
+ delay = target.module._get_after_update(_get_prefetch_delay_key(target.item))
1068
1184
  if not isinstance(delay, StateWithDelay):
1069
1185
  raise TypeError(f'The prefetch target should be a {StateWithDelay.__name__} when accessing '
1070
1186
  f'its delay. But got {delay}.')
@@ -23,8 +23,8 @@ import jax.typing
23
23
 
24
24
  from brainstate import random, functional as F
25
25
  from brainstate._state import ParamState
26
- from brainstate.nn._module import ElementWiseBlock
27
26
  from brainstate.typing import ArrayLike
27
+ from ._module import ElementWiseBlock
28
28
 
29
29
  __all__ = [
30
30
  # activation functions
@@ -17,8 +17,8 @@ from typing import Optional, Callable, Union
17
17
 
18
18
  from brainstate import init
19
19
  from brainstate._state import ParamState
20
- from brainstate.nn._module import Module
21
20
  from brainstate.typing import ArrayLike
21
+ from ._module import Module
22
22
 
23
23
  __all__ = [
24
24
  'Embedding',