brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +15 -28
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,681 @@
|
|
1
|
+
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""
|
17
|
+
Comprehensive test suite for the _compatible_import module.
|
18
|
+
|
19
|
+
This test module provides extensive coverage of the compatibility layer
|
20
|
+
functionality, including:
|
21
|
+
- JAX version-dependent imports and compatibility
|
22
|
+
- Utility functions (safe_map, safe_zip, unzip2)
|
23
|
+
- Function wrapping and metadata handling
|
24
|
+
- Type safety and error handling
|
25
|
+
- Edge cases and boundary conditions
|
26
|
+
"""
|
27
|
+
|
28
|
+
import unittest
|
29
|
+
from functools import partial
|
30
|
+
from unittest.mock import Mock
|
31
|
+
|
32
|
+
import jax
|
33
|
+
import jax.numpy as jnp
|
34
|
+
import numpy as np
|
35
|
+
|
36
|
+
from brainstate import _compatible_import as compat
|
37
|
+
|
38
|
+
|
39
|
+
class TestJAXVersionCompatibility(unittest.TestCase):
|
40
|
+
"""Test JAX version-dependent imports and compatibility."""
|
41
|
+
|
42
|
+
def setUp(self):
|
43
|
+
"""Set up test environment."""
|
44
|
+
self.original_version = jax.__version_info__
|
45
|
+
|
46
|
+
def tearDown(self):
|
47
|
+
"""Clean up after tests."""
|
48
|
+
# Restore original version info
|
49
|
+
jax.__version_info__ = self.original_version
|
50
|
+
|
51
|
+
def test_device_import_compatibility(self):
|
52
|
+
"""Test Device import works across JAX versions."""
|
53
|
+
# Test that Device is available and importable
|
54
|
+
self.assertTrue(hasattr(compat, 'Device'))
|
55
|
+
self.assertIsNotNone(compat.Device)
|
56
|
+
|
57
|
+
# Test Device can be used for type checking
|
58
|
+
device = jax.devices()[0]
|
59
|
+
self.assertIsInstance(device, compat.Device)
|
60
|
+
|
61
|
+
def test_core_imports_availability(self):
|
62
|
+
"""Test core JAX imports are available."""
|
63
|
+
# Core types should be available
|
64
|
+
core_types = [
|
65
|
+
'ClosedJaxpr', 'Primitive', 'Var', 'JaxprEqn',
|
66
|
+
'Jaxpr', 'Literal', 'Tracer'
|
67
|
+
]
|
68
|
+
|
69
|
+
for type_name in core_types:
|
70
|
+
self.assertTrue(hasattr(compat, type_name),
|
71
|
+
f"{type_name} should be available")
|
72
|
+
self.assertIsNotNone(getattr(compat, type_name))
|
73
|
+
|
74
|
+
def test_function_imports_availability(self):
|
75
|
+
"""Test function imports are available."""
|
76
|
+
functions = [
|
77
|
+
'jaxpr_as_fun', 'get_aval', 'to_concrete_aval',
|
78
|
+
'extend_axis_env_nd'
|
79
|
+
]
|
80
|
+
|
81
|
+
for func_name in functions:
|
82
|
+
self.assertTrue(hasattr(compat, func_name),
|
83
|
+
f"{func_name} should be available")
|
84
|
+
self.assertTrue(callable(getattr(compat, func_name)),
|
85
|
+
f"{func_name} should be callable")
|
86
|
+
|
87
|
+
def test_extend_axis_env_nd_functionality(self):
|
88
|
+
"""Test extend_axis_env_nd context manager."""
|
89
|
+
# Test basic functionality
|
90
|
+
with compat.extend_axis_env_nd([('test_axis', 10)]):
|
91
|
+
# Context should execute without error
|
92
|
+
pass
|
93
|
+
|
94
|
+
# Test with multiple axes
|
95
|
+
with compat.extend_axis_env_nd([('batch', 32), ('seq', 128)]):
|
96
|
+
pass
|
97
|
+
|
98
|
+
# Test with empty axes
|
99
|
+
with compat.extend_axis_env_nd([]):
|
100
|
+
pass
|
101
|
+
|
102
|
+
def test_get_aval_functionality(self):
|
103
|
+
"""Test get_aval function works correctly."""
|
104
|
+
# Test with JAX array
|
105
|
+
arr = jnp.array([1, 2, 3])
|
106
|
+
aval = compat.get_aval(arr)
|
107
|
+
self.assertIsNotNone(aval)
|
108
|
+
|
109
|
+
# Test with scalar
|
110
|
+
scalar = jnp.float32(3.14)
|
111
|
+
scalar_aval = compat.get_aval(scalar)
|
112
|
+
self.assertIsNotNone(scalar_aval)
|
113
|
+
|
114
|
+
def test_to_concrete_aval_functionality(self):
|
115
|
+
"""Test to_concrete_aval function."""
|
116
|
+
# Test with concrete array
|
117
|
+
arr = jnp.array([1, 2, 3])
|
118
|
+
result = compat.to_concrete_aval(arr)
|
119
|
+
self.assertIsNotNone(result)
|
120
|
+
|
121
|
+
# # Test with scalar
|
122
|
+
# scalar = 42.0
|
123
|
+
# result = compat.to_concrete_aval(scalar)
|
124
|
+
# self.assertEqual(result, scalar)
|
125
|
+
|
126
|
+
|
127
|
+
class TestUtilityFunctions(unittest.TestCase):
|
128
|
+
"""Test utility functions like safe_map, safe_zip, unzip2."""
|
129
|
+
|
130
|
+
def setUp(self):
|
131
|
+
"""Set up test environment."""
|
132
|
+
pass
|
133
|
+
|
134
|
+
def tearDown(self):
|
135
|
+
"""Clean up after tests."""
|
136
|
+
pass
|
137
|
+
|
138
|
+
def test_safe_map_basic(self):
|
139
|
+
"""Test basic safe_map functionality."""
|
140
|
+
# Single argument function
|
141
|
+
result = compat.safe_map(lambda x: x * 2, [1, 2, 3])
|
142
|
+
self.assertEqual(result, [2, 4, 6])
|
143
|
+
|
144
|
+
# Multiple argument function
|
145
|
+
result = compat.safe_map(lambda x, y: x + y, [1, 2, 3], [4, 5, 6])
|
146
|
+
self.assertEqual(result, [5, 7, 9])
|
147
|
+
|
148
|
+
# String function
|
149
|
+
result = compat.safe_map(str.upper, ['a', 'b', 'c'])
|
150
|
+
self.assertEqual(result, ['A', 'B', 'C'])
|
151
|
+
|
152
|
+
def test_safe_map_empty_inputs(self):
|
153
|
+
"""Test safe_map with empty inputs."""
|
154
|
+
result = compat.safe_map(lambda x: x, [])
|
155
|
+
self.assertEqual(result, [])
|
156
|
+
|
157
|
+
result = compat.safe_map(lambda x, y: x + y, [], [])
|
158
|
+
self.assertEqual(result, [])
|
159
|
+
|
160
|
+
def test_safe_map_length_mismatch(self):
|
161
|
+
"""Test safe_map raises error on length mismatch."""
|
162
|
+
with self.assertRaises(AssertionError) as context:
|
163
|
+
compat.safe_map(lambda x, y: x + y, [1, 2, 3], [4, 5])
|
164
|
+
|
165
|
+
self.assertIn('length mismatch', str(context.exception))
|
166
|
+
|
167
|
+
def test_safe_map_complex_functions(self):
|
168
|
+
"""Test safe_map with complex functions."""
|
169
|
+
# Lambda with multiple operations
|
170
|
+
result = compat.safe_map(lambda x: x ** 2 + 1, [1, 2, 3])
|
171
|
+
self.assertEqual(result, [2, 5, 10])
|
172
|
+
|
173
|
+
# Function that returns tuples
|
174
|
+
result = compat.safe_map(lambda x, y: (x, y), [1, 2], ['a', 'b'])
|
175
|
+
self.assertEqual(result, [(1, 'a'), (2, 'b')])
|
176
|
+
|
177
|
+
def test_safe_zip_basic(self):
|
178
|
+
"""Test basic safe_zip functionality."""
|
179
|
+
# Two sequences
|
180
|
+
result = compat.safe_zip([1, 2, 3], ['a', 'b', 'c'])
|
181
|
+
expected = [(1, 'a'), (2, 'b'), (3, 'c')]
|
182
|
+
self.assertEqual(result, expected)
|
183
|
+
|
184
|
+
# Three sequences
|
185
|
+
result = compat.safe_zip([1, 2], [3, 4], [5, 6])
|
186
|
+
expected = [(1, 3, 5), (2, 4, 6)]
|
187
|
+
self.assertEqual(result, expected)
|
188
|
+
|
189
|
+
def test_safe_zip_empty_inputs(self):
|
190
|
+
"""Test safe_zip with empty inputs."""
|
191
|
+
result = compat.safe_zip([], [])
|
192
|
+
self.assertEqual(result, [])
|
193
|
+
|
194
|
+
result = compat.safe_zip([], [], [])
|
195
|
+
self.assertEqual(result, [])
|
196
|
+
|
197
|
+
def test_safe_zip_length_mismatch(self):
|
198
|
+
"""Test safe_zip raises error on length mismatch."""
|
199
|
+
with self.assertRaises(AssertionError) as context:
|
200
|
+
compat.safe_zip([1, 2, 3], [4, 5])
|
201
|
+
|
202
|
+
self.assertIn('length mismatch', str(context.exception))
|
203
|
+
|
204
|
+
def test_safe_zip_single_sequence(self):
|
205
|
+
"""Test safe_zip with single sequence."""
|
206
|
+
result = compat.safe_zip([1, 2, 3])
|
207
|
+
expected = [(1,), (2,), (3,)]
|
208
|
+
self.assertEqual(result, expected)
|
209
|
+
|
210
|
+
def test_safe_zip_mixed_types(self):
|
211
|
+
"""Test safe_zip with mixed data types."""
|
212
|
+
result = compat.safe_zip([1, 2], ['a', 'b'], [True, False])
|
213
|
+
expected = [(1, 'a', True), (2, 'b', False)]
|
214
|
+
self.assertEqual(result, expected)
|
215
|
+
|
216
|
+
def test_unzip2_basic(self):
|
217
|
+
"""Test basic unzip2 functionality."""
|
218
|
+
pairs = [(1, 'a'), (2, 'b'), (3, 'c')]
|
219
|
+
first, second = compat.unzip2(pairs)
|
220
|
+
|
221
|
+
self.assertEqual(first, (1, 2, 3))
|
222
|
+
self.assertEqual(second, ('a', 'b', 'c'))
|
223
|
+
|
224
|
+
def test_unzip2_empty(self):
|
225
|
+
"""Test unzip2 with empty input."""
|
226
|
+
first, second = compat.unzip2([])
|
227
|
+
self.assertEqual(first, ())
|
228
|
+
self.assertEqual(second, ())
|
229
|
+
|
230
|
+
def test_unzip2_single_pair(self):
|
231
|
+
"""Test unzip2 with single pair."""
|
232
|
+
first, second = compat.unzip2([(42, 'answer')])
|
233
|
+
self.assertEqual(first, (42,))
|
234
|
+
self.assertEqual(second, ('answer',))
|
235
|
+
|
236
|
+
def test_unzip2_mixed_types(self):
|
237
|
+
"""Test unzip2 with mixed data types."""
|
238
|
+
pairs = [(1, 'a'), (2.5, 'b'), (None, 'c')]
|
239
|
+
first, second = compat.unzip2(pairs)
|
240
|
+
|
241
|
+
self.assertEqual(first, (1, 2.5, None))
|
242
|
+
self.assertEqual(second, ('a', 'b', 'c'))
|
243
|
+
|
244
|
+
def test_unzip2_return_types(self):
|
245
|
+
"""Test unzip2 returns proper tuple types."""
|
246
|
+
pairs = [(1, 'a'), (2, 'b')]
|
247
|
+
first, second = compat.unzip2(pairs)
|
248
|
+
|
249
|
+
self.assertIsInstance(first, tuple)
|
250
|
+
self.assertIsInstance(second, tuple)
|
251
|
+
|
252
|
+
def test_unzip2_with_generator(self):
|
253
|
+
"""Test unzip2 with generator input."""
|
254
|
+
|
255
|
+
def pair_generator():
|
256
|
+
yield (1, 'a')
|
257
|
+
yield (2, 'b')
|
258
|
+
yield (3, 'c')
|
259
|
+
|
260
|
+
first, second = compat.unzip2(pair_generator())
|
261
|
+
self.assertEqual(first, (1, 2, 3))
|
262
|
+
self.assertEqual(second, ('a', 'b', 'c'))
|
263
|
+
|
264
|
+
|
265
|
+
class TestFunctionWrapping(unittest.TestCase):
|
266
|
+
"""Test function wrapping and metadata handling."""
|
267
|
+
|
268
|
+
def setUp(self):
|
269
|
+
"""Set up test environment."""
|
270
|
+
pass
|
271
|
+
|
272
|
+
def tearDown(self):
|
273
|
+
"""Clean up after tests."""
|
274
|
+
pass
|
275
|
+
|
276
|
+
def test_fun_name_basic(self):
|
277
|
+
"""Test fun_name function with regular functions."""
|
278
|
+
|
279
|
+
def test_function():
|
280
|
+
"""Test function docstring."""
|
281
|
+
pass
|
282
|
+
|
283
|
+
name = compat.fun_name(test_function)
|
284
|
+
self.assertEqual(name, 'test_function')
|
285
|
+
|
286
|
+
def test_fun_name_lambda(self):
|
287
|
+
"""Test fun_name with lambda functions."""
|
288
|
+
lambda_func = lambda x: x * 2
|
289
|
+
name = compat.fun_name(lambda_func)
|
290
|
+
self.assertEqual(name, '<lambda>')
|
291
|
+
|
292
|
+
def test_fun_name_partial(self):
|
293
|
+
"""Test fun_name with partial functions."""
|
294
|
+
|
295
|
+
def original_function(x, y):
|
296
|
+
return x + y
|
297
|
+
|
298
|
+
partial_func = partial(original_function, 10)
|
299
|
+
name = compat.fun_name(partial_func)
|
300
|
+
self.assertEqual(name, 'original_function')
|
301
|
+
|
302
|
+
def test_fun_name_nested_partial(self):
|
303
|
+
"""Test fun_name with nested partial functions."""
|
304
|
+
|
305
|
+
def base_function(x, y, z):
|
306
|
+
return x + y + z
|
307
|
+
|
308
|
+
partial1 = partial(base_function, 1)
|
309
|
+
partial2 = partial(partial1, 2)
|
310
|
+
|
311
|
+
name = compat.fun_name(partial2)
|
312
|
+
self.assertEqual(name, 'base_function')
|
313
|
+
|
314
|
+
def test_fun_name_no_name_attribute(self):
|
315
|
+
"""Test fun_name with objects without __name__."""
|
316
|
+
|
317
|
+
class CallableClass:
|
318
|
+
def __call__(self):
|
319
|
+
pass
|
320
|
+
|
321
|
+
callable_obj = CallableClass()
|
322
|
+
name = compat.fun_name(callable_obj)
|
323
|
+
self.assertEqual(name, '<unnamed function>')
|
324
|
+
|
325
|
+
def test_wraps_basic(self):
|
326
|
+
"""Test basic wraps functionality."""
|
327
|
+
|
328
|
+
def original_function():
|
329
|
+
"""Original function docstring."""
|
330
|
+
return 42
|
331
|
+
|
332
|
+
@compat.wraps(original_function)
|
333
|
+
def wrapper():
|
334
|
+
return original_function()
|
335
|
+
|
336
|
+
self.assertEqual(wrapper.__name__, 'original_function')
|
337
|
+
self.assertEqual(wrapper.__doc__, 'Original function docstring.')
|
338
|
+
self.assertEqual(wrapper.__wrapped__, original_function)
|
339
|
+
|
340
|
+
def test_wraps_with_namestr(self):
|
341
|
+
"""Test wraps with custom name string."""
|
342
|
+
|
343
|
+
def original_function():
|
344
|
+
pass
|
345
|
+
|
346
|
+
@compat.wraps(original_function, namestr="wrapped_{fun}")
|
347
|
+
def wrapper():
|
348
|
+
pass
|
349
|
+
|
350
|
+
self.assertEqual(wrapper.__name__, 'wrapped_original_function')
|
351
|
+
|
352
|
+
def test_wraps_with_docstr(self):
|
353
|
+
"""Test wraps with custom docstring."""
|
354
|
+
|
355
|
+
def original_function():
|
356
|
+
"""Original docstring."""
|
357
|
+
pass
|
358
|
+
|
359
|
+
@compat.wraps(original_function, docstr="Wrapper for {fun}: {doc}")
|
360
|
+
def wrapper():
|
361
|
+
pass
|
362
|
+
|
363
|
+
expected_doc = "Wrapper for original_function: Original docstring."
|
364
|
+
self.assertEqual(wrapper.__doc__, expected_doc)
|
365
|
+
|
366
|
+
def test_wraps_with_kwargs(self):
|
367
|
+
"""Test wraps with additional keyword arguments."""
|
368
|
+
|
369
|
+
def original_function():
|
370
|
+
pass
|
371
|
+
|
372
|
+
@compat.wraps(original_function,
|
373
|
+
docstr="Function {fun} version {version}",
|
374
|
+
version="1.0")
|
375
|
+
def wrapper():
|
376
|
+
pass
|
377
|
+
|
378
|
+
expected_doc = "Function original_function version 1.0"
|
379
|
+
self.assertEqual(wrapper.__doc__, expected_doc)
|
380
|
+
|
381
|
+
def test_wraps_preserves_annotations(self):
|
382
|
+
"""Test wraps preserves function annotations."""
|
383
|
+
|
384
|
+
def original_function(x: int, y: str) -> float:
|
385
|
+
return float(x)
|
386
|
+
|
387
|
+
@compat.wraps(original_function)
|
388
|
+
def wrapper(x: int, y: str) -> float:
|
389
|
+
return original_function(x, y)
|
390
|
+
|
391
|
+
self.assertEqual(wrapper.__annotations__, original_function.__annotations__)
|
392
|
+
|
393
|
+
def test_wraps_preserves_dict(self):
|
394
|
+
"""Test wraps preserves function __dict__."""
|
395
|
+
|
396
|
+
def original_function():
|
397
|
+
pass
|
398
|
+
|
399
|
+
original_function.custom_attr = "test_value"
|
400
|
+
original_function.another_attr = 42
|
401
|
+
|
402
|
+
@compat.wraps(original_function)
|
403
|
+
def wrapper():
|
404
|
+
pass
|
405
|
+
|
406
|
+
self.assertEqual(wrapper.custom_attr, "test_value")
|
407
|
+
self.assertEqual(wrapper.another_attr, 42)
|
408
|
+
|
409
|
+
def test_wraps_handles_exceptions(self):
|
410
|
+
"""Test wraps handles exceptions gracefully."""
|
411
|
+
# Create a mock object that raises exceptions
|
412
|
+
mock_func = Mock()
|
413
|
+
mock_func.__name__ = Mock(side_effect=Exception("Test exception"))
|
414
|
+
|
415
|
+
@compat.wraps(mock_func)
|
416
|
+
def wrapper():
|
417
|
+
pass
|
418
|
+
|
419
|
+
# Should not raise exception, just continue
|
420
|
+
self.assertTrue(callable(wrapper))
|
421
|
+
|
422
|
+
def test_wraps_with_missing_attributes(self):
|
423
|
+
"""Test wraps handles missing attributes gracefully."""
|
424
|
+
|
425
|
+
class MinimalCallable:
|
426
|
+
pass
|
427
|
+
|
428
|
+
minimal_func = MinimalCallable()
|
429
|
+
|
430
|
+
@compat.wraps(minimal_func)
|
431
|
+
def wrapper():
|
432
|
+
pass
|
433
|
+
|
434
|
+
# Should handle missing attributes without crashing
|
435
|
+
self.assertTrue(callable(wrapper))
|
436
|
+
|
437
|
+
|
438
|
+
class TestEdgeCases(unittest.TestCase):
|
439
|
+
"""Test edge cases and boundary conditions."""
|
440
|
+
|
441
|
+
def setUp(self):
|
442
|
+
"""Set up test environment."""
|
443
|
+
pass
|
444
|
+
|
445
|
+
def tearDown(self):
|
446
|
+
"""Clean up after tests."""
|
447
|
+
pass
|
448
|
+
|
449
|
+
def test_safe_map_with_none_inputs(self):
|
450
|
+
"""Test safe_map behavior with None inputs."""
|
451
|
+
# Function that handles None
|
452
|
+
result = compat.safe_map(lambda x: x if x is not None else 'default',
|
453
|
+
[1, None, 3])
|
454
|
+
self.assertEqual(result, [1, 'default', 3])
|
455
|
+
|
456
|
+
def test_safe_map_with_zero_length(self):
|
457
|
+
"""Test safe_map with zero-length sequences."""
|
458
|
+
result = compat.safe_map(str, [])
|
459
|
+
self.assertEqual(result, [])
|
460
|
+
|
461
|
+
def test_safe_zip_with_none_values(self):
|
462
|
+
"""Test safe_zip with None values."""
|
463
|
+
result = compat.safe_zip([1, None, 3], [4, 5, None])
|
464
|
+
expected = [(1, 4), (None, 5), (3, None)]
|
465
|
+
self.assertEqual(result, expected)
|
466
|
+
|
467
|
+
def test_unzip2_with_none_values(self):
|
468
|
+
"""Test unzip2 with None values."""
|
469
|
+
pairs = [(1, None), (None, 'a'), (2, 'b')]
|
470
|
+
first, second = compat.unzip2(pairs)
|
471
|
+
|
472
|
+
self.assertEqual(first, (1, None, 2))
|
473
|
+
self.assertEqual(second, (None, 'a', 'b'))
|
474
|
+
|
475
|
+
def test_large_sequences(self):
|
476
|
+
"""Test utility functions with large sequences."""
|
477
|
+
large_seq1 = list(range(10000))
|
478
|
+
large_seq2 = list(range(10000, 20000))
|
479
|
+
|
480
|
+
# Test safe_map
|
481
|
+
result = compat.safe_map(lambda x, y: x + y, large_seq1[:100], large_seq2[:100])
|
482
|
+
self.assertEqual(len(result), 100)
|
483
|
+
self.assertEqual(result[0], 10000) # 0 + 10000
|
484
|
+
|
485
|
+
# Test safe_zip
|
486
|
+
result = compat.safe_zip(large_seq1[:100], large_seq2[:100])
|
487
|
+
self.assertEqual(len(result), 100)
|
488
|
+
self.assertEqual(result[0], (0, 10000))
|
489
|
+
|
490
|
+
# Test unzip2
|
491
|
+
pairs = list(zip(large_seq1[:100], large_seq2[:100]))
|
492
|
+
first, second = compat.unzip2(pairs)
|
493
|
+
self.assertEqual(len(first), 100)
|
494
|
+
self.assertEqual(len(second), 100)
|
495
|
+
|
496
|
+
# def test_to_concrete_aval_edge_cases(self):
|
497
|
+
# """Test to_concrete_aval with edge cases."""
|
498
|
+
# # Test with None
|
499
|
+
# result = compat.to_concrete_aval(None)
|
500
|
+
# self.assertIsNone(result)
|
501
|
+
#
|
502
|
+
# # Test with regular Python objects
|
503
|
+
# result = compat.to_concrete_aval(42)
|
504
|
+
# self.assertEqual(result, 42)
|
505
|
+
#
|
506
|
+
# result = compat.to_concrete_aval("string")
|
507
|
+
# self.assertEqual(result, "string")
|
508
|
+
#
|
509
|
+
# # Test with list
|
510
|
+
# test_list = [1, 2, 3]
|
511
|
+
# result = compat.to_concrete_aval(test_list)
|
512
|
+
# self.assertEqual(result, test_list)
|
513
|
+
|
514
|
+
def test_function_name_edge_cases(self):
|
515
|
+
"""Test fun_name with edge cases."""
|
516
|
+
# Built-in function
|
517
|
+
name = compat.fun_name(len)
|
518
|
+
self.assertEqual(name, 'len')
|
519
|
+
|
520
|
+
# Method
|
521
|
+
name = compat.fun_name(str.upper)
|
522
|
+
self.assertEqual(name, 'upper')
|
523
|
+
|
524
|
+
# Nested function
|
525
|
+
def outer():
|
526
|
+
def inner():
|
527
|
+
pass
|
528
|
+
|
529
|
+
return inner
|
530
|
+
|
531
|
+
inner_func = outer()
|
532
|
+
name = compat.fun_name(inner_func)
|
533
|
+
self.assertEqual(name, 'inner')
|
534
|
+
|
535
|
+
def test_concurrent_usage(self):
|
536
|
+
"""Test thread safety of utility functions."""
|
537
|
+
import threading
|
538
|
+
|
539
|
+
results = []
|
540
|
+
errors = []
|
541
|
+
|
542
|
+
def worker():
|
543
|
+
try:
|
544
|
+
# Test safe_map in concurrent context
|
545
|
+
for i in range(100):
|
546
|
+
result = compat.safe_map(lambda x: x * 2, [1, 2, 3])
|
547
|
+
results.append(result)
|
548
|
+
|
549
|
+
# Test safe_zip
|
550
|
+
for i in range(100):
|
551
|
+
result = compat.safe_zip([1, 2], [3, 4])
|
552
|
+
|
553
|
+
# Test unzip2
|
554
|
+
for i in range(100):
|
555
|
+
first, second = compat.unzip2([(1, 'a'), (2, 'b')])
|
556
|
+
|
557
|
+
except Exception as e:
|
558
|
+
errors.append(e)
|
559
|
+
|
560
|
+
threads = []
|
561
|
+
for _ in range(5):
|
562
|
+
thread = threading.Thread(target=worker)
|
563
|
+
threads.append(thread)
|
564
|
+
thread.start()
|
565
|
+
|
566
|
+
for thread in threads:
|
567
|
+
thread.join()
|
568
|
+
|
569
|
+
# Should have no errors
|
570
|
+
self.assertEqual(len(errors), 0)
|
571
|
+
# Should have expected number of results
|
572
|
+
self.assertEqual(len(results), 500) # 5 threads * 100 iterations
|
573
|
+
|
574
|
+
|
575
|
+
class TestTypeHints(unittest.TestCase):
|
576
|
+
"""Test type hints and generic type variables."""
|
577
|
+
|
578
|
+
def test_type_variables_defined(self):
|
579
|
+
"""Test that type variables are properly defined."""
|
580
|
+
# Check TypeVars are available in the module
|
581
|
+
self.assertTrue(hasattr(compat, 'T'))
|
582
|
+
self.assertTrue(hasattr(compat, 'T1'))
|
583
|
+
self.assertTrue(hasattr(compat, 'T2'))
|
584
|
+
self.assertTrue(hasattr(compat, 'T3'))
|
585
|
+
|
586
|
+
def test_unzip2_type_preservation(self):
|
587
|
+
"""Test unzip2 preserves type information."""
|
588
|
+
# Test with specific types
|
589
|
+
int_str_pairs = [(1, 'a'), (2, 'b'), (3, 'c')]
|
590
|
+
ints, strs = compat.unzip2(int_str_pairs)
|
591
|
+
|
592
|
+
# Verify types are preserved
|
593
|
+
self.assertTrue(all(isinstance(x, int) for x in ints))
|
594
|
+
self.assertTrue(all(isinstance(x, str) for x in strs))
|
595
|
+
|
596
|
+
def test_safe_functions_with_different_types(self):
|
597
|
+
"""Test safe functions work with different types."""
|
598
|
+
# Test safe_map with different input types
|
599
|
+
mixed_inputs = [1, 2.5, '3']
|
600
|
+
result = compat.safe_map(str, mixed_inputs)
|
601
|
+
expected = ['1', '2.5', '3']
|
602
|
+
self.assertEqual(result, expected)
|
603
|
+
|
604
|
+
# Test safe_zip with different types
|
605
|
+
result = compat.safe_zip([1, 2], [3.14, 2.71], ['a', 'b'])
|
606
|
+
expected = [(1, 3.14, 'a'), (2, 2.71, 'b')]
|
607
|
+
self.assertEqual(result, expected)
|
608
|
+
|
609
|
+
|
610
|
+
class TestIntegration(unittest.TestCase):
|
611
|
+
"""Integration tests with JAX functionality."""
|
612
|
+
|
613
|
+
def test_jax_integration(self):
|
614
|
+
"""Test integration with JAX arrays and operations."""
|
615
|
+
# Create JAX arrays
|
616
|
+
arr1 = jnp.array([1, 2, 3])
|
617
|
+
arr2 = jnp.array([4, 5, 6])
|
618
|
+
|
619
|
+
# Use safe_map with JAX operations
|
620
|
+
result = compat.safe_map(lambda x, y: x + y, arr1.tolist(), arr2.tolist())
|
621
|
+
expected = [5, 7, 9]
|
622
|
+
self.assertEqual(result, expected)
|
623
|
+
|
624
|
+
# Use safe_zip with JAX arrays
|
625
|
+
result = compat.safe_zip(arr1.tolist(), arr2.tolist())
|
626
|
+
expected = [(1, 4), (2, 5), (3, 6)]
|
627
|
+
self.assertEqual(result, expected)
|
628
|
+
|
629
|
+
def test_with_jax_transformations(self):
|
630
|
+
"""Test compatibility with JAX transformations."""
|
631
|
+
|
632
|
+
def test_function(x):
|
633
|
+
# Use utility functions inside JAX-transformable code
|
634
|
+
pairs = [(x, x + 1), (x + 2, x + 3)]
|
635
|
+
first, second = compat.unzip2(pairs)
|
636
|
+
return jnp.array(first), jnp.array(second)
|
637
|
+
|
638
|
+
# Test function works
|
639
|
+
result1, result2 = test_function(10.0)
|
640
|
+
np.testing.assert_array_equal(result1, [10.0, 12.0])
|
641
|
+
np.testing.assert_array_equal(result2, [11.0, 13.0])
|
642
|
+
|
643
|
+
# Test with JAX transformations
|
644
|
+
jitted_func = jax.jit(test_function)
|
645
|
+
result1, result2 = jitted_func(10.0)
|
646
|
+
np.testing.assert_array_equal(result1, [10.0, 12.0])
|
647
|
+
np.testing.assert_array_equal(result2, [11.0, 13.0])
|
648
|
+
|
649
|
+
|
650
|
+
class TestModuleStructure(unittest.TestCase):
|
651
|
+
"""Test module structure and __all__ exports."""
|
652
|
+
|
653
|
+
def test_all_exports(self):
|
654
|
+
"""Test that __all__ contains expected exports."""
|
655
|
+
expected_exports = [
|
656
|
+
'ClosedJaxpr', 'Primitive', 'extend_axis_env_nd', 'jaxpr_as_fun',
|
657
|
+
'get_aval', 'Tracer', 'to_concrete_aval', 'safe_map', 'safe_zip',
|
658
|
+
'unzip2', 'wraps', 'Device', 'wrap_init', 'Var', 'JaxprEqn',
|
659
|
+
'Jaxpr', 'Literal'
|
660
|
+
]
|
661
|
+
|
662
|
+
for export in expected_exports:
|
663
|
+
self.assertIn(export, compat.__all__,
|
664
|
+
f"{export} should be in __all__")
|
665
|
+
self.assertTrue(hasattr(compat, export),
|
666
|
+
f"{export} should be available in module")
|
667
|
+
|
668
|
+
def test_no_unexpected_exports(self):
|
669
|
+
"""Test that no private functions are exported."""
|
670
|
+
for name in compat.__all__:
|
671
|
+
self.assertFalse(name.startswith('_'),
|
672
|
+
f"Private name {name} should not be in __all__")
|
673
|
+
|
674
|
+
def test_module_docstring(self):
|
675
|
+
"""Test module has proper docstring."""
|
676
|
+
self.assertIsNotNone(compat.__doc__)
|
677
|
+
self.assertIn('Compatibility layer', compat.__doc__)
|
678
|
+
|
679
|
+
|
680
|
+
if __name__ == '__main__':
|
681
|
+
unittest.main(verbosity=2)
|