scatrans 0.7.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scatrans/__init__.py +56 -0
- scatrans/_bias.py +24 -0
- scatrans/_de.py +555 -0
- scatrans/_permutation.py +229 -0
- scatrans/_utils.py +343 -0
- scatrans/_velocity.py +152 -0
- scatrans/_version.py +24 -0
- scatrans/data/Hs_KEGG_2026.txt +223 -0
- scatrans/data/Mm_KEGG_2026.txt +219 -0
- scatrans/data/Mus_musculus.GRCm39.115_gene_features.parquet +0 -0
- scatrans/data/README.md +94 -0
- scatrans/data/mouse_2020A_gene_features.parquet +0 -0
- scatrans/enrich.py +735 -0
- scatrans/generate_gene_features.py +99 -0
- scatrans/pl.py +1168 -0
- scatrans/pp_bias.py +261 -0
- scatrans/qc.py +41 -0
- scatrans/tl.py +1618 -0
- scatrans-0.7.0.dev0.dist-info/METADATA +826 -0
- scatrans-0.7.0.dev0.dist-info/RECORD +24 -0
- scatrans-0.7.0.dev0.dist-info/WHEEL +5 -0
- scatrans-0.7.0.dev0.dist-info/entry_points.txt +2 -0
- scatrans-0.7.0.dev0.dist-info/licenses/LICENSE +21 -0
- scatrans-0.7.0.dev0.dist-info/top_level.txt +1 -0
scatrans/__init__.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""
|
|
2
|
+
scATrans public API.
|
|
3
|
+
|
|
4
|
+
Recommended usage:
|
|
5
|
+
import scatrans as scat
|
|
6
|
+
scat.active_score(...)
|
|
7
|
+
scat.add_gene_features(...)
|
|
8
|
+
scat.pl.set_style()
|
|
9
|
+
scat.run_enrichment(...)
|
|
10
|
+
|
|
11
|
+
Submodules `pl` and `qc` are intentionally exposed (scanpy-style convention).
|
|
12
|
+
Other internal modules are not part of the stable public surface.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from . import pl, qc
|
|
16
|
+
from .enrich import list_bundled_gene_sets, run_enrichment, run_kegg, simplify_enrichment
|
|
17
|
+
from .generate_gene_features import main as generate_gene_features_main
|
|
18
|
+
from .pp_bias import add_gene_features, list_available_gene_features
|
|
19
|
+
from .tl import (
|
|
20
|
+
active_score,
|
|
21
|
+
diagnose_design,
|
|
22
|
+
differential_expression,
|
|
23
|
+
filter_active_genes,
|
|
24
|
+
restore_raw_counts,
|
|
25
|
+
store_raw_counts,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
__all__ = [
|
|
29
|
+
"active_score",
|
|
30
|
+
"differential_expression",
|
|
31
|
+
"diagnose_design",
|
|
32
|
+
"filter_active_genes",
|
|
33
|
+
"restore_raw_counts",
|
|
34
|
+
"store_raw_counts",
|
|
35
|
+
"add_gene_features",
|
|
36
|
+
"list_available_gene_features",
|
|
37
|
+
"run_enrichment",
|
|
38
|
+
"run_kegg",
|
|
39
|
+
"simplify_enrichment",
|
|
40
|
+
"list_bundled_gene_sets",
|
|
41
|
+
"pl",
|
|
42
|
+
"qc",
|
|
43
|
+
"generate_gene_features_main",
|
|
44
|
+
"__version__",
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
# Version is provided dynamically when possible
|
|
48
|
+
try:
|
|
49
|
+
from ._version import version as __version__
|
|
50
|
+
except ImportError:
|
|
51
|
+
__version__ = "0.7.0.dev0"
|
|
52
|
+
|
|
53
|
+
# Optional: prevent some internal modules from appearing too prominently
|
|
54
|
+
# in casual inspection while still allowing advanced users to do
|
|
55
|
+
# `import scatrans.tl as tl` if they really need it.
|
|
56
|
+
# We do not delete them aggressively to preserve scanpy-like ergonomics.
|
scatrans/_bias.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Bias correction (Huber regression on gene length + intron number).
|
|
3
|
+
|
|
4
|
+
The actual implementation lives in _utils._fit_huber_bias_correction so it can be
|
|
5
|
+
used from both the main analysis path and from permutation tasks without duplication.
|
|
6
|
+
|
|
7
|
+
Enhanced return: (residual, bias_info_dict) with fit diagnostics for transparency.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
|
|
16
|
+
from ._utils import _fit_huber_bias_correction as _raw_fit
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def fit_huber_bias_correction(*args, **kwargs) -> tuple[np.ndarray, dict[str, Any]]:
|
|
20
|
+
"""Public/internal wrapper that returns (residual, bias_info)."""
|
|
21
|
+
return _raw_fit(*args, **kwargs)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
__all__ = ["fit_huber_bias_correction"]
|
scatrans/_de.py
ADDED
|
@@ -0,0 +1,555 @@
|
|
|
1
|
+
"""
|
|
2
|
+
scATrans internal differential expression helpers.
|
|
3
|
+
|
|
4
|
+
Contains the wrapper that supports both scanpy rank_genes_groups and PyDESeq2
|
|
5
|
+
for pseudobulk. Extracted so tl.py stays small.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
import warnings
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
import anndata as ad
|
|
15
|
+
import numpy as np
|
|
16
|
+
import pandas as pd
|
|
17
|
+
import scanpy as sc
|
|
18
|
+
from joblib import Parallel, delayed
|
|
19
|
+
from scipy import sparse
|
|
20
|
+
from scipy.sparse import csr_matrix
|
|
21
|
+
from scipy.stats import chi2
|
|
22
|
+
from statsmodels.stats.multitest import multipletests
|
|
23
|
+
|
|
24
|
+
from ._utils import (
|
|
25
|
+
_is_integer_counts_like,
|
|
26
|
+
_warn_if_low_counts_matrix,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _run_de_wrapper(
|
|
33
|
+
adata: ad.AnnData,
|
|
34
|
+
groupby: str,
|
|
35
|
+
target_group: str,
|
|
36
|
+
reference_group: str,
|
|
37
|
+
de_method: str = "t-test_overestim_var",
|
|
38
|
+
is_pseudobulk: bool = False,
|
|
39
|
+
pb_backend: str = "pydeseq2",
|
|
40
|
+
n_jobs: int = 1,
|
|
41
|
+
labels: Any | None = None,
|
|
42
|
+
strict_pydeseq2_counts: bool = True,
|
|
43
|
+
use_mixed_model: bool = False,
|
|
44
|
+
sample_col: str | None = None,
|
|
45
|
+
mixed_model_pval: str = "wald",
|
|
46
|
+
# Memento (Cell 2024 method-of-moments) as independent cell-level DE backend
|
|
47
|
+
use_memento_de: bool = False,
|
|
48
|
+
memento_capture_rate: float = 0.07,
|
|
49
|
+
memento_num_boot: int = 5000,
|
|
50
|
+
memento_n_cpus: int = -1,
|
|
51
|
+
# Allow providing raw counts separately (common when adata.X is already HVG + log1p)
|
|
52
|
+
counts: str | np.ndarray | sparse.spmatrix | pd.DataFrame | ad.AnnData | None = None,
|
|
53
|
+
) -> pd.DataFrame:
|
|
54
|
+
"""Run DE and return a DataFrame with logFC, p_val, p_adj (and optionally delta_variance, delta_var_pval when mixed).
|
|
55
|
+
|
|
56
|
+
When use_memento_de=True, Memento is used for the primary DE statistics (logFC/p_adj).
|
|
57
|
+
This is treated as a third parallel cell-level backend (alongside scanpy-style and mixed-model).
|
|
58
|
+
"""
|
|
59
|
+
if use_mixed_model:
|
|
60
|
+
if sample_col is None:
|
|
61
|
+
raise ValueError("sample_col must be provided when use_mixed_model=True")
|
|
62
|
+
return _run_mixedlm_de(
|
|
63
|
+
adata,
|
|
64
|
+
groupby=groupby,
|
|
65
|
+
target_group=target_group,
|
|
66
|
+
reference_group=reference_group,
|
|
67
|
+
sample_col=sample_col,
|
|
68
|
+
n_jobs=n_jobs,
|
|
69
|
+
labels=labels,
|
|
70
|
+
mixed_model_pval=mixed_model_pval,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
if use_memento_de:
|
|
74
|
+
if is_pseudobulk:
|
|
75
|
+
raise ValueError(
|
|
76
|
+
"use_memento_de=True is not supported with use_pseudobulk=True "
|
|
77
|
+
"(Memento is a cell-level method-of-moments estimator; use PyDESeq2 for pseudobulk)."
|
|
78
|
+
)
|
|
79
|
+
return _run_memento_de(
|
|
80
|
+
adata,
|
|
81
|
+
groupby=groupby,
|
|
82
|
+
target_group=target_group,
|
|
83
|
+
reference_group=reference_group,
|
|
84
|
+
labels=labels,
|
|
85
|
+
capture_rate=memento_capture_rate,
|
|
86
|
+
num_boot=memento_num_boot,
|
|
87
|
+
n_cpus=memento_n_cpus,
|
|
88
|
+
counts=counts,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
ad_temp = adata.copy() if labels is not None else adata
|
|
92
|
+
use_groupby = groupby
|
|
93
|
+
|
|
94
|
+
if labels is not None:
|
|
95
|
+
use_groupby = "_de_temp_group"
|
|
96
|
+
ad_temp.obs[use_groupby] = pd.Categorical(
|
|
97
|
+
np.asarray(labels).astype(str), categories=[reference_group, target_group]
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
if is_pseudobulk and pb_backend == "pydeseq2":
|
|
101
|
+
try:
|
|
102
|
+
from pydeseq2.dds import DeseqDataSet
|
|
103
|
+
from pydeseq2.ds import DeseqStats
|
|
104
|
+
except ImportError as e:
|
|
105
|
+
raise ImportError(
|
|
106
|
+
"pydeseq2 is required when pseudobulk_de_backend='pydeseq2'. "
|
|
107
|
+
"Install with: pip install pydeseq2 or 'scatrans[pseudobulk]'"
|
|
108
|
+
) from e
|
|
109
|
+
|
|
110
|
+
n_t = (ad_temp.obs[use_groupby] == target_group).sum()
|
|
111
|
+
n_r = (ad_temp.obs[use_groupby] == reference_group).sum()
|
|
112
|
+
if n_t < 2 or n_r < 2:
|
|
113
|
+
raise ValueError(
|
|
114
|
+
f"PyDESeq2 requires >=2 replicates per group. Found {n_t} target, {n_r} ref."
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
is_count_like = _is_integer_counts_like(ad_temp.X)
|
|
118
|
+
|
|
119
|
+
if not is_count_like:
|
|
120
|
+
msg = (
|
|
121
|
+
"Data passed to PyDESeq2 does not look like raw non-negative integer counts. "
|
|
122
|
+
"PyDESeq2 requires unnormalized integer counts in adata.X. "
|
|
123
|
+
"If you intentionally want to allow rounding, set strict_pydeseq2_counts=False."
|
|
124
|
+
)
|
|
125
|
+
if strict_pydeseq2_counts:
|
|
126
|
+
raise ValueError(msg)
|
|
127
|
+
logger.warning(msg)
|
|
128
|
+
else:
|
|
129
|
+
_warn_if_low_counts_matrix(ad_temp.X)
|
|
130
|
+
|
|
131
|
+
if sparse.issparse(ad_temp.X):
|
|
132
|
+
gene_sums = np.asarray(ad_temp.X.sum(axis=0)).ravel()
|
|
133
|
+
gene_keep = gene_sums >= 10
|
|
134
|
+
if gene_keep.sum() == 0:
|
|
135
|
+
raise ValueError("No genes passed the DESeq2 count filter (sum(counts) >= 10).")
|
|
136
|
+
X_filtered = ad_temp.X[:, gene_keep].toarray()
|
|
137
|
+
X_filtered = np.clip(np.round(np.nan_to_num(X_filtered)), 0, None).astype(int)
|
|
138
|
+
counts_use = pd.DataFrame(
|
|
139
|
+
X_filtered, index=ad_temp.obs_names, columns=ad_temp.var_names[gene_keep]
|
|
140
|
+
)
|
|
141
|
+
else:
|
|
142
|
+
X = np.asarray(ad_temp.X)
|
|
143
|
+
X = np.clip(np.round(np.nan_to_num(X)), 0, None).astype(int)
|
|
144
|
+
counts_df = pd.DataFrame(X, index=ad_temp.obs_names, columns=ad_temp.var_names)
|
|
145
|
+
gene_keep = counts_df.sum(axis=0) >= 10
|
|
146
|
+
counts_use = counts_df.loc[:, gene_keep].copy()
|
|
147
|
+
|
|
148
|
+
if counts_use.shape[1] == 0:
|
|
149
|
+
raise ValueError("No genes passed the DESeq2 count filter (sum(counts) >= 10).")
|
|
150
|
+
|
|
151
|
+
condition = ad_temp.obs[use_groupby].astype(str).values
|
|
152
|
+
metadata = pd.DataFrame(
|
|
153
|
+
{use_groupby: pd.Categorical(condition, categories=[reference_group, target_group])},
|
|
154
|
+
index=counts_use.index,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
with warnings.catch_warnings():
|
|
158
|
+
warnings.simplefilter("ignore")
|
|
159
|
+
try:
|
|
160
|
+
dds = DeseqDataSet(
|
|
161
|
+
counts=counts_use,
|
|
162
|
+
metadata=metadata,
|
|
163
|
+
design_factors=use_groupby,
|
|
164
|
+
ref_level=[use_groupby, reference_group],
|
|
165
|
+
quiet=True,
|
|
166
|
+
n_cpus=n_jobs,
|
|
167
|
+
)
|
|
168
|
+
except TypeError:
|
|
169
|
+
dds = DeseqDataSet(
|
|
170
|
+
counts=counts_use,
|
|
171
|
+
metadata=metadata,
|
|
172
|
+
design=f"~{use_groupby}",
|
|
173
|
+
refit_cooks=True,
|
|
174
|
+
quiet=True,
|
|
175
|
+
n_cpus=n_jobs,
|
|
176
|
+
)
|
|
177
|
+
dds.deseq2()
|
|
178
|
+
|
|
179
|
+
try:
|
|
180
|
+
stat_res = DeseqStats(
|
|
181
|
+
dds,
|
|
182
|
+
contrast=[use_groupby, target_group, reference_group],
|
|
183
|
+
quiet=True,
|
|
184
|
+
n_cpus=n_jobs,
|
|
185
|
+
)
|
|
186
|
+
except TypeError:
|
|
187
|
+
stat_res = DeseqStats(
|
|
188
|
+
dds,
|
|
189
|
+
contrast=[use_groupby, target_group, reference_group],
|
|
190
|
+
n_cpus=n_jobs,
|
|
191
|
+
)
|
|
192
|
+
stat_res.summary()
|
|
193
|
+
|
|
194
|
+
res2 = stat_res.results_df.copy().reindex(ad_temp.var_names)
|
|
195
|
+
de_df = pd.DataFrame(index=ad_temp.var_names)
|
|
196
|
+
de_df["logFC"] = res2["log2FoldChange"].fillna(0.0)
|
|
197
|
+
de_df["p_val"] = res2.get("pvalue", pd.Series(1.0, index=res2.index)).fillna(1.0)
|
|
198
|
+
de_df["p_adj"] = res2.get("padj", pd.Series(1.0, index=res2.index)).fillna(1.0)
|
|
199
|
+
return de_df
|
|
200
|
+
|
|
201
|
+
else:
|
|
202
|
+
# Standard scanpy path (works for both regular and pseudobulk when not using pydeseq2)
|
|
203
|
+
rank_key = "_scatrans_rank_genes_groups"
|
|
204
|
+
with warnings.catch_warnings():
|
|
205
|
+
warnings.simplefilter("ignore")
|
|
206
|
+
sc.tl.rank_genes_groups(
|
|
207
|
+
ad_temp,
|
|
208
|
+
groupby=use_groupby,
|
|
209
|
+
groups=[target_group],
|
|
210
|
+
reference=reference_group,
|
|
211
|
+
method=de_method,
|
|
212
|
+
key_added=rank_key,
|
|
213
|
+
)
|
|
214
|
+
de_raw = sc.get.rank_genes_groups_df(ad_temp, group=target_group, key=rank_key).set_index(
|
|
215
|
+
"names"
|
|
216
|
+
)
|
|
217
|
+
de_df = pd.DataFrame(index=ad_temp.var_names)
|
|
218
|
+
de_df["logFC"] = de_raw["logfoldchanges"].reindex(ad_temp.var_names).fillna(0.0)
|
|
219
|
+
de_df["p_val"] = de_raw["pvals"].reindex(ad_temp.var_names).fillna(1.0)
|
|
220
|
+
de_df["p_adj"] = de_raw["pvals_adj"].reindex(ad_temp.var_names).fillna(1.0)
|
|
221
|
+
return de_df
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def _run_mixedlm_de(
|
|
225
|
+
adata: ad.AnnData,
|
|
226
|
+
groupby: str,
|
|
227
|
+
target_group: str,
|
|
228
|
+
reference_group: str,
|
|
229
|
+
sample_col: str,
|
|
230
|
+
n_jobs: int = 1,
|
|
231
|
+
labels: Any | None = None,
|
|
232
|
+
mixed_model_pval: str = "wald",
|
|
233
|
+
) -> pd.DataFrame:
|
|
234
|
+
"""
|
|
235
|
+
Mixed linear model (LMM) DE + Delta Variance using statsmodels mixedlm.
|
|
236
|
+
|
|
237
|
+
Models: y_log ~ C(condition) + (1 | sample)
|
|
238
|
+
- logFC: coefficient for the target condition (on log1p scale)
|
|
239
|
+
- p_val / p_adj: Wald p for the condition fixed effect (BH adj across genes)
|
|
240
|
+
- delta_variance: fraction of total variance (var_fe + re_var + resid) attributable to the fixed condition effect.
|
|
241
|
+
- delta_var_pval: LRT p-value for the contribution of condition (full vs reduced ~1 + (1|sample))
|
|
242
|
+
|
|
243
|
+
This provides a lightweight Python analogue to variancePartition/dream (fraction of variation explained)
|
|
244
|
+
+ LMM DE, suitable for cell-level data with sample-level random effects (addresses pseudoreplication).
|
|
245
|
+
For full voom + precision weights + dreampy/dreamlet on pseudobulk, or NEBULA NB-GLMM, see external packages.
|
|
246
|
+
"""
|
|
247
|
+
try:
|
|
248
|
+
import statsmodels.formula.api as smf
|
|
249
|
+
except ImportError as e:
|
|
250
|
+
raise ImportError(
|
|
251
|
+
"statsmodels is required for use_mixed_model=True (it is a core dependency of scatrans)."
|
|
252
|
+
) from e
|
|
253
|
+
|
|
254
|
+
ad_temp = adata.copy() if labels is not None else adata
|
|
255
|
+
use_groupby = groupby
|
|
256
|
+
if labels is not None:
|
|
257
|
+
use_groupby = "_de_temp_group"
|
|
258
|
+
ad_temp.obs[use_groupby] = pd.Categorical(
|
|
259
|
+
np.asarray(labels).astype(str), categories=[reference_group, target_group]
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
if sample_col not in ad_temp.obs.columns:
|
|
263
|
+
raise ValueError(f"sample_col='{sample_col}' not found in adata.obs")
|
|
264
|
+
|
|
265
|
+
# Prepare expression for LMM: always use log1p library-size normalized (LMM assumes approx Gaussian on log scale)
|
|
266
|
+
# Work on a temp copy to avoid mutating caller adata state for this auxiliary norm
|
|
267
|
+
ad_expr = ad_temp.copy()
|
|
268
|
+
# If very sparse or raw counts, normalize; safe for already-log too (will just re-log)
|
|
269
|
+
try:
|
|
270
|
+
with warnings.catch_warnings():
|
|
271
|
+
warnings.simplefilter("ignore")
|
|
272
|
+
sc.pp.normalize_total(ad_expr, target_sum=1e4)
|
|
273
|
+
sc.pp.log1p(ad_expr)
|
|
274
|
+
except Exception:
|
|
275
|
+
# Fallback: manual log1p on X if pp fails (e.g. already transformed or edge data)
|
|
276
|
+
if sparse.issparse(ad_expr.X):
|
|
277
|
+
Xn = np.log1p(ad_expr.X.toarray())
|
|
278
|
+
else:
|
|
279
|
+
Xn = np.log1p(np.asarray(ad_expr.X))
|
|
280
|
+
ad_expr.X = Xn
|
|
281
|
+
|
|
282
|
+
expr_mat = ad_expr.X.toarray() if sparse.issparse(ad_expr.X) else np.asarray(ad_expr.X)
|
|
283
|
+
|
|
284
|
+
obs = ad_temp.obs
|
|
285
|
+
condition = obs[use_groupby].astype(str).values
|
|
286
|
+
samples = obs[sample_col].astype(str).values
|
|
287
|
+
|
|
288
|
+
n_genes = expr_mat.shape[1]
|
|
289
|
+
var_names = ad_temp.var_names
|
|
290
|
+
|
|
291
|
+
# Per-gene worker (returns idx, logfc, wald_p, lrt_p, delta_var)
|
|
292
|
+
def _fit_gene_mixed(idx: int):
|
|
293
|
+
y = expr_mat[:, idx].astype(float)
|
|
294
|
+
# guard against all-zero / constant (mixedlm will be singular)
|
|
295
|
+
if np.allclose(y, y[0]):
|
|
296
|
+
return idx, 0.0, 1.0, 1.0, 0.0
|
|
297
|
+
df = pd.DataFrame({"y": y, "condition": condition, "sample": samples})
|
|
298
|
+
try:
|
|
299
|
+
with warnings.catch_warnings():
|
|
300
|
+
warnings.simplefilter("ignore")
|
|
301
|
+
# Full model
|
|
302
|
+
md_full = smf.mixedlm("y ~ C(condition)", df, groups=df["sample"])
|
|
303
|
+
m_full = md_full.fit(reml=False, maxiter=200, disp=False)
|
|
304
|
+
|
|
305
|
+
# Reduced (null) for LRT on condition contribution
|
|
306
|
+
md_null = smf.mixedlm("y ~ 1", df, groups=df["sample"])
|
|
307
|
+
m_null = md_null.fit(reml=False, maxiter=200, disp=False)
|
|
308
|
+
|
|
309
|
+
# LRT statistic and p (chi2 df=1 for the added fixed effect term(s))
|
|
310
|
+
lrt_stat = -2.0 * (m_null.llf - m_full.llf)
|
|
311
|
+
lrt_p = float(chi2.sf(max(lrt_stat, 0.0), 1))
|
|
312
|
+
|
|
313
|
+
# Extract condition coef (target vs ref)
|
|
314
|
+
# The param name is typically "C(condition)[T.<target>]" or similar
|
|
315
|
+
coef_name = None
|
|
316
|
+
for pname in m_full.params.index:
|
|
317
|
+
if "condition" in str(pname) and target_group in str(pname):
|
|
318
|
+
coef_name = pname
|
|
319
|
+
break
|
|
320
|
+
if coef_name is None:
|
|
321
|
+
# fallback: take the second coef if intercept + one more
|
|
322
|
+
if len(m_full.params) >= 2:
|
|
323
|
+
coef_name = m_full.params.index[1]
|
|
324
|
+
else:
|
|
325
|
+
coef_name = m_full.params.index[0]
|
|
326
|
+
logfc = float(m_full.params.get(coef_name, 0.0))
|
|
327
|
+
p_wald = float(m_full.pvalues.get(coef_name, 1.0))
|
|
328
|
+
|
|
329
|
+
# Delta variance: var attributable to fixed effects / total modeled var
|
|
330
|
+
exog = m_full.model.exog
|
|
331
|
+
beta = np.asarray(m_full.fe_params)
|
|
332
|
+
fe_contrib = exog @ beta
|
|
333
|
+
var_fe = float(np.var(fe_contrib))
|
|
334
|
+
re_var = 0.0
|
|
335
|
+
try:
|
|
336
|
+
if (
|
|
337
|
+
hasattr(m_full, "cov_re")
|
|
338
|
+
and m_full.cov_re is not None
|
|
339
|
+
and len(m_full.cov_re) > 0
|
|
340
|
+
):
|
|
341
|
+
re_var = float(np.diag(m_full.cov_re)[0]) # first (only) RE variance
|
|
342
|
+
except Exception:
|
|
343
|
+
re_var = 0.0
|
|
344
|
+
resid_var = float(getattr(m_full, "scale", 0.0))
|
|
345
|
+
total_v = var_fe + max(re_var, 0.0) + max(resid_var, 0.0)
|
|
346
|
+
delta_var = var_fe / total_v if total_v > 1e-12 else 0.0
|
|
347
|
+
|
|
348
|
+
return idx, logfc, p_wald, lrt_p, float(np.clip(delta_var, 0.0, 1.0))
|
|
349
|
+
except Exception:
|
|
350
|
+
# Degenerate fit (few samples per group, collinear, etc.) -> non-informative
|
|
351
|
+
return idx, 0.0, 1.0, 1.0, 0.0
|
|
352
|
+
|
|
353
|
+
# Parallel execution (loky or threading; mixedlm releases GIL-ish via numpy)
|
|
354
|
+
effective_jobs = max(1, n_jobs) if n_jobs and n_jobs > 0 else 1
|
|
355
|
+
|
|
356
|
+
results = Parallel(n_jobs=effective_jobs, backend="loky")(
|
|
357
|
+
delayed(_fit_gene_mixed)(i) for i in range(n_genes)
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
# Assemble
|
|
361
|
+
results = sorted(results, key=lambda t: t[0])
|
|
362
|
+
logfcs = np.array([r[1] for r in results], dtype=float)
|
|
363
|
+
p_walds = np.array([r[2] for r in results], dtype=float)
|
|
364
|
+
p_lrts = np.array([r[3] for r in results], dtype=float)
|
|
365
|
+
dvars = np.array([r[4] for r in results], dtype=float)
|
|
366
|
+
|
|
367
|
+
# Choose which p-value to expose as the main "p_val" for active_score weighting and default filtering.
|
|
368
|
+
# "wald": the coefficient test (standard for logFC-like effect)
|
|
369
|
+
# "lrt": the likelihood ratio test for the condition term contribution (ties directly to delta_variance)
|
|
370
|
+
if mixed_model_pval == "lrt":
|
|
371
|
+
main_pvals = p_lrts
|
|
372
|
+
else:
|
|
373
|
+
if mixed_model_pval != "wald":
|
|
374
|
+
logger.warning("mixed_model_pval must be 'wald' or 'lrt'; falling back to 'wald'.")
|
|
375
|
+
main_pvals = p_walds
|
|
376
|
+
|
|
377
|
+
with warnings.catch_warnings():
|
|
378
|
+
warnings.simplefilter("ignore")
|
|
379
|
+
p_adjs = multipletests(main_pvals, method="fdr_bh")[1]
|
|
380
|
+
|
|
381
|
+
de_df = pd.DataFrame(index=var_names)
|
|
382
|
+
de_df["logFC"] = pd.Series(logfcs, index=var_names)
|
|
383
|
+
de_df["p_val"] = pd.Series(main_pvals, index=var_names)
|
|
384
|
+
de_df["p_adj"] = pd.Series(p_adjs, index=var_names)
|
|
385
|
+
de_df["delta_variance"] = pd.Series(dvars, index=var_names)
|
|
386
|
+
de_df["delta_var_pval"] = pd.Series(p_lrts, index=var_names)
|
|
387
|
+
return de_df
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
def _run_memento_de(
|
|
391
|
+
adata: ad.AnnData,
|
|
392
|
+
groupby: str,
|
|
393
|
+
target_group: str,
|
|
394
|
+
reference_group: str,
|
|
395
|
+
labels: Any | None = None,
|
|
396
|
+
capture_rate: float = 0.07,
|
|
397
|
+
num_boot: int = 5000,
|
|
398
|
+
n_cpus: int = -1,
|
|
399
|
+
counts: str | np.ndarray | sparse.spmatrix | pd.DataFrame | ad.AnnData | None = None,
|
|
400
|
+
) -> pd.DataFrame:
|
|
401
|
+
"""Memento (method of moments) cell-level DE backend.
|
|
402
|
+
|
|
403
|
+
Returns a DataFrame with at minimum 'logFC', 'p_val', 'p_adj' (plus optional
|
|
404
|
+
memento_de_* and memento_dv_* columns for advanced inspection).
|
|
405
|
+
This replaces the scanpy rank_genes_groups path when use_memento_de=True.
|
|
406
|
+
"""
|
|
407
|
+
try:
|
|
408
|
+
import memento
|
|
409
|
+
except ImportError as e:
|
|
410
|
+
raise ImportError(
|
|
411
|
+
"memento-de is required when use_memento_de=True. "
|
|
412
|
+
'Install with: pip install "scatrans[memento]" (or pip install memento-de)'
|
|
413
|
+
) from e
|
|
414
|
+
|
|
415
|
+
ad_temp = adata.copy() if labels is not None else adata
|
|
416
|
+
use_groupby = groupby
|
|
417
|
+
|
|
418
|
+
if labels is not None:
|
|
419
|
+
use_groupby = "_memento_temp_group"
|
|
420
|
+
ad_temp.obs[use_groupby] = pd.Categorical(
|
|
421
|
+
np.asarray(labels).astype(str), categories=[reference_group, target_group]
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
# Restrict to the two groups being compared
|
|
425
|
+
keep = ad_temp.obs[use_groupby].astype(str).isin([target_group, reference_group])
|
|
426
|
+
ad_temp = ad_temp[keep].copy()
|
|
427
|
+
|
|
428
|
+
# Binary treatment column expected by memento.binary_test_1d
|
|
429
|
+
ad_temp.obs["stim"] = (ad_temp.obs[use_groupby].astype(str) == target_group).astype(int)
|
|
430
|
+
|
|
431
|
+
# --- Resolve raw counts for Memento ---
|
|
432
|
+
# Priority:
|
|
433
|
+
# 1. Explicit `counts` argument (most flexible)
|
|
434
|
+
# 2. adata.layers["counts"]
|
|
435
|
+
# 3. adata.raw (if it has the counts)
|
|
436
|
+
# 4. Current .X if it already looks like raw counts
|
|
437
|
+
|
|
438
|
+
def _to_csr(x):
|
|
439
|
+
from scipy.sparse import csr_matrix, issparse
|
|
440
|
+
|
|
441
|
+
if issparse(x):
|
|
442
|
+
return x.tocsr()
|
|
443
|
+
return csr_matrix(np.asarray(x))
|
|
444
|
+
|
|
445
|
+
raw_counts = None
|
|
446
|
+
|
|
447
|
+
if counts is not None:
|
|
448
|
+
if isinstance(counts, str):
|
|
449
|
+
if counts in ad_temp.layers:
|
|
450
|
+
raw_counts = ad_temp.layers[counts]
|
|
451
|
+
else:
|
|
452
|
+
raise ValueError(f"counts='{counts}' layer not found in adata.layers")
|
|
453
|
+
elif isinstance(counts, ad.AnnData):
|
|
454
|
+
raw_counts = counts.X
|
|
455
|
+
if counts.var_names.tolist() != ad_temp.var_names.tolist():
|
|
456
|
+
common = counts.var_names.intersection(ad_temp.var_names)
|
|
457
|
+
if len(common) == 0:
|
|
458
|
+
raise ValueError(
|
|
459
|
+
"No overlapping genes between provided counts AnnData and current adata"
|
|
460
|
+
)
|
|
461
|
+
raw_counts = counts[:, common].X
|
|
462
|
+
ad_temp = ad_temp[:, common].copy()
|
|
463
|
+
else:
|
|
464
|
+
raw_counts = counts
|
|
465
|
+
|
|
466
|
+
if raw_counts is not None:
|
|
467
|
+
raw_counts = _to_csr(raw_counts)
|
|
468
|
+
logger.info("Memento: using explicitly provided counts.")
|
|
469
|
+
|
|
470
|
+
if raw_counts is None and "counts" in getattr(ad_temp, "layers", {}):
|
|
471
|
+
raw_counts = ad_temp.layers["counts"]
|
|
472
|
+
logger.info("Memento: using 'counts' layer.")
|
|
473
|
+
raw_counts = _to_csr(raw_counts)
|
|
474
|
+
|
|
475
|
+
if (
|
|
476
|
+
raw_counts is None
|
|
477
|
+
and hasattr(ad_temp, "raw")
|
|
478
|
+
and ad_temp.raw is not None
|
|
479
|
+
and ad_temp.raw.shape[1] >= ad_temp.shape[1]
|
|
480
|
+
):
|
|
481
|
+
raw_counts = ad_temp.raw.X
|
|
482
|
+
logger.info("Memento: using counts from adata.raw.")
|
|
483
|
+
raw_counts = _to_csr(raw_counts)
|
|
484
|
+
|
|
485
|
+
if raw_counts is None and _is_integer_counts_like(ad_temp.X):
|
|
486
|
+
raw_counts = ad_temp.X
|
|
487
|
+
logger.info("Memento: using current .X (looks like raw counts).")
|
|
488
|
+
raw_counts = _to_csr(raw_counts)
|
|
489
|
+
|
|
490
|
+
if raw_counts is None:
|
|
491
|
+
logger.warning(
|
|
492
|
+
"Could not obtain raw counts for Memento. "
|
|
493
|
+
"Memento works best with raw integer UMI counts. "
|
|
494
|
+
"Please call scat.pp.store_raw_counts(adata) early (before HVG + log), "
|
|
495
|
+
"or provide via the `counts` parameter, or ensure adata.raw / layers['counts'] has raw counts."
|
|
496
|
+
)
|
|
497
|
+
raw_counts = _to_csr(ad_temp.X)
|
|
498
|
+
|
|
499
|
+
ad_temp.X = raw_counts
|
|
500
|
+
|
|
501
|
+
from scipy.sparse import issparse
|
|
502
|
+
|
|
503
|
+
if not (issparse(ad_temp.X) and isinstance(ad_temp.X, csr_matrix)):
|
|
504
|
+
ad_temp.X = csr_matrix(ad_temp.X)
|
|
505
|
+
|
|
506
|
+
# Effective cpus
|
|
507
|
+
effective_cpus = n_cpus if n_cpus and n_cpus > 0 else -1
|
|
508
|
+
|
|
509
|
+
result = memento.binary_test_1d(
|
|
510
|
+
adata=ad_temp,
|
|
511
|
+
treatment_col="stim",
|
|
512
|
+
capture_rate=capture_rate,
|
|
513
|
+
num_boot=num_boot,
|
|
514
|
+
num_cpus=effective_cpus,
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
# result may be indexed by gene or have a 'gene' column (handle both)
|
|
518
|
+
if isinstance(result, pd.DataFrame):
|
|
519
|
+
if "gene" in result.columns:
|
|
520
|
+
result = result.set_index("gene")
|
|
521
|
+
res_index = result.index
|
|
522
|
+
else:
|
|
523
|
+
# Fallback (should not happen)
|
|
524
|
+
res_index = ad_temp.var_names
|
|
525
|
+
result = pd.DataFrame(index=res_index)
|
|
526
|
+
|
|
527
|
+
de_df = pd.DataFrame(index=res_index)
|
|
528
|
+
de_df["logFC"] = (
|
|
529
|
+
pd.to_numeric(result.get("de_coef", 0.0), errors="coerce").reindex(res_index).fillna(0.0)
|
|
530
|
+
)
|
|
531
|
+
pvals = (
|
|
532
|
+
pd.to_numeric(result.get("de_pval", 1.0), errors="coerce").reindex(res_index).fillna(1.0)
|
|
533
|
+
)
|
|
534
|
+
de_df["p_val"] = pvals
|
|
535
|
+
|
|
536
|
+
# BH adjustment (Memento may return raw p; we make p_adj consistent with other backends)
|
|
537
|
+
with warnings.catch_warnings():
|
|
538
|
+
warnings.simplefilter("ignore")
|
|
539
|
+
de_df["p_adj"] = multipletests(pvals.values, method="fdr_bh")[1]
|
|
540
|
+
|
|
541
|
+
# Expose Memento's native columns for users who want mean + variability signals
|
|
542
|
+
for src, dst in [
|
|
543
|
+
("de_se", "memento_de_se"),
|
|
544
|
+
("dv_coef", "memento_dv_coef"),
|
|
545
|
+
("dv_se", "memento_dv_se"),
|
|
546
|
+
("dv_pval", "memento_dv_pval"),
|
|
547
|
+
]:
|
|
548
|
+
if src in result.columns:
|
|
549
|
+
de_df[dst] = pd.to_numeric(result[src], errors="coerce").reindex(res_index)
|
|
550
|
+
|
|
551
|
+
# Re-align to the var_names of the adata object that was passed into the wrapper
|
|
552
|
+
# (important for the labels= permutation case and any internal subsetting)
|
|
553
|
+
de_df = de_df.reindex(adata.var_names).fillna({"logFC": 0.0, "p_val": 1.0, "p_adj": 1.0})
|
|
554
|
+
|
|
555
|
+
return de_df
|