brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1025 @@
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """
17
+ Utility functions and classes for BrainState.
18
+
19
+ This module provides various utility functions and enhanced dictionary classes
20
+ for managing collections, memory, and object operations in the BrainState framework.
21
+ """
22
+
23
+ import copy
24
+ import functools
25
+ import gc
26
+ import threading
27
+ import types
28
+ import warnings
29
+ from collections.abc import Iterable, Mapping, MutableMapping
30
+ from typing import (
31
+ Any, Callable, Dict, Iterator, List, Optional,
32
+ Tuple, Type, TypeVar, Union, overload
33
+ )
34
+
35
+ import jax
36
+ from jax.lib import xla_bridge
37
+
38
+ from brainstate._utils import set_module_as
39
+
40
+ __all__ = [
41
+ 'split_total',
42
+ 'clear_buffer_memory',
43
+ 'not_instance_eval',
44
+ 'is_instance_eval',
45
+ 'DictManager',
46
+ 'DotDict',
47
+ 'get_unique_name',
48
+ 'merge_dicts',
49
+ 'flatten_dict',
50
+ 'unflatten_dict',
51
+ ]
52
+
53
+ T = TypeVar('T')
54
+ V = TypeVar('V')
55
+ K = TypeVar('K')
56
+
57
+
58
+ def split_total(
59
+ total: int,
60
+ fraction: Union[int, float],
61
+ ) -> int:
62
+ """
63
+ Calculate the number of epochs for simulation based on a total and a fraction.
64
+
65
+ This function determines the number of epochs to simulate given a total number
66
+ of epochs and either a fraction or a specific number of epochs to run.
67
+
68
+ Parameters
69
+ ----------
70
+ total : int
71
+ The total number of epochs. Must be a positive integer.
72
+ fraction : Union[int, float]
73
+ If ``float``: A value between 0 and 1 representing the fraction of total epochs to run.
74
+ If ``int``: The specific number of epochs to run, must not exceed the total.
75
+
76
+ Returns
77
+ -------
78
+ int
79
+ The calculated number of epochs to simulate.
80
+
81
+ Raises
82
+ ------
83
+ TypeError
84
+ If total is not an integer.
85
+ ValueError
86
+ If total is not positive, fraction is negative, or if fraction as float is > 1
87
+ or as int is > total.
88
+
89
+ Examples
90
+ --------
91
+ >>> split_total(100, 0.5)
92
+ 50
93
+ >>> split_total(100, 25)
94
+ 25
95
+ >>> split_total(100, 1.5) # Raises ValueError
96
+ ValueError: 'fraction' value cannot be greater than 1.
97
+ """
98
+ if not isinstance(total, int):
99
+ raise TypeError(f"'total' must be an integer, got {type(total).__name__}.")
100
+ if total <= 0:
101
+ raise ValueError(f"'total' must be a positive integer, got {total}.")
102
+
103
+ if isinstance(fraction, float):
104
+ if fraction < 0:
105
+ raise ValueError(f"'fraction' value cannot be negative, got {fraction}.")
106
+ if fraction > 1:
107
+ raise ValueError(f"'fraction' value cannot be greater than 1, got {fraction}.")
108
+ return int(total * fraction)
109
+
110
+ elif isinstance(fraction, int):
111
+ if fraction < 0:
112
+ raise ValueError(f"'fraction' value cannot be negative, got {fraction}.")
113
+ if fraction > total:
114
+ raise ValueError(f"'fraction' value cannot be greater than total ({total}), got {fraction}.")
115
+ return fraction
116
+
117
+ else:
118
+ raise TypeError(f"'fraction' must be an integer or float, got {type(fraction).__name__}.")
119
+
120
+
121
+ class NameContext(threading.local):
122
+ """Thread-local context for managing unique names."""
123
+
124
+ def __init__(self):
125
+ self.typed_names: Dict[str, int] = {}
126
+
127
+ def reset(self, type_: Optional[str] = None) -> None:
128
+ """Reset the counter for a specific type or all types."""
129
+ if type_ is None:
130
+ self.typed_names.clear()
131
+ elif type_ in self.typed_names:
132
+ self.typed_names[type_] = 0
133
+
134
+
135
+ NAME = NameContext()
136
+
137
+
138
+ @set_module_as('brainstate.util')
139
+ def get_unique_name(type_: str, prefix: str = '') -> str:
140
+ """
141
+ Get a unique name for the given object type.
142
+
143
+ Parameters
144
+ ----------
145
+ type_ : str
146
+ The base type name.
147
+ prefix : str, optional
148
+ Additional prefix to add before the type name.
149
+
150
+ Returns
151
+ -------
152
+ str
153
+ A unique name combining prefix, type, and counter.
154
+
155
+ Examples
156
+ --------
157
+ >>> get_unique_name('layer')
158
+ 'layer0'
159
+ >>> get_unique_name('layer', 'conv_')
160
+ 'conv_layer1'
161
+ """
162
+ if type_ not in NAME.typed_names:
163
+ NAME.typed_names[type_] = 0
164
+
165
+ full_prefix = f'{prefix}{type_}' if prefix else type_
166
+ name = f'{full_prefix}{NAME.typed_names[type_]}'
167
+ NAME.typed_names[type_] += 1
168
+ return name
169
+
170
+
171
+ @jax.tree_util.register_pytree_node_class
172
+ class DictManager(dict, MutableMapping[K, V]):
173
+ """
174
+ Enhanced dictionary for managing collections in BrainState.
175
+
176
+ DictManager extends the standard Python dict with additional methods for
177
+ filtering, splitting, and managing collections of objects. It's registered
178
+ as a JAX pytree node for compatibility with JAX transformations.
179
+
180
+ Examples
181
+ --------
182
+ >>> dm = DictManager({'a': 1, 'b': 2.0, 'c': 'text'})
183
+ >>> dm.subset(int) # Get only integer values
184
+ DictManager({'a': 1})
185
+ >>> dm.unique() # Get unique values only
186
+ DictManager({'a': 1, 'b': 2.0, 'c': 'text'})
187
+ """
188
+
189
+ __module__ = 'brainstate.util'
190
+ _val_id_to_key: Dict[int, Any]
191
+
192
+ def __init__(self, *args, **kwargs):
193
+ """Initialize DictManager with optional dict-like arguments."""
194
+ super().__init__(*args, **kwargs)
195
+ self._val_id_to_key = {}
196
+
197
+ def subset(self, sep: Union[Type, Tuple[Type, ...], Callable[[Any], bool]]) -> 'DictManager':
198
+ """
199
+ Get a new DictManager with a subset of items based on value type or predicate.
200
+
201
+ Parameters
202
+ ----------
203
+ sep : Union[Type, Tuple[Type, ...], Callable]
204
+ If Type or Tuple of Types: Select values that are instances of these types.
205
+ If Callable: Select values where sep(value) returns True.
206
+
207
+ Returns
208
+ -------
209
+ DictManager
210
+ A new DictManager containing only matching items.
211
+
212
+ Examples
213
+ --------
214
+ >>> dm = DictManager({'a': 1, 'b': 2.0, 'c': 'text'})
215
+ >>> dm.subset(int)
216
+ DictManager({'a': 1})
217
+ >>> dm.subset(lambda x: isinstance(x, (int, float)))
218
+ DictManager({'a': 1, 'b': 2.0})
219
+ """
220
+ gather = type(self)()
221
+ if callable(sep) and not isinstance(sep, type):
222
+ for k, v in self.items():
223
+ if sep(v):
224
+ gather[k] = v
225
+ else:
226
+ for k, v in self.items():
227
+ if isinstance(v, sep):
228
+ gather[k] = v
229
+ return gather
230
+
231
+ def not_subset(self, sep: Union[Type, Tuple[Type, ...]]) -> 'DictManager':
232
+ """
233
+ Get a new DictManager excluding items of specified types.
234
+
235
+ Parameters
236
+ ----------
237
+ sep : Union[Type, Tuple[Type, ...]]
238
+ Types to exclude from the result.
239
+
240
+ Returns
241
+ -------
242
+ DictManager
243
+ A new DictManager excluding items of specified types.
244
+ """
245
+ gather = type(self)()
246
+ for k, v in self.items():
247
+ if not isinstance(v, sep):
248
+ gather[k] = v
249
+ return gather
250
+
251
+ def add_unique_key(self, key: K, val: V) -> None:
252
+ """
253
+ Add a new element ensuring the key maps to a unique value.
254
+
255
+ Parameters
256
+ ----------
257
+ key : Any
258
+ The key to add.
259
+ val : Any
260
+ The value to associate with the key.
261
+
262
+ Raises
263
+ ------
264
+ ValueError
265
+ If the key already exists with a different value.
266
+ """
267
+ self._check_elem(val)
268
+ if key in self:
269
+ if id(val) != id(self[key]):
270
+ raise ValueError(
271
+ f"Key '{key}' already exists with a different value. "
272
+ f"Existing: {self[key]}, New: {val}"
273
+ )
274
+ else:
275
+ self[key] = val
276
+
277
+ def add_unique_value(self, key: K, val: V) -> bool:
278
+ """
279
+ Add a new element only if the value is unique across all entries.
280
+
281
+ Parameters
282
+ ----------
283
+ key : Any
284
+ The key to add.
285
+ val : Any
286
+ The value to associate with the key.
287
+
288
+ Returns
289
+ -------
290
+ bool
291
+ True if the value was added (was unique), False otherwise.
292
+ """
293
+ self._check_elem(val)
294
+ if not hasattr(self, '_val_id_to_key'):
295
+ self._val_id_to_key = {id(v): k for k, v in self.items()}
296
+
297
+ val_id = id(val)
298
+ if val_id not in self._val_id_to_key:
299
+ self._val_id_to_key[val_id] = key
300
+ self[key] = val
301
+ return True
302
+ return False
303
+
304
+ def unique(self) -> 'DictManager':
305
+ """
306
+ Get a new DictManager with unique values only.
307
+
308
+ If multiple keys map to the same value (by identity),
309
+ only the first key-value pair is retained.
310
+
311
+ Returns
312
+ -------
313
+ DictManager
314
+ A new DictManager with unique values.
315
+ """
316
+ gather = type(self)()
317
+ seen = set()
318
+ for k, v in self.items():
319
+ v_id = id(v)
320
+ if v_id not in seen:
321
+ seen.add(v_id)
322
+ gather[k] = v
323
+ return gather
324
+
325
+ def unique_(self) -> 'DictManager':
326
+ """
327
+ Remove duplicate values in-place.
328
+
329
+ Returns
330
+ -------
331
+ DictManager
332
+ Self, for method chaining.
333
+ """
334
+ seen = set()
335
+ keys_to_remove = []
336
+ for k, v in self.items():
337
+ v_id = id(v)
338
+ if v_id in seen:
339
+ keys_to_remove.append(k)
340
+ else:
341
+ seen.add(v_id)
342
+
343
+ for k in keys_to_remove:
344
+ del self[k]
345
+ return self
346
+
347
+ def assign(self, *args: Dict[K, V], **kwargs: V) -> None:
348
+ """
349
+ Update the DictManager with multiple dictionaries.
350
+
351
+ Parameters
352
+ ----------
353
+ *args : Dict
354
+ Dictionaries to merge into this one.
355
+ **kwargs
356
+ Additional key-value pairs to add.
357
+ """
358
+ for arg in args:
359
+ if not isinstance(arg, dict):
360
+ raise TypeError(f"Arguments must be dict instances, got {type(arg).__name__}")
361
+ self.update(arg)
362
+ if kwargs:
363
+ self.update(kwargs)
364
+
365
+ def split(self, *types: Type) -> Tuple['DictManager', ...]:
366
+ """
367
+ Split the DictManager into multiple based on value types.
368
+
369
+ Parameters
370
+ ----------
371
+ *types : Type
372
+ Types to use for splitting. Each type gets its own DictManager.
373
+
374
+ Returns
375
+ -------
376
+ Tuple[DictManager, ...]
377
+ A tuple of DictManagers, one for each type plus one for unmatched items.
378
+ """
379
+ results = tuple(type(self)() for _ in range(len(types) + 1))
380
+
381
+ for k, v in self.items():
382
+ for i, type_ in enumerate(types):
383
+ if isinstance(v, type_):
384
+ results[i][k] = v
385
+ break
386
+ else:
387
+ results[-1][k] = v
388
+
389
+ return results
390
+
391
+ def filter_by_predicate(self, predicate: Callable[[K, V], bool]) -> 'DictManager':
392
+ """
393
+ Filter items using a predicate function.
394
+
395
+ Parameters
396
+ ----------
397
+ predicate : Callable[[key, value], bool]
398
+ Function that returns True for items to keep.
399
+
400
+ Returns
401
+ -------
402
+ DictManager
403
+ A new DictManager with filtered items.
404
+ """
405
+ return type(self)({k: v for k, v in self.items() if predicate(k, v)})
406
+
407
+ def map_values(self, func: Callable[[V], Any]) -> 'DictManager':
408
+ """
409
+ Apply a function to all values.
410
+
411
+ Parameters
412
+ ----------
413
+ func : Callable
414
+ Function to apply to each value.
415
+
416
+ Returns
417
+ -------
418
+ DictManager
419
+ A new DictManager with transformed values.
420
+ """
421
+ return type(self)({k: func(v) for k, v in self.items()})
422
+
423
+ def map_keys(self, func: Callable[[K], Any]) -> 'DictManager':
424
+ """
425
+ Apply a function to all keys.
426
+
427
+ Parameters
428
+ ----------
429
+ func : Callable
430
+ Function to apply to each key.
431
+
432
+ Returns
433
+ -------
434
+ DictManager
435
+ A new DictManager with transformed keys.
436
+
437
+ Raises
438
+ ------
439
+ ValueError
440
+ If the transformation creates duplicate keys.
441
+ """
442
+ result = type(self)()
443
+ for k, v in self.items():
444
+ new_key = func(k)
445
+ if new_key in result:
446
+ raise ValueError(f"Key transformation created duplicate: {new_key}")
447
+ result[new_key] = v
448
+ return result
449
+
450
+ def pop_by_keys(self, keys: Iterable[K]) -> None:
451
+ """Remove multiple keys from the DictManager."""
452
+ keys_set = set(keys)
453
+ for k in list(self.keys()):
454
+ if k in keys_set:
455
+ self.pop(k)
456
+
457
+ def pop_by_values(self, values: Iterable[V], by: str = 'id') -> None:
458
+ """
459
+ Remove items by their values.
460
+
461
+ Parameters
462
+ ----------
463
+ values : Iterable
464
+ Values to remove.
465
+ by : str
466
+ Comparison method: 'id' (identity) or 'value' (equality).
467
+ """
468
+ if by == 'id':
469
+ value_ids = {id(v) for v in values}
470
+ keys_to_remove = [k for k, v in self.items() if id(v) in value_ids]
471
+ elif by == 'value':
472
+ values_set = set(values) if not isinstance(values, set) else values
473
+ keys_to_remove = [k for k, v in self.items() if v in values_set]
474
+ else:
475
+ raise ValueError(f"Invalid comparison method: {by}. Use 'id' or 'value'.")
476
+
477
+ for k in keys_to_remove:
478
+ del self[k]
479
+
480
+ def difference_by_keys(self, keys: Iterable[K]) -> 'DictManager':
481
+ """Get items not in the specified keys."""
482
+ keys_set = set(keys)
483
+ return type(self)({k: v for k, v in self.items() if k not in keys_set})
484
+
485
+ def difference_by_values(self, values: Iterable[V], by: str = 'id') -> 'DictManager':
486
+ """Get items whose values are not in the specified collection."""
487
+ if by == 'id':
488
+ value_ids = {id(v) for v in values}
489
+ return type(self)({k: v for k, v in self.items() if id(v) not in value_ids})
490
+ elif by == 'value':
491
+ values_set = set(values) if not isinstance(values, set) else values
492
+ return type(self)({k: v for k, v in self.items() if v not in values_set})
493
+ else:
494
+ raise ValueError(f"Invalid comparison method: {by}. Use 'id' or 'value'.")
495
+
496
+ def intersection_by_keys(self, keys: Iterable[K]) -> 'DictManager':
497
+ """Get items with keys in the specified collection."""
498
+ keys_set = set(keys)
499
+ return type(self)({k: v for k, v in self.items() if k in keys_set})
500
+
501
+ def intersection_by_values(self, values: Iterable[V], by: str = 'id') -> 'DictManager':
502
+ """Get items whose values are in the specified collection."""
503
+ if by == 'id':
504
+ value_ids = {id(v) for v in values}
505
+ return type(self)({k: v for k, v in self.items() if id(v) in value_ids})
506
+ elif by == 'value':
507
+ values_set = set(values) if not isinstance(values, set) else values
508
+ return type(self)({k: v for k, v in self.items() if v in values_set})
509
+ else:
510
+ raise ValueError(f"Invalid comparison method: {by}. Use 'id' or 'value'.")
511
+
512
+ def __add__(self, other: Mapping[K, V]) -> 'DictManager':
513
+ """Combine with another mapping using the + operator."""
514
+ if not isinstance(other, Mapping):
515
+ return NotImplemented
516
+ new_dict = type(self)(self)
517
+ new_dict.update(other)
518
+ return new_dict
519
+
520
+ def __or__(self, other: Mapping[K, V]) -> 'DictManager':
521
+ """Combine with another mapping using the | operator (Python 3.9+)."""
522
+ if not isinstance(other, Mapping):
523
+ return NotImplemented
524
+ new_dict = type(self)(self)
525
+ new_dict.update(other)
526
+ return new_dict
527
+
528
+ def __ior__(self, other: Mapping[K, V]) -> 'DictManager':
529
+ """Update in-place with another mapping using |= operator."""
530
+ if not isinstance(other, Mapping):
531
+ return NotImplemented
532
+ self.update(other)
533
+ return self
534
+
535
+ def tree_flatten(self) -> Tuple[Tuple[V, ...], Tuple[K, ...]]:
536
+ """Flatten for JAX pytree."""
537
+ return tuple(self.values()), tuple(self.keys())
538
+
539
+ @classmethod
540
+ def tree_unflatten(cls, keys: Tuple[K, ...], values: Tuple[V, ...]) -> 'DictManager':
541
+ """Unflatten from JAX pytree."""
542
+ return cls(zip(keys, values))
543
+
544
+ def _check_elem(self, elem: Any) -> None:
545
+ """Override in subclasses to validate elements."""
546
+ pass
547
+
548
+ def to_dict(self) -> Dict[K, V]:
549
+ """Convert to a standard Python dict."""
550
+ return dict(self)
551
+
552
+ def __copy__(self) -> 'DictManager':
553
+ """Shallow copy."""
554
+ return type(self)(self)
555
+
556
+ def __deepcopy__(self, memo: Dict[int, Any]) -> 'DictManager':
557
+ """Deep copy."""
558
+ return type(self)({
559
+ copy.deepcopy(k, memo): copy.deepcopy(v, memo)
560
+ for k, v in self.items()
561
+ })
562
+
563
+ def __repr__(self) -> str:
564
+ """String representation."""
565
+ items = ', '.join(f'{k!r}: {v!r}' for k, v in self.items())
566
+ return f'{self.__class__.__name__}({{{items}}})'
567
+
568
+
569
+ @set_module_as('brainstate.util')
570
+ def clear_buffer_memory(
571
+ platform: Optional[str] = None,
572
+ array: bool = True,
573
+ compilation: bool = False,
574
+ ) -> None:
575
+ """
576
+ Clear on-device memory buffers and optionally compilation cache.
577
+
578
+ This function is useful when running models in loops to prevent memory leaks
579
+ by clearing cached arrays and freeing device memory.
580
+
581
+ .. warning::
582
+ This operation may invalidate existing array references.
583
+ Regenerate data after calling this function.
584
+
585
+ Parameters
586
+ ----------
587
+ platform : str, optional
588
+ The specific device platform to clear. If None, clears the default platform.
589
+ array : bool, default=True
590
+ Whether to clear array buffers.
591
+ compilation : bool, default=False
592
+ Whether to clear the compilation cache.
593
+
594
+ Examples
595
+ --------
596
+ >>> clear_buffer_memory() # Clear array buffers
597
+ >>> clear_buffer_memory(compilation=True) # Also clear compilation cache
598
+ """
599
+ if array:
600
+ try:
601
+ backend = xla_bridge.get_backend(platform)
602
+ for buf in backend.live_buffers():
603
+ buf.delete()
604
+ except Exception as e:
605
+ warnings.warn(f"Failed to clear buffers: {e}", RuntimeWarning)
606
+
607
+ if compilation:
608
+ jax.clear_caches()
609
+
610
+ gc.collect()
611
+
612
+
613
+ @jax.tree_util.register_pytree_node_class
614
+ class DotDict(dict, MutableMapping[str, Any]):
615
+ """
616
+ Dictionary with dot notation access to nested keys.
617
+
618
+ DotDict allows accessing dictionary items using attribute syntax,
619
+ making code more readable when dealing with nested configurations.
620
+
621
+ Examples
622
+ --------
623
+ >>> config = DotDict({'model': {'layers': 3, 'units': 64}})
624
+ >>> config.model.layers
625
+ 3
626
+ >>> config.model.units = 128
627
+ >>> config['model']['units']
628
+ 128
629
+
630
+ Attributes
631
+ ----------
632
+ All dictionary keys become accessible as attributes unless they conflict
633
+ with built-in methods.
634
+ """
635
+
636
+ __module__ = 'brainstate.util'
637
+
638
+ def __init__(self, *args, **kwargs):
639
+ """
640
+ Initialize DotDict with dict-like arguments.
641
+
642
+ Parameters
643
+ ----------
644
+ *args
645
+ Positional arguments (dicts, iterables of pairs).
646
+ **kwargs
647
+ Keyword arguments become key-value pairs.
648
+ """
649
+ # Handle parent reference for nested updates
650
+ object.__setattr__(self, '__parent', kwargs.pop('__parent', None))
651
+ object.__setattr__(self, '__key', kwargs.pop('__key', None))
652
+
653
+ # Process positional arguments
654
+ for arg in args:
655
+ if not arg:
656
+ continue
657
+ elif isinstance(arg, dict):
658
+ for key, val in arg.items():
659
+ self[key] = self._hook(val)
660
+ elif isinstance(arg, tuple) and len(arg) == 2 and not isinstance(arg[0], tuple):
661
+ # Single key-value pair
662
+ self[arg[0]] = self._hook(arg[1])
663
+ else:
664
+ # Iterable of key-value pairs
665
+ try:
666
+ for key, val in arg:
667
+ self[key] = self._hook(val)
668
+ except (TypeError, ValueError) as e:
669
+ raise TypeError(f"Invalid argument type for DotDict: {type(arg).__name__}") from e
670
+
671
+ # Process keyword arguments
672
+ for key, val in kwargs.items():
673
+ self[key] = self._hook(val)
674
+
675
+ def __setattr__(self, name: str, value: Any) -> None:
676
+ """Set attribute as dictionary item."""
677
+ if hasattr(self.__class__, name):
678
+ raise AttributeError(
679
+ f"Cannot set attribute '{name}': it's a built-in method of {self.__class__.__name__}"
680
+ )
681
+ self[name] = value
682
+
683
+ def __setitem__(self, name: str, value: Any) -> None:
684
+ """Set item and update parent if nested."""
685
+ super().__setitem__(name, value)
686
+ try:
687
+ parent = object.__getattribute__(self, '__parent')
688
+ key = object.__getattribute__(self, '__key')
689
+ if parent is not None:
690
+ parent[key] = self
691
+ object.__delattr__(self, '__parent')
692
+ object.__delattr__(self, '__key')
693
+ except AttributeError:
694
+ pass
695
+
696
+ @classmethod
697
+ def _hook(cls, item: Any) -> Any:
698
+ """Convert nested dicts to DotDict."""
699
+ if isinstance(item, dict) and not isinstance(item, cls):
700
+ return cls(item)
701
+ elif isinstance(item, (list, tuple)):
702
+ return type(item)(cls._hook(elem) for elem in item)
703
+ return item
704
+
705
+ def __getattr__(self, item: str) -> Any:
706
+ """Get attribute from dictionary."""
707
+ try:
708
+ return self[item]
709
+ except KeyError:
710
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'")
711
+
712
+ def __delattr__(self, name: str) -> None:
713
+ """Delete attribute from dictionary."""
714
+ try:
715
+ del self[name]
716
+ except KeyError:
717
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
718
+
719
+ def __dir__(self) -> List[str]:
720
+ """List all attributes including dict keys."""
721
+ return list(self.keys()) + dir(self.__class__)
722
+
723
+ def get(self, key: str, default: Any = None) -> Any:
724
+ """Get item with default value."""
725
+ return super().get(key, default)
726
+
727
+ def copy(self) -> 'DotDict':
728
+ """Create a shallow copy."""
729
+ return copy.copy(self)
730
+
731
+ def deepcopy(self) -> 'DotDict':
732
+ """Create a deep copy."""
733
+ return copy.deepcopy(self)
734
+
735
+ def __deepcopy__(self, memo: Dict[int, Any]) -> 'DotDict':
736
+ """Deep copy implementation."""
737
+ other = self.__class__()
738
+ memo[id(self)] = other
739
+ for key, value in self.items():
740
+ other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo)
741
+ return other
742
+
743
+ def to_dict(self) -> Dict[str, Any]:
744
+ """
745
+ Convert to standard dict recursively.
746
+
747
+ Returns
748
+ -------
749
+ dict
750
+ A standard Python dict with nested DotDicts also converted.
751
+ """
752
+ result = {}
753
+ for key, value in self.items():
754
+ if isinstance(value, DotDict):
755
+ result[key] = value.to_dict()
756
+ elif isinstance(value, (list, tuple)):
757
+ result[key] = type(value)(
758
+ item.to_dict() if isinstance(item, DotDict) else item
759
+ for item in value
760
+ )
761
+ else:
762
+ result[key] = value
763
+ return result
764
+
765
+ @classmethod
766
+ def from_dict(cls, d: Dict[str, Any]) -> 'DotDict':
767
+ """
768
+ Create DotDict from standard dict.
769
+
770
+ Parameters
771
+ ----------
772
+ d : dict
773
+ Standard Python dictionary.
774
+
775
+ Returns
776
+ -------
777
+ DotDict
778
+ A new DotDict instance.
779
+ """
780
+ return cls(d)
781
+
782
+ def update(self, *args, **kwargs) -> None:
783
+ """
784
+ Update with recursive merge for nested dicts.
785
+
786
+ Parameters
787
+ ----------
788
+ *args
789
+ Dict-like objects to merge.
790
+ **kwargs
791
+ Key-value pairs to merge.
792
+ """
793
+ if args:
794
+ if len(args) > 1:
795
+ raise TypeError(f"update expected at most 1 argument, got {len(args)}")
796
+ other = args[0]
797
+ else:
798
+ other = {}
799
+
800
+ if hasattr(other, 'items'):
801
+ other = dict(other.items())
802
+ other.update(kwargs)
803
+
804
+ for k, v in other.items():
805
+ if k in self and isinstance(self[k], dict) and isinstance(v, dict):
806
+ # Recursive merge for nested dicts
807
+ if isinstance(self[k], DotDict):
808
+ self[k].update(v)
809
+ else:
810
+ self[k] = DotDict(self[k])
811
+ self[k].update(v)
812
+ else:
813
+ self[k] = self._hook(v)
814
+
815
+ def setdefault(self, key: str, default: Any = None) -> Any:
816
+ """Set default value if key doesn't exist."""
817
+ if key not in self:
818
+ self[key] = default
819
+ return self[key]
820
+
821
+ def __getstate__(self) -> Dict[str, Any]:
822
+ """Get state for pickling."""
823
+ return dict(self)
824
+
825
+ def __setstate__(self, state: Dict[str, Any]) -> None:
826
+ """Set state from pickling."""
827
+ self.update(state)
828
+
829
+ def tree_flatten(self) -> Tuple[Tuple[Any, ...], Tuple[str, ...]]:
830
+ """Flatten for JAX pytree."""
831
+ return tuple(self.values()), tuple(self.keys())
832
+
833
+ @classmethod
834
+ def tree_unflatten(cls, keys: Tuple[str, ...], values: Tuple[Any, ...]) -> 'DotDict':
835
+ """Unflatten from JAX pytree."""
836
+ return cls(zip(keys, values))
837
+
838
+ def __repr__(self) -> str:
839
+ """String representation."""
840
+ items = ', '.join(f'{k!r}: {v!r}' for k, v in self.items())
841
+ return f'DotDict({{{items}}})'
842
+
843
+
844
+ @set_module_as('brainstate.util')
845
+ def merge_dicts(*dicts: Dict[K, V], recursive: bool = True) -> Dict[K, V]:
846
+ """
847
+ Merge multiple dictionaries.
848
+
849
+ Parameters
850
+ ----------
851
+ *dicts : Dict
852
+ Dictionaries to merge (later ones override earlier ones).
853
+ recursive : bool, default=True
854
+ Whether to recursively merge nested dicts.
855
+
856
+ Returns
857
+ -------
858
+ Dict
859
+ Merged dictionary.
860
+
861
+ Examples
862
+ --------
863
+ >>> d1 = {'a': 1, 'b': {'c': 2}}
864
+ >>> d2 = {'b': {'d': 3}, 'e': 4}
865
+ >>> merge_dicts(d1, d2)
866
+ {'a': 1, 'b': {'c': 2, 'd': 3}, 'e': 4}
867
+ """
868
+ result = {}
869
+
870
+ for d in dicts:
871
+ if not isinstance(d, dict):
872
+ raise TypeError(f"All arguments must be dicts, got {type(d).__name__}")
873
+
874
+ for key, value in d.items():
875
+ if recursive and key in result and isinstance(result[key], dict) and isinstance(value, dict):
876
+ result[key] = merge_dicts(result[key], value, recursive=True)
877
+ else:
878
+ result[key] = value
879
+
880
+ return result
881
+
882
+
883
+ @set_module_as('brainstate.util')
884
+ def flatten_dict(
885
+ d: Dict[str, Any],
886
+ parent_key: str = '',
887
+ sep: str = '.'
888
+ ) -> Dict[str, Any]:
889
+ """
890
+ Flatten a nested dictionary.
891
+
892
+ Parameters
893
+ ----------
894
+ d : Dict
895
+ Dictionary to flatten.
896
+ parent_key : str, default=''
897
+ Prefix for keys.
898
+ sep : str, default='.'
899
+ Separator between nested keys.
900
+
901
+ Returns
902
+ -------
903
+ Dict
904
+ Flattened dictionary.
905
+
906
+ Examples
907
+ --------
908
+ >>> d = {'a': 1, 'b': {'c': 2, 'd': {'e': 3}}}
909
+ >>> flatten_dict(d)
910
+ {'a': 1, 'b.c': 2, 'b.d.e': 3}
911
+ """
912
+ items = []
913
+ for k, v in d.items():
914
+ new_key = f"{parent_key}{sep}{k}" if parent_key else k
915
+ if isinstance(v, dict):
916
+ items.extend(flatten_dict(v, new_key, sep=sep).items())
917
+ else:
918
+ items.append((new_key, v))
919
+ return dict(items)
920
+
921
+
922
+ @set_module_as('brainstate.util')
923
+ def unflatten_dict(
924
+ d: Dict[str, Any],
925
+ sep: str = '.'
926
+ ) -> Dict[str, Any]:
927
+ """
928
+ Unflatten a dictionary with separated keys.
929
+
930
+ Parameters
931
+ ----------
932
+ d : Dict
933
+ Flattened dictionary.
934
+ sep : str, default='.'
935
+ Separator in keys.
936
+
937
+ Returns
938
+ -------
939
+ Dict
940
+ Nested dictionary.
941
+
942
+ Examples
943
+ --------
944
+ >>> d = {'a': 1, 'b.c': 2, 'b.d.e': 3}
945
+ >>> unflatten_dict(d)
946
+ {'a': 1, 'b': {'c': 2, 'd': {'e': 3}}}
947
+ """
948
+ result = {}
949
+
950
+ for key, value in d.items():
951
+ parts = key.split(sep)
952
+ current = result
953
+
954
+ for part in parts[:-1]:
955
+ if part not in current:
956
+ current[part] = {}
957
+ current = current[part]
958
+
959
+ current[parts[-1]] = value
960
+
961
+ return result
962
+
963
+
964
+ def _is_not_instance(x: Any, cls: Union[Type, Tuple[Type, ...]]) -> bool:
965
+ """Check if x is not an instance of cls."""
966
+ return not isinstance(x, cls)
967
+
968
+
969
+ def _is_instance(x: Any, cls: Union[Type, Tuple[Type, ...]]) -> bool:
970
+ """Check if x is an instance of cls."""
971
+ return isinstance(x, cls)
972
+
973
+
974
+ @set_module_as('brainstate.util')
975
+ def not_instance_eval(*cls: Type) -> Callable[[Any], bool]:
976
+ """
977
+ Create a partial function to check if input is NOT an instance of given classes.
978
+
979
+ Parameters
980
+ ----------
981
+ *cls : Type
982
+ Classes to check against.
983
+
984
+ Returns
985
+ -------
986
+ Callable
987
+ A function that returns True if input is not an instance of any given class.
988
+
989
+ Examples
990
+ --------
991
+ >>> not_int = not_instance_eval(int)
992
+ >>> not_int(5)
993
+ False
994
+ >>> not_int("hello")
995
+ True
996
+ """
997
+ return functools.partial(_is_not_instance, cls=cls)
998
+
999
+
1000
+ @set_module_as('brainstate.util')
1001
+ def is_instance_eval(*cls: Type) -> Callable[[Any], bool]:
1002
+ """
1003
+ Create a partial function to check if input IS an instance of given classes.
1004
+
1005
+ Parameters
1006
+ ----------
1007
+ *cls : Type
1008
+ Classes to check against.
1009
+
1010
+ Returns
1011
+ -------
1012
+ Callable
1013
+ A function that returns True if input is an instance of any given class.
1014
+
1015
+ Examples
1016
+ --------
1017
+ >>> is_number = is_instance_eval(int, float)
1018
+ >>> is_number(5)
1019
+ True
1020
+ >>> is_number(3.14)
1021
+ True
1022
+ >>> is_number("hello")
1023
+ False
1024
+ """
1025
+ return functools.partial(_is_instance, cls=cls)