brainstate 0.1.8__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +58 -51
- brainstate/_compatible_import.py +148 -148
- brainstate/_state.py +1605 -1663
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/augment/__init__.py +30 -30
- brainstate/augment/_autograd.py +778 -778
- brainstate/augment/_autograd_test.py +1289 -1289
- brainstate/augment/_eval_shape.py +99 -99
- brainstate/augment/_eval_shape_test.py +38 -38
- brainstate/augment/_mapping.py +1060 -1060
- brainstate/augment/_mapping_test.py +597 -597
- brainstate/augment/_random.py +151 -151
- brainstate/compile/__init__.py +38 -38
- brainstate/compile/_ad_checkpoint.py +204 -204
- brainstate/compile/_ad_checkpoint_test.py +49 -49
- brainstate/compile/_conditions.py +256 -256
- brainstate/compile/_conditions_test.py +220 -220
- brainstate/compile/_error_if.py +92 -92
- brainstate/compile/_error_if_test.py +52 -52
- brainstate/compile/_jit.py +346 -346
- brainstate/compile/_jit_test.py +143 -143
- brainstate/compile/_loop_collect_return.py +536 -536
- brainstate/compile/_loop_collect_return_test.py +58 -58
- brainstate/compile/_loop_no_collection.py +184 -184
- brainstate/compile/_loop_no_collection_test.py +50 -50
- brainstate/compile/_make_jaxpr.py +888 -888
- brainstate/compile/_make_jaxpr_test.py +156 -156
- brainstate/compile/_progress_bar.py +202 -202
- brainstate/compile/_unvmap.py +159 -159
- brainstate/compile/_util.py +147 -147
- brainstate/environ.py +563 -563
- brainstate/environ_test.py +62 -62
- brainstate/functional/__init__.py +27 -26
- brainstate/graph/__init__.py +29 -29
- brainstate/graph/_graph_node.py +244 -244
- brainstate/graph/_graph_node_test.py +73 -73
- brainstate/graph/_graph_operation.py +1738 -1738
- brainstate/graph/_graph_operation_test.py +563 -563
- brainstate/init/__init__.py +26 -26
- brainstate/init/_base.py +52 -52
- brainstate/init/_generic.py +244 -244
- brainstate/init/_random_inits.py +553 -553
- brainstate/init/_random_inits_test.py +149 -149
- brainstate/init/_regular_inits.py +105 -105
- brainstate/init/_regular_inits_test.py +50 -50
- brainstate/mixin.py +365 -363
- brainstate/mixin_test.py +77 -73
- brainstate/nn/__init__.py +135 -131
- brainstate/{functional → nn}/_activations.py +808 -813
- brainstate/{functional → nn}/_activations_test.py +331 -331
- brainstate/nn/_collective_ops.py +514 -514
- brainstate/nn/_collective_ops_test.py +43 -43
- brainstate/nn/_common.py +178 -178
- brainstate/nn/_conv.py +501 -501
- brainstate/nn/_conv_test.py +238 -238
- brainstate/nn/_delay.py +509 -502
- brainstate/nn/_delay_test.py +238 -184
- brainstate/nn/_dropout.py +426 -426
- brainstate/nn/_dropout_test.py +100 -100
- brainstate/nn/_dynamics.py +1343 -1343
- brainstate/nn/_dynamics_test.py +78 -78
- brainstate/nn/_elementwise.py +1119 -1119
- brainstate/nn/_elementwise_test.py +169 -169
- brainstate/nn/_embedding.py +58 -58
- brainstate/nn/_exp_euler.py +92 -92
- brainstate/nn/_exp_euler_test.py +35 -35
- brainstate/nn/_fixedprob.py +239 -239
- brainstate/nn/_fixedprob_test.py +114 -114
- brainstate/nn/_inputs.py +608 -608
- brainstate/nn/_linear.py +424 -424
- brainstate/nn/_linear_mv.py +83 -83
- brainstate/nn/_linear_mv_test.py +120 -120
- brainstate/nn/_linear_test.py +107 -107
- brainstate/nn/_ltp.py +28 -28
- brainstate/nn/_module.py +377 -377
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_neuron.py +705 -705
- brainstate/nn/_neuron_test.py +161 -161
- brainstate/nn/_normalizations.py +975 -918
- brainstate/nn/_normalizations_test.py +73 -73
- brainstate/{functional → nn}/_others.py +46 -46
- brainstate/nn/_poolings.py +1177 -1177
- brainstate/nn/_poolings_test.py +217 -217
- brainstate/nn/_projection.py +486 -486
- brainstate/nn/_rate_rnns.py +554 -554
- brainstate/nn/_rate_rnns_test.py +63 -63
- brainstate/nn/_readout.py +209 -209
- brainstate/nn/_readout_test.py +53 -53
- brainstate/nn/_stp.py +236 -236
- brainstate/nn/_synapse.py +505 -505
- brainstate/nn/_synapse_test.py +131 -131
- brainstate/nn/_synaptic_projection.py +423 -423
- brainstate/nn/_synouts.py +162 -162
- brainstate/nn/_synouts_test.py +57 -57
- brainstate/nn/_utils.py +89 -89
- brainstate/nn/metrics.py +388 -388
- brainstate/optim/__init__.py +38 -38
- brainstate/optim/_base.py +64 -64
- brainstate/optim/_lr_scheduler.py +448 -448
- brainstate/optim/_lr_scheduler_test.py +50 -50
- brainstate/optim/_optax_optimizer.py +152 -152
- brainstate/optim/_optax_optimizer_test.py +53 -53
- brainstate/optim/_sgd_optimizer.py +1104 -1104
- brainstate/random/__init__.py +24 -24
- brainstate/random/_rand_funs.py +3616 -3616
- brainstate/random/_rand_funs_test.py +567 -567
- brainstate/random/_rand_seed.py +210 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1409 -1409
- brainstate/random/_random_for_unit.py +52 -52
- brainstate/surrogate.py +1957 -1957
- brainstate/transform.py +23 -23
- brainstate/typing.py +304 -304
- brainstate/util/__init__.py +50 -50
- brainstate/util/caller.py +98 -98
- brainstate/util/error.py +55 -55
- brainstate/util/filter.py +469 -469
- brainstate/util/others.py +540 -540
- brainstate/util/pretty_pytree.py +945 -945
- brainstate/util/pretty_pytree_test.py +159 -159
- brainstate/util/pretty_repr.py +328 -328
- brainstate/util/pretty_table.py +2954 -2954
- brainstate/util/scaling.py +258 -258
- brainstate/util/struct.py +523 -523
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
- brainstate-0.1.9.dist-info/RECORD +130 -0
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
- brainstate/functional/_normalization.py +0 -81
- brainstate/functional/_spikes.py +0 -204
- brainstate-0.1.8.dist-info/RECORD +0 -132
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
brainstate/surrogate.py
CHANGED
@@ -1,1957 +1,1957 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
import jax
|
19
|
-
import jax.numpy as jnp
|
20
|
-
import jax.scipy as sci
|
21
|
-
from jax.interpreters import batching, ad, mlir
|
22
|
-
|
23
|
-
from brainstate._compatible_import import Primitive
|
24
|
-
from brainstate.util import PrettyObject
|
25
|
-
|
26
|
-
__all__ = [
|
27
|
-
'Surrogate',
|
28
|
-
'Sigmoid',
|
29
|
-
'sigmoid',
|
30
|
-
'PiecewiseQuadratic',
|
31
|
-
'piecewise_quadratic',
|
32
|
-
'PiecewiseExp',
|
33
|
-
'piecewise_exp',
|
34
|
-
'SoftSign',
|
35
|
-
'soft_sign',
|
36
|
-
'Arctan',
|
37
|
-
'arctan',
|
38
|
-
'NonzeroSignLog',
|
39
|
-
'nonzero_sign_log',
|
40
|
-
'ERF',
|
41
|
-
'erf',
|
42
|
-
'PiecewiseLeakyRelu',
|
43
|
-
'piecewise_leaky_relu',
|
44
|
-
'SquarewaveFourierSeries',
|
45
|
-
'squarewave_fourier_series',
|
46
|
-
'S2NN',
|
47
|
-
's2nn',
|
48
|
-
'QPseudoSpike',
|
49
|
-
'q_pseudo_spike',
|
50
|
-
'LeakyRelu',
|
51
|
-
'leaky_relu',
|
52
|
-
'LogTailedRelu',
|
53
|
-
'log_tailed_relu',
|
54
|
-
'ReluGrad',
|
55
|
-
'relu_grad',
|
56
|
-
'GaussianGrad',
|
57
|
-
'gaussian_grad',
|
58
|
-
'InvSquareGrad',
|
59
|
-
'inv_square_grad',
|
60
|
-
'MultiGaussianGrad',
|
61
|
-
'multi_gaussian_grad',
|
62
|
-
'SlayerGrad',
|
63
|
-
'slayer_grad',
|
64
|
-
]
|
65
|
-
|
66
|
-
|
67
|
-
def _heaviside_abstract(x, dx):
|
68
|
-
return [x]
|
69
|
-
|
70
|
-
|
71
|
-
def _heaviside_imp(x, dx):
|
72
|
-
z = jnp.asarray(x >= 0, dtype=x.dtype)
|
73
|
-
return [z]
|
74
|
-
|
75
|
-
|
76
|
-
def _heaviside_batching(args, axes):
|
77
|
-
x, dx = args
|
78
|
-
if axes[0] != axes[1]:
|
79
|
-
dx = batching.moveaxis(dx, axes[1], axes[0])
|
80
|
-
return heaviside_p.bind(x, dx), tuple([axes[0]])
|
81
|
-
|
82
|
-
|
83
|
-
def _heaviside_jvp(primals, tangents):
|
84
|
-
x, dx = primals
|
85
|
-
tx, tdx = tangents
|
86
|
-
primal_outs = heaviside_p.bind(x, dx)
|
87
|
-
tangent_outs = [dx * tx, ]
|
88
|
-
return primal_outs, tangent_outs
|
89
|
-
|
90
|
-
|
91
|
-
heaviside_p = Primitive('heaviside_p')
|
92
|
-
heaviside_p.multiple_results = True
|
93
|
-
heaviside_p.def_abstract_eval(_heaviside_abstract)
|
94
|
-
heaviside_p.def_impl(_heaviside_imp)
|
95
|
-
batching.primitive_batchers[heaviside_p] = _heaviside_batching
|
96
|
-
ad.primitive_jvps[heaviside_p] = _heaviside_jvp
|
97
|
-
mlir.register_lowering(heaviside_p, mlir.lower_fun(_heaviside_imp, multiple_results=True))
|
98
|
-
|
99
|
-
|
100
|
-
class Surrogate(PrettyObject):
|
101
|
-
"""The base surrograte gradient function.
|
102
|
-
|
103
|
-
To customize a surrogate gradient function, you can inherit this class and
|
104
|
-
implement the `surrogate_fun` and `surrogate_grad` methods.
|
105
|
-
|
106
|
-
Examples
|
107
|
-
--------
|
108
|
-
|
109
|
-
>>> import brainstate as brainstate
|
110
|
-
>>> import brainstate.nn as nn
|
111
|
-
>>> import jax.numpy as jnp
|
112
|
-
|
113
|
-
>>> class MySurrogate(brainstate.surrogate.Surrogate):
|
114
|
-
... def __init__(self, alpha=1.):
|
115
|
-
... super().__init__()
|
116
|
-
... self.alpha = alpha
|
117
|
-
...
|
118
|
-
... def surrogate_fun(self, x):
|
119
|
-
... return jnp.sin(x) * self.alpha
|
120
|
-
...
|
121
|
-
... def surrogate_grad(self, x):
|
122
|
-
... return jnp.cos(x) * self.alpha
|
123
|
-
|
124
|
-
"""
|
125
|
-
|
126
|
-
def __call__(self, x):
|
127
|
-
dx = self.surrogate_grad(x)
|
128
|
-
return heaviside_p.bind(x, dx)[0]
|
129
|
-
|
130
|
-
def surrogate_fun(self, x) -> jax.Array:
|
131
|
-
"""The surrogate function."""
|
132
|
-
raise NotImplementedError
|
133
|
-
|
134
|
-
def surrogate_grad(self, x) -> jax.Array:
|
135
|
-
"""The gradient function of the surrogate function."""
|
136
|
-
raise NotImplementedError
|
137
|
-
|
138
|
-
|
139
|
-
class Sigmoid(Surrogate):
|
140
|
-
"""Spike function with the sigmoid-shaped surrogate gradient.
|
141
|
-
|
142
|
-
This class implements a spiking neuron activation with a sigmoid-shaped
|
143
|
-
surrogate gradient for backpropagation. It can be used in spiking neural
|
144
|
-
networks to approximate the non-differentiable step function during training.
|
145
|
-
|
146
|
-
Parameters
|
147
|
-
----------
|
148
|
-
alpha : float, optional
|
149
|
-
A parameter controlling the steepness of the sigmoid curve in the
|
150
|
-
surrogate gradient. Higher values make the transition sharper.
|
151
|
-
Default is 4.0.
|
152
|
-
|
153
|
-
See Also
|
154
|
-
--------
|
155
|
-
sigmoid : Function version of this class.
|
156
|
-
|
157
|
-
"""
|
158
|
-
|
159
|
-
def __init__(self, alpha: float = 4.):
|
160
|
-
super().__init__()
|
161
|
-
self.alpha = alpha
|
162
|
-
|
163
|
-
def surrogate_fun(self, x):
|
164
|
-
"""Compute the surrogate function.
|
165
|
-
|
166
|
-
Parameters
|
167
|
-
----------
|
168
|
-
x : jax.Array
|
169
|
-
The input array.
|
170
|
-
|
171
|
-
Returns
|
172
|
-
-------
|
173
|
-
jax.Array
|
174
|
-
The output of the surrogate function.
|
175
|
-
"""
|
176
|
-
return sci.special.expit(self.alpha * x)
|
177
|
-
|
178
|
-
def surrogate_grad(self, x):
|
179
|
-
"""Compute the gradient of the surrogate function.
|
180
|
-
|
181
|
-
Parameters
|
182
|
-
----------
|
183
|
-
x : jax.Array
|
184
|
-
The input array.
|
185
|
-
|
186
|
-
Returns
|
187
|
-
-------
|
188
|
-
jax.Array
|
189
|
-
The gradient of the surrogate function.
|
190
|
-
"""
|
191
|
-
sgax = sci.special.expit(x * self.alpha)
|
192
|
-
dx = (1. - sgax) * sgax * self.alpha
|
193
|
-
return dx
|
194
|
-
|
195
|
-
def __repr__(self):
|
196
|
-
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
197
|
-
|
198
|
-
def __hash__(self):
|
199
|
-
return hash((self.__class__, self.alpha))
|
200
|
-
|
201
|
-
|
202
|
-
def sigmoid(
|
203
|
-
x: jax.Array,
|
204
|
-
alpha: float = 4.,
|
205
|
-
):
|
206
|
-
r"""
|
207
|
-
Compute a spike function with a sigmoid-shaped surrogate gradient.
|
208
|
-
|
209
|
-
This function implements a spiking neuron activation with a sigmoid-shaped
|
210
|
-
surrogate gradient for backpropagation. It can be used in spiking neural
|
211
|
-
networks to approximate the non-differentiable step function during training.
|
212
|
-
|
213
|
-
If `origin=False`, return the forward function:
|
214
|
-
|
215
|
-
.. math::
|
216
|
-
|
217
|
-
g(x) = \begin{cases}
|
218
|
-
1, & x \geq 0 \\
|
219
|
-
0, & x < 0 \\
|
220
|
-
\end{cases}
|
221
|
-
|
222
|
-
If `origin=True`, computes the original function:
|
223
|
-
|
224
|
-
.. math::
|
225
|
-
|
226
|
-
g(x) = \mathrm{sigmoid}(\alpha x) = \frac{1}{1+e^{-\alpha x}}
|
227
|
-
|
228
|
-
Backward function:
|
229
|
-
|
230
|
-
.. math::
|
231
|
-
|
232
|
-
g'(x) = \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x)
|
233
|
-
|
234
|
-
.. plot::
|
235
|
-
:include-source: True
|
236
|
-
|
237
|
-
>>> import jax
|
238
|
-
>>> import brainstate.nn as nn
|
239
|
-
>>> import brainstate as brainstate
|
240
|
-
>>> import matplotlib.pyplot as plt
|
241
|
-
>>> xs = jax.numpy.linspace(-2, 2, 1000)
|
242
|
-
>>> for alpha in [1., 2., 4.]:
|
243
|
-
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.sigmoid)(xs, alpha)
|
244
|
-
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
245
|
-
>>> plt.legend()
|
246
|
-
>>> plt.show()
|
247
|
-
|
248
|
-
Parameters
|
249
|
-
----------
|
250
|
-
x : jax.Array
|
251
|
-
The input array representing the neuron's membrane potential.
|
252
|
-
alpha : float, optional
|
253
|
-
A parameter controlling the steepness of the sigmoid curve in the
|
254
|
-
surrogate gradient. Higher values make the transition sharper.
|
255
|
-
Default is 4.0.
|
256
|
-
|
257
|
-
Returns
|
258
|
-
-------
|
259
|
-
jax.Array
|
260
|
-
An array of the same shape as the input, containing binary values (0 or 1)
|
261
|
-
representing the spiking state of each neuron.
|
262
|
-
|
263
|
-
Notes
|
264
|
-
-----
|
265
|
-
The forward pass uses a step function (1 for x >= 0, 0 for x < 0),
|
266
|
-
while the backward pass uses a sigmoid-shaped surrogate gradient for
|
267
|
-
smooth optimization.
|
268
|
-
|
269
|
-
The surrogate gradient is defined as:
|
270
|
-
g'(x) = alpha * (1 - sigmoid(alpha * x)) * sigmoid(alpha * x)
|
271
|
-
|
272
|
-
"""
|
273
|
-
return Sigmoid(alpha=alpha)(x)
|
274
|
-
|
275
|
-
|
276
|
-
class PiecewiseQuadratic(Surrogate):
|
277
|
-
"""Judge spiking state with a piecewise quadratic function.
|
278
|
-
|
279
|
-
See Also
|
280
|
-
--------
|
281
|
-
piecewise_quadratic
|
282
|
-
|
283
|
-
"""
|
284
|
-
|
285
|
-
def __init__(self, alpha: float = 1.):
|
286
|
-
super().__init__()
|
287
|
-
self.alpha = alpha
|
288
|
-
|
289
|
-
def surrogate_fun(self, x):
|
290
|
-
z = jnp.where(
|
291
|
-
x < -1 / self.alpha,
|
292
|
-
0.,
|
293
|
-
jnp.where(
|
294
|
-
x > 1 / self.alpha,
|
295
|
-
1.,
|
296
|
-
(-self.alpha * jnp.abs(x) / 2 + 1) * self.alpha * x + 0.5
|
297
|
-
)
|
298
|
-
)
|
299
|
-
return z
|
300
|
-
|
301
|
-
def surrogate_grad(self, x):
|
302
|
-
dx = jnp.where(jnp.abs(x) > 1 / self.alpha, 0., (-(self.alpha * x) ** 2 + self.alpha))
|
303
|
-
return dx
|
304
|
-
|
305
|
-
def __repr__(self):
|
306
|
-
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
307
|
-
|
308
|
-
def __hash__(self):
|
309
|
-
return hash((self.__class__, self.alpha))
|
310
|
-
|
311
|
-
|
312
|
-
def piecewise_quadratic(
|
313
|
-
x: jax.Array,
|
314
|
-
alpha: float = 1.,
|
315
|
-
):
|
316
|
-
r"""
|
317
|
-
Judge spiking state with a piecewise quadratic function [1]_ [2]_ [3]_ [4]_ [5]_.
|
318
|
-
|
319
|
-
This function implements a surrogate gradient method for spiking neural networks
|
320
|
-
using a piecewise quadratic function. It provides a differentiable approximation
|
321
|
-
of the step function used in the forward pass of spiking neurons.
|
322
|
-
|
323
|
-
If `origin=False`, computes the forward function:
|
324
|
-
|
325
|
-
.. math::
|
326
|
-
|
327
|
-
g(x) = \begin{cases}
|
328
|
-
1, & x \geq 0 \\
|
329
|
-
0, & x < 0 \\
|
330
|
-
\end{cases}
|
331
|
-
|
332
|
-
If `origin=True`, computes the original function:
|
333
|
-
|
334
|
-
.. math::
|
335
|
-
|
336
|
-
g(x) =
|
337
|
-
\begin{cases}
|
338
|
-
0, & x < -\frac{1}{\alpha} \\
|
339
|
-
-\frac{1}{2}\alpha^2|x|x + \alpha x + \frac{1}{2}, & |x| \leq \frac{1}{\alpha} \\
|
340
|
-
1, & x > \frac{1}{\alpha} \\
|
341
|
-
\end{cases}
|
342
|
-
|
343
|
-
Backward function:
|
344
|
-
|
345
|
-
.. math::
|
346
|
-
|
347
|
-
g'(x) =
|
348
|
-
\begin{cases}
|
349
|
-
0, & |x| > \frac{1}{\alpha} \\
|
350
|
-
-\alpha^2|x|+\alpha, & |x| \leq \frac{1}{\alpha}
|
351
|
-
\end{cases}
|
352
|
-
|
353
|
-
.. plot::
|
354
|
-
:include-source: True
|
355
|
-
|
356
|
-
>>> import jax
|
357
|
-
>>> import brainstate.nn as nn
|
358
|
-
>>> import brainstate as brainstate
|
359
|
-
>>> import matplotlib.pyplot as plt
|
360
|
-
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
361
|
-
>>> for alpha in [0.5, 1., 2., 4.]:
|
362
|
-
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.piecewise_quadratic)(xs, alpha)
|
363
|
-
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
364
|
-
>>> plt.legend()
|
365
|
-
>>> plt.show()
|
366
|
-
|
367
|
-
Parameters
|
368
|
-
----------
|
369
|
-
x : jax.Array
|
370
|
-
The input array representing the neuron's membrane potential.
|
371
|
-
alpha : float, optional
|
372
|
-
A parameter controlling the steepness of the surrogate gradient.
|
373
|
-
Higher values result in a steeper gradient. Default is 1.0.
|
374
|
-
|
375
|
-
Returns
|
376
|
-
-------
|
377
|
-
jax.Array
|
378
|
-
An array of the same shape as the input, containing binary values (0 or 1)
|
379
|
-
representing the spiking state of each neuron.
|
380
|
-
|
381
|
-
Notes
|
382
|
-
-----
|
383
|
-
The function uses different computations for forward and backward passes:
|
384
|
-
- Forward: Step function (1 for x >= 0, 0 for x < 0)
|
385
|
-
- Backward: Piecewise quadratic function for smooth gradient
|
386
|
-
|
387
|
-
The surrogate gradient is defined as:
|
388
|
-
g'(x) = 0 if |x| > 1/alpha
|
389
|
-
-alpha^2|x| + alpha if |x| <= 1/alpha
|
390
|
-
|
391
|
-
References
|
392
|
-
----------
|
393
|
-
.. [1] Esser S K, Merolla P A, Arthur J V, et al. Convolutional networks for fast, energy-efficient neuromorphic computing[J]. Proceedings of the national academy of sciences, 2016, 113(41): 11441-11446.
|
394
|
-
.. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331.
|
395
|
-
.. [3] Bellec G, Salaj D, Subramoney A, et al. Long short-term memory and learning-to-learn in networks of spiking neurons[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 795-805.
|
396
|
-
.. [4] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63.
|
397
|
-
.. [5] Panda P, Aketi S A, Roy K. Toward scalable, efficient, and accurate deep spiking neural networks with backward residual connections, stochastic softmax, and hybridization[J]. Frontiers in Neuroscience, 2020, 14.
|
398
|
-
"""
|
399
|
-
return PiecewiseQuadratic(alpha=alpha)(x)
|
400
|
-
|
401
|
-
|
402
|
-
class PiecewiseExp(Surrogate):
|
403
|
-
"""
|
404
|
-
Judge spiking state with a piecewise exponential function.
|
405
|
-
|
406
|
-
This class implements a surrogate gradient method for spiking neural networks
|
407
|
-
using a piecewise exponential function. It provides a differentiable approximation
|
408
|
-
of the step function used in the forward pass of spiking neurons.
|
409
|
-
|
410
|
-
Parameters
|
411
|
-
----------
|
412
|
-
alpha : float, optional
|
413
|
-
A parameter controlling the steepness of the surrogate gradient.
|
414
|
-
Higher values result in a steeper gradient. Default is 1.0.
|
415
|
-
|
416
|
-
See Also
|
417
|
-
--------
|
418
|
-
piecewise_exp : Function version of this class.
|
419
|
-
"""
|
420
|
-
|
421
|
-
def __init__(self, alpha: float = 1.):
|
422
|
-
super().__init__()
|
423
|
-
self.alpha = alpha
|
424
|
-
|
425
|
-
def surrogate_grad(self, x):
|
426
|
-
"""
|
427
|
-
Compute the surrogate gradient.
|
428
|
-
|
429
|
-
Parameters
|
430
|
-
----------
|
431
|
-
x : jax.Array
|
432
|
-
The input array.
|
433
|
-
|
434
|
-
Returns
|
435
|
-
-------
|
436
|
-
jax.Array
|
437
|
-
The surrogate gradient.
|
438
|
-
"""
|
439
|
-
dx = (self.alpha / 2) * jnp.exp(-self.alpha * jnp.abs(x))
|
440
|
-
return dx
|
441
|
-
|
442
|
-
def surrogate_fun(self, x):
|
443
|
-
"""
|
444
|
-
Compute the surrogate function.
|
445
|
-
|
446
|
-
Parameters
|
447
|
-
----------
|
448
|
-
x : jax.Array
|
449
|
-
The input array.
|
450
|
-
|
451
|
-
Returns
|
452
|
-
-------
|
453
|
-
jax.Array
|
454
|
-
The output of the surrogate function.
|
455
|
-
"""
|
456
|
-
return jnp.where(
|
457
|
-
x < 0,
|
458
|
-
jnp.exp(self.alpha * x) / 2,
|
459
|
-
1 - jnp.exp(-self.alpha * x) / 2
|
460
|
-
)
|
461
|
-
|
462
|
-
def __repr__(self):
|
463
|
-
"""
|
464
|
-
Return a string representation of the PiecewiseExp instance.
|
465
|
-
|
466
|
-
Returns
|
467
|
-
-------
|
468
|
-
str
|
469
|
-
A string representation of the instance.
|
470
|
-
"""
|
471
|
-
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
472
|
-
|
473
|
-
def __hash__(self):
|
474
|
-
"""
|
475
|
-
Compute a hash value for the PiecewiseExp instance.
|
476
|
-
|
477
|
-
Returns
|
478
|
-
-------
|
479
|
-
int
|
480
|
-
A hash value for the instance.
|
481
|
-
"""
|
482
|
-
return hash((self.__class__, self.alpha))
|
483
|
-
|
484
|
-
|
485
|
-
def piecewise_exp(
|
486
|
-
x: jax.Array,
|
487
|
-
alpha: float = 1.,
|
488
|
-
|
489
|
-
):
|
490
|
-
r"""Judge spiking state with a piecewise exponential function [1]_.
|
491
|
-
|
492
|
-
This function implements a surrogate gradient method for spiking neural networks
|
493
|
-
using a piecewise exponential function. It provides a differentiable approximation
|
494
|
-
of the step function used in the forward pass of spiking neurons.
|
495
|
-
|
496
|
-
If `origin=False`, computes the forward function:
|
497
|
-
|
498
|
-
.. math::
|
499
|
-
|
500
|
-
g(x) = \begin{cases}
|
501
|
-
1, & x \geq 0 \\
|
502
|
-
0, & x < 0 \\
|
503
|
-
\end{cases}
|
504
|
-
|
505
|
-
If `origin=True`, computes the original function:
|
506
|
-
|
507
|
-
.. math::
|
508
|
-
|
509
|
-
g(x) = \begin{cases}
|
510
|
-
\frac{1}{2}e^{\alpha x}, & x < 0 \\
|
511
|
-
1 - \frac{1}{2}e^{-\alpha x}, & x \geq 0
|
512
|
-
\end{cases}
|
513
|
-
|
514
|
-
Backward function:
|
515
|
-
|
516
|
-
.. math::
|
517
|
-
|
518
|
-
g'(x) = \frac{\alpha}{2}e^{-\alpha |x|}
|
519
|
-
|
520
|
-
.. plot::
|
521
|
-
:include-source: True
|
522
|
-
|
523
|
-
>>> import jax
|
524
|
-
>>> import brainstate.nn as nn
|
525
|
-
>>> import brainstate as brainstate
|
526
|
-
>>> import matplotlib.pyplot as plt
|
527
|
-
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
528
|
-
>>> for alpha in [0.5, 1., 2., 4.]:
|
529
|
-
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.piecewise_exp)(xs, alpha)
|
530
|
-
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
531
|
-
>>> plt.legend()
|
532
|
-
>>> plt.show()
|
533
|
-
|
534
|
-
Parameters
|
535
|
-
----------
|
536
|
-
x : jax.Array
|
537
|
-
The input array representing the neuron's membrane potential.
|
538
|
-
alpha : float, optional
|
539
|
-
A parameter controlling the steepness of the surrogate gradient.
|
540
|
-
Higher values result in a steeper gradient. Default is 1.0.
|
541
|
-
|
542
|
-
Returns
|
543
|
-
-------
|
544
|
-
jax.Array
|
545
|
-
An array of the same shape as the input, containing binary values (0 or 1)
|
546
|
-
representing the spiking state of each neuron.
|
547
|
-
|
548
|
-
Notes
|
549
|
-
-----
|
550
|
-
The function uses different computations for forward and backward passes:
|
551
|
-
- Forward: Step function (1 for x >= 0, 0 for x < 0)
|
552
|
-
- Backward: Piecewise exponential function for smooth gradient
|
553
|
-
|
554
|
-
The surrogate gradient is defined as:
|
555
|
-
g'(x) = (alpha / 2) * exp(-alpha * |x|)
|
556
|
-
|
557
|
-
References
|
558
|
-
----------
|
559
|
-
.. [1] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63.
|
560
|
-
"""
|
561
|
-
return PiecewiseExp(alpha=alpha)(x)
|
562
|
-
|
563
|
-
|
564
|
-
class SoftSign(Surrogate):
|
565
|
-
"""Judge spiking state with a soft sign function.
|
566
|
-
|
567
|
-
See Also
|
568
|
-
--------
|
569
|
-
soft_sign
|
570
|
-
"""
|
571
|
-
|
572
|
-
def __init__(self, alpha=1.):
|
573
|
-
super().__init__()
|
574
|
-
self.alpha = alpha
|
575
|
-
|
576
|
-
def surrogate_grad(self, x):
|
577
|
-
dx = self.alpha * 0.5 / (1 + jnp.abs(self.alpha * x)) ** 2
|
578
|
-
return dx
|
579
|
-
|
580
|
-
def surrogate_fun(self, x):
|
581
|
-
return x / (2 / self.alpha + 2 * jnp.abs(x)) + 0.5
|
582
|
-
|
583
|
-
def __repr__(self):
|
584
|
-
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
585
|
-
|
586
|
-
def __hash__(self):
|
587
|
-
return hash((self.__class__, self.alpha))
|
588
|
-
|
589
|
-
|
590
|
-
def soft_sign(
|
591
|
-
x: jax.Array,
|
592
|
-
alpha: float = 1.,
|
593
|
-
|
594
|
-
):
|
595
|
-
r"""Judge spiking state with a soft sign function.
|
596
|
-
|
597
|
-
If `origin=False`, computes the forward function:
|
598
|
-
|
599
|
-
.. math::
|
600
|
-
|
601
|
-
g(x) = \begin{cases}
|
602
|
-
1, & x \geq 0 \\
|
603
|
-
0, & x < 0 \\
|
604
|
-
\end{cases}
|
605
|
-
|
606
|
-
If `origin=True`, computes the original function:
|
607
|
-
|
608
|
-
.. math::
|
609
|
-
|
610
|
-
g(x) = \frac{1}{2} (\frac{\alpha x}{1 + |\alpha x|} + 1)
|
611
|
-
= \frac{1}{2} (\frac{x}{\frac{1}{\alpha} + |x|} + 1)
|
612
|
-
|
613
|
-
Backward function:
|
614
|
-
|
615
|
-
.. math::
|
616
|
-
|
617
|
-
g'(x) = \frac{\alpha}{2(1 + |\alpha x|)^{2}} = \frac{1}{2\alpha(\frac{1}{\alpha} + |x|)^{2}}
|
618
|
-
|
619
|
-
.. plot::
|
620
|
-
:include-source: True
|
621
|
-
|
622
|
-
>>> import jax
|
623
|
-
>>> import brainstate.nn as nn
|
624
|
-
>>> import brainstate as brainstate
|
625
|
-
>>> import matplotlib.pyplot as plt
|
626
|
-
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
627
|
-
>>> for alpha in [0.5, 1., 2., 4.]:
|
628
|
-
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.soft_sign)(xs, alpha)
|
629
|
-
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
630
|
-
>>> plt.legend()
|
631
|
-
>>> plt.show()
|
632
|
-
|
633
|
-
Parameters
|
634
|
-
----------
|
635
|
-
x: jax.Array, Array
|
636
|
-
The input data.
|
637
|
-
alpha: float
|
638
|
-
Parameter to control smoothness of gradient
|
639
|
-
|
640
|
-
|
641
|
-
Returns
|
642
|
-
-------
|
643
|
-
out: jax.Array
|
644
|
-
The spiking state.
|
645
|
-
|
646
|
-
"""
|
647
|
-
return SoftSign(alpha=alpha)(x)
|
648
|
-
|
649
|
-
|
650
|
-
class Arctan(Surrogate):
|
651
|
-
"""Judge spiking state with an arctan function.
|
652
|
-
|
653
|
-
See Also
|
654
|
-
--------
|
655
|
-
arctan
|
656
|
-
"""
|
657
|
-
|
658
|
-
def __init__(self, alpha=1.):
|
659
|
-
super().__init__()
|
660
|
-
self.alpha = alpha
|
661
|
-
|
662
|
-
def surrogate_grad(self, x):
|
663
|
-
dx = self.alpha * 0.5 / (1 + (jnp.pi / 2 * self.alpha * x) ** 2)
|
664
|
-
return dx
|
665
|
-
|
666
|
-
def surrogate_fun(self, x):
|
667
|
-
return jnp.arctan2(jnp.pi / 2 * self.alpha * x) / jnp.pi + 0.5
|
668
|
-
|
669
|
-
def __repr__(self):
|
670
|
-
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
671
|
-
|
672
|
-
def __hash__(self):
|
673
|
-
return hash((self.__class__, self.alpha))
|
674
|
-
|
675
|
-
|
676
|
-
def arctan(
|
677
|
-
x: jax.Array,
|
678
|
-
alpha: float = 1.,
|
679
|
-
|
680
|
-
):
|
681
|
-
r"""Judge spiking state with an arctan function.
|
682
|
-
|
683
|
-
If `origin=False`, computes the forward function:
|
684
|
-
|
685
|
-
.. math::
|
686
|
-
|
687
|
-
g(x) = \begin{cases}
|
688
|
-
1, & x \geq 0 \\
|
689
|
-
0, & x < 0 \\
|
690
|
-
\end{cases}
|
691
|
-
|
692
|
-
If `origin=True`, computes the original function:
|
693
|
-
|
694
|
-
.. math::
|
695
|
-
|
696
|
-
g(x) = \frac{1}{\pi} \arctan(\frac{\pi}{2}\alpha x) + \frac{1}{2}
|
697
|
-
|
698
|
-
Backward function:
|
699
|
-
|
700
|
-
.. math::
|
701
|
-
|
702
|
-
g'(x) = \frac{\alpha}{2(1 + (\frac{\pi}{2}\alpha x)^2)}
|
703
|
-
|
704
|
-
.. plot::
|
705
|
-
:include-source: True
|
706
|
-
|
707
|
-
>>> import jax
|
708
|
-
>>> import brainstate.nn as nn
|
709
|
-
>>> import brainstate as brainstate
|
710
|
-
>>> import matplotlib.pyplot as plt
|
711
|
-
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
712
|
-
>>> for alpha in [0.5, 1., 2., 4.]:
|
713
|
-
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.arctan)(xs, alpha)
|
714
|
-
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
715
|
-
>>> plt.legend()
|
716
|
-
>>> plt.show()
|
717
|
-
|
718
|
-
Parameters
|
719
|
-
----------
|
720
|
-
x: jax.Array, Array
|
721
|
-
The input data.
|
722
|
-
alpha: float
|
723
|
-
Parameter to control smoothness of gradient
|
724
|
-
|
725
|
-
|
726
|
-
Returns
|
727
|
-
-------
|
728
|
-
out: jax.Array
|
729
|
-
The spiking state.
|
730
|
-
|
731
|
-
"""
|
732
|
-
return Arctan(alpha=alpha)(x)
|
733
|
-
|
734
|
-
|
735
|
-
class NonzeroSignLog(Surrogate):
|
736
|
-
"""Judge spiking state with a nonzero sign log function.
|
737
|
-
|
738
|
-
See Also
|
739
|
-
--------
|
740
|
-
nonzero_sign_log
|
741
|
-
"""
|
742
|
-
|
743
|
-
def __init__(self, alpha=1.):
|
744
|
-
super().__init__()
|
745
|
-
self.alpha = alpha
|
746
|
-
|
747
|
-
def surrogate_grad(self, x):
|
748
|
-
dx = 1. / (1 / self.alpha + jnp.abs(x))
|
749
|
-
return dx
|
750
|
-
|
751
|
-
def surrogate_fun(self, x):
|
752
|
-
return jnp.where(x < 0, -1., 1.) * jnp.log(jnp.abs(self.alpha * x) + 1)
|
753
|
-
|
754
|
-
def __repr__(self):
|
755
|
-
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
756
|
-
|
757
|
-
def __hash__(self):
|
758
|
-
return hash((self.__class__, self.alpha))
|
759
|
-
|
760
|
-
|
761
|
-
def nonzero_sign_log(
|
762
|
-
x: jax.Array,
|
763
|
-
alpha: float = 1.,
|
764
|
-
|
765
|
-
):
|
766
|
-
r"""Judge spiking state with a nonzero sign log function.
|
767
|
-
|
768
|
-
If `origin=False`, computes the forward function:
|
769
|
-
|
770
|
-
.. math::
|
771
|
-
|
772
|
-
g(x) = \begin{cases}
|
773
|
-
1, & x \geq 0 \\
|
774
|
-
0, & x < 0 \\
|
775
|
-
\end{cases}
|
776
|
-
|
777
|
-
If `origin=True`, computes the original function:
|
778
|
-
|
779
|
-
.. math::
|
780
|
-
|
781
|
-
g(x) = \mathrm{NonzeroSign}(x) \log (|\alpha x| + 1)
|
782
|
-
|
783
|
-
where
|
784
|
-
|
785
|
-
.. math::
|
786
|
-
|
787
|
-
\begin{split}\mathrm{NonzeroSign}(x) =
|
788
|
-
\begin{cases}
|
789
|
-
1, & x \geq 0 \\
|
790
|
-
-1, & x < 0 \\
|
791
|
-
\end{cases}\end{split}
|
792
|
-
|
793
|
-
Backward function:
|
794
|
-
|
795
|
-
.. math::
|
796
|
-
|
797
|
-
g'(x) = \frac{\alpha}{1 + |\alpha x|} = \frac{1}{\frac{1}{\alpha} + |x|}
|
798
|
-
|
799
|
-
This surrogate function has the advantage of low computation cost during the backward.
|
800
|
-
|
801
|
-
|
802
|
-
.. plot::
|
803
|
-
:include-source: True
|
804
|
-
|
805
|
-
>>> import jax
|
806
|
-
>>> import brainstate.nn as nn
|
807
|
-
>>> import brainstate as brainstate
|
808
|
-
>>> import matplotlib.pyplot as plt
|
809
|
-
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
810
|
-
>>> for alpha in [0.5, 1., 2., 4.]:
|
811
|
-
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.nonzero_sign_log)(xs, alpha)
|
812
|
-
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
813
|
-
>>> plt.legend()
|
814
|
-
>>> plt.show()
|
815
|
-
|
816
|
-
Parameters
|
817
|
-
----------
|
818
|
-
x: jax.Array, Array
|
819
|
-
The input data.
|
820
|
-
alpha: float
|
821
|
-
Parameter to control smoothness of gradient
|
822
|
-
|
823
|
-
|
824
|
-
Returns
|
825
|
-
-------
|
826
|
-
out: jax.Array
|
827
|
-
The spiking state.
|
828
|
-
|
829
|
-
"""
|
830
|
-
return NonzeroSignLog(alpha=alpha)(x)
|
831
|
-
|
832
|
-
|
833
|
-
class ERF(Surrogate):
|
834
|
-
"""Judge spiking state with an erf function.
|
835
|
-
|
836
|
-
See Also
|
837
|
-
--------
|
838
|
-
erf
|
839
|
-
"""
|
840
|
-
|
841
|
-
def __init__(self, alpha=1.):
|
842
|
-
super().__init__()
|
843
|
-
self.alpha = alpha
|
844
|
-
|
845
|
-
def surrogate_grad(self, x):
|
846
|
-
dx = (self.alpha / jnp.sqrt(jnp.pi)) * jnp.exp(-jnp.power(self.alpha, 2) * x * x)
|
847
|
-
return dx
|
848
|
-
|
849
|
-
def surrogate_fun(self, x):
|
850
|
-
return sci.special.erf(-self.alpha * x) * 0.5
|
851
|
-
|
852
|
-
def __repr__(self):
|
853
|
-
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
854
|
-
|
855
|
-
def __hash__(self):
|
856
|
-
return hash((self.__class__, self.alpha))
|
857
|
-
|
858
|
-
|
859
|
-
def erf(
|
860
|
-
x: jax.Array,
|
861
|
-
alpha: float = 1.,
|
862
|
-
|
863
|
-
):
|
864
|
-
r"""Judge spiking state with an erf function [1]_ [2]_ [3]_.
|
865
|
-
|
866
|
-
If `origin=False`, computes the forward function:
|
867
|
-
|
868
|
-
.. math::
|
869
|
-
|
870
|
-
g(x) = \begin{cases}
|
871
|
-
1, & x \geq 0 \\
|
872
|
-
0, & x < 0 \\
|
873
|
-
\end{cases}
|
874
|
-
|
875
|
-
If `origin=True`, computes the original function:
|
876
|
-
|
877
|
-
.. math::
|
878
|
-
|
879
|
-
\begin{split}
|
880
|
-
g(x) &= \frac{1}{2}(1-\text{erf}(-\alpha x)) \\
|
881
|
-
&= \frac{1}{2} \text{erfc}(-\alpha x) \\
|
882
|
-
&= \frac{1}{\sqrt{\pi}}\int_{-\infty}^{\alpha x}e^{-t^2}dt
|
883
|
-
\end{split}
|
884
|
-
|
885
|
-
Backward function:
|
886
|
-
|
887
|
-
.. math::
|
888
|
-
|
889
|
-
g'(x) = \frac{\alpha}{\sqrt{\pi}}e^{-\alpha^2x^2}
|
890
|
-
|
891
|
-
.. plot::
|
892
|
-
:include-source: True
|
893
|
-
|
894
|
-
>>> import jax
|
895
|
-
>>> import brainstate.nn as nn
|
896
|
-
>>> import brainstate as brainstate
|
897
|
-
>>> import matplotlib.pyplot as plt
|
898
|
-
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
899
|
-
>>> for alpha in [0.5, 1., 2., 4.]:
|
900
|
-
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.nonzero_sign_log)(xs, alpha)
|
901
|
-
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
902
|
-
>>> plt.legend()
|
903
|
-
>>> plt.show()
|
904
|
-
|
905
|
-
Parameters
|
906
|
-
----------
|
907
|
-
x: jax.Array, Array
|
908
|
-
The input data.
|
909
|
-
alpha: float
|
910
|
-
Parameter to control smoothness of gradient
|
911
|
-
|
912
|
-
|
913
|
-
Returns
|
914
|
-
-------
|
915
|
-
out: jax.Array
|
916
|
-
The spiking state.
|
917
|
-
|
918
|
-
References
|
919
|
-
----------
|
920
|
-
.. [1] Esser S K, Appuswamy R, Merolla P, et al. Backpropagation for energy-efficient neuromorphic computing[J]. Advances in neural information processing systems, 2015, 28: 1117-1125.
|
921
|
-
.. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331.
|
922
|
-
.. [3] Yin B, Corradi F, Bohté S M. Effective and efficient computation with multiple-timescale spiking recurrent neural networks[C]//International Conference on Neuromorphic Systems 2020. 2020: 1-8.
|
923
|
-
|
924
|
-
"""
|
925
|
-
return ERF(alpha=alpha)(x)
|
926
|
-
|
927
|
-
|
928
|
-
class PiecewiseLeakyRelu(Surrogate):
|
929
|
-
"""Judge spiking state with a piecewise leaky relu function.
|
930
|
-
|
931
|
-
See Also
|
932
|
-
--------
|
933
|
-
piecewise_leaky_relu
|
934
|
-
"""
|
935
|
-
|
936
|
-
def __init__(self, c=0.01, w=1.):
|
937
|
-
super().__init__()
|
938
|
-
self.c = c
|
939
|
-
self.w = w
|
940
|
-
|
941
|
-
def surrogate_fun(self, x):
|
942
|
-
z = jnp.where(x < -self.w,
|
943
|
-
self.c * x + self.c * self.w,
|
944
|
-
jnp.where(x > self.w,
|
945
|
-
self.c * x - self.c * self.w + 1,
|
946
|
-
0.5 * x / self.w + 0.5))
|
947
|
-
return z
|
948
|
-
|
949
|
-
def surrogate_grad(self, x):
|
950
|
-
dx = jnp.where(jnp.abs(x) > self.w, self.c, 1 / self.w)
|
951
|
-
return dx
|
952
|
-
|
953
|
-
def __repr__(self):
|
954
|
-
return f'{self.__class__.__name__}(c={self.c}, w={self.w})'
|
955
|
-
|
956
|
-
def __hash__(self):
|
957
|
-
return hash((self.__class__, self.c, self.w))
|
958
|
-
|
959
|
-
|
960
|
-
def piecewise_leaky_relu(
|
961
|
-
x: jax.Array,
|
962
|
-
c: float = 0.01,
|
963
|
-
w: float = 1.,
|
964
|
-
|
965
|
-
):
|
966
|
-
r"""Judge spiking state with a piecewise leaky relu function [1]_ [2]_ [3]_ [4]_ [5]_ [6]_ [7]_ [8]_.
|
967
|
-
|
968
|
-
If `origin=False`, computes the forward function:
|
969
|
-
|
970
|
-
.. math::
|
971
|
-
|
972
|
-
g(x) = \begin{cases}
|
973
|
-
1, & x \geq 0 \\
|
974
|
-
0, & x < 0 \\
|
975
|
-
\end{cases}
|
976
|
-
|
977
|
-
If `origin=True`, computes the original function:
|
978
|
-
|
979
|
-
.. math::
|
980
|
-
|
981
|
-
\begin{split}g(x) =
|
982
|
-
\begin{cases}
|
983
|
-
cx + cw, & x < -w \\
|
984
|
-
\frac{1}{2w}x + \frac{1}{2}, & -w \leq x \leq w \\
|
985
|
-
cx - cw + 1, & x > w \\
|
986
|
-
\end{cases}\end{split}
|
987
|
-
|
988
|
-
Backward function:
|
989
|
-
|
990
|
-
.. math::
|
991
|
-
|
992
|
-
\begin{split}g'(x) =
|
993
|
-
\begin{cases}
|
994
|
-
\frac{1}{w}, & |x| \leq w \\
|
995
|
-
c, & |x| > w
|
996
|
-
\end{cases}\end{split}
|
997
|
-
|
998
|
-
.. plot::
|
999
|
-
:include-source: True
|
1000
|
-
|
1001
|
-
>>> import jax
|
1002
|
-
>>> import brainstate.nn as nn
|
1003
|
-
>>> import brainstate as brainstate
|
1004
|
-
>>> import matplotlib.pyplot as plt
|
1005
|
-
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1006
|
-
>>> for c in [0.01, 0.05, 0.1]:
|
1007
|
-
>>> for w in [1., 2.]:
|
1008
|
-
>>> grads1 = brainstate.augment.vector_grad(brainstate.surrogate.piecewise_leaky_relu)(xs, c=c, w=w)
|
1009
|
-
>>> plt.plot(xs, grads1, label=f'x={c}, w={w}')
|
1010
|
-
>>> plt.legend()
|
1011
|
-
>>> plt.show()
|
1012
|
-
|
1013
|
-
Parameters
|
1014
|
-
----------
|
1015
|
-
x: jax.Array, Array
|
1016
|
-
The input data.
|
1017
|
-
c: float
|
1018
|
-
When :math:`|x| > w` the gradient is `c`.
|
1019
|
-
w: float
|
1020
|
-
When :math:`|x| <= w` the gradient is `1 / w`.
|
1021
|
-
|
1022
|
-
|
1023
|
-
Returns
|
1024
|
-
-------
|
1025
|
-
out: jax.Array
|
1026
|
-
The spiking state.
|
1027
|
-
|
1028
|
-
References
|
1029
|
-
----------
|
1030
|
-
.. [1] Yin S, Venkataramanaiah S K, Chen G K, et al. Algorithm and hardware design of discrete-time spiking neural networks based on back propagation with binary activations[C]//2017 IEEE Biomedical Circuits and Systems Conference (BioCAS). IEEE, 2017: 1-5.
|
1031
|
-
.. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331.
|
1032
|
-
.. [3] Huh D, Sejnowski T J. Gradient descent for spiking neural networks[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 1440-1450.
|
1033
|
-
.. [4] Wu Y, Deng L, Li G, et al. Direct training for spiking neural networks: Faster, larger, better[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2019, 33(01): 1311-1318.
|
1034
|
-
.. [5] Gu P, Xiao R, Pan G, et al. STCA: Spatio-Temporal Credit Assignment with Delayed Feedback in Deep Spiking Neural Networks[C]//IJCAI. 2019: 1366-1372.
|
1035
|
-
.. [6] Roy D, Chakraborty I, Roy K. Scaling deep spiking neural networks with binary stochastic activations[C]//2019 IEEE International Conference on Cognitive Computing (ICCC). IEEE, 2019: 50-58.
|
1036
|
-
.. [7] Cheng X, Hao Y, Xu J, et al. LISNN: Improving Spiking Neural Networks with Lateral Interactions for Robust Object Recognition[C]//IJCAI. 1519-1525.
|
1037
|
-
.. [8] Kaiser J, Mostafa H, Neftci E. Synaptic plasticity dynamics for deep continuous local learning (DECOLLE)[J]. Frontiers in Neuroscience, 2020, 14: 424.
|
1038
|
-
|
1039
|
-
"""
|
1040
|
-
return PiecewiseLeakyRelu(c=c, w=w)(x)
|
1041
|
-
|
1042
|
-
|
1043
|
-
class SquarewaveFourierSeries(Surrogate):
|
1044
|
-
"""Judge spiking state with a squarewave fourier series.
|
1045
|
-
|
1046
|
-
See Also
|
1047
|
-
--------
|
1048
|
-
squarewave_fourier_series
|
1049
|
-
"""
|
1050
|
-
|
1051
|
-
def __init__(self, n=2, t_period=8.):
|
1052
|
-
super().__init__()
|
1053
|
-
self.n = n
|
1054
|
-
self.t_period = t_period
|
1055
|
-
|
1056
|
-
def surrogate_grad(self, x):
|
1057
|
-
|
1058
|
-
w = jnp.pi * 2. / self.t_period
|
1059
|
-
dx = jnp.cos(w * x)
|
1060
|
-
for i in range(2, self.n):
|
1061
|
-
dx += jnp.cos((2 * i - 1.) * w * x)
|
1062
|
-
dx *= 4. / self.t_period
|
1063
|
-
return dx
|
1064
|
-
|
1065
|
-
def surrogate_fun(self, x):
|
1066
|
-
|
1067
|
-
w = jnp.pi * 2. / self.t_period
|
1068
|
-
ret = jnp.sin(w * x)
|
1069
|
-
for i in range(2, self.n):
|
1070
|
-
c = (2 * i - 1.)
|
1071
|
-
ret += jnp.sin(c * w * x) / c
|
1072
|
-
z = 0.5 + 2. / jnp.pi * ret
|
1073
|
-
return z
|
1074
|
-
|
1075
|
-
def __repr__(self):
|
1076
|
-
return f'{self.__class__.__name__}(n={self.n}, t_period={self.t_period})'
|
1077
|
-
|
1078
|
-
def __hash__(self):
|
1079
|
-
return hash((self.__class__, self.n, self.t_period))
|
1080
|
-
|
1081
|
-
|
1082
|
-
def squarewave_fourier_series(
|
1083
|
-
x: jax.Array,
|
1084
|
-
n: int = 2,
|
1085
|
-
t_period: float = 8.,
|
1086
|
-
|
1087
|
-
):
|
1088
|
-
r"""Judge spiking state with a squarewave fourier series.
|
1089
|
-
|
1090
|
-
If `origin=False`, computes the forward function:
|
1091
|
-
|
1092
|
-
.. math::
|
1093
|
-
|
1094
|
-
g(x) = \begin{cases}
|
1095
|
-
1, & x \geq 0 \\
|
1096
|
-
0, & x < 0 \\
|
1097
|
-
\end{cases}
|
1098
|
-
|
1099
|
-
If `origin=True`, computes the original function:
|
1100
|
-
|
1101
|
-
.. math::
|
1102
|
-
|
1103
|
-
g(x) = 0.5 + \frac{1}{\pi}*\sum_{i=1}^n {\sin\left({(2i-1)*2\pi}*x/T\right) \over 2i-1 }
|
1104
|
-
|
1105
|
-
Backward function:
|
1106
|
-
|
1107
|
-
.. math::
|
1108
|
-
|
1109
|
-
g'(x) = \sum_{i=1}^n\frac{4\cos\left((2 * i - 1.) * 2\pi * x / T\right)}{T}
|
1110
|
-
|
1111
|
-
.. plot::
|
1112
|
-
:include-source: True
|
1113
|
-
|
1114
|
-
>>> import jax
|
1115
|
-
>>> import brainstate.nn as nn
|
1116
|
-
>>> import brainstate as brainstate
|
1117
|
-
>>> import matplotlib.pyplot as plt
|
1118
|
-
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1119
|
-
>>> for n in [2, 4, 8]:
|
1120
|
-
>>> f = brainstate.surrogate.SquarewaveFourierSeries(n=n)
|
1121
|
-
>>> grads1 = brainstate.augment.vector_grad(f)(xs)
|
1122
|
-
>>> plt.plot(xs, grads1, label=f'n={n}')
|
1123
|
-
>>> plt.legend()
|
1124
|
-
>>> plt.show()
|
1125
|
-
|
1126
|
-
Parameters
|
1127
|
-
----------
|
1128
|
-
x: jax.Array, Array
|
1129
|
-
The input data.
|
1130
|
-
n: int
|
1131
|
-
t_period: float
|
1132
|
-
|
1133
|
-
|
1134
|
-
Returns
|
1135
|
-
-------
|
1136
|
-
out: jax.Array
|
1137
|
-
The spiking state.
|
1138
|
-
|
1139
|
-
"""
|
1140
|
-
|
1141
|
-
return SquarewaveFourierSeries(n=n, t_period=t_period)(x)
|
1142
|
-
|
1143
|
-
|
1144
|
-
class S2NN(Surrogate):
|
1145
|
-
"""Judge spiking state with the S2NN surrogate spiking function.
|
1146
|
-
|
1147
|
-
See Also
|
1148
|
-
--------
|
1149
|
-
s2nn
|
1150
|
-
"""
|
1151
|
-
|
1152
|
-
def __init__(self, alpha=4., beta=1., epsilon=1e-8):
|
1153
|
-
super().__init__()
|
1154
|
-
self.alpha = alpha
|
1155
|
-
self.beta = beta
|
1156
|
-
self.epsilon = epsilon
|
1157
|
-
|
1158
|
-
def surrogate_fun(self, x):
|
1159
|
-
z = jnp.where(x < 0.,
|
1160
|
-
sci.special.expit(x * self.alpha),
|
1161
|
-
self.beta * jnp.log(jnp.abs((x + 1.)) + self.epsilon) + 0.5)
|
1162
|
-
return z
|
1163
|
-
|
1164
|
-
def surrogate_grad(self, x):
|
1165
|
-
sg = sci.special.expit(self.alpha * x)
|
1166
|
-
dx = jnp.where(x < 0., self.alpha * sg * (1. - sg), self.beta / (x + 1.))
|
1167
|
-
return dx
|
1168
|
-
|
1169
|
-
def __repr__(self):
|
1170
|
-
return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta}, epsilon={self.epsilon})'
|
1171
|
-
|
1172
|
-
def __hash__(self):
|
1173
|
-
return hash((self.__class__, self.alpha, self.beta, self.epsilon))
|
1174
|
-
|
1175
|
-
|
1176
|
-
def s2nn(
|
1177
|
-
x: jax.Array,
|
1178
|
-
alpha: float = 4.,
|
1179
|
-
beta: float = 1.,
|
1180
|
-
epsilon: float = 1e-8,
|
1181
|
-
|
1182
|
-
):
|
1183
|
-
r"""Judge spiking state with the S2NN surrogate spiking function [1]_.
|
1184
|
-
|
1185
|
-
If `origin=False`, computes the forward function:
|
1186
|
-
|
1187
|
-
.. math::
|
1188
|
-
|
1189
|
-
g(x) = \begin{cases}
|
1190
|
-
1, & x \geq 0 \\
|
1191
|
-
0, & x < 0 \\
|
1192
|
-
\end{cases}
|
1193
|
-
|
1194
|
-
If `origin=True`, computes the original function:
|
1195
|
-
|
1196
|
-
.. math::
|
1197
|
-
|
1198
|
-
\begin{split}g(x) = \begin{cases}
|
1199
|
-
\mathrm{sigmoid} (\alpha x), x < 0 \\
|
1200
|
-
\beta \ln(|x + 1|) + 0.5, x \ge 0
|
1201
|
-
\end{cases}\end{split}
|
1202
|
-
|
1203
|
-
Backward function:
|
1204
|
-
|
1205
|
-
.. math::
|
1206
|
-
|
1207
|
-
\begin{split}g'(x) = \begin{cases}
|
1208
|
-
\alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x), x < 0 \\
|
1209
|
-
\frac{\beta}{(x + 1)}, x \ge 0
|
1210
|
-
\end{cases}\end{split}
|
1211
|
-
|
1212
|
-
.. plot::
|
1213
|
-
:include-source: True
|
1214
|
-
|
1215
|
-
>>> import jax
|
1216
|
-
>>> import brainstate.nn as nn
|
1217
|
-
>>> import brainstate as brainstate
|
1218
|
-
>>> import matplotlib.pyplot as plt
|
1219
|
-
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1220
|
-
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.s2nn)(xs, 4., 1.)
|
1221
|
-
>>> plt.plot(xs, grads, label=r'$\alpha=4, \beta=1$')
|
1222
|
-
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.s2nn)(xs, 8., 2.)
|
1223
|
-
>>> plt.plot(xs, grads, label=r'$\alpha=8, \beta=2$')
|
1224
|
-
>>> plt.legend()
|
1225
|
-
>>> plt.show()
|
1226
|
-
|
1227
|
-
Parameters
|
1228
|
-
----------
|
1229
|
-
x: jax.Array, Array
|
1230
|
-
The input data.
|
1231
|
-
alpha: float
|
1232
|
-
The param that controls the gradient when ``x < 0``.
|
1233
|
-
beta: float
|
1234
|
-
The param that controls the gradient when ``x >= 0``
|
1235
|
-
epsilon: float
|
1236
|
-
Avoid nan
|
1237
|
-
|
1238
|
-
|
1239
|
-
Returns
|
1240
|
-
-------
|
1241
|
-
out: jax.Array
|
1242
|
-
The spiking state.
|
1243
|
-
|
1244
|
-
References
|
1245
|
-
----------
|
1246
|
-
.. [1] Suetake, Kazuma et al. “S2NN: Time Step Reduction of Spiking Surrogate Gradients for Training Energy Efficient Single-Step Neural Networks.” ArXiv abs/2201.10879 (2022): n. pag.
|
1247
|
-
|
1248
|
-
"""
|
1249
|
-
return S2NN(alpha=alpha, beta=beta, epsilon=epsilon)(x)
|
1250
|
-
|
1251
|
-
|
1252
|
-
class QPseudoSpike(Surrogate):
|
1253
|
-
"""Judge spiking state with the q-PseudoSpike surrogate function.
|
1254
|
-
|
1255
|
-
See Also
|
1256
|
-
--------
|
1257
|
-
q_pseudo_spike
|
1258
|
-
"""
|
1259
|
-
|
1260
|
-
def __init__(self, alpha=2.):
|
1261
|
-
super().__init__()
|
1262
|
-
self.alpha = alpha
|
1263
|
-
|
1264
|
-
def surrogate_grad(self, x):
|
1265
|
-
dx = jnp.power(1 + 2 / (self.alpha + 1) * jnp.abs(x), -self.alpha)
|
1266
|
-
return dx
|
1267
|
-
|
1268
|
-
def surrogate_fun(self, x):
|
1269
|
-
z = jnp.where(x < 0.,
|
1270
|
-
0.5 * jnp.power(1 - 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha),
|
1271
|
-
1. - 0.5 * jnp.power(1 + 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha))
|
1272
|
-
return z
|
1273
|
-
|
1274
|
-
def __repr__(self):
|
1275
|
-
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
1276
|
-
|
1277
|
-
def __hash__(self):
|
1278
|
-
return hash((self.__class__, self.alpha))
|
1279
|
-
|
1280
|
-
|
1281
|
-
def q_pseudo_spike(
|
1282
|
-
x: jax.Array,
|
1283
|
-
alpha: float = 2.,
|
1284
|
-
|
1285
|
-
):
|
1286
|
-
r"""Judge spiking state with the q-PseudoSpike surrogate function [1]_.
|
1287
|
-
|
1288
|
-
If `origin=False`, computes the forward function:
|
1289
|
-
|
1290
|
-
.. math::
|
1291
|
-
|
1292
|
-
g(x) = \begin{cases}
|
1293
|
-
1, & x \geq 0 \\
|
1294
|
-
0, & x < 0 \\
|
1295
|
-
\end{cases}
|
1296
|
-
|
1297
|
-
If `origin=True`, computes the original function:
|
1298
|
-
|
1299
|
-
.. math::
|
1300
|
-
|
1301
|
-
\begin{split}g(x) =
|
1302
|
-
\begin{cases}
|
1303
|
-
\frac{1}{2}(1-\frac{2x}{\alpha-1})^{1-\alpha}, & x < 0 \\
|
1304
|
-
1 - \frac{1}{2}(1+\frac{2x}{\alpha-1})^{1-\alpha}, & x \geq 0.
|
1305
|
-
\end{cases}\end{split}
|
1306
|
-
|
1307
|
-
Backward function:
|
1308
|
-
|
1309
|
-
.. math::
|
1310
|
-
|
1311
|
-
g'(x) = (1+\frac{2|x|}{\alpha-1})^{-\alpha}
|
1312
|
-
|
1313
|
-
.. plot::
|
1314
|
-
:include-source: True
|
1315
|
-
|
1316
|
-
>>> import jax
|
1317
|
-
>>> import brainstate.nn as nn
|
1318
|
-
>>> import brainstate as brainstate
|
1319
|
-
>>> import matplotlib.pyplot as plt
|
1320
|
-
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1321
|
-
>>> for alpha in [0.5, 1., 2., 4.]:
|
1322
|
-
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.q_pseudo_spike)(xs, alpha)
|
1323
|
-
>>> plt.plot(xs, grads, label=r'$\alpha=$' + str(alpha))
|
1324
|
-
>>> plt.legend()
|
1325
|
-
>>> plt.show()
|
1326
|
-
|
1327
|
-
Parameters
|
1328
|
-
----------
|
1329
|
-
x: jax.Array, Array
|
1330
|
-
The input data.
|
1331
|
-
alpha: float
|
1332
|
-
The parameter to control tail fatness of gradient.
|
1333
|
-
|
1334
|
-
|
1335
|
-
Returns
|
1336
|
-
-------
|
1337
|
-
out: jax.Array
|
1338
|
-
The spiking state.
|
1339
|
-
|
1340
|
-
References
|
1341
|
-
----------
|
1342
|
-
.. [1] Herranz-Celotti, Luca and Jean Rouat. “Surrogate Gradients Design.” ArXiv abs/2202.00282 (2022): n. pag.
|
1343
|
-
"""
|
1344
|
-
return QPseudoSpike(alpha=alpha)(x)
|
1345
|
-
|
1346
|
-
|
1347
|
-
class LeakyRelu(Surrogate):
|
1348
|
-
"""Judge spiking state with the Leaky ReLU function.
|
1349
|
-
|
1350
|
-
See Also
|
1351
|
-
--------
|
1352
|
-
leaky_relu
|
1353
|
-
"""
|
1354
|
-
|
1355
|
-
def __init__(self, alpha=0.1, beta=1.):
|
1356
|
-
super().__init__()
|
1357
|
-
self.alpha = alpha
|
1358
|
-
self.beta = beta
|
1359
|
-
|
1360
|
-
def surrogate_fun(self, x):
|
1361
|
-
return jnp.where(x < 0., self.alpha * x, self.beta * x)
|
1362
|
-
|
1363
|
-
def surrogate_grad(self, x):
|
1364
|
-
dx = jnp.where(x < 0., self.alpha, self.beta)
|
1365
|
-
return dx
|
1366
|
-
|
1367
|
-
def __repr__(self):
|
1368
|
-
return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta})'
|
1369
|
-
|
1370
|
-
def __hash__(self):
|
1371
|
-
return hash((self.__class__, self.alpha, self.beta))
|
1372
|
-
|
1373
|
-
|
1374
|
-
def leaky_relu(
|
1375
|
-
x: jax.Array,
|
1376
|
-
alpha: float = 0.1,
|
1377
|
-
beta: float = 1.,
|
1378
|
-
|
1379
|
-
):
|
1380
|
-
r"""Judge spiking state with the Leaky ReLU function.
|
1381
|
-
|
1382
|
-
If `origin=False`, computes the forward function:
|
1383
|
-
|
1384
|
-
.. math::
|
1385
|
-
|
1386
|
-
g(x) = \begin{cases}
|
1387
|
-
1, & x \geq 0 \\
|
1388
|
-
0, & x < 0 \\
|
1389
|
-
\end{cases}
|
1390
|
-
|
1391
|
-
If `origin=True`, computes the original function:
|
1392
|
-
|
1393
|
-
.. math::
|
1394
|
-
|
1395
|
-
\begin{split}g(x) =
|
1396
|
-
\begin{cases}
|
1397
|
-
\beta \cdot x, & x \geq 0 \\
|
1398
|
-
\alpha \cdot x, & x < 0 \\
|
1399
|
-
\end{cases}\end{split}
|
1400
|
-
|
1401
|
-
Backward function:
|
1402
|
-
|
1403
|
-
.. math::
|
1404
|
-
|
1405
|
-
\begin{split}g'(x) =
|
1406
|
-
\begin{cases}
|
1407
|
-
\beta, & x \geq 0 \\
|
1408
|
-
\alpha, & x < 0 \\
|
1409
|
-
\end{cases}\end{split}
|
1410
|
-
|
1411
|
-
.. plot::
|
1412
|
-
:include-source: True
|
1413
|
-
|
1414
|
-
>>> import jax
|
1415
|
-
>>> import brainstate.nn as nn
|
1416
|
-
>>> import brainstate as brainstate
|
1417
|
-
>>> import matplotlib.pyplot as plt
|
1418
|
-
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1419
|
-
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.leaky_relu)(xs, 0., 1.)
|
1420
|
-
>>> plt.plot(xs, grads, label=r'$\alpha=0., \beta=1.$')
|
1421
|
-
>>> plt.legend()
|
1422
|
-
>>> plt.show()
|
1423
|
-
|
1424
|
-
Parameters
|
1425
|
-
----------
|
1426
|
-
x: jax.Array, Array
|
1427
|
-
The input data.
|
1428
|
-
alpha: float
|
1429
|
-
The parameter to control the gradient when :math:`x < 0`.
|
1430
|
-
beta: float
|
1431
|
-
The parameter to control the gradient when :math:`x >= 0`.
|
1432
|
-
|
1433
|
-
|
1434
|
-
Returns
|
1435
|
-
-------
|
1436
|
-
out: jax.Array
|
1437
|
-
The spiking state.
|
1438
|
-
"""
|
1439
|
-
return LeakyRelu(alpha=alpha, beta=beta)(x)
|
1440
|
-
|
1441
|
-
|
1442
|
-
class LogTailedRelu(Surrogate):
|
1443
|
-
"""Judge spiking state with the Log-tailed ReLU function.
|
1444
|
-
|
1445
|
-
See Also
|
1446
|
-
--------
|
1447
|
-
log_tailed_relu
|
1448
|
-
"""
|
1449
|
-
|
1450
|
-
def __init__(self, alpha=0.):
|
1451
|
-
super().__init__()
|
1452
|
-
self.alpha = alpha
|
1453
|
-
|
1454
|
-
def surrogate_fun(self, x):
|
1455
|
-
z = jnp.where(x > 1,
|
1456
|
-
jnp.log(x),
|
1457
|
-
jnp.where(x > 0,
|
1458
|
-
x,
|
1459
|
-
self.alpha * x))
|
1460
|
-
return z
|
1461
|
-
|
1462
|
-
def surrogate_grad(self, x):
|
1463
|
-
dx = jnp.where(x > 1,
|
1464
|
-
1 / x,
|
1465
|
-
jnp.where(x > 0,
|
1466
|
-
1.,
|
1467
|
-
self.alpha))
|
1468
|
-
return dx
|
1469
|
-
|
1470
|
-
def __repr__(self):
|
1471
|
-
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
1472
|
-
|
1473
|
-
def __hash__(self):
|
1474
|
-
return hash((self.__class__, self.alpha))
|
1475
|
-
|
1476
|
-
|
1477
|
-
def log_tailed_relu(
|
1478
|
-
x: jax.Array,
|
1479
|
-
alpha: float = 0.,
|
1480
|
-
|
1481
|
-
):
|
1482
|
-
r"""Judge spiking state with the Log-tailed ReLU function [1]_.
|
1483
|
-
|
1484
|
-
If `origin=False`, computes the forward function:
|
1485
|
-
|
1486
|
-
.. math::
|
1487
|
-
|
1488
|
-
g(x) = \begin{cases}
|
1489
|
-
1, & x \geq 0 \\
|
1490
|
-
0, & x < 0 \\
|
1491
|
-
\end{cases}
|
1492
|
-
|
1493
|
-
If `origin=True`, computes the original function:
|
1494
|
-
|
1495
|
-
.. math::
|
1496
|
-
|
1497
|
-
\begin{split}g(x) =
|
1498
|
-
\begin{cases}
|
1499
|
-
\alpha x, & x \leq 0 \\
|
1500
|
-
x, & 0 < x \leq 0 \\
|
1501
|
-
log(x), x > 1 \\
|
1502
|
-
\end{cases}\end{split}
|
1503
|
-
|
1504
|
-
Backward function:
|
1505
|
-
|
1506
|
-
.. math::
|
1507
|
-
|
1508
|
-
\begin{split}g'(x) =
|
1509
|
-
\begin{cases}
|
1510
|
-
\alpha, & x \leq 0 \\
|
1511
|
-
1, & 0 < x \leq 0 \\
|
1512
|
-
\frac{1}{x}, x > 1 \\
|
1513
|
-
\end{cases}\end{split}
|
1514
|
-
|
1515
|
-
.. plot::
|
1516
|
-
:include-source: True
|
1517
|
-
|
1518
|
-
>>> import jax
|
1519
|
-
>>> import brainstate.nn as nn
|
1520
|
-
>>> import brainstate as brainstate
|
1521
|
-
>>> import matplotlib.pyplot as plt
|
1522
|
-
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1523
|
-
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.leaky_relu)(xs, 0., 1.)
|
1524
|
-
>>> plt.plot(xs, grads, label=r'$\alpha=0., \beta=1.$')
|
1525
|
-
>>> plt.legend()
|
1526
|
-
>>> plt.show()
|
1527
|
-
|
1528
|
-
Parameters
|
1529
|
-
----------
|
1530
|
-
x: jax.Array, Array
|
1531
|
-
The input data.
|
1532
|
-
alpha: float
|
1533
|
-
The parameter to control the gradient.
|
1534
|
-
|
1535
|
-
|
1536
|
-
Returns
|
1537
|
-
-------
|
1538
|
-
out: jax.Array
|
1539
|
-
The spiking state.
|
1540
|
-
|
1541
|
-
References
|
1542
|
-
----------
|
1543
|
-
.. [1] Cai, Zhaowei et al. “Deep Learning with Low Precision by Half-Wave Gaussian Quantization.” 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2017): 5406-5414.
|
1544
|
-
"""
|
1545
|
-
return LogTailedRelu(alpha=alpha)(x)
|
1546
|
-
|
1547
|
-
|
1548
|
-
class ReluGrad(Surrogate):
|
1549
|
-
"""Judge spiking state with the ReLU gradient function.
|
1550
|
-
|
1551
|
-
See Also
|
1552
|
-
--------
|
1553
|
-
relu_grad
|
1554
|
-
"""
|
1555
|
-
|
1556
|
-
def __init__(self, alpha=0.3, width=1.):
|
1557
|
-
super().__init__()
|
1558
|
-
self.alpha = alpha
|
1559
|
-
self.width = width
|
1560
|
-
|
1561
|
-
def surrogate_grad(self, x):
|
1562
|
-
dx = jnp.maximum(self.alpha * self.width - jnp.abs(x) * self.alpha, 0)
|
1563
|
-
return dx
|
1564
|
-
|
1565
|
-
def __repr__(self):
|
1566
|
-
return f'{self.__class__.__name__}(alpha={self.alpha}, width={self.width})'
|
1567
|
-
|
1568
|
-
def __hash__(self):
|
1569
|
-
return hash((self.__class__, self.alpha, self.width))
|
1570
|
-
|
1571
|
-
|
1572
|
-
def relu_grad(
|
1573
|
-
x: jax.Array,
|
1574
|
-
alpha: float = 0.3,
|
1575
|
-
width: float = 1.,
|
1576
|
-
):
|
1577
|
-
r"""Spike function with the ReLU gradient function [1]_.
|
1578
|
-
|
1579
|
-
The forward function:
|
1580
|
-
|
1581
|
-
.. math::
|
1582
|
-
|
1583
|
-
g(x) = \begin{cases}
|
1584
|
-
1, & x \geq 0 \\
|
1585
|
-
0, & x < 0 \\
|
1586
|
-
\end{cases}
|
1587
|
-
|
1588
|
-
Backward function:
|
1589
|
-
|
1590
|
-
.. math::
|
1591
|
-
|
1592
|
-
g'(x) = \text{ReLU}(\alpha * (\mathrm{width}-|x|))
|
1593
|
-
|
1594
|
-
.. plot::
|
1595
|
-
:include-source: True
|
1596
|
-
|
1597
|
-
>>> import jax
|
1598
|
-
>>> import brainstate.nn as nn
|
1599
|
-
>>> import brainstate as brainstate
|
1600
|
-
>>> import matplotlib.pyplot as plt
|
1601
|
-
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1602
|
-
>>> for s in [0.5, 1.]:
|
1603
|
-
>>> for w in [1, 2.]:
|
1604
|
-
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.relu_grad)(xs, s, w)
|
1605
|
-
>>> plt.plot(xs, grads, label=r'$\alpha=$' + f'{s}, width={w}')
|
1606
|
-
>>> plt.legend()
|
1607
|
-
>>> plt.show()
|
1608
|
-
|
1609
|
-
Parameters
|
1610
|
-
----------
|
1611
|
-
x: jax.Array, Array
|
1612
|
-
The input data.
|
1613
|
-
alpha: float
|
1614
|
-
The parameter to control the gradient.
|
1615
|
-
width: float
|
1616
|
-
The parameter to control the width of the gradient.
|
1617
|
-
|
1618
|
-
Returns
|
1619
|
-
-------
|
1620
|
-
out: jax.Array
|
1621
|
-
The spiking state.
|
1622
|
-
|
1623
|
-
References
|
1624
|
-
----------
|
1625
|
-
.. [1] Neftci, E. O., Mostafa, H. & Zenke, F. Surrogate gradient learning in spiking neural networks. IEEE Signal Process. Mag. 36, 61–63 (2019).
|
1626
|
-
"""
|
1627
|
-
return ReluGrad(alpha=alpha, width=width)(x)
|
1628
|
-
|
1629
|
-
|
1630
|
-
class GaussianGrad(Surrogate):
|
1631
|
-
"""Judge spiking state with the Gaussian gradient function.
|
1632
|
-
|
1633
|
-
See Also
|
1634
|
-
--------
|
1635
|
-
gaussian_grad
|
1636
|
-
"""
|
1637
|
-
|
1638
|
-
def __init__(self, sigma=0.5, alpha=0.5):
|
1639
|
-
super().__init__()
|
1640
|
-
self.sigma = sigma
|
1641
|
-
self.alpha = alpha
|
1642
|
-
|
1643
|
-
def surrogate_grad(self, x):
|
1644
|
-
dx = jnp.exp(-(x ** 2) / 2 * jnp.power(self.sigma, 2)) / (jnp.sqrt(2 * jnp.pi) * self.sigma)
|
1645
|
-
return self.alpha * dx
|
1646
|
-
|
1647
|
-
def __repr__(self):
|
1648
|
-
return f'{self.__class__.__name__}(alpha={self.alpha}, sigma={self.sigma})'
|
1649
|
-
|
1650
|
-
def __hash__(self):
|
1651
|
-
return hash((self.__class__, self.alpha, self.sigma))
|
1652
|
-
|
1653
|
-
|
1654
|
-
def gaussian_grad(
|
1655
|
-
x: jax.Array,
|
1656
|
-
sigma: float = 0.5,
|
1657
|
-
alpha: float = 0.5,
|
1658
|
-
):
|
1659
|
-
r"""Spike function with the Gaussian gradient function [1]_.
|
1660
|
-
|
1661
|
-
The forward function:
|
1662
|
-
|
1663
|
-
.. math::
|
1664
|
-
|
1665
|
-
g(x) = \begin{cases}
|
1666
|
-
1, & x \geq 0 \\
|
1667
|
-
0, & x < 0 \\
|
1668
|
-
\end{cases}
|
1669
|
-
|
1670
|
-
Backward function:
|
1671
|
-
|
1672
|
-
.. math::
|
1673
|
-
|
1674
|
-
g'(x) = \alpha * \text{gaussian}(x, 0., \sigma)
|
1675
|
-
|
1676
|
-
.. plot::
|
1677
|
-
:include-source: True
|
1678
|
-
|
1679
|
-
>>> import jax
|
1680
|
-
>>> import brainstate.nn as nn
|
1681
|
-
>>> import brainstate as brainstate
|
1682
|
-
>>> import matplotlib.pyplot as plt
|
1683
|
-
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1684
|
-
>>> for s in [0.5, 1., 2.]:
|
1685
|
-
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.gaussian_grad)(xs, s, 0.5)
|
1686
|
-
>>> plt.plot(xs, grads, label=r'$\alpha=0.5, \sigma=$' + str(s))
|
1687
|
-
>>> plt.legend()
|
1688
|
-
>>> plt.show()
|
1689
|
-
|
1690
|
-
Parameters
|
1691
|
-
----------
|
1692
|
-
x: jax.Array, Array
|
1693
|
-
The input data.
|
1694
|
-
sigma: float
|
1695
|
-
The parameter to control the variance of gaussian distribution.
|
1696
|
-
alpha: float
|
1697
|
-
The parameter to control the scale of the gradient.
|
1698
|
-
|
1699
|
-
Returns
|
1700
|
-
-------
|
1701
|
-
out: jax.Array
|
1702
|
-
The spiking state.
|
1703
|
-
|
1704
|
-
References
|
1705
|
-
----------
|
1706
|
-
.. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021).
|
1707
|
-
"""
|
1708
|
-
return GaussianGrad(sigma=sigma, alpha=alpha)(x)
|
1709
|
-
|
1710
|
-
|
1711
|
-
class MultiGaussianGrad(Surrogate):
|
1712
|
-
"""Judge spiking state with the multi-Gaussian gradient function.
|
1713
|
-
|
1714
|
-
See Also
|
1715
|
-
--------
|
1716
|
-
multi_gaussian_grad
|
1717
|
-
"""
|
1718
|
-
|
1719
|
-
def __init__(self, h=0.15, s=6.0, sigma=0.5, scale=0.5):
|
1720
|
-
super().__init__()
|
1721
|
-
self.h = h
|
1722
|
-
self.s = s
|
1723
|
-
self.sigma = sigma
|
1724
|
-
self.scale = scale
|
1725
|
-
|
1726
|
-
def surrogate_grad(self, x):
|
1727
|
-
g1 = jnp.exp(-x ** 2 / (2 * jnp.power(self.sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * self.sigma)
|
1728
|
-
g2 = jnp.exp(-(x - self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2))
|
1729
|
-
) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma)
|
1730
|
-
g3 = jnp.exp(-(x + self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2))
|
1731
|
-
) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma)
|
1732
|
-
dx = g1 * (1. + self.h) - g2 * self.h - g3 * self.h
|
1733
|
-
return self.scale * dx
|
1734
|
-
|
1735
|
-
def __repr__(self):
|
1736
|
-
return f'{self.__class__.__name__}(h={self.h}, s={self.s}, sigma={self.sigma}, scale={self.scale})'
|
1737
|
-
|
1738
|
-
def __hash__(self):
|
1739
|
-
return hash((self.__class__, self.h, self.s, self.sigma, self.scale))
|
1740
|
-
|
1741
|
-
|
1742
|
-
def multi_gaussian_grad(
|
1743
|
-
x: jax.Array,
|
1744
|
-
h: float = 0.15,
|
1745
|
-
s: float = 6.0,
|
1746
|
-
sigma: float = 0.5,
|
1747
|
-
scale: float = 0.5,
|
1748
|
-
):
|
1749
|
-
r"""Spike function with the multi-Gaussian gradient function [1]_.
|
1750
|
-
|
1751
|
-
The forward function:
|
1752
|
-
|
1753
|
-
.. math::
|
1754
|
-
|
1755
|
-
g(x) = \begin{cases}
|
1756
|
-
1, & x \geq 0 \\
|
1757
|
-
0, & x < 0 \\
|
1758
|
-
\end{cases}
|
1759
|
-
|
1760
|
-
Backward function:
|
1761
|
-
|
1762
|
-
.. math::
|
1763
|
-
|
1764
|
-
\begin{array}{l}
|
1765
|
-
g'(x)=(1+h){{{\mathcal{N}}}}(x, 0, {\sigma }^{2})
|
1766
|
-
-h{{{\mathcal{N}}}}(x, \sigma,{(s\sigma )}^{2})-
|
1767
|
-
h{{{\mathcal{N}}}}(x, -\sigma ,{(s\sigma )}^{2})
|
1768
|
-
\end{array}
|
1769
|
-
|
1770
|
-
|
1771
|
-
.. plot::
|
1772
|
-
:include-source: True
|
1773
|
-
|
1774
|
-
>>> import jax
|
1775
|
-
>>> import brainstate.nn as nn
|
1776
|
-
>>> import brainstate as brainstate
|
1777
|
-
>>> import matplotlib.pyplot as plt
|
1778
|
-
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1779
|
-
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.multi_gaussian_grad)(xs)
|
1780
|
-
>>> plt.plot(xs, grads)
|
1781
|
-
>>> plt.show()
|
1782
|
-
|
1783
|
-
Parameters
|
1784
|
-
----------
|
1785
|
-
x: jax.Array, Array
|
1786
|
-
The input data.
|
1787
|
-
h: float
|
1788
|
-
The hyper-parameters of approximate function
|
1789
|
-
s: float
|
1790
|
-
The hyper-parameters of approximate function
|
1791
|
-
sigma: float
|
1792
|
-
The gaussian sigma.
|
1793
|
-
scale: float
|
1794
|
-
The gradient scale.
|
1795
|
-
|
1796
|
-
Returns
|
1797
|
-
-------
|
1798
|
-
out: jax.Array
|
1799
|
-
The spiking state.
|
1800
|
-
|
1801
|
-
References
|
1802
|
-
----------
|
1803
|
-
.. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021).
|
1804
|
-
"""
|
1805
|
-
return MultiGaussianGrad(h=h, s=s, sigma=sigma, scale=scale)(x)
|
1806
|
-
|
1807
|
-
|
1808
|
-
class InvSquareGrad(Surrogate):
|
1809
|
-
"""Judge spiking state with the inverse-square surrogate gradient function.
|
1810
|
-
|
1811
|
-
See Also
|
1812
|
-
--------
|
1813
|
-
inv_square_grad
|
1814
|
-
"""
|
1815
|
-
|
1816
|
-
def __init__(self, alpha=100.):
|
1817
|
-
super().__init__()
|
1818
|
-
self.alpha = alpha
|
1819
|
-
|
1820
|
-
def surrogate_grad(self, x):
|
1821
|
-
dx = 1. / (self.alpha * jnp.abs(x) + 1.0) ** 2
|
1822
|
-
return dx
|
1823
|
-
|
1824
|
-
def __repr__(self):
|
1825
|
-
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
1826
|
-
|
1827
|
-
def __hash__(self):
|
1828
|
-
return hash((self.__class__, self.alpha))
|
1829
|
-
|
1830
|
-
|
1831
|
-
def inv_square_grad(
|
1832
|
-
x: jax.Array,
|
1833
|
-
alpha: float = 100.
|
1834
|
-
):
|
1835
|
-
r"""Spike function with the inverse-square surrogate gradient.
|
1836
|
-
|
1837
|
-
Forward function:
|
1838
|
-
|
1839
|
-
.. math::
|
1840
|
-
|
1841
|
-
g(x) = \begin{cases}
|
1842
|
-
1, & x \geq 0 \\
|
1843
|
-
0, & x < 0 \\
|
1844
|
-
\end{cases}
|
1845
|
-
|
1846
|
-
Backward function:
|
1847
|
-
|
1848
|
-
.. math::
|
1849
|
-
|
1850
|
-
g'(x) = \frac{1}{(\alpha * |x| + 1.) ^ 2}
|
1851
|
-
|
1852
|
-
|
1853
|
-
.. plot::
|
1854
|
-
:include-source: True
|
1855
|
-
|
1856
|
-
>>> import jax
|
1857
|
-
>>> import brainstate.nn as nn
|
1858
|
-
>>> import brainstate as brainstate
|
1859
|
-
>>> import matplotlib.pyplot as plt
|
1860
|
-
>>> xs = jax.numpy.linspace(-1, 1, 1000)
|
1861
|
-
>>> for alpha in [1., 10., 100.]:
|
1862
|
-
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.inv_square_grad)(xs, alpha)
|
1863
|
-
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
1864
|
-
>>> plt.legend()
|
1865
|
-
>>> plt.show()
|
1866
|
-
|
1867
|
-
Parameters
|
1868
|
-
----------
|
1869
|
-
x: jax.Array, Array
|
1870
|
-
The input data.
|
1871
|
-
alpha: float
|
1872
|
-
Parameter to control smoothness of gradient
|
1873
|
-
|
1874
|
-
Returns
|
1875
|
-
-------
|
1876
|
-
out: jax.Array
|
1877
|
-
The spiking state.
|
1878
|
-
"""
|
1879
|
-
return InvSquareGrad(alpha=alpha)(x)
|
1880
|
-
|
1881
|
-
|
1882
|
-
class SlayerGrad(Surrogate):
|
1883
|
-
"""Judge spiking state with the slayer surrogate gradient function.
|
1884
|
-
|
1885
|
-
See Also
|
1886
|
-
--------
|
1887
|
-
slayer_grad
|
1888
|
-
"""
|
1889
|
-
|
1890
|
-
def __init__(self, alpha=1.):
|
1891
|
-
super().__init__()
|
1892
|
-
self.alpha = alpha
|
1893
|
-
|
1894
|
-
def surrogate_grad(self, x):
|
1895
|
-
dx = jnp.exp(-self.alpha * jnp.abs(x))
|
1896
|
-
return dx
|
1897
|
-
|
1898
|
-
def __repr__(self):
|
1899
|
-
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
1900
|
-
|
1901
|
-
def __hash__(self):
|
1902
|
-
return hash((self.__class__, self.alpha))
|
1903
|
-
|
1904
|
-
|
1905
|
-
def slayer_grad(
|
1906
|
-
x: jax.Array,
|
1907
|
-
alpha: float = 1.
|
1908
|
-
):
|
1909
|
-
r"""Spike function with the slayer surrogate gradient function.
|
1910
|
-
|
1911
|
-
Forward function:
|
1912
|
-
|
1913
|
-
.. math::
|
1914
|
-
|
1915
|
-
g(x) = \begin{cases}
|
1916
|
-
1, & x \geq 0 \\
|
1917
|
-
0, & x < 0 \\
|
1918
|
-
\end{cases}
|
1919
|
-
|
1920
|
-
Backward function:
|
1921
|
-
|
1922
|
-
.. math::
|
1923
|
-
|
1924
|
-
g'(x) = \exp(-\alpha |x|)
|
1925
|
-
|
1926
|
-
|
1927
|
-
.. plot::
|
1928
|
-
:include-source: True
|
1929
|
-
|
1930
|
-
>>> import jax
|
1931
|
-
>>> import brainstate.nn as nn
|
1932
|
-
>>> import brainstate as brainstate
|
1933
|
-
>>> import matplotlib.pyplot as plt
|
1934
|
-
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1935
|
-
>>> for alpha in [0.5, 1., 2., 4.]:
|
1936
|
-
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.slayer_grad)(xs, alpha)
|
1937
|
-
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
1938
|
-
>>> plt.legend()
|
1939
|
-
>>> plt.show()
|
1940
|
-
|
1941
|
-
Parameters
|
1942
|
-
----------
|
1943
|
-
x: jax.Array, Array
|
1944
|
-
The input data.
|
1945
|
-
alpha: float
|
1946
|
-
Parameter to control smoothness of gradient
|
1947
|
-
|
1948
|
-
Returns
|
1949
|
-
-------
|
1950
|
-
out: jax.Array
|
1951
|
-
The spiking state.
|
1952
|
-
|
1953
|
-
References
|
1954
|
-
----------
|
1955
|
-
.. [1] Shrestha, S. B. & Orchard, G. Slayer: spike layer error reassignment in time. In Advances in Neural Information Processing Systems Vol. 31, 1412–1421 (NeurIPS, 2018).
|
1956
|
-
"""
|
1957
|
-
return SlayerGrad(alpha=alpha)(x)
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
import jax
|
19
|
+
import jax.numpy as jnp
|
20
|
+
import jax.scipy as sci
|
21
|
+
from jax.interpreters import batching, ad, mlir
|
22
|
+
|
23
|
+
from brainstate._compatible_import import Primitive
|
24
|
+
from brainstate.util import PrettyObject
|
25
|
+
|
26
|
+
__all__ = [
|
27
|
+
'Surrogate',
|
28
|
+
'Sigmoid',
|
29
|
+
'sigmoid',
|
30
|
+
'PiecewiseQuadratic',
|
31
|
+
'piecewise_quadratic',
|
32
|
+
'PiecewiseExp',
|
33
|
+
'piecewise_exp',
|
34
|
+
'SoftSign',
|
35
|
+
'soft_sign',
|
36
|
+
'Arctan',
|
37
|
+
'arctan',
|
38
|
+
'NonzeroSignLog',
|
39
|
+
'nonzero_sign_log',
|
40
|
+
'ERF',
|
41
|
+
'erf',
|
42
|
+
'PiecewiseLeakyRelu',
|
43
|
+
'piecewise_leaky_relu',
|
44
|
+
'SquarewaveFourierSeries',
|
45
|
+
'squarewave_fourier_series',
|
46
|
+
'S2NN',
|
47
|
+
's2nn',
|
48
|
+
'QPseudoSpike',
|
49
|
+
'q_pseudo_spike',
|
50
|
+
'LeakyRelu',
|
51
|
+
'leaky_relu',
|
52
|
+
'LogTailedRelu',
|
53
|
+
'log_tailed_relu',
|
54
|
+
'ReluGrad',
|
55
|
+
'relu_grad',
|
56
|
+
'GaussianGrad',
|
57
|
+
'gaussian_grad',
|
58
|
+
'InvSquareGrad',
|
59
|
+
'inv_square_grad',
|
60
|
+
'MultiGaussianGrad',
|
61
|
+
'multi_gaussian_grad',
|
62
|
+
'SlayerGrad',
|
63
|
+
'slayer_grad',
|
64
|
+
]
|
65
|
+
|
66
|
+
|
67
|
+
def _heaviside_abstract(x, dx):
|
68
|
+
return [x]
|
69
|
+
|
70
|
+
|
71
|
+
def _heaviside_imp(x, dx):
|
72
|
+
z = jnp.asarray(x >= 0, dtype=x.dtype)
|
73
|
+
return [z]
|
74
|
+
|
75
|
+
|
76
|
+
def _heaviside_batching(args, axes):
|
77
|
+
x, dx = args
|
78
|
+
if axes[0] != axes[1]:
|
79
|
+
dx = batching.moveaxis(dx, axes[1], axes[0])
|
80
|
+
return heaviside_p.bind(x, dx), tuple([axes[0]])
|
81
|
+
|
82
|
+
|
83
|
+
def _heaviside_jvp(primals, tangents):
|
84
|
+
x, dx = primals
|
85
|
+
tx, tdx = tangents
|
86
|
+
primal_outs = heaviside_p.bind(x, dx)
|
87
|
+
tangent_outs = [dx * tx, ]
|
88
|
+
return primal_outs, tangent_outs
|
89
|
+
|
90
|
+
|
91
|
+
heaviside_p = Primitive('heaviside_p')
|
92
|
+
heaviside_p.multiple_results = True
|
93
|
+
heaviside_p.def_abstract_eval(_heaviside_abstract)
|
94
|
+
heaviside_p.def_impl(_heaviside_imp)
|
95
|
+
batching.primitive_batchers[heaviside_p] = _heaviside_batching
|
96
|
+
ad.primitive_jvps[heaviside_p] = _heaviside_jvp
|
97
|
+
mlir.register_lowering(heaviside_p, mlir.lower_fun(_heaviside_imp, multiple_results=True))
|
98
|
+
|
99
|
+
|
100
|
+
class Surrogate(PrettyObject):
|
101
|
+
"""The base surrograte gradient function.
|
102
|
+
|
103
|
+
To customize a surrogate gradient function, you can inherit this class and
|
104
|
+
implement the `surrogate_fun` and `surrogate_grad` methods.
|
105
|
+
|
106
|
+
Examples
|
107
|
+
--------
|
108
|
+
|
109
|
+
>>> import brainstate as brainstate
|
110
|
+
>>> import brainstate.nn as nn
|
111
|
+
>>> import jax.numpy as jnp
|
112
|
+
|
113
|
+
>>> class MySurrogate(brainstate.surrogate.Surrogate):
|
114
|
+
... def __init__(self, alpha=1.):
|
115
|
+
... super().__init__()
|
116
|
+
... self.alpha = alpha
|
117
|
+
...
|
118
|
+
... def surrogate_fun(self, x):
|
119
|
+
... return jnp.sin(x) * self.alpha
|
120
|
+
...
|
121
|
+
... def surrogate_grad(self, x):
|
122
|
+
... return jnp.cos(x) * self.alpha
|
123
|
+
|
124
|
+
"""
|
125
|
+
|
126
|
+
def __call__(self, x):
|
127
|
+
dx = self.surrogate_grad(x)
|
128
|
+
return heaviside_p.bind(x, dx)[0]
|
129
|
+
|
130
|
+
def surrogate_fun(self, x) -> jax.Array:
|
131
|
+
"""The surrogate function."""
|
132
|
+
raise NotImplementedError
|
133
|
+
|
134
|
+
def surrogate_grad(self, x) -> jax.Array:
|
135
|
+
"""The gradient function of the surrogate function."""
|
136
|
+
raise NotImplementedError
|
137
|
+
|
138
|
+
|
139
|
+
class Sigmoid(Surrogate):
|
140
|
+
"""Spike function with the sigmoid-shaped surrogate gradient.
|
141
|
+
|
142
|
+
This class implements a spiking neuron activation with a sigmoid-shaped
|
143
|
+
surrogate gradient for backpropagation. It can be used in spiking neural
|
144
|
+
networks to approximate the non-differentiable step function during training.
|
145
|
+
|
146
|
+
Parameters
|
147
|
+
----------
|
148
|
+
alpha : float, optional
|
149
|
+
A parameter controlling the steepness of the sigmoid curve in the
|
150
|
+
surrogate gradient. Higher values make the transition sharper.
|
151
|
+
Default is 4.0.
|
152
|
+
|
153
|
+
See Also
|
154
|
+
--------
|
155
|
+
sigmoid : Function version of this class.
|
156
|
+
|
157
|
+
"""
|
158
|
+
|
159
|
+
def __init__(self, alpha: float = 4.):
|
160
|
+
super().__init__()
|
161
|
+
self.alpha = alpha
|
162
|
+
|
163
|
+
def surrogate_fun(self, x):
|
164
|
+
"""Compute the surrogate function.
|
165
|
+
|
166
|
+
Parameters
|
167
|
+
----------
|
168
|
+
x : jax.Array
|
169
|
+
The input array.
|
170
|
+
|
171
|
+
Returns
|
172
|
+
-------
|
173
|
+
jax.Array
|
174
|
+
The output of the surrogate function.
|
175
|
+
"""
|
176
|
+
return sci.special.expit(self.alpha * x)
|
177
|
+
|
178
|
+
def surrogate_grad(self, x):
|
179
|
+
"""Compute the gradient of the surrogate function.
|
180
|
+
|
181
|
+
Parameters
|
182
|
+
----------
|
183
|
+
x : jax.Array
|
184
|
+
The input array.
|
185
|
+
|
186
|
+
Returns
|
187
|
+
-------
|
188
|
+
jax.Array
|
189
|
+
The gradient of the surrogate function.
|
190
|
+
"""
|
191
|
+
sgax = sci.special.expit(x * self.alpha)
|
192
|
+
dx = (1. - sgax) * sgax * self.alpha
|
193
|
+
return dx
|
194
|
+
|
195
|
+
def __repr__(self):
|
196
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
197
|
+
|
198
|
+
def __hash__(self):
|
199
|
+
return hash((self.__class__, self.alpha))
|
200
|
+
|
201
|
+
|
202
|
+
def sigmoid(
|
203
|
+
x: jax.Array,
|
204
|
+
alpha: float = 4.,
|
205
|
+
):
|
206
|
+
r"""
|
207
|
+
Compute a spike function with a sigmoid-shaped surrogate gradient.
|
208
|
+
|
209
|
+
This function implements a spiking neuron activation with a sigmoid-shaped
|
210
|
+
surrogate gradient for backpropagation. It can be used in spiking neural
|
211
|
+
networks to approximate the non-differentiable step function during training.
|
212
|
+
|
213
|
+
If `origin=False`, return the forward function:
|
214
|
+
|
215
|
+
.. math::
|
216
|
+
|
217
|
+
g(x) = \begin{cases}
|
218
|
+
1, & x \geq 0 \\
|
219
|
+
0, & x < 0 \\
|
220
|
+
\end{cases}
|
221
|
+
|
222
|
+
If `origin=True`, computes the original function:
|
223
|
+
|
224
|
+
.. math::
|
225
|
+
|
226
|
+
g(x) = \mathrm{sigmoid}(\alpha x) = \frac{1}{1+e^{-\alpha x}}
|
227
|
+
|
228
|
+
Backward function:
|
229
|
+
|
230
|
+
.. math::
|
231
|
+
|
232
|
+
g'(x) = \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x)
|
233
|
+
|
234
|
+
.. plot::
|
235
|
+
:include-source: True
|
236
|
+
|
237
|
+
>>> import jax
|
238
|
+
>>> import brainstate.nn as nn
|
239
|
+
>>> import brainstate as brainstate
|
240
|
+
>>> import matplotlib.pyplot as plt
|
241
|
+
>>> xs = jax.numpy.linspace(-2, 2, 1000)
|
242
|
+
>>> for alpha in [1., 2., 4.]:
|
243
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.sigmoid)(xs, alpha)
|
244
|
+
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
245
|
+
>>> plt.legend()
|
246
|
+
>>> plt.show()
|
247
|
+
|
248
|
+
Parameters
|
249
|
+
----------
|
250
|
+
x : jax.Array
|
251
|
+
The input array representing the neuron's membrane potential.
|
252
|
+
alpha : float, optional
|
253
|
+
A parameter controlling the steepness of the sigmoid curve in the
|
254
|
+
surrogate gradient. Higher values make the transition sharper.
|
255
|
+
Default is 4.0.
|
256
|
+
|
257
|
+
Returns
|
258
|
+
-------
|
259
|
+
jax.Array
|
260
|
+
An array of the same shape as the input, containing binary values (0 or 1)
|
261
|
+
representing the spiking state of each neuron.
|
262
|
+
|
263
|
+
Notes
|
264
|
+
-----
|
265
|
+
The forward pass uses a step function (1 for x >= 0, 0 for x < 0),
|
266
|
+
while the backward pass uses a sigmoid-shaped surrogate gradient for
|
267
|
+
smooth optimization.
|
268
|
+
|
269
|
+
The surrogate gradient is defined as:
|
270
|
+
g'(x) = alpha * (1 - sigmoid(alpha * x)) * sigmoid(alpha * x)
|
271
|
+
|
272
|
+
"""
|
273
|
+
return Sigmoid(alpha=alpha)(x)
|
274
|
+
|
275
|
+
|
276
|
+
class PiecewiseQuadratic(Surrogate):
|
277
|
+
"""Judge spiking state with a piecewise quadratic function.
|
278
|
+
|
279
|
+
See Also
|
280
|
+
--------
|
281
|
+
piecewise_quadratic
|
282
|
+
|
283
|
+
"""
|
284
|
+
|
285
|
+
def __init__(self, alpha: float = 1.):
|
286
|
+
super().__init__()
|
287
|
+
self.alpha = alpha
|
288
|
+
|
289
|
+
def surrogate_fun(self, x):
|
290
|
+
z = jnp.where(
|
291
|
+
x < -1 / self.alpha,
|
292
|
+
0.,
|
293
|
+
jnp.where(
|
294
|
+
x > 1 / self.alpha,
|
295
|
+
1.,
|
296
|
+
(-self.alpha * jnp.abs(x) / 2 + 1) * self.alpha * x + 0.5
|
297
|
+
)
|
298
|
+
)
|
299
|
+
return z
|
300
|
+
|
301
|
+
def surrogate_grad(self, x):
|
302
|
+
dx = jnp.where(jnp.abs(x) > 1 / self.alpha, 0., (-(self.alpha * x) ** 2 + self.alpha))
|
303
|
+
return dx
|
304
|
+
|
305
|
+
def __repr__(self):
|
306
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
307
|
+
|
308
|
+
def __hash__(self):
|
309
|
+
return hash((self.__class__, self.alpha))
|
310
|
+
|
311
|
+
|
312
|
+
def piecewise_quadratic(
|
313
|
+
x: jax.Array,
|
314
|
+
alpha: float = 1.,
|
315
|
+
):
|
316
|
+
r"""
|
317
|
+
Judge spiking state with a piecewise quadratic function [1]_ [2]_ [3]_ [4]_ [5]_.
|
318
|
+
|
319
|
+
This function implements a surrogate gradient method for spiking neural networks
|
320
|
+
using a piecewise quadratic function. It provides a differentiable approximation
|
321
|
+
of the step function used in the forward pass of spiking neurons.
|
322
|
+
|
323
|
+
If `origin=False`, computes the forward function:
|
324
|
+
|
325
|
+
.. math::
|
326
|
+
|
327
|
+
g(x) = \begin{cases}
|
328
|
+
1, & x \geq 0 \\
|
329
|
+
0, & x < 0 \\
|
330
|
+
\end{cases}
|
331
|
+
|
332
|
+
If `origin=True`, computes the original function:
|
333
|
+
|
334
|
+
.. math::
|
335
|
+
|
336
|
+
g(x) =
|
337
|
+
\begin{cases}
|
338
|
+
0, & x < -\frac{1}{\alpha} \\
|
339
|
+
-\frac{1}{2}\alpha^2|x|x + \alpha x + \frac{1}{2}, & |x| \leq \frac{1}{\alpha} \\
|
340
|
+
1, & x > \frac{1}{\alpha} \\
|
341
|
+
\end{cases}
|
342
|
+
|
343
|
+
Backward function:
|
344
|
+
|
345
|
+
.. math::
|
346
|
+
|
347
|
+
g'(x) =
|
348
|
+
\begin{cases}
|
349
|
+
0, & |x| > \frac{1}{\alpha} \\
|
350
|
+
-\alpha^2|x|+\alpha, & |x| \leq \frac{1}{\alpha}
|
351
|
+
\end{cases}
|
352
|
+
|
353
|
+
.. plot::
|
354
|
+
:include-source: True
|
355
|
+
|
356
|
+
>>> import jax
|
357
|
+
>>> import brainstate.nn as nn
|
358
|
+
>>> import brainstate as brainstate
|
359
|
+
>>> import matplotlib.pyplot as plt
|
360
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
361
|
+
>>> for alpha in [0.5, 1., 2., 4.]:
|
362
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.piecewise_quadratic)(xs, alpha)
|
363
|
+
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
364
|
+
>>> plt.legend()
|
365
|
+
>>> plt.show()
|
366
|
+
|
367
|
+
Parameters
|
368
|
+
----------
|
369
|
+
x : jax.Array
|
370
|
+
The input array representing the neuron's membrane potential.
|
371
|
+
alpha : float, optional
|
372
|
+
A parameter controlling the steepness of the surrogate gradient.
|
373
|
+
Higher values result in a steeper gradient. Default is 1.0.
|
374
|
+
|
375
|
+
Returns
|
376
|
+
-------
|
377
|
+
jax.Array
|
378
|
+
An array of the same shape as the input, containing binary values (0 or 1)
|
379
|
+
representing the spiking state of each neuron.
|
380
|
+
|
381
|
+
Notes
|
382
|
+
-----
|
383
|
+
The function uses different computations for forward and backward passes:
|
384
|
+
- Forward: Step function (1 for x >= 0, 0 for x < 0)
|
385
|
+
- Backward: Piecewise quadratic function for smooth gradient
|
386
|
+
|
387
|
+
The surrogate gradient is defined as:
|
388
|
+
g'(x) = 0 if |x| > 1/alpha
|
389
|
+
-alpha^2|x| + alpha if |x| <= 1/alpha
|
390
|
+
|
391
|
+
References
|
392
|
+
----------
|
393
|
+
.. [1] Esser S K, Merolla P A, Arthur J V, et al. Convolutional networks for fast, energy-efficient neuromorphic computing[J]. Proceedings of the national academy of sciences, 2016, 113(41): 11441-11446.
|
394
|
+
.. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331.
|
395
|
+
.. [3] Bellec G, Salaj D, Subramoney A, et al. Long short-term memory and learning-to-learn in networks of spiking neurons[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 795-805.
|
396
|
+
.. [4] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63.
|
397
|
+
.. [5] Panda P, Aketi S A, Roy K. Toward scalable, efficient, and accurate deep spiking neural networks with backward residual connections, stochastic softmax, and hybridization[J]. Frontiers in Neuroscience, 2020, 14.
|
398
|
+
"""
|
399
|
+
return PiecewiseQuadratic(alpha=alpha)(x)
|
400
|
+
|
401
|
+
|
402
|
+
class PiecewiseExp(Surrogate):
|
403
|
+
"""
|
404
|
+
Judge spiking state with a piecewise exponential function.
|
405
|
+
|
406
|
+
This class implements a surrogate gradient method for spiking neural networks
|
407
|
+
using a piecewise exponential function. It provides a differentiable approximation
|
408
|
+
of the step function used in the forward pass of spiking neurons.
|
409
|
+
|
410
|
+
Parameters
|
411
|
+
----------
|
412
|
+
alpha : float, optional
|
413
|
+
A parameter controlling the steepness of the surrogate gradient.
|
414
|
+
Higher values result in a steeper gradient. Default is 1.0.
|
415
|
+
|
416
|
+
See Also
|
417
|
+
--------
|
418
|
+
piecewise_exp : Function version of this class.
|
419
|
+
"""
|
420
|
+
|
421
|
+
def __init__(self, alpha: float = 1.):
|
422
|
+
super().__init__()
|
423
|
+
self.alpha = alpha
|
424
|
+
|
425
|
+
def surrogate_grad(self, x):
|
426
|
+
"""
|
427
|
+
Compute the surrogate gradient.
|
428
|
+
|
429
|
+
Parameters
|
430
|
+
----------
|
431
|
+
x : jax.Array
|
432
|
+
The input array.
|
433
|
+
|
434
|
+
Returns
|
435
|
+
-------
|
436
|
+
jax.Array
|
437
|
+
The surrogate gradient.
|
438
|
+
"""
|
439
|
+
dx = (self.alpha / 2) * jnp.exp(-self.alpha * jnp.abs(x))
|
440
|
+
return dx
|
441
|
+
|
442
|
+
def surrogate_fun(self, x):
|
443
|
+
"""
|
444
|
+
Compute the surrogate function.
|
445
|
+
|
446
|
+
Parameters
|
447
|
+
----------
|
448
|
+
x : jax.Array
|
449
|
+
The input array.
|
450
|
+
|
451
|
+
Returns
|
452
|
+
-------
|
453
|
+
jax.Array
|
454
|
+
The output of the surrogate function.
|
455
|
+
"""
|
456
|
+
return jnp.where(
|
457
|
+
x < 0,
|
458
|
+
jnp.exp(self.alpha * x) / 2,
|
459
|
+
1 - jnp.exp(-self.alpha * x) / 2
|
460
|
+
)
|
461
|
+
|
462
|
+
def __repr__(self):
|
463
|
+
"""
|
464
|
+
Return a string representation of the PiecewiseExp instance.
|
465
|
+
|
466
|
+
Returns
|
467
|
+
-------
|
468
|
+
str
|
469
|
+
A string representation of the instance.
|
470
|
+
"""
|
471
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
472
|
+
|
473
|
+
def __hash__(self):
|
474
|
+
"""
|
475
|
+
Compute a hash value for the PiecewiseExp instance.
|
476
|
+
|
477
|
+
Returns
|
478
|
+
-------
|
479
|
+
int
|
480
|
+
A hash value for the instance.
|
481
|
+
"""
|
482
|
+
return hash((self.__class__, self.alpha))
|
483
|
+
|
484
|
+
|
485
|
+
def piecewise_exp(
|
486
|
+
x: jax.Array,
|
487
|
+
alpha: float = 1.,
|
488
|
+
|
489
|
+
):
|
490
|
+
r"""Judge spiking state with a piecewise exponential function [1]_.
|
491
|
+
|
492
|
+
This function implements a surrogate gradient method for spiking neural networks
|
493
|
+
using a piecewise exponential function. It provides a differentiable approximation
|
494
|
+
of the step function used in the forward pass of spiking neurons.
|
495
|
+
|
496
|
+
If `origin=False`, computes the forward function:
|
497
|
+
|
498
|
+
.. math::
|
499
|
+
|
500
|
+
g(x) = \begin{cases}
|
501
|
+
1, & x \geq 0 \\
|
502
|
+
0, & x < 0 \\
|
503
|
+
\end{cases}
|
504
|
+
|
505
|
+
If `origin=True`, computes the original function:
|
506
|
+
|
507
|
+
.. math::
|
508
|
+
|
509
|
+
g(x) = \begin{cases}
|
510
|
+
\frac{1}{2}e^{\alpha x}, & x < 0 \\
|
511
|
+
1 - \frac{1}{2}e^{-\alpha x}, & x \geq 0
|
512
|
+
\end{cases}
|
513
|
+
|
514
|
+
Backward function:
|
515
|
+
|
516
|
+
.. math::
|
517
|
+
|
518
|
+
g'(x) = \frac{\alpha}{2}e^{-\alpha |x|}
|
519
|
+
|
520
|
+
.. plot::
|
521
|
+
:include-source: True
|
522
|
+
|
523
|
+
>>> import jax
|
524
|
+
>>> import brainstate.nn as nn
|
525
|
+
>>> import brainstate as brainstate
|
526
|
+
>>> import matplotlib.pyplot as plt
|
527
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
528
|
+
>>> for alpha in [0.5, 1., 2., 4.]:
|
529
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.piecewise_exp)(xs, alpha)
|
530
|
+
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
531
|
+
>>> plt.legend()
|
532
|
+
>>> plt.show()
|
533
|
+
|
534
|
+
Parameters
|
535
|
+
----------
|
536
|
+
x : jax.Array
|
537
|
+
The input array representing the neuron's membrane potential.
|
538
|
+
alpha : float, optional
|
539
|
+
A parameter controlling the steepness of the surrogate gradient.
|
540
|
+
Higher values result in a steeper gradient. Default is 1.0.
|
541
|
+
|
542
|
+
Returns
|
543
|
+
-------
|
544
|
+
jax.Array
|
545
|
+
An array of the same shape as the input, containing binary values (0 or 1)
|
546
|
+
representing the spiking state of each neuron.
|
547
|
+
|
548
|
+
Notes
|
549
|
+
-----
|
550
|
+
The function uses different computations for forward and backward passes:
|
551
|
+
- Forward: Step function (1 for x >= 0, 0 for x < 0)
|
552
|
+
- Backward: Piecewise exponential function for smooth gradient
|
553
|
+
|
554
|
+
The surrogate gradient is defined as:
|
555
|
+
g'(x) = (alpha / 2) * exp(-alpha * |x|)
|
556
|
+
|
557
|
+
References
|
558
|
+
----------
|
559
|
+
.. [1] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63.
|
560
|
+
"""
|
561
|
+
return PiecewiseExp(alpha=alpha)(x)
|
562
|
+
|
563
|
+
|
564
|
+
class SoftSign(Surrogate):
|
565
|
+
"""Judge spiking state with a soft sign function.
|
566
|
+
|
567
|
+
See Also
|
568
|
+
--------
|
569
|
+
soft_sign
|
570
|
+
"""
|
571
|
+
|
572
|
+
def __init__(self, alpha=1.):
|
573
|
+
super().__init__()
|
574
|
+
self.alpha = alpha
|
575
|
+
|
576
|
+
def surrogate_grad(self, x):
|
577
|
+
dx = self.alpha * 0.5 / (1 + jnp.abs(self.alpha * x)) ** 2
|
578
|
+
return dx
|
579
|
+
|
580
|
+
def surrogate_fun(self, x):
|
581
|
+
return x / (2 / self.alpha + 2 * jnp.abs(x)) + 0.5
|
582
|
+
|
583
|
+
def __repr__(self):
|
584
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
585
|
+
|
586
|
+
def __hash__(self):
|
587
|
+
return hash((self.__class__, self.alpha))
|
588
|
+
|
589
|
+
|
590
|
+
def soft_sign(
|
591
|
+
x: jax.Array,
|
592
|
+
alpha: float = 1.,
|
593
|
+
|
594
|
+
):
|
595
|
+
r"""Judge spiking state with a soft sign function.
|
596
|
+
|
597
|
+
If `origin=False`, computes the forward function:
|
598
|
+
|
599
|
+
.. math::
|
600
|
+
|
601
|
+
g(x) = \begin{cases}
|
602
|
+
1, & x \geq 0 \\
|
603
|
+
0, & x < 0 \\
|
604
|
+
\end{cases}
|
605
|
+
|
606
|
+
If `origin=True`, computes the original function:
|
607
|
+
|
608
|
+
.. math::
|
609
|
+
|
610
|
+
g(x) = \frac{1}{2} (\frac{\alpha x}{1 + |\alpha x|} + 1)
|
611
|
+
= \frac{1}{2} (\frac{x}{\frac{1}{\alpha} + |x|} + 1)
|
612
|
+
|
613
|
+
Backward function:
|
614
|
+
|
615
|
+
.. math::
|
616
|
+
|
617
|
+
g'(x) = \frac{\alpha}{2(1 + |\alpha x|)^{2}} = \frac{1}{2\alpha(\frac{1}{\alpha} + |x|)^{2}}
|
618
|
+
|
619
|
+
.. plot::
|
620
|
+
:include-source: True
|
621
|
+
|
622
|
+
>>> import jax
|
623
|
+
>>> import brainstate.nn as nn
|
624
|
+
>>> import brainstate as brainstate
|
625
|
+
>>> import matplotlib.pyplot as plt
|
626
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
627
|
+
>>> for alpha in [0.5, 1., 2., 4.]:
|
628
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.soft_sign)(xs, alpha)
|
629
|
+
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
630
|
+
>>> plt.legend()
|
631
|
+
>>> plt.show()
|
632
|
+
|
633
|
+
Parameters
|
634
|
+
----------
|
635
|
+
x: jax.Array, Array
|
636
|
+
The input data.
|
637
|
+
alpha: float
|
638
|
+
Parameter to control smoothness of gradient
|
639
|
+
|
640
|
+
|
641
|
+
Returns
|
642
|
+
-------
|
643
|
+
out: jax.Array
|
644
|
+
The spiking state.
|
645
|
+
|
646
|
+
"""
|
647
|
+
return SoftSign(alpha=alpha)(x)
|
648
|
+
|
649
|
+
|
650
|
+
class Arctan(Surrogate):
|
651
|
+
"""Judge spiking state with an arctan function.
|
652
|
+
|
653
|
+
See Also
|
654
|
+
--------
|
655
|
+
arctan
|
656
|
+
"""
|
657
|
+
|
658
|
+
def __init__(self, alpha=1.):
|
659
|
+
super().__init__()
|
660
|
+
self.alpha = alpha
|
661
|
+
|
662
|
+
def surrogate_grad(self, x):
|
663
|
+
dx = self.alpha * 0.5 / (1 + (jnp.pi / 2 * self.alpha * x) ** 2)
|
664
|
+
return dx
|
665
|
+
|
666
|
+
def surrogate_fun(self, x):
|
667
|
+
return jnp.arctan2(jnp.pi / 2 * self.alpha * x) / jnp.pi + 0.5
|
668
|
+
|
669
|
+
def __repr__(self):
|
670
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
671
|
+
|
672
|
+
def __hash__(self):
|
673
|
+
return hash((self.__class__, self.alpha))
|
674
|
+
|
675
|
+
|
676
|
+
def arctan(
|
677
|
+
x: jax.Array,
|
678
|
+
alpha: float = 1.,
|
679
|
+
|
680
|
+
):
|
681
|
+
r"""Judge spiking state with an arctan function.
|
682
|
+
|
683
|
+
If `origin=False`, computes the forward function:
|
684
|
+
|
685
|
+
.. math::
|
686
|
+
|
687
|
+
g(x) = \begin{cases}
|
688
|
+
1, & x \geq 0 \\
|
689
|
+
0, & x < 0 \\
|
690
|
+
\end{cases}
|
691
|
+
|
692
|
+
If `origin=True`, computes the original function:
|
693
|
+
|
694
|
+
.. math::
|
695
|
+
|
696
|
+
g(x) = \frac{1}{\pi} \arctan(\frac{\pi}{2}\alpha x) + \frac{1}{2}
|
697
|
+
|
698
|
+
Backward function:
|
699
|
+
|
700
|
+
.. math::
|
701
|
+
|
702
|
+
g'(x) = \frac{\alpha}{2(1 + (\frac{\pi}{2}\alpha x)^2)}
|
703
|
+
|
704
|
+
.. plot::
|
705
|
+
:include-source: True
|
706
|
+
|
707
|
+
>>> import jax
|
708
|
+
>>> import brainstate.nn as nn
|
709
|
+
>>> import brainstate as brainstate
|
710
|
+
>>> import matplotlib.pyplot as plt
|
711
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
712
|
+
>>> for alpha in [0.5, 1., 2., 4.]:
|
713
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.arctan)(xs, alpha)
|
714
|
+
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
715
|
+
>>> plt.legend()
|
716
|
+
>>> plt.show()
|
717
|
+
|
718
|
+
Parameters
|
719
|
+
----------
|
720
|
+
x: jax.Array, Array
|
721
|
+
The input data.
|
722
|
+
alpha: float
|
723
|
+
Parameter to control smoothness of gradient
|
724
|
+
|
725
|
+
|
726
|
+
Returns
|
727
|
+
-------
|
728
|
+
out: jax.Array
|
729
|
+
The spiking state.
|
730
|
+
|
731
|
+
"""
|
732
|
+
return Arctan(alpha=alpha)(x)
|
733
|
+
|
734
|
+
|
735
|
+
class NonzeroSignLog(Surrogate):
|
736
|
+
"""Judge spiking state with a nonzero sign log function.
|
737
|
+
|
738
|
+
See Also
|
739
|
+
--------
|
740
|
+
nonzero_sign_log
|
741
|
+
"""
|
742
|
+
|
743
|
+
def __init__(self, alpha=1.):
|
744
|
+
super().__init__()
|
745
|
+
self.alpha = alpha
|
746
|
+
|
747
|
+
def surrogate_grad(self, x):
|
748
|
+
dx = 1. / (1 / self.alpha + jnp.abs(x))
|
749
|
+
return dx
|
750
|
+
|
751
|
+
def surrogate_fun(self, x):
|
752
|
+
return jnp.where(x < 0, -1., 1.) * jnp.log(jnp.abs(self.alpha * x) + 1)
|
753
|
+
|
754
|
+
def __repr__(self):
|
755
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
756
|
+
|
757
|
+
def __hash__(self):
|
758
|
+
return hash((self.__class__, self.alpha))
|
759
|
+
|
760
|
+
|
761
|
+
def nonzero_sign_log(
|
762
|
+
x: jax.Array,
|
763
|
+
alpha: float = 1.,
|
764
|
+
|
765
|
+
):
|
766
|
+
r"""Judge spiking state with a nonzero sign log function.
|
767
|
+
|
768
|
+
If `origin=False`, computes the forward function:
|
769
|
+
|
770
|
+
.. math::
|
771
|
+
|
772
|
+
g(x) = \begin{cases}
|
773
|
+
1, & x \geq 0 \\
|
774
|
+
0, & x < 0 \\
|
775
|
+
\end{cases}
|
776
|
+
|
777
|
+
If `origin=True`, computes the original function:
|
778
|
+
|
779
|
+
.. math::
|
780
|
+
|
781
|
+
g(x) = \mathrm{NonzeroSign}(x) \log (|\alpha x| + 1)
|
782
|
+
|
783
|
+
where
|
784
|
+
|
785
|
+
.. math::
|
786
|
+
|
787
|
+
\begin{split}\mathrm{NonzeroSign}(x) =
|
788
|
+
\begin{cases}
|
789
|
+
1, & x \geq 0 \\
|
790
|
+
-1, & x < 0 \\
|
791
|
+
\end{cases}\end{split}
|
792
|
+
|
793
|
+
Backward function:
|
794
|
+
|
795
|
+
.. math::
|
796
|
+
|
797
|
+
g'(x) = \frac{\alpha}{1 + |\alpha x|} = \frac{1}{\frac{1}{\alpha} + |x|}
|
798
|
+
|
799
|
+
This surrogate function has the advantage of low computation cost during the backward.
|
800
|
+
|
801
|
+
|
802
|
+
.. plot::
|
803
|
+
:include-source: True
|
804
|
+
|
805
|
+
>>> import jax
|
806
|
+
>>> import brainstate.nn as nn
|
807
|
+
>>> import brainstate as brainstate
|
808
|
+
>>> import matplotlib.pyplot as plt
|
809
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
810
|
+
>>> for alpha in [0.5, 1., 2., 4.]:
|
811
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.nonzero_sign_log)(xs, alpha)
|
812
|
+
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
813
|
+
>>> plt.legend()
|
814
|
+
>>> plt.show()
|
815
|
+
|
816
|
+
Parameters
|
817
|
+
----------
|
818
|
+
x: jax.Array, Array
|
819
|
+
The input data.
|
820
|
+
alpha: float
|
821
|
+
Parameter to control smoothness of gradient
|
822
|
+
|
823
|
+
|
824
|
+
Returns
|
825
|
+
-------
|
826
|
+
out: jax.Array
|
827
|
+
The spiking state.
|
828
|
+
|
829
|
+
"""
|
830
|
+
return NonzeroSignLog(alpha=alpha)(x)
|
831
|
+
|
832
|
+
|
833
|
+
class ERF(Surrogate):
|
834
|
+
"""Judge spiking state with an erf function.
|
835
|
+
|
836
|
+
See Also
|
837
|
+
--------
|
838
|
+
erf
|
839
|
+
"""
|
840
|
+
|
841
|
+
def __init__(self, alpha=1.):
|
842
|
+
super().__init__()
|
843
|
+
self.alpha = alpha
|
844
|
+
|
845
|
+
def surrogate_grad(self, x):
|
846
|
+
dx = (self.alpha / jnp.sqrt(jnp.pi)) * jnp.exp(-jnp.power(self.alpha, 2) * x * x)
|
847
|
+
return dx
|
848
|
+
|
849
|
+
def surrogate_fun(self, x):
|
850
|
+
return sci.special.erf(-self.alpha * x) * 0.5
|
851
|
+
|
852
|
+
def __repr__(self):
|
853
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
854
|
+
|
855
|
+
def __hash__(self):
|
856
|
+
return hash((self.__class__, self.alpha))
|
857
|
+
|
858
|
+
|
859
|
+
def erf(
|
860
|
+
x: jax.Array,
|
861
|
+
alpha: float = 1.,
|
862
|
+
|
863
|
+
):
|
864
|
+
r"""Judge spiking state with an erf function [1]_ [2]_ [3]_.
|
865
|
+
|
866
|
+
If `origin=False`, computes the forward function:
|
867
|
+
|
868
|
+
.. math::
|
869
|
+
|
870
|
+
g(x) = \begin{cases}
|
871
|
+
1, & x \geq 0 \\
|
872
|
+
0, & x < 0 \\
|
873
|
+
\end{cases}
|
874
|
+
|
875
|
+
If `origin=True`, computes the original function:
|
876
|
+
|
877
|
+
.. math::
|
878
|
+
|
879
|
+
\begin{split}
|
880
|
+
g(x) &= \frac{1}{2}(1-\text{erf}(-\alpha x)) \\
|
881
|
+
&= \frac{1}{2} \text{erfc}(-\alpha x) \\
|
882
|
+
&= \frac{1}{\sqrt{\pi}}\int_{-\infty}^{\alpha x}e^{-t^2}dt
|
883
|
+
\end{split}
|
884
|
+
|
885
|
+
Backward function:
|
886
|
+
|
887
|
+
.. math::
|
888
|
+
|
889
|
+
g'(x) = \frac{\alpha}{\sqrt{\pi}}e^{-\alpha^2x^2}
|
890
|
+
|
891
|
+
.. plot::
|
892
|
+
:include-source: True
|
893
|
+
|
894
|
+
>>> import jax
|
895
|
+
>>> import brainstate.nn as nn
|
896
|
+
>>> import brainstate as brainstate
|
897
|
+
>>> import matplotlib.pyplot as plt
|
898
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
899
|
+
>>> for alpha in [0.5, 1., 2., 4.]:
|
900
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.nonzero_sign_log)(xs, alpha)
|
901
|
+
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
902
|
+
>>> plt.legend()
|
903
|
+
>>> plt.show()
|
904
|
+
|
905
|
+
Parameters
|
906
|
+
----------
|
907
|
+
x: jax.Array, Array
|
908
|
+
The input data.
|
909
|
+
alpha: float
|
910
|
+
Parameter to control smoothness of gradient
|
911
|
+
|
912
|
+
|
913
|
+
Returns
|
914
|
+
-------
|
915
|
+
out: jax.Array
|
916
|
+
The spiking state.
|
917
|
+
|
918
|
+
References
|
919
|
+
----------
|
920
|
+
.. [1] Esser S K, Appuswamy R, Merolla P, et al. Backpropagation for energy-efficient neuromorphic computing[J]. Advances in neural information processing systems, 2015, 28: 1117-1125.
|
921
|
+
.. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331.
|
922
|
+
.. [3] Yin B, Corradi F, Bohté S M. Effective and efficient computation with multiple-timescale spiking recurrent neural networks[C]//International Conference on Neuromorphic Systems 2020. 2020: 1-8.
|
923
|
+
|
924
|
+
"""
|
925
|
+
return ERF(alpha=alpha)(x)
|
926
|
+
|
927
|
+
|
928
|
+
class PiecewiseLeakyRelu(Surrogate):
|
929
|
+
"""Judge spiking state with a piecewise leaky relu function.
|
930
|
+
|
931
|
+
See Also
|
932
|
+
--------
|
933
|
+
piecewise_leaky_relu
|
934
|
+
"""
|
935
|
+
|
936
|
+
def __init__(self, c=0.01, w=1.):
|
937
|
+
super().__init__()
|
938
|
+
self.c = c
|
939
|
+
self.w = w
|
940
|
+
|
941
|
+
def surrogate_fun(self, x):
|
942
|
+
z = jnp.where(x < -self.w,
|
943
|
+
self.c * x + self.c * self.w,
|
944
|
+
jnp.where(x > self.w,
|
945
|
+
self.c * x - self.c * self.w + 1,
|
946
|
+
0.5 * x / self.w + 0.5))
|
947
|
+
return z
|
948
|
+
|
949
|
+
def surrogate_grad(self, x):
|
950
|
+
dx = jnp.where(jnp.abs(x) > self.w, self.c, 1 / self.w)
|
951
|
+
return dx
|
952
|
+
|
953
|
+
def __repr__(self):
|
954
|
+
return f'{self.__class__.__name__}(c={self.c}, w={self.w})'
|
955
|
+
|
956
|
+
def __hash__(self):
|
957
|
+
return hash((self.__class__, self.c, self.w))
|
958
|
+
|
959
|
+
|
960
|
+
def piecewise_leaky_relu(
|
961
|
+
x: jax.Array,
|
962
|
+
c: float = 0.01,
|
963
|
+
w: float = 1.,
|
964
|
+
|
965
|
+
):
|
966
|
+
r"""Judge spiking state with a piecewise leaky relu function [1]_ [2]_ [3]_ [4]_ [5]_ [6]_ [7]_ [8]_.
|
967
|
+
|
968
|
+
If `origin=False`, computes the forward function:
|
969
|
+
|
970
|
+
.. math::
|
971
|
+
|
972
|
+
g(x) = \begin{cases}
|
973
|
+
1, & x \geq 0 \\
|
974
|
+
0, & x < 0 \\
|
975
|
+
\end{cases}
|
976
|
+
|
977
|
+
If `origin=True`, computes the original function:
|
978
|
+
|
979
|
+
.. math::
|
980
|
+
|
981
|
+
\begin{split}g(x) =
|
982
|
+
\begin{cases}
|
983
|
+
cx + cw, & x < -w \\
|
984
|
+
\frac{1}{2w}x + \frac{1}{2}, & -w \leq x \leq w \\
|
985
|
+
cx - cw + 1, & x > w \\
|
986
|
+
\end{cases}\end{split}
|
987
|
+
|
988
|
+
Backward function:
|
989
|
+
|
990
|
+
.. math::
|
991
|
+
|
992
|
+
\begin{split}g'(x) =
|
993
|
+
\begin{cases}
|
994
|
+
\frac{1}{w}, & |x| \leq w \\
|
995
|
+
c, & |x| > w
|
996
|
+
\end{cases}\end{split}
|
997
|
+
|
998
|
+
.. plot::
|
999
|
+
:include-source: True
|
1000
|
+
|
1001
|
+
>>> import jax
|
1002
|
+
>>> import brainstate.nn as nn
|
1003
|
+
>>> import brainstate as brainstate
|
1004
|
+
>>> import matplotlib.pyplot as plt
|
1005
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1006
|
+
>>> for c in [0.01, 0.05, 0.1]:
|
1007
|
+
>>> for w in [1., 2.]:
|
1008
|
+
>>> grads1 = brainstate.augment.vector_grad(brainstate.surrogate.piecewise_leaky_relu)(xs, c=c, w=w)
|
1009
|
+
>>> plt.plot(xs, grads1, label=f'x={c}, w={w}')
|
1010
|
+
>>> plt.legend()
|
1011
|
+
>>> plt.show()
|
1012
|
+
|
1013
|
+
Parameters
|
1014
|
+
----------
|
1015
|
+
x: jax.Array, Array
|
1016
|
+
The input data.
|
1017
|
+
c: float
|
1018
|
+
When :math:`|x| > w` the gradient is `c`.
|
1019
|
+
w: float
|
1020
|
+
When :math:`|x| <= w` the gradient is `1 / w`.
|
1021
|
+
|
1022
|
+
|
1023
|
+
Returns
|
1024
|
+
-------
|
1025
|
+
out: jax.Array
|
1026
|
+
The spiking state.
|
1027
|
+
|
1028
|
+
References
|
1029
|
+
----------
|
1030
|
+
.. [1] Yin S, Venkataramanaiah S K, Chen G K, et al. Algorithm and hardware design of discrete-time spiking neural networks based on back propagation with binary activations[C]//2017 IEEE Biomedical Circuits and Systems Conference (BioCAS). IEEE, 2017: 1-5.
|
1031
|
+
.. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331.
|
1032
|
+
.. [3] Huh D, Sejnowski T J. Gradient descent for spiking neural networks[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 1440-1450.
|
1033
|
+
.. [4] Wu Y, Deng L, Li G, et al. Direct training for spiking neural networks: Faster, larger, better[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2019, 33(01): 1311-1318.
|
1034
|
+
.. [5] Gu P, Xiao R, Pan G, et al. STCA: Spatio-Temporal Credit Assignment with Delayed Feedback in Deep Spiking Neural Networks[C]//IJCAI. 2019: 1366-1372.
|
1035
|
+
.. [6] Roy D, Chakraborty I, Roy K. Scaling deep spiking neural networks with binary stochastic activations[C]//2019 IEEE International Conference on Cognitive Computing (ICCC). IEEE, 2019: 50-58.
|
1036
|
+
.. [7] Cheng X, Hao Y, Xu J, et al. LISNN: Improving Spiking Neural Networks with Lateral Interactions for Robust Object Recognition[C]//IJCAI. 1519-1525.
|
1037
|
+
.. [8] Kaiser J, Mostafa H, Neftci E. Synaptic plasticity dynamics for deep continuous local learning (DECOLLE)[J]. Frontiers in Neuroscience, 2020, 14: 424.
|
1038
|
+
|
1039
|
+
"""
|
1040
|
+
return PiecewiseLeakyRelu(c=c, w=w)(x)
|
1041
|
+
|
1042
|
+
|
1043
|
+
class SquarewaveFourierSeries(Surrogate):
|
1044
|
+
"""Judge spiking state with a squarewave fourier series.
|
1045
|
+
|
1046
|
+
See Also
|
1047
|
+
--------
|
1048
|
+
squarewave_fourier_series
|
1049
|
+
"""
|
1050
|
+
|
1051
|
+
def __init__(self, n=2, t_period=8.):
|
1052
|
+
super().__init__()
|
1053
|
+
self.n = n
|
1054
|
+
self.t_period = t_period
|
1055
|
+
|
1056
|
+
def surrogate_grad(self, x):
|
1057
|
+
|
1058
|
+
w = jnp.pi * 2. / self.t_period
|
1059
|
+
dx = jnp.cos(w * x)
|
1060
|
+
for i in range(2, self.n):
|
1061
|
+
dx += jnp.cos((2 * i - 1.) * w * x)
|
1062
|
+
dx *= 4. / self.t_period
|
1063
|
+
return dx
|
1064
|
+
|
1065
|
+
def surrogate_fun(self, x):
|
1066
|
+
|
1067
|
+
w = jnp.pi * 2. / self.t_period
|
1068
|
+
ret = jnp.sin(w * x)
|
1069
|
+
for i in range(2, self.n):
|
1070
|
+
c = (2 * i - 1.)
|
1071
|
+
ret += jnp.sin(c * w * x) / c
|
1072
|
+
z = 0.5 + 2. / jnp.pi * ret
|
1073
|
+
return z
|
1074
|
+
|
1075
|
+
def __repr__(self):
|
1076
|
+
return f'{self.__class__.__name__}(n={self.n}, t_period={self.t_period})'
|
1077
|
+
|
1078
|
+
def __hash__(self):
|
1079
|
+
return hash((self.__class__, self.n, self.t_period))
|
1080
|
+
|
1081
|
+
|
1082
|
+
def squarewave_fourier_series(
|
1083
|
+
x: jax.Array,
|
1084
|
+
n: int = 2,
|
1085
|
+
t_period: float = 8.,
|
1086
|
+
|
1087
|
+
):
|
1088
|
+
r"""Judge spiking state with a squarewave fourier series.
|
1089
|
+
|
1090
|
+
If `origin=False`, computes the forward function:
|
1091
|
+
|
1092
|
+
.. math::
|
1093
|
+
|
1094
|
+
g(x) = \begin{cases}
|
1095
|
+
1, & x \geq 0 \\
|
1096
|
+
0, & x < 0 \\
|
1097
|
+
\end{cases}
|
1098
|
+
|
1099
|
+
If `origin=True`, computes the original function:
|
1100
|
+
|
1101
|
+
.. math::
|
1102
|
+
|
1103
|
+
g(x) = 0.5 + \frac{1}{\pi}*\sum_{i=1}^n {\sin\left({(2i-1)*2\pi}*x/T\right) \over 2i-1 }
|
1104
|
+
|
1105
|
+
Backward function:
|
1106
|
+
|
1107
|
+
.. math::
|
1108
|
+
|
1109
|
+
g'(x) = \sum_{i=1}^n\frac{4\cos\left((2 * i - 1.) * 2\pi * x / T\right)}{T}
|
1110
|
+
|
1111
|
+
.. plot::
|
1112
|
+
:include-source: True
|
1113
|
+
|
1114
|
+
>>> import jax
|
1115
|
+
>>> import brainstate.nn as nn
|
1116
|
+
>>> import brainstate as brainstate
|
1117
|
+
>>> import matplotlib.pyplot as plt
|
1118
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1119
|
+
>>> for n in [2, 4, 8]:
|
1120
|
+
>>> f = brainstate.surrogate.SquarewaveFourierSeries(n=n)
|
1121
|
+
>>> grads1 = brainstate.augment.vector_grad(f)(xs)
|
1122
|
+
>>> plt.plot(xs, grads1, label=f'n={n}')
|
1123
|
+
>>> plt.legend()
|
1124
|
+
>>> plt.show()
|
1125
|
+
|
1126
|
+
Parameters
|
1127
|
+
----------
|
1128
|
+
x: jax.Array, Array
|
1129
|
+
The input data.
|
1130
|
+
n: int
|
1131
|
+
t_period: float
|
1132
|
+
|
1133
|
+
|
1134
|
+
Returns
|
1135
|
+
-------
|
1136
|
+
out: jax.Array
|
1137
|
+
The spiking state.
|
1138
|
+
|
1139
|
+
"""
|
1140
|
+
|
1141
|
+
return SquarewaveFourierSeries(n=n, t_period=t_period)(x)
|
1142
|
+
|
1143
|
+
|
1144
|
+
class S2NN(Surrogate):
|
1145
|
+
"""Judge spiking state with the S2NN surrogate spiking function.
|
1146
|
+
|
1147
|
+
See Also
|
1148
|
+
--------
|
1149
|
+
s2nn
|
1150
|
+
"""
|
1151
|
+
|
1152
|
+
def __init__(self, alpha=4., beta=1., epsilon=1e-8):
|
1153
|
+
super().__init__()
|
1154
|
+
self.alpha = alpha
|
1155
|
+
self.beta = beta
|
1156
|
+
self.epsilon = epsilon
|
1157
|
+
|
1158
|
+
def surrogate_fun(self, x):
|
1159
|
+
z = jnp.where(x < 0.,
|
1160
|
+
sci.special.expit(x * self.alpha),
|
1161
|
+
self.beta * jnp.log(jnp.abs((x + 1.)) + self.epsilon) + 0.5)
|
1162
|
+
return z
|
1163
|
+
|
1164
|
+
def surrogate_grad(self, x):
|
1165
|
+
sg = sci.special.expit(self.alpha * x)
|
1166
|
+
dx = jnp.where(x < 0., self.alpha * sg * (1. - sg), self.beta / (x + 1.))
|
1167
|
+
return dx
|
1168
|
+
|
1169
|
+
def __repr__(self):
|
1170
|
+
return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta}, epsilon={self.epsilon})'
|
1171
|
+
|
1172
|
+
def __hash__(self):
|
1173
|
+
return hash((self.__class__, self.alpha, self.beta, self.epsilon))
|
1174
|
+
|
1175
|
+
|
1176
|
+
def s2nn(
|
1177
|
+
x: jax.Array,
|
1178
|
+
alpha: float = 4.,
|
1179
|
+
beta: float = 1.,
|
1180
|
+
epsilon: float = 1e-8,
|
1181
|
+
|
1182
|
+
):
|
1183
|
+
r"""Judge spiking state with the S2NN surrogate spiking function [1]_.
|
1184
|
+
|
1185
|
+
If `origin=False`, computes the forward function:
|
1186
|
+
|
1187
|
+
.. math::
|
1188
|
+
|
1189
|
+
g(x) = \begin{cases}
|
1190
|
+
1, & x \geq 0 \\
|
1191
|
+
0, & x < 0 \\
|
1192
|
+
\end{cases}
|
1193
|
+
|
1194
|
+
If `origin=True`, computes the original function:
|
1195
|
+
|
1196
|
+
.. math::
|
1197
|
+
|
1198
|
+
\begin{split}g(x) = \begin{cases}
|
1199
|
+
\mathrm{sigmoid} (\alpha x), x < 0 \\
|
1200
|
+
\beta \ln(|x + 1|) + 0.5, x \ge 0
|
1201
|
+
\end{cases}\end{split}
|
1202
|
+
|
1203
|
+
Backward function:
|
1204
|
+
|
1205
|
+
.. math::
|
1206
|
+
|
1207
|
+
\begin{split}g'(x) = \begin{cases}
|
1208
|
+
\alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x), x < 0 \\
|
1209
|
+
\frac{\beta}{(x + 1)}, x \ge 0
|
1210
|
+
\end{cases}\end{split}
|
1211
|
+
|
1212
|
+
.. plot::
|
1213
|
+
:include-source: True
|
1214
|
+
|
1215
|
+
>>> import jax
|
1216
|
+
>>> import brainstate.nn as nn
|
1217
|
+
>>> import brainstate as brainstate
|
1218
|
+
>>> import matplotlib.pyplot as plt
|
1219
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1220
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.s2nn)(xs, 4., 1.)
|
1221
|
+
>>> plt.plot(xs, grads, label=r'$\alpha=4, \beta=1$')
|
1222
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.s2nn)(xs, 8., 2.)
|
1223
|
+
>>> plt.plot(xs, grads, label=r'$\alpha=8, \beta=2$')
|
1224
|
+
>>> plt.legend()
|
1225
|
+
>>> plt.show()
|
1226
|
+
|
1227
|
+
Parameters
|
1228
|
+
----------
|
1229
|
+
x: jax.Array, Array
|
1230
|
+
The input data.
|
1231
|
+
alpha: float
|
1232
|
+
The param that controls the gradient when ``x < 0``.
|
1233
|
+
beta: float
|
1234
|
+
The param that controls the gradient when ``x >= 0``
|
1235
|
+
epsilon: float
|
1236
|
+
Avoid nan
|
1237
|
+
|
1238
|
+
|
1239
|
+
Returns
|
1240
|
+
-------
|
1241
|
+
out: jax.Array
|
1242
|
+
The spiking state.
|
1243
|
+
|
1244
|
+
References
|
1245
|
+
----------
|
1246
|
+
.. [1] Suetake, Kazuma et al. “S2NN: Time Step Reduction of Spiking Surrogate Gradients for Training Energy Efficient Single-Step Neural Networks.” ArXiv abs/2201.10879 (2022): n. pag.
|
1247
|
+
|
1248
|
+
"""
|
1249
|
+
return S2NN(alpha=alpha, beta=beta, epsilon=epsilon)(x)
|
1250
|
+
|
1251
|
+
|
1252
|
+
class QPseudoSpike(Surrogate):
|
1253
|
+
"""Judge spiking state with the q-PseudoSpike surrogate function.
|
1254
|
+
|
1255
|
+
See Also
|
1256
|
+
--------
|
1257
|
+
q_pseudo_spike
|
1258
|
+
"""
|
1259
|
+
|
1260
|
+
def __init__(self, alpha=2.):
|
1261
|
+
super().__init__()
|
1262
|
+
self.alpha = alpha
|
1263
|
+
|
1264
|
+
def surrogate_grad(self, x):
|
1265
|
+
dx = jnp.power(1 + 2 / (self.alpha + 1) * jnp.abs(x), -self.alpha)
|
1266
|
+
return dx
|
1267
|
+
|
1268
|
+
def surrogate_fun(self, x):
|
1269
|
+
z = jnp.where(x < 0.,
|
1270
|
+
0.5 * jnp.power(1 - 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha),
|
1271
|
+
1. - 0.5 * jnp.power(1 + 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha))
|
1272
|
+
return z
|
1273
|
+
|
1274
|
+
def __repr__(self):
|
1275
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
1276
|
+
|
1277
|
+
def __hash__(self):
|
1278
|
+
return hash((self.__class__, self.alpha))
|
1279
|
+
|
1280
|
+
|
1281
|
+
def q_pseudo_spike(
|
1282
|
+
x: jax.Array,
|
1283
|
+
alpha: float = 2.,
|
1284
|
+
|
1285
|
+
):
|
1286
|
+
r"""Judge spiking state with the q-PseudoSpike surrogate function [1]_.
|
1287
|
+
|
1288
|
+
If `origin=False`, computes the forward function:
|
1289
|
+
|
1290
|
+
.. math::
|
1291
|
+
|
1292
|
+
g(x) = \begin{cases}
|
1293
|
+
1, & x \geq 0 \\
|
1294
|
+
0, & x < 0 \\
|
1295
|
+
\end{cases}
|
1296
|
+
|
1297
|
+
If `origin=True`, computes the original function:
|
1298
|
+
|
1299
|
+
.. math::
|
1300
|
+
|
1301
|
+
\begin{split}g(x) =
|
1302
|
+
\begin{cases}
|
1303
|
+
\frac{1}{2}(1-\frac{2x}{\alpha-1})^{1-\alpha}, & x < 0 \\
|
1304
|
+
1 - \frac{1}{2}(1+\frac{2x}{\alpha-1})^{1-\alpha}, & x \geq 0.
|
1305
|
+
\end{cases}\end{split}
|
1306
|
+
|
1307
|
+
Backward function:
|
1308
|
+
|
1309
|
+
.. math::
|
1310
|
+
|
1311
|
+
g'(x) = (1+\frac{2|x|}{\alpha-1})^{-\alpha}
|
1312
|
+
|
1313
|
+
.. plot::
|
1314
|
+
:include-source: True
|
1315
|
+
|
1316
|
+
>>> import jax
|
1317
|
+
>>> import brainstate.nn as nn
|
1318
|
+
>>> import brainstate as brainstate
|
1319
|
+
>>> import matplotlib.pyplot as plt
|
1320
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1321
|
+
>>> for alpha in [0.5, 1., 2., 4.]:
|
1322
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.q_pseudo_spike)(xs, alpha)
|
1323
|
+
>>> plt.plot(xs, grads, label=r'$\alpha=$' + str(alpha))
|
1324
|
+
>>> plt.legend()
|
1325
|
+
>>> plt.show()
|
1326
|
+
|
1327
|
+
Parameters
|
1328
|
+
----------
|
1329
|
+
x: jax.Array, Array
|
1330
|
+
The input data.
|
1331
|
+
alpha: float
|
1332
|
+
The parameter to control tail fatness of gradient.
|
1333
|
+
|
1334
|
+
|
1335
|
+
Returns
|
1336
|
+
-------
|
1337
|
+
out: jax.Array
|
1338
|
+
The spiking state.
|
1339
|
+
|
1340
|
+
References
|
1341
|
+
----------
|
1342
|
+
.. [1] Herranz-Celotti, Luca and Jean Rouat. “Surrogate Gradients Design.” ArXiv abs/2202.00282 (2022): n. pag.
|
1343
|
+
"""
|
1344
|
+
return QPseudoSpike(alpha=alpha)(x)
|
1345
|
+
|
1346
|
+
|
1347
|
+
class LeakyRelu(Surrogate):
|
1348
|
+
"""Judge spiking state with the Leaky ReLU function.
|
1349
|
+
|
1350
|
+
See Also
|
1351
|
+
--------
|
1352
|
+
leaky_relu
|
1353
|
+
"""
|
1354
|
+
|
1355
|
+
def __init__(self, alpha=0.1, beta=1.):
|
1356
|
+
super().__init__()
|
1357
|
+
self.alpha = alpha
|
1358
|
+
self.beta = beta
|
1359
|
+
|
1360
|
+
def surrogate_fun(self, x):
|
1361
|
+
return jnp.where(x < 0., self.alpha * x, self.beta * x)
|
1362
|
+
|
1363
|
+
def surrogate_grad(self, x):
|
1364
|
+
dx = jnp.where(x < 0., self.alpha, self.beta)
|
1365
|
+
return dx
|
1366
|
+
|
1367
|
+
def __repr__(self):
|
1368
|
+
return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta})'
|
1369
|
+
|
1370
|
+
def __hash__(self):
|
1371
|
+
return hash((self.__class__, self.alpha, self.beta))
|
1372
|
+
|
1373
|
+
|
1374
|
+
def leaky_relu(
|
1375
|
+
x: jax.Array,
|
1376
|
+
alpha: float = 0.1,
|
1377
|
+
beta: float = 1.,
|
1378
|
+
|
1379
|
+
):
|
1380
|
+
r"""Judge spiking state with the Leaky ReLU function.
|
1381
|
+
|
1382
|
+
If `origin=False`, computes the forward function:
|
1383
|
+
|
1384
|
+
.. math::
|
1385
|
+
|
1386
|
+
g(x) = \begin{cases}
|
1387
|
+
1, & x \geq 0 \\
|
1388
|
+
0, & x < 0 \\
|
1389
|
+
\end{cases}
|
1390
|
+
|
1391
|
+
If `origin=True`, computes the original function:
|
1392
|
+
|
1393
|
+
.. math::
|
1394
|
+
|
1395
|
+
\begin{split}g(x) =
|
1396
|
+
\begin{cases}
|
1397
|
+
\beta \cdot x, & x \geq 0 \\
|
1398
|
+
\alpha \cdot x, & x < 0 \\
|
1399
|
+
\end{cases}\end{split}
|
1400
|
+
|
1401
|
+
Backward function:
|
1402
|
+
|
1403
|
+
.. math::
|
1404
|
+
|
1405
|
+
\begin{split}g'(x) =
|
1406
|
+
\begin{cases}
|
1407
|
+
\beta, & x \geq 0 \\
|
1408
|
+
\alpha, & x < 0 \\
|
1409
|
+
\end{cases}\end{split}
|
1410
|
+
|
1411
|
+
.. plot::
|
1412
|
+
:include-source: True
|
1413
|
+
|
1414
|
+
>>> import jax
|
1415
|
+
>>> import brainstate.nn as nn
|
1416
|
+
>>> import brainstate as brainstate
|
1417
|
+
>>> import matplotlib.pyplot as plt
|
1418
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1419
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.leaky_relu)(xs, 0., 1.)
|
1420
|
+
>>> plt.plot(xs, grads, label=r'$\alpha=0., \beta=1.$')
|
1421
|
+
>>> plt.legend()
|
1422
|
+
>>> plt.show()
|
1423
|
+
|
1424
|
+
Parameters
|
1425
|
+
----------
|
1426
|
+
x: jax.Array, Array
|
1427
|
+
The input data.
|
1428
|
+
alpha: float
|
1429
|
+
The parameter to control the gradient when :math:`x < 0`.
|
1430
|
+
beta: float
|
1431
|
+
The parameter to control the gradient when :math:`x >= 0`.
|
1432
|
+
|
1433
|
+
|
1434
|
+
Returns
|
1435
|
+
-------
|
1436
|
+
out: jax.Array
|
1437
|
+
The spiking state.
|
1438
|
+
"""
|
1439
|
+
return LeakyRelu(alpha=alpha, beta=beta)(x)
|
1440
|
+
|
1441
|
+
|
1442
|
+
class LogTailedRelu(Surrogate):
|
1443
|
+
"""Judge spiking state with the Log-tailed ReLU function.
|
1444
|
+
|
1445
|
+
See Also
|
1446
|
+
--------
|
1447
|
+
log_tailed_relu
|
1448
|
+
"""
|
1449
|
+
|
1450
|
+
def __init__(self, alpha=0.):
|
1451
|
+
super().__init__()
|
1452
|
+
self.alpha = alpha
|
1453
|
+
|
1454
|
+
def surrogate_fun(self, x):
|
1455
|
+
z = jnp.where(x > 1,
|
1456
|
+
jnp.log(x),
|
1457
|
+
jnp.where(x > 0,
|
1458
|
+
x,
|
1459
|
+
self.alpha * x))
|
1460
|
+
return z
|
1461
|
+
|
1462
|
+
def surrogate_grad(self, x):
|
1463
|
+
dx = jnp.where(x > 1,
|
1464
|
+
1 / x,
|
1465
|
+
jnp.where(x > 0,
|
1466
|
+
1.,
|
1467
|
+
self.alpha))
|
1468
|
+
return dx
|
1469
|
+
|
1470
|
+
def __repr__(self):
|
1471
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
1472
|
+
|
1473
|
+
def __hash__(self):
|
1474
|
+
return hash((self.__class__, self.alpha))
|
1475
|
+
|
1476
|
+
|
1477
|
+
def log_tailed_relu(
|
1478
|
+
x: jax.Array,
|
1479
|
+
alpha: float = 0.,
|
1480
|
+
|
1481
|
+
):
|
1482
|
+
r"""Judge spiking state with the Log-tailed ReLU function [1]_.
|
1483
|
+
|
1484
|
+
If `origin=False`, computes the forward function:
|
1485
|
+
|
1486
|
+
.. math::
|
1487
|
+
|
1488
|
+
g(x) = \begin{cases}
|
1489
|
+
1, & x \geq 0 \\
|
1490
|
+
0, & x < 0 \\
|
1491
|
+
\end{cases}
|
1492
|
+
|
1493
|
+
If `origin=True`, computes the original function:
|
1494
|
+
|
1495
|
+
.. math::
|
1496
|
+
|
1497
|
+
\begin{split}g(x) =
|
1498
|
+
\begin{cases}
|
1499
|
+
\alpha x, & x \leq 0 \\
|
1500
|
+
x, & 0 < x \leq 0 \\
|
1501
|
+
log(x), x > 1 \\
|
1502
|
+
\end{cases}\end{split}
|
1503
|
+
|
1504
|
+
Backward function:
|
1505
|
+
|
1506
|
+
.. math::
|
1507
|
+
|
1508
|
+
\begin{split}g'(x) =
|
1509
|
+
\begin{cases}
|
1510
|
+
\alpha, & x \leq 0 \\
|
1511
|
+
1, & 0 < x \leq 0 \\
|
1512
|
+
\frac{1}{x}, x > 1 \\
|
1513
|
+
\end{cases}\end{split}
|
1514
|
+
|
1515
|
+
.. plot::
|
1516
|
+
:include-source: True
|
1517
|
+
|
1518
|
+
>>> import jax
|
1519
|
+
>>> import brainstate.nn as nn
|
1520
|
+
>>> import brainstate as brainstate
|
1521
|
+
>>> import matplotlib.pyplot as plt
|
1522
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1523
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.leaky_relu)(xs, 0., 1.)
|
1524
|
+
>>> plt.plot(xs, grads, label=r'$\alpha=0., \beta=1.$')
|
1525
|
+
>>> plt.legend()
|
1526
|
+
>>> plt.show()
|
1527
|
+
|
1528
|
+
Parameters
|
1529
|
+
----------
|
1530
|
+
x: jax.Array, Array
|
1531
|
+
The input data.
|
1532
|
+
alpha: float
|
1533
|
+
The parameter to control the gradient.
|
1534
|
+
|
1535
|
+
|
1536
|
+
Returns
|
1537
|
+
-------
|
1538
|
+
out: jax.Array
|
1539
|
+
The spiking state.
|
1540
|
+
|
1541
|
+
References
|
1542
|
+
----------
|
1543
|
+
.. [1] Cai, Zhaowei et al. “Deep Learning with Low Precision by Half-Wave Gaussian Quantization.” 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2017): 5406-5414.
|
1544
|
+
"""
|
1545
|
+
return LogTailedRelu(alpha=alpha)(x)
|
1546
|
+
|
1547
|
+
|
1548
|
+
class ReluGrad(Surrogate):
|
1549
|
+
"""Judge spiking state with the ReLU gradient function.
|
1550
|
+
|
1551
|
+
See Also
|
1552
|
+
--------
|
1553
|
+
relu_grad
|
1554
|
+
"""
|
1555
|
+
|
1556
|
+
def __init__(self, alpha=0.3, width=1.):
|
1557
|
+
super().__init__()
|
1558
|
+
self.alpha = alpha
|
1559
|
+
self.width = width
|
1560
|
+
|
1561
|
+
def surrogate_grad(self, x):
|
1562
|
+
dx = jnp.maximum(self.alpha * self.width - jnp.abs(x) * self.alpha, 0)
|
1563
|
+
return dx
|
1564
|
+
|
1565
|
+
def __repr__(self):
|
1566
|
+
return f'{self.__class__.__name__}(alpha={self.alpha}, width={self.width})'
|
1567
|
+
|
1568
|
+
def __hash__(self):
|
1569
|
+
return hash((self.__class__, self.alpha, self.width))
|
1570
|
+
|
1571
|
+
|
1572
|
+
def relu_grad(
|
1573
|
+
x: jax.Array,
|
1574
|
+
alpha: float = 0.3,
|
1575
|
+
width: float = 1.,
|
1576
|
+
):
|
1577
|
+
r"""Spike function with the ReLU gradient function [1]_.
|
1578
|
+
|
1579
|
+
The forward function:
|
1580
|
+
|
1581
|
+
.. math::
|
1582
|
+
|
1583
|
+
g(x) = \begin{cases}
|
1584
|
+
1, & x \geq 0 \\
|
1585
|
+
0, & x < 0 \\
|
1586
|
+
\end{cases}
|
1587
|
+
|
1588
|
+
Backward function:
|
1589
|
+
|
1590
|
+
.. math::
|
1591
|
+
|
1592
|
+
g'(x) = \text{ReLU}(\alpha * (\mathrm{width}-|x|))
|
1593
|
+
|
1594
|
+
.. plot::
|
1595
|
+
:include-source: True
|
1596
|
+
|
1597
|
+
>>> import jax
|
1598
|
+
>>> import brainstate.nn as nn
|
1599
|
+
>>> import brainstate as brainstate
|
1600
|
+
>>> import matplotlib.pyplot as plt
|
1601
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1602
|
+
>>> for s in [0.5, 1.]:
|
1603
|
+
>>> for w in [1, 2.]:
|
1604
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.relu_grad)(xs, s, w)
|
1605
|
+
>>> plt.plot(xs, grads, label=r'$\alpha=$' + f'{s}, width={w}')
|
1606
|
+
>>> plt.legend()
|
1607
|
+
>>> plt.show()
|
1608
|
+
|
1609
|
+
Parameters
|
1610
|
+
----------
|
1611
|
+
x: jax.Array, Array
|
1612
|
+
The input data.
|
1613
|
+
alpha: float
|
1614
|
+
The parameter to control the gradient.
|
1615
|
+
width: float
|
1616
|
+
The parameter to control the width of the gradient.
|
1617
|
+
|
1618
|
+
Returns
|
1619
|
+
-------
|
1620
|
+
out: jax.Array
|
1621
|
+
The spiking state.
|
1622
|
+
|
1623
|
+
References
|
1624
|
+
----------
|
1625
|
+
.. [1] Neftci, E. O., Mostafa, H. & Zenke, F. Surrogate gradient learning in spiking neural networks. IEEE Signal Process. Mag. 36, 61–63 (2019).
|
1626
|
+
"""
|
1627
|
+
return ReluGrad(alpha=alpha, width=width)(x)
|
1628
|
+
|
1629
|
+
|
1630
|
+
class GaussianGrad(Surrogate):
|
1631
|
+
"""Judge spiking state with the Gaussian gradient function.
|
1632
|
+
|
1633
|
+
See Also
|
1634
|
+
--------
|
1635
|
+
gaussian_grad
|
1636
|
+
"""
|
1637
|
+
|
1638
|
+
def __init__(self, sigma=0.5, alpha=0.5):
|
1639
|
+
super().__init__()
|
1640
|
+
self.sigma = sigma
|
1641
|
+
self.alpha = alpha
|
1642
|
+
|
1643
|
+
def surrogate_grad(self, x):
|
1644
|
+
dx = jnp.exp(-(x ** 2) / 2 * jnp.power(self.sigma, 2)) / (jnp.sqrt(2 * jnp.pi) * self.sigma)
|
1645
|
+
return self.alpha * dx
|
1646
|
+
|
1647
|
+
def __repr__(self):
|
1648
|
+
return f'{self.__class__.__name__}(alpha={self.alpha}, sigma={self.sigma})'
|
1649
|
+
|
1650
|
+
def __hash__(self):
|
1651
|
+
return hash((self.__class__, self.alpha, self.sigma))
|
1652
|
+
|
1653
|
+
|
1654
|
+
def gaussian_grad(
|
1655
|
+
x: jax.Array,
|
1656
|
+
sigma: float = 0.5,
|
1657
|
+
alpha: float = 0.5,
|
1658
|
+
):
|
1659
|
+
r"""Spike function with the Gaussian gradient function [1]_.
|
1660
|
+
|
1661
|
+
The forward function:
|
1662
|
+
|
1663
|
+
.. math::
|
1664
|
+
|
1665
|
+
g(x) = \begin{cases}
|
1666
|
+
1, & x \geq 0 \\
|
1667
|
+
0, & x < 0 \\
|
1668
|
+
\end{cases}
|
1669
|
+
|
1670
|
+
Backward function:
|
1671
|
+
|
1672
|
+
.. math::
|
1673
|
+
|
1674
|
+
g'(x) = \alpha * \text{gaussian}(x, 0., \sigma)
|
1675
|
+
|
1676
|
+
.. plot::
|
1677
|
+
:include-source: True
|
1678
|
+
|
1679
|
+
>>> import jax
|
1680
|
+
>>> import brainstate.nn as nn
|
1681
|
+
>>> import brainstate as brainstate
|
1682
|
+
>>> import matplotlib.pyplot as plt
|
1683
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1684
|
+
>>> for s in [0.5, 1., 2.]:
|
1685
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.gaussian_grad)(xs, s, 0.5)
|
1686
|
+
>>> plt.plot(xs, grads, label=r'$\alpha=0.5, \sigma=$' + str(s))
|
1687
|
+
>>> plt.legend()
|
1688
|
+
>>> plt.show()
|
1689
|
+
|
1690
|
+
Parameters
|
1691
|
+
----------
|
1692
|
+
x: jax.Array, Array
|
1693
|
+
The input data.
|
1694
|
+
sigma: float
|
1695
|
+
The parameter to control the variance of gaussian distribution.
|
1696
|
+
alpha: float
|
1697
|
+
The parameter to control the scale of the gradient.
|
1698
|
+
|
1699
|
+
Returns
|
1700
|
+
-------
|
1701
|
+
out: jax.Array
|
1702
|
+
The spiking state.
|
1703
|
+
|
1704
|
+
References
|
1705
|
+
----------
|
1706
|
+
.. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021).
|
1707
|
+
"""
|
1708
|
+
return GaussianGrad(sigma=sigma, alpha=alpha)(x)
|
1709
|
+
|
1710
|
+
|
1711
|
+
class MultiGaussianGrad(Surrogate):
|
1712
|
+
"""Judge spiking state with the multi-Gaussian gradient function.
|
1713
|
+
|
1714
|
+
See Also
|
1715
|
+
--------
|
1716
|
+
multi_gaussian_grad
|
1717
|
+
"""
|
1718
|
+
|
1719
|
+
def __init__(self, h=0.15, s=6.0, sigma=0.5, scale=0.5):
|
1720
|
+
super().__init__()
|
1721
|
+
self.h = h
|
1722
|
+
self.s = s
|
1723
|
+
self.sigma = sigma
|
1724
|
+
self.scale = scale
|
1725
|
+
|
1726
|
+
def surrogate_grad(self, x):
|
1727
|
+
g1 = jnp.exp(-x ** 2 / (2 * jnp.power(self.sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * self.sigma)
|
1728
|
+
g2 = jnp.exp(-(x - self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2))
|
1729
|
+
) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma)
|
1730
|
+
g3 = jnp.exp(-(x + self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2))
|
1731
|
+
) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma)
|
1732
|
+
dx = g1 * (1. + self.h) - g2 * self.h - g3 * self.h
|
1733
|
+
return self.scale * dx
|
1734
|
+
|
1735
|
+
def __repr__(self):
|
1736
|
+
return f'{self.__class__.__name__}(h={self.h}, s={self.s}, sigma={self.sigma}, scale={self.scale})'
|
1737
|
+
|
1738
|
+
def __hash__(self):
|
1739
|
+
return hash((self.__class__, self.h, self.s, self.sigma, self.scale))
|
1740
|
+
|
1741
|
+
|
1742
|
+
def multi_gaussian_grad(
|
1743
|
+
x: jax.Array,
|
1744
|
+
h: float = 0.15,
|
1745
|
+
s: float = 6.0,
|
1746
|
+
sigma: float = 0.5,
|
1747
|
+
scale: float = 0.5,
|
1748
|
+
):
|
1749
|
+
r"""Spike function with the multi-Gaussian gradient function [1]_.
|
1750
|
+
|
1751
|
+
The forward function:
|
1752
|
+
|
1753
|
+
.. math::
|
1754
|
+
|
1755
|
+
g(x) = \begin{cases}
|
1756
|
+
1, & x \geq 0 \\
|
1757
|
+
0, & x < 0 \\
|
1758
|
+
\end{cases}
|
1759
|
+
|
1760
|
+
Backward function:
|
1761
|
+
|
1762
|
+
.. math::
|
1763
|
+
|
1764
|
+
\begin{array}{l}
|
1765
|
+
g'(x)=(1+h){{{\mathcal{N}}}}(x, 0, {\sigma }^{2})
|
1766
|
+
-h{{{\mathcal{N}}}}(x, \sigma,{(s\sigma )}^{2})-
|
1767
|
+
h{{{\mathcal{N}}}}(x, -\sigma ,{(s\sigma )}^{2})
|
1768
|
+
\end{array}
|
1769
|
+
|
1770
|
+
|
1771
|
+
.. plot::
|
1772
|
+
:include-source: True
|
1773
|
+
|
1774
|
+
>>> import jax
|
1775
|
+
>>> import brainstate.nn as nn
|
1776
|
+
>>> import brainstate as brainstate
|
1777
|
+
>>> import matplotlib.pyplot as plt
|
1778
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1779
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.multi_gaussian_grad)(xs)
|
1780
|
+
>>> plt.plot(xs, grads)
|
1781
|
+
>>> plt.show()
|
1782
|
+
|
1783
|
+
Parameters
|
1784
|
+
----------
|
1785
|
+
x: jax.Array, Array
|
1786
|
+
The input data.
|
1787
|
+
h: float
|
1788
|
+
The hyper-parameters of approximate function
|
1789
|
+
s: float
|
1790
|
+
The hyper-parameters of approximate function
|
1791
|
+
sigma: float
|
1792
|
+
The gaussian sigma.
|
1793
|
+
scale: float
|
1794
|
+
The gradient scale.
|
1795
|
+
|
1796
|
+
Returns
|
1797
|
+
-------
|
1798
|
+
out: jax.Array
|
1799
|
+
The spiking state.
|
1800
|
+
|
1801
|
+
References
|
1802
|
+
----------
|
1803
|
+
.. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021).
|
1804
|
+
"""
|
1805
|
+
return MultiGaussianGrad(h=h, s=s, sigma=sigma, scale=scale)(x)
|
1806
|
+
|
1807
|
+
|
1808
|
+
class InvSquareGrad(Surrogate):
|
1809
|
+
"""Judge spiking state with the inverse-square surrogate gradient function.
|
1810
|
+
|
1811
|
+
See Also
|
1812
|
+
--------
|
1813
|
+
inv_square_grad
|
1814
|
+
"""
|
1815
|
+
|
1816
|
+
def __init__(self, alpha=100.):
|
1817
|
+
super().__init__()
|
1818
|
+
self.alpha = alpha
|
1819
|
+
|
1820
|
+
def surrogate_grad(self, x):
|
1821
|
+
dx = 1. / (self.alpha * jnp.abs(x) + 1.0) ** 2
|
1822
|
+
return dx
|
1823
|
+
|
1824
|
+
def __repr__(self):
|
1825
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
1826
|
+
|
1827
|
+
def __hash__(self):
|
1828
|
+
return hash((self.__class__, self.alpha))
|
1829
|
+
|
1830
|
+
|
1831
|
+
def inv_square_grad(
|
1832
|
+
x: jax.Array,
|
1833
|
+
alpha: float = 100.
|
1834
|
+
):
|
1835
|
+
r"""Spike function with the inverse-square surrogate gradient.
|
1836
|
+
|
1837
|
+
Forward function:
|
1838
|
+
|
1839
|
+
.. math::
|
1840
|
+
|
1841
|
+
g(x) = \begin{cases}
|
1842
|
+
1, & x \geq 0 \\
|
1843
|
+
0, & x < 0 \\
|
1844
|
+
\end{cases}
|
1845
|
+
|
1846
|
+
Backward function:
|
1847
|
+
|
1848
|
+
.. math::
|
1849
|
+
|
1850
|
+
g'(x) = \frac{1}{(\alpha * |x| + 1.) ^ 2}
|
1851
|
+
|
1852
|
+
|
1853
|
+
.. plot::
|
1854
|
+
:include-source: True
|
1855
|
+
|
1856
|
+
>>> import jax
|
1857
|
+
>>> import brainstate.nn as nn
|
1858
|
+
>>> import brainstate as brainstate
|
1859
|
+
>>> import matplotlib.pyplot as plt
|
1860
|
+
>>> xs = jax.numpy.linspace(-1, 1, 1000)
|
1861
|
+
>>> for alpha in [1., 10., 100.]:
|
1862
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.inv_square_grad)(xs, alpha)
|
1863
|
+
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
1864
|
+
>>> plt.legend()
|
1865
|
+
>>> plt.show()
|
1866
|
+
|
1867
|
+
Parameters
|
1868
|
+
----------
|
1869
|
+
x: jax.Array, Array
|
1870
|
+
The input data.
|
1871
|
+
alpha: float
|
1872
|
+
Parameter to control smoothness of gradient
|
1873
|
+
|
1874
|
+
Returns
|
1875
|
+
-------
|
1876
|
+
out: jax.Array
|
1877
|
+
The spiking state.
|
1878
|
+
"""
|
1879
|
+
return InvSquareGrad(alpha=alpha)(x)
|
1880
|
+
|
1881
|
+
|
1882
|
+
class SlayerGrad(Surrogate):
|
1883
|
+
"""Judge spiking state with the slayer surrogate gradient function.
|
1884
|
+
|
1885
|
+
See Also
|
1886
|
+
--------
|
1887
|
+
slayer_grad
|
1888
|
+
"""
|
1889
|
+
|
1890
|
+
def __init__(self, alpha=1.):
|
1891
|
+
super().__init__()
|
1892
|
+
self.alpha = alpha
|
1893
|
+
|
1894
|
+
def surrogate_grad(self, x):
|
1895
|
+
dx = jnp.exp(-self.alpha * jnp.abs(x))
|
1896
|
+
return dx
|
1897
|
+
|
1898
|
+
def __repr__(self):
|
1899
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
1900
|
+
|
1901
|
+
def __hash__(self):
|
1902
|
+
return hash((self.__class__, self.alpha))
|
1903
|
+
|
1904
|
+
|
1905
|
+
def slayer_grad(
|
1906
|
+
x: jax.Array,
|
1907
|
+
alpha: float = 1.
|
1908
|
+
):
|
1909
|
+
r"""Spike function with the slayer surrogate gradient function.
|
1910
|
+
|
1911
|
+
Forward function:
|
1912
|
+
|
1913
|
+
.. math::
|
1914
|
+
|
1915
|
+
g(x) = \begin{cases}
|
1916
|
+
1, & x \geq 0 \\
|
1917
|
+
0, & x < 0 \\
|
1918
|
+
\end{cases}
|
1919
|
+
|
1920
|
+
Backward function:
|
1921
|
+
|
1922
|
+
.. math::
|
1923
|
+
|
1924
|
+
g'(x) = \exp(-\alpha |x|)
|
1925
|
+
|
1926
|
+
|
1927
|
+
.. plot::
|
1928
|
+
:include-source: True
|
1929
|
+
|
1930
|
+
>>> import jax
|
1931
|
+
>>> import brainstate.nn as nn
|
1932
|
+
>>> import brainstate as brainstate
|
1933
|
+
>>> import matplotlib.pyplot as plt
|
1934
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1935
|
+
>>> for alpha in [0.5, 1., 2., 4.]:
|
1936
|
+
>>> grads = brainstate.augment.vector_grad(brainstate.surrogate.slayer_grad)(xs, alpha)
|
1937
|
+
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
1938
|
+
>>> plt.legend()
|
1939
|
+
>>> plt.show()
|
1940
|
+
|
1941
|
+
Parameters
|
1942
|
+
----------
|
1943
|
+
x: jax.Array, Array
|
1944
|
+
The input data.
|
1945
|
+
alpha: float
|
1946
|
+
Parameter to control smoothness of gradient
|
1947
|
+
|
1948
|
+
Returns
|
1949
|
+
-------
|
1950
|
+
out: jax.Array
|
1951
|
+
The spiking state.
|
1952
|
+
|
1953
|
+
References
|
1954
|
+
----------
|
1955
|
+
.. [1] Shrestha, S. B. & Orchard, G. Slayer: spike layer error reassignment in time. In Advances in Neural Information Processing Systems Vol. 31, 1412–1421 (NeurIPS, 2018).
|
1956
|
+
"""
|
1957
|
+
return SlayerGrad(alpha=alpha)(x)
|