braintrace 0.1.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braintrace/__init__.py +79 -0
- braintrace/_compatible_imports.py +62 -0
- braintrace/_compatible_imports_test.py +94 -0
- braintrace/_etrace_algorithms.py +333 -0
- braintrace/_etrace_compiler_base.py +290 -0
- braintrace/_etrace_compiler_graph.py +287 -0
- braintrace/_etrace_compiler_graph_test.py +329 -0
- braintrace/_etrace_compiler_hid_param_op.py +832 -0
- braintrace/_etrace_compiler_hid_param_op_test.py +112 -0
- braintrace/_etrace_compiler_hidden_group.py +954 -0
- braintrace/_etrace_compiler_hidden_group_test.py +843 -0
- braintrace/_etrace_compiler_hidden_pertubation.py +381 -0
- braintrace/_etrace_compiler_hidden_pertubation_test.py +126 -0
- braintrace/_etrace_compiler_module_info.py +551 -0
- braintrace/_etrace_compiler_module_info_test.py +114 -0
- braintrace/_etrace_concepts.py +382 -0
- braintrace/_etrace_concepts_test.py +159 -0
- braintrace/_etrace_debug_jaxpr2code.py +1134 -0
- braintrace/_etrace_debug_visualize.py +1561 -0
- braintrace/_etrace_graph_executor.py +319 -0
- braintrace/_etrace_graph_executor_test.py +67 -0
- braintrace/_etrace_input_data.py +203 -0
- braintrace/_etrace_input_data_test.py +51 -0
- braintrace/_etrace_model_test.py +450 -0
- braintrace/_etrace_model_with_group_state.py +267 -0
- braintrace/_etrace_operators.py +1072 -0
- braintrace/_etrace_operators_test.py +58 -0
- braintrace/_etrace_vjp/__init__.py +29 -0
- braintrace/_etrace_vjp/base.py +671 -0
- braintrace/_etrace_vjp/d_rtrl.py +756 -0
- braintrace/_etrace_vjp/d_rtrl_test.py +205 -0
- braintrace/_etrace_vjp/esd_rtrl.py +847 -0
- braintrace/_etrace_vjp/esd_rtrl_test.py +194 -0
- braintrace/_etrace_vjp/graph_executor.py +718 -0
- braintrace/_etrace_vjp/graph_executor_test.py +102 -0
- braintrace/_etrace_vjp/hybrid.py +604 -0
- braintrace/_etrace_vjp/misc.py +162 -0
- braintrace/_grad_exponential.py +85 -0
- braintrace/_misc.py +403 -0
- braintrace/_state_managment.py +436 -0
- braintrace/_typing.py +91 -0
- braintrace/nn/__init__.py +68 -0
- braintrace/nn/_conv.py +395 -0
- braintrace/nn/_conv_test.py +868 -0
- braintrace/nn/_linear.py +524 -0
- braintrace/nn/_linear_test.py +658 -0
- braintrace/nn/_normalizations.py +508 -0
- braintrace/nn/_normalizations_test.py +695 -0
- braintrace/nn/_readout.py +278 -0
- braintrace/nn/_readout_test.py +763 -0
- braintrace/nn/_rnn.py +1057 -0
- braintrace/nn/_rnn_test.py +710 -0
- braintrace-0.1.1.dist-info/METADATA +137 -0
- braintrace-0.1.1.dist-info/RECORD +57 -0
- braintrace-0.1.1.dist-info/WHEEL +6 -0
- braintrace-0.1.1.dist-info/licenses/LICENSE +202 -0
- braintrace-0.1.1.dist-info/top_level.txt +1 -0
braintrace/__init__.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
# -*- coding: utf-8 -*-
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
__version__ = "0.1.1"
|
|
20
|
+
__version_info__ = (0, 1, 1)
|
|
21
|
+
|
|
22
|
+
from braintrace._etrace_algorithms import *
|
|
23
|
+
from braintrace._etrace_algorithms import __all__ as _alg_all
|
|
24
|
+
from braintrace._etrace_compiler_graph import *
|
|
25
|
+
from braintrace._etrace_compiler_graph import __all__ as _compiler_all
|
|
26
|
+
from braintrace._etrace_compiler_hid_param_op import *
|
|
27
|
+
from braintrace._etrace_compiler_hid_param_op import __all__ as _hid_param_all
|
|
28
|
+
from braintrace._etrace_compiler_hidden_group import *
|
|
29
|
+
from braintrace._etrace_compiler_hidden_group import __all__ as _hid_group_all
|
|
30
|
+
from braintrace._etrace_compiler_hidden_pertubation import *
|
|
31
|
+
from braintrace._etrace_compiler_hidden_pertubation import __all__ as _hid_pertub_all
|
|
32
|
+
from braintrace._etrace_compiler_module_info import *
|
|
33
|
+
from braintrace._etrace_compiler_module_info import __all__ as _mod_info_all
|
|
34
|
+
from braintrace._etrace_concepts import *
|
|
35
|
+
from braintrace._etrace_concepts import __all__ as _con_all
|
|
36
|
+
from braintrace._etrace_graph_executor import *
|
|
37
|
+
from braintrace._etrace_graph_executor import __all__ as _exec_all
|
|
38
|
+
from braintrace._etrace_input_data import *
|
|
39
|
+
from braintrace._etrace_input_data import __all__ as _data_all
|
|
40
|
+
from braintrace._etrace_operators import *
|
|
41
|
+
from braintrace._etrace_operators import __all__ as _op_all
|
|
42
|
+
from braintrace._etrace_vjp import *
|
|
43
|
+
from braintrace._etrace_vjp import __all__ as _vjp_all
|
|
44
|
+
from braintrace._grad_exponential import *
|
|
45
|
+
from braintrace._grad_exponential import __all__ as _grad_exp_all
|
|
46
|
+
from braintrace._misc import *
|
|
47
|
+
from braintrace._misc import __all__ as _misc_all
|
|
48
|
+
from . import nn
|
|
49
|
+
|
|
50
|
+
__all__ = ['nn'] + _alg_all + _compiler_all + _hid_param_all + _hid_group_all + _hid_pertub_all
|
|
51
|
+
__all__ += _mod_info_all + _con_all + _exec_all + _data_all + _op_all + _vjp_all
|
|
52
|
+
__all__ += _grad_exp_all + _misc_all
|
|
53
|
+
|
|
54
|
+
del _alg_all, _compiler_all, _hid_param_all, _hid_group_all, _hid_pertub_all
|
|
55
|
+
del _mod_info_all, _con_all, _exec_all, _data_all, _op_all, _vjp_all
|
|
56
|
+
del _grad_exp_all,
|
|
57
|
+
del _misc_all
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def __getattr__(name):
|
|
61
|
+
mapping = {
|
|
62
|
+
'ETraceState': 'HiddenState',
|
|
63
|
+
'ETraceGroupState': 'HiddenGroupState',
|
|
64
|
+
'ETraceTreeState': 'HiddenTreeState',
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
if name in mapping:
|
|
68
|
+
import warnings
|
|
69
|
+
import brainstate
|
|
70
|
+
|
|
71
|
+
warnings.warn(
|
|
72
|
+
f"braintrace.{name} is deprecated and will be removed in a future release. "
|
|
73
|
+
f"Please use brainstate.{mapping[name]} instead.",
|
|
74
|
+
DeprecationWarning,
|
|
75
|
+
stacklevel=2,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
return getattr(brainstate, mapping[name])
|
|
79
|
+
raise AttributeError(name)
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import jax
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
'Primitive',
|
|
20
|
+
'Var',
|
|
21
|
+
'JaxprEqn',
|
|
22
|
+
'Jaxpr',
|
|
23
|
+
'ClosedJaxpr',
|
|
24
|
+
'Literal',
|
|
25
|
+
'new_var',
|
|
26
|
+
'is_jit_primitive',
|
|
27
|
+
'is_scan_primitive',
|
|
28
|
+
'is_while_primitive',
|
|
29
|
+
'is_cond_primitive',
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
if jax.__version_info__ < (0, 4, 38):
|
|
33
|
+
from jax.core import Primitive, Var, JaxprEqn, Jaxpr, ClosedJaxpr, Literal
|
|
34
|
+
|
|
35
|
+
else:
|
|
36
|
+
from jax.extend.core import Primitive, Var, JaxprEqn, Jaxpr, ClosedJaxpr, Literal
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def new_var(suffix, aval):
|
|
40
|
+
if jax.__version_info__ < (0, 6, 2):
|
|
41
|
+
return Var(suffix, aval)
|
|
42
|
+
else:
|
|
43
|
+
return Var(aval)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def is_jit_primitive(eqn: JaxprEqn) -> bool:
|
|
47
|
+
if jax.__version_info__ < (0, 7, 0):
|
|
48
|
+
return eqn.primitive.name == 'pjit'
|
|
49
|
+
else:
|
|
50
|
+
return eqn.primitive.name == 'jit'
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def is_scan_primitive(eqn: JaxprEqn) -> bool:
|
|
54
|
+
return eqn.primitive.name == 'scan'
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def is_while_primitive(eqn: JaxprEqn) -> bool:
|
|
58
|
+
return eqn.primitive.name == 'while'
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def is_cond_primitive(eqn: JaxprEqn) -> bool:
|
|
62
|
+
return eqn.primitive.name == 'cond'
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import jax.numpy as jnp
|
|
17
|
+
from jax import jit, make_jaxpr, lax
|
|
18
|
+
|
|
19
|
+
from braintrace._compatible_imports import (
|
|
20
|
+
is_jit_primitive, is_scan_primitive, is_while_primitive,
|
|
21
|
+
is_cond_primitive
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class TestPrimitive:
|
|
26
|
+
def test_jit(self):
|
|
27
|
+
@jit
|
|
28
|
+
def jit_function(x, y):
|
|
29
|
+
return x ** 2 + jnp.sin(y)
|
|
30
|
+
|
|
31
|
+
# Note: make_jaxpr on a jitted function shows the same jaxpr
|
|
32
|
+
jaxpr_jit = make_jaxpr(jit_function)(2.0, 1.0)
|
|
33
|
+
assert is_jit_primitive(jaxpr_jit.eqns[0])
|
|
34
|
+
|
|
35
|
+
def test_scan(self):
|
|
36
|
+
print("3. make_jaxpr with lax.scan:")
|
|
37
|
+
|
|
38
|
+
def scan_step(carry, x):
|
|
39
|
+
return carry + x, carry * x
|
|
40
|
+
|
|
41
|
+
def scan_function(init, xs):
|
|
42
|
+
return lax.scan(scan_step, init, xs)
|
|
43
|
+
|
|
44
|
+
# Create sample data
|
|
45
|
+
init_val = 1.0
|
|
46
|
+
xs = jnp.array([1.0, 2.0, 3.0, 4.0])
|
|
47
|
+
|
|
48
|
+
jaxpr_scan = make_jaxpr(scan_function)(init_val, xs)
|
|
49
|
+
assert is_scan_primitive(jaxpr_scan.eqns[0])
|
|
50
|
+
|
|
51
|
+
def test_while(self):
|
|
52
|
+
def while_cond(carry):
|
|
53
|
+
i, x = carry
|
|
54
|
+
return i < 5
|
|
55
|
+
|
|
56
|
+
def while_body(carry):
|
|
57
|
+
i, x = carry
|
|
58
|
+
return i + 1, x * 2
|
|
59
|
+
|
|
60
|
+
def while_function(init_carry):
|
|
61
|
+
return lax.while_loop(while_cond, while_body, init_carry)
|
|
62
|
+
|
|
63
|
+
init_carry = (0, 1.0)
|
|
64
|
+
jaxpr_while = make_jaxpr(while_function)(init_carry)
|
|
65
|
+
assert is_while_primitive(jaxpr_while.eqns[0])
|
|
66
|
+
|
|
67
|
+
def test_cond(self):
|
|
68
|
+
def true_branch(x):
|
|
69
|
+
return x * 2
|
|
70
|
+
|
|
71
|
+
def false_branch(x):
|
|
72
|
+
return x + 1
|
|
73
|
+
|
|
74
|
+
def cond_function(pred, x):
|
|
75
|
+
return lax.cond(pred, true_branch, false_branch, x)
|
|
76
|
+
|
|
77
|
+
jaxpr_cond = make_jaxpr(cond_function)(True, 5.0)
|
|
78
|
+
assert is_cond_primitive(jaxpr_cond.eqns[-1])
|
|
79
|
+
|
|
80
|
+
def test_fori_loop(self):
|
|
81
|
+
def branch_0(x):
|
|
82
|
+
return x * 2
|
|
83
|
+
|
|
84
|
+
def branch_1(x):
|
|
85
|
+
return x + 10
|
|
86
|
+
|
|
87
|
+
def branch_2(x):
|
|
88
|
+
return x ** 2
|
|
89
|
+
|
|
90
|
+
def switch_function(index, x):
|
|
91
|
+
return lax.switch(index, [branch_0, branch_1, branch_2], x)
|
|
92
|
+
|
|
93
|
+
jaxpr_switch = make_jaxpr(switch_function)(1, 3.0)
|
|
94
|
+
assert is_cond_primitive(jaxpr_switch.eqns[-1])
|
|
@@ -0,0 +1,333 @@
|
|
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
#
|
|
15
|
+
# Author: Chaoming Wang <chao.brain@qq.com>
|
|
16
|
+
# Date: 2024-04-03
|
|
17
|
+
# Copyright: 2024, Chaoming Wang
|
|
18
|
+
# ==============================================================================
|
|
19
|
+
|
|
20
|
+
# -*- coding: utf-8 -*-
|
|
21
|
+
|
|
22
|
+
from typing import Dict, Any, Optional
|
|
23
|
+
|
|
24
|
+
import brainstate
|
|
25
|
+
|
|
26
|
+
from ._etrace_compiler_graph import ETraceGraph
|
|
27
|
+
from ._etrace_graph_executor import ETraceGraphExecutor
|
|
28
|
+
from ._typing import Path
|
|
29
|
+
|
|
30
|
+
__all__ = [
|
|
31
|
+
'ETraceAlgorithm',
|
|
32
|
+
'EligibilityTrace',
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class EligibilityTrace(brainstate.ShortTermState):
|
|
37
|
+
"""
|
|
38
|
+
The state for storing the eligibility trace during the computation of
|
|
39
|
+
online learning algorithms.
|
|
40
|
+
|
|
41
|
+
Examples
|
|
42
|
+
--------
|
|
43
|
+
When you are using :class:`braintrace.IODimVjpAlgorithm`, you can get
|
|
44
|
+
the eligibility trace of the weight by calling:
|
|
45
|
+
|
|
46
|
+
.. code-block:: python
|
|
47
|
+
|
|
48
|
+
>>> etrace = etrace_algorithm.etrace_of(weight)
|
|
49
|
+
|
|
50
|
+
"""
|
|
51
|
+
__module__ = 'braintrace'
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class ETraceAlgorithm(brainstate.nn.Module):
|
|
55
|
+
r"""
|
|
56
|
+
The base class for the eligibility trace algorithm.
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
model : brainstate.nn.Module
|
|
61
|
+
The model function, which receives the input arguments and returns the model output.
|
|
62
|
+
name : str, optional
|
|
63
|
+
The name of the etrace algorithm.
|
|
64
|
+
|
|
65
|
+
Attributes
|
|
66
|
+
----------
|
|
67
|
+
graph : ETraceGraphExecutor
|
|
68
|
+
The etrace graph.
|
|
69
|
+
param_states : Dict[Hashable, brainstate.ParamState]
|
|
70
|
+
The weight states.
|
|
71
|
+
hidden_states : Dict[Hashable, brainstate.HiddenState]
|
|
72
|
+
The hidden states.
|
|
73
|
+
other_states : Dict[Hashable, brainstate.State]
|
|
74
|
+
The other states.
|
|
75
|
+
is_compiled : bool
|
|
76
|
+
Whether the etrace algorithm has been compiled.
|
|
77
|
+
running_index : brainstate.ParamState[int]
|
|
78
|
+
The running index.
|
|
79
|
+
"""
|
|
80
|
+
__module__ = 'braintrace'
|
|
81
|
+
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
model: brainstate.nn.Module,
|
|
85
|
+
graph_executor: ETraceGraphExecutor,
|
|
86
|
+
name: Optional[str] = None,
|
|
87
|
+
):
|
|
88
|
+
super().__init__(name=name)
|
|
89
|
+
|
|
90
|
+
# the model
|
|
91
|
+
if not isinstance(model, brainstate.nn.Module):
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f'The model should be a brainstate.nn.Module, this can help us to '
|
|
94
|
+
f'better obtain the program structure. But we got {type(model)}.'
|
|
95
|
+
)
|
|
96
|
+
self.model4compile = model
|
|
97
|
+
|
|
98
|
+
# the graph
|
|
99
|
+
if not isinstance(graph_executor, ETraceGraphExecutor):
|
|
100
|
+
raise ValueError(
|
|
101
|
+
f'The graph should be a ETraceGraphExecutor, this can help us to '
|
|
102
|
+
f'better obtain the program structure. But we got {type(graph_executor)}.'
|
|
103
|
+
)
|
|
104
|
+
self.graph_executor = graph_executor
|
|
105
|
+
|
|
106
|
+
# The flag to indicate whether the etrace algorithm has been compiled
|
|
107
|
+
self.is_compiled = False
|
|
108
|
+
|
|
109
|
+
# the running index
|
|
110
|
+
self.running_index = brainstate.LongTermState(0)
|
|
111
|
+
|
|
112
|
+
# other states
|
|
113
|
+
self._param_states = None
|
|
114
|
+
self._hidden_states = None
|
|
115
|
+
self._other_states = None
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def graph(self) -> ETraceGraph:
|
|
119
|
+
"""
|
|
120
|
+
Get the etrace graph.
|
|
121
|
+
|
|
122
|
+
Returns
|
|
123
|
+
-------
|
|
124
|
+
ETraceGraph
|
|
125
|
+
The etrace graph.
|
|
126
|
+
"""
|
|
127
|
+
return self.graph_executor.graph
|
|
128
|
+
|
|
129
|
+
@property
|
|
130
|
+
def executor(self) -> ETraceGraphExecutor:
|
|
131
|
+
"""
|
|
132
|
+
Get the etrace graph executor.
|
|
133
|
+
|
|
134
|
+
Returns
|
|
135
|
+
-------
|
|
136
|
+
ETraceGraphExecutor
|
|
137
|
+
The etrace graph executor.
|
|
138
|
+
"""
|
|
139
|
+
return self.graph_executor
|
|
140
|
+
|
|
141
|
+
@property
|
|
142
|
+
def param_states(self) -> brainstate.util.FlattedDict[Path, brainstate.ParamState]:
|
|
143
|
+
"""
|
|
144
|
+
Get the parameter weight states.
|
|
145
|
+
|
|
146
|
+
Returns
|
|
147
|
+
-------
|
|
148
|
+
brainstate.util.FlattedDict[Path, brainstate.ParamState]
|
|
149
|
+
The parameter weight states.
|
|
150
|
+
"""
|
|
151
|
+
if self._param_states is None:
|
|
152
|
+
self._split_state()
|
|
153
|
+
return self._param_states
|
|
154
|
+
|
|
155
|
+
@property
|
|
156
|
+
def hidden_states(self) -> brainstate.util.FlattedDict[Path, brainstate.HiddenState]:
|
|
157
|
+
"""
|
|
158
|
+
Get the hidden states.
|
|
159
|
+
|
|
160
|
+
Returns
|
|
161
|
+
-------
|
|
162
|
+
brainstate.util.FlattedDict[Path, brainstate.HiddenState]
|
|
163
|
+
The hidden states.
|
|
164
|
+
"""
|
|
165
|
+
if self._hidden_states is None:
|
|
166
|
+
self._split_state()
|
|
167
|
+
return self._hidden_states
|
|
168
|
+
|
|
169
|
+
@property
|
|
170
|
+
def other_states(self) -> brainstate.util.FlattedDict[Path, brainstate.State]:
|
|
171
|
+
"""
|
|
172
|
+
Get the other states.
|
|
173
|
+
|
|
174
|
+
Returns
|
|
175
|
+
-------
|
|
176
|
+
brainstate.util.FlattedDict[Path, brainstate.State]
|
|
177
|
+
The other states.
|
|
178
|
+
"""
|
|
179
|
+
if self._other_states is None:
|
|
180
|
+
self._split_state()
|
|
181
|
+
return self._other_states
|
|
182
|
+
|
|
183
|
+
def _split_state(self):
|
|
184
|
+
# --- the state separation --- #
|
|
185
|
+
#
|
|
186
|
+
# [NOTE]
|
|
187
|
+
#
|
|
188
|
+
# The `ETraceGraphExecutor` and the following states suggests that
|
|
189
|
+
# `ETraceAlgorithm` depends on the states we created in the
|
|
190
|
+
# `ETraceGraphExecutor`, including:
|
|
191
|
+
#
|
|
192
|
+
# - the weight states, which is invariant during the training process
|
|
193
|
+
# - the hidden states, the recurrent states, which may be changed between different training epochs
|
|
194
|
+
# - the other states, which may be changed between different training epochs
|
|
195
|
+
(
|
|
196
|
+
self._param_states,
|
|
197
|
+
self._hidden_states,
|
|
198
|
+
self._other_states
|
|
199
|
+
) = self.graph.module_info.retrieved_model_states.split(brainstate.ParamState, brainstate.HiddenState, ...)
|
|
200
|
+
|
|
201
|
+
def compile_graph(self, *args) -> None:
|
|
202
|
+
r"""
|
|
203
|
+
Compile the eligibility trace graph of the relationship between etrace weights, states and operators.
|
|
204
|
+
|
|
205
|
+
The compilation process includes:
|
|
206
|
+
|
|
207
|
+
- building the etrace graph
|
|
208
|
+
- separating the states
|
|
209
|
+
- initializing the etrace states
|
|
210
|
+
|
|
211
|
+
Parameters
|
|
212
|
+
----------
|
|
213
|
+
*args
|
|
214
|
+
The input arguments.
|
|
215
|
+
"""
|
|
216
|
+
|
|
217
|
+
if not self.is_compiled:
|
|
218
|
+
# --- the model etrace graph -- #
|
|
219
|
+
self.graph_executor.compile_graph(*args)
|
|
220
|
+
|
|
221
|
+
# --- the initialization of the states --- #
|
|
222
|
+
self.init_etrace_state(*args)
|
|
223
|
+
|
|
224
|
+
# mark the graph is compiled
|
|
225
|
+
self.is_compiled = True
|
|
226
|
+
|
|
227
|
+
@property
|
|
228
|
+
def path_to_states(self) -> brainstate.util.FlattedDict[Path, brainstate.State]:
|
|
229
|
+
"""
|
|
230
|
+
Get the path to the states.
|
|
231
|
+
|
|
232
|
+
Returns
|
|
233
|
+
-------
|
|
234
|
+
brainstate.util.FlattedDict[Path, brainstate.State]
|
|
235
|
+
The mapping from path to states.
|
|
236
|
+
"""
|
|
237
|
+
return self.graph_executor.path_to_states
|
|
238
|
+
|
|
239
|
+
@property
|
|
240
|
+
def state_id_to_path(self) -> Dict[int, Path]:
|
|
241
|
+
"""
|
|
242
|
+
Get the state ID to the path.
|
|
243
|
+
|
|
244
|
+
Returns
|
|
245
|
+
-------
|
|
246
|
+
Dict[int, Path]
|
|
247
|
+
The mapping from state ID to path.
|
|
248
|
+
"""
|
|
249
|
+
return self.graph_executor.state_id_to_path
|
|
250
|
+
|
|
251
|
+
def show_graph(self) -> None:
|
|
252
|
+
"""
|
|
253
|
+
Show the etrace graph.
|
|
254
|
+
"""
|
|
255
|
+
return self.graph_executor.show_graph()
|
|
256
|
+
|
|
257
|
+
def __call__(self, *args) -> Any:
|
|
258
|
+
"""
|
|
259
|
+
Update the model and the eligibility trace states.
|
|
260
|
+
|
|
261
|
+
Parameters
|
|
262
|
+
----------
|
|
263
|
+
*args
|
|
264
|
+
The input arguments.
|
|
265
|
+
|
|
266
|
+
Returns
|
|
267
|
+
-------
|
|
268
|
+
Any
|
|
269
|
+
The output of the update method.
|
|
270
|
+
"""
|
|
271
|
+
return self.update(*args)
|
|
272
|
+
|
|
273
|
+
def update(self, *args) -> Any:
|
|
274
|
+
"""
|
|
275
|
+
Update the model and the eligibility trace states.
|
|
276
|
+
|
|
277
|
+
Parameters
|
|
278
|
+
----------
|
|
279
|
+
*args
|
|
280
|
+
The input arguments.
|
|
281
|
+
|
|
282
|
+
Returns
|
|
283
|
+
-------
|
|
284
|
+
Any
|
|
285
|
+
The model output.
|
|
286
|
+
|
|
287
|
+
Raises
|
|
288
|
+
------
|
|
289
|
+
NotImplementedError
|
|
290
|
+
This method must be implemented by subclasses.
|
|
291
|
+
"""
|
|
292
|
+
raise NotImplementedError
|
|
293
|
+
|
|
294
|
+
def init_etrace_state(self, *args, **kwargs) -> None:
|
|
295
|
+
"""
|
|
296
|
+
Initialize the eligibility trace states of the etrace algorithm.
|
|
297
|
+
|
|
298
|
+
This method is needed after compiling the etrace graph. See `.compile_graph()` for the details.
|
|
299
|
+
|
|
300
|
+
Parameters
|
|
301
|
+
----------
|
|
302
|
+
*args
|
|
303
|
+
The positional arguments.
|
|
304
|
+
**kwargs
|
|
305
|
+
The keyword arguments.
|
|
306
|
+
|
|
307
|
+
Raises
|
|
308
|
+
------
|
|
309
|
+
NotImplementedError
|
|
310
|
+
This method must be implemented by subclasses.
|
|
311
|
+
"""
|
|
312
|
+
raise NotImplementedError
|
|
313
|
+
|
|
314
|
+
def get_etrace_of(self, weight: brainstate.ParamState | Path) -> Any:
|
|
315
|
+
"""
|
|
316
|
+
Get the eligibility trace of the given weight.
|
|
317
|
+
|
|
318
|
+
Parameters
|
|
319
|
+
----------
|
|
320
|
+
weight : brainstate.ParamState | Path
|
|
321
|
+
The parameter weight or path to the weight.
|
|
322
|
+
|
|
323
|
+
Returns
|
|
324
|
+
-------
|
|
325
|
+
Any
|
|
326
|
+
The eligibility trace.
|
|
327
|
+
|
|
328
|
+
Raises
|
|
329
|
+
------
|
|
330
|
+
NotImplementedError
|
|
331
|
+
This method must be implemented by subclasses.
|
|
332
|
+
"""
|
|
333
|
+
raise NotImplementedError
|