braintrace 0.1.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. braintrace/__init__.py +79 -0
  2. braintrace/_compatible_imports.py +62 -0
  3. braintrace/_compatible_imports_test.py +94 -0
  4. braintrace/_etrace_algorithms.py +333 -0
  5. braintrace/_etrace_compiler_base.py +290 -0
  6. braintrace/_etrace_compiler_graph.py +287 -0
  7. braintrace/_etrace_compiler_graph_test.py +329 -0
  8. braintrace/_etrace_compiler_hid_param_op.py +832 -0
  9. braintrace/_etrace_compiler_hid_param_op_test.py +112 -0
  10. braintrace/_etrace_compiler_hidden_group.py +954 -0
  11. braintrace/_etrace_compiler_hidden_group_test.py +843 -0
  12. braintrace/_etrace_compiler_hidden_pertubation.py +381 -0
  13. braintrace/_etrace_compiler_hidden_pertubation_test.py +126 -0
  14. braintrace/_etrace_compiler_module_info.py +551 -0
  15. braintrace/_etrace_compiler_module_info_test.py +114 -0
  16. braintrace/_etrace_concepts.py +382 -0
  17. braintrace/_etrace_concepts_test.py +159 -0
  18. braintrace/_etrace_debug_jaxpr2code.py +1134 -0
  19. braintrace/_etrace_debug_visualize.py +1561 -0
  20. braintrace/_etrace_graph_executor.py +319 -0
  21. braintrace/_etrace_graph_executor_test.py +67 -0
  22. braintrace/_etrace_input_data.py +203 -0
  23. braintrace/_etrace_input_data_test.py +51 -0
  24. braintrace/_etrace_model_test.py +450 -0
  25. braintrace/_etrace_model_with_group_state.py +267 -0
  26. braintrace/_etrace_operators.py +1072 -0
  27. braintrace/_etrace_operators_test.py +58 -0
  28. braintrace/_etrace_vjp/__init__.py +29 -0
  29. braintrace/_etrace_vjp/base.py +671 -0
  30. braintrace/_etrace_vjp/d_rtrl.py +756 -0
  31. braintrace/_etrace_vjp/d_rtrl_test.py +205 -0
  32. braintrace/_etrace_vjp/esd_rtrl.py +847 -0
  33. braintrace/_etrace_vjp/esd_rtrl_test.py +194 -0
  34. braintrace/_etrace_vjp/graph_executor.py +718 -0
  35. braintrace/_etrace_vjp/graph_executor_test.py +102 -0
  36. braintrace/_etrace_vjp/hybrid.py +604 -0
  37. braintrace/_etrace_vjp/misc.py +162 -0
  38. braintrace/_grad_exponential.py +85 -0
  39. braintrace/_misc.py +403 -0
  40. braintrace/_state_managment.py +436 -0
  41. braintrace/_typing.py +91 -0
  42. braintrace/nn/__init__.py +68 -0
  43. braintrace/nn/_conv.py +395 -0
  44. braintrace/nn/_conv_test.py +868 -0
  45. braintrace/nn/_linear.py +524 -0
  46. braintrace/nn/_linear_test.py +658 -0
  47. braintrace/nn/_normalizations.py +508 -0
  48. braintrace/nn/_normalizations_test.py +695 -0
  49. braintrace/nn/_readout.py +278 -0
  50. braintrace/nn/_readout_test.py +763 -0
  51. braintrace/nn/_rnn.py +1057 -0
  52. braintrace/nn/_rnn_test.py +710 -0
  53. braintrace-0.1.1.dist-info/METADATA +137 -0
  54. braintrace-0.1.1.dist-info/RECORD +57 -0
  55. braintrace-0.1.1.dist-info/WHEEL +6 -0
  56. braintrace-0.1.1.dist-info/licenses/LICENSE +202 -0
  57. braintrace-0.1.1.dist-info/top_level.txt +1 -0
braintrace/__init__.py ADDED
@@ -0,0 +1,79 @@
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+
19
+ __version__ = "0.1.1"
20
+ __version_info__ = (0, 1, 1)
21
+
22
+ from braintrace._etrace_algorithms import *
23
+ from braintrace._etrace_algorithms import __all__ as _alg_all
24
+ from braintrace._etrace_compiler_graph import *
25
+ from braintrace._etrace_compiler_graph import __all__ as _compiler_all
26
+ from braintrace._etrace_compiler_hid_param_op import *
27
+ from braintrace._etrace_compiler_hid_param_op import __all__ as _hid_param_all
28
+ from braintrace._etrace_compiler_hidden_group import *
29
+ from braintrace._etrace_compiler_hidden_group import __all__ as _hid_group_all
30
+ from braintrace._etrace_compiler_hidden_pertubation import *
31
+ from braintrace._etrace_compiler_hidden_pertubation import __all__ as _hid_pertub_all
32
+ from braintrace._etrace_compiler_module_info import *
33
+ from braintrace._etrace_compiler_module_info import __all__ as _mod_info_all
34
+ from braintrace._etrace_concepts import *
35
+ from braintrace._etrace_concepts import __all__ as _con_all
36
+ from braintrace._etrace_graph_executor import *
37
+ from braintrace._etrace_graph_executor import __all__ as _exec_all
38
+ from braintrace._etrace_input_data import *
39
+ from braintrace._etrace_input_data import __all__ as _data_all
40
+ from braintrace._etrace_operators import *
41
+ from braintrace._etrace_operators import __all__ as _op_all
42
+ from braintrace._etrace_vjp import *
43
+ from braintrace._etrace_vjp import __all__ as _vjp_all
44
+ from braintrace._grad_exponential import *
45
+ from braintrace._grad_exponential import __all__ as _grad_exp_all
46
+ from braintrace._misc import *
47
+ from braintrace._misc import __all__ as _misc_all
48
+ from . import nn
49
+
50
+ __all__ = ['nn'] + _alg_all + _compiler_all + _hid_param_all + _hid_group_all + _hid_pertub_all
51
+ __all__ += _mod_info_all + _con_all + _exec_all + _data_all + _op_all + _vjp_all
52
+ __all__ += _grad_exp_all + _misc_all
53
+
54
+ del _alg_all, _compiler_all, _hid_param_all, _hid_group_all, _hid_pertub_all
55
+ del _mod_info_all, _con_all, _exec_all, _data_all, _op_all, _vjp_all
56
+ del _grad_exp_all,
57
+ del _misc_all
58
+
59
+
60
+ def __getattr__(name):
61
+ mapping = {
62
+ 'ETraceState': 'HiddenState',
63
+ 'ETraceGroupState': 'HiddenGroupState',
64
+ 'ETraceTreeState': 'HiddenTreeState',
65
+ }
66
+
67
+ if name in mapping:
68
+ import warnings
69
+ import brainstate
70
+
71
+ warnings.warn(
72
+ f"braintrace.{name} is deprecated and will be removed in a future release. "
73
+ f"Please use brainstate.{mapping[name]} instead.",
74
+ DeprecationWarning,
75
+ stacklevel=2,
76
+ )
77
+
78
+ return getattr(brainstate, mapping[name])
79
+ raise AttributeError(name)
@@ -0,0 +1,62 @@
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import jax
17
+
18
+ __all__ = [
19
+ 'Primitive',
20
+ 'Var',
21
+ 'JaxprEqn',
22
+ 'Jaxpr',
23
+ 'ClosedJaxpr',
24
+ 'Literal',
25
+ 'new_var',
26
+ 'is_jit_primitive',
27
+ 'is_scan_primitive',
28
+ 'is_while_primitive',
29
+ 'is_cond_primitive',
30
+ ]
31
+
32
+ if jax.__version_info__ < (0, 4, 38):
33
+ from jax.core import Primitive, Var, JaxprEqn, Jaxpr, ClosedJaxpr, Literal
34
+
35
+ else:
36
+ from jax.extend.core import Primitive, Var, JaxprEqn, Jaxpr, ClosedJaxpr, Literal
37
+
38
+
39
+ def new_var(suffix, aval):
40
+ if jax.__version_info__ < (0, 6, 2):
41
+ return Var(suffix, aval)
42
+ else:
43
+ return Var(aval)
44
+
45
+
46
+ def is_jit_primitive(eqn: JaxprEqn) -> bool:
47
+ if jax.__version_info__ < (0, 7, 0):
48
+ return eqn.primitive.name == 'pjit'
49
+ else:
50
+ return eqn.primitive.name == 'jit'
51
+
52
+
53
+ def is_scan_primitive(eqn: JaxprEqn) -> bool:
54
+ return eqn.primitive.name == 'scan'
55
+
56
+
57
+ def is_while_primitive(eqn: JaxprEqn) -> bool:
58
+ return eqn.primitive.name == 'while'
59
+
60
+
61
+ def is_cond_primitive(eqn: JaxprEqn) -> bool:
62
+ return eqn.primitive.name == 'cond'
@@ -0,0 +1,94 @@
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import jax.numpy as jnp
17
+ from jax import jit, make_jaxpr, lax
18
+
19
+ from braintrace._compatible_imports import (
20
+ is_jit_primitive, is_scan_primitive, is_while_primitive,
21
+ is_cond_primitive
22
+ )
23
+
24
+
25
+ class TestPrimitive:
26
+ def test_jit(self):
27
+ @jit
28
+ def jit_function(x, y):
29
+ return x ** 2 + jnp.sin(y)
30
+
31
+ # Note: make_jaxpr on a jitted function shows the same jaxpr
32
+ jaxpr_jit = make_jaxpr(jit_function)(2.0, 1.0)
33
+ assert is_jit_primitive(jaxpr_jit.eqns[0])
34
+
35
+ def test_scan(self):
36
+ print("3. make_jaxpr with lax.scan:")
37
+
38
+ def scan_step(carry, x):
39
+ return carry + x, carry * x
40
+
41
+ def scan_function(init, xs):
42
+ return lax.scan(scan_step, init, xs)
43
+
44
+ # Create sample data
45
+ init_val = 1.0
46
+ xs = jnp.array([1.0, 2.0, 3.0, 4.0])
47
+
48
+ jaxpr_scan = make_jaxpr(scan_function)(init_val, xs)
49
+ assert is_scan_primitive(jaxpr_scan.eqns[0])
50
+
51
+ def test_while(self):
52
+ def while_cond(carry):
53
+ i, x = carry
54
+ return i < 5
55
+
56
+ def while_body(carry):
57
+ i, x = carry
58
+ return i + 1, x * 2
59
+
60
+ def while_function(init_carry):
61
+ return lax.while_loop(while_cond, while_body, init_carry)
62
+
63
+ init_carry = (0, 1.0)
64
+ jaxpr_while = make_jaxpr(while_function)(init_carry)
65
+ assert is_while_primitive(jaxpr_while.eqns[0])
66
+
67
+ def test_cond(self):
68
+ def true_branch(x):
69
+ return x * 2
70
+
71
+ def false_branch(x):
72
+ return x + 1
73
+
74
+ def cond_function(pred, x):
75
+ return lax.cond(pred, true_branch, false_branch, x)
76
+
77
+ jaxpr_cond = make_jaxpr(cond_function)(True, 5.0)
78
+ assert is_cond_primitive(jaxpr_cond.eqns[-1])
79
+
80
+ def test_fori_loop(self):
81
+ def branch_0(x):
82
+ return x * 2
83
+
84
+ def branch_1(x):
85
+ return x + 10
86
+
87
+ def branch_2(x):
88
+ return x ** 2
89
+
90
+ def switch_function(index, x):
91
+ return lax.switch(index, [branch_0, branch_1, branch_2], x)
92
+
93
+ jaxpr_switch = make_jaxpr(switch_function)(1, 3.0)
94
+ assert is_cond_primitive(jaxpr_switch.eqns[-1])
@@ -0,0 +1,333 @@
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # Author: Chaoming Wang <chao.brain@qq.com>
16
+ # Date: 2024-04-03
17
+ # Copyright: 2024, Chaoming Wang
18
+ # ==============================================================================
19
+
20
+ # -*- coding: utf-8 -*-
21
+
22
+ from typing import Dict, Any, Optional
23
+
24
+ import brainstate
25
+
26
+ from ._etrace_compiler_graph import ETraceGraph
27
+ from ._etrace_graph_executor import ETraceGraphExecutor
28
+ from ._typing import Path
29
+
30
+ __all__ = [
31
+ 'ETraceAlgorithm',
32
+ 'EligibilityTrace',
33
+ ]
34
+
35
+
36
+ class EligibilityTrace(brainstate.ShortTermState):
37
+ """
38
+ The state for storing the eligibility trace during the computation of
39
+ online learning algorithms.
40
+
41
+ Examples
42
+ --------
43
+ When you are using :class:`braintrace.IODimVjpAlgorithm`, you can get
44
+ the eligibility trace of the weight by calling:
45
+
46
+ .. code-block:: python
47
+
48
+ >>> etrace = etrace_algorithm.etrace_of(weight)
49
+
50
+ """
51
+ __module__ = 'braintrace'
52
+
53
+
54
+ class ETraceAlgorithm(brainstate.nn.Module):
55
+ r"""
56
+ The base class for the eligibility trace algorithm.
57
+
58
+ Parameters
59
+ ----------
60
+ model : brainstate.nn.Module
61
+ The model function, which receives the input arguments and returns the model output.
62
+ name : str, optional
63
+ The name of the etrace algorithm.
64
+
65
+ Attributes
66
+ ----------
67
+ graph : ETraceGraphExecutor
68
+ The etrace graph.
69
+ param_states : Dict[Hashable, brainstate.ParamState]
70
+ The weight states.
71
+ hidden_states : Dict[Hashable, brainstate.HiddenState]
72
+ The hidden states.
73
+ other_states : Dict[Hashable, brainstate.State]
74
+ The other states.
75
+ is_compiled : bool
76
+ Whether the etrace algorithm has been compiled.
77
+ running_index : brainstate.ParamState[int]
78
+ The running index.
79
+ """
80
+ __module__ = 'braintrace'
81
+
82
+ def __init__(
83
+ self,
84
+ model: brainstate.nn.Module,
85
+ graph_executor: ETraceGraphExecutor,
86
+ name: Optional[str] = None,
87
+ ):
88
+ super().__init__(name=name)
89
+
90
+ # the model
91
+ if not isinstance(model, brainstate.nn.Module):
92
+ raise ValueError(
93
+ f'The model should be a brainstate.nn.Module, this can help us to '
94
+ f'better obtain the program structure. But we got {type(model)}.'
95
+ )
96
+ self.model4compile = model
97
+
98
+ # the graph
99
+ if not isinstance(graph_executor, ETraceGraphExecutor):
100
+ raise ValueError(
101
+ f'The graph should be a ETraceGraphExecutor, this can help us to '
102
+ f'better obtain the program structure. But we got {type(graph_executor)}.'
103
+ )
104
+ self.graph_executor = graph_executor
105
+
106
+ # The flag to indicate whether the etrace algorithm has been compiled
107
+ self.is_compiled = False
108
+
109
+ # the running index
110
+ self.running_index = brainstate.LongTermState(0)
111
+
112
+ # other states
113
+ self._param_states = None
114
+ self._hidden_states = None
115
+ self._other_states = None
116
+
117
+ @property
118
+ def graph(self) -> ETraceGraph:
119
+ """
120
+ Get the etrace graph.
121
+
122
+ Returns
123
+ -------
124
+ ETraceGraph
125
+ The etrace graph.
126
+ """
127
+ return self.graph_executor.graph
128
+
129
+ @property
130
+ def executor(self) -> ETraceGraphExecutor:
131
+ """
132
+ Get the etrace graph executor.
133
+
134
+ Returns
135
+ -------
136
+ ETraceGraphExecutor
137
+ The etrace graph executor.
138
+ """
139
+ return self.graph_executor
140
+
141
+ @property
142
+ def param_states(self) -> brainstate.util.FlattedDict[Path, brainstate.ParamState]:
143
+ """
144
+ Get the parameter weight states.
145
+
146
+ Returns
147
+ -------
148
+ brainstate.util.FlattedDict[Path, brainstate.ParamState]
149
+ The parameter weight states.
150
+ """
151
+ if self._param_states is None:
152
+ self._split_state()
153
+ return self._param_states
154
+
155
+ @property
156
+ def hidden_states(self) -> brainstate.util.FlattedDict[Path, brainstate.HiddenState]:
157
+ """
158
+ Get the hidden states.
159
+
160
+ Returns
161
+ -------
162
+ brainstate.util.FlattedDict[Path, brainstate.HiddenState]
163
+ The hidden states.
164
+ """
165
+ if self._hidden_states is None:
166
+ self._split_state()
167
+ return self._hidden_states
168
+
169
+ @property
170
+ def other_states(self) -> brainstate.util.FlattedDict[Path, brainstate.State]:
171
+ """
172
+ Get the other states.
173
+
174
+ Returns
175
+ -------
176
+ brainstate.util.FlattedDict[Path, brainstate.State]
177
+ The other states.
178
+ """
179
+ if self._other_states is None:
180
+ self._split_state()
181
+ return self._other_states
182
+
183
+ def _split_state(self):
184
+ # --- the state separation --- #
185
+ #
186
+ # [NOTE]
187
+ #
188
+ # The `ETraceGraphExecutor` and the following states suggests that
189
+ # `ETraceAlgorithm` depends on the states we created in the
190
+ # `ETraceGraphExecutor`, including:
191
+ #
192
+ # - the weight states, which is invariant during the training process
193
+ # - the hidden states, the recurrent states, which may be changed between different training epochs
194
+ # - the other states, which may be changed between different training epochs
195
+ (
196
+ self._param_states,
197
+ self._hidden_states,
198
+ self._other_states
199
+ ) = self.graph.module_info.retrieved_model_states.split(brainstate.ParamState, brainstate.HiddenState, ...)
200
+
201
+ def compile_graph(self, *args) -> None:
202
+ r"""
203
+ Compile the eligibility trace graph of the relationship between etrace weights, states and operators.
204
+
205
+ The compilation process includes:
206
+
207
+ - building the etrace graph
208
+ - separating the states
209
+ - initializing the etrace states
210
+
211
+ Parameters
212
+ ----------
213
+ *args
214
+ The input arguments.
215
+ """
216
+
217
+ if not self.is_compiled:
218
+ # --- the model etrace graph -- #
219
+ self.graph_executor.compile_graph(*args)
220
+
221
+ # --- the initialization of the states --- #
222
+ self.init_etrace_state(*args)
223
+
224
+ # mark the graph is compiled
225
+ self.is_compiled = True
226
+
227
+ @property
228
+ def path_to_states(self) -> brainstate.util.FlattedDict[Path, brainstate.State]:
229
+ """
230
+ Get the path to the states.
231
+
232
+ Returns
233
+ -------
234
+ brainstate.util.FlattedDict[Path, brainstate.State]
235
+ The mapping from path to states.
236
+ """
237
+ return self.graph_executor.path_to_states
238
+
239
+ @property
240
+ def state_id_to_path(self) -> Dict[int, Path]:
241
+ """
242
+ Get the state ID to the path.
243
+
244
+ Returns
245
+ -------
246
+ Dict[int, Path]
247
+ The mapping from state ID to path.
248
+ """
249
+ return self.graph_executor.state_id_to_path
250
+
251
+ def show_graph(self) -> None:
252
+ """
253
+ Show the etrace graph.
254
+ """
255
+ return self.graph_executor.show_graph()
256
+
257
+ def __call__(self, *args) -> Any:
258
+ """
259
+ Update the model and the eligibility trace states.
260
+
261
+ Parameters
262
+ ----------
263
+ *args
264
+ The input arguments.
265
+
266
+ Returns
267
+ -------
268
+ Any
269
+ The output of the update method.
270
+ """
271
+ return self.update(*args)
272
+
273
+ def update(self, *args) -> Any:
274
+ """
275
+ Update the model and the eligibility trace states.
276
+
277
+ Parameters
278
+ ----------
279
+ *args
280
+ The input arguments.
281
+
282
+ Returns
283
+ -------
284
+ Any
285
+ The model output.
286
+
287
+ Raises
288
+ ------
289
+ NotImplementedError
290
+ This method must be implemented by subclasses.
291
+ """
292
+ raise NotImplementedError
293
+
294
+ def init_etrace_state(self, *args, **kwargs) -> None:
295
+ """
296
+ Initialize the eligibility trace states of the etrace algorithm.
297
+
298
+ This method is needed after compiling the etrace graph. See `.compile_graph()` for the details.
299
+
300
+ Parameters
301
+ ----------
302
+ *args
303
+ The positional arguments.
304
+ **kwargs
305
+ The keyword arguments.
306
+
307
+ Raises
308
+ ------
309
+ NotImplementedError
310
+ This method must be implemented by subclasses.
311
+ """
312
+ raise NotImplementedError
313
+
314
+ def get_etrace_of(self, weight: brainstate.ParamState | Path) -> Any:
315
+ """
316
+ Get the eligibility trace of the given weight.
317
+
318
+ Parameters
319
+ ----------
320
+ weight : brainstate.ParamState | Path
321
+ The parameter weight or path to the weight.
322
+
323
+ Returns
324
+ -------
325
+ Any
326
+ The eligibility trace.
327
+
328
+ Raises
329
+ ------
330
+ NotImplementedError
331
+ This method must be implemented by subclasses.
332
+ """
333
+ raise NotImplementedError