brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. brainstate/__init__.py +169 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2319 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +1652 -1652
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1624 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1433 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +137 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +633 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +154 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +477 -477
  32. brainstate/nn/_dynamics.py +1267 -1267
  33. brainstate/nn/_dynamics_test.py +67 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +384 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/_rand_funs.py +3938 -3938
  64. brainstate/random/_rand_funs_test.py +640 -640
  65. brainstate/random/_rand_seed.py +675 -675
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1617
  68. brainstate/random/_rand_state_test.py +551 -551
  69. brainstate/transform/__init__.py +59 -59
  70. brainstate/transform/_ad_checkpoint.py +176 -176
  71. brainstate/transform/_ad_checkpoint_test.py +49 -49
  72. brainstate/transform/_autograd.py +1025 -1025
  73. brainstate/transform/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -316
  75. brainstate/transform/_conditions_test.py +220 -220
  76. brainstate/transform/_error_if.py +94 -94
  77. brainstate/transform/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -145
  79. brainstate/transform/_eval_shape_test.py +38 -38
  80. brainstate/transform/_jit.py +399 -399
  81. brainstate/transform/_jit_test.py +143 -143
  82. brainstate/transform/_loop_collect_return.py +675 -675
  83. brainstate/transform/_loop_collect_return_test.py +58 -58
  84. brainstate/transform/_loop_no_collection.py +283 -283
  85. brainstate/transform/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -2016
  87. brainstate/transform/_make_jaxpr_test.py +1510 -1510
  88. brainstate/transform/_mapping.py +529 -529
  89. brainstate/transform/_mapping_test.py +194 -194
  90. brainstate/transform/_progress_bar.py +255 -255
  91. brainstate/transform/_random.py +171 -171
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate-0.2.0.dist-info/RECORD +0 -111
  111. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  112. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/util/filter.py CHANGED
@@ -1,945 +1,945 @@
1
- # The file is adapted from the Flax library (https://github.com/google/flax).
2
- # The credit should go to the Flax authors.
3
- #
4
- # Copyright 2024 The Flax Authors.
5
- #
6
- # Licensed under the Apache License, Version 2.0 (the "License");
7
- # you may not use this file except in compliance with the License.
8
- # You may obtain a copy of the License at
9
- #
10
- # http://www.apache.org/licenses/LICENSE-2.0
11
- #
12
- # Unless required by applicable law or agreed to in writing, software
13
- # distributed under the License is distributed on an "AS IS" BASIS,
14
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
- # See the License for the specific language governing permissions and
16
- # limitations under the License.
17
-
18
- """
19
- Filter utilities for traversing and selecting objects in nested structures.
20
-
21
- This module provides a flexible filtering system for working with nested data
22
- structures in BrainState. It offers various filter classes and utilities to
23
- select, match, and transform objects based on their properties, types, or
24
- positions within a hierarchical structure.
25
-
26
- Key Features
27
- ------------
28
- - **Type-based filtering**: Select objects by their type or inheritance
29
- - **Tag-based filtering**: Filter objects that have specific tags
30
- - **Path-based filtering**: Select based on object paths in nested structures
31
- - **Logical operations**: Combine filters with AND, OR, and NOT operations
32
- - **Flexible conversion**: Convert various inputs to predicate functions
33
-
34
- Filter Types
35
- ------------
36
- The module provides several built-in filter classes:
37
-
38
- - :class:`WithTag`: Filters objects with specific tags
39
- - :class:`PathContains`: Filters based on path contents
40
- - :class:`OfType`: Filters by object type
41
- - :class:`Any`: Logical OR combination of filters
42
- - :class:`All`: Logical AND combination of filters
43
- - :class:`Not`: Logical negation of a filter
44
- - :class:`Everything`: Matches all objects
45
- - :class:`Nothing`: Matches no objects
46
-
47
- Examples
48
- --------
49
-
50
- .. code-block:: python
51
-
52
- >>> import brainstate as bs
53
- >>> from brainstate.util.filter import WithTag, OfType, Any, All, Not
54
- >>>
55
- >>> # Filter objects with a specific tag
56
- >>> tag_filter = WithTag('trainable')
57
- >>>
58
- >>> # Filter objects of a specific type
59
- >>> type_filter = OfType(bs.nn.Linear)
60
- >>>
61
- >>> # Combine filters with logical operations
62
- >>> combined_filter = All(
63
- ... WithTag('trainable'),
64
- ... OfType(bs.nn.Linear)
65
- ... )
66
- >>>
67
- >>> # Negate a filter
68
- >>> not_trainable = Not(WithTag('trainable'))
69
- >>>
70
- >>> # Use Any for OR operations
71
- >>> any_filter = Any(
72
- ... OfType(bs.nn.Linear),
73
- ... OfType(bs.nn.Conv)
74
- ... )
75
-
76
- Using Filters with Tree Operations
77
- -----------------------------------
78
-
79
- .. code-block:: python
80
-
81
- >>> import brainstate as bs
82
- >>> import jax.tree_util as tree
83
- >>> from brainstate.util.filter import to_predicate, WithTag
84
- >>>
85
- >>> # Create a model with tagged parameters
86
- >>> class Model(bs.Module):
87
- ... def __init__(self):
88
- ... super().__init__()
89
- ... self.layer1 = bs.nn.Linear(10, 20)
90
- ... self.layer1.tag = 'trainable'
91
- ... self.layer2 = bs.nn.Linear(20, 10)
92
- ... self.layer2.tag = 'frozen'
93
- >>>
94
- >>> model = Model()
95
- >>>
96
- >>> # Filter trainable parameters
97
- >>> trainable_filter = to_predicate('trainable')
98
- >>>
99
- >>> # Apply filter in tree operations
100
- >>> def get_trainable_params(model):
101
- ... return tree.tree_map_with_path(
102
- ... lambda path, x: x if trainable_filter(path, x) else None,
103
- ... model
104
- ... )
105
-
106
- Notes
107
- -----
108
- This module is adapted from the Flax library and provides similar functionality
109
- for filtering and selecting components in neural network models and other
110
- hierarchical data structures.
111
-
112
- See Also
113
- --------
114
- brainstate.tree : Tree manipulation utilities
115
- brainstate.typing : Type definitions for filters and predicates
116
-
117
- """
118
-
119
- import builtins
120
- import dataclasses
121
- import typing
122
- from typing import TYPE_CHECKING
123
-
124
- from brainstate.typing import Filter, PathParts, Predicate, Key
125
-
126
- if TYPE_CHECKING:
127
- ellipsis = builtins.ellipsis
128
- else:
129
- ellipsis = typing.Any
130
-
131
- __all__ = [
132
- 'to_predicate',
133
- 'WithTag',
134
- 'PathContains',
135
- 'OfType',
136
- 'Any',
137
- 'All',
138
- 'Nothing',
139
- 'Not',
140
- 'Everything',
141
- ]
142
-
143
-
144
- def to_predicate(the_filter: Filter) -> Predicate:
145
- """
146
- Convert a Filter to a predicate function.
147
-
148
- This function takes various types of filters and converts them into
149
- corresponding predicate functions that can be used for filtering objects
150
- in nested structures.
151
-
152
- Parameters
153
- ----------
154
- the_filter : Filter
155
- The filter to be converted. Can be of various types:
156
-
157
- - **str**: Converted to a :class:`WithTag` filter
158
- - **type**: Converted to an :class:`OfType` filter
159
- - **bool**: ``True`` becomes :class:`Everything`, ``False`` becomes :class:`Nothing`
160
- - **Ellipsis** (...): Converted to :class:`Everything`
161
- - **None**: Converted to :class:`Nothing`
162
- - **callable**: Returned as-is
163
- - **list or tuple**: Converted to :class:`Any` filter with elements as arguments
164
-
165
- Returns
166
- -------
167
- Predicate
168
- A callable predicate function that takes (path, object) and returns bool.
169
-
170
- Raises
171
- ------
172
- TypeError
173
- If the input filter is of an invalid type.
174
-
175
- Examples
176
- --------
177
- .. code-block:: python
178
-
179
- >>> from brainstate.util.filter import to_predicate
180
- >>>
181
- >>> # Convert string to WithTag filter
182
- >>> pred = to_predicate('trainable')
183
- >>> pred([], {'tag': 'trainable'})
184
- True
185
- >>>
186
- >>> # Convert type to OfType filter
187
- >>> import numpy as np
188
- >>> pred = to_predicate(np.ndarray)
189
- >>> pred([], np.array([1, 2, 3]))
190
- True
191
- >>>
192
- >>> # Convert bool to Everything/Nothing
193
- >>> pred_all = to_predicate(True)
194
- >>> pred_all([], 'anything')
195
- True
196
- >>> pred_none = to_predicate(False)
197
- >>> pred_none([], 'anything')
198
- False
199
- >>>
200
- >>> # Convert list to Any filter
201
- >>> pred = to_predicate(['tag1', 'tag2'])
202
- >>> # This will match objects with either 'tag1' or 'tag2'
203
-
204
- See Also
205
- --------
206
- WithTag : Filter for objects with specific tags
207
- OfType : Filter for objects of specific types
208
- Any : Logical OR combination of filters
209
- Everything : Filter that matches all objects
210
- Nothing : Filter that matches no objects
211
-
212
- Notes
213
- -----
214
- This function is the main entry point for creating predicate functions
215
- from various filter specifications. It provides a flexible way to define
216
- filtering criteria without explicitly instantiating filter classes.
217
- """
218
-
219
- if isinstance(the_filter, str):
220
- return WithTag(the_filter)
221
- elif isinstance(the_filter, type):
222
- return OfType(the_filter)
223
- elif isinstance(the_filter, bool):
224
- if the_filter:
225
- return Everything()
226
- else:
227
- return Nothing()
228
- elif the_filter is Ellipsis:
229
- return Everything()
230
- elif the_filter is None:
231
- return Nothing()
232
- elif callable(the_filter):
233
- return the_filter
234
- elif isinstance(the_filter, (list, tuple)):
235
- return Any(*the_filter)
236
- else:
237
- raise TypeError(f'Invalid collection filter: {the_filter!r}. ')
238
-
239
-
240
- @dataclasses.dataclass(frozen=True)
241
- class WithTag:
242
- """
243
- Filter objects that have a specific tag attribute.
244
-
245
- This filter checks if an object has a 'tag' attribute that matches
246
- the specified tag value. It's commonly used to filter parameters or
247
- modules in neural networks based on their assigned tags.
248
-
249
- Parameters
250
- ----------
251
- tag : str
252
- The tag value to match against.
253
-
254
- Attributes
255
- ----------
256
- tag : str
257
- The tag value to match against.
258
-
259
- Examples
260
- --------
261
- .. code-block:: python
262
-
263
- >>> from brainstate.util.filter import WithTag
264
- >>> import brainstate as bs
265
- >>>
266
- >>> # Create a filter for 'trainable' tag
267
- >>> filter_trainable = WithTag('trainable')
268
- >>>
269
- >>> # Test with an object that has the tag
270
- >>> class Param:
271
- ... def __init__(self, tag):
272
- ... self.tag = tag
273
- >>>
274
- >>> param1 = Param('trainable')
275
- >>> param2 = Param('frozen')
276
- >>>
277
- >>> filter_trainable([], param1)
278
- True
279
- >>> filter_trainable([], param2)
280
- False
281
- >>>
282
- >>> # Use with neural network modules
283
- >>> class MyModule(bs.Module):
284
- ... def __init__(self):
285
- ... super().__init__()
286
- ... self.weight = bs.State(bs.random.randn(10, 10))
287
- ... self.weight.tag = 'trainable'
288
- ... self.bias = bs.State(bs.zeros(10))
289
- ... self.bias.tag = 'frozen'
290
-
291
- See Also
292
- --------
293
- PathContains : Filter based on path contents
294
- OfType : Filter based on object type
295
- to_predicate : Convert various inputs to predicates
296
-
297
- Notes
298
- -----
299
- The filter only matches objects that have a 'tag' attribute. Objects
300
- without this attribute will not match, even if the filter is looking
301
- for a specific tag value.
302
- """
303
-
304
- tag: str
305
-
306
- def __call__(self, path: PathParts, x: typing.Any) -> bool:
307
- """
308
- Check if the object has a matching tag.
309
-
310
- Parameters
311
- ----------
312
- path : PathParts
313
- The path to the current object (not used in this filter).
314
- x : Any
315
- The object to check for the tag.
316
-
317
- Returns
318
- -------
319
- bool
320
- True if the object has a 'tag' attribute matching the specified tag,
321
- False otherwise.
322
- """
323
- return hasattr(x, 'tag') and x.tag == self.tag
324
-
325
- def __repr__(self) -> str:
326
- return f'WithTag({self.tag!r})'
327
-
328
-
329
- @dataclasses.dataclass(frozen=True)
330
- class PathContains:
331
- """
332
- Filter objects based on whether their path contains a specific key.
333
-
334
- This filter checks if a given key appears anywhere in the path to an object
335
- within a nested structure. It's useful for selecting objects at specific
336
- locations or with specific names in a hierarchy.
337
-
338
- Parameters
339
- ----------
340
- key : Key
341
- The key to search for in the path.
342
-
343
- Attributes
344
- ----------
345
- key : Key
346
- The key to search for in the path.
347
-
348
- Examples
349
- --------
350
- .. code-block:: python
351
-
352
- >>> from brainstate.util.filter import PathContains
353
- >>>
354
- >>> # Create a filter for paths containing 'weight'
355
- >>> weight_filter = PathContains('weight')
356
- >>>
357
- >>> # Test with different paths
358
- >>> weight_filter(['model', 'layer1', 'weight'], None)
359
- True
360
- >>> weight_filter(['model', 'layer1', 'bias'], None)
361
- False
362
- >>>
363
- >>> # Filter for specific layer
364
- >>> layer2_filter = PathContains('layer2')
365
- >>> layer2_filter(['model', 'layer2', 'weight'], None)
366
- True
367
- >>> layer2_filter(['model', 'layer1', 'weight'], None)
368
- False
369
- >>>
370
- >>> # Use with nested structures
371
- >>> import jax.tree_util as tree
372
- >>> nested_dict = {
373
- ... 'layer1': {'weight': [1, 2, 3], 'bias': [4, 5]},
374
- ... 'layer2': {'weight': [6, 7, 8], 'bias': [9, 10]}
375
- ... }
376
- >>>
377
- >>> # Filter all 'weight' entries
378
- >>> def filter_weights(path, value):
379
- ... return value if weight_filter(path, value) else None
380
-
381
- See Also
382
- --------
383
- WithTag : Filter based on tag attributes
384
- OfType : Filter based on object type
385
- to_predicate : Convert various inputs to predicates
386
-
387
- Notes
388
- -----
389
- The path is typically a sequence of keys representing the location of
390
- an object in a nested structure, such as the attribute names leading
391
- to a parameter in a neural network model.
392
- """
393
-
394
- key: Key
395
-
396
- def __call__(self, path: PathParts, x: typing.Any) -> bool:
397
- """
398
- Check if the key is present in the path.
399
-
400
- Parameters
401
- ----------
402
- path : PathParts
403
- The path to check for the presence of the key.
404
- x : Any
405
- The object associated with the path (not used in this filter).
406
-
407
- Returns
408
- -------
409
- bool
410
- True if the key is present in the path, False otherwise.
411
- """
412
- return self.key in path
413
-
414
- def __repr__(self) -> str:
415
- return f'PathContains({self.key!r})'
416
-
417
-
418
- @dataclasses.dataclass(frozen=True)
419
- class OfType:
420
- """
421
- Filter objects based on their type.
422
-
423
- This filter checks if an object is an instance of a specific type or
424
- if it has a 'type' attribute that is a subclass of the specified type.
425
- It's useful for filtering specific kinds of objects in a nested structure.
426
-
427
- Parameters
428
- ----------
429
- type : type
430
- The type to match against.
431
-
432
- Attributes
433
- ----------
434
- type : type
435
- The type to match against.
436
-
437
- Examples
438
- --------
439
- .. code-block:: python
440
-
441
- >>> from brainstate.util.filter import OfType
442
- >>> import numpy as np
443
- >>> import jax.numpy as jnp
444
- >>>
445
- >>> # Create a filter for numpy arrays
446
- >>> array_filter = OfType(np.ndarray)
447
- >>>
448
- >>> # Test with different objects
449
- >>> array_filter([], np.array([1, 2, 3]))
450
- True
451
- >>> array_filter([], [1, 2, 3])
452
- False
453
- >>>
454
- >>> # Filter for specific module types
455
- >>> import brainstate as bs
456
- >>> linear_filter = OfType(bs.nn.Linear)
457
- >>>
458
- >>> # Use in model filtering
459
- >>> class Model(bs.nn.Module):
460
- ... def __init__(self):
461
- ... super().__init__()
462
- ... self.linear1 = bs.nn.Linear(10, 20)
463
- ... self.linear2 = bs.nn.Linear(20, 10)
464
- ... self.activation = bs.nn.ReLU()
465
- >>>
466
- >>> # Filter all Linear layers
467
- >>> model = Model()
468
- >>> # linear_filter will match linear1 and linear2, not activation
469
-
470
- See Also
471
- --------
472
- WithTag : Filter based on tag attributes
473
- PathContains : Filter based on path contents
474
- to_predicate : Convert various inputs to predicates
475
-
476
- Notes
477
- -----
478
- This filter also checks for objects that have a 'type' attribute,
479
- which is useful for wrapped or proxy objects that maintain type
480
- information differently.
481
- """
482
- type: type
483
-
484
- def __call__(self, path: PathParts, x: typing.Any):
485
- """
486
- Check if the object is of the specified type.
487
-
488
- Parameters
489
- ----------
490
- path : PathParts
491
- The path to the current object (not used in this filter).
492
- x : Any
493
- The object to check.
494
-
495
- Returns
496
- -------
497
- bool
498
- True if the object is an instance of the specified type or
499
- has a 'type' attribute that is a subclass of the specified type.
500
- """
501
- return isinstance(x, self.type) or (
502
- hasattr(x, 'type') and issubclass(x.type, self.type)
503
- )
504
-
505
- def __repr__(self):
506
- return f'OfType({self.type!r})'
507
-
508
-
509
- class Any:
510
- """
511
- Combine multiple filters using logical OR operation.
512
-
513
- This filter returns True if any of its constituent filters return True.
514
- It's useful for creating flexible filtering criteria where multiple
515
- conditions can be satisfied.
516
-
517
- Parameters
518
- ----------
519
- *filters : Filter
520
- Variable number of filters to be combined with OR logic.
521
-
522
- Attributes
523
- ----------
524
- predicates : tuple of Predicate
525
- Tuple of predicate functions converted from the input filters.
526
-
527
- Examples
528
- --------
529
- .. code-block:: python
530
-
531
- >>> from brainstate.util.filter import Any, WithTag, OfType
532
- >>> import numpy as np
533
- >>>
534
- >>> # Create a filter that matches either tag
535
- >>> trainable_or_frozen = Any('trainable', 'frozen')
536
- >>>
537
- >>> # Test with objects
538
- >>> class Param:
539
- ... def __init__(self, tag):
540
- ... self.tag = tag
541
- >>>
542
- >>> trainable = Param('trainable')
543
- >>> frozen = Param('frozen')
544
- >>> other = Param('other')
545
- >>>
546
- >>> trainable_or_frozen([], trainable)
547
- True
548
- >>> trainable_or_frozen([], frozen)
549
- True
550
- >>> trainable_or_frozen([], other)
551
- False
552
- >>>
553
- >>> # Combine different filter types
554
- >>> array_or_list = Any(
555
- ... OfType(np.ndarray),
556
- ... OfType(list)
557
- ... )
558
- >>>
559
- >>> array_or_list([], np.array([1, 2, 3]))
560
- True
561
- >>> array_or_list([], [1, 2, 3])
562
- True
563
- >>> array_or_list([], (1, 2, 3))
564
- False
565
-
566
- See Also
567
- --------
568
- All : Logical AND combination of filters
569
- Not : Logical negation of a filter
570
- to_predicate : Convert various inputs to predicates
571
-
572
- Notes
573
- -----
574
- The Any filter short-circuits evaluation, returning True as soon as
575
- one of its constituent filters returns True.
576
- """
577
-
578
- def __init__(self, *filters: Filter):
579
- """
580
- Initialize the Any filter.
581
-
582
- Parameters
583
- ----------
584
- *filters : Filter
585
- Variable number of filters to be combined.
586
- """
587
- self.predicates = tuple(
588
- to_predicate(collection_filter) for collection_filter in filters
589
- )
590
-
591
- def __call__(self, path: PathParts, x: typing.Any) -> bool:
592
- """
593
- Apply the composite filter to the given path and object.
594
-
595
- Args:
596
- path (PathParts): The path to the current object.
597
- x (typing.Any): The object to be filtered.
598
-
599
- Returns:
600
- bool: True if any of the constituent predicates return True, False otherwise.
601
- """
602
- return any(predicate(path, x) for predicate in self.predicates)
603
-
604
- def __repr__(self) -> str:
605
- """
606
- Return a string representation of the Any filter.
607
-
608
- Returns:
609
- str: A string representation of the Any filter, including its predicates.
610
- """
611
- return f'Any({", ".join(map(repr, self.predicates))})'
612
-
613
- def __eq__(self, other) -> bool:
614
- """
615
- Check if this Any filter is equal to another object.
616
-
617
- Args:
618
- other: The object to compare with.
619
-
620
- Returns:
621
- bool: True if the other object is an Any filter with the same predicates, False otherwise.
622
- """
623
- return isinstance(other, Any) and self.predicates == other.predicates
624
-
625
- def __hash__(self) -> int:
626
- """
627
- Compute the hash value for this Any filter.
628
-
629
- Returns:
630
- int: The hash value of the predicates tuple.
631
- """
632
- return hash(self.predicates)
633
-
634
-
635
- class All:
636
- """
637
- A filter class that combines multiple filters using a logical AND operation.
638
-
639
- This class creates a composite filter that returns True only if all of its
640
- constituent filters return True.
641
-
642
- Attributes:
643
- predicates (tuple): A tuple of predicate functions converted from the input filters.
644
- """
645
-
646
- def __init__(self, *filters: Filter):
647
- """
648
- Initialize the All filter with a variable number of filters.
649
-
650
- Args:
651
- *filters (Filter): Variable number of filters to be combined.
652
- """
653
- self.predicates = tuple(
654
- to_predicate(collection_filter) for collection_filter in filters
655
- )
656
-
657
- def __call__(self, path: PathParts, x: typing.Any) -> bool:
658
- """
659
- Apply the composite filter to the given path and object.
660
-
661
- Args:
662
- path (PathParts): The path to the current object.
663
- x (typing.Any): The object to be filtered.
664
-
665
- Returns:
666
- bool: True if all of the constituent predicates return True, False otherwise.
667
- """
668
- return all(predicate(path, x) for predicate in self.predicates)
669
-
670
- def __repr__(self) -> str:
671
- """
672
- Return a string representation of the All filter.
673
-
674
- Returns:
675
- str: A string representation of the All filter, including its predicates.
676
- """
677
- return f'All({", ".join(map(repr, self.predicates))})'
678
-
679
- def __eq__(self, other) -> bool:
680
- """
681
- Check if this All filter is equal to another object.
682
-
683
- Args:
684
- other: The object to compare with.
685
-
686
- Returns:
687
- bool: True if the other object is an All filter with the same predicates, False otherwise.
688
- """
689
- return isinstance(other, All) and self.predicates == other.predicates
690
-
691
- def __hash__(self) -> int:
692
- """
693
- Compute the hash value for this All filter.
694
-
695
- Returns:
696
- int: The hash value of the predicates tuple.
697
- """
698
- return hash(self.predicates)
699
-
700
-
701
- class Not:
702
- """
703
- A filter class that negates the result of another filter.
704
-
705
- This class creates a new filter that returns the opposite boolean value
706
- of the filter it wraps.
707
-
708
- Attributes:
709
- predicate (Predicate): The predicate function converted from the input filter.
710
- """
711
-
712
- def __init__(self, collection_filter: Filter, /):
713
- """
714
- Initialize the Not filter with another filter.
715
-
716
- Args:
717
- collection_filter (Filter): The filter to be negated.
718
- """
719
- self.predicate = to_predicate(collection_filter)
720
-
721
- def __call__(self, path: PathParts, x: typing.Any) -> bool:
722
- """
723
- Apply the negated filter to the given path and object.
724
-
725
- Args:
726
- path (PathParts): The path to the current object.
727
- x (typing.Any): The object to be filtered.
728
-
729
- Returns:
730
- bool: The negation of the result from the wrapped predicate.
731
- """
732
- return not self.predicate(path, x)
733
-
734
- def __repr__(self) -> str:
735
- """
736
- Return a string representation of the Not filter.
737
-
738
- Returns:
739
- str: A string representation of the Not filter, including its predicate.
740
- """
741
- return f'Not({self.predicate!r})'
742
-
743
- def __eq__(self, other) -> bool:
744
- """
745
- Check if this Not filter is equal to another object.
746
-
747
- Args:
748
- other: The object to compare with.
749
-
750
- Returns:
751
- bool: True if the other object is a Not filter with the same predicate, False otherwise.
752
- """
753
- return isinstance(other, Not) and self.predicate == other.predicate
754
-
755
- def __hash__(self) -> int:
756
- """
757
- Compute the hash value for this Not filter.
758
-
759
- Returns:
760
- int: The hash value of the predicate.
761
- """
762
- return hash(self.predicate)
763
-
764
-
765
- class Everything:
766
- """
767
- Filter that matches all objects.
768
-
769
- This filter always returns True, effectively disabling filtering.
770
- It's useful as a default filter or when you want to select everything
771
- in a structure.
772
-
773
- Examples
774
- --------
775
- .. code-block:: python
776
-
777
- >>> from brainstate.util.filter import Everything
778
- >>>
779
- >>> # Create a filter that matches everything
780
- >>> all_filter = Everything()
781
- >>>
782
- >>> # Always returns True
783
- >>> all_filter([], 'any_object')
784
- True
785
- >>> all_filter(['some', 'path'], 42)
786
- True
787
- >>> all_filter([], None)
788
- True
789
- >>>
790
- >>> # Useful as a default filter
791
- >>> def process_data(data, filter=None):
792
- ... if filter is None:
793
- ... filter = Everything()
794
- ... # Process all data when no specific filter is provided
795
-
796
- See Also
797
- --------
798
- Nothing : Filter that matches no objects
799
- to_predicate : Convert True to Everything filter
800
-
801
- Notes
802
- -----
803
- This filter is equivalent to using ``to_predicate(True)`` or
804
- ``to_predicate(...)`` (Ellipsis).
805
- """
806
-
807
- def __call__(self, path: PathParts, x: typing.Any) -> bool:
808
- """
809
- Always return True.
810
-
811
- Parameters
812
- ----------
813
- path : PathParts
814
- The path to the current object (ignored).
815
- x : Any
816
- The object to be filtered (ignored).
817
-
818
- Returns
819
- -------
820
- bool
821
- Always returns True.
822
- """
823
- return True
824
-
825
- def __repr__(self) -> str:
826
- """
827
- Return a string representation of the Everything filter.
828
-
829
- Returns:
830
- str: The string 'Everything()'.
831
- """
832
- return 'Everything()'
833
-
834
- def __eq__(self, other) -> bool:
835
- """
836
- Check if this Everything filter is equal to another object.
837
-
838
- Args:
839
- other: The object to compare with.
840
-
841
- Returns:
842
- bool: True if the other object is an instance of Everything, False otherwise.
843
- """
844
- return isinstance(other, Everything)
845
-
846
- def __hash__(self) -> int:
847
- """
848
- Compute the hash value for this Everything filter.
849
-
850
- Returns:
851
- int: The hash value of the Everything class.
852
- """
853
- return hash(Everything)
854
-
855
-
856
- class Nothing:
857
- """
858
- Filter that matches no objects.
859
-
860
- This filter always returns False, effectively filtering out all objects.
861
- It's useful for disabling selection or creating empty filter results.
862
-
863
- Examples
864
- --------
865
- .. code-block:: python
866
-
867
- >>> from brainstate.util.filter import Nothing
868
- >>>
869
- >>> # Create a filter that matches nothing
870
- >>> none_filter = Nothing()
871
- >>>
872
- >>> # Always returns False
873
- >>> none_filter([], 'any_object')
874
- False
875
- >>> none_filter(['some', 'path'], 42)
876
- False
877
- >>> none_filter([], None)
878
- False
879
- >>>
880
- >>> # Useful for conditional filtering
881
- >>> def get_params(model, include_frozen=False):
882
- ... if include_frozen:
883
- ... filter = Everything()
884
- ... else:
885
- ... filter = Nothing() # Exclude all frozen params
886
- ... # Apply filter to model parameters
887
-
888
- See Also
889
- --------
890
- Everything : Filter that matches all objects
891
- to_predicate : Convert False or None to Nothing filter
892
-
893
- Notes
894
- -----
895
- This filter is equivalent to using ``to_predicate(False)`` or
896
- ``to_predicate(None)``.
897
- """
898
-
899
- def __call__(self, path: PathParts, x: typing.Any) -> bool:
900
- """
901
- Always return False.
902
-
903
- Parameters
904
- ----------
905
- path : PathParts
906
- The path to the current object (ignored).
907
- x : Any
908
- The object to be filtered (ignored).
909
-
910
- Returns
911
- -------
912
- bool
913
- Always returns False.
914
- """
915
- return False
916
-
917
- def __repr__(self) -> str:
918
- """
919
- Return a string representation of the Nothing filter.
920
-
921
- Returns:
922
- str: The string 'Nothing()'.
923
- """
924
- return 'Nothing()'
925
-
926
- def __eq__(self, other) -> bool:
927
- """
928
- Check if this Nothing filter is equal to another object.
929
-
930
- Args:
931
- other: The object to compare with.
932
-
933
- Returns:
934
- bool: True if the other object is an instance of Nothing, False otherwise.
935
- """
936
- return isinstance(other, Nothing)
937
-
938
- def __hash__(self) -> int:
939
- """
940
- Compute the hash value for this Nothing filter.
941
-
942
- Returns:
943
- int: The hash value of the Nothing class.
944
- """
945
- return hash(Nothing)
1
+ # The file is adapted from the Flax library (https://github.com/google/flax).
2
+ # The credit should go to the Flax authors.
3
+ #
4
+ # Copyright 2024 The Flax Authors.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """
19
+ Filter utilities for traversing and selecting objects in nested structures.
20
+
21
+ This module provides a flexible filtering system for working with nested data
22
+ structures in BrainState. It offers various filter classes and utilities to
23
+ select, match, and transform objects based on their properties, types, or
24
+ positions within a hierarchical structure.
25
+
26
+ Key Features
27
+ ------------
28
+ - **Type-based filtering**: Select objects by their type or inheritance
29
+ - **Tag-based filtering**: Filter objects that have specific tags
30
+ - **Path-based filtering**: Select based on object paths in nested structures
31
+ - **Logical operations**: Combine filters with AND, OR, and NOT operations
32
+ - **Flexible conversion**: Convert various inputs to predicate functions
33
+
34
+ Filter Types
35
+ ------------
36
+ The module provides several built-in filter classes:
37
+
38
+ - :class:`WithTag`: Filters objects with specific tags
39
+ - :class:`PathContains`: Filters based on path contents
40
+ - :class:`OfType`: Filters by object type
41
+ - :class:`Any`: Logical OR combination of filters
42
+ - :class:`All`: Logical AND combination of filters
43
+ - :class:`Not`: Logical negation of a filter
44
+ - :class:`Everything`: Matches all objects
45
+ - :class:`Nothing`: Matches no objects
46
+
47
+ Examples
48
+ --------
49
+
50
+ .. code-block:: python
51
+
52
+ >>> import brainstate as bs
53
+ >>> from brainstate.util.filter import WithTag, OfType, Any, All, Not
54
+ >>>
55
+ >>> # Filter objects with a specific tag
56
+ >>> tag_filter = WithTag('trainable')
57
+ >>>
58
+ >>> # Filter objects of a specific type
59
+ >>> type_filter = OfType(bs.nn.Linear)
60
+ >>>
61
+ >>> # Combine filters with logical operations
62
+ >>> combined_filter = All(
63
+ ... WithTag('trainable'),
64
+ ... OfType(bs.nn.Linear)
65
+ ... )
66
+ >>>
67
+ >>> # Negate a filter
68
+ >>> not_trainable = Not(WithTag('trainable'))
69
+ >>>
70
+ >>> # Use Any for OR operations
71
+ >>> any_filter = Any(
72
+ ... OfType(bs.nn.Linear),
73
+ ... OfType(bs.nn.Conv)
74
+ ... )
75
+
76
+ Using Filters with Tree Operations
77
+ -----------------------------------
78
+
79
+ .. code-block:: python
80
+
81
+ >>> import brainstate as bs
82
+ >>> import jax.tree_util as tree
83
+ >>> from brainstate.util.filter import to_predicate, WithTag
84
+ >>>
85
+ >>> # Create a model with tagged parameters
86
+ >>> class Model(bs.Module):
87
+ ... def __init__(self):
88
+ ... super().__init__()
89
+ ... self.layer1 = bs.nn.Linear(10, 20)
90
+ ... self.layer1.tag = 'trainable'
91
+ ... self.layer2 = bs.nn.Linear(20, 10)
92
+ ... self.layer2.tag = 'frozen'
93
+ >>>
94
+ >>> model = Model()
95
+ >>>
96
+ >>> # Filter trainable parameters
97
+ >>> trainable_filter = to_predicate('trainable')
98
+ >>>
99
+ >>> # Apply filter in tree operations
100
+ >>> def get_trainable_params(model):
101
+ ... return tree.tree_map_with_path(
102
+ ... lambda path, x: x if trainable_filter(path, x) else None,
103
+ ... model
104
+ ... )
105
+
106
+ Notes
107
+ -----
108
+ This module is adapted from the Flax library and provides similar functionality
109
+ for filtering and selecting components in neural network models and other
110
+ hierarchical data structures.
111
+
112
+ See Also
113
+ --------
114
+ brainstate.tree : Tree manipulation utilities
115
+ brainstate.typing : Type definitions for filters and predicates
116
+
117
+ """
118
+
119
+ import builtins
120
+ import dataclasses
121
+ import typing
122
+ from typing import TYPE_CHECKING
123
+
124
+ from brainstate.typing import Filter, PathParts, Predicate, Key
125
+
126
+ if TYPE_CHECKING:
127
+ ellipsis = builtins.ellipsis
128
+ else:
129
+ ellipsis = typing.Any
130
+
131
+ __all__ = [
132
+ 'to_predicate',
133
+ 'WithTag',
134
+ 'PathContains',
135
+ 'OfType',
136
+ 'Any',
137
+ 'All',
138
+ 'Nothing',
139
+ 'Not',
140
+ 'Everything',
141
+ ]
142
+
143
+
144
+ def to_predicate(the_filter: Filter) -> Predicate:
145
+ """
146
+ Convert a Filter to a predicate function.
147
+
148
+ This function takes various types of filters and converts them into
149
+ corresponding predicate functions that can be used for filtering objects
150
+ in nested structures.
151
+
152
+ Parameters
153
+ ----------
154
+ the_filter : Filter
155
+ The filter to be converted. Can be of various types:
156
+
157
+ - **str**: Converted to a :class:`WithTag` filter
158
+ - **type**: Converted to an :class:`OfType` filter
159
+ - **bool**: ``True`` becomes :class:`Everything`, ``False`` becomes :class:`Nothing`
160
+ - **Ellipsis** (...): Converted to :class:`Everything`
161
+ - **None**: Converted to :class:`Nothing`
162
+ - **callable**: Returned as-is
163
+ - **list or tuple**: Converted to :class:`Any` filter with elements as arguments
164
+
165
+ Returns
166
+ -------
167
+ Predicate
168
+ A callable predicate function that takes (path, object) and returns bool.
169
+
170
+ Raises
171
+ ------
172
+ TypeError
173
+ If the input filter is of an invalid type.
174
+
175
+ Examples
176
+ --------
177
+ .. code-block:: python
178
+
179
+ >>> from brainstate.util.filter import to_predicate
180
+ >>>
181
+ >>> # Convert string to WithTag filter
182
+ >>> pred = to_predicate('trainable')
183
+ >>> pred([], {'tag': 'trainable'})
184
+ True
185
+ >>>
186
+ >>> # Convert type to OfType filter
187
+ >>> import numpy as np
188
+ >>> pred = to_predicate(np.ndarray)
189
+ >>> pred([], np.array([1, 2, 3]))
190
+ True
191
+ >>>
192
+ >>> # Convert bool to Everything/Nothing
193
+ >>> pred_all = to_predicate(True)
194
+ >>> pred_all([], 'anything')
195
+ True
196
+ >>> pred_none = to_predicate(False)
197
+ >>> pred_none([], 'anything')
198
+ False
199
+ >>>
200
+ >>> # Convert list to Any filter
201
+ >>> pred = to_predicate(['tag1', 'tag2'])
202
+ >>> # This will match objects with either 'tag1' or 'tag2'
203
+
204
+ See Also
205
+ --------
206
+ WithTag : Filter for objects with specific tags
207
+ OfType : Filter for objects of specific types
208
+ Any : Logical OR combination of filters
209
+ Everything : Filter that matches all objects
210
+ Nothing : Filter that matches no objects
211
+
212
+ Notes
213
+ -----
214
+ This function is the main entry point for creating predicate functions
215
+ from various filter specifications. It provides a flexible way to define
216
+ filtering criteria without explicitly instantiating filter classes.
217
+ """
218
+
219
+ if isinstance(the_filter, str):
220
+ return WithTag(the_filter)
221
+ elif isinstance(the_filter, type):
222
+ return OfType(the_filter)
223
+ elif isinstance(the_filter, bool):
224
+ if the_filter:
225
+ return Everything()
226
+ else:
227
+ return Nothing()
228
+ elif the_filter is Ellipsis:
229
+ return Everything()
230
+ elif the_filter is None:
231
+ return Nothing()
232
+ elif callable(the_filter):
233
+ return the_filter
234
+ elif isinstance(the_filter, (list, tuple)):
235
+ return Any(*the_filter)
236
+ else:
237
+ raise TypeError(f'Invalid collection filter: {the_filter!r}. ')
238
+
239
+
240
+ @dataclasses.dataclass(frozen=True)
241
+ class WithTag:
242
+ """
243
+ Filter objects that have a specific tag attribute.
244
+
245
+ This filter checks if an object has a 'tag' attribute that matches
246
+ the specified tag value. It's commonly used to filter parameters or
247
+ modules in neural networks based on their assigned tags.
248
+
249
+ Parameters
250
+ ----------
251
+ tag : str
252
+ The tag value to match against.
253
+
254
+ Attributes
255
+ ----------
256
+ tag : str
257
+ The tag value to match against.
258
+
259
+ Examples
260
+ --------
261
+ .. code-block:: python
262
+
263
+ >>> from brainstate.util.filter import WithTag
264
+ >>> import brainstate as bs
265
+ >>>
266
+ >>> # Create a filter for 'trainable' tag
267
+ >>> filter_trainable = WithTag('trainable')
268
+ >>>
269
+ >>> # Test with an object that has the tag
270
+ >>> class Param:
271
+ ... def __init__(self, tag):
272
+ ... self.tag = tag
273
+ >>>
274
+ >>> param1 = Param('trainable')
275
+ >>> param2 = Param('frozen')
276
+ >>>
277
+ >>> filter_trainable([], param1)
278
+ True
279
+ >>> filter_trainable([], param2)
280
+ False
281
+ >>>
282
+ >>> # Use with neural network modules
283
+ >>> class MyModule(bs.Module):
284
+ ... def __init__(self):
285
+ ... super().__init__()
286
+ ... self.weight = bs.State(bs.random.randn(10, 10))
287
+ ... self.weight.tag = 'trainable'
288
+ ... self.bias = bs.State(bs.zeros(10))
289
+ ... self.bias.tag = 'frozen'
290
+
291
+ See Also
292
+ --------
293
+ PathContains : Filter based on path contents
294
+ OfType : Filter based on object type
295
+ to_predicate : Convert various inputs to predicates
296
+
297
+ Notes
298
+ -----
299
+ The filter only matches objects that have a 'tag' attribute. Objects
300
+ without this attribute will not match, even if the filter is looking
301
+ for a specific tag value.
302
+ """
303
+
304
+ tag: str
305
+
306
+ def __call__(self, path: PathParts, x: typing.Any) -> bool:
307
+ """
308
+ Check if the object has a matching tag.
309
+
310
+ Parameters
311
+ ----------
312
+ path : PathParts
313
+ The path to the current object (not used in this filter).
314
+ x : Any
315
+ The object to check for the tag.
316
+
317
+ Returns
318
+ -------
319
+ bool
320
+ True if the object has a 'tag' attribute matching the specified tag,
321
+ False otherwise.
322
+ """
323
+ return hasattr(x, 'tag') and x.tag == self.tag
324
+
325
+ def __repr__(self) -> str:
326
+ return f'WithTag({self.tag!r})'
327
+
328
+
329
+ @dataclasses.dataclass(frozen=True)
330
+ class PathContains:
331
+ """
332
+ Filter objects based on whether their path contains a specific key.
333
+
334
+ This filter checks if a given key appears anywhere in the path to an object
335
+ within a nested structure. It's useful for selecting objects at specific
336
+ locations or with specific names in a hierarchy.
337
+
338
+ Parameters
339
+ ----------
340
+ key : Key
341
+ The key to search for in the path.
342
+
343
+ Attributes
344
+ ----------
345
+ key : Key
346
+ The key to search for in the path.
347
+
348
+ Examples
349
+ --------
350
+ .. code-block:: python
351
+
352
+ >>> from brainstate.util.filter import PathContains
353
+ >>>
354
+ >>> # Create a filter for paths containing 'weight'
355
+ >>> weight_filter = PathContains('weight')
356
+ >>>
357
+ >>> # Test with different paths
358
+ >>> weight_filter(['model', 'layer1', 'weight'], None)
359
+ True
360
+ >>> weight_filter(['model', 'layer1', 'bias'], None)
361
+ False
362
+ >>>
363
+ >>> # Filter for specific layer
364
+ >>> layer2_filter = PathContains('layer2')
365
+ >>> layer2_filter(['model', 'layer2', 'weight'], None)
366
+ True
367
+ >>> layer2_filter(['model', 'layer1', 'weight'], None)
368
+ False
369
+ >>>
370
+ >>> # Use with nested structures
371
+ >>> import jax.tree_util as tree
372
+ >>> nested_dict = {
373
+ ... 'layer1': {'weight': [1, 2, 3], 'bias': [4, 5]},
374
+ ... 'layer2': {'weight': [6, 7, 8], 'bias': [9, 10]}
375
+ ... }
376
+ >>>
377
+ >>> # Filter all 'weight' entries
378
+ >>> def filter_weights(path, value):
379
+ ... return value if weight_filter(path, value) else None
380
+
381
+ See Also
382
+ --------
383
+ WithTag : Filter based on tag attributes
384
+ OfType : Filter based on object type
385
+ to_predicate : Convert various inputs to predicates
386
+
387
+ Notes
388
+ -----
389
+ The path is typically a sequence of keys representing the location of
390
+ an object in a nested structure, such as the attribute names leading
391
+ to a parameter in a neural network model.
392
+ """
393
+
394
+ key: Key
395
+
396
+ def __call__(self, path: PathParts, x: typing.Any) -> bool:
397
+ """
398
+ Check if the key is present in the path.
399
+
400
+ Parameters
401
+ ----------
402
+ path : PathParts
403
+ The path to check for the presence of the key.
404
+ x : Any
405
+ The object associated with the path (not used in this filter).
406
+
407
+ Returns
408
+ -------
409
+ bool
410
+ True if the key is present in the path, False otherwise.
411
+ """
412
+ return self.key in path
413
+
414
+ def __repr__(self) -> str:
415
+ return f'PathContains({self.key!r})'
416
+
417
+
418
+ @dataclasses.dataclass(frozen=True)
419
+ class OfType:
420
+ """
421
+ Filter objects based on their type.
422
+
423
+ This filter checks if an object is an instance of a specific type or
424
+ if it has a 'type' attribute that is a subclass of the specified type.
425
+ It's useful for filtering specific kinds of objects in a nested structure.
426
+
427
+ Parameters
428
+ ----------
429
+ type : type
430
+ The type to match against.
431
+
432
+ Attributes
433
+ ----------
434
+ type : type
435
+ The type to match against.
436
+
437
+ Examples
438
+ --------
439
+ .. code-block:: python
440
+
441
+ >>> from brainstate.util.filter import OfType
442
+ >>> import numpy as np
443
+ >>> import jax.numpy as jnp
444
+ >>>
445
+ >>> # Create a filter for numpy arrays
446
+ >>> array_filter = OfType(np.ndarray)
447
+ >>>
448
+ >>> # Test with different objects
449
+ >>> array_filter([], np.array([1, 2, 3]))
450
+ True
451
+ >>> array_filter([], [1, 2, 3])
452
+ False
453
+ >>>
454
+ >>> # Filter for specific module types
455
+ >>> import brainstate as bs
456
+ >>> linear_filter = OfType(bs.nn.Linear)
457
+ >>>
458
+ >>> # Use in model filtering
459
+ >>> class Model(bs.nn.Module):
460
+ ... def __init__(self):
461
+ ... super().__init__()
462
+ ... self.linear1 = bs.nn.Linear(10, 20)
463
+ ... self.linear2 = bs.nn.Linear(20, 10)
464
+ ... self.activation = bs.nn.ReLU()
465
+ >>>
466
+ >>> # Filter all Linear layers
467
+ >>> model = Model()
468
+ >>> # linear_filter will match linear1 and linear2, not activation
469
+
470
+ See Also
471
+ --------
472
+ WithTag : Filter based on tag attributes
473
+ PathContains : Filter based on path contents
474
+ to_predicate : Convert various inputs to predicates
475
+
476
+ Notes
477
+ -----
478
+ This filter also checks for objects that have a 'type' attribute,
479
+ which is useful for wrapped or proxy objects that maintain type
480
+ information differently.
481
+ """
482
+ type: type
483
+
484
+ def __call__(self, path: PathParts, x: typing.Any):
485
+ """
486
+ Check if the object is of the specified type.
487
+
488
+ Parameters
489
+ ----------
490
+ path : PathParts
491
+ The path to the current object (not used in this filter).
492
+ x : Any
493
+ The object to check.
494
+
495
+ Returns
496
+ -------
497
+ bool
498
+ True if the object is an instance of the specified type or
499
+ has a 'type' attribute that is a subclass of the specified type.
500
+ """
501
+ return isinstance(x, self.type) or (
502
+ hasattr(x, 'type') and issubclass(x.type, self.type)
503
+ )
504
+
505
+ def __repr__(self):
506
+ return f'OfType({self.type!r})'
507
+
508
+
509
+ class Any:
510
+ """
511
+ Combine multiple filters using logical OR operation.
512
+
513
+ This filter returns True if any of its constituent filters return True.
514
+ It's useful for creating flexible filtering criteria where multiple
515
+ conditions can be satisfied.
516
+
517
+ Parameters
518
+ ----------
519
+ *filters : Filter
520
+ Variable number of filters to be combined with OR logic.
521
+
522
+ Attributes
523
+ ----------
524
+ predicates : tuple of Predicate
525
+ Tuple of predicate functions converted from the input filters.
526
+
527
+ Examples
528
+ --------
529
+ .. code-block:: python
530
+
531
+ >>> from brainstate.util.filter import Any, WithTag, OfType
532
+ >>> import numpy as np
533
+ >>>
534
+ >>> # Create a filter that matches either tag
535
+ >>> trainable_or_frozen = Any('trainable', 'frozen')
536
+ >>>
537
+ >>> # Test with objects
538
+ >>> class Param:
539
+ ... def __init__(self, tag):
540
+ ... self.tag = tag
541
+ >>>
542
+ >>> trainable = Param('trainable')
543
+ >>> frozen = Param('frozen')
544
+ >>> other = Param('other')
545
+ >>>
546
+ >>> trainable_or_frozen([], trainable)
547
+ True
548
+ >>> trainable_or_frozen([], frozen)
549
+ True
550
+ >>> trainable_or_frozen([], other)
551
+ False
552
+ >>>
553
+ >>> # Combine different filter types
554
+ >>> array_or_list = Any(
555
+ ... OfType(np.ndarray),
556
+ ... OfType(list)
557
+ ... )
558
+ >>>
559
+ >>> array_or_list([], np.array([1, 2, 3]))
560
+ True
561
+ >>> array_or_list([], [1, 2, 3])
562
+ True
563
+ >>> array_or_list([], (1, 2, 3))
564
+ False
565
+
566
+ See Also
567
+ --------
568
+ All : Logical AND combination of filters
569
+ Not : Logical negation of a filter
570
+ to_predicate : Convert various inputs to predicates
571
+
572
+ Notes
573
+ -----
574
+ The Any filter short-circuits evaluation, returning True as soon as
575
+ one of its constituent filters returns True.
576
+ """
577
+
578
+ def __init__(self, *filters: Filter):
579
+ """
580
+ Initialize the Any filter.
581
+
582
+ Parameters
583
+ ----------
584
+ *filters : Filter
585
+ Variable number of filters to be combined.
586
+ """
587
+ self.predicates = tuple(
588
+ to_predicate(collection_filter) for collection_filter in filters
589
+ )
590
+
591
+ def __call__(self, path: PathParts, x: typing.Any) -> bool:
592
+ """
593
+ Apply the composite filter to the given path and object.
594
+
595
+ Args:
596
+ path (PathParts): The path to the current object.
597
+ x (typing.Any): The object to be filtered.
598
+
599
+ Returns:
600
+ bool: True if any of the constituent predicates return True, False otherwise.
601
+ """
602
+ return any(predicate(path, x) for predicate in self.predicates)
603
+
604
+ def __repr__(self) -> str:
605
+ """
606
+ Return a string representation of the Any filter.
607
+
608
+ Returns:
609
+ str: A string representation of the Any filter, including its predicates.
610
+ """
611
+ return f'Any({", ".join(map(repr, self.predicates))})'
612
+
613
+ def __eq__(self, other) -> bool:
614
+ """
615
+ Check if this Any filter is equal to another object.
616
+
617
+ Args:
618
+ other: The object to compare with.
619
+
620
+ Returns:
621
+ bool: True if the other object is an Any filter with the same predicates, False otherwise.
622
+ """
623
+ return isinstance(other, Any) and self.predicates == other.predicates
624
+
625
+ def __hash__(self) -> int:
626
+ """
627
+ Compute the hash value for this Any filter.
628
+
629
+ Returns:
630
+ int: The hash value of the predicates tuple.
631
+ """
632
+ return hash(self.predicates)
633
+
634
+
635
+ class All:
636
+ """
637
+ A filter class that combines multiple filters using a logical AND operation.
638
+
639
+ This class creates a composite filter that returns True only if all of its
640
+ constituent filters return True.
641
+
642
+ Attributes:
643
+ predicates (tuple): A tuple of predicate functions converted from the input filters.
644
+ """
645
+
646
+ def __init__(self, *filters: Filter):
647
+ """
648
+ Initialize the All filter with a variable number of filters.
649
+
650
+ Args:
651
+ *filters (Filter): Variable number of filters to be combined.
652
+ """
653
+ self.predicates = tuple(
654
+ to_predicate(collection_filter) for collection_filter in filters
655
+ )
656
+
657
+ def __call__(self, path: PathParts, x: typing.Any) -> bool:
658
+ """
659
+ Apply the composite filter to the given path and object.
660
+
661
+ Args:
662
+ path (PathParts): The path to the current object.
663
+ x (typing.Any): The object to be filtered.
664
+
665
+ Returns:
666
+ bool: True if all of the constituent predicates return True, False otherwise.
667
+ """
668
+ return all(predicate(path, x) for predicate in self.predicates)
669
+
670
+ def __repr__(self) -> str:
671
+ """
672
+ Return a string representation of the All filter.
673
+
674
+ Returns:
675
+ str: A string representation of the All filter, including its predicates.
676
+ """
677
+ return f'All({", ".join(map(repr, self.predicates))})'
678
+
679
+ def __eq__(self, other) -> bool:
680
+ """
681
+ Check if this All filter is equal to another object.
682
+
683
+ Args:
684
+ other: The object to compare with.
685
+
686
+ Returns:
687
+ bool: True if the other object is an All filter with the same predicates, False otherwise.
688
+ """
689
+ return isinstance(other, All) and self.predicates == other.predicates
690
+
691
+ def __hash__(self) -> int:
692
+ """
693
+ Compute the hash value for this All filter.
694
+
695
+ Returns:
696
+ int: The hash value of the predicates tuple.
697
+ """
698
+ return hash(self.predicates)
699
+
700
+
701
+ class Not:
702
+ """
703
+ A filter class that negates the result of another filter.
704
+
705
+ This class creates a new filter that returns the opposite boolean value
706
+ of the filter it wraps.
707
+
708
+ Attributes:
709
+ predicate (Predicate): The predicate function converted from the input filter.
710
+ """
711
+
712
+ def __init__(self, collection_filter: Filter, /):
713
+ """
714
+ Initialize the Not filter with another filter.
715
+
716
+ Args:
717
+ collection_filter (Filter): The filter to be negated.
718
+ """
719
+ self.predicate = to_predicate(collection_filter)
720
+
721
+ def __call__(self, path: PathParts, x: typing.Any) -> bool:
722
+ """
723
+ Apply the negated filter to the given path and object.
724
+
725
+ Args:
726
+ path (PathParts): The path to the current object.
727
+ x (typing.Any): The object to be filtered.
728
+
729
+ Returns:
730
+ bool: The negation of the result from the wrapped predicate.
731
+ """
732
+ return not self.predicate(path, x)
733
+
734
+ def __repr__(self) -> str:
735
+ """
736
+ Return a string representation of the Not filter.
737
+
738
+ Returns:
739
+ str: A string representation of the Not filter, including its predicate.
740
+ """
741
+ return f'Not({self.predicate!r})'
742
+
743
+ def __eq__(self, other) -> bool:
744
+ """
745
+ Check if this Not filter is equal to another object.
746
+
747
+ Args:
748
+ other: The object to compare with.
749
+
750
+ Returns:
751
+ bool: True if the other object is a Not filter with the same predicate, False otherwise.
752
+ """
753
+ return isinstance(other, Not) and self.predicate == other.predicate
754
+
755
+ def __hash__(self) -> int:
756
+ """
757
+ Compute the hash value for this Not filter.
758
+
759
+ Returns:
760
+ int: The hash value of the predicate.
761
+ """
762
+ return hash(self.predicate)
763
+
764
+
765
+ class Everything:
766
+ """
767
+ Filter that matches all objects.
768
+
769
+ This filter always returns True, effectively disabling filtering.
770
+ It's useful as a default filter or when you want to select everything
771
+ in a structure.
772
+
773
+ Examples
774
+ --------
775
+ .. code-block:: python
776
+
777
+ >>> from brainstate.util.filter import Everything
778
+ >>>
779
+ >>> # Create a filter that matches everything
780
+ >>> all_filter = Everything()
781
+ >>>
782
+ >>> # Always returns True
783
+ >>> all_filter([], 'any_object')
784
+ True
785
+ >>> all_filter(['some', 'path'], 42)
786
+ True
787
+ >>> all_filter([], None)
788
+ True
789
+ >>>
790
+ >>> # Useful as a default filter
791
+ >>> def process_data(data, filter=None):
792
+ ... if filter is None:
793
+ ... filter = Everything()
794
+ ... # Process all data when no specific filter is provided
795
+
796
+ See Also
797
+ --------
798
+ Nothing : Filter that matches no objects
799
+ to_predicate : Convert True to Everything filter
800
+
801
+ Notes
802
+ -----
803
+ This filter is equivalent to using ``to_predicate(True)`` or
804
+ ``to_predicate(...)`` (Ellipsis).
805
+ """
806
+
807
+ def __call__(self, path: PathParts, x: typing.Any) -> bool:
808
+ """
809
+ Always return True.
810
+
811
+ Parameters
812
+ ----------
813
+ path : PathParts
814
+ The path to the current object (ignored).
815
+ x : Any
816
+ The object to be filtered (ignored).
817
+
818
+ Returns
819
+ -------
820
+ bool
821
+ Always returns True.
822
+ """
823
+ return True
824
+
825
+ def __repr__(self) -> str:
826
+ """
827
+ Return a string representation of the Everything filter.
828
+
829
+ Returns:
830
+ str: The string 'Everything()'.
831
+ """
832
+ return 'Everything()'
833
+
834
+ def __eq__(self, other) -> bool:
835
+ """
836
+ Check if this Everything filter is equal to another object.
837
+
838
+ Args:
839
+ other: The object to compare with.
840
+
841
+ Returns:
842
+ bool: True if the other object is an instance of Everything, False otherwise.
843
+ """
844
+ return isinstance(other, Everything)
845
+
846
+ def __hash__(self) -> int:
847
+ """
848
+ Compute the hash value for this Everything filter.
849
+
850
+ Returns:
851
+ int: The hash value of the Everything class.
852
+ """
853
+ return hash(Everything)
854
+
855
+
856
+ class Nothing:
857
+ """
858
+ Filter that matches no objects.
859
+
860
+ This filter always returns False, effectively filtering out all objects.
861
+ It's useful for disabling selection or creating empty filter results.
862
+
863
+ Examples
864
+ --------
865
+ .. code-block:: python
866
+
867
+ >>> from brainstate.util.filter import Nothing
868
+ >>>
869
+ >>> # Create a filter that matches nothing
870
+ >>> none_filter = Nothing()
871
+ >>>
872
+ >>> # Always returns False
873
+ >>> none_filter([], 'any_object')
874
+ False
875
+ >>> none_filter(['some', 'path'], 42)
876
+ False
877
+ >>> none_filter([], None)
878
+ False
879
+ >>>
880
+ >>> # Useful for conditional filtering
881
+ >>> def get_params(model, include_frozen=False):
882
+ ... if include_frozen:
883
+ ... filter = Everything()
884
+ ... else:
885
+ ... filter = Nothing() # Exclude all frozen params
886
+ ... # Apply filter to model parameters
887
+
888
+ See Also
889
+ --------
890
+ Everything : Filter that matches all objects
891
+ to_predicate : Convert False or None to Nothing filter
892
+
893
+ Notes
894
+ -----
895
+ This filter is equivalent to using ``to_predicate(False)`` or
896
+ ``to_predicate(None)``.
897
+ """
898
+
899
+ def __call__(self, path: PathParts, x: typing.Any) -> bool:
900
+ """
901
+ Always return False.
902
+
903
+ Parameters
904
+ ----------
905
+ path : PathParts
906
+ The path to the current object (ignored).
907
+ x : Any
908
+ The object to be filtered (ignored).
909
+
910
+ Returns
911
+ -------
912
+ bool
913
+ Always returns False.
914
+ """
915
+ return False
916
+
917
+ def __repr__(self) -> str:
918
+ """
919
+ Return a string representation of the Nothing filter.
920
+
921
+ Returns:
922
+ str: The string 'Nothing()'.
923
+ """
924
+ return 'Nothing()'
925
+
926
+ def __eq__(self, other) -> bool:
927
+ """
928
+ Check if this Nothing filter is equal to another object.
929
+
930
+ Args:
931
+ other: The object to compare with.
932
+
933
+ Returns:
934
+ bool: True if the other object is an instance of Nothing, False otherwise.
935
+ """
936
+ return isinstance(other, Nothing)
937
+
938
+ def __hash__(self) -> int:
939
+ """
940
+ Compute the hash value for this Nothing filter.
941
+
942
+ Returns:
943
+ int: The hash value of the Nothing class.
944
+ """
945
+ return hash(Nothing)