brainstate 0.1.8__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +58 -51
- brainstate/_compatible_import.py +148 -148
- brainstate/_state.py +1605 -1663
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/augment/__init__.py +30 -30
- brainstate/augment/_autograd.py +778 -778
- brainstate/augment/_autograd_test.py +1289 -1289
- brainstate/augment/_eval_shape.py +99 -99
- brainstate/augment/_eval_shape_test.py +38 -38
- brainstate/augment/_mapping.py +1060 -1060
- brainstate/augment/_mapping_test.py +597 -597
- brainstate/augment/_random.py +151 -151
- brainstate/compile/__init__.py +38 -38
- brainstate/compile/_ad_checkpoint.py +204 -204
- brainstate/compile/_ad_checkpoint_test.py +49 -49
- brainstate/compile/_conditions.py +256 -256
- brainstate/compile/_conditions_test.py +220 -220
- brainstate/compile/_error_if.py +92 -92
- brainstate/compile/_error_if_test.py +52 -52
- brainstate/compile/_jit.py +346 -346
- brainstate/compile/_jit_test.py +143 -143
- brainstate/compile/_loop_collect_return.py +536 -536
- brainstate/compile/_loop_collect_return_test.py +58 -58
- brainstate/compile/_loop_no_collection.py +184 -184
- brainstate/compile/_loop_no_collection_test.py +50 -50
- brainstate/compile/_make_jaxpr.py +888 -888
- brainstate/compile/_make_jaxpr_test.py +156 -156
- brainstate/compile/_progress_bar.py +202 -202
- brainstate/compile/_unvmap.py +159 -159
- brainstate/compile/_util.py +147 -147
- brainstate/environ.py +563 -563
- brainstate/environ_test.py +62 -62
- brainstate/functional/__init__.py +27 -26
- brainstate/graph/__init__.py +29 -29
- brainstate/graph/_graph_node.py +244 -244
- brainstate/graph/_graph_node_test.py +73 -73
- brainstate/graph/_graph_operation.py +1738 -1738
- brainstate/graph/_graph_operation_test.py +563 -563
- brainstate/init/__init__.py +26 -26
- brainstate/init/_base.py +52 -52
- brainstate/init/_generic.py +244 -244
- brainstate/init/_random_inits.py +553 -553
- brainstate/init/_random_inits_test.py +149 -149
- brainstate/init/_regular_inits.py +105 -105
- brainstate/init/_regular_inits_test.py +50 -50
- brainstate/mixin.py +365 -363
- brainstate/mixin_test.py +77 -73
- brainstate/nn/__init__.py +135 -131
- brainstate/{functional → nn}/_activations.py +808 -813
- brainstate/{functional → nn}/_activations_test.py +331 -331
- brainstate/nn/_collective_ops.py +514 -514
- brainstate/nn/_collective_ops_test.py +43 -43
- brainstate/nn/_common.py +178 -178
- brainstate/nn/_conv.py +501 -501
- brainstate/nn/_conv_test.py +238 -238
- brainstate/nn/_delay.py +509 -502
- brainstate/nn/_delay_test.py +238 -184
- brainstate/nn/_dropout.py +426 -426
- brainstate/nn/_dropout_test.py +100 -100
- brainstate/nn/_dynamics.py +1343 -1343
- brainstate/nn/_dynamics_test.py +78 -78
- brainstate/nn/_elementwise.py +1119 -1119
- brainstate/nn/_elementwise_test.py +169 -169
- brainstate/nn/_embedding.py +58 -58
- brainstate/nn/_exp_euler.py +92 -92
- brainstate/nn/_exp_euler_test.py +35 -35
- brainstate/nn/_fixedprob.py +239 -239
- brainstate/nn/_fixedprob_test.py +114 -114
- brainstate/nn/_inputs.py +608 -608
- brainstate/nn/_linear.py +424 -424
- brainstate/nn/_linear_mv.py +83 -83
- brainstate/nn/_linear_mv_test.py +120 -120
- brainstate/nn/_linear_test.py +107 -107
- brainstate/nn/_ltp.py +28 -28
- brainstate/nn/_module.py +377 -377
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_neuron.py +705 -705
- brainstate/nn/_neuron_test.py +161 -161
- brainstate/nn/_normalizations.py +975 -918
- brainstate/nn/_normalizations_test.py +73 -73
- brainstate/{functional → nn}/_others.py +46 -46
- brainstate/nn/_poolings.py +1177 -1177
- brainstate/nn/_poolings_test.py +217 -217
- brainstate/nn/_projection.py +486 -486
- brainstate/nn/_rate_rnns.py +554 -554
- brainstate/nn/_rate_rnns_test.py +63 -63
- brainstate/nn/_readout.py +209 -209
- brainstate/nn/_readout_test.py +53 -53
- brainstate/nn/_stp.py +236 -236
- brainstate/nn/_synapse.py +505 -505
- brainstate/nn/_synapse_test.py +131 -131
- brainstate/nn/_synaptic_projection.py +423 -423
- brainstate/nn/_synouts.py +162 -162
- brainstate/nn/_synouts_test.py +57 -57
- brainstate/nn/_utils.py +89 -89
- brainstate/nn/metrics.py +388 -388
- brainstate/optim/__init__.py +38 -38
- brainstate/optim/_base.py +64 -64
- brainstate/optim/_lr_scheduler.py +448 -448
- brainstate/optim/_lr_scheduler_test.py +50 -50
- brainstate/optim/_optax_optimizer.py +152 -152
- brainstate/optim/_optax_optimizer_test.py +53 -53
- brainstate/optim/_sgd_optimizer.py +1104 -1104
- brainstate/random/__init__.py +24 -24
- brainstate/random/_rand_funs.py +3616 -3616
- brainstate/random/_rand_funs_test.py +567 -567
- brainstate/random/_rand_seed.py +210 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1409 -1409
- brainstate/random/_random_for_unit.py +52 -52
- brainstate/surrogate.py +1957 -1957
- brainstate/transform.py +23 -23
- brainstate/typing.py +304 -304
- brainstate/util/__init__.py +50 -50
- brainstate/util/caller.py +98 -98
- brainstate/util/error.py +55 -55
- brainstate/util/filter.py +469 -469
- brainstate/util/others.py +540 -540
- brainstate/util/pretty_pytree.py +945 -945
- brainstate/util/pretty_pytree_test.py +159 -159
- brainstate/util/pretty_repr.py +328 -328
- brainstate/util/pretty_table.py +2954 -2954
- brainstate/util/scaling.py +258 -258
- brainstate/util/struct.py +523 -523
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
- brainstate-0.1.9.dist-info/RECORD +130 -0
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
- brainstate/functional/_normalization.py +0 -81
- brainstate/functional/_spikes.py +0 -204
- brainstate-0.1.8.dist-info/RECORD +0 -132
- {brainstate-0.1.8.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
brainstate/util/filter.py
CHANGED
@@ -1,469 +1,469 @@
|
|
1
|
-
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
|
-
# The credit should go to the Flax authors.
|
3
|
-
#
|
4
|
-
# Copyright 2024 The Flax Authors.
|
5
|
-
#
|
6
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
|
-
# you may not use this file except in compliance with the License.
|
8
|
-
# You may obtain a copy of the License at
|
9
|
-
#
|
10
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
11
|
-
#
|
12
|
-
# Unless required by applicable law or agreed to in writing, software
|
13
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
14
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15
|
-
# See the License for the specific language governing permissions and
|
16
|
-
# limitations under the License.
|
17
|
-
|
18
|
-
import builtins
|
19
|
-
import dataclasses
|
20
|
-
import typing
|
21
|
-
from typing import TYPE_CHECKING
|
22
|
-
|
23
|
-
from brainstate.typing import Filter, PathParts, Predicate, Key
|
24
|
-
|
25
|
-
if TYPE_CHECKING:
|
26
|
-
ellipsis = builtins.ellipsis
|
27
|
-
else:
|
28
|
-
ellipsis = typing.Any
|
29
|
-
|
30
|
-
__all__ = [
|
31
|
-
'to_predicate',
|
32
|
-
'WithTag',
|
33
|
-
'PathContains',
|
34
|
-
'OfType',
|
35
|
-
'Any',
|
36
|
-
'All',
|
37
|
-
'Nothing',
|
38
|
-
'Not',
|
39
|
-
'Everything',
|
40
|
-
]
|
41
|
-
|
42
|
-
|
43
|
-
def to_predicate(the_filter: Filter) -> Predicate:
|
44
|
-
"""
|
45
|
-
Converts a Filter to a predicate function.
|
46
|
-
|
47
|
-
This function takes various types of filters and converts them into
|
48
|
-
corresponding predicate functions that can be used for filtering.
|
49
|
-
|
50
|
-
Args:
|
51
|
-
the_filter (Filter): The filter to be converted. Can be of various types:
|
52
|
-
- str: Converted to a WithTag filter.
|
53
|
-
- type: Converted to an OfType filter.
|
54
|
-
- bool: True becomes Everything(), False becomes Nothing().
|
55
|
-
- Ellipsis: Converted to Everything().
|
56
|
-
- None: Converted to Nothing().
|
57
|
-
- callable: Returned as-is.
|
58
|
-
- list or tuple: Converted to Any filter with elements as arguments.
|
59
|
-
|
60
|
-
Returns:
|
61
|
-
Predicate: A callable predicate function that can be used for filtering.
|
62
|
-
|
63
|
-
Raises:
|
64
|
-
TypeError: If the input filter is of an invalid type.
|
65
|
-
"""
|
66
|
-
|
67
|
-
if isinstance(the_filter, str):
|
68
|
-
return WithTag(the_filter)
|
69
|
-
elif isinstance(the_filter, type):
|
70
|
-
return OfType(the_filter)
|
71
|
-
elif isinstance(the_filter, bool):
|
72
|
-
if the_filter:
|
73
|
-
return Everything()
|
74
|
-
else:
|
75
|
-
return Nothing()
|
76
|
-
elif the_filter is Ellipsis:
|
77
|
-
return Everything()
|
78
|
-
elif the_filter is None:
|
79
|
-
return Nothing()
|
80
|
-
elif callable(the_filter):
|
81
|
-
return the_filter
|
82
|
-
elif isinstance(the_filter, (list, tuple)):
|
83
|
-
return Any(*the_filter)
|
84
|
-
else:
|
85
|
-
raise TypeError(f'Invalid collection filter: {the_filter!r}. ')
|
86
|
-
|
87
|
-
|
88
|
-
@dataclasses.dataclass(frozen=True)
|
89
|
-
class WithTag:
|
90
|
-
"""
|
91
|
-
A filter class that checks if an object has a specific tag.
|
92
|
-
|
93
|
-
This class is a callable that can be used as a predicate function
|
94
|
-
to filter objects based on their 'tag' attribute.
|
95
|
-
|
96
|
-
Attributes:
|
97
|
-
tag (str): The tag to match against.
|
98
|
-
"""
|
99
|
-
|
100
|
-
tag: str
|
101
|
-
|
102
|
-
def __call__(self, path: PathParts, x: typing.Any) -> bool:
|
103
|
-
"""
|
104
|
-
Check if the object has a 'tag' attribute matching the specified tag.
|
105
|
-
|
106
|
-
Args:
|
107
|
-
path (PathParts): The path to the current object (not used in this filter).
|
108
|
-
x (typing.Any): The object to check for the tag.
|
109
|
-
|
110
|
-
Returns:
|
111
|
-
bool: True if the object has a 'tag' attribute matching the specified tag, False otherwise.
|
112
|
-
"""
|
113
|
-
return hasattr(x, 'tag') and x.tag == self.tag
|
114
|
-
|
115
|
-
def __repr__(self) -> str:
|
116
|
-
return f'WithTag({self.tag!r})'
|
117
|
-
|
118
|
-
|
119
|
-
@dataclasses.dataclass(frozen=True)
|
120
|
-
class PathContains:
|
121
|
-
"""
|
122
|
-
A filter class that checks if a given key is present in the path.
|
123
|
-
|
124
|
-
This class is a callable that can be used as a predicate function
|
125
|
-
to filter objects based on whether a specific key is present in their path.
|
126
|
-
|
127
|
-
Attributes:
|
128
|
-
key (Key): The key to search for in the path.
|
129
|
-
"""
|
130
|
-
|
131
|
-
key: Key
|
132
|
-
|
133
|
-
def __call__(self, path: PathParts, x: typing.Any) -> bool:
|
134
|
-
"""
|
135
|
-
Check if the key is present in the given path.
|
136
|
-
|
137
|
-
Args:
|
138
|
-
path (PathParts): The path to check for the presence of the key.
|
139
|
-
x (typing.Any): The object associated with the path (not used in this filter).
|
140
|
-
|
141
|
-
Returns:
|
142
|
-
bool: True if the key is present in the path, False otherwise.
|
143
|
-
"""
|
144
|
-
return self.key in path
|
145
|
-
|
146
|
-
def __repr__(self) -> str:
|
147
|
-
return f'PathContains({self.key!r})'
|
148
|
-
|
149
|
-
|
150
|
-
@dataclasses.dataclass(frozen=True)
|
151
|
-
class OfType:
|
152
|
-
"""
|
153
|
-
A filter class that checks if an object is of a specific type.
|
154
|
-
|
155
|
-
This class is a callable that can be used as a predicate function
|
156
|
-
to filter objects based on their type.
|
157
|
-
|
158
|
-
Attributes:
|
159
|
-
type (type): The type to match against.
|
160
|
-
"""
|
161
|
-
type: type
|
162
|
-
|
163
|
-
def __call__(self, path: PathParts, x: typing.Any):
|
164
|
-
return isinstance(x, self.type) or (
|
165
|
-
hasattr(x, 'type') and issubclass(x.type, self.type)
|
166
|
-
)
|
167
|
-
|
168
|
-
def __repr__(self):
|
169
|
-
return f'OfType({self.type!r})'
|
170
|
-
|
171
|
-
|
172
|
-
class Any:
|
173
|
-
"""
|
174
|
-
A filter class that combines multiple filters using a logical OR operation.
|
175
|
-
|
176
|
-
This class creates a composite filter that returns True if any of its
|
177
|
-
constituent filters return True.
|
178
|
-
|
179
|
-
Attributes:
|
180
|
-
predicates (tuple): A tuple of predicate functions converted from the input filters.
|
181
|
-
"""
|
182
|
-
|
183
|
-
def __init__(self, *filters: Filter):
|
184
|
-
"""
|
185
|
-
Initialize the Any filter with a variable number of filters.
|
186
|
-
|
187
|
-
Args:
|
188
|
-
*filters (Filter): Variable number of filters to be combined.
|
189
|
-
"""
|
190
|
-
self.predicates = tuple(
|
191
|
-
to_predicate(collection_filter) for collection_filter in filters
|
192
|
-
)
|
193
|
-
|
194
|
-
def __call__(self, path: PathParts, x: typing.Any) -> bool:
|
195
|
-
"""
|
196
|
-
Apply the composite filter to the given path and object.
|
197
|
-
|
198
|
-
Args:
|
199
|
-
path (PathParts): The path to the current object.
|
200
|
-
x (typing.Any): The object to be filtered.
|
201
|
-
|
202
|
-
Returns:
|
203
|
-
bool: True if any of the constituent predicates return True, False otherwise.
|
204
|
-
"""
|
205
|
-
return any(predicate(path, x) for predicate in self.predicates)
|
206
|
-
|
207
|
-
def __repr__(self) -> str:
|
208
|
-
"""
|
209
|
-
Return a string representation of the Any filter.
|
210
|
-
|
211
|
-
Returns:
|
212
|
-
str: A string representation of the Any filter, including its predicates.
|
213
|
-
"""
|
214
|
-
return f'Any({", ".join(map(repr, self.predicates))})'
|
215
|
-
|
216
|
-
def __eq__(self, other) -> bool:
|
217
|
-
"""
|
218
|
-
Check if this Any filter is equal to another object.
|
219
|
-
|
220
|
-
Args:
|
221
|
-
other: The object to compare with.
|
222
|
-
|
223
|
-
Returns:
|
224
|
-
bool: True if the other object is an Any filter with the same predicates, False otherwise.
|
225
|
-
"""
|
226
|
-
return isinstance(other, Any) and self.predicates == other.predicates
|
227
|
-
|
228
|
-
def __hash__(self) -> int:
|
229
|
-
"""
|
230
|
-
Compute the hash value for this Any filter.
|
231
|
-
|
232
|
-
Returns:
|
233
|
-
int: The hash value of the predicates tuple.
|
234
|
-
"""
|
235
|
-
return hash(self.predicates)
|
236
|
-
|
237
|
-
|
238
|
-
class All:
|
239
|
-
"""
|
240
|
-
A filter class that combines multiple filters using a logical AND operation.
|
241
|
-
|
242
|
-
This class creates a composite filter that returns True only if all of its
|
243
|
-
constituent filters return True.
|
244
|
-
|
245
|
-
Attributes:
|
246
|
-
predicates (tuple): A tuple of predicate functions converted from the input filters.
|
247
|
-
"""
|
248
|
-
|
249
|
-
def __init__(self, *filters: Filter):
|
250
|
-
"""
|
251
|
-
Initialize the All filter with a variable number of filters.
|
252
|
-
|
253
|
-
Args:
|
254
|
-
*filters (Filter): Variable number of filters to be combined.
|
255
|
-
"""
|
256
|
-
self.predicates = tuple(
|
257
|
-
to_predicate(collection_filter) for collection_filter in filters
|
258
|
-
)
|
259
|
-
|
260
|
-
def __call__(self, path: PathParts, x: typing.Any) -> bool:
|
261
|
-
"""
|
262
|
-
Apply the composite filter to the given path and object.
|
263
|
-
|
264
|
-
Args:
|
265
|
-
path (PathParts): The path to the current object.
|
266
|
-
x (typing.Any): The object to be filtered.
|
267
|
-
|
268
|
-
Returns:
|
269
|
-
bool: True if all of the constituent predicates return True, False otherwise.
|
270
|
-
"""
|
271
|
-
return all(predicate(path, x) for predicate in self.predicates)
|
272
|
-
|
273
|
-
def __repr__(self) -> str:
|
274
|
-
"""
|
275
|
-
Return a string representation of the All filter.
|
276
|
-
|
277
|
-
Returns:
|
278
|
-
str: A string representation of the All filter, including its predicates.
|
279
|
-
"""
|
280
|
-
return f'All({", ".join(map(repr, self.predicates))})'
|
281
|
-
|
282
|
-
def __eq__(self, other) -> bool:
|
283
|
-
"""
|
284
|
-
Check if this All filter is equal to another object.
|
285
|
-
|
286
|
-
Args:
|
287
|
-
other: The object to compare with.
|
288
|
-
|
289
|
-
Returns:
|
290
|
-
bool: True if the other object is an All filter with the same predicates, False otherwise.
|
291
|
-
"""
|
292
|
-
return isinstance(other, All) and self.predicates == other.predicates
|
293
|
-
|
294
|
-
def __hash__(self) -> int:
|
295
|
-
"""
|
296
|
-
Compute the hash value for this All filter.
|
297
|
-
|
298
|
-
Returns:
|
299
|
-
int: The hash value of the predicates tuple.
|
300
|
-
"""
|
301
|
-
return hash(self.predicates)
|
302
|
-
|
303
|
-
|
304
|
-
class Not:
|
305
|
-
"""
|
306
|
-
A filter class that negates the result of another filter.
|
307
|
-
|
308
|
-
This class creates a new filter that returns the opposite boolean value
|
309
|
-
of the filter it wraps.
|
310
|
-
|
311
|
-
Attributes:
|
312
|
-
predicate (Predicate): The predicate function converted from the input filter.
|
313
|
-
"""
|
314
|
-
|
315
|
-
def __init__(self, collection_filter: Filter, /):
|
316
|
-
"""
|
317
|
-
Initialize the Not filter with another filter.
|
318
|
-
|
319
|
-
Args:
|
320
|
-
collection_filter (Filter): The filter to be negated.
|
321
|
-
"""
|
322
|
-
self.predicate = to_predicate(collection_filter)
|
323
|
-
|
324
|
-
def __call__(self, path: PathParts, x: typing.Any) -> bool:
|
325
|
-
"""
|
326
|
-
Apply the negated filter to the given path and object.
|
327
|
-
|
328
|
-
Args:
|
329
|
-
path (PathParts): The path to the current object.
|
330
|
-
x (typing.Any): The object to be filtered.
|
331
|
-
|
332
|
-
Returns:
|
333
|
-
bool: The negation of the result from the wrapped predicate.
|
334
|
-
"""
|
335
|
-
return not self.predicate(path, x)
|
336
|
-
|
337
|
-
def __repr__(self) -> str:
|
338
|
-
"""
|
339
|
-
Return a string representation of the Not filter.
|
340
|
-
|
341
|
-
Returns:
|
342
|
-
str: A string representation of the Not filter, including its predicate.
|
343
|
-
"""
|
344
|
-
return f'Not({self.predicate!r})'
|
345
|
-
|
346
|
-
def __eq__(self, other) -> bool:
|
347
|
-
"""
|
348
|
-
Check if this Not filter is equal to another object.
|
349
|
-
|
350
|
-
Args:
|
351
|
-
other: The object to compare with.
|
352
|
-
|
353
|
-
Returns:
|
354
|
-
bool: True if the other object is a Not filter with the same predicate, False otherwise.
|
355
|
-
"""
|
356
|
-
return isinstance(other, Not) and self.predicate == other.predicate
|
357
|
-
|
358
|
-
def __hash__(self) -> int:
|
359
|
-
"""
|
360
|
-
Compute the hash value for this Not filter.
|
361
|
-
|
362
|
-
Returns:
|
363
|
-
int: The hash value of the predicate.
|
364
|
-
"""
|
365
|
-
return hash(self.predicate)
|
366
|
-
|
367
|
-
|
368
|
-
class Everything:
|
369
|
-
"""
|
370
|
-
A filter class that always returns True for any input.
|
371
|
-
|
372
|
-
This class represents a filter that matches everything, effectively
|
373
|
-
allowing all objects to pass through without any filtering.
|
374
|
-
"""
|
375
|
-
|
376
|
-
def __call__(self, path: PathParts, x: typing.Any) -> bool:
|
377
|
-
"""
|
378
|
-
Always return True, regardless of the input.
|
379
|
-
|
380
|
-
Args:
|
381
|
-
path (PathParts): The path to the current object (not used).
|
382
|
-
x (typing.Any): The object to be filtered (not used).
|
383
|
-
|
384
|
-
Returns:
|
385
|
-
bool: Always returns True.
|
386
|
-
"""
|
387
|
-
return True
|
388
|
-
|
389
|
-
def __repr__(self) -> str:
|
390
|
-
"""
|
391
|
-
Return a string representation of the Everything filter.
|
392
|
-
|
393
|
-
Returns:
|
394
|
-
str: The string 'Everything()'.
|
395
|
-
"""
|
396
|
-
return 'Everything()'
|
397
|
-
|
398
|
-
def __eq__(self, other) -> bool:
|
399
|
-
"""
|
400
|
-
Check if this Everything filter is equal to another object.
|
401
|
-
|
402
|
-
Args:
|
403
|
-
other: The object to compare with.
|
404
|
-
|
405
|
-
Returns:
|
406
|
-
bool: True if the other object is an instance of Everything, False otherwise.
|
407
|
-
"""
|
408
|
-
return isinstance(other, Everything)
|
409
|
-
|
410
|
-
def __hash__(self) -> int:
|
411
|
-
"""
|
412
|
-
Compute the hash value for this Everything filter.
|
413
|
-
|
414
|
-
Returns:
|
415
|
-
int: The hash value of the Everything class.
|
416
|
-
"""
|
417
|
-
return hash(Everything)
|
418
|
-
|
419
|
-
|
420
|
-
class Nothing:
|
421
|
-
"""
|
422
|
-
A filter class that always returns False for any input.
|
423
|
-
|
424
|
-
This class represents a filter that matches nothing, effectively
|
425
|
-
filtering out all objects.
|
426
|
-
"""
|
427
|
-
|
428
|
-
def __call__(self, path: PathParts, x: typing.Any) -> bool:
|
429
|
-
"""
|
430
|
-
Always return False, regardless of the input.
|
431
|
-
|
432
|
-
Args:
|
433
|
-
path (PathParts): The path to the current object (not used).
|
434
|
-
x (typing.Any): The object to be filtered (not used).
|
435
|
-
|
436
|
-
Returns:
|
437
|
-
bool: Always returns False.
|
438
|
-
"""
|
439
|
-
return False
|
440
|
-
|
441
|
-
def __repr__(self) -> str:
|
442
|
-
"""
|
443
|
-
Return a string representation of the Nothing filter.
|
444
|
-
|
445
|
-
Returns:
|
446
|
-
str: The string 'Nothing()'.
|
447
|
-
"""
|
448
|
-
return 'Nothing()'
|
449
|
-
|
450
|
-
def __eq__(self, other) -> bool:
|
451
|
-
"""
|
452
|
-
Check if this Nothing filter is equal to another object.
|
453
|
-
|
454
|
-
Args:
|
455
|
-
other: The object to compare with.
|
456
|
-
|
457
|
-
Returns:
|
458
|
-
bool: True if the other object is an instance of Nothing, False otherwise.
|
459
|
-
"""
|
460
|
-
return isinstance(other, Nothing)
|
461
|
-
|
462
|
-
def __hash__(self) -> int:
|
463
|
-
"""
|
464
|
-
Compute the hash value for this Nothing filter.
|
465
|
-
|
466
|
-
Returns:
|
467
|
-
int: The hash value of the Nothing class.
|
468
|
-
"""
|
469
|
-
return hash(Nothing)
|
1
|
+
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
|
+
# The credit should go to the Flax authors.
|
3
|
+
#
|
4
|
+
# Copyright 2024 The Flax Authors.
|
5
|
+
#
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
|
+
# you may not use this file except in compliance with the License.
|
8
|
+
# You may obtain a copy of the License at
|
9
|
+
#
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11
|
+
#
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15
|
+
# See the License for the specific language governing permissions and
|
16
|
+
# limitations under the License.
|
17
|
+
|
18
|
+
import builtins
|
19
|
+
import dataclasses
|
20
|
+
import typing
|
21
|
+
from typing import TYPE_CHECKING
|
22
|
+
|
23
|
+
from brainstate.typing import Filter, PathParts, Predicate, Key
|
24
|
+
|
25
|
+
if TYPE_CHECKING:
|
26
|
+
ellipsis = builtins.ellipsis
|
27
|
+
else:
|
28
|
+
ellipsis = typing.Any
|
29
|
+
|
30
|
+
__all__ = [
|
31
|
+
'to_predicate',
|
32
|
+
'WithTag',
|
33
|
+
'PathContains',
|
34
|
+
'OfType',
|
35
|
+
'Any',
|
36
|
+
'All',
|
37
|
+
'Nothing',
|
38
|
+
'Not',
|
39
|
+
'Everything',
|
40
|
+
]
|
41
|
+
|
42
|
+
|
43
|
+
def to_predicate(the_filter: Filter) -> Predicate:
|
44
|
+
"""
|
45
|
+
Converts a Filter to a predicate function.
|
46
|
+
|
47
|
+
This function takes various types of filters and converts them into
|
48
|
+
corresponding predicate functions that can be used for filtering.
|
49
|
+
|
50
|
+
Args:
|
51
|
+
the_filter (Filter): The filter to be converted. Can be of various types:
|
52
|
+
- str: Converted to a WithTag filter.
|
53
|
+
- type: Converted to an OfType filter.
|
54
|
+
- bool: True becomes Everything(), False becomes Nothing().
|
55
|
+
- Ellipsis: Converted to Everything().
|
56
|
+
- None: Converted to Nothing().
|
57
|
+
- callable: Returned as-is.
|
58
|
+
- list or tuple: Converted to Any filter with elements as arguments.
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
Predicate: A callable predicate function that can be used for filtering.
|
62
|
+
|
63
|
+
Raises:
|
64
|
+
TypeError: If the input filter is of an invalid type.
|
65
|
+
"""
|
66
|
+
|
67
|
+
if isinstance(the_filter, str):
|
68
|
+
return WithTag(the_filter)
|
69
|
+
elif isinstance(the_filter, type):
|
70
|
+
return OfType(the_filter)
|
71
|
+
elif isinstance(the_filter, bool):
|
72
|
+
if the_filter:
|
73
|
+
return Everything()
|
74
|
+
else:
|
75
|
+
return Nothing()
|
76
|
+
elif the_filter is Ellipsis:
|
77
|
+
return Everything()
|
78
|
+
elif the_filter is None:
|
79
|
+
return Nothing()
|
80
|
+
elif callable(the_filter):
|
81
|
+
return the_filter
|
82
|
+
elif isinstance(the_filter, (list, tuple)):
|
83
|
+
return Any(*the_filter)
|
84
|
+
else:
|
85
|
+
raise TypeError(f'Invalid collection filter: {the_filter!r}. ')
|
86
|
+
|
87
|
+
|
88
|
+
@dataclasses.dataclass(frozen=True)
|
89
|
+
class WithTag:
|
90
|
+
"""
|
91
|
+
A filter class that checks if an object has a specific tag.
|
92
|
+
|
93
|
+
This class is a callable that can be used as a predicate function
|
94
|
+
to filter objects based on their 'tag' attribute.
|
95
|
+
|
96
|
+
Attributes:
|
97
|
+
tag (str): The tag to match against.
|
98
|
+
"""
|
99
|
+
|
100
|
+
tag: str
|
101
|
+
|
102
|
+
def __call__(self, path: PathParts, x: typing.Any) -> bool:
|
103
|
+
"""
|
104
|
+
Check if the object has a 'tag' attribute matching the specified tag.
|
105
|
+
|
106
|
+
Args:
|
107
|
+
path (PathParts): The path to the current object (not used in this filter).
|
108
|
+
x (typing.Any): The object to check for the tag.
|
109
|
+
|
110
|
+
Returns:
|
111
|
+
bool: True if the object has a 'tag' attribute matching the specified tag, False otherwise.
|
112
|
+
"""
|
113
|
+
return hasattr(x, 'tag') and x.tag == self.tag
|
114
|
+
|
115
|
+
def __repr__(self) -> str:
|
116
|
+
return f'WithTag({self.tag!r})'
|
117
|
+
|
118
|
+
|
119
|
+
@dataclasses.dataclass(frozen=True)
|
120
|
+
class PathContains:
|
121
|
+
"""
|
122
|
+
A filter class that checks if a given key is present in the path.
|
123
|
+
|
124
|
+
This class is a callable that can be used as a predicate function
|
125
|
+
to filter objects based on whether a specific key is present in their path.
|
126
|
+
|
127
|
+
Attributes:
|
128
|
+
key (Key): The key to search for in the path.
|
129
|
+
"""
|
130
|
+
|
131
|
+
key: Key
|
132
|
+
|
133
|
+
def __call__(self, path: PathParts, x: typing.Any) -> bool:
|
134
|
+
"""
|
135
|
+
Check if the key is present in the given path.
|
136
|
+
|
137
|
+
Args:
|
138
|
+
path (PathParts): The path to check for the presence of the key.
|
139
|
+
x (typing.Any): The object associated with the path (not used in this filter).
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
bool: True if the key is present in the path, False otherwise.
|
143
|
+
"""
|
144
|
+
return self.key in path
|
145
|
+
|
146
|
+
def __repr__(self) -> str:
|
147
|
+
return f'PathContains({self.key!r})'
|
148
|
+
|
149
|
+
|
150
|
+
@dataclasses.dataclass(frozen=True)
|
151
|
+
class OfType:
|
152
|
+
"""
|
153
|
+
A filter class that checks if an object is of a specific type.
|
154
|
+
|
155
|
+
This class is a callable that can be used as a predicate function
|
156
|
+
to filter objects based on their type.
|
157
|
+
|
158
|
+
Attributes:
|
159
|
+
type (type): The type to match against.
|
160
|
+
"""
|
161
|
+
type: type
|
162
|
+
|
163
|
+
def __call__(self, path: PathParts, x: typing.Any):
|
164
|
+
return isinstance(x, self.type) or (
|
165
|
+
hasattr(x, 'type') and issubclass(x.type, self.type)
|
166
|
+
)
|
167
|
+
|
168
|
+
def __repr__(self):
|
169
|
+
return f'OfType({self.type!r})'
|
170
|
+
|
171
|
+
|
172
|
+
class Any:
|
173
|
+
"""
|
174
|
+
A filter class that combines multiple filters using a logical OR operation.
|
175
|
+
|
176
|
+
This class creates a composite filter that returns True if any of its
|
177
|
+
constituent filters return True.
|
178
|
+
|
179
|
+
Attributes:
|
180
|
+
predicates (tuple): A tuple of predicate functions converted from the input filters.
|
181
|
+
"""
|
182
|
+
|
183
|
+
def __init__(self, *filters: Filter):
|
184
|
+
"""
|
185
|
+
Initialize the Any filter with a variable number of filters.
|
186
|
+
|
187
|
+
Args:
|
188
|
+
*filters (Filter): Variable number of filters to be combined.
|
189
|
+
"""
|
190
|
+
self.predicates = tuple(
|
191
|
+
to_predicate(collection_filter) for collection_filter in filters
|
192
|
+
)
|
193
|
+
|
194
|
+
def __call__(self, path: PathParts, x: typing.Any) -> bool:
|
195
|
+
"""
|
196
|
+
Apply the composite filter to the given path and object.
|
197
|
+
|
198
|
+
Args:
|
199
|
+
path (PathParts): The path to the current object.
|
200
|
+
x (typing.Any): The object to be filtered.
|
201
|
+
|
202
|
+
Returns:
|
203
|
+
bool: True if any of the constituent predicates return True, False otherwise.
|
204
|
+
"""
|
205
|
+
return any(predicate(path, x) for predicate in self.predicates)
|
206
|
+
|
207
|
+
def __repr__(self) -> str:
|
208
|
+
"""
|
209
|
+
Return a string representation of the Any filter.
|
210
|
+
|
211
|
+
Returns:
|
212
|
+
str: A string representation of the Any filter, including its predicates.
|
213
|
+
"""
|
214
|
+
return f'Any({", ".join(map(repr, self.predicates))})'
|
215
|
+
|
216
|
+
def __eq__(self, other) -> bool:
|
217
|
+
"""
|
218
|
+
Check if this Any filter is equal to another object.
|
219
|
+
|
220
|
+
Args:
|
221
|
+
other: The object to compare with.
|
222
|
+
|
223
|
+
Returns:
|
224
|
+
bool: True if the other object is an Any filter with the same predicates, False otherwise.
|
225
|
+
"""
|
226
|
+
return isinstance(other, Any) and self.predicates == other.predicates
|
227
|
+
|
228
|
+
def __hash__(self) -> int:
|
229
|
+
"""
|
230
|
+
Compute the hash value for this Any filter.
|
231
|
+
|
232
|
+
Returns:
|
233
|
+
int: The hash value of the predicates tuple.
|
234
|
+
"""
|
235
|
+
return hash(self.predicates)
|
236
|
+
|
237
|
+
|
238
|
+
class All:
|
239
|
+
"""
|
240
|
+
A filter class that combines multiple filters using a logical AND operation.
|
241
|
+
|
242
|
+
This class creates a composite filter that returns True only if all of its
|
243
|
+
constituent filters return True.
|
244
|
+
|
245
|
+
Attributes:
|
246
|
+
predicates (tuple): A tuple of predicate functions converted from the input filters.
|
247
|
+
"""
|
248
|
+
|
249
|
+
def __init__(self, *filters: Filter):
|
250
|
+
"""
|
251
|
+
Initialize the All filter with a variable number of filters.
|
252
|
+
|
253
|
+
Args:
|
254
|
+
*filters (Filter): Variable number of filters to be combined.
|
255
|
+
"""
|
256
|
+
self.predicates = tuple(
|
257
|
+
to_predicate(collection_filter) for collection_filter in filters
|
258
|
+
)
|
259
|
+
|
260
|
+
def __call__(self, path: PathParts, x: typing.Any) -> bool:
|
261
|
+
"""
|
262
|
+
Apply the composite filter to the given path and object.
|
263
|
+
|
264
|
+
Args:
|
265
|
+
path (PathParts): The path to the current object.
|
266
|
+
x (typing.Any): The object to be filtered.
|
267
|
+
|
268
|
+
Returns:
|
269
|
+
bool: True if all of the constituent predicates return True, False otherwise.
|
270
|
+
"""
|
271
|
+
return all(predicate(path, x) for predicate in self.predicates)
|
272
|
+
|
273
|
+
def __repr__(self) -> str:
|
274
|
+
"""
|
275
|
+
Return a string representation of the All filter.
|
276
|
+
|
277
|
+
Returns:
|
278
|
+
str: A string representation of the All filter, including its predicates.
|
279
|
+
"""
|
280
|
+
return f'All({", ".join(map(repr, self.predicates))})'
|
281
|
+
|
282
|
+
def __eq__(self, other) -> bool:
|
283
|
+
"""
|
284
|
+
Check if this All filter is equal to another object.
|
285
|
+
|
286
|
+
Args:
|
287
|
+
other: The object to compare with.
|
288
|
+
|
289
|
+
Returns:
|
290
|
+
bool: True if the other object is an All filter with the same predicates, False otherwise.
|
291
|
+
"""
|
292
|
+
return isinstance(other, All) and self.predicates == other.predicates
|
293
|
+
|
294
|
+
def __hash__(self) -> int:
|
295
|
+
"""
|
296
|
+
Compute the hash value for this All filter.
|
297
|
+
|
298
|
+
Returns:
|
299
|
+
int: The hash value of the predicates tuple.
|
300
|
+
"""
|
301
|
+
return hash(self.predicates)
|
302
|
+
|
303
|
+
|
304
|
+
class Not:
|
305
|
+
"""
|
306
|
+
A filter class that negates the result of another filter.
|
307
|
+
|
308
|
+
This class creates a new filter that returns the opposite boolean value
|
309
|
+
of the filter it wraps.
|
310
|
+
|
311
|
+
Attributes:
|
312
|
+
predicate (Predicate): The predicate function converted from the input filter.
|
313
|
+
"""
|
314
|
+
|
315
|
+
def __init__(self, collection_filter: Filter, /):
|
316
|
+
"""
|
317
|
+
Initialize the Not filter with another filter.
|
318
|
+
|
319
|
+
Args:
|
320
|
+
collection_filter (Filter): The filter to be negated.
|
321
|
+
"""
|
322
|
+
self.predicate = to_predicate(collection_filter)
|
323
|
+
|
324
|
+
def __call__(self, path: PathParts, x: typing.Any) -> bool:
|
325
|
+
"""
|
326
|
+
Apply the negated filter to the given path and object.
|
327
|
+
|
328
|
+
Args:
|
329
|
+
path (PathParts): The path to the current object.
|
330
|
+
x (typing.Any): The object to be filtered.
|
331
|
+
|
332
|
+
Returns:
|
333
|
+
bool: The negation of the result from the wrapped predicate.
|
334
|
+
"""
|
335
|
+
return not self.predicate(path, x)
|
336
|
+
|
337
|
+
def __repr__(self) -> str:
|
338
|
+
"""
|
339
|
+
Return a string representation of the Not filter.
|
340
|
+
|
341
|
+
Returns:
|
342
|
+
str: A string representation of the Not filter, including its predicate.
|
343
|
+
"""
|
344
|
+
return f'Not({self.predicate!r})'
|
345
|
+
|
346
|
+
def __eq__(self, other) -> bool:
|
347
|
+
"""
|
348
|
+
Check if this Not filter is equal to another object.
|
349
|
+
|
350
|
+
Args:
|
351
|
+
other: The object to compare with.
|
352
|
+
|
353
|
+
Returns:
|
354
|
+
bool: True if the other object is a Not filter with the same predicate, False otherwise.
|
355
|
+
"""
|
356
|
+
return isinstance(other, Not) and self.predicate == other.predicate
|
357
|
+
|
358
|
+
def __hash__(self) -> int:
|
359
|
+
"""
|
360
|
+
Compute the hash value for this Not filter.
|
361
|
+
|
362
|
+
Returns:
|
363
|
+
int: The hash value of the predicate.
|
364
|
+
"""
|
365
|
+
return hash(self.predicate)
|
366
|
+
|
367
|
+
|
368
|
+
class Everything:
|
369
|
+
"""
|
370
|
+
A filter class that always returns True for any input.
|
371
|
+
|
372
|
+
This class represents a filter that matches everything, effectively
|
373
|
+
allowing all objects to pass through without any filtering.
|
374
|
+
"""
|
375
|
+
|
376
|
+
def __call__(self, path: PathParts, x: typing.Any) -> bool:
|
377
|
+
"""
|
378
|
+
Always return True, regardless of the input.
|
379
|
+
|
380
|
+
Args:
|
381
|
+
path (PathParts): The path to the current object (not used).
|
382
|
+
x (typing.Any): The object to be filtered (not used).
|
383
|
+
|
384
|
+
Returns:
|
385
|
+
bool: Always returns True.
|
386
|
+
"""
|
387
|
+
return True
|
388
|
+
|
389
|
+
def __repr__(self) -> str:
|
390
|
+
"""
|
391
|
+
Return a string representation of the Everything filter.
|
392
|
+
|
393
|
+
Returns:
|
394
|
+
str: The string 'Everything()'.
|
395
|
+
"""
|
396
|
+
return 'Everything()'
|
397
|
+
|
398
|
+
def __eq__(self, other) -> bool:
|
399
|
+
"""
|
400
|
+
Check if this Everything filter is equal to another object.
|
401
|
+
|
402
|
+
Args:
|
403
|
+
other: The object to compare with.
|
404
|
+
|
405
|
+
Returns:
|
406
|
+
bool: True if the other object is an instance of Everything, False otherwise.
|
407
|
+
"""
|
408
|
+
return isinstance(other, Everything)
|
409
|
+
|
410
|
+
def __hash__(self) -> int:
|
411
|
+
"""
|
412
|
+
Compute the hash value for this Everything filter.
|
413
|
+
|
414
|
+
Returns:
|
415
|
+
int: The hash value of the Everything class.
|
416
|
+
"""
|
417
|
+
return hash(Everything)
|
418
|
+
|
419
|
+
|
420
|
+
class Nothing:
|
421
|
+
"""
|
422
|
+
A filter class that always returns False for any input.
|
423
|
+
|
424
|
+
This class represents a filter that matches nothing, effectively
|
425
|
+
filtering out all objects.
|
426
|
+
"""
|
427
|
+
|
428
|
+
def __call__(self, path: PathParts, x: typing.Any) -> bool:
|
429
|
+
"""
|
430
|
+
Always return False, regardless of the input.
|
431
|
+
|
432
|
+
Args:
|
433
|
+
path (PathParts): The path to the current object (not used).
|
434
|
+
x (typing.Any): The object to be filtered (not used).
|
435
|
+
|
436
|
+
Returns:
|
437
|
+
bool: Always returns False.
|
438
|
+
"""
|
439
|
+
return False
|
440
|
+
|
441
|
+
def __repr__(self) -> str:
|
442
|
+
"""
|
443
|
+
Return a string representation of the Nothing filter.
|
444
|
+
|
445
|
+
Returns:
|
446
|
+
str: The string 'Nothing()'.
|
447
|
+
"""
|
448
|
+
return 'Nothing()'
|
449
|
+
|
450
|
+
def __eq__(self, other) -> bool:
|
451
|
+
"""
|
452
|
+
Check if this Nothing filter is equal to another object.
|
453
|
+
|
454
|
+
Args:
|
455
|
+
other: The object to compare with.
|
456
|
+
|
457
|
+
Returns:
|
458
|
+
bool: True if the other object is an instance of Nothing, False otherwise.
|
459
|
+
"""
|
460
|
+
return isinstance(other, Nothing)
|
461
|
+
|
462
|
+
def __hash__(self) -> int:
|
463
|
+
"""
|
464
|
+
Compute the hash value for this Nothing filter.
|
465
|
+
|
466
|
+
Returns:
|
467
|
+
int: The hash value of the Nothing class.
|
468
|
+
"""
|
469
|
+
return hash(Nothing)
|