brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/surrogate.py DELETED
@@ -1,1957 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- import jax
19
- import jax.numpy as jnp
20
- import jax.scipy as sci
21
- from jax.interpreters import batching, ad, mlir
22
-
23
- from brainstate._compatible_import import Primitive
24
- from brainstate.util import PrettyObject
25
-
26
- __all__ = [
27
- 'Surrogate',
28
- 'Sigmoid',
29
- 'sigmoid',
30
- 'PiecewiseQuadratic',
31
- 'piecewise_quadratic',
32
- 'PiecewiseExp',
33
- 'piecewise_exp',
34
- 'SoftSign',
35
- 'soft_sign',
36
- 'Arctan',
37
- 'arctan',
38
- 'NonzeroSignLog',
39
- 'nonzero_sign_log',
40
- 'ERF',
41
- 'erf',
42
- 'PiecewiseLeakyRelu',
43
- 'piecewise_leaky_relu',
44
- 'SquarewaveFourierSeries',
45
- 'squarewave_fourier_series',
46
- 'S2NN',
47
- 's2nn',
48
- 'QPseudoSpike',
49
- 'q_pseudo_spike',
50
- 'LeakyRelu',
51
- 'leaky_relu',
52
- 'LogTailedRelu',
53
- 'log_tailed_relu',
54
- 'ReluGrad',
55
- 'relu_grad',
56
- 'GaussianGrad',
57
- 'gaussian_grad',
58
- 'InvSquareGrad',
59
- 'inv_square_grad',
60
- 'MultiGaussianGrad',
61
- 'multi_gaussian_grad',
62
- 'SlayerGrad',
63
- 'slayer_grad',
64
- ]
65
-
66
-
67
- def _heaviside_abstract(x, dx):
68
- return [x]
69
-
70
-
71
- def _heaviside_imp(x, dx):
72
- z = jnp.asarray(x >= 0, dtype=x.dtype)
73
- return [z]
74
-
75
-
76
- def _heaviside_batching(args, axes):
77
- x, dx = args
78
- if axes[0] != axes[1]:
79
- dx = batching.moveaxis(dx, axes[1], axes[0])
80
- return heaviside_p.bind(x, dx), tuple([axes[0]])
81
-
82
-
83
- def _heaviside_jvp(primals, tangents):
84
- x, dx = primals
85
- tx, tdx = tangents
86
- primal_outs = heaviside_p.bind(x, dx)
87
- tangent_outs = [dx * tx, ]
88
- return primal_outs, tangent_outs
89
-
90
-
91
- heaviside_p = Primitive('heaviside_p')
92
- heaviside_p.multiple_results = True
93
- heaviside_p.def_abstract_eval(_heaviside_abstract)
94
- heaviside_p.def_impl(_heaviside_imp)
95
- batching.primitive_batchers[heaviside_p] = _heaviside_batching
96
- ad.primitive_jvps[heaviside_p] = _heaviside_jvp
97
- mlir.register_lowering(heaviside_p, mlir.lower_fun(_heaviside_imp, multiple_results=True))
98
-
99
-
100
- class Surrogate(PrettyObject):
101
- """The base surrograte gradient function.
102
-
103
- To customize a surrogate gradient function, you can inherit this class and
104
- implement the `surrogate_fun` and `surrogate_grad` methods.
105
-
106
- Examples
107
- --------
108
-
109
- >>> import brainstate as brainstate
110
- >>> import brainstate.nn as nn
111
- >>> import jax.numpy as jnp
112
-
113
- >>> class MySurrogate(brainstate.surrogate.Surrogate):
114
- ... def __init__(self, alpha=1.):
115
- ... super().__init__()
116
- ... self.alpha = alpha
117
- ...
118
- ... def surrogate_fun(self, x):
119
- ... return jnp.sin(x) * self.alpha
120
- ...
121
- ... def surrogate_grad(self, x):
122
- ... return jnp.cos(x) * self.alpha
123
-
124
- """
125
-
126
- def __call__(self, x):
127
- dx = self.surrogate_grad(x)
128
- return heaviside_p.bind(x, dx)[0]
129
-
130
- def surrogate_fun(self, x) -> jax.Array:
131
- """The surrogate function."""
132
- raise NotImplementedError
133
-
134
- def surrogate_grad(self, x) -> jax.Array:
135
- """The gradient function of the surrogate function."""
136
- raise NotImplementedError
137
-
138
-
139
- class Sigmoid(Surrogate):
140
- """Spike function with the sigmoid-shaped surrogate gradient.
141
-
142
- This class implements a spiking neuron activation with a sigmoid-shaped
143
- surrogate gradient for backpropagation. It can be used in spiking neural
144
- networks to approximate the non-differentiable step function during training.
145
-
146
- Parameters
147
- ----------
148
- alpha : float, optional
149
- A parameter controlling the steepness of the sigmoid curve in the
150
- surrogate gradient. Higher values make the transition sharper.
151
- Default is 4.0.
152
-
153
- See Also
154
- --------
155
- sigmoid : Function version of this class.
156
-
157
- """
158
-
159
- def __init__(self, alpha: float = 4.):
160
- super().__init__()
161
- self.alpha = alpha
162
-
163
- def surrogate_fun(self, x):
164
- """Compute the surrogate function.
165
-
166
- Parameters
167
- ----------
168
- x : jax.Array
169
- The input array.
170
-
171
- Returns
172
- -------
173
- jax.Array
174
- The output of the surrogate function.
175
- """
176
- return sci.special.expit(self.alpha * x)
177
-
178
- def surrogate_grad(self, x):
179
- """Compute the gradient of the surrogate function.
180
-
181
- Parameters
182
- ----------
183
- x : jax.Array
184
- The input array.
185
-
186
- Returns
187
- -------
188
- jax.Array
189
- The gradient of the surrogate function.
190
- """
191
- sgax = sci.special.expit(x * self.alpha)
192
- dx = (1. - sgax) * sgax * self.alpha
193
- return dx
194
-
195
- def __repr__(self):
196
- return f'{self.__class__.__name__}(alpha={self.alpha})'
197
-
198
- def __hash__(self):
199
- return hash((self.__class__, self.alpha))
200
-
201
-
202
- def sigmoid(
203
- x: jax.Array,
204
- alpha: float = 4.,
205
- ):
206
- r"""
207
- Compute a spike function with a sigmoid-shaped surrogate gradient.
208
-
209
- This function implements a spiking neuron activation with a sigmoid-shaped
210
- surrogate gradient for backpropagation. It can be used in spiking neural
211
- networks to approximate the non-differentiable step function during training.
212
-
213
- If `origin=False`, return the forward function:
214
-
215
- .. math::
216
-
217
- g(x) = \begin{cases}
218
- 1, & x \geq 0 \\
219
- 0, & x < 0 \\
220
- \end{cases}
221
-
222
- If `origin=True`, computes the original function:
223
-
224
- .. math::
225
-
226
- g(x) = \mathrm{sigmoid}(\alpha x) = \frac{1}{1+e^{-\alpha x}}
227
-
228
- Backward function:
229
-
230
- .. math::
231
-
232
- g'(x) = \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x)
233
-
234
- .. plot::
235
- :include-source: True
236
-
237
- >>> import jax
238
- >>> import brainstate.nn as nn
239
- >>> import brainstate as brainstate
240
- >>> import matplotlib.pyplot as plt
241
- >>> xs = jax.numpy.linspace(-2, 2, 1000)
242
- >>> for alpha in [1., 2., 4.]:
243
- >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.sigmoid)(xs, alpha)
244
- >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
245
- >>> plt.legend()
246
- >>> plt.show()
247
-
248
- Parameters
249
- ----------
250
- x : jax.Array
251
- The input array representing the neuron's membrane potential.
252
- alpha : float, optional
253
- A parameter controlling the steepness of the sigmoid curve in the
254
- surrogate gradient. Higher values make the transition sharper.
255
- Default is 4.0.
256
-
257
- Returns
258
- -------
259
- jax.Array
260
- An array of the same shape as the input, containing binary values (0 or 1)
261
- representing the spiking state of each neuron.
262
-
263
- Notes
264
- -----
265
- The forward pass uses a step function (1 for x >= 0, 0 for x < 0),
266
- while the backward pass uses a sigmoid-shaped surrogate gradient for
267
- smooth optimization.
268
-
269
- The surrogate gradient is defined as:
270
- g'(x) = alpha * (1 - sigmoid(alpha * x)) * sigmoid(alpha * x)
271
-
272
- """
273
- return Sigmoid(alpha=alpha)(x)
274
-
275
-
276
- class PiecewiseQuadratic(Surrogate):
277
- """Judge spiking state with a piecewise quadratic function.
278
-
279
- See Also
280
- --------
281
- piecewise_quadratic
282
-
283
- """
284
-
285
- def __init__(self, alpha: float = 1.):
286
- super().__init__()
287
- self.alpha = alpha
288
-
289
- def surrogate_fun(self, x):
290
- z = jnp.where(
291
- x < -1 / self.alpha,
292
- 0.,
293
- jnp.where(
294
- x > 1 / self.alpha,
295
- 1.,
296
- (-self.alpha * jnp.abs(x) / 2 + 1) * self.alpha * x + 0.5
297
- )
298
- )
299
- return z
300
-
301
- def surrogate_grad(self, x):
302
- dx = jnp.where(jnp.abs(x) > 1 / self.alpha, 0., (-(self.alpha * x) ** 2 + self.alpha))
303
- return dx
304
-
305
- def __repr__(self):
306
- return f'{self.__class__.__name__}(alpha={self.alpha})'
307
-
308
- def __hash__(self):
309
- return hash((self.__class__, self.alpha))
310
-
311
-
312
- def piecewise_quadratic(
313
- x: jax.Array,
314
- alpha: float = 1.,
315
- ):
316
- r"""
317
- Judge spiking state with a piecewise quadratic function [1]_ [2]_ [3]_ [4]_ [5]_.
318
-
319
- This function implements a surrogate gradient method for spiking neural networks
320
- using a piecewise quadratic function. It provides a differentiable approximation
321
- of the step function used in the forward pass of spiking neurons.
322
-
323
- If `origin=False`, computes the forward function:
324
-
325
- .. math::
326
-
327
- g(x) = \begin{cases}
328
- 1, & x \geq 0 \\
329
- 0, & x < 0 \\
330
- \end{cases}
331
-
332
- If `origin=True`, computes the original function:
333
-
334
- .. math::
335
-
336
- g(x) =
337
- \begin{cases}
338
- 0, & x < -\frac{1}{\alpha} \\
339
- -\frac{1}{2}\alpha^2|x|x + \alpha x + \frac{1}{2}, & |x| \leq \frac{1}{\alpha} \\
340
- 1, & x > \frac{1}{\alpha} \\
341
- \end{cases}
342
-
343
- Backward function:
344
-
345
- .. math::
346
-
347
- g'(x) =
348
- \begin{cases}
349
- 0, & |x| > \frac{1}{\alpha} \\
350
- -\alpha^2|x|+\alpha, & |x| \leq \frac{1}{\alpha}
351
- \end{cases}
352
-
353
- .. plot::
354
- :include-source: True
355
-
356
- >>> import jax
357
- >>> import brainstate.nn as nn
358
- >>> import brainstate as brainstate
359
- >>> import matplotlib.pyplot as plt
360
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
361
- >>> for alpha in [0.5, 1., 2., 4.]:
362
- >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.piecewise_quadratic)(xs, alpha)
363
- >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
364
- >>> plt.legend()
365
- >>> plt.show()
366
-
367
- Parameters
368
- ----------
369
- x : jax.Array
370
- The input array representing the neuron's membrane potential.
371
- alpha : float, optional
372
- A parameter controlling the steepness of the surrogate gradient.
373
- Higher values result in a steeper gradient. Default is 1.0.
374
-
375
- Returns
376
- -------
377
- jax.Array
378
- An array of the same shape as the input, containing binary values (0 or 1)
379
- representing the spiking state of each neuron.
380
-
381
- Notes
382
- -----
383
- The function uses different computations for forward and backward passes:
384
- - Forward: Step function (1 for x >= 0, 0 for x < 0)
385
- - Backward: Piecewise quadratic function for smooth gradient
386
-
387
- The surrogate gradient is defined as:
388
- g'(x) = 0 if |x| > 1/alpha
389
- -alpha^2|x| + alpha if |x| <= 1/alpha
390
-
391
- References
392
- ----------
393
- .. [1] Esser S K, Merolla P A, Arthur J V, et al. Convolutional networks for fast, energy-efficient neuromorphic computing[J]. Proceedings of the national academy of sciences, 2016, 113(41): 11441-11446.
394
- .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331.
395
- .. [3] Bellec G, Salaj D, Subramoney A, et al. Long short-term memory and learning-to-learn in networks of spiking neurons[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 795-805.
396
- .. [4] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63.
397
- .. [5] Panda P, Aketi S A, Roy K. Toward scalable, efficient, and accurate deep spiking neural networks with backward residual connections, stochastic softmax, and hybridization[J]. Frontiers in Neuroscience, 2020, 14.
398
- """
399
- return PiecewiseQuadratic(alpha=alpha)(x)
400
-
401
-
402
- class PiecewiseExp(Surrogate):
403
- """
404
- Judge spiking state with a piecewise exponential function.
405
-
406
- This class implements a surrogate gradient method for spiking neural networks
407
- using a piecewise exponential function. It provides a differentiable approximation
408
- of the step function used in the forward pass of spiking neurons.
409
-
410
- Parameters
411
- ----------
412
- alpha : float, optional
413
- A parameter controlling the steepness of the surrogate gradient.
414
- Higher values result in a steeper gradient. Default is 1.0.
415
-
416
- See Also
417
- --------
418
- piecewise_exp : Function version of this class.
419
- """
420
-
421
- def __init__(self, alpha: float = 1.):
422
- super().__init__()
423
- self.alpha = alpha
424
-
425
- def surrogate_grad(self, x):
426
- """
427
- Compute the surrogate gradient.
428
-
429
- Parameters
430
- ----------
431
- x : jax.Array
432
- The input array.
433
-
434
- Returns
435
- -------
436
- jax.Array
437
- The surrogate gradient.
438
- """
439
- dx = (self.alpha / 2) * jnp.exp(-self.alpha * jnp.abs(x))
440
- return dx
441
-
442
- def surrogate_fun(self, x):
443
- """
444
- Compute the surrogate function.
445
-
446
- Parameters
447
- ----------
448
- x : jax.Array
449
- The input array.
450
-
451
- Returns
452
- -------
453
- jax.Array
454
- The output of the surrogate function.
455
- """
456
- return jnp.where(
457
- x < 0,
458
- jnp.exp(self.alpha * x) / 2,
459
- 1 - jnp.exp(-self.alpha * x) / 2
460
- )
461
-
462
- def __repr__(self):
463
- """
464
- Return a string representation of the PiecewiseExp instance.
465
-
466
- Returns
467
- -------
468
- str
469
- A string representation of the instance.
470
- """
471
- return f'{self.__class__.__name__}(alpha={self.alpha})'
472
-
473
- def __hash__(self):
474
- """
475
- Compute a hash value for the PiecewiseExp instance.
476
-
477
- Returns
478
- -------
479
- int
480
- A hash value for the instance.
481
- """
482
- return hash((self.__class__, self.alpha))
483
-
484
-
485
- def piecewise_exp(
486
- x: jax.Array,
487
- alpha: float = 1.,
488
-
489
- ):
490
- r"""Judge spiking state with a piecewise exponential function [1]_.
491
-
492
- This function implements a surrogate gradient method for spiking neural networks
493
- using a piecewise exponential function. It provides a differentiable approximation
494
- of the step function used in the forward pass of spiking neurons.
495
-
496
- If `origin=False`, computes the forward function:
497
-
498
- .. math::
499
-
500
- g(x) = \begin{cases}
501
- 1, & x \geq 0 \\
502
- 0, & x < 0 \\
503
- \end{cases}
504
-
505
- If `origin=True`, computes the original function:
506
-
507
- .. math::
508
-
509
- g(x) = \begin{cases}
510
- \frac{1}{2}e^{\alpha x}, & x < 0 \\
511
- 1 - \frac{1}{2}e^{-\alpha x}, & x \geq 0
512
- \end{cases}
513
-
514
- Backward function:
515
-
516
- .. math::
517
-
518
- g'(x) = \frac{\alpha}{2}e^{-\alpha |x|}
519
-
520
- .. plot::
521
- :include-source: True
522
-
523
- >>> import jax
524
- >>> import brainstate.nn as nn
525
- >>> import brainstate as brainstate
526
- >>> import matplotlib.pyplot as plt
527
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
528
- >>> for alpha in [0.5, 1., 2., 4.]:
529
- >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.piecewise_exp)(xs, alpha)
530
- >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
531
- >>> plt.legend()
532
- >>> plt.show()
533
-
534
- Parameters
535
- ----------
536
- x : jax.Array
537
- The input array representing the neuron's membrane potential.
538
- alpha : float, optional
539
- A parameter controlling the steepness of the surrogate gradient.
540
- Higher values result in a steeper gradient. Default is 1.0.
541
-
542
- Returns
543
- -------
544
- jax.Array
545
- An array of the same shape as the input, containing binary values (0 or 1)
546
- representing the spiking state of each neuron.
547
-
548
- Notes
549
- -----
550
- The function uses different computations for forward and backward passes:
551
- - Forward: Step function (1 for x >= 0, 0 for x < 0)
552
- - Backward: Piecewise exponential function for smooth gradient
553
-
554
- The surrogate gradient is defined as:
555
- g'(x) = (alpha / 2) * exp(-alpha * |x|)
556
-
557
- References
558
- ----------
559
- .. [1] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63.
560
- """
561
- return PiecewiseExp(alpha=alpha)(x)
562
-
563
-
564
- class SoftSign(Surrogate):
565
- """Judge spiking state with a soft sign function.
566
-
567
- See Also
568
- --------
569
- soft_sign
570
- """
571
-
572
- def __init__(self, alpha=1.):
573
- super().__init__()
574
- self.alpha = alpha
575
-
576
- def surrogate_grad(self, x):
577
- dx = self.alpha * 0.5 / (1 + jnp.abs(self.alpha * x)) ** 2
578
- return dx
579
-
580
- def surrogate_fun(self, x):
581
- return x / (2 / self.alpha + 2 * jnp.abs(x)) + 0.5
582
-
583
- def __repr__(self):
584
- return f'{self.__class__.__name__}(alpha={self.alpha})'
585
-
586
- def __hash__(self):
587
- return hash((self.__class__, self.alpha))
588
-
589
-
590
- def soft_sign(
591
- x: jax.Array,
592
- alpha: float = 1.,
593
-
594
- ):
595
- r"""Judge spiking state with a soft sign function.
596
-
597
- If `origin=False`, computes the forward function:
598
-
599
- .. math::
600
-
601
- g(x) = \begin{cases}
602
- 1, & x \geq 0 \\
603
- 0, & x < 0 \\
604
- \end{cases}
605
-
606
- If `origin=True`, computes the original function:
607
-
608
- .. math::
609
-
610
- g(x) = \frac{1}{2} (\frac{\alpha x}{1 + |\alpha x|} + 1)
611
- = \frac{1}{2} (\frac{x}{\frac{1}{\alpha} + |x|} + 1)
612
-
613
- Backward function:
614
-
615
- .. math::
616
-
617
- g'(x) = \frac{\alpha}{2(1 + |\alpha x|)^{2}} = \frac{1}{2\alpha(\frac{1}{\alpha} + |x|)^{2}}
618
-
619
- .. plot::
620
- :include-source: True
621
-
622
- >>> import jax
623
- >>> import brainstate.nn as nn
624
- >>> import brainstate as brainstate
625
- >>> import matplotlib.pyplot as plt
626
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
627
- >>> for alpha in [0.5, 1., 2., 4.]:
628
- >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.soft_sign)(xs, alpha)
629
- >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
630
- >>> plt.legend()
631
- >>> plt.show()
632
-
633
- Parameters
634
- ----------
635
- x: jax.Array, Array
636
- The input data.
637
- alpha: float
638
- Parameter to control smoothness of gradient
639
-
640
-
641
- Returns
642
- -------
643
- out: jax.Array
644
- The spiking state.
645
-
646
- """
647
- return SoftSign(alpha=alpha)(x)
648
-
649
-
650
- class Arctan(Surrogate):
651
- """Judge spiking state with an arctan function.
652
-
653
- See Also
654
- --------
655
- arctan
656
- """
657
-
658
- def __init__(self, alpha=1.):
659
- super().__init__()
660
- self.alpha = alpha
661
-
662
- def surrogate_grad(self, x):
663
- dx = self.alpha * 0.5 / (1 + (jnp.pi / 2 * self.alpha * x) ** 2)
664
- return dx
665
-
666
- def surrogate_fun(self, x):
667
- return jnp.arctan2(jnp.pi / 2 * self.alpha * x) / jnp.pi + 0.5
668
-
669
- def __repr__(self):
670
- return f'{self.__class__.__name__}(alpha={self.alpha})'
671
-
672
- def __hash__(self):
673
- return hash((self.__class__, self.alpha))
674
-
675
-
676
- def arctan(
677
- x: jax.Array,
678
- alpha: float = 1.,
679
-
680
- ):
681
- r"""Judge spiking state with an arctan function.
682
-
683
- If `origin=False`, computes the forward function:
684
-
685
- .. math::
686
-
687
- g(x) = \begin{cases}
688
- 1, & x \geq 0 \\
689
- 0, & x < 0 \\
690
- \end{cases}
691
-
692
- If `origin=True`, computes the original function:
693
-
694
- .. math::
695
-
696
- g(x) = \frac{1}{\pi} \arctan(\frac{\pi}{2}\alpha x) + \frac{1}{2}
697
-
698
- Backward function:
699
-
700
- .. math::
701
-
702
- g'(x) = \frac{\alpha}{2(1 + (\frac{\pi}{2}\alpha x)^2)}
703
-
704
- .. plot::
705
- :include-source: True
706
-
707
- >>> import jax
708
- >>> import brainstate.nn as nn
709
- >>> import brainstate as brainstate
710
- >>> import matplotlib.pyplot as plt
711
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
712
- >>> for alpha in [0.5, 1., 2., 4.]:
713
- >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.arctan)(xs, alpha)
714
- >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
715
- >>> plt.legend()
716
- >>> plt.show()
717
-
718
- Parameters
719
- ----------
720
- x: jax.Array, Array
721
- The input data.
722
- alpha: float
723
- Parameter to control smoothness of gradient
724
-
725
-
726
- Returns
727
- -------
728
- out: jax.Array
729
- The spiking state.
730
-
731
- """
732
- return Arctan(alpha=alpha)(x)
733
-
734
-
735
- class NonzeroSignLog(Surrogate):
736
- """Judge spiking state with a nonzero sign log function.
737
-
738
- See Also
739
- --------
740
- nonzero_sign_log
741
- """
742
-
743
- def __init__(self, alpha=1.):
744
- super().__init__()
745
- self.alpha = alpha
746
-
747
- def surrogate_grad(self, x):
748
- dx = 1. / (1 / self.alpha + jnp.abs(x))
749
- return dx
750
-
751
- def surrogate_fun(self, x):
752
- return jnp.where(x < 0, -1., 1.) * jnp.log(jnp.abs(self.alpha * x) + 1)
753
-
754
- def __repr__(self):
755
- return f'{self.__class__.__name__}(alpha={self.alpha})'
756
-
757
- def __hash__(self):
758
- return hash((self.__class__, self.alpha))
759
-
760
-
761
- def nonzero_sign_log(
762
- x: jax.Array,
763
- alpha: float = 1.,
764
-
765
- ):
766
- r"""Judge spiking state with a nonzero sign log function.
767
-
768
- If `origin=False`, computes the forward function:
769
-
770
- .. math::
771
-
772
- g(x) = \begin{cases}
773
- 1, & x \geq 0 \\
774
- 0, & x < 0 \\
775
- \end{cases}
776
-
777
- If `origin=True`, computes the original function:
778
-
779
- .. math::
780
-
781
- g(x) = \mathrm{NonzeroSign}(x) \log (|\alpha x| + 1)
782
-
783
- where
784
-
785
- .. math::
786
-
787
- \begin{split}\mathrm{NonzeroSign}(x) =
788
- \begin{cases}
789
- 1, & x \geq 0 \\
790
- -1, & x < 0 \\
791
- \end{cases}\end{split}
792
-
793
- Backward function:
794
-
795
- .. math::
796
-
797
- g'(x) = \frac{\alpha}{1 + |\alpha x|} = \frac{1}{\frac{1}{\alpha} + |x|}
798
-
799
- This surrogate function has the advantage of low computation cost during the backward.
800
-
801
-
802
- .. plot::
803
- :include-source: True
804
-
805
- >>> import jax
806
- >>> import brainstate.nn as nn
807
- >>> import brainstate as brainstate
808
- >>> import matplotlib.pyplot as plt
809
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
810
- >>> for alpha in [0.5, 1., 2., 4.]:
811
- >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.nonzero_sign_log)(xs, alpha)
812
- >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
813
- >>> plt.legend()
814
- >>> plt.show()
815
-
816
- Parameters
817
- ----------
818
- x: jax.Array, Array
819
- The input data.
820
- alpha: float
821
- Parameter to control smoothness of gradient
822
-
823
-
824
- Returns
825
- -------
826
- out: jax.Array
827
- The spiking state.
828
-
829
- """
830
- return NonzeroSignLog(alpha=alpha)(x)
831
-
832
-
833
- class ERF(Surrogate):
834
- """Judge spiking state with an erf function.
835
-
836
- See Also
837
- --------
838
- erf
839
- """
840
-
841
- def __init__(self, alpha=1.):
842
- super().__init__()
843
- self.alpha = alpha
844
-
845
- def surrogate_grad(self, x):
846
- dx = (self.alpha / jnp.sqrt(jnp.pi)) * jnp.exp(-jnp.power(self.alpha, 2) * x * x)
847
- return dx
848
-
849
- def surrogate_fun(self, x):
850
- return sci.special.erf(-self.alpha * x) * 0.5
851
-
852
- def __repr__(self):
853
- return f'{self.__class__.__name__}(alpha={self.alpha})'
854
-
855
- def __hash__(self):
856
- return hash((self.__class__, self.alpha))
857
-
858
-
859
- def erf(
860
- x: jax.Array,
861
- alpha: float = 1.,
862
-
863
- ):
864
- r"""Judge spiking state with an erf function [1]_ [2]_ [3]_.
865
-
866
- If `origin=False`, computes the forward function:
867
-
868
- .. math::
869
-
870
- g(x) = \begin{cases}
871
- 1, & x \geq 0 \\
872
- 0, & x < 0 \\
873
- \end{cases}
874
-
875
- If `origin=True`, computes the original function:
876
-
877
- .. math::
878
-
879
- \begin{split}
880
- g(x) &= \frac{1}{2}(1-\text{erf}(-\alpha x)) \\
881
- &= \frac{1}{2} \text{erfc}(-\alpha x) \\
882
- &= \frac{1}{\sqrt{\pi}}\int_{-\infty}^{\alpha x}e^{-t^2}dt
883
- \end{split}
884
-
885
- Backward function:
886
-
887
- .. math::
888
-
889
- g'(x) = \frac{\alpha}{\sqrt{\pi}}e^{-\alpha^2x^2}
890
-
891
- .. plot::
892
- :include-source: True
893
-
894
- >>> import jax
895
- >>> import brainstate.nn as nn
896
- >>> import brainstate as brainstate
897
- >>> import matplotlib.pyplot as plt
898
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
899
- >>> for alpha in [0.5, 1., 2., 4.]:
900
- >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.nonzero_sign_log)(xs, alpha)
901
- >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
902
- >>> plt.legend()
903
- >>> plt.show()
904
-
905
- Parameters
906
- ----------
907
- x: jax.Array, Array
908
- The input data.
909
- alpha: float
910
- Parameter to control smoothness of gradient
911
-
912
-
913
- Returns
914
- -------
915
- out: jax.Array
916
- The spiking state.
917
-
918
- References
919
- ----------
920
- .. [1] Esser S K, Appuswamy R, Merolla P, et al. Backpropagation for energy-efficient neuromorphic computing[J]. Advances in neural information processing systems, 2015, 28: 1117-1125.
921
- .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331.
922
- .. [3] Yin B, Corradi F, Bohté S M. Effective and efficient computation with multiple-timescale spiking recurrent neural networks[C]//International Conference on Neuromorphic Systems 2020. 2020: 1-8.
923
-
924
- """
925
- return ERF(alpha=alpha)(x)
926
-
927
-
928
- class PiecewiseLeakyRelu(Surrogate):
929
- """Judge spiking state with a piecewise leaky relu function.
930
-
931
- See Also
932
- --------
933
- piecewise_leaky_relu
934
- """
935
-
936
- def __init__(self, c=0.01, w=1.):
937
- super().__init__()
938
- self.c = c
939
- self.w = w
940
-
941
- def surrogate_fun(self, x):
942
- z = jnp.where(x < -self.w,
943
- self.c * x + self.c * self.w,
944
- jnp.where(x > self.w,
945
- self.c * x - self.c * self.w + 1,
946
- 0.5 * x / self.w + 0.5))
947
- return z
948
-
949
- def surrogate_grad(self, x):
950
- dx = jnp.where(jnp.abs(x) > self.w, self.c, 1 / self.w)
951
- return dx
952
-
953
- def __repr__(self):
954
- return f'{self.__class__.__name__}(c={self.c}, w={self.w})'
955
-
956
- def __hash__(self):
957
- return hash((self.__class__, self.c, self.w))
958
-
959
-
960
- def piecewise_leaky_relu(
961
- x: jax.Array,
962
- c: float = 0.01,
963
- w: float = 1.,
964
-
965
- ):
966
- r"""Judge spiking state with a piecewise leaky relu function [1]_ [2]_ [3]_ [4]_ [5]_ [6]_ [7]_ [8]_.
967
-
968
- If `origin=False`, computes the forward function:
969
-
970
- .. math::
971
-
972
- g(x) = \begin{cases}
973
- 1, & x \geq 0 \\
974
- 0, & x < 0 \\
975
- \end{cases}
976
-
977
- If `origin=True`, computes the original function:
978
-
979
- .. math::
980
-
981
- \begin{split}g(x) =
982
- \begin{cases}
983
- cx + cw, & x < -w \\
984
- \frac{1}{2w}x + \frac{1}{2}, & -w \leq x \leq w \\
985
- cx - cw + 1, & x > w \\
986
- \end{cases}\end{split}
987
-
988
- Backward function:
989
-
990
- .. math::
991
-
992
- \begin{split}g'(x) =
993
- \begin{cases}
994
- \frac{1}{w}, & |x| \leq w \\
995
- c, & |x| > w
996
- \end{cases}\end{split}
997
-
998
- .. plot::
999
- :include-source: True
1000
-
1001
- >>> import jax
1002
- >>> import brainstate.nn as nn
1003
- >>> import brainstate as brainstate
1004
- >>> import matplotlib.pyplot as plt
1005
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
1006
- >>> for c in [0.01, 0.05, 0.1]:
1007
- >>> for w in [1., 2.]:
1008
- >>> grads1 = brainstate.augment.vector_grad(brainstate.surrogate.piecewise_leaky_relu)(xs, c=c, w=w)
1009
- >>> plt.plot(xs, grads1, label=f'x={c}, w={w}')
1010
- >>> plt.legend()
1011
- >>> plt.show()
1012
-
1013
- Parameters
1014
- ----------
1015
- x: jax.Array, Array
1016
- The input data.
1017
- c: float
1018
- When :math:`|x| > w` the gradient is `c`.
1019
- w: float
1020
- When :math:`|x| <= w` the gradient is `1 / w`.
1021
-
1022
-
1023
- Returns
1024
- -------
1025
- out: jax.Array
1026
- The spiking state.
1027
-
1028
- References
1029
- ----------
1030
- .. [1] Yin S, Venkataramanaiah S K, Chen G K, et al. Algorithm and hardware design of discrete-time spiking neural networks based on back propagation with binary activations[C]//2017 IEEE Biomedical Circuits and Systems Conference (BioCAS). IEEE, 2017: 1-5.
1031
- .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331.
1032
- .. [3] Huh D, Sejnowski T J. Gradient descent for spiking neural networks[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 1440-1450.
1033
- .. [4] Wu Y, Deng L, Li G, et al. Direct training for spiking neural networks: Faster, larger, better[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2019, 33(01): 1311-1318.
1034
- .. [5] Gu P, Xiao R, Pan G, et al. STCA: Spatio-Temporal Credit Assignment with Delayed Feedback in Deep Spiking Neural Networks[C]//IJCAI. 2019: 1366-1372.
1035
- .. [6] Roy D, Chakraborty I, Roy K. Scaling deep spiking neural networks with binary stochastic activations[C]//2019 IEEE International Conference on Cognitive Computing (ICCC). IEEE, 2019: 50-58.
1036
- .. [7] Cheng X, Hao Y, Xu J, et al. LISNN: Improving Spiking Neural Networks with Lateral Interactions for Robust Object Recognition[C]//IJCAI. 1519-1525.
1037
- .. [8] Kaiser J, Mostafa H, Neftci E. Synaptic plasticity dynamics for deep continuous local learning (DECOLLE)[J]. Frontiers in Neuroscience, 2020, 14: 424.
1038
-
1039
- """
1040
- return PiecewiseLeakyRelu(c=c, w=w)(x)
1041
-
1042
-
1043
- class SquarewaveFourierSeries(Surrogate):
1044
- """Judge spiking state with a squarewave fourier series.
1045
-
1046
- See Also
1047
- --------
1048
- squarewave_fourier_series
1049
- """
1050
-
1051
- def __init__(self, n=2, t_period=8.):
1052
- super().__init__()
1053
- self.n = n
1054
- self.t_period = t_period
1055
-
1056
- def surrogate_grad(self, x):
1057
-
1058
- w = jnp.pi * 2. / self.t_period
1059
- dx = jnp.cos(w * x)
1060
- for i in range(2, self.n):
1061
- dx += jnp.cos((2 * i - 1.) * w * x)
1062
- dx *= 4. / self.t_period
1063
- return dx
1064
-
1065
- def surrogate_fun(self, x):
1066
-
1067
- w = jnp.pi * 2. / self.t_period
1068
- ret = jnp.sin(w * x)
1069
- for i in range(2, self.n):
1070
- c = (2 * i - 1.)
1071
- ret += jnp.sin(c * w * x) / c
1072
- z = 0.5 + 2. / jnp.pi * ret
1073
- return z
1074
-
1075
- def __repr__(self):
1076
- return f'{self.__class__.__name__}(n={self.n}, t_period={self.t_period})'
1077
-
1078
- def __hash__(self):
1079
- return hash((self.__class__, self.n, self.t_period))
1080
-
1081
-
1082
- def squarewave_fourier_series(
1083
- x: jax.Array,
1084
- n: int = 2,
1085
- t_period: float = 8.,
1086
-
1087
- ):
1088
- r"""Judge spiking state with a squarewave fourier series.
1089
-
1090
- If `origin=False`, computes the forward function:
1091
-
1092
- .. math::
1093
-
1094
- g(x) = \begin{cases}
1095
- 1, & x \geq 0 \\
1096
- 0, & x < 0 \\
1097
- \end{cases}
1098
-
1099
- If `origin=True`, computes the original function:
1100
-
1101
- .. math::
1102
-
1103
- g(x) = 0.5 + \frac{1}{\pi}*\sum_{i=1}^n {\sin\left({(2i-1)*2\pi}*x/T\right) \over 2i-1 }
1104
-
1105
- Backward function:
1106
-
1107
- .. math::
1108
-
1109
- g'(x) = \sum_{i=1}^n\frac{4\cos\left((2 * i - 1.) * 2\pi * x / T\right)}{T}
1110
-
1111
- .. plot::
1112
- :include-source: True
1113
-
1114
- >>> import jax
1115
- >>> import brainstate.nn as nn
1116
- >>> import brainstate as brainstate
1117
- >>> import matplotlib.pyplot as plt
1118
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
1119
- >>> for n in [2, 4, 8]:
1120
- >>> f = brainstate.surrogate.SquarewaveFourierSeries(n=n)
1121
- >>> grads1 = brainstate.augment.vector_grad(f)(xs)
1122
- >>> plt.plot(xs, grads1, label=f'n={n}')
1123
- >>> plt.legend()
1124
- >>> plt.show()
1125
-
1126
- Parameters
1127
- ----------
1128
- x: jax.Array, Array
1129
- The input data.
1130
- n: int
1131
- t_period: float
1132
-
1133
-
1134
- Returns
1135
- -------
1136
- out: jax.Array
1137
- The spiking state.
1138
-
1139
- """
1140
-
1141
- return SquarewaveFourierSeries(n=n, t_period=t_period)(x)
1142
-
1143
-
1144
- class S2NN(Surrogate):
1145
- """Judge spiking state with the S2NN surrogate spiking function.
1146
-
1147
- See Also
1148
- --------
1149
- s2nn
1150
- """
1151
-
1152
- def __init__(self, alpha=4., beta=1., epsilon=1e-8):
1153
- super().__init__()
1154
- self.alpha = alpha
1155
- self.beta = beta
1156
- self.epsilon = epsilon
1157
-
1158
- def surrogate_fun(self, x):
1159
- z = jnp.where(x < 0.,
1160
- sci.special.expit(x * self.alpha),
1161
- self.beta * jnp.log(jnp.abs((x + 1.)) + self.epsilon) + 0.5)
1162
- return z
1163
-
1164
- def surrogate_grad(self, x):
1165
- sg = sci.special.expit(self.alpha * x)
1166
- dx = jnp.where(x < 0., self.alpha * sg * (1. - sg), self.beta / (x + 1.))
1167
- return dx
1168
-
1169
- def __repr__(self):
1170
- return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta}, epsilon={self.epsilon})'
1171
-
1172
- def __hash__(self):
1173
- return hash((self.__class__, self.alpha, self.beta, self.epsilon))
1174
-
1175
-
1176
- def s2nn(
1177
- x: jax.Array,
1178
- alpha: float = 4.,
1179
- beta: float = 1.,
1180
- epsilon: float = 1e-8,
1181
-
1182
- ):
1183
- r"""Judge spiking state with the S2NN surrogate spiking function [1]_.
1184
-
1185
- If `origin=False`, computes the forward function:
1186
-
1187
- .. math::
1188
-
1189
- g(x) = \begin{cases}
1190
- 1, & x \geq 0 \\
1191
- 0, & x < 0 \\
1192
- \end{cases}
1193
-
1194
- If `origin=True`, computes the original function:
1195
-
1196
- .. math::
1197
-
1198
- \begin{split}g(x) = \begin{cases}
1199
- \mathrm{sigmoid} (\alpha x), x < 0 \\
1200
- \beta \ln(|x + 1|) + 0.5, x \ge 0
1201
- \end{cases}\end{split}
1202
-
1203
- Backward function:
1204
-
1205
- .. math::
1206
-
1207
- \begin{split}g'(x) = \begin{cases}
1208
- \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x), x < 0 \\
1209
- \frac{\beta}{(x + 1)}, x \ge 0
1210
- \end{cases}\end{split}
1211
-
1212
- .. plot::
1213
- :include-source: True
1214
-
1215
- >>> import jax
1216
- >>> import brainstate.nn as nn
1217
- >>> import brainstate as brainstate
1218
- >>> import matplotlib.pyplot as plt
1219
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
1220
- >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.s2nn)(xs, 4., 1.)
1221
- >>> plt.plot(xs, grads, label=r'$\alpha=4, \beta=1$')
1222
- >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.s2nn)(xs, 8., 2.)
1223
- >>> plt.plot(xs, grads, label=r'$\alpha=8, \beta=2$')
1224
- >>> plt.legend()
1225
- >>> plt.show()
1226
-
1227
- Parameters
1228
- ----------
1229
- x: jax.Array, Array
1230
- The input data.
1231
- alpha: float
1232
- The param that controls the gradient when ``x < 0``.
1233
- beta: float
1234
- The param that controls the gradient when ``x >= 0``
1235
- epsilon: float
1236
- Avoid nan
1237
-
1238
-
1239
- Returns
1240
- -------
1241
- out: jax.Array
1242
- The spiking state.
1243
-
1244
- References
1245
- ----------
1246
- .. [1] Suetake, Kazuma et al. “S2NN: Time Step Reduction of Spiking Surrogate Gradients for Training Energy Efficient Single-Step Neural Networks.” ArXiv abs/2201.10879 (2022): n. pag.
1247
-
1248
- """
1249
- return S2NN(alpha=alpha, beta=beta, epsilon=epsilon)(x)
1250
-
1251
-
1252
- class QPseudoSpike(Surrogate):
1253
- """Judge spiking state with the q-PseudoSpike surrogate function.
1254
-
1255
- See Also
1256
- --------
1257
- q_pseudo_spike
1258
- """
1259
-
1260
- def __init__(self, alpha=2.):
1261
- super().__init__()
1262
- self.alpha = alpha
1263
-
1264
- def surrogate_grad(self, x):
1265
- dx = jnp.power(1 + 2 / (self.alpha + 1) * jnp.abs(x), -self.alpha)
1266
- return dx
1267
-
1268
- def surrogate_fun(self, x):
1269
- z = jnp.where(x < 0.,
1270
- 0.5 * jnp.power(1 - 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha),
1271
- 1. - 0.5 * jnp.power(1 + 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha))
1272
- return z
1273
-
1274
- def __repr__(self):
1275
- return f'{self.__class__.__name__}(alpha={self.alpha})'
1276
-
1277
- def __hash__(self):
1278
- return hash((self.__class__, self.alpha))
1279
-
1280
-
1281
- def q_pseudo_spike(
1282
- x: jax.Array,
1283
- alpha: float = 2.,
1284
-
1285
- ):
1286
- r"""Judge spiking state with the q-PseudoSpike surrogate function [1]_.
1287
-
1288
- If `origin=False`, computes the forward function:
1289
-
1290
- .. math::
1291
-
1292
- g(x) = \begin{cases}
1293
- 1, & x \geq 0 \\
1294
- 0, & x < 0 \\
1295
- \end{cases}
1296
-
1297
- If `origin=True`, computes the original function:
1298
-
1299
- .. math::
1300
-
1301
- \begin{split}g(x) =
1302
- \begin{cases}
1303
- \frac{1}{2}(1-\frac{2x}{\alpha-1})^{1-\alpha}, & x < 0 \\
1304
- 1 - \frac{1}{2}(1+\frac{2x}{\alpha-1})^{1-\alpha}, & x \geq 0.
1305
- \end{cases}\end{split}
1306
-
1307
- Backward function:
1308
-
1309
- .. math::
1310
-
1311
- g'(x) = (1+\frac{2|x|}{\alpha-1})^{-\alpha}
1312
-
1313
- .. plot::
1314
- :include-source: True
1315
-
1316
- >>> import jax
1317
- >>> import brainstate.nn as nn
1318
- >>> import brainstate as brainstate
1319
- >>> import matplotlib.pyplot as plt
1320
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
1321
- >>> for alpha in [0.5, 1., 2., 4.]:
1322
- >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.q_pseudo_spike)(xs, alpha)
1323
- >>> plt.plot(xs, grads, label=r'$\alpha=$' + str(alpha))
1324
- >>> plt.legend()
1325
- >>> plt.show()
1326
-
1327
- Parameters
1328
- ----------
1329
- x: jax.Array, Array
1330
- The input data.
1331
- alpha: float
1332
- The parameter to control tail fatness of gradient.
1333
-
1334
-
1335
- Returns
1336
- -------
1337
- out: jax.Array
1338
- The spiking state.
1339
-
1340
- References
1341
- ----------
1342
- .. [1] Herranz-Celotti, Luca and Jean Rouat. “Surrogate Gradients Design.” ArXiv abs/2202.00282 (2022): n. pag.
1343
- """
1344
- return QPseudoSpike(alpha=alpha)(x)
1345
-
1346
-
1347
- class LeakyRelu(Surrogate):
1348
- """Judge spiking state with the Leaky ReLU function.
1349
-
1350
- See Also
1351
- --------
1352
- leaky_relu
1353
- """
1354
-
1355
- def __init__(self, alpha=0.1, beta=1.):
1356
- super().__init__()
1357
- self.alpha = alpha
1358
- self.beta = beta
1359
-
1360
- def surrogate_fun(self, x):
1361
- return jnp.where(x < 0., self.alpha * x, self.beta * x)
1362
-
1363
- def surrogate_grad(self, x):
1364
- dx = jnp.where(x < 0., self.alpha, self.beta)
1365
- return dx
1366
-
1367
- def __repr__(self):
1368
- return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta})'
1369
-
1370
- def __hash__(self):
1371
- return hash((self.__class__, self.alpha, self.beta))
1372
-
1373
-
1374
- def leaky_relu(
1375
- x: jax.Array,
1376
- alpha: float = 0.1,
1377
- beta: float = 1.,
1378
-
1379
- ):
1380
- r"""Judge spiking state with the Leaky ReLU function.
1381
-
1382
- If `origin=False`, computes the forward function:
1383
-
1384
- .. math::
1385
-
1386
- g(x) = \begin{cases}
1387
- 1, & x \geq 0 \\
1388
- 0, & x < 0 \\
1389
- \end{cases}
1390
-
1391
- If `origin=True`, computes the original function:
1392
-
1393
- .. math::
1394
-
1395
- \begin{split}g(x) =
1396
- \begin{cases}
1397
- \beta \cdot x, & x \geq 0 \\
1398
- \alpha \cdot x, & x < 0 \\
1399
- \end{cases}\end{split}
1400
-
1401
- Backward function:
1402
-
1403
- .. math::
1404
-
1405
- \begin{split}g'(x) =
1406
- \begin{cases}
1407
- \beta, & x \geq 0 \\
1408
- \alpha, & x < 0 \\
1409
- \end{cases}\end{split}
1410
-
1411
- .. plot::
1412
- :include-source: True
1413
-
1414
- >>> import jax
1415
- >>> import brainstate.nn as nn
1416
- >>> import brainstate as brainstate
1417
- >>> import matplotlib.pyplot as plt
1418
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
1419
- >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.leaky_relu)(xs, 0., 1.)
1420
- >>> plt.plot(xs, grads, label=r'$\alpha=0., \beta=1.$')
1421
- >>> plt.legend()
1422
- >>> plt.show()
1423
-
1424
- Parameters
1425
- ----------
1426
- x: jax.Array, Array
1427
- The input data.
1428
- alpha: float
1429
- The parameter to control the gradient when :math:`x < 0`.
1430
- beta: float
1431
- The parameter to control the gradient when :math:`x >= 0`.
1432
-
1433
-
1434
- Returns
1435
- -------
1436
- out: jax.Array
1437
- The spiking state.
1438
- """
1439
- return LeakyRelu(alpha=alpha, beta=beta)(x)
1440
-
1441
-
1442
- class LogTailedRelu(Surrogate):
1443
- """Judge spiking state with the Log-tailed ReLU function.
1444
-
1445
- See Also
1446
- --------
1447
- log_tailed_relu
1448
- """
1449
-
1450
- def __init__(self, alpha=0.):
1451
- super().__init__()
1452
- self.alpha = alpha
1453
-
1454
- def surrogate_fun(self, x):
1455
- z = jnp.where(x > 1,
1456
- jnp.log(x),
1457
- jnp.where(x > 0,
1458
- x,
1459
- self.alpha * x))
1460
- return z
1461
-
1462
- def surrogate_grad(self, x):
1463
- dx = jnp.where(x > 1,
1464
- 1 / x,
1465
- jnp.where(x > 0,
1466
- 1.,
1467
- self.alpha))
1468
- return dx
1469
-
1470
- def __repr__(self):
1471
- return f'{self.__class__.__name__}(alpha={self.alpha})'
1472
-
1473
- def __hash__(self):
1474
- return hash((self.__class__, self.alpha))
1475
-
1476
-
1477
- def log_tailed_relu(
1478
- x: jax.Array,
1479
- alpha: float = 0.,
1480
-
1481
- ):
1482
- r"""Judge spiking state with the Log-tailed ReLU function [1]_.
1483
-
1484
- If `origin=False`, computes the forward function:
1485
-
1486
- .. math::
1487
-
1488
- g(x) = \begin{cases}
1489
- 1, & x \geq 0 \\
1490
- 0, & x < 0 \\
1491
- \end{cases}
1492
-
1493
- If `origin=True`, computes the original function:
1494
-
1495
- .. math::
1496
-
1497
- \begin{split}g(x) =
1498
- \begin{cases}
1499
- \alpha x, & x \leq 0 \\
1500
- x, & 0 < x \leq 0 \\
1501
- log(x), x > 1 \\
1502
- \end{cases}\end{split}
1503
-
1504
- Backward function:
1505
-
1506
- .. math::
1507
-
1508
- \begin{split}g'(x) =
1509
- \begin{cases}
1510
- \alpha, & x \leq 0 \\
1511
- 1, & 0 < x \leq 0 \\
1512
- \frac{1}{x}, x > 1 \\
1513
- \end{cases}\end{split}
1514
-
1515
- .. plot::
1516
- :include-source: True
1517
-
1518
- >>> import jax
1519
- >>> import brainstate.nn as nn
1520
- >>> import brainstate as brainstate
1521
- >>> import matplotlib.pyplot as plt
1522
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
1523
- >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.leaky_relu)(xs, 0., 1.)
1524
- >>> plt.plot(xs, grads, label=r'$\alpha=0., \beta=1.$')
1525
- >>> plt.legend()
1526
- >>> plt.show()
1527
-
1528
- Parameters
1529
- ----------
1530
- x: jax.Array, Array
1531
- The input data.
1532
- alpha: float
1533
- The parameter to control the gradient.
1534
-
1535
-
1536
- Returns
1537
- -------
1538
- out: jax.Array
1539
- The spiking state.
1540
-
1541
- References
1542
- ----------
1543
- .. [1] Cai, Zhaowei et al. “Deep Learning with Low Precision by Half-Wave Gaussian Quantization.” 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2017): 5406-5414.
1544
- """
1545
- return LogTailedRelu(alpha=alpha)(x)
1546
-
1547
-
1548
- class ReluGrad(Surrogate):
1549
- """Judge spiking state with the ReLU gradient function.
1550
-
1551
- See Also
1552
- --------
1553
- relu_grad
1554
- """
1555
-
1556
- def __init__(self, alpha=0.3, width=1.):
1557
- super().__init__()
1558
- self.alpha = alpha
1559
- self.width = width
1560
-
1561
- def surrogate_grad(self, x):
1562
- dx = jnp.maximum(self.alpha * self.width - jnp.abs(x) * self.alpha, 0)
1563
- return dx
1564
-
1565
- def __repr__(self):
1566
- return f'{self.__class__.__name__}(alpha={self.alpha}, width={self.width})'
1567
-
1568
- def __hash__(self):
1569
- return hash((self.__class__, self.alpha, self.width))
1570
-
1571
-
1572
- def relu_grad(
1573
- x: jax.Array,
1574
- alpha: float = 0.3,
1575
- width: float = 1.,
1576
- ):
1577
- r"""Spike function with the ReLU gradient function [1]_.
1578
-
1579
- The forward function:
1580
-
1581
- .. math::
1582
-
1583
- g(x) = \begin{cases}
1584
- 1, & x \geq 0 \\
1585
- 0, & x < 0 \\
1586
- \end{cases}
1587
-
1588
- Backward function:
1589
-
1590
- .. math::
1591
-
1592
- g'(x) = \text{ReLU}(\alpha * (\mathrm{width}-|x|))
1593
-
1594
- .. plot::
1595
- :include-source: True
1596
-
1597
- >>> import jax
1598
- >>> import brainstate.nn as nn
1599
- >>> import brainstate as brainstate
1600
- >>> import matplotlib.pyplot as plt
1601
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
1602
- >>> for s in [0.5, 1.]:
1603
- >>> for w in [1, 2.]:
1604
- >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.relu_grad)(xs, s, w)
1605
- >>> plt.plot(xs, grads, label=r'$\alpha=$' + f'{s}, width={w}')
1606
- >>> plt.legend()
1607
- >>> plt.show()
1608
-
1609
- Parameters
1610
- ----------
1611
- x: jax.Array, Array
1612
- The input data.
1613
- alpha: float
1614
- The parameter to control the gradient.
1615
- width: float
1616
- The parameter to control the width of the gradient.
1617
-
1618
- Returns
1619
- -------
1620
- out: jax.Array
1621
- The spiking state.
1622
-
1623
- References
1624
- ----------
1625
- .. [1] Neftci, E. O., Mostafa, H. & Zenke, F. Surrogate gradient learning in spiking neural networks. IEEE Signal Process. Mag. 36, 61–63 (2019).
1626
- """
1627
- return ReluGrad(alpha=alpha, width=width)(x)
1628
-
1629
-
1630
- class GaussianGrad(Surrogate):
1631
- """Judge spiking state with the Gaussian gradient function.
1632
-
1633
- See Also
1634
- --------
1635
- gaussian_grad
1636
- """
1637
-
1638
- def __init__(self, sigma=0.5, alpha=0.5):
1639
- super().__init__()
1640
- self.sigma = sigma
1641
- self.alpha = alpha
1642
-
1643
- def surrogate_grad(self, x):
1644
- dx = jnp.exp(-(x ** 2) / 2 * jnp.power(self.sigma, 2)) / (jnp.sqrt(2 * jnp.pi) * self.sigma)
1645
- return self.alpha * dx
1646
-
1647
- def __repr__(self):
1648
- return f'{self.__class__.__name__}(alpha={self.alpha}, sigma={self.sigma})'
1649
-
1650
- def __hash__(self):
1651
- return hash((self.__class__, self.alpha, self.sigma))
1652
-
1653
-
1654
- def gaussian_grad(
1655
- x: jax.Array,
1656
- sigma: float = 0.5,
1657
- alpha: float = 0.5,
1658
- ):
1659
- r"""Spike function with the Gaussian gradient function [1]_.
1660
-
1661
- The forward function:
1662
-
1663
- .. math::
1664
-
1665
- g(x) = \begin{cases}
1666
- 1, & x \geq 0 \\
1667
- 0, & x < 0 \\
1668
- \end{cases}
1669
-
1670
- Backward function:
1671
-
1672
- .. math::
1673
-
1674
- g'(x) = \alpha * \text{gaussian}(x, 0., \sigma)
1675
-
1676
- .. plot::
1677
- :include-source: True
1678
-
1679
- >>> import jax
1680
- >>> import brainstate.nn as nn
1681
- >>> import brainstate as brainstate
1682
- >>> import matplotlib.pyplot as plt
1683
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
1684
- >>> for s in [0.5, 1., 2.]:
1685
- >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.gaussian_grad)(xs, s, 0.5)
1686
- >>> plt.plot(xs, grads, label=r'$\alpha=0.5, \sigma=$' + str(s))
1687
- >>> plt.legend()
1688
- >>> plt.show()
1689
-
1690
- Parameters
1691
- ----------
1692
- x: jax.Array, Array
1693
- The input data.
1694
- sigma: float
1695
- The parameter to control the variance of gaussian distribution.
1696
- alpha: float
1697
- The parameter to control the scale of the gradient.
1698
-
1699
- Returns
1700
- -------
1701
- out: jax.Array
1702
- The spiking state.
1703
-
1704
- References
1705
- ----------
1706
- .. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021).
1707
- """
1708
- return GaussianGrad(sigma=sigma, alpha=alpha)(x)
1709
-
1710
-
1711
- class MultiGaussianGrad(Surrogate):
1712
- """Judge spiking state with the multi-Gaussian gradient function.
1713
-
1714
- See Also
1715
- --------
1716
- multi_gaussian_grad
1717
- """
1718
-
1719
- def __init__(self, h=0.15, s=6.0, sigma=0.5, scale=0.5):
1720
- super().__init__()
1721
- self.h = h
1722
- self.s = s
1723
- self.sigma = sigma
1724
- self.scale = scale
1725
-
1726
- def surrogate_grad(self, x):
1727
- g1 = jnp.exp(-x ** 2 / (2 * jnp.power(self.sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * self.sigma)
1728
- g2 = jnp.exp(-(x - self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2))
1729
- ) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma)
1730
- g3 = jnp.exp(-(x + self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2))
1731
- ) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma)
1732
- dx = g1 * (1. + self.h) - g2 * self.h - g3 * self.h
1733
- return self.scale * dx
1734
-
1735
- def __repr__(self):
1736
- return f'{self.__class__.__name__}(h={self.h}, s={self.s}, sigma={self.sigma}, scale={self.scale})'
1737
-
1738
- def __hash__(self):
1739
- return hash((self.__class__, self.h, self.s, self.sigma, self.scale))
1740
-
1741
-
1742
- def multi_gaussian_grad(
1743
- x: jax.Array,
1744
- h: float = 0.15,
1745
- s: float = 6.0,
1746
- sigma: float = 0.5,
1747
- scale: float = 0.5,
1748
- ):
1749
- r"""Spike function with the multi-Gaussian gradient function [1]_.
1750
-
1751
- The forward function:
1752
-
1753
- .. math::
1754
-
1755
- g(x) = \begin{cases}
1756
- 1, & x \geq 0 \\
1757
- 0, & x < 0 \\
1758
- \end{cases}
1759
-
1760
- Backward function:
1761
-
1762
- .. math::
1763
-
1764
- \begin{array}{l}
1765
- g'(x)=(1+h){{{\mathcal{N}}}}(x, 0, {\sigma }^{2})
1766
- -h{{{\mathcal{N}}}}(x, \sigma,{(s\sigma )}^{2})-
1767
- h{{{\mathcal{N}}}}(x, -\sigma ,{(s\sigma )}^{2})
1768
- \end{array}
1769
-
1770
-
1771
- .. plot::
1772
- :include-source: True
1773
-
1774
- >>> import jax
1775
- >>> import brainstate.nn as nn
1776
- >>> import brainstate as brainstate
1777
- >>> import matplotlib.pyplot as plt
1778
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
1779
- >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.multi_gaussian_grad)(xs)
1780
- >>> plt.plot(xs, grads)
1781
- >>> plt.show()
1782
-
1783
- Parameters
1784
- ----------
1785
- x: jax.Array, Array
1786
- The input data.
1787
- h: float
1788
- The hyper-parameters of approximate function
1789
- s: float
1790
- The hyper-parameters of approximate function
1791
- sigma: float
1792
- The gaussian sigma.
1793
- scale: float
1794
- The gradient scale.
1795
-
1796
- Returns
1797
- -------
1798
- out: jax.Array
1799
- The spiking state.
1800
-
1801
- References
1802
- ----------
1803
- .. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021).
1804
- """
1805
- return MultiGaussianGrad(h=h, s=s, sigma=sigma, scale=scale)(x)
1806
-
1807
-
1808
- class InvSquareGrad(Surrogate):
1809
- """Judge spiking state with the inverse-square surrogate gradient function.
1810
-
1811
- See Also
1812
- --------
1813
- inv_square_grad
1814
- """
1815
-
1816
- def __init__(self, alpha=100.):
1817
- super().__init__()
1818
- self.alpha = alpha
1819
-
1820
- def surrogate_grad(self, x):
1821
- dx = 1. / (self.alpha * jnp.abs(x) + 1.0) ** 2
1822
- return dx
1823
-
1824
- def __repr__(self):
1825
- return f'{self.__class__.__name__}(alpha={self.alpha})'
1826
-
1827
- def __hash__(self):
1828
- return hash((self.__class__, self.alpha))
1829
-
1830
-
1831
- def inv_square_grad(
1832
- x: jax.Array,
1833
- alpha: float = 100.
1834
- ):
1835
- r"""Spike function with the inverse-square surrogate gradient.
1836
-
1837
- Forward function:
1838
-
1839
- .. math::
1840
-
1841
- g(x) = \begin{cases}
1842
- 1, & x \geq 0 \\
1843
- 0, & x < 0 \\
1844
- \end{cases}
1845
-
1846
- Backward function:
1847
-
1848
- .. math::
1849
-
1850
- g'(x) = \frac{1}{(\alpha * |x| + 1.) ^ 2}
1851
-
1852
-
1853
- .. plot::
1854
- :include-source: True
1855
-
1856
- >>> import jax
1857
- >>> import brainstate.nn as nn
1858
- >>> import brainstate as brainstate
1859
- >>> import matplotlib.pyplot as plt
1860
- >>> xs = jax.numpy.linspace(-1, 1, 1000)
1861
- >>> for alpha in [1., 10., 100.]:
1862
- >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.inv_square_grad)(xs, alpha)
1863
- >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
1864
- >>> plt.legend()
1865
- >>> plt.show()
1866
-
1867
- Parameters
1868
- ----------
1869
- x: jax.Array, Array
1870
- The input data.
1871
- alpha: float
1872
- Parameter to control smoothness of gradient
1873
-
1874
- Returns
1875
- -------
1876
- out: jax.Array
1877
- The spiking state.
1878
- """
1879
- return InvSquareGrad(alpha=alpha)(x)
1880
-
1881
-
1882
- class SlayerGrad(Surrogate):
1883
- """Judge spiking state with the slayer surrogate gradient function.
1884
-
1885
- See Also
1886
- --------
1887
- slayer_grad
1888
- """
1889
-
1890
- def __init__(self, alpha=1.):
1891
- super().__init__()
1892
- self.alpha = alpha
1893
-
1894
- def surrogate_grad(self, x):
1895
- dx = jnp.exp(-self.alpha * jnp.abs(x))
1896
- return dx
1897
-
1898
- def __repr__(self):
1899
- return f'{self.__class__.__name__}(alpha={self.alpha})'
1900
-
1901
- def __hash__(self):
1902
- return hash((self.__class__, self.alpha))
1903
-
1904
-
1905
- def slayer_grad(
1906
- x: jax.Array,
1907
- alpha: float = 1.
1908
- ):
1909
- r"""Spike function with the slayer surrogate gradient function.
1910
-
1911
- Forward function:
1912
-
1913
- .. math::
1914
-
1915
- g(x) = \begin{cases}
1916
- 1, & x \geq 0 \\
1917
- 0, & x < 0 \\
1918
- \end{cases}
1919
-
1920
- Backward function:
1921
-
1922
- .. math::
1923
-
1924
- g'(x) = \exp(-\alpha |x|)
1925
-
1926
-
1927
- .. plot::
1928
- :include-source: True
1929
-
1930
- >>> import jax
1931
- >>> import brainstate.nn as nn
1932
- >>> import brainstate as brainstate
1933
- >>> import matplotlib.pyplot as plt
1934
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
1935
- >>> for alpha in [0.5, 1., 2., 4.]:
1936
- >>> grads = brainstate.augment.vector_grad(brainstate.surrogate.slayer_grad)(xs, alpha)
1937
- >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
1938
- >>> plt.legend()
1939
- >>> plt.show()
1940
-
1941
- Parameters
1942
- ----------
1943
- x: jax.Array, Array
1944
- The input data.
1945
- alpha: float
1946
- Parameter to control smoothness of gradient
1947
-
1948
- Returns
1949
- -------
1950
- out: jax.Array
1951
- The spiking state.
1952
-
1953
- References
1954
- ----------
1955
- .. [1] Shrestha, S. B. & Orchard, G. Slayer: spike layer error reassignment in time. In Advances in Neural Information Processing Systems Vol. 31, 1412–1421 (NeurIPS, 2018).
1956
- """
1957
- return SlayerGrad(alpha=alpha)(x)