brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,169 +1,830 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from absl.testing import absltest
17
- from absl.testing import parameterized
18
-
19
- import brainstate
20
-
21
-
22
- class Test_Activation(parameterized.TestCase):
23
-
24
- def test_Threshold(self):
25
- threshold_layer = brainstate.nn.Threshold(5, 20)
26
- input = brainstate.random.randn(2)
27
- output = threshold_layer(input)
28
-
29
- def test_ReLU(self):
30
- ReLU_layer = brainstate.nn.ReLU()
31
- input = brainstate.random.randn(2)
32
- output = ReLU_layer(input)
33
-
34
- def test_RReLU(self):
35
- RReLU_layer = brainstate.nn.RReLU(lower=0, upper=1)
36
- input = brainstate.random.randn(2)
37
- output = RReLU_layer(input)
38
-
39
- def test_Hardtanh(self):
40
- Hardtanh_layer = brainstate.nn.Hardtanh(min_val=0, max_val=1, )
41
- input = brainstate.random.randn(2)
42
- output = Hardtanh_layer(input)
43
-
44
- def test_ReLU6(self):
45
- ReLU6_layer = brainstate.nn.ReLU6()
46
- input = brainstate.random.randn(2)
47
- output = ReLU6_layer(input)
48
-
49
- def test_Sigmoid(self):
50
- Sigmoid_layer = brainstate.nn.Sigmoid()
51
- input = brainstate.random.randn(2)
52
- output = Sigmoid_layer(input)
53
-
54
- def test_Hardsigmoid(self):
55
- Hardsigmoid_layer = brainstate.nn.Hardsigmoid()
56
- input = brainstate.random.randn(2)
57
- output = Hardsigmoid_layer(input)
58
-
59
- def test_Tanh(self):
60
- Tanh_layer = brainstate.nn.Tanh()
61
- input = brainstate.random.randn(2)
62
- output = Tanh_layer(input)
63
-
64
- def test_SiLU(self):
65
- SiLU_layer = brainstate.nn.SiLU()
66
- input = brainstate.random.randn(2)
67
- output = SiLU_layer(input)
68
-
69
- def test_Mish(self):
70
- Mish_layer = brainstate.nn.Mish()
71
- input = brainstate.random.randn(2)
72
- output = Mish_layer(input)
73
-
74
- def test_Hardswish(self):
75
- Hardswish_layer = brainstate.nn.Hardswish()
76
- input = brainstate.random.randn(2)
77
- output = Hardswish_layer(input)
78
-
79
- def test_ELU(self):
80
- ELU_layer = brainstate.nn.ELU(alpha=0.5, )
81
- input = brainstate.random.randn(2)
82
- output = ELU_layer(input)
83
-
84
- def test_CELU(self):
85
- CELU_layer = brainstate.nn.CELU(alpha=0.5, )
86
- input = brainstate.random.randn(2)
87
- output = CELU_layer(input)
88
-
89
- def test_SELU(self):
90
- SELU_layer = brainstate.nn.SELU()
91
- input = brainstate.random.randn(2)
92
- output = SELU_layer(input)
93
-
94
- def test_GLU(self):
95
- GLU_layer = brainstate.nn.GLU()
96
- input = brainstate.random.randn(4, 2)
97
- output = GLU_layer(input)
98
-
99
- @parameterized.product(
100
- approximate=['tanh', 'none']
101
- )
102
- def test_GELU(self, approximate):
103
- GELU_layer = brainstate.nn.GELU()
104
- input = brainstate.random.randn(2)
105
- output = GELU_layer(input)
106
-
107
- def test_Hardshrink(self):
108
- Hardshrink_layer = brainstate.nn.Hardshrink(lambd=1)
109
- input = brainstate.random.randn(2)
110
- output = Hardshrink_layer(input)
111
-
112
- def test_LeakyReLU(self):
113
- LeakyReLU_layer = brainstate.nn.LeakyReLU()
114
- input = brainstate.random.randn(2)
115
- output = LeakyReLU_layer(input)
116
-
117
- def test_LogSigmoid(self):
118
- LogSigmoid_layer = brainstate.nn.LogSigmoid()
119
- input = brainstate.random.randn(2)
120
- output = LogSigmoid_layer(input)
121
-
122
- def test_Softplus(self):
123
- Softplus_layer = brainstate.nn.Softplus()
124
- input = brainstate.random.randn(2)
125
- output = Softplus_layer(input)
126
-
127
- def test_Softshrink(self):
128
- Softshrink_layer = brainstate.nn.Softshrink(lambd=1)
129
- input = brainstate.random.randn(2)
130
- output = Softshrink_layer(input)
131
-
132
- def test_PReLU(self):
133
- PReLU_layer = brainstate.nn.PReLU(num_parameters=2, init=0.5)
134
- input = brainstate.random.randn(2)
135
- output = PReLU_layer(input)
136
-
137
- def test_Softsign(self):
138
- Softsign_layer = brainstate.nn.Softsign()
139
- input = brainstate.random.randn(2)
140
- output = Softsign_layer(input)
141
-
142
- def test_Tanhshrink(self):
143
- Tanhshrink_layer = brainstate.nn.Tanhshrink()
144
- input = brainstate.random.randn(2)
145
- output = Tanhshrink_layer(input)
146
-
147
- def test_Softmin(self):
148
- Softmin_layer = brainstate.nn.Softmin(dim=2)
149
- input = brainstate.random.randn(2, 3, 4)
150
- output = Softmin_layer(input)
151
-
152
- def test_Softmax(self):
153
- Softmax_layer = brainstate.nn.Softmax(dim=2)
154
- input = brainstate.random.randn(2, 3, 4)
155
- output = Softmax_layer(input)
156
-
157
- def test_Softmax2d(self):
158
- Softmax2d_layer = brainstate.nn.Softmax2d()
159
- input = brainstate.random.randn(2, 3, 12, 13)
160
- output = Softmax2d_layer(input)
161
-
162
- def test_LogSoftmax(self):
163
- LogSoftmax_layer = brainstate.nn.LogSoftmax(dim=2)
164
- input = brainstate.random.randn(2, 3, 4)
165
- output = LogSoftmax_layer(input)
166
-
167
-
168
- if __name__ == '__main__':
169
- absltest.main()
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import unittest
17
+ from absl.testing import absltest
18
+ from absl.testing import parameterized
19
+
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import numpy as np
23
+
24
+ import brainstate
25
+ import brainstate.nn as nn
26
+
27
+
28
+ class TestActivationFunctions(parameterized.TestCase):
29
+ """Comprehensive tests for activation functions."""
30
+
31
+ def setUp(self):
32
+ """Set up test fixtures."""
33
+ self.seed = 42
34
+ self.key = jax.random.PRNGKey(self.seed)
35
+
36
+ def _check_shape_preservation(self, layer, input_shape):
37
+ """Helper to check if layer preserves input shape."""
38
+ x = jax.random.normal(self.key, input_shape)
39
+ output = layer(x)
40
+ self.assertEqual(output.shape, x.shape)
41
+
42
+ def _check_gradient_flow(self, layer, input_shape):
43
+ """Helper to check if gradients can flow through the layer."""
44
+ x = jax.random.normal(self.key, input_shape)
45
+
46
+ def loss_fn(x):
47
+ return jnp.sum(layer(x))
48
+
49
+ grad = jax.grad(loss_fn)(x)
50
+ self.assertEqual(grad.shape, x.shape)
51
+ # Check that gradients are not all zeros (for most activations)
52
+ if not isinstance(layer, (nn.Threshold, nn.Hardtanh, nn.ReLU6)):
53
+ self.assertFalse(jnp.allclose(grad, 0.0))
54
+
55
+ # Test Threshold
56
+ def test_threshold_functionality(self):
57
+ """Test Threshold activation function."""
58
+ layer = nn.Threshold(threshold=0.5, value=0.0)
59
+
60
+ # Test with values above and below threshold
61
+ x = jnp.array([-1.0, 0.0, 0.3, 0.7, 1.0])
62
+ output = layer(x)
63
+ expected = jnp.array([0.0, 0.0, 0.0, 0.7, 1.0])
64
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
65
+
66
+ @parameterized.parameters(
67
+ ((2,), ),
68
+ ((3, 4), ),
69
+ ((2, 3, 4), ),
70
+ ((2, 3, 4, 5), ),
71
+ )
72
+ def test_threshold_shapes(self, shape):
73
+ """Test Threshold with different input shapes."""
74
+ layer = nn.Threshold(threshold=0.1, value=20)
75
+ self._check_shape_preservation(layer, shape)
76
+
77
+ # Test ReLU
78
+ def test_relu_functionality(self):
79
+ """Test ReLU activation function."""
80
+ layer = nn.ReLU()
81
+
82
+ x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
83
+ output = layer(x)
84
+ expected = jnp.array([0.0, 0.0, 0.0, 1.0, 2.0])
85
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
86
+
87
+ @parameterized.parameters(
88
+ ((10,), ),
89
+ ((5, 10), ),
90
+ ((3, 5, 10), ),
91
+ )
92
+ def test_relu_shapes_and_gradients(self, shape):
93
+ """Test ReLU shapes and gradients."""
94
+ layer = nn.ReLU()
95
+ self._check_shape_preservation(layer, shape)
96
+ self._check_gradient_flow(layer, shape)
97
+
98
+ # Test RReLU
99
+ def test_rrelu_functionality(self):
100
+ """Test RReLU activation function."""
101
+ layer = nn.RReLU(lower=0.1, upper=0.3)
102
+
103
+ # Test positive and negative values
104
+ x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
105
+ output = layer(x)
106
+
107
+ # Positive values should remain unchanged
108
+ self.assertTrue(jnp.all(output[x > 0] == x[x > 0]))
109
+ # Negative values should be scaled by a factor in [lower, upper]
110
+ negative_mask = x < 0
111
+ if jnp.any(negative_mask):
112
+ scaled = output[negative_mask] / x[negative_mask]
113
+ self.assertTrue(jnp.all((scaled >= 0.1) & (scaled <= 0.3)))
114
+
115
+ # Test Hardtanh
116
+ def test_hardtanh_functionality(self):
117
+ """Test Hardtanh activation function."""
118
+ layer = nn.Hardtanh(min_val=-1.0, max_val=1.0)
119
+
120
+ x = jnp.array([-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0])
121
+ output = layer(x)
122
+ expected = jnp.array([-1.0, -1.0, -0.5, 0.0, 0.5, 1.0, 1.0])
123
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
124
+
125
+ def test_hardtanh_custom_bounds(self):
126
+ """Test Hardtanh with custom bounds."""
127
+ layer = nn.Hardtanh(min_val=-2.0, max_val=3.0)
128
+
129
+ x = jnp.array([-3.0, -2.0, 0.0, 3.0, 4.0])
130
+ output = layer(x)
131
+ expected = jnp.array([-2.0, -2.0, 0.0, 3.0, 3.0])
132
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
133
+
134
+ # Test ReLU6
135
+ def test_relu6_functionality(self):
136
+ """Test ReLU6 activation function."""
137
+ layer = nn.ReLU6()
138
+
139
+ x = jnp.array([-2.0, 0.0, 3.0, 6.0, 8.0])
140
+ output = layer(x)
141
+ expected = jnp.array([0.0, 0.0, 3.0, 6.0, 6.0])
142
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
143
+
144
+ # Test Sigmoid
145
+ def test_sigmoid_functionality(self):
146
+ """Test Sigmoid activation function."""
147
+ layer = nn.Sigmoid()
148
+
149
+ x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])
150
+ output = layer(x)
151
+
152
+ # Check sigmoid properties
153
+ self.assertTrue(jnp.all((output >= 0.0) & (output <= 1.0)))
154
+ np.testing.assert_allclose(output[2], 0.5, rtol=1e-5) # sigmoid(0) = 0.5
155
+
156
+ @parameterized.parameters(
157
+ ((10,), ),
158
+ ((5, 10), ),
159
+ ((3, 5, 10), ),
160
+ )
161
+ def test_sigmoid_shapes_and_gradients(self, shape):
162
+ """Test Sigmoid shapes and gradients."""
163
+ layer = nn.Sigmoid()
164
+ self._check_shape_preservation(layer, shape)
165
+ self._check_gradient_flow(layer, shape)
166
+
167
+ # Test Hardsigmoid
168
+ def test_hardsigmoid_functionality(self):
169
+ """Test Hardsigmoid activation function."""
170
+ layer = nn.Hardsigmoid()
171
+
172
+ x = jnp.array([-4.0, -3.0, -1.0, 0.0, 1.0, 3.0, 4.0])
173
+ output = layer(x)
174
+
175
+ # Check bounds
176
+ self.assertTrue(jnp.all((output >= 0.0) & (output <= 1.0)))
177
+ # Check specific values
178
+ np.testing.assert_allclose(output[1], 0.0, rtol=1e-5) # x=-3
179
+ np.testing.assert_allclose(output[3], 0.5, rtol=1e-5) # x=0
180
+ np.testing.assert_allclose(output[5], 1.0, rtol=1e-5) # x=3
181
+
182
+ # Test Tanh
183
+ def test_tanh_functionality(self):
184
+ """Test Tanh activation function."""
185
+ layer = nn.Tanh()
186
+
187
+ x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])
188
+ output = layer(x)
189
+
190
+ # Check tanh properties
191
+ self.assertTrue(jnp.all((output >= -1.0) & (output <= 1.0)))
192
+ np.testing.assert_allclose(output[2], 0.0, rtol=1e-5) # tanh(0) = 0
193
+
194
+ # Test SiLU (Swish)
195
+ def test_silu_functionality(self):
196
+ """Test SiLU activation function."""
197
+ layer = nn.SiLU()
198
+
199
+ x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
200
+ output = layer(x)
201
+
202
+ # SiLU(x) = x * sigmoid(x)
203
+ expected = x * jax.nn.sigmoid(x)
204
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
205
+
206
+ # Test Mish
207
+ def test_mish_functionality(self):
208
+ """Test Mish activation function."""
209
+ layer = nn.Mish()
210
+
211
+ x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
212
+ output = layer(x)
213
+
214
+ # Mish(x) = x * tanh(softplus(x))
215
+ expected = x * jnp.tanh(jax.nn.softplus(x))
216
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
217
+
218
+ # Test Hardswish
219
+ def test_hardswish_functionality(self):
220
+ """Test Hardswish activation function."""
221
+ layer = nn.Hardswish()
222
+
223
+ x = jnp.array([-4.0, -3.0, -1.0, 0.0, 1.0, 3.0, 4.0])
224
+ output = layer(x)
225
+
226
+ # Check boundary conditions
227
+ np.testing.assert_allclose(output[1], 0.0, rtol=1e-5) # x=-3
228
+ np.testing.assert_allclose(output[5], 3.0, rtol=1e-5) # x=3
229
+
230
+ # Test ELU
231
+ def test_elu_functionality(self):
232
+ """Test ELU activation function."""
233
+ layer = nn.ELU(alpha=1.0)
234
+
235
+ x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
236
+ output = layer(x)
237
+
238
+ # Positive values should remain unchanged
239
+ self.assertTrue(jnp.all(output[x > 0] == x[x > 0]))
240
+ # Check ELU formula for negative values
241
+ negative_mask = x <= 0
242
+ expected_negative = 1.0 * (jnp.exp(x[negative_mask]) - 1)
243
+ np.testing.assert_allclose(output[negative_mask], expected_negative, rtol=1e-5)
244
+
245
+ def test_elu_with_different_alpha(self):
246
+ """Test ELU with different alpha values."""
247
+ alpha = 2.0
248
+ layer = nn.ELU(alpha=alpha)
249
+
250
+ x = jnp.array([-1.0])
251
+ output = layer(x)
252
+ expected = alpha * (jnp.exp(x) - 1)
253
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
254
+
255
+ # Test CELU
256
+ def test_celu_functionality(self):
257
+ """Test CELU activation function."""
258
+ layer = nn.CELU(alpha=1.0)
259
+
260
+ x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
261
+ output = layer(x)
262
+
263
+ # Positive values should remain unchanged
264
+ self.assertTrue(jnp.all(output[x > 0] == x[x > 0]))
265
+
266
+ # Test SELU
267
+ def test_selu_functionality(self):
268
+ """Test SELU activation function."""
269
+ layer = nn.SELU()
270
+
271
+ x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
272
+ output = layer(x)
273
+
274
+ # Check that output is scaled ELU
275
+ # SELU has specific scale and alpha values
276
+ scale = 1.0507009873554804934193349852946
277
+ alpha = 1.6732632423543772848170429916717
278
+
279
+ positive_mask = x > 0
280
+ self.assertTrue(jnp.all(output[positive_mask] == scale * x[positive_mask]))
281
+
282
+ # Test GLU
283
+ def test_glu_functionality(self):
284
+ """Test GLU activation function."""
285
+ layer = nn.GLU(dim=-1)
286
+
287
+ # GLU splits input in half along specified dimension
288
+ x = jnp.array([[1.0, 2.0, 3.0, 4.0],
289
+ [5.0, 6.0, 7.0, 8.0]])
290
+ output = layer(x)
291
+
292
+ # Output should have half the size along the split dimension
293
+ self.assertEqual(output.shape, (2, 2))
294
+
295
+ def test_glu_different_dimensions(self):
296
+ """Test GLU with different split dimensions."""
297
+ # Test splitting along different dimensions
298
+ x = jax.random.normal(self.key, (4, 6, 8))
299
+
300
+ layer_0 = nn.GLU(dim=0)
301
+ output_0 = layer_0(x)
302
+ self.assertEqual(output_0.shape, (2, 6, 8))
303
+
304
+ layer_1 = nn.GLU(dim=1)
305
+ output_1 = layer_1(x)
306
+ self.assertEqual(output_1.shape, (4, 3, 8))
307
+
308
+ layer_2 = nn.GLU(dim=2)
309
+ output_2 = layer_2(x)
310
+ self.assertEqual(output_2.shape, (4, 6, 4))
311
+
312
+ # Test GELU
313
+ def test_gelu_functionality(self):
314
+ """Test GELU activation function."""
315
+ layer = nn.GELU(approximate=False)
316
+
317
+ x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
318
+ output = layer(x)
319
+
320
+ # GELU should be smooth and differentiable everywhere
321
+ np.testing.assert_allclose(output[2], 0.0, rtol=1e-5) # GELU(0) ≈ 0
322
+
323
+ def test_gelu_approximate(self):
324
+ """Test GELU with tanh approximation."""
325
+ layer_exact = nn.GELU(approximate=False)
326
+ layer_approx = nn.GELU(approximate=True)
327
+
328
+ x = jnp.array([-1.0, 0.0, 1.0])
329
+ output_exact = layer_exact(x)
330
+ output_approx = layer_approx(x)
331
+
332
+ # Approximation should be close but not exactly equal
333
+ np.testing.assert_allclose(output_exact, output_approx, rtol=1e-2)
334
+
335
+ # Test Hardshrink
336
+ def test_hardshrink_functionality(self):
337
+ """Test Hardshrink activation function."""
338
+ lambd = 0.5
339
+ layer = nn.Hardshrink(lambd=lambd)
340
+
341
+ x = jnp.array([-1.0, -0.6, -0.5, -0.3, 0.0, 0.3, 0.5, 0.6, 1.0])
342
+ output = layer(x)
343
+
344
+ # Check each value according to hardshrink formula
345
+ expected = []
346
+ for xi in x:
347
+ if xi > lambd:
348
+ expected.append(xi)
349
+ elif xi < -lambd:
350
+ expected.append(xi)
351
+ else:
352
+ expected.append(0.0)
353
+ expected = jnp.array(expected)
354
+
355
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
356
+
357
+ # Test LeakyReLU
358
+ def test_leaky_relu_functionality(self):
359
+ """Test LeakyReLU activation function."""
360
+ negative_slope = 0.01
361
+ layer = nn.LeakyReLU(negative_slope=negative_slope)
362
+
363
+ x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
364
+ output = layer(x)
365
+
366
+ # Positive values should remain unchanged
367
+ self.assertTrue(jnp.all(output[x > 0] == x[x > 0]))
368
+ # Negative values should be scaled
369
+ negative_mask = x < 0
370
+ expected_negative = negative_slope * x[negative_mask]
371
+ np.testing.assert_allclose(output[negative_mask], expected_negative, rtol=1e-5)
372
+
373
+ def test_leaky_relu_custom_slope(self):
374
+ """Test LeakyReLU with custom negative slope."""
375
+ negative_slope = 0.2
376
+ layer = nn.LeakyReLU(negative_slope=negative_slope)
377
+
378
+ x = jnp.array([-5.0])
379
+ output = layer(x)
380
+ expected = negative_slope * x
381
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
382
+
383
+ # Test LogSigmoid
384
+ def test_log_sigmoid_functionality(self):
385
+ """Test LogSigmoid activation function."""
386
+ layer = nn.LogSigmoid()
387
+
388
+ x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])
389
+ output = layer(x)
390
+
391
+ # LogSigmoid(x) = log(sigmoid(x))
392
+ expected = jnp.log(jax.nn.sigmoid(x))
393
+ np.testing.assert_allclose(output, expected, rtol=1e-2)
394
+
395
+ # Output should always be negative or zero
396
+ self.assertTrue(jnp.all(output <= 0.0))
397
+
398
+ # Test Softplus
399
+ def test_softplus_functionality(self):
400
+ """Test Softplus activation function."""
401
+ layer = nn.Softplus()
402
+
403
+ x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])
404
+ output = layer(x)
405
+
406
+ # Softplus is a smooth approximation to ReLU
407
+ # Should always be positive
408
+ self.assertTrue(jnp.all(output > 0.0))
409
+
410
+ # For large positive values, should approximate x
411
+ np.testing.assert_allclose(output[-1], x[-1], rtol=1e-2)
412
+
413
+ # Test Softshrink
414
+ def test_softshrink_functionality(self):
415
+ """Test Softshrink activation function."""
416
+ lambd = 0.5
417
+ layer = nn.Softshrink(lambd=lambd)
418
+
419
+ x = jnp.array([-1.0, -0.5, -0.3, 0.0, 0.3, 0.5, 1.0])
420
+ output = layer(x)
421
+
422
+ # Check the softshrink formula
423
+ for i in range(len(x)):
424
+ if x[i] > lambd:
425
+ expected = x[i] - lambd
426
+ elif x[i] < -lambd:
427
+ expected = x[i] + lambd
428
+ else:
429
+ expected = 0.0
430
+ np.testing.assert_allclose(output[i], expected, rtol=1e-5)
431
+
432
+ # Test PReLU
433
+ def test_prelu_functionality(self):
434
+ """Test PReLU activation function."""
435
+ layer = nn.PReLU(num_parameters=1, init=0.25)
436
+
437
+ x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
438
+ output = layer(x)
439
+
440
+ # Positive values should remain unchanged
441
+ self.assertTrue(jnp.all(output[x > 0] == x[x > 0]))
442
+ # Negative values should be scaled by learned parameter
443
+ negative_mask = x < 0
444
+ # Check that negative values are scaled
445
+ self.assertTrue(jnp.all(output[negative_mask] != x[negative_mask]))
446
+
447
+ def test_prelu_multi_channel(self):
448
+ """Test PReLU with multiple channels."""
449
+ num_channels = 3
450
+ layer = nn.PReLU(num_parameters=num_channels, init=0.25)
451
+
452
+ # Input shape: (batch, channels, height, width)
453
+ x = jax.random.normal(self.key, (2, 4, 4, num_channels))
454
+ output = layer(x)
455
+
456
+ self.assertEqual(output.shape, x.shape)
457
+
458
+ # Test Softsign
459
+ def test_softsign_functionality(self):
460
+ """Test Softsign activation function."""
461
+ layer = nn.Softsign()
462
+
463
+ x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])
464
+ output = layer(x)
465
+
466
+ # Softsign(x) = x / (1 + |x|)
467
+ expected = x / (1 + jnp.abs(x))
468
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
469
+
470
+ # Output should be bounded between -1 and 1
471
+ self.assertTrue(jnp.all((output >= -1.0) & (output <= 1.0)))
472
+
473
+ # Test Tanhshrink
474
+ def test_tanhshrink_functionality(self):
475
+ """Test Tanhshrink activation function."""
476
+ layer = nn.Tanhshrink()
477
+
478
+ x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
479
+ output = layer(x)
480
+
481
+ # Tanhshrink(x) = x - tanh(x)
482
+ expected = x - jnp.tanh(x)
483
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
484
+
485
+ # Test Softmin
486
+ def test_softmin_functionality(self):
487
+ """Test Softmin activation function."""
488
+ layer = nn.Softmin(dim=-1)
489
+
490
+ x = jnp.array([[1.0, 2.0, 3.0],
491
+ [4.0, 5.0, 6.0]])
492
+ output = layer(x)
493
+
494
+ # Softmin should sum to 1 along the specified dimension
495
+ sums = jnp.sum(output, axis=-1)
496
+ np.testing.assert_allclose(sums, jnp.ones_like(sums), rtol=1e-5)
497
+
498
+ # Higher values should have lower probabilities
499
+ self.assertTrue(jnp.all(output[:, 0] > output[:, 1]))
500
+ self.assertTrue(jnp.all(output[:, 1] > output[:, 2]))
501
+
502
+ # Test Softmax
503
+ def test_softmax_functionality(self):
504
+ """Test Softmax activation function."""
505
+ layer = nn.Softmax(dim=-1)
506
+
507
+ x = jnp.array([[1.0, 2.0, 3.0],
508
+ [4.0, 5.0, 6.0]])
509
+ output = layer(x)
510
+
511
+ # Softmax should sum to 1 along the specified dimension
512
+ sums = jnp.sum(output, axis=-1)
513
+ np.testing.assert_allclose(sums, jnp.ones_like(sums), rtol=1e-5)
514
+
515
+ # Higher values should have higher probabilities
516
+ self.assertTrue(jnp.all(output[:, 2] > output[:, 1]))
517
+ self.assertTrue(jnp.all(output[:, 1] > output[:, 0]))
518
+
519
+ def test_softmax_numerical_stability(self):
520
+ """Test Softmax numerical stability with large values."""
521
+ layer = nn.Softmax(dim=-1)
522
+
523
+ # Test with large values that could cause overflow
524
+ x = jnp.array([[1000.0, 1000.0, 1000.0]])
525
+ output = layer(x)
526
+
527
+ # Should still sum to 1 and have equal probabilities
528
+ np.testing.assert_allclose(jnp.sum(output), 1.0, rtol=1e-5)
529
+ np.testing.assert_allclose(output[0, 0], 1/3, rtol=1e-5)
530
+
531
+ # Test Softmax2d
532
+ def test_softmax2d_functionality(self):
533
+ """Test Softmax2d activation function."""
534
+ layer = nn.Softmax2d()
535
+
536
+ # Input shape: (batch, channels, height, width)
537
+ x = jax.random.normal(self.key, (2, 3, 4, 5))
538
+ output = layer(x)
539
+
540
+ self.assertEqual(output.shape, x.shape)
541
+
542
+ # Should sum to 1 across channels for each spatial location
543
+ channel_sums = jnp.sum(output, axis=1)
544
+ np.testing.assert_allclose(channel_sums, jnp.ones_like(channel_sums), rtol=1e-5)
545
+
546
+ def test_softmax2d_3d_input(self):
547
+ """Test Softmax2d with 3D input."""
548
+ layer = nn.Softmax2d()
549
+
550
+ # Input shape: (channels, height, width)
551
+ x = jax.random.normal(self.key, (3, 4, 5))
552
+ output = layer(x)
553
+
554
+ self.assertEqual(output.shape, x.shape)
555
+
556
+ # Test LogSoftmax
557
+ def test_log_softmax_functionality(self):
558
+ """Test LogSoftmax activation function."""
559
+ layer = nn.LogSoftmax(dim=-1)
560
+
561
+ x = jnp.array([[1.0, 2.0, 3.0],
562
+ [4.0, 5.0, 6.0]])
563
+ output = layer(x)
564
+
565
+ # LogSoftmax = log(softmax(x))
566
+ softmax_output = jax.nn.softmax(x, axis=-1)
567
+ expected = jnp.log(softmax_output)
568
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
569
+
570
+ # Output should be all negative or zero
571
+ self.assertTrue(jnp.all(output <= 0.0))
572
+
573
+ def test_log_softmax_numerical_stability(self):
574
+ """Test LogSoftmax numerical stability."""
575
+ layer = nn.LogSoftmax(dim=-1)
576
+
577
+ # Test with values that could cause numerical issues
578
+ x = jnp.array([[1000.0, 0.0, -1000.0]])
579
+ output = layer(x)
580
+
581
+ # Should not contain NaN or Inf
582
+ self.assertFalse(jnp.any(jnp.isnan(output)))
583
+ self.assertFalse(jnp.any(jnp.isinf(output)))
584
+
585
+ # Test Identity
586
+ def test_identity_functionality(self):
587
+ """Test Identity activation function."""
588
+ layer = nn.Identity()
589
+
590
+ x = jax.random.normal(self.key, (3, 4, 5))
591
+ output = layer(x)
592
+
593
+ # Should be exactly equal to input
594
+ np.testing.assert_array_equal(output, x)
595
+
596
+ def test_identity_gradient(self):
597
+ """Test Identity gradient flow."""
598
+ layer = nn.Identity()
599
+
600
+ x = jax.random.normal(self.key, (3, 4))
601
+
602
+ def loss_fn(x):
603
+ return jnp.sum(layer(x))
604
+
605
+ grad = jax.grad(loss_fn)(x)
606
+
607
+ # Gradient should be all ones
608
+ np.testing.assert_allclose(grad, jnp.ones_like(x), rtol=1e-5)
609
+
610
+ # Test SpikeBitwise
611
+ def test_spike_bitwise_add(self):
612
+ """Test SpikeBitwise with ADD operation."""
613
+ layer = nn.SpikeBitwise(op='and')
614
+
615
+ x = jnp.array([[1.0, 0.0], [1.0, 1.0]])
616
+ y = jnp.array([[1.0, 1.0], [0.0, 1.0]])
617
+ output = layer(x, y)
618
+
619
+ expected = jnp.logical_and(x, y)
620
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
621
+
622
+ def test_spike_bitwise_and(self):
623
+ """Test SpikeBitwise with AND operation."""
624
+ layer = nn.SpikeBitwise(op='and')
625
+
626
+ x = jnp.array([[1.0, 0.0], [1.0, 1.0]])
627
+ y = jnp.array([[1.0, 1.0], [0.0, 1.0]])
628
+ output = layer(x, y)
629
+
630
+ expected = x * y
631
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
632
+
633
+ def test_spike_bitwise_iand(self):
634
+ """Test SpikeBitwise with IAND operation."""
635
+ layer = nn.SpikeBitwise(op='iand')
636
+
637
+ x = jnp.array([[1.0, 0.0], [1.0, 1.0]])
638
+ y = jnp.array([[1.0, 1.0], [0.0, 1.0]])
639
+ output = layer(x, y)
640
+
641
+ expected = (1 - x) * y
642
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
643
+
644
+ def test_spike_bitwise_or(self):
645
+ """Test SpikeBitwise with OR operation."""
646
+ layer = nn.SpikeBitwise(op='or')
647
+
648
+ x = jnp.array([[1.0, 0.0], [1.0, 1.0]])
649
+ y = jnp.array([[1.0, 1.0], [0.0, 1.0]])
650
+ output = layer(x, y)
651
+
652
+ expected = (x + y) - (x * y)
653
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
654
+
655
+
656
+ class TestEdgeCases(parameterized.TestCase):
657
+ """Test edge cases and boundary conditions."""
658
+
659
+ def test_zero_input(self):
660
+ """Test all activations with zero input."""
661
+ x = jnp.zeros((3, 4))
662
+
663
+ activations = [
664
+ nn.ReLU(),
665
+ nn.Sigmoid(),
666
+ nn.Tanh(),
667
+ nn.SiLU(),
668
+ nn.ELU(),
669
+ nn.GELU(),
670
+ nn.Softplus(),
671
+ nn.Softsign(),
672
+ ]
673
+
674
+ for activation in activations:
675
+ output = activation(x)
676
+ self.assertEqual(output.shape, x.shape)
677
+ self.assertFalse(jnp.any(jnp.isnan(output)))
678
+
679
+ def test_large_positive_input(self):
680
+ """Test activations with very large positive values."""
681
+ x = jnp.ones((2, 3)) * 1000.0
682
+
683
+ activations = [
684
+ nn.ReLU(),
685
+ nn.Sigmoid(),
686
+ nn.Tanh(),
687
+ nn.Hardsigmoid(),
688
+ nn.Hardswish(),
689
+ ]
690
+
691
+ for activation in activations:
692
+ output = activation(x)
693
+ self.assertFalse(jnp.any(jnp.isnan(output)))
694
+ self.assertFalse(jnp.any(jnp.isinf(output)))
695
+
696
+ def test_large_negative_input(self):
697
+ """Test activations with very large negative values."""
698
+ x = jnp.ones((2, 3)) * -1000.0
699
+
700
+ activations = [
701
+ nn.ReLU(),
702
+ nn.Sigmoid(),
703
+ nn.Tanh(),
704
+ nn.Hardsigmoid(),
705
+ nn.Hardswish(),
706
+ ]
707
+
708
+ for activation in activations:
709
+ output = activation(x)
710
+ self.assertFalse(jnp.any(jnp.isnan(output)))
711
+ self.assertFalse(jnp.any(jnp.isinf(output)))
712
+
713
+ def test_nan_propagation(self):
714
+ """Test that NaN inputs produce NaN outputs (where appropriate)."""
715
+ x = jnp.array([jnp.nan, 1.0, 2.0])
716
+
717
+ activations = [
718
+ nn.ReLU(),
719
+ nn.Sigmoid(),
720
+ nn.Tanh(),
721
+ ]
722
+
723
+ for activation in activations:
724
+ output = activation(x)
725
+ self.assertTrue(jnp.isnan(output[0]))
726
+
727
+ def test_inf_handling(self):
728
+ """Test handling of infinite values."""
729
+ x = jnp.array([jnp.inf, -jnp.inf, 1.0])
730
+
731
+ # ReLU should handle inf properly
732
+ relu = nn.ReLU()
733
+ output = relu(x)
734
+ self.assertEqual(output[0], jnp.inf)
735
+ self.assertEqual(output[1], 0.0)
736
+
737
+ # Sigmoid should saturate
738
+ sigmoid = nn.Sigmoid()
739
+ output = sigmoid(x)
740
+ np.testing.assert_allclose(output[0], 1.0, rtol=1e-5)
741
+ np.testing.assert_allclose(output[1], 0.0, rtol=1e-5)
742
+
743
+
744
+ class TestBatchProcessing(parameterized.TestCase):
745
+ """Test batch processing capabilities."""
746
+
747
+ @parameterized.parameters(
748
+ (nn.ReLU(), ),
749
+ (nn.Sigmoid(), ),
750
+ (nn.Tanh(), ),
751
+ (nn.GELU(), ),
752
+ (nn.SiLU(), ),
753
+ (nn.ELU(), ),
754
+ )
755
+ def test_batch_consistency(self, activation):
756
+ """Test that batch processing gives same results as individual processing."""
757
+ # Process as batch
758
+ batch_input = jax.random.normal(jax.random.PRNGKey(42), (5, 10))
759
+ batch_output = activation(batch_input)
760
+
761
+ # Process individually
762
+ individual_outputs = []
763
+ for i in range(5):
764
+ individual_output = activation(batch_input[i])
765
+ individual_outputs.append(individual_output)
766
+ individual_outputs = jnp.stack(individual_outputs)
767
+
768
+ np.testing.assert_allclose(batch_output, individual_outputs, rtol=1e-5)
769
+
770
+ def test_different_batch_sizes(self):
771
+ """Test activations with different batch sizes."""
772
+ activation = nn.ReLU()
773
+
774
+ for batch_size in [1, 10, 100]:
775
+ x = jax.random.normal(jax.random.PRNGKey(42), (batch_size, 20))
776
+ output = activation(x)
777
+ self.assertEqual(output.shape[0], batch_size)
778
+
779
+
780
+ class TestMemoryAndPerformance(parameterized.TestCase):
781
+ """Test memory and performance characteristics."""
782
+
783
+ def test_in_place_operations(self):
784
+ """Test that activations don't modify input in-place."""
785
+ x_original = jax.random.normal(jax.random.PRNGKey(42), (10, 10))
786
+ x = x_original.copy()
787
+
788
+ activations = [
789
+ nn.ReLU(),
790
+ nn.Sigmoid(),
791
+ nn.Tanh(),
792
+ ]
793
+
794
+ for activation in activations:
795
+ output = activation(x)
796
+ np.testing.assert_array_equal(x, x_original)
797
+
798
+ def test_jit_compilation(self):
799
+ """Test that activations work with JIT compilation."""
800
+ @jax.jit
801
+ def forward(x):
802
+ relu = nn.ReLU()
803
+ return relu(x)
804
+
805
+ x = jax.random.normal(jax.random.PRNGKey(42), (10, 10))
806
+ output = forward(x)
807
+
808
+ # Should not raise any errors and produce valid output
809
+ self.assertEqual(output.shape, x.shape)
810
+
811
+ @parameterized.parameters(
812
+ (nn.ReLU(), ),
813
+ (nn.Sigmoid(), ),
814
+ (nn.Tanh(), ),
815
+ )
816
+ def test_vmap_compatibility(self, activation):
817
+ """Test compatibility with vmap."""
818
+ def single_forward(x):
819
+ return activation(x)
820
+
821
+ batch_forward = jax.vmap(single_forward)
822
+
823
+ x = jax.random.normal(jax.random.PRNGKey(42), (5, 10, 20))
824
+ output = batch_forward(x)
825
+
826
+ self.assertEqual(output.shape, x.shape)
827
+
828
+
829
+ if __name__ == '__main__':
830
+ absltest.main()