brainstate 0.0.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +45 -0
- brainstate/_module.py +1466 -0
- brainstate/_module_test.py +133 -0
- brainstate/_state.py +378 -0
- brainstate/_state_test.py +41 -0
- brainstate/_utils.py +21 -0
- brainstate/environ.py +375 -0
- brainstate/functional/__init__.py +25 -0
- brainstate/functional/_activations.py +754 -0
- brainstate/functional/_normalization.py +69 -0
- brainstate/functional/_spikes.py +90 -0
- brainstate/init/__init__.py +26 -0
- brainstate/init/_base.py +36 -0
- brainstate/init/_generic.py +175 -0
- brainstate/init/_random_inits.py +489 -0
- brainstate/init/_regular_inits.py +109 -0
- brainstate/math/__init__.py +21 -0
- brainstate/math/_einops.py +787 -0
- brainstate/math/_einops_parsing.py +169 -0
- brainstate/math/_einops_parsing_test.py +126 -0
- brainstate/math/_einops_test.py +346 -0
- brainstate/math/_misc.py +298 -0
- brainstate/math/_misc_test.py +58 -0
- brainstate/mixin.py +373 -0
- brainstate/mixin_test.py +73 -0
- brainstate/nn/__init__.py +68 -0
- brainstate/nn/_base.py +248 -0
- brainstate/nn/_connections.py +686 -0
- brainstate/nn/_dynamics.py +406 -0
- brainstate/nn/_elementwise.py +1437 -0
- brainstate/nn/_misc.py +132 -0
- brainstate/nn/_normalizations.py +389 -0
- brainstate/nn/_others.py +100 -0
- brainstate/nn/_poolings.py +1228 -0
- brainstate/nn/_poolings_test.py +231 -0
- brainstate/nn/_projection/__init__.py +32 -0
- brainstate/nn/_projection/_align_post.py +528 -0
- brainstate/nn/_projection/_align_pre.py +599 -0
- brainstate/nn/_projection/_delta.py +241 -0
- brainstate/nn/_projection/_utils.py +17 -0
- brainstate/nn/_projection/_vanilla.py +101 -0
- brainstate/nn/_rate_rnns.py +393 -0
- brainstate/nn/_readout.py +130 -0
- brainstate/nn/_synouts.py +166 -0
- brainstate/nn/functional/__init__.py +25 -0
- brainstate/nn/functional/_activations.py +754 -0
- brainstate/nn/functional/_normalization.py +69 -0
- brainstate/nn/functional/_spikes.py +90 -0
- brainstate/nn/init/__init__.py +26 -0
- brainstate/nn/init/_base.py +36 -0
- brainstate/nn/init/_generic.py +175 -0
- brainstate/nn/init/_random_inits.py +489 -0
- brainstate/nn/init/_regular_inits.py +109 -0
- brainstate/nn/surrogate.py +1740 -0
- brainstate/optim/__init__.py +23 -0
- brainstate/optim/_lr_scheduler.py +486 -0
- brainstate/optim/_lr_scheduler_test.py +36 -0
- brainstate/optim/_sgd_optimizer.py +1148 -0
- brainstate/random.py +5148 -0
- brainstate/random_test.py +576 -0
- brainstate/surrogate.py +1740 -0
- brainstate/transform/__init__.py +36 -0
- brainstate/transform/_autograd.py +585 -0
- brainstate/transform/_autograd_test.py +1183 -0
- brainstate/transform/_control.py +665 -0
- brainstate/transform/_controls_test.py +220 -0
- brainstate/transform/_jit.py +239 -0
- brainstate/transform/_jit_error.py +158 -0
- brainstate/transform/_jit_test.py +102 -0
- brainstate/transform/_make_jaxpr.py +573 -0
- brainstate/transform/_make_jaxpr_test.py +133 -0
- brainstate/transform/_progress_bar.py +113 -0
- brainstate/typing.py +69 -0
- brainstate/util.py +747 -0
- brainstate-0.0.1.dist-info/LICENSE +202 -0
- brainstate-0.0.1.dist-info/METADATA +101 -0
- brainstate-0.0.1.dist-info/RECORD +79 -0
- brainstate-0.0.1.dist-info/WHEEL +6 -0
- brainstate-0.0.1.dist-info/top_level.txt +1 -0
brainstate/surrogate.py
ADDED
@@ -0,0 +1,1740 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
import jax
|
19
|
+
import jax.numpy as jnp
|
20
|
+
import jax.scipy as sci
|
21
|
+
from jax.core import Primitive
|
22
|
+
from jax.interpreters import batching, ad, mlir
|
23
|
+
|
24
|
+
__all__ = [
|
25
|
+
'Surrogate',
|
26
|
+
'Sigmoid',
|
27
|
+
'sigmoid',
|
28
|
+
'PiecewiseQuadratic',
|
29
|
+
'piecewise_quadratic',
|
30
|
+
'PiecewiseExp',
|
31
|
+
'piecewise_exp',
|
32
|
+
'SoftSign',
|
33
|
+
'soft_sign',
|
34
|
+
'Arctan',
|
35
|
+
'arctan',
|
36
|
+
'NonzeroSignLog',
|
37
|
+
'nonzero_sign_log',
|
38
|
+
'ERF',
|
39
|
+
'erf',
|
40
|
+
'PiecewiseLeakyRelu',
|
41
|
+
'piecewise_leaky_relu',
|
42
|
+
'SquarewaveFourierSeries',
|
43
|
+
'squarewave_fourier_series',
|
44
|
+
'S2NN',
|
45
|
+
's2nn',
|
46
|
+
'QPseudoSpike',
|
47
|
+
'q_pseudo_spike',
|
48
|
+
'LeakyRelu',
|
49
|
+
'leaky_relu',
|
50
|
+
'LogTailedRelu',
|
51
|
+
'log_tailed_relu',
|
52
|
+
'ReluGrad',
|
53
|
+
'relu_grad',
|
54
|
+
'GaussianGrad',
|
55
|
+
'gaussian_grad',
|
56
|
+
'InvSquareGrad',
|
57
|
+
'inv_square_grad',
|
58
|
+
'MultiGaussianGrad',
|
59
|
+
'multi_gaussian_grad',
|
60
|
+
'SlayerGrad',
|
61
|
+
'slayer_grad',
|
62
|
+
]
|
63
|
+
|
64
|
+
|
65
|
+
def _heaviside_abstract(x, dx):
|
66
|
+
return [x]
|
67
|
+
|
68
|
+
|
69
|
+
def _heaviside_imp(x, dx):
|
70
|
+
z = jnp.asarray(x >= 0, dtype=x.dtype)
|
71
|
+
return [z]
|
72
|
+
|
73
|
+
|
74
|
+
def _heaviside_batching(args, axes):
|
75
|
+
return heaviside_p.bind(*args), [axes[0]]
|
76
|
+
|
77
|
+
|
78
|
+
def _heaviside_jvp(primals, tangents):
|
79
|
+
x, dx = primals
|
80
|
+
tx, tdx = tangents
|
81
|
+
primal_outs = heaviside_p.bind(x, dx)
|
82
|
+
tangent_outs = [dx * tx, ]
|
83
|
+
return primal_outs, tangent_outs
|
84
|
+
|
85
|
+
|
86
|
+
heaviside_p = Primitive('heaviside_p')
|
87
|
+
heaviside_p.multiple_results = True
|
88
|
+
heaviside_p.def_abstract_eval(_heaviside_abstract)
|
89
|
+
heaviside_p.def_impl(_heaviside_imp)
|
90
|
+
batching.primitive_batchers[heaviside_p] = _heaviside_batching
|
91
|
+
ad.primitive_jvps[heaviside_p] = _heaviside_jvp
|
92
|
+
mlir.register_lowering(heaviside_p, mlir.lower_fun(_heaviside_imp, multiple_results=True))
|
93
|
+
|
94
|
+
|
95
|
+
class Surrogate(object):
|
96
|
+
"""The base surrograte gradient function.
|
97
|
+
|
98
|
+
To customize a surrogate gradient function, you can inherit this class and
|
99
|
+
implement the `surrogate_fun` and `surrogate_grad` methods.
|
100
|
+
|
101
|
+
Examples
|
102
|
+
--------
|
103
|
+
|
104
|
+
>>> import brainstate as bst
|
105
|
+
>>> import brainstate.nn as nn
|
106
|
+
>>> import jax.numpy as jnp
|
107
|
+
|
108
|
+
>>> class MySurrogate(nn.surrogate.Surrogate):
|
109
|
+
... def __init__(self, alpha=1.):
|
110
|
+
... super().__init__()
|
111
|
+
... self.alpha = alpha
|
112
|
+
...
|
113
|
+
... def surrogate_fun(self, x):
|
114
|
+
... return jnp.sin(x) * self.alpha
|
115
|
+
...
|
116
|
+
... def surrogate_grad(self, x):
|
117
|
+
... return jnp.cos(x) * self.alpha
|
118
|
+
|
119
|
+
"""
|
120
|
+
|
121
|
+
def __call__(self, x):
|
122
|
+
dx = self.surrogate_grad(x)
|
123
|
+
return heaviside_p.bind(x, dx)[0]
|
124
|
+
|
125
|
+
def __repr__(self):
|
126
|
+
return f'{self.__class__.__name__}()'
|
127
|
+
|
128
|
+
def surrogate_fun(self, x) -> jax.Array:
|
129
|
+
"""The surrogate function."""
|
130
|
+
raise NotImplementedError
|
131
|
+
|
132
|
+
def surrogate_grad(self, x) -> jax.Array:
|
133
|
+
"""The gradient function of the surrogate function."""
|
134
|
+
raise NotImplementedError
|
135
|
+
|
136
|
+
|
137
|
+
class Sigmoid(Surrogate):
|
138
|
+
"""Spike function with the sigmoid-shaped surrogate gradient.
|
139
|
+
|
140
|
+
See Also
|
141
|
+
--------
|
142
|
+
sigmoid
|
143
|
+
|
144
|
+
"""
|
145
|
+
|
146
|
+
def __init__(self, alpha: float = 4.):
|
147
|
+
super().__init__()
|
148
|
+
self.alpha = alpha
|
149
|
+
|
150
|
+
def surrogate_fun(self, x):
|
151
|
+
return sci.special.expit(self.alpha * x)
|
152
|
+
|
153
|
+
def surrogate_grad(self, x):
|
154
|
+
sgax = sci.special.expit(x * self.alpha)
|
155
|
+
dx = (1. - sgax) * sgax * self.alpha
|
156
|
+
return dx
|
157
|
+
|
158
|
+
def __repr__(self):
|
159
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
160
|
+
|
161
|
+
|
162
|
+
def sigmoid(
|
163
|
+
x: jax.Array,
|
164
|
+
alpha: float = 4.,
|
165
|
+
):
|
166
|
+
r"""Spike function with the sigmoid-shaped surrogate gradient.
|
167
|
+
|
168
|
+
If `origin=False`, return the forward function:
|
169
|
+
|
170
|
+
.. math::
|
171
|
+
|
172
|
+
g(x) = \begin{cases}
|
173
|
+
1, & x \geq 0 \\
|
174
|
+
0, & x < 0 \\
|
175
|
+
\end{cases}
|
176
|
+
|
177
|
+
If `origin=True`, computes the original function:
|
178
|
+
|
179
|
+
.. math::
|
180
|
+
|
181
|
+
g(x) = \mathrm{sigmoid}(\alpha x) = \frac{1}{1+e^{-\alpha x}}
|
182
|
+
|
183
|
+
Backward function:
|
184
|
+
|
185
|
+
.. math::
|
186
|
+
|
187
|
+
g'(x) = \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x)
|
188
|
+
|
189
|
+
.. plot::
|
190
|
+
:include-source: True
|
191
|
+
|
192
|
+
>>> import brainstate.nn as nn
|
193
|
+
>>> import brainstate as bst
|
194
|
+
>>> import matplotlib.pyplot as plt
|
195
|
+
>>> xs = jax.numpy.linspace(-2, 2, 1000)
|
196
|
+
>>> for alpha in [1., 2., 4.]:
|
197
|
+
>>> grads = bst.transform.vector_grad(nn.surrogate.sigmoid)(xs, alpha)
|
198
|
+
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
199
|
+
>>> plt.legend()
|
200
|
+
>>> plt.show()
|
201
|
+
|
202
|
+
Parameters
|
203
|
+
----------
|
204
|
+
x: jax.Array, Array
|
205
|
+
The input data.
|
206
|
+
alpha: float
|
207
|
+
Parameter to control smoothness of gradient
|
208
|
+
|
209
|
+
|
210
|
+
Returns
|
211
|
+
-------
|
212
|
+
out: jax.Array
|
213
|
+
The spiking state.
|
214
|
+
"""
|
215
|
+
return Sigmoid(alpha=alpha)(x)
|
216
|
+
|
217
|
+
|
218
|
+
class PiecewiseQuadratic(Surrogate):
|
219
|
+
"""Judge spiking state with a piecewise quadratic function.
|
220
|
+
|
221
|
+
See Also
|
222
|
+
--------
|
223
|
+
piecewise_quadratic
|
224
|
+
|
225
|
+
"""
|
226
|
+
|
227
|
+
def __init__(self, alpha: float = 1.):
|
228
|
+
super().__init__()
|
229
|
+
self.alpha = alpha
|
230
|
+
|
231
|
+
def surrogate_fun(self, x):
|
232
|
+
z = jnp.where(x < -1 / self.alpha,
|
233
|
+
0.,
|
234
|
+
jnp.where(x > 1 / self.alpha,
|
235
|
+
1.,
|
236
|
+
(-self.alpha * jnp.abs(x) / 2 + 1) * self.alpha * x + 0.5))
|
237
|
+
return z
|
238
|
+
|
239
|
+
def surrogate_grad(self, x):
|
240
|
+
dx = jnp.where(jnp.abs(x) > 1 / self.alpha, 0., (-(self.alpha * x) ** 2 + self.alpha))
|
241
|
+
return dx
|
242
|
+
|
243
|
+
def __repr__(self):
|
244
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
245
|
+
|
246
|
+
|
247
|
+
def piecewise_quadratic(
|
248
|
+
x: jax.Array,
|
249
|
+
alpha: float = 1.,
|
250
|
+
):
|
251
|
+
r"""Judge spiking state with a piecewise quadratic function [1]_ [2]_ [3]_ [4]_ [5]_.
|
252
|
+
|
253
|
+
If `origin=False`, computes the forward function:
|
254
|
+
|
255
|
+
.. math::
|
256
|
+
|
257
|
+
g(x) = \begin{cases}
|
258
|
+
1, & x \geq 0 \\
|
259
|
+
0, & x < 0 \\
|
260
|
+
\end{cases}
|
261
|
+
|
262
|
+
If `origin=True`, computes the original function:
|
263
|
+
|
264
|
+
.. math::
|
265
|
+
|
266
|
+
g(x) =
|
267
|
+
\begin{cases}
|
268
|
+
0, & x < -\frac{1}{\alpha} \\
|
269
|
+
-\frac{1}{2}\alpha^2|x|x + \alpha x + \frac{1}{2}, & |x| \leq \frac{1}{\alpha} \\
|
270
|
+
1, & x > \frac{1}{\alpha} \\
|
271
|
+
\end{cases}
|
272
|
+
|
273
|
+
Backward function:
|
274
|
+
|
275
|
+
.. math::
|
276
|
+
|
277
|
+
g'(x) =
|
278
|
+
\begin{cases}
|
279
|
+
0, & |x| > \frac{1}{\alpha} \\
|
280
|
+
-\alpha^2|x|+\alpha, & |x| \leq \frac{1}{\alpha}
|
281
|
+
\end{cases}
|
282
|
+
|
283
|
+
.. plot::
|
284
|
+
:include-source: True
|
285
|
+
|
286
|
+
>>> import brainstate.nn as nn
|
287
|
+
>>> import brainstate as bst
|
288
|
+
>>> import matplotlib.pyplot as plt
|
289
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
290
|
+
>>> for alpha in [0.5, 1., 2., 4.]:
|
291
|
+
>>> grads = bst.transform.vector_grad(nn.surrogate.piecewise_quadratic)(xs, alpha)
|
292
|
+
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
293
|
+
>>> plt.legend()
|
294
|
+
>>> plt.show()
|
295
|
+
|
296
|
+
Parameters
|
297
|
+
----------
|
298
|
+
x: jax.Array, Array
|
299
|
+
The input data.
|
300
|
+
alpha: float
|
301
|
+
Parameter to control smoothness of gradient
|
302
|
+
|
303
|
+
|
304
|
+
Returns
|
305
|
+
-------
|
306
|
+
out: jax.Array
|
307
|
+
The spiking state.
|
308
|
+
|
309
|
+
References
|
310
|
+
----------
|
311
|
+
.. [1] Esser S K, Merolla P A, Arthur J V, et al. Convolutional networks for fast, energy-efficient neuromorphic computing[J]. Proceedings of the national academy of sciences, 2016, 113(41): 11441-11446.
|
312
|
+
.. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331.
|
313
|
+
.. [3] Bellec G, Salaj D, Subramoney A, et al. Long short-term memory and learning-to-learn in networks of spiking neurons[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 795-805.
|
314
|
+
.. [4] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63.
|
315
|
+
.. [5] Panda P, Aketi S A, Roy K. Toward scalable, efficient, and accurate deep spiking neural networks with backward residual connections, stochastic softmax, and hybridization[J]. Frontiers in Neuroscience, 2020, 14.
|
316
|
+
"""
|
317
|
+
return PiecewiseQuadratic(alpha=alpha)(x)
|
318
|
+
|
319
|
+
|
320
|
+
class PiecewiseExp(Surrogate):
|
321
|
+
"""Judge spiking state with a piecewise exponential function.
|
322
|
+
|
323
|
+
See Also
|
324
|
+
--------
|
325
|
+
piecewise_exp
|
326
|
+
"""
|
327
|
+
|
328
|
+
def __init__(self, alpha: float = 1.):
|
329
|
+
super().__init__()
|
330
|
+
self.alpha = alpha
|
331
|
+
|
332
|
+
def surrogate_grad(self, x):
|
333
|
+
dx = (self.alpha / 2) * jnp.exp(-self.alpha * jnp.abs(x))
|
334
|
+
return dx
|
335
|
+
|
336
|
+
def surrogate_fun(self, x):
|
337
|
+
return jnp.where(x < 0, jnp.exp(self.alpha * x) / 2, 1 - jnp.exp(-self.alpha * x) / 2)
|
338
|
+
|
339
|
+
def __repr__(self):
|
340
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
341
|
+
|
342
|
+
|
343
|
+
def piecewise_exp(
|
344
|
+
x: jax.Array,
|
345
|
+
alpha: float = 1.,
|
346
|
+
|
347
|
+
):
|
348
|
+
r"""Judge spiking state with a piecewise exponential function [1]_.
|
349
|
+
|
350
|
+
If `origin=False`, computes the forward function:
|
351
|
+
|
352
|
+
.. math::
|
353
|
+
|
354
|
+
g(x) = \begin{cases}
|
355
|
+
1, & x \geq 0 \\
|
356
|
+
0, & x < 0 \\
|
357
|
+
\end{cases}
|
358
|
+
|
359
|
+
If `origin=True`, computes the original function:
|
360
|
+
|
361
|
+
.. math::
|
362
|
+
|
363
|
+
g(x) = \begin{cases}
|
364
|
+
\frac{1}{2}e^{\alpha x}, & x < 0 \\
|
365
|
+
1 - \frac{1}{2}e^{-\alpha x}, & x \geq 0
|
366
|
+
\end{cases}
|
367
|
+
|
368
|
+
Backward function:
|
369
|
+
|
370
|
+
.. math::
|
371
|
+
|
372
|
+
g'(x) = \frac{\alpha}{2}e^{-\alpha |x|}
|
373
|
+
|
374
|
+
.. plot::
|
375
|
+
:include-source: True
|
376
|
+
|
377
|
+
>>> import brainstate.nn as nn
|
378
|
+
>>> import brainstate as bst
|
379
|
+
>>> import matplotlib.pyplot as plt
|
380
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
381
|
+
>>> for alpha in [0.5, 1., 2., 4.]:
|
382
|
+
>>> grads = bst.transform.vector_grad(nn.surrogate.piecewise_exp)(xs, alpha)
|
383
|
+
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
384
|
+
>>> plt.legend()
|
385
|
+
>>> plt.show()
|
386
|
+
|
387
|
+
Parameters
|
388
|
+
----------
|
389
|
+
x: jax.Array, Array
|
390
|
+
The input data.
|
391
|
+
alpha: float
|
392
|
+
Parameter to control smoothness of gradient
|
393
|
+
|
394
|
+
|
395
|
+
Returns
|
396
|
+
-------
|
397
|
+
out: jax.Array
|
398
|
+
The spiking state.
|
399
|
+
|
400
|
+
References
|
401
|
+
----------
|
402
|
+
.. [1] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63.
|
403
|
+
"""
|
404
|
+
return PiecewiseExp(alpha=alpha)(x)
|
405
|
+
|
406
|
+
|
407
|
+
class SoftSign(Surrogate):
|
408
|
+
"""Judge spiking state with a soft sign function.
|
409
|
+
|
410
|
+
See Also
|
411
|
+
--------
|
412
|
+
soft_sign
|
413
|
+
"""
|
414
|
+
|
415
|
+
def __init__(self, alpha=1.):
|
416
|
+
super().__init__()
|
417
|
+
self.alpha = alpha
|
418
|
+
|
419
|
+
def surrogate_grad(self, x):
|
420
|
+
dx = self.alpha * 0.5 / (1 + jnp.abs(self.alpha * x)) ** 2
|
421
|
+
return dx
|
422
|
+
|
423
|
+
def surrogate_fun(self, x):
|
424
|
+
return x / (2 / self.alpha + 2 * jnp.abs(x)) + 0.5
|
425
|
+
|
426
|
+
def __repr__(self):
|
427
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
428
|
+
|
429
|
+
|
430
|
+
def soft_sign(
|
431
|
+
x: jax.Array,
|
432
|
+
alpha: float = 1.,
|
433
|
+
|
434
|
+
):
|
435
|
+
r"""Judge spiking state with a soft sign function.
|
436
|
+
|
437
|
+
If `origin=False`, computes the forward function:
|
438
|
+
|
439
|
+
.. math::
|
440
|
+
|
441
|
+
g(x) = \begin{cases}
|
442
|
+
1, & x \geq 0 \\
|
443
|
+
0, & x < 0 \\
|
444
|
+
\end{cases}
|
445
|
+
|
446
|
+
If `origin=True`, computes the original function:
|
447
|
+
|
448
|
+
.. math::
|
449
|
+
|
450
|
+
g(x) = \frac{1}{2} (\frac{\alpha x}{1 + |\alpha x|} + 1)
|
451
|
+
= \frac{1}{2} (\frac{x}{\frac{1}{\alpha} + |x|} + 1)
|
452
|
+
|
453
|
+
Backward function:
|
454
|
+
|
455
|
+
.. math::
|
456
|
+
|
457
|
+
g'(x) = \frac{\alpha}{2(1 + |\alpha x|)^{2}} = \frac{1}{2\alpha(\frac{1}{\alpha} + |x|)^{2}}
|
458
|
+
|
459
|
+
.. plot::
|
460
|
+
:include-source: True
|
461
|
+
|
462
|
+
>>> import brainstate.nn as nn
|
463
|
+
>>> import brainstate as bst
|
464
|
+
>>> import matplotlib.pyplot as plt
|
465
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
466
|
+
>>> for alpha in [0.5, 1., 2., 4.]:
|
467
|
+
>>> grads = bst.transform.vector_grad(nn.surrogate.soft_sign)(xs, alpha)
|
468
|
+
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
469
|
+
>>> plt.legend()
|
470
|
+
>>> plt.show()
|
471
|
+
|
472
|
+
Parameters
|
473
|
+
----------
|
474
|
+
x: jax.Array, Array
|
475
|
+
The input data.
|
476
|
+
alpha: float
|
477
|
+
Parameter to control smoothness of gradient
|
478
|
+
|
479
|
+
|
480
|
+
Returns
|
481
|
+
-------
|
482
|
+
out: jax.Array
|
483
|
+
The spiking state.
|
484
|
+
|
485
|
+
"""
|
486
|
+
return SoftSign(alpha=alpha)(x)
|
487
|
+
|
488
|
+
|
489
|
+
class Arctan(Surrogate):
|
490
|
+
"""Judge spiking state with an arctan function.
|
491
|
+
|
492
|
+
See Also
|
493
|
+
--------
|
494
|
+
arctan
|
495
|
+
"""
|
496
|
+
|
497
|
+
def __init__(self, alpha=1.):
|
498
|
+
super().__init__()
|
499
|
+
self.alpha = alpha
|
500
|
+
|
501
|
+
def surrogate_grad(self, x):
|
502
|
+
dx = self.alpha * 0.5 / (1 + (jnp.pi / 2 * self.alpha * x) ** 2)
|
503
|
+
return dx
|
504
|
+
|
505
|
+
def surrogate_fun(self, x):
|
506
|
+
return jnp.arctan2(jnp.pi / 2 * self.alpha * x) / jnp.pi + 0.5
|
507
|
+
|
508
|
+
def __repr__(self):
|
509
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
510
|
+
|
511
|
+
|
512
|
+
def arctan(
|
513
|
+
x: jax.Array,
|
514
|
+
alpha: float = 1.,
|
515
|
+
|
516
|
+
):
|
517
|
+
r"""Judge spiking state with an arctan function.
|
518
|
+
|
519
|
+
If `origin=False`, computes the forward function:
|
520
|
+
|
521
|
+
.. math::
|
522
|
+
|
523
|
+
g(x) = \begin{cases}
|
524
|
+
1, & x \geq 0 \\
|
525
|
+
0, & x < 0 \\
|
526
|
+
\end{cases}
|
527
|
+
|
528
|
+
If `origin=True`, computes the original function:
|
529
|
+
|
530
|
+
.. math::
|
531
|
+
|
532
|
+
g(x) = \frac{1}{\pi} \arctan(\frac{\pi}{2}\alpha x) + \frac{1}{2}
|
533
|
+
|
534
|
+
Backward function:
|
535
|
+
|
536
|
+
.. math::
|
537
|
+
|
538
|
+
g'(x) = \frac{\alpha}{2(1 + (\frac{\pi}{2}\alpha x)^2)}
|
539
|
+
|
540
|
+
.. plot::
|
541
|
+
:include-source: True
|
542
|
+
|
543
|
+
>>> import brainstate.nn as nn
|
544
|
+
>>> import brainstate as bst
|
545
|
+
>>> import matplotlib.pyplot as plt
|
546
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
547
|
+
>>> for alpha in [0.5, 1., 2., 4.]:
|
548
|
+
>>> grads = bst.transform.vector_grad(nn.surrogate.arctan)(xs, alpha)
|
549
|
+
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
550
|
+
>>> plt.legend()
|
551
|
+
>>> plt.show()
|
552
|
+
|
553
|
+
Parameters
|
554
|
+
----------
|
555
|
+
x: jax.Array, Array
|
556
|
+
The input data.
|
557
|
+
alpha: float
|
558
|
+
Parameter to control smoothness of gradient
|
559
|
+
|
560
|
+
|
561
|
+
Returns
|
562
|
+
-------
|
563
|
+
out: jax.Array
|
564
|
+
The spiking state.
|
565
|
+
|
566
|
+
"""
|
567
|
+
return Arctan(alpha=alpha)(x)
|
568
|
+
|
569
|
+
|
570
|
+
class NonzeroSignLog(Surrogate):
|
571
|
+
"""Judge spiking state with a nonzero sign log function.
|
572
|
+
|
573
|
+
See Also
|
574
|
+
--------
|
575
|
+
nonzero_sign_log
|
576
|
+
"""
|
577
|
+
|
578
|
+
def __init__(self, alpha=1.):
|
579
|
+
super().__init__()
|
580
|
+
self.alpha = alpha
|
581
|
+
|
582
|
+
def surrogate_grad(self, x):
|
583
|
+
dx = 1. / (1 / self.alpha + jnp.abs(x))
|
584
|
+
return dx
|
585
|
+
|
586
|
+
def surrogate_fun(self, x):
|
587
|
+
return jnp.where(x < 0, -1., 1.) * jnp.log(jnp.abs(self.alpha * x) + 1)
|
588
|
+
|
589
|
+
def __repr__(self):
|
590
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
591
|
+
|
592
|
+
|
593
|
+
def nonzero_sign_log(
|
594
|
+
x: jax.Array,
|
595
|
+
alpha: float = 1.,
|
596
|
+
|
597
|
+
):
|
598
|
+
r"""Judge spiking state with a nonzero sign log function.
|
599
|
+
|
600
|
+
If `origin=False`, computes the forward function:
|
601
|
+
|
602
|
+
.. math::
|
603
|
+
|
604
|
+
g(x) = \begin{cases}
|
605
|
+
1, & x \geq 0 \\
|
606
|
+
0, & x < 0 \\
|
607
|
+
\end{cases}
|
608
|
+
|
609
|
+
If `origin=True`, computes the original function:
|
610
|
+
|
611
|
+
.. math::
|
612
|
+
|
613
|
+
g(x) = \mathrm{NonzeroSign}(x) \log (|\alpha x| + 1)
|
614
|
+
|
615
|
+
where
|
616
|
+
|
617
|
+
.. math::
|
618
|
+
|
619
|
+
\begin{split}\mathrm{NonzeroSign}(x) =
|
620
|
+
\begin{cases}
|
621
|
+
1, & x \geq 0 \\
|
622
|
+
-1, & x < 0 \\
|
623
|
+
\end{cases}\end{split}
|
624
|
+
|
625
|
+
Backward function:
|
626
|
+
|
627
|
+
.. math::
|
628
|
+
|
629
|
+
g'(x) = \frac{\alpha}{1 + |\alpha x|} = \frac{1}{\frac{1}{\alpha} + |x|}
|
630
|
+
|
631
|
+
This surrogate function has the advantage of low computation cost during the backward.
|
632
|
+
|
633
|
+
|
634
|
+
.. plot::
|
635
|
+
:include-source: True
|
636
|
+
|
637
|
+
>>> import brainstate.nn as nn
|
638
|
+
>>> import brainstate as bst
|
639
|
+
>>> import matplotlib.pyplot as plt
|
640
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
641
|
+
>>> for alpha in [0.5, 1., 2., 4.]:
|
642
|
+
>>> grads = bst.transform.vector_grad(nn.surrogate.nonzero_sign_log)(xs, alpha)
|
643
|
+
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
644
|
+
>>> plt.legend()
|
645
|
+
>>> plt.show()
|
646
|
+
|
647
|
+
Parameters
|
648
|
+
----------
|
649
|
+
x: jax.Array, Array
|
650
|
+
The input data.
|
651
|
+
alpha: float
|
652
|
+
Parameter to control smoothness of gradient
|
653
|
+
|
654
|
+
|
655
|
+
Returns
|
656
|
+
-------
|
657
|
+
out: jax.Array
|
658
|
+
The spiking state.
|
659
|
+
|
660
|
+
"""
|
661
|
+
return NonzeroSignLog(alpha=alpha)(x)
|
662
|
+
|
663
|
+
|
664
|
+
class ERF(Surrogate):
|
665
|
+
"""Judge spiking state with an erf function.
|
666
|
+
|
667
|
+
See Also
|
668
|
+
--------
|
669
|
+
erf
|
670
|
+
"""
|
671
|
+
|
672
|
+
def __init__(self, alpha=1.):
|
673
|
+
super().__init__()
|
674
|
+
self.alpha = alpha
|
675
|
+
|
676
|
+
def surrogate_grad(self, x):
|
677
|
+
dx = (self.alpha / jnp.sqrt(jnp.pi)) * jnp.exp(-jnp.power(self.alpha, 2) * x * x)
|
678
|
+
return dx
|
679
|
+
|
680
|
+
def surrogate_fun(self, x):
|
681
|
+
return sci.special.erf(-self.alpha * x) * 0.5
|
682
|
+
|
683
|
+
def __repr__(self):
|
684
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
685
|
+
|
686
|
+
|
687
|
+
def erf(
|
688
|
+
x: jax.Array,
|
689
|
+
alpha: float = 1.,
|
690
|
+
|
691
|
+
):
|
692
|
+
r"""Judge spiking state with an erf function [1]_ [2]_ [3]_.
|
693
|
+
|
694
|
+
If `origin=False`, computes the forward function:
|
695
|
+
|
696
|
+
.. math::
|
697
|
+
|
698
|
+
g(x) = \begin{cases}
|
699
|
+
1, & x \geq 0 \\
|
700
|
+
0, & x < 0 \\
|
701
|
+
\end{cases}
|
702
|
+
|
703
|
+
If `origin=True`, computes the original function:
|
704
|
+
|
705
|
+
.. math::
|
706
|
+
|
707
|
+
\begin{split}
|
708
|
+
g(x) &= \frac{1}{2}(1-\text{erf}(-\alpha x)) \\
|
709
|
+
&= \frac{1}{2} \text{erfc}(-\alpha x) \\
|
710
|
+
&= \frac{1}{\sqrt{\pi}}\int_{-\infty}^{\alpha x}e^{-t^2}dt
|
711
|
+
\end{split}
|
712
|
+
|
713
|
+
Backward function:
|
714
|
+
|
715
|
+
.. math::
|
716
|
+
|
717
|
+
g'(x) = \frac{\alpha}{\sqrt{\pi}}e^{-\alpha^2x^2}
|
718
|
+
|
719
|
+
.. plot::
|
720
|
+
:include-source: True
|
721
|
+
|
722
|
+
>>> import brainstate.nn as nn
|
723
|
+
>>> import brainstate as bst
|
724
|
+
>>> import matplotlib.pyplot as plt
|
725
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
726
|
+
>>> for alpha in [0.5, 1., 2., 4.]:
|
727
|
+
>>> grads = bst.transform.vector_grad(nn.surrogate.nonzero_sign_log)(xs, alpha)
|
728
|
+
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
729
|
+
>>> plt.legend()
|
730
|
+
>>> plt.show()
|
731
|
+
|
732
|
+
Parameters
|
733
|
+
----------
|
734
|
+
x: jax.Array, Array
|
735
|
+
The input data.
|
736
|
+
alpha: float
|
737
|
+
Parameter to control smoothness of gradient
|
738
|
+
|
739
|
+
|
740
|
+
Returns
|
741
|
+
-------
|
742
|
+
out: jax.Array
|
743
|
+
The spiking state.
|
744
|
+
|
745
|
+
References
|
746
|
+
----------
|
747
|
+
.. [1] Esser S K, Appuswamy R, Merolla P, et al. Backpropagation for energy-efficient neuromorphic computing[J]. Advances in neural information processing systems, 2015, 28: 1117-1125.
|
748
|
+
.. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331.
|
749
|
+
.. [3] Yin B, Corradi F, Bohté S M. Effective and efficient computation with multiple-timescale spiking recurrent neural networks[C]//International Conference on Neuromorphic Systems 2020. 2020: 1-8.
|
750
|
+
|
751
|
+
"""
|
752
|
+
return ERF(alpha=alpha)(x)
|
753
|
+
|
754
|
+
|
755
|
+
class PiecewiseLeakyRelu(Surrogate):
|
756
|
+
"""Judge spiking state with a piecewise leaky relu function.
|
757
|
+
|
758
|
+
See Also
|
759
|
+
--------
|
760
|
+
piecewise_leaky_relu
|
761
|
+
"""
|
762
|
+
|
763
|
+
def __init__(self, c=0.01, w=1.):
|
764
|
+
super().__init__()
|
765
|
+
self.c = c
|
766
|
+
self.w = w
|
767
|
+
|
768
|
+
def surrogate_fun(self, x):
|
769
|
+
z = jnp.where(x < -self.w,
|
770
|
+
self.c * x + self.c * self.w,
|
771
|
+
jnp.where(x > self.w,
|
772
|
+
self.c * x - self.c * self.w + 1,
|
773
|
+
0.5 * x / self.w + 0.5))
|
774
|
+
return z
|
775
|
+
|
776
|
+
def surrogate_grad(self, x):
|
777
|
+
dx = jnp.where(jnp.abs(x) > self.w, self.c, 1 / self.w)
|
778
|
+
return dx
|
779
|
+
|
780
|
+
def __repr__(self):
|
781
|
+
return f'{self.__class__.__name__}(c={self.c}, w={self.w})'
|
782
|
+
|
783
|
+
|
784
|
+
def piecewise_leaky_relu(
|
785
|
+
x: jax.Array,
|
786
|
+
c: float = 0.01,
|
787
|
+
w: float = 1.,
|
788
|
+
|
789
|
+
):
|
790
|
+
r"""Judge spiking state with a piecewise leaky relu function [1]_ [2]_ [3]_ [4]_ [5]_ [6]_ [7]_ [8]_.
|
791
|
+
|
792
|
+
If `origin=False`, computes the forward function:
|
793
|
+
|
794
|
+
.. math::
|
795
|
+
|
796
|
+
g(x) = \begin{cases}
|
797
|
+
1, & x \geq 0 \\
|
798
|
+
0, & x < 0 \\
|
799
|
+
\end{cases}
|
800
|
+
|
801
|
+
If `origin=True`, computes the original function:
|
802
|
+
|
803
|
+
.. math::
|
804
|
+
|
805
|
+
\begin{split}g(x) =
|
806
|
+
\begin{cases}
|
807
|
+
cx + cw, & x < -w \\
|
808
|
+
\frac{1}{2w}x + \frac{1}{2}, & -w \leq x \leq w \\
|
809
|
+
cx - cw + 1, & x > w \\
|
810
|
+
\end{cases}\end{split}
|
811
|
+
|
812
|
+
Backward function:
|
813
|
+
|
814
|
+
.. math::
|
815
|
+
|
816
|
+
\begin{split}g'(x) =
|
817
|
+
\begin{cases}
|
818
|
+
\frac{1}{w}, & |x| \leq w \\
|
819
|
+
c, & |x| > w
|
820
|
+
\end{cases}\end{split}
|
821
|
+
|
822
|
+
.. plot::
|
823
|
+
:include-source: True
|
824
|
+
|
825
|
+
>>> import brainstate.nn as nn
|
826
|
+
>>> import brainstate as bst
|
827
|
+
>>> import matplotlib.pyplot as plt
|
828
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
829
|
+
>>> for c in [0.01, 0.05, 0.1]:
|
830
|
+
>>> for w in [1., 2.]:
|
831
|
+
>>> grads1 = bst.transform.vector_grad(nn.surrogate.piecewise_leaky_relu)(xs, c=c, w=w)
|
832
|
+
>>> plt.plot(xs, grads1, label=f'x={c}, w={w}')
|
833
|
+
>>> plt.legend()
|
834
|
+
>>> plt.show()
|
835
|
+
|
836
|
+
Parameters
|
837
|
+
----------
|
838
|
+
x: jax.Array, Array
|
839
|
+
The input data.
|
840
|
+
c: float
|
841
|
+
When :math:`|x| > w` the gradient is `c`.
|
842
|
+
w: float
|
843
|
+
When :math:`|x| <= w` the gradient is `1 / w`.
|
844
|
+
|
845
|
+
|
846
|
+
Returns
|
847
|
+
-------
|
848
|
+
out: jax.Array
|
849
|
+
The spiking state.
|
850
|
+
|
851
|
+
References
|
852
|
+
----------
|
853
|
+
.. [1] Yin S, Venkataramanaiah S K, Chen G K, et al. Algorithm and hardware design of discrete-time spiking neural networks based on back propagation with binary activations[C]//2017 IEEE Biomedical Circuits and Systems Conference (BioCAS). IEEE, 2017: 1-5.
|
854
|
+
.. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331.
|
855
|
+
.. [3] Huh D, Sejnowski T J. Gradient descent for spiking neural networks[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 1440-1450.
|
856
|
+
.. [4] Wu Y, Deng L, Li G, et al. Direct training for spiking neural networks: Faster, larger, better[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2019, 33(01): 1311-1318.
|
857
|
+
.. [5] Gu P, Xiao R, Pan G, et al. STCA: Spatio-Temporal Credit Assignment with Delayed Feedback in Deep Spiking Neural Networks[C]//IJCAI. 2019: 1366-1372.
|
858
|
+
.. [6] Roy D, Chakraborty I, Roy K. Scaling deep spiking neural networks with binary stochastic activations[C]//2019 IEEE International Conference on Cognitive Computing (ICCC). IEEE, 2019: 50-58.
|
859
|
+
.. [7] Cheng X, Hao Y, Xu J, et al. LISNN: Improving Spiking Neural Networks with Lateral Interactions for Robust Object Recognition[C]//IJCAI. 1519-1525.
|
860
|
+
.. [8] Kaiser J, Mostafa H, Neftci E. Synaptic plasticity dynamics for deep continuous local learning (DECOLLE)[J]. Frontiers in Neuroscience, 2020, 14: 424.
|
861
|
+
|
862
|
+
"""
|
863
|
+
return PiecewiseLeakyRelu(c=c, w=w)(x)
|
864
|
+
|
865
|
+
|
866
|
+
class SquarewaveFourierSeries(Surrogate):
|
867
|
+
"""Judge spiking state with a squarewave fourier series.
|
868
|
+
|
869
|
+
See Also
|
870
|
+
--------
|
871
|
+
squarewave_fourier_series
|
872
|
+
"""
|
873
|
+
|
874
|
+
def __init__(self, n=2, t_period=8.):
|
875
|
+
super().__init__()
|
876
|
+
self.n = n
|
877
|
+
self.t_period = t_period
|
878
|
+
|
879
|
+
def surrogate_grad(self, x):
|
880
|
+
|
881
|
+
w = jnp.pi * 2. / self.t_period
|
882
|
+
dx = jnp.cos(w * x)
|
883
|
+
for i in range(2, self.n):
|
884
|
+
dx += jnp.cos((2 * i - 1.) * w * x)
|
885
|
+
dx *= 4. / self.t_period
|
886
|
+
return dx
|
887
|
+
|
888
|
+
def surrogate_fun(self, x):
|
889
|
+
|
890
|
+
w = jnp.pi * 2. / self.t_period
|
891
|
+
ret = jnp.sin(w * x)
|
892
|
+
for i in range(2, self.n):
|
893
|
+
c = (2 * i - 1.)
|
894
|
+
ret += jnp.sin(c * w * x) / c
|
895
|
+
z = 0.5 + 2. / jnp.pi * ret
|
896
|
+
return z
|
897
|
+
|
898
|
+
def __repr__(self):
|
899
|
+
return f'{self.__class__.__name__}(n={self.n}, t_period={self.t_period})'
|
900
|
+
|
901
|
+
|
902
|
+
def squarewave_fourier_series(
|
903
|
+
x: jax.Array,
|
904
|
+
n: int = 2,
|
905
|
+
t_period: float = 8.,
|
906
|
+
|
907
|
+
):
|
908
|
+
r"""Judge spiking state with a squarewave fourier series.
|
909
|
+
|
910
|
+
If `origin=False`, computes the forward function:
|
911
|
+
|
912
|
+
.. math::
|
913
|
+
|
914
|
+
g(x) = \begin{cases}
|
915
|
+
1, & x \geq 0 \\
|
916
|
+
0, & x < 0 \\
|
917
|
+
\end{cases}
|
918
|
+
|
919
|
+
If `origin=True`, computes the original function:
|
920
|
+
|
921
|
+
.. math::
|
922
|
+
|
923
|
+
g(x) = 0.5 + \frac{1}{\pi}*\sum_{i=1}^n {\sin\left({(2i-1)*2\pi}*x/T\right) \over 2i-1 }
|
924
|
+
|
925
|
+
Backward function:
|
926
|
+
|
927
|
+
.. math::
|
928
|
+
|
929
|
+
g'(x) = \sum_{i=1}^n\frac{4\cos\left((2 * i - 1.) * 2\pi * x / T\right)}{T}
|
930
|
+
|
931
|
+
.. plot::
|
932
|
+
:include-source: True
|
933
|
+
|
934
|
+
>>> import brainstate.nn as nn
|
935
|
+
>>> import brainstate as bst
|
936
|
+
>>> import matplotlib.pyplot as plt
|
937
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
938
|
+
>>> for n in [2, 4, 8]:
|
939
|
+
>>> f = nn.surrogate.SquarewaveFourierSeries(n=n)
|
940
|
+
>>> grads1 = bst.transform.vector_grad(f)(xs)
|
941
|
+
>>> plt.plot(xs, grads1, label=f'n={n}')
|
942
|
+
>>> plt.legend()
|
943
|
+
>>> plt.show()
|
944
|
+
|
945
|
+
Parameters
|
946
|
+
----------
|
947
|
+
x: jax.Array, Array
|
948
|
+
The input data.
|
949
|
+
n: int
|
950
|
+
t_period: float
|
951
|
+
|
952
|
+
|
953
|
+
Returns
|
954
|
+
-------
|
955
|
+
out: jax.Array
|
956
|
+
The spiking state.
|
957
|
+
|
958
|
+
"""
|
959
|
+
|
960
|
+
return SquarewaveFourierSeries(n=n, t_period=t_period)(x)
|
961
|
+
|
962
|
+
|
963
|
+
class S2NN(Surrogate):
|
964
|
+
"""Judge spiking state with the S2NN surrogate spiking function.
|
965
|
+
|
966
|
+
See Also
|
967
|
+
--------
|
968
|
+
s2nn
|
969
|
+
"""
|
970
|
+
|
971
|
+
def __init__(self, alpha=4., beta=1., epsilon=1e-8):
|
972
|
+
super().__init__()
|
973
|
+
self.alpha = alpha
|
974
|
+
self.beta = beta
|
975
|
+
self.epsilon = epsilon
|
976
|
+
|
977
|
+
def surrogate_fun(self, x):
|
978
|
+
z = jnp.where(x < 0.,
|
979
|
+
sci.special.expit(x * self.alpha),
|
980
|
+
self.beta * jnp.log(jnp.abs((x + 1.)) + self.epsilon) + 0.5)
|
981
|
+
return z
|
982
|
+
|
983
|
+
def surrogate_grad(self, x):
|
984
|
+
sg = sci.special.expit(self.alpha * x)
|
985
|
+
dx = jnp.where(x < 0., self.alpha * sg * (1. - sg), self.beta / (x + 1.))
|
986
|
+
return dx
|
987
|
+
|
988
|
+
def __repr__(self):
|
989
|
+
return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta}, epsilon={self.epsilon})'
|
990
|
+
|
991
|
+
|
992
|
+
def s2nn(
|
993
|
+
x: jax.Array,
|
994
|
+
alpha: float = 4.,
|
995
|
+
beta: float = 1.,
|
996
|
+
epsilon: float = 1e-8,
|
997
|
+
|
998
|
+
):
|
999
|
+
r"""Judge spiking state with the S2NN surrogate spiking function [1]_.
|
1000
|
+
|
1001
|
+
If `origin=False`, computes the forward function:
|
1002
|
+
|
1003
|
+
.. math::
|
1004
|
+
|
1005
|
+
g(x) = \begin{cases}
|
1006
|
+
1, & x \geq 0 \\
|
1007
|
+
0, & x < 0 \\
|
1008
|
+
\end{cases}
|
1009
|
+
|
1010
|
+
If `origin=True`, computes the original function:
|
1011
|
+
|
1012
|
+
.. math::
|
1013
|
+
|
1014
|
+
\begin{split}g(x) = \begin{cases}
|
1015
|
+
\mathrm{sigmoid} (\alpha x), x < 0 \\
|
1016
|
+
\beta \ln(|x + 1|) + 0.5, x \ge 0
|
1017
|
+
\end{cases}\end{split}
|
1018
|
+
|
1019
|
+
Backward function:
|
1020
|
+
|
1021
|
+
.. math::
|
1022
|
+
|
1023
|
+
\begin{split}g'(x) = \begin{cases}
|
1024
|
+
\alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x), x < 0 \\
|
1025
|
+
\frac{\beta}{(x + 1)}, x \ge 0
|
1026
|
+
\end{cases}\end{split}
|
1027
|
+
|
1028
|
+
.. plot::
|
1029
|
+
:include-source: True
|
1030
|
+
|
1031
|
+
>>> import brainstate.nn as nn
|
1032
|
+
>>> import brainstate as bst
|
1033
|
+
>>> import matplotlib.pyplot as plt
|
1034
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1035
|
+
>>> grads = bst.transform.vector_grad(nn.surrogate.s2nn)(xs, 4., 1.)
|
1036
|
+
>>> plt.plot(xs, grads, label=r'$\alpha=4, \beta=1$')
|
1037
|
+
>>> grads = bst.transform.vector_grad(nn.surrogate.s2nn)(xs, 8., 2.)
|
1038
|
+
>>> plt.plot(xs, grads, label=r'$\alpha=8, \beta=2$')
|
1039
|
+
>>> plt.legend()
|
1040
|
+
>>> plt.show()
|
1041
|
+
|
1042
|
+
Parameters
|
1043
|
+
----------
|
1044
|
+
x: jax.Array, Array
|
1045
|
+
The input data.
|
1046
|
+
alpha: float
|
1047
|
+
The param that controls the gradient when ``x < 0``.
|
1048
|
+
beta: float
|
1049
|
+
The param that controls the gradient when ``x >= 0``
|
1050
|
+
epsilon: float
|
1051
|
+
Avoid nan
|
1052
|
+
|
1053
|
+
|
1054
|
+
Returns
|
1055
|
+
-------
|
1056
|
+
out: jax.Array
|
1057
|
+
The spiking state.
|
1058
|
+
|
1059
|
+
References
|
1060
|
+
----------
|
1061
|
+
.. [1] Suetake, Kazuma et al. “S2NN: Time Step Reduction of Spiking Surrogate Gradients for Training Energy Efficient Single-Step Neural Networks.” ArXiv abs/2201.10879 (2022): n. pag.
|
1062
|
+
|
1063
|
+
"""
|
1064
|
+
return S2NN(alpha=alpha, beta=beta, epsilon=epsilon)(x)
|
1065
|
+
|
1066
|
+
|
1067
|
+
class QPseudoSpike(Surrogate):
|
1068
|
+
"""Judge spiking state with the q-PseudoSpike surrogate function.
|
1069
|
+
|
1070
|
+
See Also
|
1071
|
+
--------
|
1072
|
+
q_pseudo_spike
|
1073
|
+
"""
|
1074
|
+
|
1075
|
+
def __init__(self, alpha=2.):
|
1076
|
+
super().__init__()
|
1077
|
+
self.alpha = alpha
|
1078
|
+
|
1079
|
+
def surrogate_grad(self, x):
|
1080
|
+
dx = jnp.power(1 + 2 / (self.alpha + 1) * jnp.abs(x), -self.alpha)
|
1081
|
+
return dx
|
1082
|
+
|
1083
|
+
def surrogate_fun(self, x):
|
1084
|
+
z = jnp.where(x < 0.,
|
1085
|
+
0.5 * jnp.power(1 - 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha),
|
1086
|
+
1. - 0.5 * jnp.power(1 + 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha))
|
1087
|
+
return z
|
1088
|
+
|
1089
|
+
def __repr__(self):
|
1090
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
1091
|
+
|
1092
|
+
|
1093
|
+
def q_pseudo_spike(
|
1094
|
+
x: jax.Array,
|
1095
|
+
alpha: float = 2.,
|
1096
|
+
|
1097
|
+
):
|
1098
|
+
r"""Judge spiking state with the q-PseudoSpike surrogate function [1]_.
|
1099
|
+
|
1100
|
+
If `origin=False`, computes the forward function:
|
1101
|
+
|
1102
|
+
.. math::
|
1103
|
+
|
1104
|
+
g(x) = \begin{cases}
|
1105
|
+
1, & x \geq 0 \\
|
1106
|
+
0, & x < 0 \\
|
1107
|
+
\end{cases}
|
1108
|
+
|
1109
|
+
If `origin=True`, computes the original function:
|
1110
|
+
|
1111
|
+
.. math::
|
1112
|
+
|
1113
|
+
\begin{split}g(x) =
|
1114
|
+
\begin{cases}
|
1115
|
+
\frac{1}{2}(1-\frac{2x}{\alpha-1})^{1-\alpha}, & x < 0 \\
|
1116
|
+
1 - \frac{1}{2}(1+\frac{2x}{\alpha-1})^{1-\alpha}, & x \geq 0.
|
1117
|
+
\end{cases}\end{split}
|
1118
|
+
|
1119
|
+
Backward function:
|
1120
|
+
|
1121
|
+
.. math::
|
1122
|
+
|
1123
|
+
g'(x) = (1+\frac{2|x|}{\alpha-1})^{-\alpha}
|
1124
|
+
|
1125
|
+
.. plot::
|
1126
|
+
:include-source: True
|
1127
|
+
|
1128
|
+
>>> import brainstate.nn as nn
|
1129
|
+
>>> import brainstate as bst
|
1130
|
+
>>> import matplotlib.pyplot as plt
|
1131
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1132
|
+
>>> for alpha in [0.5, 1., 2., 4.]:
|
1133
|
+
>>> grads = bst.transform.vector_grad(nn.surrogate.q_pseudo_spike)(xs, alpha)
|
1134
|
+
>>> plt.plot(xs, grads, label=r'$\alpha=$' + str(alpha))
|
1135
|
+
>>> plt.legend()
|
1136
|
+
>>> plt.show()
|
1137
|
+
|
1138
|
+
Parameters
|
1139
|
+
----------
|
1140
|
+
x: jax.Array, Array
|
1141
|
+
The input data.
|
1142
|
+
alpha: float
|
1143
|
+
The parameter to control tail fatness of gradient.
|
1144
|
+
|
1145
|
+
|
1146
|
+
Returns
|
1147
|
+
-------
|
1148
|
+
out: jax.Array
|
1149
|
+
The spiking state.
|
1150
|
+
|
1151
|
+
References
|
1152
|
+
----------
|
1153
|
+
.. [1] Herranz-Celotti, Luca and Jean Rouat. “Surrogate Gradients Design.” ArXiv abs/2202.00282 (2022): n. pag.
|
1154
|
+
"""
|
1155
|
+
return QPseudoSpike(alpha=alpha)(x)
|
1156
|
+
|
1157
|
+
|
1158
|
+
class LeakyRelu(Surrogate):
|
1159
|
+
"""Judge spiking state with the Leaky ReLU function.
|
1160
|
+
|
1161
|
+
See Also
|
1162
|
+
--------
|
1163
|
+
leaky_relu
|
1164
|
+
"""
|
1165
|
+
|
1166
|
+
def __init__(self, alpha=0.1, beta=1.):
|
1167
|
+
super().__init__()
|
1168
|
+
self.alpha = alpha
|
1169
|
+
self.beta = beta
|
1170
|
+
|
1171
|
+
def surrogate_fun(self, x):
|
1172
|
+
return jnp.where(x < 0., self.alpha * x, self.beta * x)
|
1173
|
+
|
1174
|
+
def surrogate_grad(self, x):
|
1175
|
+
dx = jnp.where(x < 0., self.alpha, self.beta)
|
1176
|
+
return dx
|
1177
|
+
|
1178
|
+
def __repr__(self):
|
1179
|
+
return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta})'
|
1180
|
+
|
1181
|
+
|
1182
|
+
def leaky_relu(
|
1183
|
+
x: jax.Array,
|
1184
|
+
alpha: float = 0.1,
|
1185
|
+
beta: float = 1.,
|
1186
|
+
|
1187
|
+
):
|
1188
|
+
r"""Judge spiking state with the Leaky ReLU function.
|
1189
|
+
|
1190
|
+
If `origin=False`, computes the forward function:
|
1191
|
+
|
1192
|
+
.. math::
|
1193
|
+
|
1194
|
+
g(x) = \begin{cases}
|
1195
|
+
1, & x \geq 0 \\
|
1196
|
+
0, & x < 0 \\
|
1197
|
+
\end{cases}
|
1198
|
+
|
1199
|
+
If `origin=True`, computes the original function:
|
1200
|
+
|
1201
|
+
.. math::
|
1202
|
+
|
1203
|
+
\begin{split}g(x) =
|
1204
|
+
\begin{cases}
|
1205
|
+
\beta \cdot x, & x \geq 0 \\
|
1206
|
+
\alpha \cdot x, & x < 0 \\
|
1207
|
+
\end{cases}\end{split}
|
1208
|
+
|
1209
|
+
Backward function:
|
1210
|
+
|
1211
|
+
.. math::
|
1212
|
+
|
1213
|
+
\begin{split}g'(x) =
|
1214
|
+
\begin{cases}
|
1215
|
+
\beta, & x \geq 0 \\
|
1216
|
+
\alpha, & x < 0 \\
|
1217
|
+
\end{cases}\end{split}
|
1218
|
+
|
1219
|
+
.. plot::
|
1220
|
+
:include-source: True
|
1221
|
+
|
1222
|
+
>>> import brainstate.nn as nn
|
1223
|
+
>>> import brainstate as bst
|
1224
|
+
>>> import matplotlib.pyplot as plt
|
1225
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1226
|
+
>>> grads = bst.transform.vector_grad(nn.surrogate.leaky_relu)(xs, 0., 1.)
|
1227
|
+
>>> plt.plot(xs, grads, label=r'$\alpha=0., \beta=1.$')
|
1228
|
+
>>> plt.legend()
|
1229
|
+
>>> plt.show()
|
1230
|
+
|
1231
|
+
Parameters
|
1232
|
+
----------
|
1233
|
+
x: jax.Array, Array
|
1234
|
+
The input data.
|
1235
|
+
alpha: float
|
1236
|
+
The parameter to control the gradient when :math:`x < 0`.
|
1237
|
+
beta: float
|
1238
|
+
The parameter to control the gradient when :math:`x >= 0`.
|
1239
|
+
|
1240
|
+
|
1241
|
+
Returns
|
1242
|
+
-------
|
1243
|
+
out: jax.Array
|
1244
|
+
The spiking state.
|
1245
|
+
"""
|
1246
|
+
return LeakyRelu(alpha=alpha, beta=beta)(x)
|
1247
|
+
|
1248
|
+
|
1249
|
+
class LogTailedRelu(Surrogate):
|
1250
|
+
"""Judge spiking state with the Log-tailed ReLU function.
|
1251
|
+
|
1252
|
+
See Also
|
1253
|
+
--------
|
1254
|
+
log_tailed_relu
|
1255
|
+
"""
|
1256
|
+
|
1257
|
+
def __init__(self, alpha=0.):
|
1258
|
+
super().__init__()
|
1259
|
+
self.alpha = alpha
|
1260
|
+
|
1261
|
+
def surrogate_fun(self, x):
|
1262
|
+
z = jnp.where(x > 1,
|
1263
|
+
jnp.log(x),
|
1264
|
+
jnp.where(x > 0,
|
1265
|
+
x,
|
1266
|
+
self.alpha * x))
|
1267
|
+
return z
|
1268
|
+
|
1269
|
+
def surrogate_grad(self, x):
|
1270
|
+
dx = jnp.where(x > 1,
|
1271
|
+
1 / x,
|
1272
|
+
jnp.where(x > 0,
|
1273
|
+
1.,
|
1274
|
+
self.alpha))
|
1275
|
+
return dx
|
1276
|
+
|
1277
|
+
def __repr__(self):
|
1278
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
1279
|
+
|
1280
|
+
|
1281
|
+
def log_tailed_relu(
|
1282
|
+
x: jax.Array,
|
1283
|
+
alpha: float = 0.,
|
1284
|
+
|
1285
|
+
):
|
1286
|
+
r"""Judge spiking state with the Log-tailed ReLU function [1]_.
|
1287
|
+
|
1288
|
+
If `origin=False`, computes the forward function:
|
1289
|
+
|
1290
|
+
.. math::
|
1291
|
+
|
1292
|
+
g(x) = \begin{cases}
|
1293
|
+
1, & x \geq 0 \\
|
1294
|
+
0, & x < 0 \\
|
1295
|
+
\end{cases}
|
1296
|
+
|
1297
|
+
If `origin=True`, computes the original function:
|
1298
|
+
|
1299
|
+
.. math::
|
1300
|
+
|
1301
|
+
\begin{split}g(x) =
|
1302
|
+
\begin{cases}
|
1303
|
+
\alpha x, & x \leq 0 \\
|
1304
|
+
x, & 0 < x \leq 0 \\
|
1305
|
+
log(x), x > 1 \\
|
1306
|
+
\end{cases}\end{split}
|
1307
|
+
|
1308
|
+
Backward function:
|
1309
|
+
|
1310
|
+
.. math::
|
1311
|
+
|
1312
|
+
\begin{split}g'(x) =
|
1313
|
+
\begin{cases}
|
1314
|
+
\alpha, & x \leq 0 \\
|
1315
|
+
1, & 0 < x \leq 0 \\
|
1316
|
+
\frac{1}{x}, x > 1 \\
|
1317
|
+
\end{cases}\end{split}
|
1318
|
+
|
1319
|
+
.. plot::
|
1320
|
+
:include-source: True
|
1321
|
+
|
1322
|
+
>>> import brainstate.nn as nn
|
1323
|
+
>>> import brainstate as bst
|
1324
|
+
>>> import matplotlib.pyplot as plt
|
1325
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1326
|
+
>>> grads = bst.transform.vector_grad(nn.surrogate.leaky_relu)(xs, 0., 1.)
|
1327
|
+
>>> plt.plot(xs, grads, label=r'$\alpha=0., \beta=1.$')
|
1328
|
+
>>> plt.legend()
|
1329
|
+
>>> plt.show()
|
1330
|
+
|
1331
|
+
Parameters
|
1332
|
+
----------
|
1333
|
+
x: jax.Array, Array
|
1334
|
+
The input data.
|
1335
|
+
alpha: float
|
1336
|
+
The parameter to control the gradient.
|
1337
|
+
|
1338
|
+
|
1339
|
+
Returns
|
1340
|
+
-------
|
1341
|
+
out: jax.Array
|
1342
|
+
The spiking state.
|
1343
|
+
|
1344
|
+
References
|
1345
|
+
----------
|
1346
|
+
.. [1] Cai, Zhaowei et al. “Deep Learning with Low Precision by Half-Wave Gaussian Quantization.” 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2017): 5406-5414.
|
1347
|
+
"""
|
1348
|
+
return LogTailedRelu(alpha=alpha)(x)
|
1349
|
+
|
1350
|
+
|
1351
|
+
class ReluGrad(Surrogate):
|
1352
|
+
"""Judge spiking state with the ReLU gradient function.
|
1353
|
+
|
1354
|
+
See Also
|
1355
|
+
--------
|
1356
|
+
relu_grad
|
1357
|
+
"""
|
1358
|
+
|
1359
|
+
def __init__(self, alpha=0.3, width=1.):
|
1360
|
+
super().__init__()
|
1361
|
+
self.alpha = alpha
|
1362
|
+
self.width = width
|
1363
|
+
|
1364
|
+
def surrogate_grad(self, x):
|
1365
|
+
dx = jnp.maximum(self.alpha * self.width - jnp.abs(x) * self.alpha, 0)
|
1366
|
+
return dx
|
1367
|
+
|
1368
|
+
def __repr__(self):
|
1369
|
+
return f'{self.__class__.__name__}(alpha={self.alpha}, width={self.width})'
|
1370
|
+
|
1371
|
+
|
1372
|
+
def relu_grad(
|
1373
|
+
x: jax.Array,
|
1374
|
+
alpha: float = 0.3,
|
1375
|
+
width: float = 1.,
|
1376
|
+
):
|
1377
|
+
r"""Spike function with the ReLU gradient function [1]_.
|
1378
|
+
|
1379
|
+
The forward function:
|
1380
|
+
|
1381
|
+
.. math::
|
1382
|
+
|
1383
|
+
g(x) = \begin{cases}
|
1384
|
+
1, & x \geq 0 \\
|
1385
|
+
0, & x < 0 \\
|
1386
|
+
\end{cases}
|
1387
|
+
|
1388
|
+
Backward function:
|
1389
|
+
|
1390
|
+
.. math::
|
1391
|
+
|
1392
|
+
g'(x) = \text{ReLU}(\alpha * (\mathrm{width}-|x|))
|
1393
|
+
|
1394
|
+
.. plot::
|
1395
|
+
:include-source: True
|
1396
|
+
|
1397
|
+
>>> import brainstate.nn as nn
|
1398
|
+
>>> import brainstate as bst
|
1399
|
+
>>> import matplotlib.pyplot as plt
|
1400
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1401
|
+
>>> for s in [0.5, 1.]:
|
1402
|
+
>>> for w in [1, 2.]:
|
1403
|
+
>>> grads = bst.transform.vector_grad(nn.surrogate.relu_grad)(xs, s, w)
|
1404
|
+
>>> plt.plot(xs, grads, label=r'$\alpha=$' + f'{s}, width={w}')
|
1405
|
+
>>> plt.legend()
|
1406
|
+
>>> plt.show()
|
1407
|
+
|
1408
|
+
Parameters
|
1409
|
+
----------
|
1410
|
+
x: jax.Array, Array
|
1411
|
+
The input data.
|
1412
|
+
alpha: float
|
1413
|
+
The parameter to control the gradient.
|
1414
|
+
width: float
|
1415
|
+
The parameter to control the width of the gradient.
|
1416
|
+
|
1417
|
+
Returns
|
1418
|
+
-------
|
1419
|
+
out: jax.Array
|
1420
|
+
The spiking state.
|
1421
|
+
|
1422
|
+
References
|
1423
|
+
----------
|
1424
|
+
.. [1] Neftci, E. O., Mostafa, H. & Zenke, F. Surrogate gradient learning in spiking neural networks. IEEE Signal Process. Mag. 36, 61–63 (2019).
|
1425
|
+
"""
|
1426
|
+
return ReluGrad(alpha=alpha, width=width)(x)
|
1427
|
+
|
1428
|
+
|
1429
|
+
class GaussianGrad(Surrogate):
|
1430
|
+
"""Judge spiking state with the Gaussian gradient function.
|
1431
|
+
|
1432
|
+
See Also
|
1433
|
+
--------
|
1434
|
+
gaussian_grad
|
1435
|
+
"""
|
1436
|
+
|
1437
|
+
def __init__(self, sigma=0.5, alpha=0.5):
|
1438
|
+
super().__init__()
|
1439
|
+
self.sigma = sigma
|
1440
|
+
self.alpha = alpha
|
1441
|
+
|
1442
|
+
def surrogate_grad(self, x):
|
1443
|
+
dx = jnp.exp(-(x ** 2) / 2 * jnp.power(self.sigma, 2)) / (jnp.sqrt(2 * jnp.pi) * self.sigma)
|
1444
|
+
return self.alpha * dx
|
1445
|
+
|
1446
|
+
def __repr__(self):
|
1447
|
+
return f'{self.__class__.__name__}(alpha={self.alpha}, sigma={self.sigma})'
|
1448
|
+
|
1449
|
+
|
1450
|
+
def gaussian_grad(
|
1451
|
+
x: jax.Array,
|
1452
|
+
sigma: float = 0.5,
|
1453
|
+
alpha: float = 0.5,
|
1454
|
+
):
|
1455
|
+
r"""Spike function with the Gaussian gradient function [1]_.
|
1456
|
+
|
1457
|
+
The forward function:
|
1458
|
+
|
1459
|
+
.. math::
|
1460
|
+
|
1461
|
+
g(x) = \begin{cases}
|
1462
|
+
1, & x \geq 0 \\
|
1463
|
+
0, & x < 0 \\
|
1464
|
+
\end{cases}
|
1465
|
+
|
1466
|
+
Backward function:
|
1467
|
+
|
1468
|
+
.. math::
|
1469
|
+
|
1470
|
+
g'(x) = \alpha * \text{gaussian}(x, 0., \sigma)
|
1471
|
+
|
1472
|
+
.. plot::
|
1473
|
+
:include-source: True
|
1474
|
+
|
1475
|
+
>>> import brainstate.nn as nn
|
1476
|
+
>>> import brainstate as bst
|
1477
|
+
>>> import matplotlib.pyplot as plt
|
1478
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1479
|
+
>>> for s in [0.5, 1., 2.]:
|
1480
|
+
>>> grads = bst.transform.vector_grad(nn.surrogate.gaussian_grad)(xs, s, 0.5)
|
1481
|
+
>>> plt.plot(xs, grads, label=r'$\alpha=0.5, \sigma=$' + str(s))
|
1482
|
+
>>> plt.legend()
|
1483
|
+
>>> plt.show()
|
1484
|
+
|
1485
|
+
Parameters
|
1486
|
+
----------
|
1487
|
+
x: jax.Array, Array
|
1488
|
+
The input data.
|
1489
|
+
sigma: float
|
1490
|
+
The parameter to control the variance of gaussian distribution.
|
1491
|
+
alpha: float
|
1492
|
+
The parameter to control the scale of the gradient.
|
1493
|
+
|
1494
|
+
Returns
|
1495
|
+
-------
|
1496
|
+
out: jax.Array
|
1497
|
+
The spiking state.
|
1498
|
+
|
1499
|
+
References
|
1500
|
+
----------
|
1501
|
+
.. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021).
|
1502
|
+
"""
|
1503
|
+
return GaussianGrad(sigma=sigma, alpha=alpha)(x)
|
1504
|
+
|
1505
|
+
|
1506
|
+
class MultiGaussianGrad(Surrogate):
|
1507
|
+
"""Judge spiking state with the multi-Gaussian gradient function.
|
1508
|
+
|
1509
|
+
See Also
|
1510
|
+
--------
|
1511
|
+
multi_gaussian_grad
|
1512
|
+
"""
|
1513
|
+
|
1514
|
+
def __init__(self, h=0.15, s=6.0, sigma=0.5, scale=0.5):
|
1515
|
+
super().__init__()
|
1516
|
+
self.h = h
|
1517
|
+
self.s = s
|
1518
|
+
self.sigma = sigma
|
1519
|
+
self.scale = scale
|
1520
|
+
|
1521
|
+
def surrogate_grad(self, x):
|
1522
|
+
g1 = jnp.exp(-x ** 2 / (2 * jnp.power(self.sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * self.sigma)
|
1523
|
+
g2 = jnp.exp(-(x - self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2))
|
1524
|
+
) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma)
|
1525
|
+
g3 = jnp.exp(-(x + self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2))
|
1526
|
+
) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma)
|
1527
|
+
dx = g1 * (1. + self.h) - g2 * self.h - g3 * self.h
|
1528
|
+
return self.scale * dx
|
1529
|
+
|
1530
|
+
def __repr__(self):
|
1531
|
+
return f'{self.__class__.__name__}(h={self.h}, s={self.s}, sigma={self.sigma}, scale={self.scale})'
|
1532
|
+
|
1533
|
+
|
1534
|
+
def multi_gaussian_grad(
|
1535
|
+
x: jax.Array,
|
1536
|
+
h: float = 0.15,
|
1537
|
+
s: float = 6.0,
|
1538
|
+
sigma: float = 0.5,
|
1539
|
+
scale: float = 0.5,
|
1540
|
+
):
|
1541
|
+
r"""Spike function with the multi-Gaussian gradient function [1]_.
|
1542
|
+
|
1543
|
+
The forward function:
|
1544
|
+
|
1545
|
+
.. math::
|
1546
|
+
|
1547
|
+
g(x) = \begin{cases}
|
1548
|
+
1, & x \geq 0 \\
|
1549
|
+
0, & x < 0 \\
|
1550
|
+
\end{cases}
|
1551
|
+
|
1552
|
+
Backward function:
|
1553
|
+
|
1554
|
+
.. math::
|
1555
|
+
|
1556
|
+
\begin{array}{l}
|
1557
|
+
g'(x)=(1+h){{{\mathcal{N}}}}(x, 0, {\sigma }^{2})
|
1558
|
+
-h{{{\mathcal{N}}}}(x, \sigma,{(s\sigma )}^{2})-
|
1559
|
+
h{{{\mathcal{N}}}}(x, -\sigma ,{(s\sigma )}^{2})
|
1560
|
+
\end{array}
|
1561
|
+
|
1562
|
+
|
1563
|
+
.. plot::
|
1564
|
+
:include-source: True
|
1565
|
+
|
1566
|
+
>>> import brainstate.nn as nn
|
1567
|
+
>>> import brainstate as bst
|
1568
|
+
>>> import matplotlib.pyplot as plt
|
1569
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1570
|
+
>>> grads = bst.transform.vector_grad(nn.surrogate.multi_gaussian_grad)(xs)
|
1571
|
+
>>> plt.plot(xs, grads)
|
1572
|
+
>>> plt.show()
|
1573
|
+
|
1574
|
+
Parameters
|
1575
|
+
----------
|
1576
|
+
x: jax.Array, Array
|
1577
|
+
The input data.
|
1578
|
+
h: float
|
1579
|
+
The hyper-parameters of approximate function
|
1580
|
+
s: float
|
1581
|
+
The hyper-parameters of approximate function
|
1582
|
+
sigma: float
|
1583
|
+
The gaussian sigma.
|
1584
|
+
scale: float
|
1585
|
+
The gradient scale.
|
1586
|
+
|
1587
|
+
Returns
|
1588
|
+
-------
|
1589
|
+
out: jax.Array
|
1590
|
+
The spiking state.
|
1591
|
+
|
1592
|
+
References
|
1593
|
+
----------
|
1594
|
+
.. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021).
|
1595
|
+
"""
|
1596
|
+
return MultiGaussianGrad(h=h, s=s, sigma=sigma, scale=scale)(x)
|
1597
|
+
|
1598
|
+
|
1599
|
+
class InvSquareGrad(Surrogate):
|
1600
|
+
"""Judge spiking state with the inverse-square surrogate gradient function.
|
1601
|
+
|
1602
|
+
See Also
|
1603
|
+
--------
|
1604
|
+
inv_square_grad
|
1605
|
+
"""
|
1606
|
+
|
1607
|
+
def __init__(self, alpha=100.):
|
1608
|
+
super().__init__()
|
1609
|
+
self.alpha = alpha
|
1610
|
+
|
1611
|
+
def surrogate_grad(self, x):
|
1612
|
+
dx = 1. / (self.alpha * jnp.abs(x) + 1.0) ** 2
|
1613
|
+
return dx
|
1614
|
+
|
1615
|
+
def __repr__(self):
|
1616
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
1617
|
+
|
1618
|
+
|
1619
|
+
def inv_square_grad(
|
1620
|
+
x: jax.Array,
|
1621
|
+
alpha: float = 100.
|
1622
|
+
):
|
1623
|
+
r"""Spike function with the inverse-square surrogate gradient.
|
1624
|
+
|
1625
|
+
Forward function:
|
1626
|
+
|
1627
|
+
.. math::
|
1628
|
+
|
1629
|
+
g(x) = \begin{cases}
|
1630
|
+
1, & x \geq 0 \\
|
1631
|
+
0, & x < 0 \\
|
1632
|
+
\end{cases}
|
1633
|
+
|
1634
|
+
Backward function:
|
1635
|
+
|
1636
|
+
.. math::
|
1637
|
+
|
1638
|
+
g'(x) = \frac{1}{(\alpha * |x| + 1.) ^ 2}
|
1639
|
+
|
1640
|
+
|
1641
|
+
.. plot::
|
1642
|
+
:include-source: True
|
1643
|
+
|
1644
|
+
>>> import brainstate.nn as nn
|
1645
|
+
>>> import brainstate as bst
|
1646
|
+
>>> import matplotlib.pyplot as plt
|
1647
|
+
>>> xs = jax.numpy.linspace(-1, 1, 1000)
|
1648
|
+
>>> for alpha in [1., 10., 100.]:
|
1649
|
+
>>> grads = bst.transform.vector_grad(nn.surrogate.inv_square_grad)(xs, alpha)
|
1650
|
+
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
1651
|
+
>>> plt.legend()
|
1652
|
+
>>> plt.show()
|
1653
|
+
|
1654
|
+
Parameters
|
1655
|
+
----------
|
1656
|
+
x: jax.Array, Array
|
1657
|
+
The input data.
|
1658
|
+
alpha: float
|
1659
|
+
Parameter to control smoothness of gradient
|
1660
|
+
|
1661
|
+
Returns
|
1662
|
+
-------
|
1663
|
+
out: jax.Array
|
1664
|
+
The spiking state.
|
1665
|
+
"""
|
1666
|
+
return InvSquareGrad(alpha=alpha)(x)
|
1667
|
+
|
1668
|
+
|
1669
|
+
class SlayerGrad(Surrogate):
|
1670
|
+
"""Judge spiking state with the slayer surrogate gradient function.
|
1671
|
+
|
1672
|
+
See Also
|
1673
|
+
--------
|
1674
|
+
slayer_grad
|
1675
|
+
"""
|
1676
|
+
|
1677
|
+
def __init__(self, alpha=1.):
|
1678
|
+
super().__init__()
|
1679
|
+
self.alpha = alpha
|
1680
|
+
|
1681
|
+
def surrogate_grad(self, x):
|
1682
|
+
dx = jnp.exp(-self.alpha * jnp.abs(x))
|
1683
|
+
return dx
|
1684
|
+
|
1685
|
+
def __repr__(self):
|
1686
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
1687
|
+
|
1688
|
+
|
1689
|
+
def slayer_grad(
|
1690
|
+
x: jax.Array,
|
1691
|
+
alpha: float = 1.
|
1692
|
+
):
|
1693
|
+
r"""Spike function with the slayer surrogate gradient function.
|
1694
|
+
|
1695
|
+
Forward function:
|
1696
|
+
|
1697
|
+
.. math::
|
1698
|
+
|
1699
|
+
g(x) = \begin{cases}
|
1700
|
+
1, & x \geq 0 \\
|
1701
|
+
0, & x < 0 \\
|
1702
|
+
\end{cases}
|
1703
|
+
|
1704
|
+
Backward function:
|
1705
|
+
|
1706
|
+
.. math::
|
1707
|
+
|
1708
|
+
g'(x) = \exp(-\alpha |x|)
|
1709
|
+
|
1710
|
+
|
1711
|
+
.. plot::
|
1712
|
+
:include-source: True
|
1713
|
+
|
1714
|
+
>>> import brainstate.nn as nn
|
1715
|
+
>>> import brainstate as bst
|
1716
|
+
>>> import matplotlib.pyplot as plt
|
1717
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1718
|
+
>>> for alpha in [0.5, 1., 2., 4.]:
|
1719
|
+
>>> grads = bst.transform.vector_grad(nn.surrogate.slayer_grad)(xs, alpha)
|
1720
|
+
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
1721
|
+
>>> plt.legend()
|
1722
|
+
>>> plt.show()
|
1723
|
+
|
1724
|
+
Parameters
|
1725
|
+
----------
|
1726
|
+
x: jax.Array, Array
|
1727
|
+
The input data.
|
1728
|
+
alpha: float
|
1729
|
+
Parameter to control smoothness of gradient
|
1730
|
+
|
1731
|
+
Returns
|
1732
|
+
-------
|
1733
|
+
out: jax.Array
|
1734
|
+
The spiking state.
|
1735
|
+
|
1736
|
+
References
|
1737
|
+
----------
|
1738
|
+
.. [1] Shrestha, S. B. & Orchard, G. Slayer: spike layer error reassignment in time. In Advances in Neural Information Processing Systems Vol. 31, 1412–1421 (NeurIPS, 2018).
|
1739
|
+
"""
|
1740
|
+
return SlayerGrad(alpha=alpha)(x)
|