brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +95 -29
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.9.dist-info/RECORD +0 -130
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/nn/_neuron_test.py
DELETED
@@ -1,161 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
|
19
|
-
import unittest
|
20
|
-
|
21
|
-
import brainunit as u
|
22
|
-
import jax
|
23
|
-
import jax.numpy as jnp
|
24
|
-
|
25
|
-
import brainstate
|
26
|
-
from brainstate.nn import IF, LIF, ALIF
|
27
|
-
|
28
|
-
|
29
|
-
class TestNeuron(unittest.TestCase):
|
30
|
-
def setUp(self):
|
31
|
-
self.in_size = 10
|
32
|
-
self.batch_size = 5
|
33
|
-
self.time_steps = 100
|
34
|
-
|
35
|
-
def test_neuron_base_class(self):
|
36
|
-
with self.assertRaises(NotImplementedError):
|
37
|
-
brainstate.nn.Neuron(self.in_size).get_spike() # Neuron is an abstract base class
|
38
|
-
|
39
|
-
def generate_input(self):
|
40
|
-
return brainstate.random.randn(self.time_steps, self.batch_size, self.in_size) * u.mA
|
41
|
-
|
42
|
-
def test_if_neuron(self):
|
43
|
-
with brainstate.environ.context(dt=0.1 * u.ms):
|
44
|
-
neuron = IF(self.in_size)
|
45
|
-
inputs = self.generate_input()
|
46
|
-
|
47
|
-
# Test initialization
|
48
|
-
self.assertEqual(neuron.in_size, (self.in_size,))
|
49
|
-
self.assertEqual(neuron.out_size, (self.in_size,))
|
50
|
-
|
51
|
-
# Test forward pass
|
52
|
-
state = neuron.init_state(self.batch_size)
|
53
|
-
|
54
|
-
for t in range(self.time_steps):
|
55
|
-
out = neuron(inputs[t])
|
56
|
-
self.assertEqual(out.shape, (self.batch_size, self.in_size))
|
57
|
-
|
58
|
-
# Test spike generation
|
59
|
-
v = jnp.linspace(-1, 1, 100) * u.mV
|
60
|
-
spikes = neuron.get_spike(v)
|
61
|
-
self.assertTrue(jnp.all((spikes >= 0) & (spikes <= 1)))
|
62
|
-
|
63
|
-
def test_lif_neuron(self):
|
64
|
-
with brainstate.environ.context(dt=0.1 * u.ms):
|
65
|
-
tau = 20.0 * u.ms
|
66
|
-
neuron = LIF(self.in_size, tau=tau)
|
67
|
-
inputs = self.generate_input()
|
68
|
-
|
69
|
-
# Test initialization
|
70
|
-
self.assertEqual(neuron.in_size, (self.in_size,))
|
71
|
-
self.assertEqual(neuron.out_size, (self.in_size,))
|
72
|
-
self.assertEqual(neuron.tau, tau)
|
73
|
-
|
74
|
-
# Test forward pass
|
75
|
-
state = neuron.init_state(self.batch_size)
|
76
|
-
call = brainstate.compile.jit(neuron)
|
77
|
-
|
78
|
-
for t in range(self.time_steps):
|
79
|
-
out = call(inputs[t])
|
80
|
-
self.assertEqual(out.shape, (self.batch_size, self.in_size))
|
81
|
-
|
82
|
-
def test_alif_neuron(self):
|
83
|
-
tau = 20.0 * u.ms
|
84
|
-
tau_ada = 100.0 * u.ms
|
85
|
-
neuron = ALIF(self.in_size, tau=tau, tau_a=tau_ada)
|
86
|
-
inputs = self.generate_input()
|
87
|
-
|
88
|
-
# Test initialization
|
89
|
-
self.assertEqual(neuron.in_size, (self.in_size,))
|
90
|
-
self.assertEqual(neuron.out_size, (self.in_size,))
|
91
|
-
self.assertEqual(neuron.tau, tau)
|
92
|
-
self.assertEqual(neuron.tau_a, tau_ada)
|
93
|
-
|
94
|
-
# Test forward pass
|
95
|
-
neuron.init_state(self.batch_size)
|
96
|
-
call = brainstate.compile.jit(neuron)
|
97
|
-
with brainstate.environ.context(dt=0.1 * u.ms):
|
98
|
-
for t in range(self.time_steps):
|
99
|
-
out = call(inputs[t])
|
100
|
-
self.assertEqual(out.shape, (self.batch_size, self.in_size))
|
101
|
-
|
102
|
-
def test_spike_function(self):
|
103
|
-
for NeuronClass in [IF, LIF, ALIF]:
|
104
|
-
neuron = NeuronClass(self.in_size)
|
105
|
-
neuron.init_state()
|
106
|
-
v = jnp.linspace(-1, 1, self.in_size) * u.mV
|
107
|
-
spikes = neuron.get_spike(v)
|
108
|
-
self.assertTrue(jnp.all((spikes >= 0) & (spikes <= 1)))
|
109
|
-
|
110
|
-
def test_soft_reset(self):
|
111
|
-
for NeuronClass in [IF, LIF, ALIF]:
|
112
|
-
neuron = NeuronClass(self.in_size, spk_reset='soft')
|
113
|
-
inputs = self.generate_input()
|
114
|
-
state = neuron.init_state(self.batch_size)
|
115
|
-
call = brainstate.compile.jit(neuron)
|
116
|
-
with brainstate.environ.context(dt=0.1 * u.ms):
|
117
|
-
for t in range(self.time_steps):
|
118
|
-
out = call(inputs[t])
|
119
|
-
self.assertTrue(jnp.all(neuron.V.value <= neuron.V_th))
|
120
|
-
|
121
|
-
def test_hard_reset(self):
|
122
|
-
for NeuronClass in [IF, LIF, ALIF]:
|
123
|
-
neuron = NeuronClass(self.in_size, spk_reset='hard')
|
124
|
-
inputs = self.generate_input()
|
125
|
-
state = neuron.init_state(self.batch_size)
|
126
|
-
call = brainstate.compile.jit(neuron)
|
127
|
-
with brainstate.environ.context(dt=0.1 * u.ms):
|
128
|
-
for t in range(self.time_steps):
|
129
|
-
out = call(inputs[t])
|
130
|
-
self.assertTrue(jnp.all((neuron.V.value < neuron.V_th) | (neuron.V.value == 0. * u.mV)))
|
131
|
-
|
132
|
-
def test_detach_spike(self):
|
133
|
-
for NeuronClass in [IF, LIF, ALIF]:
|
134
|
-
neuron = NeuronClass(self.in_size)
|
135
|
-
inputs = self.generate_input()
|
136
|
-
state = neuron.init_state(self.batch_size)
|
137
|
-
call = brainstate.compile.jit(neuron)
|
138
|
-
with brainstate.environ.context(dt=0.1 * u.ms):
|
139
|
-
for t in range(self.time_steps):
|
140
|
-
out = call(inputs[t])
|
141
|
-
self.assertFalse(jax.tree_util.tree_leaves(out)[0].aval.weak_type)
|
142
|
-
|
143
|
-
def test_keep_size(self):
|
144
|
-
in_size = (2, 3)
|
145
|
-
for NeuronClass in [IF, LIF, ALIF]:
|
146
|
-
neuron = NeuronClass(in_size)
|
147
|
-
self.assertEqual(neuron.in_size, in_size)
|
148
|
-
self.assertEqual(neuron.out_size, in_size)
|
149
|
-
|
150
|
-
inputs = brainstate.random.randn(self.time_steps, self.batch_size, *in_size) * u.mA
|
151
|
-
state = neuron.init_state(self.batch_size)
|
152
|
-
call = brainstate.compile.jit(neuron)
|
153
|
-
with brainstate.environ.context(dt=0.1 * u.ms):
|
154
|
-
for t in range(self.time_steps):
|
155
|
-
out = call(inputs[t])
|
156
|
-
self.assertEqual(out.shape, (self.batch_size, *in_size))
|
157
|
-
|
158
|
-
|
159
|
-
if __name__ == '__main__':
|
160
|
-
with brainstate.environ.context(dt=0.1):
|
161
|
-
unittest.main()
|
brainstate/nn/_others.py
DELETED
@@ -1,46 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
from functools import partial
|
17
|
-
|
18
|
-
import jax
|
19
|
-
import jax.numpy as jnp
|
20
|
-
|
21
|
-
from brainstate.typing import PyTree
|
22
|
-
|
23
|
-
__all__ = [
|
24
|
-
'clip_grad_norm',
|
25
|
-
]
|
26
|
-
|
27
|
-
|
28
|
-
def clip_grad_norm(
|
29
|
-
grad: PyTree,
|
30
|
-
max_norm: float | jax.Array,
|
31
|
-
norm_type: int | str | None = None
|
32
|
-
):
|
33
|
-
"""
|
34
|
-
Clips gradient norm of an iterable of parameters.
|
35
|
-
|
36
|
-
The norm is computed over all gradients together, as if they were
|
37
|
-
concatenated into a single vector. Gradients are modified in-place.
|
38
|
-
|
39
|
-
Args:
|
40
|
-
grad (PyTree): an iterable of Tensors or a single Tensor that will have gradients normalized
|
41
|
-
max_norm (float): max norm of the gradients.
|
42
|
-
norm_type (int, str, None): type of the used p-norm. Can be ``'inf'`` for infinity norm.
|
43
|
-
"""
|
44
|
-
norm_fn = partial(jnp.linalg.norm, ord=norm_type)
|
45
|
-
norm = norm_fn(jnp.asarray(jax.tree.leaves(jax.tree.map(norm_fn, grad))))
|
46
|
-
return jax.tree.map(lambda x: jnp.where(norm < max_norm, x, x * max_norm / (norm + 1e-6)), grad)
|
brainstate/nn/_projection.py
DELETED
@@ -1,486 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
from typing import Callable, Union
|
17
|
-
from typing import Optional
|
18
|
-
|
19
|
-
import brainevent
|
20
|
-
import brainunit as u
|
21
|
-
|
22
|
-
from brainstate._state import State
|
23
|
-
from brainstate.mixin import BindCondData, JointTypes
|
24
|
-
from brainstate.mixin import ParamDescriber, AlignPost
|
25
|
-
from brainstate.util.others import get_unique_name
|
26
|
-
from ._collective_ops import call_order
|
27
|
-
from ._dynamics import Dynamics, Projection, maybe_init_prefetch, Prefetch, PrefetchDelayAt
|
28
|
-
from ._module import Module
|
29
|
-
from ._stp import ShortTermPlasticity
|
30
|
-
from ._synapse import Synapse
|
31
|
-
from ._synouts import SynOut
|
32
|
-
|
33
|
-
__all__ = [
|
34
|
-
'AlignPostProj',
|
35
|
-
'DeltaProj',
|
36
|
-
'CurrentProj',
|
37
|
-
|
38
|
-
'align_pre_projection',
|
39
|
-
'align_post_projection',
|
40
|
-
]
|
41
|
-
|
42
|
-
|
43
|
-
def _check_modules(*modules):
|
44
|
-
# checking modules
|
45
|
-
for module in modules:
|
46
|
-
if not callable(module) and not isinstance(module, State):
|
47
|
-
raise TypeError(
|
48
|
-
f'The module should be a callable function or a brainstate.State, but got {module}.'
|
49
|
-
)
|
50
|
-
return tuple(modules)
|
51
|
-
|
52
|
-
|
53
|
-
def call_module(module, *args, **kwargs):
|
54
|
-
if callable(module):
|
55
|
-
return module(*args, **kwargs)
|
56
|
-
elif isinstance(module, State):
|
57
|
-
return module.value
|
58
|
-
else:
|
59
|
-
raise TypeError(
|
60
|
-
f'The module should be a callable function or a brainstate.State, but got {module}.'
|
61
|
-
)
|
62
|
-
|
63
|
-
|
64
|
-
def is_instance(x, cls) -> bool:
|
65
|
-
return isinstance(x, cls)
|
66
|
-
|
67
|
-
|
68
|
-
def get_post_repr(label, syn, out):
|
69
|
-
if label is None:
|
70
|
-
return f'{syn.identifier} // {out.identifier}'
|
71
|
-
else:
|
72
|
-
return f'{label}{syn.identifier} // {out.identifier}'
|
73
|
-
|
74
|
-
|
75
|
-
def align_post_add_bef_update(
|
76
|
-
syn_desc: ParamDescriber[AlignPost],
|
77
|
-
out_desc: ParamDescriber[BindCondData],
|
78
|
-
post: Dynamics,
|
79
|
-
proj_name: str,
|
80
|
-
label: str,
|
81
|
-
):
|
82
|
-
# synapse and output initialization
|
83
|
-
_post_repr = get_post_repr(label, syn_desc, out_desc)
|
84
|
-
if not post._has_before_update(_post_repr):
|
85
|
-
syn_cls = syn_desc()
|
86
|
-
out_cls = out_desc()
|
87
|
-
|
88
|
-
# synapse and output initialization
|
89
|
-
post.add_current_input(proj_name, out_cls, label=label)
|
90
|
-
post._add_before_update(_post_repr, _AlignPost(syn_cls, out_cls))
|
91
|
-
syn = post._get_before_update(_post_repr).syn
|
92
|
-
out = post._get_before_update(_post_repr).out
|
93
|
-
return syn, out
|
94
|
-
|
95
|
-
|
96
|
-
class _AlignPost(Module):
|
97
|
-
def __init__(
|
98
|
-
self,
|
99
|
-
syn: Dynamics,
|
100
|
-
out: BindCondData
|
101
|
-
):
|
102
|
-
super().__init__()
|
103
|
-
self.syn = syn
|
104
|
-
self.out = out
|
105
|
-
|
106
|
-
def update(self, *args, **kwargs):
|
107
|
-
self.out.bind_cond(self.syn(*args, **kwargs))
|
108
|
-
|
109
|
-
|
110
|
-
class AlignPostProj(Projection):
|
111
|
-
"""
|
112
|
-
Align-post projection of the neural network.
|
113
|
-
|
114
|
-
|
115
|
-
Examples
|
116
|
-
--------
|
117
|
-
|
118
|
-
Here is an example of using the `AlignPostProj` to create a synaptic projection.
|
119
|
-
Note that this projection needs the manual input of pre-synaptic spikes.
|
120
|
-
|
121
|
-
>>> import brainstate
|
122
|
-
>>> import brainunit as u
|
123
|
-
>>> n_exc = 3200
|
124
|
-
>>> n_inh = 800
|
125
|
-
>>> num = n_exc + n_inh
|
126
|
-
>>> pop = brainstate.nn.LIFRef(
|
127
|
-
... num,
|
128
|
-
... V_rest=-49. * u.mV, V_th=-50. * u.mV, V_reset=-60. * u.mV,
|
129
|
-
... tau=20. * u.ms, tau_ref=5. * u.ms,
|
130
|
-
... V_initializer=brainstate.init.Normal(-55., 2., unit=u.mV)
|
131
|
-
... )
|
132
|
-
>>> pop.init_state()
|
133
|
-
>>> E = brainstate.nn.AlignPostProj(
|
134
|
-
... comm=brainstate.nn.FixedNumConn(n_exc, num, prob=80 / num, weight=1.62 * u.mS),
|
135
|
-
... syn=brainstate.nn.Expon.desc(num, tau=5. * u.ms),
|
136
|
-
... out=brainstate.nn.CUBA.desc(scale=u.volt),
|
137
|
-
... post=pop
|
138
|
-
... )
|
139
|
-
>>> exe_current = E(pop.get_spike())
|
140
|
-
|
141
|
-
"""
|
142
|
-
__module__ = 'brainstate.nn'
|
143
|
-
|
144
|
-
def __init__(
|
145
|
-
self,
|
146
|
-
*modules,
|
147
|
-
comm: Callable,
|
148
|
-
syn: Union[ParamDescriber[AlignPost], AlignPost],
|
149
|
-
out: Union[ParamDescriber[SynOut], SynOut],
|
150
|
-
post: Dynamics,
|
151
|
-
label: Optional[str] = None,
|
152
|
-
):
|
153
|
-
super().__init__(name=get_unique_name(self.__class__.__name__))
|
154
|
-
|
155
|
-
# checking modules
|
156
|
-
self.modules = _check_modules(*modules)
|
157
|
-
|
158
|
-
# checking communication model
|
159
|
-
if not callable(comm):
|
160
|
-
raise TypeError(
|
161
|
-
f'The communication should be an instance of callable function, but got {comm}.'
|
162
|
-
)
|
163
|
-
|
164
|
-
# checking synapse and output models
|
165
|
-
if is_instance(syn, ParamDescriber[AlignPost]):
|
166
|
-
if not is_instance(out, ParamDescriber[SynOut]):
|
167
|
-
if is_instance(out, ParamDescriber):
|
168
|
-
raise TypeError(
|
169
|
-
f'The output should be an instance of describer {ParamDescriber[SynOut]} when '
|
170
|
-
f'the synapse is an instance of {AlignPost}, but got {out}.'
|
171
|
-
)
|
172
|
-
raise TypeError(
|
173
|
-
f'The output should be an instance of describer {ParamDescriber[SynOut]} when '
|
174
|
-
f'the synapse is a describer, but we got {out}.'
|
175
|
-
)
|
176
|
-
merging = True
|
177
|
-
else:
|
178
|
-
if is_instance(syn, ParamDescriber):
|
179
|
-
raise TypeError(
|
180
|
-
f'The synapse should be an instance of describer {ParamDescriber[AlignPost]}, but got {syn}.'
|
181
|
-
)
|
182
|
-
if not is_instance(out, SynOut):
|
183
|
-
raise TypeError(
|
184
|
-
f'The output should be an instance of {SynOut} when the synapse is '
|
185
|
-
f'not a describer, but we got {out}.'
|
186
|
-
)
|
187
|
-
merging = False
|
188
|
-
self.merging = merging
|
189
|
-
|
190
|
-
# checking post model
|
191
|
-
if not is_instance(post, Dynamics):
|
192
|
-
raise TypeError(
|
193
|
-
f'The post should be an instance of {Dynamics}, but got {post}.'
|
194
|
-
)
|
195
|
-
|
196
|
-
if merging:
|
197
|
-
# synapse and output initialization
|
198
|
-
syn, out = align_post_add_bef_update(syn_desc=syn,
|
199
|
-
out_desc=out,
|
200
|
-
post=post,
|
201
|
-
proj_name=self.name,
|
202
|
-
label=label)
|
203
|
-
else:
|
204
|
-
post.add_current_input(self.name, out)
|
205
|
-
|
206
|
-
# references
|
207
|
-
self.comm = comm
|
208
|
-
self.syn: JointTypes[Dynamics, AlignPost] = syn
|
209
|
-
self.out: BindCondData = out
|
210
|
-
self.post: Dynamics = post
|
211
|
-
|
212
|
-
@call_order(2)
|
213
|
-
def init_state(self, *args, **kwargs):
|
214
|
-
for module in self.modules:
|
215
|
-
maybe_init_prefetch(module, *args, **kwargs)
|
216
|
-
|
217
|
-
def update(self, *args):
|
218
|
-
# call all modules
|
219
|
-
for module in self.modules:
|
220
|
-
x = call_module(module, *args)
|
221
|
-
args = (x,)
|
222
|
-
# communication module
|
223
|
-
x = self.comm(*args)
|
224
|
-
# add synapse input
|
225
|
-
self.syn.add_delta_input(self.name, x)
|
226
|
-
if not self.merging:
|
227
|
-
# synapse and output interaction
|
228
|
-
conductance = self.syn()
|
229
|
-
self.out.bind_cond(conductance)
|
230
|
-
|
231
|
-
|
232
|
-
class DeltaProj(Projection):
|
233
|
-
"""
|
234
|
-
Delta-based projection of the neural network.
|
235
|
-
|
236
|
-
This projection directly applies delta inputs to post-synaptic neurons without intervening
|
237
|
-
synaptic dynamics. It processes inputs through optional prefetch modules, applies a communication model,
|
238
|
-
and adds the result directly as a delta input to the post-synaptic population.
|
239
|
-
|
240
|
-
Parameters
|
241
|
-
----------
|
242
|
-
*prefetch : State or callable
|
243
|
-
Optional prefetch modules to process input before communication.
|
244
|
-
comm : callable
|
245
|
-
Communication model that determines how signals are transmitted.
|
246
|
-
post : Dynamics
|
247
|
-
Post-synaptic neural population to receive the delta inputs.
|
248
|
-
label : Optional[str], default=None
|
249
|
-
Optional label for the projection to identify it in the post-synaptic population.
|
250
|
-
|
251
|
-
Examples
|
252
|
-
--------
|
253
|
-
>>> import brainstate
|
254
|
-
>>> import brainunit as u
|
255
|
-
>>> n_neurons = 100
|
256
|
-
>>> pop = brainstate.nn.LIF(n_neurons, V_rest=-70*u.mV, V_threshold=-50*u.mV)
|
257
|
-
>>> pop.init_state()
|
258
|
-
>>> delta_input = brainstate.nn.DeltaProj(
|
259
|
-
... comm=lambda x: x * 10.0*u.mV,
|
260
|
-
... post=pop
|
261
|
-
... )
|
262
|
-
>>> delta_input(1.0) # Apply voltage increment directly
|
263
|
-
"""
|
264
|
-
__module__ = 'brainstate.nn'
|
265
|
-
|
266
|
-
def __init__(self, *prefetch, comm: Callable, post: Dynamics, label=None):
|
267
|
-
super().__init__(name=get_unique_name(self.__class__.__name__))
|
268
|
-
|
269
|
-
self.label = label
|
270
|
-
|
271
|
-
# checking modules
|
272
|
-
self.prefetches = _check_modules(*prefetch)
|
273
|
-
|
274
|
-
# checking communication model
|
275
|
-
if not callable(comm):
|
276
|
-
raise TypeError(
|
277
|
-
f'The communication should be an instance of callable function, but got {comm}.'
|
278
|
-
)
|
279
|
-
self.comm = comm
|
280
|
-
|
281
|
-
# post model
|
282
|
-
if not isinstance(post, Dynamics):
|
283
|
-
raise TypeError(
|
284
|
-
f'The post should be an instance of {Dynamics}, but got {post}.'
|
285
|
-
)
|
286
|
-
self.post = post
|
287
|
-
|
288
|
-
@call_order(2)
|
289
|
-
def init_state(self, *args, **kwargs):
|
290
|
-
for prefetch in self.prefetches:
|
291
|
-
maybe_init_prefetch(prefetch, *args, **kwargs)
|
292
|
-
|
293
|
-
def update(self, *x):
|
294
|
-
for module in self.prefetches:
|
295
|
-
x = (call_module(module, *x),)
|
296
|
-
assert len(x) == 1, f'The output of the modules should be a single value, but got {x}.'
|
297
|
-
x = self.comm(x[0])
|
298
|
-
self.post.add_delta_input(self.name, x, label=self.label)
|
299
|
-
|
300
|
-
|
301
|
-
class CurrentProj(Projection):
|
302
|
-
"""
|
303
|
-
Current-based projection of the neural network.
|
304
|
-
|
305
|
-
This projection directly modulates post-synaptic currents without separate synaptic dynamics.
|
306
|
-
It processes inputs through optional prefetch modules, applies a communication model,
|
307
|
-
and binds the result to the output model which is then added as a current input to the post-synaptic population.
|
308
|
-
|
309
|
-
Parameters
|
310
|
-
----------
|
311
|
-
*prefetch : State or callable
|
312
|
-
Optional prefetch modules to process input before communication.
|
313
|
-
The last element must be an instance of Prefetch or PrefetchDelayAt if any are provided.
|
314
|
-
comm : callable
|
315
|
-
Communication model that determines how signals are transmitted.
|
316
|
-
out : SynOut
|
317
|
-
Output model that converts communication results to post-synaptic currents.
|
318
|
-
post : Dynamics
|
319
|
-
Post-synaptic neural population to receive the currents.
|
320
|
-
|
321
|
-
Examples
|
322
|
-
--------
|
323
|
-
>>> import brainstate
|
324
|
-
>>> import brainunit as u
|
325
|
-
>>> n_neurons = 100
|
326
|
-
>>> pop = brainstate.nn.LIF(n_neurons, V_rest=-70*u.mV, V_threshold=-50*u.mV)
|
327
|
-
>>> pop.init_state()
|
328
|
-
>>> current_input = brainstate.nn.CurrentProj(
|
329
|
-
... comm=lambda x: x * 0.5,
|
330
|
-
... out=brainstate.nn.CUBA(scale=1.0*u.nA),
|
331
|
-
... post=pop
|
332
|
-
... )
|
333
|
-
>>> current_input(0.2) # Apply external current
|
334
|
-
"""
|
335
|
-
__module__ = 'brainstate.nn'
|
336
|
-
|
337
|
-
def __init__(
|
338
|
-
self,
|
339
|
-
*prefetch,
|
340
|
-
comm: Callable,
|
341
|
-
out: SynOut,
|
342
|
-
post: Dynamics,
|
343
|
-
):
|
344
|
-
super().__init__(name=get_unique_name(self.__class__.__name__))
|
345
|
-
|
346
|
-
# check prefetch
|
347
|
-
self.prefetch = prefetch
|
348
|
-
if len(self.prefetch) > 0 and not isinstance(prefetch[-1], (Prefetch, PrefetchDelayAt)):
|
349
|
-
raise TypeError(
|
350
|
-
f'The last element of prefetch should be an instance of {Prefetch} or {PrefetchDelayAt}, '
|
351
|
-
f'but got {prefetch[-1]}.'
|
352
|
-
)
|
353
|
-
|
354
|
-
# check out
|
355
|
-
if not isinstance(out, SynOut):
|
356
|
-
raise TypeError(f'The out should be a SynOut, but got {out}.')
|
357
|
-
self.out = out
|
358
|
-
|
359
|
-
# check post
|
360
|
-
if not isinstance(post, Dynamics):
|
361
|
-
raise TypeError(f'The post should be a Dynamics, but got {post}.')
|
362
|
-
self.post = post
|
363
|
-
post.add_current_input(self.name, out)
|
364
|
-
|
365
|
-
# output initialization
|
366
|
-
self.comm = comm
|
367
|
-
|
368
|
-
@call_order(2)
|
369
|
-
def init_state(self, *args, **kwargs):
|
370
|
-
for prefetch in self.prefetch:
|
371
|
-
maybe_init_prefetch(prefetch, *args, **kwargs)
|
372
|
-
|
373
|
-
def update(self, *x):
|
374
|
-
for prefetch in self.prefetch:
|
375
|
-
x = (call_module(prefetch, *x),)
|
376
|
-
x = self.comm(*x)
|
377
|
-
self.out.bind_cond(x)
|
378
|
-
|
379
|
-
|
380
|
-
class align_pre_projection(Projection):
|
381
|
-
"""
|
382
|
-
Represents a pre-synaptic alignment projection mechanism.
|
383
|
-
|
384
|
-
This class inherits from the `Projection` base class and is designed to
|
385
|
-
manage the pre-synaptic alignment process in neural network simulations.
|
386
|
-
It takes into account pre-synaptic dynamics, synaptic properties, delays,
|
387
|
-
communication functions, synaptic outputs, post-synaptic dynamics, and
|
388
|
-
short-term plasticity.
|
389
|
-
|
390
|
-
Attributes:
|
391
|
-
pre (Dynamics): The pre-synaptic dynamics object.
|
392
|
-
syn (Synapse): The synaptic object after pre-synaptic alignment.
|
393
|
-
delay (u.Quantity[u.second]): The output delay from the synapse.
|
394
|
-
projection (CurrentProj): The current projection object handling communication,
|
395
|
-
output, and post-synaptic dynamics.
|
396
|
-
stp (ShortTermPlasticity, optional): The short-term plasticity object,
|
397
|
-
defaults to None.
|
398
|
-
"""
|
399
|
-
|
400
|
-
def __init__(
|
401
|
-
self,
|
402
|
-
*spike_generator,
|
403
|
-
syn: Dynamics,
|
404
|
-
comm: Callable,
|
405
|
-
out: SynOut,
|
406
|
-
post: Dynamics,
|
407
|
-
stp: ShortTermPlasticity = None,
|
408
|
-
):
|
409
|
-
super().__init__()
|
410
|
-
|
411
|
-
self.spike_generator = _check_modules(*spike_generator)
|
412
|
-
self.projection = CurrentProj(comm=comm, out=out, post=post)
|
413
|
-
self.syn = syn
|
414
|
-
self.stp = stp
|
415
|
-
|
416
|
-
@call_order(2)
|
417
|
-
def init_state(self, *args, **kwargs):
|
418
|
-
for module in self.spike_generator:
|
419
|
-
maybe_init_prefetch(module, *args, **kwargs)
|
420
|
-
|
421
|
-
def update(self, *x):
|
422
|
-
for fun in self.spike_generator:
|
423
|
-
x = fun(*x)
|
424
|
-
if isinstance(x, (tuple, list)):
|
425
|
-
x = tuple(x)
|
426
|
-
else:
|
427
|
-
x = (x,)
|
428
|
-
assert len(x) == 1, "Spike generator must return a single value or a tuple/list of values"
|
429
|
-
x = brainevent.BinaryArray(x[0]) # Ensure input is a BinaryFloat for spike generation
|
430
|
-
if self.stp is not None:
|
431
|
-
x = brainevent.MaskedFloat(self.stp(x)) # Ensure STP output is a MaskedFloat
|
432
|
-
x = self.syn(x) # Apply pre-synaptic alignment
|
433
|
-
return self.projection(x)
|
434
|
-
|
435
|
-
|
436
|
-
class align_post_projection(Projection):
|
437
|
-
"""
|
438
|
-
Represents a post-synaptic alignment projection mechanism.
|
439
|
-
|
440
|
-
This class inherits from the `Projection` base class and is designed to
|
441
|
-
manage the post-synaptic alignment process in neural network simulations.
|
442
|
-
It takes into account spike generators, communication functions, synaptic
|
443
|
-
properties, synaptic outputs, post-synaptic dynamics, and short-term plasticity.
|
444
|
-
|
445
|
-
Args:
|
446
|
-
*spike_generator: Callable(s) that generate spike events or transform input spikes.
|
447
|
-
comm (Callable): Communication function for the projection.
|
448
|
-
syn (Union[AlignPost, ParamDescriber[AlignPost]]): The post-synaptic alignment object or its parameter describer.
|
449
|
-
out (Union[SynOut, ParamDescriber[SynOut]]): The synaptic output object or its parameter describer.
|
450
|
-
post (Dynamics): The post-synaptic dynamics object.
|
451
|
-
stp (ShortTermPlasticity, optional): The short-term plasticity object, defaults to None.
|
452
|
-
|
453
|
-
"""
|
454
|
-
|
455
|
-
def __init__(
|
456
|
-
self,
|
457
|
-
*spike_generator,
|
458
|
-
comm: Callable,
|
459
|
-
syn: Union[AlignPost, ParamDescriber[AlignPost]],
|
460
|
-
out: Union[SynOut, ParamDescriber[SynOut]],
|
461
|
-
post: Dynamics,
|
462
|
-
stp: ShortTermPlasticity = None,
|
463
|
-
):
|
464
|
-
super().__init__()
|
465
|
-
|
466
|
-
self.spike_generator = _check_modules(*spike_generator)
|
467
|
-
self.projection = AlignPostProj(comm=comm, syn=syn, out=out, post=post)
|
468
|
-
self.stp = stp
|
469
|
-
|
470
|
-
@call_order(2)
|
471
|
-
def init_state(self, *args, **kwargs):
|
472
|
-
for module in self.spike_generator:
|
473
|
-
maybe_init_prefetch(module, *args, **kwargs)
|
474
|
-
|
475
|
-
def update(self, *x):
|
476
|
-
for fun in self.spike_generator:
|
477
|
-
x = fun(*x)
|
478
|
-
if isinstance(x, (tuple, list)):
|
479
|
-
x = tuple(x)
|
480
|
-
else:
|
481
|
-
x = (x,)
|
482
|
-
assert len(x) == 1, "Spike generator must return a single value or a tuple/list of values"
|
483
|
-
x = brainevent.BinaryArray(x[0]) # Ensure input is a BinaryFloat for spike generation
|
484
|
-
if self.stp is not None:
|
485
|
-
x = brainevent.MaskedFloat(self.stp(x)) # Ensure STP output is a MaskedFloat
|
486
|
-
return self.projection(x)
|