brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,912 @@
1
+ # Copyright 2024 BrainState Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Comprehensive tests for filter module.
17
+ """
18
+
19
+ import unittest
20
+ from typing import Any
21
+ import numpy as np
22
+
23
+ from brainstate.util.filter import (
24
+ to_predicate,
25
+ WithTag,
26
+ PathContains,
27
+ OfType,
28
+ Any,
29
+ All,
30
+ Not,
31
+ Everything,
32
+ Nothing,
33
+ )
34
+
35
+
36
+ class MockTaggedObject:
37
+ """Mock object with a tag attribute for testing."""
38
+ def __init__(self, tag: str):
39
+ self.tag = tag
40
+
41
+
42
+ class MockTypedObject:
43
+ """Mock object with a type attribute for testing."""
44
+ def __init__(self, type_value: type):
45
+ self.type = type_value
46
+
47
+
48
+ class TestToPredicateFunction(unittest.TestCase):
49
+ """Test cases for to_predicate function."""
50
+
51
+ def test_string_to_withtag(self):
52
+ """Test converting string to WithTag filter."""
53
+ pred = to_predicate('trainable')
54
+ self.assertIsInstance(pred, WithTag)
55
+ self.assertEqual(pred.tag, 'trainable')
56
+
57
+ # Test functionality
58
+ obj_with_tag = MockTaggedObject('trainable')
59
+ obj_without_tag = MockTaggedObject('frozen')
60
+ self.assertTrue(pred([], obj_with_tag))
61
+ self.assertFalse(pred([], obj_without_tag))
62
+
63
+ def test_type_to_oftype(self):
64
+ """Test converting type to OfType filter."""
65
+ pred = to_predicate(np.ndarray)
66
+ self.assertIsInstance(pred, OfType)
67
+ self.assertEqual(pred.type, np.ndarray)
68
+
69
+ # Test functionality
70
+ arr = np.array([1, 2, 3])
71
+ lst = [1, 2, 3]
72
+ self.assertTrue(pred([], arr))
73
+ self.assertFalse(pred([], lst))
74
+
75
+ def test_bool_true_to_everything(self):
76
+ """Test converting True to Everything filter."""
77
+ pred = to_predicate(True)
78
+ self.assertIsInstance(pred, Everything)
79
+
80
+ # Test functionality
81
+ self.assertTrue(pred([], 'anything'))
82
+ self.assertTrue(pred(['path'], None))
83
+ self.assertTrue(pred([], 42))
84
+
85
+ def test_bool_false_to_nothing(self):
86
+ """Test converting False to Nothing filter."""
87
+ pred = to_predicate(False)
88
+ self.assertIsInstance(pred, Nothing)
89
+
90
+ # Test functionality
91
+ self.assertFalse(pred([], 'anything'))
92
+ self.assertFalse(pred(['path'], None))
93
+ self.assertFalse(pred([], 42))
94
+
95
+ def test_ellipsis_to_everything(self):
96
+ """Test converting Ellipsis to Everything filter."""
97
+ pred = to_predicate(...)
98
+ self.assertIsInstance(pred, Everything)
99
+ self.assertTrue(pred([], 'test'))
100
+
101
+ def test_none_to_nothing(self):
102
+ """Test converting None to Nothing filter."""
103
+ pred = to_predicate(None)
104
+ self.assertIsInstance(pred, Nothing)
105
+ self.assertFalse(pred([], 'test'))
106
+
107
+ def test_callable_passthrough(self):
108
+ """Test that callable is returned as-is."""
109
+ def custom_filter(path, x):
110
+ return x == 'special'
111
+
112
+ pred = to_predicate(custom_filter)
113
+ self.assertIs(pred, custom_filter)
114
+ self.assertTrue(pred([], 'special'))
115
+ self.assertFalse(pred([], 'normal'))
116
+
117
+ def test_list_to_any(self):
118
+ """Test converting list to Any filter."""
119
+ pred = to_predicate(['trainable', 'frozen'])
120
+ self.assertIsInstance(pred, Any)
121
+
122
+ # Test functionality
123
+ trainable = MockTaggedObject('trainable')
124
+ frozen = MockTaggedObject('frozen')
125
+ other = MockTaggedObject('other')
126
+
127
+ self.assertTrue(pred([], trainable))
128
+ self.assertTrue(pred([], frozen))
129
+ self.assertFalse(pred([], other))
130
+
131
+ def test_tuple_to_any(self):
132
+ """Test converting tuple to Any filter."""
133
+ pred = to_predicate((np.ndarray, list))
134
+ self.assertIsInstance(pred, Any)
135
+
136
+ # Test functionality
137
+ self.assertTrue(pred([], np.array([1, 2])))
138
+ self.assertTrue(pred([], [1, 2]))
139
+ self.assertFalse(pred([], (1, 2)))
140
+
141
+ def test_invalid_type_raises_error(self):
142
+ """Test that invalid type raises TypeError."""
143
+ with self.assertRaises(TypeError) as context:
144
+ to_predicate(42)
145
+ self.assertIn('Invalid collection filter', str(context.exception))
146
+
147
+ with self.assertRaises(TypeError):
148
+ to_predicate({'key': 'value'})
149
+
150
+
151
+ class TestWithTagFilter(unittest.TestCase):
152
+ """Test cases for WithTag filter."""
153
+
154
+ def test_basic_functionality(self):
155
+ """Test basic WithTag functionality."""
156
+ filter_trainable = WithTag('trainable')
157
+
158
+ # Object with matching tag
159
+ obj1 = MockTaggedObject('trainable')
160
+ self.assertTrue(filter_trainable([], obj1))
161
+
162
+ # Object with different tag
163
+ obj2 = MockTaggedObject('frozen')
164
+ self.assertFalse(filter_trainable([], obj2))
165
+
166
+ # Object without tag attribute
167
+ obj3 = {'value': 42}
168
+ self.assertFalse(filter_trainable([], obj3))
169
+
170
+ def test_repr(self):
171
+ """Test string representation."""
172
+ filter_tag = WithTag('test_tag')
173
+ self.assertEqual(repr(filter_tag), "WithTag('test_tag')")
174
+
175
+ def test_immutability(self):
176
+ """Test that WithTag is immutable (frozen dataclass)."""
177
+ filter_tag = WithTag('test')
178
+ with self.assertRaises(AttributeError):
179
+ filter_tag.tag = 'modified'
180
+
181
+ def test_path_parameter_ignored(self):
182
+ """Test that path parameter is ignored."""
183
+ filter_tag = WithTag('test')
184
+ obj = MockTaggedObject('test')
185
+
186
+ # Different paths should not affect result
187
+ self.assertTrue(filter_tag([], obj))
188
+ self.assertTrue(filter_tag(['some', 'path'], obj))
189
+ self.assertTrue(filter_tag(['another', 'nested', 'path'], obj))
190
+
191
+
192
+ class TestPathContainsFilter(unittest.TestCase):
193
+ """Test cases for PathContains filter."""
194
+
195
+ def test_basic_functionality(self):
196
+ """Test basic PathContains functionality."""
197
+ filter_weight = PathContains('weight')
198
+
199
+ # Path containing the key
200
+ self.assertTrue(filter_weight(['model', 'layer1', 'weight'], None))
201
+ self.assertTrue(filter_weight(['weight'], None))
202
+ self.assertTrue(filter_weight(['deep', 'nested', 'weight', 'param'], None))
203
+
204
+ # Path not containing the key
205
+ self.assertFalse(filter_weight(['model', 'layer1', 'bias'], None))
206
+ self.assertFalse(filter_weight([], None))
207
+ self.assertFalse(filter_weight(['other', 'path'], None))
208
+
209
+ def test_numeric_keys(self):
210
+ """Test with numeric keys in path."""
211
+ filter_num = PathContains(0)
212
+
213
+ self.assertTrue(filter_num([0, 'item'], None))
214
+ self.assertTrue(filter_num(['list', 0, 'element'], None))
215
+ self.assertFalse(filter_num([1, 2, 3], None))
216
+
217
+ def test_repr(self):
218
+ """Test string representation."""
219
+ filter_path = PathContains('layer2')
220
+ self.assertEqual(repr(filter_path), "PathContains('layer2')")
221
+
222
+ def test_object_parameter_ignored(self):
223
+ """Test that object parameter is ignored."""
224
+ filter_path = PathContains('test')
225
+
226
+ # Different objects should not affect result if path contains key
227
+ self.assertTrue(filter_path(['test'], 'string'))
228
+ self.assertTrue(filter_path(['test'], 123))
229
+ self.assertTrue(filter_path(['test'], None))
230
+ self.assertTrue(filter_path(['test'], {'dict': 'value'}))
231
+
232
+
233
+ class TestOfTypeFilter(unittest.TestCase):
234
+ """Test cases for OfType filter."""
235
+
236
+ def test_direct_instance_check(self):
237
+ """Test checking direct instances of a type."""
238
+ filter_array = OfType(np.ndarray)
239
+
240
+ self.assertTrue(filter_array([], np.array([1, 2, 3])))
241
+ self.assertTrue(filter_array([], np.zeros((2, 2))))
242
+ self.assertFalse(filter_array([], [1, 2, 3]))
243
+ self.assertFalse(filter_array([], 42))
244
+
245
+ def test_inheritance(self):
246
+ """Test that subclasses are also matched."""
247
+ class BaseClass:
248
+ pass
249
+
250
+ class DerivedClass(BaseClass):
251
+ pass
252
+
253
+ filter_base = OfType(BaseClass)
254
+
255
+ base_obj = BaseClass()
256
+ derived_obj = DerivedClass()
257
+ other_obj = "not related"
258
+
259
+ self.assertTrue(filter_base([], base_obj))
260
+ self.assertTrue(filter_base([], derived_obj))
261
+ self.assertFalse(filter_base([], other_obj))
262
+
263
+ def test_type_attribute_check(self):
264
+ """Test checking objects with type attribute."""
265
+ filter_list = OfType(list)
266
+
267
+ # Object with type attribute
268
+ typed_obj = MockTypedObject(list)
269
+ self.assertTrue(filter_list([], typed_obj))
270
+
271
+ # Object with non-matching type attribute
272
+ typed_obj2 = MockTypedObject(dict)
273
+ self.assertFalse(filter_list([], typed_obj2))
274
+
275
+ def test_repr(self):
276
+ """Test string representation."""
277
+ filter_type = OfType(str)
278
+ self.assertEqual(repr(filter_type), f"OfType({str!r})")
279
+
280
+ def test_builtin_types(self):
281
+ """Test with built-in types."""
282
+ filter_str = OfType(str)
283
+ filter_int = OfType(int)
284
+ filter_list = OfType(list)
285
+
286
+ self.assertTrue(filter_str([], "hello"))
287
+ self.assertTrue(filter_int([], 42))
288
+ self.assertTrue(filter_list([], [1, 2, 3]))
289
+
290
+ self.assertFalse(filter_str([], 42))
291
+ self.assertFalse(filter_int([], "42"))
292
+ self.assertFalse(filter_list([], (1, 2, 3)))
293
+
294
+
295
+ class TestAnyFilter(unittest.TestCase):
296
+ """Test cases for Any filter."""
297
+
298
+ def test_basic_or_operation(self):
299
+ """Test basic OR operation with multiple filters."""
300
+ filter_any = Any('trainable', 'frozen')
301
+
302
+ trainable = MockTaggedObject('trainable')
303
+ frozen = MockTaggedObject('frozen')
304
+ other = MockTaggedObject('other')
305
+
306
+ self.assertTrue(filter_any([], trainable))
307
+ self.assertTrue(filter_any([], frozen))
308
+ self.assertFalse(filter_any([], other))
309
+
310
+ def test_mixed_filter_types(self):
311
+ """Test combining different filter types."""
312
+ filter_mixed = Any(
313
+ OfType(np.ndarray),
314
+ WithTag('special'),
315
+ PathContains('important')
316
+ )
317
+
318
+ # Test each condition
319
+ self.assertTrue(filter_mixed([], np.array([1, 2])))
320
+ self.assertTrue(filter_mixed([], MockTaggedObject('special')))
321
+ self.assertTrue(filter_mixed(['important'], 'anything'))
322
+
323
+ # Test none match
324
+ self.assertFalse(filter_mixed(['other'], MockTaggedObject('normal')))
325
+
326
+ def test_short_circuit_evaluation(self):
327
+ """Test that Any short-circuits on first True."""
328
+ call_count = [0]
329
+
330
+ def counting_filter(path, x):
331
+ call_count[0] += 1
332
+ return x == 'match'
333
+
334
+ filter_any = Any(
335
+ lambda p, x: x == 'match', # This will match
336
+ counting_filter # This should not be called
337
+ )
338
+
339
+ self.assertTrue(filter_any([], 'match'))
340
+ self.assertEqual(call_count[0], 0) # Second filter not called
341
+
342
+ def test_empty_any(self):
343
+ """Test Any with no filters."""
344
+ filter_empty = Any()
345
+ # Empty Any should return False (no conditions to satisfy)
346
+ self.assertFalse(filter_empty([], 'anything'))
347
+
348
+ def test_repr(self):
349
+ """Test string representation."""
350
+ filter_any = Any(WithTag('tag1'), WithTag('tag2'))
351
+ repr_str = repr(filter_any)
352
+ self.assertIn('Any', repr_str)
353
+ self.assertIn("WithTag('tag1')", repr_str)
354
+ self.assertIn("WithTag('tag2')", repr_str)
355
+
356
+ def test_equality(self):
357
+ """Test equality comparison."""
358
+ filter1 = Any('tag1', 'tag2')
359
+ filter2 = Any('tag1', 'tag2')
360
+ filter3 = Any('tag2', 'tag1') # Different order
361
+
362
+ self.assertEqual(filter1, filter2)
363
+ self.assertNotEqual(filter1, filter3)
364
+ self.assertNotEqual(filter1, 'not a filter')
365
+
366
+ def test_hashable(self):
367
+ """Test that Any filters are hashable."""
368
+ filter1 = Any('tag1', 'tag2')
369
+ filter2 = Any('tag1', 'tag2')
370
+
371
+ # Should be able to use in set/dict
372
+ filter_set = {filter1, filter2}
373
+ self.assertEqual(len(filter_set), 1) # Same filters
374
+
375
+
376
+ class TestAllFilter(unittest.TestCase):
377
+ """Test cases for All filter."""
378
+
379
+ def test_basic_and_operation(self):
380
+ """Test basic AND operation with multiple filters."""
381
+ filter_all = All(
382
+ WithTag('trainable'),
383
+ OfType(np.ndarray)
384
+ )
385
+
386
+ # Create a numpy subclass that can have attributes
387
+ class TaggedArray(np.ndarray):
388
+ def __new__(cls, input_array, tag=None):
389
+ obj = np.asarray(input_array).view(cls)
390
+ obj.tag = tag
391
+ return obj
392
+
393
+ # Create test objects
394
+ arr_obj = TaggedArray([1, 2, 3], tag='trainable')
395
+ self.assertTrue(filter_all([], arr_obj)) # Matches both conditions
396
+
397
+ arr_obj2 = TaggedArray([4, 5, 6], tag='frozen')
398
+ self.assertFalse(filter_all([], arr_obj2)) # Wrong tag
399
+
400
+ # List with tag (won't match type)
401
+ class ListWithTag:
402
+ def __init__(self, tag):
403
+ self.tag = tag
404
+
405
+ lst_obj = ListWithTag('trainable')
406
+ self.assertFalse(filter_all([], lst_obj)) # Wrong type
407
+
408
+ def test_short_circuit_evaluation(self):
409
+ """Test that All short-circuits on first False."""
410
+ call_count = [0]
411
+
412
+ def counting_filter(path, x):
413
+ call_count[0] += 1
414
+ return True
415
+
416
+ filter_all = All(
417
+ lambda p, x: False, # This will fail
418
+ counting_filter # This should not be called
419
+ )
420
+
421
+ self.assertFalse(filter_all([], 'anything'))
422
+ self.assertEqual(call_count[0], 0) # Second filter not called
423
+
424
+ def test_empty_all(self):
425
+ """Test All with no filters."""
426
+ filter_empty = All()
427
+ # Empty All should return True (no conditions to violate)
428
+ self.assertTrue(filter_empty([], 'anything'))
429
+
430
+ def test_complex_combination(self):
431
+ """Test complex combination of conditions."""
432
+ class CustomObject:
433
+ def __init__(self, tag, value):
434
+ self.tag = tag
435
+ self.value = value
436
+
437
+ filter_complex = All(
438
+ WithTag('important'),
439
+ lambda p, x: hasattr(x, 'value') and x.value > 10,
440
+ lambda p, x: hasattr(x, 'value') and x.value < 100
441
+ )
442
+
443
+ obj1 = CustomObject('important', 50)
444
+ obj2 = CustomObject('important', 5)
445
+ obj3 = CustomObject('important', 150)
446
+ obj4 = CustomObject('other', 50)
447
+
448
+ self.assertTrue(filter_complex([], obj1)) # All conditions met
449
+ self.assertFalse(filter_complex([], obj2)) # value too small
450
+ self.assertFalse(filter_complex([], obj3)) # value too large
451
+ self.assertFalse(filter_complex([], obj4)) # wrong tag
452
+
453
+ def test_repr(self):
454
+ """Test string representation."""
455
+ filter_all = All(WithTag('tag1'), OfType(list))
456
+ repr_str = repr(filter_all)
457
+ self.assertIn('All', repr_str)
458
+ self.assertIn("WithTag('tag1')", repr_str)
459
+ self.assertIn('OfType', repr_str)
460
+
461
+ def test_equality(self):
462
+ """Test equality comparison."""
463
+ filter1 = All('tag1', np.ndarray)
464
+ filter2 = All('tag1', np.ndarray)
465
+ filter3 = All(np.ndarray, 'tag1') # Different order
466
+
467
+ self.assertEqual(filter1, filter2)
468
+ self.assertNotEqual(filter1, filter3)
469
+
470
+ def test_hashable(self):
471
+ """Test that All filters are hashable."""
472
+ filter1 = All('tag1', np.ndarray)
473
+ filter2 = All('tag1', np.ndarray)
474
+
475
+ filter_dict = {filter1: 'value1', filter2: 'value2'}
476
+ self.assertEqual(len(filter_dict), 1) # Same filters
477
+
478
+
479
+ class TestNotFilter(unittest.TestCase):
480
+ """Test cases for Not filter."""
481
+
482
+ def test_basic_negation(self):
483
+ """Test basic negation of filters."""
484
+ filter_not_trainable = Not(WithTag('trainable'))
485
+
486
+ trainable = MockTaggedObject('trainable')
487
+ frozen = MockTaggedObject('frozen')
488
+
489
+ self.assertFalse(filter_not_trainable([], trainable))
490
+ self.assertTrue(filter_not_trainable([], frozen))
491
+
492
+ def test_negating_type_filter(self):
493
+ """Test negating type filters."""
494
+ filter_not_array = Not(OfType(np.ndarray))
495
+
496
+ self.assertFalse(filter_not_array([], np.array([1, 2])))
497
+ self.assertTrue(filter_not_array([], [1, 2]))
498
+ self.assertTrue(filter_not_array([], 'string'))
499
+
500
+ def test_negating_complex_filters(self):
501
+ """Test negating complex filter combinations."""
502
+ # Not(Any(...)) - none should match
503
+ filter_not_any = Not(Any('tag1', 'tag2'))
504
+
505
+ obj1 = MockTaggedObject('tag1')
506
+ obj2 = MockTaggedObject('tag2')
507
+ obj3 = MockTaggedObject('tag3')
508
+
509
+ self.assertFalse(filter_not_any([], obj1))
510
+ self.assertFalse(filter_not_any([], obj2))
511
+ self.assertTrue(filter_not_any([], obj3))
512
+
513
+ # Not(All(...)) - at least one should not match
514
+ filter_not_all = Not(All(WithTag('tag'), OfType(list)))
515
+
516
+ # Create a list-like object with tag
517
+ class TaggedList(list):
518
+ def __init__(self, *args):
519
+ super().__init__(*args)
520
+ self.tag = 'tag'
521
+
522
+ lst = TaggedList([1, 2, 3])
523
+ self.assertFalse(filter_not_all([], lst)) # Matches all conditions
524
+
525
+ # Create a numpy subclass with tag
526
+ class TaggedArray(np.ndarray):
527
+ def __new__(cls, input_array, tag=None):
528
+ obj = np.asarray(input_array).view(cls)
529
+ obj.tag = tag
530
+ return obj
531
+
532
+ arr = TaggedArray([], tag='tag')
533
+ self.assertTrue(filter_not_all([], arr)) # Doesn't match type (not a list)
534
+
535
+ def test_double_negation(self):
536
+ """Test double negation returns to original."""
537
+ original = WithTag('test')
538
+ double_neg = Not(Not(original))
539
+
540
+ obj_match = MockTaggedObject('test')
541
+ obj_no_match = MockTaggedObject('other')
542
+
543
+ # Double negation should behave like original
544
+ self.assertEqual(
545
+ original([], obj_match),
546
+ double_neg([], obj_match)
547
+ )
548
+ self.assertEqual(
549
+ original([], obj_no_match),
550
+ double_neg([], obj_no_match)
551
+ )
552
+
553
+ def test_repr(self):
554
+ """Test string representation."""
555
+ filter_not = Not(WithTag('test'))
556
+ self.assertEqual(repr(filter_not), "Not(WithTag('test'))")
557
+
558
+ def test_equality(self):
559
+ """Test equality comparison."""
560
+ filter1 = Not(WithTag('test'))
561
+ filter2 = Not(WithTag('test'))
562
+ filter3 = Not(WithTag('other'))
563
+
564
+ self.assertEqual(filter1, filter2)
565
+ self.assertNotEqual(filter1, filter3)
566
+
567
+ def test_hashable(self):
568
+ """Test that Not filters are hashable."""
569
+ filter1 = Not(WithTag('test'))
570
+ filter2 = Not(WithTag('test'))
571
+
572
+ filter_set = {filter1, filter2}
573
+ self.assertEqual(len(filter_set), 1)
574
+
575
+
576
+ class TestEverythingFilter(unittest.TestCase):
577
+ """Test cases for Everything filter."""
578
+
579
+ def test_always_returns_true(self):
580
+ """Test that Everything always returns True."""
581
+ filter_all = Everything()
582
+
583
+ # Test with various objects
584
+ self.assertTrue(filter_all([], None))
585
+ self.assertTrue(filter_all([], 42))
586
+ self.assertTrue(filter_all([], 'string'))
587
+ self.assertTrue(filter_all([], [1, 2, 3]))
588
+ self.assertTrue(filter_all([], np.array([1, 2])))
589
+ self.assertTrue(filter_all([], {'key': 'value'}))
590
+
591
+ # Test with various paths
592
+ self.assertTrue(filter_all(['path'], None))
593
+ self.assertTrue(filter_all(['nested', 'path'], None))
594
+ self.assertTrue(filter_all([], None))
595
+
596
+ def test_repr(self):
597
+ """Test string representation."""
598
+ filter_all = Everything()
599
+ self.assertEqual(repr(filter_all), 'Everything()')
600
+
601
+ def test_equality(self):
602
+ """Test equality comparison."""
603
+ filter1 = Everything()
604
+ filter2 = Everything()
605
+ filter3 = Nothing()
606
+
607
+ self.assertEqual(filter1, filter2)
608
+ self.assertNotEqual(filter1, filter3)
609
+
610
+ def test_hashable(self):
611
+ """Test that Everything filters are hashable."""
612
+ filter1 = Everything()
613
+ filter2 = Everything()
614
+
615
+ # All Everything instances should be equal and have same hash
616
+ self.assertEqual(hash(filter1), hash(filter2))
617
+
618
+ filter_set = {filter1, filter2}
619
+ self.assertEqual(len(filter_set), 1)
620
+
621
+ def test_conversion_from_true(self):
622
+ """Test that True converts to Everything."""
623
+ filter_from_true = to_predicate(True)
624
+ filter_direct = Everything()
625
+
626
+ # Should behave identically
627
+ test_cases = [None, 42, 'test', [], {}]
628
+ for obj in test_cases:
629
+ self.assertEqual(
630
+ filter_from_true([], obj),
631
+ filter_direct([], obj)
632
+ )
633
+
634
+ def test_conversion_from_ellipsis(self):
635
+ """Test that Ellipsis converts to Everything."""
636
+ filter_from_ellipsis = to_predicate(...)
637
+ filter_direct = Everything()
638
+
639
+ self.assertIsInstance(filter_from_ellipsis, Everything)
640
+ self.assertEqual(filter_from_ellipsis, filter_direct)
641
+
642
+
643
+ class TestNothingFilter(unittest.TestCase):
644
+ """Test cases for Nothing filter."""
645
+
646
+ def test_always_returns_false(self):
647
+ """Test that Nothing always returns False."""
648
+ filter_none = Nothing()
649
+
650
+ # Test with various objects
651
+ self.assertFalse(filter_none([], None))
652
+ self.assertFalse(filter_none([], 42))
653
+ self.assertFalse(filter_none([], 'string'))
654
+ self.assertFalse(filter_none([], [1, 2, 3]))
655
+ self.assertFalse(filter_none([], np.array([1, 2])))
656
+ self.assertFalse(filter_none([], {'key': 'value'}))
657
+
658
+ # Test with various paths
659
+ self.assertFalse(filter_none(['path'], None))
660
+ self.assertFalse(filter_none(['nested', 'path'], None))
661
+ self.assertFalse(filter_none([], None))
662
+
663
+ def test_repr(self):
664
+ """Test string representation."""
665
+ filter_none = Nothing()
666
+ self.assertEqual(repr(filter_none), 'Nothing()')
667
+
668
+ def test_equality(self):
669
+ """Test equality comparison."""
670
+ filter1 = Nothing()
671
+ filter2 = Nothing()
672
+ filter3 = Everything()
673
+
674
+ self.assertEqual(filter1, filter2)
675
+ self.assertNotEqual(filter1, filter3)
676
+
677
+ def test_hashable(self):
678
+ """Test that Nothing filters are hashable."""
679
+ filter1 = Nothing()
680
+ filter2 = Nothing()
681
+
682
+ # All Nothing instances should be equal and have same hash
683
+ self.assertEqual(hash(filter1), hash(filter2))
684
+
685
+ filter_dict = {filter1: 'value'}
686
+ filter_dict[filter2] = 'new_value'
687
+ self.assertEqual(len(filter_dict), 1) # Same key
688
+
689
+ def test_conversion_from_false(self):
690
+ """Test that False converts to Nothing."""
691
+ filter_from_false = to_predicate(False)
692
+ filter_direct = Nothing()
693
+
694
+ # Should behave identically
695
+ test_cases = [None, 42, 'test', [], {}]
696
+ for obj in test_cases:
697
+ self.assertEqual(
698
+ filter_from_false([], obj),
699
+ filter_direct([], obj)
700
+ )
701
+
702
+ def test_conversion_from_none(self):
703
+ """Test that None converts to Nothing."""
704
+ filter_from_none = to_predicate(None)
705
+ filter_direct = Nothing()
706
+
707
+ self.assertIsInstance(filter_from_none, Nothing)
708
+ self.assertEqual(filter_from_none, filter_direct)
709
+
710
+
711
+ class TestIntegrationScenarios(unittest.TestCase):
712
+ """Integration tests for complex filter combinations."""
713
+
714
+ def test_neural_network_parameter_filtering(self):
715
+ """Test filtering neural network parameters with complex criteria."""
716
+ # Simulate neural network parameters
717
+ class Parameter:
718
+ def __init__(self, shape, tag=None, dtype=None):
719
+ self.shape = shape
720
+ self.tag = tag
721
+ self.dtype = dtype
722
+ self.data = np.random.randn(*shape) if dtype != 'int32' else np.random.randint(0, 10, shape)
723
+
724
+ # Create various parameters
725
+ weight1 = Parameter((10, 20), tag='trainable')
726
+ weight1.data = np.array(weight1.data) # Ensure it's ndarray
727
+
728
+ bias1 = Parameter((20,), tag='trainable')
729
+ bias1.data = np.array(bias1.data)
730
+
731
+ embedding = Parameter((100, 64), tag='frozen')
732
+ embedding.data = np.array(embedding.data)
733
+
734
+ # Complex filter: trainable arrays with shape[0] > 5
735
+ filter_complex = All(
736
+ WithTag('trainable'),
737
+ lambda p, x: hasattr(x, 'data') and isinstance(x.data, np.ndarray),
738
+ lambda p, x: hasattr(x, 'shape') and x.shape[0] > 5
739
+ )
740
+
741
+ self.assertTrue(filter_complex([], weight1)) # Matches all
742
+ self.assertTrue(filter_complex([], bias1)) # shape[0] = 20 > 5
743
+ self.assertFalse(filter_complex([], embedding)) # Wrong tag
744
+
745
+ def test_path_based_model_filtering(self):
746
+ """Test filtering based on model structure paths."""
747
+ # Filter for encoder weights
748
+ encoder_weight_filter = All(
749
+ PathContains('encoder'),
750
+ PathContains('weight')
751
+ )
752
+
753
+ # Test various paths
754
+ paths = [
755
+ (['model', 'encoder', 'layer1', 'weight'], True),
756
+ (['model', 'encoder', 'layer2', 'weight'], True),
757
+ (['model', 'encoder', 'layer1', 'bias'], False), # Not weight
758
+ (['model', 'decoder', 'layer1', 'weight'], False), # Not encoder
759
+ (['encoder', 'attention', 'weight'], True),
760
+ ]
761
+
762
+ for path, expected in paths:
763
+ self.assertEqual(
764
+ encoder_weight_filter(path, None),
765
+ expected,
766
+ f"Failed for path: {path}"
767
+ )
768
+
769
+ def test_selective_gradient_computation(self):
770
+ """Test filter for selective gradient computation."""
771
+ # Only compute gradients for trainable non-embedding layers
772
+ gradient_filter = All(
773
+ WithTag('trainable'),
774
+ Not(PathContains('embedding')),
775
+ Any(
776
+ PathContains('weight'),
777
+ PathContains('bias')
778
+ )
779
+ )
780
+
781
+ # Create test objects
782
+ class Param:
783
+ def __init__(self, tag):
784
+ self.tag = tag
785
+
786
+ trainable_weight = Param('trainable')
787
+ trainable_bias = Param('trainable')
788
+ frozen_weight = Param('frozen')
789
+ trainable_other = Param('trainable')
790
+
791
+ test_cases = [
792
+ (['layer1', 'weight'], trainable_weight, True),
793
+ (['layer1', 'bias'], trainable_bias, True),
794
+ (['embedding', 'weight'], trainable_weight, False), # Excluded path
795
+ (['layer1', 'weight'], frozen_weight, False), # Wrong tag
796
+ (['layer1', 'gamma'], trainable_other, False), # Not weight/bias
797
+ ]
798
+
799
+ for path, obj, expected in test_cases:
800
+ self.assertEqual(
801
+ gradient_filter(path, obj),
802
+ expected,
803
+ f"Failed for path={path}, tag={obj.tag}"
804
+ )
805
+
806
+ def test_demorgan_laws(self):
807
+ """Test De Morgan's laws with filters."""
808
+ tag1_filter = WithTag('tag1')
809
+ tag2_filter = WithTag('tag2')
810
+
811
+ # Not(A or B) == (Not A) and (Not B)
812
+ not_any = Not(Any(tag1_filter, tag2_filter))
813
+ all_not = All(Not(tag1_filter), Not(tag2_filter))
814
+
815
+ test_objects = [
816
+ MockTaggedObject('tag1'),
817
+ MockTaggedObject('tag2'),
818
+ MockTaggedObject('tag3'),
819
+ MockTaggedObject('other'),
820
+ ]
821
+
822
+ for obj in test_objects:
823
+ self.assertEqual(
824
+ not_any([], obj),
825
+ all_not([], obj),
826
+ f"De Morgan's law failed for tag={obj.tag}"
827
+ )
828
+
829
+ # Not(A and B) == (Not A) or (Not B)
830
+ not_all = Not(All(tag1_filter, tag2_filter))
831
+ any_not = Any(Not(tag1_filter), Not(tag2_filter))
832
+
833
+ # For All filter to work, object needs both tags
834
+ class DualTagged:
835
+ def __init__(self, tag1, tag2):
836
+ self.tag = tag1 if tag1 else tag2 # For single tag check
837
+
838
+ # Create object that could match both filters
839
+ dual1 = DualTagged('tag1', None)
840
+ dual2 = DualTagged('tag2', None)
841
+ neither = DualTagged('other', None)
842
+
843
+ for obj in [dual1, dual2, neither]:
844
+ # Since our mock object can only have one tag at a time,
845
+ # All(tag1, tag2) will always be False, so Not(All(...)) will always be True
846
+ # and Any(Not(tag1), Not(tag2)) will depend on the specific tag
847
+ pass # This specific test is limited by our mock implementation
848
+
849
+ def test_filter_chaining_performance(self):
850
+ """Test that filter chaining works correctly."""
851
+ # Create a chain of filters
852
+ base_filter = WithTag('base')
853
+ extended_filter = All(base_filter, OfType(dict))
854
+ final_filter = Any(extended_filter, PathContains('special'))
855
+
856
+ # Test object that matches base but not type
857
+ obj1 = MockTaggedObject('base')
858
+ self.assertFalse(extended_filter([], obj1)) # Not a dict
859
+ self.assertFalse(final_filter([], obj1)) # Doesn't match Any conditions
860
+
861
+ # Test object that matches everything
862
+ class TaggedDict(dict):
863
+ def __init__(self, *args, **kwargs):
864
+ super().__init__(*args, **kwargs)
865
+ self.tag = 'base'
866
+
867
+ obj2 = TaggedDict(key='value')
868
+ self.assertTrue(extended_filter([], obj2)) # Matches both conditions
869
+ self.assertTrue(final_filter([], obj2)) # Matches via extended_filter
870
+
871
+ # Test path-based match
872
+ self.assertTrue(final_filter(['special'], 'anything'))
873
+
874
+ def test_recursive_filter_structures(self):
875
+ """Test deeply nested filter combinations."""
876
+ # Build a complex filter structure
877
+ filter_deep = Any(
878
+ All(
879
+ WithTag('level1'),
880
+ Any(
881
+ All(WithTag('level2a'), OfType(list)),
882
+ All(WithTag('level2b'), OfType(dict))
883
+ )
884
+ ),
885
+ Not(
886
+ All(
887
+ PathContains('excluded'),
888
+ Not(WithTag('override'))
889
+ )
890
+ )
891
+ )
892
+
893
+ # This is a complex filter, let's test a few cases
894
+ # Case 1: Would need an object with tag='level1' and also tag='level2a' and be a list
895
+ # Since objects can only have one tag, this is hard to test directly
896
+ # Instead, test the second branch
897
+
898
+ # Case 2: Matches second branch (not in excluded path without override)
899
+ obj = MockTaggedObject('any')
900
+ self.assertTrue(filter_deep(['included'], obj)) # Not excluded path
901
+
902
+ # Case 3: In excluded path but has override tag
903
+ override_obj = MockTaggedObject('override')
904
+ self.assertTrue(filter_deep(['excluded'], override_obj)) # Has override
905
+
906
+ # Case 4: In excluded path without override - should not match
907
+ regular_obj = MockTaggedObject('regular')
908
+ self.assertFalse(filter_deep(['excluded'], regular_obj)) # Excluded without override
909
+
910
+
911
+ if __name__ == '__main__':
912
+ unittest.main()