brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +95 -29
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.9.dist-info/RECORD +0 -130
  161. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,962 @@
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ Comprehensive tests for the others module.
18
+ """
19
+
20
+ import pickle
21
+ import threading
22
+ import unittest
23
+ from unittest.mock import MagicMock, patch
24
+
25
+ import jax
26
+ import jax.numpy as jnp
27
+
28
+ from brainstate.util._others import (
29
+ DictManager,
30
+ DotDict,
31
+ NameContext,
32
+ clear_buffer_memory,
33
+ flatten_dict,
34
+ get_unique_name,
35
+ is_instance_eval,
36
+ merge_dicts,
37
+ not_instance_eval,
38
+ split_total,
39
+ unflatten_dict,
40
+ )
41
+
42
+
43
+ class TestSplitTotal(unittest.TestCase):
44
+ """Test cases for split_total function."""
45
+
46
+ def test_float_fraction(self):
47
+ """Test with float fraction values."""
48
+ self.assertEqual(split_total(100, 0.5), 50)
49
+ self.assertEqual(split_total(100, 0.25), 25)
50
+ self.assertEqual(split_total(100, 0.75), 75)
51
+ self.assertEqual(split_total(100, 0.0), 0)
52
+ self.assertEqual(split_total(100, 1.0), 100)
53
+
54
+ def test_int_fraction(self):
55
+ """Test with integer fraction values."""
56
+ self.assertEqual(split_total(100, 25), 25)
57
+ self.assertEqual(split_total(100, 0), 0)
58
+ self.assertEqual(split_total(100, 100), 100)
59
+ self.assertEqual(split_total(50, 30), 30)
60
+
61
+ def test_edge_cases(self):
62
+ """Test edge cases."""
63
+ self.assertEqual(split_total(1, 0.5), 0) # int(0.5) = 0
64
+ self.assertEqual(split_total(1, 1), 1)
65
+ self.assertEqual(split_total(10, 0.99), 9) # int(9.9) = 9
66
+
67
+ def test_type_errors(self):
68
+ """Test type error handling."""
69
+ with self.assertRaises(TypeError) as ctx:
70
+ split_total("100", 0.5)
71
+ self.assertIn("must be an integer", str(ctx.exception))
72
+
73
+ with self.assertRaises(TypeError) as ctx:
74
+ split_total(100, "0.5")
75
+ self.assertIn("must be an integer or float", str(ctx.exception))
76
+
77
+ with self.assertRaises(TypeError) as ctx:
78
+ split_total(100.5, 0.5)
79
+ self.assertIn("must be an integer", str(ctx.exception))
80
+
81
+ def test_value_errors(self):
82
+ """Test value error handling."""
83
+ # Negative total
84
+ with self.assertRaises(ValueError) as ctx:
85
+ split_total(-10, 0.5)
86
+ self.assertIn("must be a positive integer", str(ctx.exception))
87
+
88
+ # Zero total
89
+ with self.assertRaises(ValueError) as ctx:
90
+ split_total(0, 0.5)
91
+ self.assertIn("must be a positive integer", str(ctx.exception))
92
+
93
+ # Negative fraction (float)
94
+ with self.assertRaises(ValueError) as ctx:
95
+ split_total(100, -0.5)
96
+ self.assertIn("cannot be negative", str(ctx.exception))
97
+
98
+ # Fraction > 1 (float)
99
+ with self.assertRaises(ValueError) as ctx:
100
+ split_total(100, 1.5)
101
+ self.assertIn("cannot be greater than 1", str(ctx.exception))
102
+
103
+ # Negative fraction (int)
104
+ with self.assertRaises(ValueError) as ctx:
105
+ split_total(100, -10)
106
+ self.assertIn("cannot be negative", str(ctx.exception))
107
+
108
+ # Fraction > total (int)
109
+ with self.assertRaises(ValueError) as ctx:
110
+ split_total(100, 150)
111
+ self.assertIn("cannot be greater than total", str(ctx.exception))
112
+
113
+
114
+ class TestNameContext(unittest.TestCase):
115
+ """Test cases for NameContext and get_unique_name."""
116
+
117
+ def setUp(self):
118
+ """Reset the global NAME context before each test."""
119
+ global NAME
120
+ from brainstate.util._others import NAME
121
+ NAME.typed_names.clear()
122
+
123
+ def test_get_unique_name_basic(self):
124
+ """Test basic unique name generation."""
125
+ name1 = get_unique_name('layer')
126
+ name2 = get_unique_name('layer')
127
+ name3 = get_unique_name('layer')
128
+
129
+ self.assertEqual(name1, 'layer0')
130
+ self.assertEqual(name2, 'layer1')
131
+ self.assertEqual(name3, 'layer2')
132
+
133
+ def test_get_unique_name_with_prefix(self):
134
+ """Test unique name generation with prefix."""
135
+ name1 = get_unique_name('layer', 'conv_')
136
+ name2 = get_unique_name('layer', 'conv_')
137
+ name3 = get_unique_name('layer', 'dense_')
138
+
139
+ self.assertEqual(name1, 'conv_layer0')
140
+ self.assertEqual(name2, 'conv_layer1')
141
+ self.assertEqual(name3, 'dense_layer2')
142
+
143
+ def test_different_types(self):
144
+ """Test unique names for different types."""
145
+ layer1 = get_unique_name('layer')
146
+ neuron1 = get_unique_name('neuron')
147
+ layer2 = get_unique_name('layer')
148
+ neuron2 = get_unique_name('neuron')
149
+
150
+ self.assertEqual(layer1, 'layer0')
151
+ self.assertEqual(neuron1, 'neuron0')
152
+ self.assertEqual(layer2, 'layer1')
153
+ self.assertEqual(neuron2, 'neuron1')
154
+
155
+ def test_thread_local_context(self):
156
+ """Test that NameContext is thread-local."""
157
+ results = {}
158
+
159
+ def worker(thread_id):
160
+ name1 = get_unique_name(f'type_{thread_id}')
161
+ name2 = get_unique_name(f'type_{thread_id}')
162
+ results[thread_id] = (name1, name2)
163
+
164
+ threads = []
165
+ for i in range(3):
166
+ t = threading.Thread(target=worker, args=(i,))
167
+ threads.append(t)
168
+ t.start()
169
+
170
+ for t in threads:
171
+ t.join()
172
+
173
+ # Each thread should have independent counters
174
+ for thread_id in range(3):
175
+ self.assertEqual(results[thread_id][0], f'type_{thread_id}0')
176
+ self.assertEqual(results[thread_id][1], f'type_{thread_id}1')
177
+
178
+ def test_name_context_reset(self):
179
+ """Test resetting name context."""
180
+ context = NameContext()
181
+ context.typed_names['test'] = 5
182
+
183
+ # Reset specific type
184
+ context.reset('test')
185
+ self.assertEqual(context.typed_names.get('test'), 0)
186
+
187
+ # Add more names
188
+ context.typed_names['test1'] = 1
189
+ context.typed_names['test2'] = 2
190
+
191
+ # Reset all
192
+ context.reset()
193
+ self.assertEqual(len(context.typed_names), 0)
194
+
195
+
196
+ class TestDictManager(unittest.TestCase):
197
+ """Test cases for DictManager class."""
198
+
199
+ def test_initialization(self):
200
+ """Test DictManager initialization."""
201
+ # Empty initialization
202
+ dm1 = DictManager()
203
+ self.assertEqual(len(dm1), 0)
204
+
205
+ # From dict
206
+ dm2 = DictManager({'a': 1, 'b': 2})
207
+ self.assertEqual(dm2['a'], 1)
208
+ self.assertEqual(dm2['b'], 2)
209
+
210
+ # From kwargs
211
+ dm3 = DictManager(a=1, b=2)
212
+ self.assertEqual(dm3['a'], 1)
213
+ self.assertEqual(dm3['b'], 2)
214
+
215
+ # From items
216
+ dm4 = DictManager([('a', 1), ('b', 2)])
217
+ self.assertEqual(dm4['a'], 1)
218
+ self.assertEqual(dm4['b'], 2)
219
+
220
+ def test_subset(self):
221
+ """Test subset filtering."""
222
+ dm = DictManager({'a': 1, 'b': 2.0, 'c': 'text', 'd': 3})
223
+
224
+ # By type
225
+ int_subset = dm.subset(int)
226
+ self.assertEqual(dict(int_subset), {'a': 1, 'd': 3})
227
+
228
+ # By multiple types
229
+ num_subset = dm.subset((int, float))
230
+ self.assertEqual(dict(num_subset), {'a': 1, 'b': 2.0, 'd': 3})
231
+
232
+ # By predicate
233
+ large_subset = dm.subset(lambda x: isinstance(x, (int, float)) and x > 1.5)
234
+ self.assertEqual(dict(large_subset), {'b': 2.0, 'd': 3})
235
+
236
+ def test_not_subset(self):
237
+ """Test not_subset filtering."""
238
+ dm = DictManager({'a': 1, 'b': 2.0, 'c': 'text', 'd': 3})
239
+
240
+ not_int = dm.not_subset(int)
241
+ self.assertEqual(dict(not_int), {'b': 2.0, 'c': 'text'})
242
+
243
+ not_num = dm.not_subset((int, float))
244
+ self.assertEqual(dict(not_num), {'c': 'text'})
245
+
246
+ def test_add_unique_key(self):
247
+ """Test adding unique keys."""
248
+ dm = DictManager()
249
+ obj1 = object()
250
+ obj2 = object()
251
+
252
+ # Add new key
253
+ dm.add_unique_key('key1', obj1)
254
+ self.assertIs(dm['key1'], obj1)
255
+
256
+ # Add same key with same object (should work)
257
+ dm.add_unique_key('key1', obj1)
258
+ self.assertIs(dm['key1'], obj1)
259
+
260
+ # Add same key with different object (should fail)
261
+ with self.assertRaises(ValueError) as ctx:
262
+ dm.add_unique_key('key1', obj2)
263
+ self.assertIn("already exists with a different value", str(ctx.exception))
264
+
265
+ def test_add_unique_value(self):
266
+ """Test adding unique values."""
267
+ dm = DictManager()
268
+ obj1 = object()
269
+
270
+ # First addition should succeed
271
+ result1 = dm.add_unique_value('key1', obj1)
272
+ self.assertTrue(result1)
273
+ self.assertIs(dm['key1'], obj1)
274
+
275
+ # Adding same value with different key should fail
276
+ result2 = dm.add_unique_value('key2', obj1)
277
+ self.assertFalse(result2)
278
+ self.assertNotIn('key2', dm)
279
+
280
+ # Adding different value should succeed
281
+ obj2 = object()
282
+ result3 = dm.add_unique_value('key2', obj2)
283
+ self.assertTrue(result3)
284
+ self.assertIs(dm['key2'], obj2)
285
+
286
+ def test_unique(self):
287
+ """Test getting unique values."""
288
+ obj1 = object()
289
+ obj2 = object()
290
+ dm = DictManager({'a': obj1, 'b': obj2, 'c': obj1, 'd': obj2, 'e': obj1})
291
+
292
+ unique_dm = dm.unique()
293
+ self.assertEqual(len(unique_dm), 2)
294
+ # Check that each object appears only once
295
+ values = list(unique_dm.values())
296
+ self.assertEqual(len(set(id(v) for v in values)), 2)
297
+
298
+ def test_unique_inplace(self):
299
+ """Test in-place unique operation."""
300
+ obj1 = object()
301
+ obj2 = object()
302
+ dm = DictManager({'a': obj1, 'b': obj2, 'c': obj1})
303
+
304
+ result = dm.unique_()
305
+ self.assertIs(result, dm) # Should return self
306
+ self.assertEqual(len(dm), 2) # One duplicate removed
307
+
308
+ def test_assign(self):
309
+ """Test assign method."""
310
+ dm = DictManager({'a': 1})
311
+
312
+ # Assign from dict
313
+ dm.assign({'b': 2, 'c': 3})
314
+ self.assertEqual(dict(dm), {'a': 1, 'b': 2, 'c': 3})
315
+
316
+ # Assign from multiple dicts
317
+ dm.assign({'d': 4}, {'e': 5})
318
+ self.assertEqual(len(dm), 5)
319
+
320
+ # Assign with kwargs
321
+ dm.assign(f=6, g=7)
322
+ self.assertEqual(dm['f'], 6)
323
+ self.assertEqual(dm['g'], 7)
324
+
325
+ # Invalid argument
326
+ with self.assertRaises(TypeError):
327
+ dm.assign([1, 2, 3])
328
+
329
+ def test_split(self):
330
+ """Test splitting by types."""
331
+ dm = DictManager({
332
+ 'a': 1, 'b': 2.0, 'c': 'text',
333
+ 'd': 3, 'e': 4.5, 'f': [1, 2]
334
+ })
335
+
336
+ int_dm, float_dm, rest = dm.split(int, float)
337
+
338
+ self.assertEqual(dict(int_dm), {'a': 1, 'd': 3})
339
+ self.assertEqual(dict(float_dm), {'b': 2.0, 'e': 4.5})
340
+ self.assertEqual(dict(rest), {'c': 'text', 'f': [1, 2]})
341
+
342
+ def test_filter_by_predicate(self):
343
+ """Test filtering with predicate."""
344
+ dm = DictManager({'a': 1, 'b': 2, 'c': 3, 'd': 4})
345
+
346
+ # Filter by key
347
+ filtered = dm.filter_by_predicate(lambda k, v: k in ['a', 'c'])
348
+ self.assertEqual(dict(filtered), {'a': 1, 'c': 3})
349
+
350
+ # Filter by value
351
+ filtered = dm.filter_by_predicate(lambda k, v: v > 2)
352
+ self.assertEqual(dict(filtered), {'c': 3, 'd': 4})
353
+
354
+ # Filter by both
355
+ filtered = dm.filter_by_predicate(lambda k, v: k == 'a' or v == 4)
356
+ self.assertEqual(dict(filtered), {'a': 1, 'd': 4})
357
+
358
+ def test_map_values(self):
359
+ """Test mapping function to values."""
360
+ dm = DictManager({'a': 1, 'b': 2, 'c': 3})
361
+
362
+ doubled = dm.map_values(lambda x: x * 2)
363
+ self.assertEqual(dict(doubled), {'a': 2, 'b': 4, 'c': 6})
364
+
365
+ # Original should be unchanged
366
+ self.assertEqual(dict(dm), {'a': 1, 'b': 2, 'c': 3})
367
+
368
+ def test_map_keys(self):
369
+ """Test mapping function to keys."""
370
+ dm = DictManager({'a': 1, 'b': 2, 'c': 3})
371
+
372
+ upper = dm.map_keys(str.upper)
373
+ self.assertEqual(dict(upper), {'A': 1, 'B': 2, 'C': 3})
374
+
375
+ # Test duplicate key error
376
+ with self.assertRaises(ValueError) as ctx:
377
+ dm.map_keys(lambda x: 'same')
378
+ self.assertIn("duplicate", str(ctx.exception))
379
+
380
+ def test_pop_by_keys(self):
381
+ """Test removing items by keys."""
382
+ dm = DictManager({'a': 1, 'b': 2, 'c': 3, 'd': 4})
383
+
384
+ dm.pop_by_keys(['b', 'd'])
385
+ self.assertEqual(dict(dm), {'a': 1, 'c': 3})
386
+
387
+ # Pop non-existent keys (should not raise)
388
+ dm.pop_by_keys(['x', 'y'])
389
+ self.assertEqual(dict(dm), {'a': 1, 'c': 3})
390
+
391
+ def test_pop_by_values(self):
392
+ """Test removing items by values."""
393
+ obj1, obj2, obj3 = object(), object(), object()
394
+ dm = DictManager({'a': obj1, 'b': obj2, 'c': obj3})
395
+
396
+ # By identity
397
+ dm_copy = DictManager(dm)
398
+ dm_copy.pop_by_values([obj2], by='id')
399
+ self.assertEqual(len(dm_copy), 2)
400
+ self.assertNotIn('b', dm_copy)
401
+
402
+ # By value equality
403
+ dm2 = DictManager({'a': 1, 'b': 2, 'c': 3})
404
+ dm2.pop_by_values([2, 3], by='value')
405
+ self.assertEqual(dict(dm2), {'a': 1})
406
+
407
+ # Invalid method
408
+ with self.assertRaises(ValueError):
409
+ dm.pop_by_values([obj1], by='invalid')
410
+
411
+ def test_difference_operations(self):
412
+ """Test difference operations."""
413
+ dm = DictManager({'a': 1, 'b': 2, 'c': 3, 'd': 4})
414
+
415
+ # Difference by keys
416
+ diff = dm.difference_by_keys(['b', 'd'])
417
+ self.assertEqual(dict(diff), {'a': 1, 'c': 3})
418
+
419
+ # Difference by values
420
+ diff = dm.difference_by_values([2, 4], by='value')
421
+ self.assertEqual(dict(diff), {'a': 1, 'c': 3})
422
+
423
+ def test_intersection_operations(self):
424
+ """Test intersection operations."""
425
+ dm = DictManager({'a': 1, 'b': 2, 'c': 3, 'd': 4})
426
+
427
+ # Intersection by keys
428
+ inter = dm.intersection_by_keys(['a', 'c', 'e'])
429
+ self.assertEqual(dict(inter), {'a': 1, 'c': 3})
430
+
431
+ # Intersection by values
432
+ inter = dm.intersection_by_values([1, 3, 5], by='value')
433
+ self.assertEqual(dict(inter), {'a': 1, 'c': 3})
434
+
435
+ def test_operators(self):
436
+ """Test operator overloading."""
437
+ dm1 = DictManager({'a': 1, 'b': 2})
438
+ dm2 = DictManager({'c': 3, 'd': 4})
439
+
440
+ # Addition
441
+ dm3 = dm1 + dm2
442
+ self.assertEqual(dict(dm3), {'a': 1, 'b': 2, 'c': 3, 'd': 4})
443
+ self.assertIsNot(dm3, dm1) # New object
444
+
445
+ # Addition with regular dict
446
+ dm4 = dm1 + {'e': 5}
447
+ self.assertEqual(dm4['e'], 5)
448
+
449
+ # Union operator (Python 3.9+) - only test if available
450
+ import sys
451
+ if sys.version_info >= (3, 9):
452
+ dm5 = dm1 | dm2
453
+ self.assertEqual(dict(dm5), {'a': 1, 'b': 2, 'c': 3, 'd': 4})
454
+
455
+ # In-place union
456
+ dm1_copy = DictManager(dm1)
457
+ dm1_copy |= dm2
458
+ self.assertEqual(dict(dm1_copy), {'a': 1, 'b': 2, 'c': 3, 'd': 4})
459
+
460
+ # Invalid operations should raise TypeError or return NotImplemented
461
+ # The actual behavior depends on whether __add__ returns NotImplemented
462
+ # or lets Python raise TypeError
463
+ with self.assertRaises(TypeError):
464
+ _ = dm1 + 123
465
+ if sys.version_info >= (3, 9):
466
+ with self.assertRaises(TypeError):
467
+ _ = dm1 | 123
468
+
469
+ def test_copy_operations(self):
470
+ """Test copy operations."""
471
+ obj = object()
472
+ dm1 = DictManager({'a': 1, 'b': obj})
473
+
474
+ # Shallow copy
475
+ dm2 = dm1.__copy__()
476
+ self.assertIsNot(dm2, dm1)
477
+ self.assertEqual(dict(dm2), dict(dm1))
478
+ self.assertIs(dm2['b'], obj) # Same object reference
479
+
480
+ # Deep copy
481
+ dm3 = dm1.__deepcopy__({})
482
+ self.assertIsNot(dm3, dm1)
483
+ self.assertEqual(dm3['a'], 1)
484
+ # Note: object() can't be deep copied, but other values work
485
+
486
+ def test_jax_pytree(self):
487
+ """Test JAX pytree registration."""
488
+ dm = DictManager({'a': 1, 'b': 2, 'c': 3})
489
+
490
+ # Flatten
491
+ values, keys = dm.tree_flatten()
492
+ self.assertEqual(len(values), 3)
493
+ self.assertEqual(len(keys), 3)
494
+
495
+ # Unflatten
496
+ dm2 = DictManager.tree_unflatten(keys, values)
497
+ self.assertEqual(dict(dm2), dict(dm))
498
+
499
+ # Test with JAX tree operations
500
+ dm3 = DictManager({'x': jnp.array([1, 2]), 'y': jnp.array([3, 4])})
501
+ doubled = jax.tree_util.tree_map(lambda x: x * 2, dm3)
502
+ self.assertTrue(jnp.allclose(doubled['x'], jnp.array([2, 4])))
503
+ self.assertTrue(jnp.allclose(doubled['y'], jnp.array([6, 8])))
504
+
505
+ def test_repr(self):
506
+ """Test string representation."""
507
+ dm = DictManager({'a': 1, 'b': 'text'})
508
+ repr_str = repr(dm)
509
+ self.assertIn('DictManager', repr_str)
510
+ self.assertIn("'a': 1", repr_str)
511
+ self.assertIn("'b': 'text'", repr_str)
512
+
513
+
514
+ class TestDotDict(unittest.TestCase):
515
+ """Test cases for DotDict class."""
516
+
517
+ def test_initialization(self):
518
+ """Test DotDict initialization."""
519
+ # From dict
520
+ dd1 = DotDict({'a': 1, 'b': 2})
521
+ self.assertEqual(dd1.a, 1)
522
+ self.assertEqual(dd1['b'], 2)
523
+
524
+ # From kwargs
525
+ dd2 = DotDict(a=1, b=2)
526
+ self.assertEqual(dd2.a, 1)
527
+ self.assertEqual(dd2.b, 2)
528
+
529
+ # From tuple
530
+ dd3 = DotDict(('key', 'value'))
531
+ self.assertEqual(dd3.key, 'value')
532
+
533
+ # From items
534
+ dd4 = DotDict([('a', 1), ('b', 2)])
535
+ self.assertEqual(dd4.a, 1)
536
+ self.assertEqual(dd4.b, 2)
537
+
538
+ # Empty
539
+ dd5 = DotDict()
540
+ self.assertEqual(len(dd5), 0)
541
+
542
+ # Invalid argument
543
+ with self.assertRaises(TypeError):
544
+ DotDict(123)
545
+
546
+ def test_dot_access(self):
547
+ """Test dot notation access."""
548
+ dd = DotDict({'a': 1, 'b': {'c': 2, 'd': 3}})
549
+
550
+ # Read access
551
+ self.assertEqual(dd.a, 1)
552
+ self.assertEqual(dd.b.c, 2)
553
+ self.assertEqual(dd.b.d, 3)
554
+
555
+ # Write access
556
+ dd.a = 10
557
+ dd.b.c = 20
558
+ self.assertEqual(dd['a'], 10)
559
+ self.assertEqual(dd['b']['c'], 20)
560
+
561
+ # Add new attributes
562
+ dd.e = 5
563
+ self.assertEqual(dd['e'], 5)
564
+
565
+ def test_nested_dict_conversion(self):
566
+ """Test automatic nested dict conversion."""
567
+ dd = DotDict({
568
+ 'level1': {
569
+ 'level2': {
570
+ 'level3': 'value'
571
+ }
572
+ }
573
+ })
574
+
575
+ self.assertIsInstance(dd.level1, DotDict)
576
+ self.assertIsInstance(dd.level1.level2, DotDict)
577
+ self.assertEqual(dd.level1.level2.level3, 'value')
578
+
579
+ def test_list_tuple_conversion(self):
580
+ """Test conversion of lists and tuples containing dicts."""
581
+ dd = DotDict({
582
+ 'list': [{'a': 1}, {'b': 2}],
583
+ 'tuple': ({'c': 3}, {'d': 4})
584
+ })
585
+
586
+ # List items should be DotDict
587
+ self.assertIsInstance(dd.list[0], DotDict)
588
+ self.assertEqual(dd.list[0].a, 1)
589
+ self.assertIsInstance(dd.list[1], DotDict)
590
+ self.assertEqual(dd.list[1].b, 2)
591
+
592
+ # Tuple items should be DotDict
593
+ self.assertIsInstance(dd.tuple[0], DotDict)
594
+ self.assertEqual(dd.tuple[0].c, 3)
595
+
596
+ def test_attribute_errors(self):
597
+ """Test attribute error handling."""
598
+ dd = DotDict({'a': 1})
599
+
600
+ # Non-existent attribute
601
+ with self.assertRaises(AttributeError) as ctx:
602
+ _ = dd.nonexistent
603
+ self.assertIn("has no attribute 'nonexistent'", str(ctx.exception))
604
+
605
+ # Delete non-existent attribute
606
+ with self.assertRaises(AttributeError):
607
+ del dd.nonexistent
608
+
609
+ # Try to set built-in method
610
+ with self.assertRaises(AttributeError) as ctx:
611
+ dd.keys = 'value'
612
+ self.assertIn("built-in method", str(ctx.exception))
613
+
614
+ def test_dir_method(self):
615
+ """Test __dir__ method."""
616
+ dd = DotDict({'a': 1, 'b': 2})
617
+ attrs = dir(dd)
618
+
619
+ self.assertIn('a', attrs)
620
+ self.assertIn('b', attrs)
621
+ self.assertIn('keys', attrs) # Built-in method
622
+ self.assertIn('values', attrs) # Built-in method
623
+
624
+ def test_get_method(self):
625
+ """Test get method with default."""
626
+ dd = DotDict({'a': 1})
627
+
628
+ self.assertEqual(dd.get('a'), 1)
629
+ self.assertEqual(dd.get('b'), None)
630
+ self.assertEqual(dd.get('b', 'default'), 'default')
631
+
632
+ def test_copy_operations(self):
633
+ """Test copy operations."""
634
+ dd1 = DotDict({'a': 1, 'b': {'c': 2}})
635
+
636
+ # Shallow copy
637
+ dd2 = dd1.copy()
638
+ self.assertIsNot(dd2, dd1)
639
+ self.assertEqual(dd2.a, 1)
640
+ dd2.a = 10
641
+ self.assertEqual(dd1.a, 1) # Original unchanged
642
+
643
+ # Deep copy
644
+ dd3 = dd1.deepcopy()
645
+ self.assertIsNot(dd3, dd1)
646
+ self.assertIsNot(dd3.b, dd1.b)
647
+ dd3.b.c = 20
648
+ self.assertEqual(dd1.b.c, 2) # Original unchanged
649
+
650
+ def test_to_dict_from_dict(self):
651
+ """Test conversion to/from standard dict."""
652
+ dd1 = DotDict({
653
+ 'a': 1,
654
+ 'b': {'c': 2, 'd': {'e': 3}},
655
+ 'list': [{'f': 4}]
656
+ })
657
+
658
+ # Convert to dict
659
+ d = dd1.to_dict()
660
+ self.assertIsInstance(d, dict)
661
+ self.assertNotIsInstance(d, DotDict)
662
+ self.assertIsInstance(d['b'], dict)
663
+ self.assertNotIsInstance(d['b'], DotDict)
664
+ self.assertEqual(d['b']['d']['e'], 3)
665
+
666
+ # Convert from dict
667
+ dd2 = DotDict.from_dict(d)
668
+ self.assertIsInstance(dd2, DotDict)
669
+ self.assertIsInstance(dd2.b, DotDict)
670
+ self.assertEqual(dd2.b.d.e, 3)
671
+
672
+ def test_update_method(self):
673
+ """Test update with recursive merge."""
674
+ dd = DotDict({'a': 1, 'b': {'c': 2, 'd': 3}})
675
+
676
+ # Simple update
677
+ dd.update({'a': 10})
678
+ self.assertEqual(dd.a, 10)
679
+
680
+ # Recursive merge
681
+ dd.update({'b': {'d': 30, 'e': 4}})
682
+ self.assertEqual(dd.b.c, 2) # Preserved
683
+ self.assertEqual(dd.b.d, 30) # Updated
684
+ self.assertEqual(dd.b.e, 4) # Added
685
+
686
+ # Update with kwargs
687
+ dd.update(f=5, g=6)
688
+ self.assertEqual(dd.f, 5)
689
+ self.assertEqual(dd.g, 6)
690
+
691
+ # Multiple arguments error
692
+ with self.assertRaises(TypeError):
693
+ dd.update({}, {})
694
+
695
+ def test_setdefault(self):
696
+ """Test setdefault method."""
697
+ dd = DotDict({'a': 1})
698
+
699
+ # Existing key
700
+ result = dd.setdefault('a', 10)
701
+ self.assertEqual(result, 1)
702
+ self.assertEqual(dd.a, 1)
703
+
704
+ # New key
705
+ result = dd.setdefault('b', 2)
706
+ self.assertEqual(result, 2)
707
+ self.assertEqual(dd.b, 2)
708
+
709
+ # New key with None default
710
+ result = dd.setdefault('c')
711
+ self.assertIsNone(result)
712
+ self.assertIsNone(dd.c)
713
+
714
+ def test_pickling(self):
715
+ """Test pickling/unpickling."""
716
+ dd1 = DotDict({'a': 1, 'b': {'c': 2}})
717
+
718
+ # Pickle and unpickle
719
+ pickled = pickle.dumps(dd1)
720
+ dd2 = pickle.loads(pickled)
721
+
722
+ self.assertIsNot(dd2, dd1)
723
+ self.assertEqual(dd2.a, 1)
724
+ self.assertEqual(dd2.b.c, 2)
725
+ self.assertIsInstance(dd2, DotDict)
726
+ self.assertIsInstance(dd2.b, DotDict)
727
+
728
+ def test_jax_pytree(self):
729
+ """Test JAX pytree registration."""
730
+ dd = DotDict({'a': jnp.array([1, 2]), 'b': jnp.array([3, 4])})
731
+
732
+ # Tree operations
733
+ doubled = jax.tree_util.tree_map(lambda x: x * 2, dd)
734
+ self.assertTrue(jnp.allclose(doubled.a, jnp.array([2, 4])))
735
+ self.assertTrue(jnp.allclose(doubled.b, jnp.array([6, 8])))
736
+
737
+ def test_repr(self):
738
+ """Test string representation."""
739
+ dd = DotDict({'a': 1, 'b': 'text'})
740
+ repr_str = repr(dd)
741
+ self.assertIn('DotDict', repr_str)
742
+ self.assertIn("'a': 1", repr_str)
743
+ self.assertIn("'b': 'text'", repr_str)
744
+
745
+
746
+ class TestUtilityFunctions(unittest.TestCase):
747
+ """Test cases for utility functions."""
748
+
749
+ def test_merge_dicts_basic(self):
750
+ """Test basic dict merging."""
751
+ d1 = {'a': 1, 'b': 2}
752
+ d2 = {'b': 3, 'c': 4}
753
+ d3 = {'d': 5}
754
+
755
+ result = merge_dicts(d1, d2, d3)
756
+ self.assertEqual(result, {'a': 1, 'b': 3, 'c': 4, 'd': 5})
757
+
758
+ # Original dicts should be unchanged
759
+ self.assertEqual(d1, {'a': 1, 'b': 2})
760
+
761
+ def test_merge_dicts_recursive(self):
762
+ """Test recursive dict merging."""
763
+ d1 = {'a': 1, 'b': {'c': 2, 'd': 3}}
764
+ d2 = {'b': {'d': 4, 'e': 5}, 'f': 6}
765
+
766
+ result = merge_dicts(d1, d2, recursive=True)
767
+ self.assertEqual(result, {
768
+ 'a': 1,
769
+ 'b': {'c': 2, 'd': 4, 'e': 5},
770
+ 'f': 6
771
+ })
772
+
773
+ def test_merge_dicts_non_recursive(self):
774
+ """Test non-recursive dict merging."""
775
+ d1 = {'a': 1, 'b': {'c': 2}}
776
+ d2 = {'b': {'d': 3}}
777
+
778
+ result = merge_dicts(d1, d2, recursive=False)
779
+ self.assertEqual(result, {'a': 1, 'b': {'d': 3}})
780
+
781
+ def test_merge_dicts_errors(self):
782
+ """Test merge_dicts error handling."""
783
+ with self.assertRaises(TypeError):
784
+ merge_dicts({'a': 1}, [1, 2, 3])
785
+
786
+ def test_flatten_dict(self):
787
+ """Test dictionary flattening."""
788
+ nested = {
789
+ 'a': 1,
790
+ 'b': {
791
+ 'c': 2,
792
+ 'd': {
793
+ 'e': 3,
794
+ 'f': 4
795
+ }
796
+ },
797
+ 'g': 5
798
+ }
799
+
800
+ flat = flatten_dict(nested)
801
+ self.assertEqual(flat, {
802
+ 'a': 1,
803
+ 'b.c': 2,
804
+ 'b.d.e': 3,
805
+ 'b.d.f': 4,
806
+ 'g': 5
807
+ })
808
+
809
+ # Custom separator
810
+ flat_dash = flatten_dict(nested, sep='-')
811
+ self.assertEqual(flat_dash, {
812
+ 'a': 1,
813
+ 'b-c': 2,
814
+ 'b-d-e': 3,
815
+ 'b-d-f': 4,
816
+ 'g': 5
817
+ })
818
+
819
+ def test_unflatten_dict(self):
820
+ """Test dictionary unflattening."""
821
+ flat = {
822
+ 'a': 1,
823
+ 'b.c': 2,
824
+ 'b.d.e': 3,
825
+ 'b.d.f': 4,
826
+ 'g': 5
827
+ }
828
+
829
+ nested = unflatten_dict(flat)
830
+ self.assertEqual(nested, {
831
+ 'a': 1,
832
+ 'b': {
833
+ 'c': 2,
834
+ 'd': {
835
+ 'e': 3,
836
+ 'f': 4
837
+ }
838
+ },
839
+ 'g': 5
840
+ })
841
+
842
+ # Custom separator
843
+ flat_dash = {'a': 1, 'b-c': 2, 'b-d': 3}
844
+ nested_dash = unflatten_dict(flat_dash, sep='-')
845
+ self.assertEqual(nested_dash, {
846
+ 'a': 1,
847
+ 'b': {'c': 2, 'd': 3}
848
+ })
849
+
850
+ def test_flatten_unflatten_roundtrip(self):
851
+ """Test that flatten/unflatten is reversible."""
852
+ original = {
853
+ 'level1': {
854
+ 'level2': {
855
+ 'level3': 'value',
856
+ 'another': 42
857
+ },
858
+ 'sibling': 'data'
859
+ },
860
+ 'root': 'element'
861
+ }
862
+
863
+ flattened = flatten_dict(original)
864
+ unflattened = unflatten_dict(flattened)
865
+ self.assertEqual(unflattened, original)
866
+
867
+ def test_is_instance_eval(self):
868
+ """Test is_instance_eval function."""
869
+ # Single type
870
+ is_int = is_instance_eval(int)
871
+ self.assertTrue(is_int(5))
872
+ self.assertFalse(is_int("5"))
873
+ self.assertFalse(is_int(5.0))
874
+
875
+ # Multiple types
876
+ is_number = is_instance_eval(int, float)
877
+ self.assertTrue(is_number(5))
878
+ self.assertTrue(is_number(5.0))
879
+ self.assertFalse(is_number("5"))
880
+
881
+ # With subclasses
882
+ class MyList(list):
883
+ pass
884
+
885
+ is_list = is_instance_eval(list)
886
+ self.assertTrue(is_list([1, 2, 3]))
887
+ self.assertTrue(is_list(MyList([1, 2, 3])))
888
+ self.assertFalse(is_list((1, 2, 3)))
889
+
890
+ def test_not_instance_eval(self):
891
+ """Test not_instance_eval function."""
892
+ # Single type
893
+ not_int = not_instance_eval(int)
894
+ self.assertFalse(not_int(5))
895
+ self.assertTrue(not_int("5"))
896
+ self.assertTrue(not_int(5.0))
897
+
898
+ # Multiple types
899
+ not_number = not_instance_eval(int, float)
900
+ self.assertFalse(not_number(5))
901
+ self.assertFalse(not_number(5.0))
902
+ self.assertTrue(not_number("5"))
903
+ self.assertTrue(not_number([1, 2, 3]))
904
+
905
+
906
+ class TestJaxIntegration(unittest.TestCase):
907
+ """Test JAX integration for DictManager and DotDict."""
908
+
909
+ def test_dictmanager_pytree_operations(self):
910
+ """Test DictManager with JAX tree operations."""
911
+ dm = DictManager({
912
+ 'weights': jnp.array([1.0, 2.0, 3.0]),
913
+ 'bias': jnp.array([0.1, 0.2])
914
+ })
915
+
916
+ # Tree map
917
+ scaled = jax.tree_util.tree_map(lambda x: x * 2, dm)
918
+ self.assertTrue(jnp.allclose(scaled['weights'], jnp.array([2.0, 4.0, 6.0])))
919
+ self.assertTrue(jnp.allclose(scaled['bias'], jnp.array([0.2, 0.4])))
920
+
921
+ # Tree reduce
922
+ total = jax.tree_util.tree_reduce(lambda x, y: x + y.sum(), dm, 0.0)
923
+ self.assertAlmostEqual(total, 6.3, places=5)
924
+
925
+ def test_dotdict_pytree_operations(self):
926
+ """Test DotDict with JAX tree operations."""
927
+ dd = DotDict({
928
+ 'model': {
929
+ 'weights': jnp.array([[1.0, 2.0], [3.0, 4.0]]),
930
+ 'bias': jnp.array([0.1, 0.2])
931
+ },
932
+ 'optimizer': {
933
+ 'lr': 0.01
934
+ }
935
+ })
936
+
937
+ # Tree map on nested structure
938
+ def scale_arrays(x):
939
+ return x * 2 if isinstance(x, jnp.ndarray) else x
940
+
941
+ scaled = jax.tree_util.tree_map(scale_arrays, dd)
942
+ self.assertTrue(jnp.allclose(
943
+ scaled.model.weights,
944
+ jnp.array([[2.0, 4.0], [6.0, 8.0]])
945
+ ))
946
+ self.assertEqual(scaled.optimizer.lr, 0.01) # Non-array unchanged
947
+
948
+ def test_mixed_pytree_structures(self):
949
+ """Test mixing DictManager and DotDict in pytree operations."""
950
+ structure = {
951
+ 'dict_manager': DictManager({'a': jnp.array([1, 2])})
952
+ }
953
+
954
+ doubled = jax.tree_util.tree_map(lambda x: x * 2, structure)
955
+ self.assertTrue(jnp.allclose(
956
+ doubled['dict_manager']['a'],
957
+ jnp.array([2, 4])
958
+ ))
959
+
960
+
961
+ if __name__ == '__main__':
962
+ unittest.main()