brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,681 @@
1
+ # Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ Comprehensive test suite for the _compatible_import module.
18
+
19
+ This test module provides extensive coverage of the compatibility layer
20
+ functionality, including:
21
+ - JAX version-dependent imports and compatibility
22
+ - Utility functions (safe_map, safe_zip, unzip2)
23
+ - Function wrapping and metadata handling
24
+ - Type safety and error handling
25
+ - Edge cases and boundary conditions
26
+ """
27
+
28
+ import unittest
29
+ from functools import partial
30
+ from unittest.mock import Mock
31
+
32
+ import jax
33
+ import jax.numpy as jnp
34
+ import numpy as np
35
+
36
+ from brainstate import _compatible_import as compat
37
+
38
+
39
+ class TestJAXVersionCompatibility(unittest.TestCase):
40
+ """Test JAX version-dependent imports and compatibility."""
41
+
42
+ def setUp(self):
43
+ """Set up test environment."""
44
+ self.original_version = jax.__version_info__
45
+
46
+ def tearDown(self):
47
+ """Clean up after tests."""
48
+ # Restore original version info
49
+ jax.__version_info__ = self.original_version
50
+
51
+ def test_device_import_compatibility(self):
52
+ """Test Device import works across JAX versions."""
53
+ # Test that Device is available and importable
54
+ self.assertTrue(hasattr(compat, 'Device'))
55
+ self.assertIsNotNone(compat.Device)
56
+
57
+ # Test Device can be used for type checking
58
+ device = jax.devices()[0]
59
+ self.assertIsInstance(device, compat.Device)
60
+
61
+ def test_core_imports_availability(self):
62
+ """Test core JAX imports are available."""
63
+ # Core types should be available
64
+ core_types = [
65
+ 'ClosedJaxpr', 'Primitive', 'Var', 'JaxprEqn',
66
+ 'Jaxpr', 'Literal', 'Tracer'
67
+ ]
68
+
69
+ for type_name in core_types:
70
+ self.assertTrue(hasattr(compat, type_name),
71
+ f"{type_name} should be available")
72
+ self.assertIsNotNone(getattr(compat, type_name))
73
+
74
+ def test_function_imports_availability(self):
75
+ """Test function imports are available."""
76
+ functions = [
77
+ 'jaxpr_as_fun', 'get_aval', 'to_concrete_aval',
78
+ 'extend_axis_env_nd'
79
+ ]
80
+
81
+ for func_name in functions:
82
+ self.assertTrue(hasattr(compat, func_name),
83
+ f"{func_name} should be available")
84
+ self.assertTrue(callable(getattr(compat, func_name)),
85
+ f"{func_name} should be callable")
86
+
87
+ def test_extend_axis_env_nd_functionality(self):
88
+ """Test extend_axis_env_nd context manager."""
89
+ # Test basic functionality
90
+ with compat.extend_axis_env_nd([('test_axis', 10)]):
91
+ # Context should execute without error
92
+ pass
93
+
94
+ # Test with multiple axes
95
+ with compat.extend_axis_env_nd([('batch', 32), ('seq', 128)]):
96
+ pass
97
+
98
+ # Test with empty axes
99
+ with compat.extend_axis_env_nd([]):
100
+ pass
101
+
102
+ def test_get_aval_functionality(self):
103
+ """Test get_aval function works correctly."""
104
+ # Test with JAX array
105
+ arr = jnp.array([1, 2, 3])
106
+ aval = compat.get_aval(arr)
107
+ self.assertIsNotNone(aval)
108
+
109
+ # Test with scalar
110
+ scalar = jnp.float32(3.14)
111
+ scalar_aval = compat.get_aval(scalar)
112
+ self.assertIsNotNone(scalar_aval)
113
+
114
+ def test_to_concrete_aval_functionality(self):
115
+ """Test to_concrete_aval function."""
116
+ # Test with concrete array
117
+ arr = jnp.array([1, 2, 3])
118
+ result = compat.to_concrete_aval(arr)
119
+ self.assertIsNotNone(result)
120
+
121
+ # # Test with scalar
122
+ # scalar = 42.0
123
+ # result = compat.to_concrete_aval(scalar)
124
+ # self.assertEqual(result, scalar)
125
+
126
+
127
+ class TestUtilityFunctions(unittest.TestCase):
128
+ """Test utility functions like safe_map, safe_zip, unzip2."""
129
+
130
+ def setUp(self):
131
+ """Set up test environment."""
132
+ pass
133
+
134
+ def tearDown(self):
135
+ """Clean up after tests."""
136
+ pass
137
+
138
+ def test_safe_map_basic(self):
139
+ """Test basic safe_map functionality."""
140
+ # Single argument function
141
+ result = compat.safe_map(lambda x: x * 2, [1, 2, 3])
142
+ self.assertEqual(result, [2, 4, 6])
143
+
144
+ # Multiple argument function
145
+ result = compat.safe_map(lambda x, y: x + y, [1, 2, 3], [4, 5, 6])
146
+ self.assertEqual(result, [5, 7, 9])
147
+
148
+ # String function
149
+ result = compat.safe_map(str.upper, ['a', 'b', 'c'])
150
+ self.assertEqual(result, ['A', 'B', 'C'])
151
+
152
+ def test_safe_map_empty_inputs(self):
153
+ """Test safe_map with empty inputs."""
154
+ result = compat.safe_map(lambda x: x, [])
155
+ self.assertEqual(result, [])
156
+
157
+ result = compat.safe_map(lambda x, y: x + y, [], [])
158
+ self.assertEqual(result, [])
159
+
160
+ def test_safe_map_length_mismatch(self):
161
+ """Test safe_map raises error on length mismatch."""
162
+ with self.assertRaises(AssertionError) as context:
163
+ compat.safe_map(lambda x, y: x + y, [1, 2, 3], [4, 5])
164
+
165
+ self.assertIn('length mismatch', str(context.exception))
166
+
167
+ def test_safe_map_complex_functions(self):
168
+ """Test safe_map with complex functions."""
169
+ # Lambda with multiple operations
170
+ result = compat.safe_map(lambda x: x ** 2 + 1, [1, 2, 3])
171
+ self.assertEqual(result, [2, 5, 10])
172
+
173
+ # Function that returns tuples
174
+ result = compat.safe_map(lambda x, y: (x, y), [1, 2], ['a', 'b'])
175
+ self.assertEqual(result, [(1, 'a'), (2, 'b')])
176
+
177
+ def test_safe_zip_basic(self):
178
+ """Test basic safe_zip functionality."""
179
+ # Two sequences
180
+ result = compat.safe_zip([1, 2, 3], ['a', 'b', 'c'])
181
+ expected = [(1, 'a'), (2, 'b'), (3, 'c')]
182
+ self.assertEqual(result, expected)
183
+
184
+ # Three sequences
185
+ result = compat.safe_zip([1, 2], [3, 4], [5, 6])
186
+ expected = [(1, 3, 5), (2, 4, 6)]
187
+ self.assertEqual(result, expected)
188
+
189
+ def test_safe_zip_empty_inputs(self):
190
+ """Test safe_zip with empty inputs."""
191
+ result = compat.safe_zip([], [])
192
+ self.assertEqual(result, [])
193
+
194
+ result = compat.safe_zip([], [], [])
195
+ self.assertEqual(result, [])
196
+
197
+ def test_safe_zip_length_mismatch(self):
198
+ """Test safe_zip raises error on length mismatch."""
199
+ with self.assertRaises(AssertionError) as context:
200
+ compat.safe_zip([1, 2, 3], [4, 5])
201
+
202
+ self.assertIn('length mismatch', str(context.exception))
203
+
204
+ def test_safe_zip_single_sequence(self):
205
+ """Test safe_zip with single sequence."""
206
+ result = compat.safe_zip([1, 2, 3])
207
+ expected = [(1,), (2,), (3,)]
208
+ self.assertEqual(result, expected)
209
+
210
+ def test_safe_zip_mixed_types(self):
211
+ """Test safe_zip with mixed data types."""
212
+ result = compat.safe_zip([1, 2], ['a', 'b'], [True, False])
213
+ expected = [(1, 'a', True), (2, 'b', False)]
214
+ self.assertEqual(result, expected)
215
+
216
+ def test_unzip2_basic(self):
217
+ """Test basic unzip2 functionality."""
218
+ pairs = [(1, 'a'), (2, 'b'), (3, 'c')]
219
+ first, second = compat.unzip2(pairs)
220
+
221
+ self.assertEqual(first, (1, 2, 3))
222
+ self.assertEqual(second, ('a', 'b', 'c'))
223
+
224
+ def test_unzip2_empty(self):
225
+ """Test unzip2 with empty input."""
226
+ first, second = compat.unzip2([])
227
+ self.assertEqual(first, ())
228
+ self.assertEqual(second, ())
229
+
230
+ def test_unzip2_single_pair(self):
231
+ """Test unzip2 with single pair."""
232
+ first, second = compat.unzip2([(42, 'answer')])
233
+ self.assertEqual(first, (42,))
234
+ self.assertEqual(second, ('answer',))
235
+
236
+ def test_unzip2_mixed_types(self):
237
+ """Test unzip2 with mixed data types."""
238
+ pairs = [(1, 'a'), (2.5, 'b'), (None, 'c')]
239
+ first, second = compat.unzip2(pairs)
240
+
241
+ self.assertEqual(first, (1, 2.5, None))
242
+ self.assertEqual(second, ('a', 'b', 'c'))
243
+
244
+ def test_unzip2_return_types(self):
245
+ """Test unzip2 returns proper tuple types."""
246
+ pairs = [(1, 'a'), (2, 'b')]
247
+ first, second = compat.unzip2(pairs)
248
+
249
+ self.assertIsInstance(first, tuple)
250
+ self.assertIsInstance(second, tuple)
251
+
252
+ def test_unzip2_with_generator(self):
253
+ """Test unzip2 with generator input."""
254
+
255
+ def pair_generator():
256
+ yield (1, 'a')
257
+ yield (2, 'b')
258
+ yield (3, 'c')
259
+
260
+ first, second = compat.unzip2(pair_generator())
261
+ self.assertEqual(first, (1, 2, 3))
262
+ self.assertEqual(second, ('a', 'b', 'c'))
263
+
264
+
265
+ class TestFunctionWrapping(unittest.TestCase):
266
+ """Test function wrapping and metadata handling."""
267
+
268
+ def setUp(self):
269
+ """Set up test environment."""
270
+ pass
271
+
272
+ def tearDown(self):
273
+ """Clean up after tests."""
274
+ pass
275
+
276
+ def test_fun_name_basic(self):
277
+ """Test fun_name function with regular functions."""
278
+
279
+ def test_function():
280
+ """Test function docstring."""
281
+ pass
282
+
283
+ name = compat.fun_name(test_function)
284
+ self.assertEqual(name, 'test_function')
285
+
286
+ def test_fun_name_lambda(self):
287
+ """Test fun_name with lambda functions."""
288
+ lambda_func = lambda x: x * 2
289
+ name = compat.fun_name(lambda_func)
290
+ self.assertEqual(name, '<lambda>')
291
+
292
+ def test_fun_name_partial(self):
293
+ """Test fun_name with partial functions."""
294
+
295
+ def original_function(x, y):
296
+ return x + y
297
+
298
+ partial_func = partial(original_function, 10)
299
+ name = compat.fun_name(partial_func)
300
+ self.assertEqual(name, 'original_function')
301
+
302
+ def test_fun_name_nested_partial(self):
303
+ """Test fun_name with nested partial functions."""
304
+
305
+ def base_function(x, y, z):
306
+ return x + y + z
307
+
308
+ partial1 = partial(base_function, 1)
309
+ partial2 = partial(partial1, 2)
310
+
311
+ name = compat.fun_name(partial2)
312
+ self.assertEqual(name, 'base_function')
313
+
314
+ def test_fun_name_no_name_attribute(self):
315
+ """Test fun_name with objects without __name__."""
316
+
317
+ class CallableClass:
318
+ def __call__(self):
319
+ pass
320
+
321
+ callable_obj = CallableClass()
322
+ name = compat.fun_name(callable_obj)
323
+ self.assertEqual(name, '<unnamed function>')
324
+
325
+ def test_wraps_basic(self):
326
+ """Test basic wraps functionality."""
327
+
328
+ def original_function():
329
+ """Original function docstring."""
330
+ return 42
331
+
332
+ @compat.wraps(original_function)
333
+ def wrapper():
334
+ return original_function()
335
+
336
+ self.assertEqual(wrapper.__name__, 'original_function')
337
+ self.assertEqual(wrapper.__doc__, 'Original function docstring.')
338
+ self.assertEqual(wrapper.__wrapped__, original_function)
339
+
340
+ def test_wraps_with_namestr(self):
341
+ """Test wraps with custom name string."""
342
+
343
+ def original_function():
344
+ pass
345
+
346
+ @compat.wraps(original_function, namestr="wrapped_{fun}")
347
+ def wrapper():
348
+ pass
349
+
350
+ self.assertEqual(wrapper.__name__, 'wrapped_original_function')
351
+
352
+ def test_wraps_with_docstr(self):
353
+ """Test wraps with custom docstring."""
354
+
355
+ def original_function():
356
+ """Original docstring."""
357
+ pass
358
+
359
+ @compat.wraps(original_function, docstr="Wrapper for {fun}: {doc}")
360
+ def wrapper():
361
+ pass
362
+
363
+ expected_doc = "Wrapper for original_function: Original docstring."
364
+ self.assertEqual(wrapper.__doc__, expected_doc)
365
+
366
+ def test_wraps_with_kwargs(self):
367
+ """Test wraps with additional keyword arguments."""
368
+
369
+ def original_function():
370
+ pass
371
+
372
+ @compat.wraps(original_function,
373
+ docstr="Function {fun} version {version}",
374
+ version="1.0")
375
+ def wrapper():
376
+ pass
377
+
378
+ expected_doc = "Function original_function version 1.0"
379
+ self.assertEqual(wrapper.__doc__, expected_doc)
380
+
381
+ def test_wraps_preserves_annotations(self):
382
+ """Test wraps preserves function annotations."""
383
+
384
+ def original_function(x: int, y: str) -> float:
385
+ return float(x)
386
+
387
+ @compat.wraps(original_function)
388
+ def wrapper(x: int, y: str) -> float:
389
+ return original_function(x, y)
390
+
391
+ self.assertEqual(wrapper.__annotations__, original_function.__annotations__)
392
+
393
+ def test_wraps_preserves_dict(self):
394
+ """Test wraps preserves function __dict__."""
395
+
396
+ def original_function():
397
+ pass
398
+
399
+ original_function.custom_attr = "test_value"
400
+ original_function.another_attr = 42
401
+
402
+ @compat.wraps(original_function)
403
+ def wrapper():
404
+ pass
405
+
406
+ self.assertEqual(wrapper.custom_attr, "test_value")
407
+ self.assertEqual(wrapper.another_attr, 42)
408
+
409
+ def test_wraps_handles_exceptions(self):
410
+ """Test wraps handles exceptions gracefully."""
411
+ # Create a mock object that raises exceptions
412
+ mock_func = Mock()
413
+ mock_func.__name__ = Mock(side_effect=Exception("Test exception"))
414
+
415
+ @compat.wraps(mock_func)
416
+ def wrapper():
417
+ pass
418
+
419
+ # Should not raise exception, just continue
420
+ self.assertTrue(callable(wrapper))
421
+
422
+ def test_wraps_with_missing_attributes(self):
423
+ """Test wraps handles missing attributes gracefully."""
424
+
425
+ class MinimalCallable:
426
+ pass
427
+
428
+ minimal_func = MinimalCallable()
429
+
430
+ @compat.wraps(minimal_func)
431
+ def wrapper():
432
+ pass
433
+
434
+ # Should handle missing attributes without crashing
435
+ self.assertTrue(callable(wrapper))
436
+
437
+
438
+ class TestEdgeCases(unittest.TestCase):
439
+ """Test edge cases and boundary conditions."""
440
+
441
+ def setUp(self):
442
+ """Set up test environment."""
443
+ pass
444
+
445
+ def tearDown(self):
446
+ """Clean up after tests."""
447
+ pass
448
+
449
+ def test_safe_map_with_none_inputs(self):
450
+ """Test safe_map behavior with None inputs."""
451
+ # Function that handles None
452
+ result = compat.safe_map(lambda x: x if x is not None else 'default',
453
+ [1, None, 3])
454
+ self.assertEqual(result, [1, 'default', 3])
455
+
456
+ def test_safe_map_with_zero_length(self):
457
+ """Test safe_map with zero-length sequences."""
458
+ result = compat.safe_map(str, [])
459
+ self.assertEqual(result, [])
460
+
461
+ def test_safe_zip_with_none_values(self):
462
+ """Test safe_zip with None values."""
463
+ result = compat.safe_zip([1, None, 3], [4, 5, None])
464
+ expected = [(1, 4), (None, 5), (3, None)]
465
+ self.assertEqual(result, expected)
466
+
467
+ def test_unzip2_with_none_values(self):
468
+ """Test unzip2 with None values."""
469
+ pairs = [(1, None), (None, 'a'), (2, 'b')]
470
+ first, second = compat.unzip2(pairs)
471
+
472
+ self.assertEqual(first, (1, None, 2))
473
+ self.assertEqual(second, (None, 'a', 'b'))
474
+
475
+ def test_large_sequences(self):
476
+ """Test utility functions with large sequences."""
477
+ large_seq1 = list(range(10000))
478
+ large_seq2 = list(range(10000, 20000))
479
+
480
+ # Test safe_map
481
+ result = compat.safe_map(lambda x, y: x + y, large_seq1[:100], large_seq2[:100])
482
+ self.assertEqual(len(result), 100)
483
+ self.assertEqual(result[0], 10000) # 0 + 10000
484
+
485
+ # Test safe_zip
486
+ result = compat.safe_zip(large_seq1[:100], large_seq2[:100])
487
+ self.assertEqual(len(result), 100)
488
+ self.assertEqual(result[0], (0, 10000))
489
+
490
+ # Test unzip2
491
+ pairs = list(zip(large_seq1[:100], large_seq2[:100]))
492
+ first, second = compat.unzip2(pairs)
493
+ self.assertEqual(len(first), 100)
494
+ self.assertEqual(len(second), 100)
495
+
496
+ # def test_to_concrete_aval_edge_cases(self):
497
+ # """Test to_concrete_aval with edge cases."""
498
+ # # Test with None
499
+ # result = compat.to_concrete_aval(None)
500
+ # self.assertIsNone(result)
501
+ #
502
+ # # Test with regular Python objects
503
+ # result = compat.to_concrete_aval(42)
504
+ # self.assertEqual(result, 42)
505
+ #
506
+ # result = compat.to_concrete_aval("string")
507
+ # self.assertEqual(result, "string")
508
+ #
509
+ # # Test with list
510
+ # test_list = [1, 2, 3]
511
+ # result = compat.to_concrete_aval(test_list)
512
+ # self.assertEqual(result, test_list)
513
+
514
+ def test_function_name_edge_cases(self):
515
+ """Test fun_name with edge cases."""
516
+ # Built-in function
517
+ name = compat.fun_name(len)
518
+ self.assertEqual(name, 'len')
519
+
520
+ # Method
521
+ name = compat.fun_name(str.upper)
522
+ self.assertEqual(name, 'upper')
523
+
524
+ # Nested function
525
+ def outer():
526
+ def inner():
527
+ pass
528
+
529
+ return inner
530
+
531
+ inner_func = outer()
532
+ name = compat.fun_name(inner_func)
533
+ self.assertEqual(name, 'inner')
534
+
535
+ def test_concurrent_usage(self):
536
+ """Test thread safety of utility functions."""
537
+ import threading
538
+
539
+ results = []
540
+ errors = []
541
+
542
+ def worker():
543
+ try:
544
+ # Test safe_map in concurrent context
545
+ for i in range(100):
546
+ result = compat.safe_map(lambda x: x * 2, [1, 2, 3])
547
+ results.append(result)
548
+
549
+ # Test safe_zip
550
+ for i in range(100):
551
+ result = compat.safe_zip([1, 2], [3, 4])
552
+
553
+ # Test unzip2
554
+ for i in range(100):
555
+ first, second = compat.unzip2([(1, 'a'), (2, 'b')])
556
+
557
+ except Exception as e:
558
+ errors.append(e)
559
+
560
+ threads = []
561
+ for _ in range(5):
562
+ thread = threading.Thread(target=worker)
563
+ threads.append(thread)
564
+ thread.start()
565
+
566
+ for thread in threads:
567
+ thread.join()
568
+
569
+ # Should have no errors
570
+ self.assertEqual(len(errors), 0)
571
+ # Should have expected number of results
572
+ self.assertEqual(len(results), 500) # 5 threads * 100 iterations
573
+
574
+
575
+ class TestTypeHints(unittest.TestCase):
576
+ """Test type hints and generic type variables."""
577
+
578
+ def test_type_variables_defined(self):
579
+ """Test that type variables are properly defined."""
580
+ # Check TypeVars are available in the module
581
+ self.assertTrue(hasattr(compat, 'T'))
582
+ self.assertTrue(hasattr(compat, 'T1'))
583
+ self.assertTrue(hasattr(compat, 'T2'))
584
+ self.assertTrue(hasattr(compat, 'T3'))
585
+
586
+ def test_unzip2_type_preservation(self):
587
+ """Test unzip2 preserves type information."""
588
+ # Test with specific types
589
+ int_str_pairs = [(1, 'a'), (2, 'b'), (3, 'c')]
590
+ ints, strs = compat.unzip2(int_str_pairs)
591
+
592
+ # Verify types are preserved
593
+ self.assertTrue(all(isinstance(x, int) for x in ints))
594
+ self.assertTrue(all(isinstance(x, str) for x in strs))
595
+
596
+ def test_safe_functions_with_different_types(self):
597
+ """Test safe functions work with different types."""
598
+ # Test safe_map with different input types
599
+ mixed_inputs = [1, 2.5, '3']
600
+ result = compat.safe_map(str, mixed_inputs)
601
+ expected = ['1', '2.5', '3']
602
+ self.assertEqual(result, expected)
603
+
604
+ # Test safe_zip with different types
605
+ result = compat.safe_zip([1, 2], [3.14, 2.71], ['a', 'b'])
606
+ expected = [(1, 3.14, 'a'), (2, 2.71, 'b')]
607
+ self.assertEqual(result, expected)
608
+
609
+
610
+ class TestIntegration(unittest.TestCase):
611
+ """Integration tests with JAX functionality."""
612
+
613
+ def test_jax_integration(self):
614
+ """Test integration with JAX arrays and operations."""
615
+ # Create JAX arrays
616
+ arr1 = jnp.array([1, 2, 3])
617
+ arr2 = jnp.array([4, 5, 6])
618
+
619
+ # Use safe_map with JAX operations
620
+ result = compat.safe_map(lambda x, y: x + y, arr1.tolist(), arr2.tolist())
621
+ expected = [5, 7, 9]
622
+ self.assertEqual(result, expected)
623
+
624
+ # Use safe_zip with JAX arrays
625
+ result = compat.safe_zip(arr1.tolist(), arr2.tolist())
626
+ expected = [(1, 4), (2, 5), (3, 6)]
627
+ self.assertEqual(result, expected)
628
+
629
+ def test_with_jax_transformations(self):
630
+ """Test compatibility with JAX transformations."""
631
+
632
+ def test_function(x):
633
+ # Use utility functions inside JAX-transformable code
634
+ pairs = [(x, x + 1), (x + 2, x + 3)]
635
+ first, second = compat.unzip2(pairs)
636
+ return jnp.array(first), jnp.array(second)
637
+
638
+ # Test function works
639
+ result1, result2 = test_function(10.0)
640
+ np.testing.assert_array_equal(result1, [10.0, 12.0])
641
+ np.testing.assert_array_equal(result2, [11.0, 13.0])
642
+
643
+ # Test with JAX transformations
644
+ jitted_func = jax.jit(test_function)
645
+ result1, result2 = jitted_func(10.0)
646
+ np.testing.assert_array_equal(result1, [10.0, 12.0])
647
+ np.testing.assert_array_equal(result2, [11.0, 13.0])
648
+
649
+
650
+ class TestModuleStructure(unittest.TestCase):
651
+ """Test module structure and __all__ exports."""
652
+
653
+ def test_all_exports(self):
654
+ """Test that __all__ contains expected exports."""
655
+ expected_exports = [
656
+ 'ClosedJaxpr', 'Primitive', 'extend_axis_env_nd', 'jaxpr_as_fun',
657
+ 'get_aval', 'Tracer', 'to_concrete_aval', 'safe_map', 'safe_zip',
658
+ 'unzip2', 'wraps', 'Device', 'wrap_init', 'Var', 'JaxprEqn',
659
+ 'Jaxpr', 'Literal'
660
+ ]
661
+
662
+ for export in expected_exports:
663
+ self.assertIn(export, compat.__all__,
664
+ f"{export} should be in __all__")
665
+ self.assertTrue(hasattr(compat, export),
666
+ f"{export} should be available in module")
667
+
668
+ def test_no_unexpected_exports(self):
669
+ """Test that no private functions are exported."""
670
+ for name in compat.__all__:
671
+ self.assertFalse(name.startswith('_'),
672
+ f"Private name {name} should not be in __all__")
673
+
674
+ def test_module_docstring(self):
675
+ """Test module has proper docstring."""
676
+ self.assertIsNotNone(compat.__doc__)
677
+ self.assertIn('Compatibility layer', compat.__doc__)
678
+
679
+
680
+ if __name__ == '__main__':
681
+ unittest.main(verbosity=2)