braintrace 0.1.2__tar.gz → 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {braintrace-0.1.2 → braintrace-0.2.1}/PKG-INFO +43 -59
- braintrace-0.2.1/README.md +69 -0
- braintrace-0.2.1/braintrace/__init__.py +260 -0
- {braintrace-0.1.2 → braintrace-0.2.1}/braintrace/_compatible_imports.py +1 -5
- braintrace-0.2.1/braintrace/_compile.py +153 -0
- braintrace-0.2.1/braintrace/_compile_test.py +125 -0
- braintrace-0.2.1/braintrace/_etrace_algorithms/__init__.py +62 -0
- braintrace-0.2.1/braintrace/_etrace_algorithms/_common.py +508 -0
- braintrace-0.2.1/braintrace/_etrace_algorithms/_common_test.py +501 -0
- braintrace-0.2.1/braintrace/_etrace_algorithms/approx_correctness_test.py +323 -0
- braintrace-0.1.2/braintrace/_etrace_algorithms.py → braintrace-0.2.1/braintrace/_etrace_algorithms/base.py +28 -15
- braintrace-0.2.1/braintrace/_etrace_algorithms/base_test.py +584 -0
- braintrace-0.2.1/braintrace/_etrace_algorithms/cross_check_test.py +71 -0
- braintrace-0.2.1/braintrace/_etrace_algorithms/d_rtrl.py +56 -0
- braintrace-0.2.1/braintrace/_etrace_algorithms/d_rtrl_test.py +1168 -0
- braintrace-0.2.1/braintrace/_etrace_algorithms/diagnostic_exploration_test.py +177 -0
- braintrace-0.2.1/braintrace/_etrace_algorithms/e_prop.py +248 -0
- braintrace-0.2.1/braintrace/_etrace_algorithms/e_prop_test.py +192 -0
- braintrace-0.2.1/braintrace/_etrace_algorithms/exact_correctness_test.py +281 -0
- braintrace-0.1.2/braintrace/_etrace_graph_executor.py → braintrace-0.2.1/braintrace/_etrace_algorithms/graph_executor.py +74 -41
- braintrace-0.1.2/braintrace/_etrace_graph_executor_test.py → braintrace-0.2.1/braintrace/_etrace_algorithms/graph_executor_test.py +1 -1
- braintrace-0.1.2/braintrace/_etrace_vjp/esd_rtrl.py → braintrace-0.2.1/braintrace/_etrace_algorithms/io_dim_vjp.py +308 -221
- braintrace-0.2.1/braintrace/_etrace_algorithms/oracle.py +222 -0
- braintrace-0.2.1/braintrace/_etrace_algorithms/oracle_models.py +195 -0
- braintrace-0.2.1/braintrace/_etrace_algorithms/oracle_test.py +159 -0
- braintrace-0.2.1/braintrace/_etrace_algorithms/ostl.py +187 -0
- braintrace-0.2.1/braintrace/_etrace_algorithms/ostl_test.py +248 -0
- braintrace-0.2.1/braintrace/_etrace_algorithms/osttp.py +201 -0
- braintrace-0.2.1/braintrace/_etrace_algorithms/osttp_test.py +172 -0
- braintrace-0.2.1/braintrace/_etrace_algorithms/otpe.py +384 -0
- braintrace-0.2.1/braintrace/_etrace_algorithms/otpe_test.py +189 -0
- braintrace-0.2.1/braintrace/_etrace_algorithms/ottt.py +260 -0
- braintrace-0.2.1/braintrace/_etrace_algorithms/ottt_test.py +183 -0
- braintrace-0.2.1/braintrace/_etrace_algorithms/param_dim_vjp.py +969 -0
- braintrace-0.2.1/braintrace/_etrace_algorithms/pp_prop.py +92 -0
- braintrace-0.2.1/braintrace/_etrace_algorithms/pp_prop_test.py +981 -0
- braintrace-0.2.1/braintrace/_etrace_algorithms/public_api_test.py +32 -0
- braintrace-0.2.1/braintrace/_etrace_algorithms/scan_fusion_test.py +141 -0
- braintrace-0.2.1/braintrace/_etrace_algorithms/transform_correctness_test.py +338 -0
- braintrace-0.1.2/braintrace/_etrace_vjp/base.py → braintrace-0.2.1/braintrace/_etrace_algorithms/vjp_base.py +196 -92
- braintrace-0.2.1/braintrace/_etrace_algorithms/vjp_base_test.py +616 -0
- braintrace-0.1.2/braintrace/_etrace_vjp/graph_executor.py → braintrace-0.2.1/braintrace/_etrace_algorithms/vjp_graph_executor.py +178 -96
- braintrace-0.1.2/braintrace/_etrace_vjp/graph_executor_test.py → braintrace-0.2.1/braintrace/_etrace_algorithms/vjp_graph_executor_test.py +6 -2
- braintrace-0.2.1/braintrace/_etrace_compiler/__init__.py +53 -0
- braintrace-0.1.2/braintrace/_etrace_compiler_base.py → braintrace-0.2.1/braintrace/_etrace_compiler/base.py +25 -11
- braintrace-0.2.1/braintrace/_etrace_compiler/base_test.py +704 -0
- braintrace-0.2.1/braintrace/_etrace_compiler/cell_relation_guardrail_test.py +92 -0
- braintrace-0.2.1/braintrace/_etrace_compiler/compiler_oracle_test.py +256 -0
- braintrace-0.2.1/braintrace/_etrace_compiler/compiler_property_test.py +305 -0
- braintrace-0.2.1/braintrace/_etrace_compiler/diagnostics.py +238 -0
- braintrace-0.2.1/braintrace/_etrace_compiler/diagnostics_test.py +242 -0
- braintrace-0.2.1/braintrace/_etrace_compiler/graph.py +412 -0
- braintrace-0.1.2/braintrace/_etrace_compiler_graph_test.py → braintrace-0.2.1/braintrace/_etrace_compiler/graph_test.py +9 -1
- braintrace-0.2.1/braintrace/_etrace_compiler/hid_param_op.py +1082 -0
- braintrace-0.2.1/braintrace/_etrace_compiler/hid_param_op_test.py +197 -0
- braintrace-0.1.2/braintrace/_etrace_compiler_hidden_group.py → braintrace-0.2.1/braintrace/_etrace_compiler/hidden_group.py +287 -165
- braintrace-0.1.2/braintrace/_etrace_compiler_hidden_group_test.py → braintrace-0.2.1/braintrace/_etrace_compiler/hidden_group_test.py +16 -16
- braintrace-0.1.2/braintrace/_etrace_compiler_hidden_pertubation.py → braintrace-0.2.1/braintrace/_etrace_compiler/hidden_pertubation.py +169 -69
- braintrace-0.1.2/braintrace/_etrace_compiler_hidden_pertubation_test.py → braintrace-0.2.1/braintrace/_etrace_compiler/hidden_pertubation_test.py +4 -4
- braintrace-0.1.2/braintrace/_etrace_model_with_group_state.py → braintrace-0.2.1/braintrace/_etrace_compiler/model4test.py +3 -3
- braintrace-0.1.2/braintrace/_etrace_compiler_module_info.py → braintrace-0.2.1/braintrace/_etrace_compiler/module_info.py +192 -93
- braintrace-0.1.2/braintrace/_etrace_compiler_module_info_test.py → braintrace-0.2.1/braintrace/_etrace_compiler/module_info_test.py +3 -3
- braintrace-0.2.1/braintrace/_etrace_compiler/scenario_catalog.py +355 -0
- braintrace-0.2.1/braintrace/_etrace_compiler/scenario_catalog_test.py +977 -0
- {braintrace-0.1.2 → braintrace-0.2.1}/braintrace/_etrace_model_test.py +14 -14
- braintrace-0.2.1/braintrace/_etrace_op/__init__.py +98 -0
- braintrace-0.2.1/braintrace/_etrace_op/_primitive.py +250 -0
- braintrace-0.2.1/braintrace/_etrace_op/_primitive_test.py +285 -0
- braintrace-0.2.1/braintrace/_etrace_op/_registries.py +136 -0
- braintrace-0.2.1/braintrace/_etrace_op/_registries_test.py +165 -0
- braintrace-0.2.1/braintrace/_etrace_op/conv.py +528 -0
- braintrace-0.2.1/braintrace/_etrace_op/conv_test.py +526 -0
- braintrace-0.2.1/braintrace/_etrace_op/dense.py +445 -0
- braintrace-0.2.1/braintrace/_etrace_op/dense_test.py +554 -0
- braintrace-0.2.1/braintrace/_etrace_op/elemwise.py +239 -0
- braintrace-0.2.1/braintrace/_etrace_op/elemwise_test.py +244 -0
- braintrace-0.2.1/braintrace/_etrace_op/lora.py +389 -0
- braintrace-0.2.1/braintrace/_etrace_op/lora_test.py +407 -0
- braintrace-0.2.1/braintrace/_etrace_op/op_rule_oracle.py +60 -0
- braintrace-0.2.1/braintrace/_etrace_op/op_rule_oracle_test.py +171 -0
- braintrace-0.2.1/braintrace/_etrace_op/sparse.py +354 -0
- braintrace-0.2.1/braintrace/_etrace_op/sparse_test.py +454 -0
- braintrace-0.2.1/braintrace/_grad_exponential.py +119 -0
- braintrace-0.1.2/braintrace/_etrace_input_data.py → braintrace-0.2.1/braintrace/_input_data.py +53 -29
- braintrace-0.2.1/braintrace/_legacy/__init__.py +64 -0
- braintrace-0.2.1/braintrace/_legacy/_legacy_test.py +324 -0
- braintrace-0.2.1/braintrace/_legacy/_ops.py +533 -0
- braintrace-0.2.1/braintrace/_legacy/_params.py +319 -0
- {braintrace-0.1.2 → braintrace-0.2.1}/braintrace/_misc.py +12 -43
- braintrace-0.2.1/braintrace/_state_managment.py +221 -0
- braintrace-0.2.1/braintrace/_typing.py +118 -0
- braintrace-0.1.2/braintrace/_etrace_vjp/__init__.py → braintrace-0.2.1/braintrace/_version.py +3 -15
- braintrace-0.2.1/braintrace/api_contract_test.py +214 -0
- braintrace-0.2.1/braintrace/legacy_deprecation_test.py +53 -0
- braintrace-0.2.1/braintrace/nn/__init__.py +97 -0
- braintrace-0.2.1/braintrace/nn/_conv.py +76 -0
- {braintrace-0.1.2 → braintrace-0.2.1}/braintrace/nn/_conv_test.py +39 -84
- braintrace-0.2.1/braintrace/nn/_linear.py +244 -0
- {braintrace-0.1.2 → braintrace-0.2.1}/braintrace/nn/_linear_test.py +48 -55
- braintrace-0.2.1/braintrace/nn/_readout.py +151 -0
- braintrace-0.2.1/braintrace/nn/_readout_test.py +354 -0
- {braintrace-0.1.2 → braintrace-0.2.1}/braintrace/nn/_rnn.py +209 -109
- {braintrace-0.1.2 → braintrace-0.2.1}/braintrace/nn/_rnn_test.py +2 -2
- braintrace-0.2.1/braintrace/py.typed +0 -0
- {braintrace-0.1.2 → braintrace-0.2.1}/braintrace.egg-info/PKG-INFO +43 -59
- braintrace-0.2.1/braintrace.egg-info/SOURCES.txt +112 -0
- braintrace-0.2.1/braintrace.egg-info/requires.txt +29 -0
- braintrace-0.2.1/pyproject.toml +128 -0
- braintrace-0.1.2/README.md +0 -71
- braintrace-0.1.2/braintrace/__init__.py +0 -79
- braintrace-0.1.2/braintrace/_etrace_compiler_graph.py +0 -287
- braintrace-0.1.2/braintrace/_etrace_compiler_hid_param_op.py +0 -832
- braintrace-0.1.2/braintrace/_etrace_compiler_hid_param_op_test.py +0 -112
- braintrace-0.1.2/braintrace/_etrace_concepts.py +0 -382
- braintrace-0.1.2/braintrace/_etrace_concepts_test.py +0 -159
- braintrace-0.1.2/braintrace/_etrace_debug_jaxpr2code.py +0 -1134
- braintrace-0.1.2/braintrace/_etrace_debug_visualize.py +0 -1561
- braintrace-0.1.2/braintrace/_etrace_operators.py +0 -1072
- braintrace-0.1.2/braintrace/_etrace_operators_test.py +0 -58
- braintrace-0.1.2/braintrace/_etrace_vjp/d_rtrl.py +0 -756
- braintrace-0.1.2/braintrace/_etrace_vjp/d_rtrl_test.py +0 -205
- braintrace-0.1.2/braintrace/_etrace_vjp/esd_rtrl_test.py +0 -194
- braintrace-0.1.2/braintrace/_etrace_vjp/hybrid.py +0 -604
- braintrace-0.1.2/braintrace/_etrace_vjp/misc.py +0 -162
- braintrace-0.1.2/braintrace/_grad_exponential.py +0 -85
- braintrace-0.1.2/braintrace/_state_managment.py +0 -436
- braintrace-0.1.2/braintrace/_typing.py +0 -91
- braintrace-0.1.2/braintrace/nn/__init__.py +0 -68
- braintrace-0.1.2/braintrace/nn/_conv.py +0 -395
- braintrace-0.1.2/braintrace/nn/_linear.py +0 -524
- braintrace-0.1.2/braintrace/nn/_normalizations.py +0 -508
- braintrace-0.1.2/braintrace/nn/_normalizations_test.py +0 -695
- braintrace-0.1.2/braintrace/nn/_readout.py +0 -278
- braintrace-0.1.2/braintrace/nn/_readout_test.py +0 -763
- braintrace-0.1.2/braintrace.egg-info/SOURCES.txt +0 -60
- braintrace-0.1.2/braintrace.egg-info/requires.txt +0 -41
- braintrace-0.1.2/pyproject.toml +0 -82
- {braintrace-0.1.2 → braintrace-0.2.1}/LICENSE +0 -0
- {braintrace-0.1.2 → braintrace-0.2.1}/braintrace/_compatible_imports_test.py +0 -0
- /braintrace-0.1.2/braintrace/_etrace_input_data_test.py → /braintrace-0.2.1/braintrace/_input_data_test.py +0 -0
- {braintrace-0.1.2 → braintrace-0.2.1}/braintrace.egg-info/dependency_links.txt +0 -0
- {braintrace-0.1.2 → braintrace-0.2.1}/braintrace.egg-info/top_level.txt +0 -0
- {braintrace-0.1.2 → braintrace-0.2.1}/setup.cfg +0 -0
|
@@ -1,19 +1,18 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: braintrace
|
|
3
|
-
Version: 0.1
|
|
3
|
+
Version: 0.2.1
|
|
4
4
|
Summary: Enabling Scalable Online Learning for Brain Dynamics.
|
|
5
5
|
Author-email: BrainTrace Developers <chao.brain@qq.com>
|
|
6
6
|
License: Apache-2.0 license
|
|
7
7
|
Project-URL: Homepage, https://github.com/chaobrain/braintrace
|
|
8
8
|
Project-URL: Bug Tracker, https://github.com/chaobrain/braintrace/issues
|
|
9
|
-
Project-URL: Documentation, https://
|
|
9
|
+
Project-URL: Documentation, https://brainx.chaobrain.com/braintrace/
|
|
10
10
|
Project-URL: Source Code, https://github.com/chaobrain/braintrace
|
|
11
11
|
Keywords: computational neuroscience,brain-inspired computing,brain modeling,online learning
|
|
12
12
|
Classifier: Natural Language :: English
|
|
13
13
|
Classifier: Operating System :: OS Independent
|
|
14
14
|
Classifier: Programming Language :: Python
|
|
15
15
|
Classifier: Programming Language :: Python :: 3
|
|
16
|
-
Classifier: Programming Language :: Python :: 3.10
|
|
17
16
|
Classifier: Programming Language :: Python :: 3.11
|
|
18
17
|
Classifier: Programming Language :: Python :: 3.12
|
|
19
18
|
Classifier: Programming Language :: Python :: 3.13
|
|
@@ -24,7 +23,7 @@ Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
|
|
|
24
23
|
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
25
24
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
26
25
|
Classifier: Topic :: Software Development :: Libraries
|
|
27
|
-
Requires-Python: >=3.
|
|
26
|
+
Requires-Python: >=3.11
|
|
28
27
|
Description-Content-Type: text/markdown
|
|
29
28
|
License-File: LICENSE
|
|
30
29
|
Requires-Dist: brainstate>=0.2.2
|
|
@@ -34,54 +33,41 @@ Requires-Dist: brainpy-state
|
|
|
34
33
|
Requires-Dist: braintools
|
|
35
34
|
Provides-Extra: cpu
|
|
36
35
|
Requires-Dist: jax[cpu]; extra == "cpu"
|
|
37
|
-
Requires-Dist: brainunit; extra == "cpu"
|
|
38
|
-
Requires-Dist: brainstate; extra == "cpu"
|
|
39
|
-
Requires-Dist: brainpy-state; extra == "cpu"
|
|
40
|
-
Requires-Dist: braintools; extra == "cpu"
|
|
41
36
|
Provides-Extra: cuda12
|
|
42
37
|
Requires-Dist: jax[cuda12]; extra == "cuda12"
|
|
43
|
-
Requires-Dist: brainunit; extra == "cuda12"
|
|
44
|
-
Requires-Dist: brainstate; extra == "cuda12"
|
|
45
|
-
Requires-Dist: brainpy-state; extra == "cuda12"
|
|
46
|
-
Requires-Dist: braintools; extra == "cuda12"
|
|
47
38
|
Provides-Extra: cuda13
|
|
48
39
|
Requires-Dist: jax[cuda13]; extra == "cuda13"
|
|
49
|
-
Requires-Dist: brainunit; extra == "cuda13"
|
|
50
|
-
Requires-Dist: brainstate; extra == "cuda13"
|
|
51
|
-
Requires-Dist: brainpy-state; extra == "cuda13"
|
|
52
|
-
Requires-Dist: braintools; extra == "cuda13"
|
|
53
40
|
Provides-Extra: tpu
|
|
54
41
|
Requires-Dist: jax[tpu]; extra == "tpu"
|
|
55
|
-
Requires-Dist: brainunit; extra == "tpu"
|
|
56
|
-
Requires-Dist: brainstate; extra == "tpu"
|
|
57
|
-
Requires-Dist: brainpy-state; extra == "tpu"
|
|
58
|
-
Requires-Dist: braintools; extra == "tpu"
|
|
59
42
|
Provides-Extra: testing
|
|
60
43
|
Requires-Dist: pytest; extra == "testing"
|
|
61
44
|
Requires-Dist: jax[cpu]; extra == "testing"
|
|
62
|
-
Requires-Dist:
|
|
63
|
-
|
|
64
|
-
Requires-Dist:
|
|
65
|
-
Requires-Dist:
|
|
45
|
+
Requires-Dist: hypothesis; extra == "testing"
|
|
46
|
+
Provides-Extra: dev
|
|
47
|
+
Requires-Dist: pytest; extra == "dev"
|
|
48
|
+
Requires-Dist: jax[cpu]; extra == "dev"
|
|
49
|
+
Requires-Dist: hypothesis; extra == "dev"
|
|
50
|
+
Requires-Dist: mypy>=1.8; extra == "dev"
|
|
51
|
+
Requires-Dist: build; extra == "dev"
|
|
66
52
|
Dynamic: license-file
|
|
67
53
|
|
|
68
54
|
<h1 align="center">BrainTrace</h1>
|
|
69
55
|
<h2 align="center">Eligibility Trace-based Online Learning for Brain Dynamics</h2>
|
|
70
56
|
|
|
71
57
|
<p align="center">
|
|
72
|
-
<img alt="Header image of braintrace." src="https://
|
|
58
|
+
<img alt="Header image of braintrace." src="https://brainx.chaobrain.com/images/braintrace.webp" width=40%>
|
|
73
59
|
</p>
|
|
74
60
|
|
|
75
61
|
<p align="center">
|
|
76
62
|
<a href="https://pypi.org/project/braintrace/"><img alt="Supported Python Version" src="https://img.shields.io/pypi/pyversions/braintrace"></a>
|
|
77
63
|
<a href="https://github.com/chaobrain/braintrace/blob/main/LICENSE"><img alt="LICENSE" src="https://img.shields.io/badge/License-Apache%202.0-blue.svg"></a>
|
|
78
|
-
<a href="https://
|
|
64
|
+
<a href="https://brainx.chaobrain.com/braintrace/"><img alt="Documentation" src="https://readthedocs.org/projects/braintrace/badge/?version=latest"></a>
|
|
79
65
|
<a href="https://badge.fury.io/py/braintrace"><img alt="PyPI version" src="https://badge.fury.io/py/braintrace.svg"></a>
|
|
80
66
|
<a href="https://github.com/chaobrain/braintrace/actions/workflows/CI.yml"><img alt="Continuous Integration" src="https://github.com/chaobrain/braintrace/actions/workflows/CI.yml/badge.svg"></a>
|
|
81
67
|
</p>
|
|
82
68
|
|
|
83
69
|
[``braintrace``](https://github.com/chaobrain/braintrace) provides online learning algorithms for biological neural networks.
|
|
84
|
-
It has been integrated into our establishing [brain modeling ecosystem](https://
|
|
70
|
+
It has been integrated into our establishing [brain modeling ecosystem](https://brainx.chaobrain.com/).
|
|
85
71
|
|
|
86
72
|
## Installation
|
|
87
73
|
|
|
@@ -99,40 +85,38 @@ pip install BrainX -U
|
|
|
99
85
|
|
|
100
86
|
## Documentation
|
|
101
87
|
|
|
102
|
-
The official documentation is hosted on Read the Docs: [https://
|
|
88
|
+
The official documentation is hosted on Read the Docs: [https://brainx.chaobrain.com/braintrace](https://brainx.chaobrain.com/braintrace)
|
|
89
|
+
|
|
90
|
+
## Citation
|
|
91
|
+
|
|
92
|
+
If you use this package in your research, please cite:
|
|
93
|
+
|
|
94
|
+
```bibtex
|
|
95
|
+
|
|
96
|
+
@Article{Wang2026,
|
|
97
|
+
author={Wang, Chaoming
|
|
98
|
+
and Dong, Xingsi
|
|
99
|
+
and Ji, Zilong
|
|
100
|
+
and Xiao, Mingqing
|
|
101
|
+
and Jiang, Jiedong
|
|
102
|
+
and Liu, Xiao
|
|
103
|
+
and Huan, Yuxiang
|
|
104
|
+
and Wu, Si},
|
|
105
|
+
title={Model-agnostic linear-memory online learning in spiking neural networks},
|
|
106
|
+
journal={Nature Communications},
|
|
107
|
+
year={2026},
|
|
108
|
+
month={Jan},
|
|
109
|
+
day={19},
|
|
110
|
+
abstract={Spiking neural networks (SNNs) offer a promising paradigm for modeling brain dynamics and developing neuromorphic intelligence, yet an online learning system capable of training rich spiking dynamics over long horizons with low memory footprints has been missing. Existing online approaches either incur quadratic memory growth, sacrifice biological fidelity through oversimplified models, or lack end-to-end automated tooling. Here, we introduce BrainTrace, a model-agnostic, linear-memory, and automated online learning system for spiking neural networks. BrainTrace standardizes model specification to encompass diverse neuronal and synaptic dynamics; implements a linear-memory online learning rule by exploiting intrinsic properties of spiking dynamics; and provides a compiler that automatically generates optimized online-learning code for arbitrary user-defined models. Across diverse dynamics and tasks, BrainTrace achieves strong learning performance with a low memory footprint and high computational throughput. Critically, these properties enable online fitting of a whole-brain-scale Drosophila SNN that recapitulates region-level functional activity. By reconciling generality, efficiency, and usability, BrainTrace establishes a foundation for spiking network modeling at scale.},
|
|
111
|
+
issn={2041-1723},
|
|
112
|
+
doi={10.1038/s41467-026-68453-w},
|
|
113
|
+
url={https://doi.org/10.1038/s41467-026-68453-w},
|
|
114
|
+
publisher={Nature Publishing Group UK London}
|
|
115
|
+
}
|
|
103
116
|
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
[//]: # ()
|
|
107
|
-
[//]: # (If you use this package in your research, please cite:)
|
|
108
|
-
|
|
109
|
-
[//]: # ()
|
|
110
|
-
[//]: # (```bibtex)
|
|
111
|
-
|
|
112
|
-
[//]: # (@article {Wang2024.09.24.614728,)
|
|
113
|
-
|
|
114
|
-
[//]: # ( author = {Wang, Chaoming and Dong, Xingsi and Ji, Zilong and Jiang, Jiedong and Liu, Xiao and Wu, Si},)
|
|
115
|
-
|
|
116
|
-
[//]: # ( title = {Enabling Scalable Online Learning in Spiking Neural Networks},)
|
|
117
|
-
|
|
118
|
-
[//]: # ( elocation-id = {2024.09.24.614728},)
|
|
119
|
-
|
|
120
|
-
[//]: # ( year = {2025},)
|
|
121
|
-
|
|
122
|
-
[//]: # ( doi = {10.1101/2024.09.24.614728},)
|
|
123
|
-
|
|
124
|
-
[//]: # ( publisher = {Cold Spring Harbor Laboratory},)
|
|
125
|
-
|
|
126
|
-
[//]: # ( URL = {https://www.biorxiv.org/content/early/2025/07/27/2024.09.24.614728},)
|
|
127
|
-
|
|
128
|
-
[//]: # ( eprint = {https://www.biorxiv.org/content/early/2025/07/27/2024.09.24.614728.full.pdf},)
|
|
129
|
-
|
|
130
|
-
[//]: # ( journal = {bioRxiv})
|
|
131
|
-
|
|
132
|
-
[//]: # (})
|
|
117
|
+
```
|
|
133
118
|
|
|
134
|
-
[//]: # (```)
|
|
135
119
|
|
|
136
120
|
## See also the ecosystem
|
|
137
121
|
|
|
138
|
-
``braintrace`` is one part of our brain simulation ecosystem: https://
|
|
122
|
+
``braintrace`` is one part of our brain simulation ecosystem: https://brainx.chaobrain.com/
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
<h1 align="center">BrainTrace</h1>
|
|
2
|
+
<h2 align="center">Eligibility Trace-based Online Learning for Brain Dynamics</h2>
|
|
3
|
+
|
|
4
|
+
<p align="center">
|
|
5
|
+
<img alt="Header image of braintrace." src="https://brainx.chaobrain.com/images/braintrace.webp" width=40%>
|
|
6
|
+
</p>
|
|
7
|
+
|
|
8
|
+
<p align="center">
|
|
9
|
+
<a href="https://pypi.org/project/braintrace/"><img alt="Supported Python Version" src="https://img.shields.io/pypi/pyversions/braintrace"></a>
|
|
10
|
+
<a href="https://github.com/chaobrain/braintrace/blob/main/LICENSE"><img alt="LICENSE" src="https://img.shields.io/badge/License-Apache%202.0-blue.svg"></a>
|
|
11
|
+
<a href="https://brainx.chaobrain.com/braintrace/"><img alt="Documentation" src="https://readthedocs.org/projects/braintrace/badge/?version=latest"></a>
|
|
12
|
+
<a href="https://badge.fury.io/py/braintrace"><img alt="PyPI version" src="https://badge.fury.io/py/braintrace.svg"></a>
|
|
13
|
+
<a href="https://github.com/chaobrain/braintrace/actions/workflows/CI.yml"><img alt="Continuous Integration" src="https://github.com/chaobrain/braintrace/actions/workflows/CI.yml/badge.svg"></a>
|
|
14
|
+
</p>
|
|
15
|
+
|
|
16
|
+
[``braintrace``](https://github.com/chaobrain/braintrace) provides online learning algorithms for biological neural networks.
|
|
17
|
+
It has been integrated into our establishing [brain modeling ecosystem](https://brainx.chaobrain.com/).
|
|
18
|
+
|
|
19
|
+
## Installation
|
|
20
|
+
|
|
21
|
+
``braintrace`` can run on Python 3.10+ installed on Linux, MacOS, and Windows. You can install ``braintrace`` via pip:
|
|
22
|
+
|
|
23
|
+
```bash
|
|
24
|
+
pip install braintrace --upgrade
|
|
25
|
+
```
|
|
26
|
+
|
|
27
|
+
Alternatively, you can install `BrainX`, which bundles `braintrace` with other compatible packages for a comprehensive brain modeling ecosystem:
|
|
28
|
+
|
|
29
|
+
```bash
|
|
30
|
+
pip install BrainX -U
|
|
31
|
+
```
|
|
32
|
+
|
|
33
|
+
## Documentation
|
|
34
|
+
|
|
35
|
+
The official documentation is hosted on Read the Docs: [https://brainx.chaobrain.com/braintrace](https://brainx.chaobrain.com/braintrace)
|
|
36
|
+
|
|
37
|
+
## Citation
|
|
38
|
+
|
|
39
|
+
If you use this package in your research, please cite:
|
|
40
|
+
|
|
41
|
+
```bibtex
|
|
42
|
+
|
|
43
|
+
@Article{Wang2026,
|
|
44
|
+
author={Wang, Chaoming
|
|
45
|
+
and Dong, Xingsi
|
|
46
|
+
and Ji, Zilong
|
|
47
|
+
and Xiao, Mingqing
|
|
48
|
+
and Jiang, Jiedong
|
|
49
|
+
and Liu, Xiao
|
|
50
|
+
and Huan, Yuxiang
|
|
51
|
+
and Wu, Si},
|
|
52
|
+
title={Model-agnostic linear-memory online learning in spiking neural networks},
|
|
53
|
+
journal={Nature Communications},
|
|
54
|
+
year={2026},
|
|
55
|
+
month={Jan},
|
|
56
|
+
day={19},
|
|
57
|
+
abstract={Spiking neural networks (SNNs) offer a promising paradigm for modeling brain dynamics and developing neuromorphic intelligence, yet an online learning system capable of training rich spiking dynamics over long horizons with low memory footprints has been missing. Existing online approaches either incur quadratic memory growth, sacrifice biological fidelity through oversimplified models, or lack end-to-end automated tooling. Here, we introduce BrainTrace, a model-agnostic, linear-memory, and automated online learning system for spiking neural networks. BrainTrace standardizes model specification to encompass diverse neuronal and synaptic dynamics; implements a linear-memory online learning rule by exploiting intrinsic properties of spiking dynamics; and provides a compiler that automatically generates optimized online-learning code for arbitrary user-defined models. Across diverse dynamics and tasks, BrainTrace achieves strong learning performance with a low memory footprint and high computational throughput. Critically, these properties enable online fitting of a whole-brain-scale Drosophila SNN that recapitulates region-level functional activity. By reconciling generality, efficiency, and usability, BrainTrace establishes a foundation for spiking network modeling at scale.},
|
|
58
|
+
issn={2041-1723},
|
|
59
|
+
doi={10.1038/s41467-026-68453-w},
|
|
60
|
+
url={https://doi.org/10.1038/s41467-026-68453-w},
|
|
61
|
+
publisher={Nature Publishing Group UK London}
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
## See also the ecosystem
|
|
68
|
+
|
|
69
|
+
``braintrace`` is one part of our brain simulation ecosystem: https://brainx.chaobrain.com/
|
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
# -*- coding: utf-8 -*-
|
|
17
|
+
|
|
18
|
+
"""braintrace: online learning for recurrent networks via Eligibility Trace Propagation (ETP).
|
|
19
|
+
|
|
20
|
+
``braintrace`` trains recurrent and spiking neural networks **online** — forward
|
|
21
|
+
in time, without backpropagation through time (BPTT). Models mark their
|
|
22
|
+
trainable operations with ETP user-API ops (for example :func:`matmul`,
|
|
23
|
+
:func:`conv`, :func:`sparse_matmul`, :func:`lora_matmul`, :func:`element_wise`)
|
|
24
|
+
instead of wrapping parameters in a special class. A compiler then walks the
|
|
25
|
+
JAX ``jaxpr``, identifies those ETP primitives, and connects each parameter to
|
|
26
|
+
the hidden states it influences so that eligibility traces can be propagated.
|
|
27
|
+
|
|
28
|
+
The public API is organised in four layers, with dependencies pointing strictly
|
|
29
|
+
downward:
|
|
30
|
+
|
|
31
|
+
1. **ETP operators** — the user-facing ops (:func:`matmul`, :func:`conv`, ...),
|
|
32
|
+
the :class:`ETPPrimitive` class, and :func:`register_primitive` for adding
|
|
33
|
+
new ones.
|
|
34
|
+
2. **Compiler** — :func:`compile_etrace_graph` and the analysis containers
|
|
35
|
+
(:class:`ETraceGraph`, :class:`ModuleInfo`, :class:`HiddenGroup`,
|
|
36
|
+
:class:`HiddenParamOpRelation`, :class:`HiddenPerturbation`) plus the
|
|
37
|
+
diagnostics types (:class:`CompilationRecord`, :class:`DiagnosticKind`,
|
|
38
|
+
:class:`DiagnosticLevel`).
|
|
39
|
+
3. **Graph executor** — :class:`ETraceGraphExecutor` /
|
|
40
|
+
:class:`ETraceVjpGraphExecutor`, which run the forward pass and the
|
|
41
|
+
hidden->weight / hidden->hidden Jacobian computations.
|
|
42
|
+
4. **Algorithms** — online-learning orchestrators: the exact algorithms
|
|
43
|
+
:class:`D_RTRL` / :func:`pp_prop` / :class:`ES_D_RTRL`, and the SNN family
|
|
44
|
+
:class:`EProp`, :class:`OSTLRecurrent`, :class:`OSTLFeedforward`,
|
|
45
|
+
:class:`OTPE`, :class:`OTTT`, :class:`OSTTP`.
|
|
46
|
+
|
|
47
|
+
The :mod:`braintrace.nn` subpackage provides ready-made ETP-wired layers
|
|
48
|
+
(linear maps, convolutions, recurrent cells, read-outs).
|
|
49
|
+
|
|
50
|
+
Notes
|
|
51
|
+
-----
|
|
52
|
+
The convenience entry point :func:`compile` wraps a model together with an
|
|
53
|
+
algorithm into a single trainable object and is the recommended starting point.
|
|
54
|
+
The ``braintrace.MatMulOp`` / ``ETraceParam`` style names from the v0.1.x API
|
|
55
|
+
are deprecated shims served lazily with a :class:`DeprecationWarning`; new code
|
|
56
|
+
should mark parameters by routing them through ETP ops instead.
|
|
57
|
+
|
|
58
|
+
Examples
|
|
59
|
+
--------
|
|
60
|
+
.. code-block:: python
|
|
61
|
+
|
|
62
|
+
>>> import braintrace
|
|
63
|
+
>>> # the public API surface is enumerated by __all__
|
|
64
|
+
>>> 'matmul' in braintrace.__all__
|
|
65
|
+
True
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
from typing import TYPE_CHECKING
|
|
70
|
+
|
|
71
|
+
from . import nn
|
|
72
|
+
from ._compile import compile
|
|
73
|
+
from ._etrace_algorithms import (
|
|
74
|
+
ETraceAlgorithm,
|
|
75
|
+
EligibilityTrace,
|
|
76
|
+
ETraceGraphExecutor,
|
|
77
|
+
ETraceVjpAlgorithm,
|
|
78
|
+
ETraceVjpGraphExecutor,
|
|
79
|
+
ParamDimVjpAlgorithm,
|
|
80
|
+
D_RTRL,
|
|
81
|
+
pp_prop,
|
|
82
|
+
ES_D_RTRL,
|
|
83
|
+
IODimVjpAlgorithm,
|
|
84
|
+
EProp,
|
|
85
|
+
OSTLRecurrent,
|
|
86
|
+
OSTLFeedforward,
|
|
87
|
+
OTPE,
|
|
88
|
+
OTTT,
|
|
89
|
+
OSTTP,
|
|
90
|
+
FixedRandomFeedback,
|
|
91
|
+
KappaFilter,
|
|
92
|
+
PresynapticTrace,
|
|
93
|
+
)
|
|
94
|
+
from ._etrace_compiler import (
|
|
95
|
+
ETraceGraph,
|
|
96
|
+
compile_etrace_graph,
|
|
97
|
+
HiddenParamOpRelation,
|
|
98
|
+
find_hidden_param_op_relations_from_minfo,
|
|
99
|
+
find_hidden_param_op_relations_from_module,
|
|
100
|
+
HiddenGroup,
|
|
101
|
+
find_hidden_groups_from_minfo,
|
|
102
|
+
find_hidden_groups_from_module,
|
|
103
|
+
HiddenPerturbation,
|
|
104
|
+
add_hidden_perturbation_from_minfo,
|
|
105
|
+
add_hidden_perturbation_in_module,
|
|
106
|
+
ModuleInfo,
|
|
107
|
+
extract_module_info,
|
|
108
|
+
CompilationRecord,
|
|
109
|
+
DiagnosticKind,
|
|
110
|
+
DiagnosticLevel,
|
|
111
|
+
)
|
|
112
|
+
from ._etrace_op import (
|
|
113
|
+
ETPPrimitive,
|
|
114
|
+
matmul,
|
|
115
|
+
element_wise,
|
|
116
|
+
conv,
|
|
117
|
+
sparse_matmul,
|
|
118
|
+
lora_matmul,
|
|
119
|
+
register_primitive,
|
|
120
|
+
)
|
|
121
|
+
from ._grad_exponential import GradExpon
|
|
122
|
+
from ._input_data import (
|
|
123
|
+
SingleStepData,
|
|
124
|
+
MultiStepData,
|
|
125
|
+
)
|
|
126
|
+
from ._misc import NotSupportedError, CompilationError
|
|
127
|
+
from ._version import __version__, __version_info__
|
|
128
|
+
|
|
129
|
+
if TYPE_CHECKING:
|
|
130
|
+
# The v0.1.x legacy shims are deprecated and served lazily via ``__getattr__``
|
|
131
|
+
# below. Re-import them here so static type checkers / IDEs can still resolve
|
|
132
|
+
# ``braintrace.MatMulOp`` etc.
|
|
133
|
+
from ._legacy import (
|
|
134
|
+
ConvOp,
|
|
135
|
+
ElemWiseOp,
|
|
136
|
+
ElemWiseParam,
|
|
137
|
+
ETraceOp,
|
|
138
|
+
ETraceParam,
|
|
139
|
+
FakeElemWiseParam,
|
|
140
|
+
FakeETraceParam,
|
|
141
|
+
LoraOp,
|
|
142
|
+
MatMulOp,
|
|
143
|
+
NonTempParam,
|
|
144
|
+
SpMatMulOp,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
__all__ = [
|
|
148
|
+
# version
|
|
149
|
+
'__version__',
|
|
150
|
+
'__version_info__',
|
|
151
|
+
|
|
152
|
+
# algorithms
|
|
153
|
+
'ETraceAlgorithm',
|
|
154
|
+
'EligibilityTrace',
|
|
155
|
+
'ETraceVjpAlgorithm',
|
|
156
|
+
'ETraceVjpGraphExecutor',
|
|
157
|
+
'ParamDimVjpAlgorithm',
|
|
158
|
+
'D_RTRL',
|
|
159
|
+
'pp_prop',
|
|
160
|
+
'ES_D_RTRL',
|
|
161
|
+
'IODimVjpAlgorithm',
|
|
162
|
+
|
|
163
|
+
# one-call entry point
|
|
164
|
+
'compile',
|
|
165
|
+
|
|
166
|
+
# ETP primitives (user API)
|
|
167
|
+
'matmul',
|
|
168
|
+
'element_wise',
|
|
169
|
+
'conv',
|
|
170
|
+
'sparse_matmul',
|
|
171
|
+
'lora_matmul',
|
|
172
|
+
|
|
173
|
+
# ETP primitive class & rule registration
|
|
174
|
+
'ETPPrimitive',
|
|
175
|
+
'register_primitive',
|
|
176
|
+
|
|
177
|
+
# input data
|
|
178
|
+
'SingleStepData',
|
|
179
|
+
'MultiStepData',
|
|
180
|
+
|
|
181
|
+
# graph executor
|
|
182
|
+
'ETraceGraphExecutor',
|
|
183
|
+
|
|
184
|
+
# compiler
|
|
185
|
+
'ETraceGraph',
|
|
186
|
+
'compile_etrace_graph',
|
|
187
|
+
'HiddenGroup',
|
|
188
|
+
'find_hidden_groups_from_minfo',
|
|
189
|
+
'find_hidden_groups_from_module',
|
|
190
|
+
'HiddenParamOpRelation',
|
|
191
|
+
'find_hidden_param_op_relations_from_minfo',
|
|
192
|
+
'find_hidden_param_op_relations_from_module',
|
|
193
|
+
'ModuleInfo',
|
|
194
|
+
'extract_module_info',
|
|
195
|
+
'HiddenPerturbation',
|
|
196
|
+
'add_hidden_perturbation_from_minfo',
|
|
197
|
+
'add_hidden_perturbation_in_module',
|
|
198
|
+
|
|
199
|
+
# compiler diagnostics
|
|
200
|
+
'CompilationRecord',
|
|
201
|
+
'DiagnosticKind',
|
|
202
|
+
'DiagnosticLevel',
|
|
203
|
+
|
|
204
|
+
# gradient utilities
|
|
205
|
+
'GradExpon',
|
|
206
|
+
|
|
207
|
+
# SNN online-learning algorithms
|
|
208
|
+
'EProp',
|
|
209
|
+
'OSTLRecurrent',
|
|
210
|
+
'OSTLFeedforward',
|
|
211
|
+
'OTPE',
|
|
212
|
+
'OTTT',
|
|
213
|
+
'OSTTP',
|
|
214
|
+
'FixedRandomFeedback',
|
|
215
|
+
'KappaFilter',
|
|
216
|
+
'PresynapticTrace',
|
|
217
|
+
|
|
218
|
+
# errors
|
|
219
|
+
'NotSupportedError',
|
|
220
|
+
'CompilationError',
|
|
221
|
+
|
|
222
|
+
# submodules
|
|
223
|
+
'nn',
|
|
224
|
+
]
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
# --- v0.1.x legacy shims: deprecated, served lazily with an access-time warning.
|
|
228
|
+
# Each maps the public name -> migration replacement text. The shim classes still
|
|
229
|
+
# work; new code should use the primitive-based ETP user-API instead.
|
|
230
|
+
_DEPRECATED_LEGACY = {
|
|
231
|
+
'MatMulOp': 'braintrace.matmul (with a brainstate.ParamState)',
|
|
232
|
+
'ElemWiseOp': 'braintrace.element_wise',
|
|
233
|
+
'ConvOp': 'braintrace.conv',
|
|
234
|
+
'SpMatMulOp': 'braintrace.sparse_matmul',
|
|
235
|
+
'LoraOp': 'braintrace.lora_matmul',
|
|
236
|
+
'ETraceOp': 'the braintrace ETP primitive functions (matmul, conv, ...)',
|
|
237
|
+
'ETraceParam': 'brainstate.ParamState together with an ETP primitive function',
|
|
238
|
+
'ElemWiseParam': 'brainstate.ParamState together with braintrace.element_wise',
|
|
239
|
+
'NonTempParam': 'brainstate.ParamState with plain JAX ops (keeps the weight out of the ETP graph)',
|
|
240
|
+
'FakeETraceParam': 'a plain object with plain JAX ops',
|
|
241
|
+
'FakeElemWiseParam': 'a plain object with plain JAX ops',
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def __getattr__(name):
|
|
246
|
+
if name in _DEPRECATED_LEGACY:
|
|
247
|
+
import warnings
|
|
248
|
+
warnings.warn(
|
|
249
|
+
f'braintrace.{name} is deprecated since 0.2.0 and will be removed in a '
|
|
250
|
+
f'future release; use {_DEPRECATED_LEGACY[name]} instead.',
|
|
251
|
+
DeprecationWarning,
|
|
252
|
+
stacklevel=2,
|
|
253
|
+
)
|
|
254
|
+
from . import _legacy
|
|
255
|
+
return getattr(_legacy, name)
|
|
256
|
+
raise AttributeError(f'module {__name__!r} has no attribute {name!r}')
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def __dir__():
|
|
260
|
+
return sorted(list(__all__) + list(_DEPRECATED_LEGACY))
|
|
@@ -29,11 +29,7 @@ __all__ = [
|
|
|
29
29
|
'is_cond_primitive',
|
|
30
30
|
]
|
|
31
31
|
|
|
32
|
-
|
|
33
|
-
from jax.core import Primitive, Var, JaxprEqn, Jaxpr, ClosedJaxpr, Literal
|
|
34
|
-
|
|
35
|
-
else:
|
|
36
|
-
from jax.extend.core import Primitive, Var, JaxprEqn, Jaxpr, ClosedJaxpr, Literal
|
|
32
|
+
from brainstate._compatible_import import Primitive, Var, JaxprEqn, Jaxpr, ClosedJaxpr, Literal
|
|
37
33
|
|
|
38
34
|
|
|
39
35
|
def new_var(suffix, aval):
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
# Copyright 2026 BrainX Ecosystem Limited. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from typing import Type, Union
|
|
17
|
+
|
|
18
|
+
from ._etrace_algorithms import (
|
|
19
|
+
ETraceAlgorithm,
|
|
20
|
+
D_RTRL,
|
|
21
|
+
pp_prop,
|
|
22
|
+
EProp,
|
|
23
|
+
OSTLRecurrent,
|
|
24
|
+
OSTLFeedforward,
|
|
25
|
+
OTPE,
|
|
26
|
+
OTTT,
|
|
27
|
+
OSTTP,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
__all__ = ['compile']
|
|
31
|
+
|
|
32
|
+
# Canonical lowercase name (+ aliases) -> algorithm class. No bare ``ostl``
|
|
33
|
+
# alias: the ambiguous OSTL factory was removed in 0.2.0, so callers pick
|
|
34
|
+
# ``ostl_recurrent`` vs ``ostl_feedforward`` explicitly.
|
|
35
|
+
_ALGORITHM_REGISTRY: dict[str, type[ETraceAlgorithm]] = {
|
|
36
|
+
'd_rtrl': D_RTRL,
|
|
37
|
+
'pp_prop': pp_prop,
|
|
38
|
+
'es_d_rtrl': pp_prop,
|
|
39
|
+
'esd_rtrl': pp_prop,
|
|
40
|
+
'eprop': EProp,
|
|
41
|
+
'e_prop': EProp,
|
|
42
|
+
'ostl_recurrent': OSTLRecurrent,
|
|
43
|
+
'ostl_feedforward': OSTLFeedforward,
|
|
44
|
+
'otpe': OTPE,
|
|
45
|
+
'ottt': OTTT,
|
|
46
|
+
'osttp': OSTTP,
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _resolve_algorithm(
|
|
51
|
+
algorithm: Union[str, Type[ETraceAlgorithm]]
|
|
52
|
+
) -> Type[ETraceAlgorithm]:
|
|
53
|
+
"""Resolve ``algorithm`` to an :class:`ETraceAlgorithm` subclass.
|
|
54
|
+
|
|
55
|
+
Parameters
|
|
56
|
+
----------
|
|
57
|
+
algorithm : type or str
|
|
58
|
+
Either an :class:`ETraceAlgorithm` subclass (returned unchanged) or a
|
|
59
|
+
registered string name (case-insensitive), e.g. ``'D_RTRL'``,
|
|
60
|
+
``'eprop'``, ``'ottt'``.
|
|
61
|
+
|
|
62
|
+
Returns
|
|
63
|
+
-------
|
|
64
|
+
type
|
|
65
|
+
The resolved :class:`ETraceAlgorithm` subclass.
|
|
66
|
+
|
|
67
|
+
Raises
|
|
68
|
+
------
|
|
69
|
+
ValueError
|
|
70
|
+
If ``algorithm`` is a string that is not a registered name.
|
|
71
|
+
TypeError
|
|
72
|
+
If ``algorithm`` is a class that is not an ``ETraceAlgorithm`` subclass,
|
|
73
|
+
or is neither a class nor a string.
|
|
74
|
+
"""
|
|
75
|
+
if isinstance(algorithm, type):
|
|
76
|
+
if issubclass(algorithm, ETraceAlgorithm):
|
|
77
|
+
return algorithm
|
|
78
|
+
raise TypeError(
|
|
79
|
+
f'algorithm class must be a subclass of ETraceAlgorithm, got {algorithm!r}.'
|
|
80
|
+
)
|
|
81
|
+
if isinstance(algorithm, str):
|
|
82
|
+
key = algorithm.strip().lower()
|
|
83
|
+
try:
|
|
84
|
+
return _ALGORITHM_REGISTRY[key]
|
|
85
|
+
except KeyError:
|
|
86
|
+
valid = ', '.join(sorted(_ALGORITHM_REGISTRY))
|
|
87
|
+
raise ValueError(
|
|
88
|
+
f'Unknown algorithm name {algorithm!r}. Valid names: {valid}. '
|
|
89
|
+
f'Or pass an ETraceAlgorithm subclass directly.'
|
|
90
|
+
)
|
|
91
|
+
raise TypeError(
|
|
92
|
+
f'algorithm must be an ETraceAlgorithm subclass or a registered string name, '
|
|
93
|
+
f'got {type(algorithm)}.'
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def compile(model, algorithm, *example_inputs, **options):
|
|
98
|
+
"""Construct an online-learning algorithm for ``model`` and eagerly build its
|
|
99
|
+
eligibility-trace graph, returning a ready-to-``update`` learner.
|
|
100
|
+
|
|
101
|
+
Parameters
|
|
102
|
+
----------
|
|
103
|
+
model : brainstate.nn.Module
|
|
104
|
+
The recurrent model. Its states must already be initialized, e.g. via
|
|
105
|
+
``brainstate.nn.init_all_states(model)``.
|
|
106
|
+
algorithm : type or str
|
|
107
|
+
An :class:`ETraceAlgorithm` subclass, or a registered string name
|
|
108
|
+
(case-insensitive), e.g. ``'D_RTRL'``, ``'eprop'``, ``'ottt'``.
|
|
109
|
+
*example_inputs
|
|
110
|
+
Example call inputs (arrays / :class:`SingleStepData` /
|
|
111
|
+
:class:`MultiStepData`), matching what ``learner.update(...)`` will later
|
|
112
|
+
receive. Forwarded to :meth:`ETraceAlgorithm.compile_graph` to trace the
|
|
113
|
+
jaxpr graph. At least one is required.
|
|
114
|
+
**options
|
|
115
|
+
Keyword options forwarded to the algorithm constructor, e.g.
|
|
116
|
+
``vjp_method``, ``leak``, ``fast_solve``, ``trace_dtype``, ``feedback``.
|
|
117
|
+
|
|
118
|
+
Returns
|
|
119
|
+
-------
|
|
120
|
+
ETraceAlgorithm
|
|
121
|
+
The compiled learner; call ``.update(*inputs)`` to train.
|
|
122
|
+
|
|
123
|
+
Raises
|
|
124
|
+
------
|
|
125
|
+
ValueError
|
|
126
|
+
If ``algorithm`` is an unknown string name, or no ``example_inputs`` are
|
|
127
|
+
given.
|
|
128
|
+
TypeError
|
|
129
|
+
If ``algorithm`` is neither an ``ETraceAlgorithm`` subclass nor a string.
|
|
130
|
+
|
|
131
|
+
Examples
|
|
132
|
+
--------
|
|
133
|
+
.. code-block:: python
|
|
134
|
+
|
|
135
|
+
>>> import braintrace
|
|
136
|
+
>>> import brainstate
|
|
137
|
+
>>> import jax.numpy as jnp
|
|
138
|
+
>>> model = MyRNN()
|
|
139
|
+
>>> brainstate.nn.init_all_states(model, batch_size=1)
|
|
140
|
+
>>> x0 = jnp.ones((3,))
|
|
141
|
+
>>> learner = braintrace.compile(model, 'D_RTRL', x0, vjp_method='multi-step')
|
|
142
|
+
>>> y = learner.update(x0)
|
|
143
|
+
"""
|
|
144
|
+
cls = _resolve_algorithm(algorithm)
|
|
145
|
+
if len(example_inputs) == 0:
|
|
146
|
+
raise ValueError(
|
|
147
|
+
'compile() needs at least one example input to build the graph '
|
|
148
|
+
'eagerly, e.g. compile(model, "D_RTRL", x0). Pass the same inputs '
|
|
149
|
+
'you will give to learner.update(...).'
|
|
150
|
+
)
|
|
151
|
+
learner = cls(model, **options)
|
|
152
|
+
learner.compile_graph(*example_inputs)
|
|
153
|
+
return learner
|