brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. brainstate/__init__.py +169 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2319 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +1652 -1652
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1624 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1433 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +137 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +633 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +154 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +477 -477
  32. brainstate/nn/_dynamics.py +1267 -1267
  33. brainstate/nn/_dynamics_test.py +67 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +384 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/_rand_funs.py +3938 -3938
  64. brainstate/random/_rand_funs_test.py +640 -640
  65. brainstate/random/_rand_seed.py +675 -675
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1617
  68. brainstate/random/_rand_state_test.py +551 -551
  69. brainstate/transform/__init__.py +59 -59
  70. brainstate/transform/_ad_checkpoint.py +176 -176
  71. brainstate/transform/_ad_checkpoint_test.py +49 -49
  72. brainstate/transform/_autograd.py +1025 -1025
  73. brainstate/transform/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -316
  75. brainstate/transform/_conditions_test.py +220 -220
  76. brainstate/transform/_error_if.py +94 -94
  77. brainstate/transform/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -145
  79. brainstate/transform/_eval_shape_test.py +38 -38
  80. brainstate/transform/_jit.py +399 -399
  81. brainstate/transform/_jit_test.py +143 -143
  82. brainstate/transform/_loop_collect_return.py +675 -675
  83. brainstate/transform/_loop_collect_return_test.py +58 -58
  84. brainstate/transform/_loop_no_collection.py +283 -283
  85. brainstate/transform/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -2016
  87. brainstate/transform/_make_jaxpr_test.py +1510 -1510
  88. brainstate/transform/_mapping.py +529 -529
  89. brainstate/transform/_mapping_test.py +194 -194
  90. brainstate/transform/_progress_bar.py +255 -255
  91. brainstate/transform/_random.py +171 -171
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate-0.2.0.dist-info/RECORD +0 -111
  111. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  112. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,681 +1,681 @@
1
- # Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """
17
- Comprehensive test suite for the _compatible_import module.
18
-
19
- This test module provides extensive coverage of the compatibility layer
20
- functionality, including:
21
- - JAX version-dependent imports and compatibility
22
- - Utility functions (safe_map, safe_zip, unzip2)
23
- - Function wrapping and metadata handling
24
- - Type safety and error handling
25
- - Edge cases and boundary conditions
26
- """
27
-
28
- import unittest
29
- from functools import partial
30
- from unittest.mock import Mock
31
-
32
- import jax
33
- import jax.numpy as jnp
34
- import numpy as np
35
-
36
- from brainstate import _compatible_import as compat
37
-
38
-
39
- class TestJAXVersionCompatibility(unittest.TestCase):
40
- """Test JAX version-dependent imports and compatibility."""
41
-
42
- def setUp(self):
43
- """Set up test environment."""
44
- self.original_version = jax.__version_info__
45
-
46
- def tearDown(self):
47
- """Clean up after tests."""
48
- # Restore original version info
49
- jax.__version_info__ = self.original_version
50
-
51
- def test_device_import_compatibility(self):
52
- """Test Device import works across JAX versions."""
53
- # Test that Device is available and importable
54
- self.assertTrue(hasattr(compat, 'Device'))
55
- self.assertIsNotNone(compat.Device)
56
-
57
- # Test Device can be used for type checking
58
- device = jax.devices()[0]
59
- self.assertIsInstance(device, compat.Device)
60
-
61
- def test_core_imports_availability(self):
62
- """Test core JAX imports are available."""
63
- # Core types should be available
64
- core_types = [
65
- 'ClosedJaxpr', 'Primitive', 'Var', 'JaxprEqn',
66
- 'Jaxpr', 'Literal', 'Tracer'
67
- ]
68
-
69
- for type_name in core_types:
70
- self.assertTrue(hasattr(compat, type_name),
71
- f"{type_name} should be available")
72
- self.assertIsNotNone(getattr(compat, type_name))
73
-
74
- def test_function_imports_availability(self):
75
- """Test function imports are available."""
76
- functions = [
77
- 'jaxpr_as_fun', 'get_aval', 'to_concrete_aval',
78
- 'extend_axis_env_nd'
79
- ]
80
-
81
- for func_name in functions:
82
- self.assertTrue(hasattr(compat, func_name),
83
- f"{func_name} should be available")
84
- self.assertTrue(callable(getattr(compat, func_name)),
85
- f"{func_name} should be callable")
86
-
87
- def test_extend_axis_env_nd_functionality(self):
88
- """Test extend_axis_env_nd context manager."""
89
- # Test basic functionality
90
- with compat.extend_axis_env_nd([('test_axis', 10)]):
91
- # Context should execute without error
92
- pass
93
-
94
- # Test with multiple axes
95
- with compat.extend_axis_env_nd([('batch', 32), ('seq', 128)]):
96
- pass
97
-
98
- # Test with empty axes
99
- with compat.extend_axis_env_nd([]):
100
- pass
101
-
102
- def test_get_aval_functionality(self):
103
- """Test get_aval function works correctly."""
104
- # Test with JAX array
105
- arr = jnp.array([1, 2, 3])
106
- aval = compat.get_aval(arr)
107
- self.assertIsNotNone(aval)
108
-
109
- # Test with scalar
110
- scalar = jnp.float32(3.14)
111
- scalar_aval = compat.get_aval(scalar)
112
- self.assertIsNotNone(scalar_aval)
113
-
114
- def test_to_concrete_aval_functionality(self):
115
- """Test to_concrete_aval function."""
116
- # Test with concrete array
117
- arr = jnp.array([1, 2, 3])
118
- result = compat.to_concrete_aval(arr)
119
- self.assertIsNotNone(result)
120
-
121
- # # Test with scalar
122
- # scalar = 42.0
123
- # result = compat.to_concrete_aval(scalar)
124
- # self.assertEqual(result, scalar)
125
-
126
-
127
- class TestUtilityFunctions(unittest.TestCase):
128
- """Test utility functions like safe_map, safe_zip, unzip2."""
129
-
130
- def setUp(self):
131
- """Set up test environment."""
132
- pass
133
-
134
- def tearDown(self):
135
- """Clean up after tests."""
136
- pass
137
-
138
- def test_safe_map_basic(self):
139
- """Test basic safe_map functionality."""
140
- # Single argument function
141
- result = compat.safe_map(lambda x: x * 2, [1, 2, 3])
142
- self.assertEqual(result, [2, 4, 6])
143
-
144
- # Multiple argument function
145
- result = compat.safe_map(lambda x, y: x + y, [1, 2, 3], [4, 5, 6])
146
- self.assertEqual(result, [5, 7, 9])
147
-
148
- # String function
149
- result = compat.safe_map(str.upper, ['a', 'b', 'c'])
150
- self.assertEqual(result, ['A', 'B', 'C'])
151
-
152
- def test_safe_map_empty_inputs(self):
153
- """Test safe_map with empty inputs."""
154
- result = compat.safe_map(lambda x: x, [])
155
- self.assertEqual(result, [])
156
-
157
- result = compat.safe_map(lambda x, y: x + y, [], [])
158
- self.assertEqual(result, [])
159
-
160
- def test_safe_map_length_mismatch(self):
161
- """Test safe_map raises error on length mismatch."""
162
- with self.assertRaises(AssertionError) as context:
163
- compat.safe_map(lambda x, y: x + y, [1, 2, 3], [4, 5])
164
-
165
- self.assertIn('length mismatch', str(context.exception))
166
-
167
- def test_safe_map_complex_functions(self):
168
- """Test safe_map with complex functions."""
169
- # Lambda with multiple operations
170
- result = compat.safe_map(lambda x: x ** 2 + 1, [1, 2, 3])
171
- self.assertEqual(result, [2, 5, 10])
172
-
173
- # Function that returns tuples
174
- result = compat.safe_map(lambda x, y: (x, y), [1, 2], ['a', 'b'])
175
- self.assertEqual(result, [(1, 'a'), (2, 'b')])
176
-
177
- def test_safe_zip_basic(self):
178
- """Test basic safe_zip functionality."""
179
- # Two sequences
180
- result = compat.safe_zip([1, 2, 3], ['a', 'b', 'c'])
181
- expected = [(1, 'a'), (2, 'b'), (3, 'c')]
182
- self.assertEqual(result, expected)
183
-
184
- # Three sequences
185
- result = compat.safe_zip([1, 2], [3, 4], [5, 6])
186
- expected = [(1, 3, 5), (2, 4, 6)]
187
- self.assertEqual(result, expected)
188
-
189
- def test_safe_zip_empty_inputs(self):
190
- """Test safe_zip with empty inputs."""
191
- result = compat.safe_zip([], [])
192
- self.assertEqual(result, [])
193
-
194
- result = compat.safe_zip([], [], [])
195
- self.assertEqual(result, [])
196
-
197
- def test_safe_zip_length_mismatch(self):
198
- """Test safe_zip raises error on length mismatch."""
199
- with self.assertRaises(AssertionError) as context:
200
- compat.safe_zip([1, 2, 3], [4, 5])
201
-
202
- self.assertIn('length mismatch', str(context.exception))
203
-
204
- def test_safe_zip_single_sequence(self):
205
- """Test safe_zip with single sequence."""
206
- result = compat.safe_zip([1, 2, 3])
207
- expected = [(1,), (2,), (3,)]
208
- self.assertEqual(result, expected)
209
-
210
- def test_safe_zip_mixed_types(self):
211
- """Test safe_zip with mixed data types."""
212
- result = compat.safe_zip([1, 2], ['a', 'b'], [True, False])
213
- expected = [(1, 'a', True), (2, 'b', False)]
214
- self.assertEqual(result, expected)
215
-
216
- def test_unzip2_basic(self):
217
- """Test basic unzip2 functionality."""
218
- pairs = [(1, 'a'), (2, 'b'), (3, 'c')]
219
- first, second = compat.unzip2(pairs)
220
-
221
- self.assertEqual(first, (1, 2, 3))
222
- self.assertEqual(second, ('a', 'b', 'c'))
223
-
224
- def test_unzip2_empty(self):
225
- """Test unzip2 with empty input."""
226
- first, second = compat.unzip2([])
227
- self.assertEqual(first, ())
228
- self.assertEqual(second, ())
229
-
230
- def test_unzip2_single_pair(self):
231
- """Test unzip2 with single pair."""
232
- first, second = compat.unzip2([(42, 'answer')])
233
- self.assertEqual(first, (42,))
234
- self.assertEqual(second, ('answer',))
235
-
236
- def test_unzip2_mixed_types(self):
237
- """Test unzip2 with mixed data types."""
238
- pairs = [(1, 'a'), (2.5, 'b'), (None, 'c')]
239
- first, second = compat.unzip2(pairs)
240
-
241
- self.assertEqual(first, (1, 2.5, None))
242
- self.assertEqual(second, ('a', 'b', 'c'))
243
-
244
- def test_unzip2_return_types(self):
245
- """Test unzip2 returns proper tuple types."""
246
- pairs = [(1, 'a'), (2, 'b')]
247
- first, second = compat.unzip2(pairs)
248
-
249
- self.assertIsInstance(first, tuple)
250
- self.assertIsInstance(second, tuple)
251
-
252
- def test_unzip2_with_generator(self):
253
- """Test unzip2 with generator input."""
254
-
255
- def pair_generator():
256
- yield (1, 'a')
257
- yield (2, 'b')
258
- yield (3, 'c')
259
-
260
- first, second = compat.unzip2(pair_generator())
261
- self.assertEqual(first, (1, 2, 3))
262
- self.assertEqual(second, ('a', 'b', 'c'))
263
-
264
-
265
- class TestFunctionWrapping(unittest.TestCase):
266
- """Test function wrapping and metadata handling."""
267
-
268
- def setUp(self):
269
- """Set up test environment."""
270
- pass
271
-
272
- def tearDown(self):
273
- """Clean up after tests."""
274
- pass
275
-
276
- def test_fun_name_basic(self):
277
- """Test fun_name function with regular functions."""
278
-
279
- def test_function():
280
- """Test function docstring."""
281
- pass
282
-
283
- name = compat.fun_name(test_function)
284
- self.assertEqual(name, 'test_function')
285
-
286
- def test_fun_name_lambda(self):
287
- """Test fun_name with lambda functions."""
288
- lambda_func = lambda x: x * 2
289
- name = compat.fun_name(lambda_func)
290
- self.assertEqual(name, '<lambda>')
291
-
292
- def test_fun_name_partial(self):
293
- """Test fun_name with partial functions."""
294
-
295
- def original_function(x, y):
296
- return x + y
297
-
298
- partial_func = partial(original_function, 10)
299
- name = compat.fun_name(partial_func)
300
- self.assertEqual(name, 'original_function')
301
-
302
- def test_fun_name_nested_partial(self):
303
- """Test fun_name with nested partial functions."""
304
-
305
- def base_function(x, y, z):
306
- return x + y + z
307
-
308
- partial1 = partial(base_function, 1)
309
- partial2 = partial(partial1, 2)
310
-
311
- name = compat.fun_name(partial2)
312
- self.assertEqual(name, 'base_function')
313
-
314
- def test_fun_name_no_name_attribute(self):
315
- """Test fun_name with objects without __name__."""
316
-
317
- class CallableClass:
318
- def __call__(self):
319
- pass
320
-
321
- callable_obj = CallableClass()
322
- name = compat.fun_name(callable_obj)
323
- self.assertEqual(name, '<unnamed function>')
324
-
325
- def test_wraps_basic(self):
326
- """Test basic wraps functionality."""
327
-
328
- def original_function():
329
- """Original function docstring."""
330
- return 42
331
-
332
- @compat.wraps(original_function)
333
- def wrapper():
334
- return original_function()
335
-
336
- self.assertEqual(wrapper.__name__, 'original_function')
337
- self.assertEqual(wrapper.__doc__, 'Original function docstring.')
338
- self.assertEqual(wrapper.__wrapped__, original_function)
339
-
340
- def test_wraps_with_namestr(self):
341
- """Test wraps with custom name string."""
342
-
343
- def original_function():
344
- pass
345
-
346
- @compat.wraps(original_function, namestr="wrapped_{fun}")
347
- def wrapper():
348
- pass
349
-
350
- self.assertEqual(wrapper.__name__, 'wrapped_original_function')
351
-
352
- def test_wraps_with_docstr(self):
353
- """Test wraps with custom docstring."""
354
-
355
- def original_function():
356
- """Original docstring."""
357
- pass
358
-
359
- @compat.wraps(original_function, docstr="Wrapper for {fun}: {doc}")
360
- def wrapper():
361
- pass
362
-
363
- expected_doc = "Wrapper for original_function: Original docstring."
364
- self.assertEqual(wrapper.__doc__, expected_doc)
365
-
366
- def test_wraps_with_kwargs(self):
367
- """Test wraps with additional keyword arguments."""
368
-
369
- def original_function():
370
- pass
371
-
372
- @compat.wraps(original_function,
373
- docstr="Function {fun} version {version}",
374
- version="1.0")
375
- def wrapper():
376
- pass
377
-
378
- expected_doc = "Function original_function version 1.0"
379
- self.assertEqual(wrapper.__doc__, expected_doc)
380
-
381
- def test_wraps_preserves_annotations(self):
382
- """Test wraps preserves function annotations."""
383
-
384
- def original_function(x: int, y: str) -> float:
385
- return float(x)
386
-
387
- @compat.wraps(original_function)
388
- def wrapper(x: int, y: str) -> float:
389
- return original_function(x, y)
390
-
391
- self.assertEqual(wrapper.__annotations__, original_function.__annotations__)
392
-
393
- def test_wraps_preserves_dict(self):
394
- """Test wraps preserves function __dict__."""
395
-
396
- def original_function():
397
- pass
398
-
399
- original_function.custom_attr = "test_value"
400
- original_function.another_attr = 42
401
-
402
- @compat.wraps(original_function)
403
- def wrapper():
404
- pass
405
-
406
- self.assertEqual(wrapper.custom_attr, "test_value")
407
- self.assertEqual(wrapper.another_attr, 42)
408
-
409
- def test_wraps_handles_exceptions(self):
410
- """Test wraps handles exceptions gracefully."""
411
- # Create a mock object that raises exceptions
412
- mock_func = Mock()
413
- mock_func.__name__ = Mock(side_effect=Exception("Test exception"))
414
-
415
- @compat.wraps(mock_func)
416
- def wrapper():
417
- pass
418
-
419
- # Should not raise exception, just continue
420
- self.assertTrue(callable(wrapper))
421
-
422
- def test_wraps_with_missing_attributes(self):
423
- """Test wraps handles missing attributes gracefully."""
424
-
425
- class MinimalCallable:
426
- pass
427
-
428
- minimal_func = MinimalCallable()
429
-
430
- @compat.wraps(minimal_func)
431
- def wrapper():
432
- pass
433
-
434
- # Should handle missing attributes without crashing
435
- self.assertTrue(callable(wrapper))
436
-
437
-
438
- class TestEdgeCases(unittest.TestCase):
439
- """Test edge cases and boundary conditions."""
440
-
441
- def setUp(self):
442
- """Set up test environment."""
443
- pass
444
-
445
- def tearDown(self):
446
- """Clean up after tests."""
447
- pass
448
-
449
- def test_safe_map_with_none_inputs(self):
450
- """Test safe_map behavior with None inputs."""
451
- # Function that handles None
452
- result = compat.safe_map(lambda x: x if x is not None else 'default',
453
- [1, None, 3])
454
- self.assertEqual(result, [1, 'default', 3])
455
-
456
- def test_safe_map_with_zero_length(self):
457
- """Test safe_map with zero-length sequences."""
458
- result = compat.safe_map(str, [])
459
- self.assertEqual(result, [])
460
-
461
- def test_safe_zip_with_none_values(self):
462
- """Test safe_zip with None values."""
463
- result = compat.safe_zip([1, None, 3], [4, 5, None])
464
- expected = [(1, 4), (None, 5), (3, None)]
465
- self.assertEqual(result, expected)
466
-
467
- def test_unzip2_with_none_values(self):
468
- """Test unzip2 with None values."""
469
- pairs = [(1, None), (None, 'a'), (2, 'b')]
470
- first, second = compat.unzip2(pairs)
471
-
472
- self.assertEqual(first, (1, None, 2))
473
- self.assertEqual(second, (None, 'a', 'b'))
474
-
475
- def test_large_sequences(self):
476
- """Test utility functions with large sequences."""
477
- large_seq1 = list(range(10000))
478
- large_seq2 = list(range(10000, 20000))
479
-
480
- # Test safe_map
481
- result = compat.safe_map(lambda x, y: x + y, large_seq1[:100], large_seq2[:100])
482
- self.assertEqual(len(result), 100)
483
- self.assertEqual(result[0], 10000) # 0 + 10000
484
-
485
- # Test safe_zip
486
- result = compat.safe_zip(large_seq1[:100], large_seq2[:100])
487
- self.assertEqual(len(result), 100)
488
- self.assertEqual(result[0], (0, 10000))
489
-
490
- # Test unzip2
491
- pairs = list(zip(large_seq1[:100], large_seq2[:100]))
492
- first, second = compat.unzip2(pairs)
493
- self.assertEqual(len(first), 100)
494
- self.assertEqual(len(second), 100)
495
-
496
- # def test_to_concrete_aval_edge_cases(self):
497
- # """Test to_concrete_aval with edge cases."""
498
- # # Test with None
499
- # result = compat.to_concrete_aval(None)
500
- # self.assertIsNone(result)
501
- #
502
- # # Test with regular Python objects
503
- # result = compat.to_concrete_aval(42)
504
- # self.assertEqual(result, 42)
505
- #
506
- # result = compat.to_concrete_aval("string")
507
- # self.assertEqual(result, "string")
508
- #
509
- # # Test with list
510
- # test_list = [1, 2, 3]
511
- # result = compat.to_concrete_aval(test_list)
512
- # self.assertEqual(result, test_list)
513
-
514
- def test_function_name_edge_cases(self):
515
- """Test fun_name with edge cases."""
516
- # Built-in function
517
- name = compat.fun_name(len)
518
- self.assertEqual(name, 'len')
519
-
520
- # Method
521
- name = compat.fun_name(str.upper)
522
- self.assertEqual(name, 'upper')
523
-
524
- # Nested function
525
- def outer():
526
- def inner():
527
- pass
528
-
529
- return inner
530
-
531
- inner_func = outer()
532
- name = compat.fun_name(inner_func)
533
- self.assertEqual(name, 'inner')
534
-
535
- def test_concurrent_usage(self):
536
- """Test thread safety of utility functions."""
537
- import threading
538
-
539
- results = []
540
- errors = []
541
-
542
- def worker():
543
- try:
544
- # Test safe_map in concurrent context
545
- for i in range(100):
546
- result = compat.safe_map(lambda x: x * 2, [1, 2, 3])
547
- results.append(result)
548
-
549
- # Test safe_zip
550
- for i in range(100):
551
- result = compat.safe_zip([1, 2], [3, 4])
552
-
553
- # Test unzip2
554
- for i in range(100):
555
- first, second = compat.unzip2([(1, 'a'), (2, 'b')])
556
-
557
- except Exception as e:
558
- errors.append(e)
559
-
560
- threads = []
561
- for _ in range(5):
562
- thread = threading.Thread(target=worker)
563
- threads.append(thread)
564
- thread.start()
565
-
566
- for thread in threads:
567
- thread.join()
568
-
569
- # Should have no errors
570
- self.assertEqual(len(errors), 0)
571
- # Should have expected number of results
572
- self.assertEqual(len(results), 500) # 5 threads * 100 iterations
573
-
574
-
575
- class TestTypeHints(unittest.TestCase):
576
- """Test type hints and generic type variables."""
577
-
578
- def test_type_variables_defined(self):
579
- """Test that type variables are properly defined."""
580
- # Check TypeVars are available in the module
581
- self.assertTrue(hasattr(compat, 'T'))
582
- self.assertTrue(hasattr(compat, 'T1'))
583
- self.assertTrue(hasattr(compat, 'T2'))
584
- self.assertTrue(hasattr(compat, 'T3'))
585
-
586
- def test_unzip2_type_preservation(self):
587
- """Test unzip2 preserves type information."""
588
- # Test with specific types
589
- int_str_pairs = [(1, 'a'), (2, 'b'), (3, 'c')]
590
- ints, strs = compat.unzip2(int_str_pairs)
591
-
592
- # Verify types are preserved
593
- self.assertTrue(all(isinstance(x, int) for x in ints))
594
- self.assertTrue(all(isinstance(x, str) for x in strs))
595
-
596
- def test_safe_functions_with_different_types(self):
597
- """Test safe functions work with different types."""
598
- # Test safe_map with different input types
599
- mixed_inputs = [1, 2.5, '3']
600
- result = compat.safe_map(str, mixed_inputs)
601
- expected = ['1', '2.5', '3']
602
- self.assertEqual(result, expected)
603
-
604
- # Test safe_zip with different types
605
- result = compat.safe_zip([1, 2], [3.14, 2.71], ['a', 'b'])
606
- expected = [(1, 3.14, 'a'), (2, 2.71, 'b')]
607
- self.assertEqual(result, expected)
608
-
609
-
610
- class TestIntegration(unittest.TestCase):
611
- """Integration tests with JAX functionality."""
612
-
613
- def test_jax_integration(self):
614
- """Test integration with JAX arrays and operations."""
615
- # Create JAX arrays
616
- arr1 = jnp.array([1, 2, 3])
617
- arr2 = jnp.array([4, 5, 6])
618
-
619
- # Use safe_map with JAX operations
620
- result = compat.safe_map(lambda x, y: x + y, arr1.tolist(), arr2.tolist())
621
- expected = [5, 7, 9]
622
- self.assertEqual(result, expected)
623
-
624
- # Use safe_zip with JAX arrays
625
- result = compat.safe_zip(arr1.tolist(), arr2.tolist())
626
- expected = [(1, 4), (2, 5), (3, 6)]
627
- self.assertEqual(result, expected)
628
-
629
- def test_with_jax_transformations(self):
630
- """Test compatibility with JAX transformations."""
631
-
632
- def test_function(x):
633
- # Use utility functions inside JAX-transformable code
634
- pairs = [(x, x + 1), (x + 2, x + 3)]
635
- first, second = compat.unzip2(pairs)
636
- return jnp.array(first), jnp.array(second)
637
-
638
- # Test function works
639
- result1, result2 = test_function(10.0)
640
- np.testing.assert_array_equal(result1, [10.0, 12.0])
641
- np.testing.assert_array_equal(result2, [11.0, 13.0])
642
-
643
- # Test with JAX transformations
644
- jitted_func = jax.jit(test_function)
645
- result1, result2 = jitted_func(10.0)
646
- np.testing.assert_array_equal(result1, [10.0, 12.0])
647
- np.testing.assert_array_equal(result2, [11.0, 13.0])
648
-
649
-
650
- class TestModuleStructure(unittest.TestCase):
651
- """Test module structure and __all__ exports."""
652
-
653
- def test_all_exports(self):
654
- """Test that __all__ contains expected exports."""
655
- expected_exports = [
656
- 'ClosedJaxpr', 'Primitive', 'extend_axis_env_nd', 'jaxpr_as_fun',
657
- 'get_aval', 'Tracer', 'to_concrete_aval', 'safe_map', 'safe_zip',
658
- 'unzip2', 'wraps', 'Device', 'wrap_init', 'Var', 'JaxprEqn',
659
- 'Jaxpr', 'Literal'
660
- ]
661
-
662
- for export in expected_exports:
663
- self.assertIn(export, compat.__all__,
664
- f"{export} should be in __all__")
665
- self.assertTrue(hasattr(compat, export),
666
- f"{export} should be available in module")
667
-
668
- def test_no_unexpected_exports(self):
669
- """Test that no private functions are exported."""
670
- for name in compat.__all__:
671
- self.assertFalse(name.startswith('_'),
672
- f"Private name {name} should not be in __all__")
673
-
674
- def test_module_docstring(self):
675
- """Test module has proper docstring."""
676
- self.assertIsNotNone(compat.__doc__)
677
- self.assertIn('Compatibility layer', compat.__doc__)
678
-
679
-
680
- if __name__ == '__main__':
681
- unittest.main(verbosity=2)
1
+ # Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ Comprehensive test suite for the _compatible_import module.
18
+
19
+ This test module provides extensive coverage of the compatibility layer
20
+ functionality, including:
21
+ - JAX version-dependent imports and compatibility
22
+ - Utility functions (safe_map, safe_zip, unzip2)
23
+ - Function wrapping and metadata handling
24
+ - Type safety and error handling
25
+ - Edge cases and boundary conditions
26
+ """
27
+
28
+ import unittest
29
+ from functools import partial
30
+ from unittest.mock import Mock
31
+
32
+ import jax
33
+ import jax.numpy as jnp
34
+ import numpy as np
35
+
36
+ from brainstate import _compatible_import as compat
37
+
38
+
39
+ class TestJAXVersionCompatibility(unittest.TestCase):
40
+ """Test JAX version-dependent imports and compatibility."""
41
+
42
+ def setUp(self):
43
+ """Set up test environment."""
44
+ self.original_version = jax.__version_info__
45
+
46
+ def tearDown(self):
47
+ """Clean up after tests."""
48
+ # Restore original version info
49
+ jax.__version_info__ = self.original_version
50
+
51
+ def test_device_import_compatibility(self):
52
+ """Test Device import works across JAX versions."""
53
+ # Test that Device is available and importable
54
+ self.assertTrue(hasattr(compat, 'Device'))
55
+ self.assertIsNotNone(compat.Device)
56
+
57
+ # Test Device can be used for type checking
58
+ device = jax.devices()[0]
59
+ self.assertIsInstance(device, compat.Device)
60
+
61
+ def test_core_imports_availability(self):
62
+ """Test core JAX imports are available."""
63
+ # Core types should be available
64
+ core_types = [
65
+ 'ClosedJaxpr', 'Primitive', 'Var', 'JaxprEqn',
66
+ 'Jaxpr', 'Literal', 'Tracer'
67
+ ]
68
+
69
+ for type_name in core_types:
70
+ self.assertTrue(hasattr(compat, type_name),
71
+ f"{type_name} should be available")
72
+ self.assertIsNotNone(getattr(compat, type_name))
73
+
74
+ def test_function_imports_availability(self):
75
+ """Test function imports are available."""
76
+ functions = [
77
+ 'jaxpr_as_fun', 'get_aval', 'to_concrete_aval',
78
+ 'extend_axis_env_nd'
79
+ ]
80
+
81
+ for func_name in functions:
82
+ self.assertTrue(hasattr(compat, func_name),
83
+ f"{func_name} should be available")
84
+ self.assertTrue(callable(getattr(compat, func_name)),
85
+ f"{func_name} should be callable")
86
+
87
+ def test_extend_axis_env_nd_functionality(self):
88
+ """Test extend_axis_env_nd context manager."""
89
+ # Test basic functionality
90
+ with compat.extend_axis_env_nd([('test_axis', 10)]):
91
+ # Context should execute without error
92
+ pass
93
+
94
+ # Test with multiple axes
95
+ with compat.extend_axis_env_nd([('batch', 32), ('seq', 128)]):
96
+ pass
97
+
98
+ # Test with empty axes
99
+ with compat.extend_axis_env_nd([]):
100
+ pass
101
+
102
+ def test_get_aval_functionality(self):
103
+ """Test get_aval function works correctly."""
104
+ # Test with JAX array
105
+ arr = jnp.array([1, 2, 3])
106
+ aval = compat.get_aval(arr)
107
+ self.assertIsNotNone(aval)
108
+
109
+ # Test with scalar
110
+ scalar = jnp.float32(3.14)
111
+ scalar_aval = compat.get_aval(scalar)
112
+ self.assertIsNotNone(scalar_aval)
113
+
114
+ def test_to_concrete_aval_functionality(self):
115
+ """Test to_concrete_aval function."""
116
+ # Test with concrete array
117
+ arr = jnp.array([1, 2, 3])
118
+ result = compat.to_concrete_aval(arr)
119
+ self.assertIsNotNone(result)
120
+
121
+ # # Test with scalar
122
+ # scalar = 42.0
123
+ # result = compat.to_concrete_aval(scalar)
124
+ # self.assertEqual(result, scalar)
125
+
126
+
127
+ class TestUtilityFunctions(unittest.TestCase):
128
+ """Test utility functions like safe_map, safe_zip, unzip2."""
129
+
130
+ def setUp(self):
131
+ """Set up test environment."""
132
+ pass
133
+
134
+ def tearDown(self):
135
+ """Clean up after tests."""
136
+ pass
137
+
138
+ def test_safe_map_basic(self):
139
+ """Test basic safe_map functionality."""
140
+ # Single argument function
141
+ result = compat.safe_map(lambda x: x * 2, [1, 2, 3])
142
+ self.assertEqual(result, [2, 4, 6])
143
+
144
+ # Multiple argument function
145
+ result = compat.safe_map(lambda x, y: x + y, [1, 2, 3], [4, 5, 6])
146
+ self.assertEqual(result, [5, 7, 9])
147
+
148
+ # String function
149
+ result = compat.safe_map(str.upper, ['a', 'b', 'c'])
150
+ self.assertEqual(result, ['A', 'B', 'C'])
151
+
152
+ def test_safe_map_empty_inputs(self):
153
+ """Test safe_map with empty inputs."""
154
+ result = compat.safe_map(lambda x: x, [])
155
+ self.assertEqual(result, [])
156
+
157
+ result = compat.safe_map(lambda x, y: x + y, [], [])
158
+ self.assertEqual(result, [])
159
+
160
+ def test_safe_map_length_mismatch(self):
161
+ """Test safe_map raises error on length mismatch."""
162
+ with self.assertRaises(AssertionError) as context:
163
+ compat.safe_map(lambda x, y: x + y, [1, 2, 3], [4, 5])
164
+
165
+ self.assertIn('length mismatch', str(context.exception))
166
+
167
+ def test_safe_map_complex_functions(self):
168
+ """Test safe_map with complex functions."""
169
+ # Lambda with multiple operations
170
+ result = compat.safe_map(lambda x: x ** 2 + 1, [1, 2, 3])
171
+ self.assertEqual(result, [2, 5, 10])
172
+
173
+ # Function that returns tuples
174
+ result = compat.safe_map(lambda x, y: (x, y), [1, 2], ['a', 'b'])
175
+ self.assertEqual(result, [(1, 'a'), (2, 'b')])
176
+
177
+ def test_safe_zip_basic(self):
178
+ """Test basic safe_zip functionality."""
179
+ # Two sequences
180
+ result = compat.safe_zip([1, 2, 3], ['a', 'b', 'c'])
181
+ expected = [(1, 'a'), (2, 'b'), (3, 'c')]
182
+ self.assertEqual(result, expected)
183
+
184
+ # Three sequences
185
+ result = compat.safe_zip([1, 2], [3, 4], [5, 6])
186
+ expected = [(1, 3, 5), (2, 4, 6)]
187
+ self.assertEqual(result, expected)
188
+
189
+ def test_safe_zip_empty_inputs(self):
190
+ """Test safe_zip with empty inputs."""
191
+ result = compat.safe_zip([], [])
192
+ self.assertEqual(result, [])
193
+
194
+ result = compat.safe_zip([], [], [])
195
+ self.assertEqual(result, [])
196
+
197
+ def test_safe_zip_length_mismatch(self):
198
+ """Test safe_zip raises error on length mismatch."""
199
+ with self.assertRaises(AssertionError) as context:
200
+ compat.safe_zip([1, 2, 3], [4, 5])
201
+
202
+ self.assertIn('length mismatch', str(context.exception))
203
+
204
+ def test_safe_zip_single_sequence(self):
205
+ """Test safe_zip with single sequence."""
206
+ result = compat.safe_zip([1, 2, 3])
207
+ expected = [(1,), (2,), (3,)]
208
+ self.assertEqual(result, expected)
209
+
210
+ def test_safe_zip_mixed_types(self):
211
+ """Test safe_zip with mixed data types."""
212
+ result = compat.safe_zip([1, 2], ['a', 'b'], [True, False])
213
+ expected = [(1, 'a', True), (2, 'b', False)]
214
+ self.assertEqual(result, expected)
215
+
216
+ def test_unzip2_basic(self):
217
+ """Test basic unzip2 functionality."""
218
+ pairs = [(1, 'a'), (2, 'b'), (3, 'c')]
219
+ first, second = compat.unzip2(pairs)
220
+
221
+ self.assertEqual(first, (1, 2, 3))
222
+ self.assertEqual(second, ('a', 'b', 'c'))
223
+
224
+ def test_unzip2_empty(self):
225
+ """Test unzip2 with empty input."""
226
+ first, second = compat.unzip2([])
227
+ self.assertEqual(first, ())
228
+ self.assertEqual(second, ())
229
+
230
+ def test_unzip2_single_pair(self):
231
+ """Test unzip2 with single pair."""
232
+ first, second = compat.unzip2([(42, 'answer')])
233
+ self.assertEqual(first, (42,))
234
+ self.assertEqual(second, ('answer',))
235
+
236
+ def test_unzip2_mixed_types(self):
237
+ """Test unzip2 with mixed data types."""
238
+ pairs = [(1, 'a'), (2.5, 'b'), (None, 'c')]
239
+ first, second = compat.unzip2(pairs)
240
+
241
+ self.assertEqual(first, (1, 2.5, None))
242
+ self.assertEqual(second, ('a', 'b', 'c'))
243
+
244
+ def test_unzip2_return_types(self):
245
+ """Test unzip2 returns proper tuple types."""
246
+ pairs = [(1, 'a'), (2, 'b')]
247
+ first, second = compat.unzip2(pairs)
248
+
249
+ self.assertIsInstance(first, tuple)
250
+ self.assertIsInstance(second, tuple)
251
+
252
+ def test_unzip2_with_generator(self):
253
+ """Test unzip2 with generator input."""
254
+
255
+ def pair_generator():
256
+ yield (1, 'a')
257
+ yield (2, 'b')
258
+ yield (3, 'c')
259
+
260
+ first, second = compat.unzip2(pair_generator())
261
+ self.assertEqual(first, (1, 2, 3))
262
+ self.assertEqual(second, ('a', 'b', 'c'))
263
+
264
+
265
+ class TestFunctionWrapping(unittest.TestCase):
266
+ """Test function wrapping and metadata handling."""
267
+
268
+ def setUp(self):
269
+ """Set up test environment."""
270
+ pass
271
+
272
+ def tearDown(self):
273
+ """Clean up after tests."""
274
+ pass
275
+
276
+ def test_fun_name_basic(self):
277
+ """Test fun_name function with regular functions."""
278
+
279
+ def test_function():
280
+ """Test function docstring."""
281
+ pass
282
+
283
+ name = compat.fun_name(test_function)
284
+ self.assertEqual(name, 'test_function')
285
+
286
+ def test_fun_name_lambda(self):
287
+ """Test fun_name with lambda functions."""
288
+ lambda_func = lambda x: x * 2
289
+ name = compat.fun_name(lambda_func)
290
+ self.assertEqual(name, '<lambda>')
291
+
292
+ def test_fun_name_partial(self):
293
+ """Test fun_name with partial functions."""
294
+
295
+ def original_function(x, y):
296
+ return x + y
297
+
298
+ partial_func = partial(original_function, 10)
299
+ name = compat.fun_name(partial_func)
300
+ self.assertEqual(name, 'original_function')
301
+
302
+ def test_fun_name_nested_partial(self):
303
+ """Test fun_name with nested partial functions."""
304
+
305
+ def base_function(x, y, z):
306
+ return x + y + z
307
+
308
+ partial1 = partial(base_function, 1)
309
+ partial2 = partial(partial1, 2)
310
+
311
+ name = compat.fun_name(partial2)
312
+ self.assertEqual(name, 'base_function')
313
+
314
+ def test_fun_name_no_name_attribute(self):
315
+ """Test fun_name with objects without __name__."""
316
+
317
+ class CallableClass:
318
+ def __call__(self):
319
+ pass
320
+
321
+ callable_obj = CallableClass()
322
+ name = compat.fun_name(callable_obj)
323
+ self.assertEqual(name, '<unnamed function>')
324
+
325
+ def test_wraps_basic(self):
326
+ """Test basic wraps functionality."""
327
+
328
+ def original_function():
329
+ """Original function docstring."""
330
+ return 42
331
+
332
+ @compat.wraps(original_function)
333
+ def wrapper():
334
+ return original_function()
335
+
336
+ self.assertEqual(wrapper.__name__, 'original_function')
337
+ self.assertEqual(wrapper.__doc__, 'Original function docstring.')
338
+ self.assertEqual(wrapper.__wrapped__, original_function)
339
+
340
+ def test_wraps_with_namestr(self):
341
+ """Test wraps with custom name string."""
342
+
343
+ def original_function():
344
+ pass
345
+
346
+ @compat.wraps(original_function, namestr="wrapped_{fun}")
347
+ def wrapper():
348
+ pass
349
+
350
+ self.assertEqual(wrapper.__name__, 'wrapped_original_function')
351
+
352
+ def test_wraps_with_docstr(self):
353
+ """Test wraps with custom docstring."""
354
+
355
+ def original_function():
356
+ """Original docstring."""
357
+ pass
358
+
359
+ @compat.wraps(original_function, docstr="Wrapper for {fun}: {doc}")
360
+ def wrapper():
361
+ pass
362
+
363
+ expected_doc = "Wrapper for original_function: Original docstring."
364
+ self.assertEqual(wrapper.__doc__, expected_doc)
365
+
366
+ def test_wraps_with_kwargs(self):
367
+ """Test wraps with additional keyword arguments."""
368
+
369
+ def original_function():
370
+ pass
371
+
372
+ @compat.wraps(original_function,
373
+ docstr="Function {fun} version {version}",
374
+ version="1.0")
375
+ def wrapper():
376
+ pass
377
+
378
+ expected_doc = "Function original_function version 1.0"
379
+ self.assertEqual(wrapper.__doc__, expected_doc)
380
+
381
+ def test_wraps_preserves_annotations(self):
382
+ """Test wraps preserves function annotations."""
383
+
384
+ def original_function(x: int, y: str) -> float:
385
+ return float(x)
386
+
387
+ @compat.wraps(original_function)
388
+ def wrapper(x: int, y: str) -> float:
389
+ return original_function(x, y)
390
+
391
+ self.assertEqual(wrapper.__annotations__, original_function.__annotations__)
392
+
393
+ def test_wraps_preserves_dict(self):
394
+ """Test wraps preserves function __dict__."""
395
+
396
+ def original_function():
397
+ pass
398
+
399
+ original_function.custom_attr = "test_value"
400
+ original_function.another_attr = 42
401
+
402
+ @compat.wraps(original_function)
403
+ def wrapper():
404
+ pass
405
+
406
+ self.assertEqual(wrapper.custom_attr, "test_value")
407
+ self.assertEqual(wrapper.another_attr, 42)
408
+
409
+ def test_wraps_handles_exceptions(self):
410
+ """Test wraps handles exceptions gracefully."""
411
+ # Create a mock object that raises exceptions
412
+ mock_func = Mock()
413
+ mock_func.__name__ = Mock(side_effect=Exception("Test exception"))
414
+
415
+ @compat.wraps(mock_func)
416
+ def wrapper():
417
+ pass
418
+
419
+ # Should not raise exception, just continue
420
+ self.assertTrue(callable(wrapper))
421
+
422
+ def test_wraps_with_missing_attributes(self):
423
+ """Test wraps handles missing attributes gracefully."""
424
+
425
+ class MinimalCallable:
426
+ pass
427
+
428
+ minimal_func = MinimalCallable()
429
+
430
+ @compat.wraps(minimal_func)
431
+ def wrapper():
432
+ pass
433
+
434
+ # Should handle missing attributes without crashing
435
+ self.assertTrue(callable(wrapper))
436
+
437
+
438
+ class TestEdgeCases(unittest.TestCase):
439
+ """Test edge cases and boundary conditions."""
440
+
441
+ def setUp(self):
442
+ """Set up test environment."""
443
+ pass
444
+
445
+ def tearDown(self):
446
+ """Clean up after tests."""
447
+ pass
448
+
449
+ def test_safe_map_with_none_inputs(self):
450
+ """Test safe_map behavior with None inputs."""
451
+ # Function that handles None
452
+ result = compat.safe_map(lambda x: x if x is not None else 'default',
453
+ [1, None, 3])
454
+ self.assertEqual(result, [1, 'default', 3])
455
+
456
+ def test_safe_map_with_zero_length(self):
457
+ """Test safe_map with zero-length sequences."""
458
+ result = compat.safe_map(str, [])
459
+ self.assertEqual(result, [])
460
+
461
+ def test_safe_zip_with_none_values(self):
462
+ """Test safe_zip with None values."""
463
+ result = compat.safe_zip([1, None, 3], [4, 5, None])
464
+ expected = [(1, 4), (None, 5), (3, None)]
465
+ self.assertEqual(result, expected)
466
+
467
+ def test_unzip2_with_none_values(self):
468
+ """Test unzip2 with None values."""
469
+ pairs = [(1, None), (None, 'a'), (2, 'b')]
470
+ first, second = compat.unzip2(pairs)
471
+
472
+ self.assertEqual(first, (1, None, 2))
473
+ self.assertEqual(second, (None, 'a', 'b'))
474
+
475
+ def test_large_sequences(self):
476
+ """Test utility functions with large sequences."""
477
+ large_seq1 = list(range(10000))
478
+ large_seq2 = list(range(10000, 20000))
479
+
480
+ # Test safe_map
481
+ result = compat.safe_map(lambda x, y: x + y, large_seq1[:100], large_seq2[:100])
482
+ self.assertEqual(len(result), 100)
483
+ self.assertEqual(result[0], 10000) # 0 + 10000
484
+
485
+ # Test safe_zip
486
+ result = compat.safe_zip(large_seq1[:100], large_seq2[:100])
487
+ self.assertEqual(len(result), 100)
488
+ self.assertEqual(result[0], (0, 10000))
489
+
490
+ # Test unzip2
491
+ pairs = list(zip(large_seq1[:100], large_seq2[:100]))
492
+ first, second = compat.unzip2(pairs)
493
+ self.assertEqual(len(first), 100)
494
+ self.assertEqual(len(second), 100)
495
+
496
+ # def test_to_concrete_aval_edge_cases(self):
497
+ # """Test to_concrete_aval with edge cases."""
498
+ # # Test with None
499
+ # result = compat.to_concrete_aval(None)
500
+ # self.assertIsNone(result)
501
+ #
502
+ # # Test with regular Python objects
503
+ # result = compat.to_concrete_aval(42)
504
+ # self.assertEqual(result, 42)
505
+ #
506
+ # result = compat.to_concrete_aval("string")
507
+ # self.assertEqual(result, "string")
508
+ #
509
+ # # Test with list
510
+ # test_list = [1, 2, 3]
511
+ # result = compat.to_concrete_aval(test_list)
512
+ # self.assertEqual(result, test_list)
513
+
514
+ def test_function_name_edge_cases(self):
515
+ """Test fun_name with edge cases."""
516
+ # Built-in function
517
+ name = compat.fun_name(len)
518
+ self.assertEqual(name, 'len')
519
+
520
+ # Method
521
+ name = compat.fun_name(str.upper)
522
+ self.assertEqual(name, 'upper')
523
+
524
+ # Nested function
525
+ def outer():
526
+ def inner():
527
+ pass
528
+
529
+ return inner
530
+
531
+ inner_func = outer()
532
+ name = compat.fun_name(inner_func)
533
+ self.assertEqual(name, 'inner')
534
+
535
+ def test_concurrent_usage(self):
536
+ """Test thread safety of utility functions."""
537
+ import threading
538
+
539
+ results = []
540
+ errors = []
541
+
542
+ def worker():
543
+ try:
544
+ # Test safe_map in concurrent context
545
+ for i in range(100):
546
+ result = compat.safe_map(lambda x: x * 2, [1, 2, 3])
547
+ results.append(result)
548
+
549
+ # Test safe_zip
550
+ for i in range(100):
551
+ result = compat.safe_zip([1, 2], [3, 4])
552
+
553
+ # Test unzip2
554
+ for i in range(100):
555
+ first, second = compat.unzip2([(1, 'a'), (2, 'b')])
556
+
557
+ except Exception as e:
558
+ errors.append(e)
559
+
560
+ threads = []
561
+ for _ in range(5):
562
+ thread = threading.Thread(target=worker)
563
+ threads.append(thread)
564
+ thread.start()
565
+
566
+ for thread in threads:
567
+ thread.join()
568
+
569
+ # Should have no errors
570
+ self.assertEqual(len(errors), 0)
571
+ # Should have expected number of results
572
+ self.assertEqual(len(results), 500) # 5 threads * 100 iterations
573
+
574
+
575
+ class TestTypeHints(unittest.TestCase):
576
+ """Test type hints and generic type variables."""
577
+
578
+ def test_type_variables_defined(self):
579
+ """Test that type variables are properly defined."""
580
+ # Check TypeVars are available in the module
581
+ self.assertTrue(hasattr(compat, 'T'))
582
+ self.assertTrue(hasattr(compat, 'T1'))
583
+ self.assertTrue(hasattr(compat, 'T2'))
584
+ self.assertTrue(hasattr(compat, 'T3'))
585
+
586
+ def test_unzip2_type_preservation(self):
587
+ """Test unzip2 preserves type information."""
588
+ # Test with specific types
589
+ int_str_pairs = [(1, 'a'), (2, 'b'), (3, 'c')]
590
+ ints, strs = compat.unzip2(int_str_pairs)
591
+
592
+ # Verify types are preserved
593
+ self.assertTrue(all(isinstance(x, int) for x in ints))
594
+ self.assertTrue(all(isinstance(x, str) for x in strs))
595
+
596
+ def test_safe_functions_with_different_types(self):
597
+ """Test safe functions work with different types."""
598
+ # Test safe_map with different input types
599
+ mixed_inputs = [1, 2.5, '3']
600
+ result = compat.safe_map(str, mixed_inputs)
601
+ expected = ['1', '2.5', '3']
602
+ self.assertEqual(result, expected)
603
+
604
+ # Test safe_zip with different types
605
+ result = compat.safe_zip([1, 2], [3.14, 2.71], ['a', 'b'])
606
+ expected = [(1, 3.14, 'a'), (2, 2.71, 'b')]
607
+ self.assertEqual(result, expected)
608
+
609
+
610
+ class TestIntegration(unittest.TestCase):
611
+ """Integration tests with JAX functionality."""
612
+
613
+ def test_jax_integration(self):
614
+ """Test integration with JAX arrays and operations."""
615
+ # Create JAX arrays
616
+ arr1 = jnp.array([1, 2, 3])
617
+ arr2 = jnp.array([4, 5, 6])
618
+
619
+ # Use safe_map with JAX operations
620
+ result = compat.safe_map(lambda x, y: x + y, arr1.tolist(), arr2.tolist())
621
+ expected = [5, 7, 9]
622
+ self.assertEqual(result, expected)
623
+
624
+ # Use safe_zip with JAX arrays
625
+ result = compat.safe_zip(arr1.tolist(), arr2.tolist())
626
+ expected = [(1, 4), (2, 5), (3, 6)]
627
+ self.assertEqual(result, expected)
628
+
629
+ def test_with_jax_transformations(self):
630
+ """Test compatibility with JAX transformations."""
631
+
632
+ def test_function(x):
633
+ # Use utility functions inside JAX-transformable code
634
+ pairs = [(x, x + 1), (x + 2, x + 3)]
635
+ first, second = compat.unzip2(pairs)
636
+ return jnp.array(first), jnp.array(second)
637
+
638
+ # Test function works
639
+ result1, result2 = test_function(10.0)
640
+ np.testing.assert_array_equal(result1, [10.0, 12.0])
641
+ np.testing.assert_array_equal(result2, [11.0, 13.0])
642
+
643
+ # Test with JAX transformations
644
+ jitted_func = jax.jit(test_function)
645
+ result1, result2 = jitted_func(10.0)
646
+ np.testing.assert_array_equal(result1, [10.0, 12.0])
647
+ np.testing.assert_array_equal(result2, [11.0, 13.0])
648
+
649
+
650
+ class TestModuleStructure(unittest.TestCase):
651
+ """Test module structure and __all__ exports."""
652
+
653
+ def test_all_exports(self):
654
+ """Test that __all__ contains expected exports."""
655
+ expected_exports = [
656
+ 'ClosedJaxpr', 'Primitive', 'extend_axis_env_nd', 'jaxpr_as_fun',
657
+ 'get_aval', 'Tracer', 'to_concrete_aval', 'safe_map', 'safe_zip',
658
+ 'unzip2', 'wraps', 'Device', 'wrap_init', 'Var', 'JaxprEqn',
659
+ 'Jaxpr', 'Literal'
660
+ ]
661
+
662
+ for export in expected_exports:
663
+ self.assertIn(export, compat.__all__,
664
+ f"{export} should be in __all__")
665
+ self.assertTrue(hasattr(compat, export),
666
+ f"{export} should be available in module")
667
+
668
+ def test_no_unexpected_exports(self):
669
+ """Test that no private functions are exported."""
670
+ for name in compat.__all__:
671
+ self.assertFalse(name.startswith('_'),
672
+ f"Private name {name} should not be in __all__")
673
+
674
+ def test_module_docstring(self):
675
+ """Test module has proper docstring."""
676
+ self.assertIsNotNone(compat.__doc__)
677
+ self.assertIn('Compatibility layer', compat.__doc__)
678
+
679
+
680
+ if __name__ == '__main__':
681
+ unittest.main(verbosity=2)