brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +95 -29
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.9.dist-info/RECORD +0 -130
  161. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -49,20 +49,26 @@ class Threshold(ElementWiseBlock):
49
49
  \text{value}, &\text{ otherwise }
50
50
  \end{cases}
51
51
 
52
- Args:
53
- threshold: The value to threshold at
54
- value: The value to replace with
52
+ Parameters
53
+ ----------
54
+ threshold : float
55
+ The value to threshold at.
56
+ value : float
57
+ The value to replace with.
55
58
 
56
- Shape:
57
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
58
- - Output: :math:`(*)`, same shape as the input.
59
+ Shape
60
+ -----
61
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
62
+ - Output: :math:`(*)`, same shape as the input.
59
63
 
60
- Examples::
64
+ Examples
65
+ --------
66
+ .. code-block:: python
61
67
 
62
68
  >>> import brainstate.nn as nn
63
69
  >>> import brainstate
64
70
  >>> m = nn.Threshold(0.1, 20)
65
- >>> x = random.randn(2)
71
+ >>> x = brainstate.random.randn(2)
66
72
  >>> output = m(x)
67
73
  """
68
74
  __module__ = 'brainstate.nn'
@@ -85,30 +91,38 @@ class Threshold(ElementWiseBlock):
85
91
 
86
92
 
87
93
  class ReLU(ElementWiseBlock):
88
- r"""Applies the rectified linear unit function element-wise:
94
+ r"""Applies the rectified linear unit function element-wise.
95
+
96
+ The ReLU function is defined as:
89
97
 
90
- :math:`\text{ReLU}(x) = (x)^+ = \max(0, x)`
98
+ .. math::
99
+ \text{ReLU}(x) = (x)^+ = \max(0, x)
91
100
 
92
- Shape:
93
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
94
- - Output: :math:`(*)`, same shape as the input.
101
+ Shape
102
+ -----
103
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
104
+ - Output: :math:`(*)`, same shape as the input.
95
105
 
96
- Examples::
106
+ Examples
107
+ --------
108
+ .. code-block:: python
97
109
 
98
110
  >>> import brainstate.nn as nn
99
- >>> import brainstate as brainstate
111
+ >>> import brainstate
100
112
  >>> m = nn.ReLU()
101
- >>> x = random.randn(2)
113
+ >>> x = brainstate.random.randn(2)
102
114
  >>> output = m(x)
103
115
 
116
+ An implementation of CReLU - https://arxiv.org/abs/1603.05201
104
117
 
105
- An implementation of CReLU - https://arxiv.org/abs/1603.05201
118
+ .. code-block:: python
106
119
 
107
120
  >>> import brainstate.nn as nn
108
- >>> import brainstate as brainstate
121
+ >>> import brainstate
122
+ >>> import jax.numpy as jnp
109
123
  >>> m = nn.ReLU()
110
- >>> x = random.randn(2).unsqueeze(0)
111
- >>> output = jax.numpy.concat((m(x), m(-x)))
124
+ >>> x = brainstate.random.randn(2).unsqueeze(0)
125
+ >>> output = jnp.concat((m(x), m(-x)))
112
126
  """
113
127
  __module__ = 'brainstate.nn'
114
128
 
@@ -120,10 +134,10 @@ class ReLU(ElementWiseBlock):
120
134
 
121
135
 
122
136
  class RReLU(ElementWiseBlock):
123
- r"""Applies the randomized leaky rectified liner unit function, element-wise,
124
- as described in the paper:
137
+ r"""Applies the randomized leaky rectified liner unit function, element-wise.
125
138
 
126
- `Empirical Evaluation of Rectified Activations in Convolutional Network`_.
139
+ As described in the paper `Empirical Evaluation of Rectified Activations in
140
+ Convolutional Network`_.
127
141
 
128
142
  The function is defined as:
129
143
 
@@ -137,26 +151,32 @@ class RReLU(ElementWiseBlock):
137
151
  where :math:`a` is randomly sampled from uniform distribution
138
152
  :math:`\mathcal{U}(\text{lower}, \text{upper})`.
139
153
 
140
- See: https://arxiv.org/pdf/1505.00853.pdf
154
+ Parameters
155
+ ----------
156
+ lower : float, optional
157
+ Lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
158
+ upper : float, optional
159
+ Upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
141
160
 
142
- Args:
143
- lower: lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
144
- upper: upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
161
+ Shape
162
+ -----
163
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
164
+ - Output: :math:`(*)`, same shape as the input.
145
165
 
146
- Shape:
147
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
148
- - Output: :math:`(*)`, same shape as the input.
166
+ References
167
+ ----------
168
+ .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
169
+ https://arxiv.org/abs/1505.00853
149
170
 
150
- Examples::
171
+ Examples
172
+ --------
173
+ .. code-block:: python
151
174
 
152
175
  >>> import brainstate.nn as nn
153
- >>> import brainstate as brainstate
176
+ >>> import brainstate
154
177
  >>> m = nn.RReLU(0.1, 0.3)
155
- >>> x = random.randn(2)
178
+ >>> x = brainstate.random.randn(2)
156
179
  >>> output = m(x)
157
-
158
- .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
159
- https://arxiv.org/abs/1505.00853
160
180
  """
161
181
  __module__ = 'brainstate.nn'
162
182
  lower: float
@@ -190,23 +210,31 @@ class Hardtanh(ElementWiseBlock):
190
210
  x & \text{ otherwise } \\
191
211
  \end{cases}
192
212
 
193
- Args:
194
- min_val: minimum value of the linear region range. Default: -1
195
- max_val: maximum value of the linear region range. Default: 1
213
+ Parameters
214
+ ----------
215
+ min_val : float, optional
216
+ Minimum value of the linear region range. Default: -1
217
+ max_val : float, optional
218
+ Maximum value of the linear region range. Default: 1
196
219
 
220
+ Notes
221
+ -----
197
222
  Keyword arguments :attr:`min_value` and :attr:`max_value`
198
223
  have been deprecated in favor of :attr:`min_val` and :attr:`max_val`.
199
224
 
200
- Shape:
201
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
202
- - Output: :math:`(*)`, same shape as the input.
225
+ Shape
226
+ -----
227
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
228
+ - Output: :math:`(*)`, same shape as the input.
203
229
 
204
- Examples::
230
+ Examples
231
+ --------
232
+ .. code-block:: python
205
233
 
206
234
  >>> import brainstate.nn as nn
207
- >>> import brainstate as brainstate
235
+ >>> import brainstate
208
236
  >>> m = nn.Hardtanh(-2, 2)
209
- >>> x = random.randn(2)
237
+ >>> x = brainstate.random.randn(2)
210
238
  >>> output = m(x)
211
239
  """
212
240
  __module__ = 'brainstate.nn'
@@ -230,22 +258,27 @@ class Hardtanh(ElementWiseBlock):
230
258
  return f'{self.__class__.__name__}(min_val={self.min_val}, max_val={self.max_val})'
231
259
 
232
260
 
233
- class ReLU6(Hardtanh, ElementWiseBlock):
234
- r"""Applies the element-wise function:
261
+ class ReLU6(Hardtanh):
262
+ r"""Applies the element-wise function.
263
+
264
+ ReLU6 is defined as:
235
265
 
236
266
  .. math::
237
267
  \text{ReLU6}(x) = \min(\max(0,x), 6)
238
268
 
239
- Shape:
240
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
241
- - Output: :math:`(*)`, same shape as the input.
269
+ Shape
270
+ -----
271
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
272
+ - Output: :math:`(*)`, same shape as the input.
242
273
 
243
- Examples::
274
+ Examples
275
+ --------
276
+ .. code-block:: python
244
277
 
245
278
  >>> import brainstate.nn as nn
246
- >>> import brainstate as brainstate
279
+ >>> import brainstate
247
280
  >>> m = nn.ReLU6()
248
- >>> x = random.randn(2)
281
+ >>> x = brainstate.random.randn(2)
249
282
  >>> output = m(x)
250
283
  """
251
284
  __module__ = 'brainstate.nn'
@@ -255,22 +288,26 @@ class ReLU6(Hardtanh, ElementWiseBlock):
255
288
 
256
289
 
257
290
  class Sigmoid(ElementWiseBlock):
258
- r"""Applies the element-wise function:
291
+ r"""Applies the element-wise function.
292
+
293
+ Sigmoid is defined as:
259
294
 
260
295
  .. math::
261
296
  \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
262
297
 
298
+ Shape
299
+ -----
300
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
301
+ - Output: :math:`(*)`, same shape as the input.
263
302
 
264
- Shape:
265
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
266
- - Output: :math:`(*)`, same shape as the input.
267
-
268
- Examples::
303
+ Examples
304
+ --------
305
+ .. code-block:: python
269
306
 
270
307
  >>> import brainstate.nn as nn
271
- >>> import brainstate as brainstate
308
+ >>> import brainstate
272
309
  >>> m = nn.Sigmoid()
273
- >>> x = random.randn(2)
310
+ >>> x = brainstate.random.randn(2)
274
311
  >>> output = m(x)
275
312
  """
276
313
  __module__ = 'brainstate.nn'
@@ -291,16 +328,19 @@ class Hardsigmoid(ElementWiseBlock):
291
328
  x / 6 + 1 / 2 & \text{otherwise}
292
329
  \end{cases}
293
330
 
294
- Shape:
295
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
296
- - Output: :math:`(*)`, same shape as the input.
331
+ Shape
332
+ -----
333
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
334
+ - Output: :math:`(*)`, same shape as the input.
297
335
 
298
- Examples::
336
+ Examples
337
+ --------
338
+ .. code-block:: python
299
339
 
300
340
  >>> import brainstate.nn as nn
301
- >>> import brainstate as brainstate
341
+ >>> import brainstate
302
342
  >>> m = nn.Hardsigmoid()
303
- >>> x = random.randn(2)
343
+ >>> x = brainstate.random.randn(2)
304
344
  >>> output = m(x)
305
345
  """
306
346
  __module__ = 'brainstate.nn'
@@ -317,16 +357,19 @@ class Tanh(ElementWiseBlock):
317
357
  .. math::
318
358
  \text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)}
319
359
 
320
- Shape:
321
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
322
- - Output: :math:`(*)`, same shape as the input.
360
+ Shape
361
+ -----
362
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
363
+ - Output: :math:`(*)`, same shape as the input.
323
364
 
324
- Examples::
365
+ Examples
366
+ --------
367
+ .. code-block:: python
325
368
 
326
369
  >>> import brainstate.nn as nn
327
- >>> import brainstate as brainstate
370
+ >>> import brainstate
328
371
  >>> m = nn.Tanh()
329
- >>> x = random.randn(2)
372
+ >>> x = brainstate.random.randn(2)
330
373
  >>> output = m(x)
331
374
  """
332
375
  __module__ = 'brainstate.nn'
@@ -337,28 +380,34 @@ class Tanh(ElementWiseBlock):
337
380
 
338
381
  class SiLU(ElementWiseBlock):
339
382
  r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
383
+
340
384
  The SiLU function is also known as the swish function.
341
385
 
342
386
  .. math::
343
387
  \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.}
344
388
 
345
- .. note::
346
- See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
347
- where the SiLU (Sigmoid Linear Unit) was originally coined, and see
348
- `Sigmoid-Weighted Linear Units for Neural Network Function Approximation
349
- in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
350
- a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
351
- where the SiLU was experimented with later.
389
+ Notes
390
+ -----
391
+ See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
392
+ where the SiLU (Sigmoid Linear Unit) was originally coined, and see
393
+ `Sigmoid-Weighted Linear Units for Neural Network Function Approximation
394
+ in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
395
+ a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
396
+ where the SiLU was experimented with later.
352
397
 
353
- Shape:
354
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
355
- - Output: :math:`(*)`, same shape as the input.
398
+ Shape
399
+ -----
400
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
401
+ - Output: :math:`(*)`, same shape as the input.
356
402
 
357
- Examples::
403
+ Examples
404
+ --------
405
+ .. code-block:: python
358
406
 
359
407
  >>> import brainstate.nn as nn
408
+ >>> import brainstate
360
409
  >>> m = nn.SiLU()
361
- >>> x = random.randn(2)
410
+ >>> x = brainstate.random.randn(2)
362
411
  >>> output = m(x)
363
412
  """
364
413
  __module__ = 'brainstate.nn'
@@ -369,25 +418,30 @@ class SiLU(ElementWiseBlock):
369
418
 
370
419
  class Mish(ElementWiseBlock):
371
420
  r"""Applies the Mish function, element-wise.
421
+
372
422
  Mish: A Self Regularized Non-Monotonic Neural Activation Function.
373
423
 
374
424
  .. math::
375
425
  \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
376
426
 
377
- .. note::
378
- See `Mish: A Self Regularized Non-Monotonic Neural Activation Function <https://arxiv.org/abs/1908.08681>`_
379
-
427
+ Notes
428
+ -----
429
+ See `Mish: A Self Regularized Non-Monotonic Neural Activation Function
430
+ <https://arxiv.org/abs/1908.08681>`_
380
431
 
381
- Shape:
382
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
383
- - Output: :math:`(*)`, same shape as the input.
432
+ Shape
433
+ -----
434
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
435
+ - Output: :math:`(*)`, same shape as the input.
384
436
 
385
- Examples::
437
+ Examples
438
+ --------
439
+ .. code-block:: python
386
440
 
387
441
  >>> import brainstate.nn as nn
388
- >>> import brainstate as brainstate
442
+ >>> import brainstate
389
443
  >>> m = nn.Mish()
390
- >>> x = random.randn(2)
444
+ >>> x = brainstate.random.randn(2)
391
445
  >>> output = m(x)
392
446
  """
393
447
  __module__ = 'brainstate.nn'
@@ -397,8 +451,10 @@ class Mish(ElementWiseBlock):
397
451
 
398
452
 
399
453
  class Hardswish(ElementWiseBlock):
400
- r"""Applies the Hardswish function, element-wise, as described in the paper:
401
- `Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`_.
454
+ r"""Applies the Hardswish function, element-wise.
455
+
456
+ As described in the paper `Searching for MobileNetV3
457
+ <https://arxiv.org/abs/1905.02244>`_.
402
458
 
403
459
  Hardswish is defined as:
404
460
 
@@ -409,17 +465,19 @@ class Hardswish(ElementWiseBlock):
409
465
  x \cdot (x + 3) /6 & \text{otherwise}
410
466
  \end{cases}
411
467
 
468
+ Shape
469
+ -----
470
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
471
+ - Output: :math:`(*)`, same shape as the input.
412
472
 
413
- Shape:
414
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
415
- - Output: :math:`(*)`, same shape as the input.
416
-
417
- Examples::
473
+ Examples
474
+ --------
475
+ .. code-block:: python
418
476
 
419
477
  >>> import brainstate.nn as nn
420
- >>> import brainstate as brainstate
478
+ >>> import brainstate
421
479
  >>> m = nn.Hardswish()
422
- >>> x = random.randn(2)
480
+ >>> x = brainstate.random.randn(2)
423
481
  >>> output = m(x)
424
482
  """
425
483
  __module__ = 'brainstate.nn'
@@ -429,9 +487,10 @@ class Hardswish(ElementWiseBlock):
429
487
 
430
488
 
431
489
  class ELU(ElementWiseBlock):
432
- r"""Applies the Exponential Linear Unit (ELU) function, element-wise, as described
433
- in the paper: `Fast and Accurate Deep Network Learning by Exponential Linear
434
- Units (ELUs) <https://arxiv.org/abs/1511.07289>`__.
490
+ r"""Applies the Exponential Linear Unit (ELU) function, element-wise.
491
+
492
+ As described in the paper: `Fast and Accurate Deep Network Learning by
493
+ Exponential Linear Units (ELUs) <https://arxiv.org/abs/1511.07289>`__.
435
494
 
436
495
  ELU is defined as:
437
496
 
@@ -441,19 +500,24 @@ class ELU(ElementWiseBlock):
441
500
  \alpha * (\exp(x) - 1), & \text{ if } x \leq 0
442
501
  \end{cases}
443
502
 
444
- Args:
445
- alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
503
+ Parameters
504
+ ----------
505
+ alpha : float, optional
506
+ The :math:`\alpha` value for the ELU formulation. Default: 1.0
446
507
 
447
- Shape:
448
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
449
- - Output: :math:`(*)`, same shape as the input.
508
+ Shape
509
+ -----
510
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
511
+ - Output: :math:`(*)`, same shape as the input.
450
512
 
451
- Examples::
513
+ Examples
514
+ --------
515
+ .. code-block:: python
452
516
 
453
517
  >>> import brainstate.nn as nn
454
- >>> import brainstate as brainstate
518
+ >>> import brainstate
455
519
  >>> m = nn.ELU()
456
- >>> x = random.randn(2)
520
+ >>> x = brainstate.random.randn(2)
457
521
  >>> output = m(x)
458
522
  """
459
523
  __module__ = 'brainstate.nn'
@@ -471,30 +535,38 @@ class ELU(ElementWiseBlock):
471
535
 
472
536
 
473
537
  class CELU(ElementWiseBlock):
474
- r"""Applies the element-wise function:
538
+ r"""Applies the element-wise function.
475
539
 
476
540
  .. math::
477
541
  \text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
478
542
 
479
- More details can be found in the paper `Continuously Differentiable Exponential Linear Units`_ .
543
+ More details can be found in the paper `Continuously Differentiable Exponential
544
+ Linear Units`_ .
480
545
 
481
- Args:
482
- alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
546
+ Parameters
547
+ ----------
548
+ alpha : float, optional
549
+ The :math:`\alpha` value for the CELU formulation. Default: 1.0
483
550
 
484
- Shape:
485
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
486
- - Output: :math:`(*)`, same shape as the input.
551
+ Shape
552
+ -----
553
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
554
+ - Output: :math:`(*)`, same shape as the input.
487
555
 
488
- Examples::
556
+ References
557
+ ----------
558
+ .. _`Continuously Differentiable Exponential Linear Units`:
559
+ https://arxiv.org/abs/1704.07483
560
+
561
+ Examples
562
+ --------
563
+ .. code-block:: python
489
564
 
490
565
  >>> import brainstate.nn as nn
491
- >>> import brainstate as brainstate
566
+ >>> import brainstate
492
567
  >>> m = nn.CELU()
493
- >>> x = random.randn(2)
568
+ >>> x = brainstate.random.randn(2)
494
569
  >>> output = m(x)
495
-
496
- .. _`Continuously Differentiable Exponential Linear Units`:
497
- https://arxiv.org/abs/1704.07483
498
570
  """
499
571
  __module__ = 'brainstate.nn'
500
572
  alpha: float
@@ -511,7 +583,7 @@ class CELU(ElementWiseBlock):
511
583
 
512
584
 
513
585
  class SELU(ElementWiseBlock):
514
- r"""Applied element-wise, as:
586
+ r"""Applied element-wise.
515
587
 
516
588
  .. math::
517
589
  \text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))
@@ -521,20 +593,24 @@ class SELU(ElementWiseBlock):
521
593
 
522
594
  More details can be found in the paper `Self-Normalizing Neural Networks`_ .
523
595
 
596
+ Shape
597
+ -----
598
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
599
+ - Output: :math:`(*)`, same shape as the input.
524
600
 
525
- Shape:
526
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
527
- - Output: :math:`(*)`, same shape as the input.
601
+ References
602
+ ----------
603
+ .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
528
604
 
529
- Examples::
605
+ Examples
606
+ --------
607
+ .. code-block:: python
530
608
 
531
609
  >>> import brainstate.nn as nn
532
- >>> import brainstate as brainstate
610
+ >>> import brainstate
533
611
  >>> m = nn.SELU()
534
- >>> x = random.randn(2)
612
+ >>> x = brainstate.random.randn(2)
535
613
  >>> output = m(x)
536
-
537
- .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
538
614
  """
539
615
  __module__ = 'brainstate.nn'
540
616
 
@@ -543,24 +619,33 @@ class SELU(ElementWiseBlock):
543
619
 
544
620
 
545
621
  class GLU(ElementWiseBlock):
546
- r"""Applies the gated linear unit function
547
- :math:`{GLU}(a, b)= a \otimes \sigma(b)` where :math:`a` is the first half
548
- of the input matrices and :math:`b` is the second half.
622
+ r"""Applies the gated linear unit function.
549
623
 
550
- Args:
551
- dim (int): the dimension on which to split the input. Default: -1
624
+ .. math::
625
+ {GLU}(a, b)= a \otimes \sigma(b)
626
+
627
+ where :math:`a` is the first half of the input matrices and :math:`b` is
628
+ the second half.
552
629
 
553
- Shape:
554
- - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
555
- dimensions
556
- - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
630
+ Parameters
631
+ ----------
632
+ dim : int, optional
633
+ The dimension on which to split the input. Default: -1
557
634
 
558
- Examples::
635
+ Shape
636
+ -----
637
+ - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
638
+ dimensions
639
+ - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
640
+
641
+ Examples
642
+ --------
643
+ .. code-block:: python
559
644
 
560
645
  >>> import brainstate.nn as nn
561
- >>> import brainstate as brainstate
646
+ >>> import brainstate
562
647
  >>> m = nn.GLU()
563
- >>> x = random.randn(4, 2)
648
+ >>> x = brainstate.random.randn(4, 2)
564
649
  >>> output = m(x)
565
650
  """
566
651
  __module__ = 'brainstate.nn'
@@ -578,30 +663,35 @@ class GLU(ElementWiseBlock):
578
663
 
579
664
 
580
665
  class GELU(ElementWiseBlock):
581
- r"""Applies the Gaussian Error Linear Units function:
666
+ r"""Applies the Gaussian Error Linear Units function.
582
667
 
583
668
  .. math:: \text{GELU}(x) = x * \Phi(x)
584
669
 
585
- where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.
670
+ where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian
671
+ Distribution.
586
672
 
587
- When the approximate argument is 'tanh', Gelu is estimated with:
673
+ When the approximate argument is True, Gelu is estimated with:
588
674
 
589
675
  .. math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt(2 / \pi) * (x + 0.044715 * x^3)))
590
676
 
591
- Args:
592
- approximate (str, optional): the gelu approximation algorithm to use:
593
- ``'none'`` | ``'tanh'``. Default: ``'none'``
677
+ Parameters
678
+ ----------
679
+ approximate : bool, optional
680
+ Whether to use the tanh approximation algorithm. Default: False
594
681
 
595
- Shape:
596
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
597
- - Output: :math:`(*)`, same shape as the input.
682
+ Shape
683
+ -----
684
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
685
+ - Output: :math:`(*)`, same shape as the input.
598
686
 
599
- Examples::
687
+ Examples
688
+ --------
689
+ .. code-block:: python
600
690
 
601
691
  >>> import brainstate.nn as nn
602
- >>> import brainstate as brainstate
692
+ >>> import brainstate
603
693
  >>> m = nn.GELU()
604
- >>> x = random.randn(2)
694
+ >>> x = brainstate.random.randn(2)
605
695
  >>> output = m(x)
606
696
  """
607
697
  __module__ = 'brainstate.nn'
@@ -631,19 +721,24 @@ class Hardshrink(ElementWiseBlock):
631
721
  0, & \text{ otherwise }
632
722
  \end{cases}
633
723
 
634
- Args:
635
- lambd: the :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
724
+ Parameters
725
+ ----------
726
+ lambd : float, optional
727
+ The :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
636
728
 
637
- Shape:
638
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
639
- - Output: :math:`(*)`, same shape as the input.
729
+ Shape
730
+ -----
731
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
732
+ - Output: :math:`(*)`, same shape as the input.
640
733
 
641
- Examples::
734
+ Examples
735
+ --------
736
+ .. code-block:: python
642
737
 
643
738
  >>> import brainstate.nn as nn
644
- >>> import brainstate as brainstate
739
+ >>> import brainstate
645
740
  >>> m = nn.Hardshrink()
646
- >>> x = random.randn(2)
741
+ >>> x = brainstate.random.randn(2)
647
742
  >>> output = m(x)
648
743
  """
649
744
  __module__ = 'brainstate.nn'
@@ -661,12 +756,11 @@ class Hardshrink(ElementWiseBlock):
661
756
 
662
757
 
663
758
  class LeakyReLU(ElementWiseBlock):
664
- r"""Applies the element-wise function:
759
+ r"""Applies the element-wise function.
665
760
 
666
761
  .. math::
667
762
  \text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)
668
763
 
669
-
670
764
  or
671
765
 
672
766
  .. math::
@@ -676,21 +770,26 @@ class LeakyReLU(ElementWiseBlock):
676
770
  \text{negative\_slope} \times x, & \text{ otherwise }
677
771
  \end{cases}
678
772
 
679
- Args:
680
- negative_slope: Controls the angle of the negative slope (which is used for
681
- negative input values). Default: 1e-2
773
+ Parameters
774
+ ----------
775
+ negative_slope : float, optional
776
+ Controls the angle of the negative slope (which is used for
777
+ negative input values). Default: 1e-2
682
778
 
683
- Shape:
684
- - Input: :math:`(*)` where `*` means, any number of additional
685
- dimensions
686
- - Output: :math:`(*)`, same shape as the input
779
+ Shape
780
+ -----
781
+ - Input: :math:`(*)` where `*` means, any number of additional
782
+ dimensions
783
+ - Output: :math:`(*)`, same shape as the input
687
784
 
688
- Examples::
785
+ Examples
786
+ --------
787
+ .. code-block:: python
689
788
 
690
789
  >>> import brainstate.nn as nn
691
- >>> import brainstate as brainstate
790
+ >>> import brainstate
692
791
  >>> m = nn.LeakyReLU(0.1)
693
- >>> x = random.randn(2)
792
+ >>> x = brainstate.random.randn(2)
694
793
  >>> output = m(x)
695
794
  """
696
795
  __module__ = 'brainstate.nn'
@@ -708,21 +807,24 @@ class LeakyReLU(ElementWiseBlock):
708
807
 
709
808
 
710
809
  class LogSigmoid(ElementWiseBlock):
711
- r"""Applies the element-wise function:
810
+ r"""Applies the element-wise function.
712
811
 
713
812
  .. math::
714
813
  \text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right)
715
814
 
716
- Shape:
717
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
718
- - Output: :math:`(*)`, same shape as the input.
815
+ Shape
816
+ -----
817
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
818
+ - Output: :math:`(*)`, same shape as the input.
719
819
 
720
- Examples::
820
+ Examples
821
+ --------
822
+ .. code-block:: python
721
823
 
722
824
  >>> import brainstate.nn as nn
723
- >>> import brainstate as brainstate
825
+ >>> import brainstate
724
826
  >>> m = nn.LogSigmoid()
725
- >>> x = random.randn(2)
827
+ >>> x = brainstate.random.randn(2)
726
828
  >>> output = m(x)
727
829
  """
728
830
  __module__ = 'brainstate.nn'
@@ -732,8 +834,10 @@ class LogSigmoid(ElementWiseBlock):
732
834
 
733
835
 
734
836
  class Softplus(ElementWiseBlock):
735
- r"""Applies the Softplus function :math:`\text{Softplus}(x) = \frac{1}{\beta} *
736
- \log(1 + \exp(\beta * x))` element-wise.
837
+ r"""Applies the Softplus function element-wise.
838
+
839
+ .. math::
840
+ \text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))
737
841
 
738
842
  SoftPlus is a smooth approximation to the ReLU function and can be used
739
843
  to constrain the output of a machine to always be positive.
@@ -741,16 +845,19 @@ class Softplus(ElementWiseBlock):
741
845
  For numerical stability the implementation reverts to the linear function
742
846
  when :math:`input \times \beta > threshold`.
743
847
 
744
- Shape:
745
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
746
- - Output: :math:`(*)`, same shape as the input.
848
+ Shape
849
+ -----
850
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
851
+ - Output: :math:`(*)`, same shape as the input.
747
852
 
748
- Examples::
853
+ Examples
854
+ --------
855
+ .. code-block:: python
749
856
 
750
857
  >>> import brainstate.nn as nn
751
- >>> import brainstate as brainstate
858
+ >>> import brainstate
752
859
  >>> m = nn.Softplus()
753
- >>> x = random.randn(2)
860
+ >>> x = brainstate.random.randn(2)
754
861
  >>> output = m(x)
755
862
  """
756
863
  __module__ = 'brainstate.nn'
@@ -760,7 +867,7 @@ class Softplus(ElementWiseBlock):
760
867
 
761
868
 
762
869
  class Softshrink(ElementWiseBlock):
763
- r"""Applies the soft shrinkage function elementwise:
870
+ r"""Applies the soft shrinkage function elementwise.
764
871
 
765
872
  .. math::
766
873
  \text{SoftShrinkage}(x) =
@@ -770,19 +877,25 @@ class Softshrink(ElementWiseBlock):
770
877
  0, & \text{ otherwise }
771
878
  \end{cases}
772
879
 
773
- Args:
774
- lambd: the :math:`\lambda` (must be no less than zero) value for the Softshrink formulation. Default: 0.5
880
+ Parameters
881
+ ----------
882
+ lambd : float, optional
883
+ The :math:`\lambda` (must be no less than zero) value for the
884
+ Softshrink formulation. Default: 0.5
775
885
 
776
- Shape:
777
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
778
- - Output: :math:`(*)`, same shape as the input.
886
+ Shape
887
+ -----
888
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
889
+ - Output: :math:`(*)`, same shape as the input.
779
890
 
780
- Examples::
891
+ Examples
892
+ --------
893
+ .. code-block:: python
781
894
 
782
895
  >>> import brainstate.nn as nn
783
- >>> import brainstate as brainstate
896
+ >>> import brainstate
784
897
  >>> m = nn.Softshrink()
785
- >>> x = random.randn(2)
898
+ >>> x = brainstate.random.randn(2)
786
899
  >>> output = m(x)
787
900
  """
788
901
  __module__ = 'brainstate.nn'
@@ -800,7 +913,7 @@ class Softshrink(ElementWiseBlock):
800
913
 
801
914
 
802
915
  class PReLU(ElementWiseBlock):
803
- r"""Applies the element-wise function:
916
+ r"""Applies the element-wise function.
804
917
 
805
918
  .. math::
806
919
  \text{PReLU}(x) = \max(0,x) + a * \min(0,x)
@@ -814,35 +927,43 @@ class PReLU(ElementWiseBlock):
814
927
  ax, & \text{ otherwise }
815
928
  \end{cases}
816
929
 
817
- Here :math:`a` is a learnable parameter. When called without arguments, `nn.PReLU()` uses a single
818
- parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
819
- a separate :math:`a` is used for each input channel.
820
-
821
-
822
- .. note::
823
- weight decay should not be used when learning :math:`a` for good performance.
930
+ Here :math:`a` is a learnable parameter. When called without arguments,
931
+ `nn.PReLU()` uses a single parameter :math:`a` across all input channels.
932
+ If called with `nn.PReLU(nChannels)`, a separate :math:`a` is used for
933
+ each input channel.
934
+
935
+ Parameters
936
+ ----------
937
+ num_parameters : int, optional
938
+ Number of :math:`a` to learn. Although it takes an int as input,
939
+ there is only two values are legitimate: 1, or the number of channels
940
+ at input. Default: 1
941
+ init : float, optional
942
+ The initial value of :math:`a`. Default: 0.25
943
+ dtype : optional
944
+ The data type for the weight parameter.
945
+
946
+ Shape
947
+ -----
948
+ - Input: :math:`( *)` where `*` means, any number of additional dimensions.
949
+ - Output: :math:`(*)`, same shape as the input.
950
+
951
+ Attributes
952
+ ----------
953
+ weight : Tensor
954
+ The learnable weights of shape (:attr:`num_parameters`).
955
+
956
+ Notes
957
+ -----
958
+ - Weight decay should not be used when learning :math:`a` for good performance.
959
+ - Channel dim is the 2nd dim of input. When input has dims < 2, then there is
960
+ no channel dim and the number of channels = 1.
961
+
962
+ Examples
963
+ --------
964
+ .. code-block:: python
824
965
 
825
- .. note::
826
- Channel dim is the 2nd dim of input. When input has dims < 2, then there is
827
- no channel dim and the number of channels = 1.
828
-
829
- Args:
830
- num_parameters (int): number of :math:`a` to learn.
831
- Although it takes an int as input, there is only two values are legitimate:
832
- 1, or the number of channels at input. Default: 1
833
- init (float): the initial value of :math:`a`. Default: 0.25
834
-
835
- Shape:
836
- - Input: :math:`( *)` where `*` means, any number of additional
837
- dimensions.
838
- - Output: :math:`(*)`, same shape as the input.
839
-
840
- Attributes:
841
- weight (Tensor): the learnable weights of shape (:attr:`num_parameters`).
842
-
843
- Examples::
844
-
845
- >>> import brainstate as brainstate
966
+ >>> import brainstate
846
967
  >>> m = brainstate.nn.PReLU()
847
968
  >>> x = brainstate.random.randn(2)
848
969
  >>> output = m(x)
@@ -863,21 +984,24 @@ class PReLU(ElementWiseBlock):
863
984
 
864
985
 
865
986
  class Softsign(ElementWiseBlock):
866
- r"""Applies the element-wise function:
987
+ r"""Applies the element-wise function.
867
988
 
868
989
  .. math::
869
990
  \text{SoftSign}(x) = \frac{x}{ 1 + |x|}
870
991
 
871
- Shape:
872
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
873
- - Output: :math:`(*)`, same shape as the input.
992
+ Shape
993
+ -----
994
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
995
+ - Output: :math:`(*)`, same shape as the input.
874
996
 
875
- Examples::
997
+ Examples
998
+ --------
999
+ .. code-block:: python
876
1000
 
877
1001
  >>> import brainstate.nn as nn
878
- >>> import brainstate as brainstate
1002
+ >>> import brainstate
879
1003
  >>> m = nn.Softsign()
880
- >>> x = random.randn(2)
1004
+ >>> x = brainstate.random.randn(2)
881
1005
  >>> output = m(x)
882
1006
  """
883
1007
  __module__ = 'brainstate.nn'
@@ -887,21 +1011,24 @@ class Softsign(ElementWiseBlock):
887
1011
 
888
1012
 
889
1013
  class Tanhshrink(ElementWiseBlock):
890
- r"""Applies the element-wise function:
1014
+ r"""Applies the element-wise function.
891
1015
 
892
1016
  .. math::
893
1017
  \text{Tanhshrink}(x) = x - \tanh(x)
894
1018
 
895
- Shape:
896
- - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
897
- - Output: :math:`(*)`, same shape as the input.
1019
+ Shape
1020
+ -----
1021
+ - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
1022
+ - Output: :math:`(*)`, same shape as the input.
898
1023
 
899
- Examples::
1024
+ Examples
1025
+ --------
1026
+ .. code-block:: python
900
1027
 
901
1028
  >>> import brainstate.nn as nn
902
- >>> import brainstate as brainstate
1029
+ >>> import brainstate
903
1030
  >>> m = nn.Tanhshrink()
904
- >>> x = random.randn(2)
1031
+ >>> x = brainstate.random.randn(2)
905
1032
  >>> output = m(x)
906
1033
  """
907
1034
  __module__ = 'brainstate.nn'
@@ -911,8 +1038,9 @@ class Tanhshrink(ElementWiseBlock):
911
1038
 
912
1039
 
913
1040
  class Softmin(ElementWiseBlock):
914
- r"""Applies the Softmin function to an n-dimensional input Tensor
915
- rescaling them so that the elements of the n-dimensional output Tensor
1041
+ r"""Applies the Softmin function to an n-dimensional input Tensor.
1042
+
1043
+ Rescales the input so that the elements of the n-dimensional output Tensor
916
1044
  lie in the range `[0, 1]` and sum to 1.
917
1045
 
918
1046
  Softmin is defined as:
@@ -920,25 +1048,31 @@ class Softmin(ElementWiseBlock):
920
1048
  .. math::
921
1049
  \text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
922
1050
 
923
- Shape:
924
- - Input: :math:`(*)` where `*` means, any number of additional
925
- dimensions
926
- - Output: :math:`(*)`, same shape as the input
927
-
928
- Args:
929
- dim (int): A dimension along which Softmin will be computed (so every slice
930
- along dim will sum to 1).
931
-
932
- Returns:
933
- a Tensor of the same dimension and shape as the input, with
1051
+ Parameters
1052
+ ----------
1053
+ dim : int, optional
1054
+ A dimension along which Softmin will be computed (so every slice
1055
+ along dim will sum to 1).
1056
+
1057
+ Shape
1058
+ -----
1059
+ - Input: :math:`(*)` where `*` means, any number of additional dimensions
1060
+ - Output: :math:`(*)`, same shape as the input
1061
+
1062
+ Returns
1063
+ -------
1064
+ Tensor
1065
+ A Tensor of the same dimension and shape as the input, with
934
1066
  values in the range [0, 1]
935
1067
 
936
- Examples::
1068
+ Examples
1069
+ --------
1070
+ .. code-block:: python
937
1071
 
938
1072
  >>> import brainstate.nn as nn
939
- >>> import brainstate as brainstate
1073
+ >>> import brainstate
940
1074
  >>> m = nn.Softmin(dim=1)
941
- >>> x = random.randn(2, 3)
1075
+ >>> x = brainstate.random.randn(2, 3)
942
1076
  >>> output = m(x)
943
1077
  """
944
1078
  __module__ = 'brainstate.nn'
@@ -956,8 +1090,9 @@ class Softmin(ElementWiseBlock):
956
1090
 
957
1091
 
958
1092
  class Softmax(ElementWiseBlock):
959
- r"""Applies the Softmax function to an n-dimensional input Tensor
960
- rescaling them so that the elements of the n-dimensional output Tensor
1093
+ r"""Applies the Softmax function to an n-dimensional input Tensor.
1094
+
1095
+ Rescales the input so that the elements of the n-dimensional output Tensor
961
1096
  lie in the range [0,1] and sum to 1.
962
1097
 
963
1098
  Softmax is defined as:
@@ -968,32 +1103,38 @@ class Softmax(ElementWiseBlock):
968
1103
  When the input Tensor is a sparse tensor then the unspecified
969
1104
  values are treated as ``-inf``.
970
1105
 
971
- Shape:
972
- - Input: :math:`(*)` where `*` means, any number of additional
973
- dimensions
974
- - Output: :math:`(*)`, same shape as the input
975
-
976
- Returns:
977
- a Tensor of the same dimension and shape as the input with
1106
+ Parameters
1107
+ ----------
1108
+ dim : int, optional
1109
+ A dimension along which Softmax will be computed (so every slice
1110
+ along dim will sum to 1).
1111
+
1112
+ Shape
1113
+ -----
1114
+ - Input: :math:`(*)` where `*` means, any number of additional dimensions
1115
+ - Output: :math:`(*)`, same shape as the input
1116
+
1117
+ Returns
1118
+ -------
1119
+ Tensor
1120
+ A Tensor of the same dimension and shape as the input with
978
1121
  values in the range [0, 1]
979
1122
 
980
- Args:
981
- dim (int): A dimension along which Softmax will be computed (so every slice
982
- along dim will sum to 1).
1123
+ Notes
1124
+ -----
1125
+ This module doesn't work directly with NLLLoss, which expects the Log to be
1126
+ computed between the Softmax and itself. Use `LogSoftmax` instead (it's
1127
+ faster and has better numerical properties).
983
1128
 
984
- .. note::
985
- This module doesn't work directly with NLLLoss,
986
- which expects the Log to be computed between the Softmax and itself.
987
- Use `LogSoftmax` instead (it's faster and has better numerical properties).
988
-
989
- Examples::
1129
+ Examples
1130
+ --------
1131
+ .. code-block:: python
990
1132
 
991
1133
  >>> import brainstate.nn as nn
992
- >>> import brainstate as brainstate
1134
+ >>> import brainstate
993
1135
  >>> m = nn.Softmax(dim=1)
994
- >>> x = random.randn(2, 3)
1136
+ >>> x = brainstate.random.randn(2, 3)
995
1137
  >>> output = m(x)
996
-
997
1138
  """
998
1139
  __module__ = 'brainstate.nn'
999
1140
  dim: Optional[int]
@@ -1015,21 +1156,26 @@ class Softmax2d(ElementWiseBlock):
1015
1156
  When given an image of ``Channels x Height x Width``, it will
1016
1157
  apply `Softmax` to each location :math:`(Channels, h_i, w_j)`
1017
1158
 
1018
- Shape:
1019
- - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`.
1020
- - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
1159
+ Shape
1160
+ -----
1161
+ - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`.
1162
+ - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
1021
1163
 
1022
- Returns:
1023
- a Tensor of the same dimension and shape as the input with
1164
+ Returns
1165
+ -------
1166
+ Tensor
1167
+ A Tensor of the same dimension and shape as the input with
1024
1168
  values in the range [0, 1]
1025
1169
 
1026
- Examples::
1170
+ Examples
1171
+ --------
1172
+ .. code-block:: python
1027
1173
 
1028
1174
  >>> import brainstate.nn as nn
1029
- >>> import brainstate as brainstate
1175
+ >>> import brainstate
1030
1176
  >>> m = nn.Softmax2d()
1031
1177
  >>> # you softmax over the 2nd dimension
1032
- >>> x = random.randn(2, 3, 12, 13)
1178
+ >>> x = brainstate.random.randn(2, 3, 12, 13)
1033
1179
  >>> output = m(x)
1034
1180
  """
1035
1181
  __module__ = 'brainstate.nn'
@@ -1040,30 +1186,37 @@ class Softmax2d(ElementWiseBlock):
1040
1186
 
1041
1187
 
1042
1188
  class LogSoftmax(ElementWiseBlock):
1043
- r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional
1044
- input Tensor. The LogSoftmax formulation can be simplified as:
1189
+ r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional input Tensor.
1190
+
1191
+ The LogSoftmax formulation can be simplified as:
1045
1192
 
1046
1193
  .. math::
1047
1194
  \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
1048
1195
 
1049
- Shape:
1050
- - Input: :math:`(*)` where `*` means, any number of additional
1051
- dimensions
1052
- - Output: :math:`(*)`, same shape as the input
1196
+ Parameters
1197
+ ----------
1198
+ dim : int, optional
1199
+ A dimension along which LogSoftmax will be computed.
1053
1200
 
1054
- Args:
1055
- dim (int): A dimension along which LogSoftmax will be computed.
1201
+ Shape
1202
+ -----
1203
+ - Input: :math:`(*)` where `*` means, any number of additional dimensions
1204
+ - Output: :math:`(*)`, same shape as the input
1056
1205
 
1057
- Returns:
1058
- a Tensor of the same dimension and shape as the input with
1206
+ Returns
1207
+ -------
1208
+ Tensor
1209
+ A Tensor of the same dimension and shape as the input with
1059
1210
  values in the range [-inf, 0)
1060
1211
 
1061
- Examples::
1212
+ Examples
1213
+ --------
1214
+ .. code-block:: python
1062
1215
 
1063
1216
  >>> import brainstate.nn as nn
1064
- >>> import brainstate as brainstate
1217
+ >>> import brainstate
1065
1218
  >>> m = nn.LogSoftmax(dim=1)
1066
- >>> x = random.randn(2, 3)
1219
+ >>> x = brainstate.random.randn(2, 3)
1067
1220
  >>> output = m(x)
1068
1221
  """
1069
1222
  __module__ = 'brainstate.nn'
@@ -1082,6 +1235,16 @@ class LogSoftmax(ElementWiseBlock):
1082
1235
 
1083
1236
  class Identity(ElementWiseBlock):
1084
1237
  r"""A placeholder identity operator that is argument-insensitive.
1238
+
1239
+ Examples
1240
+ --------
1241
+ .. code-block:: python
1242
+
1243
+ >>> import brainstate.nn as nn
1244
+ >>> m = nn.Identity()
1245
+ >>> x = brainstate.random.randn(2, 3)
1246
+ >>> output = m(x)
1247
+ >>> assert (output == x).all()
1085
1248
  """
1086
1249
  __module__ = 'brainstate.nn'
1087
1250
 
@@ -1103,17 +1266,33 @@ class SpikeBitwise(ElementWiseBlock):
1103
1266
  \hline
1104
1267
  \end{array}
1105
1268
 
1106
- Args:
1107
- op: str. The bitwise operation.
1108
- name: str. The name of the dynamic system.
1269
+ Parameters
1270
+ ----------
1271
+ op : str, optional
1272
+ The bitwise operation. Default: 'add'
1273
+ name : str, optional
1274
+ The name of the dynamic system.
1275
+
1276
+ Examples
1277
+ --------
1278
+ .. code-block:: python
1279
+
1280
+ >>> import brainstate.nn as nn
1281
+ >>> m = nn.SpikeBitwise(op='and')
1282
+ >>> x = brainstate.random.randn(2, 3) > 0
1283
+ >>> y = brainstate.random.randn(2, 3) > 0
1284
+ >>> output = m(x, y)
1109
1285
  """
1110
1286
  __module__ = 'brainstate.nn'
1111
1287
 
1112
- def __init__(self,
1113
- op: str = 'add',
1114
- name: Optional[str] = None) -> None:
1288
+ def __init__(
1289
+ self,
1290
+ op: str = 'add',
1291
+ name: Optional[str] = None
1292
+ ) -> None:
1115
1293
  super().__init__(name=name)
1116
1294
  self.op = op
1117
1295
 
1118
1296
  def __call__(self, x, y):
1119
- return F.spike_bitwise(x, y, self.op)
1297
+ import braintools
1298
+ return braintools.spike_bitwise(x, y, self.op)