brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/transform.py DELETED
@@ -1,23 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # alias for compilation and augmentation functions
17
-
18
- from .augment import *
19
- from .compile import *
20
-
21
- if __name__ == '__main__':
22
- ifelse
23
- grad
brainstate/util/caller.py DELETED
@@ -1,98 +0,0 @@
1
- # The file is adapted from the Flax library (https://github.com/google/flax).
2
- # The credit should go to the Flax authors.
3
- #
4
- # Copyright 2024 The Flax Authors.
5
- #
6
- # Licensed under the Apache License, Version 2.0 (the "License");
7
- # you may not use this file except in compliance with the License.
8
- # You may obtain a copy of the License at
9
- #
10
- # http://www.apache.org/licenses/LICENSE-2.0
11
- #
12
- # Unless required by applicable law or agreed to in writing, software
13
- # distributed under the License is distributed on an "AS IS" BASIS,
14
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
- # See the License for the specific language governing permissions and
16
- # limitations under the License.
17
-
18
- import dataclasses
19
- from typing import Any, TypeVar, Protocol, Generic
20
-
21
- import jax
22
-
23
- __all__ = [
24
- 'DelayedAccessor',
25
- 'CallableProxy',
26
- 'ApplyCaller',
27
- ]
28
-
29
- A = TypeVar('A', covariant=True) # type: ignore[not-supported-yet]
30
-
31
-
32
- def _identity(x):
33
- return x
34
-
35
-
36
- @dataclasses.dataclass(frozen=True)
37
- class GetItem:
38
- key: Any
39
-
40
-
41
- @dataclasses.dataclass(frozen=True)
42
- class GetAttr:
43
- name: str
44
-
45
-
46
- @dataclasses.dataclass(frozen=True)
47
- class DelayedAccessor:
48
- actions: tuple[GetItem | GetAttr, ...] = ()
49
-
50
- def __call__(self, x):
51
- for action in self.actions:
52
- if isinstance(action, GetItem):
53
- x = x[action.key]
54
- elif isinstance(action, GetAttr):
55
- x = getattr(x, action.name)
56
- return x
57
-
58
- def __getattr__(self, name):
59
- return DelayedAccessor(self.actions + (GetAttr(name),))
60
-
61
- def __getitem__(self, key):
62
- return DelayedAccessor(self.actions + (GetItem(key),))
63
-
64
-
65
- jax.tree_util.register_static(DelayedAccessor)
66
-
67
-
68
- class _AccessorCall(Protocol):
69
- def __call__(self, accessor: DelayedAccessor, /, *args, **kwargs) -> Any:
70
- ...
71
-
72
-
73
- class CallableProxy:
74
- def __init__(
75
- self, fun: _AccessorCall, accessor: DelayedAccessor | None = None
76
- ):
77
- self._callable = fun
78
- self._accessor = DelayedAccessor() if accessor is None else accessor
79
-
80
- def __call__(self, *args, **kwargs):
81
- return self._callable(self._accessor, *args, **kwargs)
82
-
83
- def __getattr__(self, name) -> 'CallableProxy':
84
- return CallableProxy(self._callable, getattr(self._accessor, name))
85
-
86
- def __getitem__(self, key) -> 'CallableProxy':
87
- return CallableProxy(self._callable, self._accessor[key])
88
-
89
-
90
- class ApplyCaller(Protocol, Generic[A]):
91
- def __getattr__(self, __name) -> 'ApplyCaller[A]':
92
- ...
93
-
94
- def __getitem__(self, __name) -> 'ApplyCaller[A]':
95
- ...
96
-
97
- def __call__(self, *args, **kwargs) -> tuple[Any, A]:
98
- ...
brainstate/util/others.py DELETED
@@ -1,540 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import copy
17
- import functools
18
- import gc
19
- import threading
20
- import types
21
- from collections.abc import Iterable
22
- from typing import Any, Callable, Tuple, Union, Dict
23
-
24
- import jax
25
- from jax.lib import xla_bridge
26
-
27
- from brainstate._utils import set_module_as
28
-
29
- __all__ = [
30
- 'split_total',
31
- 'clear_buffer_memory',
32
- 'not_instance_eval',
33
- 'is_instance_eval',
34
- 'DictManager',
35
- 'DotDict',
36
- ]
37
-
38
-
39
- def split_total(
40
- total: int,
41
- fraction: Union[int, float],
42
- ) -> int:
43
- """
44
- Calculate the number of epochs for simulation based on a total and a fraction.
45
-
46
- This function determines the number of epochs to simulate given a total number
47
- of epochs and either a fraction or a specific number of epochs to run.
48
-
49
- Parameters:
50
- -----------
51
- total : int
52
- The total number of epochs. Must be a positive integer.
53
- fraction : Union[int, float]
54
- If ``float``: A value between 0 and 1 representing the fraction of total epochs to run.
55
- If ``int``: The specific number of epochs to run, must not exceed the total.
56
-
57
- Returns:
58
- --------
59
- int
60
- The calculated number of epochs to simulate.
61
-
62
- Raises:
63
- -------
64
- ValueError
65
- If total is not positive, fraction is negative, or if fraction as float is > 1
66
- or as int is > total.
67
- AssertionError
68
- If total is not an integer.
69
- """
70
- assert isinstance(total, int), "Length must be an integer."
71
- if total <= 0:
72
- raise ValueError("'total' must be a positive integer.")
73
- if fraction < 0:
74
- raise ValueError("'fraction' value cannot be negative.")
75
-
76
- if isinstance(fraction, float):
77
- if fraction < 0:
78
- raise ValueError("'fraction' value cannot be negative.")
79
- if fraction > 1:
80
- raise ValueError("'fraction' value cannot be greater than 1.")
81
- return int(total * fraction)
82
-
83
- elif isinstance(fraction, int):
84
- if fraction < 0:
85
- raise ValueError("'fraction' value cannot be negative.")
86
- if fraction > total:
87
- raise ValueError("'fraction' value cannot be greater than total.")
88
- return fraction
89
-
90
- else:
91
- raise ValueError("'fraction' must be an integer or float.")
92
-
93
-
94
- class NameContext(threading.local):
95
- def __init__(self):
96
- self.typed_names: Dict[str, int] = {}
97
-
98
-
99
- NAME = NameContext()
100
-
101
-
102
- def get_unique_name(type_: str):
103
- """Get the unique name for the given object type."""
104
- if type_ not in NAME.typed_names:
105
- NAME.typed_names[type_] = 0
106
- name = f'{type_}{NAME.typed_names[type_]}'
107
- NAME.typed_names[type_] += 1
108
- return name
109
-
110
-
111
- @jax.tree_util.register_pytree_node_class
112
- class DictManager(dict):
113
- """
114
- DictManager, for collecting all pytree used in the program.
115
-
116
- :py:class:`~.DictManager` supports all features of python dict.
117
- """
118
- __module__ = 'brainstate.util'
119
- _val_id_to_key: dict
120
-
121
- def subset(self, sep: Union[type, Tuple[type, ...], Callable]) -> 'DictManager':
122
- """
123
- Get a new stack with the subset of keys.
124
- """
125
- gather = type(self)()
126
- if isinstance(sep, types.FunctionType):
127
- for k, v in self.items():
128
- if sep(v):
129
- gather[k] = v
130
- return gather
131
- else:
132
- for k, v in self.items():
133
- if isinstance(v, sep):
134
- gather[k] = v
135
- return gather
136
-
137
- def not_subset(self, sep: Union[type, Tuple[type, ...]]) -> 'DictManager':
138
- """
139
- Get a new stack with the subset of keys.
140
- """
141
- gather = type(self)()
142
- for k, v in self.items():
143
- if not isinstance(v, sep):
144
- gather[k] = v
145
- return gather
146
-
147
- def add_unique_key(self, key: Any, val: Any):
148
- """
149
- Add a new element and check if the value is same or not.
150
- """
151
- self._check_elem(val)
152
- if key in self:
153
- if id(val) != id(self[key]):
154
- raise ValueError(f'{key} has been registered by {self[key]}, the new value is different from it.')
155
- else:
156
- self[key] = val
157
-
158
- def add_unique_value(self, key: Any, val: Any) -> bool:
159
- """
160
- Add a new element and check if the val is unique.
161
-
162
- Parameters:
163
- key: The key of the element.
164
- val: The value of the element
165
-
166
- Returns:
167
- bool: True if the value is unique, False otherwise.
168
- """
169
- self._check_elem(val)
170
- if not hasattr(self, '_val_id_to_key'):
171
- self._val_id_to_key = {id(v): k for k, v in self.items()}
172
- if id(val) not in self._val_id_to_key:
173
- self._val_id_to_key[id(val)] = key
174
- self[key] = val
175
- return True
176
- else:
177
- return False
178
-
179
- def unique(self) -> 'DictManager':
180
- """
181
- Get a new type of collections with unique values.
182
-
183
- If one value is assigned to two or more keys,
184
- then only one pair of (key, value) will be returned.
185
- """
186
- gather = type(self)()
187
- seen = set()
188
- for k, v in self.items():
189
- if id(v) not in seen:
190
- seen.add(id(v))
191
- gather[k] = v
192
- return gather
193
-
194
- def unique_(self):
195
- """
196
- Get a new type of collections with unique values.
197
-
198
- If one value is assigned to two or more keys,
199
- then only one pair of (key, value) will be returned.
200
- """
201
- seen = set()
202
- for k in tuple(self.keys()):
203
- v = self[k]
204
- if id(v) not in seen:
205
- seen.add(id(v))
206
- else:
207
- self.pop(k)
208
- return self
209
-
210
- def assign(self, *args) -> None:
211
- """
212
- Assign the value for each element according to the given ``data``.
213
- """
214
- for arg in args:
215
- assert isinstance(arg, dict), 'Must be an instance of dict.'
216
- for k, v in arg.items():
217
- self[k] = v
218
-
219
- def split(self, first: type, *others: type) -> Tuple['DictManager', ...]:
220
- """
221
- Split the stack into subsets of stack by the given types.
222
- """
223
- filters = (first, *others)
224
- results = tuple(type(self)() for _ in range(len(filters) + 1))
225
- for k, v in self.items():
226
- for i, filt in enumerate(filters):
227
- if isinstance(v, filt):
228
- results[i][k] = v
229
- break
230
- else:
231
- results[-1][k] = v
232
- return results
233
-
234
- def pop_by_keys(self, keys: Iterable):
235
- """
236
- Pop the elements by the keys.
237
- """
238
- for k in tuple(self.keys()):
239
- if k in keys:
240
- self.pop(k)
241
-
242
- def pop_by_values(self, values: Iterable, by: str = 'id'):
243
- """
244
- Pop the elements by the values.
245
-
246
- Args:
247
- values: The value ids.
248
- by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
249
- """
250
- if by == 'id':
251
- value_ids = {id(v) for v in values}
252
- for k in tuple(self.keys()):
253
- if id(self[k]) in value_ids:
254
- self.pop(k)
255
- elif by == 'value':
256
- for k in tuple(self.keys()):
257
- if self[k] in values:
258
- self.pop(k)
259
- else:
260
- raise ValueError(f'Unsupported method: {by}')
261
-
262
- def difference_by_keys(self, keys: Iterable):
263
- """
264
- Get the difference of the stack by the keys.
265
- """
266
- return type(self)({k: v for k, v in self.items() if k not in keys})
267
-
268
- def difference_by_values(self, values: Iterable, by: str = 'id'):
269
- """
270
- Get the difference of the stack by the values.
271
-
272
- Args:
273
- values: The value ids.
274
- by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
275
- """
276
- if by == 'id':
277
- value_ids = {id(v) for v in values}
278
- return type(self)({k: v for k, v in self.items() if id(v) not in value_ids})
279
- elif by == 'value':
280
- return type(self)({k: v for k, v in self.items() if v not in values})
281
- else:
282
- raise ValueError(f'Unsupported method: {by}')
283
-
284
- def intersection_by_keys(self, keys: Iterable):
285
- """
286
- Get the intersection of the stack by the keys.
287
- """
288
- return type(self)({k: v for k, v in self.items() if k in keys})
289
-
290
- def intersection_by_values(self, values: Iterable, by: str = 'id'):
291
- """
292
- Get the intersection of the stack by the values.
293
-
294
- Args:
295
- values: The value ids.
296
- by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
297
- """
298
- if by == 'id':
299
- value_ids = {id(v) for v in values}
300
- return type(self)({k: v for k, v in self.items() if id(v) in value_ids})
301
- elif by == 'value':
302
- return type(self)({k: v for k, v in self.items() if v in values})
303
- else:
304
- raise ValueError(f'Unsupported method: {by}')
305
-
306
- def __add__(self, other: dict):
307
- """
308
- Compose other instance of dict.
309
- """
310
- new_dict = type(self)(self)
311
- new_dict.update(other)
312
- return new_dict
313
-
314
- def tree_flatten(self):
315
- return tuple(self.values()), tuple(self.keys())
316
-
317
- @classmethod
318
- def tree_unflatten(cls, keys, values):
319
- return cls(jax.util.safe_zip(keys, values))
320
-
321
- def _check_elem(self, elem: Any):
322
- raise NotImplementedError
323
-
324
- def to_dict(self):
325
- """
326
- Convert the stack to a dict.
327
-
328
- Returns
329
- -------
330
- dict
331
- The dict object.
332
- """
333
- return dict(self)
334
-
335
- def __copy__(self):
336
- return type(self)(self)
337
-
338
-
339
- @set_module_as('brainstate.util')
340
- def clear_buffer_memory(
341
- platform: str = None,
342
- array: bool = True,
343
- compilation: bool = False,
344
- ):
345
- """Clear all on-device buffers.
346
-
347
- This function will be very useful when you call models in a Python loop,
348
- because it can clear all cached arrays, and clear device memory.
349
-
350
- .. warning::
351
-
352
- This operation may cause errors when you use a deleted buffer.
353
- Therefore, regenerate data always.
354
-
355
- Parameters
356
- ----------
357
- platform: str
358
- The device to clear its memory.
359
- array: bool
360
- Clear all buffer array. Default is True.
361
- compilation: bool
362
- Clear compilation cache. Default is False.
363
-
364
- """
365
- if array:
366
- for buf in xla_bridge.get_backend(platform).live_buffers():
367
- buf.delete()
368
- if compilation:
369
- jax.clear_caches()
370
- gc.collect()
371
-
372
-
373
- @jax.tree_util.register_pytree_node_class
374
- class DotDict(dict):
375
- """Python dictionaries with advanced dot notation access.
376
-
377
- For example:
378
-
379
- >>> d = DotDict({'a': 10, 'b': 20})
380
- >>> d.a
381
- 10
382
- >>> d['a']
383
- 10
384
- >>> d.c # this will raise a KeyError
385
- KeyError: 'c'
386
- >>> d.c = 30 # but you can assign a value to a non-existing item
387
- >>> d.c
388
- 30
389
- """
390
-
391
- __module__ = 'brainstate.util'
392
-
393
- def __init__(self, *args, **kwargs):
394
- object.__setattr__(self, '__parent', kwargs.pop('__parent', None))
395
- object.__setattr__(self, '__key', kwargs.pop('__key', None))
396
- for arg in args:
397
- if not arg:
398
- continue
399
- elif isinstance(arg, dict):
400
- for key, val in arg.items():
401
- self[key] = self._hook(val)
402
- elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)):
403
- self[arg[0]] = self._hook(arg[1])
404
- else:
405
- for key, val in iter(arg):
406
- self[key] = self._hook(val)
407
-
408
- for key, val in kwargs.items():
409
- self[key] = self._hook(val)
410
-
411
- def __setattr__(self, name, value):
412
- if hasattr(self.__class__, name):
413
- raise AttributeError(f"Attribute '{name}' is read-only in '{type(self)}' object.")
414
- else:
415
- self[name] = value
416
-
417
- def __setitem__(self, name, value):
418
- super(DotDict, self).__setitem__(name, value)
419
- try:
420
- p = object.__getattribute__(self, '__parent')
421
- key = object.__getattribute__(self, '__key')
422
- except AttributeError:
423
- p = None
424
- key = None
425
- if p is not None:
426
- p[key] = self
427
- object.__delattr__(self, '__parent')
428
- object.__delattr__(self, '__key')
429
-
430
- @classmethod
431
- def _hook(cls, item):
432
- if isinstance(item, dict):
433
- return cls(item)
434
- elif isinstance(item, (list, tuple)):
435
- return type(item)(cls._hook(elem) for elem in item)
436
- return item
437
-
438
- def __getattr__(self, item):
439
- return self.__getitem__(item)
440
-
441
- def __delattr__(self, name):
442
- del self[name]
443
-
444
- def copy(self):
445
- return copy.copy(self)
446
-
447
- def deepcopy(self):
448
- return copy.deepcopy(self)
449
-
450
- def __deepcopy__(self, memo):
451
- other = self.__class__()
452
- memo[id(self)] = other
453
- for key, value in self.items():
454
- other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo)
455
- return other
456
-
457
- def to_dict(self):
458
- base = {}
459
- for key, value in self.items():
460
- if isinstance(value, type(self)):
461
- base[key] = value.to_dict()
462
- elif isinstance(value, (list, tuple)):
463
- base[key] = type(value)(item.to_dict() if isinstance(item, type(self)) else item
464
- for item in value)
465
- else:
466
- base[key] = value
467
- return base
468
-
469
- def update(self, *args, **kwargs):
470
- other = {}
471
- if args:
472
- if len(args) > 1:
473
- raise TypeError()
474
- other.update(args[0])
475
- other.update(kwargs)
476
- for k, v in other.items():
477
- if (k not in self) or (not isinstance(self[k], dict)) or (not isinstance(v, dict)):
478
- self[k] = v
479
- else:
480
- self[k].update(v)
481
-
482
- def __getnewargs__(self):
483
- return tuple(self.items())
484
-
485
- def __getstate__(self):
486
- return self
487
-
488
- def __setstate__(self, state):
489
- self.update(state)
490
-
491
- def setdefault(self, key, default=None):
492
- if key in self:
493
- return self[key]
494
- else:
495
- self[key] = default
496
- return default
497
-
498
- def tree_flatten(self):
499
- return tuple(self.values()), tuple(self.keys())
500
-
501
- @classmethod
502
- def tree_unflatten(cls, keys, values):
503
- return cls(jax.util.safe_zip(keys, values))
504
-
505
-
506
- def _is_not_instance(x, cls):
507
- return not isinstance(x, cls)
508
-
509
-
510
- def _is_instance(x, cls):
511
- return isinstance(x, cls)
512
-
513
-
514
- @set_module_as('brainstate.util')
515
- def not_instance_eval(*cls):
516
- """
517
- Create a partial function to evaluate if the input is not an instance of the given class.
518
-
519
- Args:
520
- *cls: The classes to check.
521
-
522
- Returns:
523
- The partial function.
524
-
525
- """
526
- return functools.partial(_is_not_instance, cls=cls)
527
-
528
-
529
- @set_module_as('brainstate.util')
530
- def is_instance_eval(*cls):
531
- """
532
- Create a partial function to evaluate if the input is an instance of the given class.
533
-
534
- Args:
535
- *cls: The classes to check.
536
-
537
- Returns:
538
- The partial function.
539
- """
540
- return functools.partial(_is_instance, cls=cls)