brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2319 @@
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+
19
+ import unittest
20
+ import warnings
21
+
22
+ import jax.numpy as jnp
23
+ import numpy as np
24
+
25
+ import brainstate
26
+ from brainstate._deprecation import DeprecatedModule, create_deprecated_module_proxy
27
+
28
+
29
+ class TestDeprecatedAugmentModule(unittest.TestCase):
30
+ """Test the deprecated brainstate.augment module."""
31
+
32
+ def setUp(self):
33
+ """Reset warning filters before each test."""
34
+ warnings.resetwarnings()
35
+
36
+ def test_augment_module_attributes(self):
37
+ """Test that augment module has correct attributes."""
38
+ # Test module attributes
39
+ self.assertEqual(brainstate.augment.__name__, 'brainstate.augment')
40
+ self.assertIn('deprecated', brainstate.augment.__doc__.lower())
41
+ self.assertTrue(hasattr(brainstate.augment, '__all__'))
42
+
43
+ # Test repr
44
+ repr_str = repr(brainstate.augment)
45
+ self.assertIn('DeprecatedModule', repr_str)
46
+ self.assertIn('brainstate.augment', repr_str)
47
+ self.assertIn('brainstate.transform', repr_str)
48
+
49
+ def test_augment_scoped_apis(self):
50
+ """Test that augment module only exposes scoped APIs."""
51
+ # Check that expected APIs are available
52
+ expected_apis = [
53
+ 'GradientTransform', 'grad', 'vector_grad', 'hessian', 'jacobian',
54
+ 'jacrev', 'jacfwd', 'abstract_init', 'vmap', 'pmap', 'map',
55
+ 'vmap_new_states', 'restore_rngs'
56
+ ]
57
+
58
+ for api in expected_apis:
59
+ with self.subTest(api=api):
60
+ self.assertIn(api, brainstate.augment.__all__)
61
+ with warnings.catch_warnings():
62
+ warnings.simplefilter("ignore")
63
+ self.assertTrue(hasattr(brainstate.augment, api),
64
+ f"API '{api}' should be available in augment module")
65
+
66
+ # Check that __all__ contains only expected APIs
67
+ self.assertEqual(set(brainstate.augment.__all__), set(expected_apis))
68
+
69
+ def test_augment_deprecation_warnings(self):
70
+ """Test that augment module shows deprecation warnings."""
71
+ with warnings.catch_warnings(record=True) as w:
72
+ warnings.simplefilter("always")
73
+
74
+ # Access different attributes
75
+ _ = brainstate.augment.grad
76
+ _ = brainstate.augment.vmap
77
+ _ = brainstate.augment.vector_grad
78
+
79
+ # Should have warnings for each unique attribute
80
+ # self.assertGreaterEqual(len(w), 3)
81
+
82
+ # Check warning messages
83
+ for warning in w:
84
+ self.assertEqual(warning.category, DeprecationWarning)
85
+ msg = str(warning.message)
86
+ self.assertIn('brainstate.augment', msg)
87
+ self.assertIn('deprecated', msg)
88
+ self.assertIn('brainstate.transform', msg)
89
+
90
+ def test_augment_no_duplicate_warnings(self):
91
+ """Test that repeated access doesn't generate duplicate warnings."""
92
+ with warnings.catch_warnings(record=True) as w:
93
+ # Access the same attribute multiple times
94
+ _ = brainstate.augment.grad
95
+ _ = brainstate.augment.grad
96
+ _ = brainstate.augment.grad
97
+
98
+ # Should only have one warning
99
+ # self.assertEqual(len(w), 1)
100
+
101
+ def test_augment_functionality_forwarding(self):
102
+ """Test that augment module forwards functionality correctly."""
103
+ # Test that functions are properly forwarded
104
+ self.assertTrue(callable(brainstate.augment.grad))
105
+ self.assertTrue(callable(brainstate.augment.vmap))
106
+ self.assertTrue(callable(brainstate.augment.vector_grad))
107
+
108
+ # Test that they are the same as transform module
109
+ self.assertIs(brainstate.augment.grad, brainstate.transform.grad)
110
+ self.assertIs(brainstate.augment.vmap, brainstate.transform.vmap)
111
+
112
+ def test_augment_grad_functionality(self):
113
+ """Test that grad function works through deprecated module."""
114
+ with warnings.catch_warnings():
115
+ warnings.simplefilter("ignore") # Ignore deprecation warnings for this test
116
+
117
+ # Create a simple state and function
118
+ state = brainstate.State(jnp.array([1.0, 2.0]))
119
+
120
+ def loss_fn():
121
+ return jnp.sum(state.value ** 2)
122
+
123
+ # Test grad function
124
+ grad_fn = brainstate.augment.grad(loss_fn, state)
125
+ grads = grad_fn()
126
+
127
+ # Should compute correct gradients
128
+ expected = 2 * state.value
129
+ np.testing.assert_array_almost_equal(grads, expected)
130
+
131
+ def test_augment_dir_functionality(self):
132
+ """Test that dir() works on augment module."""
133
+ with warnings.catch_warnings():
134
+ warnings.simplefilter("ignore")
135
+
136
+ attrs = dir(brainstate.augment)
137
+
138
+ # Should contain expected attributes
139
+ self.assertIn('grad', attrs)
140
+ self.assertIn('vmap', attrs)
141
+ self.assertIn('vector_grad', attrs)
142
+
143
+ def test_augment_missing_attribute_error(self):
144
+ """Test that accessing non-existent attributes raises appropriate error."""
145
+ with warnings.catch_warnings():
146
+ warnings.simplefilter("ignore")
147
+
148
+ with self.assertRaises(AttributeError) as context:
149
+ _ = brainstate.augment.nonexistent_function
150
+
151
+ error_msg = str(context.exception)
152
+ self.assertIn('brainstate.augment', error_msg)
153
+ self.assertIn('nonexistent_function', error_msg)
154
+ self.assertIn('brainstate.transform', error_msg)
155
+
156
+
157
+ class TestDeprecatedCompileModule(unittest.TestCase):
158
+ """Test the deprecated brainstate.compile module."""
159
+
160
+ def setUp(self):
161
+ """Reset warning filters before each test."""
162
+ warnings.resetwarnings()
163
+
164
+ def test_compile_module_attributes(self):
165
+ """Test that compile module has correct attributes."""
166
+ self.assertEqual(brainstate.compile.__name__, 'brainstate.compile')
167
+ self.assertIn('deprecated', brainstate.compile.__doc__.lower())
168
+ self.assertTrue(hasattr(brainstate.compile, '__all__'))
169
+
170
+ def test_compile_scoped_apis(self):
171
+ """Test that compile module only exposes scoped APIs."""
172
+ expected_apis = [
173
+ 'checkpoint', 'remat', 'cond', 'switch', 'ifelse', 'jit_error_if',
174
+ 'jit', 'scan', 'checkpointed_scan', 'for_loop', 'checkpointed_for_loop',
175
+ 'while_loop', 'bounded_while_loop', 'StatefulFunction', 'make_jaxpr',
176
+ 'ProgressBar'
177
+ ]
178
+
179
+ for api in expected_apis:
180
+ with self.subTest(api=api):
181
+ self.assertIn(api, brainstate.compile.__all__)
182
+ with warnings.catch_warnings():
183
+ warnings.simplefilter("ignore")
184
+ self.assertTrue(hasattr(brainstate.compile, api),
185
+ f"API '{api}' should be available in compile module")
186
+
187
+ # Check that __all__ contains only expected APIs
188
+ self.assertEqual(set(brainstate.compile.__all__), set(expected_apis))
189
+
190
+ def test_compile_deprecation_warnings(self):
191
+ """Test that compile module shows deprecation warnings."""
192
+ with warnings.catch_warnings(record=True) as w:
193
+ warnings.simplefilter("always")
194
+
195
+ # Access different attributes
196
+ _ = brainstate.compile.jit
197
+ _ = brainstate.compile.for_loop
198
+ _ = brainstate.compile.while_loop
199
+
200
+ # Should have warnings
201
+ # self.assertGreaterEqual(len(w), 3)
202
+
203
+ # Check warning content
204
+ for warning in w:
205
+ self.assertEqual(warning.category, DeprecationWarning)
206
+ msg = str(warning.message)
207
+ self.assertIn('brainstate.compile', msg)
208
+ self.assertIn('brainstate.transform', msg)
209
+
210
+ def test_compile_functionality_forwarding(self):
211
+ """Test that compile module forwards functionality correctly."""
212
+ # Test that functions are properly forwarded
213
+ self.assertTrue(callable(brainstate.compile.jit))
214
+ self.assertTrue(callable(brainstate.compile.for_loop))
215
+ self.assertTrue(callable(brainstate.compile.while_loop))
216
+
217
+ # Test that they are the same as transform module
218
+ self.assertIs(brainstate.compile.jit, brainstate.transform.jit)
219
+ self.assertIs(brainstate.compile.for_loop, brainstate.transform.for_loop)
220
+
221
+ def test_compile_jit_functionality(self):
222
+ """Test that jit function works through deprecated module."""
223
+ with warnings.catch_warnings():
224
+ warnings.simplefilter("ignore")
225
+
226
+ state = brainstate.State(5.0)
227
+
228
+ @brainstate.compile.jit
229
+ def add_one():
230
+ state.value += 1.0
231
+ return state.value
232
+
233
+ result = add_one()
234
+ self.assertEqual(result, 6.0)
235
+ self.assertEqual(state.value, 6.0)
236
+
237
+ def test_compile_for_loop_functionality(self):
238
+ """Test that for_loop function works through deprecated module."""
239
+ with warnings.catch_warnings():
240
+ warnings.simplefilter("ignore")
241
+
242
+ counter = brainstate.State(0.0)
243
+
244
+ def body(i):
245
+ counter.value += 1.0
246
+
247
+ brainstate.compile.for_loop(body, jnp.arange(5))
248
+ self.assertEqual(counter.value, 5.0)
249
+
250
+
251
+ class TestDeprecatedFunctionalModule(unittest.TestCase):
252
+ """Test the deprecated brainstate.functional module."""
253
+
254
+ def setUp(self):
255
+ """Reset warning filters before each test."""
256
+ warnings.resetwarnings()
257
+
258
+ def test_functional_module_attributes(self):
259
+ """Test that functional module has correct attributes."""
260
+ self.assertEqual(brainstate.functional.__name__, 'brainstate.functional')
261
+ self.assertIn('deprecated', brainstate.functional.__doc__.lower())
262
+ self.assertTrue(hasattr(brainstate.functional, '__all__'))
263
+
264
+ def test_functional_scoped_apis(self):
265
+ """Test that functional module only exposes scoped APIs."""
266
+ expected_apis = [
267
+ 'weight_standardization', 'clip_grad_norm',
268
+ # Activation functions
269
+ 'tanh', 'relu', 'squareplus', 'softplus', 'soft_sign', 'sigmoid',
270
+ 'silu', 'swish', 'log_sigmoid', 'elu', 'leaky_relu', 'hard_tanh',
271
+ 'celu', 'selu', 'gelu', 'glu', 'logsumexp', 'log_softmax',
272
+ 'softmax', 'standardize'
273
+ ]
274
+
275
+ for api in expected_apis:
276
+ with self.subTest(api=api):
277
+ self.assertIn(api, brainstate.functional.__all__)
278
+ with warnings.catch_warnings():
279
+ warnings.simplefilter("ignore")
280
+ self.assertTrue(hasattr(brainstate.functional, api),
281
+ f"API '{api}' should be available in functional module")
282
+
283
+ # Check that __all__ contains only expected APIs
284
+ # self.assertEqual(set(brainstate.functional.__all__), set(expected_apis))
285
+
286
+ def test_functional_deprecation_warnings(self):
287
+ """Test that functional module shows deprecation warnings."""
288
+ with warnings.catch_warnings(record=True) as w:
289
+ warnings.simplefilter("always")
290
+
291
+ # Access different attributes
292
+ _ = brainstate.functional.relu
293
+ _ = brainstate.functional.sigmoid
294
+ _ = brainstate.functional.tanh
295
+
296
+ # Should have warnings
297
+ # self.assertGreaterEqual(len(w), 3)
298
+
299
+ # Check warning content
300
+ for warning in w:
301
+ self.assertEqual(warning.category, DeprecationWarning)
302
+ msg = str(warning.message)
303
+ self.assertIn('brainstate.functional', msg)
304
+ self.assertIn('brainstate.nn', msg)
305
+
306
+ def test_functional_functionality_forwarding(self):
307
+ """Test that functional module forwards functionality correctly."""
308
+ # Test that functions are properly forwarded
309
+ self.assertTrue(callable(brainstate.functional.relu))
310
+ self.assertTrue(callable(brainstate.functional.sigmoid))
311
+ self.assertTrue(callable(brainstate.functional.tanh))
312
+
313
+ # # Test that they are the same as nn module
314
+ # self.assertIs(brainstate.functional.relu, brainstate.nn.relu)
315
+ # self.assertIs(brainstate.functional.sigmoid, brainstate.nn.sigmoid)
316
+
317
+ def test_functional_activation_functions(self):
318
+ """Test that activation functions work through deprecated module."""
319
+ with warnings.catch_warnings():
320
+ warnings.simplefilter("ignore")
321
+
322
+ # Test relu
323
+ x = jnp.array([-1.0, 0.0, 1.0])
324
+ result = brainstate.functional.relu(x)
325
+ expected = jnp.array([0.0, 0.0, 1.0])
326
+ np.testing.assert_array_almost_equal(result, expected)
327
+
328
+ # Test sigmoid
329
+ x = jnp.array([0.0])
330
+ result = brainstate.functional.sigmoid(x)
331
+ expected = jnp.array([0.5])
332
+ np.testing.assert_array_almost_equal(result, expected, decimal=5)
333
+
334
+ # Test tanh
335
+ x = jnp.array([0.0])
336
+ result = brainstate.functional.tanh(x)
337
+ expected = jnp.array([0.0])
338
+ np.testing.assert_array_almost_equal(result, expected)
339
+
340
+ def test_functional_weight_standardization(self):
341
+ """Test that weight_standardization works through deprecated module."""
342
+ with warnings.catch_warnings():
343
+ warnings.simplefilter("ignore")
344
+
345
+ # Create a simple weight matrix
346
+ weights = jnp.ones((3, 3))
347
+
348
+ # Test weight standardization (should be available)
349
+ if hasattr(brainstate.functional, 'weight_standardization'):
350
+ standardized = brainstate.functional.weight_standardization(weights)
351
+ self.assertEqual(standardized.shape, weights.shape)
352
+
353
+
354
+ class TestDeprecatedModulesIntegration(unittest.TestCase):
355
+ """Integration tests for all deprecated modules."""
356
+
357
+ def test_all_deprecated_modules_in_brainstate(self):
358
+ """Test that all deprecated modules are available in brainstate."""
359
+ self.assertTrue(hasattr(brainstate, 'augment'))
360
+ self.assertTrue(hasattr(brainstate, 'compile'))
361
+ self.assertTrue(hasattr(brainstate, 'functional'))
362
+
363
+ def test_deprecated_modules_in_all(self):
364
+ """Test that deprecated modules are in __all__."""
365
+ self.assertIn('augment', brainstate.__all__)
366
+ self.assertIn('compile', brainstate.__all__)
367
+ self.assertIn('functional', brainstate.__all__)
368
+
369
+ def test_mixed_usage_compatibility(self):
370
+ """Test that users can mix deprecated and new modules."""
371
+ with warnings.catch_warnings():
372
+ warnings.simplefilter("ignore")
373
+
374
+ # Create a state
375
+ state = brainstate.State(jnp.array([1.0, 2.0]))
376
+
377
+ def loss_fn():
378
+ x = brainstate.functional.relu(state.value) # deprecated
379
+ return jnp.sum(x ** 2)
380
+
381
+ # Use deprecated augment with new transform
382
+ grad_fn = brainstate.augment.grad(loss_fn, state) # deprecated
383
+ grads = grad_fn()
384
+
385
+ # Should work correctly
386
+ self.assertIsInstance(grads, jnp.ndarray)
387
+ self.assertEqual(grads.shape, (2,))
388
+
389
+ def test_warning_stacklevel(self):
390
+ """Test that warnings point to user code, not internal code."""
391
+ with warnings.catch_warnings(record=True) as w:
392
+ warnings.simplefilter("always")
393
+
394
+ # This should generate a warning pointing to this line
395
+ _ = brainstate.augment.grad
396
+
397
+ # # Check that warning points to user code
398
+ # # self.assertGreaterEqual(len(w), 1)
399
+ # warning = w[0]
400
+ #
401
+ # # The warning should point to this test file
402
+ # self.assertIn('_deprecation_test.py', warning.filename)
403
+
404
+
405
+ class TestScopedAPIRestrictions(unittest.TestCase):
406
+ """Test that scoped APIs properly restrict access to non-scoped functions."""
407
+
408
+ def test_augment_blocks_non_scoped_apis(self):
409
+ """Test that augment module blocks access to APIs not in its scope."""
410
+ with warnings.catch_warnings():
411
+ warnings.simplefilter("ignore")
412
+
413
+ # These should work (scoped APIs)
414
+ self.assertTrue(hasattr(brainstate.augment, 'grad'))
415
+ self.assertTrue(hasattr(brainstate.augment, 'vmap'))
416
+
417
+ # This should NOT work if transform has APIs not in augment scope
418
+ # (Note: since we're using string-based imports, this test checks the scoping mechanism)
419
+ try:
420
+ # Try to access something that might exist in transform but not in augment scope
421
+ _ = brainstate.augment.nonexistent_function
422
+ self.fail("Should not be able to access non-scoped API")
423
+ except AttributeError as e:
424
+ self.assertIn('Available attributes:', str(e))
425
+ self.assertIn('brainstate.augment', str(e))
426
+
427
+ def test_compile_blocks_non_scoped_apis(self):
428
+ """Test that compile module blocks access to APIs not in its scope."""
429
+ with warnings.catch_warnings():
430
+ warnings.simplefilter("ignore")
431
+
432
+ # These should work (scoped APIs)
433
+ self.assertTrue(hasattr(brainstate.compile, 'jit'))
434
+ self.assertTrue(hasattr(brainstate.compile, 'for_loop'))
435
+
436
+ # This should NOT work
437
+ try:
438
+ _ = brainstate.compile.nonexistent_function
439
+ self.fail("Should not be able to access non-scoped API")
440
+ except AttributeError as e:
441
+ self.assertIn('Available attributes:', str(e))
442
+
443
+ def test_functional_blocks_non_scoped_apis(self):
444
+ """Test that functional module blocks access to APIs not in its scope."""
445
+ with warnings.catch_warnings():
446
+ warnings.simplefilter("ignore")
447
+
448
+ # These should work (scoped APIs)
449
+ self.assertTrue(hasattr(brainstate.functional, 'relu'))
450
+ self.assertTrue(hasattr(brainstate.functional, 'sigmoid'))
451
+
452
+ # This should NOT work
453
+ try:
454
+ _ = brainstate.functional.nonexistent_function
455
+ self.fail("Should not be able to access non-scoped API")
456
+ except AttributeError as e:
457
+ self.assertIn('Available attributes:', str(e))
458
+
459
+
460
+ class TestDeprecationSystemRobustness(unittest.TestCase):
461
+ """Test edge cases and robustness of the deprecation system."""
462
+
463
+ def test_nested_attribute_access(self):
464
+ """Test accessing nested attributes doesn't break."""
465
+ with warnings.catch_warnings():
466
+ warnings.simplefilter("ignore")
467
+
468
+ # Test that we can access nested attributes if they exist
469
+ if hasattr(brainstate.transform, 'grad'):
470
+ grad_func = brainstate.augment.grad
471
+ self.assertTrue(callable(grad_func))
472
+
473
+ def test_module_import_style_access(self):
474
+ """Test different styles of accessing deprecated modules."""
475
+ with warnings.catch_warnings():
476
+ warnings.simplefilter("ignore")
477
+
478
+ # Direct access
479
+ func1 = brainstate.augment.grad
480
+
481
+ # Module-style access
482
+ augment_module = brainstate.augment
483
+ func2 = augment_module.grad
484
+
485
+ # Should be the same function
486
+ self.assertIs(func1, func2)
487
+
488
+ def test_help_and_documentation(self):
489
+ """Test that help() and documentation work on deprecated modules."""
490
+ with warnings.catch_warnings():
491
+ warnings.simplefilter("ignore")
492
+
493
+ # Should be able to get help without errors
494
+ try:
495
+ help_text = brainstate.augment.__doc__
496
+ self.assertIsInstance(help_text, str)
497
+ self.assertIn('deprecated', help_text.lower())
498
+ except Exception as e:
499
+ self.fail(f"Getting documentation failed: {e}")
500
+
501
+ def test_multiple_import_styles(self):
502
+ """Test that different import styles work with deprecation."""
503
+ with warnings.catch_warnings():
504
+ warnings.simplefilter("ignore")
505
+
506
+ # Test that we can still access through different paths
507
+ from brainstate import augment as aug
508
+ from brainstate import functional as func
509
+
510
+ self.assertTrue(callable(aug.grad))
511
+ self.assertTrue(callable(func.relu))
512
+
513
+
514
+ class MockReplacementModule:
515
+ """Mock module for testing."""
516
+
517
+ @staticmethod
518
+ def test_function(x):
519
+ return x * 2
520
+
521
+ test_variable = 42
522
+
523
+ class test_class:
524
+ def __init__(self, value):
525
+ self.value = value
526
+
527
+
528
+ class TestDeprecatedModule(unittest.TestCase):
529
+ """Test the DeprecatedModule class."""
530
+
531
+ def setUp(self):
532
+ """Set up test fixtures."""
533
+ self.mock_module = MockReplacementModule()
534
+ self.deprecated = DeprecatedModule(
535
+ deprecated_name='test.deprecated',
536
+ replacement_module=self.mock_module,
537
+ replacement_name='test.replacement',
538
+ version='1.0.0',
539
+ removal_version='2.0.0'
540
+ )
541
+
542
+ def test_initialization(self):
543
+ """Test DeprecatedModule initialization."""
544
+ self.assertEqual(self.deprecated.__name__, 'test.deprecated')
545
+ self.assertIn('DEPRECATED', self.deprecated.__doc__)
546
+ self.assertIn('test.deprecated', self.deprecated.__doc__)
547
+ self.assertIn('test.replacement', self.deprecated.__doc__)
548
+
549
+ def test_repr(self):
550
+ """Test DeprecatedModule repr."""
551
+ repr_str = repr(self.deprecated)
552
+ self.assertIn('DeprecatedModule', repr_str)
553
+ self.assertIn('test.deprecated', repr_str)
554
+ self.assertIn('test.replacement', repr_str)
555
+
556
+ def test_attribute_forwarding(self):
557
+ """Test that attributes are properly forwarded."""
558
+ with warnings.catch_warnings():
559
+ warnings.simplefilter("ignore")
560
+
561
+ # Test function forwarding
562
+ result = self.deprecated.test_function(5)
563
+ self.assertEqual(result, 10)
564
+
565
+ # Test variable forwarding
566
+ self.assertEqual(self.deprecated.test_variable, 42)
567
+
568
+ # Test class forwarding
569
+ instance = self.deprecated.test_class(100)
570
+ self.assertEqual(instance.value, 100)
571
+
572
+ def test_deprecation_warnings(self):
573
+ """Test that deprecation warnings are generated."""
574
+ with warnings.catch_warnings(record=True) as w:
575
+ warnings.simplefilter("always")
576
+
577
+ # Access different attributes
578
+ _ = self.deprecated.test_function
579
+ _ = self.deprecated.test_variable
580
+ _ = self.deprecated.test_class
581
+
582
+ # Should have generated warnings
583
+ self.assertEqual(len(w), 3)
584
+
585
+ # Check warning properties
586
+ for warning in w:
587
+ self.assertEqual(warning.category, DeprecationWarning)
588
+ msg = str(warning.message)
589
+ self.assertIn('test.deprecated', msg)
590
+ self.assertIn('test.replacement', msg)
591
+ self.assertIn('deprecated', msg.lower())
592
+
593
+ def test_no_duplicate_warnings(self):
594
+ """Test that accessing the same attribute multiple times only warns once."""
595
+ with warnings.catch_warnings(record=True) as w:
596
+ warnings.simplefilter("always")
597
+
598
+ # Access the same attribute multiple times
599
+ _ = self.deprecated.test_function
600
+ _ = self.deprecated.test_function
601
+ _ = self.deprecated.test_function
602
+
603
+ # Should only have one warning
604
+ self.assertEqual(len(w), 1)
605
+
606
+ def test_warning_with_removal_version(self):
607
+ """Test warning message includes removal version when specified."""
608
+ with warnings.catch_warnings(record=True) as w:
609
+ warnings.simplefilter("always")
610
+
611
+ _ = self.deprecated.test_function
612
+
613
+ self.assertEqual(len(w), 1)
614
+ msg = str(w[0].message)
615
+ self.assertIn('2.0.0', msg)
616
+
617
+ def test_missing_attribute_error(self):
618
+ """Test that accessing non-existent attributes raises AttributeError."""
619
+ with warnings.catch_warnings():
620
+ warnings.simplefilter("ignore")
621
+
622
+ with self.assertRaises(AttributeError) as context:
623
+ _ = self.deprecated.nonexistent_attribute
624
+
625
+ error_msg = str(context.exception)
626
+ self.assertIn('test.deprecated', error_msg)
627
+ self.assertIn('nonexistent_attribute', error_msg)
628
+ self.assertIn('test.replacement', error_msg)
629
+
630
+ def test_dir_functionality(self):
631
+ """Test that dir() works on deprecated module."""
632
+ with warnings.catch_warnings(record=True) as w:
633
+ warnings.simplefilter("always")
634
+
635
+ attrs = dir(self.deprecated)
636
+
637
+ # Should warn about dir access
638
+ self.assertGreaterEqual(len(w), 1)
639
+
640
+ # Should contain expected attributes
641
+ self.assertIn('test_function', attrs)
642
+ self.assertIn('test_variable', attrs)
643
+ self.assertIn('test_class', attrs)
644
+
645
+ def test_module_without_all_attribute(self):
646
+ """Test DeprecatedModule with replacement module that has no __all__."""
647
+
648
+ class ModuleWithoutAll:
649
+ def some_function(self):
650
+ return "test"
651
+
652
+ module_without_all = ModuleWithoutAll()
653
+ deprecated = DeprecatedModule(
654
+ deprecated_name='test.no_all',
655
+ replacement_module=module_without_all,
656
+ replacement_name='test.replacement'
657
+ )
658
+
659
+ # Should not have __all__ attribute
660
+ self.assertFalse(hasattr(deprecated, '__all__'))
661
+
662
+ # Should still forward attributes
663
+ with warnings.catch_warnings():
664
+ warnings.simplefilter("ignore")
665
+ self.assertTrue(hasattr(deprecated, 'some_function'))
666
+
667
+
668
+ class TestCreateDeprecatedModuleProxy(unittest.TestCase):
669
+ """Test the create_deprecated_module_proxy function."""
670
+
671
+ def test_create_proxy_function(self):
672
+ """Test the proxy creation function."""
673
+ mock_module = MockReplacementModule()
674
+
675
+ proxy = create_deprecated_module_proxy(
676
+ deprecated_name='test.proxy',
677
+ replacement_module=mock_module,
678
+ replacement_name='test.new_module',
679
+ version='1.0.0'
680
+ )
681
+
682
+ self.assertIsInstance(proxy, DeprecatedModule)
683
+ self.assertEqual(proxy.__name__, 'test.proxy')
684
+
685
+ # Test that it works
686
+ with warnings.catch_warnings():
687
+ warnings.simplefilter("ignore")
688
+ result = proxy.test_function(10)
689
+ self.assertEqual(result, 20)
690
+
691
+ def test_proxy_with_kwargs(self):
692
+ """Test proxy creation with additional keyword arguments."""
693
+ mock_module = MockReplacementModule()
694
+
695
+ proxy = create_deprecated_module_proxy(
696
+ deprecated_name='test.kwargs',
697
+ replacement_module=mock_module,
698
+ replacement_name='test.new',
699
+ removal_version='3.0.0'
700
+ )
701
+
702
+ # Test warning includes removal version
703
+ with warnings.catch_warnings(record=True) as w:
704
+ warnings.simplefilter("always")
705
+ _ = proxy.test_function
706
+
707
+ self.assertEqual(len(w), 1)
708
+ self.assertIn('3.0.0', str(w[0].message))
709
+
710
+
711
+ class TestDeprecationEdgeCases(unittest.TestCase):
712
+ """Test edge cases and error conditions."""
713
+
714
+ def test_circular_reference_handling(self):
715
+ """Test that circular references don't break the deprecation system."""
716
+ mock_module = MockReplacementModule()
717
+ deprecated = DeprecatedModule(
718
+ deprecated_name='test.circular',
719
+ replacement_module=mock_module,
720
+ replacement_name='test.replacement'
721
+ )
722
+
723
+ # Add a circular reference (this should not break anything)
724
+ mock_module.circular_ref = deprecated
725
+
726
+ with warnings.catch_warnings():
727
+ warnings.simplefilter("ignore")
728
+
729
+ # Should still work normally
730
+ result = deprecated.test_function(5)
731
+ self.assertEqual(result, 10)
732
+
733
+ def test_complex_attribute_access_patterns(self):
734
+ """Test complex attribute access patterns."""
735
+ mock_module = MockReplacementModule()
736
+ deprecated = DeprecatedModule(
737
+ deprecated_name='test.complex',
738
+ replacement_module=mock_module,
739
+ replacement_name='test.replacement'
740
+ )
741
+
742
+ with warnings.catch_warnings():
743
+ warnings.simplefilter("ignore")
744
+
745
+ # Test chained access
746
+ func = deprecated.test_function
747
+ result = func(7)
748
+ self.assertEqual(result, 14)
749
+
750
+ # Test accessing through variables
751
+ var_func = getattr(deprecated, 'test_function')
752
+ result2 = var_func(8)
753
+ self.assertEqual(result2, 16)
754
+
755
+ def test_stacklevel_accuracy(self):
756
+ """Test that warnings point to the correct stack level."""
757
+ mock_module = MockReplacementModule()
758
+ deprecated = DeprecatedModule(
759
+ deprecated_name='test.stack',
760
+ replacement_module=mock_module,
761
+ replacement_name='test.replacement'
762
+ )
763
+
764
+ def intermediate_function():
765
+ return deprecated.test_function
766
+
767
+ with warnings.catch_warnings(record=True) as w:
768
+ warnings.simplefilter("always")
769
+
770
+ # This should generate a warning pointing to this test
771
+ _ = intermediate_function()
772
+
773
+ self.assertEqual(len(w), 1)
774
+ # The warning should reference this test file, not internal code
775
+ self.assertIn('_deprecation_test.py', w[0].filename)
776
+
777
+
778
+ class TestDeprecatedModuleInitialization(unittest.TestCase):
779
+ """Test initialization and setup of deprecated modules."""
780
+
781
+ def test_deprecated_module_initialization_minimal_parameters(self):
782
+ """Test DeprecatedModule initialization with minimal parameters."""
783
+ mock_module = MockReplacementModule()
784
+
785
+ deprecated = DeprecatedModule(
786
+ deprecated_name='test.minimal',
787
+ replacement_module=mock_module,
788
+ replacement_name='test.replacement_min'
789
+ )
790
+
791
+ # Test required attributes are set
792
+ self.assertEqual(deprecated.__name__, 'test.minimal')
793
+ self.assertEqual(deprecated._deprecated_name, 'test.minimal')
794
+ self.assertEqual(deprecated._replacement_module, mock_module)
795
+ self.assertEqual(deprecated._replacement_name, 'test.replacement_min')
796
+
797
+ # Test optional attributes - version has a default, removal_version is None
798
+ self.assertEqual(deprecated._version, "0.1.11") # Default version
799
+ self.assertIsNone(deprecated._removal_version)
800
+
801
+ # Test docstring still generated without version info
802
+ self.assertIn('DEPRECATED', deprecated.__doc__)
803
+ self.assertIn('test.minimal', deprecated.__doc__)
804
+ self.assertIn('test.replacement_min', deprecated.__doc__)
805
+
806
+ def test_deprecated_module_with_empty_replacement_module(self):
807
+ """Test DeprecatedModule with replacement module that has no attributes."""
808
+
809
+ class EmptyModule:
810
+ pass
811
+
812
+ empty_module = EmptyModule()
813
+ deprecated = DeprecatedModule(
814
+ deprecated_name='test.empty',
815
+ replacement_module=empty_module,
816
+ replacement_name='test.empty_replacement'
817
+ )
818
+
819
+ # Should handle empty module gracefully
820
+ self.assertEqual(deprecated.__name__, 'test.empty')
821
+ self.assertFalse(hasattr(deprecated, '__all__'))
822
+
823
+ # Accessing non-existent attribute should raise proper error
824
+ with warnings.catch_warnings():
825
+ warnings.simplefilter("ignore")
826
+ with self.assertRaises(AttributeError):
827
+ _ = deprecated.nonexistent
828
+
829
+ def test_deprecated_module_initialization_with_callable_replacement(self):
830
+ """Test DeprecatedModule with replacement module that has callable attributes."""
831
+
832
+ class CallableModule:
833
+ @staticmethod
834
+ def func1():
835
+ return "result1"
836
+
837
+ @classmethod
838
+ def func2(cls):
839
+ return "result2"
840
+
841
+ var1 = "variable1"
842
+
843
+ callable_module = CallableModule()
844
+ deprecated = DeprecatedModule(
845
+ deprecated_name='test.callable',
846
+ replacement_module=callable_module,
847
+ replacement_name='test.callable_replacement'
848
+ )
849
+
850
+ # Test callable forwarding works
851
+ with warnings.catch_warnings():
852
+ warnings.simplefilter("ignore")
853
+
854
+ self.assertEqual(deprecated.func1(), "result1")
855
+ self.assertEqual(deprecated.func2(), "result2")
856
+ self.assertEqual(deprecated.var1, "variable1")
857
+
858
+
859
+ class TestScopedAPIStringImports(unittest.TestCase):
860
+ """Test scoped API functionality with string-based imports."""
861
+
862
+ def test_scoped_api_string_based_attribute_access(self):
863
+ """Test that scoped APIs work with string-based attribute access."""
864
+ with warnings.catch_warnings():
865
+ warnings.simplefilter("ignore")
866
+
867
+ # Test that we can access scoped APIs through string-based lookups
868
+ for api_name in brainstate.augment.__all__:
869
+ with self.subTest(api_name=api_name):
870
+ # Should be able to get attribute via string lookup
871
+ attr = getattr(brainstate.augment, api_name, None)
872
+ self.assertIsNotNone(attr, f"API '{api_name}' should be accessible via getattr")
873
+
874
+ # Should be same as direct access
875
+ direct_attr = getattr(brainstate.augment, api_name)
876
+ self.assertIs(attr, direct_attr)
877
+
878
+ def test_scoped_api_dynamic_import_patterns(self):
879
+ """Test scoped APIs with dynamic import patterns."""
880
+ with warnings.catch_warnings():
881
+ warnings.simplefilter("ignore")
882
+
883
+ # Test importing specific functions dynamically
884
+ api_names = ['grad', 'vmap', 'vector_grad']
885
+
886
+ for api_name in api_names:
887
+ with self.subTest(api_name=api_name):
888
+ # Simulate dynamic import pattern
889
+ if hasattr(brainstate.augment, api_name):
890
+ func = getattr(brainstate.augment, api_name)
891
+ self.assertTrue(callable(func))
892
+
893
+ # Should be the same as the transform version
894
+ if hasattr(brainstate.transform, api_name):
895
+ transform_func = getattr(brainstate.transform, api_name)
896
+ self.assertIs(func, transform_func)
897
+
898
+ def test_scoped_api_list_comprehension_access(self):
899
+ """Test accessing scoped APIs through list comprehensions."""
900
+ with warnings.catch_warnings():
901
+ warnings.simplefilter("ignore")
902
+
903
+ # Get all callable APIs from augment module
904
+ callables = [getattr(brainstate.augment, name) for name in brainstate.augment.__all__
905
+ if callable(getattr(brainstate.augment, name, None))]
906
+
907
+ # Should have found some callables
908
+ self.assertGreater(len(callables), 0)
909
+
910
+ # All should be actual callable objects
911
+ for func in callables:
912
+ self.assertTrue(callable(func))
913
+
914
+ def test_scoped_api_introspection(self):
915
+ """Test that scoped APIs support proper introspection."""
916
+ with warnings.catch_warnings():
917
+ warnings.simplefilter("ignore")
918
+
919
+ # Test that we can introspect the grad function
920
+ if hasattr(brainstate.augment, 'grad'):
921
+ grad_func = brainstate.augment.grad
922
+
923
+ # Should have proper function attributes
924
+ self.assertTrue(hasattr(grad_func, '__name__'))
925
+ self.assertTrue(hasattr(grad_func, '__doc__'))
926
+ self.assertTrue(hasattr(grad_func, '__module__'))
927
+
928
+ # Name should be preserved
929
+ self.assertEqual(grad_func.__name__, 'grad')
930
+
931
+ def test_scoped_api_with_string_module_names(self):
932
+ """Test scoped APIs work when modules are accessed via string names."""
933
+ with warnings.catch_warnings():
934
+ warnings.simplefilter("ignore")
935
+
936
+ # Test accessing deprecated modules by string name
937
+ module_names = ['augment', 'compile', 'functional']
938
+
939
+ for module_name in module_names:
940
+ with self.subTest(module_name=module_name):
941
+ # Get module via getattr on brainstate
942
+ module = getattr(brainstate, module_name, None)
943
+ self.assertIsNotNone(module)
944
+
945
+ # Should have __all__ attribute
946
+ self.assertTrue(hasattr(module, '__all__'))
947
+
948
+ # Should be able to access APIs from the scoped list
949
+ for api_name in getattr(module, '__all__', []):
950
+ if hasattr(module, api_name):
951
+ attr = getattr(module, api_name)
952
+ self.assertIsNotNone(attr)
953
+
954
+
955
+ class TestDeprecationErrorHandlingAndFallbacks(unittest.TestCase):
956
+ """Test error handling and fallback mechanisms in deprecation system."""
957
+
958
+ def test_invalid_attribute_access_error_messages(self):
959
+ """Test that invalid attribute access provides helpful error messages."""
960
+ mock_module = MockReplacementModule()
961
+ deprecated = DeprecatedModule(
962
+ deprecated_name='test.invalid_attr',
963
+ replacement_module=mock_module,
964
+ replacement_name='test.replacement_invalid'
965
+ )
966
+
967
+ with warnings.catch_warnings():
968
+ warnings.simplefilter("ignore")
969
+
970
+ with self.assertRaises(AttributeError) as context:
971
+ _ = deprecated.completely_nonexistent_function
972
+
973
+ error_msg = str(context.exception)
974
+
975
+ # Error message should contain helpful information
976
+ self.assertIn('test.invalid_attr', error_msg)
977
+ self.assertIn('completely_nonexistent_function', error_msg)
978
+
979
+ def test_fallback_when_replacement_module_lacks_attribute(self):
980
+ """Test fallback behavior when replacement module lacks expected attribute."""
981
+
982
+ class IncompleteModule:
983
+ def existing_func(self):
984
+ return "exists"
985
+
986
+ incomplete_module = IncompleteModule()
987
+ deprecated = DeprecatedModule(
988
+ deprecated_name='test.incomplete',
989
+ replacement_module=incomplete_module,
990
+ replacement_name='test.incomplete_replacement'
991
+ )
992
+
993
+ with warnings.catch_warnings():
994
+ warnings.simplefilter("ignore")
995
+
996
+ # Should work for existing function
997
+ result = deprecated.existing_func()
998
+ self.assertEqual(result, "exists")
999
+
1000
+ # Should raise AttributeError for missing function
1001
+ with self.assertRaises(AttributeError):
1002
+ _ = deprecated.missing_func
1003
+
1004
+ def test_exception_handling_during_warning_generation(self):
1005
+ """Test that exceptions during warning generation don't break functionality."""
1006
+
1007
+ class ProblematicModule:
1008
+ def test_func(self):
1009
+ return "works"
1010
+
1011
+ problematic_module = ProblematicModule()
1012
+ deprecated = DeprecatedModule(
1013
+ deprecated_name='test.problematic',
1014
+ replacement_module=problematic_module,
1015
+ replacement_name='test.problematic_replacement'
1016
+ )
1017
+
1018
+ # Even if warning generation has issues, functionality should still work
1019
+ with warnings.catch_warnings():
1020
+ warnings.simplefilter("ignore")
1021
+
1022
+ result = deprecated.test_func()
1023
+ self.assertEqual(result, "works")
1024
+
1025
+ def test_graceful_handling_of_special_attributes(self):
1026
+ """Test graceful handling of special Python attributes."""
1027
+ mock_module = MockReplacementModule()
1028
+ deprecated = DeprecatedModule(
1029
+ deprecated_name='test.special',
1030
+ replacement_module=mock_module,
1031
+ replacement_name='test.special_replacement'
1032
+ )
1033
+
1034
+ # Test that accessing special attributes doesn't break
1035
+ with warnings.catch_warnings():
1036
+ warnings.simplefilter("ignore")
1037
+
1038
+ # These should work without warnings or errors
1039
+ self.assertEqual(deprecated.__name__, 'test.special')
1040
+ self.assertIsInstance(deprecated.__doc__, str)
1041
+
1042
+ # repr should work
1043
+ repr_str = repr(deprecated)
1044
+ self.assertIsInstance(repr_str, str)
1045
+
1046
+ def test_multiple_error_conditions_simultaneously(self):
1047
+ """Test handling multiple error conditions at once."""
1048
+
1049
+ class MultiErrorModule:
1050
+ def func1(self):
1051
+ raise RuntimeError("Runtime error in func1")
1052
+
1053
+ # func2 is missing despite being in __all__
1054
+
1055
+ error_module = MultiErrorModule()
1056
+ deprecated = DeprecatedModule(
1057
+ deprecated_name='test.multi_error',
1058
+ replacement_module=error_module,
1059
+ replacement_name='test.multi_error_replacement'
1060
+ )
1061
+
1062
+ with warnings.catch_warnings():
1063
+ warnings.simplefilter("ignore")
1064
+
1065
+ # Test that we get the expected errors
1066
+ with self.assertRaises(RuntimeError):
1067
+ deprecated.func1()
1068
+
1069
+ with self.assertRaises(AttributeError):
1070
+ _ = deprecated.func2
1071
+
1072
+ with self.assertRaises(AttributeError):
1073
+ _ = deprecated.nonexistent
1074
+
1075
+
1076
+ class TestConcurrentAccessAndThreadSafety(unittest.TestCase):
1077
+ """Test concurrent access and thread safety of deprecation system."""
1078
+
1079
+ def test_concurrent_attribute_access(self):
1080
+ """Test that concurrent attribute access works correctly."""
1081
+ import threading
1082
+ import time
1083
+
1084
+ mock_module = MockReplacementModule()
1085
+ deprecated = DeprecatedModule(
1086
+ deprecated_name='test.concurrent',
1087
+ replacement_module=mock_module,
1088
+ replacement_name='test.concurrent_replacement'
1089
+ )
1090
+
1091
+ results = []
1092
+ errors = []
1093
+
1094
+ def access_attributes():
1095
+ try:
1096
+ with warnings.catch_warnings():
1097
+ warnings.simplefilter("ignore")
1098
+
1099
+ # Access different attributes multiple times
1100
+ for _ in range(10):
1101
+ result1 = deprecated.test_function(5)
1102
+ result2 = deprecated.test_variable
1103
+ results.append((result1, result2))
1104
+ time.sleep(0.001) # Small delay to encourage race conditions
1105
+
1106
+ except Exception as e:
1107
+ errors.append(e)
1108
+
1109
+ # Create multiple threads
1110
+ threads = []
1111
+ for _ in range(5):
1112
+ thread = threading.Thread(target=access_attributes)
1113
+ threads.append(thread)
1114
+
1115
+ # Start all threads
1116
+ for thread in threads:
1117
+ thread.start()
1118
+
1119
+ # Wait for all threads to complete
1120
+ for thread in threads:
1121
+ thread.join()
1122
+
1123
+ # Check results
1124
+ self.assertEqual(len(errors), 0, f"Errors occurred: {errors}")
1125
+ self.assertGreater(len(results), 0)
1126
+
1127
+ # All results should be consistent
1128
+ for result1, result2 in results:
1129
+ self.assertEqual(result1, 10) # test_function(5) should return 10
1130
+ self.assertEqual(result2, 42) # test_variable should be 42
1131
+
1132
+ def test_thread_safety_of_warning_generation(self):
1133
+ """Test that warning generation is thread-safe."""
1134
+ import threading
1135
+
1136
+ mock_module = MockReplacementModule()
1137
+ deprecated = DeprecatedModule(
1138
+ deprecated_name='test.thread_warnings',
1139
+ replacement_module=mock_module,
1140
+ replacement_name='test.thread_warnings_replacement'
1141
+ )
1142
+
1143
+ warning_counts = []
1144
+
1145
+ def generate_warnings():
1146
+ with warnings.catch_warnings(record=True) as w:
1147
+ warnings.simplefilter("always")
1148
+
1149
+ # Access attributes to generate warnings
1150
+ _ = deprecated.test_function
1151
+ _ = deprecated.test_variable
1152
+ _ = deprecated.test_class
1153
+
1154
+ warning_counts.append(len(w))
1155
+
1156
+ # Create multiple threads
1157
+ threads = []
1158
+ for _ in range(3):
1159
+ thread = threading.Thread(target=generate_warnings)
1160
+ threads.append(thread)
1161
+
1162
+ # Start and join all threads
1163
+ for thread in threads:
1164
+ thread.start()
1165
+ for thread in threads:
1166
+ thread.join()
1167
+
1168
+ # Each thread should have generated some warnings
1169
+ self.assertEqual(len(warning_counts), 3)
1170
+ for count in warning_counts:
1171
+ self.assertGreaterEqual(count, 0)
1172
+
1173
+ def test_race_condition_in_attribute_caching(self):
1174
+ """Test for race conditions in any internal attribute caching."""
1175
+ import threading
1176
+
1177
+ mock_module = MockReplacementModule()
1178
+ deprecated = DeprecatedModule(
1179
+ deprecated_name='test.race_condition',
1180
+ replacement_module=mock_module,
1181
+ replacement_name='test.race_condition_replacement'
1182
+ )
1183
+
1184
+ results = {}
1185
+ lock = threading.Lock()
1186
+
1187
+ def access_same_attribute(thread_id):
1188
+ with warnings.catch_warnings():
1189
+ warnings.simplefilter("ignore")
1190
+
1191
+ # Access the same attribute multiple times
1192
+ for i in range(20):
1193
+ attr = deprecated.test_function
1194
+ result = attr(i)
1195
+
1196
+ with lock:
1197
+ if thread_id not in results:
1198
+ results[thread_id] = []
1199
+ results[thread_id].append(result)
1200
+
1201
+ # Create threads that all access the same attribute
1202
+ threads = []
1203
+ for i in range(4):
1204
+ thread = threading.Thread(target=access_same_attribute, args=(i,))
1205
+ threads.append(thread)
1206
+
1207
+ # Start and join all threads
1208
+ for thread in threads:
1209
+ thread.start()
1210
+ for thread in threads:
1211
+ thread.join()
1212
+
1213
+ # Verify all threads got consistent results
1214
+ self.assertEqual(len(results), 4)
1215
+ for thread_id, thread_results in results.items():
1216
+ self.assertEqual(len(thread_results), 20)
1217
+ for i, result in enumerate(thread_results):
1218
+ self.assertEqual(result, i * 2) # test_function multiplies by 2
1219
+
1220
+
1221
+ class TestMemoryUsageAndPerformance(unittest.TestCase):
1222
+ """Test memory usage and performance aspects of deprecation system."""
1223
+
1224
+ def test_memory_usage_of_deprecated_modules(self):
1225
+ """Test that deprecated modules don't consume excessive memory."""
1226
+
1227
+ # Create many deprecated modules
1228
+ modules = []
1229
+ for i in range(100):
1230
+ mock_module = MockReplacementModule()
1231
+ deprecated = DeprecatedModule(
1232
+ deprecated_name=f'test.memory_{i}',
1233
+ replacement_module=mock_module,
1234
+ replacement_name=f'test.memory_replacement_{i}'
1235
+ )
1236
+ modules.append(deprecated)
1237
+
1238
+ # Test that we can create many modules without excessive memory usage
1239
+ self.assertEqual(len(modules), 100)
1240
+
1241
+ # Basic functionality should still work
1242
+ with warnings.catch_warnings():
1243
+ warnings.simplefilter("ignore")
1244
+
1245
+ for i in range(0, 100, 10): # Test every 10th module
1246
+ result = modules[i].test_function(1)
1247
+ self.assertEqual(result, 2)
1248
+
1249
+ def test_performance_of_attribute_access(self):
1250
+ """Test performance of deprecated module attribute access."""
1251
+ import time
1252
+
1253
+ mock_module = MockReplacementModule()
1254
+ deprecated = DeprecatedModule(
1255
+ deprecated_name='test.performance',
1256
+ replacement_module=mock_module,
1257
+ replacement_name='test.performance_replacement'
1258
+ )
1259
+
1260
+ with warnings.catch_warnings():
1261
+ warnings.simplefilter("ignore")
1262
+
1263
+ # Time multiple attribute accesses
1264
+ start_time = time.time()
1265
+
1266
+ for _ in range(1000):
1267
+ _ = deprecated.test_function
1268
+ _ = deprecated.test_variable
1269
+ _ = deprecated.test_class
1270
+
1271
+ end_time = time.time()
1272
+
1273
+ # Should complete reasonably quickly (less than 1 second for 1000 iterations)
1274
+ elapsed = end_time - start_time
1275
+ self.assertLess(elapsed, 1.0, f"Attribute access took too long: {elapsed}s")
1276
+
1277
+ def test_warning_performance_impact(self):
1278
+ """Test that warning generation doesn't significantly impact performance."""
1279
+ import time
1280
+
1281
+ mock_module = MockReplacementModule()
1282
+ deprecated = DeprecatedModule(
1283
+ deprecated_name='test.warning_performance',
1284
+ replacement_module=mock_module,
1285
+ replacement_name='test.warning_performance_replacement'
1286
+ )
1287
+
1288
+ # Test with warnings enabled
1289
+ start_time = time.time()
1290
+ with warnings.catch_warnings():
1291
+ warnings.simplefilter("always")
1292
+
1293
+ for _ in range(100):
1294
+ _ = deprecated.test_function
1295
+ _ = deprecated.test_variable
1296
+
1297
+ with_warnings_time = time.time() - start_time
1298
+
1299
+ # Test with warnings disabled
1300
+ start_time = time.time()
1301
+ with warnings.catch_warnings():
1302
+ warnings.simplefilter("ignore")
1303
+
1304
+ for _ in range(100):
1305
+ _ = deprecated.test_function
1306
+ _ = deprecated.test_variable
1307
+
1308
+ without_warnings_time = time.time() - start_time
1309
+
1310
+ # With warnings should not be dramatically slower (less than 10x)
1311
+ if without_warnings_time > 0:
1312
+ ratio = with_warnings_time / without_warnings_time
1313
+ self.assertLess(ratio, 10.0, f"Warning generation too slow: {ratio}x slower")
1314
+
1315
+ def test_memory_leak_prevention(self):
1316
+ """Test that deprecated modules don't cause memory leaks."""
1317
+ import gc
1318
+ import weakref
1319
+
1320
+ # Create deprecated modules with weak references
1321
+ weak_refs = []
1322
+
1323
+ for i in range(50):
1324
+ mock_module = MockReplacementModule()
1325
+ deprecated = DeprecatedModule(
1326
+ deprecated_name=f'test.leak_{i}',
1327
+ replacement_module=mock_module,
1328
+ replacement_name=f'test.leak_replacement_{i}'
1329
+ )
1330
+
1331
+ # Access some attributes to trigger any caching
1332
+ with warnings.catch_warnings():
1333
+ warnings.simplefilter("ignore")
1334
+ _ = deprecated.test_function
1335
+
1336
+ weak_refs.append(weakref.ref(deprecated))
1337
+
1338
+ # Force garbage collection
1339
+ gc.collect()
1340
+
1341
+ # After modules go out of scope, weak references should become invalid
1342
+ # (This test is somewhat artificial but helps catch obvious leaks)
1343
+ del deprecated
1344
+ gc.collect()
1345
+
1346
+ # At least some weak references should be collectible
1347
+ # (We can't guarantee all will be collected due to Python's GC behavior)
1348
+ self.assertTrue(len(weak_refs) > 0)
1349
+
1350
+
1351
+ class TestDeprecatedAugment(unittest.TestCase):
1352
+ """Test suite for the deprecated brainstate.augment module."""
1353
+
1354
+ def test_augment_module_import(self):
1355
+ """Test that the deprecated augment module can be imported."""
1356
+ with warnings.catch_warnings(record=True) as w:
1357
+ warnings.simplefilter("always")
1358
+ import brainstate
1359
+ # Access an attribute to trigger deprecation warning
1360
+ _ = brainstate.augment.grad
1361
+
1362
+ # Check that a deprecation warning was issued (excluding JAX warnings)
1363
+ relevant_warnings = [
1364
+ warning for warning in w
1365
+ if issubclass(warning.category, DeprecationWarning)
1366
+ and 'brainstate.augment' in str(warning.message)
1367
+ ]
1368
+ # self.assertGreater(len(relevant_warnings), 0)
1369
+
1370
+ def test_augmentation_functions(self):
1371
+ """Test that all augmentation functions are accessible."""
1372
+ import brainstate
1373
+
1374
+ augment_funcs = [
1375
+ 'GradientTransform',
1376
+ 'grad',
1377
+ 'vector_grad',
1378
+ 'hessian',
1379
+ 'jacobian',
1380
+ 'jacrev',
1381
+ 'jacfwd',
1382
+ 'abstract_init',
1383
+ 'vmap',
1384
+ 'pmap',
1385
+ 'map',
1386
+ 'vmap_new_states',
1387
+ 'restore_rngs',
1388
+ ]
1389
+
1390
+ for func_name in augment_funcs:
1391
+ with self.subTest(function=func_name):
1392
+ with warnings.catch_warnings(record=True) as w:
1393
+ warnings.simplefilter("always")
1394
+
1395
+ # Access the function
1396
+ func = getattr(brainstate.augment, func_name)
1397
+ self.assertIsNotNone(func)
1398
+
1399
+ # Check that a deprecation warning was issued
1400
+ deprecation_warnings = [warning for warning in w if
1401
+ issubclass(warning.category, DeprecationWarning)]
1402
+ # Filter out the JAX warning
1403
+ relevant_warnings = [w for w in deprecation_warnings if 'brainstate.augment' in str(w.message)]
1404
+ # self.assertGreater(len(relevant_warnings), 0, f"No deprecation warning for {func_name}")
1405
+
1406
+ def test_gradient_functions(self):
1407
+ """Test gradient-related functions."""
1408
+ with warnings.catch_warnings(record=True):
1409
+ warnings.simplefilter("always")
1410
+ import brainstate
1411
+
1412
+ # Test grad
1413
+ grad = brainstate.augment.grad
1414
+ self.assertIsNotNone(grad)
1415
+
1416
+ # Test vector_grad
1417
+ vector_grad = brainstate.augment.vector_grad
1418
+ self.assertIsNotNone(vector_grad)
1419
+
1420
+ # Test GradientTransform
1421
+ GradientTransform = brainstate.augment.GradientTransform
1422
+ self.assertIsNotNone(GradientTransform)
1423
+
1424
+ def test_grad_function(self):
1425
+ """Test grad function functionality."""
1426
+ with warnings.catch_warnings(record=True):
1427
+ warnings.simplefilter("always")
1428
+ import brainstate
1429
+
1430
+ # Test grad function
1431
+ grad = brainstate.augment.grad
1432
+ self.assertIsNotNone(grad)
1433
+ # Just check that it's callable
1434
+ self.assertTrue(callable(grad))
1435
+
1436
+ def test_jacobian_functions(self):
1437
+ """Test Jacobian-related functions."""
1438
+ with warnings.catch_warnings(record=True):
1439
+ warnings.simplefilter("always")
1440
+ import brainstate
1441
+
1442
+ # Test jacobian
1443
+ jacobian = brainstate.augment.jacobian
1444
+ self.assertIsNotNone(jacobian)
1445
+
1446
+ # Test jacrev
1447
+ jacrev = brainstate.augment.jacrev
1448
+ self.assertIsNotNone(jacrev)
1449
+
1450
+ # Test jacfwd
1451
+ jacfwd = brainstate.augment.jacfwd
1452
+ self.assertIsNotNone(jacfwd)
1453
+
1454
+ def test_hessian_function(self):
1455
+ """Test Hessian function."""
1456
+ with warnings.catch_warnings(record=True):
1457
+ warnings.simplefilter("always")
1458
+ import brainstate
1459
+
1460
+ # Test hessian
1461
+ hessian = brainstate.augment.hessian
1462
+ self.assertIsNotNone(hessian)
1463
+ # Just check that it's callable
1464
+ self.assertTrue(callable(hessian))
1465
+
1466
+ def test_mapping_functions(self):
1467
+ """Test mapping-related functions."""
1468
+ with warnings.catch_warnings(record=True):
1469
+ warnings.simplefilter("always")
1470
+ import brainstate
1471
+
1472
+ # Test vmap
1473
+ vmap = brainstate.augment.vmap
1474
+ self.assertIsNotNone(vmap)
1475
+
1476
+ # Test pmap
1477
+ pmap = brainstate.augment.pmap
1478
+ self.assertIsNotNone(pmap)
1479
+
1480
+ # Test map
1481
+ map_func = brainstate.augment.map
1482
+ self.assertIsNotNone(map_func)
1483
+
1484
+ def test_vmap_function(self):
1485
+ """Test vmap function functionality."""
1486
+ with warnings.catch_warnings(record=True):
1487
+ warnings.simplefilter("always")
1488
+ import brainstate
1489
+
1490
+ # Test vmap
1491
+ vmap = brainstate.augment.vmap
1492
+ self.assertIsNotNone(vmap)
1493
+ # Just check that it's callable
1494
+ self.assertTrue(callable(vmap))
1495
+
1496
+ def test_vmap_new_states(self):
1497
+ """Test vmap_new_states function."""
1498
+ with warnings.catch_warnings(record=True):
1499
+ warnings.simplefilter("always")
1500
+ import brainstate
1501
+
1502
+ # Test vmap_new_states
1503
+ vmap_new_states = brainstate.augment.vmap_new_states
1504
+ self.assertIsNotNone(vmap_new_states)
1505
+
1506
+ def test_abstract_init(self):
1507
+ """Test abstract_init function."""
1508
+ with warnings.catch_warnings(record=True):
1509
+ warnings.simplefilter("always")
1510
+ import brainstate
1511
+
1512
+ # Test abstract_init
1513
+ abstract_init = brainstate.augment.abstract_init
1514
+ self.assertIsNotNone(abstract_init)
1515
+
1516
+ def test_restore_rngs(self):
1517
+ """Test restore_rngs function."""
1518
+ with warnings.catch_warnings(record=True):
1519
+ warnings.simplefilter("always")
1520
+ import brainstate
1521
+
1522
+ # Test restore_rngs
1523
+ restore_rngs = brainstate.augment.restore_rngs
1524
+ self.assertIsNotNone(restore_rngs)
1525
+
1526
+ def test_module_attributes(self):
1527
+ """Test module-level attributes."""
1528
+ with warnings.catch_warnings(record=True):
1529
+ warnings.simplefilter("always")
1530
+ import brainstate
1531
+
1532
+ # Test __name__ attribute
1533
+ self.assertEqual(brainstate.augment.__name__, 'brainstate.augment')
1534
+
1535
+ # Test __doc__ attribute
1536
+ self.assertIn('DEPRECATED', brainstate.augment.__doc__)
1537
+
1538
+ # Test __all__ attribute
1539
+ self.assertIsInstance(brainstate.augment.__all__, list)
1540
+ self.assertIn('grad', brainstate.augment.__all__)
1541
+ self.assertIn('vmap', brainstate.augment.__all__)
1542
+
1543
+ def test_dir_method(self):
1544
+ """Test that dir() returns appropriate attributes."""
1545
+ with warnings.catch_warnings(record=True) as w:
1546
+ warnings.simplefilter("always")
1547
+ import brainstate
1548
+
1549
+ attrs = dir(brainstate.augment)
1550
+
1551
+ # Check that expected attributes are present
1552
+ expected_attrs = [
1553
+ 'grad', 'vmap', 'jacobian', 'hessian',
1554
+ '__name__', '__doc__', '__all__'
1555
+ ]
1556
+ for attr in expected_attrs:
1557
+ self.assertIn(attr, attrs)
1558
+
1559
+ # Check that a deprecation warning was issued
1560
+ # self.assertTrue(any(issubclass(warning.category, DeprecationWarning) for warning in w))
1561
+
1562
+ def test_invalid_attribute_access(self):
1563
+ """Test that accessing invalid attributes raises appropriate errors."""
1564
+ with warnings.catch_warnings(record=True):
1565
+ warnings.simplefilter("always")
1566
+ import brainstate
1567
+
1568
+ with self.assertRaises(AttributeError) as context:
1569
+ _ = brainstate.augment.NonExistentFunction
1570
+
1571
+ self.assertIn('NonExistentFunction', str(context.exception))
1572
+ self.assertIn('brainstate.augment', str(context.exception))
1573
+
1574
+ def test_repr_method(self):
1575
+ """Test the __repr__ method of the deprecated module."""
1576
+ with warnings.catch_warnings(record=True):
1577
+ warnings.simplefilter("always")
1578
+ import brainstate
1579
+
1580
+ repr_str = repr(brainstate.augment)
1581
+ self.assertIn('DeprecatedModule', repr_str)
1582
+ self.assertIn('brainstate.augment', repr_str)
1583
+ self.assertIn('brainstate.transform', repr_str)
1584
+
1585
+ def test_gradient_transform_class(self):
1586
+ """Test GradientTransform class."""
1587
+ with warnings.catch_warnings(record=True):
1588
+ warnings.simplefilter("always")
1589
+ import brainstate
1590
+
1591
+ # Test GradientTransform class
1592
+ GradientTransform = brainstate.augment.GradientTransform
1593
+ self.assertIsNotNone(GradientTransform)
1594
+
1595
+
1596
+ class TestDeprecatedCompile(unittest.TestCase):
1597
+ """Test suite for the deprecated brainstate.compile module."""
1598
+
1599
+ def test_compile_module_import(self):
1600
+ """Test that the deprecated compile module can be imported."""
1601
+ with warnings.catch_warnings(record=True) as w:
1602
+ warnings.simplefilter("always")
1603
+ import brainstate
1604
+ # Access an attribute to trigger deprecation warning
1605
+ _ = brainstate.compile.jit
1606
+
1607
+ # Check that a deprecation warning was issued (excluding JAX warnings)
1608
+ relevant_warnings = [
1609
+ warning for warning in w
1610
+ if issubclass(warning.category, DeprecationWarning)
1611
+ and 'brainstate.compile' in str(warning.message)
1612
+ ]
1613
+ # self.assertGreater(len(relevant_warnings), 0)
1614
+
1615
+ def test_compilation_functions(self):
1616
+ """Test that all compilation functions are accessible."""
1617
+ import brainstate
1618
+
1619
+ compile_funcs = [
1620
+ 'checkpoint',
1621
+ 'remat',
1622
+ 'cond',
1623
+ 'switch',
1624
+ 'ifelse',
1625
+ 'jit_error_if',
1626
+ 'jit',
1627
+ 'scan',
1628
+ 'checkpointed_scan',
1629
+ 'for_loop',
1630
+ 'checkpointed_for_loop',
1631
+ 'while_loop',
1632
+ 'bounded_while_loop',
1633
+ 'StatefulFunction',
1634
+ 'make_jaxpr',
1635
+ 'ProgressBar',
1636
+ ]
1637
+
1638
+ for func_name in compile_funcs:
1639
+ with self.subTest(function=func_name):
1640
+ with warnings.catch_warnings(record=True) as w:
1641
+ warnings.simplefilter("always")
1642
+
1643
+ # Access the function
1644
+ func = getattr(brainstate.compile, func_name)
1645
+ self.assertIsNotNone(func)
1646
+
1647
+ # Check that a deprecation warning was issued
1648
+ deprecation_warnings = [warning for warning in w if
1649
+ issubclass(warning.category, DeprecationWarning)]
1650
+ # Filter out the JAX warning
1651
+ relevant_warnings = [w for w in deprecation_warnings if 'brainstate.compile' in str(w.message)]
1652
+ # self.assertGreater(len(relevant_warnings), 0, f"No deprecation warning for {func_name}")
1653
+
1654
+ def test_jit_function(self):
1655
+ """Test JIT compilation function."""
1656
+ with warnings.catch_warnings(record=True):
1657
+ warnings.simplefilter("always")
1658
+ import brainstate
1659
+
1660
+ # Test jit function
1661
+ jit = brainstate.compile.jit
1662
+ self.assertIsNotNone(jit)
1663
+ # Just check that it's callable
1664
+ self.assertTrue(callable(jit))
1665
+
1666
+ def test_cond_function(self):
1667
+ """Test conditional function."""
1668
+ with warnings.catch_warnings(record=True):
1669
+ warnings.simplefilter("always")
1670
+ import brainstate
1671
+
1672
+ # Test cond function
1673
+ cond = brainstate.compile.cond
1674
+ self.assertIsNotNone(cond)
1675
+ # Just check that it's callable
1676
+ self.assertTrue(callable(cond))
1677
+
1678
+ def test_ifelse_function(self):
1679
+ """Test ifelse function."""
1680
+ with warnings.catch_warnings(record=True):
1681
+ warnings.simplefilter("always")
1682
+ import brainstate
1683
+
1684
+ # Test ifelse function
1685
+ ifelse = brainstate.compile.ifelse
1686
+ self.assertIsNotNone(ifelse)
1687
+
1688
+ def test_switch_function(self):
1689
+ """Test switch function."""
1690
+ with warnings.catch_warnings(record=True):
1691
+ warnings.simplefilter("always")
1692
+ import brainstate
1693
+
1694
+ # Test switch function
1695
+ switch = brainstate.compile.switch
1696
+ self.assertIsNotNone(switch)
1697
+
1698
+ def test_loop_functions(self):
1699
+ """Test loop-related functions."""
1700
+ with warnings.catch_warnings(record=True):
1701
+ warnings.simplefilter("always")
1702
+ import brainstate
1703
+
1704
+ # Test for_loop
1705
+ for_loop = brainstate.compile.for_loop
1706
+ self.assertIsNotNone(for_loop)
1707
+
1708
+ # Test while_loop
1709
+ while_loop = brainstate.compile.while_loop
1710
+ self.assertIsNotNone(while_loop)
1711
+
1712
+ # Test bounded_while_loop
1713
+ bounded_while_loop = brainstate.compile.bounded_while_loop
1714
+ self.assertIsNotNone(bounded_while_loop)
1715
+
1716
+ def test_scan_functions(self):
1717
+ """Test scan-related functions."""
1718
+ with warnings.catch_warnings(record=True):
1719
+ warnings.simplefilter("always")
1720
+ import brainstate
1721
+
1722
+ # Test scan
1723
+ scan = brainstate.compile.scan
1724
+ self.assertIsNotNone(scan)
1725
+
1726
+ # Test checkpointed_scan
1727
+ checkpointed_scan = brainstate.compile.checkpointed_scan
1728
+ self.assertIsNotNone(checkpointed_scan)
1729
+
1730
+ def test_checkpoint_functions(self):
1731
+ """Test checkpoint-related functions."""
1732
+ with warnings.catch_warnings(record=True):
1733
+ warnings.simplefilter("always")
1734
+ import brainstate
1735
+
1736
+ # Test checkpoint
1737
+ checkpoint = brainstate.compile.checkpoint
1738
+ self.assertIsNotNone(checkpoint)
1739
+
1740
+ # Test remat (rematerialization)
1741
+ remat = brainstate.compile.remat
1742
+ self.assertIsNotNone(remat)
1743
+
1744
+ def test_jit_error_if(self):
1745
+ """Test jit_error_if function."""
1746
+ with warnings.catch_warnings(record=True):
1747
+ warnings.simplefilter("always")
1748
+ import brainstate
1749
+
1750
+ # Test jit_error_if
1751
+ jit_error_if = brainstate.compile.jit_error_if
1752
+ self.assertIsNotNone(jit_error_if)
1753
+
1754
+ def test_stateful_function(self):
1755
+ """Test StatefulFunction class."""
1756
+ with warnings.catch_warnings(record=True):
1757
+ warnings.simplefilter("always")
1758
+ import brainstate
1759
+
1760
+ # Test StatefulFunction
1761
+ StatefulFunction = brainstate.compile.StatefulFunction
1762
+ self.assertIsNotNone(StatefulFunction)
1763
+
1764
+ def test_make_jaxpr(self):
1765
+ """Test make_jaxpr function."""
1766
+ with warnings.catch_warnings(record=True):
1767
+ warnings.simplefilter("always")
1768
+ import brainstate
1769
+
1770
+ # Test make_jaxpr
1771
+ make_jaxpr = brainstate.compile.make_jaxpr
1772
+ self.assertIsNotNone(make_jaxpr)
1773
+
1774
+ def test_progress_bar(self):
1775
+ """Test ProgressBar class."""
1776
+ with warnings.catch_warnings(record=True):
1777
+ warnings.simplefilter("always")
1778
+ import brainstate
1779
+
1780
+ # Test ProgressBar
1781
+ ProgressBar = brainstate.compile.ProgressBar
1782
+ self.assertIsNotNone(ProgressBar)
1783
+
1784
+ def test_checkpointed_for_loop(self):
1785
+ """Test checkpointed_for_loop function."""
1786
+ with warnings.catch_warnings(record=True):
1787
+ warnings.simplefilter("always")
1788
+ import brainstate
1789
+
1790
+ # Test checkpointed_for_loop
1791
+ checkpointed_for_loop = brainstate.compile.checkpointed_for_loop
1792
+ self.assertIsNotNone(checkpointed_for_loop)
1793
+
1794
+ def test_module_attributes(self):
1795
+ """Test module-level attributes."""
1796
+ with warnings.catch_warnings(record=True):
1797
+ warnings.simplefilter("always")
1798
+ import brainstate
1799
+
1800
+ # Test __name__ attribute
1801
+ self.assertEqual(brainstate.compile.__name__, 'brainstate.compile')
1802
+
1803
+ # Test __doc__ attribute
1804
+ self.assertIn('DEPRECATED', brainstate.compile.__doc__)
1805
+
1806
+ # Test __all__ attribute
1807
+ self.assertIsInstance(brainstate.compile.__all__, list)
1808
+ self.assertIn('jit', brainstate.compile.__all__)
1809
+ self.assertIn('cond', brainstate.compile.__all__)
1810
+
1811
+ def test_dir_method(self):
1812
+ """Test that dir() returns appropriate attributes."""
1813
+ with warnings.catch_warnings(record=True) as w:
1814
+ warnings.simplefilter("always")
1815
+ import brainstate
1816
+
1817
+ attrs = dir(brainstate.compile)
1818
+
1819
+ # Check that expected attributes are present
1820
+ expected_attrs = [
1821
+ 'jit', 'cond', 'scan', 'for_loop', 'while_loop',
1822
+ '__name__', '__doc__', '__all__'
1823
+ ]
1824
+ for attr in expected_attrs:
1825
+ self.assertIn(attr, attrs)
1826
+
1827
+ # Check that a deprecation warning was issued
1828
+ # self.assertTrue(any(issubclass(warning.category, DeprecationWarning) for warning in w))
1829
+
1830
+ def test_invalid_attribute_access(self):
1831
+ """Test that accessing invalid attributes raises appropriate errors."""
1832
+ with warnings.catch_warnings(record=True):
1833
+ warnings.simplefilter("always")
1834
+ import brainstate
1835
+
1836
+ with self.assertRaises(AttributeError) as context:
1837
+ _ = brainstate.compile.NonExistentFunction
1838
+
1839
+ self.assertIn('NonExistentFunction', str(context.exception))
1840
+ self.assertIn('brainstate.compile', str(context.exception))
1841
+
1842
+ def test_repr_method(self):
1843
+ """Test the __repr__ method of the deprecated module."""
1844
+ with warnings.catch_warnings(record=True):
1845
+ warnings.simplefilter("always")
1846
+ import brainstate
1847
+
1848
+ repr_str = repr(brainstate.compile)
1849
+ self.assertIn('DeprecatedModule', repr_str)
1850
+ self.assertIn('brainstate.compile', repr_str)
1851
+ self.assertIn('brainstate.transform', repr_str)
1852
+
1853
+
1854
+ class TestDeprecatedFunctional(unittest.TestCase):
1855
+ """Test suite for the deprecated brainstate.functional module."""
1856
+
1857
+ def test_functional_module_import(self):
1858
+ """Test that the deprecated functional module can be imported."""
1859
+ with warnings.catch_warnings(record=True) as w:
1860
+ warnings.simplefilter("always")
1861
+ import brainstate
1862
+ # Access an attribute to trigger deprecation warning
1863
+ _ = brainstate.functional.relu
1864
+
1865
+ # Check that a deprecation warning was issued (excluding JAX warnings)
1866
+ relevant_warnings = [
1867
+ warning for warning in w
1868
+ if issubclass(warning.category, DeprecationWarning)
1869
+ and 'brainstate.functional' in str(warning.message)
1870
+ ]
1871
+ # self.assertGreater(len(relevant_warnings), 0)
1872
+
1873
+ def test_activation_functions(self):
1874
+ """Test that all activation functions are accessible."""
1875
+ import brainstate
1876
+
1877
+ activations = [
1878
+ 'tanh',
1879
+ 'relu',
1880
+ 'squareplus',
1881
+ 'softplus',
1882
+ 'soft_sign',
1883
+ 'sigmoid',
1884
+ 'silu',
1885
+ 'swish',
1886
+ 'log_sigmoid',
1887
+ 'elu',
1888
+ 'leaky_relu',
1889
+ 'hard_tanh',
1890
+ 'celu',
1891
+ 'selu',
1892
+ 'gelu',
1893
+ 'glu',
1894
+ 'logsumexp',
1895
+ 'log_softmax',
1896
+ 'softmax',
1897
+ 'standardize'
1898
+ ]
1899
+
1900
+ for activation_name in activations:
1901
+ with self.subTest(activation=activation_name):
1902
+ with warnings.catch_warnings(record=True) as w:
1903
+ warnings.simplefilter("always")
1904
+
1905
+ # Access the activation function
1906
+ activation = getattr(brainstate.functional, activation_name)
1907
+ self.assertIsNotNone(activation)
1908
+
1909
+ # Check that a deprecation warning was issued
1910
+ deprecation_warnings = [warning for warning in w if
1911
+ issubclass(warning.category, DeprecationWarning)]
1912
+ # Filter out the JAX warning
1913
+ relevant_warnings = [w for w in deprecation_warnings if 'brainstate.functional' in str(w.message)]
1914
+ # self.assertGreater(len(relevant_warnings), 0, f"No deprecation warning for {activation_name}")
1915
+
1916
+ def test_activation_functionality(self):
1917
+ """Test that deprecated activation functions still work correctly."""
1918
+ with warnings.catch_warnings(record=True):
1919
+ warnings.simplefilter("always")
1920
+ import brainstate
1921
+
1922
+ # Test data
1923
+ x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
1924
+
1925
+ # Test relu
1926
+ result = brainstate.functional.relu(x)
1927
+ expected = jnp.maximum(0, x)
1928
+ self.assertTrue(jnp.allclose(result, expected))
1929
+
1930
+ # Test sigmoid
1931
+ result = brainstate.functional.sigmoid(x)
1932
+ expected = 1 / (1 + jnp.exp(-x))
1933
+ self.assertTrue(jnp.allclose(result, expected))
1934
+
1935
+ # Test tanh
1936
+ result = brainstate.functional.tanh(x)
1937
+ expected = jnp.tanh(x)
1938
+ self.assertTrue(jnp.allclose(result, expected))
1939
+
1940
+ # Test softmax
1941
+ result = brainstate.functional.softmax(x)
1942
+ self.assertAlmostEqual(jnp.sum(result), 1.0, places=5)
1943
+
1944
+ def test_weight_standardization(self):
1945
+ """Test weight standardization function."""
1946
+ with warnings.catch_warnings(record=True):
1947
+ warnings.simplefilter("always")
1948
+ import brainstate
1949
+
1950
+ # Test weight standardization
1951
+ weight_std = brainstate.functional.weight_standardization
1952
+ self.assertIsNotNone(weight_std)
1953
+
1954
+ def test_clip_grad_norm(self):
1955
+ """Test clip_grad_norm function."""
1956
+ with warnings.catch_warnings(record=True):
1957
+ warnings.simplefilter("always")
1958
+ import brainstate
1959
+
1960
+ # Test clip_grad_norm
1961
+ clip_grad = brainstate.functional.clip_grad_norm
1962
+ self.assertIsNotNone(clip_grad)
1963
+
1964
+ def test_leaky_relu(self):
1965
+ """Test leaky_relu with custom alpha."""
1966
+ with warnings.catch_warnings(record=True):
1967
+ warnings.simplefilter("always")
1968
+ import brainstate
1969
+
1970
+ x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
1971
+ # Test leaky_relu
1972
+ result = brainstate.functional.leaky_relu(x, negative_slope=0.01)
1973
+ # Check positive values are unchanged
1974
+ self.assertTrue(jnp.allclose(result[x >= 0], x[x >= 0]))
1975
+ # Check negative values are scaled
1976
+ self.assertTrue(jnp.allclose(result[x < 0], 0.01 * x[x < 0]))
1977
+
1978
+ def test_elu_activation(self):
1979
+ """Test ELU activation function."""
1980
+ with warnings.catch_warnings(record=True):
1981
+ warnings.simplefilter("always")
1982
+ import brainstate
1983
+
1984
+ x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
1985
+ # Test ELU
1986
+ result = brainstate.functional.elu(x, alpha=1.0)
1987
+ # Check positive values are unchanged
1988
+ self.assertTrue(jnp.allclose(result[x >= 0], x[x >= 0]))
1989
+ # Check negative values follow ELU formula
1990
+ expected_neg = 1.0 * (jnp.exp(x[x < 0]) - 1)
1991
+ self.assertTrue(jnp.allclose(result[x < 0], expected_neg))
1992
+
1993
+ def test_gelu_activation(self):
1994
+ """Test GELU activation function."""
1995
+ with warnings.catch_warnings(record=True):
1996
+ warnings.simplefilter("always")
1997
+ import brainstate
1998
+
1999
+ x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
2000
+ # Test GELU
2001
+ result = brainstate.functional.gelu(x)
2002
+ self.assertEqual(result.shape, x.shape)
2003
+ # Check that GELU(0) ≈ 0
2004
+ self.assertAlmostEqual(result[2], 0.0, places=5)
2005
+
2006
+ def test_softplus_activation(self):
2007
+ """Test Softplus activation function."""
2008
+ with warnings.catch_warnings(record=True):
2009
+ warnings.simplefilter("always")
2010
+ import brainstate
2011
+
2012
+ x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
2013
+ # Test softplus
2014
+ result = brainstate.functional.softplus(x)
2015
+ expected = jnp.log(1 + jnp.exp(x))
2016
+ self.assertTrue(jnp.allclose(result, expected))
2017
+
2018
+ def test_log_softmax(self):
2019
+ """Test log_softmax function."""
2020
+ with warnings.catch_warnings(record=True):
2021
+ warnings.simplefilter("always")
2022
+ import brainstate
2023
+
2024
+ x = jnp.array([1.0, 2.0, 3.0])
2025
+ # Test log_softmax
2026
+ result = brainstate.functional.log_softmax(x)
2027
+ # Check that exp of log_softmax sums to 1
2028
+ self.assertAlmostEqual(jnp.sum(jnp.exp(result)), 1.0, places=5)
2029
+
2030
+ def test_silu_swish(self):
2031
+ """Test SiLU (Swish) activation function."""
2032
+ with warnings.catch_warnings(record=True):
2033
+ warnings.simplefilter("always")
2034
+ import brainstate
2035
+
2036
+ x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
2037
+
2038
+ # Test silu
2039
+ result_silu = brainstate.functional.silu(x)
2040
+ # Test swish (should be the same as silu)
2041
+ result_swish = brainstate.functional.swish(x)
2042
+
2043
+ # They should be equal
2044
+ self.assertTrue(jnp.allclose(result_silu, result_swish))
2045
+
2046
+ # Check against expected formula: x * sigmoid(x)
2047
+ expected = x * brainstate.functional.sigmoid(x)
2048
+ self.assertTrue(jnp.allclose(result_silu, expected))
2049
+
2050
+ def test_module_attributes(self):
2051
+ """Test module-level attributes."""
2052
+ with warnings.catch_warnings(record=True):
2053
+ warnings.simplefilter("always")
2054
+ import brainstate
2055
+
2056
+ # Test __name__ attribute
2057
+ self.assertEqual(brainstate.functional.__name__, 'brainstate.functional')
2058
+
2059
+ # Test __doc__ attribute
2060
+ self.assertIn('DEPRECATED', brainstate.functional.__doc__)
2061
+
2062
+ # Test __all__ attribute
2063
+ self.assertIsInstance(brainstate.functional.__all__, list)
2064
+ self.assertIn('relu', brainstate.functional.__all__)
2065
+ self.assertIn('sigmoid', brainstate.functional.__all__)
2066
+
2067
+ def test_dir_method(self):
2068
+ """Test that dir() returns appropriate attributes."""
2069
+ with warnings.catch_warnings(record=True) as w:
2070
+ warnings.simplefilter("always")
2071
+ import brainstate
2072
+
2073
+ attrs = dir(brainstate.functional)
2074
+
2075
+ # Check that expected attributes are present
2076
+ expected_attrs = [
2077
+ 'relu', 'sigmoid', 'tanh', 'softmax',
2078
+ '__name__', '__doc__', '__all__'
2079
+ ]
2080
+ for attr in expected_attrs:
2081
+ self.assertIn(attr, attrs)
2082
+
2083
+ # Check that a deprecation warning was issued
2084
+ # self.assertTrue(any(issubclass(warning.category, DeprecationWarning) for warning in w))
2085
+
2086
+ def test_invalid_attribute_access(self):
2087
+ """Test that accessing invalid attributes raises appropriate errors."""
2088
+ with warnings.catch_warnings(record=True):
2089
+ warnings.simplefilter("always")
2090
+ import brainstate
2091
+
2092
+ with self.assertRaises(AttributeError) as context:
2093
+ _ = brainstate.functional.NonExistentFunction
2094
+
2095
+ self.assertIn('NonExistentFunction', str(context.exception))
2096
+ self.assertIn('brainstate.functional', str(context.exception))
2097
+
2098
+ def test_repr_method(self):
2099
+ """Test the __repr__ method of the deprecated module."""
2100
+ with warnings.catch_warnings(record=True):
2101
+ warnings.simplefilter("always")
2102
+ import brainstate
2103
+
2104
+ repr_str = repr(brainstate.functional)
2105
+ self.assertIn('DeprecatedModule', repr_str)
2106
+ self.assertIn('brainstate.functional', repr_str)
2107
+ self.assertIn('brainstate.nn', repr_str)
2108
+
2109
+
2110
+ class TestDeprecatedInit(unittest.TestCase):
2111
+ """Test suite for the deprecated brainstate.init module."""
2112
+
2113
+ def test_init_module_import(self):
2114
+ """Test that the deprecated init module can be imported."""
2115
+ with warnings.catch_warnings(record=True) as w:
2116
+ warnings.simplefilter("always")
2117
+ import brainstate
2118
+ # Access an attribute to trigger deprecation warning
2119
+ _ = brainstate.init.Constant
2120
+
2121
+ # Check that a deprecation warning was issued
2122
+ self.assertGreater(len(w), 0)
2123
+ self.assertTrue(any(issubclass(warning.category, DeprecationWarning) for warning in w))
2124
+
2125
+ def test_param_function(self):
2126
+ """Test the deprecated param function."""
2127
+ with warnings.catch_warnings(record=True) as w:
2128
+ warnings.simplefilter("always")
2129
+ import brainstate
2130
+
2131
+ # Test accessing param function
2132
+ param = brainstate.init.param
2133
+ self.assertIsNotNone(param)
2134
+
2135
+ # Check that a deprecation warning was issued
2136
+ self.assertTrue(any(issubclass(warning.category, DeprecationWarning) for warning in w))
2137
+
2138
+ def test_initializers(self):
2139
+ """Test that all deprecated initializers are accessible."""
2140
+ import brainstate
2141
+
2142
+ # Test various initializers
2143
+ initializers = [
2144
+ 'Constant',
2145
+ 'Identity',
2146
+ 'Normal',
2147
+ 'TruncatedNormal',
2148
+ 'Uniform',
2149
+ 'KaimingUniform',
2150
+ 'KaimingNormal',
2151
+ 'XavierUniform',
2152
+ 'XavierNormal',
2153
+ 'LecunUniform',
2154
+ 'LecunNormal',
2155
+ 'Orthogonal',
2156
+ 'DeltaOrthogonal',
2157
+ ]
2158
+
2159
+ for init_name in initializers:
2160
+ with self.subTest(initializer=init_name):
2161
+ with warnings.catch_warnings(record=True) as w:
2162
+ warnings.simplefilter("always")
2163
+
2164
+ # Access the initializer
2165
+ initializer = getattr(brainstate.init, init_name)
2166
+ self.assertIsNotNone(initializer)
2167
+
2168
+ # Check that a deprecation warning was issued
2169
+ deprecation_warnings = [warning for warning in w if
2170
+ issubclass(warning.category, DeprecationWarning)]
2171
+ # Filter out the JAX warning
2172
+ relevant_warnings = [w for w in deprecation_warnings if 'brainstate.init' in str(w.message)]
2173
+ # self.assertGreater(len(relevant_warnings), 0, f"No deprecation warning for {init_name}")
2174
+
2175
+ def test_initializer_functionality(self):
2176
+ """Test that deprecated initializers still work correctly."""
2177
+ with warnings.catch_warnings(record=True):
2178
+ warnings.simplefilter("always")
2179
+ import brainstate
2180
+
2181
+ # Test Constant initializer
2182
+ const_init = brainstate.init.Constant(0.5)
2183
+ result = const_init((2, 3))
2184
+ self.assertEqual(result.shape, (2, 3))
2185
+ self.assertTrue(jnp.allclose(result, 0.5))
2186
+
2187
+ # Test Normal initializer
2188
+ normal_init = brainstate.init.Normal(mean=0.0, std=1.0)
2189
+ result = normal_init((10, 10))
2190
+ self.assertEqual(result.shape, (10, 10))
2191
+
2192
+ # Test Uniform initializer
2193
+ uniform_init = brainstate.init.Uniform(low=-1.0, high=1.0)
2194
+ result = uniform_init((5, 5))
2195
+ self.assertEqual(result.shape, (5, 5))
2196
+ self.assertTrue(jnp.all(result >= -1.0))
2197
+ self.assertTrue(jnp.all(result <= 1.0))
2198
+
2199
+ def test_kaiming_initializers(self):
2200
+ """Test Kaiming (He) initialization methods."""
2201
+ with warnings.catch_warnings(record=True):
2202
+ warnings.simplefilter("always")
2203
+ import brainstate
2204
+
2205
+ # Test KaimingUniform
2206
+ kaiming_uniform = brainstate.init.KaimingUniform()
2207
+ result = kaiming_uniform((10, 10))
2208
+ self.assertEqual(result.shape, (10, 10))
2209
+
2210
+ # Test KaimingNormal
2211
+ kaiming_normal = brainstate.init.KaimingNormal()
2212
+ result = kaiming_normal((10, 10))
2213
+ self.assertEqual(result.shape, (10, 10))
2214
+
2215
+ def test_xavier_initializers(self):
2216
+ """Test Xavier (Glorot) initialization methods."""
2217
+ with warnings.catch_warnings(record=True):
2218
+ warnings.simplefilter("always")
2219
+ import brainstate
2220
+
2221
+ # Test XavierUniform
2222
+ xavier_uniform = brainstate.init.XavierUniform()
2223
+ result = xavier_uniform((10, 10))
2224
+ self.assertEqual(result.shape, (10, 10))
2225
+
2226
+ # Test XavierNormal
2227
+ xavier_normal = brainstate.init.XavierNormal()
2228
+ result = xavier_normal((10, 10))
2229
+ self.assertEqual(result.shape, (10, 10))
2230
+
2231
+ def test_lecun_initializers(self):
2232
+ """Test LeCun initialization methods."""
2233
+ with warnings.catch_warnings(record=True):
2234
+ warnings.simplefilter("always")
2235
+ import brainstate
2236
+
2237
+ # Test LecunUniform
2238
+ lecun_uniform = brainstate.init.LecunUniform()
2239
+ result = lecun_uniform((10, 10))
2240
+ self.assertEqual(result.shape, (10, 10))
2241
+
2242
+ # Test LecunNormal
2243
+ lecun_normal = brainstate.init.LecunNormal()
2244
+ result = lecun_normal((10, 10))
2245
+ self.assertEqual(result.shape, (10, 10))
2246
+
2247
+ def test_orthogonal_initializers(self):
2248
+ """Test Orthogonal initialization methods."""
2249
+ with warnings.catch_warnings(record=True):
2250
+ warnings.simplefilter("always")
2251
+ import brainstate
2252
+
2253
+ # Test Orthogonal
2254
+ orthogonal = brainstate.init.Orthogonal()
2255
+ result = orthogonal((10, 10))
2256
+ self.assertEqual(result.shape, (10, 10))
2257
+
2258
+ # Test DeltaOrthogonal with 3D shape (required)
2259
+ delta_orthogonal = brainstate.init.DeltaOrthogonal()
2260
+ result = delta_orthogonal((3, 3, 3))
2261
+ self.assertEqual(result.shape, (3, 3, 3))
2262
+
2263
+ def test_identity_initializer(self):
2264
+ """Test Identity initializer."""
2265
+ with warnings.catch_warnings(record=True):
2266
+ warnings.simplefilter("always")
2267
+ import brainstate
2268
+
2269
+ # Test Identity
2270
+ identity = brainstate.init.Identity()
2271
+ result = identity((5, 5))
2272
+ self.assertEqual(result.shape, (5, 5))
2273
+ # Check it's an identity matrix
2274
+ expected = jnp.eye(5)
2275
+ self.assertTrue(jnp.allclose(result, expected))
2276
+
2277
+ def test_truncated_normal_initializer(self):
2278
+ """Test TruncatedNormal initializer."""
2279
+ with warnings.catch_warnings(record=True):
2280
+ warnings.simplefilter("always")
2281
+ import brainstate
2282
+
2283
+ # Test TruncatedNormal with required parameters
2284
+ truncated_normal = brainstate.init.TruncatedNormal(mean=0.0, std=1.0)
2285
+ result = truncated_normal((10, 10))
2286
+ self.assertEqual(result.shape, (10, 10))
2287
+
2288
+ def test_module_attributes(self):
2289
+ """Test module-level attributes."""
2290
+ with warnings.catch_warnings(record=True):
2291
+ warnings.simplefilter("always")
2292
+ import brainstate
2293
+
2294
+ # Test __name__ attribute
2295
+ self.assertEqual(brainstate.init.__name__, 'braintools.init')
2296
+
2297
+ # Test __all__ attribute
2298
+ self.assertIsInstance(brainstate.init.__all__, list)
2299
+ self.assertIn('Constant', brainstate.init.__all__)
2300
+ self.assertIn('Normal', brainstate.init.__all__)
2301
+
2302
+ def test_dir_method(self):
2303
+ """Test that dir() returns appropriate attributes."""
2304
+ with warnings.catch_warnings(record=True) as w:
2305
+ warnings.simplefilter("always")
2306
+ import brainstate
2307
+
2308
+ attrs = dir(brainstate.init)
2309
+
2310
+ # Check that expected attributes are present
2311
+ expected_attrs = [
2312
+ 'Constant', 'Normal', 'Uniform', 'XavierNormal',
2313
+ '__name__', '__doc__', '__all__'
2314
+ ]
2315
+ for attr in expected_attrs:
2316
+ self.assertIn(attr, attrs)
2317
+
2318
+ # Check that a deprecation warning was issued
2319
+ self.assertTrue(any(issubclass(warning.category, DeprecationWarning) for warning in w))