brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,1298 +1,1298 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- from typing import Optional
19
-
20
- import brainunit as u
21
- import jax.numpy as jnp
22
-
23
- from brainstate._state import ParamState
24
- from brainstate.typing import ArrayLike
25
- from . import _activations as F
26
- from ._module import ElementWiseBlock
27
-
28
- __all__ = [
29
- # activation functions
30
- 'Threshold', 'ReLU', 'RReLU', 'Hardtanh', 'ReLU6', 'Sigmoid', 'Hardsigmoid',
31
- 'Tanh', 'SiLU', 'Mish', 'Hardswish', 'ELU', 'CELU', 'SELU', 'GLU', 'GELU',
32
- 'Hardshrink', 'LeakyReLU', 'LogSigmoid', 'Softplus', 'Softshrink', 'PReLU',
33
- 'Softsign', 'Tanhshrink', 'Softmin', 'Softmax', 'Softmax2d', 'LogSoftmax',
34
-
35
- # others
36
- 'Identity', 'SpikeBitwise',
37
- ]
38
-
39
-
40
- class Threshold(ElementWiseBlock):
41
- r"""Thresholds each element of the input Tensor.
42
-
43
- Threshold is defined as:
44
-
45
- .. math::
46
- y =
47
- \begin{cases}
48
- x, &\text{ if } x > \text{threshold} \\
49
- \text{value}, &\text{ otherwise }
50
- \end{cases}
51
-
52
- Parameters
53
- ----------
54
- threshold : float
55
- The value to threshold at.
56
- value : float
57
- The value to replace with.
58
-
59
- Shape
60
- -----
61
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
62
- - Output: :math:`(*)`, same shape as the input.
63
-
64
- Examples
65
- --------
66
- .. code-block:: python
67
-
68
- >>> import brainstate.nn as nn
69
- >>> import brainstate
70
- >>> m = nn.Threshold(0.1, 20)
71
- >>> x = brainstate.random.randn(2)
72
- >>> output = m(x)
73
- """
74
- __module__ = 'brainstate.nn'
75
- threshold: float
76
- value: float
77
-
78
- def __init__(self, threshold: float, value: float) -> None:
79
- super().__init__()
80
- self.threshold = threshold
81
- self.value = value
82
-
83
- def __call__(self, x: ArrayLike) -> ArrayLike:
84
- dtype = u.math.get_dtype(x)
85
- return jnp.where(x > jnp.asarray(self.threshold, dtype=dtype),
86
- x,
87
- jnp.asarray(self.value, dtype=dtype))
88
-
89
- def __repr__(self):
90
- return f'{self.__class__.__name__}(threshold={self.threshold}, value={self.value})'
91
-
92
-
93
- class ReLU(ElementWiseBlock):
94
- r"""Applies the rectified linear unit function element-wise.
95
-
96
- The ReLU function is defined as:
97
-
98
- .. math::
99
- \text{ReLU}(x) = (x)^+ = \max(0, x)
100
-
101
- Shape
102
- -----
103
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
104
- - Output: :math:`(*)`, same shape as the input.
105
-
106
- Examples
107
- --------
108
- .. code-block:: python
109
-
110
- >>> import brainstate.nn as nn
111
- >>> import brainstate
112
- >>> m = nn.ReLU()
113
- >>> x = brainstate.random.randn(2)
114
- >>> output = m(x)
115
-
116
- An implementation of CReLU - https://arxiv.org/abs/1603.05201
117
-
118
- .. code-block:: python
119
-
120
- >>> import brainstate.nn as nn
121
- >>> import brainstate
122
- >>> import jax.numpy as jnp
123
- >>> m = nn.ReLU()
124
- >>> x = brainstate.random.randn(2).unsqueeze(0)
125
- >>> output = jnp.concat((m(x), m(-x)))
126
- """
127
- __module__ = 'brainstate.nn'
128
-
129
- def __call__(self, x: ArrayLike) -> ArrayLike:
130
- return F.relu(x)
131
-
132
- def __repr__(self):
133
- return f'{self.__class__.__name__}()'
134
-
135
-
136
- class RReLU(ElementWiseBlock):
137
- r"""Applies the randomized leaky rectified liner unit function, element-wise.
138
-
139
- As described in the paper `Empirical Evaluation of Rectified Activations in
140
- Convolutional Network`_.
141
-
142
- The function is defined as:
143
-
144
- .. math::
145
- \text{RReLU}(x) =
146
- \begin{cases}
147
- x & \text{if } x \geq 0 \\
148
- ax & \text{ otherwise }
149
- \end{cases}
150
-
151
- where :math:`a` is randomly sampled from uniform distribution
152
- :math:`\mathcal{U}(\text{lower}, \text{upper})`.
153
-
154
- Parameters
155
- ----------
156
- lower : float, optional
157
- Lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
158
- upper : float, optional
159
- Upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
160
-
161
- Shape
162
- -----
163
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
164
- - Output: :math:`(*)`, same shape as the input.
165
-
166
- References
167
- ----------
168
- .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
169
- https://arxiv.org/abs/1505.00853
170
-
171
- Examples
172
- --------
173
- .. code-block:: python
174
-
175
- >>> import brainstate.nn as nn
176
- >>> import brainstate
177
- >>> m = nn.RReLU(0.1, 0.3)
178
- >>> x = brainstate.random.randn(2)
179
- >>> output = m(x)
180
- """
181
- __module__ = 'brainstate.nn'
182
- lower: float
183
- upper: float
184
-
185
- def __init__(
186
- self,
187
- lower: float = 1. / 8,
188
- upper: float = 1. / 3,
189
- ):
190
- super().__init__()
191
- self.lower = lower
192
- self.upper = upper
193
-
194
- def __call__(self, x: ArrayLike) -> ArrayLike:
195
- return F.rrelu(x, self.lower, self.upper)
196
-
197
- def extra_repr(self):
198
- return f'{self.__class__.__name__}(lower={self.lower}, upper={self.upper})'
199
-
200
-
201
- class Hardtanh(ElementWiseBlock):
202
- r"""Applies the HardTanh function element-wise.
203
-
204
- HardTanh is defined as:
205
-
206
- .. math::
207
- \text{HardTanh}(x) = \begin{cases}
208
- \text{max\_val} & \text{ if } x > \text{ max\_val } \\
209
- \text{min\_val} & \text{ if } x < \text{ min\_val } \\
210
- x & \text{ otherwise } \\
211
- \end{cases}
212
-
213
- Parameters
214
- ----------
215
- min_val : float, optional
216
- Minimum value of the linear region range. Default: -1
217
- max_val : float, optional
218
- Maximum value of the linear region range. Default: 1
219
-
220
- Notes
221
- -----
222
- Keyword arguments :attr:`min_value` and :attr:`max_value`
223
- have been deprecated in favor of :attr:`min_val` and :attr:`max_val`.
224
-
225
- Shape
226
- -----
227
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
228
- - Output: :math:`(*)`, same shape as the input.
229
-
230
- Examples
231
- --------
232
- .. code-block:: python
233
-
234
- >>> import brainstate.nn as nn
235
- >>> import brainstate
236
- >>> m = nn.Hardtanh(-2, 2)
237
- >>> x = brainstate.random.randn(2)
238
- >>> output = m(x)
239
- """
240
- __module__ = 'brainstate.nn'
241
- min_val: float
242
- max_val: float
243
-
244
- def __init__(
245
- self,
246
- min_val: float = -1.,
247
- max_val: float = 1.,
248
- ) -> None:
249
- super().__init__()
250
- self.min_val = min_val
251
- self.max_val = max_val
252
- assert self.max_val > self.min_val
253
-
254
- def __call__(self, x: ArrayLike) -> ArrayLike:
255
- return F.hard_tanh(x, self.min_val, self.max_val)
256
-
257
- def extra_repr(self) -> str:
258
- return f'{self.__class__.__name__}(min_val={self.min_val}, max_val={self.max_val})'
259
-
260
-
261
- class ReLU6(Hardtanh):
262
- r"""Applies the element-wise function.
263
-
264
- ReLU6 is defined as:
265
-
266
- .. math::
267
- \text{ReLU6}(x) = \min(\max(0,x), 6)
268
-
269
- Shape
270
- -----
271
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
272
- - Output: :math:`(*)`, same shape as the input.
273
-
274
- Examples
275
- --------
276
- .. code-block:: python
277
-
278
- >>> import brainstate.nn as nn
279
- >>> import brainstate
280
- >>> m = nn.ReLU6()
281
- >>> x = brainstate.random.randn(2)
282
- >>> output = m(x)
283
- """
284
- __module__ = 'brainstate.nn'
285
-
286
- def __init__(self):
287
- super().__init__(0., 6.)
288
-
289
-
290
- class Sigmoid(ElementWiseBlock):
291
- r"""Applies the element-wise function.
292
-
293
- Sigmoid is defined as:
294
-
295
- .. math::
296
- \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
297
-
298
- Shape
299
- -----
300
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
301
- - Output: :math:`(*)`, same shape as the input.
302
-
303
- Examples
304
- --------
305
- .. code-block:: python
306
-
307
- >>> import brainstate.nn as nn
308
- >>> import brainstate
309
- >>> m = nn.Sigmoid()
310
- >>> x = brainstate.random.randn(2)
311
- >>> output = m(x)
312
- """
313
- __module__ = 'brainstate.nn'
314
-
315
- def __call__(self, x: ArrayLike) -> ArrayLike:
316
- return F.sigmoid(x)
317
-
318
-
319
- class Hardsigmoid(ElementWiseBlock):
320
- r"""Applies the Hardsigmoid function element-wise.
321
-
322
- Hardsigmoid is defined as:
323
-
324
- .. math::
325
- \text{Hardsigmoid}(x) = \begin{cases}
326
- 0 & \text{if~} x \le -3, \\
327
- 1 & \text{if~} x \ge +3, \\
328
- x / 6 + 1 / 2 & \text{otherwise}
329
- \end{cases}
330
-
331
- Shape
332
- -----
333
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
334
- - Output: :math:`(*)`, same shape as the input.
335
-
336
- Examples
337
- --------
338
- .. code-block:: python
339
-
340
- >>> import brainstate.nn as nn
341
- >>> import brainstate
342
- >>> m = nn.Hardsigmoid()
343
- >>> x = brainstate.random.randn(2)
344
- >>> output = m(x)
345
- """
346
- __module__ = 'brainstate.nn'
347
-
348
- def __call__(self, x: ArrayLike) -> ArrayLike:
349
- return F.hard_sigmoid(x)
350
-
351
-
352
- class Tanh(ElementWiseBlock):
353
- r"""Applies the Hyperbolic Tangent (Tanh) function element-wise.
354
-
355
- Tanh is defined as:
356
-
357
- .. math::
358
- \text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)}
359
-
360
- Shape
361
- -----
362
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
363
- - Output: :math:`(*)`, same shape as the input.
364
-
365
- Examples
366
- --------
367
- .. code-block:: python
368
-
369
- >>> import brainstate.nn as nn
370
- >>> import brainstate
371
- >>> m = nn.Tanh()
372
- >>> x = brainstate.random.randn(2)
373
- >>> output = m(x)
374
- """
375
- __module__ = 'brainstate.nn'
376
-
377
- def __call__(self, x: ArrayLike) -> ArrayLike:
378
- return F.tanh(x)
379
-
380
-
381
- class SiLU(ElementWiseBlock):
382
- r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
383
-
384
- The SiLU function is also known as the swish function.
385
-
386
- .. math::
387
- \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.}
388
-
389
- Notes
390
- -----
391
- See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
392
- where the SiLU (Sigmoid Linear Unit) was originally coined, and see
393
- `Sigmoid-Weighted Linear Units for Neural Network Function Approximation
394
- in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
395
- a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
396
- where the SiLU was experimented with later.
397
-
398
- Shape
399
- -----
400
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
401
- - Output: :math:`(*)`, same shape as the input.
402
-
403
- Examples
404
- --------
405
- .. code-block:: python
406
-
407
- >>> import brainstate.nn as nn
408
- >>> import brainstate
409
- >>> m = nn.SiLU()
410
- >>> x = brainstate.random.randn(2)
411
- >>> output = m(x)
412
- """
413
- __module__ = 'brainstate.nn'
414
-
415
- def __call__(self, x: ArrayLike) -> ArrayLike:
416
- return F.silu(x)
417
-
418
-
419
- class Mish(ElementWiseBlock):
420
- r"""Applies the Mish function, element-wise.
421
-
422
- Mish: A Self Regularized Non-Monotonic Neural Activation Function.
423
-
424
- .. math::
425
- \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
426
-
427
- Notes
428
- -----
429
- See `Mish: A Self Regularized Non-Monotonic Neural Activation Function
430
- <https://arxiv.org/abs/1908.08681>`_
431
-
432
- Shape
433
- -----
434
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
435
- - Output: :math:`(*)`, same shape as the input.
436
-
437
- Examples
438
- --------
439
- .. code-block:: python
440
-
441
- >>> import brainstate.nn as nn
442
- >>> import brainstate
443
- >>> m = nn.Mish()
444
- >>> x = brainstate.random.randn(2)
445
- >>> output = m(x)
446
- """
447
- __module__ = 'brainstate.nn'
448
-
449
- def __call__(self, x: ArrayLike) -> ArrayLike:
450
- return F.mish(x)
451
-
452
-
453
- class Hardswish(ElementWiseBlock):
454
- r"""Applies the Hardswish function, element-wise.
455
-
456
- As described in the paper `Searching for MobileNetV3
457
- <https://arxiv.org/abs/1905.02244>`_.
458
-
459
- Hardswish is defined as:
460
-
461
- .. math::
462
- \text{Hardswish}(x) = \begin{cases}
463
- 0 & \text{if~} x \le -3, \\
464
- x & \text{if~} x \ge +3, \\
465
- x \cdot (x + 3) /6 & \text{otherwise}
466
- \end{cases}
467
-
468
- Shape
469
- -----
470
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
471
- - Output: :math:`(*)`, same shape as the input.
472
-
473
- Examples
474
- --------
475
- .. code-block:: python
476
-
477
- >>> import brainstate.nn as nn
478
- >>> import brainstate
479
- >>> m = nn.Hardswish()
480
- >>> x = brainstate.random.randn(2)
481
- >>> output = m(x)
482
- """
483
- __module__ = 'brainstate.nn'
484
-
485
- def __call__(self, x: ArrayLike) -> ArrayLike:
486
- return F.hard_swish(x)
487
-
488
-
489
- class ELU(ElementWiseBlock):
490
- r"""Applies the Exponential Linear Unit (ELU) function, element-wise.
491
-
492
- As described in the paper: `Fast and Accurate Deep Network Learning by
493
- Exponential Linear Units (ELUs) <https://arxiv.org/abs/1511.07289>`__.
494
-
495
- ELU is defined as:
496
-
497
- .. math::
498
- \text{ELU}(x) = \begin{cases}
499
- x, & \text{ if } x > 0\\
500
- \alpha * (\exp(x) - 1), & \text{ if } x \leq 0
501
- \end{cases}
502
-
503
- Parameters
504
- ----------
505
- alpha : float, optional
506
- The :math:`\alpha` value for the ELU formulation. Default: 1.0
507
-
508
- Shape
509
- -----
510
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
511
- - Output: :math:`(*)`, same shape as the input.
512
-
513
- Examples
514
- --------
515
- .. code-block:: python
516
-
517
- >>> import brainstate.nn as nn
518
- >>> import brainstate
519
- >>> m = nn.ELU()
520
- >>> x = brainstate.random.randn(2)
521
- >>> output = m(x)
522
- """
523
- __module__ = 'brainstate.nn'
524
- alpha: float
525
-
526
- def __init__(self, alpha: float = 1.) -> None:
527
- super().__init__()
528
- self.alpha = alpha
529
-
530
- def __call__(self, x: ArrayLike) -> ArrayLike:
531
- return F.elu(x, self.alpha)
532
-
533
- def extra_repr(self) -> str:
534
- return f'{self.__class__.__name__}(alpha={self.alpha})'
535
-
536
-
537
- class CELU(ElementWiseBlock):
538
- r"""Applies the element-wise function.
539
-
540
- .. math::
541
- \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
542
-
543
- More details can be found in the paper `Continuously Differentiable Exponential
544
- Linear Units`_ .
545
-
546
- Parameters
547
- ----------
548
- alpha : float, optional
549
- The :math:`\alpha` value for the CELU formulation. Default: 1.0
550
-
551
- Shape
552
- -----
553
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
554
- - Output: :math:`(*)`, same shape as the input.
555
-
556
- References
557
- ----------
558
- .. _`Continuously Differentiable Exponential Linear Units`:
559
- https://arxiv.org/abs/1704.07483
560
-
561
- Examples
562
- --------
563
- .. code-block:: python
564
-
565
- >>> import brainstate.nn as nn
566
- >>> import brainstate
567
- >>> m = nn.CELU()
568
- >>> x = brainstate.random.randn(2)
569
- >>> output = m(x)
570
- """
571
- __module__ = 'brainstate.nn'
572
- alpha: float
573
-
574
- def __init__(self, alpha: float = 1.) -> None:
575
- super().__init__()
576
- self.alpha = alpha
577
-
578
- def __call__(self, x: ArrayLike) -> ArrayLike:
579
- return F.celu(x, self.alpha)
580
-
581
- def extra_repr(self) -> str:
582
- return f'{self.__class__.__name__}(alpha={self.alpha})'
583
-
584
-
585
- class SELU(ElementWiseBlock):
586
- r"""Applied element-wise.
587
-
588
- .. math::
589
- \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))
590
-
591
- with :math:`\alpha = 1.6732632423543772848170429916717` and
592
- :math:`\text{scale} = 1.0507009873554804934193349852946`.
593
-
594
- More details can be found in the paper `Self-Normalizing Neural Networks`_ .
595
-
596
- Shape
597
- -----
598
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
599
- - Output: :math:`(*)`, same shape as the input.
600
-
601
- References
602
- ----------
603
- .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
604
-
605
- Examples
606
- --------
607
- .. code-block:: python
608
-
609
- >>> import brainstate.nn as nn
610
- >>> import brainstate
611
- >>> m = nn.SELU()
612
- >>> x = brainstate.random.randn(2)
613
- >>> output = m(x)
614
- """
615
- __module__ = 'brainstate.nn'
616
-
617
- def __call__(self, x: ArrayLike) -> ArrayLike:
618
- return F.selu(x)
619
-
620
-
621
- class GLU(ElementWiseBlock):
622
- r"""Applies the gated linear unit function.
623
-
624
- .. math::
625
- {GLU}(a, b)= a \otimes \sigma(b)
626
-
627
- where :math:`a` is the first half of the input matrices and :math:`b` is
628
- the second half.
629
-
630
- Parameters
631
- ----------
632
- dim : int, optional
633
- The dimension on which to split the input. Default: -1
634
-
635
- Shape
636
- -----
637
- - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
638
- dimensions
639
- - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
640
-
641
- Examples
642
- --------
643
- .. code-block:: python
644
-
645
- >>> import brainstate.nn as nn
646
- >>> import brainstate
647
- >>> m = nn.GLU()
648
- >>> x = brainstate.random.randn(4, 2)
649
- >>> output = m(x)
650
- """
651
- __module__ = 'brainstate.nn'
652
- dim: int
653
-
654
- def __init__(self, dim: int = -1) -> None:
655
- super().__init__()
656
- self.dim = dim
657
-
658
- def __call__(self, x: ArrayLike) -> ArrayLike:
659
- return F.glu(x, self.dim)
660
-
661
- def __repr__(self):
662
- return f'{self.__class__.__name__}(dim={self.dim})'
663
-
664
-
665
- class GELU(ElementWiseBlock):
666
- r"""Applies the Gaussian Error Linear Units function.
667
-
668
- .. math:: \text{GELU}(x) = x * \Phi(x)
669
-
670
- where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian
671
- Distribution.
672
-
673
- When the approximate argument is True, Gelu is estimated with:
674
-
675
- .. math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt(2 / \pi) * (x + 0.044715 * x^3)))
676
-
677
- Parameters
678
- ----------
679
- approximate : bool, optional
680
- Whether to use the tanh approximation algorithm. Default: False
681
-
682
- Shape
683
- -----
684
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
685
- - Output: :math:`(*)`, same shape as the input.
686
-
687
- Examples
688
- --------
689
- .. code-block:: python
690
-
691
- >>> import brainstate.nn as nn
692
- >>> import brainstate
693
- >>> m = nn.GELU()
694
- >>> x = brainstate.random.randn(2)
695
- >>> output = m(x)
696
- """
697
- __module__ = 'brainstate.nn'
698
- approximate: bool
699
-
700
- def __init__(self, approximate: bool = False) -> None:
701
- super().__init__()
702
- self.approximate = approximate
703
-
704
- def __call__(self, x: ArrayLike) -> ArrayLike:
705
- return F.gelu(x, approximate=self.approximate)
706
-
707
- def __repr__(self):
708
- return f'{self.__class__.__name__}(approximate={self.approximate})'
709
-
710
-
711
- class Hardshrink(ElementWiseBlock):
712
- r"""Applies the Hard Shrinkage (Hardshrink) function element-wise.
713
-
714
- Hardshrink is defined as:
715
-
716
- .. math::
717
- \text{HardShrink}(x) =
718
- \begin{cases}
719
- x, & \text{ if } x > \lambda \\
720
- x, & \text{ if } x < -\lambda \\
721
- 0, & \text{ otherwise }
722
- \end{cases}
723
-
724
- Parameters
725
- ----------
726
- lambd : float, optional
727
- The :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
728
-
729
- Shape
730
- -----
731
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
732
- - Output: :math:`(*)`, same shape as the input.
733
-
734
- Examples
735
- --------
736
- .. code-block:: python
737
-
738
- >>> import brainstate.nn as nn
739
- >>> import brainstate
740
- >>> m = nn.Hardshrink()
741
- >>> x = brainstate.random.randn(2)
742
- >>> output = m(x)
743
- """
744
- __module__ = 'brainstate.nn'
745
- lambd: float
746
-
747
- def __init__(self, lambd: float = 0.5) -> None:
748
- super().__init__()
749
- self.lambd = lambd
750
-
751
- def __call__(self, x: ArrayLike) -> ArrayLike:
752
- return F.hard_shrink(x, self.lambd)
753
-
754
- def __repr__(self):
755
- return f'{self.__class__.__name__}(lambd={self.lambd})'
756
-
757
-
758
- class LeakyReLU(ElementWiseBlock):
759
- r"""Applies the element-wise function.
760
-
761
- .. math::
762
- \text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)
763
-
764
- or
765
-
766
- .. math::
767
- \text{LeakyReLU}(x) =
768
- \begin{cases}
769
- x, & \text{ if } x \geq 0 \\
770
- \text{negative\_slope} \times x, & \text{ otherwise }
771
- \end{cases}
772
-
773
- Parameters
774
- ----------
775
- negative_slope : float, optional
776
- Controls the angle of the negative slope (which is used for
777
- negative input values). Default: 1e-2
778
-
779
- Shape
780
- -----
781
- - Input: :math:`(*)` where `*` means, any number of additional
782
- dimensions
783
- - Output: :math:`(*)`, same shape as the input
784
-
785
- Examples
786
- --------
787
- .. code-block:: python
788
-
789
- >>> import brainstate.nn as nn
790
- >>> import brainstate
791
- >>> m = nn.LeakyReLU(0.1)
792
- >>> x = brainstate.random.randn(2)
793
- >>> output = m(x)
794
- """
795
- __module__ = 'brainstate.nn'
796
- negative_slope: float
797
-
798
- def __init__(self, negative_slope: float = 1e-2) -> None:
799
- super().__init__()
800
- self.negative_slope = negative_slope
801
-
802
- def __call__(self, x: ArrayLike) -> ArrayLike:
803
- return F.leaky_relu(x, self.negative_slope)
804
-
805
- def __repr__(self):
806
- return f'{self.__class__.__name__}(negative_slope={self.negative_slope})'
807
-
808
-
809
- class LogSigmoid(ElementWiseBlock):
810
- r"""Applies the element-wise function.
811
-
812
- .. math::
813
- \text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right)
814
-
815
- Shape
816
- -----
817
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
818
- - Output: :math:`(*)`, same shape as the input.
819
-
820
- Examples
821
- --------
822
- .. code-block:: python
823
-
824
- >>> import brainstate.nn as nn
825
- >>> import brainstate
826
- >>> m = nn.LogSigmoid()
827
- >>> x = brainstate.random.randn(2)
828
- >>> output = m(x)
829
- """
830
- __module__ = 'brainstate.nn'
831
-
832
- def __call__(self, x: ArrayLike) -> ArrayLike:
833
- return F.log_sigmoid(x)
834
-
835
-
836
- class Softplus(ElementWiseBlock):
837
- r"""Applies the Softplus function element-wise.
838
-
839
- .. math::
840
- \text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))
841
-
842
- SoftPlus is a smooth approximation to the ReLU function and can be used
843
- to constrain the output of a machine to always be positive.
844
-
845
- For numerical stability the implementation reverts to the linear function
846
- when :math:`input \times \beta > threshold`.
847
-
848
- Shape
849
- -----
850
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
851
- - Output: :math:`(*)`, same shape as the input.
852
-
853
- Examples
854
- --------
855
- .. code-block:: python
856
-
857
- >>> import brainstate.nn as nn
858
- >>> import brainstate
859
- >>> m = nn.Softplus()
860
- >>> x = brainstate.random.randn(2)
861
- >>> output = m(x)
862
- """
863
- __module__ = 'brainstate.nn'
864
-
865
- def __call__(self, x: ArrayLike) -> ArrayLike:
866
- return F.softplus(x)
867
-
868
-
869
- class Softshrink(ElementWiseBlock):
870
- r"""Applies the soft shrinkage function elementwise.
871
-
872
- .. math::
873
- \text{SoftShrinkage}(x) =
874
- \begin{cases}
875
- x - \lambda, & \text{ if } x > \lambda \\
876
- x + \lambda, & \text{ if } x < -\lambda \\
877
- 0, & \text{ otherwise }
878
- \end{cases}
879
-
880
- Parameters
881
- ----------
882
- lambd : float, optional
883
- The :math:`\lambda` (must be no less than zero) value for the
884
- Softshrink formulation. Default: 0.5
885
-
886
- Shape
887
- -----
888
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
889
- - Output: :math:`(*)`, same shape as the input.
890
-
891
- Examples
892
- --------
893
- .. code-block:: python
894
-
895
- >>> import brainstate.nn as nn
896
- >>> import brainstate
897
- >>> m = nn.Softshrink()
898
- >>> x = brainstate.random.randn(2)
899
- >>> output = m(x)
900
- """
901
- __module__ = 'brainstate.nn'
902
- lambd: float
903
-
904
- def __init__(self, lambd: float = 0.5) -> None:
905
- super().__init__()
906
- self.lambd = lambd
907
-
908
- def __call__(self, x: ArrayLike) -> ArrayLike:
909
- return F.soft_shrink(x, self.lambd)
910
-
911
- def __repr__(self):
912
- return f'{self.__class__.__name__}(lambd={self.lambd})'
913
-
914
-
915
- class PReLU(ElementWiseBlock):
916
- r"""Applies the element-wise function.
917
-
918
- .. math::
919
- \text{PReLU}(x) = \max(0,x) + a * \min(0,x)
920
-
921
- or
922
-
923
- .. math::
924
- \text{PReLU}(x) =
925
- \begin{cases}
926
- x, & \text{ if } x \geq 0 \\
927
- ax, & \text{ otherwise }
928
- \end{cases}
929
-
930
- Here :math:`a` is a learnable parameter. When called without arguments,
931
- `nn.PReLU()` uses a single parameter :math:`a` across all input channels.
932
- If called with `nn.PReLU(nChannels)`, a separate :math:`a` is used for
933
- each input channel.
934
-
935
- Parameters
936
- ----------
937
- num_parameters : int, optional
938
- Number of :math:`a` to learn. Although it takes an int as input,
939
- there is only two values are legitimate: 1, or the number of channels
940
- at input. Default: 1
941
- init : float, optional
942
- The initial value of :math:`a`. Default: 0.25
943
- dtype : optional
944
- The data type for the weight parameter.
945
-
946
- Shape
947
- -----
948
- - Input: :math:`( *)` where `*` means, any number of additional dimensions.
949
- - Output: :math:`(*)`, same shape as the input.
950
-
951
- Attributes
952
- ----------
953
- weight : Tensor
954
- The learnable weights of shape (:attr:`num_parameters`).
955
-
956
- Notes
957
- -----
958
- - Weight decay should not be used when learning :math:`a` for good performance.
959
- - Channel dim is the 2nd dim of input. When input has dims < 2, then there is
960
- no channel dim and the number of channels = 1.
961
-
962
- Examples
963
- --------
964
- .. code-block:: python
965
-
966
- >>> import brainstate
967
- >>> m = brainstate.nn.PReLU()
968
- >>> x = brainstate.random.randn(2)
969
- >>> output = m(x)
970
- """
971
- __module__ = 'brainstate.nn'
972
- num_parameters: int
973
-
974
- def __init__(self, num_parameters: int = 1, init: float = 0.25, dtype=None) -> None:
975
- super().__init__()
976
- self.num_parameters = num_parameters
977
- self.weight = ParamState(jnp.ones(num_parameters, dtype=dtype) * init)
978
-
979
- def __call__(self, x: ArrayLike) -> ArrayLike:
980
- return F.prelu(x, self.weight.value)
981
-
982
- def __repr__(self):
983
- return f'{self.__class__.__name__}(num_parameters={self.num_parameters})'
984
-
985
-
986
- class Softsign(ElementWiseBlock):
987
- r"""Applies the element-wise function.
988
-
989
- .. math::
990
- \text{SoftSign}(x) = \frac{x}{ 1 + |x|}
991
-
992
- Shape
993
- -----
994
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
995
- - Output: :math:`(*)`, same shape as the input.
996
-
997
- Examples
998
- --------
999
- .. code-block:: python
1000
-
1001
- >>> import brainstate.nn as nn
1002
- >>> import brainstate
1003
- >>> m = nn.Softsign()
1004
- >>> x = brainstate.random.randn(2)
1005
- >>> output = m(x)
1006
- """
1007
- __module__ = 'brainstate.nn'
1008
-
1009
- def __call__(self, x: ArrayLike) -> ArrayLike:
1010
- return F.soft_sign(x)
1011
-
1012
-
1013
- class Tanhshrink(ElementWiseBlock):
1014
- r"""Applies the element-wise function.
1015
-
1016
- .. math::
1017
- \text{Tanhshrink}(x) = x - \tanh(x)
1018
-
1019
- Shape
1020
- -----
1021
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
1022
- - Output: :math:`(*)`, same shape as the input.
1023
-
1024
- Examples
1025
- --------
1026
- .. code-block:: python
1027
-
1028
- >>> import brainstate.nn as nn
1029
- >>> import brainstate
1030
- >>> m = nn.Tanhshrink()
1031
- >>> x = brainstate.random.randn(2)
1032
- >>> output = m(x)
1033
- """
1034
- __module__ = 'brainstate.nn'
1035
-
1036
- def __call__(self, x: ArrayLike) -> ArrayLike:
1037
- return F.tanh_shrink(x)
1038
-
1039
-
1040
- class Softmin(ElementWiseBlock):
1041
- r"""Applies the Softmin function to an n-dimensional input Tensor.
1042
-
1043
- Rescales the input so that the elements of the n-dimensional output Tensor
1044
- lie in the range `[0, 1]` and sum to 1.
1045
-
1046
- Softmin is defined as:
1047
-
1048
- .. math::
1049
- \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
1050
-
1051
- Parameters
1052
- ----------
1053
- dim : int, optional
1054
- A dimension along which Softmin will be computed (so every slice
1055
- along dim will sum to 1).
1056
-
1057
- Shape
1058
- -----
1059
- - Input: :math:`(*)` where `*` means, any number of additional dimensions
1060
- - Output: :math:`(*)`, same shape as the input
1061
-
1062
- Returns
1063
- -------
1064
- Tensor
1065
- A Tensor of the same dimension and shape as the input, with
1066
- values in the range [0, 1]
1067
-
1068
- Examples
1069
- --------
1070
- .. code-block:: python
1071
-
1072
- >>> import brainstate.nn as nn
1073
- >>> import brainstate
1074
- >>> m = nn.Softmin(dim=1)
1075
- >>> x = brainstate.random.randn(2, 3)
1076
- >>> output = m(x)
1077
- """
1078
- __module__ = 'brainstate.nn'
1079
- dim: Optional[int]
1080
-
1081
- def __init__(self, dim: Optional[int] = None) -> None:
1082
- super().__init__()
1083
- self.dim = dim
1084
-
1085
- def __call__(self, x: ArrayLike) -> ArrayLike:
1086
- return F.softmin(x, self.dim)
1087
-
1088
- def __repr__(self):
1089
- return f'{self.__class__.__name__}(dim={self.dim})'
1090
-
1091
-
1092
- class Softmax(ElementWiseBlock):
1093
- r"""Applies the Softmax function to an n-dimensional input Tensor.
1094
-
1095
- Rescales the input so that the elements of the n-dimensional output Tensor
1096
- lie in the range [0,1] and sum to 1.
1097
-
1098
- Softmax is defined as:
1099
-
1100
- .. math::
1101
- \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
1102
-
1103
- When the input Tensor is a sparse tensor then the unspecified
1104
- values are treated as ``-inf``.
1105
-
1106
- Parameters
1107
- ----------
1108
- dim : int, optional
1109
- A dimension along which Softmax will be computed (so every slice
1110
- along dim will sum to 1).
1111
-
1112
- Shape
1113
- -----
1114
- - Input: :math:`(*)` where `*` means, any number of additional dimensions
1115
- - Output: :math:`(*)`, same shape as the input
1116
-
1117
- Returns
1118
- -------
1119
- Tensor
1120
- A Tensor of the same dimension and shape as the input with
1121
- values in the range [0, 1]
1122
-
1123
- Notes
1124
- -----
1125
- This module doesn't work directly with NLLLoss, which expects the Log to be
1126
- computed between the Softmax and itself. Use `LogSoftmax` instead (it's
1127
- faster and has better numerical properties).
1128
-
1129
- Examples
1130
- --------
1131
- .. code-block:: python
1132
-
1133
- >>> import brainstate.nn as nn
1134
- >>> import brainstate
1135
- >>> m = nn.Softmax(dim=1)
1136
- >>> x = brainstate.random.randn(2, 3)
1137
- >>> output = m(x)
1138
- """
1139
- __module__ = 'brainstate.nn'
1140
- dim: Optional[int]
1141
-
1142
- def __init__(self, dim: Optional[int] = None) -> None:
1143
- super().__init__()
1144
- self.dim = dim
1145
-
1146
- def __call__(self, x: ArrayLike) -> ArrayLike:
1147
- return F.softmax(x, self.dim)
1148
-
1149
- def __repr__(self):
1150
- return f'{self.__class__.__name__}(dim={self.dim})'
1151
-
1152
-
1153
- class Softmax2d(ElementWiseBlock):
1154
- r"""Applies SoftMax over features to each spatial location.
1155
-
1156
- When given an image of ``Channels x Height x Width``, it will
1157
- apply `Softmax` to each location :math:`(Channels, h_i, w_j)`
1158
-
1159
- Shape
1160
- -----
1161
- - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`.
1162
- - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
1163
-
1164
- Returns
1165
- -------
1166
- Tensor
1167
- A Tensor of the same dimension and shape as the input with
1168
- values in the range [0, 1]
1169
-
1170
- Examples
1171
- --------
1172
- .. code-block:: python
1173
-
1174
- >>> import brainstate.nn as nn
1175
- >>> import brainstate
1176
- >>> m = nn.Softmax2d()
1177
- >>> # you softmax over the 2nd dimension
1178
- >>> x = brainstate.random.randn(2, 3, 12, 13)
1179
- >>> output = m(x)
1180
- """
1181
- __module__ = 'brainstate.nn'
1182
-
1183
- def __call__(self, x: ArrayLike) -> ArrayLike:
1184
- assert x.ndim == 4 or x.ndim == 3, 'Softmax2d requires a 3D or 4D tensor as input'
1185
- return F.softmax(x, -3)
1186
-
1187
-
1188
- class LogSoftmax(ElementWiseBlock):
1189
- r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional input Tensor.
1190
-
1191
- The LogSoftmax formulation can be simplified as:
1192
-
1193
- .. math::
1194
- \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
1195
-
1196
- Parameters
1197
- ----------
1198
- dim : int, optional
1199
- A dimension along which LogSoftmax will be computed.
1200
-
1201
- Shape
1202
- -----
1203
- - Input: :math:`(*)` where `*` means, any number of additional dimensions
1204
- - Output: :math:`(*)`, same shape as the input
1205
-
1206
- Returns
1207
- -------
1208
- Tensor
1209
- A Tensor of the same dimension and shape as the input with
1210
- values in the range [-inf, 0)
1211
-
1212
- Examples
1213
- --------
1214
- .. code-block:: python
1215
-
1216
- >>> import brainstate.nn as nn
1217
- >>> import brainstate
1218
- >>> m = nn.LogSoftmax(dim=1)
1219
- >>> x = brainstate.random.randn(2, 3)
1220
- >>> output = m(x)
1221
- """
1222
- __module__ = 'brainstate.nn'
1223
- dim: Optional[int]
1224
-
1225
- def __init__(self, dim: Optional[int] = None) -> None:
1226
- super().__init__()
1227
- self.dim = dim
1228
-
1229
- def __call__(self, x: ArrayLike) -> ArrayLike:
1230
- return F.log_softmax(x, self.dim)
1231
-
1232
- def __repr__(self):
1233
- return f'{self.__class__.__name__}(dim={self.dim})'
1234
-
1235
-
1236
- class Identity(ElementWiseBlock):
1237
- r"""A placeholder identity operator that is argument-insensitive.
1238
-
1239
- Examples
1240
- --------
1241
- .. code-block:: python
1242
-
1243
- >>> import brainstate.nn as nn
1244
- >>> m = nn.Identity()
1245
- >>> x = brainstate.random.randn(2, 3)
1246
- >>> output = m(x)
1247
- >>> assert (output == x).all()
1248
- """
1249
- __module__ = 'brainstate.nn'
1250
-
1251
- def __call__(self, x):
1252
- return x
1253
-
1254
-
1255
- class SpikeBitwise(ElementWiseBlock):
1256
- r"""Bitwise addition for the spiking inputs.
1257
-
1258
- .. math::
1259
-
1260
- \begin{array}{ccc}
1261
- \hline \text { Mode } & \text { Expression for } \mathrm{g}(\mathrm{x}, \mathrm{y}) & \text { Code for } \mathrm{g}(\mathrm{x}, \mathrm{y}) \\
1262
- \hline \text { ADD } & x+y & x+y \\
1263
- \text { AND } & x \cap y & x \cdot y \\
1264
- \text { IAND } & (\neg x) \cap y & (1-x) \cdot y \\
1265
- \text { OR } & x \cup y & (x+y)-(x \cdot y) \\
1266
- \hline
1267
- \end{array}
1268
-
1269
- Parameters
1270
- ----------
1271
- op : str, optional
1272
- The bitwise operation. Default: 'add'
1273
- name : str, optional
1274
- The name of the dynamic system.
1275
-
1276
- Examples
1277
- --------
1278
- .. code-block:: python
1279
-
1280
- >>> import brainstate.nn as nn
1281
- >>> m = nn.SpikeBitwise(op='and')
1282
- >>> x = brainstate.random.randn(2, 3) > 0
1283
- >>> y = brainstate.random.randn(2, 3) > 0
1284
- >>> output = m(x, y)
1285
- """
1286
- __module__ = 'brainstate.nn'
1287
-
1288
- def __init__(
1289
- self,
1290
- op: str = 'add',
1291
- name: Optional[str] = None
1292
- ) -> None:
1293
- super().__init__(name=name)
1294
- self.op = op
1295
-
1296
- def __call__(self, x, y):
1297
- import braintools
1298
- return braintools.spike_bitwise(x, y, self.op)
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from typing import Optional
19
+
20
+ import brainunit as u
21
+ import jax.numpy as jnp
22
+
23
+ from brainstate._state import ParamState
24
+ from brainstate.typing import ArrayLike
25
+ from . import _activations as F
26
+ from ._module import ElementWiseBlock
27
+
28
+ __all__ = [
29
+ # activation functions
30
+ 'Threshold', 'ReLU', 'RReLU', 'Hardtanh', 'ReLU6', 'Sigmoid', 'Hardsigmoid',
31
+ 'Tanh', 'SiLU', 'Mish', 'Hardswish', 'ELU', 'CELU', 'SELU', 'GLU', 'GELU',
32
+ 'Hardshrink', 'LeakyReLU', 'LogSigmoid', 'Softplus', 'Softshrink', 'PReLU',
33
+ 'Softsign', 'Tanhshrink', 'Softmin', 'Softmax', 'Softmax2d', 'LogSoftmax',
34
+
35
+ # others
36
+ 'Identity', 'SpikeBitwise',
37
+ ]
38
+
39
+
40
+ class Threshold(ElementWiseBlock):
41
+ r"""Thresholds each element of the input Tensor.
42
+
43
+ Threshold is defined as:
44
+
45
+ .. math::
46
+ y =
47
+ \begin{cases}
48
+ x, &\text{ if } x > \text{threshold} \\
49
+ \text{value}, &\text{ otherwise }
50
+ \end{cases}
51
+
52
+ Parameters
53
+ ----------
54
+ threshold : float
55
+ The value to threshold at.
56
+ value : float
57
+ The value to replace with.
58
+
59
+ Shape
60
+ -----
61
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
62
+ - Output: :math:`(*)`, same shape as the input.
63
+
64
+ Examples
65
+ --------
66
+ .. code-block:: python
67
+
68
+ >>> import brainstate.nn as nn
69
+ >>> import brainstate
70
+ >>> m = nn.Threshold(0.1, 20)
71
+ >>> x = brainstate.random.randn(2)
72
+ >>> output = m(x)
73
+ """
74
+ __module__ = 'brainstate.nn'
75
+ threshold: float
76
+ value: float
77
+
78
+ def __init__(self, threshold: float, value: float) -> None:
79
+ super().__init__()
80
+ self.threshold = threshold
81
+ self.value = value
82
+
83
+ def __call__(self, x: ArrayLike) -> ArrayLike:
84
+ dtype = u.math.get_dtype(x)
85
+ return jnp.where(x > jnp.asarray(self.threshold, dtype=dtype),
86
+ x,
87
+ jnp.asarray(self.value, dtype=dtype))
88
+
89
+ def __repr__(self):
90
+ return f'{self.__class__.__name__}(threshold={self.threshold}, value={self.value})'
91
+
92
+
93
+ class ReLU(ElementWiseBlock):
94
+ r"""Applies the rectified linear unit function element-wise.
95
+
96
+ The ReLU function is defined as:
97
+
98
+ .. math::
99
+ \text{ReLU}(x) = (x)^+ = \max(0, x)
100
+
101
+ Shape
102
+ -----
103
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
104
+ - Output: :math:`(*)`, same shape as the input.
105
+
106
+ Examples
107
+ --------
108
+ .. code-block:: python
109
+
110
+ >>> import brainstate.nn as nn
111
+ >>> import brainstate
112
+ >>> m = nn.ReLU()
113
+ >>> x = brainstate.random.randn(2)
114
+ >>> output = m(x)
115
+
116
+ An implementation of CReLU - https://arxiv.org/abs/1603.05201
117
+
118
+ .. code-block:: python
119
+
120
+ >>> import brainstate.nn as nn
121
+ >>> import brainstate
122
+ >>> import jax.numpy as jnp
123
+ >>> m = nn.ReLU()
124
+ >>> x = brainstate.random.randn(2).unsqueeze(0)
125
+ >>> output = jnp.concat((m(x), m(-x)))
126
+ """
127
+ __module__ = 'brainstate.nn'
128
+
129
+ def __call__(self, x: ArrayLike) -> ArrayLike:
130
+ return F.relu(x)
131
+
132
+ def __repr__(self):
133
+ return f'{self.__class__.__name__}()'
134
+
135
+
136
+ class RReLU(ElementWiseBlock):
137
+ r"""Applies the randomized leaky rectified liner unit function, element-wise.
138
+
139
+ As described in the paper `Empirical Evaluation of Rectified Activations in
140
+ Convolutional Network`_.
141
+
142
+ The function is defined as:
143
+
144
+ .. math::
145
+ \text{RReLU}(x) =
146
+ \begin{cases}
147
+ x & \text{if } x \geq 0 \\
148
+ ax & \text{ otherwise }
149
+ \end{cases}
150
+
151
+ where :math:`a` is randomly sampled from uniform distribution
152
+ :math:`\mathcal{U}(\text{lower}, \text{upper})`.
153
+
154
+ Parameters
155
+ ----------
156
+ lower : float, optional
157
+ Lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
158
+ upper : float, optional
159
+ Upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
160
+
161
+ Shape
162
+ -----
163
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
164
+ - Output: :math:`(*)`, same shape as the input.
165
+
166
+ References
167
+ ----------
168
+ .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
169
+ https://arxiv.org/abs/1505.00853
170
+
171
+ Examples
172
+ --------
173
+ .. code-block:: python
174
+
175
+ >>> import brainstate.nn as nn
176
+ >>> import brainstate
177
+ >>> m = nn.RReLU(0.1, 0.3)
178
+ >>> x = brainstate.random.randn(2)
179
+ >>> output = m(x)
180
+ """
181
+ __module__ = 'brainstate.nn'
182
+ lower: float
183
+ upper: float
184
+
185
+ def __init__(
186
+ self,
187
+ lower: float = 1. / 8,
188
+ upper: float = 1. / 3,
189
+ ):
190
+ super().__init__()
191
+ self.lower = lower
192
+ self.upper = upper
193
+
194
+ def __call__(self, x: ArrayLike) -> ArrayLike:
195
+ return F.rrelu(x, self.lower, self.upper)
196
+
197
+ def extra_repr(self):
198
+ return f'{self.__class__.__name__}(lower={self.lower}, upper={self.upper})'
199
+
200
+
201
+ class Hardtanh(ElementWiseBlock):
202
+ r"""Applies the HardTanh function element-wise.
203
+
204
+ HardTanh is defined as:
205
+
206
+ .. math::
207
+ \text{HardTanh}(x) = \begin{cases}
208
+ \text{max\_val} & \text{ if } x > \text{ max\_val } \\
209
+ \text{min\_val} & \text{ if } x < \text{ min\_val } \\
210
+ x & \text{ otherwise } \\
211
+ \end{cases}
212
+
213
+ Parameters
214
+ ----------
215
+ min_val : float, optional
216
+ Minimum value of the linear region range. Default: -1
217
+ max_val : float, optional
218
+ Maximum value of the linear region range. Default: 1
219
+
220
+ Notes
221
+ -----
222
+ Keyword arguments :attr:`min_value` and :attr:`max_value`
223
+ have been deprecated in favor of :attr:`min_val` and :attr:`max_val`.
224
+
225
+ Shape
226
+ -----
227
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
228
+ - Output: :math:`(*)`, same shape as the input.
229
+
230
+ Examples
231
+ --------
232
+ .. code-block:: python
233
+
234
+ >>> import brainstate.nn as nn
235
+ >>> import brainstate
236
+ >>> m = nn.Hardtanh(-2, 2)
237
+ >>> x = brainstate.random.randn(2)
238
+ >>> output = m(x)
239
+ """
240
+ __module__ = 'brainstate.nn'
241
+ min_val: float
242
+ max_val: float
243
+
244
+ def __init__(
245
+ self,
246
+ min_val: float = -1.,
247
+ max_val: float = 1.,
248
+ ) -> None:
249
+ super().__init__()
250
+ self.min_val = min_val
251
+ self.max_val = max_val
252
+ assert self.max_val > self.min_val
253
+
254
+ def __call__(self, x: ArrayLike) -> ArrayLike:
255
+ return F.hard_tanh(x, self.min_val, self.max_val)
256
+
257
+ def extra_repr(self) -> str:
258
+ return f'{self.__class__.__name__}(min_val={self.min_val}, max_val={self.max_val})'
259
+
260
+
261
+ class ReLU6(Hardtanh):
262
+ r"""Applies the element-wise function.
263
+
264
+ ReLU6 is defined as:
265
+
266
+ .. math::
267
+ \text{ReLU6}(x) = \min(\max(0,x), 6)
268
+
269
+ Shape
270
+ -----
271
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
272
+ - Output: :math:`(*)`, same shape as the input.
273
+
274
+ Examples
275
+ --------
276
+ .. code-block:: python
277
+
278
+ >>> import brainstate.nn as nn
279
+ >>> import brainstate
280
+ >>> m = nn.ReLU6()
281
+ >>> x = brainstate.random.randn(2)
282
+ >>> output = m(x)
283
+ """
284
+ __module__ = 'brainstate.nn'
285
+
286
+ def __init__(self):
287
+ super().__init__(0., 6.)
288
+
289
+
290
+ class Sigmoid(ElementWiseBlock):
291
+ r"""Applies the element-wise function.
292
+
293
+ Sigmoid is defined as:
294
+
295
+ .. math::
296
+ \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
297
+
298
+ Shape
299
+ -----
300
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
301
+ - Output: :math:`(*)`, same shape as the input.
302
+
303
+ Examples
304
+ --------
305
+ .. code-block:: python
306
+
307
+ >>> import brainstate.nn as nn
308
+ >>> import brainstate
309
+ >>> m = nn.Sigmoid()
310
+ >>> x = brainstate.random.randn(2)
311
+ >>> output = m(x)
312
+ """
313
+ __module__ = 'brainstate.nn'
314
+
315
+ def __call__(self, x: ArrayLike) -> ArrayLike:
316
+ return F.sigmoid(x)
317
+
318
+
319
+ class Hardsigmoid(ElementWiseBlock):
320
+ r"""Applies the Hardsigmoid function element-wise.
321
+
322
+ Hardsigmoid is defined as:
323
+
324
+ .. math::
325
+ \text{Hardsigmoid}(x) = \begin{cases}
326
+ 0 & \text{if~} x \le -3, \\
327
+ 1 & \text{if~} x \ge +3, \\
328
+ x / 6 + 1 / 2 & \text{otherwise}
329
+ \end{cases}
330
+
331
+ Shape
332
+ -----
333
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
334
+ - Output: :math:`(*)`, same shape as the input.
335
+
336
+ Examples
337
+ --------
338
+ .. code-block:: python
339
+
340
+ >>> import brainstate.nn as nn
341
+ >>> import brainstate
342
+ >>> m = nn.Hardsigmoid()
343
+ >>> x = brainstate.random.randn(2)
344
+ >>> output = m(x)
345
+ """
346
+ __module__ = 'brainstate.nn'
347
+
348
+ def __call__(self, x: ArrayLike) -> ArrayLike:
349
+ return F.hard_sigmoid(x)
350
+
351
+
352
+ class Tanh(ElementWiseBlock):
353
+ r"""Applies the Hyperbolic Tangent (Tanh) function element-wise.
354
+
355
+ Tanh is defined as:
356
+
357
+ .. math::
358
+ \text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)}
359
+
360
+ Shape
361
+ -----
362
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
363
+ - Output: :math:`(*)`, same shape as the input.
364
+
365
+ Examples
366
+ --------
367
+ .. code-block:: python
368
+
369
+ >>> import brainstate.nn as nn
370
+ >>> import brainstate
371
+ >>> m = nn.Tanh()
372
+ >>> x = brainstate.random.randn(2)
373
+ >>> output = m(x)
374
+ """
375
+ __module__ = 'brainstate.nn'
376
+
377
+ def __call__(self, x: ArrayLike) -> ArrayLike:
378
+ return F.tanh(x)
379
+
380
+
381
+ class SiLU(ElementWiseBlock):
382
+ r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
383
+
384
+ The SiLU function is also known as the swish function.
385
+
386
+ .. math::
387
+ \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.}
388
+
389
+ Notes
390
+ -----
391
+ See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
392
+ where the SiLU (Sigmoid Linear Unit) was originally coined, and see
393
+ `Sigmoid-Weighted Linear Units for Neural Network Function Approximation
394
+ in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
395
+ a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
396
+ where the SiLU was experimented with later.
397
+
398
+ Shape
399
+ -----
400
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
401
+ - Output: :math:`(*)`, same shape as the input.
402
+
403
+ Examples
404
+ --------
405
+ .. code-block:: python
406
+
407
+ >>> import brainstate.nn as nn
408
+ >>> import brainstate
409
+ >>> m = nn.SiLU()
410
+ >>> x = brainstate.random.randn(2)
411
+ >>> output = m(x)
412
+ """
413
+ __module__ = 'brainstate.nn'
414
+
415
+ def __call__(self, x: ArrayLike) -> ArrayLike:
416
+ return F.silu(x)
417
+
418
+
419
+ class Mish(ElementWiseBlock):
420
+ r"""Applies the Mish function, element-wise.
421
+
422
+ Mish: A Self Regularized Non-Monotonic Neural Activation Function.
423
+
424
+ .. math::
425
+ \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
426
+
427
+ Notes
428
+ -----
429
+ See `Mish: A Self Regularized Non-Monotonic Neural Activation Function
430
+ <https://arxiv.org/abs/1908.08681>`_
431
+
432
+ Shape
433
+ -----
434
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
435
+ - Output: :math:`(*)`, same shape as the input.
436
+
437
+ Examples
438
+ --------
439
+ .. code-block:: python
440
+
441
+ >>> import brainstate.nn as nn
442
+ >>> import brainstate
443
+ >>> m = nn.Mish()
444
+ >>> x = brainstate.random.randn(2)
445
+ >>> output = m(x)
446
+ """
447
+ __module__ = 'brainstate.nn'
448
+
449
+ def __call__(self, x: ArrayLike) -> ArrayLike:
450
+ return F.mish(x)
451
+
452
+
453
+ class Hardswish(ElementWiseBlock):
454
+ r"""Applies the Hardswish function, element-wise.
455
+
456
+ As described in the paper `Searching for MobileNetV3
457
+ <https://arxiv.org/abs/1905.02244>`_.
458
+
459
+ Hardswish is defined as:
460
+
461
+ .. math::
462
+ \text{Hardswish}(x) = \begin{cases}
463
+ 0 & \text{if~} x \le -3, \\
464
+ x & \text{if~} x \ge +3, \\
465
+ x \cdot (x + 3) /6 & \text{otherwise}
466
+ \end{cases}
467
+
468
+ Shape
469
+ -----
470
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
471
+ - Output: :math:`(*)`, same shape as the input.
472
+
473
+ Examples
474
+ --------
475
+ .. code-block:: python
476
+
477
+ >>> import brainstate.nn as nn
478
+ >>> import brainstate
479
+ >>> m = nn.Hardswish()
480
+ >>> x = brainstate.random.randn(2)
481
+ >>> output = m(x)
482
+ """
483
+ __module__ = 'brainstate.nn'
484
+
485
+ def __call__(self, x: ArrayLike) -> ArrayLike:
486
+ return F.hard_swish(x)
487
+
488
+
489
+ class ELU(ElementWiseBlock):
490
+ r"""Applies the Exponential Linear Unit (ELU) function, element-wise.
491
+
492
+ As described in the paper: `Fast and Accurate Deep Network Learning by
493
+ Exponential Linear Units (ELUs) <https://arxiv.org/abs/1511.07289>`__.
494
+
495
+ ELU is defined as:
496
+
497
+ .. math::
498
+ \text{ELU}(x) = \begin{cases}
499
+ x, & \text{ if } x > 0\\
500
+ \alpha * (\exp(x) - 1), & \text{ if } x \leq 0
501
+ \end{cases}
502
+
503
+ Parameters
504
+ ----------
505
+ alpha : float, optional
506
+ The :math:`\alpha` value for the ELU formulation. Default: 1.0
507
+
508
+ Shape
509
+ -----
510
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
511
+ - Output: :math:`(*)`, same shape as the input.
512
+
513
+ Examples
514
+ --------
515
+ .. code-block:: python
516
+
517
+ >>> import brainstate.nn as nn
518
+ >>> import brainstate
519
+ >>> m = nn.ELU()
520
+ >>> x = brainstate.random.randn(2)
521
+ >>> output = m(x)
522
+ """
523
+ __module__ = 'brainstate.nn'
524
+ alpha: float
525
+
526
+ def __init__(self, alpha: float = 1.) -> None:
527
+ super().__init__()
528
+ self.alpha = alpha
529
+
530
+ def __call__(self, x: ArrayLike) -> ArrayLike:
531
+ return F.elu(x, self.alpha)
532
+
533
+ def extra_repr(self) -> str:
534
+ return f'{self.__class__.__name__}(alpha={self.alpha})'
535
+
536
+
537
+ class CELU(ElementWiseBlock):
538
+ r"""Applies the element-wise function.
539
+
540
+ .. math::
541
+ \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
542
+
543
+ More details can be found in the paper `Continuously Differentiable Exponential
544
+ Linear Units`_ .
545
+
546
+ Parameters
547
+ ----------
548
+ alpha : float, optional
549
+ The :math:`\alpha` value for the CELU formulation. Default: 1.0
550
+
551
+ Shape
552
+ -----
553
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
554
+ - Output: :math:`(*)`, same shape as the input.
555
+
556
+ References
557
+ ----------
558
+ .. _`Continuously Differentiable Exponential Linear Units`:
559
+ https://arxiv.org/abs/1704.07483
560
+
561
+ Examples
562
+ --------
563
+ .. code-block:: python
564
+
565
+ >>> import brainstate.nn as nn
566
+ >>> import brainstate
567
+ >>> m = nn.CELU()
568
+ >>> x = brainstate.random.randn(2)
569
+ >>> output = m(x)
570
+ """
571
+ __module__ = 'brainstate.nn'
572
+ alpha: float
573
+
574
+ def __init__(self, alpha: float = 1.) -> None:
575
+ super().__init__()
576
+ self.alpha = alpha
577
+
578
+ def __call__(self, x: ArrayLike) -> ArrayLike:
579
+ return F.celu(x, self.alpha)
580
+
581
+ def extra_repr(self) -> str:
582
+ return f'{self.__class__.__name__}(alpha={self.alpha})'
583
+
584
+
585
+ class SELU(ElementWiseBlock):
586
+ r"""Applied element-wise.
587
+
588
+ .. math::
589
+ \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))
590
+
591
+ with :math:`\alpha = 1.6732632423543772848170429916717` and
592
+ :math:`\text{scale} = 1.0507009873554804934193349852946`.
593
+
594
+ More details can be found in the paper `Self-Normalizing Neural Networks`_ .
595
+
596
+ Shape
597
+ -----
598
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
599
+ - Output: :math:`(*)`, same shape as the input.
600
+
601
+ References
602
+ ----------
603
+ .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
604
+
605
+ Examples
606
+ --------
607
+ .. code-block:: python
608
+
609
+ >>> import brainstate.nn as nn
610
+ >>> import brainstate
611
+ >>> m = nn.SELU()
612
+ >>> x = brainstate.random.randn(2)
613
+ >>> output = m(x)
614
+ """
615
+ __module__ = 'brainstate.nn'
616
+
617
+ def __call__(self, x: ArrayLike) -> ArrayLike:
618
+ return F.selu(x)
619
+
620
+
621
+ class GLU(ElementWiseBlock):
622
+ r"""Applies the gated linear unit function.
623
+
624
+ .. math::
625
+ {GLU}(a, b)= a \otimes \sigma(b)
626
+
627
+ where :math:`a` is the first half of the input matrices and :math:`b` is
628
+ the second half.
629
+
630
+ Parameters
631
+ ----------
632
+ dim : int, optional
633
+ The dimension on which to split the input. Default: -1
634
+
635
+ Shape
636
+ -----
637
+ - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
638
+ dimensions
639
+ - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
640
+
641
+ Examples
642
+ --------
643
+ .. code-block:: python
644
+
645
+ >>> import brainstate.nn as nn
646
+ >>> import brainstate
647
+ >>> m = nn.GLU()
648
+ >>> x = brainstate.random.randn(4, 2)
649
+ >>> output = m(x)
650
+ """
651
+ __module__ = 'brainstate.nn'
652
+ dim: int
653
+
654
+ def __init__(self, dim: int = -1) -> None:
655
+ super().__init__()
656
+ self.dim = dim
657
+
658
+ def __call__(self, x: ArrayLike) -> ArrayLike:
659
+ return F.glu(x, self.dim)
660
+
661
+ def __repr__(self):
662
+ return f'{self.__class__.__name__}(dim={self.dim})'
663
+
664
+
665
+ class GELU(ElementWiseBlock):
666
+ r"""Applies the Gaussian Error Linear Units function.
667
+
668
+ .. math:: \text{GELU}(x) = x * \Phi(x)
669
+
670
+ where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian
671
+ Distribution.
672
+
673
+ When the approximate argument is True, Gelu is estimated with:
674
+
675
+ .. math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt(2 / \pi) * (x + 0.044715 * x^3)))
676
+
677
+ Parameters
678
+ ----------
679
+ approximate : bool, optional
680
+ Whether to use the tanh approximation algorithm. Default: False
681
+
682
+ Shape
683
+ -----
684
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
685
+ - Output: :math:`(*)`, same shape as the input.
686
+
687
+ Examples
688
+ --------
689
+ .. code-block:: python
690
+
691
+ >>> import brainstate.nn as nn
692
+ >>> import brainstate
693
+ >>> m = nn.GELU()
694
+ >>> x = brainstate.random.randn(2)
695
+ >>> output = m(x)
696
+ """
697
+ __module__ = 'brainstate.nn'
698
+ approximate: bool
699
+
700
+ def __init__(self, approximate: bool = False) -> None:
701
+ super().__init__()
702
+ self.approximate = approximate
703
+
704
+ def __call__(self, x: ArrayLike) -> ArrayLike:
705
+ return F.gelu(x, approximate=self.approximate)
706
+
707
+ def __repr__(self):
708
+ return f'{self.__class__.__name__}(approximate={self.approximate})'
709
+
710
+
711
+ class Hardshrink(ElementWiseBlock):
712
+ r"""Applies the Hard Shrinkage (Hardshrink) function element-wise.
713
+
714
+ Hardshrink is defined as:
715
+
716
+ .. math::
717
+ \text{HardShrink}(x) =
718
+ \begin{cases}
719
+ x, & \text{ if } x > \lambda \\
720
+ x, & \text{ if } x < -\lambda \\
721
+ 0, & \text{ otherwise }
722
+ \end{cases}
723
+
724
+ Parameters
725
+ ----------
726
+ lambd : float, optional
727
+ The :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
728
+
729
+ Shape
730
+ -----
731
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
732
+ - Output: :math:`(*)`, same shape as the input.
733
+
734
+ Examples
735
+ --------
736
+ .. code-block:: python
737
+
738
+ >>> import brainstate.nn as nn
739
+ >>> import brainstate
740
+ >>> m = nn.Hardshrink()
741
+ >>> x = brainstate.random.randn(2)
742
+ >>> output = m(x)
743
+ """
744
+ __module__ = 'brainstate.nn'
745
+ lambd: float
746
+
747
+ def __init__(self, lambd: float = 0.5) -> None:
748
+ super().__init__()
749
+ self.lambd = lambd
750
+
751
+ def __call__(self, x: ArrayLike) -> ArrayLike:
752
+ return F.hard_shrink(x, self.lambd)
753
+
754
+ def __repr__(self):
755
+ return f'{self.__class__.__name__}(lambd={self.lambd})'
756
+
757
+
758
+ class LeakyReLU(ElementWiseBlock):
759
+ r"""Applies the element-wise function.
760
+
761
+ .. math::
762
+ \text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)
763
+
764
+ or
765
+
766
+ .. math::
767
+ \text{LeakyReLU}(x) =
768
+ \begin{cases}
769
+ x, & \text{ if } x \geq 0 \\
770
+ \text{negative\_slope} \times x, & \text{ otherwise }
771
+ \end{cases}
772
+
773
+ Parameters
774
+ ----------
775
+ negative_slope : float, optional
776
+ Controls the angle of the negative slope (which is used for
777
+ negative input values). Default: 1e-2
778
+
779
+ Shape
780
+ -----
781
+ - Input: :math:`(*)` where `*` means, any number of additional
782
+ dimensions
783
+ - Output: :math:`(*)`, same shape as the input
784
+
785
+ Examples
786
+ --------
787
+ .. code-block:: python
788
+
789
+ >>> import brainstate.nn as nn
790
+ >>> import brainstate
791
+ >>> m = nn.LeakyReLU(0.1)
792
+ >>> x = brainstate.random.randn(2)
793
+ >>> output = m(x)
794
+ """
795
+ __module__ = 'brainstate.nn'
796
+ negative_slope: float
797
+
798
+ def __init__(self, negative_slope: float = 1e-2) -> None:
799
+ super().__init__()
800
+ self.negative_slope = negative_slope
801
+
802
+ def __call__(self, x: ArrayLike) -> ArrayLike:
803
+ return F.leaky_relu(x, self.negative_slope)
804
+
805
+ def __repr__(self):
806
+ return f'{self.__class__.__name__}(negative_slope={self.negative_slope})'
807
+
808
+
809
+ class LogSigmoid(ElementWiseBlock):
810
+ r"""Applies the element-wise function.
811
+
812
+ .. math::
813
+ \text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right)
814
+
815
+ Shape
816
+ -----
817
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
818
+ - Output: :math:`(*)`, same shape as the input.
819
+
820
+ Examples
821
+ --------
822
+ .. code-block:: python
823
+
824
+ >>> import brainstate.nn as nn
825
+ >>> import brainstate
826
+ >>> m = nn.LogSigmoid()
827
+ >>> x = brainstate.random.randn(2)
828
+ >>> output = m(x)
829
+ """
830
+ __module__ = 'brainstate.nn'
831
+
832
+ def __call__(self, x: ArrayLike) -> ArrayLike:
833
+ return F.log_sigmoid(x)
834
+
835
+
836
+ class Softplus(ElementWiseBlock):
837
+ r"""Applies the Softplus function element-wise.
838
+
839
+ .. math::
840
+ \text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))
841
+
842
+ SoftPlus is a smooth approximation to the ReLU function and can be used
843
+ to constrain the output of a machine to always be positive.
844
+
845
+ For numerical stability the implementation reverts to the linear function
846
+ when :math:`input \times \beta > threshold`.
847
+
848
+ Shape
849
+ -----
850
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
851
+ - Output: :math:`(*)`, same shape as the input.
852
+
853
+ Examples
854
+ --------
855
+ .. code-block:: python
856
+
857
+ >>> import brainstate.nn as nn
858
+ >>> import brainstate
859
+ >>> m = nn.Softplus()
860
+ >>> x = brainstate.random.randn(2)
861
+ >>> output = m(x)
862
+ """
863
+ __module__ = 'brainstate.nn'
864
+
865
+ def __call__(self, x: ArrayLike) -> ArrayLike:
866
+ return F.softplus(x)
867
+
868
+
869
+ class Softshrink(ElementWiseBlock):
870
+ r"""Applies the soft shrinkage function elementwise.
871
+
872
+ .. math::
873
+ \text{SoftShrinkage}(x) =
874
+ \begin{cases}
875
+ x - \lambda, & \text{ if } x > \lambda \\
876
+ x + \lambda, & \text{ if } x < -\lambda \\
877
+ 0, & \text{ otherwise }
878
+ \end{cases}
879
+
880
+ Parameters
881
+ ----------
882
+ lambd : float, optional
883
+ The :math:`\lambda` (must be no less than zero) value for the
884
+ Softshrink formulation. Default: 0.5
885
+
886
+ Shape
887
+ -----
888
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
889
+ - Output: :math:`(*)`, same shape as the input.
890
+
891
+ Examples
892
+ --------
893
+ .. code-block:: python
894
+
895
+ >>> import brainstate.nn as nn
896
+ >>> import brainstate
897
+ >>> m = nn.Softshrink()
898
+ >>> x = brainstate.random.randn(2)
899
+ >>> output = m(x)
900
+ """
901
+ __module__ = 'brainstate.nn'
902
+ lambd: float
903
+
904
+ def __init__(self, lambd: float = 0.5) -> None:
905
+ super().__init__()
906
+ self.lambd = lambd
907
+
908
+ def __call__(self, x: ArrayLike) -> ArrayLike:
909
+ return F.soft_shrink(x, self.lambd)
910
+
911
+ def __repr__(self):
912
+ return f'{self.__class__.__name__}(lambd={self.lambd})'
913
+
914
+
915
+ class PReLU(ElementWiseBlock):
916
+ r"""Applies the element-wise function.
917
+
918
+ .. math::
919
+ \text{PReLU}(x) = \max(0,x) + a * \min(0,x)
920
+
921
+ or
922
+
923
+ .. math::
924
+ \text{PReLU}(x) =
925
+ \begin{cases}
926
+ x, & \text{ if } x \geq 0 \\
927
+ ax, & \text{ otherwise }
928
+ \end{cases}
929
+
930
+ Here :math:`a` is a learnable parameter. When called without arguments,
931
+ `nn.PReLU()` uses a single parameter :math:`a` across all input channels.
932
+ If called with `nn.PReLU(nChannels)`, a separate :math:`a` is used for
933
+ each input channel.
934
+
935
+ Parameters
936
+ ----------
937
+ num_parameters : int, optional
938
+ Number of :math:`a` to learn. Although it takes an int as input,
939
+ there is only two values are legitimate: 1, or the number of channels
940
+ at input. Default: 1
941
+ init : float, optional
942
+ The initial value of :math:`a`. Default: 0.25
943
+ dtype : optional
944
+ The data type for the weight parameter.
945
+
946
+ Shape
947
+ -----
948
+ - Input: :math:`( *)` where `*` means, any number of additional dimensions.
949
+ - Output: :math:`(*)`, same shape as the input.
950
+
951
+ Attributes
952
+ ----------
953
+ weight : Tensor
954
+ The learnable weights of shape (:attr:`num_parameters`).
955
+
956
+ Notes
957
+ -----
958
+ - Weight decay should not be used when learning :math:`a` for good performance.
959
+ - Channel dim is the 2nd dim of input. When input has dims < 2, then there is
960
+ no channel dim and the number of channels = 1.
961
+
962
+ Examples
963
+ --------
964
+ .. code-block:: python
965
+
966
+ >>> import brainstate
967
+ >>> m = brainstate.nn.PReLU()
968
+ >>> x = brainstate.random.randn(2)
969
+ >>> output = m(x)
970
+ """
971
+ __module__ = 'brainstate.nn'
972
+ num_parameters: int
973
+
974
+ def __init__(self, num_parameters: int = 1, init: float = 0.25, dtype=None) -> None:
975
+ super().__init__()
976
+ self.num_parameters = num_parameters
977
+ self.weight = ParamState(jnp.ones(num_parameters, dtype=dtype) * init)
978
+
979
+ def __call__(self, x: ArrayLike) -> ArrayLike:
980
+ return F.prelu(x, self.weight.value)
981
+
982
+ def __repr__(self):
983
+ return f'{self.__class__.__name__}(num_parameters={self.num_parameters})'
984
+
985
+
986
+ class Softsign(ElementWiseBlock):
987
+ r"""Applies the element-wise function.
988
+
989
+ .. math::
990
+ \text{SoftSign}(x) = \frac{x}{ 1 + |x|}
991
+
992
+ Shape
993
+ -----
994
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
995
+ - Output: :math:`(*)`, same shape as the input.
996
+
997
+ Examples
998
+ --------
999
+ .. code-block:: python
1000
+
1001
+ >>> import brainstate.nn as nn
1002
+ >>> import brainstate
1003
+ >>> m = nn.Softsign()
1004
+ >>> x = brainstate.random.randn(2)
1005
+ >>> output = m(x)
1006
+ """
1007
+ __module__ = 'brainstate.nn'
1008
+
1009
+ def __call__(self, x: ArrayLike) -> ArrayLike:
1010
+ return F.soft_sign(x)
1011
+
1012
+
1013
+ class Tanhshrink(ElementWiseBlock):
1014
+ r"""Applies the element-wise function.
1015
+
1016
+ .. math::
1017
+ \text{Tanhshrink}(x) = x - \tanh(x)
1018
+
1019
+ Shape
1020
+ -----
1021
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
1022
+ - Output: :math:`(*)`, same shape as the input.
1023
+
1024
+ Examples
1025
+ --------
1026
+ .. code-block:: python
1027
+
1028
+ >>> import brainstate.nn as nn
1029
+ >>> import brainstate
1030
+ >>> m = nn.Tanhshrink()
1031
+ >>> x = brainstate.random.randn(2)
1032
+ >>> output = m(x)
1033
+ """
1034
+ __module__ = 'brainstate.nn'
1035
+
1036
+ def __call__(self, x: ArrayLike) -> ArrayLike:
1037
+ return F.tanh_shrink(x)
1038
+
1039
+
1040
+ class Softmin(ElementWiseBlock):
1041
+ r"""Applies the Softmin function to an n-dimensional input Tensor.
1042
+
1043
+ Rescales the input so that the elements of the n-dimensional output Tensor
1044
+ lie in the range `[0, 1]` and sum to 1.
1045
+
1046
+ Softmin is defined as:
1047
+
1048
+ .. math::
1049
+ \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
1050
+
1051
+ Parameters
1052
+ ----------
1053
+ dim : int, optional
1054
+ A dimension along which Softmin will be computed (so every slice
1055
+ along dim will sum to 1).
1056
+
1057
+ Shape
1058
+ -----
1059
+ - Input: :math:`(*)` where `*` means, any number of additional dimensions
1060
+ - Output: :math:`(*)`, same shape as the input
1061
+
1062
+ Returns
1063
+ -------
1064
+ Tensor
1065
+ A Tensor of the same dimension and shape as the input, with
1066
+ values in the range [0, 1]
1067
+
1068
+ Examples
1069
+ --------
1070
+ .. code-block:: python
1071
+
1072
+ >>> import brainstate.nn as nn
1073
+ >>> import brainstate
1074
+ >>> m = nn.Softmin(dim=1)
1075
+ >>> x = brainstate.random.randn(2, 3)
1076
+ >>> output = m(x)
1077
+ """
1078
+ __module__ = 'brainstate.nn'
1079
+ dim: Optional[int]
1080
+
1081
+ def __init__(self, dim: Optional[int] = None) -> None:
1082
+ super().__init__()
1083
+ self.dim = dim
1084
+
1085
+ def __call__(self, x: ArrayLike) -> ArrayLike:
1086
+ return F.softmin(x, self.dim)
1087
+
1088
+ def __repr__(self):
1089
+ return f'{self.__class__.__name__}(dim={self.dim})'
1090
+
1091
+
1092
+ class Softmax(ElementWiseBlock):
1093
+ r"""Applies the Softmax function to an n-dimensional input Tensor.
1094
+
1095
+ Rescales the input so that the elements of the n-dimensional output Tensor
1096
+ lie in the range [0,1] and sum to 1.
1097
+
1098
+ Softmax is defined as:
1099
+
1100
+ .. math::
1101
+ \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
1102
+
1103
+ When the input Tensor is a sparse tensor then the unspecified
1104
+ values are treated as ``-inf``.
1105
+
1106
+ Parameters
1107
+ ----------
1108
+ dim : int, optional
1109
+ A dimension along which Softmax will be computed (so every slice
1110
+ along dim will sum to 1).
1111
+
1112
+ Shape
1113
+ -----
1114
+ - Input: :math:`(*)` where `*` means, any number of additional dimensions
1115
+ - Output: :math:`(*)`, same shape as the input
1116
+
1117
+ Returns
1118
+ -------
1119
+ Tensor
1120
+ A Tensor of the same dimension and shape as the input with
1121
+ values in the range [0, 1]
1122
+
1123
+ Notes
1124
+ -----
1125
+ This module doesn't work directly with NLLLoss, which expects the Log to be
1126
+ computed between the Softmax and itself. Use `LogSoftmax` instead (it's
1127
+ faster and has better numerical properties).
1128
+
1129
+ Examples
1130
+ --------
1131
+ .. code-block:: python
1132
+
1133
+ >>> import brainstate.nn as nn
1134
+ >>> import brainstate
1135
+ >>> m = nn.Softmax(dim=1)
1136
+ >>> x = brainstate.random.randn(2, 3)
1137
+ >>> output = m(x)
1138
+ """
1139
+ __module__ = 'brainstate.nn'
1140
+ dim: Optional[int]
1141
+
1142
+ def __init__(self, dim: Optional[int] = None) -> None:
1143
+ super().__init__()
1144
+ self.dim = dim
1145
+
1146
+ def __call__(self, x: ArrayLike) -> ArrayLike:
1147
+ return F.softmax(x, self.dim)
1148
+
1149
+ def __repr__(self):
1150
+ return f'{self.__class__.__name__}(dim={self.dim})'
1151
+
1152
+
1153
+ class Softmax2d(ElementWiseBlock):
1154
+ r"""Applies SoftMax over features to each spatial location.
1155
+
1156
+ When given an image of ``Channels x Height x Width``, it will
1157
+ apply `Softmax` to each location :math:`(Channels, h_i, w_j)`
1158
+
1159
+ Shape
1160
+ -----
1161
+ - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`.
1162
+ - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
1163
+
1164
+ Returns
1165
+ -------
1166
+ Tensor
1167
+ A Tensor of the same dimension and shape as the input with
1168
+ values in the range [0, 1]
1169
+
1170
+ Examples
1171
+ --------
1172
+ .. code-block:: python
1173
+
1174
+ >>> import brainstate.nn as nn
1175
+ >>> import brainstate
1176
+ >>> m = nn.Softmax2d()
1177
+ >>> # you softmax over the 2nd dimension
1178
+ >>> x = brainstate.random.randn(2, 3, 12, 13)
1179
+ >>> output = m(x)
1180
+ """
1181
+ __module__ = 'brainstate.nn'
1182
+
1183
+ def __call__(self, x: ArrayLike) -> ArrayLike:
1184
+ assert x.ndim == 4 or x.ndim == 3, 'Softmax2d requires a 3D or 4D tensor as input'
1185
+ return F.softmax(x, -3)
1186
+
1187
+
1188
+ class LogSoftmax(ElementWiseBlock):
1189
+ r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional input Tensor.
1190
+
1191
+ The LogSoftmax formulation can be simplified as:
1192
+
1193
+ .. math::
1194
+ \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
1195
+
1196
+ Parameters
1197
+ ----------
1198
+ dim : int, optional
1199
+ A dimension along which LogSoftmax will be computed.
1200
+
1201
+ Shape
1202
+ -----
1203
+ - Input: :math:`(*)` where `*` means, any number of additional dimensions
1204
+ - Output: :math:`(*)`, same shape as the input
1205
+
1206
+ Returns
1207
+ -------
1208
+ Tensor
1209
+ A Tensor of the same dimension and shape as the input with
1210
+ values in the range [-inf, 0)
1211
+
1212
+ Examples
1213
+ --------
1214
+ .. code-block:: python
1215
+
1216
+ >>> import brainstate.nn as nn
1217
+ >>> import brainstate
1218
+ >>> m = nn.LogSoftmax(dim=1)
1219
+ >>> x = brainstate.random.randn(2, 3)
1220
+ >>> output = m(x)
1221
+ """
1222
+ __module__ = 'brainstate.nn'
1223
+ dim: Optional[int]
1224
+
1225
+ def __init__(self, dim: Optional[int] = None) -> None:
1226
+ super().__init__()
1227
+ self.dim = dim
1228
+
1229
+ def __call__(self, x: ArrayLike) -> ArrayLike:
1230
+ return F.log_softmax(x, self.dim)
1231
+
1232
+ def __repr__(self):
1233
+ return f'{self.__class__.__name__}(dim={self.dim})'
1234
+
1235
+
1236
+ class Identity(ElementWiseBlock):
1237
+ r"""A placeholder identity operator that is argument-insensitive.
1238
+
1239
+ Examples
1240
+ --------
1241
+ .. code-block:: python
1242
+
1243
+ >>> import brainstate.nn as nn
1244
+ >>> m = nn.Identity()
1245
+ >>> x = brainstate.random.randn(2, 3)
1246
+ >>> output = m(x)
1247
+ >>> assert (output == x).all()
1248
+ """
1249
+ __module__ = 'brainstate.nn'
1250
+
1251
+ def __call__(self, x):
1252
+ return x
1253
+
1254
+
1255
+ class SpikeBitwise(ElementWiseBlock):
1256
+ r"""Bitwise addition for the spiking inputs.
1257
+
1258
+ .. math::
1259
+
1260
+ \begin{array}{ccc}
1261
+ \hline \text { Mode } & \text { Expression for } \mathrm{g}(\mathrm{x}, \mathrm{y}) & \text { Code for } \mathrm{g}(\mathrm{x}, \mathrm{y}) \\
1262
+ \hline \text { ADD } & x+y & x+y \\
1263
+ \text { AND } & x \cap y & x \cdot y \\
1264
+ \text { IAND } & (\neg x) \cap y & (1-x) \cdot y \\
1265
+ \text { OR } & x \cup y & (x+y)-(x \cdot y) \\
1266
+ \hline
1267
+ \end{array}
1268
+
1269
+ Parameters
1270
+ ----------
1271
+ op : str, optional
1272
+ The bitwise operation. Default: 'add'
1273
+ name : str, optional
1274
+ The name of the dynamic system.
1275
+
1276
+ Examples
1277
+ --------
1278
+ .. code-block:: python
1279
+
1280
+ >>> import brainstate.nn as nn
1281
+ >>> m = nn.SpikeBitwise(op='and')
1282
+ >>> x = brainstate.random.randn(2, 3) > 0
1283
+ >>> y = brainstate.random.randn(2, 3) > 0
1284
+ >>> output = m(x, y)
1285
+ """
1286
+ __module__ = 'brainstate.nn'
1287
+
1288
+ def __init__(
1289
+ self,
1290
+ op: str = 'add',
1291
+ name: Optional[str] = None
1292
+ ) -> None:
1293
+ super().__init__(name=name)
1294
+ self.op = op
1295
+
1296
+ def __call__(self, x, y):
1297
+ import braintools
1298
+ return braintools.spike_bitwise(x, y, self.op)