smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +6 -8
- smftools/_settings.py +4 -6
- smftools/_version.py +1 -1
- smftools/cli/helpers.py +54 -0
- smftools/cli/hmm_adata.py +937 -256
- smftools/cli/load_adata.py +448 -268
- smftools/cli/preprocess_adata.py +469 -263
- smftools/cli/spatial_adata.py +536 -319
- smftools/cli_entry.py +97 -182
- smftools/config/__init__.py +1 -1
- smftools/config/conversion.yaml +17 -6
- smftools/config/deaminase.yaml +12 -10
- smftools/config/default.yaml +142 -33
- smftools/config/direct.yaml +11 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +594 -264
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +2 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2128 -1418
- smftools/hmm/__init__.py +2 -9
- smftools/hmm/archived/call_hmm_peaks.py +121 -0
- smftools/hmm/call_hmm_peaks.py +299 -91
- smftools/hmm/display_hmm.py +19 -6
- smftools/hmm/hmm_readwrite.py +13 -4
- smftools/hmm/nucleosome_hmm_refinement.py +102 -14
- smftools/informatics/__init__.py +30 -7
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
- smftools/informatics/archived/print_bam_query_seq.py +7 -1
- smftools/informatics/bam_functions.py +397 -175
- smftools/informatics/basecalling.py +51 -9
- smftools/informatics/bed_functions.py +90 -57
- smftools/informatics/binarize_converted_base_identities.py +18 -7
- smftools/informatics/complement_base_list.py +7 -6
- smftools/informatics/converted_BAM_to_adata.py +265 -122
- smftools/informatics/fasta_functions.py +161 -83
- smftools/informatics/h5ad_functions.py +196 -30
- smftools/informatics/modkit_extract_to_adata.py +609 -270
- smftools/informatics/modkit_functions.py +85 -44
- smftools/informatics/ohe.py +44 -21
- smftools/informatics/pod5_functions.py +112 -73
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +2 -7
- smftools/machine_learning/data/anndata_data_module.py +143 -50
- smftools/machine_learning/data/preprocessing.py +2 -1
- smftools/machine_learning/evaluation/__init__.py +1 -1
- smftools/machine_learning/evaluation/eval_utils.py +11 -14
- smftools/machine_learning/evaluation/evaluators.py +46 -33
- smftools/machine_learning/inference/__init__.py +1 -1
- smftools/machine_learning/inference/inference_utils.py +7 -4
- smftools/machine_learning/inference/lightning_inference.py +9 -13
- smftools/machine_learning/inference/sklearn_inference.py +6 -8
- smftools/machine_learning/inference/sliding_window_inference.py +35 -25
- smftools/machine_learning/models/__init__.py +10 -5
- smftools/machine_learning/models/base.py +28 -42
- smftools/machine_learning/models/cnn.py +15 -11
- smftools/machine_learning/models/lightning_base.py +71 -40
- smftools/machine_learning/models/mlp.py +13 -4
- smftools/machine_learning/models/positional.py +3 -2
- smftools/machine_learning/models/rnn.py +3 -2
- smftools/machine_learning/models/sklearn_models.py +39 -22
- smftools/machine_learning/models/transformer.py +68 -53
- smftools/machine_learning/models/wrappers.py +2 -1
- smftools/machine_learning/training/__init__.py +2 -2
- smftools/machine_learning/training/train_lightning_model.py +29 -20
- smftools/machine_learning/training/train_sklearn_model.py +9 -15
- smftools/machine_learning/utils/__init__.py +1 -1
- smftools/machine_learning/utils/device.py +7 -4
- smftools/machine_learning/utils/grl.py +3 -1
- smftools/metadata.py +443 -0
- smftools/plotting/__init__.py +19 -5
- smftools/plotting/autocorrelation_plotting.py +145 -44
- smftools/plotting/classifiers.py +162 -72
- smftools/plotting/general_plotting.py +422 -197
- smftools/plotting/hmm_plotting.py +42 -13
- smftools/plotting/position_stats.py +147 -87
- smftools/plotting/qc_plotting.py +20 -12
- smftools/preprocessing/__init__.py +10 -12
- smftools/preprocessing/append_base_context.py +115 -80
- smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
- smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
- smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +129 -31
- smftools/preprocessing/binary_layers_to_ohe.py +17 -11
- smftools/preprocessing/calculate_complexity_II.py +86 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +50 -25
- smftools/preprocessing/calculate_pairwise_differences.py +2 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
- smftools/preprocessing/calculate_position_Youden.py +118 -54
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
- smftools/preprocessing/flag_duplicate_reads.py +689 -272
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +8 -3
- smftools/preprocessing/min_non_diagonal.py +2 -1
- smftools/preprocessing/recipes.py +56 -23
- smftools/preprocessing/reindex_references_adata.py +103 -0
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +331 -82
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +3 -4
- smftools/tools/archived/classifiers.py +163 -0
- smftools/tools/archived/subset_adata_v1.py +10 -1
- smftools/tools/archived/subset_adata_v2.py +12 -1
- smftools/tools/calculate_umap.py +54 -15
- smftools/tools/cluster_adata_on_methylation.py +115 -46
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +229 -98
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
- smftools-0.2.5.dist-info/RECORD +181 -0
- smftools-0.2.3.dist-info/RECORD +0 -173
- /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
- /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
- /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
- /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
- /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
- /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
- /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,24 +1,28 @@
|
|
|
1
1
|
# duplicate_detection_with_hier_and_plots.py
|
|
2
2
|
import copy
|
|
3
|
-
import warnings
|
|
4
3
|
import math
|
|
5
4
|
import os
|
|
5
|
+
import warnings
|
|
6
6
|
from collections import defaultdict
|
|
7
|
-
from typing import Dict,
|
|
7
|
+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
|
8
8
|
|
|
9
|
-
import torch
|
|
10
9
|
import anndata as ad
|
|
10
|
+
import matplotlib.pyplot as plt
|
|
11
11
|
import numpy as np
|
|
12
12
|
import pandas as pd
|
|
13
|
-
import
|
|
14
|
-
|
|
13
|
+
import torch
|
|
14
|
+
|
|
15
|
+
from smftools.logging_utils import get_logger
|
|
15
16
|
|
|
16
17
|
from ..readwrite import make_dirs
|
|
17
18
|
|
|
19
|
+
logger = get_logger(__name__)
|
|
20
|
+
|
|
18
21
|
# optional imports for clustering / PCA / KDE
|
|
19
22
|
try:
|
|
20
23
|
from scipy.cluster import hierarchy as sch
|
|
21
24
|
from scipy.spatial.distance import pdist, squareform
|
|
25
|
+
|
|
22
26
|
SCIPY_AVAILABLE = True
|
|
23
27
|
except Exception:
|
|
24
28
|
sch = None
|
|
@@ -27,10 +31,11 @@ except Exception:
|
|
|
27
31
|
SCIPY_AVAILABLE = False
|
|
28
32
|
|
|
29
33
|
try:
|
|
34
|
+
from sklearn.cluster import DBSCAN, KMeans
|
|
30
35
|
from sklearn.decomposition import PCA
|
|
31
|
-
from sklearn.cluster import KMeans, DBSCAN
|
|
32
|
-
from sklearn.mixture import GaussianMixture
|
|
33
36
|
from sklearn.metrics import silhouette_score
|
|
37
|
+
from sklearn.mixture import GaussianMixture
|
|
38
|
+
|
|
34
39
|
SKLEARN_AVAILABLE = True
|
|
35
40
|
except Exception:
|
|
36
41
|
PCA = None
|
|
@@ -43,10 +48,16 @@ except Exception:
|
|
|
43
48
|
gaussian_kde = None
|
|
44
49
|
|
|
45
50
|
|
|
46
|
-
def merge_uns_preserve(orig_uns: dict, new_uns: dict, prefer="orig") -> dict:
|
|
47
|
-
"""
|
|
48
|
-
|
|
49
|
-
|
|
51
|
+
def merge_uns_preserve(orig_uns: dict, new_uns: dict, prefer: str = "orig") -> dict:
|
|
52
|
+
"""Merge two ``.uns`` dictionaries while preserving preferred values.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
orig_uns: Original ``.uns`` dictionary.
|
|
56
|
+
new_uns: New ``.uns`` dictionary to merge.
|
|
57
|
+
prefer: Which dictionary to prefer on conflict (``"orig"`` or ``"new"``).
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
dict: Merged dictionary.
|
|
50
61
|
"""
|
|
51
62
|
out = copy.deepcopy(new_uns) if new_uns is not None else {}
|
|
52
63
|
for k, v in (orig_uns or {}).items():
|
|
@@ -55,7 +66,7 @@ def merge_uns_preserve(orig_uns: dict, new_uns: dict, prefer="orig") -> dict:
|
|
|
55
66
|
else:
|
|
56
67
|
# present in both: compare quickly (best-effort)
|
|
57
68
|
try:
|
|
58
|
-
equal =
|
|
69
|
+
equal = out[k] == v
|
|
59
70
|
except Exception:
|
|
60
71
|
equal = False
|
|
61
72
|
if equal:
|
|
@@ -69,19 +80,20 @@ def merge_uns_preserve(orig_uns: dict, new_uns: dict, prefer="orig") -> dict:
|
|
|
69
80
|
out[f"orig_uns__{k}"] = copy.deepcopy(v)
|
|
70
81
|
return out
|
|
71
82
|
|
|
83
|
+
|
|
72
84
|
def flag_duplicate_reads(
|
|
73
|
-
adata,
|
|
74
|
-
var_filters_sets,
|
|
85
|
+
adata: ad.AnnData,
|
|
86
|
+
var_filters_sets: Sequence[dict[str, Any]],
|
|
75
87
|
distance_threshold: float = 0.07,
|
|
76
88
|
obs_reference_col: str = "Reference_strand",
|
|
77
89
|
sample_col: str = "Barcode",
|
|
78
90
|
output_directory: Optional[str] = None,
|
|
79
91
|
metric_keys: Union[str, List[str]] = ("Fraction_any_C_site_modified",),
|
|
80
|
-
uns_flag: str = "
|
|
92
|
+
uns_flag: str = "flag_duplicate_reads_performed",
|
|
81
93
|
uns_filtered_flag: str = "read_duplicates_removed",
|
|
82
94
|
bypass: bool = False,
|
|
83
95
|
force_redo: bool = False,
|
|
84
|
-
keep_best_metric: Optional[str] =
|
|
96
|
+
keep_best_metric: Optional[str] = "read_quality",
|
|
85
97
|
keep_best_higher: bool = True,
|
|
86
98
|
window_size: int = 50,
|
|
87
99
|
min_overlap_positions: int = 20,
|
|
@@ -93,19 +105,137 @@ def flag_duplicate_reads(
|
|
|
93
105
|
hierarchical_metric: str = "euclidean",
|
|
94
106
|
hierarchical_window: int = 50,
|
|
95
107
|
random_state: int = 0,
|
|
96
|
-
|
|
108
|
+
demux_types: Optional[Sequence[str]] = None,
|
|
109
|
+
demux_col: str = "demux_type",
|
|
110
|
+
) -> ad.AnnData:
|
|
111
|
+
"""Flag duplicate reads with demux-aware keeper preference.
|
|
112
|
+
|
|
113
|
+
Behavior:
|
|
114
|
+
- All reads are processed (no masking by demux).
|
|
115
|
+
- At each keeper decision, prefer reads whose ``demux_col`` value is in
|
|
116
|
+
``demux_types`` when present. Among candidates, choose by
|
|
117
|
+
``keep_best_metric``.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
adata: AnnData object to process.
|
|
121
|
+
var_filters_sets: Sequence of variable filter definitions.
|
|
122
|
+
distance_threshold: Distance threshold for duplicate detection.
|
|
123
|
+
obs_reference_col: Obs column containing reference identifiers.
|
|
124
|
+
sample_col: Obs column containing sample identifiers.
|
|
125
|
+
output_directory: Directory for output plots and artifacts.
|
|
126
|
+
metric_keys: Metric key(s) used in processing.
|
|
127
|
+
uns_flag: Flag in ``adata.uns`` indicating prior completion.
|
|
128
|
+
uns_filtered_flag: Flag to mark read duplicates removal.
|
|
129
|
+
bypass: Whether to skip processing.
|
|
130
|
+
force_redo: Whether to rerun even if ``uns_flag`` is set.
|
|
131
|
+
keep_best_metric: Obs column used to select best read within duplicates.
|
|
132
|
+
keep_best_higher: Whether higher values in ``keep_best_metric`` are preferred.
|
|
133
|
+
window_size: Window size for local comparisons.
|
|
134
|
+
min_overlap_positions: Minimum overlapping positions required.
|
|
135
|
+
do_pca: Whether to run PCA before clustering.
|
|
136
|
+
pca_n_components: Number of PCA components.
|
|
137
|
+
pca_center: Whether to center data before PCA.
|
|
138
|
+
do_hierarchical: Whether to run hierarchical clustering.
|
|
139
|
+
hierarchical_linkage: Linkage method for hierarchical clustering.
|
|
140
|
+
hierarchical_metric: Distance metric for hierarchical clustering.
|
|
141
|
+
hierarchical_window: Window size for hierarchical clustering.
|
|
142
|
+
random_state: Random seed.
|
|
143
|
+
demux_types: Preferred demux types for keeper selection.
|
|
144
|
+
demux_col: Obs column containing demux type labels.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
anndata.AnnData: AnnData object with duplicate flags stored in ``adata.obs``.
|
|
97
148
|
"""
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
enforcement happens only after hierarchical merging.
|
|
149
|
+
import copy
|
|
150
|
+
import warnings
|
|
101
151
|
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
152
|
+
import anndata as ad
|
|
153
|
+
import numpy as np
|
|
154
|
+
import pandas as pd
|
|
155
|
+
|
|
156
|
+
# optional imports already guarded at module import time, but re-check
|
|
157
|
+
try:
|
|
158
|
+
from scipy.cluster import hierarchy as sch
|
|
159
|
+
from scipy.spatial.distance import pdist
|
|
160
|
+
|
|
161
|
+
SCIPY_AVAILABLE = True
|
|
162
|
+
except Exception:
|
|
163
|
+
sch = None
|
|
164
|
+
pdist = None
|
|
165
|
+
SCIPY_AVAILABLE = False
|
|
166
|
+
try:
|
|
167
|
+
from sklearn.decomposition import PCA
|
|
168
|
+
|
|
169
|
+
SKLEARN_AVAILABLE = True
|
|
170
|
+
except Exception:
|
|
171
|
+
PCA = None
|
|
172
|
+
SKLEARN_AVAILABLE = False
|
|
173
|
+
|
|
174
|
+
# -------- helper: demux-aware keeper selection --------
|
|
175
|
+
def _choose_keeper_with_demux_preference(
|
|
176
|
+
members_idx: List[int],
|
|
177
|
+
adata_subset: ad.AnnData,
|
|
178
|
+
obs_index_list: List[Any],
|
|
179
|
+
*,
|
|
180
|
+
demux_col: str = "demux_type",
|
|
181
|
+
preferred_types: Optional[Sequence[str]] = None,
|
|
182
|
+
keep_best_metric: Optional[str] = None,
|
|
183
|
+
keep_best_higher: bool = True,
|
|
184
|
+
lex_keeper_mask: Optional[np.ndarray] = None, # aligned to members order
|
|
185
|
+
) -> int:
|
|
186
|
+
"""
|
|
187
|
+
Prefer members whose demux_col ∈ preferred_types.
|
|
188
|
+
Among candidates, pick by keep_best_metric (higher/lower).
|
|
189
|
+
If metric missing/NaN, prefer lex keeper (via mask) among candidates.
|
|
190
|
+
Fallback: first candidate.
|
|
191
|
+
Returns the chosen *member index* (int from members_idx).
|
|
192
|
+
"""
|
|
193
|
+
# 1) demux-preferred candidates
|
|
194
|
+
if preferred_types and (demux_col in adata_subset.obs.columns):
|
|
195
|
+
preferred = set(map(str, preferred_types))
|
|
196
|
+
demux_series = adata_subset.obs[demux_col].astype("string")
|
|
197
|
+
names = [obs_index_list[m] for m in members_idx]
|
|
198
|
+
is_pref = demux_series.loc[names].isin(preferred).to_numpy()
|
|
199
|
+
candidates = [members_idx[i] for i, ok in enumerate(is_pref) if ok]
|
|
200
|
+
else:
|
|
201
|
+
candidates = []
|
|
202
|
+
|
|
203
|
+
if not candidates:
|
|
204
|
+
candidates = list(members_idx)
|
|
205
|
+
|
|
206
|
+
# 2) metric-based within candidates
|
|
207
|
+
if keep_best_metric and (keep_best_metric in adata_subset.obs.columns):
|
|
208
|
+
cand_names = [obs_index_list[m] for m in candidates]
|
|
209
|
+
try:
|
|
210
|
+
vals = pd.to_numeric(
|
|
211
|
+
adata_subset.obs.loc[cand_names, keep_best_metric],
|
|
212
|
+
errors="coerce",
|
|
213
|
+
).to_numpy(dtype=float)
|
|
214
|
+
except Exception:
|
|
215
|
+
vals = np.array([np.nan] * len(candidates), dtype=float)
|
|
216
|
+
|
|
217
|
+
if not np.all(np.isnan(vals)):
|
|
218
|
+
if keep_best_higher:
|
|
219
|
+
vals = np.where(np.isnan(vals), -np.inf, vals)
|
|
220
|
+
return candidates[int(np.nanargmax(vals))]
|
|
221
|
+
else:
|
|
222
|
+
vals = np.where(np.isnan(vals), np.inf, vals)
|
|
223
|
+
return candidates[int(np.nanargmin(vals))]
|
|
224
|
+
|
|
225
|
+
# 3) metric unhelpful — prefer lex keeper if provided
|
|
226
|
+
if lex_keeper_mask is not None:
|
|
227
|
+
for i, midx in enumerate(members_idx):
|
|
228
|
+
if (midx in candidates) and bool(lex_keeper_mask[i]):
|
|
229
|
+
return midx
|
|
230
|
+
|
|
231
|
+
# 4) fallback
|
|
232
|
+
return candidates[0]
|
|
233
|
+
|
|
234
|
+
# -------- early exits --------
|
|
105
235
|
already = bool(adata.uns.get(uns_flag, False))
|
|
106
|
-
if
|
|
236
|
+
if already and not force_redo:
|
|
107
237
|
if "is_duplicate" in adata.obs.columns:
|
|
108
|
-
adata_unique = adata[adata.obs["is_duplicate"]
|
|
238
|
+
adata_unique = adata[~adata.obs["is_duplicate"]].copy()
|
|
109
239
|
return adata_unique, adata
|
|
110
240
|
else:
|
|
111
241
|
return adata.copy(), adata.copy()
|
|
@@ -117,17 +247,23 @@ def flag_duplicate_reads(
|
|
|
117
247
|
|
|
118
248
|
# local UnionFind
|
|
119
249
|
class UnionFind:
|
|
250
|
+
"""Disjoint-set union-find helper for clustering indices."""
|
|
251
|
+
|
|
120
252
|
def __init__(self, size):
|
|
253
|
+
"""Initialize parent pointers for the union-find."""
|
|
121
254
|
self.parent = list(range(size))
|
|
122
255
|
|
|
123
256
|
def find(self, x):
|
|
257
|
+
"""Find the root for a member with path compression."""
|
|
124
258
|
while self.parent[x] != x:
|
|
125
259
|
self.parent[x] = self.parent[self.parent[x]]
|
|
126
260
|
x = self.parent[x]
|
|
127
261
|
return x
|
|
128
262
|
|
|
129
263
|
def union(self, x, y):
|
|
130
|
-
|
|
264
|
+
"""Union the sets that contain x and y."""
|
|
265
|
+
rx = self.find(x)
|
|
266
|
+
ry = self.find(y)
|
|
131
267
|
if rx != ry:
|
|
132
268
|
self.parent[ry] = rx
|
|
133
269
|
|
|
@@ -139,14 +275,14 @@ def flag_duplicate_reads(
|
|
|
139
275
|
|
|
140
276
|
for sample in samples:
|
|
141
277
|
for ref in references:
|
|
142
|
-
|
|
278
|
+
logger.info("Processing sample=%s ref=%s", sample, ref)
|
|
143
279
|
sample_mask = adata.obs[sample_col] == sample
|
|
144
280
|
ref_mask = adata.obs[obs_reference_col] == ref
|
|
145
281
|
subset_mask = sample_mask & ref_mask
|
|
146
282
|
adata_subset = adata[subset_mask].copy()
|
|
147
283
|
|
|
148
284
|
if adata_subset.n_obs < 2:
|
|
149
|
-
|
|
285
|
+
logger.info(" Skipping %s_%s (too few reads)", sample, ref)
|
|
150
286
|
continue
|
|
151
287
|
|
|
152
288
|
N = adata_subset.shape[0]
|
|
@@ -162,7 +298,12 @@ def flag_duplicate_reads(
|
|
|
162
298
|
|
|
163
299
|
selected_cols = adata.var.index[combined_mask.tolist()].to_list()
|
|
164
300
|
col_indices = [adata.var.index.get_loc(c) for c in selected_cols]
|
|
165
|
-
|
|
301
|
+
logger.info(
|
|
302
|
+
" Selected %s columns out of %s for %s",
|
|
303
|
+
len(col_indices),
|
|
304
|
+
adata.var.shape[0],
|
|
305
|
+
ref,
|
|
306
|
+
)
|
|
166
307
|
|
|
167
308
|
# Extract data matrix (dense numpy) for the subset
|
|
168
309
|
X = adata_subset.X
|
|
@@ -187,7 +328,10 @@ def flag_duplicate_reads(
|
|
|
187
328
|
hierarchical_found_dists = []
|
|
188
329
|
|
|
189
330
|
# Lexicographic windowed pass function
|
|
190
|
-
def cluster_pass(
|
|
331
|
+
def cluster_pass(
|
|
332
|
+
X_tensor_local, reverse=False, window=int(window_size), record_distances=False
|
|
333
|
+
):
|
|
334
|
+
"""Perform a lexicographic windowed clustering pass."""
|
|
191
335
|
N_local = X_tensor_local.shape[0]
|
|
192
336
|
X_sortable = X_tensor_local.clone().nan_to_num(-1.0)
|
|
193
337
|
sort_keys = [tuple(row.numpy().tolist()) for row in X_sortable]
|
|
@@ -208,10 +352,16 @@ def flag_duplicate_reads(
|
|
|
208
352
|
if enough_overlap.any():
|
|
209
353
|
diffs = (row_i_exp != block_rows) & valid_mask
|
|
210
354
|
hamming_counts = diffs.sum(dim=1).float()
|
|
211
|
-
hamming_dists = torch.where(
|
|
355
|
+
hamming_dists = torch.where(
|
|
356
|
+
valid_counts > 0,
|
|
357
|
+
hamming_counts / valid_counts,
|
|
358
|
+
torch.tensor(float("nan")),
|
|
359
|
+
)
|
|
212
360
|
# record distances (legacy list of all local comparisons)
|
|
213
361
|
hamming_np = hamming_dists.cpu().numpy().tolist()
|
|
214
|
-
local_hamming_dists.extend(
|
|
362
|
+
local_hamming_dists.extend(
|
|
363
|
+
[float(x) for x in hamming_np if (not np.isnan(x))]
|
|
364
|
+
)
|
|
215
365
|
matches = (hamming_dists < distance_threshold) & (enough_overlap)
|
|
216
366
|
for offset_local, m in enumerate(matches):
|
|
217
367
|
if m:
|
|
@@ -223,20 +373,28 @@ def flag_duplicate_reads(
|
|
|
223
373
|
next_local_idx = i + 1
|
|
224
374
|
if next_local_idx < len(sorted_X):
|
|
225
375
|
next_global = sorted_idx[next_local_idx]
|
|
226
|
-
vm_pair = (~torch.isnan(row_i)) & (
|
|
376
|
+
vm_pair = (~torch.isnan(row_i)) & (
|
|
377
|
+
~torch.isnan(sorted_X[next_local_idx])
|
|
378
|
+
)
|
|
227
379
|
vc = vm_pair.sum().item()
|
|
228
380
|
if vc >= min_overlap_positions:
|
|
229
|
-
d = float(
|
|
381
|
+
d = float(
|
|
382
|
+
(
|
|
383
|
+
(row_i[vm_pair] != sorted_X[next_local_idx][vm_pair])
|
|
384
|
+
.sum()
|
|
385
|
+
.item()
|
|
386
|
+
)
|
|
387
|
+
/ vc
|
|
388
|
+
)
|
|
230
389
|
if reverse:
|
|
231
390
|
rev_hamming_to_prev[next_global] = d
|
|
232
391
|
else:
|
|
233
392
|
fwd_hamming_to_next[sorted_idx[i]] = d
|
|
234
393
|
return cluster_pairs_local
|
|
235
394
|
|
|
236
|
-
# run forward
|
|
395
|
+
# run forward & reverse windows
|
|
237
396
|
pairs_fwd = cluster_pass(X_tensor, reverse=False, record_distances=True)
|
|
238
397
|
involved_in_fwd = set([p for pair in pairs_fwd for p in pair])
|
|
239
|
-
# build mask for reverse pass to avoid re-checking items already paired
|
|
240
398
|
mask_for_rev = np.ones(N, dtype=bool)
|
|
241
399
|
if len(involved_in_fwd) > 0:
|
|
242
400
|
for idx in involved_in_fwd:
|
|
@@ -245,8 +403,9 @@ def flag_duplicate_reads(
|
|
|
245
403
|
if len(rev_idx_map) > 0:
|
|
246
404
|
reduced_tensor = X_tensor[rev_idx_map]
|
|
247
405
|
pairs_rev_local = cluster_pass(reduced_tensor, reverse=True, record_distances=True)
|
|
248
|
-
|
|
249
|
-
|
|
406
|
+
remapped_rev_pairs = [
|
|
407
|
+
(int(rev_idx_map[i]), int(rev_idx_map[j])) for (i, j) in pairs_rev_local
|
|
408
|
+
]
|
|
250
409
|
else:
|
|
251
410
|
remapped_rev_pairs = []
|
|
252
411
|
|
|
@@ -265,53 +424,41 @@ def flag_duplicate_reads(
|
|
|
265
424
|
id_map = {old: new for new, old in enumerate(sorted(unique_initial.tolist()))}
|
|
266
425
|
merged_cluster_mapped = np.array([id_map[int(x)] for x in merged_cluster], dtype=int)
|
|
267
426
|
|
|
268
|
-
# cluster sizes and choose lex-keeper per lex-cluster (
|
|
427
|
+
# cluster sizes and choose lex-keeper per lex-cluster (demux-aware)
|
|
269
428
|
cluster_sizes = np.zeros_like(merged_cluster_mapped)
|
|
270
429
|
cluster_counts = []
|
|
271
430
|
unique_clusters = np.unique(merged_cluster_mapped)
|
|
272
431
|
keeper_for_cluster = {}
|
|
432
|
+
|
|
433
|
+
obs_index = list(adata_subset.obs.index)
|
|
273
434
|
for cid in unique_clusters:
|
|
274
435
|
members = np.where(merged_cluster_mapped == cid)[0].tolist()
|
|
275
436
|
csize = int(len(members))
|
|
276
437
|
cluster_counts.append(csize)
|
|
277
438
|
cluster_sizes[members] = csize
|
|
278
|
-
# pick lex keeper (representative)
|
|
279
|
-
if len(members) == 1:
|
|
280
|
-
keeper_for_cluster[cid] = members[0]
|
|
281
|
-
else:
|
|
282
|
-
if keep_best_metric is None:
|
|
283
|
-
keeper_for_cluster[cid] = members[0]
|
|
284
|
-
else:
|
|
285
|
-
obs_index = list(adata_subset.obs.index)
|
|
286
|
-
member_names = [obs_index[m] for m in members]
|
|
287
|
-
try:
|
|
288
|
-
vals = pd.to_numeric(adata_subset.obs.loc[member_names, keep_best_metric], errors="coerce").to_numpy(dtype=float)
|
|
289
|
-
except Exception:
|
|
290
|
-
vals = np.array([np.nan] * len(members), dtype=float)
|
|
291
|
-
if np.all(np.isnan(vals)):
|
|
292
|
-
keeper_for_cluster[cid] = members[0]
|
|
293
|
-
else:
|
|
294
|
-
if keep_best_higher:
|
|
295
|
-
nan_mask = np.isnan(vals)
|
|
296
|
-
vals[nan_mask] = -np.inf
|
|
297
|
-
rel_idx = int(np.nanargmax(vals))
|
|
298
|
-
else:
|
|
299
|
-
nan_mask = np.isnan(vals)
|
|
300
|
-
vals[nan_mask] = np.inf
|
|
301
|
-
rel_idx = int(np.nanargmin(vals))
|
|
302
|
-
keeper_for_cluster[cid] = members[rel_idx]
|
|
303
439
|
|
|
304
|
-
|
|
440
|
+
keeper_for_cluster[cid] = _choose_keeper_with_demux_preference(
|
|
441
|
+
members,
|
|
442
|
+
adata_subset,
|
|
443
|
+
obs_index,
|
|
444
|
+
demux_col=demux_col,
|
|
445
|
+
preferred_types=demux_types,
|
|
446
|
+
keep_best_metric=keep_best_metric,
|
|
447
|
+
keep_best_higher=keep_best_higher,
|
|
448
|
+
lex_keeper_mask=None, # no lex preference yet
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
# expose lex keeper info (record only)
|
|
305
452
|
lex_is_keeper = np.zeros((N,), dtype=bool)
|
|
306
453
|
lex_is_duplicate = np.zeros((N,), dtype=bool)
|
|
307
|
-
for cid
|
|
454
|
+
for cid in unique_clusters:
|
|
455
|
+
members = np.where(merged_cluster_mapped == cid)[0].tolist()
|
|
308
456
|
keeper_idx = keeper_for_cluster[cid]
|
|
309
457
|
lex_is_keeper[keeper_idx] = True
|
|
310
458
|
for m in members:
|
|
311
459
|
if m != keeper_idx:
|
|
312
460
|
lex_is_duplicate[m] = True
|
|
313
|
-
|
|
314
|
-
# and will be written to adata_subset.obs below
|
|
461
|
+
|
|
315
462
|
# record lex min pair (min of fwd/rev neighbor) for each read
|
|
316
463
|
min_pair = np.full((N,), np.nan, dtype=float)
|
|
317
464
|
for i in range(N):
|
|
@@ -349,7 +496,12 @@ def flag_duplicate_reads(
|
|
|
349
496
|
if n_comp <= 0:
|
|
350
497
|
reps_for_clustering = reps_arr
|
|
351
498
|
else:
|
|
352
|
-
pca = PCA(
|
|
499
|
+
pca = PCA(
|
|
500
|
+
n_components=n_comp,
|
|
501
|
+
random_state=int(random_state),
|
|
502
|
+
svd_solver="auto",
|
|
503
|
+
copy=True,
|
|
504
|
+
)
|
|
353
505
|
reps_for_clustering = pca.fit_transform(reps_arr)
|
|
354
506
|
else:
|
|
355
507
|
reps_for_clustering = reps_arr
|
|
@@ -360,10 +512,12 @@ def flag_duplicate_reads(
|
|
|
360
512
|
Z = sch.linkage(pdist_vec, method=hierarchical_linkage)
|
|
361
513
|
leaves = sch.leaves_list(Z)
|
|
362
514
|
except Exception as e:
|
|
363
|
-
warnings.warn(
|
|
515
|
+
warnings.warn(
|
|
516
|
+
f"hierarchical pass failed: {e}; skipping hierarchical stage."
|
|
517
|
+
)
|
|
364
518
|
leaves = np.arange(len(rep_global_indices), dtype=int)
|
|
365
519
|
|
|
366
|
-
#
|
|
520
|
+
# windowed hamming comparisons across ordered reps and union
|
|
367
521
|
order_global_reps = [rep_global_indices[i] for i in leaves]
|
|
368
522
|
n_reps = len(order_global_reps)
|
|
369
523
|
for pos in range(n_reps):
|
|
@@ -389,55 +543,40 @@ def flag_duplicate_reads(
|
|
|
389
543
|
merged_cluster_after[i] = uf.find(i)
|
|
390
544
|
unique_final = np.unique(merged_cluster_after)
|
|
391
545
|
id_map_final = {old: new for new, old in enumerate(sorted(unique_final.tolist()))}
|
|
392
|
-
merged_cluster_mapped_final = np.array(
|
|
546
|
+
merged_cluster_mapped_final = np.array(
|
|
547
|
+
[id_map_final[int(x)] for x in merged_cluster_after], dtype=int
|
|
548
|
+
)
|
|
393
549
|
|
|
394
|
-
# compute final cluster members and choose final keeper per final cluster
|
|
550
|
+
# compute final cluster members and choose final keeper per final cluster (demux-aware)
|
|
395
551
|
cluster_sizes_final = np.zeros_like(merged_cluster_mapped_final)
|
|
396
|
-
final_cluster_counts = []
|
|
397
|
-
final_unique = np.unique(merged_cluster_mapped_final)
|
|
398
552
|
final_keeper_for_cluster = {}
|
|
399
553
|
cluster_members_map = {}
|
|
400
|
-
|
|
554
|
+
|
|
555
|
+
obs_index = list(adata_subset.obs.index)
|
|
556
|
+
lex_mask_full = lex_is_keeper # use lex keeper as optional tiebreaker
|
|
557
|
+
|
|
558
|
+
for cid in np.unique(merged_cluster_mapped_final):
|
|
401
559
|
members = np.where(merged_cluster_mapped_final == cid)[0].tolist()
|
|
402
560
|
cluster_members_map[cid] = members
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
else:
|
|
421
|
-
if keep_best_higher:
|
|
422
|
-
nan_mask = np.isnan(vals)
|
|
423
|
-
vals[nan_mask] = -np.inf
|
|
424
|
-
rel_idx = int(np.nanargmax(vals))
|
|
425
|
-
else:
|
|
426
|
-
nan_mask = np.isnan(vals)
|
|
427
|
-
vals[nan_mask] = np.inf
|
|
428
|
-
rel_idx = int(np.nanargmin(vals))
|
|
429
|
-
final_keeper_for_cluster[cid] = members[rel_idx]
|
|
430
|
-
else:
|
|
431
|
-
# if lex keepers present among members, prefer them
|
|
432
|
-
lex_members = [m for m in members if lex_is_keeper[m]]
|
|
433
|
-
if len(lex_members) > 0:
|
|
434
|
-
final_keeper_for_cluster[cid] = lex_members[0]
|
|
435
|
-
else:
|
|
436
|
-
final_keeper_for_cluster[cid] = members[0]
|
|
437
|
-
|
|
438
|
-
# update sequence__is_duplicate based on final clusters: non-keepers in multi-member clusters are duplicates
|
|
561
|
+
cluster_sizes_final[members] = len(members)
|
|
562
|
+
|
|
563
|
+
lex_mask_members = np.array([bool(lex_mask_full[m]) for m in members], dtype=bool)
|
|
564
|
+
|
|
565
|
+
keeper = _choose_keeper_with_demux_preference(
|
|
566
|
+
members,
|
|
567
|
+
adata_subset,
|
|
568
|
+
obs_index,
|
|
569
|
+
demux_col=demux_col,
|
|
570
|
+
preferred_types=demux_types,
|
|
571
|
+
keep_best_metric=keep_best_metric,
|
|
572
|
+
keep_best_higher=keep_best_higher,
|
|
573
|
+
lex_keeper_mask=lex_mask_members,
|
|
574
|
+
)
|
|
575
|
+
final_keeper_for_cluster[cid] = keeper
|
|
576
|
+
|
|
577
|
+
# update sequence__is_duplicate based on final clusters
|
|
439
578
|
sequence_is_duplicate = np.zeros((N,), dtype=bool)
|
|
440
|
-
for cid in
|
|
579
|
+
for cid in np.unique(merged_cluster_mapped_final):
|
|
441
580
|
keeper = final_keeper_for_cluster[cid]
|
|
442
581
|
members = cluster_members_map[cid]
|
|
443
582
|
if len(members) > 1:
|
|
@@ -446,8 +585,7 @@ def flag_duplicate_reads(
|
|
|
446
585
|
sequence_is_duplicate[m] = True
|
|
447
586
|
|
|
448
587
|
# propagate hierarchical distances into hierarchical_min_pair for all cluster members
|
|
449
|
-
for
|
|
450
|
-
# identify their final cluster ids (after unions)
|
|
588
|
+
for i_g, j_g, d in hierarchical_pairs:
|
|
451
589
|
c_i = merged_cluster_mapped_final[int(i_g)]
|
|
452
590
|
c_j = merged_cluster_mapped_final[int(j_g)]
|
|
453
591
|
members_i = cluster_members_map.get(c_i, [int(i_g)])
|
|
@@ -459,7 +597,7 @@ def flag_duplicate_reads(
|
|
|
459
597
|
if np.isnan(hierarchical_min_pair[mj]) or (d < hierarchical_min_pair[mj]):
|
|
460
598
|
hierarchical_min_pair[mj] = d
|
|
461
599
|
|
|
462
|
-
# combine
|
|
600
|
+
# combine min pairs
|
|
463
601
|
combined_min = min_pair.copy()
|
|
464
602
|
for i in range(N):
|
|
465
603
|
hval = hierarchical_min_pair[i]
|
|
@@ -475,69 +613,117 @@ def flag_duplicate_reads(
|
|
|
475
613
|
adata_subset.obs["rev_hamming_to_prev"] = rev_hamming_to_prev
|
|
476
614
|
adata_subset.obs["sequence__hier_hamming_to_pair"] = hierarchical_min_pair
|
|
477
615
|
adata_subset.obs["sequence__min_hamming_to_pair"] = combined_min
|
|
478
|
-
# persist lex bookkeeping
|
|
616
|
+
# persist lex bookkeeping
|
|
479
617
|
adata_subset.obs["sequence__lex_is_keeper"] = lex_is_keeper
|
|
480
618
|
adata_subset.obs["sequence__lex_is_duplicate"] = lex_is_duplicate
|
|
481
619
|
|
|
482
620
|
adata_processed_list.append(adata_subset)
|
|
483
621
|
|
|
484
|
-
histograms.append(
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
622
|
+
histograms.append(
|
|
623
|
+
{
|
|
624
|
+
"sample": sample,
|
|
625
|
+
"reference": ref,
|
|
626
|
+
"distances": local_hamming_dists,
|
|
627
|
+
"cluster_counts": [
|
|
628
|
+
int(x) for x in np.unique(cluster_sizes_final[cluster_sizes_final > 0])
|
|
629
|
+
],
|
|
630
|
+
"hierarchical_pairs": hierarchical_found_dists,
|
|
631
|
+
}
|
|
632
|
+
)
|
|
633
|
+
|
|
634
|
+
# Merge annotated subsets back together BEFORE plotting
|
|
493
635
|
_original_uns = copy.deepcopy(adata.uns)
|
|
494
636
|
if len(adata_processed_list) == 0:
|
|
495
637
|
return adata.copy(), adata.copy()
|
|
496
638
|
|
|
497
639
|
adata_full = ad.concat(adata_processed_list, merge="same", join="outer", index_unique=None)
|
|
640
|
+
|
|
641
|
+
# preserve uns (prefer original on conflicts)
|
|
642
|
+
def merge_uns_preserve(orig_uns: dict, new_uns: dict, prefer="orig") -> dict:
|
|
643
|
+
"""Merge .uns dictionaries while preserving original on conflicts."""
|
|
644
|
+
out = copy.deepcopy(new_uns) if new_uns is not None else {}
|
|
645
|
+
for k, v in (orig_uns or {}).items():
|
|
646
|
+
if k not in out:
|
|
647
|
+
out[k] = copy.deepcopy(v)
|
|
648
|
+
else:
|
|
649
|
+
try:
|
|
650
|
+
equal = out[k] == v
|
|
651
|
+
except Exception:
|
|
652
|
+
equal = False
|
|
653
|
+
if equal:
|
|
654
|
+
continue
|
|
655
|
+
if prefer == "orig":
|
|
656
|
+
out[k] = copy.deepcopy(v)
|
|
657
|
+
else:
|
|
658
|
+
out[f"orig_uns__{k}"] = copy.deepcopy(v)
|
|
659
|
+
return out
|
|
660
|
+
|
|
498
661
|
adata_full.uns = merge_uns_preserve(_original_uns, adata_full.uns, prefer="orig")
|
|
499
662
|
|
|
500
663
|
# Ensure expected numeric columns exist (create if missing)
|
|
501
|
-
for col in (
|
|
664
|
+
for col in (
|
|
665
|
+
"fwd_hamming_to_next",
|
|
666
|
+
"rev_hamming_to_prev",
|
|
667
|
+
"sequence__min_hamming_to_pair",
|
|
668
|
+
"sequence__hier_hamming_to_pair",
|
|
669
|
+
):
|
|
502
670
|
if col not in adata_full.obs.columns:
|
|
503
671
|
adata_full.obs[col] = np.nan
|
|
504
672
|
|
|
505
|
-
# histograms
|
|
506
|
-
hist_outs =
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
673
|
+
# histograms
|
|
674
|
+
hist_outs = (
|
|
675
|
+
os.path.join(output_directory, "read_pair_hamming_distance_histograms")
|
|
676
|
+
if output_directory
|
|
677
|
+
else None
|
|
678
|
+
)
|
|
679
|
+
if hist_outs:
|
|
680
|
+
make_dirs([hist_outs])
|
|
681
|
+
plot_histogram_pages(
|
|
682
|
+
histograms,
|
|
683
|
+
distance_threshold=distance_threshold,
|
|
684
|
+
adata=adata_full,
|
|
685
|
+
output_directory=hist_outs,
|
|
686
|
+
distance_types=["min", "fwd", "rev", "hier", "lex_local"],
|
|
687
|
+
sample_key=sample_col,
|
|
688
|
+
)
|
|
515
689
|
|
|
516
690
|
# hamming vs metric scatter
|
|
517
|
-
scatter_outs =
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
691
|
+
scatter_outs = (
|
|
692
|
+
os.path.join(output_directory, "read_pair_hamming_distance_scatter_plots")
|
|
693
|
+
if output_directory
|
|
694
|
+
else None
|
|
695
|
+
)
|
|
696
|
+
if scatter_outs:
|
|
697
|
+
make_dirs([scatter_outs])
|
|
698
|
+
plot_hamming_vs_metric_pages(
|
|
699
|
+
adata_full,
|
|
700
|
+
metric_keys=metric_keys,
|
|
701
|
+
output_dir=scatter_outs,
|
|
702
|
+
hamming_col="sequence__min_hamming_to_pair",
|
|
703
|
+
highlight_threshold=distance_threshold,
|
|
704
|
+
highlight_color="red",
|
|
705
|
+
sample_col=sample_col,
|
|
706
|
+
)
|
|
526
707
|
|
|
527
708
|
# boolean columns from neighbor distances
|
|
528
|
-
fwd_vals = pd.to_numeric(
|
|
529
|
-
|
|
709
|
+
fwd_vals = pd.to_numeric(
|
|
710
|
+
adata_full.obs.get("fwd_hamming_to_next", pd.Series(np.nan, index=adata_full.obs.index)),
|
|
711
|
+
errors="coerce",
|
|
712
|
+
)
|
|
713
|
+
rev_vals = pd.to_numeric(
|
|
714
|
+
adata_full.obs.get("rev_hamming_to_prev", pd.Series(np.nan, index=adata_full.obs.index)),
|
|
715
|
+
errors="coerce",
|
|
716
|
+
)
|
|
530
717
|
is_dup_dist = (fwd_vals < float(distance_threshold)) | (rev_vals < float(distance_threshold))
|
|
531
718
|
is_dup_dist = is_dup_dist.fillna(False).astype(bool)
|
|
532
719
|
adata_full.obs["is_duplicate_distance"] = is_dup_dist.values
|
|
533
720
|
|
|
534
|
-
# combine sequence
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
# cluster-based duplicate indicator (if any clustering columns exist)
|
|
721
|
+
# combine with sequence flag and any clustering flags
|
|
722
|
+
seq_dup = (
|
|
723
|
+
adata_full.obs["sequence__is_duplicate"].astype(bool)
|
|
724
|
+
if "sequence__is_duplicate" in adata_full.obs.columns
|
|
725
|
+
else pd.Series(False, index=adata_full.obs.index)
|
|
726
|
+
)
|
|
541
727
|
cluster_cols = [c for c in adata_full.obs.columns if c.startswith("hamming_cluster__")]
|
|
542
728
|
if cluster_cols:
|
|
543
729
|
cl_mask = pd.Series(False, index=adata_full.obs.index)
|
|
@@ -550,59 +736,61 @@ def flag_duplicate_reads(
|
|
|
550
736
|
else:
|
|
551
737
|
adata_full.obs["is_duplicate_clustering"] = False
|
|
552
738
|
|
|
553
|
-
final_dup =
|
|
739
|
+
final_dup = (
|
|
740
|
+
seq_dup
|
|
741
|
+
| adata_full.obs["is_duplicate_distance"].astype(bool)
|
|
742
|
+
| adata_full.obs["is_duplicate_clustering"].astype(bool)
|
|
743
|
+
)
|
|
554
744
|
adata_full.obs["is_duplicate"] = final_dup.values
|
|
555
745
|
|
|
556
|
-
# Final keeper enforcement
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
keeper_idx_by_cluster = {}
|
|
560
|
-
metric_col = keep_best_metric if 'keep_best_metric' in locals() else None
|
|
746
|
+
# -------- Final keeper enforcement on adata_full (demux-aware) --------
|
|
747
|
+
keeper_idx_by_cluster = {}
|
|
748
|
+
metric_col = keep_best_metric
|
|
561
749
|
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
try:
|
|
566
|
-
members = sub.index.to_list()
|
|
567
|
-
except Exception:
|
|
568
|
-
members = list(sub.index)
|
|
569
|
-
keeper = None
|
|
570
|
-
# prefer keep_best_metric (if present), else prefer lex keeper among members, else first member
|
|
571
|
-
if metric_col and metric_col in adata_full.obs.columns:
|
|
572
|
-
try:
|
|
573
|
-
vals = pd.to_numeric(adata_full.obs.loc[members, metric_col], errors="coerce")
|
|
574
|
-
if vals.notna().any():
|
|
575
|
-
keeper = vals.idxmax() if keep_best_higher else vals.idxmin()
|
|
576
|
-
else:
|
|
577
|
-
keeper = members[0]
|
|
578
|
-
except Exception:
|
|
579
|
-
keeper = members[0]
|
|
580
|
-
else:
|
|
581
|
-
# prefer lex keeper if present in this merged cluster
|
|
582
|
-
lex_candidates = [m for m in members if ("sequence__lex_is_keeper" in adata_full.obs.columns and adata_full.obs.at[m, "sequence__lex_is_keeper"])]
|
|
583
|
-
if len(lex_candidates) > 0:
|
|
584
|
-
keeper = lex_candidates[0]
|
|
585
|
-
else:
|
|
586
|
-
keeper = members[0]
|
|
750
|
+
# Build an index→row-number mapping
|
|
751
|
+
name_to_pos = {name: i for i, name in enumerate(adata_full.obs.index)}
|
|
752
|
+
obs_index_full = list(adata_full.obs.index)
|
|
587
753
|
|
|
588
|
-
|
|
754
|
+
lex_col = (
|
|
755
|
+
"sequence__lex_is_keeper" if "sequence__lex_is_keeper" in adata_full.obs.columns else None
|
|
756
|
+
)
|
|
589
757
|
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
for
|
|
593
|
-
if keeper_idx in adata_full.obs.index:
|
|
594
|
-
is_dup_series.at[keeper_idx] = False
|
|
595
|
-
# clear sequence__is_duplicate for keeper if present
|
|
596
|
-
if "sequence__is_duplicate" in adata_full.obs.columns:
|
|
597
|
-
adata_full.obs.at[keeper_idx, "sequence__is_duplicate"] = False
|
|
598
|
-
# clear lex duplicate flag too if present
|
|
599
|
-
if "sequence__lex_is_duplicate" in adata_full.obs.columns:
|
|
600
|
-
adata_full.obs.at[keeper_idx, "sequence__lex_is_duplicate"] = False
|
|
758
|
+
for cid, sub in adata_full.obs.groupby("sequence__merged_cluster_id", dropna=False):
|
|
759
|
+
members_names = sub.index.to_list()
|
|
760
|
+
members_pos = [name_to_pos[n] for n in members_names]
|
|
601
761
|
|
|
602
|
-
|
|
762
|
+
if lex_col:
|
|
763
|
+
lex_mask_members = adata_full.obs.loc[members_names, lex_col].astype(bool).to_numpy()
|
|
764
|
+
else:
|
|
765
|
+
lex_mask_members = np.zeros(len(members_pos), dtype=bool)
|
|
766
|
+
|
|
767
|
+
keeper_pos = _choose_keeper_with_demux_preference(
|
|
768
|
+
members_pos,
|
|
769
|
+
adata_full,
|
|
770
|
+
obs_index_full,
|
|
771
|
+
demux_col=demux_col,
|
|
772
|
+
preferred_types=demux_types,
|
|
773
|
+
keep_best_metric=metric_col,
|
|
774
|
+
keep_best_higher=keep_best_higher,
|
|
775
|
+
lex_keeper_mask=lex_mask_members,
|
|
776
|
+
)
|
|
777
|
+
keeper_name = obs_index_full[keeper_pos]
|
|
778
|
+
keeper_idx_by_cluster[cid] = keeper_name
|
|
779
|
+
|
|
780
|
+
# enforce: keepers are not duplicates
|
|
781
|
+
is_dup_series = adata_full.obs["is_duplicate"].astype(bool)
|
|
782
|
+
for cid, keeper_name in keeper_idx_by_cluster.items():
|
|
783
|
+
if keeper_name in adata_full.obs.index:
|
|
784
|
+
is_dup_series.at[keeper_name] = False
|
|
785
|
+
if "sequence__is_duplicate" in adata_full.obs.columns:
|
|
786
|
+
adata_full.obs.at[keeper_name, "sequence__is_duplicate"] = False
|
|
787
|
+
if "sequence__lex_is_duplicate" in adata_full.obs.columns:
|
|
788
|
+
adata_full.obs.at[keeper_name, "sequence__lex_is_duplicate"] = False
|
|
789
|
+
adata_full.obs["is_duplicate"] = is_dup_series.values
|
|
603
790
|
|
|
604
791
|
# reason column
|
|
605
792
|
def _dup_reason_row(row):
|
|
793
|
+
"""Build a semi-colon delimited duplicate reason string."""
|
|
606
794
|
reasons = []
|
|
607
795
|
if row.get("is_duplicate_distance", False):
|
|
608
796
|
reasons.append("distance_thresh")
|
|
@@ -632,6 +820,7 @@ def flag_duplicate_reads(
|
|
|
632
820
|
# Plot helpers (use adata_full as input)
|
|
633
821
|
# ---------------------------
|
|
634
822
|
|
|
823
|
+
|
|
635
824
|
def plot_histogram_pages(
|
|
636
825
|
histograms,
|
|
637
826
|
distance_threshold,
|
|
@@ -674,10 +863,11 @@ def plot_histogram_pages(
|
|
|
674
863
|
use_adata = False
|
|
675
864
|
|
|
676
865
|
if len(samples) == 0 or len(references) == 0:
|
|
677
|
-
|
|
866
|
+
logger.info("No histogram data to plot.")
|
|
678
867
|
return {"distance_pages": [], "cluster_size_pages": []}
|
|
679
868
|
|
|
680
869
|
def clean_array(arr):
|
|
870
|
+
"""Filter array values to finite [0, 1] range for plotting."""
|
|
681
871
|
if arr is None or len(arr) == 0:
|
|
682
872
|
return np.array([], dtype=float)
|
|
683
873
|
a = np.asarray(arr, dtype=float)
|
|
@@ -707,7 +897,9 @@ def plot_histogram_pages(
|
|
|
707
897
|
if "rev" in distance_types and "rev_hamming_to_prev" in group.columns:
|
|
708
898
|
grid[(s, r)]["rev"].extend(clean_array(group["rev_hamming_to_prev"].to_numpy()))
|
|
709
899
|
if "hier" in distance_types and "sequence__hier_hamming_to_pair" in group.columns:
|
|
710
|
-
grid[(s, r)]["hier"].extend(
|
|
900
|
+
grid[(s, r)]["hier"].extend(
|
|
901
|
+
clean_array(group["sequence__hier_hamming_to_pair"].to_numpy())
|
|
902
|
+
)
|
|
711
903
|
else:
|
|
712
904
|
for (s, r), group in grouped:
|
|
713
905
|
if "min" in distance_types and distance_key in group.columns:
|
|
@@ -717,7 +909,9 @@ def plot_histogram_pages(
|
|
|
717
909
|
if "rev" in distance_types and "rev_hamming_to_prev" in group.columns:
|
|
718
910
|
grid[(s, r)]["rev"].extend(clean_array(group["rev_hamming_to_prev"].to_numpy()))
|
|
719
911
|
if "hier" in distance_types and "sequence__hier_hamming_to_pair" in group.columns:
|
|
720
|
-
grid[(s, r)]["hier"].extend(
|
|
912
|
+
grid[(s, r)]["hier"].extend(
|
|
913
|
+
clean_array(group["sequence__hier_hamming_to_pair"].to_numpy())
|
|
914
|
+
)
|
|
721
915
|
|
|
722
916
|
# legacy histograms fallback
|
|
723
917
|
if histograms:
|
|
@@ -753,9 +947,17 @@ def plot_histogram_pages(
|
|
|
753
947
|
|
|
754
948
|
# counts (for labels)
|
|
755
949
|
if use_adata:
|
|
756
|
-
counts = {
|
|
950
|
+
counts = {
|
|
951
|
+
(s, r): int(((adata.obs[sample_key] == s) & (adata.obs[ref_key] == r)).sum())
|
|
952
|
+
for s in samples
|
|
953
|
+
for r in references
|
|
954
|
+
}
|
|
757
955
|
else:
|
|
758
|
-
counts = {
|
|
956
|
+
counts = {
|
|
957
|
+
(s, r): sum(len(grid[(s, r)][dt]) for dt in distance_types)
|
|
958
|
+
for s in samples
|
|
959
|
+
for r in references
|
|
960
|
+
}
|
|
759
961
|
|
|
760
962
|
distance_pages = []
|
|
761
963
|
cluster_size_pages = []
|
|
@@ -773,7 +975,9 @@ def plot_histogram_pages(
|
|
|
773
975
|
# Distance histogram page
|
|
774
976
|
fig_w = figsize_per_cell[0] * ncols
|
|
775
977
|
fig_h = figsize_per_cell[1] * nrows
|
|
776
|
-
fig, axes = plt.subplots(
|
|
978
|
+
fig, axes = plt.subplots(
|
|
979
|
+
nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False
|
|
980
|
+
)
|
|
777
981
|
|
|
778
982
|
for r_idx, sample_name in enumerate(chunk):
|
|
779
983
|
for c_idx, ref_name in enumerate(references):
|
|
@@ -789,17 +993,37 @@ def plot_histogram_pages(
|
|
|
789
993
|
vals = vals[(vals >= 0.0) & (vals <= ref_vmax)]
|
|
790
994
|
if vals.size > 0:
|
|
791
995
|
any_data = True
|
|
792
|
-
ax.hist(
|
|
793
|
-
|
|
996
|
+
ax.hist(
|
|
997
|
+
vals,
|
|
998
|
+
bins=bins_edges,
|
|
999
|
+
alpha=0.5,
|
|
1000
|
+
label=dtype,
|
|
1001
|
+
density=False,
|
|
1002
|
+
stacked=False,
|
|
1003
|
+
color=dtype_colors.get(dtype, None),
|
|
1004
|
+
)
|
|
794
1005
|
if not any_data:
|
|
795
|
-
ax.text(
|
|
1006
|
+
ax.text(
|
|
1007
|
+
0.5,
|
|
1008
|
+
0.5,
|
|
1009
|
+
"No data",
|
|
1010
|
+
ha="center",
|
|
1011
|
+
va="center",
|
|
1012
|
+
transform=ax.transAxes,
|
|
1013
|
+
fontsize=10,
|
|
1014
|
+
color="gray",
|
|
1015
|
+
)
|
|
796
1016
|
# threshold line (make sure it is within axis)
|
|
797
1017
|
ax.axvline(distance_threshold, color="red", linestyle="--", linewidth=1)
|
|
798
1018
|
|
|
799
1019
|
if r_idx == 0:
|
|
800
1020
|
ax.set_title(str(ref_name), fontsize=10)
|
|
801
1021
|
if c_idx == 0:
|
|
802
|
-
total_reads =
|
|
1022
|
+
total_reads = (
|
|
1023
|
+
sum(counts.get((sample_name, ref), 0) for ref in references)
|
|
1024
|
+
if not use_adata
|
|
1025
|
+
else int((adata.obs[sample_key] == sample_name).sum())
|
|
1026
|
+
)
|
|
803
1027
|
ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=9)
|
|
804
1028
|
if r_idx == nrows - 1:
|
|
805
1029
|
ax.set_xlabel("Hamming Distance", fontsize=9)
|
|
@@ -811,12 +1035,16 @@ def plot_histogram_pages(
|
|
|
811
1035
|
if r_idx == 0 and c_idx == 0:
|
|
812
1036
|
ax.legend(fontsize=7, loc="upper right")
|
|
813
1037
|
|
|
814
|
-
fig.suptitle(
|
|
1038
|
+
fig.suptitle(
|
|
1039
|
+
f"Hamming distance histograms (rows=samples, cols=references) — page {page + 1}/{n_pages}",
|
|
1040
|
+
fontsize=12,
|
|
1041
|
+
y=0.995,
|
|
1042
|
+
)
|
|
815
1043
|
fig.tight_layout(rect=[0, 0, 1, 0.96])
|
|
816
1044
|
|
|
817
1045
|
if output_directory:
|
|
818
1046
|
os.makedirs(output_directory, exist_ok=True)
|
|
819
|
-
fname = os.path.join(output_directory, f"hamming_histograms_page_{page+1}.png")
|
|
1047
|
+
fname = os.path.join(output_directory, f"hamming_histograms_page_{page + 1}.png")
|
|
820
1048
|
plt.savefig(fname, bbox_inches="tight")
|
|
821
1049
|
distance_pages.append(fname)
|
|
822
1050
|
else:
|
|
@@ -826,22 +1054,43 @@ def plot_histogram_pages(
|
|
|
826
1054
|
# Cluster-size histogram page (unchanged except it uses adata-derived sizes per cluster if available)
|
|
827
1055
|
fig_w = figsize_per_cell[0] * ncols
|
|
828
1056
|
fig_h = figsize_per_cell[1] * nrows
|
|
829
|
-
fig2, axes2 = plt.subplots(
|
|
1057
|
+
fig2, axes2 = plt.subplots(
|
|
1058
|
+
nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False
|
|
1059
|
+
)
|
|
830
1060
|
|
|
831
1061
|
for r_idx, sample_name in enumerate(chunk):
|
|
832
1062
|
for c_idx, ref_name in enumerate(references):
|
|
833
1063
|
ax = axes2[r_idx][c_idx]
|
|
834
1064
|
sizes = []
|
|
835
|
-
if use_adata and (
|
|
836
|
-
|
|
1065
|
+
if use_adata and (
|
|
1066
|
+
"sequence__merged_cluster_id" in adata.obs.columns
|
|
1067
|
+
and "sequence__cluster_size" in adata.obs.columns
|
|
1068
|
+
):
|
|
1069
|
+
sub = adata.obs[
|
|
1070
|
+
(adata.obs[sample_key] == sample_name) & (adata.obs[ref_key] == ref_name)
|
|
1071
|
+
]
|
|
837
1072
|
if not sub.empty:
|
|
838
1073
|
try:
|
|
839
|
-
grp = sub.groupby("sequence__merged_cluster_id")[
|
|
840
|
-
|
|
1074
|
+
grp = sub.groupby("sequence__merged_cluster_id")[
|
|
1075
|
+
"sequence__cluster_size"
|
|
1076
|
+
].first()
|
|
1077
|
+
sizes = [
|
|
1078
|
+
int(x)
|
|
1079
|
+
for x in grp.to_numpy().tolist()
|
|
1080
|
+
if (pd.notna(x) and np.isfinite(x))
|
|
1081
|
+
]
|
|
841
1082
|
except Exception:
|
|
842
1083
|
try:
|
|
843
|
-
unique_pairs = sub[
|
|
844
|
-
|
|
1084
|
+
unique_pairs = sub[
|
|
1085
|
+
["sequence__merged_cluster_id", "sequence__cluster_size"]
|
|
1086
|
+
].drop_duplicates()
|
|
1087
|
+
sizes = [
|
|
1088
|
+
int(x)
|
|
1089
|
+
for x in unique_pairs["sequence__cluster_size"]
|
|
1090
|
+
.dropna()
|
|
1091
|
+
.astype(int)
|
|
1092
|
+
.tolist()
|
|
1093
|
+
]
|
|
845
1094
|
except Exception:
|
|
846
1095
|
sizes = []
|
|
847
1096
|
if (not sizes) and histograms:
|
|
@@ -855,23 +1104,38 @@ def plot_histogram_pages(
|
|
|
855
1104
|
ax.set_xlabel("Cluster size")
|
|
856
1105
|
ax.set_ylabel("Count")
|
|
857
1106
|
else:
|
|
858
|
-
ax.text(
|
|
1107
|
+
ax.text(
|
|
1108
|
+
0.5,
|
|
1109
|
+
0.5,
|
|
1110
|
+
"No clusters",
|
|
1111
|
+
ha="center",
|
|
1112
|
+
va="center",
|
|
1113
|
+
transform=ax.transAxes,
|
|
1114
|
+
fontsize=10,
|
|
1115
|
+
color="gray",
|
|
1116
|
+
)
|
|
859
1117
|
|
|
860
1118
|
if r_idx == 0:
|
|
861
1119
|
ax.set_title(str(ref_name), fontsize=10)
|
|
862
1120
|
if c_idx == 0:
|
|
863
|
-
total_reads =
|
|
1121
|
+
total_reads = (
|
|
1122
|
+
sum(counts.get((sample_name, ref), 0) for ref in references)
|
|
1123
|
+
if not use_adata
|
|
1124
|
+
else int((adata.obs[sample_key] == sample_name).sum())
|
|
1125
|
+
)
|
|
864
1126
|
ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=9)
|
|
865
1127
|
if r_idx != nrows - 1:
|
|
866
1128
|
ax.set_xticklabels([])
|
|
867
1129
|
|
|
868
1130
|
ax.grid(True, alpha=0.25)
|
|
869
1131
|
|
|
870
|
-
fig2.suptitle(
|
|
1132
|
+
fig2.suptitle(
|
|
1133
|
+
f"Union-find cluster size histograms — page {page + 1}/{n_pages}", fontsize=12, y=0.995
|
|
1134
|
+
)
|
|
871
1135
|
fig2.tight_layout(rect=[0, 0, 1, 0.96])
|
|
872
1136
|
|
|
873
1137
|
if output_directory:
|
|
874
|
-
fname2 = os.path.join(output_directory, f"cluster_size_histograms_page_{page+1}.png")
|
|
1138
|
+
fname2 = os.path.join(output_directory, f"cluster_size_histograms_page_{page + 1}.png")
|
|
875
1139
|
plt.savefig(fname2, bbox_inches="tight")
|
|
876
1140
|
cluster_size_pages.append(fname2)
|
|
877
1141
|
else:
|
|
@@ -923,7 +1187,9 @@ def plot_hamming_vs_metric_pages(
|
|
|
923
1187
|
|
|
924
1188
|
obs = adata.obs
|
|
925
1189
|
if sample_col not in obs.columns or ref_col not in obs.columns:
|
|
926
|
-
raise ValueError(
|
|
1190
|
+
raise ValueError(
|
|
1191
|
+
f"sample_col '{sample_col}' and ref_col '{ref_col}' must exist in adata.obs"
|
|
1192
|
+
)
|
|
927
1193
|
|
|
928
1194
|
# canonicalize samples and refs
|
|
929
1195
|
if samples is None:
|
|
@@ -964,14 +1230,24 @@ def plot_hamming_vs_metric_pages(
|
|
|
964
1230
|
sY = pd.to_numeric(obs[hamming_col], errors="coerce") if hamming_present else None
|
|
965
1231
|
|
|
966
1232
|
if (sX is not None) and (sY is not None):
|
|
967
|
-
valid_both =
|
|
1233
|
+
valid_both = (
|
|
1234
|
+
sX.notna() & sY.notna() & np.isfinite(sX.values) & np.isfinite(sY.values)
|
|
1235
|
+
)
|
|
968
1236
|
if valid_both.any():
|
|
969
1237
|
xvals = sX[valid_both].to_numpy(dtype=float)
|
|
970
1238
|
yvals = sY[valid_both].to_numpy(dtype=float)
|
|
971
1239
|
xmin, xmax = float(np.nanmin(xvals)), float(np.nanmax(xvals))
|
|
972
1240
|
ymin, ymax = float(np.nanmin(yvals)), float(np.nanmax(yvals))
|
|
973
|
-
xpad =
|
|
974
|
-
|
|
1241
|
+
xpad = (
|
|
1242
|
+
max(1e-6, (xmax - xmin) * 0.05)
|
|
1243
|
+
if xmax > xmin
|
|
1244
|
+
else max(1e-3, abs(xmin) * 0.05 + 1e-3)
|
|
1245
|
+
)
|
|
1246
|
+
ypad = (
|
|
1247
|
+
max(1e-6, (ymax - ymin) * 0.05)
|
|
1248
|
+
if ymax > ymin
|
|
1249
|
+
else max(1e-3, abs(ymin) * 0.05 + 1e-3)
|
|
1250
|
+
)
|
|
975
1251
|
global_xlim = (xmin - xpad, xmax + xpad)
|
|
976
1252
|
global_ylim = (ymin - ypad, ymax + ypad)
|
|
977
1253
|
else:
|
|
@@ -979,23 +1255,39 @@ def plot_hamming_vs_metric_pages(
|
|
|
979
1255
|
sY_finite = sY[np.isfinite(sY)]
|
|
980
1256
|
if sX_finite.size > 0:
|
|
981
1257
|
xmin, xmax = float(np.nanmin(sX_finite)), float(np.nanmax(sX_finite))
|
|
982
|
-
xpad =
|
|
1258
|
+
xpad = (
|
|
1259
|
+
max(1e-6, (xmax - xmin) * 0.05)
|
|
1260
|
+
if xmax > xmin
|
|
1261
|
+
else max(1e-3, abs(xmin) * 0.05 + 1e-3)
|
|
1262
|
+
)
|
|
983
1263
|
global_xlim = (xmin - xpad, xmax + xpad)
|
|
984
1264
|
if sY_finite.size > 0:
|
|
985
1265
|
ymin, ymax = float(np.nanmin(sY_finite)), float(np.nanmax(sY_finite))
|
|
986
|
-
ypad =
|
|
1266
|
+
ypad = (
|
|
1267
|
+
max(1e-6, (ymax - ymin) * 0.05)
|
|
1268
|
+
if ymax > ymin
|
|
1269
|
+
else max(1e-3, abs(ymin) * 0.05 + 1e-3)
|
|
1270
|
+
)
|
|
987
1271
|
global_ylim = (ymin - ypad, ymax + ypad)
|
|
988
1272
|
elif sX is not None:
|
|
989
1273
|
sX_finite = sX[np.isfinite(sX)]
|
|
990
1274
|
if sX_finite.size > 0:
|
|
991
1275
|
xmin, xmax = float(np.nanmin(sX_finite)), float(np.nanmax(sX_finite))
|
|
992
|
-
xpad =
|
|
1276
|
+
xpad = (
|
|
1277
|
+
max(1e-6, (xmax - xmin) * 0.05)
|
|
1278
|
+
if xmax > xmin
|
|
1279
|
+
else max(1e-3, abs(xmin) * 0.05 + 1e-3)
|
|
1280
|
+
)
|
|
993
1281
|
global_xlim = (xmin - xpad, xmax + xpad)
|
|
994
1282
|
elif sY is not None:
|
|
995
1283
|
sY_finite = sY[np.isfinite(sY)]
|
|
996
1284
|
if sY_finite.size > 0:
|
|
997
1285
|
ymin, ymax = float(np.nanmin(sY_finite)), float(np.nanmax(sY_finite))
|
|
998
|
-
ypad =
|
|
1286
|
+
ypad = (
|
|
1287
|
+
max(1e-6, (ymax - ymin) * 0.05)
|
|
1288
|
+
if ymax > ymin
|
|
1289
|
+
else max(1e-3, abs(ymin) * 0.05 + 1e-3)
|
|
1290
|
+
)
|
|
999
1291
|
global_ylim = (ymin - ypad, ymax + ypad)
|
|
1000
1292
|
|
|
1001
1293
|
# pagination
|
|
@@ -1005,15 +1297,19 @@ def plot_hamming_vs_metric_pages(
|
|
|
1005
1297
|
ncols = len(cols)
|
|
1006
1298
|
fig_w = ncols * figsize_per_cell[0]
|
|
1007
1299
|
fig_h = nrows * figsize_per_cell[1]
|
|
1008
|
-
fig, axes = plt.subplots(
|
|
1300
|
+
fig, axes = plt.subplots(
|
|
1301
|
+
nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False
|
|
1302
|
+
)
|
|
1009
1303
|
|
|
1010
1304
|
for r_idx, sample_name in enumerate(chunk):
|
|
1011
1305
|
for c_idx, ref_name in enumerate(cols):
|
|
1012
1306
|
ax = axes[r_idx][c_idx]
|
|
1013
1307
|
if ref_name == extra_col:
|
|
1014
|
-
mask =
|
|
1308
|
+
mask = obs[sample_col].values == sample_name
|
|
1015
1309
|
else:
|
|
1016
|
-
mask = (obs[sample_col].values == sample_name) & (
|
|
1310
|
+
mask = (obs[sample_col].values == sample_name) & (
|
|
1311
|
+
obs[ref_col].values == ref_name
|
|
1312
|
+
)
|
|
1017
1313
|
|
|
1018
1314
|
sub = obs[mask]
|
|
1019
1315
|
|
|
@@ -1022,7 +1318,9 @@ def plot_hamming_vs_metric_pages(
|
|
|
1022
1318
|
else:
|
|
1023
1319
|
x_all = np.array([], dtype=float)
|
|
1024
1320
|
if hamming_col in sub.columns:
|
|
1025
|
-
y_all = pd.to_numeric(sub[hamming_col], errors="coerce").to_numpy(
|
|
1321
|
+
y_all = pd.to_numeric(sub[hamming_col], errors="coerce").to_numpy(
|
|
1322
|
+
dtype=float
|
|
1323
|
+
)
|
|
1026
1324
|
else:
|
|
1027
1325
|
y_all = np.array([], dtype=float)
|
|
1028
1326
|
|
|
@@ -1040,32 +1338,67 @@ def plot_hamming_vs_metric_pages(
|
|
|
1040
1338
|
idxs_valid = np.array([], dtype=int)
|
|
1041
1339
|
|
|
1042
1340
|
if x.size == 0:
|
|
1043
|
-
ax.text(
|
|
1341
|
+
ax.text(
|
|
1342
|
+
0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes
|
|
1343
|
+
)
|
|
1044
1344
|
clusters_info[(sample_name, ref_name)] = {"diag": None, "n_points": 0}
|
|
1045
1345
|
else:
|
|
1046
1346
|
# Decide color mapping
|
|
1047
|
-
if
|
|
1347
|
+
if (
|
|
1348
|
+
color_by_duplicate
|
|
1349
|
+
and duplicate_col in adata.obs.columns
|
|
1350
|
+
and idxs_valid.size > 0
|
|
1351
|
+
):
|
|
1048
1352
|
# get boolean series aligned to idxs_valid
|
|
1049
1353
|
try:
|
|
1050
|
-
dup_flags =
|
|
1354
|
+
dup_flags = (
|
|
1355
|
+
adata.obs.loc[idxs_valid, duplicate_col].astype(bool).to_numpy()
|
|
1356
|
+
)
|
|
1051
1357
|
except Exception:
|
|
1052
1358
|
dup_flags = np.zeros(len(idxs_valid), dtype=bool)
|
|
1053
1359
|
mask_dup = dup_flags
|
|
1054
1360
|
mask_nondup = ~mask_dup
|
|
1055
1361
|
# plot non-duplicates first in gray, duplicates in highlight color
|
|
1056
1362
|
if mask_nondup.any():
|
|
1057
|
-
ax.scatter(
|
|
1363
|
+
ax.scatter(
|
|
1364
|
+
x[mask_nondup],
|
|
1365
|
+
y[mask_nondup],
|
|
1366
|
+
s=12,
|
|
1367
|
+
alpha=0.6,
|
|
1368
|
+
rasterized=True,
|
|
1369
|
+
c="lightgray",
|
|
1370
|
+
)
|
|
1058
1371
|
if mask_dup.any():
|
|
1059
|
-
ax.scatter(
|
|
1372
|
+
ax.scatter(
|
|
1373
|
+
x[mask_dup],
|
|
1374
|
+
y[mask_dup],
|
|
1375
|
+
s=20,
|
|
1376
|
+
alpha=0.9,
|
|
1377
|
+
rasterized=True,
|
|
1378
|
+
c=highlight_color,
|
|
1379
|
+
edgecolors="k",
|
|
1380
|
+
linewidths=0.3,
|
|
1381
|
+
)
|
|
1060
1382
|
else:
|
|
1061
1383
|
# old behavior: highlight by threshold if requested
|
|
1062
1384
|
if highlight_threshold is not None and y.size:
|
|
1063
1385
|
mask_low = (y < float(highlight_threshold)) & np.isfinite(y)
|
|
1064
1386
|
mask_high = ~mask_low
|
|
1065
1387
|
if mask_high.any():
|
|
1066
|
-
ax.scatter(
|
|
1388
|
+
ax.scatter(
|
|
1389
|
+
x[mask_high], y[mask_high], s=12, alpha=0.6, rasterized=True
|
|
1390
|
+
)
|
|
1067
1391
|
if mask_low.any():
|
|
1068
|
-
ax.scatter(
|
|
1392
|
+
ax.scatter(
|
|
1393
|
+
x[mask_low],
|
|
1394
|
+
y[mask_low],
|
|
1395
|
+
s=18,
|
|
1396
|
+
alpha=0.9,
|
|
1397
|
+
rasterized=True,
|
|
1398
|
+
c=highlight_color,
|
|
1399
|
+
edgecolors="k",
|
|
1400
|
+
linewidths=0.3,
|
|
1401
|
+
)
|
|
1069
1402
|
else:
|
|
1070
1403
|
ax.scatter(x, y, s=12, alpha=0.6, rasterized=True)
|
|
1071
1404
|
|
|
@@ -1081,7 +1414,9 @@ def plot_hamming_vs_metric_pages(
|
|
|
1081
1414
|
zi = gaussian_kde(np.vstack([x, y]))(coords).reshape(xi_g.shape)
|
|
1082
1415
|
ax.contourf(xi_g, yi_g, zi, levels=8, alpha=0.35, cmap="Blues")
|
|
1083
1416
|
else:
|
|
1084
|
-
ax.scatter(
|
|
1417
|
+
ax.scatter(
|
|
1418
|
+
x, y, c=kde2, s=16, cmap="viridis", alpha=0.7, linewidths=0
|
|
1419
|
+
)
|
|
1085
1420
|
except Exception:
|
|
1086
1421
|
pass
|
|
1087
1422
|
|
|
@@ -1090,16 +1425,29 @@ def plot_hamming_vs_metric_pages(
|
|
|
1090
1425
|
a, b = np.polyfit(x, y, 1)
|
|
1091
1426
|
xs = np.linspace(np.nanmin(x), np.nanmax(x), 100)
|
|
1092
1427
|
ys = a * xs + b
|
|
1093
|
-
ax.plot(
|
|
1428
|
+
ax.plot(
|
|
1429
|
+
xs, ys, linestyle="--", linewidth=1.2, alpha=0.9, color="red"
|
|
1430
|
+
)
|
|
1094
1431
|
r = np.corrcoef(x, y)[0, 1]
|
|
1095
|
-
ax.text(
|
|
1096
|
-
|
|
1432
|
+
ax.text(
|
|
1433
|
+
0.98,
|
|
1434
|
+
0.02,
|
|
1435
|
+
f"r={float(r):.3f}",
|
|
1436
|
+
ha="right",
|
|
1437
|
+
va="bottom",
|
|
1438
|
+
transform=ax.transAxes,
|
|
1439
|
+
fontsize=8,
|
|
1440
|
+
bbox=dict(
|
|
1441
|
+
facecolor="white", alpha=0.6, boxstyle="round,pad=0.2"
|
|
1442
|
+
),
|
|
1443
|
+
)
|
|
1097
1444
|
except Exception:
|
|
1098
1445
|
pass
|
|
1099
1446
|
|
|
1100
1447
|
if clustering:
|
|
1101
1448
|
cl_labels, diag = _run_clustering(
|
|
1102
|
-
x,
|
|
1449
|
+
x,
|
|
1450
|
+
y,
|
|
1103
1451
|
method=clustering.get("method", "dbscan"),
|
|
1104
1452
|
n_clusters=clustering.get("n_clusters", 2),
|
|
1105
1453
|
dbscan_eps=clustering.get("dbscan_eps", 0.05),
|
|
@@ -1113,32 +1461,59 @@ def plot_hamming_vs_metric_pages(
|
|
|
1113
1461
|
if len(unique_nonnoise) > 0:
|
|
1114
1462
|
medians = {}
|
|
1115
1463
|
for lab in unique_nonnoise:
|
|
1116
|
-
mask_lab =
|
|
1117
|
-
medians[lab] =
|
|
1118
|
-
|
|
1464
|
+
mask_lab = cl_labels == lab
|
|
1465
|
+
medians[lab] = (
|
|
1466
|
+
float(np.median(y[mask_lab]))
|
|
1467
|
+
if mask_lab.any()
|
|
1468
|
+
else float("nan")
|
|
1469
|
+
)
|
|
1470
|
+
sorted_by_median = sorted(
|
|
1471
|
+
unique_nonnoise,
|
|
1472
|
+
key=lambda idx: (
|
|
1473
|
+
np.nan if np.isnan(medians[idx]) else medians[idx]
|
|
1474
|
+
),
|
|
1475
|
+
reverse=True,
|
|
1476
|
+
)
|
|
1119
1477
|
mapping = {old: new for new, old in enumerate(sorted_by_median)}
|
|
1120
1478
|
for i_lab in range(len(remapped_labels)):
|
|
1121
1479
|
if remapped_labels[i_lab] != -1:
|
|
1122
|
-
remapped_labels[i_lab] = mapping.get(
|
|
1480
|
+
remapped_labels[i_lab] = mapping.get(
|
|
1481
|
+
remapped_labels[i_lab], -1
|
|
1482
|
+
)
|
|
1123
1483
|
diag = diag or {}
|
|
1124
|
-
diag["cluster_median_hamming"] = {
|
|
1125
|
-
|
|
1484
|
+
diag["cluster_median_hamming"] = {
|
|
1485
|
+
int(old): medians[old] for old in medians
|
|
1486
|
+
}
|
|
1487
|
+
diag["cluster_old_to_new_map"] = {
|
|
1488
|
+
int(old): int(new) for old, new in mapping.items()
|
|
1489
|
+
}
|
|
1126
1490
|
else:
|
|
1127
1491
|
remapped_labels = cl_labels.copy()
|
|
1128
1492
|
diag = diag or {}
|
|
1129
1493
|
diag["cluster_median_hamming"] = {}
|
|
1130
1494
|
diag["cluster_old_to_new_map"] = {}
|
|
1131
1495
|
|
|
1132
|
-
_overlay_clusters_on_ax(
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1496
|
+
_overlay_clusters_on_ax(
|
|
1497
|
+
ax,
|
|
1498
|
+
x,
|
|
1499
|
+
y,
|
|
1500
|
+
remapped_labels,
|
|
1501
|
+
diag,
|
|
1502
|
+
cmap=clustering.get("cmap", "tab10"),
|
|
1503
|
+
hull=clustering.get("hull", True),
|
|
1504
|
+
show_cluster_labels=True,
|
|
1505
|
+
)
|
|
1136
1506
|
|
|
1137
|
-
clusters_info[(sample_name, ref_name)] = {
|
|
1507
|
+
clusters_info[(sample_name, ref_name)] = {
|
|
1508
|
+
"diag": diag,
|
|
1509
|
+
"n_points": len(x),
|
|
1510
|
+
}
|
|
1138
1511
|
|
|
1139
1512
|
if write_clusters_to_adata and idxs_valid.size > 0:
|
|
1140
|
-
colname_safe_ref =
|
|
1141
|
-
colname =
|
|
1513
|
+
colname_safe_ref = ref_name if ref_name != extra_col else "ALLREFS"
|
|
1514
|
+
colname = (
|
|
1515
|
+
f"hamming_cluster__{metric}__{sample_name}__{colname_safe_ref}"
|
|
1516
|
+
)
|
|
1142
1517
|
if colname not in adata.obs.columns:
|
|
1143
1518
|
adata.obs[colname] = np.nan
|
|
1144
1519
|
lab_arr = remapped_labels.astype(float)
|
|
@@ -1178,7 +1553,11 @@ def plot_hamming_vs_metric_pages(
|
|
|
1178
1553
|
plt.show()
|
|
1179
1554
|
plt.close(fig)
|
|
1180
1555
|
|
|
1181
|
-
saved_map[metric] = {
|
|
1556
|
+
saved_map[metric] = {
|
|
1557
|
+
"files": files,
|
|
1558
|
+
"clusters_info": clusters_info,
|
|
1559
|
+
"written_cols": written_cols,
|
|
1560
|
+
}
|
|
1182
1561
|
|
|
1183
1562
|
return saved_map
|
|
1184
1563
|
|
|
@@ -1187,7 +1566,7 @@ def _run_clustering(
|
|
|
1187
1566
|
x: np.ndarray,
|
|
1188
1567
|
y: np.ndarray,
|
|
1189
1568
|
*,
|
|
1190
|
-
method: str = "kmeans",
|
|
1569
|
+
method: str = "kmeans", # "kmeans", "dbscan", "gmm", "hdbscan"
|
|
1191
1570
|
n_clusters: int = 2,
|
|
1192
1571
|
dbscan_eps: float = 0.05,
|
|
1193
1572
|
dbscan_min_samples: int = 5,
|
|
@@ -1199,9 +1578,9 @@ def _run_clustering(
|
|
|
1199
1578
|
Labels follow sklearn conventions (noise -> -1 for DBSCAN/HDBSCAN).
|
|
1200
1579
|
"""
|
|
1201
1580
|
try:
|
|
1202
|
-
from sklearn.cluster import
|
|
1203
|
-
from sklearn.mixture import GaussianMixture
|
|
1581
|
+
from sklearn.cluster import DBSCAN, KMeans
|
|
1204
1582
|
from sklearn.metrics import silhouette_score
|
|
1583
|
+
from sklearn.mixture import GaussianMixture
|
|
1205
1584
|
except Exception:
|
|
1206
1585
|
KMeans = DBSCAN = GaussianMixture = silhouette_score = None
|
|
1207
1586
|
|
|
@@ -1270,7 +1649,11 @@ def _run_clustering(
|
|
|
1270
1649
|
|
|
1271
1650
|
# compute silhouette if suitable
|
|
1272
1651
|
try:
|
|
1273
|
-
if
|
|
1652
|
+
if (
|
|
1653
|
+
diagnostics.get("n_clusters_found", 0) >= 2
|
|
1654
|
+
and len(x) >= 3
|
|
1655
|
+
and silhouette_score is not None
|
|
1656
|
+
):
|
|
1274
1657
|
diagnostics["silhouette"] = float(silhouette_score(pts, labels))
|
|
1275
1658
|
else:
|
|
1276
1659
|
diagnostics["silhouette"] = None
|
|
@@ -1305,7 +1688,6 @@ def _overlay_clusters_on_ax(
|
|
|
1305
1688
|
Labels == -1 are noise and drawn in grey.
|
|
1306
1689
|
Also annotates cluster numbers near centroids (contiguous numbers starting at 0).
|
|
1307
1690
|
"""
|
|
1308
|
-
import matplotlib.colors as mcolors
|
|
1309
1691
|
from scipy.spatial import ConvexHull
|
|
1310
1692
|
|
|
1311
1693
|
labels = np.asarray(labels)
|
|
@@ -1323,19 +1705,47 @@ def _overlay_clusters_on_ax(
|
|
|
1323
1705
|
if not mask.any():
|
|
1324
1706
|
continue
|
|
1325
1707
|
col = (0.6, 0.6, 0.6, 0.6) if lab == -1 else colors[idx % ncolors]
|
|
1326
|
-
ax.scatter(
|
|
1708
|
+
ax.scatter(
|
|
1709
|
+
x[mask],
|
|
1710
|
+
y[mask],
|
|
1711
|
+
s=20,
|
|
1712
|
+
c=[col],
|
|
1713
|
+
alpha=alpha_pts,
|
|
1714
|
+
marker=marker,
|
|
1715
|
+
linewidths=0.2,
|
|
1716
|
+
edgecolors="none",
|
|
1717
|
+
rasterized=True,
|
|
1718
|
+
)
|
|
1327
1719
|
|
|
1328
1720
|
if lab != -1:
|
|
1329
1721
|
# centroid
|
|
1330
1722
|
if plot_centroids:
|
|
1331
1723
|
cx = float(np.mean(x[mask]))
|
|
1332
1724
|
cy = float(np.mean(y[mask]))
|
|
1333
|
-
ax.scatter(
|
|
1725
|
+
ax.scatter(
|
|
1726
|
+
[cx],
|
|
1727
|
+
[cy],
|
|
1728
|
+
s=centroid_size,
|
|
1729
|
+
marker=centroid_marker,
|
|
1730
|
+
c=[col],
|
|
1731
|
+
edgecolor="k",
|
|
1732
|
+
linewidth=0.6,
|
|
1733
|
+
zorder=10,
|
|
1734
|
+
)
|
|
1334
1735
|
|
|
1335
1736
|
if show_cluster_labels:
|
|
1336
|
-
ax.text(
|
|
1337
|
-
|
|
1338
|
-
|
|
1737
|
+
ax.text(
|
|
1738
|
+
cx,
|
|
1739
|
+
cy,
|
|
1740
|
+
str(int(lab)),
|
|
1741
|
+
color="white",
|
|
1742
|
+
fontsize=cluster_label_fontsize,
|
|
1743
|
+
ha="center",
|
|
1744
|
+
va="center",
|
|
1745
|
+
weight="bold",
|
|
1746
|
+
zorder=12,
|
|
1747
|
+
bbox=dict(facecolor=(0, 0, 0, 0.5), pad=0.3, boxstyle="round"),
|
|
1748
|
+
)
|
|
1339
1749
|
|
|
1340
1750
|
# hull
|
|
1341
1751
|
if hull and np.sum(mask) >= 3:
|
|
@@ -1343,9 +1753,16 @@ def _overlay_clusters_on_ax(
|
|
|
1343
1753
|
ch_pts = pts[mask]
|
|
1344
1754
|
hull_idx = ConvexHull(ch_pts).vertices
|
|
1345
1755
|
hull_poly = ch_pts[hull_idx]
|
|
1346
|
-
ax.fill(
|
|
1756
|
+
ax.fill(
|
|
1757
|
+
hull_poly[:, 0],
|
|
1758
|
+
hull_poly[:, 1],
|
|
1759
|
+
alpha=hull_alpha,
|
|
1760
|
+
facecolor=col,
|
|
1761
|
+
edgecolor=hull_edgecolor,
|
|
1762
|
+
linewidth=0.6,
|
|
1763
|
+
zorder=5,
|
|
1764
|
+
)
|
|
1347
1765
|
except Exception:
|
|
1348
1766
|
pass
|
|
1349
1767
|
|
|
1350
1768
|
return None
|
|
1351
|
-
|