scatrans 0.7.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scatrans/__init__.py ADDED
@@ -0,0 +1,56 @@
1
+ """
2
+ scATrans public API.
3
+
4
+ Recommended usage:
5
+ import scatrans as scat
6
+ scat.active_score(...)
7
+ scat.add_gene_features(...)
8
+ scat.pl.set_style()
9
+ scat.run_enrichment(...)
10
+
11
+ Submodules `pl` and `qc` are intentionally exposed (scanpy-style convention).
12
+ Other internal modules are not part of the stable public surface.
13
+ """
14
+
15
+ from . import pl, qc
16
+ from .enrich import list_bundled_gene_sets, run_enrichment, run_kegg, simplify_enrichment
17
+ from .generate_gene_features import main as generate_gene_features_main
18
+ from .pp_bias import add_gene_features, list_available_gene_features
19
+ from .tl import (
20
+ active_score,
21
+ diagnose_design,
22
+ differential_expression,
23
+ filter_active_genes,
24
+ restore_raw_counts,
25
+ store_raw_counts,
26
+ )
27
+
28
+ __all__ = [
29
+ "active_score",
30
+ "differential_expression",
31
+ "diagnose_design",
32
+ "filter_active_genes",
33
+ "restore_raw_counts",
34
+ "store_raw_counts",
35
+ "add_gene_features",
36
+ "list_available_gene_features",
37
+ "run_enrichment",
38
+ "run_kegg",
39
+ "simplify_enrichment",
40
+ "list_bundled_gene_sets",
41
+ "pl",
42
+ "qc",
43
+ "generate_gene_features_main",
44
+ "__version__",
45
+ ]
46
+
47
+ # Version is provided dynamically when possible
48
+ try:
49
+ from ._version import version as __version__
50
+ except ImportError:
51
+ __version__ = "0.7.0.dev0"
52
+
53
+ # Optional: prevent some internal modules from appearing too prominently
54
+ # in casual inspection while still allowing advanced users to do
55
+ # `import scatrans.tl as tl` if they really need it.
56
+ # We do not delete them aggressively to preserve scanpy-like ergonomics.
scatrans/_bias.py ADDED
@@ -0,0 +1,24 @@
1
+ """
2
+ Bias correction (Huber regression on gene length + intron number).
3
+
4
+ The actual implementation lives in _utils._fit_huber_bias_correction so it can be
5
+ used from both the main analysis path and from permutation tasks without duplication.
6
+
7
+ Enhanced return: (residual, bias_info_dict) with fit diagnostics for transparency.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from typing import Any
13
+
14
+ import numpy as np
15
+
16
+ from ._utils import _fit_huber_bias_correction as _raw_fit
17
+
18
+
19
+ def fit_huber_bias_correction(*args, **kwargs) -> tuple[np.ndarray, dict[str, Any]]:
20
+ """Public/internal wrapper that returns (residual, bias_info)."""
21
+ return _raw_fit(*args, **kwargs)
22
+
23
+
24
+ __all__ = ["fit_huber_bias_correction"]
scatrans/_de.py ADDED
@@ -0,0 +1,555 @@
1
+ """
2
+ scATrans internal differential expression helpers.
3
+
4
+ Contains the wrapper that supports both scanpy rank_genes_groups and PyDESeq2
5
+ for pseudobulk. Extracted so tl.py stays small.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import logging
11
+ import warnings
12
+ from typing import Any
13
+
14
+ import anndata as ad
15
+ import numpy as np
16
+ import pandas as pd
17
+ import scanpy as sc
18
+ from joblib import Parallel, delayed
19
+ from scipy import sparse
20
+ from scipy.sparse import csr_matrix
21
+ from scipy.stats import chi2
22
+ from statsmodels.stats.multitest import multipletests
23
+
24
+ from ._utils import (
25
+ _is_integer_counts_like,
26
+ _warn_if_low_counts_matrix,
27
+ )
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ def _run_de_wrapper(
33
+ adata: ad.AnnData,
34
+ groupby: str,
35
+ target_group: str,
36
+ reference_group: str,
37
+ de_method: str = "t-test_overestim_var",
38
+ is_pseudobulk: bool = False,
39
+ pb_backend: str = "pydeseq2",
40
+ n_jobs: int = 1,
41
+ labels: Any | None = None,
42
+ strict_pydeseq2_counts: bool = True,
43
+ use_mixed_model: bool = False,
44
+ sample_col: str | None = None,
45
+ mixed_model_pval: str = "wald",
46
+ # Memento (Cell 2024 method-of-moments) as independent cell-level DE backend
47
+ use_memento_de: bool = False,
48
+ memento_capture_rate: float = 0.07,
49
+ memento_num_boot: int = 5000,
50
+ memento_n_cpus: int = -1,
51
+ # Allow providing raw counts separately (common when adata.X is already HVG + log1p)
52
+ counts: str | np.ndarray | sparse.spmatrix | pd.DataFrame | ad.AnnData | None = None,
53
+ ) -> pd.DataFrame:
54
+ """Run DE and return a DataFrame with logFC, p_val, p_adj (and optionally delta_variance, delta_var_pval when mixed).
55
+
56
+ When use_memento_de=True, Memento is used for the primary DE statistics (logFC/p_adj).
57
+ This is treated as a third parallel cell-level backend (alongside scanpy-style and mixed-model).
58
+ """
59
+ if use_mixed_model:
60
+ if sample_col is None:
61
+ raise ValueError("sample_col must be provided when use_mixed_model=True")
62
+ return _run_mixedlm_de(
63
+ adata,
64
+ groupby=groupby,
65
+ target_group=target_group,
66
+ reference_group=reference_group,
67
+ sample_col=sample_col,
68
+ n_jobs=n_jobs,
69
+ labels=labels,
70
+ mixed_model_pval=mixed_model_pval,
71
+ )
72
+
73
+ if use_memento_de:
74
+ if is_pseudobulk:
75
+ raise ValueError(
76
+ "use_memento_de=True is not supported with use_pseudobulk=True "
77
+ "(Memento is a cell-level method-of-moments estimator; use PyDESeq2 for pseudobulk)."
78
+ )
79
+ return _run_memento_de(
80
+ adata,
81
+ groupby=groupby,
82
+ target_group=target_group,
83
+ reference_group=reference_group,
84
+ labels=labels,
85
+ capture_rate=memento_capture_rate,
86
+ num_boot=memento_num_boot,
87
+ n_cpus=memento_n_cpus,
88
+ counts=counts,
89
+ )
90
+
91
+ ad_temp = adata.copy() if labels is not None else adata
92
+ use_groupby = groupby
93
+
94
+ if labels is not None:
95
+ use_groupby = "_de_temp_group"
96
+ ad_temp.obs[use_groupby] = pd.Categorical(
97
+ np.asarray(labels).astype(str), categories=[reference_group, target_group]
98
+ )
99
+
100
+ if is_pseudobulk and pb_backend == "pydeseq2":
101
+ try:
102
+ from pydeseq2.dds import DeseqDataSet
103
+ from pydeseq2.ds import DeseqStats
104
+ except ImportError as e:
105
+ raise ImportError(
106
+ "pydeseq2 is required when pseudobulk_de_backend='pydeseq2'. "
107
+ "Install with: pip install pydeseq2 or 'scatrans[pseudobulk]'"
108
+ ) from e
109
+
110
+ n_t = (ad_temp.obs[use_groupby] == target_group).sum()
111
+ n_r = (ad_temp.obs[use_groupby] == reference_group).sum()
112
+ if n_t < 2 or n_r < 2:
113
+ raise ValueError(
114
+ f"PyDESeq2 requires >=2 replicates per group. Found {n_t} target, {n_r} ref."
115
+ )
116
+
117
+ is_count_like = _is_integer_counts_like(ad_temp.X)
118
+
119
+ if not is_count_like:
120
+ msg = (
121
+ "Data passed to PyDESeq2 does not look like raw non-negative integer counts. "
122
+ "PyDESeq2 requires unnormalized integer counts in adata.X. "
123
+ "If you intentionally want to allow rounding, set strict_pydeseq2_counts=False."
124
+ )
125
+ if strict_pydeseq2_counts:
126
+ raise ValueError(msg)
127
+ logger.warning(msg)
128
+ else:
129
+ _warn_if_low_counts_matrix(ad_temp.X)
130
+
131
+ if sparse.issparse(ad_temp.X):
132
+ gene_sums = np.asarray(ad_temp.X.sum(axis=0)).ravel()
133
+ gene_keep = gene_sums >= 10
134
+ if gene_keep.sum() == 0:
135
+ raise ValueError("No genes passed the DESeq2 count filter (sum(counts) >= 10).")
136
+ X_filtered = ad_temp.X[:, gene_keep].toarray()
137
+ X_filtered = np.clip(np.round(np.nan_to_num(X_filtered)), 0, None).astype(int)
138
+ counts_use = pd.DataFrame(
139
+ X_filtered, index=ad_temp.obs_names, columns=ad_temp.var_names[gene_keep]
140
+ )
141
+ else:
142
+ X = np.asarray(ad_temp.X)
143
+ X = np.clip(np.round(np.nan_to_num(X)), 0, None).astype(int)
144
+ counts_df = pd.DataFrame(X, index=ad_temp.obs_names, columns=ad_temp.var_names)
145
+ gene_keep = counts_df.sum(axis=0) >= 10
146
+ counts_use = counts_df.loc[:, gene_keep].copy()
147
+
148
+ if counts_use.shape[1] == 0:
149
+ raise ValueError("No genes passed the DESeq2 count filter (sum(counts) >= 10).")
150
+
151
+ condition = ad_temp.obs[use_groupby].astype(str).values
152
+ metadata = pd.DataFrame(
153
+ {use_groupby: pd.Categorical(condition, categories=[reference_group, target_group])},
154
+ index=counts_use.index,
155
+ )
156
+
157
+ with warnings.catch_warnings():
158
+ warnings.simplefilter("ignore")
159
+ try:
160
+ dds = DeseqDataSet(
161
+ counts=counts_use,
162
+ metadata=metadata,
163
+ design_factors=use_groupby,
164
+ ref_level=[use_groupby, reference_group],
165
+ quiet=True,
166
+ n_cpus=n_jobs,
167
+ )
168
+ except TypeError:
169
+ dds = DeseqDataSet(
170
+ counts=counts_use,
171
+ metadata=metadata,
172
+ design=f"~{use_groupby}",
173
+ refit_cooks=True,
174
+ quiet=True,
175
+ n_cpus=n_jobs,
176
+ )
177
+ dds.deseq2()
178
+
179
+ try:
180
+ stat_res = DeseqStats(
181
+ dds,
182
+ contrast=[use_groupby, target_group, reference_group],
183
+ quiet=True,
184
+ n_cpus=n_jobs,
185
+ )
186
+ except TypeError:
187
+ stat_res = DeseqStats(
188
+ dds,
189
+ contrast=[use_groupby, target_group, reference_group],
190
+ n_cpus=n_jobs,
191
+ )
192
+ stat_res.summary()
193
+
194
+ res2 = stat_res.results_df.copy().reindex(ad_temp.var_names)
195
+ de_df = pd.DataFrame(index=ad_temp.var_names)
196
+ de_df["logFC"] = res2["log2FoldChange"].fillna(0.0)
197
+ de_df["p_val"] = res2.get("pvalue", pd.Series(1.0, index=res2.index)).fillna(1.0)
198
+ de_df["p_adj"] = res2.get("padj", pd.Series(1.0, index=res2.index)).fillna(1.0)
199
+ return de_df
200
+
201
+ else:
202
+ # Standard scanpy path (works for both regular and pseudobulk when not using pydeseq2)
203
+ rank_key = "_scatrans_rank_genes_groups"
204
+ with warnings.catch_warnings():
205
+ warnings.simplefilter("ignore")
206
+ sc.tl.rank_genes_groups(
207
+ ad_temp,
208
+ groupby=use_groupby,
209
+ groups=[target_group],
210
+ reference=reference_group,
211
+ method=de_method,
212
+ key_added=rank_key,
213
+ )
214
+ de_raw = sc.get.rank_genes_groups_df(ad_temp, group=target_group, key=rank_key).set_index(
215
+ "names"
216
+ )
217
+ de_df = pd.DataFrame(index=ad_temp.var_names)
218
+ de_df["logFC"] = de_raw["logfoldchanges"].reindex(ad_temp.var_names).fillna(0.0)
219
+ de_df["p_val"] = de_raw["pvals"].reindex(ad_temp.var_names).fillna(1.0)
220
+ de_df["p_adj"] = de_raw["pvals_adj"].reindex(ad_temp.var_names).fillna(1.0)
221
+ return de_df
222
+
223
+
224
+ def _run_mixedlm_de(
225
+ adata: ad.AnnData,
226
+ groupby: str,
227
+ target_group: str,
228
+ reference_group: str,
229
+ sample_col: str,
230
+ n_jobs: int = 1,
231
+ labels: Any | None = None,
232
+ mixed_model_pval: str = "wald",
233
+ ) -> pd.DataFrame:
234
+ """
235
+ Mixed linear model (LMM) DE + Delta Variance using statsmodels mixedlm.
236
+
237
+ Models: y_log ~ C(condition) + (1 | sample)
238
+ - logFC: coefficient for the target condition (on log1p scale)
239
+ - p_val / p_adj: Wald p for the condition fixed effect (BH adj across genes)
240
+ - delta_variance: fraction of total variance (var_fe + re_var + resid) attributable to the fixed condition effect.
241
+ - delta_var_pval: LRT p-value for the contribution of condition (full vs reduced ~1 + (1|sample))
242
+
243
+ This provides a lightweight Python analogue to variancePartition/dream (fraction of variation explained)
244
+ + LMM DE, suitable for cell-level data with sample-level random effects (addresses pseudoreplication).
245
+ For full voom + precision weights + dreampy/dreamlet on pseudobulk, or NEBULA NB-GLMM, see external packages.
246
+ """
247
+ try:
248
+ import statsmodels.formula.api as smf
249
+ except ImportError as e:
250
+ raise ImportError(
251
+ "statsmodels is required for use_mixed_model=True (it is a core dependency of scatrans)."
252
+ ) from e
253
+
254
+ ad_temp = adata.copy() if labels is not None else adata
255
+ use_groupby = groupby
256
+ if labels is not None:
257
+ use_groupby = "_de_temp_group"
258
+ ad_temp.obs[use_groupby] = pd.Categorical(
259
+ np.asarray(labels).astype(str), categories=[reference_group, target_group]
260
+ )
261
+
262
+ if sample_col not in ad_temp.obs.columns:
263
+ raise ValueError(f"sample_col='{sample_col}' not found in adata.obs")
264
+
265
+ # Prepare expression for LMM: always use log1p library-size normalized (LMM assumes approx Gaussian on log scale)
266
+ # Work on a temp copy to avoid mutating caller adata state for this auxiliary norm
267
+ ad_expr = ad_temp.copy()
268
+ # If very sparse or raw counts, normalize; safe for already-log too (will just re-log)
269
+ try:
270
+ with warnings.catch_warnings():
271
+ warnings.simplefilter("ignore")
272
+ sc.pp.normalize_total(ad_expr, target_sum=1e4)
273
+ sc.pp.log1p(ad_expr)
274
+ except Exception:
275
+ # Fallback: manual log1p on X if pp fails (e.g. already transformed or edge data)
276
+ if sparse.issparse(ad_expr.X):
277
+ Xn = np.log1p(ad_expr.X.toarray())
278
+ else:
279
+ Xn = np.log1p(np.asarray(ad_expr.X))
280
+ ad_expr.X = Xn
281
+
282
+ expr_mat = ad_expr.X.toarray() if sparse.issparse(ad_expr.X) else np.asarray(ad_expr.X)
283
+
284
+ obs = ad_temp.obs
285
+ condition = obs[use_groupby].astype(str).values
286
+ samples = obs[sample_col].astype(str).values
287
+
288
+ n_genes = expr_mat.shape[1]
289
+ var_names = ad_temp.var_names
290
+
291
+ # Per-gene worker (returns idx, logfc, wald_p, lrt_p, delta_var)
292
+ def _fit_gene_mixed(idx: int):
293
+ y = expr_mat[:, idx].astype(float)
294
+ # guard against all-zero / constant (mixedlm will be singular)
295
+ if np.allclose(y, y[0]):
296
+ return idx, 0.0, 1.0, 1.0, 0.0
297
+ df = pd.DataFrame({"y": y, "condition": condition, "sample": samples})
298
+ try:
299
+ with warnings.catch_warnings():
300
+ warnings.simplefilter("ignore")
301
+ # Full model
302
+ md_full = smf.mixedlm("y ~ C(condition)", df, groups=df["sample"])
303
+ m_full = md_full.fit(reml=False, maxiter=200, disp=False)
304
+
305
+ # Reduced (null) for LRT on condition contribution
306
+ md_null = smf.mixedlm("y ~ 1", df, groups=df["sample"])
307
+ m_null = md_null.fit(reml=False, maxiter=200, disp=False)
308
+
309
+ # LRT statistic and p (chi2 df=1 for the added fixed effect term(s))
310
+ lrt_stat = -2.0 * (m_null.llf - m_full.llf)
311
+ lrt_p = float(chi2.sf(max(lrt_stat, 0.0), 1))
312
+
313
+ # Extract condition coef (target vs ref)
314
+ # The param name is typically "C(condition)[T.<target>]" or similar
315
+ coef_name = None
316
+ for pname in m_full.params.index:
317
+ if "condition" in str(pname) and target_group in str(pname):
318
+ coef_name = pname
319
+ break
320
+ if coef_name is None:
321
+ # fallback: take the second coef if intercept + one more
322
+ if len(m_full.params) >= 2:
323
+ coef_name = m_full.params.index[1]
324
+ else:
325
+ coef_name = m_full.params.index[0]
326
+ logfc = float(m_full.params.get(coef_name, 0.0))
327
+ p_wald = float(m_full.pvalues.get(coef_name, 1.0))
328
+
329
+ # Delta variance: var attributable to fixed effects / total modeled var
330
+ exog = m_full.model.exog
331
+ beta = np.asarray(m_full.fe_params)
332
+ fe_contrib = exog @ beta
333
+ var_fe = float(np.var(fe_contrib))
334
+ re_var = 0.0
335
+ try:
336
+ if (
337
+ hasattr(m_full, "cov_re")
338
+ and m_full.cov_re is not None
339
+ and len(m_full.cov_re) > 0
340
+ ):
341
+ re_var = float(np.diag(m_full.cov_re)[0]) # first (only) RE variance
342
+ except Exception:
343
+ re_var = 0.0
344
+ resid_var = float(getattr(m_full, "scale", 0.0))
345
+ total_v = var_fe + max(re_var, 0.0) + max(resid_var, 0.0)
346
+ delta_var = var_fe / total_v if total_v > 1e-12 else 0.0
347
+
348
+ return idx, logfc, p_wald, lrt_p, float(np.clip(delta_var, 0.0, 1.0))
349
+ except Exception:
350
+ # Degenerate fit (few samples per group, collinear, etc.) -> non-informative
351
+ return idx, 0.0, 1.0, 1.0, 0.0
352
+
353
+ # Parallel execution (loky or threading; mixedlm releases GIL-ish via numpy)
354
+ effective_jobs = max(1, n_jobs) if n_jobs and n_jobs > 0 else 1
355
+
356
+ results = Parallel(n_jobs=effective_jobs, backend="loky")(
357
+ delayed(_fit_gene_mixed)(i) for i in range(n_genes)
358
+ )
359
+
360
+ # Assemble
361
+ results = sorted(results, key=lambda t: t[0])
362
+ logfcs = np.array([r[1] for r in results], dtype=float)
363
+ p_walds = np.array([r[2] for r in results], dtype=float)
364
+ p_lrts = np.array([r[3] for r in results], dtype=float)
365
+ dvars = np.array([r[4] for r in results], dtype=float)
366
+
367
+ # Choose which p-value to expose as the main "p_val" for active_score weighting and default filtering.
368
+ # "wald": the coefficient test (standard for logFC-like effect)
369
+ # "lrt": the likelihood ratio test for the condition term contribution (ties directly to delta_variance)
370
+ if mixed_model_pval == "lrt":
371
+ main_pvals = p_lrts
372
+ else:
373
+ if mixed_model_pval != "wald":
374
+ logger.warning("mixed_model_pval must be 'wald' or 'lrt'; falling back to 'wald'.")
375
+ main_pvals = p_walds
376
+
377
+ with warnings.catch_warnings():
378
+ warnings.simplefilter("ignore")
379
+ p_adjs = multipletests(main_pvals, method="fdr_bh")[1]
380
+
381
+ de_df = pd.DataFrame(index=var_names)
382
+ de_df["logFC"] = pd.Series(logfcs, index=var_names)
383
+ de_df["p_val"] = pd.Series(main_pvals, index=var_names)
384
+ de_df["p_adj"] = pd.Series(p_adjs, index=var_names)
385
+ de_df["delta_variance"] = pd.Series(dvars, index=var_names)
386
+ de_df["delta_var_pval"] = pd.Series(p_lrts, index=var_names)
387
+ return de_df
388
+
389
+
390
+ def _run_memento_de(
391
+ adata: ad.AnnData,
392
+ groupby: str,
393
+ target_group: str,
394
+ reference_group: str,
395
+ labels: Any | None = None,
396
+ capture_rate: float = 0.07,
397
+ num_boot: int = 5000,
398
+ n_cpus: int = -1,
399
+ counts: str | np.ndarray | sparse.spmatrix | pd.DataFrame | ad.AnnData | None = None,
400
+ ) -> pd.DataFrame:
401
+ """Memento (method of moments) cell-level DE backend.
402
+
403
+ Returns a DataFrame with at minimum 'logFC', 'p_val', 'p_adj' (plus optional
404
+ memento_de_* and memento_dv_* columns for advanced inspection).
405
+ This replaces the scanpy rank_genes_groups path when use_memento_de=True.
406
+ """
407
+ try:
408
+ import memento
409
+ except ImportError as e:
410
+ raise ImportError(
411
+ "memento-de is required when use_memento_de=True. "
412
+ 'Install with: pip install "scatrans[memento]" (or pip install memento-de)'
413
+ ) from e
414
+
415
+ ad_temp = adata.copy() if labels is not None else adata
416
+ use_groupby = groupby
417
+
418
+ if labels is not None:
419
+ use_groupby = "_memento_temp_group"
420
+ ad_temp.obs[use_groupby] = pd.Categorical(
421
+ np.asarray(labels).astype(str), categories=[reference_group, target_group]
422
+ )
423
+
424
+ # Restrict to the two groups being compared
425
+ keep = ad_temp.obs[use_groupby].astype(str).isin([target_group, reference_group])
426
+ ad_temp = ad_temp[keep].copy()
427
+
428
+ # Binary treatment column expected by memento.binary_test_1d
429
+ ad_temp.obs["stim"] = (ad_temp.obs[use_groupby].astype(str) == target_group).astype(int)
430
+
431
+ # --- Resolve raw counts for Memento ---
432
+ # Priority:
433
+ # 1. Explicit `counts` argument (most flexible)
434
+ # 2. adata.layers["counts"]
435
+ # 3. adata.raw (if it has the counts)
436
+ # 4. Current .X if it already looks like raw counts
437
+
438
+ def _to_csr(x):
439
+ from scipy.sparse import csr_matrix, issparse
440
+
441
+ if issparse(x):
442
+ return x.tocsr()
443
+ return csr_matrix(np.asarray(x))
444
+
445
+ raw_counts = None
446
+
447
+ if counts is not None:
448
+ if isinstance(counts, str):
449
+ if counts in ad_temp.layers:
450
+ raw_counts = ad_temp.layers[counts]
451
+ else:
452
+ raise ValueError(f"counts='{counts}' layer not found in adata.layers")
453
+ elif isinstance(counts, ad.AnnData):
454
+ raw_counts = counts.X
455
+ if counts.var_names.tolist() != ad_temp.var_names.tolist():
456
+ common = counts.var_names.intersection(ad_temp.var_names)
457
+ if len(common) == 0:
458
+ raise ValueError(
459
+ "No overlapping genes between provided counts AnnData and current adata"
460
+ )
461
+ raw_counts = counts[:, common].X
462
+ ad_temp = ad_temp[:, common].copy()
463
+ else:
464
+ raw_counts = counts
465
+
466
+ if raw_counts is not None:
467
+ raw_counts = _to_csr(raw_counts)
468
+ logger.info("Memento: using explicitly provided counts.")
469
+
470
+ if raw_counts is None and "counts" in getattr(ad_temp, "layers", {}):
471
+ raw_counts = ad_temp.layers["counts"]
472
+ logger.info("Memento: using 'counts' layer.")
473
+ raw_counts = _to_csr(raw_counts)
474
+
475
+ if (
476
+ raw_counts is None
477
+ and hasattr(ad_temp, "raw")
478
+ and ad_temp.raw is not None
479
+ and ad_temp.raw.shape[1] >= ad_temp.shape[1]
480
+ ):
481
+ raw_counts = ad_temp.raw.X
482
+ logger.info("Memento: using counts from adata.raw.")
483
+ raw_counts = _to_csr(raw_counts)
484
+
485
+ if raw_counts is None and _is_integer_counts_like(ad_temp.X):
486
+ raw_counts = ad_temp.X
487
+ logger.info("Memento: using current .X (looks like raw counts).")
488
+ raw_counts = _to_csr(raw_counts)
489
+
490
+ if raw_counts is None:
491
+ logger.warning(
492
+ "Could not obtain raw counts for Memento. "
493
+ "Memento works best with raw integer UMI counts. "
494
+ "Please call scat.pp.store_raw_counts(adata) early (before HVG + log), "
495
+ "or provide via the `counts` parameter, or ensure adata.raw / layers['counts'] has raw counts."
496
+ )
497
+ raw_counts = _to_csr(ad_temp.X)
498
+
499
+ ad_temp.X = raw_counts
500
+
501
+ from scipy.sparse import issparse
502
+
503
+ if not (issparse(ad_temp.X) and isinstance(ad_temp.X, csr_matrix)):
504
+ ad_temp.X = csr_matrix(ad_temp.X)
505
+
506
+ # Effective cpus
507
+ effective_cpus = n_cpus if n_cpus and n_cpus > 0 else -1
508
+
509
+ result = memento.binary_test_1d(
510
+ adata=ad_temp,
511
+ treatment_col="stim",
512
+ capture_rate=capture_rate,
513
+ num_boot=num_boot,
514
+ num_cpus=effective_cpus,
515
+ )
516
+
517
+ # result may be indexed by gene or have a 'gene' column (handle both)
518
+ if isinstance(result, pd.DataFrame):
519
+ if "gene" in result.columns:
520
+ result = result.set_index("gene")
521
+ res_index = result.index
522
+ else:
523
+ # Fallback (should not happen)
524
+ res_index = ad_temp.var_names
525
+ result = pd.DataFrame(index=res_index)
526
+
527
+ de_df = pd.DataFrame(index=res_index)
528
+ de_df["logFC"] = (
529
+ pd.to_numeric(result.get("de_coef", 0.0), errors="coerce").reindex(res_index).fillna(0.0)
530
+ )
531
+ pvals = (
532
+ pd.to_numeric(result.get("de_pval", 1.0), errors="coerce").reindex(res_index).fillna(1.0)
533
+ )
534
+ de_df["p_val"] = pvals
535
+
536
+ # BH adjustment (Memento may return raw p; we make p_adj consistent with other backends)
537
+ with warnings.catch_warnings():
538
+ warnings.simplefilter("ignore")
539
+ de_df["p_adj"] = multipletests(pvals.values, method="fdr_bh")[1]
540
+
541
+ # Expose Memento's native columns for users who want mean + variability signals
542
+ for src, dst in [
543
+ ("de_se", "memento_de_se"),
544
+ ("dv_coef", "memento_dv_coef"),
545
+ ("dv_se", "memento_dv_se"),
546
+ ("dv_pval", "memento_dv_pval"),
547
+ ]:
548
+ if src in result.columns:
549
+ de_df[dst] = pd.to_numeric(result[src], errors="coerce").reindex(res_index)
550
+
551
+ # Re-align to the var_names of the adata object that was passed into the wrapper
552
+ # (important for the labels= permutation case and any internal subsetting)
553
+ de_df = de_df.reindex(adata.var_names).fillna({"logFC": 0.0, "p_val": 1.0, "p_adj": 1.0})
554
+
555
+ return de_df