brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/COBA_2005.py +125 -0
- benchmark/CUBA_2005.py +149 -0
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +611 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/event/__init__.py +27 -0
- brainstate/event/_csr.py +316 -0
- brainstate/event/_csr_benchmark.py +14 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +708 -0
- brainstate/event/_fixed_probability_benchmark.py +128 -0
- brainstate/event/_fixed_probability_test.py +131 -0
- brainstate/event/_linear.py +359 -0
- brainstate/event/_linear_benckmark.py +82 -0
- brainstate/event/_linear_test.py +117 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/event/_xla_custom_op.py +312 -0
- brainstate/event/_xla_custom_op_test.py +55 -0
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +41 -0
- brainstate/nn/_interaction/_conv.py +499 -0
- brainstate/nn/_interaction/_conv_test.py +239 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_linear.py +582 -0
- brainstate/nn/_interaction/_linear_test.py +42 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +121 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
- brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
brainstate/util.py
DELETED
@@ -1,746 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import copy
|
17
|
-
import functools
|
18
|
-
import gc
|
19
|
-
import types
|
20
|
-
import warnings
|
21
|
-
from collections.abc import Iterable
|
22
|
-
from typing import Any, Callable, Tuple, Union, Sequence
|
23
|
-
|
24
|
-
import jax
|
25
|
-
from jax.lib import xla_bridge
|
26
|
-
|
27
|
-
from ._utils import set_module_as
|
28
|
-
|
29
|
-
__all__ = [
|
30
|
-
'unique_name',
|
31
|
-
'clear_buffer_memory',
|
32
|
-
'not_instance_eval',
|
33
|
-
'is_instance_eval',
|
34
|
-
'DictManager',
|
35
|
-
'MemScaling',
|
36
|
-
'IdMemScaling',
|
37
|
-
'DotDict',
|
38
|
-
]
|
39
|
-
|
40
|
-
_name2id = dict()
|
41
|
-
_typed_names = {}
|
42
|
-
|
43
|
-
|
44
|
-
@set_module_as('brainstate.util')
|
45
|
-
def check_name_uniqueness(name, obj):
|
46
|
-
"""Check the uniqueness of the name for the object type."""
|
47
|
-
if not name.isidentifier():
|
48
|
-
raise ValueError(
|
49
|
-
f'"{name}" isn\'t a valid identifier '
|
50
|
-
f'according to Python language definition. '
|
51
|
-
f'Please choose another name.'
|
52
|
-
)
|
53
|
-
if name in _name2id:
|
54
|
-
if _name2id[name] != id(obj):
|
55
|
-
raise ValueError(
|
56
|
-
f'In BrainPy, each object should have a unique name. '
|
57
|
-
f'However, we detect that {obj} has a used name "{name}". \n'
|
58
|
-
f'If you try to run multiple trials, you may need \n\n'
|
59
|
-
f'>>> brainpy.brainpy_object.clear_name_cache() \n\n'
|
60
|
-
f'to clear all cached names. '
|
61
|
-
)
|
62
|
-
else:
|
63
|
-
_name2id[name] = id(obj)
|
64
|
-
|
65
|
-
|
66
|
-
def get_unique_name(type_: str):
|
67
|
-
"""Get the unique name for the given object type."""
|
68
|
-
if type_ not in _typed_names:
|
69
|
-
_typed_names[type_] = 0
|
70
|
-
name = f'{type_}{_typed_names[type_]}'
|
71
|
-
_typed_names[type_] += 1
|
72
|
-
return name
|
73
|
-
|
74
|
-
|
75
|
-
@set_module_as('brainstate.util')
|
76
|
-
def unique_name(name=None, self=None):
|
77
|
-
"""Get the unique name for this object.
|
78
|
-
|
79
|
-
Parameters
|
80
|
-
----------
|
81
|
-
name : str, optional
|
82
|
-
The expected name. If None, the default unique name will be returned.
|
83
|
-
Otherwise, the provided name will be checked to guarantee its uniqueness.
|
84
|
-
self : str, optional
|
85
|
-
The name of this class, used for object naming.
|
86
|
-
|
87
|
-
Returns
|
88
|
-
-------
|
89
|
-
name : str
|
90
|
-
The unique name for this object.
|
91
|
-
"""
|
92
|
-
if name is None:
|
93
|
-
assert self is not None, 'If name is None, self should be provided.'
|
94
|
-
return get_unique_name(type_=self.__class__.__name__)
|
95
|
-
else:
|
96
|
-
check_name_uniqueness(name=name, obj=self)
|
97
|
-
return name
|
98
|
-
|
99
|
-
|
100
|
-
@set_module_as('brainstate.util')
|
101
|
-
def clear_name_cache(ignore_warn: bool = True):
|
102
|
-
"""Clear the cached names."""
|
103
|
-
_name2id.clear()
|
104
|
-
_typed_names.clear()
|
105
|
-
if not ignore_warn:
|
106
|
-
warnings.warn(f'All named models and their ids are cleared.', UserWarning)
|
107
|
-
|
108
|
-
|
109
|
-
@jax.tree_util.register_pytree_node_class
|
110
|
-
class DictManager(dict):
|
111
|
-
"""
|
112
|
-
DictManager, for collecting all pytree used in the program.
|
113
|
-
|
114
|
-
:py:class:`~.DictManager` supports all features of python dict.
|
115
|
-
"""
|
116
|
-
__module__ = 'brainstate.util'
|
117
|
-
|
118
|
-
def subset(self, sep: Union[type, Tuple[type, ...], Callable]) -> 'DictManager':
|
119
|
-
"""
|
120
|
-
Get a new stack with the subset of keys.
|
121
|
-
"""
|
122
|
-
gather = type(self)()
|
123
|
-
if isinstance(sep, types.FunctionType):
|
124
|
-
for k, v in self.items():
|
125
|
-
if sep(v):
|
126
|
-
gather[k] = v
|
127
|
-
return gather
|
128
|
-
else:
|
129
|
-
for k, v in self.items():
|
130
|
-
if isinstance(v, sep):
|
131
|
-
gather[k] = v
|
132
|
-
return gather
|
133
|
-
|
134
|
-
def not_subset(self, sep: Union[type, Tuple[type, ...]]) -> 'DictManager':
|
135
|
-
"""
|
136
|
-
Get a new stack with the subset of keys.
|
137
|
-
"""
|
138
|
-
gather = type(self)()
|
139
|
-
for k, v in self.items():
|
140
|
-
if not isinstance(v, sep):
|
141
|
-
gather[k] = v
|
142
|
-
return gather
|
143
|
-
|
144
|
-
def add_unique_elem(self, key: Any, var: Any):
|
145
|
-
"""Add a new element."""
|
146
|
-
self._check_elem(var)
|
147
|
-
if key in self:
|
148
|
-
if id(var) != id(self[key]):
|
149
|
-
raise ValueError(f'{key} has been registered by {self[key]}, the new value is different from it.')
|
150
|
-
else:
|
151
|
-
self[key] = var
|
152
|
-
|
153
|
-
def unique(self) -> 'DictManager':
|
154
|
-
"""
|
155
|
-
Get a new type of collections with unique values.
|
156
|
-
|
157
|
-
If one value is assigned to two or more keys,
|
158
|
-
then only one pair of (key, value) will be returned.
|
159
|
-
"""
|
160
|
-
gather = type(self)()
|
161
|
-
seen = set()
|
162
|
-
for k, v in self.items():
|
163
|
-
if id(v) not in seen:
|
164
|
-
seen.add(id(v))
|
165
|
-
gather[k] = v
|
166
|
-
return gather
|
167
|
-
|
168
|
-
def assign(self, *args) -> None:
|
169
|
-
"""
|
170
|
-
Assign the value for each element according to the given ``data``.
|
171
|
-
"""
|
172
|
-
for arg in args:
|
173
|
-
assert isinstance(arg, dict), 'Must be an instance of dict.'
|
174
|
-
for k, v in arg.items():
|
175
|
-
self[k] = v
|
176
|
-
|
177
|
-
def split(self, first: type, *others: type) -> Tuple['DictManager', ...]:
|
178
|
-
"""
|
179
|
-
Split the stack into subsets of stack by the given types.
|
180
|
-
"""
|
181
|
-
filters = (first, *others)
|
182
|
-
results = tuple(type(self)() for _ in range(len(filters) + 1))
|
183
|
-
for k, v in self.items():
|
184
|
-
for i, filt in enumerate(filters):
|
185
|
-
if isinstance(v, filt):
|
186
|
-
results[i][k] = v
|
187
|
-
break
|
188
|
-
else:
|
189
|
-
results[-1][k] = v
|
190
|
-
return results
|
191
|
-
|
192
|
-
def pop_by_keys(self, keys: Iterable):
|
193
|
-
"""
|
194
|
-
Pop the elements by the keys.
|
195
|
-
"""
|
196
|
-
for k in tuple(self.keys()):
|
197
|
-
if k in keys:
|
198
|
-
self.pop(k)
|
199
|
-
|
200
|
-
def pop_by_values(self, values: Iterable, by: str = 'id'):
|
201
|
-
"""
|
202
|
-
Pop the elements by the values.
|
203
|
-
|
204
|
-
Args:
|
205
|
-
values: The value ids.
|
206
|
-
by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
|
207
|
-
"""
|
208
|
-
if by == 'id':
|
209
|
-
value_ids = {id(v) for v in values}
|
210
|
-
for k in tuple(self.keys()):
|
211
|
-
if id(self[k]) in value_ids:
|
212
|
-
self.pop(k)
|
213
|
-
elif by == 'value':
|
214
|
-
for k in tuple(self.keys()):
|
215
|
-
if self[k] in values:
|
216
|
-
self.pop(k)
|
217
|
-
else:
|
218
|
-
raise ValueError(f'Unsupported method: {by}')
|
219
|
-
|
220
|
-
def difference_by_keys(self, keys: Iterable):
|
221
|
-
"""
|
222
|
-
Get the difference of the stack by the keys.
|
223
|
-
"""
|
224
|
-
return type(self)({k: v for k, v in self.items() if k not in keys})
|
225
|
-
|
226
|
-
def difference_by_values(self, values: Iterable, by: str = 'id'):
|
227
|
-
"""
|
228
|
-
Get the difference of the stack by the values.
|
229
|
-
|
230
|
-
Args:
|
231
|
-
values: The value ids.
|
232
|
-
by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
|
233
|
-
"""
|
234
|
-
if by == 'id':
|
235
|
-
value_ids = {id(v) for v in values}
|
236
|
-
return type(self)({k: v for k, v in self.items() if id(v) not in value_ids})
|
237
|
-
elif by == 'value':
|
238
|
-
return type(self)({k: v for k, v in self.items() if v not in values})
|
239
|
-
else:
|
240
|
-
raise ValueError(f'Unsupported method: {by}')
|
241
|
-
|
242
|
-
def intersection_by_keys(self, keys: Iterable):
|
243
|
-
"""
|
244
|
-
Get the intersection of the stack by the keys.
|
245
|
-
"""
|
246
|
-
return type(self)({k: v for k, v in self.items() if k in keys})
|
247
|
-
|
248
|
-
def intersection_by_values(self, values: Iterable, by: str = 'id'):
|
249
|
-
"""
|
250
|
-
Get the intersection of the stack by the values.
|
251
|
-
|
252
|
-
Args:
|
253
|
-
values: The value ids.
|
254
|
-
by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
|
255
|
-
"""
|
256
|
-
if by == 'id':
|
257
|
-
value_ids = {id(v) for v in values}
|
258
|
-
return type(self)({k: v for k, v in self.items() if id(v) in value_ids})
|
259
|
-
elif by == 'value':
|
260
|
-
return type(self)({k: v for k, v in self.items() if v in values})
|
261
|
-
else:
|
262
|
-
raise ValueError(f'Unsupported method: {by}')
|
263
|
-
|
264
|
-
def union_by_value_ids(self, other: dict):
|
265
|
-
"""
|
266
|
-
Union the stack by the value ids.
|
267
|
-
|
268
|
-
Args:
|
269
|
-
other:
|
270
|
-
|
271
|
-
Returns:
|
272
|
-
|
273
|
-
"""
|
274
|
-
|
275
|
-
def __add__(self, other: dict):
|
276
|
-
"""
|
277
|
-
Compose other instance of dict.
|
278
|
-
"""
|
279
|
-
new_dict = type(self)(self)
|
280
|
-
new_dict.update(other)
|
281
|
-
return new_dict
|
282
|
-
|
283
|
-
def tree_flatten(self):
|
284
|
-
return tuple(self.values()), tuple(self.keys())
|
285
|
-
|
286
|
-
@classmethod
|
287
|
-
def tree_unflatten(cls, keys, values):
|
288
|
-
return cls(jax.util.safe_zip(keys, values))
|
289
|
-
|
290
|
-
def _check_elem(self, elem: Any):
|
291
|
-
raise NotImplementedError
|
292
|
-
|
293
|
-
def to_dict(self):
|
294
|
-
"""
|
295
|
-
Convert the stack to a dict.
|
296
|
-
|
297
|
-
Returns
|
298
|
-
-------
|
299
|
-
dict
|
300
|
-
The dict object.
|
301
|
-
"""
|
302
|
-
return dict(self)
|
303
|
-
|
304
|
-
def __copy__(self):
|
305
|
-
return type(self)(self)
|
306
|
-
|
307
|
-
|
308
|
-
@set_module_as('brainstate.util')
|
309
|
-
def clear_buffer_memory(
|
310
|
-
platform: str = None,
|
311
|
-
array: bool = True,
|
312
|
-
compilation: bool = False,
|
313
|
-
):
|
314
|
-
"""Clear all on-device buffers.
|
315
|
-
|
316
|
-
This function will be very useful when you call models in a Python loop,
|
317
|
-
because it can clear all cached arrays, and clear device memory.
|
318
|
-
|
319
|
-
.. warning::
|
320
|
-
|
321
|
-
This operation may cause errors when you use a deleted buffer.
|
322
|
-
Therefore, regenerate data always.
|
323
|
-
|
324
|
-
Parameters
|
325
|
-
----------
|
326
|
-
platform: str
|
327
|
-
The device to clear its memory.
|
328
|
-
array: bool
|
329
|
-
Clear all buffer array. Default is True.
|
330
|
-
compilation: bool
|
331
|
-
Clear compilation cache. Default is False.
|
332
|
-
|
333
|
-
"""
|
334
|
-
if array:
|
335
|
-
for buf in xla_bridge.get_backend(platform).live_buffers():
|
336
|
-
buf.delete()
|
337
|
-
if compilation:
|
338
|
-
jax.clear_caches()
|
339
|
-
gc.collect()
|
340
|
-
|
341
|
-
|
342
|
-
class MemScaling(object):
|
343
|
-
"""
|
344
|
-
The scaling object for membrane potential.
|
345
|
-
|
346
|
-
The scaling object is used to transform the membrane potential range to a
|
347
|
-
standard range. The scaling object can be used to transform the membrane
|
348
|
-
potential to a standard range, and transform the standard range to the
|
349
|
-
membrane potential.
|
350
|
-
|
351
|
-
"""
|
352
|
-
__module__ = 'brainstate.util'
|
353
|
-
|
354
|
-
def __init__(self, scale, bias):
|
355
|
-
self._scale = scale
|
356
|
-
self._bias = bias
|
357
|
-
|
358
|
-
@classmethod
|
359
|
-
def transform(
|
360
|
-
cls,
|
361
|
-
oring_range: Sequence[Union[float, int]],
|
362
|
-
target_range: Sequence[Union[float, int]] = (0., 1.)
|
363
|
-
) -> 'MemScaling':
|
364
|
-
"""Transform the membrane potential range to a ``Scaling`` instance.
|
365
|
-
|
366
|
-
Args:
|
367
|
-
oring_range: [V_min, V_max]
|
368
|
-
target_range: [scaled_V_min, scaled_V_max]
|
369
|
-
|
370
|
-
Returns:
|
371
|
-
The instanced scaling object.
|
372
|
-
"""
|
373
|
-
V_min, V_max = oring_range
|
374
|
-
scaled_V_min, scaled_V_max = target_range
|
375
|
-
scale = (V_max - V_min) / (scaled_V_max - scaled_V_min)
|
376
|
-
bias = scaled_V_min * scale - V_min
|
377
|
-
return cls(scale=scale, bias=bias)
|
378
|
-
|
379
|
-
def scale_offset(self, x, bias=None, scale=None):
|
380
|
-
"""
|
381
|
-
Transform the membrane potential to the standard range.
|
382
|
-
|
383
|
-
Parameters
|
384
|
-
----------
|
385
|
-
x : array_like
|
386
|
-
The membrane potential.
|
387
|
-
bias : float, optional
|
388
|
-
The bias of the scaling object. If None, the default bias will be used.
|
389
|
-
scale : float, optional
|
390
|
-
The scale of the scaling object. If None, the default scale will be used.
|
391
|
-
|
392
|
-
Returns
|
393
|
-
-------
|
394
|
-
x : array_like
|
395
|
-
The standard range of the membrane potential.
|
396
|
-
"""
|
397
|
-
if bias is None:
|
398
|
-
bias = self._bias
|
399
|
-
if scale is None:
|
400
|
-
scale = self._scale
|
401
|
-
return (x + bias) / scale
|
402
|
-
|
403
|
-
def scale(self, x, scale=None):
|
404
|
-
"""
|
405
|
-
Transform the membrane potential to the standard range.
|
406
|
-
|
407
|
-
Parameters
|
408
|
-
----------
|
409
|
-
x : array_like
|
410
|
-
The membrane potential.
|
411
|
-
scale : float, optional
|
412
|
-
The scale of the scaling object. If None, the default scale will be used.
|
413
|
-
|
414
|
-
Returns
|
415
|
-
-------
|
416
|
-
x : array_like
|
417
|
-
The standard range of the membrane potential.
|
418
|
-
"""
|
419
|
-
if scale is None:
|
420
|
-
scale = self._scale
|
421
|
-
return x / scale
|
422
|
-
|
423
|
-
def offset(self, x, bias=None):
|
424
|
-
"""
|
425
|
-
Transform the membrane potential to the standard range.
|
426
|
-
|
427
|
-
Parameters
|
428
|
-
----------
|
429
|
-
x : array_like
|
430
|
-
The membrane potential.
|
431
|
-
bias : float, optional
|
432
|
-
The bias of the scaling object. If None, the default bias will be used.
|
433
|
-
|
434
|
-
Returns
|
435
|
-
-------
|
436
|
-
x : array_like
|
437
|
-
The standard range of the membrane potential.
|
438
|
-
"""
|
439
|
-
if bias is None:
|
440
|
-
bias = self._bias
|
441
|
-
return x + bias
|
442
|
-
|
443
|
-
def rev_scale(self, x, scale=None):
|
444
|
-
"""
|
445
|
-
Reversely transform the standard range to the original membrane potential.
|
446
|
-
|
447
|
-
Parameters
|
448
|
-
----------
|
449
|
-
x : array_like
|
450
|
-
The standard range of the membrane potential.
|
451
|
-
scale : float, optional
|
452
|
-
The scale of the scaling object. If None, the default scale will be used.
|
453
|
-
|
454
|
-
Returns
|
455
|
-
-------
|
456
|
-
x : array_like
|
457
|
-
The original membrane potential.
|
458
|
-
"""
|
459
|
-
if scale is None:
|
460
|
-
scale = self._scale
|
461
|
-
return x * scale
|
462
|
-
|
463
|
-
def rev_offset(self, x, bias=None):
|
464
|
-
"""
|
465
|
-
Reversely transform the standard range to the original membrane potential.
|
466
|
-
|
467
|
-
Parameters
|
468
|
-
----------
|
469
|
-
x : array_like
|
470
|
-
The standard range of the membrane potential.
|
471
|
-
bias : float, optional
|
472
|
-
The bias of the scaling object. If None, the default bias will be used.
|
473
|
-
|
474
|
-
Returns
|
475
|
-
-------
|
476
|
-
x : array_like
|
477
|
-
The original membrane potential.
|
478
|
-
"""
|
479
|
-
if bias is None:
|
480
|
-
bias = self._bias
|
481
|
-
return x - bias
|
482
|
-
|
483
|
-
def rev_scale_offset(self, x, bias=None, scale=None):
|
484
|
-
"""
|
485
|
-
Reversely transform the standard range to the original membrane potential.
|
486
|
-
|
487
|
-
Parameters
|
488
|
-
----------
|
489
|
-
x : array_like
|
490
|
-
The standard range of the membrane potential.
|
491
|
-
bias : float, optional
|
492
|
-
The bias of the scaling object. If None, the default bias will be used.
|
493
|
-
scale : float, optional
|
494
|
-
The scale of the scaling object. If None, the default scale will be used.
|
495
|
-
|
496
|
-
Returns
|
497
|
-
-------
|
498
|
-
x : array_like
|
499
|
-
The original membrane potential.
|
500
|
-
"""
|
501
|
-
if bias is None:
|
502
|
-
bias = self._bias
|
503
|
-
if scale is None:
|
504
|
-
scale = self._scale
|
505
|
-
return x * scale - bias
|
506
|
-
|
507
|
-
def clone(self):
|
508
|
-
"""
|
509
|
-
Clone the scaling object.
|
510
|
-
|
511
|
-
Returns
|
512
|
-
-------
|
513
|
-
scaling : MemScaling
|
514
|
-
The cloned scaling object.
|
515
|
-
"""
|
516
|
-
return MemScaling(bias=self._bias, scale=self._scale)
|
517
|
-
|
518
|
-
|
519
|
-
class IdMemScaling(MemScaling):
|
520
|
-
"""
|
521
|
-
The identity scaling object.
|
522
|
-
|
523
|
-
The identity scaling object is used to transform the membrane potential to
|
524
|
-
the standard range, and reversely transform the standard range to the
|
525
|
-
membrane potential.
|
526
|
-
|
527
|
-
"""
|
528
|
-
__module__ = 'brainstate.util'
|
529
|
-
|
530
|
-
def __init__(self):
|
531
|
-
super().__init__(scale=1., bias=0.)
|
532
|
-
|
533
|
-
def scale_offset(self, x, bias=None, scale=None):
|
534
|
-
"""
|
535
|
-
Transform the membrane potential to the standard range.
|
536
|
-
"""
|
537
|
-
return x
|
538
|
-
|
539
|
-
def scale(self, x, scale=None):
|
540
|
-
"""
|
541
|
-
Transform the membrane potential to the standard range.
|
542
|
-
"""
|
543
|
-
return x
|
544
|
-
|
545
|
-
def offset(self, x, bias=None):
|
546
|
-
"""
|
547
|
-
Transform the membrane potential to the standard range.
|
548
|
-
"""
|
549
|
-
return x
|
550
|
-
|
551
|
-
def rev_scale(self, x, scale=None):
|
552
|
-
"""
|
553
|
-
Reversely transform the standard range to the original membrane potential.
|
554
|
-
|
555
|
-
"""
|
556
|
-
return x
|
557
|
-
|
558
|
-
def rev_offset(self, x, bias=None):
|
559
|
-
"""
|
560
|
-
Reversely transform the standard range to the original membrane potential.
|
561
|
-
|
562
|
-
|
563
|
-
"""
|
564
|
-
return x
|
565
|
-
|
566
|
-
def rev_scale_offset(self, x, bias=None, scale=None):
|
567
|
-
"""
|
568
|
-
Reversely transform the standard range to the original membrane potential.
|
569
|
-
"""
|
570
|
-
return x
|
571
|
-
|
572
|
-
def clone(self):
|
573
|
-
"""
|
574
|
-
Clone the scaling object.
|
575
|
-
"""
|
576
|
-
return IdMemScaling()
|
577
|
-
|
578
|
-
|
579
|
-
@jax.tree_util.register_pytree_node_class
|
580
|
-
class DotDict(dict):
|
581
|
-
"""Python dictionaries with advanced dot notation access.
|
582
|
-
|
583
|
-
For example:
|
584
|
-
|
585
|
-
>>> d = DotDict({'a': 10, 'b': 20})
|
586
|
-
>>> d.a
|
587
|
-
10
|
588
|
-
>>> d['a']
|
589
|
-
10
|
590
|
-
>>> d.c # this will raise a KeyError
|
591
|
-
KeyError: 'c'
|
592
|
-
>>> d.c = 30 # but you can assign a value to a non-existing item
|
593
|
-
>>> d.c
|
594
|
-
30
|
595
|
-
"""
|
596
|
-
|
597
|
-
__module__ = 'brainstate.util'
|
598
|
-
|
599
|
-
def __init__(self, *args, **kwargs):
|
600
|
-
object.__setattr__(self, '__parent', kwargs.pop('__parent', None))
|
601
|
-
object.__setattr__(self, '__key', kwargs.pop('__key', None))
|
602
|
-
for arg in args:
|
603
|
-
if not arg:
|
604
|
-
continue
|
605
|
-
elif isinstance(arg, dict):
|
606
|
-
for key, val in arg.items():
|
607
|
-
self[key] = self._hook(val)
|
608
|
-
elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)):
|
609
|
-
self[arg[0]] = self._hook(arg[1])
|
610
|
-
else:
|
611
|
-
for key, val in iter(arg):
|
612
|
-
self[key] = self._hook(val)
|
613
|
-
|
614
|
-
for key, val in kwargs.items():
|
615
|
-
self[key] = self._hook(val)
|
616
|
-
|
617
|
-
def __setattr__(self, name, value):
|
618
|
-
if hasattr(self.__class__, name):
|
619
|
-
raise AttributeError(f"Attribute '{name}' is read-only in '{type(self)}' object.")
|
620
|
-
else:
|
621
|
-
self[name] = value
|
622
|
-
|
623
|
-
def __setitem__(self, name, value):
|
624
|
-
super(DotDict, self).__setitem__(name, value)
|
625
|
-
try:
|
626
|
-
p = object.__getattribute__(self, '__parent')
|
627
|
-
key = object.__getattribute__(self, '__key')
|
628
|
-
except AttributeError:
|
629
|
-
p = None
|
630
|
-
key = None
|
631
|
-
if p is not None:
|
632
|
-
p[key] = self
|
633
|
-
object.__delattr__(self, '__parent')
|
634
|
-
object.__delattr__(self, '__key')
|
635
|
-
|
636
|
-
@classmethod
|
637
|
-
def _hook(cls, item):
|
638
|
-
if isinstance(item, dict):
|
639
|
-
return cls(item)
|
640
|
-
elif isinstance(item, (list, tuple)):
|
641
|
-
return type(item)(cls._hook(elem) for elem in item)
|
642
|
-
return item
|
643
|
-
|
644
|
-
def __getattr__(self, item):
|
645
|
-
return self.__getitem__(item)
|
646
|
-
|
647
|
-
def __delattr__(self, name):
|
648
|
-
del self[name]
|
649
|
-
|
650
|
-
def copy(self):
|
651
|
-
return copy.copy(self)
|
652
|
-
|
653
|
-
def deepcopy(self):
|
654
|
-
return copy.deepcopy(self)
|
655
|
-
|
656
|
-
def __deepcopy__(self, memo):
|
657
|
-
other = self.__class__()
|
658
|
-
memo[id(self)] = other
|
659
|
-
for key, value in self.items():
|
660
|
-
other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo)
|
661
|
-
return other
|
662
|
-
|
663
|
-
def to_dict(self):
|
664
|
-
base = {}
|
665
|
-
for key, value in self.items():
|
666
|
-
if isinstance(value, type(self)):
|
667
|
-
base[key] = value.to_dict()
|
668
|
-
elif isinstance(value, (list, tuple)):
|
669
|
-
base[key] = type(value)(item.to_dict() if isinstance(item, type(self)) else item
|
670
|
-
for item in value)
|
671
|
-
else:
|
672
|
-
base[key] = value
|
673
|
-
return base
|
674
|
-
|
675
|
-
def update(self, *args, **kwargs):
|
676
|
-
other = {}
|
677
|
-
if args:
|
678
|
-
if len(args) > 1:
|
679
|
-
raise TypeError()
|
680
|
-
other.update(args[0])
|
681
|
-
other.update(kwargs)
|
682
|
-
for k, v in other.items():
|
683
|
-
if (k not in self) or (not isinstance(self[k], dict)) or (not isinstance(v, dict)):
|
684
|
-
self[k] = v
|
685
|
-
else:
|
686
|
-
self[k].update(v)
|
687
|
-
|
688
|
-
def __getnewargs__(self):
|
689
|
-
return tuple(self.items())
|
690
|
-
|
691
|
-
def __getstate__(self):
|
692
|
-
return self
|
693
|
-
|
694
|
-
def __setstate__(self, state):
|
695
|
-
self.update(state)
|
696
|
-
|
697
|
-
def setdefault(self, key, default=None):
|
698
|
-
if key in self:
|
699
|
-
return self[key]
|
700
|
-
else:
|
701
|
-
self[key] = default
|
702
|
-
return default
|
703
|
-
|
704
|
-
def tree_flatten(self):
|
705
|
-
return tuple(self.values()), tuple(self.keys())
|
706
|
-
|
707
|
-
@classmethod
|
708
|
-
def tree_unflatten(cls, keys, values):
|
709
|
-
return cls(jax.util.safe_zip(keys, values))
|
710
|
-
|
711
|
-
|
712
|
-
def _is_not_instance(x, cls):
|
713
|
-
return not isinstance(x, cls)
|
714
|
-
|
715
|
-
|
716
|
-
def _is_instance(x, cls):
|
717
|
-
return isinstance(x, cls)
|
718
|
-
|
719
|
-
|
720
|
-
@set_module_as('brainstate.util')
|
721
|
-
def not_instance_eval(*cls):
|
722
|
-
"""
|
723
|
-
Create a partial function to evaluate if the input is not an instance of the given class.
|
724
|
-
|
725
|
-
Args:
|
726
|
-
*cls: The classes to check.
|
727
|
-
|
728
|
-
Returns:
|
729
|
-
The partial function.
|
730
|
-
|
731
|
-
"""
|
732
|
-
return functools.partial(_is_not_instance, cls=cls)
|
733
|
-
|
734
|
-
|
735
|
-
@set_module_as('brainstate.util')
|
736
|
-
def is_instance_eval(*cls):
|
737
|
-
"""
|
738
|
-
Create a partial function to evaluate if the input is an instance of the given class.
|
739
|
-
|
740
|
-
Args:
|
741
|
-
*cls: The classes to check.
|
742
|
-
|
743
|
-
Returns:
|
744
|
-
The partial function.
|
745
|
-
"""
|
746
|
-
return functools.partial(_is_instance, cls=cls)
|