brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +169 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2319 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +1652 -1652
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1624 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1433 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +137 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +633 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +154 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +477 -477
- brainstate/nn/_dynamics.py +1267 -1267
- brainstate/nn/_dynamics_test.py +67 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +384 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/_rand_funs.py +3938 -3938
- brainstate/random/_rand_funs_test.py +640 -640
- brainstate/random/_rand_seed.py +675 -675
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1617 -1617
- brainstate/random/_rand_state_test.py +551 -551
- brainstate/transform/__init__.py +59 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_eval_shape.py +145 -145
- brainstate/transform/_eval_shape_test.py +38 -38
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2016 -2016
- brainstate/transform/_make_jaxpr_test.py +1510 -1510
- brainstate/transform/_mapping.py +529 -529
- brainstate/transform/_mapping_test.py +194 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_random.py +171 -171
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
- brainstate-0.2.1.dist-info/RECORD +111 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
- brainstate-0.2.0.dist-info/RECORD +0 -111
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,830 +1,830 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import unittest
|
17
|
-
from absl.testing import absltest
|
18
|
-
from absl.testing import parameterized
|
19
|
-
|
20
|
-
import jax
|
21
|
-
import jax.numpy as jnp
|
22
|
-
import numpy as np
|
23
|
-
|
24
|
-
import brainstate
|
25
|
-
import brainstate.nn as nn
|
26
|
-
|
27
|
-
|
28
|
-
class TestActivationFunctions(parameterized.TestCase):
|
29
|
-
"""Comprehensive tests for activation functions."""
|
30
|
-
|
31
|
-
def setUp(self):
|
32
|
-
"""Set up test fixtures."""
|
33
|
-
self.seed = 42
|
34
|
-
self.key = jax.random.PRNGKey(self.seed)
|
35
|
-
|
36
|
-
def _check_shape_preservation(self, layer, input_shape):
|
37
|
-
"""Helper to check if layer preserves input shape."""
|
38
|
-
x = jax.random.normal(self.key, input_shape)
|
39
|
-
output = layer(x)
|
40
|
-
self.assertEqual(output.shape, x.shape)
|
41
|
-
|
42
|
-
def _check_gradient_flow(self, layer, input_shape):
|
43
|
-
"""Helper to check if gradients can flow through the layer."""
|
44
|
-
x = jax.random.normal(self.key, input_shape)
|
45
|
-
|
46
|
-
def loss_fn(x):
|
47
|
-
return jnp.sum(layer(x))
|
48
|
-
|
49
|
-
grad = jax.grad(loss_fn)(x)
|
50
|
-
self.assertEqual(grad.shape, x.shape)
|
51
|
-
# Check that gradients are not all zeros (for most activations)
|
52
|
-
if not isinstance(layer, (nn.Threshold, nn.Hardtanh, nn.ReLU6)):
|
53
|
-
self.assertFalse(jnp.allclose(grad, 0.0))
|
54
|
-
|
55
|
-
# Test Threshold
|
56
|
-
def test_threshold_functionality(self):
|
57
|
-
"""Test Threshold activation function."""
|
58
|
-
layer = nn.Threshold(threshold=0.5, value=0.0)
|
59
|
-
|
60
|
-
# Test with values above and below threshold
|
61
|
-
x = jnp.array([-1.0, 0.0, 0.3, 0.7, 1.0])
|
62
|
-
output = layer(x)
|
63
|
-
expected = jnp.array([0.0, 0.0, 0.0, 0.7, 1.0])
|
64
|
-
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
65
|
-
|
66
|
-
@parameterized.parameters(
|
67
|
-
((2,), ),
|
68
|
-
((3, 4), ),
|
69
|
-
((2, 3, 4), ),
|
70
|
-
((2, 3, 4, 5), ),
|
71
|
-
)
|
72
|
-
def test_threshold_shapes(self, shape):
|
73
|
-
"""Test Threshold with different input shapes."""
|
74
|
-
layer = nn.Threshold(threshold=0.1, value=20)
|
75
|
-
self._check_shape_preservation(layer, shape)
|
76
|
-
|
77
|
-
# Test ReLU
|
78
|
-
def test_relu_functionality(self):
|
79
|
-
"""Test ReLU activation function."""
|
80
|
-
layer = nn.ReLU()
|
81
|
-
|
82
|
-
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
|
83
|
-
output = layer(x)
|
84
|
-
expected = jnp.array([0.0, 0.0, 0.0, 1.0, 2.0])
|
85
|
-
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
86
|
-
|
87
|
-
@parameterized.parameters(
|
88
|
-
((10,), ),
|
89
|
-
((5, 10), ),
|
90
|
-
((3, 5, 10), ),
|
91
|
-
)
|
92
|
-
def test_relu_shapes_and_gradients(self, shape):
|
93
|
-
"""Test ReLU shapes and gradients."""
|
94
|
-
layer = nn.ReLU()
|
95
|
-
self._check_shape_preservation(layer, shape)
|
96
|
-
self._check_gradient_flow(layer, shape)
|
97
|
-
|
98
|
-
# Test RReLU
|
99
|
-
def test_rrelu_functionality(self):
|
100
|
-
"""Test RReLU activation function."""
|
101
|
-
layer = nn.RReLU(lower=0.1, upper=0.3)
|
102
|
-
|
103
|
-
# Test positive and negative values
|
104
|
-
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
|
105
|
-
output = layer(x)
|
106
|
-
|
107
|
-
# Positive values should remain unchanged
|
108
|
-
self.assertTrue(jnp.all(output[x > 0] == x[x > 0]))
|
109
|
-
# Negative values should be scaled by a factor in [lower, upper]
|
110
|
-
negative_mask = x < 0
|
111
|
-
if jnp.any(negative_mask):
|
112
|
-
scaled = output[negative_mask] / x[negative_mask]
|
113
|
-
self.assertTrue(jnp.all((scaled >= 0.1) & (scaled <= 0.3)))
|
114
|
-
|
115
|
-
# Test Hardtanh
|
116
|
-
def test_hardtanh_functionality(self):
|
117
|
-
"""Test Hardtanh activation function."""
|
118
|
-
layer = nn.Hardtanh(min_val=-1.0, max_val=1.0)
|
119
|
-
|
120
|
-
x = jnp.array([-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0])
|
121
|
-
output = layer(x)
|
122
|
-
expected = jnp.array([-1.0, -1.0, -0.5, 0.0, 0.5, 1.0, 1.0])
|
123
|
-
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
124
|
-
|
125
|
-
def test_hardtanh_custom_bounds(self):
|
126
|
-
"""Test Hardtanh with custom bounds."""
|
127
|
-
layer = nn.Hardtanh(min_val=-2.0, max_val=3.0)
|
128
|
-
|
129
|
-
x = jnp.array([-3.0, -2.0, 0.0, 3.0, 4.0])
|
130
|
-
output = layer(x)
|
131
|
-
expected = jnp.array([-2.0, -2.0, 0.0, 3.0, 3.0])
|
132
|
-
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
133
|
-
|
134
|
-
# Test ReLU6
|
135
|
-
def test_relu6_functionality(self):
|
136
|
-
"""Test ReLU6 activation function."""
|
137
|
-
layer = nn.ReLU6()
|
138
|
-
|
139
|
-
x = jnp.array([-2.0, 0.0, 3.0, 6.0, 8.0])
|
140
|
-
output = layer(x)
|
141
|
-
expected = jnp.array([0.0, 0.0, 3.0, 6.0, 6.0])
|
142
|
-
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
143
|
-
|
144
|
-
# Test Sigmoid
|
145
|
-
def test_sigmoid_functionality(self):
|
146
|
-
"""Test Sigmoid activation function."""
|
147
|
-
layer = nn.Sigmoid()
|
148
|
-
|
149
|
-
x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])
|
150
|
-
output = layer(x)
|
151
|
-
|
152
|
-
# Check sigmoid properties
|
153
|
-
self.assertTrue(jnp.all((output >= 0.0) & (output <= 1.0)))
|
154
|
-
np.testing.assert_allclose(output[2], 0.5, rtol=1e-5) # sigmoid(0) = 0.5
|
155
|
-
|
156
|
-
@parameterized.parameters(
|
157
|
-
((10,), ),
|
158
|
-
((5, 10), ),
|
159
|
-
((3, 5, 10), ),
|
160
|
-
)
|
161
|
-
def test_sigmoid_shapes_and_gradients(self, shape):
|
162
|
-
"""Test Sigmoid shapes and gradients."""
|
163
|
-
layer = nn.Sigmoid()
|
164
|
-
self._check_shape_preservation(layer, shape)
|
165
|
-
self._check_gradient_flow(layer, shape)
|
166
|
-
|
167
|
-
# Test Hardsigmoid
|
168
|
-
def test_hardsigmoid_functionality(self):
|
169
|
-
"""Test Hardsigmoid activation function."""
|
170
|
-
layer = nn.Hardsigmoid()
|
171
|
-
|
172
|
-
x = jnp.array([-4.0, -3.0, -1.0, 0.0, 1.0, 3.0, 4.0])
|
173
|
-
output = layer(x)
|
174
|
-
|
175
|
-
# Check bounds
|
176
|
-
self.assertTrue(jnp.all((output >= 0.0) & (output <= 1.0)))
|
177
|
-
# Check specific values
|
178
|
-
np.testing.assert_allclose(output[1], 0.0, rtol=1e-5) # x=-3
|
179
|
-
np.testing.assert_allclose(output[3], 0.5, rtol=1e-5) # x=0
|
180
|
-
np.testing.assert_allclose(output[5], 1.0, rtol=1e-5) # x=3
|
181
|
-
|
182
|
-
# Test Tanh
|
183
|
-
def test_tanh_functionality(self):
|
184
|
-
"""Test Tanh activation function."""
|
185
|
-
layer = nn.Tanh()
|
186
|
-
|
187
|
-
x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])
|
188
|
-
output = layer(x)
|
189
|
-
|
190
|
-
# Check tanh properties
|
191
|
-
self.assertTrue(jnp.all((output >= -1.0) & (output <= 1.0)))
|
192
|
-
np.testing.assert_allclose(output[2], 0.0, rtol=1e-5) # tanh(0) = 0
|
193
|
-
|
194
|
-
# Test SiLU (Swish)
|
195
|
-
def test_silu_functionality(self):
|
196
|
-
"""Test SiLU activation function."""
|
197
|
-
layer = nn.SiLU()
|
198
|
-
|
199
|
-
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
|
200
|
-
output = layer(x)
|
201
|
-
|
202
|
-
# SiLU(x) = x * sigmoid(x)
|
203
|
-
expected = x * jax.nn.sigmoid(x)
|
204
|
-
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
205
|
-
|
206
|
-
# Test Mish
|
207
|
-
def test_mish_functionality(self):
|
208
|
-
"""Test Mish activation function."""
|
209
|
-
layer = nn.Mish()
|
210
|
-
|
211
|
-
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
|
212
|
-
output = layer(x)
|
213
|
-
|
214
|
-
# Mish(x) = x * tanh(softplus(x))
|
215
|
-
expected = x * jnp.tanh(jax.nn.softplus(x))
|
216
|
-
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
217
|
-
|
218
|
-
# Test Hardswish
|
219
|
-
def test_hardswish_functionality(self):
|
220
|
-
"""Test Hardswish activation function."""
|
221
|
-
layer = nn.Hardswish()
|
222
|
-
|
223
|
-
x = jnp.array([-4.0, -3.0, -1.0, 0.0, 1.0, 3.0, 4.0])
|
224
|
-
output = layer(x)
|
225
|
-
|
226
|
-
# Check boundary conditions
|
227
|
-
np.testing.assert_allclose(output[1], 0.0, rtol=1e-5) # x=-3
|
228
|
-
np.testing.assert_allclose(output[5], 3.0, rtol=1e-5) # x=3
|
229
|
-
|
230
|
-
# Test ELU
|
231
|
-
def test_elu_functionality(self):
|
232
|
-
"""Test ELU activation function."""
|
233
|
-
layer = nn.ELU(alpha=1.0)
|
234
|
-
|
235
|
-
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
|
236
|
-
output = layer(x)
|
237
|
-
|
238
|
-
# Positive values should remain unchanged
|
239
|
-
self.assertTrue(jnp.all(output[x > 0] == x[x > 0]))
|
240
|
-
# Check ELU formula for negative values
|
241
|
-
negative_mask = x <= 0
|
242
|
-
expected_negative = 1.0 * (jnp.exp(x[negative_mask]) - 1)
|
243
|
-
np.testing.assert_allclose(output[negative_mask], expected_negative, rtol=1e-5)
|
244
|
-
|
245
|
-
def test_elu_with_different_alpha(self):
|
246
|
-
"""Test ELU with different alpha values."""
|
247
|
-
alpha = 2.0
|
248
|
-
layer = nn.ELU(alpha=alpha)
|
249
|
-
|
250
|
-
x = jnp.array([-1.0])
|
251
|
-
output = layer(x)
|
252
|
-
expected = alpha * (jnp.exp(x) - 1)
|
253
|
-
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
254
|
-
|
255
|
-
# Test CELU
|
256
|
-
def test_celu_functionality(self):
|
257
|
-
"""Test CELU activation function."""
|
258
|
-
layer = nn.CELU(alpha=1.0)
|
259
|
-
|
260
|
-
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
|
261
|
-
output = layer(x)
|
262
|
-
|
263
|
-
# Positive values should remain unchanged
|
264
|
-
self.assertTrue(jnp.all(output[x > 0] == x[x > 0]))
|
265
|
-
|
266
|
-
# Test SELU
|
267
|
-
def test_selu_functionality(self):
|
268
|
-
"""Test SELU activation function."""
|
269
|
-
layer = nn.SELU()
|
270
|
-
|
271
|
-
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
|
272
|
-
output = layer(x)
|
273
|
-
|
274
|
-
# Check that output is scaled ELU
|
275
|
-
# SELU has specific scale and alpha values
|
276
|
-
scale = 1.0507009873554804934193349852946
|
277
|
-
alpha = 1.6732632423543772848170429916717
|
278
|
-
|
279
|
-
positive_mask = x > 0
|
280
|
-
self.assertTrue(jnp.all(output[positive_mask] == scale * x[positive_mask]))
|
281
|
-
|
282
|
-
# Test GLU
|
283
|
-
def test_glu_functionality(self):
|
284
|
-
"""Test GLU activation function."""
|
285
|
-
layer = nn.GLU(dim=-1)
|
286
|
-
|
287
|
-
# GLU splits input in half along specified dimension
|
288
|
-
x = jnp.array([[1.0, 2.0, 3.0, 4.0],
|
289
|
-
[5.0, 6.0, 7.0, 8.0]])
|
290
|
-
output = layer(x)
|
291
|
-
|
292
|
-
# Output should have half the size along the split dimension
|
293
|
-
self.assertEqual(output.shape, (2, 2))
|
294
|
-
|
295
|
-
def test_glu_different_dimensions(self):
|
296
|
-
"""Test GLU with different split dimensions."""
|
297
|
-
# Test splitting along different dimensions
|
298
|
-
x = jax.random.normal(self.key, (4, 6, 8))
|
299
|
-
|
300
|
-
layer_0 = nn.GLU(dim=0)
|
301
|
-
output_0 = layer_0(x)
|
302
|
-
self.assertEqual(output_0.shape, (2, 6, 8))
|
303
|
-
|
304
|
-
layer_1 = nn.GLU(dim=1)
|
305
|
-
output_1 = layer_1(x)
|
306
|
-
self.assertEqual(output_1.shape, (4, 3, 8))
|
307
|
-
|
308
|
-
layer_2 = nn.GLU(dim=2)
|
309
|
-
output_2 = layer_2(x)
|
310
|
-
self.assertEqual(output_2.shape, (4, 6, 4))
|
311
|
-
|
312
|
-
# Test GELU
|
313
|
-
def test_gelu_functionality(self):
|
314
|
-
"""Test GELU activation function."""
|
315
|
-
layer = nn.GELU(approximate=False)
|
316
|
-
|
317
|
-
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
|
318
|
-
output = layer(x)
|
319
|
-
|
320
|
-
# GELU should be smooth and differentiable everywhere
|
321
|
-
np.testing.assert_allclose(output[2], 0.0, rtol=1e-5) # GELU(0) ≈ 0
|
322
|
-
|
323
|
-
def test_gelu_approximate(self):
|
324
|
-
"""Test GELU with tanh approximation."""
|
325
|
-
layer_exact = nn.GELU(approximate=False)
|
326
|
-
layer_approx = nn.GELU(approximate=True)
|
327
|
-
|
328
|
-
x = jnp.array([-1.0, 0.0, 1.0])
|
329
|
-
output_exact = layer_exact(x)
|
330
|
-
output_approx = layer_approx(x)
|
331
|
-
|
332
|
-
# Approximation should be close but not exactly equal
|
333
|
-
np.testing.assert_allclose(output_exact, output_approx, rtol=1e-2)
|
334
|
-
|
335
|
-
# Test Hardshrink
|
336
|
-
def test_hardshrink_functionality(self):
|
337
|
-
"""Test Hardshrink activation function."""
|
338
|
-
lambd = 0.5
|
339
|
-
layer = nn.Hardshrink(lambd=lambd)
|
340
|
-
|
341
|
-
x = jnp.array([-1.0, -0.6, -0.5, -0.3, 0.0, 0.3, 0.5, 0.6, 1.0])
|
342
|
-
output = layer(x)
|
343
|
-
|
344
|
-
# Check each value according to hardshrink formula
|
345
|
-
expected = []
|
346
|
-
for xi in x:
|
347
|
-
if xi > lambd:
|
348
|
-
expected.append(xi)
|
349
|
-
elif xi < -lambd:
|
350
|
-
expected.append(xi)
|
351
|
-
else:
|
352
|
-
expected.append(0.0)
|
353
|
-
expected = jnp.array(expected)
|
354
|
-
|
355
|
-
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
356
|
-
|
357
|
-
# Test LeakyReLU
|
358
|
-
def test_leaky_relu_functionality(self):
|
359
|
-
"""Test LeakyReLU activation function."""
|
360
|
-
negative_slope = 0.01
|
361
|
-
layer = nn.LeakyReLU(negative_slope=negative_slope)
|
362
|
-
|
363
|
-
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
|
364
|
-
output = layer(x)
|
365
|
-
|
366
|
-
# Positive values should remain unchanged
|
367
|
-
self.assertTrue(jnp.all(output[x > 0] == x[x > 0]))
|
368
|
-
# Negative values should be scaled
|
369
|
-
negative_mask = x < 0
|
370
|
-
expected_negative = negative_slope * x[negative_mask]
|
371
|
-
np.testing.assert_allclose(output[negative_mask], expected_negative, rtol=1e-5)
|
372
|
-
|
373
|
-
def test_leaky_relu_custom_slope(self):
|
374
|
-
"""Test LeakyReLU with custom negative slope."""
|
375
|
-
negative_slope = 0.2
|
376
|
-
layer = nn.LeakyReLU(negative_slope=negative_slope)
|
377
|
-
|
378
|
-
x = jnp.array([-5.0])
|
379
|
-
output = layer(x)
|
380
|
-
expected = negative_slope * x
|
381
|
-
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
382
|
-
|
383
|
-
# Test LogSigmoid
|
384
|
-
def test_log_sigmoid_functionality(self):
|
385
|
-
"""Test LogSigmoid activation function."""
|
386
|
-
layer = nn.LogSigmoid()
|
387
|
-
|
388
|
-
x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])
|
389
|
-
output = layer(x)
|
390
|
-
|
391
|
-
# LogSigmoid(x) = log(sigmoid(x))
|
392
|
-
expected = jnp.log(jax.nn.sigmoid(x))
|
393
|
-
np.testing.assert_allclose(output, expected, rtol=1e-2)
|
394
|
-
|
395
|
-
# Output should always be negative or zero
|
396
|
-
self.assertTrue(jnp.all(output <= 0.0))
|
397
|
-
|
398
|
-
# Test Softplus
|
399
|
-
def test_softplus_functionality(self):
|
400
|
-
"""Test Softplus activation function."""
|
401
|
-
layer = nn.Softplus()
|
402
|
-
|
403
|
-
x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])
|
404
|
-
output = layer(x)
|
405
|
-
|
406
|
-
# Softplus is a smooth approximation to ReLU
|
407
|
-
# Should always be positive
|
408
|
-
self.assertTrue(jnp.all(output > 0.0))
|
409
|
-
|
410
|
-
# For large positive values, should approximate x
|
411
|
-
np.testing.assert_allclose(output[-1], x[-1], rtol=1e-2)
|
412
|
-
|
413
|
-
# Test Softshrink
|
414
|
-
def test_softshrink_functionality(self):
|
415
|
-
"""Test Softshrink activation function."""
|
416
|
-
lambd = 0.5
|
417
|
-
layer = nn.Softshrink(lambd=lambd)
|
418
|
-
|
419
|
-
x = jnp.array([-1.0, -0.5, -0.3, 0.0, 0.3, 0.5, 1.0])
|
420
|
-
output = layer(x)
|
421
|
-
|
422
|
-
# Check the softshrink formula
|
423
|
-
for i in range(len(x)):
|
424
|
-
if x[i] > lambd:
|
425
|
-
expected = x[i] - lambd
|
426
|
-
elif x[i] < -lambd:
|
427
|
-
expected = x[i] + lambd
|
428
|
-
else:
|
429
|
-
expected = 0.0
|
430
|
-
np.testing.assert_allclose(output[i], expected, rtol=1e-5)
|
431
|
-
|
432
|
-
# Test PReLU
|
433
|
-
def test_prelu_functionality(self):
|
434
|
-
"""Test PReLU activation function."""
|
435
|
-
layer = nn.PReLU(num_parameters=1, init=0.25)
|
436
|
-
|
437
|
-
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
|
438
|
-
output = layer(x)
|
439
|
-
|
440
|
-
# Positive values should remain unchanged
|
441
|
-
self.assertTrue(jnp.all(output[x > 0] == x[x > 0]))
|
442
|
-
# Negative values should be scaled by learned parameter
|
443
|
-
negative_mask = x < 0
|
444
|
-
# Check that negative values are scaled
|
445
|
-
self.assertTrue(jnp.all(output[negative_mask] != x[negative_mask]))
|
446
|
-
|
447
|
-
def test_prelu_multi_channel(self):
|
448
|
-
"""Test PReLU with multiple channels."""
|
449
|
-
num_channels = 3
|
450
|
-
layer = nn.PReLU(num_parameters=num_channels, init=0.25)
|
451
|
-
|
452
|
-
# Input shape: (batch, channels, height, width)
|
453
|
-
x = jax.random.normal(self.key, (2, 4, 4, num_channels))
|
454
|
-
output = layer(x)
|
455
|
-
|
456
|
-
self.assertEqual(output.shape, x.shape)
|
457
|
-
|
458
|
-
# Test Softsign
|
459
|
-
def test_softsign_functionality(self):
|
460
|
-
"""Test Softsign activation function."""
|
461
|
-
layer = nn.Softsign()
|
462
|
-
|
463
|
-
x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])
|
464
|
-
output = layer(x)
|
465
|
-
|
466
|
-
# Softsign(x) = x / (1 + |x|)
|
467
|
-
expected = x / (1 + jnp.abs(x))
|
468
|
-
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
469
|
-
|
470
|
-
# Output should be bounded between -1 and 1
|
471
|
-
self.assertTrue(jnp.all((output >= -1.0) & (output <= 1.0)))
|
472
|
-
|
473
|
-
# Test Tanhshrink
|
474
|
-
def test_tanhshrink_functionality(self):
|
475
|
-
"""Test Tanhshrink activation function."""
|
476
|
-
layer = nn.Tanhshrink()
|
477
|
-
|
478
|
-
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
|
479
|
-
output = layer(x)
|
480
|
-
|
481
|
-
# Tanhshrink(x) = x - tanh(x)
|
482
|
-
expected = x - jnp.tanh(x)
|
483
|
-
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
484
|
-
|
485
|
-
# Test Softmin
|
486
|
-
def test_softmin_functionality(self):
|
487
|
-
"""Test Softmin activation function."""
|
488
|
-
layer = nn.Softmin(dim=-1)
|
489
|
-
|
490
|
-
x = jnp.array([[1.0, 2.0, 3.0],
|
491
|
-
[4.0, 5.0, 6.0]])
|
492
|
-
output = layer(x)
|
493
|
-
|
494
|
-
# Softmin should sum to 1 along the specified dimension
|
495
|
-
sums = jnp.sum(output, axis=-1)
|
496
|
-
np.testing.assert_allclose(sums, jnp.ones_like(sums), rtol=1e-5)
|
497
|
-
|
498
|
-
# Higher values should have lower probabilities
|
499
|
-
self.assertTrue(jnp.all(output[:, 0] > output[:, 1]))
|
500
|
-
self.assertTrue(jnp.all(output[:, 1] > output[:, 2]))
|
501
|
-
|
502
|
-
# Test Softmax
|
503
|
-
def test_softmax_functionality(self):
|
504
|
-
"""Test Softmax activation function."""
|
505
|
-
layer = nn.Softmax(dim=-1)
|
506
|
-
|
507
|
-
x = jnp.array([[1.0, 2.0, 3.0],
|
508
|
-
[4.0, 5.0, 6.0]])
|
509
|
-
output = layer(x)
|
510
|
-
|
511
|
-
# Softmax should sum to 1 along the specified dimension
|
512
|
-
sums = jnp.sum(output, axis=-1)
|
513
|
-
np.testing.assert_allclose(sums, jnp.ones_like(sums), rtol=1e-5)
|
514
|
-
|
515
|
-
# Higher values should have higher probabilities
|
516
|
-
self.assertTrue(jnp.all(output[:, 2] > output[:, 1]))
|
517
|
-
self.assertTrue(jnp.all(output[:, 1] > output[:, 0]))
|
518
|
-
|
519
|
-
def test_softmax_numerical_stability(self):
|
520
|
-
"""Test Softmax numerical stability with large values."""
|
521
|
-
layer = nn.Softmax(dim=-1)
|
522
|
-
|
523
|
-
# Test with large values that could cause overflow
|
524
|
-
x = jnp.array([[1000.0, 1000.0, 1000.0]])
|
525
|
-
output = layer(x)
|
526
|
-
|
527
|
-
# Should still sum to 1 and have equal probabilities
|
528
|
-
np.testing.assert_allclose(jnp.sum(output), 1.0, rtol=1e-5)
|
529
|
-
np.testing.assert_allclose(output[0, 0], 1/3, rtol=1e-5)
|
530
|
-
|
531
|
-
# Test Softmax2d
|
532
|
-
def test_softmax2d_functionality(self):
|
533
|
-
"""Test Softmax2d activation function."""
|
534
|
-
layer = nn.Softmax2d()
|
535
|
-
|
536
|
-
# Input shape: (batch, channels, height, width)
|
537
|
-
x = jax.random.normal(self.key, (2, 3, 4, 5))
|
538
|
-
output = layer(x)
|
539
|
-
|
540
|
-
self.assertEqual(output.shape, x.shape)
|
541
|
-
|
542
|
-
# Should sum to 1 across channels for each spatial location
|
543
|
-
channel_sums = jnp.sum(output, axis=1)
|
544
|
-
np.testing.assert_allclose(channel_sums, jnp.ones_like(channel_sums), rtol=1e-5)
|
545
|
-
|
546
|
-
def test_softmax2d_3d_input(self):
|
547
|
-
"""Test Softmax2d with 3D input."""
|
548
|
-
layer = nn.Softmax2d()
|
549
|
-
|
550
|
-
# Input shape: (channels, height, width)
|
551
|
-
x = jax.random.normal(self.key, (3, 4, 5))
|
552
|
-
output = layer(x)
|
553
|
-
|
554
|
-
self.assertEqual(output.shape, x.shape)
|
555
|
-
|
556
|
-
# Test LogSoftmax
|
557
|
-
def test_log_softmax_functionality(self):
|
558
|
-
"""Test LogSoftmax activation function."""
|
559
|
-
layer = nn.LogSoftmax(dim=-1)
|
560
|
-
|
561
|
-
x = jnp.array([[1.0, 2.0, 3.0],
|
562
|
-
[4.0, 5.0, 6.0]])
|
563
|
-
output = layer(x)
|
564
|
-
|
565
|
-
# LogSoftmax = log(softmax(x))
|
566
|
-
softmax_output = jax.nn.softmax(x, axis=-1)
|
567
|
-
expected = jnp.log(softmax_output)
|
568
|
-
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
569
|
-
|
570
|
-
# Output should be all negative or zero
|
571
|
-
self.assertTrue(jnp.all(output <= 0.0))
|
572
|
-
|
573
|
-
def test_log_softmax_numerical_stability(self):
|
574
|
-
"""Test LogSoftmax numerical stability."""
|
575
|
-
layer = nn.LogSoftmax(dim=-1)
|
576
|
-
|
577
|
-
# Test with values that could cause numerical issues
|
578
|
-
x = jnp.array([[1000.0, 0.0, -1000.0]])
|
579
|
-
output = layer(x)
|
580
|
-
|
581
|
-
# Should not contain NaN or Inf
|
582
|
-
self.assertFalse(jnp.any(jnp.isnan(output)))
|
583
|
-
self.assertFalse(jnp.any(jnp.isinf(output)))
|
584
|
-
|
585
|
-
# Test Identity
|
586
|
-
def test_identity_functionality(self):
|
587
|
-
"""Test Identity activation function."""
|
588
|
-
layer = nn.Identity()
|
589
|
-
|
590
|
-
x = jax.random.normal(self.key, (3, 4, 5))
|
591
|
-
output = layer(x)
|
592
|
-
|
593
|
-
# Should be exactly equal to input
|
594
|
-
np.testing.assert_array_equal(output, x)
|
595
|
-
|
596
|
-
def test_identity_gradient(self):
|
597
|
-
"""Test Identity gradient flow."""
|
598
|
-
layer = nn.Identity()
|
599
|
-
|
600
|
-
x = jax.random.normal(self.key, (3, 4))
|
601
|
-
|
602
|
-
def loss_fn(x):
|
603
|
-
return jnp.sum(layer(x))
|
604
|
-
|
605
|
-
grad = jax.grad(loss_fn)(x)
|
606
|
-
|
607
|
-
# Gradient should be all ones
|
608
|
-
np.testing.assert_allclose(grad, jnp.ones_like(x), rtol=1e-5)
|
609
|
-
|
610
|
-
# Test SpikeBitwise
|
611
|
-
def test_spike_bitwise_add(self):
|
612
|
-
"""Test SpikeBitwise with ADD operation."""
|
613
|
-
layer = nn.SpikeBitwise(op='and')
|
614
|
-
|
615
|
-
x = jnp.array([[1.0, 0.0], [1.0, 1.0]])
|
616
|
-
y = jnp.array([[1.0, 1.0], [0.0, 1.0]])
|
617
|
-
output = layer(x, y)
|
618
|
-
|
619
|
-
expected = jnp.logical_and(x, y)
|
620
|
-
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
621
|
-
|
622
|
-
def test_spike_bitwise_and(self):
|
623
|
-
"""Test SpikeBitwise with AND operation."""
|
624
|
-
layer = nn.SpikeBitwise(op='and')
|
625
|
-
|
626
|
-
x = jnp.array([[1.0, 0.0], [1.0, 1.0]])
|
627
|
-
y = jnp.array([[1.0, 1.0], [0.0, 1.0]])
|
628
|
-
output = layer(x, y)
|
629
|
-
|
630
|
-
expected = x * y
|
631
|
-
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
632
|
-
|
633
|
-
def test_spike_bitwise_iand(self):
|
634
|
-
"""Test SpikeBitwise with IAND operation."""
|
635
|
-
layer = nn.SpikeBitwise(op='iand')
|
636
|
-
|
637
|
-
x = jnp.array([[1.0, 0.0], [1.0, 1.0]])
|
638
|
-
y = jnp.array([[1.0, 1.0], [0.0, 1.0]])
|
639
|
-
output = layer(x, y)
|
640
|
-
|
641
|
-
expected = (1 - x) * y
|
642
|
-
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
643
|
-
|
644
|
-
def test_spike_bitwise_or(self):
|
645
|
-
"""Test SpikeBitwise with OR operation."""
|
646
|
-
layer = nn.SpikeBitwise(op='or')
|
647
|
-
|
648
|
-
x = jnp.array([[1.0, 0.0], [1.0, 1.0]])
|
649
|
-
y = jnp.array([[1.0, 1.0], [0.0, 1.0]])
|
650
|
-
output = layer(x, y)
|
651
|
-
|
652
|
-
expected = (x + y) - (x * y)
|
653
|
-
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
654
|
-
|
655
|
-
|
656
|
-
class TestEdgeCases(parameterized.TestCase):
|
657
|
-
"""Test edge cases and boundary conditions."""
|
658
|
-
|
659
|
-
def test_zero_input(self):
|
660
|
-
"""Test all activations with zero input."""
|
661
|
-
x = jnp.zeros((3, 4))
|
662
|
-
|
663
|
-
activations = [
|
664
|
-
nn.ReLU(),
|
665
|
-
nn.Sigmoid(),
|
666
|
-
nn.Tanh(),
|
667
|
-
nn.SiLU(),
|
668
|
-
nn.ELU(),
|
669
|
-
nn.GELU(),
|
670
|
-
nn.Softplus(),
|
671
|
-
nn.Softsign(),
|
672
|
-
]
|
673
|
-
|
674
|
-
for activation in activations:
|
675
|
-
output = activation(x)
|
676
|
-
self.assertEqual(output.shape, x.shape)
|
677
|
-
self.assertFalse(jnp.any(jnp.isnan(output)))
|
678
|
-
|
679
|
-
def test_large_positive_input(self):
|
680
|
-
"""Test activations with very large positive values."""
|
681
|
-
x = jnp.ones((2, 3)) * 1000.0
|
682
|
-
|
683
|
-
activations = [
|
684
|
-
nn.ReLU(),
|
685
|
-
nn.Sigmoid(),
|
686
|
-
nn.Tanh(),
|
687
|
-
nn.Hardsigmoid(),
|
688
|
-
nn.Hardswish(),
|
689
|
-
]
|
690
|
-
|
691
|
-
for activation in activations:
|
692
|
-
output = activation(x)
|
693
|
-
self.assertFalse(jnp.any(jnp.isnan(output)))
|
694
|
-
self.assertFalse(jnp.any(jnp.isinf(output)))
|
695
|
-
|
696
|
-
def test_large_negative_input(self):
|
697
|
-
"""Test activations with very large negative values."""
|
698
|
-
x = jnp.ones((2, 3)) * -1000.0
|
699
|
-
|
700
|
-
activations = [
|
701
|
-
nn.ReLU(),
|
702
|
-
nn.Sigmoid(),
|
703
|
-
nn.Tanh(),
|
704
|
-
nn.Hardsigmoid(),
|
705
|
-
nn.Hardswish(),
|
706
|
-
]
|
707
|
-
|
708
|
-
for activation in activations:
|
709
|
-
output = activation(x)
|
710
|
-
self.assertFalse(jnp.any(jnp.isnan(output)))
|
711
|
-
self.assertFalse(jnp.any(jnp.isinf(output)))
|
712
|
-
|
713
|
-
def test_nan_propagation(self):
|
714
|
-
"""Test that NaN inputs produce NaN outputs (where appropriate)."""
|
715
|
-
x = jnp.array([jnp.nan, 1.0, 2.0])
|
716
|
-
|
717
|
-
activations = [
|
718
|
-
nn.ReLU(),
|
719
|
-
nn.Sigmoid(),
|
720
|
-
nn.Tanh(),
|
721
|
-
]
|
722
|
-
|
723
|
-
for activation in activations:
|
724
|
-
output = activation(x)
|
725
|
-
self.assertTrue(jnp.isnan(output[0]))
|
726
|
-
|
727
|
-
def test_inf_handling(self):
|
728
|
-
"""Test handling of infinite values."""
|
729
|
-
x = jnp.array([jnp.inf, -jnp.inf, 1.0])
|
730
|
-
|
731
|
-
# ReLU should handle inf properly
|
732
|
-
relu = nn.ReLU()
|
733
|
-
output = relu(x)
|
734
|
-
self.assertEqual(output[0], jnp.inf)
|
735
|
-
self.assertEqual(output[1], 0.0)
|
736
|
-
|
737
|
-
# Sigmoid should saturate
|
738
|
-
sigmoid = nn.Sigmoid()
|
739
|
-
output = sigmoid(x)
|
740
|
-
np.testing.assert_allclose(output[0], 1.0, rtol=1e-5)
|
741
|
-
np.testing.assert_allclose(output[1], 0.0, rtol=1e-5)
|
742
|
-
|
743
|
-
|
744
|
-
class TestBatchProcessing(parameterized.TestCase):
|
745
|
-
"""Test batch processing capabilities."""
|
746
|
-
|
747
|
-
@parameterized.parameters(
|
748
|
-
(nn.ReLU(), ),
|
749
|
-
(nn.Sigmoid(), ),
|
750
|
-
(nn.Tanh(), ),
|
751
|
-
(nn.GELU(), ),
|
752
|
-
(nn.SiLU(), ),
|
753
|
-
(nn.ELU(), ),
|
754
|
-
)
|
755
|
-
def test_batch_consistency(self, activation):
|
756
|
-
"""Test that batch processing gives same results as individual processing."""
|
757
|
-
# Process as batch
|
758
|
-
batch_input = jax.random.normal(jax.random.PRNGKey(42), (5, 10))
|
759
|
-
batch_output = activation(batch_input)
|
760
|
-
|
761
|
-
# Process individually
|
762
|
-
individual_outputs = []
|
763
|
-
for i in range(5):
|
764
|
-
individual_output = activation(batch_input[i])
|
765
|
-
individual_outputs.append(individual_output)
|
766
|
-
individual_outputs = jnp.stack(individual_outputs)
|
767
|
-
|
768
|
-
np.testing.assert_allclose(batch_output, individual_outputs, rtol=1e-5)
|
769
|
-
|
770
|
-
def test_different_batch_sizes(self):
|
771
|
-
"""Test activations with different batch sizes."""
|
772
|
-
activation = nn.ReLU()
|
773
|
-
|
774
|
-
for batch_size in [1, 10, 100]:
|
775
|
-
x = jax.random.normal(jax.random.PRNGKey(42), (batch_size, 20))
|
776
|
-
output = activation(x)
|
777
|
-
self.assertEqual(output.shape[0], batch_size)
|
778
|
-
|
779
|
-
|
780
|
-
class TestMemoryAndPerformance(parameterized.TestCase):
|
781
|
-
"""Test memory and performance characteristics."""
|
782
|
-
|
783
|
-
def test_in_place_operations(self):
|
784
|
-
"""Test that activations don't modify input in-place."""
|
785
|
-
x_original = jax.random.normal(jax.random.PRNGKey(42), (10, 10))
|
786
|
-
x = x_original.copy()
|
787
|
-
|
788
|
-
activations = [
|
789
|
-
nn.ReLU(),
|
790
|
-
nn.Sigmoid(),
|
791
|
-
nn.Tanh(),
|
792
|
-
]
|
793
|
-
|
794
|
-
for activation in activations:
|
795
|
-
output = activation(x)
|
796
|
-
np.testing.assert_array_equal(x, x_original)
|
797
|
-
|
798
|
-
def test_jit_compilation(self):
|
799
|
-
"""Test that activations work with JIT compilation."""
|
800
|
-
@jax.jit
|
801
|
-
def forward(x):
|
802
|
-
relu = nn.ReLU()
|
803
|
-
return relu(x)
|
804
|
-
|
805
|
-
x = jax.random.normal(jax.random.PRNGKey(42), (10, 10))
|
806
|
-
output = forward(x)
|
807
|
-
|
808
|
-
# Should not raise any errors and produce valid output
|
809
|
-
self.assertEqual(output.shape, x.shape)
|
810
|
-
|
811
|
-
@parameterized.parameters(
|
812
|
-
(nn.ReLU(), ),
|
813
|
-
(nn.Sigmoid(), ),
|
814
|
-
(nn.Tanh(), ),
|
815
|
-
)
|
816
|
-
def test_vmap_compatibility(self, activation):
|
817
|
-
"""Test compatibility with vmap."""
|
818
|
-
def single_forward(x):
|
819
|
-
return activation(x)
|
820
|
-
|
821
|
-
batch_forward = jax.vmap(single_forward)
|
822
|
-
|
823
|
-
x = jax.random.normal(jax.random.PRNGKey(42), (5, 10, 20))
|
824
|
-
output = batch_forward(x)
|
825
|
-
|
826
|
-
self.assertEqual(output.shape, x.shape)
|
827
|
-
|
828
|
-
|
829
|
-
if __name__ == '__main__':
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import unittest
|
17
|
+
from absl.testing import absltest
|
18
|
+
from absl.testing import parameterized
|
19
|
+
|
20
|
+
import jax
|
21
|
+
import jax.numpy as jnp
|
22
|
+
import numpy as np
|
23
|
+
|
24
|
+
import brainstate
|
25
|
+
import brainstate.nn as nn
|
26
|
+
|
27
|
+
|
28
|
+
class TestActivationFunctions(parameterized.TestCase):
|
29
|
+
"""Comprehensive tests for activation functions."""
|
30
|
+
|
31
|
+
def setUp(self):
|
32
|
+
"""Set up test fixtures."""
|
33
|
+
self.seed = 42
|
34
|
+
self.key = jax.random.PRNGKey(self.seed)
|
35
|
+
|
36
|
+
def _check_shape_preservation(self, layer, input_shape):
|
37
|
+
"""Helper to check if layer preserves input shape."""
|
38
|
+
x = jax.random.normal(self.key, input_shape)
|
39
|
+
output = layer(x)
|
40
|
+
self.assertEqual(output.shape, x.shape)
|
41
|
+
|
42
|
+
def _check_gradient_flow(self, layer, input_shape):
|
43
|
+
"""Helper to check if gradients can flow through the layer."""
|
44
|
+
x = jax.random.normal(self.key, input_shape)
|
45
|
+
|
46
|
+
def loss_fn(x):
|
47
|
+
return jnp.sum(layer(x))
|
48
|
+
|
49
|
+
grad = jax.grad(loss_fn)(x)
|
50
|
+
self.assertEqual(grad.shape, x.shape)
|
51
|
+
# Check that gradients are not all zeros (for most activations)
|
52
|
+
if not isinstance(layer, (nn.Threshold, nn.Hardtanh, nn.ReLU6)):
|
53
|
+
self.assertFalse(jnp.allclose(grad, 0.0))
|
54
|
+
|
55
|
+
# Test Threshold
|
56
|
+
def test_threshold_functionality(self):
|
57
|
+
"""Test Threshold activation function."""
|
58
|
+
layer = nn.Threshold(threshold=0.5, value=0.0)
|
59
|
+
|
60
|
+
# Test with values above and below threshold
|
61
|
+
x = jnp.array([-1.0, 0.0, 0.3, 0.7, 1.0])
|
62
|
+
output = layer(x)
|
63
|
+
expected = jnp.array([0.0, 0.0, 0.0, 0.7, 1.0])
|
64
|
+
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
65
|
+
|
66
|
+
@parameterized.parameters(
|
67
|
+
((2,), ),
|
68
|
+
((3, 4), ),
|
69
|
+
((2, 3, 4), ),
|
70
|
+
((2, 3, 4, 5), ),
|
71
|
+
)
|
72
|
+
def test_threshold_shapes(self, shape):
|
73
|
+
"""Test Threshold with different input shapes."""
|
74
|
+
layer = nn.Threshold(threshold=0.1, value=20)
|
75
|
+
self._check_shape_preservation(layer, shape)
|
76
|
+
|
77
|
+
# Test ReLU
|
78
|
+
def test_relu_functionality(self):
|
79
|
+
"""Test ReLU activation function."""
|
80
|
+
layer = nn.ReLU()
|
81
|
+
|
82
|
+
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
|
83
|
+
output = layer(x)
|
84
|
+
expected = jnp.array([0.0, 0.0, 0.0, 1.0, 2.0])
|
85
|
+
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
86
|
+
|
87
|
+
@parameterized.parameters(
|
88
|
+
((10,), ),
|
89
|
+
((5, 10), ),
|
90
|
+
((3, 5, 10), ),
|
91
|
+
)
|
92
|
+
def test_relu_shapes_and_gradients(self, shape):
|
93
|
+
"""Test ReLU shapes and gradients."""
|
94
|
+
layer = nn.ReLU()
|
95
|
+
self._check_shape_preservation(layer, shape)
|
96
|
+
self._check_gradient_flow(layer, shape)
|
97
|
+
|
98
|
+
# Test RReLU
|
99
|
+
def test_rrelu_functionality(self):
|
100
|
+
"""Test RReLU activation function."""
|
101
|
+
layer = nn.RReLU(lower=0.1, upper=0.3)
|
102
|
+
|
103
|
+
# Test positive and negative values
|
104
|
+
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
|
105
|
+
output = layer(x)
|
106
|
+
|
107
|
+
# Positive values should remain unchanged
|
108
|
+
self.assertTrue(jnp.all(output[x > 0] == x[x > 0]))
|
109
|
+
# Negative values should be scaled by a factor in [lower, upper]
|
110
|
+
negative_mask = x < 0
|
111
|
+
if jnp.any(negative_mask):
|
112
|
+
scaled = output[negative_mask] / x[negative_mask]
|
113
|
+
self.assertTrue(jnp.all((scaled >= 0.1) & (scaled <= 0.3)))
|
114
|
+
|
115
|
+
# Test Hardtanh
|
116
|
+
def test_hardtanh_functionality(self):
|
117
|
+
"""Test Hardtanh activation function."""
|
118
|
+
layer = nn.Hardtanh(min_val=-1.0, max_val=1.0)
|
119
|
+
|
120
|
+
x = jnp.array([-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0])
|
121
|
+
output = layer(x)
|
122
|
+
expected = jnp.array([-1.0, -1.0, -0.5, 0.0, 0.5, 1.0, 1.0])
|
123
|
+
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
124
|
+
|
125
|
+
def test_hardtanh_custom_bounds(self):
|
126
|
+
"""Test Hardtanh with custom bounds."""
|
127
|
+
layer = nn.Hardtanh(min_val=-2.0, max_val=3.0)
|
128
|
+
|
129
|
+
x = jnp.array([-3.0, -2.0, 0.0, 3.0, 4.0])
|
130
|
+
output = layer(x)
|
131
|
+
expected = jnp.array([-2.0, -2.0, 0.0, 3.0, 3.0])
|
132
|
+
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
133
|
+
|
134
|
+
# Test ReLU6
|
135
|
+
def test_relu6_functionality(self):
|
136
|
+
"""Test ReLU6 activation function."""
|
137
|
+
layer = nn.ReLU6()
|
138
|
+
|
139
|
+
x = jnp.array([-2.0, 0.0, 3.0, 6.0, 8.0])
|
140
|
+
output = layer(x)
|
141
|
+
expected = jnp.array([0.0, 0.0, 3.0, 6.0, 6.0])
|
142
|
+
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
143
|
+
|
144
|
+
# Test Sigmoid
|
145
|
+
def test_sigmoid_functionality(self):
|
146
|
+
"""Test Sigmoid activation function."""
|
147
|
+
layer = nn.Sigmoid()
|
148
|
+
|
149
|
+
x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])
|
150
|
+
output = layer(x)
|
151
|
+
|
152
|
+
# Check sigmoid properties
|
153
|
+
self.assertTrue(jnp.all((output >= 0.0) & (output <= 1.0)))
|
154
|
+
np.testing.assert_allclose(output[2], 0.5, rtol=1e-5) # sigmoid(0) = 0.5
|
155
|
+
|
156
|
+
@parameterized.parameters(
|
157
|
+
((10,), ),
|
158
|
+
((5, 10), ),
|
159
|
+
((3, 5, 10), ),
|
160
|
+
)
|
161
|
+
def test_sigmoid_shapes_and_gradients(self, shape):
|
162
|
+
"""Test Sigmoid shapes and gradients."""
|
163
|
+
layer = nn.Sigmoid()
|
164
|
+
self._check_shape_preservation(layer, shape)
|
165
|
+
self._check_gradient_flow(layer, shape)
|
166
|
+
|
167
|
+
# Test Hardsigmoid
|
168
|
+
def test_hardsigmoid_functionality(self):
|
169
|
+
"""Test Hardsigmoid activation function."""
|
170
|
+
layer = nn.Hardsigmoid()
|
171
|
+
|
172
|
+
x = jnp.array([-4.0, -3.0, -1.0, 0.0, 1.0, 3.0, 4.0])
|
173
|
+
output = layer(x)
|
174
|
+
|
175
|
+
# Check bounds
|
176
|
+
self.assertTrue(jnp.all((output >= 0.0) & (output <= 1.0)))
|
177
|
+
# Check specific values
|
178
|
+
np.testing.assert_allclose(output[1], 0.0, rtol=1e-5) # x=-3
|
179
|
+
np.testing.assert_allclose(output[3], 0.5, rtol=1e-5) # x=0
|
180
|
+
np.testing.assert_allclose(output[5], 1.0, rtol=1e-5) # x=3
|
181
|
+
|
182
|
+
# Test Tanh
|
183
|
+
def test_tanh_functionality(self):
|
184
|
+
"""Test Tanh activation function."""
|
185
|
+
layer = nn.Tanh()
|
186
|
+
|
187
|
+
x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])
|
188
|
+
output = layer(x)
|
189
|
+
|
190
|
+
# Check tanh properties
|
191
|
+
self.assertTrue(jnp.all((output >= -1.0) & (output <= 1.0)))
|
192
|
+
np.testing.assert_allclose(output[2], 0.0, rtol=1e-5) # tanh(0) = 0
|
193
|
+
|
194
|
+
# Test SiLU (Swish)
|
195
|
+
def test_silu_functionality(self):
|
196
|
+
"""Test SiLU activation function."""
|
197
|
+
layer = nn.SiLU()
|
198
|
+
|
199
|
+
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
|
200
|
+
output = layer(x)
|
201
|
+
|
202
|
+
# SiLU(x) = x * sigmoid(x)
|
203
|
+
expected = x * jax.nn.sigmoid(x)
|
204
|
+
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
205
|
+
|
206
|
+
# Test Mish
|
207
|
+
def test_mish_functionality(self):
|
208
|
+
"""Test Mish activation function."""
|
209
|
+
layer = nn.Mish()
|
210
|
+
|
211
|
+
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
|
212
|
+
output = layer(x)
|
213
|
+
|
214
|
+
# Mish(x) = x * tanh(softplus(x))
|
215
|
+
expected = x * jnp.tanh(jax.nn.softplus(x))
|
216
|
+
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
217
|
+
|
218
|
+
# Test Hardswish
|
219
|
+
def test_hardswish_functionality(self):
|
220
|
+
"""Test Hardswish activation function."""
|
221
|
+
layer = nn.Hardswish()
|
222
|
+
|
223
|
+
x = jnp.array([-4.0, -3.0, -1.0, 0.0, 1.0, 3.0, 4.0])
|
224
|
+
output = layer(x)
|
225
|
+
|
226
|
+
# Check boundary conditions
|
227
|
+
np.testing.assert_allclose(output[1], 0.0, rtol=1e-5) # x=-3
|
228
|
+
np.testing.assert_allclose(output[5], 3.0, rtol=1e-5) # x=3
|
229
|
+
|
230
|
+
# Test ELU
|
231
|
+
def test_elu_functionality(self):
|
232
|
+
"""Test ELU activation function."""
|
233
|
+
layer = nn.ELU(alpha=1.0)
|
234
|
+
|
235
|
+
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
|
236
|
+
output = layer(x)
|
237
|
+
|
238
|
+
# Positive values should remain unchanged
|
239
|
+
self.assertTrue(jnp.all(output[x > 0] == x[x > 0]))
|
240
|
+
# Check ELU formula for negative values
|
241
|
+
negative_mask = x <= 0
|
242
|
+
expected_negative = 1.0 * (jnp.exp(x[negative_mask]) - 1)
|
243
|
+
np.testing.assert_allclose(output[negative_mask], expected_negative, rtol=1e-5)
|
244
|
+
|
245
|
+
def test_elu_with_different_alpha(self):
|
246
|
+
"""Test ELU with different alpha values."""
|
247
|
+
alpha = 2.0
|
248
|
+
layer = nn.ELU(alpha=alpha)
|
249
|
+
|
250
|
+
x = jnp.array([-1.0])
|
251
|
+
output = layer(x)
|
252
|
+
expected = alpha * (jnp.exp(x) - 1)
|
253
|
+
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
254
|
+
|
255
|
+
# Test CELU
|
256
|
+
def test_celu_functionality(self):
|
257
|
+
"""Test CELU activation function."""
|
258
|
+
layer = nn.CELU(alpha=1.0)
|
259
|
+
|
260
|
+
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
|
261
|
+
output = layer(x)
|
262
|
+
|
263
|
+
# Positive values should remain unchanged
|
264
|
+
self.assertTrue(jnp.all(output[x > 0] == x[x > 0]))
|
265
|
+
|
266
|
+
# Test SELU
|
267
|
+
def test_selu_functionality(self):
|
268
|
+
"""Test SELU activation function."""
|
269
|
+
layer = nn.SELU()
|
270
|
+
|
271
|
+
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
|
272
|
+
output = layer(x)
|
273
|
+
|
274
|
+
# Check that output is scaled ELU
|
275
|
+
# SELU has specific scale and alpha values
|
276
|
+
scale = 1.0507009873554804934193349852946
|
277
|
+
alpha = 1.6732632423543772848170429916717
|
278
|
+
|
279
|
+
positive_mask = x > 0
|
280
|
+
self.assertTrue(jnp.all(output[positive_mask] == scale * x[positive_mask]))
|
281
|
+
|
282
|
+
# Test GLU
|
283
|
+
def test_glu_functionality(self):
|
284
|
+
"""Test GLU activation function."""
|
285
|
+
layer = nn.GLU(dim=-1)
|
286
|
+
|
287
|
+
# GLU splits input in half along specified dimension
|
288
|
+
x = jnp.array([[1.0, 2.0, 3.0, 4.0],
|
289
|
+
[5.0, 6.0, 7.0, 8.0]])
|
290
|
+
output = layer(x)
|
291
|
+
|
292
|
+
# Output should have half the size along the split dimension
|
293
|
+
self.assertEqual(output.shape, (2, 2))
|
294
|
+
|
295
|
+
def test_glu_different_dimensions(self):
|
296
|
+
"""Test GLU with different split dimensions."""
|
297
|
+
# Test splitting along different dimensions
|
298
|
+
x = jax.random.normal(self.key, (4, 6, 8))
|
299
|
+
|
300
|
+
layer_0 = nn.GLU(dim=0)
|
301
|
+
output_0 = layer_0(x)
|
302
|
+
self.assertEqual(output_0.shape, (2, 6, 8))
|
303
|
+
|
304
|
+
layer_1 = nn.GLU(dim=1)
|
305
|
+
output_1 = layer_1(x)
|
306
|
+
self.assertEqual(output_1.shape, (4, 3, 8))
|
307
|
+
|
308
|
+
layer_2 = nn.GLU(dim=2)
|
309
|
+
output_2 = layer_2(x)
|
310
|
+
self.assertEqual(output_2.shape, (4, 6, 4))
|
311
|
+
|
312
|
+
# Test GELU
|
313
|
+
def test_gelu_functionality(self):
|
314
|
+
"""Test GELU activation function."""
|
315
|
+
layer = nn.GELU(approximate=False)
|
316
|
+
|
317
|
+
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
|
318
|
+
output = layer(x)
|
319
|
+
|
320
|
+
# GELU should be smooth and differentiable everywhere
|
321
|
+
np.testing.assert_allclose(output[2], 0.0, rtol=1e-5) # GELU(0) ≈ 0
|
322
|
+
|
323
|
+
def test_gelu_approximate(self):
|
324
|
+
"""Test GELU with tanh approximation."""
|
325
|
+
layer_exact = nn.GELU(approximate=False)
|
326
|
+
layer_approx = nn.GELU(approximate=True)
|
327
|
+
|
328
|
+
x = jnp.array([-1.0, 0.0, 1.0])
|
329
|
+
output_exact = layer_exact(x)
|
330
|
+
output_approx = layer_approx(x)
|
331
|
+
|
332
|
+
# Approximation should be close but not exactly equal
|
333
|
+
np.testing.assert_allclose(output_exact, output_approx, rtol=1e-2)
|
334
|
+
|
335
|
+
# Test Hardshrink
|
336
|
+
def test_hardshrink_functionality(self):
|
337
|
+
"""Test Hardshrink activation function."""
|
338
|
+
lambd = 0.5
|
339
|
+
layer = nn.Hardshrink(lambd=lambd)
|
340
|
+
|
341
|
+
x = jnp.array([-1.0, -0.6, -0.5, -0.3, 0.0, 0.3, 0.5, 0.6, 1.0])
|
342
|
+
output = layer(x)
|
343
|
+
|
344
|
+
# Check each value according to hardshrink formula
|
345
|
+
expected = []
|
346
|
+
for xi in x:
|
347
|
+
if xi > lambd:
|
348
|
+
expected.append(xi)
|
349
|
+
elif xi < -lambd:
|
350
|
+
expected.append(xi)
|
351
|
+
else:
|
352
|
+
expected.append(0.0)
|
353
|
+
expected = jnp.array(expected)
|
354
|
+
|
355
|
+
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
356
|
+
|
357
|
+
# Test LeakyReLU
|
358
|
+
def test_leaky_relu_functionality(self):
|
359
|
+
"""Test LeakyReLU activation function."""
|
360
|
+
negative_slope = 0.01
|
361
|
+
layer = nn.LeakyReLU(negative_slope=negative_slope)
|
362
|
+
|
363
|
+
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
|
364
|
+
output = layer(x)
|
365
|
+
|
366
|
+
# Positive values should remain unchanged
|
367
|
+
self.assertTrue(jnp.all(output[x > 0] == x[x > 0]))
|
368
|
+
# Negative values should be scaled
|
369
|
+
negative_mask = x < 0
|
370
|
+
expected_negative = negative_slope * x[negative_mask]
|
371
|
+
np.testing.assert_allclose(output[negative_mask], expected_negative, rtol=1e-5)
|
372
|
+
|
373
|
+
def test_leaky_relu_custom_slope(self):
|
374
|
+
"""Test LeakyReLU with custom negative slope."""
|
375
|
+
negative_slope = 0.2
|
376
|
+
layer = nn.LeakyReLU(negative_slope=negative_slope)
|
377
|
+
|
378
|
+
x = jnp.array([-5.0])
|
379
|
+
output = layer(x)
|
380
|
+
expected = negative_slope * x
|
381
|
+
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
382
|
+
|
383
|
+
# Test LogSigmoid
|
384
|
+
def test_log_sigmoid_functionality(self):
|
385
|
+
"""Test LogSigmoid activation function."""
|
386
|
+
layer = nn.LogSigmoid()
|
387
|
+
|
388
|
+
x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])
|
389
|
+
output = layer(x)
|
390
|
+
|
391
|
+
# LogSigmoid(x) = log(sigmoid(x))
|
392
|
+
expected = jnp.log(jax.nn.sigmoid(x))
|
393
|
+
np.testing.assert_allclose(output, expected, rtol=1e-2)
|
394
|
+
|
395
|
+
# Output should always be negative or zero
|
396
|
+
self.assertTrue(jnp.all(output <= 0.0))
|
397
|
+
|
398
|
+
# Test Softplus
|
399
|
+
def test_softplus_functionality(self):
|
400
|
+
"""Test Softplus activation function."""
|
401
|
+
layer = nn.Softplus()
|
402
|
+
|
403
|
+
x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])
|
404
|
+
output = layer(x)
|
405
|
+
|
406
|
+
# Softplus is a smooth approximation to ReLU
|
407
|
+
# Should always be positive
|
408
|
+
self.assertTrue(jnp.all(output > 0.0))
|
409
|
+
|
410
|
+
# For large positive values, should approximate x
|
411
|
+
np.testing.assert_allclose(output[-1], x[-1], rtol=1e-2)
|
412
|
+
|
413
|
+
# Test Softshrink
|
414
|
+
def test_softshrink_functionality(self):
|
415
|
+
"""Test Softshrink activation function."""
|
416
|
+
lambd = 0.5
|
417
|
+
layer = nn.Softshrink(lambd=lambd)
|
418
|
+
|
419
|
+
x = jnp.array([-1.0, -0.5, -0.3, 0.0, 0.3, 0.5, 1.0])
|
420
|
+
output = layer(x)
|
421
|
+
|
422
|
+
# Check the softshrink formula
|
423
|
+
for i in range(len(x)):
|
424
|
+
if x[i] > lambd:
|
425
|
+
expected = x[i] - lambd
|
426
|
+
elif x[i] < -lambd:
|
427
|
+
expected = x[i] + lambd
|
428
|
+
else:
|
429
|
+
expected = 0.0
|
430
|
+
np.testing.assert_allclose(output[i], expected, rtol=1e-5)
|
431
|
+
|
432
|
+
# Test PReLU
|
433
|
+
def test_prelu_functionality(self):
|
434
|
+
"""Test PReLU activation function."""
|
435
|
+
layer = nn.PReLU(num_parameters=1, init=0.25)
|
436
|
+
|
437
|
+
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
|
438
|
+
output = layer(x)
|
439
|
+
|
440
|
+
# Positive values should remain unchanged
|
441
|
+
self.assertTrue(jnp.all(output[x > 0] == x[x > 0]))
|
442
|
+
# Negative values should be scaled by learned parameter
|
443
|
+
negative_mask = x < 0
|
444
|
+
# Check that negative values are scaled
|
445
|
+
self.assertTrue(jnp.all(output[negative_mask] != x[negative_mask]))
|
446
|
+
|
447
|
+
def test_prelu_multi_channel(self):
|
448
|
+
"""Test PReLU with multiple channels."""
|
449
|
+
num_channels = 3
|
450
|
+
layer = nn.PReLU(num_parameters=num_channels, init=0.25)
|
451
|
+
|
452
|
+
# Input shape: (batch, channels, height, width)
|
453
|
+
x = jax.random.normal(self.key, (2, 4, 4, num_channels))
|
454
|
+
output = layer(x)
|
455
|
+
|
456
|
+
self.assertEqual(output.shape, x.shape)
|
457
|
+
|
458
|
+
# Test Softsign
|
459
|
+
def test_softsign_functionality(self):
|
460
|
+
"""Test Softsign activation function."""
|
461
|
+
layer = nn.Softsign()
|
462
|
+
|
463
|
+
x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])
|
464
|
+
output = layer(x)
|
465
|
+
|
466
|
+
# Softsign(x) = x / (1 + |x|)
|
467
|
+
expected = x / (1 + jnp.abs(x))
|
468
|
+
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
469
|
+
|
470
|
+
# Output should be bounded between -1 and 1
|
471
|
+
self.assertTrue(jnp.all((output >= -1.0) & (output <= 1.0)))
|
472
|
+
|
473
|
+
# Test Tanhshrink
|
474
|
+
def test_tanhshrink_functionality(self):
|
475
|
+
"""Test Tanhshrink activation function."""
|
476
|
+
layer = nn.Tanhshrink()
|
477
|
+
|
478
|
+
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
|
479
|
+
output = layer(x)
|
480
|
+
|
481
|
+
# Tanhshrink(x) = x - tanh(x)
|
482
|
+
expected = x - jnp.tanh(x)
|
483
|
+
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
484
|
+
|
485
|
+
# Test Softmin
|
486
|
+
def test_softmin_functionality(self):
|
487
|
+
"""Test Softmin activation function."""
|
488
|
+
layer = nn.Softmin(dim=-1)
|
489
|
+
|
490
|
+
x = jnp.array([[1.0, 2.0, 3.0],
|
491
|
+
[4.0, 5.0, 6.0]])
|
492
|
+
output = layer(x)
|
493
|
+
|
494
|
+
# Softmin should sum to 1 along the specified dimension
|
495
|
+
sums = jnp.sum(output, axis=-1)
|
496
|
+
np.testing.assert_allclose(sums, jnp.ones_like(sums), rtol=1e-5)
|
497
|
+
|
498
|
+
# Higher values should have lower probabilities
|
499
|
+
self.assertTrue(jnp.all(output[:, 0] > output[:, 1]))
|
500
|
+
self.assertTrue(jnp.all(output[:, 1] > output[:, 2]))
|
501
|
+
|
502
|
+
# Test Softmax
|
503
|
+
def test_softmax_functionality(self):
|
504
|
+
"""Test Softmax activation function."""
|
505
|
+
layer = nn.Softmax(dim=-1)
|
506
|
+
|
507
|
+
x = jnp.array([[1.0, 2.0, 3.0],
|
508
|
+
[4.0, 5.0, 6.0]])
|
509
|
+
output = layer(x)
|
510
|
+
|
511
|
+
# Softmax should sum to 1 along the specified dimension
|
512
|
+
sums = jnp.sum(output, axis=-1)
|
513
|
+
np.testing.assert_allclose(sums, jnp.ones_like(sums), rtol=1e-5)
|
514
|
+
|
515
|
+
# Higher values should have higher probabilities
|
516
|
+
self.assertTrue(jnp.all(output[:, 2] > output[:, 1]))
|
517
|
+
self.assertTrue(jnp.all(output[:, 1] > output[:, 0]))
|
518
|
+
|
519
|
+
def test_softmax_numerical_stability(self):
|
520
|
+
"""Test Softmax numerical stability with large values."""
|
521
|
+
layer = nn.Softmax(dim=-1)
|
522
|
+
|
523
|
+
# Test with large values that could cause overflow
|
524
|
+
x = jnp.array([[1000.0, 1000.0, 1000.0]])
|
525
|
+
output = layer(x)
|
526
|
+
|
527
|
+
# Should still sum to 1 and have equal probabilities
|
528
|
+
np.testing.assert_allclose(jnp.sum(output), 1.0, rtol=1e-5)
|
529
|
+
np.testing.assert_allclose(output[0, 0], 1/3, rtol=1e-5)
|
530
|
+
|
531
|
+
# Test Softmax2d
|
532
|
+
def test_softmax2d_functionality(self):
|
533
|
+
"""Test Softmax2d activation function."""
|
534
|
+
layer = nn.Softmax2d()
|
535
|
+
|
536
|
+
# Input shape: (batch, channels, height, width)
|
537
|
+
x = jax.random.normal(self.key, (2, 3, 4, 5))
|
538
|
+
output = layer(x)
|
539
|
+
|
540
|
+
self.assertEqual(output.shape, x.shape)
|
541
|
+
|
542
|
+
# Should sum to 1 across channels for each spatial location
|
543
|
+
channel_sums = jnp.sum(output, axis=1)
|
544
|
+
np.testing.assert_allclose(channel_sums, jnp.ones_like(channel_sums), rtol=1e-5)
|
545
|
+
|
546
|
+
def test_softmax2d_3d_input(self):
|
547
|
+
"""Test Softmax2d with 3D input."""
|
548
|
+
layer = nn.Softmax2d()
|
549
|
+
|
550
|
+
# Input shape: (channels, height, width)
|
551
|
+
x = jax.random.normal(self.key, (3, 4, 5))
|
552
|
+
output = layer(x)
|
553
|
+
|
554
|
+
self.assertEqual(output.shape, x.shape)
|
555
|
+
|
556
|
+
# Test LogSoftmax
|
557
|
+
def test_log_softmax_functionality(self):
|
558
|
+
"""Test LogSoftmax activation function."""
|
559
|
+
layer = nn.LogSoftmax(dim=-1)
|
560
|
+
|
561
|
+
x = jnp.array([[1.0, 2.0, 3.0],
|
562
|
+
[4.0, 5.0, 6.0]])
|
563
|
+
output = layer(x)
|
564
|
+
|
565
|
+
# LogSoftmax = log(softmax(x))
|
566
|
+
softmax_output = jax.nn.softmax(x, axis=-1)
|
567
|
+
expected = jnp.log(softmax_output)
|
568
|
+
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
569
|
+
|
570
|
+
# Output should be all negative or zero
|
571
|
+
self.assertTrue(jnp.all(output <= 0.0))
|
572
|
+
|
573
|
+
def test_log_softmax_numerical_stability(self):
|
574
|
+
"""Test LogSoftmax numerical stability."""
|
575
|
+
layer = nn.LogSoftmax(dim=-1)
|
576
|
+
|
577
|
+
# Test with values that could cause numerical issues
|
578
|
+
x = jnp.array([[1000.0, 0.0, -1000.0]])
|
579
|
+
output = layer(x)
|
580
|
+
|
581
|
+
# Should not contain NaN or Inf
|
582
|
+
self.assertFalse(jnp.any(jnp.isnan(output)))
|
583
|
+
self.assertFalse(jnp.any(jnp.isinf(output)))
|
584
|
+
|
585
|
+
# Test Identity
|
586
|
+
def test_identity_functionality(self):
|
587
|
+
"""Test Identity activation function."""
|
588
|
+
layer = nn.Identity()
|
589
|
+
|
590
|
+
x = jax.random.normal(self.key, (3, 4, 5))
|
591
|
+
output = layer(x)
|
592
|
+
|
593
|
+
# Should be exactly equal to input
|
594
|
+
np.testing.assert_array_equal(output, x)
|
595
|
+
|
596
|
+
def test_identity_gradient(self):
|
597
|
+
"""Test Identity gradient flow."""
|
598
|
+
layer = nn.Identity()
|
599
|
+
|
600
|
+
x = jax.random.normal(self.key, (3, 4))
|
601
|
+
|
602
|
+
def loss_fn(x):
|
603
|
+
return jnp.sum(layer(x))
|
604
|
+
|
605
|
+
grad = jax.grad(loss_fn)(x)
|
606
|
+
|
607
|
+
# Gradient should be all ones
|
608
|
+
np.testing.assert_allclose(grad, jnp.ones_like(x), rtol=1e-5)
|
609
|
+
|
610
|
+
# Test SpikeBitwise
|
611
|
+
def test_spike_bitwise_add(self):
|
612
|
+
"""Test SpikeBitwise with ADD operation."""
|
613
|
+
layer = nn.SpikeBitwise(op='and')
|
614
|
+
|
615
|
+
x = jnp.array([[1.0, 0.0], [1.0, 1.0]])
|
616
|
+
y = jnp.array([[1.0, 1.0], [0.0, 1.0]])
|
617
|
+
output = layer(x, y)
|
618
|
+
|
619
|
+
expected = jnp.logical_and(x, y)
|
620
|
+
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
621
|
+
|
622
|
+
def test_spike_bitwise_and(self):
|
623
|
+
"""Test SpikeBitwise with AND operation."""
|
624
|
+
layer = nn.SpikeBitwise(op='and')
|
625
|
+
|
626
|
+
x = jnp.array([[1.0, 0.0], [1.0, 1.0]])
|
627
|
+
y = jnp.array([[1.0, 1.0], [0.0, 1.0]])
|
628
|
+
output = layer(x, y)
|
629
|
+
|
630
|
+
expected = x * y
|
631
|
+
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
632
|
+
|
633
|
+
def test_spike_bitwise_iand(self):
|
634
|
+
"""Test SpikeBitwise with IAND operation."""
|
635
|
+
layer = nn.SpikeBitwise(op='iand')
|
636
|
+
|
637
|
+
x = jnp.array([[1.0, 0.0], [1.0, 1.0]])
|
638
|
+
y = jnp.array([[1.0, 1.0], [0.0, 1.0]])
|
639
|
+
output = layer(x, y)
|
640
|
+
|
641
|
+
expected = (1 - x) * y
|
642
|
+
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
643
|
+
|
644
|
+
def test_spike_bitwise_or(self):
|
645
|
+
"""Test SpikeBitwise with OR operation."""
|
646
|
+
layer = nn.SpikeBitwise(op='or')
|
647
|
+
|
648
|
+
x = jnp.array([[1.0, 0.0], [1.0, 1.0]])
|
649
|
+
y = jnp.array([[1.0, 1.0], [0.0, 1.0]])
|
650
|
+
output = layer(x, y)
|
651
|
+
|
652
|
+
expected = (x + y) - (x * y)
|
653
|
+
np.testing.assert_allclose(output, expected, rtol=1e-5)
|
654
|
+
|
655
|
+
|
656
|
+
class TestEdgeCases(parameterized.TestCase):
|
657
|
+
"""Test edge cases and boundary conditions."""
|
658
|
+
|
659
|
+
def test_zero_input(self):
|
660
|
+
"""Test all activations with zero input."""
|
661
|
+
x = jnp.zeros((3, 4))
|
662
|
+
|
663
|
+
activations = [
|
664
|
+
nn.ReLU(),
|
665
|
+
nn.Sigmoid(),
|
666
|
+
nn.Tanh(),
|
667
|
+
nn.SiLU(),
|
668
|
+
nn.ELU(),
|
669
|
+
nn.GELU(),
|
670
|
+
nn.Softplus(),
|
671
|
+
nn.Softsign(),
|
672
|
+
]
|
673
|
+
|
674
|
+
for activation in activations:
|
675
|
+
output = activation(x)
|
676
|
+
self.assertEqual(output.shape, x.shape)
|
677
|
+
self.assertFalse(jnp.any(jnp.isnan(output)))
|
678
|
+
|
679
|
+
def test_large_positive_input(self):
|
680
|
+
"""Test activations with very large positive values."""
|
681
|
+
x = jnp.ones((2, 3)) * 1000.0
|
682
|
+
|
683
|
+
activations = [
|
684
|
+
nn.ReLU(),
|
685
|
+
nn.Sigmoid(),
|
686
|
+
nn.Tanh(),
|
687
|
+
nn.Hardsigmoid(),
|
688
|
+
nn.Hardswish(),
|
689
|
+
]
|
690
|
+
|
691
|
+
for activation in activations:
|
692
|
+
output = activation(x)
|
693
|
+
self.assertFalse(jnp.any(jnp.isnan(output)))
|
694
|
+
self.assertFalse(jnp.any(jnp.isinf(output)))
|
695
|
+
|
696
|
+
def test_large_negative_input(self):
|
697
|
+
"""Test activations with very large negative values."""
|
698
|
+
x = jnp.ones((2, 3)) * -1000.0
|
699
|
+
|
700
|
+
activations = [
|
701
|
+
nn.ReLU(),
|
702
|
+
nn.Sigmoid(),
|
703
|
+
nn.Tanh(),
|
704
|
+
nn.Hardsigmoid(),
|
705
|
+
nn.Hardswish(),
|
706
|
+
]
|
707
|
+
|
708
|
+
for activation in activations:
|
709
|
+
output = activation(x)
|
710
|
+
self.assertFalse(jnp.any(jnp.isnan(output)))
|
711
|
+
self.assertFalse(jnp.any(jnp.isinf(output)))
|
712
|
+
|
713
|
+
def test_nan_propagation(self):
|
714
|
+
"""Test that NaN inputs produce NaN outputs (where appropriate)."""
|
715
|
+
x = jnp.array([jnp.nan, 1.0, 2.0])
|
716
|
+
|
717
|
+
activations = [
|
718
|
+
nn.ReLU(),
|
719
|
+
nn.Sigmoid(),
|
720
|
+
nn.Tanh(),
|
721
|
+
]
|
722
|
+
|
723
|
+
for activation in activations:
|
724
|
+
output = activation(x)
|
725
|
+
self.assertTrue(jnp.isnan(output[0]))
|
726
|
+
|
727
|
+
def test_inf_handling(self):
|
728
|
+
"""Test handling of infinite values."""
|
729
|
+
x = jnp.array([jnp.inf, -jnp.inf, 1.0])
|
730
|
+
|
731
|
+
# ReLU should handle inf properly
|
732
|
+
relu = nn.ReLU()
|
733
|
+
output = relu(x)
|
734
|
+
self.assertEqual(output[0], jnp.inf)
|
735
|
+
self.assertEqual(output[1], 0.0)
|
736
|
+
|
737
|
+
# Sigmoid should saturate
|
738
|
+
sigmoid = nn.Sigmoid()
|
739
|
+
output = sigmoid(x)
|
740
|
+
np.testing.assert_allclose(output[0], 1.0, rtol=1e-5)
|
741
|
+
np.testing.assert_allclose(output[1], 0.0, rtol=1e-5)
|
742
|
+
|
743
|
+
|
744
|
+
class TestBatchProcessing(parameterized.TestCase):
|
745
|
+
"""Test batch processing capabilities."""
|
746
|
+
|
747
|
+
@parameterized.parameters(
|
748
|
+
(nn.ReLU(), ),
|
749
|
+
(nn.Sigmoid(), ),
|
750
|
+
(nn.Tanh(), ),
|
751
|
+
(nn.GELU(), ),
|
752
|
+
(nn.SiLU(), ),
|
753
|
+
(nn.ELU(), ),
|
754
|
+
)
|
755
|
+
def test_batch_consistency(self, activation):
|
756
|
+
"""Test that batch processing gives same results as individual processing."""
|
757
|
+
# Process as batch
|
758
|
+
batch_input = jax.random.normal(jax.random.PRNGKey(42), (5, 10))
|
759
|
+
batch_output = activation(batch_input)
|
760
|
+
|
761
|
+
# Process individually
|
762
|
+
individual_outputs = []
|
763
|
+
for i in range(5):
|
764
|
+
individual_output = activation(batch_input[i])
|
765
|
+
individual_outputs.append(individual_output)
|
766
|
+
individual_outputs = jnp.stack(individual_outputs)
|
767
|
+
|
768
|
+
np.testing.assert_allclose(batch_output, individual_outputs, rtol=1e-5)
|
769
|
+
|
770
|
+
def test_different_batch_sizes(self):
|
771
|
+
"""Test activations with different batch sizes."""
|
772
|
+
activation = nn.ReLU()
|
773
|
+
|
774
|
+
for batch_size in [1, 10, 100]:
|
775
|
+
x = jax.random.normal(jax.random.PRNGKey(42), (batch_size, 20))
|
776
|
+
output = activation(x)
|
777
|
+
self.assertEqual(output.shape[0], batch_size)
|
778
|
+
|
779
|
+
|
780
|
+
class TestMemoryAndPerformance(parameterized.TestCase):
|
781
|
+
"""Test memory and performance characteristics."""
|
782
|
+
|
783
|
+
def test_in_place_operations(self):
|
784
|
+
"""Test that activations don't modify input in-place."""
|
785
|
+
x_original = jax.random.normal(jax.random.PRNGKey(42), (10, 10))
|
786
|
+
x = x_original.copy()
|
787
|
+
|
788
|
+
activations = [
|
789
|
+
nn.ReLU(),
|
790
|
+
nn.Sigmoid(),
|
791
|
+
nn.Tanh(),
|
792
|
+
]
|
793
|
+
|
794
|
+
for activation in activations:
|
795
|
+
output = activation(x)
|
796
|
+
np.testing.assert_array_equal(x, x_original)
|
797
|
+
|
798
|
+
def test_jit_compilation(self):
|
799
|
+
"""Test that activations work with JIT compilation."""
|
800
|
+
@jax.jit
|
801
|
+
def forward(x):
|
802
|
+
relu = nn.ReLU()
|
803
|
+
return relu(x)
|
804
|
+
|
805
|
+
x = jax.random.normal(jax.random.PRNGKey(42), (10, 10))
|
806
|
+
output = forward(x)
|
807
|
+
|
808
|
+
# Should not raise any errors and produce valid output
|
809
|
+
self.assertEqual(output.shape, x.shape)
|
810
|
+
|
811
|
+
@parameterized.parameters(
|
812
|
+
(nn.ReLU(), ),
|
813
|
+
(nn.Sigmoid(), ),
|
814
|
+
(nn.Tanh(), ),
|
815
|
+
)
|
816
|
+
def test_vmap_compatibility(self, activation):
|
817
|
+
"""Test compatibility with vmap."""
|
818
|
+
def single_forward(x):
|
819
|
+
return activation(x)
|
820
|
+
|
821
|
+
batch_forward = jax.vmap(single_forward)
|
822
|
+
|
823
|
+
x = jax.random.normal(jax.random.PRNGKey(42), (5, 10, 20))
|
824
|
+
output = batch_forward(x)
|
825
|
+
|
826
|
+
self.assertEqual(output.shape, x.shape)
|
827
|
+
|
828
|
+
|
829
|
+
if __name__ == '__main__':
|
830
830
|
absltest.main()
|