brainstate 0.0.1__py2.py3-none-any.whl → 0.0.1.post20240622__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. brainstate/__init__.py +4 -5
  2. brainstate/_module.py +191 -48
  3. brainstate/_module_test.py +95 -21
  4. brainstate/_state.py +17 -0
  5. brainstate/environ.py +2 -2
  6. brainstate/functional/__init__.py +3 -2
  7. brainstate/functional/_activations.py +7 -26
  8. brainstate/functional/_normalization.py +3 -0
  9. brainstate/functional/_others.py +49 -0
  10. brainstate/functional/_spikes.py +0 -1
  11. brainstate/mixin.py +2 -2
  12. brainstate/nn/__init__.py +4 -0
  13. brainstate/nn/_base.py +10 -7
  14. brainstate/nn/_dynamics.py +20 -0
  15. brainstate/nn/_elementwise.py +5 -4
  16. brainstate/nn/_embedding.py +66 -0
  17. brainstate/nn/_misc.py +4 -3
  18. brainstate/nn/_others.py +3 -2
  19. brainstate/nn/_poolings.py +21 -20
  20. brainstate/nn/_poolings_test.py +4 -4
  21. brainstate/nn/_rate_rnns.py +17 -0
  22. brainstate/nn/_readout.py +6 -0
  23. brainstate/optim/__init__.py +0 -1
  24. brainstate/optim/_lr_scheduler_test.py +13 -0
  25. brainstate/optim/_sgd_optimizer.py +18 -17
  26. brainstate/transform/__init__.py +2 -3
  27. brainstate/transform/_autograd.py +1 -1
  28. brainstate/transform/_autograd_test.py +0 -2
  29. brainstate/transform/_jit.py +47 -21
  30. brainstate/transform/_jit_test.py +0 -3
  31. brainstate/transform/_make_jaxpr.py +164 -3
  32. brainstate/transform/_make_jaxpr_test.py +0 -2
  33. brainstate/transform/_progress_bar.py +1 -3
  34. brainstate/util.py +0 -1
  35. {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/METADATA +9 -17
  36. brainstate-0.0.1.post20240622.dist-info/RECORD +64 -0
  37. brainstate/math/__init__.py +0 -21
  38. brainstate/math/_einops.py +0 -787
  39. brainstate/math/_einops_parsing.py +0 -169
  40. brainstate/math/_einops_parsing_test.py +0 -126
  41. brainstate/math/_einops_test.py +0 -346
  42. brainstate/math/_misc.py +0 -298
  43. brainstate/math/_misc_test.py +0 -58
  44. brainstate/nn/functional/__init__.py +0 -25
  45. brainstate/nn/functional/_activations.py +0 -754
  46. brainstate/nn/functional/_normalization.py +0 -69
  47. brainstate/nn/functional/_spikes.py +0 -90
  48. brainstate/nn/init/__init__.py +0 -26
  49. brainstate/nn/init/_base.py +0 -36
  50. brainstate/nn/init/_generic.py +0 -175
  51. brainstate/nn/init/_random_inits.py +0 -489
  52. brainstate/nn/init/_regular_inits.py +0 -109
  53. brainstate/nn/surrogate.py +0 -1740
  54. brainstate-0.0.1.dist-info/RECORD +0 -79
  55. {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/LICENSE +0 -0
  56. {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/WHEEL +0 -0
  57. {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/top_level.txt +0 -0
@@ -1,1740 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- import jax
19
- import jax.numpy as jnp
20
- import jax.scipy as sci
21
- from jax.core import Primitive
22
- from jax.interpreters import batching, ad, mlir
23
-
24
- __all__ = [
25
- 'Surrogate',
26
- 'Sigmoid',
27
- 'sigmoid',
28
- 'PiecewiseQuadratic',
29
- 'piecewise_quadratic',
30
- 'PiecewiseExp',
31
- 'piecewise_exp',
32
- 'SoftSign',
33
- 'soft_sign',
34
- 'Arctan',
35
- 'arctan',
36
- 'NonzeroSignLog',
37
- 'nonzero_sign_log',
38
- 'ERF',
39
- 'erf',
40
- 'PiecewiseLeakyRelu',
41
- 'piecewise_leaky_relu',
42
- 'SquarewaveFourierSeries',
43
- 'squarewave_fourier_series',
44
- 'S2NN',
45
- 's2nn',
46
- 'QPseudoSpike',
47
- 'q_pseudo_spike',
48
- 'LeakyRelu',
49
- 'leaky_relu',
50
- 'LogTailedRelu',
51
- 'log_tailed_relu',
52
- 'ReluGrad',
53
- 'relu_grad',
54
- 'GaussianGrad',
55
- 'gaussian_grad',
56
- 'InvSquareGrad',
57
- 'inv_square_grad',
58
- 'MultiGaussianGrad',
59
- 'multi_gaussian_grad',
60
- 'SlayerGrad',
61
- 'slayer_grad',
62
- ]
63
-
64
-
65
- def _heaviside_abstract(x, dx):
66
- return [x]
67
-
68
-
69
- def _heaviside_imp(x, dx):
70
- z = jnp.asarray(x >= 0, dtype=x.dtype)
71
- return [z]
72
-
73
-
74
- def _heaviside_batching(args, axes):
75
- return heaviside_p.bind(*args), [axes[0]]
76
-
77
-
78
- def _heaviside_jvp(primals, tangents):
79
- x, dx = primals
80
- tx, tdx = tangents
81
- primal_outs = heaviside_p.bind(x, dx)
82
- tangent_outs = [dx * tx, ]
83
- return primal_outs, tangent_outs
84
-
85
-
86
- heaviside_p = Primitive('heaviside_p')
87
- heaviside_p.multiple_results = True
88
- heaviside_p.def_abstract_eval(_heaviside_abstract)
89
- heaviside_p.def_impl(_heaviside_imp)
90
- batching.primitive_batchers[heaviside_p] = _heaviside_batching
91
- ad.primitive_jvps[heaviside_p] = _heaviside_jvp
92
- mlir.register_lowering(heaviside_p, mlir.lower_fun(_heaviside_imp, multiple_results=True))
93
-
94
-
95
- class Surrogate(object):
96
- """The base surrograte gradient function.
97
-
98
- To customize a surrogate gradient function, you can inherit this class and
99
- implement the `surrogate_fun` and `surrogate_grad` methods.
100
-
101
- Examples
102
- --------
103
-
104
- >>> import brainstate as bst
105
- >>> import brainstate.nn as nn
106
- >>> import jax.numpy as jnp
107
-
108
- >>> class MySurrogate(nn.surrogate.Surrogate):
109
- ... def __init__(self, alpha=1.):
110
- ... super().__init__()
111
- ... self.alpha = alpha
112
- ...
113
- ... def surrogate_fun(self, x):
114
- ... return jnp.sin(x) * self.alpha
115
- ...
116
- ... def surrogate_grad(self, x):
117
- ... return jnp.cos(x) * self.alpha
118
-
119
- """
120
-
121
- def __call__(self, x):
122
- dx = self.surrogate_grad(x)
123
- return heaviside_p.bind(x, dx)[0]
124
-
125
- def __repr__(self):
126
- return f'{self.__class__.__name__}()'
127
-
128
- def surrogate_fun(self, x) -> jax.Array:
129
- """The surrogate function."""
130
- raise NotImplementedError
131
-
132
- def surrogate_grad(self, x) -> jax.Array:
133
- """The gradient function of the surrogate function."""
134
- raise NotImplementedError
135
-
136
-
137
- class Sigmoid(Surrogate):
138
- """Spike function with the sigmoid-shaped surrogate gradient.
139
-
140
- See Also
141
- --------
142
- sigmoid
143
-
144
- """
145
-
146
- def __init__(self, alpha: float = 4.):
147
- super().__init__()
148
- self.alpha = alpha
149
-
150
- def surrogate_fun(self, x):
151
- return sci.special.expit(self.alpha * x)
152
-
153
- def surrogate_grad(self, x):
154
- sgax = sci.special.expit(x * self.alpha)
155
- dx = (1. - sgax) * sgax * self.alpha
156
- return dx
157
-
158
- def __repr__(self):
159
- return f'{self.__class__.__name__}(alpha={self.alpha})'
160
-
161
-
162
- def sigmoid(
163
- x: jax.Array,
164
- alpha: float = 4.,
165
- ):
166
- r"""Spike function with the sigmoid-shaped surrogate gradient.
167
-
168
- If `origin=False`, return the forward function:
169
-
170
- .. math::
171
-
172
- g(x) = \begin{cases}
173
- 1, & x \geq 0 \\
174
- 0, & x < 0 \\
175
- \end{cases}
176
-
177
- If `origin=True`, computes the original function:
178
-
179
- .. math::
180
-
181
- g(x) = \mathrm{sigmoid}(\alpha x) = \frac{1}{1+e^{-\alpha x}}
182
-
183
- Backward function:
184
-
185
- .. math::
186
-
187
- g'(x) = \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x)
188
-
189
- .. plot::
190
- :include-source: True
191
-
192
- >>> import brainstate.nn as nn
193
- >>> import brainstate as bst
194
- >>> import matplotlib.pyplot as plt
195
- >>> xs = jax.numpy.linspace(-2, 2, 1000)
196
- >>> for alpha in [1., 2., 4.]:
197
- >>> grads = bst.transform.vector_grad(nn.surrogate.sigmoid)(xs, alpha)
198
- >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
199
- >>> plt.legend()
200
- >>> plt.show()
201
-
202
- Parameters
203
- ----------
204
- x: jax.Array, Array
205
- The input data.
206
- alpha: float
207
- Parameter to control smoothness of gradient
208
-
209
-
210
- Returns
211
- -------
212
- out: jax.Array
213
- The spiking state.
214
- """
215
- return Sigmoid(alpha=alpha)(x)
216
-
217
-
218
- class PiecewiseQuadratic(Surrogate):
219
- """Judge spiking state with a piecewise quadratic function.
220
-
221
- See Also
222
- --------
223
- piecewise_quadratic
224
-
225
- """
226
-
227
- def __init__(self, alpha: float = 1.):
228
- super().__init__()
229
- self.alpha = alpha
230
-
231
- def surrogate_fun(self, x):
232
- z = jnp.where(x < -1 / self.alpha,
233
- 0.,
234
- jnp.where(x > 1 / self.alpha,
235
- 1.,
236
- (-self.alpha * jnp.abs(x) / 2 + 1) * self.alpha * x + 0.5))
237
- return z
238
-
239
- def surrogate_grad(self, x):
240
- dx = jnp.where(jnp.abs(x) > 1 / self.alpha, 0., (-(self.alpha * x) ** 2 + self.alpha))
241
- return dx
242
-
243
- def __repr__(self):
244
- return f'{self.__class__.__name__}(alpha={self.alpha})'
245
-
246
-
247
- def piecewise_quadratic(
248
- x: jax.Array,
249
- alpha: float = 1.,
250
- ):
251
- r"""Judge spiking state with a piecewise quadratic function [1]_ [2]_ [3]_ [4]_ [5]_.
252
-
253
- If `origin=False`, computes the forward function:
254
-
255
- .. math::
256
-
257
- g(x) = \begin{cases}
258
- 1, & x \geq 0 \\
259
- 0, & x < 0 \\
260
- \end{cases}
261
-
262
- If `origin=True`, computes the original function:
263
-
264
- .. math::
265
-
266
- g(x) =
267
- \begin{cases}
268
- 0, & x < -\frac{1}{\alpha} \\
269
- -\frac{1}{2}\alpha^2|x|x + \alpha x + \frac{1}{2}, & |x| \leq \frac{1}{\alpha} \\
270
- 1, & x > \frac{1}{\alpha} \\
271
- \end{cases}
272
-
273
- Backward function:
274
-
275
- .. math::
276
-
277
- g'(x) =
278
- \begin{cases}
279
- 0, & |x| > \frac{1}{\alpha} \\
280
- -\alpha^2|x|+\alpha, & |x| \leq \frac{1}{\alpha}
281
- \end{cases}
282
-
283
- .. plot::
284
- :include-source: True
285
-
286
- >>> import brainstate.nn as nn
287
- >>> import brainstate as bst
288
- >>> import matplotlib.pyplot as plt
289
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
290
- >>> for alpha in [0.5, 1., 2., 4.]:
291
- >>> grads = bst.transform.vector_grad(nn.surrogate.piecewise_quadratic)(xs, alpha)
292
- >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
293
- >>> plt.legend()
294
- >>> plt.show()
295
-
296
- Parameters
297
- ----------
298
- x: jax.Array, Array
299
- The input data.
300
- alpha: float
301
- Parameter to control smoothness of gradient
302
-
303
-
304
- Returns
305
- -------
306
- out: jax.Array
307
- The spiking state.
308
-
309
- References
310
- ----------
311
- .. [1] Esser S K, Merolla P A, Arthur J V, et al. Convolutional networks for fast, energy-efficient neuromorphic computing[J]. Proceedings of the national academy of sciences, 2016, 113(41): 11441-11446.
312
- .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331.
313
- .. [3] Bellec G, Salaj D, Subramoney A, et al. Long short-term memory and learning-to-learn in networks of spiking neurons[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 795-805.
314
- .. [4] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63.
315
- .. [5] Panda P, Aketi S A, Roy K. Toward scalable, efficient, and accurate deep spiking neural networks with backward residual connections, stochastic softmax, and hybridization[J]. Frontiers in Neuroscience, 2020, 14.
316
- """
317
- return PiecewiseQuadratic(alpha=alpha)(x)
318
-
319
-
320
- class PiecewiseExp(Surrogate):
321
- """Judge spiking state with a piecewise exponential function.
322
-
323
- See Also
324
- --------
325
- piecewise_exp
326
- """
327
-
328
- def __init__(self, alpha: float = 1.):
329
- super().__init__()
330
- self.alpha = alpha
331
-
332
- def surrogate_grad(self, x):
333
- dx = (self.alpha / 2) * jnp.exp(-self.alpha * jnp.abs(x))
334
- return dx
335
-
336
- def surrogate_fun(self, x):
337
- return jnp.where(x < 0, jnp.exp(self.alpha * x) / 2, 1 - jnp.exp(-self.alpha * x) / 2)
338
-
339
- def __repr__(self):
340
- return f'{self.__class__.__name__}(alpha={self.alpha})'
341
-
342
-
343
- def piecewise_exp(
344
- x: jax.Array,
345
- alpha: float = 1.,
346
-
347
- ):
348
- r"""Judge spiking state with a piecewise exponential function [1]_.
349
-
350
- If `origin=False`, computes the forward function:
351
-
352
- .. math::
353
-
354
- g(x) = \begin{cases}
355
- 1, & x \geq 0 \\
356
- 0, & x < 0 \\
357
- \end{cases}
358
-
359
- If `origin=True`, computes the original function:
360
-
361
- .. math::
362
-
363
- g(x) = \begin{cases}
364
- \frac{1}{2}e^{\alpha x}, & x < 0 \\
365
- 1 - \frac{1}{2}e^{-\alpha x}, & x \geq 0
366
- \end{cases}
367
-
368
- Backward function:
369
-
370
- .. math::
371
-
372
- g'(x) = \frac{\alpha}{2}e^{-\alpha |x|}
373
-
374
- .. plot::
375
- :include-source: True
376
-
377
- >>> import brainstate.nn as nn
378
- >>> import brainstate as bst
379
- >>> import matplotlib.pyplot as plt
380
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
381
- >>> for alpha in [0.5, 1., 2., 4.]:
382
- >>> grads = bst.transform.vector_grad(nn.surrogate.piecewise_exp)(xs, alpha)
383
- >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
384
- >>> plt.legend()
385
- >>> plt.show()
386
-
387
- Parameters
388
- ----------
389
- x: jax.Array, Array
390
- The input data.
391
- alpha: float
392
- Parameter to control smoothness of gradient
393
-
394
-
395
- Returns
396
- -------
397
- out: jax.Array
398
- The spiking state.
399
-
400
- References
401
- ----------
402
- .. [1] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63.
403
- """
404
- return PiecewiseExp(alpha=alpha)(x)
405
-
406
-
407
- class SoftSign(Surrogate):
408
- """Judge spiking state with a soft sign function.
409
-
410
- See Also
411
- --------
412
- soft_sign
413
- """
414
-
415
- def __init__(self, alpha=1.):
416
- super().__init__()
417
- self.alpha = alpha
418
-
419
- def surrogate_grad(self, x):
420
- dx = self.alpha * 0.5 / (1 + jnp.abs(self.alpha * x)) ** 2
421
- return dx
422
-
423
- def surrogate_fun(self, x):
424
- return x / (2 / self.alpha + 2 * jnp.abs(x)) + 0.5
425
-
426
- def __repr__(self):
427
- return f'{self.__class__.__name__}(alpha={self.alpha})'
428
-
429
-
430
- def soft_sign(
431
- x: jax.Array,
432
- alpha: float = 1.,
433
-
434
- ):
435
- r"""Judge spiking state with a soft sign function.
436
-
437
- If `origin=False`, computes the forward function:
438
-
439
- .. math::
440
-
441
- g(x) = \begin{cases}
442
- 1, & x \geq 0 \\
443
- 0, & x < 0 \\
444
- \end{cases}
445
-
446
- If `origin=True`, computes the original function:
447
-
448
- .. math::
449
-
450
- g(x) = \frac{1}{2} (\frac{\alpha x}{1 + |\alpha x|} + 1)
451
- = \frac{1}{2} (\frac{x}{\frac{1}{\alpha} + |x|} + 1)
452
-
453
- Backward function:
454
-
455
- .. math::
456
-
457
- g'(x) = \frac{\alpha}{2(1 + |\alpha x|)^{2}} = \frac{1}{2\alpha(\frac{1}{\alpha} + |x|)^{2}}
458
-
459
- .. plot::
460
- :include-source: True
461
-
462
- >>> import brainstate.nn as nn
463
- >>> import brainstate as bst
464
- >>> import matplotlib.pyplot as plt
465
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
466
- >>> for alpha in [0.5, 1., 2., 4.]:
467
- >>> grads = bst.transform.vector_grad(nn.surrogate.soft_sign)(xs, alpha)
468
- >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
469
- >>> plt.legend()
470
- >>> plt.show()
471
-
472
- Parameters
473
- ----------
474
- x: jax.Array, Array
475
- The input data.
476
- alpha: float
477
- Parameter to control smoothness of gradient
478
-
479
-
480
- Returns
481
- -------
482
- out: jax.Array
483
- The spiking state.
484
-
485
- """
486
- return SoftSign(alpha=alpha)(x)
487
-
488
-
489
- class Arctan(Surrogate):
490
- """Judge spiking state with an arctan function.
491
-
492
- See Also
493
- --------
494
- arctan
495
- """
496
-
497
- def __init__(self, alpha=1.):
498
- super().__init__()
499
- self.alpha = alpha
500
-
501
- def surrogate_grad(self, x):
502
- dx = self.alpha * 0.5 / (1 + (jnp.pi / 2 * self.alpha * x) ** 2)
503
- return dx
504
-
505
- def surrogate_fun(self, x):
506
- return jnp.arctan2(jnp.pi / 2 * self.alpha * x) / jnp.pi + 0.5
507
-
508
- def __repr__(self):
509
- return f'{self.__class__.__name__}(alpha={self.alpha})'
510
-
511
-
512
- def arctan(
513
- x: jax.Array,
514
- alpha: float = 1.,
515
-
516
- ):
517
- r"""Judge spiking state with an arctan function.
518
-
519
- If `origin=False`, computes the forward function:
520
-
521
- .. math::
522
-
523
- g(x) = \begin{cases}
524
- 1, & x \geq 0 \\
525
- 0, & x < 0 \\
526
- \end{cases}
527
-
528
- If `origin=True`, computes the original function:
529
-
530
- .. math::
531
-
532
- g(x) = \frac{1}{\pi} \arctan(\frac{\pi}{2}\alpha x) + \frac{1}{2}
533
-
534
- Backward function:
535
-
536
- .. math::
537
-
538
- g'(x) = \frac{\alpha}{2(1 + (\frac{\pi}{2}\alpha x)^2)}
539
-
540
- .. plot::
541
- :include-source: True
542
-
543
- >>> import brainstate.nn as nn
544
- >>> import brainstate as bst
545
- >>> import matplotlib.pyplot as plt
546
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
547
- >>> for alpha in [0.5, 1., 2., 4.]:
548
- >>> grads = bst.transform.vector_grad(nn.surrogate.arctan)(xs, alpha)
549
- >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
550
- >>> plt.legend()
551
- >>> plt.show()
552
-
553
- Parameters
554
- ----------
555
- x: jax.Array, Array
556
- The input data.
557
- alpha: float
558
- Parameter to control smoothness of gradient
559
-
560
-
561
- Returns
562
- -------
563
- out: jax.Array
564
- The spiking state.
565
-
566
- """
567
- return Arctan(alpha=alpha)(x)
568
-
569
-
570
- class NonzeroSignLog(Surrogate):
571
- """Judge spiking state with a nonzero sign log function.
572
-
573
- See Also
574
- --------
575
- nonzero_sign_log
576
- """
577
-
578
- def __init__(self, alpha=1.):
579
- super().__init__()
580
- self.alpha = alpha
581
-
582
- def surrogate_grad(self, x):
583
- dx = 1. / (1 / self.alpha + jnp.abs(x))
584
- return dx
585
-
586
- def surrogate_fun(self, x):
587
- return jnp.where(x < 0, -1., 1.) * jnp.log(jnp.abs(self.alpha * x) + 1)
588
-
589
- def __repr__(self):
590
- return f'{self.__class__.__name__}(alpha={self.alpha})'
591
-
592
-
593
- def nonzero_sign_log(
594
- x: jax.Array,
595
- alpha: float = 1.,
596
-
597
- ):
598
- r"""Judge spiking state with a nonzero sign log function.
599
-
600
- If `origin=False`, computes the forward function:
601
-
602
- .. math::
603
-
604
- g(x) = \begin{cases}
605
- 1, & x \geq 0 \\
606
- 0, & x < 0 \\
607
- \end{cases}
608
-
609
- If `origin=True`, computes the original function:
610
-
611
- .. math::
612
-
613
- g(x) = \mathrm{NonzeroSign}(x) \log (|\alpha x| + 1)
614
-
615
- where
616
-
617
- .. math::
618
-
619
- \begin{split}\mathrm{NonzeroSign}(x) =
620
- \begin{cases}
621
- 1, & x \geq 0 \\
622
- -1, & x < 0 \\
623
- \end{cases}\end{split}
624
-
625
- Backward function:
626
-
627
- .. math::
628
-
629
- g'(x) = \frac{\alpha}{1 + |\alpha x|} = \frac{1}{\frac{1}{\alpha} + |x|}
630
-
631
- This surrogate function has the advantage of low computation cost during the backward.
632
-
633
-
634
- .. plot::
635
- :include-source: True
636
-
637
- >>> import brainstate.nn as nn
638
- >>> import brainstate as bst
639
- >>> import matplotlib.pyplot as plt
640
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
641
- >>> for alpha in [0.5, 1., 2., 4.]:
642
- >>> grads = bst.transform.vector_grad(nn.surrogate.nonzero_sign_log)(xs, alpha)
643
- >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
644
- >>> plt.legend()
645
- >>> plt.show()
646
-
647
- Parameters
648
- ----------
649
- x: jax.Array, Array
650
- The input data.
651
- alpha: float
652
- Parameter to control smoothness of gradient
653
-
654
-
655
- Returns
656
- -------
657
- out: jax.Array
658
- The spiking state.
659
-
660
- """
661
- return NonzeroSignLog(alpha=alpha)(x)
662
-
663
-
664
- class ERF(Surrogate):
665
- """Judge spiking state with an erf function.
666
-
667
- See Also
668
- --------
669
- erf
670
- """
671
-
672
- def __init__(self, alpha=1.):
673
- super().__init__()
674
- self.alpha = alpha
675
-
676
- def surrogate_grad(self, x):
677
- dx = (self.alpha / jnp.sqrt(jnp.pi)) * jnp.exp(-jnp.power(self.alpha, 2) * x * x)
678
- return dx
679
-
680
- def surrogate_fun(self, x):
681
- return sci.special.erf(-self.alpha * x) * 0.5
682
-
683
- def __repr__(self):
684
- return f'{self.__class__.__name__}(alpha={self.alpha})'
685
-
686
-
687
- def erf(
688
- x: jax.Array,
689
- alpha: float = 1.,
690
-
691
- ):
692
- r"""Judge spiking state with an erf function [1]_ [2]_ [3]_.
693
-
694
- If `origin=False`, computes the forward function:
695
-
696
- .. math::
697
-
698
- g(x) = \begin{cases}
699
- 1, & x \geq 0 \\
700
- 0, & x < 0 \\
701
- \end{cases}
702
-
703
- If `origin=True`, computes the original function:
704
-
705
- .. math::
706
-
707
- \begin{split}
708
- g(x) &= \frac{1}{2}(1-\text{erf}(-\alpha x)) \\
709
- &= \frac{1}{2} \text{erfc}(-\alpha x) \\
710
- &= \frac{1}{\sqrt{\pi}}\int_{-\infty}^{\alpha x}e^{-t^2}dt
711
- \end{split}
712
-
713
- Backward function:
714
-
715
- .. math::
716
-
717
- g'(x) = \frac{\alpha}{\sqrt{\pi}}e^{-\alpha^2x^2}
718
-
719
- .. plot::
720
- :include-source: True
721
-
722
- >>> import brainstate.nn as nn
723
- >>> import brainstate as bst
724
- >>> import matplotlib.pyplot as plt
725
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
726
- >>> for alpha in [0.5, 1., 2., 4.]:
727
- >>> grads = bst.transform.vector_grad(nn.surrogate.nonzero_sign_log)(xs, alpha)
728
- >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
729
- >>> plt.legend()
730
- >>> plt.show()
731
-
732
- Parameters
733
- ----------
734
- x: jax.Array, Array
735
- The input data.
736
- alpha: float
737
- Parameter to control smoothness of gradient
738
-
739
-
740
- Returns
741
- -------
742
- out: jax.Array
743
- The spiking state.
744
-
745
- References
746
- ----------
747
- .. [1] Esser S K, Appuswamy R, Merolla P, et al. Backpropagation for energy-efficient neuromorphic computing[J]. Advances in neural information processing systems, 2015, 28: 1117-1125.
748
- .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331.
749
- .. [3] Yin B, Corradi F, Bohté S M. Effective and efficient computation with multiple-timescale spiking recurrent neural networks[C]//International Conference on Neuromorphic Systems 2020. 2020: 1-8.
750
-
751
- """
752
- return ERF(alpha=alpha)(x)
753
-
754
-
755
- class PiecewiseLeakyRelu(Surrogate):
756
- """Judge spiking state with a piecewise leaky relu function.
757
-
758
- See Also
759
- --------
760
- piecewise_leaky_relu
761
- """
762
-
763
- def __init__(self, c=0.01, w=1.):
764
- super().__init__()
765
- self.c = c
766
- self.w = w
767
-
768
- def surrogate_fun(self, x):
769
- z = jnp.where(x < -self.w,
770
- self.c * x + self.c * self.w,
771
- jnp.where(x > self.w,
772
- self.c * x - self.c * self.w + 1,
773
- 0.5 * x / self.w + 0.5))
774
- return z
775
-
776
- def surrogate_grad(self, x):
777
- dx = jnp.where(jnp.abs(x) > self.w, self.c, 1 / self.w)
778
- return dx
779
-
780
- def __repr__(self):
781
- return f'{self.__class__.__name__}(c={self.c}, w={self.w})'
782
-
783
-
784
- def piecewise_leaky_relu(
785
- x: jax.Array,
786
- c: float = 0.01,
787
- w: float = 1.,
788
-
789
- ):
790
- r"""Judge spiking state with a piecewise leaky relu function [1]_ [2]_ [3]_ [4]_ [5]_ [6]_ [7]_ [8]_.
791
-
792
- If `origin=False`, computes the forward function:
793
-
794
- .. math::
795
-
796
- g(x) = \begin{cases}
797
- 1, & x \geq 0 \\
798
- 0, & x < 0 \\
799
- \end{cases}
800
-
801
- If `origin=True`, computes the original function:
802
-
803
- .. math::
804
-
805
- \begin{split}g(x) =
806
- \begin{cases}
807
- cx + cw, & x < -w \\
808
- \frac{1}{2w}x + \frac{1}{2}, & -w \leq x \leq w \\
809
- cx - cw + 1, & x > w \\
810
- \end{cases}\end{split}
811
-
812
- Backward function:
813
-
814
- .. math::
815
-
816
- \begin{split}g'(x) =
817
- \begin{cases}
818
- \frac{1}{w}, & |x| \leq w \\
819
- c, & |x| > w
820
- \end{cases}\end{split}
821
-
822
- .. plot::
823
- :include-source: True
824
-
825
- >>> import brainstate.nn as nn
826
- >>> import brainstate as bst
827
- >>> import matplotlib.pyplot as plt
828
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
829
- >>> for c in [0.01, 0.05, 0.1]:
830
- >>> for w in [1., 2.]:
831
- >>> grads1 = bst.transform.vector_grad(nn.surrogate.piecewise_leaky_relu)(xs, c=c, w=w)
832
- >>> plt.plot(xs, grads1, label=f'x={c}, w={w}')
833
- >>> plt.legend()
834
- >>> plt.show()
835
-
836
- Parameters
837
- ----------
838
- x: jax.Array, Array
839
- The input data.
840
- c: float
841
- When :math:`|x| > w` the gradient is `c`.
842
- w: float
843
- When :math:`|x| <= w` the gradient is `1 / w`.
844
-
845
-
846
- Returns
847
- -------
848
- out: jax.Array
849
- The spiking state.
850
-
851
- References
852
- ----------
853
- .. [1] Yin S, Venkataramanaiah S K, Chen G K, et al. Algorithm and hardware design of discrete-time spiking neural networks based on back propagation with binary activations[C]//2017 IEEE Biomedical Circuits and Systems Conference (BioCAS). IEEE, 2017: 1-5.
854
- .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331.
855
- .. [3] Huh D, Sejnowski T J. Gradient descent for spiking neural networks[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 1440-1450.
856
- .. [4] Wu Y, Deng L, Li G, et al. Direct training for spiking neural networks: Faster, larger, better[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2019, 33(01): 1311-1318.
857
- .. [5] Gu P, Xiao R, Pan G, et al. STCA: Spatio-Temporal Credit Assignment with Delayed Feedback in Deep Spiking Neural Networks[C]//IJCAI. 2019: 1366-1372.
858
- .. [6] Roy D, Chakraborty I, Roy K. Scaling deep spiking neural networks with binary stochastic activations[C]//2019 IEEE International Conference on Cognitive Computing (ICCC). IEEE, 2019: 50-58.
859
- .. [7] Cheng X, Hao Y, Xu J, et al. LISNN: Improving Spiking Neural Networks with Lateral Interactions for Robust Object Recognition[C]//IJCAI. 1519-1525.
860
- .. [8] Kaiser J, Mostafa H, Neftci E. Synaptic plasticity dynamics for deep continuous local learning (DECOLLE)[J]. Frontiers in Neuroscience, 2020, 14: 424.
861
-
862
- """
863
- return PiecewiseLeakyRelu(c=c, w=w)(x)
864
-
865
-
866
- class SquarewaveFourierSeries(Surrogate):
867
- """Judge spiking state with a squarewave fourier series.
868
-
869
- See Also
870
- --------
871
- squarewave_fourier_series
872
- """
873
-
874
- def __init__(self, n=2, t_period=8.):
875
- super().__init__()
876
- self.n = n
877
- self.t_period = t_period
878
-
879
- def surrogate_grad(self, x):
880
-
881
- w = jnp.pi * 2. / self.t_period
882
- dx = jnp.cos(w * x)
883
- for i in range(2, self.n):
884
- dx += jnp.cos((2 * i - 1.) * w * x)
885
- dx *= 4. / self.t_period
886
- return dx
887
-
888
- def surrogate_fun(self, x):
889
-
890
- w = jnp.pi * 2. / self.t_period
891
- ret = jnp.sin(w * x)
892
- for i in range(2, self.n):
893
- c = (2 * i - 1.)
894
- ret += jnp.sin(c * w * x) / c
895
- z = 0.5 + 2. / jnp.pi * ret
896
- return z
897
-
898
- def __repr__(self):
899
- return f'{self.__class__.__name__}(n={self.n}, t_period={self.t_period})'
900
-
901
-
902
- def squarewave_fourier_series(
903
- x: jax.Array,
904
- n: int = 2,
905
- t_period: float = 8.,
906
-
907
- ):
908
- r"""Judge spiking state with a squarewave fourier series.
909
-
910
- If `origin=False`, computes the forward function:
911
-
912
- .. math::
913
-
914
- g(x) = \begin{cases}
915
- 1, & x \geq 0 \\
916
- 0, & x < 0 \\
917
- \end{cases}
918
-
919
- If `origin=True`, computes the original function:
920
-
921
- .. math::
922
-
923
- g(x) = 0.5 + \frac{1}{\pi}*\sum_{i=1}^n {\sin\left({(2i-1)*2\pi}*x/T\right) \over 2i-1 }
924
-
925
- Backward function:
926
-
927
- .. math::
928
-
929
- g'(x) = \sum_{i=1}^n\frac{4\cos\left((2 * i - 1.) * 2\pi * x / T\right)}{T}
930
-
931
- .. plot::
932
- :include-source: True
933
-
934
- >>> import brainstate.nn as nn
935
- >>> import brainstate as bst
936
- >>> import matplotlib.pyplot as plt
937
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
938
- >>> for n in [2, 4, 8]:
939
- >>> f = nn.surrogate.SquarewaveFourierSeries(n=n)
940
- >>> grads1 = bst.transform.vector_grad(f)(xs)
941
- >>> plt.plot(xs, grads1, label=f'n={n}')
942
- >>> plt.legend()
943
- >>> plt.show()
944
-
945
- Parameters
946
- ----------
947
- x: jax.Array, Array
948
- The input data.
949
- n: int
950
- t_period: float
951
-
952
-
953
- Returns
954
- -------
955
- out: jax.Array
956
- The spiking state.
957
-
958
- """
959
-
960
- return SquarewaveFourierSeries(n=n, t_period=t_period)(x)
961
-
962
-
963
- class S2NN(Surrogate):
964
- """Judge spiking state with the S2NN surrogate spiking function.
965
-
966
- See Also
967
- --------
968
- s2nn
969
- """
970
-
971
- def __init__(self, alpha=4., beta=1., epsilon=1e-8):
972
- super().__init__()
973
- self.alpha = alpha
974
- self.beta = beta
975
- self.epsilon = epsilon
976
-
977
- def surrogate_fun(self, x):
978
- z = jnp.where(x < 0.,
979
- sci.special.expit(x * self.alpha),
980
- self.beta * jnp.log(jnp.abs((x + 1.)) + self.epsilon) + 0.5)
981
- return z
982
-
983
- def surrogate_grad(self, x):
984
- sg = sci.special.expit(self.alpha * x)
985
- dx = jnp.where(x < 0., self.alpha * sg * (1. - sg), self.beta / (x + 1.))
986
- return dx
987
-
988
- def __repr__(self):
989
- return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta}, epsilon={self.epsilon})'
990
-
991
-
992
- def s2nn(
993
- x: jax.Array,
994
- alpha: float = 4.,
995
- beta: float = 1.,
996
- epsilon: float = 1e-8,
997
-
998
- ):
999
- r"""Judge spiking state with the S2NN surrogate spiking function [1]_.
1000
-
1001
- If `origin=False`, computes the forward function:
1002
-
1003
- .. math::
1004
-
1005
- g(x) = \begin{cases}
1006
- 1, & x \geq 0 \\
1007
- 0, & x < 0 \\
1008
- \end{cases}
1009
-
1010
- If `origin=True`, computes the original function:
1011
-
1012
- .. math::
1013
-
1014
- \begin{split}g(x) = \begin{cases}
1015
- \mathrm{sigmoid} (\alpha x), x < 0 \\
1016
- \beta \ln(|x + 1|) + 0.5, x \ge 0
1017
- \end{cases}\end{split}
1018
-
1019
- Backward function:
1020
-
1021
- .. math::
1022
-
1023
- \begin{split}g'(x) = \begin{cases}
1024
- \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x), x < 0 \\
1025
- \frac{\beta}{(x + 1)}, x \ge 0
1026
- \end{cases}\end{split}
1027
-
1028
- .. plot::
1029
- :include-source: True
1030
-
1031
- >>> import brainstate.nn as nn
1032
- >>> import brainstate as bst
1033
- >>> import matplotlib.pyplot as plt
1034
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
1035
- >>> grads = bst.transform.vector_grad(nn.surrogate.s2nn)(xs, 4., 1.)
1036
- >>> plt.plot(xs, grads, label=r'$\alpha=4, \beta=1$')
1037
- >>> grads = bst.transform.vector_grad(nn.surrogate.s2nn)(xs, 8., 2.)
1038
- >>> plt.plot(xs, grads, label=r'$\alpha=8, \beta=2$')
1039
- >>> plt.legend()
1040
- >>> plt.show()
1041
-
1042
- Parameters
1043
- ----------
1044
- x: jax.Array, Array
1045
- The input data.
1046
- alpha: float
1047
- The param that controls the gradient when ``x < 0``.
1048
- beta: float
1049
- The param that controls the gradient when ``x >= 0``
1050
- epsilon: float
1051
- Avoid nan
1052
-
1053
-
1054
- Returns
1055
- -------
1056
- out: jax.Array
1057
- The spiking state.
1058
-
1059
- References
1060
- ----------
1061
- .. [1] Suetake, Kazuma et al. “S2NN: Time Step Reduction of Spiking Surrogate Gradients for Training Energy Efficient Single-Step Neural Networks.” ArXiv abs/2201.10879 (2022): n. pag.
1062
-
1063
- """
1064
- return S2NN(alpha=alpha, beta=beta, epsilon=epsilon)(x)
1065
-
1066
-
1067
- class QPseudoSpike(Surrogate):
1068
- """Judge spiking state with the q-PseudoSpike surrogate function.
1069
-
1070
- See Also
1071
- --------
1072
- q_pseudo_spike
1073
- """
1074
-
1075
- def __init__(self, alpha=2.):
1076
- super().__init__()
1077
- self.alpha = alpha
1078
-
1079
- def surrogate_grad(self, x):
1080
- dx = jnp.power(1 + 2 / (self.alpha + 1) * jnp.abs(x), -self.alpha)
1081
- return dx
1082
-
1083
- def surrogate_fun(self, x):
1084
- z = jnp.where(x < 0.,
1085
- 0.5 * jnp.power(1 - 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha),
1086
- 1. - 0.5 * jnp.power(1 + 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha))
1087
- return z
1088
-
1089
- def __repr__(self):
1090
- return f'{self.__class__.__name__}(alpha={self.alpha})'
1091
-
1092
-
1093
- def q_pseudo_spike(
1094
- x: jax.Array,
1095
- alpha: float = 2.,
1096
-
1097
- ):
1098
- r"""Judge spiking state with the q-PseudoSpike surrogate function [1]_.
1099
-
1100
- If `origin=False`, computes the forward function:
1101
-
1102
- .. math::
1103
-
1104
- g(x) = \begin{cases}
1105
- 1, & x \geq 0 \\
1106
- 0, & x < 0 \\
1107
- \end{cases}
1108
-
1109
- If `origin=True`, computes the original function:
1110
-
1111
- .. math::
1112
-
1113
- \begin{split}g(x) =
1114
- \begin{cases}
1115
- \frac{1}{2}(1-\frac{2x}{\alpha-1})^{1-\alpha}, & x < 0 \\
1116
- 1 - \frac{1}{2}(1+\frac{2x}{\alpha-1})^{1-\alpha}, & x \geq 0.
1117
- \end{cases}\end{split}
1118
-
1119
- Backward function:
1120
-
1121
- .. math::
1122
-
1123
- g'(x) = (1+\frac{2|x|}{\alpha-1})^{-\alpha}
1124
-
1125
- .. plot::
1126
- :include-source: True
1127
-
1128
- >>> import brainstate.nn as nn
1129
- >>> import brainstate as bst
1130
- >>> import matplotlib.pyplot as plt
1131
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
1132
- >>> for alpha in [0.5, 1., 2., 4.]:
1133
- >>> grads = bst.transform.vector_grad(nn.surrogate.q_pseudo_spike)(xs, alpha)
1134
- >>> plt.plot(xs, grads, label=r'$\alpha=$' + str(alpha))
1135
- >>> plt.legend()
1136
- >>> plt.show()
1137
-
1138
- Parameters
1139
- ----------
1140
- x: jax.Array, Array
1141
- The input data.
1142
- alpha: float
1143
- The parameter to control tail fatness of gradient.
1144
-
1145
-
1146
- Returns
1147
- -------
1148
- out: jax.Array
1149
- The spiking state.
1150
-
1151
- References
1152
- ----------
1153
- .. [1] Herranz-Celotti, Luca and Jean Rouat. “Surrogate Gradients Design.” ArXiv abs/2202.00282 (2022): n. pag.
1154
- """
1155
- return QPseudoSpike(alpha=alpha)(x)
1156
-
1157
-
1158
- class LeakyRelu(Surrogate):
1159
- """Judge spiking state with the Leaky ReLU function.
1160
-
1161
- See Also
1162
- --------
1163
- leaky_relu
1164
- """
1165
-
1166
- def __init__(self, alpha=0.1, beta=1.):
1167
- super().__init__()
1168
- self.alpha = alpha
1169
- self.beta = beta
1170
-
1171
- def surrogate_fun(self, x):
1172
- return jnp.where(x < 0., self.alpha * x, self.beta * x)
1173
-
1174
- def surrogate_grad(self, x):
1175
- dx = jnp.where(x < 0., self.alpha, self.beta)
1176
- return dx
1177
-
1178
- def __repr__(self):
1179
- return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta})'
1180
-
1181
-
1182
- def leaky_relu(
1183
- x: jax.Array,
1184
- alpha: float = 0.1,
1185
- beta: float = 1.,
1186
-
1187
- ):
1188
- r"""Judge spiking state with the Leaky ReLU function.
1189
-
1190
- If `origin=False`, computes the forward function:
1191
-
1192
- .. math::
1193
-
1194
- g(x) = \begin{cases}
1195
- 1, & x \geq 0 \\
1196
- 0, & x < 0 \\
1197
- \end{cases}
1198
-
1199
- If `origin=True`, computes the original function:
1200
-
1201
- .. math::
1202
-
1203
- \begin{split}g(x) =
1204
- \begin{cases}
1205
- \beta \cdot x, & x \geq 0 \\
1206
- \alpha \cdot x, & x < 0 \\
1207
- \end{cases}\end{split}
1208
-
1209
- Backward function:
1210
-
1211
- .. math::
1212
-
1213
- \begin{split}g'(x) =
1214
- \begin{cases}
1215
- \beta, & x \geq 0 \\
1216
- \alpha, & x < 0 \\
1217
- \end{cases}\end{split}
1218
-
1219
- .. plot::
1220
- :include-source: True
1221
-
1222
- >>> import brainstate.nn as nn
1223
- >>> import brainstate as bst
1224
- >>> import matplotlib.pyplot as plt
1225
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
1226
- >>> grads = bst.transform.vector_grad(nn.surrogate.leaky_relu)(xs, 0., 1.)
1227
- >>> plt.plot(xs, grads, label=r'$\alpha=0., \beta=1.$')
1228
- >>> plt.legend()
1229
- >>> plt.show()
1230
-
1231
- Parameters
1232
- ----------
1233
- x: jax.Array, Array
1234
- The input data.
1235
- alpha: float
1236
- The parameter to control the gradient when :math:`x < 0`.
1237
- beta: float
1238
- The parameter to control the gradient when :math:`x >= 0`.
1239
-
1240
-
1241
- Returns
1242
- -------
1243
- out: jax.Array
1244
- The spiking state.
1245
- """
1246
- return LeakyRelu(alpha=alpha, beta=beta)(x)
1247
-
1248
-
1249
- class LogTailedRelu(Surrogate):
1250
- """Judge spiking state with the Log-tailed ReLU function.
1251
-
1252
- See Also
1253
- --------
1254
- log_tailed_relu
1255
- """
1256
-
1257
- def __init__(self, alpha=0.):
1258
- super().__init__()
1259
- self.alpha = alpha
1260
-
1261
- def surrogate_fun(self, x):
1262
- z = jnp.where(x > 1,
1263
- jnp.log(x),
1264
- jnp.where(x > 0,
1265
- x,
1266
- self.alpha * x))
1267
- return z
1268
-
1269
- def surrogate_grad(self, x):
1270
- dx = jnp.where(x > 1,
1271
- 1 / x,
1272
- jnp.where(x > 0,
1273
- 1.,
1274
- self.alpha))
1275
- return dx
1276
-
1277
- def __repr__(self):
1278
- return f'{self.__class__.__name__}(alpha={self.alpha})'
1279
-
1280
-
1281
- def log_tailed_relu(
1282
- x: jax.Array,
1283
- alpha: float = 0.,
1284
-
1285
- ):
1286
- r"""Judge spiking state with the Log-tailed ReLU function [1]_.
1287
-
1288
- If `origin=False`, computes the forward function:
1289
-
1290
- .. math::
1291
-
1292
- g(x) = \begin{cases}
1293
- 1, & x \geq 0 \\
1294
- 0, & x < 0 \\
1295
- \end{cases}
1296
-
1297
- If `origin=True`, computes the original function:
1298
-
1299
- .. math::
1300
-
1301
- \begin{split}g(x) =
1302
- \begin{cases}
1303
- \alpha x, & x \leq 0 \\
1304
- x, & 0 < x \leq 0 \\
1305
- log(x), x > 1 \\
1306
- \end{cases}\end{split}
1307
-
1308
- Backward function:
1309
-
1310
- .. math::
1311
-
1312
- \begin{split}g'(x) =
1313
- \begin{cases}
1314
- \alpha, & x \leq 0 \\
1315
- 1, & 0 < x \leq 0 \\
1316
- \frac{1}{x}, x > 1 \\
1317
- \end{cases}\end{split}
1318
-
1319
- .. plot::
1320
- :include-source: True
1321
-
1322
- >>> import brainstate.nn as nn
1323
- >>> import brainstate as bst
1324
- >>> import matplotlib.pyplot as plt
1325
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
1326
- >>> grads = bst.transform.vector_grad(nn.surrogate.leaky_relu)(xs, 0., 1.)
1327
- >>> plt.plot(xs, grads, label=r'$\alpha=0., \beta=1.$')
1328
- >>> plt.legend()
1329
- >>> plt.show()
1330
-
1331
- Parameters
1332
- ----------
1333
- x: jax.Array, Array
1334
- The input data.
1335
- alpha: float
1336
- The parameter to control the gradient.
1337
-
1338
-
1339
- Returns
1340
- -------
1341
- out: jax.Array
1342
- The spiking state.
1343
-
1344
- References
1345
- ----------
1346
- .. [1] Cai, Zhaowei et al. “Deep Learning with Low Precision by Half-Wave Gaussian Quantization.” 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2017): 5406-5414.
1347
- """
1348
- return LogTailedRelu(alpha=alpha)(x)
1349
-
1350
-
1351
- class ReluGrad(Surrogate):
1352
- """Judge spiking state with the ReLU gradient function.
1353
-
1354
- See Also
1355
- --------
1356
- relu_grad
1357
- """
1358
-
1359
- def __init__(self, alpha=0.3, width=1.):
1360
- super().__init__()
1361
- self.alpha = alpha
1362
- self.width = width
1363
-
1364
- def surrogate_grad(self, x):
1365
- dx = jnp.maximum(self.alpha * self.width - jnp.abs(x) * self.alpha, 0)
1366
- return dx
1367
-
1368
- def __repr__(self):
1369
- return f'{self.__class__.__name__}(alpha={self.alpha}, width={self.width})'
1370
-
1371
-
1372
- def relu_grad(
1373
- x: jax.Array,
1374
- alpha: float = 0.3,
1375
- width: float = 1.,
1376
- ):
1377
- r"""Spike function with the ReLU gradient function [1]_.
1378
-
1379
- The forward function:
1380
-
1381
- .. math::
1382
-
1383
- g(x) = \begin{cases}
1384
- 1, & x \geq 0 \\
1385
- 0, & x < 0 \\
1386
- \end{cases}
1387
-
1388
- Backward function:
1389
-
1390
- .. math::
1391
-
1392
- g'(x) = \text{ReLU}(\alpha * (\mathrm{width}-|x|))
1393
-
1394
- .. plot::
1395
- :include-source: True
1396
-
1397
- >>> import brainstate.nn as nn
1398
- >>> import brainstate as bst
1399
- >>> import matplotlib.pyplot as plt
1400
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
1401
- >>> for s in [0.5, 1.]:
1402
- >>> for w in [1, 2.]:
1403
- >>> grads = bst.transform.vector_grad(nn.surrogate.relu_grad)(xs, s, w)
1404
- >>> plt.plot(xs, grads, label=r'$\alpha=$' + f'{s}, width={w}')
1405
- >>> plt.legend()
1406
- >>> plt.show()
1407
-
1408
- Parameters
1409
- ----------
1410
- x: jax.Array, Array
1411
- The input data.
1412
- alpha: float
1413
- The parameter to control the gradient.
1414
- width: float
1415
- The parameter to control the width of the gradient.
1416
-
1417
- Returns
1418
- -------
1419
- out: jax.Array
1420
- The spiking state.
1421
-
1422
- References
1423
- ----------
1424
- .. [1] Neftci, E. O., Mostafa, H. & Zenke, F. Surrogate gradient learning in spiking neural networks. IEEE Signal Process. Mag. 36, 61–63 (2019).
1425
- """
1426
- return ReluGrad(alpha=alpha, width=width)(x)
1427
-
1428
-
1429
- class GaussianGrad(Surrogate):
1430
- """Judge spiking state with the Gaussian gradient function.
1431
-
1432
- See Also
1433
- --------
1434
- gaussian_grad
1435
- """
1436
-
1437
- def __init__(self, sigma=0.5, alpha=0.5):
1438
- super().__init__()
1439
- self.sigma = sigma
1440
- self.alpha = alpha
1441
-
1442
- def surrogate_grad(self, x):
1443
- dx = jnp.exp(-(x ** 2) / 2 * jnp.power(self.sigma, 2)) / (jnp.sqrt(2 * jnp.pi) * self.sigma)
1444
- return self.alpha * dx
1445
-
1446
- def __repr__(self):
1447
- return f'{self.__class__.__name__}(alpha={self.alpha}, sigma={self.sigma})'
1448
-
1449
-
1450
- def gaussian_grad(
1451
- x: jax.Array,
1452
- sigma: float = 0.5,
1453
- alpha: float = 0.5,
1454
- ):
1455
- r"""Spike function with the Gaussian gradient function [1]_.
1456
-
1457
- The forward function:
1458
-
1459
- .. math::
1460
-
1461
- g(x) = \begin{cases}
1462
- 1, & x \geq 0 \\
1463
- 0, & x < 0 \\
1464
- \end{cases}
1465
-
1466
- Backward function:
1467
-
1468
- .. math::
1469
-
1470
- g'(x) = \alpha * \text{gaussian}(x, 0., \sigma)
1471
-
1472
- .. plot::
1473
- :include-source: True
1474
-
1475
- >>> import brainstate.nn as nn
1476
- >>> import brainstate as bst
1477
- >>> import matplotlib.pyplot as plt
1478
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
1479
- >>> for s in [0.5, 1., 2.]:
1480
- >>> grads = bst.transform.vector_grad(nn.surrogate.gaussian_grad)(xs, s, 0.5)
1481
- >>> plt.plot(xs, grads, label=r'$\alpha=0.5, \sigma=$' + str(s))
1482
- >>> plt.legend()
1483
- >>> plt.show()
1484
-
1485
- Parameters
1486
- ----------
1487
- x: jax.Array, Array
1488
- The input data.
1489
- sigma: float
1490
- The parameter to control the variance of gaussian distribution.
1491
- alpha: float
1492
- The parameter to control the scale of the gradient.
1493
-
1494
- Returns
1495
- -------
1496
- out: jax.Array
1497
- The spiking state.
1498
-
1499
- References
1500
- ----------
1501
- .. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021).
1502
- """
1503
- return GaussianGrad(sigma=sigma, alpha=alpha)(x)
1504
-
1505
-
1506
- class MultiGaussianGrad(Surrogate):
1507
- """Judge spiking state with the multi-Gaussian gradient function.
1508
-
1509
- See Also
1510
- --------
1511
- multi_gaussian_grad
1512
- """
1513
-
1514
- def __init__(self, h=0.15, s=6.0, sigma=0.5, scale=0.5):
1515
- super().__init__()
1516
- self.h = h
1517
- self.s = s
1518
- self.sigma = sigma
1519
- self.scale = scale
1520
-
1521
- def surrogate_grad(self, x):
1522
- g1 = jnp.exp(-x ** 2 / (2 * jnp.power(self.sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * self.sigma)
1523
- g2 = jnp.exp(-(x - self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2))
1524
- ) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma)
1525
- g3 = jnp.exp(-(x + self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2))
1526
- ) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma)
1527
- dx = g1 * (1. + self.h) - g2 * self.h - g3 * self.h
1528
- return self.scale * dx
1529
-
1530
- def __repr__(self):
1531
- return f'{self.__class__.__name__}(h={self.h}, s={self.s}, sigma={self.sigma}, scale={self.scale})'
1532
-
1533
-
1534
- def multi_gaussian_grad(
1535
- x: jax.Array,
1536
- h: float = 0.15,
1537
- s: float = 6.0,
1538
- sigma: float = 0.5,
1539
- scale: float = 0.5,
1540
- ):
1541
- r"""Spike function with the multi-Gaussian gradient function [1]_.
1542
-
1543
- The forward function:
1544
-
1545
- .. math::
1546
-
1547
- g(x) = \begin{cases}
1548
- 1, & x \geq 0 \\
1549
- 0, & x < 0 \\
1550
- \end{cases}
1551
-
1552
- Backward function:
1553
-
1554
- .. math::
1555
-
1556
- \begin{array}{l}
1557
- g'(x)=(1+h){{{\mathcal{N}}}}(x, 0, {\sigma }^{2})
1558
- -h{{{\mathcal{N}}}}(x, \sigma,{(s\sigma )}^{2})-
1559
- h{{{\mathcal{N}}}}(x, -\sigma ,{(s\sigma )}^{2})
1560
- \end{array}
1561
-
1562
-
1563
- .. plot::
1564
- :include-source: True
1565
-
1566
- >>> import brainstate.nn as nn
1567
- >>> import brainstate as bst
1568
- >>> import matplotlib.pyplot as plt
1569
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
1570
- >>> grads = bst.transform.vector_grad(nn.surrogate.multi_gaussian_grad)(xs)
1571
- >>> plt.plot(xs, grads)
1572
- >>> plt.show()
1573
-
1574
- Parameters
1575
- ----------
1576
- x: jax.Array, Array
1577
- The input data.
1578
- h: float
1579
- The hyper-parameters of approximate function
1580
- s: float
1581
- The hyper-parameters of approximate function
1582
- sigma: float
1583
- The gaussian sigma.
1584
- scale: float
1585
- The gradient scale.
1586
-
1587
- Returns
1588
- -------
1589
- out: jax.Array
1590
- The spiking state.
1591
-
1592
- References
1593
- ----------
1594
- .. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021).
1595
- """
1596
- return MultiGaussianGrad(h=h, s=s, sigma=sigma, scale=scale)(x)
1597
-
1598
-
1599
- class InvSquareGrad(Surrogate):
1600
- """Judge spiking state with the inverse-square surrogate gradient function.
1601
-
1602
- See Also
1603
- --------
1604
- inv_square_grad
1605
- """
1606
-
1607
- def __init__(self, alpha=100.):
1608
- super().__init__()
1609
- self.alpha = alpha
1610
-
1611
- def surrogate_grad(self, x):
1612
- dx = 1. / (self.alpha * jnp.abs(x) + 1.0) ** 2
1613
- return dx
1614
-
1615
- def __repr__(self):
1616
- return f'{self.__class__.__name__}(alpha={self.alpha})'
1617
-
1618
-
1619
- def inv_square_grad(
1620
- x: jax.Array,
1621
- alpha: float = 100.
1622
- ):
1623
- r"""Spike function with the inverse-square surrogate gradient.
1624
-
1625
- Forward function:
1626
-
1627
- .. math::
1628
-
1629
- g(x) = \begin{cases}
1630
- 1, & x \geq 0 \\
1631
- 0, & x < 0 \\
1632
- \end{cases}
1633
-
1634
- Backward function:
1635
-
1636
- .. math::
1637
-
1638
- g'(x) = \frac{1}{(\alpha * |x| + 1.) ^ 2}
1639
-
1640
-
1641
- .. plot::
1642
- :include-source: True
1643
-
1644
- >>> import brainstate.nn as nn
1645
- >>> import brainstate as bst
1646
- >>> import matplotlib.pyplot as plt
1647
- >>> xs = jax.numpy.linspace(-1, 1, 1000)
1648
- >>> for alpha in [1., 10., 100.]:
1649
- >>> grads = bst.transform.vector_grad(nn.surrogate.inv_square_grad)(xs, alpha)
1650
- >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
1651
- >>> plt.legend()
1652
- >>> plt.show()
1653
-
1654
- Parameters
1655
- ----------
1656
- x: jax.Array, Array
1657
- The input data.
1658
- alpha: float
1659
- Parameter to control smoothness of gradient
1660
-
1661
- Returns
1662
- -------
1663
- out: jax.Array
1664
- The spiking state.
1665
- """
1666
- return InvSquareGrad(alpha=alpha)(x)
1667
-
1668
-
1669
- class SlayerGrad(Surrogate):
1670
- """Judge spiking state with the slayer surrogate gradient function.
1671
-
1672
- See Also
1673
- --------
1674
- slayer_grad
1675
- """
1676
-
1677
- def __init__(self, alpha=1.):
1678
- super().__init__()
1679
- self.alpha = alpha
1680
-
1681
- def surrogate_grad(self, x):
1682
- dx = jnp.exp(-self.alpha * jnp.abs(x))
1683
- return dx
1684
-
1685
- def __repr__(self):
1686
- return f'{self.__class__.__name__}(alpha={self.alpha})'
1687
-
1688
-
1689
- def slayer_grad(
1690
- x: jax.Array,
1691
- alpha: float = 1.
1692
- ):
1693
- r"""Spike function with the slayer surrogate gradient function.
1694
-
1695
- Forward function:
1696
-
1697
- .. math::
1698
-
1699
- g(x) = \begin{cases}
1700
- 1, & x \geq 0 \\
1701
- 0, & x < 0 \\
1702
- \end{cases}
1703
-
1704
- Backward function:
1705
-
1706
- .. math::
1707
-
1708
- g'(x) = \exp(-\alpha |x|)
1709
-
1710
-
1711
- .. plot::
1712
- :include-source: True
1713
-
1714
- >>> import brainstate.nn as nn
1715
- >>> import brainstate as bst
1716
- >>> import matplotlib.pyplot as plt
1717
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
1718
- >>> for alpha in [0.5, 1., 2., 4.]:
1719
- >>> grads = bst.transform.vector_grad(nn.surrogate.slayer_grad)(xs, alpha)
1720
- >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
1721
- >>> plt.legend()
1722
- >>> plt.show()
1723
-
1724
- Parameters
1725
- ----------
1726
- x: jax.Array, Array
1727
- The input data.
1728
- alpha: float
1729
- Parameter to control smoothness of gradient
1730
-
1731
- Returns
1732
- -------
1733
- out: jax.Array
1734
- The spiking state.
1735
-
1736
- References
1737
- ----------
1738
- .. [1] Shrestha, S. B. & Orchard, G. Slayer: spike layer error reassignment in time. In Advances in Neural Information Processing Systems Vol. 31, 1412–1421 (NeurIPS, 2018).
1739
- """
1740
- return SlayerGrad(alpha=alpha)(x)