brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -0,0 +1,1122 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from __future__ import annotations
19
+
20
+ from typing import Optional
21
+
22
+ import brainunit as u
23
+ import jax.numpy as jnp
24
+ import jax.typing
25
+
26
+ from brainstate import random, functional as F
27
+ from brainstate._state import ParamState
28
+ from brainstate.nn._module import ElementWiseBlock
29
+ from brainstate.typing import ArrayLike
30
+
31
+ __all__ = [
32
+ # activation functions
33
+ 'Threshold', 'ReLU', 'RReLU', 'Hardtanh', 'ReLU6', 'Sigmoid', 'Hardsigmoid',
34
+ 'Tanh', 'SiLU', 'Mish', 'Hardswish', 'ELU', 'CELU', 'SELU', 'GLU', 'GELU',
35
+ 'Hardshrink', 'LeakyReLU', 'LogSigmoid', 'Softplus', 'Softshrink', 'PReLU',
36
+ 'Softsign', 'Tanhshrink', 'Softmin', 'Softmax', 'Softmax2d', 'LogSoftmax',
37
+
38
+ # others
39
+ 'Identity', 'SpikeBitwise',
40
+ ]
41
+
42
+
43
+ class Threshold(ElementWiseBlock):
44
+ r"""Thresholds each element of the input Tensor.
45
+
46
+ Threshold is defined as:
47
+
48
+ .. math::
49
+ y =
50
+ \begin{cases}
51
+ x, &\text{ if } x > \text{threshold} \\
52
+ \text{value}, &\text{ otherwise }
53
+ \end{cases}
54
+
55
+ Args:
56
+ threshold: The value to threshold at
57
+ value: The value to replace with
58
+
59
+ Shape:
60
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
61
+ - Output: :math:`(*)`, same shape as the input.
62
+
63
+ Examples::
64
+
65
+ >>> import brainstate.nn as nn
66
+ >>> import brainstate as bst
67
+ >>> m = nn.Threshold(0.1, 20)
68
+ >>> x = random.randn(2)
69
+ >>> output = m(x)
70
+ """
71
+ __module__ = 'brainstate.nn'
72
+ threshold: float
73
+ value: float
74
+
75
+ def __init__(self, threshold: float, value: float) -> None:
76
+ super().__init__()
77
+ self.threshold = threshold
78
+ self.value = value
79
+
80
+ def __call__(self, x: ArrayLike) -> ArrayLike:
81
+ dtype = u.math.get_dtype(x)
82
+ return jnp.where(x > jnp.asarray(self.threshold, dtype=dtype),
83
+ x,
84
+ jnp.asarray(self.value, dtype=dtype))
85
+
86
+ def __repr__(self):
87
+ return f'{self.__class__.__name__}(threshold={self.threshold}, value={self.value})'
88
+
89
+
90
+ class ReLU(ElementWiseBlock):
91
+ r"""Applies the rectified linear unit function element-wise:
92
+
93
+ :math:`\text{ReLU}(x) = (x)^+ = \max(0, x)`
94
+
95
+ Shape:
96
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
97
+ - Output: :math:`(*)`, same shape as the input.
98
+
99
+ Examples::
100
+
101
+ >>> import brainstate.nn as nn
102
+ >>> import brainstate as bst
103
+ >>> m = nn.ReLU()
104
+ >>> x = random.randn(2)
105
+ >>> output = m(x)
106
+
107
+
108
+ An implementation of CReLU - https://arxiv.org/abs/1603.05201
109
+
110
+ >>> import brainstate.nn as nn
111
+ >>> import brainstate as bst
112
+ >>> m = nn.ReLU()
113
+ >>> x = random.randn(2).unsqueeze(0)
114
+ >>> output = jax.numpy.concat((m(x), m(-x)))
115
+ """
116
+ __module__ = 'brainstate.nn'
117
+
118
+ def __call__(self, x: ArrayLike) -> ArrayLike:
119
+ return F.relu(x)
120
+
121
+ def __repr__(self):
122
+ return f'{self.__class__.__name__}()'
123
+
124
+
125
+ class RReLU(ElementWiseBlock):
126
+ r"""Applies the randomized leaky rectified liner unit function, element-wise,
127
+ as described in the paper:
128
+
129
+ `Empirical Evaluation of Rectified Activations in Convolutional Network`_.
130
+
131
+ The function is defined as:
132
+
133
+ .. math::
134
+ \text{RReLU}(x) =
135
+ \begin{cases}
136
+ x & \text{if } x \geq 0 \\
137
+ ax & \text{ otherwise }
138
+ \end{cases}
139
+
140
+ where :math:`a` is randomly sampled from uniform distribution
141
+ :math:`\mathcal{U}(\text{lower}, \text{upper})`.
142
+
143
+ See: https://arxiv.org/pdf/1505.00853.pdf
144
+
145
+ Args:
146
+ lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
147
+ upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
148
+
149
+ Shape:
150
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
151
+ - Output: :math:`(*)`, same shape as the input.
152
+
153
+ Examples::
154
+
155
+ >>> import brainstate.nn as nn
156
+ >>> import brainstate as bst
157
+ >>> m = nn.RReLU(0.1, 0.3)
158
+ >>> x = random.randn(2)
159
+ >>> output = m(x)
160
+
161
+ .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
162
+ https://arxiv.org/abs/1505.00853
163
+ """
164
+ __module__ = 'brainstate.nn'
165
+ lower: float
166
+ upper: float
167
+
168
+ def __init__(
169
+ self,
170
+ lower: float = 1. / 8,
171
+ upper: float = 1. / 3,
172
+ ):
173
+ super().__init__()
174
+ self.lower = lower
175
+ self.upper = upper
176
+
177
+ def __call__(self, x: ArrayLike) -> ArrayLike:
178
+ return F.rrelu(x, self.lower, self.upper)
179
+
180
+ def extra_repr(self):
181
+ return f'{self.__class__.__name__}(lower={self.lower}, upper={self.upper})'
182
+
183
+
184
+ class Hardtanh(ElementWiseBlock):
185
+ r"""Applies the HardTanh function element-wise.
186
+
187
+ HardTanh is defined as:
188
+
189
+ .. math::
190
+ \text{HardTanh}(x) = \begin{cases}
191
+ \text{max\_val} & \text{ if } x > \text{ max\_val } \\
192
+ \text{min\_val} & \text{ if } x < \text{ min\_val } \\
193
+ x & \text{ otherwise } \\
194
+ \end{cases}
195
+
196
+ Args:
197
+ min_val: minimum value of the linear region range. Default: -1
198
+ max_val: maximum value of the linear region range. Default: 1
199
+
200
+ Keyword arguments :attr:`min_value` and :attr:`max_value`
201
+ have been deprecated in favor of :attr:`min_val` and :attr:`max_val`.
202
+
203
+ Shape:
204
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
205
+ - Output: :math:`(*)`, same shape as the input.
206
+
207
+ Examples::
208
+
209
+ >>> import brainstate.nn as nn
210
+ >>> import brainstate as bst
211
+ >>> m = nn.Hardtanh(-2, 2)
212
+ >>> x = random.randn(2)
213
+ >>> output = m(x)
214
+ """
215
+ __module__ = 'brainstate.nn'
216
+ min_val: float
217
+ max_val: float
218
+
219
+ def __init__(
220
+ self,
221
+ min_val: float = -1.,
222
+ max_val: float = 1.,
223
+ ) -> None:
224
+ super().__init__()
225
+ self.min_val = min_val
226
+ self.max_val = max_val
227
+ assert self.max_val > self.min_val
228
+
229
+ def __call__(self, x: ArrayLike) -> ArrayLike:
230
+ return F.hard_tanh(x, self.min_val, self.max_val)
231
+
232
+ def extra_repr(self) -> str:
233
+ return f'{self.__class__.__name__}(min_val={self.min_val}, max_val={self.max_val})'
234
+
235
+
236
+ class ReLU6(Hardtanh, ElementWiseBlock):
237
+ r"""Applies the element-wise function:
238
+
239
+ .. math::
240
+ \text{ReLU6}(x) = \min(\max(0,x), 6)
241
+
242
+ Shape:
243
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
244
+ - Output: :math:`(*)`, same shape as the input.
245
+
246
+ Examples::
247
+
248
+ >>> import brainstate.nn as nn
249
+ >>> import brainstate as bst
250
+ >>> m = nn.ReLU6()
251
+ >>> x = random.randn(2)
252
+ >>> output = m(x)
253
+ """
254
+ __module__ = 'brainstate.nn'
255
+
256
+ def __init__(self):
257
+ super().__init__(0., 6.)
258
+
259
+
260
+ class Sigmoid(ElementWiseBlock):
261
+ r"""Applies the element-wise function:
262
+
263
+ .. math::
264
+ \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
265
+
266
+
267
+ Shape:
268
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
269
+ - Output: :math:`(*)`, same shape as the input.
270
+
271
+ Examples::
272
+
273
+ >>> import brainstate.nn as nn
274
+ >>> import brainstate as bst
275
+ >>> m = nn.Sigmoid()
276
+ >>> x = random.randn(2)
277
+ >>> output = m(x)
278
+ """
279
+ __module__ = 'brainstate.nn'
280
+
281
+ def __call__(self, x: ArrayLike) -> ArrayLike:
282
+ return F.sigmoid(x)
283
+
284
+
285
+ class Hardsigmoid(ElementWiseBlock):
286
+ r"""Applies the Hardsigmoid function element-wise.
287
+
288
+ Hardsigmoid is defined as:
289
+
290
+ .. math::
291
+ \text{Hardsigmoid}(x) = \begin{cases}
292
+ 0 & \text{if~} x \le -3, \\
293
+ 1 & \text{if~} x \ge +3, \\
294
+ x / 6 + 1 / 2 & \text{otherwise}
295
+ \end{cases}
296
+
297
+ Shape:
298
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
299
+ - Output: :math:`(*)`, same shape as the input.
300
+
301
+ Examples::
302
+
303
+ >>> import brainstate.nn as nn
304
+ >>> import brainstate as bst
305
+ >>> m = nn.Hardsigmoid()
306
+ >>> x = random.randn(2)
307
+ >>> output = m(x)
308
+ """
309
+ __module__ = 'brainstate.nn'
310
+
311
+ def __call__(self, x: ArrayLike) -> ArrayLike:
312
+ return F.hard_sigmoid(x)
313
+
314
+
315
+ class Tanh(ElementWiseBlock):
316
+ r"""Applies the Hyperbolic Tangent (Tanh) function element-wise.
317
+
318
+ Tanh is defined as:
319
+
320
+ .. math::
321
+ \text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)}
322
+
323
+ Shape:
324
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
325
+ - Output: :math:`(*)`, same shape as the input.
326
+
327
+ Examples::
328
+
329
+ >>> import brainstate.nn as nn
330
+ >>> import brainstate as bst
331
+ >>> m = nn.Tanh()
332
+ >>> x = random.randn(2)
333
+ >>> output = m(x)
334
+ """
335
+ __module__ = 'brainstate.nn'
336
+
337
+ def __call__(self, x: ArrayLike) -> ArrayLike:
338
+ return F.tanh(x)
339
+
340
+
341
+ class SiLU(ElementWiseBlock):
342
+ r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
343
+ The SiLU function is also known as the swish function.
344
+
345
+ .. math::
346
+ \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.}
347
+
348
+ .. note::
349
+ See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
350
+ where the SiLU (Sigmoid Linear Unit) was originally coined, and see
351
+ `Sigmoid-Weighted Linear Units for Neural Network Function Approximation
352
+ in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
353
+ a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
354
+ where the SiLU was experimented with later.
355
+
356
+ Shape:
357
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
358
+ - Output: :math:`(*)`, same shape as the input.
359
+
360
+ Examples::
361
+
362
+ >>> import brainstate.nn as nn
363
+ >>> m = nn.SiLU()
364
+ >>> x = random.randn(2)
365
+ >>> output = m(x)
366
+ """
367
+ __module__ = 'brainstate.nn'
368
+
369
+ def __call__(self, x: ArrayLike) -> ArrayLike:
370
+ return F.silu(x)
371
+
372
+
373
+ class Mish(ElementWiseBlock):
374
+ r"""Applies the Mish function, element-wise.
375
+ Mish: A Self Regularized Non-Monotonic Neural Activation Function.
376
+
377
+ .. math::
378
+ \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
379
+
380
+ .. note::
381
+ See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_
382
+
383
+
384
+ Shape:
385
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
386
+ - Output: :math:`(*)`, same shape as the input.
387
+
388
+ Examples::
389
+
390
+ >>> import brainstate.nn as nn
391
+ >>> import brainstate as bst
392
+ >>> m = nn.Mish()
393
+ >>> x = random.randn(2)
394
+ >>> output = m(x)
395
+ """
396
+ __module__ = 'brainstate.nn'
397
+
398
+ def __call__(self, x: ArrayLike) -> ArrayLike:
399
+ return F.mish(x)
400
+
401
+
402
+ class Hardswish(ElementWiseBlock):
403
+ r"""Applies the Hardswish function, element-wise, as described in the paper:
404
+ `Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`_.
405
+
406
+ Hardswish is defined as:
407
+
408
+ .. math::
409
+ \text{Hardswish}(x) = \begin{cases}
410
+ 0 & \text{if~} x \le -3, \\
411
+ x & \text{if~} x \ge +3, \\
412
+ x \cdot (x + 3) /6 & \text{otherwise}
413
+ \end{cases}
414
+
415
+
416
+ Shape:
417
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
418
+ - Output: :math:`(*)`, same shape as the input.
419
+
420
+ Examples::
421
+
422
+ >>> import brainstate.nn as nn
423
+ >>> import brainstate as bst
424
+ >>> m = nn.Hardswish()
425
+ >>> x = random.randn(2)
426
+ >>> output = m(x)
427
+ """
428
+ __module__ = 'brainstate.nn'
429
+
430
+ def __call__(self, x: ArrayLike) -> ArrayLike:
431
+ return F.hard_swish(x)
432
+
433
+
434
+ class ELU(ElementWiseBlock):
435
+ r"""Applies the Exponential Linear Unit (ELU) function, element-wise, as described
436
+ in the paper: `Fast and Accurate Deep Network Learning by Exponential Linear
437
+ Units (ELUs) <https://arxiv.org/abs/1511.07289>`__.
438
+
439
+ ELU is defined as:
440
+
441
+ .. math::
442
+ \text{ELU}(x) = \begin{cases}
443
+ x, & \text{ if } x > 0\\
444
+ \alpha * (\exp(x) - 1), & \text{ if } x \leq 0
445
+ \end{cases}
446
+
447
+ Args:
448
+ alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
449
+
450
+ Shape:
451
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
452
+ - Output: :math:`(*)`, same shape as the input.
453
+
454
+ Examples::
455
+
456
+ >>> import brainstate.nn as nn
457
+ >>> import brainstate as bst
458
+ >>> m = nn.ELU()
459
+ >>> x = random.randn(2)
460
+ >>> output = m(x)
461
+ """
462
+ __module__ = 'brainstate.nn'
463
+ alpha: float
464
+
465
+ def __init__(self, alpha: float = 1.) -> None:
466
+ super().__init__()
467
+ self.alpha = alpha
468
+
469
+ def __call__(self, x: ArrayLike) -> ArrayLike:
470
+ return F.elu(x, self.alpha)
471
+
472
+ def extra_repr(self) -> str:
473
+ return f'{self.__class__.__name__}(alpha={self.alpha})'
474
+
475
+
476
+ class CELU(ElementWiseBlock):
477
+ r"""Applies the element-wise function:
478
+
479
+ .. math::
480
+ \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
481
+
482
+ More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ .
483
+
484
+ Args:
485
+ alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
486
+
487
+ Shape:
488
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
489
+ - Output: :math:`(*)`, same shape as the input.
490
+
491
+ Examples::
492
+
493
+ >>> import brainstate.nn as nn
494
+ >>> import brainstate as bst
495
+ >>> m = nn.CELU()
496
+ >>> x = random.randn(2)
497
+ >>> output = m(x)
498
+
499
+ .. _`Continuously Differentiable Exponential Linear Units`:
500
+ https://arxiv.org/abs/1704.07483
501
+ """
502
+ __module__ = 'brainstate.nn'
503
+ alpha: float
504
+
505
+ def __init__(self, alpha: float = 1.) -> None:
506
+ super().__init__()
507
+ self.alpha = alpha
508
+
509
+ def __call__(self, x: ArrayLike) -> ArrayLike:
510
+ return F.celu(x, self.alpha)
511
+
512
+ def extra_repr(self) -> str:
513
+ return f'{self.__class__.__name__}(alpha={self.alpha})'
514
+
515
+
516
+ class SELU(ElementWiseBlock):
517
+ r"""Applied element-wise, as:
518
+
519
+ .. math::
520
+ \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))
521
+
522
+ with :math:`\alpha = 1.6732632423543772848170429916717` and
523
+ :math:`\text{scale} = 1.0507009873554804934193349852946`.
524
+
525
+ More details can be found in the paper `Self-Normalizing Neural Networks`_ .
526
+
527
+
528
+ Shape:
529
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
530
+ - Output: :math:`(*)`, same shape as the input.
531
+
532
+ Examples::
533
+
534
+ >>> import brainstate.nn as nn
535
+ >>> import brainstate as bst
536
+ >>> m = nn.SELU()
537
+ >>> x = random.randn(2)
538
+ >>> output = m(x)
539
+
540
+ .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
541
+ """
542
+ __module__ = 'brainstate.nn'
543
+
544
+ def __call__(self, x: ArrayLike) -> ArrayLike:
545
+ return F.selu(x)
546
+
547
+
548
+ class GLU(ElementWiseBlock):
549
+ r"""Applies the gated linear unit function
550
+ :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
551
+ of the input matrices and :math:`b` is the second half.
552
+
553
+ Args:
554
+ dim (int): the dimension on which to split the input. Default: -1
555
+
556
+ Shape:
557
+ - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
558
+ dimensions
559
+ - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
560
+
561
+ Examples::
562
+
563
+ >>> import brainstate.nn as nn
564
+ >>> import brainstate as bst
565
+ >>> m = nn.GLU()
566
+ >>> x = random.randn(4, 2)
567
+ >>> output = m(x)
568
+ """
569
+ __module__ = 'brainstate.nn'
570
+ dim: int
571
+
572
+ def __init__(self, dim: int = -1) -> None:
573
+ super().__init__()
574
+ self.dim = dim
575
+
576
+ def __call__(self, x: ArrayLike) -> ArrayLike:
577
+ return F.glu(x, self.dim)
578
+
579
+ def __repr__(self):
580
+ return f'{self.__class__.__name__}(dim={self.dim})'
581
+
582
+
583
+ class GELU(ElementWiseBlock):
584
+ r"""Applies the Gaussian Error Linear Units function:
585
+
586
+ .. math:: \text{GELU}(x) = x * \Phi(x)
587
+
588
+ where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
589
+
590
+ When the approximate argument is 'tanh', Gelu is estimated with:
591
+
592
+ .. math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt(2 / \pi) * (x + 0.044715 * x^3)))
593
+
594
+ Args:
595
+ approximate (str, optional): the gelu approximation algorithm to use:
596
+ ``'none'`` | ``'tanh'``. Default: ``'none'``
597
+
598
+ Shape:
599
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
600
+ - Output: :math:`(*)`, same shape as the input.
601
+
602
+ Examples::
603
+
604
+ >>> import brainstate.nn as nn
605
+ >>> import brainstate as bst
606
+ >>> m = nn.GELU()
607
+ >>> x = random.randn(2)
608
+ >>> output = m(x)
609
+ """
610
+ __module__ = 'brainstate.nn'
611
+ approximate: bool
612
+
613
+ def __init__(self, approximate: bool = False) -> None:
614
+ super().__init__()
615
+ self.approximate = approximate
616
+
617
+ def __call__(self, x: ArrayLike) -> ArrayLike:
618
+ return F.gelu(x, approximate=self.approximate)
619
+
620
+ def __repr__(self):
621
+ return f'{self.__class__.__name__}(approximate={self.approximate})'
622
+
623
+
624
+ class Hardshrink(ElementWiseBlock):
625
+ r"""Applies the Hard Shrinkage (Hardshrink) function element-wise.
626
+
627
+ Hardshrink is defined as:
628
+
629
+ .. math::
630
+ \text{HardShrink}(x) =
631
+ \begin{cases}
632
+ x, & \text{ if } x > \lambda \\
633
+ x, & \text{ if } x < -\lambda \\
634
+ 0, & \text{ otherwise }
635
+ \end{cases}
636
+
637
+ Args:
638
+ lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
639
+
640
+ Shape:
641
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
642
+ - Output: :math:`(*)`, same shape as the input.
643
+
644
+ Examples::
645
+
646
+ >>> import brainstate.nn as nn
647
+ >>> import brainstate as bst
648
+ >>> m = nn.Hardshrink()
649
+ >>> x = random.randn(2)
650
+ >>> output = m(x)
651
+ """
652
+ __module__ = 'brainstate.nn'
653
+ lambd: float
654
+
655
+ def __init__(self, lambd: float = 0.5) -> None:
656
+ super().__init__()
657
+ self.lambd = lambd
658
+
659
+ def __call__(self, x: ArrayLike) -> ArrayLike:
660
+ return F.hard_shrink(x, self.lambd)
661
+
662
+ def __repr__(self):
663
+ return f'{self.__class__.__name__}(lambd={self.lambd})'
664
+
665
+
666
+ class LeakyReLU(ElementWiseBlock):
667
+ r"""Applies the element-wise function:
668
+
669
+ .. math::
670
+ \text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)
671
+
672
+
673
+ or
674
+
675
+ .. math::
676
+ \text{LeakyReLU}(x) =
677
+ \begin{cases}
678
+ x, & \text{ if } x \geq 0 \\
679
+ \text{negative\_slope} \times x, & \text{ otherwise }
680
+ \end{cases}
681
+
682
+ Args:
683
+ negative_slope: Controls the angle of the negative slope (which is used for
684
+ negative input values). Default: 1e-2
685
+
686
+ Shape:
687
+ - Input: :math:`(*)` where `*` means, any number of additional
688
+ dimensions
689
+ - Output: :math:`(*)`, same shape as the input
690
+
691
+ Examples::
692
+
693
+ >>> import brainstate.nn as nn
694
+ >>> import brainstate as bst
695
+ >>> m = nn.LeakyReLU(0.1)
696
+ >>> x = random.randn(2)
697
+ >>> output = m(x)
698
+ """
699
+ __module__ = 'brainstate.nn'
700
+ negative_slope: float
701
+
702
+ def __init__(self, negative_slope: float = 1e-2) -> None:
703
+ super().__init__()
704
+ self.negative_slope = negative_slope
705
+
706
+ def __call__(self, x: ArrayLike) -> ArrayLike:
707
+ return F.leaky_relu(x, self.negative_slope)
708
+
709
+ def __repr__(self):
710
+ return f'{self.__class__.__name__}(negative_slope={self.negative_slope})'
711
+
712
+
713
+ class LogSigmoid(ElementWiseBlock):
714
+ r"""Applies the element-wise function:
715
+
716
+ .. math::
717
+ \text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right)
718
+
719
+ Shape:
720
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
721
+ - Output: :math:`(*)`, same shape as the input.
722
+
723
+ Examples::
724
+
725
+ >>> import brainstate.nn as nn
726
+ >>> import brainstate as bst
727
+ >>> m = nn.LogSigmoid()
728
+ >>> x = random.randn(2)
729
+ >>> output = m(x)
730
+ """
731
+ __module__ = 'brainstate.nn'
732
+
733
+ def __call__(self, x: ArrayLike) -> ArrayLike:
734
+ return F.log_sigmoid(x)
735
+
736
+
737
+ class Softplus(ElementWiseBlock):
738
+ r"""Applies the Softplus function :math:`\text{Softplus}(x) = \frac{1}{\beta} *
739
+ \log(1 + \exp(\beta * x))` element-wise.
740
+
741
+ SoftPlus is a smooth approximation to the ReLU function and can be used
742
+ to constrain the output of a machine to always be positive.
743
+
744
+ For numerical stability the implementation reverts to the linear function
745
+ when :math:`input \times \beta > threshold`.
746
+
747
+ Shape:
748
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
749
+ - Output: :math:`(*)`, same shape as the input.
750
+
751
+ Examples::
752
+
753
+ >>> import brainstate.nn as nn
754
+ >>> import brainstate as bst
755
+ >>> m = nn.Softplus()
756
+ >>> x = random.randn(2)
757
+ >>> output = m(x)
758
+ """
759
+ __module__ = 'brainstate.nn'
760
+
761
+ def __call__(self, x: ArrayLike) -> ArrayLike:
762
+ return F.softplus(x)
763
+
764
+
765
+ class Softshrink(ElementWiseBlock):
766
+ r"""Applies the soft shrinkage function elementwise:
767
+
768
+ .. math::
769
+ \text{SoftShrinkage}(x) =
770
+ \begin{cases}
771
+ x - \lambda, & \text{ if } x > \lambda \\
772
+ x + \lambda, & \text{ if } x < -\lambda \\
773
+ 0, & \text{ otherwise }
774
+ \end{cases}
775
+
776
+ Args:
777
+ lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5
778
+
779
+ Shape:
780
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
781
+ - Output: :math:`(*)`, same shape as the input.
782
+
783
+ Examples::
784
+
785
+ >>> import brainstate.nn as nn
786
+ >>> import brainstate as bst
787
+ >>> m = nn.Softshrink()
788
+ >>> x = random.randn(2)
789
+ >>> output = m(x)
790
+ """
791
+ __module__ = 'brainstate.nn'
792
+ lambd: float
793
+
794
+ def __init__(self, lambd: float = 0.5) -> None:
795
+ super().__init__()
796
+ self.lambd = lambd
797
+
798
+ def __call__(self, x: ArrayLike) -> ArrayLike:
799
+ return F.soft_shrink(x, self.lambd)
800
+
801
+ def __repr__(self):
802
+ return f'{self.__class__.__name__}(lambd={self.lambd})'
803
+
804
+
805
+ class PReLU(ElementWiseBlock):
806
+ r"""Applies the element-wise function:
807
+
808
+ .. math::
809
+ \text{PReLU}(x) = \max(0,x) + a * \min(0,x)
810
+
811
+ or
812
+
813
+ .. math::
814
+ \text{PReLU}(x) =
815
+ \begin{cases}
816
+ x, & \text{ if } x \geq 0 \\
817
+ ax, & \text{ otherwise }
818
+ \end{cases}
819
+
820
+ Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single
821
+ parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
822
+ a separate :math:`a` is used for each input channel.
823
+
824
+
825
+ .. note::
826
+ weight decay should not be used when learning :math:`a` for good performance.
827
+
828
+ .. note::
829
+ Channel dim is the 2nd dim of input. When input has dims < 2, then there is
830
+ no channel dim and the number of channels = 1.
831
+
832
+ Args:
833
+ num_parameters (int): number of :math:`a` to learn.
834
+ Although it takes an int as input, there is only two values are legitimate:
835
+ 1, or the number of channels at input. Default: 1
836
+ init (float): the initial value of :math:`a`. Default: 0.25
837
+
838
+ Shape:
839
+ - Input: :math:`( *)` where `*` means, any number of additional
840
+ dimensions.
841
+ - Output: :math:`(*)`, same shape as the input.
842
+
843
+ Attributes:
844
+ weight (Tensor): the learnable weights of shape (:attr:`num_parameters`).
845
+
846
+ Examples::
847
+
848
+ >>> import brainstate as bst
849
+ >>> m = bst.nn.PReLU()
850
+ >>> x = bst.random.randn(2)
851
+ >>> output = m(x)
852
+ """
853
+ __module__ = 'brainstate.nn'
854
+ num_parameters: int
855
+
856
+ def __init__(self, num_parameters: int = 1, init: float = 0.25, dtype=None) -> None:
857
+ super().__init__()
858
+ self.num_parameters = num_parameters
859
+ self.weight = ParamState(jnp.ones(num_parameters, dtype=dtype) * init)
860
+
861
+ def __call__(self, x: ArrayLike) -> ArrayLike:
862
+ return F.prelu(x, self.weight.value)
863
+
864
+ def __repr__(self):
865
+ return f'{self.__class__.__name__}(num_parameters={self.num_parameters})'
866
+
867
+
868
+ class Softsign(ElementWiseBlock):
869
+ r"""Applies the element-wise function:
870
+
871
+ .. math::
872
+ \text{SoftSign}(x) = \frac{x}{ 1 + |x|}
873
+
874
+ Shape:
875
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
876
+ - Output: :math:`(*)`, same shape as the input.
877
+
878
+ Examples::
879
+
880
+ >>> import brainstate.nn as nn
881
+ >>> import brainstate as bst
882
+ >>> m = nn.Softsign()
883
+ >>> x = random.randn(2)
884
+ >>> output = m(x)
885
+ """
886
+ __module__ = 'brainstate.nn'
887
+
888
+ def __call__(self, x: ArrayLike) -> ArrayLike:
889
+ return F.soft_sign(x)
890
+
891
+
892
+ class Tanhshrink(ElementWiseBlock):
893
+ r"""Applies the element-wise function:
894
+
895
+ .. math::
896
+ \text{Tanhshrink}(x) = x - \tanh(x)
897
+
898
+ Shape:
899
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
900
+ - Output: :math:`(*)`, same shape as the input.
901
+
902
+ Examples::
903
+
904
+ >>> import brainstate.nn as nn
905
+ >>> import brainstate as bst
906
+ >>> m = nn.Tanhshrink()
907
+ >>> x = random.randn(2)
908
+ >>> output = m(x)
909
+ """
910
+ __module__ = 'brainstate.nn'
911
+
912
+ def __call__(self, x: ArrayLike) -> ArrayLike:
913
+ return F.tanh_shrink(x)
914
+
915
+
916
+ class Softmin(ElementWiseBlock):
917
+ r"""Applies the Softmin function to an n-dimensional input Tensor
918
+ rescaling them so that the elements of the n-dimensional output Tensor
919
+ lie in the range `[0, 1]` and sum to 1.
920
+
921
+ Softmin is defined as:
922
+
923
+ .. math::
924
+ \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
925
+
926
+ Shape:
927
+ - Input: :math:`(*)` where `*` means, any number of additional
928
+ dimensions
929
+ - Output: :math:`(*)`, same shape as the input
930
+
931
+ Args:
932
+ dim (int): A dimension along which Softmin will be computed (so every slice
933
+ along dim will sum to 1).
934
+
935
+ Returns:
936
+ a Tensor of the same dimension and shape as the input, with
937
+ values in the range [0, 1]
938
+
939
+ Examples::
940
+
941
+ >>> import brainstate.nn as nn
942
+ >>> import brainstate as bst
943
+ >>> m = nn.Softmin(dim=1)
944
+ >>> x = random.randn(2, 3)
945
+ >>> output = m(x)
946
+ """
947
+ __module__ = 'brainstate.nn'
948
+ dim: Optional[int]
949
+
950
+ def __init__(self, dim: Optional[int] = None) -> None:
951
+ super().__init__()
952
+ self.dim = dim
953
+
954
+ def __call__(self, x: ArrayLike) -> ArrayLike:
955
+ return F.softmin(x, self.dim)
956
+
957
+ def __repr__(self):
958
+ return f'{self.__class__.__name__}(dim={self.dim})'
959
+
960
+
961
+ class Softmax(ElementWiseBlock):
962
+ r"""Applies the Softmax function to an n-dimensional input Tensor
963
+ rescaling them so that the elements of the n-dimensional output Tensor
964
+ lie in the range [0,1] and sum to 1.
965
+
966
+ Softmax is defined as:
967
+
968
+ .. math::
969
+ \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
970
+
971
+ When the input Tensor is a sparse tensor then the unspecified
972
+ values are treated as ``-inf``.
973
+
974
+ Shape:
975
+ - Input: :math:`(*)` where `*` means, any number of additional
976
+ dimensions
977
+ - Output: :math:`(*)`, same shape as the input
978
+
979
+ Returns:
980
+ a Tensor of the same dimension and shape as the input with
981
+ values in the range [0, 1]
982
+
983
+ Args:
984
+ dim (int): A dimension along which Softmax will be computed (so every slice
985
+ along dim will sum to 1).
986
+
987
+ .. note::
988
+ This module doesn't work directly with NLLLoss,
989
+ which expects the Log to be computed between the Softmax and itself.
990
+ Use `LogSoftmax` instead (it's faster and has better numerical properties).
991
+
992
+ Examples::
993
+
994
+ >>> import brainstate.nn as nn
995
+ >>> import brainstate as bst
996
+ >>> m = nn.Softmax(dim=1)
997
+ >>> x = random.randn(2, 3)
998
+ >>> output = m(x)
999
+
1000
+ """
1001
+ __module__ = 'brainstate.nn'
1002
+ dim: Optional[int]
1003
+
1004
+ def __init__(self, dim: Optional[int] = None) -> None:
1005
+ super().__init__()
1006
+ self.dim = dim
1007
+
1008
+ def __call__(self, x: ArrayLike) -> ArrayLike:
1009
+ return F.softmax(x, self.dim)
1010
+
1011
+ def __repr__(self):
1012
+ return f'{self.__class__.__name__}(dim={self.dim})'
1013
+
1014
+
1015
+ class Softmax2d(ElementWiseBlock):
1016
+ r"""Applies SoftMax over features to each spatial location.
1017
+
1018
+ When given an image of ``Channels x Height x Width``, it will
1019
+ apply `Softmax` to each location :math:`(Channels, h_i, w_j)`
1020
+
1021
+ Shape:
1022
+ - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`.
1023
+ - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
1024
+
1025
+ Returns:
1026
+ a Tensor of the same dimension and shape as the input with
1027
+ values in the range [0, 1]
1028
+
1029
+ Examples::
1030
+
1031
+ >>> import brainstate.nn as nn
1032
+ >>> import brainstate as bst
1033
+ >>> m = nn.Softmax2d()
1034
+ >>> # you softmax over the 2nd dimension
1035
+ >>> x = random.randn(2, 3, 12, 13)
1036
+ >>> output = m(x)
1037
+ """
1038
+ __module__ = 'brainstate.nn'
1039
+
1040
+ def __call__(self, x: ArrayLike) -> ArrayLike:
1041
+ assert x.ndim == 4 or x.ndim == 3, 'Softmax2d requires a 3D or 4D tensor as input'
1042
+ return F.softmax(x, -3)
1043
+
1044
+
1045
+ class LogSoftmax(ElementWiseBlock):
1046
+ r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
1047
+ input Tensor. The LogSoftmax formulation can be simplified as:
1048
+
1049
+ .. math::
1050
+ \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
1051
+
1052
+ Shape:
1053
+ - Input: :math:`(*)` where `*` means, any number of additional
1054
+ dimensions
1055
+ - Output: :math:`(*)`, same shape as the input
1056
+
1057
+ Args:
1058
+ dim (int): A dimension along which LogSoftmax will be computed.
1059
+
1060
+ Returns:
1061
+ a Tensor of the same dimension and shape as the input with
1062
+ values in the range [-inf, 0)
1063
+
1064
+ Examples::
1065
+
1066
+ >>> import brainstate.nn as nn
1067
+ >>> import brainstate as bst
1068
+ >>> m = nn.LogSoftmax(dim=1)
1069
+ >>> x = random.randn(2, 3)
1070
+ >>> output = m(x)
1071
+ """
1072
+ __module__ = 'brainstate.nn'
1073
+ dim: Optional[int]
1074
+
1075
+ def __init__(self, dim: Optional[int] = None) -> None:
1076
+ super().__init__()
1077
+ self.dim = dim
1078
+
1079
+ def __call__(self, x: ArrayLike) -> ArrayLike:
1080
+ return F.log_softmax(x, self.dim)
1081
+
1082
+ def __repr__(self):
1083
+ return f'{self.__class__.__name__}(dim={self.dim})'
1084
+
1085
+
1086
+ class Identity(ElementWiseBlock):
1087
+ r"""A placeholder identity operator that is argument-insensitive.
1088
+ """
1089
+ __module__ = 'brainstate.nn'
1090
+
1091
+ def __call__(self, x):
1092
+ return x
1093
+
1094
+
1095
+ class SpikeBitwise(ElementWiseBlock):
1096
+ r"""Bitwise addition for the spiking inputs.
1097
+
1098
+ .. math::
1099
+
1100
+ \begin{array}{ccc}
1101
+ \hline \text { Mode } & \text { Expression for } \mathrm{g}(\mathrm{x}, \mathrm{y}) & \text { Code for } \mathrm{g}(\mathrm{x}, \mathrm{y}) \\
1102
+ \hline \text { ADD } & x+y & x+y \\
1103
+ \text { AND } & x \cap y & x \cdot y \\
1104
+ \text { IAND } & (\neg x) \cap y & (1-x) \cdot y \\
1105
+ \text { OR } & x \cup y & (x+y)-(x \cdot y) \\
1106
+ \hline
1107
+ \end{array}
1108
+
1109
+ Args:
1110
+ op: str. The bitwise operation.
1111
+ name: str. The name of the dynamic system.
1112
+ """
1113
+ __module__ = 'brainstate.nn'
1114
+
1115
+ def __init__(self,
1116
+ op: str = 'add',
1117
+ name: Optional[str] = None) -> None:
1118
+ super().__init__(name=name)
1119
+ self.op = op
1120
+
1121
+ def __call__(self, x, y):
1122
+ return F.spike_bitwise(x, y, self.op)