brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,681 @@
1
+ # Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ Comprehensive test suite for the _compatible_import module.
18
+
19
+ This test module provides extensive coverage of the compatibility layer
20
+ functionality, including:
21
+ - JAX version-dependent imports and compatibility
22
+ - Utility functions (safe_map, safe_zip, unzip2)
23
+ - Function wrapping and metadata handling
24
+ - Type safety and error handling
25
+ - Edge cases and boundary conditions
26
+ """
27
+
28
+ import unittest
29
+ from functools import partial
30
+ from unittest.mock import Mock
31
+
32
+ import jax
33
+ import jax.numpy as jnp
34
+ import numpy as np
35
+
36
+ from brainstate import _compatible_import as compat
37
+
38
+
39
+ class TestJAXVersionCompatibility(unittest.TestCase):
40
+ """Test JAX version-dependent imports and compatibility."""
41
+
42
+ def setUp(self):
43
+ """Set up test environment."""
44
+ self.original_version = jax.__version_info__
45
+
46
+ def tearDown(self):
47
+ """Clean up after tests."""
48
+ # Restore original version info
49
+ jax.__version_info__ = self.original_version
50
+
51
+ def test_device_import_compatibility(self):
52
+ """Test Device import works across JAX versions."""
53
+ # Test that Device is available and importable
54
+ self.assertTrue(hasattr(compat, 'Device'))
55
+ self.assertIsNotNone(compat.Device)
56
+
57
+ # Test Device can be used for type checking
58
+ device = jax.devices()[0]
59
+ self.assertIsInstance(device, compat.Device)
60
+
61
+ def test_core_imports_availability(self):
62
+ """Test core JAX imports are available."""
63
+ # Core types should be available
64
+ core_types = [
65
+ 'ClosedJaxpr', 'Primitive', 'Var', 'JaxprEqn',
66
+ 'Jaxpr', 'Literal', 'Tracer'
67
+ ]
68
+
69
+ for type_name in core_types:
70
+ self.assertTrue(hasattr(compat, type_name),
71
+ f"{type_name} should be available")
72
+ self.assertIsNotNone(getattr(compat, type_name))
73
+
74
+ def test_function_imports_availability(self):
75
+ """Test function imports are available."""
76
+ functions = [
77
+ 'jaxpr_as_fun', 'get_aval', 'to_concrete_aval',
78
+ 'extend_axis_env_nd'
79
+ ]
80
+
81
+ for func_name in functions:
82
+ self.assertTrue(hasattr(compat, func_name),
83
+ f"{func_name} should be available")
84
+ self.assertTrue(callable(getattr(compat, func_name)),
85
+ f"{func_name} should be callable")
86
+
87
+ def test_extend_axis_env_nd_functionality(self):
88
+ """Test extend_axis_env_nd context manager."""
89
+ # Test basic functionality
90
+ with compat.extend_axis_env_nd([('test_axis', 10)]):
91
+ # Context should execute without error
92
+ pass
93
+
94
+ # Test with multiple axes
95
+ with compat.extend_axis_env_nd([('batch', 32), ('seq', 128)]):
96
+ pass
97
+
98
+ # Test with empty axes
99
+ with compat.extend_axis_env_nd([]):
100
+ pass
101
+
102
+ def test_get_aval_functionality(self):
103
+ """Test get_aval function works correctly."""
104
+ # Test with JAX array
105
+ arr = jnp.array([1, 2, 3])
106
+ aval = compat.get_aval(arr)
107
+ self.assertIsNotNone(aval)
108
+
109
+ # Test with scalar
110
+ scalar = jnp.float32(3.14)
111
+ scalar_aval = compat.get_aval(scalar)
112
+ self.assertIsNotNone(scalar_aval)
113
+
114
+ def test_to_concrete_aval_functionality(self):
115
+ """Test to_concrete_aval function."""
116
+ # Test with concrete array
117
+ arr = jnp.array([1, 2, 3])
118
+ result = compat.to_concrete_aval(arr)
119
+ self.assertIsNotNone(result)
120
+
121
+ # # Test with scalar
122
+ # scalar = 42.0
123
+ # result = compat.to_concrete_aval(scalar)
124
+ # self.assertEqual(result, scalar)
125
+
126
+
127
+ class TestUtilityFunctions(unittest.TestCase):
128
+ """Test utility functions like safe_map, safe_zip, unzip2."""
129
+
130
+ def setUp(self):
131
+ """Set up test environment."""
132
+ pass
133
+
134
+ def tearDown(self):
135
+ """Clean up after tests."""
136
+ pass
137
+
138
+ def test_safe_map_basic(self):
139
+ """Test basic safe_map functionality."""
140
+ # Single argument function
141
+ result = compat.safe_map(lambda x: x * 2, [1, 2, 3])
142
+ self.assertEqual(result, [2, 4, 6])
143
+
144
+ # Multiple argument function
145
+ result = compat.safe_map(lambda x, y: x + y, [1, 2, 3], [4, 5, 6])
146
+ self.assertEqual(result, [5, 7, 9])
147
+
148
+ # String function
149
+ result = compat.safe_map(str.upper, ['a', 'b', 'c'])
150
+ self.assertEqual(result, ['A', 'B', 'C'])
151
+
152
+ def test_safe_map_empty_inputs(self):
153
+ """Test safe_map with empty inputs."""
154
+ result = compat.safe_map(lambda x: x, [])
155
+ self.assertEqual(result, [])
156
+
157
+ result = compat.safe_map(lambda x, y: x + y, [], [])
158
+ self.assertEqual(result, [])
159
+
160
+ def test_safe_map_length_mismatch(self):
161
+ """Test safe_map raises error on length mismatch."""
162
+ with self.assertRaises(AssertionError) as context:
163
+ compat.safe_map(lambda x, y: x + y, [1, 2, 3], [4, 5])
164
+
165
+ self.assertIn('length mismatch', str(context.exception))
166
+
167
+ def test_safe_map_complex_functions(self):
168
+ """Test safe_map with complex functions."""
169
+ # Lambda with multiple operations
170
+ result = compat.safe_map(lambda x: x ** 2 + 1, [1, 2, 3])
171
+ self.assertEqual(result, [2, 5, 10])
172
+
173
+ # Function that returns tuples
174
+ result = compat.safe_map(lambda x, y: (x, y), [1, 2], ['a', 'b'])
175
+ self.assertEqual(result, [(1, 'a'), (2, 'b')])
176
+
177
+ def test_safe_zip_basic(self):
178
+ """Test basic safe_zip functionality."""
179
+ # Two sequences
180
+ result = compat.safe_zip([1, 2, 3], ['a', 'b', 'c'])
181
+ expected = [(1, 'a'), (2, 'b'), (3, 'c')]
182
+ self.assertEqual(result, expected)
183
+
184
+ # Three sequences
185
+ result = compat.safe_zip([1, 2], [3, 4], [5, 6])
186
+ expected = [(1, 3, 5), (2, 4, 6)]
187
+ self.assertEqual(result, expected)
188
+
189
+ def test_safe_zip_empty_inputs(self):
190
+ """Test safe_zip with empty inputs."""
191
+ result = compat.safe_zip([], [])
192
+ self.assertEqual(result, [])
193
+
194
+ result = compat.safe_zip([], [], [])
195
+ self.assertEqual(result, [])
196
+
197
+ def test_safe_zip_length_mismatch(self):
198
+ """Test safe_zip raises error on length mismatch."""
199
+ with self.assertRaises(AssertionError) as context:
200
+ compat.safe_zip([1, 2, 3], [4, 5])
201
+
202
+ self.assertIn('length mismatch', str(context.exception))
203
+
204
+ def test_safe_zip_single_sequence(self):
205
+ """Test safe_zip with single sequence."""
206
+ result = compat.safe_zip([1, 2, 3])
207
+ expected = [(1,), (2,), (3,)]
208
+ self.assertEqual(result, expected)
209
+
210
+ def test_safe_zip_mixed_types(self):
211
+ """Test safe_zip with mixed data types."""
212
+ result = compat.safe_zip([1, 2], ['a', 'b'], [True, False])
213
+ expected = [(1, 'a', True), (2, 'b', False)]
214
+ self.assertEqual(result, expected)
215
+
216
+ def test_unzip2_basic(self):
217
+ """Test basic unzip2 functionality."""
218
+ pairs = [(1, 'a'), (2, 'b'), (3, 'c')]
219
+ first, second = compat.unzip2(pairs)
220
+
221
+ self.assertEqual(first, (1, 2, 3))
222
+ self.assertEqual(second, ('a', 'b', 'c'))
223
+
224
+ def test_unzip2_empty(self):
225
+ """Test unzip2 with empty input."""
226
+ first, second = compat.unzip2([])
227
+ self.assertEqual(first, ())
228
+ self.assertEqual(second, ())
229
+
230
+ def test_unzip2_single_pair(self):
231
+ """Test unzip2 with single pair."""
232
+ first, second = compat.unzip2([(42, 'answer')])
233
+ self.assertEqual(first, (42,))
234
+ self.assertEqual(second, ('answer',))
235
+
236
+ def test_unzip2_mixed_types(self):
237
+ """Test unzip2 with mixed data types."""
238
+ pairs = [(1, 'a'), (2.5, 'b'), (None, 'c')]
239
+ first, second = compat.unzip2(pairs)
240
+
241
+ self.assertEqual(first, (1, 2.5, None))
242
+ self.assertEqual(second, ('a', 'b', 'c'))
243
+
244
+ def test_unzip2_return_types(self):
245
+ """Test unzip2 returns proper tuple types."""
246
+ pairs = [(1, 'a'), (2, 'b')]
247
+ first, second = compat.unzip2(pairs)
248
+
249
+ self.assertIsInstance(first, tuple)
250
+ self.assertIsInstance(second, tuple)
251
+
252
+ def test_unzip2_with_generator(self):
253
+ """Test unzip2 with generator input."""
254
+
255
+ def pair_generator():
256
+ yield (1, 'a')
257
+ yield (2, 'b')
258
+ yield (3, 'c')
259
+
260
+ first, second = compat.unzip2(pair_generator())
261
+ self.assertEqual(first, (1, 2, 3))
262
+ self.assertEqual(second, ('a', 'b', 'c'))
263
+
264
+
265
+ class TestFunctionWrapping(unittest.TestCase):
266
+ """Test function wrapping and metadata handling."""
267
+
268
+ def setUp(self):
269
+ """Set up test environment."""
270
+ pass
271
+
272
+ def tearDown(self):
273
+ """Clean up after tests."""
274
+ pass
275
+
276
+ def test_fun_name_basic(self):
277
+ """Test fun_name function with regular functions."""
278
+
279
+ def test_function():
280
+ """Test function docstring."""
281
+ pass
282
+
283
+ name = compat.fun_name(test_function)
284
+ self.assertEqual(name, 'test_function')
285
+
286
+ def test_fun_name_lambda(self):
287
+ """Test fun_name with lambda functions."""
288
+ lambda_func = lambda x: x * 2
289
+ name = compat.fun_name(lambda_func)
290
+ self.assertEqual(name, '<lambda>')
291
+
292
+ def test_fun_name_partial(self):
293
+ """Test fun_name with partial functions."""
294
+
295
+ def original_function(x, y):
296
+ return x + y
297
+
298
+ partial_func = partial(original_function, 10)
299
+ name = compat.fun_name(partial_func)
300
+ self.assertEqual(name, 'original_function')
301
+
302
+ def test_fun_name_nested_partial(self):
303
+ """Test fun_name with nested partial functions."""
304
+
305
+ def base_function(x, y, z):
306
+ return x + y + z
307
+
308
+ partial1 = partial(base_function, 1)
309
+ partial2 = partial(partial1, 2)
310
+
311
+ name = compat.fun_name(partial2)
312
+ self.assertEqual(name, 'base_function')
313
+
314
+ def test_fun_name_no_name_attribute(self):
315
+ """Test fun_name with objects without __name__."""
316
+
317
+ class CallableClass:
318
+ def __call__(self):
319
+ pass
320
+
321
+ callable_obj = CallableClass()
322
+ name = compat.fun_name(callable_obj)
323
+ self.assertEqual(name, '<unnamed function>')
324
+
325
+ def test_wraps_basic(self):
326
+ """Test basic wraps functionality."""
327
+
328
+ def original_function():
329
+ """Original function docstring."""
330
+ return 42
331
+
332
+ @compat.wraps(original_function)
333
+ def wrapper():
334
+ return original_function()
335
+
336
+ self.assertEqual(wrapper.__name__, 'original_function')
337
+ self.assertEqual(wrapper.__doc__, 'Original function docstring.')
338
+ self.assertEqual(wrapper.__wrapped__, original_function)
339
+
340
+ def test_wraps_with_namestr(self):
341
+ """Test wraps with custom name string."""
342
+
343
+ def original_function():
344
+ pass
345
+
346
+ @compat.wraps(original_function, namestr="wrapped_{fun}")
347
+ def wrapper():
348
+ pass
349
+
350
+ self.assertEqual(wrapper.__name__, 'wrapped_original_function')
351
+
352
+ def test_wraps_with_docstr(self):
353
+ """Test wraps with custom docstring."""
354
+
355
+ def original_function():
356
+ """Original docstring."""
357
+ pass
358
+
359
+ @compat.wraps(original_function, docstr="Wrapper for {fun}: {doc}")
360
+ def wrapper():
361
+ pass
362
+
363
+ expected_doc = "Wrapper for original_function: Original docstring."
364
+ self.assertEqual(wrapper.__doc__, expected_doc)
365
+
366
+ def test_wraps_with_kwargs(self):
367
+ """Test wraps with additional keyword arguments."""
368
+
369
+ def original_function():
370
+ pass
371
+
372
+ @compat.wraps(original_function,
373
+ docstr="Function {fun} version {version}",
374
+ version="1.0")
375
+ def wrapper():
376
+ pass
377
+
378
+ expected_doc = "Function original_function version 1.0"
379
+ self.assertEqual(wrapper.__doc__, expected_doc)
380
+
381
+ def test_wraps_preserves_annotations(self):
382
+ """Test wraps preserves function annotations."""
383
+
384
+ def original_function(x: int, y: str) -> float:
385
+ return float(x)
386
+
387
+ @compat.wraps(original_function)
388
+ def wrapper(x: int, y: str) -> float:
389
+ return original_function(x, y)
390
+
391
+ self.assertEqual(wrapper.__annotations__, original_function.__annotations__)
392
+
393
+ def test_wraps_preserves_dict(self):
394
+ """Test wraps preserves function __dict__."""
395
+
396
+ def original_function():
397
+ pass
398
+
399
+ original_function.custom_attr = "test_value"
400
+ original_function.another_attr = 42
401
+
402
+ @compat.wraps(original_function)
403
+ def wrapper():
404
+ pass
405
+
406
+ self.assertEqual(wrapper.custom_attr, "test_value")
407
+ self.assertEqual(wrapper.another_attr, 42)
408
+
409
+ def test_wraps_handles_exceptions(self):
410
+ """Test wraps handles exceptions gracefully."""
411
+ # Create a mock object that raises exceptions
412
+ mock_func = Mock()
413
+ mock_func.__name__ = Mock(side_effect=Exception("Test exception"))
414
+
415
+ @compat.wraps(mock_func)
416
+ def wrapper():
417
+ pass
418
+
419
+ # Should not raise exception, just continue
420
+ self.assertTrue(callable(wrapper))
421
+
422
+ def test_wraps_with_missing_attributes(self):
423
+ """Test wraps handles missing attributes gracefully."""
424
+
425
+ class MinimalCallable:
426
+ pass
427
+
428
+ minimal_func = MinimalCallable()
429
+
430
+ @compat.wraps(minimal_func)
431
+ def wrapper():
432
+ pass
433
+
434
+ # Should handle missing attributes without crashing
435
+ self.assertTrue(callable(wrapper))
436
+
437
+
438
+ class TestEdgeCases(unittest.TestCase):
439
+ """Test edge cases and boundary conditions."""
440
+
441
+ def setUp(self):
442
+ """Set up test environment."""
443
+ pass
444
+
445
+ def tearDown(self):
446
+ """Clean up after tests."""
447
+ pass
448
+
449
+ def test_safe_map_with_none_inputs(self):
450
+ """Test safe_map behavior with None inputs."""
451
+ # Function that handles None
452
+ result = compat.safe_map(lambda x: x if x is not None else 'default',
453
+ [1, None, 3])
454
+ self.assertEqual(result, [1, 'default', 3])
455
+
456
+ def test_safe_map_with_zero_length(self):
457
+ """Test safe_map with zero-length sequences."""
458
+ result = compat.safe_map(str, [])
459
+ self.assertEqual(result, [])
460
+
461
+ def test_safe_zip_with_none_values(self):
462
+ """Test safe_zip with None values."""
463
+ result = compat.safe_zip([1, None, 3], [4, 5, None])
464
+ expected = [(1, 4), (None, 5), (3, None)]
465
+ self.assertEqual(result, expected)
466
+
467
+ def test_unzip2_with_none_values(self):
468
+ """Test unzip2 with None values."""
469
+ pairs = [(1, None), (None, 'a'), (2, 'b')]
470
+ first, second = compat.unzip2(pairs)
471
+
472
+ self.assertEqual(first, (1, None, 2))
473
+ self.assertEqual(second, (None, 'a', 'b'))
474
+
475
+ def test_large_sequences(self):
476
+ """Test utility functions with large sequences."""
477
+ large_seq1 = list(range(10000))
478
+ large_seq2 = list(range(10000, 20000))
479
+
480
+ # Test safe_map
481
+ result = compat.safe_map(lambda x, y: x + y, large_seq1[:100], large_seq2[:100])
482
+ self.assertEqual(len(result), 100)
483
+ self.assertEqual(result[0], 10000) # 0 + 10000
484
+
485
+ # Test safe_zip
486
+ result = compat.safe_zip(large_seq1[:100], large_seq2[:100])
487
+ self.assertEqual(len(result), 100)
488
+ self.assertEqual(result[0], (0, 10000))
489
+
490
+ # Test unzip2
491
+ pairs = list(zip(large_seq1[:100], large_seq2[:100]))
492
+ first, second = compat.unzip2(pairs)
493
+ self.assertEqual(len(first), 100)
494
+ self.assertEqual(len(second), 100)
495
+
496
+ # def test_to_concrete_aval_edge_cases(self):
497
+ # """Test to_concrete_aval with edge cases."""
498
+ # # Test with None
499
+ # result = compat.to_concrete_aval(None)
500
+ # self.assertIsNone(result)
501
+ #
502
+ # # Test with regular Python objects
503
+ # result = compat.to_concrete_aval(42)
504
+ # self.assertEqual(result, 42)
505
+ #
506
+ # result = compat.to_concrete_aval("string")
507
+ # self.assertEqual(result, "string")
508
+ #
509
+ # # Test with list
510
+ # test_list = [1, 2, 3]
511
+ # result = compat.to_concrete_aval(test_list)
512
+ # self.assertEqual(result, test_list)
513
+
514
+ def test_function_name_edge_cases(self):
515
+ """Test fun_name with edge cases."""
516
+ # Built-in function
517
+ name = compat.fun_name(len)
518
+ self.assertEqual(name, 'len')
519
+
520
+ # Method
521
+ name = compat.fun_name(str.upper)
522
+ self.assertEqual(name, 'upper')
523
+
524
+ # Nested function
525
+ def outer():
526
+ def inner():
527
+ pass
528
+
529
+ return inner
530
+
531
+ inner_func = outer()
532
+ name = compat.fun_name(inner_func)
533
+ self.assertEqual(name, 'inner')
534
+
535
+ def test_concurrent_usage(self):
536
+ """Test thread safety of utility functions."""
537
+ import threading
538
+
539
+ results = []
540
+ errors = []
541
+
542
+ def worker():
543
+ try:
544
+ # Test safe_map in concurrent context
545
+ for i in range(100):
546
+ result = compat.safe_map(lambda x: x * 2, [1, 2, 3])
547
+ results.append(result)
548
+
549
+ # Test safe_zip
550
+ for i in range(100):
551
+ result = compat.safe_zip([1, 2], [3, 4])
552
+
553
+ # Test unzip2
554
+ for i in range(100):
555
+ first, second = compat.unzip2([(1, 'a'), (2, 'b')])
556
+
557
+ except Exception as e:
558
+ errors.append(e)
559
+
560
+ threads = []
561
+ for _ in range(5):
562
+ thread = threading.Thread(target=worker)
563
+ threads.append(thread)
564
+ thread.start()
565
+
566
+ for thread in threads:
567
+ thread.join()
568
+
569
+ # Should have no errors
570
+ self.assertEqual(len(errors), 0)
571
+ # Should have expected number of results
572
+ self.assertEqual(len(results), 500) # 5 threads * 100 iterations
573
+
574
+
575
+ class TestTypeHints(unittest.TestCase):
576
+ """Test type hints and generic type variables."""
577
+
578
+ def test_type_variables_defined(self):
579
+ """Test that type variables are properly defined."""
580
+ # Check TypeVars are available in the module
581
+ self.assertTrue(hasattr(compat, 'T'))
582
+ self.assertTrue(hasattr(compat, 'T1'))
583
+ self.assertTrue(hasattr(compat, 'T2'))
584
+ self.assertTrue(hasattr(compat, 'T3'))
585
+
586
+ def test_unzip2_type_preservation(self):
587
+ """Test unzip2 preserves type information."""
588
+ # Test with specific types
589
+ int_str_pairs = [(1, 'a'), (2, 'b'), (3, 'c')]
590
+ ints, strs = compat.unzip2(int_str_pairs)
591
+
592
+ # Verify types are preserved
593
+ self.assertTrue(all(isinstance(x, int) for x in ints))
594
+ self.assertTrue(all(isinstance(x, str) for x in strs))
595
+
596
+ def test_safe_functions_with_different_types(self):
597
+ """Test safe functions work with different types."""
598
+ # Test safe_map with different input types
599
+ mixed_inputs = [1, 2.5, '3']
600
+ result = compat.safe_map(str, mixed_inputs)
601
+ expected = ['1', '2.5', '3']
602
+ self.assertEqual(result, expected)
603
+
604
+ # Test safe_zip with different types
605
+ result = compat.safe_zip([1, 2], [3.14, 2.71], ['a', 'b'])
606
+ expected = [(1, 3.14, 'a'), (2, 2.71, 'b')]
607
+ self.assertEqual(result, expected)
608
+
609
+
610
+ class TestIntegration(unittest.TestCase):
611
+ """Integration tests with JAX functionality."""
612
+
613
+ def test_jax_integration(self):
614
+ """Test integration with JAX arrays and operations."""
615
+ # Create JAX arrays
616
+ arr1 = jnp.array([1, 2, 3])
617
+ arr2 = jnp.array([4, 5, 6])
618
+
619
+ # Use safe_map with JAX operations
620
+ result = compat.safe_map(lambda x, y: x + y, arr1.tolist(), arr2.tolist())
621
+ expected = [5, 7, 9]
622
+ self.assertEqual(result, expected)
623
+
624
+ # Use safe_zip with JAX arrays
625
+ result = compat.safe_zip(arr1.tolist(), arr2.tolist())
626
+ expected = [(1, 4), (2, 5), (3, 6)]
627
+ self.assertEqual(result, expected)
628
+
629
+ def test_with_jax_transformations(self):
630
+ """Test compatibility with JAX transformations."""
631
+
632
+ def test_function(x):
633
+ # Use utility functions inside JAX-transformable code
634
+ pairs = [(x, x + 1), (x + 2, x + 3)]
635
+ first, second = compat.unzip2(pairs)
636
+ return jnp.array(first), jnp.array(second)
637
+
638
+ # Test function works
639
+ result1, result2 = test_function(10.0)
640
+ np.testing.assert_array_equal(result1, [10.0, 12.0])
641
+ np.testing.assert_array_equal(result2, [11.0, 13.0])
642
+
643
+ # Test with JAX transformations
644
+ jitted_func = jax.jit(test_function)
645
+ result1, result2 = jitted_func(10.0)
646
+ np.testing.assert_array_equal(result1, [10.0, 12.0])
647
+ np.testing.assert_array_equal(result2, [11.0, 13.0])
648
+
649
+
650
+ class TestModuleStructure(unittest.TestCase):
651
+ """Test module structure and __all__ exports."""
652
+
653
+ def test_all_exports(self):
654
+ """Test that __all__ contains expected exports."""
655
+ expected_exports = [
656
+ 'ClosedJaxpr', 'Primitive', 'extend_axis_env_nd', 'jaxpr_as_fun',
657
+ 'get_aval', 'Tracer', 'to_concrete_aval', 'safe_map', 'safe_zip',
658
+ 'unzip2', 'wraps', 'Device', 'wrap_init', 'Var', 'JaxprEqn',
659
+ 'Jaxpr', 'Literal'
660
+ ]
661
+
662
+ for export in expected_exports:
663
+ self.assertIn(export, compat.__all__,
664
+ f"{export} should be in __all__")
665
+ self.assertTrue(hasattr(compat, export),
666
+ f"{export} should be available in module")
667
+
668
+ def test_no_unexpected_exports(self):
669
+ """Test that no private functions are exported."""
670
+ for name in compat.__all__:
671
+ self.assertFalse(name.startswith('_'),
672
+ f"Private name {name} should not be in __all__")
673
+
674
+ def test_module_docstring(self):
675
+ """Test module has proper docstring."""
676
+ self.assertIsNotNone(compat.__doc__)
677
+ self.assertIn('Compatibility layer', compat.__doc__)
678
+
679
+
680
+ if __name__ == '__main__':
681
+ unittest.main(verbosity=2)