brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/COBA_2005.py +125 -0
- benchmark/CUBA_2005.py +149 -0
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +611 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/event/__init__.py +27 -0
- brainstate/event/_csr.py +316 -0
- brainstate/event/_csr_benchmark.py +14 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +708 -0
- brainstate/event/_fixed_probability_benchmark.py +128 -0
- brainstate/event/_fixed_probability_test.py +131 -0
- brainstate/event/_linear.py +359 -0
- brainstate/event/_linear_benckmark.py +82 -0
- brainstate/event/_linear_test.py +117 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/event/_xla_custom_op.py +312 -0
- brainstate/event/_xla_custom_op_test.py +55 -0
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +41 -0
- brainstate/nn/_interaction/_conv.py +499 -0
- brainstate/nn/_interaction/_conv_test.py +239 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_linear.py +582 -0
- brainstate/nn/_interaction/_linear_test.py +42 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +121 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
- brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
brainstate/mixin_test.py
CHANGED
@@ -19,57 +19,55 @@ import brainstate as bc
|
|
19
19
|
|
20
20
|
|
21
21
|
class TestMixin(unittest.TestCase):
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
22
|
+
def test_mixin(self):
|
23
|
+
self.assertTrue(bc.mixin.Mixin)
|
24
|
+
self.assertTrue(bc.mixin.ParamDesc)
|
25
|
+
self.assertTrue(bc.mixin.ParamDescriber)
|
26
|
+
self.assertTrue(bc.mixin.JointTypes)
|
27
|
+
self.assertTrue(bc.mixin.OneOfTypes)
|
28
|
+
self.assertTrue(bc.mixin.Mode)
|
29
|
+
self.assertTrue(bc.mixin.Batching)
|
30
|
+
self.assertTrue(bc.mixin.Training)
|
33
31
|
|
34
32
|
|
35
33
|
class TestMode(unittest.TestCase):
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
34
|
+
def test_JointMode(self):
|
35
|
+
a = bc.mixin.JointMode(bc.mixin.Batching(), bc.mixin.Training())
|
36
|
+
self.assertTrue(a.is_a(bc.mixin.JointTypes[bc.mixin.Batching, bc.mixin.Training]))
|
37
|
+
self.assertTrue(a.has(bc.mixin.Batching))
|
38
|
+
self.assertTrue(a.has(bc.mixin.Training))
|
39
|
+
b = bc.mixin.JointMode(bc.mixin.Batching())
|
40
|
+
self.assertTrue(b.is_a(bc.mixin.JointTypes[bc.mixin.Batching]))
|
41
|
+
self.assertTrue(b.is_a(bc.mixin.Batching))
|
42
|
+
self.assertTrue(b.has(bc.mixin.Batching))
|
45
43
|
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
44
|
+
def test_Training(self):
|
45
|
+
a = bc.mixin.Training()
|
46
|
+
self.assertTrue(a.is_a(bc.mixin.Training))
|
47
|
+
self.assertTrue(a.is_a(bc.mixin.JointTypes[bc.mixin.Training]))
|
48
|
+
self.assertTrue(a.has(bc.mixin.Training))
|
49
|
+
self.assertTrue(a.has(bc.mixin.JointTypes[bc.mixin.Training]))
|
50
|
+
self.assertFalse(a.is_a(bc.mixin.Batching))
|
51
|
+
self.assertFalse(a.has(bc.mixin.Batching))
|
54
52
|
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
53
|
+
def test_Batching(self):
|
54
|
+
a = bc.mixin.Batching()
|
55
|
+
self.assertTrue(a.is_a(bc.mixin.Batching))
|
56
|
+
self.assertTrue(a.is_a(bc.mixin.JointTypes[bc.mixin.Batching]))
|
57
|
+
self.assertTrue(a.has(bc.mixin.Batching))
|
58
|
+
self.assertTrue(a.has(bc.mixin.JointTypes[bc.mixin.Batching]))
|
61
59
|
|
62
|
-
|
63
|
-
|
60
|
+
self.assertFalse(a.is_a(bc.mixin.Training))
|
61
|
+
self.assertFalse(a.has(bc.mixin.Training))
|
64
62
|
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
63
|
+
def test_Mode(self):
|
64
|
+
a = bc.mixin.Mode()
|
65
|
+
self.assertTrue(a.is_a(bc.mixin.Mode))
|
66
|
+
self.assertTrue(a.is_a(bc.mixin.JointTypes[bc.mixin.Mode]))
|
67
|
+
self.assertTrue(a.has(bc.mixin.Mode))
|
68
|
+
self.assertTrue(a.has(bc.mixin.JointTypes[bc.mixin.Mode]))
|
71
69
|
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
70
|
+
self.assertFalse(a.is_a(bc.mixin.Training))
|
71
|
+
self.assertFalse(a.has(bc.mixin.Training))
|
72
|
+
self.assertFalse(a.is_a(bc.mixin.Batching))
|
73
|
+
self.assertFalse(a.has(bc.mixin.Batching))
|
brainstate/nn/__init__.py
CHANGED
@@ -13,65 +13,40 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
|
16
17
|
from . import metrics
|
17
|
-
from .
|
18
|
-
from .
|
19
|
-
from .
|
20
|
-
from .
|
18
|
+
from ._collective_ops import *
|
19
|
+
from ._collective_ops import __all__ as collective_ops_all
|
20
|
+
from ._dyn_impl import *
|
21
|
+
from ._dyn_impl import __all__ as dyn_impl_all
|
21
22
|
from ._dynamics import *
|
22
23
|
from ._dynamics import __all__ as dynamics_all
|
23
24
|
from ._elementwise import *
|
24
25
|
from ._elementwise import __all__ as elementwise_all
|
25
|
-
from .
|
26
|
-
from .
|
27
|
-
from .
|
28
|
-
from .
|
29
|
-
from .
|
30
|
-
from .
|
31
|
-
from ._others import *
|
32
|
-
from ._others import __all__ as others_all
|
33
|
-
from ._poolings import *
|
34
|
-
from ._poolings import __all__ as poolings_all
|
35
|
-
from ._projection import *
|
36
|
-
from ._projection import __all__ as _projection_all
|
37
|
-
from ._rate_rnns import *
|
38
|
-
from ._rate_rnns import __all__ as rate_rnns
|
39
|
-
from ._readout import *
|
40
|
-
from ._readout import __all__ as readout_all
|
41
|
-
from ._synouts import *
|
42
|
-
from ._synouts import __all__ as synouts_all
|
43
|
-
from .event import *
|
44
|
-
from .event import __all__ as event_all
|
26
|
+
from ._exp_euler import *
|
27
|
+
from ._exp_euler import __all__ as exp_euler_all
|
28
|
+
from ._interaction import *
|
29
|
+
from ._interaction import __all__ as interaction_all
|
30
|
+
from ._module import *
|
31
|
+
from ._module import __all__ as module_all
|
45
32
|
|
46
33
|
__all__ = (
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
rate_rnns +
|
56
|
-
readout_all +
|
57
|
-
synouts_all +
|
58
|
-
_projection_all +
|
59
|
-
_misc_all +
|
60
|
-
event_all
|
34
|
+
['metrics']
|
35
|
+
+ collective_ops_all
|
36
|
+
+ dyn_impl_all
|
37
|
+
+ dynamics_all
|
38
|
+
+ elementwise_all
|
39
|
+
+ module_all
|
40
|
+
+ exp_euler_all
|
41
|
+
+ interaction_all
|
61
42
|
)
|
62
43
|
|
63
44
|
del (
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
poolings_all,
|
72
|
-
readout_all,
|
73
|
-
synouts_all,
|
74
|
-
_projection_all,
|
75
|
-
_misc_all,
|
76
|
-
event_all
|
45
|
+
collective_ops_all,
|
46
|
+
dyn_impl_all,
|
47
|
+
dynamics_all,
|
48
|
+
elementwise_all,
|
49
|
+
module_all,
|
50
|
+
exp_euler_all,
|
51
|
+
interaction_all,
|
77
52
|
)
|
@@ -0,0 +1,199 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
from collections import namedtuple
|
19
|
+
from typing import Dict, Callable, TypeVar
|
20
|
+
|
21
|
+
import jax
|
22
|
+
|
23
|
+
from brainstate._utils import set_module_as
|
24
|
+
from brainstate.graph import nodes
|
25
|
+
from ._module import Module
|
26
|
+
|
27
|
+
# the maximum order
|
28
|
+
MAX_ORDER = 10
|
29
|
+
|
30
|
+
# State Load Results
|
31
|
+
StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys'])
|
32
|
+
|
33
|
+
T = TypeVar('T', bound=Module)
|
34
|
+
|
35
|
+
__all__ = [
|
36
|
+
'MAX_ORDER', 'call_order', 'init_all_states', 'reset_all_states',
|
37
|
+
'load_all_states', 'save_all_states', 'assign_state_values',
|
38
|
+
]
|
39
|
+
|
40
|
+
|
41
|
+
@set_module_as('brainstate.nn')
|
42
|
+
def call_order(level: int = 0, check_order_boundary: bool = True):
|
43
|
+
"""The decorator for indicating the resetting level.
|
44
|
+
|
45
|
+
The function takes an optional integer argument level with a default value of 0.
|
46
|
+
|
47
|
+
The lower the level, the earlier the function is called.
|
48
|
+
|
49
|
+
>>> import brainstate as bst
|
50
|
+
>>> bst.nn.call_order(0)
|
51
|
+
>>> bst.nn.call_order(-1)
|
52
|
+
>>> bst.nn.call_order(-2)
|
53
|
+
|
54
|
+
Parameters
|
55
|
+
----------
|
56
|
+
level: int
|
57
|
+
The call order level.
|
58
|
+
check_order_boundary: bool
|
59
|
+
Whether check the boundary of function call order. If True,
|
60
|
+
the order that not in [0, 10) will raise a ValueError.
|
61
|
+
|
62
|
+
Returns
|
63
|
+
-------
|
64
|
+
The function to warp.
|
65
|
+
"""
|
66
|
+
if check_order_boundary and (level < 0 or level >= MAX_ORDER):
|
67
|
+
raise ValueError(f'"call_order" must be an integer in [0, {MAX_ORDER}). but we got {level}.')
|
68
|
+
|
69
|
+
def wrap(fun: Callable):
|
70
|
+
fun.call_order = level
|
71
|
+
return fun
|
72
|
+
|
73
|
+
return wrap
|
74
|
+
|
75
|
+
|
76
|
+
@set_module_as('brainstate.nn')
|
77
|
+
def init_all_states(target: T, *args, exclude=None, **kwargs) -> T:
|
78
|
+
"""
|
79
|
+
Collectively initialize states of all children nodes in the given target.
|
80
|
+
|
81
|
+
Args:
|
82
|
+
target: The target Module.
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
The target Module.
|
86
|
+
"""
|
87
|
+
nodes_with_order = []
|
88
|
+
|
89
|
+
nodes_ = nodes(target).filter(Module)
|
90
|
+
if exclude is not None:
|
91
|
+
nodes_ = nodes_ - nodes_.filter(exclude)
|
92
|
+
|
93
|
+
# reset node whose `init_state` has no `call_order`
|
94
|
+
for node in list(nodes_.values()):
|
95
|
+
if hasattr(node.init_state, 'call_order'):
|
96
|
+
nodes_with_order.append(node)
|
97
|
+
else:
|
98
|
+
node.init_state(*args, **kwargs)
|
99
|
+
|
100
|
+
# reset the node's states
|
101
|
+
for node in sorted(nodes_with_order, key=lambda x: x.init_state.call_order):
|
102
|
+
node.init_state(*args, **kwargs)
|
103
|
+
|
104
|
+
return target
|
105
|
+
|
106
|
+
|
107
|
+
@set_module_as('brainstate.nn')
|
108
|
+
def reset_all_states(target: Module, *args, **kwargs) -> Module:
|
109
|
+
"""
|
110
|
+
Collectively reset states of all children nodes in the given target.
|
111
|
+
|
112
|
+
Args:
|
113
|
+
target: The target Module.
|
114
|
+
|
115
|
+
Returns:
|
116
|
+
The target Module.
|
117
|
+
"""
|
118
|
+
nodes_with_order = []
|
119
|
+
|
120
|
+
# reset node whose `init_state` has no `call_order`
|
121
|
+
for path, node in nodes(target).filter(Module).items():
|
122
|
+
if hasattr(node.reset_state, 'call_order'):
|
123
|
+
nodes_with_order.append(node)
|
124
|
+
else:
|
125
|
+
node.reset_state(*args, **kwargs)
|
126
|
+
|
127
|
+
# reset the node's states
|
128
|
+
for node in sorted(nodes_with_order, key=lambda x: x.reset_state.call_order):
|
129
|
+
node.reset_state(*args, **kwargs)
|
130
|
+
|
131
|
+
return target
|
132
|
+
|
133
|
+
|
134
|
+
@set_module_as('brainstate.nn')
|
135
|
+
def load_all_states(target: Module, state_dict: Dict, **kwargs):
|
136
|
+
"""
|
137
|
+
Copy parameters and buffers from :attr:`state_dict` into
|
138
|
+
this module and its descendants.
|
139
|
+
|
140
|
+
Args:
|
141
|
+
target: Module. The dynamical system to load its states.
|
142
|
+
state_dict: dict. A dict containing parameters and persistent buffers.
|
143
|
+
|
144
|
+
Returns
|
145
|
+
-------
|
146
|
+
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
|
147
|
+
|
148
|
+
* **missing_keys** is a list of str containing the missing keys
|
149
|
+
* **unexpected_keys** is a list of str containing the unexpected keys
|
150
|
+
"""
|
151
|
+
missing_keys = []
|
152
|
+
unexpected_keys = []
|
153
|
+
for path, node in nodes(target).items():
|
154
|
+
r = node.load_state(state_dict[path], **kwargs)
|
155
|
+
if r is not None:
|
156
|
+
missing, unexpected = r
|
157
|
+
missing_keys.extend([f'{path}.{key}' for key in missing])
|
158
|
+
unexpected_keys.extend([f'{path}.{key}' for key in unexpected])
|
159
|
+
return StateLoadResult(missing_keys, unexpected_keys)
|
160
|
+
|
161
|
+
|
162
|
+
@set_module_as('brainstate.nn')
|
163
|
+
def save_all_states(target: Module, **kwargs) -> Dict:
|
164
|
+
"""
|
165
|
+
Save all states in the ``target`` as a dictionary for later disk serialization.
|
166
|
+
|
167
|
+
Args:
|
168
|
+
target: Module. The node to save its states.
|
169
|
+
|
170
|
+
Returns:
|
171
|
+
Dict. The state dict for serialization.
|
172
|
+
"""
|
173
|
+
return {key: node.save_state(**kwargs) for key, node in target.nodes().items()}
|
174
|
+
|
175
|
+
|
176
|
+
@set_module_as('brainstate.nn')
|
177
|
+
def assign_state_values(target: Module, *state_by_abs_path: Dict):
|
178
|
+
"""
|
179
|
+
Assign state values according to the given state dictionary.
|
180
|
+
|
181
|
+
Parameters
|
182
|
+
----------
|
183
|
+
target: Module
|
184
|
+
The target module.
|
185
|
+
state_by_abs_path: dict
|
186
|
+
The state dictionary which is accessed by the "absolute" accessing method.
|
187
|
+
|
188
|
+
"""
|
189
|
+
all_states = dict()
|
190
|
+
for state in state_by_abs_path:
|
191
|
+
all_states.update(state)
|
192
|
+
variables = target.states()
|
193
|
+
keys1 = set(all_states.keys())
|
194
|
+
keys2 = set(variables.keys())
|
195
|
+
for key in keys2.intersection(keys1):
|
196
|
+
variables[key].value = jax.numpy.asarray(all_states[key])
|
197
|
+
unexpected_keys = list(keys1 - keys2)
|
198
|
+
missing_keys = list(keys2 - keys1)
|
199
|
+
return unexpected_keys, missing_keys
|
@@ -0,0 +1,46 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
from ._dynamics_neuron import *
|
18
|
+
from ._dynamics_neuron import __all__ as dyn_neuron_all
|
19
|
+
from ._dynamics_synapse import *
|
20
|
+
from ._dynamics_synapse import __all__ as dyn_synapse_all
|
21
|
+
from ._inputs import *
|
22
|
+
from ._inputs import __all__ as inputs_all
|
23
|
+
from ._projection_alignpost import *
|
24
|
+
from ._projection_alignpost import __all__ as alignpost_all
|
25
|
+
from ._rate_rnns import *
|
26
|
+
from ._rate_rnns import __all__ as rate_rnns
|
27
|
+
from ._readout import *
|
28
|
+
from ._readout import __all__ as readout_all
|
29
|
+
|
30
|
+
__all__ = (
|
31
|
+
dyn_neuron_all
|
32
|
+
+ dyn_synapse_all
|
33
|
+
+ inputs_all
|
34
|
+
+ alignpost_all
|
35
|
+
+ rate_rnns
|
36
|
+
+ readout_all
|
37
|
+
)
|
38
|
+
|
39
|
+
del (
|
40
|
+
dyn_neuron_all,
|
41
|
+
dyn_synapse_all,
|
42
|
+
inputs_all,
|
43
|
+
readout_all,
|
44
|
+
alignpost_all,
|
45
|
+
rate_rnns,
|
46
|
+
)
|