brainstate 0.1.8__py2.py3-none-any.whl → 0.1.10__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. brainstate/__init__.py +58 -51
  2. brainstate/_compatible_import.py +148 -148
  3. brainstate/_state.py +1605 -1663
  4. brainstate/_state_test.py +52 -52
  5. brainstate/_utils.py +47 -47
  6. brainstate/augment/__init__.py +30 -30
  7. brainstate/augment/_autograd.py +778 -778
  8. brainstate/augment/_autograd_test.py +1289 -1289
  9. brainstate/augment/_eval_shape.py +99 -99
  10. brainstate/augment/_eval_shape_test.py +38 -38
  11. brainstate/augment/_mapping.py +1060 -1060
  12. brainstate/augment/_mapping_test.py +597 -597
  13. brainstate/augment/_random.py +151 -151
  14. brainstate/compile/__init__.py +38 -38
  15. brainstate/compile/_ad_checkpoint.py +204 -204
  16. brainstate/compile/_ad_checkpoint_test.py +49 -49
  17. brainstate/compile/_conditions.py +256 -256
  18. brainstate/compile/_conditions_test.py +220 -220
  19. brainstate/compile/_error_if.py +92 -92
  20. brainstate/compile/_error_if_test.py +52 -52
  21. brainstate/compile/_jit.py +346 -346
  22. brainstate/compile/_jit_test.py +143 -143
  23. brainstate/compile/_loop_collect_return.py +536 -536
  24. brainstate/compile/_loop_collect_return_test.py +58 -58
  25. brainstate/compile/_loop_no_collection.py +184 -184
  26. brainstate/compile/_loop_no_collection_test.py +50 -50
  27. brainstate/compile/_make_jaxpr.py +888 -888
  28. brainstate/compile/_make_jaxpr_test.py +156 -156
  29. brainstate/compile/_progress_bar.py +202 -202
  30. brainstate/compile/_unvmap.py +159 -159
  31. brainstate/compile/_util.py +147 -147
  32. brainstate/environ.py +563 -563
  33. brainstate/environ_test.py +62 -62
  34. brainstate/functional/__init__.py +27 -26
  35. brainstate/graph/__init__.py +29 -29
  36. brainstate/graph/_graph_node.py +244 -244
  37. brainstate/graph/_graph_node_test.py +73 -73
  38. brainstate/graph/_graph_operation.py +1738 -1738
  39. brainstate/graph/_graph_operation_test.py +563 -563
  40. brainstate/init/__init__.py +26 -26
  41. brainstate/init/_base.py +52 -52
  42. brainstate/init/_generic.py +244 -244
  43. brainstate/init/_random_inits.py +553 -553
  44. brainstate/init/_random_inits_test.py +149 -149
  45. brainstate/init/_regular_inits.py +105 -105
  46. brainstate/init/_regular_inits_test.py +50 -50
  47. brainstate/mixin.py +365 -363
  48. brainstate/mixin_test.py +77 -73
  49. brainstate/nn/__init__.py +135 -131
  50. brainstate/{functional → nn}/_activations.py +808 -813
  51. brainstate/{functional → nn}/_activations_test.py +331 -331
  52. brainstate/nn/_collective_ops.py +514 -514
  53. brainstate/nn/_collective_ops_test.py +43 -43
  54. brainstate/nn/_common.py +178 -178
  55. brainstate/nn/_conv.py +501 -501
  56. brainstate/nn/_conv_test.py +238 -238
  57. brainstate/nn/_delay.py +588 -502
  58. brainstate/nn/_delay_test.py +238 -184
  59. brainstate/nn/_dropout.py +426 -426
  60. brainstate/nn/_dropout_test.py +100 -100
  61. brainstate/nn/_dynamics.py +1343 -1343
  62. brainstate/nn/_dynamics_test.py +78 -78
  63. brainstate/nn/_elementwise.py +1119 -1119
  64. brainstate/nn/_elementwise_test.py +169 -169
  65. brainstate/nn/_embedding.py +58 -58
  66. brainstate/nn/_exp_euler.py +92 -92
  67. brainstate/nn/_exp_euler_test.py +35 -35
  68. brainstate/nn/_fixedprob.py +239 -239
  69. brainstate/nn/_fixedprob_test.py +114 -114
  70. brainstate/nn/_inputs.py +608 -608
  71. brainstate/nn/_linear.py +424 -424
  72. brainstate/nn/_linear_mv.py +83 -83
  73. brainstate/nn/_linear_mv_test.py +120 -120
  74. brainstate/nn/_linear_test.py +107 -107
  75. brainstate/nn/_ltp.py +28 -28
  76. brainstate/nn/_module.py +377 -377
  77. brainstate/nn/_module_test.py +40 -40
  78. brainstate/nn/_neuron.py +705 -705
  79. brainstate/nn/_neuron_test.py +161 -161
  80. brainstate/nn/_normalizations.py +975 -918
  81. brainstate/nn/_normalizations_test.py +73 -73
  82. brainstate/{functional → nn}/_others.py +46 -46
  83. brainstate/nn/_poolings.py +1177 -1177
  84. brainstate/nn/_poolings_test.py +217 -217
  85. brainstate/nn/_projection.py +486 -486
  86. brainstate/nn/_rate_rnns.py +554 -554
  87. brainstate/nn/_rate_rnns_test.py +63 -63
  88. brainstate/nn/_readout.py +209 -209
  89. brainstate/nn/_readout_test.py +53 -53
  90. brainstate/nn/_stp.py +236 -236
  91. brainstate/nn/_synapse.py +505 -505
  92. brainstate/nn/_synapse_test.py +131 -131
  93. brainstate/nn/_synaptic_projection.py +423 -423
  94. brainstate/nn/_synouts.py +162 -162
  95. brainstate/nn/_synouts_test.py +57 -57
  96. brainstate/nn/_utils.py +89 -89
  97. brainstate/nn/metrics.py +388 -388
  98. brainstate/optim/__init__.py +38 -38
  99. brainstate/optim/_base.py +64 -64
  100. brainstate/optim/_lr_scheduler.py +448 -448
  101. brainstate/optim/_lr_scheduler_test.py +50 -50
  102. brainstate/optim/_optax_optimizer.py +152 -152
  103. brainstate/optim/_optax_optimizer_test.py +53 -53
  104. brainstate/optim/_sgd_optimizer.py +1104 -1104
  105. brainstate/random/__init__.py +24 -24
  106. brainstate/random/_rand_funs.py +3616 -3616
  107. brainstate/random/_rand_funs_test.py +567 -567
  108. brainstate/random/_rand_seed.py +210 -210
  109. brainstate/random/_rand_seed_test.py +48 -48
  110. brainstate/random/_rand_state.py +1409 -1409
  111. brainstate/random/_random_for_unit.py +52 -52
  112. brainstate/surrogate.py +1957 -1957
  113. brainstate/transform.py +23 -23
  114. brainstate/typing.py +304 -304
  115. brainstate/util/__init__.py +50 -50
  116. brainstate/util/caller.py +98 -98
  117. brainstate/util/error.py +55 -55
  118. brainstate/util/filter.py +469 -469
  119. brainstate/util/others.py +540 -540
  120. brainstate/util/pretty_pytree.py +945 -945
  121. brainstate/util/pretty_pytree_test.py +159 -159
  122. brainstate/util/pretty_repr.py +328 -328
  123. brainstate/util/pretty_table.py +2954 -2954
  124. brainstate/util/scaling.py +258 -258
  125. brainstate/util/struct.py +523 -523
  126. {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/METADATA +91 -99
  127. brainstate-0.1.10.dist-info/RECORD +130 -0
  128. {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/WHEEL +1 -1
  129. {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info/licenses}/LICENSE +202 -202
  130. brainstate/functional/_normalization.py +0 -81
  131. brainstate/functional/_spikes.py +0 -204
  132. brainstate-0.1.8.dist-info/RECORD +0 -132
  133. {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/top_level.txt +0 -0
brainstate/util/others.py CHANGED
@@ -1,540 +1,540 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import copy
17
- import functools
18
- import gc
19
- import threading
20
- import types
21
- from collections.abc import Iterable
22
- from typing import Any, Callable, Tuple, Union, Dict
23
-
24
- import jax
25
- from jax.lib import xla_bridge
26
-
27
- from brainstate._utils import set_module_as
28
-
29
- __all__ = [
30
- 'split_total',
31
- 'clear_buffer_memory',
32
- 'not_instance_eval',
33
- 'is_instance_eval',
34
- 'DictManager',
35
- 'DotDict',
36
- ]
37
-
38
-
39
- def split_total(
40
- total: int,
41
- fraction: Union[int, float],
42
- ) -> int:
43
- """
44
- Calculate the number of epochs for simulation based on a total and a fraction.
45
-
46
- This function determines the number of epochs to simulate given a total number
47
- of epochs and either a fraction or a specific number of epochs to run.
48
-
49
- Parameters:
50
- -----------
51
- total : int
52
- The total number of epochs. Must be a positive integer.
53
- fraction : Union[int, float]
54
- If ``float``: A value between 0 and 1 representing the fraction of total epochs to run.
55
- If ``int``: The specific number of epochs to run, must not exceed the total.
56
-
57
- Returns:
58
- --------
59
- int
60
- The calculated number of epochs to simulate.
61
-
62
- Raises:
63
- -------
64
- ValueError
65
- If total is not positive, fraction is negative, or if fraction as float is > 1
66
- or as int is > total.
67
- AssertionError
68
- If total is not an integer.
69
- """
70
- assert isinstance(total, int), "Length must be an integer."
71
- if total <= 0:
72
- raise ValueError("'total' must be a positive integer.")
73
- if fraction < 0:
74
- raise ValueError("'fraction' value cannot be negative.")
75
-
76
- if isinstance(fraction, float):
77
- if fraction < 0:
78
- raise ValueError("'fraction' value cannot be negative.")
79
- if fraction > 1:
80
- raise ValueError("'fraction' value cannot be greater than 1.")
81
- return int(total * fraction)
82
-
83
- elif isinstance(fraction, int):
84
- if fraction < 0:
85
- raise ValueError("'fraction' value cannot be negative.")
86
- if fraction > total:
87
- raise ValueError("'fraction' value cannot be greater than total.")
88
- return fraction
89
-
90
- else:
91
- raise ValueError("'fraction' must be an integer or float.")
92
-
93
-
94
- class NameContext(threading.local):
95
- def __init__(self):
96
- self.typed_names: Dict[str, int] = {}
97
-
98
-
99
- NAME = NameContext()
100
-
101
-
102
- def get_unique_name(type_: str):
103
- """Get the unique name for the given object type."""
104
- if type_ not in NAME.typed_names:
105
- NAME.typed_names[type_] = 0
106
- name = f'{type_}{NAME.typed_names[type_]}'
107
- NAME.typed_names[type_] += 1
108
- return name
109
-
110
-
111
- @jax.tree_util.register_pytree_node_class
112
- class DictManager(dict):
113
- """
114
- DictManager, for collecting all pytree used in the program.
115
-
116
- :py:class:`~.DictManager` supports all features of python dict.
117
- """
118
- __module__ = 'brainstate.util'
119
- _val_id_to_key: dict
120
-
121
- def subset(self, sep: Union[type, Tuple[type, ...], Callable]) -> 'DictManager':
122
- """
123
- Get a new stack with the subset of keys.
124
- """
125
- gather = type(self)()
126
- if isinstance(sep, types.FunctionType):
127
- for k, v in self.items():
128
- if sep(v):
129
- gather[k] = v
130
- return gather
131
- else:
132
- for k, v in self.items():
133
- if isinstance(v, sep):
134
- gather[k] = v
135
- return gather
136
-
137
- def not_subset(self, sep: Union[type, Tuple[type, ...]]) -> 'DictManager':
138
- """
139
- Get a new stack with the subset of keys.
140
- """
141
- gather = type(self)()
142
- for k, v in self.items():
143
- if not isinstance(v, sep):
144
- gather[k] = v
145
- return gather
146
-
147
- def add_unique_key(self, key: Any, val: Any):
148
- """
149
- Add a new element and check if the value is same or not.
150
- """
151
- self._check_elem(val)
152
- if key in self:
153
- if id(val) != id(self[key]):
154
- raise ValueError(f'{key} has been registered by {self[key]}, the new value is different from it.')
155
- else:
156
- self[key] = val
157
-
158
- def add_unique_value(self, key: Any, val: Any) -> bool:
159
- """
160
- Add a new element and check if the val is unique.
161
-
162
- Parameters:
163
- key: The key of the element.
164
- val: The value of the element
165
-
166
- Returns:
167
- bool: True if the value is unique, False otherwise.
168
- """
169
- self._check_elem(val)
170
- if not hasattr(self, '_val_id_to_key'):
171
- self._val_id_to_key = {id(v): k for k, v in self.items()}
172
- if id(val) not in self._val_id_to_key:
173
- self._val_id_to_key[id(val)] = key
174
- self[key] = val
175
- return True
176
- else:
177
- return False
178
-
179
- def unique(self) -> 'DictManager':
180
- """
181
- Get a new type of collections with unique values.
182
-
183
- If one value is assigned to two or more keys,
184
- then only one pair of (key, value) will be returned.
185
- """
186
- gather = type(self)()
187
- seen = set()
188
- for k, v in self.items():
189
- if id(v) not in seen:
190
- seen.add(id(v))
191
- gather[k] = v
192
- return gather
193
-
194
- def unique_(self):
195
- """
196
- Get a new type of collections with unique values.
197
-
198
- If one value is assigned to two or more keys,
199
- then only one pair of (key, value) will be returned.
200
- """
201
- seen = set()
202
- for k in tuple(self.keys()):
203
- v = self[k]
204
- if id(v) not in seen:
205
- seen.add(id(v))
206
- else:
207
- self.pop(k)
208
- return self
209
-
210
- def assign(self, *args) -> None:
211
- """
212
- Assign the value for each element according to the given ``data``.
213
- """
214
- for arg in args:
215
- assert isinstance(arg, dict), 'Must be an instance of dict.'
216
- for k, v in arg.items():
217
- self[k] = v
218
-
219
- def split(self, first: type, *others: type) -> Tuple['DictManager', ...]:
220
- """
221
- Split the stack into subsets of stack by the given types.
222
- """
223
- filters = (first, *others)
224
- results = tuple(type(self)() for _ in range(len(filters) + 1))
225
- for k, v in self.items():
226
- for i, filt in enumerate(filters):
227
- if isinstance(v, filt):
228
- results[i][k] = v
229
- break
230
- else:
231
- results[-1][k] = v
232
- return results
233
-
234
- def pop_by_keys(self, keys: Iterable):
235
- """
236
- Pop the elements by the keys.
237
- """
238
- for k in tuple(self.keys()):
239
- if k in keys:
240
- self.pop(k)
241
-
242
- def pop_by_values(self, values: Iterable, by: str = 'id'):
243
- """
244
- Pop the elements by the values.
245
-
246
- Args:
247
- values: The value ids.
248
- by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
249
- """
250
- if by == 'id':
251
- value_ids = {id(v) for v in values}
252
- for k in tuple(self.keys()):
253
- if id(self[k]) in value_ids:
254
- self.pop(k)
255
- elif by == 'value':
256
- for k in tuple(self.keys()):
257
- if self[k] in values:
258
- self.pop(k)
259
- else:
260
- raise ValueError(f'Unsupported method: {by}')
261
-
262
- def difference_by_keys(self, keys: Iterable):
263
- """
264
- Get the difference of the stack by the keys.
265
- """
266
- return type(self)({k: v for k, v in self.items() if k not in keys})
267
-
268
- def difference_by_values(self, values: Iterable, by: str = 'id'):
269
- """
270
- Get the difference of the stack by the values.
271
-
272
- Args:
273
- values: The value ids.
274
- by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
275
- """
276
- if by == 'id':
277
- value_ids = {id(v) for v in values}
278
- return type(self)({k: v for k, v in self.items() if id(v) not in value_ids})
279
- elif by == 'value':
280
- return type(self)({k: v for k, v in self.items() if v not in values})
281
- else:
282
- raise ValueError(f'Unsupported method: {by}')
283
-
284
- def intersection_by_keys(self, keys: Iterable):
285
- """
286
- Get the intersection of the stack by the keys.
287
- """
288
- return type(self)({k: v for k, v in self.items() if k in keys})
289
-
290
- def intersection_by_values(self, values: Iterable, by: str = 'id'):
291
- """
292
- Get the intersection of the stack by the values.
293
-
294
- Args:
295
- values: The value ids.
296
- by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
297
- """
298
- if by == 'id':
299
- value_ids = {id(v) for v in values}
300
- return type(self)({k: v for k, v in self.items() if id(v) in value_ids})
301
- elif by == 'value':
302
- return type(self)({k: v for k, v in self.items() if v in values})
303
- else:
304
- raise ValueError(f'Unsupported method: {by}')
305
-
306
- def __add__(self, other: dict):
307
- """
308
- Compose other instance of dict.
309
- """
310
- new_dict = type(self)(self)
311
- new_dict.update(other)
312
- return new_dict
313
-
314
- def tree_flatten(self):
315
- return tuple(self.values()), tuple(self.keys())
316
-
317
- @classmethod
318
- def tree_unflatten(cls, keys, values):
319
- return cls(jax.util.safe_zip(keys, values))
320
-
321
- def _check_elem(self, elem: Any):
322
- raise NotImplementedError
323
-
324
- def to_dict(self):
325
- """
326
- Convert the stack to a dict.
327
-
328
- Returns
329
- -------
330
- dict
331
- The dict object.
332
- """
333
- return dict(self)
334
-
335
- def __copy__(self):
336
- return type(self)(self)
337
-
338
-
339
- @set_module_as('brainstate.util')
340
- def clear_buffer_memory(
341
- platform: str = None,
342
- array: bool = True,
343
- compilation: bool = False,
344
- ):
345
- """Clear all on-device buffers.
346
-
347
- This function will be very useful when you call models in a Python loop,
348
- because it can clear all cached arrays, and clear device memory.
349
-
350
- .. warning::
351
-
352
- This operation may cause errors when you use a deleted buffer.
353
- Therefore, regenerate data always.
354
-
355
- Parameters
356
- ----------
357
- platform: str
358
- The device to clear its memory.
359
- array: bool
360
- Clear all buffer array. Default is True.
361
- compilation: bool
362
- Clear compilation cache. Default is False.
363
-
364
- """
365
- if array:
366
- for buf in xla_bridge.get_backend(platform).live_buffers():
367
- buf.delete()
368
- if compilation:
369
- jax.clear_caches()
370
- gc.collect()
371
-
372
-
373
- @jax.tree_util.register_pytree_node_class
374
- class DotDict(dict):
375
- """Python dictionaries with advanced dot notation access.
376
-
377
- For example:
378
-
379
- >>> d = DotDict({'a': 10, 'b': 20})
380
- >>> d.a
381
- 10
382
- >>> d['a']
383
- 10
384
- >>> d.c # this will raise a KeyError
385
- KeyError: 'c'
386
- >>> d.c = 30 # but you can assign a value to a non-existing item
387
- >>> d.c
388
- 30
389
- """
390
-
391
- __module__ = 'brainstate.util'
392
-
393
- def __init__(self, *args, **kwargs):
394
- object.__setattr__(self, '__parent', kwargs.pop('__parent', None))
395
- object.__setattr__(self, '__key', kwargs.pop('__key', None))
396
- for arg in args:
397
- if not arg:
398
- continue
399
- elif isinstance(arg, dict):
400
- for key, val in arg.items():
401
- self[key] = self._hook(val)
402
- elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)):
403
- self[arg[0]] = self._hook(arg[1])
404
- else:
405
- for key, val in iter(arg):
406
- self[key] = self._hook(val)
407
-
408
- for key, val in kwargs.items():
409
- self[key] = self._hook(val)
410
-
411
- def __setattr__(self, name, value):
412
- if hasattr(self.__class__, name):
413
- raise AttributeError(f"Attribute '{name}' is read-only in '{type(self)}' object.")
414
- else:
415
- self[name] = value
416
-
417
- def __setitem__(self, name, value):
418
- super(DotDict, self).__setitem__(name, value)
419
- try:
420
- p = object.__getattribute__(self, '__parent')
421
- key = object.__getattribute__(self, '__key')
422
- except AttributeError:
423
- p = None
424
- key = None
425
- if p is not None:
426
- p[key] = self
427
- object.__delattr__(self, '__parent')
428
- object.__delattr__(self, '__key')
429
-
430
- @classmethod
431
- def _hook(cls, item):
432
- if isinstance(item, dict):
433
- return cls(item)
434
- elif isinstance(item, (list, tuple)):
435
- return type(item)(cls._hook(elem) for elem in item)
436
- return item
437
-
438
- def __getattr__(self, item):
439
- return self.__getitem__(item)
440
-
441
- def __delattr__(self, name):
442
- del self[name]
443
-
444
- def copy(self):
445
- return copy.copy(self)
446
-
447
- def deepcopy(self):
448
- return copy.deepcopy(self)
449
-
450
- def __deepcopy__(self, memo):
451
- other = self.__class__()
452
- memo[id(self)] = other
453
- for key, value in self.items():
454
- other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo)
455
- return other
456
-
457
- def to_dict(self):
458
- base = {}
459
- for key, value in self.items():
460
- if isinstance(value, type(self)):
461
- base[key] = value.to_dict()
462
- elif isinstance(value, (list, tuple)):
463
- base[key] = type(value)(item.to_dict() if isinstance(item, type(self)) else item
464
- for item in value)
465
- else:
466
- base[key] = value
467
- return base
468
-
469
- def update(self, *args, **kwargs):
470
- other = {}
471
- if args:
472
- if len(args) > 1:
473
- raise TypeError()
474
- other.update(args[0])
475
- other.update(kwargs)
476
- for k, v in other.items():
477
- if (k not in self) or (not isinstance(self[k], dict)) or (not isinstance(v, dict)):
478
- self[k] = v
479
- else:
480
- self[k].update(v)
481
-
482
- def __getnewargs__(self):
483
- return tuple(self.items())
484
-
485
- def __getstate__(self):
486
- return self
487
-
488
- def __setstate__(self, state):
489
- self.update(state)
490
-
491
- def setdefault(self, key, default=None):
492
- if key in self:
493
- return self[key]
494
- else:
495
- self[key] = default
496
- return default
497
-
498
- def tree_flatten(self):
499
- return tuple(self.values()), tuple(self.keys())
500
-
501
- @classmethod
502
- def tree_unflatten(cls, keys, values):
503
- return cls(jax.util.safe_zip(keys, values))
504
-
505
-
506
- def _is_not_instance(x, cls):
507
- return not isinstance(x, cls)
508
-
509
-
510
- def _is_instance(x, cls):
511
- return isinstance(x, cls)
512
-
513
-
514
- @set_module_as('brainstate.util')
515
- def not_instance_eval(*cls):
516
- """
517
- Create a partial function to evaluate if the input is not an instance of the given class.
518
-
519
- Args:
520
- *cls: The classes to check.
521
-
522
- Returns:
523
- The partial function.
524
-
525
- """
526
- return functools.partial(_is_not_instance, cls=cls)
527
-
528
-
529
- @set_module_as('brainstate.util')
530
- def is_instance_eval(*cls):
531
- """
532
- Create a partial function to evaluate if the input is an instance of the given class.
533
-
534
- Args:
535
- *cls: The classes to check.
536
-
537
- Returns:
538
- The partial function.
539
- """
540
- return functools.partial(_is_instance, cls=cls)
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import copy
17
+ import functools
18
+ import gc
19
+ import threading
20
+ import types
21
+ from collections.abc import Iterable
22
+ from typing import Any, Callable, Tuple, Union, Dict
23
+
24
+ import jax
25
+ from jax.lib import xla_bridge
26
+
27
+ from brainstate._utils import set_module_as
28
+
29
+ __all__ = [
30
+ 'split_total',
31
+ 'clear_buffer_memory',
32
+ 'not_instance_eval',
33
+ 'is_instance_eval',
34
+ 'DictManager',
35
+ 'DotDict',
36
+ ]
37
+
38
+
39
+ def split_total(
40
+ total: int,
41
+ fraction: Union[int, float],
42
+ ) -> int:
43
+ """
44
+ Calculate the number of epochs for simulation based on a total and a fraction.
45
+
46
+ This function determines the number of epochs to simulate given a total number
47
+ of epochs and either a fraction or a specific number of epochs to run.
48
+
49
+ Parameters:
50
+ -----------
51
+ total : int
52
+ The total number of epochs. Must be a positive integer.
53
+ fraction : Union[int, float]
54
+ If ``float``: A value between 0 and 1 representing the fraction of total epochs to run.
55
+ If ``int``: The specific number of epochs to run, must not exceed the total.
56
+
57
+ Returns:
58
+ --------
59
+ int
60
+ The calculated number of epochs to simulate.
61
+
62
+ Raises:
63
+ -------
64
+ ValueError
65
+ If total is not positive, fraction is negative, or if fraction as float is > 1
66
+ or as int is > total.
67
+ AssertionError
68
+ If total is not an integer.
69
+ """
70
+ assert isinstance(total, int), "Length must be an integer."
71
+ if total <= 0:
72
+ raise ValueError("'total' must be a positive integer.")
73
+ if fraction < 0:
74
+ raise ValueError("'fraction' value cannot be negative.")
75
+
76
+ if isinstance(fraction, float):
77
+ if fraction < 0:
78
+ raise ValueError("'fraction' value cannot be negative.")
79
+ if fraction > 1:
80
+ raise ValueError("'fraction' value cannot be greater than 1.")
81
+ return int(total * fraction)
82
+
83
+ elif isinstance(fraction, int):
84
+ if fraction < 0:
85
+ raise ValueError("'fraction' value cannot be negative.")
86
+ if fraction > total:
87
+ raise ValueError("'fraction' value cannot be greater than total.")
88
+ return fraction
89
+
90
+ else:
91
+ raise ValueError("'fraction' must be an integer or float.")
92
+
93
+
94
+ class NameContext(threading.local):
95
+ def __init__(self):
96
+ self.typed_names: Dict[str, int] = {}
97
+
98
+
99
+ NAME = NameContext()
100
+
101
+
102
+ def get_unique_name(type_: str):
103
+ """Get the unique name for the given object type."""
104
+ if type_ not in NAME.typed_names:
105
+ NAME.typed_names[type_] = 0
106
+ name = f'{type_}{NAME.typed_names[type_]}'
107
+ NAME.typed_names[type_] += 1
108
+ return name
109
+
110
+
111
+ @jax.tree_util.register_pytree_node_class
112
+ class DictManager(dict):
113
+ """
114
+ DictManager, for collecting all pytree used in the program.
115
+
116
+ :py:class:`~.DictManager` supports all features of python dict.
117
+ """
118
+ __module__ = 'brainstate.util'
119
+ _val_id_to_key: dict
120
+
121
+ def subset(self, sep: Union[type, Tuple[type, ...], Callable]) -> 'DictManager':
122
+ """
123
+ Get a new stack with the subset of keys.
124
+ """
125
+ gather = type(self)()
126
+ if isinstance(sep, types.FunctionType):
127
+ for k, v in self.items():
128
+ if sep(v):
129
+ gather[k] = v
130
+ return gather
131
+ else:
132
+ for k, v in self.items():
133
+ if isinstance(v, sep):
134
+ gather[k] = v
135
+ return gather
136
+
137
+ def not_subset(self, sep: Union[type, Tuple[type, ...]]) -> 'DictManager':
138
+ """
139
+ Get a new stack with the subset of keys.
140
+ """
141
+ gather = type(self)()
142
+ for k, v in self.items():
143
+ if not isinstance(v, sep):
144
+ gather[k] = v
145
+ return gather
146
+
147
+ def add_unique_key(self, key: Any, val: Any):
148
+ """
149
+ Add a new element and check if the value is same or not.
150
+ """
151
+ self._check_elem(val)
152
+ if key in self:
153
+ if id(val) != id(self[key]):
154
+ raise ValueError(f'{key} has been registered by {self[key]}, the new value is different from it.')
155
+ else:
156
+ self[key] = val
157
+
158
+ def add_unique_value(self, key: Any, val: Any) -> bool:
159
+ """
160
+ Add a new element and check if the val is unique.
161
+
162
+ Parameters:
163
+ key: The key of the element.
164
+ val: The value of the element
165
+
166
+ Returns:
167
+ bool: True if the value is unique, False otherwise.
168
+ """
169
+ self._check_elem(val)
170
+ if not hasattr(self, '_val_id_to_key'):
171
+ self._val_id_to_key = {id(v): k for k, v in self.items()}
172
+ if id(val) not in self._val_id_to_key:
173
+ self._val_id_to_key[id(val)] = key
174
+ self[key] = val
175
+ return True
176
+ else:
177
+ return False
178
+
179
+ def unique(self) -> 'DictManager':
180
+ """
181
+ Get a new type of collections with unique values.
182
+
183
+ If one value is assigned to two or more keys,
184
+ then only one pair of (key, value) will be returned.
185
+ """
186
+ gather = type(self)()
187
+ seen = set()
188
+ for k, v in self.items():
189
+ if id(v) not in seen:
190
+ seen.add(id(v))
191
+ gather[k] = v
192
+ return gather
193
+
194
+ def unique_(self):
195
+ """
196
+ Get a new type of collections with unique values.
197
+
198
+ If one value is assigned to two or more keys,
199
+ then only one pair of (key, value) will be returned.
200
+ """
201
+ seen = set()
202
+ for k in tuple(self.keys()):
203
+ v = self[k]
204
+ if id(v) not in seen:
205
+ seen.add(id(v))
206
+ else:
207
+ self.pop(k)
208
+ return self
209
+
210
+ def assign(self, *args) -> None:
211
+ """
212
+ Assign the value for each element according to the given ``data``.
213
+ """
214
+ for arg in args:
215
+ assert isinstance(arg, dict), 'Must be an instance of dict.'
216
+ for k, v in arg.items():
217
+ self[k] = v
218
+
219
+ def split(self, first: type, *others: type) -> Tuple['DictManager', ...]:
220
+ """
221
+ Split the stack into subsets of stack by the given types.
222
+ """
223
+ filters = (first, *others)
224
+ results = tuple(type(self)() for _ in range(len(filters) + 1))
225
+ for k, v in self.items():
226
+ for i, filt in enumerate(filters):
227
+ if isinstance(v, filt):
228
+ results[i][k] = v
229
+ break
230
+ else:
231
+ results[-1][k] = v
232
+ return results
233
+
234
+ def pop_by_keys(self, keys: Iterable):
235
+ """
236
+ Pop the elements by the keys.
237
+ """
238
+ for k in tuple(self.keys()):
239
+ if k in keys:
240
+ self.pop(k)
241
+
242
+ def pop_by_values(self, values: Iterable, by: str = 'id'):
243
+ """
244
+ Pop the elements by the values.
245
+
246
+ Args:
247
+ values: The value ids.
248
+ by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
249
+ """
250
+ if by == 'id':
251
+ value_ids = {id(v) for v in values}
252
+ for k in tuple(self.keys()):
253
+ if id(self[k]) in value_ids:
254
+ self.pop(k)
255
+ elif by == 'value':
256
+ for k in tuple(self.keys()):
257
+ if self[k] in values:
258
+ self.pop(k)
259
+ else:
260
+ raise ValueError(f'Unsupported method: {by}')
261
+
262
+ def difference_by_keys(self, keys: Iterable):
263
+ """
264
+ Get the difference of the stack by the keys.
265
+ """
266
+ return type(self)({k: v for k, v in self.items() if k not in keys})
267
+
268
+ def difference_by_values(self, values: Iterable, by: str = 'id'):
269
+ """
270
+ Get the difference of the stack by the values.
271
+
272
+ Args:
273
+ values: The value ids.
274
+ by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
275
+ """
276
+ if by == 'id':
277
+ value_ids = {id(v) for v in values}
278
+ return type(self)({k: v for k, v in self.items() if id(v) not in value_ids})
279
+ elif by == 'value':
280
+ return type(self)({k: v for k, v in self.items() if v not in values})
281
+ else:
282
+ raise ValueError(f'Unsupported method: {by}')
283
+
284
+ def intersection_by_keys(self, keys: Iterable):
285
+ """
286
+ Get the intersection of the stack by the keys.
287
+ """
288
+ return type(self)({k: v for k, v in self.items() if k in keys})
289
+
290
+ def intersection_by_values(self, values: Iterable, by: str = 'id'):
291
+ """
292
+ Get the intersection of the stack by the values.
293
+
294
+ Args:
295
+ values: The value ids.
296
+ by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
297
+ """
298
+ if by == 'id':
299
+ value_ids = {id(v) for v in values}
300
+ return type(self)({k: v for k, v in self.items() if id(v) in value_ids})
301
+ elif by == 'value':
302
+ return type(self)({k: v for k, v in self.items() if v in values})
303
+ else:
304
+ raise ValueError(f'Unsupported method: {by}')
305
+
306
+ def __add__(self, other: dict):
307
+ """
308
+ Compose other instance of dict.
309
+ """
310
+ new_dict = type(self)(self)
311
+ new_dict.update(other)
312
+ return new_dict
313
+
314
+ def tree_flatten(self):
315
+ return tuple(self.values()), tuple(self.keys())
316
+
317
+ @classmethod
318
+ def tree_unflatten(cls, keys, values):
319
+ return cls(jax.util.safe_zip(keys, values))
320
+
321
+ def _check_elem(self, elem: Any):
322
+ raise NotImplementedError
323
+
324
+ def to_dict(self):
325
+ """
326
+ Convert the stack to a dict.
327
+
328
+ Returns
329
+ -------
330
+ dict
331
+ The dict object.
332
+ """
333
+ return dict(self)
334
+
335
+ def __copy__(self):
336
+ return type(self)(self)
337
+
338
+
339
+ @set_module_as('brainstate.util')
340
+ def clear_buffer_memory(
341
+ platform: str = None,
342
+ array: bool = True,
343
+ compilation: bool = False,
344
+ ):
345
+ """Clear all on-device buffers.
346
+
347
+ This function will be very useful when you call models in a Python loop,
348
+ because it can clear all cached arrays, and clear device memory.
349
+
350
+ .. warning::
351
+
352
+ This operation may cause errors when you use a deleted buffer.
353
+ Therefore, regenerate data always.
354
+
355
+ Parameters
356
+ ----------
357
+ platform: str
358
+ The device to clear its memory.
359
+ array: bool
360
+ Clear all buffer array. Default is True.
361
+ compilation: bool
362
+ Clear compilation cache. Default is False.
363
+
364
+ """
365
+ if array:
366
+ for buf in xla_bridge.get_backend(platform).live_buffers():
367
+ buf.delete()
368
+ if compilation:
369
+ jax.clear_caches()
370
+ gc.collect()
371
+
372
+
373
+ @jax.tree_util.register_pytree_node_class
374
+ class DotDict(dict):
375
+ """Python dictionaries with advanced dot notation access.
376
+
377
+ For example:
378
+
379
+ >>> d = DotDict({'a': 10, 'b': 20})
380
+ >>> d.a
381
+ 10
382
+ >>> d['a']
383
+ 10
384
+ >>> d.c # this will raise a KeyError
385
+ KeyError: 'c'
386
+ >>> d.c = 30 # but you can assign a value to a non-existing item
387
+ >>> d.c
388
+ 30
389
+ """
390
+
391
+ __module__ = 'brainstate.util'
392
+
393
+ def __init__(self, *args, **kwargs):
394
+ object.__setattr__(self, '__parent', kwargs.pop('__parent', None))
395
+ object.__setattr__(self, '__key', kwargs.pop('__key', None))
396
+ for arg in args:
397
+ if not arg:
398
+ continue
399
+ elif isinstance(arg, dict):
400
+ for key, val in arg.items():
401
+ self[key] = self._hook(val)
402
+ elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)):
403
+ self[arg[0]] = self._hook(arg[1])
404
+ else:
405
+ for key, val in iter(arg):
406
+ self[key] = self._hook(val)
407
+
408
+ for key, val in kwargs.items():
409
+ self[key] = self._hook(val)
410
+
411
+ def __setattr__(self, name, value):
412
+ if hasattr(self.__class__, name):
413
+ raise AttributeError(f"Attribute '{name}' is read-only in '{type(self)}' object.")
414
+ else:
415
+ self[name] = value
416
+
417
+ def __setitem__(self, name, value):
418
+ super(DotDict, self).__setitem__(name, value)
419
+ try:
420
+ p = object.__getattribute__(self, '__parent')
421
+ key = object.__getattribute__(self, '__key')
422
+ except AttributeError:
423
+ p = None
424
+ key = None
425
+ if p is not None:
426
+ p[key] = self
427
+ object.__delattr__(self, '__parent')
428
+ object.__delattr__(self, '__key')
429
+
430
+ @classmethod
431
+ def _hook(cls, item):
432
+ if isinstance(item, dict):
433
+ return cls(item)
434
+ elif isinstance(item, (list, tuple)):
435
+ return type(item)(cls._hook(elem) for elem in item)
436
+ return item
437
+
438
+ def __getattr__(self, item):
439
+ return self.__getitem__(item)
440
+
441
+ def __delattr__(self, name):
442
+ del self[name]
443
+
444
+ def copy(self):
445
+ return copy.copy(self)
446
+
447
+ def deepcopy(self):
448
+ return copy.deepcopy(self)
449
+
450
+ def __deepcopy__(self, memo):
451
+ other = self.__class__()
452
+ memo[id(self)] = other
453
+ for key, value in self.items():
454
+ other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo)
455
+ return other
456
+
457
+ def to_dict(self):
458
+ base = {}
459
+ for key, value in self.items():
460
+ if isinstance(value, type(self)):
461
+ base[key] = value.to_dict()
462
+ elif isinstance(value, (list, tuple)):
463
+ base[key] = type(value)(item.to_dict() if isinstance(item, type(self)) else item
464
+ for item in value)
465
+ else:
466
+ base[key] = value
467
+ return base
468
+
469
+ def update(self, *args, **kwargs):
470
+ other = {}
471
+ if args:
472
+ if len(args) > 1:
473
+ raise TypeError()
474
+ other.update(args[0])
475
+ other.update(kwargs)
476
+ for k, v in other.items():
477
+ if (k not in self) or (not isinstance(self[k], dict)) or (not isinstance(v, dict)):
478
+ self[k] = v
479
+ else:
480
+ self[k].update(v)
481
+
482
+ def __getnewargs__(self):
483
+ return tuple(self.items())
484
+
485
+ def __getstate__(self):
486
+ return self
487
+
488
+ def __setstate__(self, state):
489
+ self.update(state)
490
+
491
+ def setdefault(self, key, default=None):
492
+ if key in self:
493
+ return self[key]
494
+ else:
495
+ self[key] = default
496
+ return default
497
+
498
+ def tree_flatten(self):
499
+ return tuple(self.values()), tuple(self.keys())
500
+
501
+ @classmethod
502
+ def tree_unflatten(cls, keys, values):
503
+ return cls(jax.util.safe_zip(keys, values))
504
+
505
+
506
+ def _is_not_instance(x, cls):
507
+ return not isinstance(x, cls)
508
+
509
+
510
+ def _is_instance(x, cls):
511
+ return isinstance(x, cls)
512
+
513
+
514
+ @set_module_as('brainstate.util')
515
+ def not_instance_eval(*cls):
516
+ """
517
+ Create a partial function to evaluate if the input is not an instance of the given class.
518
+
519
+ Args:
520
+ *cls: The classes to check.
521
+
522
+ Returns:
523
+ The partial function.
524
+
525
+ """
526
+ return functools.partial(_is_not_instance, cls=cls)
527
+
528
+
529
+ @set_module_as('brainstate.util')
530
+ def is_instance_eval(*cls):
531
+ """
532
+ Create a partial function to evaluate if the input is an instance of the given class.
533
+
534
+ Args:
535
+ *cls: The classes to check.
536
+
537
+ Returns:
538
+ The partial function.
539
+ """
540
+ return functools.partial(_is_instance, cls=cls)