brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/COBA_2005.py +125 -0
- benchmark/CUBA_2005.py +149 -0
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +611 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/event/__init__.py +27 -0
- brainstate/event/_csr.py +316 -0
- brainstate/event/_csr_benchmark.py +14 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +708 -0
- brainstate/event/_fixed_probability_benchmark.py +128 -0
- brainstate/event/_fixed_probability_test.py +131 -0
- brainstate/event/_linear.py +359 -0
- brainstate/event/_linear_benckmark.py +82 -0
- brainstate/event/_linear_test.py +117 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/event/_xla_custom_op.py +312 -0
- brainstate/event/_xla_custom_op_test.py +55 -0
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +41 -0
- brainstate/nn/_interaction/_conv.py +499 -0
- brainstate/nn/_interaction/_conv_test.py +239 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_linear.py +582 -0
- brainstate/nn/_interaction/_linear_test.py +42 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +121 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
- brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -1,599 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
from typing import Optional, Union
|
18
|
-
|
19
|
-
from brainstate._module import (Module, DelayAccess, Projection,
|
20
|
-
ExtendedUpdateWithBA, ReceiveInputProj,
|
21
|
-
register_delay_of_target)
|
22
|
-
from brainstate._utils import set_module_as
|
23
|
-
from brainstate.mixin import (DelayedInitializer, BindCondData, UpdateReturn, Mode, JointTypes)
|
24
|
-
from ._utils import is_instance
|
25
|
-
|
26
|
-
__all__ = [
|
27
|
-
'FullProjAlignPreSDMg', 'FullProjAlignPreDSMg',
|
28
|
-
'FullProjAlignPreSD', 'FullProjAlignPreDS',
|
29
|
-
]
|
30
|
-
|
31
|
-
|
32
|
-
def align_pre_add_bef_update(
|
33
|
-
syn_desc: DelayedInitializer,
|
34
|
-
delay_at,
|
35
|
-
delay_cls: ExtendedUpdateWithBA,
|
36
|
-
proj_name: str = None
|
37
|
-
):
|
38
|
-
_syn_id = f'Delay({str(delay_at)}) // {syn_desc.identifier}'
|
39
|
-
if not delay_cls.has_before_update(_syn_id):
|
40
|
-
# delay
|
41
|
-
delay_access = DelayAccess(delay_cls, delay_at, delay_entry=proj_name)
|
42
|
-
# synapse
|
43
|
-
syn_cls = syn_desc()
|
44
|
-
# add to "after_updates"
|
45
|
-
delay_cls.add_before_update(_syn_id, _AlignPreMg(delay_access, syn_cls))
|
46
|
-
syn = delay_cls.get_before_update(_syn_id).syn
|
47
|
-
return syn
|
48
|
-
|
49
|
-
|
50
|
-
class _AlignPreMg(Module):
|
51
|
-
def __init__(self, access, syn):
|
52
|
-
super().__init__()
|
53
|
-
self.access = access
|
54
|
-
self.syn = syn
|
55
|
-
|
56
|
-
def update(self, *args, **kwargs):
|
57
|
-
return self.syn(self.access())
|
58
|
-
|
59
|
-
|
60
|
-
@set_module_as('brainstate.nn')
|
61
|
-
class FullProjAlignPreSDMg(Projection):
|
62
|
-
"""Full-chain synaptic projection with the align-pre reduction and synapse+delay updating and merging.
|
63
|
-
|
64
|
-
The ``full-chain`` means that the model needs to provide all information needed for a projection,
|
65
|
-
including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``.
|
66
|
-
|
67
|
-
The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group.
|
68
|
-
|
69
|
-
The ``synapse+delay updating`` means that the projection first computes the synapse states, then delivers the
|
70
|
-
synapse states to the delay model, and finally computes the synaptic current.
|
71
|
-
|
72
|
-
The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same
|
73
|
-
parameters (such like time constants) will also share the same synaptic variables.
|
74
|
-
|
75
|
-
Neither ``FullProjAlignPreSDMg`` nor ``FullProjAlignPreDSMg`` facilitates the event-driven computation.
|
76
|
-
This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather
|
77
|
-
than the spiking. To facilitate the event-driven computation, please use align post projections.
|
78
|
-
|
79
|
-
To simulate an E/I balanced network model:
|
80
|
-
|
81
|
-
.. code-block:: python
|
82
|
-
|
83
|
-
class EINet(bp.DynSysGroup):
|
84
|
-
def __init__(self):
|
85
|
-
super().__init__()
|
86
|
-
ne, ni = 3200, 800
|
87
|
-
self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
|
88
|
-
V_initializer=bp.init.Normal(-55., 2.))
|
89
|
-
self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
|
90
|
-
V_initializer=bp.init.Normal(-55., 2.))
|
91
|
-
self.E2E = bp.dyn.FullProjAlignPreSDMg(pre=self.E,
|
92
|
-
syn=bp.dyn.Expon.desc(size=ne, tau=5.),
|
93
|
-
delay=0.1,
|
94
|
-
comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
|
95
|
-
out=bp.dyn.COBA(E=0.),
|
96
|
-
post=self.E)
|
97
|
-
self.E2I = bp.dyn.FullProjAlignPreSDMg(pre=self.E,
|
98
|
-
syn=bp.dyn.Expon.desc(size=ne, tau=5.),
|
99
|
-
delay=0.1,
|
100
|
-
comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
|
101
|
-
out=bp.dyn.COBA(E=0.),
|
102
|
-
post=self.I)
|
103
|
-
self.I2E = bp.dyn.FullProjAlignPreSDMg(pre=self.I,
|
104
|
-
syn=bp.dyn.Expon.desc(size=ni, tau=10.),
|
105
|
-
delay=0.1,
|
106
|
-
comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
|
107
|
-
out=bp.dyn.COBA(E=-80.),
|
108
|
-
post=self.E)
|
109
|
-
self.I2I = bp.dyn.FullProjAlignPreSDMg(pre=self.I,
|
110
|
-
syn=bp.dyn.Expon.desc(size=ni, tau=10.),
|
111
|
-
delay=0.1,
|
112
|
-
comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
|
113
|
-
out=bp.dyn.COBA(E=-80.),
|
114
|
-
post=self.I)
|
115
|
-
|
116
|
-
def update(self, inp):
|
117
|
-
self.E2E()
|
118
|
-
self.E2I()
|
119
|
-
self.I2E()
|
120
|
-
self.I2I()
|
121
|
-
self.E(inp)
|
122
|
-
self.I(inp)
|
123
|
-
return self.E.spike
|
124
|
-
|
125
|
-
model = EINet()
|
126
|
-
indices = bm.arange(1000)
|
127
|
-
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
|
128
|
-
bp.visualize.raster_plot(indices, spks, show=True)
|
129
|
-
|
130
|
-
|
131
|
-
Args:
|
132
|
-
pre: The pre-synaptic neuron group.
|
133
|
-
syn: The synaptic dynamics.
|
134
|
-
delay: The synaptic delay.
|
135
|
-
comm: The synaptic communication.
|
136
|
-
out: The synaptic output.
|
137
|
-
post: The post-synaptic neuron group.
|
138
|
-
name: str. The projection name.
|
139
|
-
mode: Mode. The computing mode.
|
140
|
-
"""
|
141
|
-
|
142
|
-
_invisible_nodes = ['pre', 'syn', 'delay', 'post']
|
143
|
-
|
144
|
-
def __init__(
|
145
|
-
self,
|
146
|
-
pre: ExtendedUpdateWithBA,
|
147
|
-
syn: DelayedInitializer[UpdateReturn],
|
148
|
-
delay: Union[None, int, float],
|
149
|
-
comm: Module,
|
150
|
-
out: BindCondData,
|
151
|
-
post: ReceiveInputProj,
|
152
|
-
out_label: Optional[str] = None,
|
153
|
-
name: Optional[str] = None,
|
154
|
-
mode: Optional[Mode] = None,
|
155
|
-
):
|
156
|
-
super().__init__(name=name, mode=mode)
|
157
|
-
|
158
|
-
# synaptic models
|
159
|
-
is_instance(pre, ExtendedUpdateWithBA)
|
160
|
-
is_instance(syn, DelayedInitializer[UpdateReturn])
|
161
|
-
is_instance(comm, Module)
|
162
|
-
is_instance(out, BindCondData)
|
163
|
-
is_instance(post, ReceiveInputProj)
|
164
|
-
self.comm = comm
|
165
|
-
self.out = out
|
166
|
-
|
167
|
-
# synapse initialization
|
168
|
-
if not pre.has_after_update(syn.identifier):
|
169
|
-
syn_cls = syn()
|
170
|
-
pre.add_after_update(syn.identifier, syn_cls)
|
171
|
-
self.syn = pre.get_after_update(syn.identifier)
|
172
|
-
|
173
|
-
# delay initialization
|
174
|
-
if delay is not None and delay > 0.:
|
175
|
-
delay_cls = register_delay_of_target(self.syn)
|
176
|
-
self.has_delay = True
|
177
|
-
self.delay = delay_cls
|
178
|
-
else:
|
179
|
-
self.has_delay = False
|
180
|
-
self.delay = None
|
181
|
-
|
182
|
-
# output initialization
|
183
|
-
post.add_input_fun(self.name, out, label=out_label)
|
184
|
-
|
185
|
-
# references
|
186
|
-
self.pre = pre
|
187
|
-
self.post = post
|
188
|
-
|
189
|
-
def update(self, x=None):
|
190
|
-
if x is None:
|
191
|
-
if self.has_delay:
|
192
|
-
x = self.delay.at(self.name)
|
193
|
-
else:
|
194
|
-
x = self.syn.update_return()
|
195
|
-
current = self.comm(x)
|
196
|
-
self.out.bind_cond(current)
|
197
|
-
return current
|
198
|
-
|
199
|
-
|
200
|
-
@set_module_as('brainstate.nn')
|
201
|
-
class FullProjAlignPreDSMg(Projection):
|
202
|
-
"""Full-chain synaptic projection with the align-pre reduction and delay+synapse updating and merging.
|
203
|
-
|
204
|
-
The ``full-chain`` means that the model needs to provide all information needed for a projection,
|
205
|
-
including ``pre`` -> ``delay`` -> ``syn`` -> ``comm`` -> ``out`` -> ``post``.
|
206
|
-
Note here, compared to ``FullProjAlignPreSDMg``, the ``delay`` and ``syn`` are exchanged.
|
207
|
-
|
208
|
-
The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group.
|
209
|
-
|
210
|
-
The ``delay+synapse updating`` means that the projection first delivers the pre neuron output (usually the
|
211
|
-
spiking) to the delay model, then computes the synapse states, and finally computes the synaptic current.
|
212
|
-
|
213
|
-
The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same
|
214
|
-
parameters (such like time constants) will also share the same synaptic variables.
|
215
|
-
|
216
|
-
Neither ``FullProjAlignPreDSMg`` nor ``FullProjAlignPreSDMg`` facilitates the event-driven computation.
|
217
|
-
This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather
|
218
|
-
than the spiking. To facilitate the event-driven computation, please use align post projections.
|
219
|
-
|
220
|
-
|
221
|
-
To simulate an E/I balanced network model:
|
222
|
-
|
223
|
-
.. code-block:: python
|
224
|
-
|
225
|
-
class EINet(bp.DynSysGroup):
|
226
|
-
def __init__(self):
|
227
|
-
super().__init__()
|
228
|
-
ne, ni = 3200, 800
|
229
|
-
self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
|
230
|
-
V_initializer=bp.init.Normal(-55., 2.))
|
231
|
-
self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
|
232
|
-
V_initializer=bp.init.Normal(-55., 2.))
|
233
|
-
self.E2E = bp.dyn.FullProjAlignPreDSMg(pre=self.E,
|
234
|
-
delay=0.1,
|
235
|
-
syn=bp.dyn.Expon.desc(size=ne, tau=5.),
|
236
|
-
comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
|
237
|
-
out=bp.dyn.COBA(E=0.),
|
238
|
-
post=self.E)
|
239
|
-
self.E2I = bp.dyn.FullProjAlignPreDSMg(pre=self.E,
|
240
|
-
delay=0.1,
|
241
|
-
syn=bp.dyn.Expon.desc(size=ne, tau=5.),
|
242
|
-
comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
|
243
|
-
out=bp.dyn.COBA(E=0.),
|
244
|
-
post=self.I)
|
245
|
-
self.I2E = bp.dyn.FullProjAlignPreDSMg(pre=self.I,
|
246
|
-
delay=0.1,
|
247
|
-
syn=bp.dyn.Expon.desc(size=ni, tau=10.),
|
248
|
-
comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
|
249
|
-
out=bp.dyn.COBA(E=-80.),
|
250
|
-
post=self.E)
|
251
|
-
self.I2I = bp.dyn.FullProjAlignPreDSMg(pre=self.I,
|
252
|
-
delay=0.1,
|
253
|
-
syn=bp.dyn.Expon.desc(size=ni, tau=10.),
|
254
|
-
comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
|
255
|
-
out=bp.dyn.COBA(E=-80.),
|
256
|
-
post=self.I)
|
257
|
-
|
258
|
-
def update(self, inp):
|
259
|
-
self.E2E()
|
260
|
-
self.E2I()
|
261
|
-
self.I2E()
|
262
|
-
self.I2I()
|
263
|
-
self.E(inp)
|
264
|
-
self.I(inp)
|
265
|
-
return self.E.spike
|
266
|
-
|
267
|
-
model = EINet()
|
268
|
-
indices = bm.arange(1000)
|
269
|
-
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
|
270
|
-
bp.visualize.raster_plot(indices, spks, show=True)
|
271
|
-
|
272
|
-
|
273
|
-
Args:
|
274
|
-
pre: The pre-synaptic neuron group.
|
275
|
-
delay: The synaptic delay.
|
276
|
-
syn: The synaptic dynamics.
|
277
|
-
comm: The synaptic communication.
|
278
|
-
out: The synaptic output.
|
279
|
-
post: The post-synaptic neuron group.
|
280
|
-
name: str. The projection name.
|
281
|
-
mode: Mode. The computing mode.
|
282
|
-
"""
|
283
|
-
_invisible_nodes = ['pre', 'syn', 'delay', 'post']
|
284
|
-
|
285
|
-
def __init__(
|
286
|
-
self,
|
287
|
-
pre: JointTypes[ExtendedUpdateWithBA, UpdateReturn],
|
288
|
-
delay: Union[None, int, float],
|
289
|
-
syn: DelayedInitializer[UpdateReturn],
|
290
|
-
comm: Module,
|
291
|
-
out: BindCondData,
|
292
|
-
post: ReceiveInputProj,
|
293
|
-
out_label: Optional[str] = None,
|
294
|
-
name: Optional[str] = None,
|
295
|
-
mode: Optional[Mode] = None,
|
296
|
-
):
|
297
|
-
super().__init__(name=name, mode=mode)
|
298
|
-
|
299
|
-
# synaptic models
|
300
|
-
is_instance(pre, JointTypes[ExtendedUpdateWithBA, UpdateReturn])
|
301
|
-
is_instance(syn, DelayedInitializer[Module])
|
302
|
-
is_instance(comm, Module)
|
303
|
-
is_instance(out, BindCondData)
|
304
|
-
is_instance(post, ReceiveInputProj)
|
305
|
-
self.comm = comm
|
306
|
-
self.out = out
|
307
|
-
|
308
|
-
# delay initialization
|
309
|
-
if delay is not None and delay > 0.:
|
310
|
-
delay_cls = register_delay_of_target(pre)
|
311
|
-
self.has_delay = True
|
312
|
-
self.delay = delay_cls
|
313
|
-
# synapse initialization
|
314
|
-
self.syn = align_pre_add_bef_update(syn, delay, delay_cls, self.name)
|
315
|
-
else:
|
316
|
-
self.has_delay = False
|
317
|
-
self.delay = None
|
318
|
-
if not pre.has_after_update(syn.identifier):
|
319
|
-
syn_cls = syn()
|
320
|
-
pre.add_after_update(syn.identifier, syn_cls)
|
321
|
-
self.syn = pre.get_after_update(syn.identifier)
|
322
|
-
|
323
|
-
# output initialization
|
324
|
-
post.add_input_fun(self.name, out, label=out_label)
|
325
|
-
|
326
|
-
# references
|
327
|
-
self.pre = pre
|
328
|
-
self.post = post
|
329
|
-
|
330
|
-
def update(self):
|
331
|
-
x = self.syn.update_return()
|
332
|
-
current = self.comm(x)
|
333
|
-
self.out.bind_cond(current)
|
334
|
-
return current
|
335
|
-
|
336
|
-
|
337
|
-
@set_module_as('brainstate.nn')
|
338
|
-
class FullProjAlignPreSD(Projection):
|
339
|
-
"""Full-chain synaptic projection with the align-pre reduction and synapse+delay updating.
|
340
|
-
|
341
|
-
The ``full-chain`` means that the model needs to provide all information needed for a projection,
|
342
|
-
including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``.
|
343
|
-
|
344
|
-
The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group.
|
345
|
-
|
346
|
-
The ``synapse+delay updating`` means that the projection first computes the synapse states, then delivers the
|
347
|
-
synapse states to the delay model, and finally computes the synaptic current.
|
348
|
-
|
349
|
-
Neither ``FullProjAlignPreSD`` nor ``FullProjAlignPreDS`` facilitates the event-driven computation.
|
350
|
-
This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather
|
351
|
-
than the spiking. To facilitate the event-driven computation, please use align post projections.
|
352
|
-
|
353
|
-
|
354
|
-
To simulate an E/I balanced network model:
|
355
|
-
|
356
|
-
.. code-block:: python
|
357
|
-
|
358
|
-
class EINet(bp.DynSysGroup):
|
359
|
-
def __init__(self):
|
360
|
-
super().__init__()
|
361
|
-
ne, ni = 3200, 800
|
362
|
-
self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
|
363
|
-
V_initializer=bp.init.Normal(-55., 2.))
|
364
|
-
self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
|
365
|
-
V_initializer=bp.init.Normal(-55., 2.))
|
366
|
-
self.E2E = bp.dyn.FullProjAlignPreSD(pre=self.E,
|
367
|
-
syn=bp.dyn.Expon.desc(size=ne, tau=5.),
|
368
|
-
delay=0.1,
|
369
|
-
comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
|
370
|
-
out=bp.dyn.COBA(E=0.),
|
371
|
-
post=self.E)
|
372
|
-
self.E2I = bp.dyn.FullProjAlignPreSD(pre=self.E,
|
373
|
-
syn=bp.dyn.Expon.desc(size=ne, tau=5.),
|
374
|
-
delay=0.1,
|
375
|
-
comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
|
376
|
-
out=bp.dyn.COBA(E=0.),
|
377
|
-
post=self.I)
|
378
|
-
self.I2E = bp.dyn.FullProjAlignPreSD(pre=self.I,
|
379
|
-
syn=bp.dyn.Expon.desc(size=ni, tau=10.),
|
380
|
-
delay=0.1,
|
381
|
-
comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
|
382
|
-
out=bp.dyn.COBA(E=-80.),
|
383
|
-
post=self.E)
|
384
|
-
self.I2I = bp.dyn.FullProjAlignPreSD(pre=self.I,
|
385
|
-
syn=bp.dyn.Expon.desc(size=ni, tau=10.),
|
386
|
-
delay=0.1,
|
387
|
-
comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
|
388
|
-
out=bp.dyn.COBA(E=-80.),
|
389
|
-
post=self.I)
|
390
|
-
|
391
|
-
def update(self, inp):
|
392
|
-
self.E2E()
|
393
|
-
self.E2I()
|
394
|
-
self.I2E()
|
395
|
-
self.I2I()
|
396
|
-
self.E(inp)
|
397
|
-
self.I(inp)
|
398
|
-
return self.E.spike
|
399
|
-
|
400
|
-
model = EINet()
|
401
|
-
indices = bm.arange(1000)
|
402
|
-
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
|
403
|
-
bp.visualize.raster_plot(indices, spks, show=True)
|
404
|
-
|
405
|
-
|
406
|
-
Args:
|
407
|
-
pre: The pre-synaptic neuron group.
|
408
|
-
syn: The synaptic dynamics.
|
409
|
-
delay: The synaptic delay.
|
410
|
-
comm: The synaptic communication.
|
411
|
-
out: The synaptic output.
|
412
|
-
post: The post-synaptic neuron group.
|
413
|
-
name: str. The projection name.
|
414
|
-
mode: Mode. The computing mode.
|
415
|
-
"""
|
416
|
-
|
417
|
-
_invisible_nodes = ['pre', 'post']
|
418
|
-
|
419
|
-
def __init__(
|
420
|
-
self,
|
421
|
-
pre: ExtendedUpdateWithBA,
|
422
|
-
syn: UpdateReturn,
|
423
|
-
delay: Union[None, int, float],
|
424
|
-
comm: Module,
|
425
|
-
out: BindCondData,
|
426
|
-
post: ReceiveInputProj,
|
427
|
-
out_label: Optional[str] = None,
|
428
|
-
name: Optional[str] = None,
|
429
|
-
mode: Optional[Mode] = None,
|
430
|
-
):
|
431
|
-
super().__init__(name=name, mode=mode)
|
432
|
-
|
433
|
-
# synaptic models
|
434
|
-
is_instance(pre, ExtendedUpdateWithBA)
|
435
|
-
is_instance(syn, UpdateReturn)
|
436
|
-
is_instance(comm, Module)
|
437
|
-
is_instance(out, BindCondData)
|
438
|
-
is_instance(post, ReceiveInputProj)
|
439
|
-
self.comm = comm
|
440
|
-
self.syn = syn
|
441
|
-
self.out = out
|
442
|
-
|
443
|
-
# delay initialization
|
444
|
-
if delay is not None and delay > 0.:
|
445
|
-
delay_cls = register_delay_of_target(syn)
|
446
|
-
delay_cls.register_entry(self.name, delay)
|
447
|
-
self.delay = delay_cls
|
448
|
-
else:
|
449
|
-
self.delay = None
|
450
|
-
|
451
|
-
# output initialization
|
452
|
-
post.add_input_fun(self.name, out, label=out_label)
|
453
|
-
|
454
|
-
# references
|
455
|
-
self.pre = pre
|
456
|
-
self.post = post
|
457
|
-
|
458
|
-
def update(self):
|
459
|
-
if self.delay is not None:
|
460
|
-
self.delay(self.syn(self.pre.update_return()))
|
461
|
-
x = self.delay.at(self.name)
|
462
|
-
else:
|
463
|
-
x = self.syn(self.pre.update_return())
|
464
|
-
current = self.comm(x)
|
465
|
-
self.out.bind_cond(current)
|
466
|
-
return current
|
467
|
-
|
468
|
-
|
469
|
-
@set_module_as('brainstate.nn')
|
470
|
-
class FullProjAlignPreDS(Projection):
|
471
|
-
"""Full-chain synaptic projection with the align-pre reduction and delay+synapse updating.
|
472
|
-
|
473
|
-
The ``full-chain`` means that the model needs to provide all information needed for a projection,
|
474
|
-
including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``.
|
475
|
-
Note here, compared to ``FullProjAlignPreSD``, the ``delay`` and ``syn`` are exchanged.
|
476
|
-
|
477
|
-
The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group.
|
478
|
-
|
479
|
-
The ``delay+synapse updating`` means that the projection first delivers the pre neuron output (usually the
|
480
|
-
spiking) to the delay model, then computes the synapse states, and finally computes the synaptic current.
|
481
|
-
|
482
|
-
Neither ``FullProjAlignPreDS`` nor ``FullProjAlignPreSD`` facilitates the event-driven computation.
|
483
|
-
This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather
|
484
|
-
than the spiking. To facilitate the event-driven computation, please use align post projections.
|
485
|
-
|
486
|
-
|
487
|
-
To simulate an E/I balanced network model:
|
488
|
-
|
489
|
-
.. code-block:: python
|
490
|
-
|
491
|
-
class EINet(bp.DynSysGroup):
|
492
|
-
def __init__(self):
|
493
|
-
super().__init__()
|
494
|
-
ne, ni = 3200, 800
|
495
|
-
self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
|
496
|
-
V_initializer=bp.init.Normal(-55., 2.))
|
497
|
-
self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
|
498
|
-
V_initializer=bp.init.Normal(-55., 2.))
|
499
|
-
self.E2E = bp.dyn.FullProjAlignPreDS(pre=self.E,
|
500
|
-
delay=0.1,
|
501
|
-
syn=bp.dyn.Expon.desc(size=ne, tau=5.),
|
502
|
-
comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6),
|
503
|
-
out=bp.dyn.COBA(E=0.),
|
504
|
-
post=self.E)
|
505
|
-
self.E2I = bp.dyn.FullProjAlignPreDS(pre=self.E,
|
506
|
-
delay=0.1,
|
507
|
-
syn=bp.dyn.Expon.desc(size=ne, tau=5.),
|
508
|
-
comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6),
|
509
|
-
out=bp.dyn.COBA(E=0.),
|
510
|
-
post=self.I)
|
511
|
-
self.I2E = bp.dyn.FullProjAlignPreDS(pre=self.I,
|
512
|
-
delay=0.1,
|
513
|
-
syn=bp.dyn.Expon.desc(size=ni, tau=10.),
|
514
|
-
comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7),
|
515
|
-
out=bp.dyn.COBA(E=-80.),
|
516
|
-
post=self.E)
|
517
|
-
self.I2I = bp.dyn.FullProjAlignPreDS(pre=self.I,
|
518
|
-
delay=0.1,
|
519
|
-
syn=bp.dyn.Expon.desc(size=ni, tau=10.),
|
520
|
-
comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7),
|
521
|
-
out=bp.dyn.COBA(E=-80.),
|
522
|
-
post=self.I)
|
523
|
-
|
524
|
-
def update(self, inp):
|
525
|
-
self.E2E()
|
526
|
-
self.E2I()
|
527
|
-
self.I2E()
|
528
|
-
self.I2I()
|
529
|
-
self.E(inp)
|
530
|
-
self.I(inp)
|
531
|
-
return self.E.spike
|
532
|
-
|
533
|
-
model = EINet()
|
534
|
-
indices = bm.arange(1000)
|
535
|
-
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
|
536
|
-
bp.visualize.raster_plot(indices, spks, show=True)
|
537
|
-
|
538
|
-
|
539
|
-
Args:
|
540
|
-
pre: The pre-synaptic neuron group.
|
541
|
-
delay: The synaptic delay.
|
542
|
-
syn: The synaptic dynamics.
|
543
|
-
comm: The synaptic communication.
|
544
|
-
out: The synaptic output.
|
545
|
-
post: The post-synaptic neuron group.
|
546
|
-
name: str. The projection name.
|
547
|
-
mode: Mode. The computing mode.
|
548
|
-
"""
|
549
|
-
|
550
|
-
_invisible_nodes = ['pre', 'post', 'delay']
|
551
|
-
|
552
|
-
def __init__(
|
553
|
-
self,
|
554
|
-
pre: UpdateReturn,
|
555
|
-
delay: Union[None, int, float],
|
556
|
-
syn: Module,
|
557
|
-
comm: Module,
|
558
|
-
out: BindCondData,
|
559
|
-
post: ReceiveInputProj,
|
560
|
-
out_label: Optional[str] = None,
|
561
|
-
name: Optional[str] = None,
|
562
|
-
mode: Optional[Mode] = None,
|
563
|
-
):
|
564
|
-
super().__init__(name=name, mode=mode)
|
565
|
-
|
566
|
-
# synaptic models
|
567
|
-
is_instance(pre, UpdateReturn)
|
568
|
-
is_instance(syn, Module)
|
569
|
-
is_instance(comm, Module)
|
570
|
-
is_instance(out, BindCondData)
|
571
|
-
is_instance(post, ReceiveInputProj)
|
572
|
-
self.comm = comm
|
573
|
-
self.syn = syn
|
574
|
-
self.out = out
|
575
|
-
|
576
|
-
# delay initialization
|
577
|
-
if delay is not None and delay > 0.:
|
578
|
-
delay_cls = register_delay_of_target(pre)
|
579
|
-
delay_cls.register_entry(self.name, delay)
|
580
|
-
self.delay = delay_cls
|
581
|
-
else:
|
582
|
-
self.delay = None
|
583
|
-
|
584
|
-
# output initialization
|
585
|
-
post.add_input_fun(self.name, out, label=out_label)
|
586
|
-
|
587
|
-
# references
|
588
|
-
self.pre = pre
|
589
|
-
self.post = post
|
590
|
-
|
591
|
-
def update(self, x=None):
|
592
|
-
if x is None:
|
593
|
-
if self.delay is not None:
|
594
|
-
x = self.delay.at(self.name)
|
595
|
-
else:
|
596
|
-
x = self.pre.update_return()
|
597
|
-
g = self.comm(self.syn(x))
|
598
|
-
self.out.bind_cond(g)
|
599
|
-
return g
|