mbrila 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. mbrila-0.1.0/.github/workflows/ci.yml +24 -0
  2. mbrila-0.1.0/.gitignore +78 -0
  3. mbrila-0.1.0/.python-version +1 -0
  4. mbrila-0.1.0/LICENSE +21 -0
  5. mbrila-0.1.0/PKG-INFO +457 -0
  6. mbrila-0.1.0/README.md +421 -0
  7. mbrila-0.1.0/THIRD_PARTY_NOTICES.md +96 -0
  8. mbrila-0.1.0/data/demo_v1v2_data.pkl +0 -0
  9. mbrila-0.1.0/examples/synthetic/demo_adm.py +218 -0
  10. mbrila-0.1.0/examples/synthetic/demo_common.py +1472 -0
  11. mbrila-0.1.0/examples/synthetic/demo_custom_kernel.py +281 -0
  12. mbrila-0.1.0/examples/synthetic/demo_dlag.py +158 -0
  13. mbrila-0.1.0/examples/synthetic/demo_dlag_ssm.py +195 -0
  14. mbrila-0.1.0/examples/synthetic/demo_gpfa_ssm.py +187 -0
  15. mbrila-0.1.0/examples/synthetic/demo_lds.py +152 -0
  16. mbrila-0.1.0/examples/synthetic/demo_matern.py +195 -0
  17. mbrila-0.1.0/examples/synthetic/demo_mdlag_freq.py +182 -0
  18. mbrila-0.1.0/examples/synthetic/demo_mdlag_ssm.py +198 -0
  19. mbrila-0.1.0/examples/synthetic/demo_mdlag_time.py +180 -0
  20. mbrila-0.1.0/examples/v1v2/compare_v1v2_runs.py +263 -0
  21. mbrila-0.1.0/examples/v1v2/demo_adm.py +272 -0
  22. mbrila-0.1.0/examples/v1v2/demo_custom_kernel.py +301 -0
  23. mbrila-0.1.0/examples/v1v2/demo_dlag.py +180 -0
  24. mbrila-0.1.0/examples/v1v2/demo_dlag_ssm.py +255 -0
  25. mbrila-0.1.0/examples/v1v2/demo_mdlag_freq.py +190 -0
  26. mbrila-0.1.0/examples/v1v2/demo_mdlag_ssm.py +196 -0
  27. mbrila-0.1.0/examples/v1v2/demo_mdlag_time.py +189 -0
  28. mbrila-0.1.0/examples/v1v2/v1v2_common.py +1092 -0
  29. mbrila-0.1.0/notebooks/nb_helpers.py +468 -0
  30. mbrila-0.1.0/notebooks/synthetic/demo_adm.ipynb +586 -0
  31. mbrila-0.1.0/notebooks/synthetic/demo_custom_kernel.ipynb +682 -0
  32. mbrila-0.1.0/notebooks/synthetic/demo_dlag.ipynb +579 -0
  33. mbrila-0.1.0/notebooks/synthetic/demo_dlag_ssm.ipynb +642 -0
  34. mbrila-0.1.0/notebooks/synthetic/demo_gpfa_ssm.ipynb +562 -0
  35. mbrila-0.1.0/notebooks/synthetic/demo_lds.ipynb +495 -0
  36. mbrila-0.1.0/notebooks/synthetic/demo_matern.ipynb +613 -0
  37. mbrila-0.1.0/notebooks/synthetic/demo_mdlag_freq.ipynb +651 -0
  38. mbrila-0.1.0/notebooks/synthetic/demo_mdlag_ssm.ipynb +660 -0
  39. mbrila-0.1.0/notebooks/synthetic/demo_mdlag_time.ipynb +617 -0
  40. mbrila-0.1.0/notebooks/v1v2/demo_adm.ipynb +662 -0
  41. mbrila-0.1.0/notebooks/v1v2/demo_custom_kernel.ipynb +678 -0
  42. mbrila-0.1.0/notebooks/v1v2/demo_dlag.ipynb +643 -0
  43. mbrila-0.1.0/notebooks/v1v2/demo_dlag_ssm.ipynb +656 -0
  44. mbrila-0.1.0/notebooks/v1v2/demo_mdlag_freq.ipynb +663 -0
  45. mbrila-0.1.0/notebooks/v1v2/demo_mdlag_ssm.ipynb +669 -0
  46. mbrila-0.1.0/notebooks/v1v2/demo_mdlag_time.ipynb +663 -0
  47. mbrila-0.1.0/pyproject.toml +133 -0
  48. mbrila-0.1.0/src/mbrila/__init__.py +107 -0
  49. mbrila-0.1.0/src/mbrila/_testing/__init__.py +5 -0
  50. mbrila-0.1.0/src/mbrila/_testing/synthetic.py +63 -0
  51. mbrila-0.1.0/src/mbrila/api.py +105 -0
  52. mbrila-0.1.0/src/mbrila/core/__init__.py +38 -0
  53. mbrila-0.1.0/src/mbrila/core/base_model.py +217 -0
  54. mbrila-0.1.0/src/mbrila/core/data.py +210 -0
  55. mbrila-0.1.0/src/mbrila/core/delay_spec.py +95 -0
  56. mbrila-0.1.0/src/mbrila/core/inference_engine.py +122 -0
  57. mbrila-0.1.0/src/mbrila/core/kernel_spec.py +109 -0
  58. mbrila-0.1.0/src/mbrila/core/latent_spec.py +139 -0
  59. mbrila-0.1.0/src/mbrila/core/observation_spec.py +86 -0
  60. mbrila-0.1.0/src/mbrila/core/registry.py +62 -0
  61. mbrila-0.1.0/src/mbrila/delays/__init__.py +7 -0
  62. mbrila-0.1.0/src/mbrila/delays/fixed.py +136 -0
  63. mbrila-0.1.0/src/mbrila/delays/none.py +62 -0
  64. mbrila-0.1.0/src/mbrila/delays/time_varying.py +172 -0
  65. mbrila-0.1.0/src/mbrila/dynamics/__init__.py +19 -0
  66. mbrila-0.1.0/src/mbrila/dynamics/exact_gp.py +376 -0
  67. mbrila-0.1.0/src/mbrila/dynamics/free_lds.py +184 -0
  68. mbrila-0.1.0/src/mbrila/dynamics/kernel_to_sde.py +304 -0
  69. mbrila-0.1.0/src/mbrila/dynamics/markov_gp.py +301 -0
  70. mbrila-0.1.0/src/mbrila/dynamics/ssm_base.py +75 -0
  71. mbrila-0.1.0/src/mbrila/frequency/__init__.py +10 -0
  72. mbrila-0.1.0/src/mbrila/frequency/fft.py +112 -0
  73. mbrila-0.1.0/src/mbrila/inference/__init__.py +17 -0
  74. mbrila-0.1.0/src/mbrila/inference/ard_helpers.py +352 -0
  75. mbrila-0.1.0/src/mbrila/inference/em_exact.py +390 -0
  76. mbrila-0.1.0/src/mbrila/inference/hmm/__init__.py +0 -0
  77. mbrila-0.1.0/src/mbrila/inference/kalman/__init__.py +23 -0
  78. mbrila-0.1.0/src/mbrila/inference/kalman/parallel.py +545 -0
  79. mbrila-0.1.0/src/mbrila/inference/kalman/sequential.py +305 -0
  80. mbrila-0.1.0/src/mbrila/inference/kalman/state.py +114 -0
  81. mbrila-0.1.0/src/mbrila/inference/kalman_em.py +576 -0
  82. mbrila-0.1.0/src/mbrila/inference/optim.py +54 -0
  83. mbrila-0.1.0/src/mbrila/inference/vem_ard.py +512 -0
  84. mbrila-0.1.0/src/mbrila/inference/vem_ard_freq.py +740 -0
  85. mbrila-0.1.0/src/mbrila/inference/vem_kalman_ard.py +573 -0
  86. mbrila-0.1.0/src/mbrila/init/__init__.py +12 -0
  87. mbrila-0.1.0/src/mbrila/init/factor_analysis.py +238 -0
  88. mbrila-0.1.0/src/mbrila/init/pcca.py +159 -0
  89. mbrila-0.1.0/src/mbrila/init/scale_anchor.py +163 -0
  90. mbrila-0.1.0/src/mbrila/kernels/__init__.py +43 -0
  91. mbrila-0.1.0/src/mbrila/kernels/base.py +120 -0
  92. mbrila-0.1.0/src/mbrila/kernels/matern.py +250 -0
  93. mbrila-0.1.0/src/mbrila/kernels/mose.py +338 -0
  94. mbrila-0.1.0/src/mbrila/kernels/validate.py +184 -0
  95. mbrila-0.1.0/src/mbrila/metrics/__init__.py +0 -0
  96. mbrila-0.1.0/src/mbrila/models/__init__.py +9 -0
  97. mbrila-0.1.0/src/mbrila/models/adm.py +523 -0
  98. mbrila-0.1.0/src/mbrila/models/dlag.py +629 -0
  99. mbrila-0.1.0/src/mbrila/models/gpfa.py +316 -0
  100. mbrila-0.1.0/src/mbrila/models/lds.py +218 -0
  101. mbrila-0.1.0/src/mbrila/models/mdlag.py +457 -0
  102. mbrila-0.1.0/src/mbrila/observations/__init__.py +15 -0
  103. mbrila-0.1.0/src/mbrila/observations/ard.py +668 -0
  104. mbrila-0.1.0/src/mbrila/observations/linear_regression.py +211 -0
  105. mbrila-0.1.0/src/mbrila/observations/multi_region.py +205 -0
  106. mbrila-0.1.0/src/mbrila/synthetic/__init__.py +41 -0
  107. mbrila-0.1.0/src/mbrila/synthetic/multiregion.py +504 -0
  108. mbrila-0.1.0/src/mbrila/synthetic/scenarios.py +166 -0
  109. mbrila-0.1.0/src/mbrila/utils/__init__.py +0 -0
  110. mbrila-0.1.0/src/mbrila/utils/device.py +33 -0
  111. mbrila-0.1.0/tests/__init__.py +0 -0
  112. mbrila-0.1.0/tests/integration/__init__.py +0 -0
  113. mbrila-0.1.0/tests/recovery/__init__.py +0 -0
  114. mbrila-0.1.0/tests/recovery/test_adm_recovery.py +108 -0
  115. mbrila-0.1.0/tests/recovery/test_dlag_recovery.py +128 -0
  116. mbrila-0.1.0/tests/recovery/test_mdlag_freq_recovery.py +133 -0
  117. mbrila-0.1.0/tests/recovery/test_mdlag_recovery.py +170 -0
  118. mbrila-0.1.0/tests/recovery/test_mdlag_ssm_recovery.py +185 -0
  119. mbrila-0.1.0/tests/unit/__init__.py +0 -0
  120. mbrila-0.1.0/tests/unit/test_adm_components.py +316 -0
  121. mbrila-0.1.0/tests/unit/test_adm_pcca_init.py +122 -0
  122. mbrila-0.1.0/tests/unit/test_ard_helpers.py +268 -0
  123. mbrila-0.1.0/tests/unit/test_ard_observation.py +423 -0
  124. mbrila-0.1.0/tests/unit/test_core_data.py +114 -0
  125. mbrila-0.1.0/tests/unit/test_core_engine_and_model.py +257 -0
  126. mbrila-0.1.0/tests/unit/test_core_specs.py +181 -0
  127. mbrila-0.1.0/tests/unit/test_cosine_lr_min.py +173 -0
  128. mbrila-0.1.0/tests/unit/test_dlag_components.py +168 -0
  129. mbrila-0.1.0/tests/unit/test_dlag_ssm.py +328 -0
  130. mbrila-0.1.0/tests/unit/test_em_exact.py +184 -0
  131. mbrila-0.1.0/tests/unit/test_exact_gp.py +284 -0
  132. mbrila-0.1.0/tests/unit/test_factor_analysis.py +81 -0
  133. mbrila-0.1.0/tests/unit/test_fixed_delay.py +111 -0
  134. mbrila-0.1.0/tests/unit/test_frequency.py +300 -0
  135. mbrila-0.1.0/tests/unit/test_gpfa_and_build_model.py +352 -0
  136. mbrila-0.1.0/tests/unit/test_kalman_em_parallel_parity.py +48 -0
  137. mbrila-0.1.0/tests/unit/test_kalman_parallel.py +228 -0
  138. mbrila-0.1.0/tests/unit/test_kalman_sequential.py +174 -0
  139. mbrila-0.1.0/tests/unit/test_kalman_state.py +72 -0
  140. mbrila-0.1.0/tests/unit/test_kernel_extensibility.py +313 -0
  141. mbrila-0.1.0/tests/unit/test_lagged_cov_grid.py +237 -0
  142. mbrila-0.1.0/tests/unit/test_lds_naive_ssm.py +250 -0
  143. mbrila-0.1.0/tests/unit/test_linear_regression.py +131 -0
  144. mbrila-0.1.0/tests/unit/test_mdlag_ssm.py +328 -0
  145. mbrila-0.1.0/tests/unit/test_mdlag_ssm_joint_ll_em.py +161 -0
  146. mbrila-0.1.0/tests/unit/test_mose_grad.py +80 -0
  147. mbrila-0.1.0/tests/unit/test_no_delay.py +166 -0
  148. mbrila-0.1.0/tests/unit/test_pcca_init.py +172 -0
  149. mbrila-0.1.0/tests/unit/test_registry.py +57 -0
  150. mbrila-0.1.0/tests/unit/test_variational_kalman_inputs.py +265 -0
  151. mbrila-0.1.0/tests/unit/test_vem_ard.py +243 -0
  152. mbrila-0.1.0/tests/unit/test_vem_ard_freq.py +352 -0
  153. mbrila-0.1.0/uv.lock +1446 -0
@@ -0,0 +1,24 @@
1
+ name: ci
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+ pull_request:
7
+
8
+ concurrency:
9
+ group: ${{ github.workflow }}-${{ github.ref }}
10
+ cancel-in-progress: true
11
+
12
+ jobs:
13
+ check:
14
+ runs-on: ubuntu-latest
15
+ steps:
16
+ - uses: actions/checkout@v4
17
+ - uses: astral-sh/setup-uv@v3
18
+ with:
19
+ enable-cache: true
20
+ - run: uv sync --extra dev
21
+ - run: uv run ruff check src tests
22
+ - run: uv run ruff format --check src tests
23
+ - run: uv run mypy src
24
+ - run: uv run pytest -q -m "not recovery"
@@ -0,0 +1,78 @@
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.egg-info/
6
+ .Python
7
+ build/
8
+ dist/
9
+ wheels/
10
+ *.egg
11
+
12
+ # Virtual environments
13
+ # (uv.lock and .python-version are intentionally tracked — do not ignore)
14
+ .venv/
15
+ venv/
16
+ ENV/
17
+ .env
18
+
19
+ # Tooling caches
20
+ .pytest_cache/
21
+ .mypy_cache/
22
+ .ruff_cache/
23
+ .coverage
24
+ .coverage.*
25
+ htmlcov/
26
+ .cache/
27
+
28
+ # Editors / tools / OS
29
+ .vscode/
30
+ .idea/
31
+ .claude/
32
+ *.swp
33
+ *.swo
34
+ .DS_Store
35
+ Thumbs.db
36
+ *~
37
+
38
+ # Documentation builds
39
+ site/
40
+ docs/_build/
41
+
42
+ # Jupyter
43
+ .ipynb_checkpoints/
44
+
45
+ # SLURM / cluster job logs
46
+ *.out
47
+ slurm-*.out
48
+ *.log
49
+
50
+ # scratch/ — only the demo scripts (scratch/*.py) are part of the
51
+ # release; every generated output (figures, arrays, text reports) is
52
+ # ignored. scratch holds no tracked .txt, so ignore them all.
53
+ scratch/**/*.png
54
+ scratch/**/*.npz
55
+ scratch/**/*.pt
56
+ scratch/**/*.pdf
57
+ scratch/**/*.txt
58
+ scratch/
59
+
60
+ examples/**/demo_outputs*
61
+
62
+ # Neural data files (keep large recordings out of the repo)
63
+ *.pkl
64
+ # Exception: the V1/V2 demo data shipped with the V1V2 examples.
65
+ !data/demo_v1v2_data.pkl
66
+
67
+ # Internal dev files — not part of the public release
68
+ CLAUDE.md
69
+ python-pytorch-multi-brain-region-cheerful-clarke.md
70
+ job.sh
71
+ job_v1v2.sh
72
+ job_test.sh
73
+ scripts/
74
+ # Test for scripts/check_no_trial_loops.py — excluded with scripts/ itself.
75
+ tests/unit/test_check_no_trial_loops.py
76
+ profile_adm_vs_dlag_ssm.py
77
+ tmp*/
78
+ notebooks/_test/
@@ -0,0 +1 @@
1
+ 3.12
mbrila-0.1.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 mbrila contributors
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
mbrila-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,457 @@
1
+ Metadata-Version: 2.4
2
+ Name: mbrila
3
+ Version: 0.1.0
4
+ Summary: Multiple Brain Region Interaction using Latent Analysis — unified PyTorch library for multi-region neural latent variable models.
5
+ Project-URL: Homepage, https://github.com/BRAINML-GT/MBRILA
6
+ Project-URL: Issues, https://github.com/BRAINML-GT/MBRILA/issues
7
+ Project-URL: Repository, https://github.com/BRAINML-GT/MBRILA
8
+ Author-email: Weihan Li <weihanli@gatech.edu>
9
+ License: MIT
10
+ License-File: LICENSE
11
+ Keywords: gaussian-process,latent-variable,multi-region,neuroscience,pytorch,state-space-model
12
+ Classifier: Development Status :: 4 - Beta
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: License :: OSI Approved :: MIT License
15
+ Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
17
+ Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
18
+ Requires-Python: <3.13,>=3.12
19
+ Requires-Dist: matplotlib>=3.10.9
20
+ Requires-Dist: numpy>=2.0
21
+ Requires-Dist: scikit-learn>=1.5
22
+ Requires-Dist: scipy>=1.13
23
+ Requires-Dist: torch>=2.5
24
+ Requires-Dist: tqdm>=4.66
25
+ Provides-Extra: dev
26
+ Requires-Dist: hypothesis>=6.100; extra == 'dev'
27
+ Requires-Dist: mypy>=1.11; extra == 'dev'
28
+ Requires-Dist: pytest-cov>=5.0; extra == 'dev'
29
+ Requires-Dist: pytest>=8.0; extra == 'dev'
30
+ Requires-Dist: ruff>=0.6; extra == 'dev'
31
+ Provides-Extra: docs
32
+ Requires-Dist: mkdocs-material>=9.5; extra == 'docs'
33
+ Requires-Dist: mkdocs>=1.6; extra == 'docs'
34
+ Requires-Dist: mkdocstrings[python]>=0.26; extra == 'docs'
35
+ Description-Content-Type: text/markdown
36
+
37
+ # mbrila
38
+
39
+ **M**ultiple **B**rain **R**egion **I**nteraction using **L**atent **A**nalysis
40
+
41
+ [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE)
42
+ [![Python 3.12](https://img.shields.io/badge/python-3.12-blue.svg)](https://www.python.org/)
43
+
44
+ A unified, GPU-native PyTorch library for **multi-region neural latent
45
+ variable models** — methods that infer shared low-dimensional latent
46
+ processes across simultaneously recorded brain regions, optionally with
47
+ **inter-region communication delays** (constant or time-varying).
48
+
49
+ ## Why GP-SSM
50
+
51
+ Multi-region neural population dynamics are typically slow and smooth —
52
+ they have a **Gaussian process structure** on the latent level (this is
53
+ the modelling assumption behind GPFA, DLAG, mDLAG, …). The price of a
54
+ dense GP prior is **O(T³) inference cost** in the trial length T:
55
+ fitting 200-bin trials with 5 latents already needs a 1000-dim Cholesky
56
+ per EM iter. Free state-space models (linear LDS, RNN) scale to long T
57
+ but throw away GP smoothness and timescale priors.
58
+
59
+ `mbrila` builds on the recent observation
60
+ ([Li et al., ICML 2025](https://proceedings.mlr.press/v267/li25ck.html))
61
+ that **any stationary GP can be approximately lifted into a Markov
62
+ state-space model** via an AR(P) realisation. The resulting GP-SSM
63
+ keeps the GP modelling structure (timescale priors, ARD, inter-region
64
+ delays as kernel arguments) but inherits the **O(log T) GPU parallel
65
+ Kalman scan** for inference.
66
+
67
+ ## What you can do with this library
68
+
69
+ - **Drop in any of the standard methods** — ADM, DLAG (exact GP),
70
+ DLAG-SSM, mDLAG (time-domain / frequency-domain circulant /
71
+ Kalman-SSM), GPFA-SSM, free LDS — with a single `model = Preset(...)`
72
+ call and a unified `model.fit(data)` interface.
73
+ - **Swap GP kernels** without touching the inference engine. Any
74
+ stationary scalar kernel — MOSE/RBF, Matérn-1/2/3/2/5/2, or your own
75
+ `BaseKernel` subclass — slots in via `kernel_factory_*` callables.
76
+ Matérn-5/2 even has an **exact** finite-state SDE form (no AR(P)
77
+ approximation needed); see [`notebooks/synthetic/demo_matern.ipynb`](notebooks/synthetic/demo_matern.ipynb).
78
+ - **Write your own kernel in ~10 lines** — implement `cov(τ)` on a
79
+ `BaseKernel` subclass and plug it into every GP-prior model. See
80
+ [`notebooks/synthetic/demo_custom_kernel.ipynb`](notebooks/synthetic/demo_custom_kernel.ipynb)
81
+ for a Rational-Quadratic-kernel tutorial.
82
+ - **Compare methods fairly** — every method ships with the same
83
+ evaluation pipeline (delay-recovery RMSE on synthetic, NLB-style
84
+ held-out-neuron co-smoothing on real data, both with multi-seed
85
+ averaging).
86
+ - **Pick speed vs accuracy per use case** — exact GP (`O(T³)`,
87
+ reference) ↔ SSM-GP via Markov AR(P) lift (`O(log T)` on GPU) ↔
88
+ frequency-domain circulant approximation.
89
+ - **Browse results in Jupyter** — every method has a self-contained
90
+ notebook (see [`notebooks/`](notebooks/)) with results baked in.
91
+
92
+ All operations are **fully batched over trials** and **pure PyTorch**
93
+ — no per-trial Python loops, GPU by
94
+ default.
95
+
96
+ ---
97
+
98
+ ## The GP-SSM framework
99
+
100
+ The headline modelling choice in `mbrila` is the **latent prior class**
101
+ — what assumption you make about how latent state evolves in time. Four
102
+ families:
103
+
104
+ | Latent prior | What it assumes | Inference cost | Engine class |
105
+ |---|---|---|---|
106
+ | **Dense exact GP** | latent ~ GP with a stationary kernel; covariance is a full `T × T` matrix | `O(T³)` per EM iter | `ExactEMEngine` · `VEMARDEngine` (ARD) · `VEMARDFreqEngine` (circulant ≈ dense in freq domain) |
107
+ | **SSM-GP via AR(P) lift** | latent ~ same GP, **approximated** by a `P`-step Markov state-space model | `O(log T)` (parallel scan) | `KalmanEMEngine` · `VEMKalmanARDEngine` (ARD) |
108
+ | **SSM-GP via exact SDE** | latent ~ Matérn-`p/2` GP, **exactly** representable as a finite-state SDE | `O(log T)` | `KalmanEMEngine` |
109
+ | **Free SSM** | latent ~ generic linear-Gaussian Markov chain, learnable `(A, Q)`, no GP / no kernel | `O(log T)` | `KalmanEMEngine` |
110
+
111
+ Within a chosen latent prior, the model is further specified by:
112
+
113
+ - **GP kernel** (only for the three GP families): MOSE/RBF, Matérn,
114
+ or any user-defined `BaseKernel`. **The kernel also encodes the
115
+ inter-region communication delay** through its lagged covariance
116
+ `cov(τ + δ_j − δ_i)`. The delay parameterisation — `NoDelay`,
117
+ `FixedDelay` (constant `δ`), `TimeVaryingDelay` (`δ(t)`) — is
118
+ implemented at the dynamics layer (`mbrila.delays`) and is what
119
+ separates GPFA-SSM (no delay) from DLAG (fixed) from ADM (time-varying).
120
+ Practical note: `FixedDelay` drops in cleanly with any kernel, but
121
+ `TimeVaryingDelay` is significantly more expressive and its
122
+ high-dimensional `δ(t)` parameter space needs a careful
123
+ initialisation — specifically the rank-1 deflation init that
124
+ `ADM` ships, which breaks the symmetry between latent components
125
+ before joint training. If you want a time-varying-delay variant on a
126
+ custom kernel, follow `ADM` end-to-end as the reference — not just
127
+ the delay class, but the initialisation recipe.
128
+ - **Observation model**: linear-Gaussian
129
+ (`MultiRegionLinearObservation`) for ADM / DLAG / GPFA / LDS, or
130
+ variational ARD (`ARDObservation`) for the mDLAG family.
131
+ - **Model structure**: `LatentSpec(n_regions, n_across, n_within)` —
132
+ the per-region neuron-count tuple plus the across-region (shared)
133
+ and within-region (private) latent counts.
134
+
135
+ Engine compatibility is enforced by capability matching, e.g. `LDS`
136
+ cannot accept `ExactEMEngine` because it has no kernel; `MDLAG` with
137
+ ARD cannot accept the non-ARD `KalmanEMEngine`.
138
+
139
+ ### Model presets
140
+
141
+ A "method" is just a name for a configured combination. The headline
142
+ presets are all the field-known multi-region methods, plus their
143
+ SSM-approximate cousins:
144
+
145
+ | Preset | Latent prior class | Delay | Engine class | Notes |
146
+ |---|---|---|---|---|
147
+ | **`ADM`** | **SSM-GP** (AR(P) lift of MOSE/RBF) | time-varying `δ(t)` | `KalmanEMEngine` | `O(log T)` parallel scan |
148
+ | **`DLAG`** | dense exact GP (MOSE/RBF) | constant `δ` | `ExactEMEngine` (default) **or** `KalmanEMEngine` | the second route gives a DLAG-SSM AR(P) approximation |
149
+ | **`MDLAG`** | dense exact GP + **ARD** | constant `δ` | `VEMARDEngine` (time) / `VEMARDFreqEngine` (freq, ~22× faster) / `VEMKalmanARDEngine` (SSM) | ARD prunes redundant latents automatically |
150
+ | **`GPFA-SSM`** | SSM-GP, **no delay** | — | `KalmanEMEngine` | SSM (AR(P) lift) approximation of GPFA; shared-only baseline (`n_within = 0`). |
151
+ | **`LDS`** | **free SSM** (no GP prior) | — | `KalmanEMEngine` | no-kernel baseline |
152
+
153
+ The framework is **user-extensible** along the kernel dimension: writing a new stationary
154
+ kernel by subclassing `BaseKernel` and supplying `cov(τ)` lets you plug
155
+ that kernel into any GP-prior model. See
156
+ [`notebooks/synthetic/demo_custom_kernel.ipynb`](notebooks/synthetic/demo_custom_kernel.ipynb)
157
+ for an end-to-end Rational-Quadratic-kernel tutorial.
158
+
159
+ The presets shipped here all model inter-region interaction as a
160
+ **communication delay** in the kernel's lagged covariance. This is one
161
+ particular hypothesis about how brain regions interact — and the GP
162
+ kernel is the natural place to encode others. We encourage users to
163
+ design new kernels that capture different forms of inter-region
164
+ interaction and contribute them back.
165
+
166
+ ---
167
+
168
+ ## Installation
169
+
170
+ `mbrila` targets Python 3.12 (PyTorch does not yet support 3.13) and is
171
+ managed with [`uv`](https://docs.astral.sh/uv/):
172
+
173
+ ```bash
174
+ git clone https://github.com/BRAINML-GT/MBRILA.git mbrila
175
+ cd mbrila
176
+ uv sync # runtime dependencies
177
+ uv sync --extra dev # + dev tools (pytest, ruff, mypy)
178
+ ```
179
+
180
+ Or with plain pip:
181
+
182
+ ```bash
183
+ pip install -e .
184
+ ```
185
+
186
+ Default device is CUDA when available, CPU otherwise; nothing is
187
+ hard-coded — pass `--device cpu` or `device="cpu"` to force CPU.
188
+
189
+ ### Quickstart
190
+
191
+ The fastest way to see the library in action is to open one of the
192
+ Jupyter notebooks — they already have results baked in:
193
+
194
+ ```bash
195
+ jupyter lab notebooks/synthetic/demo_adm.ipynb # ADM on synthetic delay-recovery
196
+ jupyter lab notebooks/v1v2/demo_dlag_ssm.ipynb # DLAG-SSM on real V1/V2 data
197
+ jupyter lab notebooks/synthetic/demo_custom_kernel.ipynb # plug in your own GP kernel
198
+ ```
199
+
200
+ Each notebook is self-contained and produces every diagnostic figure
201
+ inline. See [Examples](#examples) below for the full list.
202
+
203
+ ---
204
+
205
+ ## Examples
206
+
207
+ > 📓 **Start with the Jupyter notebooks** in [`notebooks/`](notebooks/) —
208
+ > they are the easiest entry point. Every method has its own
209
+ > self-contained notebook that loads data, builds the model, fits it,
210
+ > and produces every diagnostic figure inline (convergence,
211
+ > inter-region delay, smoother latents, PSTH heatmap, co-smoothing
212
+ > reconstruction, ARD α bar, headline metric). Just open one and read
213
+ > top-to-bottom — no CLI needed, no shell scripts to read, results are
214
+ > baked into the file so you can browse them even before running.
215
+
216
+ For automation / sweeps / SLURM jobs, the CLI demos in
217
+ [`examples/`](examples/) cover the same methods with the same configs
218
+ (every notebook has a one-to-one CLI counterpart with identical
219
+ defaults).
220
+
221
+ ### Notebooks (recommended)
222
+
223
+ ```
224
+ notebooks/
225
+ ├── synthetic/ # ground-truth-delay recovery on synthetic GP data
226
+ │ ├── demo_adm.ipynb
227
+ │ ├── demo_dlag.ipynb (exact-GP engine)
228
+ │ ├── demo_dlag_ssm.ipynb (SSM-GP engine)
229
+ │ ├── demo_mdlag_time.ipynb / demo_mdlag_freq.ipynb / demo_mdlag_ssm.ipynb
230
+ │ ├── demo_gpfa_ssm.ipynb
231
+ │ ├── demo_lds.ipynb
232
+ │ ├── demo_matern.ipynb (Matérn-5/2 with exact SDE form)
233
+ │ └── demo_custom_kernel.ipynb (user-defined Rational Quadratic — kernel-as-axis tutorial)
234
+ └── v1v2/ # real-data co-smoothing on V1/V2 visual-cortex recordings
235
+ ├── demo_adm.ipynb / demo_dlag.ipynb / demo_dlag_ssm.ipynb
236
+ ├── demo_mdlag_time.ipynb / demo_mdlag_freq.ipynb / demo_mdlag_ssm.ipynb
237
+ └── demo_custom_kernel.ipynb
238
+ ```
239
+
240
+ Every notebook begins with a markdown banner stating the engine class
241
+ (dense exact GP / SSM-GP / SSM-GP exact-SDE / free SSM), a config
242
+ table, then runs the full fit-evaluate-plot pipeline. Diagnostic
243
+ figures are produced inline using shared helpers in
244
+ [`notebooks/nb_helpers.py`](notebooks/nb_helpers.py).
245
+
246
+ ### CLI scripts (for sweeps / SLURM)
247
+
248
+ ```
249
+ examples/
250
+ ├── synthetic/ # same as notebooks/synthetic/ but CLI
251
+ │ ├── demo_adm.py / demo_dlag.py / demo_dlag_ssm.py
252
+ │ ├── demo_mdlag_time.py / demo_mdlag_freq.py / demo_mdlag_ssm.py
253
+ │ ├── demo_gpfa_ssm.py / demo_lds.py
254
+ │ ├── demo_matern.py / demo_custom_kernel.py
255
+ │ └── demo_common.py
256
+ └── v1v2/ # same as notebooks/v1v2/ but CLI
257
+ ├── demo_adm.py / demo_dlag.py / demo_dlag_ssm.py
258
+ ├── demo_mdlag_time.py / demo_mdlag_freq.py / demo_mdlag_ssm.py
259
+ ├── demo_custom_kernel.py
260
+ └── v1v2_common.py
261
+ ```
262
+
263
+ Each CLI demo accepts `--help` for the full flag list.
264
+
265
+ ### Synthetic data — delay recovery
266
+
267
+ Synthetic multi-region data is sampled from exact Gaussian processes
268
+ with a **known ground-truth delay**, so the headline metric is
269
+ delay-recovery RMSE against truth.
270
+
271
+ ```bash
272
+ # CLI (one method per command)
273
+ uv run python examples/synthetic/demo_adm.py
274
+ uv run python examples/synthetic/demo_dlag.py # exact-GP DLAG
275
+ uv run python examples/synthetic/demo_dlag_ssm.py # SSM-GP DLAG
276
+ uv run python examples/synthetic/demo_mdlag_time.py # dense time-domain mDLAG
277
+ uv run python examples/synthetic/demo_mdlag_freq.py # frequency-domain mDLAG
278
+ uv run python examples/synthetic/demo_mdlag_ssm.py # mDLAG-SSM (Kalman + ARD)
279
+ uv run python examples/synthetic/demo_gpfa_ssm.py # shared-only SSM-GP, no delay
280
+ uv run python examples/synthetic/demo_lds.py # free-SSM baseline
281
+ uv run python examples/synthetic/demo_matern.py # Matérn-5/2 kernel
282
+ uv run python examples/synthetic/demo_custom_kernel.py # custom RQ kernel
283
+ ```
284
+
285
+ Each run writes per-pair delay overlays, per-region latent traces, y
286
+ reconstruction, convergence trace, and `summary.json` into the
287
+ preset's output directory.
288
+
289
+ ### Real data — V1/V2 visual cortex
290
+
291
+ The V1/V2 dataset shipped with the demos is from **Semedo et al.,
292
+ *Cortical Areas Interact through a Communication Subspace*, Neuron
293
+ 2019** — see [Citation](#citation) below.
294
+
295
+ The shipped pickle ([`data/demo_v1v2_data.pkl`](data/demo_v1v2_data.pkl))
296
+ is one recording session, 400 trials, with spike counts
297
+ Gaussian-smoothed in time and z-scored so the linear-Gaussian
298
+ emission models in this library see well-behaved inputs. The layout is
299
+ a dict with `V1` / `V2` arrays of shape `(n_trials, T, n_neurons)` =
300
+ `(400, 64, 72)` / `(400, 64, 22)`.
301
+
302
+ Real recordings have **no ground-truth delay**, so the headline metric
303
+ is **held-out-neuron co-smoothing RMSE**: a fraction of neurons per
304
+ region is hidden from inference and predicted from the posterior latent
305
+ inferred on the remaining context neurons. Reported per region:
306
+ `holdout_psth_rmse_{V1, V2}` (PSTH-level prediction).
307
+
308
+ V1V2 demos vary the **`--split-seed`** (train/val/test partition) and
309
+ **average over `--n-holdout-seeds`** different held-out-neuron masks per
310
+ split. The 3-split-seed std is the reported method-stability error bar:
311
+
312
+ ```bash
313
+ DATA=data/demo_v1v2_data.pkl # swap in your own pickle here
314
+ SEEDS=(0 1 2)
315
+ N_HOLDOUT_SEEDS=3
316
+
317
+ for SPLIT_SEED in "${SEEDS[@]}"; do
318
+ uv run python examples/v1v2/demo_adm.py \
319
+ --data-path "$DATA" \
320
+ --seed 0 --split-seed "${SPLIT_SEED}" \
321
+ --holdout-seed 0 --n-holdout-seeds "${N_HOLDOUT_SEEDS}" \
322
+ --out-dir "examples/v1v2/demo_outputs/adm/split_${SPLIT_SEED}"
323
+ done
324
+
325
+ # Other methods: swap `demo_adm.py` for any of
326
+ # demo_dlag.py | demo_dlag_ssm.py
327
+ # demo_mdlag_time.py | demo_mdlag_freq.py | demo_mdlag_ssm.py
328
+ # demo_custom_kernel.py
329
+ ```
330
+
331
+ Then aggregate the methods into one comparison:
332
+
333
+ ```bash
334
+ uv run python examples/v1v2/compare_v1v2_runs.py \
335
+ --label adm --runs "examples/v1v2/demo_outputs/adm/split_*" \
336
+ --label dlag --runs "examples/v1v2/demo_outputs/dlag/split_*" \
337
+ --label dlag_ssm --runs "examples/v1v2/demo_outputs/dlag_ssm/split_*" \
338
+ --label mdlag_time --runs "examples/v1v2/demo_outputs/mdlag_time/split_*" \
339
+ --label mdlag_freq --runs "examples/v1v2/demo_outputs/mdlag_freq/split_*" \
340
+ --label mdlag_ssm --runs "examples/v1v2/demo_outputs/mdlag_ssm/split_*" \
341
+ --label custom_kernel --runs "examples/v1v2/demo_outputs/custom_kernel/split_*" \
342
+ --out-dir examples/v1v2/demo_outputs/_compare
343
+ ```
344
+
345
+ ---
346
+
347
+ ## Directory layout
348
+
349
+ ```
350
+ src/mbrila/
351
+ ├── core/ abstract base classes + MultiRegionData container + LatentSpec
352
+ ├── kernels/ MOSE (RBF) · Matérn-1/2, 3/2, 5/2 · BaseKernel ABC (user extension point)
353
+ ├── delays/ NoDelay · FixedDelay · TimeVaryingDelay
354
+ ├── dynamics/ MarkovianGPLatent (kernel → AR(P) lift) · ExactGPLatent · FreeLDSLatent
355
+ ├── observations/ MultiRegionLinearObservation · ARDObservation
356
+ ├── inference/ ExactEMEngine · KalmanEMEngine · VEMARDEngine (time / freq) · VEMKalmanARDEngine
357
+ │ (parallel-scan Kalman filter/smoother, Särkkä & García-Fernández 2021)
358
+ ├── init/ pCCA emission init · rank-1 deflation init · scale anchor
359
+ ├── frequency/ FFT utilities for the frequency-domain mDLAG engine
360
+ ├── models/ ADM · DLAG · MDLAG · GPFA · LDS — assembled presets
361
+ ├── synthetic/ multi-region scenario generator (exact-GP sampling, configurable
362
+ │ delay shapes / per-latent heterogeneity / SNR)
363
+ ├── metrics/ evaluation metrics
364
+ └── utils/ device handling + shared helpers
365
+
366
+ examples/synthetic/ end-to-end CLI demos on synthetic data (ground-truth delay)
367
+ examples/v1v2/ end-to-end CLI demos on V1/V2 data (co-smoothing metric)
368
+ notebooks/synthetic/ Jupyter version of every synthetic demo
369
+ notebooks/v1v2/ Jupyter version of every V1V2 demo
370
+ ```
371
+
372
+ ---
373
+
374
+ ## Evaluation metrics
375
+
376
+ - **Synthetic data** — delay-recovery RMSE: how closely the fitted
377
+ delay matches the known ground-truth delay (in time bins).
378
+ - **Real data** — co-smoothing RMSE on held-out neurons (NLB-style),
379
+ per region, both trial-mean (PSTH) and trial-by-trial. This is the metric that fairly compares model classes on real data, because full-set reconstruction RMSE saturates at the spike-noise floor.
380
+
381
+ Log-likelihood / ELBO traces are kept as **convergence diagnostics
382
+ only**, never as a cross-model performance metric — different engines
383
+ optimise different surrogates (joint LL, marginal LL, true ELBO,
384
+ proxy ELBO), and absolute values are not comparable across model
385
+ classes.
386
+
387
+ ---
388
+
389
+ ## Citation
390
+
391
+ If you use `mbrila`, please cite the ADM paper that introduces this
392
+ GP-SSM framework:
393
+
394
+ ```bibtex
395
+ @inproceedings{li2025learning,
396
+ title={Learning Time-Varying Multi-Region Brain Communications via Scalable Markovian Gaussian Processes},
397
+ author={Li, Weihan and Wang, Yule and Li, Chengrui and Wu, Anqi},
398
+ booktitle={International Conference on Machine Learning},
399
+ pages={36021--36041},
400
+ year={2025},
401
+ organization={PMLR}
402
+ }
403
+ ```
404
+
405
+ If you additionally use models reimplemented here, please also cite
406
+ their original publications:
407
+
408
+ **DLAG** — Gokcen et al., Nature Computational Science 2022.
409
+ [doi:10.1038/s43588-022-00282-5](https://doi.org/10.1038/s43588-022-00282-5)
410
+
411
+ **mDLAG** — Gokcen et al., NeurIPS 2023.
412
+ [nips.cc/virtual/2023/poster/70171](https://nips.cc/virtual/2023/poster/70171)
413
+
414
+ **fast-mDLAG** (the `--mdlag-engine freq` path) — Gokcen et al., Neural
415
+ Computation 2025. [doi:10.1162/neco.a.22](https://doi.org/10.1162/neco.a.22)
416
+
417
+ ### Datasets
418
+
419
+ The **V1/V2 visual cortex** data used in `examples/v1v2/` and
420
+ `notebooks/v1v2/` is from:
421
+
422
+ > Semedo, J. D., Zandvakili, A., Machens, C. K., Yu, B. M., & Kohn, A.
423
+ > (2019). *Cortical Areas Interact through a Communication Subspace*.
424
+ > **Neuron**, 102(1), 249–259.
425
+ > [doi:10.1016/j.neuron.2019.01.026](https://doi.org/10.1016/j.neuron.2019.01.026)
426
+
427
+ If you use that data in published work, please cite Semedo et al. 2019
428
+ in addition to `mbrila`.
429
+
430
+ ---
431
+
432
+ ## Contributing
433
+
434
+ Contributions are warmly welcomed — new kernels, new presets, bug
435
+ fixes, documentation improvements, or simply opening an issue with
436
+ your use case. Open a PR or issue on
437
+ [GitHub](https://github.com/BRAINML-GT/MBRILA).
438
+
439
+ ---
440
+
441
+ ## License & Acknowledgements
442
+
443
+ `mbrila` is released under the **MIT License** — see [`LICENSE`](LICENSE).
444
+
445
+ `mbrila` is an **independent PyTorch reimplementation**: it does not
446
+ import or copy any upstream source code. Its models reimplement
447
+ algorithms from separate research codebases:
448
+
449
+ | Model | Reimplemented from | Original author | Original license |
450
+ |---|---|---|---|
451
+ | ADM | Adaptive Delay Model (Python) | Li et al. 2025 | MIT |
452
+ | DLAG | DLAG (MATLAB) | Evren Gokcen et al. 2022 | MIT |
453
+ | mDLAG / fast-mDLAG | fast-mDLAG (MATLAB) | Evren Gokcen et al. 2023, 2025 | MIT |
454
+
455
+ All upstream projects are MIT-licensed; their copyright notices are
456
+ reproduced in [`THIRD_PARTY_NOTICES.md`](THIRD_PARTY_NOTICES.md) as
457
+ an acknowledgement.