brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/util/filter.py CHANGED
@@ -15,6 +15,107 @@
15
15
  # See the License for the specific language governing permissions and
16
16
  # limitations under the License.
17
17
 
18
+ """
19
+ Filter utilities for traversing and selecting objects in nested structures.
20
+
21
+ This module provides a flexible filtering system for working with nested data
22
+ structures in BrainState. It offers various filter classes and utilities to
23
+ select, match, and transform objects based on their properties, types, or
24
+ positions within a hierarchical structure.
25
+
26
+ Key Features
27
+ ------------
28
+ - **Type-based filtering**: Select objects by their type or inheritance
29
+ - **Tag-based filtering**: Filter objects that have specific tags
30
+ - **Path-based filtering**: Select based on object paths in nested structures
31
+ - **Logical operations**: Combine filters with AND, OR, and NOT operations
32
+ - **Flexible conversion**: Convert various inputs to predicate functions
33
+
34
+ Filter Types
35
+ ------------
36
+ The module provides several built-in filter classes:
37
+
38
+ - :class:`WithTag`: Filters objects with specific tags
39
+ - :class:`PathContains`: Filters based on path contents
40
+ - :class:`OfType`: Filters by object type
41
+ - :class:`Any`: Logical OR combination of filters
42
+ - :class:`All`: Logical AND combination of filters
43
+ - :class:`Not`: Logical negation of a filter
44
+ - :class:`Everything`: Matches all objects
45
+ - :class:`Nothing`: Matches no objects
46
+
47
+ Examples
48
+ --------
49
+
50
+ .. code-block:: python
51
+
52
+ >>> import brainstate as bs
53
+ >>> from brainstate.util.filter import WithTag, OfType, Any, All, Not
54
+ >>>
55
+ >>> # Filter objects with a specific tag
56
+ >>> tag_filter = WithTag('trainable')
57
+ >>>
58
+ >>> # Filter objects of a specific type
59
+ >>> type_filter = OfType(bs.nn.Linear)
60
+ >>>
61
+ >>> # Combine filters with logical operations
62
+ >>> combined_filter = All(
63
+ ... WithTag('trainable'),
64
+ ... OfType(bs.nn.Linear)
65
+ ... )
66
+ >>>
67
+ >>> # Negate a filter
68
+ >>> not_trainable = Not(WithTag('trainable'))
69
+ >>>
70
+ >>> # Use Any for OR operations
71
+ >>> any_filter = Any(
72
+ ... OfType(bs.nn.Linear),
73
+ ... OfType(bs.nn.Conv)
74
+ ... )
75
+
76
+ Using Filters with Tree Operations
77
+ -----------------------------------
78
+
79
+ .. code-block:: python
80
+
81
+ >>> import brainstate as bs
82
+ >>> import jax.tree_util as tree
83
+ >>> from brainstate.util.filter import to_predicate, WithTag
84
+ >>>
85
+ >>> # Create a model with tagged parameters
86
+ >>> class Model(bs.Module):
87
+ ... def __init__(self):
88
+ ... super().__init__()
89
+ ... self.layer1 = bs.nn.Linear(10, 20)
90
+ ... self.layer1.tag = 'trainable'
91
+ ... self.layer2 = bs.nn.Linear(20, 10)
92
+ ... self.layer2.tag = 'frozen'
93
+ >>>
94
+ >>> model = Model()
95
+ >>>
96
+ >>> # Filter trainable parameters
97
+ >>> trainable_filter = to_predicate('trainable')
98
+ >>>
99
+ >>> # Apply filter in tree operations
100
+ >>> def get_trainable_params(model):
101
+ ... return tree.tree_map_with_path(
102
+ ... lambda path, x: x if trainable_filter(path, x) else None,
103
+ ... model
104
+ ... )
105
+
106
+ Notes
107
+ -----
108
+ This module is adapted from the Flax library and provides similar functionality
109
+ for filtering and selecting components in neural network models and other
110
+ hierarchical data structures.
111
+
112
+ See Also
113
+ --------
114
+ brainstate.tree : Tree manipulation utilities
115
+ brainstate.typing : Type definitions for filters and predicates
116
+
117
+ """
118
+
18
119
  import builtins
19
120
  import dataclasses
20
121
  import typing
@@ -42,26 +143,77 @@ __all__ = [
42
143
 
43
144
  def to_predicate(the_filter: Filter) -> Predicate:
44
145
  """
45
- Converts a Filter to a predicate function.
146
+ Convert a Filter to a predicate function.
46
147
 
47
148
  This function takes various types of filters and converts them into
48
- corresponding predicate functions that can be used for filtering.
49
-
50
- Args:
51
- the_filter (Filter): The filter to be converted. Can be of various types:
52
- - str: Converted to a WithTag filter.
53
- - type: Converted to an OfType filter.
54
- - bool: True becomes Everything(), False becomes Nothing().
55
- - Ellipsis: Converted to Everything().
56
- - None: Converted to Nothing().
57
- - callable: Returned as-is.
58
- - list or tuple: Converted to Any filter with elements as arguments.
59
-
60
- Returns:
61
- Predicate: A callable predicate function that can be used for filtering.
62
-
63
- Raises:
64
- TypeError: If the input filter is of an invalid type.
149
+ corresponding predicate functions that can be used for filtering objects
150
+ in nested structures.
151
+
152
+ Parameters
153
+ ----------
154
+ the_filter : Filter
155
+ The filter to be converted. Can be of various types:
156
+
157
+ - **str**: Converted to a :class:`WithTag` filter
158
+ - **type**: Converted to an :class:`OfType` filter
159
+ - **bool**: ``True`` becomes :class:`Everything`, ``False`` becomes :class:`Nothing`
160
+ - **Ellipsis** (...): Converted to :class:`Everything`
161
+ - **None**: Converted to :class:`Nothing`
162
+ - **callable**: Returned as-is
163
+ - **list or tuple**: Converted to :class:`Any` filter with elements as arguments
164
+
165
+ Returns
166
+ -------
167
+ Predicate
168
+ A callable predicate function that takes (path, object) and returns bool.
169
+
170
+ Raises
171
+ ------
172
+ TypeError
173
+ If the input filter is of an invalid type.
174
+
175
+ Examples
176
+ --------
177
+ .. code-block:: python
178
+
179
+ >>> from brainstate.util.filter import to_predicate
180
+ >>>
181
+ >>> # Convert string to WithTag filter
182
+ >>> pred = to_predicate('trainable')
183
+ >>> pred([], {'tag': 'trainable'})
184
+ True
185
+ >>>
186
+ >>> # Convert type to OfType filter
187
+ >>> import numpy as np
188
+ >>> pred = to_predicate(np.ndarray)
189
+ >>> pred([], np.array([1, 2, 3]))
190
+ True
191
+ >>>
192
+ >>> # Convert bool to Everything/Nothing
193
+ >>> pred_all = to_predicate(True)
194
+ >>> pred_all([], 'anything')
195
+ True
196
+ >>> pred_none = to_predicate(False)
197
+ >>> pred_none([], 'anything')
198
+ False
199
+ >>>
200
+ >>> # Convert list to Any filter
201
+ >>> pred = to_predicate(['tag1', 'tag2'])
202
+ >>> # This will match objects with either 'tag1' or 'tag2'
203
+
204
+ See Also
205
+ --------
206
+ WithTag : Filter for objects with specific tags
207
+ OfType : Filter for objects of specific types
208
+ Any : Logical OR combination of filters
209
+ Everything : Filter that matches all objects
210
+ Nothing : Filter that matches no objects
211
+
212
+ Notes
213
+ -----
214
+ This function is the main entry point for creating predicate functions
215
+ from various filter specifications. It provides a flexible way to define
216
+ filtering criteria without explicitly instantiating filter classes.
65
217
  """
66
218
 
67
219
  if isinstance(the_filter, str):
@@ -88,27 +240,85 @@ def to_predicate(the_filter: Filter) -> Predicate:
88
240
  @dataclasses.dataclass(frozen=True)
89
241
  class WithTag:
90
242
  """
91
- A filter class that checks if an object has a specific tag.
92
-
93
- This class is a callable that can be used as a predicate function
94
- to filter objects based on their 'tag' attribute.
95
-
96
- Attributes:
97
- tag (str): The tag to match against.
243
+ Filter objects that have a specific tag attribute.
244
+
245
+ This filter checks if an object has a 'tag' attribute that matches
246
+ the specified tag value. It's commonly used to filter parameters or
247
+ modules in neural networks based on their assigned tags.
248
+
249
+ Parameters
250
+ ----------
251
+ tag : str
252
+ The tag value to match against.
253
+
254
+ Attributes
255
+ ----------
256
+ tag : str
257
+ The tag value to match against.
258
+
259
+ Examples
260
+ --------
261
+ .. code-block:: python
262
+
263
+ >>> from brainstate.util.filter import WithTag
264
+ >>> import brainstate as bs
265
+ >>>
266
+ >>> # Create a filter for 'trainable' tag
267
+ >>> filter_trainable = WithTag('trainable')
268
+ >>>
269
+ >>> # Test with an object that has the tag
270
+ >>> class Param:
271
+ ... def __init__(self, tag):
272
+ ... self.tag = tag
273
+ >>>
274
+ >>> param1 = Param('trainable')
275
+ >>> param2 = Param('frozen')
276
+ >>>
277
+ >>> filter_trainable([], param1)
278
+ True
279
+ >>> filter_trainable([], param2)
280
+ False
281
+ >>>
282
+ >>> # Use with neural network modules
283
+ >>> class MyModule(bs.Module):
284
+ ... def __init__(self):
285
+ ... super().__init__()
286
+ ... self.weight = bs.State(bs.random.randn(10, 10))
287
+ ... self.weight.tag = 'trainable'
288
+ ... self.bias = bs.State(bs.zeros(10))
289
+ ... self.bias.tag = 'frozen'
290
+
291
+ See Also
292
+ --------
293
+ PathContains : Filter based on path contents
294
+ OfType : Filter based on object type
295
+ to_predicate : Convert various inputs to predicates
296
+
297
+ Notes
298
+ -----
299
+ The filter only matches objects that have a 'tag' attribute. Objects
300
+ without this attribute will not match, even if the filter is looking
301
+ for a specific tag value.
98
302
  """
99
303
 
100
304
  tag: str
101
305
 
102
306
  def __call__(self, path: PathParts, x: typing.Any) -> bool:
103
307
  """
104
- Check if the object has a 'tag' attribute matching the specified tag.
308
+ Check if the object has a matching tag.
105
309
 
106
- Args:
107
- path (PathParts): The path to the current object (not used in this filter).
108
- x (typing.Any): The object to check for the tag.
310
+ Parameters
311
+ ----------
312
+ path : PathParts
313
+ The path to the current object (not used in this filter).
314
+ x : Any
315
+ The object to check for the tag.
109
316
 
110
- Returns:
111
- bool: True if the object has a 'tag' attribute matching the specified tag, False otherwise.
317
+ Returns
318
+ -------
319
+ bool
320
+ True if the object has a 'tag' attribute matching the specified tag,
321
+ False otherwise.
112
322
  """
113
323
  return hasattr(x, 'tag') and x.tag == self.tag
114
324
 
@@ -119,27 +329,85 @@ class WithTag:
119
329
  @dataclasses.dataclass(frozen=True)
120
330
  class PathContains:
121
331
  """
122
- A filter class that checks if a given key is present in the path.
123
-
124
- This class is a callable that can be used as a predicate function
125
- to filter objects based on whether a specific key is present in their path.
126
-
127
- Attributes:
128
- key (Key): The key to search for in the path.
332
+ Filter objects based on whether their path contains a specific key.
333
+
334
+ This filter checks if a given key appears anywhere in the path to an object
335
+ within a nested structure. It's useful for selecting objects at specific
336
+ locations or with specific names in a hierarchy.
337
+
338
+ Parameters
339
+ ----------
340
+ key : Key
341
+ The key to search for in the path.
342
+
343
+ Attributes
344
+ ----------
345
+ key : Key
346
+ The key to search for in the path.
347
+
348
+ Examples
349
+ --------
350
+ .. code-block:: python
351
+
352
+ >>> from brainstate.util.filter import PathContains
353
+ >>>
354
+ >>> # Create a filter for paths containing 'weight'
355
+ >>> weight_filter = PathContains('weight')
356
+ >>>
357
+ >>> # Test with different paths
358
+ >>> weight_filter(['model', 'layer1', 'weight'], None)
359
+ True
360
+ >>> weight_filter(['model', 'layer1', 'bias'], None)
361
+ False
362
+ >>>
363
+ >>> # Filter for specific layer
364
+ >>> layer2_filter = PathContains('layer2')
365
+ >>> layer2_filter(['model', 'layer2', 'weight'], None)
366
+ True
367
+ >>> layer2_filter(['model', 'layer1', 'weight'], None)
368
+ False
369
+ >>>
370
+ >>> # Use with nested structures
371
+ >>> import jax.tree_util as tree
372
+ >>> nested_dict = {
373
+ ... 'layer1': {'weight': [1, 2, 3], 'bias': [4, 5]},
374
+ ... 'layer2': {'weight': [6, 7, 8], 'bias': [9, 10]}
375
+ ... }
376
+ >>>
377
+ >>> # Filter all 'weight' entries
378
+ >>> def filter_weights(path, value):
379
+ ... return value if weight_filter(path, value) else None
380
+
381
+ See Also
382
+ --------
383
+ WithTag : Filter based on tag attributes
384
+ OfType : Filter based on object type
385
+ to_predicate : Convert various inputs to predicates
386
+
387
+ Notes
388
+ -----
389
+ The path is typically a sequence of keys representing the location of
390
+ an object in a nested structure, such as the attribute names leading
391
+ to a parameter in a neural network model.
129
392
  """
130
393
 
131
394
  key: Key
132
395
 
133
396
  def __call__(self, path: PathParts, x: typing.Any) -> bool:
134
397
  """
135
- Check if the key is present in the given path.
398
+ Check if the key is present in the path.
136
399
 
137
- Args:
138
- path (PathParts): The path to check for the presence of the key.
139
- x (typing.Any): The object associated with the path (not used in this filter).
400
+ Parameters
401
+ ----------
402
+ path : PathParts
403
+ The path to check for the presence of the key.
404
+ x : Any
405
+ The object associated with the path (not used in this filter).
140
406
 
141
- Returns:
142
- bool: True if the key is present in the path, False otherwise.
407
+ Returns
408
+ -------
409
+ bool
410
+ True if the key is present in the path, False otherwise.
143
411
  """
144
412
  return self.key in path
145
413
 
@@ -150,17 +418,86 @@ class PathContains:
150
418
  @dataclasses.dataclass(frozen=True)
151
419
  class OfType:
152
420
  """
153
- A filter class that checks if an object is of a specific type.
154
-
155
- This class is a callable that can be used as a predicate function
156
- to filter objects based on their type.
157
-
158
- Attributes:
159
- type (type): The type to match against.
421
+ Filter objects based on their type.
422
+
423
+ This filter checks if an object is an instance of a specific type or
424
+ if it has a 'type' attribute that is a subclass of the specified type.
425
+ It's useful for filtering specific kinds of objects in a nested structure.
426
+
427
+ Parameters
428
+ ----------
429
+ type : type
430
+ The type to match against.
431
+
432
+ Attributes
433
+ ----------
434
+ type : type
435
+ The type to match against.
436
+
437
+ Examples
438
+ --------
439
+ .. code-block:: python
440
+
441
+ >>> from brainstate.util.filter import OfType
442
+ >>> import numpy as np
443
+ >>> import jax.numpy as jnp
444
+ >>>
445
+ >>> # Create a filter for numpy arrays
446
+ >>> array_filter = OfType(np.ndarray)
447
+ >>>
448
+ >>> # Test with different objects
449
+ >>> array_filter([], np.array([1, 2, 3]))
450
+ True
451
+ >>> array_filter([], [1, 2, 3])
452
+ False
453
+ >>>
454
+ >>> # Filter for specific module types
455
+ >>> import brainstate as bs
456
+ >>> linear_filter = OfType(bs.nn.Linear)
457
+ >>>
458
+ >>> # Use in model filtering
459
+ >>> class Model(bs.nn.Module):
460
+ ... def __init__(self):
461
+ ... super().__init__()
462
+ ... self.linear1 = bs.nn.Linear(10, 20)
463
+ ... self.linear2 = bs.nn.Linear(20, 10)
464
+ ... self.activation = bs.nn.ReLU()
465
+ >>>
466
+ >>> # Filter all Linear layers
467
+ >>> model = Model()
468
+ >>> # linear_filter will match linear1 and linear2, not activation
469
+
470
+ See Also
471
+ --------
472
+ WithTag : Filter based on tag attributes
473
+ PathContains : Filter based on path contents
474
+ to_predicate : Convert various inputs to predicates
475
+
476
+ Notes
477
+ -----
478
+ This filter also checks for objects that have a 'type' attribute,
479
+ which is useful for wrapped or proxy objects that maintain type
480
+ information differently.
160
481
  """
161
482
  type: type
162
483
 
163
484
  def __call__(self, path: PathParts, x: typing.Any):
485
+ """
486
+ Check if the object is of the specified type.
487
+
488
+ Parameters
489
+ ----------
490
+ path : PathParts
491
+ The path to the current object (not used in this filter).
492
+ x : Any
493
+ The object to check.
494
+
495
+ Returns
496
+ -------
497
+ bool
498
+ True if the object is an instance of the specified type or
499
+ has a 'type' attribute that is a subclass of the specified type.
500
+ """
164
501
  return isinstance(x, self.type) or (
165
502
  hasattr(x, 'type') and issubclass(x.type, self.type)
166
503
  )
@@ -171,21 +508,81 @@ class OfType:
171
508
 
172
509
  class Any:
173
510
  """
174
- A filter class that combines multiple filters using a logical OR operation.
175
-
176
- This class creates a composite filter that returns True if any of its
177
- constituent filters return True.
178
-
179
- Attributes:
180
- predicates (tuple): A tuple of predicate functions converted from the input filters.
511
+ Combine multiple filters using logical OR operation.
512
+
513
+ This filter returns True if any of its constituent filters return True.
514
+ It's useful for creating flexible filtering criteria where multiple
515
+ conditions can be satisfied.
516
+
517
+ Parameters
518
+ ----------
519
+ *filters : Filter
520
+ Variable number of filters to be combined with OR logic.
521
+
522
+ Attributes
523
+ ----------
524
+ predicates : tuple of Predicate
525
+ Tuple of predicate functions converted from the input filters.
526
+
527
+ Examples
528
+ --------
529
+ .. code-block:: python
530
+
531
+ >>> from brainstate.util.filter import Any, WithTag, OfType
532
+ >>> import numpy as np
533
+ >>>
534
+ >>> # Create a filter that matches either tag
535
+ >>> trainable_or_frozen = Any('trainable', 'frozen')
536
+ >>>
537
+ >>> # Test with objects
538
+ >>> class Param:
539
+ ... def __init__(self, tag):
540
+ ... self.tag = tag
541
+ >>>
542
+ >>> trainable = Param('trainable')
543
+ >>> frozen = Param('frozen')
544
+ >>> other = Param('other')
545
+ >>>
546
+ >>> trainable_or_frozen([], trainable)
547
+ True
548
+ >>> trainable_or_frozen([], frozen)
549
+ True
550
+ >>> trainable_or_frozen([], other)
551
+ False
552
+ >>>
553
+ >>> # Combine different filter types
554
+ >>> array_or_list = Any(
555
+ ... OfType(np.ndarray),
556
+ ... OfType(list)
557
+ ... )
558
+ >>>
559
+ >>> array_or_list([], np.array([1, 2, 3]))
560
+ True
561
+ >>> array_or_list([], [1, 2, 3])
562
+ True
563
+ >>> array_or_list([], (1, 2, 3))
564
+ False
565
+
566
+ See Also
567
+ --------
568
+ All : Logical AND combination of filters
569
+ Not : Logical negation of a filter
570
+ to_predicate : Convert various inputs to predicates
571
+
572
+ Notes
573
+ -----
574
+ The Any filter short-circuits evaluation, returning True as soon as
575
+ one of its constituent filters returns True.
181
576
  """
182
577
 
183
578
  def __init__(self, *filters: Filter):
184
579
  """
185
- Initialize the Any filter with a variable number of filters.
580
+ Initialize the Any filter.
186
581
 
187
- Args:
188
- *filters (Filter): Variable number of filters to be combined.
582
+ Parameters
583
+ ----------
584
+ *filters : Filter
585
+ Variable number of filters to be combined.
189
586
  """
190
587
  self.predicates = tuple(
191
588
  to_predicate(collection_filter) for collection_filter in filters
@@ -367,22 +764,61 @@ class Not:
367
764
 
368
765
  class Everything:
369
766
  """
370
- A filter class that always returns True for any input.
371
-
372
- This class represents a filter that matches everything, effectively
373
- allowing all objects to pass through without any filtering.
767
+ Filter that matches all objects.
768
+
769
+ This filter always returns True, effectively disabling filtering.
770
+ It's useful as a default filter or when you want to select everything
771
+ in a structure.
772
+
773
+ Examples
774
+ --------
775
+ .. code-block:: python
776
+
777
+ >>> from brainstate.util.filter import Everything
778
+ >>>
779
+ >>> # Create a filter that matches everything
780
+ >>> all_filter = Everything()
781
+ >>>
782
+ >>> # Always returns True
783
+ >>> all_filter([], 'any_object')
784
+ True
785
+ >>> all_filter(['some', 'path'], 42)
786
+ True
787
+ >>> all_filter([], None)
788
+ True
789
+ >>>
790
+ >>> # Useful as a default filter
791
+ >>> def process_data(data, filter=None):
792
+ ... if filter is None:
793
+ ... filter = Everything()
794
+ ... # Process all data when no specific filter is provided
795
+
796
+ See Also
797
+ --------
798
+ Nothing : Filter that matches no objects
799
+ to_predicate : Convert True to Everything filter
800
+
801
+ Notes
802
+ -----
803
+ This filter is equivalent to using ``to_predicate(True)`` or
804
+ ``to_predicate(...)`` (Ellipsis).
374
805
  """
375
806
 
376
807
  def __call__(self, path: PathParts, x: typing.Any) -> bool:
377
808
  """
378
- Always return True, regardless of the input.
809
+ Always return True.
379
810
 
380
- Args:
381
- path (PathParts): The path to the current object (not used).
382
- x (typing.Any): The object to be filtered (not used).
811
+ Parameters
812
+ ----------
813
+ path : PathParts
814
+ The path to the current object (ignored).
815
+ x : Any
816
+ The object to be filtered (ignored).
383
817
 
384
- Returns:
385
- bool: Always returns True.
818
+ Returns
819
+ -------
820
+ bool
821
+ Always returns True.
386
822
  """
387
823
  return True
388
824
 
@@ -419,22 +855,62 @@ class Everything:
419
855
 
420
856
  class Nothing:
421
857
  """
422
- A filter class that always returns False for any input.
423
-
424
- This class represents a filter that matches nothing, effectively
425
- filtering out all objects.
858
+ Filter that matches no objects.
859
+
860
+ This filter always returns False, effectively filtering out all objects.
861
+ It's useful for disabling selection or creating empty filter results.
862
+
863
+ Examples
864
+ --------
865
+ .. code-block:: python
866
+
867
+ >>> from brainstate.util.filter import Nothing
868
+ >>>
869
+ >>> # Create a filter that matches nothing
870
+ >>> none_filter = Nothing()
871
+ >>>
872
+ >>> # Always returns False
873
+ >>> none_filter([], 'any_object')
874
+ False
875
+ >>> none_filter(['some', 'path'], 42)
876
+ False
877
+ >>> none_filter([], None)
878
+ False
879
+ >>>
880
+ >>> # Useful for conditional filtering
881
+ >>> def get_params(model, include_frozen=False):
882
+ ... if include_frozen:
883
+ ... filter = Everything()
884
+ ... else:
885
+ ... filter = Nothing() # Exclude all frozen params
886
+ ... # Apply filter to model parameters
887
+
888
+ See Also
889
+ --------
890
+ Everything : Filter that matches all objects
891
+ to_predicate : Convert False or None to Nothing filter
892
+
893
+ Notes
894
+ -----
895
+ This filter is equivalent to using ``to_predicate(False)`` or
896
+ ``to_predicate(None)``.
426
897
  """
427
898
 
428
899
  def __call__(self, path: PathParts, x: typing.Any) -> bool:
429
900
  """
430
- Always return False, regardless of the input.
901
+ Always return False.
431
902
 
432
- Args:
433
- path (PathParts): The path to the current object (not used).
434
- x (typing.Any): The object to be filtered (not used).
903
+ Parameters
904
+ ----------
905
+ path : PathParts
906
+ The path to the current object (ignored).
907
+ x : Any
908
+ The object to be filtered (ignored).
435
909
 
436
- Returns:
437
- bool: Always returns False.
910
+ Returns
911
+ -------
912
+ bool
913
+ Always returns False.
438
914
  """
439
915
  return False
440
916