brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +167 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2297 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +2157 -1652
- brainstate/_state_test.py +1129 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1620 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1447 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +146 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +635 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +134 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +480 -477
- brainstate/nn/_dynamics.py +870 -1267
- brainstate/nn/_dynamics_test.py +53 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +391 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
- brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
- brainstate/random/_impl.py +672 -0
- brainstate/random/{_rand_seed.py → _seed.py} +675 -675
- brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
- brainstate/random/{_rand_state.py → _state.py} +1320 -1617
- brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
- brainstate/transform/__init__.py +56 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_find_state.py +200 -0
- brainstate/transform/_find_state_test.py +84 -0
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2176 -2016
- brainstate/transform/_make_jaxpr_test.py +1634 -1510
- brainstate/transform/_mapping.py +607 -529
- brainstate/transform/_mapping_test.py +104 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
- brainstate-0.2.2.dist-info/RECORD +111 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
- brainstate/transform/_eval_shape.py +0 -145
- brainstate/transform/_eval_shape_test.py +0 -38
- brainstate/transform/_random.py +0 -171
- brainstate-0.2.1.dist-info/RECORD +0 -111
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
brainstate/util/_others.py
CHANGED
@@ -1,1025 +1,1025 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
"""
|
17
|
-
Utility functions and classes for BrainState.
|
18
|
-
|
19
|
-
This module provides various utility functions and enhanced dictionary classes
|
20
|
-
for managing collections, memory, and object operations in the BrainState framework.
|
21
|
-
"""
|
22
|
-
|
23
|
-
import copy
|
24
|
-
import functools
|
25
|
-
import gc
|
26
|
-
import threading
|
27
|
-
import types
|
28
|
-
import warnings
|
29
|
-
from collections.abc import Iterable, Mapping, MutableMapping
|
30
|
-
from typing import (
|
31
|
-
Any, Callable, Dict, Iterator, List, Optional,
|
32
|
-
Tuple, Type, TypeVar, Union, overload
|
33
|
-
)
|
34
|
-
|
35
|
-
import jax
|
36
|
-
from jax.lib import xla_bridge
|
37
|
-
|
38
|
-
from brainstate._utils import set_module_as
|
39
|
-
|
40
|
-
__all__ = [
|
41
|
-
'split_total',
|
42
|
-
'clear_buffer_memory',
|
43
|
-
'not_instance_eval',
|
44
|
-
'is_instance_eval',
|
45
|
-
'DictManager',
|
46
|
-
'DotDict',
|
47
|
-
'get_unique_name',
|
48
|
-
'merge_dicts',
|
49
|
-
'flatten_dict',
|
50
|
-
'unflatten_dict',
|
51
|
-
]
|
52
|
-
|
53
|
-
T = TypeVar('T')
|
54
|
-
V = TypeVar('V')
|
55
|
-
K = TypeVar('K')
|
56
|
-
|
57
|
-
|
58
|
-
def split_total(
|
59
|
-
total: int,
|
60
|
-
fraction: Union[int, float],
|
61
|
-
) -> int:
|
62
|
-
"""
|
63
|
-
Calculate the number of epochs for simulation based on a total and a fraction.
|
64
|
-
|
65
|
-
This function determines the number of epochs to simulate given a total number
|
66
|
-
of epochs and either a fraction or a specific number of epochs to run.
|
67
|
-
|
68
|
-
Parameters
|
69
|
-
----------
|
70
|
-
total : int
|
71
|
-
The total number of epochs. Must be a positive integer.
|
72
|
-
fraction : Union[int, float]
|
73
|
-
If ``float``: A value between 0 and 1 representing the fraction of total epochs to run.
|
74
|
-
If ``int``: The specific number of epochs to run, must not exceed the total.
|
75
|
-
|
76
|
-
Returns
|
77
|
-
-------
|
78
|
-
int
|
79
|
-
The calculated number of epochs to simulate.
|
80
|
-
|
81
|
-
Raises
|
82
|
-
------
|
83
|
-
TypeError
|
84
|
-
If total is not an integer.
|
85
|
-
ValueError
|
86
|
-
If total is not positive, fraction is negative, or if fraction as float is > 1
|
87
|
-
or as int is > total.
|
88
|
-
|
89
|
-
Examples
|
90
|
-
--------
|
91
|
-
>>> split_total(100, 0.5)
|
92
|
-
50
|
93
|
-
>>> split_total(100, 25)
|
94
|
-
25
|
95
|
-
>>> split_total(100, 1.5) # Raises ValueError
|
96
|
-
ValueError: 'fraction' value cannot be greater than 1.
|
97
|
-
"""
|
98
|
-
if not isinstance(total, int):
|
99
|
-
raise TypeError(f"'total' must be an integer, got {type(total).__name__}.")
|
100
|
-
if total <= 0:
|
101
|
-
raise ValueError(f"'total' must be a positive integer, got {total}.")
|
102
|
-
|
103
|
-
if isinstance(fraction, float):
|
104
|
-
if fraction < 0:
|
105
|
-
raise ValueError(f"'fraction' value cannot be negative, got {fraction}.")
|
106
|
-
if fraction > 1:
|
107
|
-
raise ValueError(f"'fraction' value cannot be greater than 1, got {fraction}.")
|
108
|
-
return int(total * fraction)
|
109
|
-
|
110
|
-
elif isinstance(fraction, int):
|
111
|
-
if fraction < 0:
|
112
|
-
raise ValueError(f"'fraction' value cannot be negative, got {fraction}.")
|
113
|
-
if fraction > total:
|
114
|
-
raise ValueError(f"'fraction' value cannot be greater than total ({total}), got {fraction}.")
|
115
|
-
return fraction
|
116
|
-
|
117
|
-
else:
|
118
|
-
raise TypeError(f"'fraction' must be an integer or float, got {type(fraction).__name__}.")
|
119
|
-
|
120
|
-
|
121
|
-
class NameContext(threading.local):
|
122
|
-
"""Thread-local context for managing unique names."""
|
123
|
-
|
124
|
-
def __init__(self):
|
125
|
-
self.typed_names: Dict[str, int] = {}
|
126
|
-
|
127
|
-
def reset(self, type_: Optional[str] = None) -> None:
|
128
|
-
"""Reset the counter for a specific type or all types."""
|
129
|
-
if type_ is None:
|
130
|
-
self.typed_names.clear()
|
131
|
-
elif type_ in self.typed_names:
|
132
|
-
self.typed_names[type_] = 0
|
133
|
-
|
134
|
-
|
135
|
-
NAME = NameContext()
|
136
|
-
|
137
|
-
|
138
|
-
@set_module_as('brainstate.util')
|
139
|
-
def get_unique_name(type_: str, prefix: str = '') -> str:
|
140
|
-
"""
|
141
|
-
Get a unique name for the given object type.
|
142
|
-
|
143
|
-
Parameters
|
144
|
-
----------
|
145
|
-
type_ : str
|
146
|
-
The base type name.
|
147
|
-
prefix : str, optional
|
148
|
-
Additional prefix to add before the type name.
|
149
|
-
|
150
|
-
Returns
|
151
|
-
-------
|
152
|
-
str
|
153
|
-
A unique name combining prefix, type, and counter.
|
154
|
-
|
155
|
-
Examples
|
156
|
-
--------
|
157
|
-
>>> get_unique_name('layer')
|
158
|
-
'layer0'
|
159
|
-
>>> get_unique_name('layer', 'conv_')
|
160
|
-
'conv_layer1'
|
161
|
-
"""
|
162
|
-
if type_ not in NAME.typed_names:
|
163
|
-
NAME.typed_names[type_] = 0
|
164
|
-
|
165
|
-
full_prefix = f'{prefix}{type_}' if prefix else type_
|
166
|
-
name = f'{full_prefix}{NAME.typed_names[type_]}'
|
167
|
-
NAME.typed_names[type_] += 1
|
168
|
-
return name
|
169
|
-
|
170
|
-
|
171
|
-
@jax.tree_util.register_pytree_node_class
|
172
|
-
class DictManager(dict, MutableMapping[K, V]):
|
173
|
-
"""
|
174
|
-
Enhanced dictionary for managing collections in BrainState.
|
175
|
-
|
176
|
-
DictManager extends the standard Python dict with additional methods for
|
177
|
-
filtering, splitting, and managing collections of objects. It's registered
|
178
|
-
as a JAX pytree node for compatibility with JAX transformations.
|
179
|
-
|
180
|
-
Examples
|
181
|
-
--------
|
182
|
-
>>> dm = DictManager({'a': 1, 'b': 2.0, 'c': 'text'})
|
183
|
-
>>> dm.subset(int) # Get only integer values
|
184
|
-
DictManager({'a': 1})
|
185
|
-
>>> dm.unique() # Get unique values only
|
186
|
-
DictManager({'a': 1, 'b': 2.0, 'c': 'text'})
|
187
|
-
"""
|
188
|
-
|
189
|
-
__module__ = 'brainstate.util'
|
190
|
-
_val_id_to_key: Dict[int, Any]
|
191
|
-
|
192
|
-
def __init__(self, *args, **kwargs):
|
193
|
-
"""Initialize DictManager with optional dict-like arguments."""
|
194
|
-
super().__init__(*args, **kwargs)
|
195
|
-
self._val_id_to_key = {}
|
196
|
-
|
197
|
-
def subset(self, sep: Union[Type, Tuple[Type, ...], Callable[[Any], bool]]) -> 'DictManager':
|
198
|
-
"""
|
199
|
-
Get a new DictManager with a subset of items based on value type or predicate.
|
200
|
-
|
201
|
-
Parameters
|
202
|
-
----------
|
203
|
-
sep : Union[Type, Tuple[Type, ...], Callable]
|
204
|
-
If Type or Tuple of Types: Select values that are instances of these types.
|
205
|
-
If Callable: Select values where sep(value) returns True.
|
206
|
-
|
207
|
-
Returns
|
208
|
-
-------
|
209
|
-
DictManager
|
210
|
-
A new DictManager containing only matching items.
|
211
|
-
|
212
|
-
Examples
|
213
|
-
--------
|
214
|
-
>>> dm = DictManager({'a': 1, 'b': 2.0, 'c': 'text'})
|
215
|
-
>>> dm.subset(int)
|
216
|
-
DictManager({'a': 1})
|
217
|
-
>>> dm.subset(lambda x: isinstance(x, (int, float)))
|
218
|
-
DictManager({'a': 1, 'b': 2.0})
|
219
|
-
"""
|
220
|
-
gather = type(self)()
|
221
|
-
if callable(sep) and not isinstance(sep, type):
|
222
|
-
for k, v in self.items():
|
223
|
-
if sep(v):
|
224
|
-
gather[k] = v
|
225
|
-
else:
|
226
|
-
for k, v in self.items():
|
227
|
-
if isinstance(v, sep):
|
228
|
-
gather[k] = v
|
229
|
-
return gather
|
230
|
-
|
231
|
-
def not_subset(self, sep: Union[Type, Tuple[Type, ...]]) -> 'DictManager':
|
232
|
-
"""
|
233
|
-
Get a new DictManager excluding items of specified types.
|
234
|
-
|
235
|
-
Parameters
|
236
|
-
----------
|
237
|
-
sep : Union[Type, Tuple[Type, ...]]
|
238
|
-
Types to exclude from the result.
|
239
|
-
|
240
|
-
Returns
|
241
|
-
-------
|
242
|
-
DictManager
|
243
|
-
A new DictManager excluding items of specified types.
|
244
|
-
"""
|
245
|
-
gather = type(self)()
|
246
|
-
for k, v in self.items():
|
247
|
-
if not isinstance(v, sep):
|
248
|
-
gather[k] = v
|
249
|
-
return gather
|
250
|
-
|
251
|
-
def add_unique_key(self, key: K, val: V) -> None:
|
252
|
-
"""
|
253
|
-
Add a new element ensuring the key maps to a unique value.
|
254
|
-
|
255
|
-
Parameters
|
256
|
-
----------
|
257
|
-
key : Any
|
258
|
-
The key to add.
|
259
|
-
val : Any
|
260
|
-
The value to associate with the key.
|
261
|
-
|
262
|
-
Raises
|
263
|
-
------
|
264
|
-
ValueError
|
265
|
-
If the key already exists with a different value.
|
266
|
-
"""
|
267
|
-
self._check_elem(val)
|
268
|
-
if key in self:
|
269
|
-
if id(val) != id(self[key]):
|
270
|
-
raise ValueError(
|
271
|
-
f"Key '{key}' already exists with a different value. "
|
272
|
-
f"Existing: {self[key]}, New: {val}"
|
273
|
-
)
|
274
|
-
else:
|
275
|
-
self[key] = val
|
276
|
-
|
277
|
-
def add_unique_value(self, key: K, val: V) -> bool:
|
278
|
-
"""
|
279
|
-
Add a new element only if the value is unique across all entries.
|
280
|
-
|
281
|
-
Parameters
|
282
|
-
----------
|
283
|
-
key : Any
|
284
|
-
The key to add.
|
285
|
-
val : Any
|
286
|
-
The value to associate with the key.
|
287
|
-
|
288
|
-
Returns
|
289
|
-
-------
|
290
|
-
bool
|
291
|
-
True if the value was added (was unique), False otherwise.
|
292
|
-
"""
|
293
|
-
self._check_elem(val)
|
294
|
-
if not hasattr(self, '_val_id_to_key'):
|
295
|
-
self._val_id_to_key = {id(v): k for k, v in self.items()}
|
296
|
-
|
297
|
-
val_id = id(val)
|
298
|
-
if val_id not in self._val_id_to_key:
|
299
|
-
self._val_id_to_key[val_id] = key
|
300
|
-
self[key] = val
|
301
|
-
return True
|
302
|
-
return False
|
303
|
-
|
304
|
-
def unique(self) -> 'DictManager':
|
305
|
-
"""
|
306
|
-
Get a new DictManager with unique values only.
|
307
|
-
|
308
|
-
If multiple keys map to the same value (by identity),
|
309
|
-
only the first key-value pair is retained.
|
310
|
-
|
311
|
-
Returns
|
312
|
-
-------
|
313
|
-
DictManager
|
314
|
-
A new DictManager with unique values.
|
315
|
-
"""
|
316
|
-
gather = type(self)()
|
317
|
-
seen = set()
|
318
|
-
for k, v in self.items():
|
319
|
-
v_id = id(v)
|
320
|
-
if v_id not in seen:
|
321
|
-
seen.add(v_id)
|
322
|
-
gather[k] = v
|
323
|
-
return gather
|
324
|
-
|
325
|
-
def unique_(self) -> 'DictManager':
|
326
|
-
"""
|
327
|
-
Remove duplicate values in-place.
|
328
|
-
|
329
|
-
Returns
|
330
|
-
-------
|
331
|
-
DictManager
|
332
|
-
Self, for method chaining.
|
333
|
-
"""
|
334
|
-
seen = set()
|
335
|
-
keys_to_remove = []
|
336
|
-
for k, v in self.items():
|
337
|
-
v_id = id(v)
|
338
|
-
if v_id in seen:
|
339
|
-
keys_to_remove.append(k)
|
340
|
-
else:
|
341
|
-
seen.add(v_id)
|
342
|
-
|
343
|
-
for k in keys_to_remove:
|
344
|
-
del self[k]
|
345
|
-
return self
|
346
|
-
|
347
|
-
def assign(self, *args: Dict[K, V], **kwargs: V) -> None:
|
348
|
-
"""
|
349
|
-
Update the DictManager with multiple dictionaries.
|
350
|
-
|
351
|
-
Parameters
|
352
|
-
----------
|
353
|
-
*args : Dict
|
354
|
-
Dictionaries to merge into this one.
|
355
|
-
**kwargs
|
356
|
-
Additional key-value pairs to add.
|
357
|
-
"""
|
358
|
-
for arg in args:
|
359
|
-
if not isinstance(arg, dict):
|
360
|
-
raise TypeError(f"Arguments must be dict instances, got {type(arg).__name__}")
|
361
|
-
self.update(arg)
|
362
|
-
if kwargs:
|
363
|
-
self.update(kwargs)
|
364
|
-
|
365
|
-
def split(self, *types: Type) -> Tuple['DictManager', ...]:
|
366
|
-
"""
|
367
|
-
Split the DictManager into multiple based on value types.
|
368
|
-
|
369
|
-
Parameters
|
370
|
-
----------
|
371
|
-
*types : Type
|
372
|
-
Types to use for splitting. Each type gets its own DictManager.
|
373
|
-
|
374
|
-
Returns
|
375
|
-
-------
|
376
|
-
Tuple[DictManager, ...]
|
377
|
-
A tuple of DictManagers, one for each type plus one for unmatched items.
|
378
|
-
"""
|
379
|
-
results = tuple(type(self)() for _ in range(len(types) + 1))
|
380
|
-
|
381
|
-
for k, v in self.items():
|
382
|
-
for i, type_ in enumerate(types):
|
383
|
-
if isinstance(v, type_):
|
384
|
-
results[i][k] = v
|
385
|
-
break
|
386
|
-
else:
|
387
|
-
results[-1][k] = v
|
388
|
-
|
389
|
-
return results
|
390
|
-
|
391
|
-
def filter_by_predicate(self, predicate: Callable[[K, V], bool]) -> 'DictManager':
|
392
|
-
"""
|
393
|
-
Filter items using a predicate function.
|
394
|
-
|
395
|
-
Parameters
|
396
|
-
----------
|
397
|
-
predicate : Callable[[key, value], bool]
|
398
|
-
Function that returns True for items to keep.
|
399
|
-
|
400
|
-
Returns
|
401
|
-
-------
|
402
|
-
DictManager
|
403
|
-
A new DictManager with filtered items.
|
404
|
-
"""
|
405
|
-
return type(self)({k: v for k, v in self.items() if predicate(k, v)})
|
406
|
-
|
407
|
-
def map_values(self, func: Callable[[V], Any]) -> 'DictManager':
|
408
|
-
"""
|
409
|
-
Apply a function to all values.
|
410
|
-
|
411
|
-
Parameters
|
412
|
-
----------
|
413
|
-
func : Callable
|
414
|
-
Function to apply to each value.
|
415
|
-
|
416
|
-
Returns
|
417
|
-
-------
|
418
|
-
DictManager
|
419
|
-
A new DictManager with transformed values.
|
420
|
-
"""
|
421
|
-
return type(self)({k: func(v) for k, v in self.items()})
|
422
|
-
|
423
|
-
def map_keys(self, func: Callable[[K], Any]) -> 'DictManager':
|
424
|
-
"""
|
425
|
-
Apply a function to all keys.
|
426
|
-
|
427
|
-
Parameters
|
428
|
-
----------
|
429
|
-
func : Callable
|
430
|
-
Function to apply to each key.
|
431
|
-
|
432
|
-
Returns
|
433
|
-
-------
|
434
|
-
DictManager
|
435
|
-
A new DictManager with transformed keys.
|
436
|
-
|
437
|
-
Raises
|
438
|
-
------
|
439
|
-
ValueError
|
440
|
-
If the transformation creates duplicate keys.
|
441
|
-
"""
|
442
|
-
result = type(self)()
|
443
|
-
for k, v in self.items():
|
444
|
-
new_key = func(k)
|
445
|
-
if new_key in result:
|
446
|
-
raise ValueError(f"Key transformation created duplicate: {new_key}")
|
447
|
-
result[new_key] = v
|
448
|
-
return result
|
449
|
-
|
450
|
-
def pop_by_keys(self, keys: Iterable[K]) -> None:
|
451
|
-
"""Remove multiple keys from the DictManager."""
|
452
|
-
keys_set = set(keys)
|
453
|
-
for k in list(self.keys()):
|
454
|
-
if k in keys_set:
|
455
|
-
self.pop(k)
|
456
|
-
|
457
|
-
def pop_by_values(self, values: Iterable[V], by: str = 'id') -> None:
|
458
|
-
"""
|
459
|
-
Remove items by their values.
|
460
|
-
|
461
|
-
Parameters
|
462
|
-
----------
|
463
|
-
values : Iterable
|
464
|
-
Values to remove.
|
465
|
-
by : str
|
466
|
-
Comparison method: 'id' (identity) or 'value' (equality).
|
467
|
-
"""
|
468
|
-
if by == 'id':
|
469
|
-
value_ids = {id(v) for v in values}
|
470
|
-
keys_to_remove = [k for k, v in self.items() if id(v) in value_ids]
|
471
|
-
elif by == 'value':
|
472
|
-
values_set = set(values) if not isinstance(values, set) else values
|
473
|
-
keys_to_remove = [k for k, v in self.items() if v in values_set]
|
474
|
-
else:
|
475
|
-
raise ValueError(f"Invalid comparison method: {by}. Use 'id' or 'value'.")
|
476
|
-
|
477
|
-
for k in keys_to_remove:
|
478
|
-
del self[k]
|
479
|
-
|
480
|
-
def difference_by_keys(self, keys: Iterable[K]) -> 'DictManager':
|
481
|
-
"""Get items not in the specified keys."""
|
482
|
-
keys_set = set(keys)
|
483
|
-
return type(self)({k: v for k, v in self.items() if k not in keys_set})
|
484
|
-
|
485
|
-
def difference_by_values(self, values: Iterable[V], by: str = 'id') -> 'DictManager':
|
486
|
-
"""Get items whose values are not in the specified collection."""
|
487
|
-
if by == 'id':
|
488
|
-
value_ids = {id(v) for v in values}
|
489
|
-
return type(self)({k: v for k, v in self.items() if id(v) not in value_ids})
|
490
|
-
elif by == 'value':
|
491
|
-
values_set = set(values) if not isinstance(values, set) else values
|
492
|
-
return type(self)({k: v for k, v in self.items() if v not in values_set})
|
493
|
-
else:
|
494
|
-
raise ValueError(f"Invalid comparison method: {by}. Use 'id' or 'value'.")
|
495
|
-
|
496
|
-
def intersection_by_keys(self, keys: Iterable[K]) -> 'DictManager':
|
497
|
-
"""Get items with keys in the specified collection."""
|
498
|
-
keys_set = set(keys)
|
499
|
-
return type(self)({k: v for k, v in self.items() if k in keys_set})
|
500
|
-
|
501
|
-
def intersection_by_values(self, values: Iterable[V], by: str = 'id') -> 'DictManager':
|
502
|
-
"""Get items whose values are in the specified collection."""
|
503
|
-
if by == 'id':
|
504
|
-
value_ids = {id(v) for v in values}
|
505
|
-
return type(self)({k: v for k, v in self.items() if id(v) in value_ids})
|
506
|
-
elif by == 'value':
|
507
|
-
values_set = set(values) if not isinstance(values, set) else values
|
508
|
-
return type(self)({k: v for k, v in self.items() if v in values_set})
|
509
|
-
else:
|
510
|
-
raise ValueError(f"Invalid comparison method: {by}. Use 'id' or 'value'.")
|
511
|
-
|
512
|
-
def __add__(self, other: Mapping[K, V]) -> 'DictManager':
|
513
|
-
"""Combine with another mapping using the + operator."""
|
514
|
-
if not isinstance(other, Mapping):
|
515
|
-
return NotImplemented
|
516
|
-
new_dict = type(self)(self)
|
517
|
-
new_dict.update(other)
|
518
|
-
return new_dict
|
519
|
-
|
520
|
-
def __or__(self, other: Mapping[K, V]) -> 'DictManager':
|
521
|
-
"""Combine with another mapping using the | operator (Python 3.9+)."""
|
522
|
-
if not isinstance(other, Mapping):
|
523
|
-
return NotImplemented
|
524
|
-
new_dict = type(self)(self)
|
525
|
-
new_dict.update(other)
|
526
|
-
return new_dict
|
527
|
-
|
528
|
-
def __ior__(self, other: Mapping[K, V]) -> 'DictManager':
|
529
|
-
"""Update in-place with another mapping using |= operator."""
|
530
|
-
if not isinstance(other, Mapping):
|
531
|
-
return NotImplemented
|
532
|
-
self.update(other)
|
533
|
-
return self
|
534
|
-
|
535
|
-
def tree_flatten(self) -> Tuple[Tuple[V, ...], Tuple[K, ...]]:
|
536
|
-
"""Flatten for JAX pytree."""
|
537
|
-
return tuple(self.values()), tuple(self.keys())
|
538
|
-
|
539
|
-
@classmethod
|
540
|
-
def tree_unflatten(cls, keys: Tuple[K, ...], values: Tuple[V, ...]) -> 'DictManager':
|
541
|
-
"""Unflatten from JAX pytree."""
|
542
|
-
return cls(zip(keys, values))
|
543
|
-
|
544
|
-
def _check_elem(self, elem: Any) -> None:
|
545
|
-
"""Override in subclasses to validate elements."""
|
546
|
-
pass
|
547
|
-
|
548
|
-
def to_dict(self) -> Dict[K, V]:
|
549
|
-
"""Convert to a standard Python dict."""
|
550
|
-
return dict(self)
|
551
|
-
|
552
|
-
def __copy__(self) -> 'DictManager':
|
553
|
-
"""Shallow copy."""
|
554
|
-
return type(self)(self)
|
555
|
-
|
556
|
-
def __deepcopy__(self, memo: Dict[int, Any]) -> 'DictManager':
|
557
|
-
"""Deep copy."""
|
558
|
-
return type(self)({
|
559
|
-
copy.deepcopy(k, memo): copy.deepcopy(v, memo)
|
560
|
-
for k, v in self.items()
|
561
|
-
})
|
562
|
-
|
563
|
-
def __repr__(self) -> str:
|
564
|
-
"""String representation."""
|
565
|
-
items = ', '.join(f'{k!r}: {v!r}' for k, v in self.items())
|
566
|
-
return f'{self.__class__.__name__}({{{items}}})'
|
567
|
-
|
568
|
-
|
569
|
-
@set_module_as('brainstate.util')
|
570
|
-
def clear_buffer_memory(
|
571
|
-
platform: Optional[str] = None,
|
572
|
-
array: bool = True,
|
573
|
-
compilation: bool = False,
|
574
|
-
) -> None:
|
575
|
-
"""
|
576
|
-
Clear on-device memory buffers and optionally compilation cache.
|
577
|
-
|
578
|
-
This function is useful when running models in loops to prevent memory leaks
|
579
|
-
by clearing cached arrays and freeing device memory.
|
580
|
-
|
581
|
-
.. warning::
|
582
|
-
This operation may invalidate existing array references.
|
583
|
-
Regenerate data after calling this function.
|
584
|
-
|
585
|
-
Parameters
|
586
|
-
----------
|
587
|
-
platform : str, optional
|
588
|
-
The specific device platform to clear. If None, clears the default platform.
|
589
|
-
array : bool, default=True
|
590
|
-
Whether to clear array buffers.
|
591
|
-
compilation : bool, default=False
|
592
|
-
Whether to clear the compilation cache.
|
593
|
-
|
594
|
-
Examples
|
595
|
-
--------
|
596
|
-
>>> clear_buffer_memory() # Clear array buffers
|
597
|
-
>>> clear_buffer_memory(compilation=True) # Also clear compilation cache
|
598
|
-
"""
|
599
|
-
if array:
|
600
|
-
try:
|
601
|
-
backend = xla_bridge.get_backend(platform)
|
602
|
-
for buf in backend.live_buffers():
|
603
|
-
buf.delete()
|
604
|
-
except Exception as e:
|
605
|
-
warnings.warn(f"Failed to clear buffers: {e}", RuntimeWarning)
|
606
|
-
|
607
|
-
if compilation:
|
608
|
-
jax.clear_caches()
|
609
|
-
|
610
|
-
gc.collect()
|
611
|
-
|
612
|
-
|
613
|
-
@jax.tree_util.register_pytree_node_class
|
614
|
-
class DotDict(dict, MutableMapping[str, Any]):
|
615
|
-
"""
|
616
|
-
Dictionary with dot notation access to nested keys.
|
617
|
-
|
618
|
-
DotDict allows accessing dictionary items using attribute syntax,
|
619
|
-
making code more readable when dealing with nested configurations.
|
620
|
-
|
621
|
-
Examples
|
622
|
-
--------
|
623
|
-
>>> config = DotDict({'model': {'layers': 3, 'units': 64}})
|
624
|
-
>>> config.model.layers
|
625
|
-
3
|
626
|
-
>>> config.model.units = 128
|
627
|
-
>>> config['model']['units']
|
628
|
-
128
|
629
|
-
|
630
|
-
Attributes
|
631
|
-
----------
|
632
|
-
All dictionary keys become accessible as attributes unless they conflict
|
633
|
-
with built-in methods.
|
634
|
-
"""
|
635
|
-
|
636
|
-
__module__ = 'brainstate.util'
|
637
|
-
|
638
|
-
def __init__(self, *args, **kwargs):
|
639
|
-
"""
|
640
|
-
Initialize DotDict with dict-like arguments.
|
641
|
-
|
642
|
-
Parameters
|
643
|
-
----------
|
644
|
-
*args
|
645
|
-
Positional arguments (dicts, iterables of pairs).
|
646
|
-
**kwargs
|
647
|
-
Keyword arguments become key-value pairs.
|
648
|
-
"""
|
649
|
-
# Handle parent reference for nested updates
|
650
|
-
object.__setattr__(self, '__parent', kwargs.pop('__parent', None))
|
651
|
-
object.__setattr__(self, '__key', kwargs.pop('__key', None))
|
652
|
-
|
653
|
-
# Process positional arguments
|
654
|
-
for arg in args:
|
655
|
-
if not arg:
|
656
|
-
continue
|
657
|
-
elif isinstance(arg, dict):
|
658
|
-
for key, val in arg.items():
|
659
|
-
self[key] = self._hook(val)
|
660
|
-
elif isinstance(arg, tuple) and len(arg) == 2 and not isinstance(arg[0], tuple):
|
661
|
-
# Single key-value pair
|
662
|
-
self[arg[0]] = self._hook(arg[1])
|
663
|
-
else:
|
664
|
-
# Iterable of key-value pairs
|
665
|
-
try:
|
666
|
-
for key, val in arg:
|
667
|
-
self[key] = self._hook(val)
|
668
|
-
except (TypeError, ValueError) as e:
|
669
|
-
raise TypeError(f"Invalid argument type for DotDict: {type(arg).__name__}") from e
|
670
|
-
|
671
|
-
# Process keyword arguments
|
672
|
-
for key, val in kwargs.items():
|
673
|
-
self[key] = self._hook(val)
|
674
|
-
|
675
|
-
def __setattr__(self, name: str, value: Any) -> None:
|
676
|
-
"""Set attribute as dictionary item."""
|
677
|
-
if hasattr(self.__class__, name):
|
678
|
-
raise AttributeError(
|
679
|
-
f"Cannot set attribute '{name}': it's a built-in method of {self.__class__.__name__}"
|
680
|
-
)
|
681
|
-
self[name] = value
|
682
|
-
|
683
|
-
def __setitem__(self, name: str, value: Any) -> None:
|
684
|
-
"""Set item and update parent if nested."""
|
685
|
-
super().__setitem__(name, value)
|
686
|
-
try:
|
687
|
-
parent = object.__getattribute__(self, '__parent')
|
688
|
-
key = object.__getattribute__(self, '__key')
|
689
|
-
if parent is not None:
|
690
|
-
parent[key] = self
|
691
|
-
object.__delattr__(self, '__parent')
|
692
|
-
object.__delattr__(self, '__key')
|
693
|
-
except AttributeError:
|
694
|
-
pass
|
695
|
-
|
696
|
-
@classmethod
|
697
|
-
def _hook(cls, item: Any) -> Any:
|
698
|
-
"""Convert nested dicts to DotDict."""
|
699
|
-
if isinstance(item, dict) and not isinstance(item, cls):
|
700
|
-
return cls(item)
|
701
|
-
elif isinstance(item, (list, tuple)):
|
702
|
-
return type(item)(cls._hook(elem) for elem in item)
|
703
|
-
return item
|
704
|
-
|
705
|
-
def __getattr__(self, item: str) -> Any:
|
706
|
-
"""Get attribute from dictionary."""
|
707
|
-
try:
|
708
|
-
return self[item]
|
709
|
-
except KeyError:
|
710
|
-
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'")
|
711
|
-
|
712
|
-
def __delattr__(self, name: str) -> None:
|
713
|
-
"""Delete attribute from dictionary."""
|
714
|
-
try:
|
715
|
-
del self[name]
|
716
|
-
except KeyError:
|
717
|
-
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
718
|
-
|
719
|
-
def __dir__(self) -> List[str]:
|
720
|
-
"""List all attributes including dict keys."""
|
721
|
-
return list(self.keys()) + dir(self.__class__)
|
722
|
-
|
723
|
-
def get(self, key: str, default: Any = None) -> Any:
|
724
|
-
"""Get item with default value."""
|
725
|
-
return super().get(key, default)
|
726
|
-
|
727
|
-
def copy(self) -> 'DotDict':
|
728
|
-
"""Create a shallow copy."""
|
729
|
-
return copy.copy(self)
|
730
|
-
|
731
|
-
def deepcopy(self) -> 'DotDict':
|
732
|
-
"""Create a deep copy."""
|
733
|
-
return copy.deepcopy(self)
|
734
|
-
|
735
|
-
def __deepcopy__(self, memo: Dict[int, Any]) -> 'DotDict':
|
736
|
-
"""Deep copy implementation."""
|
737
|
-
other = self.__class__()
|
738
|
-
memo[id(self)] = other
|
739
|
-
for key, value in self.items():
|
740
|
-
other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo)
|
741
|
-
return other
|
742
|
-
|
743
|
-
def to_dict(self) -> Dict[str, Any]:
|
744
|
-
"""
|
745
|
-
Convert to standard dict recursively.
|
746
|
-
|
747
|
-
Returns
|
748
|
-
-------
|
749
|
-
dict
|
750
|
-
A standard Python dict with nested DotDicts also converted.
|
751
|
-
"""
|
752
|
-
result = {}
|
753
|
-
for key, value in self.items():
|
754
|
-
if isinstance(value, DotDict):
|
755
|
-
result[key] = value.to_dict()
|
756
|
-
elif isinstance(value, (list, tuple)):
|
757
|
-
result[key] = type(value)(
|
758
|
-
item.to_dict() if isinstance(item, DotDict) else item
|
759
|
-
for item in value
|
760
|
-
)
|
761
|
-
else:
|
762
|
-
result[key] = value
|
763
|
-
return result
|
764
|
-
|
765
|
-
@classmethod
|
766
|
-
def from_dict(cls, d: Dict[str, Any]) -> 'DotDict':
|
767
|
-
"""
|
768
|
-
Create DotDict from standard dict.
|
769
|
-
|
770
|
-
Parameters
|
771
|
-
----------
|
772
|
-
d : dict
|
773
|
-
Standard Python dictionary.
|
774
|
-
|
775
|
-
Returns
|
776
|
-
-------
|
777
|
-
DotDict
|
778
|
-
A new DotDict instance.
|
779
|
-
"""
|
780
|
-
return cls(d)
|
781
|
-
|
782
|
-
def update(self, *args, **kwargs) -> None:
|
783
|
-
"""
|
784
|
-
Update with recursive merge for nested dicts.
|
785
|
-
|
786
|
-
Parameters
|
787
|
-
----------
|
788
|
-
*args
|
789
|
-
Dict-like objects to merge.
|
790
|
-
**kwargs
|
791
|
-
Key-value pairs to merge.
|
792
|
-
"""
|
793
|
-
if args:
|
794
|
-
if len(args) > 1:
|
795
|
-
raise TypeError(f"update expected at most 1 argument, got {len(args)}")
|
796
|
-
other = args[0]
|
797
|
-
else:
|
798
|
-
other = {}
|
799
|
-
|
800
|
-
if hasattr(other, 'items'):
|
801
|
-
other = dict(other.items())
|
802
|
-
other.update(kwargs)
|
803
|
-
|
804
|
-
for k, v in other.items():
|
805
|
-
if k in self and isinstance(self[k], dict) and isinstance(v, dict):
|
806
|
-
# Recursive merge for nested dicts
|
807
|
-
if isinstance(self[k], DotDict):
|
808
|
-
self[k].update(v)
|
809
|
-
else:
|
810
|
-
self[k] = DotDict(self[k])
|
811
|
-
self[k].update(v)
|
812
|
-
else:
|
813
|
-
self[k] = self._hook(v)
|
814
|
-
|
815
|
-
def setdefault(self, key: str, default: Any = None) -> Any:
|
816
|
-
"""Set default value if key doesn't exist."""
|
817
|
-
if key not in self:
|
818
|
-
self[key] = default
|
819
|
-
return self[key]
|
820
|
-
|
821
|
-
def __getstate__(self) -> Dict[str, Any]:
|
822
|
-
"""Get state for pickling."""
|
823
|
-
return dict(self)
|
824
|
-
|
825
|
-
def __setstate__(self, state: Dict[str, Any]) -> None:
|
826
|
-
"""Set state from pickling."""
|
827
|
-
self.update(state)
|
828
|
-
|
829
|
-
def tree_flatten(self) -> Tuple[Tuple[Any, ...], Tuple[str, ...]]:
|
830
|
-
"""Flatten for JAX pytree."""
|
831
|
-
return tuple(self.values()), tuple(self.keys())
|
832
|
-
|
833
|
-
@classmethod
|
834
|
-
def tree_unflatten(cls, keys: Tuple[str, ...], values: Tuple[Any, ...]) -> 'DotDict':
|
835
|
-
"""Unflatten from JAX pytree."""
|
836
|
-
return cls(zip(keys, values))
|
837
|
-
|
838
|
-
def __repr__(self) -> str:
|
839
|
-
"""String representation."""
|
840
|
-
items = ', '.join(f'{k!r}: {v!r}' for k, v in self.items())
|
841
|
-
return f'DotDict({{{items}}})'
|
842
|
-
|
843
|
-
|
844
|
-
@set_module_as('brainstate.util')
|
845
|
-
def merge_dicts(*dicts: Dict[K, V], recursive: bool = True) -> Dict[K, V]:
|
846
|
-
"""
|
847
|
-
Merge multiple dictionaries.
|
848
|
-
|
849
|
-
Parameters
|
850
|
-
----------
|
851
|
-
*dicts : Dict
|
852
|
-
Dictionaries to merge (later ones override earlier ones).
|
853
|
-
recursive : bool, default=True
|
854
|
-
Whether to recursively merge nested dicts.
|
855
|
-
|
856
|
-
Returns
|
857
|
-
-------
|
858
|
-
Dict
|
859
|
-
Merged dictionary.
|
860
|
-
|
861
|
-
Examples
|
862
|
-
--------
|
863
|
-
>>> d1 = {'a': 1, 'b': {'c': 2}}
|
864
|
-
>>> d2 = {'b': {'d': 3}, 'e': 4}
|
865
|
-
>>> merge_dicts(d1, d2)
|
866
|
-
{'a': 1, 'b': {'c': 2, 'd': 3}, 'e': 4}
|
867
|
-
"""
|
868
|
-
result = {}
|
869
|
-
|
870
|
-
for d in dicts:
|
871
|
-
if not isinstance(d, dict):
|
872
|
-
raise TypeError(f"All arguments must be dicts, got {type(d).__name__}")
|
873
|
-
|
874
|
-
for key, value in d.items():
|
875
|
-
if recursive and key in result and isinstance(result[key], dict) and isinstance(value, dict):
|
876
|
-
result[key] = merge_dicts(result[key], value, recursive=True)
|
877
|
-
else:
|
878
|
-
result[key] = value
|
879
|
-
|
880
|
-
return result
|
881
|
-
|
882
|
-
|
883
|
-
@set_module_as('brainstate.util')
|
884
|
-
def flatten_dict(
|
885
|
-
d: Dict[str, Any],
|
886
|
-
parent_key: str = '',
|
887
|
-
sep: str = '.'
|
888
|
-
) -> Dict[str, Any]:
|
889
|
-
"""
|
890
|
-
Flatten a nested dictionary.
|
891
|
-
|
892
|
-
Parameters
|
893
|
-
----------
|
894
|
-
d : Dict
|
895
|
-
Dictionary to flatten.
|
896
|
-
parent_key : str, default=''
|
897
|
-
Prefix for keys.
|
898
|
-
sep : str, default='.'
|
899
|
-
Separator between nested keys.
|
900
|
-
|
901
|
-
Returns
|
902
|
-
-------
|
903
|
-
Dict
|
904
|
-
Flattened dictionary.
|
905
|
-
|
906
|
-
Examples
|
907
|
-
--------
|
908
|
-
>>> d = {'a': 1, 'b': {'c': 2, 'd': {'e': 3}}}
|
909
|
-
>>> flatten_dict(d)
|
910
|
-
{'a': 1, 'b.c': 2, 'b.d.e': 3}
|
911
|
-
"""
|
912
|
-
items = []
|
913
|
-
for k, v in d.items():
|
914
|
-
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
915
|
-
if isinstance(v, dict):
|
916
|
-
items.extend(flatten_dict(v, new_key, sep=sep).items())
|
917
|
-
else:
|
918
|
-
items.append((new_key, v))
|
919
|
-
return dict(items)
|
920
|
-
|
921
|
-
|
922
|
-
@set_module_as('brainstate.util')
|
923
|
-
def unflatten_dict(
|
924
|
-
d: Dict[str, Any],
|
925
|
-
sep: str = '.'
|
926
|
-
) -> Dict[str, Any]:
|
927
|
-
"""
|
928
|
-
Unflatten a dictionary with separated keys.
|
929
|
-
|
930
|
-
Parameters
|
931
|
-
----------
|
932
|
-
d : Dict
|
933
|
-
Flattened dictionary.
|
934
|
-
sep : str, default='.'
|
935
|
-
Separator in keys.
|
936
|
-
|
937
|
-
Returns
|
938
|
-
-------
|
939
|
-
Dict
|
940
|
-
Nested dictionary.
|
941
|
-
|
942
|
-
Examples
|
943
|
-
--------
|
944
|
-
>>> d = {'a': 1, 'b.c': 2, 'b.d.e': 3}
|
945
|
-
>>> unflatten_dict(d)
|
946
|
-
{'a': 1, 'b': {'c': 2, 'd': {'e': 3}}}
|
947
|
-
"""
|
948
|
-
result = {}
|
949
|
-
|
950
|
-
for key, value in d.items():
|
951
|
-
parts = key.split(sep)
|
952
|
-
current = result
|
953
|
-
|
954
|
-
for part in parts[:-1]:
|
955
|
-
if part not in current:
|
956
|
-
current[part] = {}
|
957
|
-
current = current[part]
|
958
|
-
|
959
|
-
current[parts[-1]] = value
|
960
|
-
|
961
|
-
return result
|
962
|
-
|
963
|
-
|
964
|
-
def _is_not_instance(x: Any, cls: Union[Type, Tuple[Type, ...]]) -> bool:
|
965
|
-
"""Check if x is not an instance of cls."""
|
966
|
-
return not isinstance(x, cls)
|
967
|
-
|
968
|
-
|
969
|
-
def _is_instance(x: Any, cls: Union[Type, Tuple[Type, ...]]) -> bool:
|
970
|
-
"""Check if x is an instance of cls."""
|
971
|
-
return isinstance(x, cls)
|
972
|
-
|
973
|
-
|
974
|
-
@set_module_as('brainstate.util')
|
975
|
-
def not_instance_eval(*cls: Type) -> Callable[[Any], bool]:
|
976
|
-
"""
|
977
|
-
Create a partial function to check if input is NOT an instance of given classes.
|
978
|
-
|
979
|
-
Parameters
|
980
|
-
----------
|
981
|
-
*cls : Type
|
982
|
-
Classes to check against.
|
983
|
-
|
984
|
-
Returns
|
985
|
-
-------
|
986
|
-
Callable
|
987
|
-
A function that returns True if input is not an instance of any given class.
|
988
|
-
|
989
|
-
Examples
|
990
|
-
--------
|
991
|
-
>>> not_int = not_instance_eval(int)
|
992
|
-
>>> not_int(5)
|
993
|
-
False
|
994
|
-
>>> not_int("hello")
|
995
|
-
True
|
996
|
-
"""
|
997
|
-
return functools.partial(_is_not_instance, cls=cls)
|
998
|
-
|
999
|
-
|
1000
|
-
@set_module_as('brainstate.util')
|
1001
|
-
def is_instance_eval(*cls: Type) -> Callable[[Any], bool]:
|
1002
|
-
"""
|
1003
|
-
Create a partial function to check if input IS an instance of given classes.
|
1004
|
-
|
1005
|
-
Parameters
|
1006
|
-
----------
|
1007
|
-
*cls : Type
|
1008
|
-
Classes to check against.
|
1009
|
-
|
1010
|
-
Returns
|
1011
|
-
-------
|
1012
|
-
Callable
|
1013
|
-
A function that returns True if input is an instance of any given class.
|
1014
|
-
|
1015
|
-
Examples
|
1016
|
-
--------
|
1017
|
-
>>> is_number = is_instance_eval(int, float)
|
1018
|
-
>>> is_number(5)
|
1019
|
-
True
|
1020
|
-
>>> is_number(3.14)
|
1021
|
-
True
|
1022
|
-
>>> is_number("hello")
|
1023
|
-
False
|
1024
|
-
"""
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""
|
17
|
+
Utility functions and classes for BrainState.
|
18
|
+
|
19
|
+
This module provides various utility functions and enhanced dictionary classes
|
20
|
+
for managing collections, memory, and object operations in the BrainState framework.
|
21
|
+
"""
|
22
|
+
|
23
|
+
import copy
|
24
|
+
import functools
|
25
|
+
import gc
|
26
|
+
import threading
|
27
|
+
import types
|
28
|
+
import warnings
|
29
|
+
from collections.abc import Iterable, Mapping, MutableMapping
|
30
|
+
from typing import (
|
31
|
+
Any, Callable, Dict, Iterator, List, Optional,
|
32
|
+
Tuple, Type, TypeVar, Union, overload
|
33
|
+
)
|
34
|
+
|
35
|
+
import jax
|
36
|
+
from jax.lib import xla_bridge
|
37
|
+
|
38
|
+
from brainstate._utils import set_module_as
|
39
|
+
|
40
|
+
__all__ = [
|
41
|
+
'split_total',
|
42
|
+
'clear_buffer_memory',
|
43
|
+
'not_instance_eval',
|
44
|
+
'is_instance_eval',
|
45
|
+
'DictManager',
|
46
|
+
'DotDict',
|
47
|
+
'get_unique_name',
|
48
|
+
'merge_dicts',
|
49
|
+
'flatten_dict',
|
50
|
+
'unflatten_dict',
|
51
|
+
]
|
52
|
+
|
53
|
+
T = TypeVar('T')
|
54
|
+
V = TypeVar('V')
|
55
|
+
K = TypeVar('K')
|
56
|
+
|
57
|
+
|
58
|
+
def split_total(
|
59
|
+
total: int,
|
60
|
+
fraction: Union[int, float],
|
61
|
+
) -> int:
|
62
|
+
"""
|
63
|
+
Calculate the number of epochs for simulation based on a total and a fraction.
|
64
|
+
|
65
|
+
This function determines the number of epochs to simulate given a total number
|
66
|
+
of epochs and either a fraction or a specific number of epochs to run.
|
67
|
+
|
68
|
+
Parameters
|
69
|
+
----------
|
70
|
+
total : int
|
71
|
+
The total number of epochs. Must be a positive integer.
|
72
|
+
fraction : Union[int, float]
|
73
|
+
If ``float``: A value between 0 and 1 representing the fraction of total epochs to run.
|
74
|
+
If ``int``: The specific number of epochs to run, must not exceed the total.
|
75
|
+
|
76
|
+
Returns
|
77
|
+
-------
|
78
|
+
int
|
79
|
+
The calculated number of epochs to simulate.
|
80
|
+
|
81
|
+
Raises
|
82
|
+
------
|
83
|
+
TypeError
|
84
|
+
If total is not an integer.
|
85
|
+
ValueError
|
86
|
+
If total is not positive, fraction is negative, or if fraction as float is > 1
|
87
|
+
or as int is > total.
|
88
|
+
|
89
|
+
Examples
|
90
|
+
--------
|
91
|
+
>>> split_total(100, 0.5)
|
92
|
+
50
|
93
|
+
>>> split_total(100, 25)
|
94
|
+
25
|
95
|
+
>>> split_total(100, 1.5) # Raises ValueError
|
96
|
+
ValueError: 'fraction' value cannot be greater than 1.
|
97
|
+
"""
|
98
|
+
if not isinstance(total, int):
|
99
|
+
raise TypeError(f"'total' must be an integer, got {type(total).__name__}.")
|
100
|
+
if total <= 0:
|
101
|
+
raise ValueError(f"'total' must be a positive integer, got {total}.")
|
102
|
+
|
103
|
+
if isinstance(fraction, float):
|
104
|
+
if fraction < 0:
|
105
|
+
raise ValueError(f"'fraction' value cannot be negative, got {fraction}.")
|
106
|
+
if fraction > 1:
|
107
|
+
raise ValueError(f"'fraction' value cannot be greater than 1, got {fraction}.")
|
108
|
+
return int(total * fraction)
|
109
|
+
|
110
|
+
elif isinstance(fraction, int):
|
111
|
+
if fraction < 0:
|
112
|
+
raise ValueError(f"'fraction' value cannot be negative, got {fraction}.")
|
113
|
+
if fraction > total:
|
114
|
+
raise ValueError(f"'fraction' value cannot be greater than total ({total}), got {fraction}.")
|
115
|
+
return fraction
|
116
|
+
|
117
|
+
else:
|
118
|
+
raise TypeError(f"'fraction' must be an integer or float, got {type(fraction).__name__}.")
|
119
|
+
|
120
|
+
|
121
|
+
class NameContext(threading.local):
|
122
|
+
"""Thread-local context for managing unique names."""
|
123
|
+
|
124
|
+
def __init__(self):
|
125
|
+
self.typed_names: Dict[str, int] = {}
|
126
|
+
|
127
|
+
def reset(self, type_: Optional[str] = None) -> None:
|
128
|
+
"""Reset the counter for a specific type or all types."""
|
129
|
+
if type_ is None:
|
130
|
+
self.typed_names.clear()
|
131
|
+
elif type_ in self.typed_names:
|
132
|
+
self.typed_names[type_] = 0
|
133
|
+
|
134
|
+
|
135
|
+
NAME = NameContext()
|
136
|
+
|
137
|
+
|
138
|
+
@set_module_as('brainstate.util')
|
139
|
+
def get_unique_name(type_: str, prefix: str = '') -> str:
|
140
|
+
"""
|
141
|
+
Get a unique name for the given object type.
|
142
|
+
|
143
|
+
Parameters
|
144
|
+
----------
|
145
|
+
type_ : str
|
146
|
+
The base type name.
|
147
|
+
prefix : str, optional
|
148
|
+
Additional prefix to add before the type name.
|
149
|
+
|
150
|
+
Returns
|
151
|
+
-------
|
152
|
+
str
|
153
|
+
A unique name combining prefix, type, and counter.
|
154
|
+
|
155
|
+
Examples
|
156
|
+
--------
|
157
|
+
>>> get_unique_name('layer')
|
158
|
+
'layer0'
|
159
|
+
>>> get_unique_name('layer', 'conv_')
|
160
|
+
'conv_layer1'
|
161
|
+
"""
|
162
|
+
if type_ not in NAME.typed_names:
|
163
|
+
NAME.typed_names[type_] = 0
|
164
|
+
|
165
|
+
full_prefix = f'{prefix}{type_}' if prefix else type_
|
166
|
+
name = f'{full_prefix}{NAME.typed_names[type_]}'
|
167
|
+
NAME.typed_names[type_] += 1
|
168
|
+
return name
|
169
|
+
|
170
|
+
|
171
|
+
@jax.tree_util.register_pytree_node_class
|
172
|
+
class DictManager(dict, MutableMapping[K, V]):
|
173
|
+
"""
|
174
|
+
Enhanced dictionary for managing collections in BrainState.
|
175
|
+
|
176
|
+
DictManager extends the standard Python dict with additional methods for
|
177
|
+
filtering, splitting, and managing collections of objects. It's registered
|
178
|
+
as a JAX pytree node for compatibility with JAX transformations.
|
179
|
+
|
180
|
+
Examples
|
181
|
+
--------
|
182
|
+
>>> dm = DictManager({'a': 1, 'b': 2.0, 'c': 'text'})
|
183
|
+
>>> dm.subset(int) # Get only integer values
|
184
|
+
DictManager({'a': 1})
|
185
|
+
>>> dm.unique() # Get unique values only
|
186
|
+
DictManager({'a': 1, 'b': 2.0, 'c': 'text'})
|
187
|
+
"""
|
188
|
+
|
189
|
+
__module__ = 'brainstate.util'
|
190
|
+
_val_id_to_key: Dict[int, Any]
|
191
|
+
|
192
|
+
def __init__(self, *args, **kwargs):
|
193
|
+
"""Initialize DictManager with optional dict-like arguments."""
|
194
|
+
super().__init__(*args, **kwargs)
|
195
|
+
self._val_id_to_key = {}
|
196
|
+
|
197
|
+
def subset(self, sep: Union[Type, Tuple[Type, ...], Callable[[Any], bool]]) -> 'DictManager':
|
198
|
+
"""
|
199
|
+
Get a new DictManager with a subset of items based on value type or predicate.
|
200
|
+
|
201
|
+
Parameters
|
202
|
+
----------
|
203
|
+
sep : Union[Type, Tuple[Type, ...], Callable]
|
204
|
+
If Type or Tuple of Types: Select values that are instances of these types.
|
205
|
+
If Callable: Select values where sep(value) returns True.
|
206
|
+
|
207
|
+
Returns
|
208
|
+
-------
|
209
|
+
DictManager
|
210
|
+
A new DictManager containing only matching items.
|
211
|
+
|
212
|
+
Examples
|
213
|
+
--------
|
214
|
+
>>> dm = DictManager({'a': 1, 'b': 2.0, 'c': 'text'})
|
215
|
+
>>> dm.subset(int)
|
216
|
+
DictManager({'a': 1})
|
217
|
+
>>> dm.subset(lambda x: isinstance(x, (int, float)))
|
218
|
+
DictManager({'a': 1, 'b': 2.0})
|
219
|
+
"""
|
220
|
+
gather = type(self)()
|
221
|
+
if callable(sep) and not isinstance(sep, type):
|
222
|
+
for k, v in self.items():
|
223
|
+
if sep(v):
|
224
|
+
gather[k] = v
|
225
|
+
else:
|
226
|
+
for k, v in self.items():
|
227
|
+
if isinstance(v, sep):
|
228
|
+
gather[k] = v
|
229
|
+
return gather
|
230
|
+
|
231
|
+
def not_subset(self, sep: Union[Type, Tuple[Type, ...]]) -> 'DictManager':
|
232
|
+
"""
|
233
|
+
Get a new DictManager excluding items of specified types.
|
234
|
+
|
235
|
+
Parameters
|
236
|
+
----------
|
237
|
+
sep : Union[Type, Tuple[Type, ...]]
|
238
|
+
Types to exclude from the result.
|
239
|
+
|
240
|
+
Returns
|
241
|
+
-------
|
242
|
+
DictManager
|
243
|
+
A new DictManager excluding items of specified types.
|
244
|
+
"""
|
245
|
+
gather = type(self)()
|
246
|
+
for k, v in self.items():
|
247
|
+
if not isinstance(v, sep):
|
248
|
+
gather[k] = v
|
249
|
+
return gather
|
250
|
+
|
251
|
+
def add_unique_key(self, key: K, val: V) -> None:
|
252
|
+
"""
|
253
|
+
Add a new element ensuring the key maps to a unique value.
|
254
|
+
|
255
|
+
Parameters
|
256
|
+
----------
|
257
|
+
key : Any
|
258
|
+
The key to add.
|
259
|
+
val : Any
|
260
|
+
The value to associate with the key.
|
261
|
+
|
262
|
+
Raises
|
263
|
+
------
|
264
|
+
ValueError
|
265
|
+
If the key already exists with a different value.
|
266
|
+
"""
|
267
|
+
self._check_elem(val)
|
268
|
+
if key in self:
|
269
|
+
if id(val) != id(self[key]):
|
270
|
+
raise ValueError(
|
271
|
+
f"Key '{key}' already exists with a different value. "
|
272
|
+
f"Existing: {self[key]}, New: {val}"
|
273
|
+
)
|
274
|
+
else:
|
275
|
+
self[key] = val
|
276
|
+
|
277
|
+
def add_unique_value(self, key: K, val: V) -> bool:
|
278
|
+
"""
|
279
|
+
Add a new element only if the value is unique across all entries.
|
280
|
+
|
281
|
+
Parameters
|
282
|
+
----------
|
283
|
+
key : Any
|
284
|
+
The key to add.
|
285
|
+
val : Any
|
286
|
+
The value to associate with the key.
|
287
|
+
|
288
|
+
Returns
|
289
|
+
-------
|
290
|
+
bool
|
291
|
+
True if the value was added (was unique), False otherwise.
|
292
|
+
"""
|
293
|
+
self._check_elem(val)
|
294
|
+
if not hasattr(self, '_val_id_to_key'):
|
295
|
+
self._val_id_to_key = {id(v): k for k, v in self.items()}
|
296
|
+
|
297
|
+
val_id = id(val)
|
298
|
+
if val_id not in self._val_id_to_key:
|
299
|
+
self._val_id_to_key[val_id] = key
|
300
|
+
self[key] = val
|
301
|
+
return True
|
302
|
+
return False
|
303
|
+
|
304
|
+
def unique(self) -> 'DictManager':
|
305
|
+
"""
|
306
|
+
Get a new DictManager with unique values only.
|
307
|
+
|
308
|
+
If multiple keys map to the same value (by identity),
|
309
|
+
only the first key-value pair is retained.
|
310
|
+
|
311
|
+
Returns
|
312
|
+
-------
|
313
|
+
DictManager
|
314
|
+
A new DictManager with unique values.
|
315
|
+
"""
|
316
|
+
gather = type(self)()
|
317
|
+
seen = set()
|
318
|
+
for k, v in self.items():
|
319
|
+
v_id = id(v)
|
320
|
+
if v_id not in seen:
|
321
|
+
seen.add(v_id)
|
322
|
+
gather[k] = v
|
323
|
+
return gather
|
324
|
+
|
325
|
+
def unique_(self) -> 'DictManager':
|
326
|
+
"""
|
327
|
+
Remove duplicate values in-place.
|
328
|
+
|
329
|
+
Returns
|
330
|
+
-------
|
331
|
+
DictManager
|
332
|
+
Self, for method chaining.
|
333
|
+
"""
|
334
|
+
seen = set()
|
335
|
+
keys_to_remove = []
|
336
|
+
for k, v in self.items():
|
337
|
+
v_id = id(v)
|
338
|
+
if v_id in seen:
|
339
|
+
keys_to_remove.append(k)
|
340
|
+
else:
|
341
|
+
seen.add(v_id)
|
342
|
+
|
343
|
+
for k in keys_to_remove:
|
344
|
+
del self[k]
|
345
|
+
return self
|
346
|
+
|
347
|
+
def assign(self, *args: Dict[K, V], **kwargs: V) -> None:
|
348
|
+
"""
|
349
|
+
Update the DictManager with multiple dictionaries.
|
350
|
+
|
351
|
+
Parameters
|
352
|
+
----------
|
353
|
+
*args : Dict
|
354
|
+
Dictionaries to merge into this one.
|
355
|
+
**kwargs
|
356
|
+
Additional key-value pairs to add.
|
357
|
+
"""
|
358
|
+
for arg in args:
|
359
|
+
if not isinstance(arg, dict):
|
360
|
+
raise TypeError(f"Arguments must be dict instances, got {type(arg).__name__}")
|
361
|
+
self.update(arg)
|
362
|
+
if kwargs:
|
363
|
+
self.update(kwargs)
|
364
|
+
|
365
|
+
def split(self, *types: Type) -> Tuple['DictManager', ...]:
|
366
|
+
"""
|
367
|
+
Split the DictManager into multiple based on value types.
|
368
|
+
|
369
|
+
Parameters
|
370
|
+
----------
|
371
|
+
*types : Type
|
372
|
+
Types to use for splitting. Each type gets its own DictManager.
|
373
|
+
|
374
|
+
Returns
|
375
|
+
-------
|
376
|
+
Tuple[DictManager, ...]
|
377
|
+
A tuple of DictManagers, one for each type plus one for unmatched items.
|
378
|
+
"""
|
379
|
+
results = tuple(type(self)() for _ in range(len(types) + 1))
|
380
|
+
|
381
|
+
for k, v in self.items():
|
382
|
+
for i, type_ in enumerate(types):
|
383
|
+
if isinstance(v, type_):
|
384
|
+
results[i][k] = v
|
385
|
+
break
|
386
|
+
else:
|
387
|
+
results[-1][k] = v
|
388
|
+
|
389
|
+
return results
|
390
|
+
|
391
|
+
def filter_by_predicate(self, predicate: Callable[[K, V], bool]) -> 'DictManager':
|
392
|
+
"""
|
393
|
+
Filter items using a predicate function.
|
394
|
+
|
395
|
+
Parameters
|
396
|
+
----------
|
397
|
+
predicate : Callable[[key, value], bool]
|
398
|
+
Function that returns True for items to keep.
|
399
|
+
|
400
|
+
Returns
|
401
|
+
-------
|
402
|
+
DictManager
|
403
|
+
A new DictManager with filtered items.
|
404
|
+
"""
|
405
|
+
return type(self)({k: v for k, v in self.items() if predicate(k, v)})
|
406
|
+
|
407
|
+
def map_values(self, func: Callable[[V], Any]) -> 'DictManager':
|
408
|
+
"""
|
409
|
+
Apply a function to all values.
|
410
|
+
|
411
|
+
Parameters
|
412
|
+
----------
|
413
|
+
func : Callable
|
414
|
+
Function to apply to each value.
|
415
|
+
|
416
|
+
Returns
|
417
|
+
-------
|
418
|
+
DictManager
|
419
|
+
A new DictManager with transformed values.
|
420
|
+
"""
|
421
|
+
return type(self)({k: func(v) for k, v in self.items()})
|
422
|
+
|
423
|
+
def map_keys(self, func: Callable[[K], Any]) -> 'DictManager':
|
424
|
+
"""
|
425
|
+
Apply a function to all keys.
|
426
|
+
|
427
|
+
Parameters
|
428
|
+
----------
|
429
|
+
func : Callable
|
430
|
+
Function to apply to each key.
|
431
|
+
|
432
|
+
Returns
|
433
|
+
-------
|
434
|
+
DictManager
|
435
|
+
A new DictManager with transformed keys.
|
436
|
+
|
437
|
+
Raises
|
438
|
+
------
|
439
|
+
ValueError
|
440
|
+
If the transformation creates duplicate keys.
|
441
|
+
"""
|
442
|
+
result = type(self)()
|
443
|
+
for k, v in self.items():
|
444
|
+
new_key = func(k)
|
445
|
+
if new_key in result:
|
446
|
+
raise ValueError(f"Key transformation created duplicate: {new_key}")
|
447
|
+
result[new_key] = v
|
448
|
+
return result
|
449
|
+
|
450
|
+
def pop_by_keys(self, keys: Iterable[K]) -> None:
|
451
|
+
"""Remove multiple keys from the DictManager."""
|
452
|
+
keys_set = set(keys)
|
453
|
+
for k in list(self.keys()):
|
454
|
+
if k in keys_set:
|
455
|
+
self.pop(k)
|
456
|
+
|
457
|
+
def pop_by_values(self, values: Iterable[V], by: str = 'id') -> None:
|
458
|
+
"""
|
459
|
+
Remove items by their values.
|
460
|
+
|
461
|
+
Parameters
|
462
|
+
----------
|
463
|
+
values : Iterable
|
464
|
+
Values to remove.
|
465
|
+
by : str
|
466
|
+
Comparison method: 'id' (identity) or 'value' (equality).
|
467
|
+
"""
|
468
|
+
if by == 'id':
|
469
|
+
value_ids = {id(v) for v in values}
|
470
|
+
keys_to_remove = [k for k, v in self.items() if id(v) in value_ids]
|
471
|
+
elif by == 'value':
|
472
|
+
values_set = set(values) if not isinstance(values, set) else values
|
473
|
+
keys_to_remove = [k for k, v in self.items() if v in values_set]
|
474
|
+
else:
|
475
|
+
raise ValueError(f"Invalid comparison method: {by}. Use 'id' or 'value'.")
|
476
|
+
|
477
|
+
for k in keys_to_remove:
|
478
|
+
del self[k]
|
479
|
+
|
480
|
+
def difference_by_keys(self, keys: Iterable[K]) -> 'DictManager':
|
481
|
+
"""Get items not in the specified keys."""
|
482
|
+
keys_set = set(keys)
|
483
|
+
return type(self)({k: v for k, v in self.items() if k not in keys_set})
|
484
|
+
|
485
|
+
def difference_by_values(self, values: Iterable[V], by: str = 'id') -> 'DictManager':
|
486
|
+
"""Get items whose values are not in the specified collection."""
|
487
|
+
if by == 'id':
|
488
|
+
value_ids = {id(v) for v in values}
|
489
|
+
return type(self)({k: v for k, v in self.items() if id(v) not in value_ids})
|
490
|
+
elif by == 'value':
|
491
|
+
values_set = set(values) if not isinstance(values, set) else values
|
492
|
+
return type(self)({k: v for k, v in self.items() if v not in values_set})
|
493
|
+
else:
|
494
|
+
raise ValueError(f"Invalid comparison method: {by}. Use 'id' or 'value'.")
|
495
|
+
|
496
|
+
def intersection_by_keys(self, keys: Iterable[K]) -> 'DictManager':
|
497
|
+
"""Get items with keys in the specified collection."""
|
498
|
+
keys_set = set(keys)
|
499
|
+
return type(self)({k: v for k, v in self.items() if k in keys_set})
|
500
|
+
|
501
|
+
def intersection_by_values(self, values: Iterable[V], by: str = 'id') -> 'DictManager':
|
502
|
+
"""Get items whose values are in the specified collection."""
|
503
|
+
if by == 'id':
|
504
|
+
value_ids = {id(v) for v in values}
|
505
|
+
return type(self)({k: v for k, v in self.items() if id(v) in value_ids})
|
506
|
+
elif by == 'value':
|
507
|
+
values_set = set(values) if not isinstance(values, set) else values
|
508
|
+
return type(self)({k: v for k, v in self.items() if v in values_set})
|
509
|
+
else:
|
510
|
+
raise ValueError(f"Invalid comparison method: {by}. Use 'id' or 'value'.")
|
511
|
+
|
512
|
+
def __add__(self, other: Mapping[K, V]) -> 'DictManager':
|
513
|
+
"""Combine with another mapping using the + operator."""
|
514
|
+
if not isinstance(other, Mapping):
|
515
|
+
return NotImplemented
|
516
|
+
new_dict = type(self)(self)
|
517
|
+
new_dict.update(other)
|
518
|
+
return new_dict
|
519
|
+
|
520
|
+
def __or__(self, other: Mapping[K, V]) -> 'DictManager':
|
521
|
+
"""Combine with another mapping using the | operator (Python 3.9+)."""
|
522
|
+
if not isinstance(other, Mapping):
|
523
|
+
return NotImplemented
|
524
|
+
new_dict = type(self)(self)
|
525
|
+
new_dict.update(other)
|
526
|
+
return new_dict
|
527
|
+
|
528
|
+
def __ior__(self, other: Mapping[K, V]) -> 'DictManager':
|
529
|
+
"""Update in-place with another mapping using |= operator."""
|
530
|
+
if not isinstance(other, Mapping):
|
531
|
+
return NotImplemented
|
532
|
+
self.update(other)
|
533
|
+
return self
|
534
|
+
|
535
|
+
def tree_flatten(self) -> Tuple[Tuple[V, ...], Tuple[K, ...]]:
|
536
|
+
"""Flatten for JAX pytree."""
|
537
|
+
return tuple(self.values()), tuple(self.keys())
|
538
|
+
|
539
|
+
@classmethod
|
540
|
+
def tree_unflatten(cls, keys: Tuple[K, ...], values: Tuple[V, ...]) -> 'DictManager':
|
541
|
+
"""Unflatten from JAX pytree."""
|
542
|
+
return cls(zip(keys, values))
|
543
|
+
|
544
|
+
def _check_elem(self, elem: Any) -> None:
|
545
|
+
"""Override in subclasses to validate elements."""
|
546
|
+
pass
|
547
|
+
|
548
|
+
def to_dict(self) -> Dict[K, V]:
|
549
|
+
"""Convert to a standard Python dict."""
|
550
|
+
return dict(self)
|
551
|
+
|
552
|
+
def __copy__(self) -> 'DictManager':
|
553
|
+
"""Shallow copy."""
|
554
|
+
return type(self)(self)
|
555
|
+
|
556
|
+
def __deepcopy__(self, memo: Dict[int, Any]) -> 'DictManager':
|
557
|
+
"""Deep copy."""
|
558
|
+
return type(self)({
|
559
|
+
copy.deepcopy(k, memo): copy.deepcopy(v, memo)
|
560
|
+
for k, v in self.items()
|
561
|
+
})
|
562
|
+
|
563
|
+
def __repr__(self) -> str:
|
564
|
+
"""String representation."""
|
565
|
+
items = ', '.join(f'{k!r}: {v!r}' for k, v in self.items())
|
566
|
+
return f'{self.__class__.__name__}({{{items}}})'
|
567
|
+
|
568
|
+
|
569
|
+
@set_module_as('brainstate.util')
|
570
|
+
def clear_buffer_memory(
|
571
|
+
platform: Optional[str] = None,
|
572
|
+
array: bool = True,
|
573
|
+
compilation: bool = False,
|
574
|
+
) -> None:
|
575
|
+
"""
|
576
|
+
Clear on-device memory buffers and optionally compilation cache.
|
577
|
+
|
578
|
+
This function is useful when running models in loops to prevent memory leaks
|
579
|
+
by clearing cached arrays and freeing device memory.
|
580
|
+
|
581
|
+
.. warning::
|
582
|
+
This operation may invalidate existing array references.
|
583
|
+
Regenerate data after calling this function.
|
584
|
+
|
585
|
+
Parameters
|
586
|
+
----------
|
587
|
+
platform : str, optional
|
588
|
+
The specific device platform to clear. If None, clears the default platform.
|
589
|
+
array : bool, default=True
|
590
|
+
Whether to clear array buffers.
|
591
|
+
compilation : bool, default=False
|
592
|
+
Whether to clear the compilation cache.
|
593
|
+
|
594
|
+
Examples
|
595
|
+
--------
|
596
|
+
>>> clear_buffer_memory() # Clear array buffers
|
597
|
+
>>> clear_buffer_memory(compilation=True) # Also clear compilation cache
|
598
|
+
"""
|
599
|
+
if array:
|
600
|
+
try:
|
601
|
+
backend = xla_bridge.get_backend(platform)
|
602
|
+
for buf in backend.live_buffers():
|
603
|
+
buf.delete()
|
604
|
+
except Exception as e:
|
605
|
+
warnings.warn(f"Failed to clear buffers: {e}", RuntimeWarning)
|
606
|
+
|
607
|
+
if compilation:
|
608
|
+
jax.clear_caches()
|
609
|
+
|
610
|
+
gc.collect()
|
611
|
+
|
612
|
+
|
613
|
+
@jax.tree_util.register_pytree_node_class
|
614
|
+
class DotDict(dict, MutableMapping[str, Any]):
|
615
|
+
"""
|
616
|
+
Dictionary with dot notation access to nested keys.
|
617
|
+
|
618
|
+
DotDict allows accessing dictionary items using attribute syntax,
|
619
|
+
making code more readable when dealing with nested configurations.
|
620
|
+
|
621
|
+
Examples
|
622
|
+
--------
|
623
|
+
>>> config = DotDict({'model': {'layers': 3, 'units': 64}})
|
624
|
+
>>> config.model.layers
|
625
|
+
3
|
626
|
+
>>> config.model.units = 128
|
627
|
+
>>> config['model']['units']
|
628
|
+
128
|
629
|
+
|
630
|
+
Attributes
|
631
|
+
----------
|
632
|
+
All dictionary keys become accessible as attributes unless they conflict
|
633
|
+
with built-in methods.
|
634
|
+
"""
|
635
|
+
|
636
|
+
__module__ = 'brainstate.util'
|
637
|
+
|
638
|
+
def __init__(self, *args, **kwargs):
|
639
|
+
"""
|
640
|
+
Initialize DotDict with dict-like arguments.
|
641
|
+
|
642
|
+
Parameters
|
643
|
+
----------
|
644
|
+
*args
|
645
|
+
Positional arguments (dicts, iterables of pairs).
|
646
|
+
**kwargs
|
647
|
+
Keyword arguments become key-value pairs.
|
648
|
+
"""
|
649
|
+
# Handle parent reference for nested updates
|
650
|
+
object.__setattr__(self, '__parent', kwargs.pop('__parent', None))
|
651
|
+
object.__setattr__(self, '__key', kwargs.pop('__key', None))
|
652
|
+
|
653
|
+
# Process positional arguments
|
654
|
+
for arg in args:
|
655
|
+
if not arg:
|
656
|
+
continue
|
657
|
+
elif isinstance(arg, dict):
|
658
|
+
for key, val in arg.items():
|
659
|
+
self[key] = self._hook(val)
|
660
|
+
elif isinstance(arg, tuple) and len(arg) == 2 and not isinstance(arg[0], tuple):
|
661
|
+
# Single key-value pair
|
662
|
+
self[arg[0]] = self._hook(arg[1])
|
663
|
+
else:
|
664
|
+
# Iterable of key-value pairs
|
665
|
+
try:
|
666
|
+
for key, val in arg:
|
667
|
+
self[key] = self._hook(val)
|
668
|
+
except (TypeError, ValueError) as e:
|
669
|
+
raise TypeError(f"Invalid argument type for DotDict: {type(arg).__name__}") from e
|
670
|
+
|
671
|
+
# Process keyword arguments
|
672
|
+
for key, val in kwargs.items():
|
673
|
+
self[key] = self._hook(val)
|
674
|
+
|
675
|
+
def __setattr__(self, name: str, value: Any) -> None:
|
676
|
+
"""Set attribute as dictionary item."""
|
677
|
+
if hasattr(self.__class__, name):
|
678
|
+
raise AttributeError(
|
679
|
+
f"Cannot set attribute '{name}': it's a built-in method of {self.__class__.__name__}"
|
680
|
+
)
|
681
|
+
self[name] = value
|
682
|
+
|
683
|
+
def __setitem__(self, name: str, value: Any) -> None:
|
684
|
+
"""Set item and update parent if nested."""
|
685
|
+
super().__setitem__(name, value)
|
686
|
+
try:
|
687
|
+
parent = object.__getattribute__(self, '__parent')
|
688
|
+
key = object.__getattribute__(self, '__key')
|
689
|
+
if parent is not None:
|
690
|
+
parent[key] = self
|
691
|
+
object.__delattr__(self, '__parent')
|
692
|
+
object.__delattr__(self, '__key')
|
693
|
+
except AttributeError:
|
694
|
+
pass
|
695
|
+
|
696
|
+
@classmethod
|
697
|
+
def _hook(cls, item: Any) -> Any:
|
698
|
+
"""Convert nested dicts to DotDict."""
|
699
|
+
if isinstance(item, dict) and not isinstance(item, cls):
|
700
|
+
return cls(item)
|
701
|
+
elif isinstance(item, (list, tuple)):
|
702
|
+
return type(item)(cls._hook(elem) for elem in item)
|
703
|
+
return item
|
704
|
+
|
705
|
+
def __getattr__(self, item: str) -> Any:
|
706
|
+
"""Get attribute from dictionary."""
|
707
|
+
try:
|
708
|
+
return self[item]
|
709
|
+
except KeyError:
|
710
|
+
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'")
|
711
|
+
|
712
|
+
def __delattr__(self, name: str) -> None:
|
713
|
+
"""Delete attribute from dictionary."""
|
714
|
+
try:
|
715
|
+
del self[name]
|
716
|
+
except KeyError:
|
717
|
+
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
|
718
|
+
|
719
|
+
def __dir__(self) -> List[str]:
|
720
|
+
"""List all attributes including dict keys."""
|
721
|
+
return list(self.keys()) + dir(self.__class__)
|
722
|
+
|
723
|
+
def get(self, key: str, default: Any = None) -> Any:
|
724
|
+
"""Get item with default value."""
|
725
|
+
return super().get(key, default)
|
726
|
+
|
727
|
+
def copy(self) -> 'DotDict':
|
728
|
+
"""Create a shallow copy."""
|
729
|
+
return copy.copy(self)
|
730
|
+
|
731
|
+
def deepcopy(self) -> 'DotDict':
|
732
|
+
"""Create a deep copy."""
|
733
|
+
return copy.deepcopy(self)
|
734
|
+
|
735
|
+
def __deepcopy__(self, memo: Dict[int, Any]) -> 'DotDict':
|
736
|
+
"""Deep copy implementation."""
|
737
|
+
other = self.__class__()
|
738
|
+
memo[id(self)] = other
|
739
|
+
for key, value in self.items():
|
740
|
+
other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo)
|
741
|
+
return other
|
742
|
+
|
743
|
+
def to_dict(self) -> Dict[str, Any]:
|
744
|
+
"""
|
745
|
+
Convert to standard dict recursively.
|
746
|
+
|
747
|
+
Returns
|
748
|
+
-------
|
749
|
+
dict
|
750
|
+
A standard Python dict with nested DotDicts also converted.
|
751
|
+
"""
|
752
|
+
result = {}
|
753
|
+
for key, value in self.items():
|
754
|
+
if isinstance(value, DotDict):
|
755
|
+
result[key] = value.to_dict()
|
756
|
+
elif isinstance(value, (list, tuple)):
|
757
|
+
result[key] = type(value)(
|
758
|
+
item.to_dict() if isinstance(item, DotDict) else item
|
759
|
+
for item in value
|
760
|
+
)
|
761
|
+
else:
|
762
|
+
result[key] = value
|
763
|
+
return result
|
764
|
+
|
765
|
+
@classmethod
|
766
|
+
def from_dict(cls, d: Dict[str, Any]) -> 'DotDict':
|
767
|
+
"""
|
768
|
+
Create DotDict from standard dict.
|
769
|
+
|
770
|
+
Parameters
|
771
|
+
----------
|
772
|
+
d : dict
|
773
|
+
Standard Python dictionary.
|
774
|
+
|
775
|
+
Returns
|
776
|
+
-------
|
777
|
+
DotDict
|
778
|
+
A new DotDict instance.
|
779
|
+
"""
|
780
|
+
return cls(d)
|
781
|
+
|
782
|
+
def update(self, *args, **kwargs) -> None:
|
783
|
+
"""
|
784
|
+
Update with recursive merge for nested dicts.
|
785
|
+
|
786
|
+
Parameters
|
787
|
+
----------
|
788
|
+
*args
|
789
|
+
Dict-like objects to merge.
|
790
|
+
**kwargs
|
791
|
+
Key-value pairs to merge.
|
792
|
+
"""
|
793
|
+
if args:
|
794
|
+
if len(args) > 1:
|
795
|
+
raise TypeError(f"update expected at most 1 argument, got {len(args)}")
|
796
|
+
other = args[0]
|
797
|
+
else:
|
798
|
+
other = {}
|
799
|
+
|
800
|
+
if hasattr(other, 'items'):
|
801
|
+
other = dict(other.items())
|
802
|
+
other.update(kwargs)
|
803
|
+
|
804
|
+
for k, v in other.items():
|
805
|
+
if k in self and isinstance(self[k], dict) and isinstance(v, dict):
|
806
|
+
# Recursive merge for nested dicts
|
807
|
+
if isinstance(self[k], DotDict):
|
808
|
+
self[k].update(v)
|
809
|
+
else:
|
810
|
+
self[k] = DotDict(self[k])
|
811
|
+
self[k].update(v)
|
812
|
+
else:
|
813
|
+
self[k] = self._hook(v)
|
814
|
+
|
815
|
+
def setdefault(self, key: str, default: Any = None) -> Any:
|
816
|
+
"""Set default value if key doesn't exist."""
|
817
|
+
if key not in self:
|
818
|
+
self[key] = default
|
819
|
+
return self[key]
|
820
|
+
|
821
|
+
def __getstate__(self) -> Dict[str, Any]:
|
822
|
+
"""Get state for pickling."""
|
823
|
+
return dict(self)
|
824
|
+
|
825
|
+
def __setstate__(self, state: Dict[str, Any]) -> None:
|
826
|
+
"""Set state from pickling."""
|
827
|
+
self.update(state)
|
828
|
+
|
829
|
+
def tree_flatten(self) -> Tuple[Tuple[Any, ...], Tuple[str, ...]]:
|
830
|
+
"""Flatten for JAX pytree."""
|
831
|
+
return tuple(self.values()), tuple(self.keys())
|
832
|
+
|
833
|
+
@classmethod
|
834
|
+
def tree_unflatten(cls, keys: Tuple[str, ...], values: Tuple[Any, ...]) -> 'DotDict':
|
835
|
+
"""Unflatten from JAX pytree."""
|
836
|
+
return cls(zip(keys, values))
|
837
|
+
|
838
|
+
def __repr__(self) -> str:
|
839
|
+
"""String representation."""
|
840
|
+
items = ', '.join(f'{k!r}: {v!r}' for k, v in self.items())
|
841
|
+
return f'DotDict({{{items}}})'
|
842
|
+
|
843
|
+
|
844
|
+
@set_module_as('brainstate.util')
|
845
|
+
def merge_dicts(*dicts: Dict[K, V], recursive: bool = True) -> Dict[K, V]:
|
846
|
+
"""
|
847
|
+
Merge multiple dictionaries.
|
848
|
+
|
849
|
+
Parameters
|
850
|
+
----------
|
851
|
+
*dicts : Dict
|
852
|
+
Dictionaries to merge (later ones override earlier ones).
|
853
|
+
recursive : bool, default=True
|
854
|
+
Whether to recursively merge nested dicts.
|
855
|
+
|
856
|
+
Returns
|
857
|
+
-------
|
858
|
+
Dict
|
859
|
+
Merged dictionary.
|
860
|
+
|
861
|
+
Examples
|
862
|
+
--------
|
863
|
+
>>> d1 = {'a': 1, 'b': {'c': 2}}
|
864
|
+
>>> d2 = {'b': {'d': 3}, 'e': 4}
|
865
|
+
>>> merge_dicts(d1, d2)
|
866
|
+
{'a': 1, 'b': {'c': 2, 'd': 3}, 'e': 4}
|
867
|
+
"""
|
868
|
+
result = {}
|
869
|
+
|
870
|
+
for d in dicts:
|
871
|
+
if not isinstance(d, dict):
|
872
|
+
raise TypeError(f"All arguments must be dicts, got {type(d).__name__}")
|
873
|
+
|
874
|
+
for key, value in d.items():
|
875
|
+
if recursive and key in result and isinstance(result[key], dict) and isinstance(value, dict):
|
876
|
+
result[key] = merge_dicts(result[key], value, recursive=True)
|
877
|
+
else:
|
878
|
+
result[key] = value
|
879
|
+
|
880
|
+
return result
|
881
|
+
|
882
|
+
|
883
|
+
@set_module_as('brainstate.util')
|
884
|
+
def flatten_dict(
|
885
|
+
d: Dict[str, Any],
|
886
|
+
parent_key: str = '',
|
887
|
+
sep: str = '.'
|
888
|
+
) -> Dict[str, Any]:
|
889
|
+
"""
|
890
|
+
Flatten a nested dictionary.
|
891
|
+
|
892
|
+
Parameters
|
893
|
+
----------
|
894
|
+
d : Dict
|
895
|
+
Dictionary to flatten.
|
896
|
+
parent_key : str, default=''
|
897
|
+
Prefix for keys.
|
898
|
+
sep : str, default='.'
|
899
|
+
Separator between nested keys.
|
900
|
+
|
901
|
+
Returns
|
902
|
+
-------
|
903
|
+
Dict
|
904
|
+
Flattened dictionary.
|
905
|
+
|
906
|
+
Examples
|
907
|
+
--------
|
908
|
+
>>> d = {'a': 1, 'b': {'c': 2, 'd': {'e': 3}}}
|
909
|
+
>>> flatten_dict(d)
|
910
|
+
{'a': 1, 'b.c': 2, 'b.d.e': 3}
|
911
|
+
"""
|
912
|
+
items = []
|
913
|
+
for k, v in d.items():
|
914
|
+
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
915
|
+
if isinstance(v, dict):
|
916
|
+
items.extend(flatten_dict(v, new_key, sep=sep).items())
|
917
|
+
else:
|
918
|
+
items.append((new_key, v))
|
919
|
+
return dict(items)
|
920
|
+
|
921
|
+
|
922
|
+
@set_module_as('brainstate.util')
|
923
|
+
def unflatten_dict(
|
924
|
+
d: Dict[str, Any],
|
925
|
+
sep: str = '.'
|
926
|
+
) -> Dict[str, Any]:
|
927
|
+
"""
|
928
|
+
Unflatten a dictionary with separated keys.
|
929
|
+
|
930
|
+
Parameters
|
931
|
+
----------
|
932
|
+
d : Dict
|
933
|
+
Flattened dictionary.
|
934
|
+
sep : str, default='.'
|
935
|
+
Separator in keys.
|
936
|
+
|
937
|
+
Returns
|
938
|
+
-------
|
939
|
+
Dict
|
940
|
+
Nested dictionary.
|
941
|
+
|
942
|
+
Examples
|
943
|
+
--------
|
944
|
+
>>> d = {'a': 1, 'b.c': 2, 'b.d.e': 3}
|
945
|
+
>>> unflatten_dict(d)
|
946
|
+
{'a': 1, 'b': {'c': 2, 'd': {'e': 3}}}
|
947
|
+
"""
|
948
|
+
result = {}
|
949
|
+
|
950
|
+
for key, value in d.items():
|
951
|
+
parts = key.split(sep)
|
952
|
+
current = result
|
953
|
+
|
954
|
+
for part in parts[:-1]:
|
955
|
+
if part not in current:
|
956
|
+
current[part] = {}
|
957
|
+
current = current[part]
|
958
|
+
|
959
|
+
current[parts[-1]] = value
|
960
|
+
|
961
|
+
return result
|
962
|
+
|
963
|
+
|
964
|
+
def _is_not_instance(x: Any, cls: Union[Type, Tuple[Type, ...]]) -> bool:
|
965
|
+
"""Check if x is not an instance of cls."""
|
966
|
+
return not isinstance(x, cls)
|
967
|
+
|
968
|
+
|
969
|
+
def _is_instance(x: Any, cls: Union[Type, Tuple[Type, ...]]) -> bool:
|
970
|
+
"""Check if x is an instance of cls."""
|
971
|
+
return isinstance(x, cls)
|
972
|
+
|
973
|
+
|
974
|
+
@set_module_as('brainstate.util')
|
975
|
+
def not_instance_eval(*cls: Type) -> Callable[[Any], bool]:
|
976
|
+
"""
|
977
|
+
Create a partial function to check if input is NOT an instance of given classes.
|
978
|
+
|
979
|
+
Parameters
|
980
|
+
----------
|
981
|
+
*cls : Type
|
982
|
+
Classes to check against.
|
983
|
+
|
984
|
+
Returns
|
985
|
+
-------
|
986
|
+
Callable
|
987
|
+
A function that returns True if input is not an instance of any given class.
|
988
|
+
|
989
|
+
Examples
|
990
|
+
--------
|
991
|
+
>>> not_int = not_instance_eval(int)
|
992
|
+
>>> not_int(5)
|
993
|
+
False
|
994
|
+
>>> not_int("hello")
|
995
|
+
True
|
996
|
+
"""
|
997
|
+
return functools.partial(_is_not_instance, cls=cls)
|
998
|
+
|
999
|
+
|
1000
|
+
@set_module_as('brainstate.util')
|
1001
|
+
def is_instance_eval(*cls: Type) -> Callable[[Any], bool]:
|
1002
|
+
"""
|
1003
|
+
Create a partial function to check if input IS an instance of given classes.
|
1004
|
+
|
1005
|
+
Parameters
|
1006
|
+
----------
|
1007
|
+
*cls : Type
|
1008
|
+
Classes to check against.
|
1009
|
+
|
1010
|
+
Returns
|
1011
|
+
-------
|
1012
|
+
Callable
|
1013
|
+
A function that returns True if input is an instance of any given class.
|
1014
|
+
|
1015
|
+
Examples
|
1016
|
+
--------
|
1017
|
+
>>> is_number = is_instance_eval(int, float)
|
1018
|
+
>>> is_number(5)
|
1019
|
+
True
|
1020
|
+
>>> is_number(3.14)
|
1021
|
+
True
|
1022
|
+
>>> is_number("hello")
|
1023
|
+
False
|
1024
|
+
"""
|
1025
1025
|
return functools.partial(_is_instance, cls=cls)
|