brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +169 -58
- brainstate/_compatible_import.py +340 -148
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +45 -55
- brainstate/_state.py +1652 -1605
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -563
- brainstate/environ_test.py +1223 -62
- brainstate/graph/__init__.py +22 -29
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1433 -365
- brainstate/mixin_test.py +1017 -77
- brainstate/nn/__init__.py +137 -135
- brainstate/nn/_activations.py +1100 -808
- brainstate/nn/_activations_test.py +354 -331
- brainstate/nn/_collective_ops.py +633 -514
- brainstate/nn/_collective_ops_test.py +774 -43
- brainstate/nn/_common.py +226 -178
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +2010 -501
- brainstate/nn/_conv_test.py +849 -238
- brainstate/nn/_delay.py +575 -588
- brainstate/nn/_delay_test.py +243 -238
- brainstate/nn/_dropout.py +618 -426
- brainstate/nn/_dropout_test.py +477 -100
- brainstate/nn/_dynamics.py +1267 -1343
- brainstate/nn/_dynamics_test.py +67 -78
- brainstate/nn/_elementwise.py +1298 -1119
- brainstate/nn/_elementwise_test.py +830 -169
- brainstate/nn/_embedding.py +408 -58
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
- brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
- brainstate/nn/_exp_euler.py +254 -92
- brainstate/nn/_exp_euler_test.py +377 -35
- brainstate/nn/_linear.py +744 -424
- brainstate/nn/_linear_test.py +475 -107
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +384 -377
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -975
- brainstate/nn/_normalizations_test.py +699 -73
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +2239 -1177
- brainstate/nn/_poolings_test.py +953 -217
- brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +216 -89
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +809 -553
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
- brainstate/random/__init__.py +270 -24
- brainstate/random/_rand_funs.py +3938 -3616
- brainstate/random/_rand_funs_test.py +640 -567
- brainstate/random/_rand_seed.py +675 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1617 -1409
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
- brainstate/{augment → transform}/_autograd.py +1025 -778
- brainstate/{augment → transform}/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +220 -220
- brainstate/{compile → transform}/_error_if.py +94 -92
- brainstate/{compile → transform}/_error_if_test.py +52 -52
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +38 -38
- brainstate/{compile → transform}/_jit.py +399 -346
- brainstate/{compile → transform}/_jit_test.py +143 -143
- brainstate/{compile → transform}/_loop_collect_return.py +675 -536
- brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
- brainstate/{compile → transform}/_loop_no_collection.py +283 -184
- brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +255 -202
- brainstate/{augment → transform}/_random.py +171 -151
- brainstate/{compile → transform}/_unvmap.py +256 -159
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +837 -304
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +27 -50
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +945 -469
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +910 -523
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
- brainstate-0.2.1.dist-info/RECORD +111 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/util/filter.py
CHANGED
@@ -1,469 +1,945 @@
|
|
1
|
-
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
|
-
# The credit should go to the Flax authors.
|
3
|
-
#
|
4
|
-
# Copyright 2024 The Flax Authors.
|
5
|
-
#
|
6
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
|
-
# you may not use this file except in compliance with the License.
|
8
|
-
# You may obtain a copy of the License at
|
9
|
-
#
|
10
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
11
|
-
#
|
12
|
-
# Unless required by applicable law or agreed to in writing, software
|
13
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
14
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15
|
-
# See the License for the specific language governing permissions and
|
16
|
-
# limitations under the License.
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
return
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
"""
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
filtering
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
1
|
+
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
|
+
# The credit should go to the Flax authors.
|
3
|
+
#
|
4
|
+
# Copyright 2024 The Flax Authors.
|
5
|
+
#
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
|
+
# you may not use this file except in compliance with the License.
|
8
|
+
# You may obtain a copy of the License at
|
9
|
+
#
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11
|
+
#
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15
|
+
# See the License for the specific language governing permissions and
|
16
|
+
# limitations under the License.
|
17
|
+
|
18
|
+
"""
|
19
|
+
Filter utilities for traversing and selecting objects in nested structures.
|
20
|
+
|
21
|
+
This module provides a flexible filtering system for working with nested data
|
22
|
+
structures in BrainState. It offers various filter classes and utilities to
|
23
|
+
select, match, and transform objects based on their properties, types, or
|
24
|
+
positions within a hierarchical structure.
|
25
|
+
|
26
|
+
Key Features
|
27
|
+
------------
|
28
|
+
- **Type-based filtering**: Select objects by their type or inheritance
|
29
|
+
- **Tag-based filtering**: Filter objects that have specific tags
|
30
|
+
- **Path-based filtering**: Select based on object paths in nested structures
|
31
|
+
- **Logical operations**: Combine filters with AND, OR, and NOT operations
|
32
|
+
- **Flexible conversion**: Convert various inputs to predicate functions
|
33
|
+
|
34
|
+
Filter Types
|
35
|
+
------------
|
36
|
+
The module provides several built-in filter classes:
|
37
|
+
|
38
|
+
- :class:`WithTag`: Filters objects with specific tags
|
39
|
+
- :class:`PathContains`: Filters based on path contents
|
40
|
+
- :class:`OfType`: Filters by object type
|
41
|
+
- :class:`Any`: Logical OR combination of filters
|
42
|
+
- :class:`All`: Logical AND combination of filters
|
43
|
+
- :class:`Not`: Logical negation of a filter
|
44
|
+
- :class:`Everything`: Matches all objects
|
45
|
+
- :class:`Nothing`: Matches no objects
|
46
|
+
|
47
|
+
Examples
|
48
|
+
--------
|
49
|
+
|
50
|
+
.. code-block:: python
|
51
|
+
|
52
|
+
>>> import brainstate as bs
|
53
|
+
>>> from brainstate.util.filter import WithTag, OfType, Any, All, Not
|
54
|
+
>>>
|
55
|
+
>>> # Filter objects with a specific tag
|
56
|
+
>>> tag_filter = WithTag('trainable')
|
57
|
+
>>>
|
58
|
+
>>> # Filter objects of a specific type
|
59
|
+
>>> type_filter = OfType(bs.nn.Linear)
|
60
|
+
>>>
|
61
|
+
>>> # Combine filters with logical operations
|
62
|
+
>>> combined_filter = All(
|
63
|
+
... WithTag('trainable'),
|
64
|
+
... OfType(bs.nn.Linear)
|
65
|
+
... )
|
66
|
+
>>>
|
67
|
+
>>> # Negate a filter
|
68
|
+
>>> not_trainable = Not(WithTag('trainable'))
|
69
|
+
>>>
|
70
|
+
>>> # Use Any for OR operations
|
71
|
+
>>> any_filter = Any(
|
72
|
+
... OfType(bs.nn.Linear),
|
73
|
+
... OfType(bs.nn.Conv)
|
74
|
+
... )
|
75
|
+
|
76
|
+
Using Filters with Tree Operations
|
77
|
+
-----------------------------------
|
78
|
+
|
79
|
+
.. code-block:: python
|
80
|
+
|
81
|
+
>>> import brainstate as bs
|
82
|
+
>>> import jax.tree_util as tree
|
83
|
+
>>> from brainstate.util.filter import to_predicate, WithTag
|
84
|
+
>>>
|
85
|
+
>>> # Create a model with tagged parameters
|
86
|
+
>>> class Model(bs.Module):
|
87
|
+
... def __init__(self):
|
88
|
+
... super().__init__()
|
89
|
+
... self.layer1 = bs.nn.Linear(10, 20)
|
90
|
+
... self.layer1.tag = 'trainable'
|
91
|
+
... self.layer2 = bs.nn.Linear(20, 10)
|
92
|
+
... self.layer2.tag = 'frozen'
|
93
|
+
>>>
|
94
|
+
>>> model = Model()
|
95
|
+
>>>
|
96
|
+
>>> # Filter trainable parameters
|
97
|
+
>>> trainable_filter = to_predicate('trainable')
|
98
|
+
>>>
|
99
|
+
>>> # Apply filter in tree operations
|
100
|
+
>>> def get_trainable_params(model):
|
101
|
+
... return tree.tree_map_with_path(
|
102
|
+
... lambda path, x: x if trainable_filter(path, x) else None,
|
103
|
+
... model
|
104
|
+
... )
|
105
|
+
|
106
|
+
Notes
|
107
|
+
-----
|
108
|
+
This module is adapted from the Flax library and provides similar functionality
|
109
|
+
for filtering and selecting components in neural network models and other
|
110
|
+
hierarchical data structures.
|
111
|
+
|
112
|
+
See Also
|
113
|
+
--------
|
114
|
+
brainstate.tree : Tree manipulation utilities
|
115
|
+
brainstate.typing : Type definitions for filters and predicates
|
116
|
+
|
117
|
+
"""
|
118
|
+
|
119
|
+
import builtins
|
120
|
+
import dataclasses
|
121
|
+
import typing
|
122
|
+
from typing import TYPE_CHECKING
|
123
|
+
|
124
|
+
from brainstate.typing import Filter, PathParts, Predicate, Key
|
125
|
+
|
126
|
+
if TYPE_CHECKING:
|
127
|
+
ellipsis = builtins.ellipsis
|
128
|
+
else:
|
129
|
+
ellipsis = typing.Any
|
130
|
+
|
131
|
+
__all__ = [
|
132
|
+
'to_predicate',
|
133
|
+
'WithTag',
|
134
|
+
'PathContains',
|
135
|
+
'OfType',
|
136
|
+
'Any',
|
137
|
+
'All',
|
138
|
+
'Nothing',
|
139
|
+
'Not',
|
140
|
+
'Everything',
|
141
|
+
]
|
142
|
+
|
143
|
+
|
144
|
+
def to_predicate(the_filter: Filter) -> Predicate:
|
145
|
+
"""
|
146
|
+
Convert a Filter to a predicate function.
|
147
|
+
|
148
|
+
This function takes various types of filters and converts them into
|
149
|
+
corresponding predicate functions that can be used for filtering objects
|
150
|
+
in nested structures.
|
151
|
+
|
152
|
+
Parameters
|
153
|
+
----------
|
154
|
+
the_filter : Filter
|
155
|
+
The filter to be converted. Can be of various types:
|
156
|
+
|
157
|
+
- **str**: Converted to a :class:`WithTag` filter
|
158
|
+
- **type**: Converted to an :class:`OfType` filter
|
159
|
+
- **bool**: ``True`` becomes :class:`Everything`, ``False`` becomes :class:`Nothing`
|
160
|
+
- **Ellipsis** (...): Converted to :class:`Everything`
|
161
|
+
- **None**: Converted to :class:`Nothing`
|
162
|
+
- **callable**: Returned as-is
|
163
|
+
- **list or tuple**: Converted to :class:`Any` filter with elements as arguments
|
164
|
+
|
165
|
+
Returns
|
166
|
+
-------
|
167
|
+
Predicate
|
168
|
+
A callable predicate function that takes (path, object) and returns bool.
|
169
|
+
|
170
|
+
Raises
|
171
|
+
------
|
172
|
+
TypeError
|
173
|
+
If the input filter is of an invalid type.
|
174
|
+
|
175
|
+
Examples
|
176
|
+
--------
|
177
|
+
.. code-block:: python
|
178
|
+
|
179
|
+
>>> from brainstate.util.filter import to_predicate
|
180
|
+
>>>
|
181
|
+
>>> # Convert string to WithTag filter
|
182
|
+
>>> pred = to_predicate('trainable')
|
183
|
+
>>> pred([], {'tag': 'trainable'})
|
184
|
+
True
|
185
|
+
>>>
|
186
|
+
>>> # Convert type to OfType filter
|
187
|
+
>>> import numpy as np
|
188
|
+
>>> pred = to_predicate(np.ndarray)
|
189
|
+
>>> pred([], np.array([1, 2, 3]))
|
190
|
+
True
|
191
|
+
>>>
|
192
|
+
>>> # Convert bool to Everything/Nothing
|
193
|
+
>>> pred_all = to_predicate(True)
|
194
|
+
>>> pred_all([], 'anything')
|
195
|
+
True
|
196
|
+
>>> pred_none = to_predicate(False)
|
197
|
+
>>> pred_none([], 'anything')
|
198
|
+
False
|
199
|
+
>>>
|
200
|
+
>>> # Convert list to Any filter
|
201
|
+
>>> pred = to_predicate(['tag1', 'tag2'])
|
202
|
+
>>> # This will match objects with either 'tag1' or 'tag2'
|
203
|
+
|
204
|
+
See Also
|
205
|
+
--------
|
206
|
+
WithTag : Filter for objects with specific tags
|
207
|
+
OfType : Filter for objects of specific types
|
208
|
+
Any : Logical OR combination of filters
|
209
|
+
Everything : Filter that matches all objects
|
210
|
+
Nothing : Filter that matches no objects
|
211
|
+
|
212
|
+
Notes
|
213
|
+
-----
|
214
|
+
This function is the main entry point for creating predicate functions
|
215
|
+
from various filter specifications. It provides a flexible way to define
|
216
|
+
filtering criteria without explicitly instantiating filter classes.
|
217
|
+
"""
|
218
|
+
|
219
|
+
if isinstance(the_filter, str):
|
220
|
+
return WithTag(the_filter)
|
221
|
+
elif isinstance(the_filter, type):
|
222
|
+
return OfType(the_filter)
|
223
|
+
elif isinstance(the_filter, bool):
|
224
|
+
if the_filter:
|
225
|
+
return Everything()
|
226
|
+
else:
|
227
|
+
return Nothing()
|
228
|
+
elif the_filter is Ellipsis:
|
229
|
+
return Everything()
|
230
|
+
elif the_filter is None:
|
231
|
+
return Nothing()
|
232
|
+
elif callable(the_filter):
|
233
|
+
return the_filter
|
234
|
+
elif isinstance(the_filter, (list, tuple)):
|
235
|
+
return Any(*the_filter)
|
236
|
+
else:
|
237
|
+
raise TypeError(f'Invalid collection filter: {the_filter!r}. ')
|
238
|
+
|
239
|
+
|
240
|
+
@dataclasses.dataclass(frozen=True)
|
241
|
+
class WithTag:
|
242
|
+
"""
|
243
|
+
Filter objects that have a specific tag attribute.
|
244
|
+
|
245
|
+
This filter checks if an object has a 'tag' attribute that matches
|
246
|
+
the specified tag value. It's commonly used to filter parameters or
|
247
|
+
modules in neural networks based on their assigned tags.
|
248
|
+
|
249
|
+
Parameters
|
250
|
+
----------
|
251
|
+
tag : str
|
252
|
+
The tag value to match against.
|
253
|
+
|
254
|
+
Attributes
|
255
|
+
----------
|
256
|
+
tag : str
|
257
|
+
The tag value to match against.
|
258
|
+
|
259
|
+
Examples
|
260
|
+
--------
|
261
|
+
.. code-block:: python
|
262
|
+
|
263
|
+
>>> from brainstate.util.filter import WithTag
|
264
|
+
>>> import brainstate as bs
|
265
|
+
>>>
|
266
|
+
>>> # Create a filter for 'trainable' tag
|
267
|
+
>>> filter_trainable = WithTag('trainable')
|
268
|
+
>>>
|
269
|
+
>>> # Test with an object that has the tag
|
270
|
+
>>> class Param:
|
271
|
+
... def __init__(self, tag):
|
272
|
+
... self.tag = tag
|
273
|
+
>>>
|
274
|
+
>>> param1 = Param('trainable')
|
275
|
+
>>> param2 = Param('frozen')
|
276
|
+
>>>
|
277
|
+
>>> filter_trainable([], param1)
|
278
|
+
True
|
279
|
+
>>> filter_trainable([], param2)
|
280
|
+
False
|
281
|
+
>>>
|
282
|
+
>>> # Use with neural network modules
|
283
|
+
>>> class MyModule(bs.Module):
|
284
|
+
... def __init__(self):
|
285
|
+
... super().__init__()
|
286
|
+
... self.weight = bs.State(bs.random.randn(10, 10))
|
287
|
+
... self.weight.tag = 'trainable'
|
288
|
+
... self.bias = bs.State(bs.zeros(10))
|
289
|
+
... self.bias.tag = 'frozen'
|
290
|
+
|
291
|
+
See Also
|
292
|
+
--------
|
293
|
+
PathContains : Filter based on path contents
|
294
|
+
OfType : Filter based on object type
|
295
|
+
to_predicate : Convert various inputs to predicates
|
296
|
+
|
297
|
+
Notes
|
298
|
+
-----
|
299
|
+
The filter only matches objects that have a 'tag' attribute. Objects
|
300
|
+
without this attribute will not match, even if the filter is looking
|
301
|
+
for a specific tag value.
|
302
|
+
"""
|
303
|
+
|
304
|
+
tag: str
|
305
|
+
|
306
|
+
def __call__(self, path: PathParts, x: typing.Any) -> bool:
|
307
|
+
"""
|
308
|
+
Check if the object has a matching tag.
|
309
|
+
|
310
|
+
Parameters
|
311
|
+
----------
|
312
|
+
path : PathParts
|
313
|
+
The path to the current object (not used in this filter).
|
314
|
+
x : Any
|
315
|
+
The object to check for the tag.
|
316
|
+
|
317
|
+
Returns
|
318
|
+
-------
|
319
|
+
bool
|
320
|
+
True if the object has a 'tag' attribute matching the specified tag,
|
321
|
+
False otherwise.
|
322
|
+
"""
|
323
|
+
return hasattr(x, 'tag') and x.tag == self.tag
|
324
|
+
|
325
|
+
def __repr__(self) -> str:
|
326
|
+
return f'WithTag({self.tag!r})'
|
327
|
+
|
328
|
+
|
329
|
+
@dataclasses.dataclass(frozen=True)
|
330
|
+
class PathContains:
|
331
|
+
"""
|
332
|
+
Filter objects based on whether their path contains a specific key.
|
333
|
+
|
334
|
+
This filter checks if a given key appears anywhere in the path to an object
|
335
|
+
within a nested structure. It's useful for selecting objects at specific
|
336
|
+
locations or with specific names in a hierarchy.
|
337
|
+
|
338
|
+
Parameters
|
339
|
+
----------
|
340
|
+
key : Key
|
341
|
+
The key to search for in the path.
|
342
|
+
|
343
|
+
Attributes
|
344
|
+
----------
|
345
|
+
key : Key
|
346
|
+
The key to search for in the path.
|
347
|
+
|
348
|
+
Examples
|
349
|
+
--------
|
350
|
+
.. code-block:: python
|
351
|
+
|
352
|
+
>>> from brainstate.util.filter import PathContains
|
353
|
+
>>>
|
354
|
+
>>> # Create a filter for paths containing 'weight'
|
355
|
+
>>> weight_filter = PathContains('weight')
|
356
|
+
>>>
|
357
|
+
>>> # Test with different paths
|
358
|
+
>>> weight_filter(['model', 'layer1', 'weight'], None)
|
359
|
+
True
|
360
|
+
>>> weight_filter(['model', 'layer1', 'bias'], None)
|
361
|
+
False
|
362
|
+
>>>
|
363
|
+
>>> # Filter for specific layer
|
364
|
+
>>> layer2_filter = PathContains('layer2')
|
365
|
+
>>> layer2_filter(['model', 'layer2', 'weight'], None)
|
366
|
+
True
|
367
|
+
>>> layer2_filter(['model', 'layer1', 'weight'], None)
|
368
|
+
False
|
369
|
+
>>>
|
370
|
+
>>> # Use with nested structures
|
371
|
+
>>> import jax.tree_util as tree
|
372
|
+
>>> nested_dict = {
|
373
|
+
... 'layer1': {'weight': [1, 2, 3], 'bias': [4, 5]},
|
374
|
+
... 'layer2': {'weight': [6, 7, 8], 'bias': [9, 10]}
|
375
|
+
... }
|
376
|
+
>>>
|
377
|
+
>>> # Filter all 'weight' entries
|
378
|
+
>>> def filter_weights(path, value):
|
379
|
+
... return value if weight_filter(path, value) else None
|
380
|
+
|
381
|
+
See Also
|
382
|
+
--------
|
383
|
+
WithTag : Filter based on tag attributes
|
384
|
+
OfType : Filter based on object type
|
385
|
+
to_predicate : Convert various inputs to predicates
|
386
|
+
|
387
|
+
Notes
|
388
|
+
-----
|
389
|
+
The path is typically a sequence of keys representing the location of
|
390
|
+
an object in a nested structure, such as the attribute names leading
|
391
|
+
to a parameter in a neural network model.
|
392
|
+
"""
|
393
|
+
|
394
|
+
key: Key
|
395
|
+
|
396
|
+
def __call__(self, path: PathParts, x: typing.Any) -> bool:
|
397
|
+
"""
|
398
|
+
Check if the key is present in the path.
|
399
|
+
|
400
|
+
Parameters
|
401
|
+
----------
|
402
|
+
path : PathParts
|
403
|
+
The path to check for the presence of the key.
|
404
|
+
x : Any
|
405
|
+
The object associated with the path (not used in this filter).
|
406
|
+
|
407
|
+
Returns
|
408
|
+
-------
|
409
|
+
bool
|
410
|
+
True if the key is present in the path, False otherwise.
|
411
|
+
"""
|
412
|
+
return self.key in path
|
413
|
+
|
414
|
+
def __repr__(self) -> str:
|
415
|
+
return f'PathContains({self.key!r})'
|
416
|
+
|
417
|
+
|
418
|
+
@dataclasses.dataclass(frozen=True)
|
419
|
+
class OfType:
|
420
|
+
"""
|
421
|
+
Filter objects based on their type.
|
422
|
+
|
423
|
+
This filter checks if an object is an instance of a specific type or
|
424
|
+
if it has a 'type' attribute that is a subclass of the specified type.
|
425
|
+
It's useful for filtering specific kinds of objects in a nested structure.
|
426
|
+
|
427
|
+
Parameters
|
428
|
+
----------
|
429
|
+
type : type
|
430
|
+
The type to match against.
|
431
|
+
|
432
|
+
Attributes
|
433
|
+
----------
|
434
|
+
type : type
|
435
|
+
The type to match against.
|
436
|
+
|
437
|
+
Examples
|
438
|
+
--------
|
439
|
+
.. code-block:: python
|
440
|
+
|
441
|
+
>>> from brainstate.util.filter import OfType
|
442
|
+
>>> import numpy as np
|
443
|
+
>>> import jax.numpy as jnp
|
444
|
+
>>>
|
445
|
+
>>> # Create a filter for numpy arrays
|
446
|
+
>>> array_filter = OfType(np.ndarray)
|
447
|
+
>>>
|
448
|
+
>>> # Test with different objects
|
449
|
+
>>> array_filter([], np.array([1, 2, 3]))
|
450
|
+
True
|
451
|
+
>>> array_filter([], [1, 2, 3])
|
452
|
+
False
|
453
|
+
>>>
|
454
|
+
>>> # Filter for specific module types
|
455
|
+
>>> import brainstate as bs
|
456
|
+
>>> linear_filter = OfType(bs.nn.Linear)
|
457
|
+
>>>
|
458
|
+
>>> # Use in model filtering
|
459
|
+
>>> class Model(bs.nn.Module):
|
460
|
+
... def __init__(self):
|
461
|
+
... super().__init__()
|
462
|
+
... self.linear1 = bs.nn.Linear(10, 20)
|
463
|
+
... self.linear2 = bs.nn.Linear(20, 10)
|
464
|
+
... self.activation = bs.nn.ReLU()
|
465
|
+
>>>
|
466
|
+
>>> # Filter all Linear layers
|
467
|
+
>>> model = Model()
|
468
|
+
>>> # linear_filter will match linear1 and linear2, not activation
|
469
|
+
|
470
|
+
See Also
|
471
|
+
--------
|
472
|
+
WithTag : Filter based on tag attributes
|
473
|
+
PathContains : Filter based on path contents
|
474
|
+
to_predicate : Convert various inputs to predicates
|
475
|
+
|
476
|
+
Notes
|
477
|
+
-----
|
478
|
+
This filter also checks for objects that have a 'type' attribute,
|
479
|
+
which is useful for wrapped or proxy objects that maintain type
|
480
|
+
information differently.
|
481
|
+
"""
|
482
|
+
type: type
|
483
|
+
|
484
|
+
def __call__(self, path: PathParts, x: typing.Any):
|
485
|
+
"""
|
486
|
+
Check if the object is of the specified type.
|
487
|
+
|
488
|
+
Parameters
|
489
|
+
----------
|
490
|
+
path : PathParts
|
491
|
+
The path to the current object (not used in this filter).
|
492
|
+
x : Any
|
493
|
+
The object to check.
|
494
|
+
|
495
|
+
Returns
|
496
|
+
-------
|
497
|
+
bool
|
498
|
+
True if the object is an instance of the specified type or
|
499
|
+
has a 'type' attribute that is a subclass of the specified type.
|
500
|
+
"""
|
501
|
+
return isinstance(x, self.type) or (
|
502
|
+
hasattr(x, 'type') and issubclass(x.type, self.type)
|
503
|
+
)
|
504
|
+
|
505
|
+
def __repr__(self):
|
506
|
+
return f'OfType({self.type!r})'
|
507
|
+
|
508
|
+
|
509
|
+
class Any:
|
510
|
+
"""
|
511
|
+
Combine multiple filters using logical OR operation.
|
512
|
+
|
513
|
+
This filter returns True if any of its constituent filters return True.
|
514
|
+
It's useful for creating flexible filtering criteria where multiple
|
515
|
+
conditions can be satisfied.
|
516
|
+
|
517
|
+
Parameters
|
518
|
+
----------
|
519
|
+
*filters : Filter
|
520
|
+
Variable number of filters to be combined with OR logic.
|
521
|
+
|
522
|
+
Attributes
|
523
|
+
----------
|
524
|
+
predicates : tuple of Predicate
|
525
|
+
Tuple of predicate functions converted from the input filters.
|
526
|
+
|
527
|
+
Examples
|
528
|
+
--------
|
529
|
+
.. code-block:: python
|
530
|
+
|
531
|
+
>>> from brainstate.util.filter import Any, WithTag, OfType
|
532
|
+
>>> import numpy as np
|
533
|
+
>>>
|
534
|
+
>>> # Create a filter that matches either tag
|
535
|
+
>>> trainable_or_frozen = Any('trainable', 'frozen')
|
536
|
+
>>>
|
537
|
+
>>> # Test with objects
|
538
|
+
>>> class Param:
|
539
|
+
... def __init__(self, tag):
|
540
|
+
... self.tag = tag
|
541
|
+
>>>
|
542
|
+
>>> trainable = Param('trainable')
|
543
|
+
>>> frozen = Param('frozen')
|
544
|
+
>>> other = Param('other')
|
545
|
+
>>>
|
546
|
+
>>> trainable_or_frozen([], trainable)
|
547
|
+
True
|
548
|
+
>>> trainable_or_frozen([], frozen)
|
549
|
+
True
|
550
|
+
>>> trainable_or_frozen([], other)
|
551
|
+
False
|
552
|
+
>>>
|
553
|
+
>>> # Combine different filter types
|
554
|
+
>>> array_or_list = Any(
|
555
|
+
... OfType(np.ndarray),
|
556
|
+
... OfType(list)
|
557
|
+
... )
|
558
|
+
>>>
|
559
|
+
>>> array_or_list([], np.array([1, 2, 3]))
|
560
|
+
True
|
561
|
+
>>> array_or_list([], [1, 2, 3])
|
562
|
+
True
|
563
|
+
>>> array_or_list([], (1, 2, 3))
|
564
|
+
False
|
565
|
+
|
566
|
+
See Also
|
567
|
+
--------
|
568
|
+
All : Logical AND combination of filters
|
569
|
+
Not : Logical negation of a filter
|
570
|
+
to_predicate : Convert various inputs to predicates
|
571
|
+
|
572
|
+
Notes
|
573
|
+
-----
|
574
|
+
The Any filter short-circuits evaluation, returning True as soon as
|
575
|
+
one of its constituent filters returns True.
|
576
|
+
"""
|
577
|
+
|
578
|
+
def __init__(self, *filters: Filter):
|
579
|
+
"""
|
580
|
+
Initialize the Any filter.
|
581
|
+
|
582
|
+
Parameters
|
583
|
+
----------
|
584
|
+
*filters : Filter
|
585
|
+
Variable number of filters to be combined.
|
586
|
+
"""
|
587
|
+
self.predicates = tuple(
|
588
|
+
to_predicate(collection_filter) for collection_filter in filters
|
589
|
+
)
|
590
|
+
|
591
|
+
def __call__(self, path: PathParts, x: typing.Any) -> bool:
|
592
|
+
"""
|
593
|
+
Apply the composite filter to the given path and object.
|
594
|
+
|
595
|
+
Args:
|
596
|
+
path (PathParts): The path to the current object.
|
597
|
+
x (typing.Any): The object to be filtered.
|
598
|
+
|
599
|
+
Returns:
|
600
|
+
bool: True if any of the constituent predicates return True, False otherwise.
|
601
|
+
"""
|
602
|
+
return any(predicate(path, x) for predicate in self.predicates)
|
603
|
+
|
604
|
+
def __repr__(self) -> str:
|
605
|
+
"""
|
606
|
+
Return a string representation of the Any filter.
|
607
|
+
|
608
|
+
Returns:
|
609
|
+
str: A string representation of the Any filter, including its predicates.
|
610
|
+
"""
|
611
|
+
return f'Any({", ".join(map(repr, self.predicates))})'
|
612
|
+
|
613
|
+
def __eq__(self, other) -> bool:
|
614
|
+
"""
|
615
|
+
Check if this Any filter is equal to another object.
|
616
|
+
|
617
|
+
Args:
|
618
|
+
other: The object to compare with.
|
619
|
+
|
620
|
+
Returns:
|
621
|
+
bool: True if the other object is an Any filter with the same predicates, False otherwise.
|
622
|
+
"""
|
623
|
+
return isinstance(other, Any) and self.predicates == other.predicates
|
624
|
+
|
625
|
+
def __hash__(self) -> int:
|
626
|
+
"""
|
627
|
+
Compute the hash value for this Any filter.
|
628
|
+
|
629
|
+
Returns:
|
630
|
+
int: The hash value of the predicates tuple.
|
631
|
+
"""
|
632
|
+
return hash(self.predicates)
|
633
|
+
|
634
|
+
|
635
|
+
class All:
|
636
|
+
"""
|
637
|
+
A filter class that combines multiple filters using a logical AND operation.
|
638
|
+
|
639
|
+
This class creates a composite filter that returns True only if all of its
|
640
|
+
constituent filters return True.
|
641
|
+
|
642
|
+
Attributes:
|
643
|
+
predicates (tuple): A tuple of predicate functions converted from the input filters.
|
644
|
+
"""
|
645
|
+
|
646
|
+
def __init__(self, *filters: Filter):
|
647
|
+
"""
|
648
|
+
Initialize the All filter with a variable number of filters.
|
649
|
+
|
650
|
+
Args:
|
651
|
+
*filters (Filter): Variable number of filters to be combined.
|
652
|
+
"""
|
653
|
+
self.predicates = tuple(
|
654
|
+
to_predicate(collection_filter) for collection_filter in filters
|
655
|
+
)
|
656
|
+
|
657
|
+
def __call__(self, path: PathParts, x: typing.Any) -> bool:
|
658
|
+
"""
|
659
|
+
Apply the composite filter to the given path and object.
|
660
|
+
|
661
|
+
Args:
|
662
|
+
path (PathParts): The path to the current object.
|
663
|
+
x (typing.Any): The object to be filtered.
|
664
|
+
|
665
|
+
Returns:
|
666
|
+
bool: True if all of the constituent predicates return True, False otherwise.
|
667
|
+
"""
|
668
|
+
return all(predicate(path, x) for predicate in self.predicates)
|
669
|
+
|
670
|
+
def __repr__(self) -> str:
|
671
|
+
"""
|
672
|
+
Return a string representation of the All filter.
|
673
|
+
|
674
|
+
Returns:
|
675
|
+
str: A string representation of the All filter, including its predicates.
|
676
|
+
"""
|
677
|
+
return f'All({", ".join(map(repr, self.predicates))})'
|
678
|
+
|
679
|
+
def __eq__(self, other) -> bool:
|
680
|
+
"""
|
681
|
+
Check if this All filter is equal to another object.
|
682
|
+
|
683
|
+
Args:
|
684
|
+
other: The object to compare with.
|
685
|
+
|
686
|
+
Returns:
|
687
|
+
bool: True if the other object is an All filter with the same predicates, False otherwise.
|
688
|
+
"""
|
689
|
+
return isinstance(other, All) and self.predicates == other.predicates
|
690
|
+
|
691
|
+
def __hash__(self) -> int:
|
692
|
+
"""
|
693
|
+
Compute the hash value for this All filter.
|
694
|
+
|
695
|
+
Returns:
|
696
|
+
int: The hash value of the predicates tuple.
|
697
|
+
"""
|
698
|
+
return hash(self.predicates)
|
699
|
+
|
700
|
+
|
701
|
+
class Not:
|
702
|
+
"""
|
703
|
+
A filter class that negates the result of another filter.
|
704
|
+
|
705
|
+
This class creates a new filter that returns the opposite boolean value
|
706
|
+
of the filter it wraps.
|
707
|
+
|
708
|
+
Attributes:
|
709
|
+
predicate (Predicate): The predicate function converted from the input filter.
|
710
|
+
"""
|
711
|
+
|
712
|
+
def __init__(self, collection_filter: Filter, /):
|
713
|
+
"""
|
714
|
+
Initialize the Not filter with another filter.
|
715
|
+
|
716
|
+
Args:
|
717
|
+
collection_filter (Filter): The filter to be negated.
|
718
|
+
"""
|
719
|
+
self.predicate = to_predicate(collection_filter)
|
720
|
+
|
721
|
+
def __call__(self, path: PathParts, x: typing.Any) -> bool:
|
722
|
+
"""
|
723
|
+
Apply the negated filter to the given path and object.
|
724
|
+
|
725
|
+
Args:
|
726
|
+
path (PathParts): The path to the current object.
|
727
|
+
x (typing.Any): The object to be filtered.
|
728
|
+
|
729
|
+
Returns:
|
730
|
+
bool: The negation of the result from the wrapped predicate.
|
731
|
+
"""
|
732
|
+
return not self.predicate(path, x)
|
733
|
+
|
734
|
+
def __repr__(self) -> str:
|
735
|
+
"""
|
736
|
+
Return a string representation of the Not filter.
|
737
|
+
|
738
|
+
Returns:
|
739
|
+
str: A string representation of the Not filter, including its predicate.
|
740
|
+
"""
|
741
|
+
return f'Not({self.predicate!r})'
|
742
|
+
|
743
|
+
def __eq__(self, other) -> bool:
|
744
|
+
"""
|
745
|
+
Check if this Not filter is equal to another object.
|
746
|
+
|
747
|
+
Args:
|
748
|
+
other: The object to compare with.
|
749
|
+
|
750
|
+
Returns:
|
751
|
+
bool: True if the other object is a Not filter with the same predicate, False otherwise.
|
752
|
+
"""
|
753
|
+
return isinstance(other, Not) and self.predicate == other.predicate
|
754
|
+
|
755
|
+
def __hash__(self) -> int:
|
756
|
+
"""
|
757
|
+
Compute the hash value for this Not filter.
|
758
|
+
|
759
|
+
Returns:
|
760
|
+
int: The hash value of the predicate.
|
761
|
+
"""
|
762
|
+
return hash(self.predicate)
|
763
|
+
|
764
|
+
|
765
|
+
class Everything:
|
766
|
+
"""
|
767
|
+
Filter that matches all objects.
|
768
|
+
|
769
|
+
This filter always returns True, effectively disabling filtering.
|
770
|
+
It's useful as a default filter or when you want to select everything
|
771
|
+
in a structure.
|
772
|
+
|
773
|
+
Examples
|
774
|
+
--------
|
775
|
+
.. code-block:: python
|
776
|
+
|
777
|
+
>>> from brainstate.util.filter import Everything
|
778
|
+
>>>
|
779
|
+
>>> # Create a filter that matches everything
|
780
|
+
>>> all_filter = Everything()
|
781
|
+
>>>
|
782
|
+
>>> # Always returns True
|
783
|
+
>>> all_filter([], 'any_object')
|
784
|
+
True
|
785
|
+
>>> all_filter(['some', 'path'], 42)
|
786
|
+
True
|
787
|
+
>>> all_filter([], None)
|
788
|
+
True
|
789
|
+
>>>
|
790
|
+
>>> # Useful as a default filter
|
791
|
+
>>> def process_data(data, filter=None):
|
792
|
+
... if filter is None:
|
793
|
+
... filter = Everything()
|
794
|
+
... # Process all data when no specific filter is provided
|
795
|
+
|
796
|
+
See Also
|
797
|
+
--------
|
798
|
+
Nothing : Filter that matches no objects
|
799
|
+
to_predicate : Convert True to Everything filter
|
800
|
+
|
801
|
+
Notes
|
802
|
+
-----
|
803
|
+
This filter is equivalent to using ``to_predicate(True)`` or
|
804
|
+
``to_predicate(...)`` (Ellipsis).
|
805
|
+
"""
|
806
|
+
|
807
|
+
def __call__(self, path: PathParts, x: typing.Any) -> bool:
|
808
|
+
"""
|
809
|
+
Always return True.
|
810
|
+
|
811
|
+
Parameters
|
812
|
+
----------
|
813
|
+
path : PathParts
|
814
|
+
The path to the current object (ignored).
|
815
|
+
x : Any
|
816
|
+
The object to be filtered (ignored).
|
817
|
+
|
818
|
+
Returns
|
819
|
+
-------
|
820
|
+
bool
|
821
|
+
Always returns True.
|
822
|
+
"""
|
823
|
+
return True
|
824
|
+
|
825
|
+
def __repr__(self) -> str:
|
826
|
+
"""
|
827
|
+
Return a string representation of the Everything filter.
|
828
|
+
|
829
|
+
Returns:
|
830
|
+
str: The string 'Everything()'.
|
831
|
+
"""
|
832
|
+
return 'Everything()'
|
833
|
+
|
834
|
+
def __eq__(self, other) -> bool:
|
835
|
+
"""
|
836
|
+
Check if this Everything filter is equal to another object.
|
837
|
+
|
838
|
+
Args:
|
839
|
+
other: The object to compare with.
|
840
|
+
|
841
|
+
Returns:
|
842
|
+
bool: True if the other object is an instance of Everything, False otherwise.
|
843
|
+
"""
|
844
|
+
return isinstance(other, Everything)
|
845
|
+
|
846
|
+
def __hash__(self) -> int:
|
847
|
+
"""
|
848
|
+
Compute the hash value for this Everything filter.
|
849
|
+
|
850
|
+
Returns:
|
851
|
+
int: The hash value of the Everything class.
|
852
|
+
"""
|
853
|
+
return hash(Everything)
|
854
|
+
|
855
|
+
|
856
|
+
class Nothing:
|
857
|
+
"""
|
858
|
+
Filter that matches no objects.
|
859
|
+
|
860
|
+
This filter always returns False, effectively filtering out all objects.
|
861
|
+
It's useful for disabling selection or creating empty filter results.
|
862
|
+
|
863
|
+
Examples
|
864
|
+
--------
|
865
|
+
.. code-block:: python
|
866
|
+
|
867
|
+
>>> from brainstate.util.filter import Nothing
|
868
|
+
>>>
|
869
|
+
>>> # Create a filter that matches nothing
|
870
|
+
>>> none_filter = Nothing()
|
871
|
+
>>>
|
872
|
+
>>> # Always returns False
|
873
|
+
>>> none_filter([], 'any_object')
|
874
|
+
False
|
875
|
+
>>> none_filter(['some', 'path'], 42)
|
876
|
+
False
|
877
|
+
>>> none_filter([], None)
|
878
|
+
False
|
879
|
+
>>>
|
880
|
+
>>> # Useful for conditional filtering
|
881
|
+
>>> def get_params(model, include_frozen=False):
|
882
|
+
... if include_frozen:
|
883
|
+
... filter = Everything()
|
884
|
+
... else:
|
885
|
+
... filter = Nothing() # Exclude all frozen params
|
886
|
+
... # Apply filter to model parameters
|
887
|
+
|
888
|
+
See Also
|
889
|
+
--------
|
890
|
+
Everything : Filter that matches all objects
|
891
|
+
to_predicate : Convert False or None to Nothing filter
|
892
|
+
|
893
|
+
Notes
|
894
|
+
-----
|
895
|
+
This filter is equivalent to using ``to_predicate(False)`` or
|
896
|
+
``to_predicate(None)``.
|
897
|
+
"""
|
898
|
+
|
899
|
+
def __call__(self, path: PathParts, x: typing.Any) -> bool:
|
900
|
+
"""
|
901
|
+
Always return False.
|
902
|
+
|
903
|
+
Parameters
|
904
|
+
----------
|
905
|
+
path : PathParts
|
906
|
+
The path to the current object (ignored).
|
907
|
+
x : Any
|
908
|
+
The object to be filtered (ignored).
|
909
|
+
|
910
|
+
Returns
|
911
|
+
-------
|
912
|
+
bool
|
913
|
+
Always returns False.
|
914
|
+
"""
|
915
|
+
return False
|
916
|
+
|
917
|
+
def __repr__(self) -> str:
|
918
|
+
"""
|
919
|
+
Return a string representation of the Nothing filter.
|
920
|
+
|
921
|
+
Returns:
|
922
|
+
str: The string 'Nothing()'.
|
923
|
+
"""
|
924
|
+
return 'Nothing()'
|
925
|
+
|
926
|
+
def __eq__(self, other) -> bool:
|
927
|
+
"""
|
928
|
+
Check if this Nothing filter is equal to another object.
|
929
|
+
|
930
|
+
Args:
|
931
|
+
other: The object to compare with.
|
932
|
+
|
933
|
+
Returns:
|
934
|
+
bool: True if the other object is an instance of Nothing, False otherwise.
|
935
|
+
"""
|
936
|
+
return isinstance(other, Nothing)
|
937
|
+
|
938
|
+
def __hash__(self) -> int:
|
939
|
+
"""
|
940
|
+
Compute the hash value for this Nothing filter.
|
941
|
+
|
942
|
+
Returns:
|
943
|
+
int: The hash value of the Nothing class.
|
944
|
+
"""
|
945
|
+
return hash(Nothing)
|