brainstate 0.1.8__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. brainstate/__init__.py +58 -51
  2. brainstate/_compatible_import.py +148 -148
  3. brainstate/_state.py +1605 -1663
  4. brainstate/_state_test.py +52 -52
  5. brainstate/_utils.py +47 -47
  6. brainstate/augment/__init__.py +30 -30
  7. brainstate/augment/_autograd.py +778 -778
  8. brainstate/augment/_autograd_test.py +1289 -1289
  9. brainstate/augment/_eval_shape.py +99 -99
  10. brainstate/augment/_eval_shape_test.py +38 -38
  11. brainstate/augment/_mapping.py +1060 -1060
  12. brainstate/augment/_mapping_test.py +597 -597
  13. brainstate/augment/_random.py +151 -151
  14. brainstate/compile/__init__.py +38 -38
  15. brainstate/compile/_ad_checkpoint.py +204 -204
  16. brainstate/compile/_ad_checkpoint_test.py +49 -49
  17. brainstate/compile/_conditions.py +256 -256
  18. brainstate/compile/_conditions_test.py +220 -220
  19. brainstate/compile/_error_if.py +92 -92
  20. brainstate/compile/_error_if_test.py +52 -52
  21. brainstate/compile/_jit.py +346 -346
  22. brainstate/compile/_jit_test.py +143 -143
  23. brainstate/compile/_loop_collect_return.py +536 -536
  24. brainstate/compile/_loop_collect_return_test.py +58 -58
  25. brainstate/compile/_loop_no_collection.py +184 -184
  26. brainstate/compile/_loop_no_collection_test.py +50 -50
  27. brainstate/compile/_make_jaxpr.py +888 -888
  28. brainstate/compile/_make_jaxpr_test.py +156 -156
  29. brainstate/compile/_progress_bar.py +202 -202
  30. brainstate/compile/_unvmap.py +159 -159
  31. brainstate/compile/_util.py +147 -147
  32. brainstate/environ.py +563 -563
  33. brainstate/environ_test.py +62 -62
  34. brainstate/functional/__init__.py +27 -26
  35. brainstate/graph/__init__.py +29 -29
  36. brainstate/graph/_graph_node.py +244 -244
  37. brainstate/graph/_graph_node_test.py +73 -73
  38. brainstate/graph/_graph_operation.py +1738 -1738
  39. brainstate/graph/_graph_operation_test.py +563 -563
  40. brainstate/init/__init__.py +26 -26
  41. brainstate/init/_base.py +52 -52
  42. brainstate/init/_generic.py +244 -244
  43. brainstate/init/_random_inits.py +553 -553
  44. brainstate/init/_random_inits_test.py +149 -149
  45. brainstate/init/_regular_inits.py +105 -105
  46. brainstate/init/_regular_inits_test.py +50 -50
  47. brainstate/mixin.py +365 -363
  48. brainstate/mixin_test.py +77 -73
  49. brainstate/nn/__init__.py +135 -131
  50. brainstate/{functional → nn}/_activations.py +808 -813
  51. brainstate/{functional → nn}/_activations_test.py +331 -331
  52. brainstate/nn/_collective_ops.py +514 -514
  53. brainstate/nn/_collective_ops_test.py +43 -43
  54. brainstate/nn/_common.py +178 -178
  55. brainstate/nn/_conv.py +501 -501
  56. brainstate/nn/_conv_test.py +238 -238
  57. brainstate/nn/_delay.py +509 -502
  58. brainstate/nn/_delay_test.py +238 -184
  59. brainstate/nn/_dropout.py +426 -426
  60. brainstate/nn/_dropout_test.py +100 -100
  61. brainstate/nn/_dynamics.py +1343 -1343
  62. brainstate/nn/_dynamics_test.py +78 -78
  63. brainstate/nn/_elementwise.py +1119 -1119
  64. brainstate/nn/_elementwise_test.py +169 -169
  65. brainstate/nn/_embedding.py +58 -58
  66. brainstate/nn/_exp_euler.py +92 -92
  67. brainstate/nn/_exp_euler_test.py +35 -35
  68. brainstate/nn/_fixedprob.py +239 -239
  69. brainstate/nn/_fixedprob_test.py +114 -114
  70. brainstate/nn/_inputs.py +608 -608
  71. brainstate/nn/_linear.py +424 -424
  72. brainstate/nn/_linear_mv.py +83 -83
  73. brainstate/nn/_linear_mv_test.py +120 -120
  74. brainstate/nn/_linear_test.py +107 -107
  75. brainstate/nn/_ltp.py +28 -28
  76. brainstate/nn/_module.py +377 -377
  77. brainstate/nn/_module_test.py +40 -40
  78. brainstate/nn/_neuron.py +705 -705
  79. brainstate/nn/_neuron_test.py +161 -161
  80. brainstate/nn/_normalizations.py +975 -918
  81. brainstate/nn/_normalizations_test.py +73 -73
  82. brainstate/{functional → nn}/_others.py +46 -46
  83. brainstate/nn/_poolings.py +1177 -1177
  84. brainstate/nn/_poolings_test.py +217 -217
  85. brainstate/nn/_projection.py +486 -486
  86. brainstate/nn/_rate_rnns.py +554 -554
  87. brainstate/nn/_rate_rnns_test.py +63 -63
  88. brainstate/nn/_readout.py +209 -209
  89. brainstate/nn/_readout_test.py +53 -53
  90. brainstate/nn/_stp.py +236 -236
  91. brainstate/nn/_synapse.py +505 -505
  92. brainstate/nn/_synapse_test.py +131 -131
  93. brainstate/nn/_synaptic_projection.py +423 -423
  94. brainstate/nn/_synouts.py +162 -162
  95. brainstate/nn/_synouts_test.py +57 -57
  96. brainstate/nn/_utils.py +89 -89
  97. brainstate/nn/metrics.py +388 -388
  98. brainstate/optim/__init__.py +38 -38
  99. brainstate/optim/_base.py +64 -64
  100. brainstate/optim/_lr_scheduler.py +448 -448
  101. brainstate/optim/_lr_scheduler_test.py +50 -50
  102. brainstate/optim/_optax_optimizer.py +152 -152
  103. brainstate/optim/_optax_optimizer_test.py +53 -53
  104. brainstate/optim/_sgd_optimizer.py +1104 -1104
  105. brainstate/random/__init__.py +24 -24
  106. brainstate/random/_rand_funs.py +3616 -3616
  107. brainstate/random/_rand_funs_test.py +567 -567
  108. brainstate/random/_rand_seed.py +210 -210
  109. brainstate/random/_rand_seed_test.py +48 -48
  110. brainstate/random/_rand_state.py +1409 -1409
  111. brainstate/random/_random_for_unit.py +52 -52
  112. brainstate/surrogate.py +1957 -1957
  113. brainstate/transform.py +23 -23
  114. brainstate/typing.py +304 -304
  115. brainstate/util/__init__.py +50 -50
  116. brainstate/util/caller.py +98 -98
  117. brainstate/util/error.py +55 -55
  118. brainstate/util/filter.py +469 -469
  119. brainstate/util/others.py +540 -540
  120. brainstate/util/pretty_pytree.py +945 -945
  121. brainstate/util/pretty_pytree_test.py +159 -159
  122. brainstate/util/pretty_repr.py +328 -328
  123. brainstate/util/pretty_table.py +2954 -2954
  124. brainstate/util/scaling.py +258 -258
  125. brainstate/util/struct.py +523 -523
  126. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
  127. brainstate-0.1.9.dist-info/RECORD +130 -0
  128. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
  129. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
  130. brainstate/functional/_normalization.py +0 -81
  131. brainstate/functional/_spikes.py +0 -204
  132. brainstate-0.1.8.dist-info/RECORD +0 -132
  133. {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
@@ -1,1119 +1,1119 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- from typing import Optional
19
-
20
- import brainunit as u
21
- import jax.numpy as jnp
22
-
23
- from brainstate import functional as F
24
- from brainstate._state import ParamState
25
- from brainstate.typing import ArrayLike
26
- from ._module import ElementWiseBlock
27
-
28
- __all__ = [
29
- # activation functions
30
- 'Threshold', 'ReLU', 'RReLU', 'Hardtanh', 'ReLU6', 'Sigmoid', 'Hardsigmoid',
31
- 'Tanh', 'SiLU', 'Mish', 'Hardswish', 'ELU', 'CELU', 'SELU', 'GLU', 'GELU',
32
- 'Hardshrink', 'LeakyReLU', 'LogSigmoid', 'Softplus', 'Softshrink', 'PReLU',
33
- 'Softsign', 'Tanhshrink', 'Softmin', 'Softmax', 'Softmax2d', 'LogSoftmax',
34
-
35
- # others
36
- 'Identity', 'SpikeBitwise',
37
- ]
38
-
39
-
40
- class Threshold(ElementWiseBlock):
41
- r"""Thresholds each element of the input Tensor.
42
-
43
- Threshold is defined as:
44
-
45
- .. math::
46
- y =
47
- \begin{cases}
48
- x, &\text{ if } x > \text{threshold} \\
49
- \text{value}, &\text{ otherwise }
50
- \end{cases}
51
-
52
- Args:
53
- threshold: The value to threshold at
54
- value: The value to replace with
55
-
56
- Shape:
57
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
58
- - Output: :math:`(*)`, same shape as the input.
59
-
60
- Examples::
61
-
62
- >>> import brainstate.nn as nn
63
- >>> import brainstate
64
- >>> m = nn.Threshold(0.1, 20)
65
- >>> x = random.randn(2)
66
- >>> output = m(x)
67
- """
68
- __module__ = 'brainstate.nn'
69
- threshold: float
70
- value: float
71
-
72
- def __init__(self, threshold: float, value: float) -> None:
73
- super().__init__()
74
- self.threshold = threshold
75
- self.value = value
76
-
77
- def __call__(self, x: ArrayLike) -> ArrayLike:
78
- dtype = u.math.get_dtype(x)
79
- return jnp.where(x > jnp.asarray(self.threshold, dtype=dtype),
80
- x,
81
- jnp.asarray(self.value, dtype=dtype))
82
-
83
- def __repr__(self):
84
- return f'{self.__class__.__name__}(threshold={self.threshold}, value={self.value})'
85
-
86
-
87
- class ReLU(ElementWiseBlock):
88
- r"""Applies the rectified linear unit function element-wise:
89
-
90
- :math:`\text{ReLU}(x) = (x)^+ = \max(0, x)`
91
-
92
- Shape:
93
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
94
- - Output: :math:`(*)`, same shape as the input.
95
-
96
- Examples::
97
-
98
- >>> import brainstate.nn as nn
99
- >>> import brainstate as brainstate
100
- >>> m = nn.ReLU()
101
- >>> x = random.randn(2)
102
- >>> output = m(x)
103
-
104
-
105
- An implementation of CReLU - https://arxiv.org/abs/1603.05201
106
-
107
- >>> import brainstate.nn as nn
108
- >>> import brainstate as brainstate
109
- >>> m = nn.ReLU()
110
- >>> x = random.randn(2).unsqueeze(0)
111
- >>> output = jax.numpy.concat((m(x), m(-x)))
112
- """
113
- __module__ = 'brainstate.nn'
114
-
115
- def __call__(self, x: ArrayLike) -> ArrayLike:
116
- return F.relu(x)
117
-
118
- def __repr__(self):
119
- return f'{self.__class__.__name__}()'
120
-
121
-
122
- class RReLU(ElementWiseBlock):
123
- r"""Applies the randomized leaky rectified liner unit function, element-wise,
124
- as described in the paper:
125
-
126
- `Empirical Evaluation of Rectified Activations in Convolutional Network`_.
127
-
128
- The function is defined as:
129
-
130
- .. math::
131
- \text{RReLU}(x) =
132
- \begin{cases}
133
- x & \text{if } x \geq 0 \\
134
- ax & \text{ otherwise }
135
- \end{cases}
136
-
137
- where :math:`a` is randomly sampled from uniform distribution
138
- :math:`\mathcal{U}(\text{lower}, \text{upper})`.
139
-
140
- See: https://arxiv.org/pdf/1505.00853.pdf
141
-
142
- Args:
143
- lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
144
- upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
145
-
146
- Shape:
147
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
148
- - Output: :math:`(*)`, same shape as the input.
149
-
150
- Examples::
151
-
152
- >>> import brainstate.nn as nn
153
- >>> import brainstate as brainstate
154
- >>> m = nn.RReLU(0.1, 0.3)
155
- >>> x = random.randn(2)
156
- >>> output = m(x)
157
-
158
- .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
159
- https://arxiv.org/abs/1505.00853
160
- """
161
- __module__ = 'brainstate.nn'
162
- lower: float
163
- upper: float
164
-
165
- def __init__(
166
- self,
167
- lower: float = 1. / 8,
168
- upper: float = 1. / 3,
169
- ):
170
- super().__init__()
171
- self.lower = lower
172
- self.upper = upper
173
-
174
- def __call__(self, x: ArrayLike) -> ArrayLike:
175
- return F.rrelu(x, self.lower, self.upper)
176
-
177
- def extra_repr(self):
178
- return f'{self.__class__.__name__}(lower={self.lower}, upper={self.upper})'
179
-
180
-
181
- class Hardtanh(ElementWiseBlock):
182
- r"""Applies the HardTanh function element-wise.
183
-
184
- HardTanh is defined as:
185
-
186
- .. math::
187
- \text{HardTanh}(x) = \begin{cases}
188
- \text{max\_val} & \text{ if } x > \text{ max\_val } \\
189
- \text{min\_val} & \text{ if } x < \text{ min\_val } \\
190
- x & \text{ otherwise } \\
191
- \end{cases}
192
-
193
- Args:
194
- min_val: minimum value of the linear region range. Default: -1
195
- max_val: maximum value of the linear region range. Default: 1
196
-
197
- Keyword arguments :attr:`min_value` and :attr:`max_value`
198
- have been deprecated in favor of :attr:`min_val` and :attr:`max_val`.
199
-
200
- Shape:
201
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
202
- - Output: :math:`(*)`, same shape as the input.
203
-
204
- Examples::
205
-
206
- >>> import brainstate.nn as nn
207
- >>> import brainstate as brainstate
208
- >>> m = nn.Hardtanh(-2, 2)
209
- >>> x = random.randn(2)
210
- >>> output = m(x)
211
- """
212
- __module__ = 'brainstate.nn'
213
- min_val: float
214
- max_val: float
215
-
216
- def __init__(
217
- self,
218
- min_val: float = -1.,
219
- max_val: float = 1.,
220
- ) -> None:
221
- super().__init__()
222
- self.min_val = min_val
223
- self.max_val = max_val
224
- assert self.max_val > self.min_val
225
-
226
- def __call__(self, x: ArrayLike) -> ArrayLike:
227
- return F.hard_tanh(x, self.min_val, self.max_val)
228
-
229
- def extra_repr(self) -> str:
230
- return f'{self.__class__.__name__}(min_val={self.min_val}, max_val={self.max_val})'
231
-
232
-
233
- class ReLU6(Hardtanh, ElementWiseBlock):
234
- r"""Applies the element-wise function:
235
-
236
- .. math::
237
- \text{ReLU6}(x) = \min(\max(0,x), 6)
238
-
239
- Shape:
240
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
241
- - Output: :math:`(*)`, same shape as the input.
242
-
243
- Examples::
244
-
245
- >>> import brainstate.nn as nn
246
- >>> import brainstate as brainstate
247
- >>> m = nn.ReLU6()
248
- >>> x = random.randn(2)
249
- >>> output = m(x)
250
- """
251
- __module__ = 'brainstate.nn'
252
-
253
- def __init__(self):
254
- super().__init__(0., 6.)
255
-
256
-
257
- class Sigmoid(ElementWiseBlock):
258
- r"""Applies the element-wise function:
259
-
260
- .. math::
261
- \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
262
-
263
-
264
- Shape:
265
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
266
- - Output: :math:`(*)`, same shape as the input.
267
-
268
- Examples::
269
-
270
- >>> import brainstate.nn as nn
271
- >>> import brainstate as brainstate
272
- >>> m = nn.Sigmoid()
273
- >>> x = random.randn(2)
274
- >>> output = m(x)
275
- """
276
- __module__ = 'brainstate.nn'
277
-
278
- def __call__(self, x: ArrayLike) -> ArrayLike:
279
- return F.sigmoid(x)
280
-
281
-
282
- class Hardsigmoid(ElementWiseBlock):
283
- r"""Applies the Hardsigmoid function element-wise.
284
-
285
- Hardsigmoid is defined as:
286
-
287
- .. math::
288
- \text{Hardsigmoid}(x) = \begin{cases}
289
- 0 & \text{if~} x \le -3, \\
290
- 1 & \text{if~} x \ge +3, \\
291
- x / 6 + 1 / 2 & \text{otherwise}
292
- \end{cases}
293
-
294
- Shape:
295
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
296
- - Output: :math:`(*)`, same shape as the input.
297
-
298
- Examples::
299
-
300
- >>> import brainstate.nn as nn
301
- >>> import brainstate as brainstate
302
- >>> m = nn.Hardsigmoid()
303
- >>> x = random.randn(2)
304
- >>> output = m(x)
305
- """
306
- __module__ = 'brainstate.nn'
307
-
308
- def __call__(self, x: ArrayLike) -> ArrayLike:
309
- return F.hard_sigmoid(x)
310
-
311
-
312
- class Tanh(ElementWiseBlock):
313
- r"""Applies the Hyperbolic Tangent (Tanh) function element-wise.
314
-
315
- Tanh is defined as:
316
-
317
- .. math::
318
- \text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)}
319
-
320
- Shape:
321
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
322
- - Output: :math:`(*)`, same shape as the input.
323
-
324
- Examples::
325
-
326
- >>> import brainstate.nn as nn
327
- >>> import brainstate as brainstate
328
- >>> m = nn.Tanh()
329
- >>> x = random.randn(2)
330
- >>> output = m(x)
331
- """
332
- __module__ = 'brainstate.nn'
333
-
334
- def __call__(self, x: ArrayLike) -> ArrayLike:
335
- return F.tanh(x)
336
-
337
-
338
- class SiLU(ElementWiseBlock):
339
- r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
340
- The SiLU function is also known as the swish function.
341
-
342
- .. math::
343
- \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.}
344
-
345
- .. note::
346
- See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
347
- where the SiLU (Sigmoid Linear Unit) was originally coined, and see
348
- `Sigmoid-Weighted Linear Units for Neural Network Function Approximation
349
- in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
350
- a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
351
- where the SiLU was experimented with later.
352
-
353
- Shape:
354
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
355
- - Output: :math:`(*)`, same shape as the input.
356
-
357
- Examples::
358
-
359
- >>> import brainstate.nn as nn
360
- >>> m = nn.SiLU()
361
- >>> x = random.randn(2)
362
- >>> output = m(x)
363
- """
364
- __module__ = 'brainstate.nn'
365
-
366
- def __call__(self, x: ArrayLike) -> ArrayLike:
367
- return F.silu(x)
368
-
369
-
370
- class Mish(ElementWiseBlock):
371
- r"""Applies the Mish function, element-wise.
372
- Mish: A Self Regularized Non-Monotonic Neural Activation Function.
373
-
374
- .. math::
375
- \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
376
-
377
- .. note::
378
- See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_
379
-
380
-
381
- Shape:
382
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
383
- - Output: :math:`(*)`, same shape as the input.
384
-
385
- Examples::
386
-
387
- >>> import brainstate.nn as nn
388
- >>> import brainstate as brainstate
389
- >>> m = nn.Mish()
390
- >>> x = random.randn(2)
391
- >>> output = m(x)
392
- """
393
- __module__ = 'brainstate.nn'
394
-
395
- def __call__(self, x: ArrayLike) -> ArrayLike:
396
- return F.mish(x)
397
-
398
-
399
- class Hardswish(ElementWiseBlock):
400
- r"""Applies the Hardswish function, element-wise, as described in the paper:
401
- `Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`_.
402
-
403
- Hardswish is defined as:
404
-
405
- .. math::
406
- \text{Hardswish}(x) = \begin{cases}
407
- 0 & \text{if~} x \le -3, \\
408
- x & \text{if~} x \ge +3, \\
409
- x \cdot (x + 3) /6 & \text{otherwise}
410
- \end{cases}
411
-
412
-
413
- Shape:
414
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
415
- - Output: :math:`(*)`, same shape as the input.
416
-
417
- Examples::
418
-
419
- >>> import brainstate.nn as nn
420
- >>> import brainstate as brainstate
421
- >>> m = nn.Hardswish()
422
- >>> x = random.randn(2)
423
- >>> output = m(x)
424
- """
425
- __module__ = 'brainstate.nn'
426
-
427
- def __call__(self, x: ArrayLike) -> ArrayLike:
428
- return F.hard_swish(x)
429
-
430
-
431
- class ELU(ElementWiseBlock):
432
- r"""Applies the Exponential Linear Unit (ELU) function, element-wise, as described
433
- in the paper: `Fast and Accurate Deep Network Learning by Exponential Linear
434
- Units (ELUs) <https://arxiv.org/abs/1511.07289>`__.
435
-
436
- ELU is defined as:
437
-
438
- .. math::
439
- \text{ELU}(x) = \begin{cases}
440
- x, & \text{ if } x > 0\\
441
- \alpha * (\exp(x) - 1), & \text{ if } x \leq 0
442
- \end{cases}
443
-
444
- Args:
445
- alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
446
-
447
- Shape:
448
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
449
- - Output: :math:`(*)`, same shape as the input.
450
-
451
- Examples::
452
-
453
- >>> import brainstate.nn as nn
454
- >>> import brainstate as brainstate
455
- >>> m = nn.ELU()
456
- >>> x = random.randn(2)
457
- >>> output = m(x)
458
- """
459
- __module__ = 'brainstate.nn'
460
- alpha: float
461
-
462
- def __init__(self, alpha: float = 1.) -> None:
463
- super().__init__()
464
- self.alpha = alpha
465
-
466
- def __call__(self, x: ArrayLike) -> ArrayLike:
467
- return F.elu(x, self.alpha)
468
-
469
- def extra_repr(self) -> str:
470
- return f'{self.__class__.__name__}(alpha={self.alpha})'
471
-
472
-
473
- class CELU(ElementWiseBlock):
474
- r"""Applies the element-wise function:
475
-
476
- .. math::
477
- \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
478
-
479
- More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ .
480
-
481
- Args:
482
- alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
483
-
484
- Shape:
485
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
486
- - Output: :math:`(*)`, same shape as the input.
487
-
488
- Examples::
489
-
490
- >>> import brainstate.nn as nn
491
- >>> import brainstate as brainstate
492
- >>> m = nn.CELU()
493
- >>> x = random.randn(2)
494
- >>> output = m(x)
495
-
496
- .. _`Continuously Differentiable Exponential Linear Units`:
497
- https://arxiv.org/abs/1704.07483
498
- """
499
- __module__ = 'brainstate.nn'
500
- alpha: float
501
-
502
- def __init__(self, alpha: float = 1.) -> None:
503
- super().__init__()
504
- self.alpha = alpha
505
-
506
- def __call__(self, x: ArrayLike) -> ArrayLike:
507
- return F.celu(x, self.alpha)
508
-
509
- def extra_repr(self) -> str:
510
- return f'{self.__class__.__name__}(alpha={self.alpha})'
511
-
512
-
513
- class SELU(ElementWiseBlock):
514
- r"""Applied element-wise, as:
515
-
516
- .. math::
517
- \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))
518
-
519
- with :math:`\alpha = 1.6732632423543772848170429916717` and
520
- :math:`\text{scale} = 1.0507009873554804934193349852946`.
521
-
522
- More details can be found in the paper `Self-Normalizing Neural Networks`_ .
523
-
524
-
525
- Shape:
526
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
527
- - Output: :math:`(*)`, same shape as the input.
528
-
529
- Examples::
530
-
531
- >>> import brainstate.nn as nn
532
- >>> import brainstate as brainstate
533
- >>> m = nn.SELU()
534
- >>> x = random.randn(2)
535
- >>> output = m(x)
536
-
537
- .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
538
- """
539
- __module__ = 'brainstate.nn'
540
-
541
- def __call__(self, x: ArrayLike) -> ArrayLike:
542
- return F.selu(x)
543
-
544
-
545
- class GLU(ElementWiseBlock):
546
- r"""Applies the gated linear unit function
547
- :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
548
- of the input matrices and :math:`b` is the second half.
549
-
550
- Args:
551
- dim (int): the dimension on which to split the input. Default: -1
552
-
553
- Shape:
554
- - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
555
- dimensions
556
- - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
557
-
558
- Examples::
559
-
560
- >>> import brainstate.nn as nn
561
- >>> import brainstate as brainstate
562
- >>> m = nn.GLU()
563
- >>> x = random.randn(4, 2)
564
- >>> output = m(x)
565
- """
566
- __module__ = 'brainstate.nn'
567
- dim: int
568
-
569
- def __init__(self, dim: int = -1) -> None:
570
- super().__init__()
571
- self.dim = dim
572
-
573
- def __call__(self, x: ArrayLike) -> ArrayLike:
574
- return F.glu(x, self.dim)
575
-
576
- def __repr__(self):
577
- return f'{self.__class__.__name__}(dim={self.dim})'
578
-
579
-
580
- class GELU(ElementWiseBlock):
581
- r"""Applies the Gaussian Error Linear Units function:
582
-
583
- .. math:: \text{GELU}(x) = x * \Phi(x)
584
-
585
- where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
586
-
587
- When the approximate argument is 'tanh', Gelu is estimated with:
588
-
589
- .. math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt(2 / \pi) * (x + 0.044715 * x^3)))
590
-
591
- Args:
592
- approximate (str, optional): the gelu approximation algorithm to use:
593
- ``'none'`` | ``'tanh'``. Default: ``'none'``
594
-
595
- Shape:
596
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
597
- - Output: :math:`(*)`, same shape as the input.
598
-
599
- Examples::
600
-
601
- >>> import brainstate.nn as nn
602
- >>> import brainstate as brainstate
603
- >>> m = nn.GELU()
604
- >>> x = random.randn(2)
605
- >>> output = m(x)
606
- """
607
- __module__ = 'brainstate.nn'
608
- approximate: bool
609
-
610
- def __init__(self, approximate: bool = False) -> None:
611
- super().__init__()
612
- self.approximate = approximate
613
-
614
- def __call__(self, x: ArrayLike) -> ArrayLike:
615
- return F.gelu(x, approximate=self.approximate)
616
-
617
- def __repr__(self):
618
- return f'{self.__class__.__name__}(approximate={self.approximate})'
619
-
620
-
621
- class Hardshrink(ElementWiseBlock):
622
- r"""Applies the Hard Shrinkage (Hardshrink) function element-wise.
623
-
624
- Hardshrink is defined as:
625
-
626
- .. math::
627
- \text{HardShrink}(x) =
628
- \begin{cases}
629
- x, & \text{ if } x > \lambda \\
630
- x, & \text{ if } x < -\lambda \\
631
- 0, & \text{ otherwise }
632
- \end{cases}
633
-
634
- Args:
635
- lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
636
-
637
- Shape:
638
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
639
- - Output: :math:`(*)`, same shape as the input.
640
-
641
- Examples::
642
-
643
- >>> import brainstate.nn as nn
644
- >>> import brainstate as brainstate
645
- >>> m = nn.Hardshrink()
646
- >>> x = random.randn(2)
647
- >>> output = m(x)
648
- """
649
- __module__ = 'brainstate.nn'
650
- lambd: float
651
-
652
- def __init__(self, lambd: float = 0.5) -> None:
653
- super().__init__()
654
- self.lambd = lambd
655
-
656
- def __call__(self, x: ArrayLike) -> ArrayLike:
657
- return F.hard_shrink(x, self.lambd)
658
-
659
- def __repr__(self):
660
- return f'{self.__class__.__name__}(lambd={self.lambd})'
661
-
662
-
663
- class LeakyReLU(ElementWiseBlock):
664
- r"""Applies the element-wise function:
665
-
666
- .. math::
667
- \text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)
668
-
669
-
670
- or
671
-
672
- .. math::
673
- \text{LeakyReLU}(x) =
674
- \begin{cases}
675
- x, & \text{ if } x \geq 0 \\
676
- \text{negative\_slope} \times x, & \text{ otherwise }
677
- \end{cases}
678
-
679
- Args:
680
- negative_slope: Controls the angle of the negative slope (which is used for
681
- negative input values). Default: 1e-2
682
-
683
- Shape:
684
- - Input: :math:`(*)` where `*` means, any number of additional
685
- dimensions
686
- - Output: :math:`(*)`, same shape as the input
687
-
688
- Examples::
689
-
690
- >>> import brainstate.nn as nn
691
- >>> import brainstate as brainstate
692
- >>> m = nn.LeakyReLU(0.1)
693
- >>> x = random.randn(2)
694
- >>> output = m(x)
695
- """
696
- __module__ = 'brainstate.nn'
697
- negative_slope: float
698
-
699
- def __init__(self, negative_slope: float = 1e-2) -> None:
700
- super().__init__()
701
- self.negative_slope = negative_slope
702
-
703
- def __call__(self, x: ArrayLike) -> ArrayLike:
704
- return F.leaky_relu(x, self.negative_slope)
705
-
706
- def __repr__(self):
707
- return f'{self.__class__.__name__}(negative_slope={self.negative_slope})'
708
-
709
-
710
- class LogSigmoid(ElementWiseBlock):
711
- r"""Applies the element-wise function:
712
-
713
- .. math::
714
- \text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right)
715
-
716
- Shape:
717
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
718
- - Output: :math:`(*)`, same shape as the input.
719
-
720
- Examples::
721
-
722
- >>> import brainstate.nn as nn
723
- >>> import brainstate as brainstate
724
- >>> m = nn.LogSigmoid()
725
- >>> x = random.randn(2)
726
- >>> output = m(x)
727
- """
728
- __module__ = 'brainstate.nn'
729
-
730
- def __call__(self, x: ArrayLike) -> ArrayLike:
731
- return F.log_sigmoid(x)
732
-
733
-
734
- class Softplus(ElementWiseBlock):
735
- r"""Applies the Softplus function :math:`\text{Softplus}(x) = \frac{1}{\beta} *
736
- \log(1 + \exp(\beta * x))` element-wise.
737
-
738
- SoftPlus is a smooth approximation to the ReLU function and can be used
739
- to constrain the output of a machine to always be positive.
740
-
741
- For numerical stability the implementation reverts to the linear function
742
- when :math:`input \times \beta > threshold`.
743
-
744
- Shape:
745
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
746
- - Output: :math:`(*)`, same shape as the input.
747
-
748
- Examples::
749
-
750
- >>> import brainstate.nn as nn
751
- >>> import brainstate as brainstate
752
- >>> m = nn.Softplus()
753
- >>> x = random.randn(2)
754
- >>> output = m(x)
755
- """
756
- __module__ = 'brainstate.nn'
757
-
758
- def __call__(self, x: ArrayLike) -> ArrayLike:
759
- return F.softplus(x)
760
-
761
-
762
- class Softshrink(ElementWiseBlock):
763
- r"""Applies the soft shrinkage function elementwise:
764
-
765
- .. math::
766
- \text{SoftShrinkage}(x) =
767
- \begin{cases}
768
- x - \lambda, & \text{ if } x > \lambda \\
769
- x + \lambda, & \text{ if } x < -\lambda \\
770
- 0, & \text{ otherwise }
771
- \end{cases}
772
-
773
- Args:
774
- lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5
775
-
776
- Shape:
777
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
778
- - Output: :math:`(*)`, same shape as the input.
779
-
780
- Examples::
781
-
782
- >>> import brainstate.nn as nn
783
- >>> import brainstate as brainstate
784
- >>> m = nn.Softshrink()
785
- >>> x = random.randn(2)
786
- >>> output = m(x)
787
- """
788
- __module__ = 'brainstate.nn'
789
- lambd: float
790
-
791
- def __init__(self, lambd: float = 0.5) -> None:
792
- super().__init__()
793
- self.lambd = lambd
794
-
795
- def __call__(self, x: ArrayLike) -> ArrayLike:
796
- return F.soft_shrink(x, self.lambd)
797
-
798
- def __repr__(self):
799
- return f'{self.__class__.__name__}(lambd={self.lambd})'
800
-
801
-
802
- class PReLU(ElementWiseBlock):
803
- r"""Applies the element-wise function:
804
-
805
- .. math::
806
- \text{PReLU}(x) = \max(0,x) + a * \min(0,x)
807
-
808
- or
809
-
810
- .. math::
811
- \text{PReLU}(x) =
812
- \begin{cases}
813
- x, & \text{ if } x \geq 0 \\
814
- ax, & \text{ otherwise }
815
- \end{cases}
816
-
817
- Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single
818
- parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
819
- a separate :math:`a` is used for each input channel.
820
-
821
-
822
- .. note::
823
- weight decay should not be used when learning :math:`a` for good performance.
824
-
825
- .. note::
826
- Channel dim is the 2nd dim of input. When input has dims < 2, then there is
827
- no channel dim and the number of channels = 1.
828
-
829
- Args:
830
- num_parameters (int): number of :math:`a` to learn.
831
- Although it takes an int as input, there is only two values are legitimate:
832
- 1, or the number of channels at input. Default: 1
833
- init (float): the initial value of :math:`a`. Default: 0.25
834
-
835
- Shape:
836
- - Input: :math:`( *)` where `*` means, any number of additional
837
- dimensions.
838
- - Output: :math:`(*)`, same shape as the input.
839
-
840
- Attributes:
841
- weight (Tensor): the learnable weights of shape (:attr:`num_parameters`).
842
-
843
- Examples::
844
-
845
- >>> import brainstate as brainstate
846
- >>> m = brainstate.nn.PReLU()
847
- >>> x = brainstate.random.randn(2)
848
- >>> output = m(x)
849
- """
850
- __module__ = 'brainstate.nn'
851
- num_parameters: int
852
-
853
- def __init__(self, num_parameters: int = 1, init: float = 0.25, dtype=None) -> None:
854
- super().__init__()
855
- self.num_parameters = num_parameters
856
- self.weight = ParamState(jnp.ones(num_parameters, dtype=dtype) * init)
857
-
858
- def __call__(self, x: ArrayLike) -> ArrayLike:
859
- return F.prelu(x, self.weight.value)
860
-
861
- def __repr__(self):
862
- return f'{self.__class__.__name__}(num_parameters={self.num_parameters})'
863
-
864
-
865
- class Softsign(ElementWiseBlock):
866
- r"""Applies the element-wise function:
867
-
868
- .. math::
869
- \text{SoftSign}(x) = \frac{x}{ 1 + |x|}
870
-
871
- Shape:
872
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
873
- - Output: :math:`(*)`, same shape as the input.
874
-
875
- Examples::
876
-
877
- >>> import brainstate.nn as nn
878
- >>> import brainstate as brainstate
879
- >>> m = nn.Softsign()
880
- >>> x = random.randn(2)
881
- >>> output = m(x)
882
- """
883
- __module__ = 'brainstate.nn'
884
-
885
- def __call__(self, x: ArrayLike) -> ArrayLike:
886
- return F.soft_sign(x)
887
-
888
-
889
- class Tanhshrink(ElementWiseBlock):
890
- r"""Applies the element-wise function:
891
-
892
- .. math::
893
- \text{Tanhshrink}(x) = x - \tanh(x)
894
-
895
- Shape:
896
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
897
- - Output: :math:`(*)`, same shape as the input.
898
-
899
- Examples::
900
-
901
- >>> import brainstate.nn as nn
902
- >>> import brainstate as brainstate
903
- >>> m = nn.Tanhshrink()
904
- >>> x = random.randn(2)
905
- >>> output = m(x)
906
- """
907
- __module__ = 'brainstate.nn'
908
-
909
- def __call__(self, x: ArrayLike) -> ArrayLike:
910
- return F.tanh_shrink(x)
911
-
912
-
913
- class Softmin(ElementWiseBlock):
914
- r"""Applies the Softmin function to an n-dimensional input Tensor
915
- rescaling them so that the elements of the n-dimensional output Tensor
916
- lie in the range `[0, 1]` and sum to 1.
917
-
918
- Softmin is defined as:
919
-
920
- .. math::
921
- \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
922
-
923
- Shape:
924
- - Input: :math:`(*)` where `*` means, any number of additional
925
- dimensions
926
- - Output: :math:`(*)`, same shape as the input
927
-
928
- Args:
929
- dim (int): A dimension along which Softmin will be computed (so every slice
930
- along dim will sum to 1).
931
-
932
- Returns:
933
- a Tensor of the same dimension and shape as the input, with
934
- values in the range [0, 1]
935
-
936
- Examples::
937
-
938
- >>> import brainstate.nn as nn
939
- >>> import brainstate as brainstate
940
- >>> m = nn.Softmin(dim=1)
941
- >>> x = random.randn(2, 3)
942
- >>> output = m(x)
943
- """
944
- __module__ = 'brainstate.nn'
945
- dim: Optional[int]
946
-
947
- def __init__(self, dim: Optional[int] = None) -> None:
948
- super().__init__()
949
- self.dim = dim
950
-
951
- def __call__(self, x: ArrayLike) -> ArrayLike:
952
- return F.softmin(x, self.dim)
953
-
954
- def __repr__(self):
955
- return f'{self.__class__.__name__}(dim={self.dim})'
956
-
957
-
958
- class Softmax(ElementWiseBlock):
959
- r"""Applies the Softmax function to an n-dimensional input Tensor
960
- rescaling them so that the elements of the n-dimensional output Tensor
961
- lie in the range [0,1] and sum to 1.
962
-
963
- Softmax is defined as:
964
-
965
- .. math::
966
- \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
967
-
968
- When the input Tensor is a sparse tensor then the unspecified
969
- values are treated as ``-inf``.
970
-
971
- Shape:
972
- - Input: :math:`(*)` where `*` means, any number of additional
973
- dimensions
974
- - Output: :math:`(*)`, same shape as the input
975
-
976
- Returns:
977
- a Tensor of the same dimension and shape as the input with
978
- values in the range [0, 1]
979
-
980
- Args:
981
- dim (int): A dimension along which Softmax will be computed (so every slice
982
- along dim will sum to 1).
983
-
984
- .. note::
985
- This module doesn't work directly with NLLLoss,
986
- which expects the Log to be computed between the Softmax and itself.
987
- Use `LogSoftmax` instead (it's faster and has better numerical properties).
988
-
989
- Examples::
990
-
991
- >>> import brainstate.nn as nn
992
- >>> import brainstate as brainstate
993
- >>> m = nn.Softmax(dim=1)
994
- >>> x = random.randn(2, 3)
995
- >>> output = m(x)
996
-
997
- """
998
- __module__ = 'brainstate.nn'
999
- dim: Optional[int]
1000
-
1001
- def __init__(self, dim: Optional[int] = None) -> None:
1002
- super().__init__()
1003
- self.dim = dim
1004
-
1005
- def __call__(self, x: ArrayLike) -> ArrayLike:
1006
- return F.softmax(x, self.dim)
1007
-
1008
- def __repr__(self):
1009
- return f'{self.__class__.__name__}(dim={self.dim})'
1010
-
1011
-
1012
- class Softmax2d(ElementWiseBlock):
1013
- r"""Applies SoftMax over features to each spatial location.
1014
-
1015
- When given an image of ``Channels x Height x Width``, it will
1016
- apply `Softmax` to each location :math:`(Channels, h_i, w_j)`
1017
-
1018
- Shape:
1019
- - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`.
1020
- - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
1021
-
1022
- Returns:
1023
- a Tensor of the same dimension and shape as the input with
1024
- values in the range [0, 1]
1025
-
1026
- Examples::
1027
-
1028
- >>> import brainstate.nn as nn
1029
- >>> import brainstate as brainstate
1030
- >>> m = nn.Softmax2d()
1031
- >>> # you softmax over the 2nd dimension
1032
- >>> x = random.randn(2, 3, 12, 13)
1033
- >>> output = m(x)
1034
- """
1035
- __module__ = 'brainstate.nn'
1036
-
1037
- def __call__(self, x: ArrayLike) -> ArrayLike:
1038
- assert x.ndim == 4 or x.ndim == 3, 'Softmax2d requires a 3D or 4D tensor as input'
1039
- return F.softmax(x, -3)
1040
-
1041
-
1042
- class LogSoftmax(ElementWiseBlock):
1043
- r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
1044
- input Tensor. The LogSoftmax formulation can be simplified as:
1045
-
1046
- .. math::
1047
- \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
1048
-
1049
- Shape:
1050
- - Input: :math:`(*)` where `*` means, any number of additional
1051
- dimensions
1052
- - Output: :math:`(*)`, same shape as the input
1053
-
1054
- Args:
1055
- dim (int): A dimension along which LogSoftmax will be computed.
1056
-
1057
- Returns:
1058
- a Tensor of the same dimension and shape as the input with
1059
- values in the range [-inf, 0)
1060
-
1061
- Examples::
1062
-
1063
- >>> import brainstate.nn as nn
1064
- >>> import brainstate as brainstate
1065
- >>> m = nn.LogSoftmax(dim=1)
1066
- >>> x = random.randn(2, 3)
1067
- >>> output = m(x)
1068
- """
1069
- __module__ = 'brainstate.nn'
1070
- dim: Optional[int]
1071
-
1072
- def __init__(self, dim: Optional[int] = None) -> None:
1073
- super().__init__()
1074
- self.dim = dim
1075
-
1076
- def __call__(self, x: ArrayLike) -> ArrayLike:
1077
- return F.log_softmax(x, self.dim)
1078
-
1079
- def __repr__(self):
1080
- return f'{self.__class__.__name__}(dim={self.dim})'
1081
-
1082
-
1083
- class Identity(ElementWiseBlock):
1084
- r"""A placeholder identity operator that is argument-insensitive.
1085
- """
1086
- __module__ = 'brainstate.nn'
1087
-
1088
- def __call__(self, x):
1089
- return x
1090
-
1091
-
1092
- class SpikeBitwise(ElementWiseBlock):
1093
- r"""Bitwise addition for the spiking inputs.
1094
-
1095
- .. math::
1096
-
1097
- \begin{array}{ccc}
1098
- \hline \text { Mode } & \text { Expression for } \mathrm{g}(\mathrm{x}, \mathrm{y}) & \text { Code for } \mathrm{g}(\mathrm{x}, \mathrm{y}) \\
1099
- \hline \text { ADD } & x+y & x+y \\
1100
- \text { AND } & x \cap y & x \cdot y \\
1101
- \text { IAND } & (\neg x) \cap y & (1-x) \cdot y \\
1102
- \text { OR } & x \cup y & (x+y)-(x \cdot y) \\
1103
- \hline
1104
- \end{array}
1105
-
1106
- Args:
1107
- op: str. The bitwise operation.
1108
- name: str. The name of the dynamic system.
1109
- """
1110
- __module__ = 'brainstate.nn'
1111
-
1112
- def __init__(self,
1113
- op: str = 'add',
1114
- name: Optional[str] = None) -> None:
1115
- super().__init__(name=name)
1116
- self.op = op
1117
-
1118
- def __call__(self, x, y):
1119
- return F.spike_bitwise(x, y, self.op)
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from typing import Optional
19
+
20
+ import brainunit as u
21
+ import jax.numpy as jnp
22
+
23
+ from brainstate._state import ParamState
24
+ from brainstate.typing import ArrayLike
25
+ from . import _activations as F
26
+ from ._module import ElementWiseBlock
27
+
28
+ __all__ = [
29
+ # activation functions
30
+ 'Threshold', 'ReLU', 'RReLU', 'Hardtanh', 'ReLU6', 'Sigmoid', 'Hardsigmoid',
31
+ 'Tanh', 'SiLU', 'Mish', 'Hardswish', 'ELU', 'CELU', 'SELU', 'GLU', 'GELU',
32
+ 'Hardshrink', 'LeakyReLU', 'LogSigmoid', 'Softplus', 'Softshrink', 'PReLU',
33
+ 'Softsign', 'Tanhshrink', 'Softmin', 'Softmax', 'Softmax2d', 'LogSoftmax',
34
+
35
+ # others
36
+ 'Identity', 'SpikeBitwise',
37
+ ]
38
+
39
+
40
+ class Threshold(ElementWiseBlock):
41
+ r"""Thresholds each element of the input Tensor.
42
+
43
+ Threshold is defined as:
44
+
45
+ .. math::
46
+ y =
47
+ \begin{cases}
48
+ x, &\text{ if } x > \text{threshold} \\
49
+ \text{value}, &\text{ otherwise }
50
+ \end{cases}
51
+
52
+ Args:
53
+ threshold: The value to threshold at
54
+ value: The value to replace with
55
+
56
+ Shape:
57
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
58
+ - Output: :math:`(*)`, same shape as the input.
59
+
60
+ Examples::
61
+
62
+ >>> import brainstate.nn as nn
63
+ >>> import brainstate
64
+ >>> m = nn.Threshold(0.1, 20)
65
+ >>> x = random.randn(2)
66
+ >>> output = m(x)
67
+ """
68
+ __module__ = 'brainstate.nn'
69
+ threshold: float
70
+ value: float
71
+
72
+ def __init__(self, threshold: float, value: float) -> None:
73
+ super().__init__()
74
+ self.threshold = threshold
75
+ self.value = value
76
+
77
+ def __call__(self, x: ArrayLike) -> ArrayLike:
78
+ dtype = u.math.get_dtype(x)
79
+ return jnp.where(x > jnp.asarray(self.threshold, dtype=dtype),
80
+ x,
81
+ jnp.asarray(self.value, dtype=dtype))
82
+
83
+ def __repr__(self):
84
+ return f'{self.__class__.__name__}(threshold={self.threshold}, value={self.value})'
85
+
86
+
87
+ class ReLU(ElementWiseBlock):
88
+ r"""Applies the rectified linear unit function element-wise:
89
+
90
+ :math:`\text{ReLU}(x) = (x)^+ = \max(0, x)`
91
+
92
+ Shape:
93
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
94
+ - Output: :math:`(*)`, same shape as the input.
95
+
96
+ Examples::
97
+
98
+ >>> import brainstate.nn as nn
99
+ >>> import brainstate as brainstate
100
+ >>> m = nn.ReLU()
101
+ >>> x = random.randn(2)
102
+ >>> output = m(x)
103
+
104
+
105
+ An implementation of CReLU - https://arxiv.org/abs/1603.05201
106
+
107
+ >>> import brainstate.nn as nn
108
+ >>> import brainstate as brainstate
109
+ >>> m = nn.ReLU()
110
+ >>> x = random.randn(2).unsqueeze(0)
111
+ >>> output = jax.numpy.concat((m(x), m(-x)))
112
+ """
113
+ __module__ = 'brainstate.nn'
114
+
115
+ def __call__(self, x: ArrayLike) -> ArrayLike:
116
+ return F.relu(x)
117
+
118
+ def __repr__(self):
119
+ return f'{self.__class__.__name__}()'
120
+
121
+
122
+ class RReLU(ElementWiseBlock):
123
+ r"""Applies the randomized leaky rectified liner unit function, element-wise,
124
+ as described in the paper:
125
+
126
+ `Empirical Evaluation of Rectified Activations in Convolutional Network`_.
127
+
128
+ The function is defined as:
129
+
130
+ .. math::
131
+ \text{RReLU}(x) =
132
+ \begin{cases}
133
+ x & \text{if } x \geq 0 \\
134
+ ax & \text{ otherwise }
135
+ \end{cases}
136
+
137
+ where :math:`a` is randomly sampled from uniform distribution
138
+ :math:`\mathcal{U}(\text{lower}, \text{upper})`.
139
+
140
+ See: https://arxiv.org/pdf/1505.00853.pdf
141
+
142
+ Args:
143
+ lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
144
+ upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
145
+
146
+ Shape:
147
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
148
+ - Output: :math:`(*)`, same shape as the input.
149
+
150
+ Examples::
151
+
152
+ >>> import brainstate.nn as nn
153
+ >>> import brainstate as brainstate
154
+ >>> m = nn.RReLU(0.1, 0.3)
155
+ >>> x = random.randn(2)
156
+ >>> output = m(x)
157
+
158
+ .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
159
+ https://arxiv.org/abs/1505.00853
160
+ """
161
+ __module__ = 'brainstate.nn'
162
+ lower: float
163
+ upper: float
164
+
165
+ def __init__(
166
+ self,
167
+ lower: float = 1. / 8,
168
+ upper: float = 1. / 3,
169
+ ):
170
+ super().__init__()
171
+ self.lower = lower
172
+ self.upper = upper
173
+
174
+ def __call__(self, x: ArrayLike) -> ArrayLike:
175
+ return F.rrelu(x, self.lower, self.upper)
176
+
177
+ def extra_repr(self):
178
+ return f'{self.__class__.__name__}(lower={self.lower}, upper={self.upper})'
179
+
180
+
181
+ class Hardtanh(ElementWiseBlock):
182
+ r"""Applies the HardTanh function element-wise.
183
+
184
+ HardTanh is defined as:
185
+
186
+ .. math::
187
+ \text{HardTanh}(x) = \begin{cases}
188
+ \text{max\_val} & \text{ if } x > \text{ max\_val } \\
189
+ \text{min\_val} & \text{ if } x < \text{ min\_val } \\
190
+ x & \text{ otherwise } \\
191
+ \end{cases}
192
+
193
+ Args:
194
+ min_val: minimum value of the linear region range. Default: -1
195
+ max_val: maximum value of the linear region range. Default: 1
196
+
197
+ Keyword arguments :attr:`min_value` and :attr:`max_value`
198
+ have been deprecated in favor of :attr:`min_val` and :attr:`max_val`.
199
+
200
+ Shape:
201
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
202
+ - Output: :math:`(*)`, same shape as the input.
203
+
204
+ Examples::
205
+
206
+ >>> import brainstate.nn as nn
207
+ >>> import brainstate as brainstate
208
+ >>> m = nn.Hardtanh(-2, 2)
209
+ >>> x = random.randn(2)
210
+ >>> output = m(x)
211
+ """
212
+ __module__ = 'brainstate.nn'
213
+ min_val: float
214
+ max_val: float
215
+
216
+ def __init__(
217
+ self,
218
+ min_val: float = -1.,
219
+ max_val: float = 1.,
220
+ ) -> None:
221
+ super().__init__()
222
+ self.min_val = min_val
223
+ self.max_val = max_val
224
+ assert self.max_val > self.min_val
225
+
226
+ def __call__(self, x: ArrayLike) -> ArrayLike:
227
+ return F.hard_tanh(x, self.min_val, self.max_val)
228
+
229
+ def extra_repr(self) -> str:
230
+ return f'{self.__class__.__name__}(min_val={self.min_val}, max_val={self.max_val})'
231
+
232
+
233
+ class ReLU6(Hardtanh, ElementWiseBlock):
234
+ r"""Applies the element-wise function:
235
+
236
+ .. math::
237
+ \text{ReLU6}(x) = \min(\max(0,x), 6)
238
+
239
+ Shape:
240
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
241
+ - Output: :math:`(*)`, same shape as the input.
242
+
243
+ Examples::
244
+
245
+ >>> import brainstate.nn as nn
246
+ >>> import brainstate as brainstate
247
+ >>> m = nn.ReLU6()
248
+ >>> x = random.randn(2)
249
+ >>> output = m(x)
250
+ """
251
+ __module__ = 'brainstate.nn'
252
+
253
+ def __init__(self):
254
+ super().__init__(0., 6.)
255
+
256
+
257
+ class Sigmoid(ElementWiseBlock):
258
+ r"""Applies the element-wise function:
259
+
260
+ .. math::
261
+ \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
262
+
263
+
264
+ Shape:
265
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
266
+ - Output: :math:`(*)`, same shape as the input.
267
+
268
+ Examples::
269
+
270
+ >>> import brainstate.nn as nn
271
+ >>> import brainstate as brainstate
272
+ >>> m = nn.Sigmoid()
273
+ >>> x = random.randn(2)
274
+ >>> output = m(x)
275
+ """
276
+ __module__ = 'brainstate.nn'
277
+
278
+ def __call__(self, x: ArrayLike) -> ArrayLike:
279
+ return F.sigmoid(x)
280
+
281
+
282
+ class Hardsigmoid(ElementWiseBlock):
283
+ r"""Applies the Hardsigmoid function element-wise.
284
+
285
+ Hardsigmoid is defined as:
286
+
287
+ .. math::
288
+ \text{Hardsigmoid}(x) = \begin{cases}
289
+ 0 & \text{if~} x \le -3, \\
290
+ 1 & \text{if~} x \ge +3, \\
291
+ x / 6 + 1 / 2 & \text{otherwise}
292
+ \end{cases}
293
+
294
+ Shape:
295
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
296
+ - Output: :math:`(*)`, same shape as the input.
297
+
298
+ Examples::
299
+
300
+ >>> import brainstate.nn as nn
301
+ >>> import brainstate as brainstate
302
+ >>> m = nn.Hardsigmoid()
303
+ >>> x = random.randn(2)
304
+ >>> output = m(x)
305
+ """
306
+ __module__ = 'brainstate.nn'
307
+
308
+ def __call__(self, x: ArrayLike) -> ArrayLike:
309
+ return F.hard_sigmoid(x)
310
+
311
+
312
+ class Tanh(ElementWiseBlock):
313
+ r"""Applies the Hyperbolic Tangent (Tanh) function element-wise.
314
+
315
+ Tanh is defined as:
316
+
317
+ .. math::
318
+ \text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)}
319
+
320
+ Shape:
321
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
322
+ - Output: :math:`(*)`, same shape as the input.
323
+
324
+ Examples::
325
+
326
+ >>> import brainstate.nn as nn
327
+ >>> import brainstate as brainstate
328
+ >>> m = nn.Tanh()
329
+ >>> x = random.randn(2)
330
+ >>> output = m(x)
331
+ """
332
+ __module__ = 'brainstate.nn'
333
+
334
+ def __call__(self, x: ArrayLike) -> ArrayLike:
335
+ return F.tanh(x)
336
+
337
+
338
+ class SiLU(ElementWiseBlock):
339
+ r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
340
+ The SiLU function is also known as the swish function.
341
+
342
+ .. math::
343
+ \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.}
344
+
345
+ .. note::
346
+ See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
347
+ where the SiLU (Sigmoid Linear Unit) was originally coined, and see
348
+ `Sigmoid-Weighted Linear Units for Neural Network Function Approximation
349
+ in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
350
+ a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
351
+ where the SiLU was experimented with later.
352
+
353
+ Shape:
354
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
355
+ - Output: :math:`(*)`, same shape as the input.
356
+
357
+ Examples::
358
+
359
+ >>> import brainstate.nn as nn
360
+ >>> m = nn.SiLU()
361
+ >>> x = random.randn(2)
362
+ >>> output = m(x)
363
+ """
364
+ __module__ = 'brainstate.nn'
365
+
366
+ def __call__(self, x: ArrayLike) -> ArrayLike:
367
+ return F.silu(x)
368
+
369
+
370
+ class Mish(ElementWiseBlock):
371
+ r"""Applies the Mish function, element-wise.
372
+ Mish: A Self Regularized Non-Monotonic Neural Activation Function.
373
+
374
+ .. math::
375
+ \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
376
+
377
+ .. note::
378
+ See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_
379
+
380
+
381
+ Shape:
382
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
383
+ - Output: :math:`(*)`, same shape as the input.
384
+
385
+ Examples::
386
+
387
+ >>> import brainstate.nn as nn
388
+ >>> import brainstate as brainstate
389
+ >>> m = nn.Mish()
390
+ >>> x = random.randn(2)
391
+ >>> output = m(x)
392
+ """
393
+ __module__ = 'brainstate.nn'
394
+
395
+ def __call__(self, x: ArrayLike) -> ArrayLike:
396
+ return F.mish(x)
397
+
398
+
399
+ class Hardswish(ElementWiseBlock):
400
+ r"""Applies the Hardswish function, element-wise, as described in the paper:
401
+ `Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`_.
402
+
403
+ Hardswish is defined as:
404
+
405
+ .. math::
406
+ \text{Hardswish}(x) = \begin{cases}
407
+ 0 & \text{if~} x \le -3, \\
408
+ x & \text{if~} x \ge +3, \\
409
+ x \cdot (x + 3) /6 & \text{otherwise}
410
+ \end{cases}
411
+
412
+
413
+ Shape:
414
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
415
+ - Output: :math:`(*)`, same shape as the input.
416
+
417
+ Examples::
418
+
419
+ >>> import brainstate.nn as nn
420
+ >>> import brainstate as brainstate
421
+ >>> m = nn.Hardswish()
422
+ >>> x = random.randn(2)
423
+ >>> output = m(x)
424
+ """
425
+ __module__ = 'brainstate.nn'
426
+
427
+ def __call__(self, x: ArrayLike) -> ArrayLike:
428
+ return F.hard_swish(x)
429
+
430
+
431
+ class ELU(ElementWiseBlock):
432
+ r"""Applies the Exponential Linear Unit (ELU) function, element-wise, as described
433
+ in the paper: `Fast and Accurate Deep Network Learning by Exponential Linear
434
+ Units (ELUs) <https://arxiv.org/abs/1511.07289>`__.
435
+
436
+ ELU is defined as:
437
+
438
+ .. math::
439
+ \text{ELU}(x) = \begin{cases}
440
+ x, & \text{ if } x > 0\\
441
+ \alpha * (\exp(x) - 1), & \text{ if } x \leq 0
442
+ \end{cases}
443
+
444
+ Args:
445
+ alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
446
+
447
+ Shape:
448
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
449
+ - Output: :math:`(*)`, same shape as the input.
450
+
451
+ Examples::
452
+
453
+ >>> import brainstate.nn as nn
454
+ >>> import brainstate as brainstate
455
+ >>> m = nn.ELU()
456
+ >>> x = random.randn(2)
457
+ >>> output = m(x)
458
+ """
459
+ __module__ = 'brainstate.nn'
460
+ alpha: float
461
+
462
+ def __init__(self, alpha: float = 1.) -> None:
463
+ super().__init__()
464
+ self.alpha = alpha
465
+
466
+ def __call__(self, x: ArrayLike) -> ArrayLike:
467
+ return F.elu(x, self.alpha)
468
+
469
+ def extra_repr(self) -> str:
470
+ return f'{self.__class__.__name__}(alpha={self.alpha})'
471
+
472
+
473
+ class CELU(ElementWiseBlock):
474
+ r"""Applies the element-wise function:
475
+
476
+ .. math::
477
+ \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
478
+
479
+ More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ .
480
+
481
+ Args:
482
+ alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
483
+
484
+ Shape:
485
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
486
+ - Output: :math:`(*)`, same shape as the input.
487
+
488
+ Examples::
489
+
490
+ >>> import brainstate.nn as nn
491
+ >>> import brainstate as brainstate
492
+ >>> m = nn.CELU()
493
+ >>> x = random.randn(2)
494
+ >>> output = m(x)
495
+
496
+ .. _`Continuously Differentiable Exponential Linear Units`:
497
+ https://arxiv.org/abs/1704.07483
498
+ """
499
+ __module__ = 'brainstate.nn'
500
+ alpha: float
501
+
502
+ def __init__(self, alpha: float = 1.) -> None:
503
+ super().__init__()
504
+ self.alpha = alpha
505
+
506
+ def __call__(self, x: ArrayLike) -> ArrayLike:
507
+ return F.celu(x, self.alpha)
508
+
509
+ def extra_repr(self) -> str:
510
+ return f'{self.__class__.__name__}(alpha={self.alpha})'
511
+
512
+
513
+ class SELU(ElementWiseBlock):
514
+ r"""Applied element-wise, as:
515
+
516
+ .. math::
517
+ \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))
518
+
519
+ with :math:`\alpha = 1.6732632423543772848170429916717` and
520
+ :math:`\text{scale} = 1.0507009873554804934193349852946`.
521
+
522
+ More details can be found in the paper `Self-Normalizing Neural Networks`_ .
523
+
524
+
525
+ Shape:
526
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
527
+ - Output: :math:`(*)`, same shape as the input.
528
+
529
+ Examples::
530
+
531
+ >>> import brainstate.nn as nn
532
+ >>> import brainstate as brainstate
533
+ >>> m = nn.SELU()
534
+ >>> x = random.randn(2)
535
+ >>> output = m(x)
536
+
537
+ .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
538
+ """
539
+ __module__ = 'brainstate.nn'
540
+
541
+ def __call__(self, x: ArrayLike) -> ArrayLike:
542
+ return F.selu(x)
543
+
544
+
545
+ class GLU(ElementWiseBlock):
546
+ r"""Applies the gated linear unit function
547
+ :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
548
+ of the input matrices and :math:`b` is the second half.
549
+
550
+ Args:
551
+ dim (int): the dimension on which to split the input. Default: -1
552
+
553
+ Shape:
554
+ - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
555
+ dimensions
556
+ - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
557
+
558
+ Examples::
559
+
560
+ >>> import brainstate.nn as nn
561
+ >>> import brainstate as brainstate
562
+ >>> m = nn.GLU()
563
+ >>> x = random.randn(4, 2)
564
+ >>> output = m(x)
565
+ """
566
+ __module__ = 'brainstate.nn'
567
+ dim: int
568
+
569
+ def __init__(self, dim: int = -1) -> None:
570
+ super().__init__()
571
+ self.dim = dim
572
+
573
+ def __call__(self, x: ArrayLike) -> ArrayLike:
574
+ return F.glu(x, self.dim)
575
+
576
+ def __repr__(self):
577
+ return f'{self.__class__.__name__}(dim={self.dim})'
578
+
579
+
580
+ class GELU(ElementWiseBlock):
581
+ r"""Applies the Gaussian Error Linear Units function:
582
+
583
+ .. math:: \text{GELU}(x) = x * \Phi(x)
584
+
585
+ where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
586
+
587
+ When the approximate argument is 'tanh', Gelu is estimated with:
588
+
589
+ .. math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt(2 / \pi) * (x + 0.044715 * x^3)))
590
+
591
+ Args:
592
+ approximate (str, optional): the gelu approximation algorithm to use:
593
+ ``'none'`` | ``'tanh'``. Default: ``'none'``
594
+
595
+ Shape:
596
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
597
+ - Output: :math:`(*)`, same shape as the input.
598
+
599
+ Examples::
600
+
601
+ >>> import brainstate.nn as nn
602
+ >>> import brainstate as brainstate
603
+ >>> m = nn.GELU()
604
+ >>> x = random.randn(2)
605
+ >>> output = m(x)
606
+ """
607
+ __module__ = 'brainstate.nn'
608
+ approximate: bool
609
+
610
+ def __init__(self, approximate: bool = False) -> None:
611
+ super().__init__()
612
+ self.approximate = approximate
613
+
614
+ def __call__(self, x: ArrayLike) -> ArrayLike:
615
+ return F.gelu(x, approximate=self.approximate)
616
+
617
+ def __repr__(self):
618
+ return f'{self.__class__.__name__}(approximate={self.approximate})'
619
+
620
+
621
+ class Hardshrink(ElementWiseBlock):
622
+ r"""Applies the Hard Shrinkage (Hardshrink) function element-wise.
623
+
624
+ Hardshrink is defined as:
625
+
626
+ .. math::
627
+ \text{HardShrink}(x) =
628
+ \begin{cases}
629
+ x, & \text{ if } x > \lambda \\
630
+ x, & \text{ if } x < -\lambda \\
631
+ 0, & \text{ otherwise }
632
+ \end{cases}
633
+
634
+ Args:
635
+ lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
636
+
637
+ Shape:
638
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
639
+ - Output: :math:`(*)`, same shape as the input.
640
+
641
+ Examples::
642
+
643
+ >>> import brainstate.nn as nn
644
+ >>> import brainstate as brainstate
645
+ >>> m = nn.Hardshrink()
646
+ >>> x = random.randn(2)
647
+ >>> output = m(x)
648
+ """
649
+ __module__ = 'brainstate.nn'
650
+ lambd: float
651
+
652
+ def __init__(self, lambd: float = 0.5) -> None:
653
+ super().__init__()
654
+ self.lambd = lambd
655
+
656
+ def __call__(self, x: ArrayLike) -> ArrayLike:
657
+ return F.hard_shrink(x, self.lambd)
658
+
659
+ def __repr__(self):
660
+ return f'{self.__class__.__name__}(lambd={self.lambd})'
661
+
662
+
663
+ class LeakyReLU(ElementWiseBlock):
664
+ r"""Applies the element-wise function:
665
+
666
+ .. math::
667
+ \text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)
668
+
669
+
670
+ or
671
+
672
+ .. math::
673
+ \text{LeakyReLU}(x) =
674
+ \begin{cases}
675
+ x, & \text{ if } x \geq 0 \\
676
+ \text{negative\_slope} \times x, & \text{ otherwise }
677
+ \end{cases}
678
+
679
+ Args:
680
+ negative_slope: Controls the angle of the negative slope (which is used for
681
+ negative input values). Default: 1e-2
682
+
683
+ Shape:
684
+ - Input: :math:`(*)` where `*` means, any number of additional
685
+ dimensions
686
+ - Output: :math:`(*)`, same shape as the input
687
+
688
+ Examples::
689
+
690
+ >>> import brainstate.nn as nn
691
+ >>> import brainstate as brainstate
692
+ >>> m = nn.LeakyReLU(0.1)
693
+ >>> x = random.randn(2)
694
+ >>> output = m(x)
695
+ """
696
+ __module__ = 'brainstate.nn'
697
+ negative_slope: float
698
+
699
+ def __init__(self, negative_slope: float = 1e-2) -> None:
700
+ super().__init__()
701
+ self.negative_slope = negative_slope
702
+
703
+ def __call__(self, x: ArrayLike) -> ArrayLike:
704
+ return F.leaky_relu(x, self.negative_slope)
705
+
706
+ def __repr__(self):
707
+ return f'{self.__class__.__name__}(negative_slope={self.negative_slope})'
708
+
709
+
710
+ class LogSigmoid(ElementWiseBlock):
711
+ r"""Applies the element-wise function:
712
+
713
+ .. math::
714
+ \text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right)
715
+
716
+ Shape:
717
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
718
+ - Output: :math:`(*)`, same shape as the input.
719
+
720
+ Examples::
721
+
722
+ >>> import brainstate.nn as nn
723
+ >>> import brainstate as brainstate
724
+ >>> m = nn.LogSigmoid()
725
+ >>> x = random.randn(2)
726
+ >>> output = m(x)
727
+ """
728
+ __module__ = 'brainstate.nn'
729
+
730
+ def __call__(self, x: ArrayLike) -> ArrayLike:
731
+ return F.log_sigmoid(x)
732
+
733
+
734
+ class Softplus(ElementWiseBlock):
735
+ r"""Applies the Softplus function :math:`\text{Softplus}(x) = \frac{1}{\beta} *
736
+ \log(1 + \exp(\beta * x))` element-wise.
737
+
738
+ SoftPlus is a smooth approximation to the ReLU function and can be used
739
+ to constrain the output of a machine to always be positive.
740
+
741
+ For numerical stability the implementation reverts to the linear function
742
+ when :math:`input \times \beta > threshold`.
743
+
744
+ Shape:
745
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
746
+ - Output: :math:`(*)`, same shape as the input.
747
+
748
+ Examples::
749
+
750
+ >>> import brainstate.nn as nn
751
+ >>> import brainstate as brainstate
752
+ >>> m = nn.Softplus()
753
+ >>> x = random.randn(2)
754
+ >>> output = m(x)
755
+ """
756
+ __module__ = 'brainstate.nn'
757
+
758
+ def __call__(self, x: ArrayLike) -> ArrayLike:
759
+ return F.softplus(x)
760
+
761
+
762
+ class Softshrink(ElementWiseBlock):
763
+ r"""Applies the soft shrinkage function elementwise:
764
+
765
+ .. math::
766
+ \text{SoftShrinkage}(x) =
767
+ \begin{cases}
768
+ x - \lambda, & \text{ if } x > \lambda \\
769
+ x + \lambda, & \text{ if } x < -\lambda \\
770
+ 0, & \text{ otherwise }
771
+ \end{cases}
772
+
773
+ Args:
774
+ lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5
775
+
776
+ Shape:
777
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
778
+ - Output: :math:`(*)`, same shape as the input.
779
+
780
+ Examples::
781
+
782
+ >>> import brainstate.nn as nn
783
+ >>> import brainstate as brainstate
784
+ >>> m = nn.Softshrink()
785
+ >>> x = random.randn(2)
786
+ >>> output = m(x)
787
+ """
788
+ __module__ = 'brainstate.nn'
789
+ lambd: float
790
+
791
+ def __init__(self, lambd: float = 0.5) -> None:
792
+ super().__init__()
793
+ self.lambd = lambd
794
+
795
+ def __call__(self, x: ArrayLike) -> ArrayLike:
796
+ return F.soft_shrink(x, self.lambd)
797
+
798
+ def __repr__(self):
799
+ return f'{self.__class__.__name__}(lambd={self.lambd})'
800
+
801
+
802
+ class PReLU(ElementWiseBlock):
803
+ r"""Applies the element-wise function:
804
+
805
+ .. math::
806
+ \text{PReLU}(x) = \max(0,x) + a * \min(0,x)
807
+
808
+ or
809
+
810
+ .. math::
811
+ \text{PReLU}(x) =
812
+ \begin{cases}
813
+ x, & \text{ if } x \geq 0 \\
814
+ ax, & \text{ otherwise }
815
+ \end{cases}
816
+
817
+ Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single
818
+ parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
819
+ a separate :math:`a` is used for each input channel.
820
+
821
+
822
+ .. note::
823
+ weight decay should not be used when learning :math:`a` for good performance.
824
+
825
+ .. note::
826
+ Channel dim is the 2nd dim of input. When input has dims < 2, then there is
827
+ no channel dim and the number of channels = 1.
828
+
829
+ Args:
830
+ num_parameters (int): number of :math:`a` to learn.
831
+ Although it takes an int as input, there is only two values are legitimate:
832
+ 1, or the number of channels at input. Default: 1
833
+ init (float): the initial value of :math:`a`. Default: 0.25
834
+
835
+ Shape:
836
+ - Input: :math:`( *)` where `*` means, any number of additional
837
+ dimensions.
838
+ - Output: :math:`(*)`, same shape as the input.
839
+
840
+ Attributes:
841
+ weight (Tensor): the learnable weights of shape (:attr:`num_parameters`).
842
+
843
+ Examples::
844
+
845
+ >>> import brainstate as brainstate
846
+ >>> m = brainstate.nn.PReLU()
847
+ >>> x = brainstate.random.randn(2)
848
+ >>> output = m(x)
849
+ """
850
+ __module__ = 'brainstate.nn'
851
+ num_parameters: int
852
+
853
+ def __init__(self, num_parameters: int = 1, init: float = 0.25, dtype=None) -> None:
854
+ super().__init__()
855
+ self.num_parameters = num_parameters
856
+ self.weight = ParamState(jnp.ones(num_parameters, dtype=dtype) * init)
857
+
858
+ def __call__(self, x: ArrayLike) -> ArrayLike:
859
+ return F.prelu(x, self.weight.value)
860
+
861
+ def __repr__(self):
862
+ return f'{self.__class__.__name__}(num_parameters={self.num_parameters})'
863
+
864
+
865
+ class Softsign(ElementWiseBlock):
866
+ r"""Applies the element-wise function:
867
+
868
+ .. math::
869
+ \text{SoftSign}(x) = \frac{x}{ 1 + |x|}
870
+
871
+ Shape:
872
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
873
+ - Output: :math:`(*)`, same shape as the input.
874
+
875
+ Examples::
876
+
877
+ >>> import brainstate.nn as nn
878
+ >>> import brainstate as brainstate
879
+ >>> m = nn.Softsign()
880
+ >>> x = random.randn(2)
881
+ >>> output = m(x)
882
+ """
883
+ __module__ = 'brainstate.nn'
884
+
885
+ def __call__(self, x: ArrayLike) -> ArrayLike:
886
+ return F.soft_sign(x)
887
+
888
+
889
+ class Tanhshrink(ElementWiseBlock):
890
+ r"""Applies the element-wise function:
891
+
892
+ .. math::
893
+ \text{Tanhshrink}(x) = x - \tanh(x)
894
+
895
+ Shape:
896
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
897
+ - Output: :math:`(*)`, same shape as the input.
898
+
899
+ Examples::
900
+
901
+ >>> import brainstate.nn as nn
902
+ >>> import brainstate as brainstate
903
+ >>> m = nn.Tanhshrink()
904
+ >>> x = random.randn(2)
905
+ >>> output = m(x)
906
+ """
907
+ __module__ = 'brainstate.nn'
908
+
909
+ def __call__(self, x: ArrayLike) -> ArrayLike:
910
+ return F.tanh_shrink(x)
911
+
912
+
913
+ class Softmin(ElementWiseBlock):
914
+ r"""Applies the Softmin function to an n-dimensional input Tensor
915
+ rescaling them so that the elements of the n-dimensional output Tensor
916
+ lie in the range `[0, 1]` and sum to 1.
917
+
918
+ Softmin is defined as:
919
+
920
+ .. math::
921
+ \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
922
+
923
+ Shape:
924
+ - Input: :math:`(*)` where `*` means, any number of additional
925
+ dimensions
926
+ - Output: :math:`(*)`, same shape as the input
927
+
928
+ Args:
929
+ dim (int): A dimension along which Softmin will be computed (so every slice
930
+ along dim will sum to 1).
931
+
932
+ Returns:
933
+ a Tensor of the same dimension and shape as the input, with
934
+ values in the range [0, 1]
935
+
936
+ Examples::
937
+
938
+ >>> import brainstate.nn as nn
939
+ >>> import brainstate as brainstate
940
+ >>> m = nn.Softmin(dim=1)
941
+ >>> x = random.randn(2, 3)
942
+ >>> output = m(x)
943
+ """
944
+ __module__ = 'brainstate.nn'
945
+ dim: Optional[int]
946
+
947
+ def __init__(self, dim: Optional[int] = None) -> None:
948
+ super().__init__()
949
+ self.dim = dim
950
+
951
+ def __call__(self, x: ArrayLike) -> ArrayLike:
952
+ return F.softmin(x, self.dim)
953
+
954
+ def __repr__(self):
955
+ return f'{self.__class__.__name__}(dim={self.dim})'
956
+
957
+
958
+ class Softmax(ElementWiseBlock):
959
+ r"""Applies the Softmax function to an n-dimensional input Tensor
960
+ rescaling them so that the elements of the n-dimensional output Tensor
961
+ lie in the range [0,1] and sum to 1.
962
+
963
+ Softmax is defined as:
964
+
965
+ .. math::
966
+ \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
967
+
968
+ When the input Tensor is a sparse tensor then the unspecified
969
+ values are treated as ``-inf``.
970
+
971
+ Shape:
972
+ - Input: :math:`(*)` where `*` means, any number of additional
973
+ dimensions
974
+ - Output: :math:`(*)`, same shape as the input
975
+
976
+ Returns:
977
+ a Tensor of the same dimension and shape as the input with
978
+ values in the range [0, 1]
979
+
980
+ Args:
981
+ dim (int): A dimension along which Softmax will be computed (so every slice
982
+ along dim will sum to 1).
983
+
984
+ .. note::
985
+ This module doesn't work directly with NLLLoss,
986
+ which expects the Log to be computed between the Softmax and itself.
987
+ Use `LogSoftmax` instead (it's faster and has better numerical properties).
988
+
989
+ Examples::
990
+
991
+ >>> import brainstate.nn as nn
992
+ >>> import brainstate as brainstate
993
+ >>> m = nn.Softmax(dim=1)
994
+ >>> x = random.randn(2, 3)
995
+ >>> output = m(x)
996
+
997
+ """
998
+ __module__ = 'brainstate.nn'
999
+ dim: Optional[int]
1000
+
1001
+ def __init__(self, dim: Optional[int] = None) -> None:
1002
+ super().__init__()
1003
+ self.dim = dim
1004
+
1005
+ def __call__(self, x: ArrayLike) -> ArrayLike:
1006
+ return F.softmax(x, self.dim)
1007
+
1008
+ def __repr__(self):
1009
+ return f'{self.__class__.__name__}(dim={self.dim})'
1010
+
1011
+
1012
+ class Softmax2d(ElementWiseBlock):
1013
+ r"""Applies SoftMax over features to each spatial location.
1014
+
1015
+ When given an image of ``Channels x Height x Width``, it will
1016
+ apply `Softmax` to each location :math:`(Channels, h_i, w_j)`
1017
+
1018
+ Shape:
1019
+ - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`.
1020
+ - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
1021
+
1022
+ Returns:
1023
+ a Tensor of the same dimension and shape as the input with
1024
+ values in the range [0, 1]
1025
+
1026
+ Examples::
1027
+
1028
+ >>> import brainstate.nn as nn
1029
+ >>> import brainstate as brainstate
1030
+ >>> m = nn.Softmax2d()
1031
+ >>> # you softmax over the 2nd dimension
1032
+ >>> x = random.randn(2, 3, 12, 13)
1033
+ >>> output = m(x)
1034
+ """
1035
+ __module__ = 'brainstate.nn'
1036
+
1037
+ def __call__(self, x: ArrayLike) -> ArrayLike:
1038
+ assert x.ndim == 4 or x.ndim == 3, 'Softmax2d requires a 3D or 4D tensor as input'
1039
+ return F.softmax(x, -3)
1040
+
1041
+
1042
+ class LogSoftmax(ElementWiseBlock):
1043
+ r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
1044
+ input Tensor. The LogSoftmax formulation can be simplified as:
1045
+
1046
+ .. math::
1047
+ \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
1048
+
1049
+ Shape:
1050
+ - Input: :math:`(*)` where `*` means, any number of additional
1051
+ dimensions
1052
+ - Output: :math:`(*)`, same shape as the input
1053
+
1054
+ Args:
1055
+ dim (int): A dimension along which LogSoftmax will be computed.
1056
+
1057
+ Returns:
1058
+ a Tensor of the same dimension and shape as the input with
1059
+ values in the range [-inf, 0)
1060
+
1061
+ Examples::
1062
+
1063
+ >>> import brainstate.nn as nn
1064
+ >>> import brainstate as brainstate
1065
+ >>> m = nn.LogSoftmax(dim=1)
1066
+ >>> x = random.randn(2, 3)
1067
+ >>> output = m(x)
1068
+ """
1069
+ __module__ = 'brainstate.nn'
1070
+ dim: Optional[int]
1071
+
1072
+ def __init__(self, dim: Optional[int] = None) -> None:
1073
+ super().__init__()
1074
+ self.dim = dim
1075
+
1076
+ def __call__(self, x: ArrayLike) -> ArrayLike:
1077
+ return F.log_softmax(x, self.dim)
1078
+
1079
+ def __repr__(self):
1080
+ return f'{self.__class__.__name__}(dim={self.dim})'
1081
+
1082
+
1083
+ class Identity(ElementWiseBlock):
1084
+ r"""A placeholder identity operator that is argument-insensitive.
1085
+ """
1086
+ __module__ = 'brainstate.nn'
1087
+
1088
+ def __call__(self, x):
1089
+ return x
1090
+
1091
+
1092
+ class SpikeBitwise(ElementWiseBlock):
1093
+ r"""Bitwise addition for the spiking inputs.
1094
+
1095
+ .. math::
1096
+
1097
+ \begin{array}{ccc}
1098
+ \hline \text { Mode } & \text { Expression for } \mathrm{g}(\mathrm{x}, \mathrm{y}) & \text { Code for } \mathrm{g}(\mathrm{x}, \mathrm{y}) \\
1099
+ \hline \text { ADD } & x+y & x+y \\
1100
+ \text { AND } & x \cap y & x \cdot y \\
1101
+ \text { IAND } & (\neg x) \cap y & (1-x) \cdot y \\
1102
+ \text { OR } & x \cup y & (x+y)-(x \cdot y) \\
1103
+ \hline
1104
+ \end{array}
1105
+
1106
+ Args:
1107
+ op: str. The bitwise operation.
1108
+ name: str. The name of the dynamic system.
1109
+ """
1110
+ __module__ = 'brainstate.nn'
1111
+
1112
+ def __init__(self,
1113
+ op: str = 'add',
1114
+ name: Optional[str] = None) -> None:
1115
+ super().__init__(name=name)
1116
+ self.op = op
1117
+
1118
+ def __call__(self, x, y):
1119
+ return F.spike_bitwise(x, y, self.op)