brainstate 0.0.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. brainstate/__init__.py +45 -0
  2. brainstate/_module.py +1466 -0
  3. brainstate/_module_test.py +133 -0
  4. brainstate/_state.py +378 -0
  5. brainstate/_state_test.py +41 -0
  6. brainstate/_utils.py +21 -0
  7. brainstate/environ.py +375 -0
  8. brainstate/functional/__init__.py +25 -0
  9. brainstate/functional/_activations.py +754 -0
  10. brainstate/functional/_normalization.py +69 -0
  11. brainstate/functional/_spikes.py +90 -0
  12. brainstate/init/__init__.py +26 -0
  13. brainstate/init/_base.py +36 -0
  14. brainstate/init/_generic.py +175 -0
  15. brainstate/init/_random_inits.py +489 -0
  16. brainstate/init/_regular_inits.py +109 -0
  17. brainstate/math/__init__.py +21 -0
  18. brainstate/math/_einops.py +787 -0
  19. brainstate/math/_einops_parsing.py +169 -0
  20. brainstate/math/_einops_parsing_test.py +126 -0
  21. brainstate/math/_einops_test.py +346 -0
  22. brainstate/math/_misc.py +298 -0
  23. brainstate/math/_misc_test.py +58 -0
  24. brainstate/mixin.py +373 -0
  25. brainstate/mixin_test.py +73 -0
  26. brainstate/nn/__init__.py +68 -0
  27. brainstate/nn/_base.py +248 -0
  28. brainstate/nn/_connections.py +686 -0
  29. brainstate/nn/_dynamics.py +406 -0
  30. brainstate/nn/_elementwise.py +1437 -0
  31. brainstate/nn/_misc.py +132 -0
  32. brainstate/nn/_normalizations.py +389 -0
  33. brainstate/nn/_others.py +100 -0
  34. brainstate/nn/_poolings.py +1228 -0
  35. brainstate/nn/_poolings_test.py +231 -0
  36. brainstate/nn/_projection/__init__.py +32 -0
  37. brainstate/nn/_projection/_align_post.py +528 -0
  38. brainstate/nn/_projection/_align_pre.py +599 -0
  39. brainstate/nn/_projection/_delta.py +241 -0
  40. brainstate/nn/_projection/_utils.py +17 -0
  41. brainstate/nn/_projection/_vanilla.py +101 -0
  42. brainstate/nn/_rate_rnns.py +393 -0
  43. brainstate/nn/_readout.py +130 -0
  44. brainstate/nn/_synouts.py +166 -0
  45. brainstate/nn/functional/__init__.py +25 -0
  46. brainstate/nn/functional/_activations.py +754 -0
  47. brainstate/nn/functional/_normalization.py +69 -0
  48. brainstate/nn/functional/_spikes.py +90 -0
  49. brainstate/nn/init/__init__.py +26 -0
  50. brainstate/nn/init/_base.py +36 -0
  51. brainstate/nn/init/_generic.py +175 -0
  52. brainstate/nn/init/_random_inits.py +489 -0
  53. brainstate/nn/init/_regular_inits.py +109 -0
  54. brainstate/nn/surrogate.py +1740 -0
  55. brainstate/optim/__init__.py +23 -0
  56. brainstate/optim/_lr_scheduler.py +486 -0
  57. brainstate/optim/_lr_scheduler_test.py +36 -0
  58. brainstate/optim/_sgd_optimizer.py +1148 -0
  59. brainstate/random.py +5148 -0
  60. brainstate/random_test.py +576 -0
  61. brainstate/surrogate.py +1740 -0
  62. brainstate/transform/__init__.py +36 -0
  63. brainstate/transform/_autograd.py +585 -0
  64. brainstate/transform/_autograd_test.py +1183 -0
  65. brainstate/transform/_control.py +665 -0
  66. brainstate/transform/_controls_test.py +220 -0
  67. brainstate/transform/_jit.py +239 -0
  68. brainstate/transform/_jit_error.py +158 -0
  69. brainstate/transform/_jit_test.py +102 -0
  70. brainstate/transform/_make_jaxpr.py +573 -0
  71. brainstate/transform/_make_jaxpr_test.py +133 -0
  72. brainstate/transform/_progress_bar.py +113 -0
  73. brainstate/typing.py +69 -0
  74. brainstate/util.py +747 -0
  75. brainstate-0.0.1.dist-info/LICENSE +202 -0
  76. brainstate-0.0.1.dist-info/METADATA +101 -0
  77. brainstate-0.0.1.dist-info/RECORD +79 -0
  78. brainstate-0.0.1.dist-info/WHEEL +6 -0
  79. brainstate-0.0.1.dist-info/top_level.txt +1 -0
brainstate/util.py ADDED
@@ -0,0 +1,747 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import copy
17
+ import functools
18
+ import gc
19
+ import types
20
+ import warnings
21
+ from collections.abc import Iterable
22
+ from typing import Any, Callable, Tuple, Union, Sequence
23
+
24
+ import jax
25
+ from jax.lib import xla_bridge
26
+
27
+ from ._utils import set_module_as
28
+
29
+
30
+ __all__ = [
31
+ 'unique_name',
32
+ 'clear_buffer_memory',
33
+ 'not_instance_eval',
34
+ 'is_instance_eval',
35
+ 'DictManager',
36
+ 'MemScaling',
37
+ 'IdMemScaling',
38
+ 'DotDict',
39
+ ]
40
+
41
+ _name2id = dict()
42
+ _typed_names = {}
43
+
44
+
45
+ @set_module_as('brainstate.util')
46
+ def check_name_uniqueness(name, obj):
47
+ """Check the uniqueness of the name for the object type."""
48
+ if not name.isidentifier():
49
+ raise ValueError(
50
+ f'"{name}" isn\'t a valid identifier '
51
+ f'according to Python language definition. '
52
+ f'Please choose another name.'
53
+ )
54
+ if name in _name2id:
55
+ if _name2id[name] != id(obj):
56
+ raise ValueError(
57
+ f'In BrainPy, each object should have a unique name. '
58
+ f'However, we detect that {obj} has a used name "{name}". \n'
59
+ f'If you try to run multiple trials, you may need \n\n'
60
+ f'>>> brainpy.brainpy_object.clear_name_cache() \n\n'
61
+ f'to clear all cached names. '
62
+ )
63
+ else:
64
+ _name2id[name] = id(obj)
65
+
66
+
67
+ def get_unique_name(type_: str):
68
+ """Get the unique name for the given object type."""
69
+ if type_ not in _typed_names:
70
+ _typed_names[type_] = 0
71
+ name = f'{type_}{_typed_names[type_]}'
72
+ _typed_names[type_] += 1
73
+ return name
74
+
75
+
76
+ @set_module_as('brainstate.util')
77
+ def unique_name(name=None, self=None):
78
+ """Get the unique name for this object.
79
+
80
+ Parameters
81
+ ----------
82
+ name : str, optional
83
+ The expected name. If None, the default unique name will be returned.
84
+ Otherwise, the provided name will be checked to guarantee its uniqueness.
85
+ self : str, optional
86
+ The name of this class, used for object naming.
87
+
88
+ Returns
89
+ -------
90
+ name : str
91
+ The unique name for this object.
92
+ """
93
+ if name is None:
94
+ assert self is not None, 'If name is None, self should be provided.'
95
+ return get_unique_name(type_=self.__class__.__name__)
96
+ else:
97
+ check_name_uniqueness(name=name, obj=self)
98
+ return name
99
+
100
+
101
+ @set_module_as('brainstate.util')
102
+ def clear_name_cache(ignore_warn: bool = True):
103
+ """Clear the cached names."""
104
+ _name2id.clear()
105
+ _typed_names.clear()
106
+ if not ignore_warn:
107
+ warnings.warn(f'All named models and their ids are cleared.', UserWarning)
108
+
109
+
110
+ @jax.tree_util.register_pytree_node_class
111
+ class DictManager(dict):
112
+ """
113
+ DictManager, for collecting all pytree used in the program.
114
+
115
+ :py:class:`~.DictManager` supports all features of python dict.
116
+ """
117
+ __module__ = 'brainstate.util'
118
+
119
+ def subset(self, sep: Union[type, Tuple[type, ...], Callable]) -> 'DictManager':
120
+ """
121
+ Get a new stack with the subset of keys.
122
+ """
123
+ gather = type(self)()
124
+ if isinstance(sep, types.FunctionType):
125
+ for k, v in self.items():
126
+ if sep(v):
127
+ gather[k] = v
128
+ return gather
129
+ else:
130
+ for k, v in self.items():
131
+ if isinstance(v, sep):
132
+ gather[k] = v
133
+ return gather
134
+
135
+ def not_subset(self, sep: Union[type, Tuple[type, ...]]) -> 'DictManager':
136
+ """
137
+ Get a new stack with the subset of keys.
138
+ """
139
+ gather = type(self)()
140
+ for k, v in self.items():
141
+ if not isinstance(v, sep):
142
+ gather[k] = v
143
+ return gather
144
+
145
+ def add_unique_elem(self, key: Any, var: Any):
146
+ """Add a new element."""
147
+ self._check_elem(var)
148
+ if key in self:
149
+ if id(var) != id(self[key]):
150
+ raise ValueError(f'{key} has been registered by {self[key]}, the new value is different from it.')
151
+ else:
152
+ self[key] = var
153
+
154
+ def unique(self) -> 'DictManager':
155
+ """
156
+ Get a new type of collections with unique values.
157
+
158
+ If one value is assigned to two or more keys,
159
+ then only one pair of (key, value) will be returned.
160
+ """
161
+ gather = type(self)()
162
+ seen = set()
163
+ for k, v in self.items():
164
+ if id(v) not in seen:
165
+ seen.add(id(v))
166
+ gather[k] = v
167
+ return gather
168
+
169
+ def assign(self, *args) -> None:
170
+ """
171
+ Assign the value for each element according to the given ``data``.
172
+ """
173
+ for arg in args:
174
+ assert isinstance(arg, dict), 'Must be an instance of dict.'
175
+ for k, v in arg.items():
176
+ self[k] = v
177
+
178
+ def split(self, first: type, *others: type) -> Tuple['DictManager', ...]:
179
+ """
180
+ Split the stack into subsets of stack by the given types.
181
+ """
182
+ filters = (first, *others)
183
+ results = tuple(type(self)() for _ in range(len(filters) + 1))
184
+ for k, v in self.items():
185
+ for i, filt in enumerate(filters):
186
+ if isinstance(v, filt):
187
+ results[i][k] = v
188
+ break
189
+ else:
190
+ results[-1][k] = v
191
+ return results
192
+
193
+ def pop_by_keys(self, keys: Iterable):
194
+ """
195
+ Pop the elements by the keys.
196
+ """
197
+ for k in tuple(self.keys()):
198
+ if k in keys:
199
+ self.pop(k)
200
+
201
+ def pop_by_values(self, values: Iterable, by: str = 'id'):
202
+ """
203
+ Pop the elements by the values.
204
+
205
+ Args:
206
+ values: The value ids.
207
+ by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
208
+ """
209
+ if by == 'id':
210
+ value_ids = {id(v) for v in values}
211
+ for k in tuple(self.keys()):
212
+ if id(self[k]) in value_ids:
213
+ self.pop(k)
214
+ elif by == 'value':
215
+ for k in tuple(self.keys()):
216
+ if self[k] in values:
217
+ self.pop(k)
218
+ else:
219
+ raise ValueError(f'Unsupported method: {by}')
220
+
221
+ def difference_by_keys(self, keys: Iterable):
222
+ """
223
+ Get the difference of the stack by the keys.
224
+ """
225
+ return type(self)({k: v for k, v in self.items() if k not in keys})
226
+
227
+ def difference_by_values(self, values: Iterable, by: str = 'id'):
228
+ """
229
+ Get the difference of the stack by the values.
230
+
231
+ Args:
232
+ values: The value ids.
233
+ by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
234
+ """
235
+ if by == 'id':
236
+ value_ids = {id(v) for v in values}
237
+ return type(self)({k: v for k, v in self.items() if id(v) not in value_ids})
238
+ elif by == 'value':
239
+ return type(self)({k: v for k, v in self.items() if v not in values})
240
+ else:
241
+ raise ValueError(f'Unsupported method: {by}')
242
+
243
+ def intersection_by_keys(self, keys: Iterable):
244
+ """
245
+ Get the intersection of the stack by the keys.
246
+ """
247
+ return type(self)({k: v for k, v in self.items() if k in keys})
248
+
249
+ def intersection_by_values(self, values: Iterable, by: str = 'id'):
250
+ """
251
+ Get the intersection of the stack by the values.
252
+
253
+ Args:
254
+ values: The value ids.
255
+ by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
256
+ """
257
+ if by == 'id':
258
+ value_ids = {id(v) for v in values}
259
+ return type(self)({k: v for k, v in self.items() if id(v) in value_ids})
260
+ elif by == 'value':
261
+ return type(self)({k: v for k, v in self.items() if v in values})
262
+ else:
263
+ raise ValueError(f'Unsupported method: {by}')
264
+
265
+ def union_by_value_ids(self, other: dict):
266
+ """
267
+ Union the stack by the value ids.
268
+
269
+ Args:
270
+ other:
271
+
272
+ Returns:
273
+
274
+ """
275
+
276
+ def __add__(self, other: dict):
277
+ """
278
+ Compose other instance of dict.
279
+ """
280
+ new_dict = type(self)(self)
281
+ new_dict.update(other)
282
+ return new_dict
283
+
284
+ def tree_flatten(self):
285
+ return tuple(self.values()), tuple(self.keys())
286
+
287
+ @classmethod
288
+ def tree_unflatten(cls, keys, values):
289
+ return cls(jax.util.safe_zip(keys, values))
290
+
291
+ def _check_elem(self, elem: Any):
292
+ raise NotImplementedError
293
+
294
+ def to_dict(self):
295
+ """
296
+ Convert the stack to a dict.
297
+
298
+ Returns
299
+ -------
300
+ dict
301
+ The dict object.
302
+ """
303
+ return dict(self)
304
+
305
+ def __copy__(self):
306
+ return type(self)(self)
307
+
308
+
309
+ @set_module_as('brainstate.util')
310
+ def clear_buffer_memory(
311
+ platform: str = None,
312
+ array: bool = True,
313
+ compilation: bool = False,
314
+ ):
315
+ """Clear all on-device buffers.
316
+
317
+ This function will be very useful when you call models in a Python loop,
318
+ because it can clear all cached arrays, and clear device memory.
319
+
320
+ .. warning::
321
+
322
+ This operation may cause errors when you use a deleted buffer.
323
+ Therefore, regenerate data always.
324
+
325
+ Parameters
326
+ ----------
327
+ platform: str
328
+ The device to clear its memory.
329
+ array: bool
330
+ Clear all buffer array. Default is True.
331
+ compilation: bool
332
+ Clear compilation cache. Default is False.
333
+
334
+ """
335
+ if array:
336
+ for buf in xla_bridge.get_backend(platform).live_buffers():
337
+ buf.delete()
338
+ if compilation:
339
+ jax.clear_caches()
340
+ gc.collect()
341
+
342
+
343
+ class MemScaling(object):
344
+ """
345
+ The scaling object for membrane potential.
346
+
347
+ The scaling object is used to transform the membrane potential range to a
348
+ standard range. The scaling object can be used to transform the membrane
349
+ potential to a standard range, and transform the standard range to the
350
+ membrane potential.
351
+
352
+ """
353
+ __module__ = 'brainstate.util'
354
+
355
+ def __init__(self, scale, bias):
356
+ self._scale = scale
357
+ self._bias = bias
358
+
359
+ @classmethod
360
+ def transform(
361
+ cls,
362
+ oring_range: Sequence[Union[float, int]],
363
+ target_range: Sequence[Union[float, int]] = (0., 1.)
364
+ ) -> 'MemScaling':
365
+ """Transform the membrane potential range to a ``Scaling`` instance.
366
+
367
+ Args:
368
+ oring_range: [V_min, V_max]
369
+ target_range: [scaled_V_min, scaled_V_max]
370
+
371
+ Returns:
372
+ The instanced scaling object.
373
+ """
374
+ V_min, V_max = oring_range
375
+ scaled_V_min, scaled_V_max = target_range
376
+ scale = (V_max - V_min) / (scaled_V_max - scaled_V_min)
377
+ bias = scaled_V_min * scale - V_min
378
+ return cls(scale=scale, bias=bias)
379
+
380
+ def scale_offset(self, x, bias=None, scale=None):
381
+ """
382
+ Transform the membrane potential to the standard range.
383
+
384
+ Parameters
385
+ ----------
386
+ x : array_like
387
+ The membrane potential.
388
+ bias : float, optional
389
+ The bias of the scaling object. If None, the default bias will be used.
390
+ scale : float, optional
391
+ The scale of the scaling object. If None, the default scale will be used.
392
+
393
+ Returns
394
+ -------
395
+ x : array_like
396
+ The standard range of the membrane potential.
397
+ """
398
+ if bias is None:
399
+ bias = self._bias
400
+ if scale is None:
401
+ scale = self._scale
402
+ return (x + bias) / scale
403
+
404
+ def scale(self, x, scale=None):
405
+ """
406
+ Transform the membrane potential to the standard range.
407
+
408
+ Parameters
409
+ ----------
410
+ x : array_like
411
+ The membrane potential.
412
+ scale : float, optional
413
+ The scale of the scaling object. If None, the default scale will be used.
414
+
415
+ Returns
416
+ -------
417
+ x : array_like
418
+ The standard range of the membrane potential.
419
+ """
420
+ if scale is None:
421
+ scale = self._scale
422
+ return x / scale
423
+
424
+ def offset(self, x, bias=None):
425
+ """
426
+ Transform the membrane potential to the standard range.
427
+
428
+ Parameters
429
+ ----------
430
+ x : array_like
431
+ The membrane potential.
432
+ bias : float, optional
433
+ The bias of the scaling object. If None, the default bias will be used.
434
+
435
+ Returns
436
+ -------
437
+ x : array_like
438
+ The standard range of the membrane potential.
439
+ """
440
+ if bias is None:
441
+ bias = self._bias
442
+ return x + bias
443
+
444
+ def rev_scale(self, x, scale=None):
445
+ """
446
+ Reversely transform the standard range to the original membrane potential.
447
+
448
+ Parameters
449
+ ----------
450
+ x : array_like
451
+ The standard range of the membrane potential.
452
+ scale : float, optional
453
+ The scale of the scaling object. If None, the default scale will be used.
454
+
455
+ Returns
456
+ -------
457
+ x : array_like
458
+ The original membrane potential.
459
+ """
460
+ if scale is None:
461
+ scale = self._scale
462
+ return x * scale
463
+
464
+ def rev_offset(self, x, bias=None):
465
+ """
466
+ Reversely transform the standard range to the original membrane potential.
467
+
468
+ Parameters
469
+ ----------
470
+ x : array_like
471
+ The standard range of the membrane potential.
472
+ bias : float, optional
473
+ The bias of the scaling object. If None, the default bias will be used.
474
+
475
+ Returns
476
+ -------
477
+ x : array_like
478
+ The original membrane potential.
479
+ """
480
+ if bias is None:
481
+ bias = self._bias
482
+ return x - bias
483
+
484
+ def rev_scale_offset(self, x, bias=None, scale=None):
485
+ """
486
+ Reversely transform the standard range to the original membrane potential.
487
+
488
+ Parameters
489
+ ----------
490
+ x : array_like
491
+ The standard range of the membrane potential.
492
+ bias : float, optional
493
+ The bias of the scaling object. If None, the default bias will be used.
494
+ scale : float, optional
495
+ The scale of the scaling object. If None, the default scale will be used.
496
+
497
+ Returns
498
+ -------
499
+ x : array_like
500
+ The original membrane potential.
501
+ """
502
+ if bias is None:
503
+ bias = self._bias
504
+ if scale is None:
505
+ scale = self._scale
506
+ return x * scale - bias
507
+
508
+ def clone(self):
509
+ """
510
+ Clone the scaling object.
511
+
512
+ Returns
513
+ -------
514
+ scaling : MemScaling
515
+ The cloned scaling object.
516
+ """
517
+ return MemScaling(bias=self._bias, scale=self._scale)
518
+
519
+
520
+ class IdMemScaling(MemScaling):
521
+ """
522
+ The identity scaling object.
523
+
524
+ The identity scaling object is used to transform the membrane potential to
525
+ the standard range, and reversely transform the standard range to the
526
+ membrane potential.
527
+
528
+ """
529
+ __module__ = 'brainstate.util'
530
+
531
+ def __init__(self):
532
+ super().__init__(scale=1., bias=0.)
533
+
534
+ def scale_offset(self, x, bias=None, scale=None):
535
+ """
536
+ Transform the membrane potential to the standard range.
537
+ """
538
+ return x
539
+
540
+ def scale(self, x, scale=None):
541
+ """
542
+ Transform the membrane potential to the standard range.
543
+ """
544
+ return x
545
+
546
+ def offset(self, x, bias=None):
547
+ """
548
+ Transform the membrane potential to the standard range.
549
+ """
550
+ return x
551
+
552
+ def rev_scale(self, x, scale=None):
553
+ """
554
+ Reversely transform the standard range to the original membrane potential.
555
+
556
+ """
557
+ return x
558
+
559
+ def rev_offset(self, x, bias=None):
560
+ """
561
+ Reversely transform the standard range to the original membrane potential.
562
+
563
+
564
+ """
565
+ return x
566
+
567
+ def rev_scale_offset(self, x, bias=None, scale=None):
568
+ """
569
+ Reversely transform the standard range to the original membrane potential.
570
+ """
571
+ return x
572
+
573
+ def clone(self):
574
+ """
575
+ Clone the scaling object.
576
+ """
577
+ return IdMemScaling()
578
+
579
+
580
+ @jax.tree_util.register_pytree_node_class
581
+ class DotDict(dict):
582
+ """Python dictionaries with advanced dot notation access.
583
+
584
+ For example:
585
+
586
+ >>> d = DotDict({'a': 10, 'b': 20})
587
+ >>> d.a
588
+ 10
589
+ >>> d['a']
590
+ 10
591
+ >>> d.c # this will raise a KeyError
592
+ KeyError: 'c'
593
+ >>> d.c = 30 # but you can assign a value to a non-existing item
594
+ >>> d.c
595
+ 30
596
+ """
597
+
598
+ __module__ = 'brainstate.util'
599
+
600
+ def __init__(self, *args, **kwargs):
601
+ object.__setattr__(self, '__parent', kwargs.pop('__parent', None))
602
+ object.__setattr__(self, '__key', kwargs.pop('__key', None))
603
+ for arg in args:
604
+ if not arg:
605
+ continue
606
+ elif isinstance(arg, dict):
607
+ for key, val in arg.items():
608
+ self[key] = self._hook(val)
609
+ elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)):
610
+ self[arg[0]] = self._hook(arg[1])
611
+ else:
612
+ for key, val in iter(arg):
613
+ self[key] = self._hook(val)
614
+
615
+ for key, val in kwargs.items():
616
+ self[key] = self._hook(val)
617
+
618
+ def __setattr__(self, name, value):
619
+ if hasattr(self.__class__, name):
620
+ raise AttributeError(f"Attribute '{name}' is read-only in '{type(self)}' object.")
621
+ else:
622
+ self[name] = value
623
+
624
+ def __setitem__(self, name, value):
625
+ super(DotDict, self).__setitem__(name, value)
626
+ try:
627
+ p = object.__getattribute__(self, '__parent')
628
+ key = object.__getattribute__(self, '__key')
629
+ except AttributeError:
630
+ p = None
631
+ key = None
632
+ if p is not None:
633
+ p[key] = self
634
+ object.__delattr__(self, '__parent')
635
+ object.__delattr__(self, '__key')
636
+
637
+ @classmethod
638
+ def _hook(cls, item):
639
+ if isinstance(item, dict):
640
+ return cls(item)
641
+ elif isinstance(item, (list, tuple)):
642
+ return type(item)(cls._hook(elem) for elem in item)
643
+ return item
644
+
645
+ def __getattr__(self, item):
646
+ return self.__getitem__(item)
647
+
648
+ def __delattr__(self, name):
649
+ del self[name]
650
+
651
+ def copy(self):
652
+ return copy.copy(self)
653
+
654
+ def deepcopy(self):
655
+ return copy.deepcopy(self)
656
+
657
+ def __deepcopy__(self, memo):
658
+ other = self.__class__()
659
+ memo[id(self)] = other
660
+ for key, value in self.items():
661
+ other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo)
662
+ return other
663
+
664
+ def to_dict(self):
665
+ base = {}
666
+ for key, value in self.items():
667
+ if isinstance(value, type(self)):
668
+ base[key] = value.to_dict()
669
+ elif isinstance(value, (list, tuple)):
670
+ base[key] = type(value)(item.to_dict() if isinstance(item, type(self)) else item
671
+ for item in value)
672
+ else:
673
+ base[key] = value
674
+ return base
675
+
676
+ def update(self, *args, **kwargs):
677
+ other = {}
678
+ if args:
679
+ if len(args) > 1:
680
+ raise TypeError()
681
+ other.update(args[0])
682
+ other.update(kwargs)
683
+ for k, v in other.items():
684
+ if (k not in self) or (not isinstance(self[k], dict)) or (not isinstance(v, dict)):
685
+ self[k] = v
686
+ else:
687
+ self[k].update(v)
688
+
689
+ def __getnewargs__(self):
690
+ return tuple(self.items())
691
+
692
+ def __getstate__(self):
693
+ return self
694
+
695
+ def __setstate__(self, state):
696
+ self.update(state)
697
+
698
+ def setdefault(self, key, default=None):
699
+ if key in self:
700
+ return self[key]
701
+ else:
702
+ self[key] = default
703
+ return default
704
+
705
+ def tree_flatten(self):
706
+ return tuple(self.values()), tuple(self.keys())
707
+
708
+ @classmethod
709
+ def tree_unflatten(cls, keys, values):
710
+ return cls(jax.util.safe_zip(keys, values))
711
+
712
+
713
+ def _is_not_instance(x, cls):
714
+ return not isinstance(x, cls)
715
+
716
+
717
+ def _is_instance(x, cls):
718
+ return isinstance(x, cls)
719
+
720
+
721
+ @set_module_as('brainstate.util')
722
+ def not_instance_eval(*cls):
723
+ """
724
+ Create a partial function to evaluate if the input is not an instance of the given class.
725
+
726
+ Args:
727
+ *cls: The classes to check.
728
+
729
+ Returns:
730
+ The partial function.
731
+
732
+ """
733
+ return functools.partial(_is_not_instance, cls=cls)
734
+
735
+
736
+ @set_module_as('brainstate.util')
737
+ def is_instance_eval(*cls):
738
+ """
739
+ Create a partial function to evaluate if the input is an instance of the given class.
740
+
741
+ Args:
742
+ *cls: The classes to check.
743
+
744
+ Returns:
745
+ The partial function.
746
+ """
747
+ return functools.partial(_is_instance, cls=cls)