scatrans 0.7.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,229 @@
1
+ """
2
+ Permutation testing support for active_score significance.
3
+
4
+ The heavy _single_permutation_task (and the orchestration) was one of the
5
+ biggest contributors to tl.py line count. Extracted here.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import logging
11
+ import warnings
12
+ from typing import Any
13
+
14
+ import numpy as np
15
+ import scanpy as sc
16
+ from joblib import Parallel, delayed
17
+ from statsmodels.stats.multitest import multipletests
18
+
19
+ from ._de import _run_de_wrapper
20
+
21
+ # local import to avoid circulars at module load
22
+ from ._utils import (
23
+ _fit_huber_bias_correction,
24
+ _soft_scale,
25
+ )
26
+ from ._velocity import _compute_velocity_delta
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ def _single_permutation_task(
32
+ seed: int,
33
+ original_labels: np.ndarray,
34
+ target_group: str,
35
+ reference_group: str,
36
+ adata_subset: Any,
37
+ X_features: np.ndarray | None,
38
+ valid_feat: np.ndarray,
39
+ uns_layer: Any,
40
+ spl_layer: Any,
41
+ total_us_for_filter: np.ndarray,
42
+ min_total_counts: int,
43
+ weight_fc: float,
44
+ weight_unspliced: float,
45
+ weight_pval: float,
46
+ lambda_fc: float,
47
+ lambda_res: float,
48
+ lambda_pval: float,
49
+ is_pseudobulk: bool,
50
+ pb_backend: str,
51
+ de_method: str,
52
+ prior_weight: float,
53
+ de_preprocess: str,
54
+ strict_pydeseq2_counts: bool,
55
+ bias_correction: str = "huber_length_intron",
56
+ # Memento support for permutation (advanced, usually False for speed)
57
+ use_memento_de: bool = False,
58
+ memento_capture_rate: float = 0.07,
59
+ memento_num_boot: int = 5000,
60
+ memento_n_cpus: int = -1,
61
+ ) -> np.ndarray:
62
+ """One permutation replicate. Returns the active score vector for that shuffle.
63
+
64
+ bias_correction is forwarded to the shared bias correction routine so that
65
+ permutation scores are computed under the same correction setting the user chose
66
+ for the real data (default = on).
67
+ """
68
+ rng = np.random.default_rng(seed)
69
+ for _ in range(50):
70
+ shuffled_labels = rng.permutation(original_labels)
71
+ if not np.array_equal(shuffled_labels, original_labels):
72
+ break
73
+ else:
74
+ logger.warning("Failed to generate a different permutation after 50 attempts.")
75
+
76
+ ad_temp = adata_subset.copy()
77
+
78
+ if de_preprocess == "normalize_log1p":
79
+ sc.pp.normalize_total(ad_temp, target_sum=1e4)
80
+ sc.pp.log1p(ad_temp)
81
+ elif de_preprocess == "auto" and not (is_pseudobulk and pb_backend == "pydeseq2"):
82
+ if "log1p" not in ad_temp.uns:
83
+ sc.pp.normalize_total(ad_temp, target_sum=1e4)
84
+ sc.pp.log1p(ad_temp)
85
+ elif de_preprocess == "none":
86
+ pass
87
+
88
+ perm_de_df = _run_de_wrapper(
89
+ ad_temp,
90
+ groupby="_unused_when_labels_provided",
91
+ target_group=target_group,
92
+ reference_group=reference_group,
93
+ de_method=de_method,
94
+ is_pseudobulk=is_pseudobulk,
95
+ pb_backend=pb_backend,
96
+ n_jobs=1,
97
+ labels=shuffled_labels,
98
+ strict_pydeseq2_counts=strict_pydeseq2_counts,
99
+ use_memento_de=use_memento_de,
100
+ memento_capture_rate=memento_capture_rate,
101
+ memento_num_boot=memento_num_boot,
102
+ memento_n_cpus=memento_n_cpus,
103
+ )
104
+
105
+ t_mask = shuffled_labels == target_group
106
+ r_mask = shuffled_labels == reference_group
107
+ delta_velocity, _, _gamma_ref = _compute_velocity_delta(
108
+ uns_layer, spl_layer, t_mask, r_mask, prior_weight
109
+ )
110
+
111
+ total_us_for_filter = np.asarray(total_us_for_filter)
112
+ valid_expr = total_us_for_filter >= min_total_counts
113
+
114
+ # Use the shared bias correction (DRY). bias_correction setting is respected
115
+ # so that permuted scores are comparable to the real run.
116
+ gene_length = adata_subset.var["gene_length"].values
117
+ intron_number = adata_subset.var["intron_number"].values
118
+
119
+ residual, _bias_info = _fit_huber_bias_correction(
120
+ delta_velocity,
121
+ gene_length,
122
+ intron_number,
123
+ total_us_for_filter,
124
+ valid_feat,
125
+ valid_expr,
126
+ X_features,
127
+ bias_correction=bias_correction,
128
+ )
129
+
130
+ s1 = _soft_scale(perm_de_df["logFC"].values, lambda_fc)
131
+ s2 = _soft_scale(residual, lambda_res)
132
+ s3 = _soft_scale(-np.log10(perm_de_df["p_adj"].values + 1e-300), lambda_pval)
133
+
134
+ total_w = weight_fc + weight_unspliced + weight_pval
135
+ return (weight_fc * s1 + weight_unspliced * s2 + weight_pval * s3) / total_w * 100.0
136
+
137
+
138
+ def run_permutation_test(
139
+ *,
140
+ n_perm: int,
141
+ effective_n_jobs: int,
142
+ random_seed: int,
143
+ obs_labels: np.ndarray,
144
+ target_group: str,
145
+ reference_group: str,
146
+ adata: Any,
147
+ X_features: np.ndarray | None,
148
+ valid_feat: np.ndarray,
149
+ velocity_layer_for_perm_uns: Any,
150
+ velocity_layer_for_perm_spl: Any,
151
+ total_us_raw: np.ndarray,
152
+ min_total_counts: int,
153
+ weight_fc: float,
154
+ weight_unspliced: float,
155
+ weight_pval: float,
156
+ lambda_fc: float,
157
+ lambda_res: float,
158
+ lambda_pval: float,
159
+ is_pseudobulk: bool,
160
+ perm_pb_backend: str,
161
+ perm_de_method: str,
162
+ prior_weight: float,
163
+ de_preprocess: str,
164
+ strict_pydeseq2_counts: bool,
165
+ real_score: np.ndarray,
166
+ bias_correction: str = "huber_length_intron",
167
+ # Memento forwarding for advanced consistent permutation
168
+ use_memento_de: bool = False,
169
+ memento_capture_rate: float = 0.07,
170
+ memento_num_boot: int = 5000,
171
+ memento_n_cpus: int = -1,
172
+ ) -> tuple:
173
+ """Run the full parallel permutation and return (pvals, fdr, use_fdr_for_significance, reason_if_disabled).
174
+
175
+ bias_correction is passed through so permutations match the user's chosen setting.
176
+ """
177
+ logger.info("Running parallel permutation testing (%d iterations)...", n_perm)
178
+
179
+ with warnings.catch_warnings():
180
+ warnings.simplefilter("ignore")
181
+ perm_results = Parallel(n_jobs=effective_n_jobs, backend="loky")(
182
+ delayed(_single_permutation_task)(
183
+ i + random_seed,
184
+ obs_labels,
185
+ target_group,
186
+ reference_group,
187
+ adata,
188
+ X_features,
189
+ valid_feat,
190
+ velocity_layer_for_perm_uns,
191
+ velocity_layer_for_perm_spl,
192
+ total_us_raw,
193
+ min_total_counts,
194
+ weight_fc,
195
+ weight_unspliced,
196
+ weight_pval,
197
+ lambda_fc,
198
+ lambda_res,
199
+ lambda_pval,
200
+ is_pseudobulk,
201
+ perm_pb_backend,
202
+ perm_de_method,
203
+ prior_weight,
204
+ de_preprocess,
205
+ strict_pydeseq2_counts,
206
+ bias_correction=bias_correction,
207
+ use_memento_de=use_memento_de,
208
+ memento_capture_rate=memento_capture_rate,
209
+ memento_num_boot=memento_num_boot,
210
+ memento_n_cpus=memento_n_cpus,
211
+ )
212
+ for i in range(n_perm)
213
+ )
214
+
215
+ perm_scores_matrix = np.vstack(perm_results)
216
+ exceed_count = np.sum(perm_scores_matrix >= real_score.reshape(1, -1), axis=0)
217
+ pvals = (1.0 + exceed_count) / (n_perm + 1.0)
218
+
219
+ fdr = np.ones(adata.n_vars)
220
+ valid_expr = adata.var.get("valid_expr", np.ones(adata.n_vars, dtype=bool))
221
+ if valid_expr.sum() > 0:
222
+ fdr[valid_expr] = multipletests(pvals[valid_expr], method="fdr_bh")[1]
223
+
224
+ use_fdr = True
225
+ disabled_reason = None
226
+
227
+ # The caller (active_score) still computes current_max_perm in the same way.
228
+ # We only return the arrays; the small-n decision stays in the orchestrator for clarity.
229
+ return pvals, fdr, use_fdr, disabled_reason
scatrans/_utils.py ADDED
@@ -0,0 +1,343 @@
1
+ """
2
+ scATrans internal utilities (not part of public API).
3
+
4
+ Small, pure or near-pure helper functions extracted from the original tl.py
5
+ to keep the core active_score readable and to enable reuse (esp. bias correction).
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import logging
11
+ from math import comb # re-exported for permutation use
12
+ from typing import Any, Iterable
13
+
14
+ import anndata as ad
15
+ import numpy as np
16
+ import pandas as pd
17
+ from scipy import sparse
18
+ from sklearn.linear_model import HuberRegressor
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ # Re-export for modules that need it without importing math directly
24
+ __all__ = [
25
+ "comb",
26
+ "_is_integer_counts_like",
27
+ "_warn_if_not_integer_counts_matrix",
28
+ "_warn_if_low_counts_matrix",
29
+ "_safe_add_matrices",
30
+ "_normalize_velocity_layers_by_size_factor",
31
+ "_get_group_mean",
32
+ "_get_exponential_scale_lambda",
33
+ "_soft_scale",
34
+ "_pseudobulk_with_layers",
35
+ "_fit_huber_bias_correction",
36
+ ]
37
+
38
+
39
+ def _is_integer_counts_like(X: Any, max_check: int = 100000) -> bool:
40
+ if sparse.issparse(X):
41
+ data = X.data
42
+ if data.size == 0:
43
+ return True
44
+ if not np.all(np.isfinite(data)):
45
+ return False
46
+ vals = data
47
+ else:
48
+ arr = np.asarray(X)
49
+ if not np.all(np.isfinite(arr)):
50
+ return False
51
+ vals = arr.ravel()
52
+
53
+ if vals.size == 0:
54
+ return True
55
+
56
+ if vals.size > max_check:
57
+ rng = np.random.default_rng(0)
58
+ vals = rng.choice(vals, size=max_check, replace=False)
59
+
60
+ return np.all(vals >= 0) and np.allclose(vals, np.round(vals))
61
+
62
+
63
+ def _warn_if_not_integer_counts_matrix(X: Any, max_check: int = 100000) -> None:
64
+ if not _is_integer_counts_like(X, max_check=max_check):
65
+ logger.warning(
66
+ "Data passed to PyDESeq2 may not be raw non-negative integer counts. "
67
+ "Please ensure the input contains unnormalized counts."
68
+ )
69
+
70
+
71
+ def _warn_if_low_counts_matrix(X: Any, max_check: int = 100000) -> None:
72
+ vals = X.data if sparse.issparse(X) else np.asarray(X).ravel()
73
+
74
+ vals = vals[np.isfinite(vals)]
75
+ if vals.size == 0:
76
+ return
77
+
78
+ if vals.size > max_check:
79
+ rng = np.random.default_rng(0)
80
+ vals = rng.choice(vals, size=max_check, replace=False)
81
+
82
+ if vals.max() < 30:
83
+ logger.warning(
84
+ "Maximum count passed to PyDESeq2 is <30. This may be valid for small datasets, "
85
+ "but please verify that the matrix contains raw counts, not normalized/log-transformed values."
86
+ )
87
+
88
+
89
+ def _safe_add_matrices(a: Any, b: Any) -> Any:
90
+ if sparse.issparse(a) or sparse.issparse(b):
91
+ return sparse.csr_matrix(a) + sparse.csr_matrix(b)
92
+ return np.asarray(a) + np.asarray(b)
93
+
94
+
95
+ def _normalize_velocity_layers_by_size_factor(
96
+ uns_layer: Any, spl_layer: Any, target_sum: float | None = None
97
+ ) -> tuple[Any, Any, np.ndarray, np.ndarray]:
98
+ total_layer = _safe_add_matrices(uns_layer, spl_layer)
99
+ row_totals = np.asarray(total_layer.sum(axis=1)).ravel()
100
+ positive = row_totals > 0
101
+
102
+ if positive.sum() == 0:
103
+ return uns_layer, spl_layer, row_totals, np.ones_like(row_totals, dtype=float)
104
+
105
+ if target_sum is None:
106
+ target_sum = np.median(row_totals[positive])
107
+
108
+ factors = target_sum / np.maximum(row_totals, 1e-8)
109
+
110
+ if sparse.issparse(uns_layer) or sparse.issparse(spl_layer):
111
+ uns_layer = sparse.csr_matrix(uns_layer)
112
+ spl_layer = sparse.csr_matrix(spl_layer)
113
+ D = sparse.diags(factors)
114
+ return D @ uns_layer, D @ spl_layer, row_totals, factors
115
+
116
+ return (
117
+ np.asarray(uns_layer) * factors[:, None],
118
+ np.asarray(spl_layer) * factors[:, None],
119
+ row_totals,
120
+ factors,
121
+ )
122
+
123
+
124
+ def _get_group_mean(matrix: Any, mask: np.ndarray) -> np.ndarray:
125
+ if np.sum(mask) == 0:
126
+ raise ValueError("Cannot compute group mean for an empty group.")
127
+ sub = matrix[mask]
128
+ if sparse.issparse(sub):
129
+ return np.asarray(sub.mean(axis=0)).ravel()
130
+ return np.asarray(sub.mean(axis=0)).ravel()
131
+
132
+
133
+ def _get_exponential_scale_lambda(x: np.ndarray) -> float:
134
+ x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
135
+ x_pos = np.clip(x, 0.0, None)
136
+ nonzero_x = x_pos[x_pos > 0]
137
+ if len(nonzero_x) < 2:
138
+ return 1e-8
139
+ med = np.median(nonzero_x)
140
+ return med / np.log(2.0) if med > 0 else 1e-8
141
+
142
+
143
+ def _soft_scale(x: np.ndarray, lam: float) -> np.ndarray:
144
+ x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
145
+ x_pos = np.clip(x, 0.0, None)
146
+ if lam <= 1e-8:
147
+ return np.zeros_like(x)
148
+ return 1.0 - np.exp(-x_pos / lam)
149
+
150
+
151
+ def _pseudobulk_with_layers(
152
+ adata: ad.AnnData,
153
+ sample_col: str,
154
+ groupby: str,
155
+ layers: Iterable[str] = ("spliced", "unspliced"),
156
+ x_layer: str | None = None,
157
+ use_total_for_x: bool = False,
158
+ min_cells: int = 10,
159
+ min_counts: int = 1000,
160
+ ) -> ad.AnnData:
161
+ """Aggregate to pseudobulk while preserving the requested layers."""
162
+ if sample_col not in adata.obs.columns:
163
+ raise ValueError(f"sample_col='{sample_col}' not found.")
164
+ if groupby not in adata.obs.columns:
165
+ raise ValueError(f"groupby='{groupby}' not found.")
166
+ for layer in layers:
167
+ if layer not in adata.layers:
168
+ raise ValueError(f"Layer '{layer}' not found in adata.layers")
169
+
170
+ if use_total_for_x:
171
+ X_source = _safe_add_matrices(adata.layers["spliced"], adata.layers["unspliced"])
172
+ x_source_name = "spliced + unspliced"
173
+ else:
174
+ if x_layer is not None and x_layer not in adata.layers:
175
+ raise ValueError(f"x_layer '{x_layer}' not found in adata.layers")
176
+ X_source = adata.X if x_layer is None else adata.layers[x_layer]
177
+ x_source_name = "adata.X" if x_layer is None else f"layer '{x_layer}'"
178
+
179
+ group_df = adata.obs[[sample_col, groupby]].copy()
180
+ group_df[sample_col] = group_df[sample_col].astype(str)
181
+ group_df[groupby] = group_df[groupby].astype(str)
182
+ pb_key = group_df[sample_col] + "||" + group_df[groupby]
183
+ unique_keys = pd.Index(pb_key.unique())
184
+
185
+ X_rows, obs_rows = [], []
186
+ layer_rows: dict[str, list] = {layer: [] for layer in layers}
187
+
188
+ for key in unique_keys:
189
+ mask = pb_key.values == key
190
+ n_cells = int(mask.sum())
191
+ if n_cells < min_cells:
192
+ continue
193
+ x_sum = np.nan_to_num(np.asarray(X_source[mask].sum(axis=0)).ravel())
194
+ if float(x_sum.sum()) < min_counts:
195
+ continue
196
+
197
+ sample_id, group_value = key.split("||", 1)
198
+ X_rows.append(sparse.csr_matrix(x_sum.reshape(1, -1)))
199
+ obs_rows.append(
200
+ {
201
+ sample_col: sample_id,
202
+ groupby: group_value,
203
+ "n_cells": n_cells,
204
+ "total_counts": float(x_sum.sum()),
205
+ "pb_x_source": x_source_name,
206
+ }
207
+ )
208
+ for layer in layers:
209
+ l_sum = np.nan_to_num(np.asarray(adata.layers[layer][mask].sum(axis=0)).ravel())
210
+ layer_rows[layer].append(sparse.csr_matrix(l_sum.reshape(1, -1)))
211
+
212
+ if not X_rows:
213
+ raise ValueError("No samples remained after pseudobulk filtering.")
214
+
215
+ adata_pb = ad.AnnData(
216
+ X=sparse.vstack(X_rows).tocsr(),
217
+ obs=pd.DataFrame(obs_rows),
218
+ var=adata.var.copy(),
219
+ )
220
+ adata_pb.obs.index = (
221
+ adata_pb.obs[sample_col].astype(str) + "_" + adata_pb.obs[groupby].astype(str)
222
+ )
223
+ for layer in layers:
224
+ adata_pb.layers[layer] = sparse.vstack(layer_rows[layer]).tocsr()
225
+ adata_pb.obs_names_make_unique()
226
+ return adata_pb
227
+
228
+
229
+ def _is_bias_correction_enabled(val: Any) -> bool:
230
+ """Return True unless the user explicitly disabled bias correction."""
231
+ if val is None or val is False:
232
+ return False
233
+ if isinstance(val, str):
234
+ v = val.strip().lower()
235
+ if v in ("none", "off", "no", "false", "disable", ""):
236
+ return False
237
+ return True
238
+
239
+
240
+ def _fit_huber_bias_correction(
241
+ delta_velocity: np.ndarray,
242
+ gene_length: np.ndarray,
243
+ intron_number: np.ndarray,
244
+ total_us_for_weights: np.ndarray,
245
+ valid_feat: np.ndarray,
246
+ valid_expr: np.ndarray,
247
+ X_features: np.ndarray | None,
248
+ min_fit_obs: int = 30,
249
+ huber_epsilon: float = 1.35,
250
+ huber_max_iter: int = 500,
251
+ bias_correction: str = "huber_length_intron",
252
+ ) -> tuple[np.ndarray, dict[str, Any]]:
253
+ """
254
+ Shared Huber regression bias correction (or median fallback).
255
+
256
+ Used by both the main analysis path and the permutation tasks so the
257
+ correction logic stays in one place (DRY).
258
+
259
+ bias_correction controls behavior:
260
+ - "huber_length_intron" (default), "huber", "yes", "on": perform the
261
+ length+intron Huber correction (with median fallback if regression
262
+ cannot be fit).
263
+ - "none", "off", False, None: disable correction entirely; residual is
264
+ the raw delta_velocity (no subtraction of fit or median). This keeps
265
+ the basic analysis clean for users who do not want the correction.
266
+
267
+ Returns (residual, bias_info_dict) where bias_info contains:
268
+ - "bias_corrected": bool
269
+ - "method": the effective method used ("huber_length_intron" or "none")
270
+ - "n_genes_used_for_fit": int
271
+ - "fallback_to_median": bool
272
+ - "coef_gene_length", "coef_intron_number" (if regression succeeded)
273
+ - "intercept" (if available)
274
+ """
275
+ residual = np.zeros_like(delta_velocity, dtype=float)
276
+ method = str(bias_correction) if bias_correction is not None else "huber_length_intron"
277
+ bias_info: dict[str, Any] = {
278
+ "bias_corrected": False,
279
+ "method": method,
280
+ "n_genes_used_for_fit": 0,
281
+ "fallback_to_median": False,
282
+ "coef_gene_length": np.nan,
283
+ "coef_intron_number": np.nan,
284
+ "intercept": np.nan,
285
+ }
286
+
287
+ if not _is_bias_correction_enabled(bias_correction):
288
+ # No correction at all: residual == raw delta (clipped for invalid expr)
289
+ residual = np.array(delta_velocity, dtype=float, copy=True)
290
+ residual[~valid_expr] = 0.0
291
+ bias_info["bias_corrected"] = False
292
+ bias_info["fallback_to_median"] = False
293
+ bias_info["n_genes_used_for_fit"] = 0
294
+ return residual, bias_info
295
+
296
+ fit_mask = valid_feat & valid_expr
297
+ n_fit = int(fit_mask.sum())
298
+ bias_info["n_genes_used_for_fit"] = n_fit
299
+
300
+ regression_succeeded = False
301
+ if X_features is not None and n_fit >= min_fit_obs:
302
+ try:
303
+ X_fit = np.column_stack(
304
+ [
305
+ np.log1p(gene_length[fit_mask]),
306
+ np.log1p(intron_number[fit_mask]),
307
+ ]
308
+ )
309
+ weights = np.clip(
310
+ total_us_for_weights[fit_mask],
311
+ a_min=None,
312
+ a_max=np.percentile(total_us_for_weights[fit_mask], 95),
313
+ )
314
+ with warnings.catch_warnings():
315
+ import warnings as _w
316
+
317
+ _w.simplefilter("ignore")
318
+ model = HuberRegressor(epsilon=huber_epsilon, max_iter=huber_max_iter).fit(
319
+ X_fit, delta_velocity[fit_mask], sample_weight=weights
320
+ )
321
+ pred = model.predict(X_features)
322
+ residual[valid_feat] = delta_velocity[valid_feat] - pred
323
+ regression_succeeded = True
324
+ bias_info["bias_corrected"] = True
325
+ if hasattr(model, "coef_") and len(model.coef_) >= 2:
326
+ bias_info["coef_gene_length"] = float(model.coef_[0])
327
+ bias_info["coef_intron_number"] = float(model.coef_[1])
328
+ if hasattr(model, "intercept_"):
329
+ bias_info["intercept"] = float(model.intercept_)
330
+ except Exception as e:
331
+ logger.warning("Bias correction failed. Falling back to median. Reason: %s", e)
332
+
333
+ if not regression_succeeded and valid_expr.sum() > 0:
334
+ residual[valid_expr] = delta_velocity[valid_expr] - np.nanmedian(delta_velocity[valid_expr])
335
+ bias_info["fallback_to_median"] = True
336
+ bias_info["bias_corrected"] = True # median correction still applied
337
+
338
+ residual[~valid_expr] = 0.0
339
+ return residual, bias_info
340
+
341
+
342
+ # warnings is used inside _fit_huber_bias_correction
343
+ import warnings # noqa: E402 (executed at import time)