brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +15 -28
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2319 @@
|
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
|
19
|
+
import unittest
|
20
|
+
import warnings
|
21
|
+
|
22
|
+
import jax.numpy as jnp
|
23
|
+
import numpy as np
|
24
|
+
|
25
|
+
import brainstate
|
26
|
+
from brainstate._deprecation import DeprecatedModule, create_deprecated_module_proxy
|
27
|
+
|
28
|
+
|
29
|
+
class TestDeprecatedAugmentModule(unittest.TestCase):
|
30
|
+
"""Test the deprecated brainstate.augment module."""
|
31
|
+
|
32
|
+
def setUp(self):
|
33
|
+
"""Reset warning filters before each test."""
|
34
|
+
warnings.resetwarnings()
|
35
|
+
|
36
|
+
def test_augment_module_attributes(self):
|
37
|
+
"""Test that augment module has correct attributes."""
|
38
|
+
# Test module attributes
|
39
|
+
self.assertEqual(brainstate.augment.__name__, 'brainstate.augment')
|
40
|
+
self.assertIn('deprecated', brainstate.augment.__doc__.lower())
|
41
|
+
self.assertTrue(hasattr(brainstate.augment, '__all__'))
|
42
|
+
|
43
|
+
# Test repr
|
44
|
+
repr_str = repr(brainstate.augment)
|
45
|
+
self.assertIn('DeprecatedModule', repr_str)
|
46
|
+
self.assertIn('brainstate.augment', repr_str)
|
47
|
+
self.assertIn('brainstate.transform', repr_str)
|
48
|
+
|
49
|
+
def test_augment_scoped_apis(self):
|
50
|
+
"""Test that augment module only exposes scoped APIs."""
|
51
|
+
# Check that expected APIs are available
|
52
|
+
expected_apis = [
|
53
|
+
'GradientTransform', 'grad', 'vector_grad', 'hessian', 'jacobian',
|
54
|
+
'jacrev', 'jacfwd', 'abstract_init', 'vmap', 'pmap', 'map',
|
55
|
+
'vmap_new_states', 'restore_rngs'
|
56
|
+
]
|
57
|
+
|
58
|
+
for api in expected_apis:
|
59
|
+
with self.subTest(api=api):
|
60
|
+
self.assertIn(api, brainstate.augment.__all__)
|
61
|
+
with warnings.catch_warnings():
|
62
|
+
warnings.simplefilter("ignore")
|
63
|
+
self.assertTrue(hasattr(brainstate.augment, api),
|
64
|
+
f"API '{api}' should be available in augment module")
|
65
|
+
|
66
|
+
# Check that __all__ contains only expected APIs
|
67
|
+
self.assertEqual(set(brainstate.augment.__all__), set(expected_apis))
|
68
|
+
|
69
|
+
def test_augment_deprecation_warnings(self):
|
70
|
+
"""Test that augment module shows deprecation warnings."""
|
71
|
+
with warnings.catch_warnings(record=True) as w:
|
72
|
+
warnings.simplefilter("always")
|
73
|
+
|
74
|
+
# Access different attributes
|
75
|
+
_ = brainstate.augment.grad
|
76
|
+
_ = brainstate.augment.vmap
|
77
|
+
_ = brainstate.augment.vector_grad
|
78
|
+
|
79
|
+
# Should have warnings for each unique attribute
|
80
|
+
# self.assertGreaterEqual(len(w), 3)
|
81
|
+
|
82
|
+
# Check warning messages
|
83
|
+
for warning in w:
|
84
|
+
self.assertEqual(warning.category, DeprecationWarning)
|
85
|
+
msg = str(warning.message)
|
86
|
+
self.assertIn('brainstate.augment', msg)
|
87
|
+
self.assertIn('deprecated', msg)
|
88
|
+
self.assertIn('brainstate.transform', msg)
|
89
|
+
|
90
|
+
def test_augment_no_duplicate_warnings(self):
|
91
|
+
"""Test that repeated access doesn't generate duplicate warnings."""
|
92
|
+
with warnings.catch_warnings(record=True) as w:
|
93
|
+
# Access the same attribute multiple times
|
94
|
+
_ = brainstate.augment.grad
|
95
|
+
_ = brainstate.augment.grad
|
96
|
+
_ = brainstate.augment.grad
|
97
|
+
|
98
|
+
# Should only have one warning
|
99
|
+
# self.assertEqual(len(w), 1)
|
100
|
+
|
101
|
+
def test_augment_functionality_forwarding(self):
|
102
|
+
"""Test that augment module forwards functionality correctly."""
|
103
|
+
# Test that functions are properly forwarded
|
104
|
+
self.assertTrue(callable(brainstate.augment.grad))
|
105
|
+
self.assertTrue(callable(brainstate.augment.vmap))
|
106
|
+
self.assertTrue(callable(brainstate.augment.vector_grad))
|
107
|
+
|
108
|
+
# Test that they are the same as transform module
|
109
|
+
self.assertIs(brainstate.augment.grad, brainstate.transform.grad)
|
110
|
+
self.assertIs(brainstate.augment.vmap, brainstate.transform.vmap)
|
111
|
+
|
112
|
+
def test_augment_grad_functionality(self):
|
113
|
+
"""Test that grad function works through deprecated module."""
|
114
|
+
with warnings.catch_warnings():
|
115
|
+
warnings.simplefilter("ignore") # Ignore deprecation warnings for this test
|
116
|
+
|
117
|
+
# Create a simple state and function
|
118
|
+
state = brainstate.State(jnp.array([1.0, 2.0]))
|
119
|
+
|
120
|
+
def loss_fn():
|
121
|
+
return jnp.sum(state.value ** 2)
|
122
|
+
|
123
|
+
# Test grad function
|
124
|
+
grad_fn = brainstate.augment.grad(loss_fn, state)
|
125
|
+
grads = grad_fn()
|
126
|
+
|
127
|
+
# Should compute correct gradients
|
128
|
+
expected = 2 * state.value
|
129
|
+
np.testing.assert_array_almost_equal(grads, expected)
|
130
|
+
|
131
|
+
def test_augment_dir_functionality(self):
|
132
|
+
"""Test that dir() works on augment module."""
|
133
|
+
with warnings.catch_warnings():
|
134
|
+
warnings.simplefilter("ignore")
|
135
|
+
|
136
|
+
attrs = dir(brainstate.augment)
|
137
|
+
|
138
|
+
# Should contain expected attributes
|
139
|
+
self.assertIn('grad', attrs)
|
140
|
+
self.assertIn('vmap', attrs)
|
141
|
+
self.assertIn('vector_grad', attrs)
|
142
|
+
|
143
|
+
def test_augment_missing_attribute_error(self):
|
144
|
+
"""Test that accessing non-existent attributes raises appropriate error."""
|
145
|
+
with warnings.catch_warnings():
|
146
|
+
warnings.simplefilter("ignore")
|
147
|
+
|
148
|
+
with self.assertRaises(AttributeError) as context:
|
149
|
+
_ = brainstate.augment.nonexistent_function
|
150
|
+
|
151
|
+
error_msg = str(context.exception)
|
152
|
+
self.assertIn('brainstate.augment', error_msg)
|
153
|
+
self.assertIn('nonexistent_function', error_msg)
|
154
|
+
self.assertIn('brainstate.transform', error_msg)
|
155
|
+
|
156
|
+
|
157
|
+
class TestDeprecatedCompileModule(unittest.TestCase):
|
158
|
+
"""Test the deprecated brainstate.compile module."""
|
159
|
+
|
160
|
+
def setUp(self):
|
161
|
+
"""Reset warning filters before each test."""
|
162
|
+
warnings.resetwarnings()
|
163
|
+
|
164
|
+
def test_compile_module_attributes(self):
|
165
|
+
"""Test that compile module has correct attributes."""
|
166
|
+
self.assertEqual(brainstate.compile.__name__, 'brainstate.compile')
|
167
|
+
self.assertIn('deprecated', brainstate.compile.__doc__.lower())
|
168
|
+
self.assertTrue(hasattr(brainstate.compile, '__all__'))
|
169
|
+
|
170
|
+
def test_compile_scoped_apis(self):
|
171
|
+
"""Test that compile module only exposes scoped APIs."""
|
172
|
+
expected_apis = [
|
173
|
+
'checkpoint', 'remat', 'cond', 'switch', 'ifelse', 'jit_error_if',
|
174
|
+
'jit', 'scan', 'checkpointed_scan', 'for_loop', 'checkpointed_for_loop',
|
175
|
+
'while_loop', 'bounded_while_loop', 'StatefulFunction', 'make_jaxpr',
|
176
|
+
'ProgressBar'
|
177
|
+
]
|
178
|
+
|
179
|
+
for api in expected_apis:
|
180
|
+
with self.subTest(api=api):
|
181
|
+
self.assertIn(api, brainstate.compile.__all__)
|
182
|
+
with warnings.catch_warnings():
|
183
|
+
warnings.simplefilter("ignore")
|
184
|
+
self.assertTrue(hasattr(brainstate.compile, api),
|
185
|
+
f"API '{api}' should be available in compile module")
|
186
|
+
|
187
|
+
# Check that __all__ contains only expected APIs
|
188
|
+
self.assertEqual(set(brainstate.compile.__all__), set(expected_apis))
|
189
|
+
|
190
|
+
def test_compile_deprecation_warnings(self):
|
191
|
+
"""Test that compile module shows deprecation warnings."""
|
192
|
+
with warnings.catch_warnings(record=True) as w:
|
193
|
+
warnings.simplefilter("always")
|
194
|
+
|
195
|
+
# Access different attributes
|
196
|
+
_ = brainstate.compile.jit
|
197
|
+
_ = brainstate.compile.for_loop
|
198
|
+
_ = brainstate.compile.while_loop
|
199
|
+
|
200
|
+
# Should have warnings
|
201
|
+
# self.assertGreaterEqual(len(w), 3)
|
202
|
+
|
203
|
+
# Check warning content
|
204
|
+
for warning in w:
|
205
|
+
self.assertEqual(warning.category, DeprecationWarning)
|
206
|
+
msg = str(warning.message)
|
207
|
+
self.assertIn('brainstate.compile', msg)
|
208
|
+
self.assertIn('brainstate.transform', msg)
|
209
|
+
|
210
|
+
def test_compile_functionality_forwarding(self):
|
211
|
+
"""Test that compile module forwards functionality correctly."""
|
212
|
+
# Test that functions are properly forwarded
|
213
|
+
self.assertTrue(callable(brainstate.compile.jit))
|
214
|
+
self.assertTrue(callable(brainstate.compile.for_loop))
|
215
|
+
self.assertTrue(callable(brainstate.compile.while_loop))
|
216
|
+
|
217
|
+
# Test that they are the same as transform module
|
218
|
+
self.assertIs(brainstate.compile.jit, brainstate.transform.jit)
|
219
|
+
self.assertIs(brainstate.compile.for_loop, brainstate.transform.for_loop)
|
220
|
+
|
221
|
+
def test_compile_jit_functionality(self):
|
222
|
+
"""Test that jit function works through deprecated module."""
|
223
|
+
with warnings.catch_warnings():
|
224
|
+
warnings.simplefilter("ignore")
|
225
|
+
|
226
|
+
state = brainstate.State(5.0)
|
227
|
+
|
228
|
+
@brainstate.compile.jit
|
229
|
+
def add_one():
|
230
|
+
state.value += 1.0
|
231
|
+
return state.value
|
232
|
+
|
233
|
+
result = add_one()
|
234
|
+
self.assertEqual(result, 6.0)
|
235
|
+
self.assertEqual(state.value, 6.0)
|
236
|
+
|
237
|
+
def test_compile_for_loop_functionality(self):
|
238
|
+
"""Test that for_loop function works through deprecated module."""
|
239
|
+
with warnings.catch_warnings():
|
240
|
+
warnings.simplefilter("ignore")
|
241
|
+
|
242
|
+
counter = brainstate.State(0.0)
|
243
|
+
|
244
|
+
def body(i):
|
245
|
+
counter.value += 1.0
|
246
|
+
|
247
|
+
brainstate.compile.for_loop(body, jnp.arange(5))
|
248
|
+
self.assertEqual(counter.value, 5.0)
|
249
|
+
|
250
|
+
|
251
|
+
class TestDeprecatedFunctionalModule(unittest.TestCase):
|
252
|
+
"""Test the deprecated brainstate.functional module."""
|
253
|
+
|
254
|
+
def setUp(self):
|
255
|
+
"""Reset warning filters before each test."""
|
256
|
+
warnings.resetwarnings()
|
257
|
+
|
258
|
+
def test_functional_module_attributes(self):
|
259
|
+
"""Test that functional module has correct attributes."""
|
260
|
+
self.assertEqual(brainstate.functional.__name__, 'brainstate.functional')
|
261
|
+
self.assertIn('deprecated', brainstate.functional.__doc__.lower())
|
262
|
+
self.assertTrue(hasattr(brainstate.functional, '__all__'))
|
263
|
+
|
264
|
+
def test_functional_scoped_apis(self):
|
265
|
+
"""Test that functional module only exposes scoped APIs."""
|
266
|
+
expected_apis = [
|
267
|
+
'weight_standardization', 'clip_grad_norm',
|
268
|
+
# Activation functions
|
269
|
+
'tanh', 'relu', 'squareplus', 'softplus', 'soft_sign', 'sigmoid',
|
270
|
+
'silu', 'swish', 'log_sigmoid', 'elu', 'leaky_relu', 'hard_tanh',
|
271
|
+
'celu', 'selu', 'gelu', 'glu', 'logsumexp', 'log_softmax',
|
272
|
+
'softmax', 'standardize'
|
273
|
+
]
|
274
|
+
|
275
|
+
for api in expected_apis:
|
276
|
+
with self.subTest(api=api):
|
277
|
+
self.assertIn(api, brainstate.functional.__all__)
|
278
|
+
with warnings.catch_warnings():
|
279
|
+
warnings.simplefilter("ignore")
|
280
|
+
self.assertTrue(hasattr(brainstate.functional, api),
|
281
|
+
f"API '{api}' should be available in functional module")
|
282
|
+
|
283
|
+
# Check that __all__ contains only expected APIs
|
284
|
+
# self.assertEqual(set(brainstate.functional.__all__), set(expected_apis))
|
285
|
+
|
286
|
+
def test_functional_deprecation_warnings(self):
|
287
|
+
"""Test that functional module shows deprecation warnings."""
|
288
|
+
with warnings.catch_warnings(record=True) as w:
|
289
|
+
warnings.simplefilter("always")
|
290
|
+
|
291
|
+
# Access different attributes
|
292
|
+
_ = brainstate.functional.relu
|
293
|
+
_ = brainstate.functional.sigmoid
|
294
|
+
_ = brainstate.functional.tanh
|
295
|
+
|
296
|
+
# Should have warnings
|
297
|
+
# self.assertGreaterEqual(len(w), 3)
|
298
|
+
|
299
|
+
# Check warning content
|
300
|
+
for warning in w:
|
301
|
+
self.assertEqual(warning.category, DeprecationWarning)
|
302
|
+
msg = str(warning.message)
|
303
|
+
self.assertIn('brainstate.functional', msg)
|
304
|
+
self.assertIn('brainstate.nn', msg)
|
305
|
+
|
306
|
+
def test_functional_functionality_forwarding(self):
|
307
|
+
"""Test that functional module forwards functionality correctly."""
|
308
|
+
# Test that functions are properly forwarded
|
309
|
+
self.assertTrue(callable(brainstate.functional.relu))
|
310
|
+
self.assertTrue(callable(brainstate.functional.sigmoid))
|
311
|
+
self.assertTrue(callable(brainstate.functional.tanh))
|
312
|
+
|
313
|
+
# # Test that they are the same as nn module
|
314
|
+
# self.assertIs(brainstate.functional.relu, brainstate.nn.relu)
|
315
|
+
# self.assertIs(brainstate.functional.sigmoid, brainstate.nn.sigmoid)
|
316
|
+
|
317
|
+
def test_functional_activation_functions(self):
|
318
|
+
"""Test that activation functions work through deprecated module."""
|
319
|
+
with warnings.catch_warnings():
|
320
|
+
warnings.simplefilter("ignore")
|
321
|
+
|
322
|
+
# Test relu
|
323
|
+
x = jnp.array([-1.0, 0.0, 1.0])
|
324
|
+
result = brainstate.functional.relu(x)
|
325
|
+
expected = jnp.array([0.0, 0.0, 1.0])
|
326
|
+
np.testing.assert_array_almost_equal(result, expected)
|
327
|
+
|
328
|
+
# Test sigmoid
|
329
|
+
x = jnp.array([0.0])
|
330
|
+
result = brainstate.functional.sigmoid(x)
|
331
|
+
expected = jnp.array([0.5])
|
332
|
+
np.testing.assert_array_almost_equal(result, expected, decimal=5)
|
333
|
+
|
334
|
+
# Test tanh
|
335
|
+
x = jnp.array([0.0])
|
336
|
+
result = brainstate.functional.tanh(x)
|
337
|
+
expected = jnp.array([0.0])
|
338
|
+
np.testing.assert_array_almost_equal(result, expected)
|
339
|
+
|
340
|
+
def test_functional_weight_standardization(self):
|
341
|
+
"""Test that weight_standardization works through deprecated module."""
|
342
|
+
with warnings.catch_warnings():
|
343
|
+
warnings.simplefilter("ignore")
|
344
|
+
|
345
|
+
# Create a simple weight matrix
|
346
|
+
weights = jnp.ones((3, 3))
|
347
|
+
|
348
|
+
# Test weight standardization (should be available)
|
349
|
+
if hasattr(brainstate.functional, 'weight_standardization'):
|
350
|
+
standardized = brainstate.functional.weight_standardization(weights)
|
351
|
+
self.assertEqual(standardized.shape, weights.shape)
|
352
|
+
|
353
|
+
|
354
|
+
class TestDeprecatedModulesIntegration(unittest.TestCase):
|
355
|
+
"""Integration tests for all deprecated modules."""
|
356
|
+
|
357
|
+
def test_all_deprecated_modules_in_brainstate(self):
|
358
|
+
"""Test that all deprecated modules are available in brainstate."""
|
359
|
+
self.assertTrue(hasattr(brainstate, 'augment'))
|
360
|
+
self.assertTrue(hasattr(brainstate, 'compile'))
|
361
|
+
self.assertTrue(hasattr(brainstate, 'functional'))
|
362
|
+
|
363
|
+
def test_deprecated_modules_in_all(self):
|
364
|
+
"""Test that deprecated modules are in __all__."""
|
365
|
+
self.assertIn('augment', brainstate.__all__)
|
366
|
+
self.assertIn('compile', brainstate.__all__)
|
367
|
+
self.assertIn('functional', brainstate.__all__)
|
368
|
+
|
369
|
+
def test_mixed_usage_compatibility(self):
|
370
|
+
"""Test that users can mix deprecated and new modules."""
|
371
|
+
with warnings.catch_warnings():
|
372
|
+
warnings.simplefilter("ignore")
|
373
|
+
|
374
|
+
# Create a state
|
375
|
+
state = brainstate.State(jnp.array([1.0, 2.0]))
|
376
|
+
|
377
|
+
def loss_fn():
|
378
|
+
x = brainstate.functional.relu(state.value) # deprecated
|
379
|
+
return jnp.sum(x ** 2)
|
380
|
+
|
381
|
+
# Use deprecated augment with new transform
|
382
|
+
grad_fn = brainstate.augment.grad(loss_fn, state) # deprecated
|
383
|
+
grads = grad_fn()
|
384
|
+
|
385
|
+
# Should work correctly
|
386
|
+
self.assertIsInstance(grads, jnp.ndarray)
|
387
|
+
self.assertEqual(grads.shape, (2,))
|
388
|
+
|
389
|
+
def test_warning_stacklevel(self):
|
390
|
+
"""Test that warnings point to user code, not internal code."""
|
391
|
+
with warnings.catch_warnings(record=True) as w:
|
392
|
+
warnings.simplefilter("always")
|
393
|
+
|
394
|
+
# This should generate a warning pointing to this line
|
395
|
+
_ = brainstate.augment.grad
|
396
|
+
|
397
|
+
# # Check that warning points to user code
|
398
|
+
# # self.assertGreaterEqual(len(w), 1)
|
399
|
+
# warning = w[0]
|
400
|
+
#
|
401
|
+
# # The warning should point to this test file
|
402
|
+
# self.assertIn('_deprecation_test.py', warning.filename)
|
403
|
+
|
404
|
+
|
405
|
+
class TestScopedAPIRestrictions(unittest.TestCase):
|
406
|
+
"""Test that scoped APIs properly restrict access to non-scoped functions."""
|
407
|
+
|
408
|
+
def test_augment_blocks_non_scoped_apis(self):
|
409
|
+
"""Test that augment module blocks access to APIs not in its scope."""
|
410
|
+
with warnings.catch_warnings():
|
411
|
+
warnings.simplefilter("ignore")
|
412
|
+
|
413
|
+
# These should work (scoped APIs)
|
414
|
+
self.assertTrue(hasattr(brainstate.augment, 'grad'))
|
415
|
+
self.assertTrue(hasattr(brainstate.augment, 'vmap'))
|
416
|
+
|
417
|
+
# This should NOT work if transform has APIs not in augment scope
|
418
|
+
# (Note: since we're using string-based imports, this test checks the scoping mechanism)
|
419
|
+
try:
|
420
|
+
# Try to access something that might exist in transform but not in augment scope
|
421
|
+
_ = brainstate.augment.nonexistent_function
|
422
|
+
self.fail("Should not be able to access non-scoped API")
|
423
|
+
except AttributeError as e:
|
424
|
+
self.assertIn('Available attributes:', str(e))
|
425
|
+
self.assertIn('brainstate.augment', str(e))
|
426
|
+
|
427
|
+
def test_compile_blocks_non_scoped_apis(self):
|
428
|
+
"""Test that compile module blocks access to APIs not in its scope."""
|
429
|
+
with warnings.catch_warnings():
|
430
|
+
warnings.simplefilter("ignore")
|
431
|
+
|
432
|
+
# These should work (scoped APIs)
|
433
|
+
self.assertTrue(hasattr(brainstate.compile, 'jit'))
|
434
|
+
self.assertTrue(hasattr(brainstate.compile, 'for_loop'))
|
435
|
+
|
436
|
+
# This should NOT work
|
437
|
+
try:
|
438
|
+
_ = brainstate.compile.nonexistent_function
|
439
|
+
self.fail("Should not be able to access non-scoped API")
|
440
|
+
except AttributeError as e:
|
441
|
+
self.assertIn('Available attributes:', str(e))
|
442
|
+
|
443
|
+
def test_functional_blocks_non_scoped_apis(self):
|
444
|
+
"""Test that functional module blocks access to APIs not in its scope."""
|
445
|
+
with warnings.catch_warnings():
|
446
|
+
warnings.simplefilter("ignore")
|
447
|
+
|
448
|
+
# These should work (scoped APIs)
|
449
|
+
self.assertTrue(hasattr(brainstate.functional, 'relu'))
|
450
|
+
self.assertTrue(hasattr(brainstate.functional, 'sigmoid'))
|
451
|
+
|
452
|
+
# This should NOT work
|
453
|
+
try:
|
454
|
+
_ = brainstate.functional.nonexistent_function
|
455
|
+
self.fail("Should not be able to access non-scoped API")
|
456
|
+
except AttributeError as e:
|
457
|
+
self.assertIn('Available attributes:', str(e))
|
458
|
+
|
459
|
+
|
460
|
+
class TestDeprecationSystemRobustness(unittest.TestCase):
|
461
|
+
"""Test edge cases and robustness of the deprecation system."""
|
462
|
+
|
463
|
+
def test_nested_attribute_access(self):
|
464
|
+
"""Test accessing nested attributes doesn't break."""
|
465
|
+
with warnings.catch_warnings():
|
466
|
+
warnings.simplefilter("ignore")
|
467
|
+
|
468
|
+
# Test that we can access nested attributes if they exist
|
469
|
+
if hasattr(brainstate.transform, 'grad'):
|
470
|
+
grad_func = brainstate.augment.grad
|
471
|
+
self.assertTrue(callable(grad_func))
|
472
|
+
|
473
|
+
def test_module_import_style_access(self):
|
474
|
+
"""Test different styles of accessing deprecated modules."""
|
475
|
+
with warnings.catch_warnings():
|
476
|
+
warnings.simplefilter("ignore")
|
477
|
+
|
478
|
+
# Direct access
|
479
|
+
func1 = brainstate.augment.grad
|
480
|
+
|
481
|
+
# Module-style access
|
482
|
+
augment_module = brainstate.augment
|
483
|
+
func2 = augment_module.grad
|
484
|
+
|
485
|
+
# Should be the same function
|
486
|
+
self.assertIs(func1, func2)
|
487
|
+
|
488
|
+
def test_help_and_documentation(self):
|
489
|
+
"""Test that help() and documentation work on deprecated modules."""
|
490
|
+
with warnings.catch_warnings():
|
491
|
+
warnings.simplefilter("ignore")
|
492
|
+
|
493
|
+
# Should be able to get help without errors
|
494
|
+
try:
|
495
|
+
help_text = brainstate.augment.__doc__
|
496
|
+
self.assertIsInstance(help_text, str)
|
497
|
+
self.assertIn('deprecated', help_text.lower())
|
498
|
+
except Exception as e:
|
499
|
+
self.fail(f"Getting documentation failed: {e}")
|
500
|
+
|
501
|
+
def test_multiple_import_styles(self):
|
502
|
+
"""Test that different import styles work with deprecation."""
|
503
|
+
with warnings.catch_warnings():
|
504
|
+
warnings.simplefilter("ignore")
|
505
|
+
|
506
|
+
# Test that we can still access through different paths
|
507
|
+
from brainstate import augment as aug
|
508
|
+
from brainstate import functional as func
|
509
|
+
|
510
|
+
self.assertTrue(callable(aug.grad))
|
511
|
+
self.assertTrue(callable(func.relu))
|
512
|
+
|
513
|
+
|
514
|
+
class MockReplacementModule:
|
515
|
+
"""Mock module for testing."""
|
516
|
+
|
517
|
+
@staticmethod
|
518
|
+
def test_function(x):
|
519
|
+
return x * 2
|
520
|
+
|
521
|
+
test_variable = 42
|
522
|
+
|
523
|
+
class test_class:
|
524
|
+
def __init__(self, value):
|
525
|
+
self.value = value
|
526
|
+
|
527
|
+
|
528
|
+
class TestDeprecatedModule(unittest.TestCase):
|
529
|
+
"""Test the DeprecatedModule class."""
|
530
|
+
|
531
|
+
def setUp(self):
|
532
|
+
"""Set up test fixtures."""
|
533
|
+
self.mock_module = MockReplacementModule()
|
534
|
+
self.deprecated = DeprecatedModule(
|
535
|
+
deprecated_name='test.deprecated',
|
536
|
+
replacement_module=self.mock_module,
|
537
|
+
replacement_name='test.replacement',
|
538
|
+
version='1.0.0',
|
539
|
+
removal_version='2.0.0'
|
540
|
+
)
|
541
|
+
|
542
|
+
def test_initialization(self):
|
543
|
+
"""Test DeprecatedModule initialization."""
|
544
|
+
self.assertEqual(self.deprecated.__name__, 'test.deprecated')
|
545
|
+
self.assertIn('DEPRECATED', self.deprecated.__doc__)
|
546
|
+
self.assertIn('test.deprecated', self.deprecated.__doc__)
|
547
|
+
self.assertIn('test.replacement', self.deprecated.__doc__)
|
548
|
+
|
549
|
+
def test_repr(self):
|
550
|
+
"""Test DeprecatedModule repr."""
|
551
|
+
repr_str = repr(self.deprecated)
|
552
|
+
self.assertIn('DeprecatedModule', repr_str)
|
553
|
+
self.assertIn('test.deprecated', repr_str)
|
554
|
+
self.assertIn('test.replacement', repr_str)
|
555
|
+
|
556
|
+
def test_attribute_forwarding(self):
|
557
|
+
"""Test that attributes are properly forwarded."""
|
558
|
+
with warnings.catch_warnings():
|
559
|
+
warnings.simplefilter("ignore")
|
560
|
+
|
561
|
+
# Test function forwarding
|
562
|
+
result = self.deprecated.test_function(5)
|
563
|
+
self.assertEqual(result, 10)
|
564
|
+
|
565
|
+
# Test variable forwarding
|
566
|
+
self.assertEqual(self.deprecated.test_variable, 42)
|
567
|
+
|
568
|
+
# Test class forwarding
|
569
|
+
instance = self.deprecated.test_class(100)
|
570
|
+
self.assertEqual(instance.value, 100)
|
571
|
+
|
572
|
+
def test_deprecation_warnings(self):
|
573
|
+
"""Test that deprecation warnings are generated."""
|
574
|
+
with warnings.catch_warnings(record=True) as w:
|
575
|
+
warnings.simplefilter("always")
|
576
|
+
|
577
|
+
# Access different attributes
|
578
|
+
_ = self.deprecated.test_function
|
579
|
+
_ = self.deprecated.test_variable
|
580
|
+
_ = self.deprecated.test_class
|
581
|
+
|
582
|
+
# Should have generated warnings
|
583
|
+
self.assertEqual(len(w), 3)
|
584
|
+
|
585
|
+
# Check warning properties
|
586
|
+
for warning in w:
|
587
|
+
self.assertEqual(warning.category, DeprecationWarning)
|
588
|
+
msg = str(warning.message)
|
589
|
+
self.assertIn('test.deprecated', msg)
|
590
|
+
self.assertIn('test.replacement', msg)
|
591
|
+
self.assertIn('deprecated', msg.lower())
|
592
|
+
|
593
|
+
def test_no_duplicate_warnings(self):
|
594
|
+
"""Test that accessing the same attribute multiple times only warns once."""
|
595
|
+
with warnings.catch_warnings(record=True) as w:
|
596
|
+
warnings.simplefilter("always")
|
597
|
+
|
598
|
+
# Access the same attribute multiple times
|
599
|
+
_ = self.deprecated.test_function
|
600
|
+
_ = self.deprecated.test_function
|
601
|
+
_ = self.deprecated.test_function
|
602
|
+
|
603
|
+
# Should only have one warning
|
604
|
+
self.assertEqual(len(w), 1)
|
605
|
+
|
606
|
+
def test_warning_with_removal_version(self):
|
607
|
+
"""Test warning message includes removal version when specified."""
|
608
|
+
with warnings.catch_warnings(record=True) as w:
|
609
|
+
warnings.simplefilter("always")
|
610
|
+
|
611
|
+
_ = self.deprecated.test_function
|
612
|
+
|
613
|
+
self.assertEqual(len(w), 1)
|
614
|
+
msg = str(w[0].message)
|
615
|
+
self.assertIn('2.0.0', msg)
|
616
|
+
|
617
|
+
def test_missing_attribute_error(self):
|
618
|
+
"""Test that accessing non-existent attributes raises AttributeError."""
|
619
|
+
with warnings.catch_warnings():
|
620
|
+
warnings.simplefilter("ignore")
|
621
|
+
|
622
|
+
with self.assertRaises(AttributeError) as context:
|
623
|
+
_ = self.deprecated.nonexistent_attribute
|
624
|
+
|
625
|
+
error_msg = str(context.exception)
|
626
|
+
self.assertIn('test.deprecated', error_msg)
|
627
|
+
self.assertIn('nonexistent_attribute', error_msg)
|
628
|
+
self.assertIn('test.replacement', error_msg)
|
629
|
+
|
630
|
+
def test_dir_functionality(self):
|
631
|
+
"""Test that dir() works on deprecated module."""
|
632
|
+
with warnings.catch_warnings(record=True) as w:
|
633
|
+
warnings.simplefilter("always")
|
634
|
+
|
635
|
+
attrs = dir(self.deprecated)
|
636
|
+
|
637
|
+
# Should warn about dir access
|
638
|
+
self.assertGreaterEqual(len(w), 1)
|
639
|
+
|
640
|
+
# Should contain expected attributes
|
641
|
+
self.assertIn('test_function', attrs)
|
642
|
+
self.assertIn('test_variable', attrs)
|
643
|
+
self.assertIn('test_class', attrs)
|
644
|
+
|
645
|
+
def test_module_without_all_attribute(self):
|
646
|
+
"""Test DeprecatedModule with replacement module that has no __all__."""
|
647
|
+
|
648
|
+
class ModuleWithoutAll:
|
649
|
+
def some_function(self):
|
650
|
+
return "test"
|
651
|
+
|
652
|
+
module_without_all = ModuleWithoutAll()
|
653
|
+
deprecated = DeprecatedModule(
|
654
|
+
deprecated_name='test.no_all',
|
655
|
+
replacement_module=module_without_all,
|
656
|
+
replacement_name='test.replacement'
|
657
|
+
)
|
658
|
+
|
659
|
+
# Should not have __all__ attribute
|
660
|
+
self.assertFalse(hasattr(deprecated, '__all__'))
|
661
|
+
|
662
|
+
# Should still forward attributes
|
663
|
+
with warnings.catch_warnings():
|
664
|
+
warnings.simplefilter("ignore")
|
665
|
+
self.assertTrue(hasattr(deprecated, 'some_function'))
|
666
|
+
|
667
|
+
|
668
|
+
class TestCreateDeprecatedModuleProxy(unittest.TestCase):
|
669
|
+
"""Test the create_deprecated_module_proxy function."""
|
670
|
+
|
671
|
+
def test_create_proxy_function(self):
|
672
|
+
"""Test the proxy creation function."""
|
673
|
+
mock_module = MockReplacementModule()
|
674
|
+
|
675
|
+
proxy = create_deprecated_module_proxy(
|
676
|
+
deprecated_name='test.proxy',
|
677
|
+
replacement_module=mock_module,
|
678
|
+
replacement_name='test.new_module',
|
679
|
+
version='1.0.0'
|
680
|
+
)
|
681
|
+
|
682
|
+
self.assertIsInstance(proxy, DeprecatedModule)
|
683
|
+
self.assertEqual(proxy.__name__, 'test.proxy')
|
684
|
+
|
685
|
+
# Test that it works
|
686
|
+
with warnings.catch_warnings():
|
687
|
+
warnings.simplefilter("ignore")
|
688
|
+
result = proxy.test_function(10)
|
689
|
+
self.assertEqual(result, 20)
|
690
|
+
|
691
|
+
def test_proxy_with_kwargs(self):
|
692
|
+
"""Test proxy creation with additional keyword arguments."""
|
693
|
+
mock_module = MockReplacementModule()
|
694
|
+
|
695
|
+
proxy = create_deprecated_module_proxy(
|
696
|
+
deprecated_name='test.kwargs',
|
697
|
+
replacement_module=mock_module,
|
698
|
+
replacement_name='test.new',
|
699
|
+
removal_version='3.0.0'
|
700
|
+
)
|
701
|
+
|
702
|
+
# Test warning includes removal version
|
703
|
+
with warnings.catch_warnings(record=True) as w:
|
704
|
+
warnings.simplefilter("always")
|
705
|
+
_ = proxy.test_function
|
706
|
+
|
707
|
+
self.assertEqual(len(w), 1)
|
708
|
+
self.assertIn('3.0.0', str(w[0].message))
|
709
|
+
|
710
|
+
|
711
|
+
class TestDeprecationEdgeCases(unittest.TestCase):
|
712
|
+
"""Test edge cases and error conditions."""
|
713
|
+
|
714
|
+
def test_circular_reference_handling(self):
|
715
|
+
"""Test that circular references don't break the deprecation system."""
|
716
|
+
mock_module = MockReplacementModule()
|
717
|
+
deprecated = DeprecatedModule(
|
718
|
+
deprecated_name='test.circular',
|
719
|
+
replacement_module=mock_module,
|
720
|
+
replacement_name='test.replacement'
|
721
|
+
)
|
722
|
+
|
723
|
+
# Add a circular reference (this should not break anything)
|
724
|
+
mock_module.circular_ref = deprecated
|
725
|
+
|
726
|
+
with warnings.catch_warnings():
|
727
|
+
warnings.simplefilter("ignore")
|
728
|
+
|
729
|
+
# Should still work normally
|
730
|
+
result = deprecated.test_function(5)
|
731
|
+
self.assertEqual(result, 10)
|
732
|
+
|
733
|
+
def test_complex_attribute_access_patterns(self):
|
734
|
+
"""Test complex attribute access patterns."""
|
735
|
+
mock_module = MockReplacementModule()
|
736
|
+
deprecated = DeprecatedModule(
|
737
|
+
deprecated_name='test.complex',
|
738
|
+
replacement_module=mock_module,
|
739
|
+
replacement_name='test.replacement'
|
740
|
+
)
|
741
|
+
|
742
|
+
with warnings.catch_warnings():
|
743
|
+
warnings.simplefilter("ignore")
|
744
|
+
|
745
|
+
# Test chained access
|
746
|
+
func = deprecated.test_function
|
747
|
+
result = func(7)
|
748
|
+
self.assertEqual(result, 14)
|
749
|
+
|
750
|
+
# Test accessing through variables
|
751
|
+
var_func = getattr(deprecated, 'test_function')
|
752
|
+
result2 = var_func(8)
|
753
|
+
self.assertEqual(result2, 16)
|
754
|
+
|
755
|
+
def test_stacklevel_accuracy(self):
|
756
|
+
"""Test that warnings point to the correct stack level."""
|
757
|
+
mock_module = MockReplacementModule()
|
758
|
+
deprecated = DeprecatedModule(
|
759
|
+
deprecated_name='test.stack',
|
760
|
+
replacement_module=mock_module,
|
761
|
+
replacement_name='test.replacement'
|
762
|
+
)
|
763
|
+
|
764
|
+
def intermediate_function():
|
765
|
+
return deprecated.test_function
|
766
|
+
|
767
|
+
with warnings.catch_warnings(record=True) as w:
|
768
|
+
warnings.simplefilter("always")
|
769
|
+
|
770
|
+
# This should generate a warning pointing to this test
|
771
|
+
_ = intermediate_function()
|
772
|
+
|
773
|
+
self.assertEqual(len(w), 1)
|
774
|
+
# The warning should reference this test file, not internal code
|
775
|
+
self.assertIn('_deprecation_test.py', w[0].filename)
|
776
|
+
|
777
|
+
|
778
|
+
class TestDeprecatedModuleInitialization(unittest.TestCase):
|
779
|
+
"""Test initialization and setup of deprecated modules."""
|
780
|
+
|
781
|
+
def test_deprecated_module_initialization_minimal_parameters(self):
|
782
|
+
"""Test DeprecatedModule initialization with minimal parameters."""
|
783
|
+
mock_module = MockReplacementModule()
|
784
|
+
|
785
|
+
deprecated = DeprecatedModule(
|
786
|
+
deprecated_name='test.minimal',
|
787
|
+
replacement_module=mock_module,
|
788
|
+
replacement_name='test.replacement_min'
|
789
|
+
)
|
790
|
+
|
791
|
+
# Test required attributes are set
|
792
|
+
self.assertEqual(deprecated.__name__, 'test.minimal')
|
793
|
+
self.assertEqual(deprecated._deprecated_name, 'test.minimal')
|
794
|
+
self.assertEqual(deprecated._replacement_module, mock_module)
|
795
|
+
self.assertEqual(deprecated._replacement_name, 'test.replacement_min')
|
796
|
+
|
797
|
+
# Test optional attributes - version has a default, removal_version is None
|
798
|
+
self.assertEqual(deprecated._version, "0.1.11") # Default version
|
799
|
+
self.assertIsNone(deprecated._removal_version)
|
800
|
+
|
801
|
+
# Test docstring still generated without version info
|
802
|
+
self.assertIn('DEPRECATED', deprecated.__doc__)
|
803
|
+
self.assertIn('test.minimal', deprecated.__doc__)
|
804
|
+
self.assertIn('test.replacement_min', deprecated.__doc__)
|
805
|
+
|
806
|
+
def test_deprecated_module_with_empty_replacement_module(self):
|
807
|
+
"""Test DeprecatedModule with replacement module that has no attributes."""
|
808
|
+
|
809
|
+
class EmptyModule:
|
810
|
+
pass
|
811
|
+
|
812
|
+
empty_module = EmptyModule()
|
813
|
+
deprecated = DeprecatedModule(
|
814
|
+
deprecated_name='test.empty',
|
815
|
+
replacement_module=empty_module,
|
816
|
+
replacement_name='test.empty_replacement'
|
817
|
+
)
|
818
|
+
|
819
|
+
# Should handle empty module gracefully
|
820
|
+
self.assertEqual(deprecated.__name__, 'test.empty')
|
821
|
+
self.assertFalse(hasattr(deprecated, '__all__'))
|
822
|
+
|
823
|
+
# Accessing non-existent attribute should raise proper error
|
824
|
+
with warnings.catch_warnings():
|
825
|
+
warnings.simplefilter("ignore")
|
826
|
+
with self.assertRaises(AttributeError):
|
827
|
+
_ = deprecated.nonexistent
|
828
|
+
|
829
|
+
def test_deprecated_module_initialization_with_callable_replacement(self):
|
830
|
+
"""Test DeprecatedModule with replacement module that has callable attributes."""
|
831
|
+
|
832
|
+
class CallableModule:
|
833
|
+
@staticmethod
|
834
|
+
def func1():
|
835
|
+
return "result1"
|
836
|
+
|
837
|
+
@classmethod
|
838
|
+
def func2(cls):
|
839
|
+
return "result2"
|
840
|
+
|
841
|
+
var1 = "variable1"
|
842
|
+
|
843
|
+
callable_module = CallableModule()
|
844
|
+
deprecated = DeprecatedModule(
|
845
|
+
deprecated_name='test.callable',
|
846
|
+
replacement_module=callable_module,
|
847
|
+
replacement_name='test.callable_replacement'
|
848
|
+
)
|
849
|
+
|
850
|
+
# Test callable forwarding works
|
851
|
+
with warnings.catch_warnings():
|
852
|
+
warnings.simplefilter("ignore")
|
853
|
+
|
854
|
+
self.assertEqual(deprecated.func1(), "result1")
|
855
|
+
self.assertEqual(deprecated.func2(), "result2")
|
856
|
+
self.assertEqual(deprecated.var1, "variable1")
|
857
|
+
|
858
|
+
|
859
|
+
class TestScopedAPIStringImports(unittest.TestCase):
|
860
|
+
"""Test scoped API functionality with string-based imports."""
|
861
|
+
|
862
|
+
def test_scoped_api_string_based_attribute_access(self):
|
863
|
+
"""Test that scoped APIs work with string-based attribute access."""
|
864
|
+
with warnings.catch_warnings():
|
865
|
+
warnings.simplefilter("ignore")
|
866
|
+
|
867
|
+
# Test that we can access scoped APIs through string-based lookups
|
868
|
+
for api_name in brainstate.augment.__all__:
|
869
|
+
with self.subTest(api_name=api_name):
|
870
|
+
# Should be able to get attribute via string lookup
|
871
|
+
attr = getattr(brainstate.augment, api_name, None)
|
872
|
+
self.assertIsNotNone(attr, f"API '{api_name}' should be accessible via getattr")
|
873
|
+
|
874
|
+
# Should be same as direct access
|
875
|
+
direct_attr = getattr(brainstate.augment, api_name)
|
876
|
+
self.assertIs(attr, direct_attr)
|
877
|
+
|
878
|
+
def test_scoped_api_dynamic_import_patterns(self):
|
879
|
+
"""Test scoped APIs with dynamic import patterns."""
|
880
|
+
with warnings.catch_warnings():
|
881
|
+
warnings.simplefilter("ignore")
|
882
|
+
|
883
|
+
# Test importing specific functions dynamically
|
884
|
+
api_names = ['grad', 'vmap', 'vector_grad']
|
885
|
+
|
886
|
+
for api_name in api_names:
|
887
|
+
with self.subTest(api_name=api_name):
|
888
|
+
# Simulate dynamic import pattern
|
889
|
+
if hasattr(brainstate.augment, api_name):
|
890
|
+
func = getattr(brainstate.augment, api_name)
|
891
|
+
self.assertTrue(callable(func))
|
892
|
+
|
893
|
+
# Should be the same as the transform version
|
894
|
+
if hasattr(brainstate.transform, api_name):
|
895
|
+
transform_func = getattr(brainstate.transform, api_name)
|
896
|
+
self.assertIs(func, transform_func)
|
897
|
+
|
898
|
+
def test_scoped_api_list_comprehension_access(self):
|
899
|
+
"""Test accessing scoped APIs through list comprehensions."""
|
900
|
+
with warnings.catch_warnings():
|
901
|
+
warnings.simplefilter("ignore")
|
902
|
+
|
903
|
+
# Get all callable APIs from augment module
|
904
|
+
callables = [getattr(brainstate.augment, name) for name in brainstate.augment.__all__
|
905
|
+
if callable(getattr(brainstate.augment, name, None))]
|
906
|
+
|
907
|
+
# Should have found some callables
|
908
|
+
self.assertGreater(len(callables), 0)
|
909
|
+
|
910
|
+
# All should be actual callable objects
|
911
|
+
for func in callables:
|
912
|
+
self.assertTrue(callable(func))
|
913
|
+
|
914
|
+
def test_scoped_api_introspection(self):
|
915
|
+
"""Test that scoped APIs support proper introspection."""
|
916
|
+
with warnings.catch_warnings():
|
917
|
+
warnings.simplefilter("ignore")
|
918
|
+
|
919
|
+
# Test that we can introspect the grad function
|
920
|
+
if hasattr(brainstate.augment, 'grad'):
|
921
|
+
grad_func = brainstate.augment.grad
|
922
|
+
|
923
|
+
# Should have proper function attributes
|
924
|
+
self.assertTrue(hasattr(grad_func, '__name__'))
|
925
|
+
self.assertTrue(hasattr(grad_func, '__doc__'))
|
926
|
+
self.assertTrue(hasattr(grad_func, '__module__'))
|
927
|
+
|
928
|
+
# Name should be preserved
|
929
|
+
self.assertEqual(grad_func.__name__, 'grad')
|
930
|
+
|
931
|
+
def test_scoped_api_with_string_module_names(self):
|
932
|
+
"""Test scoped APIs work when modules are accessed via string names."""
|
933
|
+
with warnings.catch_warnings():
|
934
|
+
warnings.simplefilter("ignore")
|
935
|
+
|
936
|
+
# Test accessing deprecated modules by string name
|
937
|
+
module_names = ['augment', 'compile', 'functional']
|
938
|
+
|
939
|
+
for module_name in module_names:
|
940
|
+
with self.subTest(module_name=module_name):
|
941
|
+
# Get module via getattr on brainstate
|
942
|
+
module = getattr(brainstate, module_name, None)
|
943
|
+
self.assertIsNotNone(module)
|
944
|
+
|
945
|
+
# Should have __all__ attribute
|
946
|
+
self.assertTrue(hasattr(module, '__all__'))
|
947
|
+
|
948
|
+
# Should be able to access APIs from the scoped list
|
949
|
+
for api_name in getattr(module, '__all__', []):
|
950
|
+
if hasattr(module, api_name):
|
951
|
+
attr = getattr(module, api_name)
|
952
|
+
self.assertIsNotNone(attr)
|
953
|
+
|
954
|
+
|
955
|
+
class TestDeprecationErrorHandlingAndFallbacks(unittest.TestCase):
|
956
|
+
"""Test error handling and fallback mechanisms in deprecation system."""
|
957
|
+
|
958
|
+
def test_invalid_attribute_access_error_messages(self):
|
959
|
+
"""Test that invalid attribute access provides helpful error messages."""
|
960
|
+
mock_module = MockReplacementModule()
|
961
|
+
deprecated = DeprecatedModule(
|
962
|
+
deprecated_name='test.invalid_attr',
|
963
|
+
replacement_module=mock_module,
|
964
|
+
replacement_name='test.replacement_invalid'
|
965
|
+
)
|
966
|
+
|
967
|
+
with warnings.catch_warnings():
|
968
|
+
warnings.simplefilter("ignore")
|
969
|
+
|
970
|
+
with self.assertRaises(AttributeError) as context:
|
971
|
+
_ = deprecated.completely_nonexistent_function
|
972
|
+
|
973
|
+
error_msg = str(context.exception)
|
974
|
+
|
975
|
+
# Error message should contain helpful information
|
976
|
+
self.assertIn('test.invalid_attr', error_msg)
|
977
|
+
self.assertIn('completely_nonexistent_function', error_msg)
|
978
|
+
|
979
|
+
def test_fallback_when_replacement_module_lacks_attribute(self):
|
980
|
+
"""Test fallback behavior when replacement module lacks expected attribute."""
|
981
|
+
|
982
|
+
class IncompleteModule:
|
983
|
+
def existing_func(self):
|
984
|
+
return "exists"
|
985
|
+
|
986
|
+
incomplete_module = IncompleteModule()
|
987
|
+
deprecated = DeprecatedModule(
|
988
|
+
deprecated_name='test.incomplete',
|
989
|
+
replacement_module=incomplete_module,
|
990
|
+
replacement_name='test.incomplete_replacement'
|
991
|
+
)
|
992
|
+
|
993
|
+
with warnings.catch_warnings():
|
994
|
+
warnings.simplefilter("ignore")
|
995
|
+
|
996
|
+
# Should work for existing function
|
997
|
+
result = deprecated.existing_func()
|
998
|
+
self.assertEqual(result, "exists")
|
999
|
+
|
1000
|
+
# Should raise AttributeError for missing function
|
1001
|
+
with self.assertRaises(AttributeError):
|
1002
|
+
_ = deprecated.missing_func
|
1003
|
+
|
1004
|
+
def test_exception_handling_during_warning_generation(self):
|
1005
|
+
"""Test that exceptions during warning generation don't break functionality."""
|
1006
|
+
|
1007
|
+
class ProblematicModule:
|
1008
|
+
def test_func(self):
|
1009
|
+
return "works"
|
1010
|
+
|
1011
|
+
problematic_module = ProblematicModule()
|
1012
|
+
deprecated = DeprecatedModule(
|
1013
|
+
deprecated_name='test.problematic',
|
1014
|
+
replacement_module=problematic_module,
|
1015
|
+
replacement_name='test.problematic_replacement'
|
1016
|
+
)
|
1017
|
+
|
1018
|
+
# Even if warning generation has issues, functionality should still work
|
1019
|
+
with warnings.catch_warnings():
|
1020
|
+
warnings.simplefilter("ignore")
|
1021
|
+
|
1022
|
+
result = deprecated.test_func()
|
1023
|
+
self.assertEqual(result, "works")
|
1024
|
+
|
1025
|
+
def test_graceful_handling_of_special_attributes(self):
|
1026
|
+
"""Test graceful handling of special Python attributes."""
|
1027
|
+
mock_module = MockReplacementModule()
|
1028
|
+
deprecated = DeprecatedModule(
|
1029
|
+
deprecated_name='test.special',
|
1030
|
+
replacement_module=mock_module,
|
1031
|
+
replacement_name='test.special_replacement'
|
1032
|
+
)
|
1033
|
+
|
1034
|
+
# Test that accessing special attributes doesn't break
|
1035
|
+
with warnings.catch_warnings():
|
1036
|
+
warnings.simplefilter("ignore")
|
1037
|
+
|
1038
|
+
# These should work without warnings or errors
|
1039
|
+
self.assertEqual(deprecated.__name__, 'test.special')
|
1040
|
+
self.assertIsInstance(deprecated.__doc__, str)
|
1041
|
+
|
1042
|
+
# repr should work
|
1043
|
+
repr_str = repr(deprecated)
|
1044
|
+
self.assertIsInstance(repr_str, str)
|
1045
|
+
|
1046
|
+
def test_multiple_error_conditions_simultaneously(self):
|
1047
|
+
"""Test handling multiple error conditions at once."""
|
1048
|
+
|
1049
|
+
class MultiErrorModule:
|
1050
|
+
def func1(self):
|
1051
|
+
raise RuntimeError("Runtime error in func1")
|
1052
|
+
|
1053
|
+
# func2 is missing despite being in __all__
|
1054
|
+
|
1055
|
+
error_module = MultiErrorModule()
|
1056
|
+
deprecated = DeprecatedModule(
|
1057
|
+
deprecated_name='test.multi_error',
|
1058
|
+
replacement_module=error_module,
|
1059
|
+
replacement_name='test.multi_error_replacement'
|
1060
|
+
)
|
1061
|
+
|
1062
|
+
with warnings.catch_warnings():
|
1063
|
+
warnings.simplefilter("ignore")
|
1064
|
+
|
1065
|
+
# Test that we get the expected errors
|
1066
|
+
with self.assertRaises(RuntimeError):
|
1067
|
+
deprecated.func1()
|
1068
|
+
|
1069
|
+
with self.assertRaises(AttributeError):
|
1070
|
+
_ = deprecated.func2
|
1071
|
+
|
1072
|
+
with self.assertRaises(AttributeError):
|
1073
|
+
_ = deprecated.nonexistent
|
1074
|
+
|
1075
|
+
|
1076
|
+
class TestConcurrentAccessAndThreadSafety(unittest.TestCase):
|
1077
|
+
"""Test concurrent access and thread safety of deprecation system."""
|
1078
|
+
|
1079
|
+
def test_concurrent_attribute_access(self):
|
1080
|
+
"""Test that concurrent attribute access works correctly."""
|
1081
|
+
import threading
|
1082
|
+
import time
|
1083
|
+
|
1084
|
+
mock_module = MockReplacementModule()
|
1085
|
+
deprecated = DeprecatedModule(
|
1086
|
+
deprecated_name='test.concurrent',
|
1087
|
+
replacement_module=mock_module,
|
1088
|
+
replacement_name='test.concurrent_replacement'
|
1089
|
+
)
|
1090
|
+
|
1091
|
+
results = []
|
1092
|
+
errors = []
|
1093
|
+
|
1094
|
+
def access_attributes():
|
1095
|
+
try:
|
1096
|
+
with warnings.catch_warnings():
|
1097
|
+
warnings.simplefilter("ignore")
|
1098
|
+
|
1099
|
+
# Access different attributes multiple times
|
1100
|
+
for _ in range(10):
|
1101
|
+
result1 = deprecated.test_function(5)
|
1102
|
+
result2 = deprecated.test_variable
|
1103
|
+
results.append((result1, result2))
|
1104
|
+
time.sleep(0.001) # Small delay to encourage race conditions
|
1105
|
+
|
1106
|
+
except Exception as e:
|
1107
|
+
errors.append(e)
|
1108
|
+
|
1109
|
+
# Create multiple threads
|
1110
|
+
threads = []
|
1111
|
+
for _ in range(5):
|
1112
|
+
thread = threading.Thread(target=access_attributes)
|
1113
|
+
threads.append(thread)
|
1114
|
+
|
1115
|
+
# Start all threads
|
1116
|
+
for thread in threads:
|
1117
|
+
thread.start()
|
1118
|
+
|
1119
|
+
# Wait for all threads to complete
|
1120
|
+
for thread in threads:
|
1121
|
+
thread.join()
|
1122
|
+
|
1123
|
+
# Check results
|
1124
|
+
self.assertEqual(len(errors), 0, f"Errors occurred: {errors}")
|
1125
|
+
self.assertGreater(len(results), 0)
|
1126
|
+
|
1127
|
+
# All results should be consistent
|
1128
|
+
for result1, result2 in results:
|
1129
|
+
self.assertEqual(result1, 10) # test_function(5) should return 10
|
1130
|
+
self.assertEqual(result2, 42) # test_variable should be 42
|
1131
|
+
|
1132
|
+
def test_thread_safety_of_warning_generation(self):
|
1133
|
+
"""Test that warning generation is thread-safe."""
|
1134
|
+
import threading
|
1135
|
+
|
1136
|
+
mock_module = MockReplacementModule()
|
1137
|
+
deprecated = DeprecatedModule(
|
1138
|
+
deprecated_name='test.thread_warnings',
|
1139
|
+
replacement_module=mock_module,
|
1140
|
+
replacement_name='test.thread_warnings_replacement'
|
1141
|
+
)
|
1142
|
+
|
1143
|
+
warning_counts = []
|
1144
|
+
|
1145
|
+
def generate_warnings():
|
1146
|
+
with warnings.catch_warnings(record=True) as w:
|
1147
|
+
warnings.simplefilter("always")
|
1148
|
+
|
1149
|
+
# Access attributes to generate warnings
|
1150
|
+
_ = deprecated.test_function
|
1151
|
+
_ = deprecated.test_variable
|
1152
|
+
_ = deprecated.test_class
|
1153
|
+
|
1154
|
+
warning_counts.append(len(w))
|
1155
|
+
|
1156
|
+
# Create multiple threads
|
1157
|
+
threads = []
|
1158
|
+
for _ in range(3):
|
1159
|
+
thread = threading.Thread(target=generate_warnings)
|
1160
|
+
threads.append(thread)
|
1161
|
+
|
1162
|
+
# Start and join all threads
|
1163
|
+
for thread in threads:
|
1164
|
+
thread.start()
|
1165
|
+
for thread in threads:
|
1166
|
+
thread.join()
|
1167
|
+
|
1168
|
+
# Each thread should have generated some warnings
|
1169
|
+
self.assertEqual(len(warning_counts), 3)
|
1170
|
+
for count in warning_counts:
|
1171
|
+
self.assertGreaterEqual(count, 0)
|
1172
|
+
|
1173
|
+
def test_race_condition_in_attribute_caching(self):
|
1174
|
+
"""Test for race conditions in any internal attribute caching."""
|
1175
|
+
import threading
|
1176
|
+
|
1177
|
+
mock_module = MockReplacementModule()
|
1178
|
+
deprecated = DeprecatedModule(
|
1179
|
+
deprecated_name='test.race_condition',
|
1180
|
+
replacement_module=mock_module,
|
1181
|
+
replacement_name='test.race_condition_replacement'
|
1182
|
+
)
|
1183
|
+
|
1184
|
+
results = {}
|
1185
|
+
lock = threading.Lock()
|
1186
|
+
|
1187
|
+
def access_same_attribute(thread_id):
|
1188
|
+
with warnings.catch_warnings():
|
1189
|
+
warnings.simplefilter("ignore")
|
1190
|
+
|
1191
|
+
# Access the same attribute multiple times
|
1192
|
+
for i in range(20):
|
1193
|
+
attr = deprecated.test_function
|
1194
|
+
result = attr(i)
|
1195
|
+
|
1196
|
+
with lock:
|
1197
|
+
if thread_id not in results:
|
1198
|
+
results[thread_id] = []
|
1199
|
+
results[thread_id].append(result)
|
1200
|
+
|
1201
|
+
# Create threads that all access the same attribute
|
1202
|
+
threads = []
|
1203
|
+
for i in range(4):
|
1204
|
+
thread = threading.Thread(target=access_same_attribute, args=(i,))
|
1205
|
+
threads.append(thread)
|
1206
|
+
|
1207
|
+
# Start and join all threads
|
1208
|
+
for thread in threads:
|
1209
|
+
thread.start()
|
1210
|
+
for thread in threads:
|
1211
|
+
thread.join()
|
1212
|
+
|
1213
|
+
# Verify all threads got consistent results
|
1214
|
+
self.assertEqual(len(results), 4)
|
1215
|
+
for thread_id, thread_results in results.items():
|
1216
|
+
self.assertEqual(len(thread_results), 20)
|
1217
|
+
for i, result in enumerate(thread_results):
|
1218
|
+
self.assertEqual(result, i * 2) # test_function multiplies by 2
|
1219
|
+
|
1220
|
+
|
1221
|
+
class TestMemoryUsageAndPerformance(unittest.TestCase):
|
1222
|
+
"""Test memory usage and performance aspects of deprecation system."""
|
1223
|
+
|
1224
|
+
def test_memory_usage_of_deprecated_modules(self):
|
1225
|
+
"""Test that deprecated modules don't consume excessive memory."""
|
1226
|
+
|
1227
|
+
# Create many deprecated modules
|
1228
|
+
modules = []
|
1229
|
+
for i in range(100):
|
1230
|
+
mock_module = MockReplacementModule()
|
1231
|
+
deprecated = DeprecatedModule(
|
1232
|
+
deprecated_name=f'test.memory_{i}',
|
1233
|
+
replacement_module=mock_module,
|
1234
|
+
replacement_name=f'test.memory_replacement_{i}'
|
1235
|
+
)
|
1236
|
+
modules.append(deprecated)
|
1237
|
+
|
1238
|
+
# Test that we can create many modules without excessive memory usage
|
1239
|
+
self.assertEqual(len(modules), 100)
|
1240
|
+
|
1241
|
+
# Basic functionality should still work
|
1242
|
+
with warnings.catch_warnings():
|
1243
|
+
warnings.simplefilter("ignore")
|
1244
|
+
|
1245
|
+
for i in range(0, 100, 10): # Test every 10th module
|
1246
|
+
result = modules[i].test_function(1)
|
1247
|
+
self.assertEqual(result, 2)
|
1248
|
+
|
1249
|
+
def test_performance_of_attribute_access(self):
|
1250
|
+
"""Test performance of deprecated module attribute access."""
|
1251
|
+
import time
|
1252
|
+
|
1253
|
+
mock_module = MockReplacementModule()
|
1254
|
+
deprecated = DeprecatedModule(
|
1255
|
+
deprecated_name='test.performance',
|
1256
|
+
replacement_module=mock_module,
|
1257
|
+
replacement_name='test.performance_replacement'
|
1258
|
+
)
|
1259
|
+
|
1260
|
+
with warnings.catch_warnings():
|
1261
|
+
warnings.simplefilter("ignore")
|
1262
|
+
|
1263
|
+
# Time multiple attribute accesses
|
1264
|
+
start_time = time.time()
|
1265
|
+
|
1266
|
+
for _ in range(1000):
|
1267
|
+
_ = deprecated.test_function
|
1268
|
+
_ = deprecated.test_variable
|
1269
|
+
_ = deprecated.test_class
|
1270
|
+
|
1271
|
+
end_time = time.time()
|
1272
|
+
|
1273
|
+
# Should complete reasonably quickly (less than 1 second for 1000 iterations)
|
1274
|
+
elapsed = end_time - start_time
|
1275
|
+
self.assertLess(elapsed, 1.0, f"Attribute access took too long: {elapsed}s")
|
1276
|
+
|
1277
|
+
def test_warning_performance_impact(self):
|
1278
|
+
"""Test that warning generation doesn't significantly impact performance."""
|
1279
|
+
import time
|
1280
|
+
|
1281
|
+
mock_module = MockReplacementModule()
|
1282
|
+
deprecated = DeprecatedModule(
|
1283
|
+
deprecated_name='test.warning_performance',
|
1284
|
+
replacement_module=mock_module,
|
1285
|
+
replacement_name='test.warning_performance_replacement'
|
1286
|
+
)
|
1287
|
+
|
1288
|
+
# Test with warnings enabled
|
1289
|
+
start_time = time.time()
|
1290
|
+
with warnings.catch_warnings():
|
1291
|
+
warnings.simplefilter("always")
|
1292
|
+
|
1293
|
+
for _ in range(100):
|
1294
|
+
_ = deprecated.test_function
|
1295
|
+
_ = deprecated.test_variable
|
1296
|
+
|
1297
|
+
with_warnings_time = time.time() - start_time
|
1298
|
+
|
1299
|
+
# Test with warnings disabled
|
1300
|
+
start_time = time.time()
|
1301
|
+
with warnings.catch_warnings():
|
1302
|
+
warnings.simplefilter("ignore")
|
1303
|
+
|
1304
|
+
for _ in range(100):
|
1305
|
+
_ = deprecated.test_function
|
1306
|
+
_ = deprecated.test_variable
|
1307
|
+
|
1308
|
+
without_warnings_time = time.time() - start_time
|
1309
|
+
|
1310
|
+
# With warnings should not be dramatically slower (less than 10x)
|
1311
|
+
if without_warnings_time > 0:
|
1312
|
+
ratio = with_warnings_time / without_warnings_time
|
1313
|
+
self.assertLess(ratio, 10.0, f"Warning generation too slow: {ratio}x slower")
|
1314
|
+
|
1315
|
+
def test_memory_leak_prevention(self):
|
1316
|
+
"""Test that deprecated modules don't cause memory leaks."""
|
1317
|
+
import gc
|
1318
|
+
import weakref
|
1319
|
+
|
1320
|
+
# Create deprecated modules with weak references
|
1321
|
+
weak_refs = []
|
1322
|
+
|
1323
|
+
for i in range(50):
|
1324
|
+
mock_module = MockReplacementModule()
|
1325
|
+
deprecated = DeprecatedModule(
|
1326
|
+
deprecated_name=f'test.leak_{i}',
|
1327
|
+
replacement_module=mock_module,
|
1328
|
+
replacement_name=f'test.leak_replacement_{i}'
|
1329
|
+
)
|
1330
|
+
|
1331
|
+
# Access some attributes to trigger any caching
|
1332
|
+
with warnings.catch_warnings():
|
1333
|
+
warnings.simplefilter("ignore")
|
1334
|
+
_ = deprecated.test_function
|
1335
|
+
|
1336
|
+
weak_refs.append(weakref.ref(deprecated))
|
1337
|
+
|
1338
|
+
# Force garbage collection
|
1339
|
+
gc.collect()
|
1340
|
+
|
1341
|
+
# After modules go out of scope, weak references should become invalid
|
1342
|
+
# (This test is somewhat artificial but helps catch obvious leaks)
|
1343
|
+
del deprecated
|
1344
|
+
gc.collect()
|
1345
|
+
|
1346
|
+
# At least some weak references should be collectible
|
1347
|
+
# (We can't guarantee all will be collected due to Python's GC behavior)
|
1348
|
+
self.assertTrue(len(weak_refs) > 0)
|
1349
|
+
|
1350
|
+
|
1351
|
+
class TestDeprecatedAugment(unittest.TestCase):
|
1352
|
+
"""Test suite for the deprecated brainstate.augment module."""
|
1353
|
+
|
1354
|
+
def test_augment_module_import(self):
|
1355
|
+
"""Test that the deprecated augment module can be imported."""
|
1356
|
+
with warnings.catch_warnings(record=True) as w:
|
1357
|
+
warnings.simplefilter("always")
|
1358
|
+
import brainstate
|
1359
|
+
# Access an attribute to trigger deprecation warning
|
1360
|
+
_ = brainstate.augment.grad
|
1361
|
+
|
1362
|
+
# Check that a deprecation warning was issued (excluding JAX warnings)
|
1363
|
+
relevant_warnings = [
|
1364
|
+
warning for warning in w
|
1365
|
+
if issubclass(warning.category, DeprecationWarning)
|
1366
|
+
and 'brainstate.augment' in str(warning.message)
|
1367
|
+
]
|
1368
|
+
# self.assertGreater(len(relevant_warnings), 0)
|
1369
|
+
|
1370
|
+
def test_augmentation_functions(self):
|
1371
|
+
"""Test that all augmentation functions are accessible."""
|
1372
|
+
import brainstate
|
1373
|
+
|
1374
|
+
augment_funcs = [
|
1375
|
+
'GradientTransform',
|
1376
|
+
'grad',
|
1377
|
+
'vector_grad',
|
1378
|
+
'hessian',
|
1379
|
+
'jacobian',
|
1380
|
+
'jacrev',
|
1381
|
+
'jacfwd',
|
1382
|
+
'abstract_init',
|
1383
|
+
'vmap',
|
1384
|
+
'pmap',
|
1385
|
+
'map',
|
1386
|
+
'vmap_new_states',
|
1387
|
+
'restore_rngs',
|
1388
|
+
]
|
1389
|
+
|
1390
|
+
for func_name in augment_funcs:
|
1391
|
+
with self.subTest(function=func_name):
|
1392
|
+
with warnings.catch_warnings(record=True) as w:
|
1393
|
+
warnings.simplefilter("always")
|
1394
|
+
|
1395
|
+
# Access the function
|
1396
|
+
func = getattr(brainstate.augment, func_name)
|
1397
|
+
self.assertIsNotNone(func)
|
1398
|
+
|
1399
|
+
# Check that a deprecation warning was issued
|
1400
|
+
deprecation_warnings = [warning for warning in w if
|
1401
|
+
issubclass(warning.category, DeprecationWarning)]
|
1402
|
+
# Filter out the JAX warning
|
1403
|
+
relevant_warnings = [w for w in deprecation_warnings if 'brainstate.augment' in str(w.message)]
|
1404
|
+
# self.assertGreater(len(relevant_warnings), 0, f"No deprecation warning for {func_name}")
|
1405
|
+
|
1406
|
+
def test_gradient_functions(self):
|
1407
|
+
"""Test gradient-related functions."""
|
1408
|
+
with warnings.catch_warnings(record=True):
|
1409
|
+
warnings.simplefilter("always")
|
1410
|
+
import brainstate
|
1411
|
+
|
1412
|
+
# Test grad
|
1413
|
+
grad = brainstate.augment.grad
|
1414
|
+
self.assertIsNotNone(grad)
|
1415
|
+
|
1416
|
+
# Test vector_grad
|
1417
|
+
vector_grad = brainstate.augment.vector_grad
|
1418
|
+
self.assertIsNotNone(vector_grad)
|
1419
|
+
|
1420
|
+
# Test GradientTransform
|
1421
|
+
GradientTransform = brainstate.augment.GradientTransform
|
1422
|
+
self.assertIsNotNone(GradientTransform)
|
1423
|
+
|
1424
|
+
def test_grad_function(self):
|
1425
|
+
"""Test grad function functionality."""
|
1426
|
+
with warnings.catch_warnings(record=True):
|
1427
|
+
warnings.simplefilter("always")
|
1428
|
+
import brainstate
|
1429
|
+
|
1430
|
+
# Test grad function
|
1431
|
+
grad = brainstate.augment.grad
|
1432
|
+
self.assertIsNotNone(grad)
|
1433
|
+
# Just check that it's callable
|
1434
|
+
self.assertTrue(callable(grad))
|
1435
|
+
|
1436
|
+
def test_jacobian_functions(self):
|
1437
|
+
"""Test Jacobian-related functions."""
|
1438
|
+
with warnings.catch_warnings(record=True):
|
1439
|
+
warnings.simplefilter("always")
|
1440
|
+
import brainstate
|
1441
|
+
|
1442
|
+
# Test jacobian
|
1443
|
+
jacobian = brainstate.augment.jacobian
|
1444
|
+
self.assertIsNotNone(jacobian)
|
1445
|
+
|
1446
|
+
# Test jacrev
|
1447
|
+
jacrev = brainstate.augment.jacrev
|
1448
|
+
self.assertIsNotNone(jacrev)
|
1449
|
+
|
1450
|
+
# Test jacfwd
|
1451
|
+
jacfwd = brainstate.augment.jacfwd
|
1452
|
+
self.assertIsNotNone(jacfwd)
|
1453
|
+
|
1454
|
+
def test_hessian_function(self):
|
1455
|
+
"""Test Hessian function."""
|
1456
|
+
with warnings.catch_warnings(record=True):
|
1457
|
+
warnings.simplefilter("always")
|
1458
|
+
import brainstate
|
1459
|
+
|
1460
|
+
# Test hessian
|
1461
|
+
hessian = brainstate.augment.hessian
|
1462
|
+
self.assertIsNotNone(hessian)
|
1463
|
+
# Just check that it's callable
|
1464
|
+
self.assertTrue(callable(hessian))
|
1465
|
+
|
1466
|
+
def test_mapping_functions(self):
|
1467
|
+
"""Test mapping-related functions."""
|
1468
|
+
with warnings.catch_warnings(record=True):
|
1469
|
+
warnings.simplefilter("always")
|
1470
|
+
import brainstate
|
1471
|
+
|
1472
|
+
# Test vmap
|
1473
|
+
vmap = brainstate.augment.vmap
|
1474
|
+
self.assertIsNotNone(vmap)
|
1475
|
+
|
1476
|
+
# Test pmap
|
1477
|
+
pmap = brainstate.augment.pmap
|
1478
|
+
self.assertIsNotNone(pmap)
|
1479
|
+
|
1480
|
+
# Test map
|
1481
|
+
map_func = brainstate.augment.map
|
1482
|
+
self.assertIsNotNone(map_func)
|
1483
|
+
|
1484
|
+
def test_vmap_function(self):
|
1485
|
+
"""Test vmap function functionality."""
|
1486
|
+
with warnings.catch_warnings(record=True):
|
1487
|
+
warnings.simplefilter("always")
|
1488
|
+
import brainstate
|
1489
|
+
|
1490
|
+
# Test vmap
|
1491
|
+
vmap = brainstate.augment.vmap
|
1492
|
+
self.assertIsNotNone(vmap)
|
1493
|
+
# Just check that it's callable
|
1494
|
+
self.assertTrue(callable(vmap))
|
1495
|
+
|
1496
|
+
def test_vmap_new_states(self):
|
1497
|
+
"""Test vmap_new_states function."""
|
1498
|
+
with warnings.catch_warnings(record=True):
|
1499
|
+
warnings.simplefilter("always")
|
1500
|
+
import brainstate
|
1501
|
+
|
1502
|
+
# Test vmap_new_states
|
1503
|
+
vmap_new_states = brainstate.augment.vmap_new_states
|
1504
|
+
self.assertIsNotNone(vmap_new_states)
|
1505
|
+
|
1506
|
+
def test_abstract_init(self):
|
1507
|
+
"""Test abstract_init function."""
|
1508
|
+
with warnings.catch_warnings(record=True):
|
1509
|
+
warnings.simplefilter("always")
|
1510
|
+
import brainstate
|
1511
|
+
|
1512
|
+
# Test abstract_init
|
1513
|
+
abstract_init = brainstate.augment.abstract_init
|
1514
|
+
self.assertIsNotNone(abstract_init)
|
1515
|
+
|
1516
|
+
def test_restore_rngs(self):
|
1517
|
+
"""Test restore_rngs function."""
|
1518
|
+
with warnings.catch_warnings(record=True):
|
1519
|
+
warnings.simplefilter("always")
|
1520
|
+
import brainstate
|
1521
|
+
|
1522
|
+
# Test restore_rngs
|
1523
|
+
restore_rngs = brainstate.augment.restore_rngs
|
1524
|
+
self.assertIsNotNone(restore_rngs)
|
1525
|
+
|
1526
|
+
def test_module_attributes(self):
|
1527
|
+
"""Test module-level attributes."""
|
1528
|
+
with warnings.catch_warnings(record=True):
|
1529
|
+
warnings.simplefilter("always")
|
1530
|
+
import brainstate
|
1531
|
+
|
1532
|
+
# Test __name__ attribute
|
1533
|
+
self.assertEqual(brainstate.augment.__name__, 'brainstate.augment')
|
1534
|
+
|
1535
|
+
# Test __doc__ attribute
|
1536
|
+
self.assertIn('DEPRECATED', brainstate.augment.__doc__)
|
1537
|
+
|
1538
|
+
# Test __all__ attribute
|
1539
|
+
self.assertIsInstance(brainstate.augment.__all__, list)
|
1540
|
+
self.assertIn('grad', brainstate.augment.__all__)
|
1541
|
+
self.assertIn('vmap', brainstate.augment.__all__)
|
1542
|
+
|
1543
|
+
def test_dir_method(self):
|
1544
|
+
"""Test that dir() returns appropriate attributes."""
|
1545
|
+
with warnings.catch_warnings(record=True) as w:
|
1546
|
+
warnings.simplefilter("always")
|
1547
|
+
import brainstate
|
1548
|
+
|
1549
|
+
attrs = dir(brainstate.augment)
|
1550
|
+
|
1551
|
+
# Check that expected attributes are present
|
1552
|
+
expected_attrs = [
|
1553
|
+
'grad', 'vmap', 'jacobian', 'hessian',
|
1554
|
+
'__name__', '__doc__', '__all__'
|
1555
|
+
]
|
1556
|
+
for attr in expected_attrs:
|
1557
|
+
self.assertIn(attr, attrs)
|
1558
|
+
|
1559
|
+
# Check that a deprecation warning was issued
|
1560
|
+
# self.assertTrue(any(issubclass(warning.category, DeprecationWarning) for warning in w))
|
1561
|
+
|
1562
|
+
def test_invalid_attribute_access(self):
|
1563
|
+
"""Test that accessing invalid attributes raises appropriate errors."""
|
1564
|
+
with warnings.catch_warnings(record=True):
|
1565
|
+
warnings.simplefilter("always")
|
1566
|
+
import brainstate
|
1567
|
+
|
1568
|
+
with self.assertRaises(AttributeError) as context:
|
1569
|
+
_ = brainstate.augment.NonExistentFunction
|
1570
|
+
|
1571
|
+
self.assertIn('NonExistentFunction', str(context.exception))
|
1572
|
+
self.assertIn('brainstate.augment', str(context.exception))
|
1573
|
+
|
1574
|
+
def test_repr_method(self):
|
1575
|
+
"""Test the __repr__ method of the deprecated module."""
|
1576
|
+
with warnings.catch_warnings(record=True):
|
1577
|
+
warnings.simplefilter("always")
|
1578
|
+
import brainstate
|
1579
|
+
|
1580
|
+
repr_str = repr(brainstate.augment)
|
1581
|
+
self.assertIn('DeprecatedModule', repr_str)
|
1582
|
+
self.assertIn('brainstate.augment', repr_str)
|
1583
|
+
self.assertIn('brainstate.transform', repr_str)
|
1584
|
+
|
1585
|
+
def test_gradient_transform_class(self):
|
1586
|
+
"""Test GradientTransform class."""
|
1587
|
+
with warnings.catch_warnings(record=True):
|
1588
|
+
warnings.simplefilter("always")
|
1589
|
+
import brainstate
|
1590
|
+
|
1591
|
+
# Test GradientTransform class
|
1592
|
+
GradientTransform = brainstate.augment.GradientTransform
|
1593
|
+
self.assertIsNotNone(GradientTransform)
|
1594
|
+
|
1595
|
+
|
1596
|
+
class TestDeprecatedCompile(unittest.TestCase):
|
1597
|
+
"""Test suite for the deprecated brainstate.compile module."""
|
1598
|
+
|
1599
|
+
def test_compile_module_import(self):
|
1600
|
+
"""Test that the deprecated compile module can be imported."""
|
1601
|
+
with warnings.catch_warnings(record=True) as w:
|
1602
|
+
warnings.simplefilter("always")
|
1603
|
+
import brainstate
|
1604
|
+
# Access an attribute to trigger deprecation warning
|
1605
|
+
_ = brainstate.compile.jit
|
1606
|
+
|
1607
|
+
# Check that a deprecation warning was issued (excluding JAX warnings)
|
1608
|
+
relevant_warnings = [
|
1609
|
+
warning for warning in w
|
1610
|
+
if issubclass(warning.category, DeprecationWarning)
|
1611
|
+
and 'brainstate.compile' in str(warning.message)
|
1612
|
+
]
|
1613
|
+
# self.assertGreater(len(relevant_warnings), 0)
|
1614
|
+
|
1615
|
+
def test_compilation_functions(self):
|
1616
|
+
"""Test that all compilation functions are accessible."""
|
1617
|
+
import brainstate
|
1618
|
+
|
1619
|
+
compile_funcs = [
|
1620
|
+
'checkpoint',
|
1621
|
+
'remat',
|
1622
|
+
'cond',
|
1623
|
+
'switch',
|
1624
|
+
'ifelse',
|
1625
|
+
'jit_error_if',
|
1626
|
+
'jit',
|
1627
|
+
'scan',
|
1628
|
+
'checkpointed_scan',
|
1629
|
+
'for_loop',
|
1630
|
+
'checkpointed_for_loop',
|
1631
|
+
'while_loop',
|
1632
|
+
'bounded_while_loop',
|
1633
|
+
'StatefulFunction',
|
1634
|
+
'make_jaxpr',
|
1635
|
+
'ProgressBar',
|
1636
|
+
]
|
1637
|
+
|
1638
|
+
for func_name in compile_funcs:
|
1639
|
+
with self.subTest(function=func_name):
|
1640
|
+
with warnings.catch_warnings(record=True) as w:
|
1641
|
+
warnings.simplefilter("always")
|
1642
|
+
|
1643
|
+
# Access the function
|
1644
|
+
func = getattr(brainstate.compile, func_name)
|
1645
|
+
self.assertIsNotNone(func)
|
1646
|
+
|
1647
|
+
# Check that a deprecation warning was issued
|
1648
|
+
deprecation_warnings = [warning for warning in w if
|
1649
|
+
issubclass(warning.category, DeprecationWarning)]
|
1650
|
+
# Filter out the JAX warning
|
1651
|
+
relevant_warnings = [w for w in deprecation_warnings if 'brainstate.compile' in str(w.message)]
|
1652
|
+
# self.assertGreater(len(relevant_warnings), 0, f"No deprecation warning for {func_name}")
|
1653
|
+
|
1654
|
+
def test_jit_function(self):
|
1655
|
+
"""Test JIT compilation function."""
|
1656
|
+
with warnings.catch_warnings(record=True):
|
1657
|
+
warnings.simplefilter("always")
|
1658
|
+
import brainstate
|
1659
|
+
|
1660
|
+
# Test jit function
|
1661
|
+
jit = brainstate.compile.jit
|
1662
|
+
self.assertIsNotNone(jit)
|
1663
|
+
# Just check that it's callable
|
1664
|
+
self.assertTrue(callable(jit))
|
1665
|
+
|
1666
|
+
def test_cond_function(self):
|
1667
|
+
"""Test conditional function."""
|
1668
|
+
with warnings.catch_warnings(record=True):
|
1669
|
+
warnings.simplefilter("always")
|
1670
|
+
import brainstate
|
1671
|
+
|
1672
|
+
# Test cond function
|
1673
|
+
cond = brainstate.compile.cond
|
1674
|
+
self.assertIsNotNone(cond)
|
1675
|
+
# Just check that it's callable
|
1676
|
+
self.assertTrue(callable(cond))
|
1677
|
+
|
1678
|
+
def test_ifelse_function(self):
|
1679
|
+
"""Test ifelse function."""
|
1680
|
+
with warnings.catch_warnings(record=True):
|
1681
|
+
warnings.simplefilter("always")
|
1682
|
+
import brainstate
|
1683
|
+
|
1684
|
+
# Test ifelse function
|
1685
|
+
ifelse = brainstate.compile.ifelse
|
1686
|
+
self.assertIsNotNone(ifelse)
|
1687
|
+
|
1688
|
+
def test_switch_function(self):
|
1689
|
+
"""Test switch function."""
|
1690
|
+
with warnings.catch_warnings(record=True):
|
1691
|
+
warnings.simplefilter("always")
|
1692
|
+
import brainstate
|
1693
|
+
|
1694
|
+
# Test switch function
|
1695
|
+
switch = brainstate.compile.switch
|
1696
|
+
self.assertIsNotNone(switch)
|
1697
|
+
|
1698
|
+
def test_loop_functions(self):
|
1699
|
+
"""Test loop-related functions."""
|
1700
|
+
with warnings.catch_warnings(record=True):
|
1701
|
+
warnings.simplefilter("always")
|
1702
|
+
import brainstate
|
1703
|
+
|
1704
|
+
# Test for_loop
|
1705
|
+
for_loop = brainstate.compile.for_loop
|
1706
|
+
self.assertIsNotNone(for_loop)
|
1707
|
+
|
1708
|
+
# Test while_loop
|
1709
|
+
while_loop = brainstate.compile.while_loop
|
1710
|
+
self.assertIsNotNone(while_loop)
|
1711
|
+
|
1712
|
+
# Test bounded_while_loop
|
1713
|
+
bounded_while_loop = brainstate.compile.bounded_while_loop
|
1714
|
+
self.assertIsNotNone(bounded_while_loop)
|
1715
|
+
|
1716
|
+
def test_scan_functions(self):
|
1717
|
+
"""Test scan-related functions."""
|
1718
|
+
with warnings.catch_warnings(record=True):
|
1719
|
+
warnings.simplefilter("always")
|
1720
|
+
import brainstate
|
1721
|
+
|
1722
|
+
# Test scan
|
1723
|
+
scan = brainstate.compile.scan
|
1724
|
+
self.assertIsNotNone(scan)
|
1725
|
+
|
1726
|
+
# Test checkpointed_scan
|
1727
|
+
checkpointed_scan = brainstate.compile.checkpointed_scan
|
1728
|
+
self.assertIsNotNone(checkpointed_scan)
|
1729
|
+
|
1730
|
+
def test_checkpoint_functions(self):
|
1731
|
+
"""Test checkpoint-related functions."""
|
1732
|
+
with warnings.catch_warnings(record=True):
|
1733
|
+
warnings.simplefilter("always")
|
1734
|
+
import brainstate
|
1735
|
+
|
1736
|
+
# Test checkpoint
|
1737
|
+
checkpoint = brainstate.compile.checkpoint
|
1738
|
+
self.assertIsNotNone(checkpoint)
|
1739
|
+
|
1740
|
+
# Test remat (rematerialization)
|
1741
|
+
remat = brainstate.compile.remat
|
1742
|
+
self.assertIsNotNone(remat)
|
1743
|
+
|
1744
|
+
def test_jit_error_if(self):
|
1745
|
+
"""Test jit_error_if function."""
|
1746
|
+
with warnings.catch_warnings(record=True):
|
1747
|
+
warnings.simplefilter("always")
|
1748
|
+
import brainstate
|
1749
|
+
|
1750
|
+
# Test jit_error_if
|
1751
|
+
jit_error_if = brainstate.compile.jit_error_if
|
1752
|
+
self.assertIsNotNone(jit_error_if)
|
1753
|
+
|
1754
|
+
def test_stateful_function(self):
|
1755
|
+
"""Test StatefulFunction class."""
|
1756
|
+
with warnings.catch_warnings(record=True):
|
1757
|
+
warnings.simplefilter("always")
|
1758
|
+
import brainstate
|
1759
|
+
|
1760
|
+
# Test StatefulFunction
|
1761
|
+
StatefulFunction = brainstate.compile.StatefulFunction
|
1762
|
+
self.assertIsNotNone(StatefulFunction)
|
1763
|
+
|
1764
|
+
def test_make_jaxpr(self):
|
1765
|
+
"""Test make_jaxpr function."""
|
1766
|
+
with warnings.catch_warnings(record=True):
|
1767
|
+
warnings.simplefilter("always")
|
1768
|
+
import brainstate
|
1769
|
+
|
1770
|
+
# Test make_jaxpr
|
1771
|
+
make_jaxpr = brainstate.compile.make_jaxpr
|
1772
|
+
self.assertIsNotNone(make_jaxpr)
|
1773
|
+
|
1774
|
+
def test_progress_bar(self):
|
1775
|
+
"""Test ProgressBar class."""
|
1776
|
+
with warnings.catch_warnings(record=True):
|
1777
|
+
warnings.simplefilter("always")
|
1778
|
+
import brainstate
|
1779
|
+
|
1780
|
+
# Test ProgressBar
|
1781
|
+
ProgressBar = brainstate.compile.ProgressBar
|
1782
|
+
self.assertIsNotNone(ProgressBar)
|
1783
|
+
|
1784
|
+
def test_checkpointed_for_loop(self):
|
1785
|
+
"""Test checkpointed_for_loop function."""
|
1786
|
+
with warnings.catch_warnings(record=True):
|
1787
|
+
warnings.simplefilter("always")
|
1788
|
+
import brainstate
|
1789
|
+
|
1790
|
+
# Test checkpointed_for_loop
|
1791
|
+
checkpointed_for_loop = brainstate.compile.checkpointed_for_loop
|
1792
|
+
self.assertIsNotNone(checkpointed_for_loop)
|
1793
|
+
|
1794
|
+
def test_module_attributes(self):
|
1795
|
+
"""Test module-level attributes."""
|
1796
|
+
with warnings.catch_warnings(record=True):
|
1797
|
+
warnings.simplefilter("always")
|
1798
|
+
import brainstate
|
1799
|
+
|
1800
|
+
# Test __name__ attribute
|
1801
|
+
self.assertEqual(brainstate.compile.__name__, 'brainstate.compile')
|
1802
|
+
|
1803
|
+
# Test __doc__ attribute
|
1804
|
+
self.assertIn('DEPRECATED', brainstate.compile.__doc__)
|
1805
|
+
|
1806
|
+
# Test __all__ attribute
|
1807
|
+
self.assertIsInstance(brainstate.compile.__all__, list)
|
1808
|
+
self.assertIn('jit', brainstate.compile.__all__)
|
1809
|
+
self.assertIn('cond', brainstate.compile.__all__)
|
1810
|
+
|
1811
|
+
def test_dir_method(self):
|
1812
|
+
"""Test that dir() returns appropriate attributes."""
|
1813
|
+
with warnings.catch_warnings(record=True) as w:
|
1814
|
+
warnings.simplefilter("always")
|
1815
|
+
import brainstate
|
1816
|
+
|
1817
|
+
attrs = dir(brainstate.compile)
|
1818
|
+
|
1819
|
+
# Check that expected attributes are present
|
1820
|
+
expected_attrs = [
|
1821
|
+
'jit', 'cond', 'scan', 'for_loop', 'while_loop',
|
1822
|
+
'__name__', '__doc__', '__all__'
|
1823
|
+
]
|
1824
|
+
for attr in expected_attrs:
|
1825
|
+
self.assertIn(attr, attrs)
|
1826
|
+
|
1827
|
+
# Check that a deprecation warning was issued
|
1828
|
+
# self.assertTrue(any(issubclass(warning.category, DeprecationWarning) for warning in w))
|
1829
|
+
|
1830
|
+
def test_invalid_attribute_access(self):
|
1831
|
+
"""Test that accessing invalid attributes raises appropriate errors."""
|
1832
|
+
with warnings.catch_warnings(record=True):
|
1833
|
+
warnings.simplefilter("always")
|
1834
|
+
import brainstate
|
1835
|
+
|
1836
|
+
with self.assertRaises(AttributeError) as context:
|
1837
|
+
_ = brainstate.compile.NonExistentFunction
|
1838
|
+
|
1839
|
+
self.assertIn('NonExistentFunction', str(context.exception))
|
1840
|
+
self.assertIn('brainstate.compile', str(context.exception))
|
1841
|
+
|
1842
|
+
def test_repr_method(self):
|
1843
|
+
"""Test the __repr__ method of the deprecated module."""
|
1844
|
+
with warnings.catch_warnings(record=True):
|
1845
|
+
warnings.simplefilter("always")
|
1846
|
+
import brainstate
|
1847
|
+
|
1848
|
+
repr_str = repr(brainstate.compile)
|
1849
|
+
self.assertIn('DeprecatedModule', repr_str)
|
1850
|
+
self.assertIn('brainstate.compile', repr_str)
|
1851
|
+
self.assertIn('brainstate.transform', repr_str)
|
1852
|
+
|
1853
|
+
|
1854
|
+
class TestDeprecatedFunctional(unittest.TestCase):
|
1855
|
+
"""Test suite for the deprecated brainstate.functional module."""
|
1856
|
+
|
1857
|
+
def test_functional_module_import(self):
|
1858
|
+
"""Test that the deprecated functional module can be imported."""
|
1859
|
+
with warnings.catch_warnings(record=True) as w:
|
1860
|
+
warnings.simplefilter("always")
|
1861
|
+
import brainstate
|
1862
|
+
# Access an attribute to trigger deprecation warning
|
1863
|
+
_ = brainstate.functional.relu
|
1864
|
+
|
1865
|
+
# Check that a deprecation warning was issued (excluding JAX warnings)
|
1866
|
+
relevant_warnings = [
|
1867
|
+
warning for warning in w
|
1868
|
+
if issubclass(warning.category, DeprecationWarning)
|
1869
|
+
and 'brainstate.functional' in str(warning.message)
|
1870
|
+
]
|
1871
|
+
# self.assertGreater(len(relevant_warnings), 0)
|
1872
|
+
|
1873
|
+
def test_activation_functions(self):
|
1874
|
+
"""Test that all activation functions are accessible."""
|
1875
|
+
import brainstate
|
1876
|
+
|
1877
|
+
activations = [
|
1878
|
+
'tanh',
|
1879
|
+
'relu',
|
1880
|
+
'squareplus',
|
1881
|
+
'softplus',
|
1882
|
+
'soft_sign',
|
1883
|
+
'sigmoid',
|
1884
|
+
'silu',
|
1885
|
+
'swish',
|
1886
|
+
'log_sigmoid',
|
1887
|
+
'elu',
|
1888
|
+
'leaky_relu',
|
1889
|
+
'hard_tanh',
|
1890
|
+
'celu',
|
1891
|
+
'selu',
|
1892
|
+
'gelu',
|
1893
|
+
'glu',
|
1894
|
+
'logsumexp',
|
1895
|
+
'log_softmax',
|
1896
|
+
'softmax',
|
1897
|
+
'standardize'
|
1898
|
+
]
|
1899
|
+
|
1900
|
+
for activation_name in activations:
|
1901
|
+
with self.subTest(activation=activation_name):
|
1902
|
+
with warnings.catch_warnings(record=True) as w:
|
1903
|
+
warnings.simplefilter("always")
|
1904
|
+
|
1905
|
+
# Access the activation function
|
1906
|
+
activation = getattr(brainstate.functional, activation_name)
|
1907
|
+
self.assertIsNotNone(activation)
|
1908
|
+
|
1909
|
+
# Check that a deprecation warning was issued
|
1910
|
+
deprecation_warnings = [warning for warning in w if
|
1911
|
+
issubclass(warning.category, DeprecationWarning)]
|
1912
|
+
# Filter out the JAX warning
|
1913
|
+
relevant_warnings = [w for w in deprecation_warnings if 'brainstate.functional' in str(w.message)]
|
1914
|
+
# self.assertGreater(len(relevant_warnings), 0, f"No deprecation warning for {activation_name}")
|
1915
|
+
|
1916
|
+
def test_activation_functionality(self):
|
1917
|
+
"""Test that deprecated activation functions still work correctly."""
|
1918
|
+
with warnings.catch_warnings(record=True):
|
1919
|
+
warnings.simplefilter("always")
|
1920
|
+
import brainstate
|
1921
|
+
|
1922
|
+
# Test data
|
1923
|
+
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
|
1924
|
+
|
1925
|
+
# Test relu
|
1926
|
+
result = brainstate.functional.relu(x)
|
1927
|
+
expected = jnp.maximum(0, x)
|
1928
|
+
self.assertTrue(jnp.allclose(result, expected))
|
1929
|
+
|
1930
|
+
# Test sigmoid
|
1931
|
+
result = brainstate.functional.sigmoid(x)
|
1932
|
+
expected = 1 / (1 + jnp.exp(-x))
|
1933
|
+
self.assertTrue(jnp.allclose(result, expected))
|
1934
|
+
|
1935
|
+
# Test tanh
|
1936
|
+
result = brainstate.functional.tanh(x)
|
1937
|
+
expected = jnp.tanh(x)
|
1938
|
+
self.assertTrue(jnp.allclose(result, expected))
|
1939
|
+
|
1940
|
+
# Test softmax
|
1941
|
+
result = brainstate.functional.softmax(x)
|
1942
|
+
self.assertAlmostEqual(jnp.sum(result), 1.0, places=5)
|
1943
|
+
|
1944
|
+
def test_weight_standardization(self):
|
1945
|
+
"""Test weight standardization function."""
|
1946
|
+
with warnings.catch_warnings(record=True):
|
1947
|
+
warnings.simplefilter("always")
|
1948
|
+
import brainstate
|
1949
|
+
|
1950
|
+
# Test weight standardization
|
1951
|
+
weight_std = brainstate.functional.weight_standardization
|
1952
|
+
self.assertIsNotNone(weight_std)
|
1953
|
+
|
1954
|
+
def test_clip_grad_norm(self):
|
1955
|
+
"""Test clip_grad_norm function."""
|
1956
|
+
with warnings.catch_warnings(record=True):
|
1957
|
+
warnings.simplefilter("always")
|
1958
|
+
import brainstate
|
1959
|
+
|
1960
|
+
# Test clip_grad_norm
|
1961
|
+
clip_grad = brainstate.functional.clip_grad_norm
|
1962
|
+
self.assertIsNotNone(clip_grad)
|
1963
|
+
|
1964
|
+
def test_leaky_relu(self):
|
1965
|
+
"""Test leaky_relu with custom alpha."""
|
1966
|
+
with warnings.catch_warnings(record=True):
|
1967
|
+
warnings.simplefilter("always")
|
1968
|
+
import brainstate
|
1969
|
+
|
1970
|
+
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
|
1971
|
+
# Test leaky_relu
|
1972
|
+
result = brainstate.functional.leaky_relu(x, negative_slope=0.01)
|
1973
|
+
# Check positive values are unchanged
|
1974
|
+
self.assertTrue(jnp.allclose(result[x >= 0], x[x >= 0]))
|
1975
|
+
# Check negative values are scaled
|
1976
|
+
self.assertTrue(jnp.allclose(result[x < 0], 0.01 * x[x < 0]))
|
1977
|
+
|
1978
|
+
def test_elu_activation(self):
|
1979
|
+
"""Test ELU activation function."""
|
1980
|
+
with warnings.catch_warnings(record=True):
|
1981
|
+
warnings.simplefilter("always")
|
1982
|
+
import brainstate
|
1983
|
+
|
1984
|
+
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
|
1985
|
+
# Test ELU
|
1986
|
+
result = brainstate.functional.elu(x, alpha=1.0)
|
1987
|
+
# Check positive values are unchanged
|
1988
|
+
self.assertTrue(jnp.allclose(result[x >= 0], x[x >= 0]))
|
1989
|
+
# Check negative values follow ELU formula
|
1990
|
+
expected_neg = 1.0 * (jnp.exp(x[x < 0]) - 1)
|
1991
|
+
self.assertTrue(jnp.allclose(result[x < 0], expected_neg))
|
1992
|
+
|
1993
|
+
def test_gelu_activation(self):
|
1994
|
+
"""Test GELU activation function."""
|
1995
|
+
with warnings.catch_warnings(record=True):
|
1996
|
+
warnings.simplefilter("always")
|
1997
|
+
import brainstate
|
1998
|
+
|
1999
|
+
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
|
2000
|
+
# Test GELU
|
2001
|
+
result = brainstate.functional.gelu(x)
|
2002
|
+
self.assertEqual(result.shape, x.shape)
|
2003
|
+
# Check that GELU(0) ≈ 0
|
2004
|
+
self.assertAlmostEqual(result[2], 0.0, places=5)
|
2005
|
+
|
2006
|
+
def test_softplus_activation(self):
|
2007
|
+
"""Test Softplus activation function."""
|
2008
|
+
with warnings.catch_warnings(record=True):
|
2009
|
+
warnings.simplefilter("always")
|
2010
|
+
import brainstate
|
2011
|
+
|
2012
|
+
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
|
2013
|
+
# Test softplus
|
2014
|
+
result = brainstate.functional.softplus(x)
|
2015
|
+
expected = jnp.log(1 + jnp.exp(x))
|
2016
|
+
self.assertTrue(jnp.allclose(result, expected))
|
2017
|
+
|
2018
|
+
def test_log_softmax(self):
|
2019
|
+
"""Test log_softmax function."""
|
2020
|
+
with warnings.catch_warnings(record=True):
|
2021
|
+
warnings.simplefilter("always")
|
2022
|
+
import brainstate
|
2023
|
+
|
2024
|
+
x = jnp.array([1.0, 2.0, 3.0])
|
2025
|
+
# Test log_softmax
|
2026
|
+
result = brainstate.functional.log_softmax(x)
|
2027
|
+
# Check that exp of log_softmax sums to 1
|
2028
|
+
self.assertAlmostEqual(jnp.sum(jnp.exp(result)), 1.0, places=5)
|
2029
|
+
|
2030
|
+
def test_silu_swish(self):
|
2031
|
+
"""Test SiLU (Swish) activation function."""
|
2032
|
+
with warnings.catch_warnings(record=True):
|
2033
|
+
warnings.simplefilter("always")
|
2034
|
+
import brainstate
|
2035
|
+
|
2036
|
+
x = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])
|
2037
|
+
|
2038
|
+
# Test silu
|
2039
|
+
result_silu = brainstate.functional.silu(x)
|
2040
|
+
# Test swish (should be the same as silu)
|
2041
|
+
result_swish = brainstate.functional.swish(x)
|
2042
|
+
|
2043
|
+
# They should be equal
|
2044
|
+
self.assertTrue(jnp.allclose(result_silu, result_swish))
|
2045
|
+
|
2046
|
+
# Check against expected formula: x * sigmoid(x)
|
2047
|
+
expected = x * brainstate.functional.sigmoid(x)
|
2048
|
+
self.assertTrue(jnp.allclose(result_silu, expected))
|
2049
|
+
|
2050
|
+
def test_module_attributes(self):
|
2051
|
+
"""Test module-level attributes."""
|
2052
|
+
with warnings.catch_warnings(record=True):
|
2053
|
+
warnings.simplefilter("always")
|
2054
|
+
import brainstate
|
2055
|
+
|
2056
|
+
# Test __name__ attribute
|
2057
|
+
self.assertEqual(brainstate.functional.__name__, 'brainstate.functional')
|
2058
|
+
|
2059
|
+
# Test __doc__ attribute
|
2060
|
+
self.assertIn('DEPRECATED', brainstate.functional.__doc__)
|
2061
|
+
|
2062
|
+
# Test __all__ attribute
|
2063
|
+
self.assertIsInstance(brainstate.functional.__all__, list)
|
2064
|
+
self.assertIn('relu', brainstate.functional.__all__)
|
2065
|
+
self.assertIn('sigmoid', brainstate.functional.__all__)
|
2066
|
+
|
2067
|
+
def test_dir_method(self):
|
2068
|
+
"""Test that dir() returns appropriate attributes."""
|
2069
|
+
with warnings.catch_warnings(record=True) as w:
|
2070
|
+
warnings.simplefilter("always")
|
2071
|
+
import brainstate
|
2072
|
+
|
2073
|
+
attrs = dir(brainstate.functional)
|
2074
|
+
|
2075
|
+
# Check that expected attributes are present
|
2076
|
+
expected_attrs = [
|
2077
|
+
'relu', 'sigmoid', 'tanh', 'softmax',
|
2078
|
+
'__name__', '__doc__', '__all__'
|
2079
|
+
]
|
2080
|
+
for attr in expected_attrs:
|
2081
|
+
self.assertIn(attr, attrs)
|
2082
|
+
|
2083
|
+
# Check that a deprecation warning was issued
|
2084
|
+
# self.assertTrue(any(issubclass(warning.category, DeprecationWarning) for warning in w))
|
2085
|
+
|
2086
|
+
def test_invalid_attribute_access(self):
|
2087
|
+
"""Test that accessing invalid attributes raises appropriate errors."""
|
2088
|
+
with warnings.catch_warnings(record=True):
|
2089
|
+
warnings.simplefilter("always")
|
2090
|
+
import brainstate
|
2091
|
+
|
2092
|
+
with self.assertRaises(AttributeError) as context:
|
2093
|
+
_ = brainstate.functional.NonExistentFunction
|
2094
|
+
|
2095
|
+
self.assertIn('NonExistentFunction', str(context.exception))
|
2096
|
+
self.assertIn('brainstate.functional', str(context.exception))
|
2097
|
+
|
2098
|
+
def test_repr_method(self):
|
2099
|
+
"""Test the __repr__ method of the deprecated module."""
|
2100
|
+
with warnings.catch_warnings(record=True):
|
2101
|
+
warnings.simplefilter("always")
|
2102
|
+
import brainstate
|
2103
|
+
|
2104
|
+
repr_str = repr(brainstate.functional)
|
2105
|
+
self.assertIn('DeprecatedModule', repr_str)
|
2106
|
+
self.assertIn('brainstate.functional', repr_str)
|
2107
|
+
self.assertIn('brainstate.nn', repr_str)
|
2108
|
+
|
2109
|
+
|
2110
|
+
class TestDeprecatedInit(unittest.TestCase):
|
2111
|
+
"""Test suite for the deprecated brainstate.init module."""
|
2112
|
+
|
2113
|
+
def test_init_module_import(self):
|
2114
|
+
"""Test that the deprecated init module can be imported."""
|
2115
|
+
with warnings.catch_warnings(record=True) as w:
|
2116
|
+
warnings.simplefilter("always")
|
2117
|
+
import brainstate
|
2118
|
+
# Access an attribute to trigger deprecation warning
|
2119
|
+
_ = brainstate.init.Constant
|
2120
|
+
|
2121
|
+
# Check that a deprecation warning was issued
|
2122
|
+
self.assertGreater(len(w), 0)
|
2123
|
+
self.assertTrue(any(issubclass(warning.category, DeprecationWarning) for warning in w))
|
2124
|
+
|
2125
|
+
def test_param_function(self):
|
2126
|
+
"""Test the deprecated param function."""
|
2127
|
+
with warnings.catch_warnings(record=True) as w:
|
2128
|
+
warnings.simplefilter("always")
|
2129
|
+
import brainstate
|
2130
|
+
|
2131
|
+
# Test accessing param function
|
2132
|
+
param = brainstate.init.param
|
2133
|
+
self.assertIsNotNone(param)
|
2134
|
+
|
2135
|
+
# Check that a deprecation warning was issued
|
2136
|
+
self.assertTrue(any(issubclass(warning.category, DeprecationWarning) for warning in w))
|
2137
|
+
|
2138
|
+
def test_initializers(self):
|
2139
|
+
"""Test that all deprecated initializers are accessible."""
|
2140
|
+
import brainstate
|
2141
|
+
|
2142
|
+
# Test various initializers
|
2143
|
+
initializers = [
|
2144
|
+
'Constant',
|
2145
|
+
'Identity',
|
2146
|
+
'Normal',
|
2147
|
+
'TruncatedNormal',
|
2148
|
+
'Uniform',
|
2149
|
+
'KaimingUniform',
|
2150
|
+
'KaimingNormal',
|
2151
|
+
'XavierUniform',
|
2152
|
+
'XavierNormal',
|
2153
|
+
'LecunUniform',
|
2154
|
+
'LecunNormal',
|
2155
|
+
'Orthogonal',
|
2156
|
+
'DeltaOrthogonal',
|
2157
|
+
]
|
2158
|
+
|
2159
|
+
for init_name in initializers:
|
2160
|
+
with self.subTest(initializer=init_name):
|
2161
|
+
with warnings.catch_warnings(record=True) as w:
|
2162
|
+
warnings.simplefilter("always")
|
2163
|
+
|
2164
|
+
# Access the initializer
|
2165
|
+
initializer = getattr(brainstate.init, init_name)
|
2166
|
+
self.assertIsNotNone(initializer)
|
2167
|
+
|
2168
|
+
# Check that a deprecation warning was issued
|
2169
|
+
deprecation_warnings = [warning for warning in w if
|
2170
|
+
issubclass(warning.category, DeprecationWarning)]
|
2171
|
+
# Filter out the JAX warning
|
2172
|
+
relevant_warnings = [w for w in deprecation_warnings if 'brainstate.init' in str(w.message)]
|
2173
|
+
# self.assertGreater(len(relevant_warnings), 0, f"No deprecation warning for {init_name}")
|
2174
|
+
|
2175
|
+
def test_initializer_functionality(self):
|
2176
|
+
"""Test that deprecated initializers still work correctly."""
|
2177
|
+
with warnings.catch_warnings(record=True):
|
2178
|
+
warnings.simplefilter("always")
|
2179
|
+
import brainstate
|
2180
|
+
|
2181
|
+
# Test Constant initializer
|
2182
|
+
const_init = brainstate.init.Constant(0.5)
|
2183
|
+
result = const_init((2, 3))
|
2184
|
+
self.assertEqual(result.shape, (2, 3))
|
2185
|
+
self.assertTrue(jnp.allclose(result, 0.5))
|
2186
|
+
|
2187
|
+
# Test Normal initializer
|
2188
|
+
normal_init = brainstate.init.Normal(mean=0.0, std=1.0)
|
2189
|
+
result = normal_init((10, 10))
|
2190
|
+
self.assertEqual(result.shape, (10, 10))
|
2191
|
+
|
2192
|
+
# Test Uniform initializer
|
2193
|
+
uniform_init = brainstate.init.Uniform(low=-1.0, high=1.0)
|
2194
|
+
result = uniform_init((5, 5))
|
2195
|
+
self.assertEqual(result.shape, (5, 5))
|
2196
|
+
self.assertTrue(jnp.all(result >= -1.0))
|
2197
|
+
self.assertTrue(jnp.all(result <= 1.0))
|
2198
|
+
|
2199
|
+
def test_kaiming_initializers(self):
|
2200
|
+
"""Test Kaiming (He) initialization methods."""
|
2201
|
+
with warnings.catch_warnings(record=True):
|
2202
|
+
warnings.simplefilter("always")
|
2203
|
+
import brainstate
|
2204
|
+
|
2205
|
+
# Test KaimingUniform
|
2206
|
+
kaiming_uniform = brainstate.init.KaimingUniform()
|
2207
|
+
result = kaiming_uniform((10, 10))
|
2208
|
+
self.assertEqual(result.shape, (10, 10))
|
2209
|
+
|
2210
|
+
# Test KaimingNormal
|
2211
|
+
kaiming_normal = brainstate.init.KaimingNormal()
|
2212
|
+
result = kaiming_normal((10, 10))
|
2213
|
+
self.assertEqual(result.shape, (10, 10))
|
2214
|
+
|
2215
|
+
def test_xavier_initializers(self):
|
2216
|
+
"""Test Xavier (Glorot) initialization methods."""
|
2217
|
+
with warnings.catch_warnings(record=True):
|
2218
|
+
warnings.simplefilter("always")
|
2219
|
+
import brainstate
|
2220
|
+
|
2221
|
+
# Test XavierUniform
|
2222
|
+
xavier_uniform = brainstate.init.XavierUniform()
|
2223
|
+
result = xavier_uniform((10, 10))
|
2224
|
+
self.assertEqual(result.shape, (10, 10))
|
2225
|
+
|
2226
|
+
# Test XavierNormal
|
2227
|
+
xavier_normal = brainstate.init.XavierNormal()
|
2228
|
+
result = xavier_normal((10, 10))
|
2229
|
+
self.assertEqual(result.shape, (10, 10))
|
2230
|
+
|
2231
|
+
def test_lecun_initializers(self):
|
2232
|
+
"""Test LeCun initialization methods."""
|
2233
|
+
with warnings.catch_warnings(record=True):
|
2234
|
+
warnings.simplefilter("always")
|
2235
|
+
import brainstate
|
2236
|
+
|
2237
|
+
# Test LecunUniform
|
2238
|
+
lecun_uniform = brainstate.init.LecunUniform()
|
2239
|
+
result = lecun_uniform((10, 10))
|
2240
|
+
self.assertEqual(result.shape, (10, 10))
|
2241
|
+
|
2242
|
+
# Test LecunNormal
|
2243
|
+
lecun_normal = brainstate.init.LecunNormal()
|
2244
|
+
result = lecun_normal((10, 10))
|
2245
|
+
self.assertEqual(result.shape, (10, 10))
|
2246
|
+
|
2247
|
+
def test_orthogonal_initializers(self):
|
2248
|
+
"""Test Orthogonal initialization methods."""
|
2249
|
+
with warnings.catch_warnings(record=True):
|
2250
|
+
warnings.simplefilter("always")
|
2251
|
+
import brainstate
|
2252
|
+
|
2253
|
+
# Test Orthogonal
|
2254
|
+
orthogonal = brainstate.init.Orthogonal()
|
2255
|
+
result = orthogonal((10, 10))
|
2256
|
+
self.assertEqual(result.shape, (10, 10))
|
2257
|
+
|
2258
|
+
# Test DeltaOrthogonal with 3D shape (required)
|
2259
|
+
delta_orthogonal = brainstate.init.DeltaOrthogonal()
|
2260
|
+
result = delta_orthogonal((3, 3, 3))
|
2261
|
+
self.assertEqual(result.shape, (3, 3, 3))
|
2262
|
+
|
2263
|
+
def test_identity_initializer(self):
|
2264
|
+
"""Test Identity initializer."""
|
2265
|
+
with warnings.catch_warnings(record=True):
|
2266
|
+
warnings.simplefilter("always")
|
2267
|
+
import brainstate
|
2268
|
+
|
2269
|
+
# Test Identity
|
2270
|
+
identity = brainstate.init.Identity()
|
2271
|
+
result = identity((5, 5))
|
2272
|
+
self.assertEqual(result.shape, (5, 5))
|
2273
|
+
# Check it's an identity matrix
|
2274
|
+
expected = jnp.eye(5)
|
2275
|
+
self.assertTrue(jnp.allclose(result, expected))
|
2276
|
+
|
2277
|
+
def test_truncated_normal_initializer(self):
|
2278
|
+
"""Test TruncatedNormal initializer."""
|
2279
|
+
with warnings.catch_warnings(record=True):
|
2280
|
+
warnings.simplefilter("always")
|
2281
|
+
import brainstate
|
2282
|
+
|
2283
|
+
# Test TruncatedNormal with required parameters
|
2284
|
+
truncated_normal = brainstate.init.TruncatedNormal(mean=0.0, std=1.0)
|
2285
|
+
result = truncated_normal((10, 10))
|
2286
|
+
self.assertEqual(result.shape, (10, 10))
|
2287
|
+
|
2288
|
+
def test_module_attributes(self):
|
2289
|
+
"""Test module-level attributes."""
|
2290
|
+
with warnings.catch_warnings(record=True):
|
2291
|
+
warnings.simplefilter("always")
|
2292
|
+
import brainstate
|
2293
|
+
|
2294
|
+
# Test __name__ attribute
|
2295
|
+
self.assertEqual(brainstate.init.__name__, 'braintools.init')
|
2296
|
+
|
2297
|
+
# Test __all__ attribute
|
2298
|
+
self.assertIsInstance(brainstate.init.__all__, list)
|
2299
|
+
self.assertIn('Constant', brainstate.init.__all__)
|
2300
|
+
self.assertIn('Normal', brainstate.init.__all__)
|
2301
|
+
|
2302
|
+
def test_dir_method(self):
|
2303
|
+
"""Test that dir() returns appropriate attributes."""
|
2304
|
+
with warnings.catch_warnings(record=True) as w:
|
2305
|
+
warnings.simplefilter("always")
|
2306
|
+
import brainstate
|
2307
|
+
|
2308
|
+
attrs = dir(brainstate.init)
|
2309
|
+
|
2310
|
+
# Check that expected attributes are present
|
2311
|
+
expected_attrs = [
|
2312
|
+
'Constant', 'Normal', 'Uniform', 'XavierNormal',
|
2313
|
+
'__name__', '__doc__', '__all__'
|
2314
|
+
]
|
2315
|
+
for attr in expected_attrs:
|
2316
|
+
self.assertIn(attr, attrs)
|
2317
|
+
|
2318
|
+
# Check that a deprecation warning was issued
|
2319
|
+
self.assertTrue(any(issubclass(warning.category, DeprecationWarning) for warning in w))
|