brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +95 -29
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.9.dist-info/RECORD +0 -130
  161. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -13,157 +13,818 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ import unittest
16
17
  from absl.testing import absltest
17
18
  from absl.testing import parameterized
18
19
 
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import numpy as np
23
+
19
24
  import brainstate
25
+ import brainstate.nn as nn
26
+
27
+
28
+ class TestActivationFunctions(parameterized.TestCase):
29
+ """Comprehensive tests for activation functions."""
30
+
31
+ def setUp(self):
32
+ """Set up test fixtures."""
33
+ self.seed = 42
34
+ self.key = jax.random.PRNGKey(self.seed)
35
+
36
+ def _check_shape_preservation(self, layer, input_shape):
37
+ """Helper to check if layer preserves input shape."""
38
+ x = jax.random.normal(self.key, input_shape)
39
+ output = layer(x)
40
+ self.assertEqual(output.shape, x.shape)
41
+
42
+ def _check_gradient_flow(self, layer, input_shape):
43
+ """Helper to check if gradients can flow through the layer."""
44
+ x = jax.random.normal(self.key, input_shape)
45
+
46
+ def loss_fn(x):
47
+ return jnp.sum(layer(x))
48
+
49
+ grad = jax.grad(loss_fn)(x)
50
+ self.assertEqual(grad.shape, x.shape)
51
+ # Check that gradients are not all zeros (for most activations)
52
+ if not isinstance(layer, (nn.Threshold, nn.Hardtanh, nn.ReLU6)):
53
+ self.assertFalse(jnp.allclose(grad, 0.0))
54
+
55
+ # Test Threshold
56
+ def test_threshold_functionality(self):
57
+ """Test Threshold activation function."""
58
+ layer = nn.Threshold(threshold=0.5, value=0.0)
59
+
60
+ # Test with values above and below threshold
61
+ x = jnp.array([-1.0, 0.0, 0.3, 0.7, 1.0])
62
+ output = layer(x)
63
+ expected = jnp.array([0.0, 0.0, 0.0, 0.7, 1.0])
64
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
65
+
66
+ @parameterized.parameters(
67
+ ((2,), ),
68
+ ((3, 4), ),
69
+ ((2, 3, 4), ),
70
+ ((2, 3, 4, 5), ),
71
+ )
72
+ def test_threshold_shapes(self, shape):
73
+ """Test Threshold with different input shapes."""
74
+ layer = nn.Threshold(threshold=0.1, value=20)
75
+ self._check_shape_preservation(layer, shape)
76
+
77
+ # Test ReLU
78
+ def test_relu_functionality(self):
79
+ """Test ReLU activation function."""
80
+ layer = nn.ReLU()
81
+
82
+ x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
83
+ output = layer(x)
84
+ expected = jnp.array([0.0, 0.0, 0.0, 1.0, 2.0])
85
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
86
+
87
+ @parameterized.parameters(
88
+ ((10,), ),
89
+ ((5, 10), ),
90
+ ((3, 5, 10), ),
91
+ )
92
+ def test_relu_shapes_and_gradients(self, shape):
93
+ """Test ReLU shapes and gradients."""
94
+ layer = nn.ReLU()
95
+ self._check_shape_preservation(layer, shape)
96
+ self._check_gradient_flow(layer, shape)
97
+
98
+ # Test RReLU
99
+ def test_rrelu_functionality(self):
100
+ """Test RReLU activation function."""
101
+ layer = nn.RReLU(lower=0.1, upper=0.3)
102
+
103
+ # Test positive and negative values
104
+ x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
105
+ output = layer(x)
106
+
107
+ # Positive values should remain unchanged
108
+ self.assertTrue(jnp.all(output[x > 0] == x[x > 0]))
109
+ # Negative values should be scaled by a factor in [lower, upper]
110
+ negative_mask = x < 0
111
+ if jnp.any(negative_mask):
112
+ scaled = output[negative_mask] / x[negative_mask]
113
+ self.assertTrue(jnp.all((scaled >= 0.1) & (scaled <= 0.3)))
114
+
115
+ # Test Hardtanh
116
+ def test_hardtanh_functionality(self):
117
+ """Test Hardtanh activation function."""
118
+ layer = nn.Hardtanh(min_val=-1.0, max_val=1.0)
119
+
120
+ x = jnp.array([-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0])
121
+ output = layer(x)
122
+ expected = jnp.array([-1.0, -1.0, -0.5, 0.0, 0.5, 1.0, 1.0])
123
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
124
+
125
+ def test_hardtanh_custom_bounds(self):
126
+ """Test Hardtanh with custom bounds."""
127
+ layer = nn.Hardtanh(min_val=-2.0, max_val=3.0)
128
+
129
+ x = jnp.array([-3.0, -2.0, 0.0, 3.0, 4.0])
130
+ output = layer(x)
131
+ expected = jnp.array([-2.0, -2.0, 0.0, 3.0, 3.0])
132
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
133
+
134
+ # Test ReLU6
135
+ def test_relu6_functionality(self):
136
+ """Test ReLU6 activation function."""
137
+ layer = nn.ReLU6()
138
+
139
+ x = jnp.array([-2.0, 0.0, 3.0, 6.0, 8.0])
140
+ output = layer(x)
141
+ expected = jnp.array([0.0, 0.0, 3.0, 6.0, 6.0])
142
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
143
+
144
+ # Test Sigmoid
145
+ def test_sigmoid_functionality(self):
146
+ """Test Sigmoid activation function."""
147
+ layer = nn.Sigmoid()
148
+
149
+ x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])
150
+ output = layer(x)
151
+
152
+ # Check sigmoid properties
153
+ self.assertTrue(jnp.all((output >= 0.0) & (output <= 1.0)))
154
+ np.testing.assert_allclose(output[2], 0.5, rtol=1e-5) # sigmoid(0) = 0.5
155
+
156
+ @parameterized.parameters(
157
+ ((10,), ),
158
+ ((5, 10), ),
159
+ ((3, 5, 10), ),
160
+ )
161
+ def test_sigmoid_shapes_and_gradients(self, shape):
162
+ """Test Sigmoid shapes and gradients."""
163
+ layer = nn.Sigmoid()
164
+ self._check_shape_preservation(layer, shape)
165
+ self._check_gradient_flow(layer, shape)
166
+
167
+ # Test Hardsigmoid
168
+ def test_hardsigmoid_functionality(self):
169
+ """Test Hardsigmoid activation function."""
170
+ layer = nn.Hardsigmoid()
171
+
172
+ x = jnp.array([-4.0, -3.0, -1.0, 0.0, 1.0, 3.0, 4.0])
173
+ output = layer(x)
174
+
175
+ # Check bounds
176
+ self.assertTrue(jnp.all((output >= 0.0) & (output <= 1.0)))
177
+ # Check specific values
178
+ np.testing.assert_allclose(output[1], 0.0, rtol=1e-5) # x=-3
179
+ np.testing.assert_allclose(output[3], 0.5, rtol=1e-5) # x=0
180
+ np.testing.assert_allclose(output[5], 1.0, rtol=1e-5) # x=3
181
+
182
+ # Test Tanh
183
+ def test_tanh_functionality(self):
184
+ """Test Tanh activation function."""
185
+ layer = nn.Tanh()
186
+
187
+ x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])
188
+ output = layer(x)
189
+
190
+ # Check tanh properties
191
+ self.assertTrue(jnp.all((output >= -1.0) & (output <= 1.0)))
192
+ np.testing.assert_allclose(output[2], 0.0, rtol=1e-5) # tanh(0) = 0
193
+
194
+ # Test SiLU (Swish)
195
+ def test_silu_functionality(self):
196
+ """Test SiLU activation function."""
197
+ layer = nn.SiLU()
198
+
199
+ x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
200
+ output = layer(x)
201
+
202
+ # SiLU(x) = x * sigmoid(x)
203
+ expected = x * jax.nn.sigmoid(x)
204
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
205
+
206
+ # Test Mish
207
+ def test_mish_functionality(self):
208
+ """Test Mish activation function."""
209
+ layer = nn.Mish()
210
+
211
+ x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
212
+ output = layer(x)
213
+
214
+ # Mish(x) = x * tanh(softplus(x))
215
+ expected = x * jnp.tanh(jax.nn.softplus(x))
216
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
217
+
218
+ # Test Hardswish
219
+ def test_hardswish_functionality(self):
220
+ """Test Hardswish activation function."""
221
+ layer = nn.Hardswish()
222
+
223
+ x = jnp.array([-4.0, -3.0, -1.0, 0.0, 1.0, 3.0, 4.0])
224
+ output = layer(x)
225
+
226
+ # Check boundary conditions
227
+ np.testing.assert_allclose(output[1], 0.0, rtol=1e-5) # x=-3
228
+ np.testing.assert_allclose(output[5], 3.0, rtol=1e-5) # x=3
229
+
230
+ # Test ELU
231
+ def test_elu_functionality(self):
232
+ """Test ELU activation function."""
233
+ layer = nn.ELU(alpha=1.0)
234
+
235
+ x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
236
+ output = layer(x)
237
+
238
+ # Positive values should remain unchanged
239
+ self.assertTrue(jnp.all(output[x > 0] == x[x > 0]))
240
+ # Check ELU formula for negative values
241
+ negative_mask = x <= 0
242
+ expected_negative = 1.0 * (jnp.exp(x[negative_mask]) - 1)
243
+ np.testing.assert_allclose(output[negative_mask], expected_negative, rtol=1e-5)
244
+
245
+ def test_elu_with_different_alpha(self):
246
+ """Test ELU with different alpha values."""
247
+ alpha = 2.0
248
+ layer = nn.ELU(alpha=alpha)
249
+
250
+ x = jnp.array([-1.0])
251
+ output = layer(x)
252
+ expected = alpha * (jnp.exp(x) - 1)
253
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
254
+
255
+ # Test CELU
256
+ def test_celu_functionality(self):
257
+ """Test CELU activation function."""
258
+ layer = nn.CELU(alpha=1.0)
259
+
260
+ x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
261
+ output = layer(x)
262
+
263
+ # Positive values should remain unchanged
264
+ self.assertTrue(jnp.all(output[x > 0] == x[x > 0]))
265
+
266
+ # Test SELU
267
+ def test_selu_functionality(self):
268
+ """Test SELU activation function."""
269
+ layer = nn.SELU()
270
+
271
+ x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
272
+ output = layer(x)
273
+
274
+ # Check that output is scaled ELU
275
+ # SELU has specific scale and alpha values
276
+ scale = 1.0507009873554804934193349852946
277
+ alpha = 1.6732632423543772848170429916717
278
+
279
+ positive_mask = x > 0
280
+ self.assertTrue(jnp.all(output[positive_mask] == scale * x[positive_mask]))
281
+
282
+ # Test GLU
283
+ def test_glu_functionality(self):
284
+ """Test GLU activation function."""
285
+ layer = nn.GLU(dim=-1)
286
+
287
+ # GLU splits input in half along specified dimension
288
+ x = jnp.array([[1.0, 2.0, 3.0, 4.0],
289
+ [5.0, 6.0, 7.0, 8.0]])
290
+ output = layer(x)
291
+
292
+ # Output should have half the size along the split dimension
293
+ self.assertEqual(output.shape, (2, 2))
294
+
295
+ def test_glu_different_dimensions(self):
296
+ """Test GLU with different split dimensions."""
297
+ # Test splitting along different dimensions
298
+ x = jax.random.normal(self.key, (4, 6, 8))
299
+
300
+ layer_0 = nn.GLU(dim=0)
301
+ output_0 = layer_0(x)
302
+ self.assertEqual(output_0.shape, (2, 6, 8))
303
+
304
+ layer_1 = nn.GLU(dim=1)
305
+ output_1 = layer_1(x)
306
+ self.assertEqual(output_1.shape, (4, 3, 8))
307
+
308
+ layer_2 = nn.GLU(dim=2)
309
+ output_2 = layer_2(x)
310
+ self.assertEqual(output_2.shape, (4, 6, 4))
311
+
312
+ # Test GELU
313
+ def test_gelu_functionality(self):
314
+ """Test GELU activation function."""
315
+ layer = nn.GELU(approximate=False)
316
+
317
+ x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
318
+ output = layer(x)
319
+
320
+ # GELU should be smooth and differentiable everywhere
321
+ np.testing.assert_allclose(output[2], 0.0, rtol=1e-5) # GELU(0) ≈ 0
322
+
323
+ def test_gelu_approximate(self):
324
+ """Test GELU with tanh approximation."""
325
+ layer_exact = nn.GELU(approximate=False)
326
+ layer_approx = nn.GELU(approximate=True)
327
+
328
+ x = jnp.array([-1.0, 0.0, 1.0])
329
+ output_exact = layer_exact(x)
330
+ output_approx = layer_approx(x)
331
+
332
+ # Approximation should be close but not exactly equal
333
+ np.testing.assert_allclose(output_exact, output_approx, rtol=1e-2)
334
+
335
+ # Test Hardshrink
336
+ def test_hardshrink_functionality(self):
337
+ """Test Hardshrink activation function."""
338
+ lambd = 0.5
339
+ layer = nn.Hardshrink(lambd=lambd)
340
+
341
+ x = jnp.array([-1.0, -0.6, -0.5, -0.3, 0.0, 0.3, 0.5, 0.6, 1.0])
342
+ output = layer(x)
343
+
344
+ # Check each value according to hardshrink formula
345
+ expected = []
346
+ for xi in x:
347
+ if xi > lambd:
348
+ expected.append(xi)
349
+ elif xi < -lambd:
350
+ expected.append(xi)
351
+ else:
352
+ expected.append(0.0)
353
+ expected = jnp.array(expected)
354
+
355
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
356
+
357
+ # Test LeakyReLU
358
+ def test_leaky_relu_functionality(self):
359
+ """Test LeakyReLU activation function."""
360
+ negative_slope = 0.01
361
+ layer = nn.LeakyReLU(negative_slope=negative_slope)
362
+
363
+ x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
364
+ output = layer(x)
365
+
366
+ # Positive values should remain unchanged
367
+ self.assertTrue(jnp.all(output[x > 0] == x[x > 0]))
368
+ # Negative values should be scaled
369
+ negative_mask = x < 0
370
+ expected_negative = negative_slope * x[negative_mask]
371
+ np.testing.assert_allclose(output[negative_mask], expected_negative, rtol=1e-5)
372
+
373
+ def test_leaky_relu_custom_slope(self):
374
+ """Test LeakyReLU with custom negative slope."""
375
+ negative_slope = 0.2
376
+ layer = nn.LeakyReLU(negative_slope=negative_slope)
377
+
378
+ x = jnp.array([-5.0])
379
+ output = layer(x)
380
+ expected = negative_slope * x
381
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
382
+
383
+ # Test LogSigmoid
384
+ def test_log_sigmoid_functionality(self):
385
+ """Test LogSigmoid activation function."""
386
+ layer = nn.LogSigmoid()
387
+
388
+ x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])
389
+ output = layer(x)
390
+
391
+ # LogSigmoid(x) = log(sigmoid(x))
392
+ expected = jnp.log(jax.nn.sigmoid(x))
393
+ np.testing.assert_allclose(output, expected, rtol=1e-2)
394
+
395
+ # Output should always be negative or zero
396
+ self.assertTrue(jnp.all(output <= 0.0))
397
+
398
+ # Test Softplus
399
+ def test_softplus_functionality(self):
400
+ """Test Softplus activation function."""
401
+ layer = nn.Softplus()
402
+
403
+ x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])
404
+ output = layer(x)
405
+
406
+ # Softplus is a smooth approximation to ReLU
407
+ # Should always be positive
408
+ self.assertTrue(jnp.all(output > 0.0))
409
+
410
+ # For large positive values, should approximate x
411
+ np.testing.assert_allclose(output[-1], x[-1], rtol=1e-2)
412
+
413
+ # Test Softshrink
414
+ def test_softshrink_functionality(self):
415
+ """Test Softshrink activation function."""
416
+ lambd = 0.5
417
+ layer = nn.Softshrink(lambd=lambd)
418
+
419
+ x = jnp.array([-1.0, -0.5, -0.3, 0.0, 0.3, 0.5, 1.0])
420
+ output = layer(x)
421
+
422
+ # Check the softshrink formula
423
+ for i in range(len(x)):
424
+ if x[i] > lambd:
425
+ expected = x[i] - lambd
426
+ elif x[i] < -lambd:
427
+ expected = x[i] + lambd
428
+ else:
429
+ expected = 0.0
430
+ np.testing.assert_allclose(output[i], expected, rtol=1e-5)
431
+
432
+ # Test PReLU
433
+ def test_prelu_functionality(self):
434
+ """Test PReLU activation function."""
435
+ layer = nn.PReLU(num_parameters=1, init=0.25)
436
+
437
+ x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
438
+ output = layer(x)
439
+
440
+ # Positive values should remain unchanged
441
+ self.assertTrue(jnp.all(output[x > 0] == x[x > 0]))
442
+ # Negative values should be scaled by learned parameter
443
+ negative_mask = x < 0
444
+ # Check that negative values are scaled
445
+ self.assertTrue(jnp.all(output[negative_mask] != x[negative_mask]))
446
+
447
+ def test_prelu_multi_channel(self):
448
+ """Test PReLU with multiple channels."""
449
+ num_channels = 3
450
+ layer = nn.PReLU(num_parameters=num_channels, init=0.25)
451
+
452
+ # Input shape: (batch, channels, height, width)
453
+ x = jax.random.normal(self.key, (2, 4, 4, num_channels))
454
+ output = layer(x)
455
+
456
+ self.assertEqual(output.shape, x.shape)
457
+
458
+ # Test Softsign
459
+ def test_softsign_functionality(self):
460
+ """Test Softsign activation function."""
461
+ layer = nn.Softsign()
462
+
463
+ x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])
464
+ output = layer(x)
465
+
466
+ # Softsign(x) = x / (1 + |x|)
467
+ expected = x / (1 + jnp.abs(x))
468
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
469
+
470
+ # Output should be bounded between -1 and 1
471
+ self.assertTrue(jnp.all((output >= -1.0) & (output <= 1.0)))
472
+
473
+ # Test Tanhshrink
474
+ def test_tanhshrink_functionality(self):
475
+ """Test Tanhshrink activation function."""
476
+ layer = nn.Tanhshrink()
477
+
478
+ x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
479
+ output = layer(x)
480
+
481
+ # Tanhshrink(x) = x - tanh(x)
482
+ expected = x - jnp.tanh(x)
483
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
484
+
485
+ # Test Softmin
486
+ def test_softmin_functionality(self):
487
+ """Test Softmin activation function."""
488
+ layer = nn.Softmin(dim=-1)
489
+
490
+ x = jnp.array([[1.0, 2.0, 3.0],
491
+ [4.0, 5.0, 6.0]])
492
+ output = layer(x)
493
+
494
+ # Softmin should sum to 1 along the specified dimension
495
+ sums = jnp.sum(output, axis=-1)
496
+ np.testing.assert_allclose(sums, jnp.ones_like(sums), rtol=1e-5)
497
+
498
+ # Higher values should have lower probabilities
499
+ self.assertTrue(jnp.all(output[:, 0] > output[:, 1]))
500
+ self.assertTrue(jnp.all(output[:, 1] > output[:, 2]))
501
+
502
+ # Test Softmax
503
+ def test_softmax_functionality(self):
504
+ """Test Softmax activation function."""
505
+ layer = nn.Softmax(dim=-1)
506
+
507
+ x = jnp.array([[1.0, 2.0, 3.0],
508
+ [4.0, 5.0, 6.0]])
509
+ output = layer(x)
510
+
511
+ # Softmax should sum to 1 along the specified dimension
512
+ sums = jnp.sum(output, axis=-1)
513
+ np.testing.assert_allclose(sums, jnp.ones_like(sums), rtol=1e-5)
514
+
515
+ # Higher values should have higher probabilities
516
+ self.assertTrue(jnp.all(output[:, 2] > output[:, 1]))
517
+ self.assertTrue(jnp.all(output[:, 1] > output[:, 0]))
518
+
519
+ def test_softmax_numerical_stability(self):
520
+ """Test Softmax numerical stability with large values."""
521
+ layer = nn.Softmax(dim=-1)
522
+
523
+ # Test with large values that could cause overflow
524
+ x = jnp.array([[1000.0, 1000.0, 1000.0]])
525
+ output = layer(x)
526
+
527
+ # Should still sum to 1 and have equal probabilities
528
+ np.testing.assert_allclose(jnp.sum(output), 1.0, rtol=1e-5)
529
+ np.testing.assert_allclose(output[0, 0], 1/3, rtol=1e-5)
530
+
531
+ # Test Softmax2d
532
+ def test_softmax2d_functionality(self):
533
+ """Test Softmax2d activation function."""
534
+ layer = nn.Softmax2d()
535
+
536
+ # Input shape: (batch, channels, height, width)
537
+ x = jax.random.normal(self.key, (2, 3, 4, 5))
538
+ output = layer(x)
539
+
540
+ self.assertEqual(output.shape, x.shape)
541
+
542
+ # Should sum to 1 across channels for each spatial location
543
+ channel_sums = jnp.sum(output, axis=1)
544
+ np.testing.assert_allclose(channel_sums, jnp.ones_like(channel_sums), rtol=1e-5)
545
+
546
+ def test_softmax2d_3d_input(self):
547
+ """Test Softmax2d with 3D input."""
548
+ layer = nn.Softmax2d()
549
+
550
+ # Input shape: (channels, height, width)
551
+ x = jax.random.normal(self.key, (3, 4, 5))
552
+ output = layer(x)
553
+
554
+ self.assertEqual(output.shape, x.shape)
555
+
556
+ # Test LogSoftmax
557
+ def test_log_softmax_functionality(self):
558
+ """Test LogSoftmax activation function."""
559
+ layer = nn.LogSoftmax(dim=-1)
20
560
 
561
+ x = jnp.array([[1.0, 2.0, 3.0],
562
+ [4.0, 5.0, 6.0]])
563
+ output = layer(x)
21
564
 
22
- class Test_Activation(parameterized.TestCase):
23
-
24
- def test_Threshold(self):
25
- threshold_layer = brainstate.nn.Threshold(5, 20)
26
- input = brainstate.random.randn(2)
27
- output = threshold_layer(input)
28
-
29
- def test_ReLU(self):
30
- ReLU_layer = brainstate.nn.ReLU()
31
- input = brainstate.random.randn(2)
32
- output = ReLU_layer(input)
33
-
34
- def test_RReLU(self):
35
- RReLU_layer = brainstate.nn.RReLU(lower=0, upper=1)
36
- input = brainstate.random.randn(2)
37
- output = RReLU_layer(input)
38
-
39
- def test_Hardtanh(self):
40
- Hardtanh_layer = brainstate.nn.Hardtanh(min_val=0, max_val=1, )
41
- input = brainstate.random.randn(2)
42
- output = Hardtanh_layer(input)
43
-
44
- def test_ReLU6(self):
45
- ReLU6_layer = brainstate.nn.ReLU6()
46
- input = brainstate.random.randn(2)
47
- output = ReLU6_layer(input)
48
-
49
- def test_Sigmoid(self):
50
- Sigmoid_layer = brainstate.nn.Sigmoid()
51
- input = brainstate.random.randn(2)
52
- output = Sigmoid_layer(input)
53
-
54
- def test_Hardsigmoid(self):
55
- Hardsigmoid_layer = brainstate.nn.Hardsigmoid()
56
- input = brainstate.random.randn(2)
57
- output = Hardsigmoid_layer(input)
58
-
59
- def test_Tanh(self):
60
- Tanh_layer = brainstate.nn.Tanh()
61
- input = brainstate.random.randn(2)
62
- output = Tanh_layer(input)
63
-
64
- def test_SiLU(self):
65
- SiLU_layer = brainstate.nn.SiLU()
66
- input = brainstate.random.randn(2)
67
- output = SiLU_layer(input)
68
-
69
- def test_Mish(self):
70
- Mish_layer = brainstate.nn.Mish()
71
- input = brainstate.random.randn(2)
72
- output = Mish_layer(input)
73
-
74
- def test_Hardswish(self):
75
- Hardswish_layer = brainstate.nn.Hardswish()
76
- input = brainstate.random.randn(2)
77
- output = Hardswish_layer(input)
78
-
79
- def test_ELU(self):
80
- ELU_layer = brainstate.nn.ELU(alpha=0.5, )
81
- input = brainstate.random.randn(2)
82
- output = ELU_layer(input)
83
-
84
- def test_CELU(self):
85
- CELU_layer = brainstate.nn.CELU(alpha=0.5, )
86
- input = brainstate.random.randn(2)
87
- output = CELU_layer(input)
88
-
89
- def test_SELU(self):
90
- SELU_layer = brainstate.nn.SELU()
91
- input = brainstate.random.randn(2)
92
- output = SELU_layer(input)
93
-
94
- def test_GLU(self):
95
- GLU_layer = brainstate.nn.GLU()
96
- input = brainstate.random.randn(4, 2)
97
- output = GLU_layer(input)
98
-
99
- @parameterized.product(
100
- approximate=['tanh', 'none']
565
+ # LogSoftmax = log(softmax(x))
566
+ softmax_output = jax.nn.softmax(x, axis=-1)
567
+ expected = jnp.log(softmax_output)
568
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
569
+
570
+ # Output should be all negative or zero
571
+ self.assertTrue(jnp.all(output <= 0.0))
572
+
573
+ def test_log_softmax_numerical_stability(self):
574
+ """Test LogSoftmax numerical stability."""
575
+ layer = nn.LogSoftmax(dim=-1)
576
+
577
+ # Test with values that could cause numerical issues
578
+ x = jnp.array([[1000.0, 0.0, -1000.0]])
579
+ output = layer(x)
580
+
581
+ # Should not contain NaN or Inf
582
+ self.assertFalse(jnp.any(jnp.isnan(output)))
583
+ self.assertFalse(jnp.any(jnp.isinf(output)))
584
+
585
+ # Test Identity
586
+ def test_identity_functionality(self):
587
+ """Test Identity activation function."""
588
+ layer = nn.Identity()
589
+
590
+ x = jax.random.normal(self.key, (3, 4, 5))
591
+ output = layer(x)
592
+
593
+ # Should be exactly equal to input
594
+ np.testing.assert_array_equal(output, x)
595
+
596
+ def test_identity_gradient(self):
597
+ """Test Identity gradient flow."""
598
+ layer = nn.Identity()
599
+
600
+ x = jax.random.normal(self.key, (3, 4))
601
+
602
+ def loss_fn(x):
603
+ return jnp.sum(layer(x))
604
+
605
+ grad = jax.grad(loss_fn)(x)
606
+
607
+ # Gradient should be all ones
608
+ np.testing.assert_allclose(grad, jnp.ones_like(x), rtol=1e-5)
609
+
610
+ # Test SpikeBitwise
611
+ def test_spike_bitwise_add(self):
612
+ """Test SpikeBitwise with ADD operation."""
613
+ layer = nn.SpikeBitwise(op='and')
614
+
615
+ x = jnp.array([[1.0, 0.0], [1.0, 1.0]])
616
+ y = jnp.array([[1.0, 1.0], [0.0, 1.0]])
617
+ output = layer(x, y)
618
+
619
+ expected = jnp.logical_and(x, y)
620
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
621
+
622
+ def test_spike_bitwise_and(self):
623
+ """Test SpikeBitwise with AND operation."""
624
+ layer = nn.SpikeBitwise(op='and')
625
+
626
+ x = jnp.array([[1.0, 0.0], [1.0, 1.0]])
627
+ y = jnp.array([[1.0, 1.0], [0.0, 1.0]])
628
+ output = layer(x, y)
629
+
630
+ expected = x * y
631
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
632
+
633
+ def test_spike_bitwise_iand(self):
634
+ """Test SpikeBitwise with IAND operation."""
635
+ layer = nn.SpikeBitwise(op='iand')
636
+
637
+ x = jnp.array([[1.0, 0.0], [1.0, 1.0]])
638
+ y = jnp.array([[1.0, 1.0], [0.0, 1.0]])
639
+ output = layer(x, y)
640
+
641
+ expected = (1 - x) * y
642
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
643
+
644
+ def test_spike_bitwise_or(self):
645
+ """Test SpikeBitwise with OR operation."""
646
+ layer = nn.SpikeBitwise(op='or')
647
+
648
+ x = jnp.array([[1.0, 0.0], [1.0, 1.0]])
649
+ y = jnp.array([[1.0, 1.0], [0.0, 1.0]])
650
+ output = layer(x, y)
651
+
652
+ expected = (x + y) - (x * y)
653
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
654
+
655
+
656
+ class TestEdgeCases(parameterized.TestCase):
657
+ """Test edge cases and boundary conditions."""
658
+
659
+ def test_zero_input(self):
660
+ """Test all activations with zero input."""
661
+ x = jnp.zeros((3, 4))
662
+
663
+ activations = [
664
+ nn.ReLU(),
665
+ nn.Sigmoid(),
666
+ nn.Tanh(),
667
+ nn.SiLU(),
668
+ nn.ELU(),
669
+ nn.GELU(),
670
+ nn.Softplus(),
671
+ nn.Softsign(),
672
+ ]
673
+
674
+ for activation in activations:
675
+ output = activation(x)
676
+ self.assertEqual(output.shape, x.shape)
677
+ self.assertFalse(jnp.any(jnp.isnan(output)))
678
+
679
+ def test_large_positive_input(self):
680
+ """Test activations with very large positive values."""
681
+ x = jnp.ones((2, 3)) * 1000.0
682
+
683
+ activations = [
684
+ nn.ReLU(),
685
+ nn.Sigmoid(),
686
+ nn.Tanh(),
687
+ nn.Hardsigmoid(),
688
+ nn.Hardswish(),
689
+ ]
690
+
691
+ for activation in activations:
692
+ output = activation(x)
693
+ self.assertFalse(jnp.any(jnp.isnan(output)))
694
+ self.assertFalse(jnp.any(jnp.isinf(output)))
695
+
696
+ def test_large_negative_input(self):
697
+ """Test activations with very large negative values."""
698
+ x = jnp.ones((2, 3)) * -1000.0
699
+
700
+ activations = [
701
+ nn.ReLU(),
702
+ nn.Sigmoid(),
703
+ nn.Tanh(),
704
+ nn.Hardsigmoid(),
705
+ nn.Hardswish(),
706
+ ]
707
+
708
+ for activation in activations:
709
+ output = activation(x)
710
+ self.assertFalse(jnp.any(jnp.isnan(output)))
711
+ self.assertFalse(jnp.any(jnp.isinf(output)))
712
+
713
+ def test_nan_propagation(self):
714
+ """Test that NaN inputs produce NaN outputs (where appropriate)."""
715
+ x = jnp.array([jnp.nan, 1.0, 2.0])
716
+
717
+ activations = [
718
+ nn.ReLU(),
719
+ nn.Sigmoid(),
720
+ nn.Tanh(),
721
+ ]
722
+
723
+ for activation in activations:
724
+ output = activation(x)
725
+ self.assertTrue(jnp.isnan(output[0]))
726
+
727
+ def test_inf_handling(self):
728
+ """Test handling of infinite values."""
729
+ x = jnp.array([jnp.inf, -jnp.inf, 1.0])
730
+
731
+ # ReLU should handle inf properly
732
+ relu = nn.ReLU()
733
+ output = relu(x)
734
+ self.assertEqual(output[0], jnp.inf)
735
+ self.assertEqual(output[1], 0.0)
736
+
737
+ # Sigmoid should saturate
738
+ sigmoid = nn.Sigmoid()
739
+ output = sigmoid(x)
740
+ np.testing.assert_allclose(output[0], 1.0, rtol=1e-5)
741
+ np.testing.assert_allclose(output[1], 0.0, rtol=1e-5)
742
+
743
+
744
+ class TestBatchProcessing(parameterized.TestCase):
745
+ """Test batch processing capabilities."""
746
+
747
+ @parameterized.parameters(
748
+ (nn.ReLU(), ),
749
+ (nn.Sigmoid(), ),
750
+ (nn.Tanh(), ),
751
+ (nn.GELU(), ),
752
+ (nn.SiLU(), ),
753
+ (nn.ELU(), ),
101
754
  )
102
- def test_GELU(self, approximate):
103
- GELU_layer = brainstate.nn.GELU()
104
- input = brainstate.random.randn(2)
105
- output = GELU_layer(input)
106
-
107
- def test_Hardshrink(self):
108
- Hardshrink_layer = brainstate.nn.Hardshrink(lambd=1)
109
- input = brainstate.random.randn(2)
110
- output = Hardshrink_layer(input)
111
-
112
- def test_LeakyReLU(self):
113
- LeakyReLU_layer = brainstate.nn.LeakyReLU()
114
- input = brainstate.random.randn(2)
115
- output = LeakyReLU_layer(input)
116
-
117
- def test_LogSigmoid(self):
118
- LogSigmoid_layer = brainstate.nn.LogSigmoid()
119
- input = brainstate.random.randn(2)
120
- output = LogSigmoid_layer(input)
121
-
122
- def test_Softplus(self):
123
- Softplus_layer = brainstate.nn.Softplus()
124
- input = brainstate.random.randn(2)
125
- output = Softplus_layer(input)
126
-
127
- def test_Softshrink(self):
128
- Softshrink_layer = brainstate.nn.Softshrink(lambd=1)
129
- input = brainstate.random.randn(2)
130
- output = Softshrink_layer(input)
131
-
132
- def test_PReLU(self):
133
- PReLU_layer = brainstate.nn.PReLU(num_parameters=2, init=0.5)
134
- input = brainstate.random.randn(2)
135
- output = PReLU_layer(input)
136
-
137
- def test_Softsign(self):
138
- Softsign_layer = brainstate.nn.Softsign()
139
- input = brainstate.random.randn(2)
140
- output = Softsign_layer(input)
141
-
142
- def test_Tanhshrink(self):
143
- Tanhshrink_layer = brainstate.nn.Tanhshrink()
144
- input = brainstate.random.randn(2)
145
- output = Tanhshrink_layer(input)
146
-
147
- def test_Softmin(self):
148
- Softmin_layer = brainstate.nn.Softmin(dim=2)
149
- input = brainstate.random.randn(2, 3, 4)
150
- output = Softmin_layer(input)
151
-
152
- def test_Softmax(self):
153
- Softmax_layer = brainstate.nn.Softmax(dim=2)
154
- input = brainstate.random.randn(2, 3, 4)
155
- output = Softmax_layer(input)
156
-
157
- def test_Softmax2d(self):
158
- Softmax2d_layer = brainstate.nn.Softmax2d()
159
- input = brainstate.random.randn(2, 3, 12, 13)
160
- output = Softmax2d_layer(input)
161
-
162
- def test_LogSoftmax(self):
163
- LogSoftmax_layer = brainstate.nn.LogSoftmax(dim=2)
164
- input = brainstate.random.randn(2, 3, 4)
165
- output = LogSoftmax_layer(input)
755
+ def test_batch_consistency(self, activation):
756
+ """Test that batch processing gives same results as individual processing."""
757
+ # Process as batch
758
+ batch_input = jax.random.normal(jax.random.PRNGKey(42), (5, 10))
759
+ batch_output = activation(batch_input)
760
+
761
+ # Process individually
762
+ individual_outputs = []
763
+ for i in range(5):
764
+ individual_output = activation(batch_input[i])
765
+ individual_outputs.append(individual_output)
766
+ individual_outputs = jnp.stack(individual_outputs)
767
+
768
+ np.testing.assert_allclose(batch_output, individual_outputs, rtol=1e-5)
769
+
770
+ def test_different_batch_sizes(self):
771
+ """Test activations with different batch sizes."""
772
+ activation = nn.ReLU()
773
+
774
+ for batch_size in [1, 10, 100]:
775
+ x = jax.random.normal(jax.random.PRNGKey(42), (batch_size, 20))
776
+ output = activation(x)
777
+ self.assertEqual(output.shape[0], batch_size)
778
+
779
+
780
+ class TestMemoryAndPerformance(parameterized.TestCase):
781
+ """Test memory and performance characteristics."""
782
+
783
+ def test_in_place_operations(self):
784
+ """Test that activations don't modify input in-place."""
785
+ x_original = jax.random.normal(jax.random.PRNGKey(42), (10, 10))
786
+ x = x_original.copy()
787
+
788
+ activations = [
789
+ nn.ReLU(),
790
+ nn.Sigmoid(),
791
+ nn.Tanh(),
792
+ ]
793
+
794
+ for activation in activations:
795
+ output = activation(x)
796
+ np.testing.assert_array_equal(x, x_original)
797
+
798
+ def test_jit_compilation(self):
799
+ """Test that activations work with JIT compilation."""
800
+ @jax.jit
801
+ def forward(x):
802
+ relu = nn.ReLU()
803
+ return relu(x)
804
+
805
+ x = jax.random.normal(jax.random.PRNGKey(42), (10, 10))
806
+ output = forward(x)
807
+
808
+ # Should not raise any errors and produce valid output
809
+ self.assertEqual(output.shape, x.shape)
810
+
811
+ @parameterized.parameters(
812
+ (nn.ReLU(), ),
813
+ (nn.Sigmoid(), ),
814
+ (nn.Tanh(), ),
815
+ )
816
+ def test_vmap_compatibility(self, activation):
817
+ """Test compatibility with vmap."""
818
+ def single_forward(x):
819
+ return activation(x)
820
+
821
+ batch_forward = jax.vmap(single_forward)
822
+
823
+ x = jax.random.normal(jax.random.PRNGKey(42), (5, 10, 20))
824
+ output = batch_forward(x)
825
+
826
+ self.assertEqual(output.shape, x.shape)
166
827
 
167
828
 
168
829
  if __name__ == '__main__':
169
- absltest.main()
830
+ absltest.main()