brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +15 -28
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/util/filter.py
CHANGED
@@ -15,6 +15,107 @@
|
|
15
15
|
# See the License for the specific language governing permissions and
|
16
16
|
# limitations under the License.
|
17
17
|
|
18
|
+
"""
|
19
|
+
Filter utilities for traversing and selecting objects in nested structures.
|
20
|
+
|
21
|
+
This module provides a flexible filtering system for working with nested data
|
22
|
+
structures in BrainState. It offers various filter classes and utilities to
|
23
|
+
select, match, and transform objects based on their properties, types, or
|
24
|
+
positions within a hierarchical structure.
|
25
|
+
|
26
|
+
Key Features
|
27
|
+
------------
|
28
|
+
- **Type-based filtering**: Select objects by their type or inheritance
|
29
|
+
- **Tag-based filtering**: Filter objects that have specific tags
|
30
|
+
- **Path-based filtering**: Select based on object paths in nested structures
|
31
|
+
- **Logical operations**: Combine filters with AND, OR, and NOT operations
|
32
|
+
- **Flexible conversion**: Convert various inputs to predicate functions
|
33
|
+
|
34
|
+
Filter Types
|
35
|
+
------------
|
36
|
+
The module provides several built-in filter classes:
|
37
|
+
|
38
|
+
- :class:`WithTag`: Filters objects with specific tags
|
39
|
+
- :class:`PathContains`: Filters based on path contents
|
40
|
+
- :class:`OfType`: Filters by object type
|
41
|
+
- :class:`Any`: Logical OR combination of filters
|
42
|
+
- :class:`All`: Logical AND combination of filters
|
43
|
+
- :class:`Not`: Logical negation of a filter
|
44
|
+
- :class:`Everything`: Matches all objects
|
45
|
+
- :class:`Nothing`: Matches no objects
|
46
|
+
|
47
|
+
Examples
|
48
|
+
--------
|
49
|
+
|
50
|
+
.. code-block:: python
|
51
|
+
|
52
|
+
>>> import brainstate as bs
|
53
|
+
>>> from brainstate.util.filter import WithTag, OfType, Any, All, Not
|
54
|
+
>>>
|
55
|
+
>>> # Filter objects with a specific tag
|
56
|
+
>>> tag_filter = WithTag('trainable')
|
57
|
+
>>>
|
58
|
+
>>> # Filter objects of a specific type
|
59
|
+
>>> type_filter = OfType(bs.nn.Linear)
|
60
|
+
>>>
|
61
|
+
>>> # Combine filters with logical operations
|
62
|
+
>>> combined_filter = All(
|
63
|
+
... WithTag('trainable'),
|
64
|
+
... OfType(bs.nn.Linear)
|
65
|
+
... )
|
66
|
+
>>>
|
67
|
+
>>> # Negate a filter
|
68
|
+
>>> not_trainable = Not(WithTag('trainable'))
|
69
|
+
>>>
|
70
|
+
>>> # Use Any for OR operations
|
71
|
+
>>> any_filter = Any(
|
72
|
+
... OfType(bs.nn.Linear),
|
73
|
+
... OfType(bs.nn.Conv)
|
74
|
+
... )
|
75
|
+
|
76
|
+
Using Filters with Tree Operations
|
77
|
+
-----------------------------------
|
78
|
+
|
79
|
+
.. code-block:: python
|
80
|
+
|
81
|
+
>>> import brainstate as bs
|
82
|
+
>>> import jax.tree_util as tree
|
83
|
+
>>> from brainstate.util.filter import to_predicate, WithTag
|
84
|
+
>>>
|
85
|
+
>>> # Create a model with tagged parameters
|
86
|
+
>>> class Model(bs.Module):
|
87
|
+
... def __init__(self):
|
88
|
+
... super().__init__()
|
89
|
+
... self.layer1 = bs.nn.Linear(10, 20)
|
90
|
+
... self.layer1.tag = 'trainable'
|
91
|
+
... self.layer2 = bs.nn.Linear(20, 10)
|
92
|
+
... self.layer2.tag = 'frozen'
|
93
|
+
>>>
|
94
|
+
>>> model = Model()
|
95
|
+
>>>
|
96
|
+
>>> # Filter trainable parameters
|
97
|
+
>>> trainable_filter = to_predicate('trainable')
|
98
|
+
>>>
|
99
|
+
>>> # Apply filter in tree operations
|
100
|
+
>>> def get_trainable_params(model):
|
101
|
+
... return tree.tree_map_with_path(
|
102
|
+
... lambda path, x: x if trainable_filter(path, x) else None,
|
103
|
+
... model
|
104
|
+
... )
|
105
|
+
|
106
|
+
Notes
|
107
|
+
-----
|
108
|
+
This module is adapted from the Flax library and provides similar functionality
|
109
|
+
for filtering and selecting components in neural network models and other
|
110
|
+
hierarchical data structures.
|
111
|
+
|
112
|
+
See Also
|
113
|
+
--------
|
114
|
+
brainstate.tree : Tree manipulation utilities
|
115
|
+
brainstate.typing : Type definitions for filters and predicates
|
116
|
+
|
117
|
+
"""
|
118
|
+
|
18
119
|
import builtins
|
19
120
|
import dataclasses
|
20
121
|
import typing
|
@@ -42,26 +143,77 @@ __all__ = [
|
|
42
143
|
|
43
144
|
def to_predicate(the_filter: Filter) -> Predicate:
|
44
145
|
"""
|
45
|
-
|
146
|
+
Convert a Filter to a predicate function.
|
46
147
|
|
47
148
|
This function takes various types of filters and converts them into
|
48
|
-
corresponding predicate functions that can be used for filtering
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
149
|
+
corresponding predicate functions that can be used for filtering objects
|
150
|
+
in nested structures.
|
151
|
+
|
152
|
+
Parameters
|
153
|
+
----------
|
154
|
+
the_filter : Filter
|
155
|
+
The filter to be converted. Can be of various types:
|
156
|
+
|
157
|
+
- **str**: Converted to a :class:`WithTag` filter
|
158
|
+
- **type**: Converted to an :class:`OfType` filter
|
159
|
+
- **bool**: ``True`` becomes :class:`Everything`, ``False`` becomes :class:`Nothing`
|
160
|
+
- **Ellipsis** (...): Converted to :class:`Everything`
|
161
|
+
- **None**: Converted to :class:`Nothing`
|
162
|
+
- **callable**: Returned as-is
|
163
|
+
- **list or tuple**: Converted to :class:`Any` filter with elements as arguments
|
164
|
+
|
165
|
+
Returns
|
166
|
+
-------
|
167
|
+
Predicate
|
168
|
+
A callable predicate function that takes (path, object) and returns bool.
|
169
|
+
|
170
|
+
Raises
|
171
|
+
------
|
172
|
+
TypeError
|
173
|
+
If the input filter is of an invalid type.
|
174
|
+
|
175
|
+
Examples
|
176
|
+
--------
|
177
|
+
.. code-block:: python
|
178
|
+
|
179
|
+
>>> from brainstate.util.filter import to_predicate
|
180
|
+
>>>
|
181
|
+
>>> # Convert string to WithTag filter
|
182
|
+
>>> pred = to_predicate('trainable')
|
183
|
+
>>> pred([], {'tag': 'trainable'})
|
184
|
+
True
|
185
|
+
>>>
|
186
|
+
>>> # Convert type to OfType filter
|
187
|
+
>>> import numpy as np
|
188
|
+
>>> pred = to_predicate(np.ndarray)
|
189
|
+
>>> pred([], np.array([1, 2, 3]))
|
190
|
+
True
|
191
|
+
>>>
|
192
|
+
>>> # Convert bool to Everything/Nothing
|
193
|
+
>>> pred_all = to_predicate(True)
|
194
|
+
>>> pred_all([], 'anything')
|
195
|
+
True
|
196
|
+
>>> pred_none = to_predicate(False)
|
197
|
+
>>> pred_none([], 'anything')
|
198
|
+
False
|
199
|
+
>>>
|
200
|
+
>>> # Convert list to Any filter
|
201
|
+
>>> pred = to_predicate(['tag1', 'tag2'])
|
202
|
+
>>> # This will match objects with either 'tag1' or 'tag2'
|
203
|
+
|
204
|
+
See Also
|
205
|
+
--------
|
206
|
+
WithTag : Filter for objects with specific tags
|
207
|
+
OfType : Filter for objects of specific types
|
208
|
+
Any : Logical OR combination of filters
|
209
|
+
Everything : Filter that matches all objects
|
210
|
+
Nothing : Filter that matches no objects
|
211
|
+
|
212
|
+
Notes
|
213
|
+
-----
|
214
|
+
This function is the main entry point for creating predicate functions
|
215
|
+
from various filter specifications. It provides a flexible way to define
|
216
|
+
filtering criteria without explicitly instantiating filter classes.
|
65
217
|
"""
|
66
218
|
|
67
219
|
if isinstance(the_filter, str):
|
@@ -88,27 +240,85 @@ def to_predicate(the_filter: Filter) -> Predicate:
|
|
88
240
|
@dataclasses.dataclass(frozen=True)
|
89
241
|
class WithTag:
|
90
242
|
"""
|
91
|
-
|
92
|
-
|
93
|
-
This
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
243
|
+
Filter objects that have a specific tag attribute.
|
244
|
+
|
245
|
+
This filter checks if an object has a 'tag' attribute that matches
|
246
|
+
the specified tag value. It's commonly used to filter parameters or
|
247
|
+
modules in neural networks based on their assigned tags.
|
248
|
+
|
249
|
+
Parameters
|
250
|
+
----------
|
251
|
+
tag : str
|
252
|
+
The tag value to match against.
|
253
|
+
|
254
|
+
Attributes
|
255
|
+
----------
|
256
|
+
tag : str
|
257
|
+
The tag value to match against.
|
258
|
+
|
259
|
+
Examples
|
260
|
+
--------
|
261
|
+
.. code-block:: python
|
262
|
+
|
263
|
+
>>> from brainstate.util.filter import WithTag
|
264
|
+
>>> import brainstate as bs
|
265
|
+
>>>
|
266
|
+
>>> # Create a filter for 'trainable' tag
|
267
|
+
>>> filter_trainable = WithTag('trainable')
|
268
|
+
>>>
|
269
|
+
>>> # Test with an object that has the tag
|
270
|
+
>>> class Param:
|
271
|
+
... def __init__(self, tag):
|
272
|
+
... self.tag = tag
|
273
|
+
>>>
|
274
|
+
>>> param1 = Param('trainable')
|
275
|
+
>>> param2 = Param('frozen')
|
276
|
+
>>>
|
277
|
+
>>> filter_trainable([], param1)
|
278
|
+
True
|
279
|
+
>>> filter_trainable([], param2)
|
280
|
+
False
|
281
|
+
>>>
|
282
|
+
>>> # Use with neural network modules
|
283
|
+
>>> class MyModule(bs.Module):
|
284
|
+
... def __init__(self):
|
285
|
+
... super().__init__()
|
286
|
+
... self.weight = bs.State(bs.random.randn(10, 10))
|
287
|
+
... self.weight.tag = 'trainable'
|
288
|
+
... self.bias = bs.State(bs.zeros(10))
|
289
|
+
... self.bias.tag = 'frozen'
|
290
|
+
|
291
|
+
See Also
|
292
|
+
--------
|
293
|
+
PathContains : Filter based on path contents
|
294
|
+
OfType : Filter based on object type
|
295
|
+
to_predicate : Convert various inputs to predicates
|
296
|
+
|
297
|
+
Notes
|
298
|
+
-----
|
299
|
+
The filter only matches objects that have a 'tag' attribute. Objects
|
300
|
+
without this attribute will not match, even if the filter is looking
|
301
|
+
for a specific tag value.
|
98
302
|
"""
|
99
303
|
|
100
304
|
tag: str
|
101
305
|
|
102
306
|
def __call__(self, path: PathParts, x: typing.Any) -> bool:
|
103
307
|
"""
|
104
|
-
Check if the object has a
|
308
|
+
Check if the object has a matching tag.
|
105
309
|
|
106
|
-
|
107
|
-
|
108
|
-
|
310
|
+
Parameters
|
311
|
+
----------
|
312
|
+
path : PathParts
|
313
|
+
The path to the current object (not used in this filter).
|
314
|
+
x : Any
|
315
|
+
The object to check for the tag.
|
109
316
|
|
110
|
-
Returns
|
111
|
-
|
317
|
+
Returns
|
318
|
+
-------
|
319
|
+
bool
|
320
|
+
True if the object has a 'tag' attribute matching the specified tag,
|
321
|
+
False otherwise.
|
112
322
|
"""
|
113
323
|
return hasattr(x, 'tag') and x.tag == self.tag
|
114
324
|
|
@@ -119,27 +329,85 @@ class WithTag:
|
|
119
329
|
@dataclasses.dataclass(frozen=True)
|
120
330
|
class PathContains:
|
121
331
|
"""
|
122
|
-
|
123
|
-
|
124
|
-
This
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
332
|
+
Filter objects based on whether their path contains a specific key.
|
333
|
+
|
334
|
+
This filter checks if a given key appears anywhere in the path to an object
|
335
|
+
within a nested structure. It's useful for selecting objects at specific
|
336
|
+
locations or with specific names in a hierarchy.
|
337
|
+
|
338
|
+
Parameters
|
339
|
+
----------
|
340
|
+
key : Key
|
341
|
+
The key to search for in the path.
|
342
|
+
|
343
|
+
Attributes
|
344
|
+
----------
|
345
|
+
key : Key
|
346
|
+
The key to search for in the path.
|
347
|
+
|
348
|
+
Examples
|
349
|
+
--------
|
350
|
+
.. code-block:: python
|
351
|
+
|
352
|
+
>>> from brainstate.util.filter import PathContains
|
353
|
+
>>>
|
354
|
+
>>> # Create a filter for paths containing 'weight'
|
355
|
+
>>> weight_filter = PathContains('weight')
|
356
|
+
>>>
|
357
|
+
>>> # Test with different paths
|
358
|
+
>>> weight_filter(['model', 'layer1', 'weight'], None)
|
359
|
+
True
|
360
|
+
>>> weight_filter(['model', 'layer1', 'bias'], None)
|
361
|
+
False
|
362
|
+
>>>
|
363
|
+
>>> # Filter for specific layer
|
364
|
+
>>> layer2_filter = PathContains('layer2')
|
365
|
+
>>> layer2_filter(['model', 'layer2', 'weight'], None)
|
366
|
+
True
|
367
|
+
>>> layer2_filter(['model', 'layer1', 'weight'], None)
|
368
|
+
False
|
369
|
+
>>>
|
370
|
+
>>> # Use with nested structures
|
371
|
+
>>> import jax.tree_util as tree
|
372
|
+
>>> nested_dict = {
|
373
|
+
... 'layer1': {'weight': [1, 2, 3], 'bias': [4, 5]},
|
374
|
+
... 'layer2': {'weight': [6, 7, 8], 'bias': [9, 10]}
|
375
|
+
... }
|
376
|
+
>>>
|
377
|
+
>>> # Filter all 'weight' entries
|
378
|
+
>>> def filter_weights(path, value):
|
379
|
+
... return value if weight_filter(path, value) else None
|
380
|
+
|
381
|
+
See Also
|
382
|
+
--------
|
383
|
+
WithTag : Filter based on tag attributes
|
384
|
+
OfType : Filter based on object type
|
385
|
+
to_predicate : Convert various inputs to predicates
|
386
|
+
|
387
|
+
Notes
|
388
|
+
-----
|
389
|
+
The path is typically a sequence of keys representing the location of
|
390
|
+
an object in a nested structure, such as the attribute names leading
|
391
|
+
to a parameter in a neural network model.
|
129
392
|
"""
|
130
393
|
|
131
394
|
key: Key
|
132
395
|
|
133
396
|
def __call__(self, path: PathParts, x: typing.Any) -> bool:
|
134
397
|
"""
|
135
|
-
Check if the key is present in the
|
398
|
+
Check if the key is present in the path.
|
136
399
|
|
137
|
-
|
138
|
-
|
139
|
-
|
400
|
+
Parameters
|
401
|
+
----------
|
402
|
+
path : PathParts
|
403
|
+
The path to check for the presence of the key.
|
404
|
+
x : Any
|
405
|
+
The object associated with the path (not used in this filter).
|
140
406
|
|
141
|
-
Returns
|
142
|
-
|
407
|
+
Returns
|
408
|
+
-------
|
409
|
+
bool
|
410
|
+
True if the key is present in the path, False otherwise.
|
143
411
|
"""
|
144
412
|
return self.key in path
|
145
413
|
|
@@ -150,17 +418,86 @@ class PathContains:
|
|
150
418
|
@dataclasses.dataclass(frozen=True)
|
151
419
|
class OfType:
|
152
420
|
"""
|
153
|
-
|
154
|
-
|
155
|
-
This
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
421
|
+
Filter objects based on their type.
|
422
|
+
|
423
|
+
This filter checks if an object is an instance of a specific type or
|
424
|
+
if it has a 'type' attribute that is a subclass of the specified type.
|
425
|
+
It's useful for filtering specific kinds of objects in a nested structure.
|
426
|
+
|
427
|
+
Parameters
|
428
|
+
----------
|
429
|
+
type : type
|
430
|
+
The type to match against.
|
431
|
+
|
432
|
+
Attributes
|
433
|
+
----------
|
434
|
+
type : type
|
435
|
+
The type to match against.
|
436
|
+
|
437
|
+
Examples
|
438
|
+
--------
|
439
|
+
.. code-block:: python
|
440
|
+
|
441
|
+
>>> from brainstate.util.filter import OfType
|
442
|
+
>>> import numpy as np
|
443
|
+
>>> import jax.numpy as jnp
|
444
|
+
>>>
|
445
|
+
>>> # Create a filter for numpy arrays
|
446
|
+
>>> array_filter = OfType(np.ndarray)
|
447
|
+
>>>
|
448
|
+
>>> # Test with different objects
|
449
|
+
>>> array_filter([], np.array([1, 2, 3]))
|
450
|
+
True
|
451
|
+
>>> array_filter([], [1, 2, 3])
|
452
|
+
False
|
453
|
+
>>>
|
454
|
+
>>> # Filter for specific module types
|
455
|
+
>>> import brainstate as bs
|
456
|
+
>>> linear_filter = OfType(bs.nn.Linear)
|
457
|
+
>>>
|
458
|
+
>>> # Use in model filtering
|
459
|
+
>>> class Model(bs.nn.Module):
|
460
|
+
... def __init__(self):
|
461
|
+
... super().__init__()
|
462
|
+
... self.linear1 = bs.nn.Linear(10, 20)
|
463
|
+
... self.linear2 = bs.nn.Linear(20, 10)
|
464
|
+
... self.activation = bs.nn.ReLU()
|
465
|
+
>>>
|
466
|
+
>>> # Filter all Linear layers
|
467
|
+
>>> model = Model()
|
468
|
+
>>> # linear_filter will match linear1 and linear2, not activation
|
469
|
+
|
470
|
+
See Also
|
471
|
+
--------
|
472
|
+
WithTag : Filter based on tag attributes
|
473
|
+
PathContains : Filter based on path contents
|
474
|
+
to_predicate : Convert various inputs to predicates
|
475
|
+
|
476
|
+
Notes
|
477
|
+
-----
|
478
|
+
This filter also checks for objects that have a 'type' attribute,
|
479
|
+
which is useful for wrapped or proxy objects that maintain type
|
480
|
+
information differently.
|
160
481
|
"""
|
161
482
|
type: type
|
162
483
|
|
163
484
|
def __call__(self, path: PathParts, x: typing.Any):
|
485
|
+
"""
|
486
|
+
Check if the object is of the specified type.
|
487
|
+
|
488
|
+
Parameters
|
489
|
+
----------
|
490
|
+
path : PathParts
|
491
|
+
The path to the current object (not used in this filter).
|
492
|
+
x : Any
|
493
|
+
The object to check.
|
494
|
+
|
495
|
+
Returns
|
496
|
+
-------
|
497
|
+
bool
|
498
|
+
True if the object is an instance of the specified type or
|
499
|
+
has a 'type' attribute that is a subclass of the specified type.
|
500
|
+
"""
|
164
501
|
return isinstance(x, self.type) or (
|
165
502
|
hasattr(x, 'type') and issubclass(x.type, self.type)
|
166
503
|
)
|
@@ -171,21 +508,81 @@ class OfType:
|
|
171
508
|
|
172
509
|
class Any:
|
173
510
|
"""
|
174
|
-
|
175
|
-
|
176
|
-
This
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
511
|
+
Combine multiple filters using logical OR operation.
|
512
|
+
|
513
|
+
This filter returns True if any of its constituent filters return True.
|
514
|
+
It's useful for creating flexible filtering criteria where multiple
|
515
|
+
conditions can be satisfied.
|
516
|
+
|
517
|
+
Parameters
|
518
|
+
----------
|
519
|
+
*filters : Filter
|
520
|
+
Variable number of filters to be combined with OR logic.
|
521
|
+
|
522
|
+
Attributes
|
523
|
+
----------
|
524
|
+
predicates : tuple of Predicate
|
525
|
+
Tuple of predicate functions converted from the input filters.
|
526
|
+
|
527
|
+
Examples
|
528
|
+
--------
|
529
|
+
.. code-block:: python
|
530
|
+
|
531
|
+
>>> from brainstate.util.filter import Any, WithTag, OfType
|
532
|
+
>>> import numpy as np
|
533
|
+
>>>
|
534
|
+
>>> # Create a filter that matches either tag
|
535
|
+
>>> trainable_or_frozen = Any('trainable', 'frozen')
|
536
|
+
>>>
|
537
|
+
>>> # Test with objects
|
538
|
+
>>> class Param:
|
539
|
+
... def __init__(self, tag):
|
540
|
+
... self.tag = tag
|
541
|
+
>>>
|
542
|
+
>>> trainable = Param('trainable')
|
543
|
+
>>> frozen = Param('frozen')
|
544
|
+
>>> other = Param('other')
|
545
|
+
>>>
|
546
|
+
>>> trainable_or_frozen([], trainable)
|
547
|
+
True
|
548
|
+
>>> trainable_or_frozen([], frozen)
|
549
|
+
True
|
550
|
+
>>> trainable_or_frozen([], other)
|
551
|
+
False
|
552
|
+
>>>
|
553
|
+
>>> # Combine different filter types
|
554
|
+
>>> array_or_list = Any(
|
555
|
+
... OfType(np.ndarray),
|
556
|
+
... OfType(list)
|
557
|
+
... )
|
558
|
+
>>>
|
559
|
+
>>> array_or_list([], np.array([1, 2, 3]))
|
560
|
+
True
|
561
|
+
>>> array_or_list([], [1, 2, 3])
|
562
|
+
True
|
563
|
+
>>> array_or_list([], (1, 2, 3))
|
564
|
+
False
|
565
|
+
|
566
|
+
See Also
|
567
|
+
--------
|
568
|
+
All : Logical AND combination of filters
|
569
|
+
Not : Logical negation of a filter
|
570
|
+
to_predicate : Convert various inputs to predicates
|
571
|
+
|
572
|
+
Notes
|
573
|
+
-----
|
574
|
+
The Any filter short-circuits evaluation, returning True as soon as
|
575
|
+
one of its constituent filters returns True.
|
181
576
|
"""
|
182
577
|
|
183
578
|
def __init__(self, *filters: Filter):
|
184
579
|
"""
|
185
|
-
Initialize the Any filter
|
580
|
+
Initialize the Any filter.
|
186
581
|
|
187
|
-
|
188
|
-
|
582
|
+
Parameters
|
583
|
+
----------
|
584
|
+
*filters : Filter
|
585
|
+
Variable number of filters to be combined.
|
189
586
|
"""
|
190
587
|
self.predicates = tuple(
|
191
588
|
to_predicate(collection_filter) for collection_filter in filters
|
@@ -367,22 +764,61 @@ class Not:
|
|
367
764
|
|
368
765
|
class Everything:
|
369
766
|
"""
|
370
|
-
|
371
|
-
|
372
|
-
This
|
373
|
-
|
767
|
+
Filter that matches all objects.
|
768
|
+
|
769
|
+
This filter always returns True, effectively disabling filtering.
|
770
|
+
It's useful as a default filter or when you want to select everything
|
771
|
+
in a structure.
|
772
|
+
|
773
|
+
Examples
|
774
|
+
--------
|
775
|
+
.. code-block:: python
|
776
|
+
|
777
|
+
>>> from brainstate.util.filter import Everything
|
778
|
+
>>>
|
779
|
+
>>> # Create a filter that matches everything
|
780
|
+
>>> all_filter = Everything()
|
781
|
+
>>>
|
782
|
+
>>> # Always returns True
|
783
|
+
>>> all_filter([], 'any_object')
|
784
|
+
True
|
785
|
+
>>> all_filter(['some', 'path'], 42)
|
786
|
+
True
|
787
|
+
>>> all_filter([], None)
|
788
|
+
True
|
789
|
+
>>>
|
790
|
+
>>> # Useful as a default filter
|
791
|
+
>>> def process_data(data, filter=None):
|
792
|
+
... if filter is None:
|
793
|
+
... filter = Everything()
|
794
|
+
... # Process all data when no specific filter is provided
|
795
|
+
|
796
|
+
See Also
|
797
|
+
--------
|
798
|
+
Nothing : Filter that matches no objects
|
799
|
+
to_predicate : Convert True to Everything filter
|
800
|
+
|
801
|
+
Notes
|
802
|
+
-----
|
803
|
+
This filter is equivalent to using ``to_predicate(True)`` or
|
804
|
+
``to_predicate(...)`` (Ellipsis).
|
374
805
|
"""
|
375
806
|
|
376
807
|
def __call__(self, path: PathParts, x: typing.Any) -> bool:
|
377
808
|
"""
|
378
|
-
Always return True
|
809
|
+
Always return True.
|
379
810
|
|
380
|
-
|
381
|
-
|
382
|
-
|
811
|
+
Parameters
|
812
|
+
----------
|
813
|
+
path : PathParts
|
814
|
+
The path to the current object (ignored).
|
815
|
+
x : Any
|
816
|
+
The object to be filtered (ignored).
|
383
817
|
|
384
|
-
Returns
|
385
|
-
|
818
|
+
Returns
|
819
|
+
-------
|
820
|
+
bool
|
821
|
+
Always returns True.
|
386
822
|
"""
|
387
823
|
return True
|
388
824
|
|
@@ -419,22 +855,62 @@ class Everything:
|
|
419
855
|
|
420
856
|
class Nothing:
|
421
857
|
"""
|
422
|
-
|
423
|
-
|
424
|
-
This
|
425
|
-
|
858
|
+
Filter that matches no objects.
|
859
|
+
|
860
|
+
This filter always returns False, effectively filtering out all objects.
|
861
|
+
It's useful for disabling selection or creating empty filter results.
|
862
|
+
|
863
|
+
Examples
|
864
|
+
--------
|
865
|
+
.. code-block:: python
|
866
|
+
|
867
|
+
>>> from brainstate.util.filter import Nothing
|
868
|
+
>>>
|
869
|
+
>>> # Create a filter that matches nothing
|
870
|
+
>>> none_filter = Nothing()
|
871
|
+
>>>
|
872
|
+
>>> # Always returns False
|
873
|
+
>>> none_filter([], 'any_object')
|
874
|
+
False
|
875
|
+
>>> none_filter(['some', 'path'], 42)
|
876
|
+
False
|
877
|
+
>>> none_filter([], None)
|
878
|
+
False
|
879
|
+
>>>
|
880
|
+
>>> # Useful for conditional filtering
|
881
|
+
>>> def get_params(model, include_frozen=False):
|
882
|
+
... if include_frozen:
|
883
|
+
... filter = Everything()
|
884
|
+
... else:
|
885
|
+
... filter = Nothing() # Exclude all frozen params
|
886
|
+
... # Apply filter to model parameters
|
887
|
+
|
888
|
+
See Also
|
889
|
+
--------
|
890
|
+
Everything : Filter that matches all objects
|
891
|
+
to_predicate : Convert False or None to Nothing filter
|
892
|
+
|
893
|
+
Notes
|
894
|
+
-----
|
895
|
+
This filter is equivalent to using ``to_predicate(False)`` or
|
896
|
+
``to_predicate(None)``.
|
426
897
|
"""
|
427
898
|
|
428
899
|
def __call__(self, path: PathParts, x: typing.Any) -> bool:
|
429
900
|
"""
|
430
|
-
Always return False
|
901
|
+
Always return False.
|
431
902
|
|
432
|
-
|
433
|
-
|
434
|
-
|
903
|
+
Parameters
|
904
|
+
----------
|
905
|
+
path : PathParts
|
906
|
+
The path to the current object (ignored).
|
907
|
+
x : Any
|
908
|
+
The object to be filtered (ignored).
|
435
909
|
|
436
|
-
Returns
|
437
|
-
|
910
|
+
Returns
|
911
|
+
-------
|
912
|
+
bool
|
913
|
+
Always returns False.
|
438
914
|
"""
|
439
915
|
return False
|
440
916
|
|