brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1356 -1321
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
brainstate/surrogate.py CHANGED
@@ -14,6 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  # -*- coding: utf-8 -*-
17
+ from __future__ import annotations
17
18
 
18
19
  import jax
19
20
  import jax.numpy as jnp
@@ -22,65 +23,65 @@ from jax.core import Primitive
22
23
  from jax.interpreters import batching, ad, mlir
23
24
 
24
25
  __all__ = [
25
- 'Surrogate',
26
- 'Sigmoid',
27
- 'sigmoid',
28
- 'PiecewiseQuadratic',
29
- 'piecewise_quadratic',
30
- 'PiecewiseExp',
31
- 'piecewise_exp',
32
- 'SoftSign',
33
- 'soft_sign',
34
- 'Arctan',
35
- 'arctan',
36
- 'NonzeroSignLog',
37
- 'nonzero_sign_log',
38
- 'ERF',
39
- 'erf',
40
- 'PiecewiseLeakyRelu',
41
- 'piecewise_leaky_relu',
42
- 'SquarewaveFourierSeries',
43
- 'squarewave_fourier_series',
44
- 'S2NN',
45
- 's2nn',
46
- 'QPseudoSpike',
47
- 'q_pseudo_spike',
48
- 'LeakyRelu',
49
- 'leaky_relu',
50
- 'LogTailedRelu',
51
- 'log_tailed_relu',
52
- 'ReluGrad',
53
- 'relu_grad',
54
- 'GaussianGrad',
55
- 'gaussian_grad',
56
- 'InvSquareGrad',
57
- 'inv_square_grad',
58
- 'MultiGaussianGrad',
59
- 'multi_gaussian_grad',
60
- 'SlayerGrad',
61
- 'slayer_grad',
26
+ 'Surrogate',
27
+ 'Sigmoid',
28
+ 'sigmoid',
29
+ 'PiecewiseQuadratic',
30
+ 'piecewise_quadratic',
31
+ 'PiecewiseExp',
32
+ 'piecewise_exp',
33
+ 'SoftSign',
34
+ 'soft_sign',
35
+ 'Arctan',
36
+ 'arctan',
37
+ 'NonzeroSignLog',
38
+ 'nonzero_sign_log',
39
+ 'ERF',
40
+ 'erf',
41
+ 'PiecewiseLeakyRelu',
42
+ 'piecewise_leaky_relu',
43
+ 'SquarewaveFourierSeries',
44
+ 'squarewave_fourier_series',
45
+ 'S2NN',
46
+ 's2nn',
47
+ 'QPseudoSpike',
48
+ 'q_pseudo_spike',
49
+ 'LeakyRelu',
50
+ 'leaky_relu',
51
+ 'LogTailedRelu',
52
+ 'log_tailed_relu',
53
+ 'ReluGrad',
54
+ 'relu_grad',
55
+ 'GaussianGrad',
56
+ 'gaussian_grad',
57
+ 'InvSquareGrad',
58
+ 'inv_square_grad',
59
+ 'MultiGaussianGrad',
60
+ 'multi_gaussian_grad',
61
+ 'SlayerGrad',
62
+ 'slayer_grad',
62
63
  ]
63
64
 
64
65
 
65
66
  def _heaviside_abstract(x, dx):
66
- return [x]
67
+ return [x]
67
68
 
68
69
 
69
70
  def _heaviside_imp(x, dx):
70
- z = jnp.asarray(x >= 0, dtype=x.dtype)
71
- return [z]
71
+ z = jnp.asarray(x >= 0, dtype=x.dtype)
72
+ return [z]
72
73
 
73
74
 
74
75
  def _heaviside_batching(args, axes):
75
- return heaviside_p.bind(*args), [axes[0]]
76
+ return heaviside_p.bind(*args), [axes[0]]
76
77
 
77
78
 
78
79
  def _heaviside_jvp(primals, tangents):
79
- x, dx = primals
80
- tx, tdx = tangents
81
- primal_outs = heaviside_p.bind(x, dx)
82
- tangent_outs = [dx * tx, ]
83
- return primal_outs, tangent_outs
80
+ x, dx = primals
81
+ tx, tdx = tangents
82
+ primal_outs = heaviside_p.bind(x, dx)
83
+ tangent_outs = [dx * tx, ]
84
+ return primal_outs, tangent_outs
84
85
 
85
86
 
86
87
  heaviside_p = Primitive('heaviside_p')
@@ -93,260 +94,262 @@ mlir.register_lowering(heaviside_p, mlir.lower_fun(_heaviside_imp, multiple_resu
93
94
 
94
95
 
95
96
  class Surrogate(object):
96
- """The base surrograte gradient function.
97
+ """The base surrograte gradient function.
97
98
 
98
- To customize a surrogate gradient function, you can inherit this class and
99
- implement the `surrogate_fun` and `surrogate_grad` methods.
99
+ To customize a surrogate gradient function, you can inherit this class and
100
+ implement the `surrogate_fun` and `surrogate_grad` methods.
100
101
 
101
- Examples
102
- --------
102
+ Examples
103
+ --------
103
104
 
104
- >>> import brainstate as bst
105
- >>> import brainstate.nn as nn
106
- >>> import jax.numpy as jnp
105
+ >>> import brainstate as bst
106
+ >>> import brainstate.nn as nn
107
+ >>> import jax.numpy as jnp
107
108
 
108
- >>> class MySurrogate(nn.surrogate.Surrogate):
109
- ... def __init__(self, alpha=1.):
110
- ... super().__init__()
111
- ... self.alpha = alpha
112
- ...
113
- ... def surrogate_fun(self, x):
114
- ... return jnp.sin(x) * self.alpha
115
- ...
116
- ... def surrogate_grad(self, x):
117
- ... return jnp.cos(x) * self.alpha
109
+ >>> class MySurrogate(bst.surrogate.Surrogate):
110
+ ... def __init__(self, alpha=1.):
111
+ ... super().__init__()
112
+ ... self.alpha = alpha
113
+ ...
114
+ ... def surrogate_fun(self, x):
115
+ ... return jnp.sin(x) * self.alpha
116
+ ...
117
+ ... def surrogate_grad(self, x):
118
+ ... return jnp.cos(x) * self.alpha
118
119
 
119
- """
120
+ """
120
121
 
121
- def __call__(self, x):
122
- dx = self.surrogate_grad(x)
123
- return heaviside_p.bind(x, dx)[0]
122
+ def __call__(self, x):
123
+ dx = self.surrogate_grad(x)
124
+ return heaviside_p.bind(x, dx)[0]
124
125
 
125
- def __repr__(self):
126
- return f'{self.__class__.__name__}()'
126
+ def __repr__(self):
127
+ return f'{self.__class__.__name__}()'
127
128
 
128
- def surrogate_fun(self, x) -> jax.Array:
129
- """The surrogate function."""
130
- raise NotImplementedError
129
+ def surrogate_fun(self, x) -> jax.Array:
130
+ """The surrogate function."""
131
+ raise NotImplementedError
131
132
 
132
- def surrogate_grad(self, x) -> jax.Array:
133
- """The gradient function of the surrogate function."""
134
- raise NotImplementedError
133
+ def surrogate_grad(self, x) -> jax.Array:
134
+ """The gradient function of the surrogate function."""
135
+ raise NotImplementedError
135
136
 
136
137
 
137
138
  class Sigmoid(Surrogate):
138
- """Spike function with the sigmoid-shaped surrogate gradient.
139
+ """Spike function with the sigmoid-shaped surrogate gradient.
139
140
 
140
- See Also
141
- --------
142
- sigmoid
141
+ See Also
142
+ --------
143
+ sigmoid
143
144
 
144
- """
145
+ """
145
146
 
146
- def __init__(self, alpha: float = 4.):
147
- super().__init__()
148
- self.alpha = alpha
147
+ def __init__(self, alpha: float = 4.):
148
+ super().__init__()
149
+ self.alpha = alpha
149
150
 
150
- def surrogate_fun(self, x):
151
- return sci.special.expit(self.alpha * x)
151
+ def surrogate_fun(self, x):
152
+ return sci.special.expit(self.alpha * x)
152
153
 
153
- def surrogate_grad(self, x):
154
- sgax = sci.special.expit(x * self.alpha)
155
- dx = (1. - sgax) * sgax * self.alpha
156
- return dx
154
+ def surrogate_grad(self, x):
155
+ sgax = sci.special.expit(x * self.alpha)
156
+ dx = (1. - sgax) * sgax * self.alpha
157
+ return dx
157
158
 
158
- def __repr__(self):
159
- return f'{self.__class__.__name__}(alpha={self.alpha})'
159
+ def __repr__(self):
160
+ return f'{self.__class__.__name__}(alpha={self.alpha})'
160
161
 
161
- def __hash__(self):
162
- return hash((self.__class__, self.alpha))
162
+ def __hash__(self):
163
+ return hash((self.__class__, self.alpha))
163
164
 
164
165
 
165
166
  def sigmoid(
166
167
  x: jax.Array,
167
168
  alpha: float = 4.,
168
169
  ):
169
- r"""Spike function with the sigmoid-shaped surrogate gradient.
170
+ r"""Spike function with the sigmoid-shaped surrogate gradient.
170
171
 
171
- If `origin=False`, return the forward function:
172
+ If `origin=False`, return the forward function:
172
173
 
173
- .. math::
174
+ .. math::
174
175
 
175
- g(x) = \begin{cases}
176
- 1, & x \geq 0 \\
177
- 0, & x < 0 \\
178
- \end{cases}
176
+ g(x) = \begin{cases}
177
+ 1, & x \geq 0 \\
178
+ 0, & x < 0 \\
179
+ \end{cases}
179
180
 
180
- If `origin=True`, computes the original function:
181
+ If `origin=True`, computes the original function:
181
182
 
182
- .. math::
183
+ .. math::
183
184
 
184
- g(x) = \mathrm{sigmoid}(\alpha x) = \frac{1}{1+e^{-\alpha x}}
185
+ g(x) = \mathrm{sigmoid}(\alpha x) = \frac{1}{1+e^{-\alpha x}}
185
186
 
186
- Backward function:
187
+ Backward function:
187
188
 
188
- .. math::
189
+ .. math::
189
190
 
190
- g'(x) = \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x)
191
+ g'(x) = \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x)
191
192
 
192
- .. plot::
193
- :include-source: True
193
+ .. plot::
194
+ :include-source: True
194
195
 
195
- >>> import brainstate.nn as nn
196
- >>> import brainstate as bst
197
- >>> import matplotlib.pyplot as plt
198
- >>> xs = jax.numpy.linspace(-2, 2, 1000)
199
- >>> for alpha in [1., 2., 4.]:
200
- >>> grads = bst.transform.vector_grad(nn.surrogate.sigmoid)(xs, alpha)
201
- >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
202
- >>> plt.legend()
203
- >>> plt.show()
196
+ >>> import jax
197
+ >>> import brainstate.nn as nn
198
+ >>> import brainstate as bst
199
+ >>> import matplotlib.pyplot as plt
200
+ >>> xs = jax.numpy.linspace(-2, 2, 1000)
201
+ >>> for alpha in [1., 2., 4.]:
202
+ >>> grads = bst.augment.vector_grad(bst.surrogate.sigmoid)(xs, alpha)
203
+ >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
204
+ >>> plt.legend()
205
+ >>> plt.show()
204
206
 
205
- Parameters
206
- ----------
207
- x: jax.Array, Array
208
- The input data.
209
- alpha: float
210
- Parameter to control smoothness of gradient
207
+ Parameters
208
+ ----------
209
+ x: jax.Array, Array
210
+ The input data.
211
+ alpha: float
212
+ Parameter to control smoothness of gradient
211
213
 
212
214
 
213
- Returns
214
- -------
215
- out: jax.Array
216
- The spiking state.
217
- """
218
- return Sigmoid(alpha=alpha)(x)
215
+ Returns
216
+ -------
217
+ out: jax.Array
218
+ The spiking state.
219
+ """
220
+ return Sigmoid(alpha=alpha)(x)
219
221
 
220
222
 
221
223
  class PiecewiseQuadratic(Surrogate):
222
- """Judge spiking state with a piecewise quadratic function.
224
+ """Judge spiking state with a piecewise quadratic function.
223
225
 
224
- See Also
225
- --------
226
- piecewise_quadratic
226
+ See Also
227
+ --------
228
+ piecewise_quadratic
227
229
 
228
- """
230
+ """
229
231
 
230
- def __init__(self, alpha: float = 1.):
231
- super().__init__()
232
- self.alpha = alpha
232
+ def __init__(self, alpha: float = 1.):
233
+ super().__init__()
234
+ self.alpha = alpha
233
235
 
234
- def surrogate_fun(self, x):
235
- z = jnp.where(x < -1 / self.alpha,
236
- 0.,
237
- jnp.where(x > 1 / self.alpha,
238
- 1.,
239
- (-self.alpha * jnp.abs(x) / 2 + 1) * self.alpha * x + 0.5))
240
- return z
236
+ def surrogate_fun(self, x):
237
+ z = jnp.where(x < -1 / self.alpha,
238
+ 0.,
239
+ jnp.where(x > 1 / self.alpha,
240
+ 1.,
241
+ (-self.alpha * jnp.abs(x) / 2 + 1) * self.alpha * x + 0.5))
242
+ return z
241
243
 
242
- def surrogate_grad(self, x):
243
- dx = jnp.where(jnp.abs(x) > 1 / self.alpha, 0., (-(self.alpha * x) ** 2 + self.alpha))
244
- return dx
244
+ def surrogate_grad(self, x):
245
+ dx = jnp.where(jnp.abs(x) > 1 / self.alpha, 0., (-(self.alpha * x) ** 2 + self.alpha))
246
+ return dx
245
247
 
246
- def __repr__(self):
247
- return f'{self.__class__.__name__}(alpha={self.alpha})'
248
+ def __repr__(self):
249
+ return f'{self.__class__.__name__}(alpha={self.alpha})'
248
250
 
249
- def __hash__(self):
250
- return hash((self.__class__, self.alpha))
251
+ def __hash__(self):
252
+ return hash((self.__class__, self.alpha))
251
253
 
252
254
 
253
255
  def piecewise_quadratic(
254
256
  x: jax.Array,
255
257
  alpha: float = 1.,
256
258
  ):
257
- r"""Judge spiking state with a piecewise quadratic function [1]_ [2]_ [3]_ [4]_ [5]_.
259
+ r"""Judge spiking state with a piecewise quadratic function [1]_ [2]_ [3]_ [4]_ [5]_.
258
260
 
259
- If `origin=False`, computes the forward function:
261
+ If `origin=False`, computes the forward function:
260
262
 
261
- .. math::
263
+ .. math::
262
264
 
263
- g(x) = \begin{cases}
264
- 1, & x \geq 0 \\
265
- 0, & x < 0 \\
266
- \end{cases}
265
+ g(x) = \begin{cases}
266
+ 1, & x \geq 0 \\
267
+ 0, & x < 0 \\
268
+ \end{cases}
267
269
 
268
- If `origin=True`, computes the original function:
270
+ If `origin=True`, computes the original function:
269
271
 
270
- .. math::
272
+ .. math::
271
273
 
272
- g(x) =
273
- \begin{cases}
274
- 0, & x < -\frac{1}{\alpha} \\
275
- -\frac{1}{2}\alpha^2|x|x + \alpha x + \frac{1}{2}, & |x| \leq \frac{1}{\alpha} \\
276
- 1, & x > \frac{1}{\alpha} \\
277
- \end{cases}
274
+ g(x) =
275
+ \begin{cases}
276
+ 0, & x < -\frac{1}{\alpha} \\
277
+ -\frac{1}{2}\alpha^2|x|x + \alpha x + \frac{1}{2}, & |x| \leq \frac{1}{\alpha} \\
278
+ 1, & x > \frac{1}{\alpha} \\
279
+ \end{cases}
278
280
 
279
- Backward function:
281
+ Backward function:
280
282
 
281
- .. math::
283
+ .. math::
282
284
 
283
- g'(x) =
284
- \begin{cases}
285
- 0, & |x| > \frac{1}{\alpha} \\
286
- -\alpha^2|x|+\alpha, & |x| \leq \frac{1}{\alpha}
287
- \end{cases}
288
-
289
- .. plot::
290
- :include-source: True
291
-
292
- >>> import brainstate.nn as nn
293
- >>> import brainstate as bst
294
- >>> import matplotlib.pyplot as plt
295
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
296
- >>> for alpha in [0.5, 1., 2., 4.]:
297
- >>> grads = bst.transform.vector_grad(nn.surrogate.piecewise_quadratic)(xs, alpha)
298
- >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
299
- >>> plt.legend()
300
- >>> plt.show()
301
-
302
- Parameters
303
- ----------
304
- x: jax.Array, Array
305
- The input data.
306
- alpha: float
307
- Parameter to control smoothness of gradient
308
-
309
-
310
- Returns
311
- -------
312
- out: jax.Array
313
- The spiking state.
314
-
315
- References
316
- ----------
317
- .. [1] Esser S K, Merolla P A, Arthur J V, et al. Convolutional networks for fast, energy-efficient neuromorphic computing[J]. Proceedings of the national academy of sciences, 2016, 113(41): 11441-11446.
318
- .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331.
319
- .. [3] Bellec G, Salaj D, Subramoney A, et al. Long short-term memory and learning-to-learn in networks of spiking neurons[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 795-805.
320
- .. [4] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63.
321
- .. [5] Panda P, Aketi S A, Roy K. Toward scalable, efficient, and accurate deep spiking neural networks with backward residual connections, stochastic softmax, and hybridization[J]. Frontiers in Neuroscience, 2020, 14.
322
- """
323
- return PiecewiseQuadratic(alpha=alpha)(x)
285
+ g'(x) =
286
+ \begin{cases}
287
+ 0, & |x| > \frac{1}{\alpha} \\
288
+ -\alpha^2|x|+\alpha, & |x| \leq \frac{1}{\alpha}
289
+ \end{cases}
290
+
291
+ .. plot::
292
+ :include-source: True
293
+
294
+ >>> import jax
295
+ >>> import brainstate.nn as nn
296
+ >>> import brainstate as bst
297
+ >>> import matplotlib.pyplot as plt
298
+ >>> xs = jax.numpy.linspace(-3, 3, 1000)
299
+ >>> for alpha in [0.5, 1., 2., 4.]:
300
+ >>> grads = bst.augment.vector_grad(bst.surrogate.piecewise_quadratic)(xs, alpha)
301
+ >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
302
+ >>> plt.legend()
303
+ >>> plt.show()
304
+
305
+ Parameters
306
+ ----------
307
+ x: jax.Array, Array
308
+ The input data.
309
+ alpha: float
310
+ Parameter to control smoothness of gradient
311
+
312
+
313
+ Returns
314
+ -------
315
+ out: jax.Array
316
+ The spiking state.
317
+
318
+ References
319
+ ----------
320
+ .. [1] Esser S K, Merolla P A, Arthur J V, et al. Convolutional networks for fast, energy-efficient neuromorphic computing[J]. Proceedings of the national academy of sciences, 2016, 113(41): 11441-11446.
321
+ .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331.
322
+ .. [3] Bellec G, Salaj D, Subramoney A, et al. Long short-term memory and learning-to-learn in networks of spiking neurons[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 795-805.
323
+ .. [4] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63.
324
+ .. [5] Panda P, Aketi S A, Roy K. Toward scalable, efficient, and accurate deep spiking neural networks with backward residual connections, stochastic softmax, and hybridization[J]. Frontiers in Neuroscience, 2020, 14.
325
+ """
326
+ return PiecewiseQuadratic(alpha=alpha)(x)
324
327
 
325
328
 
326
329
  class PiecewiseExp(Surrogate):
327
- """Judge spiking state with a piecewise exponential function.
330
+ """Judge spiking state with a piecewise exponential function.
328
331
 
329
- See Also
330
- --------
331
- piecewise_exp
332
- """
332
+ See Also
333
+ --------
334
+ piecewise_exp
335
+ """
333
336
 
334
- def __init__(self, alpha: float = 1.):
335
- super().__init__()
336
- self.alpha = alpha
337
+ def __init__(self, alpha: float = 1.):
338
+ super().__init__()
339
+ self.alpha = alpha
337
340
 
338
- def surrogate_grad(self, x):
339
- dx = (self.alpha / 2) * jnp.exp(-self.alpha * jnp.abs(x))
340
- return dx
341
+ def surrogate_grad(self, x):
342
+ dx = (self.alpha / 2) * jnp.exp(-self.alpha * jnp.abs(x))
343
+ return dx
341
344
 
342
- def surrogate_fun(self, x):
343
- return jnp.where(x < 0, jnp.exp(self.alpha * x) / 2, 1 - jnp.exp(-self.alpha * x) / 2)
345
+ def surrogate_fun(self, x):
346
+ return jnp.where(x < 0, jnp.exp(self.alpha * x) / 2, 1 - jnp.exp(-self.alpha * x) / 2)
344
347
 
345
- def __repr__(self):
346
- return f'{self.__class__.__name__}(alpha={self.alpha})'
348
+ def __repr__(self):
349
+ return f'{self.__class__.__name__}(alpha={self.alpha})'
347
350
 
348
- def __hash__(self):
349
- return hash((self.__class__, self.alpha))
351
+ def __hash__(self):
352
+ return hash((self.__class__, self.alpha))
350
353
 
351
354
 
352
355
  def piecewise_exp(
@@ -354,89 +357,90 @@ def piecewise_exp(
354
357
  alpha: float = 1.,
355
358
 
356
359
  ):
357
- r"""Judge spiking state with a piecewise exponential function [1]_.
360
+ r"""Judge spiking state with a piecewise exponential function [1]_.
358
361
 
359
- If `origin=False`, computes the forward function:
362
+ If `origin=False`, computes the forward function:
360
363
 
361
- .. math::
364
+ .. math::
362
365
 
363
- g(x) = \begin{cases}
364
- 1, & x \geq 0 \\
365
- 0, & x < 0 \\
366
- \end{cases}
366
+ g(x) = \begin{cases}
367
+ 1, & x \geq 0 \\
368
+ 0, & x < 0 \\
369
+ \end{cases}
367
370
 
368
- If `origin=True`, computes the original function:
371
+ If `origin=True`, computes the original function:
369
372
 
370
- .. math::
373
+ .. math::
371
374
 
372
- g(x) = \begin{cases}
373
- \frac{1}{2}e^{\alpha x}, & x < 0 \\
374
- 1 - \frac{1}{2}e^{-\alpha x}, & x \geq 0
375
- \end{cases}
375
+ g(x) = \begin{cases}
376
+ \frac{1}{2}e^{\alpha x}, & x < 0 \\
377
+ 1 - \frac{1}{2}e^{-\alpha x}, & x \geq 0
378
+ \end{cases}
376
379
 
377
- Backward function:
380
+ Backward function:
378
381
 
379
- .. math::
382
+ .. math::
380
383
 
381
- g'(x) = \frac{\alpha}{2}e^{-\alpha |x|}
384
+ g'(x) = \frac{\alpha}{2}e^{-\alpha |x|}
382
385
 
383
- .. plot::
384
- :include-source: True
386
+ .. plot::
387
+ :include-source: True
385
388
 
386
- >>> import brainstate.nn as nn
387
- >>> import brainstate as bst
388
- >>> import matplotlib.pyplot as plt
389
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
390
- >>> for alpha in [0.5, 1., 2., 4.]:
391
- >>> grads = bst.transform.vector_grad(nn.surrogate.piecewise_exp)(xs, alpha)
392
- >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
393
- >>> plt.legend()
394
- >>> plt.show()
389
+ >>> import jax
390
+ >>> import brainstate.nn as nn
391
+ >>> import brainstate as bst
392
+ >>> import matplotlib.pyplot as plt
393
+ >>> xs = jax.numpy.linspace(-3, 3, 1000)
394
+ >>> for alpha in [0.5, 1., 2., 4.]:
395
+ >>> grads = bst.augment.vector_grad(bst.surrogate.piecewise_exp)(xs, alpha)
396
+ >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
397
+ >>> plt.legend()
398
+ >>> plt.show()
395
399
 
396
- Parameters
397
- ----------
398
- x: jax.Array, Array
399
- The input data.
400
- alpha: float
401
- Parameter to control smoothness of gradient
400
+ Parameters
401
+ ----------
402
+ x: jax.Array, Array
403
+ The input data.
404
+ alpha: float
405
+ Parameter to control smoothness of gradient
402
406
 
403
407
 
404
- Returns
405
- -------
406
- out: jax.Array
407
- The spiking state.
408
+ Returns
409
+ -------
410
+ out: jax.Array
411
+ The spiking state.
408
412
 
409
- References
410
- ----------
411
- .. [1] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63.
412
- """
413
- return PiecewiseExp(alpha=alpha)(x)
413
+ References
414
+ ----------
415
+ .. [1] Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63.
416
+ """
417
+ return PiecewiseExp(alpha=alpha)(x)
414
418
 
415
419
 
416
420
  class SoftSign(Surrogate):
417
- """Judge spiking state with a soft sign function.
421
+ """Judge spiking state with a soft sign function.
418
422
 
419
- See Also
420
- --------
421
- soft_sign
422
- """
423
+ See Also
424
+ --------
425
+ soft_sign
426
+ """
423
427
 
424
- def __init__(self, alpha=1.):
425
- super().__init__()
426
- self.alpha = alpha
428
+ def __init__(self, alpha=1.):
429
+ super().__init__()
430
+ self.alpha = alpha
427
431
 
428
- def surrogate_grad(self, x):
429
- dx = self.alpha * 0.5 / (1 + jnp.abs(self.alpha * x)) ** 2
430
- return dx
432
+ def surrogate_grad(self, x):
433
+ dx = self.alpha * 0.5 / (1 + jnp.abs(self.alpha * x)) ** 2
434
+ return dx
431
435
 
432
- def surrogate_fun(self, x):
433
- return x / (2 / self.alpha + 2 * jnp.abs(x)) + 0.5
436
+ def surrogate_fun(self, x):
437
+ return x / (2 / self.alpha + 2 * jnp.abs(x)) + 0.5
434
438
 
435
- def __repr__(self):
436
- return f'{self.__class__.__name__}(alpha={self.alpha})'
439
+ def __repr__(self):
440
+ return f'{self.__class__.__name__}(alpha={self.alpha})'
437
441
 
438
- def __hash__(self):
439
- return hash((self.__class__, self.alpha))
442
+ def __hash__(self):
443
+ return hash((self.__class__, self.alpha))
440
444
 
441
445
 
442
446
  def soft_sign(
@@ -444,84 +448,85 @@ def soft_sign(
444
448
  alpha: float = 1.,
445
449
 
446
450
  ):
447
- r"""Judge spiking state with a soft sign function.
451
+ r"""Judge spiking state with a soft sign function.
448
452
 
449
- If `origin=False`, computes the forward function:
453
+ If `origin=False`, computes the forward function:
450
454
 
451
- .. math::
455
+ .. math::
452
456
 
453
- g(x) = \begin{cases}
454
- 1, & x \geq 0 \\
455
- 0, & x < 0 \\
456
- \end{cases}
457
+ g(x) = \begin{cases}
458
+ 1, & x \geq 0 \\
459
+ 0, & x < 0 \\
460
+ \end{cases}
457
461
 
458
- If `origin=True`, computes the original function:
462
+ If `origin=True`, computes the original function:
459
463
 
460
- .. math::
464
+ .. math::
461
465
 
462
- g(x) = \frac{1}{2} (\frac{\alpha x}{1 + |\alpha x|} + 1)
463
- = \frac{1}{2} (\frac{x}{\frac{1}{\alpha} + |x|} + 1)
466
+ g(x) = \frac{1}{2} (\frac{\alpha x}{1 + |\alpha x|} + 1)
467
+ = \frac{1}{2} (\frac{x}{\frac{1}{\alpha} + |x|} + 1)
464
468
 
465
- Backward function:
469
+ Backward function:
466
470
 
467
- .. math::
471
+ .. math::
468
472
 
469
- g'(x) = \frac{\alpha}{2(1 + |\alpha x|)^{2}} = \frac{1}{2\alpha(\frac{1}{\alpha} + |x|)^{2}}
473
+ g'(x) = \frac{\alpha}{2(1 + |\alpha x|)^{2}} = \frac{1}{2\alpha(\frac{1}{\alpha} + |x|)^{2}}
470
474
 
471
- .. plot::
472
- :include-source: True
475
+ .. plot::
476
+ :include-source: True
473
477
 
474
- >>> import brainstate.nn as nn
475
- >>> import brainstate as bst
476
- >>> import matplotlib.pyplot as plt
477
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
478
- >>> for alpha in [0.5, 1., 2., 4.]:
479
- >>> grads = bst.transform.vector_grad(nn.surrogate.soft_sign)(xs, alpha)
480
- >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
481
- >>> plt.legend()
482
- >>> plt.show()
478
+ >>> import jax
479
+ >>> import brainstate.nn as nn
480
+ >>> import brainstate as bst
481
+ >>> import matplotlib.pyplot as plt
482
+ >>> xs = jax.numpy.linspace(-3, 3, 1000)
483
+ >>> for alpha in [0.5, 1., 2., 4.]:
484
+ >>> grads = bst.augment.vector_grad(bst.surrogate.soft_sign)(xs, alpha)
485
+ >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
486
+ >>> plt.legend()
487
+ >>> plt.show()
483
488
 
484
- Parameters
485
- ----------
486
- x: jax.Array, Array
487
- The input data.
488
- alpha: float
489
- Parameter to control smoothness of gradient
489
+ Parameters
490
+ ----------
491
+ x: jax.Array, Array
492
+ The input data.
493
+ alpha: float
494
+ Parameter to control smoothness of gradient
490
495
 
491
496
 
492
- Returns
493
- -------
494
- out: jax.Array
495
- The spiking state.
497
+ Returns
498
+ -------
499
+ out: jax.Array
500
+ The spiking state.
496
501
 
497
- """
498
- return SoftSign(alpha=alpha)(x)
502
+ """
503
+ return SoftSign(alpha=alpha)(x)
499
504
 
500
505
 
501
506
  class Arctan(Surrogate):
502
- """Judge spiking state with an arctan function.
507
+ """Judge spiking state with an arctan function.
503
508
 
504
- See Also
505
- --------
506
- arctan
507
- """
509
+ See Also
510
+ --------
511
+ arctan
512
+ """
508
513
 
509
- def __init__(self, alpha=1.):
510
- super().__init__()
511
- self.alpha = alpha
514
+ def __init__(self, alpha=1.):
515
+ super().__init__()
516
+ self.alpha = alpha
512
517
 
513
- def surrogate_grad(self, x):
514
- dx = self.alpha * 0.5 / (1 + (jnp.pi / 2 * self.alpha * x) ** 2)
515
- return dx
518
+ def surrogate_grad(self, x):
519
+ dx = self.alpha * 0.5 / (1 + (jnp.pi / 2 * self.alpha * x) ** 2)
520
+ return dx
516
521
 
517
- def surrogate_fun(self, x):
518
- return jnp.arctan2(jnp.pi / 2 * self.alpha * x) / jnp.pi + 0.5
522
+ def surrogate_fun(self, x):
523
+ return jnp.arctan2(jnp.pi / 2 * self.alpha * x) / jnp.pi + 0.5
519
524
 
520
- def __repr__(self):
521
- return f'{self.__class__.__name__}(alpha={self.alpha})'
525
+ def __repr__(self):
526
+ return f'{self.__class__.__name__}(alpha={self.alpha})'
522
527
 
523
- def __hash__(self):
524
- return hash((self.__class__, self.alpha))
528
+ def __hash__(self):
529
+ return hash((self.__class__, self.alpha))
525
530
 
526
531
 
527
532
  def arctan(
@@ -529,83 +534,84 @@ def arctan(
529
534
  alpha: float = 1.,
530
535
 
531
536
  ):
532
- r"""Judge spiking state with an arctan function.
537
+ r"""Judge spiking state with an arctan function.
533
538
 
534
- If `origin=False`, computes the forward function:
539
+ If `origin=False`, computes the forward function:
535
540
 
536
- .. math::
541
+ .. math::
537
542
 
538
- g(x) = \begin{cases}
539
- 1, & x \geq 0 \\
540
- 0, & x < 0 \\
541
- \end{cases}
543
+ g(x) = \begin{cases}
544
+ 1, & x \geq 0 \\
545
+ 0, & x < 0 \\
546
+ \end{cases}
542
547
 
543
- If `origin=True`, computes the original function:
548
+ If `origin=True`, computes the original function:
544
549
 
545
- .. math::
550
+ .. math::
546
551
 
547
- g(x) = \frac{1}{\pi} \arctan(\frac{\pi}{2}\alpha x) + \frac{1}{2}
552
+ g(x) = \frac{1}{\pi} \arctan(\frac{\pi}{2}\alpha x) + \frac{1}{2}
548
553
 
549
- Backward function:
554
+ Backward function:
550
555
 
551
- .. math::
556
+ .. math::
552
557
 
553
- g'(x) = \frac{\alpha}{2(1 + (\frac{\pi}{2}\alpha x)^2)}
558
+ g'(x) = \frac{\alpha}{2(1 + (\frac{\pi}{2}\alpha x)^2)}
554
559
 
555
- .. plot::
556
- :include-source: True
560
+ .. plot::
561
+ :include-source: True
557
562
 
558
- >>> import brainstate.nn as nn
559
- >>> import brainstate as bst
560
- >>> import matplotlib.pyplot as plt
561
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
562
- >>> for alpha in [0.5, 1., 2., 4.]:
563
- >>> grads = bst.transform.vector_grad(nn.surrogate.arctan)(xs, alpha)
564
- >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
565
- >>> plt.legend()
566
- >>> plt.show()
563
+ >>> import jax
564
+ >>> import brainstate.nn as nn
565
+ >>> import brainstate as bst
566
+ >>> import matplotlib.pyplot as plt
567
+ >>> xs = jax.numpy.linspace(-3, 3, 1000)
568
+ >>> for alpha in [0.5, 1., 2., 4.]:
569
+ >>> grads = bst.augment.vector_grad(bst.surrogate.arctan)(xs, alpha)
570
+ >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
571
+ >>> plt.legend()
572
+ >>> plt.show()
567
573
 
568
- Parameters
569
- ----------
570
- x: jax.Array, Array
571
- The input data.
572
- alpha: float
573
- Parameter to control smoothness of gradient
574
+ Parameters
575
+ ----------
576
+ x: jax.Array, Array
577
+ The input data.
578
+ alpha: float
579
+ Parameter to control smoothness of gradient
574
580
 
575
581
 
576
- Returns
577
- -------
578
- out: jax.Array
579
- The spiking state.
582
+ Returns
583
+ -------
584
+ out: jax.Array
585
+ The spiking state.
580
586
 
581
- """
582
- return Arctan(alpha=alpha)(x)
587
+ """
588
+ return Arctan(alpha=alpha)(x)
583
589
 
584
590
 
585
591
  class NonzeroSignLog(Surrogate):
586
- """Judge spiking state with a nonzero sign log function.
592
+ """Judge spiking state with a nonzero sign log function.
587
593
 
588
- See Also
589
- --------
590
- nonzero_sign_log
591
- """
594
+ See Also
595
+ --------
596
+ nonzero_sign_log
597
+ """
592
598
 
593
- def __init__(self, alpha=1.):
594
- super().__init__()
595
- self.alpha = alpha
599
+ def __init__(self, alpha=1.):
600
+ super().__init__()
601
+ self.alpha = alpha
596
602
 
597
- def surrogate_grad(self, x):
598
- dx = 1. / (1 / self.alpha + jnp.abs(x))
599
- return dx
603
+ def surrogate_grad(self, x):
604
+ dx = 1. / (1 / self.alpha + jnp.abs(x))
605
+ return dx
600
606
 
601
- def surrogate_fun(self, x):
602
- return jnp.where(x < 0, -1., 1.) * jnp.log(jnp.abs(self.alpha * x) + 1)
607
+ def surrogate_fun(self, x):
608
+ return jnp.where(x < 0, -1., 1.) * jnp.log(jnp.abs(self.alpha * x) + 1)
603
609
 
604
- def __repr__(self):
605
- return f'{self.__class__.__name__}(alpha={self.alpha})'
610
+ def __repr__(self):
611
+ return f'{self.__class__.__name__}(alpha={self.alpha})'
606
612
 
607
- def __hash__(self):
608
- return hash((self.__class__, self.alpha))
613
+ def __hash__(self):
614
+ return hash((self.__class__, self.alpha))
609
615
 
610
616
 
611
617
  def nonzero_sign_log(
@@ -613,96 +619,97 @@ def nonzero_sign_log(
613
619
  alpha: float = 1.,
614
620
 
615
621
  ):
616
- r"""Judge spiking state with a nonzero sign log function.
622
+ r"""Judge spiking state with a nonzero sign log function.
617
623
 
618
- If `origin=False`, computes the forward function:
624
+ If `origin=False`, computes the forward function:
619
625
 
620
- .. math::
626
+ .. math::
621
627
 
622
- g(x) = \begin{cases}
623
- 1, & x \geq 0 \\
624
- 0, & x < 0 \\
625
- \end{cases}
628
+ g(x) = \begin{cases}
629
+ 1, & x \geq 0 \\
630
+ 0, & x < 0 \\
631
+ \end{cases}
626
632
 
627
- If `origin=True`, computes the original function:
633
+ If `origin=True`, computes the original function:
628
634
 
629
- .. math::
635
+ .. math::
630
636
 
631
- g(x) = \mathrm{NonzeroSign}(x) \log (|\alpha x| + 1)
637
+ g(x) = \mathrm{NonzeroSign}(x) \log (|\alpha x| + 1)
632
638
 
633
- where
639
+ where
634
640
 
635
- .. math::
641
+ .. math::
636
642
 
637
- \begin{split}\mathrm{NonzeroSign}(x) =
638
- \begin{cases}
639
- 1, & x \geq 0 \\
640
- -1, & x < 0 \\
641
- \end{cases}\end{split}
643
+ \begin{split}\mathrm{NonzeroSign}(x) =
644
+ \begin{cases}
645
+ 1, & x \geq 0 \\
646
+ -1, & x < 0 \\
647
+ \end{cases}\end{split}
642
648
 
643
- Backward function:
649
+ Backward function:
644
650
 
645
- .. math::
651
+ .. math::
646
652
 
647
- g'(x) = \frac{\alpha}{1 + |\alpha x|} = \frac{1}{\frac{1}{\alpha} + |x|}
653
+ g'(x) = \frac{\alpha}{1 + |\alpha x|} = \frac{1}{\frac{1}{\alpha} + |x|}
648
654
 
649
- This surrogate function has the advantage of low computation cost during the backward.
655
+ This surrogate function has the advantage of low computation cost during the backward.
650
656
 
651
657
 
652
- .. plot::
653
- :include-source: True
658
+ .. plot::
659
+ :include-source: True
654
660
 
655
- >>> import brainstate.nn as nn
656
- >>> import brainstate as bst
657
- >>> import matplotlib.pyplot as plt
658
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
659
- >>> for alpha in [0.5, 1., 2., 4.]:
660
- >>> grads = bst.transform.vector_grad(nn.surrogate.nonzero_sign_log)(xs, alpha)
661
- >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
662
- >>> plt.legend()
663
- >>> plt.show()
661
+ >>> import jax
662
+ >>> import brainstate.nn as nn
663
+ >>> import brainstate as bst
664
+ >>> import matplotlib.pyplot as plt
665
+ >>> xs = jax.numpy.linspace(-3, 3, 1000)
666
+ >>> for alpha in [0.5, 1., 2., 4.]:
667
+ >>> grads = bst.augment.vector_grad(bst.surrogate.nonzero_sign_log)(xs, alpha)
668
+ >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
669
+ >>> plt.legend()
670
+ >>> plt.show()
664
671
 
665
- Parameters
666
- ----------
667
- x: jax.Array, Array
668
- The input data.
669
- alpha: float
670
- Parameter to control smoothness of gradient
672
+ Parameters
673
+ ----------
674
+ x: jax.Array, Array
675
+ The input data.
676
+ alpha: float
677
+ Parameter to control smoothness of gradient
671
678
 
672
679
 
673
- Returns
674
- -------
675
- out: jax.Array
676
- The spiking state.
680
+ Returns
681
+ -------
682
+ out: jax.Array
683
+ The spiking state.
677
684
 
678
- """
679
- return NonzeroSignLog(alpha=alpha)(x)
685
+ """
686
+ return NonzeroSignLog(alpha=alpha)(x)
680
687
 
681
688
 
682
689
  class ERF(Surrogate):
683
- """Judge spiking state with an erf function.
690
+ """Judge spiking state with an erf function.
684
691
 
685
- See Also
686
- --------
687
- erf
688
- """
692
+ See Also
693
+ --------
694
+ erf
695
+ """
689
696
 
690
- def __init__(self, alpha=1.):
691
- super().__init__()
692
- self.alpha = alpha
697
+ def __init__(self, alpha=1.):
698
+ super().__init__()
699
+ self.alpha = alpha
693
700
 
694
- def surrogate_grad(self, x):
695
- dx = (self.alpha / jnp.sqrt(jnp.pi)) * jnp.exp(-jnp.power(self.alpha, 2) * x * x)
696
- return dx
701
+ def surrogate_grad(self, x):
702
+ dx = (self.alpha / jnp.sqrt(jnp.pi)) * jnp.exp(-jnp.power(self.alpha, 2) * x * x)
703
+ return dx
697
704
 
698
- def surrogate_fun(self, x):
699
- return sci.special.erf(-self.alpha * x) * 0.5
705
+ def surrogate_fun(self, x):
706
+ return sci.special.erf(-self.alpha * x) * 0.5
700
707
 
701
- def __repr__(self):
702
- return f'{self.__class__.__name__}(alpha={self.alpha})'
708
+ def __repr__(self):
709
+ return f'{self.__class__.__name__}(alpha={self.alpha})'
703
710
 
704
- def __hash__(self):
705
- return hash((self.__class__, self.alpha))
711
+ def __hash__(self):
712
+ return hash((self.__class__, self.alpha))
706
713
 
707
714
 
708
715
  def erf(
@@ -710,99 +717,100 @@ def erf(
710
717
  alpha: float = 1.,
711
718
 
712
719
  ):
713
- r"""Judge spiking state with an erf function [1]_ [2]_ [3]_.
720
+ r"""Judge spiking state with an erf function [1]_ [2]_ [3]_.
714
721
 
715
- If `origin=False`, computes the forward function:
722
+ If `origin=False`, computes the forward function:
716
723
 
717
- .. math::
724
+ .. math::
718
725
 
719
- g(x) = \begin{cases}
720
- 1, & x \geq 0 \\
721
- 0, & x < 0 \\
722
- \end{cases}
726
+ g(x) = \begin{cases}
727
+ 1, & x \geq 0 \\
728
+ 0, & x < 0 \\
729
+ \end{cases}
723
730
 
724
- If `origin=True`, computes the original function:
731
+ If `origin=True`, computes the original function:
725
732
 
726
- .. math::
733
+ .. math::
727
734
 
728
- \begin{split}
729
- g(x) &= \frac{1}{2}(1-\text{erf}(-\alpha x)) \\
730
- &= \frac{1}{2} \text{erfc}(-\alpha x) \\
731
- &= \frac{1}{\sqrt{\pi}}\int_{-\infty}^{\alpha x}e^{-t^2}dt
732
- \end{split}
735
+ \begin{split}
736
+ g(x) &= \frac{1}{2}(1-\text{erf}(-\alpha x)) \\
737
+ &= \frac{1}{2} \text{erfc}(-\alpha x) \\
738
+ &= \frac{1}{\sqrt{\pi}}\int_{-\infty}^{\alpha x}e^{-t^2}dt
739
+ \end{split}
733
740
 
734
- Backward function:
741
+ Backward function:
735
742
 
736
- .. math::
743
+ .. math::
737
744
 
738
- g'(x) = \frac{\alpha}{\sqrt{\pi}}e^{-\alpha^2x^2}
745
+ g'(x) = \frac{\alpha}{\sqrt{\pi}}e^{-\alpha^2x^2}
739
746
 
740
- .. plot::
741
- :include-source: True
747
+ .. plot::
748
+ :include-source: True
742
749
 
743
- >>> import brainstate.nn as nn
744
- >>> import brainstate as bst
745
- >>> import matplotlib.pyplot as plt
746
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
747
- >>> for alpha in [0.5, 1., 2., 4.]:
748
- >>> grads = bst.transform.vector_grad(nn.surrogate.nonzero_sign_log)(xs, alpha)
749
- >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
750
- >>> plt.legend()
751
- >>> plt.show()
750
+ >>> import jax
751
+ >>> import brainstate.nn as nn
752
+ >>> import brainstate as bst
753
+ >>> import matplotlib.pyplot as plt
754
+ >>> xs = jax.numpy.linspace(-3, 3, 1000)
755
+ >>> for alpha in [0.5, 1., 2., 4.]:
756
+ >>> grads = bst.augment.vector_grad(bst.surrogate.nonzero_sign_log)(xs, alpha)
757
+ >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
758
+ >>> plt.legend()
759
+ >>> plt.show()
752
760
 
753
- Parameters
754
- ----------
755
- x: jax.Array, Array
756
- The input data.
757
- alpha: float
758
- Parameter to control smoothness of gradient
761
+ Parameters
762
+ ----------
763
+ x: jax.Array, Array
764
+ The input data.
765
+ alpha: float
766
+ Parameter to control smoothness of gradient
759
767
 
760
768
 
761
- Returns
762
- -------
763
- out: jax.Array
764
- The spiking state.
769
+ Returns
770
+ -------
771
+ out: jax.Array
772
+ The spiking state.
765
773
 
766
- References
767
- ----------
768
- .. [1] Esser S K, Appuswamy R, Merolla P, et al. Backpropagation for energy-efficient neuromorphic computing[J]. Advances in neural information processing systems, 2015, 28: 1117-1125.
769
- .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331.
770
- .. [3] Yin B, Corradi F, Bohté S M. Effective and efficient computation with multiple-timescale spiking recurrent neural networks[C]//International Conference on Neuromorphic Systems 2020. 2020: 1-8.
774
+ References
775
+ ----------
776
+ .. [1] Esser S K, Appuswamy R, Merolla P, et al. Backpropagation for energy-efficient neuromorphic computing[J]. Advances in neural information processing systems, 2015, 28: 1117-1125.
777
+ .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331.
778
+ .. [3] Yin B, Corradi F, Bohté S M. Effective and efficient computation with multiple-timescale spiking recurrent neural networks[C]//International Conference on Neuromorphic Systems 2020. 2020: 1-8.
771
779
 
772
- """
773
- return ERF(alpha=alpha)(x)
780
+ """
781
+ return ERF(alpha=alpha)(x)
774
782
 
775
783
 
776
784
  class PiecewiseLeakyRelu(Surrogate):
777
- """Judge spiking state with a piecewise leaky relu function.
785
+ """Judge spiking state with a piecewise leaky relu function.
778
786
 
779
- See Also
780
- --------
781
- piecewise_leaky_relu
782
- """
787
+ See Also
788
+ --------
789
+ piecewise_leaky_relu
790
+ """
783
791
 
784
- def __init__(self, c=0.01, w=1.):
785
- super().__init__()
786
- self.c = c
787
- self.w = w
792
+ def __init__(self, c=0.01, w=1.):
793
+ super().__init__()
794
+ self.c = c
795
+ self.w = w
788
796
 
789
- def surrogate_fun(self, x):
790
- z = jnp.where(x < -self.w,
791
- self.c * x + self.c * self.w,
792
- jnp.where(x > self.w,
793
- self.c * x - self.c * self.w + 1,
794
- 0.5 * x / self.w + 0.5))
795
- return z
797
+ def surrogate_fun(self, x):
798
+ z = jnp.where(x < -self.w,
799
+ self.c * x + self.c * self.w,
800
+ jnp.where(x > self.w,
801
+ self.c * x - self.c * self.w + 1,
802
+ 0.5 * x / self.w + 0.5))
803
+ return z
796
804
 
797
- def surrogate_grad(self, x):
798
- dx = jnp.where(jnp.abs(x) > self.w, self.c, 1 / self.w)
799
- return dx
805
+ def surrogate_grad(self, x):
806
+ dx = jnp.where(jnp.abs(x) > self.w, self.c, 1 / self.w)
807
+ return dx
800
808
 
801
- def __repr__(self):
802
- return f'{self.__class__.__name__}(c={self.c}, w={self.w})'
809
+ def __repr__(self):
810
+ return f'{self.__class__.__name__}(c={self.c}, w={self.w})'
803
811
 
804
- def __hash__(self):
805
- return hash((self.__class__, self.c, self.w))
812
+ def __hash__(self):
813
+ return hash((self.__class__, self.c, self.w))
806
814
 
807
815
 
808
816
  def piecewise_leaky_relu(
@@ -811,119 +819,120 @@ def piecewise_leaky_relu(
811
819
  w: float = 1.,
812
820
 
813
821
  ):
814
- r"""Judge spiking state with a piecewise leaky relu function [1]_ [2]_ [3]_ [4]_ [5]_ [6]_ [7]_ [8]_.
822
+ r"""Judge spiking state with a piecewise leaky relu function [1]_ [2]_ [3]_ [4]_ [5]_ [6]_ [7]_ [8]_.
815
823
 
816
- If `origin=False`, computes the forward function:
824
+ If `origin=False`, computes the forward function:
817
825
 
818
- .. math::
826
+ .. math::
819
827
 
820
- g(x) = \begin{cases}
821
- 1, & x \geq 0 \\
822
- 0, & x < 0 \\
823
- \end{cases}
824
-
825
- If `origin=True`, computes the original function:
828
+ g(x) = \begin{cases}
829
+ 1, & x \geq 0 \\
830
+ 0, & x < 0 \\
831
+ \end{cases}
826
832
 
827
- .. math::
833
+ If `origin=True`, computes the original function:
828
834
 
829
- \begin{split}g(x) =
830
- \begin{cases}
831
- cx + cw, & x < -w \\
832
- \frac{1}{2w}x + \frac{1}{2}, & -w \leq x \leq w \\
833
- cx - cw + 1, & x > w \\
834
- \end{cases}\end{split}
835
+ .. math::
835
836
 
836
- Backward function:
837
+ \begin{split}g(x) =
838
+ \begin{cases}
839
+ cx + cw, & x < -w \\
840
+ \frac{1}{2w}x + \frac{1}{2}, & -w \leq x \leq w \\
841
+ cx - cw + 1, & x > w \\
842
+ \end{cases}\end{split}
837
843
 
838
- .. math::
844
+ Backward function:
839
845
 
840
- \begin{split}g'(x) =
841
- \begin{cases}
842
- \frac{1}{w}, & |x| \leq w \\
843
- c, & |x| > w
844
- \end{cases}\end{split}
846
+ .. math::
845
847
 
846
- .. plot::
847
- :include-source: True
848
-
849
- >>> import brainstate.nn as nn
850
- >>> import brainstate as bst
851
- >>> import matplotlib.pyplot as plt
852
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
853
- >>> for c in [0.01, 0.05, 0.1]:
854
- >>> for w in [1., 2.]:
855
- >>> grads1 = bst.transform.vector_grad(nn.surrogate.piecewise_leaky_relu)(xs, c=c, w=w)
856
- >>> plt.plot(xs, grads1, label=f'x={c}, w={w}')
857
- >>> plt.legend()
858
- >>> plt.show()
859
-
860
- Parameters
861
- ----------
862
- x: jax.Array, Array
863
- The input data.
864
- c: float
865
- When :math:`|x| > w` the gradient is `c`.
866
- w: float
867
- When :math:`|x| <= w` the gradient is `1 / w`.
868
-
869
-
870
- Returns
871
- -------
872
- out: jax.Array
873
- The spiking state.
874
-
875
- References
876
- ----------
877
- .. [1] Yin S, Venkataramanaiah S K, Chen G K, et al. Algorithm and hardware design of discrete-time spiking neural networks based on back propagation with binary activations[C]//2017 IEEE Biomedical Circuits and Systems Conference (BioCAS). IEEE, 2017: 1-5.
878
- .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331.
879
- .. [3] Huh D, Sejnowski T J. Gradient descent for spiking neural networks[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 1440-1450.
880
- .. [4] Wu Y, Deng L, Li G, et al. Direct training for spiking neural networks: Faster, larger, better[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2019, 33(01): 1311-1318.
881
- .. [5] Gu P, Xiao R, Pan G, et al. STCA: Spatio-Temporal Credit Assignment with Delayed Feedback in Deep Spiking Neural Networks[C]//IJCAI. 2019: 1366-1372.
882
- .. [6] Roy D, Chakraborty I, Roy K. Scaling deep spiking neural networks with binary stochastic activations[C]//2019 IEEE International Conference on Cognitive Computing (ICCC). IEEE, 2019: 50-58.
883
- .. [7] Cheng X, Hao Y, Xu J, et al. LISNN: Improving Spiking Neural Networks with Lateral Interactions for Robust Object Recognition[C]//IJCAI. 1519-1525.
884
- .. [8] Kaiser J, Mostafa H, Neftci E. Synaptic plasticity dynamics for deep continuous local learning (DECOLLE)[J]. Frontiers in Neuroscience, 2020, 14: 424.
885
-
886
- """
887
- return PiecewiseLeakyRelu(c=c, w=w)(x)
848
+ \begin{split}g'(x) =
849
+ \begin{cases}
850
+ \frac{1}{w}, & |x| \leq w \\
851
+ c, & |x| > w
852
+ \end{cases}\end{split}
853
+
854
+ .. plot::
855
+ :include-source: True
856
+
857
+ >>> import jax
858
+ >>> import brainstate.nn as nn
859
+ >>> import brainstate as bst
860
+ >>> import matplotlib.pyplot as plt
861
+ >>> xs = jax.numpy.linspace(-3, 3, 1000)
862
+ >>> for c in [0.01, 0.05, 0.1]:
863
+ >>> for w in [1., 2.]:
864
+ >>> grads1 = bst.augment.vector_grad(bst.surrogate.piecewise_leaky_relu)(xs, c=c, w=w)
865
+ >>> plt.plot(xs, grads1, label=f'x={c}, w={w}')
866
+ >>> plt.legend()
867
+ >>> plt.show()
868
+
869
+ Parameters
870
+ ----------
871
+ x: jax.Array, Array
872
+ The input data.
873
+ c: float
874
+ When :math:`|x| > w` the gradient is `c`.
875
+ w: float
876
+ When :math:`|x| <= w` the gradient is `1 / w`.
877
+
878
+
879
+ Returns
880
+ -------
881
+ out: jax.Array
882
+ The spiking state.
883
+
884
+ References
885
+ ----------
886
+ .. [1] Yin S, Venkataramanaiah S K, Chen G K, et al. Algorithm and hardware design of discrete-time spiking neural networks based on back propagation with binary activations[C]//2017 IEEE Biomedical Circuits and Systems Conference (BioCAS). IEEE, 2017: 1-5.
887
+ .. [2] Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331.
888
+ .. [3] Huh D, Sejnowski T J. Gradient descent for spiking neural networks[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 1440-1450.
889
+ .. [4] Wu Y, Deng L, Li G, et al. Direct training for spiking neural networks: Faster, larger, better[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2019, 33(01): 1311-1318.
890
+ .. [5] Gu P, Xiao R, Pan G, et al. STCA: Spatio-Temporal Credit Assignment with Delayed Feedback in Deep Spiking Neural Networks[C]//IJCAI. 2019: 1366-1372.
891
+ .. [6] Roy D, Chakraborty I, Roy K. Scaling deep spiking neural networks with binary stochastic activations[C]//2019 IEEE International Conference on Cognitive Computing (ICCC). IEEE, 2019: 50-58.
892
+ .. [7] Cheng X, Hao Y, Xu J, et al. LISNN: Improving Spiking Neural Networks with Lateral Interactions for Robust Object Recognition[C]//IJCAI. 1519-1525.
893
+ .. [8] Kaiser J, Mostafa H, Neftci E. Synaptic plasticity dynamics for deep continuous local learning (DECOLLE)[J]. Frontiers in Neuroscience, 2020, 14: 424.
894
+
895
+ """
896
+ return PiecewiseLeakyRelu(c=c, w=w)(x)
888
897
 
889
898
 
890
899
  class SquarewaveFourierSeries(Surrogate):
891
- """Judge spiking state with a squarewave fourier series.
900
+ """Judge spiking state with a squarewave fourier series.
892
901
 
893
- See Also
894
- --------
895
- squarewave_fourier_series
896
- """
902
+ See Also
903
+ --------
904
+ squarewave_fourier_series
905
+ """
897
906
 
898
- def __init__(self, n=2, t_period=8.):
899
- super().__init__()
900
- self.n = n
901
- self.t_period = t_period
907
+ def __init__(self, n=2, t_period=8.):
908
+ super().__init__()
909
+ self.n = n
910
+ self.t_period = t_period
902
911
 
903
- def surrogate_grad(self, x):
912
+ def surrogate_grad(self, x):
904
913
 
905
- w = jnp.pi * 2. / self.t_period
906
- dx = jnp.cos(w * x)
907
- for i in range(2, self.n):
908
- dx += jnp.cos((2 * i - 1.) * w * x)
909
- dx *= 4. / self.t_period
910
- return dx
914
+ w = jnp.pi * 2. / self.t_period
915
+ dx = jnp.cos(w * x)
916
+ for i in range(2, self.n):
917
+ dx += jnp.cos((2 * i - 1.) * w * x)
918
+ dx *= 4. / self.t_period
919
+ return dx
911
920
 
912
- def surrogate_fun(self, x):
921
+ def surrogate_fun(self, x):
913
922
 
914
- w = jnp.pi * 2. / self.t_period
915
- ret = jnp.sin(w * x)
916
- for i in range(2, self.n):
917
- c = (2 * i - 1.)
918
- ret += jnp.sin(c * w * x) / c
919
- z = 0.5 + 2. / jnp.pi * ret
920
- return z
923
+ w = jnp.pi * 2. / self.t_period
924
+ ret = jnp.sin(w * x)
925
+ for i in range(2, self.n):
926
+ c = (2 * i - 1.)
927
+ ret += jnp.sin(c * w * x) / c
928
+ z = 0.5 + 2. / jnp.pi * ret
929
+ return z
921
930
 
922
- def __repr__(self):
923
- return f'{self.__class__.__name__}(n={self.n}, t_period={self.t_period})'
931
+ def __repr__(self):
932
+ return f'{self.__class__.__name__}(n={self.n}, t_period={self.t_period})'
924
933
 
925
- def __hash__(self):
926
- return hash((self.__class__, self.n, self.t_period))
934
+ def __hash__(self):
935
+ return hash((self.__class__, self.n, self.t_period))
927
936
 
928
937
 
929
938
  def squarewave_fourier_series(
@@ -932,91 +941,92 @@ def squarewave_fourier_series(
932
941
  t_period: float = 8.,
933
942
 
934
943
  ):
935
- r"""Judge spiking state with a squarewave fourier series.
944
+ r"""Judge spiking state with a squarewave fourier series.
936
945
 
937
- If `origin=False`, computes the forward function:
946
+ If `origin=False`, computes the forward function:
938
947
 
939
- .. math::
948
+ .. math::
940
949
 
941
- g(x) = \begin{cases}
942
- 1, & x \geq 0 \\
943
- 0, & x < 0 \\
944
- \end{cases}
950
+ g(x) = \begin{cases}
951
+ 1, & x \geq 0 \\
952
+ 0, & x < 0 \\
953
+ \end{cases}
945
954
 
946
- If `origin=True`, computes the original function:
955
+ If `origin=True`, computes the original function:
947
956
 
948
- .. math::
957
+ .. math::
949
958
 
950
- g(x) = 0.5 + \frac{1}{\pi}*\sum_{i=1}^n {\sin\left({(2i-1)*2\pi}*x/T\right) \over 2i-1 }
959
+ g(x) = 0.5 + \frac{1}{\pi}*\sum_{i=1}^n {\sin\left({(2i-1)*2\pi}*x/T\right) \over 2i-1 }
951
960
 
952
- Backward function:
961
+ Backward function:
953
962
 
954
- .. math::
963
+ .. math::
955
964
 
956
- g'(x) = \sum_{i=1}^n\frac{4\cos\left((2 * i - 1.) * 2\pi * x / T\right)}{T}
965
+ g'(x) = \sum_{i=1}^n\frac{4\cos\left((2 * i - 1.) * 2\pi * x / T\right)}{T}
957
966
 
958
- .. plot::
959
- :include-source: True
967
+ .. plot::
968
+ :include-source: True
960
969
 
961
- >>> import brainstate.nn as nn
962
- >>> import brainstate as bst
963
- >>> import matplotlib.pyplot as plt
964
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
965
- >>> for n in [2, 4, 8]:
966
- >>> f = nn.surrogate.SquarewaveFourierSeries(n=n)
967
- >>> grads1 = bst.transform.vector_grad(f)(xs)
968
- >>> plt.plot(xs, grads1, label=f'n={n}')
969
- >>> plt.legend()
970
- >>> plt.show()
970
+ >>> import jax
971
+ >>> import brainstate.nn as nn
972
+ >>> import brainstate as bst
973
+ >>> import matplotlib.pyplot as plt
974
+ >>> xs = jax.numpy.linspace(-3, 3, 1000)
975
+ >>> for n in [2, 4, 8]:
976
+ >>> f = bst.surrogate.SquarewaveFourierSeries(n=n)
977
+ >>> grads1 = bst.augment.vector_grad(f)(xs)
978
+ >>> plt.plot(xs, grads1, label=f'n={n}')
979
+ >>> plt.legend()
980
+ >>> plt.show()
971
981
 
972
- Parameters
973
- ----------
974
- x: jax.Array, Array
975
- The input data.
976
- n: int
977
- t_period: float
982
+ Parameters
983
+ ----------
984
+ x: jax.Array, Array
985
+ The input data.
986
+ n: int
987
+ t_period: float
978
988
 
979
989
 
980
- Returns
981
- -------
982
- out: jax.Array
983
- The spiking state.
990
+ Returns
991
+ -------
992
+ out: jax.Array
993
+ The spiking state.
984
994
 
985
- """
995
+ """
986
996
 
987
- return SquarewaveFourierSeries(n=n, t_period=t_period)(x)
997
+ return SquarewaveFourierSeries(n=n, t_period=t_period)(x)
988
998
 
989
999
 
990
1000
  class S2NN(Surrogate):
991
- """Judge spiking state with the S2NN surrogate spiking function.
1001
+ """Judge spiking state with the S2NN surrogate spiking function.
992
1002
 
993
- See Also
994
- --------
995
- s2nn
996
- """
1003
+ See Also
1004
+ --------
1005
+ s2nn
1006
+ """
997
1007
 
998
- def __init__(self, alpha=4., beta=1., epsilon=1e-8):
999
- super().__init__()
1000
- self.alpha = alpha
1001
- self.beta = beta
1002
- self.epsilon = epsilon
1008
+ def __init__(self, alpha=4., beta=1., epsilon=1e-8):
1009
+ super().__init__()
1010
+ self.alpha = alpha
1011
+ self.beta = beta
1012
+ self.epsilon = epsilon
1003
1013
 
1004
- def surrogate_fun(self, x):
1005
- z = jnp.where(x < 0.,
1006
- sci.special.expit(x * self.alpha),
1007
- self.beta * jnp.log(jnp.abs((x + 1.)) + self.epsilon) + 0.5)
1008
- return z
1014
+ def surrogate_fun(self, x):
1015
+ z = jnp.where(x < 0.,
1016
+ sci.special.expit(x * self.alpha),
1017
+ self.beta * jnp.log(jnp.abs((x + 1.)) + self.epsilon) + 0.5)
1018
+ return z
1009
1019
 
1010
- def surrogate_grad(self, x):
1011
- sg = sci.special.expit(self.alpha * x)
1012
- dx = jnp.where(x < 0., self.alpha * sg * (1. - sg), self.beta / (x + 1.))
1013
- return dx
1020
+ def surrogate_grad(self, x):
1021
+ sg = sci.special.expit(self.alpha * x)
1022
+ dx = jnp.where(x < 0., self.alpha * sg * (1. - sg), self.beta / (x + 1.))
1023
+ return dx
1014
1024
 
1015
- def __repr__(self):
1016
- return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta}, epsilon={self.epsilon})'
1025
+ def __repr__(self):
1026
+ return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta}, epsilon={self.epsilon})'
1017
1027
 
1018
- def __hash__(self):
1019
- return hash((self.__class__, self.alpha, self.beta, self.epsilon))
1028
+ def __hash__(self):
1029
+ return hash((self.__class__, self.alpha, self.beta, self.epsilon))
1020
1030
 
1021
1031
 
1022
1032
  def s2nn(
@@ -1026,101 +1036,102 @@ def s2nn(
1026
1036
  epsilon: float = 1e-8,
1027
1037
 
1028
1038
  ):
1029
- r"""Judge spiking state with the S2NN surrogate spiking function [1]_.
1030
-
1031
- If `origin=False`, computes the forward function:
1032
-
1033
- .. math::
1034
-
1035
- g(x) = \begin{cases}
1036
- 1, & x \geq 0 \\
1037
- 0, & x < 0 \\
1038
- \end{cases}
1039
+ r"""Judge spiking state with the S2NN surrogate spiking function [1]_.
1039
1040
 
1040
- If `origin=True`, computes the original function:
1041
+ If `origin=False`, computes the forward function:
1041
1042
 
1042
- .. math::
1043
+ .. math::
1043
1044
 
1044
- \begin{split}g(x) = \begin{cases}
1045
- \mathrm{sigmoid} (\alpha x), x < 0 \\
1046
- \beta \ln(|x + 1|) + 0.5, x \ge 0
1047
- \end{cases}\end{split}
1045
+ g(x) = \begin{cases}
1046
+ 1, & x \geq 0 \\
1047
+ 0, & x < 0 \\
1048
+ \end{cases}
1048
1049
 
1049
- Backward function:
1050
+ If `origin=True`, computes the original function:
1050
1051
 
1051
- .. math::
1052
+ .. math::
1052
1053
 
1053
- \begin{split}g'(x) = \begin{cases}
1054
- \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x), x < 0 \\
1055
- \frac{\beta}{(x + 1)}, x \ge 0
1056
- \end{cases}\end{split}
1057
-
1058
- .. plot::
1059
- :include-source: True
1060
-
1061
- >>> import brainstate.nn as nn
1062
- >>> import brainstate as bst
1063
- >>> import matplotlib.pyplot as plt
1064
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
1065
- >>> grads = bst.transform.vector_grad(nn.surrogate.s2nn)(xs, 4., 1.)
1066
- >>> plt.plot(xs, grads, label=r'$\alpha=4, \beta=1$')
1067
- >>> grads = bst.transform.vector_grad(nn.surrogate.s2nn)(xs, 8., 2.)
1068
- >>> plt.plot(xs, grads, label=r'$\alpha=8, \beta=2$')
1069
- >>> plt.legend()
1070
- >>> plt.show()
1071
-
1072
- Parameters
1073
- ----------
1074
- x: jax.Array, Array
1075
- The input data.
1076
- alpha: float
1077
- The param that controls the gradient when ``x < 0``.
1078
- beta: float
1079
- The param that controls the gradient when ``x >= 0``
1080
- epsilon: float
1081
- Avoid nan
1054
+ \begin{split}g(x) = \begin{cases}
1055
+ \mathrm{sigmoid} (\alpha x), x < 0 \\
1056
+ \beta \ln(|x + 1|) + 0.5, x \ge 0
1057
+ \end{cases}\end{split}
1082
1058
 
1059
+ Backward function:
1083
1060
 
1084
- Returns
1085
- -------
1086
- out: jax.Array
1087
- The spiking state.
1061
+ .. math::
1088
1062
 
1089
- References
1090
- ----------
1091
- .. [1] Suetake, Kazuma et al. “S2NN: Time Step Reduction of Spiking Surrogate Gradients for Training Energy Efficient Single-Step Neural Networks.” ArXiv abs/2201.10879 (2022): n. pag.
1063
+ \begin{split}g'(x) = \begin{cases}
1064
+ \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x), x < 0 \\
1065
+ \frac{\beta}{(x + 1)}, x \ge 0
1066
+ \end{cases}\end{split}
1092
1067
 
1093
- """
1094
- return S2NN(alpha=alpha, beta=beta, epsilon=epsilon)(x)
1068
+ .. plot::
1069
+ :include-source: True
1070
+
1071
+ >>> import jax
1072
+ >>> import brainstate.nn as nn
1073
+ >>> import brainstate as bst
1074
+ >>> import matplotlib.pyplot as plt
1075
+ >>> xs = jax.numpy.linspace(-3, 3, 1000)
1076
+ >>> grads = bst.augment.vector_grad(bst.surrogate.s2nn)(xs, 4., 1.)
1077
+ >>> plt.plot(xs, grads, label=r'$\alpha=4, \beta=1$')
1078
+ >>> grads = bst.augment.vector_grad(bst.surrogate.s2nn)(xs, 8., 2.)
1079
+ >>> plt.plot(xs, grads, label=r'$\alpha=8, \beta=2$')
1080
+ >>> plt.legend()
1081
+ >>> plt.show()
1082
+
1083
+ Parameters
1084
+ ----------
1085
+ x: jax.Array, Array
1086
+ The input data.
1087
+ alpha: float
1088
+ The param that controls the gradient when ``x < 0``.
1089
+ beta: float
1090
+ The param that controls the gradient when ``x >= 0``
1091
+ epsilon: float
1092
+ Avoid nan
1093
+
1094
+
1095
+ Returns
1096
+ -------
1097
+ out: jax.Array
1098
+ The spiking state.
1099
+
1100
+ References
1101
+ ----------
1102
+ .. [1] Suetake, Kazuma et al. “S2NN: Time Step Reduction of Spiking Surrogate Gradients for Training Energy Efficient Single-Step Neural Networks.” ArXiv abs/2201.10879 (2022): n. pag.
1103
+
1104
+ """
1105
+ return S2NN(alpha=alpha, beta=beta, epsilon=epsilon)(x)
1095
1106
 
1096
1107
 
1097
1108
  class QPseudoSpike(Surrogate):
1098
- """Judge spiking state with the q-PseudoSpike surrogate function.
1109
+ """Judge spiking state with the q-PseudoSpike surrogate function.
1099
1110
 
1100
- See Also
1101
- --------
1102
- q_pseudo_spike
1103
- """
1111
+ See Also
1112
+ --------
1113
+ q_pseudo_spike
1114
+ """
1104
1115
 
1105
- def __init__(self, alpha=2.):
1106
- super().__init__()
1107
- self.alpha = alpha
1116
+ def __init__(self, alpha=2.):
1117
+ super().__init__()
1118
+ self.alpha = alpha
1108
1119
 
1109
- def surrogate_grad(self, x):
1110
- dx = jnp.power(1 + 2 / (self.alpha + 1) * jnp.abs(x), -self.alpha)
1111
- return dx
1120
+ def surrogate_grad(self, x):
1121
+ dx = jnp.power(1 + 2 / (self.alpha + 1) * jnp.abs(x), -self.alpha)
1122
+ return dx
1112
1123
 
1113
- def surrogate_fun(self, x):
1114
- z = jnp.where(x < 0.,
1115
- 0.5 * jnp.power(1 - 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha),
1116
- 1. - 0.5 * jnp.power(1 + 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha))
1117
- return z
1124
+ def surrogate_fun(self, x):
1125
+ z = jnp.where(x < 0.,
1126
+ 0.5 * jnp.power(1 - 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha),
1127
+ 1. - 0.5 * jnp.power(1 + 2 / (self.alpha - 1) * jnp.abs(x), 1 - self.alpha))
1128
+ return z
1118
1129
 
1119
- def __repr__(self):
1120
- return f'{self.__class__.__name__}(alpha={self.alpha})'
1130
+ def __repr__(self):
1131
+ return f'{self.__class__.__name__}(alpha={self.alpha})'
1121
1132
 
1122
- def __hash__(self):
1123
- return hash((self.__class__, self.alpha))
1133
+ def __hash__(self):
1134
+ return hash((self.__class__, self.alpha))
1124
1135
 
1125
1136
 
1126
1137
  def q_pseudo_spike(
@@ -1128,91 +1139,92 @@ def q_pseudo_spike(
1128
1139
  alpha: float = 2.,
1129
1140
 
1130
1141
  ):
1131
- r"""Judge spiking state with the q-PseudoSpike surrogate function [1]_.
1142
+ r"""Judge spiking state with the q-PseudoSpike surrogate function [1]_.
1132
1143
 
1133
- If `origin=False`, computes the forward function:
1144
+ If `origin=False`, computes the forward function:
1134
1145
 
1135
- .. math::
1146
+ .. math::
1136
1147
 
1137
- g(x) = \begin{cases}
1138
- 1, & x \geq 0 \\
1139
- 0, & x < 0 \\
1140
- \end{cases}
1148
+ g(x) = \begin{cases}
1149
+ 1, & x \geq 0 \\
1150
+ 0, & x < 0 \\
1151
+ \end{cases}
1141
1152
 
1142
- If `origin=True`, computes the original function:
1153
+ If `origin=True`, computes the original function:
1143
1154
 
1144
- .. math::
1155
+ .. math::
1145
1156
 
1146
- \begin{split}g(x) =
1147
- \begin{cases}
1148
- \frac{1}{2}(1-\frac{2x}{\alpha-1})^{1-\alpha}, & x < 0 \\
1149
- 1 - \frac{1}{2}(1+\frac{2x}{\alpha-1})^{1-\alpha}, & x \geq 0.
1150
- \end{cases}\end{split}
1157
+ \begin{split}g(x) =
1158
+ \begin{cases}
1159
+ \frac{1}{2}(1-\frac{2x}{\alpha-1})^{1-\alpha}, & x < 0 \\
1160
+ 1 - \frac{1}{2}(1+\frac{2x}{\alpha-1})^{1-\alpha}, & x \geq 0.
1161
+ \end{cases}\end{split}
1151
1162
 
1152
- Backward function:
1163
+ Backward function:
1153
1164
 
1154
- .. math::
1165
+ .. math::
1155
1166
 
1156
- g'(x) = (1+\frac{2|x|}{\alpha-1})^{-\alpha}
1167
+ g'(x) = (1+\frac{2|x|}{\alpha-1})^{-\alpha}
1157
1168
 
1158
- .. plot::
1159
- :include-source: True
1169
+ .. plot::
1170
+ :include-source: True
1160
1171
 
1161
- >>> import brainstate.nn as nn
1162
- >>> import brainstate as bst
1163
- >>> import matplotlib.pyplot as plt
1164
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
1165
- >>> for alpha in [0.5, 1., 2., 4.]:
1166
- >>> grads = bst.transform.vector_grad(nn.surrogate.q_pseudo_spike)(xs, alpha)
1167
- >>> plt.plot(xs, grads, label=r'$\alpha=$' + str(alpha))
1168
- >>> plt.legend()
1169
- >>> plt.show()
1172
+ >>> import jax
1173
+ >>> import brainstate.nn as nn
1174
+ >>> import brainstate as bst
1175
+ >>> import matplotlib.pyplot as plt
1176
+ >>> xs = jax.numpy.linspace(-3, 3, 1000)
1177
+ >>> for alpha in [0.5, 1., 2., 4.]:
1178
+ >>> grads = bst.augment.vector_grad(bst.surrogate.q_pseudo_spike)(xs, alpha)
1179
+ >>> plt.plot(xs, grads, label=r'$\alpha=$' + str(alpha))
1180
+ >>> plt.legend()
1181
+ >>> plt.show()
1170
1182
 
1171
- Parameters
1172
- ----------
1173
- x: jax.Array, Array
1174
- The input data.
1175
- alpha: float
1176
- The parameter to control tail fatness of gradient.
1183
+ Parameters
1184
+ ----------
1185
+ x: jax.Array, Array
1186
+ The input data.
1187
+ alpha: float
1188
+ The parameter to control tail fatness of gradient.
1177
1189
 
1178
1190
 
1179
- Returns
1180
- -------
1181
- out: jax.Array
1182
- The spiking state.
1191
+ Returns
1192
+ -------
1193
+ out: jax.Array
1194
+ The spiking state.
1183
1195
 
1184
- References
1185
- ----------
1186
- .. [1] Herranz-Celotti, Luca and Jean Rouat. “Surrogate Gradients Design.” ArXiv abs/2202.00282 (2022): n. pag.
1187
- """
1188
- return QPseudoSpike(alpha=alpha)(x)
1196
+ References
1197
+ ----------
1198
+ .. [1] Herranz-Celotti, Luca and Jean Rouat. “Surrogate Gradients Design.” ArXiv abs/2202.00282 (2022): n. pag.
1199
+ """
1200
+ return QPseudoSpike(alpha=alpha)(x)
1189
1201
 
1190
1202
 
1191
1203
  class LeakyRelu(Surrogate):
1192
- """Judge spiking state with the Leaky ReLU function.
1204
+ """Judge spiking state with the Leaky ReLU function.
1193
1205
 
1194
- See Also
1195
- --------
1196
- leaky_relu
1197
- """
1206
+ See Also
1207
+ --------
1208
+ leaky_relu
1209
+ """
1198
1210
 
1199
- def __init__(self, alpha=0.1, beta=1.):
1200
- super().__init__()
1201
- self.alpha = alpha
1202
- self.beta = beta
1211
+ def __init__(self, alpha=0.1, beta=1.):
1212
+ super().__init__()
1213
+ self.alpha = alpha
1214
+ self.beta = beta
1203
1215
 
1204
- def surrogate_fun(self, x):
1205
- return jnp.where(x < 0., self.alpha * x, self.beta * x)
1216
+ def surrogate_fun(self, x):
1217
+ return jnp.where(x < 0., self.alpha * x, self.beta * x)
1206
1218
 
1207
- def surrogate_grad(self, x):
1208
- dx = jnp.where(x < 0., self.alpha, self.beta)
1209
- return dx
1219
+ def surrogate_grad(self, x):
1220
+ dx = jnp.where(x < 0., self.alpha, self.beta)
1221
+ return dx
1210
1222
 
1211
- def __repr__(self):
1212
- return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta})'
1223
+ def __repr__(self):
1224
+ return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta})'
1213
1225
 
1214
- def __hash__(self):
1215
- return hash((self.__class__, self.alpha, self.beta))
1226
+ def __hash__(self):
1227
+ return hash((self.__class__, self.alpha, self.beta))
1216
1228
 
1217
1229
 
1218
1230
  def leaky_relu(
@@ -1221,100 +1233,101 @@ def leaky_relu(
1221
1233
  beta: float = 1.,
1222
1234
 
1223
1235
  ):
1224
- r"""Judge spiking state with the Leaky ReLU function.
1236
+ r"""Judge spiking state with the Leaky ReLU function.
1225
1237
 
1226
- If `origin=False`, computes the forward function:
1227
-
1228
- .. math::
1229
-
1230
- g(x) = \begin{cases}
1231
- 1, & x \geq 0 \\
1232
- 0, & x < 0 \\
1233
- \end{cases}
1238
+ If `origin=False`, computes the forward function:
1234
1239
 
1235
- If `origin=True`, computes the original function:
1240
+ .. math::
1236
1241
 
1237
- .. math::
1242
+ g(x) = \begin{cases}
1243
+ 1, & x \geq 0 \\
1244
+ 0, & x < 0 \\
1245
+ \end{cases}
1238
1246
 
1239
- \begin{split}g(x) =
1240
- \begin{cases}
1241
- \beta \cdot x, & x \geq 0 \\
1242
- \alpha \cdot x, & x < 0 \\
1243
- \end{cases}\end{split}
1244
-
1245
- Backward function:
1246
-
1247
- .. math::
1248
-
1249
- \begin{split}g'(x) =
1250
- \begin{cases}
1251
- \beta, & x \geq 0 \\
1252
- \alpha, & x < 0 \\
1253
- \end{cases}\end{split}
1247
+ If `origin=True`, computes the original function:
1254
1248
 
1255
- .. plot::
1256
- :include-source: True
1249
+ .. math::
1257
1250
 
1258
- >>> import brainstate.nn as nn
1259
- >>> import brainstate as bst
1260
- >>> import matplotlib.pyplot as plt
1261
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
1262
- >>> grads = bst.transform.vector_grad(nn.surrogate.leaky_relu)(xs, 0., 1.)
1263
- >>> plt.plot(xs, grads, label=r'$\alpha=0., \beta=1.$')
1264
- >>> plt.legend()
1265
- >>> plt.show()
1251
+ \begin{split}g(x) =
1252
+ \begin{cases}
1253
+ \beta \cdot x, & x \geq 0 \\
1254
+ \alpha \cdot x, & x < 0 \\
1255
+ \end{cases}\end{split}
1266
1256
 
1267
- Parameters
1268
- ----------
1269
- x: jax.Array, Array
1270
- The input data.
1271
- alpha: float
1272
- The parameter to control the gradient when :math:`x < 0`.
1273
- beta: float
1274
- The parameter to control the gradient when :math:`x >= 0`.
1257
+ Backward function:
1275
1258
 
1259
+ .. math::
1276
1260
 
1277
- Returns
1278
- -------
1279
- out: jax.Array
1280
- The spiking state.
1281
- """
1282
- return LeakyRelu(alpha=alpha, beta=beta)(x)
1261
+ \begin{split}g'(x) =
1262
+ \begin{cases}
1263
+ \beta, & x \geq 0 \\
1264
+ \alpha, & x < 0 \\
1265
+ \end{cases}\end{split}
1266
+
1267
+ .. plot::
1268
+ :include-source: True
1269
+
1270
+ >>> import jax
1271
+ >>> import brainstate.nn as nn
1272
+ >>> import brainstate as bst
1273
+ >>> import matplotlib.pyplot as plt
1274
+ >>> xs = jax.numpy.linspace(-3, 3, 1000)
1275
+ >>> grads = bst.augment.vector_grad(bst.surrogate.leaky_relu)(xs, 0., 1.)
1276
+ >>> plt.plot(xs, grads, label=r'$\alpha=0., \beta=1.$')
1277
+ >>> plt.legend()
1278
+ >>> plt.show()
1279
+
1280
+ Parameters
1281
+ ----------
1282
+ x: jax.Array, Array
1283
+ The input data.
1284
+ alpha: float
1285
+ The parameter to control the gradient when :math:`x < 0`.
1286
+ beta: float
1287
+ The parameter to control the gradient when :math:`x >= 0`.
1288
+
1289
+
1290
+ Returns
1291
+ -------
1292
+ out: jax.Array
1293
+ The spiking state.
1294
+ """
1295
+ return LeakyRelu(alpha=alpha, beta=beta)(x)
1283
1296
 
1284
1297
 
1285
1298
  class LogTailedRelu(Surrogate):
1286
- """Judge spiking state with the Log-tailed ReLU function.
1299
+ """Judge spiking state with the Log-tailed ReLU function.
1287
1300
 
1288
- See Also
1289
- --------
1290
- log_tailed_relu
1291
- """
1301
+ See Also
1302
+ --------
1303
+ log_tailed_relu
1304
+ """
1292
1305
 
1293
- def __init__(self, alpha=0.):
1294
- super().__init__()
1295
- self.alpha = alpha
1306
+ def __init__(self, alpha=0.):
1307
+ super().__init__()
1308
+ self.alpha = alpha
1296
1309
 
1297
- def surrogate_fun(self, x):
1298
- z = jnp.where(x > 1,
1299
- jnp.log(x),
1300
- jnp.where(x > 0,
1301
- x,
1302
- self.alpha * x))
1303
- return z
1310
+ def surrogate_fun(self, x):
1311
+ z = jnp.where(x > 1,
1312
+ jnp.log(x),
1313
+ jnp.where(x > 0,
1314
+ x,
1315
+ self.alpha * x))
1316
+ return z
1304
1317
 
1305
- def surrogate_grad(self, x):
1306
- dx = jnp.where(x > 1,
1307
- 1 / x,
1308
- jnp.where(x > 0,
1309
- 1.,
1310
- self.alpha))
1311
- return dx
1318
+ def surrogate_grad(self, x):
1319
+ dx = jnp.where(x > 1,
1320
+ 1 / x,
1321
+ jnp.where(x > 0,
1322
+ 1.,
1323
+ self.alpha))
1324
+ return dx
1312
1325
 
1313
- def __repr__(self):
1314
- return f'{self.__class__.__name__}(alpha={self.alpha})'
1326
+ def __repr__(self):
1327
+ return f'{self.__class__.__name__}(alpha={self.alpha})'
1315
1328
 
1316
- def __hash__(self):
1317
- return hash((self.__class__, self.alpha))
1329
+ def __hash__(self):
1330
+ return hash((self.__class__, self.alpha))
1318
1331
 
1319
1332
 
1320
1333
  def log_tailed_relu(
@@ -1322,93 +1335,94 @@ def log_tailed_relu(
1322
1335
  alpha: float = 0.,
1323
1336
 
1324
1337
  ):
1325
- r"""Judge spiking state with the Log-tailed ReLU function [1]_.
1326
-
1327
- If `origin=False`, computes the forward function:
1328
-
1329
- .. math::
1330
-
1331
- g(x) = \begin{cases}
1332
- 1, & x \geq 0 \\
1333
- 0, & x < 0 \\
1334
- \end{cases}
1338
+ r"""Judge spiking state with the Log-tailed ReLU function [1]_.
1335
1339
 
1336
- If `origin=True`, computes the original function:
1340
+ If `origin=False`, computes the forward function:
1337
1341
 
1338
- .. math::
1342
+ .. math::
1339
1343
 
1340
- \begin{split}g(x) =
1341
- \begin{cases}
1342
- \alpha x, & x \leq 0 \\
1343
- x, & 0 < x \leq 0 \\
1344
- log(x), x > 1 \\
1345
- \end{cases}\end{split}
1346
-
1347
- Backward function:
1348
-
1349
- .. math::
1350
-
1351
- \begin{split}g'(x) =
1352
- \begin{cases}
1353
- \alpha, & x \leq 0 \\
1354
- 1, & 0 < x \leq 0 \\
1355
- \frac{1}{x}, x > 1 \\
1356
- \end{cases}\end{split}
1344
+ g(x) = \begin{cases}
1345
+ 1, & x \geq 0 \\
1346
+ 0, & x < 0 \\
1347
+ \end{cases}
1357
1348
 
1358
- .. plot::
1359
- :include-source: True
1349
+ If `origin=True`, computes the original function:
1360
1350
 
1361
- >>> import brainstate.nn as nn
1362
- >>> import brainstate as bst
1363
- >>> import matplotlib.pyplot as plt
1364
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
1365
- >>> grads = bst.transform.vector_grad(nn.surrogate.leaky_relu)(xs, 0., 1.)
1366
- >>> plt.plot(xs, grads, label=r'$\alpha=0., \beta=1.$')
1367
- >>> plt.legend()
1368
- >>> plt.show()
1351
+ .. math::
1369
1352
 
1370
- Parameters
1371
- ----------
1372
- x: jax.Array, Array
1373
- The input data.
1374
- alpha: float
1375
- The parameter to control the gradient.
1353
+ \begin{split}g(x) =
1354
+ \begin{cases}
1355
+ \alpha x, & x \leq 0 \\
1356
+ x, & 0 < x \leq 0 \\
1357
+ log(x), x > 1 \\
1358
+ \end{cases}\end{split}
1376
1359
 
1360
+ Backward function:
1377
1361
 
1378
- Returns
1379
- -------
1380
- out: jax.Array
1381
- The spiking state.
1362
+ .. math::
1382
1363
 
1383
- References
1384
- ----------
1385
- .. [1] Cai, Zhaowei et al. “Deep Learning with Low Precision by Half-Wave Gaussian Quantization.” 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2017): 5406-5414.
1386
- """
1387
- return LogTailedRelu(alpha=alpha)(x)
1364
+ \begin{split}g'(x) =
1365
+ \begin{cases}
1366
+ \alpha, & x \leq 0 \\
1367
+ 1, & 0 < x \leq 0 \\
1368
+ \frac{1}{x}, x > 1 \\
1369
+ \end{cases}\end{split}
1370
+
1371
+ .. plot::
1372
+ :include-source: True
1373
+
1374
+ >>> import jax
1375
+ >>> import brainstate.nn as nn
1376
+ >>> import brainstate as bst
1377
+ >>> import matplotlib.pyplot as plt
1378
+ >>> xs = jax.numpy.linspace(-3, 3, 1000)
1379
+ >>> grads = bst.augment.vector_grad(bst.surrogate.leaky_relu)(xs, 0., 1.)
1380
+ >>> plt.plot(xs, grads, label=r'$\alpha=0., \beta=1.$')
1381
+ >>> plt.legend()
1382
+ >>> plt.show()
1383
+
1384
+ Parameters
1385
+ ----------
1386
+ x: jax.Array, Array
1387
+ The input data.
1388
+ alpha: float
1389
+ The parameter to control the gradient.
1390
+
1391
+
1392
+ Returns
1393
+ -------
1394
+ out: jax.Array
1395
+ The spiking state.
1396
+
1397
+ References
1398
+ ----------
1399
+ .. [1] Cai, Zhaowei et al. “Deep Learning with Low Precision by Half-Wave Gaussian Quantization.” 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2017): 5406-5414.
1400
+ """
1401
+ return LogTailedRelu(alpha=alpha)(x)
1388
1402
 
1389
1403
 
1390
1404
  class ReluGrad(Surrogate):
1391
- """Judge spiking state with the ReLU gradient function.
1405
+ """Judge spiking state with the ReLU gradient function.
1392
1406
 
1393
- See Also
1394
- --------
1395
- relu_grad
1396
- """
1407
+ See Also
1408
+ --------
1409
+ relu_grad
1410
+ """
1397
1411
 
1398
- def __init__(self, alpha=0.3, width=1.):
1399
- super().__init__()
1400
- self.alpha = alpha
1401
- self.width = width
1412
+ def __init__(self, alpha=0.3, width=1.):
1413
+ super().__init__()
1414
+ self.alpha = alpha
1415
+ self.width = width
1402
1416
 
1403
- def surrogate_grad(self, x):
1404
- dx = jnp.maximum(self.alpha * self.width - jnp.abs(x) * self.alpha, 0)
1405
- return dx
1417
+ def surrogate_grad(self, x):
1418
+ dx = jnp.maximum(self.alpha * self.width - jnp.abs(x) * self.alpha, 0)
1419
+ return dx
1406
1420
 
1407
- def __repr__(self):
1408
- return f'{self.__class__.__name__}(alpha={self.alpha}, width={self.width})'
1421
+ def __repr__(self):
1422
+ return f'{self.__class__.__name__}(alpha={self.alpha}, width={self.width})'
1409
1423
 
1410
- def __hash__(self):
1411
- return hash((self.__class__, self.alpha, self.width))
1424
+ def __hash__(self):
1425
+ return hash((self.__class__, self.alpha, self.width))
1412
1426
 
1413
1427
 
1414
1428
  def relu_grad(
@@ -1416,80 +1430,81 @@ def relu_grad(
1416
1430
  alpha: float = 0.3,
1417
1431
  width: float = 1.,
1418
1432
  ):
1419
- r"""Spike function with the ReLU gradient function [1]_.
1433
+ r"""Spike function with the ReLU gradient function [1]_.
1420
1434
 
1421
- The forward function:
1435
+ The forward function:
1422
1436
 
1423
- .. math::
1437
+ .. math::
1424
1438
 
1425
- g(x) = \begin{cases}
1426
- 1, & x \geq 0 \\
1427
- 0, & x < 0 \\
1428
- \end{cases}
1429
-
1430
- Backward function:
1431
-
1432
- .. math::
1433
-
1434
- g'(x) = \text{ReLU}(\alpha * (\mathrm{width}-|x|))
1435
-
1436
- .. plot::
1437
- :include-source: True
1438
-
1439
- >>> import brainstate.nn as nn
1440
- >>> import brainstate as bst
1441
- >>> import matplotlib.pyplot as plt
1442
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
1443
- >>> for s in [0.5, 1.]:
1444
- >>> for w in [1, 2.]:
1445
- >>> grads = bst.transform.vector_grad(nn.surrogate.relu_grad)(xs, s, w)
1446
- >>> plt.plot(xs, grads, label=r'$\alpha=$' + f'{s}, width={w}')
1447
- >>> plt.legend()
1448
- >>> plt.show()
1449
-
1450
- Parameters
1451
- ----------
1452
- x: jax.Array, Array
1453
- The input data.
1454
- alpha: float
1455
- The parameter to control the gradient.
1456
- width: float
1457
- The parameter to control the width of the gradient.
1458
-
1459
- Returns
1460
- -------
1461
- out: jax.Array
1462
- The spiking state.
1463
-
1464
- References
1465
- ----------
1466
- .. [1] Neftci, E. O., Mostafa, H. & Zenke, F. Surrogate gradient learning in spiking neural networks. IEEE Signal Process. Mag. 36, 61–63 (2019).
1467
- """
1468
- return ReluGrad(alpha=alpha, width=width)(x)
1439
+ g(x) = \begin{cases}
1440
+ 1, & x \geq 0 \\
1441
+ 0, & x < 0 \\
1442
+ \end{cases}
1443
+
1444
+ Backward function:
1445
+
1446
+ .. math::
1447
+
1448
+ g'(x) = \text{ReLU}(\alpha * (\mathrm{width}-|x|))
1449
+
1450
+ .. plot::
1451
+ :include-source: True
1452
+
1453
+ >>> import jax
1454
+ >>> import brainstate.nn as nn
1455
+ >>> import brainstate as bst
1456
+ >>> import matplotlib.pyplot as plt
1457
+ >>> xs = jax.numpy.linspace(-3, 3, 1000)
1458
+ >>> for s in [0.5, 1.]:
1459
+ >>> for w in [1, 2.]:
1460
+ >>> grads = bst.augment.vector_grad(bst.surrogate.relu_grad)(xs, s, w)
1461
+ >>> plt.plot(xs, grads, label=r'$\alpha=$' + f'{s}, width={w}')
1462
+ >>> plt.legend()
1463
+ >>> plt.show()
1464
+
1465
+ Parameters
1466
+ ----------
1467
+ x: jax.Array, Array
1468
+ The input data.
1469
+ alpha: float
1470
+ The parameter to control the gradient.
1471
+ width: float
1472
+ The parameter to control the width of the gradient.
1473
+
1474
+ Returns
1475
+ -------
1476
+ out: jax.Array
1477
+ The spiking state.
1478
+
1479
+ References
1480
+ ----------
1481
+ .. [1] Neftci, E. O., Mostafa, H. & Zenke, F. Surrogate gradient learning in spiking neural networks. IEEE Signal Process. Mag. 36, 61–63 (2019).
1482
+ """
1483
+ return ReluGrad(alpha=alpha, width=width)(x)
1469
1484
 
1470
1485
 
1471
1486
  class GaussianGrad(Surrogate):
1472
- """Judge spiking state with the Gaussian gradient function.
1487
+ """Judge spiking state with the Gaussian gradient function.
1473
1488
 
1474
- See Also
1475
- --------
1476
- gaussian_grad
1477
- """
1489
+ See Also
1490
+ --------
1491
+ gaussian_grad
1492
+ """
1478
1493
 
1479
- def __init__(self, sigma=0.5, alpha=0.5):
1480
- super().__init__()
1481
- self.sigma = sigma
1482
- self.alpha = alpha
1494
+ def __init__(self, sigma=0.5, alpha=0.5):
1495
+ super().__init__()
1496
+ self.sigma = sigma
1497
+ self.alpha = alpha
1483
1498
 
1484
- def surrogate_grad(self, x):
1485
- dx = jnp.exp(-(x ** 2) / 2 * jnp.power(self.sigma, 2)) / (jnp.sqrt(2 * jnp.pi) * self.sigma)
1486
- return self.alpha * dx
1499
+ def surrogate_grad(self, x):
1500
+ dx = jnp.exp(-(x ** 2) / 2 * jnp.power(self.sigma, 2)) / (jnp.sqrt(2 * jnp.pi) * self.sigma)
1501
+ return self.alpha * dx
1487
1502
 
1488
- def __repr__(self):
1489
- return f'{self.__class__.__name__}(alpha={self.alpha}, sigma={self.sigma})'
1503
+ def __repr__(self):
1504
+ return f'{self.__class__.__name__}(alpha={self.alpha}, sigma={self.sigma})'
1490
1505
 
1491
- def __hash__(self):
1492
- return hash((self.__class__, self.alpha, self.sigma))
1506
+ def __hash__(self):
1507
+ return hash((self.__class__, self.alpha, self.sigma))
1493
1508
 
1494
1509
 
1495
1510
  def gaussian_grad(
@@ -1497,86 +1512,87 @@ def gaussian_grad(
1497
1512
  sigma: float = 0.5,
1498
1513
  alpha: float = 0.5,
1499
1514
  ):
1500
- r"""Spike function with the Gaussian gradient function [1]_.
1515
+ r"""Spike function with the Gaussian gradient function [1]_.
1501
1516
 
1502
- The forward function:
1517
+ The forward function:
1503
1518
 
1504
- .. math::
1519
+ .. math::
1505
1520
 
1506
- g(x) = \begin{cases}
1507
- 1, & x \geq 0 \\
1508
- 0, & x < 0 \\
1509
- \end{cases}
1521
+ g(x) = \begin{cases}
1522
+ 1, & x \geq 0 \\
1523
+ 0, & x < 0 \\
1524
+ \end{cases}
1510
1525
 
1511
- Backward function:
1526
+ Backward function:
1512
1527
 
1513
- .. math::
1528
+ .. math::
1514
1529
 
1515
- g'(x) = \alpha * \text{gaussian}(x, 0., \sigma)
1530
+ g'(x) = \alpha * \text{gaussian}(x, 0., \sigma)
1516
1531
 
1517
- .. plot::
1518
- :include-source: True
1532
+ .. plot::
1533
+ :include-source: True
1519
1534
 
1520
- >>> import brainstate.nn as nn
1521
- >>> import brainstate as bst
1522
- >>> import matplotlib.pyplot as plt
1523
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
1524
- >>> for s in [0.5, 1., 2.]:
1525
- >>> grads = bst.transform.vector_grad(nn.surrogate.gaussian_grad)(xs, s, 0.5)
1526
- >>> plt.plot(xs, grads, label=r'$\alpha=0.5, \sigma=$' + str(s))
1527
- >>> plt.legend()
1528
- >>> plt.show()
1535
+ >>> import jax
1536
+ >>> import brainstate.nn as nn
1537
+ >>> import brainstate as bst
1538
+ >>> import matplotlib.pyplot as plt
1539
+ >>> xs = jax.numpy.linspace(-3, 3, 1000)
1540
+ >>> for s in [0.5, 1., 2.]:
1541
+ >>> grads = bst.augment.vector_grad(bst.surrogate.gaussian_grad)(xs, s, 0.5)
1542
+ >>> plt.plot(xs, grads, label=r'$\alpha=0.5, \sigma=$' + str(s))
1543
+ >>> plt.legend()
1544
+ >>> plt.show()
1529
1545
 
1530
- Parameters
1531
- ----------
1532
- x: jax.Array, Array
1533
- The input data.
1534
- sigma: float
1535
- The parameter to control the variance of gaussian distribution.
1536
- alpha: float
1537
- The parameter to control the scale of the gradient.
1546
+ Parameters
1547
+ ----------
1548
+ x: jax.Array, Array
1549
+ The input data.
1550
+ sigma: float
1551
+ The parameter to control the variance of gaussian distribution.
1552
+ alpha: float
1553
+ The parameter to control the scale of the gradient.
1538
1554
 
1539
- Returns
1540
- -------
1541
- out: jax.Array
1542
- The spiking state.
1555
+ Returns
1556
+ -------
1557
+ out: jax.Array
1558
+ The spiking state.
1543
1559
 
1544
- References
1545
- ----------
1546
- .. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021).
1547
- """
1548
- return GaussianGrad(sigma=sigma, alpha=alpha)(x)
1560
+ References
1561
+ ----------
1562
+ .. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021).
1563
+ """
1564
+ return GaussianGrad(sigma=sigma, alpha=alpha)(x)
1549
1565
 
1550
1566
 
1551
1567
  class MultiGaussianGrad(Surrogate):
1552
- """Judge spiking state with the multi-Gaussian gradient function.
1568
+ """Judge spiking state with the multi-Gaussian gradient function.
1553
1569
 
1554
- See Also
1555
- --------
1556
- multi_gaussian_grad
1557
- """
1570
+ See Also
1571
+ --------
1572
+ multi_gaussian_grad
1573
+ """
1558
1574
 
1559
- def __init__(self, h=0.15, s=6.0, sigma=0.5, scale=0.5):
1560
- super().__init__()
1561
- self.h = h
1562
- self.s = s
1563
- self.sigma = sigma
1564
- self.scale = scale
1575
+ def __init__(self, h=0.15, s=6.0, sigma=0.5, scale=0.5):
1576
+ super().__init__()
1577
+ self.h = h
1578
+ self.s = s
1579
+ self.sigma = sigma
1580
+ self.scale = scale
1565
1581
 
1566
- def surrogate_grad(self, x):
1567
- g1 = jnp.exp(-x ** 2 / (2 * jnp.power(self.sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * self.sigma)
1568
- g2 = jnp.exp(-(x - self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2))
1569
- ) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma)
1570
- g3 = jnp.exp(-(x + self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2))
1571
- ) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma)
1572
- dx = g1 * (1. + self.h) - g2 * self.h - g3 * self.h
1573
- return self.scale * dx
1582
+ def surrogate_grad(self, x):
1583
+ g1 = jnp.exp(-x ** 2 / (2 * jnp.power(self.sigma, 2))) / (jnp.sqrt(2 * jnp.pi) * self.sigma)
1584
+ g2 = jnp.exp(-(x - self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2))
1585
+ ) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma)
1586
+ g3 = jnp.exp(-(x + self.sigma) ** 2 / (2 * jnp.power(self.s * self.sigma, 2))
1587
+ ) / (jnp.sqrt(2 * jnp.pi) * self.s * self.sigma)
1588
+ dx = g1 * (1. + self.h) - g2 * self.h - g3 * self.h
1589
+ return self.scale * dx
1574
1590
 
1575
- def __repr__(self):
1576
- return f'{self.__class__.__name__}(h={self.h}, s={self.s}, sigma={self.sigma}, scale={self.scale})'
1591
+ def __repr__(self):
1592
+ return f'{self.__class__.__name__}(h={self.h}, s={self.s}, sigma={self.sigma}, scale={self.scale})'
1577
1593
 
1578
- def __hash__(self):
1579
- return hash((self.__class__, self.h, self.s, self.sigma, self.scale))
1594
+ def __hash__(self):
1595
+ return hash((self.__class__, self.h, self.s, self.sigma, self.scale))
1580
1596
 
1581
1597
 
1582
1598
  def multi_gaussian_grad(
@@ -1586,209 +1602,212 @@ def multi_gaussian_grad(
1586
1602
  sigma: float = 0.5,
1587
1603
  scale: float = 0.5,
1588
1604
  ):
1589
- r"""Spike function with the multi-Gaussian gradient function [1]_.
1605
+ r"""Spike function with the multi-Gaussian gradient function [1]_.
1590
1606
 
1591
- The forward function:
1607
+ The forward function:
1592
1608
 
1593
- .. math::
1609
+ .. math::
1594
1610
 
1595
- g(x) = \begin{cases}
1596
- 1, & x \geq 0 \\
1597
- 0, & x < 0 \\
1598
- \end{cases}
1599
-
1600
- Backward function:
1601
-
1602
- .. math::
1603
-
1604
- \begin{array}{l}
1605
- g'(x)=(1+h){{{\mathcal{N}}}}(x, 0, {\sigma }^{2})
1606
- -h{{{\mathcal{N}}}}(x, \sigma,{(s\sigma )}^{2})-
1607
- h{{{\mathcal{N}}}}(x, -\sigma ,{(s\sigma )}^{2})
1608
- \end{array}
1609
-
1610
-
1611
- .. plot::
1612
- :include-source: True
1613
-
1614
- >>> import brainstate.nn as nn
1615
- >>> import brainstate as bst
1616
- >>> import matplotlib.pyplot as plt
1617
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
1618
- >>> grads = bst.transform.vector_grad(nn.surrogate.multi_gaussian_grad)(xs)
1619
- >>> plt.plot(xs, grads)
1620
- >>> plt.show()
1621
-
1622
- Parameters
1623
- ----------
1624
- x: jax.Array, Array
1625
- The input data.
1626
- h: float
1627
- The hyper-parameters of approximate function
1628
- s: float
1629
- The hyper-parameters of approximate function
1630
- sigma: float
1631
- The gaussian sigma.
1632
- scale: float
1633
- The gradient scale.
1634
-
1635
- Returns
1636
- -------
1637
- out: jax.Array
1638
- The spiking state.
1639
-
1640
- References
1641
- ----------
1642
- .. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021).
1643
- """
1644
- return MultiGaussianGrad(h=h, s=s, sigma=sigma, scale=scale)(x)
1611
+ g(x) = \begin{cases}
1612
+ 1, & x \geq 0 \\
1613
+ 0, & x < 0 \\
1614
+ \end{cases}
1615
+
1616
+ Backward function:
1617
+
1618
+ .. math::
1619
+
1620
+ \begin{array}{l}
1621
+ g'(x)=(1+h){{{\mathcal{N}}}}(x, 0, {\sigma }^{2})
1622
+ -h{{{\mathcal{N}}}}(x, \sigma,{(s\sigma )}^{2})-
1623
+ h{{{\mathcal{N}}}}(x, -\sigma ,{(s\sigma )}^{2})
1624
+ \end{array}
1625
+
1626
+
1627
+ .. plot::
1628
+ :include-source: True
1629
+
1630
+ >>> import jax
1631
+ >>> import brainstate.nn as nn
1632
+ >>> import brainstate as bst
1633
+ >>> import matplotlib.pyplot as plt
1634
+ >>> xs = jax.numpy.linspace(-3, 3, 1000)
1635
+ >>> grads = bst.augment.vector_grad(bst.surrogate.multi_gaussian_grad)(xs)
1636
+ >>> plt.plot(xs, grads)
1637
+ >>> plt.show()
1638
+
1639
+ Parameters
1640
+ ----------
1641
+ x: jax.Array, Array
1642
+ The input data.
1643
+ h: float
1644
+ The hyper-parameters of approximate function
1645
+ s: float
1646
+ The hyper-parameters of approximate function
1647
+ sigma: float
1648
+ The gaussian sigma.
1649
+ scale: float
1650
+ The gradient scale.
1651
+
1652
+ Returns
1653
+ -------
1654
+ out: jax.Array
1655
+ The spiking state.
1656
+
1657
+ References
1658
+ ----------
1659
+ .. [1] Yin, B., Corradi, F. & Bohté, S.M. Accurate and efficient time-domain classification with adaptive spiking recurrent neural networks. Nat Mach Intell 3, 905–913 (2021).
1660
+ """
1661
+ return MultiGaussianGrad(h=h, s=s, sigma=sigma, scale=scale)(x)
1645
1662
 
1646
1663
 
1647
1664
  class InvSquareGrad(Surrogate):
1648
- """Judge spiking state with the inverse-square surrogate gradient function.
1665
+ """Judge spiking state with the inverse-square surrogate gradient function.
1649
1666
 
1650
- See Also
1651
- --------
1652
- inv_square_grad
1653
- """
1667
+ See Also
1668
+ --------
1669
+ inv_square_grad
1670
+ """
1654
1671
 
1655
- def __init__(self, alpha=100.):
1656
- super().__init__()
1657
- self.alpha = alpha
1672
+ def __init__(self, alpha=100.):
1673
+ super().__init__()
1674
+ self.alpha = alpha
1658
1675
 
1659
- def surrogate_grad(self, x):
1660
- dx = 1. / (self.alpha * jnp.abs(x) + 1.0) ** 2
1661
- return dx
1676
+ def surrogate_grad(self, x):
1677
+ dx = 1. / (self.alpha * jnp.abs(x) + 1.0) ** 2
1678
+ return dx
1662
1679
 
1663
- def __repr__(self):
1664
- return f'{self.__class__.__name__}(alpha={self.alpha})'
1680
+ def __repr__(self):
1681
+ return f'{self.__class__.__name__}(alpha={self.alpha})'
1665
1682
 
1666
- def __hash__(self):
1667
- return hash((self.__class__, self.alpha))
1683
+ def __hash__(self):
1684
+ return hash((self.__class__, self.alpha))
1668
1685
 
1669
1686
 
1670
1687
  def inv_square_grad(
1671
1688
  x: jax.Array,
1672
1689
  alpha: float = 100.
1673
1690
  ):
1674
- r"""Spike function with the inverse-square surrogate gradient.
1691
+ r"""Spike function with the inverse-square surrogate gradient.
1675
1692
 
1676
- Forward function:
1693
+ Forward function:
1677
1694
 
1678
- .. math::
1695
+ .. math::
1679
1696
 
1680
- g(x) = \begin{cases}
1681
- 1, & x \geq 0 \\
1682
- 0, & x < 0 \\
1683
- \end{cases}
1697
+ g(x) = \begin{cases}
1698
+ 1, & x \geq 0 \\
1699
+ 0, & x < 0 \\
1700
+ \end{cases}
1684
1701
 
1685
- Backward function:
1702
+ Backward function:
1686
1703
 
1687
- .. math::
1704
+ .. math::
1688
1705
 
1689
- g'(x) = \frac{1}{(\alpha * |x| + 1.) ^ 2}
1706
+ g'(x) = \frac{1}{(\alpha * |x| + 1.) ^ 2}
1690
1707
 
1691
1708
 
1692
- .. plot::
1693
- :include-source: True
1709
+ .. plot::
1710
+ :include-source: True
1694
1711
 
1695
- >>> import brainstate.nn as nn
1696
- >>> import brainstate as bst
1697
- >>> import matplotlib.pyplot as plt
1698
- >>> xs = jax.numpy.linspace(-1, 1, 1000)
1699
- >>> for alpha in [1., 10., 100.]:
1700
- >>> grads = bst.transform.vector_grad(nn.surrogate.inv_square_grad)(xs, alpha)
1701
- >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
1702
- >>> plt.legend()
1703
- >>> plt.show()
1712
+ >>> import jax
1713
+ >>> import brainstate.nn as nn
1714
+ >>> import brainstate as bst
1715
+ >>> import matplotlib.pyplot as plt
1716
+ >>> xs = jax.numpy.linspace(-1, 1, 1000)
1717
+ >>> for alpha in [1., 10., 100.]:
1718
+ >>> grads = bst.augment.vector_grad(bst.surrogate.inv_square_grad)(xs, alpha)
1719
+ >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
1720
+ >>> plt.legend()
1721
+ >>> plt.show()
1704
1722
 
1705
- Parameters
1706
- ----------
1707
- x: jax.Array, Array
1708
- The input data.
1709
- alpha: float
1710
- Parameter to control smoothness of gradient
1723
+ Parameters
1724
+ ----------
1725
+ x: jax.Array, Array
1726
+ The input data.
1727
+ alpha: float
1728
+ Parameter to control smoothness of gradient
1711
1729
 
1712
- Returns
1713
- -------
1714
- out: jax.Array
1715
- The spiking state.
1716
- """
1717
- return InvSquareGrad(alpha=alpha)(x)
1730
+ Returns
1731
+ -------
1732
+ out: jax.Array
1733
+ The spiking state.
1734
+ """
1735
+ return InvSquareGrad(alpha=alpha)(x)
1718
1736
 
1719
1737
 
1720
1738
  class SlayerGrad(Surrogate):
1721
- """Judge spiking state with the slayer surrogate gradient function.
1739
+ """Judge spiking state with the slayer surrogate gradient function.
1722
1740
 
1723
- See Also
1724
- --------
1725
- slayer_grad
1726
- """
1741
+ See Also
1742
+ --------
1743
+ slayer_grad
1744
+ """
1727
1745
 
1728
- def __init__(self, alpha=1.):
1729
- super().__init__()
1730
- self.alpha = alpha
1746
+ def __init__(self, alpha=1.):
1747
+ super().__init__()
1748
+ self.alpha = alpha
1731
1749
 
1732
- def surrogate_grad(self, x):
1733
- dx = jnp.exp(-self.alpha * jnp.abs(x))
1734
- return dx
1750
+ def surrogate_grad(self, x):
1751
+ dx = jnp.exp(-self.alpha * jnp.abs(x))
1752
+ return dx
1735
1753
 
1736
- def __repr__(self):
1737
- return f'{self.__class__.__name__}(alpha={self.alpha})'
1754
+ def __repr__(self):
1755
+ return f'{self.__class__.__name__}(alpha={self.alpha})'
1738
1756
 
1739
- def __hash__(self):
1740
- return hash((self.__class__, self.alpha))
1757
+ def __hash__(self):
1758
+ return hash((self.__class__, self.alpha))
1741
1759
 
1742
1760
 
1743
1761
  def slayer_grad(
1744
1762
  x: jax.Array,
1745
1763
  alpha: float = 1.
1746
1764
  ):
1747
- r"""Spike function with the slayer surrogate gradient function.
1765
+ r"""Spike function with the slayer surrogate gradient function.
1748
1766
 
1749
- Forward function:
1767
+ Forward function:
1750
1768
 
1751
- .. math::
1769
+ .. math::
1752
1770
 
1753
- g(x) = \begin{cases}
1754
- 1, & x \geq 0 \\
1755
- 0, & x < 0 \\
1756
- \end{cases}
1757
-
1758
- Backward function:
1759
-
1760
- .. math::
1761
-
1762
- g'(x) = \exp(-\alpha |x|)
1763
-
1764
-
1765
- .. plot::
1766
- :include-source: True
1767
-
1768
- >>> import brainstate.nn as nn
1769
- >>> import brainstate as bst
1770
- >>> import matplotlib.pyplot as plt
1771
- >>> xs = jax.numpy.linspace(-3, 3, 1000)
1772
- >>> for alpha in [0.5, 1., 2., 4.]:
1773
- >>> grads = bst.transform.vector_grad(nn.surrogate.slayer_grad)(xs, alpha)
1774
- >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
1775
- >>> plt.legend()
1776
- >>> plt.show()
1777
-
1778
- Parameters
1779
- ----------
1780
- x: jax.Array, Array
1781
- The input data.
1782
- alpha: float
1783
- Parameter to control smoothness of gradient
1784
-
1785
- Returns
1786
- -------
1787
- out: jax.Array
1788
- The spiking state.
1789
-
1790
- References
1791
- ----------
1792
- .. [1] Shrestha, S. B. & Orchard, G. Slayer: spike layer error reassignment in time. In Advances in Neural Information Processing Systems Vol. 31, 1412–1421 (NeurIPS, 2018).
1793
- """
1794
- return SlayerGrad(alpha=alpha)(x)
1771
+ g(x) = \begin{cases}
1772
+ 1, & x \geq 0 \\
1773
+ 0, & x < 0 \\
1774
+ \end{cases}
1775
+
1776
+ Backward function:
1777
+
1778
+ .. math::
1779
+
1780
+ g'(x) = \exp(-\alpha |x|)
1781
+
1782
+
1783
+ .. plot::
1784
+ :include-source: True
1785
+
1786
+ >>> import jax
1787
+ >>> import brainstate.nn as nn
1788
+ >>> import brainstate as bst
1789
+ >>> import matplotlib.pyplot as plt
1790
+ >>> xs = jax.numpy.linspace(-3, 3, 1000)
1791
+ >>> for alpha in [0.5, 1., 2., 4.]:
1792
+ >>> grads = bst.augment.vector_grad(bst.surrogate.slayer_grad)(xs, alpha)
1793
+ >>> plt.plot(xs, grads, label=r'$\alpha$=' + str(alpha))
1794
+ >>> plt.legend()
1795
+ >>> plt.show()
1796
+
1797
+ Parameters
1798
+ ----------
1799
+ x: jax.Array, Array
1800
+ The input data.
1801
+ alpha: float
1802
+ Parameter to control smoothness of gradient
1803
+
1804
+ Returns
1805
+ -------
1806
+ out: jax.Array
1807
+ The spiking state.
1808
+
1809
+ References
1810
+ ----------
1811
+ .. [1] Shrestha, S. B. & Orchard, G. Slayer: spike layer error reassignment in time. In Advances in Neural Information Processing Systems Vol. 31, 1412–1421 (NeurIPS, 2018).
1812
+ """
1813
+ return SlayerGrad(alpha=alpha)(x)