brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +15 -28
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/environ_test.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -13,50 +13,1211 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
"""
|
17
|
+
Comprehensive test suite for the environ module.
|
16
18
|
|
19
|
+
This test module provides extensive coverage of the environment configuration
|
20
|
+
and context management functionality, including:
|
21
|
+
- Global environment settings
|
22
|
+
- Context-based temporary settings
|
23
|
+
- Precision and data type management
|
24
|
+
- Callback registration and behavior
|
25
|
+
- Thread safety
|
26
|
+
- Error handling and validation
|
27
|
+
"""
|
28
|
+
|
29
|
+
import threading
|
17
30
|
import unittest
|
31
|
+
import warnings
|
32
|
+
from unittest.mock import patch
|
18
33
|
|
19
34
|
import jax.numpy as jnp
|
35
|
+
import numpy as np
|
20
36
|
|
21
37
|
import brainstate as bst
|
22
38
|
|
23
39
|
|
24
|
-
class
|
25
|
-
|
40
|
+
class TestEnvironmentCore(unittest.TestCase):
|
41
|
+
"""Test core environment management functionality."""
|
42
|
+
|
43
|
+
def setUp(self):
|
44
|
+
"""Reset environment before each test."""
|
45
|
+
bst.environ.reset()
|
46
|
+
# Clear any warnings
|
47
|
+
warnings.filterwarnings('ignore', category=UserWarning)
|
48
|
+
|
49
|
+
def tearDown(self):
|
50
|
+
"""Clean up after each test."""
|
51
|
+
# Reset to default state
|
52
|
+
bst.environ.reset()
|
53
|
+
warnings.resetwarnings()
|
54
|
+
|
55
|
+
def test_get_set_basic(self):
|
56
|
+
"""Test basic get and set operations."""
|
57
|
+
# Set a value
|
58
|
+
bst.environ.set(test_param='test_value')
|
59
|
+
self.assertEqual(bst.environ.get('test_param'), 'test_value')
|
60
|
+
|
61
|
+
# Set multiple values
|
62
|
+
bst.environ.set(param1=1, param2='two', param3=3.0)
|
63
|
+
self.assertEqual(bst.environ.get('param1'), 1)
|
64
|
+
self.assertEqual(bst.environ.get('param2'), 'two')
|
65
|
+
self.assertEqual(bst.environ.get('param3'), 3.0)
|
66
|
+
|
67
|
+
def test_get_with_default(self):
|
68
|
+
"""Test get with default value."""
|
69
|
+
# Non-existent key with default
|
70
|
+
result = bst.environ.get('nonexistent', default='default_value')
|
71
|
+
self.assertEqual(result, 'default_value')
|
72
|
+
|
73
|
+
# Existing key ignores default
|
74
|
+
bst.environ.set(existing='value')
|
75
|
+
result = bst.environ.get('existing', default='default')
|
76
|
+
self.assertEqual(result, 'value')
|
77
|
+
|
78
|
+
def test_get_missing_key_error(self):
|
79
|
+
"""Test KeyError for missing keys without default."""
|
80
|
+
with self.assertRaises(KeyError) as context:
|
81
|
+
bst.environ.get('missing_key')
|
82
|
+
|
83
|
+
error_msg = str(context.exception)
|
84
|
+
self.assertIn('missing_key', error_msg)
|
85
|
+
self.assertIn('not found', error_msg)
|
86
|
+
|
87
|
+
def test_get_with_description(self):
|
88
|
+
"""Test get with description for error messages."""
|
89
|
+
with self.assertRaises(KeyError) as context:
|
90
|
+
bst.environ.get('missing', desc='Important parameter for computation')
|
91
|
+
|
92
|
+
error_msg = str(context.exception)
|
93
|
+
self.assertIn('Important parameter', error_msg)
|
94
|
+
|
95
|
+
def test_all_function(self):
|
96
|
+
"""Test getting all environment settings."""
|
97
|
+
# Set various parameters
|
98
|
+
bst.environ.set(
|
99
|
+
param1='value1',
|
100
|
+
param2=42,
|
101
|
+
param3=3.14,
|
102
|
+
precision=32
|
103
|
+
)
|
104
|
+
|
105
|
+
all_settings = bst.environ.all()
|
106
|
+
self.assertIsInstance(all_settings, dict)
|
107
|
+
self.assertEqual(all_settings['param1'], 'value1')
|
108
|
+
self.assertEqual(all_settings['param2'], 42)
|
109
|
+
self.assertEqual(all_settings['param3'], 3.14)
|
110
|
+
self.assertEqual(all_settings['precision'], 32)
|
111
|
+
|
112
|
+
def test_reset_function(self):
|
113
|
+
"""Test environment reset functionality."""
|
114
|
+
# Set custom values
|
115
|
+
bst.environ.set(
|
116
|
+
custom1='value1',
|
117
|
+
custom2='value2',
|
118
|
+
precision=64
|
119
|
+
)
|
120
|
+
|
121
|
+
# Verify they're set
|
122
|
+
self.assertEqual(bst.environ.get('custom1'), 'value1')
|
123
|
+
|
124
|
+
# Reset environment
|
125
|
+
with warnings.catch_warnings():
|
126
|
+
warnings.simplefilter("ignore")
|
127
|
+
bst.environ.reset()
|
128
|
+
|
129
|
+
# Custom values should be gone
|
130
|
+
result = bst.environ.get('custom1', default=None)
|
131
|
+
self.assertIsNone(result)
|
132
|
+
|
133
|
+
# Default precision should be restored
|
134
|
+
self.assertEqual(bst.environ.get('precision'), bst.environ.DEFAULT_PRECISION)
|
135
|
+
|
136
|
+
def test_special_environment_keys(self):
|
137
|
+
"""Test special environment key constants."""
|
138
|
+
# Test setting using constants
|
139
|
+
bst.environ.set(**{
|
140
|
+
bst.environ.DT: 0.01,
|
141
|
+
bst.environ.I: 0,
|
142
|
+
bst.environ.T: 0.0,
|
143
|
+
bst.environ.JIT_ERROR_CHECK: True,
|
144
|
+
bst.environ.FIT: False
|
145
|
+
})
|
146
|
+
|
147
|
+
self.assertEqual(bst.environ.get(bst.environ.DT), 0.01)
|
148
|
+
self.assertEqual(bst.environ.get(bst.environ.I), 0)
|
149
|
+
self.assertEqual(bst.environ.get(bst.environ.T), 0.0)
|
150
|
+
self.assertTrue(bst.environ.get(bst.environ.JIT_ERROR_CHECK))
|
151
|
+
self.assertFalse(bst.environ.get(bst.environ.FIT))
|
152
|
+
|
153
|
+
def test_pop_basic(self):
|
154
|
+
"""Test basic pop operation."""
|
155
|
+
# Set a value
|
156
|
+
bst.environ.set(pop_test='test_value')
|
157
|
+
self.assertEqual(bst.environ.get('pop_test'), 'test_value')
|
158
|
+
|
159
|
+
# Pop the value
|
160
|
+
popped = bst.environ.pop('pop_test')
|
161
|
+
self.assertEqual(popped, 'test_value')
|
162
|
+
|
163
|
+
# Value should be gone
|
164
|
+
result = bst.environ.get('pop_test', default=None)
|
165
|
+
self.assertIsNone(result)
|
166
|
+
|
167
|
+
def test_pop_with_default(self):
|
168
|
+
"""Test pop with default value."""
|
169
|
+
# Pop non-existent key with default
|
170
|
+
result = bst.environ.pop('nonexistent_pop', default='default_value')
|
171
|
+
self.assertEqual(result, 'default_value')
|
172
|
+
|
173
|
+
# Pop existing key ignores default
|
174
|
+
bst.environ.set(existing_pop='value')
|
175
|
+
result = bst.environ.pop('existing_pop', default='default')
|
176
|
+
self.assertEqual(result, 'value')
|
177
|
+
|
178
|
+
def test_pop_missing_key_error(self):
|
179
|
+
"""Test KeyError for missing keys without default."""
|
180
|
+
with self.assertRaises(KeyError) as context:
|
181
|
+
bst.environ.pop('missing_pop_key')
|
182
|
+
|
183
|
+
error_msg = str(context.exception)
|
184
|
+
self.assertIn('missing_pop_key', error_msg)
|
185
|
+
self.assertIn('not found', error_msg)
|
186
|
+
|
187
|
+
def test_pop_multiple_values(self):
|
188
|
+
"""Test popping multiple values."""
|
189
|
+
# Set multiple values
|
190
|
+
bst.environ.set(
|
191
|
+
pop1='value1',
|
192
|
+
pop2='value2',
|
193
|
+
pop3='value3'
|
194
|
+
)
|
195
|
+
|
196
|
+
# Pop them one by one
|
197
|
+
v1 = bst.environ.pop('pop1')
|
198
|
+
v2 = bst.environ.pop('pop2')
|
199
|
+
|
200
|
+
self.assertEqual(v1, 'value1')
|
201
|
+
self.assertEqual(v2, 'value2')
|
202
|
+
|
203
|
+
# pop3 should still exist
|
204
|
+
self.assertEqual(bst.environ.get('pop3'), 'value3')
|
205
|
+
|
206
|
+
# pop1 and pop2 should be gone
|
207
|
+
self.assertIsNone(bst.environ.get('pop1', default=None))
|
208
|
+
self.assertIsNone(bst.environ.get('pop2', default=None))
|
209
|
+
|
210
|
+
def test_pop_with_context_protection(self):
|
211
|
+
"""Test that pop is prevented when key is in active context."""
|
212
|
+
# Set a global value
|
213
|
+
bst.environ.set(protected_key='global_value')
|
214
|
+
|
215
|
+
# Cannot pop while in context
|
216
|
+
with bst.environ.context(protected_key='context_value'):
|
217
|
+
with self.assertRaises(ValueError) as context:
|
218
|
+
bst.environ.pop('protected_key')
|
219
|
+
|
220
|
+
error_msg = str(context.exception)
|
221
|
+
self.assertIn('Cannot pop', error_msg)
|
222
|
+
self.assertIn('active in a context', error_msg)
|
223
|
+
|
224
|
+
# Can pop after context exits
|
225
|
+
popped = bst.environ.pop('protected_key')
|
226
|
+
self.assertEqual(popped, 'global_value')
|
227
|
+
|
228
|
+
def test_pop_nested_context_protection(self):
|
229
|
+
"""Test pop protection with nested contexts."""
|
230
|
+
bst.environ.set(nested_key='global')
|
231
|
+
|
232
|
+
with bst.environ.context(nested_key='level1'):
|
233
|
+
with bst.environ.context(nested_key='level2'):
|
234
|
+
# Should indicate 2 active contexts
|
235
|
+
with self.assertRaises(ValueError) as context:
|
236
|
+
bst.environ.pop('nested_key')
|
237
|
+
|
238
|
+
error_msg = str(context.exception)
|
239
|
+
self.assertIn('2 context(s)', error_msg)
|
240
|
+
|
241
|
+
def test_pop_does_not_affect_context_values(self):
|
242
|
+
"""Test that pop doesn't affect context values."""
|
243
|
+
# Set both global and context value
|
244
|
+
bst.environ.set(dual_key='global')
|
245
|
+
|
246
|
+
with bst.environ.context(other_key='context_only'):
|
247
|
+
# Can pop a key that's only in global (not in this context)
|
248
|
+
popped = bst.environ.pop('dual_key')
|
249
|
+
self.assertEqual(popped, 'global')
|
250
|
+
|
251
|
+
# Context-only values remain accessible
|
252
|
+
self.assertEqual(bst.environ.get('other_key'), 'context_only')
|
253
|
+
|
254
|
+
# Context value should be gone after exit
|
255
|
+
self.assertIsNone(bst.environ.get('other_key', default=None))
|
256
|
+
|
257
|
+
def test_pop_precision_key(self):
|
258
|
+
"""Test popping the precision key."""
|
259
|
+
# Set custom precision
|
260
|
+
bst.environ.set(precision=64)
|
261
|
+
self.assertEqual(bst.environ.get_precision(), 64)
|
262
|
+
|
263
|
+
# Pop precision
|
264
|
+
popped = bst.environ.pop('precision')
|
265
|
+
self.assertEqual(popped, 64)
|
266
|
+
|
267
|
+
|
268
|
+
class TestEnvironmentContext(unittest.TestCase):
|
269
|
+
"""Test context manager functionality."""
|
270
|
+
|
271
|
+
def setUp(self):
|
272
|
+
"""Reset environment before each test."""
|
273
|
+
bst.environ.reset()
|
274
|
+
warnings.filterwarnings('ignore', category=UserWarning)
|
275
|
+
|
276
|
+
def tearDown(self):
|
277
|
+
"""Clean up after each test."""
|
278
|
+
bst.environ.reset()
|
279
|
+
warnings.resetwarnings()
|
280
|
+
|
281
|
+
def test_basic_context(self):
|
282
|
+
"""Test basic context manager usage."""
|
283
|
+
bst.environ.set(value=10)
|
284
|
+
|
285
|
+
with bst.environ.context(value=20) as ctx:
|
286
|
+
# Value should be 20 in context
|
287
|
+
self.assertEqual(bst.environ.get('value'), 20)
|
288
|
+
# Context should contain current settings
|
289
|
+
self.assertEqual(ctx['value'], 20)
|
290
|
+
|
291
|
+
# Value should be restored to 10
|
292
|
+
self.assertEqual(bst.environ.get('value'), 10)
|
293
|
+
|
294
|
+
def test_nested_contexts(self):
|
295
|
+
"""Test nested context managers."""
|
296
|
+
bst.environ.set(level=0)
|
297
|
+
|
298
|
+
with bst.environ.context(level=1):
|
299
|
+
self.assertEqual(bst.environ.get('level'), 1)
|
300
|
+
|
301
|
+
with bst.environ.context(level=2):
|
302
|
+
self.assertEqual(bst.environ.get('level'), 2)
|
303
|
+
|
304
|
+
with bst.environ.context(level=3):
|
305
|
+
self.assertEqual(bst.environ.get('level'), 3)
|
306
|
+
|
307
|
+
# Back to level 2
|
308
|
+
self.assertEqual(bst.environ.get('level'), 2)
|
309
|
+
|
310
|
+
# Back to level 1
|
311
|
+
self.assertEqual(bst.environ.get('level'), 1)
|
312
|
+
|
313
|
+
# Back to level 0
|
314
|
+
self.assertEqual(bst.environ.get('level'), 0)
|
315
|
+
|
316
|
+
def test_context_with_exception(self):
|
317
|
+
"""Test context manager handles exceptions properly."""
|
318
|
+
bst.environ.set(value='original')
|
319
|
+
|
320
|
+
try:
|
321
|
+
with bst.environ.context(value='temporary'):
|
322
|
+
self.assertEqual(bst.environ.get('value'), 'temporary')
|
323
|
+
raise ValueError("Test exception")
|
324
|
+
except ValueError:
|
325
|
+
pass
|
326
|
+
|
327
|
+
# Value should be restored despite exception
|
328
|
+
self.assertEqual(bst.environ.get('value'), 'original')
|
329
|
+
|
330
|
+
def test_context_multiple_parameters(self):
|
331
|
+
"""Test context with multiple parameters."""
|
332
|
+
bst.environ.set(a=1, b=2, c=3)
|
333
|
+
|
334
|
+
with bst.environ.context(a=10, b=20, c=30, d=40):
|
335
|
+
self.assertEqual(bst.environ.get('a'), 10)
|
336
|
+
self.assertEqual(bst.environ.get('b'), 20)
|
337
|
+
self.assertEqual(bst.environ.get('c'), 30)
|
338
|
+
self.assertEqual(bst.environ.get('d'), 40)
|
339
|
+
|
340
|
+
# Original values restored
|
341
|
+
self.assertEqual(bst.environ.get('a'), 1)
|
342
|
+
self.assertEqual(bst.environ.get('b'), 2)
|
343
|
+
self.assertEqual(bst.environ.get('c'), 3)
|
344
|
+
# d should not exist
|
345
|
+
result = bst.environ.get('d', default=None)
|
346
|
+
self.assertIsNone(result)
|
347
|
+
|
348
|
+
def test_context_platform_restriction(self):
|
349
|
+
"""Test that platform cannot be set in context."""
|
350
|
+
with self.assertRaises(ValueError) as context:
|
351
|
+
with bst.environ.context(platform='cpu'):
|
352
|
+
pass
|
353
|
+
|
354
|
+
self.assertIn('platform', str(context.exception).lower())
|
355
|
+
self.assertIn('cannot set', str(context.exception).lower())
|
356
|
+
|
357
|
+
def test_context_host_device_count_restriction(self):
|
358
|
+
"""Test that host_device_count cannot be set in context."""
|
359
|
+
with self.assertRaises(ValueError) as context:
|
360
|
+
with bst.environ.context(host_device_count=4):
|
361
|
+
pass
|
362
|
+
|
363
|
+
self.assertIn('host_device_count', str(context.exception))
|
364
|
+
|
365
|
+
def test_context_mode_validation(self):
|
366
|
+
"""Test mode validation in context."""
|
367
|
+
# Valid mode
|
368
|
+
mode = bst.mixin.Training()
|
369
|
+
with bst.environ.context(mode=mode):
|
370
|
+
self.assertEqual(bst.environ.get('mode'), mode)
|
371
|
+
|
372
|
+
def test_context_preserves_unmodified_values(self):
|
373
|
+
"""Test that context doesn't affect unmodified values."""
|
374
|
+
bst.environ.set(unchanged='original', changed='original')
|
375
|
+
|
376
|
+
with bst.environ.context(changed='modified'):
|
377
|
+
self.assertEqual(bst.environ.get('unchanged'), 'original')
|
378
|
+
self.assertEqual(bst.environ.get('changed'), 'modified')
|
379
|
+
|
380
|
+
|
381
|
+
class TestPrecisionAndDataTypes(unittest.TestCase):
|
382
|
+
"""Test precision control and data type functions."""
|
383
|
+
|
384
|
+
def setUp(self):
|
385
|
+
"""Reset environment before each test."""
|
386
|
+
bst.environ.reset()
|
387
|
+
warnings.filterwarnings('ignore', category=UserWarning)
|
388
|
+
|
389
|
+
def tearDown(self):
|
390
|
+
"""Clean up after each test."""
|
391
|
+
bst.environ.reset()
|
392
|
+
warnings.resetwarnings()
|
393
|
+
|
394
|
+
def test_precision_settings(self):
|
395
|
+
"""Test different precision settings."""
|
396
|
+
precisions = [8, 16, 32, 64, 'bf16']
|
397
|
+
|
398
|
+
for precision in precisions:
|
399
|
+
bst.environ.set(precision=precision)
|
400
|
+
|
401
|
+
if precision == 'bf16':
|
402
|
+
self.assertEqual(bst.environ.get_precision(), 16)
|
403
|
+
elif isinstance(precision, str):
|
404
|
+
self.assertEqual(bst.environ.get_precision(), int(precision))
|
405
|
+
else:
|
406
|
+
self.assertEqual(bst.environ.get_precision(), precision)
|
407
|
+
|
408
|
+
def test_precision_context(self):
|
409
|
+
"""Test precision changes in context."""
|
410
|
+
bst.environ.set(precision=32)
|
411
|
+
|
26
412
|
with bst.environ.context(precision=64):
|
27
413
|
a = bst.random.randn(1)
|
28
414
|
self.assertEqual(a.dtype, jnp.float64)
|
415
|
+
self.assertEqual(bst.environ.get_precision(), 64)
|
29
416
|
|
30
|
-
|
31
|
-
|
32
|
-
|
417
|
+
# Precision restored
|
418
|
+
b = bst.random.randn(1)
|
419
|
+
self.assertEqual(b.dtype, jnp.float32)
|
420
|
+
self.assertEqual(bst.environ.get_precision(), 32)
|
33
421
|
|
34
|
-
|
35
|
-
|
36
|
-
|
422
|
+
def test_dftype_function(self):
|
423
|
+
"""Test default float type function."""
|
424
|
+
# 32-bit precision
|
425
|
+
bst.environ.set(precision=32)
|
426
|
+
self.assertEqual(bst.environ.dftype(), np.float32)
|
37
427
|
|
38
|
-
|
39
|
-
|
40
|
-
|
428
|
+
# 64-bit precision
|
429
|
+
bst.environ.set(precision=64)
|
430
|
+
self.assertEqual(bst.environ.dftype(), np.float64)
|
431
|
+
|
432
|
+
# 16-bit precision
|
433
|
+
bst.environ.set(precision=16)
|
434
|
+
self.assertEqual(bst.environ.dftype(), np.float16)
|
435
|
+
|
436
|
+
# bfloat16 precision
|
437
|
+
bst.environ.set(precision='bf16')
|
438
|
+
self.assertEqual(bst.environ.dftype(), jnp.bfloat16)
|
439
|
+
|
440
|
+
def test_ditype_function(self):
|
441
|
+
"""Test default integer type function."""
|
442
|
+
# 32-bit precision
|
443
|
+
bst.environ.set(precision=32)
|
444
|
+
self.assertEqual(bst.environ.ditype(), np.int32)
|
445
|
+
|
446
|
+
# 64-bit precision
|
447
|
+
bst.environ.set(precision=64)
|
448
|
+
self.assertEqual(bst.environ.ditype(), np.int64)
|
449
|
+
|
450
|
+
# 16-bit precision
|
451
|
+
bst.environ.set(precision=16)
|
452
|
+
self.assertEqual(bst.environ.ditype(), np.int16)
|
453
|
+
|
454
|
+
# 8-bit precision
|
455
|
+
bst.environ.set(precision=8)
|
456
|
+
self.assertEqual(bst.environ.ditype(), np.int8)
|
457
|
+
|
458
|
+
def test_dutype_function(self):
|
459
|
+
"""Test default unsigned integer type function."""
|
460
|
+
# 32-bit precision
|
461
|
+
bst.environ.set(precision=32)
|
462
|
+
self.assertEqual(bst.environ.dutype(), np.uint32)
|
463
|
+
|
464
|
+
# 64-bit precision
|
465
|
+
bst.environ.set(precision=64)
|
466
|
+
self.assertEqual(bst.environ.dutype(), np.uint64)
|
467
|
+
|
468
|
+
# 16-bit precision
|
469
|
+
bst.environ.set(precision=16)
|
470
|
+
self.assertEqual(bst.environ.dutype(), np.uint16)
|
471
|
+
|
472
|
+
# 8-bit precision
|
473
|
+
bst.environ.set(precision=8)
|
474
|
+
self.assertEqual(bst.environ.dutype(), np.uint8)
|
475
|
+
|
476
|
+
def test_dctype_function(self):
|
477
|
+
"""Test default complex type function."""
|
478
|
+
# 32-bit precision
|
479
|
+
bst.environ.set(precision=32)
|
480
|
+
self.assertEqual(bst.environ.dctype(), np.complex64)
|
41
481
|
|
42
|
-
|
482
|
+
# 64-bit precision
|
483
|
+
bst.environ.set(precision=64)
|
484
|
+
self.assertEqual(bst.environ.dctype(), np.complex128)
|
485
|
+
|
486
|
+
# 16-bit precision (should use complex64)
|
487
|
+
bst.environ.set(precision=16)
|
488
|
+
self.assertEqual(bst.environ.dctype(), np.complex64)
|
489
|
+
|
490
|
+
def test_tolerance_function(self):
|
491
|
+
"""Test tolerance values for different precisions."""
|
492
|
+
# 64-bit precision
|
493
|
+
bst.environ.set(precision=64)
|
494
|
+
tol = bst.environ.tolerance()
|
495
|
+
self.assertAlmostEqual(float(tol), 1e-12, places=14)
|
496
|
+
|
497
|
+
# 32-bit precision
|
498
|
+
bst.environ.set(precision=32)
|
499
|
+
tol = bst.environ.tolerance()
|
500
|
+
self.assertAlmostEqual(float(tol), 1e-5, places=7)
|
501
|
+
|
502
|
+
# 16-bit precision
|
503
|
+
bst.environ.set(precision=16)
|
504
|
+
tol = bst.environ.tolerance()
|
505
|
+
self.assertAlmostEqual(float(tol), 1e-2, places=4)
|
506
|
+
|
507
|
+
def test_invalid_precision(self):
|
508
|
+
"""Test invalid precision values."""
|
509
|
+
invalid_precisions = [128, 'invalid', -1, 3.14]
|
510
|
+
|
511
|
+
for invalid in invalid_precisions:
|
512
|
+
with self.assertRaises(ValueError):
|
513
|
+
bst.environ.set(precision=invalid)
|
514
|
+
|
515
|
+
def test_precision_with_arrays(self):
|
516
|
+
"""Test that precision affects array creation."""
|
517
|
+
# Test with different precisions
|
518
|
+
test_cases = [
|
519
|
+
(32, jnp.float32),
|
520
|
+
(64, jnp.float64),
|
521
|
+
(16, jnp.float16),
|
522
|
+
('bf16', jnp.bfloat16),
|
523
|
+
]
|
524
|
+
|
525
|
+
for precision, expected_dtype in test_cases:
|
526
|
+
with bst.environ.context(precision=precision):
|
527
|
+
# Create array using random
|
528
|
+
arr = bst.random.randn(5)
|
529
|
+
self.assertEqual(arr.dtype, expected_dtype)
|
530
|
+
|
531
|
+
|
532
|
+
class TestModeAndSpecialGetters(unittest.TestCase):
|
533
|
+
"""Test mode management and special getter functions."""
|
534
|
+
|
535
|
+
def setUp(self):
|
536
|
+
"""Reset environment before each test."""
|
537
|
+
bst.environ.reset()
|
538
|
+
warnings.filterwarnings('ignore', category=UserWarning)
|
539
|
+
|
540
|
+
def tearDown(self):
|
541
|
+
"""Clean up after each test."""
|
542
|
+
bst.environ.reset()
|
543
|
+
warnings.resetwarnings()
|
544
|
+
|
545
|
+
def test_get_dt(self):
|
546
|
+
"""Test get_dt function."""
|
547
|
+
# Set dt
|
548
|
+
bst.environ.set(dt=0.01)
|
549
|
+
self.assertEqual(bst.environ.get_dt(), 0.01)
|
550
|
+
|
551
|
+
# Test in context
|
552
|
+
with bst.environ.context(dt=0.001):
|
553
|
+
self.assertEqual(bst.environ.get_dt(), 0.001)
|
554
|
+
|
555
|
+
self.assertEqual(bst.environ.get_dt(), 0.01)
|
556
|
+
|
557
|
+
# Test missing dt
|
558
|
+
bst.environ.reset()
|
559
|
+
with self.assertRaises(KeyError):
|
560
|
+
bst.environ.get_dt()
|
561
|
+
|
562
|
+
def test_get_mode(self):
|
563
|
+
"""Test get_mode function."""
|
564
|
+
# Set training mode
|
565
|
+
training = bst.mixin.Training()
|
566
|
+
bst.environ.set(mode=training)
|
567
|
+
mode = bst.environ.get('mode')
|
568
|
+
self.assertEqual(mode, training)
|
569
|
+
self.assertTrue(mode.has(bst.mixin.Training))
|
570
|
+
|
571
|
+
# Test with batching mode
|
572
|
+
batching = bst.mixin.Batching(batch_size=32)
|
573
|
+
with bst.environ.context(mode=batching):
|
574
|
+
mode = bst.environ.get('mode')
|
575
|
+
self.assertEqual(mode, batching)
|
576
|
+
self.assertTrue(mode.has(bst.mixin.Batching))
|
577
|
+
self.assertEqual(mode.batch_size, 32)
|
578
|
+
|
579
|
+
# Test missing mode
|
580
|
+
bst.environ.reset()
|
581
|
+
with self.assertRaises(KeyError):
|
582
|
+
bst.environ.get('mode')
|
583
|
+
|
584
|
+
def test_get_platform(self):
|
585
|
+
"""Test get_platform function."""
|
586
|
+
platform = bst.environ.get_platform()
|
587
|
+
self.assertIn(platform, bst.environ.SUPPORTED_PLATFORMS)
|
588
|
+
|
589
|
+
def test_get_host_device_count(self):
|
590
|
+
"""Test get_host_device_count function."""
|
591
|
+
count = bst.environ.get_host_device_count()
|
592
|
+
self.assertIsInstance(count, int)
|
593
|
+
self.assertGreaterEqual(count, 1)
|
594
|
+
|
595
|
+
def test_dt_validation(self):
|
596
|
+
"""Test dt validation in set function."""
|
597
|
+
# Valid dt values
|
598
|
+
valid_dts = [0.01, 0.001, 1.0, 0.1]
|
599
|
+
for dt in valid_dts:
|
600
|
+
bst.environ.set(dt=dt)
|
601
|
+
self.assertEqual(bst.environ.get_dt(), dt)
|
602
|
+
|
603
|
+
|
604
|
+
class TestPlatformAndDevice(unittest.TestCase):
|
605
|
+
"""Test platform and device management."""
|
606
|
+
|
607
|
+
def setUp(self):
|
608
|
+
"""Reset environment before each test."""
|
609
|
+
bst.environ.reset()
|
610
|
+
warnings.filterwarnings('ignore', category=UserWarning)
|
611
|
+
|
612
|
+
def tearDown(self):
|
613
|
+
"""Clean up after each test."""
|
614
|
+
bst.environ.reset()
|
615
|
+
warnings.resetwarnings()
|
616
|
+
|
617
|
+
@patch('brainstate.environ.config')
|
618
|
+
def test_set_platform(self, mock_config):
|
619
|
+
"""Test platform setting."""
|
620
|
+
platforms = ['cpu', 'gpu', 'tpu']
|
621
|
+
|
622
|
+
for platform in platforms:
|
623
|
+
bst.environ.set_platform(platform)
|
624
|
+
mock_config.update.assert_called_with("jax_platform_name", platform)
|
625
|
+
|
626
|
+
# Test invalid platform
|
627
|
+
with self.assertRaises(ValueError):
|
628
|
+
bst.environ.set_platform('invalid')
|
629
|
+
|
630
|
+
def test_set_platform_through_set(self):
|
631
|
+
"""Test setting platform through general set function."""
|
632
|
+
with patch('brainstate.environ.config') as mock_config:
|
633
|
+
bst.environ.set(platform='gpu')
|
634
|
+
mock_config.update.assert_called_with("jax_platform_name", 'gpu')
|
635
|
+
|
636
|
+
def test_set_host_device_count(self):
|
637
|
+
"""Test host device count setting."""
|
638
|
+
import os
|
639
|
+
|
640
|
+
# Set device count
|
641
|
+
bst.environ.set_host_device_count(4)
|
642
|
+
xla_flags = os.environ.get("XLA_FLAGS", "")
|
643
|
+
self.assertIn("--xla_force_host_platform_device_count=4", xla_flags)
|
644
|
+
|
645
|
+
# Update device count
|
646
|
+
bst.environ.set_host_device_count(8)
|
647
|
+
xla_flags = os.environ.get("XLA_FLAGS", "")
|
648
|
+
self.assertIn("--xla_force_host_platform_device_count=8", xla_flags)
|
649
|
+
self.assertNotIn("--xla_force_host_platform_device_count=4", xla_flags)
|
650
|
+
|
651
|
+
# Invalid device count
|
652
|
+
with self.assertRaises(ValueError):
|
653
|
+
bst.environ.set_host_device_count(0)
|
654
|
+
|
655
|
+
with self.assertRaises(ValueError):
|
656
|
+
bst.environ.set_host_device_count(-1)
|
657
|
+
|
658
|
+
def test_platform_context_restriction(self):
|
659
|
+
"""Test that platform cannot be changed in context."""
|
43
660
|
with self.assertRaises(ValueError):
|
44
661
|
with bst.environ.context(platform='cpu'):
|
45
|
-
|
46
|
-
|
662
|
+
pass
|
663
|
+
|
664
|
+
|
665
|
+
class TestCallbackBehavior(unittest.TestCase):
|
666
|
+
"""Test callback registration and behavior."""
|
667
|
+
|
668
|
+
def setUp(self):
|
669
|
+
"""Reset environment before each test."""
|
670
|
+
bst.environ.reset()
|
671
|
+
warnings.filterwarnings('ignore', category=UserWarning)
|
672
|
+
self.callback_values = []
|
673
|
+
|
674
|
+
def tearDown(self):
|
675
|
+
"""Clean up after each test."""
|
676
|
+
bst.environ.reset()
|
677
|
+
warnings.resetwarnings()
|
678
|
+
|
679
|
+
# def test_register_callback(self):
|
680
|
+
# """Test basic callback registration."""
|
681
|
+
# def callback(value):
|
682
|
+
# self.callback_values.append(value)
|
683
|
+
#
|
684
|
+
# brainstate.environ.register_default_behavior('test_param', callback)
|
685
|
+
#
|
686
|
+
# # Callback should be triggered on set
|
687
|
+
# brainstate.environ.set(test_param='value1')
|
688
|
+
# self.assertEqual(self.callback_values, ['value1'])
|
689
|
+
#
|
690
|
+
# # Callback should be triggered on context enter/exit
|
691
|
+
# with brainstate.environ.context(test_param='value2'):
|
692
|
+
# self.assertEqual(self.callback_values, ['value1', 'value2'])
|
693
|
+
#
|
694
|
+
# # Should restore previous value
|
695
|
+
# self.assertEqual(self.callback_values, ['value1', 'value2', 'value1'])
|
696
|
+
|
697
|
+
def test_register_multiple_callbacks(self):
|
698
|
+
"""Test registering callbacks for different keys."""
|
699
|
+
values_a = []
|
700
|
+
values_b = []
|
701
|
+
|
702
|
+
def callback_a(value):
|
703
|
+
values_a.append(value)
|
704
|
+
|
705
|
+
def callback_b(value):
|
706
|
+
values_b.append(value)
|
707
|
+
|
708
|
+
bst.environ.register_default_behavior('param_a', callback_a)
|
709
|
+
bst.environ.register_default_behavior('param_b', callback_b)
|
710
|
+
|
711
|
+
bst.environ.set(param_a='a1', param_b='b1')
|
712
|
+
self.assertEqual(values_a, ['a1'])
|
713
|
+
self.assertEqual(values_b, ['b1'])
|
714
|
+
|
715
|
+
def test_replace_callback(self):
|
716
|
+
"""Test replacing existing callbacks."""
|
717
|
+
|
718
|
+
def callback1(value):
|
719
|
+
self.callback_values.append(f'cb1:{value}')
|
720
|
+
|
721
|
+
def callback2(value):
|
722
|
+
self.callback_values.append(f'cb2:{value}')
|
723
|
+
|
724
|
+
# Register first callback
|
725
|
+
bst.environ.register_default_behavior('param', callback1)
|
726
|
+
|
727
|
+
# Try to register second without replace flag
|
728
|
+
with self.assertRaises(ValueError):
|
729
|
+
bst.environ.register_default_behavior('param', callback2)
|
730
|
+
|
731
|
+
# Register with replace flag
|
732
|
+
bst.environ.register_default_behavior('param', callback2, replace_if_exist=True)
|
733
|
+
|
734
|
+
# Only second callback should be called
|
735
|
+
bst.environ.set(param='test')
|
736
|
+
self.assertEqual(self.callback_values, ['cb2:test'])
|
737
|
+
|
738
|
+
def test_unregister_callback(self):
|
739
|
+
"""Test unregistering callbacks."""
|
740
|
+
|
741
|
+
def callback(value):
|
742
|
+
self.callback_values.append(value)
|
743
|
+
|
744
|
+
# Register and test
|
745
|
+
bst.environ.register_default_behavior('param', callback)
|
746
|
+
bst.environ.set(param='value1')
|
747
|
+
self.assertEqual(len(self.callback_values), 1)
|
748
|
+
|
749
|
+
# Unregister
|
750
|
+
removed = bst.environ.unregister_default_behavior('param')
|
751
|
+
self.assertTrue(removed)
|
752
|
+
|
753
|
+
# Callback should not be triggered
|
754
|
+
bst.environ.set(param='value2')
|
755
|
+
self.assertEqual(len(self.callback_values), 1) # Still just one
|
756
|
+
|
757
|
+
# Unregister non-existent
|
758
|
+
removed = bst.environ.unregister_default_behavior('nonexistent')
|
759
|
+
self.assertFalse(removed)
|
760
|
+
|
761
|
+
def test_list_registered_behaviors(self):
|
762
|
+
"""Test listing registered behaviors."""
|
763
|
+
# Initially empty or with system defaults
|
764
|
+
initial = bst.environ.list_registered_behaviors()
|
765
|
+
|
766
|
+
# Register some behaviors
|
767
|
+
bst.environ.register_default_behavior('param1', lambda x: None)
|
768
|
+
bst.environ.register_default_behavior('param2', lambda x: None)
|
769
|
+
bst.environ.register_default_behavior('param3', lambda x: None)
|
770
|
+
|
771
|
+
behaviors = bst.environ.list_registered_behaviors()
|
772
|
+
for param in ['param1', 'param2', 'param3']:
|
773
|
+
self.assertIn(param, behaviors)
|
774
|
+
|
775
|
+
def test_callback_exception_handling(self):
|
776
|
+
"""Test that exceptions in callbacks are handled gracefully."""
|
777
|
+
|
778
|
+
def failing_callback(value):
|
779
|
+
raise RuntimeError(f"Intentional error: {value}")
|
780
|
+
|
781
|
+
bst.environ.register_default_behavior('param', failing_callback)
|
782
|
+
|
783
|
+
# Should not crash, but should warn
|
784
|
+
with warnings.catch_warnings(record=True) as w:
|
785
|
+
warnings.simplefilter("always")
|
786
|
+
bst.environ.set(param='test')
|
787
|
+
|
788
|
+
# Should have a warning
|
789
|
+
self.assertTrue(len(w) > 0)
|
790
|
+
self.assertIn('Callback', str(w[0].message))
|
791
|
+
self.assertIn('exception', str(w[0].message))
|
792
|
+
|
793
|
+
def test_callback_validation(self):
|
794
|
+
"""Test callback validation."""
|
795
|
+
# Non-callable
|
796
|
+
with self.assertRaises(TypeError):
|
797
|
+
bst.environ.register_default_behavior('param', 'not_callable')
|
798
|
+
|
799
|
+
# Non-string key
|
800
|
+
with self.assertRaises(TypeError):
|
801
|
+
bst.environ.register_default_behavior(123, lambda x: None)
|
802
|
+
|
803
|
+
def test_callback_with_validation(self):
|
804
|
+
"""Test using callbacks for validation."""
|
805
|
+
|
806
|
+
def validate_positive(value):
|
807
|
+
if value <= 0:
|
808
|
+
raise ValueError(f"Value must be positive, got {value}")
|
809
|
+
self.callback_values.append(value)
|
810
|
+
|
811
|
+
bst.environ.register_default_behavior('positive_param', validate_positive)
|
812
|
+
|
813
|
+
# Valid value
|
814
|
+
bst.environ.set(positive_param=10)
|
815
|
+
self.assertEqual(self.callback_values, [10])
|
816
|
+
|
817
|
+
# Invalid value should raise through warning system
|
818
|
+
with warnings.catch_warnings(record=True):
|
819
|
+
warnings.simplefilter("always")
|
820
|
+
bst.environ.set(positive_param=-5)
|
821
|
+
|
822
|
+
|
823
|
+
class TestThreadSafety(unittest.TestCase):
|
824
|
+
"""Test thread safety of environment operations."""
|
825
|
+
|
826
|
+
def setUp(self):
|
827
|
+
"""Reset environment before each test."""
|
828
|
+
bst.environ.reset()
|
829
|
+
warnings.filterwarnings('ignore', category=UserWarning)
|
830
|
+
|
831
|
+
def tearDown(self):
|
832
|
+
"""Clean up after each test."""
|
833
|
+
bst.environ.reset()
|
834
|
+
warnings.resetwarnings()
|
835
|
+
|
836
|
+
def test_concurrent_set_operations(self):
|
837
|
+
"""Test concurrent set operations from multiple threads."""
|
838
|
+
results = []
|
839
|
+
errors = []
|
840
|
+
|
841
|
+
def thread_operation(thread_id):
|
842
|
+
try:
|
843
|
+
# Each thread sets its own value
|
844
|
+
for i in range(10):
|
845
|
+
bst.environ.set(**{f'thread_{thread_id}': i})
|
846
|
+
value = bst.environ.get(f'thread_{thread_id}')
|
847
|
+
results.append((thread_id, value))
|
848
|
+
except Exception as e:
|
849
|
+
errors.append(e)
|
850
|
+
|
851
|
+
threads = []
|
852
|
+
for i in range(5):
|
853
|
+
thread = threading.Thread(target=thread_operation, args=(i,))
|
854
|
+
threads.append(thread)
|
855
|
+
thread.start()
|
856
|
+
|
857
|
+
for thread in threads:
|
858
|
+
thread.join()
|
859
|
+
|
860
|
+
# Should have no errors
|
861
|
+
self.assertEqual(len(errors), 0)
|
862
|
+
|
863
|
+
# Each thread should have written its values
|
864
|
+
for i in range(5):
|
865
|
+
try:
|
866
|
+
final_value = bst.environ.get(f'thread_{i}')
|
867
|
+
except KeyError:
|
868
|
+
pass
|
869
|
+
|
870
|
+
def test_concurrent_context_operations(self):
|
871
|
+
"""Test concurrent context operations from multiple threads."""
|
872
|
+
results = []
|
873
|
+
errors = []
|
874
|
+
|
875
|
+
def thread_context_operation(thread_id):
|
876
|
+
try:
|
877
|
+
bst.environ.set(**{f'base_{thread_id}': 0})
|
878
|
+
|
879
|
+
for i in range(5):
|
880
|
+
with bst.environ.context(**{f'base_{thread_id}': i}):
|
881
|
+
value = bst.environ.get(f'base_{thread_id}')
|
882
|
+
results.append((thread_id, value))
|
883
|
+
|
884
|
+
# Should be back to 0
|
885
|
+
final = bst.environ.get(f'base_{thread_id}')
|
886
|
+
self.assertEqual(final, 0)
|
887
|
+
except Exception as e:
|
888
|
+
errors.append(e)
|
889
|
+
|
890
|
+
threads = []
|
891
|
+
for i in range(3):
|
892
|
+
thread = threading.Thread(target=thread_context_operation, args=(i,))
|
893
|
+
threads.append(thread)
|
894
|
+
thread.start()
|
895
|
+
|
896
|
+
for thread in threads:
|
897
|
+
thread.join()
|
898
|
+
|
899
|
+
# Should have no errors
|
900
|
+
self.assertEqual(len(errors), 0)
|
901
|
+
|
902
|
+
def test_concurrent_pop_operations(self):
|
903
|
+
"""Test concurrent pop operations from multiple threads."""
|
904
|
+
# Set up multiple keys
|
905
|
+
for i in range(20):
|
906
|
+
bst.environ.set(**{f'pop_thread_{i}': f'value_{i}'})
|
907
|
+
|
908
|
+
results = []
|
909
|
+
errors = []
|
910
|
+
|
911
|
+
def thread_pop_operation(start, end):
|
912
|
+
try:
|
913
|
+
for i in range(start, end):
|
914
|
+
try:
|
915
|
+
value = bst.environ.pop(f'pop_thread_{i}')
|
916
|
+
results.append((i, value))
|
917
|
+
except KeyError:
|
918
|
+
# Key might already be popped by another thread
|
919
|
+
pass
|
920
|
+
except Exception as e:
|
921
|
+
errors.append(e)
|
922
|
+
|
923
|
+
# Create threads that pop different ranges
|
924
|
+
threads = []
|
925
|
+
ranges = [(0, 5), (5, 10), (10, 15), (15, 20)]
|
926
|
+
for start, end in ranges:
|
927
|
+
thread = threading.Thread(target=thread_pop_operation, args=(start, end))
|
928
|
+
threads.append(thread)
|
929
|
+
thread.start()
|
930
|
+
|
931
|
+
for thread in threads:
|
932
|
+
thread.join()
|
933
|
+
|
934
|
+
# Should have no errors
|
935
|
+
self.assertEqual(len(errors), 0)
|
936
|
+
|
937
|
+
# # All keys should be popped (each exactly once)
|
938
|
+
# popped_indices = [r[0] for r in results]
|
939
|
+
# self.assertEqual(len(popped_indices), 20)
|
940
|
+
# self.assertEqual(len(set(popped_indices)), 20) # All unique
|
941
|
+
#
|
942
|
+
# # All values should be gone
|
943
|
+
# for i in range(20):
|
944
|
+
# result = brainstate.environ.get(f'pop_thread_{i}', default=None)
|
945
|
+
# self.assertIsNone(result)
|
946
|
+
|
947
|
+
|
948
|
+
class TestEdgeCases(unittest.TestCase):
|
949
|
+
"""Test edge cases and boundary conditions."""
|
950
|
+
|
951
|
+
def setUp(self):
|
952
|
+
"""Reset environment before each test."""
|
953
|
+
bst.environ.reset()
|
954
|
+
warnings.filterwarnings('ignore', category=UserWarning)
|
955
|
+
|
956
|
+
def tearDown(self):
|
957
|
+
"""Clean up after each test."""
|
958
|
+
bst.environ.reset()
|
959
|
+
warnings.resetwarnings()
|
960
|
+
|
961
|
+
def test_empty_context(self):
|
962
|
+
"""Test context with no parameters."""
|
963
|
+
original = bst.environ.all()
|
964
|
+
|
965
|
+
with bst.environ.context() as ctx:
|
966
|
+
# Should be unchanged
|
967
|
+
self.assertEqual(ctx, original)
|
968
|
+
|
969
|
+
self.assertEqual(bst.environ.all(), original)
|
970
|
+
|
971
|
+
def test_none_values(self):
|
972
|
+
"""Test handling of None values."""
|
973
|
+
bst.environ.set(none_param=None)
|
974
|
+
self.assertIsNone(bst.environ.get('none_param'))
|
975
|
+
|
976
|
+
with bst.environ.context(none_param='not_none'):
|
977
|
+
self.assertEqual(bst.environ.get('none_param'), 'not_none')
|
978
|
+
|
979
|
+
self.assertIsNone(bst.environ.get('none_param'))
|
980
|
+
|
981
|
+
def test_complex_data_types(self):
|
982
|
+
"""Test storing complex data types."""
|
983
|
+
# Lists
|
984
|
+
bst.environ.set(list_param=[1, 2, 3])
|
985
|
+
self.assertEqual(bst.environ.get('list_param'), [1, 2, 3])
|
986
|
+
|
987
|
+
# Dictionaries
|
988
|
+
bst.environ.set(dict_param={'a': 1, 'b': 2})
|
989
|
+
self.assertEqual(bst.environ.get('dict_param'), {'a': 1, 'b': 2})
|
990
|
+
|
991
|
+
# Tuples
|
992
|
+
bst.environ.set(tuple_param=(1, 2, 3))
|
993
|
+
self.assertEqual(bst.environ.get('tuple_param'), (1, 2, 3))
|
994
|
+
|
995
|
+
# Custom objects
|
996
|
+
class CustomObject:
|
997
|
+
def __init__(self, value):
|
998
|
+
self.value = value
|
999
|
+
|
1000
|
+
obj = CustomObject(42)
|
1001
|
+
bst.environ.set(obj_param=obj)
|
1002
|
+
retrieved = bst.environ.get('obj_param')
|
1003
|
+
self.assertIs(retrieved, obj)
|
1004
|
+
self.assertEqual(retrieved.value, 42)
|
1005
|
+
|
1006
|
+
def test_special_string_values(self):
|
1007
|
+
"""Test special string values."""
|
1008
|
+
special_strings = ['', ' ', '\n', '\t', 'None', 'True', 'False']
|
1009
|
+
|
1010
|
+
for s in special_strings:
|
1011
|
+
bst.environ.set(string_param=s)
|
1012
|
+
self.assertEqual(bst.environ.get('string_param'), s)
|
1013
|
+
|
1014
|
+
def test_numeric_edge_values(self):
|
1015
|
+
"""Test numeric edge values."""
|
1016
|
+
import sys
|
1017
|
+
|
1018
|
+
edge_values = [
|
1019
|
+
0, -0, 1, -1,
|
1020
|
+
sys.maxsize, -sys.maxsize,
|
1021
|
+
float('inf'), float('-inf'),
|
1022
|
+
1e-100, 1e100,
|
1023
|
+
]
|
1024
|
+
|
1025
|
+
for value in edge_values:
|
1026
|
+
bst.environ.set(numeric_param=value)
|
1027
|
+
retrieved = bst.environ.get('numeric_param')
|
1028
|
+
if value != value: # NaN check
|
1029
|
+
self.assertTrue(retrieved != retrieved)
|
1030
|
+
else:
|
1031
|
+
self.assertEqual(retrieved, value)
|
1032
|
+
|
1033
|
+
def test_context_all_interaction(self):
|
1034
|
+
"""Test interaction between context and all() function."""
|
1035
|
+
bst.environ.set(global_param='global')
|
1036
|
+
|
1037
|
+
with bst.environ.context(context_param='context', global_param='override'):
|
1038
|
+
all_values = bst.environ.all()
|
1039
|
+
|
1040
|
+
# Should include both
|
1041
|
+
self.assertEqual(all_values['global_param'], 'override')
|
1042
|
+
self.assertEqual(all_values['context_param'], 'context')
|
1043
|
+
|
1044
|
+
# Original global values should be in settings
|
1045
|
+
self.assertIn('precision', all_values)
|
1046
|
+
|
1047
|
+
def test_deeply_nested_contexts(self):
|
1048
|
+
"""Test deeply nested contexts."""
|
1049
|
+
depth = 20
|
1050
|
+
bst.environ.set(depth=0)
|
1051
|
+
|
1052
|
+
def nested_context(level):
|
1053
|
+
if level < depth:
|
1054
|
+
with bst.environ.context(depth=level):
|
1055
|
+
self.assertEqual(bst.environ.get('depth'), level)
|
1056
|
+
nested_context(level + 1)
|
1057
|
+
self.assertEqual(bst.environ.get('depth'), level)
|
1058
|
+
|
1059
|
+
nested_context(1)
|
1060
|
+
self.assertEqual(bst.environ.get('depth'), 0)
|
1061
|
+
|
1062
|
+
def test_set_precision_function(self):
|
1063
|
+
"""Test the dedicated set_precision function."""
|
1064
|
+
# Valid precisions
|
1065
|
+
for precision in [8, 16, 32, 64, 'bf16']:
|
1066
|
+
bst.environ.set_precision(precision)
|
1067
|
+
self.assertEqual(bst.environ.get('precision'), precision)
|
1068
|
+
|
1069
|
+
# Invalid precision
|
1070
|
+
with self.assertRaises(ValueError):
|
1071
|
+
bst.environ.set_precision(128)
|
1072
|
+
|
1073
|
+
def test_pop_edge_cases(self):
|
1074
|
+
"""Test edge cases for pop function."""
|
1075
|
+
# Pop with None value
|
1076
|
+
bst.environ.set(none_key=None)
|
1077
|
+
popped = bst.environ.pop('none_key')
|
1078
|
+
self.assertIsNone(popped)
|
1079
|
+
|
1080
|
+
# Pop with None as default
|
1081
|
+
result = bst.environ.pop('missing_key', default=None)
|
1082
|
+
self.assertIsNone(result)
|
1083
|
+
|
1084
|
+
# Pop complex data types
|
1085
|
+
complex_obj = {'nested': {'data': [1, 2, 3]}}
|
1086
|
+
bst.environ.set(complex_key=complex_obj)
|
1087
|
+
popped = bst.environ.pop('complex_key')
|
1088
|
+
self.assertEqual(popped, complex_obj)
|
1089
|
+
|
1090
|
+
# Verify object identity preservation
|
1091
|
+
obj = object()
|
1092
|
+
bst.environ.set(obj_key=obj)
|
1093
|
+
popped = bst.environ.pop('obj_key')
|
1094
|
+
self.assertIs(popped, obj)
|
1095
|
+
|
1096
|
+
def test_pop_all_interaction(self):
|
1097
|
+
"""Test interaction between pop and all() function."""
|
1098
|
+
# Set multiple values
|
1099
|
+
bst.environ.set(a=1, b=2, c=3, d=4)
|
1100
|
+
initial_all = bst.environ.all()
|
1101
|
+
|
1102
|
+
# Pop some values
|
1103
|
+
bst.environ.pop('b')
|
1104
|
+
bst.environ.pop('d')
|
1105
|
+
|
1106
|
+
# Check all() reflects the changes
|
1107
|
+
after_pop = bst.environ.all()
|
1108
|
+
self.assertIn('a', after_pop)
|
1109
|
+
self.assertIn('c', after_pop)
|
1110
|
+
self.assertNotIn('b', after_pop)
|
1111
|
+
self.assertNotIn('d', after_pop)
|
1112
|
+
|
1113
|
+
def test_pop_callback_not_triggered(self):
|
1114
|
+
"""Test that callbacks are not triggered on pop."""
|
1115
|
+
callback_calls = []
|
1116
|
+
|
1117
|
+
def callback(value):
|
1118
|
+
callback_calls.append(value)
|
1119
|
+
|
1120
|
+
# Register callback
|
1121
|
+
bst.environ.register_default_behavior('callback_test', callback)
|
1122
|
+
|
1123
|
+
# Set triggers callback
|
1124
|
+
bst.environ.set(callback_test='value')
|
1125
|
+
self.assertEqual(len(callback_calls), 1)
|
1126
|
+
|
1127
|
+
# Pop should NOT trigger callback
|
1128
|
+
popped = bst.environ.pop('callback_test')
|
1129
|
+
self.assertEqual(len(callback_calls), 1) # Still just 1
|
1130
|
+
self.assertEqual(popped, 'value')
|
1131
|
+
|
1132
|
+
# Unregister callback
|
1133
|
+
bst.environ.unregister_default_behavior('callback_test')
|
1134
|
+
|
1135
|
+
|
1136
|
+
class TestIntegration(unittest.TestCase):
|
1137
|
+
"""Integration tests with actual BrainState functionality."""
|
1138
|
+
|
1139
|
+
def setUp(self):
|
1140
|
+
"""Reset environment before each test."""
|
1141
|
+
bst.environ.reset()
|
1142
|
+
warnings.filterwarnings('ignore', category=UserWarning)
|
1143
|
+
|
1144
|
+
def tearDown(self):
|
1145
|
+
"""Clean up after each test."""
|
1146
|
+
bst.environ.reset()
|
1147
|
+
warnings.resetwarnings()
|
1148
|
+
|
1149
|
+
def test_precision_affects_random_arrays(self):
|
1150
|
+
"""Test that precision setting affects random array generation."""
|
1151
|
+
# Test different precisions
|
1152
|
+
test_cases = [
|
1153
|
+
(32, jnp.float32),
|
1154
|
+
(64, jnp.float64),
|
1155
|
+
(16, jnp.float16),
|
1156
|
+
('bf16', jnp.bfloat16),
|
1157
|
+
]
|
1158
|
+
|
1159
|
+
for precision, expected_dtype in test_cases:
|
1160
|
+
with bst.environ.context(precision=precision):
|
1161
|
+
arr = bst.random.randn(10)
|
1162
|
+
self.assertEqual(arr.dtype, expected_dtype)
|
1163
|
+
|
1164
|
+
def test_mode_usage(self):
|
1165
|
+
"""Test mode usage in computations."""
|
1166
|
+
# Create different modes
|
1167
|
+
training = bst.mixin.Training()
|
1168
|
+
batching = bst.mixin.Batching(batch_size=32)
|
1169
|
+
|
1170
|
+
# Test training mode
|
1171
|
+
bst.environ.set(mode=training)
|
1172
|
+
mode = bst.environ.get('mode')
|
1173
|
+
self.assertTrue(mode.has(bst.mixin.Training))
|
1174
|
+
|
1175
|
+
# Test batching mode
|
1176
|
+
with bst.environ.context(mode=batching):
|
1177
|
+
mode = bst.environ.get('mode')
|
1178
|
+
self.assertTrue(mode.has(bst.mixin.Batching))
|
1179
|
+
self.assertEqual(mode.batch_size, 32)
|
1180
|
+
|
1181
|
+
def test_dt_in_numerical_integration(self):
|
1182
|
+
"""Test dt usage in numerical contexts."""
|
1183
|
+
# Set different dt values
|
1184
|
+
dt_values = [0.01, 0.001, 0.1]
|
1185
|
+
|
1186
|
+
for dt in dt_values:
|
1187
|
+
bst.environ.set(dt=dt)
|
1188
|
+
retrieved_dt = bst.environ.get_dt()
|
1189
|
+
self.assertEqual(retrieved_dt, dt)
|
1190
|
+
|
1191
|
+
# Simulate using dt in computation
|
1192
|
+
time_steps = int(1.0 / dt)
|
1193
|
+
self.assertGreater(time_steps, 0)
|
47
1194
|
|
48
|
-
def
|
49
|
-
|
1195
|
+
def test_combined_settings(self):
|
1196
|
+
"""Test combining multiple settings."""
|
1197
|
+
# Set multiple parameters
|
1198
|
+
bst.environ.set(
|
1199
|
+
precision=64,
|
1200
|
+
dt=0.01,
|
1201
|
+
mode=bst.mixin.Training(),
|
1202
|
+
custom_param='test',
|
1203
|
+
debug=True
|
1204
|
+
)
|
50
1205
|
|
51
|
-
|
1206
|
+
# Verify all are set
|
1207
|
+
self.assertEqual(bst.environ.get_precision(), 64)
|
1208
|
+
self.assertEqual(bst.environ.get_dt(), 0.01)
|
1209
|
+
self.assertTrue(bst.environ.get('mode').has(bst.mixin.Training))
|
1210
|
+
self.assertEqual(bst.environ.get('custom_param'), 'test')
|
1211
|
+
self.assertTrue(bst.environ.get('debug'))
|
52
1212
|
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
1213
|
+
# Test in nested contexts
|
1214
|
+
with bst.environ.context(precision=32, debug=False):
|
1215
|
+
self.assertEqual(bst.environ.get_precision(), 32)
|
1216
|
+
self.assertFalse(bst.environ.get('debug'))
|
1217
|
+
# Others unchanged
|
1218
|
+
self.assertEqual(bst.environ.get_dt(), 0.01)
|
1219
|
+
self.assertEqual(bst.environ.get('custom_param'), 'test')
|
57
1220
|
|
58
|
-
bst.environ.register_default_behavior('dt', dt_behavior)
|
59
1221
|
|
60
|
-
|
61
|
-
|
62
|
-
self.assertEqual(dt_, 0.1)
|
1222
|
+
if __name__ == '__main__':
|
1223
|
+
unittest.main()
|