brainstate 0.1.3__py2.py3-none-any.whl → 0.1.4__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +1 -16
- brainstate/compile/_jit.py +14 -5
- brainstate/compile/_make_jaxpr.py +78 -22
- brainstate/compile/_make_jaxpr_test.py +13 -2
- brainstate/graph/_graph_node.py +1 -1
- brainstate/graph/_graph_operation.py +4 -4
- brainstate/mixin.py +31 -2
- brainstate/nn/__init__.py +8 -5
- brainstate/nn/_delay.py +13 -1
- brainstate/nn/_dropout.py +5 -4
- brainstate/nn/_dynamics.py +39 -44
- brainstate/nn/{_fixedprob_mv.py → _fixedprob.py} +95 -24
- brainstate/nn/_linear_mv.py +1 -1
- brainstate/nn/_module.py +5 -5
- brainstate/nn/_projection.py +190 -98
- brainstate/nn/_synapse.py +5 -9
- brainstate/nn/_synaptic_projection.py +376 -86
- brainstate/surrogate.py +1 -1
- brainstate/typing.py +1 -1
- brainstate/util/__init__.py +14 -14
- brainstate/util/{_pretty_pytree.py → pretty_pytree.py} +2 -2
- {brainstate-0.1.3.dist-info → brainstate-0.1.4.dist-info}/METADATA +1 -1
- {brainstate-0.1.3.dist-info → brainstate-0.1.4.dist-info}/RECORD +35 -35
- /brainstate/nn/{_fixedprob_mv_test.py → _fixedprob_test.py} +0 -0
- /brainstate/util/{_caller.py → caller.py} +0 -0
- /brainstate/util/{_error.py → error.py} +0 -0
- /brainstate/util/{_others.py → others.py} +0 -0
- /brainstate/util/{_pretty_repr.py → pretty_repr.py} +0 -0
- /brainstate/util/{_pretty_table.py → pretty_table.py} +0 -0
- /brainstate/util/{_scaling.py → scaling.py} +0 -0
- /brainstate/util/{_struct.py → struct.py} +0 -0
- {brainstate-0.1.3.dist-info → brainstate-0.1.4.dist-info}/LICENSE +0 -0
- {brainstate-0.1.3.dist-info → brainstate-0.1.4.dist-info}/WHEEL +0 -0
- {brainstate-0.1.3.dist-info → brainstate-0.1.4.dist-info}/top_level.txt +0 -0
brainstate/nn/_dynamics.py
CHANGED
@@ -36,13 +36,14 @@ For handling the delays:
|
|
36
36
|
from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, TYPE_CHECKING
|
37
37
|
|
38
38
|
import brainunit as u
|
39
|
+
import jax
|
39
40
|
import numpy as np
|
40
41
|
|
41
42
|
from brainstate import environ
|
42
43
|
from brainstate._state import State
|
43
44
|
from brainstate.graph import Node
|
44
45
|
from brainstate.mixin import ParamDescriber, UpdateReturn
|
45
|
-
from brainstate.typing import Size, ArrayLike
|
46
|
+
from brainstate.typing import Size, ArrayLike, PyTree
|
46
47
|
from ._delay import StateWithDelay, Delay
|
47
48
|
from ._module import Module
|
48
49
|
|
@@ -811,18 +812,25 @@ class Dynamics(Module, UpdateReturn):
|
|
811
812
|
>>> n1.align_pre(brainstate.nn.Expon.desc(n1.varshape)) # n2 will run after n1
|
812
813
|
"""
|
813
814
|
if isinstance(dyn, Dynamics):
|
814
|
-
self._add_after_update(dyn
|
815
|
+
self._add_after_update(id(dyn), dyn)
|
815
816
|
return dyn
|
816
817
|
elif isinstance(dyn, ParamDescriber):
|
817
818
|
if not issubclass(dyn.cls, Dynamics):
|
818
819
|
raise TypeError(f'The input {dyn} should be an instance of {Dynamics}.')
|
819
820
|
if not self._has_after_update(dyn.identifier):
|
820
|
-
self._add_after_update(
|
821
|
+
self._add_after_update(
|
822
|
+
dyn.identifier,
|
823
|
+
dyn() if ('in_size' in dyn.kwargs or len(dyn.args) > 0) else dyn(in_size=self.varshape)
|
824
|
+
)
|
821
825
|
return self._get_after_update(dyn.identifier)
|
822
826
|
else:
|
823
827
|
raise TypeError(f'The input {dyn} should be an instance of {Dynamics} or a delayed initializer.')
|
824
828
|
|
825
|
-
def prefetch_delay(
|
829
|
+
def prefetch_delay(
|
830
|
+
self,
|
831
|
+
state: str,
|
832
|
+
delay: Optional[ArrayLike] = None
|
833
|
+
) -> 'PrefetchDelayAt':
|
826
834
|
"""
|
827
835
|
Create a reference to a delayed state or variable in the module.
|
828
836
|
|
@@ -840,7 +848,11 @@ class Dynamics(Module, UpdateReturn):
|
|
840
848
|
"""
|
841
849
|
return self.prefetch(state).delay.at(delay)
|
842
850
|
|
843
|
-
def output_delay(
|
851
|
+
def output_delay(
|
852
|
+
self,
|
853
|
+
delay: Optional[ArrayLike] = None,
|
854
|
+
variable_like: PyTree = None
|
855
|
+
) -> 'OutputDelayAt':
|
844
856
|
"""
|
845
857
|
Create a reference to the delayed output of the module.
|
846
858
|
|
@@ -851,6 +863,7 @@ class Dynamics(Module, UpdateReturn):
|
|
851
863
|
Args:
|
852
864
|
delay (Optional[ArrayLike]): The amount of time to delay the output access,
|
853
865
|
typically in time units (e.g., milliseconds). Defaults to None.
|
866
|
+
variable_like:
|
854
867
|
|
855
868
|
Returns:
|
856
869
|
OutputDelayAt: An object that provides access to the module's output at the specified delay time.
|
@@ -1102,8 +1115,6 @@ class OutputDelayAt(Node):
|
|
1102
1115
|
----------
|
1103
1116
|
module : Dynamics
|
1104
1117
|
The dynamics module that contains the referenced state or variable.
|
1105
|
-
item : str
|
1106
|
-
The name of the state or variable to access with delay.
|
1107
1118
|
time : ArrayLike
|
1108
1119
|
The amount of time to delay access by, typically in time units (e.g., milliseconds).
|
1109
1120
|
|
@@ -1113,56 +1124,40 @@ class OutputDelayAt(Node):
|
|
1113
1124
|
>>> import brainunit as u
|
1114
1125
|
>>> neuron = brainstate.nn.LIF(10)
|
1115
1126
|
>>> # Create a reference to voltage delayed by 5ms
|
1116
|
-
>>>
|
1127
|
+
>>> delayed_spike = OutputDelayAt(neuron, 5.0 * u.ms)
|
1117
1128
|
>>> # Get the delayed value
|
1118
|
-
>>> v_value =
|
1129
|
+
>>> v_value = delayed_spike()
|
1119
1130
|
"""
|
1120
1131
|
|
1121
1132
|
def __init__(
|
1122
1133
|
self,
|
1123
1134
|
module: Dynamics,
|
1124
1135
|
time: Optional[ArrayLike] = None,
|
1136
|
+
variable_like: Optional[PyTree] = None,
|
1125
1137
|
):
|
1126
|
-
"""
|
1127
|
-
Initialize a PrefetchDelayAt object.
|
1128
|
-
|
1129
|
-
Parameters
|
1130
|
-
----------
|
1131
|
-
module : AlignPre, Module
|
1132
|
-
The dynamics module that contains the referenced state or variable.
|
1133
|
-
time : ArrayLike
|
1134
|
-
The amount of time to delay access by, typically in time units.
|
1135
|
-
"""
|
1136
1138
|
super().__init__()
|
1137
|
-
assert isinstance(module,
|
1138
|
-
assert isinstance(module, Module), 'The module should be an instance of Module.'
|
1139
|
+
assert isinstance(module, Dynamics), 'The module should be an instance of Dynamics.'
|
1139
1140
|
self.module = module
|
1141
|
+
dt = environ.get_dt()
|
1142
|
+
if time is None:
|
1143
|
+
time = u.math.zeros_like(dt)
|
1140
1144
|
self.time = time
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
1148
|
-
|
1149
|
-
|
1150
|
-
|
1151
|
-
|
1145
|
+
self.step = u.math.asarray(time / dt, dtype=environ.ditype())
|
1146
|
+
|
1147
|
+
# register the delay
|
1148
|
+
key = _get_output_delay_key()
|
1149
|
+
if not module._has_after_update(key):
|
1150
|
+
delay = Delay(
|
1151
|
+
jax.ShapeDtypeStruct(module.out_size, dtype=environ.dftype()),
|
1152
|
+
time,
|
1153
|
+
take_aware_unit=True
|
1154
|
+
)
|
1155
|
+
module._add_after_update(key, receive_update_output(delay))
|
1156
|
+
self.out_delay: Delay = module._get_after_update(key)
|
1157
|
+
self.out_delay.register_delay(time)
|
1152
1158
|
|
1153
1159
|
def __call__(self, *args, **kwargs):
|
1154
|
-
|
1155
|
-
Retrieve the value of the state at the specified delay time.
|
1156
|
-
|
1157
|
-
Returns
|
1158
|
-
-------
|
1159
|
-
Any
|
1160
|
-
The value of the state or variable at the specified delay time.
|
1161
|
-
"""
|
1162
|
-
if self.time is None:
|
1163
|
-
return self.module.update_return()
|
1164
|
-
else:
|
1165
|
-
return self.out_delay.retrieve_at_step(self.step)
|
1160
|
+
return self.out_delay.retrieve_at_step(self.step)
|
1166
1161
|
|
1167
1162
|
|
1168
1163
|
def _get_prefetch_delay_key(item) -> str:
|
@@ -16,19 +16,20 @@
|
|
16
16
|
|
17
17
|
from typing import Union, Callable, Optional
|
18
18
|
|
19
|
+
import brainevent
|
19
20
|
import brainunit as u
|
20
21
|
import jax
|
21
22
|
import jax.numpy as jnp
|
22
23
|
import numpy as np
|
23
24
|
|
24
25
|
from brainstate import random, augment, environ, init
|
25
|
-
from brainstate.
|
26
|
-
from brainstate._state import ParamState
|
26
|
+
from brainstate._state import ParamState, FakeState
|
27
27
|
from brainstate.compile import for_loop
|
28
28
|
from brainstate.typing import Size, ArrayLike
|
29
29
|
from ._module import Module
|
30
30
|
|
31
31
|
__all__ = [
|
32
|
+
'FixedNumConn',
|
32
33
|
'EventFixedNumConn',
|
33
34
|
'EventFixedProb',
|
34
35
|
]
|
@@ -44,12 +45,11 @@ def init_indices_without_replace(
|
|
44
45
|
rng = random.default_rng(seed)
|
45
46
|
|
46
47
|
if method == 'vmap':
|
47
|
-
@augment.vmap
|
48
|
-
def rand_indices(
|
49
|
-
rng.set_key(key)
|
48
|
+
@augment.vmap(axis_size=n_pre)
|
49
|
+
def rand_indices():
|
50
50
|
return rng.choice(n_post, size=(conn_num,), replace=False)
|
51
51
|
|
52
|
-
return rand_indices(
|
52
|
+
return rand_indices()
|
53
53
|
|
54
54
|
elif method == 'for_loop':
|
55
55
|
return for_loop(
|
@@ -61,9 +61,9 @@ def init_indices_without_replace(
|
|
61
61
|
raise ValueError(f"Unknown method: {method}")
|
62
62
|
|
63
63
|
|
64
|
-
class
|
64
|
+
class FixedNumConn(Module):
|
65
65
|
"""
|
66
|
-
The
|
66
|
+
The ``FixedNumConn`` module implements a fixed probability connection with CSR sparse data structure.
|
67
67
|
|
68
68
|
Parameters
|
69
69
|
----------
|
@@ -77,7 +77,7 @@ class EventFixedNumConn(Module):
|
|
77
77
|
If it is an integer, representing the number of connections.
|
78
78
|
conn_weight : float or callable or jax.Array or brainunit.Quantity
|
79
79
|
Maximum synaptic conductance, i.e., synaptic weight.
|
80
|
-
|
80
|
+
efferent_target : str, optional
|
81
81
|
The target of the connection. Default is 'post', meaning that each pre-synaptic neuron connects to
|
82
82
|
a fixed number of post-synaptic neurons. The connection number is determined by the value of ``n_conn``.
|
83
83
|
|
@@ -104,7 +104,8 @@ class EventFixedNumConn(Module):
|
|
104
104
|
out_size: Size,
|
105
105
|
conn_num: Union[int, float],
|
106
106
|
conn_weight: Union[Callable, ArrayLike],
|
107
|
-
|
107
|
+
efferent_target: str = 'post', # 'pre' or 'post'
|
108
|
+
afferent_ratio: Union[int, float] = 1.,
|
108
109
|
allow_multi_conn: bool = True,
|
109
110
|
seed: Optional[int] = None,
|
110
111
|
name: Optional[str] = None,
|
@@ -116,11 +117,14 @@ class EventFixedNumConn(Module):
|
|
116
117
|
# network parameters
|
117
118
|
self.in_size = in_size
|
118
119
|
self.out_size = out_size
|
119
|
-
self.
|
120
|
-
assert
|
120
|
+
self.efferent_target = efferent_target
|
121
|
+
assert efferent_target in ('pre', 'post'), 'The target of the connection must be either "pre" or "post".'
|
122
|
+
assert 0. <= afferent_ratio <= 1., 'Afferent ratio must be in [0, 1].'
|
121
123
|
if isinstance(conn_num, float):
|
122
124
|
assert 0. <= conn_num <= 1., 'Connection probability must be in [0, 1].'
|
123
|
-
conn_num = int(self.out_size[-1] * conn_num)
|
125
|
+
conn_num = (int(self.out_size[-1] * conn_num)
|
126
|
+
if efferent_target == 'post' else
|
127
|
+
int(self.in_size[-1] * conn_num))
|
124
128
|
assert isinstance(conn_num, int), 'Connection number must be an integer.'
|
125
129
|
self.conn_num = conn_num
|
126
130
|
self.seed = seed
|
@@ -128,14 +132,13 @@ class EventFixedNumConn(Module):
|
|
128
132
|
|
129
133
|
# connections
|
130
134
|
if self.conn_num >= 1:
|
131
|
-
if self.
|
135
|
+
if self.efferent_target == 'post':
|
132
136
|
n_post = self.out_size[-1]
|
133
137
|
n_pre = self.in_size[-1]
|
134
138
|
else:
|
135
139
|
n_post = self.in_size[-1]
|
136
140
|
n_pre = self.out_size[-1]
|
137
141
|
|
138
|
-
# indices of post connected neurons
|
139
142
|
with jax.ensure_compile_time_eval():
|
140
143
|
if allow_multi_conn:
|
141
144
|
rng = np.random if seed is None else np.random.RandomState(seed)
|
@@ -143,15 +146,83 @@ class EventFixedNumConn(Module):
|
|
143
146
|
else:
|
144
147
|
indices = init_indices_without_replace(self.conn_num, n_pre, n_post, seed, conn_init)
|
145
148
|
indices = u.math.asarray(indices, dtype=environ.ditype())
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
149
|
+
|
150
|
+
if afferent_ratio == 1.:
|
151
|
+
conn_weight = u.math.asarray(init.param(conn_weight, (n_pre, self.conn_num), allow_none=False))
|
152
|
+
self.weight = param_type(conn_weight)
|
153
|
+
csr = (
|
154
|
+
brainevent.FixedPostNumConn((conn_weight, indices), shape=(n_pre, n_post))
|
155
|
+
if self.efferent_target == 'post' else
|
156
|
+
brainevent.FixedPreNumConn((conn_weight, indices), shape=(n_pre, n_post))
|
157
|
+
)
|
158
|
+
self.conn = csr
|
159
|
+
|
160
|
+
else:
|
161
|
+
self.pre_selected = np.random.random(n_pre) < afferent_ratio
|
162
|
+
indices = indices[self.pre_selected].flatten()
|
163
|
+
conn_weight = u.math.asarray(init.param(conn_weight, (indices.size,), allow_none=False))
|
164
|
+
self.weight = param_type(conn_weight)
|
165
|
+
indptr = (jnp.arange(1, n_pre + 1) * self.conn_num -
|
166
|
+
jnp.cumsum(~self.pre_selected) * self.conn_num)
|
167
|
+
indptr = jnp.insert(indptr, 0, 0) # insert 0 at the beginning
|
168
|
+
csr = (
|
169
|
+
brainevent.CSR((conn_weight, indices, indptr), shape=(n_pre, n_post))
|
170
|
+
if self.efferent_target == 'post' else
|
171
|
+
brainevent.CSC((conn_weight, indices, indptr), shape=(n_pre, n_post))
|
172
|
+
)
|
173
|
+
self.conn = csr
|
174
|
+
|
175
|
+
else:
|
176
|
+
conn_weight = u.math.asarray(init.param(conn_weight, (), allow_none=False))
|
177
|
+
self.weight = FakeState(conn_weight)
|
178
|
+
|
179
|
+
def update(self, x: jax.Array) -> Union[jax.Array, u.Quantity]:
|
180
|
+
if self.conn_num >= 1:
|
181
|
+
csr = self.conn.with_data(self.weight.value)
|
182
|
+
return x @ csr
|
183
|
+
else:
|
184
|
+
weight = self.weight.value
|
185
|
+
r = u.math.zeros(x.shape[:-1] + (self.out_size[-1],), dtype=weight.dtype)
|
186
|
+
r = u.maybe_decimal(u.Quantity(r, unit=u.get_unit(weight)))
|
187
|
+
return u.math.asarray(r, dtype=environ.dftype())
|
188
|
+
|
189
|
+
|
190
|
+
class EventFixedNumConn(FixedNumConn):
|
191
|
+
"""
|
192
|
+
The FixedProb module implements a fixed probability connection with CSR sparse data structure.
|
193
|
+
|
194
|
+
Parameters
|
195
|
+
----------
|
196
|
+
in_size : Size
|
197
|
+
Number of pre-synaptic neurons, i.e., input size.
|
198
|
+
out_size : Size
|
199
|
+
Number of post-synaptic neurons, i.e., output size.
|
200
|
+
conn_num : float, int
|
201
|
+
If it is a float, representing the probability of connection, i.e., connection probability.
|
202
|
+
|
203
|
+
If it is an integer, representing the number of connections.
|
204
|
+
conn_weight : float or callable or jax.Array or brainunit.Quantity
|
205
|
+
Maximum synaptic conductance, i.e., synaptic weight.
|
206
|
+
conn_target : str, optional
|
207
|
+
The target of the connection. Default is 'post', meaning that each pre-synaptic neuron connects to
|
208
|
+
a fixed number of post-synaptic neurons. The connection number is determined by the value of ``n_conn``.
|
209
|
+
|
210
|
+
If 'pre', each post-synaptic neuron connects to a fixed number of pre-synaptic neurons.
|
211
|
+
conn_init : str, optional
|
212
|
+
The initialization method of the connection weight. Default is 'vmap', meaning that the connection weight
|
213
|
+
is initialized by parallelized across multiple threads.
|
214
|
+
|
215
|
+
If 'for_loop', the connection weight is initialized by a for loop.
|
216
|
+
allow_multi_conn : bool, optional
|
217
|
+
Whether multiple connections are allowed from a single pre-synaptic neuron.
|
218
|
+
Default is True, meaning that a value of ``a`` can be selected multiple times.
|
219
|
+
seed: int, optional
|
220
|
+
Random seed. Default is None. If None, the default random seed will be used.
|
221
|
+
name : str, optional
|
222
|
+
Name of the module.
|
223
|
+
"""
|
224
|
+
|
225
|
+
__module__ = 'brainstate.nn'
|
155
226
|
|
156
227
|
def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
|
157
228
|
if self.conn_num >= 1:
|
brainstate/nn/_linear_mv.py
CHANGED
@@ -19,7 +19,7 @@ import brainunit as u
|
|
19
19
|
import jax
|
20
20
|
|
21
21
|
from brainstate import init
|
22
|
-
|
22
|
+
import brainevent
|
23
23
|
from brainstate._state import ParamState
|
24
24
|
from brainstate.typing import Size, ArrayLike
|
25
25
|
from ._module import Module
|
brainstate/nn/_module.py
CHANGED
@@ -34,7 +34,7 @@ import numpy as np
|
|
34
34
|
from brainstate._state import State
|
35
35
|
from brainstate.graph import Node, states, nodes, flatten
|
36
36
|
from brainstate.mixin import ParamDescriber, ParamDesc
|
37
|
-
from brainstate.typing import PathParts
|
37
|
+
from brainstate.typing import PathParts, Size
|
38
38
|
from brainstate.util import FlattedDict, NestedDict, BrainStateError
|
39
39
|
|
40
40
|
# maximum integer
|
@@ -62,8 +62,8 @@ class Module(Node, ParamDesc):
|
|
62
62
|
|
63
63
|
__module__ = 'brainstate.nn'
|
64
64
|
|
65
|
-
_in_size: Optional[
|
66
|
-
_out_size: Optional[
|
65
|
+
_in_size: Optional[Size]
|
66
|
+
_out_size: Optional[Size]
|
67
67
|
_name: Optional[str]
|
68
68
|
|
69
69
|
if not TYPE_CHECKING:
|
@@ -87,7 +87,7 @@ class Module(Node, ParamDesc):
|
|
87
87
|
raise AttributeError('The name of the model is read-only.')
|
88
88
|
|
89
89
|
@property
|
90
|
-
def in_size(self) ->
|
90
|
+
def in_size(self) -> Size:
|
91
91
|
return self._in_size
|
92
92
|
|
93
93
|
@in_size.setter
|
@@ -98,7 +98,7 @@ class Module(Node, ParamDesc):
|
|
98
98
|
self._in_size = tuple(in_size)
|
99
99
|
|
100
100
|
@property
|
101
|
-
def out_size(self) ->
|
101
|
+
def out_size(self) -> Size:
|
102
102
|
return self._out_size
|
103
103
|
|
104
104
|
@out_size.setter
|