brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. brainstate/__init__.py +169 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2319 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +1652 -1652
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1624 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1433 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +137 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +633 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +154 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +477 -477
  32. brainstate/nn/_dynamics.py +1267 -1267
  33. brainstate/nn/_dynamics_test.py +67 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +384 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/_rand_funs.py +3938 -3938
  64. brainstate/random/_rand_funs_test.py +640 -640
  65. brainstate/random/_rand_seed.py +675 -675
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1617
  68. brainstate/random/_rand_state_test.py +551 -551
  69. brainstate/transform/__init__.py +59 -59
  70. brainstate/transform/_ad_checkpoint.py +176 -176
  71. brainstate/transform/_ad_checkpoint_test.py +49 -49
  72. brainstate/transform/_autograd.py +1025 -1025
  73. brainstate/transform/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -316
  75. brainstate/transform/_conditions_test.py +220 -220
  76. brainstate/transform/_error_if.py +94 -94
  77. brainstate/transform/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -145
  79. brainstate/transform/_eval_shape_test.py +38 -38
  80. brainstate/transform/_jit.py +399 -399
  81. brainstate/transform/_jit_test.py +143 -143
  82. brainstate/transform/_loop_collect_return.py +675 -675
  83. brainstate/transform/_loop_collect_return_test.py +58 -58
  84. brainstate/transform/_loop_no_collection.py +283 -283
  85. brainstate/transform/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -2016
  87. brainstate/transform/_make_jaxpr_test.py +1510 -1510
  88. brainstate/transform/_mapping.py +529 -529
  89. brainstate/transform/_mapping_test.py +194 -194
  90. brainstate/transform/_progress_bar.py +255 -255
  91. brainstate/transform/_random.py +171 -171
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate-0.2.0.dist-info/RECORD +0 -111
  111. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  112. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,1025 +1,1025 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """
17
- Utility functions and classes for BrainState.
18
-
19
- This module provides various utility functions and enhanced dictionary classes
20
- for managing collections, memory, and object operations in the BrainState framework.
21
- """
22
-
23
- import copy
24
- import functools
25
- import gc
26
- import threading
27
- import types
28
- import warnings
29
- from collections.abc import Iterable, Mapping, MutableMapping
30
- from typing import (
31
- Any, Callable, Dict, Iterator, List, Optional,
32
- Tuple, Type, TypeVar, Union, overload
33
- )
34
-
35
- import jax
36
- from jax.lib import xla_bridge
37
-
38
- from brainstate._utils import set_module_as
39
-
40
- __all__ = [
41
- 'split_total',
42
- 'clear_buffer_memory',
43
- 'not_instance_eval',
44
- 'is_instance_eval',
45
- 'DictManager',
46
- 'DotDict',
47
- 'get_unique_name',
48
- 'merge_dicts',
49
- 'flatten_dict',
50
- 'unflatten_dict',
51
- ]
52
-
53
- T = TypeVar('T')
54
- V = TypeVar('V')
55
- K = TypeVar('K')
56
-
57
-
58
- def split_total(
59
- total: int,
60
- fraction: Union[int, float],
61
- ) -> int:
62
- """
63
- Calculate the number of epochs for simulation based on a total and a fraction.
64
-
65
- This function determines the number of epochs to simulate given a total number
66
- of epochs and either a fraction or a specific number of epochs to run.
67
-
68
- Parameters
69
- ----------
70
- total : int
71
- The total number of epochs. Must be a positive integer.
72
- fraction : Union[int, float]
73
- If ``float``: A value between 0 and 1 representing the fraction of total epochs to run.
74
- If ``int``: The specific number of epochs to run, must not exceed the total.
75
-
76
- Returns
77
- -------
78
- int
79
- The calculated number of epochs to simulate.
80
-
81
- Raises
82
- ------
83
- TypeError
84
- If total is not an integer.
85
- ValueError
86
- If total is not positive, fraction is negative, or if fraction as float is > 1
87
- or as int is > total.
88
-
89
- Examples
90
- --------
91
- >>> split_total(100, 0.5)
92
- 50
93
- >>> split_total(100, 25)
94
- 25
95
- >>> split_total(100, 1.5) # Raises ValueError
96
- ValueError: 'fraction' value cannot be greater than 1.
97
- """
98
- if not isinstance(total, int):
99
- raise TypeError(f"'total' must be an integer, got {type(total).__name__}.")
100
- if total <= 0:
101
- raise ValueError(f"'total' must be a positive integer, got {total}.")
102
-
103
- if isinstance(fraction, float):
104
- if fraction < 0:
105
- raise ValueError(f"'fraction' value cannot be negative, got {fraction}.")
106
- if fraction > 1:
107
- raise ValueError(f"'fraction' value cannot be greater than 1, got {fraction}.")
108
- return int(total * fraction)
109
-
110
- elif isinstance(fraction, int):
111
- if fraction < 0:
112
- raise ValueError(f"'fraction' value cannot be negative, got {fraction}.")
113
- if fraction > total:
114
- raise ValueError(f"'fraction' value cannot be greater than total ({total}), got {fraction}.")
115
- return fraction
116
-
117
- else:
118
- raise TypeError(f"'fraction' must be an integer or float, got {type(fraction).__name__}.")
119
-
120
-
121
- class NameContext(threading.local):
122
- """Thread-local context for managing unique names."""
123
-
124
- def __init__(self):
125
- self.typed_names: Dict[str, int] = {}
126
-
127
- def reset(self, type_: Optional[str] = None) -> None:
128
- """Reset the counter for a specific type or all types."""
129
- if type_ is None:
130
- self.typed_names.clear()
131
- elif type_ in self.typed_names:
132
- self.typed_names[type_] = 0
133
-
134
-
135
- NAME = NameContext()
136
-
137
-
138
- @set_module_as('brainstate.util')
139
- def get_unique_name(type_: str, prefix: str = '') -> str:
140
- """
141
- Get a unique name for the given object type.
142
-
143
- Parameters
144
- ----------
145
- type_ : str
146
- The base type name.
147
- prefix : str, optional
148
- Additional prefix to add before the type name.
149
-
150
- Returns
151
- -------
152
- str
153
- A unique name combining prefix, type, and counter.
154
-
155
- Examples
156
- --------
157
- >>> get_unique_name('layer')
158
- 'layer0'
159
- >>> get_unique_name('layer', 'conv_')
160
- 'conv_layer1'
161
- """
162
- if type_ not in NAME.typed_names:
163
- NAME.typed_names[type_] = 0
164
-
165
- full_prefix = f'{prefix}{type_}' if prefix else type_
166
- name = f'{full_prefix}{NAME.typed_names[type_]}'
167
- NAME.typed_names[type_] += 1
168
- return name
169
-
170
-
171
- @jax.tree_util.register_pytree_node_class
172
- class DictManager(dict, MutableMapping[K, V]):
173
- """
174
- Enhanced dictionary for managing collections in BrainState.
175
-
176
- DictManager extends the standard Python dict with additional methods for
177
- filtering, splitting, and managing collections of objects. It's registered
178
- as a JAX pytree node for compatibility with JAX transformations.
179
-
180
- Examples
181
- --------
182
- >>> dm = DictManager({'a': 1, 'b': 2.0, 'c': 'text'})
183
- >>> dm.subset(int) # Get only integer values
184
- DictManager({'a': 1})
185
- >>> dm.unique() # Get unique values only
186
- DictManager({'a': 1, 'b': 2.0, 'c': 'text'})
187
- """
188
-
189
- __module__ = 'brainstate.util'
190
- _val_id_to_key: Dict[int, Any]
191
-
192
- def __init__(self, *args, **kwargs):
193
- """Initialize DictManager with optional dict-like arguments."""
194
- super().__init__(*args, **kwargs)
195
- self._val_id_to_key = {}
196
-
197
- def subset(self, sep: Union[Type, Tuple[Type, ...], Callable[[Any], bool]]) -> 'DictManager':
198
- """
199
- Get a new DictManager with a subset of items based on value type or predicate.
200
-
201
- Parameters
202
- ----------
203
- sep : Union[Type, Tuple[Type, ...], Callable]
204
- If Type or Tuple of Types: Select values that are instances of these types.
205
- If Callable: Select values where sep(value) returns True.
206
-
207
- Returns
208
- -------
209
- DictManager
210
- A new DictManager containing only matching items.
211
-
212
- Examples
213
- --------
214
- >>> dm = DictManager({'a': 1, 'b': 2.0, 'c': 'text'})
215
- >>> dm.subset(int)
216
- DictManager({'a': 1})
217
- >>> dm.subset(lambda x: isinstance(x, (int, float)))
218
- DictManager({'a': 1, 'b': 2.0})
219
- """
220
- gather = type(self)()
221
- if callable(sep) and not isinstance(sep, type):
222
- for k, v in self.items():
223
- if sep(v):
224
- gather[k] = v
225
- else:
226
- for k, v in self.items():
227
- if isinstance(v, sep):
228
- gather[k] = v
229
- return gather
230
-
231
- def not_subset(self, sep: Union[Type, Tuple[Type, ...]]) -> 'DictManager':
232
- """
233
- Get a new DictManager excluding items of specified types.
234
-
235
- Parameters
236
- ----------
237
- sep : Union[Type, Tuple[Type, ...]]
238
- Types to exclude from the result.
239
-
240
- Returns
241
- -------
242
- DictManager
243
- A new DictManager excluding items of specified types.
244
- """
245
- gather = type(self)()
246
- for k, v in self.items():
247
- if not isinstance(v, sep):
248
- gather[k] = v
249
- return gather
250
-
251
- def add_unique_key(self, key: K, val: V) -> None:
252
- """
253
- Add a new element ensuring the key maps to a unique value.
254
-
255
- Parameters
256
- ----------
257
- key : Any
258
- The key to add.
259
- val : Any
260
- The value to associate with the key.
261
-
262
- Raises
263
- ------
264
- ValueError
265
- If the key already exists with a different value.
266
- """
267
- self._check_elem(val)
268
- if key in self:
269
- if id(val) != id(self[key]):
270
- raise ValueError(
271
- f"Key '{key}' already exists with a different value. "
272
- f"Existing: {self[key]}, New: {val}"
273
- )
274
- else:
275
- self[key] = val
276
-
277
- def add_unique_value(self, key: K, val: V) -> bool:
278
- """
279
- Add a new element only if the value is unique across all entries.
280
-
281
- Parameters
282
- ----------
283
- key : Any
284
- The key to add.
285
- val : Any
286
- The value to associate with the key.
287
-
288
- Returns
289
- -------
290
- bool
291
- True if the value was added (was unique), False otherwise.
292
- """
293
- self._check_elem(val)
294
- if not hasattr(self, '_val_id_to_key'):
295
- self._val_id_to_key = {id(v): k for k, v in self.items()}
296
-
297
- val_id = id(val)
298
- if val_id not in self._val_id_to_key:
299
- self._val_id_to_key[val_id] = key
300
- self[key] = val
301
- return True
302
- return False
303
-
304
- def unique(self) -> 'DictManager':
305
- """
306
- Get a new DictManager with unique values only.
307
-
308
- If multiple keys map to the same value (by identity),
309
- only the first key-value pair is retained.
310
-
311
- Returns
312
- -------
313
- DictManager
314
- A new DictManager with unique values.
315
- """
316
- gather = type(self)()
317
- seen = set()
318
- for k, v in self.items():
319
- v_id = id(v)
320
- if v_id not in seen:
321
- seen.add(v_id)
322
- gather[k] = v
323
- return gather
324
-
325
- def unique_(self) -> 'DictManager':
326
- """
327
- Remove duplicate values in-place.
328
-
329
- Returns
330
- -------
331
- DictManager
332
- Self, for method chaining.
333
- """
334
- seen = set()
335
- keys_to_remove = []
336
- for k, v in self.items():
337
- v_id = id(v)
338
- if v_id in seen:
339
- keys_to_remove.append(k)
340
- else:
341
- seen.add(v_id)
342
-
343
- for k in keys_to_remove:
344
- del self[k]
345
- return self
346
-
347
- def assign(self, *args: Dict[K, V], **kwargs: V) -> None:
348
- """
349
- Update the DictManager with multiple dictionaries.
350
-
351
- Parameters
352
- ----------
353
- *args : Dict
354
- Dictionaries to merge into this one.
355
- **kwargs
356
- Additional key-value pairs to add.
357
- """
358
- for arg in args:
359
- if not isinstance(arg, dict):
360
- raise TypeError(f"Arguments must be dict instances, got {type(arg).__name__}")
361
- self.update(arg)
362
- if kwargs:
363
- self.update(kwargs)
364
-
365
- def split(self, *types: Type) -> Tuple['DictManager', ...]:
366
- """
367
- Split the DictManager into multiple based on value types.
368
-
369
- Parameters
370
- ----------
371
- *types : Type
372
- Types to use for splitting. Each type gets its own DictManager.
373
-
374
- Returns
375
- -------
376
- Tuple[DictManager, ...]
377
- A tuple of DictManagers, one for each type plus one for unmatched items.
378
- """
379
- results = tuple(type(self)() for _ in range(len(types) + 1))
380
-
381
- for k, v in self.items():
382
- for i, type_ in enumerate(types):
383
- if isinstance(v, type_):
384
- results[i][k] = v
385
- break
386
- else:
387
- results[-1][k] = v
388
-
389
- return results
390
-
391
- def filter_by_predicate(self, predicate: Callable[[K, V], bool]) -> 'DictManager':
392
- """
393
- Filter items using a predicate function.
394
-
395
- Parameters
396
- ----------
397
- predicate : Callable[[key, value], bool]
398
- Function that returns True for items to keep.
399
-
400
- Returns
401
- -------
402
- DictManager
403
- A new DictManager with filtered items.
404
- """
405
- return type(self)({k: v for k, v in self.items() if predicate(k, v)})
406
-
407
- def map_values(self, func: Callable[[V], Any]) -> 'DictManager':
408
- """
409
- Apply a function to all values.
410
-
411
- Parameters
412
- ----------
413
- func : Callable
414
- Function to apply to each value.
415
-
416
- Returns
417
- -------
418
- DictManager
419
- A new DictManager with transformed values.
420
- """
421
- return type(self)({k: func(v) for k, v in self.items()})
422
-
423
- def map_keys(self, func: Callable[[K], Any]) -> 'DictManager':
424
- """
425
- Apply a function to all keys.
426
-
427
- Parameters
428
- ----------
429
- func : Callable
430
- Function to apply to each key.
431
-
432
- Returns
433
- -------
434
- DictManager
435
- A new DictManager with transformed keys.
436
-
437
- Raises
438
- ------
439
- ValueError
440
- If the transformation creates duplicate keys.
441
- """
442
- result = type(self)()
443
- for k, v in self.items():
444
- new_key = func(k)
445
- if new_key in result:
446
- raise ValueError(f"Key transformation created duplicate: {new_key}")
447
- result[new_key] = v
448
- return result
449
-
450
- def pop_by_keys(self, keys: Iterable[K]) -> None:
451
- """Remove multiple keys from the DictManager."""
452
- keys_set = set(keys)
453
- for k in list(self.keys()):
454
- if k in keys_set:
455
- self.pop(k)
456
-
457
- def pop_by_values(self, values: Iterable[V], by: str = 'id') -> None:
458
- """
459
- Remove items by their values.
460
-
461
- Parameters
462
- ----------
463
- values : Iterable
464
- Values to remove.
465
- by : str
466
- Comparison method: 'id' (identity) or 'value' (equality).
467
- """
468
- if by == 'id':
469
- value_ids = {id(v) for v in values}
470
- keys_to_remove = [k for k, v in self.items() if id(v) in value_ids]
471
- elif by == 'value':
472
- values_set = set(values) if not isinstance(values, set) else values
473
- keys_to_remove = [k for k, v in self.items() if v in values_set]
474
- else:
475
- raise ValueError(f"Invalid comparison method: {by}. Use 'id' or 'value'.")
476
-
477
- for k in keys_to_remove:
478
- del self[k]
479
-
480
- def difference_by_keys(self, keys: Iterable[K]) -> 'DictManager':
481
- """Get items not in the specified keys."""
482
- keys_set = set(keys)
483
- return type(self)({k: v for k, v in self.items() if k not in keys_set})
484
-
485
- def difference_by_values(self, values: Iterable[V], by: str = 'id') -> 'DictManager':
486
- """Get items whose values are not in the specified collection."""
487
- if by == 'id':
488
- value_ids = {id(v) for v in values}
489
- return type(self)({k: v for k, v in self.items() if id(v) not in value_ids})
490
- elif by == 'value':
491
- values_set = set(values) if not isinstance(values, set) else values
492
- return type(self)({k: v for k, v in self.items() if v not in values_set})
493
- else:
494
- raise ValueError(f"Invalid comparison method: {by}. Use 'id' or 'value'.")
495
-
496
- def intersection_by_keys(self, keys: Iterable[K]) -> 'DictManager':
497
- """Get items with keys in the specified collection."""
498
- keys_set = set(keys)
499
- return type(self)({k: v for k, v in self.items() if k in keys_set})
500
-
501
- def intersection_by_values(self, values: Iterable[V], by: str = 'id') -> 'DictManager':
502
- """Get items whose values are in the specified collection."""
503
- if by == 'id':
504
- value_ids = {id(v) for v in values}
505
- return type(self)({k: v for k, v in self.items() if id(v) in value_ids})
506
- elif by == 'value':
507
- values_set = set(values) if not isinstance(values, set) else values
508
- return type(self)({k: v for k, v in self.items() if v in values_set})
509
- else:
510
- raise ValueError(f"Invalid comparison method: {by}. Use 'id' or 'value'.")
511
-
512
- def __add__(self, other: Mapping[K, V]) -> 'DictManager':
513
- """Combine with another mapping using the + operator."""
514
- if not isinstance(other, Mapping):
515
- return NotImplemented
516
- new_dict = type(self)(self)
517
- new_dict.update(other)
518
- return new_dict
519
-
520
- def __or__(self, other: Mapping[K, V]) -> 'DictManager':
521
- """Combine with another mapping using the | operator (Python 3.9+)."""
522
- if not isinstance(other, Mapping):
523
- return NotImplemented
524
- new_dict = type(self)(self)
525
- new_dict.update(other)
526
- return new_dict
527
-
528
- def __ior__(self, other: Mapping[K, V]) -> 'DictManager':
529
- """Update in-place with another mapping using |= operator."""
530
- if not isinstance(other, Mapping):
531
- return NotImplemented
532
- self.update(other)
533
- return self
534
-
535
- def tree_flatten(self) -> Tuple[Tuple[V, ...], Tuple[K, ...]]:
536
- """Flatten for JAX pytree."""
537
- return tuple(self.values()), tuple(self.keys())
538
-
539
- @classmethod
540
- def tree_unflatten(cls, keys: Tuple[K, ...], values: Tuple[V, ...]) -> 'DictManager':
541
- """Unflatten from JAX pytree."""
542
- return cls(zip(keys, values))
543
-
544
- def _check_elem(self, elem: Any) -> None:
545
- """Override in subclasses to validate elements."""
546
- pass
547
-
548
- def to_dict(self) -> Dict[K, V]:
549
- """Convert to a standard Python dict."""
550
- return dict(self)
551
-
552
- def __copy__(self) -> 'DictManager':
553
- """Shallow copy."""
554
- return type(self)(self)
555
-
556
- def __deepcopy__(self, memo: Dict[int, Any]) -> 'DictManager':
557
- """Deep copy."""
558
- return type(self)({
559
- copy.deepcopy(k, memo): copy.deepcopy(v, memo)
560
- for k, v in self.items()
561
- })
562
-
563
- def __repr__(self) -> str:
564
- """String representation."""
565
- items = ', '.join(f'{k!r}: {v!r}' for k, v in self.items())
566
- return f'{self.__class__.__name__}({{{items}}})'
567
-
568
-
569
- @set_module_as('brainstate.util')
570
- def clear_buffer_memory(
571
- platform: Optional[str] = None,
572
- array: bool = True,
573
- compilation: bool = False,
574
- ) -> None:
575
- """
576
- Clear on-device memory buffers and optionally compilation cache.
577
-
578
- This function is useful when running models in loops to prevent memory leaks
579
- by clearing cached arrays and freeing device memory.
580
-
581
- .. warning::
582
- This operation may invalidate existing array references.
583
- Regenerate data after calling this function.
584
-
585
- Parameters
586
- ----------
587
- platform : str, optional
588
- The specific device platform to clear. If None, clears the default platform.
589
- array : bool, default=True
590
- Whether to clear array buffers.
591
- compilation : bool, default=False
592
- Whether to clear the compilation cache.
593
-
594
- Examples
595
- --------
596
- >>> clear_buffer_memory() # Clear array buffers
597
- >>> clear_buffer_memory(compilation=True) # Also clear compilation cache
598
- """
599
- if array:
600
- try:
601
- backend = xla_bridge.get_backend(platform)
602
- for buf in backend.live_buffers():
603
- buf.delete()
604
- except Exception as e:
605
- warnings.warn(f"Failed to clear buffers: {e}", RuntimeWarning)
606
-
607
- if compilation:
608
- jax.clear_caches()
609
-
610
- gc.collect()
611
-
612
-
613
- @jax.tree_util.register_pytree_node_class
614
- class DotDict(dict, MutableMapping[str, Any]):
615
- """
616
- Dictionary with dot notation access to nested keys.
617
-
618
- DotDict allows accessing dictionary items using attribute syntax,
619
- making code more readable when dealing with nested configurations.
620
-
621
- Examples
622
- --------
623
- >>> config = DotDict({'model': {'layers': 3, 'units': 64}})
624
- >>> config.model.layers
625
- 3
626
- >>> config.model.units = 128
627
- >>> config['model']['units']
628
- 128
629
-
630
- Attributes
631
- ----------
632
- All dictionary keys become accessible as attributes unless they conflict
633
- with built-in methods.
634
- """
635
-
636
- __module__ = 'brainstate.util'
637
-
638
- def __init__(self, *args, **kwargs):
639
- """
640
- Initialize DotDict with dict-like arguments.
641
-
642
- Parameters
643
- ----------
644
- *args
645
- Positional arguments (dicts, iterables of pairs).
646
- **kwargs
647
- Keyword arguments become key-value pairs.
648
- """
649
- # Handle parent reference for nested updates
650
- object.__setattr__(self, '__parent', kwargs.pop('__parent', None))
651
- object.__setattr__(self, '__key', kwargs.pop('__key', None))
652
-
653
- # Process positional arguments
654
- for arg in args:
655
- if not arg:
656
- continue
657
- elif isinstance(arg, dict):
658
- for key, val in arg.items():
659
- self[key] = self._hook(val)
660
- elif isinstance(arg, tuple) and len(arg) == 2 and not isinstance(arg[0], tuple):
661
- # Single key-value pair
662
- self[arg[0]] = self._hook(arg[1])
663
- else:
664
- # Iterable of key-value pairs
665
- try:
666
- for key, val in arg:
667
- self[key] = self._hook(val)
668
- except (TypeError, ValueError) as e:
669
- raise TypeError(f"Invalid argument type for DotDict: {type(arg).__name__}") from e
670
-
671
- # Process keyword arguments
672
- for key, val in kwargs.items():
673
- self[key] = self._hook(val)
674
-
675
- def __setattr__(self, name: str, value: Any) -> None:
676
- """Set attribute as dictionary item."""
677
- if hasattr(self.__class__, name):
678
- raise AttributeError(
679
- f"Cannot set attribute '{name}': it's a built-in method of {self.__class__.__name__}"
680
- )
681
- self[name] = value
682
-
683
- def __setitem__(self, name: str, value: Any) -> None:
684
- """Set item and update parent if nested."""
685
- super().__setitem__(name, value)
686
- try:
687
- parent = object.__getattribute__(self, '__parent')
688
- key = object.__getattribute__(self, '__key')
689
- if parent is not None:
690
- parent[key] = self
691
- object.__delattr__(self, '__parent')
692
- object.__delattr__(self, '__key')
693
- except AttributeError:
694
- pass
695
-
696
- @classmethod
697
- def _hook(cls, item: Any) -> Any:
698
- """Convert nested dicts to DotDict."""
699
- if isinstance(item, dict) and not isinstance(item, cls):
700
- return cls(item)
701
- elif isinstance(item, (list, tuple)):
702
- return type(item)(cls._hook(elem) for elem in item)
703
- return item
704
-
705
- def __getattr__(self, item: str) -> Any:
706
- """Get attribute from dictionary."""
707
- try:
708
- return self[item]
709
- except KeyError:
710
- raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'")
711
-
712
- def __delattr__(self, name: str) -> None:
713
- """Delete attribute from dictionary."""
714
- try:
715
- del self[name]
716
- except KeyError:
717
- raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
718
-
719
- def __dir__(self) -> List[str]:
720
- """List all attributes including dict keys."""
721
- return list(self.keys()) + dir(self.__class__)
722
-
723
- def get(self, key: str, default: Any = None) -> Any:
724
- """Get item with default value."""
725
- return super().get(key, default)
726
-
727
- def copy(self) -> 'DotDict':
728
- """Create a shallow copy."""
729
- return copy.copy(self)
730
-
731
- def deepcopy(self) -> 'DotDict':
732
- """Create a deep copy."""
733
- return copy.deepcopy(self)
734
-
735
- def __deepcopy__(self, memo: Dict[int, Any]) -> 'DotDict':
736
- """Deep copy implementation."""
737
- other = self.__class__()
738
- memo[id(self)] = other
739
- for key, value in self.items():
740
- other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo)
741
- return other
742
-
743
- def to_dict(self) -> Dict[str, Any]:
744
- """
745
- Convert to standard dict recursively.
746
-
747
- Returns
748
- -------
749
- dict
750
- A standard Python dict with nested DotDicts also converted.
751
- """
752
- result = {}
753
- for key, value in self.items():
754
- if isinstance(value, DotDict):
755
- result[key] = value.to_dict()
756
- elif isinstance(value, (list, tuple)):
757
- result[key] = type(value)(
758
- item.to_dict() if isinstance(item, DotDict) else item
759
- for item in value
760
- )
761
- else:
762
- result[key] = value
763
- return result
764
-
765
- @classmethod
766
- def from_dict(cls, d: Dict[str, Any]) -> 'DotDict':
767
- """
768
- Create DotDict from standard dict.
769
-
770
- Parameters
771
- ----------
772
- d : dict
773
- Standard Python dictionary.
774
-
775
- Returns
776
- -------
777
- DotDict
778
- A new DotDict instance.
779
- """
780
- return cls(d)
781
-
782
- def update(self, *args, **kwargs) -> None:
783
- """
784
- Update with recursive merge for nested dicts.
785
-
786
- Parameters
787
- ----------
788
- *args
789
- Dict-like objects to merge.
790
- **kwargs
791
- Key-value pairs to merge.
792
- """
793
- if args:
794
- if len(args) > 1:
795
- raise TypeError(f"update expected at most 1 argument, got {len(args)}")
796
- other = args[0]
797
- else:
798
- other = {}
799
-
800
- if hasattr(other, 'items'):
801
- other = dict(other.items())
802
- other.update(kwargs)
803
-
804
- for k, v in other.items():
805
- if k in self and isinstance(self[k], dict) and isinstance(v, dict):
806
- # Recursive merge for nested dicts
807
- if isinstance(self[k], DotDict):
808
- self[k].update(v)
809
- else:
810
- self[k] = DotDict(self[k])
811
- self[k].update(v)
812
- else:
813
- self[k] = self._hook(v)
814
-
815
- def setdefault(self, key: str, default: Any = None) -> Any:
816
- """Set default value if key doesn't exist."""
817
- if key not in self:
818
- self[key] = default
819
- return self[key]
820
-
821
- def __getstate__(self) -> Dict[str, Any]:
822
- """Get state for pickling."""
823
- return dict(self)
824
-
825
- def __setstate__(self, state: Dict[str, Any]) -> None:
826
- """Set state from pickling."""
827
- self.update(state)
828
-
829
- def tree_flatten(self) -> Tuple[Tuple[Any, ...], Tuple[str, ...]]:
830
- """Flatten for JAX pytree."""
831
- return tuple(self.values()), tuple(self.keys())
832
-
833
- @classmethod
834
- def tree_unflatten(cls, keys: Tuple[str, ...], values: Tuple[Any, ...]) -> 'DotDict':
835
- """Unflatten from JAX pytree."""
836
- return cls(zip(keys, values))
837
-
838
- def __repr__(self) -> str:
839
- """String representation."""
840
- items = ', '.join(f'{k!r}: {v!r}' for k, v in self.items())
841
- return f'DotDict({{{items}}})'
842
-
843
-
844
- @set_module_as('brainstate.util')
845
- def merge_dicts(*dicts: Dict[K, V], recursive: bool = True) -> Dict[K, V]:
846
- """
847
- Merge multiple dictionaries.
848
-
849
- Parameters
850
- ----------
851
- *dicts : Dict
852
- Dictionaries to merge (later ones override earlier ones).
853
- recursive : bool, default=True
854
- Whether to recursively merge nested dicts.
855
-
856
- Returns
857
- -------
858
- Dict
859
- Merged dictionary.
860
-
861
- Examples
862
- --------
863
- >>> d1 = {'a': 1, 'b': {'c': 2}}
864
- >>> d2 = {'b': {'d': 3}, 'e': 4}
865
- >>> merge_dicts(d1, d2)
866
- {'a': 1, 'b': {'c': 2, 'd': 3}, 'e': 4}
867
- """
868
- result = {}
869
-
870
- for d in dicts:
871
- if not isinstance(d, dict):
872
- raise TypeError(f"All arguments must be dicts, got {type(d).__name__}")
873
-
874
- for key, value in d.items():
875
- if recursive and key in result and isinstance(result[key], dict) and isinstance(value, dict):
876
- result[key] = merge_dicts(result[key], value, recursive=True)
877
- else:
878
- result[key] = value
879
-
880
- return result
881
-
882
-
883
- @set_module_as('brainstate.util')
884
- def flatten_dict(
885
- d: Dict[str, Any],
886
- parent_key: str = '',
887
- sep: str = '.'
888
- ) -> Dict[str, Any]:
889
- """
890
- Flatten a nested dictionary.
891
-
892
- Parameters
893
- ----------
894
- d : Dict
895
- Dictionary to flatten.
896
- parent_key : str, default=''
897
- Prefix for keys.
898
- sep : str, default='.'
899
- Separator between nested keys.
900
-
901
- Returns
902
- -------
903
- Dict
904
- Flattened dictionary.
905
-
906
- Examples
907
- --------
908
- >>> d = {'a': 1, 'b': {'c': 2, 'd': {'e': 3}}}
909
- >>> flatten_dict(d)
910
- {'a': 1, 'b.c': 2, 'b.d.e': 3}
911
- """
912
- items = []
913
- for k, v in d.items():
914
- new_key = f"{parent_key}{sep}{k}" if parent_key else k
915
- if isinstance(v, dict):
916
- items.extend(flatten_dict(v, new_key, sep=sep).items())
917
- else:
918
- items.append((new_key, v))
919
- return dict(items)
920
-
921
-
922
- @set_module_as('brainstate.util')
923
- def unflatten_dict(
924
- d: Dict[str, Any],
925
- sep: str = '.'
926
- ) -> Dict[str, Any]:
927
- """
928
- Unflatten a dictionary with separated keys.
929
-
930
- Parameters
931
- ----------
932
- d : Dict
933
- Flattened dictionary.
934
- sep : str, default='.'
935
- Separator in keys.
936
-
937
- Returns
938
- -------
939
- Dict
940
- Nested dictionary.
941
-
942
- Examples
943
- --------
944
- >>> d = {'a': 1, 'b.c': 2, 'b.d.e': 3}
945
- >>> unflatten_dict(d)
946
- {'a': 1, 'b': {'c': 2, 'd': {'e': 3}}}
947
- """
948
- result = {}
949
-
950
- for key, value in d.items():
951
- parts = key.split(sep)
952
- current = result
953
-
954
- for part in parts[:-1]:
955
- if part not in current:
956
- current[part] = {}
957
- current = current[part]
958
-
959
- current[parts[-1]] = value
960
-
961
- return result
962
-
963
-
964
- def _is_not_instance(x: Any, cls: Union[Type, Tuple[Type, ...]]) -> bool:
965
- """Check if x is not an instance of cls."""
966
- return not isinstance(x, cls)
967
-
968
-
969
- def _is_instance(x: Any, cls: Union[Type, Tuple[Type, ...]]) -> bool:
970
- """Check if x is an instance of cls."""
971
- return isinstance(x, cls)
972
-
973
-
974
- @set_module_as('brainstate.util')
975
- def not_instance_eval(*cls: Type) -> Callable[[Any], bool]:
976
- """
977
- Create a partial function to check if input is NOT an instance of given classes.
978
-
979
- Parameters
980
- ----------
981
- *cls : Type
982
- Classes to check against.
983
-
984
- Returns
985
- -------
986
- Callable
987
- A function that returns True if input is not an instance of any given class.
988
-
989
- Examples
990
- --------
991
- >>> not_int = not_instance_eval(int)
992
- >>> not_int(5)
993
- False
994
- >>> not_int("hello")
995
- True
996
- """
997
- return functools.partial(_is_not_instance, cls=cls)
998
-
999
-
1000
- @set_module_as('brainstate.util')
1001
- def is_instance_eval(*cls: Type) -> Callable[[Any], bool]:
1002
- """
1003
- Create a partial function to check if input IS an instance of given classes.
1004
-
1005
- Parameters
1006
- ----------
1007
- *cls : Type
1008
- Classes to check against.
1009
-
1010
- Returns
1011
- -------
1012
- Callable
1013
- A function that returns True if input is an instance of any given class.
1014
-
1015
- Examples
1016
- --------
1017
- >>> is_number = is_instance_eval(int, float)
1018
- >>> is_number(5)
1019
- True
1020
- >>> is_number(3.14)
1021
- True
1022
- >>> is_number("hello")
1023
- False
1024
- """
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ Utility functions and classes for BrainState.
18
+
19
+ This module provides various utility functions and enhanced dictionary classes
20
+ for managing collections, memory, and object operations in the BrainState framework.
21
+ """
22
+
23
+ import copy
24
+ import functools
25
+ import gc
26
+ import threading
27
+ import types
28
+ import warnings
29
+ from collections.abc import Iterable, Mapping, MutableMapping
30
+ from typing import (
31
+ Any, Callable, Dict, Iterator, List, Optional,
32
+ Tuple, Type, TypeVar, Union, overload
33
+ )
34
+
35
+ import jax
36
+ from jax.lib import xla_bridge
37
+
38
+ from brainstate._utils import set_module_as
39
+
40
+ __all__ = [
41
+ 'split_total',
42
+ 'clear_buffer_memory',
43
+ 'not_instance_eval',
44
+ 'is_instance_eval',
45
+ 'DictManager',
46
+ 'DotDict',
47
+ 'get_unique_name',
48
+ 'merge_dicts',
49
+ 'flatten_dict',
50
+ 'unflatten_dict',
51
+ ]
52
+
53
+ T = TypeVar('T')
54
+ V = TypeVar('V')
55
+ K = TypeVar('K')
56
+
57
+
58
+ def split_total(
59
+ total: int,
60
+ fraction: Union[int, float],
61
+ ) -> int:
62
+ """
63
+ Calculate the number of epochs for simulation based on a total and a fraction.
64
+
65
+ This function determines the number of epochs to simulate given a total number
66
+ of epochs and either a fraction or a specific number of epochs to run.
67
+
68
+ Parameters
69
+ ----------
70
+ total : int
71
+ The total number of epochs. Must be a positive integer.
72
+ fraction : Union[int, float]
73
+ If ``float``: A value between 0 and 1 representing the fraction of total epochs to run.
74
+ If ``int``: The specific number of epochs to run, must not exceed the total.
75
+
76
+ Returns
77
+ -------
78
+ int
79
+ The calculated number of epochs to simulate.
80
+
81
+ Raises
82
+ ------
83
+ TypeError
84
+ If total is not an integer.
85
+ ValueError
86
+ If total is not positive, fraction is negative, or if fraction as float is > 1
87
+ or as int is > total.
88
+
89
+ Examples
90
+ --------
91
+ >>> split_total(100, 0.5)
92
+ 50
93
+ >>> split_total(100, 25)
94
+ 25
95
+ >>> split_total(100, 1.5) # Raises ValueError
96
+ ValueError: 'fraction' value cannot be greater than 1.
97
+ """
98
+ if not isinstance(total, int):
99
+ raise TypeError(f"'total' must be an integer, got {type(total).__name__}.")
100
+ if total <= 0:
101
+ raise ValueError(f"'total' must be a positive integer, got {total}.")
102
+
103
+ if isinstance(fraction, float):
104
+ if fraction < 0:
105
+ raise ValueError(f"'fraction' value cannot be negative, got {fraction}.")
106
+ if fraction > 1:
107
+ raise ValueError(f"'fraction' value cannot be greater than 1, got {fraction}.")
108
+ return int(total * fraction)
109
+
110
+ elif isinstance(fraction, int):
111
+ if fraction < 0:
112
+ raise ValueError(f"'fraction' value cannot be negative, got {fraction}.")
113
+ if fraction > total:
114
+ raise ValueError(f"'fraction' value cannot be greater than total ({total}), got {fraction}.")
115
+ return fraction
116
+
117
+ else:
118
+ raise TypeError(f"'fraction' must be an integer or float, got {type(fraction).__name__}.")
119
+
120
+
121
+ class NameContext(threading.local):
122
+ """Thread-local context for managing unique names."""
123
+
124
+ def __init__(self):
125
+ self.typed_names: Dict[str, int] = {}
126
+
127
+ def reset(self, type_: Optional[str] = None) -> None:
128
+ """Reset the counter for a specific type or all types."""
129
+ if type_ is None:
130
+ self.typed_names.clear()
131
+ elif type_ in self.typed_names:
132
+ self.typed_names[type_] = 0
133
+
134
+
135
+ NAME = NameContext()
136
+
137
+
138
+ @set_module_as('brainstate.util')
139
+ def get_unique_name(type_: str, prefix: str = '') -> str:
140
+ """
141
+ Get a unique name for the given object type.
142
+
143
+ Parameters
144
+ ----------
145
+ type_ : str
146
+ The base type name.
147
+ prefix : str, optional
148
+ Additional prefix to add before the type name.
149
+
150
+ Returns
151
+ -------
152
+ str
153
+ A unique name combining prefix, type, and counter.
154
+
155
+ Examples
156
+ --------
157
+ >>> get_unique_name('layer')
158
+ 'layer0'
159
+ >>> get_unique_name('layer', 'conv_')
160
+ 'conv_layer1'
161
+ """
162
+ if type_ not in NAME.typed_names:
163
+ NAME.typed_names[type_] = 0
164
+
165
+ full_prefix = f'{prefix}{type_}' if prefix else type_
166
+ name = f'{full_prefix}{NAME.typed_names[type_]}'
167
+ NAME.typed_names[type_] += 1
168
+ return name
169
+
170
+
171
+ @jax.tree_util.register_pytree_node_class
172
+ class DictManager(dict, MutableMapping[K, V]):
173
+ """
174
+ Enhanced dictionary for managing collections in BrainState.
175
+
176
+ DictManager extends the standard Python dict with additional methods for
177
+ filtering, splitting, and managing collections of objects. It's registered
178
+ as a JAX pytree node for compatibility with JAX transformations.
179
+
180
+ Examples
181
+ --------
182
+ >>> dm = DictManager({'a': 1, 'b': 2.0, 'c': 'text'})
183
+ >>> dm.subset(int) # Get only integer values
184
+ DictManager({'a': 1})
185
+ >>> dm.unique() # Get unique values only
186
+ DictManager({'a': 1, 'b': 2.0, 'c': 'text'})
187
+ """
188
+
189
+ __module__ = 'brainstate.util'
190
+ _val_id_to_key: Dict[int, Any]
191
+
192
+ def __init__(self, *args, **kwargs):
193
+ """Initialize DictManager with optional dict-like arguments."""
194
+ super().__init__(*args, **kwargs)
195
+ self._val_id_to_key = {}
196
+
197
+ def subset(self, sep: Union[Type, Tuple[Type, ...], Callable[[Any], bool]]) -> 'DictManager':
198
+ """
199
+ Get a new DictManager with a subset of items based on value type or predicate.
200
+
201
+ Parameters
202
+ ----------
203
+ sep : Union[Type, Tuple[Type, ...], Callable]
204
+ If Type or Tuple of Types: Select values that are instances of these types.
205
+ If Callable: Select values where sep(value) returns True.
206
+
207
+ Returns
208
+ -------
209
+ DictManager
210
+ A new DictManager containing only matching items.
211
+
212
+ Examples
213
+ --------
214
+ >>> dm = DictManager({'a': 1, 'b': 2.0, 'c': 'text'})
215
+ >>> dm.subset(int)
216
+ DictManager({'a': 1})
217
+ >>> dm.subset(lambda x: isinstance(x, (int, float)))
218
+ DictManager({'a': 1, 'b': 2.0})
219
+ """
220
+ gather = type(self)()
221
+ if callable(sep) and not isinstance(sep, type):
222
+ for k, v in self.items():
223
+ if sep(v):
224
+ gather[k] = v
225
+ else:
226
+ for k, v in self.items():
227
+ if isinstance(v, sep):
228
+ gather[k] = v
229
+ return gather
230
+
231
+ def not_subset(self, sep: Union[Type, Tuple[Type, ...]]) -> 'DictManager':
232
+ """
233
+ Get a new DictManager excluding items of specified types.
234
+
235
+ Parameters
236
+ ----------
237
+ sep : Union[Type, Tuple[Type, ...]]
238
+ Types to exclude from the result.
239
+
240
+ Returns
241
+ -------
242
+ DictManager
243
+ A new DictManager excluding items of specified types.
244
+ """
245
+ gather = type(self)()
246
+ for k, v in self.items():
247
+ if not isinstance(v, sep):
248
+ gather[k] = v
249
+ return gather
250
+
251
+ def add_unique_key(self, key: K, val: V) -> None:
252
+ """
253
+ Add a new element ensuring the key maps to a unique value.
254
+
255
+ Parameters
256
+ ----------
257
+ key : Any
258
+ The key to add.
259
+ val : Any
260
+ The value to associate with the key.
261
+
262
+ Raises
263
+ ------
264
+ ValueError
265
+ If the key already exists with a different value.
266
+ """
267
+ self._check_elem(val)
268
+ if key in self:
269
+ if id(val) != id(self[key]):
270
+ raise ValueError(
271
+ f"Key '{key}' already exists with a different value. "
272
+ f"Existing: {self[key]}, New: {val}"
273
+ )
274
+ else:
275
+ self[key] = val
276
+
277
+ def add_unique_value(self, key: K, val: V) -> bool:
278
+ """
279
+ Add a new element only if the value is unique across all entries.
280
+
281
+ Parameters
282
+ ----------
283
+ key : Any
284
+ The key to add.
285
+ val : Any
286
+ The value to associate with the key.
287
+
288
+ Returns
289
+ -------
290
+ bool
291
+ True if the value was added (was unique), False otherwise.
292
+ """
293
+ self._check_elem(val)
294
+ if not hasattr(self, '_val_id_to_key'):
295
+ self._val_id_to_key = {id(v): k for k, v in self.items()}
296
+
297
+ val_id = id(val)
298
+ if val_id not in self._val_id_to_key:
299
+ self._val_id_to_key[val_id] = key
300
+ self[key] = val
301
+ return True
302
+ return False
303
+
304
+ def unique(self) -> 'DictManager':
305
+ """
306
+ Get a new DictManager with unique values only.
307
+
308
+ If multiple keys map to the same value (by identity),
309
+ only the first key-value pair is retained.
310
+
311
+ Returns
312
+ -------
313
+ DictManager
314
+ A new DictManager with unique values.
315
+ """
316
+ gather = type(self)()
317
+ seen = set()
318
+ for k, v in self.items():
319
+ v_id = id(v)
320
+ if v_id not in seen:
321
+ seen.add(v_id)
322
+ gather[k] = v
323
+ return gather
324
+
325
+ def unique_(self) -> 'DictManager':
326
+ """
327
+ Remove duplicate values in-place.
328
+
329
+ Returns
330
+ -------
331
+ DictManager
332
+ Self, for method chaining.
333
+ """
334
+ seen = set()
335
+ keys_to_remove = []
336
+ for k, v in self.items():
337
+ v_id = id(v)
338
+ if v_id in seen:
339
+ keys_to_remove.append(k)
340
+ else:
341
+ seen.add(v_id)
342
+
343
+ for k in keys_to_remove:
344
+ del self[k]
345
+ return self
346
+
347
+ def assign(self, *args: Dict[K, V], **kwargs: V) -> None:
348
+ """
349
+ Update the DictManager with multiple dictionaries.
350
+
351
+ Parameters
352
+ ----------
353
+ *args : Dict
354
+ Dictionaries to merge into this one.
355
+ **kwargs
356
+ Additional key-value pairs to add.
357
+ """
358
+ for arg in args:
359
+ if not isinstance(arg, dict):
360
+ raise TypeError(f"Arguments must be dict instances, got {type(arg).__name__}")
361
+ self.update(arg)
362
+ if kwargs:
363
+ self.update(kwargs)
364
+
365
+ def split(self, *types: Type) -> Tuple['DictManager', ...]:
366
+ """
367
+ Split the DictManager into multiple based on value types.
368
+
369
+ Parameters
370
+ ----------
371
+ *types : Type
372
+ Types to use for splitting. Each type gets its own DictManager.
373
+
374
+ Returns
375
+ -------
376
+ Tuple[DictManager, ...]
377
+ A tuple of DictManagers, one for each type plus one for unmatched items.
378
+ """
379
+ results = tuple(type(self)() for _ in range(len(types) + 1))
380
+
381
+ for k, v in self.items():
382
+ for i, type_ in enumerate(types):
383
+ if isinstance(v, type_):
384
+ results[i][k] = v
385
+ break
386
+ else:
387
+ results[-1][k] = v
388
+
389
+ return results
390
+
391
+ def filter_by_predicate(self, predicate: Callable[[K, V], bool]) -> 'DictManager':
392
+ """
393
+ Filter items using a predicate function.
394
+
395
+ Parameters
396
+ ----------
397
+ predicate : Callable[[key, value], bool]
398
+ Function that returns True for items to keep.
399
+
400
+ Returns
401
+ -------
402
+ DictManager
403
+ A new DictManager with filtered items.
404
+ """
405
+ return type(self)({k: v for k, v in self.items() if predicate(k, v)})
406
+
407
+ def map_values(self, func: Callable[[V], Any]) -> 'DictManager':
408
+ """
409
+ Apply a function to all values.
410
+
411
+ Parameters
412
+ ----------
413
+ func : Callable
414
+ Function to apply to each value.
415
+
416
+ Returns
417
+ -------
418
+ DictManager
419
+ A new DictManager with transformed values.
420
+ """
421
+ return type(self)({k: func(v) for k, v in self.items()})
422
+
423
+ def map_keys(self, func: Callable[[K], Any]) -> 'DictManager':
424
+ """
425
+ Apply a function to all keys.
426
+
427
+ Parameters
428
+ ----------
429
+ func : Callable
430
+ Function to apply to each key.
431
+
432
+ Returns
433
+ -------
434
+ DictManager
435
+ A new DictManager with transformed keys.
436
+
437
+ Raises
438
+ ------
439
+ ValueError
440
+ If the transformation creates duplicate keys.
441
+ """
442
+ result = type(self)()
443
+ for k, v in self.items():
444
+ new_key = func(k)
445
+ if new_key in result:
446
+ raise ValueError(f"Key transformation created duplicate: {new_key}")
447
+ result[new_key] = v
448
+ return result
449
+
450
+ def pop_by_keys(self, keys: Iterable[K]) -> None:
451
+ """Remove multiple keys from the DictManager."""
452
+ keys_set = set(keys)
453
+ for k in list(self.keys()):
454
+ if k in keys_set:
455
+ self.pop(k)
456
+
457
+ def pop_by_values(self, values: Iterable[V], by: str = 'id') -> None:
458
+ """
459
+ Remove items by their values.
460
+
461
+ Parameters
462
+ ----------
463
+ values : Iterable
464
+ Values to remove.
465
+ by : str
466
+ Comparison method: 'id' (identity) or 'value' (equality).
467
+ """
468
+ if by == 'id':
469
+ value_ids = {id(v) for v in values}
470
+ keys_to_remove = [k for k, v in self.items() if id(v) in value_ids]
471
+ elif by == 'value':
472
+ values_set = set(values) if not isinstance(values, set) else values
473
+ keys_to_remove = [k for k, v in self.items() if v in values_set]
474
+ else:
475
+ raise ValueError(f"Invalid comparison method: {by}. Use 'id' or 'value'.")
476
+
477
+ for k in keys_to_remove:
478
+ del self[k]
479
+
480
+ def difference_by_keys(self, keys: Iterable[K]) -> 'DictManager':
481
+ """Get items not in the specified keys."""
482
+ keys_set = set(keys)
483
+ return type(self)({k: v for k, v in self.items() if k not in keys_set})
484
+
485
+ def difference_by_values(self, values: Iterable[V], by: str = 'id') -> 'DictManager':
486
+ """Get items whose values are not in the specified collection."""
487
+ if by == 'id':
488
+ value_ids = {id(v) for v in values}
489
+ return type(self)({k: v for k, v in self.items() if id(v) not in value_ids})
490
+ elif by == 'value':
491
+ values_set = set(values) if not isinstance(values, set) else values
492
+ return type(self)({k: v for k, v in self.items() if v not in values_set})
493
+ else:
494
+ raise ValueError(f"Invalid comparison method: {by}. Use 'id' or 'value'.")
495
+
496
+ def intersection_by_keys(self, keys: Iterable[K]) -> 'DictManager':
497
+ """Get items with keys in the specified collection."""
498
+ keys_set = set(keys)
499
+ return type(self)({k: v for k, v in self.items() if k in keys_set})
500
+
501
+ def intersection_by_values(self, values: Iterable[V], by: str = 'id') -> 'DictManager':
502
+ """Get items whose values are in the specified collection."""
503
+ if by == 'id':
504
+ value_ids = {id(v) for v in values}
505
+ return type(self)({k: v for k, v in self.items() if id(v) in value_ids})
506
+ elif by == 'value':
507
+ values_set = set(values) if not isinstance(values, set) else values
508
+ return type(self)({k: v for k, v in self.items() if v in values_set})
509
+ else:
510
+ raise ValueError(f"Invalid comparison method: {by}. Use 'id' or 'value'.")
511
+
512
+ def __add__(self, other: Mapping[K, V]) -> 'DictManager':
513
+ """Combine with another mapping using the + operator."""
514
+ if not isinstance(other, Mapping):
515
+ return NotImplemented
516
+ new_dict = type(self)(self)
517
+ new_dict.update(other)
518
+ return new_dict
519
+
520
+ def __or__(self, other: Mapping[K, V]) -> 'DictManager':
521
+ """Combine with another mapping using the | operator (Python 3.9+)."""
522
+ if not isinstance(other, Mapping):
523
+ return NotImplemented
524
+ new_dict = type(self)(self)
525
+ new_dict.update(other)
526
+ return new_dict
527
+
528
+ def __ior__(self, other: Mapping[K, V]) -> 'DictManager':
529
+ """Update in-place with another mapping using |= operator."""
530
+ if not isinstance(other, Mapping):
531
+ return NotImplemented
532
+ self.update(other)
533
+ return self
534
+
535
+ def tree_flatten(self) -> Tuple[Tuple[V, ...], Tuple[K, ...]]:
536
+ """Flatten for JAX pytree."""
537
+ return tuple(self.values()), tuple(self.keys())
538
+
539
+ @classmethod
540
+ def tree_unflatten(cls, keys: Tuple[K, ...], values: Tuple[V, ...]) -> 'DictManager':
541
+ """Unflatten from JAX pytree."""
542
+ return cls(zip(keys, values))
543
+
544
+ def _check_elem(self, elem: Any) -> None:
545
+ """Override in subclasses to validate elements."""
546
+ pass
547
+
548
+ def to_dict(self) -> Dict[K, V]:
549
+ """Convert to a standard Python dict."""
550
+ return dict(self)
551
+
552
+ def __copy__(self) -> 'DictManager':
553
+ """Shallow copy."""
554
+ return type(self)(self)
555
+
556
+ def __deepcopy__(self, memo: Dict[int, Any]) -> 'DictManager':
557
+ """Deep copy."""
558
+ return type(self)({
559
+ copy.deepcopy(k, memo): copy.deepcopy(v, memo)
560
+ for k, v in self.items()
561
+ })
562
+
563
+ def __repr__(self) -> str:
564
+ """String representation."""
565
+ items = ', '.join(f'{k!r}: {v!r}' for k, v in self.items())
566
+ return f'{self.__class__.__name__}({{{items}}})'
567
+
568
+
569
+ @set_module_as('brainstate.util')
570
+ def clear_buffer_memory(
571
+ platform: Optional[str] = None,
572
+ array: bool = True,
573
+ compilation: bool = False,
574
+ ) -> None:
575
+ """
576
+ Clear on-device memory buffers and optionally compilation cache.
577
+
578
+ This function is useful when running models in loops to prevent memory leaks
579
+ by clearing cached arrays and freeing device memory.
580
+
581
+ .. warning::
582
+ This operation may invalidate existing array references.
583
+ Regenerate data after calling this function.
584
+
585
+ Parameters
586
+ ----------
587
+ platform : str, optional
588
+ The specific device platform to clear. If None, clears the default platform.
589
+ array : bool, default=True
590
+ Whether to clear array buffers.
591
+ compilation : bool, default=False
592
+ Whether to clear the compilation cache.
593
+
594
+ Examples
595
+ --------
596
+ >>> clear_buffer_memory() # Clear array buffers
597
+ >>> clear_buffer_memory(compilation=True) # Also clear compilation cache
598
+ """
599
+ if array:
600
+ try:
601
+ backend = xla_bridge.get_backend(platform)
602
+ for buf in backend.live_buffers():
603
+ buf.delete()
604
+ except Exception as e:
605
+ warnings.warn(f"Failed to clear buffers: {e}", RuntimeWarning)
606
+
607
+ if compilation:
608
+ jax.clear_caches()
609
+
610
+ gc.collect()
611
+
612
+
613
+ @jax.tree_util.register_pytree_node_class
614
+ class DotDict(dict, MutableMapping[str, Any]):
615
+ """
616
+ Dictionary with dot notation access to nested keys.
617
+
618
+ DotDict allows accessing dictionary items using attribute syntax,
619
+ making code more readable when dealing with nested configurations.
620
+
621
+ Examples
622
+ --------
623
+ >>> config = DotDict({'model': {'layers': 3, 'units': 64}})
624
+ >>> config.model.layers
625
+ 3
626
+ >>> config.model.units = 128
627
+ >>> config['model']['units']
628
+ 128
629
+
630
+ Attributes
631
+ ----------
632
+ All dictionary keys become accessible as attributes unless they conflict
633
+ with built-in methods.
634
+ """
635
+
636
+ __module__ = 'brainstate.util'
637
+
638
+ def __init__(self, *args, **kwargs):
639
+ """
640
+ Initialize DotDict with dict-like arguments.
641
+
642
+ Parameters
643
+ ----------
644
+ *args
645
+ Positional arguments (dicts, iterables of pairs).
646
+ **kwargs
647
+ Keyword arguments become key-value pairs.
648
+ """
649
+ # Handle parent reference for nested updates
650
+ object.__setattr__(self, '__parent', kwargs.pop('__parent', None))
651
+ object.__setattr__(self, '__key', kwargs.pop('__key', None))
652
+
653
+ # Process positional arguments
654
+ for arg in args:
655
+ if not arg:
656
+ continue
657
+ elif isinstance(arg, dict):
658
+ for key, val in arg.items():
659
+ self[key] = self._hook(val)
660
+ elif isinstance(arg, tuple) and len(arg) == 2 and not isinstance(arg[0], tuple):
661
+ # Single key-value pair
662
+ self[arg[0]] = self._hook(arg[1])
663
+ else:
664
+ # Iterable of key-value pairs
665
+ try:
666
+ for key, val in arg:
667
+ self[key] = self._hook(val)
668
+ except (TypeError, ValueError) as e:
669
+ raise TypeError(f"Invalid argument type for DotDict: {type(arg).__name__}") from e
670
+
671
+ # Process keyword arguments
672
+ for key, val in kwargs.items():
673
+ self[key] = self._hook(val)
674
+
675
+ def __setattr__(self, name: str, value: Any) -> None:
676
+ """Set attribute as dictionary item."""
677
+ if hasattr(self.__class__, name):
678
+ raise AttributeError(
679
+ f"Cannot set attribute '{name}': it's a built-in method of {self.__class__.__name__}"
680
+ )
681
+ self[name] = value
682
+
683
+ def __setitem__(self, name: str, value: Any) -> None:
684
+ """Set item and update parent if nested."""
685
+ super().__setitem__(name, value)
686
+ try:
687
+ parent = object.__getattribute__(self, '__parent')
688
+ key = object.__getattribute__(self, '__key')
689
+ if parent is not None:
690
+ parent[key] = self
691
+ object.__delattr__(self, '__parent')
692
+ object.__delattr__(self, '__key')
693
+ except AttributeError:
694
+ pass
695
+
696
+ @classmethod
697
+ def _hook(cls, item: Any) -> Any:
698
+ """Convert nested dicts to DotDict."""
699
+ if isinstance(item, dict) and not isinstance(item, cls):
700
+ return cls(item)
701
+ elif isinstance(item, (list, tuple)):
702
+ return type(item)(cls._hook(elem) for elem in item)
703
+ return item
704
+
705
+ def __getattr__(self, item: str) -> Any:
706
+ """Get attribute from dictionary."""
707
+ try:
708
+ return self[item]
709
+ except KeyError:
710
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'")
711
+
712
+ def __delattr__(self, name: str) -> None:
713
+ """Delete attribute from dictionary."""
714
+ try:
715
+ del self[name]
716
+ except KeyError:
717
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
718
+
719
+ def __dir__(self) -> List[str]:
720
+ """List all attributes including dict keys."""
721
+ return list(self.keys()) + dir(self.__class__)
722
+
723
+ def get(self, key: str, default: Any = None) -> Any:
724
+ """Get item with default value."""
725
+ return super().get(key, default)
726
+
727
+ def copy(self) -> 'DotDict':
728
+ """Create a shallow copy."""
729
+ return copy.copy(self)
730
+
731
+ def deepcopy(self) -> 'DotDict':
732
+ """Create a deep copy."""
733
+ return copy.deepcopy(self)
734
+
735
+ def __deepcopy__(self, memo: Dict[int, Any]) -> 'DotDict':
736
+ """Deep copy implementation."""
737
+ other = self.__class__()
738
+ memo[id(self)] = other
739
+ for key, value in self.items():
740
+ other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo)
741
+ return other
742
+
743
+ def to_dict(self) -> Dict[str, Any]:
744
+ """
745
+ Convert to standard dict recursively.
746
+
747
+ Returns
748
+ -------
749
+ dict
750
+ A standard Python dict with nested DotDicts also converted.
751
+ """
752
+ result = {}
753
+ for key, value in self.items():
754
+ if isinstance(value, DotDict):
755
+ result[key] = value.to_dict()
756
+ elif isinstance(value, (list, tuple)):
757
+ result[key] = type(value)(
758
+ item.to_dict() if isinstance(item, DotDict) else item
759
+ for item in value
760
+ )
761
+ else:
762
+ result[key] = value
763
+ return result
764
+
765
+ @classmethod
766
+ def from_dict(cls, d: Dict[str, Any]) -> 'DotDict':
767
+ """
768
+ Create DotDict from standard dict.
769
+
770
+ Parameters
771
+ ----------
772
+ d : dict
773
+ Standard Python dictionary.
774
+
775
+ Returns
776
+ -------
777
+ DotDict
778
+ A new DotDict instance.
779
+ """
780
+ return cls(d)
781
+
782
+ def update(self, *args, **kwargs) -> None:
783
+ """
784
+ Update with recursive merge for nested dicts.
785
+
786
+ Parameters
787
+ ----------
788
+ *args
789
+ Dict-like objects to merge.
790
+ **kwargs
791
+ Key-value pairs to merge.
792
+ """
793
+ if args:
794
+ if len(args) > 1:
795
+ raise TypeError(f"update expected at most 1 argument, got {len(args)}")
796
+ other = args[0]
797
+ else:
798
+ other = {}
799
+
800
+ if hasattr(other, 'items'):
801
+ other = dict(other.items())
802
+ other.update(kwargs)
803
+
804
+ for k, v in other.items():
805
+ if k in self and isinstance(self[k], dict) and isinstance(v, dict):
806
+ # Recursive merge for nested dicts
807
+ if isinstance(self[k], DotDict):
808
+ self[k].update(v)
809
+ else:
810
+ self[k] = DotDict(self[k])
811
+ self[k].update(v)
812
+ else:
813
+ self[k] = self._hook(v)
814
+
815
+ def setdefault(self, key: str, default: Any = None) -> Any:
816
+ """Set default value if key doesn't exist."""
817
+ if key not in self:
818
+ self[key] = default
819
+ return self[key]
820
+
821
+ def __getstate__(self) -> Dict[str, Any]:
822
+ """Get state for pickling."""
823
+ return dict(self)
824
+
825
+ def __setstate__(self, state: Dict[str, Any]) -> None:
826
+ """Set state from pickling."""
827
+ self.update(state)
828
+
829
+ def tree_flatten(self) -> Tuple[Tuple[Any, ...], Tuple[str, ...]]:
830
+ """Flatten for JAX pytree."""
831
+ return tuple(self.values()), tuple(self.keys())
832
+
833
+ @classmethod
834
+ def tree_unflatten(cls, keys: Tuple[str, ...], values: Tuple[Any, ...]) -> 'DotDict':
835
+ """Unflatten from JAX pytree."""
836
+ return cls(zip(keys, values))
837
+
838
+ def __repr__(self) -> str:
839
+ """String representation."""
840
+ items = ', '.join(f'{k!r}: {v!r}' for k, v in self.items())
841
+ return f'DotDict({{{items}}})'
842
+
843
+
844
+ @set_module_as('brainstate.util')
845
+ def merge_dicts(*dicts: Dict[K, V], recursive: bool = True) -> Dict[K, V]:
846
+ """
847
+ Merge multiple dictionaries.
848
+
849
+ Parameters
850
+ ----------
851
+ *dicts : Dict
852
+ Dictionaries to merge (later ones override earlier ones).
853
+ recursive : bool, default=True
854
+ Whether to recursively merge nested dicts.
855
+
856
+ Returns
857
+ -------
858
+ Dict
859
+ Merged dictionary.
860
+
861
+ Examples
862
+ --------
863
+ >>> d1 = {'a': 1, 'b': {'c': 2}}
864
+ >>> d2 = {'b': {'d': 3}, 'e': 4}
865
+ >>> merge_dicts(d1, d2)
866
+ {'a': 1, 'b': {'c': 2, 'd': 3}, 'e': 4}
867
+ """
868
+ result = {}
869
+
870
+ for d in dicts:
871
+ if not isinstance(d, dict):
872
+ raise TypeError(f"All arguments must be dicts, got {type(d).__name__}")
873
+
874
+ for key, value in d.items():
875
+ if recursive and key in result and isinstance(result[key], dict) and isinstance(value, dict):
876
+ result[key] = merge_dicts(result[key], value, recursive=True)
877
+ else:
878
+ result[key] = value
879
+
880
+ return result
881
+
882
+
883
+ @set_module_as('brainstate.util')
884
+ def flatten_dict(
885
+ d: Dict[str, Any],
886
+ parent_key: str = '',
887
+ sep: str = '.'
888
+ ) -> Dict[str, Any]:
889
+ """
890
+ Flatten a nested dictionary.
891
+
892
+ Parameters
893
+ ----------
894
+ d : Dict
895
+ Dictionary to flatten.
896
+ parent_key : str, default=''
897
+ Prefix for keys.
898
+ sep : str, default='.'
899
+ Separator between nested keys.
900
+
901
+ Returns
902
+ -------
903
+ Dict
904
+ Flattened dictionary.
905
+
906
+ Examples
907
+ --------
908
+ >>> d = {'a': 1, 'b': {'c': 2, 'd': {'e': 3}}}
909
+ >>> flatten_dict(d)
910
+ {'a': 1, 'b.c': 2, 'b.d.e': 3}
911
+ """
912
+ items = []
913
+ for k, v in d.items():
914
+ new_key = f"{parent_key}{sep}{k}" if parent_key else k
915
+ if isinstance(v, dict):
916
+ items.extend(flatten_dict(v, new_key, sep=sep).items())
917
+ else:
918
+ items.append((new_key, v))
919
+ return dict(items)
920
+
921
+
922
+ @set_module_as('brainstate.util')
923
+ def unflatten_dict(
924
+ d: Dict[str, Any],
925
+ sep: str = '.'
926
+ ) -> Dict[str, Any]:
927
+ """
928
+ Unflatten a dictionary with separated keys.
929
+
930
+ Parameters
931
+ ----------
932
+ d : Dict
933
+ Flattened dictionary.
934
+ sep : str, default='.'
935
+ Separator in keys.
936
+
937
+ Returns
938
+ -------
939
+ Dict
940
+ Nested dictionary.
941
+
942
+ Examples
943
+ --------
944
+ >>> d = {'a': 1, 'b.c': 2, 'b.d.e': 3}
945
+ >>> unflatten_dict(d)
946
+ {'a': 1, 'b': {'c': 2, 'd': {'e': 3}}}
947
+ """
948
+ result = {}
949
+
950
+ for key, value in d.items():
951
+ parts = key.split(sep)
952
+ current = result
953
+
954
+ for part in parts[:-1]:
955
+ if part not in current:
956
+ current[part] = {}
957
+ current = current[part]
958
+
959
+ current[parts[-1]] = value
960
+
961
+ return result
962
+
963
+
964
+ def _is_not_instance(x: Any, cls: Union[Type, Tuple[Type, ...]]) -> bool:
965
+ """Check if x is not an instance of cls."""
966
+ return not isinstance(x, cls)
967
+
968
+
969
+ def _is_instance(x: Any, cls: Union[Type, Tuple[Type, ...]]) -> bool:
970
+ """Check if x is an instance of cls."""
971
+ return isinstance(x, cls)
972
+
973
+
974
+ @set_module_as('brainstate.util')
975
+ def not_instance_eval(*cls: Type) -> Callable[[Any], bool]:
976
+ """
977
+ Create a partial function to check if input is NOT an instance of given classes.
978
+
979
+ Parameters
980
+ ----------
981
+ *cls : Type
982
+ Classes to check against.
983
+
984
+ Returns
985
+ -------
986
+ Callable
987
+ A function that returns True if input is not an instance of any given class.
988
+
989
+ Examples
990
+ --------
991
+ >>> not_int = not_instance_eval(int)
992
+ >>> not_int(5)
993
+ False
994
+ >>> not_int("hello")
995
+ True
996
+ """
997
+ return functools.partial(_is_not_instance, cls=cls)
998
+
999
+
1000
+ @set_module_as('brainstate.util')
1001
+ def is_instance_eval(*cls: Type) -> Callable[[Any], bool]:
1002
+ """
1003
+ Create a partial function to check if input IS an instance of given classes.
1004
+
1005
+ Parameters
1006
+ ----------
1007
+ *cls : Type
1008
+ Classes to check against.
1009
+
1010
+ Returns
1011
+ -------
1012
+ Callable
1013
+ A function that returns True if input is an instance of any given class.
1014
+
1015
+ Examples
1016
+ --------
1017
+ >>> is_number = is_instance_eval(int, float)
1018
+ >>> is_number(5)
1019
+ True
1020
+ >>> is_number(3.14)
1021
+ True
1022
+ >>> is_number("hello")
1023
+ False
1024
+ """
1025
1025
  return functools.partial(_is_instance, cls=cls)