brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/COBA_2005.py +125 -0
- benchmark/CUBA_2005.py +149 -0
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +611 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/event/__init__.py +27 -0
- brainstate/event/_csr.py +316 -0
- brainstate/event/_csr_benchmark.py +14 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +708 -0
- brainstate/event/_fixed_probability_benchmark.py +128 -0
- brainstate/event/_fixed_probability_test.py +131 -0
- brainstate/event/_linear.py +359 -0
- brainstate/event/_linear_benckmark.py +82 -0
- brainstate/event/_linear_test.py +117 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/event/_xla_custom_op.py +312 -0
- brainstate/event/_xla_custom_op_test.py +55 -0
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +41 -0
- brainstate/nn/_interaction/_conv.py +499 -0
- brainstate/nn/_interaction/_conv_test.py +239 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_linear.py +582 -0
- brainstate/nn/_interaction/_linear_test.py +42 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +121 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
- brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -0,0 +1,497 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import copy
|
19
|
+
import functools
|
20
|
+
import gc
|
21
|
+
import threading
|
22
|
+
import types
|
23
|
+
from collections.abc import Iterable
|
24
|
+
from typing import Any, Callable, Tuple, Union, Dict
|
25
|
+
|
26
|
+
import jax
|
27
|
+
from jax.lib import xla_bridge
|
28
|
+
|
29
|
+
from brainstate._utils import set_module_as
|
30
|
+
|
31
|
+
__all__ = [
|
32
|
+
'clear_buffer_memory',
|
33
|
+
'not_instance_eval',
|
34
|
+
'is_instance_eval',
|
35
|
+
'DictManager',
|
36
|
+
'DotDict',
|
37
|
+
]
|
38
|
+
|
39
|
+
|
40
|
+
class NameContext(threading.local):
|
41
|
+
def __init__(self):
|
42
|
+
self.typed_names: Dict[str, int] = {}
|
43
|
+
|
44
|
+
|
45
|
+
NAME = NameContext()
|
46
|
+
|
47
|
+
|
48
|
+
def get_unique_name(type_: str):
|
49
|
+
"""Get the unique name for the given object type."""
|
50
|
+
if type_ not in NAME.typed_names:
|
51
|
+
NAME.typed_names[type_] = 0
|
52
|
+
name = f'{type_}{NAME.typed_names[type_]}'
|
53
|
+
NAME.typed_names[type_] += 1
|
54
|
+
return name
|
55
|
+
|
56
|
+
|
57
|
+
@jax.tree_util.register_pytree_node_class
|
58
|
+
class DictManager(dict):
|
59
|
+
"""
|
60
|
+
DictManager, for collecting all pytree used in the program.
|
61
|
+
|
62
|
+
:py:class:`~.DictManager` supports all features of python dict.
|
63
|
+
"""
|
64
|
+
__module__ = 'brainstate.util'
|
65
|
+
_val_id_to_key: dict
|
66
|
+
|
67
|
+
def subset(self, sep: Union[type, Tuple[type, ...], Callable]) -> 'DictManager':
|
68
|
+
"""
|
69
|
+
Get a new stack with the subset of keys.
|
70
|
+
"""
|
71
|
+
gather = type(self)()
|
72
|
+
if isinstance(sep, types.FunctionType):
|
73
|
+
for k, v in self.items():
|
74
|
+
if sep(v):
|
75
|
+
gather[k] = v
|
76
|
+
return gather
|
77
|
+
else:
|
78
|
+
for k, v in self.items():
|
79
|
+
if isinstance(v, sep):
|
80
|
+
gather[k] = v
|
81
|
+
return gather
|
82
|
+
|
83
|
+
def not_subset(self, sep: Union[type, Tuple[type, ...]]) -> 'DictManager':
|
84
|
+
"""
|
85
|
+
Get a new stack with the subset of keys.
|
86
|
+
"""
|
87
|
+
gather = type(self)()
|
88
|
+
for k, v in self.items():
|
89
|
+
if not isinstance(v, sep):
|
90
|
+
gather[k] = v
|
91
|
+
return gather
|
92
|
+
|
93
|
+
def add_unique_key(self, key: Any, val: Any):
|
94
|
+
"""
|
95
|
+
Add a new element and check if the value is same or not.
|
96
|
+
"""
|
97
|
+
self._check_elem(val)
|
98
|
+
if key in self:
|
99
|
+
if id(val) != id(self[key]):
|
100
|
+
raise ValueError(f'{key} has been registered by {self[key]}, the new value is different from it.')
|
101
|
+
else:
|
102
|
+
self[key] = val
|
103
|
+
|
104
|
+
def add_unique_value(self, key: Any, val: Any) -> bool:
|
105
|
+
"""
|
106
|
+
Add a new element and check if the val is unique.
|
107
|
+
|
108
|
+
Parameters:
|
109
|
+
key: The key of the element.
|
110
|
+
val: The value of the element
|
111
|
+
|
112
|
+
Returns:
|
113
|
+
bool: True if the value is unique, False otherwise.
|
114
|
+
"""
|
115
|
+
self._check_elem(val)
|
116
|
+
if not hasattr(self, '_val_id_to_key'):
|
117
|
+
self._val_id_to_key = {id(v): k for k, v in self.items()}
|
118
|
+
if id(val) not in self._val_id_to_key:
|
119
|
+
self._val_id_to_key[id(val)] = key
|
120
|
+
self[key] = val
|
121
|
+
return True
|
122
|
+
else:
|
123
|
+
return False
|
124
|
+
|
125
|
+
def unique(self) -> 'DictManager':
|
126
|
+
"""
|
127
|
+
Get a new type of collections with unique values.
|
128
|
+
|
129
|
+
If one value is assigned to two or more keys,
|
130
|
+
then only one pair of (key, value) will be returned.
|
131
|
+
"""
|
132
|
+
gather = type(self)()
|
133
|
+
seen = set()
|
134
|
+
for k, v in self.items():
|
135
|
+
if id(v) not in seen:
|
136
|
+
seen.add(id(v))
|
137
|
+
gather[k] = v
|
138
|
+
return gather
|
139
|
+
|
140
|
+
def unique_(self):
|
141
|
+
"""
|
142
|
+
Get a new type of collections with unique values.
|
143
|
+
|
144
|
+
If one value is assigned to two or more keys,
|
145
|
+
then only one pair of (key, value) will be returned.
|
146
|
+
"""
|
147
|
+
seen = set()
|
148
|
+
for k in tuple(self.keys()):
|
149
|
+
v = self[k]
|
150
|
+
if id(v) not in seen:
|
151
|
+
seen.add(id(v))
|
152
|
+
else:
|
153
|
+
self.pop(k)
|
154
|
+
return self
|
155
|
+
|
156
|
+
def assign(self, *args) -> None:
|
157
|
+
"""
|
158
|
+
Assign the value for each element according to the given ``data``.
|
159
|
+
"""
|
160
|
+
for arg in args:
|
161
|
+
assert isinstance(arg, dict), 'Must be an instance of dict.'
|
162
|
+
for k, v in arg.items():
|
163
|
+
self[k] = v
|
164
|
+
|
165
|
+
def split(self, first: type, *others: type) -> Tuple['DictManager', ...]:
|
166
|
+
"""
|
167
|
+
Split the stack into subsets of stack by the given types.
|
168
|
+
"""
|
169
|
+
filters = (first, *others)
|
170
|
+
results = tuple(type(self)() for _ in range(len(filters) + 1))
|
171
|
+
for k, v in self.items():
|
172
|
+
for i, filt in enumerate(filters):
|
173
|
+
if isinstance(v, filt):
|
174
|
+
results[i][k] = v
|
175
|
+
break
|
176
|
+
else:
|
177
|
+
results[-1][k] = v
|
178
|
+
return results
|
179
|
+
|
180
|
+
def pop_by_keys(self, keys: Iterable):
|
181
|
+
"""
|
182
|
+
Pop the elements by the keys.
|
183
|
+
"""
|
184
|
+
for k in tuple(self.keys()):
|
185
|
+
if k in keys:
|
186
|
+
self.pop(k)
|
187
|
+
|
188
|
+
def pop_by_values(self, values: Iterable, by: str = 'id'):
|
189
|
+
"""
|
190
|
+
Pop the elements by the values.
|
191
|
+
|
192
|
+
Args:
|
193
|
+
values: The value ids.
|
194
|
+
by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
|
195
|
+
"""
|
196
|
+
if by == 'id':
|
197
|
+
value_ids = {id(v) for v in values}
|
198
|
+
for k in tuple(self.keys()):
|
199
|
+
if id(self[k]) in value_ids:
|
200
|
+
self.pop(k)
|
201
|
+
elif by == 'value':
|
202
|
+
for k in tuple(self.keys()):
|
203
|
+
if self[k] in values:
|
204
|
+
self.pop(k)
|
205
|
+
else:
|
206
|
+
raise ValueError(f'Unsupported method: {by}')
|
207
|
+
|
208
|
+
def difference_by_keys(self, keys: Iterable):
|
209
|
+
"""
|
210
|
+
Get the difference of the stack by the keys.
|
211
|
+
"""
|
212
|
+
return type(self)({k: v for k, v in self.items() if k not in keys})
|
213
|
+
|
214
|
+
def difference_by_values(self, values: Iterable, by: str = 'id'):
|
215
|
+
"""
|
216
|
+
Get the difference of the stack by the values.
|
217
|
+
|
218
|
+
Args:
|
219
|
+
values: The value ids.
|
220
|
+
by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
|
221
|
+
"""
|
222
|
+
if by == 'id':
|
223
|
+
value_ids = {id(v) for v in values}
|
224
|
+
return type(self)({k: v for k, v in self.items() if id(v) not in value_ids})
|
225
|
+
elif by == 'value':
|
226
|
+
return type(self)({k: v for k, v in self.items() if v not in values})
|
227
|
+
else:
|
228
|
+
raise ValueError(f'Unsupported method: {by}')
|
229
|
+
|
230
|
+
def intersection_by_keys(self, keys: Iterable):
|
231
|
+
"""
|
232
|
+
Get the intersection of the stack by the keys.
|
233
|
+
"""
|
234
|
+
return type(self)({k: v for k, v in self.items() if k in keys})
|
235
|
+
|
236
|
+
def intersection_by_values(self, values: Iterable, by: str = 'id'):
|
237
|
+
"""
|
238
|
+
Get the intersection of the stack by the values.
|
239
|
+
|
240
|
+
Args:
|
241
|
+
values: The value ids.
|
242
|
+
by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
|
243
|
+
"""
|
244
|
+
if by == 'id':
|
245
|
+
value_ids = {id(v) for v in values}
|
246
|
+
return type(self)({k: v for k, v in self.items() if id(v) in value_ids})
|
247
|
+
elif by == 'value':
|
248
|
+
return type(self)({k: v for k, v in self.items() if v in values})
|
249
|
+
else:
|
250
|
+
raise ValueError(f'Unsupported method: {by}')
|
251
|
+
|
252
|
+
def union_by_value_ids(self, other: dict):
|
253
|
+
"""
|
254
|
+
Union the stack by the value ids.
|
255
|
+
|
256
|
+
Args:
|
257
|
+
other:
|
258
|
+
|
259
|
+
Returns:
|
260
|
+
|
261
|
+
"""
|
262
|
+
|
263
|
+
def __add__(self, other: dict):
|
264
|
+
"""
|
265
|
+
Compose other instance of dict.
|
266
|
+
"""
|
267
|
+
new_dict = type(self)(self)
|
268
|
+
new_dict.update(other)
|
269
|
+
return new_dict
|
270
|
+
|
271
|
+
def tree_flatten(self):
|
272
|
+
return tuple(self.values()), tuple(self.keys())
|
273
|
+
|
274
|
+
@classmethod
|
275
|
+
def tree_unflatten(cls, keys, values):
|
276
|
+
return cls(jax.util.safe_zip(keys, values))
|
277
|
+
|
278
|
+
def _check_elem(self, elem: Any):
|
279
|
+
raise NotImplementedError
|
280
|
+
|
281
|
+
def to_dict(self):
|
282
|
+
"""
|
283
|
+
Convert the stack to a dict.
|
284
|
+
|
285
|
+
Returns
|
286
|
+
-------
|
287
|
+
dict
|
288
|
+
The dict object.
|
289
|
+
"""
|
290
|
+
return dict(self)
|
291
|
+
|
292
|
+
def __copy__(self):
|
293
|
+
return type(self)(self)
|
294
|
+
|
295
|
+
|
296
|
+
@set_module_as('brainstate.util')
|
297
|
+
def clear_buffer_memory(
|
298
|
+
platform: str = None,
|
299
|
+
array: bool = True,
|
300
|
+
compilation: bool = False,
|
301
|
+
):
|
302
|
+
"""Clear all on-device buffers.
|
303
|
+
|
304
|
+
This function will be very useful when you call models in a Python loop,
|
305
|
+
because it can clear all cached arrays, and clear device memory.
|
306
|
+
|
307
|
+
.. warning::
|
308
|
+
|
309
|
+
This operation may cause errors when you use a deleted buffer.
|
310
|
+
Therefore, regenerate data always.
|
311
|
+
|
312
|
+
Parameters
|
313
|
+
----------
|
314
|
+
platform: str
|
315
|
+
The device to clear its memory.
|
316
|
+
array: bool
|
317
|
+
Clear all buffer array. Default is True.
|
318
|
+
compilation: bool
|
319
|
+
Clear compilation cache. Default is False.
|
320
|
+
|
321
|
+
"""
|
322
|
+
if array:
|
323
|
+
for buf in xla_bridge.get_backend(platform).live_buffers():
|
324
|
+
buf.delete()
|
325
|
+
if compilation:
|
326
|
+
jax.clear_caches()
|
327
|
+
gc.collect()
|
328
|
+
|
329
|
+
|
330
|
+
@jax.tree_util.register_pytree_node_class
|
331
|
+
class DotDict(dict):
|
332
|
+
"""Python dictionaries with advanced dot notation access.
|
333
|
+
|
334
|
+
For example:
|
335
|
+
|
336
|
+
>>> d = DotDict({'a': 10, 'b': 20})
|
337
|
+
>>> d.a
|
338
|
+
10
|
339
|
+
>>> d['a']
|
340
|
+
10
|
341
|
+
>>> d.c # this will raise a KeyError
|
342
|
+
KeyError: 'c'
|
343
|
+
>>> d.c = 30 # but you can assign a value to a non-existing item
|
344
|
+
>>> d.c
|
345
|
+
30
|
346
|
+
"""
|
347
|
+
|
348
|
+
__module__ = 'brainstate.util'
|
349
|
+
|
350
|
+
def __init__(self, *args, **kwargs):
|
351
|
+
object.__setattr__(self, '__parent', kwargs.pop('__parent', None))
|
352
|
+
object.__setattr__(self, '__key', kwargs.pop('__key', None))
|
353
|
+
for arg in args:
|
354
|
+
if not arg:
|
355
|
+
continue
|
356
|
+
elif isinstance(arg, dict):
|
357
|
+
for key, val in arg.items():
|
358
|
+
self[key] = self._hook(val)
|
359
|
+
elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)):
|
360
|
+
self[arg[0]] = self._hook(arg[1])
|
361
|
+
else:
|
362
|
+
for key, val in iter(arg):
|
363
|
+
self[key] = self._hook(val)
|
364
|
+
|
365
|
+
for key, val in kwargs.items():
|
366
|
+
self[key] = self._hook(val)
|
367
|
+
|
368
|
+
def __setattr__(self, name, value):
|
369
|
+
if hasattr(self.__class__, name):
|
370
|
+
raise AttributeError(f"Attribute '{name}' is read-only in '{type(self)}' object.")
|
371
|
+
else:
|
372
|
+
self[name] = value
|
373
|
+
|
374
|
+
def __setitem__(self, name, value):
|
375
|
+
super(DotDict, self).__setitem__(name, value)
|
376
|
+
try:
|
377
|
+
p = object.__getattribute__(self, '__parent')
|
378
|
+
key = object.__getattribute__(self, '__key')
|
379
|
+
except AttributeError:
|
380
|
+
p = None
|
381
|
+
key = None
|
382
|
+
if p is not None:
|
383
|
+
p[key] = self
|
384
|
+
object.__delattr__(self, '__parent')
|
385
|
+
object.__delattr__(self, '__key')
|
386
|
+
|
387
|
+
@classmethod
|
388
|
+
def _hook(cls, item):
|
389
|
+
if isinstance(item, dict):
|
390
|
+
return cls(item)
|
391
|
+
elif isinstance(item, (list, tuple)):
|
392
|
+
return type(item)(cls._hook(elem) for elem in item)
|
393
|
+
return item
|
394
|
+
|
395
|
+
def __getattr__(self, item):
|
396
|
+
return self.__getitem__(item)
|
397
|
+
|
398
|
+
def __delattr__(self, name):
|
399
|
+
del self[name]
|
400
|
+
|
401
|
+
def copy(self):
|
402
|
+
return copy.copy(self)
|
403
|
+
|
404
|
+
def deepcopy(self):
|
405
|
+
return copy.deepcopy(self)
|
406
|
+
|
407
|
+
def __deepcopy__(self, memo):
|
408
|
+
other = self.__class__()
|
409
|
+
memo[id(self)] = other
|
410
|
+
for key, value in self.items():
|
411
|
+
other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo)
|
412
|
+
return other
|
413
|
+
|
414
|
+
def to_dict(self):
|
415
|
+
base = {}
|
416
|
+
for key, value in self.items():
|
417
|
+
if isinstance(value, type(self)):
|
418
|
+
base[key] = value.to_dict()
|
419
|
+
elif isinstance(value, (list, tuple)):
|
420
|
+
base[key] = type(value)(item.to_dict() if isinstance(item, type(self)) else item
|
421
|
+
for item in value)
|
422
|
+
else:
|
423
|
+
base[key] = value
|
424
|
+
return base
|
425
|
+
|
426
|
+
def update(self, *args, **kwargs):
|
427
|
+
other = {}
|
428
|
+
if args:
|
429
|
+
if len(args) > 1:
|
430
|
+
raise TypeError()
|
431
|
+
other.update(args[0])
|
432
|
+
other.update(kwargs)
|
433
|
+
for k, v in other.items():
|
434
|
+
if (k not in self) or (not isinstance(self[k], dict)) or (not isinstance(v, dict)):
|
435
|
+
self[k] = v
|
436
|
+
else:
|
437
|
+
self[k].update(v)
|
438
|
+
|
439
|
+
def __getnewargs__(self):
|
440
|
+
return tuple(self.items())
|
441
|
+
|
442
|
+
def __getstate__(self):
|
443
|
+
return self
|
444
|
+
|
445
|
+
def __setstate__(self, state):
|
446
|
+
self.update(state)
|
447
|
+
|
448
|
+
def setdefault(self, key, default=None):
|
449
|
+
if key in self:
|
450
|
+
return self[key]
|
451
|
+
else:
|
452
|
+
self[key] = default
|
453
|
+
return default
|
454
|
+
|
455
|
+
def tree_flatten(self):
|
456
|
+
return tuple(self.values()), tuple(self.keys())
|
457
|
+
|
458
|
+
@classmethod
|
459
|
+
def tree_unflatten(cls, keys, values):
|
460
|
+
return cls(jax.util.safe_zip(keys, values))
|
461
|
+
|
462
|
+
|
463
|
+
def _is_not_instance(x, cls):
|
464
|
+
return not isinstance(x, cls)
|
465
|
+
|
466
|
+
|
467
|
+
def _is_instance(x, cls):
|
468
|
+
return isinstance(x, cls)
|
469
|
+
|
470
|
+
|
471
|
+
@set_module_as('brainstate.util')
|
472
|
+
def not_instance_eval(*cls):
|
473
|
+
"""
|
474
|
+
Create a partial function to evaluate if the input is not an instance of the given class.
|
475
|
+
|
476
|
+
Args:
|
477
|
+
*cls: The classes to check.
|
478
|
+
|
479
|
+
Returns:
|
480
|
+
The partial function.
|
481
|
+
|
482
|
+
"""
|
483
|
+
return functools.partial(_is_not_instance, cls=cls)
|
484
|
+
|
485
|
+
|
486
|
+
@set_module_as('brainstate.util')
|
487
|
+
def is_instance_eval(*cls):
|
488
|
+
"""
|
489
|
+
Create a partial function to evaluate if the input is an instance of the given class.
|
490
|
+
|
491
|
+
Args:
|
492
|
+
*cls: The classes to check.
|
493
|
+
|
494
|
+
Returns:
|
495
|
+
The partial function.
|
496
|
+
"""
|
497
|
+
return functools.partial(_is_instance, cls=cls)
|
@@ -0,0 +1,208 @@
|
|
1
|
+
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
|
+
# The credit should go to the Flax authors.
|
3
|
+
#
|
4
|
+
# Copyright 2024 The Flax Authors & 2024 BDP Ecosystem.
|
5
|
+
#
|
6
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
|
+
# you may not use this file except in compliance with the License.
|
8
|
+
# You may obtain a copy of the License at
|
9
|
+
#
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11
|
+
#
|
12
|
+
# Unless required by applicable law or agreed to in writing, software
|
13
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15
|
+
# See the License for the specific language governing permissions and
|
16
|
+
# limitations under the License.
|
17
|
+
|
18
|
+
from __future__ import annotations
|
19
|
+
|
20
|
+
import dataclasses
|
21
|
+
import threading
|
22
|
+
from abc import ABC, abstractmethod
|
23
|
+
from functools import partial
|
24
|
+
from typing import Any, Iterator, Mapping, TypeVar, Union, Callable, Optional
|
25
|
+
|
26
|
+
__all__ = [
|
27
|
+
'PrettyType',
|
28
|
+
'PrettyAttr',
|
29
|
+
'PrettyRepr',
|
30
|
+
'PrettyMapping',
|
31
|
+
'MappingReprMixin',
|
32
|
+
]
|
33
|
+
|
34
|
+
A = TypeVar('A')
|
35
|
+
B = TypeVar('B')
|
36
|
+
|
37
|
+
|
38
|
+
@dataclasses.dataclass
|
39
|
+
class PrettyType:
|
40
|
+
"""
|
41
|
+
Configuration for pretty representation of objects.
|
42
|
+
"""
|
43
|
+
type: Union[str, type]
|
44
|
+
start: str = '('
|
45
|
+
end: str = ')'
|
46
|
+
value_sep: str = '='
|
47
|
+
elem_indent: str = ' '
|
48
|
+
empty_repr: str = ''
|
49
|
+
|
50
|
+
|
51
|
+
@dataclasses.dataclass
|
52
|
+
class PrettyAttr:
|
53
|
+
"""
|
54
|
+
Configuration for pretty representation of attributes.
|
55
|
+
"""
|
56
|
+
key: str
|
57
|
+
value: Union[str, Any]
|
58
|
+
start: str = ''
|
59
|
+
end: str = ''
|
60
|
+
|
61
|
+
|
62
|
+
class PrettyRepr(ABC):
|
63
|
+
"""
|
64
|
+
Interface for pretty representation of objects.
|
65
|
+
|
66
|
+
Example::
|
67
|
+
|
68
|
+
>>> class MyObject(PrettyRepr):
|
69
|
+
>>> def __pretty_repr__(self):
|
70
|
+
>>> yield PrettyType(type='MyObject', start='{', end='}')
|
71
|
+
>>> yield PrettyAttr('key', self.key)
|
72
|
+
>>> yield PrettyAttr('value', self.value)
|
73
|
+
|
74
|
+
"""
|
75
|
+
__slots__ = ()
|
76
|
+
|
77
|
+
@abstractmethod
|
78
|
+
def __pretty_repr__(self) -> Iterator[Union[PrettyType, PrettyAttr]]:
|
79
|
+
raise NotImplementedError
|
80
|
+
|
81
|
+
def __repr__(self) -> str:
|
82
|
+
# repr the individual object with the pretty representation
|
83
|
+
return get_repr(self)
|
84
|
+
|
85
|
+
|
86
|
+
def _repr_elem(obj: PrettyType, elem: Any) -> str:
|
87
|
+
if not isinstance(elem, PrettyAttr):
|
88
|
+
raise TypeError(f'Item must be Elem, got {type(elem).__name__}')
|
89
|
+
|
90
|
+
value = elem.value if isinstance(elem.value, str) else repr(elem.value)
|
91
|
+
value = value.replace('\n', '\n' + obj.elem_indent)
|
92
|
+
|
93
|
+
return f'{obj.elem_indent}{elem.start}{elem.key}{obj.value_sep}{value}{elem.end}'
|
94
|
+
|
95
|
+
|
96
|
+
def get_repr(obj: PrettyRepr) -> str:
|
97
|
+
"""
|
98
|
+
Get the pretty representation of an object.
|
99
|
+
"""
|
100
|
+
if not isinstance(obj, PrettyRepr):
|
101
|
+
raise TypeError(f'Object {obj!r} is not representable')
|
102
|
+
|
103
|
+
iterator = obj.__pretty_repr__()
|
104
|
+
obj_repr = next(iterator)
|
105
|
+
|
106
|
+
# repr object
|
107
|
+
if not isinstance(obj_repr, PrettyType):
|
108
|
+
raise TypeError(f'First item must be PrettyType, got {type(obj_repr).__name__}')
|
109
|
+
|
110
|
+
# repr attributes
|
111
|
+
elem_reprs = tuple(map(partial(_repr_elem, obj_repr), iterator))
|
112
|
+
elems = ',\n'.join(elem_reprs)
|
113
|
+
if elems:
|
114
|
+
elems = '\n' + elems + '\n'
|
115
|
+
else:
|
116
|
+
elems = obj_repr.empty_repr
|
117
|
+
|
118
|
+
# repr object type
|
119
|
+
type_repr = obj_repr.type if isinstance(obj_repr.type, str) else obj_repr.type.__name__
|
120
|
+
|
121
|
+
# return repr
|
122
|
+
return f'{type_repr}{obj_repr.start}{elems}{obj_repr.end}'
|
123
|
+
|
124
|
+
|
125
|
+
class MappingReprMixin(Mapping[A, B]):
|
126
|
+
"""
|
127
|
+
Mapping mixin for pretty representation.
|
128
|
+
"""
|
129
|
+
|
130
|
+
def __pretty_repr__(self):
|
131
|
+
yield PrettyType(type='', value_sep=': ', start='{', end='}')
|
132
|
+
|
133
|
+
for key, value in self.items():
|
134
|
+
yield PrettyAttr(repr(key), value)
|
135
|
+
|
136
|
+
|
137
|
+
@dataclasses.dataclass(repr=False)
|
138
|
+
class PrettyMapping(PrettyRepr):
|
139
|
+
"""
|
140
|
+
Pretty representation of a mapping.
|
141
|
+
"""
|
142
|
+
mapping: Mapping
|
143
|
+
|
144
|
+
def __pretty_repr__(self):
|
145
|
+
yield PrettyType(type='', value_sep=': ', start='{', end='}')
|
146
|
+
|
147
|
+
for key, value in self.mapping.items():
|
148
|
+
yield PrettyAttr(repr(key), value)
|
149
|
+
|
150
|
+
|
151
|
+
@dataclasses.dataclass
|
152
|
+
class PrettyReprContext(threading.local):
|
153
|
+
# seen_modules_repr: set[int] | None = None
|
154
|
+
seen_modules_repr: dict[int, Any] | None = None
|
155
|
+
|
156
|
+
|
157
|
+
CONTEXT = PrettyReprContext()
|
158
|
+
|
159
|
+
|
160
|
+
def _default_repr_object(node):
|
161
|
+
yield PrettyType(type=type(node))
|
162
|
+
|
163
|
+
|
164
|
+
def _default_repr_attr(node):
|
165
|
+
for name, value in vars(node).items():
|
166
|
+
if name.startswith('_'):
|
167
|
+
continue
|
168
|
+
yield PrettyAttr(name, repr(value))
|
169
|
+
|
170
|
+
|
171
|
+
def pretty_repr_avoid_duplicate(
|
172
|
+
node,
|
173
|
+
repr_object: Optional[Callable] = None,
|
174
|
+
repr_attr: Optional[Callable] = None
|
175
|
+
):
|
176
|
+
"""
|
177
|
+
Pretty representation of an object avoiding duplicate representations.
|
178
|
+
"""
|
179
|
+
if repr_object is None:
|
180
|
+
repr_object = _default_repr_object
|
181
|
+
if repr_attr is None:
|
182
|
+
repr_attr = _default_repr_attr
|
183
|
+
|
184
|
+
if CONTEXT.seen_modules_repr is None:
|
185
|
+
# CONTEXT.seen_modules_repr = set()
|
186
|
+
CONTEXT.seen_modules_repr = dict()
|
187
|
+
clear_seen = True
|
188
|
+
else:
|
189
|
+
clear_seen = False
|
190
|
+
|
191
|
+
# Avoid infinite recursion
|
192
|
+
if id(node) in CONTEXT.seen_modules_repr:
|
193
|
+
yield PrettyType(type=type(node), empty_repr='...')
|
194
|
+
return
|
195
|
+
|
196
|
+
# repr object
|
197
|
+
yield from repr_object(node)
|
198
|
+
|
199
|
+
# Add to seen modules
|
200
|
+
# CONTEXT.seen_modules_repr.add(id(node))
|
201
|
+
CONTEXT.seen_modules_repr[id(node)] = node
|
202
|
+
|
203
|
+
try:
|
204
|
+
# repr attributes
|
205
|
+
yield from repr_attr(node)
|
206
|
+
finally:
|
207
|
+
if clear_seen:
|
208
|
+
CONTEXT.seen_modules_repr = None
|