brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +169 -58
- brainstate/_compatible_import.py +340 -148
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +45 -55
- brainstate/_state.py +1652 -1605
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -563
- brainstate/environ_test.py +1223 -62
- brainstate/graph/__init__.py +22 -29
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1433 -365
- brainstate/mixin_test.py +1017 -77
- brainstate/nn/__init__.py +137 -135
- brainstate/nn/_activations.py +1100 -808
- brainstate/nn/_activations_test.py +354 -331
- brainstate/nn/_collective_ops.py +633 -514
- brainstate/nn/_collective_ops_test.py +774 -43
- brainstate/nn/_common.py +226 -178
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +2010 -501
- brainstate/nn/_conv_test.py +849 -238
- brainstate/nn/_delay.py +575 -588
- brainstate/nn/_delay_test.py +243 -238
- brainstate/nn/_dropout.py +618 -426
- brainstate/nn/_dropout_test.py +477 -100
- brainstate/nn/_dynamics.py +1267 -1343
- brainstate/nn/_dynamics_test.py +67 -78
- brainstate/nn/_elementwise.py +1298 -1119
- brainstate/nn/_elementwise_test.py +830 -169
- brainstate/nn/_embedding.py +408 -58
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
- brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
- brainstate/nn/_exp_euler.py +254 -92
- brainstate/nn/_exp_euler_test.py +377 -35
- brainstate/nn/_linear.py +744 -424
- brainstate/nn/_linear_test.py +475 -107
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +384 -377
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -975
- brainstate/nn/_normalizations_test.py +699 -73
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +2239 -1177
- brainstate/nn/_poolings_test.py +953 -217
- brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +216 -89
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +809 -553
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
- brainstate/random/__init__.py +270 -24
- brainstate/random/_rand_funs.py +3938 -3616
- brainstate/random/_rand_funs_test.py +640 -567
- brainstate/random/_rand_seed.py +675 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1617 -1409
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
- brainstate/{augment → transform}/_autograd.py +1025 -778
- brainstate/{augment → transform}/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +220 -220
- brainstate/{compile → transform}/_error_if.py +94 -92
- brainstate/{compile → transform}/_error_if_test.py +52 -52
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +38 -38
- brainstate/{compile → transform}/_jit.py +399 -346
- brainstate/{compile → transform}/_jit_test.py +143 -143
- brainstate/{compile → transform}/_loop_collect_return.py +675 -536
- brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
- brainstate/{compile → transform}/_loop_no_collection.py +283 -184
- brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +255 -202
- brainstate/{augment → transform}/_random.py +171 -151
- brainstate/{compile → transform}/_unvmap.py +256 -159
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +837 -304
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +27 -50
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +945 -469
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +910 -523
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
- brainstate-0.2.1.dist-info/RECORD +111 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/nn/_elementwise.py
CHANGED
@@ -1,1119 +1,1298 @@
|
|
1
|
-
# Copyright 2024
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
from typing import Optional
|
19
|
-
|
20
|
-
import brainunit as u
|
21
|
-
import jax.numpy as jnp
|
22
|
-
|
23
|
-
from brainstate._state import ParamState
|
24
|
-
from brainstate.typing import ArrayLike
|
25
|
-
from . import _activations as F
|
26
|
-
from ._module import ElementWiseBlock
|
27
|
-
|
28
|
-
__all__ = [
|
29
|
-
# activation functions
|
30
|
-
'Threshold', 'ReLU', 'RReLU', 'Hardtanh', 'ReLU6', 'Sigmoid', 'Hardsigmoid',
|
31
|
-
'Tanh', 'SiLU', 'Mish', 'Hardswish', 'ELU', 'CELU', 'SELU', 'GLU', 'GELU',
|
32
|
-
'Hardshrink', 'LeakyReLU', 'LogSigmoid', 'Softplus', 'Softshrink', 'PReLU',
|
33
|
-
'Softsign', 'Tanhshrink', 'Softmin', 'Softmax', 'Softmax2d', 'LogSoftmax',
|
34
|
-
|
35
|
-
# others
|
36
|
-
'Identity', 'SpikeBitwise',
|
37
|
-
]
|
38
|
-
|
39
|
-
|
40
|
-
class Threshold(ElementWiseBlock):
|
41
|
-
r"""Thresholds each element of the input Tensor.
|
42
|
-
|
43
|
-
Threshold is defined as:
|
44
|
-
|
45
|
-
.. math::
|
46
|
-
y =
|
47
|
-
\begin{cases}
|
48
|
-
x, &\text{ if } x > \text{threshold} \\
|
49
|
-
\text{value}, &\text{ otherwise }
|
50
|
-
\end{cases}
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
def
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
>>>
|
111
|
-
>>>
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
..
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
"""
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
def
|
618
|
-
return
|
619
|
-
|
620
|
-
|
621
|
-
class
|
622
|
-
r"""Applies the
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
Examples
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
>>>
|
646
|
-
>>>
|
647
|
-
>>>
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
>>> import brainstate as
|
692
|
-
>>>
|
693
|
-
>>>
|
694
|
-
>>>
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
|
710
|
-
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
\
|
768
|
-
|
769
|
-
x
|
770
|
-
|
771
|
-
\end{cases}
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
|
789
|
-
|
790
|
-
|
791
|
-
|
792
|
-
|
793
|
-
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
def
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
|
808
|
-
|
809
|
-
|
810
|
-
|
811
|
-
|
812
|
-
|
813
|
-
x
|
814
|
-
|
815
|
-
|
816
|
-
|
817
|
-
|
818
|
-
|
819
|
-
|
820
|
-
|
821
|
-
|
822
|
-
..
|
823
|
-
|
824
|
-
|
825
|
-
|
826
|
-
|
827
|
-
|
828
|
-
|
829
|
-
|
830
|
-
|
831
|
-
|
832
|
-
|
833
|
-
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
838
|
-
|
839
|
-
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
|
844
|
-
|
845
|
-
|
846
|
-
|
847
|
-
|
848
|
-
|
849
|
-
|
850
|
-
|
851
|
-
|
852
|
-
|
853
|
-
|
854
|
-
|
855
|
-
|
856
|
-
|
857
|
-
|
858
|
-
|
859
|
-
|
860
|
-
|
861
|
-
|
862
|
-
|
863
|
-
|
864
|
-
|
865
|
-
|
866
|
-
|
867
|
-
|
868
|
-
|
869
|
-
|
870
|
-
|
871
|
-
|
872
|
-
|
873
|
-
|
874
|
-
|
875
|
-
|
876
|
-
|
877
|
-
|
878
|
-
|
879
|
-
|
880
|
-
|
881
|
-
|
882
|
-
|
883
|
-
|
884
|
-
|
885
|
-
|
886
|
-
|
887
|
-
|
888
|
-
|
889
|
-
|
890
|
-
|
891
|
-
|
892
|
-
|
893
|
-
|
894
|
-
|
895
|
-
|
896
|
-
|
897
|
-
|
898
|
-
|
899
|
-
|
900
|
-
|
901
|
-
|
902
|
-
|
903
|
-
|
904
|
-
|
905
|
-
|
906
|
-
|
907
|
-
|
908
|
-
|
909
|
-
|
910
|
-
|
911
|
-
|
912
|
-
|
913
|
-
|
914
|
-
|
915
|
-
|
916
|
-
|
917
|
-
|
918
|
-
|
919
|
-
|
920
|
-
|
921
|
-
|
922
|
-
|
923
|
-
|
924
|
-
|
925
|
-
|
926
|
-
|
927
|
-
|
928
|
-
|
929
|
-
|
930
|
-
|
931
|
-
|
932
|
-
|
933
|
-
|
934
|
-
|
935
|
-
|
936
|
-
|
937
|
-
|
938
|
-
|
939
|
-
|
940
|
-
|
941
|
-
|
942
|
-
|
943
|
-
|
944
|
-
|
945
|
-
|
946
|
-
|
947
|
-
|
948
|
-
|
949
|
-
|
950
|
-
|
951
|
-
|
952
|
-
|
953
|
-
|
954
|
-
|
955
|
-
|
956
|
-
|
957
|
-
|
958
|
-
|
959
|
-
|
960
|
-
|
961
|
-
|
962
|
-
|
963
|
-
|
964
|
-
|
965
|
-
|
966
|
-
|
967
|
-
|
968
|
-
|
969
|
-
|
970
|
-
|
971
|
-
|
972
|
-
|
973
|
-
|
974
|
-
|
975
|
-
|
976
|
-
|
977
|
-
|
978
|
-
|
979
|
-
|
980
|
-
|
981
|
-
|
982
|
-
|
983
|
-
|
984
|
-
|
985
|
-
|
986
|
-
|
987
|
-
|
988
|
-
|
989
|
-
|
990
|
-
|
991
|
-
|
992
|
-
|
993
|
-
|
994
|
-
|
995
|
-
|
996
|
-
|
997
|
-
|
998
|
-
|
999
|
-
|
1000
|
-
|
1001
|
-
|
1002
|
-
|
1003
|
-
|
1004
|
-
|
1005
|
-
|
1006
|
-
|
1007
|
-
|
1008
|
-
|
1009
|
-
|
1010
|
-
|
1011
|
-
|
1012
|
-
|
1013
|
-
|
1014
|
-
|
1015
|
-
|
1016
|
-
|
1017
|
-
|
1018
|
-
|
1019
|
-
|
1020
|
-
|
1021
|
-
|
1022
|
-
|
1023
|
-
|
1024
|
-
|
1025
|
-
|
1026
|
-
|
1027
|
-
|
1028
|
-
>>> import brainstate.nn as nn
|
1029
|
-
>>> import brainstate
|
1030
|
-
>>> m = nn.
|
1031
|
-
>>>
|
1032
|
-
>>>
|
1033
|
-
|
1034
|
-
|
1035
|
-
|
1036
|
-
|
1037
|
-
|
1038
|
-
|
1039
|
-
|
1040
|
-
|
1041
|
-
|
1042
|
-
|
1043
|
-
|
1044
|
-
|
1045
|
-
|
1046
|
-
|
1047
|
-
|
1048
|
-
|
1049
|
-
|
1050
|
-
|
1051
|
-
|
1052
|
-
|
1053
|
-
|
1054
|
-
|
1055
|
-
|
1056
|
-
|
1057
|
-
|
1058
|
-
|
1059
|
-
|
1060
|
-
|
1061
|
-
|
1062
|
-
|
1063
|
-
|
1064
|
-
|
1065
|
-
|
1066
|
-
|
1067
|
-
|
1068
|
-
|
1069
|
-
|
1070
|
-
|
1071
|
-
|
1072
|
-
|
1073
|
-
|
1074
|
-
|
1075
|
-
|
1076
|
-
|
1077
|
-
|
1078
|
-
|
1079
|
-
|
1080
|
-
|
1081
|
-
|
1082
|
-
|
1083
|
-
|
1084
|
-
|
1085
|
-
|
1086
|
-
|
1087
|
-
|
1088
|
-
def
|
1089
|
-
return
|
1090
|
-
|
1091
|
-
|
1092
|
-
class
|
1093
|
-
r"""
|
1094
|
-
|
1095
|
-
|
1096
|
-
|
1097
|
-
|
1098
|
-
|
1099
|
-
|
1100
|
-
|
1101
|
-
\text
|
1102
|
-
|
1103
|
-
|
1104
|
-
|
1105
|
-
|
1106
|
-
|
1107
|
-
|
1108
|
-
|
1109
|
-
|
1110
|
-
|
1111
|
-
|
1112
|
-
|
1113
|
-
|
1114
|
-
|
1115
|
-
|
1116
|
-
|
1117
|
-
|
1118
|
-
|
1119
|
-
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
from typing import Optional
|
19
|
+
|
20
|
+
import brainunit as u
|
21
|
+
import jax.numpy as jnp
|
22
|
+
|
23
|
+
from brainstate._state import ParamState
|
24
|
+
from brainstate.typing import ArrayLike
|
25
|
+
from . import _activations as F
|
26
|
+
from ._module import ElementWiseBlock
|
27
|
+
|
28
|
+
__all__ = [
|
29
|
+
# activation functions
|
30
|
+
'Threshold', 'ReLU', 'RReLU', 'Hardtanh', 'ReLU6', 'Sigmoid', 'Hardsigmoid',
|
31
|
+
'Tanh', 'SiLU', 'Mish', 'Hardswish', 'ELU', 'CELU', 'SELU', 'GLU', 'GELU',
|
32
|
+
'Hardshrink', 'LeakyReLU', 'LogSigmoid', 'Softplus', 'Softshrink', 'PReLU',
|
33
|
+
'Softsign', 'Tanhshrink', 'Softmin', 'Softmax', 'Softmax2d', 'LogSoftmax',
|
34
|
+
|
35
|
+
# others
|
36
|
+
'Identity', 'SpikeBitwise',
|
37
|
+
]
|
38
|
+
|
39
|
+
|
40
|
+
class Threshold(ElementWiseBlock):
|
41
|
+
r"""Thresholds each element of the input Tensor.
|
42
|
+
|
43
|
+
Threshold is defined as:
|
44
|
+
|
45
|
+
.. math::
|
46
|
+
y =
|
47
|
+
\begin{cases}
|
48
|
+
x, &\text{ if } x > \text{threshold} \\
|
49
|
+
\text{value}, &\text{ otherwise }
|
50
|
+
\end{cases}
|
51
|
+
|
52
|
+
Parameters
|
53
|
+
----------
|
54
|
+
threshold : float
|
55
|
+
The value to threshold at.
|
56
|
+
value : float
|
57
|
+
The value to replace with.
|
58
|
+
|
59
|
+
Shape
|
60
|
+
-----
|
61
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
62
|
+
- Output: :math:`(*)`, same shape as the input.
|
63
|
+
|
64
|
+
Examples
|
65
|
+
--------
|
66
|
+
.. code-block:: python
|
67
|
+
|
68
|
+
>>> import brainstate.nn as nn
|
69
|
+
>>> import brainstate
|
70
|
+
>>> m = nn.Threshold(0.1, 20)
|
71
|
+
>>> x = brainstate.random.randn(2)
|
72
|
+
>>> output = m(x)
|
73
|
+
"""
|
74
|
+
__module__ = 'brainstate.nn'
|
75
|
+
threshold: float
|
76
|
+
value: float
|
77
|
+
|
78
|
+
def __init__(self, threshold: float, value: float) -> None:
|
79
|
+
super().__init__()
|
80
|
+
self.threshold = threshold
|
81
|
+
self.value = value
|
82
|
+
|
83
|
+
def __call__(self, x: ArrayLike) -> ArrayLike:
|
84
|
+
dtype = u.math.get_dtype(x)
|
85
|
+
return jnp.where(x > jnp.asarray(self.threshold, dtype=dtype),
|
86
|
+
x,
|
87
|
+
jnp.asarray(self.value, dtype=dtype))
|
88
|
+
|
89
|
+
def __repr__(self):
|
90
|
+
return f'{self.__class__.__name__}(threshold={self.threshold}, value={self.value})'
|
91
|
+
|
92
|
+
|
93
|
+
class ReLU(ElementWiseBlock):
|
94
|
+
r"""Applies the rectified linear unit function element-wise.
|
95
|
+
|
96
|
+
The ReLU function is defined as:
|
97
|
+
|
98
|
+
.. math::
|
99
|
+
\text{ReLU}(x) = (x)^+ = \max(0, x)
|
100
|
+
|
101
|
+
Shape
|
102
|
+
-----
|
103
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
104
|
+
- Output: :math:`(*)`, same shape as the input.
|
105
|
+
|
106
|
+
Examples
|
107
|
+
--------
|
108
|
+
.. code-block:: python
|
109
|
+
|
110
|
+
>>> import brainstate.nn as nn
|
111
|
+
>>> import brainstate
|
112
|
+
>>> m = nn.ReLU()
|
113
|
+
>>> x = brainstate.random.randn(2)
|
114
|
+
>>> output = m(x)
|
115
|
+
|
116
|
+
An implementation of CReLU - https://arxiv.org/abs/1603.05201
|
117
|
+
|
118
|
+
.. code-block:: python
|
119
|
+
|
120
|
+
>>> import brainstate.nn as nn
|
121
|
+
>>> import brainstate
|
122
|
+
>>> import jax.numpy as jnp
|
123
|
+
>>> m = nn.ReLU()
|
124
|
+
>>> x = brainstate.random.randn(2).unsqueeze(0)
|
125
|
+
>>> output = jnp.concat((m(x), m(-x)))
|
126
|
+
"""
|
127
|
+
__module__ = 'brainstate.nn'
|
128
|
+
|
129
|
+
def __call__(self, x: ArrayLike) -> ArrayLike:
|
130
|
+
return F.relu(x)
|
131
|
+
|
132
|
+
def __repr__(self):
|
133
|
+
return f'{self.__class__.__name__}()'
|
134
|
+
|
135
|
+
|
136
|
+
class RReLU(ElementWiseBlock):
|
137
|
+
r"""Applies the randomized leaky rectified liner unit function, element-wise.
|
138
|
+
|
139
|
+
As described in the paper `Empirical Evaluation of Rectified Activations in
|
140
|
+
Convolutional Network`_.
|
141
|
+
|
142
|
+
The function is defined as:
|
143
|
+
|
144
|
+
.. math::
|
145
|
+
\text{RReLU}(x) =
|
146
|
+
\begin{cases}
|
147
|
+
x & \text{if } x \geq 0 \\
|
148
|
+
ax & \text{ otherwise }
|
149
|
+
\end{cases}
|
150
|
+
|
151
|
+
where :math:`a` is randomly sampled from uniform distribution
|
152
|
+
:math:`\mathcal{U}(\text{lower}, \text{upper})`.
|
153
|
+
|
154
|
+
Parameters
|
155
|
+
----------
|
156
|
+
lower : float, optional
|
157
|
+
Lower bound of the uniform distribution. Default: :math:`\frac{1}{8}`
|
158
|
+
upper : float, optional
|
159
|
+
Upper bound of the uniform distribution. Default: :math:`\frac{1}{3}`
|
160
|
+
|
161
|
+
Shape
|
162
|
+
-----
|
163
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
164
|
+
- Output: :math:`(*)`, same shape as the input.
|
165
|
+
|
166
|
+
References
|
167
|
+
----------
|
168
|
+
.. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
|
169
|
+
https://arxiv.org/abs/1505.00853
|
170
|
+
|
171
|
+
Examples
|
172
|
+
--------
|
173
|
+
.. code-block:: python
|
174
|
+
|
175
|
+
>>> import brainstate.nn as nn
|
176
|
+
>>> import brainstate
|
177
|
+
>>> m = nn.RReLU(0.1, 0.3)
|
178
|
+
>>> x = brainstate.random.randn(2)
|
179
|
+
>>> output = m(x)
|
180
|
+
"""
|
181
|
+
__module__ = 'brainstate.nn'
|
182
|
+
lower: float
|
183
|
+
upper: float
|
184
|
+
|
185
|
+
def __init__(
|
186
|
+
self,
|
187
|
+
lower: float = 1. / 8,
|
188
|
+
upper: float = 1. / 3,
|
189
|
+
):
|
190
|
+
super().__init__()
|
191
|
+
self.lower = lower
|
192
|
+
self.upper = upper
|
193
|
+
|
194
|
+
def __call__(self, x: ArrayLike) -> ArrayLike:
|
195
|
+
return F.rrelu(x, self.lower, self.upper)
|
196
|
+
|
197
|
+
def extra_repr(self):
|
198
|
+
return f'{self.__class__.__name__}(lower={self.lower}, upper={self.upper})'
|
199
|
+
|
200
|
+
|
201
|
+
class Hardtanh(ElementWiseBlock):
|
202
|
+
r"""Applies the HardTanh function element-wise.
|
203
|
+
|
204
|
+
HardTanh is defined as:
|
205
|
+
|
206
|
+
.. math::
|
207
|
+
\text{HardTanh}(x) = \begin{cases}
|
208
|
+
\text{max\_val} & \text{ if } x > \text{ max\_val } \\
|
209
|
+
\text{min\_val} & \text{ if } x < \text{ min\_val } \\
|
210
|
+
x & \text{ otherwise } \\
|
211
|
+
\end{cases}
|
212
|
+
|
213
|
+
Parameters
|
214
|
+
----------
|
215
|
+
min_val : float, optional
|
216
|
+
Minimum value of the linear region range. Default: -1
|
217
|
+
max_val : float, optional
|
218
|
+
Maximum value of the linear region range. Default: 1
|
219
|
+
|
220
|
+
Notes
|
221
|
+
-----
|
222
|
+
Keyword arguments :attr:`min_value` and :attr:`max_value`
|
223
|
+
have been deprecated in favor of :attr:`min_val` and :attr:`max_val`.
|
224
|
+
|
225
|
+
Shape
|
226
|
+
-----
|
227
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
228
|
+
- Output: :math:`(*)`, same shape as the input.
|
229
|
+
|
230
|
+
Examples
|
231
|
+
--------
|
232
|
+
.. code-block:: python
|
233
|
+
|
234
|
+
>>> import brainstate.nn as nn
|
235
|
+
>>> import brainstate
|
236
|
+
>>> m = nn.Hardtanh(-2, 2)
|
237
|
+
>>> x = brainstate.random.randn(2)
|
238
|
+
>>> output = m(x)
|
239
|
+
"""
|
240
|
+
__module__ = 'brainstate.nn'
|
241
|
+
min_val: float
|
242
|
+
max_val: float
|
243
|
+
|
244
|
+
def __init__(
|
245
|
+
self,
|
246
|
+
min_val: float = -1.,
|
247
|
+
max_val: float = 1.,
|
248
|
+
) -> None:
|
249
|
+
super().__init__()
|
250
|
+
self.min_val = min_val
|
251
|
+
self.max_val = max_val
|
252
|
+
assert self.max_val > self.min_val
|
253
|
+
|
254
|
+
def __call__(self, x: ArrayLike) -> ArrayLike:
|
255
|
+
return F.hard_tanh(x, self.min_val, self.max_val)
|
256
|
+
|
257
|
+
def extra_repr(self) -> str:
|
258
|
+
return f'{self.__class__.__name__}(min_val={self.min_val}, max_val={self.max_val})'
|
259
|
+
|
260
|
+
|
261
|
+
class ReLU6(Hardtanh):
|
262
|
+
r"""Applies the element-wise function.
|
263
|
+
|
264
|
+
ReLU6 is defined as:
|
265
|
+
|
266
|
+
.. math::
|
267
|
+
\text{ReLU6}(x) = \min(\max(0,x), 6)
|
268
|
+
|
269
|
+
Shape
|
270
|
+
-----
|
271
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
272
|
+
- Output: :math:`(*)`, same shape as the input.
|
273
|
+
|
274
|
+
Examples
|
275
|
+
--------
|
276
|
+
.. code-block:: python
|
277
|
+
|
278
|
+
>>> import brainstate.nn as nn
|
279
|
+
>>> import brainstate
|
280
|
+
>>> m = nn.ReLU6()
|
281
|
+
>>> x = brainstate.random.randn(2)
|
282
|
+
>>> output = m(x)
|
283
|
+
"""
|
284
|
+
__module__ = 'brainstate.nn'
|
285
|
+
|
286
|
+
def __init__(self):
|
287
|
+
super().__init__(0., 6.)
|
288
|
+
|
289
|
+
|
290
|
+
class Sigmoid(ElementWiseBlock):
|
291
|
+
r"""Applies the element-wise function.
|
292
|
+
|
293
|
+
Sigmoid is defined as:
|
294
|
+
|
295
|
+
.. math::
|
296
|
+
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
|
297
|
+
|
298
|
+
Shape
|
299
|
+
-----
|
300
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
301
|
+
- Output: :math:`(*)`, same shape as the input.
|
302
|
+
|
303
|
+
Examples
|
304
|
+
--------
|
305
|
+
.. code-block:: python
|
306
|
+
|
307
|
+
>>> import brainstate.nn as nn
|
308
|
+
>>> import brainstate
|
309
|
+
>>> m = nn.Sigmoid()
|
310
|
+
>>> x = brainstate.random.randn(2)
|
311
|
+
>>> output = m(x)
|
312
|
+
"""
|
313
|
+
__module__ = 'brainstate.nn'
|
314
|
+
|
315
|
+
def __call__(self, x: ArrayLike) -> ArrayLike:
|
316
|
+
return F.sigmoid(x)
|
317
|
+
|
318
|
+
|
319
|
+
class Hardsigmoid(ElementWiseBlock):
|
320
|
+
r"""Applies the Hardsigmoid function element-wise.
|
321
|
+
|
322
|
+
Hardsigmoid is defined as:
|
323
|
+
|
324
|
+
.. math::
|
325
|
+
\text{Hardsigmoid}(x) = \begin{cases}
|
326
|
+
0 & \text{if~} x \le -3, \\
|
327
|
+
1 & \text{if~} x \ge +3, \\
|
328
|
+
x / 6 + 1 / 2 & \text{otherwise}
|
329
|
+
\end{cases}
|
330
|
+
|
331
|
+
Shape
|
332
|
+
-----
|
333
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
334
|
+
- Output: :math:`(*)`, same shape as the input.
|
335
|
+
|
336
|
+
Examples
|
337
|
+
--------
|
338
|
+
.. code-block:: python
|
339
|
+
|
340
|
+
>>> import brainstate.nn as nn
|
341
|
+
>>> import brainstate
|
342
|
+
>>> m = nn.Hardsigmoid()
|
343
|
+
>>> x = brainstate.random.randn(2)
|
344
|
+
>>> output = m(x)
|
345
|
+
"""
|
346
|
+
__module__ = 'brainstate.nn'
|
347
|
+
|
348
|
+
def __call__(self, x: ArrayLike) -> ArrayLike:
|
349
|
+
return F.hard_sigmoid(x)
|
350
|
+
|
351
|
+
|
352
|
+
class Tanh(ElementWiseBlock):
|
353
|
+
r"""Applies the Hyperbolic Tangent (Tanh) function element-wise.
|
354
|
+
|
355
|
+
Tanh is defined as:
|
356
|
+
|
357
|
+
.. math::
|
358
|
+
\text{Tanh}(x) = \tanh(x) = \frac{\exp(x) - \exp(-x)} {\exp(x) + \exp(-x)}
|
359
|
+
|
360
|
+
Shape
|
361
|
+
-----
|
362
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
363
|
+
- Output: :math:`(*)`, same shape as the input.
|
364
|
+
|
365
|
+
Examples
|
366
|
+
--------
|
367
|
+
.. code-block:: python
|
368
|
+
|
369
|
+
>>> import brainstate.nn as nn
|
370
|
+
>>> import brainstate
|
371
|
+
>>> m = nn.Tanh()
|
372
|
+
>>> x = brainstate.random.randn(2)
|
373
|
+
>>> output = m(x)
|
374
|
+
"""
|
375
|
+
__module__ = 'brainstate.nn'
|
376
|
+
|
377
|
+
def __call__(self, x: ArrayLike) -> ArrayLike:
|
378
|
+
return F.tanh(x)
|
379
|
+
|
380
|
+
|
381
|
+
class SiLU(ElementWiseBlock):
|
382
|
+
r"""Applies the Sigmoid Linear Unit (SiLU) function, element-wise.
|
383
|
+
|
384
|
+
The SiLU function is also known as the swish function.
|
385
|
+
|
386
|
+
.. math::
|
387
|
+
\text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.}
|
388
|
+
|
389
|
+
Notes
|
390
|
+
-----
|
391
|
+
See `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_
|
392
|
+
where the SiLU (Sigmoid Linear Unit) was originally coined, and see
|
393
|
+
`Sigmoid-Weighted Linear Units for Neural Network Function Approximation
|
394
|
+
in Reinforcement Learning <https://arxiv.org/abs/1702.03118>`_ and `Swish:
|
395
|
+
a Self-Gated Activation Function <https://arxiv.org/abs/1710.05941v1>`_
|
396
|
+
where the SiLU was experimented with later.
|
397
|
+
|
398
|
+
Shape
|
399
|
+
-----
|
400
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
401
|
+
- Output: :math:`(*)`, same shape as the input.
|
402
|
+
|
403
|
+
Examples
|
404
|
+
--------
|
405
|
+
.. code-block:: python
|
406
|
+
|
407
|
+
>>> import brainstate.nn as nn
|
408
|
+
>>> import brainstate
|
409
|
+
>>> m = nn.SiLU()
|
410
|
+
>>> x = brainstate.random.randn(2)
|
411
|
+
>>> output = m(x)
|
412
|
+
"""
|
413
|
+
__module__ = 'brainstate.nn'
|
414
|
+
|
415
|
+
def __call__(self, x: ArrayLike) -> ArrayLike:
|
416
|
+
return F.silu(x)
|
417
|
+
|
418
|
+
|
419
|
+
class Mish(ElementWiseBlock):
|
420
|
+
r"""Applies the Mish function, element-wise.
|
421
|
+
|
422
|
+
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
|
423
|
+
|
424
|
+
.. math::
|
425
|
+
\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
|
426
|
+
|
427
|
+
Notes
|
428
|
+
-----
|
429
|
+
See `Mish: A Self Regularized Non-Monotonic Neural Activation Function
|
430
|
+
<https://arxiv.org/abs/1908.08681>`_
|
431
|
+
|
432
|
+
Shape
|
433
|
+
-----
|
434
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
435
|
+
- Output: :math:`(*)`, same shape as the input.
|
436
|
+
|
437
|
+
Examples
|
438
|
+
--------
|
439
|
+
.. code-block:: python
|
440
|
+
|
441
|
+
>>> import brainstate.nn as nn
|
442
|
+
>>> import brainstate
|
443
|
+
>>> m = nn.Mish()
|
444
|
+
>>> x = brainstate.random.randn(2)
|
445
|
+
>>> output = m(x)
|
446
|
+
"""
|
447
|
+
__module__ = 'brainstate.nn'
|
448
|
+
|
449
|
+
def __call__(self, x: ArrayLike) -> ArrayLike:
|
450
|
+
return F.mish(x)
|
451
|
+
|
452
|
+
|
453
|
+
class Hardswish(ElementWiseBlock):
|
454
|
+
r"""Applies the Hardswish function, element-wise.
|
455
|
+
|
456
|
+
As described in the paper `Searching for MobileNetV3
|
457
|
+
<https://arxiv.org/abs/1905.02244>`_.
|
458
|
+
|
459
|
+
Hardswish is defined as:
|
460
|
+
|
461
|
+
.. math::
|
462
|
+
\text{Hardswish}(x) = \begin{cases}
|
463
|
+
0 & \text{if~} x \le -3, \\
|
464
|
+
x & \text{if~} x \ge +3, \\
|
465
|
+
x \cdot (x + 3) /6 & \text{otherwise}
|
466
|
+
\end{cases}
|
467
|
+
|
468
|
+
Shape
|
469
|
+
-----
|
470
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
471
|
+
- Output: :math:`(*)`, same shape as the input.
|
472
|
+
|
473
|
+
Examples
|
474
|
+
--------
|
475
|
+
.. code-block:: python
|
476
|
+
|
477
|
+
>>> import brainstate.nn as nn
|
478
|
+
>>> import brainstate
|
479
|
+
>>> m = nn.Hardswish()
|
480
|
+
>>> x = brainstate.random.randn(2)
|
481
|
+
>>> output = m(x)
|
482
|
+
"""
|
483
|
+
__module__ = 'brainstate.nn'
|
484
|
+
|
485
|
+
def __call__(self, x: ArrayLike) -> ArrayLike:
|
486
|
+
return F.hard_swish(x)
|
487
|
+
|
488
|
+
|
489
|
+
class ELU(ElementWiseBlock):
|
490
|
+
r"""Applies the Exponential Linear Unit (ELU) function, element-wise.
|
491
|
+
|
492
|
+
As described in the paper: `Fast and Accurate Deep Network Learning by
|
493
|
+
Exponential Linear Units (ELUs) <https://arxiv.org/abs/1511.07289>`__.
|
494
|
+
|
495
|
+
ELU is defined as:
|
496
|
+
|
497
|
+
.. math::
|
498
|
+
\text{ELU}(x) = \begin{cases}
|
499
|
+
x, & \text{ if } x > 0\\
|
500
|
+
\alpha * (\exp(x) - 1), & \text{ if } x \leq 0
|
501
|
+
\end{cases}
|
502
|
+
|
503
|
+
Parameters
|
504
|
+
----------
|
505
|
+
alpha : float, optional
|
506
|
+
The :math:`\alpha` value for the ELU formulation. Default: 1.0
|
507
|
+
|
508
|
+
Shape
|
509
|
+
-----
|
510
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
511
|
+
- Output: :math:`(*)`, same shape as the input.
|
512
|
+
|
513
|
+
Examples
|
514
|
+
--------
|
515
|
+
.. code-block:: python
|
516
|
+
|
517
|
+
>>> import brainstate.nn as nn
|
518
|
+
>>> import brainstate
|
519
|
+
>>> m = nn.ELU()
|
520
|
+
>>> x = brainstate.random.randn(2)
|
521
|
+
>>> output = m(x)
|
522
|
+
"""
|
523
|
+
__module__ = 'brainstate.nn'
|
524
|
+
alpha: float
|
525
|
+
|
526
|
+
def __init__(self, alpha: float = 1.) -> None:
|
527
|
+
super().__init__()
|
528
|
+
self.alpha = alpha
|
529
|
+
|
530
|
+
def __call__(self, x: ArrayLike) -> ArrayLike:
|
531
|
+
return F.elu(x, self.alpha)
|
532
|
+
|
533
|
+
def extra_repr(self) -> str:
|
534
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
535
|
+
|
536
|
+
|
537
|
+
class CELU(ElementWiseBlock):
|
538
|
+
r"""Applies the element-wise function.
|
539
|
+
|
540
|
+
.. math::
|
541
|
+
\text{CELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x/\alpha) - 1))
|
542
|
+
|
543
|
+
More details can be found in the paper `Continuously Differentiable Exponential
|
544
|
+
Linear Units`_ .
|
545
|
+
|
546
|
+
Parameters
|
547
|
+
----------
|
548
|
+
alpha : float, optional
|
549
|
+
The :math:`\alpha` value for the CELU formulation. Default: 1.0
|
550
|
+
|
551
|
+
Shape
|
552
|
+
-----
|
553
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
554
|
+
- Output: :math:`(*)`, same shape as the input.
|
555
|
+
|
556
|
+
References
|
557
|
+
----------
|
558
|
+
.. _`Continuously Differentiable Exponential Linear Units`:
|
559
|
+
https://arxiv.org/abs/1704.07483
|
560
|
+
|
561
|
+
Examples
|
562
|
+
--------
|
563
|
+
.. code-block:: python
|
564
|
+
|
565
|
+
>>> import brainstate.nn as nn
|
566
|
+
>>> import brainstate
|
567
|
+
>>> m = nn.CELU()
|
568
|
+
>>> x = brainstate.random.randn(2)
|
569
|
+
>>> output = m(x)
|
570
|
+
"""
|
571
|
+
__module__ = 'brainstate.nn'
|
572
|
+
alpha: float
|
573
|
+
|
574
|
+
def __init__(self, alpha: float = 1.) -> None:
|
575
|
+
super().__init__()
|
576
|
+
self.alpha = alpha
|
577
|
+
|
578
|
+
def __call__(self, x: ArrayLike) -> ArrayLike:
|
579
|
+
return F.celu(x, self.alpha)
|
580
|
+
|
581
|
+
def extra_repr(self) -> str:
|
582
|
+
return f'{self.__class__.__name__}(alpha={self.alpha})'
|
583
|
+
|
584
|
+
|
585
|
+
class SELU(ElementWiseBlock):
|
586
|
+
r"""Applied element-wise.
|
587
|
+
|
588
|
+
.. math::
|
589
|
+
\text{SELU}(x) = \text{scale} * (\max(0,x) + \min(0, \alpha * (\exp(x) - 1)))
|
590
|
+
|
591
|
+
with :math:`\alpha = 1.6732632423543772848170429916717` and
|
592
|
+
:math:`\text{scale} = 1.0507009873554804934193349852946`.
|
593
|
+
|
594
|
+
More details can be found in the paper `Self-Normalizing Neural Networks`_ .
|
595
|
+
|
596
|
+
Shape
|
597
|
+
-----
|
598
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
599
|
+
- Output: :math:`(*)`, same shape as the input.
|
600
|
+
|
601
|
+
References
|
602
|
+
----------
|
603
|
+
.. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
|
604
|
+
|
605
|
+
Examples
|
606
|
+
--------
|
607
|
+
.. code-block:: python
|
608
|
+
|
609
|
+
>>> import brainstate.nn as nn
|
610
|
+
>>> import brainstate
|
611
|
+
>>> m = nn.SELU()
|
612
|
+
>>> x = brainstate.random.randn(2)
|
613
|
+
>>> output = m(x)
|
614
|
+
"""
|
615
|
+
__module__ = 'brainstate.nn'
|
616
|
+
|
617
|
+
def __call__(self, x: ArrayLike) -> ArrayLike:
|
618
|
+
return F.selu(x)
|
619
|
+
|
620
|
+
|
621
|
+
class GLU(ElementWiseBlock):
|
622
|
+
r"""Applies the gated linear unit function.
|
623
|
+
|
624
|
+
.. math::
|
625
|
+
{GLU}(a, b)= a \otimes \sigma(b)
|
626
|
+
|
627
|
+
where :math:`a` is the first half of the input matrices and :math:`b` is
|
628
|
+
the second half.
|
629
|
+
|
630
|
+
Parameters
|
631
|
+
----------
|
632
|
+
dim : int, optional
|
633
|
+
The dimension on which to split the input. Default: -1
|
634
|
+
|
635
|
+
Shape
|
636
|
+
-----
|
637
|
+
- Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
|
638
|
+
dimensions
|
639
|
+
- Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
|
640
|
+
|
641
|
+
Examples
|
642
|
+
--------
|
643
|
+
.. code-block:: python
|
644
|
+
|
645
|
+
>>> import brainstate.nn as nn
|
646
|
+
>>> import brainstate
|
647
|
+
>>> m = nn.GLU()
|
648
|
+
>>> x = brainstate.random.randn(4, 2)
|
649
|
+
>>> output = m(x)
|
650
|
+
"""
|
651
|
+
__module__ = 'brainstate.nn'
|
652
|
+
dim: int
|
653
|
+
|
654
|
+
def __init__(self, dim: int = -1) -> None:
|
655
|
+
super().__init__()
|
656
|
+
self.dim = dim
|
657
|
+
|
658
|
+
def __call__(self, x: ArrayLike) -> ArrayLike:
|
659
|
+
return F.glu(x, self.dim)
|
660
|
+
|
661
|
+
def __repr__(self):
|
662
|
+
return f'{self.__class__.__name__}(dim={self.dim})'
|
663
|
+
|
664
|
+
|
665
|
+
class GELU(ElementWiseBlock):
|
666
|
+
r"""Applies the Gaussian Error Linear Units function.
|
667
|
+
|
668
|
+
.. math:: \text{GELU}(x) = x * \Phi(x)
|
669
|
+
|
670
|
+
where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian
|
671
|
+
Distribution.
|
672
|
+
|
673
|
+
When the approximate argument is True, Gelu is estimated with:
|
674
|
+
|
675
|
+
.. math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt(2 / \pi) * (x + 0.044715 * x^3)))
|
676
|
+
|
677
|
+
Parameters
|
678
|
+
----------
|
679
|
+
approximate : bool, optional
|
680
|
+
Whether to use the tanh approximation algorithm. Default: False
|
681
|
+
|
682
|
+
Shape
|
683
|
+
-----
|
684
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
685
|
+
- Output: :math:`(*)`, same shape as the input.
|
686
|
+
|
687
|
+
Examples
|
688
|
+
--------
|
689
|
+
.. code-block:: python
|
690
|
+
|
691
|
+
>>> import brainstate.nn as nn
|
692
|
+
>>> import brainstate
|
693
|
+
>>> m = nn.GELU()
|
694
|
+
>>> x = brainstate.random.randn(2)
|
695
|
+
>>> output = m(x)
|
696
|
+
"""
|
697
|
+
__module__ = 'brainstate.nn'
|
698
|
+
approximate: bool
|
699
|
+
|
700
|
+
def __init__(self, approximate: bool = False) -> None:
|
701
|
+
super().__init__()
|
702
|
+
self.approximate = approximate
|
703
|
+
|
704
|
+
def __call__(self, x: ArrayLike) -> ArrayLike:
|
705
|
+
return F.gelu(x, approximate=self.approximate)
|
706
|
+
|
707
|
+
def __repr__(self):
|
708
|
+
return f'{self.__class__.__name__}(approximate={self.approximate})'
|
709
|
+
|
710
|
+
|
711
|
+
class Hardshrink(ElementWiseBlock):
|
712
|
+
r"""Applies the Hard Shrinkage (Hardshrink) function element-wise.
|
713
|
+
|
714
|
+
Hardshrink is defined as:
|
715
|
+
|
716
|
+
.. math::
|
717
|
+
\text{HardShrink}(x) =
|
718
|
+
\begin{cases}
|
719
|
+
x, & \text{ if } x > \lambda \\
|
720
|
+
x, & \text{ if } x < -\lambda \\
|
721
|
+
0, & \text{ otherwise }
|
722
|
+
\end{cases}
|
723
|
+
|
724
|
+
Parameters
|
725
|
+
----------
|
726
|
+
lambd : float, optional
|
727
|
+
The :math:`\lambda` value for the Hardshrink formulation. Default: 0.5
|
728
|
+
|
729
|
+
Shape
|
730
|
+
-----
|
731
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
732
|
+
- Output: :math:`(*)`, same shape as the input.
|
733
|
+
|
734
|
+
Examples
|
735
|
+
--------
|
736
|
+
.. code-block:: python
|
737
|
+
|
738
|
+
>>> import brainstate.nn as nn
|
739
|
+
>>> import brainstate
|
740
|
+
>>> m = nn.Hardshrink()
|
741
|
+
>>> x = brainstate.random.randn(2)
|
742
|
+
>>> output = m(x)
|
743
|
+
"""
|
744
|
+
__module__ = 'brainstate.nn'
|
745
|
+
lambd: float
|
746
|
+
|
747
|
+
def __init__(self, lambd: float = 0.5) -> None:
|
748
|
+
super().__init__()
|
749
|
+
self.lambd = lambd
|
750
|
+
|
751
|
+
def __call__(self, x: ArrayLike) -> ArrayLike:
|
752
|
+
return F.hard_shrink(x, self.lambd)
|
753
|
+
|
754
|
+
def __repr__(self):
|
755
|
+
return f'{self.__class__.__name__}(lambd={self.lambd})'
|
756
|
+
|
757
|
+
|
758
|
+
class LeakyReLU(ElementWiseBlock):
|
759
|
+
r"""Applies the element-wise function.
|
760
|
+
|
761
|
+
.. math::
|
762
|
+
\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)
|
763
|
+
|
764
|
+
or
|
765
|
+
|
766
|
+
.. math::
|
767
|
+
\text{LeakyReLU}(x) =
|
768
|
+
\begin{cases}
|
769
|
+
x, & \text{ if } x \geq 0 \\
|
770
|
+
\text{negative\_slope} \times x, & \text{ otherwise }
|
771
|
+
\end{cases}
|
772
|
+
|
773
|
+
Parameters
|
774
|
+
----------
|
775
|
+
negative_slope : float, optional
|
776
|
+
Controls the angle of the negative slope (which is used for
|
777
|
+
negative input values). Default: 1e-2
|
778
|
+
|
779
|
+
Shape
|
780
|
+
-----
|
781
|
+
- Input: :math:`(*)` where `*` means, any number of additional
|
782
|
+
dimensions
|
783
|
+
- Output: :math:`(*)`, same shape as the input
|
784
|
+
|
785
|
+
Examples
|
786
|
+
--------
|
787
|
+
.. code-block:: python
|
788
|
+
|
789
|
+
>>> import brainstate.nn as nn
|
790
|
+
>>> import brainstate
|
791
|
+
>>> m = nn.LeakyReLU(0.1)
|
792
|
+
>>> x = brainstate.random.randn(2)
|
793
|
+
>>> output = m(x)
|
794
|
+
"""
|
795
|
+
__module__ = 'brainstate.nn'
|
796
|
+
negative_slope: float
|
797
|
+
|
798
|
+
def __init__(self, negative_slope: float = 1e-2) -> None:
|
799
|
+
super().__init__()
|
800
|
+
self.negative_slope = negative_slope
|
801
|
+
|
802
|
+
def __call__(self, x: ArrayLike) -> ArrayLike:
|
803
|
+
return F.leaky_relu(x, self.negative_slope)
|
804
|
+
|
805
|
+
def __repr__(self):
|
806
|
+
return f'{self.__class__.__name__}(negative_slope={self.negative_slope})'
|
807
|
+
|
808
|
+
|
809
|
+
class LogSigmoid(ElementWiseBlock):
|
810
|
+
r"""Applies the element-wise function.
|
811
|
+
|
812
|
+
.. math::
|
813
|
+
\text{LogSigmoid}(x) = \log\left(\frac{ 1 }{ 1 + \exp(-x)}\right)
|
814
|
+
|
815
|
+
Shape
|
816
|
+
-----
|
817
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
818
|
+
- Output: :math:`(*)`, same shape as the input.
|
819
|
+
|
820
|
+
Examples
|
821
|
+
--------
|
822
|
+
.. code-block:: python
|
823
|
+
|
824
|
+
>>> import brainstate.nn as nn
|
825
|
+
>>> import brainstate
|
826
|
+
>>> m = nn.LogSigmoid()
|
827
|
+
>>> x = brainstate.random.randn(2)
|
828
|
+
>>> output = m(x)
|
829
|
+
"""
|
830
|
+
__module__ = 'brainstate.nn'
|
831
|
+
|
832
|
+
def __call__(self, x: ArrayLike) -> ArrayLike:
|
833
|
+
return F.log_sigmoid(x)
|
834
|
+
|
835
|
+
|
836
|
+
class Softplus(ElementWiseBlock):
|
837
|
+
r"""Applies the Softplus function element-wise.
|
838
|
+
|
839
|
+
.. math::
|
840
|
+
\text{Softplus}(x) = \frac{1}{\beta} * \log(1 + \exp(\beta * x))
|
841
|
+
|
842
|
+
SoftPlus is a smooth approximation to the ReLU function and can be used
|
843
|
+
to constrain the output of a machine to always be positive.
|
844
|
+
|
845
|
+
For numerical stability the implementation reverts to the linear function
|
846
|
+
when :math:`input \times \beta > threshold`.
|
847
|
+
|
848
|
+
Shape
|
849
|
+
-----
|
850
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
851
|
+
- Output: :math:`(*)`, same shape as the input.
|
852
|
+
|
853
|
+
Examples
|
854
|
+
--------
|
855
|
+
.. code-block:: python
|
856
|
+
|
857
|
+
>>> import brainstate.nn as nn
|
858
|
+
>>> import brainstate
|
859
|
+
>>> m = nn.Softplus()
|
860
|
+
>>> x = brainstate.random.randn(2)
|
861
|
+
>>> output = m(x)
|
862
|
+
"""
|
863
|
+
__module__ = 'brainstate.nn'
|
864
|
+
|
865
|
+
def __call__(self, x: ArrayLike) -> ArrayLike:
|
866
|
+
return F.softplus(x)
|
867
|
+
|
868
|
+
|
869
|
+
class Softshrink(ElementWiseBlock):
|
870
|
+
r"""Applies the soft shrinkage function elementwise.
|
871
|
+
|
872
|
+
.. math::
|
873
|
+
\text{SoftShrinkage}(x) =
|
874
|
+
\begin{cases}
|
875
|
+
x - \lambda, & \text{ if } x > \lambda \\
|
876
|
+
x + \lambda, & \text{ if } x < -\lambda \\
|
877
|
+
0, & \text{ otherwise }
|
878
|
+
\end{cases}
|
879
|
+
|
880
|
+
Parameters
|
881
|
+
----------
|
882
|
+
lambd : float, optional
|
883
|
+
The :math:`\lambda` (must be no less than zero) value for the
|
884
|
+
Softshrink formulation. Default: 0.5
|
885
|
+
|
886
|
+
Shape
|
887
|
+
-----
|
888
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
889
|
+
- Output: :math:`(*)`, same shape as the input.
|
890
|
+
|
891
|
+
Examples
|
892
|
+
--------
|
893
|
+
.. code-block:: python
|
894
|
+
|
895
|
+
>>> import brainstate.nn as nn
|
896
|
+
>>> import brainstate
|
897
|
+
>>> m = nn.Softshrink()
|
898
|
+
>>> x = brainstate.random.randn(2)
|
899
|
+
>>> output = m(x)
|
900
|
+
"""
|
901
|
+
__module__ = 'brainstate.nn'
|
902
|
+
lambd: float
|
903
|
+
|
904
|
+
def __init__(self, lambd: float = 0.5) -> None:
|
905
|
+
super().__init__()
|
906
|
+
self.lambd = lambd
|
907
|
+
|
908
|
+
def __call__(self, x: ArrayLike) -> ArrayLike:
|
909
|
+
return F.soft_shrink(x, self.lambd)
|
910
|
+
|
911
|
+
def __repr__(self):
|
912
|
+
return f'{self.__class__.__name__}(lambd={self.lambd})'
|
913
|
+
|
914
|
+
|
915
|
+
class PReLU(ElementWiseBlock):
|
916
|
+
r"""Applies the element-wise function.
|
917
|
+
|
918
|
+
.. math::
|
919
|
+
\text{PReLU}(x) = \max(0,x) + a * \min(0,x)
|
920
|
+
|
921
|
+
or
|
922
|
+
|
923
|
+
.. math::
|
924
|
+
\text{PReLU}(x) =
|
925
|
+
\begin{cases}
|
926
|
+
x, & \text{ if } x \geq 0 \\
|
927
|
+
ax, & \text{ otherwise }
|
928
|
+
\end{cases}
|
929
|
+
|
930
|
+
Here :math:`a` is a learnable parameter. When called without arguments,
|
931
|
+
`nn.PReLU()` uses a single parameter :math:`a` across all input channels.
|
932
|
+
If called with `nn.PReLU(nChannels)`, a separate :math:`a` is used for
|
933
|
+
each input channel.
|
934
|
+
|
935
|
+
Parameters
|
936
|
+
----------
|
937
|
+
num_parameters : int, optional
|
938
|
+
Number of :math:`a` to learn. Although it takes an int as input,
|
939
|
+
there is only two values are legitimate: 1, or the number of channels
|
940
|
+
at input. Default: 1
|
941
|
+
init : float, optional
|
942
|
+
The initial value of :math:`a`. Default: 0.25
|
943
|
+
dtype : optional
|
944
|
+
The data type for the weight parameter.
|
945
|
+
|
946
|
+
Shape
|
947
|
+
-----
|
948
|
+
- Input: :math:`( *)` where `*` means, any number of additional dimensions.
|
949
|
+
- Output: :math:`(*)`, same shape as the input.
|
950
|
+
|
951
|
+
Attributes
|
952
|
+
----------
|
953
|
+
weight : Tensor
|
954
|
+
The learnable weights of shape (:attr:`num_parameters`).
|
955
|
+
|
956
|
+
Notes
|
957
|
+
-----
|
958
|
+
- Weight decay should not be used when learning :math:`a` for good performance.
|
959
|
+
- Channel dim is the 2nd dim of input. When input has dims < 2, then there is
|
960
|
+
no channel dim and the number of channels = 1.
|
961
|
+
|
962
|
+
Examples
|
963
|
+
--------
|
964
|
+
.. code-block:: python
|
965
|
+
|
966
|
+
>>> import brainstate
|
967
|
+
>>> m = brainstate.nn.PReLU()
|
968
|
+
>>> x = brainstate.random.randn(2)
|
969
|
+
>>> output = m(x)
|
970
|
+
"""
|
971
|
+
__module__ = 'brainstate.nn'
|
972
|
+
num_parameters: int
|
973
|
+
|
974
|
+
def __init__(self, num_parameters: int = 1, init: float = 0.25, dtype=None) -> None:
|
975
|
+
super().__init__()
|
976
|
+
self.num_parameters = num_parameters
|
977
|
+
self.weight = ParamState(jnp.ones(num_parameters, dtype=dtype) * init)
|
978
|
+
|
979
|
+
def __call__(self, x: ArrayLike) -> ArrayLike:
|
980
|
+
return F.prelu(x, self.weight.value)
|
981
|
+
|
982
|
+
def __repr__(self):
|
983
|
+
return f'{self.__class__.__name__}(num_parameters={self.num_parameters})'
|
984
|
+
|
985
|
+
|
986
|
+
class Softsign(ElementWiseBlock):
|
987
|
+
r"""Applies the element-wise function.
|
988
|
+
|
989
|
+
.. math::
|
990
|
+
\text{SoftSign}(x) = \frac{x}{ 1 + |x|}
|
991
|
+
|
992
|
+
Shape
|
993
|
+
-----
|
994
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
995
|
+
- Output: :math:`(*)`, same shape as the input.
|
996
|
+
|
997
|
+
Examples
|
998
|
+
--------
|
999
|
+
.. code-block:: python
|
1000
|
+
|
1001
|
+
>>> import brainstate.nn as nn
|
1002
|
+
>>> import brainstate
|
1003
|
+
>>> m = nn.Softsign()
|
1004
|
+
>>> x = brainstate.random.randn(2)
|
1005
|
+
>>> output = m(x)
|
1006
|
+
"""
|
1007
|
+
__module__ = 'brainstate.nn'
|
1008
|
+
|
1009
|
+
def __call__(self, x: ArrayLike) -> ArrayLike:
|
1010
|
+
return F.soft_sign(x)
|
1011
|
+
|
1012
|
+
|
1013
|
+
class Tanhshrink(ElementWiseBlock):
|
1014
|
+
r"""Applies the element-wise function.
|
1015
|
+
|
1016
|
+
.. math::
|
1017
|
+
\text{Tanhshrink}(x) = x - \tanh(x)
|
1018
|
+
|
1019
|
+
Shape
|
1020
|
+
-----
|
1021
|
+
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
1022
|
+
- Output: :math:`(*)`, same shape as the input.
|
1023
|
+
|
1024
|
+
Examples
|
1025
|
+
--------
|
1026
|
+
.. code-block:: python
|
1027
|
+
|
1028
|
+
>>> import brainstate.nn as nn
|
1029
|
+
>>> import brainstate
|
1030
|
+
>>> m = nn.Tanhshrink()
|
1031
|
+
>>> x = brainstate.random.randn(2)
|
1032
|
+
>>> output = m(x)
|
1033
|
+
"""
|
1034
|
+
__module__ = 'brainstate.nn'
|
1035
|
+
|
1036
|
+
def __call__(self, x: ArrayLike) -> ArrayLike:
|
1037
|
+
return F.tanh_shrink(x)
|
1038
|
+
|
1039
|
+
|
1040
|
+
class Softmin(ElementWiseBlock):
|
1041
|
+
r"""Applies the Softmin function to an n-dimensional input Tensor.
|
1042
|
+
|
1043
|
+
Rescales the input so that the elements of the n-dimensional output Tensor
|
1044
|
+
lie in the range `[0, 1]` and sum to 1.
|
1045
|
+
|
1046
|
+
Softmin is defined as:
|
1047
|
+
|
1048
|
+
.. math::
|
1049
|
+
\text{Softmin}(x_{i}) = \frac{\exp(-x_i)}{\sum_j \exp(-x_j)}
|
1050
|
+
|
1051
|
+
Parameters
|
1052
|
+
----------
|
1053
|
+
dim : int, optional
|
1054
|
+
A dimension along which Softmin will be computed (so every slice
|
1055
|
+
along dim will sum to 1).
|
1056
|
+
|
1057
|
+
Shape
|
1058
|
+
-----
|
1059
|
+
- Input: :math:`(*)` where `*` means, any number of additional dimensions
|
1060
|
+
- Output: :math:`(*)`, same shape as the input
|
1061
|
+
|
1062
|
+
Returns
|
1063
|
+
-------
|
1064
|
+
Tensor
|
1065
|
+
A Tensor of the same dimension and shape as the input, with
|
1066
|
+
values in the range [0, 1]
|
1067
|
+
|
1068
|
+
Examples
|
1069
|
+
--------
|
1070
|
+
.. code-block:: python
|
1071
|
+
|
1072
|
+
>>> import brainstate.nn as nn
|
1073
|
+
>>> import brainstate
|
1074
|
+
>>> m = nn.Softmin(dim=1)
|
1075
|
+
>>> x = brainstate.random.randn(2, 3)
|
1076
|
+
>>> output = m(x)
|
1077
|
+
"""
|
1078
|
+
__module__ = 'brainstate.nn'
|
1079
|
+
dim: Optional[int]
|
1080
|
+
|
1081
|
+
def __init__(self, dim: Optional[int] = None) -> None:
|
1082
|
+
super().__init__()
|
1083
|
+
self.dim = dim
|
1084
|
+
|
1085
|
+
def __call__(self, x: ArrayLike) -> ArrayLike:
|
1086
|
+
return F.softmin(x, self.dim)
|
1087
|
+
|
1088
|
+
def __repr__(self):
|
1089
|
+
return f'{self.__class__.__name__}(dim={self.dim})'
|
1090
|
+
|
1091
|
+
|
1092
|
+
class Softmax(ElementWiseBlock):
|
1093
|
+
r"""Applies the Softmax function to an n-dimensional input Tensor.
|
1094
|
+
|
1095
|
+
Rescales the input so that the elements of the n-dimensional output Tensor
|
1096
|
+
lie in the range [0,1] and sum to 1.
|
1097
|
+
|
1098
|
+
Softmax is defined as:
|
1099
|
+
|
1100
|
+
.. math::
|
1101
|
+
\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
|
1102
|
+
|
1103
|
+
When the input Tensor is a sparse tensor then the unspecified
|
1104
|
+
values are treated as ``-inf``.
|
1105
|
+
|
1106
|
+
Parameters
|
1107
|
+
----------
|
1108
|
+
dim : int, optional
|
1109
|
+
A dimension along which Softmax will be computed (so every slice
|
1110
|
+
along dim will sum to 1).
|
1111
|
+
|
1112
|
+
Shape
|
1113
|
+
-----
|
1114
|
+
- Input: :math:`(*)` where `*` means, any number of additional dimensions
|
1115
|
+
- Output: :math:`(*)`, same shape as the input
|
1116
|
+
|
1117
|
+
Returns
|
1118
|
+
-------
|
1119
|
+
Tensor
|
1120
|
+
A Tensor of the same dimension and shape as the input with
|
1121
|
+
values in the range [0, 1]
|
1122
|
+
|
1123
|
+
Notes
|
1124
|
+
-----
|
1125
|
+
This module doesn't work directly with NLLLoss, which expects the Log to be
|
1126
|
+
computed between the Softmax and itself. Use `LogSoftmax` instead (it's
|
1127
|
+
faster and has better numerical properties).
|
1128
|
+
|
1129
|
+
Examples
|
1130
|
+
--------
|
1131
|
+
.. code-block:: python
|
1132
|
+
|
1133
|
+
>>> import brainstate.nn as nn
|
1134
|
+
>>> import brainstate
|
1135
|
+
>>> m = nn.Softmax(dim=1)
|
1136
|
+
>>> x = brainstate.random.randn(2, 3)
|
1137
|
+
>>> output = m(x)
|
1138
|
+
"""
|
1139
|
+
__module__ = 'brainstate.nn'
|
1140
|
+
dim: Optional[int]
|
1141
|
+
|
1142
|
+
def __init__(self, dim: Optional[int] = None) -> None:
|
1143
|
+
super().__init__()
|
1144
|
+
self.dim = dim
|
1145
|
+
|
1146
|
+
def __call__(self, x: ArrayLike) -> ArrayLike:
|
1147
|
+
return F.softmax(x, self.dim)
|
1148
|
+
|
1149
|
+
def __repr__(self):
|
1150
|
+
return f'{self.__class__.__name__}(dim={self.dim})'
|
1151
|
+
|
1152
|
+
|
1153
|
+
class Softmax2d(ElementWiseBlock):
|
1154
|
+
r"""Applies SoftMax over features to each spatial location.
|
1155
|
+
|
1156
|
+
When given an image of ``Channels x Height x Width``, it will
|
1157
|
+
apply `Softmax` to each location :math:`(Channels, h_i, w_j)`
|
1158
|
+
|
1159
|
+
Shape
|
1160
|
+
-----
|
1161
|
+
- Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`.
|
1162
|
+
- Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
|
1163
|
+
|
1164
|
+
Returns
|
1165
|
+
-------
|
1166
|
+
Tensor
|
1167
|
+
A Tensor of the same dimension and shape as the input with
|
1168
|
+
values in the range [0, 1]
|
1169
|
+
|
1170
|
+
Examples
|
1171
|
+
--------
|
1172
|
+
.. code-block:: python
|
1173
|
+
|
1174
|
+
>>> import brainstate.nn as nn
|
1175
|
+
>>> import brainstate
|
1176
|
+
>>> m = nn.Softmax2d()
|
1177
|
+
>>> # you softmax over the 2nd dimension
|
1178
|
+
>>> x = brainstate.random.randn(2, 3, 12, 13)
|
1179
|
+
>>> output = m(x)
|
1180
|
+
"""
|
1181
|
+
__module__ = 'brainstate.nn'
|
1182
|
+
|
1183
|
+
def __call__(self, x: ArrayLike) -> ArrayLike:
|
1184
|
+
assert x.ndim == 4 or x.ndim == 3, 'Softmax2d requires a 3D or 4D tensor as input'
|
1185
|
+
return F.softmax(x, -3)
|
1186
|
+
|
1187
|
+
|
1188
|
+
class LogSoftmax(ElementWiseBlock):
|
1189
|
+
r"""Applies the :math:`\log(\text{Softmax}(x))` function to an n-dimensional input Tensor.
|
1190
|
+
|
1191
|
+
The LogSoftmax formulation can be simplified as:
|
1192
|
+
|
1193
|
+
.. math::
|
1194
|
+
\text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
|
1195
|
+
|
1196
|
+
Parameters
|
1197
|
+
----------
|
1198
|
+
dim : int, optional
|
1199
|
+
A dimension along which LogSoftmax will be computed.
|
1200
|
+
|
1201
|
+
Shape
|
1202
|
+
-----
|
1203
|
+
- Input: :math:`(*)` where `*` means, any number of additional dimensions
|
1204
|
+
- Output: :math:`(*)`, same shape as the input
|
1205
|
+
|
1206
|
+
Returns
|
1207
|
+
-------
|
1208
|
+
Tensor
|
1209
|
+
A Tensor of the same dimension and shape as the input with
|
1210
|
+
values in the range [-inf, 0)
|
1211
|
+
|
1212
|
+
Examples
|
1213
|
+
--------
|
1214
|
+
.. code-block:: python
|
1215
|
+
|
1216
|
+
>>> import brainstate.nn as nn
|
1217
|
+
>>> import brainstate
|
1218
|
+
>>> m = nn.LogSoftmax(dim=1)
|
1219
|
+
>>> x = brainstate.random.randn(2, 3)
|
1220
|
+
>>> output = m(x)
|
1221
|
+
"""
|
1222
|
+
__module__ = 'brainstate.nn'
|
1223
|
+
dim: Optional[int]
|
1224
|
+
|
1225
|
+
def __init__(self, dim: Optional[int] = None) -> None:
|
1226
|
+
super().__init__()
|
1227
|
+
self.dim = dim
|
1228
|
+
|
1229
|
+
def __call__(self, x: ArrayLike) -> ArrayLike:
|
1230
|
+
return F.log_softmax(x, self.dim)
|
1231
|
+
|
1232
|
+
def __repr__(self):
|
1233
|
+
return f'{self.__class__.__name__}(dim={self.dim})'
|
1234
|
+
|
1235
|
+
|
1236
|
+
class Identity(ElementWiseBlock):
|
1237
|
+
r"""A placeholder identity operator that is argument-insensitive.
|
1238
|
+
|
1239
|
+
Examples
|
1240
|
+
--------
|
1241
|
+
.. code-block:: python
|
1242
|
+
|
1243
|
+
>>> import brainstate.nn as nn
|
1244
|
+
>>> m = nn.Identity()
|
1245
|
+
>>> x = brainstate.random.randn(2, 3)
|
1246
|
+
>>> output = m(x)
|
1247
|
+
>>> assert (output == x).all()
|
1248
|
+
"""
|
1249
|
+
__module__ = 'brainstate.nn'
|
1250
|
+
|
1251
|
+
def __call__(self, x):
|
1252
|
+
return x
|
1253
|
+
|
1254
|
+
|
1255
|
+
class SpikeBitwise(ElementWiseBlock):
|
1256
|
+
r"""Bitwise addition for the spiking inputs.
|
1257
|
+
|
1258
|
+
.. math::
|
1259
|
+
|
1260
|
+
\begin{array}{ccc}
|
1261
|
+
\hline \text { Mode } & \text { Expression for } \mathrm{g}(\mathrm{x}, \mathrm{y}) & \text { Code for } \mathrm{g}(\mathrm{x}, \mathrm{y}) \\
|
1262
|
+
\hline \text { ADD } & x+y & x+y \\
|
1263
|
+
\text { AND } & x \cap y & x \cdot y \\
|
1264
|
+
\text { IAND } & (\neg x) \cap y & (1-x) \cdot y \\
|
1265
|
+
\text { OR } & x \cup y & (x+y)-(x \cdot y) \\
|
1266
|
+
\hline
|
1267
|
+
\end{array}
|
1268
|
+
|
1269
|
+
Parameters
|
1270
|
+
----------
|
1271
|
+
op : str, optional
|
1272
|
+
The bitwise operation. Default: 'add'
|
1273
|
+
name : str, optional
|
1274
|
+
The name of the dynamic system.
|
1275
|
+
|
1276
|
+
Examples
|
1277
|
+
--------
|
1278
|
+
.. code-block:: python
|
1279
|
+
|
1280
|
+
>>> import brainstate.nn as nn
|
1281
|
+
>>> m = nn.SpikeBitwise(op='and')
|
1282
|
+
>>> x = brainstate.random.randn(2, 3) > 0
|
1283
|
+
>>> y = brainstate.random.randn(2, 3) > 0
|
1284
|
+
>>> output = m(x, y)
|
1285
|
+
"""
|
1286
|
+
__module__ = 'brainstate.nn'
|
1287
|
+
|
1288
|
+
def __init__(
|
1289
|
+
self,
|
1290
|
+
op: str = 'add',
|
1291
|
+
name: Optional[str] = None
|
1292
|
+
) -> None:
|
1293
|
+
super().__init__(name=name)
|
1294
|
+
self.op = op
|
1295
|
+
|
1296
|
+
def __call__(self, x, y):
|
1297
|
+
import braintools
|
1298
|
+
return braintools.spike_bitwise(x, y, self.op)
|