brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +95 -29
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.9.dist-info/RECORD +0 -130
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/nn/_synapse.py
DELETED
@@ -1,505 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
|
19
|
-
from typing import Optional, Callable
|
20
|
-
|
21
|
-
import brainunit as u
|
22
|
-
|
23
|
-
from brainstate import init, environ
|
24
|
-
from brainstate._state import ShortTermState, HiddenState
|
25
|
-
from brainstate.mixin import AlignPost
|
26
|
-
from brainstate.typing import ArrayLike, Size
|
27
|
-
from ._dynamics import Dynamics
|
28
|
-
from ._exp_euler import exp_euler_step
|
29
|
-
|
30
|
-
__all__ = [
|
31
|
-
'Synapse', 'Expon', 'DualExpon', 'Alpha', 'AMPA', 'GABAa',
|
32
|
-
]
|
33
|
-
|
34
|
-
|
35
|
-
class Synapse(Dynamics):
|
36
|
-
"""
|
37
|
-
Base class for synapse dynamics.
|
38
|
-
|
39
|
-
This class serves as the foundation for all synapse models in the library,
|
40
|
-
providing a common interface for implementing various types of synaptic
|
41
|
-
connectivity and transmission mechanisms.
|
42
|
-
|
43
|
-
Synapses are responsible for modeling the transmission of signals between
|
44
|
-
neurons, including temporal dynamics, plasticity, and neurotransmitter effects.
|
45
|
-
All specific synapse implementations (like Expon, Alpha, AMPA, etc.) should
|
46
|
-
inherit from this class.
|
47
|
-
|
48
|
-
See Also
|
49
|
-
--------
|
50
|
-
Expon : Simple first-order exponential decay synapse model
|
51
|
-
Alpha : Alpha function synapse model
|
52
|
-
DualExpon : Dual exponential synapse model
|
53
|
-
STP : Synapse with short-term plasticity
|
54
|
-
STD : Synapse with short-term depression
|
55
|
-
AMPA : AMPA receptor synapse model
|
56
|
-
GABAa : GABAa receptor synapse model
|
57
|
-
"""
|
58
|
-
__module__ = 'brainstate.nn'
|
59
|
-
|
60
|
-
|
61
|
-
class Expon(Synapse, AlignPost):
|
62
|
-
r"""
|
63
|
-
Exponential decay synapse model.
|
64
|
-
|
65
|
-
This class implements a simple first-order exponential decay synapse model where
|
66
|
-
the synaptic conductance g decays exponentially with time constant tau:
|
67
|
-
|
68
|
-
$$
|
69
|
-
dg/dt = -g/\tau + \text{input}
|
70
|
-
$$
|
71
|
-
|
72
|
-
The model is widely used for basic synaptic transmission modeling.
|
73
|
-
|
74
|
-
Parameters
|
75
|
-
----------
|
76
|
-
in_size : Size
|
77
|
-
Size of the input.
|
78
|
-
name : str, optional
|
79
|
-
Name of the synapse instance.
|
80
|
-
tau : ArrayLike, default=8.0*u.ms
|
81
|
-
Time constant of decay in milliseconds.
|
82
|
-
g_initializer : ArrayLike or Callable, default=init.ZeroInit(unit=u.mS)
|
83
|
-
Initial value or initializer for synaptic conductance.
|
84
|
-
|
85
|
-
Attributes
|
86
|
-
----------
|
87
|
-
g : HiddenState
|
88
|
-
Synaptic conductance state variable.
|
89
|
-
tau : Parameter
|
90
|
-
Time constant of decay.
|
91
|
-
|
92
|
-
Notes
|
93
|
-
-----
|
94
|
-
The implementation uses an exponential Euler integration method.
|
95
|
-
The output of this synapse is the conductance value.
|
96
|
-
|
97
|
-
This class inherits from :py:class:`AlignPost`, which means it can be used in projection patterns
|
98
|
-
where synaptic variables are aligned with post-synaptic neurons, enabling event-driven
|
99
|
-
computation and more efficient handling of sparse connectivity patterns.
|
100
|
-
"""
|
101
|
-
__module__ = 'brainstate.nn'
|
102
|
-
|
103
|
-
def __init__(
|
104
|
-
self,
|
105
|
-
in_size: Size,
|
106
|
-
name: Optional[str] = None,
|
107
|
-
tau: ArrayLike = 8.0 * u.ms,
|
108
|
-
g_initializer: ArrayLike | Callable = init.ZeroInit(unit=u.mS),
|
109
|
-
):
|
110
|
-
super().__init__(name=name, in_size=in_size)
|
111
|
-
|
112
|
-
# parameters
|
113
|
-
self.tau = init.param(tau, self.varshape)
|
114
|
-
self.g_initializer = g_initializer
|
115
|
-
|
116
|
-
def init_state(self, batch_size: int = None, **kwargs):
|
117
|
-
self.g = HiddenState(init.param(self.g_initializer, self.varshape, batch_size))
|
118
|
-
|
119
|
-
def reset_state(self, batch_size: int = None, **kwargs):
|
120
|
-
self.g.value = init.param(self.g_initializer, self.varshape, batch_size)
|
121
|
-
|
122
|
-
def update(self, x=None):
|
123
|
-
g = exp_euler_step(lambda g: self.sum_current_inputs(-g) / self.tau, self.g.value)
|
124
|
-
self.g.value = self.sum_delta_inputs(g)
|
125
|
-
if x is not None: self.g.value += x
|
126
|
-
return self.g.value
|
127
|
-
|
128
|
-
|
129
|
-
class DualExpon(Synapse, AlignPost):
|
130
|
-
r"""
|
131
|
-
Dual exponential synapse model.
|
132
|
-
|
133
|
-
This class implements a synapse model with separate rise and decay time constants,
|
134
|
-
which produces a more biologically realistic conductance waveform than a single
|
135
|
-
exponential model. The model is characterized by the differential equation system:
|
136
|
-
|
137
|
-
dg_rise/dt = -g_rise/tau_rise
|
138
|
-
dg_decay/dt = -g_decay/tau_decay
|
139
|
-
g = a * (g_decay - g_rise)
|
140
|
-
|
141
|
-
where $a$ is a normalization factor that ensures the peak conductance reaches
|
142
|
-
the desired amplitude.
|
143
|
-
|
144
|
-
Parameters
|
145
|
-
----------
|
146
|
-
in_size : Size
|
147
|
-
Size of the input.
|
148
|
-
name : str, optional
|
149
|
-
Name of the synapse instance.
|
150
|
-
tau_decay : ArrayLike, default=10.0*u.ms
|
151
|
-
Time constant of decay in milliseconds.
|
152
|
-
tau_rise : ArrayLike, default=1.0*u.ms
|
153
|
-
Time constant of rise in milliseconds.
|
154
|
-
A : ArrayLike, optional
|
155
|
-
Amplitude scaling factor. If None, a scaling factor is automatically
|
156
|
-
calculated to normalize the peak amplitude.
|
157
|
-
g_initializer : ArrayLike or Callable, default=init.ZeroInit(unit=u.mS)
|
158
|
-
Initial value or initializer for synaptic conductance.
|
159
|
-
|
160
|
-
Attributes
|
161
|
-
----------
|
162
|
-
g_rise : HiddenState
|
163
|
-
Rise component of synaptic conductance.
|
164
|
-
g_decay : HiddenState
|
165
|
-
Decay component of synaptic conductance.
|
166
|
-
tau_rise : Parameter
|
167
|
-
Time constant of rise phase.
|
168
|
-
tau_decay : Parameter
|
169
|
-
Time constant of decay phase.
|
170
|
-
a : Parameter
|
171
|
-
Normalization factor calculated from tau_rise, tau_decay, and A.
|
172
|
-
|
173
|
-
Notes
|
174
|
-
-----
|
175
|
-
The dual exponential model produces a conductance waveform that is more
|
176
|
-
physiologically realistic than a simple exponential decay, with a finite
|
177
|
-
rise time followed by a slower decay.
|
178
|
-
|
179
|
-
The implementation uses an exponential Euler integration method.
|
180
|
-
The output of this synapse is the normalized difference between decay and rise components.
|
181
|
-
|
182
|
-
This class inherits from :py:class:`AlignPost`, which means it can be used in projection patterns
|
183
|
-
where synaptic variables are aligned with post-synaptic neurons, enabling event-driven
|
184
|
-
computation and more efficient handling of sparse connectivity patterns.
|
185
|
-
"""
|
186
|
-
__module__ = 'brainstate.nn'
|
187
|
-
|
188
|
-
def __init__(
|
189
|
-
self,
|
190
|
-
in_size: Size,
|
191
|
-
name: Optional[str] = None,
|
192
|
-
tau_decay: ArrayLike = 10.0 * u.ms,
|
193
|
-
tau_rise: ArrayLike = 1.0 * u.ms,
|
194
|
-
A: Optional[ArrayLike] = None,
|
195
|
-
g_initializer: ArrayLike | Callable = init.ZeroInit(unit=u.mS),
|
196
|
-
):
|
197
|
-
super().__init__(name=name, in_size=in_size)
|
198
|
-
|
199
|
-
# parameters
|
200
|
-
self.tau_decay = init.param(tau_decay, self.varshape)
|
201
|
-
self.tau_rise = init.param(tau_rise, self.varshape)
|
202
|
-
A = self._format_dual_exp_A(A)
|
203
|
-
self.a = (self.tau_decay - self.tau_rise) / self.tau_rise / self.tau_decay * A
|
204
|
-
self.g_initializer = g_initializer
|
205
|
-
|
206
|
-
def _format_dual_exp_A(self, A):
|
207
|
-
A = init.param(A, sizes=self.varshape, allow_none=True)
|
208
|
-
if A is None:
|
209
|
-
A = (
|
210
|
-
self.tau_decay / (self.tau_decay - self.tau_rise) *
|
211
|
-
u.math.float_power(self.tau_rise / self.tau_decay,
|
212
|
-
self.tau_rise / (self.tau_rise - self.tau_decay))
|
213
|
-
)
|
214
|
-
return A
|
215
|
-
|
216
|
-
def init_state(self, batch_size: int = None, **kwargs):
|
217
|
-
self.g_rise = HiddenState(init.param(self.g_initializer, self.varshape, batch_size))
|
218
|
-
self.g_decay = HiddenState(init.param(self.g_initializer, self.varshape, batch_size))
|
219
|
-
|
220
|
-
def reset_state(self, batch_size: int = None, **kwargs):
|
221
|
-
self.g_rise.value = init.param(self.g_initializer, self.varshape, batch_size)
|
222
|
-
self.g_decay.value = init.param(self.g_initializer, self.varshape, batch_size)
|
223
|
-
|
224
|
-
def update(self, x=None):
|
225
|
-
g_rise = exp_euler_step(lambda h: -h / self.tau_rise, self.g_rise.value)
|
226
|
-
g_decay = exp_euler_step(lambda g: -g / self.tau_decay, self.g_decay.value)
|
227
|
-
self.g_rise.value = self.sum_delta_inputs(g_rise)
|
228
|
-
self.g_decay.value = self.sum_delta_inputs(g_decay)
|
229
|
-
if x is not None:
|
230
|
-
self.g_rise.value += x
|
231
|
-
self.g_decay.value += x
|
232
|
-
return self.a * (self.g_decay.value - self.g_rise.value)
|
233
|
-
|
234
|
-
|
235
|
-
class Alpha(Synapse):
|
236
|
-
r"""
|
237
|
-
Alpha synapse model.
|
238
|
-
|
239
|
-
This class implements the alpha function synapse model, which produces
|
240
|
-
a smooth, biologically realistic synaptic conductance waveform.
|
241
|
-
The model is characterized by the differential equation system:
|
242
|
-
|
243
|
-
dh/dt = -h/tau
|
244
|
-
dg/dt = -g/tau + h/tau
|
245
|
-
|
246
|
-
This produces a response that rises and then falls with a characteristic
|
247
|
-
time constant $\tau$, with peak amplitude occurring at time $t = \tau$.
|
248
|
-
|
249
|
-
Parameters
|
250
|
-
----------
|
251
|
-
in_size : Size
|
252
|
-
Size of the input.
|
253
|
-
name : str, optional
|
254
|
-
Name of the synapse instance.
|
255
|
-
tau : ArrayLike, default=8.0*u.ms
|
256
|
-
Time constant of the alpha function in milliseconds.
|
257
|
-
g_initializer : ArrayLike or Callable, default=init.ZeroInit(unit=u.mS)
|
258
|
-
Initial value or initializer for synaptic conductance.
|
259
|
-
|
260
|
-
Attributes
|
261
|
-
----------
|
262
|
-
g : HiddenState
|
263
|
-
Synaptic conductance state variable.
|
264
|
-
h : HiddenState
|
265
|
-
Auxiliary state variable for implementing the alpha function.
|
266
|
-
tau : Parameter
|
267
|
-
Time constant of the alpha function.
|
268
|
-
|
269
|
-
Notes
|
270
|
-
-----
|
271
|
-
The alpha function is defined as g(t) = (t/tau) * exp(1-t/tau) for t ≥ 0.
|
272
|
-
This implementation uses an exponential Euler integration method.
|
273
|
-
The output of this synapse is the conductance value.
|
274
|
-
"""
|
275
|
-
__module__ = 'brainstate.nn'
|
276
|
-
|
277
|
-
def __init__(
|
278
|
-
self,
|
279
|
-
in_size: Size,
|
280
|
-
name: Optional[str] = None,
|
281
|
-
tau: ArrayLike = 8.0 * u.ms,
|
282
|
-
g_initializer: ArrayLike | Callable = init.ZeroInit(unit=u.mS),
|
283
|
-
):
|
284
|
-
super().__init__(name=name, in_size=in_size)
|
285
|
-
|
286
|
-
# parameters
|
287
|
-
self.tau = init.param(tau, self.varshape)
|
288
|
-
self.g_initializer = g_initializer
|
289
|
-
|
290
|
-
def init_state(self, batch_size: int = None, **kwargs):
|
291
|
-
self.g = HiddenState(init.param(self.g_initializer, self.varshape, batch_size))
|
292
|
-
self.h = HiddenState(init.param(self.g_initializer, self.varshape, batch_size))
|
293
|
-
|
294
|
-
def reset_state(self, batch_size: int = None, **kwargs):
|
295
|
-
self.g.value = init.param(self.g_initializer, self.varshape, batch_size)
|
296
|
-
self.h.value = init.param(self.g_initializer, self.varshape, batch_size)
|
297
|
-
|
298
|
-
def update(self, x=None):
|
299
|
-
h = exp_euler_step(lambda h: -h / self.tau, self.h.value)
|
300
|
-
self.g.value = exp_euler_step(lambda g, h: -g / self.tau + h / self.tau, self.g.value, self.h.value)
|
301
|
-
self.h.value = self.sum_delta_inputs(h)
|
302
|
-
if x is not None:
|
303
|
-
self.h.value += x
|
304
|
-
return self.g.value
|
305
|
-
|
306
|
-
|
307
|
-
class AMPA(Synapse):
|
308
|
-
r"""AMPA receptor synapse model.
|
309
|
-
|
310
|
-
This class implements a kinetic model of AMPA (α-amino-3-hydroxy-5-methyl-4-isoxazolepropionic acid)
|
311
|
-
receptor-mediated synaptic transmission. AMPA receptors are ionotropic glutamate receptors that mediate
|
312
|
-
fast excitatory synaptic transmission in the central nervous system.
|
313
|
-
|
314
|
-
The model uses a Markov process approach to describe the state transitions of AMPA receptors
|
315
|
-
between closed and open states, governed by neurotransmitter binding:
|
316
|
-
|
317
|
-
$$
|
318
|
-
\frac{dg}{dt} = \alpha [T] (1-g) - \beta g
|
319
|
-
$$
|
320
|
-
|
321
|
-
$$
|
322
|
-
I_{syn} = g_{max} \cdot g \cdot (V - E)
|
323
|
-
$$
|
324
|
-
|
325
|
-
where:
|
326
|
-
- $g$ represents the fraction of receptors in the open state
|
327
|
-
- $\alpha$ is the binding rate constant [ms^-1 mM^-1]
|
328
|
-
- $\beta$ is the unbinding rate constant [ms^-1]
|
329
|
-
- $[T]$ is the neurotransmitter concentration [mM]
|
330
|
-
- $I_{syn}$ is the resulting synaptic current
|
331
|
-
- $g_{max}$ is the maximum conductance
|
332
|
-
- $V$ is the membrane potential
|
333
|
-
- $E$ is the reversal potential
|
334
|
-
|
335
|
-
The neurotransmitter concentration $[T]$ follows a square pulse of amplitude T and
|
336
|
-
duration T_dur after each presynaptic spike.
|
337
|
-
|
338
|
-
Parameters
|
339
|
-
----------
|
340
|
-
in_size : Size
|
341
|
-
Size of the input.
|
342
|
-
name : str, optional
|
343
|
-
Name of the synapse instance.
|
344
|
-
alpha : ArrayLike, default=0.98/(u.ms*u.mM)
|
345
|
-
Binding rate constant [ms^-1 mM^-1].
|
346
|
-
beta : ArrayLike, default=0.18/u.ms
|
347
|
-
Unbinding rate constant [ms^-1].
|
348
|
-
T : ArrayLike, default=0.5*u.mM
|
349
|
-
Peak neurotransmitter concentration when released [mM].
|
350
|
-
T_dur : ArrayLike, default=0.5*u.ms
|
351
|
-
Duration of neurotransmitter presence in the synaptic cleft [ms].
|
352
|
-
g_initializer : ArrayLike or Callable, default=init.ZeroInit()
|
353
|
-
Initial value or initializer for the synaptic conductance.
|
354
|
-
|
355
|
-
Attributes
|
356
|
-
----------
|
357
|
-
g : HiddenState
|
358
|
-
Fraction of receptors in the open state.
|
359
|
-
spike_arrival_time : ShortTermState
|
360
|
-
Time of the most recent presynaptic spike.
|
361
|
-
|
362
|
-
Notes
|
363
|
-
-----
|
364
|
-
- The model captures the fast-rising and relatively fast-decaying excitatory currents
|
365
|
-
characteristic of AMPA receptor-mediated transmission.
|
366
|
-
- The time course of the synaptic conductance is determined by both the binding and
|
367
|
-
unbinding rate constants and the duration of transmitter presence.
|
368
|
-
- This implementation uses an exponential Euler integration method.
|
369
|
-
|
370
|
-
References
|
371
|
-
----------
|
372
|
-
.. [1] Destexhe, A., Mainen, Z. F., & Sejnowski, T. J. (1994). Synthesis of models for
|
373
|
-
excitable membranes, synaptic transmission and neuromodulation using a common
|
374
|
-
kinetic formalism. Journal of computational neuroscience, 1(3), 195-230.
|
375
|
-
.. [2] Vijayan, S., & Kopell, N. J. (2012). Thalamic model of awake alpha oscillations
|
376
|
-
and implications for stimulus processing. Proceedings of the National Academy
|
377
|
-
of Sciences, 109(45), 18553-18558.
|
378
|
-
"""
|
379
|
-
|
380
|
-
def __init__(
|
381
|
-
self,
|
382
|
-
in_size: Size,
|
383
|
-
name: Optional[str] = None,
|
384
|
-
alpha: ArrayLike = 0.98 / (u.ms * u.mM),
|
385
|
-
beta: ArrayLike = 0.18 / u.ms,
|
386
|
-
T: ArrayLike = 0.5 * u.mM,
|
387
|
-
T_dur: ArrayLike = 0.5 * u.ms,
|
388
|
-
g_initializer: ArrayLike | Callable = init.ZeroInit(unit=u.mS),
|
389
|
-
):
|
390
|
-
super().__init__(name=name, in_size=in_size)
|
391
|
-
|
392
|
-
# parameters
|
393
|
-
self.alpha = init.param(alpha, self.varshape)
|
394
|
-
self.beta = init.param(beta, self.varshape)
|
395
|
-
self.T = init.param(T, self.varshape)
|
396
|
-
self.T_duration = init.param(T_dur, self.varshape)
|
397
|
-
self.g_initializer = g_initializer
|
398
|
-
|
399
|
-
def init_state(self, batch_size=None):
|
400
|
-
self.g = HiddenState(init.param(self.g_initializer, self.varshape, batch_size))
|
401
|
-
self.spike_arrival_time = ShortTermState(init.param(init.Constant(-1e7 * u.ms), self.varshape, batch_size))
|
402
|
-
|
403
|
-
def reset_state(self, batch_or_mode=None, **kwargs):
|
404
|
-
self.g.value = init.param(self.g_initializer, self.varshape, batch_or_mode)
|
405
|
-
self.spike_arrival_time.value = init.param(init.Constant(-1e7 * u.ms), self.varshape, batch_or_mode)
|
406
|
-
|
407
|
-
def update(self, pre_spike):
|
408
|
-
t = environ.get('t')
|
409
|
-
self.spike_arrival_time.value = u.math.where(pre_spike, t, self.spike_arrival_time.value)
|
410
|
-
TT = ((t - self.spike_arrival_time.value) < self.T_duration) * self.T
|
411
|
-
dg = lambda g: self.alpha * TT * (1 * u.get_unit(g) - g) - self.beta * g
|
412
|
-
self.g.value = exp_euler_step(dg, self.g.value)
|
413
|
-
return self.g.value
|
414
|
-
|
415
|
-
|
416
|
-
class GABAa(AMPA):
|
417
|
-
r"""GABAa receptor synapse model.
|
418
|
-
|
419
|
-
This class implements a kinetic model of GABAa (gamma-aminobutyric acid type A)
|
420
|
-
receptor-mediated synaptic transmission. GABAa receptors are ionotropic chloride channels
|
421
|
-
that mediate fast inhibitory synaptic transmission in the central nervous system.
|
422
|
-
|
423
|
-
The model uses the same Markov process approach as the AMPA model but with different
|
424
|
-
kinetic parameters appropriate for GABAa receptors:
|
425
|
-
|
426
|
-
$$
|
427
|
-
\frac{dg}{dt} = \alpha [T] (1-g) - \beta g
|
428
|
-
$$
|
429
|
-
|
430
|
-
$$
|
431
|
-
I_{syn} = - g_{max} \cdot g \cdot (V - E)
|
432
|
-
$$
|
433
|
-
|
434
|
-
where:
|
435
|
-
- $g$ represents the fraction of receptors in the open state
|
436
|
-
- $\alpha$ is the binding rate constant [ms^-1 mM^-1], typically slower than AMPA
|
437
|
-
- $\beta$ is the unbinding rate constant [ms^-1]
|
438
|
-
- $[T]$ is the neurotransmitter (GABA) concentration [mM]
|
439
|
-
- $I_{syn}$ is the resulting synaptic current (note the negative sign indicating inhibition)
|
440
|
-
- $g_{max}$ is the maximum conductance
|
441
|
-
- $V$ is the membrane potential
|
442
|
-
- $E$ is the reversal potential (typically around -80 mV for chloride)
|
443
|
-
|
444
|
-
The neurotransmitter concentration $[T]$ follows a square pulse of amplitude T and
|
445
|
-
duration T_dur after each presynaptic spike.
|
446
|
-
|
447
|
-
Parameters
|
448
|
-
----------
|
449
|
-
in_size : Size
|
450
|
-
Size of the input.
|
451
|
-
name : str, optional
|
452
|
-
Name of the synapse instance.
|
453
|
-
alpha : ArrayLike, default=0.53/(u.ms*u.mM)
|
454
|
-
Binding rate constant [ms^-1 mM^-1]. Typically slower than AMPA receptors.
|
455
|
-
beta : ArrayLike, default=0.18/u.ms
|
456
|
-
Unbinding rate constant [ms^-1].
|
457
|
-
T : ArrayLike, default=1.0*u.mM
|
458
|
-
Peak neurotransmitter concentration when released [mM]. Higher than AMPA.
|
459
|
-
T_dur : ArrayLike, default=1.0*u.ms
|
460
|
-
Duration of neurotransmitter presence in the synaptic cleft [ms]. Longer than AMPA.
|
461
|
-
g_initializer : ArrayLike or Callable, default=init.ZeroInit()
|
462
|
-
Initial value or initializer for the synaptic conductance.
|
463
|
-
|
464
|
-
Attributes
|
465
|
-
----------
|
466
|
-
Inherits all attributes from AMPA class.
|
467
|
-
|
468
|
-
Notes
|
469
|
-
-----
|
470
|
-
- GABAa receptors typically produce slower-rising and longer-lasting currents compared to AMPA receptors.
|
471
|
-
- The inhibitory nature of GABAa receptors is reflected in the convention of using a negative sign in the
|
472
|
-
synaptic current equation.
|
473
|
-
- The reversal potential for GABAa receptors is typically around -80 mV (due to chloride), making them
|
474
|
-
inhibitory for neurons with resting potentials more positive than this value.
|
475
|
-
- This model does not include desensitization, which can be significant for prolonged GABA exposure.
|
476
|
-
|
477
|
-
References
|
478
|
-
----------
|
479
|
-
.. [1] Destexhe, A., Mainen, Z. F., & Sejnowski, T. J. (1994). Synthesis of models for
|
480
|
-
excitable membranes, synaptic transmission and neuromodulation using a common
|
481
|
-
kinetic formalism. Journal of computational neuroscience, 1(3), 195-230.
|
482
|
-
.. [2] Destexhe, A., & Paré, D. (1999). Impact of network activity on the integrative
|
483
|
-
properties of neocortical pyramidal neurons in vivo. Journal of neurophysiology,
|
484
|
-
81(4), 1531-1547.
|
485
|
-
"""
|
486
|
-
|
487
|
-
def __init__(
|
488
|
-
self,
|
489
|
-
in_size: Size,
|
490
|
-
name: Optional[str] = None,
|
491
|
-
alpha: ArrayLike = 0.53 / (u.ms * u.mM),
|
492
|
-
beta: ArrayLike = 0.18 / u.ms,
|
493
|
-
T: ArrayLike = 1.0 * u.mM,
|
494
|
-
T_dur: ArrayLike = 1.0 * u.ms,
|
495
|
-
g_initializer: ArrayLike | Callable = init.ZeroInit(unit=u.mS),
|
496
|
-
):
|
497
|
-
super().__init__(
|
498
|
-
alpha=alpha,
|
499
|
-
beta=beta,
|
500
|
-
T=T,
|
501
|
-
T_dur=T_dur,
|
502
|
-
name=name,
|
503
|
-
in_size=in_size,
|
504
|
-
g_initializer=g_initializer
|
505
|
-
)
|
brainstate/nn/_synapse_test.py
DELETED
@@ -1,131 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
import unittest
|
18
|
-
|
19
|
-
import brainunit as u
|
20
|
-
import jax.numpy as jnp
|
21
|
-
import pytest
|
22
|
-
|
23
|
-
import brainstate
|
24
|
-
from brainstate.nn import Expon, STP, STD
|
25
|
-
|
26
|
-
|
27
|
-
class TestSynapse(unittest.TestCase):
|
28
|
-
def setUp(self):
|
29
|
-
self.in_size = 10
|
30
|
-
self.batch_size = 5
|
31
|
-
self.time_steps = 100
|
32
|
-
|
33
|
-
def generate_input(self):
|
34
|
-
return brainstate.random.randn(self.time_steps, self.batch_size, self.in_size) * u.mS
|
35
|
-
|
36
|
-
def test_expon_synapse(self):
|
37
|
-
tau = 20.0 * u.ms
|
38
|
-
synapse = Expon(self.in_size, tau=tau)
|
39
|
-
inputs = self.generate_input()
|
40
|
-
|
41
|
-
# Test initialization
|
42
|
-
self.assertEqual(synapse.in_size, (self.in_size,))
|
43
|
-
self.assertEqual(synapse.out_size, (self.in_size,))
|
44
|
-
self.assertEqual(synapse.tau, tau)
|
45
|
-
|
46
|
-
# Test forward pass
|
47
|
-
state = synapse.init_state(self.batch_size)
|
48
|
-
call = brainstate.compile.jit(synapse)
|
49
|
-
with brainstate.environ.context(dt=0.1 * u.ms):
|
50
|
-
for t in range(self.time_steps):
|
51
|
-
out = call(inputs[t])
|
52
|
-
self.assertEqual(out.shape, (self.batch_size, self.in_size))
|
53
|
-
|
54
|
-
# Test exponential decay
|
55
|
-
constant_input = jnp.ones((self.batch_size, self.in_size)) * u.mS
|
56
|
-
out1 = call(constant_input)
|
57
|
-
out2 = call(constant_input)
|
58
|
-
self.assertTrue(jnp.all(out2 > out1)) # Output should increase with constant input
|
59
|
-
|
60
|
-
@pytest.mark.skip(reason="Not implemented yet")
|
61
|
-
def test_stp_synapse(self):
|
62
|
-
tau_d = 200.0 * u.ms
|
63
|
-
tau_f = 20.0 * u.ms
|
64
|
-
U = 0.2
|
65
|
-
synapse = STP(self.in_size, tau_d=tau_d, tau_f=tau_f, U=U)
|
66
|
-
inputs = self.generate_input()
|
67
|
-
|
68
|
-
# Test initialization
|
69
|
-
self.assertEqual(synapse.in_size, (self.in_size,))
|
70
|
-
self.assertEqual(synapse.out_size, (self.in_size,))
|
71
|
-
self.assertEqual(synapse.tau_d, tau_d)
|
72
|
-
self.assertEqual(synapse.tau_f, tau_f)
|
73
|
-
self.assertEqual(synapse.U, U)
|
74
|
-
|
75
|
-
# Test forward pass
|
76
|
-
state = synapse.init_state(self.batch_size)
|
77
|
-
call = brainstate.compile.jit(synapse)
|
78
|
-
for t in range(self.time_steps):
|
79
|
-
out = call(inputs[t])
|
80
|
-
self.assertEqual(out.shape, (self.batch_size, self.in_size))
|
81
|
-
|
82
|
-
# Test short-term plasticity
|
83
|
-
constant_input = jnp.ones((self.batch_size, self.in_size)) * u.mS
|
84
|
-
out1 = call(constant_input)
|
85
|
-
out2 = call(constant_input)
|
86
|
-
self.assertTrue(jnp.any(out2 != out1)) # Output should change due to STP
|
87
|
-
|
88
|
-
@pytest.mark.skip(reason="Not implemented yet")
|
89
|
-
def test_std_synapse(self):
|
90
|
-
tau = 200.0
|
91
|
-
U = 0.2
|
92
|
-
synapse = STD(self.in_size, tau=tau, U=U)
|
93
|
-
inputs = self.generate_input()
|
94
|
-
|
95
|
-
# Test initialization
|
96
|
-
self.assertEqual(synapse.in_size, (self.in_size,))
|
97
|
-
self.assertEqual(synapse.out_size, (self.in_size,))
|
98
|
-
self.assertEqual(synapse.tau, tau)
|
99
|
-
self.assertEqual(synapse.U, U)
|
100
|
-
|
101
|
-
# Test forward pass
|
102
|
-
state = synapse.init_state(self.batch_size)
|
103
|
-
for t in range(self.time_steps):
|
104
|
-
out = synapse(inputs[t])
|
105
|
-
self.assertEqual(out.shape, (self.batch_size, self.in_size))
|
106
|
-
|
107
|
-
# Test short-term depression
|
108
|
-
constant_input = jnp.ones((self.batch_size, self.in_size))
|
109
|
-
out1 = synapse(constant_input)
|
110
|
-
out2 = synapse(constant_input)
|
111
|
-
self.assertTrue(jnp.all(out2 < out1)) # Output should decrease due to STD
|
112
|
-
|
113
|
-
def test_keep_size(self):
|
114
|
-
in_size = (2, 3)
|
115
|
-
for SynapseClass in [Expon, ]:
|
116
|
-
synapse = SynapseClass(in_size)
|
117
|
-
self.assertEqual(synapse.in_size, in_size)
|
118
|
-
self.assertEqual(synapse.out_size, in_size)
|
119
|
-
|
120
|
-
inputs = brainstate.random.randn(self.time_steps, self.batch_size, *in_size) * u.mS
|
121
|
-
state = synapse.init_state(self.batch_size)
|
122
|
-
call = brainstate.compile.jit(synapse)
|
123
|
-
with brainstate.environ.context(dt=0.1 * u.ms):
|
124
|
-
for t in range(self.time_steps):
|
125
|
-
out = call(inputs[t])
|
126
|
-
self.assertEqual(out.shape, (self.batch_size, *in_size))
|
127
|
-
|
128
|
-
|
129
|
-
if __name__ == '__main__':
|
130
|
-
with brainstate.environ.context(dt=0.1):
|
131
|
-
unittest.main()
|