scatrans 0.7.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scatrans/__init__.py +56 -0
- scatrans/_bias.py +24 -0
- scatrans/_de.py +555 -0
- scatrans/_permutation.py +229 -0
- scatrans/_utils.py +343 -0
- scatrans/_velocity.py +152 -0
- scatrans/_version.py +24 -0
- scatrans/data/Hs_KEGG_2026.txt +223 -0
- scatrans/data/Mm_KEGG_2026.txt +219 -0
- scatrans/data/Mus_musculus.GRCm39.115_gene_features.parquet +0 -0
- scatrans/data/README.md +94 -0
- scatrans/data/mouse_2020A_gene_features.parquet +0 -0
- scatrans/enrich.py +735 -0
- scatrans/generate_gene_features.py +99 -0
- scatrans/pl.py +1168 -0
- scatrans/pp_bias.py +261 -0
- scatrans/qc.py +41 -0
- scatrans/tl.py +1618 -0
- scatrans-0.7.0.dev0.dist-info/METADATA +826 -0
- scatrans-0.7.0.dev0.dist-info/RECORD +24 -0
- scatrans-0.7.0.dev0.dist-info/WHEEL +5 -0
- scatrans-0.7.0.dev0.dist-info/entry_points.txt +2 -0
- scatrans-0.7.0.dev0.dist-info/licenses/LICENSE +21 -0
- scatrans-0.7.0.dev0.dist-info/top_level.txt +1 -0
scatrans/_permutation.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Permutation testing support for active_score significance.
|
|
3
|
+
|
|
4
|
+
The heavy _single_permutation_task (and the orchestration) was one of the
|
|
5
|
+
biggest contributors to tl.py line count. Extracted here.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
import warnings
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import scanpy as sc
|
|
16
|
+
from joblib import Parallel, delayed
|
|
17
|
+
from statsmodels.stats.multitest import multipletests
|
|
18
|
+
|
|
19
|
+
from ._de import _run_de_wrapper
|
|
20
|
+
|
|
21
|
+
# local import to avoid circulars at module load
|
|
22
|
+
from ._utils import (
|
|
23
|
+
_fit_huber_bias_correction,
|
|
24
|
+
_soft_scale,
|
|
25
|
+
)
|
|
26
|
+
from ._velocity import _compute_velocity_delta
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _single_permutation_task(
|
|
32
|
+
seed: int,
|
|
33
|
+
original_labels: np.ndarray,
|
|
34
|
+
target_group: str,
|
|
35
|
+
reference_group: str,
|
|
36
|
+
adata_subset: Any,
|
|
37
|
+
X_features: np.ndarray | None,
|
|
38
|
+
valid_feat: np.ndarray,
|
|
39
|
+
uns_layer: Any,
|
|
40
|
+
spl_layer: Any,
|
|
41
|
+
total_us_for_filter: np.ndarray,
|
|
42
|
+
min_total_counts: int,
|
|
43
|
+
weight_fc: float,
|
|
44
|
+
weight_unspliced: float,
|
|
45
|
+
weight_pval: float,
|
|
46
|
+
lambda_fc: float,
|
|
47
|
+
lambda_res: float,
|
|
48
|
+
lambda_pval: float,
|
|
49
|
+
is_pseudobulk: bool,
|
|
50
|
+
pb_backend: str,
|
|
51
|
+
de_method: str,
|
|
52
|
+
prior_weight: float,
|
|
53
|
+
de_preprocess: str,
|
|
54
|
+
strict_pydeseq2_counts: bool,
|
|
55
|
+
bias_correction: str = "huber_length_intron",
|
|
56
|
+
# Memento support for permutation (advanced, usually False for speed)
|
|
57
|
+
use_memento_de: bool = False,
|
|
58
|
+
memento_capture_rate: float = 0.07,
|
|
59
|
+
memento_num_boot: int = 5000,
|
|
60
|
+
memento_n_cpus: int = -1,
|
|
61
|
+
) -> np.ndarray:
|
|
62
|
+
"""One permutation replicate. Returns the active score vector for that shuffle.
|
|
63
|
+
|
|
64
|
+
bias_correction is forwarded to the shared bias correction routine so that
|
|
65
|
+
permutation scores are computed under the same correction setting the user chose
|
|
66
|
+
for the real data (default = on).
|
|
67
|
+
"""
|
|
68
|
+
rng = np.random.default_rng(seed)
|
|
69
|
+
for _ in range(50):
|
|
70
|
+
shuffled_labels = rng.permutation(original_labels)
|
|
71
|
+
if not np.array_equal(shuffled_labels, original_labels):
|
|
72
|
+
break
|
|
73
|
+
else:
|
|
74
|
+
logger.warning("Failed to generate a different permutation after 50 attempts.")
|
|
75
|
+
|
|
76
|
+
ad_temp = adata_subset.copy()
|
|
77
|
+
|
|
78
|
+
if de_preprocess == "normalize_log1p":
|
|
79
|
+
sc.pp.normalize_total(ad_temp, target_sum=1e4)
|
|
80
|
+
sc.pp.log1p(ad_temp)
|
|
81
|
+
elif de_preprocess == "auto" and not (is_pseudobulk and pb_backend == "pydeseq2"):
|
|
82
|
+
if "log1p" not in ad_temp.uns:
|
|
83
|
+
sc.pp.normalize_total(ad_temp, target_sum=1e4)
|
|
84
|
+
sc.pp.log1p(ad_temp)
|
|
85
|
+
elif de_preprocess == "none":
|
|
86
|
+
pass
|
|
87
|
+
|
|
88
|
+
perm_de_df = _run_de_wrapper(
|
|
89
|
+
ad_temp,
|
|
90
|
+
groupby="_unused_when_labels_provided",
|
|
91
|
+
target_group=target_group,
|
|
92
|
+
reference_group=reference_group,
|
|
93
|
+
de_method=de_method,
|
|
94
|
+
is_pseudobulk=is_pseudobulk,
|
|
95
|
+
pb_backend=pb_backend,
|
|
96
|
+
n_jobs=1,
|
|
97
|
+
labels=shuffled_labels,
|
|
98
|
+
strict_pydeseq2_counts=strict_pydeseq2_counts,
|
|
99
|
+
use_memento_de=use_memento_de,
|
|
100
|
+
memento_capture_rate=memento_capture_rate,
|
|
101
|
+
memento_num_boot=memento_num_boot,
|
|
102
|
+
memento_n_cpus=memento_n_cpus,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
t_mask = shuffled_labels == target_group
|
|
106
|
+
r_mask = shuffled_labels == reference_group
|
|
107
|
+
delta_velocity, _, _gamma_ref = _compute_velocity_delta(
|
|
108
|
+
uns_layer, spl_layer, t_mask, r_mask, prior_weight
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
total_us_for_filter = np.asarray(total_us_for_filter)
|
|
112
|
+
valid_expr = total_us_for_filter >= min_total_counts
|
|
113
|
+
|
|
114
|
+
# Use the shared bias correction (DRY). bias_correction setting is respected
|
|
115
|
+
# so that permuted scores are comparable to the real run.
|
|
116
|
+
gene_length = adata_subset.var["gene_length"].values
|
|
117
|
+
intron_number = adata_subset.var["intron_number"].values
|
|
118
|
+
|
|
119
|
+
residual, _bias_info = _fit_huber_bias_correction(
|
|
120
|
+
delta_velocity,
|
|
121
|
+
gene_length,
|
|
122
|
+
intron_number,
|
|
123
|
+
total_us_for_filter,
|
|
124
|
+
valid_feat,
|
|
125
|
+
valid_expr,
|
|
126
|
+
X_features,
|
|
127
|
+
bias_correction=bias_correction,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
s1 = _soft_scale(perm_de_df["logFC"].values, lambda_fc)
|
|
131
|
+
s2 = _soft_scale(residual, lambda_res)
|
|
132
|
+
s3 = _soft_scale(-np.log10(perm_de_df["p_adj"].values + 1e-300), lambda_pval)
|
|
133
|
+
|
|
134
|
+
total_w = weight_fc + weight_unspliced + weight_pval
|
|
135
|
+
return (weight_fc * s1 + weight_unspliced * s2 + weight_pval * s3) / total_w * 100.0
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def run_permutation_test(
|
|
139
|
+
*,
|
|
140
|
+
n_perm: int,
|
|
141
|
+
effective_n_jobs: int,
|
|
142
|
+
random_seed: int,
|
|
143
|
+
obs_labels: np.ndarray,
|
|
144
|
+
target_group: str,
|
|
145
|
+
reference_group: str,
|
|
146
|
+
adata: Any,
|
|
147
|
+
X_features: np.ndarray | None,
|
|
148
|
+
valid_feat: np.ndarray,
|
|
149
|
+
velocity_layer_for_perm_uns: Any,
|
|
150
|
+
velocity_layer_for_perm_spl: Any,
|
|
151
|
+
total_us_raw: np.ndarray,
|
|
152
|
+
min_total_counts: int,
|
|
153
|
+
weight_fc: float,
|
|
154
|
+
weight_unspliced: float,
|
|
155
|
+
weight_pval: float,
|
|
156
|
+
lambda_fc: float,
|
|
157
|
+
lambda_res: float,
|
|
158
|
+
lambda_pval: float,
|
|
159
|
+
is_pseudobulk: bool,
|
|
160
|
+
perm_pb_backend: str,
|
|
161
|
+
perm_de_method: str,
|
|
162
|
+
prior_weight: float,
|
|
163
|
+
de_preprocess: str,
|
|
164
|
+
strict_pydeseq2_counts: bool,
|
|
165
|
+
real_score: np.ndarray,
|
|
166
|
+
bias_correction: str = "huber_length_intron",
|
|
167
|
+
# Memento forwarding for advanced consistent permutation
|
|
168
|
+
use_memento_de: bool = False,
|
|
169
|
+
memento_capture_rate: float = 0.07,
|
|
170
|
+
memento_num_boot: int = 5000,
|
|
171
|
+
memento_n_cpus: int = -1,
|
|
172
|
+
) -> tuple:
|
|
173
|
+
"""Run the full parallel permutation and return (pvals, fdr, use_fdr_for_significance, reason_if_disabled).
|
|
174
|
+
|
|
175
|
+
bias_correction is passed through so permutations match the user's chosen setting.
|
|
176
|
+
"""
|
|
177
|
+
logger.info("Running parallel permutation testing (%d iterations)...", n_perm)
|
|
178
|
+
|
|
179
|
+
with warnings.catch_warnings():
|
|
180
|
+
warnings.simplefilter("ignore")
|
|
181
|
+
perm_results = Parallel(n_jobs=effective_n_jobs, backend="loky")(
|
|
182
|
+
delayed(_single_permutation_task)(
|
|
183
|
+
i + random_seed,
|
|
184
|
+
obs_labels,
|
|
185
|
+
target_group,
|
|
186
|
+
reference_group,
|
|
187
|
+
adata,
|
|
188
|
+
X_features,
|
|
189
|
+
valid_feat,
|
|
190
|
+
velocity_layer_for_perm_uns,
|
|
191
|
+
velocity_layer_for_perm_spl,
|
|
192
|
+
total_us_raw,
|
|
193
|
+
min_total_counts,
|
|
194
|
+
weight_fc,
|
|
195
|
+
weight_unspliced,
|
|
196
|
+
weight_pval,
|
|
197
|
+
lambda_fc,
|
|
198
|
+
lambda_res,
|
|
199
|
+
lambda_pval,
|
|
200
|
+
is_pseudobulk,
|
|
201
|
+
perm_pb_backend,
|
|
202
|
+
perm_de_method,
|
|
203
|
+
prior_weight,
|
|
204
|
+
de_preprocess,
|
|
205
|
+
strict_pydeseq2_counts,
|
|
206
|
+
bias_correction=bias_correction,
|
|
207
|
+
use_memento_de=use_memento_de,
|
|
208
|
+
memento_capture_rate=memento_capture_rate,
|
|
209
|
+
memento_num_boot=memento_num_boot,
|
|
210
|
+
memento_n_cpus=memento_n_cpus,
|
|
211
|
+
)
|
|
212
|
+
for i in range(n_perm)
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
perm_scores_matrix = np.vstack(perm_results)
|
|
216
|
+
exceed_count = np.sum(perm_scores_matrix >= real_score.reshape(1, -1), axis=0)
|
|
217
|
+
pvals = (1.0 + exceed_count) / (n_perm + 1.0)
|
|
218
|
+
|
|
219
|
+
fdr = np.ones(adata.n_vars)
|
|
220
|
+
valid_expr = adata.var.get("valid_expr", np.ones(adata.n_vars, dtype=bool))
|
|
221
|
+
if valid_expr.sum() > 0:
|
|
222
|
+
fdr[valid_expr] = multipletests(pvals[valid_expr], method="fdr_bh")[1]
|
|
223
|
+
|
|
224
|
+
use_fdr = True
|
|
225
|
+
disabled_reason = None
|
|
226
|
+
|
|
227
|
+
# The caller (active_score) still computes current_max_perm in the same way.
|
|
228
|
+
# We only return the arrays; the small-n decision stays in the orchestrator for clarity.
|
|
229
|
+
return pvals, fdr, use_fdr, disabled_reason
|
scatrans/_utils.py
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
1
|
+
"""
|
|
2
|
+
scATrans internal utilities (not part of public API).
|
|
3
|
+
|
|
4
|
+
Small, pure or near-pure helper functions extracted from the original tl.py
|
|
5
|
+
to keep the core active_score readable and to enable reuse (esp. bias correction).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
from math import comb # re-exported for permutation use
|
|
12
|
+
from typing import Any, Iterable
|
|
13
|
+
|
|
14
|
+
import anndata as ad
|
|
15
|
+
import numpy as np
|
|
16
|
+
import pandas as pd
|
|
17
|
+
from scipy import sparse
|
|
18
|
+
from sklearn.linear_model import HuberRegressor
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# Re-export for modules that need it without importing math directly
|
|
24
|
+
__all__ = [
|
|
25
|
+
"comb",
|
|
26
|
+
"_is_integer_counts_like",
|
|
27
|
+
"_warn_if_not_integer_counts_matrix",
|
|
28
|
+
"_warn_if_low_counts_matrix",
|
|
29
|
+
"_safe_add_matrices",
|
|
30
|
+
"_normalize_velocity_layers_by_size_factor",
|
|
31
|
+
"_get_group_mean",
|
|
32
|
+
"_get_exponential_scale_lambda",
|
|
33
|
+
"_soft_scale",
|
|
34
|
+
"_pseudobulk_with_layers",
|
|
35
|
+
"_fit_huber_bias_correction",
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _is_integer_counts_like(X: Any, max_check: int = 100000) -> bool:
|
|
40
|
+
if sparse.issparse(X):
|
|
41
|
+
data = X.data
|
|
42
|
+
if data.size == 0:
|
|
43
|
+
return True
|
|
44
|
+
if not np.all(np.isfinite(data)):
|
|
45
|
+
return False
|
|
46
|
+
vals = data
|
|
47
|
+
else:
|
|
48
|
+
arr = np.asarray(X)
|
|
49
|
+
if not np.all(np.isfinite(arr)):
|
|
50
|
+
return False
|
|
51
|
+
vals = arr.ravel()
|
|
52
|
+
|
|
53
|
+
if vals.size == 0:
|
|
54
|
+
return True
|
|
55
|
+
|
|
56
|
+
if vals.size > max_check:
|
|
57
|
+
rng = np.random.default_rng(0)
|
|
58
|
+
vals = rng.choice(vals, size=max_check, replace=False)
|
|
59
|
+
|
|
60
|
+
return np.all(vals >= 0) and np.allclose(vals, np.round(vals))
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _warn_if_not_integer_counts_matrix(X: Any, max_check: int = 100000) -> None:
|
|
64
|
+
if not _is_integer_counts_like(X, max_check=max_check):
|
|
65
|
+
logger.warning(
|
|
66
|
+
"Data passed to PyDESeq2 may not be raw non-negative integer counts. "
|
|
67
|
+
"Please ensure the input contains unnormalized counts."
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _warn_if_low_counts_matrix(X: Any, max_check: int = 100000) -> None:
|
|
72
|
+
vals = X.data if sparse.issparse(X) else np.asarray(X).ravel()
|
|
73
|
+
|
|
74
|
+
vals = vals[np.isfinite(vals)]
|
|
75
|
+
if vals.size == 0:
|
|
76
|
+
return
|
|
77
|
+
|
|
78
|
+
if vals.size > max_check:
|
|
79
|
+
rng = np.random.default_rng(0)
|
|
80
|
+
vals = rng.choice(vals, size=max_check, replace=False)
|
|
81
|
+
|
|
82
|
+
if vals.max() < 30:
|
|
83
|
+
logger.warning(
|
|
84
|
+
"Maximum count passed to PyDESeq2 is <30. This may be valid for small datasets, "
|
|
85
|
+
"but please verify that the matrix contains raw counts, not normalized/log-transformed values."
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _safe_add_matrices(a: Any, b: Any) -> Any:
|
|
90
|
+
if sparse.issparse(a) or sparse.issparse(b):
|
|
91
|
+
return sparse.csr_matrix(a) + sparse.csr_matrix(b)
|
|
92
|
+
return np.asarray(a) + np.asarray(b)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _normalize_velocity_layers_by_size_factor(
|
|
96
|
+
uns_layer: Any, spl_layer: Any, target_sum: float | None = None
|
|
97
|
+
) -> tuple[Any, Any, np.ndarray, np.ndarray]:
|
|
98
|
+
total_layer = _safe_add_matrices(uns_layer, spl_layer)
|
|
99
|
+
row_totals = np.asarray(total_layer.sum(axis=1)).ravel()
|
|
100
|
+
positive = row_totals > 0
|
|
101
|
+
|
|
102
|
+
if positive.sum() == 0:
|
|
103
|
+
return uns_layer, spl_layer, row_totals, np.ones_like(row_totals, dtype=float)
|
|
104
|
+
|
|
105
|
+
if target_sum is None:
|
|
106
|
+
target_sum = np.median(row_totals[positive])
|
|
107
|
+
|
|
108
|
+
factors = target_sum / np.maximum(row_totals, 1e-8)
|
|
109
|
+
|
|
110
|
+
if sparse.issparse(uns_layer) or sparse.issparse(spl_layer):
|
|
111
|
+
uns_layer = sparse.csr_matrix(uns_layer)
|
|
112
|
+
spl_layer = sparse.csr_matrix(spl_layer)
|
|
113
|
+
D = sparse.diags(factors)
|
|
114
|
+
return D @ uns_layer, D @ spl_layer, row_totals, factors
|
|
115
|
+
|
|
116
|
+
return (
|
|
117
|
+
np.asarray(uns_layer) * factors[:, None],
|
|
118
|
+
np.asarray(spl_layer) * factors[:, None],
|
|
119
|
+
row_totals,
|
|
120
|
+
factors,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _get_group_mean(matrix: Any, mask: np.ndarray) -> np.ndarray:
|
|
125
|
+
if np.sum(mask) == 0:
|
|
126
|
+
raise ValueError("Cannot compute group mean for an empty group.")
|
|
127
|
+
sub = matrix[mask]
|
|
128
|
+
if sparse.issparse(sub):
|
|
129
|
+
return np.asarray(sub.mean(axis=0)).ravel()
|
|
130
|
+
return np.asarray(sub.mean(axis=0)).ravel()
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _get_exponential_scale_lambda(x: np.ndarray) -> float:
|
|
134
|
+
x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
|
|
135
|
+
x_pos = np.clip(x, 0.0, None)
|
|
136
|
+
nonzero_x = x_pos[x_pos > 0]
|
|
137
|
+
if len(nonzero_x) < 2:
|
|
138
|
+
return 1e-8
|
|
139
|
+
med = np.median(nonzero_x)
|
|
140
|
+
return med / np.log(2.0) if med > 0 else 1e-8
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _soft_scale(x: np.ndarray, lam: float) -> np.ndarray:
|
|
144
|
+
x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
|
|
145
|
+
x_pos = np.clip(x, 0.0, None)
|
|
146
|
+
if lam <= 1e-8:
|
|
147
|
+
return np.zeros_like(x)
|
|
148
|
+
return 1.0 - np.exp(-x_pos / lam)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _pseudobulk_with_layers(
|
|
152
|
+
adata: ad.AnnData,
|
|
153
|
+
sample_col: str,
|
|
154
|
+
groupby: str,
|
|
155
|
+
layers: Iterable[str] = ("spliced", "unspliced"),
|
|
156
|
+
x_layer: str | None = None,
|
|
157
|
+
use_total_for_x: bool = False,
|
|
158
|
+
min_cells: int = 10,
|
|
159
|
+
min_counts: int = 1000,
|
|
160
|
+
) -> ad.AnnData:
|
|
161
|
+
"""Aggregate to pseudobulk while preserving the requested layers."""
|
|
162
|
+
if sample_col not in adata.obs.columns:
|
|
163
|
+
raise ValueError(f"sample_col='{sample_col}' not found.")
|
|
164
|
+
if groupby not in adata.obs.columns:
|
|
165
|
+
raise ValueError(f"groupby='{groupby}' not found.")
|
|
166
|
+
for layer in layers:
|
|
167
|
+
if layer not in adata.layers:
|
|
168
|
+
raise ValueError(f"Layer '{layer}' not found in adata.layers")
|
|
169
|
+
|
|
170
|
+
if use_total_for_x:
|
|
171
|
+
X_source = _safe_add_matrices(adata.layers["spliced"], adata.layers["unspliced"])
|
|
172
|
+
x_source_name = "spliced + unspliced"
|
|
173
|
+
else:
|
|
174
|
+
if x_layer is not None and x_layer not in adata.layers:
|
|
175
|
+
raise ValueError(f"x_layer '{x_layer}' not found in adata.layers")
|
|
176
|
+
X_source = adata.X if x_layer is None else adata.layers[x_layer]
|
|
177
|
+
x_source_name = "adata.X" if x_layer is None else f"layer '{x_layer}'"
|
|
178
|
+
|
|
179
|
+
group_df = adata.obs[[sample_col, groupby]].copy()
|
|
180
|
+
group_df[sample_col] = group_df[sample_col].astype(str)
|
|
181
|
+
group_df[groupby] = group_df[groupby].astype(str)
|
|
182
|
+
pb_key = group_df[sample_col] + "||" + group_df[groupby]
|
|
183
|
+
unique_keys = pd.Index(pb_key.unique())
|
|
184
|
+
|
|
185
|
+
X_rows, obs_rows = [], []
|
|
186
|
+
layer_rows: dict[str, list] = {layer: [] for layer in layers}
|
|
187
|
+
|
|
188
|
+
for key in unique_keys:
|
|
189
|
+
mask = pb_key.values == key
|
|
190
|
+
n_cells = int(mask.sum())
|
|
191
|
+
if n_cells < min_cells:
|
|
192
|
+
continue
|
|
193
|
+
x_sum = np.nan_to_num(np.asarray(X_source[mask].sum(axis=0)).ravel())
|
|
194
|
+
if float(x_sum.sum()) < min_counts:
|
|
195
|
+
continue
|
|
196
|
+
|
|
197
|
+
sample_id, group_value = key.split("||", 1)
|
|
198
|
+
X_rows.append(sparse.csr_matrix(x_sum.reshape(1, -1)))
|
|
199
|
+
obs_rows.append(
|
|
200
|
+
{
|
|
201
|
+
sample_col: sample_id,
|
|
202
|
+
groupby: group_value,
|
|
203
|
+
"n_cells": n_cells,
|
|
204
|
+
"total_counts": float(x_sum.sum()),
|
|
205
|
+
"pb_x_source": x_source_name,
|
|
206
|
+
}
|
|
207
|
+
)
|
|
208
|
+
for layer in layers:
|
|
209
|
+
l_sum = np.nan_to_num(np.asarray(adata.layers[layer][mask].sum(axis=0)).ravel())
|
|
210
|
+
layer_rows[layer].append(sparse.csr_matrix(l_sum.reshape(1, -1)))
|
|
211
|
+
|
|
212
|
+
if not X_rows:
|
|
213
|
+
raise ValueError("No samples remained after pseudobulk filtering.")
|
|
214
|
+
|
|
215
|
+
adata_pb = ad.AnnData(
|
|
216
|
+
X=sparse.vstack(X_rows).tocsr(),
|
|
217
|
+
obs=pd.DataFrame(obs_rows),
|
|
218
|
+
var=adata.var.copy(),
|
|
219
|
+
)
|
|
220
|
+
adata_pb.obs.index = (
|
|
221
|
+
adata_pb.obs[sample_col].astype(str) + "_" + adata_pb.obs[groupby].astype(str)
|
|
222
|
+
)
|
|
223
|
+
for layer in layers:
|
|
224
|
+
adata_pb.layers[layer] = sparse.vstack(layer_rows[layer]).tocsr()
|
|
225
|
+
adata_pb.obs_names_make_unique()
|
|
226
|
+
return adata_pb
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def _is_bias_correction_enabled(val: Any) -> bool:
|
|
230
|
+
"""Return True unless the user explicitly disabled bias correction."""
|
|
231
|
+
if val is None or val is False:
|
|
232
|
+
return False
|
|
233
|
+
if isinstance(val, str):
|
|
234
|
+
v = val.strip().lower()
|
|
235
|
+
if v in ("none", "off", "no", "false", "disable", ""):
|
|
236
|
+
return False
|
|
237
|
+
return True
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def _fit_huber_bias_correction(
|
|
241
|
+
delta_velocity: np.ndarray,
|
|
242
|
+
gene_length: np.ndarray,
|
|
243
|
+
intron_number: np.ndarray,
|
|
244
|
+
total_us_for_weights: np.ndarray,
|
|
245
|
+
valid_feat: np.ndarray,
|
|
246
|
+
valid_expr: np.ndarray,
|
|
247
|
+
X_features: np.ndarray | None,
|
|
248
|
+
min_fit_obs: int = 30,
|
|
249
|
+
huber_epsilon: float = 1.35,
|
|
250
|
+
huber_max_iter: int = 500,
|
|
251
|
+
bias_correction: str = "huber_length_intron",
|
|
252
|
+
) -> tuple[np.ndarray, dict[str, Any]]:
|
|
253
|
+
"""
|
|
254
|
+
Shared Huber regression bias correction (or median fallback).
|
|
255
|
+
|
|
256
|
+
Used by both the main analysis path and the permutation tasks so the
|
|
257
|
+
correction logic stays in one place (DRY).
|
|
258
|
+
|
|
259
|
+
bias_correction controls behavior:
|
|
260
|
+
- "huber_length_intron" (default), "huber", "yes", "on": perform the
|
|
261
|
+
length+intron Huber correction (with median fallback if regression
|
|
262
|
+
cannot be fit).
|
|
263
|
+
- "none", "off", False, None: disable correction entirely; residual is
|
|
264
|
+
the raw delta_velocity (no subtraction of fit or median). This keeps
|
|
265
|
+
the basic analysis clean for users who do not want the correction.
|
|
266
|
+
|
|
267
|
+
Returns (residual, bias_info_dict) where bias_info contains:
|
|
268
|
+
- "bias_corrected": bool
|
|
269
|
+
- "method": the effective method used ("huber_length_intron" or "none")
|
|
270
|
+
- "n_genes_used_for_fit": int
|
|
271
|
+
- "fallback_to_median": bool
|
|
272
|
+
- "coef_gene_length", "coef_intron_number" (if regression succeeded)
|
|
273
|
+
- "intercept" (if available)
|
|
274
|
+
"""
|
|
275
|
+
residual = np.zeros_like(delta_velocity, dtype=float)
|
|
276
|
+
method = str(bias_correction) if bias_correction is not None else "huber_length_intron"
|
|
277
|
+
bias_info: dict[str, Any] = {
|
|
278
|
+
"bias_corrected": False,
|
|
279
|
+
"method": method,
|
|
280
|
+
"n_genes_used_for_fit": 0,
|
|
281
|
+
"fallback_to_median": False,
|
|
282
|
+
"coef_gene_length": np.nan,
|
|
283
|
+
"coef_intron_number": np.nan,
|
|
284
|
+
"intercept": np.nan,
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
if not _is_bias_correction_enabled(bias_correction):
|
|
288
|
+
# No correction at all: residual == raw delta (clipped for invalid expr)
|
|
289
|
+
residual = np.array(delta_velocity, dtype=float, copy=True)
|
|
290
|
+
residual[~valid_expr] = 0.0
|
|
291
|
+
bias_info["bias_corrected"] = False
|
|
292
|
+
bias_info["fallback_to_median"] = False
|
|
293
|
+
bias_info["n_genes_used_for_fit"] = 0
|
|
294
|
+
return residual, bias_info
|
|
295
|
+
|
|
296
|
+
fit_mask = valid_feat & valid_expr
|
|
297
|
+
n_fit = int(fit_mask.sum())
|
|
298
|
+
bias_info["n_genes_used_for_fit"] = n_fit
|
|
299
|
+
|
|
300
|
+
regression_succeeded = False
|
|
301
|
+
if X_features is not None and n_fit >= min_fit_obs:
|
|
302
|
+
try:
|
|
303
|
+
X_fit = np.column_stack(
|
|
304
|
+
[
|
|
305
|
+
np.log1p(gene_length[fit_mask]),
|
|
306
|
+
np.log1p(intron_number[fit_mask]),
|
|
307
|
+
]
|
|
308
|
+
)
|
|
309
|
+
weights = np.clip(
|
|
310
|
+
total_us_for_weights[fit_mask],
|
|
311
|
+
a_min=None,
|
|
312
|
+
a_max=np.percentile(total_us_for_weights[fit_mask], 95),
|
|
313
|
+
)
|
|
314
|
+
with warnings.catch_warnings():
|
|
315
|
+
import warnings as _w
|
|
316
|
+
|
|
317
|
+
_w.simplefilter("ignore")
|
|
318
|
+
model = HuberRegressor(epsilon=huber_epsilon, max_iter=huber_max_iter).fit(
|
|
319
|
+
X_fit, delta_velocity[fit_mask], sample_weight=weights
|
|
320
|
+
)
|
|
321
|
+
pred = model.predict(X_features)
|
|
322
|
+
residual[valid_feat] = delta_velocity[valid_feat] - pred
|
|
323
|
+
regression_succeeded = True
|
|
324
|
+
bias_info["bias_corrected"] = True
|
|
325
|
+
if hasattr(model, "coef_") and len(model.coef_) >= 2:
|
|
326
|
+
bias_info["coef_gene_length"] = float(model.coef_[0])
|
|
327
|
+
bias_info["coef_intron_number"] = float(model.coef_[1])
|
|
328
|
+
if hasattr(model, "intercept_"):
|
|
329
|
+
bias_info["intercept"] = float(model.intercept_)
|
|
330
|
+
except Exception as e:
|
|
331
|
+
logger.warning("Bias correction failed. Falling back to median. Reason: %s", e)
|
|
332
|
+
|
|
333
|
+
if not regression_succeeded and valid_expr.sum() > 0:
|
|
334
|
+
residual[valid_expr] = delta_velocity[valid_expr] - np.nanmedian(delta_velocity[valid_expr])
|
|
335
|
+
bias_info["fallback_to_median"] = True
|
|
336
|
+
bias_info["bias_corrected"] = True # median correction still applied
|
|
337
|
+
|
|
338
|
+
residual[~valid_expr] = 0.0
|
|
339
|
+
return residual, bias_info
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
# warnings is used inside _fit_huber_bias_correction
|
|
343
|
+
import warnings # noqa: E402 (executed at import time)
|