lmeeeg 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. lmeeeg-0.1.0/PKG-INFO +85 -0
  2. lmeeeg-0.1.0/README.md +67 -0
  3. lmeeeg-0.1.0/pyproject.toml +31 -0
  4. lmeeeg-0.1.0/setup.cfg +4 -0
  5. lmeeeg-0.1.0/src/lmeeeg/__init__.py +12 -0
  6. lmeeeg-0.1.0/src/lmeeeg/api/__init__.py +14 -0
  7. lmeeeg-0.1.0/src/lmeeeg/api/fit.py +131 -0
  8. lmeeeg-0.1.0/src/lmeeeg/api/infer.py +92 -0
  9. lmeeeg-0.1.0/src/lmeeeg/api/simulate.py +23 -0
  10. lmeeeg-0.1.0/src/lmeeeg/backends/__init__.py +1 -0
  11. lmeeeg-0.1.0/src/lmeeeg/backends/correction/__init__.py +1 -0
  12. lmeeeg-0.1.0/src/lmeeeg/backends/correction/base.py +23 -0
  13. lmeeeg-0.1.0/src/lmeeeg/backends/correction/maxstat_backend.py +71 -0
  14. lmeeeg-0.1.0/src/lmeeeg/backends/correction/mne_cluster_backend.py +92 -0
  15. lmeeeg-0.1.0/src/lmeeeg/backends/correction/mne_tfce_backend.py +87 -0
  16. lmeeeg-0.1.0/src/lmeeeg/backends/lmm/__init__.py +1 -0
  17. lmeeeg-0.1.0/src/lmeeeg/backends/lmm/base.py +33 -0
  18. lmeeeg-0.1.0/src/lmeeeg/backends/lmm/statsmodels_backend.py +126 -0
  19. lmeeeg-0.1.0/src/lmeeeg/backends/ols/__init__.py +1 -0
  20. lmeeeg-0.1.0/src/lmeeeg/backends/ols/base.py +28 -0
  21. lmeeeg-0.1.0/src/lmeeeg/backends/ols/numpy_backend.py +60 -0
  22. lmeeeg-0.1.0/src/lmeeeg/core/__init__.py +1 -0
  23. lmeeeg-0.1.0/src/lmeeeg/core/coding.py +33 -0
  24. lmeeeg-0.1.0/src/lmeeeg/core/contrasts.py +29 -0
  25. lmeeeg-0.1.0/src/lmeeeg/core/design.py +96 -0
  26. lmeeeg-0.1.0/src/lmeeeg/core/formulas.py +64 -0
  27. lmeeeg-0.1.0/src/lmeeeg/core/marginal.py +29 -0
  28. lmeeeg-0.1.0/src/lmeeeg/core/results.py +70 -0
  29. lmeeeg-0.1.0/src/lmeeeg/simulation/__init__.py +23 -0
  30. lmeeeg-0.1.0/src/lmeeeg/simulation/generator.py +657 -0
  31. lmeeeg-0.1.0/src/lmeeeg/simulation/scenarios.py +50 -0
  32. lmeeeg-0.1.0/src/lmeeeg/utils/__init__.py +1 -0
  33. lmeeeg-0.1.0/src/lmeeeg/utils/checks.py +14 -0
  34. lmeeeg-0.1.0/src/lmeeeg/utils/reshape.py +8 -0
  35. lmeeeg-0.1.0/src/lmeeeg/utils/summary.py +8 -0
  36. lmeeeg-0.1.0/src/lmeeeg/viz/__init__.py +1 -0
  37. lmeeeg-0.1.0/src/lmeeeg/viz/mne_helpers.py +8 -0
  38. lmeeeg-0.1.0/src/lmeeeg.egg-info/PKG-INFO +85 -0
  39. lmeeeg-0.1.0/src/lmeeeg.egg-info/SOURCES.txt +40 -0
  40. lmeeeg-0.1.0/src/lmeeeg.egg-info/dependency_links.txt +1 -0
  41. lmeeeg-0.1.0/src/lmeeeg.egg-info/requires.txt +11 -0
  42. lmeeeg-0.1.0/src/lmeeeg.egg-info/top_level.txt +1 -0
lmeeeg-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,85 @@
1
+ Metadata-Version: 2.4
2
+ Name: lmeeeg
3
+ Version: 0.1.0
4
+ Summary: Minimal Python implementation of lmeEEG for random-intercept mass-univariate M/EEG analysis
5
+ Author-email: Hiro YAMASAKI <hiroyoshi.YAMASAKI@univ-amu.fr>
6
+ License: MIT
7
+ Requires-Python: >=3.10
8
+ Description-Content-Type: text/markdown
9
+ Requires-Dist: numpy>=1.24
10
+ Requires-Dist: pandas>=2.0
11
+ Requires-Dist: statsmodels>=0.14
12
+ Requires-Dist: patsy>=0.5
13
+ Provides-Extra: mne
14
+ Requires-Dist: mne>=1.5; extra == "mne"
15
+ Provides-Extra: dev
16
+ Requires-Dist: pytest>=8.0; extra == "dev"
17
+ Requires-Dist: pytest-cov>=4.0; extra == "dev"
18
+
19
+ # LmeEEG
20
+
21
+ Minimal Python package implementing the core lmeEEG workflow for epoched M/EEG data with **random intercepts**:
22
+
23
+ 1. parse a mixed-model style formula at the API edge,
24
+ 2. fit a random-intercept mixed model at each channel × timepoint,
25
+ 3. subtract the fitted random effects to obtain marginal EEG,
26
+ 4. run fast mass-univariate OLS on the marginalized data,
27
+ 5. perform max-stat, cluster, or TFCE correction.
28
+
29
+ ## Current scope
30
+
31
+ - One grouping factor in the public API
32
+ - Random intercept only
33
+ - Trial-wise epoched data shaped `(n_observations, n_channels, n_times)`
34
+ - Cluster / TFCE correction via MNE-Python when installed
35
+ - Tiny simulation utilities for recovery / null checks
36
+
37
+ ## Not yet included
38
+
39
+ - Random slopes
40
+ - Two grouping factors in the public API
41
+ - Real-data validation workflows
42
+ - Parallel / distributed optimization
43
+
44
+ ## Basic example
45
+
46
+ ```python
47
+ import numpy as np
48
+ import pandas as pd
49
+
50
+ from lmeeeg.api.fit import fit_lmm_mass_univariate
51
+ from lmeeeg.api.infer import permute_fixed_effect
52
+ from lmeeeg.simulation.generator import simulate_random_intercept_dataset
53
+
54
+ simulated = simulate_random_intercept_dataset(
55
+ n_subjects=10,
56
+ n_trials_per_subject=12,
57
+ n_channels=4,
58
+ n_times=25,
59
+ effect_channels=[1, 2],
60
+ effect_times=range(8, 14),
61
+ beta=0.8,
62
+ seed=13,
63
+ )
64
+
65
+ fit_result = fit_lmm_mass_univariate(
66
+ eeg=simulated.eeg,
67
+ metadata=simulated.metadata,
68
+ formula="y ~ condition + latency + (1|subject)",
69
+ variable_types={
70
+ "condition": "categorical",
71
+ "latency": "numeric",
72
+ "subject": "group",
73
+ },
74
+ )
75
+
76
+ inference = permute_fixed_effect(
77
+ fit_result=fit_result,
78
+ effect="condition[T.B]",
79
+ correction="maxstat",
80
+ n_permutations=200,
81
+ seed=13,
82
+ )
83
+
84
+ print(inference.corrected_p_values.shape)
85
+ ```
lmeeeg-0.1.0/README.md ADDED
@@ -0,0 +1,67 @@
1
+ # LmeEEG
2
+
3
+ Minimal Python package implementing the core lmeEEG workflow for epoched M/EEG data with **random intercepts**:
4
+
5
+ 1. parse a mixed-model style formula at the API edge,
6
+ 2. fit a random-intercept mixed model at each channel × timepoint,
7
+ 3. subtract the fitted random effects to obtain marginal EEG,
8
+ 4. run fast mass-univariate OLS on the marginalized data,
9
+ 5. perform max-stat, cluster, or TFCE correction.
10
+
11
+ ## Current scope
12
+
13
+ - One grouping factor in the public API
14
+ - Random intercept only
15
+ - Trial-wise epoched data shaped `(n_observations, n_channels, n_times)`
16
+ - Cluster / TFCE correction via MNE-Python when installed
17
+ - Tiny simulation utilities for recovery / null checks
18
+
19
+ ## Not yet included
20
+
21
+ - Random slopes
22
+ - Two grouping factors in the public API
23
+ - Real-data validation workflows
24
+ - Parallel / distributed optimization
25
+
26
+ ## Basic example
27
+
28
+ ```python
29
+ import numpy as np
30
+ import pandas as pd
31
+
32
+ from lmeeeg.api.fit import fit_lmm_mass_univariate
33
+ from lmeeeg.api.infer import permute_fixed_effect
34
+ from lmeeeg.simulation.generator import simulate_random_intercept_dataset
35
+
36
+ simulated = simulate_random_intercept_dataset(
37
+ n_subjects=10,
38
+ n_trials_per_subject=12,
39
+ n_channels=4,
40
+ n_times=25,
41
+ effect_channels=[1, 2],
42
+ effect_times=range(8, 14),
43
+ beta=0.8,
44
+ seed=13,
45
+ )
46
+
47
+ fit_result = fit_lmm_mass_univariate(
48
+ eeg=simulated.eeg,
49
+ metadata=simulated.metadata,
50
+ formula="y ~ condition + latency + (1|subject)",
51
+ variable_types={
52
+ "condition": "categorical",
53
+ "latency": "numeric",
54
+ "subject": "group",
55
+ },
56
+ )
57
+
58
+ inference = permute_fixed_effect(
59
+ fit_result=fit_result,
60
+ effect="condition[T.B]",
61
+ correction="maxstat",
62
+ n_permutations=200,
63
+ seed=13,
64
+ )
65
+
66
+ print(inference.corrected_p_values.shape)
67
+ ```
@@ -0,0 +1,31 @@
1
+ [build-system]
2
+ requires = ["setuptools>=68", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "lmeeeg"
7
+ version = "0.1.0"
8
+ description = "Minimal Python implementation of lmeEEG for random-intercept mass-univariate M/EEG analysis"
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ license = {text = "MIT"}
12
+ authors = [{name = "Hiro YAMASAKI", email = "hiroyoshi.YAMASAKI@univ-amu.fr"}]
13
+ dependencies = [
14
+ "numpy>=1.24",
15
+ "pandas>=2.0",
16
+ "statsmodels>=0.14",
17
+ "patsy>=0.5",
18
+ ]
19
+
20
+ [project.optional-dependencies]
21
+ mne = ["mne>=1.5"]
22
+ dev = [
23
+ "pytest>=8.0",
24
+ "pytest-cov>=4.0",
25
+ ]
26
+
27
+ [tool.setuptools]
28
+ package-dir = {"" = "src"}
29
+
30
+ [tool.setuptools.packages.find]
31
+ where = ["src"]
lmeeeg-0.1.0/setup.cfg ADDED
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,12 @@
1
+ """Top-level package for LmeEEG."""
2
+
3
+ from lmeeeg.api.fit import fit_lmm_mass_univariate
4
+ from lmeeeg.api.infer import permute_fixed_effect
5
+ from lmeeeg.api.simulate import simulate_erp_random_intercept_dataset, simulate_random_intercept_dataset
6
+
7
+ __all__ = [
8
+ "fit_lmm_mass_univariate",
9
+ "permute_fixed_effect",
10
+ "simulate_erp_random_intercept_dataset",
11
+ "simulate_random_intercept_dataset",
12
+ ]
@@ -0,0 +1,14 @@
1
+ """Public API layer."""
2
+
3
+ from lmeeeg.api.fit import FitConfig, fit_lmm_mass_univariate
4
+ from lmeeeg.api.infer import PermutationConfig, permute_fixed_effect
5
+ from lmeeeg.api.simulate import simulate_erp_random_intercept_dataset, simulate_random_intercept_dataset
6
+
7
+ __all__ = [
8
+ "FitConfig",
9
+ "PermutationConfig",
10
+ "fit_lmm_mass_univariate",
11
+ "permute_fixed_effect",
12
+ "simulate_erp_random_intercept_dataset",
13
+ "simulate_random_intercept_dataset",
14
+ ]
@@ -0,0 +1,131 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+
9
+ from lmeeeg.backends.lmm.statsmodels_backend import StatsModelsLMMBackend
10
+ from lmeeeg.backends.ols.numpy_backend import NumPyOLSBackend
11
+ from lmeeeg.core.design import build_design_spec
12
+ from lmeeeg.core.marginal import compute_marginal_eeg
13
+ from lmeeeg.core.results import ConvergenceSummary, FitResult
14
+
15
+
16
+ @dataclass(slots=True)
17
+ class FitConfig:
18
+ """Configuration for the public fit API.
19
+
20
+ Parameters
21
+ ----------
22
+ lmm_backend_name : str
23
+ Name of the LMM backend.
24
+ ols_backend_name : str
25
+ Name of the OLS backend.
26
+ """
27
+
28
+ lmm_backend_name: str = "statsmodels"
29
+ ols_backend_name: str = "numpy"
30
+
31
+
32
+ # ==============================
33
+ # Public fit entry point
34
+ # ==============================
35
+
36
+ def fit_lmm_mass_univariate(
37
+ eeg: np.ndarray,
38
+ metadata: pd.DataFrame,
39
+ formula: str,
40
+ variable_types: dict[str, str],
41
+ fit_intercept: bool = True,
42
+ config: FitConfig | None = None,
43
+ ) -> FitResult:
44
+ """Fit the minimal lmeEEG pipeline.
45
+
46
+ Parameters
47
+ ----------
48
+ eeg : np.ndarray
49
+ EEG data with shape `(n_observations, n_channels, n_times)`.
50
+ metadata : pd.DataFrame
51
+ Observation-level metadata. One row per EEG observation.
52
+ formula : str
53
+ Mixed-model style formula, e.g. ``"y ~ condition + latency + (1|subject)"``.
54
+ The response variable must be `y` and is treated as symbolic only.
55
+ variable_types : dict[str, str]
56
+ Explicit variable typing map. Allowed values are ``categorical``,
57
+ ``numeric``, and ``group``.
58
+ fit_intercept : bool
59
+ Whether to include a fixed intercept.
60
+ config : FitConfig | None
61
+ Backend configuration.
62
+
63
+ Returns
64
+ -------
65
+ FitResult
66
+ Result object containing design information, convergence diagnostics,
67
+ marginal EEG, and OLS summary statistics.
68
+
69
+ Usage example
70
+ -------------
71
+ fit_result = fit_lmm_mass_univariate(
72
+ eeg=eeg,
73
+ metadata=metadata,
74
+ formula="y ~ condition + latency + (1|subject)",
75
+ variable_types={
76
+ "condition": "categorical",
77
+ "latency": "numeric",
78
+ "subject": "group",
79
+ },
80
+ )
81
+ """
82
+ config = config or FitConfig()
83
+ design_spec = build_design_spec(
84
+ metadata=metadata,
85
+ formula=formula,
86
+ variable_types=variable_types,
87
+ fit_intercept=fit_intercept,
88
+ )
89
+
90
+ if config.lmm_backend_name != "statsmodels":
91
+ raise ValueError(f"Unsupported LMM backend: {config.lmm_backend_name}")
92
+ if config.ols_backend_name != "numpy":
93
+ raise ValueError(f"Unsupported OLS backend: {config.ols_backend_name}")
94
+
95
+ lmm_backend = StatsModelsLMMBackend()
96
+ lmm_result = lmm_backend.fit_mass_univariate(
97
+ eeg=eeg,
98
+ metadata=metadata,
99
+ design_spec=design_spec,
100
+ )
101
+
102
+ marginal_eeg = compute_marginal_eeg(eeg=eeg, fitted_random_effects=lmm_result.fitted_random_effects)
103
+
104
+ ols_backend = NumPyOLSBackend()
105
+ ols_result = ols_backend.fit_mass_univariate(
106
+ eeg=marginal_eeg,
107
+ design_matrix=design_spec.fixed_design_matrix,
108
+ column_names=design_spec.fixed_column_names,
109
+ )
110
+
111
+ convergence_summary = ConvergenceSummary.from_feature_table(lmm_result.feature_diagnostics)
112
+
113
+ return FitResult(
114
+ formula=formula,
115
+ variable_types=variable_types,
116
+ design_spec=design_spec,
117
+ fixed_effects_maps=lmm_result.fixed_effects_maps,
118
+ random_effect_variance_map=lmm_result.random_effect_variance_map,
119
+ residual_variance_map=lmm_result.residual_variance_map,
120
+ fitted_random_effects=lmm_result.fitted_random_effects,
121
+ feature_diagnostics=lmm_result.feature_diagnostics,
122
+ convergence_summary=convergence_summary,
123
+ marginal_eeg=marginal_eeg,
124
+ ols_betas=ols_result.beta_maps,
125
+ ols_t_values=ols_result.t_value_maps,
126
+ ols_residual_variance=ols_result.residual_variance_map,
127
+ backend_metadata={
128
+ "lmm_backend": config.lmm_backend_name,
129
+ "ols_backend": config.ols_backend_name,
130
+ },
131
+ )
@@ -0,0 +1,92 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ from lmeeeg.backends.correction.maxstat_backend import MaxStatCorrectionBackend
6
+ from lmeeeg.backends.correction.mne_cluster_backend import MNEClusterCorrectionBackend
7
+ from lmeeeg.backends.correction.mne_tfce_backend import MNETFCorrectionBackend
8
+ from lmeeeg.core.results import FitResult, InferenceResult
9
+
10
+
11
+ @dataclass(slots=True)
12
+ class PermutationConfig:
13
+ """Configuration for permutation inference."""
14
+
15
+ n_permutations: int = 1000
16
+ seed: int = 0
17
+ tail: int = 0
18
+
19
+
20
+ # ==============================
21
+ # Public inference entry point
22
+ # ==============================
23
+
24
+ def permute_fixed_effect(
25
+ fit_result: FitResult,
26
+ effect: str,
27
+ correction: str = "cluster",
28
+ n_permutations: int = 1000,
29
+ seed: int = 0,
30
+ tail: int = 0,
31
+ threshold: float | dict[str, float] | None = None,
32
+ adjacency=None,
33
+ ) -> InferenceResult:
34
+ """Run permutation-based inference for one fixed effect.
35
+
36
+ Parameters
37
+ ----------
38
+ fit_result : FitResult
39
+ Result returned by :func:`fit_lmm_mass_univariate`.
40
+ effect : str
41
+ Exact fixed-effect column name to test.
42
+ correction : str
43
+ Correction backend: ``maxstat``, ``cluster``, or ``tfce``.
44
+ n_permutations : int
45
+ Number of permutations.
46
+ seed : int
47
+ Random seed.
48
+ tail : int
49
+ Tail for MNE-compatible permutation code. Use 0 for two-sided,
50
+ 1 for positive, -1 for negative.
51
+ threshold : float | dict[str, float] | None
52
+ Cluster threshold or TFCE threshold dictionary.
53
+ adjacency : Any
54
+ Optional adjacency matrix passed through to MNE correction backends.
55
+
56
+ Returns
57
+ -------
58
+ InferenceResult
59
+ Corrected inference output.
60
+
61
+ Usage example
62
+ -------------
63
+ inference = permute_fixed_effect(
64
+ fit_result=fit_result,
65
+ effect="condition[T.B]",
66
+ correction="tfce",
67
+ n_permutations=500,
68
+ seed=1,
69
+ )
70
+ """
71
+ if effect not in fit_result.design_spec.fixed_column_names:
72
+ available = ", ".join(fit_result.design_spec.fixed_column_names)
73
+ raise ValueError(f"Unknown effect '{effect}'. Available fixed effects: {available}")
74
+
75
+ if correction == "maxstat":
76
+ backend = MaxStatCorrectionBackend()
77
+ elif correction == "cluster":
78
+ backend = MNEClusterCorrectionBackend()
79
+ elif correction == "tfce":
80
+ backend = MNETFCorrectionBackend()
81
+ else:
82
+ raise ValueError(f"Unsupported correction backend: {correction}")
83
+
84
+ return backend.run(
85
+ fit_result=fit_result,
86
+ effect=effect,
87
+ n_permutations=n_permutations,
88
+ seed=seed,
89
+ tail=tail,
90
+ threshold=threshold,
91
+ adjacency=adjacency,
92
+ )
@@ -0,0 +1,23 @@
1
+ """Public simulation helpers."""
2
+
3
+ from lmeeeg.simulation.generator import (
4
+ ERPComponentSpec,
5
+ ERPSimulationConfig,
6
+ ERPSimulationMetadata,
7
+ ERPSimulationResult,
8
+ SimulatedDataset,
9
+ simulate_erp_random_intercept_dataset,
10
+ simulate_random_intercept_dataset,
11
+ )
12
+ from lmeeeg.simulation.scenarios import build_default_erp_component_specs
13
+
14
+ __all__ = [
15
+ "ERPComponentSpec",
16
+ "ERPSimulationConfig",
17
+ "ERPSimulationMetadata",
18
+ "ERPSimulationResult",
19
+ "SimulatedDataset",
20
+ "build_default_erp_component_specs",
21
+ "simulate_erp_random_intercept_dataset",
22
+ "simulate_random_intercept_dataset",
23
+ ]
@@ -0,0 +1 @@
1
+ """Backend interfaces and implementations."""
@@ -0,0 +1 @@
1
+ """Correction backends."""
@@ -0,0 +1,23 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any
5
+
6
+ from lmeeeg.core.results import FitResult, InferenceResult
7
+
8
+
9
+ class BaseCorrectionBackend(ABC):
10
+ """Abstract base class for correction backends."""
11
+
12
+ @abstractmethod
13
+ def run(
14
+ self,
15
+ fit_result: FitResult,
16
+ effect: str,
17
+ n_permutations: int,
18
+ seed: int,
19
+ tail: int,
20
+ threshold: float | dict[str, float] | None,
21
+ adjacency: Any,
22
+ ) -> InferenceResult:
23
+ """Run correction backend."""
@@ -0,0 +1,71 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+ from lmeeeg.backends.correction.base import BaseCorrectionBackend
6
+ from lmeeeg.core.results import FitResult, InferenceResult
7
+
8
+
9
+ # ==============================
10
+ # Max-stat correction backend
11
+ # ==============================
12
+
13
+ class MaxStatCorrectionBackend(BaseCorrectionBackend):
14
+ """Permutation max-statistic backend on OLS t maps."""
15
+
16
+ def run(
17
+ self,
18
+ fit_result: FitResult,
19
+ effect: str,
20
+ n_permutations: int,
21
+ seed: int,
22
+ tail: int,
23
+ threshold: float | dict[str, float] | None,
24
+ adjacency,
25
+ ) -> InferenceResult:
26
+ """Run max-statistic correction.
27
+
28
+ Notes
29
+ -----
30
+ This backend uses row shuffling of the design matrix as a simple MVP
31
+ permutation scheme on marginalized data. It is intentionally explicit
32
+ and easy to inspect.
33
+ """
34
+ del threshold, adjacency, tail
35
+ rng = np.random.default_rng(seed)
36
+ observed_t = fit_result.ols_t_values[effect]
37
+ x_matrix = fit_result.design_spec.fixed_design_matrix
38
+ y = fit_result.marginal_eeg
39
+ n_observations, n_channels, n_times = y.shape
40
+ y_2d = y.reshape(n_observations, n_channels * n_times)
41
+
42
+ effect_index = fit_result.design_spec.fixed_column_names.index(effect)
43
+ null_distribution = np.zeros(n_permutations, dtype=float)
44
+
45
+ for permutation_index in range(n_permutations):
46
+ permuted_indices = rng.permutation(n_observations)
47
+ x_perm = x_matrix[permuted_indices, :]
48
+ xtx_inv = np.linalg.inv(x_perm.T @ x_perm)
49
+ beta = xtx_inv @ x_perm.T @ y_2d
50
+ residuals = y_2d - x_perm @ beta
51
+ residual_variance = np.sum(residuals ** 2, axis=0) / (n_observations - x_perm.shape[1])
52
+ standard_error = np.sqrt(residual_variance * xtx_inv[effect_index, effect_index])
53
+ t_values = beta[effect_index, :] / standard_error
54
+ null_distribution[permutation_index] = np.max(np.abs(t_values))
55
+
56
+ corrected_p_values = (1 + np.sum(null_distribution[:, None, None] >= np.abs(observed_t)[None, :, :], axis=0)) / (n_permutations + 1)
57
+
58
+ return InferenceResult(
59
+ effect=effect,
60
+ correction="maxstat",
61
+ observed_statistic=observed_t,
62
+ corrected_p_values=corrected_p_values,
63
+ null_distribution=null_distribution,
64
+ clusters=None,
65
+ cluster_p_values=None,
66
+ backend_metadata={
67
+ "backend": "maxstat",
68
+ "n_permutations": n_permutations,
69
+ "permutation_scheme": "row_shuffle_on_marginal_design",
70
+ },
71
+ )
@@ -0,0 +1,92 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+
5
+ import numpy as np
6
+
7
+ from lmeeeg.backends.correction.base import BaseCorrectionBackend
8
+ from lmeeeg.core.results import FitResult, InferenceResult
9
+
10
+
11
+ # ==============================
12
+ # MNE cluster correction backend
13
+ # ==============================
14
+
15
+ class MNEClusterCorrectionBackend(BaseCorrectionBackend):
16
+ """Cluster-based permutation correction using MNE-Python."""
17
+
18
+ def run(
19
+ self,
20
+ fit_result: FitResult,
21
+ effect: str,
22
+ n_permutations: int,
23
+ seed: int,
24
+ tail: int,
25
+ threshold: float | dict[str, float] | None,
26
+ adjacency,
27
+ ) -> InferenceResult:
28
+ """Run cluster-based permutation correction with MNE.
29
+
30
+ The test is performed on observation-level marginalized data using a
31
+ simple two-model Freedman-Lane style residualization for the selected
32
+ effect. This keeps the correction layer separated from the LMM layer.
33
+ """
34
+ # Some environments expose MNE through a Numba caching path that is not
35
+ # available at runtime. Disabling JIT here keeps the optional backend
36
+ # usable without affecting the public API.
37
+ os.environ.setdefault("NUMBA_DISABLE_JIT", "1")
38
+ os.environ.setdefault("MNE_DONTWRITE_HOME", "true")
39
+ try:
40
+ from mne.stats import permutation_cluster_1samp_test
41
+ except Exception as error: # pragma: no cover
42
+ raise ImportError("MNE-Python is required for cluster correction.") from error
43
+
44
+ x_matrix = fit_result.design_spec.fixed_design_matrix
45
+ column_names = fit_result.design_spec.fixed_column_names
46
+ effect_index = column_names.index(effect)
47
+ reduced_columns = [index for index in range(len(column_names)) if index != effect_index]
48
+
49
+ y = fit_result.marginal_eeg
50
+ n_observations, n_channels, n_times = y.shape
51
+ y_2d = y.reshape(n_observations, n_channels * n_times)
52
+
53
+ if reduced_columns:
54
+ x_reduced = x_matrix[:, reduced_columns]
55
+ beta_reduced = np.linalg.pinv(x_reduced) @ y_2d
56
+ residuals = y_2d - x_reduced @ beta_reduced
57
+ else:
58
+ residuals = y_2d.copy()
59
+
60
+ residuals_3d = residuals.reshape(n_observations, n_channels, n_times)
61
+ data_for_mne = np.transpose(residuals_3d, (0, 2, 1))
62
+ cluster_threshold = threshold if threshold is not None else 2.0
63
+
64
+ observed_t, clusters, cluster_p_values, null_distribution = permutation_cluster_1samp_test(
65
+ X=data_for_mne,
66
+ threshold=cluster_threshold,
67
+ n_permutations=n_permutations,
68
+ tail=tail,
69
+ adjacency=adjacency,
70
+ out_type="mask",
71
+ seed=seed,
72
+ verbose=False,
73
+ )
74
+ observed_t = observed_t.T
75
+ corrected_p_values = np.ones_like(observed_t, dtype=float)
76
+ for cluster_mask, cluster_p_value in zip(clusters, cluster_p_values):
77
+ corrected_p_values[cluster_mask.T] = np.minimum(corrected_p_values[cluster_mask.T], cluster_p_value)
78
+
79
+ return InferenceResult(
80
+ effect=effect,
81
+ correction="cluster",
82
+ observed_statistic=observed_t,
83
+ corrected_p_values=corrected_p_values,
84
+ null_distribution=np.asarray(null_distribution),
85
+ clusters=clusters,
86
+ cluster_p_values=np.asarray(cluster_p_values),
87
+ backend_metadata={
88
+ "backend": "mne_cluster",
89
+ "n_permutations": n_permutations,
90
+ "threshold": cluster_threshold,
91
+ },
92
+ )