brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,1223 +1,1223 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """
17
- Comprehensive test suite for the environ module.
18
-
19
- This test module provides extensive coverage of the environment configuration
20
- and context management functionality, including:
21
- - Global environment settings
22
- - Context-based temporary settings
23
- - Precision and data type management
24
- - Callback registration and behavior
25
- - Thread safety
26
- - Error handling and validation
27
- """
28
-
29
- import threading
30
- import unittest
31
- import warnings
32
- from unittest.mock import patch
33
-
34
- import jax.numpy as jnp
35
- import numpy as np
36
-
37
- import brainstate as bst
38
-
39
-
40
- class TestEnvironmentCore(unittest.TestCase):
41
- """Test core environment management functionality."""
42
-
43
- def setUp(self):
44
- """Reset environment before each test."""
45
- bst.environ.reset()
46
- # Clear any warnings
47
- warnings.filterwarnings('ignore', category=UserWarning)
48
-
49
- def tearDown(self):
50
- """Clean up after each test."""
51
- # Reset to default state
52
- bst.environ.reset()
53
- warnings.resetwarnings()
54
-
55
- def test_get_set_basic(self):
56
- """Test basic get and set operations."""
57
- # Set a value
58
- bst.environ.set(test_param='test_value')
59
- self.assertEqual(bst.environ.get('test_param'), 'test_value')
60
-
61
- # Set multiple values
62
- bst.environ.set(param1=1, param2='two', param3=3.0)
63
- self.assertEqual(bst.environ.get('param1'), 1)
64
- self.assertEqual(bst.environ.get('param2'), 'two')
65
- self.assertEqual(bst.environ.get('param3'), 3.0)
66
-
67
- def test_get_with_default(self):
68
- """Test get with default value."""
69
- # Non-existent key with default
70
- result = bst.environ.get('nonexistent', default='default_value')
71
- self.assertEqual(result, 'default_value')
72
-
73
- # Existing key ignores default
74
- bst.environ.set(existing='value')
75
- result = bst.environ.get('existing', default='default')
76
- self.assertEqual(result, 'value')
77
-
78
- def test_get_missing_key_error(self):
79
- """Test KeyError for missing keys without default."""
80
- with self.assertRaises(KeyError) as context:
81
- bst.environ.get('missing_key')
82
-
83
- error_msg = str(context.exception)
84
- self.assertIn('missing_key', error_msg)
85
- self.assertIn('not found', error_msg)
86
-
87
- def test_get_with_description(self):
88
- """Test get with description for error messages."""
89
- with self.assertRaises(KeyError) as context:
90
- bst.environ.get('missing', desc='Important parameter for computation')
91
-
92
- error_msg = str(context.exception)
93
- self.assertIn('Important parameter', error_msg)
94
-
95
- def test_all_function(self):
96
- """Test getting all environment settings."""
97
- # Set various parameters
98
- bst.environ.set(
99
- param1='value1',
100
- param2=42,
101
- param3=3.14,
102
- precision=32
103
- )
104
-
105
- all_settings = bst.environ.all()
106
- self.assertIsInstance(all_settings, dict)
107
- self.assertEqual(all_settings['param1'], 'value1')
108
- self.assertEqual(all_settings['param2'], 42)
109
- self.assertEqual(all_settings['param3'], 3.14)
110
- self.assertEqual(all_settings['precision'], 32)
111
-
112
- def test_reset_function(self):
113
- """Test environment reset functionality."""
114
- # Set custom values
115
- bst.environ.set(
116
- custom1='value1',
117
- custom2='value2',
118
- precision=64
119
- )
120
-
121
- # Verify they're set
122
- self.assertEqual(bst.environ.get('custom1'), 'value1')
123
-
124
- # Reset environment
125
- with warnings.catch_warnings():
126
- warnings.simplefilter("ignore")
127
- bst.environ.reset()
128
-
129
- # Custom values should be gone
130
- result = bst.environ.get('custom1', default=None)
131
- self.assertIsNone(result)
132
-
133
- # Default precision should be restored
134
- self.assertEqual(bst.environ.get('precision'), bst.environ.DEFAULT_PRECISION)
135
-
136
- def test_special_environment_keys(self):
137
- """Test special environment key constants."""
138
- # Test setting using constants
139
- bst.environ.set(**{
140
- bst.environ.DT: 0.01,
141
- bst.environ.I: 0,
142
- bst.environ.T: 0.0,
143
- bst.environ.JIT_ERROR_CHECK: True,
144
- bst.environ.FIT: False
145
- })
146
-
147
- self.assertEqual(bst.environ.get(bst.environ.DT), 0.01)
148
- self.assertEqual(bst.environ.get(bst.environ.I), 0)
149
- self.assertEqual(bst.environ.get(bst.environ.T), 0.0)
150
- self.assertTrue(bst.environ.get(bst.environ.JIT_ERROR_CHECK))
151
- self.assertFalse(bst.environ.get(bst.environ.FIT))
152
-
153
- def test_pop_basic(self):
154
- """Test basic pop operation."""
155
- # Set a value
156
- bst.environ.set(pop_test='test_value')
157
- self.assertEqual(bst.environ.get('pop_test'), 'test_value')
158
-
159
- # Pop the value
160
- popped = bst.environ.pop('pop_test')
161
- self.assertEqual(popped, 'test_value')
162
-
163
- # Value should be gone
164
- result = bst.environ.get('pop_test', default=None)
165
- self.assertIsNone(result)
166
-
167
- def test_pop_with_default(self):
168
- """Test pop with default value."""
169
- # Pop non-existent key with default
170
- result = bst.environ.pop('nonexistent_pop', default='default_value')
171
- self.assertEqual(result, 'default_value')
172
-
173
- # Pop existing key ignores default
174
- bst.environ.set(existing_pop='value')
175
- result = bst.environ.pop('existing_pop', default='default')
176
- self.assertEqual(result, 'value')
177
-
178
- def test_pop_missing_key_error(self):
179
- """Test KeyError for missing keys without default."""
180
- with self.assertRaises(KeyError) as context:
181
- bst.environ.pop('missing_pop_key')
182
-
183
- error_msg = str(context.exception)
184
- self.assertIn('missing_pop_key', error_msg)
185
- self.assertIn('not found', error_msg)
186
-
187
- def test_pop_multiple_values(self):
188
- """Test popping multiple values."""
189
- # Set multiple values
190
- bst.environ.set(
191
- pop1='value1',
192
- pop2='value2',
193
- pop3='value3'
194
- )
195
-
196
- # Pop them one by one
197
- v1 = bst.environ.pop('pop1')
198
- v2 = bst.environ.pop('pop2')
199
-
200
- self.assertEqual(v1, 'value1')
201
- self.assertEqual(v2, 'value2')
202
-
203
- # pop3 should still exist
204
- self.assertEqual(bst.environ.get('pop3'), 'value3')
205
-
206
- # pop1 and pop2 should be gone
207
- self.assertIsNone(bst.environ.get('pop1', default=None))
208
- self.assertIsNone(bst.environ.get('pop2', default=None))
209
-
210
- def test_pop_with_context_protection(self):
211
- """Test that pop is prevented when key is in active context."""
212
- # Set a global value
213
- bst.environ.set(protected_key='global_value')
214
-
215
- # Cannot pop while in context
216
- with bst.environ.context(protected_key='context_value'):
217
- with self.assertRaises(ValueError) as context:
218
- bst.environ.pop('protected_key')
219
-
220
- error_msg = str(context.exception)
221
- self.assertIn('Cannot pop', error_msg)
222
- self.assertIn('active in a context', error_msg)
223
-
224
- # Can pop after context exits
225
- popped = bst.environ.pop('protected_key')
226
- self.assertEqual(popped, 'global_value')
227
-
228
- def test_pop_nested_context_protection(self):
229
- """Test pop protection with nested contexts."""
230
- bst.environ.set(nested_key='global')
231
-
232
- with bst.environ.context(nested_key='level1'):
233
- with bst.environ.context(nested_key='level2'):
234
- # Should indicate 2 active contexts
235
- with self.assertRaises(ValueError) as context:
236
- bst.environ.pop('nested_key')
237
-
238
- error_msg = str(context.exception)
239
- self.assertIn('2 context(s)', error_msg)
240
-
241
- def test_pop_does_not_affect_context_values(self):
242
- """Test that pop doesn't affect context values."""
243
- # Set both global and context value
244
- bst.environ.set(dual_key='global')
245
-
246
- with bst.environ.context(other_key='context_only'):
247
- # Can pop a key that's only in global (not in this context)
248
- popped = bst.environ.pop('dual_key')
249
- self.assertEqual(popped, 'global')
250
-
251
- # Context-only values remain accessible
252
- self.assertEqual(bst.environ.get('other_key'), 'context_only')
253
-
254
- # Context value should be gone after exit
255
- self.assertIsNone(bst.environ.get('other_key', default=None))
256
-
257
- def test_pop_precision_key(self):
258
- """Test popping the precision key."""
259
- # Set custom precision
260
- bst.environ.set(precision=64)
261
- self.assertEqual(bst.environ.get_precision(), 64)
262
-
263
- # Pop precision
264
- popped = bst.environ.pop('precision')
265
- self.assertEqual(popped, 64)
266
-
267
-
268
- class TestEnvironmentContext(unittest.TestCase):
269
- """Test context manager functionality."""
270
-
271
- def setUp(self):
272
- """Reset environment before each test."""
273
- bst.environ.reset()
274
- warnings.filterwarnings('ignore', category=UserWarning)
275
-
276
- def tearDown(self):
277
- """Clean up after each test."""
278
- bst.environ.reset()
279
- warnings.resetwarnings()
280
-
281
- def test_basic_context(self):
282
- """Test basic context manager usage."""
283
- bst.environ.set(value=10)
284
-
285
- with bst.environ.context(value=20) as ctx:
286
- # Value should be 20 in context
287
- self.assertEqual(bst.environ.get('value'), 20)
288
- # Context should contain current settings
289
- self.assertEqual(ctx['value'], 20)
290
-
291
- # Value should be restored to 10
292
- self.assertEqual(bst.environ.get('value'), 10)
293
-
294
- def test_nested_contexts(self):
295
- """Test nested context managers."""
296
- bst.environ.set(level=0)
297
-
298
- with bst.environ.context(level=1):
299
- self.assertEqual(bst.environ.get('level'), 1)
300
-
301
- with bst.environ.context(level=2):
302
- self.assertEqual(bst.environ.get('level'), 2)
303
-
304
- with bst.environ.context(level=3):
305
- self.assertEqual(bst.environ.get('level'), 3)
306
-
307
- # Back to level 2
308
- self.assertEqual(bst.environ.get('level'), 2)
309
-
310
- # Back to level 1
311
- self.assertEqual(bst.environ.get('level'), 1)
312
-
313
- # Back to level 0
314
- self.assertEqual(bst.environ.get('level'), 0)
315
-
316
- def test_context_with_exception(self):
317
- """Test context manager handles exceptions properly."""
318
- bst.environ.set(value='original')
319
-
320
- try:
321
- with bst.environ.context(value='temporary'):
322
- self.assertEqual(bst.environ.get('value'), 'temporary')
323
- raise ValueError("Test exception")
324
- except ValueError:
325
- pass
326
-
327
- # Value should be restored despite exception
328
- self.assertEqual(bst.environ.get('value'), 'original')
329
-
330
- def test_context_multiple_parameters(self):
331
- """Test context with multiple parameters."""
332
- bst.environ.set(a=1, b=2, c=3)
333
-
334
- with bst.environ.context(a=10, b=20, c=30, d=40):
335
- self.assertEqual(bst.environ.get('a'), 10)
336
- self.assertEqual(bst.environ.get('b'), 20)
337
- self.assertEqual(bst.environ.get('c'), 30)
338
- self.assertEqual(bst.environ.get('d'), 40)
339
-
340
- # Original values restored
341
- self.assertEqual(bst.environ.get('a'), 1)
342
- self.assertEqual(bst.environ.get('b'), 2)
343
- self.assertEqual(bst.environ.get('c'), 3)
344
- # d should not exist
345
- result = bst.environ.get('d', default=None)
346
- self.assertIsNone(result)
347
-
348
- def test_context_platform_restriction(self):
349
- """Test that platform cannot be set in context."""
350
- with self.assertRaises(ValueError) as context:
351
- with bst.environ.context(platform='cpu'):
352
- pass
353
-
354
- self.assertIn('platform', str(context.exception).lower())
355
- self.assertIn('cannot set', str(context.exception).lower())
356
-
357
- def test_context_host_device_count_restriction(self):
358
- """Test that host_device_count cannot be set in context."""
359
- with self.assertRaises(ValueError) as context:
360
- with bst.environ.context(host_device_count=4):
361
- pass
362
-
363
- self.assertIn('host_device_count', str(context.exception))
364
-
365
- def test_context_mode_validation(self):
366
- """Test mode validation in context."""
367
- # Valid mode
368
- mode = bst.mixin.Training()
369
- with bst.environ.context(mode=mode):
370
- self.assertEqual(bst.environ.get('mode'), mode)
371
-
372
- def test_context_preserves_unmodified_values(self):
373
- """Test that context doesn't affect unmodified values."""
374
- bst.environ.set(unchanged='original', changed='original')
375
-
376
- with bst.environ.context(changed='modified'):
377
- self.assertEqual(bst.environ.get('unchanged'), 'original')
378
- self.assertEqual(bst.environ.get('changed'), 'modified')
379
-
380
-
381
- class TestPrecisionAndDataTypes(unittest.TestCase):
382
- """Test precision control and data type functions."""
383
-
384
- def setUp(self):
385
- """Reset environment before each test."""
386
- bst.environ.reset()
387
- warnings.filterwarnings('ignore', category=UserWarning)
388
-
389
- def tearDown(self):
390
- """Clean up after each test."""
391
- bst.environ.reset()
392
- warnings.resetwarnings()
393
-
394
- def test_precision_settings(self):
395
- """Test different precision settings."""
396
- precisions = [8, 16, 32, 64, 'bf16']
397
-
398
- for precision in precisions:
399
- bst.environ.set(precision=precision)
400
-
401
- if precision == 'bf16':
402
- self.assertEqual(bst.environ.get_precision(), 16)
403
- elif isinstance(precision, str):
404
- self.assertEqual(bst.environ.get_precision(), int(precision))
405
- else:
406
- self.assertEqual(bst.environ.get_precision(), precision)
407
-
408
- def test_precision_context(self):
409
- """Test precision changes in context."""
410
- bst.environ.set(precision=32)
411
-
412
- with bst.environ.context(precision=64):
413
- a = bst.random.randn(1)
414
- self.assertEqual(a.dtype, jnp.float64)
415
- self.assertEqual(bst.environ.get_precision(), 64)
416
-
417
- # Precision restored
418
- b = bst.random.randn(1)
419
- self.assertEqual(b.dtype, jnp.float32)
420
- self.assertEqual(bst.environ.get_precision(), 32)
421
-
422
- def test_dftype_function(self):
423
- """Test default float type function."""
424
- # 32-bit precision
425
- bst.environ.set(precision=32)
426
- self.assertEqual(bst.environ.dftype(), np.float32)
427
-
428
- # 64-bit precision
429
- bst.environ.set(precision=64)
430
- self.assertEqual(bst.environ.dftype(), np.float64)
431
-
432
- # 16-bit precision
433
- bst.environ.set(precision=16)
434
- self.assertEqual(bst.environ.dftype(), np.float16)
435
-
436
- # bfloat16 precision
437
- bst.environ.set(precision='bf16')
438
- self.assertEqual(bst.environ.dftype(), jnp.bfloat16)
439
-
440
- def test_ditype_function(self):
441
- """Test default integer type function."""
442
- # 32-bit precision
443
- bst.environ.set(precision=32)
444
- self.assertEqual(bst.environ.ditype(), np.int32)
445
-
446
- # 64-bit precision
447
- bst.environ.set(precision=64)
448
- self.assertEqual(bst.environ.ditype(), np.int64)
449
-
450
- # 16-bit precision
451
- bst.environ.set(precision=16)
452
- self.assertEqual(bst.environ.ditype(), np.int16)
453
-
454
- # 8-bit precision
455
- bst.environ.set(precision=8)
456
- self.assertEqual(bst.environ.ditype(), np.int8)
457
-
458
- def test_dutype_function(self):
459
- """Test default unsigned integer type function."""
460
- # 32-bit precision
461
- bst.environ.set(precision=32)
462
- self.assertEqual(bst.environ.dutype(), np.uint32)
463
-
464
- # 64-bit precision
465
- bst.environ.set(precision=64)
466
- self.assertEqual(bst.environ.dutype(), np.uint64)
467
-
468
- # 16-bit precision
469
- bst.environ.set(precision=16)
470
- self.assertEqual(bst.environ.dutype(), np.uint16)
471
-
472
- # 8-bit precision
473
- bst.environ.set(precision=8)
474
- self.assertEqual(bst.environ.dutype(), np.uint8)
475
-
476
- def test_dctype_function(self):
477
- """Test default complex type function."""
478
- # 32-bit precision
479
- bst.environ.set(precision=32)
480
- self.assertEqual(bst.environ.dctype(), np.complex64)
481
-
482
- # 64-bit precision
483
- bst.environ.set(precision=64)
484
- self.assertEqual(bst.environ.dctype(), np.complex128)
485
-
486
- # 16-bit precision (should use complex64)
487
- bst.environ.set(precision=16)
488
- self.assertEqual(bst.environ.dctype(), np.complex64)
489
-
490
- def test_tolerance_function(self):
491
- """Test tolerance values for different precisions."""
492
- # 64-bit precision
493
- bst.environ.set(precision=64)
494
- tol = bst.environ.tolerance()
495
- self.assertAlmostEqual(float(tol), 1e-12, places=14)
496
-
497
- # 32-bit precision
498
- bst.environ.set(precision=32)
499
- tol = bst.environ.tolerance()
500
- self.assertAlmostEqual(float(tol), 1e-5, places=7)
501
-
502
- # 16-bit precision
503
- bst.environ.set(precision=16)
504
- tol = bst.environ.tolerance()
505
- self.assertAlmostEqual(float(tol), 1e-2, places=4)
506
-
507
- def test_invalid_precision(self):
508
- """Test invalid precision values."""
509
- invalid_precisions = [128, 'invalid', -1, 3.14]
510
-
511
- for invalid in invalid_precisions:
512
- with self.assertRaises(ValueError):
513
- bst.environ.set(precision=invalid)
514
-
515
- def test_precision_with_arrays(self):
516
- """Test that precision affects array creation."""
517
- # Test with different precisions
518
- test_cases = [
519
- (32, jnp.float32),
520
- (64, jnp.float64),
521
- (16, jnp.float16),
522
- ('bf16', jnp.bfloat16),
523
- ]
524
-
525
- for precision, expected_dtype in test_cases:
526
- with bst.environ.context(precision=precision):
527
- # Create array using random
528
- arr = bst.random.randn(5)
529
- self.assertEqual(arr.dtype, expected_dtype)
530
-
531
-
532
- class TestModeAndSpecialGetters(unittest.TestCase):
533
- """Test mode management and special getter functions."""
534
-
535
- def setUp(self):
536
- """Reset environment before each test."""
537
- bst.environ.reset()
538
- warnings.filterwarnings('ignore', category=UserWarning)
539
-
540
- def tearDown(self):
541
- """Clean up after each test."""
542
- bst.environ.reset()
543
- warnings.resetwarnings()
544
-
545
- def test_get_dt(self):
546
- """Test get_dt function."""
547
- # Set dt
548
- bst.environ.set(dt=0.01)
549
- self.assertEqual(bst.environ.get_dt(), 0.01)
550
-
551
- # Test in context
552
- with bst.environ.context(dt=0.001):
553
- self.assertEqual(bst.environ.get_dt(), 0.001)
554
-
555
- self.assertEqual(bst.environ.get_dt(), 0.01)
556
-
557
- # Test missing dt
558
- bst.environ.reset()
559
- with self.assertRaises(KeyError):
560
- bst.environ.get_dt()
561
-
562
- def test_get_mode(self):
563
- """Test get_mode function."""
564
- # Set training mode
565
- training = bst.mixin.Training()
566
- bst.environ.set(mode=training)
567
- mode = bst.environ.get('mode')
568
- self.assertEqual(mode, training)
569
- self.assertTrue(mode.has(bst.mixin.Training))
570
-
571
- # Test with batching mode
572
- batching = bst.mixin.Batching(batch_size=32)
573
- with bst.environ.context(mode=batching):
574
- mode = bst.environ.get('mode')
575
- self.assertEqual(mode, batching)
576
- self.assertTrue(mode.has(bst.mixin.Batching))
577
- self.assertEqual(mode.batch_size, 32)
578
-
579
- # Test missing mode
580
- bst.environ.reset()
581
- with self.assertRaises(KeyError):
582
- bst.environ.get('mode')
583
-
584
- def test_get_platform(self):
585
- """Test get_platform function."""
586
- platform = bst.environ.get_platform()
587
- self.assertIn(platform, bst.environ.SUPPORTED_PLATFORMS)
588
-
589
- def test_get_host_device_count(self):
590
- """Test get_host_device_count function."""
591
- count = bst.environ.get_host_device_count()
592
- self.assertIsInstance(count, int)
593
- self.assertGreaterEqual(count, 1)
594
-
595
- def test_dt_validation(self):
596
- """Test dt validation in set function."""
597
- # Valid dt values
598
- valid_dts = [0.01, 0.001, 1.0, 0.1]
599
- for dt in valid_dts:
600
- bst.environ.set(dt=dt)
601
- self.assertEqual(bst.environ.get_dt(), dt)
602
-
603
-
604
- class TestPlatformAndDevice(unittest.TestCase):
605
- """Test platform and device management."""
606
-
607
- def setUp(self):
608
- """Reset environment before each test."""
609
- bst.environ.reset()
610
- warnings.filterwarnings('ignore', category=UserWarning)
611
-
612
- def tearDown(self):
613
- """Clean up after each test."""
614
- bst.environ.reset()
615
- warnings.resetwarnings()
616
-
617
- @patch('brainstate.environ.config')
618
- def test_set_platform(self, mock_config):
619
- """Test platform setting."""
620
- platforms = ['cpu', 'gpu', 'tpu']
621
-
622
- for platform in platforms:
623
- bst.environ.set_platform(platform)
624
- mock_config.update.assert_called_with("jax_platform_name", platform)
625
-
626
- # Test invalid platform
627
- with self.assertRaises(ValueError):
628
- bst.environ.set_platform('invalid')
629
-
630
- def test_set_platform_through_set(self):
631
- """Test setting platform through general set function."""
632
- with patch('brainstate.environ.config') as mock_config:
633
- bst.environ.set(platform='gpu')
634
- mock_config.update.assert_called_with("jax_platform_name", 'gpu')
635
-
636
- def test_set_host_device_count(self):
637
- """Test host device count setting."""
638
- import os
639
-
640
- # Set device count
641
- bst.environ.set_host_device_count(4)
642
- xla_flags = os.environ.get("XLA_FLAGS", "")
643
- self.assertIn("--xla_force_host_platform_device_count=4", xla_flags)
644
-
645
- # Update device count
646
- bst.environ.set_host_device_count(8)
647
- xla_flags = os.environ.get("XLA_FLAGS", "")
648
- self.assertIn("--xla_force_host_platform_device_count=8", xla_flags)
649
- self.assertNotIn("--xla_force_host_platform_device_count=4", xla_flags)
650
-
651
- # Invalid device count
652
- with self.assertRaises(ValueError):
653
- bst.environ.set_host_device_count(0)
654
-
655
- with self.assertRaises(ValueError):
656
- bst.environ.set_host_device_count(-1)
657
-
658
- def test_platform_context_restriction(self):
659
- """Test that platform cannot be changed in context."""
660
- with self.assertRaises(ValueError):
661
- with bst.environ.context(platform='cpu'):
662
- pass
663
-
664
-
665
- class TestCallbackBehavior(unittest.TestCase):
666
- """Test callback registration and behavior."""
667
-
668
- def setUp(self):
669
- """Reset environment before each test."""
670
- bst.environ.reset()
671
- warnings.filterwarnings('ignore', category=UserWarning)
672
- self.callback_values = []
673
-
674
- def tearDown(self):
675
- """Clean up after each test."""
676
- bst.environ.reset()
677
- warnings.resetwarnings()
678
-
679
- # def test_register_callback(self):
680
- # """Test basic callback registration."""
681
- # def callback(value):
682
- # self.callback_values.append(value)
683
- #
684
- # brainstate.environ.register_default_behavior('test_param', callback)
685
- #
686
- # # Callback should be triggered on set
687
- # brainstate.environ.set(test_param='value1')
688
- # self.assertEqual(self.callback_values, ['value1'])
689
- #
690
- # # Callback should be triggered on context enter/exit
691
- # with brainstate.environ.context(test_param='value2'):
692
- # self.assertEqual(self.callback_values, ['value1', 'value2'])
693
- #
694
- # # Should restore previous value
695
- # self.assertEqual(self.callback_values, ['value1', 'value2', 'value1'])
696
-
697
- def test_register_multiple_callbacks(self):
698
- """Test registering callbacks for different keys."""
699
- values_a = []
700
- values_b = []
701
-
702
- def callback_a(value):
703
- values_a.append(value)
704
-
705
- def callback_b(value):
706
- values_b.append(value)
707
-
708
- bst.environ.register_default_behavior('param_a', callback_a)
709
- bst.environ.register_default_behavior('param_b', callback_b)
710
-
711
- bst.environ.set(param_a='a1', param_b='b1')
712
- self.assertEqual(values_a, ['a1'])
713
- self.assertEqual(values_b, ['b1'])
714
-
715
- def test_replace_callback(self):
716
- """Test replacing existing callbacks."""
717
-
718
- def callback1(value):
719
- self.callback_values.append(f'cb1:{value}')
720
-
721
- def callback2(value):
722
- self.callback_values.append(f'cb2:{value}')
723
-
724
- # Register first callback
725
- bst.environ.register_default_behavior('param', callback1)
726
-
727
- # Try to register second without replace flag
728
- with self.assertRaises(ValueError):
729
- bst.environ.register_default_behavior('param', callback2)
730
-
731
- # Register with replace flag
732
- bst.environ.register_default_behavior('param', callback2, replace_if_exist=True)
733
-
734
- # Only second callback should be called
735
- bst.environ.set(param='test')
736
- self.assertEqual(self.callback_values, ['cb2:test'])
737
-
738
- def test_unregister_callback(self):
739
- """Test unregistering callbacks."""
740
-
741
- def callback(value):
742
- self.callback_values.append(value)
743
-
744
- # Register and test
745
- bst.environ.register_default_behavior('param', callback)
746
- bst.environ.set(param='value1')
747
- self.assertEqual(len(self.callback_values), 1)
748
-
749
- # Unregister
750
- removed = bst.environ.unregister_default_behavior('param')
751
- self.assertTrue(removed)
752
-
753
- # Callback should not be triggered
754
- bst.environ.set(param='value2')
755
- self.assertEqual(len(self.callback_values), 1) # Still just one
756
-
757
- # Unregister non-existent
758
- removed = bst.environ.unregister_default_behavior('nonexistent')
759
- self.assertFalse(removed)
760
-
761
- def test_list_registered_behaviors(self):
762
- """Test listing registered behaviors."""
763
- # Initially empty or with system defaults
764
- initial = bst.environ.list_registered_behaviors()
765
-
766
- # Register some behaviors
767
- bst.environ.register_default_behavior('param1', lambda x: None)
768
- bst.environ.register_default_behavior('param2', lambda x: None)
769
- bst.environ.register_default_behavior('param3', lambda x: None)
770
-
771
- behaviors = bst.environ.list_registered_behaviors()
772
- for param in ['param1', 'param2', 'param3']:
773
- self.assertIn(param, behaviors)
774
-
775
- def test_callback_exception_handling(self):
776
- """Test that exceptions in callbacks are handled gracefully."""
777
-
778
- def failing_callback(value):
779
- raise RuntimeError(f"Intentional error: {value}")
780
-
781
- bst.environ.register_default_behavior('param', failing_callback)
782
-
783
- # Should not crash, but should warn
784
- with warnings.catch_warnings(record=True) as w:
785
- warnings.simplefilter("always")
786
- bst.environ.set(param='test')
787
-
788
- # Should have a warning
789
- self.assertTrue(len(w) > 0)
790
- self.assertIn('Callback', str(w[0].message))
791
- self.assertIn('exception', str(w[0].message))
792
-
793
- def test_callback_validation(self):
794
- """Test callback validation."""
795
- # Non-callable
796
- with self.assertRaises(TypeError):
797
- bst.environ.register_default_behavior('param', 'not_callable')
798
-
799
- # Non-string key
800
- with self.assertRaises(TypeError):
801
- bst.environ.register_default_behavior(123, lambda x: None)
802
-
803
- def test_callback_with_validation(self):
804
- """Test using callbacks for validation."""
805
-
806
- def validate_positive(value):
807
- if value <= 0:
808
- raise ValueError(f"Value must be positive, got {value}")
809
- self.callback_values.append(value)
810
-
811
- bst.environ.register_default_behavior('positive_param', validate_positive)
812
-
813
- # Valid value
814
- bst.environ.set(positive_param=10)
815
- self.assertEqual(self.callback_values, [10])
816
-
817
- # Invalid value should raise through warning system
818
- with warnings.catch_warnings(record=True):
819
- warnings.simplefilter("always")
820
- bst.environ.set(positive_param=-5)
821
-
822
-
823
- class TestThreadSafety(unittest.TestCase):
824
- """Test thread safety of environment operations."""
825
-
826
- def setUp(self):
827
- """Reset environment before each test."""
828
- bst.environ.reset()
829
- warnings.filterwarnings('ignore', category=UserWarning)
830
-
831
- def tearDown(self):
832
- """Clean up after each test."""
833
- bst.environ.reset()
834
- warnings.resetwarnings()
835
-
836
- def test_concurrent_set_operations(self):
837
- """Test concurrent set operations from multiple threads."""
838
- results = []
839
- errors = []
840
-
841
- def thread_operation(thread_id):
842
- try:
843
- # Each thread sets its own value
844
- for i in range(10):
845
- bst.environ.set(**{f'thread_{thread_id}': i})
846
- value = bst.environ.get(f'thread_{thread_id}')
847
- results.append((thread_id, value))
848
- except Exception as e:
849
- errors.append(e)
850
-
851
- threads = []
852
- for i in range(5):
853
- thread = threading.Thread(target=thread_operation, args=(i,))
854
- threads.append(thread)
855
- thread.start()
856
-
857
- for thread in threads:
858
- thread.join()
859
-
860
- # Should have no errors
861
- self.assertEqual(len(errors), 0)
862
-
863
- # Each thread should have written its values
864
- for i in range(5):
865
- try:
866
- final_value = bst.environ.get(f'thread_{i}')
867
- except KeyError:
868
- pass
869
-
870
- def test_concurrent_context_operations(self):
871
- """Test concurrent context operations from multiple threads."""
872
- results = []
873
- errors = []
874
-
875
- def thread_context_operation(thread_id):
876
- try:
877
- bst.environ.set(**{f'base_{thread_id}': 0})
878
-
879
- for i in range(5):
880
- with bst.environ.context(**{f'base_{thread_id}': i}):
881
- value = bst.environ.get(f'base_{thread_id}')
882
- results.append((thread_id, value))
883
-
884
- # Should be back to 0
885
- final = bst.environ.get(f'base_{thread_id}')
886
- self.assertEqual(final, 0)
887
- except Exception as e:
888
- errors.append(e)
889
-
890
- threads = []
891
- for i in range(3):
892
- thread = threading.Thread(target=thread_context_operation, args=(i,))
893
- threads.append(thread)
894
- thread.start()
895
-
896
- for thread in threads:
897
- thread.join()
898
-
899
- # Should have no errors
900
- self.assertEqual(len(errors), 0)
901
-
902
- def test_concurrent_pop_operations(self):
903
- """Test concurrent pop operations from multiple threads."""
904
- # Set up multiple keys
905
- for i in range(20):
906
- bst.environ.set(**{f'pop_thread_{i}': f'value_{i}'})
907
-
908
- results = []
909
- errors = []
910
-
911
- def thread_pop_operation(start, end):
912
- try:
913
- for i in range(start, end):
914
- try:
915
- value = bst.environ.pop(f'pop_thread_{i}')
916
- results.append((i, value))
917
- except KeyError:
918
- # Key might already be popped by another thread
919
- pass
920
- except Exception as e:
921
- errors.append(e)
922
-
923
- # Create threads that pop different ranges
924
- threads = []
925
- ranges = [(0, 5), (5, 10), (10, 15), (15, 20)]
926
- for start, end in ranges:
927
- thread = threading.Thread(target=thread_pop_operation, args=(start, end))
928
- threads.append(thread)
929
- thread.start()
930
-
931
- for thread in threads:
932
- thread.join()
933
-
934
- # Should have no errors
935
- self.assertEqual(len(errors), 0)
936
-
937
- # # All keys should be popped (each exactly once)
938
- # popped_indices = [r[0] for r in results]
939
- # self.assertEqual(len(popped_indices), 20)
940
- # self.assertEqual(len(set(popped_indices)), 20) # All unique
941
- #
942
- # # All values should be gone
943
- # for i in range(20):
944
- # result = brainstate.environ.get(f'pop_thread_{i}', default=None)
945
- # self.assertIsNone(result)
946
-
947
-
948
- class TestEdgeCases(unittest.TestCase):
949
- """Test edge cases and boundary conditions."""
950
-
951
- def setUp(self):
952
- """Reset environment before each test."""
953
- bst.environ.reset()
954
- warnings.filterwarnings('ignore', category=UserWarning)
955
-
956
- def tearDown(self):
957
- """Clean up after each test."""
958
- bst.environ.reset()
959
- warnings.resetwarnings()
960
-
961
- def test_empty_context(self):
962
- """Test context with no parameters."""
963
- original = bst.environ.all()
964
-
965
- with bst.environ.context() as ctx:
966
- # Should be unchanged
967
- self.assertEqual(ctx, original)
968
-
969
- self.assertEqual(bst.environ.all(), original)
970
-
971
- def test_none_values(self):
972
- """Test handling of None values."""
973
- bst.environ.set(none_param=None)
974
- self.assertIsNone(bst.environ.get('none_param'))
975
-
976
- with bst.environ.context(none_param='not_none'):
977
- self.assertEqual(bst.environ.get('none_param'), 'not_none')
978
-
979
- self.assertIsNone(bst.environ.get('none_param'))
980
-
981
- def test_complex_data_types(self):
982
- """Test storing complex data types."""
983
- # Lists
984
- bst.environ.set(list_param=[1, 2, 3])
985
- self.assertEqual(bst.environ.get('list_param'), [1, 2, 3])
986
-
987
- # Dictionaries
988
- bst.environ.set(dict_param={'a': 1, 'b': 2})
989
- self.assertEqual(bst.environ.get('dict_param'), {'a': 1, 'b': 2})
990
-
991
- # Tuples
992
- bst.environ.set(tuple_param=(1, 2, 3))
993
- self.assertEqual(bst.environ.get('tuple_param'), (1, 2, 3))
994
-
995
- # Custom objects
996
- class CustomObject:
997
- def __init__(self, value):
998
- self.value = value
999
-
1000
- obj = CustomObject(42)
1001
- bst.environ.set(obj_param=obj)
1002
- retrieved = bst.environ.get('obj_param')
1003
- self.assertIs(retrieved, obj)
1004
- self.assertEqual(retrieved.value, 42)
1005
-
1006
- def test_special_string_values(self):
1007
- """Test special string values."""
1008
- special_strings = ['', ' ', '\n', '\t', 'None', 'True', 'False']
1009
-
1010
- for s in special_strings:
1011
- bst.environ.set(string_param=s)
1012
- self.assertEqual(bst.environ.get('string_param'), s)
1013
-
1014
- def test_numeric_edge_values(self):
1015
- """Test numeric edge values."""
1016
- import sys
1017
-
1018
- edge_values = [
1019
- 0, -0, 1, -1,
1020
- sys.maxsize, -sys.maxsize,
1021
- float('inf'), float('-inf'),
1022
- 1e-100, 1e100,
1023
- ]
1024
-
1025
- for value in edge_values:
1026
- bst.environ.set(numeric_param=value)
1027
- retrieved = bst.environ.get('numeric_param')
1028
- if value != value: # NaN check
1029
- self.assertTrue(retrieved != retrieved)
1030
- else:
1031
- self.assertEqual(retrieved, value)
1032
-
1033
- def test_context_all_interaction(self):
1034
- """Test interaction between context and all() function."""
1035
- bst.environ.set(global_param='global')
1036
-
1037
- with bst.environ.context(context_param='context', global_param='override'):
1038
- all_values = bst.environ.all()
1039
-
1040
- # Should include both
1041
- self.assertEqual(all_values['global_param'], 'override')
1042
- self.assertEqual(all_values['context_param'], 'context')
1043
-
1044
- # Original global values should be in settings
1045
- self.assertIn('precision', all_values)
1046
-
1047
- def test_deeply_nested_contexts(self):
1048
- """Test deeply nested contexts."""
1049
- depth = 20
1050
- bst.environ.set(depth=0)
1051
-
1052
- def nested_context(level):
1053
- if level < depth:
1054
- with bst.environ.context(depth=level):
1055
- self.assertEqual(bst.environ.get('depth'), level)
1056
- nested_context(level + 1)
1057
- self.assertEqual(bst.environ.get('depth'), level)
1058
-
1059
- nested_context(1)
1060
- self.assertEqual(bst.environ.get('depth'), 0)
1061
-
1062
- def test_set_precision_function(self):
1063
- """Test the dedicated set_precision function."""
1064
- # Valid precisions
1065
- for precision in [8, 16, 32, 64, 'bf16']:
1066
- bst.environ.set_precision(precision)
1067
- self.assertEqual(bst.environ.get('precision'), precision)
1068
-
1069
- # Invalid precision
1070
- with self.assertRaises(ValueError):
1071
- bst.environ.set_precision(128)
1072
-
1073
- def test_pop_edge_cases(self):
1074
- """Test edge cases for pop function."""
1075
- # Pop with None value
1076
- bst.environ.set(none_key=None)
1077
- popped = bst.environ.pop('none_key')
1078
- self.assertIsNone(popped)
1079
-
1080
- # Pop with None as default
1081
- result = bst.environ.pop('missing_key', default=None)
1082
- self.assertIsNone(result)
1083
-
1084
- # Pop complex data types
1085
- complex_obj = {'nested': {'data': [1, 2, 3]}}
1086
- bst.environ.set(complex_key=complex_obj)
1087
- popped = bst.environ.pop('complex_key')
1088
- self.assertEqual(popped, complex_obj)
1089
-
1090
- # Verify object identity preservation
1091
- obj = object()
1092
- bst.environ.set(obj_key=obj)
1093
- popped = bst.environ.pop('obj_key')
1094
- self.assertIs(popped, obj)
1095
-
1096
- def test_pop_all_interaction(self):
1097
- """Test interaction between pop and all() function."""
1098
- # Set multiple values
1099
- bst.environ.set(a=1, b=2, c=3, d=4)
1100
- initial_all = bst.environ.all()
1101
-
1102
- # Pop some values
1103
- bst.environ.pop('b')
1104
- bst.environ.pop('d')
1105
-
1106
- # Check all() reflects the changes
1107
- after_pop = bst.environ.all()
1108
- self.assertIn('a', after_pop)
1109
- self.assertIn('c', after_pop)
1110
- self.assertNotIn('b', after_pop)
1111
- self.assertNotIn('d', after_pop)
1112
-
1113
- def test_pop_callback_not_triggered(self):
1114
- """Test that callbacks are not triggered on pop."""
1115
- callback_calls = []
1116
-
1117
- def callback(value):
1118
- callback_calls.append(value)
1119
-
1120
- # Register callback
1121
- bst.environ.register_default_behavior('callback_test', callback)
1122
-
1123
- # Set triggers callback
1124
- bst.environ.set(callback_test='value')
1125
- self.assertEqual(len(callback_calls), 1)
1126
-
1127
- # Pop should NOT trigger callback
1128
- popped = bst.environ.pop('callback_test')
1129
- self.assertEqual(len(callback_calls), 1) # Still just 1
1130
- self.assertEqual(popped, 'value')
1131
-
1132
- # Unregister callback
1133
- bst.environ.unregister_default_behavior('callback_test')
1134
-
1135
-
1136
- class TestIntegration(unittest.TestCase):
1137
- """Integration tests with actual BrainState functionality."""
1138
-
1139
- def setUp(self):
1140
- """Reset environment before each test."""
1141
- bst.environ.reset()
1142
- warnings.filterwarnings('ignore', category=UserWarning)
1143
-
1144
- def tearDown(self):
1145
- """Clean up after each test."""
1146
- bst.environ.reset()
1147
- warnings.resetwarnings()
1148
-
1149
- def test_precision_affects_random_arrays(self):
1150
- """Test that precision setting affects random array generation."""
1151
- # Test different precisions
1152
- test_cases = [
1153
- (32, jnp.float32),
1154
- (64, jnp.float64),
1155
- (16, jnp.float16),
1156
- ('bf16', jnp.bfloat16),
1157
- ]
1158
-
1159
- for precision, expected_dtype in test_cases:
1160
- with bst.environ.context(precision=precision):
1161
- arr = bst.random.randn(10)
1162
- self.assertEqual(arr.dtype, expected_dtype)
1163
-
1164
- def test_mode_usage(self):
1165
- """Test mode usage in computations."""
1166
- # Create different modes
1167
- training = bst.mixin.Training()
1168
- batching = bst.mixin.Batching(batch_size=32)
1169
-
1170
- # Test training mode
1171
- bst.environ.set(mode=training)
1172
- mode = bst.environ.get('mode')
1173
- self.assertTrue(mode.has(bst.mixin.Training))
1174
-
1175
- # Test batching mode
1176
- with bst.environ.context(mode=batching):
1177
- mode = bst.environ.get('mode')
1178
- self.assertTrue(mode.has(bst.mixin.Batching))
1179
- self.assertEqual(mode.batch_size, 32)
1180
-
1181
- def test_dt_in_numerical_integration(self):
1182
- """Test dt usage in numerical contexts."""
1183
- # Set different dt values
1184
- dt_values = [0.01, 0.001, 0.1]
1185
-
1186
- for dt in dt_values:
1187
- bst.environ.set(dt=dt)
1188
- retrieved_dt = bst.environ.get_dt()
1189
- self.assertEqual(retrieved_dt, dt)
1190
-
1191
- # Simulate using dt in computation
1192
- time_steps = int(1.0 / dt)
1193
- self.assertGreater(time_steps, 0)
1194
-
1195
- def test_combined_settings(self):
1196
- """Test combining multiple settings."""
1197
- # Set multiple parameters
1198
- bst.environ.set(
1199
- precision=64,
1200
- dt=0.01,
1201
- mode=bst.mixin.Training(),
1202
- custom_param='test',
1203
- debug=True
1204
- )
1205
-
1206
- # Verify all are set
1207
- self.assertEqual(bst.environ.get_precision(), 64)
1208
- self.assertEqual(bst.environ.get_dt(), 0.01)
1209
- self.assertTrue(bst.environ.get('mode').has(bst.mixin.Training))
1210
- self.assertEqual(bst.environ.get('custom_param'), 'test')
1211
- self.assertTrue(bst.environ.get('debug'))
1212
-
1213
- # Test in nested contexts
1214
- with bst.environ.context(precision=32, debug=False):
1215
- self.assertEqual(bst.environ.get_precision(), 32)
1216
- self.assertFalse(bst.environ.get('debug'))
1217
- # Others unchanged
1218
- self.assertEqual(bst.environ.get_dt(), 0.01)
1219
- self.assertEqual(bst.environ.get('custom_param'), 'test')
1220
-
1221
-
1222
- if __name__ == '__main__':
1223
- unittest.main()
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ Comprehensive test suite for the environ module.
18
+
19
+ This test module provides extensive coverage of the environment configuration
20
+ and context management functionality, including:
21
+ - Global environment settings
22
+ - Context-based temporary settings
23
+ - Precision and data type management
24
+ - Callback registration and behavior
25
+ - Thread safety
26
+ - Error handling and validation
27
+ """
28
+
29
+ import threading
30
+ import unittest
31
+ import warnings
32
+ from unittest.mock import patch
33
+
34
+ import jax.numpy as jnp
35
+ import numpy as np
36
+
37
+ import brainstate as bst
38
+
39
+
40
+ class TestEnvironmentCore(unittest.TestCase):
41
+ """Test core environment management functionality."""
42
+
43
+ def setUp(self):
44
+ """Reset environment before each test."""
45
+ bst.environ.reset()
46
+ # Clear any warnings
47
+ warnings.filterwarnings('ignore', category=UserWarning)
48
+
49
+ def tearDown(self):
50
+ """Clean up after each test."""
51
+ # Reset to default state
52
+ bst.environ.reset()
53
+ warnings.resetwarnings()
54
+
55
+ def test_get_set_basic(self):
56
+ """Test basic get and set operations."""
57
+ # Set a value
58
+ bst.environ.set(test_param='test_value')
59
+ self.assertEqual(bst.environ.get('test_param'), 'test_value')
60
+
61
+ # Set multiple values
62
+ bst.environ.set(param1=1, param2='two', param3=3.0)
63
+ self.assertEqual(bst.environ.get('param1'), 1)
64
+ self.assertEqual(bst.environ.get('param2'), 'two')
65
+ self.assertEqual(bst.environ.get('param3'), 3.0)
66
+
67
+ def test_get_with_default(self):
68
+ """Test get with default value."""
69
+ # Non-existent key with default
70
+ result = bst.environ.get('nonexistent', default='default_value')
71
+ self.assertEqual(result, 'default_value')
72
+
73
+ # Existing key ignores default
74
+ bst.environ.set(existing='value')
75
+ result = bst.environ.get('existing', default='default')
76
+ self.assertEqual(result, 'value')
77
+
78
+ def test_get_missing_key_error(self):
79
+ """Test KeyError for missing keys without default."""
80
+ with self.assertRaises(KeyError) as context:
81
+ bst.environ.get('missing_key')
82
+
83
+ error_msg = str(context.exception)
84
+ self.assertIn('missing_key', error_msg)
85
+ self.assertIn('not found', error_msg)
86
+
87
+ def test_get_with_description(self):
88
+ """Test get with description for error messages."""
89
+ with self.assertRaises(KeyError) as context:
90
+ bst.environ.get('missing', desc='Important parameter for computation')
91
+
92
+ error_msg = str(context.exception)
93
+ self.assertIn('Important parameter', error_msg)
94
+
95
+ def test_all_function(self):
96
+ """Test getting all environment settings."""
97
+ # Set various parameters
98
+ bst.environ.set(
99
+ param1='value1',
100
+ param2=42,
101
+ param3=3.14,
102
+ precision=32
103
+ )
104
+
105
+ all_settings = bst.environ.all()
106
+ self.assertIsInstance(all_settings, dict)
107
+ self.assertEqual(all_settings['param1'], 'value1')
108
+ self.assertEqual(all_settings['param2'], 42)
109
+ self.assertEqual(all_settings['param3'], 3.14)
110
+ self.assertEqual(all_settings['precision'], 32)
111
+
112
+ def test_reset_function(self):
113
+ """Test environment reset functionality."""
114
+ # Set custom values
115
+ bst.environ.set(
116
+ custom1='value1',
117
+ custom2='value2',
118
+ precision=64
119
+ )
120
+
121
+ # Verify they're set
122
+ self.assertEqual(bst.environ.get('custom1'), 'value1')
123
+
124
+ # Reset environment
125
+ with warnings.catch_warnings():
126
+ warnings.simplefilter("ignore")
127
+ bst.environ.reset()
128
+
129
+ # Custom values should be gone
130
+ result = bst.environ.get('custom1', default=None)
131
+ self.assertIsNone(result)
132
+
133
+ # Default precision should be restored
134
+ self.assertEqual(bst.environ.get('precision'), bst.environ.DEFAULT_PRECISION)
135
+
136
+ def test_special_environment_keys(self):
137
+ """Test special environment key constants."""
138
+ # Test setting using constants
139
+ bst.environ.set(**{
140
+ bst.environ.DT: 0.01,
141
+ bst.environ.I: 0,
142
+ bst.environ.T: 0.0,
143
+ bst.environ.JIT_ERROR_CHECK: True,
144
+ bst.environ.FIT: False
145
+ })
146
+
147
+ self.assertEqual(bst.environ.get(bst.environ.DT), 0.01)
148
+ self.assertEqual(bst.environ.get(bst.environ.I), 0)
149
+ self.assertEqual(bst.environ.get(bst.environ.T), 0.0)
150
+ self.assertTrue(bst.environ.get(bst.environ.JIT_ERROR_CHECK))
151
+ self.assertFalse(bst.environ.get(bst.environ.FIT))
152
+
153
+ def test_pop_basic(self):
154
+ """Test basic pop operation."""
155
+ # Set a value
156
+ bst.environ.set(pop_test='test_value')
157
+ self.assertEqual(bst.environ.get('pop_test'), 'test_value')
158
+
159
+ # Pop the value
160
+ popped = bst.environ.pop('pop_test')
161
+ self.assertEqual(popped, 'test_value')
162
+
163
+ # Value should be gone
164
+ result = bst.environ.get('pop_test', default=None)
165
+ self.assertIsNone(result)
166
+
167
+ def test_pop_with_default(self):
168
+ """Test pop with default value."""
169
+ # Pop non-existent key with default
170
+ result = bst.environ.pop('nonexistent_pop', default='default_value')
171
+ self.assertEqual(result, 'default_value')
172
+
173
+ # Pop existing key ignores default
174
+ bst.environ.set(existing_pop='value')
175
+ result = bst.environ.pop('existing_pop', default='default')
176
+ self.assertEqual(result, 'value')
177
+
178
+ def test_pop_missing_key_error(self):
179
+ """Test KeyError for missing keys without default."""
180
+ with self.assertRaises(KeyError) as context:
181
+ bst.environ.pop('missing_pop_key')
182
+
183
+ error_msg = str(context.exception)
184
+ self.assertIn('missing_pop_key', error_msg)
185
+ self.assertIn('not found', error_msg)
186
+
187
+ def test_pop_multiple_values(self):
188
+ """Test popping multiple values."""
189
+ # Set multiple values
190
+ bst.environ.set(
191
+ pop1='value1',
192
+ pop2='value2',
193
+ pop3='value3'
194
+ )
195
+
196
+ # Pop them one by one
197
+ v1 = bst.environ.pop('pop1')
198
+ v2 = bst.environ.pop('pop2')
199
+
200
+ self.assertEqual(v1, 'value1')
201
+ self.assertEqual(v2, 'value2')
202
+
203
+ # pop3 should still exist
204
+ self.assertEqual(bst.environ.get('pop3'), 'value3')
205
+
206
+ # pop1 and pop2 should be gone
207
+ self.assertIsNone(bst.environ.get('pop1', default=None))
208
+ self.assertIsNone(bst.environ.get('pop2', default=None))
209
+
210
+ def test_pop_with_context_protection(self):
211
+ """Test that pop is prevented when key is in active context."""
212
+ # Set a global value
213
+ bst.environ.set(protected_key='global_value')
214
+
215
+ # Cannot pop while in context
216
+ with bst.environ.context(protected_key='context_value'):
217
+ with self.assertRaises(ValueError) as context:
218
+ bst.environ.pop('protected_key')
219
+
220
+ error_msg = str(context.exception)
221
+ self.assertIn('Cannot pop', error_msg)
222
+ self.assertIn('active in a context', error_msg)
223
+
224
+ # Can pop after context exits
225
+ popped = bst.environ.pop('protected_key')
226
+ self.assertEqual(popped, 'global_value')
227
+
228
+ def test_pop_nested_context_protection(self):
229
+ """Test pop protection with nested contexts."""
230
+ bst.environ.set(nested_key='global')
231
+
232
+ with bst.environ.context(nested_key='level1'):
233
+ with bst.environ.context(nested_key='level2'):
234
+ # Should indicate 2 active contexts
235
+ with self.assertRaises(ValueError) as context:
236
+ bst.environ.pop('nested_key')
237
+
238
+ error_msg = str(context.exception)
239
+ self.assertIn('2 context(s)', error_msg)
240
+
241
+ def test_pop_does_not_affect_context_values(self):
242
+ """Test that pop doesn't affect context values."""
243
+ # Set both global and context value
244
+ bst.environ.set(dual_key='global')
245
+
246
+ with bst.environ.context(other_key='context_only'):
247
+ # Can pop a key that's only in global (not in this context)
248
+ popped = bst.environ.pop('dual_key')
249
+ self.assertEqual(popped, 'global')
250
+
251
+ # Context-only values remain accessible
252
+ self.assertEqual(bst.environ.get('other_key'), 'context_only')
253
+
254
+ # Context value should be gone after exit
255
+ self.assertIsNone(bst.environ.get('other_key', default=None))
256
+
257
+ def test_pop_precision_key(self):
258
+ """Test popping the precision key."""
259
+ # Set custom precision
260
+ bst.environ.set(precision=64)
261
+ self.assertEqual(bst.environ.get_precision(), 64)
262
+
263
+ # Pop precision
264
+ popped = bst.environ.pop('precision')
265
+ self.assertEqual(popped, 64)
266
+
267
+
268
+ class TestEnvironmentContext(unittest.TestCase):
269
+ """Test context manager functionality."""
270
+
271
+ def setUp(self):
272
+ """Reset environment before each test."""
273
+ bst.environ.reset()
274
+ warnings.filterwarnings('ignore', category=UserWarning)
275
+
276
+ def tearDown(self):
277
+ """Clean up after each test."""
278
+ bst.environ.reset()
279
+ warnings.resetwarnings()
280
+
281
+ def test_basic_context(self):
282
+ """Test basic context manager usage."""
283
+ bst.environ.set(value=10)
284
+
285
+ with bst.environ.context(value=20) as ctx:
286
+ # Value should be 20 in context
287
+ self.assertEqual(bst.environ.get('value'), 20)
288
+ # Context should contain current settings
289
+ self.assertEqual(ctx['value'], 20)
290
+
291
+ # Value should be restored to 10
292
+ self.assertEqual(bst.environ.get('value'), 10)
293
+
294
+ def test_nested_contexts(self):
295
+ """Test nested context managers."""
296
+ bst.environ.set(level=0)
297
+
298
+ with bst.environ.context(level=1):
299
+ self.assertEqual(bst.environ.get('level'), 1)
300
+
301
+ with bst.environ.context(level=2):
302
+ self.assertEqual(bst.environ.get('level'), 2)
303
+
304
+ with bst.environ.context(level=3):
305
+ self.assertEqual(bst.environ.get('level'), 3)
306
+
307
+ # Back to level 2
308
+ self.assertEqual(bst.environ.get('level'), 2)
309
+
310
+ # Back to level 1
311
+ self.assertEqual(bst.environ.get('level'), 1)
312
+
313
+ # Back to level 0
314
+ self.assertEqual(bst.environ.get('level'), 0)
315
+
316
+ def test_context_with_exception(self):
317
+ """Test context manager handles exceptions properly."""
318
+ bst.environ.set(value='original')
319
+
320
+ try:
321
+ with bst.environ.context(value='temporary'):
322
+ self.assertEqual(bst.environ.get('value'), 'temporary')
323
+ raise ValueError("Test exception")
324
+ except ValueError:
325
+ pass
326
+
327
+ # Value should be restored despite exception
328
+ self.assertEqual(bst.environ.get('value'), 'original')
329
+
330
+ def test_context_multiple_parameters(self):
331
+ """Test context with multiple parameters."""
332
+ bst.environ.set(a=1, b=2, c=3)
333
+
334
+ with bst.environ.context(a=10, b=20, c=30, d=40):
335
+ self.assertEqual(bst.environ.get('a'), 10)
336
+ self.assertEqual(bst.environ.get('b'), 20)
337
+ self.assertEqual(bst.environ.get('c'), 30)
338
+ self.assertEqual(bst.environ.get('d'), 40)
339
+
340
+ # Original values restored
341
+ self.assertEqual(bst.environ.get('a'), 1)
342
+ self.assertEqual(bst.environ.get('b'), 2)
343
+ self.assertEqual(bst.environ.get('c'), 3)
344
+ # d should not exist
345
+ result = bst.environ.get('d', default=None)
346
+ self.assertIsNone(result)
347
+
348
+ def test_context_platform_restriction(self):
349
+ """Test that platform cannot be set in context."""
350
+ with self.assertRaises(ValueError) as context:
351
+ with bst.environ.context(platform='cpu'):
352
+ pass
353
+
354
+ self.assertIn('platform', str(context.exception).lower())
355
+ self.assertIn('cannot set', str(context.exception).lower())
356
+
357
+ def test_context_host_device_count_restriction(self):
358
+ """Test that host_device_count cannot be set in context."""
359
+ with self.assertRaises(ValueError) as context:
360
+ with bst.environ.context(host_device_count=4):
361
+ pass
362
+
363
+ self.assertIn('host_device_count', str(context.exception))
364
+
365
+ def test_context_mode_validation(self):
366
+ """Test mode validation in context."""
367
+ # Valid mode
368
+ mode = bst.mixin.Training()
369
+ with bst.environ.context(mode=mode):
370
+ self.assertEqual(bst.environ.get('mode'), mode)
371
+
372
+ def test_context_preserves_unmodified_values(self):
373
+ """Test that context doesn't affect unmodified values."""
374
+ bst.environ.set(unchanged='original', changed='original')
375
+
376
+ with bst.environ.context(changed='modified'):
377
+ self.assertEqual(bst.environ.get('unchanged'), 'original')
378
+ self.assertEqual(bst.environ.get('changed'), 'modified')
379
+
380
+
381
+ class TestPrecisionAndDataTypes(unittest.TestCase):
382
+ """Test precision control and data type functions."""
383
+
384
+ def setUp(self):
385
+ """Reset environment before each test."""
386
+ bst.environ.reset()
387
+ warnings.filterwarnings('ignore', category=UserWarning)
388
+
389
+ def tearDown(self):
390
+ """Clean up after each test."""
391
+ bst.environ.reset()
392
+ warnings.resetwarnings()
393
+
394
+ def test_precision_settings(self):
395
+ """Test different precision settings."""
396
+ precisions = [8, 16, 32, 64, 'bf16']
397
+
398
+ for precision in precisions:
399
+ bst.environ.set(precision=precision)
400
+
401
+ if precision == 'bf16':
402
+ self.assertEqual(bst.environ.get_precision(), 16)
403
+ elif isinstance(precision, str):
404
+ self.assertEqual(bst.environ.get_precision(), int(precision))
405
+ else:
406
+ self.assertEqual(bst.environ.get_precision(), precision)
407
+
408
+ def test_precision_context(self):
409
+ """Test precision changes in context."""
410
+ bst.environ.set(precision=32)
411
+
412
+ with bst.environ.context(precision=64):
413
+ a = bst.random.randn(1)
414
+ self.assertEqual(a.dtype, jnp.float64)
415
+ self.assertEqual(bst.environ.get_precision(), 64)
416
+
417
+ # Precision restored
418
+ b = bst.random.randn(1)
419
+ self.assertEqual(b.dtype, jnp.float32)
420
+ self.assertEqual(bst.environ.get_precision(), 32)
421
+
422
+ def test_dftype_function(self):
423
+ """Test default float type function."""
424
+ # 32-bit precision
425
+ bst.environ.set(precision=32)
426
+ self.assertEqual(bst.environ.dftype(), np.float32)
427
+
428
+ # 64-bit precision
429
+ bst.environ.set(precision=64)
430
+ self.assertEqual(bst.environ.dftype(), np.float64)
431
+
432
+ # 16-bit precision
433
+ bst.environ.set(precision=16)
434
+ self.assertEqual(bst.environ.dftype(), np.float16)
435
+
436
+ # bfloat16 precision
437
+ bst.environ.set(precision='bf16')
438
+ self.assertEqual(bst.environ.dftype(), jnp.bfloat16)
439
+
440
+ def test_ditype_function(self):
441
+ """Test default integer type function."""
442
+ # 32-bit precision
443
+ bst.environ.set(precision=32)
444
+ self.assertEqual(bst.environ.ditype(), np.int32)
445
+
446
+ # 64-bit precision
447
+ bst.environ.set(precision=64)
448
+ self.assertEqual(bst.environ.ditype(), np.int64)
449
+
450
+ # 16-bit precision
451
+ bst.environ.set(precision=16)
452
+ self.assertEqual(bst.environ.ditype(), np.int16)
453
+
454
+ # 8-bit precision
455
+ bst.environ.set(precision=8)
456
+ self.assertEqual(bst.environ.ditype(), np.int8)
457
+
458
+ def test_dutype_function(self):
459
+ """Test default unsigned integer type function."""
460
+ # 32-bit precision
461
+ bst.environ.set(precision=32)
462
+ self.assertEqual(bst.environ.dutype(), np.uint32)
463
+
464
+ # 64-bit precision
465
+ bst.environ.set(precision=64)
466
+ self.assertEqual(bst.environ.dutype(), np.uint64)
467
+
468
+ # 16-bit precision
469
+ bst.environ.set(precision=16)
470
+ self.assertEqual(bst.environ.dutype(), np.uint16)
471
+
472
+ # 8-bit precision
473
+ bst.environ.set(precision=8)
474
+ self.assertEqual(bst.environ.dutype(), np.uint8)
475
+
476
+ def test_dctype_function(self):
477
+ """Test default complex type function."""
478
+ # 32-bit precision
479
+ bst.environ.set(precision=32)
480
+ self.assertEqual(bst.environ.dctype(), np.complex64)
481
+
482
+ # 64-bit precision
483
+ bst.environ.set(precision=64)
484
+ self.assertEqual(bst.environ.dctype(), np.complex128)
485
+
486
+ # 16-bit precision (should use complex64)
487
+ bst.environ.set(precision=16)
488
+ self.assertEqual(bst.environ.dctype(), np.complex64)
489
+
490
+ def test_tolerance_function(self):
491
+ """Test tolerance values for different precisions."""
492
+ # 64-bit precision
493
+ bst.environ.set(precision=64)
494
+ tol = bst.environ.tolerance()
495
+ self.assertAlmostEqual(float(tol), 1e-12, places=14)
496
+
497
+ # 32-bit precision
498
+ bst.environ.set(precision=32)
499
+ tol = bst.environ.tolerance()
500
+ self.assertAlmostEqual(float(tol), 1e-5, places=7)
501
+
502
+ # 16-bit precision
503
+ bst.environ.set(precision=16)
504
+ tol = bst.environ.tolerance()
505
+ self.assertAlmostEqual(float(tol), 1e-2, places=4)
506
+
507
+ def test_invalid_precision(self):
508
+ """Test invalid precision values."""
509
+ invalid_precisions = [128, 'invalid', -1, 3.14]
510
+
511
+ for invalid in invalid_precisions:
512
+ with self.assertRaises(ValueError):
513
+ bst.environ.set(precision=invalid)
514
+
515
+ def test_precision_with_arrays(self):
516
+ """Test that precision affects array creation."""
517
+ # Test with different precisions
518
+ test_cases = [
519
+ (32, jnp.float32),
520
+ (64, jnp.float64),
521
+ (16, jnp.float16),
522
+ ('bf16', jnp.bfloat16),
523
+ ]
524
+
525
+ for precision, expected_dtype in test_cases:
526
+ with bst.environ.context(precision=precision):
527
+ # Create array using random
528
+ arr = bst.random.randn(5)
529
+ self.assertEqual(arr.dtype, expected_dtype)
530
+
531
+
532
+ class TestModeAndSpecialGetters(unittest.TestCase):
533
+ """Test mode management and special getter functions."""
534
+
535
+ def setUp(self):
536
+ """Reset environment before each test."""
537
+ bst.environ.reset()
538
+ warnings.filterwarnings('ignore', category=UserWarning)
539
+
540
+ def tearDown(self):
541
+ """Clean up after each test."""
542
+ bst.environ.reset()
543
+ warnings.resetwarnings()
544
+
545
+ def test_get_dt(self):
546
+ """Test get_dt function."""
547
+ # Set dt
548
+ bst.environ.set(dt=0.01)
549
+ self.assertEqual(bst.environ.get_dt(), 0.01)
550
+
551
+ # Test in context
552
+ with bst.environ.context(dt=0.001):
553
+ self.assertEqual(bst.environ.get_dt(), 0.001)
554
+
555
+ self.assertEqual(bst.environ.get_dt(), 0.01)
556
+
557
+ # Test missing dt
558
+ bst.environ.reset()
559
+ with self.assertRaises(KeyError):
560
+ bst.environ.get_dt()
561
+
562
+ def test_get_mode(self):
563
+ """Test get_mode function."""
564
+ # Set training mode
565
+ training = bst.mixin.Training()
566
+ bst.environ.set(mode=training)
567
+ mode = bst.environ.get('mode')
568
+ self.assertEqual(mode, training)
569
+ self.assertTrue(mode.has(bst.mixin.Training))
570
+
571
+ # Test with batching mode
572
+ batching = bst.mixin.Batching(batch_size=32)
573
+ with bst.environ.context(mode=batching):
574
+ mode = bst.environ.get('mode')
575
+ self.assertEqual(mode, batching)
576
+ self.assertTrue(mode.has(bst.mixin.Batching))
577
+ self.assertEqual(mode.batch_size, 32)
578
+
579
+ # Test missing mode
580
+ bst.environ.reset()
581
+ with self.assertRaises(KeyError):
582
+ bst.environ.get('mode')
583
+
584
+ def test_get_platform(self):
585
+ """Test get_platform function."""
586
+ platform = bst.environ.get_platform()
587
+ self.assertIn(platform, bst.environ.SUPPORTED_PLATFORMS)
588
+
589
+ def test_get_host_device_count(self):
590
+ """Test get_host_device_count function."""
591
+ count = bst.environ.get_host_device_count()
592
+ self.assertIsInstance(count, int)
593
+ self.assertGreaterEqual(count, 1)
594
+
595
+ def test_dt_validation(self):
596
+ """Test dt validation in set function."""
597
+ # Valid dt values
598
+ valid_dts = [0.01, 0.001, 1.0, 0.1]
599
+ for dt in valid_dts:
600
+ bst.environ.set(dt=dt)
601
+ self.assertEqual(bst.environ.get_dt(), dt)
602
+
603
+
604
+ class TestPlatformAndDevice(unittest.TestCase):
605
+ """Test platform and device management."""
606
+
607
+ def setUp(self):
608
+ """Reset environment before each test."""
609
+ bst.environ.reset()
610
+ warnings.filterwarnings('ignore', category=UserWarning)
611
+
612
+ def tearDown(self):
613
+ """Clean up after each test."""
614
+ bst.environ.reset()
615
+ warnings.resetwarnings()
616
+
617
+ @patch('brainstate.environ.config')
618
+ def test_set_platform(self, mock_config):
619
+ """Test platform setting."""
620
+ platforms = ['cpu', 'gpu', 'tpu']
621
+
622
+ for platform in platforms:
623
+ bst.environ.set_platform(platform)
624
+ mock_config.update.assert_called_with("jax_platform_name", platform)
625
+
626
+ # Test invalid platform
627
+ with self.assertRaises(ValueError):
628
+ bst.environ.set_platform('invalid')
629
+
630
+ def test_set_platform_through_set(self):
631
+ """Test setting platform through general set function."""
632
+ with patch('brainstate.environ.config') as mock_config:
633
+ bst.environ.set(platform='gpu')
634
+ mock_config.update.assert_called_with("jax_platform_name", 'gpu')
635
+
636
+ def test_set_host_device_count(self):
637
+ """Test host device count setting."""
638
+ import os
639
+
640
+ # Set device count
641
+ bst.environ.set_host_device_count(4)
642
+ xla_flags = os.environ.get("XLA_FLAGS", "")
643
+ self.assertIn("--xla_force_host_platform_device_count=4", xla_flags)
644
+
645
+ # Update device count
646
+ bst.environ.set_host_device_count(8)
647
+ xla_flags = os.environ.get("XLA_FLAGS", "")
648
+ self.assertIn("--xla_force_host_platform_device_count=8", xla_flags)
649
+ self.assertNotIn("--xla_force_host_platform_device_count=4", xla_flags)
650
+
651
+ # Invalid device count
652
+ with self.assertRaises(ValueError):
653
+ bst.environ.set_host_device_count(0)
654
+
655
+ with self.assertRaises(ValueError):
656
+ bst.environ.set_host_device_count(-1)
657
+
658
+ def test_platform_context_restriction(self):
659
+ """Test that platform cannot be changed in context."""
660
+ with self.assertRaises(ValueError):
661
+ with bst.environ.context(platform='cpu'):
662
+ pass
663
+
664
+
665
+ class TestCallbackBehavior(unittest.TestCase):
666
+ """Test callback registration and behavior."""
667
+
668
+ def setUp(self):
669
+ """Reset environment before each test."""
670
+ bst.environ.reset()
671
+ warnings.filterwarnings('ignore', category=UserWarning)
672
+ self.callback_values = []
673
+
674
+ def tearDown(self):
675
+ """Clean up after each test."""
676
+ bst.environ.reset()
677
+ warnings.resetwarnings()
678
+
679
+ # def test_register_callback(self):
680
+ # """Test basic callback registration."""
681
+ # def callback(value):
682
+ # self.callback_values.append(value)
683
+ #
684
+ # brainstate.environ.register_default_behavior('test_param', callback)
685
+ #
686
+ # # Callback should be triggered on set
687
+ # brainstate.environ.set(test_param='value1')
688
+ # self.assertEqual(self.callback_values, ['value1'])
689
+ #
690
+ # # Callback should be triggered on context enter/exit
691
+ # with brainstate.environ.context(test_param='value2'):
692
+ # self.assertEqual(self.callback_values, ['value1', 'value2'])
693
+ #
694
+ # # Should restore previous value
695
+ # self.assertEqual(self.callback_values, ['value1', 'value2', 'value1'])
696
+
697
+ def test_register_multiple_callbacks(self):
698
+ """Test registering callbacks for different keys."""
699
+ values_a = []
700
+ values_b = []
701
+
702
+ def callback_a(value):
703
+ values_a.append(value)
704
+
705
+ def callback_b(value):
706
+ values_b.append(value)
707
+
708
+ bst.environ.register_default_behavior('param_a', callback_a)
709
+ bst.environ.register_default_behavior('param_b', callback_b)
710
+
711
+ bst.environ.set(param_a='a1', param_b='b1')
712
+ self.assertEqual(values_a, ['a1'])
713
+ self.assertEqual(values_b, ['b1'])
714
+
715
+ def test_replace_callback(self):
716
+ """Test replacing existing callbacks."""
717
+
718
+ def callback1(value):
719
+ self.callback_values.append(f'cb1:{value}')
720
+
721
+ def callback2(value):
722
+ self.callback_values.append(f'cb2:{value}')
723
+
724
+ # Register first callback
725
+ bst.environ.register_default_behavior('param', callback1)
726
+
727
+ # Try to register second without replace flag
728
+ with self.assertRaises(ValueError):
729
+ bst.environ.register_default_behavior('param', callback2)
730
+
731
+ # Register with replace flag
732
+ bst.environ.register_default_behavior('param', callback2, replace_if_exist=True)
733
+
734
+ # Only second callback should be called
735
+ bst.environ.set(param='test')
736
+ self.assertEqual(self.callback_values, ['cb2:test'])
737
+
738
+ def test_unregister_callback(self):
739
+ """Test unregistering callbacks."""
740
+
741
+ def callback(value):
742
+ self.callback_values.append(value)
743
+
744
+ # Register and test
745
+ bst.environ.register_default_behavior('param', callback)
746
+ bst.environ.set(param='value1')
747
+ self.assertEqual(len(self.callback_values), 1)
748
+
749
+ # Unregister
750
+ removed = bst.environ.unregister_default_behavior('param')
751
+ self.assertTrue(removed)
752
+
753
+ # Callback should not be triggered
754
+ bst.environ.set(param='value2')
755
+ self.assertEqual(len(self.callback_values), 1) # Still just one
756
+
757
+ # Unregister non-existent
758
+ removed = bst.environ.unregister_default_behavior('nonexistent')
759
+ self.assertFalse(removed)
760
+
761
+ def test_list_registered_behaviors(self):
762
+ """Test listing registered behaviors."""
763
+ # Initially empty or with system defaults
764
+ initial = bst.environ.list_registered_behaviors()
765
+
766
+ # Register some behaviors
767
+ bst.environ.register_default_behavior('param1', lambda x: None)
768
+ bst.environ.register_default_behavior('param2', lambda x: None)
769
+ bst.environ.register_default_behavior('param3', lambda x: None)
770
+
771
+ behaviors = bst.environ.list_registered_behaviors()
772
+ for param in ['param1', 'param2', 'param3']:
773
+ self.assertIn(param, behaviors)
774
+
775
+ def test_callback_exception_handling(self):
776
+ """Test that exceptions in callbacks are handled gracefully."""
777
+
778
+ def failing_callback(value):
779
+ raise RuntimeError(f"Intentional error: {value}")
780
+
781
+ bst.environ.register_default_behavior('param', failing_callback)
782
+
783
+ # Should not crash, but should warn
784
+ with warnings.catch_warnings(record=True) as w:
785
+ warnings.simplefilter("always")
786
+ bst.environ.set(param='test')
787
+
788
+ # Should have a warning
789
+ self.assertTrue(len(w) > 0)
790
+ self.assertIn('Callback', str(w[0].message))
791
+ self.assertIn('exception', str(w[0].message))
792
+
793
+ def test_callback_validation(self):
794
+ """Test callback validation."""
795
+ # Non-callable
796
+ with self.assertRaises(TypeError):
797
+ bst.environ.register_default_behavior('param', 'not_callable')
798
+
799
+ # Non-string key
800
+ with self.assertRaises(TypeError):
801
+ bst.environ.register_default_behavior(123, lambda x: None)
802
+
803
+ def test_callback_with_validation(self):
804
+ """Test using callbacks for validation."""
805
+
806
+ def validate_positive(value):
807
+ if value <= 0:
808
+ raise ValueError(f"Value must be positive, got {value}")
809
+ self.callback_values.append(value)
810
+
811
+ bst.environ.register_default_behavior('positive_param', validate_positive)
812
+
813
+ # Valid value
814
+ bst.environ.set(positive_param=10)
815
+ self.assertEqual(self.callback_values, [10])
816
+
817
+ # Invalid value should raise through warning system
818
+ with warnings.catch_warnings(record=True):
819
+ warnings.simplefilter("always")
820
+ bst.environ.set(positive_param=-5)
821
+
822
+
823
+ class TestThreadSafety(unittest.TestCase):
824
+ """Test thread safety of environment operations."""
825
+
826
+ def setUp(self):
827
+ """Reset environment before each test."""
828
+ bst.environ.reset()
829
+ warnings.filterwarnings('ignore', category=UserWarning)
830
+
831
+ def tearDown(self):
832
+ """Clean up after each test."""
833
+ bst.environ.reset()
834
+ warnings.resetwarnings()
835
+
836
+ def test_concurrent_set_operations(self):
837
+ """Test concurrent set operations from multiple threads."""
838
+ results = []
839
+ errors = []
840
+
841
+ def thread_operation(thread_id):
842
+ try:
843
+ # Each thread sets its own value
844
+ for i in range(10):
845
+ bst.environ.set(**{f'thread_{thread_id}': i})
846
+ value = bst.environ.get(f'thread_{thread_id}')
847
+ results.append((thread_id, value))
848
+ except Exception as e:
849
+ errors.append(e)
850
+
851
+ threads = []
852
+ for i in range(5):
853
+ thread = threading.Thread(target=thread_operation, args=(i,))
854
+ threads.append(thread)
855
+ thread.start()
856
+
857
+ for thread in threads:
858
+ thread.join()
859
+
860
+ # Should have no errors
861
+ self.assertEqual(len(errors), 0)
862
+
863
+ # Each thread should have written its values
864
+ for i in range(5):
865
+ try:
866
+ final_value = bst.environ.get(f'thread_{i}')
867
+ except KeyError:
868
+ pass
869
+
870
+ def test_concurrent_context_operations(self):
871
+ """Test concurrent context operations from multiple threads."""
872
+ results = []
873
+ errors = []
874
+
875
+ def thread_context_operation(thread_id):
876
+ try:
877
+ bst.environ.set(**{f'base_{thread_id}': 0})
878
+
879
+ for i in range(5):
880
+ with bst.environ.context(**{f'base_{thread_id}': i}):
881
+ value = bst.environ.get(f'base_{thread_id}')
882
+ results.append((thread_id, value))
883
+
884
+ # Should be back to 0
885
+ final = bst.environ.get(f'base_{thread_id}')
886
+ self.assertEqual(final, 0)
887
+ except Exception as e:
888
+ errors.append(e)
889
+
890
+ threads = []
891
+ for i in range(3):
892
+ thread = threading.Thread(target=thread_context_operation, args=(i,))
893
+ threads.append(thread)
894
+ thread.start()
895
+
896
+ for thread in threads:
897
+ thread.join()
898
+
899
+ # Should have no errors
900
+ self.assertEqual(len(errors), 0)
901
+
902
+ def test_concurrent_pop_operations(self):
903
+ """Test concurrent pop operations from multiple threads."""
904
+ # Set up multiple keys
905
+ for i in range(20):
906
+ bst.environ.set(**{f'pop_thread_{i}': f'value_{i}'})
907
+
908
+ results = []
909
+ errors = []
910
+
911
+ def thread_pop_operation(start, end):
912
+ try:
913
+ for i in range(start, end):
914
+ try:
915
+ value = bst.environ.pop(f'pop_thread_{i}')
916
+ results.append((i, value))
917
+ except KeyError:
918
+ # Key might already be popped by another thread
919
+ pass
920
+ except Exception as e:
921
+ errors.append(e)
922
+
923
+ # Create threads that pop different ranges
924
+ threads = []
925
+ ranges = [(0, 5), (5, 10), (10, 15), (15, 20)]
926
+ for start, end in ranges:
927
+ thread = threading.Thread(target=thread_pop_operation, args=(start, end))
928
+ threads.append(thread)
929
+ thread.start()
930
+
931
+ for thread in threads:
932
+ thread.join()
933
+
934
+ # Should have no errors
935
+ self.assertEqual(len(errors), 0)
936
+
937
+ # # All keys should be popped (each exactly once)
938
+ # popped_indices = [r[0] for r in results]
939
+ # self.assertEqual(len(popped_indices), 20)
940
+ # self.assertEqual(len(set(popped_indices)), 20) # All unique
941
+ #
942
+ # # All values should be gone
943
+ # for i in range(20):
944
+ # result = brainstate.environ.get(f'pop_thread_{i}', default=None)
945
+ # self.assertIsNone(result)
946
+
947
+
948
+ class TestEdgeCases(unittest.TestCase):
949
+ """Test edge cases and boundary conditions."""
950
+
951
+ def setUp(self):
952
+ """Reset environment before each test."""
953
+ bst.environ.reset()
954
+ warnings.filterwarnings('ignore', category=UserWarning)
955
+
956
+ def tearDown(self):
957
+ """Clean up after each test."""
958
+ bst.environ.reset()
959
+ warnings.resetwarnings()
960
+
961
+ def test_empty_context(self):
962
+ """Test context with no parameters."""
963
+ original = bst.environ.all()
964
+
965
+ with bst.environ.context() as ctx:
966
+ # Should be unchanged
967
+ self.assertEqual(ctx, original)
968
+
969
+ self.assertEqual(bst.environ.all(), original)
970
+
971
+ def test_none_values(self):
972
+ """Test handling of None values."""
973
+ bst.environ.set(none_param=None)
974
+ self.assertIsNone(bst.environ.get('none_param'))
975
+
976
+ with bst.environ.context(none_param='not_none'):
977
+ self.assertEqual(bst.environ.get('none_param'), 'not_none')
978
+
979
+ self.assertIsNone(bst.environ.get('none_param'))
980
+
981
+ def test_complex_data_types(self):
982
+ """Test storing complex data types."""
983
+ # Lists
984
+ bst.environ.set(list_param=[1, 2, 3])
985
+ self.assertEqual(bst.environ.get('list_param'), [1, 2, 3])
986
+
987
+ # Dictionaries
988
+ bst.environ.set(dict_param={'a': 1, 'b': 2})
989
+ self.assertEqual(bst.environ.get('dict_param'), {'a': 1, 'b': 2})
990
+
991
+ # Tuples
992
+ bst.environ.set(tuple_param=(1, 2, 3))
993
+ self.assertEqual(bst.environ.get('tuple_param'), (1, 2, 3))
994
+
995
+ # Custom objects
996
+ class CustomObject:
997
+ def __init__(self, value):
998
+ self.value = value
999
+
1000
+ obj = CustomObject(42)
1001
+ bst.environ.set(obj_param=obj)
1002
+ retrieved = bst.environ.get('obj_param')
1003
+ self.assertIs(retrieved, obj)
1004
+ self.assertEqual(retrieved.value, 42)
1005
+
1006
+ def test_special_string_values(self):
1007
+ """Test special string values."""
1008
+ special_strings = ['', ' ', '\n', '\t', 'None', 'True', 'False']
1009
+
1010
+ for s in special_strings:
1011
+ bst.environ.set(string_param=s)
1012
+ self.assertEqual(bst.environ.get('string_param'), s)
1013
+
1014
+ def test_numeric_edge_values(self):
1015
+ """Test numeric edge values."""
1016
+ import sys
1017
+
1018
+ edge_values = [
1019
+ 0, -0, 1, -1,
1020
+ sys.maxsize, -sys.maxsize,
1021
+ float('inf'), float('-inf'),
1022
+ 1e-100, 1e100,
1023
+ ]
1024
+
1025
+ for value in edge_values:
1026
+ bst.environ.set(numeric_param=value)
1027
+ retrieved = bst.environ.get('numeric_param')
1028
+ if value != value: # NaN check
1029
+ self.assertTrue(retrieved != retrieved)
1030
+ else:
1031
+ self.assertEqual(retrieved, value)
1032
+
1033
+ def test_context_all_interaction(self):
1034
+ """Test interaction between context and all() function."""
1035
+ bst.environ.set(global_param='global')
1036
+
1037
+ with bst.environ.context(context_param='context', global_param='override'):
1038
+ all_values = bst.environ.all()
1039
+
1040
+ # Should include both
1041
+ self.assertEqual(all_values['global_param'], 'override')
1042
+ self.assertEqual(all_values['context_param'], 'context')
1043
+
1044
+ # Original global values should be in settings
1045
+ self.assertIn('precision', all_values)
1046
+
1047
+ def test_deeply_nested_contexts(self):
1048
+ """Test deeply nested contexts."""
1049
+ depth = 20
1050
+ bst.environ.set(depth=0)
1051
+
1052
+ def nested_context(level):
1053
+ if level < depth:
1054
+ with bst.environ.context(depth=level):
1055
+ self.assertEqual(bst.environ.get('depth'), level)
1056
+ nested_context(level + 1)
1057
+ self.assertEqual(bst.environ.get('depth'), level)
1058
+
1059
+ nested_context(1)
1060
+ self.assertEqual(bst.environ.get('depth'), 0)
1061
+
1062
+ def test_set_precision_function(self):
1063
+ """Test the dedicated set_precision function."""
1064
+ # Valid precisions
1065
+ for precision in [8, 16, 32, 64, 'bf16']:
1066
+ bst.environ.set_precision(precision)
1067
+ self.assertEqual(bst.environ.get('precision'), precision)
1068
+
1069
+ # Invalid precision
1070
+ with self.assertRaises(ValueError):
1071
+ bst.environ.set_precision(128)
1072
+
1073
+ def test_pop_edge_cases(self):
1074
+ """Test edge cases for pop function."""
1075
+ # Pop with None value
1076
+ bst.environ.set(none_key=None)
1077
+ popped = bst.environ.pop('none_key')
1078
+ self.assertIsNone(popped)
1079
+
1080
+ # Pop with None as default
1081
+ result = bst.environ.pop('missing_key', default=None)
1082
+ self.assertIsNone(result)
1083
+
1084
+ # Pop complex data types
1085
+ complex_obj = {'nested': {'data': [1, 2, 3]}}
1086
+ bst.environ.set(complex_key=complex_obj)
1087
+ popped = bst.environ.pop('complex_key')
1088
+ self.assertEqual(popped, complex_obj)
1089
+
1090
+ # Verify object identity preservation
1091
+ obj = object()
1092
+ bst.environ.set(obj_key=obj)
1093
+ popped = bst.environ.pop('obj_key')
1094
+ self.assertIs(popped, obj)
1095
+
1096
+ def test_pop_all_interaction(self):
1097
+ """Test interaction between pop and all() function."""
1098
+ # Set multiple values
1099
+ bst.environ.set(a=1, b=2, c=3, d=4)
1100
+ initial_all = bst.environ.all()
1101
+
1102
+ # Pop some values
1103
+ bst.environ.pop('b')
1104
+ bst.environ.pop('d')
1105
+
1106
+ # Check all() reflects the changes
1107
+ after_pop = bst.environ.all()
1108
+ self.assertIn('a', after_pop)
1109
+ self.assertIn('c', after_pop)
1110
+ self.assertNotIn('b', after_pop)
1111
+ self.assertNotIn('d', after_pop)
1112
+
1113
+ def test_pop_callback_not_triggered(self):
1114
+ """Test that callbacks are not triggered on pop."""
1115
+ callback_calls = []
1116
+
1117
+ def callback(value):
1118
+ callback_calls.append(value)
1119
+
1120
+ # Register callback
1121
+ bst.environ.register_default_behavior('callback_test', callback)
1122
+
1123
+ # Set triggers callback
1124
+ bst.environ.set(callback_test='value')
1125
+ self.assertEqual(len(callback_calls), 1)
1126
+
1127
+ # Pop should NOT trigger callback
1128
+ popped = bst.environ.pop('callback_test')
1129
+ self.assertEqual(len(callback_calls), 1) # Still just 1
1130
+ self.assertEqual(popped, 'value')
1131
+
1132
+ # Unregister callback
1133
+ bst.environ.unregister_default_behavior('callback_test')
1134
+
1135
+
1136
+ class TestIntegration(unittest.TestCase):
1137
+ """Integration tests with actual BrainState functionality."""
1138
+
1139
+ def setUp(self):
1140
+ """Reset environment before each test."""
1141
+ bst.environ.reset()
1142
+ warnings.filterwarnings('ignore', category=UserWarning)
1143
+
1144
+ def tearDown(self):
1145
+ """Clean up after each test."""
1146
+ bst.environ.reset()
1147
+ warnings.resetwarnings()
1148
+
1149
+ def test_precision_affects_random_arrays(self):
1150
+ """Test that precision setting affects random array generation."""
1151
+ # Test different precisions
1152
+ test_cases = [
1153
+ (32, jnp.float32),
1154
+ (64, jnp.float64),
1155
+ (16, jnp.float16),
1156
+ ('bf16', jnp.bfloat16),
1157
+ ]
1158
+
1159
+ for precision, expected_dtype in test_cases:
1160
+ with bst.environ.context(precision=precision):
1161
+ arr = bst.random.randn(10)
1162
+ self.assertEqual(arr.dtype, expected_dtype)
1163
+
1164
+ def test_mode_usage(self):
1165
+ """Test mode usage in computations."""
1166
+ # Create different modes
1167
+ training = bst.mixin.Training()
1168
+ batching = bst.mixin.Batching(batch_size=32)
1169
+
1170
+ # Test training mode
1171
+ bst.environ.set(mode=training)
1172
+ mode = bst.environ.get('mode')
1173
+ self.assertTrue(mode.has(bst.mixin.Training))
1174
+
1175
+ # Test batching mode
1176
+ with bst.environ.context(mode=batching):
1177
+ mode = bst.environ.get('mode')
1178
+ self.assertTrue(mode.has(bst.mixin.Batching))
1179
+ self.assertEqual(mode.batch_size, 32)
1180
+
1181
+ def test_dt_in_numerical_integration(self):
1182
+ """Test dt usage in numerical contexts."""
1183
+ # Set different dt values
1184
+ dt_values = [0.01, 0.001, 0.1]
1185
+
1186
+ for dt in dt_values:
1187
+ bst.environ.set(dt=dt)
1188
+ retrieved_dt = bst.environ.get_dt()
1189
+ self.assertEqual(retrieved_dt, dt)
1190
+
1191
+ # Simulate using dt in computation
1192
+ time_steps = int(1.0 / dt)
1193
+ self.assertGreater(time_steps, 0)
1194
+
1195
+ def test_combined_settings(self):
1196
+ """Test combining multiple settings."""
1197
+ # Set multiple parameters
1198
+ bst.environ.set(
1199
+ precision=64,
1200
+ dt=0.01,
1201
+ mode=bst.mixin.Training(),
1202
+ custom_param='test',
1203
+ debug=True
1204
+ )
1205
+
1206
+ # Verify all are set
1207
+ self.assertEqual(bst.environ.get_precision(), 64)
1208
+ self.assertEqual(bst.environ.get_dt(), 0.01)
1209
+ self.assertTrue(bst.environ.get('mode').has(bst.mixin.Training))
1210
+ self.assertEqual(bst.environ.get('custom_param'), 'test')
1211
+ self.assertTrue(bst.environ.get('debug'))
1212
+
1213
+ # Test in nested contexts
1214
+ with bst.environ.context(precision=32, debug=False):
1215
+ self.assertEqual(bst.environ.get_precision(), 32)
1216
+ self.assertFalse(bst.environ.get('debug'))
1217
+ # Others unchanged
1218
+ self.assertEqual(bst.environ.get_dt(), 0.01)
1219
+ self.assertEqual(bst.environ.get('custom_param'), 'test')
1220
+
1221
+
1222
+ if __name__ == '__main__':
1223
+ unittest.main()