braintrace 0.1.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braintrace/__init__.py +79 -0
- braintrace/_compatible_imports.py +62 -0
- braintrace/_compatible_imports_test.py +94 -0
- braintrace/_etrace_algorithms.py +333 -0
- braintrace/_etrace_compiler_base.py +290 -0
- braintrace/_etrace_compiler_graph.py +287 -0
- braintrace/_etrace_compiler_graph_test.py +329 -0
- braintrace/_etrace_compiler_hid_param_op.py +832 -0
- braintrace/_etrace_compiler_hid_param_op_test.py +112 -0
- braintrace/_etrace_compiler_hidden_group.py +954 -0
- braintrace/_etrace_compiler_hidden_group_test.py +843 -0
- braintrace/_etrace_compiler_hidden_pertubation.py +381 -0
- braintrace/_etrace_compiler_hidden_pertubation_test.py +126 -0
- braintrace/_etrace_compiler_module_info.py +551 -0
- braintrace/_etrace_compiler_module_info_test.py +114 -0
- braintrace/_etrace_concepts.py +382 -0
- braintrace/_etrace_concepts_test.py +159 -0
- braintrace/_etrace_debug_jaxpr2code.py +1134 -0
- braintrace/_etrace_debug_visualize.py +1561 -0
- braintrace/_etrace_graph_executor.py +319 -0
- braintrace/_etrace_graph_executor_test.py +67 -0
- braintrace/_etrace_input_data.py +203 -0
- braintrace/_etrace_input_data_test.py +51 -0
- braintrace/_etrace_model_test.py +450 -0
- braintrace/_etrace_model_with_group_state.py +267 -0
- braintrace/_etrace_operators.py +1072 -0
- braintrace/_etrace_operators_test.py +58 -0
- braintrace/_etrace_vjp/__init__.py +29 -0
- braintrace/_etrace_vjp/base.py +671 -0
- braintrace/_etrace_vjp/d_rtrl.py +756 -0
- braintrace/_etrace_vjp/d_rtrl_test.py +205 -0
- braintrace/_etrace_vjp/esd_rtrl.py +847 -0
- braintrace/_etrace_vjp/esd_rtrl_test.py +194 -0
- braintrace/_etrace_vjp/graph_executor.py +718 -0
- braintrace/_etrace_vjp/graph_executor_test.py +102 -0
- braintrace/_etrace_vjp/hybrid.py +604 -0
- braintrace/_etrace_vjp/misc.py +162 -0
- braintrace/_grad_exponential.py +85 -0
- braintrace/_misc.py +403 -0
- braintrace/_state_managment.py +436 -0
- braintrace/_typing.py +91 -0
- braintrace/nn/__init__.py +68 -0
- braintrace/nn/_conv.py +395 -0
- braintrace/nn/_conv_test.py +868 -0
- braintrace/nn/_linear.py +524 -0
- braintrace/nn/_linear_test.py +658 -0
- braintrace/nn/_normalizations.py +508 -0
- braintrace/nn/_normalizations_test.py +695 -0
- braintrace/nn/_readout.py +278 -0
- braintrace/nn/_readout_test.py +763 -0
- braintrace/nn/_rnn.py +1057 -0
- braintrace/nn/_rnn_test.py +710 -0
- braintrace-0.1.1.dist-info/METADATA +137 -0
- braintrace-0.1.1.dist-info/RECORD +57 -0
- braintrace-0.1.1.dist-info/WHEEL +6 -0
- braintrace-0.1.1.dist-info/licenses/LICENSE +202 -0
- braintrace-0.1.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,290 @@
|
|
|
1
|
+
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from typing import Dict, Sequence, Set, List
|
|
17
|
+
|
|
18
|
+
from ._compatible_imports import (
|
|
19
|
+
Var,
|
|
20
|
+
JaxprEqn,
|
|
21
|
+
is_jit_primitive,
|
|
22
|
+
is_scan_primitive,
|
|
23
|
+
is_while_primitive,
|
|
24
|
+
is_cond_primitive,
|
|
25
|
+
)
|
|
26
|
+
from ._etrace_operators import (
|
|
27
|
+
is_etrace_op,
|
|
28
|
+
is_etrace_op_enable_gradient,
|
|
29
|
+
)
|
|
30
|
+
from ._typing import Path
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def find_matched_vars(
|
|
34
|
+
invars: Sequence[Var],
|
|
35
|
+
invar_needed_in_oth_eqns: Set[Var]
|
|
36
|
+
) -> List[Var]:
|
|
37
|
+
"""
|
|
38
|
+
Checking whether the invars are matched with the invar_needed_in_oth_eqns.
|
|
39
|
+
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
invars : Sequence[Var]
|
|
43
|
+
The input variables of the equation.
|
|
44
|
+
invar_needed_in_oth_eqns : Set[Var]
|
|
45
|
+
The variables needed in the other equations.
|
|
46
|
+
|
|
47
|
+
Returns
|
|
48
|
+
-------
|
|
49
|
+
List[Var]
|
|
50
|
+
The list of matched variables.
|
|
51
|
+
"""
|
|
52
|
+
matched = []
|
|
53
|
+
for invar in invars:
|
|
54
|
+
if isinstance(invar, Var) and invar in invar_needed_in_oth_eqns:
|
|
55
|
+
matched.append(invar)
|
|
56
|
+
return matched
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def find_element_exist_in_the_set(
|
|
60
|
+
elements: Sequence[Var],
|
|
61
|
+
the_set: Set[Var]
|
|
62
|
+
) -> Var | None:
|
|
63
|
+
"""
|
|
64
|
+
Checking whether the jaxpr vars contain the weight variables.
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
elements : Sequence[Var]
|
|
69
|
+
The input variables of the equation.
|
|
70
|
+
the_set : Set[Var]
|
|
71
|
+
The set of the weight variables.
|
|
72
|
+
|
|
73
|
+
Returns
|
|
74
|
+
-------
|
|
75
|
+
Var | None
|
|
76
|
+
The first element found in the set, or None if no element is found.
|
|
77
|
+
"""
|
|
78
|
+
for invar in elements:
|
|
79
|
+
if isinstance(invar, Var) and invar in the_set:
|
|
80
|
+
return invar
|
|
81
|
+
return None
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def check_unsupported_op(
|
|
85
|
+
self,
|
|
86
|
+
eqn: JaxprEqn,
|
|
87
|
+
op_name: str
|
|
88
|
+
):
|
|
89
|
+
"""
|
|
90
|
+
Checks for unsupported operations involving weight or hidden state variables in the given equation.
|
|
91
|
+
|
|
92
|
+
This function verifies whether the specified JAX equation (`eqn`) uses weight or hidden state variables
|
|
93
|
+
in a manner that is currently unsupported. If such usage is detected, a `NotImplementedError` is raised
|
|
94
|
+
with a detailed message.
|
|
95
|
+
|
|
96
|
+
Parameters
|
|
97
|
+
----------
|
|
98
|
+
self : JaxprEvaluation
|
|
99
|
+
The instance of the class containing this method.
|
|
100
|
+
eqn : JaxprEqn
|
|
101
|
+
The JAX equation to be checked.
|
|
102
|
+
op_name : str
|
|
103
|
+
The name of the operation being checked (e.g., 'pjit', 'scan', 'while', 'cond').
|
|
104
|
+
|
|
105
|
+
Raises
|
|
106
|
+
------
|
|
107
|
+
NotImplementedError
|
|
108
|
+
If the equation uses weight variables or computes hidden state variables
|
|
109
|
+
in an unsupported manner.
|
|
110
|
+
"""
|
|
111
|
+
# checking whether the weight variables are used in the equation
|
|
112
|
+
invar = find_element_exist_in_the_set(eqn.invars, self.weight_invars)
|
|
113
|
+
if invar is not None:
|
|
114
|
+
raise NotImplementedError(
|
|
115
|
+
f'Currently, we do not support the weight states are used within a {op_name} function. \n'
|
|
116
|
+
f'Please remove your {op_name} on the intermediate steps. \n\n'
|
|
117
|
+
f'The weight state is: {self.invar_to_hidden_path[invar]}. \n'
|
|
118
|
+
f'The Jaxpr of the {op_name} function is: \n\n'
|
|
119
|
+
f'{eqn} \n\n'
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# checking whether the hidden variables are computed in the equation
|
|
123
|
+
outvar = find_element_exist_in_the_set(eqn.outvars, self.hidden_outvars)
|
|
124
|
+
if outvar is not None:
|
|
125
|
+
raise NotImplementedError(
|
|
126
|
+
f'Currently, we do not support the hidden states are computed within a {op_name} function. \n'
|
|
127
|
+
f'Please remove your {op_name} on the intermediate steps. \n\n'
|
|
128
|
+
f'The hidden state is: {self.outvar_to_hidden_path[outvar]}. \n'
|
|
129
|
+
f'The Jaxpr of the {op_name} function is: \n\n'
|
|
130
|
+
f'{eqn} \n\n'
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class JaxprEvaluation(object):
|
|
135
|
+
"""
|
|
136
|
+
A base class for evaluating JAX program representations (jaxpr) to extract eligibility trace relationships.
|
|
137
|
+
|
|
138
|
+
This class analyzes the computational graph represented as JAX primitives to identify and track
|
|
139
|
+
relationships between weight parameters and hidden states for eligibility trace computation.
|
|
140
|
+
Subclasses must implement the `_eval_eqn` method to define specific evaluation behavior.
|
|
141
|
+
|
|
142
|
+
The class handles special JAX primitives such as pjit, scan, while, and cond operations,
|
|
143
|
+
providing appropriate handling or restrictions for eligibility trace compilation.
|
|
144
|
+
|
|
145
|
+
Parameters
|
|
146
|
+
----------
|
|
147
|
+
weight_invars : Set[Var]
|
|
148
|
+
Input variables representing weight parameters in the computational graph.
|
|
149
|
+
hidden_invars : Set[Var]
|
|
150
|
+
Input variables representing hidden states in the computational graph.
|
|
151
|
+
hidden_outvars : Set[Var]
|
|
152
|
+
Output variables representing hidden states in the computational graph.
|
|
153
|
+
invar_to_hidden_path : Dict[Var, Path]
|
|
154
|
+
Mapping from input variables to their paths in the hidden state hierarchy.
|
|
155
|
+
outvar_to_hidden_path : Dict[Var, Path]
|
|
156
|
+
Mapping from output variables to their paths in the hidden state hierarchy.
|
|
157
|
+
|
|
158
|
+
Attributes
|
|
159
|
+
----------
|
|
160
|
+
weight_invars : Set[Var]
|
|
161
|
+
Stored input weight variables.
|
|
162
|
+
hidden_invars : Set[Var]
|
|
163
|
+
Stored input hidden state variables.
|
|
164
|
+
hidden_outvars : Set[Var]
|
|
165
|
+
Stored output hidden state variables.
|
|
166
|
+
invar_to_hidden_path : Dict[Var, Path]
|
|
167
|
+
Stored mapping from input variables to hidden paths.
|
|
168
|
+
outvar_to_hidden_path : Dict[Var, Path]
|
|
169
|
+
Stored mapping from output variables to hidden paths.
|
|
170
|
+
"""
|
|
171
|
+
__module__ = 'braintrace'
|
|
172
|
+
|
|
173
|
+
def __init__(
|
|
174
|
+
self,
|
|
175
|
+
weight_invars: Set[Var],
|
|
176
|
+
hidden_invars: Set[Var],
|
|
177
|
+
hidden_outvars: Set[Var],
|
|
178
|
+
invar_to_hidden_path: Dict[Var, Path],
|
|
179
|
+
outvar_to_hidden_path: Dict[Var, Path],
|
|
180
|
+
):
|
|
181
|
+
self.weight_invars = weight_invars
|
|
182
|
+
self.hidden_invars = hidden_invars
|
|
183
|
+
self.hidden_outvars = hidden_outvars
|
|
184
|
+
self.invar_to_hidden_path = invar_to_hidden_path
|
|
185
|
+
self.outvar_to_hidden_path = outvar_to_hidden_path
|
|
186
|
+
|
|
187
|
+
def _eval_jaxpr(self, jaxpr) -> None:
|
|
188
|
+
"""
|
|
189
|
+
Evaluating the jaxpr for extracting the etrace relationships.
|
|
190
|
+
|
|
191
|
+
Parameters
|
|
192
|
+
----------
|
|
193
|
+
jaxpr : Jaxpr
|
|
194
|
+
The jaxpr for the model.
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
for eqn in jaxpr.eqns:
|
|
198
|
+
# TODO: add the support for the scan, while, cond, pjit, and other operators
|
|
199
|
+
# Currently, scan, while, and cond are usually not the common operators used in
|
|
200
|
+
# the definition of a brain dynamics model. So we may not need to consider them
|
|
201
|
+
# during the current phase.
|
|
202
|
+
# However, for the long-term maintenance and development, we need to consider them,
|
|
203
|
+
# since users usually create crazy models.
|
|
204
|
+
|
|
205
|
+
if is_jit_primitive(eqn):
|
|
206
|
+
self._eval_pjit(eqn)
|
|
207
|
+
elif is_scan_primitive(eqn):
|
|
208
|
+
self._eval_scan(eqn)
|
|
209
|
+
elif is_while_primitive(eqn):
|
|
210
|
+
self._eval_while(eqn)
|
|
211
|
+
elif is_cond_primitive(eqn):
|
|
212
|
+
self._eval_cond(eqn)
|
|
213
|
+
else:
|
|
214
|
+
self._eval_eqn(eqn)
|
|
215
|
+
|
|
216
|
+
def _eval_pjit(self, eqn: JaxprEqn) -> None:
|
|
217
|
+
"""
|
|
218
|
+
Evaluating the pjit primitive.
|
|
219
|
+
|
|
220
|
+
Parameters
|
|
221
|
+
----------
|
|
222
|
+
eqn : JaxprEqn
|
|
223
|
+
The JAX equation to evaluate.
|
|
224
|
+
"""
|
|
225
|
+
if is_etrace_op(eqn.params['name']):
|
|
226
|
+
if is_etrace_op_enable_gradient(eqn.params['name']):
|
|
227
|
+
self._eval_eqn(eqn)
|
|
228
|
+
return
|
|
229
|
+
|
|
230
|
+
check_unsupported_op(self, eqn, 'jit')
|
|
231
|
+
|
|
232
|
+
# treat the pjit as a normal jaxpr equation
|
|
233
|
+
self._eval_eqn(eqn)
|
|
234
|
+
|
|
235
|
+
def _eval_scan(self, eqn: JaxprEqn) -> None:
|
|
236
|
+
"""
|
|
237
|
+
Evaluating the scan primitive.
|
|
238
|
+
|
|
239
|
+
Parameters
|
|
240
|
+
----------
|
|
241
|
+
eqn : JaxprEqn
|
|
242
|
+
The JAX equation to evaluate.
|
|
243
|
+
"""
|
|
244
|
+
check_unsupported_op(self, eqn, 'while')
|
|
245
|
+
self._eval_eqn(eqn)
|
|
246
|
+
|
|
247
|
+
def _eval_while(self, eqn: JaxprEqn) -> None:
|
|
248
|
+
"""
|
|
249
|
+
Evaluating the while primitive.
|
|
250
|
+
|
|
251
|
+
Parameters
|
|
252
|
+
----------
|
|
253
|
+
eqn : JaxprEqn
|
|
254
|
+
The JAX equation to evaluate.
|
|
255
|
+
"""
|
|
256
|
+
check_unsupported_op(self, eqn, 'scan')
|
|
257
|
+
self._eval_eqn(eqn)
|
|
258
|
+
|
|
259
|
+
def _eval_cond(self, eqn: JaxprEqn) -> None:
|
|
260
|
+
"""
|
|
261
|
+
Evaluating the cond primitive.
|
|
262
|
+
|
|
263
|
+
Parameters
|
|
264
|
+
----------
|
|
265
|
+
eqn : JaxprEqn
|
|
266
|
+
The JAX equation to evaluate.
|
|
267
|
+
"""
|
|
268
|
+
check_unsupported_op(self, eqn, 'cond')
|
|
269
|
+
self._eval_eqn(eqn)
|
|
270
|
+
|
|
271
|
+
def _eval_eqn(self, eqn):
|
|
272
|
+
"""
|
|
273
|
+
Evaluate a single JAX equation.
|
|
274
|
+
|
|
275
|
+
This method must be implemented by subclasses to define specific
|
|
276
|
+
evaluation behavior for equations.
|
|
277
|
+
|
|
278
|
+
Parameters
|
|
279
|
+
----------
|
|
280
|
+
eqn : JaxprEqn
|
|
281
|
+
The JAX equation to evaluate.
|
|
282
|
+
|
|
283
|
+
Raises
|
|
284
|
+
------
|
|
285
|
+
NotImplementedError
|
|
286
|
+
This method must be implemented in subclasses.
|
|
287
|
+
"""
|
|
288
|
+
raise NotImplementedError(
|
|
289
|
+
'The method "_eval_eqn" should be implemented in the subclass.'
|
|
290
|
+
)
|
|
@@ -0,0 +1,287 @@
|
|
|
1
|
+
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
import threading
|
|
18
|
+
from contextlib import contextmanager
|
|
19
|
+
from typing import Dict, Sequence, Tuple, Optional, NamedTuple
|
|
20
|
+
|
|
21
|
+
import brainstate
|
|
22
|
+
import jax
|
|
23
|
+
|
|
24
|
+
from ._etrace_compiler_hid_param_op import (
|
|
25
|
+
find_hidden_param_op_relations_from_minfo,
|
|
26
|
+
HiddenParamOpRelation,
|
|
27
|
+
)
|
|
28
|
+
from ._etrace_compiler_hidden_group import (
|
|
29
|
+
find_hidden_groups_from_minfo,
|
|
30
|
+
HiddenGroup,
|
|
31
|
+
)
|
|
32
|
+
from ._etrace_compiler_hidden_pertubation import (
|
|
33
|
+
add_hidden_perturbation_from_minfo,
|
|
34
|
+
HiddenPerturbation,
|
|
35
|
+
)
|
|
36
|
+
from ._etrace_compiler_module_info import (
|
|
37
|
+
extract_module_info,
|
|
38
|
+
ModuleInfo,
|
|
39
|
+
)
|
|
40
|
+
from ._typing import (
|
|
41
|
+
Inputs,
|
|
42
|
+
Path,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
__all__ = [
|
|
46
|
+
'ETraceGraph',
|
|
47
|
+
'compile_etrace_graph',
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def order_hidden_group_index(
|
|
52
|
+
hidden_groups: Sequence[HiddenGroup],
|
|
53
|
+
):
|
|
54
|
+
"""
|
|
55
|
+
Verifies that hidden group indices match their positions in the sequence.
|
|
56
|
+
|
|
57
|
+
This function ensures that the index attribute of each HiddenGroup in the sequence
|
|
58
|
+
matches its position in the sequence. This validation is important for maintaining
|
|
59
|
+
the correct ordering of hidden groups in the eligibility trace compilation process.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
hidden_groups (Sequence[HiddenGroup]): A sequence of HiddenGroup objects to validate.
|
|
63
|
+
|
|
64
|
+
Raises:
|
|
65
|
+
AssertionError: If any hidden group's index doesn't match its position in the sequence.
|
|
66
|
+
"""
|
|
67
|
+
for i, group in enumerate(hidden_groups):
|
|
68
|
+
assert group.index == i, f"Hidden group index {group.index} should be equal to its position {i}."
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class ETraceGraph(NamedTuple):
|
|
72
|
+
"""
|
|
73
|
+
The overall compiled graph for the eligibility trace.
|
|
74
|
+
|
|
75
|
+
The eligibility trace graph, tracking the relationship between the etrace weights
|
|
76
|
+
:py:class:`ETraceParam`, the etrace variables :py:class:`ETraceState`, and the etrace
|
|
77
|
+
operations :py:class:`ETraceOp`.
|
|
78
|
+
|
|
79
|
+
The following fields are included:
|
|
80
|
+
|
|
81
|
+
- ``module_info``: The model information, instance of :class:`ModuleInfo`.
|
|
82
|
+
- ``hidden_groups``: The hidden groups, sequence of :class:`HiddenGroup`.
|
|
83
|
+
- ``hid_path_to_group``: The mapping from the hidden path to the hidden group :class:`HiddenGroup`.
|
|
84
|
+
- ``hidden_param_op_relations``: The hidden parameter operation relations, sequence of :class:`HiddenParamOpRelation`.
|
|
85
|
+
- ``hidden_perturb``: The hidden perturbation, instance of :class:`HiddenPerturbation`, or None.
|
|
86
|
+
|
|
87
|
+
Example::
|
|
88
|
+
|
|
89
|
+
>>> import braintrace
|
|
90
|
+
>>> import brainstate
|
|
91
|
+
>>> gru = braintrace.nn.GRUCell(10, 20)
|
|
92
|
+
>>> gru.init_state()
|
|
93
|
+
>>> inputs = brainstate.random.randn(10)
|
|
94
|
+
>>> compiled_graph = braintrace.compile_etrace_graph(gru, inputs)
|
|
95
|
+
>>> compiled_graph.dict().keys()
|
|
96
|
+
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
module_info: ModuleInfo
|
|
100
|
+
hidden_groups: Sequence[HiddenGroup]
|
|
101
|
+
hid_path_to_group: Dict[Path, HiddenGroup]
|
|
102
|
+
hidden_param_op_relations: Sequence[HiddenParamOpRelation]
|
|
103
|
+
hidden_perturb: HiddenPerturbation | None
|
|
104
|
+
|
|
105
|
+
def call_hidden_perturb(
|
|
106
|
+
self,
|
|
107
|
+
args: Inputs,
|
|
108
|
+
perturb_data: Sequence[jax.Array],
|
|
109
|
+
old_state_vals: Optional[Sequence[jax.Array]] = None,
|
|
110
|
+
):
|
|
111
|
+
# state checking
|
|
112
|
+
if old_state_vals is None:
|
|
113
|
+
old_state_vals = [st.value for st in self.module_info.compiled_model_states]
|
|
114
|
+
|
|
115
|
+
# calling the function
|
|
116
|
+
jaxpr_outs = self.hidden_perturb.eval_jaxpr(
|
|
117
|
+
jax.tree.leaves((args, old_state_vals)),
|
|
118
|
+
perturb_data,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
return self.module_info._process(*args, jaxpr_outs=jaxpr_outs)
|
|
122
|
+
|
|
123
|
+
def dict(self) -> Dict:
|
|
124
|
+
return self._asdict()
|
|
125
|
+
|
|
126
|
+
def __repr__(self) -> str:
|
|
127
|
+
return repr(brainstate.util.PrettyMapping(self._asdict(), type_name=self.__class__.__name__))
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
ETraceGraph.__module__ = 'braintrace'
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class CONTEXT(threading.local):
|
|
134
|
+
"""
|
|
135
|
+
The context for the eligibility trace compiler.
|
|
136
|
+
|
|
137
|
+
The context is a thread-local object, which is used to store the compiled graph
|
|
138
|
+
for the eligibility trace.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
def __init__(self):
|
|
142
|
+
self.compilers = []
|
|
143
|
+
|
|
144
|
+
def add_compiler(self, name: str):
|
|
145
|
+
self.compilers.append(name)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
context = CONTEXT()
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@contextmanager
|
|
152
|
+
def compiler_context(name: str):
|
|
153
|
+
"""
|
|
154
|
+
Provides a context manager for managing the eligibility trace compiler context.
|
|
155
|
+
|
|
156
|
+
This function manages the context for compiling eligibility trace graphs, ensuring
|
|
157
|
+
that recursive graph compilations are detected and handled appropriately.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
name (str): The name of the compiler to be added to the context.
|
|
161
|
+
|
|
162
|
+
Yields:
|
|
163
|
+
None: This context manager does not yield any value.
|
|
164
|
+
|
|
165
|
+
Raises:
|
|
166
|
+
NotImplementedError: If a recursive call to "compile_graph" is detected.
|
|
167
|
+
"""
|
|
168
|
+
try:
|
|
169
|
+
# add the compiler to the context
|
|
170
|
+
context.add_compiler(name)
|
|
171
|
+
|
|
172
|
+
# check if there is a recursive graph compilation
|
|
173
|
+
if len(context.compilers) > 1:
|
|
174
|
+
raise NotImplementedError(
|
|
175
|
+
'Detected recursive call to "compile_graph". '
|
|
176
|
+
'This is not supported currently.'
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
yield
|
|
180
|
+
finally:
|
|
181
|
+
context.compilers.pop()
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def compile_etrace_graph(
|
|
185
|
+
model: brainstate.nn.Module,
|
|
186
|
+
*model_args: Tuple,
|
|
187
|
+
include_hidden_perturb: bool = True,
|
|
188
|
+
) -> ETraceGraph:
|
|
189
|
+
"""
|
|
190
|
+
Constructs the eligibility trace graph for a given model based on the provided inputs.
|
|
191
|
+
|
|
192
|
+
This is the most important method for the eligibility trace graph. It builds the
|
|
193
|
+
graph for the model, tracking the relationship between the etrace weights
|
|
194
|
+
:py:class:`ETraceParam`, the etrace sattes :py:class:`ETraceState`, and the etrace
|
|
195
|
+
operations :py:class:`ETraceOp`, which will be used for computing the weight
|
|
196
|
+
spatial gradients, the hidden state Jacobian, and the hidden state-weight Jacobian.
|
|
197
|
+
|
|
198
|
+
This function is crucial for building the eligibility trace graph, which tracks the
|
|
199
|
+
relationships between eligibility trace weights, states, and operations. These relationships
|
|
200
|
+
are used to compute weight spatial gradients, hidden state Jacobians, and hidden state-weight
|
|
201
|
+
Jacobians.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
model (brainstate.nn.Module): The model for which the eligibility trace graph is to be built.
|
|
205
|
+
model_args (Tuple): The arguments required by the model.
|
|
206
|
+
include_hidden_perturb (bool): Indicates whether to include hidden perturbations in the graph.
|
|
207
|
+
Defaults to True.
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
ETraceGraph: The compiled eligibility trace graph containing module information, hidden groups,
|
|
211
|
+
hidden parameter operation relations, and optional hidden perturbations.
|
|
212
|
+
"""
|
|
213
|
+
|
|
214
|
+
with compiler_context('compile_graph'):
|
|
215
|
+
|
|
216
|
+
assert isinstance(model_args, tuple)
|
|
217
|
+
minfo = extract_module_info(model, *model_args)
|
|
218
|
+
|
|
219
|
+
# --- evaluating the relationship for hidden-to-hidden --- #
|
|
220
|
+
hidden_groups, hid_path_to_group = find_hidden_groups_from_minfo(minfo)
|
|
221
|
+
order_hidden_group_index(hidden_groups)
|
|
222
|
+
|
|
223
|
+
# --- evaluating the jaxpr for (hidden, param, op) relationships --- #
|
|
224
|
+
|
|
225
|
+
hidden_param_op_relations = find_hidden_param_op_relations_from_minfo(
|
|
226
|
+
minfo=minfo,
|
|
227
|
+
hid_path_to_group=hid_path_to_group,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
# --- Rewrite the jaxpr for computing the needed variables --- #
|
|
231
|
+
|
|
232
|
+
# Rewrite jaxpr to return all necessary variables, including
|
|
233
|
+
#
|
|
234
|
+
# 1. the original function outputs
|
|
235
|
+
# 2. the hidden states
|
|
236
|
+
# 3. the weight x ===> for computing the weight spatial gradients
|
|
237
|
+
# 4. the y-to-hidden variables ===> for computing the weight spatial gradients
|
|
238
|
+
# 5. the hidden-hidden transition variables ===> for computing the hidden-hidden jacobian
|
|
239
|
+
#
|
|
240
|
+
|
|
241
|
+
# all weight x
|
|
242
|
+
out_wx_jaxvars = list(set([
|
|
243
|
+
relation.x for relation in hidden_param_op_relations
|
|
244
|
+
if relation.x is not None
|
|
245
|
+
]))
|
|
246
|
+
|
|
247
|
+
# all y-to-hidden vars
|
|
248
|
+
out_wy2hid_jaxvars = set()
|
|
249
|
+
for relation in hidden_param_op_relations:
|
|
250
|
+
for hpo_jaxpr in relation.y_to_hidden_group_jaxprs:
|
|
251
|
+
out_wy2hid_jaxvars.update(hpo_jaxpr.invars + hpo_jaxpr.constvars)
|
|
252
|
+
out_wy2hid_jaxvars = list(out_wy2hid_jaxvars)
|
|
253
|
+
|
|
254
|
+
# hidden-hidden transition vars
|
|
255
|
+
hid2hid_jaxvars = set()
|
|
256
|
+
for group in hidden_groups:
|
|
257
|
+
hid2hid_jaxvars.update(group.hidden_invars)
|
|
258
|
+
hid2hid_jaxvars.update(group.transition_jaxpr_constvars)
|
|
259
|
+
hid2hid_jaxvars = list(hid2hid_jaxvars)
|
|
260
|
+
|
|
261
|
+
# all temporary outvars
|
|
262
|
+
temp_outvars = set(
|
|
263
|
+
minfo.jaxpr.outvars[minfo.num_var_out:] + # all state variables
|
|
264
|
+
out_wx_jaxvars + # all weight x
|
|
265
|
+
out_wy2hid_jaxvars + # all y-to-hidden invars
|
|
266
|
+
hid2hid_jaxvars # all hidden-hidden transition vars
|
|
267
|
+
).difference(
|
|
268
|
+
minfo.jaxpr.outvars # exclude the original function outputs
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
# rewrite module_info
|
|
272
|
+
minfo = minfo.add_jaxpr_outs(list(temp_outvars))
|
|
273
|
+
|
|
274
|
+
# --- add perturbations to the hidden states --- #
|
|
275
|
+
# --- new jaxpr with hidden state perturbations for computing the residuals --- #
|
|
276
|
+
|
|
277
|
+
hidden_perturb = add_hidden_perturbation_from_minfo(minfo) if include_hidden_perturb else None
|
|
278
|
+
|
|
279
|
+
# --- return the compiled graph --- #
|
|
280
|
+
|
|
281
|
+
return ETraceGraph(
|
|
282
|
+
module_info=minfo,
|
|
283
|
+
hidden_groups=hidden_groups,
|
|
284
|
+
hid_path_to_group=hid_path_to_group,
|
|
285
|
+
hidden_param_op_relations=hidden_param_op_relations,
|
|
286
|
+
hidden_perturb=hidden_perturb,
|
|
287
|
+
)
|