brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +15 -28
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/nn/_neuron.py
DELETED
@@ -1,705 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
from typing import Callable, Optional
|
19
|
-
|
20
|
-
import brainunit as u
|
21
|
-
import jax
|
22
|
-
|
23
|
-
from brainstate import init, surrogate, environ
|
24
|
-
from brainstate._state import HiddenState, ShortTermState
|
25
|
-
from brainstate.typing import ArrayLike, Size
|
26
|
-
from ._dynamics import Dynamics
|
27
|
-
from ._exp_euler import exp_euler_step
|
28
|
-
|
29
|
-
__all__ = [
|
30
|
-
'Neuron', 'IF', 'LIF', 'LIFRef', 'ALIF',
|
31
|
-
]
|
32
|
-
|
33
|
-
|
34
|
-
class Neuron(Dynamics):
|
35
|
-
"""
|
36
|
-
Base class for all spiking neuron models.
|
37
|
-
|
38
|
-
This abstract class serves as the foundation for implementing various spiking neuron
|
39
|
-
models. It extends the Dynamics class and provides common functionality for spike
|
40
|
-
generation and membrane potential dynamics.
|
41
|
-
|
42
|
-
All neuron models should inherit from this class and implement the required methods,
|
43
|
-
particularly the `get_spike()` method which defines the spike generation mechanism.
|
44
|
-
|
45
|
-
Parameters
|
46
|
-
----------
|
47
|
-
in_size : Size
|
48
|
-
Size of the input to the neuron.
|
49
|
-
spk_fun : Callable, default=surrogate.InvSquareGrad()
|
50
|
-
Surrogate gradient function for the non-differentiable spike generation.
|
51
|
-
spk_reset : str, default='soft'
|
52
|
-
Reset mechanism after spike generation:
|
53
|
-
- 'soft': subtract threshold from membrane potential
|
54
|
-
- 'hard': use stop_gradient for reset
|
55
|
-
name : str, optional
|
56
|
-
Name of the neuron layer.
|
57
|
-
|
58
|
-
Methods
|
59
|
-
-------
|
60
|
-
get_spike(*args, **kwargs)
|
61
|
-
Abstract method that generates spikes based on neuron state variables.
|
62
|
-
Must be implemented by subclasses.
|
63
|
-
"""
|
64
|
-
__module__ = 'brainstate.nn'
|
65
|
-
|
66
|
-
def __init__(
|
67
|
-
self,
|
68
|
-
in_size: Size,
|
69
|
-
spk_fun: Callable = surrogate.InvSquareGrad(),
|
70
|
-
spk_reset: str = 'soft',
|
71
|
-
name: Optional[str] = None,
|
72
|
-
):
|
73
|
-
super().__init__(in_size, name=name)
|
74
|
-
self.spk_reset = spk_reset
|
75
|
-
self.spk_fun = spk_fun
|
76
|
-
|
77
|
-
def get_spike(self, *args, **kwargs):
|
78
|
-
raise NotImplementedError
|
79
|
-
|
80
|
-
|
81
|
-
class IF(Neuron):
|
82
|
-
r"""Integrate-and-Fire (IF) neuron model.
|
83
|
-
|
84
|
-
This class implements the classic Integrate-and-Fire neuron model, one of the simplest
|
85
|
-
spiking neuron models. It accumulates input current until the membrane potential reaches
|
86
|
-
a threshold, at which point it fires a spike and resets the potential.
|
87
|
-
|
88
|
-
The model is characterized by the following differential equation:
|
89
|
-
|
90
|
-
$$
|
91
|
-
\tau \frac{dV}{dt} = -V + R \cdot I(t)
|
92
|
-
$$
|
93
|
-
|
94
|
-
Spike condition:
|
95
|
-
If $V \geq V_{th}$: emit spike and reset $V = V - V_{th}$ (soft reset) or $V = 0$ (hard reset)
|
96
|
-
|
97
|
-
Parameters
|
98
|
-
----------
|
99
|
-
in_size : Size
|
100
|
-
Size of the input to the neuron.
|
101
|
-
R : ArrayLike, default=1. * u.ohm
|
102
|
-
Membrane resistance.
|
103
|
-
tau : ArrayLike, default=5. * u.ms
|
104
|
-
Membrane time constant.
|
105
|
-
V_th : ArrayLike, default=1. * u.mV
|
106
|
-
Firing threshold voltage (should be positive).
|
107
|
-
V_initializer : Callable, default=init.Constant(0. * u.mV)
|
108
|
-
Initializer for the membrane potential state.
|
109
|
-
spk_fun : Callable, default=surrogate.ReluGrad()
|
110
|
-
Surrogate gradient function for the non-differentiable spike generation.
|
111
|
-
spk_reset : str, default='soft'
|
112
|
-
Reset mechanism after spike generation:
|
113
|
-
- 'soft': subtract threshold V = V - V_th
|
114
|
-
- 'hard': strict reset using stop_gradient
|
115
|
-
name : str, optional
|
116
|
-
Name of the neuron layer.
|
117
|
-
|
118
|
-
Attributes
|
119
|
-
----------
|
120
|
-
V : HiddenState
|
121
|
-
Membrane potential.
|
122
|
-
|
123
|
-
Methods
|
124
|
-
-------
|
125
|
-
init_state(batch_size=None, **kwargs)
|
126
|
-
Initialize the neuron state variables.
|
127
|
-
reset_state(batch_size=None, **kwargs)
|
128
|
-
Reset the neuron state variables.
|
129
|
-
get_spike(V=None)
|
130
|
-
Generate spikes based on the membrane potential.
|
131
|
-
update(x=0. * u.mA)
|
132
|
-
Update the neuron state for one time step and return spikes.
|
133
|
-
|
134
|
-
Examples
|
135
|
-
--------
|
136
|
-
>>> import brainstate as bs
|
137
|
-
>>> import brainunit as u
|
138
|
-
>>>
|
139
|
-
>>> # Create an IF neuron layer with 10 neurons
|
140
|
-
>>> if_neuron = bs.nn.IF(10, tau=8*u.ms, V_th=1.2*u.mV)
|
141
|
-
>>>
|
142
|
-
>>> # Initialize the state
|
143
|
-
>>> if_neuron.init_state(batch_size=1)
|
144
|
-
>>>
|
145
|
-
>>> # Apply an input current and update the neuron state
|
146
|
-
>>> spikes = if_neuron.update(x=2.0*u.mA)
|
147
|
-
>>>
|
148
|
-
>>> # Create a network with IF neurons
|
149
|
-
>>> network = bs.nn.Sequential([
|
150
|
-
... bs.nn.IF(100, tau=5.0*u.ms),
|
151
|
-
... bs.nn.Linear(100, 10)
|
152
|
-
... ])
|
153
|
-
|
154
|
-
Notes
|
155
|
-
-----
|
156
|
-
- Unlike the LIF model, the IF model has no leak towards a resting potential.
|
157
|
-
- The membrane potential decays exponentially with time constant tau in the absence of input.
|
158
|
-
- The time-dependent dynamics are integrated using an exponential Euler method.
|
159
|
-
- The IF model is perfect integrator in the sense that it accumulates input indefinitely
|
160
|
-
until reaching threshold, without any leak current.
|
161
|
-
|
162
|
-
References
|
163
|
-
----------
|
164
|
-
.. [1] Lapicque, L. (1907). Recherches quantitatives sur l'excitation électrique
|
165
|
-
des nerfs traitée comme une polarisation. Journal de Physiologie et de
|
166
|
-
Pathologie Générale, 9, 620-635.
|
167
|
-
.. [2] Abbott, L. F. (1999). Lapicque's introduction of the integrate-and-fire
|
168
|
-
model neuron (1907). Brain Research Bulletin, 50(5-6), 303-304.
|
169
|
-
.. [3] Burkitt, A. N. (2006). A review of the integrate-and-fire neuron model:
|
170
|
-
I. Homogeneous synaptic input. Biological cybernetics, 95(1), 1-19.
|
171
|
-
"""
|
172
|
-
|
173
|
-
__module__ = 'brainstate.nn'
|
174
|
-
|
175
|
-
def __init__(
|
176
|
-
self,
|
177
|
-
in_size: Size,
|
178
|
-
R: ArrayLike = 1. * u.ohm,
|
179
|
-
tau: ArrayLike = 5. * u.ms,
|
180
|
-
V_th: ArrayLike = 1. * u.mV, # should be positive
|
181
|
-
V_initializer: Callable = init.Constant(0. * u.mV),
|
182
|
-
spk_fun: Callable = surrogate.ReluGrad(),
|
183
|
-
spk_reset: str = 'soft',
|
184
|
-
name: str = None,
|
185
|
-
):
|
186
|
-
super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
|
187
|
-
|
188
|
-
# parameters
|
189
|
-
self.R = init.param(R, self.varshape)
|
190
|
-
self.tau = init.param(tau, self.varshape)
|
191
|
-
self.V_th = init.param(V_th, self.varshape)
|
192
|
-
self.V_initializer = V_initializer
|
193
|
-
|
194
|
-
def init_state(self, batch_size: int = None, **kwargs):
|
195
|
-
self.V = HiddenState(init.param(self.V_initializer, self.varshape, batch_size))
|
196
|
-
|
197
|
-
def reset_state(self, batch_size: int = None, **kwargs):
|
198
|
-
self.V.value = init.param(self.V_initializer, self.varshape, batch_size)
|
199
|
-
|
200
|
-
def get_spike(self, V=None):
|
201
|
-
V = self.V.value if V is None else V
|
202
|
-
v_scaled = (V - self.V_th) / self.V_th
|
203
|
-
return self.spk_fun(v_scaled)
|
204
|
-
|
205
|
-
def update(self, x=0. * u.mA):
|
206
|
-
# reset
|
207
|
-
last_V = self.V.value
|
208
|
-
last_spike = self.get_spike(self.V.value)
|
209
|
-
V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_V)
|
210
|
-
V = last_V - V_th * last_spike
|
211
|
-
# membrane potential
|
212
|
-
dv = lambda v: (-v + self.R * self.sum_current_inputs(x, v)) / self.tau
|
213
|
-
V = exp_euler_step(dv, V)
|
214
|
-
V = self.sum_delta_inputs(V)
|
215
|
-
self.V.value = V
|
216
|
-
return self.get_spike(V)
|
217
|
-
|
218
|
-
|
219
|
-
class LIF(Neuron):
|
220
|
-
r"""Leaky Integrate-and-Fire (LIF) neuron model.
|
221
|
-
|
222
|
-
This class implements the Leaky Integrate-and-Fire neuron model, which extends the basic
|
223
|
-
Integrate-and-Fire model by adding a leak term. The leak causes the membrane potential
|
224
|
-
to decay towards a resting value in the absence of input, making the model more
|
225
|
-
biologically plausible.
|
226
|
-
|
227
|
-
The model is characterized by the following differential equation:
|
228
|
-
|
229
|
-
$$
|
230
|
-
\tau \frac{dV}{dt} = -(V - V_{rest}) + R \cdot I(t)
|
231
|
-
$$
|
232
|
-
|
233
|
-
Spike condition:
|
234
|
-
If $V \geq V_{th}$: emit spike and reset $V = V_{reset}$
|
235
|
-
|
236
|
-
Parameters
|
237
|
-
----------
|
238
|
-
in_size : Size
|
239
|
-
Size of the input to the neuron.
|
240
|
-
R : ArrayLike, default=1. * u.ohm
|
241
|
-
Membrane resistance.
|
242
|
-
tau : ArrayLike, default=5. * u.ms
|
243
|
-
Membrane time constant.
|
244
|
-
V_th : ArrayLike, default=1. * u.mV
|
245
|
-
Firing threshold voltage.
|
246
|
-
V_reset : ArrayLike, default=0. * u.mV
|
247
|
-
Reset voltage after spike.
|
248
|
-
V_rest : ArrayLike, default=0. * u.mV
|
249
|
-
Resting membrane potential.
|
250
|
-
V_initializer : Callable, default=init.Constant(0. * u.mV)
|
251
|
-
Initializer for the membrane potential state.
|
252
|
-
spk_fun : Callable, default=surrogate.ReluGrad()
|
253
|
-
Surrogate gradient function for the non-differentiable spike generation.
|
254
|
-
spk_reset : str, default='soft'
|
255
|
-
Reset mechanism after spike generation:
|
256
|
-
- 'soft': subtract threshold V = V - V_th
|
257
|
-
- 'hard': strict reset using stop_gradient
|
258
|
-
name : str, optional
|
259
|
-
Name of the neuron layer.
|
260
|
-
|
261
|
-
Attributes
|
262
|
-
----------
|
263
|
-
V : HiddenState
|
264
|
-
Membrane potential.
|
265
|
-
|
266
|
-
Methods
|
267
|
-
-------
|
268
|
-
init_state(batch_size=None, **kwargs)
|
269
|
-
Initialize the neuron state variables.
|
270
|
-
reset_state(batch_size=None, **kwargs)
|
271
|
-
Reset the neuron state variables.
|
272
|
-
get_spike(V=None)
|
273
|
-
Generate spikes based on the membrane potential.
|
274
|
-
update(x=0. * u.mA)
|
275
|
-
Update the neuron state for one time step and return spikes.
|
276
|
-
|
277
|
-
Examples
|
278
|
-
--------
|
279
|
-
>>> import brainstate
|
280
|
-
>>> import brainunit as u
|
281
|
-
>>>
|
282
|
-
>>> # Create a LIF neuron layer with 10 neurons
|
283
|
-
>>> lif = brainstate.nn.LIF(10, tau=10*u.ms, V_th=0.8*u.mV)
|
284
|
-
>>>
|
285
|
-
>>> # Initialize the state
|
286
|
-
>>> lif.init_state(batch_size=1)
|
287
|
-
>>>
|
288
|
-
>>> # Apply an input current and update the neuron state
|
289
|
-
>>> spikes = lif.update(x=1.5*u.mA)
|
290
|
-
|
291
|
-
Notes
|
292
|
-
-----
|
293
|
-
- The leak term causes the membrane potential to decay exponentially towards V_rest
|
294
|
-
with time constant tau when no input is present.
|
295
|
-
- The time-dependent dynamics are integrated using an exponential Euler method.
|
296
|
-
- Spike generation is non-differentiable, so surrogate gradients are used for
|
297
|
-
backpropagation during training.
|
298
|
-
|
299
|
-
References
|
300
|
-
----------
|
301
|
-
.. [1] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014).
|
302
|
-
Neuronal dynamics: From single neurons to networks and models of cognition.
|
303
|
-
Cambridge University Press.
|
304
|
-
.. [2] Burkitt, A. N. (2006). A review of the integrate-and-fire neuron model:
|
305
|
-
I. Homogeneous synaptic input. Biological cybernetics, 95(1), 1-19.
|
306
|
-
"""
|
307
|
-
__module__ = 'brainstate.nn'
|
308
|
-
|
309
|
-
def __init__(
|
310
|
-
self,
|
311
|
-
in_size: Size,
|
312
|
-
R: ArrayLike = 1. * u.ohm,
|
313
|
-
tau: ArrayLike = 5. * u.ms,
|
314
|
-
V_th: ArrayLike = 1. * u.mV,
|
315
|
-
V_reset: ArrayLike = 0. * u.mV,
|
316
|
-
V_rest: ArrayLike = 0. * u.mV,
|
317
|
-
V_initializer: Callable = init.Constant(0. * u.mV),
|
318
|
-
spk_fun: Callable = surrogate.ReluGrad(),
|
319
|
-
spk_reset: str = 'soft',
|
320
|
-
name: str = None,
|
321
|
-
):
|
322
|
-
super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
|
323
|
-
|
324
|
-
# parameters
|
325
|
-
self.R = init.param(R, self.varshape)
|
326
|
-
self.tau = init.param(tau, self.varshape)
|
327
|
-
self.V_th = init.param(V_th, self.varshape)
|
328
|
-
self.V_rest = init.param(V_rest, self.varshape)
|
329
|
-
self.V_reset = init.param(V_reset, self.varshape)
|
330
|
-
self.V_initializer = V_initializer
|
331
|
-
|
332
|
-
def init_state(self, batch_size: int = None, **kwargs):
|
333
|
-
self.V = HiddenState(init.param(self.V_initializer, self.varshape, batch_size))
|
334
|
-
|
335
|
-
def reset_state(self, batch_size: int = None, **kwargs):
|
336
|
-
self.V.value = init.param(self.V_initializer, self.varshape, batch_size)
|
337
|
-
|
338
|
-
def get_spike(self, V: ArrayLike = None):
|
339
|
-
V = self.V.value if V is None else V
|
340
|
-
v_scaled = (V - self.V_th) / (self.V_th - self.V_reset)
|
341
|
-
return self.spk_fun(v_scaled)
|
342
|
-
|
343
|
-
def update(self, x=0. * u.mA):
|
344
|
-
last_v = self.V.value
|
345
|
-
lst_spk = self.get_spike(last_v)
|
346
|
-
V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v)
|
347
|
-
V = last_v - (V_th - self.V_reset) * lst_spk
|
348
|
-
# membrane potential
|
349
|
-
dv = lambda v: (-v + self.V_rest + self.R * self.sum_current_inputs(x, v)) / self.tau
|
350
|
-
V = exp_euler_step(dv, V)
|
351
|
-
V = self.sum_delta_inputs(V)
|
352
|
-
self.V.value = V
|
353
|
-
return self.get_spike(V)
|
354
|
-
|
355
|
-
|
356
|
-
class LIFRef(Neuron):
|
357
|
-
r"""Leaky Integrate-and-Fire neuron model with refractory period.
|
358
|
-
|
359
|
-
This class implements a Leaky Integrate-and-Fire neuron model that includes a
|
360
|
-
refractory period after spiking, during which the neuron cannot fire regardless
|
361
|
-
of input. This better captures the behavior of biological neurons that exhibit
|
362
|
-
a recovery period after action potential generation.
|
363
|
-
|
364
|
-
The model is characterized by the following equations:
|
365
|
-
|
366
|
-
When not in refractory period:
|
367
|
-
$$
|
368
|
-
\tau \frac{dV}{dt} = -(V - V_{rest}) + R \cdot I(t)
|
369
|
-
$$
|
370
|
-
|
371
|
-
During refractory period:
|
372
|
-
$$
|
373
|
-
V = V_{reset}
|
374
|
-
$$
|
375
|
-
|
376
|
-
Spike condition:
|
377
|
-
If $V \geq V_{th}$: emit spike, set $V = V_{reset}$, and enter refractory period for $\tau_{ref}$
|
378
|
-
|
379
|
-
Parameters
|
380
|
-
----------
|
381
|
-
in_size : Size
|
382
|
-
Size of the input to the neuron.
|
383
|
-
R : ArrayLike, default=1. * u.ohm
|
384
|
-
Membrane resistance.
|
385
|
-
tau : ArrayLike, default=5. * u.ms
|
386
|
-
Membrane time constant.
|
387
|
-
tau_ref : ArrayLike, default=5. * u.ms
|
388
|
-
Refractory period duration.
|
389
|
-
V_th : ArrayLike, default=1. * u.mV
|
390
|
-
Firing threshold voltage.
|
391
|
-
V_reset : ArrayLike, default=0. * u.mV
|
392
|
-
Reset voltage after spike.
|
393
|
-
V_rest : ArrayLike, default=0. * u.mV
|
394
|
-
Resting membrane potential.
|
395
|
-
V_initializer : Callable, default=init.Constant(0. * u.mV)
|
396
|
-
Initializer for the membrane potential state.
|
397
|
-
spk_fun : Callable, default=surrogate.ReluGrad()
|
398
|
-
Surrogate gradient function for the non-differentiable spike generation.
|
399
|
-
spk_reset : str, default='soft'
|
400
|
-
Reset mechanism after spike generation:
|
401
|
-
- 'soft': subtract threshold V = V - V_th
|
402
|
-
- 'hard': strict reset using stop_gradient
|
403
|
-
name : str, optional
|
404
|
-
Name of the neuron layer.
|
405
|
-
|
406
|
-
Attributes
|
407
|
-
----------
|
408
|
-
V : HiddenState
|
409
|
-
Membrane potential.
|
410
|
-
last_spike_time : ShortTermState
|
411
|
-
Time of the last spike, used to implement refractory period.
|
412
|
-
|
413
|
-
Methods
|
414
|
-
-------
|
415
|
-
init_state(batch_size=None, **kwargs)
|
416
|
-
Initialize the neuron state variables.
|
417
|
-
reset_state(batch_size=None, **kwargs)
|
418
|
-
Reset the neuron state variables.
|
419
|
-
get_spike(V=None)
|
420
|
-
Generate spikes based on the membrane potential.
|
421
|
-
update(x=0. * u.mA)
|
422
|
-
Update the neuron state for one time step and return spikes.
|
423
|
-
|
424
|
-
Examples
|
425
|
-
--------
|
426
|
-
>>> import brainstate as bs
|
427
|
-
>>> import brainunit as u
|
428
|
-
>>>
|
429
|
-
>>> # Create a LIFRef neuron layer with 10 neurons
|
430
|
-
>>> lifref = bs.nn.LIFRef(10,
|
431
|
-
... tau=10*u.ms,
|
432
|
-
... tau_ref=5*u.ms,
|
433
|
-
... V_th=0.8*u.mV)
|
434
|
-
>>>
|
435
|
-
>>> # Initialize the state
|
436
|
-
>>> lifref.init_state(batch_size=1)
|
437
|
-
>>>
|
438
|
-
>>> # Apply an input current and update the neuron state
|
439
|
-
>>> spikes = lifref.update(x=1.5*u.mA)
|
440
|
-
>>>
|
441
|
-
>>> # Create a network with refractory neurons
|
442
|
-
>>> network = bs.nn.Sequential([
|
443
|
-
... bs.nn.LIFRef(100, tau_ref=4*u.ms),
|
444
|
-
... bs.nn.Linear(100, 10)
|
445
|
-
... ])
|
446
|
-
|
447
|
-
Notes
|
448
|
-
-----
|
449
|
-
- The refractory period is implemented by tracking the time of the last spike
|
450
|
-
and preventing membrane potential updates if the elapsed time is less than tau_ref.
|
451
|
-
- During the refractory period, the membrane potential remains at the reset value
|
452
|
-
regardless of input current strength.
|
453
|
-
- Refractory periods prevent high-frequency repetitive firing and are critical
|
454
|
-
for realistic neural dynamics.
|
455
|
-
- The time-dependent dynamics are integrated using an exponential Euler method.
|
456
|
-
- The simulation environment time variable 't' is used to track the refractory state.
|
457
|
-
|
458
|
-
References
|
459
|
-
----------
|
460
|
-
.. [1] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014).
|
461
|
-
Neuronal dynamics: From single neurons to networks and models of cognition.
|
462
|
-
Cambridge University Press.
|
463
|
-
.. [2] Burkitt, A. N. (2006). A review of the integrate-and-fire neuron model:
|
464
|
-
I. Homogeneous synaptic input. Biological cybernetics, 95(1), 1-19.
|
465
|
-
.. [3] Izhikevich, E. M. (2003). Simple model of spiking neurons. IEEE Transactions on
|
466
|
-
neural networks, 14(6), 1569-1572.
|
467
|
-
"""
|
468
|
-
__module__ = 'brainstate.nn'
|
469
|
-
|
470
|
-
def __init__(
|
471
|
-
self,
|
472
|
-
in_size: Size,
|
473
|
-
R: ArrayLike = 1. * u.ohm,
|
474
|
-
tau: ArrayLike = 5. * u.ms,
|
475
|
-
tau_ref: ArrayLike = 5. * u.ms,
|
476
|
-
V_th: ArrayLike = 1. * u.mV,
|
477
|
-
V_reset: ArrayLike = 0. * u.mV,
|
478
|
-
V_rest: ArrayLike = 0. * u.mV,
|
479
|
-
V_initializer: Callable = init.Constant(0. * u.mV),
|
480
|
-
spk_fun: Callable = surrogate.ReluGrad(),
|
481
|
-
spk_reset: str = 'soft',
|
482
|
-
name: str = None,
|
483
|
-
):
|
484
|
-
super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
|
485
|
-
|
486
|
-
# parameters
|
487
|
-
self.R = init.param(R, self.varshape)
|
488
|
-
self.tau = init.param(tau, self.varshape)
|
489
|
-
self.tau_ref = init.param(tau_ref, self.varshape)
|
490
|
-
self.V_th = init.param(V_th, self.varshape)
|
491
|
-
self.V_rest = init.param(V_rest, self.varshape)
|
492
|
-
self.V_reset = init.param(V_reset, self.varshape)
|
493
|
-
self.V_initializer = V_initializer
|
494
|
-
|
495
|
-
def init_state(self, batch_size: int = None, **kwargs):
|
496
|
-
self.V = HiddenState(init.param(self.V_initializer, self.varshape, batch_size))
|
497
|
-
self.last_spike_time = ShortTermState(init.param(init.Constant(-1e7 * u.ms), self.varshape, batch_size))
|
498
|
-
|
499
|
-
def reset_state(self, batch_size: int = None, **kwargs):
|
500
|
-
self.V.value = init.param(self.V_initializer, self.varshape, batch_size)
|
501
|
-
self.last_spike_time.value = init.param(init.Constant(-1e7 * u.ms), self.varshape, batch_size)
|
502
|
-
|
503
|
-
def get_spike(self, V: ArrayLike = None):
|
504
|
-
V = self.V.value if V is None else V
|
505
|
-
v_scaled = (V - self.V_th) / (self.V_th - self.V_reset)
|
506
|
-
return self.spk_fun(v_scaled)
|
507
|
-
|
508
|
-
def update(self, x=0. * u.mA):
|
509
|
-
t = environ.get('t')
|
510
|
-
last_v = self.V.value
|
511
|
-
lst_spk = self.get_spike(last_v)
|
512
|
-
V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v)
|
513
|
-
last_v = last_v - (V_th - self.V_reset) * lst_spk
|
514
|
-
# membrane potential
|
515
|
-
dv = lambda v: (-v + self.V_rest + self.R * self.sum_current_inputs(x, v)) / self.tau
|
516
|
-
V = exp_euler_step(dv, last_v)
|
517
|
-
V = self.sum_delta_inputs(V)
|
518
|
-
self.V.value = u.math.where(t - self.last_spike_time.value < self.tau_ref, last_v, V)
|
519
|
-
# spike time evaluation
|
520
|
-
lst_spk_time = u.math.where(self.V.value >= self.V_th, environ.get('t'), self.last_spike_time.value)
|
521
|
-
self.last_spike_time.value = jax.lax.stop_gradient(lst_spk_time)
|
522
|
-
return self.get_spike()
|
523
|
-
|
524
|
-
|
525
|
-
class ALIF(Neuron):
|
526
|
-
r"""Adaptive Leaky Integrate-and-Fire (ALIF) neuron model.
|
527
|
-
|
528
|
-
This class implements the Adaptive Leaky Integrate-and-Fire neuron model, which extends
|
529
|
-
the basic LIF model by adding an adaptation variable. This adaptation mechanism increases
|
530
|
-
the effective firing threshold after each spike, allowing the neuron to exhibit
|
531
|
-
spike-frequency adaptation - a common feature in biological neurons that reduces
|
532
|
-
firing rate during sustained stimulation.
|
533
|
-
|
534
|
-
The model is characterized by the following differential equations:
|
535
|
-
|
536
|
-
$$
|
537
|
-
\tau \frac{dV}{dt} = -(V - V_{rest}) + R \cdot I(t)
|
538
|
-
$$
|
539
|
-
|
540
|
-
$$
|
541
|
-
\tau_a \frac{da}{dt} = -a
|
542
|
-
$$
|
543
|
-
|
544
|
-
Spike condition:
|
545
|
-
If $V \geq V_{th} + \beta \cdot a$: emit spike, set $V = V_{reset}$, and increment $a = a + 1$
|
546
|
-
|
547
|
-
Parameters
|
548
|
-
----------
|
549
|
-
in_size : Size
|
550
|
-
Size of the input to the neuron.
|
551
|
-
R : ArrayLike, default=1. * u.ohm
|
552
|
-
Membrane resistance.
|
553
|
-
tau : ArrayLike, default=5. * u.ms
|
554
|
-
Membrane time constant.
|
555
|
-
tau_a : ArrayLike, default=100. * u.ms
|
556
|
-
Adaptation time constant (typically much longer than tau).
|
557
|
-
V_th : ArrayLike, default=1. * u.mV
|
558
|
-
Base firing threshold voltage.
|
559
|
-
V_reset : ArrayLike, default=0. * u.mV
|
560
|
-
Reset voltage after spike.
|
561
|
-
V_rest : ArrayLike, default=0. * u.mV
|
562
|
-
Resting membrane potential.
|
563
|
-
beta : ArrayLike, default=0.1 * u.mV
|
564
|
-
Adaptation coupling parameter that scales the effect of the adaptation variable.
|
565
|
-
spk_fun : Callable, default=surrogate.ReluGrad()
|
566
|
-
Surrogate gradient function for the non-differentiable spike generation.
|
567
|
-
spk_reset : str, default='soft'
|
568
|
-
Reset mechanism after spike generation:
|
569
|
-
- 'soft': subtract threshold V = V - V_th
|
570
|
-
- 'hard': strict reset using stop_gradient
|
571
|
-
V_initializer : Callable, default=init.Constant(0. * u.mV)
|
572
|
-
Initializer for the membrane potential state.
|
573
|
-
a_initializer : Callable, default=init.Constant(0.)
|
574
|
-
Initializer for the adaptation variable.
|
575
|
-
name : str, optional
|
576
|
-
Name of the neuron layer.
|
577
|
-
|
578
|
-
Attributes
|
579
|
-
----------
|
580
|
-
V : HiddenState
|
581
|
-
Membrane potential.
|
582
|
-
a : HiddenState
|
583
|
-
Adaptation variable that increases after each spike and decays exponentially.
|
584
|
-
|
585
|
-
Methods
|
586
|
-
-------
|
587
|
-
init_state(batch_size=None, **kwargs)
|
588
|
-
Initialize the neuron state variables.
|
589
|
-
reset_state(batch_size=None, **kwargs)
|
590
|
-
Reset the neuron state variables.
|
591
|
-
get_spike(V=None, a=None)
|
592
|
-
Generate spikes based on the membrane potential and adaptation variable.
|
593
|
-
update(x=0. * u.mA)
|
594
|
-
Update the neuron state for one time step and return spikes.
|
595
|
-
|
596
|
-
Examples
|
597
|
-
--------
|
598
|
-
>>> import brainstate as bs
|
599
|
-
>>> import brainunit as u
|
600
|
-
>>>
|
601
|
-
>>> # Create an ALIF neuron layer with 10 neurons
|
602
|
-
>>> alif = bs.nn.ALIF(10,
|
603
|
-
... tau=10*u.ms,
|
604
|
-
... tau_a=200*u.ms,
|
605
|
-
... beta=0.2*u.mV)
|
606
|
-
>>>
|
607
|
-
>>> # Initialize the state
|
608
|
-
>>> alif.init_state(batch_size=1)
|
609
|
-
>>>
|
610
|
-
>>> # Apply an input current and update the neuron state
|
611
|
-
>>> spikes = alif.update(x=1.5*u.mA)
|
612
|
-
>>>
|
613
|
-
>>> # Create a network with adaptation for burst detection
|
614
|
-
>>> network = bs.nn.Sequential([
|
615
|
-
... bs.nn.ALIF(100, tau_a=150*u.ms, beta=0.3*u.mV),
|
616
|
-
... bs.nn.Linear(100, 10)
|
617
|
-
... ])
|
618
|
-
|
619
|
-
Notes
|
620
|
-
-----
|
621
|
-
- The adaptation variable 'a' increases by 1 with each spike and decays exponentially
|
622
|
-
with time constant tau_a between spikes.
|
623
|
-
- The effective threshold increases by beta*a, making it progressively harder for the
|
624
|
-
neuron to fire when it has recently been active.
|
625
|
-
- This adaptation mechanism creates spike-frequency adaptation, allowing the neuron
|
626
|
-
to respond strongly to input onset but then reduce its firing rate even if the
|
627
|
-
input remains constant.
|
628
|
-
- The adaptation time constant tau_a is typically much larger than the membrane time
|
629
|
-
constant tau, creating a longer-lasting adaptation effect.
|
630
|
-
- The time-dependent dynamics are integrated using an exponential Euler method.
|
631
|
-
|
632
|
-
References
|
633
|
-
----------
|
634
|
-
.. [1] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014).
|
635
|
-
Neuronal dynamics: From single neurons to networks and models of cognition.
|
636
|
-
Cambridge University Press.
|
637
|
-
.. [2] Brette, R., & Gerstner, W. (2005). Adaptive exponential integrate-and-fire model
|
638
|
-
as an effective description of neuronal activity. Journal of neurophysiology,
|
639
|
-
94(5), 3637-3642.
|
640
|
-
.. [3] Naud, R., Marcille, N., Clopath, C., & Gerstner, W. (2008). Firing patterns in
|
641
|
-
the adaptive exponential integrate-and-fire model. Biological cybernetics,
|
642
|
-
99(4), 335-347.
|
643
|
-
"""
|
644
|
-
__module__ = 'brainstate.nn'
|
645
|
-
|
646
|
-
def __init__(
|
647
|
-
self,
|
648
|
-
in_size: Size,
|
649
|
-
R: ArrayLike = 1. * u.ohm,
|
650
|
-
tau: ArrayLike = 5. * u.ms,
|
651
|
-
tau_a: ArrayLike = 100. * u.ms,
|
652
|
-
V_th: ArrayLike = 1. * u.mV,
|
653
|
-
V_reset: ArrayLike = 0. * u.mV,
|
654
|
-
V_rest: ArrayLike = 0. * u.mV,
|
655
|
-
beta: ArrayLike = 0.1 * u.mV,
|
656
|
-
spk_fun: Callable = surrogate.ReluGrad(),
|
657
|
-
spk_reset: str = 'soft',
|
658
|
-
V_initializer: Callable = init.Constant(0. * u.mV),
|
659
|
-
a_initializer: Callable = init.Constant(0.),
|
660
|
-
name: str = None,
|
661
|
-
):
|
662
|
-
super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
|
663
|
-
|
664
|
-
# parameters
|
665
|
-
self.R = init.param(R, self.varshape)
|
666
|
-
self.tau = init.param(tau, self.varshape)
|
667
|
-
self.tau_a = init.param(tau_a, self.varshape)
|
668
|
-
self.V_th = init.param(V_th, self.varshape)
|
669
|
-
self.V_reset = init.param(V_reset, self.varshape)
|
670
|
-
self.V_rest = init.param(V_rest, self.varshape)
|
671
|
-
self.beta = init.param(beta, self.varshape)
|
672
|
-
|
673
|
-
# functions
|
674
|
-
self.V_initializer = V_initializer
|
675
|
-
self.a_initializer = a_initializer
|
676
|
-
|
677
|
-
def init_state(self, batch_size: int = None, **kwargs):
|
678
|
-
self.V = HiddenState(init.param(self.V_initializer, self.varshape, batch_size))
|
679
|
-
self.a = HiddenState(init.param(self.a_initializer, self.varshape, batch_size))
|
680
|
-
|
681
|
-
def reset_state(self, batch_size: int = None, **kwargs):
|
682
|
-
self.V.value = init.param(self.V_initializer, self.varshape, batch_size)
|
683
|
-
self.a.value = init.param(self.a_initializer, self.varshape, batch_size)
|
684
|
-
|
685
|
-
def get_spike(self, V=None, a=None):
|
686
|
-
V = self.V.value if V is None else V
|
687
|
-
a = self.a.value if a is None else a
|
688
|
-
v_scaled = (V - self.V_th - self.beta * a) / (self.V_th - self.V_reset)
|
689
|
-
return self.spk_fun(v_scaled)
|
690
|
-
|
691
|
-
def update(self, x=0. * u.mA):
|
692
|
-
last_v = self.V.value
|
693
|
-
last_a = self.a.value
|
694
|
-
lst_spk = self.get_spike(last_v, last_a)
|
695
|
-
V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v)
|
696
|
-
V = last_v - (V_th - self.V_reset) * lst_spk
|
697
|
-
a = last_a + lst_spk
|
698
|
-
# membrane potential
|
699
|
-
dv = lambda v: (-v + self.V_rest + self.R * self.sum_current_inputs(x, v)) / self.tau
|
700
|
-
da = lambda a: -a / self.tau_a
|
701
|
-
V = exp_euler_step(dv, V)
|
702
|
-
a = exp_euler_step(da, a)
|
703
|
-
self.V.value = self.sum_delta_inputs(V)
|
704
|
-
self.a.value = a
|
705
|
-
return self.get_spike(self.V.value, self.a.value)
|