braintrace 0.1.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. braintrace/__init__.py +79 -0
  2. braintrace/_compatible_imports.py +62 -0
  3. braintrace/_compatible_imports_test.py +94 -0
  4. braintrace/_etrace_algorithms.py +333 -0
  5. braintrace/_etrace_compiler_base.py +290 -0
  6. braintrace/_etrace_compiler_graph.py +287 -0
  7. braintrace/_etrace_compiler_graph_test.py +329 -0
  8. braintrace/_etrace_compiler_hid_param_op.py +832 -0
  9. braintrace/_etrace_compiler_hid_param_op_test.py +112 -0
  10. braintrace/_etrace_compiler_hidden_group.py +954 -0
  11. braintrace/_etrace_compiler_hidden_group_test.py +843 -0
  12. braintrace/_etrace_compiler_hidden_pertubation.py +381 -0
  13. braintrace/_etrace_compiler_hidden_pertubation_test.py +126 -0
  14. braintrace/_etrace_compiler_module_info.py +551 -0
  15. braintrace/_etrace_compiler_module_info_test.py +114 -0
  16. braintrace/_etrace_concepts.py +382 -0
  17. braintrace/_etrace_concepts_test.py +159 -0
  18. braintrace/_etrace_debug_jaxpr2code.py +1134 -0
  19. braintrace/_etrace_debug_visualize.py +1561 -0
  20. braintrace/_etrace_graph_executor.py +319 -0
  21. braintrace/_etrace_graph_executor_test.py +67 -0
  22. braintrace/_etrace_input_data.py +203 -0
  23. braintrace/_etrace_input_data_test.py +51 -0
  24. braintrace/_etrace_model_test.py +450 -0
  25. braintrace/_etrace_model_with_group_state.py +267 -0
  26. braintrace/_etrace_operators.py +1072 -0
  27. braintrace/_etrace_operators_test.py +58 -0
  28. braintrace/_etrace_vjp/__init__.py +29 -0
  29. braintrace/_etrace_vjp/base.py +671 -0
  30. braintrace/_etrace_vjp/d_rtrl.py +756 -0
  31. braintrace/_etrace_vjp/d_rtrl_test.py +205 -0
  32. braintrace/_etrace_vjp/esd_rtrl.py +847 -0
  33. braintrace/_etrace_vjp/esd_rtrl_test.py +194 -0
  34. braintrace/_etrace_vjp/graph_executor.py +718 -0
  35. braintrace/_etrace_vjp/graph_executor_test.py +102 -0
  36. braintrace/_etrace_vjp/hybrid.py +604 -0
  37. braintrace/_etrace_vjp/misc.py +162 -0
  38. braintrace/_grad_exponential.py +85 -0
  39. braintrace/_misc.py +403 -0
  40. braintrace/_state_managment.py +436 -0
  41. braintrace/_typing.py +91 -0
  42. braintrace/nn/__init__.py +68 -0
  43. braintrace/nn/_conv.py +395 -0
  44. braintrace/nn/_conv_test.py +868 -0
  45. braintrace/nn/_linear.py +524 -0
  46. braintrace/nn/_linear_test.py +658 -0
  47. braintrace/nn/_normalizations.py +508 -0
  48. braintrace/nn/_normalizations_test.py +695 -0
  49. braintrace/nn/_readout.py +278 -0
  50. braintrace/nn/_readout_test.py +763 -0
  51. braintrace/nn/_rnn.py +1057 -0
  52. braintrace/nn/_rnn_test.py +710 -0
  53. braintrace-0.1.1.dist-info/METADATA +137 -0
  54. braintrace-0.1.1.dist-info/RECORD +57 -0
  55. braintrace-0.1.1.dist-info/WHEEL +6 -0
  56. braintrace-0.1.1.dist-info/licenses/LICENSE +202 -0
  57. braintrace-0.1.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,290 @@
1
+ # Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Dict, Sequence, Set, List
17
+
18
+ from ._compatible_imports import (
19
+ Var,
20
+ JaxprEqn,
21
+ is_jit_primitive,
22
+ is_scan_primitive,
23
+ is_while_primitive,
24
+ is_cond_primitive,
25
+ )
26
+ from ._etrace_operators import (
27
+ is_etrace_op,
28
+ is_etrace_op_enable_gradient,
29
+ )
30
+ from ._typing import Path
31
+
32
+
33
+ def find_matched_vars(
34
+ invars: Sequence[Var],
35
+ invar_needed_in_oth_eqns: Set[Var]
36
+ ) -> List[Var]:
37
+ """
38
+ Checking whether the invars are matched with the invar_needed_in_oth_eqns.
39
+
40
+ Parameters
41
+ ----------
42
+ invars : Sequence[Var]
43
+ The input variables of the equation.
44
+ invar_needed_in_oth_eqns : Set[Var]
45
+ The variables needed in the other equations.
46
+
47
+ Returns
48
+ -------
49
+ List[Var]
50
+ The list of matched variables.
51
+ """
52
+ matched = []
53
+ for invar in invars:
54
+ if isinstance(invar, Var) and invar in invar_needed_in_oth_eqns:
55
+ matched.append(invar)
56
+ return matched
57
+
58
+
59
+ def find_element_exist_in_the_set(
60
+ elements: Sequence[Var],
61
+ the_set: Set[Var]
62
+ ) -> Var | None:
63
+ """
64
+ Checking whether the jaxpr vars contain the weight variables.
65
+
66
+ Parameters
67
+ ----------
68
+ elements : Sequence[Var]
69
+ The input variables of the equation.
70
+ the_set : Set[Var]
71
+ The set of the weight variables.
72
+
73
+ Returns
74
+ -------
75
+ Var | None
76
+ The first element found in the set, or None if no element is found.
77
+ """
78
+ for invar in elements:
79
+ if isinstance(invar, Var) and invar in the_set:
80
+ return invar
81
+ return None
82
+
83
+
84
+ def check_unsupported_op(
85
+ self,
86
+ eqn: JaxprEqn,
87
+ op_name: str
88
+ ):
89
+ """
90
+ Checks for unsupported operations involving weight or hidden state variables in the given equation.
91
+
92
+ This function verifies whether the specified JAX equation (`eqn`) uses weight or hidden state variables
93
+ in a manner that is currently unsupported. If such usage is detected, a `NotImplementedError` is raised
94
+ with a detailed message.
95
+
96
+ Parameters
97
+ ----------
98
+ self : JaxprEvaluation
99
+ The instance of the class containing this method.
100
+ eqn : JaxprEqn
101
+ The JAX equation to be checked.
102
+ op_name : str
103
+ The name of the operation being checked (e.g., 'pjit', 'scan', 'while', 'cond').
104
+
105
+ Raises
106
+ ------
107
+ NotImplementedError
108
+ If the equation uses weight variables or computes hidden state variables
109
+ in an unsupported manner.
110
+ """
111
+ # checking whether the weight variables are used in the equation
112
+ invar = find_element_exist_in_the_set(eqn.invars, self.weight_invars)
113
+ if invar is not None:
114
+ raise NotImplementedError(
115
+ f'Currently, we do not support the weight states are used within a {op_name} function. \n'
116
+ f'Please remove your {op_name} on the intermediate steps. \n\n'
117
+ f'The weight state is: {self.invar_to_hidden_path[invar]}. \n'
118
+ f'The Jaxpr of the {op_name} function is: \n\n'
119
+ f'{eqn} \n\n'
120
+ )
121
+
122
+ # checking whether the hidden variables are computed in the equation
123
+ outvar = find_element_exist_in_the_set(eqn.outvars, self.hidden_outvars)
124
+ if outvar is not None:
125
+ raise NotImplementedError(
126
+ f'Currently, we do not support the hidden states are computed within a {op_name} function. \n'
127
+ f'Please remove your {op_name} on the intermediate steps. \n\n'
128
+ f'The hidden state is: {self.outvar_to_hidden_path[outvar]}. \n'
129
+ f'The Jaxpr of the {op_name} function is: \n\n'
130
+ f'{eqn} \n\n'
131
+ )
132
+
133
+
134
+ class JaxprEvaluation(object):
135
+ """
136
+ A base class for evaluating JAX program representations (jaxpr) to extract eligibility trace relationships.
137
+
138
+ This class analyzes the computational graph represented as JAX primitives to identify and track
139
+ relationships between weight parameters and hidden states for eligibility trace computation.
140
+ Subclasses must implement the `_eval_eqn` method to define specific evaluation behavior.
141
+
142
+ The class handles special JAX primitives such as pjit, scan, while, and cond operations,
143
+ providing appropriate handling or restrictions for eligibility trace compilation.
144
+
145
+ Parameters
146
+ ----------
147
+ weight_invars : Set[Var]
148
+ Input variables representing weight parameters in the computational graph.
149
+ hidden_invars : Set[Var]
150
+ Input variables representing hidden states in the computational graph.
151
+ hidden_outvars : Set[Var]
152
+ Output variables representing hidden states in the computational graph.
153
+ invar_to_hidden_path : Dict[Var, Path]
154
+ Mapping from input variables to their paths in the hidden state hierarchy.
155
+ outvar_to_hidden_path : Dict[Var, Path]
156
+ Mapping from output variables to their paths in the hidden state hierarchy.
157
+
158
+ Attributes
159
+ ----------
160
+ weight_invars : Set[Var]
161
+ Stored input weight variables.
162
+ hidden_invars : Set[Var]
163
+ Stored input hidden state variables.
164
+ hidden_outvars : Set[Var]
165
+ Stored output hidden state variables.
166
+ invar_to_hidden_path : Dict[Var, Path]
167
+ Stored mapping from input variables to hidden paths.
168
+ outvar_to_hidden_path : Dict[Var, Path]
169
+ Stored mapping from output variables to hidden paths.
170
+ """
171
+ __module__ = 'braintrace'
172
+
173
+ def __init__(
174
+ self,
175
+ weight_invars: Set[Var],
176
+ hidden_invars: Set[Var],
177
+ hidden_outvars: Set[Var],
178
+ invar_to_hidden_path: Dict[Var, Path],
179
+ outvar_to_hidden_path: Dict[Var, Path],
180
+ ):
181
+ self.weight_invars = weight_invars
182
+ self.hidden_invars = hidden_invars
183
+ self.hidden_outvars = hidden_outvars
184
+ self.invar_to_hidden_path = invar_to_hidden_path
185
+ self.outvar_to_hidden_path = outvar_to_hidden_path
186
+
187
+ def _eval_jaxpr(self, jaxpr) -> None:
188
+ """
189
+ Evaluating the jaxpr for extracting the etrace relationships.
190
+
191
+ Parameters
192
+ ----------
193
+ jaxpr : Jaxpr
194
+ The jaxpr for the model.
195
+ """
196
+
197
+ for eqn in jaxpr.eqns:
198
+ # TODO: add the support for the scan, while, cond, pjit, and other operators
199
+ # Currently, scan, while, and cond are usually not the common operators used in
200
+ # the definition of a brain dynamics model. So we may not need to consider them
201
+ # during the current phase.
202
+ # However, for the long-term maintenance and development, we need to consider them,
203
+ # since users usually create crazy models.
204
+
205
+ if is_jit_primitive(eqn):
206
+ self._eval_pjit(eqn)
207
+ elif is_scan_primitive(eqn):
208
+ self._eval_scan(eqn)
209
+ elif is_while_primitive(eqn):
210
+ self._eval_while(eqn)
211
+ elif is_cond_primitive(eqn):
212
+ self._eval_cond(eqn)
213
+ else:
214
+ self._eval_eqn(eqn)
215
+
216
+ def _eval_pjit(self, eqn: JaxprEqn) -> None:
217
+ """
218
+ Evaluating the pjit primitive.
219
+
220
+ Parameters
221
+ ----------
222
+ eqn : JaxprEqn
223
+ The JAX equation to evaluate.
224
+ """
225
+ if is_etrace_op(eqn.params['name']):
226
+ if is_etrace_op_enable_gradient(eqn.params['name']):
227
+ self._eval_eqn(eqn)
228
+ return
229
+
230
+ check_unsupported_op(self, eqn, 'jit')
231
+
232
+ # treat the pjit as a normal jaxpr equation
233
+ self._eval_eqn(eqn)
234
+
235
+ def _eval_scan(self, eqn: JaxprEqn) -> None:
236
+ """
237
+ Evaluating the scan primitive.
238
+
239
+ Parameters
240
+ ----------
241
+ eqn : JaxprEqn
242
+ The JAX equation to evaluate.
243
+ """
244
+ check_unsupported_op(self, eqn, 'while')
245
+ self._eval_eqn(eqn)
246
+
247
+ def _eval_while(self, eqn: JaxprEqn) -> None:
248
+ """
249
+ Evaluating the while primitive.
250
+
251
+ Parameters
252
+ ----------
253
+ eqn : JaxprEqn
254
+ The JAX equation to evaluate.
255
+ """
256
+ check_unsupported_op(self, eqn, 'scan')
257
+ self._eval_eqn(eqn)
258
+
259
+ def _eval_cond(self, eqn: JaxprEqn) -> None:
260
+ """
261
+ Evaluating the cond primitive.
262
+
263
+ Parameters
264
+ ----------
265
+ eqn : JaxprEqn
266
+ The JAX equation to evaluate.
267
+ """
268
+ check_unsupported_op(self, eqn, 'cond')
269
+ self._eval_eqn(eqn)
270
+
271
+ def _eval_eqn(self, eqn):
272
+ """
273
+ Evaluate a single JAX equation.
274
+
275
+ This method must be implemented by subclasses to define specific
276
+ evaluation behavior for equations.
277
+
278
+ Parameters
279
+ ----------
280
+ eqn : JaxprEqn
281
+ The JAX equation to evaluate.
282
+
283
+ Raises
284
+ ------
285
+ NotImplementedError
286
+ This method must be implemented in subclasses.
287
+ """
288
+ raise NotImplementedError(
289
+ 'The method "_eval_eqn" should be implemented in the subclass.'
290
+ )
@@ -0,0 +1,287 @@
1
+ # Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ import threading
18
+ from contextlib import contextmanager
19
+ from typing import Dict, Sequence, Tuple, Optional, NamedTuple
20
+
21
+ import brainstate
22
+ import jax
23
+
24
+ from ._etrace_compiler_hid_param_op import (
25
+ find_hidden_param_op_relations_from_minfo,
26
+ HiddenParamOpRelation,
27
+ )
28
+ from ._etrace_compiler_hidden_group import (
29
+ find_hidden_groups_from_minfo,
30
+ HiddenGroup,
31
+ )
32
+ from ._etrace_compiler_hidden_pertubation import (
33
+ add_hidden_perturbation_from_minfo,
34
+ HiddenPerturbation,
35
+ )
36
+ from ._etrace_compiler_module_info import (
37
+ extract_module_info,
38
+ ModuleInfo,
39
+ )
40
+ from ._typing import (
41
+ Inputs,
42
+ Path,
43
+ )
44
+
45
+ __all__ = [
46
+ 'ETraceGraph',
47
+ 'compile_etrace_graph',
48
+ ]
49
+
50
+
51
+ def order_hidden_group_index(
52
+ hidden_groups: Sequence[HiddenGroup],
53
+ ):
54
+ """
55
+ Verifies that hidden group indices match their positions in the sequence.
56
+
57
+ This function ensures that the index attribute of each HiddenGroup in the sequence
58
+ matches its position in the sequence. This validation is important for maintaining
59
+ the correct ordering of hidden groups in the eligibility trace compilation process.
60
+
61
+ Args:
62
+ hidden_groups (Sequence[HiddenGroup]): A sequence of HiddenGroup objects to validate.
63
+
64
+ Raises:
65
+ AssertionError: If any hidden group's index doesn't match its position in the sequence.
66
+ """
67
+ for i, group in enumerate(hidden_groups):
68
+ assert group.index == i, f"Hidden group index {group.index} should be equal to its position {i}."
69
+
70
+
71
+ class ETraceGraph(NamedTuple):
72
+ """
73
+ The overall compiled graph for the eligibility trace.
74
+
75
+ The eligibility trace graph, tracking the relationship between the etrace weights
76
+ :py:class:`ETraceParam`, the etrace variables :py:class:`ETraceState`, and the etrace
77
+ operations :py:class:`ETraceOp`.
78
+
79
+ The following fields are included:
80
+
81
+ - ``module_info``: The model information, instance of :class:`ModuleInfo`.
82
+ - ``hidden_groups``: The hidden groups, sequence of :class:`HiddenGroup`.
83
+ - ``hid_path_to_group``: The mapping from the hidden path to the hidden group :class:`HiddenGroup`.
84
+ - ``hidden_param_op_relations``: The hidden parameter operation relations, sequence of :class:`HiddenParamOpRelation`.
85
+ - ``hidden_perturb``: The hidden perturbation, instance of :class:`HiddenPerturbation`, or None.
86
+
87
+ Example::
88
+
89
+ >>> import braintrace
90
+ >>> import brainstate
91
+ >>> gru = braintrace.nn.GRUCell(10, 20)
92
+ >>> gru.init_state()
93
+ >>> inputs = brainstate.random.randn(10)
94
+ >>> compiled_graph = braintrace.compile_etrace_graph(gru, inputs)
95
+ >>> compiled_graph.dict().keys()
96
+
97
+ """
98
+
99
+ module_info: ModuleInfo
100
+ hidden_groups: Sequence[HiddenGroup]
101
+ hid_path_to_group: Dict[Path, HiddenGroup]
102
+ hidden_param_op_relations: Sequence[HiddenParamOpRelation]
103
+ hidden_perturb: HiddenPerturbation | None
104
+
105
+ def call_hidden_perturb(
106
+ self,
107
+ args: Inputs,
108
+ perturb_data: Sequence[jax.Array],
109
+ old_state_vals: Optional[Sequence[jax.Array]] = None,
110
+ ):
111
+ # state checking
112
+ if old_state_vals is None:
113
+ old_state_vals = [st.value for st in self.module_info.compiled_model_states]
114
+
115
+ # calling the function
116
+ jaxpr_outs = self.hidden_perturb.eval_jaxpr(
117
+ jax.tree.leaves((args, old_state_vals)),
118
+ perturb_data,
119
+ )
120
+
121
+ return self.module_info._process(*args, jaxpr_outs=jaxpr_outs)
122
+
123
+ def dict(self) -> Dict:
124
+ return self._asdict()
125
+
126
+ def __repr__(self) -> str:
127
+ return repr(brainstate.util.PrettyMapping(self._asdict(), type_name=self.__class__.__name__))
128
+
129
+
130
+ ETraceGraph.__module__ = 'braintrace'
131
+
132
+
133
+ class CONTEXT(threading.local):
134
+ """
135
+ The context for the eligibility trace compiler.
136
+
137
+ The context is a thread-local object, which is used to store the compiled graph
138
+ for the eligibility trace.
139
+ """
140
+
141
+ def __init__(self):
142
+ self.compilers = []
143
+
144
+ def add_compiler(self, name: str):
145
+ self.compilers.append(name)
146
+
147
+
148
+ context = CONTEXT()
149
+
150
+
151
+ @contextmanager
152
+ def compiler_context(name: str):
153
+ """
154
+ Provides a context manager for managing the eligibility trace compiler context.
155
+
156
+ This function manages the context for compiling eligibility trace graphs, ensuring
157
+ that recursive graph compilations are detected and handled appropriately.
158
+
159
+ Args:
160
+ name (str): The name of the compiler to be added to the context.
161
+
162
+ Yields:
163
+ None: This context manager does not yield any value.
164
+
165
+ Raises:
166
+ NotImplementedError: If a recursive call to "compile_graph" is detected.
167
+ """
168
+ try:
169
+ # add the compiler to the context
170
+ context.add_compiler(name)
171
+
172
+ # check if there is a recursive graph compilation
173
+ if len(context.compilers) > 1:
174
+ raise NotImplementedError(
175
+ 'Detected recursive call to "compile_graph". '
176
+ 'This is not supported currently.'
177
+ )
178
+
179
+ yield
180
+ finally:
181
+ context.compilers.pop()
182
+
183
+
184
+ def compile_etrace_graph(
185
+ model: brainstate.nn.Module,
186
+ *model_args: Tuple,
187
+ include_hidden_perturb: bool = True,
188
+ ) -> ETraceGraph:
189
+ """
190
+ Constructs the eligibility trace graph for a given model based on the provided inputs.
191
+
192
+ This is the most important method for the eligibility trace graph. It builds the
193
+ graph for the model, tracking the relationship between the etrace weights
194
+ :py:class:`ETraceParam`, the etrace sattes :py:class:`ETraceState`, and the etrace
195
+ operations :py:class:`ETraceOp`, which will be used for computing the weight
196
+ spatial gradients, the hidden state Jacobian, and the hidden state-weight Jacobian.
197
+
198
+ This function is crucial for building the eligibility trace graph, which tracks the
199
+ relationships between eligibility trace weights, states, and operations. These relationships
200
+ are used to compute weight spatial gradients, hidden state Jacobians, and hidden state-weight
201
+ Jacobians.
202
+
203
+ Args:
204
+ model (brainstate.nn.Module): The model for which the eligibility trace graph is to be built.
205
+ model_args (Tuple): The arguments required by the model.
206
+ include_hidden_perturb (bool): Indicates whether to include hidden perturbations in the graph.
207
+ Defaults to True.
208
+
209
+ Returns:
210
+ ETraceGraph: The compiled eligibility trace graph containing module information, hidden groups,
211
+ hidden parameter operation relations, and optional hidden perturbations.
212
+ """
213
+
214
+ with compiler_context('compile_graph'):
215
+
216
+ assert isinstance(model_args, tuple)
217
+ minfo = extract_module_info(model, *model_args)
218
+
219
+ # --- evaluating the relationship for hidden-to-hidden --- #
220
+ hidden_groups, hid_path_to_group = find_hidden_groups_from_minfo(minfo)
221
+ order_hidden_group_index(hidden_groups)
222
+
223
+ # --- evaluating the jaxpr for (hidden, param, op) relationships --- #
224
+
225
+ hidden_param_op_relations = find_hidden_param_op_relations_from_minfo(
226
+ minfo=minfo,
227
+ hid_path_to_group=hid_path_to_group,
228
+ )
229
+
230
+ # --- Rewrite the jaxpr for computing the needed variables --- #
231
+
232
+ # Rewrite jaxpr to return all necessary variables, including
233
+ #
234
+ # 1. the original function outputs
235
+ # 2. the hidden states
236
+ # 3. the weight x ===> for computing the weight spatial gradients
237
+ # 4. the y-to-hidden variables ===> for computing the weight spatial gradients
238
+ # 5. the hidden-hidden transition variables ===> for computing the hidden-hidden jacobian
239
+ #
240
+
241
+ # all weight x
242
+ out_wx_jaxvars = list(set([
243
+ relation.x for relation in hidden_param_op_relations
244
+ if relation.x is not None
245
+ ]))
246
+
247
+ # all y-to-hidden vars
248
+ out_wy2hid_jaxvars = set()
249
+ for relation in hidden_param_op_relations:
250
+ for hpo_jaxpr in relation.y_to_hidden_group_jaxprs:
251
+ out_wy2hid_jaxvars.update(hpo_jaxpr.invars + hpo_jaxpr.constvars)
252
+ out_wy2hid_jaxvars = list(out_wy2hid_jaxvars)
253
+
254
+ # hidden-hidden transition vars
255
+ hid2hid_jaxvars = set()
256
+ for group in hidden_groups:
257
+ hid2hid_jaxvars.update(group.hidden_invars)
258
+ hid2hid_jaxvars.update(group.transition_jaxpr_constvars)
259
+ hid2hid_jaxvars = list(hid2hid_jaxvars)
260
+
261
+ # all temporary outvars
262
+ temp_outvars = set(
263
+ minfo.jaxpr.outvars[minfo.num_var_out:] + # all state variables
264
+ out_wx_jaxvars + # all weight x
265
+ out_wy2hid_jaxvars + # all y-to-hidden invars
266
+ hid2hid_jaxvars # all hidden-hidden transition vars
267
+ ).difference(
268
+ minfo.jaxpr.outvars # exclude the original function outputs
269
+ )
270
+
271
+ # rewrite module_info
272
+ minfo = minfo.add_jaxpr_outs(list(temp_outvars))
273
+
274
+ # --- add perturbations to the hidden states --- #
275
+ # --- new jaxpr with hidden state perturbations for computing the residuals --- #
276
+
277
+ hidden_perturb = add_hidden_perturbation_from_minfo(minfo) if include_hidden_perturb else None
278
+
279
+ # --- return the compiled graph --- #
280
+
281
+ return ETraceGraph(
282
+ module_info=minfo,
283
+ hidden_groups=hidden_groups,
284
+ hid_path_to_group=hid_path_to_group,
285
+ hidden_param_op_relations=hidden_param_op_relations,
286
+ hidden_perturb=hidden_perturb,
287
+ )