brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/COBA_2005.py +125 -0
- benchmark/CUBA_2005.py +149 -0
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +611 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/event/__init__.py +27 -0
- brainstate/event/_csr.py +316 -0
- brainstate/event/_csr_benchmark.py +14 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +708 -0
- brainstate/event/_fixed_probability_benchmark.py +128 -0
- brainstate/event/_fixed_probability_test.py +131 -0
- brainstate/event/_linear.py +359 -0
- brainstate/event/_linear_benckmark.py +82 -0
- brainstate/event/_linear_test.py +117 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/event/_xla_custom_op.py +312 -0
- brainstate/event/_xla_custom_op_test.py +55 -0
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +41 -0
- brainstate/nn/_interaction/_conv.py +499 -0
- brainstate/nn/_interaction/_conv_test.py +239 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_linear.py +582 -0
- brainstate/nn/_interaction/_linear_test.py +42 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +121 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
- brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -0,0 +1,315 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
from __future__ import annotations
|
19
|
+
|
20
|
+
from typing import Optional
|
21
|
+
|
22
|
+
import brainunit as u
|
23
|
+
|
24
|
+
from brainstate import init, environ
|
25
|
+
from brainstate._state import ShortTermState, HiddenState
|
26
|
+
from brainstate.mixin import AlignPost
|
27
|
+
from brainstate.nn._dynamics._dynamics_base import Dynamics
|
28
|
+
from brainstate.nn._exp_euler import exp_euler_step
|
29
|
+
from brainstate.typing import ArrayLike, Size
|
30
|
+
|
31
|
+
__all__ = [
|
32
|
+
'Synapse', 'Expon', 'STP', 'STD', 'AMPA', 'GABAa',
|
33
|
+
]
|
34
|
+
|
35
|
+
|
36
|
+
class Synapse(Dynamics):
|
37
|
+
"""
|
38
|
+
Base class for synapse dynamics.
|
39
|
+
"""
|
40
|
+
__module__ = 'brainstate.nn'
|
41
|
+
|
42
|
+
|
43
|
+
class Expon(Synapse, AlignPost):
|
44
|
+
r"""Exponential decay synapse model.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
tau: float. The time constant of decay. [ms]
|
48
|
+
%s
|
49
|
+
"""
|
50
|
+
__module__ = 'brainstate.nn'
|
51
|
+
|
52
|
+
def __init__(
|
53
|
+
self,
|
54
|
+
in_size: Size,
|
55
|
+
name: Optional[str] = None,
|
56
|
+
tau: ArrayLike = 8.0 * u.ms,
|
57
|
+
g_initializer: ArrayLike = init.ZeroInit(unit=u.mS),
|
58
|
+
):
|
59
|
+
super().__init__(name=name, in_size=in_size)
|
60
|
+
|
61
|
+
# parameters
|
62
|
+
self.tau = init.param(tau, self.varshape)
|
63
|
+
self.g_initializer = g_initializer
|
64
|
+
|
65
|
+
def init_state(self, batch_size: int = None, **kwargs):
|
66
|
+
self.g = HiddenState(init.param(self.g_initializer, self.varshape, batch_size))
|
67
|
+
|
68
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
69
|
+
self.g.value = init.param(self.g_initializer, self.varshape, batch_size)
|
70
|
+
|
71
|
+
def update(self, x=None):
|
72
|
+
g = exp_euler_step(lambda g: self.sum_current_inputs(-g) / self.tau, self.g.value)
|
73
|
+
self.g.value = self.sum_delta_inputs(g)
|
74
|
+
if x is not None: self.g.value += x
|
75
|
+
return self.g.value
|
76
|
+
|
77
|
+
|
78
|
+
class STP(Synapse):
|
79
|
+
r"""Synaptic output with short-term plasticity.
|
80
|
+
|
81
|
+
%s
|
82
|
+
|
83
|
+
Args:
|
84
|
+
tau_f: float, ArrayType, Callable. The time constant of short-term facilitation.
|
85
|
+
tau_d: float, ArrayType, Callable. The time constant of short-term depression.
|
86
|
+
U: float, ArrayType, Callable. The fraction of resources used per action potential.
|
87
|
+
%s
|
88
|
+
"""
|
89
|
+
__module__ = 'brainstate.nn'
|
90
|
+
|
91
|
+
def __init__(
|
92
|
+
self,
|
93
|
+
in_size: Size,
|
94
|
+
name: Optional[str] = None,
|
95
|
+
U: ArrayLike = 0.15,
|
96
|
+
tau_f: ArrayLike = 1500. * u.ms,
|
97
|
+
tau_d: ArrayLike = 200. * u.ms,
|
98
|
+
):
|
99
|
+
super().__init__(name=name, in_size=in_size)
|
100
|
+
|
101
|
+
# parameters
|
102
|
+
self.tau_f = init.param(tau_f, self.varshape)
|
103
|
+
self.tau_d = init.param(tau_d, self.varshape)
|
104
|
+
self.U = init.param(U, self.varshape)
|
105
|
+
|
106
|
+
def init_state(self, batch_size: int = None, **kwargs):
|
107
|
+
self.x = HiddenState(init.param(init.Constant(1.), self.varshape, batch_size))
|
108
|
+
self.u = HiddenState(init.param(init.Constant(self.U), self.varshape, batch_size))
|
109
|
+
|
110
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
111
|
+
self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)
|
112
|
+
self.u.value = init.param(init.Constant(self.U), self.varshape, batch_size)
|
113
|
+
|
114
|
+
def update(self, pre_spike):
|
115
|
+
u = exp_euler_step(lambda u: self.U - u / self.tau_f, self.u.value)
|
116
|
+
x = exp_euler_step(lambda x: (1 - x) / self.tau_d, self.x.value)
|
117
|
+
|
118
|
+
# --- original code:
|
119
|
+
# if pre_spike.dtype == jax.numpy.bool_:
|
120
|
+
# u = bm.where(pre_spike, u + self.U * (1 - self.u), u)
|
121
|
+
# x = bm.where(pre_spike, x - u * self.x, x)
|
122
|
+
# else:
|
123
|
+
# u = pre_spike * (u + self.U * (1 - self.u)) + (1 - pre_spike) * u
|
124
|
+
# x = pre_spike * (x - u * self.x) + (1 - pre_spike) * x
|
125
|
+
|
126
|
+
# --- simplified code:
|
127
|
+
u = u + pre_spike * self.U * (1 - self.u.value)
|
128
|
+
x = x - pre_spike * u * self.x.value
|
129
|
+
|
130
|
+
self.u.value = u
|
131
|
+
self.x.value = x
|
132
|
+
return u * x * pre_spike
|
133
|
+
|
134
|
+
|
135
|
+
class STD(Synapse):
|
136
|
+
r"""Synaptic output with short-term depression.
|
137
|
+
|
138
|
+
%s
|
139
|
+
|
140
|
+
Args:
|
141
|
+
tau: float, ArrayType, Callable. The time constant of recovery of the synaptic vesicles.
|
142
|
+
U: float, ArrayType, Callable. The fraction of resources used per action potential.
|
143
|
+
%s
|
144
|
+
"""
|
145
|
+
__module__ = 'brainstate.nn'
|
146
|
+
|
147
|
+
def __init__(
|
148
|
+
self,
|
149
|
+
in_size: Size,
|
150
|
+
name: Optional[str] = None,
|
151
|
+
# synapse parameters
|
152
|
+
tau: ArrayLike = 200. * u.ms,
|
153
|
+
U: ArrayLike = 0.07,
|
154
|
+
):
|
155
|
+
super().__init__(name=name, in_size=in_size)
|
156
|
+
|
157
|
+
# parameters
|
158
|
+
self.tau = init.param(tau, self.varshape)
|
159
|
+
self.U = init.param(U, self.varshape)
|
160
|
+
|
161
|
+
def init_state(self, batch_size: int = None, **kwargs):
|
162
|
+
self.x = HiddenState(init.param(init.Constant(1.), self.varshape, batch_size))
|
163
|
+
|
164
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
165
|
+
self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)
|
166
|
+
|
167
|
+
def update(self, pre_spike):
|
168
|
+
x = exp_euler_step(lambda x: (1 - x) / self.tau, self.x.value)
|
169
|
+
|
170
|
+
# --- original code:
|
171
|
+
# self.x.value = bm.where(pre_spike, x - self.U * self.x, x)
|
172
|
+
|
173
|
+
# --- simplified code:
|
174
|
+
self.x.value = x - pre_spike * self.U * self.x.value
|
175
|
+
|
176
|
+
return self.x.value * pre_spike
|
177
|
+
|
178
|
+
|
179
|
+
class AMPA(Synapse):
|
180
|
+
r"""AMPA synapse model.
|
181
|
+
|
182
|
+
**Model Descriptions**
|
183
|
+
|
184
|
+
AMPA receptor is an ionotropic receptor, which is an ion channel.
|
185
|
+
When it is bound by neurotransmitters, it will immediately open the
|
186
|
+
ion channel, causing the change of membrane potential of postsynaptic neurons.
|
187
|
+
|
188
|
+
A classical model is to use the Markov process to model ion channel switch.
|
189
|
+
Here :math:`g` represents the probability of channel opening, :math:`1-g`
|
190
|
+
represents the probability of ion channel closing, and :math:`\alpha` and
|
191
|
+
:math:`\beta` are the transition probability. Because neurotransmitters can
|
192
|
+
open ion channels, the transfer probability from :math:`1-g` to :math:`g`
|
193
|
+
is affected by the concentration of neurotransmitters. We denote the concentration
|
194
|
+
of neurotransmitters as :math:`[T]` and get the following Markov process.
|
195
|
+
|
196
|
+
.. image:: ../../_static/synapse_markov.png
|
197
|
+
:align: center
|
198
|
+
|
199
|
+
We obtained the following formula when describing the process by a differential equation.
|
200
|
+
|
201
|
+
.. math::
|
202
|
+
|
203
|
+
\frac{ds}{dt} =\alpha[T](1-g)-\beta g
|
204
|
+
|
205
|
+
where :math:`\alpha [T]` denotes the transition probability from state :math:`(1-g)`
|
206
|
+
to state :math:`(g)`; and :math:`\beta` represents the transition probability of
|
207
|
+
the other direction. :math:`\alpha` is the binding constant. :math:`\beta` is the
|
208
|
+
unbinding constant. :math:`[T]` is the neurotransmitter concentration, and
|
209
|
+
has the duration of 0.5 ms.
|
210
|
+
|
211
|
+
Moreover, the post-synaptic current on the post-synaptic neuron is formulated as
|
212
|
+
|
213
|
+
.. math::
|
214
|
+
|
215
|
+
I_{syn} = g_{max} g (V-E)
|
216
|
+
|
217
|
+
where :math:`g_{max}` is the maximum conductance, and `E` is the reverse potential.
|
218
|
+
|
219
|
+
This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example:
|
220
|
+
|
221
|
+
.. [1] Vijayan S, Kopell N J. Thalamic model of awake alpha oscillations
|
222
|
+
and implications for stimulus processing[J]. Proceedings of the
|
223
|
+
National Academy of Sciences, 2012, 109(45): 18553-18558.
|
224
|
+
|
225
|
+
Args:
|
226
|
+
alpha: float, ArrayType, Callable. Binding constant. [ms^-1 mM^-1]
|
227
|
+
beta: float, ArrayType, Callable. Unbinding constant. [ms^-1 mM^-1]
|
228
|
+
T: float, ArrayType, Callable. Transmitter concentration when synapse is triggered by
|
229
|
+
a pre-synaptic spike. Default 0.5 [mM].
|
230
|
+
T_dur: float, ArrayType, Callable. Transmitter concentration duration time after being triggered. Default 1 [ms]
|
231
|
+
"""
|
232
|
+
|
233
|
+
def __init__(
|
234
|
+
self,
|
235
|
+
in_size: Size,
|
236
|
+
name: Optional[str] = None,
|
237
|
+
alpha: ArrayLike = 0.98 / (u.ms * u.mM),
|
238
|
+
beta: ArrayLike = 0.18 / u.ms,
|
239
|
+
T: ArrayLike = 0.5 * u.mM,
|
240
|
+
T_dur: ArrayLike = 0.5 * u.ms,
|
241
|
+
g_initializer: ArrayLike = init.ZeroInit(),
|
242
|
+
):
|
243
|
+
super().__init__(name=name, in_size=in_size)
|
244
|
+
|
245
|
+
# parameters
|
246
|
+
self.alpha = init.param(alpha, self.varshape)
|
247
|
+
self.beta = init.param(beta, self.varshape)
|
248
|
+
self.T = init.param(T, self.varshape)
|
249
|
+
self.T_duration = init.param(T_dur, self.varshape)
|
250
|
+
self.g_initializer = g_initializer
|
251
|
+
|
252
|
+
def init_state(self, batch_size=None):
|
253
|
+
self.g = HiddenState(init.param(self.g_initializer, self.varshape, batch_size))
|
254
|
+
self.spike_arrival_time = ShortTermState(init.param(init.Constant(-1e7 * u.ms), self.varshape, batch_size))
|
255
|
+
|
256
|
+
def reset_state(self, batch_or_mode=None, **kwargs):
|
257
|
+
self.g.value = init.param(self.g_initializer, self.varshape, batch_or_mode)
|
258
|
+
self.spike_arrival_time.value = init.param(init.Constant(-1e7 * u.ms), self.varshape, batch_or_mode)
|
259
|
+
|
260
|
+
def dg(self, g, t, TT):
|
261
|
+
return self.alpha * TT * (1 - g) - self.beta * g
|
262
|
+
|
263
|
+
def update(self, pre_spike):
|
264
|
+
t = environ.get('t')
|
265
|
+
self.spike_arrival_time.value = u.math.where(pre_spike, t, self.spike_arrival_time.value)
|
266
|
+
TT = ((t - self.spike_arrival_time.value) < self.T_duration) * self.T
|
267
|
+
self.g.value = exp_euler_step(self.dg, self.g.value, t, TT)
|
268
|
+
return self.g.value
|
269
|
+
|
270
|
+
|
271
|
+
class GABAa(AMPA):
|
272
|
+
r"""GABAa synapse model.
|
273
|
+
|
274
|
+
**Model Descriptions**
|
275
|
+
|
276
|
+
GABAa synapse model has the same equation with the `AMPA synapse <./brainmodels.synapses.AMPA.rst>`_,
|
277
|
+
|
278
|
+
.. math::
|
279
|
+
|
280
|
+
\frac{d g}{d t}&=\alpha[T](1-g) - \beta g \\
|
281
|
+
I_{syn}&= - g_{max} g (V - E)
|
282
|
+
|
283
|
+
but with the difference of:
|
284
|
+
|
285
|
+
- Reversal potential of synapse :math:`E` is usually low, typically -80. mV
|
286
|
+
- Activating rate constant :math:`\alpha=0.53`
|
287
|
+
- De-activating rate constant :math:`\beta=0.18`
|
288
|
+
- Transmitter concentration :math:`[T]=1\,\mu ho(\mu S)` when synapse is
|
289
|
+
triggered by a pre-synaptic spike, with the duration of 1. ms.
|
290
|
+
|
291
|
+
This module can be used with interface ``brainpy.dyn.ProjAlignPreMg2``, as shown in the following example:
|
292
|
+
|
293
|
+
.. [1] Destexhe, Alain, and Denis Paré. "Impact of network activity
|
294
|
+
on the integrative properties of neocortical pyramidal neurons
|
295
|
+
in vivo." Journal of neurophysiology 81.4 (1999): 1531-1547.
|
296
|
+
|
297
|
+
Args:
|
298
|
+
alpha: float, ArrayType, Callable. Binding constant. Default 0.062
|
299
|
+
beta: float, ArrayType, Callable. Unbinding constant. Default 3.57
|
300
|
+
T: float, ArrayType, Callable. Transmitter concentration when synapse is triggered by
|
301
|
+
a pre-synaptic spike.. Default 1 [mM].
|
302
|
+
T_dur: float, ArrayType, Callable. Transmitter concentration duration time
|
303
|
+
after being triggered. Default 1 [ms]
|
304
|
+
"""
|
305
|
+
|
306
|
+
def __init__(
|
307
|
+
self,
|
308
|
+
in_size: Size,
|
309
|
+
name: Optional[str] = None,
|
310
|
+
alpha: ArrayLike = 0.53 / (u.ms * u.mM),
|
311
|
+
beta: ArrayLike = 0.18 / u.ms,
|
312
|
+
T: ArrayLike = 1.0 * u.mM,
|
313
|
+
T_dur: ArrayLike = 1.0 * u.ms,
|
314
|
+
):
|
315
|
+
super().__init__(alpha=alpha, beta=beta, T=T, T_dur=T_dur, name=name, in_size=in_size)
|
@@ -0,0 +1,132 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import unittest
|
19
|
+
|
20
|
+
import brainunit as u
|
21
|
+
import jax.numpy as jnp
|
22
|
+
import pytest
|
23
|
+
|
24
|
+
import brainstate as bst
|
25
|
+
from brainstate.nn import Expon, STP, STD
|
26
|
+
|
27
|
+
|
28
|
+
class TestSynapse(unittest.TestCase):
|
29
|
+
def setUp(self):
|
30
|
+
self.in_size = 10
|
31
|
+
self.batch_size = 5
|
32
|
+
self.time_steps = 100
|
33
|
+
|
34
|
+
def generate_input(self):
|
35
|
+
return bst.random.randn(self.time_steps, self.batch_size, self.in_size) * u.mS
|
36
|
+
|
37
|
+
def test_expon_synapse(self):
|
38
|
+
tau = 20.0 * u.ms
|
39
|
+
synapse = Expon(self.in_size, tau=tau)
|
40
|
+
inputs = self.generate_input()
|
41
|
+
|
42
|
+
# Test initialization
|
43
|
+
self.assertEqual(synapse.in_size, (self.in_size,))
|
44
|
+
self.assertEqual(synapse.out_size, (self.in_size,))
|
45
|
+
self.assertEqual(synapse.tau, tau)
|
46
|
+
|
47
|
+
# Test forward pass
|
48
|
+
state = synapse.init_state(self.batch_size)
|
49
|
+
call = bst.compile.jit(synapse)
|
50
|
+
with bst.environ.context(dt=0.1 * u.ms):
|
51
|
+
for t in range(self.time_steps):
|
52
|
+
out = call(inputs[t])
|
53
|
+
self.assertEqual(out.shape, (self.batch_size, self.in_size))
|
54
|
+
|
55
|
+
# Test exponential decay
|
56
|
+
constant_input = jnp.ones((self.batch_size, self.in_size)) * u.mS
|
57
|
+
out1 = call(constant_input)
|
58
|
+
out2 = call(constant_input)
|
59
|
+
self.assertTrue(jnp.all(out2 > out1)) # Output should increase with constant input
|
60
|
+
|
61
|
+
@pytest.mark.skip(reason="Not implemented yet")
|
62
|
+
def test_stp_synapse(self):
|
63
|
+
tau_d = 200.0 * u.ms
|
64
|
+
tau_f = 20.0 * u.ms
|
65
|
+
U = 0.2
|
66
|
+
synapse = STP(self.in_size, tau_d=tau_d, tau_f=tau_f, U=U)
|
67
|
+
inputs = self.generate_input()
|
68
|
+
|
69
|
+
# Test initialization
|
70
|
+
self.assertEqual(synapse.in_size, (self.in_size,))
|
71
|
+
self.assertEqual(synapse.out_size, (self.in_size,))
|
72
|
+
self.assertEqual(synapse.tau_d, tau_d)
|
73
|
+
self.assertEqual(synapse.tau_f, tau_f)
|
74
|
+
self.assertEqual(synapse.U, U)
|
75
|
+
|
76
|
+
# Test forward pass
|
77
|
+
state = synapse.init_state(self.batch_size)
|
78
|
+
call = bst.compile.jit(synapse)
|
79
|
+
for t in range(self.time_steps):
|
80
|
+
out = call(inputs[t])
|
81
|
+
self.assertEqual(out.shape, (self.batch_size, self.in_size))
|
82
|
+
|
83
|
+
# Test short-term plasticity
|
84
|
+
constant_input = jnp.ones((self.batch_size, self.in_size)) * u.mS
|
85
|
+
out1 = call(constant_input)
|
86
|
+
out2 = call(constant_input)
|
87
|
+
self.assertTrue(jnp.any(out2 != out1)) # Output should change due to STP
|
88
|
+
|
89
|
+
@pytest.mark.skip(reason="Not implemented yet")
|
90
|
+
def test_std_synapse(self):
|
91
|
+
tau = 200.0
|
92
|
+
U = 0.2
|
93
|
+
synapse = STD(self.in_size, tau=tau, U=U)
|
94
|
+
inputs = self.generate_input()
|
95
|
+
|
96
|
+
# Test initialization
|
97
|
+
self.assertEqual(synapse.in_size, (self.in_size,))
|
98
|
+
self.assertEqual(synapse.out_size, (self.in_size,))
|
99
|
+
self.assertEqual(synapse.tau, tau)
|
100
|
+
self.assertEqual(synapse.U, U)
|
101
|
+
|
102
|
+
# Test forward pass
|
103
|
+
state = synapse.init_state(self.batch_size)
|
104
|
+
for t in range(self.time_steps):
|
105
|
+
out = synapse(inputs[t])
|
106
|
+
self.assertEqual(out.shape, (self.batch_size, self.in_size))
|
107
|
+
|
108
|
+
# Test short-term depression
|
109
|
+
constant_input = jnp.ones((self.batch_size, self.in_size))
|
110
|
+
out1 = synapse(constant_input)
|
111
|
+
out2 = synapse(constant_input)
|
112
|
+
self.assertTrue(jnp.all(out2 < out1)) # Output should decrease due to STD
|
113
|
+
|
114
|
+
def test_keep_size(self):
|
115
|
+
in_size = (2, 3)
|
116
|
+
for SynapseClass in [Expon, ]:
|
117
|
+
synapse = SynapseClass(in_size)
|
118
|
+
self.assertEqual(synapse.in_size, in_size)
|
119
|
+
self.assertEqual(synapse.out_size, in_size)
|
120
|
+
|
121
|
+
inputs = bst.random.randn(self.time_steps, self.batch_size, *in_size) * u.mS
|
122
|
+
state = synapse.init_state(self.batch_size)
|
123
|
+
call = bst.compile.jit(synapse)
|
124
|
+
with bst.environ.context(dt=0.1 * u.ms):
|
125
|
+
for t in range(self.time_steps):
|
126
|
+
out = call(inputs[t])
|
127
|
+
self.assertEqual(out.shape, (self.batch_size, *in_size))
|
128
|
+
|
129
|
+
|
130
|
+
if __name__ == '__main__':
|
131
|
+
with bst.environ.context(dt=0.1):
|
132
|
+
unittest.main()
|
@@ -0,0 +1,154 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
from __future__ import annotations
|
16
|
+
|
17
|
+
from typing import Union, Optional, Sequence, Callable
|
18
|
+
|
19
|
+
import brainunit as u
|
20
|
+
|
21
|
+
from brainstate import environ, init, random
|
22
|
+
from brainstate._state import ShortTermState
|
23
|
+
from brainstate.compile import while_loop
|
24
|
+
from brainstate.nn._dynamics._dynamics_base import Dynamics
|
25
|
+
from brainstate.typing import ArrayLike, Size, DTypeLike
|
26
|
+
|
27
|
+
__all__ = [
|
28
|
+
'SpikeTime',
|
29
|
+
'PoissonSpike',
|
30
|
+
'PoissonEncoder',
|
31
|
+
]
|
32
|
+
|
33
|
+
|
34
|
+
class SpikeTime(Dynamics):
|
35
|
+
"""The input neuron group characterized by spikes emitting at given times.
|
36
|
+
|
37
|
+
>>> # Get 2 neurons, firing spikes at 10 ms and 20 ms.
|
38
|
+
>>> SpikeTime(2, times=[10, 20])
|
39
|
+
>>> # or
|
40
|
+
>>> # Get 2 neurons, the neuron 0 fires spikes at 10 ms and 20 ms.
|
41
|
+
>>> SpikeTime(2, times=[10, 20], indices=[0, 0])
|
42
|
+
>>> # or
|
43
|
+
>>> # Get 2 neurons, neuron 0 fires at 10 ms and 30 ms, neuron 1 fires at 20 ms.
|
44
|
+
>>> SpikeTime(2, times=[10, 20, 30], indices=[0, 1, 0])
|
45
|
+
>>> # or
|
46
|
+
>>> # Get 2 neurons; at 10 ms, neuron 0 fires; at 20 ms, neuron 0 and 1 fire;
|
47
|
+
>>> # at 30 ms, neuron 1 fires.
|
48
|
+
>>> SpikeTime(2, times=[10, 20, 20, 30], indices=[0, 0, 1, 1])
|
49
|
+
|
50
|
+
Parameters
|
51
|
+
----------
|
52
|
+
in_size : int, tuple, list
|
53
|
+
The neuron group geometry.
|
54
|
+
indices : list, tuple, ArrayType
|
55
|
+
The neuron indices at each time point to emit spikes.
|
56
|
+
times : list, tuple, ArrayType
|
57
|
+
The time points which generate the spikes.
|
58
|
+
name : str, optional
|
59
|
+
The name of the dynamic system.
|
60
|
+
"""
|
61
|
+
|
62
|
+
def __init__(
|
63
|
+
self,
|
64
|
+
in_size: Size,
|
65
|
+
indices: Union[Sequence, ArrayLike],
|
66
|
+
times: Union[Sequence, ArrayLike],
|
67
|
+
spk_type: DTypeLike = bool,
|
68
|
+
name: Optional[str] = None,
|
69
|
+
need_sort: bool = True,
|
70
|
+
):
|
71
|
+
super().__init__(in_size=in_size, name=name)
|
72
|
+
|
73
|
+
# parameters
|
74
|
+
if len(indices) != len(times):
|
75
|
+
raise ValueError(f'The length of "indices" and "times" must be the same. '
|
76
|
+
f'However, we got {len(indices)} != {len(times)}.')
|
77
|
+
self.num_times = len(times)
|
78
|
+
self.spk_type = spk_type
|
79
|
+
|
80
|
+
# data about times and indices
|
81
|
+
self.times = u.math.asarray(times)
|
82
|
+
self.indices = u.math.asarray(indices, dtype=environ.ditype())
|
83
|
+
if need_sort:
|
84
|
+
sort_idx = u.math.argsort(self.times)
|
85
|
+
self.indices = self.indices[sort_idx]
|
86
|
+
self.times = self.times[sort_idx]
|
87
|
+
|
88
|
+
def init_state(self, *args, **kwargs):
|
89
|
+
self.i = ShortTermState(-1)
|
90
|
+
|
91
|
+
def reset_state(self, batch_size=None, **kwargs):
|
92
|
+
self.i.value = -1
|
93
|
+
|
94
|
+
def update(self):
|
95
|
+
t = environ.get('t')
|
96
|
+
|
97
|
+
def _cond_fun(spikes):
|
98
|
+
i = self.i.value
|
99
|
+
return u.math.logical_and(i < self.num_times, t >= self.times[i])
|
100
|
+
|
101
|
+
def _body_fun(spikes):
|
102
|
+
i = self.i.value
|
103
|
+
spikes = spikes.at[..., self.indices[i]].set(True)
|
104
|
+
self.i.value += 1
|
105
|
+
return spikes
|
106
|
+
|
107
|
+
spike = u.math.zeros(self.varshape, dtype=self.spk_type)
|
108
|
+
spike = while_loop(_cond_fun, _body_fun, spike)
|
109
|
+
return spike
|
110
|
+
|
111
|
+
|
112
|
+
class PoissonSpike(Dynamics):
|
113
|
+
"""
|
114
|
+
Poisson Neuron Group.
|
115
|
+
"""
|
116
|
+
|
117
|
+
def __init__(
|
118
|
+
self,
|
119
|
+
in_size: Size,
|
120
|
+
freqs: Union[ArrayLike, Callable],
|
121
|
+
spk_type: DTypeLike = bool,
|
122
|
+
name: Optional[str] = None,
|
123
|
+
):
|
124
|
+
super().__init__(in_size=in_size, name=name)
|
125
|
+
|
126
|
+
self.spk_type = spk_type
|
127
|
+
|
128
|
+
# parameters
|
129
|
+
self.freqs = init.param(freqs, self.varshape, allow_none=False)
|
130
|
+
|
131
|
+
def update(self):
|
132
|
+
spikes = random.rand(self.varshape) <= (self.freqs * environ.get_dt())
|
133
|
+
spikes = u.math.asarray(spikes, dtype=self.spk_type)
|
134
|
+
return spikes
|
135
|
+
|
136
|
+
|
137
|
+
class PoissonEncoder(Dynamics):
|
138
|
+
"""
|
139
|
+
Poisson Neuron Group.
|
140
|
+
"""
|
141
|
+
|
142
|
+
def __init__(
|
143
|
+
self,
|
144
|
+
in_size: Size,
|
145
|
+
spk_type: DTypeLike = bool,
|
146
|
+
name: Optional[str] = None,
|
147
|
+
):
|
148
|
+
super().__init__(in_size=in_size, name=name)
|
149
|
+
self.spk_type = spk_type
|
150
|
+
|
151
|
+
def update(self, freqs: ArrayLike):
|
152
|
+
spikes = random.rand(*self.varshape) <= (freqs * environ.get_dt())
|
153
|
+
spikes = u.math.asarray(spikes, dtype=self.spk_type)
|
154
|
+
return spikes
|
@@ -13,13 +13,13 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
from __future__ import annotations
|
16
17
|
|
17
|
-
from .
|
18
|
-
from .csr import __all__ as __all_csr
|
19
|
-
from .fixed_probability import *
|
20
|
-
from .fixed_probability import __all__ as __all_fixed_probability
|
21
|
-
from .linear import *
|
22
|
-
from .linear import __all__ as __all_linear
|
18
|
+
from brainstate.nn._dynamics._dynamics_base import Projection
|
23
19
|
|
24
|
-
__all__ =
|
25
|
-
|
20
|
+
__all__ = [
|
21
|
+
]
|
22
|
+
|
23
|
+
|
24
|
+
class ExponentialSynapse(Projection):
|
25
|
+
pass
|