brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +15 -28
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/transform.py
DELETED
@@ -1,23 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# alias for compilation and augmentation functions
|
17
|
-
|
18
|
-
from .augment import *
|
19
|
-
from .compile import *
|
20
|
-
|
21
|
-
if __name__ == '__main__':
|
22
|
-
ifelse
|
23
|
-
grad
|
brainstate/util/caller.py
DELETED
@@ -1,98 +0,0 @@
|
|
1
|
-
# The file is adapted from the Flax library (https://github.com/google/flax).
|
2
|
-
# The credit should go to the Flax authors.
|
3
|
-
#
|
4
|
-
# Copyright 2024 The Flax Authors.
|
5
|
-
#
|
6
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
7
|
-
# you may not use this file except in compliance with the License.
|
8
|
-
# You may obtain a copy of the License at
|
9
|
-
#
|
10
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
11
|
-
#
|
12
|
-
# Unless required by applicable law or agreed to in writing, software
|
13
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
14
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15
|
-
# See the License for the specific language governing permissions and
|
16
|
-
# limitations under the License.
|
17
|
-
|
18
|
-
import dataclasses
|
19
|
-
from typing import Any, TypeVar, Protocol, Generic
|
20
|
-
|
21
|
-
import jax
|
22
|
-
|
23
|
-
__all__ = [
|
24
|
-
'DelayedAccessor',
|
25
|
-
'CallableProxy',
|
26
|
-
'ApplyCaller',
|
27
|
-
]
|
28
|
-
|
29
|
-
A = TypeVar('A', covariant=True) # type: ignore[not-supported-yet]
|
30
|
-
|
31
|
-
|
32
|
-
def _identity(x):
|
33
|
-
return x
|
34
|
-
|
35
|
-
|
36
|
-
@dataclasses.dataclass(frozen=True)
|
37
|
-
class GetItem:
|
38
|
-
key: Any
|
39
|
-
|
40
|
-
|
41
|
-
@dataclasses.dataclass(frozen=True)
|
42
|
-
class GetAttr:
|
43
|
-
name: str
|
44
|
-
|
45
|
-
|
46
|
-
@dataclasses.dataclass(frozen=True)
|
47
|
-
class DelayedAccessor:
|
48
|
-
actions: tuple[GetItem | GetAttr, ...] = ()
|
49
|
-
|
50
|
-
def __call__(self, x):
|
51
|
-
for action in self.actions:
|
52
|
-
if isinstance(action, GetItem):
|
53
|
-
x = x[action.key]
|
54
|
-
elif isinstance(action, GetAttr):
|
55
|
-
x = getattr(x, action.name)
|
56
|
-
return x
|
57
|
-
|
58
|
-
def __getattr__(self, name):
|
59
|
-
return DelayedAccessor(self.actions + (GetAttr(name),))
|
60
|
-
|
61
|
-
def __getitem__(self, key):
|
62
|
-
return DelayedAccessor(self.actions + (GetItem(key),))
|
63
|
-
|
64
|
-
|
65
|
-
jax.tree_util.register_static(DelayedAccessor)
|
66
|
-
|
67
|
-
|
68
|
-
class _AccessorCall(Protocol):
|
69
|
-
def __call__(self, accessor: DelayedAccessor, /, *args, **kwargs) -> Any:
|
70
|
-
...
|
71
|
-
|
72
|
-
|
73
|
-
class CallableProxy:
|
74
|
-
def __init__(
|
75
|
-
self, fun: _AccessorCall, accessor: DelayedAccessor | None = None
|
76
|
-
):
|
77
|
-
self._callable = fun
|
78
|
-
self._accessor = DelayedAccessor() if accessor is None else accessor
|
79
|
-
|
80
|
-
def __call__(self, *args, **kwargs):
|
81
|
-
return self._callable(self._accessor, *args, **kwargs)
|
82
|
-
|
83
|
-
def __getattr__(self, name) -> 'CallableProxy':
|
84
|
-
return CallableProxy(self._callable, getattr(self._accessor, name))
|
85
|
-
|
86
|
-
def __getitem__(self, key) -> 'CallableProxy':
|
87
|
-
return CallableProxy(self._callable, self._accessor[key])
|
88
|
-
|
89
|
-
|
90
|
-
class ApplyCaller(Protocol, Generic[A]):
|
91
|
-
def __getattr__(self, __name) -> 'ApplyCaller[A]':
|
92
|
-
...
|
93
|
-
|
94
|
-
def __getitem__(self, __name) -> 'ApplyCaller[A]':
|
95
|
-
...
|
96
|
-
|
97
|
-
def __call__(self, *args, **kwargs) -> tuple[Any, A]:
|
98
|
-
...
|
brainstate/util/others.py
DELETED
@@ -1,540 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import copy
|
17
|
-
import functools
|
18
|
-
import gc
|
19
|
-
import threading
|
20
|
-
import types
|
21
|
-
from collections.abc import Iterable
|
22
|
-
from typing import Any, Callable, Tuple, Union, Dict
|
23
|
-
|
24
|
-
import jax
|
25
|
-
from jax.lib import xla_bridge
|
26
|
-
|
27
|
-
from brainstate._utils import set_module_as
|
28
|
-
|
29
|
-
__all__ = [
|
30
|
-
'split_total',
|
31
|
-
'clear_buffer_memory',
|
32
|
-
'not_instance_eval',
|
33
|
-
'is_instance_eval',
|
34
|
-
'DictManager',
|
35
|
-
'DotDict',
|
36
|
-
]
|
37
|
-
|
38
|
-
|
39
|
-
def split_total(
|
40
|
-
total: int,
|
41
|
-
fraction: Union[int, float],
|
42
|
-
) -> int:
|
43
|
-
"""
|
44
|
-
Calculate the number of epochs for simulation based on a total and a fraction.
|
45
|
-
|
46
|
-
This function determines the number of epochs to simulate given a total number
|
47
|
-
of epochs and either a fraction or a specific number of epochs to run.
|
48
|
-
|
49
|
-
Parameters:
|
50
|
-
-----------
|
51
|
-
total : int
|
52
|
-
The total number of epochs. Must be a positive integer.
|
53
|
-
fraction : Union[int, float]
|
54
|
-
If ``float``: A value between 0 and 1 representing the fraction of total epochs to run.
|
55
|
-
If ``int``: The specific number of epochs to run, must not exceed the total.
|
56
|
-
|
57
|
-
Returns:
|
58
|
-
--------
|
59
|
-
int
|
60
|
-
The calculated number of epochs to simulate.
|
61
|
-
|
62
|
-
Raises:
|
63
|
-
-------
|
64
|
-
ValueError
|
65
|
-
If total is not positive, fraction is negative, or if fraction as float is > 1
|
66
|
-
or as int is > total.
|
67
|
-
AssertionError
|
68
|
-
If total is not an integer.
|
69
|
-
"""
|
70
|
-
assert isinstance(total, int), "Length must be an integer."
|
71
|
-
if total <= 0:
|
72
|
-
raise ValueError("'total' must be a positive integer.")
|
73
|
-
if fraction < 0:
|
74
|
-
raise ValueError("'fraction' value cannot be negative.")
|
75
|
-
|
76
|
-
if isinstance(fraction, float):
|
77
|
-
if fraction < 0:
|
78
|
-
raise ValueError("'fraction' value cannot be negative.")
|
79
|
-
if fraction > 1:
|
80
|
-
raise ValueError("'fraction' value cannot be greater than 1.")
|
81
|
-
return int(total * fraction)
|
82
|
-
|
83
|
-
elif isinstance(fraction, int):
|
84
|
-
if fraction < 0:
|
85
|
-
raise ValueError("'fraction' value cannot be negative.")
|
86
|
-
if fraction > total:
|
87
|
-
raise ValueError("'fraction' value cannot be greater than total.")
|
88
|
-
return fraction
|
89
|
-
|
90
|
-
else:
|
91
|
-
raise ValueError("'fraction' must be an integer or float.")
|
92
|
-
|
93
|
-
|
94
|
-
class NameContext(threading.local):
|
95
|
-
def __init__(self):
|
96
|
-
self.typed_names: Dict[str, int] = {}
|
97
|
-
|
98
|
-
|
99
|
-
NAME = NameContext()
|
100
|
-
|
101
|
-
|
102
|
-
def get_unique_name(type_: str):
|
103
|
-
"""Get the unique name for the given object type."""
|
104
|
-
if type_ not in NAME.typed_names:
|
105
|
-
NAME.typed_names[type_] = 0
|
106
|
-
name = f'{type_}{NAME.typed_names[type_]}'
|
107
|
-
NAME.typed_names[type_] += 1
|
108
|
-
return name
|
109
|
-
|
110
|
-
|
111
|
-
@jax.tree_util.register_pytree_node_class
|
112
|
-
class DictManager(dict):
|
113
|
-
"""
|
114
|
-
DictManager, for collecting all pytree used in the program.
|
115
|
-
|
116
|
-
:py:class:`~.DictManager` supports all features of python dict.
|
117
|
-
"""
|
118
|
-
__module__ = 'brainstate.util'
|
119
|
-
_val_id_to_key: dict
|
120
|
-
|
121
|
-
def subset(self, sep: Union[type, Tuple[type, ...], Callable]) -> 'DictManager':
|
122
|
-
"""
|
123
|
-
Get a new stack with the subset of keys.
|
124
|
-
"""
|
125
|
-
gather = type(self)()
|
126
|
-
if isinstance(sep, types.FunctionType):
|
127
|
-
for k, v in self.items():
|
128
|
-
if sep(v):
|
129
|
-
gather[k] = v
|
130
|
-
return gather
|
131
|
-
else:
|
132
|
-
for k, v in self.items():
|
133
|
-
if isinstance(v, sep):
|
134
|
-
gather[k] = v
|
135
|
-
return gather
|
136
|
-
|
137
|
-
def not_subset(self, sep: Union[type, Tuple[type, ...]]) -> 'DictManager':
|
138
|
-
"""
|
139
|
-
Get a new stack with the subset of keys.
|
140
|
-
"""
|
141
|
-
gather = type(self)()
|
142
|
-
for k, v in self.items():
|
143
|
-
if not isinstance(v, sep):
|
144
|
-
gather[k] = v
|
145
|
-
return gather
|
146
|
-
|
147
|
-
def add_unique_key(self, key: Any, val: Any):
|
148
|
-
"""
|
149
|
-
Add a new element and check if the value is same or not.
|
150
|
-
"""
|
151
|
-
self._check_elem(val)
|
152
|
-
if key in self:
|
153
|
-
if id(val) != id(self[key]):
|
154
|
-
raise ValueError(f'{key} has been registered by {self[key]}, the new value is different from it.')
|
155
|
-
else:
|
156
|
-
self[key] = val
|
157
|
-
|
158
|
-
def add_unique_value(self, key: Any, val: Any) -> bool:
|
159
|
-
"""
|
160
|
-
Add a new element and check if the val is unique.
|
161
|
-
|
162
|
-
Parameters:
|
163
|
-
key: The key of the element.
|
164
|
-
val: The value of the element
|
165
|
-
|
166
|
-
Returns:
|
167
|
-
bool: True if the value is unique, False otherwise.
|
168
|
-
"""
|
169
|
-
self._check_elem(val)
|
170
|
-
if not hasattr(self, '_val_id_to_key'):
|
171
|
-
self._val_id_to_key = {id(v): k for k, v in self.items()}
|
172
|
-
if id(val) not in self._val_id_to_key:
|
173
|
-
self._val_id_to_key[id(val)] = key
|
174
|
-
self[key] = val
|
175
|
-
return True
|
176
|
-
else:
|
177
|
-
return False
|
178
|
-
|
179
|
-
def unique(self) -> 'DictManager':
|
180
|
-
"""
|
181
|
-
Get a new type of collections with unique values.
|
182
|
-
|
183
|
-
If one value is assigned to two or more keys,
|
184
|
-
then only one pair of (key, value) will be returned.
|
185
|
-
"""
|
186
|
-
gather = type(self)()
|
187
|
-
seen = set()
|
188
|
-
for k, v in self.items():
|
189
|
-
if id(v) not in seen:
|
190
|
-
seen.add(id(v))
|
191
|
-
gather[k] = v
|
192
|
-
return gather
|
193
|
-
|
194
|
-
def unique_(self):
|
195
|
-
"""
|
196
|
-
Get a new type of collections with unique values.
|
197
|
-
|
198
|
-
If one value is assigned to two or more keys,
|
199
|
-
then only one pair of (key, value) will be returned.
|
200
|
-
"""
|
201
|
-
seen = set()
|
202
|
-
for k in tuple(self.keys()):
|
203
|
-
v = self[k]
|
204
|
-
if id(v) not in seen:
|
205
|
-
seen.add(id(v))
|
206
|
-
else:
|
207
|
-
self.pop(k)
|
208
|
-
return self
|
209
|
-
|
210
|
-
def assign(self, *args) -> None:
|
211
|
-
"""
|
212
|
-
Assign the value for each element according to the given ``data``.
|
213
|
-
"""
|
214
|
-
for arg in args:
|
215
|
-
assert isinstance(arg, dict), 'Must be an instance of dict.'
|
216
|
-
for k, v in arg.items():
|
217
|
-
self[k] = v
|
218
|
-
|
219
|
-
def split(self, first: type, *others: type) -> Tuple['DictManager', ...]:
|
220
|
-
"""
|
221
|
-
Split the stack into subsets of stack by the given types.
|
222
|
-
"""
|
223
|
-
filters = (first, *others)
|
224
|
-
results = tuple(type(self)() for _ in range(len(filters) + 1))
|
225
|
-
for k, v in self.items():
|
226
|
-
for i, filt in enumerate(filters):
|
227
|
-
if isinstance(v, filt):
|
228
|
-
results[i][k] = v
|
229
|
-
break
|
230
|
-
else:
|
231
|
-
results[-1][k] = v
|
232
|
-
return results
|
233
|
-
|
234
|
-
def pop_by_keys(self, keys: Iterable):
|
235
|
-
"""
|
236
|
-
Pop the elements by the keys.
|
237
|
-
"""
|
238
|
-
for k in tuple(self.keys()):
|
239
|
-
if k in keys:
|
240
|
-
self.pop(k)
|
241
|
-
|
242
|
-
def pop_by_values(self, values: Iterable, by: str = 'id'):
|
243
|
-
"""
|
244
|
-
Pop the elements by the values.
|
245
|
-
|
246
|
-
Args:
|
247
|
-
values: The value ids.
|
248
|
-
by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
|
249
|
-
"""
|
250
|
-
if by == 'id':
|
251
|
-
value_ids = {id(v) for v in values}
|
252
|
-
for k in tuple(self.keys()):
|
253
|
-
if id(self[k]) in value_ids:
|
254
|
-
self.pop(k)
|
255
|
-
elif by == 'value':
|
256
|
-
for k in tuple(self.keys()):
|
257
|
-
if self[k] in values:
|
258
|
-
self.pop(k)
|
259
|
-
else:
|
260
|
-
raise ValueError(f'Unsupported method: {by}')
|
261
|
-
|
262
|
-
def difference_by_keys(self, keys: Iterable):
|
263
|
-
"""
|
264
|
-
Get the difference of the stack by the keys.
|
265
|
-
"""
|
266
|
-
return type(self)({k: v for k, v in self.items() if k not in keys})
|
267
|
-
|
268
|
-
def difference_by_values(self, values: Iterable, by: str = 'id'):
|
269
|
-
"""
|
270
|
-
Get the difference of the stack by the values.
|
271
|
-
|
272
|
-
Args:
|
273
|
-
values: The value ids.
|
274
|
-
by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
|
275
|
-
"""
|
276
|
-
if by == 'id':
|
277
|
-
value_ids = {id(v) for v in values}
|
278
|
-
return type(self)({k: v for k, v in self.items() if id(v) not in value_ids})
|
279
|
-
elif by == 'value':
|
280
|
-
return type(self)({k: v for k, v in self.items() if v not in values})
|
281
|
-
else:
|
282
|
-
raise ValueError(f'Unsupported method: {by}')
|
283
|
-
|
284
|
-
def intersection_by_keys(self, keys: Iterable):
|
285
|
-
"""
|
286
|
-
Get the intersection of the stack by the keys.
|
287
|
-
"""
|
288
|
-
return type(self)({k: v for k, v in self.items() if k in keys})
|
289
|
-
|
290
|
-
def intersection_by_values(self, values: Iterable, by: str = 'id'):
|
291
|
-
"""
|
292
|
-
Get the intersection of the stack by the values.
|
293
|
-
|
294
|
-
Args:
|
295
|
-
values: The value ids.
|
296
|
-
by: str. The discard method, can be ``id`` or ``value``. Default is 'id'.
|
297
|
-
"""
|
298
|
-
if by == 'id':
|
299
|
-
value_ids = {id(v) for v in values}
|
300
|
-
return type(self)({k: v for k, v in self.items() if id(v) in value_ids})
|
301
|
-
elif by == 'value':
|
302
|
-
return type(self)({k: v for k, v in self.items() if v in values})
|
303
|
-
else:
|
304
|
-
raise ValueError(f'Unsupported method: {by}')
|
305
|
-
|
306
|
-
def __add__(self, other: dict):
|
307
|
-
"""
|
308
|
-
Compose other instance of dict.
|
309
|
-
"""
|
310
|
-
new_dict = type(self)(self)
|
311
|
-
new_dict.update(other)
|
312
|
-
return new_dict
|
313
|
-
|
314
|
-
def tree_flatten(self):
|
315
|
-
return tuple(self.values()), tuple(self.keys())
|
316
|
-
|
317
|
-
@classmethod
|
318
|
-
def tree_unflatten(cls, keys, values):
|
319
|
-
return cls(jax.util.safe_zip(keys, values))
|
320
|
-
|
321
|
-
def _check_elem(self, elem: Any):
|
322
|
-
raise NotImplementedError
|
323
|
-
|
324
|
-
def to_dict(self):
|
325
|
-
"""
|
326
|
-
Convert the stack to a dict.
|
327
|
-
|
328
|
-
Returns
|
329
|
-
-------
|
330
|
-
dict
|
331
|
-
The dict object.
|
332
|
-
"""
|
333
|
-
return dict(self)
|
334
|
-
|
335
|
-
def __copy__(self):
|
336
|
-
return type(self)(self)
|
337
|
-
|
338
|
-
|
339
|
-
@set_module_as('brainstate.util')
|
340
|
-
def clear_buffer_memory(
|
341
|
-
platform: str = None,
|
342
|
-
array: bool = True,
|
343
|
-
compilation: bool = False,
|
344
|
-
):
|
345
|
-
"""Clear all on-device buffers.
|
346
|
-
|
347
|
-
This function will be very useful when you call models in a Python loop,
|
348
|
-
because it can clear all cached arrays, and clear device memory.
|
349
|
-
|
350
|
-
.. warning::
|
351
|
-
|
352
|
-
This operation may cause errors when you use a deleted buffer.
|
353
|
-
Therefore, regenerate data always.
|
354
|
-
|
355
|
-
Parameters
|
356
|
-
----------
|
357
|
-
platform: str
|
358
|
-
The device to clear its memory.
|
359
|
-
array: bool
|
360
|
-
Clear all buffer array. Default is True.
|
361
|
-
compilation: bool
|
362
|
-
Clear compilation cache. Default is False.
|
363
|
-
|
364
|
-
"""
|
365
|
-
if array:
|
366
|
-
for buf in xla_bridge.get_backend(platform).live_buffers():
|
367
|
-
buf.delete()
|
368
|
-
if compilation:
|
369
|
-
jax.clear_caches()
|
370
|
-
gc.collect()
|
371
|
-
|
372
|
-
|
373
|
-
@jax.tree_util.register_pytree_node_class
|
374
|
-
class DotDict(dict):
|
375
|
-
"""Python dictionaries with advanced dot notation access.
|
376
|
-
|
377
|
-
For example:
|
378
|
-
|
379
|
-
>>> d = DotDict({'a': 10, 'b': 20})
|
380
|
-
>>> d.a
|
381
|
-
10
|
382
|
-
>>> d['a']
|
383
|
-
10
|
384
|
-
>>> d.c # this will raise a KeyError
|
385
|
-
KeyError: 'c'
|
386
|
-
>>> d.c = 30 # but you can assign a value to a non-existing item
|
387
|
-
>>> d.c
|
388
|
-
30
|
389
|
-
"""
|
390
|
-
|
391
|
-
__module__ = 'brainstate.util'
|
392
|
-
|
393
|
-
def __init__(self, *args, **kwargs):
|
394
|
-
object.__setattr__(self, '__parent', kwargs.pop('__parent', None))
|
395
|
-
object.__setattr__(self, '__key', kwargs.pop('__key', None))
|
396
|
-
for arg in args:
|
397
|
-
if not arg:
|
398
|
-
continue
|
399
|
-
elif isinstance(arg, dict):
|
400
|
-
for key, val in arg.items():
|
401
|
-
self[key] = self._hook(val)
|
402
|
-
elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)):
|
403
|
-
self[arg[0]] = self._hook(arg[1])
|
404
|
-
else:
|
405
|
-
for key, val in iter(arg):
|
406
|
-
self[key] = self._hook(val)
|
407
|
-
|
408
|
-
for key, val in kwargs.items():
|
409
|
-
self[key] = self._hook(val)
|
410
|
-
|
411
|
-
def __setattr__(self, name, value):
|
412
|
-
if hasattr(self.__class__, name):
|
413
|
-
raise AttributeError(f"Attribute '{name}' is read-only in '{type(self)}' object.")
|
414
|
-
else:
|
415
|
-
self[name] = value
|
416
|
-
|
417
|
-
def __setitem__(self, name, value):
|
418
|
-
super(DotDict, self).__setitem__(name, value)
|
419
|
-
try:
|
420
|
-
p = object.__getattribute__(self, '__parent')
|
421
|
-
key = object.__getattribute__(self, '__key')
|
422
|
-
except AttributeError:
|
423
|
-
p = None
|
424
|
-
key = None
|
425
|
-
if p is not None:
|
426
|
-
p[key] = self
|
427
|
-
object.__delattr__(self, '__parent')
|
428
|
-
object.__delattr__(self, '__key')
|
429
|
-
|
430
|
-
@classmethod
|
431
|
-
def _hook(cls, item):
|
432
|
-
if isinstance(item, dict):
|
433
|
-
return cls(item)
|
434
|
-
elif isinstance(item, (list, tuple)):
|
435
|
-
return type(item)(cls._hook(elem) for elem in item)
|
436
|
-
return item
|
437
|
-
|
438
|
-
def __getattr__(self, item):
|
439
|
-
return self.__getitem__(item)
|
440
|
-
|
441
|
-
def __delattr__(self, name):
|
442
|
-
del self[name]
|
443
|
-
|
444
|
-
def copy(self):
|
445
|
-
return copy.copy(self)
|
446
|
-
|
447
|
-
def deepcopy(self):
|
448
|
-
return copy.deepcopy(self)
|
449
|
-
|
450
|
-
def __deepcopy__(self, memo):
|
451
|
-
other = self.__class__()
|
452
|
-
memo[id(self)] = other
|
453
|
-
for key, value in self.items():
|
454
|
-
other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo)
|
455
|
-
return other
|
456
|
-
|
457
|
-
def to_dict(self):
|
458
|
-
base = {}
|
459
|
-
for key, value in self.items():
|
460
|
-
if isinstance(value, type(self)):
|
461
|
-
base[key] = value.to_dict()
|
462
|
-
elif isinstance(value, (list, tuple)):
|
463
|
-
base[key] = type(value)(item.to_dict() if isinstance(item, type(self)) else item
|
464
|
-
for item in value)
|
465
|
-
else:
|
466
|
-
base[key] = value
|
467
|
-
return base
|
468
|
-
|
469
|
-
def update(self, *args, **kwargs):
|
470
|
-
other = {}
|
471
|
-
if args:
|
472
|
-
if len(args) > 1:
|
473
|
-
raise TypeError()
|
474
|
-
other.update(args[0])
|
475
|
-
other.update(kwargs)
|
476
|
-
for k, v in other.items():
|
477
|
-
if (k not in self) or (not isinstance(self[k], dict)) or (not isinstance(v, dict)):
|
478
|
-
self[k] = v
|
479
|
-
else:
|
480
|
-
self[k].update(v)
|
481
|
-
|
482
|
-
def __getnewargs__(self):
|
483
|
-
return tuple(self.items())
|
484
|
-
|
485
|
-
def __getstate__(self):
|
486
|
-
return self
|
487
|
-
|
488
|
-
def __setstate__(self, state):
|
489
|
-
self.update(state)
|
490
|
-
|
491
|
-
def setdefault(self, key, default=None):
|
492
|
-
if key in self:
|
493
|
-
return self[key]
|
494
|
-
else:
|
495
|
-
self[key] = default
|
496
|
-
return default
|
497
|
-
|
498
|
-
def tree_flatten(self):
|
499
|
-
return tuple(self.values()), tuple(self.keys())
|
500
|
-
|
501
|
-
@classmethod
|
502
|
-
def tree_unflatten(cls, keys, values):
|
503
|
-
return cls(jax.util.safe_zip(keys, values))
|
504
|
-
|
505
|
-
|
506
|
-
def _is_not_instance(x, cls):
|
507
|
-
return not isinstance(x, cls)
|
508
|
-
|
509
|
-
|
510
|
-
def _is_instance(x, cls):
|
511
|
-
return isinstance(x, cls)
|
512
|
-
|
513
|
-
|
514
|
-
@set_module_as('brainstate.util')
|
515
|
-
def not_instance_eval(*cls):
|
516
|
-
"""
|
517
|
-
Create a partial function to evaluate if the input is not an instance of the given class.
|
518
|
-
|
519
|
-
Args:
|
520
|
-
*cls: The classes to check.
|
521
|
-
|
522
|
-
Returns:
|
523
|
-
The partial function.
|
524
|
-
|
525
|
-
"""
|
526
|
-
return functools.partial(_is_not_instance, cls=cls)
|
527
|
-
|
528
|
-
|
529
|
-
@set_module_as('brainstate.util')
|
530
|
-
def is_instance_eval(*cls):
|
531
|
-
"""
|
532
|
-
Create a partial function to evaluate if the input is an instance of the given class.
|
533
|
-
|
534
|
-
Args:
|
535
|
-
*cls: The classes to check.
|
536
|
-
|
537
|
-
Returns:
|
538
|
-
The partial function.
|
539
|
-
"""
|
540
|
-
return functools.partial(_is_instance, cls=cls)
|