brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +169 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2319 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +1652 -1652
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1624 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1433 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +137 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +633 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +154 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +477 -477
- brainstate/nn/_dynamics.py +1267 -1267
- brainstate/nn/_dynamics_test.py +67 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +384 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/_rand_funs.py +3938 -3938
- brainstate/random/_rand_funs_test.py +640 -640
- brainstate/random/_rand_seed.py +675 -675
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1617 -1617
- brainstate/random/_rand_state_test.py +551 -551
- brainstate/transform/__init__.py +59 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_eval_shape.py +145 -145
- brainstate/transform/_eval_shape_test.py +38 -38
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2016 -2016
- brainstate/transform/_make_jaxpr_test.py +1510 -1510
- brainstate/transform/_mapping.py +529 -529
- brainstate/transform/_mapping_test.py +194 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_random.py +171 -171
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
- brainstate-0.2.1.dist-info/RECORD +111 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
- brainstate-0.2.0.dist-info/RECORD +0 -111
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/util/_others_test.py
CHANGED
@@ -1,962 +1,962 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
"""
|
17
|
-
Comprehensive tests for the others module.
|
18
|
-
"""
|
19
|
-
|
20
|
-
import pickle
|
21
|
-
import threading
|
22
|
-
import unittest
|
23
|
-
from unittest.mock import MagicMock, patch
|
24
|
-
|
25
|
-
import jax
|
26
|
-
import jax.numpy as jnp
|
27
|
-
|
28
|
-
from brainstate.util._others import (
|
29
|
-
DictManager,
|
30
|
-
DotDict,
|
31
|
-
NameContext,
|
32
|
-
clear_buffer_memory,
|
33
|
-
flatten_dict,
|
34
|
-
get_unique_name,
|
35
|
-
is_instance_eval,
|
36
|
-
merge_dicts,
|
37
|
-
not_instance_eval,
|
38
|
-
split_total,
|
39
|
-
unflatten_dict,
|
40
|
-
)
|
41
|
-
|
42
|
-
|
43
|
-
class TestSplitTotal(unittest.TestCase):
|
44
|
-
"""Test cases for split_total function."""
|
45
|
-
|
46
|
-
def test_float_fraction(self):
|
47
|
-
"""Test with float fraction values."""
|
48
|
-
self.assertEqual(split_total(100, 0.5), 50)
|
49
|
-
self.assertEqual(split_total(100, 0.25), 25)
|
50
|
-
self.assertEqual(split_total(100, 0.75), 75)
|
51
|
-
self.assertEqual(split_total(100, 0.0), 0)
|
52
|
-
self.assertEqual(split_total(100, 1.0), 100)
|
53
|
-
|
54
|
-
def test_int_fraction(self):
|
55
|
-
"""Test with integer fraction values."""
|
56
|
-
self.assertEqual(split_total(100, 25), 25)
|
57
|
-
self.assertEqual(split_total(100, 0), 0)
|
58
|
-
self.assertEqual(split_total(100, 100), 100)
|
59
|
-
self.assertEqual(split_total(50, 30), 30)
|
60
|
-
|
61
|
-
def test_edge_cases(self):
|
62
|
-
"""Test edge cases."""
|
63
|
-
self.assertEqual(split_total(1, 0.5), 0) # int(0.5) = 0
|
64
|
-
self.assertEqual(split_total(1, 1), 1)
|
65
|
-
self.assertEqual(split_total(10, 0.99), 9) # int(9.9) = 9
|
66
|
-
|
67
|
-
def test_type_errors(self):
|
68
|
-
"""Test type error handling."""
|
69
|
-
with self.assertRaises(TypeError) as ctx:
|
70
|
-
split_total("100", 0.5)
|
71
|
-
self.assertIn("must be an integer", str(ctx.exception))
|
72
|
-
|
73
|
-
with self.assertRaises(TypeError) as ctx:
|
74
|
-
split_total(100, "0.5")
|
75
|
-
self.assertIn("must be an integer or float", str(ctx.exception))
|
76
|
-
|
77
|
-
with self.assertRaises(TypeError) as ctx:
|
78
|
-
split_total(100.5, 0.5)
|
79
|
-
self.assertIn("must be an integer", str(ctx.exception))
|
80
|
-
|
81
|
-
def test_value_errors(self):
|
82
|
-
"""Test value error handling."""
|
83
|
-
# Negative total
|
84
|
-
with self.assertRaises(ValueError) as ctx:
|
85
|
-
split_total(-10, 0.5)
|
86
|
-
self.assertIn("must be a positive integer", str(ctx.exception))
|
87
|
-
|
88
|
-
# Zero total
|
89
|
-
with self.assertRaises(ValueError) as ctx:
|
90
|
-
split_total(0, 0.5)
|
91
|
-
self.assertIn("must be a positive integer", str(ctx.exception))
|
92
|
-
|
93
|
-
# Negative fraction (float)
|
94
|
-
with self.assertRaises(ValueError) as ctx:
|
95
|
-
split_total(100, -0.5)
|
96
|
-
self.assertIn("cannot be negative", str(ctx.exception))
|
97
|
-
|
98
|
-
# Fraction > 1 (float)
|
99
|
-
with self.assertRaises(ValueError) as ctx:
|
100
|
-
split_total(100, 1.5)
|
101
|
-
self.assertIn("cannot be greater than 1", str(ctx.exception))
|
102
|
-
|
103
|
-
# Negative fraction (int)
|
104
|
-
with self.assertRaises(ValueError) as ctx:
|
105
|
-
split_total(100, -10)
|
106
|
-
self.assertIn("cannot be negative", str(ctx.exception))
|
107
|
-
|
108
|
-
# Fraction > total (int)
|
109
|
-
with self.assertRaises(ValueError) as ctx:
|
110
|
-
split_total(100, 150)
|
111
|
-
self.assertIn("cannot be greater than total", str(ctx.exception))
|
112
|
-
|
113
|
-
|
114
|
-
class TestNameContext(unittest.TestCase):
|
115
|
-
"""Test cases for NameContext and get_unique_name."""
|
116
|
-
|
117
|
-
def setUp(self):
|
118
|
-
"""Reset the global NAME context before each test."""
|
119
|
-
global NAME
|
120
|
-
from brainstate.util._others import NAME
|
121
|
-
NAME.typed_names.clear()
|
122
|
-
|
123
|
-
def test_get_unique_name_basic(self):
|
124
|
-
"""Test basic unique name generation."""
|
125
|
-
name1 = get_unique_name('layer')
|
126
|
-
name2 = get_unique_name('layer')
|
127
|
-
name3 = get_unique_name('layer')
|
128
|
-
|
129
|
-
self.assertEqual(name1, 'layer0')
|
130
|
-
self.assertEqual(name2, 'layer1')
|
131
|
-
self.assertEqual(name3, 'layer2')
|
132
|
-
|
133
|
-
def test_get_unique_name_with_prefix(self):
|
134
|
-
"""Test unique name generation with prefix."""
|
135
|
-
name1 = get_unique_name('layer', 'conv_')
|
136
|
-
name2 = get_unique_name('layer', 'conv_')
|
137
|
-
name3 = get_unique_name('layer', 'dense_')
|
138
|
-
|
139
|
-
self.assertEqual(name1, 'conv_layer0')
|
140
|
-
self.assertEqual(name2, 'conv_layer1')
|
141
|
-
self.assertEqual(name3, 'dense_layer2')
|
142
|
-
|
143
|
-
def test_different_types(self):
|
144
|
-
"""Test unique names for different types."""
|
145
|
-
layer1 = get_unique_name('layer')
|
146
|
-
neuron1 = get_unique_name('neuron')
|
147
|
-
layer2 = get_unique_name('layer')
|
148
|
-
neuron2 = get_unique_name('neuron')
|
149
|
-
|
150
|
-
self.assertEqual(layer1, 'layer0')
|
151
|
-
self.assertEqual(neuron1, 'neuron0')
|
152
|
-
self.assertEqual(layer2, 'layer1')
|
153
|
-
self.assertEqual(neuron2, 'neuron1')
|
154
|
-
|
155
|
-
def test_thread_local_context(self):
|
156
|
-
"""Test that NameContext is thread-local."""
|
157
|
-
results = {}
|
158
|
-
|
159
|
-
def worker(thread_id):
|
160
|
-
name1 = get_unique_name(f'type_{thread_id}')
|
161
|
-
name2 = get_unique_name(f'type_{thread_id}')
|
162
|
-
results[thread_id] = (name1, name2)
|
163
|
-
|
164
|
-
threads = []
|
165
|
-
for i in range(3):
|
166
|
-
t = threading.Thread(target=worker, args=(i,))
|
167
|
-
threads.append(t)
|
168
|
-
t.start()
|
169
|
-
|
170
|
-
for t in threads:
|
171
|
-
t.join()
|
172
|
-
|
173
|
-
# Each thread should have independent counters
|
174
|
-
for thread_id in range(3):
|
175
|
-
self.assertEqual(results[thread_id][0], f'type_{thread_id}0')
|
176
|
-
self.assertEqual(results[thread_id][1], f'type_{thread_id}1')
|
177
|
-
|
178
|
-
def test_name_context_reset(self):
|
179
|
-
"""Test resetting name context."""
|
180
|
-
context = NameContext()
|
181
|
-
context.typed_names['test'] = 5
|
182
|
-
|
183
|
-
# Reset specific type
|
184
|
-
context.reset('test')
|
185
|
-
self.assertEqual(context.typed_names.get('test'), 0)
|
186
|
-
|
187
|
-
# Add more names
|
188
|
-
context.typed_names['test1'] = 1
|
189
|
-
context.typed_names['test2'] = 2
|
190
|
-
|
191
|
-
# Reset all
|
192
|
-
context.reset()
|
193
|
-
self.assertEqual(len(context.typed_names), 0)
|
194
|
-
|
195
|
-
|
196
|
-
class TestDictManager(unittest.TestCase):
|
197
|
-
"""Test cases for DictManager class."""
|
198
|
-
|
199
|
-
def test_initialization(self):
|
200
|
-
"""Test DictManager initialization."""
|
201
|
-
# Empty initialization
|
202
|
-
dm1 = DictManager()
|
203
|
-
self.assertEqual(len(dm1), 0)
|
204
|
-
|
205
|
-
# From dict
|
206
|
-
dm2 = DictManager({'a': 1, 'b': 2})
|
207
|
-
self.assertEqual(dm2['a'], 1)
|
208
|
-
self.assertEqual(dm2['b'], 2)
|
209
|
-
|
210
|
-
# From kwargs
|
211
|
-
dm3 = DictManager(a=1, b=2)
|
212
|
-
self.assertEqual(dm3['a'], 1)
|
213
|
-
self.assertEqual(dm3['b'], 2)
|
214
|
-
|
215
|
-
# From items
|
216
|
-
dm4 = DictManager([('a', 1), ('b', 2)])
|
217
|
-
self.assertEqual(dm4['a'], 1)
|
218
|
-
self.assertEqual(dm4['b'], 2)
|
219
|
-
|
220
|
-
def test_subset(self):
|
221
|
-
"""Test subset filtering."""
|
222
|
-
dm = DictManager({'a': 1, 'b': 2.0, 'c': 'text', 'd': 3})
|
223
|
-
|
224
|
-
# By type
|
225
|
-
int_subset = dm.subset(int)
|
226
|
-
self.assertEqual(dict(int_subset), {'a': 1, 'd': 3})
|
227
|
-
|
228
|
-
# By multiple types
|
229
|
-
num_subset = dm.subset((int, float))
|
230
|
-
self.assertEqual(dict(num_subset), {'a': 1, 'b': 2.0, 'd': 3})
|
231
|
-
|
232
|
-
# By predicate
|
233
|
-
large_subset = dm.subset(lambda x: isinstance(x, (int, float)) and x > 1.5)
|
234
|
-
self.assertEqual(dict(large_subset), {'b': 2.0, 'd': 3})
|
235
|
-
|
236
|
-
def test_not_subset(self):
|
237
|
-
"""Test not_subset filtering."""
|
238
|
-
dm = DictManager({'a': 1, 'b': 2.0, 'c': 'text', 'd': 3})
|
239
|
-
|
240
|
-
not_int = dm.not_subset(int)
|
241
|
-
self.assertEqual(dict(not_int), {'b': 2.0, 'c': 'text'})
|
242
|
-
|
243
|
-
not_num = dm.not_subset((int, float))
|
244
|
-
self.assertEqual(dict(not_num), {'c': 'text'})
|
245
|
-
|
246
|
-
def test_add_unique_key(self):
|
247
|
-
"""Test adding unique keys."""
|
248
|
-
dm = DictManager()
|
249
|
-
obj1 = object()
|
250
|
-
obj2 = object()
|
251
|
-
|
252
|
-
# Add new key
|
253
|
-
dm.add_unique_key('key1', obj1)
|
254
|
-
self.assertIs(dm['key1'], obj1)
|
255
|
-
|
256
|
-
# Add same key with same object (should work)
|
257
|
-
dm.add_unique_key('key1', obj1)
|
258
|
-
self.assertIs(dm['key1'], obj1)
|
259
|
-
|
260
|
-
# Add same key with different object (should fail)
|
261
|
-
with self.assertRaises(ValueError) as ctx:
|
262
|
-
dm.add_unique_key('key1', obj2)
|
263
|
-
self.assertIn("already exists with a different value", str(ctx.exception))
|
264
|
-
|
265
|
-
def test_add_unique_value(self):
|
266
|
-
"""Test adding unique values."""
|
267
|
-
dm = DictManager()
|
268
|
-
obj1 = object()
|
269
|
-
|
270
|
-
# First addition should succeed
|
271
|
-
result1 = dm.add_unique_value('key1', obj1)
|
272
|
-
self.assertTrue(result1)
|
273
|
-
self.assertIs(dm['key1'], obj1)
|
274
|
-
|
275
|
-
# Adding same value with different key should fail
|
276
|
-
result2 = dm.add_unique_value('key2', obj1)
|
277
|
-
self.assertFalse(result2)
|
278
|
-
self.assertNotIn('key2', dm)
|
279
|
-
|
280
|
-
# Adding different value should succeed
|
281
|
-
obj2 = object()
|
282
|
-
result3 = dm.add_unique_value('key2', obj2)
|
283
|
-
self.assertTrue(result3)
|
284
|
-
self.assertIs(dm['key2'], obj2)
|
285
|
-
|
286
|
-
def test_unique(self):
|
287
|
-
"""Test getting unique values."""
|
288
|
-
obj1 = object()
|
289
|
-
obj2 = object()
|
290
|
-
dm = DictManager({'a': obj1, 'b': obj2, 'c': obj1, 'd': obj2, 'e': obj1})
|
291
|
-
|
292
|
-
unique_dm = dm.unique()
|
293
|
-
self.assertEqual(len(unique_dm), 2)
|
294
|
-
# Check that each object appears only once
|
295
|
-
values = list(unique_dm.values())
|
296
|
-
self.assertEqual(len(set(id(v) for v in values)), 2)
|
297
|
-
|
298
|
-
def test_unique_inplace(self):
|
299
|
-
"""Test in-place unique operation."""
|
300
|
-
obj1 = object()
|
301
|
-
obj2 = object()
|
302
|
-
dm = DictManager({'a': obj1, 'b': obj2, 'c': obj1})
|
303
|
-
|
304
|
-
result = dm.unique_()
|
305
|
-
self.assertIs(result, dm) # Should return self
|
306
|
-
self.assertEqual(len(dm), 2) # One duplicate removed
|
307
|
-
|
308
|
-
def test_assign(self):
|
309
|
-
"""Test assign method."""
|
310
|
-
dm = DictManager({'a': 1})
|
311
|
-
|
312
|
-
# Assign from dict
|
313
|
-
dm.assign({'b': 2, 'c': 3})
|
314
|
-
self.assertEqual(dict(dm), {'a': 1, 'b': 2, 'c': 3})
|
315
|
-
|
316
|
-
# Assign from multiple dicts
|
317
|
-
dm.assign({'d': 4}, {'e': 5})
|
318
|
-
self.assertEqual(len(dm), 5)
|
319
|
-
|
320
|
-
# Assign with kwargs
|
321
|
-
dm.assign(f=6, g=7)
|
322
|
-
self.assertEqual(dm['f'], 6)
|
323
|
-
self.assertEqual(dm['g'], 7)
|
324
|
-
|
325
|
-
# Invalid argument
|
326
|
-
with self.assertRaises(TypeError):
|
327
|
-
dm.assign([1, 2, 3])
|
328
|
-
|
329
|
-
def test_split(self):
|
330
|
-
"""Test splitting by types."""
|
331
|
-
dm = DictManager({
|
332
|
-
'a': 1, 'b': 2.0, 'c': 'text',
|
333
|
-
'd': 3, 'e': 4.5, 'f': [1, 2]
|
334
|
-
})
|
335
|
-
|
336
|
-
int_dm, float_dm, rest = dm.split(int, float)
|
337
|
-
|
338
|
-
self.assertEqual(dict(int_dm), {'a': 1, 'd': 3})
|
339
|
-
self.assertEqual(dict(float_dm), {'b': 2.0, 'e': 4.5})
|
340
|
-
self.assertEqual(dict(rest), {'c': 'text', 'f': [1, 2]})
|
341
|
-
|
342
|
-
def test_filter_by_predicate(self):
|
343
|
-
"""Test filtering with predicate."""
|
344
|
-
dm = DictManager({'a': 1, 'b': 2, 'c': 3, 'd': 4})
|
345
|
-
|
346
|
-
# Filter by key
|
347
|
-
filtered = dm.filter_by_predicate(lambda k, v: k in ['a', 'c'])
|
348
|
-
self.assertEqual(dict(filtered), {'a': 1, 'c': 3})
|
349
|
-
|
350
|
-
# Filter by value
|
351
|
-
filtered = dm.filter_by_predicate(lambda k, v: v > 2)
|
352
|
-
self.assertEqual(dict(filtered), {'c': 3, 'd': 4})
|
353
|
-
|
354
|
-
# Filter by both
|
355
|
-
filtered = dm.filter_by_predicate(lambda k, v: k == 'a' or v == 4)
|
356
|
-
self.assertEqual(dict(filtered), {'a': 1, 'd': 4})
|
357
|
-
|
358
|
-
def test_map_values(self):
|
359
|
-
"""Test mapping function to values."""
|
360
|
-
dm = DictManager({'a': 1, 'b': 2, 'c': 3})
|
361
|
-
|
362
|
-
doubled = dm.map_values(lambda x: x * 2)
|
363
|
-
self.assertEqual(dict(doubled), {'a': 2, 'b': 4, 'c': 6})
|
364
|
-
|
365
|
-
# Original should be unchanged
|
366
|
-
self.assertEqual(dict(dm), {'a': 1, 'b': 2, 'c': 3})
|
367
|
-
|
368
|
-
def test_map_keys(self):
|
369
|
-
"""Test mapping function to keys."""
|
370
|
-
dm = DictManager({'a': 1, 'b': 2, 'c': 3})
|
371
|
-
|
372
|
-
upper = dm.map_keys(str.upper)
|
373
|
-
self.assertEqual(dict(upper), {'A': 1, 'B': 2, 'C': 3})
|
374
|
-
|
375
|
-
# Test duplicate key error
|
376
|
-
with self.assertRaises(ValueError) as ctx:
|
377
|
-
dm.map_keys(lambda x: 'same')
|
378
|
-
self.assertIn("duplicate", str(ctx.exception))
|
379
|
-
|
380
|
-
def test_pop_by_keys(self):
|
381
|
-
"""Test removing items by keys."""
|
382
|
-
dm = DictManager({'a': 1, 'b': 2, 'c': 3, 'd': 4})
|
383
|
-
|
384
|
-
dm.pop_by_keys(['b', 'd'])
|
385
|
-
self.assertEqual(dict(dm), {'a': 1, 'c': 3})
|
386
|
-
|
387
|
-
# Pop non-existent keys (should not raise)
|
388
|
-
dm.pop_by_keys(['x', 'y'])
|
389
|
-
self.assertEqual(dict(dm), {'a': 1, 'c': 3})
|
390
|
-
|
391
|
-
def test_pop_by_values(self):
|
392
|
-
"""Test removing items by values."""
|
393
|
-
obj1, obj2, obj3 = object(), object(), object()
|
394
|
-
dm = DictManager({'a': obj1, 'b': obj2, 'c': obj3})
|
395
|
-
|
396
|
-
# By identity
|
397
|
-
dm_copy = DictManager(dm)
|
398
|
-
dm_copy.pop_by_values([obj2], by='id')
|
399
|
-
self.assertEqual(len(dm_copy), 2)
|
400
|
-
self.assertNotIn('b', dm_copy)
|
401
|
-
|
402
|
-
# By value equality
|
403
|
-
dm2 = DictManager({'a': 1, 'b': 2, 'c': 3})
|
404
|
-
dm2.pop_by_values([2, 3], by='value')
|
405
|
-
self.assertEqual(dict(dm2), {'a': 1})
|
406
|
-
|
407
|
-
# Invalid method
|
408
|
-
with self.assertRaises(ValueError):
|
409
|
-
dm.pop_by_values([obj1], by='invalid')
|
410
|
-
|
411
|
-
def test_difference_operations(self):
|
412
|
-
"""Test difference operations."""
|
413
|
-
dm = DictManager({'a': 1, 'b': 2, 'c': 3, 'd': 4})
|
414
|
-
|
415
|
-
# Difference by keys
|
416
|
-
diff = dm.difference_by_keys(['b', 'd'])
|
417
|
-
self.assertEqual(dict(diff), {'a': 1, 'c': 3})
|
418
|
-
|
419
|
-
# Difference by values
|
420
|
-
diff = dm.difference_by_values([2, 4], by='value')
|
421
|
-
self.assertEqual(dict(diff), {'a': 1, 'c': 3})
|
422
|
-
|
423
|
-
def test_intersection_operations(self):
|
424
|
-
"""Test intersection operations."""
|
425
|
-
dm = DictManager({'a': 1, 'b': 2, 'c': 3, 'd': 4})
|
426
|
-
|
427
|
-
# Intersection by keys
|
428
|
-
inter = dm.intersection_by_keys(['a', 'c', 'e'])
|
429
|
-
self.assertEqual(dict(inter), {'a': 1, 'c': 3})
|
430
|
-
|
431
|
-
# Intersection by values
|
432
|
-
inter = dm.intersection_by_values([1, 3, 5], by='value')
|
433
|
-
self.assertEqual(dict(inter), {'a': 1, 'c': 3})
|
434
|
-
|
435
|
-
def test_operators(self):
|
436
|
-
"""Test operator overloading."""
|
437
|
-
dm1 = DictManager({'a': 1, 'b': 2})
|
438
|
-
dm2 = DictManager({'c': 3, 'd': 4})
|
439
|
-
|
440
|
-
# Addition
|
441
|
-
dm3 = dm1 + dm2
|
442
|
-
self.assertEqual(dict(dm3), {'a': 1, 'b': 2, 'c': 3, 'd': 4})
|
443
|
-
self.assertIsNot(dm3, dm1) # New object
|
444
|
-
|
445
|
-
# Addition with regular dict
|
446
|
-
dm4 = dm1 + {'e': 5}
|
447
|
-
self.assertEqual(dm4['e'], 5)
|
448
|
-
|
449
|
-
# Union operator (Python 3.9+) - only test if available
|
450
|
-
import sys
|
451
|
-
if sys.version_info >= (3, 9):
|
452
|
-
dm5 = dm1 | dm2
|
453
|
-
self.assertEqual(dict(dm5), {'a': 1, 'b': 2, 'c': 3, 'd': 4})
|
454
|
-
|
455
|
-
# In-place union
|
456
|
-
dm1_copy = DictManager(dm1)
|
457
|
-
dm1_copy |= dm2
|
458
|
-
self.assertEqual(dict(dm1_copy), {'a': 1, 'b': 2, 'c': 3, 'd': 4})
|
459
|
-
|
460
|
-
# Invalid operations should raise TypeError or return NotImplemented
|
461
|
-
# The actual behavior depends on whether __add__ returns NotImplemented
|
462
|
-
# or lets Python raise TypeError
|
463
|
-
with self.assertRaises(TypeError):
|
464
|
-
_ = dm1 + 123
|
465
|
-
if sys.version_info >= (3, 9):
|
466
|
-
with self.assertRaises(TypeError):
|
467
|
-
_ = dm1 | 123
|
468
|
-
|
469
|
-
def test_copy_operations(self):
|
470
|
-
"""Test copy operations."""
|
471
|
-
obj = object()
|
472
|
-
dm1 = DictManager({'a': 1, 'b': obj})
|
473
|
-
|
474
|
-
# Shallow copy
|
475
|
-
dm2 = dm1.__copy__()
|
476
|
-
self.assertIsNot(dm2, dm1)
|
477
|
-
self.assertEqual(dict(dm2), dict(dm1))
|
478
|
-
self.assertIs(dm2['b'], obj) # Same object reference
|
479
|
-
|
480
|
-
# Deep copy
|
481
|
-
dm3 = dm1.__deepcopy__({})
|
482
|
-
self.assertIsNot(dm3, dm1)
|
483
|
-
self.assertEqual(dm3['a'], 1)
|
484
|
-
# Note: object() can't be deep copied, but other values work
|
485
|
-
|
486
|
-
def test_jax_pytree(self):
|
487
|
-
"""Test JAX pytree registration."""
|
488
|
-
dm = DictManager({'a': 1, 'b': 2, 'c': 3})
|
489
|
-
|
490
|
-
# Flatten
|
491
|
-
values, keys = dm.tree_flatten()
|
492
|
-
self.assertEqual(len(values), 3)
|
493
|
-
self.assertEqual(len(keys), 3)
|
494
|
-
|
495
|
-
# Unflatten
|
496
|
-
dm2 = DictManager.tree_unflatten(keys, values)
|
497
|
-
self.assertEqual(dict(dm2), dict(dm))
|
498
|
-
|
499
|
-
# Test with JAX tree operations
|
500
|
-
dm3 = DictManager({'x': jnp.array([1, 2]), 'y': jnp.array([3, 4])})
|
501
|
-
doubled = jax.tree_util.tree_map(lambda x: x * 2, dm3)
|
502
|
-
self.assertTrue(jnp.allclose(doubled['x'], jnp.array([2, 4])))
|
503
|
-
self.assertTrue(jnp.allclose(doubled['y'], jnp.array([6, 8])))
|
504
|
-
|
505
|
-
def test_repr(self):
|
506
|
-
"""Test string representation."""
|
507
|
-
dm = DictManager({'a': 1, 'b': 'text'})
|
508
|
-
repr_str = repr(dm)
|
509
|
-
self.assertIn('DictManager', repr_str)
|
510
|
-
self.assertIn("'a': 1", repr_str)
|
511
|
-
self.assertIn("'b': 'text'", repr_str)
|
512
|
-
|
513
|
-
|
514
|
-
class TestDotDict(unittest.TestCase):
|
515
|
-
"""Test cases for DotDict class."""
|
516
|
-
|
517
|
-
def test_initialization(self):
|
518
|
-
"""Test DotDict initialization."""
|
519
|
-
# From dict
|
520
|
-
dd1 = DotDict({'a': 1, 'b': 2})
|
521
|
-
self.assertEqual(dd1.a, 1)
|
522
|
-
self.assertEqual(dd1['b'], 2)
|
523
|
-
|
524
|
-
# From kwargs
|
525
|
-
dd2 = DotDict(a=1, b=2)
|
526
|
-
self.assertEqual(dd2.a, 1)
|
527
|
-
self.assertEqual(dd2.b, 2)
|
528
|
-
|
529
|
-
# From tuple
|
530
|
-
dd3 = DotDict(('key', 'value'))
|
531
|
-
self.assertEqual(dd3.key, 'value')
|
532
|
-
|
533
|
-
# From items
|
534
|
-
dd4 = DotDict([('a', 1), ('b', 2)])
|
535
|
-
self.assertEqual(dd4.a, 1)
|
536
|
-
self.assertEqual(dd4.b, 2)
|
537
|
-
|
538
|
-
# Empty
|
539
|
-
dd5 = DotDict()
|
540
|
-
self.assertEqual(len(dd5), 0)
|
541
|
-
|
542
|
-
# Invalid argument
|
543
|
-
with self.assertRaises(TypeError):
|
544
|
-
DotDict(123)
|
545
|
-
|
546
|
-
def test_dot_access(self):
|
547
|
-
"""Test dot notation access."""
|
548
|
-
dd = DotDict({'a': 1, 'b': {'c': 2, 'd': 3}})
|
549
|
-
|
550
|
-
# Read access
|
551
|
-
self.assertEqual(dd.a, 1)
|
552
|
-
self.assertEqual(dd.b.c, 2)
|
553
|
-
self.assertEqual(dd.b.d, 3)
|
554
|
-
|
555
|
-
# Write access
|
556
|
-
dd.a = 10
|
557
|
-
dd.b.c = 20
|
558
|
-
self.assertEqual(dd['a'], 10)
|
559
|
-
self.assertEqual(dd['b']['c'], 20)
|
560
|
-
|
561
|
-
# Add new attributes
|
562
|
-
dd.e = 5
|
563
|
-
self.assertEqual(dd['e'], 5)
|
564
|
-
|
565
|
-
def test_nested_dict_conversion(self):
|
566
|
-
"""Test automatic nested dict conversion."""
|
567
|
-
dd = DotDict({
|
568
|
-
'level1': {
|
569
|
-
'level2': {
|
570
|
-
'level3': 'value'
|
571
|
-
}
|
572
|
-
}
|
573
|
-
})
|
574
|
-
|
575
|
-
self.assertIsInstance(dd.level1, DotDict)
|
576
|
-
self.assertIsInstance(dd.level1.level2, DotDict)
|
577
|
-
self.assertEqual(dd.level1.level2.level3, 'value')
|
578
|
-
|
579
|
-
def test_list_tuple_conversion(self):
|
580
|
-
"""Test conversion of lists and tuples containing dicts."""
|
581
|
-
dd = DotDict({
|
582
|
-
'list': [{'a': 1}, {'b': 2}],
|
583
|
-
'tuple': ({'c': 3}, {'d': 4})
|
584
|
-
})
|
585
|
-
|
586
|
-
# List items should be DotDict
|
587
|
-
self.assertIsInstance(dd.list[0], DotDict)
|
588
|
-
self.assertEqual(dd.list[0].a, 1)
|
589
|
-
self.assertIsInstance(dd.list[1], DotDict)
|
590
|
-
self.assertEqual(dd.list[1].b, 2)
|
591
|
-
|
592
|
-
# Tuple items should be DotDict
|
593
|
-
self.assertIsInstance(dd.tuple[0], DotDict)
|
594
|
-
self.assertEqual(dd.tuple[0].c, 3)
|
595
|
-
|
596
|
-
def test_attribute_errors(self):
|
597
|
-
"""Test attribute error handling."""
|
598
|
-
dd = DotDict({'a': 1})
|
599
|
-
|
600
|
-
# Non-existent attribute
|
601
|
-
with self.assertRaises(AttributeError) as ctx:
|
602
|
-
_ = dd.nonexistent
|
603
|
-
self.assertIn("has no attribute 'nonexistent'", str(ctx.exception))
|
604
|
-
|
605
|
-
# Delete non-existent attribute
|
606
|
-
with self.assertRaises(AttributeError):
|
607
|
-
del dd.nonexistent
|
608
|
-
|
609
|
-
# Try to set built-in method
|
610
|
-
with self.assertRaises(AttributeError) as ctx:
|
611
|
-
dd.keys = 'value'
|
612
|
-
self.assertIn("built-in method", str(ctx.exception))
|
613
|
-
|
614
|
-
def test_dir_method(self):
|
615
|
-
"""Test __dir__ method."""
|
616
|
-
dd = DotDict({'a': 1, 'b': 2})
|
617
|
-
attrs = dir(dd)
|
618
|
-
|
619
|
-
self.assertIn('a', attrs)
|
620
|
-
self.assertIn('b', attrs)
|
621
|
-
self.assertIn('keys', attrs) # Built-in method
|
622
|
-
self.assertIn('values', attrs) # Built-in method
|
623
|
-
|
624
|
-
def test_get_method(self):
|
625
|
-
"""Test get method with default."""
|
626
|
-
dd = DotDict({'a': 1})
|
627
|
-
|
628
|
-
self.assertEqual(dd.get('a'), 1)
|
629
|
-
self.assertEqual(dd.get('b'), None)
|
630
|
-
self.assertEqual(dd.get('b', 'default'), 'default')
|
631
|
-
|
632
|
-
def test_copy_operations(self):
|
633
|
-
"""Test copy operations."""
|
634
|
-
dd1 = DotDict({'a': 1, 'b': {'c': 2}})
|
635
|
-
|
636
|
-
# Shallow copy
|
637
|
-
dd2 = dd1.copy()
|
638
|
-
self.assertIsNot(dd2, dd1)
|
639
|
-
self.assertEqual(dd2.a, 1)
|
640
|
-
dd2.a = 10
|
641
|
-
self.assertEqual(dd1.a, 1) # Original unchanged
|
642
|
-
|
643
|
-
# Deep copy
|
644
|
-
dd3 = dd1.deepcopy()
|
645
|
-
self.assertIsNot(dd3, dd1)
|
646
|
-
self.assertIsNot(dd3.b, dd1.b)
|
647
|
-
dd3.b.c = 20
|
648
|
-
self.assertEqual(dd1.b.c, 2) # Original unchanged
|
649
|
-
|
650
|
-
def test_to_dict_from_dict(self):
|
651
|
-
"""Test conversion to/from standard dict."""
|
652
|
-
dd1 = DotDict({
|
653
|
-
'a': 1,
|
654
|
-
'b': {'c': 2, 'd': {'e': 3}},
|
655
|
-
'list': [{'f': 4}]
|
656
|
-
})
|
657
|
-
|
658
|
-
# Convert to dict
|
659
|
-
d = dd1.to_dict()
|
660
|
-
self.assertIsInstance(d, dict)
|
661
|
-
self.assertNotIsInstance(d, DotDict)
|
662
|
-
self.assertIsInstance(d['b'], dict)
|
663
|
-
self.assertNotIsInstance(d['b'], DotDict)
|
664
|
-
self.assertEqual(d['b']['d']['e'], 3)
|
665
|
-
|
666
|
-
# Convert from dict
|
667
|
-
dd2 = DotDict.from_dict(d)
|
668
|
-
self.assertIsInstance(dd2, DotDict)
|
669
|
-
self.assertIsInstance(dd2.b, DotDict)
|
670
|
-
self.assertEqual(dd2.b.d.e, 3)
|
671
|
-
|
672
|
-
def test_update_method(self):
|
673
|
-
"""Test update with recursive merge."""
|
674
|
-
dd = DotDict({'a': 1, 'b': {'c': 2, 'd': 3}})
|
675
|
-
|
676
|
-
# Simple update
|
677
|
-
dd.update({'a': 10})
|
678
|
-
self.assertEqual(dd.a, 10)
|
679
|
-
|
680
|
-
# Recursive merge
|
681
|
-
dd.update({'b': {'d': 30, 'e': 4}})
|
682
|
-
self.assertEqual(dd.b.c, 2) # Preserved
|
683
|
-
self.assertEqual(dd.b.d, 30) # Updated
|
684
|
-
self.assertEqual(dd.b.e, 4) # Added
|
685
|
-
|
686
|
-
# Update with kwargs
|
687
|
-
dd.update(f=5, g=6)
|
688
|
-
self.assertEqual(dd.f, 5)
|
689
|
-
self.assertEqual(dd.g, 6)
|
690
|
-
|
691
|
-
# Multiple arguments error
|
692
|
-
with self.assertRaises(TypeError):
|
693
|
-
dd.update({}, {})
|
694
|
-
|
695
|
-
def test_setdefault(self):
|
696
|
-
"""Test setdefault method."""
|
697
|
-
dd = DotDict({'a': 1})
|
698
|
-
|
699
|
-
# Existing key
|
700
|
-
result = dd.setdefault('a', 10)
|
701
|
-
self.assertEqual(result, 1)
|
702
|
-
self.assertEqual(dd.a, 1)
|
703
|
-
|
704
|
-
# New key
|
705
|
-
result = dd.setdefault('b', 2)
|
706
|
-
self.assertEqual(result, 2)
|
707
|
-
self.assertEqual(dd.b, 2)
|
708
|
-
|
709
|
-
# New key with None default
|
710
|
-
result = dd.setdefault('c')
|
711
|
-
self.assertIsNone(result)
|
712
|
-
self.assertIsNone(dd.c)
|
713
|
-
|
714
|
-
def test_pickling(self):
|
715
|
-
"""Test pickling/unpickling."""
|
716
|
-
dd1 = DotDict({'a': 1, 'b': {'c': 2}})
|
717
|
-
|
718
|
-
# Pickle and unpickle
|
719
|
-
pickled = pickle.dumps(dd1)
|
720
|
-
dd2 = pickle.loads(pickled)
|
721
|
-
|
722
|
-
self.assertIsNot(dd2, dd1)
|
723
|
-
self.assertEqual(dd2.a, 1)
|
724
|
-
self.assertEqual(dd2.b.c, 2)
|
725
|
-
self.assertIsInstance(dd2, DotDict)
|
726
|
-
self.assertIsInstance(dd2.b, DotDict)
|
727
|
-
|
728
|
-
def test_jax_pytree(self):
|
729
|
-
"""Test JAX pytree registration."""
|
730
|
-
dd = DotDict({'a': jnp.array([1, 2]), 'b': jnp.array([3, 4])})
|
731
|
-
|
732
|
-
# Tree operations
|
733
|
-
doubled = jax.tree_util.tree_map(lambda x: x * 2, dd)
|
734
|
-
self.assertTrue(jnp.allclose(doubled.a, jnp.array([2, 4])))
|
735
|
-
self.assertTrue(jnp.allclose(doubled.b, jnp.array([6, 8])))
|
736
|
-
|
737
|
-
def test_repr(self):
|
738
|
-
"""Test string representation."""
|
739
|
-
dd = DotDict({'a': 1, 'b': 'text'})
|
740
|
-
repr_str = repr(dd)
|
741
|
-
self.assertIn('DotDict', repr_str)
|
742
|
-
self.assertIn("'a': 1", repr_str)
|
743
|
-
self.assertIn("'b': 'text'", repr_str)
|
744
|
-
|
745
|
-
|
746
|
-
class TestUtilityFunctions(unittest.TestCase):
|
747
|
-
"""Test cases for utility functions."""
|
748
|
-
|
749
|
-
def test_merge_dicts_basic(self):
|
750
|
-
"""Test basic dict merging."""
|
751
|
-
d1 = {'a': 1, 'b': 2}
|
752
|
-
d2 = {'b': 3, 'c': 4}
|
753
|
-
d3 = {'d': 5}
|
754
|
-
|
755
|
-
result = merge_dicts(d1, d2, d3)
|
756
|
-
self.assertEqual(result, {'a': 1, 'b': 3, 'c': 4, 'd': 5})
|
757
|
-
|
758
|
-
# Original dicts should be unchanged
|
759
|
-
self.assertEqual(d1, {'a': 1, 'b': 2})
|
760
|
-
|
761
|
-
def test_merge_dicts_recursive(self):
|
762
|
-
"""Test recursive dict merging."""
|
763
|
-
d1 = {'a': 1, 'b': {'c': 2, 'd': 3}}
|
764
|
-
d2 = {'b': {'d': 4, 'e': 5}, 'f': 6}
|
765
|
-
|
766
|
-
result = merge_dicts(d1, d2, recursive=True)
|
767
|
-
self.assertEqual(result, {
|
768
|
-
'a': 1,
|
769
|
-
'b': {'c': 2, 'd': 4, 'e': 5},
|
770
|
-
'f': 6
|
771
|
-
})
|
772
|
-
|
773
|
-
def test_merge_dicts_non_recursive(self):
|
774
|
-
"""Test non-recursive dict merging."""
|
775
|
-
d1 = {'a': 1, 'b': {'c': 2}}
|
776
|
-
d2 = {'b': {'d': 3}}
|
777
|
-
|
778
|
-
result = merge_dicts(d1, d2, recursive=False)
|
779
|
-
self.assertEqual(result, {'a': 1, 'b': {'d': 3}})
|
780
|
-
|
781
|
-
def test_merge_dicts_errors(self):
|
782
|
-
"""Test merge_dicts error handling."""
|
783
|
-
with self.assertRaises(TypeError):
|
784
|
-
merge_dicts({'a': 1}, [1, 2, 3])
|
785
|
-
|
786
|
-
def test_flatten_dict(self):
|
787
|
-
"""Test dictionary flattening."""
|
788
|
-
nested = {
|
789
|
-
'a': 1,
|
790
|
-
'b': {
|
791
|
-
'c': 2,
|
792
|
-
'd': {
|
793
|
-
'e': 3,
|
794
|
-
'f': 4
|
795
|
-
}
|
796
|
-
},
|
797
|
-
'g': 5
|
798
|
-
}
|
799
|
-
|
800
|
-
flat = flatten_dict(nested)
|
801
|
-
self.assertEqual(flat, {
|
802
|
-
'a': 1,
|
803
|
-
'b.c': 2,
|
804
|
-
'b.d.e': 3,
|
805
|
-
'b.d.f': 4,
|
806
|
-
'g': 5
|
807
|
-
})
|
808
|
-
|
809
|
-
# Custom separator
|
810
|
-
flat_dash = flatten_dict(nested, sep='-')
|
811
|
-
self.assertEqual(flat_dash, {
|
812
|
-
'a': 1,
|
813
|
-
'b-c': 2,
|
814
|
-
'b-d-e': 3,
|
815
|
-
'b-d-f': 4,
|
816
|
-
'g': 5
|
817
|
-
})
|
818
|
-
|
819
|
-
def test_unflatten_dict(self):
|
820
|
-
"""Test dictionary unflattening."""
|
821
|
-
flat = {
|
822
|
-
'a': 1,
|
823
|
-
'b.c': 2,
|
824
|
-
'b.d.e': 3,
|
825
|
-
'b.d.f': 4,
|
826
|
-
'g': 5
|
827
|
-
}
|
828
|
-
|
829
|
-
nested = unflatten_dict(flat)
|
830
|
-
self.assertEqual(nested, {
|
831
|
-
'a': 1,
|
832
|
-
'b': {
|
833
|
-
'c': 2,
|
834
|
-
'd': {
|
835
|
-
'e': 3,
|
836
|
-
'f': 4
|
837
|
-
}
|
838
|
-
},
|
839
|
-
'g': 5
|
840
|
-
})
|
841
|
-
|
842
|
-
# Custom separator
|
843
|
-
flat_dash = {'a': 1, 'b-c': 2, 'b-d': 3}
|
844
|
-
nested_dash = unflatten_dict(flat_dash, sep='-')
|
845
|
-
self.assertEqual(nested_dash, {
|
846
|
-
'a': 1,
|
847
|
-
'b': {'c': 2, 'd': 3}
|
848
|
-
})
|
849
|
-
|
850
|
-
def test_flatten_unflatten_roundtrip(self):
|
851
|
-
"""Test that flatten/unflatten is reversible."""
|
852
|
-
original = {
|
853
|
-
'level1': {
|
854
|
-
'level2': {
|
855
|
-
'level3': 'value',
|
856
|
-
'another': 42
|
857
|
-
},
|
858
|
-
'sibling': 'data'
|
859
|
-
},
|
860
|
-
'root': 'element'
|
861
|
-
}
|
862
|
-
|
863
|
-
flattened = flatten_dict(original)
|
864
|
-
unflattened = unflatten_dict(flattened)
|
865
|
-
self.assertEqual(unflattened, original)
|
866
|
-
|
867
|
-
def test_is_instance_eval(self):
|
868
|
-
"""Test is_instance_eval function."""
|
869
|
-
# Single type
|
870
|
-
is_int = is_instance_eval(int)
|
871
|
-
self.assertTrue(is_int(5))
|
872
|
-
self.assertFalse(is_int("5"))
|
873
|
-
self.assertFalse(is_int(5.0))
|
874
|
-
|
875
|
-
# Multiple types
|
876
|
-
is_number = is_instance_eval(int, float)
|
877
|
-
self.assertTrue(is_number(5))
|
878
|
-
self.assertTrue(is_number(5.0))
|
879
|
-
self.assertFalse(is_number("5"))
|
880
|
-
|
881
|
-
# With subclasses
|
882
|
-
class MyList(list):
|
883
|
-
pass
|
884
|
-
|
885
|
-
is_list = is_instance_eval(list)
|
886
|
-
self.assertTrue(is_list([1, 2, 3]))
|
887
|
-
self.assertTrue(is_list(MyList([1, 2, 3])))
|
888
|
-
self.assertFalse(is_list((1, 2, 3)))
|
889
|
-
|
890
|
-
def test_not_instance_eval(self):
|
891
|
-
"""Test not_instance_eval function."""
|
892
|
-
# Single type
|
893
|
-
not_int = not_instance_eval(int)
|
894
|
-
self.assertFalse(not_int(5))
|
895
|
-
self.assertTrue(not_int("5"))
|
896
|
-
self.assertTrue(not_int(5.0))
|
897
|
-
|
898
|
-
# Multiple types
|
899
|
-
not_number = not_instance_eval(int, float)
|
900
|
-
self.assertFalse(not_number(5))
|
901
|
-
self.assertFalse(not_number(5.0))
|
902
|
-
self.assertTrue(not_number("5"))
|
903
|
-
self.assertTrue(not_number([1, 2, 3]))
|
904
|
-
|
905
|
-
|
906
|
-
class TestJaxIntegration(unittest.TestCase):
|
907
|
-
"""Test JAX integration for DictManager and DotDict."""
|
908
|
-
|
909
|
-
def test_dictmanager_pytree_operations(self):
|
910
|
-
"""Test DictManager with JAX tree operations."""
|
911
|
-
dm = DictManager({
|
912
|
-
'weights': jnp.array([1.0, 2.0, 3.0]),
|
913
|
-
'bias': jnp.array([0.1, 0.2])
|
914
|
-
})
|
915
|
-
|
916
|
-
# Tree map
|
917
|
-
scaled = jax.tree_util.tree_map(lambda x: x * 2, dm)
|
918
|
-
self.assertTrue(jnp.allclose(scaled['weights'], jnp.array([2.0, 4.0, 6.0])))
|
919
|
-
self.assertTrue(jnp.allclose(scaled['bias'], jnp.array([0.2, 0.4])))
|
920
|
-
|
921
|
-
# Tree reduce
|
922
|
-
total = jax.tree_util.tree_reduce(lambda x, y: x + y.sum(), dm, 0.0)
|
923
|
-
self.assertAlmostEqual(total, 6.3, places=5)
|
924
|
-
|
925
|
-
def test_dotdict_pytree_operations(self):
|
926
|
-
"""Test DotDict with JAX tree operations."""
|
927
|
-
dd = DotDict({
|
928
|
-
'model': {
|
929
|
-
'weights': jnp.array([[1.0, 2.0], [3.0, 4.0]]),
|
930
|
-
'bias': jnp.array([0.1, 0.2])
|
931
|
-
},
|
932
|
-
'optimizer': {
|
933
|
-
'lr': 0.01
|
934
|
-
}
|
935
|
-
})
|
936
|
-
|
937
|
-
# Tree map on nested structure
|
938
|
-
def scale_arrays(x):
|
939
|
-
return x * 2 if isinstance(x, jnp.ndarray) else x
|
940
|
-
|
941
|
-
scaled = jax.tree_util.tree_map(scale_arrays, dd)
|
942
|
-
self.assertTrue(jnp.allclose(
|
943
|
-
scaled.model.weights,
|
944
|
-
jnp.array([[2.0, 4.0], [6.0, 8.0]])
|
945
|
-
))
|
946
|
-
self.assertEqual(scaled.optimizer.lr, 0.01) # Non-array unchanged
|
947
|
-
|
948
|
-
def test_mixed_pytree_structures(self):
|
949
|
-
"""Test mixing DictManager and DotDict in pytree operations."""
|
950
|
-
structure = {
|
951
|
-
'dict_manager': DictManager({'a': jnp.array([1, 2])})
|
952
|
-
}
|
953
|
-
|
954
|
-
doubled = jax.tree_util.tree_map(lambda x: x * 2, structure)
|
955
|
-
self.assertTrue(jnp.allclose(
|
956
|
-
doubled['dict_manager']['a'],
|
957
|
-
jnp.array([2, 4])
|
958
|
-
))
|
959
|
-
|
960
|
-
|
961
|
-
if __name__ == '__main__':
|
962
|
-
unittest.main()
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""
|
17
|
+
Comprehensive tests for the others module.
|
18
|
+
"""
|
19
|
+
|
20
|
+
import pickle
|
21
|
+
import threading
|
22
|
+
import unittest
|
23
|
+
from unittest.mock import MagicMock, patch
|
24
|
+
|
25
|
+
import jax
|
26
|
+
import jax.numpy as jnp
|
27
|
+
|
28
|
+
from brainstate.util._others import (
|
29
|
+
DictManager,
|
30
|
+
DotDict,
|
31
|
+
NameContext,
|
32
|
+
clear_buffer_memory,
|
33
|
+
flatten_dict,
|
34
|
+
get_unique_name,
|
35
|
+
is_instance_eval,
|
36
|
+
merge_dicts,
|
37
|
+
not_instance_eval,
|
38
|
+
split_total,
|
39
|
+
unflatten_dict,
|
40
|
+
)
|
41
|
+
|
42
|
+
|
43
|
+
class TestSplitTotal(unittest.TestCase):
|
44
|
+
"""Test cases for split_total function."""
|
45
|
+
|
46
|
+
def test_float_fraction(self):
|
47
|
+
"""Test with float fraction values."""
|
48
|
+
self.assertEqual(split_total(100, 0.5), 50)
|
49
|
+
self.assertEqual(split_total(100, 0.25), 25)
|
50
|
+
self.assertEqual(split_total(100, 0.75), 75)
|
51
|
+
self.assertEqual(split_total(100, 0.0), 0)
|
52
|
+
self.assertEqual(split_total(100, 1.0), 100)
|
53
|
+
|
54
|
+
def test_int_fraction(self):
|
55
|
+
"""Test with integer fraction values."""
|
56
|
+
self.assertEqual(split_total(100, 25), 25)
|
57
|
+
self.assertEqual(split_total(100, 0), 0)
|
58
|
+
self.assertEqual(split_total(100, 100), 100)
|
59
|
+
self.assertEqual(split_total(50, 30), 30)
|
60
|
+
|
61
|
+
def test_edge_cases(self):
|
62
|
+
"""Test edge cases."""
|
63
|
+
self.assertEqual(split_total(1, 0.5), 0) # int(0.5) = 0
|
64
|
+
self.assertEqual(split_total(1, 1), 1)
|
65
|
+
self.assertEqual(split_total(10, 0.99), 9) # int(9.9) = 9
|
66
|
+
|
67
|
+
def test_type_errors(self):
|
68
|
+
"""Test type error handling."""
|
69
|
+
with self.assertRaises(TypeError) as ctx:
|
70
|
+
split_total("100", 0.5)
|
71
|
+
self.assertIn("must be an integer", str(ctx.exception))
|
72
|
+
|
73
|
+
with self.assertRaises(TypeError) as ctx:
|
74
|
+
split_total(100, "0.5")
|
75
|
+
self.assertIn("must be an integer or float", str(ctx.exception))
|
76
|
+
|
77
|
+
with self.assertRaises(TypeError) as ctx:
|
78
|
+
split_total(100.5, 0.5)
|
79
|
+
self.assertIn("must be an integer", str(ctx.exception))
|
80
|
+
|
81
|
+
def test_value_errors(self):
|
82
|
+
"""Test value error handling."""
|
83
|
+
# Negative total
|
84
|
+
with self.assertRaises(ValueError) as ctx:
|
85
|
+
split_total(-10, 0.5)
|
86
|
+
self.assertIn("must be a positive integer", str(ctx.exception))
|
87
|
+
|
88
|
+
# Zero total
|
89
|
+
with self.assertRaises(ValueError) as ctx:
|
90
|
+
split_total(0, 0.5)
|
91
|
+
self.assertIn("must be a positive integer", str(ctx.exception))
|
92
|
+
|
93
|
+
# Negative fraction (float)
|
94
|
+
with self.assertRaises(ValueError) as ctx:
|
95
|
+
split_total(100, -0.5)
|
96
|
+
self.assertIn("cannot be negative", str(ctx.exception))
|
97
|
+
|
98
|
+
# Fraction > 1 (float)
|
99
|
+
with self.assertRaises(ValueError) as ctx:
|
100
|
+
split_total(100, 1.5)
|
101
|
+
self.assertIn("cannot be greater than 1", str(ctx.exception))
|
102
|
+
|
103
|
+
# Negative fraction (int)
|
104
|
+
with self.assertRaises(ValueError) as ctx:
|
105
|
+
split_total(100, -10)
|
106
|
+
self.assertIn("cannot be negative", str(ctx.exception))
|
107
|
+
|
108
|
+
# Fraction > total (int)
|
109
|
+
with self.assertRaises(ValueError) as ctx:
|
110
|
+
split_total(100, 150)
|
111
|
+
self.assertIn("cannot be greater than total", str(ctx.exception))
|
112
|
+
|
113
|
+
|
114
|
+
class TestNameContext(unittest.TestCase):
|
115
|
+
"""Test cases for NameContext and get_unique_name."""
|
116
|
+
|
117
|
+
def setUp(self):
|
118
|
+
"""Reset the global NAME context before each test."""
|
119
|
+
global NAME
|
120
|
+
from brainstate.util._others import NAME
|
121
|
+
NAME.typed_names.clear()
|
122
|
+
|
123
|
+
def test_get_unique_name_basic(self):
|
124
|
+
"""Test basic unique name generation."""
|
125
|
+
name1 = get_unique_name('layer')
|
126
|
+
name2 = get_unique_name('layer')
|
127
|
+
name3 = get_unique_name('layer')
|
128
|
+
|
129
|
+
self.assertEqual(name1, 'layer0')
|
130
|
+
self.assertEqual(name2, 'layer1')
|
131
|
+
self.assertEqual(name3, 'layer2')
|
132
|
+
|
133
|
+
def test_get_unique_name_with_prefix(self):
|
134
|
+
"""Test unique name generation with prefix."""
|
135
|
+
name1 = get_unique_name('layer', 'conv_')
|
136
|
+
name2 = get_unique_name('layer', 'conv_')
|
137
|
+
name3 = get_unique_name('layer', 'dense_')
|
138
|
+
|
139
|
+
self.assertEqual(name1, 'conv_layer0')
|
140
|
+
self.assertEqual(name2, 'conv_layer1')
|
141
|
+
self.assertEqual(name3, 'dense_layer2')
|
142
|
+
|
143
|
+
def test_different_types(self):
|
144
|
+
"""Test unique names for different types."""
|
145
|
+
layer1 = get_unique_name('layer')
|
146
|
+
neuron1 = get_unique_name('neuron')
|
147
|
+
layer2 = get_unique_name('layer')
|
148
|
+
neuron2 = get_unique_name('neuron')
|
149
|
+
|
150
|
+
self.assertEqual(layer1, 'layer0')
|
151
|
+
self.assertEqual(neuron1, 'neuron0')
|
152
|
+
self.assertEqual(layer2, 'layer1')
|
153
|
+
self.assertEqual(neuron2, 'neuron1')
|
154
|
+
|
155
|
+
def test_thread_local_context(self):
|
156
|
+
"""Test that NameContext is thread-local."""
|
157
|
+
results = {}
|
158
|
+
|
159
|
+
def worker(thread_id):
|
160
|
+
name1 = get_unique_name(f'type_{thread_id}')
|
161
|
+
name2 = get_unique_name(f'type_{thread_id}')
|
162
|
+
results[thread_id] = (name1, name2)
|
163
|
+
|
164
|
+
threads = []
|
165
|
+
for i in range(3):
|
166
|
+
t = threading.Thread(target=worker, args=(i,))
|
167
|
+
threads.append(t)
|
168
|
+
t.start()
|
169
|
+
|
170
|
+
for t in threads:
|
171
|
+
t.join()
|
172
|
+
|
173
|
+
# Each thread should have independent counters
|
174
|
+
for thread_id in range(3):
|
175
|
+
self.assertEqual(results[thread_id][0], f'type_{thread_id}0')
|
176
|
+
self.assertEqual(results[thread_id][1], f'type_{thread_id}1')
|
177
|
+
|
178
|
+
def test_name_context_reset(self):
|
179
|
+
"""Test resetting name context."""
|
180
|
+
context = NameContext()
|
181
|
+
context.typed_names['test'] = 5
|
182
|
+
|
183
|
+
# Reset specific type
|
184
|
+
context.reset('test')
|
185
|
+
self.assertEqual(context.typed_names.get('test'), 0)
|
186
|
+
|
187
|
+
# Add more names
|
188
|
+
context.typed_names['test1'] = 1
|
189
|
+
context.typed_names['test2'] = 2
|
190
|
+
|
191
|
+
# Reset all
|
192
|
+
context.reset()
|
193
|
+
self.assertEqual(len(context.typed_names), 0)
|
194
|
+
|
195
|
+
|
196
|
+
class TestDictManager(unittest.TestCase):
|
197
|
+
"""Test cases for DictManager class."""
|
198
|
+
|
199
|
+
def test_initialization(self):
|
200
|
+
"""Test DictManager initialization."""
|
201
|
+
# Empty initialization
|
202
|
+
dm1 = DictManager()
|
203
|
+
self.assertEqual(len(dm1), 0)
|
204
|
+
|
205
|
+
# From dict
|
206
|
+
dm2 = DictManager({'a': 1, 'b': 2})
|
207
|
+
self.assertEqual(dm2['a'], 1)
|
208
|
+
self.assertEqual(dm2['b'], 2)
|
209
|
+
|
210
|
+
# From kwargs
|
211
|
+
dm3 = DictManager(a=1, b=2)
|
212
|
+
self.assertEqual(dm3['a'], 1)
|
213
|
+
self.assertEqual(dm3['b'], 2)
|
214
|
+
|
215
|
+
# From items
|
216
|
+
dm4 = DictManager([('a', 1), ('b', 2)])
|
217
|
+
self.assertEqual(dm4['a'], 1)
|
218
|
+
self.assertEqual(dm4['b'], 2)
|
219
|
+
|
220
|
+
def test_subset(self):
|
221
|
+
"""Test subset filtering."""
|
222
|
+
dm = DictManager({'a': 1, 'b': 2.0, 'c': 'text', 'd': 3})
|
223
|
+
|
224
|
+
# By type
|
225
|
+
int_subset = dm.subset(int)
|
226
|
+
self.assertEqual(dict(int_subset), {'a': 1, 'd': 3})
|
227
|
+
|
228
|
+
# By multiple types
|
229
|
+
num_subset = dm.subset((int, float))
|
230
|
+
self.assertEqual(dict(num_subset), {'a': 1, 'b': 2.0, 'd': 3})
|
231
|
+
|
232
|
+
# By predicate
|
233
|
+
large_subset = dm.subset(lambda x: isinstance(x, (int, float)) and x > 1.5)
|
234
|
+
self.assertEqual(dict(large_subset), {'b': 2.0, 'd': 3})
|
235
|
+
|
236
|
+
def test_not_subset(self):
|
237
|
+
"""Test not_subset filtering."""
|
238
|
+
dm = DictManager({'a': 1, 'b': 2.0, 'c': 'text', 'd': 3})
|
239
|
+
|
240
|
+
not_int = dm.not_subset(int)
|
241
|
+
self.assertEqual(dict(not_int), {'b': 2.0, 'c': 'text'})
|
242
|
+
|
243
|
+
not_num = dm.not_subset((int, float))
|
244
|
+
self.assertEqual(dict(not_num), {'c': 'text'})
|
245
|
+
|
246
|
+
def test_add_unique_key(self):
|
247
|
+
"""Test adding unique keys."""
|
248
|
+
dm = DictManager()
|
249
|
+
obj1 = object()
|
250
|
+
obj2 = object()
|
251
|
+
|
252
|
+
# Add new key
|
253
|
+
dm.add_unique_key('key1', obj1)
|
254
|
+
self.assertIs(dm['key1'], obj1)
|
255
|
+
|
256
|
+
# Add same key with same object (should work)
|
257
|
+
dm.add_unique_key('key1', obj1)
|
258
|
+
self.assertIs(dm['key1'], obj1)
|
259
|
+
|
260
|
+
# Add same key with different object (should fail)
|
261
|
+
with self.assertRaises(ValueError) as ctx:
|
262
|
+
dm.add_unique_key('key1', obj2)
|
263
|
+
self.assertIn("already exists with a different value", str(ctx.exception))
|
264
|
+
|
265
|
+
def test_add_unique_value(self):
|
266
|
+
"""Test adding unique values."""
|
267
|
+
dm = DictManager()
|
268
|
+
obj1 = object()
|
269
|
+
|
270
|
+
# First addition should succeed
|
271
|
+
result1 = dm.add_unique_value('key1', obj1)
|
272
|
+
self.assertTrue(result1)
|
273
|
+
self.assertIs(dm['key1'], obj1)
|
274
|
+
|
275
|
+
# Adding same value with different key should fail
|
276
|
+
result2 = dm.add_unique_value('key2', obj1)
|
277
|
+
self.assertFalse(result2)
|
278
|
+
self.assertNotIn('key2', dm)
|
279
|
+
|
280
|
+
# Adding different value should succeed
|
281
|
+
obj2 = object()
|
282
|
+
result3 = dm.add_unique_value('key2', obj2)
|
283
|
+
self.assertTrue(result3)
|
284
|
+
self.assertIs(dm['key2'], obj2)
|
285
|
+
|
286
|
+
def test_unique(self):
|
287
|
+
"""Test getting unique values."""
|
288
|
+
obj1 = object()
|
289
|
+
obj2 = object()
|
290
|
+
dm = DictManager({'a': obj1, 'b': obj2, 'c': obj1, 'd': obj2, 'e': obj1})
|
291
|
+
|
292
|
+
unique_dm = dm.unique()
|
293
|
+
self.assertEqual(len(unique_dm), 2)
|
294
|
+
# Check that each object appears only once
|
295
|
+
values = list(unique_dm.values())
|
296
|
+
self.assertEqual(len(set(id(v) for v in values)), 2)
|
297
|
+
|
298
|
+
def test_unique_inplace(self):
|
299
|
+
"""Test in-place unique operation."""
|
300
|
+
obj1 = object()
|
301
|
+
obj2 = object()
|
302
|
+
dm = DictManager({'a': obj1, 'b': obj2, 'c': obj1})
|
303
|
+
|
304
|
+
result = dm.unique_()
|
305
|
+
self.assertIs(result, dm) # Should return self
|
306
|
+
self.assertEqual(len(dm), 2) # One duplicate removed
|
307
|
+
|
308
|
+
def test_assign(self):
|
309
|
+
"""Test assign method."""
|
310
|
+
dm = DictManager({'a': 1})
|
311
|
+
|
312
|
+
# Assign from dict
|
313
|
+
dm.assign({'b': 2, 'c': 3})
|
314
|
+
self.assertEqual(dict(dm), {'a': 1, 'b': 2, 'c': 3})
|
315
|
+
|
316
|
+
# Assign from multiple dicts
|
317
|
+
dm.assign({'d': 4}, {'e': 5})
|
318
|
+
self.assertEqual(len(dm), 5)
|
319
|
+
|
320
|
+
# Assign with kwargs
|
321
|
+
dm.assign(f=6, g=7)
|
322
|
+
self.assertEqual(dm['f'], 6)
|
323
|
+
self.assertEqual(dm['g'], 7)
|
324
|
+
|
325
|
+
# Invalid argument
|
326
|
+
with self.assertRaises(TypeError):
|
327
|
+
dm.assign([1, 2, 3])
|
328
|
+
|
329
|
+
def test_split(self):
|
330
|
+
"""Test splitting by types."""
|
331
|
+
dm = DictManager({
|
332
|
+
'a': 1, 'b': 2.0, 'c': 'text',
|
333
|
+
'd': 3, 'e': 4.5, 'f': [1, 2]
|
334
|
+
})
|
335
|
+
|
336
|
+
int_dm, float_dm, rest = dm.split(int, float)
|
337
|
+
|
338
|
+
self.assertEqual(dict(int_dm), {'a': 1, 'd': 3})
|
339
|
+
self.assertEqual(dict(float_dm), {'b': 2.0, 'e': 4.5})
|
340
|
+
self.assertEqual(dict(rest), {'c': 'text', 'f': [1, 2]})
|
341
|
+
|
342
|
+
def test_filter_by_predicate(self):
|
343
|
+
"""Test filtering with predicate."""
|
344
|
+
dm = DictManager({'a': 1, 'b': 2, 'c': 3, 'd': 4})
|
345
|
+
|
346
|
+
# Filter by key
|
347
|
+
filtered = dm.filter_by_predicate(lambda k, v: k in ['a', 'c'])
|
348
|
+
self.assertEqual(dict(filtered), {'a': 1, 'c': 3})
|
349
|
+
|
350
|
+
# Filter by value
|
351
|
+
filtered = dm.filter_by_predicate(lambda k, v: v > 2)
|
352
|
+
self.assertEqual(dict(filtered), {'c': 3, 'd': 4})
|
353
|
+
|
354
|
+
# Filter by both
|
355
|
+
filtered = dm.filter_by_predicate(lambda k, v: k == 'a' or v == 4)
|
356
|
+
self.assertEqual(dict(filtered), {'a': 1, 'd': 4})
|
357
|
+
|
358
|
+
def test_map_values(self):
|
359
|
+
"""Test mapping function to values."""
|
360
|
+
dm = DictManager({'a': 1, 'b': 2, 'c': 3})
|
361
|
+
|
362
|
+
doubled = dm.map_values(lambda x: x * 2)
|
363
|
+
self.assertEqual(dict(doubled), {'a': 2, 'b': 4, 'c': 6})
|
364
|
+
|
365
|
+
# Original should be unchanged
|
366
|
+
self.assertEqual(dict(dm), {'a': 1, 'b': 2, 'c': 3})
|
367
|
+
|
368
|
+
def test_map_keys(self):
|
369
|
+
"""Test mapping function to keys."""
|
370
|
+
dm = DictManager({'a': 1, 'b': 2, 'c': 3})
|
371
|
+
|
372
|
+
upper = dm.map_keys(str.upper)
|
373
|
+
self.assertEqual(dict(upper), {'A': 1, 'B': 2, 'C': 3})
|
374
|
+
|
375
|
+
# Test duplicate key error
|
376
|
+
with self.assertRaises(ValueError) as ctx:
|
377
|
+
dm.map_keys(lambda x: 'same')
|
378
|
+
self.assertIn("duplicate", str(ctx.exception))
|
379
|
+
|
380
|
+
def test_pop_by_keys(self):
|
381
|
+
"""Test removing items by keys."""
|
382
|
+
dm = DictManager({'a': 1, 'b': 2, 'c': 3, 'd': 4})
|
383
|
+
|
384
|
+
dm.pop_by_keys(['b', 'd'])
|
385
|
+
self.assertEqual(dict(dm), {'a': 1, 'c': 3})
|
386
|
+
|
387
|
+
# Pop non-existent keys (should not raise)
|
388
|
+
dm.pop_by_keys(['x', 'y'])
|
389
|
+
self.assertEqual(dict(dm), {'a': 1, 'c': 3})
|
390
|
+
|
391
|
+
def test_pop_by_values(self):
|
392
|
+
"""Test removing items by values."""
|
393
|
+
obj1, obj2, obj3 = object(), object(), object()
|
394
|
+
dm = DictManager({'a': obj1, 'b': obj2, 'c': obj3})
|
395
|
+
|
396
|
+
# By identity
|
397
|
+
dm_copy = DictManager(dm)
|
398
|
+
dm_copy.pop_by_values([obj2], by='id')
|
399
|
+
self.assertEqual(len(dm_copy), 2)
|
400
|
+
self.assertNotIn('b', dm_copy)
|
401
|
+
|
402
|
+
# By value equality
|
403
|
+
dm2 = DictManager({'a': 1, 'b': 2, 'c': 3})
|
404
|
+
dm2.pop_by_values([2, 3], by='value')
|
405
|
+
self.assertEqual(dict(dm2), {'a': 1})
|
406
|
+
|
407
|
+
# Invalid method
|
408
|
+
with self.assertRaises(ValueError):
|
409
|
+
dm.pop_by_values([obj1], by='invalid')
|
410
|
+
|
411
|
+
def test_difference_operations(self):
|
412
|
+
"""Test difference operations."""
|
413
|
+
dm = DictManager({'a': 1, 'b': 2, 'c': 3, 'd': 4})
|
414
|
+
|
415
|
+
# Difference by keys
|
416
|
+
diff = dm.difference_by_keys(['b', 'd'])
|
417
|
+
self.assertEqual(dict(diff), {'a': 1, 'c': 3})
|
418
|
+
|
419
|
+
# Difference by values
|
420
|
+
diff = dm.difference_by_values([2, 4], by='value')
|
421
|
+
self.assertEqual(dict(diff), {'a': 1, 'c': 3})
|
422
|
+
|
423
|
+
def test_intersection_operations(self):
|
424
|
+
"""Test intersection operations."""
|
425
|
+
dm = DictManager({'a': 1, 'b': 2, 'c': 3, 'd': 4})
|
426
|
+
|
427
|
+
# Intersection by keys
|
428
|
+
inter = dm.intersection_by_keys(['a', 'c', 'e'])
|
429
|
+
self.assertEqual(dict(inter), {'a': 1, 'c': 3})
|
430
|
+
|
431
|
+
# Intersection by values
|
432
|
+
inter = dm.intersection_by_values([1, 3, 5], by='value')
|
433
|
+
self.assertEqual(dict(inter), {'a': 1, 'c': 3})
|
434
|
+
|
435
|
+
def test_operators(self):
|
436
|
+
"""Test operator overloading."""
|
437
|
+
dm1 = DictManager({'a': 1, 'b': 2})
|
438
|
+
dm2 = DictManager({'c': 3, 'd': 4})
|
439
|
+
|
440
|
+
# Addition
|
441
|
+
dm3 = dm1 + dm2
|
442
|
+
self.assertEqual(dict(dm3), {'a': 1, 'b': 2, 'c': 3, 'd': 4})
|
443
|
+
self.assertIsNot(dm3, dm1) # New object
|
444
|
+
|
445
|
+
# Addition with regular dict
|
446
|
+
dm4 = dm1 + {'e': 5}
|
447
|
+
self.assertEqual(dm4['e'], 5)
|
448
|
+
|
449
|
+
# Union operator (Python 3.9+) - only test if available
|
450
|
+
import sys
|
451
|
+
if sys.version_info >= (3, 9):
|
452
|
+
dm5 = dm1 | dm2
|
453
|
+
self.assertEqual(dict(dm5), {'a': 1, 'b': 2, 'c': 3, 'd': 4})
|
454
|
+
|
455
|
+
# In-place union
|
456
|
+
dm1_copy = DictManager(dm1)
|
457
|
+
dm1_copy |= dm2
|
458
|
+
self.assertEqual(dict(dm1_copy), {'a': 1, 'b': 2, 'c': 3, 'd': 4})
|
459
|
+
|
460
|
+
# Invalid operations should raise TypeError or return NotImplemented
|
461
|
+
# The actual behavior depends on whether __add__ returns NotImplemented
|
462
|
+
# or lets Python raise TypeError
|
463
|
+
with self.assertRaises(TypeError):
|
464
|
+
_ = dm1 + 123
|
465
|
+
if sys.version_info >= (3, 9):
|
466
|
+
with self.assertRaises(TypeError):
|
467
|
+
_ = dm1 | 123
|
468
|
+
|
469
|
+
def test_copy_operations(self):
|
470
|
+
"""Test copy operations."""
|
471
|
+
obj = object()
|
472
|
+
dm1 = DictManager({'a': 1, 'b': obj})
|
473
|
+
|
474
|
+
# Shallow copy
|
475
|
+
dm2 = dm1.__copy__()
|
476
|
+
self.assertIsNot(dm2, dm1)
|
477
|
+
self.assertEqual(dict(dm2), dict(dm1))
|
478
|
+
self.assertIs(dm2['b'], obj) # Same object reference
|
479
|
+
|
480
|
+
# Deep copy
|
481
|
+
dm3 = dm1.__deepcopy__({})
|
482
|
+
self.assertIsNot(dm3, dm1)
|
483
|
+
self.assertEqual(dm3['a'], 1)
|
484
|
+
# Note: object() can't be deep copied, but other values work
|
485
|
+
|
486
|
+
def test_jax_pytree(self):
|
487
|
+
"""Test JAX pytree registration."""
|
488
|
+
dm = DictManager({'a': 1, 'b': 2, 'c': 3})
|
489
|
+
|
490
|
+
# Flatten
|
491
|
+
values, keys = dm.tree_flatten()
|
492
|
+
self.assertEqual(len(values), 3)
|
493
|
+
self.assertEqual(len(keys), 3)
|
494
|
+
|
495
|
+
# Unflatten
|
496
|
+
dm2 = DictManager.tree_unflatten(keys, values)
|
497
|
+
self.assertEqual(dict(dm2), dict(dm))
|
498
|
+
|
499
|
+
# Test with JAX tree operations
|
500
|
+
dm3 = DictManager({'x': jnp.array([1, 2]), 'y': jnp.array([3, 4])})
|
501
|
+
doubled = jax.tree_util.tree_map(lambda x: x * 2, dm3)
|
502
|
+
self.assertTrue(jnp.allclose(doubled['x'], jnp.array([2, 4])))
|
503
|
+
self.assertTrue(jnp.allclose(doubled['y'], jnp.array([6, 8])))
|
504
|
+
|
505
|
+
def test_repr(self):
|
506
|
+
"""Test string representation."""
|
507
|
+
dm = DictManager({'a': 1, 'b': 'text'})
|
508
|
+
repr_str = repr(dm)
|
509
|
+
self.assertIn('DictManager', repr_str)
|
510
|
+
self.assertIn("'a': 1", repr_str)
|
511
|
+
self.assertIn("'b': 'text'", repr_str)
|
512
|
+
|
513
|
+
|
514
|
+
class TestDotDict(unittest.TestCase):
|
515
|
+
"""Test cases for DotDict class."""
|
516
|
+
|
517
|
+
def test_initialization(self):
|
518
|
+
"""Test DotDict initialization."""
|
519
|
+
# From dict
|
520
|
+
dd1 = DotDict({'a': 1, 'b': 2})
|
521
|
+
self.assertEqual(dd1.a, 1)
|
522
|
+
self.assertEqual(dd1['b'], 2)
|
523
|
+
|
524
|
+
# From kwargs
|
525
|
+
dd2 = DotDict(a=1, b=2)
|
526
|
+
self.assertEqual(dd2.a, 1)
|
527
|
+
self.assertEqual(dd2.b, 2)
|
528
|
+
|
529
|
+
# From tuple
|
530
|
+
dd3 = DotDict(('key', 'value'))
|
531
|
+
self.assertEqual(dd3.key, 'value')
|
532
|
+
|
533
|
+
# From items
|
534
|
+
dd4 = DotDict([('a', 1), ('b', 2)])
|
535
|
+
self.assertEqual(dd4.a, 1)
|
536
|
+
self.assertEqual(dd4.b, 2)
|
537
|
+
|
538
|
+
# Empty
|
539
|
+
dd5 = DotDict()
|
540
|
+
self.assertEqual(len(dd5), 0)
|
541
|
+
|
542
|
+
# Invalid argument
|
543
|
+
with self.assertRaises(TypeError):
|
544
|
+
DotDict(123)
|
545
|
+
|
546
|
+
def test_dot_access(self):
|
547
|
+
"""Test dot notation access."""
|
548
|
+
dd = DotDict({'a': 1, 'b': {'c': 2, 'd': 3}})
|
549
|
+
|
550
|
+
# Read access
|
551
|
+
self.assertEqual(dd.a, 1)
|
552
|
+
self.assertEqual(dd.b.c, 2)
|
553
|
+
self.assertEqual(dd.b.d, 3)
|
554
|
+
|
555
|
+
# Write access
|
556
|
+
dd.a = 10
|
557
|
+
dd.b.c = 20
|
558
|
+
self.assertEqual(dd['a'], 10)
|
559
|
+
self.assertEqual(dd['b']['c'], 20)
|
560
|
+
|
561
|
+
# Add new attributes
|
562
|
+
dd.e = 5
|
563
|
+
self.assertEqual(dd['e'], 5)
|
564
|
+
|
565
|
+
def test_nested_dict_conversion(self):
|
566
|
+
"""Test automatic nested dict conversion."""
|
567
|
+
dd = DotDict({
|
568
|
+
'level1': {
|
569
|
+
'level2': {
|
570
|
+
'level3': 'value'
|
571
|
+
}
|
572
|
+
}
|
573
|
+
})
|
574
|
+
|
575
|
+
self.assertIsInstance(dd.level1, DotDict)
|
576
|
+
self.assertIsInstance(dd.level1.level2, DotDict)
|
577
|
+
self.assertEqual(dd.level1.level2.level3, 'value')
|
578
|
+
|
579
|
+
def test_list_tuple_conversion(self):
|
580
|
+
"""Test conversion of lists and tuples containing dicts."""
|
581
|
+
dd = DotDict({
|
582
|
+
'list': [{'a': 1}, {'b': 2}],
|
583
|
+
'tuple': ({'c': 3}, {'d': 4})
|
584
|
+
})
|
585
|
+
|
586
|
+
# List items should be DotDict
|
587
|
+
self.assertIsInstance(dd.list[0], DotDict)
|
588
|
+
self.assertEqual(dd.list[0].a, 1)
|
589
|
+
self.assertIsInstance(dd.list[1], DotDict)
|
590
|
+
self.assertEqual(dd.list[1].b, 2)
|
591
|
+
|
592
|
+
# Tuple items should be DotDict
|
593
|
+
self.assertIsInstance(dd.tuple[0], DotDict)
|
594
|
+
self.assertEqual(dd.tuple[0].c, 3)
|
595
|
+
|
596
|
+
def test_attribute_errors(self):
|
597
|
+
"""Test attribute error handling."""
|
598
|
+
dd = DotDict({'a': 1})
|
599
|
+
|
600
|
+
# Non-existent attribute
|
601
|
+
with self.assertRaises(AttributeError) as ctx:
|
602
|
+
_ = dd.nonexistent
|
603
|
+
self.assertIn("has no attribute 'nonexistent'", str(ctx.exception))
|
604
|
+
|
605
|
+
# Delete non-existent attribute
|
606
|
+
with self.assertRaises(AttributeError):
|
607
|
+
del dd.nonexistent
|
608
|
+
|
609
|
+
# Try to set built-in method
|
610
|
+
with self.assertRaises(AttributeError) as ctx:
|
611
|
+
dd.keys = 'value'
|
612
|
+
self.assertIn("built-in method", str(ctx.exception))
|
613
|
+
|
614
|
+
def test_dir_method(self):
|
615
|
+
"""Test __dir__ method."""
|
616
|
+
dd = DotDict({'a': 1, 'b': 2})
|
617
|
+
attrs = dir(dd)
|
618
|
+
|
619
|
+
self.assertIn('a', attrs)
|
620
|
+
self.assertIn('b', attrs)
|
621
|
+
self.assertIn('keys', attrs) # Built-in method
|
622
|
+
self.assertIn('values', attrs) # Built-in method
|
623
|
+
|
624
|
+
def test_get_method(self):
|
625
|
+
"""Test get method with default."""
|
626
|
+
dd = DotDict({'a': 1})
|
627
|
+
|
628
|
+
self.assertEqual(dd.get('a'), 1)
|
629
|
+
self.assertEqual(dd.get('b'), None)
|
630
|
+
self.assertEqual(dd.get('b', 'default'), 'default')
|
631
|
+
|
632
|
+
def test_copy_operations(self):
|
633
|
+
"""Test copy operations."""
|
634
|
+
dd1 = DotDict({'a': 1, 'b': {'c': 2}})
|
635
|
+
|
636
|
+
# Shallow copy
|
637
|
+
dd2 = dd1.copy()
|
638
|
+
self.assertIsNot(dd2, dd1)
|
639
|
+
self.assertEqual(dd2.a, 1)
|
640
|
+
dd2.a = 10
|
641
|
+
self.assertEqual(dd1.a, 1) # Original unchanged
|
642
|
+
|
643
|
+
# Deep copy
|
644
|
+
dd3 = dd1.deepcopy()
|
645
|
+
self.assertIsNot(dd3, dd1)
|
646
|
+
self.assertIsNot(dd3.b, dd1.b)
|
647
|
+
dd3.b.c = 20
|
648
|
+
self.assertEqual(dd1.b.c, 2) # Original unchanged
|
649
|
+
|
650
|
+
def test_to_dict_from_dict(self):
|
651
|
+
"""Test conversion to/from standard dict."""
|
652
|
+
dd1 = DotDict({
|
653
|
+
'a': 1,
|
654
|
+
'b': {'c': 2, 'd': {'e': 3}},
|
655
|
+
'list': [{'f': 4}]
|
656
|
+
})
|
657
|
+
|
658
|
+
# Convert to dict
|
659
|
+
d = dd1.to_dict()
|
660
|
+
self.assertIsInstance(d, dict)
|
661
|
+
self.assertNotIsInstance(d, DotDict)
|
662
|
+
self.assertIsInstance(d['b'], dict)
|
663
|
+
self.assertNotIsInstance(d['b'], DotDict)
|
664
|
+
self.assertEqual(d['b']['d']['e'], 3)
|
665
|
+
|
666
|
+
# Convert from dict
|
667
|
+
dd2 = DotDict.from_dict(d)
|
668
|
+
self.assertIsInstance(dd2, DotDict)
|
669
|
+
self.assertIsInstance(dd2.b, DotDict)
|
670
|
+
self.assertEqual(dd2.b.d.e, 3)
|
671
|
+
|
672
|
+
def test_update_method(self):
|
673
|
+
"""Test update with recursive merge."""
|
674
|
+
dd = DotDict({'a': 1, 'b': {'c': 2, 'd': 3}})
|
675
|
+
|
676
|
+
# Simple update
|
677
|
+
dd.update({'a': 10})
|
678
|
+
self.assertEqual(dd.a, 10)
|
679
|
+
|
680
|
+
# Recursive merge
|
681
|
+
dd.update({'b': {'d': 30, 'e': 4}})
|
682
|
+
self.assertEqual(dd.b.c, 2) # Preserved
|
683
|
+
self.assertEqual(dd.b.d, 30) # Updated
|
684
|
+
self.assertEqual(dd.b.e, 4) # Added
|
685
|
+
|
686
|
+
# Update with kwargs
|
687
|
+
dd.update(f=5, g=6)
|
688
|
+
self.assertEqual(dd.f, 5)
|
689
|
+
self.assertEqual(dd.g, 6)
|
690
|
+
|
691
|
+
# Multiple arguments error
|
692
|
+
with self.assertRaises(TypeError):
|
693
|
+
dd.update({}, {})
|
694
|
+
|
695
|
+
def test_setdefault(self):
|
696
|
+
"""Test setdefault method."""
|
697
|
+
dd = DotDict({'a': 1})
|
698
|
+
|
699
|
+
# Existing key
|
700
|
+
result = dd.setdefault('a', 10)
|
701
|
+
self.assertEqual(result, 1)
|
702
|
+
self.assertEqual(dd.a, 1)
|
703
|
+
|
704
|
+
# New key
|
705
|
+
result = dd.setdefault('b', 2)
|
706
|
+
self.assertEqual(result, 2)
|
707
|
+
self.assertEqual(dd.b, 2)
|
708
|
+
|
709
|
+
# New key with None default
|
710
|
+
result = dd.setdefault('c')
|
711
|
+
self.assertIsNone(result)
|
712
|
+
self.assertIsNone(dd.c)
|
713
|
+
|
714
|
+
def test_pickling(self):
|
715
|
+
"""Test pickling/unpickling."""
|
716
|
+
dd1 = DotDict({'a': 1, 'b': {'c': 2}})
|
717
|
+
|
718
|
+
# Pickle and unpickle
|
719
|
+
pickled = pickle.dumps(dd1)
|
720
|
+
dd2 = pickle.loads(pickled)
|
721
|
+
|
722
|
+
self.assertIsNot(dd2, dd1)
|
723
|
+
self.assertEqual(dd2.a, 1)
|
724
|
+
self.assertEqual(dd2.b.c, 2)
|
725
|
+
self.assertIsInstance(dd2, DotDict)
|
726
|
+
self.assertIsInstance(dd2.b, DotDict)
|
727
|
+
|
728
|
+
def test_jax_pytree(self):
|
729
|
+
"""Test JAX pytree registration."""
|
730
|
+
dd = DotDict({'a': jnp.array([1, 2]), 'b': jnp.array([3, 4])})
|
731
|
+
|
732
|
+
# Tree operations
|
733
|
+
doubled = jax.tree_util.tree_map(lambda x: x * 2, dd)
|
734
|
+
self.assertTrue(jnp.allclose(doubled.a, jnp.array([2, 4])))
|
735
|
+
self.assertTrue(jnp.allclose(doubled.b, jnp.array([6, 8])))
|
736
|
+
|
737
|
+
def test_repr(self):
|
738
|
+
"""Test string representation."""
|
739
|
+
dd = DotDict({'a': 1, 'b': 'text'})
|
740
|
+
repr_str = repr(dd)
|
741
|
+
self.assertIn('DotDict', repr_str)
|
742
|
+
self.assertIn("'a': 1", repr_str)
|
743
|
+
self.assertIn("'b': 'text'", repr_str)
|
744
|
+
|
745
|
+
|
746
|
+
class TestUtilityFunctions(unittest.TestCase):
|
747
|
+
"""Test cases for utility functions."""
|
748
|
+
|
749
|
+
def test_merge_dicts_basic(self):
|
750
|
+
"""Test basic dict merging."""
|
751
|
+
d1 = {'a': 1, 'b': 2}
|
752
|
+
d2 = {'b': 3, 'c': 4}
|
753
|
+
d3 = {'d': 5}
|
754
|
+
|
755
|
+
result = merge_dicts(d1, d2, d3)
|
756
|
+
self.assertEqual(result, {'a': 1, 'b': 3, 'c': 4, 'd': 5})
|
757
|
+
|
758
|
+
# Original dicts should be unchanged
|
759
|
+
self.assertEqual(d1, {'a': 1, 'b': 2})
|
760
|
+
|
761
|
+
def test_merge_dicts_recursive(self):
|
762
|
+
"""Test recursive dict merging."""
|
763
|
+
d1 = {'a': 1, 'b': {'c': 2, 'd': 3}}
|
764
|
+
d2 = {'b': {'d': 4, 'e': 5}, 'f': 6}
|
765
|
+
|
766
|
+
result = merge_dicts(d1, d2, recursive=True)
|
767
|
+
self.assertEqual(result, {
|
768
|
+
'a': 1,
|
769
|
+
'b': {'c': 2, 'd': 4, 'e': 5},
|
770
|
+
'f': 6
|
771
|
+
})
|
772
|
+
|
773
|
+
def test_merge_dicts_non_recursive(self):
|
774
|
+
"""Test non-recursive dict merging."""
|
775
|
+
d1 = {'a': 1, 'b': {'c': 2}}
|
776
|
+
d2 = {'b': {'d': 3}}
|
777
|
+
|
778
|
+
result = merge_dicts(d1, d2, recursive=False)
|
779
|
+
self.assertEqual(result, {'a': 1, 'b': {'d': 3}})
|
780
|
+
|
781
|
+
def test_merge_dicts_errors(self):
|
782
|
+
"""Test merge_dicts error handling."""
|
783
|
+
with self.assertRaises(TypeError):
|
784
|
+
merge_dicts({'a': 1}, [1, 2, 3])
|
785
|
+
|
786
|
+
def test_flatten_dict(self):
|
787
|
+
"""Test dictionary flattening."""
|
788
|
+
nested = {
|
789
|
+
'a': 1,
|
790
|
+
'b': {
|
791
|
+
'c': 2,
|
792
|
+
'd': {
|
793
|
+
'e': 3,
|
794
|
+
'f': 4
|
795
|
+
}
|
796
|
+
},
|
797
|
+
'g': 5
|
798
|
+
}
|
799
|
+
|
800
|
+
flat = flatten_dict(nested)
|
801
|
+
self.assertEqual(flat, {
|
802
|
+
'a': 1,
|
803
|
+
'b.c': 2,
|
804
|
+
'b.d.e': 3,
|
805
|
+
'b.d.f': 4,
|
806
|
+
'g': 5
|
807
|
+
})
|
808
|
+
|
809
|
+
# Custom separator
|
810
|
+
flat_dash = flatten_dict(nested, sep='-')
|
811
|
+
self.assertEqual(flat_dash, {
|
812
|
+
'a': 1,
|
813
|
+
'b-c': 2,
|
814
|
+
'b-d-e': 3,
|
815
|
+
'b-d-f': 4,
|
816
|
+
'g': 5
|
817
|
+
})
|
818
|
+
|
819
|
+
def test_unflatten_dict(self):
|
820
|
+
"""Test dictionary unflattening."""
|
821
|
+
flat = {
|
822
|
+
'a': 1,
|
823
|
+
'b.c': 2,
|
824
|
+
'b.d.e': 3,
|
825
|
+
'b.d.f': 4,
|
826
|
+
'g': 5
|
827
|
+
}
|
828
|
+
|
829
|
+
nested = unflatten_dict(flat)
|
830
|
+
self.assertEqual(nested, {
|
831
|
+
'a': 1,
|
832
|
+
'b': {
|
833
|
+
'c': 2,
|
834
|
+
'd': {
|
835
|
+
'e': 3,
|
836
|
+
'f': 4
|
837
|
+
}
|
838
|
+
},
|
839
|
+
'g': 5
|
840
|
+
})
|
841
|
+
|
842
|
+
# Custom separator
|
843
|
+
flat_dash = {'a': 1, 'b-c': 2, 'b-d': 3}
|
844
|
+
nested_dash = unflatten_dict(flat_dash, sep='-')
|
845
|
+
self.assertEqual(nested_dash, {
|
846
|
+
'a': 1,
|
847
|
+
'b': {'c': 2, 'd': 3}
|
848
|
+
})
|
849
|
+
|
850
|
+
def test_flatten_unflatten_roundtrip(self):
|
851
|
+
"""Test that flatten/unflatten is reversible."""
|
852
|
+
original = {
|
853
|
+
'level1': {
|
854
|
+
'level2': {
|
855
|
+
'level3': 'value',
|
856
|
+
'another': 42
|
857
|
+
},
|
858
|
+
'sibling': 'data'
|
859
|
+
},
|
860
|
+
'root': 'element'
|
861
|
+
}
|
862
|
+
|
863
|
+
flattened = flatten_dict(original)
|
864
|
+
unflattened = unflatten_dict(flattened)
|
865
|
+
self.assertEqual(unflattened, original)
|
866
|
+
|
867
|
+
def test_is_instance_eval(self):
|
868
|
+
"""Test is_instance_eval function."""
|
869
|
+
# Single type
|
870
|
+
is_int = is_instance_eval(int)
|
871
|
+
self.assertTrue(is_int(5))
|
872
|
+
self.assertFalse(is_int("5"))
|
873
|
+
self.assertFalse(is_int(5.0))
|
874
|
+
|
875
|
+
# Multiple types
|
876
|
+
is_number = is_instance_eval(int, float)
|
877
|
+
self.assertTrue(is_number(5))
|
878
|
+
self.assertTrue(is_number(5.0))
|
879
|
+
self.assertFalse(is_number("5"))
|
880
|
+
|
881
|
+
# With subclasses
|
882
|
+
class MyList(list):
|
883
|
+
pass
|
884
|
+
|
885
|
+
is_list = is_instance_eval(list)
|
886
|
+
self.assertTrue(is_list([1, 2, 3]))
|
887
|
+
self.assertTrue(is_list(MyList([1, 2, 3])))
|
888
|
+
self.assertFalse(is_list((1, 2, 3)))
|
889
|
+
|
890
|
+
def test_not_instance_eval(self):
|
891
|
+
"""Test not_instance_eval function."""
|
892
|
+
# Single type
|
893
|
+
not_int = not_instance_eval(int)
|
894
|
+
self.assertFalse(not_int(5))
|
895
|
+
self.assertTrue(not_int("5"))
|
896
|
+
self.assertTrue(not_int(5.0))
|
897
|
+
|
898
|
+
# Multiple types
|
899
|
+
not_number = not_instance_eval(int, float)
|
900
|
+
self.assertFalse(not_number(5))
|
901
|
+
self.assertFalse(not_number(5.0))
|
902
|
+
self.assertTrue(not_number("5"))
|
903
|
+
self.assertTrue(not_number([1, 2, 3]))
|
904
|
+
|
905
|
+
|
906
|
+
class TestJaxIntegration(unittest.TestCase):
|
907
|
+
"""Test JAX integration for DictManager and DotDict."""
|
908
|
+
|
909
|
+
def test_dictmanager_pytree_operations(self):
|
910
|
+
"""Test DictManager with JAX tree operations."""
|
911
|
+
dm = DictManager({
|
912
|
+
'weights': jnp.array([1.0, 2.0, 3.0]),
|
913
|
+
'bias': jnp.array([0.1, 0.2])
|
914
|
+
})
|
915
|
+
|
916
|
+
# Tree map
|
917
|
+
scaled = jax.tree_util.tree_map(lambda x: x * 2, dm)
|
918
|
+
self.assertTrue(jnp.allclose(scaled['weights'], jnp.array([2.0, 4.0, 6.0])))
|
919
|
+
self.assertTrue(jnp.allclose(scaled['bias'], jnp.array([0.2, 0.4])))
|
920
|
+
|
921
|
+
# Tree reduce
|
922
|
+
total = jax.tree_util.tree_reduce(lambda x, y: x + y.sum(), dm, 0.0)
|
923
|
+
self.assertAlmostEqual(total, 6.3, places=5)
|
924
|
+
|
925
|
+
def test_dotdict_pytree_operations(self):
|
926
|
+
"""Test DotDict with JAX tree operations."""
|
927
|
+
dd = DotDict({
|
928
|
+
'model': {
|
929
|
+
'weights': jnp.array([[1.0, 2.0], [3.0, 4.0]]),
|
930
|
+
'bias': jnp.array([0.1, 0.2])
|
931
|
+
},
|
932
|
+
'optimizer': {
|
933
|
+
'lr': 0.01
|
934
|
+
}
|
935
|
+
})
|
936
|
+
|
937
|
+
# Tree map on nested structure
|
938
|
+
def scale_arrays(x):
|
939
|
+
return x * 2 if isinstance(x, jnp.ndarray) else x
|
940
|
+
|
941
|
+
scaled = jax.tree_util.tree_map(scale_arrays, dd)
|
942
|
+
self.assertTrue(jnp.allclose(
|
943
|
+
scaled.model.weights,
|
944
|
+
jnp.array([[2.0, 4.0], [6.0, 8.0]])
|
945
|
+
))
|
946
|
+
self.assertEqual(scaled.optimizer.lr, 0.01) # Non-array unchanged
|
947
|
+
|
948
|
+
def test_mixed_pytree_structures(self):
|
949
|
+
"""Test mixing DictManager and DotDict in pytree operations."""
|
950
|
+
structure = {
|
951
|
+
'dict_manager': DictManager({'a': jnp.array([1, 2])})
|
952
|
+
}
|
953
|
+
|
954
|
+
doubled = jax.tree_util.tree_map(lambda x: x * 2, structure)
|
955
|
+
self.assertTrue(jnp.allclose(
|
956
|
+
doubled['dict_manager']['a'],
|
957
|
+
jnp.array([2, 4])
|
958
|
+
))
|
959
|
+
|
960
|
+
|
961
|
+
if __name__ == '__main__':
|
962
|
+
unittest.main()
|