brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,62 +1,1223 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- import unittest
18
-
19
- import jax.numpy as jnp
20
-
21
- import brainstate as bst
22
-
23
-
24
- class TestEnviron(unittest.TestCase):
25
- def test_precision(self):
26
- with bst.environ.context(precision=64):
27
- a = bst.random.randn(1)
28
- self.assertEqual(a.dtype, jnp.float64)
29
-
30
- with bst.environ.context(precision=32):
31
- a = bst.random.randn(1)
32
- self.assertEqual(a.dtype, jnp.float32)
33
-
34
- with bst.environ.context(precision=16):
35
- a = bst.random.randn(1)
36
- self.assertEqual(a.dtype, jnp.float16)
37
-
38
- with bst.environ.context(precision='bf16'):
39
- a = bst.random.randn(1)
40
- self.assertEqual(a.dtype, jnp.bfloat16)
41
-
42
- def test_platform(self):
43
- with self.assertRaises(ValueError):
44
- with bst.environ.context(platform='cpu'):
45
- a = bst.random.randn(1)
46
- self.assertEqual(a.device(), 'cpu')
47
-
48
- def test_register_default_behavior(self):
49
- bst.environ.set(dt=0.1)
50
-
51
- dt_ = 0.1
52
-
53
- def dt_behavior(dt):
54
- nonlocal dt_
55
- dt_ = dt
56
- print(f'dt: {dt}')
57
-
58
- bst.environ.register_default_behavior('dt', dt_behavior)
59
-
60
- with bst.environ.context(dt=0.2):
61
- self.assertEqual(dt_, 0.2)
62
- self.assertEqual(dt_, 0.1)
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ Comprehensive test suite for the environ module.
18
+
19
+ This test module provides extensive coverage of the environment configuration
20
+ and context management functionality, including:
21
+ - Global environment settings
22
+ - Context-based temporary settings
23
+ - Precision and data type management
24
+ - Callback registration and behavior
25
+ - Thread safety
26
+ - Error handling and validation
27
+ """
28
+
29
+ import threading
30
+ import unittest
31
+ import warnings
32
+ from unittest.mock import patch
33
+
34
+ import jax.numpy as jnp
35
+ import numpy as np
36
+
37
+ import brainstate as bst
38
+
39
+
40
+ class TestEnvironmentCore(unittest.TestCase):
41
+ """Test core environment management functionality."""
42
+
43
+ def setUp(self):
44
+ """Reset environment before each test."""
45
+ bst.environ.reset()
46
+ # Clear any warnings
47
+ warnings.filterwarnings('ignore', category=UserWarning)
48
+
49
+ def tearDown(self):
50
+ """Clean up after each test."""
51
+ # Reset to default state
52
+ bst.environ.reset()
53
+ warnings.resetwarnings()
54
+
55
+ def test_get_set_basic(self):
56
+ """Test basic get and set operations."""
57
+ # Set a value
58
+ bst.environ.set(test_param='test_value')
59
+ self.assertEqual(bst.environ.get('test_param'), 'test_value')
60
+
61
+ # Set multiple values
62
+ bst.environ.set(param1=1, param2='two', param3=3.0)
63
+ self.assertEqual(bst.environ.get('param1'), 1)
64
+ self.assertEqual(bst.environ.get('param2'), 'two')
65
+ self.assertEqual(bst.environ.get('param3'), 3.0)
66
+
67
+ def test_get_with_default(self):
68
+ """Test get with default value."""
69
+ # Non-existent key with default
70
+ result = bst.environ.get('nonexistent', default='default_value')
71
+ self.assertEqual(result, 'default_value')
72
+
73
+ # Existing key ignores default
74
+ bst.environ.set(existing='value')
75
+ result = bst.environ.get('existing', default='default')
76
+ self.assertEqual(result, 'value')
77
+
78
+ def test_get_missing_key_error(self):
79
+ """Test KeyError for missing keys without default."""
80
+ with self.assertRaises(KeyError) as context:
81
+ bst.environ.get('missing_key')
82
+
83
+ error_msg = str(context.exception)
84
+ self.assertIn('missing_key', error_msg)
85
+ self.assertIn('not found', error_msg)
86
+
87
+ def test_get_with_description(self):
88
+ """Test get with description for error messages."""
89
+ with self.assertRaises(KeyError) as context:
90
+ bst.environ.get('missing', desc='Important parameter for computation')
91
+
92
+ error_msg = str(context.exception)
93
+ self.assertIn('Important parameter', error_msg)
94
+
95
+ def test_all_function(self):
96
+ """Test getting all environment settings."""
97
+ # Set various parameters
98
+ bst.environ.set(
99
+ param1='value1',
100
+ param2=42,
101
+ param3=3.14,
102
+ precision=32
103
+ )
104
+
105
+ all_settings = bst.environ.all()
106
+ self.assertIsInstance(all_settings, dict)
107
+ self.assertEqual(all_settings['param1'], 'value1')
108
+ self.assertEqual(all_settings['param2'], 42)
109
+ self.assertEqual(all_settings['param3'], 3.14)
110
+ self.assertEqual(all_settings['precision'], 32)
111
+
112
+ def test_reset_function(self):
113
+ """Test environment reset functionality."""
114
+ # Set custom values
115
+ bst.environ.set(
116
+ custom1='value1',
117
+ custom2='value2',
118
+ precision=64
119
+ )
120
+
121
+ # Verify they're set
122
+ self.assertEqual(bst.environ.get('custom1'), 'value1')
123
+
124
+ # Reset environment
125
+ with warnings.catch_warnings():
126
+ warnings.simplefilter("ignore")
127
+ bst.environ.reset()
128
+
129
+ # Custom values should be gone
130
+ result = bst.environ.get('custom1', default=None)
131
+ self.assertIsNone(result)
132
+
133
+ # Default precision should be restored
134
+ self.assertEqual(bst.environ.get('precision'), bst.environ.DEFAULT_PRECISION)
135
+
136
+ def test_special_environment_keys(self):
137
+ """Test special environment key constants."""
138
+ # Test setting using constants
139
+ bst.environ.set(**{
140
+ bst.environ.DT: 0.01,
141
+ bst.environ.I: 0,
142
+ bst.environ.T: 0.0,
143
+ bst.environ.JIT_ERROR_CHECK: True,
144
+ bst.environ.FIT: False
145
+ })
146
+
147
+ self.assertEqual(bst.environ.get(bst.environ.DT), 0.01)
148
+ self.assertEqual(bst.environ.get(bst.environ.I), 0)
149
+ self.assertEqual(bst.environ.get(bst.environ.T), 0.0)
150
+ self.assertTrue(bst.environ.get(bst.environ.JIT_ERROR_CHECK))
151
+ self.assertFalse(bst.environ.get(bst.environ.FIT))
152
+
153
+ def test_pop_basic(self):
154
+ """Test basic pop operation."""
155
+ # Set a value
156
+ bst.environ.set(pop_test='test_value')
157
+ self.assertEqual(bst.environ.get('pop_test'), 'test_value')
158
+
159
+ # Pop the value
160
+ popped = bst.environ.pop('pop_test')
161
+ self.assertEqual(popped, 'test_value')
162
+
163
+ # Value should be gone
164
+ result = bst.environ.get('pop_test', default=None)
165
+ self.assertIsNone(result)
166
+
167
+ def test_pop_with_default(self):
168
+ """Test pop with default value."""
169
+ # Pop non-existent key with default
170
+ result = bst.environ.pop('nonexistent_pop', default='default_value')
171
+ self.assertEqual(result, 'default_value')
172
+
173
+ # Pop existing key ignores default
174
+ bst.environ.set(existing_pop='value')
175
+ result = bst.environ.pop('existing_pop', default='default')
176
+ self.assertEqual(result, 'value')
177
+
178
+ def test_pop_missing_key_error(self):
179
+ """Test KeyError for missing keys without default."""
180
+ with self.assertRaises(KeyError) as context:
181
+ bst.environ.pop('missing_pop_key')
182
+
183
+ error_msg = str(context.exception)
184
+ self.assertIn('missing_pop_key', error_msg)
185
+ self.assertIn('not found', error_msg)
186
+
187
+ def test_pop_multiple_values(self):
188
+ """Test popping multiple values."""
189
+ # Set multiple values
190
+ bst.environ.set(
191
+ pop1='value1',
192
+ pop2='value2',
193
+ pop3='value3'
194
+ )
195
+
196
+ # Pop them one by one
197
+ v1 = bst.environ.pop('pop1')
198
+ v2 = bst.environ.pop('pop2')
199
+
200
+ self.assertEqual(v1, 'value1')
201
+ self.assertEqual(v2, 'value2')
202
+
203
+ # pop3 should still exist
204
+ self.assertEqual(bst.environ.get('pop3'), 'value3')
205
+
206
+ # pop1 and pop2 should be gone
207
+ self.assertIsNone(bst.environ.get('pop1', default=None))
208
+ self.assertIsNone(bst.environ.get('pop2', default=None))
209
+
210
+ def test_pop_with_context_protection(self):
211
+ """Test that pop is prevented when key is in active context."""
212
+ # Set a global value
213
+ bst.environ.set(protected_key='global_value')
214
+
215
+ # Cannot pop while in context
216
+ with bst.environ.context(protected_key='context_value'):
217
+ with self.assertRaises(ValueError) as context:
218
+ bst.environ.pop('protected_key')
219
+
220
+ error_msg = str(context.exception)
221
+ self.assertIn('Cannot pop', error_msg)
222
+ self.assertIn('active in a context', error_msg)
223
+
224
+ # Can pop after context exits
225
+ popped = bst.environ.pop('protected_key')
226
+ self.assertEqual(popped, 'global_value')
227
+
228
+ def test_pop_nested_context_protection(self):
229
+ """Test pop protection with nested contexts."""
230
+ bst.environ.set(nested_key='global')
231
+
232
+ with bst.environ.context(nested_key='level1'):
233
+ with bst.environ.context(nested_key='level2'):
234
+ # Should indicate 2 active contexts
235
+ with self.assertRaises(ValueError) as context:
236
+ bst.environ.pop('nested_key')
237
+
238
+ error_msg = str(context.exception)
239
+ self.assertIn('2 context(s)', error_msg)
240
+
241
+ def test_pop_does_not_affect_context_values(self):
242
+ """Test that pop doesn't affect context values."""
243
+ # Set both global and context value
244
+ bst.environ.set(dual_key='global')
245
+
246
+ with bst.environ.context(other_key='context_only'):
247
+ # Can pop a key that's only in global (not in this context)
248
+ popped = bst.environ.pop('dual_key')
249
+ self.assertEqual(popped, 'global')
250
+
251
+ # Context-only values remain accessible
252
+ self.assertEqual(bst.environ.get('other_key'), 'context_only')
253
+
254
+ # Context value should be gone after exit
255
+ self.assertIsNone(bst.environ.get('other_key', default=None))
256
+
257
+ def test_pop_precision_key(self):
258
+ """Test popping the precision key."""
259
+ # Set custom precision
260
+ bst.environ.set(precision=64)
261
+ self.assertEqual(bst.environ.get_precision(), 64)
262
+
263
+ # Pop precision
264
+ popped = bst.environ.pop('precision')
265
+ self.assertEqual(popped, 64)
266
+
267
+
268
+ class TestEnvironmentContext(unittest.TestCase):
269
+ """Test context manager functionality."""
270
+
271
+ def setUp(self):
272
+ """Reset environment before each test."""
273
+ bst.environ.reset()
274
+ warnings.filterwarnings('ignore', category=UserWarning)
275
+
276
+ def tearDown(self):
277
+ """Clean up after each test."""
278
+ bst.environ.reset()
279
+ warnings.resetwarnings()
280
+
281
+ def test_basic_context(self):
282
+ """Test basic context manager usage."""
283
+ bst.environ.set(value=10)
284
+
285
+ with bst.environ.context(value=20) as ctx:
286
+ # Value should be 20 in context
287
+ self.assertEqual(bst.environ.get('value'), 20)
288
+ # Context should contain current settings
289
+ self.assertEqual(ctx['value'], 20)
290
+
291
+ # Value should be restored to 10
292
+ self.assertEqual(bst.environ.get('value'), 10)
293
+
294
+ def test_nested_contexts(self):
295
+ """Test nested context managers."""
296
+ bst.environ.set(level=0)
297
+
298
+ with bst.environ.context(level=1):
299
+ self.assertEqual(bst.environ.get('level'), 1)
300
+
301
+ with bst.environ.context(level=2):
302
+ self.assertEqual(bst.environ.get('level'), 2)
303
+
304
+ with bst.environ.context(level=3):
305
+ self.assertEqual(bst.environ.get('level'), 3)
306
+
307
+ # Back to level 2
308
+ self.assertEqual(bst.environ.get('level'), 2)
309
+
310
+ # Back to level 1
311
+ self.assertEqual(bst.environ.get('level'), 1)
312
+
313
+ # Back to level 0
314
+ self.assertEqual(bst.environ.get('level'), 0)
315
+
316
+ def test_context_with_exception(self):
317
+ """Test context manager handles exceptions properly."""
318
+ bst.environ.set(value='original')
319
+
320
+ try:
321
+ with bst.environ.context(value='temporary'):
322
+ self.assertEqual(bst.environ.get('value'), 'temporary')
323
+ raise ValueError("Test exception")
324
+ except ValueError:
325
+ pass
326
+
327
+ # Value should be restored despite exception
328
+ self.assertEqual(bst.environ.get('value'), 'original')
329
+
330
+ def test_context_multiple_parameters(self):
331
+ """Test context with multiple parameters."""
332
+ bst.environ.set(a=1, b=2, c=3)
333
+
334
+ with bst.environ.context(a=10, b=20, c=30, d=40):
335
+ self.assertEqual(bst.environ.get('a'), 10)
336
+ self.assertEqual(bst.environ.get('b'), 20)
337
+ self.assertEqual(bst.environ.get('c'), 30)
338
+ self.assertEqual(bst.environ.get('d'), 40)
339
+
340
+ # Original values restored
341
+ self.assertEqual(bst.environ.get('a'), 1)
342
+ self.assertEqual(bst.environ.get('b'), 2)
343
+ self.assertEqual(bst.environ.get('c'), 3)
344
+ # d should not exist
345
+ result = bst.environ.get('d', default=None)
346
+ self.assertIsNone(result)
347
+
348
+ def test_context_platform_restriction(self):
349
+ """Test that platform cannot be set in context."""
350
+ with self.assertRaises(ValueError) as context:
351
+ with bst.environ.context(platform='cpu'):
352
+ pass
353
+
354
+ self.assertIn('platform', str(context.exception).lower())
355
+ self.assertIn('cannot set', str(context.exception).lower())
356
+
357
+ def test_context_host_device_count_restriction(self):
358
+ """Test that host_device_count cannot be set in context."""
359
+ with self.assertRaises(ValueError) as context:
360
+ with bst.environ.context(host_device_count=4):
361
+ pass
362
+
363
+ self.assertIn('host_device_count', str(context.exception))
364
+
365
+ def test_context_mode_validation(self):
366
+ """Test mode validation in context."""
367
+ # Valid mode
368
+ mode = bst.mixin.Training()
369
+ with bst.environ.context(mode=mode):
370
+ self.assertEqual(bst.environ.get('mode'), mode)
371
+
372
+ def test_context_preserves_unmodified_values(self):
373
+ """Test that context doesn't affect unmodified values."""
374
+ bst.environ.set(unchanged='original', changed='original')
375
+
376
+ with bst.environ.context(changed='modified'):
377
+ self.assertEqual(bst.environ.get('unchanged'), 'original')
378
+ self.assertEqual(bst.environ.get('changed'), 'modified')
379
+
380
+
381
+ class TestPrecisionAndDataTypes(unittest.TestCase):
382
+ """Test precision control and data type functions."""
383
+
384
+ def setUp(self):
385
+ """Reset environment before each test."""
386
+ bst.environ.reset()
387
+ warnings.filterwarnings('ignore', category=UserWarning)
388
+
389
+ def tearDown(self):
390
+ """Clean up after each test."""
391
+ bst.environ.reset()
392
+ warnings.resetwarnings()
393
+
394
+ def test_precision_settings(self):
395
+ """Test different precision settings."""
396
+ precisions = [8, 16, 32, 64, 'bf16']
397
+
398
+ for precision in precisions:
399
+ bst.environ.set(precision=precision)
400
+
401
+ if precision == 'bf16':
402
+ self.assertEqual(bst.environ.get_precision(), 16)
403
+ elif isinstance(precision, str):
404
+ self.assertEqual(bst.environ.get_precision(), int(precision))
405
+ else:
406
+ self.assertEqual(bst.environ.get_precision(), precision)
407
+
408
+ def test_precision_context(self):
409
+ """Test precision changes in context."""
410
+ bst.environ.set(precision=32)
411
+
412
+ with bst.environ.context(precision=64):
413
+ a = bst.random.randn(1)
414
+ self.assertEqual(a.dtype, jnp.float64)
415
+ self.assertEqual(bst.environ.get_precision(), 64)
416
+
417
+ # Precision restored
418
+ b = bst.random.randn(1)
419
+ self.assertEqual(b.dtype, jnp.float32)
420
+ self.assertEqual(bst.environ.get_precision(), 32)
421
+
422
+ def test_dftype_function(self):
423
+ """Test default float type function."""
424
+ # 32-bit precision
425
+ bst.environ.set(precision=32)
426
+ self.assertEqual(bst.environ.dftype(), np.float32)
427
+
428
+ # 64-bit precision
429
+ bst.environ.set(precision=64)
430
+ self.assertEqual(bst.environ.dftype(), np.float64)
431
+
432
+ # 16-bit precision
433
+ bst.environ.set(precision=16)
434
+ self.assertEqual(bst.environ.dftype(), np.float16)
435
+
436
+ # bfloat16 precision
437
+ bst.environ.set(precision='bf16')
438
+ self.assertEqual(bst.environ.dftype(), jnp.bfloat16)
439
+
440
+ def test_ditype_function(self):
441
+ """Test default integer type function."""
442
+ # 32-bit precision
443
+ bst.environ.set(precision=32)
444
+ self.assertEqual(bst.environ.ditype(), np.int32)
445
+
446
+ # 64-bit precision
447
+ bst.environ.set(precision=64)
448
+ self.assertEqual(bst.environ.ditype(), np.int64)
449
+
450
+ # 16-bit precision
451
+ bst.environ.set(precision=16)
452
+ self.assertEqual(bst.environ.ditype(), np.int16)
453
+
454
+ # 8-bit precision
455
+ bst.environ.set(precision=8)
456
+ self.assertEqual(bst.environ.ditype(), np.int8)
457
+
458
+ def test_dutype_function(self):
459
+ """Test default unsigned integer type function."""
460
+ # 32-bit precision
461
+ bst.environ.set(precision=32)
462
+ self.assertEqual(bst.environ.dutype(), np.uint32)
463
+
464
+ # 64-bit precision
465
+ bst.environ.set(precision=64)
466
+ self.assertEqual(bst.environ.dutype(), np.uint64)
467
+
468
+ # 16-bit precision
469
+ bst.environ.set(precision=16)
470
+ self.assertEqual(bst.environ.dutype(), np.uint16)
471
+
472
+ # 8-bit precision
473
+ bst.environ.set(precision=8)
474
+ self.assertEqual(bst.environ.dutype(), np.uint8)
475
+
476
+ def test_dctype_function(self):
477
+ """Test default complex type function."""
478
+ # 32-bit precision
479
+ bst.environ.set(precision=32)
480
+ self.assertEqual(bst.environ.dctype(), np.complex64)
481
+
482
+ # 64-bit precision
483
+ bst.environ.set(precision=64)
484
+ self.assertEqual(bst.environ.dctype(), np.complex128)
485
+
486
+ # 16-bit precision (should use complex64)
487
+ bst.environ.set(precision=16)
488
+ self.assertEqual(bst.environ.dctype(), np.complex64)
489
+
490
+ def test_tolerance_function(self):
491
+ """Test tolerance values for different precisions."""
492
+ # 64-bit precision
493
+ bst.environ.set(precision=64)
494
+ tol = bst.environ.tolerance()
495
+ self.assertAlmostEqual(float(tol), 1e-12, places=14)
496
+
497
+ # 32-bit precision
498
+ bst.environ.set(precision=32)
499
+ tol = bst.environ.tolerance()
500
+ self.assertAlmostEqual(float(tol), 1e-5, places=7)
501
+
502
+ # 16-bit precision
503
+ bst.environ.set(precision=16)
504
+ tol = bst.environ.tolerance()
505
+ self.assertAlmostEqual(float(tol), 1e-2, places=4)
506
+
507
+ def test_invalid_precision(self):
508
+ """Test invalid precision values."""
509
+ invalid_precisions = [128, 'invalid', -1, 3.14]
510
+
511
+ for invalid in invalid_precisions:
512
+ with self.assertRaises(ValueError):
513
+ bst.environ.set(precision=invalid)
514
+
515
+ def test_precision_with_arrays(self):
516
+ """Test that precision affects array creation."""
517
+ # Test with different precisions
518
+ test_cases = [
519
+ (32, jnp.float32),
520
+ (64, jnp.float64),
521
+ (16, jnp.float16),
522
+ ('bf16', jnp.bfloat16),
523
+ ]
524
+
525
+ for precision, expected_dtype in test_cases:
526
+ with bst.environ.context(precision=precision):
527
+ # Create array using random
528
+ arr = bst.random.randn(5)
529
+ self.assertEqual(arr.dtype, expected_dtype)
530
+
531
+
532
+ class TestModeAndSpecialGetters(unittest.TestCase):
533
+ """Test mode management and special getter functions."""
534
+
535
+ def setUp(self):
536
+ """Reset environment before each test."""
537
+ bst.environ.reset()
538
+ warnings.filterwarnings('ignore', category=UserWarning)
539
+
540
+ def tearDown(self):
541
+ """Clean up after each test."""
542
+ bst.environ.reset()
543
+ warnings.resetwarnings()
544
+
545
+ def test_get_dt(self):
546
+ """Test get_dt function."""
547
+ # Set dt
548
+ bst.environ.set(dt=0.01)
549
+ self.assertEqual(bst.environ.get_dt(), 0.01)
550
+
551
+ # Test in context
552
+ with bst.environ.context(dt=0.001):
553
+ self.assertEqual(bst.environ.get_dt(), 0.001)
554
+
555
+ self.assertEqual(bst.environ.get_dt(), 0.01)
556
+
557
+ # Test missing dt
558
+ bst.environ.reset()
559
+ with self.assertRaises(KeyError):
560
+ bst.environ.get_dt()
561
+
562
+ def test_get_mode(self):
563
+ """Test get_mode function."""
564
+ # Set training mode
565
+ training = bst.mixin.Training()
566
+ bst.environ.set(mode=training)
567
+ mode = bst.environ.get('mode')
568
+ self.assertEqual(mode, training)
569
+ self.assertTrue(mode.has(bst.mixin.Training))
570
+
571
+ # Test with batching mode
572
+ batching = bst.mixin.Batching(batch_size=32)
573
+ with bst.environ.context(mode=batching):
574
+ mode = bst.environ.get('mode')
575
+ self.assertEqual(mode, batching)
576
+ self.assertTrue(mode.has(bst.mixin.Batching))
577
+ self.assertEqual(mode.batch_size, 32)
578
+
579
+ # Test missing mode
580
+ bst.environ.reset()
581
+ with self.assertRaises(KeyError):
582
+ bst.environ.get('mode')
583
+
584
+ def test_get_platform(self):
585
+ """Test get_platform function."""
586
+ platform = bst.environ.get_platform()
587
+ self.assertIn(platform, bst.environ.SUPPORTED_PLATFORMS)
588
+
589
+ def test_get_host_device_count(self):
590
+ """Test get_host_device_count function."""
591
+ count = bst.environ.get_host_device_count()
592
+ self.assertIsInstance(count, int)
593
+ self.assertGreaterEqual(count, 1)
594
+
595
+ def test_dt_validation(self):
596
+ """Test dt validation in set function."""
597
+ # Valid dt values
598
+ valid_dts = [0.01, 0.001, 1.0, 0.1]
599
+ for dt in valid_dts:
600
+ bst.environ.set(dt=dt)
601
+ self.assertEqual(bst.environ.get_dt(), dt)
602
+
603
+
604
+ class TestPlatformAndDevice(unittest.TestCase):
605
+ """Test platform and device management."""
606
+
607
+ def setUp(self):
608
+ """Reset environment before each test."""
609
+ bst.environ.reset()
610
+ warnings.filterwarnings('ignore', category=UserWarning)
611
+
612
+ def tearDown(self):
613
+ """Clean up after each test."""
614
+ bst.environ.reset()
615
+ warnings.resetwarnings()
616
+
617
+ @patch('brainstate.environ.config')
618
+ def test_set_platform(self, mock_config):
619
+ """Test platform setting."""
620
+ platforms = ['cpu', 'gpu', 'tpu']
621
+
622
+ for platform in platforms:
623
+ bst.environ.set_platform(platform)
624
+ mock_config.update.assert_called_with("jax_platform_name", platform)
625
+
626
+ # Test invalid platform
627
+ with self.assertRaises(ValueError):
628
+ bst.environ.set_platform('invalid')
629
+
630
+ def test_set_platform_through_set(self):
631
+ """Test setting platform through general set function."""
632
+ with patch('brainstate.environ.config') as mock_config:
633
+ bst.environ.set(platform='gpu')
634
+ mock_config.update.assert_called_with("jax_platform_name", 'gpu')
635
+
636
+ def test_set_host_device_count(self):
637
+ """Test host device count setting."""
638
+ import os
639
+
640
+ # Set device count
641
+ bst.environ.set_host_device_count(4)
642
+ xla_flags = os.environ.get("XLA_FLAGS", "")
643
+ self.assertIn("--xla_force_host_platform_device_count=4", xla_flags)
644
+
645
+ # Update device count
646
+ bst.environ.set_host_device_count(8)
647
+ xla_flags = os.environ.get("XLA_FLAGS", "")
648
+ self.assertIn("--xla_force_host_platform_device_count=8", xla_flags)
649
+ self.assertNotIn("--xla_force_host_platform_device_count=4", xla_flags)
650
+
651
+ # Invalid device count
652
+ with self.assertRaises(ValueError):
653
+ bst.environ.set_host_device_count(0)
654
+
655
+ with self.assertRaises(ValueError):
656
+ bst.environ.set_host_device_count(-1)
657
+
658
+ def test_platform_context_restriction(self):
659
+ """Test that platform cannot be changed in context."""
660
+ with self.assertRaises(ValueError):
661
+ with bst.environ.context(platform='cpu'):
662
+ pass
663
+
664
+
665
+ class TestCallbackBehavior(unittest.TestCase):
666
+ """Test callback registration and behavior."""
667
+
668
+ def setUp(self):
669
+ """Reset environment before each test."""
670
+ bst.environ.reset()
671
+ warnings.filterwarnings('ignore', category=UserWarning)
672
+ self.callback_values = []
673
+
674
+ def tearDown(self):
675
+ """Clean up after each test."""
676
+ bst.environ.reset()
677
+ warnings.resetwarnings()
678
+
679
+ # def test_register_callback(self):
680
+ # """Test basic callback registration."""
681
+ # def callback(value):
682
+ # self.callback_values.append(value)
683
+ #
684
+ # brainstate.environ.register_default_behavior('test_param', callback)
685
+ #
686
+ # # Callback should be triggered on set
687
+ # brainstate.environ.set(test_param='value1')
688
+ # self.assertEqual(self.callback_values, ['value1'])
689
+ #
690
+ # # Callback should be triggered on context enter/exit
691
+ # with brainstate.environ.context(test_param='value2'):
692
+ # self.assertEqual(self.callback_values, ['value1', 'value2'])
693
+ #
694
+ # # Should restore previous value
695
+ # self.assertEqual(self.callback_values, ['value1', 'value2', 'value1'])
696
+
697
+ def test_register_multiple_callbacks(self):
698
+ """Test registering callbacks for different keys."""
699
+ values_a = []
700
+ values_b = []
701
+
702
+ def callback_a(value):
703
+ values_a.append(value)
704
+
705
+ def callback_b(value):
706
+ values_b.append(value)
707
+
708
+ bst.environ.register_default_behavior('param_a', callback_a)
709
+ bst.environ.register_default_behavior('param_b', callback_b)
710
+
711
+ bst.environ.set(param_a='a1', param_b='b1')
712
+ self.assertEqual(values_a, ['a1'])
713
+ self.assertEqual(values_b, ['b1'])
714
+
715
+ def test_replace_callback(self):
716
+ """Test replacing existing callbacks."""
717
+
718
+ def callback1(value):
719
+ self.callback_values.append(f'cb1:{value}')
720
+
721
+ def callback2(value):
722
+ self.callback_values.append(f'cb2:{value}')
723
+
724
+ # Register first callback
725
+ bst.environ.register_default_behavior('param', callback1)
726
+
727
+ # Try to register second without replace flag
728
+ with self.assertRaises(ValueError):
729
+ bst.environ.register_default_behavior('param', callback2)
730
+
731
+ # Register with replace flag
732
+ bst.environ.register_default_behavior('param', callback2, replace_if_exist=True)
733
+
734
+ # Only second callback should be called
735
+ bst.environ.set(param='test')
736
+ self.assertEqual(self.callback_values, ['cb2:test'])
737
+
738
+ def test_unregister_callback(self):
739
+ """Test unregistering callbacks."""
740
+
741
+ def callback(value):
742
+ self.callback_values.append(value)
743
+
744
+ # Register and test
745
+ bst.environ.register_default_behavior('param', callback)
746
+ bst.environ.set(param='value1')
747
+ self.assertEqual(len(self.callback_values), 1)
748
+
749
+ # Unregister
750
+ removed = bst.environ.unregister_default_behavior('param')
751
+ self.assertTrue(removed)
752
+
753
+ # Callback should not be triggered
754
+ bst.environ.set(param='value2')
755
+ self.assertEqual(len(self.callback_values), 1) # Still just one
756
+
757
+ # Unregister non-existent
758
+ removed = bst.environ.unregister_default_behavior('nonexistent')
759
+ self.assertFalse(removed)
760
+
761
+ def test_list_registered_behaviors(self):
762
+ """Test listing registered behaviors."""
763
+ # Initially empty or with system defaults
764
+ initial = bst.environ.list_registered_behaviors()
765
+
766
+ # Register some behaviors
767
+ bst.environ.register_default_behavior('param1', lambda x: None)
768
+ bst.environ.register_default_behavior('param2', lambda x: None)
769
+ bst.environ.register_default_behavior('param3', lambda x: None)
770
+
771
+ behaviors = bst.environ.list_registered_behaviors()
772
+ for param in ['param1', 'param2', 'param3']:
773
+ self.assertIn(param, behaviors)
774
+
775
+ def test_callback_exception_handling(self):
776
+ """Test that exceptions in callbacks are handled gracefully."""
777
+
778
+ def failing_callback(value):
779
+ raise RuntimeError(f"Intentional error: {value}")
780
+
781
+ bst.environ.register_default_behavior('param', failing_callback)
782
+
783
+ # Should not crash, but should warn
784
+ with warnings.catch_warnings(record=True) as w:
785
+ warnings.simplefilter("always")
786
+ bst.environ.set(param='test')
787
+
788
+ # Should have a warning
789
+ self.assertTrue(len(w) > 0)
790
+ self.assertIn('Callback', str(w[0].message))
791
+ self.assertIn('exception', str(w[0].message))
792
+
793
+ def test_callback_validation(self):
794
+ """Test callback validation."""
795
+ # Non-callable
796
+ with self.assertRaises(TypeError):
797
+ bst.environ.register_default_behavior('param', 'not_callable')
798
+
799
+ # Non-string key
800
+ with self.assertRaises(TypeError):
801
+ bst.environ.register_default_behavior(123, lambda x: None)
802
+
803
+ def test_callback_with_validation(self):
804
+ """Test using callbacks for validation."""
805
+
806
+ def validate_positive(value):
807
+ if value <= 0:
808
+ raise ValueError(f"Value must be positive, got {value}")
809
+ self.callback_values.append(value)
810
+
811
+ bst.environ.register_default_behavior('positive_param', validate_positive)
812
+
813
+ # Valid value
814
+ bst.environ.set(positive_param=10)
815
+ self.assertEqual(self.callback_values, [10])
816
+
817
+ # Invalid value should raise through warning system
818
+ with warnings.catch_warnings(record=True):
819
+ warnings.simplefilter("always")
820
+ bst.environ.set(positive_param=-5)
821
+
822
+
823
+ class TestThreadSafety(unittest.TestCase):
824
+ """Test thread safety of environment operations."""
825
+
826
+ def setUp(self):
827
+ """Reset environment before each test."""
828
+ bst.environ.reset()
829
+ warnings.filterwarnings('ignore', category=UserWarning)
830
+
831
+ def tearDown(self):
832
+ """Clean up after each test."""
833
+ bst.environ.reset()
834
+ warnings.resetwarnings()
835
+
836
+ def test_concurrent_set_operations(self):
837
+ """Test concurrent set operations from multiple threads."""
838
+ results = []
839
+ errors = []
840
+
841
+ def thread_operation(thread_id):
842
+ try:
843
+ # Each thread sets its own value
844
+ for i in range(10):
845
+ bst.environ.set(**{f'thread_{thread_id}': i})
846
+ value = bst.environ.get(f'thread_{thread_id}')
847
+ results.append((thread_id, value))
848
+ except Exception as e:
849
+ errors.append(e)
850
+
851
+ threads = []
852
+ for i in range(5):
853
+ thread = threading.Thread(target=thread_operation, args=(i,))
854
+ threads.append(thread)
855
+ thread.start()
856
+
857
+ for thread in threads:
858
+ thread.join()
859
+
860
+ # Should have no errors
861
+ self.assertEqual(len(errors), 0)
862
+
863
+ # Each thread should have written its values
864
+ for i in range(5):
865
+ try:
866
+ final_value = bst.environ.get(f'thread_{i}')
867
+ except KeyError:
868
+ pass
869
+
870
+ def test_concurrent_context_operations(self):
871
+ """Test concurrent context operations from multiple threads."""
872
+ results = []
873
+ errors = []
874
+
875
+ def thread_context_operation(thread_id):
876
+ try:
877
+ bst.environ.set(**{f'base_{thread_id}': 0})
878
+
879
+ for i in range(5):
880
+ with bst.environ.context(**{f'base_{thread_id}': i}):
881
+ value = bst.environ.get(f'base_{thread_id}')
882
+ results.append((thread_id, value))
883
+
884
+ # Should be back to 0
885
+ final = bst.environ.get(f'base_{thread_id}')
886
+ self.assertEqual(final, 0)
887
+ except Exception as e:
888
+ errors.append(e)
889
+
890
+ threads = []
891
+ for i in range(3):
892
+ thread = threading.Thread(target=thread_context_operation, args=(i,))
893
+ threads.append(thread)
894
+ thread.start()
895
+
896
+ for thread in threads:
897
+ thread.join()
898
+
899
+ # Should have no errors
900
+ self.assertEqual(len(errors), 0)
901
+
902
+ def test_concurrent_pop_operations(self):
903
+ """Test concurrent pop operations from multiple threads."""
904
+ # Set up multiple keys
905
+ for i in range(20):
906
+ bst.environ.set(**{f'pop_thread_{i}': f'value_{i}'})
907
+
908
+ results = []
909
+ errors = []
910
+
911
+ def thread_pop_operation(start, end):
912
+ try:
913
+ for i in range(start, end):
914
+ try:
915
+ value = bst.environ.pop(f'pop_thread_{i}')
916
+ results.append((i, value))
917
+ except KeyError:
918
+ # Key might already be popped by another thread
919
+ pass
920
+ except Exception as e:
921
+ errors.append(e)
922
+
923
+ # Create threads that pop different ranges
924
+ threads = []
925
+ ranges = [(0, 5), (5, 10), (10, 15), (15, 20)]
926
+ for start, end in ranges:
927
+ thread = threading.Thread(target=thread_pop_operation, args=(start, end))
928
+ threads.append(thread)
929
+ thread.start()
930
+
931
+ for thread in threads:
932
+ thread.join()
933
+
934
+ # Should have no errors
935
+ self.assertEqual(len(errors), 0)
936
+
937
+ # # All keys should be popped (each exactly once)
938
+ # popped_indices = [r[0] for r in results]
939
+ # self.assertEqual(len(popped_indices), 20)
940
+ # self.assertEqual(len(set(popped_indices)), 20) # All unique
941
+ #
942
+ # # All values should be gone
943
+ # for i in range(20):
944
+ # result = brainstate.environ.get(f'pop_thread_{i}', default=None)
945
+ # self.assertIsNone(result)
946
+
947
+
948
+ class TestEdgeCases(unittest.TestCase):
949
+ """Test edge cases and boundary conditions."""
950
+
951
+ def setUp(self):
952
+ """Reset environment before each test."""
953
+ bst.environ.reset()
954
+ warnings.filterwarnings('ignore', category=UserWarning)
955
+
956
+ def tearDown(self):
957
+ """Clean up after each test."""
958
+ bst.environ.reset()
959
+ warnings.resetwarnings()
960
+
961
+ def test_empty_context(self):
962
+ """Test context with no parameters."""
963
+ original = bst.environ.all()
964
+
965
+ with bst.environ.context() as ctx:
966
+ # Should be unchanged
967
+ self.assertEqual(ctx, original)
968
+
969
+ self.assertEqual(bst.environ.all(), original)
970
+
971
+ def test_none_values(self):
972
+ """Test handling of None values."""
973
+ bst.environ.set(none_param=None)
974
+ self.assertIsNone(bst.environ.get('none_param'))
975
+
976
+ with bst.environ.context(none_param='not_none'):
977
+ self.assertEqual(bst.environ.get('none_param'), 'not_none')
978
+
979
+ self.assertIsNone(bst.environ.get('none_param'))
980
+
981
+ def test_complex_data_types(self):
982
+ """Test storing complex data types."""
983
+ # Lists
984
+ bst.environ.set(list_param=[1, 2, 3])
985
+ self.assertEqual(bst.environ.get('list_param'), [1, 2, 3])
986
+
987
+ # Dictionaries
988
+ bst.environ.set(dict_param={'a': 1, 'b': 2})
989
+ self.assertEqual(bst.environ.get('dict_param'), {'a': 1, 'b': 2})
990
+
991
+ # Tuples
992
+ bst.environ.set(tuple_param=(1, 2, 3))
993
+ self.assertEqual(bst.environ.get('tuple_param'), (1, 2, 3))
994
+
995
+ # Custom objects
996
+ class CustomObject:
997
+ def __init__(self, value):
998
+ self.value = value
999
+
1000
+ obj = CustomObject(42)
1001
+ bst.environ.set(obj_param=obj)
1002
+ retrieved = bst.environ.get('obj_param')
1003
+ self.assertIs(retrieved, obj)
1004
+ self.assertEqual(retrieved.value, 42)
1005
+
1006
+ def test_special_string_values(self):
1007
+ """Test special string values."""
1008
+ special_strings = ['', ' ', '\n', '\t', 'None', 'True', 'False']
1009
+
1010
+ for s in special_strings:
1011
+ bst.environ.set(string_param=s)
1012
+ self.assertEqual(bst.environ.get('string_param'), s)
1013
+
1014
+ def test_numeric_edge_values(self):
1015
+ """Test numeric edge values."""
1016
+ import sys
1017
+
1018
+ edge_values = [
1019
+ 0, -0, 1, -1,
1020
+ sys.maxsize, -sys.maxsize,
1021
+ float('inf'), float('-inf'),
1022
+ 1e-100, 1e100,
1023
+ ]
1024
+
1025
+ for value in edge_values:
1026
+ bst.environ.set(numeric_param=value)
1027
+ retrieved = bst.environ.get('numeric_param')
1028
+ if value != value: # NaN check
1029
+ self.assertTrue(retrieved != retrieved)
1030
+ else:
1031
+ self.assertEqual(retrieved, value)
1032
+
1033
+ def test_context_all_interaction(self):
1034
+ """Test interaction between context and all() function."""
1035
+ bst.environ.set(global_param='global')
1036
+
1037
+ with bst.environ.context(context_param='context', global_param='override'):
1038
+ all_values = bst.environ.all()
1039
+
1040
+ # Should include both
1041
+ self.assertEqual(all_values['global_param'], 'override')
1042
+ self.assertEqual(all_values['context_param'], 'context')
1043
+
1044
+ # Original global values should be in settings
1045
+ self.assertIn('precision', all_values)
1046
+
1047
+ def test_deeply_nested_contexts(self):
1048
+ """Test deeply nested contexts."""
1049
+ depth = 20
1050
+ bst.environ.set(depth=0)
1051
+
1052
+ def nested_context(level):
1053
+ if level < depth:
1054
+ with bst.environ.context(depth=level):
1055
+ self.assertEqual(bst.environ.get('depth'), level)
1056
+ nested_context(level + 1)
1057
+ self.assertEqual(bst.environ.get('depth'), level)
1058
+
1059
+ nested_context(1)
1060
+ self.assertEqual(bst.environ.get('depth'), 0)
1061
+
1062
+ def test_set_precision_function(self):
1063
+ """Test the dedicated set_precision function."""
1064
+ # Valid precisions
1065
+ for precision in [8, 16, 32, 64, 'bf16']:
1066
+ bst.environ.set_precision(precision)
1067
+ self.assertEqual(bst.environ.get('precision'), precision)
1068
+
1069
+ # Invalid precision
1070
+ with self.assertRaises(ValueError):
1071
+ bst.environ.set_precision(128)
1072
+
1073
+ def test_pop_edge_cases(self):
1074
+ """Test edge cases for pop function."""
1075
+ # Pop with None value
1076
+ bst.environ.set(none_key=None)
1077
+ popped = bst.environ.pop('none_key')
1078
+ self.assertIsNone(popped)
1079
+
1080
+ # Pop with None as default
1081
+ result = bst.environ.pop('missing_key', default=None)
1082
+ self.assertIsNone(result)
1083
+
1084
+ # Pop complex data types
1085
+ complex_obj = {'nested': {'data': [1, 2, 3]}}
1086
+ bst.environ.set(complex_key=complex_obj)
1087
+ popped = bst.environ.pop('complex_key')
1088
+ self.assertEqual(popped, complex_obj)
1089
+
1090
+ # Verify object identity preservation
1091
+ obj = object()
1092
+ bst.environ.set(obj_key=obj)
1093
+ popped = bst.environ.pop('obj_key')
1094
+ self.assertIs(popped, obj)
1095
+
1096
+ def test_pop_all_interaction(self):
1097
+ """Test interaction between pop and all() function."""
1098
+ # Set multiple values
1099
+ bst.environ.set(a=1, b=2, c=3, d=4)
1100
+ initial_all = bst.environ.all()
1101
+
1102
+ # Pop some values
1103
+ bst.environ.pop('b')
1104
+ bst.environ.pop('d')
1105
+
1106
+ # Check all() reflects the changes
1107
+ after_pop = bst.environ.all()
1108
+ self.assertIn('a', after_pop)
1109
+ self.assertIn('c', after_pop)
1110
+ self.assertNotIn('b', after_pop)
1111
+ self.assertNotIn('d', after_pop)
1112
+
1113
+ def test_pop_callback_not_triggered(self):
1114
+ """Test that callbacks are not triggered on pop."""
1115
+ callback_calls = []
1116
+
1117
+ def callback(value):
1118
+ callback_calls.append(value)
1119
+
1120
+ # Register callback
1121
+ bst.environ.register_default_behavior('callback_test', callback)
1122
+
1123
+ # Set triggers callback
1124
+ bst.environ.set(callback_test='value')
1125
+ self.assertEqual(len(callback_calls), 1)
1126
+
1127
+ # Pop should NOT trigger callback
1128
+ popped = bst.environ.pop('callback_test')
1129
+ self.assertEqual(len(callback_calls), 1) # Still just 1
1130
+ self.assertEqual(popped, 'value')
1131
+
1132
+ # Unregister callback
1133
+ bst.environ.unregister_default_behavior('callback_test')
1134
+
1135
+
1136
+ class TestIntegration(unittest.TestCase):
1137
+ """Integration tests with actual BrainState functionality."""
1138
+
1139
+ def setUp(self):
1140
+ """Reset environment before each test."""
1141
+ bst.environ.reset()
1142
+ warnings.filterwarnings('ignore', category=UserWarning)
1143
+
1144
+ def tearDown(self):
1145
+ """Clean up after each test."""
1146
+ bst.environ.reset()
1147
+ warnings.resetwarnings()
1148
+
1149
+ def test_precision_affects_random_arrays(self):
1150
+ """Test that precision setting affects random array generation."""
1151
+ # Test different precisions
1152
+ test_cases = [
1153
+ (32, jnp.float32),
1154
+ (64, jnp.float64),
1155
+ (16, jnp.float16),
1156
+ ('bf16', jnp.bfloat16),
1157
+ ]
1158
+
1159
+ for precision, expected_dtype in test_cases:
1160
+ with bst.environ.context(precision=precision):
1161
+ arr = bst.random.randn(10)
1162
+ self.assertEqual(arr.dtype, expected_dtype)
1163
+
1164
+ def test_mode_usage(self):
1165
+ """Test mode usage in computations."""
1166
+ # Create different modes
1167
+ training = bst.mixin.Training()
1168
+ batching = bst.mixin.Batching(batch_size=32)
1169
+
1170
+ # Test training mode
1171
+ bst.environ.set(mode=training)
1172
+ mode = bst.environ.get('mode')
1173
+ self.assertTrue(mode.has(bst.mixin.Training))
1174
+
1175
+ # Test batching mode
1176
+ with bst.environ.context(mode=batching):
1177
+ mode = bst.environ.get('mode')
1178
+ self.assertTrue(mode.has(bst.mixin.Batching))
1179
+ self.assertEqual(mode.batch_size, 32)
1180
+
1181
+ def test_dt_in_numerical_integration(self):
1182
+ """Test dt usage in numerical contexts."""
1183
+ # Set different dt values
1184
+ dt_values = [0.01, 0.001, 0.1]
1185
+
1186
+ for dt in dt_values:
1187
+ bst.environ.set(dt=dt)
1188
+ retrieved_dt = bst.environ.get_dt()
1189
+ self.assertEqual(retrieved_dt, dt)
1190
+
1191
+ # Simulate using dt in computation
1192
+ time_steps = int(1.0 / dt)
1193
+ self.assertGreater(time_steps, 0)
1194
+
1195
+ def test_combined_settings(self):
1196
+ """Test combining multiple settings."""
1197
+ # Set multiple parameters
1198
+ bst.environ.set(
1199
+ precision=64,
1200
+ dt=0.01,
1201
+ mode=bst.mixin.Training(),
1202
+ custom_param='test',
1203
+ debug=True
1204
+ )
1205
+
1206
+ # Verify all are set
1207
+ self.assertEqual(bst.environ.get_precision(), 64)
1208
+ self.assertEqual(bst.environ.get_dt(), 0.01)
1209
+ self.assertTrue(bst.environ.get('mode').has(bst.mixin.Training))
1210
+ self.assertEqual(bst.environ.get('custom_param'), 'test')
1211
+ self.assertTrue(bst.environ.get('debug'))
1212
+
1213
+ # Test in nested contexts
1214
+ with bst.environ.context(precision=32, debug=False):
1215
+ self.assertEqual(bst.environ.get_precision(), 32)
1216
+ self.assertFalse(bst.environ.get('debug'))
1217
+ # Others unchanged
1218
+ self.assertEqual(bst.environ.get_dt(), 0.01)
1219
+ self.assertEqual(bst.environ.get('custom_param'), 'test')
1220
+
1221
+
1222
+ if __name__ == '__main__':
1223
+ unittest.main()