brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +169 -58
- brainstate/_compatible_import.py +340 -148
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +45 -55
- brainstate/_state.py +1652 -1605
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -563
- brainstate/environ_test.py +1223 -62
- brainstate/graph/__init__.py +22 -29
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1433 -365
- brainstate/mixin_test.py +1017 -77
- brainstate/nn/__init__.py +137 -135
- brainstate/nn/_activations.py +1100 -808
- brainstate/nn/_activations_test.py +354 -331
- brainstate/nn/_collective_ops.py +633 -514
- brainstate/nn/_collective_ops_test.py +774 -43
- brainstate/nn/_common.py +226 -178
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +2010 -501
- brainstate/nn/_conv_test.py +849 -238
- brainstate/nn/_delay.py +575 -588
- brainstate/nn/_delay_test.py +243 -238
- brainstate/nn/_dropout.py +618 -426
- brainstate/nn/_dropout_test.py +477 -100
- brainstate/nn/_dynamics.py +1267 -1343
- brainstate/nn/_dynamics_test.py +67 -78
- brainstate/nn/_elementwise.py +1298 -1119
- brainstate/nn/_elementwise_test.py +830 -169
- brainstate/nn/_embedding.py +408 -58
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
- brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
- brainstate/nn/_exp_euler.py +254 -92
- brainstate/nn/_exp_euler_test.py +377 -35
- brainstate/nn/_linear.py +744 -424
- brainstate/nn/_linear_test.py +475 -107
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +384 -377
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -975
- brainstate/nn/_normalizations_test.py +699 -73
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +2239 -1177
- brainstate/nn/_poolings_test.py +953 -217
- brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +216 -89
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +809 -553
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
- brainstate/random/__init__.py +270 -24
- brainstate/random/_rand_funs.py +3938 -3616
- brainstate/random/_rand_funs_test.py +640 -567
- brainstate/random/_rand_seed.py +675 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1617 -1409
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
- brainstate/{augment → transform}/_autograd.py +1025 -778
- brainstate/{augment → transform}/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +220 -220
- brainstate/{compile → transform}/_error_if.py +94 -92
- brainstate/{compile → transform}/_error_if_test.py +52 -52
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +38 -38
- brainstate/{compile → transform}/_jit.py +399 -346
- brainstate/{compile → transform}/_jit_test.py +143 -143
- brainstate/{compile → transform}/_loop_collect_return.py +675 -536
- brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
- brainstate/{compile → transform}/_loop_no_collection.py +283 -184
- brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +255 -202
- brainstate/{augment → transform}/_random.py +171 -151
- brainstate/{compile → transform}/_unvmap.py +256 -159
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +837 -304
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +27 -50
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +945 -469
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +910 -523
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
- brainstate-0.2.1.dist-info/RECORD +111 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,912 @@
|
|
1
|
+
# Copyright 2024 BrainState Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""
|
16
|
+
Comprehensive tests for filter module.
|
17
|
+
"""
|
18
|
+
|
19
|
+
import unittest
|
20
|
+
from typing import Any
|
21
|
+
import numpy as np
|
22
|
+
|
23
|
+
from brainstate.util.filter import (
|
24
|
+
to_predicate,
|
25
|
+
WithTag,
|
26
|
+
PathContains,
|
27
|
+
OfType,
|
28
|
+
Any,
|
29
|
+
All,
|
30
|
+
Not,
|
31
|
+
Everything,
|
32
|
+
Nothing,
|
33
|
+
)
|
34
|
+
|
35
|
+
|
36
|
+
class MockTaggedObject:
|
37
|
+
"""Mock object with a tag attribute for testing."""
|
38
|
+
def __init__(self, tag: str):
|
39
|
+
self.tag = tag
|
40
|
+
|
41
|
+
|
42
|
+
class MockTypedObject:
|
43
|
+
"""Mock object with a type attribute for testing."""
|
44
|
+
def __init__(self, type_value: type):
|
45
|
+
self.type = type_value
|
46
|
+
|
47
|
+
|
48
|
+
class TestToPredicateFunction(unittest.TestCase):
|
49
|
+
"""Test cases for to_predicate function."""
|
50
|
+
|
51
|
+
def test_string_to_withtag(self):
|
52
|
+
"""Test converting string to WithTag filter."""
|
53
|
+
pred = to_predicate('trainable')
|
54
|
+
self.assertIsInstance(pred, WithTag)
|
55
|
+
self.assertEqual(pred.tag, 'trainable')
|
56
|
+
|
57
|
+
# Test functionality
|
58
|
+
obj_with_tag = MockTaggedObject('trainable')
|
59
|
+
obj_without_tag = MockTaggedObject('frozen')
|
60
|
+
self.assertTrue(pred([], obj_with_tag))
|
61
|
+
self.assertFalse(pred([], obj_without_tag))
|
62
|
+
|
63
|
+
def test_type_to_oftype(self):
|
64
|
+
"""Test converting type to OfType filter."""
|
65
|
+
pred = to_predicate(np.ndarray)
|
66
|
+
self.assertIsInstance(pred, OfType)
|
67
|
+
self.assertEqual(pred.type, np.ndarray)
|
68
|
+
|
69
|
+
# Test functionality
|
70
|
+
arr = np.array([1, 2, 3])
|
71
|
+
lst = [1, 2, 3]
|
72
|
+
self.assertTrue(pred([], arr))
|
73
|
+
self.assertFalse(pred([], lst))
|
74
|
+
|
75
|
+
def test_bool_true_to_everything(self):
|
76
|
+
"""Test converting True to Everything filter."""
|
77
|
+
pred = to_predicate(True)
|
78
|
+
self.assertIsInstance(pred, Everything)
|
79
|
+
|
80
|
+
# Test functionality
|
81
|
+
self.assertTrue(pred([], 'anything'))
|
82
|
+
self.assertTrue(pred(['path'], None))
|
83
|
+
self.assertTrue(pred([], 42))
|
84
|
+
|
85
|
+
def test_bool_false_to_nothing(self):
|
86
|
+
"""Test converting False to Nothing filter."""
|
87
|
+
pred = to_predicate(False)
|
88
|
+
self.assertIsInstance(pred, Nothing)
|
89
|
+
|
90
|
+
# Test functionality
|
91
|
+
self.assertFalse(pred([], 'anything'))
|
92
|
+
self.assertFalse(pred(['path'], None))
|
93
|
+
self.assertFalse(pred([], 42))
|
94
|
+
|
95
|
+
def test_ellipsis_to_everything(self):
|
96
|
+
"""Test converting Ellipsis to Everything filter."""
|
97
|
+
pred = to_predicate(...)
|
98
|
+
self.assertIsInstance(pred, Everything)
|
99
|
+
self.assertTrue(pred([], 'test'))
|
100
|
+
|
101
|
+
def test_none_to_nothing(self):
|
102
|
+
"""Test converting None to Nothing filter."""
|
103
|
+
pred = to_predicate(None)
|
104
|
+
self.assertIsInstance(pred, Nothing)
|
105
|
+
self.assertFalse(pred([], 'test'))
|
106
|
+
|
107
|
+
def test_callable_passthrough(self):
|
108
|
+
"""Test that callable is returned as-is."""
|
109
|
+
def custom_filter(path, x):
|
110
|
+
return x == 'special'
|
111
|
+
|
112
|
+
pred = to_predicate(custom_filter)
|
113
|
+
self.assertIs(pred, custom_filter)
|
114
|
+
self.assertTrue(pred([], 'special'))
|
115
|
+
self.assertFalse(pred([], 'normal'))
|
116
|
+
|
117
|
+
def test_list_to_any(self):
|
118
|
+
"""Test converting list to Any filter."""
|
119
|
+
pred = to_predicate(['trainable', 'frozen'])
|
120
|
+
self.assertIsInstance(pred, Any)
|
121
|
+
|
122
|
+
# Test functionality
|
123
|
+
trainable = MockTaggedObject('trainable')
|
124
|
+
frozen = MockTaggedObject('frozen')
|
125
|
+
other = MockTaggedObject('other')
|
126
|
+
|
127
|
+
self.assertTrue(pred([], trainable))
|
128
|
+
self.assertTrue(pred([], frozen))
|
129
|
+
self.assertFalse(pred([], other))
|
130
|
+
|
131
|
+
def test_tuple_to_any(self):
|
132
|
+
"""Test converting tuple to Any filter."""
|
133
|
+
pred = to_predicate((np.ndarray, list))
|
134
|
+
self.assertIsInstance(pred, Any)
|
135
|
+
|
136
|
+
# Test functionality
|
137
|
+
self.assertTrue(pred([], np.array([1, 2])))
|
138
|
+
self.assertTrue(pred([], [1, 2]))
|
139
|
+
self.assertFalse(pred([], (1, 2)))
|
140
|
+
|
141
|
+
def test_invalid_type_raises_error(self):
|
142
|
+
"""Test that invalid type raises TypeError."""
|
143
|
+
with self.assertRaises(TypeError) as context:
|
144
|
+
to_predicate(42)
|
145
|
+
self.assertIn('Invalid collection filter', str(context.exception))
|
146
|
+
|
147
|
+
with self.assertRaises(TypeError):
|
148
|
+
to_predicate({'key': 'value'})
|
149
|
+
|
150
|
+
|
151
|
+
class TestWithTagFilter(unittest.TestCase):
|
152
|
+
"""Test cases for WithTag filter."""
|
153
|
+
|
154
|
+
def test_basic_functionality(self):
|
155
|
+
"""Test basic WithTag functionality."""
|
156
|
+
filter_trainable = WithTag('trainable')
|
157
|
+
|
158
|
+
# Object with matching tag
|
159
|
+
obj1 = MockTaggedObject('trainable')
|
160
|
+
self.assertTrue(filter_trainable([], obj1))
|
161
|
+
|
162
|
+
# Object with different tag
|
163
|
+
obj2 = MockTaggedObject('frozen')
|
164
|
+
self.assertFalse(filter_trainable([], obj2))
|
165
|
+
|
166
|
+
# Object without tag attribute
|
167
|
+
obj3 = {'value': 42}
|
168
|
+
self.assertFalse(filter_trainable([], obj3))
|
169
|
+
|
170
|
+
def test_repr(self):
|
171
|
+
"""Test string representation."""
|
172
|
+
filter_tag = WithTag('test_tag')
|
173
|
+
self.assertEqual(repr(filter_tag), "WithTag('test_tag')")
|
174
|
+
|
175
|
+
def test_immutability(self):
|
176
|
+
"""Test that WithTag is immutable (frozen dataclass)."""
|
177
|
+
filter_tag = WithTag('test')
|
178
|
+
with self.assertRaises(AttributeError):
|
179
|
+
filter_tag.tag = 'modified'
|
180
|
+
|
181
|
+
def test_path_parameter_ignored(self):
|
182
|
+
"""Test that path parameter is ignored."""
|
183
|
+
filter_tag = WithTag('test')
|
184
|
+
obj = MockTaggedObject('test')
|
185
|
+
|
186
|
+
# Different paths should not affect result
|
187
|
+
self.assertTrue(filter_tag([], obj))
|
188
|
+
self.assertTrue(filter_tag(['some', 'path'], obj))
|
189
|
+
self.assertTrue(filter_tag(['another', 'nested', 'path'], obj))
|
190
|
+
|
191
|
+
|
192
|
+
class TestPathContainsFilter(unittest.TestCase):
|
193
|
+
"""Test cases for PathContains filter."""
|
194
|
+
|
195
|
+
def test_basic_functionality(self):
|
196
|
+
"""Test basic PathContains functionality."""
|
197
|
+
filter_weight = PathContains('weight')
|
198
|
+
|
199
|
+
# Path containing the key
|
200
|
+
self.assertTrue(filter_weight(['model', 'layer1', 'weight'], None))
|
201
|
+
self.assertTrue(filter_weight(['weight'], None))
|
202
|
+
self.assertTrue(filter_weight(['deep', 'nested', 'weight', 'param'], None))
|
203
|
+
|
204
|
+
# Path not containing the key
|
205
|
+
self.assertFalse(filter_weight(['model', 'layer1', 'bias'], None))
|
206
|
+
self.assertFalse(filter_weight([], None))
|
207
|
+
self.assertFalse(filter_weight(['other', 'path'], None))
|
208
|
+
|
209
|
+
def test_numeric_keys(self):
|
210
|
+
"""Test with numeric keys in path."""
|
211
|
+
filter_num = PathContains(0)
|
212
|
+
|
213
|
+
self.assertTrue(filter_num([0, 'item'], None))
|
214
|
+
self.assertTrue(filter_num(['list', 0, 'element'], None))
|
215
|
+
self.assertFalse(filter_num([1, 2, 3], None))
|
216
|
+
|
217
|
+
def test_repr(self):
|
218
|
+
"""Test string representation."""
|
219
|
+
filter_path = PathContains('layer2')
|
220
|
+
self.assertEqual(repr(filter_path), "PathContains('layer2')")
|
221
|
+
|
222
|
+
def test_object_parameter_ignored(self):
|
223
|
+
"""Test that object parameter is ignored."""
|
224
|
+
filter_path = PathContains('test')
|
225
|
+
|
226
|
+
# Different objects should not affect result if path contains key
|
227
|
+
self.assertTrue(filter_path(['test'], 'string'))
|
228
|
+
self.assertTrue(filter_path(['test'], 123))
|
229
|
+
self.assertTrue(filter_path(['test'], None))
|
230
|
+
self.assertTrue(filter_path(['test'], {'dict': 'value'}))
|
231
|
+
|
232
|
+
|
233
|
+
class TestOfTypeFilter(unittest.TestCase):
|
234
|
+
"""Test cases for OfType filter."""
|
235
|
+
|
236
|
+
def test_direct_instance_check(self):
|
237
|
+
"""Test checking direct instances of a type."""
|
238
|
+
filter_array = OfType(np.ndarray)
|
239
|
+
|
240
|
+
self.assertTrue(filter_array([], np.array([1, 2, 3])))
|
241
|
+
self.assertTrue(filter_array([], np.zeros((2, 2))))
|
242
|
+
self.assertFalse(filter_array([], [1, 2, 3]))
|
243
|
+
self.assertFalse(filter_array([], 42))
|
244
|
+
|
245
|
+
def test_inheritance(self):
|
246
|
+
"""Test that subclasses are also matched."""
|
247
|
+
class BaseClass:
|
248
|
+
pass
|
249
|
+
|
250
|
+
class DerivedClass(BaseClass):
|
251
|
+
pass
|
252
|
+
|
253
|
+
filter_base = OfType(BaseClass)
|
254
|
+
|
255
|
+
base_obj = BaseClass()
|
256
|
+
derived_obj = DerivedClass()
|
257
|
+
other_obj = "not related"
|
258
|
+
|
259
|
+
self.assertTrue(filter_base([], base_obj))
|
260
|
+
self.assertTrue(filter_base([], derived_obj))
|
261
|
+
self.assertFalse(filter_base([], other_obj))
|
262
|
+
|
263
|
+
def test_type_attribute_check(self):
|
264
|
+
"""Test checking objects with type attribute."""
|
265
|
+
filter_list = OfType(list)
|
266
|
+
|
267
|
+
# Object with type attribute
|
268
|
+
typed_obj = MockTypedObject(list)
|
269
|
+
self.assertTrue(filter_list([], typed_obj))
|
270
|
+
|
271
|
+
# Object with non-matching type attribute
|
272
|
+
typed_obj2 = MockTypedObject(dict)
|
273
|
+
self.assertFalse(filter_list([], typed_obj2))
|
274
|
+
|
275
|
+
def test_repr(self):
|
276
|
+
"""Test string representation."""
|
277
|
+
filter_type = OfType(str)
|
278
|
+
self.assertEqual(repr(filter_type), f"OfType({str!r})")
|
279
|
+
|
280
|
+
def test_builtin_types(self):
|
281
|
+
"""Test with built-in types."""
|
282
|
+
filter_str = OfType(str)
|
283
|
+
filter_int = OfType(int)
|
284
|
+
filter_list = OfType(list)
|
285
|
+
|
286
|
+
self.assertTrue(filter_str([], "hello"))
|
287
|
+
self.assertTrue(filter_int([], 42))
|
288
|
+
self.assertTrue(filter_list([], [1, 2, 3]))
|
289
|
+
|
290
|
+
self.assertFalse(filter_str([], 42))
|
291
|
+
self.assertFalse(filter_int([], "42"))
|
292
|
+
self.assertFalse(filter_list([], (1, 2, 3)))
|
293
|
+
|
294
|
+
|
295
|
+
class TestAnyFilter(unittest.TestCase):
|
296
|
+
"""Test cases for Any filter."""
|
297
|
+
|
298
|
+
def test_basic_or_operation(self):
|
299
|
+
"""Test basic OR operation with multiple filters."""
|
300
|
+
filter_any = Any('trainable', 'frozen')
|
301
|
+
|
302
|
+
trainable = MockTaggedObject('trainable')
|
303
|
+
frozen = MockTaggedObject('frozen')
|
304
|
+
other = MockTaggedObject('other')
|
305
|
+
|
306
|
+
self.assertTrue(filter_any([], trainable))
|
307
|
+
self.assertTrue(filter_any([], frozen))
|
308
|
+
self.assertFalse(filter_any([], other))
|
309
|
+
|
310
|
+
def test_mixed_filter_types(self):
|
311
|
+
"""Test combining different filter types."""
|
312
|
+
filter_mixed = Any(
|
313
|
+
OfType(np.ndarray),
|
314
|
+
WithTag('special'),
|
315
|
+
PathContains('important')
|
316
|
+
)
|
317
|
+
|
318
|
+
# Test each condition
|
319
|
+
self.assertTrue(filter_mixed([], np.array([1, 2])))
|
320
|
+
self.assertTrue(filter_mixed([], MockTaggedObject('special')))
|
321
|
+
self.assertTrue(filter_mixed(['important'], 'anything'))
|
322
|
+
|
323
|
+
# Test none match
|
324
|
+
self.assertFalse(filter_mixed(['other'], MockTaggedObject('normal')))
|
325
|
+
|
326
|
+
def test_short_circuit_evaluation(self):
|
327
|
+
"""Test that Any short-circuits on first True."""
|
328
|
+
call_count = [0]
|
329
|
+
|
330
|
+
def counting_filter(path, x):
|
331
|
+
call_count[0] += 1
|
332
|
+
return x == 'match'
|
333
|
+
|
334
|
+
filter_any = Any(
|
335
|
+
lambda p, x: x == 'match', # This will match
|
336
|
+
counting_filter # This should not be called
|
337
|
+
)
|
338
|
+
|
339
|
+
self.assertTrue(filter_any([], 'match'))
|
340
|
+
self.assertEqual(call_count[0], 0) # Second filter not called
|
341
|
+
|
342
|
+
def test_empty_any(self):
|
343
|
+
"""Test Any with no filters."""
|
344
|
+
filter_empty = Any()
|
345
|
+
# Empty Any should return False (no conditions to satisfy)
|
346
|
+
self.assertFalse(filter_empty([], 'anything'))
|
347
|
+
|
348
|
+
def test_repr(self):
|
349
|
+
"""Test string representation."""
|
350
|
+
filter_any = Any(WithTag('tag1'), WithTag('tag2'))
|
351
|
+
repr_str = repr(filter_any)
|
352
|
+
self.assertIn('Any', repr_str)
|
353
|
+
self.assertIn("WithTag('tag1')", repr_str)
|
354
|
+
self.assertIn("WithTag('tag2')", repr_str)
|
355
|
+
|
356
|
+
def test_equality(self):
|
357
|
+
"""Test equality comparison."""
|
358
|
+
filter1 = Any('tag1', 'tag2')
|
359
|
+
filter2 = Any('tag1', 'tag2')
|
360
|
+
filter3 = Any('tag2', 'tag1') # Different order
|
361
|
+
|
362
|
+
self.assertEqual(filter1, filter2)
|
363
|
+
self.assertNotEqual(filter1, filter3)
|
364
|
+
self.assertNotEqual(filter1, 'not a filter')
|
365
|
+
|
366
|
+
def test_hashable(self):
|
367
|
+
"""Test that Any filters are hashable."""
|
368
|
+
filter1 = Any('tag1', 'tag2')
|
369
|
+
filter2 = Any('tag1', 'tag2')
|
370
|
+
|
371
|
+
# Should be able to use in set/dict
|
372
|
+
filter_set = {filter1, filter2}
|
373
|
+
self.assertEqual(len(filter_set), 1) # Same filters
|
374
|
+
|
375
|
+
|
376
|
+
class TestAllFilter(unittest.TestCase):
|
377
|
+
"""Test cases for All filter."""
|
378
|
+
|
379
|
+
def test_basic_and_operation(self):
|
380
|
+
"""Test basic AND operation with multiple filters."""
|
381
|
+
filter_all = All(
|
382
|
+
WithTag('trainable'),
|
383
|
+
OfType(np.ndarray)
|
384
|
+
)
|
385
|
+
|
386
|
+
# Create a numpy subclass that can have attributes
|
387
|
+
class TaggedArray(np.ndarray):
|
388
|
+
def __new__(cls, input_array, tag=None):
|
389
|
+
obj = np.asarray(input_array).view(cls)
|
390
|
+
obj.tag = tag
|
391
|
+
return obj
|
392
|
+
|
393
|
+
# Create test objects
|
394
|
+
arr_obj = TaggedArray([1, 2, 3], tag='trainable')
|
395
|
+
self.assertTrue(filter_all([], arr_obj)) # Matches both conditions
|
396
|
+
|
397
|
+
arr_obj2 = TaggedArray([4, 5, 6], tag='frozen')
|
398
|
+
self.assertFalse(filter_all([], arr_obj2)) # Wrong tag
|
399
|
+
|
400
|
+
# List with tag (won't match type)
|
401
|
+
class ListWithTag:
|
402
|
+
def __init__(self, tag):
|
403
|
+
self.tag = tag
|
404
|
+
|
405
|
+
lst_obj = ListWithTag('trainable')
|
406
|
+
self.assertFalse(filter_all([], lst_obj)) # Wrong type
|
407
|
+
|
408
|
+
def test_short_circuit_evaluation(self):
|
409
|
+
"""Test that All short-circuits on first False."""
|
410
|
+
call_count = [0]
|
411
|
+
|
412
|
+
def counting_filter(path, x):
|
413
|
+
call_count[0] += 1
|
414
|
+
return True
|
415
|
+
|
416
|
+
filter_all = All(
|
417
|
+
lambda p, x: False, # This will fail
|
418
|
+
counting_filter # This should not be called
|
419
|
+
)
|
420
|
+
|
421
|
+
self.assertFalse(filter_all([], 'anything'))
|
422
|
+
self.assertEqual(call_count[0], 0) # Second filter not called
|
423
|
+
|
424
|
+
def test_empty_all(self):
|
425
|
+
"""Test All with no filters."""
|
426
|
+
filter_empty = All()
|
427
|
+
# Empty All should return True (no conditions to violate)
|
428
|
+
self.assertTrue(filter_empty([], 'anything'))
|
429
|
+
|
430
|
+
def test_complex_combination(self):
|
431
|
+
"""Test complex combination of conditions."""
|
432
|
+
class CustomObject:
|
433
|
+
def __init__(self, tag, value):
|
434
|
+
self.tag = tag
|
435
|
+
self.value = value
|
436
|
+
|
437
|
+
filter_complex = All(
|
438
|
+
WithTag('important'),
|
439
|
+
lambda p, x: hasattr(x, 'value') and x.value > 10,
|
440
|
+
lambda p, x: hasattr(x, 'value') and x.value < 100
|
441
|
+
)
|
442
|
+
|
443
|
+
obj1 = CustomObject('important', 50)
|
444
|
+
obj2 = CustomObject('important', 5)
|
445
|
+
obj3 = CustomObject('important', 150)
|
446
|
+
obj4 = CustomObject('other', 50)
|
447
|
+
|
448
|
+
self.assertTrue(filter_complex([], obj1)) # All conditions met
|
449
|
+
self.assertFalse(filter_complex([], obj2)) # value too small
|
450
|
+
self.assertFalse(filter_complex([], obj3)) # value too large
|
451
|
+
self.assertFalse(filter_complex([], obj4)) # wrong tag
|
452
|
+
|
453
|
+
def test_repr(self):
|
454
|
+
"""Test string representation."""
|
455
|
+
filter_all = All(WithTag('tag1'), OfType(list))
|
456
|
+
repr_str = repr(filter_all)
|
457
|
+
self.assertIn('All', repr_str)
|
458
|
+
self.assertIn("WithTag('tag1')", repr_str)
|
459
|
+
self.assertIn('OfType', repr_str)
|
460
|
+
|
461
|
+
def test_equality(self):
|
462
|
+
"""Test equality comparison."""
|
463
|
+
filter1 = All('tag1', np.ndarray)
|
464
|
+
filter2 = All('tag1', np.ndarray)
|
465
|
+
filter3 = All(np.ndarray, 'tag1') # Different order
|
466
|
+
|
467
|
+
self.assertEqual(filter1, filter2)
|
468
|
+
self.assertNotEqual(filter1, filter3)
|
469
|
+
|
470
|
+
def test_hashable(self):
|
471
|
+
"""Test that All filters are hashable."""
|
472
|
+
filter1 = All('tag1', np.ndarray)
|
473
|
+
filter2 = All('tag1', np.ndarray)
|
474
|
+
|
475
|
+
filter_dict = {filter1: 'value1', filter2: 'value2'}
|
476
|
+
self.assertEqual(len(filter_dict), 1) # Same filters
|
477
|
+
|
478
|
+
|
479
|
+
class TestNotFilter(unittest.TestCase):
|
480
|
+
"""Test cases for Not filter."""
|
481
|
+
|
482
|
+
def test_basic_negation(self):
|
483
|
+
"""Test basic negation of filters."""
|
484
|
+
filter_not_trainable = Not(WithTag('trainable'))
|
485
|
+
|
486
|
+
trainable = MockTaggedObject('trainable')
|
487
|
+
frozen = MockTaggedObject('frozen')
|
488
|
+
|
489
|
+
self.assertFalse(filter_not_trainable([], trainable))
|
490
|
+
self.assertTrue(filter_not_trainable([], frozen))
|
491
|
+
|
492
|
+
def test_negating_type_filter(self):
|
493
|
+
"""Test negating type filters."""
|
494
|
+
filter_not_array = Not(OfType(np.ndarray))
|
495
|
+
|
496
|
+
self.assertFalse(filter_not_array([], np.array([1, 2])))
|
497
|
+
self.assertTrue(filter_not_array([], [1, 2]))
|
498
|
+
self.assertTrue(filter_not_array([], 'string'))
|
499
|
+
|
500
|
+
def test_negating_complex_filters(self):
|
501
|
+
"""Test negating complex filter combinations."""
|
502
|
+
# Not(Any(...)) - none should match
|
503
|
+
filter_not_any = Not(Any('tag1', 'tag2'))
|
504
|
+
|
505
|
+
obj1 = MockTaggedObject('tag1')
|
506
|
+
obj2 = MockTaggedObject('tag2')
|
507
|
+
obj3 = MockTaggedObject('tag3')
|
508
|
+
|
509
|
+
self.assertFalse(filter_not_any([], obj1))
|
510
|
+
self.assertFalse(filter_not_any([], obj2))
|
511
|
+
self.assertTrue(filter_not_any([], obj3))
|
512
|
+
|
513
|
+
# Not(All(...)) - at least one should not match
|
514
|
+
filter_not_all = Not(All(WithTag('tag'), OfType(list)))
|
515
|
+
|
516
|
+
# Create a list-like object with tag
|
517
|
+
class TaggedList(list):
|
518
|
+
def __init__(self, *args):
|
519
|
+
super().__init__(*args)
|
520
|
+
self.tag = 'tag'
|
521
|
+
|
522
|
+
lst = TaggedList([1, 2, 3])
|
523
|
+
self.assertFalse(filter_not_all([], lst)) # Matches all conditions
|
524
|
+
|
525
|
+
# Create a numpy subclass with tag
|
526
|
+
class TaggedArray(np.ndarray):
|
527
|
+
def __new__(cls, input_array, tag=None):
|
528
|
+
obj = np.asarray(input_array).view(cls)
|
529
|
+
obj.tag = tag
|
530
|
+
return obj
|
531
|
+
|
532
|
+
arr = TaggedArray([], tag='tag')
|
533
|
+
self.assertTrue(filter_not_all([], arr)) # Doesn't match type (not a list)
|
534
|
+
|
535
|
+
def test_double_negation(self):
|
536
|
+
"""Test double negation returns to original."""
|
537
|
+
original = WithTag('test')
|
538
|
+
double_neg = Not(Not(original))
|
539
|
+
|
540
|
+
obj_match = MockTaggedObject('test')
|
541
|
+
obj_no_match = MockTaggedObject('other')
|
542
|
+
|
543
|
+
# Double negation should behave like original
|
544
|
+
self.assertEqual(
|
545
|
+
original([], obj_match),
|
546
|
+
double_neg([], obj_match)
|
547
|
+
)
|
548
|
+
self.assertEqual(
|
549
|
+
original([], obj_no_match),
|
550
|
+
double_neg([], obj_no_match)
|
551
|
+
)
|
552
|
+
|
553
|
+
def test_repr(self):
|
554
|
+
"""Test string representation."""
|
555
|
+
filter_not = Not(WithTag('test'))
|
556
|
+
self.assertEqual(repr(filter_not), "Not(WithTag('test'))")
|
557
|
+
|
558
|
+
def test_equality(self):
|
559
|
+
"""Test equality comparison."""
|
560
|
+
filter1 = Not(WithTag('test'))
|
561
|
+
filter2 = Not(WithTag('test'))
|
562
|
+
filter3 = Not(WithTag('other'))
|
563
|
+
|
564
|
+
self.assertEqual(filter1, filter2)
|
565
|
+
self.assertNotEqual(filter1, filter3)
|
566
|
+
|
567
|
+
def test_hashable(self):
|
568
|
+
"""Test that Not filters are hashable."""
|
569
|
+
filter1 = Not(WithTag('test'))
|
570
|
+
filter2 = Not(WithTag('test'))
|
571
|
+
|
572
|
+
filter_set = {filter1, filter2}
|
573
|
+
self.assertEqual(len(filter_set), 1)
|
574
|
+
|
575
|
+
|
576
|
+
class TestEverythingFilter(unittest.TestCase):
|
577
|
+
"""Test cases for Everything filter."""
|
578
|
+
|
579
|
+
def test_always_returns_true(self):
|
580
|
+
"""Test that Everything always returns True."""
|
581
|
+
filter_all = Everything()
|
582
|
+
|
583
|
+
# Test with various objects
|
584
|
+
self.assertTrue(filter_all([], None))
|
585
|
+
self.assertTrue(filter_all([], 42))
|
586
|
+
self.assertTrue(filter_all([], 'string'))
|
587
|
+
self.assertTrue(filter_all([], [1, 2, 3]))
|
588
|
+
self.assertTrue(filter_all([], np.array([1, 2])))
|
589
|
+
self.assertTrue(filter_all([], {'key': 'value'}))
|
590
|
+
|
591
|
+
# Test with various paths
|
592
|
+
self.assertTrue(filter_all(['path'], None))
|
593
|
+
self.assertTrue(filter_all(['nested', 'path'], None))
|
594
|
+
self.assertTrue(filter_all([], None))
|
595
|
+
|
596
|
+
def test_repr(self):
|
597
|
+
"""Test string representation."""
|
598
|
+
filter_all = Everything()
|
599
|
+
self.assertEqual(repr(filter_all), 'Everything()')
|
600
|
+
|
601
|
+
def test_equality(self):
|
602
|
+
"""Test equality comparison."""
|
603
|
+
filter1 = Everything()
|
604
|
+
filter2 = Everything()
|
605
|
+
filter3 = Nothing()
|
606
|
+
|
607
|
+
self.assertEqual(filter1, filter2)
|
608
|
+
self.assertNotEqual(filter1, filter3)
|
609
|
+
|
610
|
+
def test_hashable(self):
|
611
|
+
"""Test that Everything filters are hashable."""
|
612
|
+
filter1 = Everything()
|
613
|
+
filter2 = Everything()
|
614
|
+
|
615
|
+
# All Everything instances should be equal and have same hash
|
616
|
+
self.assertEqual(hash(filter1), hash(filter2))
|
617
|
+
|
618
|
+
filter_set = {filter1, filter2}
|
619
|
+
self.assertEqual(len(filter_set), 1)
|
620
|
+
|
621
|
+
def test_conversion_from_true(self):
|
622
|
+
"""Test that True converts to Everything."""
|
623
|
+
filter_from_true = to_predicate(True)
|
624
|
+
filter_direct = Everything()
|
625
|
+
|
626
|
+
# Should behave identically
|
627
|
+
test_cases = [None, 42, 'test', [], {}]
|
628
|
+
for obj in test_cases:
|
629
|
+
self.assertEqual(
|
630
|
+
filter_from_true([], obj),
|
631
|
+
filter_direct([], obj)
|
632
|
+
)
|
633
|
+
|
634
|
+
def test_conversion_from_ellipsis(self):
|
635
|
+
"""Test that Ellipsis converts to Everything."""
|
636
|
+
filter_from_ellipsis = to_predicate(...)
|
637
|
+
filter_direct = Everything()
|
638
|
+
|
639
|
+
self.assertIsInstance(filter_from_ellipsis, Everything)
|
640
|
+
self.assertEqual(filter_from_ellipsis, filter_direct)
|
641
|
+
|
642
|
+
|
643
|
+
class TestNothingFilter(unittest.TestCase):
|
644
|
+
"""Test cases for Nothing filter."""
|
645
|
+
|
646
|
+
def test_always_returns_false(self):
|
647
|
+
"""Test that Nothing always returns False."""
|
648
|
+
filter_none = Nothing()
|
649
|
+
|
650
|
+
# Test with various objects
|
651
|
+
self.assertFalse(filter_none([], None))
|
652
|
+
self.assertFalse(filter_none([], 42))
|
653
|
+
self.assertFalse(filter_none([], 'string'))
|
654
|
+
self.assertFalse(filter_none([], [1, 2, 3]))
|
655
|
+
self.assertFalse(filter_none([], np.array([1, 2])))
|
656
|
+
self.assertFalse(filter_none([], {'key': 'value'}))
|
657
|
+
|
658
|
+
# Test with various paths
|
659
|
+
self.assertFalse(filter_none(['path'], None))
|
660
|
+
self.assertFalse(filter_none(['nested', 'path'], None))
|
661
|
+
self.assertFalse(filter_none([], None))
|
662
|
+
|
663
|
+
def test_repr(self):
|
664
|
+
"""Test string representation."""
|
665
|
+
filter_none = Nothing()
|
666
|
+
self.assertEqual(repr(filter_none), 'Nothing()')
|
667
|
+
|
668
|
+
def test_equality(self):
|
669
|
+
"""Test equality comparison."""
|
670
|
+
filter1 = Nothing()
|
671
|
+
filter2 = Nothing()
|
672
|
+
filter3 = Everything()
|
673
|
+
|
674
|
+
self.assertEqual(filter1, filter2)
|
675
|
+
self.assertNotEqual(filter1, filter3)
|
676
|
+
|
677
|
+
def test_hashable(self):
|
678
|
+
"""Test that Nothing filters are hashable."""
|
679
|
+
filter1 = Nothing()
|
680
|
+
filter2 = Nothing()
|
681
|
+
|
682
|
+
# All Nothing instances should be equal and have same hash
|
683
|
+
self.assertEqual(hash(filter1), hash(filter2))
|
684
|
+
|
685
|
+
filter_dict = {filter1: 'value'}
|
686
|
+
filter_dict[filter2] = 'new_value'
|
687
|
+
self.assertEqual(len(filter_dict), 1) # Same key
|
688
|
+
|
689
|
+
def test_conversion_from_false(self):
|
690
|
+
"""Test that False converts to Nothing."""
|
691
|
+
filter_from_false = to_predicate(False)
|
692
|
+
filter_direct = Nothing()
|
693
|
+
|
694
|
+
# Should behave identically
|
695
|
+
test_cases = [None, 42, 'test', [], {}]
|
696
|
+
for obj in test_cases:
|
697
|
+
self.assertEqual(
|
698
|
+
filter_from_false([], obj),
|
699
|
+
filter_direct([], obj)
|
700
|
+
)
|
701
|
+
|
702
|
+
def test_conversion_from_none(self):
|
703
|
+
"""Test that None converts to Nothing."""
|
704
|
+
filter_from_none = to_predicate(None)
|
705
|
+
filter_direct = Nothing()
|
706
|
+
|
707
|
+
self.assertIsInstance(filter_from_none, Nothing)
|
708
|
+
self.assertEqual(filter_from_none, filter_direct)
|
709
|
+
|
710
|
+
|
711
|
+
class TestIntegrationScenarios(unittest.TestCase):
|
712
|
+
"""Integration tests for complex filter combinations."""
|
713
|
+
|
714
|
+
def test_neural_network_parameter_filtering(self):
|
715
|
+
"""Test filtering neural network parameters with complex criteria."""
|
716
|
+
# Simulate neural network parameters
|
717
|
+
class Parameter:
|
718
|
+
def __init__(self, shape, tag=None, dtype=None):
|
719
|
+
self.shape = shape
|
720
|
+
self.tag = tag
|
721
|
+
self.dtype = dtype
|
722
|
+
self.data = np.random.randn(*shape) if dtype != 'int32' else np.random.randint(0, 10, shape)
|
723
|
+
|
724
|
+
# Create various parameters
|
725
|
+
weight1 = Parameter((10, 20), tag='trainable')
|
726
|
+
weight1.data = np.array(weight1.data) # Ensure it's ndarray
|
727
|
+
|
728
|
+
bias1 = Parameter((20,), tag='trainable')
|
729
|
+
bias1.data = np.array(bias1.data)
|
730
|
+
|
731
|
+
embedding = Parameter((100, 64), tag='frozen')
|
732
|
+
embedding.data = np.array(embedding.data)
|
733
|
+
|
734
|
+
# Complex filter: trainable arrays with shape[0] > 5
|
735
|
+
filter_complex = All(
|
736
|
+
WithTag('trainable'),
|
737
|
+
lambda p, x: hasattr(x, 'data') and isinstance(x.data, np.ndarray),
|
738
|
+
lambda p, x: hasattr(x, 'shape') and x.shape[0] > 5
|
739
|
+
)
|
740
|
+
|
741
|
+
self.assertTrue(filter_complex([], weight1)) # Matches all
|
742
|
+
self.assertTrue(filter_complex([], bias1)) # shape[0] = 20 > 5
|
743
|
+
self.assertFalse(filter_complex([], embedding)) # Wrong tag
|
744
|
+
|
745
|
+
def test_path_based_model_filtering(self):
|
746
|
+
"""Test filtering based on model structure paths."""
|
747
|
+
# Filter for encoder weights
|
748
|
+
encoder_weight_filter = All(
|
749
|
+
PathContains('encoder'),
|
750
|
+
PathContains('weight')
|
751
|
+
)
|
752
|
+
|
753
|
+
# Test various paths
|
754
|
+
paths = [
|
755
|
+
(['model', 'encoder', 'layer1', 'weight'], True),
|
756
|
+
(['model', 'encoder', 'layer2', 'weight'], True),
|
757
|
+
(['model', 'encoder', 'layer1', 'bias'], False), # Not weight
|
758
|
+
(['model', 'decoder', 'layer1', 'weight'], False), # Not encoder
|
759
|
+
(['encoder', 'attention', 'weight'], True),
|
760
|
+
]
|
761
|
+
|
762
|
+
for path, expected in paths:
|
763
|
+
self.assertEqual(
|
764
|
+
encoder_weight_filter(path, None),
|
765
|
+
expected,
|
766
|
+
f"Failed for path: {path}"
|
767
|
+
)
|
768
|
+
|
769
|
+
def test_selective_gradient_computation(self):
|
770
|
+
"""Test filter for selective gradient computation."""
|
771
|
+
# Only compute gradients for trainable non-embedding layers
|
772
|
+
gradient_filter = All(
|
773
|
+
WithTag('trainable'),
|
774
|
+
Not(PathContains('embedding')),
|
775
|
+
Any(
|
776
|
+
PathContains('weight'),
|
777
|
+
PathContains('bias')
|
778
|
+
)
|
779
|
+
)
|
780
|
+
|
781
|
+
# Create test objects
|
782
|
+
class Param:
|
783
|
+
def __init__(self, tag):
|
784
|
+
self.tag = tag
|
785
|
+
|
786
|
+
trainable_weight = Param('trainable')
|
787
|
+
trainable_bias = Param('trainable')
|
788
|
+
frozen_weight = Param('frozen')
|
789
|
+
trainable_other = Param('trainable')
|
790
|
+
|
791
|
+
test_cases = [
|
792
|
+
(['layer1', 'weight'], trainable_weight, True),
|
793
|
+
(['layer1', 'bias'], trainable_bias, True),
|
794
|
+
(['embedding', 'weight'], trainable_weight, False), # Excluded path
|
795
|
+
(['layer1', 'weight'], frozen_weight, False), # Wrong tag
|
796
|
+
(['layer1', 'gamma'], trainable_other, False), # Not weight/bias
|
797
|
+
]
|
798
|
+
|
799
|
+
for path, obj, expected in test_cases:
|
800
|
+
self.assertEqual(
|
801
|
+
gradient_filter(path, obj),
|
802
|
+
expected,
|
803
|
+
f"Failed for path={path}, tag={obj.tag}"
|
804
|
+
)
|
805
|
+
|
806
|
+
def test_demorgan_laws(self):
|
807
|
+
"""Test De Morgan's laws with filters."""
|
808
|
+
tag1_filter = WithTag('tag1')
|
809
|
+
tag2_filter = WithTag('tag2')
|
810
|
+
|
811
|
+
# Not(A or B) == (Not A) and (Not B)
|
812
|
+
not_any = Not(Any(tag1_filter, tag2_filter))
|
813
|
+
all_not = All(Not(tag1_filter), Not(tag2_filter))
|
814
|
+
|
815
|
+
test_objects = [
|
816
|
+
MockTaggedObject('tag1'),
|
817
|
+
MockTaggedObject('tag2'),
|
818
|
+
MockTaggedObject('tag3'),
|
819
|
+
MockTaggedObject('other'),
|
820
|
+
]
|
821
|
+
|
822
|
+
for obj in test_objects:
|
823
|
+
self.assertEqual(
|
824
|
+
not_any([], obj),
|
825
|
+
all_not([], obj),
|
826
|
+
f"De Morgan's law failed for tag={obj.tag}"
|
827
|
+
)
|
828
|
+
|
829
|
+
# Not(A and B) == (Not A) or (Not B)
|
830
|
+
not_all = Not(All(tag1_filter, tag2_filter))
|
831
|
+
any_not = Any(Not(tag1_filter), Not(tag2_filter))
|
832
|
+
|
833
|
+
# For All filter to work, object needs both tags
|
834
|
+
class DualTagged:
|
835
|
+
def __init__(self, tag1, tag2):
|
836
|
+
self.tag = tag1 if tag1 else tag2 # For single tag check
|
837
|
+
|
838
|
+
# Create object that could match both filters
|
839
|
+
dual1 = DualTagged('tag1', None)
|
840
|
+
dual2 = DualTagged('tag2', None)
|
841
|
+
neither = DualTagged('other', None)
|
842
|
+
|
843
|
+
for obj in [dual1, dual2, neither]:
|
844
|
+
# Since our mock object can only have one tag at a time,
|
845
|
+
# All(tag1, tag2) will always be False, so Not(All(...)) will always be True
|
846
|
+
# and Any(Not(tag1), Not(tag2)) will depend on the specific tag
|
847
|
+
pass # This specific test is limited by our mock implementation
|
848
|
+
|
849
|
+
def test_filter_chaining_performance(self):
|
850
|
+
"""Test that filter chaining works correctly."""
|
851
|
+
# Create a chain of filters
|
852
|
+
base_filter = WithTag('base')
|
853
|
+
extended_filter = All(base_filter, OfType(dict))
|
854
|
+
final_filter = Any(extended_filter, PathContains('special'))
|
855
|
+
|
856
|
+
# Test object that matches base but not type
|
857
|
+
obj1 = MockTaggedObject('base')
|
858
|
+
self.assertFalse(extended_filter([], obj1)) # Not a dict
|
859
|
+
self.assertFalse(final_filter([], obj1)) # Doesn't match Any conditions
|
860
|
+
|
861
|
+
# Test object that matches everything
|
862
|
+
class TaggedDict(dict):
|
863
|
+
def __init__(self, *args, **kwargs):
|
864
|
+
super().__init__(*args, **kwargs)
|
865
|
+
self.tag = 'base'
|
866
|
+
|
867
|
+
obj2 = TaggedDict(key='value')
|
868
|
+
self.assertTrue(extended_filter([], obj2)) # Matches both conditions
|
869
|
+
self.assertTrue(final_filter([], obj2)) # Matches via extended_filter
|
870
|
+
|
871
|
+
# Test path-based match
|
872
|
+
self.assertTrue(final_filter(['special'], 'anything'))
|
873
|
+
|
874
|
+
def test_recursive_filter_structures(self):
|
875
|
+
"""Test deeply nested filter combinations."""
|
876
|
+
# Build a complex filter structure
|
877
|
+
filter_deep = Any(
|
878
|
+
All(
|
879
|
+
WithTag('level1'),
|
880
|
+
Any(
|
881
|
+
All(WithTag('level2a'), OfType(list)),
|
882
|
+
All(WithTag('level2b'), OfType(dict))
|
883
|
+
)
|
884
|
+
),
|
885
|
+
Not(
|
886
|
+
All(
|
887
|
+
PathContains('excluded'),
|
888
|
+
Not(WithTag('override'))
|
889
|
+
)
|
890
|
+
)
|
891
|
+
)
|
892
|
+
|
893
|
+
# This is a complex filter, let's test a few cases
|
894
|
+
# Case 1: Would need an object with tag='level1' and also tag='level2a' and be a list
|
895
|
+
# Since objects can only have one tag, this is hard to test directly
|
896
|
+
# Instead, test the second branch
|
897
|
+
|
898
|
+
# Case 2: Matches second branch (not in excluded path without override)
|
899
|
+
obj = MockTaggedObject('any')
|
900
|
+
self.assertTrue(filter_deep(['included'], obj)) # Not excluded path
|
901
|
+
|
902
|
+
# Case 3: In excluded path but has override tag
|
903
|
+
override_obj = MockTaggedObject('override')
|
904
|
+
self.assertTrue(filter_deep(['excluded'], override_obj)) # Has override
|
905
|
+
|
906
|
+
# Case 4: In excluded path without override - should not match
|
907
|
+
regular_obj = MockTaggedObject('regular')
|
908
|
+
self.assertFalse(filter_deep(['excluded'], regular_obj)) # Excluded without override
|
909
|
+
|
910
|
+
|
911
|
+
if __name__ == '__main__':
|
912
|
+
unittest.main()
|