brainstate 0.1.2__py2.py3-none-any.whl → 0.1.4__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +0 -15
- brainstate/compile/_jit.py +14 -5
- brainstate/compile/_make_jaxpr.py +78 -22
- brainstate/compile/_make_jaxpr_test.py +13 -2
- brainstate/graph/_graph_node.py +1 -1
- brainstate/graph/_graph_operation.py +4 -4
- brainstate/mixin.py +30 -14
- brainstate/nn/__init__.py +84 -17
- brainstate/nn/{_interaction/_conv.py → _conv.py} +1 -1
- brainstate/nn/{_dynamics/_state_delay.py → _delay.py} +19 -3
- brainstate/nn/{_elementwise/_dropout.py → _dropout.py} +6 -5
- brainstate/nn/{_dynamics/_dynamics_base.py → _dynamics.py} +137 -21
- brainstate/nn/{_elementwise/_elementwise.py → _elementwise.py} +1 -1
- brainstate/nn/{_interaction/_embedding.py → _embedding.py} +1 -1
- brainstate/nn/{_event/_fixedprob_mv.py → _fixedprob.py} +96 -25
- brainstate/nn/{_dyn_impl/_inputs.py → _inputs.py} +4 -5
- brainstate/nn/{_interaction/_linear.py → _linear.py} +2 -5
- brainstate/nn/{_event/_linear_mv.py → _linear_mv.py} +2 -2
- brainstate/nn/{_event/__init__.py → _ltp.py} +7 -5
- brainstate/nn/_module.py +5 -5
- brainstate/nn/{_dyn_impl/_dynamics_neuron.py → _neuron.py} +2 -2
- brainstate/nn/{_interaction/_normalizations.py → _normalizations.py} +1 -1
- brainstate/nn/{_interaction/_poolings.py → _poolings.py} +1 -1
- brainstate/nn/{_interaction/_poolings_test.py → _poolings_test.py} +1 -1
- brainstate/nn/_projection.py +486 -0
- brainstate/nn/{_dyn_impl/_rate_rnns.py → _rate_rnns.py} +2 -2
- brainstate/nn/{_dyn_impl/_readout.py → _readout.py} +3 -3
- brainstate/nn/_stp.py +236 -0
- brainstate/nn/{_dyn_impl/_dynamics_synapse.py → _synapse.py} +19 -212
- brainstate/nn/_synaptic_projection.py +423 -0
- brainstate/nn/{_dynamics/_synouts.py → _synouts.py} +4 -1
- brainstate/surrogate.py +1 -1
- brainstate/typing.py +1 -1
- brainstate/util/__init__.py +14 -14
- brainstate/util/{_pretty_pytree.py → pretty_pytree.py} +2 -2
- {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/METADATA +1 -1
- {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/RECORD +61 -63
- brainstate/nn/_dyn_impl/__init__.py +0 -42
- brainstate/nn/_dynamics/__init__.py +0 -37
- brainstate/nn/_dynamics/_projection_base.py +0 -362
- brainstate/nn/_elementwise/__init__.py +0 -22
- brainstate/nn/_interaction/__init__.py +0 -41
- /brainstate/nn/{_interaction/_conv_test.py → _conv_test.py} +0 -0
- /brainstate/nn/{_elementwise/_dropout_test.py → _dropout_test.py} +0 -0
- /brainstate/nn/{_dynamics/_dynamics_base_test.py → _dynamics_test.py} +0 -0
- /brainstate/nn/{_elementwise/_elementwise_test.py → _elementwise_test.py} +0 -0
- /brainstate/nn/{_event/_fixedprob_mv_test.py → _fixedprob_test.py} +0 -0
- /brainstate/nn/{_event/_linear_mv_test.py → _linear_mv_test.py} +0 -0
- /brainstate/nn/{_interaction/_linear_test.py → _linear_test.py} +0 -0
- /brainstate/nn/{_dyn_impl/_dynamics_neuron_test.py → _neuron_test.py} +0 -0
- /brainstate/nn/{_interaction/_normalizations_test.py → _normalizations_test.py} +0 -0
- /brainstate/nn/{_dyn_impl/_rate_rnns_test.py → _rate_rnns_test.py} +0 -0
- /brainstate/nn/{_dyn_impl/_readout_test.py → _readout_test.py} +0 -0
- /brainstate/nn/{_dyn_impl/_dynamics_synapse_test.py → _synapse_test.py} +0 -0
- /brainstate/nn/{_dynamics/_synouts_test.py → _synouts_test.py} +0 -0
- /brainstate/util/{_caller.py → caller.py} +0 -0
- /brainstate/util/{_error.py → error.py} +0 -0
- /brainstate/util/{_others.py → others.py} +0 -0
- /brainstate/util/{_pretty_repr.py → pretty_repr.py} +0 -0
- /brainstate/util/{_pretty_table.py → pretty_table.py} +0 -0
- /brainstate/util/{_scaling.py → scaling.py} +0 -0
- /brainstate/util/{_struct.py → struct.py} +0 -0
- {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/LICENSE +0 -0
- {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/WHEEL +0 -0
- {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/top_level.txt +0 -0
brainstate/nn/__init__.py
CHANGED
@@ -19,46 +19,113 @@ from ._collective_ops import *
|
|
19
19
|
from ._collective_ops import __all__ as collective_ops_all
|
20
20
|
from ._common import *
|
21
21
|
from ._common import __all__ as common_all
|
22
|
-
from .
|
23
|
-
from .
|
22
|
+
from ._conv import *
|
23
|
+
from ._conv import __all__ as conv_all
|
24
|
+
from ._delay import *
|
25
|
+
from ._delay import __all__ as state_delay_all
|
26
|
+
from ._dropout import *
|
27
|
+
from ._dropout import __all__ as dropout_all
|
24
28
|
from ._dynamics import *
|
25
|
-
from ._dynamics import __all__ as
|
29
|
+
from ._dynamics import __all__ as dyn_all
|
26
30
|
from ._elementwise import *
|
27
31
|
from ._elementwise import __all__ as elementwise_all
|
28
|
-
from .
|
29
|
-
from .
|
32
|
+
from ._embedding import *
|
33
|
+
from ._embedding import __all__ as embed_all
|
30
34
|
from ._exp_euler import *
|
31
35
|
from ._exp_euler import __all__ as exp_euler_all
|
32
|
-
from .
|
33
|
-
from
|
36
|
+
from ._fixedprob import *
|
37
|
+
from._fixedprob import __all__ as fixedprob_all
|
38
|
+
from ._inputs import *
|
39
|
+
from ._inputs import __all__ as inputs_all
|
40
|
+
from ._linear import *
|
41
|
+
from ._linear import __all__ as linear_all
|
42
|
+
from ._linear_mv import *
|
43
|
+
from ._linear_mv import __all__ as linear_mv_all
|
44
|
+
from ._ltp import *
|
45
|
+
from ._ltp import __all__ as ltp_all
|
34
46
|
from ._module import *
|
35
47
|
from ._module import __all__ as module_all
|
48
|
+
from ._neuron import *
|
49
|
+
from ._neuron import __all__ as dyn_neuron_all
|
50
|
+
from ._normalizations import *
|
51
|
+
from ._normalizations import __all__ as normalizations_all
|
52
|
+
from ._poolings import *
|
53
|
+
from ._poolings import __all__ as poolings_all
|
54
|
+
from ._projection import *
|
55
|
+
from ._projection import __all__ as projection_all
|
56
|
+
from ._rate_rnns import *
|
57
|
+
from ._rate_rnns import __all__ as rate_rnns
|
58
|
+
from ._readout import *
|
59
|
+
from ._readout import __all__ as readout_all
|
60
|
+
from ._stp import *
|
61
|
+
from ._stp import __all__ as stp_all
|
62
|
+
from ._synapse import *
|
63
|
+
from ._synapse import __all__ as dyn_synapse_all
|
64
|
+
from ._synaptic_projection import *
|
65
|
+
from ._synaptic_projection import __all__ as _syn_proj_all
|
66
|
+
from ._synouts import *
|
67
|
+
from ._synouts import __all__ as synouts_all
|
36
68
|
from ._utils import *
|
37
69
|
from ._utils import __all__ as utils_all
|
38
70
|
|
39
71
|
__all__ = (
|
40
|
-
[
|
72
|
+
[
|
73
|
+
'metrics',
|
74
|
+
]
|
41
75
|
+ collective_ops_all
|
42
76
|
+ common_all
|
43
|
-
+ dyn_impl_all
|
44
|
-
+ dynamics_all
|
45
77
|
+ elementwise_all
|
46
78
|
+ module_all
|
47
79
|
+ exp_euler_all
|
48
|
-
+ interaction_all
|
49
80
|
+ utils_all
|
50
|
-
+
|
81
|
+
+ dyn_all
|
82
|
+
+ projection_all
|
83
|
+
+ state_delay_all
|
84
|
+
+ synouts_all
|
85
|
+
+ conv_all
|
86
|
+
+ linear_all
|
87
|
+
+ normalizations_all
|
88
|
+
+ poolings_all
|
89
|
+
+ fixedprob_all
|
90
|
+
+ linear_mv_all
|
91
|
+
+ embed_all
|
92
|
+
+ dropout_all
|
93
|
+
+ elementwise_all
|
94
|
+
+ dyn_neuron_all
|
95
|
+
+ dyn_synapse_all
|
96
|
+
+ inputs_all
|
97
|
+
+ rate_rnns
|
98
|
+
+ readout_all
|
99
|
+
+ stp_all
|
100
|
+
+ ltp_all
|
101
|
+
+ _syn_proj_all
|
51
102
|
)
|
52
103
|
|
53
104
|
del (
|
54
105
|
collective_ops_all,
|
55
106
|
common_all,
|
56
|
-
dyn_impl_all,
|
57
|
-
dynamics_all,
|
58
|
-
elementwise_all,
|
59
107
|
module_all,
|
60
108
|
exp_euler_all,
|
61
|
-
interaction_all,
|
62
109
|
utils_all,
|
63
|
-
|
110
|
+
dyn_all,
|
111
|
+
projection_all,
|
112
|
+
state_delay_all,
|
113
|
+
synouts_all,
|
114
|
+
conv_all,
|
115
|
+
linear_all,
|
116
|
+
normalizations_all,
|
117
|
+
poolings_all,
|
118
|
+
embed_all,
|
119
|
+
fixedprob_all,
|
120
|
+
linear_mv_all,
|
121
|
+
dropout_all,
|
122
|
+
elementwise_all,
|
123
|
+
dyn_neuron_all,
|
124
|
+
dyn_synapse_all,
|
125
|
+
inputs_all,
|
126
|
+
readout_all,
|
127
|
+
rate_rnns,
|
128
|
+
stp_all,
|
129
|
+
ltp_all,
|
130
|
+
_syn_proj_all,
|
64
131
|
)
|
@@ -23,8 +23,8 @@ import jax.numpy as jnp
|
|
23
23
|
|
24
24
|
from brainstate import init, functional
|
25
25
|
from brainstate._state import ParamState
|
26
|
-
from brainstate.nn._module import Module
|
27
26
|
from brainstate.typing import ArrayLike
|
27
|
+
from ._module import Module
|
28
28
|
|
29
29
|
T = TypeVar('T')
|
30
30
|
|
@@ -27,9 +27,9 @@ from brainstate import environ
|
|
27
27
|
from brainstate._state import ShortTermState, State
|
28
28
|
from brainstate.compile import jit_error_if
|
29
29
|
from brainstate.graph import Node
|
30
|
-
from brainstate.nn._collective_ops import call_order
|
31
|
-
from brainstate.nn._module import Module
|
32
30
|
from brainstate.typing import ArrayLike, PyTree
|
31
|
+
from ._collective_ops import call_order
|
32
|
+
from ._module import Module
|
33
33
|
|
34
34
|
__all__ = [
|
35
35
|
'Delay', 'DelayAccess', 'StateWithDelay',
|
@@ -135,6 +135,7 @@ class Delay(Module):
|
|
135
135
|
entries: Optional[Dict] = None, # delay access entry
|
136
136
|
delay_method: Optional[str] = _DELAY_ROTATE, # delay method
|
137
137
|
interp_method: str = _INTERP_LINEAR, # interpolation method
|
138
|
+
take_aware_unit: bool = False
|
138
139
|
):
|
139
140
|
# target information
|
140
141
|
self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), target_info)
|
@@ -170,6 +171,9 @@ class Delay(Module):
|
|
170
171
|
for entry, delay_time in entries.items():
|
171
172
|
self.register_entry(entry, delay_time)
|
172
173
|
|
174
|
+
self.take_aware_unit = take_aware_unit
|
175
|
+
self._unit = None
|
176
|
+
|
173
177
|
@property
|
174
178
|
def history(self):
|
175
179
|
return self._history
|
@@ -326,7 +330,14 @@ class Delay(Module):
|
|
326
330
|
indices = (delay_idx,) + indices
|
327
331
|
|
328
332
|
# the delay data
|
329
|
-
|
333
|
+
if self._unit is None:
|
334
|
+
return jax.tree.map(lambda a: a[indices], self.history.value)
|
335
|
+
else:
|
336
|
+
return jax.tree.map(
|
337
|
+
lambda hist, unit: u.maybe_decimal(hist[indices] * unit),
|
338
|
+
self.history.value,
|
339
|
+
self._unit
|
340
|
+
)
|
330
341
|
|
331
342
|
def retrieve_at_time(self, delay_time, *indices) -> PyTree:
|
332
343
|
"""
|
@@ -389,6 +400,9 @@ class Delay(Module):
|
|
389
400
|
"""
|
390
401
|
assert self.history is not None, 'The delay history is not initialized.'
|
391
402
|
|
403
|
+
if self.take_aware_unit and self._unit is None:
|
404
|
+
self._unit = jax.tree.map(lambda x: u.get_unit(x), current, is_leaf=u.math.is_quantity)
|
405
|
+
|
392
406
|
# update the delay data at the rotation index
|
393
407
|
if self.delay_method == _DELAY_ROTATE:
|
394
408
|
i = environ.get(environ.I)
|
@@ -415,6 +429,8 @@ class Delay(Module):
|
|
415
429
|
raise ValueError(f'Unknown updating method "{self.delay_method}"')
|
416
430
|
|
417
431
|
|
432
|
+
|
433
|
+
|
418
434
|
class StateWithDelay(Delay):
|
419
435
|
"""
|
420
436
|
A ``State`` type that defines the state in a differential equation.
|
@@ -22,8 +22,8 @@ import jax.numpy as jnp
|
|
22
22
|
|
23
23
|
from brainstate import random, environ, init
|
24
24
|
from brainstate._state import ShortTermState
|
25
|
-
from brainstate.nn._module import ElementWiseBlock
|
26
25
|
from brainstate.typing import Size
|
26
|
+
from ._module import ElementWiseBlock
|
27
27
|
|
28
28
|
__all__ = [
|
29
29
|
'DropoutFixed', 'Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d',
|
@@ -409,7 +409,8 @@ class DropoutFixed(ElementWiseBlock):
|
|
409
409
|
self.out_size = in_size
|
410
410
|
|
411
411
|
def init_state(self, batch_size=None, **kwargs):
|
412
|
-
|
412
|
+
if self.prob < 1.:
|
413
|
+
self.mask = ShortTermState(init.param(partial(random.bernoulli, self.prob), self.in_size, batch_size))
|
413
414
|
|
414
415
|
def update(self, x):
|
415
416
|
dtype = u.math.get_dtype(x)
|
@@ -418,8 +419,8 @@ class DropoutFixed(ElementWiseBlock):
|
|
418
419
|
if self.mask.value.shape != x.shape:
|
419
420
|
raise ValueError(f"Input shape {x.shape} does not match the mask shape {self.mask.value.shape}. "
|
420
421
|
f"Please call `init_state()` method first.")
|
421
|
-
return
|
422
|
-
|
423
|
-
|
422
|
+
return u.math.where(self.mask.value,
|
423
|
+
u.math.asarray(x / self.prob, dtype=dtype),
|
424
|
+
u.math.asarray(0., dtype=dtype) * u.get_unit(x))
|
424
425
|
else:
|
425
426
|
return x
|
@@ -36,18 +36,20 @@ For handling the delays:
|
|
36
36
|
from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, TYPE_CHECKING
|
37
37
|
|
38
38
|
import brainunit as u
|
39
|
+
import jax
|
39
40
|
import numpy as np
|
40
41
|
|
41
42
|
from brainstate import environ
|
42
43
|
from brainstate._state import State
|
43
44
|
from brainstate.graph import Node
|
44
|
-
from brainstate.mixin import ParamDescriber
|
45
|
-
from brainstate.
|
46
|
-
from
|
47
|
-
from .
|
45
|
+
from brainstate.mixin import ParamDescriber, UpdateReturn
|
46
|
+
from brainstate.typing import Size, ArrayLike, PyTree
|
47
|
+
from ._delay import StateWithDelay, Delay
|
48
|
+
from ._module import Module
|
48
49
|
|
49
50
|
__all__ = [
|
50
|
-
'DynamicsGroup', 'Projection', 'Dynamics',
|
51
|
+
'DynamicsGroup', 'Projection', 'Dynamics',
|
52
|
+
'Prefetch', 'PrefetchDelay', 'PrefetchDelayAt', 'OutputDelayAt',
|
51
53
|
]
|
52
54
|
|
53
55
|
T = TypeVar('T')
|
@@ -99,7 +101,7 @@ class Projection(Module):
|
|
99
101
|
raise ValueError('Do not implement the update() function.')
|
100
102
|
|
101
103
|
|
102
|
-
class Dynamics(Module):
|
104
|
+
class Dynamics(Module, UpdateReturn):
|
103
105
|
"""
|
104
106
|
Base class for implementing neural dynamics models in BrainState.
|
105
107
|
|
@@ -810,17 +812,64 @@ class Dynamics(Module):
|
|
810
812
|
>>> n1.align_pre(brainstate.nn.Expon.desc(n1.varshape)) # n2 will run after n1
|
811
813
|
"""
|
812
814
|
if isinstance(dyn, Dynamics):
|
813
|
-
self._add_after_update(dyn
|
815
|
+
self._add_after_update(id(dyn), dyn)
|
814
816
|
return dyn
|
815
817
|
elif isinstance(dyn, ParamDescriber):
|
816
818
|
if not issubclass(dyn.cls, Dynamics):
|
817
819
|
raise TypeError(f'The input {dyn} should be an instance of {Dynamics}.')
|
818
820
|
if not self._has_after_update(dyn.identifier):
|
819
|
-
self._add_after_update(
|
821
|
+
self._add_after_update(
|
822
|
+
dyn.identifier,
|
823
|
+
dyn() if ('in_size' in dyn.kwargs or len(dyn.args) > 0) else dyn(in_size=self.varshape)
|
824
|
+
)
|
820
825
|
return self._get_after_update(dyn.identifier)
|
821
826
|
else:
|
822
827
|
raise TypeError(f'The input {dyn} should be an instance of {Dynamics} or a delayed initializer.')
|
823
828
|
|
829
|
+
def prefetch_delay(
|
830
|
+
self,
|
831
|
+
state: str,
|
832
|
+
delay: Optional[ArrayLike] = None
|
833
|
+
) -> 'PrefetchDelayAt':
|
834
|
+
"""
|
835
|
+
Create a reference to a delayed state or variable in the module.
|
836
|
+
|
837
|
+
This method simplifies the process of accessing a delayed version of a state or variable
|
838
|
+
within the module. It first creates a prefetch reference to the specified state,
|
839
|
+
then specifies the delay time for accessing this state.
|
840
|
+
|
841
|
+
Args:
|
842
|
+
state (str): The name of the state or variable to reference.
|
843
|
+
delay (Optional[ArrayLike]): The amount of time to delay the variable access,
|
844
|
+
typically in time units (e.g., milliseconds). Defaults to None.
|
845
|
+
|
846
|
+
Returns:
|
847
|
+
PrefetchDelayAt: An object that provides access to the variable at the specified delay time.
|
848
|
+
"""
|
849
|
+
return self.prefetch(state).delay.at(delay)
|
850
|
+
|
851
|
+
def output_delay(
|
852
|
+
self,
|
853
|
+
delay: Optional[ArrayLike] = None,
|
854
|
+
variable_like: PyTree = None
|
855
|
+
) -> 'OutputDelayAt':
|
856
|
+
"""
|
857
|
+
Create a reference to the delayed output of the module.
|
858
|
+
|
859
|
+
This method simplifies the process of accessing a delayed version of the module's output.
|
860
|
+
It instantiates an `OutputDelayAt` object, which can be used to retrieve the output value
|
861
|
+
at the specified delay time.
|
862
|
+
|
863
|
+
Args:
|
864
|
+
delay (Optional[ArrayLike]): The amount of time to delay the output access,
|
865
|
+
typically in time units (e.g., milliseconds). Defaults to None.
|
866
|
+
variable_like:
|
867
|
+
|
868
|
+
Returns:
|
869
|
+
OutputDelayAt: An object that provides access to the module's output at the specified delay time.
|
870
|
+
"""
|
871
|
+
return OutputDelayAt(self, delay)
|
872
|
+
|
824
873
|
|
825
874
|
class Prefetch(Node):
|
826
875
|
"""
|
@@ -885,6 +934,7 @@ class Prefetch(Node):
|
|
885
934
|
An object that provides access to delayed versions of the prefetched item.
|
886
935
|
"""
|
887
936
|
return PrefetchDelay(self.module, self.item)
|
937
|
+
# return PrefetchDelayAt(self.module, self.item, time)
|
888
938
|
|
889
939
|
def __call__(self, *args, **kwargs):
|
890
940
|
"""
|
@@ -1007,7 +1057,7 @@ class PrefetchDelayAt(Node):
|
|
1007
1057
|
self,
|
1008
1058
|
module: Dynamics,
|
1009
1059
|
item: str,
|
1010
|
-
time: ArrayLike
|
1060
|
+
time: ArrayLike = None,
|
1011
1061
|
):
|
1012
1062
|
"""
|
1013
1063
|
Initialize a PrefetchDelayAt object.
|
@@ -1026,14 +1076,16 @@ class PrefetchDelayAt(Node):
|
|
1026
1076
|
self.module = module
|
1027
1077
|
self.item = item
|
1028
1078
|
self.time = time
|
1029
|
-
self.step = u.math.asarray(time / environ.get_dt(), dtype=environ.ditype())
|
1030
1079
|
|
1031
|
-
|
1032
|
-
|
1033
|
-
|
1034
|
-
|
1035
|
-
|
1036
|
-
|
1080
|
+
if time is not None:
|
1081
|
+
self.step = u.math.asarray(time / environ.get_dt(), dtype=environ.ditype())
|
1082
|
+
|
1083
|
+
# register the delay
|
1084
|
+
key = _get_prefetch_delay_key(item)
|
1085
|
+
if not module._has_after_update(key):
|
1086
|
+
module._add_after_update(key, not_receive_update_output(StateWithDelay(module, item)))
|
1087
|
+
self.state_delay: StateWithDelay = module._get_after_update(key)
|
1088
|
+
self.state_delay.register_delay(time)
|
1037
1089
|
|
1038
1090
|
def __call__(self, *args, **kwargs):
|
1039
1091
|
"""
|
@@ -1044,12 +1096,76 @@ class PrefetchDelayAt(Node):
|
|
1044
1096
|
Any
|
1045
1097
|
The value of the state or variable at the specified delay time.
|
1046
1098
|
"""
|
1047
|
-
|
1048
|
-
|
1099
|
+
if self.time is None:
|
1100
|
+
return _get_prefetch_item(self).value
|
1101
|
+
else:
|
1102
|
+
return self.state_delay.retrieve_at_step(self.step)
|
1103
|
+
|
1104
|
+
|
1105
|
+
class OutputDelayAt(Node):
|
1106
|
+
"""
|
1107
|
+
Provides access to a specific delayed state or variable value at the specific time.
|
1108
|
+
|
1109
|
+
This class represents the final step in the prefetch delay chain, providing
|
1110
|
+
actual access to state values at a specific delay time. It converts the
|
1111
|
+
specified time delay into steps and registers the delay with the appropriate
|
1112
|
+
StateWithDelay handler.
|
1113
|
+
|
1114
|
+
Parameters
|
1115
|
+
----------
|
1116
|
+
module : Dynamics
|
1117
|
+
The dynamics module that contains the referenced state or variable.
|
1118
|
+
time : ArrayLike
|
1119
|
+
The amount of time to delay access by, typically in time units (e.g., milliseconds).
|
1120
|
+
|
1121
|
+
Examples
|
1122
|
+
--------
|
1123
|
+
>>> import brainstate
|
1124
|
+
>>> import brainunit as u
|
1125
|
+
>>> neuron = brainstate.nn.LIF(10)
|
1126
|
+
>>> # Create a reference to voltage delayed by 5ms
|
1127
|
+
>>> delayed_spike = OutputDelayAt(neuron, 5.0 * u.ms)
|
1128
|
+
>>> # Get the delayed value
|
1129
|
+
>>> v_value = delayed_spike()
|
1130
|
+
"""
|
1131
|
+
|
1132
|
+
def __init__(
|
1133
|
+
self,
|
1134
|
+
module: Dynamics,
|
1135
|
+
time: Optional[ArrayLike] = None,
|
1136
|
+
variable_like: Optional[PyTree] = None,
|
1137
|
+
):
|
1138
|
+
super().__init__()
|
1139
|
+
assert isinstance(module, Dynamics), 'The module should be an instance of Dynamics.'
|
1140
|
+
self.module = module
|
1141
|
+
dt = environ.get_dt()
|
1142
|
+
if time is None:
|
1143
|
+
time = u.math.zeros_like(dt)
|
1144
|
+
self.time = time
|
1145
|
+
self.step = u.math.asarray(time / dt, dtype=environ.ditype())
|
1146
|
+
|
1147
|
+
# register the delay
|
1148
|
+
key = _get_output_delay_key()
|
1149
|
+
if not module._has_after_update(key):
|
1150
|
+
delay = Delay(
|
1151
|
+
jax.ShapeDtypeStruct(module.out_size, dtype=environ.dftype()),
|
1152
|
+
time,
|
1153
|
+
take_aware_unit=True
|
1154
|
+
)
|
1155
|
+
module._add_after_update(key, receive_update_output(delay))
|
1156
|
+
self.out_delay: Delay = module._get_after_update(key)
|
1157
|
+
self.out_delay.register_delay(time)
|
1158
|
+
|
1159
|
+
def __call__(self, *args, **kwargs):
|
1160
|
+
return self.out_delay.retrieve_at_step(self.step)
|
1161
|
+
|
1162
|
+
|
1163
|
+
def _get_prefetch_delay_key(item) -> str:
|
1164
|
+
return f'{item}-prefetch-delay'
|
1049
1165
|
|
1050
1166
|
|
1051
|
-
def
|
1052
|
-
return f'
|
1167
|
+
def _get_output_delay_key() -> str:
|
1168
|
+
return f'output-delay'
|
1053
1169
|
|
1054
1170
|
|
1055
1171
|
def _get_prefetch_item(target: Union[Prefetch, PrefetchDelayAt]) -> Any:
|
@@ -1064,7 +1180,7 @@ def _get_prefetch_item_delay(target: Union[Prefetch, PrefetchDelay, PrefetchDela
|
|
1064
1180
|
f'The target module should be an instance '
|
1065
1181
|
f'of Dynamics. But got {target.module}.'
|
1066
1182
|
)
|
1067
|
-
delay = target.module._get_after_update(
|
1183
|
+
delay = target.module._get_after_update(_get_prefetch_delay_key(target.item))
|
1068
1184
|
if not isinstance(delay, StateWithDelay):
|
1069
1185
|
raise TypeError(f'The prefetch target should be a {StateWithDelay.__name__} when accessing '
|
1070
1186
|
f'its delay. But got {delay}.')
|
@@ -23,8 +23,8 @@ import jax.typing
|
|
23
23
|
|
24
24
|
from brainstate import random, functional as F
|
25
25
|
from brainstate._state import ParamState
|
26
|
-
from brainstate.nn._module import ElementWiseBlock
|
27
26
|
from brainstate.typing import ArrayLike
|
27
|
+
from ._module import ElementWiseBlock
|
28
28
|
|
29
29
|
__all__ = [
|
30
30
|
# activation functions
|
@@ -17,8 +17,8 @@ from typing import Optional, Callable, Union
|
|
17
17
|
|
18
18
|
from brainstate import init
|
19
19
|
from brainstate._state import ParamState
|
20
|
-
from brainstate.nn._module import Module
|
21
20
|
from brainstate.typing import ArrayLike
|
21
|
+
from ._module import Module
|
22
22
|
|
23
23
|
__all__ = [
|
24
24
|
'Embedding',
|