brainstate 0.1.2__py2.py3-none-any.whl → 0.1.4__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +0 -15
- brainstate/compile/_jit.py +14 -5
- brainstate/compile/_make_jaxpr.py +78 -22
- brainstate/compile/_make_jaxpr_test.py +13 -2
- brainstate/graph/_graph_node.py +1 -1
- brainstate/graph/_graph_operation.py +4 -4
- brainstate/mixin.py +30 -14
- brainstate/nn/__init__.py +84 -17
- brainstate/nn/{_interaction/_conv.py → _conv.py} +1 -1
- brainstate/nn/{_dynamics/_state_delay.py → _delay.py} +19 -3
- brainstate/nn/{_elementwise/_dropout.py → _dropout.py} +6 -5
- brainstate/nn/{_dynamics/_dynamics_base.py → _dynamics.py} +137 -21
- brainstate/nn/{_elementwise/_elementwise.py → _elementwise.py} +1 -1
- brainstate/nn/{_interaction/_embedding.py → _embedding.py} +1 -1
- brainstate/nn/{_event/_fixedprob_mv.py → _fixedprob.py} +96 -25
- brainstate/nn/{_dyn_impl/_inputs.py → _inputs.py} +4 -5
- brainstate/nn/{_interaction/_linear.py → _linear.py} +2 -5
- brainstate/nn/{_event/_linear_mv.py → _linear_mv.py} +2 -2
- brainstate/nn/{_event/__init__.py → _ltp.py} +7 -5
- brainstate/nn/_module.py +5 -5
- brainstate/nn/{_dyn_impl/_dynamics_neuron.py → _neuron.py} +2 -2
- brainstate/nn/{_interaction/_normalizations.py → _normalizations.py} +1 -1
- brainstate/nn/{_interaction/_poolings.py → _poolings.py} +1 -1
- brainstate/nn/{_interaction/_poolings_test.py → _poolings_test.py} +1 -1
- brainstate/nn/_projection.py +486 -0
- brainstate/nn/{_dyn_impl/_rate_rnns.py → _rate_rnns.py} +2 -2
- brainstate/nn/{_dyn_impl/_readout.py → _readout.py} +3 -3
- brainstate/nn/_stp.py +236 -0
- brainstate/nn/{_dyn_impl/_dynamics_synapse.py → _synapse.py} +19 -212
- brainstate/nn/_synaptic_projection.py +423 -0
- brainstate/nn/{_dynamics/_synouts.py → _synouts.py} +4 -1
- brainstate/surrogate.py +1 -1
- brainstate/typing.py +1 -1
- brainstate/util/__init__.py +14 -14
- brainstate/util/{_pretty_pytree.py → pretty_pytree.py} +2 -2
- {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/METADATA +1 -1
- {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/RECORD +61 -63
- brainstate/nn/_dyn_impl/__init__.py +0 -42
- brainstate/nn/_dynamics/__init__.py +0 -37
- brainstate/nn/_dynamics/_projection_base.py +0 -362
- brainstate/nn/_elementwise/__init__.py +0 -22
- brainstate/nn/_interaction/__init__.py +0 -41
- /brainstate/nn/{_interaction/_conv_test.py → _conv_test.py} +0 -0
- /brainstate/nn/{_elementwise/_dropout_test.py → _dropout_test.py} +0 -0
- /brainstate/nn/{_dynamics/_dynamics_base_test.py → _dynamics_test.py} +0 -0
- /brainstate/nn/{_elementwise/_elementwise_test.py → _elementwise_test.py} +0 -0
- /brainstate/nn/{_event/_fixedprob_mv_test.py → _fixedprob_test.py} +0 -0
- /brainstate/nn/{_event/_linear_mv_test.py → _linear_mv_test.py} +0 -0
- /brainstate/nn/{_interaction/_linear_test.py → _linear_test.py} +0 -0
- /brainstate/nn/{_dyn_impl/_dynamics_neuron_test.py → _neuron_test.py} +0 -0
- /brainstate/nn/{_interaction/_normalizations_test.py → _normalizations_test.py} +0 -0
- /brainstate/nn/{_dyn_impl/_rate_rnns_test.py → _rate_rnns_test.py} +0 -0
- /brainstate/nn/{_dyn_impl/_readout_test.py → _readout_test.py} +0 -0
- /brainstate/nn/{_dyn_impl/_dynamics_synapse_test.py → _synapse_test.py} +0 -0
- /brainstate/nn/{_dynamics/_synouts_test.py → _synouts_test.py} +0 -0
- /brainstate/util/{_caller.py → caller.py} +0 -0
- /brainstate/util/{_error.py → error.py} +0 -0
- /brainstate/util/{_others.py → others.py} +0 -0
- /brainstate/util/{_pretty_repr.py → pretty_repr.py} +0 -0
- /brainstate/util/{_pretty_table.py → pretty_table.py} +0 -0
- /brainstate/util/{_scaling.py → scaling.py} +0 -0
- /brainstate/util/{_struct.py → struct.py} +0 -0
- {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/LICENSE +0 -0
- {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/WHEEL +0 -0
- {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/top_level.txt +0 -0
@@ -16,19 +16,20 @@
|
|
16
16
|
|
17
17
|
from typing import Union, Callable, Optional
|
18
18
|
|
19
|
+
import brainevent
|
19
20
|
import brainunit as u
|
20
21
|
import jax
|
21
22
|
import jax.numpy as jnp
|
22
23
|
import numpy as np
|
23
24
|
|
24
25
|
from brainstate import random, augment, environ, init
|
25
|
-
from brainstate.
|
26
|
-
from brainstate._state import ParamState
|
26
|
+
from brainstate._state import ParamState, FakeState
|
27
27
|
from brainstate.compile import for_loop
|
28
|
-
from brainstate.nn._module import Module
|
29
28
|
from brainstate.typing import Size, ArrayLike
|
29
|
+
from ._module import Module
|
30
30
|
|
31
31
|
__all__ = [
|
32
|
+
'FixedNumConn',
|
32
33
|
'EventFixedNumConn',
|
33
34
|
'EventFixedProb',
|
34
35
|
]
|
@@ -44,12 +45,11 @@ def init_indices_without_replace(
|
|
44
45
|
rng = random.default_rng(seed)
|
45
46
|
|
46
47
|
if method == 'vmap':
|
47
|
-
@augment.vmap
|
48
|
-
def rand_indices(
|
49
|
-
rng.set_key(key)
|
48
|
+
@augment.vmap(axis_size=n_pre)
|
49
|
+
def rand_indices():
|
50
50
|
return rng.choice(n_post, size=(conn_num,), replace=False)
|
51
51
|
|
52
|
-
return rand_indices(
|
52
|
+
return rand_indices()
|
53
53
|
|
54
54
|
elif method == 'for_loop':
|
55
55
|
return for_loop(
|
@@ -61,9 +61,9 @@ def init_indices_without_replace(
|
|
61
61
|
raise ValueError(f"Unknown method: {method}")
|
62
62
|
|
63
63
|
|
64
|
-
class
|
64
|
+
class FixedNumConn(Module):
|
65
65
|
"""
|
66
|
-
The
|
66
|
+
The ``FixedNumConn`` module implements a fixed probability connection with CSR sparse data structure.
|
67
67
|
|
68
68
|
Parameters
|
69
69
|
----------
|
@@ -77,7 +77,7 @@ class EventFixedNumConn(Module):
|
|
77
77
|
If it is an integer, representing the number of connections.
|
78
78
|
conn_weight : float or callable or jax.Array or brainunit.Quantity
|
79
79
|
Maximum synaptic conductance, i.e., synaptic weight.
|
80
|
-
|
80
|
+
efferent_target : str, optional
|
81
81
|
The target of the connection. Default is 'post', meaning that each pre-synaptic neuron connects to
|
82
82
|
a fixed number of post-synaptic neurons. The connection number is determined by the value of ``n_conn``.
|
83
83
|
|
@@ -104,7 +104,8 @@ class EventFixedNumConn(Module):
|
|
104
104
|
out_size: Size,
|
105
105
|
conn_num: Union[int, float],
|
106
106
|
conn_weight: Union[Callable, ArrayLike],
|
107
|
-
|
107
|
+
efferent_target: str = 'post', # 'pre' or 'post'
|
108
|
+
afferent_ratio: Union[int, float] = 1.,
|
108
109
|
allow_multi_conn: bool = True,
|
109
110
|
seed: Optional[int] = None,
|
110
111
|
name: Optional[str] = None,
|
@@ -116,11 +117,14 @@ class EventFixedNumConn(Module):
|
|
116
117
|
# network parameters
|
117
118
|
self.in_size = in_size
|
118
119
|
self.out_size = out_size
|
119
|
-
self.
|
120
|
-
assert
|
120
|
+
self.efferent_target = efferent_target
|
121
|
+
assert efferent_target in ('pre', 'post'), 'The target of the connection must be either "pre" or "post".'
|
122
|
+
assert 0. <= afferent_ratio <= 1., 'Afferent ratio must be in [0, 1].'
|
121
123
|
if isinstance(conn_num, float):
|
122
124
|
assert 0. <= conn_num <= 1., 'Connection probability must be in [0, 1].'
|
123
|
-
conn_num = int(self.out_size[-1] * conn_num)
|
125
|
+
conn_num = (int(self.out_size[-1] * conn_num)
|
126
|
+
if efferent_target == 'post' else
|
127
|
+
int(self.in_size[-1] * conn_num))
|
124
128
|
assert isinstance(conn_num, int), 'Connection number must be an integer.'
|
125
129
|
self.conn_num = conn_num
|
126
130
|
self.seed = seed
|
@@ -128,14 +132,13 @@ class EventFixedNumConn(Module):
|
|
128
132
|
|
129
133
|
# connections
|
130
134
|
if self.conn_num >= 1:
|
131
|
-
if self.
|
135
|
+
if self.efferent_target == 'post':
|
132
136
|
n_post = self.out_size[-1]
|
133
137
|
n_pre = self.in_size[-1]
|
134
138
|
else:
|
135
139
|
n_post = self.in_size[-1]
|
136
140
|
n_pre = self.out_size[-1]
|
137
141
|
|
138
|
-
# indices of post connected neurons
|
139
142
|
with jax.ensure_compile_time_eval():
|
140
143
|
if allow_multi_conn:
|
141
144
|
rng = np.random if seed is None else np.random.RandomState(seed)
|
@@ -143,15 +146,83 @@ class EventFixedNumConn(Module):
|
|
143
146
|
else:
|
144
147
|
indices = init_indices_without_replace(self.conn_num, n_pre, n_post, seed, conn_init)
|
145
148
|
indices = u.math.asarray(indices, dtype=environ.ditype())
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
149
|
+
|
150
|
+
if afferent_ratio == 1.:
|
151
|
+
conn_weight = u.math.asarray(init.param(conn_weight, (n_pre, self.conn_num), allow_none=False))
|
152
|
+
self.weight = param_type(conn_weight)
|
153
|
+
csr = (
|
154
|
+
brainevent.FixedPostNumConn((conn_weight, indices), shape=(n_pre, n_post))
|
155
|
+
if self.efferent_target == 'post' else
|
156
|
+
brainevent.FixedPreNumConn((conn_weight, indices), shape=(n_pre, n_post))
|
157
|
+
)
|
158
|
+
self.conn = csr
|
159
|
+
|
160
|
+
else:
|
161
|
+
self.pre_selected = np.random.random(n_pre) < afferent_ratio
|
162
|
+
indices = indices[self.pre_selected].flatten()
|
163
|
+
conn_weight = u.math.asarray(init.param(conn_weight, (indices.size,), allow_none=False))
|
164
|
+
self.weight = param_type(conn_weight)
|
165
|
+
indptr = (jnp.arange(1, n_pre + 1) * self.conn_num -
|
166
|
+
jnp.cumsum(~self.pre_selected) * self.conn_num)
|
167
|
+
indptr = jnp.insert(indptr, 0, 0) # insert 0 at the beginning
|
168
|
+
csr = (
|
169
|
+
brainevent.CSR((conn_weight, indices, indptr), shape=(n_pre, n_post))
|
170
|
+
if self.efferent_target == 'post' else
|
171
|
+
brainevent.CSC((conn_weight, indices, indptr), shape=(n_pre, n_post))
|
172
|
+
)
|
173
|
+
self.conn = csr
|
174
|
+
|
175
|
+
else:
|
176
|
+
conn_weight = u.math.asarray(init.param(conn_weight, (), allow_none=False))
|
177
|
+
self.weight = FakeState(conn_weight)
|
178
|
+
|
179
|
+
def update(self, x: jax.Array) -> Union[jax.Array, u.Quantity]:
|
180
|
+
if self.conn_num >= 1:
|
181
|
+
csr = self.conn.with_data(self.weight.value)
|
182
|
+
return x @ csr
|
183
|
+
else:
|
184
|
+
weight = self.weight.value
|
185
|
+
r = u.math.zeros(x.shape[:-1] + (self.out_size[-1],), dtype=weight.dtype)
|
186
|
+
r = u.maybe_decimal(u.Quantity(r, unit=u.get_unit(weight)))
|
187
|
+
return u.math.asarray(r, dtype=environ.dftype())
|
188
|
+
|
189
|
+
|
190
|
+
class EventFixedNumConn(FixedNumConn):
|
191
|
+
"""
|
192
|
+
The FixedProb module implements a fixed probability connection with CSR sparse data structure.
|
193
|
+
|
194
|
+
Parameters
|
195
|
+
----------
|
196
|
+
in_size : Size
|
197
|
+
Number of pre-synaptic neurons, i.e., input size.
|
198
|
+
out_size : Size
|
199
|
+
Number of post-synaptic neurons, i.e., output size.
|
200
|
+
conn_num : float, int
|
201
|
+
If it is a float, representing the probability of connection, i.e., connection probability.
|
202
|
+
|
203
|
+
If it is an integer, representing the number of connections.
|
204
|
+
conn_weight : float or callable or jax.Array or brainunit.Quantity
|
205
|
+
Maximum synaptic conductance, i.e., synaptic weight.
|
206
|
+
conn_target : str, optional
|
207
|
+
The target of the connection. Default is 'post', meaning that each pre-synaptic neuron connects to
|
208
|
+
a fixed number of post-synaptic neurons. The connection number is determined by the value of ``n_conn``.
|
209
|
+
|
210
|
+
If 'pre', each post-synaptic neuron connects to a fixed number of pre-synaptic neurons.
|
211
|
+
conn_init : str, optional
|
212
|
+
The initialization method of the connection weight. Default is 'vmap', meaning that the connection weight
|
213
|
+
is initialized by parallelized across multiple threads.
|
214
|
+
|
215
|
+
If 'for_loop', the connection weight is initialized by a for loop.
|
216
|
+
allow_multi_conn : bool, optional
|
217
|
+
Whether multiple connections are allowed from a single pre-synaptic neuron.
|
218
|
+
Default is True, meaning that a value of ``a`` can be selected multiple times.
|
219
|
+
seed: int, optional
|
220
|
+
Random seed. Default is None. If None, the default random seed will be used.
|
221
|
+
name : str, optional
|
222
|
+
Name of the module.
|
223
|
+
"""
|
224
|
+
|
225
|
+
__module__ = 'brainstate.nn'
|
155
226
|
|
156
227
|
def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
|
157
228
|
if self.conn_num >= 1:
|
@@ -20,12 +20,11 @@ import jax
|
|
20
20
|
import numpy as np
|
21
21
|
|
22
22
|
from brainstate import environ, init, random
|
23
|
-
from brainstate._state import ShortTermState
|
24
|
-
from brainstate._state import State, maybe_state
|
23
|
+
from brainstate._state import ShortTermState, State, maybe_state
|
25
24
|
from brainstate.compile import while_loop
|
26
|
-
from brainstate.nn._dynamics._dynamics_base import Dynamics, Prefetch
|
27
|
-
from brainstate.nn._module import Module
|
28
25
|
from brainstate.typing import ArrayLike, Size, DTypeLike
|
26
|
+
from ._dynamics import Dynamics, Prefetch
|
27
|
+
from ._module import Module
|
29
28
|
|
30
29
|
__all__ = [
|
31
30
|
'SpikeTime',
|
@@ -134,7 +133,7 @@ class PoissonSpike(Dynamics):
|
|
134
133
|
self.freqs = init.param(freqs, self.varshape, allow_none=False)
|
135
134
|
|
136
135
|
def update(self):
|
137
|
-
spikes = random.rand(self.varshape) <= (self.freqs * environ.get_dt())
|
136
|
+
spikes = random.rand(*self.varshape) <= (self.freqs * environ.get_dt())
|
138
137
|
spikes = u.math.asarray(spikes, dtype=self.spk_type)
|
139
138
|
return spikes
|
140
139
|
|
@@ -22,8 +22,8 @@ import jax.numpy as jnp
|
|
22
22
|
|
23
23
|
from brainstate import init, functional
|
24
24
|
from brainstate._state import ParamState
|
25
|
-
from brainstate.nn._module import Module
|
26
25
|
from brainstate.typing import ArrayLike, Size
|
26
|
+
from ._module import Module
|
27
27
|
|
28
28
|
__all__ = [
|
29
29
|
'Linear',
|
@@ -350,10 +350,7 @@ class OneToOne(Module):
|
|
350
350
|
self.weight = param_type(param)
|
351
351
|
|
352
352
|
def update(self, pre_val):
|
353
|
-
|
354
|
-
w_val, w_unit = u.get_mantissa(self.weight.value['weight']), u.get_unit(self.weight.value['weight'])
|
355
|
-
post_val = pre_val * w_val
|
356
|
-
post_val = u.maybe_decimal(u.Quantity(post_val, unit=w_unit * pre_unit))
|
353
|
+
post_val = pre_val * self.weight.value['weight']
|
357
354
|
if 'bias' in self.weight.value:
|
358
355
|
post_val = post_val + self.weight.value['bias']
|
359
356
|
return post_val
|
@@ -19,10 +19,10 @@ import brainunit as u
|
|
19
19
|
import jax
|
20
20
|
|
21
21
|
from brainstate import init
|
22
|
-
|
22
|
+
import brainevent
|
23
23
|
from brainstate._state import ParamState
|
24
|
-
from brainstate.nn._module import Module
|
25
24
|
from brainstate.typing import Size, ArrayLike
|
25
|
+
from ._module import Module
|
26
26
|
|
27
27
|
__all__ = [
|
28
28
|
'EventLinear',
|
@@ -16,11 +16,13 @@
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
|
18
18
|
|
19
|
-
from .
|
20
|
-
from ._linear_mv import EventLinear
|
19
|
+
from ._synapse import Synapse
|
21
20
|
|
22
21
|
__all__ = [
|
23
|
-
'
|
24
|
-
'EventFixedProb',
|
25
|
-
'EventFixedNumConn',
|
22
|
+
'LongTermPlasticity',
|
26
23
|
]
|
24
|
+
|
25
|
+
|
26
|
+
class LongTermPlasticity(Synapse):
|
27
|
+
pass
|
28
|
+
|
brainstate/nn/_module.py
CHANGED
@@ -34,7 +34,7 @@ import numpy as np
|
|
34
34
|
from brainstate._state import State
|
35
35
|
from brainstate.graph import Node, states, nodes, flatten
|
36
36
|
from brainstate.mixin import ParamDescriber, ParamDesc
|
37
|
-
from brainstate.typing import PathParts
|
37
|
+
from brainstate.typing import PathParts, Size
|
38
38
|
from brainstate.util import FlattedDict, NestedDict, BrainStateError
|
39
39
|
|
40
40
|
# maximum integer
|
@@ -62,8 +62,8 @@ class Module(Node, ParamDesc):
|
|
62
62
|
|
63
63
|
__module__ = 'brainstate.nn'
|
64
64
|
|
65
|
-
_in_size: Optional[
|
66
|
-
_out_size: Optional[
|
65
|
+
_in_size: Optional[Size]
|
66
|
+
_out_size: Optional[Size]
|
67
67
|
_name: Optional[str]
|
68
68
|
|
69
69
|
if not TYPE_CHECKING:
|
@@ -87,7 +87,7 @@ class Module(Node, ParamDesc):
|
|
87
87
|
raise AttributeError('The name of the model is read-only.')
|
88
88
|
|
89
89
|
@property
|
90
|
-
def in_size(self) ->
|
90
|
+
def in_size(self) -> Size:
|
91
91
|
return self._in_size
|
92
92
|
|
93
93
|
@in_size.setter
|
@@ -98,7 +98,7 @@ class Module(Node, ParamDesc):
|
|
98
98
|
self._in_size = tuple(in_size)
|
99
99
|
|
100
100
|
@property
|
101
|
-
def out_size(self) ->
|
101
|
+
def out_size(self) -> Size:
|
102
102
|
return self._out_size
|
103
103
|
|
104
104
|
@out_size.setter
|
@@ -22,9 +22,9 @@ import jax
|
|
22
22
|
|
23
23
|
from brainstate import init, surrogate, environ
|
24
24
|
from brainstate._state import HiddenState, ShortTermState
|
25
|
-
from brainstate.nn._dynamics._dynamics_base import Dynamics
|
26
|
-
from brainstate.nn._exp_euler import exp_euler_step
|
27
25
|
from brainstate.typing import ArrayLike, Size
|
26
|
+
from ._dynamics import Dynamics
|
27
|
+
from ._exp_euler import exp_euler_step
|
28
28
|
|
29
29
|
__all__ = [
|
30
30
|
'Neuron', 'IF', 'LIF', 'LIFRef', 'ALIF',
|
@@ -22,8 +22,8 @@ import jax.numpy as jnp
|
|
22
22
|
|
23
23
|
from brainstate import environ, init
|
24
24
|
from brainstate._state import ParamState, BatchState
|
25
|
-
from brainstate.nn._module import Module
|
26
25
|
from brainstate.typing import DTypeLike, ArrayLike, Size, Axes
|
26
|
+
from ._module import Module
|
27
27
|
|
28
28
|
__all__ = [
|
29
29
|
'BatchNorm0d',
|
@@ -103,7 +103,7 @@ class TestPool(parameterized.TestCase):
|
|
103
103
|
for target_size in [10, 9, 8, 7, 6]
|
104
104
|
)
|
105
105
|
def test_adaptive_pool1d(self, target_size):
|
106
|
-
from brainstate.nn.
|
106
|
+
from brainstate.nn._poolings import _adaptive_pool1d
|
107
107
|
|
108
108
|
arr = brainstate.random.rand(100)
|
109
109
|
op = jax.numpy.mean
|