brainstate 0.1.2__py2.py3-none-any.whl → 0.1.3__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +10 -10
- brainstate/mixin.py +1 -14
- brainstate/nn/__init__.py +81 -17
- brainstate/nn/{_interaction/_conv.py → _conv.py} +1 -1
- brainstate/nn/{_dynamics/_state_delay.py → _delay.py} +6 -2
- brainstate/nn/{_elementwise/_dropout.py → _dropout.py} +1 -1
- brainstate/nn/{_dynamics/_dynamics_base.py → _dynamics.py} +139 -18
- brainstate/nn/{_elementwise/_elementwise.py → _elementwise.py} +1 -1
- brainstate/nn/{_interaction/_embedding.py → _embedding.py} +1 -1
- brainstate/nn/{_event/_fixedprob_mv.py → _fixedprob_mv.py} +1 -1
- brainstate/nn/{_dyn_impl/_inputs.py → _inputs.py} +4 -5
- brainstate/nn/{_interaction/_linear.py → _linear.py} +2 -5
- brainstate/nn/{_event/_linear_mv.py → _linear_mv.py} +1 -1
- brainstate/nn/{_event/__init__.py → _ltp.py} +7 -5
- brainstate/nn/{_dyn_impl/_dynamics_neuron.py → _neuron.py} +2 -2
- brainstate/nn/{_interaction/_normalizations.py → _normalizations.py} +1 -1
- brainstate/nn/{_interaction/_poolings.py → _poolings.py} +1 -1
- brainstate/nn/{_interaction/_poolings_test.py → _poolings_test.py} +1 -1
- brainstate/nn/{_dynamics/_projection_base.py → _projection.py} +35 -3
- brainstate/nn/{_dyn_impl/_rate_rnns.py → _rate_rnns.py} +2 -2
- brainstate/nn/{_dyn_impl/_readout.py → _readout.py} +3 -3
- brainstate/nn/_stp.py +236 -0
- brainstate/nn/{_dyn_impl/_dynamics_synapse.py → _synapse.py} +17 -206
- brainstate/nn/_synaptic_projection.py +133 -0
- brainstate/nn/{_dynamics/_synouts.py → _synouts.py} +4 -1
- {brainstate-0.1.2.dist-info → brainstate-0.1.3.dist-info}/METADATA +1 -1
- {brainstate-0.1.2.dist-info → brainstate-0.1.3.dist-info}/RECORD +44 -46
- brainstate/nn/_dyn_impl/__init__.py +0 -42
- brainstate/nn/_dynamics/__init__.py +0 -37
- brainstate/nn/_elementwise/__init__.py +0 -22
- brainstate/nn/_interaction/__init__.py +0 -41
- /brainstate/nn/{_interaction/_conv_test.py → _conv_test.py} +0 -0
- /brainstate/nn/{_elementwise/_dropout_test.py → _dropout_test.py} +0 -0
- /brainstate/nn/{_dynamics/_dynamics_base_test.py → _dynamics_test.py} +0 -0
- /brainstate/nn/{_elementwise/_elementwise_test.py → _elementwise_test.py} +0 -0
- /brainstate/nn/{_event/_fixedprob_mv_test.py → _fixedprob_mv_test.py} +0 -0
- /brainstate/nn/{_event/_linear_mv_test.py → _linear_mv_test.py} +0 -0
- /brainstate/nn/{_interaction/_linear_test.py → _linear_test.py} +0 -0
- /brainstate/nn/{_dyn_impl/_dynamics_neuron_test.py → _neuron_test.py} +0 -0
- /brainstate/nn/{_interaction/_normalizations_test.py → _normalizations_test.py} +0 -0
- /brainstate/nn/{_dyn_impl/_rate_rnns_test.py → _rate_rnns_test.py} +0 -0
- /brainstate/nn/{_dyn_impl/_readout_test.py → _readout_test.py} +0 -0
- /brainstate/nn/{_dyn_impl/_dynamics_synapse_test.py → _synapse_test.py} +0 -0
- /brainstate/nn/{_dynamics/_synouts_test.py → _synouts_test.py} +0 -0
- {brainstate-0.1.2.dist-info → brainstate-0.1.3.dist-info}/LICENSE +0 -0
- {brainstate-0.1.2.dist-info → brainstate-0.1.3.dist-info}/WHEEL +0 -0
- {brainstate-0.1.2.dist-info → brainstate-0.1.3.dist-info}/top_level.txt +0 -0
brainstate/__init__.py
CHANGED
brainstate/_compatible_import.py
CHANGED
@@ -19,7 +19,7 @@
|
|
19
19
|
import importlib.util
|
20
20
|
from contextlib import contextmanager
|
21
21
|
from functools import partial
|
22
|
-
from typing import Iterable, Hashable, TypeVar, Callable
|
22
|
+
from typing import Iterable, Hashable, TypeVar, Callable, TYPE_CHECKING
|
23
23
|
|
24
24
|
import jax
|
25
25
|
|
@@ -45,8 +45,8 @@ T1 = TypeVar("T1")
|
|
45
45
|
T2 = TypeVar("T2")
|
46
46
|
T3 = TypeVar("T3")
|
47
47
|
|
48
|
-
|
49
48
|
from saiunit._compatible_import import wrap_init
|
49
|
+
|
50
50
|
brainevent_installed = importlib.util.find_spec('brainevent') is not None
|
51
51
|
|
52
52
|
from jax.core import get_aval, Tracer
|
@@ -151,13 +151,13 @@ def to_concrete_aval(aval):
|
|
151
151
|
return aval
|
152
152
|
|
153
153
|
|
154
|
-
if brainevent_installed:
|
155
|
-
|
156
|
-
|
154
|
+
if not brainevent_installed:
|
155
|
+
if not TYPE_CHECKING:
|
156
|
+
class BrainEvent:
|
157
|
+
def __getattr__(self, item):
|
158
|
+
raise ImportError('brainevent is not installed, please install brainevent first.')
|
157
159
|
|
158
|
-
|
159
|
-
def __getattr__(self, item):
|
160
|
-
raise ImportError('brainevent is not installed, please install brainevent first.')
|
160
|
+
brainevent = BrainEvent()
|
161
161
|
|
162
|
-
|
163
|
-
brainevent
|
162
|
+
else:
|
163
|
+
import brainevent
|
brainstate/mixin.py
CHANGED
@@ -132,6 +132,7 @@ class AlignPost(Mixin):
|
|
132
132
|
raise NotImplementedError
|
133
133
|
|
134
134
|
|
135
|
+
|
135
136
|
class BindCondData(Mixin):
|
136
137
|
"""Bind temporary conductance data.
|
137
138
|
|
@@ -147,7 +148,6 @@ class BindCondData(Mixin):
|
|
147
148
|
|
148
149
|
|
149
150
|
class UpdateReturn(Mixin):
|
150
|
-
|
151
151
|
def update_return(self) -> PyTree:
|
152
152
|
"""
|
153
153
|
The update function return of the model.
|
@@ -157,19 +157,6 @@ class UpdateReturn(Mixin):
|
|
157
157
|
"""
|
158
158
|
raise NotImplementedError(f'Must implement the "{self.update_return.__name__}()" function.')
|
159
159
|
|
160
|
-
def update_return_info(self) -> PyTree:
|
161
|
-
"""
|
162
|
-
The update return information of the model.
|
163
|
-
|
164
|
-
It should be a pytree, with each element as a ``jax.Array``.
|
165
|
-
|
166
|
-
.. note::
|
167
|
-
Should not include the batch axis and batch in_size.
|
168
|
-
These information will be inferred from the ``mode`` attribute.
|
169
|
-
|
170
|
-
"""
|
171
|
-
raise NotImplementedError(f'Must implement the "{self.update_return_info.__name__}()" function.')
|
172
|
-
|
173
160
|
|
174
161
|
class _MetaUnionType(type):
|
175
162
|
def __new__(cls, name, bases, dct):
|
brainstate/nn/__init__.py
CHANGED
@@ -19,46 +19,110 @@ from ._collective_ops import *
|
|
19
19
|
from ._collective_ops import __all__ as collective_ops_all
|
20
20
|
from ._common import *
|
21
21
|
from ._common import __all__ as common_all
|
22
|
-
from .
|
23
|
-
from .
|
22
|
+
from ._conv import *
|
23
|
+
from ._conv import __all__ as conv_all
|
24
|
+
from ._delay import *
|
25
|
+
from ._delay import __all__ as state_delay_all
|
26
|
+
from ._dropout import *
|
27
|
+
from ._dropout import __all__ as dropout_all
|
24
28
|
from ._dynamics import *
|
25
|
-
from ._dynamics import __all__ as
|
29
|
+
from ._dynamics import __all__ as dyn_all
|
26
30
|
from ._elementwise import *
|
27
31
|
from ._elementwise import __all__ as elementwise_all
|
28
|
-
from .
|
29
|
-
from .
|
32
|
+
from ._embedding import *
|
33
|
+
from ._embedding import __all__ as embed_all
|
30
34
|
from ._exp_euler import *
|
31
35
|
from ._exp_euler import __all__ as exp_euler_all
|
32
|
-
from .
|
33
|
-
from .
|
36
|
+
from ._fixedprob_mv import EventFixedProb, EventFixedNumConn
|
37
|
+
from ._inputs import *
|
38
|
+
from ._inputs import __all__ as inputs_all
|
39
|
+
from ._linear import *
|
40
|
+
from ._linear import __all__ as linear_all
|
41
|
+
from ._linear_mv import EventLinear
|
42
|
+
from ._ltp import *
|
43
|
+
from ._ltp import __all__ as ltp_all
|
34
44
|
from ._module import *
|
35
45
|
from ._module import __all__ as module_all
|
46
|
+
from ._neuron import *
|
47
|
+
from ._neuron import __all__ as dyn_neuron_all
|
48
|
+
from ._normalizations import *
|
49
|
+
from ._normalizations import __all__ as normalizations_all
|
50
|
+
from ._poolings import *
|
51
|
+
from ._poolings import __all__ as poolings_all
|
52
|
+
from ._projection import *
|
53
|
+
from ._projection import __all__ as projection_all
|
54
|
+
from ._rate_rnns import *
|
55
|
+
from ._rate_rnns import __all__ as rate_rnns
|
56
|
+
from ._readout import *
|
57
|
+
from ._readout import __all__ as readout_all
|
58
|
+
from ._stp import *
|
59
|
+
from ._stp import __all__ as stp_all
|
60
|
+
from ._synapse import *
|
61
|
+
from ._synapse import __all__ as dyn_synapse_all
|
62
|
+
from ._synaptic_projection import *
|
63
|
+
from ._synaptic_projection import __all__ as _syn_proj_all
|
64
|
+
from ._synouts import *
|
65
|
+
from ._synouts import __all__ as synouts_all
|
36
66
|
from ._utils import *
|
37
67
|
from ._utils import __all__ as utils_all
|
38
68
|
|
39
69
|
__all__ = (
|
40
|
-
[
|
70
|
+
[
|
71
|
+
'metrics',
|
72
|
+
'EventLinear',
|
73
|
+
'EventFixedProb',
|
74
|
+
'EventFixedNumConn',
|
75
|
+
]
|
41
76
|
+ collective_ops_all
|
42
77
|
+ common_all
|
43
|
-
+ dyn_impl_all
|
44
|
-
+ dynamics_all
|
45
78
|
+ elementwise_all
|
46
79
|
+ module_all
|
47
80
|
+ exp_euler_all
|
48
|
-
+ interaction_all
|
49
81
|
+ utils_all
|
50
|
-
+
|
82
|
+
+ dyn_all
|
83
|
+
+ projection_all
|
84
|
+
+ state_delay_all
|
85
|
+
+ synouts_all
|
86
|
+
+ conv_all
|
87
|
+
+ linear_all
|
88
|
+
+ normalizations_all
|
89
|
+
+ poolings_all
|
90
|
+
+ embed_all
|
91
|
+
+ dropout_all
|
92
|
+
+ elementwise_all
|
93
|
+
+ dyn_neuron_all
|
94
|
+
+ dyn_synapse_all
|
95
|
+
+ inputs_all
|
96
|
+
+ rate_rnns
|
97
|
+
+ readout_all
|
98
|
+
+ stp_all
|
99
|
+
+ ltp_all
|
100
|
+
+ _syn_proj_all
|
51
101
|
)
|
52
102
|
|
53
103
|
del (
|
54
104
|
collective_ops_all,
|
55
105
|
common_all,
|
56
|
-
dyn_impl_all,
|
57
|
-
dynamics_all,
|
58
|
-
elementwise_all,
|
59
106
|
module_all,
|
60
107
|
exp_euler_all,
|
61
|
-
interaction_all,
|
62
108
|
utils_all,
|
63
|
-
|
109
|
+
dyn_all,
|
110
|
+
projection_all,
|
111
|
+
state_delay_all,
|
112
|
+
synouts_all,
|
113
|
+
conv_all,
|
114
|
+
linear_all,
|
115
|
+
normalizations_all,
|
116
|
+
poolings_all,
|
117
|
+
embed_all,
|
118
|
+
dropout_all,
|
119
|
+
elementwise_all,
|
120
|
+
dyn_neuron_all,
|
121
|
+
dyn_synapse_all,
|
122
|
+
inputs_all,
|
123
|
+
readout_all,
|
124
|
+
rate_rnns,
|
125
|
+
stp_all,
|
126
|
+
ltp_all,
|
127
|
+
_syn_proj_all,
|
64
128
|
)
|
@@ -23,8 +23,8 @@ import jax.numpy as jnp
|
|
23
23
|
|
24
24
|
from brainstate import init, functional
|
25
25
|
from brainstate._state import ParamState
|
26
|
-
from brainstate.nn._module import Module
|
27
26
|
from brainstate.typing import ArrayLike
|
27
|
+
from ._module import Module
|
28
28
|
|
29
29
|
T = TypeVar('T')
|
30
30
|
|
@@ -27,9 +27,9 @@ from brainstate import environ
|
|
27
27
|
from brainstate._state import ShortTermState, State
|
28
28
|
from brainstate.compile import jit_error_if
|
29
29
|
from brainstate.graph import Node
|
30
|
-
from brainstate.nn._collective_ops import call_order
|
31
|
-
from brainstate.nn._module import Module
|
32
30
|
from brainstate.typing import ArrayLike, PyTree
|
31
|
+
from ._collective_ops import call_order
|
32
|
+
from ._module import Module
|
33
33
|
|
34
34
|
__all__ = [
|
35
35
|
'Delay', 'DelayAccess', 'StateWithDelay',
|
@@ -135,6 +135,7 @@ class Delay(Module):
|
|
135
135
|
entries: Optional[Dict] = None, # delay access entry
|
136
136
|
delay_method: Optional[str] = _DELAY_ROTATE, # delay method
|
137
137
|
interp_method: str = _INTERP_LINEAR, # interpolation method
|
138
|
+
take_aware_unit: bool = False
|
138
139
|
):
|
139
140
|
# target information
|
140
141
|
self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), target_info)
|
@@ -170,6 +171,9 @@ class Delay(Module):
|
|
170
171
|
for entry, delay_time in entries.items():
|
171
172
|
self.register_entry(entry, delay_time)
|
172
173
|
|
174
|
+
self.take_aware_unit = take_aware_unit
|
175
|
+
self._unit = None
|
176
|
+
|
173
177
|
@property
|
174
178
|
def history(self):
|
175
179
|
return self._history
|
@@ -22,8 +22,8 @@ import jax.numpy as jnp
|
|
22
22
|
|
23
23
|
from brainstate import random, environ, init
|
24
24
|
from brainstate._state import ShortTermState
|
25
|
-
from brainstate.nn._module import ElementWiseBlock
|
26
25
|
from brainstate.typing import Size
|
26
|
+
from ._module import ElementWiseBlock
|
27
27
|
|
28
28
|
__all__ = [
|
29
29
|
'DropoutFixed', 'Dropout', 'Dropout1d', 'Dropout2d', 'Dropout3d',
|
@@ -41,13 +41,14 @@ import numpy as np
|
|
41
41
|
from brainstate import environ
|
42
42
|
from brainstate._state import State
|
43
43
|
from brainstate.graph import Node
|
44
|
-
from brainstate.mixin import ParamDescriber
|
45
|
-
from brainstate.nn._module import Module
|
44
|
+
from brainstate.mixin import ParamDescriber, UpdateReturn
|
46
45
|
from brainstate.typing import Size, ArrayLike
|
47
|
-
from .
|
46
|
+
from ._delay import StateWithDelay, Delay
|
47
|
+
from ._module import Module
|
48
48
|
|
49
49
|
__all__ = [
|
50
|
-
'DynamicsGroup', 'Projection', 'Dynamics',
|
50
|
+
'DynamicsGroup', 'Projection', 'Dynamics',
|
51
|
+
'Prefetch', 'PrefetchDelay', 'PrefetchDelayAt', 'OutputDelayAt',
|
51
52
|
]
|
52
53
|
|
53
54
|
T = TypeVar('T')
|
@@ -99,7 +100,7 @@ class Projection(Module):
|
|
99
100
|
raise ValueError('Do not implement the update() function.')
|
100
101
|
|
101
102
|
|
102
|
-
class Dynamics(Module):
|
103
|
+
class Dynamics(Module, UpdateReturn):
|
103
104
|
"""
|
104
105
|
Base class for implementing neural dynamics models in BrainState.
|
105
106
|
|
@@ -821,6 +822,41 @@ class Dynamics(Module):
|
|
821
822
|
else:
|
822
823
|
raise TypeError(f'The input {dyn} should be an instance of {Dynamics} or a delayed initializer.')
|
823
824
|
|
825
|
+
def prefetch_delay(self, state: str, delay: Optional[ArrayLike] = None) -> 'PrefetchDelayAt':
|
826
|
+
"""
|
827
|
+
Create a reference to a delayed state or variable in the module.
|
828
|
+
|
829
|
+
This method simplifies the process of accessing a delayed version of a state or variable
|
830
|
+
within the module. It first creates a prefetch reference to the specified state,
|
831
|
+
then specifies the delay time for accessing this state.
|
832
|
+
|
833
|
+
Args:
|
834
|
+
state (str): The name of the state or variable to reference.
|
835
|
+
delay (Optional[ArrayLike]): The amount of time to delay the variable access,
|
836
|
+
typically in time units (e.g., milliseconds). Defaults to None.
|
837
|
+
|
838
|
+
Returns:
|
839
|
+
PrefetchDelayAt: An object that provides access to the variable at the specified delay time.
|
840
|
+
"""
|
841
|
+
return self.prefetch(state).delay.at(delay)
|
842
|
+
|
843
|
+
def output_delay(self, delay: Optional[ArrayLike] = None) -> 'OutputDelayAt':
|
844
|
+
"""
|
845
|
+
Create a reference to the delayed output of the module.
|
846
|
+
|
847
|
+
This method simplifies the process of accessing a delayed version of the module's output.
|
848
|
+
It instantiates an `OutputDelayAt` object, which can be used to retrieve the output value
|
849
|
+
at the specified delay time.
|
850
|
+
|
851
|
+
Args:
|
852
|
+
delay (Optional[ArrayLike]): The amount of time to delay the output access,
|
853
|
+
typically in time units (e.g., milliseconds). Defaults to None.
|
854
|
+
|
855
|
+
Returns:
|
856
|
+
OutputDelayAt: An object that provides access to the module's output at the specified delay time.
|
857
|
+
"""
|
858
|
+
return OutputDelayAt(self, delay)
|
859
|
+
|
824
860
|
|
825
861
|
class Prefetch(Node):
|
826
862
|
"""
|
@@ -885,6 +921,7 @@ class Prefetch(Node):
|
|
885
921
|
An object that provides access to delayed versions of the prefetched item.
|
886
922
|
"""
|
887
923
|
return PrefetchDelay(self.module, self.item)
|
924
|
+
# return PrefetchDelayAt(self.module, self.item, time)
|
888
925
|
|
889
926
|
def __call__(self, *args, **kwargs):
|
890
927
|
"""
|
@@ -1007,7 +1044,7 @@ class PrefetchDelayAt(Node):
|
|
1007
1044
|
self,
|
1008
1045
|
module: Dynamics,
|
1009
1046
|
item: str,
|
1010
|
-
time: ArrayLike
|
1047
|
+
time: ArrayLike = None,
|
1011
1048
|
):
|
1012
1049
|
"""
|
1013
1050
|
Initialize a PrefetchDelayAt object.
|
@@ -1026,14 +1063,16 @@ class PrefetchDelayAt(Node):
|
|
1026
1063
|
self.module = module
|
1027
1064
|
self.item = item
|
1028
1065
|
self.time = time
|
1029
|
-
self.step = u.math.asarray(time / environ.get_dt(), dtype=environ.ditype())
|
1030
1066
|
|
1031
|
-
|
1032
|
-
|
1033
|
-
|
1034
|
-
|
1035
|
-
|
1036
|
-
|
1067
|
+
if time is not None:
|
1068
|
+
self.step = u.math.asarray(time / environ.get_dt(), dtype=environ.ditype())
|
1069
|
+
|
1070
|
+
# register the delay
|
1071
|
+
key = _get_prefetch_delay_key(item)
|
1072
|
+
if not module._has_after_update(key):
|
1073
|
+
module._add_after_update(key, not_receive_update_output(StateWithDelay(module, item)))
|
1074
|
+
self.state_delay: StateWithDelay = module._get_after_update(key)
|
1075
|
+
self.state_delay.register_delay(time)
|
1037
1076
|
|
1038
1077
|
def __call__(self, *args, **kwargs):
|
1039
1078
|
"""
|
@@ -1044,12 +1083,94 @@ class PrefetchDelayAt(Node):
|
|
1044
1083
|
Any
|
1045
1084
|
The value of the state or variable at the specified delay time.
|
1046
1085
|
"""
|
1047
|
-
|
1048
|
-
|
1086
|
+
if self.time is None:
|
1087
|
+
return _get_prefetch_item(self).value
|
1088
|
+
else:
|
1089
|
+
return self.state_delay.retrieve_at_step(self.step)
|
1090
|
+
|
1091
|
+
|
1092
|
+
class OutputDelayAt(Node):
|
1093
|
+
"""
|
1094
|
+
Provides access to a specific delayed state or variable value at the specific time.
|
1095
|
+
|
1096
|
+
This class represents the final step in the prefetch delay chain, providing
|
1097
|
+
actual access to state values at a specific delay time. It converts the
|
1098
|
+
specified time delay into steps and registers the delay with the appropriate
|
1099
|
+
StateWithDelay handler.
|
1100
|
+
|
1101
|
+
Parameters
|
1102
|
+
----------
|
1103
|
+
module : Dynamics
|
1104
|
+
The dynamics module that contains the referenced state or variable.
|
1105
|
+
item : str
|
1106
|
+
The name of the state or variable to access with delay.
|
1107
|
+
time : ArrayLike
|
1108
|
+
The amount of time to delay access by, typically in time units (e.g., milliseconds).
|
1109
|
+
|
1110
|
+
Examples
|
1111
|
+
--------
|
1112
|
+
>>> import brainstate
|
1113
|
+
>>> import brainunit as u
|
1114
|
+
>>> neuron = brainstate.nn.LIF(10)
|
1115
|
+
>>> # Create a reference to voltage delayed by 5ms
|
1116
|
+
>>> delayed_v = PrefetchDelayAt(neuron, 'V', 5.0 * u.ms)
|
1117
|
+
>>> # Get the delayed value
|
1118
|
+
>>> v_value = delayed_v()
|
1119
|
+
"""
|
1120
|
+
|
1121
|
+
def __init__(
|
1122
|
+
self,
|
1123
|
+
module: Dynamics,
|
1124
|
+
time: Optional[ArrayLike] = None,
|
1125
|
+
):
|
1126
|
+
"""
|
1127
|
+
Initialize a PrefetchDelayAt object.
|
1128
|
+
|
1129
|
+
Parameters
|
1130
|
+
----------
|
1131
|
+
module : AlignPre, Module
|
1132
|
+
The dynamics module that contains the referenced state or variable.
|
1133
|
+
time : ArrayLike
|
1134
|
+
The amount of time to delay access by, typically in time units.
|
1135
|
+
"""
|
1136
|
+
super().__init__()
|
1137
|
+
assert isinstance(module, UpdateReturn), 'The module should implement the `update_return` method.'
|
1138
|
+
assert isinstance(module, Module), 'The module should be an instance of Module.'
|
1139
|
+
self.module = module
|
1140
|
+
self.time = time
|
1141
|
+
if time is not None:
|
1142
|
+
self.step = u.math.asarray(time / environ.get_dt(), dtype=environ.ditype())
|
1143
|
+
|
1144
|
+
# register the delay
|
1145
|
+
key = _get_output_delay_key()
|
1146
|
+
if not module._has_after_update(key):
|
1147
|
+
# TODO: unit processing
|
1148
|
+
delay = Delay(module.update_return(), time)
|
1149
|
+
module._add_after_update(key, receive_update_output(delay))
|
1150
|
+
self.out_delay: Delay = module._get_after_update(key)
|
1151
|
+
self.out_delay.register_delay(time)
|
1152
|
+
|
1153
|
+
def __call__(self, *args, **kwargs):
|
1154
|
+
"""
|
1155
|
+
Retrieve the value of the state at the specified delay time.
|
1156
|
+
|
1157
|
+
Returns
|
1158
|
+
-------
|
1159
|
+
Any
|
1160
|
+
The value of the state or variable at the specified delay time.
|
1161
|
+
"""
|
1162
|
+
if self.time is None:
|
1163
|
+
return self.module.update_return()
|
1164
|
+
else:
|
1165
|
+
return self.out_delay.retrieve_at_step(self.step)
|
1166
|
+
|
1167
|
+
|
1168
|
+
def _get_prefetch_delay_key(item) -> str:
|
1169
|
+
return f'{item}-prefetch-delay'
|
1049
1170
|
|
1050
1171
|
|
1051
|
-
def
|
1052
|
-
return f'
|
1172
|
+
def _get_output_delay_key() -> str:
|
1173
|
+
return f'output-delay'
|
1053
1174
|
|
1054
1175
|
|
1055
1176
|
def _get_prefetch_item(target: Union[Prefetch, PrefetchDelayAt]) -> Any:
|
@@ -1064,7 +1185,7 @@ def _get_prefetch_item_delay(target: Union[Prefetch, PrefetchDelay, PrefetchDela
|
|
1064
1185
|
f'The target module should be an instance '
|
1065
1186
|
f'of Dynamics. But got {target.module}.'
|
1066
1187
|
)
|
1067
|
-
delay = target.module._get_after_update(
|
1188
|
+
delay = target.module._get_after_update(_get_prefetch_delay_key(target.item))
|
1068
1189
|
if not isinstance(delay, StateWithDelay):
|
1069
1190
|
raise TypeError(f'The prefetch target should be a {StateWithDelay.__name__} when accessing '
|
1070
1191
|
f'its delay. But got {delay}.')
|
@@ -23,8 +23,8 @@ import jax.typing
|
|
23
23
|
|
24
24
|
from brainstate import random, functional as F
|
25
25
|
from brainstate._state import ParamState
|
26
|
-
from brainstate.nn._module import ElementWiseBlock
|
27
26
|
from brainstate.typing import ArrayLike
|
27
|
+
from ._module import ElementWiseBlock
|
28
28
|
|
29
29
|
__all__ = [
|
30
30
|
# activation functions
|
@@ -17,8 +17,8 @@ from typing import Optional, Callable, Union
|
|
17
17
|
|
18
18
|
from brainstate import init
|
19
19
|
from brainstate._state import ParamState
|
20
|
-
from brainstate.nn._module import Module
|
21
20
|
from brainstate.typing import ArrayLike
|
21
|
+
from ._module import Module
|
22
22
|
|
23
23
|
__all__ = [
|
24
24
|
'Embedding',
|
@@ -25,8 +25,8 @@ from brainstate import random, augment, environ, init
|
|
25
25
|
from brainstate._compatible_import import brainevent
|
26
26
|
from brainstate._state import ParamState
|
27
27
|
from brainstate.compile import for_loop
|
28
|
-
from brainstate.nn._module import Module
|
29
28
|
from brainstate.typing import Size, ArrayLike
|
29
|
+
from ._module import Module
|
30
30
|
|
31
31
|
__all__ = [
|
32
32
|
'EventFixedNumConn',
|
@@ -20,12 +20,11 @@ import jax
|
|
20
20
|
import numpy as np
|
21
21
|
|
22
22
|
from brainstate import environ, init, random
|
23
|
-
from brainstate._state import ShortTermState
|
24
|
-
from brainstate._state import State, maybe_state
|
23
|
+
from brainstate._state import ShortTermState, State, maybe_state
|
25
24
|
from brainstate.compile import while_loop
|
26
|
-
from brainstate.nn._dynamics._dynamics_base import Dynamics, Prefetch
|
27
|
-
from brainstate.nn._module import Module
|
28
25
|
from brainstate.typing import ArrayLike, Size, DTypeLike
|
26
|
+
from ._dynamics import Dynamics, Prefetch
|
27
|
+
from ._module import Module
|
29
28
|
|
30
29
|
__all__ = [
|
31
30
|
'SpikeTime',
|
@@ -134,7 +133,7 @@ class PoissonSpike(Dynamics):
|
|
134
133
|
self.freqs = init.param(freqs, self.varshape, allow_none=False)
|
135
134
|
|
136
135
|
def update(self):
|
137
|
-
spikes = random.rand(self.varshape) <= (self.freqs * environ.get_dt())
|
136
|
+
spikes = random.rand(*self.varshape) <= (self.freqs * environ.get_dt())
|
138
137
|
spikes = u.math.asarray(spikes, dtype=self.spk_type)
|
139
138
|
return spikes
|
140
139
|
|
@@ -22,8 +22,8 @@ import jax.numpy as jnp
|
|
22
22
|
|
23
23
|
from brainstate import init, functional
|
24
24
|
from brainstate._state import ParamState
|
25
|
-
from brainstate.nn._module import Module
|
26
25
|
from brainstate.typing import ArrayLike, Size
|
26
|
+
from ._module import Module
|
27
27
|
|
28
28
|
__all__ = [
|
29
29
|
'Linear',
|
@@ -350,10 +350,7 @@ class OneToOne(Module):
|
|
350
350
|
self.weight = param_type(param)
|
351
351
|
|
352
352
|
def update(self, pre_val):
|
353
|
-
|
354
|
-
w_val, w_unit = u.get_mantissa(self.weight.value['weight']), u.get_unit(self.weight.value['weight'])
|
355
|
-
post_val = pre_val * w_val
|
356
|
-
post_val = u.maybe_decimal(u.Quantity(post_val, unit=w_unit * pre_unit))
|
353
|
+
post_val = pre_val * self.weight.value['weight']
|
357
354
|
if 'bias' in self.weight.value:
|
358
355
|
post_val = post_val + self.weight.value['bias']
|
359
356
|
return post_val
|
@@ -21,8 +21,8 @@ import jax
|
|
21
21
|
from brainstate import init
|
22
22
|
from brainstate._compatible_import import brainevent
|
23
23
|
from brainstate._state import ParamState
|
24
|
-
from brainstate.nn._module import Module
|
25
24
|
from brainstate.typing import Size, ArrayLike
|
25
|
+
from ._module import Module
|
26
26
|
|
27
27
|
__all__ = [
|
28
28
|
'EventLinear',
|
@@ -16,11 +16,13 @@
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
|
18
18
|
|
19
|
-
from .
|
20
|
-
from ._linear_mv import EventLinear
|
19
|
+
from ._synapse import Synapse
|
21
20
|
|
22
21
|
__all__ = [
|
23
|
-
'
|
24
|
-
'EventFixedProb',
|
25
|
-
'EventFixedNumConn',
|
22
|
+
'LongTermPlasticity',
|
26
23
|
]
|
24
|
+
|
25
|
+
|
26
|
+
class LongTermPlasticity(Synapse):
|
27
|
+
pass
|
28
|
+
|
@@ -22,9 +22,9 @@ import jax
|
|
22
22
|
|
23
23
|
from brainstate import init, surrogate, environ
|
24
24
|
from brainstate._state import HiddenState, ShortTermState
|
25
|
-
from brainstate.nn._dynamics._dynamics_base import Dynamics
|
26
|
-
from brainstate.nn._exp_euler import exp_euler_step
|
27
25
|
from brainstate.typing import ArrayLike, Size
|
26
|
+
from ._dynamics import Dynamics
|
27
|
+
from ._exp_euler import exp_euler_step
|
28
28
|
|
29
29
|
__all__ = [
|
30
30
|
'Neuron', 'IF', 'LIF', 'LIFRef', 'ALIF',
|
@@ -22,8 +22,8 @@ import jax.numpy as jnp
|
|
22
22
|
|
23
23
|
from brainstate import environ, init
|
24
24
|
from brainstate._state import ParamState, BatchState
|
25
|
-
from brainstate.nn._module import Module
|
26
25
|
from brainstate.typing import DTypeLike, ArrayLike, Size, Axes
|
26
|
+
from ._module import Module
|
27
27
|
|
28
28
|
__all__ = [
|
29
29
|
'BatchNorm0d',
|
@@ -103,7 +103,7 @@ class TestPool(parameterized.TestCase):
|
|
103
103
|
for target_size in [10, 9, 8, 7, 6]
|
104
104
|
)
|
105
105
|
def test_adaptive_pool1d(self, target_size):
|
106
|
-
from brainstate.nn.
|
106
|
+
from brainstate.nn._poolings import _adaptive_pool1d
|
107
107
|
|
108
108
|
arr = brainstate.random.rand(100)
|
109
109
|
op = jax.numpy.mean
|