brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +608 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/{nn/event → event}/__init__.py +6 -6
- brainstate/event/_csr.py +308 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +271 -0
- brainstate/event/_fixed_probability_test.py +128 -0
- brainstate/event/_linear.py +219 -0
- brainstate/event/_linear_test.py +112 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +32 -0
- brainstate/nn/_interaction/_connections.py +726 -0
- brainstate/nn/_interaction/_connections_test.py +254 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +103 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/util/_error.py +28 -0
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
- brainstate-0.1.0.dist-info/RECORD +135 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,160 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import unittest
|
19
|
+
|
20
|
+
import jax
|
21
|
+
from absl.testing import absltest
|
22
|
+
|
23
|
+
import brainstate as bst
|
24
|
+
|
25
|
+
|
26
|
+
class TestNestedMapping(absltest.TestCase):
|
27
|
+
def test_create_state(self):
|
28
|
+
state = bst.util.NestedDict({'a': bst.ParamState(1), 'b': {'c': bst.ParamState(2)}})
|
29
|
+
|
30
|
+
assert state['a'].value == 1
|
31
|
+
assert state['b']['c'].value == 2
|
32
|
+
|
33
|
+
def test_get_attr(self):
|
34
|
+
state = bst.util.NestedDict({'a': bst.ParamState(1), 'b': {'c': bst.ParamState(2)}})
|
35
|
+
|
36
|
+
assert state.a.value == 1
|
37
|
+
assert state.b['c'].value == 2
|
38
|
+
|
39
|
+
def test_set_attr(self):
|
40
|
+
state = bst.util.NestedDict({'a': bst.ParamState(1), 'b': {'c': bst.ParamState(2)}})
|
41
|
+
|
42
|
+
state.a.value = 3
|
43
|
+
state.b['c'].value = 4
|
44
|
+
|
45
|
+
assert state['a'].value == 3
|
46
|
+
assert state['b']['c'].value == 4
|
47
|
+
|
48
|
+
def test_set_attr_variables(self):
|
49
|
+
state = bst.util.NestedDict({'a': bst.ParamState(1), 'b': {'c': bst.ParamState(2)}})
|
50
|
+
|
51
|
+
state.a.value = 3
|
52
|
+
state.b['c'].value = 4
|
53
|
+
|
54
|
+
assert isinstance(state.a, bst.ParamState)
|
55
|
+
assert state.a.value == 3
|
56
|
+
assert isinstance(state.b['c'], bst.ParamState)
|
57
|
+
assert state.b['c'].value == 4
|
58
|
+
|
59
|
+
def test_add_nested_attr(self):
|
60
|
+
state = bst.util.NestedDict({'a': bst.ParamState(1), 'b': {'c': bst.ParamState(2)}})
|
61
|
+
state.b['d'] = bst.ParamState(5)
|
62
|
+
|
63
|
+
assert state['b']['d'].value == 5
|
64
|
+
|
65
|
+
def test_delete_nested_attr(self):
|
66
|
+
state = bst.util.NestedDict({'a': bst.ParamState(1), 'b': {'c': bst.ParamState(2)}})
|
67
|
+
del state['b']['c']
|
68
|
+
|
69
|
+
assert 'c' not in state['b']
|
70
|
+
|
71
|
+
def test_integer_access(self):
|
72
|
+
class Foo(bst.nn.Module):
|
73
|
+
def __init__(self):
|
74
|
+
super().__init__()
|
75
|
+
self.layers = [bst.nn.Linear(1, 2), bst.nn.Linear(2, 3)]
|
76
|
+
|
77
|
+
module = Foo()
|
78
|
+
state_refs = bst.graph.treefy_states(module)
|
79
|
+
|
80
|
+
assert module.layers[0].weight.value['weight'].shape == (1, 2)
|
81
|
+
assert state_refs.layers[0]['weight'].value['weight'].shape == (1, 2)
|
82
|
+
assert module.layers[1].weight.value['weight'].shape == (2, 3)
|
83
|
+
assert state_refs.layers[1]['weight'].value['weight'].shape == (2, 3)
|
84
|
+
|
85
|
+
def test_pure_dict(self):
|
86
|
+
module = bst.nn.Linear(4, 5)
|
87
|
+
state_map = bst.graph.treefy_states(module)
|
88
|
+
pure_dict = state_map.to_pure_dict()
|
89
|
+
assert isinstance(pure_dict, dict)
|
90
|
+
assert isinstance(pure_dict['weight'].value['weight'], jax.Array)
|
91
|
+
assert isinstance(pure_dict['weight'].value['bias'], jax.Array)
|
92
|
+
|
93
|
+
|
94
|
+
class TestSplit(unittest.TestCase):
|
95
|
+
def test_split(self):
|
96
|
+
class Model(bst.nn.Module):
|
97
|
+
def __init__(self):
|
98
|
+
super().__init__()
|
99
|
+
self.batchnorm = bst.nn.BatchNorm1d([10, 3])
|
100
|
+
self.linear = bst.nn.Linear([10, 3], [10, 4])
|
101
|
+
|
102
|
+
def __call__(self, x):
|
103
|
+
return self.linear(self.batchnorm(x))
|
104
|
+
|
105
|
+
with bst.environ.context(fit=True):
|
106
|
+
model = Model()
|
107
|
+
x = bst.random.randn(1, 10, 3)
|
108
|
+
y = model(x)
|
109
|
+
self.assertEqual(y.shape, (1, 10, 4))
|
110
|
+
|
111
|
+
state_map = bst.graph.treefy_states(model)
|
112
|
+
|
113
|
+
with self.assertRaises(ValueError):
|
114
|
+
params, others = state_map.split(bst.ParamState)
|
115
|
+
|
116
|
+
params, others = state_map.split(bst.ParamState, ...)
|
117
|
+
print()
|
118
|
+
print(params)
|
119
|
+
print(others)
|
120
|
+
|
121
|
+
self.assertTrue(len(params.to_flat()) == 2)
|
122
|
+
self.assertTrue(len(others.to_flat()) == 2)
|
123
|
+
|
124
|
+
|
125
|
+
class TestStateMap2(unittest.TestCase):
|
126
|
+
def test1(self):
|
127
|
+
class Model(bst.nn.Module):
|
128
|
+
def __init__(self):
|
129
|
+
super().__init__()
|
130
|
+
self.batchnorm = bst.nn.BatchNorm1d([10, 3])
|
131
|
+
self.linear = bst.nn.Linear([10, 3], [10, 4])
|
132
|
+
|
133
|
+
def __call__(self, x):
|
134
|
+
return self.linear(self.batchnorm(x))
|
135
|
+
|
136
|
+
with bst.environ.context(fit=True):
|
137
|
+
model = Model()
|
138
|
+
state_map = bst.graph.treefy_states(model).to_flat()
|
139
|
+
state_map = bst.util.NestedDict(state_map)
|
140
|
+
|
141
|
+
|
142
|
+
class TestFlattedMapping(unittest.TestCase):
|
143
|
+
def test1(self):
|
144
|
+
class Model(bst.nn.Module):
|
145
|
+
def __init__(self):
|
146
|
+
super().__init__()
|
147
|
+
self.batchnorm = bst.nn.BatchNorm1d([10, 3])
|
148
|
+
self.linear = bst.nn.Linear([10, 3], [10, 4])
|
149
|
+
|
150
|
+
def __call__(self, x):
|
151
|
+
return self.linear(self.batchnorm(x))
|
152
|
+
|
153
|
+
model = Model()
|
154
|
+
# print(model.states())
|
155
|
+
# print(bst.graph.states(model))
|
156
|
+
self.assertTrue(model.states() == bst.graph.states(model))
|
157
|
+
|
158
|
+
print(model.nodes())
|
159
|
+
# print(bst.graph.nodes(model))
|
160
|
+
self.assertTrue(model.nodes() == bst.graph.nodes(model))
|
@@ -0,0 +1,28 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
__all__ = [
|
18
|
+
'BrainStateError',
|
19
|
+
'TraceContextError',
|
20
|
+
]
|
21
|
+
|
22
|
+
|
23
|
+
class BrainStateError(Exception):
|
24
|
+
pass
|
25
|
+
|
26
|
+
|
27
|
+
class TraceContextError(BrainStateError):
|
28
|
+
pass
|
@@ -0,0 +1,178 @@
|
|
1
|
+
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
|
+
# The credit should go to the Flax authors.
|
3
|
+
#
|
4
|
+
# Copyright 2024 The Flax Authors & 2024 BDP Ecosystem.
|
5
|
+
#
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
|
+
# you may not use this file except in compliance with the License.
|
8
|
+
# You may obtain a copy of the License at
|
9
|
+
#
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11
|
+
#
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15
|
+
# See the License for the specific language governing permissions and
|
16
|
+
# limitations under the License.
|
17
|
+
|
18
|
+
from __future__ import annotations
|
19
|
+
|
20
|
+
import builtins
|
21
|
+
import dataclasses
|
22
|
+
import typing
|
23
|
+
from typing import TYPE_CHECKING
|
24
|
+
|
25
|
+
from brainstate.typing import Filter, PathParts, Predicate, Key
|
26
|
+
|
27
|
+
if TYPE_CHECKING:
|
28
|
+
ellipsis = builtins.ellipsis
|
29
|
+
else:
|
30
|
+
ellipsis = typing.Any
|
31
|
+
|
32
|
+
__all__ = [
|
33
|
+
'to_predicate',
|
34
|
+
]
|
35
|
+
|
36
|
+
|
37
|
+
def to_predicate(the_filter: Filter) -> Predicate:
|
38
|
+
"""
|
39
|
+
Converts a Filter to a predicate function.
|
40
|
+
"""
|
41
|
+
|
42
|
+
if isinstance(the_filter, str):
|
43
|
+
return WithTagFilter(the_filter)
|
44
|
+
elif isinstance(the_filter, type):
|
45
|
+
return OfTypeFilter(the_filter)
|
46
|
+
elif isinstance(the_filter, bool):
|
47
|
+
if the_filter:
|
48
|
+
return EverythingFilter()
|
49
|
+
else:
|
50
|
+
return NothingFilter()
|
51
|
+
elif the_filter is Ellipsis:
|
52
|
+
return EverythingFilter()
|
53
|
+
elif the_filter is None:
|
54
|
+
return NothingFilter()
|
55
|
+
elif callable(the_filter):
|
56
|
+
return the_filter
|
57
|
+
elif isinstance(the_filter, (list, tuple)):
|
58
|
+
return AnyFilter(*the_filter)
|
59
|
+
else:
|
60
|
+
raise TypeError(f'Invalid collection filter: {the_filter:!r}. ')
|
61
|
+
|
62
|
+
|
63
|
+
@dataclasses.dataclass(frozen=True)
|
64
|
+
class WithTagFilter:
|
65
|
+
tag: str
|
66
|
+
|
67
|
+
def __call__(self, path: PathParts, x: typing.Any):
|
68
|
+
return hasattr(x, 'tag') and x.tag == self.tag
|
69
|
+
|
70
|
+
def __repr__(self):
|
71
|
+
return f'WithTag({self.tag!r})'
|
72
|
+
|
73
|
+
|
74
|
+
@dataclasses.dataclass(frozen=True)
|
75
|
+
class PathContainsFilter:
|
76
|
+
key: Key
|
77
|
+
|
78
|
+
def __call__(self, path: PathParts, x: typing.Any):
|
79
|
+
return self.key in path
|
80
|
+
|
81
|
+
def __repr__(self):
|
82
|
+
return f'PathContains({self.key!r})'
|
83
|
+
|
84
|
+
|
85
|
+
@dataclasses.dataclass(frozen=True)
|
86
|
+
class OfTypeFilter:
|
87
|
+
type: type
|
88
|
+
|
89
|
+
def __call__(self, path: PathParts, x: typing.Any):
|
90
|
+
return isinstance(x, self.type) or (
|
91
|
+
hasattr(x, 'type') and issubclass(x.type, self.type)
|
92
|
+
)
|
93
|
+
|
94
|
+
def __repr__(self):
|
95
|
+
return f'OfType({self.type!r})'
|
96
|
+
|
97
|
+
|
98
|
+
class AnyFilter:
|
99
|
+
def __init__(self, *filters: Filter):
|
100
|
+
self.predicates = tuple(
|
101
|
+
to_predicate(collection_filter) for collection_filter in filters
|
102
|
+
)
|
103
|
+
|
104
|
+
def __call__(self, path: PathParts, x: typing.Any):
|
105
|
+
return any(predicate(path, x) for predicate in self.predicates)
|
106
|
+
|
107
|
+
def __repr__(self):
|
108
|
+
return f'Any({", ".join(map(repr, self.predicates))})'
|
109
|
+
|
110
|
+
def __eq__(self, other):
|
111
|
+
return isinstance(other, AnyFilter) and self.predicates == other.predicates
|
112
|
+
|
113
|
+
def __hash__(self):
|
114
|
+
return hash(self.predicates)
|
115
|
+
|
116
|
+
|
117
|
+
class AllFilter:
|
118
|
+
def __init__(self, *filters: Filter):
|
119
|
+
self.predicates = tuple(
|
120
|
+
to_predicate(collection_filter) for collection_filter in filters
|
121
|
+
)
|
122
|
+
|
123
|
+
def __call__(self, path: PathParts, x: typing.Any):
|
124
|
+
return all(predicate(path, x) for predicate in self.predicates)
|
125
|
+
|
126
|
+
def __repr__(self):
|
127
|
+
return f'All({", ".join(map(repr, self.predicates))})'
|
128
|
+
|
129
|
+
def __eq__(self, other):
|
130
|
+
return isinstance(other, AllFilter) and self.predicates == other.predicates
|
131
|
+
|
132
|
+
def __hash__(self):
|
133
|
+
return hash(self.predicates)
|
134
|
+
|
135
|
+
|
136
|
+
class NotFilter:
|
137
|
+
def __init__(self, collection_filter: Filter, /):
|
138
|
+
self.predicate = to_predicate(collection_filter)
|
139
|
+
|
140
|
+
def __call__(self, path: PathParts, x: typing.Any):
|
141
|
+
return not self.predicate(path, x)
|
142
|
+
|
143
|
+
def __repr__(self):
|
144
|
+
return f'Not({self.predicate!r})'
|
145
|
+
|
146
|
+
def __eq__(self, other):
|
147
|
+
return isinstance(other, NotFilter) and self.predicate == other.predicate
|
148
|
+
|
149
|
+
def __hash__(self):
|
150
|
+
return hash(self.predicate)
|
151
|
+
|
152
|
+
|
153
|
+
class EverythingFilter:
|
154
|
+
def __call__(self, path: PathParts, x: typing.Any):
|
155
|
+
return True
|
156
|
+
|
157
|
+
def __repr__(self):
|
158
|
+
return 'Everything()'
|
159
|
+
|
160
|
+
def __eq__(self, other):
|
161
|
+
return isinstance(other, EverythingFilter)
|
162
|
+
|
163
|
+
def __hash__(self):
|
164
|
+
return hash(EverythingFilter)
|
165
|
+
|
166
|
+
|
167
|
+
class NothingFilter:
|
168
|
+
def __call__(self, path: PathParts, x: typing.Any):
|
169
|
+
return False
|
170
|
+
|
171
|
+
def __repr__(self):
|
172
|
+
return 'Nothing()'
|
173
|
+
|
174
|
+
def __eq__(self, other):
|
175
|
+
return isinstance(other, NothingFilter)
|
176
|
+
|
177
|
+
def __hash__(self):
|
178
|
+
return hash(NothingFilter)
|