brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +608 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/{nn/event → event}/__init__.py +6 -6
- brainstate/event/_csr.py +308 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +271 -0
- brainstate/event/_fixed_probability_test.py +128 -0
- brainstate/event/_linear.py +219 -0
- brainstate/event/_linear_test.py +112 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +32 -0
- brainstate/nn/_interaction/_connections.py +726 -0
- brainstate/nn/_interaction/_connections_test.py +254 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +103 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/util/_error.py +28 -0
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
- brainstate-0.1.0.dist-info/RECORD +135 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
brainstate/surrogate.py
CHANGED
@@ -14,6 +14,7 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
|
+
from __future__ import annotations
|
17
18
|
|
18
19
|
import jax
|
19
20
|
import jax.numpy as jnp
|
@@ -22,65 +23,65 @@ from jax.core import Primitive
|
|
22
23
|
from jax.interpreters import batching, ad, mlir
|
23
24
|
|
24
25
|
__all__ = [
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
26
|
+
'Surrogate',
|
27
|
+
'Sigmoid',
|
28
|
+
'sigmoid',
|
29
|
+
'PiecewiseQuadratic',
|
30
|
+
'piecewise_quadratic',
|
31
|
+
'PiecewiseExp',
|
32
|
+
'piecewise_exp',
|
33
|
+
'SoftSign',
|
34
|
+
'soft_sign',
|
35
|
+
'Arctan',
|
36
|
+
'arctan',
|
37
|
+
'NonzeroSignLog',
|
38
|
+
'nonzero_sign_log',
|
39
|
+
'ERF',
|
40
|
+
'erf',
|
41
|
+
'PiecewiseLeakyRelu',
|
42
|
+
'piecewise_leaky_relu',
|
43
|
+
'SquarewaveFourierSeries',
|
44
|
+
'squarewave_fourier_series',
|
45
|
+
'S2NN',
|
46
|
+
's2nn',
|
47
|
+
'QPseudoSpike',
|
48
|
+
'q_pseudo_spike',
|
49
|
+
'LeakyRelu',
|
50
|
+
'leaky_relu',
|
51
|
+
'LogTailedRelu',
|
52
|
+
'log_tailed_relu',
|
53
|
+
'ReluGrad',
|
54
|
+
'relu_grad',
|
55
|
+
'GaussianGrad',
|
56
|
+
'gaussian_grad',
|
57
|
+
'InvSquareGrad',
|
58
|
+
'inv_square_grad',
|
59
|
+
'MultiGaussianGrad',
|
60
|
+
'multi_gaussian_grad',
|
61
|
+
'SlayerGrad',
|
62
|
+
'slayer_grad',
|
62
63
|
]
|
63
64
|
|
64
65
|
|
65
66
|
def _heaviside_abstract(x, dx):
|
66
|
-
|
67
|
+
return [x]
|
67
68
|
|
68
69
|
|
69
70
|
def _heaviside_imp(x, dx):
|
70
|
-
|
71
|
-
|
71
|
+
z = jnp.asarray(x >= 0, dtype=x.dtype)
|
72
|
+
return [z]
|
72
73
|
|
73
74
|
|
74
75
|
def _heaviside_batching(args, axes):
|
75
|
-
|
76
|
+
return heaviside_p.bind(*args), [axes[0]]
|
76
77
|
|
77
78
|
|
78
79
|
def _heaviside_jvp(primals, tangents):
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
80
|
+
x, dx = primals
|
81
|
+
tx, tdx = tangents
|
82
|
+
primal_outs = heaviside_p.bind(x, dx)
|
83
|
+
tangent_outs = [dx * tx, ]
|
84
|
+
return primal_outs, tangent_outs
|
84
85
|
|
85
86
|
|
86
87
|
heaviside_p = Primitive('heaviside_p')
|
@@ -93,260 +94,262 @@ mlir.register_lowering(heaviside_p, mlir.lower_fun(_heaviside_imp, multiple_resu
|
|
93
94
|
|
94
95
|
|
95
96
|
class Surrogate(object):
|
96
|
-
|
97
|
+
"""The base surrograte gradient function.
|
97
98
|
|
98
|
-
|
99
|
-
|
99
|
+
To customize a surrogate gradient function, you can inherit this class and
|
100
|
+
implement the `surrogate_fun` and `surrogate_grad` methods.
|
100
101
|
|
101
|
-
|
102
|
-
|
102
|
+
Examples
|
103
|
+
--------
|
103
104
|
|
104
|
-
|
105
|
-
|
106
|
-
|
105
|
+
>>> import brainstate as bst
|
106
|
+
>>> import brainstate.nn as nn
|
107
|
+
>>> import jax.numpy as jnp
|
107
108
|
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
109
|
+
>>> class MySurrogate(bst.surrogate.Surrogate):
|
110
|
+
... def __init__(self, alpha=1.):
|
111
|
+
... super().__init__()
|
112
|
+
... self.alpha = alpha
|
113
|
+
...
|
114
|
+
... def surrogate_fun(self, x):
|
115
|
+
... return jnp.sin(x) * self.alpha
|
116
|
+
...
|
117
|
+
... def surrogate_grad(self, x):
|
118
|
+
... return jnp.cos(x) * self.alpha
|
118
119
|
|
119
|
-
|
120
|
+
"""
|
120
121
|
|
121
|
-
|
122
|
-
|
123
|
-
|
122
|
+
def __call__(self, x):
|
123
|
+
dx = self.surrogate_grad(x)
|
124
|
+
return heaviside_p.bind(x, dx)[0]
|
124
125
|
|
125
|
-
|
126
|
-
|
126
|
+
def __repr__(self):
|
127
|
+
return f'{self.__class__.__name__}()'
|
127
128
|
|
128
|
-
|
129
|
-
|
130
|
-
|
129
|
+
def surrogate_fun(self, x) -> jax.Array:
|
130
|
+
"""The surrogate function."""
|
131
|
+
raise NotImplementedError
|
131
132
|
|
132
|
-
|
133
|
-
|
134
|
-
|
133
|
+
def surrogate_grad(self, x) -> jax.Array:
|
134
|
+
"""The gradient function of the surrogate function."""
|
135
|
+
raise NotImplementedError
|
135
136
|
|
136
137
|
|
137
138
|
class Sigmoid(Surrogate):
|
138
|
-
|
139
|
+
"""Spike function with the sigmoid-shaped surrogate gradient.
|
139
140
|
|
140
|
-
|
141
|
-
|
142
|
-
|
141
|
+
See Also
|
142
|
+
--------
|
143
|
+
sigmoid
|
143
144
|
|
144
|
-
|
145
|
+
"""
|
145
146
|
|
146
|
-
|
147
|
-
|
148
|
-
|
147
|
+
def __init__(self, alpha: float = 4.):
|
148
|
+
super().__init__()
|
149
|
+
self.alpha = alpha
|
149
150
|
|
150
|
-
|
151
|
-
|
151
|
+
def surrogate_fun(self, x):
|
152
|
+
return sci.special.expit(self.alpha * x)
|
152
153
|
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
154
|
+
def surrogate_grad(self, x):
|
155
|
+
sgax = sci.special.expit(x * self.alpha)
|
156
|
+
dx = (1. - sgax) * sgax * self.alpha
|
157
|
+
return dx
|
157
158
|
|
158
|
-
|
159
|
-
|
159
|
+
def __repr__(self):
|
160
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
160
161
|
|
161
|
-
|
162
|
-
|
162
|
+
def __hash__(self):
|
163
|
+
return hash((self.__class__, self.alpha))
|
163
164
|
|
164
165
|
|
165
166
|
def sigmoid(
|
166
167
|
x: jax.Array,
|
167
168
|
alpha: float = 4.,
|
168
169
|
):
|
169
|
-
|
170
|
+
r"""Spike function with the sigmoid-shaped surrogate gradient.
|
170
171
|
|
171
|
-
|
172
|
+
If `origin=False`, return the forward function:
|
172
173
|
|
173
|
-
|
174
|
+
.. math::
|
174
175
|
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
176
|
+
g(x) = \begin{cases}
|
177
|
+
1, & x \geq 0 \\
|
178
|
+
0, & x < 0 \\
|
179
|
+
\end{cases}
|
179
180
|
|
180
|
-
|
181
|
+
If `origin=True`, computes the original function:
|
181
182
|
|
182
|
-
|
183
|
+
.. math::
|
183
184
|
|
184
|
-
|
185
|
+
g(x) = \mathrm{sigmoid}(\alpha x) = \frac{1}{1+e^{-\alpha x}}
|
185
186
|
|
186
|
-
|
187
|
+
Backward function:
|
187
188
|
|
188
|
-
|
189
|
+
.. math::
|
189
190
|
|
190
|
-
|
191
|
+
g'(x) = \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x)
|
191
192
|
|
192
|
-
|
193
|
-
|
193
|
+
.. plot::
|
194
|
+
:include-source: True
|
194
195
|
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
196
|
+
>>> import jax
|
197
|
+
>>> import brainstate.nn as nn
|
198
|
+
>>> import brainstate as bst
|
199
|
+
>>> import matplotlib.pyplot as plt
|
200
|
+
>>> xs = jax.numpy.linspace(-2, 2, 1000)
|
201
|
+
>>> for alpha in [1., 2., 4.]:
|
202
|
+
>>> grads = bst.augment.vector_grad(bst.surrogate.sigmoid)(xs, alpha)
|
203
|
+
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
204
|
+
>>> plt.legend()
|
205
|
+
>>> plt.show()
|
204
206
|
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
207
|
+
Parameters
|
208
|
+
----------
|
209
|
+
x: jax.Array, Array
|
210
|
+
The input data.
|
211
|
+
alpha: float
|
212
|
+
Parameter to control smoothness of gradient
|
211
213
|
|
212
214
|
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
215
|
+
Returns
|
216
|
+
-------
|
217
|
+
out: jax.Array
|
218
|
+
The spiking state.
|
219
|
+
"""
|
220
|
+
return Sigmoid(alpha=alpha)(x)
|
219
221
|
|
220
222
|
|
221
223
|
class PiecewiseQuadratic(Surrogate):
|
222
|
-
|
224
|
+
"""Judge spiking state with a piecewise quadratic function.
|
223
225
|
|
224
|
-
|
225
|
-
|
226
|
-
|
226
|
+
See Also
|
227
|
+
--------
|
228
|
+
piecewise_quadratic
|
227
229
|
|
228
|
-
|
230
|
+
"""
|
229
231
|
|
230
|
-
|
231
|
-
|
232
|
-
|
232
|
+
def __init__(self, alpha: float = 1.):
|
233
|
+
super().__init__()
|
234
|
+
self.alpha = alpha
|
233
235
|
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
236
|
+
def surrogate_fun(self, x):
|
237
|
+
z = jnp.where(x < -1 / self.alpha,
|
238
|
+
0.,
|
239
|
+
jnp.where(x > 1 / self.alpha,
|
240
|
+
1.,
|
241
|
+
(-self.alpha * jnp.abs(x) / 2 + 1) * self.alpha * x + 0.5))
|
242
|
+
return z
|
241
243
|
|
242
|
-
|
243
|
-
|
244
|
-
|
244
|
+
def surrogate_grad(self, x):
|
245
|
+
dx = jnp.where(jnp.abs(x) > 1 / self.alpha, 0., (-(self.alpha * x) ** 2 + self.alpha))
|
246
|
+
return dx
|
245
247
|
|
246
|
-
|
247
|
-
|
248
|
+
def __repr__(self):
|
249
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
248
250
|
|
249
|
-
|
250
|
-
|
251
|
+
def __hash__(self):
|
252
|
+
return hash((self.__class__, self.alpha))
|
251
253
|
|
252
254
|
|
253
255
|
def piecewise_quadratic(
|
254
256
|
x: jax.Array,
|
255
257
|
alpha: float = 1.,
|
256
258
|
):
|
257
|
-
|
259
|
+
r"""Judge spiking state with a piecewise quadratic function [1]_ [2]_ [3]_ [4]_ [5]_.
|
258
260
|
|
259
|
-
|
261
|
+
If `origin=False`, computes the forward function:
|
260
262
|
|
261
|
-
|
263
|
+
.. math::
|
262
264
|
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
265
|
+
g(x) = \begin{cases}
|
266
|
+
1, & x \geq 0 \\
|
267
|
+
0, & x < 0 \\
|
268
|
+
\end{cases}
|
267
269
|
|
268
|
-
|
270
|
+
If `origin=True`, computes the original function:
|
269
271
|
|
270
|
-
|
272
|
+
.. math::
|
271
273
|
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
274
|
+
g(x) =
|
275
|
+
\begin{cases}
|
276
|
+
0, & x < -\frac{1}{\alpha} \\
|
277
|
+
-\frac{1}{2}\alpha^2|x|x + \alpha x + \frac{1}{2}, & |x| \leq \frac{1}{\alpha} \\
|
278
|
+
1, & x > \frac{1}{\alpha} \\
|
279
|
+
\end{cases}
|
278
280
|
|
279
|
-
|
281
|
+
Backward function:
|
280
282
|
|
281
|
-
|
283
|
+
.. math::
|
282
284
|
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
285
|
+
g'(x) =
|
286
|
+
\begin{cases}
|
287
|
+
0, & |x| > \frac{1}{\alpha} \\
|
288
|
+
-\alpha^2|x|+\alpha, & |x| \leq \frac{1}{\alpha}
|
289
|
+
\end{cases}
|
290
|
+
|
291
|
+
.. plot::
|
292
|
+
:include-source: True
|
293
|
+
|
294
|
+
>>> import jax
|
295
|
+
>>> import brainstate.nn as nn
|
296
|
+
>>> import brainstate as bst
|
297
|
+
>>> import matplotlib.pyplot as plt
|
298
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
299
|
+
>>> for alpha in [0.5, 1., 2., 4.]:
|
300
|
+
>>> grads = bst.augment.vector_grad(bst.surrogate.piecewise_quadratic)(xs, alpha)
|
301
|
+
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
302
|
+
>>> plt.legend()
|
303
|
+
>>> plt.show()
|
304
|
+
|
305
|
+
Parameters
|
306
|
+
----------
|
307
|
+
x: jax.Array, Array
|
308
|
+
The input data.
|
309
|
+
alpha: float
|
310
|
+
Parameter to control smoothness of gradient
|
311
|
+
|
312
|
+
|
313
|
+
Returns
|
314
|
+
-------
|
315
|
+
out: jax.Array
|
316
|
+
The spiking state.
|
317
|
+
|
318
|
+
References
|
319
|
+
----------
|
320
|
+
.. [1] Esser S K, Merolla P A, Arthur J V, et al. Convolutional networks for fast, energy-efficient neuromorphic computing[J]. Proceedings of the national academy of sciences, 2016, 113(41): 11441-11446.
|
321
|
+
.. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331.
|
322
|
+
.. [3] Bellec G, Salaj D, Subramoney A, et al. Long short-term memory and learning-to-learn in networks of spiking neurons[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 795-805.
|
323
|
+
.. [4] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63.
|
324
|
+
.. [5] Panda P, Aketi S A, Roy K. Toward scalable, efficient, and accurate deep spiking neural networks with backward residual connections, stochastic softmax, and hybridization[J]. Frontiers in Neuroscience, 2020, 14.
|
325
|
+
"""
|
326
|
+
return PiecewiseQuadratic(alpha=alpha)(x)
|
324
327
|
|
325
328
|
|
326
329
|
class PiecewiseExp(Surrogate):
|
327
|
-
|
330
|
+
"""Judge spiking state with a piecewise exponential function.
|
328
331
|
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
332
|
+
See Also
|
333
|
+
--------
|
334
|
+
piecewise_exp
|
335
|
+
"""
|
333
336
|
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
+
def __init__(self, alpha: float = 1.):
|
338
|
+
super().__init__()
|
339
|
+
self.alpha = alpha
|
337
340
|
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
+
def surrogate_grad(self, x):
|
342
|
+
dx = (self.alpha / 2) * jnp.exp(-self.alpha * jnp.abs(x))
|
343
|
+
return dx
|
341
344
|
|
342
|
-
|
343
|
-
|
345
|
+
def surrogate_fun(self, x):
|
346
|
+
return jnp.where(x < 0, jnp.exp(self.alpha * x) / 2, 1 - jnp.exp(-self.alpha * x) / 2)
|
344
347
|
|
345
|
-
|
346
|
-
|
348
|
+
def __repr__(self):
|
349
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
347
350
|
|
348
|
-
|
349
|
-
|
351
|
+
def __hash__(self):
|
352
|
+
return hash((self.__class__, self.alpha))
|
350
353
|
|
351
354
|
|
352
355
|
def piecewise_exp(
|
@@ -354,89 +357,90 @@ def piecewise_exp(
|
|
354
357
|
alpha: float = 1.,
|
355
358
|
|
356
359
|
):
|
357
|
-
|
360
|
+
r"""Judge spiking state with a piecewise exponential function [1]_.
|
358
361
|
|
359
|
-
|
362
|
+
If `origin=False`, computes the forward function:
|
360
363
|
|
361
|
-
|
364
|
+
.. math::
|
362
365
|
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
366
|
+
g(x) = \begin{cases}
|
367
|
+
1, & x \geq 0 \\
|
368
|
+
0, & x < 0 \\
|
369
|
+
\end{cases}
|
367
370
|
|
368
|
-
|
371
|
+
If `origin=True`, computes the original function:
|
369
372
|
|
370
|
-
|
373
|
+
.. math::
|
371
374
|
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
375
|
+
g(x) = \begin{cases}
|
376
|
+
\frac{1}{2}e^{\alpha x}, & x < 0 \\
|
377
|
+
1 - \frac{1}{2}e^{-\alpha x}, & x \geq 0
|
378
|
+
\end{cases}
|
376
379
|
|
377
|
-
|
380
|
+
Backward function:
|
378
381
|
|
379
|
-
|
382
|
+
.. math::
|
380
383
|
|
381
|
-
|
384
|
+
g'(x) = \frac{\alpha}{2}e^{-\alpha |x|}
|
382
385
|
|
383
|
-
|
384
|
-
|
386
|
+
.. plot::
|
387
|
+
:include-source: True
|
385
388
|
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
389
|
+
>>> import jax
|
390
|
+
>>> import brainstate.nn as nn
|
391
|
+
>>> import brainstate as bst
|
392
|
+
>>> import matplotlib.pyplot as plt
|
393
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
394
|
+
>>> for alpha in [0.5, 1., 2., 4.]:
|
395
|
+
>>> grads = bst.augment.vector_grad(bst.surrogate.piecewise_exp)(xs, alpha)
|
396
|
+
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
397
|
+
>>> plt.legend()
|
398
|
+
>>> plt.show()
|
395
399
|
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
400
|
+
Parameters
|
401
|
+
----------
|
402
|
+
x: jax.Array, Array
|
403
|
+
The input data.
|
404
|
+
alpha: float
|
405
|
+
Parameter to control smoothness of gradient
|
402
406
|
|
403
407
|
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
+
Returns
|
409
|
+
-------
|
410
|
+
out: jax.Array
|
411
|
+
The spiking state.
|
408
412
|
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
413
|
+
References
|
414
|
+
----------
|
415
|
+
.. [1] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63.
|
416
|
+
"""
|
417
|
+
return PiecewiseExp(alpha=alpha)(x)
|
414
418
|
|
415
419
|
|
416
420
|
class SoftSign(Surrogate):
|
417
|
-
|
421
|
+
"""Judge spiking state with a soft sign function.
|
418
422
|
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
+
See Also
|
424
|
+
--------
|
425
|
+
soft_sign
|
426
|
+
"""
|
423
427
|
|
424
|
-
|
425
|
-
|
426
|
-
|
428
|
+
def __init__(self, alpha=1.):
|
429
|
+
super().__init__()
|
430
|
+
self.alpha = alpha
|
427
431
|
|
428
|
-
|
429
|
-
|
430
|
-
|
432
|
+
def surrogate_grad(self, x):
|
433
|
+
dx = self.alpha * 0.5 / (1 + jnp.abs(self.alpha * x)) ** 2
|
434
|
+
return dx
|
431
435
|
|
432
|
-
|
433
|
-
|
436
|
+
def surrogate_fun(self, x):
|
437
|
+
return x / (2 / self.alpha + 2 * jnp.abs(x)) + 0.5
|
434
438
|
|
435
|
-
|
436
|
-
|
439
|
+
def __repr__(self):
|
440
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
437
441
|
|
438
|
-
|
439
|
-
|
442
|
+
def __hash__(self):
|
443
|
+
return hash((self.__class__, self.alpha))
|
440
444
|
|
441
445
|
|
442
446
|
def soft_sign(
|
@@ -444,84 +448,85 @@ def soft_sign(
|
|
444
448
|
alpha: float = 1.,
|
445
449
|
|
446
450
|
):
|
447
|
-
|
451
|
+
r"""Judge spiking state with a soft sign function.
|
448
452
|
|
449
|
-
|
453
|
+
If `origin=False`, computes the forward function:
|
450
454
|
|
451
|
-
|
455
|
+
.. math::
|
452
456
|
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
+
g(x) = \begin{cases}
|
458
|
+
1, & x \geq 0 \\
|
459
|
+
0, & x < 0 \\
|
460
|
+
\end{cases}
|
457
461
|
|
458
|
-
|
462
|
+
If `origin=True`, computes the original function:
|
459
463
|
|
460
|
-
|
464
|
+
.. math::
|
461
465
|
|
462
|
-
|
463
|
-
|
466
|
+
g(x) = \frac{1}{2} (\frac{\alpha x}{1 + |\alpha x|} + 1)
|
467
|
+
= \frac{1}{2} (\frac{x}{\frac{1}{\alpha} + |x|} + 1)
|
464
468
|
|
465
|
-
|
469
|
+
Backward function:
|
466
470
|
|
467
|
-
|
471
|
+
.. math::
|
468
472
|
|
469
|
-
|
473
|
+
g'(x) = \frac{\alpha}{2(1 + |\alpha x|)^{2}} = \frac{1}{2\alpha(\frac{1}{\alpha} + |x|)^{2}}
|
470
474
|
|
471
|
-
|
472
|
-
|
475
|
+
.. plot::
|
476
|
+
:include-source: True
|
473
477
|
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
478
|
+
>>> import jax
|
479
|
+
>>> import brainstate.nn as nn
|
480
|
+
>>> import brainstate as bst
|
481
|
+
>>> import matplotlib.pyplot as plt
|
482
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
483
|
+
>>> for alpha in [0.5, 1., 2., 4.]:
|
484
|
+
>>> grads = bst.augment.vector_grad(bst.surrogate.soft_sign)(xs, alpha)
|
485
|
+
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
486
|
+
>>> plt.legend()
|
487
|
+
>>> plt.show()
|
483
488
|
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
489
|
+
Parameters
|
490
|
+
----------
|
491
|
+
x: jax.Array, Array
|
492
|
+
The input data.
|
493
|
+
alpha: float
|
494
|
+
Parameter to control smoothness of gradient
|
490
495
|
|
491
496
|
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
497
|
+
Returns
|
498
|
+
-------
|
499
|
+
out: jax.Array
|
500
|
+
The spiking state.
|
496
501
|
|
497
|
-
|
498
|
-
|
502
|
+
"""
|
503
|
+
return SoftSign(alpha=alpha)(x)
|
499
504
|
|
500
505
|
|
501
506
|
class Arctan(Surrogate):
|
502
|
-
|
507
|
+
"""Judge spiking state with an arctan function.
|
503
508
|
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
509
|
+
See Also
|
510
|
+
--------
|
511
|
+
arctan
|
512
|
+
"""
|
508
513
|
|
509
|
-
|
510
|
-
|
511
|
-
|
514
|
+
def __init__(self, alpha=1.):
|
515
|
+
super().__init__()
|
516
|
+
self.alpha = alpha
|
512
517
|
|
513
|
-
|
514
|
-
|
515
|
-
|
518
|
+
def surrogate_grad(self, x):
|
519
|
+
dx = self.alpha * 0.5 / (1 + (jnp.pi / 2 * self.alpha * x) ** 2)
|
520
|
+
return dx
|
516
521
|
|
517
|
-
|
518
|
-
|
522
|
+
def surrogate_fun(self, x):
|
523
|
+
return jnp.arctan2(jnp.pi / 2 * self.alpha * x) / jnp.pi + 0.5
|
519
524
|
|
520
|
-
|
521
|
-
|
525
|
+
def __repr__(self):
|
526
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
522
527
|
|
523
|
-
|
524
|
-
|
528
|
+
def __hash__(self):
|
529
|
+
return hash((self.__class__, self.alpha))
|
525
530
|
|
526
531
|
|
527
532
|
def arctan(
|
@@ -529,83 +534,84 @@ def arctan(
|
|
529
534
|
alpha: float = 1.,
|
530
535
|
|
531
536
|
):
|
532
|
-
|
537
|
+
r"""Judge spiking state with an arctan function.
|
533
538
|
|
534
|
-
|
539
|
+
If `origin=False`, computes the forward function:
|
535
540
|
|
536
|
-
|
541
|
+
.. math::
|
537
542
|
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
543
|
+
g(x) = \begin{cases}
|
544
|
+
1, & x \geq 0 \\
|
545
|
+
0, & x < 0 \\
|
546
|
+
\end{cases}
|
542
547
|
|
543
|
-
|
548
|
+
If `origin=True`, computes the original function:
|
544
549
|
|
545
|
-
|
550
|
+
.. math::
|
546
551
|
|
547
|
-
|
552
|
+
g(x) = \frac{1}{\pi} \arctan(\frac{\pi}{2}\alpha x) + \frac{1}{2}
|
548
553
|
|
549
|
-
|
554
|
+
Backward function:
|
550
555
|
|
551
|
-
|
556
|
+
.. math::
|
552
557
|
|
553
|
-
|
558
|
+
g'(x) = \frac{\alpha}{2(1 + (\frac{\pi}{2}\alpha x)^2)}
|
554
559
|
|
555
|
-
|
556
|
-
|
560
|
+
.. plot::
|
561
|
+
:include-source: True
|
557
562
|
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
563
|
+
>>> import jax
|
564
|
+
>>> import brainstate.nn as nn
|
565
|
+
>>> import brainstate as bst
|
566
|
+
>>> import matplotlib.pyplot as plt
|
567
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
568
|
+
>>> for alpha in [0.5, 1., 2., 4.]:
|
569
|
+
>>> grads = bst.augment.vector_grad(bst.surrogate.arctan)(xs, alpha)
|
570
|
+
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
571
|
+
>>> plt.legend()
|
572
|
+
>>> plt.show()
|
567
573
|
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
+
Parameters
|
575
|
+
----------
|
576
|
+
x: jax.Array, Array
|
577
|
+
The input data.
|
578
|
+
alpha: float
|
579
|
+
Parameter to control smoothness of gradient
|
574
580
|
|
575
581
|
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
582
|
+
Returns
|
583
|
+
-------
|
584
|
+
out: jax.Array
|
585
|
+
The spiking state.
|
580
586
|
|
581
|
-
|
582
|
-
|
587
|
+
"""
|
588
|
+
return Arctan(alpha=alpha)(x)
|
583
589
|
|
584
590
|
|
585
591
|
class NonzeroSignLog(Surrogate):
|
586
|
-
|
592
|
+
"""Judge spiking state with a nonzero sign log function.
|
587
593
|
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
594
|
+
See Also
|
595
|
+
--------
|
596
|
+
nonzero_sign_log
|
597
|
+
"""
|
592
598
|
|
593
|
-
|
594
|
-
|
595
|
-
|
599
|
+
def __init__(self, alpha=1.):
|
600
|
+
super().__init__()
|
601
|
+
self.alpha = alpha
|
596
602
|
|
597
|
-
|
598
|
-
|
599
|
-
|
603
|
+
def surrogate_grad(self, x):
|
604
|
+
dx = 1. / (1 / self.alpha + jnp.abs(x))
|
605
|
+
return dx
|
600
606
|
|
601
|
-
|
602
|
-
|
607
|
+
def surrogate_fun(self, x):
|
608
|
+
return jnp.where(x < 0, -1., 1.) * jnp.log(jnp.abs(self.alpha * x) + 1)
|
603
609
|
|
604
|
-
|
605
|
-
|
610
|
+
def __repr__(self):
|
611
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
606
612
|
|
607
|
-
|
608
|
-
|
613
|
+
def __hash__(self):
|
614
|
+
return hash((self.__class__, self.alpha))
|
609
615
|
|
610
616
|
|
611
617
|
def nonzero_sign_log(
|
@@ -613,96 +619,97 @@ def nonzero_sign_log(
|
|
613
619
|
alpha: float = 1.,
|
614
620
|
|
615
621
|
):
|
616
|
-
|
622
|
+
r"""Judge spiking state with a nonzero sign log function.
|
617
623
|
|
618
|
-
|
624
|
+
If `origin=False`, computes the forward function:
|
619
625
|
|
620
|
-
|
626
|
+
.. math::
|
621
627
|
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
628
|
+
g(x) = \begin{cases}
|
629
|
+
1, & x \geq 0 \\
|
630
|
+
0, & x < 0 \\
|
631
|
+
\end{cases}
|
626
632
|
|
627
|
-
|
633
|
+
If `origin=True`, computes the original function:
|
628
634
|
|
629
|
-
|
635
|
+
.. math::
|
630
636
|
|
631
|
-
|
637
|
+
g(x) = \mathrm{NonzeroSign}(x) \log (|\alpha x| + 1)
|
632
638
|
|
633
|
-
|
639
|
+
where
|
634
640
|
|
635
|
-
|
641
|
+
.. math::
|
636
642
|
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
643
|
+
\begin{split}\mathrm{NonzeroSign}(x) =
|
644
|
+
\begin{cases}
|
645
|
+
1, & x \geq 0 \\
|
646
|
+
-1, & x < 0 \\
|
647
|
+
\end{cases}\end{split}
|
642
648
|
|
643
|
-
|
649
|
+
Backward function:
|
644
650
|
|
645
|
-
|
651
|
+
.. math::
|
646
652
|
|
647
|
-
|
653
|
+
g'(x) = \frac{\alpha}{1 + |\alpha x|} = \frac{1}{\frac{1}{\alpha} + |x|}
|
648
654
|
|
649
|
-
|
655
|
+
This surrogate function has the advantage of low computation cost during the backward.
|
650
656
|
|
651
657
|
|
652
|
-
|
653
|
-
|
658
|
+
.. plot::
|
659
|
+
:include-source: True
|
654
660
|
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
661
|
+
>>> import jax
|
662
|
+
>>> import brainstate.nn as nn
|
663
|
+
>>> import brainstate as bst
|
664
|
+
>>> import matplotlib.pyplot as plt
|
665
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
666
|
+
>>> for alpha in [0.5, 1., 2., 4.]:
|
667
|
+
>>> grads = bst.augment.vector_grad(bst.surrogate.nonzero_sign_log)(xs, alpha)
|
668
|
+
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
669
|
+
>>> plt.legend()
|
670
|
+
>>> plt.show()
|
664
671
|
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
672
|
+
Parameters
|
673
|
+
----------
|
674
|
+
x: jax.Array, Array
|
675
|
+
The input data.
|
676
|
+
alpha: float
|
677
|
+
Parameter to control smoothness of gradient
|
671
678
|
|
672
679
|
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
680
|
+
Returns
|
681
|
+
-------
|
682
|
+
out: jax.Array
|
683
|
+
The spiking state.
|
677
684
|
|
678
|
-
|
679
|
-
|
685
|
+
"""
|
686
|
+
return NonzeroSignLog(alpha=alpha)(x)
|
680
687
|
|
681
688
|
|
682
689
|
class ERF(Surrogate):
|
683
|
-
|
690
|
+
"""Judge spiking state with an erf function.
|
684
691
|
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
692
|
+
See Also
|
693
|
+
--------
|
694
|
+
erf
|
695
|
+
"""
|
689
696
|
|
690
|
-
|
691
|
-
|
692
|
-
|
697
|
+
def __init__(self, alpha=1.):
|
698
|
+
super().__init__()
|
699
|
+
self.alpha = alpha
|
693
700
|
|
694
|
-
|
695
|
-
|
696
|
-
|
701
|
+
def surrogate_grad(self, x):
|
702
|
+
dx = (self.alpha / jnp.sqrt(jnp.pi)) * jnp.exp(-jnp.power(self.alpha, 2) * x * x)
|
703
|
+
return dx
|
697
704
|
|
698
|
-
|
699
|
-
|
705
|
+
def surrogate_fun(self, x):
|
706
|
+
return sci.special.erf(-self.alpha * x) * 0.5
|
700
707
|
|
701
|
-
|
702
|
-
|
708
|
+
def __repr__(self):
|
709
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
703
710
|
|
704
|
-
|
705
|
-
|
711
|
+
def __hash__(self):
|
712
|
+
return hash((self.__class__, self.alpha))
|
706
713
|
|
707
714
|
|
708
715
|
def erf(
|
@@ -710,99 +717,100 @@ def erf(
|
|
710
717
|
alpha: float = 1.,
|
711
718
|
|
712
719
|
):
|
713
|
-
|
720
|
+
r"""Judge spiking state with an erf function [1]_ [2]_ [3]_.
|
714
721
|
|
715
|
-
|
722
|
+
If `origin=False`, computes the forward function:
|
716
723
|
|
717
|
-
|
724
|
+
.. math::
|
718
725
|
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
726
|
+
g(x) = \begin{cases}
|
727
|
+
1, & x \geq 0 \\
|
728
|
+
0, & x < 0 \\
|
729
|
+
\end{cases}
|
723
730
|
|
724
|
-
|
731
|
+
If `origin=True`, computes the original function:
|
725
732
|
|
726
|
-
|
733
|
+
.. math::
|
727
734
|
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
735
|
+
\begin{split}
|
736
|
+
g(x) &= \frac{1}{2}(1-\text{erf}(-\alpha x)) \\
|
737
|
+
&= \frac{1}{2} \text{erfc}(-\alpha x) \\
|
738
|
+
&= \frac{1}{\sqrt{\pi}}\int_{-\infty}^{\alpha x}e^{-t^2}dt
|
739
|
+
\end{split}
|
733
740
|
|
734
|
-
|
741
|
+
Backward function:
|
735
742
|
|
736
|
-
|
743
|
+
.. math::
|
737
744
|
|
738
|
-
|
745
|
+
g'(x) = \frac{\alpha}{\sqrt{\pi}}e^{-\alpha^2x^2}
|
739
746
|
|
740
|
-
|
741
|
-
|
747
|
+
.. plot::
|
748
|
+
:include-source: True
|
742
749
|
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
750
|
+
>>> import jax
|
751
|
+
>>> import brainstate.nn as nn
|
752
|
+
>>> import brainstate as bst
|
753
|
+
>>> import matplotlib.pyplot as plt
|
754
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
755
|
+
>>> for alpha in [0.5, 1., 2., 4.]:
|
756
|
+
>>> grads = bst.augment.vector_grad(bst.surrogate.nonzero_sign_log)(xs, alpha)
|
757
|
+
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
758
|
+
>>> plt.legend()
|
759
|
+
>>> plt.show()
|
752
760
|
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
761
|
+
Parameters
|
762
|
+
----------
|
763
|
+
x: jax.Array, Array
|
764
|
+
The input data.
|
765
|
+
alpha: float
|
766
|
+
Parameter to control smoothness of gradient
|
759
767
|
|
760
768
|
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
769
|
+
Returns
|
770
|
+
-------
|
771
|
+
out: jax.Array
|
772
|
+
The spiking state.
|
765
773
|
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
774
|
+
References
|
775
|
+
----------
|
776
|
+
.. [1] Esser S K, Appuswamy R, Merolla P, et al. Backpropagation for energy-efficient neuromorphic computing[J]. Advances in neural information processing systems, 2015, 28: 1117-1125.
|
777
|
+
.. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331.
|
778
|
+
.. [3] Yin B, Corradi F, Bohté S M. Effective and efficient computation with multiple-timescale spiking recurrent neural networks[C]//International Conference on Neuromorphic Systems 2020. 2020: 1-8.
|
771
779
|
|
772
|
-
|
773
|
-
|
780
|
+
"""
|
781
|
+
return ERF(alpha=alpha)(x)
|
774
782
|
|
775
783
|
|
776
784
|
class PiecewiseLeakyRelu(Surrogate):
|
777
|
-
|
785
|
+
"""Judge spiking state with a piecewise leaky relu function.
|
778
786
|
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
787
|
+
See Also
|
788
|
+
--------
|
789
|
+
piecewise_leaky_relu
|
790
|
+
"""
|
783
791
|
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
792
|
+
def __init__(self, c=0.01, w=1.):
|
793
|
+
super().__init__()
|
794
|
+
self.c = c
|
795
|
+
self.w = w
|
788
796
|
|
789
|
-
|
790
|
-
|
791
|
-
|
792
|
-
|
793
|
-
|
794
|
-
|
795
|
-
|
797
|
+
def surrogate_fun(self, x):
|
798
|
+
z = jnp.where(x < -self.w,
|
799
|
+
self.c * x + self.c * self.w,
|
800
|
+
jnp.where(x > self.w,
|
801
|
+
self.c * x - self.c * self.w + 1,
|
802
|
+
0.5 * x / self.w + 0.5))
|
803
|
+
return z
|
796
804
|
|
797
|
-
|
798
|
-
|
799
|
-
|
805
|
+
def surrogate_grad(self, x):
|
806
|
+
dx = jnp.where(jnp.abs(x) > self.w, self.c, 1 / self.w)
|
807
|
+
return dx
|
800
808
|
|
801
|
-
|
802
|
-
|
809
|
+
def __repr__(self):
|
810
|
+
return f'{self.__class__.__name__}(c={self.c}, w={self.w})'
|
803
811
|
|
804
|
-
|
805
|
-
|
812
|
+
def __hash__(self):
|
813
|
+
return hash((self.__class__, self.c, self.w))
|
806
814
|
|
807
815
|
|
808
816
|
def piecewise_leaky_relu(
|
@@ -811,119 +819,120 @@ def piecewise_leaky_relu(
|
|
811
819
|
w: float = 1.,
|
812
820
|
|
813
821
|
):
|
814
|
-
|
822
|
+
r"""Judge spiking state with a piecewise leaky relu function [1]_ [2]_ [3]_ [4]_ [5]_ [6]_ [7]_ [8]_.
|
815
823
|
|
816
|
-
|
824
|
+
If `origin=False`, computes the forward function:
|
817
825
|
|
818
|
-
|
826
|
+
.. math::
|
819
827
|
|
820
|
-
|
821
|
-
|
822
|
-
|
823
|
-
|
824
|
-
|
825
|
-
If `origin=True`, computes the original function:
|
828
|
+
g(x) = \begin{cases}
|
829
|
+
1, & x \geq 0 \\
|
830
|
+
0, & x < 0 \\
|
831
|
+
\end{cases}
|
826
832
|
|
827
|
-
|
833
|
+
If `origin=True`, computes the original function:
|
828
834
|
|
829
|
-
|
830
|
-
\begin{cases}
|
831
|
-
cx + cw, & x < -w \\
|
832
|
-
\frac{1}{2w}x + \frac{1}{2}, & -w \leq x \leq w \\
|
833
|
-
cx - cw + 1, & x > w \\
|
834
|
-
\end{cases}\end{split}
|
835
|
+
.. math::
|
835
836
|
|
836
|
-
|
837
|
+
\begin{split}g(x) =
|
838
|
+
\begin{cases}
|
839
|
+
cx + cw, & x < -w \\
|
840
|
+
\frac{1}{2w}x + \frac{1}{2}, & -w \leq x \leq w \\
|
841
|
+
cx - cw + 1, & x > w \\
|
842
|
+
\end{cases}\end{split}
|
837
843
|
|
838
|
-
|
844
|
+
Backward function:
|
839
845
|
|
840
|
-
|
841
|
-
\begin{cases}
|
842
|
-
\frac{1}{w}, & |x| \leq w \\
|
843
|
-
c, & |x| > w
|
844
|
-
\end{cases}\end{split}
|
846
|
+
.. math::
|
845
847
|
|
846
|
-
|
847
|
-
|
848
|
-
|
849
|
-
|
850
|
-
|
851
|
-
|
852
|
-
|
853
|
-
|
854
|
-
|
855
|
-
|
856
|
-
|
857
|
-
|
858
|
-
|
859
|
-
|
860
|
-
|
861
|
-
|
862
|
-
|
863
|
-
|
864
|
-
|
865
|
-
|
866
|
-
|
867
|
-
|
868
|
-
|
869
|
-
|
870
|
-
|
871
|
-
|
872
|
-
|
873
|
-
|
874
|
-
|
875
|
-
|
876
|
-
|
877
|
-
|
878
|
-
|
879
|
-
|
880
|
-
|
881
|
-
|
882
|
-
|
883
|
-
|
884
|
-
|
885
|
-
|
886
|
-
|
887
|
-
|
848
|
+
\begin{split}g'(x) =
|
849
|
+
\begin{cases}
|
850
|
+
\frac{1}{w}, & |x| \leq w \\
|
851
|
+
c, & |x| > w
|
852
|
+
\end{cases}\end{split}
|
853
|
+
|
854
|
+
.. plot::
|
855
|
+
:include-source: True
|
856
|
+
|
857
|
+
>>> import jax
|
858
|
+
>>> import brainstate.nn as nn
|
859
|
+
>>> import brainstate as bst
|
860
|
+
>>> import matplotlib.pyplot as plt
|
861
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
862
|
+
>>> for c in [0.01, 0.05, 0.1]:
|
863
|
+
>>> for w in [1., 2.]:
|
864
|
+
>>> grads1 = bst.augment.vector_grad(bst.surrogate.piecewise_leaky_relu)(xs, c=c, w=w)
|
865
|
+
>>> plt.plot(xs, grads1, label=f'x={c}, w={w}')
|
866
|
+
>>> plt.legend()
|
867
|
+
>>> plt.show()
|
868
|
+
|
869
|
+
Parameters
|
870
|
+
----------
|
871
|
+
x: jax.Array, Array
|
872
|
+
The input data.
|
873
|
+
c: float
|
874
|
+
When :math:`|x| > w` the gradient is `c`.
|
875
|
+
w: float
|
876
|
+
When :math:`|x| <= w` the gradient is `1 / w`.
|
877
|
+
|
878
|
+
|
879
|
+
Returns
|
880
|
+
-------
|
881
|
+
out: jax.Array
|
882
|
+
The spiking state.
|
883
|
+
|
884
|
+
References
|
885
|
+
----------
|
886
|
+
.. [1] Yin S, Venkataramanaiah S K, Chen G K, et al. Algorithm and hardware design of discrete-time spiking neural networks based on back propagation with binary activations[C]//2017 IEEE Biomedical Circuits and Systems Conference (BioCAS). IEEE, 2017: 1-5.
|
887
|
+
.. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331.
|
888
|
+
.. [3] Huh D, Sejnowski T J. Gradient descent for spiking neural networks[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 1440-1450.
|
889
|
+
.. [4] Wu Y, Deng L, Li G, et al. Direct training for spiking neural networks: Faster, larger, better[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2019, 33(01): 1311-1318.
|
890
|
+
.. [5] Gu P, Xiao R, Pan G, et al. STCA: Spatio-Temporal Credit Assignment with Delayed Feedback in Deep Spiking Neural Networks[C]//IJCAI. 2019: 1366-1372.
|
891
|
+
.. [6] Roy D, Chakraborty I, Roy K. Scaling deep spiking neural networks with binary stochastic activations[C]//2019 IEEE International Conference on Cognitive Computing (ICCC). IEEE, 2019: 50-58.
|
892
|
+
.. [7] Cheng X, Hao Y, Xu J, et al. LISNN: Improving Spiking Neural Networks with Lateral Interactions for Robust Object Recognition[C]//IJCAI. 1519-1525.
|
893
|
+
.. [8] Kaiser J, Mostafa H, Neftci E. Synaptic plasticity dynamics for deep continuous local learning (DECOLLE)[J]. Frontiers in Neuroscience, 2020, 14: 424.
|
894
|
+
|
895
|
+
"""
|
896
|
+
return PiecewiseLeakyRelu(c=c, w=w)(x)
|
888
897
|
|
889
898
|
|
890
899
|
class SquarewaveFourierSeries(Surrogate):
|
891
|
-
|
900
|
+
"""Judge spiking state with a squarewave fourier series.
|
892
901
|
|
893
|
-
|
894
|
-
|
895
|
-
|
896
|
-
|
902
|
+
See Also
|
903
|
+
--------
|
904
|
+
squarewave_fourier_series
|
905
|
+
"""
|
897
906
|
|
898
|
-
|
899
|
-
|
900
|
-
|
901
|
-
|
907
|
+
def __init__(self, n=2, t_period=8.):
|
908
|
+
super().__init__()
|
909
|
+
self.n = n
|
910
|
+
self.t_period = t_period
|
902
911
|
|
903
|
-
|
912
|
+
def surrogate_grad(self, x):
|
904
913
|
|
905
|
-
|
906
|
-
|
907
|
-
|
908
|
-
|
909
|
-
|
910
|
-
|
914
|
+
w = jnp.pi * 2. / self.t_period
|
915
|
+
dx = jnp.cos(w * x)
|
916
|
+
for i in range(2, self.n):
|
917
|
+
dx += jnp.cos((2 * i - 1.) * w * x)
|
918
|
+
dx *= 4. / self.t_period
|
919
|
+
return dx
|
911
920
|
|
912
|
-
|
921
|
+
def surrogate_fun(self, x):
|
913
922
|
|
914
|
-
|
915
|
-
|
916
|
-
|
917
|
-
|
918
|
-
|
919
|
-
|
920
|
-
|
923
|
+
w = jnp.pi * 2. / self.t_period
|
924
|
+
ret = jnp.sin(w * x)
|
925
|
+
for i in range(2, self.n):
|
926
|
+
c = (2 * i - 1.)
|
927
|
+
ret += jnp.sin(c * w * x) / c
|
928
|
+
z = 0.5 + 2. / jnp.pi * ret
|
929
|
+
return z
|
921
930
|
|
922
|
-
|
923
|
-
|
931
|
+
def __repr__(self):
|
932
|
+
return f'{self.__class__.__name__}(n={self.n}, t_period={self.t_period})'
|
924
933
|
|
925
|
-
|
926
|
-
|
934
|
+
def __hash__(self):
|
935
|
+
return hash((self.__class__, self.n, self.t_period))
|
927
936
|
|
928
937
|
|
929
938
|
def squarewave_fourier_series(
|
@@ -932,91 +941,92 @@ def squarewave_fourier_series(
|
|
932
941
|
t_period: float = 8.,
|
933
942
|
|
934
943
|
):
|
935
|
-
|
944
|
+
r"""Judge spiking state with a squarewave fourier series.
|
936
945
|
|
937
|
-
|
946
|
+
If `origin=False`, computes the forward function:
|
938
947
|
|
939
|
-
|
948
|
+
.. math::
|
940
949
|
|
941
|
-
|
942
|
-
|
943
|
-
|
944
|
-
|
950
|
+
g(x) = \begin{cases}
|
951
|
+
1, & x \geq 0 \\
|
952
|
+
0, & x < 0 \\
|
953
|
+
\end{cases}
|
945
954
|
|
946
|
-
|
955
|
+
If `origin=True`, computes the original function:
|
947
956
|
|
948
|
-
|
957
|
+
.. math::
|
949
958
|
|
950
|
-
|
959
|
+
g(x) = 0.5 + \frac{1}{\pi}*\sum_{i=1}^n {\sin\left({(2i-1)*2\pi}*x/T\right) \over 2i-1 }
|
951
960
|
|
952
|
-
|
961
|
+
Backward function:
|
953
962
|
|
954
|
-
|
963
|
+
.. math::
|
955
964
|
|
956
|
-
|
965
|
+
g'(x) = \sum_{i=1}^n\frac{4\cos\left((2 * i - 1.) * 2\pi * x / T\right)}{T}
|
957
966
|
|
958
|
-
|
959
|
-
|
967
|
+
.. plot::
|
968
|
+
:include-source: True
|
960
969
|
|
961
|
-
|
962
|
-
|
963
|
-
|
964
|
-
|
965
|
-
|
966
|
-
|
967
|
-
|
968
|
-
|
969
|
-
|
970
|
-
|
970
|
+
>>> import jax
|
971
|
+
>>> import brainstate.nn as nn
|
972
|
+
>>> import brainstate as bst
|
973
|
+
>>> import matplotlib.pyplot as plt
|
974
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
975
|
+
>>> for n in [2, 4, 8]:
|
976
|
+
>>> f = bst.surrogate.SquarewaveFourierSeries(n=n)
|
977
|
+
>>> grads1 = bst.augment.vector_grad(f)(xs)
|
978
|
+
>>> plt.plot(xs, grads1, label=f'n={n}')
|
979
|
+
>>> plt.legend()
|
980
|
+
>>> plt.show()
|
971
981
|
|
972
|
-
|
973
|
-
|
974
|
-
|
975
|
-
|
976
|
-
|
977
|
-
|
982
|
+
Parameters
|
983
|
+
----------
|
984
|
+
x: jax.Array, Array
|
985
|
+
The input data.
|
986
|
+
n: int
|
987
|
+
t_period: float
|
978
988
|
|
979
989
|
|
980
|
-
|
981
|
-
|
982
|
-
|
983
|
-
|
990
|
+
Returns
|
991
|
+
-------
|
992
|
+
out: jax.Array
|
993
|
+
The spiking state.
|
984
994
|
|
985
|
-
|
995
|
+
"""
|
986
996
|
|
987
|
-
|
997
|
+
return SquarewaveFourierSeries(n=n, t_period=t_period)(x)
|
988
998
|
|
989
999
|
|
990
1000
|
class S2NN(Surrogate):
|
991
|
-
|
1001
|
+
"""Judge spiking state with the S2NN surrogate spiking function.
|
992
1002
|
|
993
|
-
|
994
|
-
|
995
|
-
|
996
|
-
|
1003
|
+
See Also
|
1004
|
+
--------
|
1005
|
+
s2nn
|
1006
|
+
"""
|
997
1007
|
|
998
|
-
|
999
|
-
|
1000
|
-
|
1001
|
-
|
1002
|
-
|
1008
|
+
def __init__(self, alpha=4., beta=1., epsilon=1e-8):
|
1009
|
+
super().__init__()
|
1010
|
+
self.alpha = alpha
|
1011
|
+
self.beta = beta
|
1012
|
+
self.epsilon = epsilon
|
1003
1013
|
|
1004
|
-
|
1005
|
-
|
1006
|
-
|
1007
|
-
|
1008
|
-
|
1014
|
+
def surrogate_fun(self, x):
|
1015
|
+
z = jnp.where(x < 0.,
|
1016
|
+
sci.special.expit(x * self.alpha),
|
1017
|
+
self.beta * jnp.log(jnp.abs((x + 1.)) + self.epsilon) + 0.5)
|
1018
|
+
return z
|
1009
1019
|
|
1010
|
-
|
1011
|
-
|
1012
|
-
|
1013
|
-
|
1020
|
+
def surrogate_grad(self, x):
|
1021
|
+
sg = sci.special.expit(self.alpha * x)
|
1022
|
+
dx = jnp.where(x < 0., self.alpha * sg * (1. - sg), self.beta / (x + 1.))
|
1023
|
+
return dx
|
1014
1024
|
|
1015
|
-
|
1016
|
-
|
1025
|
+
def __repr__(self):
|
1026
|
+
return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta}, epsilon={self.epsilon})'
|
1017
1027
|
|
1018
|
-
|
1019
|
-
|
1028
|
+
def __hash__(self):
|
1029
|
+
return hash((self.__class__, self.alpha, self.beta, self.epsilon))
|
1020
1030
|
|
1021
1031
|
|
1022
1032
|
def s2nn(
|
@@ -1026,101 +1036,102 @@ def s2nn(
|
|
1026
1036
|
epsilon: float = 1e-8,
|
1027
1037
|
|
1028
1038
|
):
|
1029
|
-
|
1030
|
-
|
1031
|
-
If `origin=False`, computes the forward function:
|
1032
|
-
|
1033
|
-
.. math::
|
1034
|
-
|
1035
|
-
g(x) = \begin{cases}
|
1036
|
-
1, & x \geq 0 \\
|
1037
|
-
0, & x < 0 \\
|
1038
|
-
\end{cases}
|
1039
|
+
r"""Judge spiking state with the S2NN surrogate spiking function [1]_.
|
1039
1040
|
|
1040
|
-
|
1041
|
+
If `origin=False`, computes the forward function:
|
1041
1042
|
|
1042
|
-
|
1043
|
+
.. math::
|
1043
1044
|
|
1044
|
-
|
1045
|
-
|
1046
|
-
|
1047
|
-
|
1045
|
+
g(x) = \begin{cases}
|
1046
|
+
1, & x \geq 0 \\
|
1047
|
+
0, & x < 0 \\
|
1048
|
+
\end{cases}
|
1048
1049
|
|
1049
|
-
|
1050
|
+
If `origin=True`, computes the original function:
|
1050
1051
|
|
1051
|
-
|
1052
|
+
.. math::
|
1052
1053
|
|
1053
|
-
|
1054
|
-
|
1055
|
-
|
1056
|
-
|
1057
|
-
|
1058
|
-
.. plot::
|
1059
|
-
:include-source: True
|
1060
|
-
|
1061
|
-
>>> import brainstate.nn as nn
|
1062
|
-
>>> import brainstate as bst
|
1063
|
-
>>> import matplotlib.pyplot as plt
|
1064
|
-
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1065
|
-
>>> grads = bst.transform.vector_grad(nn.surrogate.s2nn)(xs, 4., 1.)
|
1066
|
-
>>> plt.plot(xs, grads, label=r'$\alpha=4, \beta=1$')
|
1067
|
-
>>> grads = bst.transform.vector_grad(nn.surrogate.s2nn)(xs, 8., 2.)
|
1068
|
-
>>> plt.plot(xs, grads, label=r'$\alpha=8, \beta=2$')
|
1069
|
-
>>> plt.legend()
|
1070
|
-
>>> plt.show()
|
1071
|
-
|
1072
|
-
Parameters
|
1073
|
-
----------
|
1074
|
-
x: jax.Array, Array
|
1075
|
-
The input data.
|
1076
|
-
alpha: float
|
1077
|
-
The param that controls the gradient when ``x < 0``.
|
1078
|
-
beta: float
|
1079
|
-
The param that controls the gradient when ``x >= 0``
|
1080
|
-
epsilon: float
|
1081
|
-
Avoid nan
|
1054
|
+
\begin{split}g(x) = \begin{cases}
|
1055
|
+
\mathrm{sigmoid} (\alpha x), x < 0 \\
|
1056
|
+
\beta \ln(|x + 1|) + 0.5, x \ge 0
|
1057
|
+
\end{cases}\end{split}
|
1082
1058
|
|
1059
|
+
Backward function:
|
1083
1060
|
|
1084
|
-
|
1085
|
-
-------
|
1086
|
-
out: jax.Array
|
1087
|
-
The spiking state.
|
1061
|
+
.. math::
|
1088
1062
|
|
1089
|
-
|
1090
|
-
|
1091
|
-
|
1063
|
+
\begin{split}g'(x) = \begin{cases}
|
1064
|
+
\alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x), x < 0 \\
|
1065
|
+
\frac{\beta}{(x + 1)}, x \ge 0
|
1066
|
+
\end{cases}\end{split}
|
1092
1067
|
|
1093
|
-
|
1094
|
-
|
1068
|
+
.. plot::
|
1069
|
+
:include-source: True
|
1070
|
+
|
1071
|
+
>>> import jax
|
1072
|
+
>>> import brainstate.nn as nn
|
1073
|
+
>>> import brainstate as bst
|
1074
|
+
>>> import matplotlib.pyplot as plt
|
1075
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1076
|
+
>>> grads = bst.augment.vector_grad(bst.surrogate.s2nn)(xs, 4., 1.)
|
1077
|
+
>>> plt.plot(xs, grads, label=r'$\alpha=4, \beta=1$')
|
1078
|
+
>>> grads = bst.augment.vector_grad(bst.surrogate.s2nn)(xs, 8., 2.)
|
1079
|
+
>>> plt.plot(xs, grads, label=r'$\alpha=8, \beta=2$')
|
1080
|
+
>>> plt.legend()
|
1081
|
+
>>> plt.show()
|
1082
|
+
|
1083
|
+
Parameters
|
1084
|
+
----------
|
1085
|
+
x: jax.Array, Array
|
1086
|
+
The input data.
|
1087
|
+
alpha: float
|
1088
|
+
The param that controls the gradient when ``x < 0``.
|
1089
|
+
beta: float
|
1090
|
+
The param that controls the gradient when ``x >= 0``
|
1091
|
+
epsilon: float
|
1092
|
+
Avoid nan
|
1093
|
+
|
1094
|
+
|
1095
|
+
Returns
|
1096
|
+
-------
|
1097
|
+
out: jax.Array
|
1098
|
+
The spiking state.
|
1099
|
+
|
1100
|
+
References
|
1101
|
+
----------
|
1102
|
+
.. [1] Suetake, Kazuma et al. “S2NN: Time Step Reduction of Spiking Surrogate Gradients for Training Energy Efficient Single-Step Neural Networks.” ArXiv abs/2201.10879 (2022): n. pag.
|
1103
|
+
|
1104
|
+
"""
|
1105
|
+
return S2NN(alpha=alpha, beta=beta, epsilon=epsilon)(x)
|
1095
1106
|
|
1096
1107
|
|
1097
1108
|
class QPseudoSpike(Surrogate):
|
1098
|
-
|
1109
|
+
"""Judge spiking state with the q-PseudoSpike surrogate function.
|
1099
1110
|
|
1100
|
-
|
1101
|
-
|
1102
|
-
|
1103
|
-
|
1111
|
+
See Also
|
1112
|
+
--------
|
1113
|
+
q_pseudo_spike
|
1114
|
+
"""
|
1104
1115
|
|
1105
|
-
|
1106
|
-
|
1107
|
-
|
1116
|
+
def __init__(self, alpha=2.):
|
1117
|
+
super().__init__()
|
1118
|
+
self.alpha = alpha
|
1108
1119
|
|
1109
|
-
|
1110
|
-
|
1111
|
-
|
1120
|
+
def surrogate_grad(self, x):
|
1121
|
+
dx = jnp.power(1 + 2 / (self.alpha + 1) * jnp.abs(x), -self.alpha)
|
1122
|
+
return dx
|
1112
1123
|
|
1113
|
-
|
1114
|
-
|
1115
|
-
|
1116
|
-
|
1117
|
-
|
1124
|
+
def surrogate_fun(self, x):
|
1125
|
+
z = jnp.where(x < 0.,
|
1126
|
+
0.5 * jnp.power(1 - 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha),
|
1127
|
+
1. - 0.5 * jnp.power(1 + 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha))
|
1128
|
+
return z
|
1118
1129
|
|
1119
|
-
|
1120
|
-
|
1130
|
+
def __repr__(self):
|
1131
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
1121
1132
|
|
1122
|
-
|
1123
|
-
|
1133
|
+
def __hash__(self):
|
1134
|
+
return hash((self.__class__, self.alpha))
|
1124
1135
|
|
1125
1136
|
|
1126
1137
|
def q_pseudo_spike(
|
@@ -1128,91 +1139,92 @@ def q_pseudo_spike(
|
|
1128
1139
|
alpha: float = 2.,
|
1129
1140
|
|
1130
1141
|
):
|
1131
|
-
|
1142
|
+
r"""Judge spiking state with the q-PseudoSpike surrogate function [1]_.
|
1132
1143
|
|
1133
|
-
|
1144
|
+
If `origin=False`, computes the forward function:
|
1134
1145
|
|
1135
|
-
|
1146
|
+
.. math::
|
1136
1147
|
|
1137
|
-
|
1138
|
-
|
1139
|
-
|
1140
|
-
|
1148
|
+
g(x) = \begin{cases}
|
1149
|
+
1, & x \geq 0 \\
|
1150
|
+
0, & x < 0 \\
|
1151
|
+
\end{cases}
|
1141
1152
|
|
1142
|
-
|
1153
|
+
If `origin=True`, computes the original function:
|
1143
1154
|
|
1144
|
-
|
1155
|
+
.. math::
|
1145
1156
|
|
1146
|
-
|
1147
|
-
|
1148
|
-
|
1149
|
-
|
1150
|
-
|
1157
|
+
\begin{split}g(x) =
|
1158
|
+
\begin{cases}
|
1159
|
+
\frac{1}{2}(1-\frac{2x}{\alpha-1})^{1-\alpha}, & x < 0 \\
|
1160
|
+
1 - \frac{1}{2}(1+\frac{2x}{\alpha-1})^{1-\alpha}, & x \geq 0.
|
1161
|
+
\end{cases}\end{split}
|
1151
1162
|
|
1152
|
-
|
1163
|
+
Backward function:
|
1153
1164
|
|
1154
|
-
|
1165
|
+
.. math::
|
1155
1166
|
|
1156
|
-
|
1167
|
+
g'(x) = (1+\frac{2|x|}{\alpha-1})^{-\alpha}
|
1157
1168
|
|
1158
|
-
|
1159
|
-
|
1169
|
+
.. plot::
|
1170
|
+
:include-source: True
|
1160
1171
|
|
1161
|
-
|
1162
|
-
|
1163
|
-
|
1164
|
-
|
1165
|
-
|
1166
|
-
|
1167
|
-
|
1168
|
-
|
1169
|
-
|
1172
|
+
>>> import jax
|
1173
|
+
>>> import brainstate.nn as nn
|
1174
|
+
>>> import brainstate as bst
|
1175
|
+
>>> import matplotlib.pyplot as plt
|
1176
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1177
|
+
>>> for alpha in [0.5, 1., 2., 4.]:
|
1178
|
+
>>> grads = bst.augment.vector_grad(bst.surrogate.q_pseudo_spike)(xs, alpha)
|
1179
|
+
>>> plt.plot(xs, grads, label=r'$\alpha=$' + str(alpha))
|
1180
|
+
>>> plt.legend()
|
1181
|
+
>>> plt.show()
|
1170
1182
|
|
1171
|
-
|
1172
|
-
|
1173
|
-
|
1174
|
-
|
1175
|
-
|
1176
|
-
|
1183
|
+
Parameters
|
1184
|
+
----------
|
1185
|
+
x: jax.Array, Array
|
1186
|
+
The input data.
|
1187
|
+
alpha: float
|
1188
|
+
The parameter to control tail fatness of gradient.
|
1177
1189
|
|
1178
1190
|
|
1179
|
-
|
1180
|
-
|
1181
|
-
|
1182
|
-
|
1191
|
+
Returns
|
1192
|
+
-------
|
1193
|
+
out: jax.Array
|
1194
|
+
The spiking state.
|
1183
1195
|
|
1184
|
-
|
1185
|
-
|
1186
|
-
|
1187
|
-
|
1188
|
-
|
1196
|
+
References
|
1197
|
+
----------
|
1198
|
+
.. [1] Herranz-Celotti, Luca and Jean Rouat. “Surrogate Gradients Design.” ArXiv abs/2202.00282 (2022): n. pag.
|
1199
|
+
"""
|
1200
|
+
return QPseudoSpike(alpha=alpha)(x)
|
1189
1201
|
|
1190
1202
|
|
1191
1203
|
class LeakyRelu(Surrogate):
|
1192
|
-
|
1204
|
+
"""Judge spiking state with the Leaky ReLU function.
|
1193
1205
|
|
1194
|
-
|
1195
|
-
|
1196
|
-
|
1197
|
-
|
1206
|
+
See Also
|
1207
|
+
--------
|
1208
|
+
leaky_relu
|
1209
|
+
"""
|
1198
1210
|
|
1199
|
-
|
1200
|
-
|
1201
|
-
|
1202
|
-
|
1211
|
+
def __init__(self, alpha=0.1, beta=1.):
|
1212
|
+
super().__init__()
|
1213
|
+
self.alpha = alpha
|
1214
|
+
self.beta = beta
|
1203
1215
|
|
1204
|
-
|
1205
|
-
|
1216
|
+
def surrogate_fun(self, x):
|
1217
|
+
return jnp.where(x < 0., self.alpha * x, self.beta * x)
|
1206
1218
|
|
1207
|
-
|
1208
|
-
|
1209
|
-
|
1219
|
+
def surrogate_grad(self, x):
|
1220
|
+
dx = jnp.where(x < 0., self.alpha, self.beta)
|
1221
|
+
return dx
|
1210
1222
|
|
1211
|
-
|
1212
|
-
|
1223
|
+
def __repr__(self):
|
1224
|
+
return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta})'
|
1213
1225
|
|
1214
|
-
|
1215
|
-
|
1226
|
+
def __hash__(self):
|
1227
|
+
return hash((self.__class__, self.alpha, self.beta))
|
1216
1228
|
|
1217
1229
|
|
1218
1230
|
def leaky_relu(
|
@@ -1221,100 +1233,101 @@ def leaky_relu(
|
|
1221
1233
|
beta: float = 1.,
|
1222
1234
|
|
1223
1235
|
):
|
1224
|
-
|
1236
|
+
r"""Judge spiking state with the Leaky ReLU function.
|
1225
1237
|
|
1226
|
-
|
1227
|
-
|
1228
|
-
.. math::
|
1229
|
-
|
1230
|
-
g(x) = \begin{cases}
|
1231
|
-
1, & x \geq 0 \\
|
1232
|
-
0, & x < 0 \\
|
1233
|
-
\end{cases}
|
1238
|
+
If `origin=False`, computes the forward function:
|
1234
1239
|
|
1235
|
-
|
1240
|
+
.. math::
|
1236
1241
|
|
1237
|
-
|
1242
|
+
g(x) = \begin{cases}
|
1243
|
+
1, & x \geq 0 \\
|
1244
|
+
0, & x < 0 \\
|
1245
|
+
\end{cases}
|
1238
1246
|
|
1239
|
-
|
1240
|
-
\begin{cases}
|
1241
|
-
\beta \cdot x, & x \geq 0 \\
|
1242
|
-
\alpha \cdot x, & x < 0 \\
|
1243
|
-
\end{cases}\end{split}
|
1244
|
-
|
1245
|
-
Backward function:
|
1246
|
-
|
1247
|
-
.. math::
|
1248
|
-
|
1249
|
-
\begin{split}g'(x) =
|
1250
|
-
\begin{cases}
|
1251
|
-
\beta, & x \geq 0 \\
|
1252
|
-
\alpha, & x < 0 \\
|
1253
|
-
\end{cases}\end{split}
|
1247
|
+
If `origin=True`, computes the original function:
|
1254
1248
|
|
1255
|
-
|
1256
|
-
:include-source: True
|
1249
|
+
.. math::
|
1257
1250
|
|
1258
|
-
|
1259
|
-
|
1260
|
-
|
1261
|
-
|
1262
|
-
|
1263
|
-
>>> plt.plot(xs, grads, label=r'$\alpha=0., \beta=1.$')
|
1264
|
-
>>> plt.legend()
|
1265
|
-
>>> plt.show()
|
1251
|
+
\begin{split}g(x) =
|
1252
|
+
\begin{cases}
|
1253
|
+
\beta \cdot x, & x \geq 0 \\
|
1254
|
+
\alpha \cdot x, & x < 0 \\
|
1255
|
+
\end{cases}\end{split}
|
1266
1256
|
|
1267
|
-
|
1268
|
-
----------
|
1269
|
-
x: jax.Array, Array
|
1270
|
-
The input data.
|
1271
|
-
alpha: float
|
1272
|
-
The parameter to control the gradient when :math:`x < 0`.
|
1273
|
-
beta: float
|
1274
|
-
The parameter to control the gradient when :math:`x >= 0`.
|
1257
|
+
Backward function:
|
1275
1258
|
|
1259
|
+
.. math::
|
1276
1260
|
|
1277
|
-
|
1278
|
-
|
1279
|
-
|
1280
|
-
|
1281
|
-
|
1282
|
-
|
1261
|
+
\begin{split}g'(x) =
|
1262
|
+
\begin{cases}
|
1263
|
+
\beta, & x \geq 0 \\
|
1264
|
+
\alpha, & x < 0 \\
|
1265
|
+
\end{cases}\end{split}
|
1266
|
+
|
1267
|
+
.. plot::
|
1268
|
+
:include-source: True
|
1269
|
+
|
1270
|
+
>>> import jax
|
1271
|
+
>>> import brainstate.nn as nn
|
1272
|
+
>>> import brainstate as bst
|
1273
|
+
>>> import matplotlib.pyplot as plt
|
1274
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1275
|
+
>>> grads = bst.augment.vector_grad(bst.surrogate.leaky_relu)(xs, 0., 1.)
|
1276
|
+
>>> plt.plot(xs, grads, label=r'$\alpha=0., \beta=1.$')
|
1277
|
+
>>> plt.legend()
|
1278
|
+
>>> plt.show()
|
1279
|
+
|
1280
|
+
Parameters
|
1281
|
+
----------
|
1282
|
+
x: jax.Array, Array
|
1283
|
+
The input data.
|
1284
|
+
alpha: float
|
1285
|
+
The parameter to control the gradient when :math:`x < 0`.
|
1286
|
+
beta: float
|
1287
|
+
The parameter to control the gradient when :math:`x >= 0`.
|
1288
|
+
|
1289
|
+
|
1290
|
+
Returns
|
1291
|
+
-------
|
1292
|
+
out: jax.Array
|
1293
|
+
The spiking state.
|
1294
|
+
"""
|
1295
|
+
return LeakyRelu(alpha=alpha, beta=beta)(x)
|
1283
1296
|
|
1284
1297
|
|
1285
1298
|
class LogTailedRelu(Surrogate):
|
1286
|
-
|
1299
|
+
"""Judge spiking state with the Log-tailed ReLU function.
|
1287
1300
|
|
1288
|
-
|
1289
|
-
|
1290
|
-
|
1291
|
-
|
1301
|
+
See Also
|
1302
|
+
--------
|
1303
|
+
log_tailed_relu
|
1304
|
+
"""
|
1292
1305
|
|
1293
|
-
|
1294
|
-
|
1295
|
-
|
1306
|
+
def __init__(self, alpha=0.):
|
1307
|
+
super().__init__()
|
1308
|
+
self.alpha = alpha
|
1296
1309
|
|
1297
|
-
|
1298
|
-
|
1299
|
-
|
1300
|
-
|
1301
|
-
|
1302
|
-
|
1303
|
-
|
1310
|
+
def surrogate_fun(self, x):
|
1311
|
+
z = jnp.where(x > 1,
|
1312
|
+
jnp.log(x),
|
1313
|
+
jnp.where(x > 0,
|
1314
|
+
x,
|
1315
|
+
self.alpha * x))
|
1316
|
+
return z
|
1304
1317
|
|
1305
|
-
|
1306
|
-
|
1307
|
-
|
1308
|
-
|
1309
|
-
|
1310
|
-
|
1311
|
-
|
1318
|
+
def surrogate_grad(self, x):
|
1319
|
+
dx = jnp.where(x > 1,
|
1320
|
+
1 / x,
|
1321
|
+
jnp.where(x > 0,
|
1322
|
+
1.,
|
1323
|
+
self.alpha))
|
1324
|
+
return dx
|
1312
1325
|
|
1313
|
-
|
1314
|
-
|
1326
|
+
def __repr__(self):
|
1327
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
1315
1328
|
|
1316
|
-
|
1317
|
-
|
1329
|
+
def __hash__(self):
|
1330
|
+
return hash((self.__class__, self.alpha))
|
1318
1331
|
|
1319
1332
|
|
1320
1333
|
def log_tailed_relu(
|
@@ -1322,93 +1335,94 @@ def log_tailed_relu(
|
|
1322
1335
|
alpha: float = 0.,
|
1323
1336
|
|
1324
1337
|
):
|
1325
|
-
|
1326
|
-
|
1327
|
-
If `origin=False`, computes the forward function:
|
1328
|
-
|
1329
|
-
.. math::
|
1330
|
-
|
1331
|
-
g(x) = \begin{cases}
|
1332
|
-
1, & x \geq 0 \\
|
1333
|
-
0, & x < 0 \\
|
1334
|
-
\end{cases}
|
1338
|
+
r"""Judge spiking state with the Log-tailed ReLU function [1]_.
|
1335
1339
|
|
1336
|
-
|
1340
|
+
If `origin=False`, computes the forward function:
|
1337
1341
|
|
1338
|
-
|
1342
|
+
.. math::
|
1339
1343
|
|
1340
|
-
|
1341
|
-
|
1342
|
-
|
1343
|
-
|
1344
|
-
log(x), x > 1 \\
|
1345
|
-
\end{cases}\end{split}
|
1346
|
-
|
1347
|
-
Backward function:
|
1348
|
-
|
1349
|
-
.. math::
|
1350
|
-
|
1351
|
-
\begin{split}g'(x) =
|
1352
|
-
\begin{cases}
|
1353
|
-
\alpha, & x \leq 0 \\
|
1354
|
-
1, & 0 < x \leq 0 \\
|
1355
|
-
\frac{1}{x}, x > 1 \\
|
1356
|
-
\end{cases}\end{split}
|
1344
|
+
g(x) = \begin{cases}
|
1345
|
+
1, & x \geq 0 \\
|
1346
|
+
0, & x < 0 \\
|
1347
|
+
\end{cases}
|
1357
1348
|
|
1358
|
-
|
1359
|
-
:include-source: True
|
1349
|
+
If `origin=True`, computes the original function:
|
1360
1350
|
|
1361
|
-
|
1362
|
-
>>> import brainstate as bst
|
1363
|
-
>>> import matplotlib.pyplot as plt
|
1364
|
-
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1365
|
-
>>> grads = bst.transform.vector_grad(nn.surrogate.leaky_relu)(xs, 0., 1.)
|
1366
|
-
>>> plt.plot(xs, grads, label=r'$\alpha=0., \beta=1.$')
|
1367
|
-
>>> plt.legend()
|
1368
|
-
>>> plt.show()
|
1351
|
+
.. math::
|
1369
1352
|
|
1370
|
-
|
1371
|
-
|
1372
|
-
|
1373
|
-
|
1374
|
-
|
1375
|
-
|
1353
|
+
\begin{split}g(x) =
|
1354
|
+
\begin{cases}
|
1355
|
+
\alpha x, & x \leq 0 \\
|
1356
|
+
x, & 0 < x \leq 0 \\
|
1357
|
+
log(x), x > 1 \\
|
1358
|
+
\end{cases}\end{split}
|
1376
1359
|
|
1360
|
+
Backward function:
|
1377
1361
|
|
1378
|
-
|
1379
|
-
-------
|
1380
|
-
out: jax.Array
|
1381
|
-
The spiking state.
|
1362
|
+
.. math::
|
1382
1363
|
|
1383
|
-
|
1384
|
-
|
1385
|
-
|
1386
|
-
|
1387
|
-
|
1364
|
+
\begin{split}g'(x) =
|
1365
|
+
\begin{cases}
|
1366
|
+
\alpha, & x \leq 0 \\
|
1367
|
+
1, & 0 < x \leq 0 \\
|
1368
|
+
\frac{1}{x}, x > 1 \\
|
1369
|
+
\end{cases}\end{split}
|
1370
|
+
|
1371
|
+
.. plot::
|
1372
|
+
:include-source: True
|
1373
|
+
|
1374
|
+
>>> import jax
|
1375
|
+
>>> import brainstate.nn as nn
|
1376
|
+
>>> import brainstate as bst
|
1377
|
+
>>> import matplotlib.pyplot as plt
|
1378
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1379
|
+
>>> grads = bst.augment.vector_grad(bst.surrogate.leaky_relu)(xs, 0., 1.)
|
1380
|
+
>>> plt.plot(xs, grads, label=r'$\alpha=0., \beta=1.$')
|
1381
|
+
>>> plt.legend()
|
1382
|
+
>>> plt.show()
|
1383
|
+
|
1384
|
+
Parameters
|
1385
|
+
----------
|
1386
|
+
x: jax.Array, Array
|
1387
|
+
The input data.
|
1388
|
+
alpha: float
|
1389
|
+
The parameter to control the gradient.
|
1390
|
+
|
1391
|
+
|
1392
|
+
Returns
|
1393
|
+
-------
|
1394
|
+
out: jax.Array
|
1395
|
+
The spiking state.
|
1396
|
+
|
1397
|
+
References
|
1398
|
+
----------
|
1399
|
+
.. [1] Cai, Zhaowei et al. “Deep Learning with Low Precision by Half-Wave Gaussian Quantization.” 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2017): 5406-5414.
|
1400
|
+
"""
|
1401
|
+
return LogTailedRelu(alpha=alpha)(x)
|
1388
1402
|
|
1389
1403
|
|
1390
1404
|
class ReluGrad(Surrogate):
|
1391
|
-
|
1405
|
+
"""Judge spiking state with the ReLU gradient function.
|
1392
1406
|
|
1393
|
-
|
1394
|
-
|
1395
|
-
|
1396
|
-
|
1407
|
+
See Also
|
1408
|
+
--------
|
1409
|
+
relu_grad
|
1410
|
+
"""
|
1397
1411
|
|
1398
|
-
|
1399
|
-
|
1400
|
-
|
1401
|
-
|
1412
|
+
def __init__(self, alpha=0.3, width=1.):
|
1413
|
+
super().__init__()
|
1414
|
+
self.alpha = alpha
|
1415
|
+
self.width = width
|
1402
1416
|
|
1403
|
-
|
1404
|
-
|
1405
|
-
|
1417
|
+
def surrogate_grad(self, x):
|
1418
|
+
dx = jnp.maximum(self.alpha * self.width - jnp.abs(x) * self.alpha, 0)
|
1419
|
+
return dx
|
1406
1420
|
|
1407
|
-
|
1408
|
-
|
1421
|
+
def __repr__(self):
|
1422
|
+
return f'{self.__class__.__name__}(alpha={self.alpha}, width={self.width})'
|
1409
1423
|
|
1410
|
-
|
1411
|
-
|
1424
|
+
def __hash__(self):
|
1425
|
+
return hash((self.__class__, self.alpha, self.width))
|
1412
1426
|
|
1413
1427
|
|
1414
1428
|
def relu_grad(
|
@@ -1416,80 +1430,81 @@ def relu_grad(
|
|
1416
1430
|
alpha: float = 0.3,
|
1417
1431
|
width: float = 1.,
|
1418
1432
|
):
|
1419
|
-
|
1433
|
+
r"""Spike function with the ReLU gradient function [1]_.
|
1420
1434
|
|
1421
|
-
|
1435
|
+
The forward function:
|
1422
1436
|
|
1423
|
-
|
1437
|
+
.. math::
|
1424
1438
|
|
1425
|
-
|
1426
|
-
|
1427
|
-
|
1428
|
-
|
1429
|
-
|
1430
|
-
|
1431
|
-
|
1432
|
-
|
1433
|
-
|
1434
|
-
|
1435
|
-
|
1436
|
-
|
1437
|
-
|
1438
|
-
|
1439
|
-
|
1440
|
-
|
1441
|
-
|
1442
|
-
|
1443
|
-
|
1444
|
-
|
1445
|
-
|
1446
|
-
|
1447
|
-
|
1448
|
-
|
1449
|
-
|
1450
|
-
|
1451
|
-
|
1452
|
-
|
1453
|
-
|
1454
|
-
|
1455
|
-
|
1456
|
-
|
1457
|
-
|
1458
|
-
|
1459
|
-
|
1460
|
-
|
1461
|
-
|
1462
|
-
|
1463
|
-
|
1464
|
-
|
1465
|
-
|
1466
|
-
|
1467
|
-
|
1468
|
-
|
1439
|
+
g(x) = \begin{cases}
|
1440
|
+
1, & x \geq 0 \\
|
1441
|
+
0, & x < 0 \\
|
1442
|
+
\end{cases}
|
1443
|
+
|
1444
|
+
Backward function:
|
1445
|
+
|
1446
|
+
.. math::
|
1447
|
+
|
1448
|
+
g'(x) = \text{ReLU}(\alpha * (\mathrm{width}-|x|))
|
1449
|
+
|
1450
|
+
.. plot::
|
1451
|
+
:include-source: True
|
1452
|
+
|
1453
|
+
>>> import jax
|
1454
|
+
>>> import brainstate.nn as nn
|
1455
|
+
>>> import brainstate as bst
|
1456
|
+
>>> import matplotlib.pyplot as plt
|
1457
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1458
|
+
>>> for s in [0.5, 1.]:
|
1459
|
+
>>> for w in [1, 2.]:
|
1460
|
+
>>> grads = bst.augment.vector_grad(bst.surrogate.relu_grad)(xs, s, w)
|
1461
|
+
>>> plt.plot(xs, grads, label=r'$\alpha=$' + f'{s}, width={w}')
|
1462
|
+
>>> plt.legend()
|
1463
|
+
>>> plt.show()
|
1464
|
+
|
1465
|
+
Parameters
|
1466
|
+
----------
|
1467
|
+
x: jax.Array, Array
|
1468
|
+
The input data.
|
1469
|
+
alpha: float
|
1470
|
+
The parameter to control the gradient.
|
1471
|
+
width: float
|
1472
|
+
The parameter to control the width of the gradient.
|
1473
|
+
|
1474
|
+
Returns
|
1475
|
+
-------
|
1476
|
+
out: jax.Array
|
1477
|
+
The spiking state.
|
1478
|
+
|
1479
|
+
References
|
1480
|
+
----------
|
1481
|
+
.. [1] Neftci, E. O., Mostafa, H. & Zenke, F. Surrogate gradient learning in spiking neural networks. IEEE Signal Process. Mag. 36, 61–63 (2019).
|
1482
|
+
"""
|
1483
|
+
return ReluGrad(alpha=alpha, width=width)(x)
|
1469
1484
|
|
1470
1485
|
|
1471
1486
|
class GaussianGrad(Surrogate):
|
1472
|
-
|
1487
|
+
"""Judge spiking state with the Gaussian gradient function.
|
1473
1488
|
|
1474
|
-
|
1475
|
-
|
1476
|
-
|
1477
|
-
|
1489
|
+
See Also
|
1490
|
+
--------
|
1491
|
+
gaussian_grad
|
1492
|
+
"""
|
1478
1493
|
|
1479
|
-
|
1480
|
-
|
1481
|
-
|
1482
|
-
|
1494
|
+
def __init__(self, sigma=0.5, alpha=0.5):
|
1495
|
+
super().__init__()
|
1496
|
+
self.sigma = sigma
|
1497
|
+
self.alpha = alpha
|
1483
1498
|
|
1484
|
-
|
1485
|
-
|
1486
|
-
|
1499
|
+
def surrogate_grad(self, x):
|
1500
|
+
dx = jnp.exp(-(x ** 2) / 2 * jnp.power(self.sigma, 2)) / (jnp.sqrt(2 * jnp.pi) * self.sigma)
|
1501
|
+
return self.alpha * dx
|
1487
1502
|
|
1488
|
-
|
1489
|
-
|
1503
|
+
def __repr__(self):
|
1504
|
+
return f'{self.__class__.__name__}(alpha={self.alpha}, sigma={self.sigma})'
|
1490
1505
|
|
1491
|
-
|
1492
|
-
|
1506
|
+
def __hash__(self):
|
1507
|
+
return hash((self.__class__, self.alpha, self.sigma))
|
1493
1508
|
|
1494
1509
|
|
1495
1510
|
def gaussian_grad(
|
@@ -1497,86 +1512,87 @@ def gaussian_grad(
|
|
1497
1512
|
sigma: float = 0.5,
|
1498
1513
|
alpha: float = 0.5,
|
1499
1514
|
):
|
1500
|
-
|
1515
|
+
r"""Spike function with the Gaussian gradient function [1]_.
|
1501
1516
|
|
1502
|
-
|
1517
|
+
The forward function:
|
1503
1518
|
|
1504
|
-
|
1519
|
+
.. math::
|
1505
1520
|
|
1506
|
-
|
1507
|
-
|
1508
|
-
|
1509
|
-
|
1521
|
+
g(x) = \begin{cases}
|
1522
|
+
1, & x \geq 0 \\
|
1523
|
+
0, & x < 0 \\
|
1524
|
+
\end{cases}
|
1510
1525
|
|
1511
|
-
|
1526
|
+
Backward function:
|
1512
1527
|
|
1513
|
-
|
1528
|
+
.. math::
|
1514
1529
|
|
1515
|
-
|
1530
|
+
g'(x) = \alpha * \text{gaussian}(x, 0., \sigma)
|
1516
1531
|
|
1517
|
-
|
1518
|
-
|
1532
|
+
.. plot::
|
1533
|
+
:include-source: True
|
1519
1534
|
|
1520
|
-
|
1521
|
-
|
1522
|
-
|
1523
|
-
|
1524
|
-
|
1525
|
-
|
1526
|
-
|
1527
|
-
|
1528
|
-
|
1535
|
+
>>> import jax
|
1536
|
+
>>> import brainstate.nn as nn
|
1537
|
+
>>> import brainstate as bst
|
1538
|
+
>>> import matplotlib.pyplot as plt
|
1539
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1540
|
+
>>> for s in [0.5, 1., 2.]:
|
1541
|
+
>>> grads = bst.augment.vector_grad(bst.surrogate.gaussian_grad)(xs, s, 0.5)
|
1542
|
+
>>> plt.plot(xs, grads, label=r'$\alpha=0.5, \sigma=$' + str(s))
|
1543
|
+
>>> plt.legend()
|
1544
|
+
>>> plt.show()
|
1529
1545
|
|
1530
|
-
|
1531
|
-
|
1532
|
-
|
1533
|
-
|
1534
|
-
|
1535
|
-
|
1536
|
-
|
1537
|
-
|
1546
|
+
Parameters
|
1547
|
+
----------
|
1548
|
+
x: jax.Array, Array
|
1549
|
+
The input data.
|
1550
|
+
sigma: float
|
1551
|
+
The parameter to control the variance of gaussian distribution.
|
1552
|
+
alpha: float
|
1553
|
+
The parameter to control the scale of the gradient.
|
1538
1554
|
|
1539
|
-
|
1540
|
-
|
1541
|
-
|
1542
|
-
|
1555
|
+
Returns
|
1556
|
+
-------
|
1557
|
+
out: jax.Array
|
1558
|
+
The spiking state.
|
1543
1559
|
|
1544
|
-
|
1545
|
-
|
1546
|
-
|
1547
|
-
|
1548
|
-
|
1560
|
+
References
|
1561
|
+
----------
|
1562
|
+
.. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021).
|
1563
|
+
"""
|
1564
|
+
return GaussianGrad(sigma=sigma, alpha=alpha)(x)
|
1549
1565
|
|
1550
1566
|
|
1551
1567
|
class MultiGaussianGrad(Surrogate):
|
1552
|
-
|
1568
|
+
"""Judge spiking state with the multi-Gaussian gradient function.
|
1553
1569
|
|
1554
|
-
|
1555
|
-
|
1556
|
-
|
1557
|
-
|
1570
|
+
See Also
|
1571
|
+
--------
|
1572
|
+
multi_gaussian_grad
|
1573
|
+
"""
|
1558
1574
|
|
1559
|
-
|
1560
|
-
|
1561
|
-
|
1562
|
-
|
1563
|
-
|
1564
|
-
|
1575
|
+
def __init__(self, h=0.15, s=6.0, sigma=0.5, scale=0.5):
|
1576
|
+
super().__init__()
|
1577
|
+
self.h = h
|
1578
|
+
self.s = s
|
1579
|
+
self.sigma = sigma
|
1580
|
+
self.scale = scale
|
1565
1581
|
|
1566
|
-
|
1567
|
-
|
1568
|
-
|
1569
|
-
|
1570
|
-
|
1571
|
-
|
1572
|
-
|
1573
|
-
|
1582
|
+
def surrogate_grad(self, x):
|
1583
|
+
g1 = jnp.exp(-x ** 2 / (2 * jnp.power(self.sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * self.sigma)
|
1584
|
+
g2 = jnp.exp(-(x - self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2))
|
1585
|
+
) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma)
|
1586
|
+
g3 = jnp.exp(-(x + self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2))
|
1587
|
+
) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma)
|
1588
|
+
dx = g1 * (1. + self.h) - g2 * self.h - g3 * self.h
|
1589
|
+
return self.scale * dx
|
1574
1590
|
|
1575
|
-
|
1576
|
-
|
1591
|
+
def __repr__(self):
|
1592
|
+
return f'{self.__class__.__name__}(h={self.h}, s={self.s}, sigma={self.sigma}, scale={self.scale})'
|
1577
1593
|
|
1578
|
-
|
1579
|
-
|
1594
|
+
def __hash__(self):
|
1595
|
+
return hash((self.__class__, self.h, self.s, self.sigma, self.scale))
|
1580
1596
|
|
1581
1597
|
|
1582
1598
|
def multi_gaussian_grad(
|
@@ -1586,209 +1602,212 @@ def multi_gaussian_grad(
|
|
1586
1602
|
sigma: float = 0.5,
|
1587
1603
|
scale: float = 0.5,
|
1588
1604
|
):
|
1589
|
-
|
1605
|
+
r"""Spike function with the multi-Gaussian gradient function [1]_.
|
1590
1606
|
|
1591
|
-
|
1607
|
+
The forward function:
|
1592
1608
|
|
1593
|
-
|
1609
|
+
.. math::
|
1594
1610
|
|
1595
|
-
|
1596
|
-
|
1597
|
-
|
1598
|
-
|
1599
|
-
|
1600
|
-
|
1601
|
-
|
1602
|
-
|
1603
|
-
|
1604
|
-
|
1605
|
-
|
1606
|
-
|
1607
|
-
|
1608
|
-
|
1609
|
-
|
1610
|
-
|
1611
|
-
|
1612
|
-
|
1613
|
-
|
1614
|
-
|
1615
|
-
|
1616
|
-
|
1617
|
-
|
1618
|
-
|
1619
|
-
|
1620
|
-
|
1621
|
-
|
1622
|
-
|
1623
|
-
|
1624
|
-
|
1625
|
-
|
1626
|
-
|
1627
|
-
|
1628
|
-
|
1629
|
-
|
1630
|
-
|
1631
|
-
|
1632
|
-
|
1633
|
-
|
1634
|
-
|
1635
|
-
|
1636
|
-
|
1637
|
-
|
1638
|
-
|
1639
|
-
|
1640
|
-
|
1641
|
-
|
1642
|
-
|
1643
|
-
|
1644
|
-
|
1611
|
+
g(x) = \begin{cases}
|
1612
|
+
1, & x \geq 0 \\
|
1613
|
+
0, & x < 0 \\
|
1614
|
+
\end{cases}
|
1615
|
+
|
1616
|
+
Backward function:
|
1617
|
+
|
1618
|
+
.. math::
|
1619
|
+
|
1620
|
+
\begin{array}{l}
|
1621
|
+
g'(x)=(1+h){{{\mathcal{N}}}}(x, 0, {\sigma }^{2})
|
1622
|
+
-h{{{\mathcal{N}}}}(x, \sigma,{(s\sigma )}^{2})-
|
1623
|
+
h{{{\mathcal{N}}}}(x, -\sigma ,{(s\sigma )}^{2})
|
1624
|
+
\end{array}
|
1625
|
+
|
1626
|
+
|
1627
|
+
.. plot::
|
1628
|
+
:include-source: True
|
1629
|
+
|
1630
|
+
>>> import jax
|
1631
|
+
>>> import brainstate.nn as nn
|
1632
|
+
>>> import brainstate as bst
|
1633
|
+
>>> import matplotlib.pyplot as plt
|
1634
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1635
|
+
>>> grads = bst.augment.vector_grad(bst.surrogate.multi_gaussian_grad)(xs)
|
1636
|
+
>>> plt.plot(xs, grads)
|
1637
|
+
>>> plt.show()
|
1638
|
+
|
1639
|
+
Parameters
|
1640
|
+
----------
|
1641
|
+
x: jax.Array, Array
|
1642
|
+
The input data.
|
1643
|
+
h: float
|
1644
|
+
The hyper-parameters of approximate function
|
1645
|
+
s: float
|
1646
|
+
The hyper-parameters of approximate function
|
1647
|
+
sigma: float
|
1648
|
+
The gaussian sigma.
|
1649
|
+
scale: float
|
1650
|
+
The gradient scale.
|
1651
|
+
|
1652
|
+
Returns
|
1653
|
+
-------
|
1654
|
+
out: jax.Array
|
1655
|
+
The spiking state.
|
1656
|
+
|
1657
|
+
References
|
1658
|
+
----------
|
1659
|
+
.. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021).
|
1660
|
+
"""
|
1661
|
+
return MultiGaussianGrad(h=h, s=s, sigma=sigma, scale=scale)(x)
|
1645
1662
|
|
1646
1663
|
|
1647
1664
|
class InvSquareGrad(Surrogate):
|
1648
|
-
|
1665
|
+
"""Judge spiking state with the inverse-square surrogate gradient function.
|
1649
1666
|
|
1650
|
-
|
1651
|
-
|
1652
|
-
|
1653
|
-
|
1667
|
+
See Also
|
1668
|
+
--------
|
1669
|
+
inv_square_grad
|
1670
|
+
"""
|
1654
1671
|
|
1655
|
-
|
1656
|
-
|
1657
|
-
|
1672
|
+
def __init__(self, alpha=100.):
|
1673
|
+
super().__init__()
|
1674
|
+
self.alpha = alpha
|
1658
1675
|
|
1659
|
-
|
1660
|
-
|
1661
|
-
|
1676
|
+
def surrogate_grad(self, x):
|
1677
|
+
dx = 1. / (self.alpha * jnp.abs(x) + 1.0) ** 2
|
1678
|
+
return dx
|
1662
1679
|
|
1663
|
-
|
1664
|
-
|
1680
|
+
def __repr__(self):
|
1681
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
1665
1682
|
|
1666
|
-
|
1667
|
-
|
1683
|
+
def __hash__(self):
|
1684
|
+
return hash((self.__class__, self.alpha))
|
1668
1685
|
|
1669
1686
|
|
1670
1687
|
def inv_square_grad(
|
1671
1688
|
x: jax.Array,
|
1672
1689
|
alpha: float = 100.
|
1673
1690
|
):
|
1674
|
-
|
1691
|
+
r"""Spike function with the inverse-square surrogate gradient.
|
1675
1692
|
|
1676
|
-
|
1693
|
+
Forward function:
|
1677
1694
|
|
1678
|
-
|
1695
|
+
.. math::
|
1679
1696
|
|
1680
|
-
|
1681
|
-
|
1682
|
-
|
1683
|
-
|
1697
|
+
g(x) = \begin{cases}
|
1698
|
+
1, & x \geq 0 \\
|
1699
|
+
0, & x < 0 \\
|
1700
|
+
\end{cases}
|
1684
1701
|
|
1685
|
-
|
1702
|
+
Backward function:
|
1686
1703
|
|
1687
|
-
|
1704
|
+
.. math::
|
1688
1705
|
|
1689
|
-
|
1706
|
+
g'(x) = \frac{1}{(\alpha * |x| + 1.) ^ 2}
|
1690
1707
|
|
1691
1708
|
|
1692
|
-
|
1693
|
-
|
1709
|
+
.. plot::
|
1710
|
+
:include-source: True
|
1694
1711
|
|
1695
|
-
|
1696
|
-
|
1697
|
-
|
1698
|
-
|
1699
|
-
|
1700
|
-
|
1701
|
-
|
1702
|
-
|
1703
|
-
|
1712
|
+
>>> import jax
|
1713
|
+
>>> import brainstate.nn as nn
|
1714
|
+
>>> import brainstate as bst
|
1715
|
+
>>> import matplotlib.pyplot as plt
|
1716
|
+
>>> xs = jax.numpy.linspace(-1, 1, 1000)
|
1717
|
+
>>> for alpha in [1., 10., 100.]:
|
1718
|
+
>>> grads = bst.augment.vector_grad(bst.surrogate.inv_square_grad)(xs, alpha)
|
1719
|
+
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
1720
|
+
>>> plt.legend()
|
1721
|
+
>>> plt.show()
|
1704
1722
|
|
1705
|
-
|
1706
|
-
|
1707
|
-
|
1708
|
-
|
1709
|
-
|
1710
|
-
|
1723
|
+
Parameters
|
1724
|
+
----------
|
1725
|
+
x: jax.Array, Array
|
1726
|
+
The input data.
|
1727
|
+
alpha: float
|
1728
|
+
Parameter to control smoothness of gradient
|
1711
1729
|
|
1712
|
-
|
1713
|
-
|
1714
|
-
|
1715
|
-
|
1716
|
-
|
1717
|
-
|
1730
|
+
Returns
|
1731
|
+
-------
|
1732
|
+
out: jax.Array
|
1733
|
+
The spiking state.
|
1734
|
+
"""
|
1735
|
+
return InvSquareGrad(alpha=alpha)(x)
|
1718
1736
|
|
1719
1737
|
|
1720
1738
|
class SlayerGrad(Surrogate):
|
1721
|
-
|
1739
|
+
"""Judge spiking state with the slayer surrogate gradient function.
|
1722
1740
|
|
1723
|
-
|
1724
|
-
|
1725
|
-
|
1726
|
-
|
1741
|
+
See Also
|
1742
|
+
--------
|
1743
|
+
slayer_grad
|
1744
|
+
"""
|
1727
1745
|
|
1728
|
-
|
1729
|
-
|
1730
|
-
|
1746
|
+
def __init__(self, alpha=1.):
|
1747
|
+
super().__init__()
|
1748
|
+
self.alpha = alpha
|
1731
1749
|
|
1732
|
-
|
1733
|
-
|
1734
|
-
|
1750
|
+
def surrogate_grad(self, x):
|
1751
|
+
dx = jnp.exp(-self.alpha * jnp.abs(x))
|
1752
|
+
return dx
|
1735
1753
|
|
1736
|
-
|
1737
|
-
|
1754
|
+
def __repr__(self):
|
1755
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
1738
1756
|
|
1739
|
-
|
1740
|
-
|
1757
|
+
def __hash__(self):
|
1758
|
+
return hash((self.__class__, self.alpha))
|
1741
1759
|
|
1742
1760
|
|
1743
1761
|
def slayer_grad(
|
1744
1762
|
x: jax.Array,
|
1745
1763
|
alpha: float = 1.
|
1746
1764
|
):
|
1747
|
-
|
1765
|
+
r"""Spike function with the slayer surrogate gradient function.
|
1748
1766
|
|
1749
|
-
|
1767
|
+
Forward function:
|
1750
1768
|
|
1751
|
-
|
1769
|
+
.. math::
|
1752
1770
|
|
1753
|
-
|
1754
|
-
|
1755
|
-
|
1756
|
-
|
1757
|
-
|
1758
|
-
|
1759
|
-
|
1760
|
-
|
1761
|
-
|
1762
|
-
|
1763
|
-
|
1764
|
-
|
1765
|
-
|
1766
|
-
|
1767
|
-
|
1768
|
-
|
1769
|
-
|
1770
|
-
|
1771
|
-
|
1772
|
-
|
1773
|
-
|
1774
|
-
|
1775
|
-
|
1776
|
-
|
1777
|
-
|
1778
|
-
|
1779
|
-
|
1780
|
-
|
1781
|
-
|
1782
|
-
|
1783
|
-
|
1784
|
-
|
1785
|
-
|
1786
|
-
|
1787
|
-
|
1788
|
-
|
1789
|
-
|
1790
|
-
|
1791
|
-
|
1792
|
-
|
1793
|
-
|
1794
|
-
|
1771
|
+
g(x) = \begin{cases}
|
1772
|
+
1, & x \geq 0 \\
|
1773
|
+
0, & x < 0 \\
|
1774
|
+
\end{cases}
|
1775
|
+
|
1776
|
+
Backward function:
|
1777
|
+
|
1778
|
+
.. math::
|
1779
|
+
|
1780
|
+
g'(x) = \exp(-\alpha |x|)
|
1781
|
+
|
1782
|
+
|
1783
|
+
.. plot::
|
1784
|
+
:include-source: True
|
1785
|
+
|
1786
|
+
>>> import jax
|
1787
|
+
>>> import brainstate.nn as nn
|
1788
|
+
>>> import brainstate as bst
|
1789
|
+
>>> import matplotlib.pyplot as plt
|
1790
|
+
>>> xs = jax.numpy.linspace(-3, 3, 1000)
|
1791
|
+
>>> for alpha in [0.5, 1., 2., 4.]:
|
1792
|
+
>>> grads = bst.augment.vector_grad(bst.surrogate.slayer_grad)(xs, alpha)
|
1793
|
+
>>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
|
1794
|
+
>>> plt.legend()
|
1795
|
+
>>> plt.show()
|
1796
|
+
|
1797
|
+
Parameters
|
1798
|
+
----------
|
1799
|
+
x: jax.Array, Array
|
1800
|
+
The input data.
|
1801
|
+
alpha: float
|
1802
|
+
Parameter to control smoothness of gradient
|
1803
|
+
|
1804
|
+
Returns
|
1805
|
+
-------
|
1806
|
+
out: jax.Array
|
1807
|
+
The spiking state.
|
1808
|
+
|
1809
|
+
References
|
1810
|
+
----------
|
1811
|
+
.. [1] Shrestha, S. B. & Orchard, G. Slayer: spike layer error reassignment in time. In Advances in Neural Information Processing Systems Vol. 31, 1412–1421 (NeurIPS, 2018).
|
1812
|
+
"""
|
1813
|
+
return SlayerGrad(alpha=alpha)(x)
|