brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,962 +1,962 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """
17
- Comprehensive tests for the others module.
18
- """
19
-
20
- import pickle
21
- import threading
22
- import unittest
23
- from unittest.mock import MagicMock, patch
24
-
25
- import jax
26
- import jax.numpy as jnp
27
-
28
- from brainstate.util._others import (
29
- DictManager,
30
- DotDict,
31
- NameContext,
32
- clear_buffer_memory,
33
- flatten_dict,
34
- get_unique_name,
35
- is_instance_eval,
36
- merge_dicts,
37
- not_instance_eval,
38
- split_total,
39
- unflatten_dict,
40
- )
41
-
42
-
43
- class TestSplitTotal(unittest.TestCase):
44
- """Test cases for split_total function."""
45
-
46
- def test_float_fraction(self):
47
- """Test with float fraction values."""
48
- self.assertEqual(split_total(100, 0.5), 50)
49
- self.assertEqual(split_total(100, 0.25), 25)
50
- self.assertEqual(split_total(100, 0.75), 75)
51
- self.assertEqual(split_total(100, 0.0), 0)
52
- self.assertEqual(split_total(100, 1.0), 100)
53
-
54
- def test_int_fraction(self):
55
- """Test with integer fraction values."""
56
- self.assertEqual(split_total(100, 25), 25)
57
- self.assertEqual(split_total(100, 0), 0)
58
- self.assertEqual(split_total(100, 100), 100)
59
- self.assertEqual(split_total(50, 30), 30)
60
-
61
- def test_edge_cases(self):
62
- """Test edge cases."""
63
- self.assertEqual(split_total(1, 0.5), 0) # int(0.5) = 0
64
- self.assertEqual(split_total(1, 1), 1)
65
- self.assertEqual(split_total(10, 0.99), 9) # int(9.9) = 9
66
-
67
- def test_type_errors(self):
68
- """Test type error handling."""
69
- with self.assertRaises(TypeError) as ctx:
70
- split_total("100", 0.5)
71
- self.assertIn("must be an integer", str(ctx.exception))
72
-
73
- with self.assertRaises(TypeError) as ctx:
74
- split_total(100, "0.5")
75
- self.assertIn("must be an integer or float", str(ctx.exception))
76
-
77
- with self.assertRaises(TypeError) as ctx:
78
- split_total(100.5, 0.5)
79
- self.assertIn("must be an integer", str(ctx.exception))
80
-
81
- def test_value_errors(self):
82
- """Test value error handling."""
83
- # Negative total
84
- with self.assertRaises(ValueError) as ctx:
85
- split_total(-10, 0.5)
86
- self.assertIn("must be a positive integer", str(ctx.exception))
87
-
88
- # Zero total
89
- with self.assertRaises(ValueError) as ctx:
90
- split_total(0, 0.5)
91
- self.assertIn("must be a positive integer", str(ctx.exception))
92
-
93
- # Negative fraction (float)
94
- with self.assertRaises(ValueError) as ctx:
95
- split_total(100, -0.5)
96
- self.assertIn("cannot be negative", str(ctx.exception))
97
-
98
- # Fraction > 1 (float)
99
- with self.assertRaises(ValueError) as ctx:
100
- split_total(100, 1.5)
101
- self.assertIn("cannot be greater than 1", str(ctx.exception))
102
-
103
- # Negative fraction (int)
104
- with self.assertRaises(ValueError) as ctx:
105
- split_total(100, -10)
106
- self.assertIn("cannot be negative", str(ctx.exception))
107
-
108
- # Fraction > total (int)
109
- with self.assertRaises(ValueError) as ctx:
110
- split_total(100, 150)
111
- self.assertIn("cannot be greater than total", str(ctx.exception))
112
-
113
-
114
- class TestNameContext(unittest.TestCase):
115
- """Test cases for NameContext and get_unique_name."""
116
-
117
- def setUp(self):
118
- """Reset the global NAME context before each test."""
119
- global NAME
120
- from brainstate.util._others import NAME
121
- NAME.typed_names.clear()
122
-
123
- def test_get_unique_name_basic(self):
124
- """Test basic unique name generation."""
125
- name1 = get_unique_name('layer')
126
- name2 = get_unique_name('layer')
127
- name3 = get_unique_name('layer')
128
-
129
- self.assertEqual(name1, 'layer0')
130
- self.assertEqual(name2, 'layer1')
131
- self.assertEqual(name3, 'layer2')
132
-
133
- def test_get_unique_name_with_prefix(self):
134
- """Test unique name generation with prefix."""
135
- name1 = get_unique_name('layer', 'conv_')
136
- name2 = get_unique_name('layer', 'conv_')
137
- name3 = get_unique_name('layer', 'dense_')
138
-
139
- self.assertEqual(name1, 'conv_layer0')
140
- self.assertEqual(name2, 'conv_layer1')
141
- self.assertEqual(name3, 'dense_layer2')
142
-
143
- def test_different_types(self):
144
- """Test unique names for different types."""
145
- layer1 = get_unique_name('layer')
146
- neuron1 = get_unique_name('neuron')
147
- layer2 = get_unique_name('layer')
148
- neuron2 = get_unique_name('neuron')
149
-
150
- self.assertEqual(layer1, 'layer0')
151
- self.assertEqual(neuron1, 'neuron0')
152
- self.assertEqual(layer2, 'layer1')
153
- self.assertEqual(neuron2, 'neuron1')
154
-
155
- def test_thread_local_context(self):
156
- """Test that NameContext is thread-local."""
157
- results = {}
158
-
159
- def worker(thread_id):
160
- name1 = get_unique_name(f'type_{thread_id}')
161
- name2 = get_unique_name(f'type_{thread_id}')
162
- results[thread_id] = (name1, name2)
163
-
164
- threads = []
165
- for i in range(3):
166
- t = threading.Thread(target=worker, args=(i,))
167
- threads.append(t)
168
- t.start()
169
-
170
- for t in threads:
171
- t.join()
172
-
173
- # Each thread should have independent counters
174
- for thread_id in range(3):
175
- self.assertEqual(results[thread_id][0], f'type_{thread_id}0')
176
- self.assertEqual(results[thread_id][1], f'type_{thread_id}1')
177
-
178
- def test_name_context_reset(self):
179
- """Test resetting name context."""
180
- context = NameContext()
181
- context.typed_names['test'] = 5
182
-
183
- # Reset specific type
184
- context.reset('test')
185
- self.assertEqual(context.typed_names.get('test'), 0)
186
-
187
- # Add more names
188
- context.typed_names['test1'] = 1
189
- context.typed_names['test2'] = 2
190
-
191
- # Reset all
192
- context.reset()
193
- self.assertEqual(len(context.typed_names), 0)
194
-
195
-
196
- class TestDictManager(unittest.TestCase):
197
- """Test cases for DictManager class."""
198
-
199
- def test_initialization(self):
200
- """Test DictManager initialization."""
201
- # Empty initialization
202
- dm1 = DictManager()
203
- self.assertEqual(len(dm1), 0)
204
-
205
- # From dict
206
- dm2 = DictManager({'a': 1, 'b': 2})
207
- self.assertEqual(dm2['a'], 1)
208
- self.assertEqual(dm2['b'], 2)
209
-
210
- # From kwargs
211
- dm3 = DictManager(a=1, b=2)
212
- self.assertEqual(dm3['a'], 1)
213
- self.assertEqual(dm3['b'], 2)
214
-
215
- # From items
216
- dm4 = DictManager([('a', 1), ('b', 2)])
217
- self.assertEqual(dm4['a'], 1)
218
- self.assertEqual(dm4['b'], 2)
219
-
220
- def test_subset(self):
221
- """Test subset filtering."""
222
- dm = DictManager({'a': 1, 'b': 2.0, 'c': 'text', 'd': 3})
223
-
224
- # By type
225
- int_subset = dm.subset(int)
226
- self.assertEqual(dict(int_subset), {'a': 1, 'd': 3})
227
-
228
- # By multiple types
229
- num_subset = dm.subset((int, float))
230
- self.assertEqual(dict(num_subset), {'a': 1, 'b': 2.0, 'd': 3})
231
-
232
- # By predicate
233
- large_subset = dm.subset(lambda x: isinstance(x, (int, float)) and x > 1.5)
234
- self.assertEqual(dict(large_subset), {'b': 2.0, 'd': 3})
235
-
236
- def test_not_subset(self):
237
- """Test not_subset filtering."""
238
- dm = DictManager({'a': 1, 'b': 2.0, 'c': 'text', 'd': 3})
239
-
240
- not_int = dm.not_subset(int)
241
- self.assertEqual(dict(not_int), {'b': 2.0, 'c': 'text'})
242
-
243
- not_num = dm.not_subset((int, float))
244
- self.assertEqual(dict(not_num), {'c': 'text'})
245
-
246
- def test_add_unique_key(self):
247
- """Test adding unique keys."""
248
- dm = DictManager()
249
- obj1 = object()
250
- obj2 = object()
251
-
252
- # Add new key
253
- dm.add_unique_key('key1', obj1)
254
- self.assertIs(dm['key1'], obj1)
255
-
256
- # Add same key with same object (should work)
257
- dm.add_unique_key('key1', obj1)
258
- self.assertIs(dm['key1'], obj1)
259
-
260
- # Add same key with different object (should fail)
261
- with self.assertRaises(ValueError) as ctx:
262
- dm.add_unique_key('key1', obj2)
263
- self.assertIn("already exists with a different value", str(ctx.exception))
264
-
265
- def test_add_unique_value(self):
266
- """Test adding unique values."""
267
- dm = DictManager()
268
- obj1 = object()
269
-
270
- # First addition should succeed
271
- result1 = dm.add_unique_value('key1', obj1)
272
- self.assertTrue(result1)
273
- self.assertIs(dm['key1'], obj1)
274
-
275
- # Adding same value with different key should fail
276
- result2 = dm.add_unique_value('key2', obj1)
277
- self.assertFalse(result2)
278
- self.assertNotIn('key2', dm)
279
-
280
- # Adding different value should succeed
281
- obj2 = object()
282
- result3 = dm.add_unique_value('key2', obj2)
283
- self.assertTrue(result3)
284
- self.assertIs(dm['key2'], obj2)
285
-
286
- def test_unique(self):
287
- """Test getting unique values."""
288
- obj1 = object()
289
- obj2 = object()
290
- dm = DictManager({'a': obj1, 'b': obj2, 'c': obj1, 'd': obj2, 'e': obj1})
291
-
292
- unique_dm = dm.unique()
293
- self.assertEqual(len(unique_dm), 2)
294
- # Check that each object appears only once
295
- values = list(unique_dm.values())
296
- self.assertEqual(len(set(id(v) for v in values)), 2)
297
-
298
- def test_unique_inplace(self):
299
- """Test in-place unique operation."""
300
- obj1 = object()
301
- obj2 = object()
302
- dm = DictManager({'a': obj1, 'b': obj2, 'c': obj1})
303
-
304
- result = dm.unique_()
305
- self.assertIs(result, dm) # Should return self
306
- self.assertEqual(len(dm), 2) # One duplicate removed
307
-
308
- def test_assign(self):
309
- """Test assign method."""
310
- dm = DictManager({'a': 1})
311
-
312
- # Assign from dict
313
- dm.assign({'b': 2, 'c': 3})
314
- self.assertEqual(dict(dm), {'a': 1, 'b': 2, 'c': 3})
315
-
316
- # Assign from multiple dicts
317
- dm.assign({'d': 4}, {'e': 5})
318
- self.assertEqual(len(dm), 5)
319
-
320
- # Assign with kwargs
321
- dm.assign(f=6, g=7)
322
- self.assertEqual(dm['f'], 6)
323
- self.assertEqual(dm['g'], 7)
324
-
325
- # Invalid argument
326
- with self.assertRaises(TypeError):
327
- dm.assign([1, 2, 3])
328
-
329
- def test_split(self):
330
- """Test splitting by types."""
331
- dm = DictManager({
332
- 'a': 1, 'b': 2.0, 'c': 'text',
333
- 'd': 3, 'e': 4.5, 'f': [1, 2]
334
- })
335
-
336
- int_dm, float_dm, rest = dm.split(int, float)
337
-
338
- self.assertEqual(dict(int_dm), {'a': 1, 'd': 3})
339
- self.assertEqual(dict(float_dm), {'b': 2.0, 'e': 4.5})
340
- self.assertEqual(dict(rest), {'c': 'text', 'f': [1, 2]})
341
-
342
- def test_filter_by_predicate(self):
343
- """Test filtering with predicate."""
344
- dm = DictManager({'a': 1, 'b': 2, 'c': 3, 'd': 4})
345
-
346
- # Filter by key
347
- filtered = dm.filter_by_predicate(lambda k, v: k in ['a', 'c'])
348
- self.assertEqual(dict(filtered), {'a': 1, 'c': 3})
349
-
350
- # Filter by value
351
- filtered = dm.filter_by_predicate(lambda k, v: v > 2)
352
- self.assertEqual(dict(filtered), {'c': 3, 'd': 4})
353
-
354
- # Filter by both
355
- filtered = dm.filter_by_predicate(lambda k, v: k == 'a' or v == 4)
356
- self.assertEqual(dict(filtered), {'a': 1, 'd': 4})
357
-
358
- def test_map_values(self):
359
- """Test mapping function to values."""
360
- dm = DictManager({'a': 1, 'b': 2, 'c': 3})
361
-
362
- doubled = dm.map_values(lambda x: x * 2)
363
- self.assertEqual(dict(doubled), {'a': 2, 'b': 4, 'c': 6})
364
-
365
- # Original should be unchanged
366
- self.assertEqual(dict(dm), {'a': 1, 'b': 2, 'c': 3})
367
-
368
- def test_map_keys(self):
369
- """Test mapping function to keys."""
370
- dm = DictManager({'a': 1, 'b': 2, 'c': 3})
371
-
372
- upper = dm.map_keys(str.upper)
373
- self.assertEqual(dict(upper), {'A': 1, 'B': 2, 'C': 3})
374
-
375
- # Test duplicate key error
376
- with self.assertRaises(ValueError) as ctx:
377
- dm.map_keys(lambda x: 'same')
378
- self.assertIn("duplicate", str(ctx.exception))
379
-
380
- def test_pop_by_keys(self):
381
- """Test removing items by keys."""
382
- dm = DictManager({'a': 1, 'b': 2, 'c': 3, 'd': 4})
383
-
384
- dm.pop_by_keys(['b', 'd'])
385
- self.assertEqual(dict(dm), {'a': 1, 'c': 3})
386
-
387
- # Pop non-existent keys (should not raise)
388
- dm.pop_by_keys(['x', 'y'])
389
- self.assertEqual(dict(dm), {'a': 1, 'c': 3})
390
-
391
- def test_pop_by_values(self):
392
- """Test removing items by values."""
393
- obj1, obj2, obj3 = object(), object(), object()
394
- dm = DictManager({'a': obj1, 'b': obj2, 'c': obj3})
395
-
396
- # By identity
397
- dm_copy = DictManager(dm)
398
- dm_copy.pop_by_values([obj2], by='id')
399
- self.assertEqual(len(dm_copy), 2)
400
- self.assertNotIn('b', dm_copy)
401
-
402
- # By value equality
403
- dm2 = DictManager({'a': 1, 'b': 2, 'c': 3})
404
- dm2.pop_by_values([2, 3], by='value')
405
- self.assertEqual(dict(dm2), {'a': 1})
406
-
407
- # Invalid method
408
- with self.assertRaises(ValueError):
409
- dm.pop_by_values([obj1], by='invalid')
410
-
411
- def test_difference_operations(self):
412
- """Test difference operations."""
413
- dm = DictManager({'a': 1, 'b': 2, 'c': 3, 'd': 4})
414
-
415
- # Difference by keys
416
- diff = dm.difference_by_keys(['b', 'd'])
417
- self.assertEqual(dict(diff), {'a': 1, 'c': 3})
418
-
419
- # Difference by values
420
- diff = dm.difference_by_values([2, 4], by='value')
421
- self.assertEqual(dict(diff), {'a': 1, 'c': 3})
422
-
423
- def test_intersection_operations(self):
424
- """Test intersection operations."""
425
- dm = DictManager({'a': 1, 'b': 2, 'c': 3, 'd': 4})
426
-
427
- # Intersection by keys
428
- inter = dm.intersection_by_keys(['a', 'c', 'e'])
429
- self.assertEqual(dict(inter), {'a': 1, 'c': 3})
430
-
431
- # Intersection by values
432
- inter = dm.intersection_by_values([1, 3, 5], by='value')
433
- self.assertEqual(dict(inter), {'a': 1, 'c': 3})
434
-
435
- def test_operators(self):
436
- """Test operator overloading."""
437
- dm1 = DictManager({'a': 1, 'b': 2})
438
- dm2 = DictManager({'c': 3, 'd': 4})
439
-
440
- # Addition
441
- dm3 = dm1 + dm2
442
- self.assertEqual(dict(dm3), {'a': 1, 'b': 2, 'c': 3, 'd': 4})
443
- self.assertIsNot(dm3, dm1) # New object
444
-
445
- # Addition with regular dict
446
- dm4 = dm1 + {'e': 5}
447
- self.assertEqual(dm4['e'], 5)
448
-
449
- # Union operator (Python 3.9+) - only test if available
450
- import sys
451
- if sys.version_info >= (3, 9):
452
- dm5 = dm1 | dm2
453
- self.assertEqual(dict(dm5), {'a': 1, 'b': 2, 'c': 3, 'd': 4})
454
-
455
- # In-place union
456
- dm1_copy = DictManager(dm1)
457
- dm1_copy |= dm2
458
- self.assertEqual(dict(dm1_copy), {'a': 1, 'b': 2, 'c': 3, 'd': 4})
459
-
460
- # Invalid operations should raise TypeError or return NotImplemented
461
- # The actual behavior depends on whether __add__ returns NotImplemented
462
- # or lets Python raise TypeError
463
- with self.assertRaises(TypeError):
464
- _ = dm1 + 123
465
- if sys.version_info >= (3, 9):
466
- with self.assertRaises(TypeError):
467
- _ = dm1 | 123
468
-
469
- def test_copy_operations(self):
470
- """Test copy operations."""
471
- obj = object()
472
- dm1 = DictManager({'a': 1, 'b': obj})
473
-
474
- # Shallow copy
475
- dm2 = dm1.__copy__()
476
- self.assertIsNot(dm2, dm1)
477
- self.assertEqual(dict(dm2), dict(dm1))
478
- self.assertIs(dm2['b'], obj) # Same object reference
479
-
480
- # Deep copy
481
- dm3 = dm1.__deepcopy__({})
482
- self.assertIsNot(dm3, dm1)
483
- self.assertEqual(dm3['a'], 1)
484
- # Note: object() can't be deep copied, but other values work
485
-
486
- def test_jax_pytree(self):
487
- """Test JAX pytree registration."""
488
- dm = DictManager({'a': 1, 'b': 2, 'c': 3})
489
-
490
- # Flatten
491
- values, keys = dm.tree_flatten()
492
- self.assertEqual(len(values), 3)
493
- self.assertEqual(len(keys), 3)
494
-
495
- # Unflatten
496
- dm2 = DictManager.tree_unflatten(keys, values)
497
- self.assertEqual(dict(dm2), dict(dm))
498
-
499
- # Test with JAX tree operations
500
- dm3 = DictManager({'x': jnp.array([1, 2]), 'y': jnp.array([3, 4])})
501
- doubled = jax.tree_util.tree_map(lambda x: x * 2, dm3)
502
- self.assertTrue(jnp.allclose(doubled['x'], jnp.array([2, 4])))
503
- self.assertTrue(jnp.allclose(doubled['y'], jnp.array([6, 8])))
504
-
505
- def test_repr(self):
506
- """Test string representation."""
507
- dm = DictManager({'a': 1, 'b': 'text'})
508
- repr_str = repr(dm)
509
- self.assertIn('DictManager', repr_str)
510
- self.assertIn("'a': 1", repr_str)
511
- self.assertIn("'b': 'text'", repr_str)
512
-
513
-
514
- class TestDotDict(unittest.TestCase):
515
- """Test cases for DotDict class."""
516
-
517
- def test_initialization(self):
518
- """Test DotDict initialization."""
519
- # From dict
520
- dd1 = DotDict({'a': 1, 'b': 2})
521
- self.assertEqual(dd1.a, 1)
522
- self.assertEqual(dd1['b'], 2)
523
-
524
- # From kwargs
525
- dd2 = DotDict(a=1, b=2)
526
- self.assertEqual(dd2.a, 1)
527
- self.assertEqual(dd2.b, 2)
528
-
529
- # From tuple
530
- dd3 = DotDict(('key', 'value'))
531
- self.assertEqual(dd3.key, 'value')
532
-
533
- # From items
534
- dd4 = DotDict([('a', 1), ('b', 2)])
535
- self.assertEqual(dd4.a, 1)
536
- self.assertEqual(dd4.b, 2)
537
-
538
- # Empty
539
- dd5 = DotDict()
540
- self.assertEqual(len(dd5), 0)
541
-
542
- # Invalid argument
543
- with self.assertRaises(TypeError):
544
- DotDict(123)
545
-
546
- def test_dot_access(self):
547
- """Test dot notation access."""
548
- dd = DotDict({'a': 1, 'b': {'c': 2, 'd': 3}})
549
-
550
- # Read access
551
- self.assertEqual(dd.a, 1)
552
- self.assertEqual(dd.b.c, 2)
553
- self.assertEqual(dd.b.d, 3)
554
-
555
- # Write access
556
- dd.a = 10
557
- dd.b.c = 20
558
- self.assertEqual(dd['a'], 10)
559
- self.assertEqual(dd['b']['c'], 20)
560
-
561
- # Add new attributes
562
- dd.e = 5
563
- self.assertEqual(dd['e'], 5)
564
-
565
- def test_nested_dict_conversion(self):
566
- """Test automatic nested dict conversion."""
567
- dd = DotDict({
568
- 'level1': {
569
- 'level2': {
570
- 'level3': 'value'
571
- }
572
- }
573
- })
574
-
575
- self.assertIsInstance(dd.level1, DotDict)
576
- self.assertIsInstance(dd.level1.level2, DotDict)
577
- self.assertEqual(dd.level1.level2.level3, 'value')
578
-
579
- def test_list_tuple_conversion(self):
580
- """Test conversion of lists and tuples containing dicts."""
581
- dd = DotDict({
582
- 'list': [{'a': 1}, {'b': 2}],
583
- 'tuple': ({'c': 3}, {'d': 4})
584
- })
585
-
586
- # List items should be DotDict
587
- self.assertIsInstance(dd.list[0], DotDict)
588
- self.assertEqual(dd.list[0].a, 1)
589
- self.assertIsInstance(dd.list[1], DotDict)
590
- self.assertEqual(dd.list[1].b, 2)
591
-
592
- # Tuple items should be DotDict
593
- self.assertIsInstance(dd.tuple[0], DotDict)
594
- self.assertEqual(dd.tuple[0].c, 3)
595
-
596
- def test_attribute_errors(self):
597
- """Test attribute error handling."""
598
- dd = DotDict({'a': 1})
599
-
600
- # Non-existent attribute
601
- with self.assertRaises(AttributeError) as ctx:
602
- _ = dd.nonexistent
603
- self.assertIn("has no attribute 'nonexistent'", str(ctx.exception))
604
-
605
- # Delete non-existent attribute
606
- with self.assertRaises(AttributeError):
607
- del dd.nonexistent
608
-
609
- # Try to set built-in method
610
- with self.assertRaises(AttributeError) as ctx:
611
- dd.keys = 'value'
612
- self.assertIn("built-in method", str(ctx.exception))
613
-
614
- def test_dir_method(self):
615
- """Test __dir__ method."""
616
- dd = DotDict({'a': 1, 'b': 2})
617
- attrs = dir(dd)
618
-
619
- self.assertIn('a', attrs)
620
- self.assertIn('b', attrs)
621
- self.assertIn('keys', attrs) # Built-in method
622
- self.assertIn('values', attrs) # Built-in method
623
-
624
- def test_get_method(self):
625
- """Test get method with default."""
626
- dd = DotDict({'a': 1})
627
-
628
- self.assertEqual(dd.get('a'), 1)
629
- self.assertEqual(dd.get('b'), None)
630
- self.assertEqual(dd.get('b', 'default'), 'default')
631
-
632
- def test_copy_operations(self):
633
- """Test copy operations."""
634
- dd1 = DotDict({'a': 1, 'b': {'c': 2}})
635
-
636
- # Shallow copy
637
- dd2 = dd1.copy()
638
- self.assertIsNot(dd2, dd1)
639
- self.assertEqual(dd2.a, 1)
640
- dd2.a = 10
641
- self.assertEqual(dd1.a, 1) # Original unchanged
642
-
643
- # Deep copy
644
- dd3 = dd1.deepcopy()
645
- self.assertIsNot(dd3, dd1)
646
- self.assertIsNot(dd3.b, dd1.b)
647
- dd3.b.c = 20
648
- self.assertEqual(dd1.b.c, 2) # Original unchanged
649
-
650
- def test_to_dict_from_dict(self):
651
- """Test conversion to/from standard dict."""
652
- dd1 = DotDict({
653
- 'a': 1,
654
- 'b': {'c': 2, 'd': {'e': 3}},
655
- 'list': [{'f': 4}]
656
- })
657
-
658
- # Convert to dict
659
- d = dd1.to_dict()
660
- self.assertIsInstance(d, dict)
661
- self.assertNotIsInstance(d, DotDict)
662
- self.assertIsInstance(d['b'], dict)
663
- self.assertNotIsInstance(d['b'], DotDict)
664
- self.assertEqual(d['b']['d']['e'], 3)
665
-
666
- # Convert from dict
667
- dd2 = DotDict.from_dict(d)
668
- self.assertIsInstance(dd2, DotDict)
669
- self.assertIsInstance(dd2.b, DotDict)
670
- self.assertEqual(dd2.b.d.e, 3)
671
-
672
- def test_update_method(self):
673
- """Test update with recursive merge."""
674
- dd = DotDict({'a': 1, 'b': {'c': 2, 'd': 3}})
675
-
676
- # Simple update
677
- dd.update({'a': 10})
678
- self.assertEqual(dd.a, 10)
679
-
680
- # Recursive merge
681
- dd.update({'b': {'d': 30, 'e': 4}})
682
- self.assertEqual(dd.b.c, 2) # Preserved
683
- self.assertEqual(dd.b.d, 30) # Updated
684
- self.assertEqual(dd.b.e, 4) # Added
685
-
686
- # Update with kwargs
687
- dd.update(f=5, g=6)
688
- self.assertEqual(dd.f, 5)
689
- self.assertEqual(dd.g, 6)
690
-
691
- # Multiple arguments error
692
- with self.assertRaises(TypeError):
693
- dd.update({}, {})
694
-
695
- def test_setdefault(self):
696
- """Test setdefault method."""
697
- dd = DotDict({'a': 1})
698
-
699
- # Existing key
700
- result = dd.setdefault('a', 10)
701
- self.assertEqual(result, 1)
702
- self.assertEqual(dd.a, 1)
703
-
704
- # New key
705
- result = dd.setdefault('b', 2)
706
- self.assertEqual(result, 2)
707
- self.assertEqual(dd.b, 2)
708
-
709
- # New key with None default
710
- result = dd.setdefault('c')
711
- self.assertIsNone(result)
712
- self.assertIsNone(dd.c)
713
-
714
- def test_pickling(self):
715
- """Test pickling/unpickling."""
716
- dd1 = DotDict({'a': 1, 'b': {'c': 2}})
717
-
718
- # Pickle and unpickle
719
- pickled = pickle.dumps(dd1)
720
- dd2 = pickle.loads(pickled)
721
-
722
- self.assertIsNot(dd2, dd1)
723
- self.assertEqual(dd2.a, 1)
724
- self.assertEqual(dd2.b.c, 2)
725
- self.assertIsInstance(dd2, DotDict)
726
- self.assertIsInstance(dd2.b, DotDict)
727
-
728
- def test_jax_pytree(self):
729
- """Test JAX pytree registration."""
730
- dd = DotDict({'a': jnp.array([1, 2]), 'b': jnp.array([3, 4])})
731
-
732
- # Tree operations
733
- doubled = jax.tree_util.tree_map(lambda x: x * 2, dd)
734
- self.assertTrue(jnp.allclose(doubled.a, jnp.array([2, 4])))
735
- self.assertTrue(jnp.allclose(doubled.b, jnp.array([6, 8])))
736
-
737
- def test_repr(self):
738
- """Test string representation."""
739
- dd = DotDict({'a': 1, 'b': 'text'})
740
- repr_str = repr(dd)
741
- self.assertIn('DotDict', repr_str)
742
- self.assertIn("'a': 1", repr_str)
743
- self.assertIn("'b': 'text'", repr_str)
744
-
745
-
746
- class TestUtilityFunctions(unittest.TestCase):
747
- """Test cases for utility functions."""
748
-
749
- def test_merge_dicts_basic(self):
750
- """Test basic dict merging."""
751
- d1 = {'a': 1, 'b': 2}
752
- d2 = {'b': 3, 'c': 4}
753
- d3 = {'d': 5}
754
-
755
- result = merge_dicts(d1, d2, d3)
756
- self.assertEqual(result, {'a': 1, 'b': 3, 'c': 4, 'd': 5})
757
-
758
- # Original dicts should be unchanged
759
- self.assertEqual(d1, {'a': 1, 'b': 2})
760
-
761
- def test_merge_dicts_recursive(self):
762
- """Test recursive dict merging."""
763
- d1 = {'a': 1, 'b': {'c': 2, 'd': 3}}
764
- d2 = {'b': {'d': 4, 'e': 5}, 'f': 6}
765
-
766
- result = merge_dicts(d1, d2, recursive=True)
767
- self.assertEqual(result, {
768
- 'a': 1,
769
- 'b': {'c': 2, 'd': 4, 'e': 5},
770
- 'f': 6
771
- })
772
-
773
- def test_merge_dicts_non_recursive(self):
774
- """Test non-recursive dict merging."""
775
- d1 = {'a': 1, 'b': {'c': 2}}
776
- d2 = {'b': {'d': 3}}
777
-
778
- result = merge_dicts(d1, d2, recursive=False)
779
- self.assertEqual(result, {'a': 1, 'b': {'d': 3}})
780
-
781
- def test_merge_dicts_errors(self):
782
- """Test merge_dicts error handling."""
783
- with self.assertRaises(TypeError):
784
- merge_dicts({'a': 1}, [1, 2, 3])
785
-
786
- def test_flatten_dict(self):
787
- """Test dictionary flattening."""
788
- nested = {
789
- 'a': 1,
790
- 'b': {
791
- 'c': 2,
792
- 'd': {
793
- 'e': 3,
794
- 'f': 4
795
- }
796
- },
797
- 'g': 5
798
- }
799
-
800
- flat = flatten_dict(nested)
801
- self.assertEqual(flat, {
802
- 'a': 1,
803
- 'b.c': 2,
804
- 'b.d.e': 3,
805
- 'b.d.f': 4,
806
- 'g': 5
807
- })
808
-
809
- # Custom separator
810
- flat_dash = flatten_dict(nested, sep='-')
811
- self.assertEqual(flat_dash, {
812
- 'a': 1,
813
- 'b-c': 2,
814
- 'b-d-e': 3,
815
- 'b-d-f': 4,
816
- 'g': 5
817
- })
818
-
819
- def test_unflatten_dict(self):
820
- """Test dictionary unflattening."""
821
- flat = {
822
- 'a': 1,
823
- 'b.c': 2,
824
- 'b.d.e': 3,
825
- 'b.d.f': 4,
826
- 'g': 5
827
- }
828
-
829
- nested = unflatten_dict(flat)
830
- self.assertEqual(nested, {
831
- 'a': 1,
832
- 'b': {
833
- 'c': 2,
834
- 'd': {
835
- 'e': 3,
836
- 'f': 4
837
- }
838
- },
839
- 'g': 5
840
- })
841
-
842
- # Custom separator
843
- flat_dash = {'a': 1, 'b-c': 2, 'b-d': 3}
844
- nested_dash = unflatten_dict(flat_dash, sep='-')
845
- self.assertEqual(nested_dash, {
846
- 'a': 1,
847
- 'b': {'c': 2, 'd': 3}
848
- })
849
-
850
- def test_flatten_unflatten_roundtrip(self):
851
- """Test that flatten/unflatten is reversible."""
852
- original = {
853
- 'level1': {
854
- 'level2': {
855
- 'level3': 'value',
856
- 'another': 42
857
- },
858
- 'sibling': 'data'
859
- },
860
- 'root': 'element'
861
- }
862
-
863
- flattened = flatten_dict(original)
864
- unflattened = unflatten_dict(flattened)
865
- self.assertEqual(unflattened, original)
866
-
867
- def test_is_instance_eval(self):
868
- """Test is_instance_eval function."""
869
- # Single type
870
- is_int = is_instance_eval(int)
871
- self.assertTrue(is_int(5))
872
- self.assertFalse(is_int("5"))
873
- self.assertFalse(is_int(5.0))
874
-
875
- # Multiple types
876
- is_number = is_instance_eval(int, float)
877
- self.assertTrue(is_number(5))
878
- self.assertTrue(is_number(5.0))
879
- self.assertFalse(is_number("5"))
880
-
881
- # With subclasses
882
- class MyList(list):
883
- pass
884
-
885
- is_list = is_instance_eval(list)
886
- self.assertTrue(is_list([1, 2, 3]))
887
- self.assertTrue(is_list(MyList([1, 2, 3])))
888
- self.assertFalse(is_list((1, 2, 3)))
889
-
890
- def test_not_instance_eval(self):
891
- """Test not_instance_eval function."""
892
- # Single type
893
- not_int = not_instance_eval(int)
894
- self.assertFalse(not_int(5))
895
- self.assertTrue(not_int("5"))
896
- self.assertTrue(not_int(5.0))
897
-
898
- # Multiple types
899
- not_number = not_instance_eval(int, float)
900
- self.assertFalse(not_number(5))
901
- self.assertFalse(not_number(5.0))
902
- self.assertTrue(not_number("5"))
903
- self.assertTrue(not_number([1, 2, 3]))
904
-
905
-
906
- class TestJaxIntegration(unittest.TestCase):
907
- """Test JAX integration for DictManager and DotDict."""
908
-
909
- def test_dictmanager_pytree_operations(self):
910
- """Test DictManager with JAX tree operations."""
911
- dm = DictManager({
912
- 'weights': jnp.array([1.0, 2.0, 3.0]),
913
- 'bias': jnp.array([0.1, 0.2])
914
- })
915
-
916
- # Tree map
917
- scaled = jax.tree_util.tree_map(lambda x: x * 2, dm)
918
- self.assertTrue(jnp.allclose(scaled['weights'], jnp.array([2.0, 4.0, 6.0])))
919
- self.assertTrue(jnp.allclose(scaled['bias'], jnp.array([0.2, 0.4])))
920
-
921
- # Tree reduce
922
- total = jax.tree_util.tree_reduce(lambda x, y: x + y.sum(), dm, 0.0)
923
- self.assertAlmostEqual(total, 6.3, places=5)
924
-
925
- def test_dotdict_pytree_operations(self):
926
- """Test DotDict with JAX tree operations."""
927
- dd = DotDict({
928
- 'model': {
929
- 'weights': jnp.array([[1.0, 2.0], [3.0, 4.0]]),
930
- 'bias': jnp.array([0.1, 0.2])
931
- },
932
- 'optimizer': {
933
- 'lr': 0.01
934
- }
935
- })
936
-
937
- # Tree map on nested structure
938
- def scale_arrays(x):
939
- return x * 2 if isinstance(x, jnp.ndarray) else x
940
-
941
- scaled = jax.tree_util.tree_map(scale_arrays, dd)
942
- self.assertTrue(jnp.allclose(
943
- scaled.model.weights,
944
- jnp.array([[2.0, 4.0], [6.0, 8.0]])
945
- ))
946
- self.assertEqual(scaled.optimizer.lr, 0.01) # Non-array unchanged
947
-
948
- def test_mixed_pytree_structures(self):
949
- """Test mixing DictManager and DotDict in pytree operations."""
950
- structure = {
951
- 'dict_manager': DictManager({'a': jnp.array([1, 2])})
952
- }
953
-
954
- doubled = jax.tree_util.tree_map(lambda x: x * 2, structure)
955
- self.assertTrue(jnp.allclose(
956
- doubled['dict_manager']['a'],
957
- jnp.array([2, 4])
958
- ))
959
-
960
-
961
- if __name__ == '__main__':
962
- unittest.main()
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ Comprehensive tests for the others module.
18
+ """
19
+
20
+ import pickle
21
+ import threading
22
+ import unittest
23
+ from unittest.mock import MagicMock, patch
24
+
25
+ import jax
26
+ import jax.numpy as jnp
27
+
28
+ from brainstate.util._others import (
29
+ DictManager,
30
+ DotDict,
31
+ NameContext,
32
+ clear_buffer_memory,
33
+ flatten_dict,
34
+ get_unique_name,
35
+ is_instance_eval,
36
+ merge_dicts,
37
+ not_instance_eval,
38
+ split_total,
39
+ unflatten_dict,
40
+ )
41
+
42
+
43
+ class TestSplitTotal(unittest.TestCase):
44
+ """Test cases for split_total function."""
45
+
46
+ def test_float_fraction(self):
47
+ """Test with float fraction values."""
48
+ self.assertEqual(split_total(100, 0.5), 50)
49
+ self.assertEqual(split_total(100, 0.25), 25)
50
+ self.assertEqual(split_total(100, 0.75), 75)
51
+ self.assertEqual(split_total(100, 0.0), 0)
52
+ self.assertEqual(split_total(100, 1.0), 100)
53
+
54
+ def test_int_fraction(self):
55
+ """Test with integer fraction values."""
56
+ self.assertEqual(split_total(100, 25), 25)
57
+ self.assertEqual(split_total(100, 0), 0)
58
+ self.assertEqual(split_total(100, 100), 100)
59
+ self.assertEqual(split_total(50, 30), 30)
60
+
61
+ def test_edge_cases(self):
62
+ """Test edge cases."""
63
+ self.assertEqual(split_total(1, 0.5), 0) # int(0.5) = 0
64
+ self.assertEqual(split_total(1, 1), 1)
65
+ self.assertEqual(split_total(10, 0.99), 9) # int(9.9) = 9
66
+
67
+ def test_type_errors(self):
68
+ """Test type error handling."""
69
+ with self.assertRaises(TypeError) as ctx:
70
+ split_total("100", 0.5)
71
+ self.assertIn("must be an integer", str(ctx.exception))
72
+
73
+ with self.assertRaises(TypeError) as ctx:
74
+ split_total(100, "0.5")
75
+ self.assertIn("must be an integer or float", str(ctx.exception))
76
+
77
+ with self.assertRaises(TypeError) as ctx:
78
+ split_total(100.5, 0.5)
79
+ self.assertIn("must be an integer", str(ctx.exception))
80
+
81
+ def test_value_errors(self):
82
+ """Test value error handling."""
83
+ # Negative total
84
+ with self.assertRaises(ValueError) as ctx:
85
+ split_total(-10, 0.5)
86
+ self.assertIn("must be a positive integer", str(ctx.exception))
87
+
88
+ # Zero total
89
+ with self.assertRaises(ValueError) as ctx:
90
+ split_total(0, 0.5)
91
+ self.assertIn("must be a positive integer", str(ctx.exception))
92
+
93
+ # Negative fraction (float)
94
+ with self.assertRaises(ValueError) as ctx:
95
+ split_total(100, -0.5)
96
+ self.assertIn("cannot be negative", str(ctx.exception))
97
+
98
+ # Fraction > 1 (float)
99
+ with self.assertRaises(ValueError) as ctx:
100
+ split_total(100, 1.5)
101
+ self.assertIn("cannot be greater than 1", str(ctx.exception))
102
+
103
+ # Negative fraction (int)
104
+ with self.assertRaises(ValueError) as ctx:
105
+ split_total(100, -10)
106
+ self.assertIn("cannot be negative", str(ctx.exception))
107
+
108
+ # Fraction > total (int)
109
+ with self.assertRaises(ValueError) as ctx:
110
+ split_total(100, 150)
111
+ self.assertIn("cannot be greater than total", str(ctx.exception))
112
+
113
+
114
+ class TestNameContext(unittest.TestCase):
115
+ """Test cases for NameContext and get_unique_name."""
116
+
117
+ def setUp(self):
118
+ """Reset the global NAME context before each test."""
119
+ global NAME
120
+ from brainstate.util._others import NAME
121
+ NAME.typed_names.clear()
122
+
123
+ def test_get_unique_name_basic(self):
124
+ """Test basic unique name generation."""
125
+ name1 = get_unique_name('layer')
126
+ name2 = get_unique_name('layer')
127
+ name3 = get_unique_name('layer')
128
+
129
+ self.assertEqual(name1, 'layer0')
130
+ self.assertEqual(name2, 'layer1')
131
+ self.assertEqual(name3, 'layer2')
132
+
133
+ def test_get_unique_name_with_prefix(self):
134
+ """Test unique name generation with prefix."""
135
+ name1 = get_unique_name('layer', 'conv_')
136
+ name2 = get_unique_name('layer', 'conv_')
137
+ name3 = get_unique_name('layer', 'dense_')
138
+
139
+ self.assertEqual(name1, 'conv_layer0')
140
+ self.assertEqual(name2, 'conv_layer1')
141
+ self.assertEqual(name3, 'dense_layer2')
142
+
143
+ def test_different_types(self):
144
+ """Test unique names for different types."""
145
+ layer1 = get_unique_name('layer')
146
+ neuron1 = get_unique_name('neuron')
147
+ layer2 = get_unique_name('layer')
148
+ neuron2 = get_unique_name('neuron')
149
+
150
+ self.assertEqual(layer1, 'layer0')
151
+ self.assertEqual(neuron1, 'neuron0')
152
+ self.assertEqual(layer2, 'layer1')
153
+ self.assertEqual(neuron2, 'neuron1')
154
+
155
+ def test_thread_local_context(self):
156
+ """Test that NameContext is thread-local."""
157
+ results = {}
158
+
159
+ def worker(thread_id):
160
+ name1 = get_unique_name(f'type_{thread_id}')
161
+ name2 = get_unique_name(f'type_{thread_id}')
162
+ results[thread_id] = (name1, name2)
163
+
164
+ threads = []
165
+ for i in range(3):
166
+ t = threading.Thread(target=worker, args=(i,))
167
+ threads.append(t)
168
+ t.start()
169
+
170
+ for t in threads:
171
+ t.join()
172
+
173
+ # Each thread should have independent counters
174
+ for thread_id in range(3):
175
+ self.assertEqual(results[thread_id][0], f'type_{thread_id}0')
176
+ self.assertEqual(results[thread_id][1], f'type_{thread_id}1')
177
+
178
+ def test_name_context_reset(self):
179
+ """Test resetting name context."""
180
+ context = NameContext()
181
+ context.typed_names['test'] = 5
182
+
183
+ # Reset specific type
184
+ context.reset('test')
185
+ self.assertEqual(context.typed_names.get('test'), 0)
186
+
187
+ # Add more names
188
+ context.typed_names['test1'] = 1
189
+ context.typed_names['test2'] = 2
190
+
191
+ # Reset all
192
+ context.reset()
193
+ self.assertEqual(len(context.typed_names), 0)
194
+
195
+
196
+ class TestDictManager(unittest.TestCase):
197
+ """Test cases for DictManager class."""
198
+
199
+ def test_initialization(self):
200
+ """Test DictManager initialization."""
201
+ # Empty initialization
202
+ dm1 = DictManager()
203
+ self.assertEqual(len(dm1), 0)
204
+
205
+ # From dict
206
+ dm2 = DictManager({'a': 1, 'b': 2})
207
+ self.assertEqual(dm2['a'], 1)
208
+ self.assertEqual(dm2['b'], 2)
209
+
210
+ # From kwargs
211
+ dm3 = DictManager(a=1, b=2)
212
+ self.assertEqual(dm3['a'], 1)
213
+ self.assertEqual(dm3['b'], 2)
214
+
215
+ # From items
216
+ dm4 = DictManager([('a', 1), ('b', 2)])
217
+ self.assertEqual(dm4['a'], 1)
218
+ self.assertEqual(dm4['b'], 2)
219
+
220
+ def test_subset(self):
221
+ """Test subset filtering."""
222
+ dm = DictManager({'a': 1, 'b': 2.0, 'c': 'text', 'd': 3})
223
+
224
+ # By type
225
+ int_subset = dm.subset(int)
226
+ self.assertEqual(dict(int_subset), {'a': 1, 'd': 3})
227
+
228
+ # By multiple types
229
+ num_subset = dm.subset((int, float))
230
+ self.assertEqual(dict(num_subset), {'a': 1, 'b': 2.0, 'd': 3})
231
+
232
+ # By predicate
233
+ large_subset = dm.subset(lambda x: isinstance(x, (int, float)) and x > 1.5)
234
+ self.assertEqual(dict(large_subset), {'b': 2.0, 'd': 3})
235
+
236
+ def test_not_subset(self):
237
+ """Test not_subset filtering."""
238
+ dm = DictManager({'a': 1, 'b': 2.0, 'c': 'text', 'd': 3})
239
+
240
+ not_int = dm.not_subset(int)
241
+ self.assertEqual(dict(not_int), {'b': 2.0, 'c': 'text'})
242
+
243
+ not_num = dm.not_subset((int, float))
244
+ self.assertEqual(dict(not_num), {'c': 'text'})
245
+
246
+ def test_add_unique_key(self):
247
+ """Test adding unique keys."""
248
+ dm = DictManager()
249
+ obj1 = object()
250
+ obj2 = object()
251
+
252
+ # Add new key
253
+ dm.add_unique_key('key1', obj1)
254
+ self.assertIs(dm['key1'], obj1)
255
+
256
+ # Add same key with same object (should work)
257
+ dm.add_unique_key('key1', obj1)
258
+ self.assertIs(dm['key1'], obj1)
259
+
260
+ # Add same key with different object (should fail)
261
+ with self.assertRaises(ValueError) as ctx:
262
+ dm.add_unique_key('key1', obj2)
263
+ self.assertIn("already exists with a different value", str(ctx.exception))
264
+
265
+ def test_add_unique_value(self):
266
+ """Test adding unique values."""
267
+ dm = DictManager()
268
+ obj1 = object()
269
+
270
+ # First addition should succeed
271
+ result1 = dm.add_unique_value('key1', obj1)
272
+ self.assertTrue(result1)
273
+ self.assertIs(dm['key1'], obj1)
274
+
275
+ # Adding same value with different key should fail
276
+ result2 = dm.add_unique_value('key2', obj1)
277
+ self.assertFalse(result2)
278
+ self.assertNotIn('key2', dm)
279
+
280
+ # Adding different value should succeed
281
+ obj2 = object()
282
+ result3 = dm.add_unique_value('key2', obj2)
283
+ self.assertTrue(result3)
284
+ self.assertIs(dm['key2'], obj2)
285
+
286
+ def test_unique(self):
287
+ """Test getting unique values."""
288
+ obj1 = object()
289
+ obj2 = object()
290
+ dm = DictManager({'a': obj1, 'b': obj2, 'c': obj1, 'd': obj2, 'e': obj1})
291
+
292
+ unique_dm = dm.unique()
293
+ self.assertEqual(len(unique_dm), 2)
294
+ # Check that each object appears only once
295
+ values = list(unique_dm.values())
296
+ self.assertEqual(len(set(id(v) for v in values)), 2)
297
+
298
+ def test_unique_inplace(self):
299
+ """Test in-place unique operation."""
300
+ obj1 = object()
301
+ obj2 = object()
302
+ dm = DictManager({'a': obj1, 'b': obj2, 'c': obj1})
303
+
304
+ result = dm.unique_()
305
+ self.assertIs(result, dm) # Should return self
306
+ self.assertEqual(len(dm), 2) # One duplicate removed
307
+
308
+ def test_assign(self):
309
+ """Test assign method."""
310
+ dm = DictManager({'a': 1})
311
+
312
+ # Assign from dict
313
+ dm.assign({'b': 2, 'c': 3})
314
+ self.assertEqual(dict(dm), {'a': 1, 'b': 2, 'c': 3})
315
+
316
+ # Assign from multiple dicts
317
+ dm.assign({'d': 4}, {'e': 5})
318
+ self.assertEqual(len(dm), 5)
319
+
320
+ # Assign with kwargs
321
+ dm.assign(f=6, g=7)
322
+ self.assertEqual(dm['f'], 6)
323
+ self.assertEqual(dm['g'], 7)
324
+
325
+ # Invalid argument
326
+ with self.assertRaises(TypeError):
327
+ dm.assign([1, 2, 3])
328
+
329
+ def test_split(self):
330
+ """Test splitting by types."""
331
+ dm = DictManager({
332
+ 'a': 1, 'b': 2.0, 'c': 'text',
333
+ 'd': 3, 'e': 4.5, 'f': [1, 2]
334
+ })
335
+
336
+ int_dm, float_dm, rest = dm.split(int, float)
337
+
338
+ self.assertEqual(dict(int_dm), {'a': 1, 'd': 3})
339
+ self.assertEqual(dict(float_dm), {'b': 2.0, 'e': 4.5})
340
+ self.assertEqual(dict(rest), {'c': 'text', 'f': [1, 2]})
341
+
342
+ def test_filter_by_predicate(self):
343
+ """Test filtering with predicate."""
344
+ dm = DictManager({'a': 1, 'b': 2, 'c': 3, 'd': 4})
345
+
346
+ # Filter by key
347
+ filtered = dm.filter_by_predicate(lambda k, v: k in ['a', 'c'])
348
+ self.assertEqual(dict(filtered), {'a': 1, 'c': 3})
349
+
350
+ # Filter by value
351
+ filtered = dm.filter_by_predicate(lambda k, v: v > 2)
352
+ self.assertEqual(dict(filtered), {'c': 3, 'd': 4})
353
+
354
+ # Filter by both
355
+ filtered = dm.filter_by_predicate(lambda k, v: k == 'a' or v == 4)
356
+ self.assertEqual(dict(filtered), {'a': 1, 'd': 4})
357
+
358
+ def test_map_values(self):
359
+ """Test mapping function to values."""
360
+ dm = DictManager({'a': 1, 'b': 2, 'c': 3})
361
+
362
+ doubled = dm.map_values(lambda x: x * 2)
363
+ self.assertEqual(dict(doubled), {'a': 2, 'b': 4, 'c': 6})
364
+
365
+ # Original should be unchanged
366
+ self.assertEqual(dict(dm), {'a': 1, 'b': 2, 'c': 3})
367
+
368
+ def test_map_keys(self):
369
+ """Test mapping function to keys."""
370
+ dm = DictManager({'a': 1, 'b': 2, 'c': 3})
371
+
372
+ upper = dm.map_keys(str.upper)
373
+ self.assertEqual(dict(upper), {'A': 1, 'B': 2, 'C': 3})
374
+
375
+ # Test duplicate key error
376
+ with self.assertRaises(ValueError) as ctx:
377
+ dm.map_keys(lambda x: 'same')
378
+ self.assertIn("duplicate", str(ctx.exception))
379
+
380
+ def test_pop_by_keys(self):
381
+ """Test removing items by keys."""
382
+ dm = DictManager({'a': 1, 'b': 2, 'c': 3, 'd': 4})
383
+
384
+ dm.pop_by_keys(['b', 'd'])
385
+ self.assertEqual(dict(dm), {'a': 1, 'c': 3})
386
+
387
+ # Pop non-existent keys (should not raise)
388
+ dm.pop_by_keys(['x', 'y'])
389
+ self.assertEqual(dict(dm), {'a': 1, 'c': 3})
390
+
391
+ def test_pop_by_values(self):
392
+ """Test removing items by values."""
393
+ obj1, obj2, obj3 = object(), object(), object()
394
+ dm = DictManager({'a': obj1, 'b': obj2, 'c': obj3})
395
+
396
+ # By identity
397
+ dm_copy = DictManager(dm)
398
+ dm_copy.pop_by_values([obj2], by='id')
399
+ self.assertEqual(len(dm_copy), 2)
400
+ self.assertNotIn('b', dm_copy)
401
+
402
+ # By value equality
403
+ dm2 = DictManager({'a': 1, 'b': 2, 'c': 3})
404
+ dm2.pop_by_values([2, 3], by='value')
405
+ self.assertEqual(dict(dm2), {'a': 1})
406
+
407
+ # Invalid method
408
+ with self.assertRaises(ValueError):
409
+ dm.pop_by_values([obj1], by='invalid')
410
+
411
+ def test_difference_operations(self):
412
+ """Test difference operations."""
413
+ dm = DictManager({'a': 1, 'b': 2, 'c': 3, 'd': 4})
414
+
415
+ # Difference by keys
416
+ diff = dm.difference_by_keys(['b', 'd'])
417
+ self.assertEqual(dict(diff), {'a': 1, 'c': 3})
418
+
419
+ # Difference by values
420
+ diff = dm.difference_by_values([2, 4], by='value')
421
+ self.assertEqual(dict(diff), {'a': 1, 'c': 3})
422
+
423
+ def test_intersection_operations(self):
424
+ """Test intersection operations."""
425
+ dm = DictManager({'a': 1, 'b': 2, 'c': 3, 'd': 4})
426
+
427
+ # Intersection by keys
428
+ inter = dm.intersection_by_keys(['a', 'c', 'e'])
429
+ self.assertEqual(dict(inter), {'a': 1, 'c': 3})
430
+
431
+ # Intersection by values
432
+ inter = dm.intersection_by_values([1, 3, 5], by='value')
433
+ self.assertEqual(dict(inter), {'a': 1, 'c': 3})
434
+
435
+ def test_operators(self):
436
+ """Test operator overloading."""
437
+ dm1 = DictManager({'a': 1, 'b': 2})
438
+ dm2 = DictManager({'c': 3, 'd': 4})
439
+
440
+ # Addition
441
+ dm3 = dm1 + dm2
442
+ self.assertEqual(dict(dm3), {'a': 1, 'b': 2, 'c': 3, 'd': 4})
443
+ self.assertIsNot(dm3, dm1) # New object
444
+
445
+ # Addition with regular dict
446
+ dm4 = dm1 + {'e': 5}
447
+ self.assertEqual(dm4['e'], 5)
448
+
449
+ # Union operator (Python 3.9+) - only test if available
450
+ import sys
451
+ if sys.version_info >= (3, 9):
452
+ dm5 = dm1 | dm2
453
+ self.assertEqual(dict(dm5), {'a': 1, 'b': 2, 'c': 3, 'd': 4})
454
+
455
+ # In-place union
456
+ dm1_copy = DictManager(dm1)
457
+ dm1_copy |= dm2
458
+ self.assertEqual(dict(dm1_copy), {'a': 1, 'b': 2, 'c': 3, 'd': 4})
459
+
460
+ # Invalid operations should raise TypeError or return NotImplemented
461
+ # The actual behavior depends on whether __add__ returns NotImplemented
462
+ # or lets Python raise TypeError
463
+ with self.assertRaises(TypeError):
464
+ _ = dm1 + 123
465
+ if sys.version_info >= (3, 9):
466
+ with self.assertRaises(TypeError):
467
+ _ = dm1 | 123
468
+
469
+ def test_copy_operations(self):
470
+ """Test copy operations."""
471
+ obj = object()
472
+ dm1 = DictManager({'a': 1, 'b': obj})
473
+
474
+ # Shallow copy
475
+ dm2 = dm1.__copy__()
476
+ self.assertIsNot(dm2, dm1)
477
+ self.assertEqual(dict(dm2), dict(dm1))
478
+ self.assertIs(dm2['b'], obj) # Same object reference
479
+
480
+ # Deep copy
481
+ dm3 = dm1.__deepcopy__({})
482
+ self.assertIsNot(dm3, dm1)
483
+ self.assertEqual(dm3['a'], 1)
484
+ # Note: object() can't be deep copied, but other values work
485
+
486
+ def test_jax_pytree(self):
487
+ """Test JAX pytree registration."""
488
+ dm = DictManager({'a': 1, 'b': 2, 'c': 3})
489
+
490
+ # Flatten
491
+ values, keys = dm.tree_flatten()
492
+ self.assertEqual(len(values), 3)
493
+ self.assertEqual(len(keys), 3)
494
+
495
+ # Unflatten
496
+ dm2 = DictManager.tree_unflatten(keys, values)
497
+ self.assertEqual(dict(dm2), dict(dm))
498
+
499
+ # Test with JAX tree operations
500
+ dm3 = DictManager({'x': jnp.array([1, 2]), 'y': jnp.array([3, 4])})
501
+ doubled = jax.tree_util.tree_map(lambda x: x * 2, dm3)
502
+ self.assertTrue(jnp.allclose(doubled['x'], jnp.array([2, 4])))
503
+ self.assertTrue(jnp.allclose(doubled['y'], jnp.array([6, 8])))
504
+
505
+ def test_repr(self):
506
+ """Test string representation."""
507
+ dm = DictManager({'a': 1, 'b': 'text'})
508
+ repr_str = repr(dm)
509
+ self.assertIn('DictManager', repr_str)
510
+ self.assertIn("'a': 1", repr_str)
511
+ self.assertIn("'b': 'text'", repr_str)
512
+
513
+
514
+ class TestDotDict(unittest.TestCase):
515
+ """Test cases for DotDict class."""
516
+
517
+ def test_initialization(self):
518
+ """Test DotDict initialization."""
519
+ # From dict
520
+ dd1 = DotDict({'a': 1, 'b': 2})
521
+ self.assertEqual(dd1.a, 1)
522
+ self.assertEqual(dd1['b'], 2)
523
+
524
+ # From kwargs
525
+ dd2 = DotDict(a=1, b=2)
526
+ self.assertEqual(dd2.a, 1)
527
+ self.assertEqual(dd2.b, 2)
528
+
529
+ # From tuple
530
+ dd3 = DotDict(('key', 'value'))
531
+ self.assertEqual(dd3.key, 'value')
532
+
533
+ # From items
534
+ dd4 = DotDict([('a', 1), ('b', 2)])
535
+ self.assertEqual(dd4.a, 1)
536
+ self.assertEqual(dd4.b, 2)
537
+
538
+ # Empty
539
+ dd5 = DotDict()
540
+ self.assertEqual(len(dd5), 0)
541
+
542
+ # Invalid argument
543
+ with self.assertRaises(TypeError):
544
+ DotDict(123)
545
+
546
+ def test_dot_access(self):
547
+ """Test dot notation access."""
548
+ dd = DotDict({'a': 1, 'b': {'c': 2, 'd': 3}})
549
+
550
+ # Read access
551
+ self.assertEqual(dd.a, 1)
552
+ self.assertEqual(dd.b.c, 2)
553
+ self.assertEqual(dd.b.d, 3)
554
+
555
+ # Write access
556
+ dd.a = 10
557
+ dd.b.c = 20
558
+ self.assertEqual(dd['a'], 10)
559
+ self.assertEqual(dd['b']['c'], 20)
560
+
561
+ # Add new attributes
562
+ dd.e = 5
563
+ self.assertEqual(dd['e'], 5)
564
+
565
+ def test_nested_dict_conversion(self):
566
+ """Test automatic nested dict conversion."""
567
+ dd = DotDict({
568
+ 'level1': {
569
+ 'level2': {
570
+ 'level3': 'value'
571
+ }
572
+ }
573
+ })
574
+
575
+ self.assertIsInstance(dd.level1, DotDict)
576
+ self.assertIsInstance(dd.level1.level2, DotDict)
577
+ self.assertEqual(dd.level1.level2.level3, 'value')
578
+
579
+ def test_list_tuple_conversion(self):
580
+ """Test conversion of lists and tuples containing dicts."""
581
+ dd = DotDict({
582
+ 'list': [{'a': 1}, {'b': 2}],
583
+ 'tuple': ({'c': 3}, {'d': 4})
584
+ })
585
+
586
+ # List items should be DotDict
587
+ self.assertIsInstance(dd.list[0], DotDict)
588
+ self.assertEqual(dd.list[0].a, 1)
589
+ self.assertIsInstance(dd.list[1], DotDict)
590
+ self.assertEqual(dd.list[1].b, 2)
591
+
592
+ # Tuple items should be DotDict
593
+ self.assertIsInstance(dd.tuple[0], DotDict)
594
+ self.assertEqual(dd.tuple[0].c, 3)
595
+
596
+ def test_attribute_errors(self):
597
+ """Test attribute error handling."""
598
+ dd = DotDict({'a': 1})
599
+
600
+ # Non-existent attribute
601
+ with self.assertRaises(AttributeError) as ctx:
602
+ _ = dd.nonexistent
603
+ self.assertIn("has no attribute 'nonexistent'", str(ctx.exception))
604
+
605
+ # Delete non-existent attribute
606
+ with self.assertRaises(AttributeError):
607
+ del dd.nonexistent
608
+
609
+ # Try to set built-in method
610
+ with self.assertRaises(AttributeError) as ctx:
611
+ dd.keys = 'value'
612
+ self.assertIn("built-in method", str(ctx.exception))
613
+
614
+ def test_dir_method(self):
615
+ """Test __dir__ method."""
616
+ dd = DotDict({'a': 1, 'b': 2})
617
+ attrs = dir(dd)
618
+
619
+ self.assertIn('a', attrs)
620
+ self.assertIn('b', attrs)
621
+ self.assertIn('keys', attrs) # Built-in method
622
+ self.assertIn('values', attrs) # Built-in method
623
+
624
+ def test_get_method(self):
625
+ """Test get method with default."""
626
+ dd = DotDict({'a': 1})
627
+
628
+ self.assertEqual(dd.get('a'), 1)
629
+ self.assertEqual(dd.get('b'), None)
630
+ self.assertEqual(dd.get('b', 'default'), 'default')
631
+
632
+ def test_copy_operations(self):
633
+ """Test copy operations."""
634
+ dd1 = DotDict({'a': 1, 'b': {'c': 2}})
635
+
636
+ # Shallow copy
637
+ dd2 = dd1.copy()
638
+ self.assertIsNot(dd2, dd1)
639
+ self.assertEqual(dd2.a, 1)
640
+ dd2.a = 10
641
+ self.assertEqual(dd1.a, 1) # Original unchanged
642
+
643
+ # Deep copy
644
+ dd3 = dd1.deepcopy()
645
+ self.assertIsNot(dd3, dd1)
646
+ self.assertIsNot(dd3.b, dd1.b)
647
+ dd3.b.c = 20
648
+ self.assertEqual(dd1.b.c, 2) # Original unchanged
649
+
650
+ def test_to_dict_from_dict(self):
651
+ """Test conversion to/from standard dict."""
652
+ dd1 = DotDict({
653
+ 'a': 1,
654
+ 'b': {'c': 2, 'd': {'e': 3}},
655
+ 'list': [{'f': 4}]
656
+ })
657
+
658
+ # Convert to dict
659
+ d = dd1.to_dict()
660
+ self.assertIsInstance(d, dict)
661
+ self.assertNotIsInstance(d, DotDict)
662
+ self.assertIsInstance(d['b'], dict)
663
+ self.assertNotIsInstance(d['b'], DotDict)
664
+ self.assertEqual(d['b']['d']['e'], 3)
665
+
666
+ # Convert from dict
667
+ dd2 = DotDict.from_dict(d)
668
+ self.assertIsInstance(dd2, DotDict)
669
+ self.assertIsInstance(dd2.b, DotDict)
670
+ self.assertEqual(dd2.b.d.e, 3)
671
+
672
+ def test_update_method(self):
673
+ """Test update with recursive merge."""
674
+ dd = DotDict({'a': 1, 'b': {'c': 2, 'd': 3}})
675
+
676
+ # Simple update
677
+ dd.update({'a': 10})
678
+ self.assertEqual(dd.a, 10)
679
+
680
+ # Recursive merge
681
+ dd.update({'b': {'d': 30, 'e': 4}})
682
+ self.assertEqual(dd.b.c, 2) # Preserved
683
+ self.assertEqual(dd.b.d, 30) # Updated
684
+ self.assertEqual(dd.b.e, 4) # Added
685
+
686
+ # Update with kwargs
687
+ dd.update(f=5, g=6)
688
+ self.assertEqual(dd.f, 5)
689
+ self.assertEqual(dd.g, 6)
690
+
691
+ # Multiple arguments error
692
+ with self.assertRaises(TypeError):
693
+ dd.update({}, {})
694
+
695
+ def test_setdefault(self):
696
+ """Test setdefault method."""
697
+ dd = DotDict({'a': 1})
698
+
699
+ # Existing key
700
+ result = dd.setdefault('a', 10)
701
+ self.assertEqual(result, 1)
702
+ self.assertEqual(dd.a, 1)
703
+
704
+ # New key
705
+ result = dd.setdefault('b', 2)
706
+ self.assertEqual(result, 2)
707
+ self.assertEqual(dd.b, 2)
708
+
709
+ # New key with None default
710
+ result = dd.setdefault('c')
711
+ self.assertIsNone(result)
712
+ self.assertIsNone(dd.c)
713
+
714
+ def test_pickling(self):
715
+ """Test pickling/unpickling."""
716
+ dd1 = DotDict({'a': 1, 'b': {'c': 2}})
717
+
718
+ # Pickle and unpickle
719
+ pickled = pickle.dumps(dd1)
720
+ dd2 = pickle.loads(pickled)
721
+
722
+ self.assertIsNot(dd2, dd1)
723
+ self.assertEqual(dd2.a, 1)
724
+ self.assertEqual(dd2.b.c, 2)
725
+ self.assertIsInstance(dd2, DotDict)
726
+ self.assertIsInstance(dd2.b, DotDict)
727
+
728
+ def test_jax_pytree(self):
729
+ """Test JAX pytree registration."""
730
+ dd = DotDict({'a': jnp.array([1, 2]), 'b': jnp.array([3, 4])})
731
+
732
+ # Tree operations
733
+ doubled = jax.tree_util.tree_map(lambda x: x * 2, dd)
734
+ self.assertTrue(jnp.allclose(doubled.a, jnp.array([2, 4])))
735
+ self.assertTrue(jnp.allclose(doubled.b, jnp.array([6, 8])))
736
+
737
+ def test_repr(self):
738
+ """Test string representation."""
739
+ dd = DotDict({'a': 1, 'b': 'text'})
740
+ repr_str = repr(dd)
741
+ self.assertIn('DotDict', repr_str)
742
+ self.assertIn("'a': 1", repr_str)
743
+ self.assertIn("'b': 'text'", repr_str)
744
+
745
+
746
+ class TestUtilityFunctions(unittest.TestCase):
747
+ """Test cases for utility functions."""
748
+
749
+ def test_merge_dicts_basic(self):
750
+ """Test basic dict merging."""
751
+ d1 = {'a': 1, 'b': 2}
752
+ d2 = {'b': 3, 'c': 4}
753
+ d3 = {'d': 5}
754
+
755
+ result = merge_dicts(d1, d2, d3)
756
+ self.assertEqual(result, {'a': 1, 'b': 3, 'c': 4, 'd': 5})
757
+
758
+ # Original dicts should be unchanged
759
+ self.assertEqual(d1, {'a': 1, 'b': 2})
760
+
761
+ def test_merge_dicts_recursive(self):
762
+ """Test recursive dict merging."""
763
+ d1 = {'a': 1, 'b': {'c': 2, 'd': 3}}
764
+ d2 = {'b': {'d': 4, 'e': 5}, 'f': 6}
765
+
766
+ result = merge_dicts(d1, d2, recursive=True)
767
+ self.assertEqual(result, {
768
+ 'a': 1,
769
+ 'b': {'c': 2, 'd': 4, 'e': 5},
770
+ 'f': 6
771
+ })
772
+
773
+ def test_merge_dicts_non_recursive(self):
774
+ """Test non-recursive dict merging."""
775
+ d1 = {'a': 1, 'b': {'c': 2}}
776
+ d2 = {'b': {'d': 3}}
777
+
778
+ result = merge_dicts(d1, d2, recursive=False)
779
+ self.assertEqual(result, {'a': 1, 'b': {'d': 3}})
780
+
781
+ def test_merge_dicts_errors(self):
782
+ """Test merge_dicts error handling."""
783
+ with self.assertRaises(TypeError):
784
+ merge_dicts({'a': 1}, [1, 2, 3])
785
+
786
+ def test_flatten_dict(self):
787
+ """Test dictionary flattening."""
788
+ nested = {
789
+ 'a': 1,
790
+ 'b': {
791
+ 'c': 2,
792
+ 'd': {
793
+ 'e': 3,
794
+ 'f': 4
795
+ }
796
+ },
797
+ 'g': 5
798
+ }
799
+
800
+ flat = flatten_dict(nested)
801
+ self.assertEqual(flat, {
802
+ 'a': 1,
803
+ 'b.c': 2,
804
+ 'b.d.e': 3,
805
+ 'b.d.f': 4,
806
+ 'g': 5
807
+ })
808
+
809
+ # Custom separator
810
+ flat_dash = flatten_dict(nested, sep='-')
811
+ self.assertEqual(flat_dash, {
812
+ 'a': 1,
813
+ 'b-c': 2,
814
+ 'b-d-e': 3,
815
+ 'b-d-f': 4,
816
+ 'g': 5
817
+ })
818
+
819
+ def test_unflatten_dict(self):
820
+ """Test dictionary unflattening."""
821
+ flat = {
822
+ 'a': 1,
823
+ 'b.c': 2,
824
+ 'b.d.e': 3,
825
+ 'b.d.f': 4,
826
+ 'g': 5
827
+ }
828
+
829
+ nested = unflatten_dict(flat)
830
+ self.assertEqual(nested, {
831
+ 'a': 1,
832
+ 'b': {
833
+ 'c': 2,
834
+ 'd': {
835
+ 'e': 3,
836
+ 'f': 4
837
+ }
838
+ },
839
+ 'g': 5
840
+ })
841
+
842
+ # Custom separator
843
+ flat_dash = {'a': 1, 'b-c': 2, 'b-d': 3}
844
+ nested_dash = unflatten_dict(flat_dash, sep='-')
845
+ self.assertEqual(nested_dash, {
846
+ 'a': 1,
847
+ 'b': {'c': 2, 'd': 3}
848
+ })
849
+
850
+ def test_flatten_unflatten_roundtrip(self):
851
+ """Test that flatten/unflatten is reversible."""
852
+ original = {
853
+ 'level1': {
854
+ 'level2': {
855
+ 'level3': 'value',
856
+ 'another': 42
857
+ },
858
+ 'sibling': 'data'
859
+ },
860
+ 'root': 'element'
861
+ }
862
+
863
+ flattened = flatten_dict(original)
864
+ unflattened = unflatten_dict(flattened)
865
+ self.assertEqual(unflattened, original)
866
+
867
+ def test_is_instance_eval(self):
868
+ """Test is_instance_eval function."""
869
+ # Single type
870
+ is_int = is_instance_eval(int)
871
+ self.assertTrue(is_int(5))
872
+ self.assertFalse(is_int("5"))
873
+ self.assertFalse(is_int(5.0))
874
+
875
+ # Multiple types
876
+ is_number = is_instance_eval(int, float)
877
+ self.assertTrue(is_number(5))
878
+ self.assertTrue(is_number(5.0))
879
+ self.assertFalse(is_number("5"))
880
+
881
+ # With subclasses
882
+ class MyList(list):
883
+ pass
884
+
885
+ is_list = is_instance_eval(list)
886
+ self.assertTrue(is_list([1, 2, 3]))
887
+ self.assertTrue(is_list(MyList([1, 2, 3])))
888
+ self.assertFalse(is_list((1, 2, 3)))
889
+
890
+ def test_not_instance_eval(self):
891
+ """Test not_instance_eval function."""
892
+ # Single type
893
+ not_int = not_instance_eval(int)
894
+ self.assertFalse(not_int(5))
895
+ self.assertTrue(not_int("5"))
896
+ self.assertTrue(not_int(5.0))
897
+
898
+ # Multiple types
899
+ not_number = not_instance_eval(int, float)
900
+ self.assertFalse(not_number(5))
901
+ self.assertFalse(not_number(5.0))
902
+ self.assertTrue(not_number("5"))
903
+ self.assertTrue(not_number([1, 2, 3]))
904
+
905
+
906
+ class TestJaxIntegration(unittest.TestCase):
907
+ """Test JAX integration for DictManager and DotDict."""
908
+
909
+ def test_dictmanager_pytree_operations(self):
910
+ """Test DictManager with JAX tree operations."""
911
+ dm = DictManager({
912
+ 'weights': jnp.array([1.0, 2.0, 3.0]),
913
+ 'bias': jnp.array([0.1, 0.2])
914
+ })
915
+
916
+ # Tree map
917
+ scaled = jax.tree_util.tree_map(lambda x: x * 2, dm)
918
+ self.assertTrue(jnp.allclose(scaled['weights'], jnp.array([2.0, 4.0, 6.0])))
919
+ self.assertTrue(jnp.allclose(scaled['bias'], jnp.array([0.2, 0.4])))
920
+
921
+ # Tree reduce
922
+ total = jax.tree_util.tree_reduce(lambda x, y: x + y.sum(), dm, 0.0)
923
+ self.assertAlmostEqual(total, 6.3, places=5)
924
+
925
+ def test_dotdict_pytree_operations(self):
926
+ """Test DotDict with JAX tree operations."""
927
+ dd = DotDict({
928
+ 'model': {
929
+ 'weights': jnp.array([[1.0, 2.0], [3.0, 4.0]]),
930
+ 'bias': jnp.array([0.1, 0.2])
931
+ },
932
+ 'optimizer': {
933
+ 'lr': 0.01
934
+ }
935
+ })
936
+
937
+ # Tree map on nested structure
938
+ def scale_arrays(x):
939
+ return x * 2 if isinstance(x, jnp.ndarray) else x
940
+
941
+ scaled = jax.tree_util.tree_map(scale_arrays, dd)
942
+ self.assertTrue(jnp.allclose(
943
+ scaled.model.weights,
944
+ jnp.array([[2.0, 4.0], [6.0, 8.0]])
945
+ ))
946
+ self.assertEqual(scaled.optimizer.lr, 0.01) # Non-array unchanged
947
+
948
+ def test_mixed_pytree_structures(self):
949
+ """Test mixing DictManager and DotDict in pytree operations."""
950
+ structure = {
951
+ 'dict_manager': DictManager({'a': jnp.array([1, 2])})
952
+ }
953
+
954
+ doubled = jax.tree_util.tree_map(lambda x: x * 2, structure)
955
+ self.assertTrue(jnp.allclose(
956
+ doubled['dict_manager']['a'],
957
+ jnp.array([2, 4])
958
+ ))
959
+
960
+
961
+ if __name__ == '__main__':
962
+ unittest.main()