smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +43 -13
- smftools/_settings.py +6 -6
- smftools/_version.py +3 -1
- smftools/cli/__init__.py +1 -0
- smftools/cli/archived/cli_flows.py +2 -0
- smftools/cli/helpers.py +9 -1
- smftools/cli/hmm_adata.py +905 -242
- smftools/cli/load_adata.py +432 -280
- smftools/cli/preprocess_adata.py +287 -171
- smftools/cli/spatial_adata.py +141 -53
- smftools/cli_entry.py +119 -178
- smftools/config/__init__.py +3 -1
- smftools/config/conversion.yaml +5 -1
- smftools/config/deaminase.yaml +1 -1
- smftools/config/default.yaml +26 -18
- smftools/config/direct.yaml +8 -3
- smftools/config/discover_input_files.py +19 -5
- smftools/config/experiment_config.py +511 -276
- smftools/constants.py +37 -0
- smftools/datasets/__init__.py +4 -8
- smftools/datasets/datasets.py +32 -18
- smftools/hmm/HMM.py +2133 -1428
- smftools/hmm/__init__.py +24 -14
- smftools/hmm/archived/apply_hmm_batched.py +2 -0
- smftools/hmm/archived/calculate_distances.py +2 -0
- smftools/hmm/archived/call_hmm_peaks.py +18 -1
- smftools/hmm/archived/train_hmm.py +2 -0
- smftools/hmm/call_hmm_peaks.py +176 -193
- smftools/hmm/display_hmm.py +23 -7
- smftools/hmm/hmm_readwrite.py +20 -6
- smftools/hmm/nucleosome_hmm_refinement.py +104 -14
- smftools/informatics/__init__.py +55 -13
- smftools/informatics/archived/bam_conversion.py +2 -0
- smftools/informatics/archived/bam_direct.py +2 -0
- smftools/informatics/archived/basecall_pod5s.py +2 -0
- smftools/informatics/archived/basecalls_to_adata.py +2 -0
- smftools/informatics/archived/conversion_smf.py +2 -0
- smftools/informatics/archived/deaminase_smf.py +1 -0
- smftools/informatics/archived/direct_smf.py +2 -0
- smftools/informatics/archived/fast5_to_pod5.py +2 -0
- smftools/informatics/archived/helpers/archived/__init__.py +2 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
- smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
- smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
- smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
- smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
- smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
- smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
- smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
- smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
- smftools/informatics/archived/helpers/archived/informatics.py +2 -0
- smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
- smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
- smftools/informatics/archived/helpers/archived/modQC.py +2 -0
- smftools/informatics/archived/helpers/archived/modcall.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
- smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
- smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +9 -1
- smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
- smftools/informatics/archived/subsample_pod5.py +2 -0
- smftools/informatics/bam_functions.py +1059 -269
- smftools/informatics/basecalling.py +53 -9
- smftools/informatics/bed_functions.py +357 -114
- smftools/informatics/binarize_converted_base_identities.py +21 -7
- smftools/informatics/complement_base_list.py +9 -6
- smftools/informatics/converted_BAM_to_adata.py +324 -137
- smftools/informatics/fasta_functions.py +251 -89
- smftools/informatics/h5ad_functions.py +202 -30
- smftools/informatics/modkit_extract_to_adata.py +623 -274
- smftools/informatics/modkit_functions.py +87 -44
- smftools/informatics/ohe.py +46 -21
- smftools/informatics/pod5_functions.py +114 -74
- smftools/informatics/run_multiqc.py +20 -14
- smftools/logging_utils.py +51 -0
- smftools/machine_learning/__init__.py +23 -12
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +157 -50
- smftools/machine_learning/data/preprocessing.py +4 -1
- smftools/machine_learning/evaluation/__init__.py +3 -1
- smftools/machine_learning/evaluation/eval_utils.py +13 -14
- smftools/machine_learning/evaluation/evaluators.py +52 -34
- smftools/machine_learning/inference/__init__.py +3 -1
- smftools/machine_learning/inference/inference_utils.py +9 -4
- smftools/machine_learning/inference/lightning_inference.py +14 -13
- smftools/machine_learning/inference/sklearn_inference.py +8 -8
- smftools/machine_learning/inference/sliding_window_inference.py +37 -25
- smftools/machine_learning/models/__init__.py +12 -5
- smftools/machine_learning/models/base.py +34 -43
- smftools/machine_learning/models/cnn.py +22 -13
- smftools/machine_learning/models/lightning_base.py +78 -42
- smftools/machine_learning/models/mlp.py +18 -5
- smftools/machine_learning/models/positional.py +10 -4
- smftools/machine_learning/models/rnn.py +8 -3
- smftools/machine_learning/models/sklearn_models.py +46 -24
- smftools/machine_learning/models/transformer.py +75 -55
- smftools/machine_learning/models/wrappers.py +8 -3
- smftools/machine_learning/training/__init__.py +4 -2
- smftools/machine_learning/training/train_lightning_model.py +42 -23
- smftools/machine_learning/training/train_sklearn_model.py +11 -15
- smftools/machine_learning/utils/__init__.py +3 -1
- smftools/machine_learning/utils/device.py +12 -5
- smftools/machine_learning/utils/grl.py +8 -2
- smftools/metadata.py +443 -0
- smftools/optional_imports.py +31 -0
- smftools/plotting/__init__.py +32 -17
- smftools/plotting/autocorrelation_plotting.py +153 -48
- smftools/plotting/classifiers.py +175 -73
- smftools/plotting/general_plotting.py +350 -168
- smftools/plotting/hmm_plotting.py +53 -14
- smftools/plotting/position_stats.py +155 -87
- smftools/plotting/qc_plotting.py +25 -12
- smftools/preprocessing/__init__.py +35 -37
- smftools/preprocessing/append_base_context.py +105 -79
- smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
- smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
- smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
- smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
- smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
- smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
- smftools/preprocessing/binarize.py +21 -4
- smftools/preprocessing/binarize_on_Youden.py +127 -31
- smftools/preprocessing/binary_layers_to_ohe.py +18 -11
- smftools/preprocessing/calculate_complexity_II.py +89 -59
- smftools/preprocessing/calculate_consensus.py +28 -19
- smftools/preprocessing/calculate_coverage.py +44 -22
- smftools/preprocessing/calculate_pairwise_differences.py +4 -1
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
- smftools/preprocessing/calculate_position_Youden.py +110 -55
- smftools/preprocessing/calculate_read_length_stats.py +52 -23
- smftools/preprocessing/calculate_read_modification_stats.py +91 -57
- smftools/preprocessing/clean_NaN.py +38 -28
- smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
- smftools/preprocessing/flag_duplicate_reads.py +708 -303
- smftools/preprocessing/invert_adata.py +26 -11
- smftools/preprocessing/load_sample_sheet.py +40 -22
- smftools/preprocessing/make_dirs.py +9 -3
- smftools/preprocessing/min_non_diagonal.py +4 -1
- smftools/preprocessing/recipes.py +58 -23
- smftools/preprocessing/reindex_references_adata.py +93 -27
- smftools/preprocessing/subsample_adata.py +33 -16
- smftools/readwrite.py +264 -109
- smftools/schema/__init__.py +11 -0
- smftools/schema/anndata_schema_v1.yaml +227 -0
- smftools/tools/__init__.py +25 -18
- smftools/tools/archived/apply_hmm.py +2 -0
- smftools/tools/archived/classifiers.py +165 -0
- smftools/tools/archived/classify_methylated_features.py +2 -0
- smftools/tools/archived/classify_non_methylated_features.py +2 -0
- smftools/tools/archived/subset_adata_v1.py +12 -1
- smftools/tools/archived/subset_adata_v2.py +14 -1
- smftools/tools/calculate_umap.py +56 -15
- smftools/tools/cluster_adata_on_methylation.py +122 -47
- smftools/tools/general_tools.py +70 -25
- smftools/tools/position_stats.py +220 -99
- smftools/tools/read_stats.py +50 -29
- smftools/tools/spatial_autocorrelation.py +365 -192
- smftools/tools/subset_adata.py +23 -21
- smftools-0.3.0.dist-info/METADATA +147 -0
- smftools-0.3.0.dist-info/RECORD +182 -0
- smftools-0.2.4.dist-info/METADATA +0 -141
- smftools-0.2.4.dist-info/RECORD +0 -176
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,52 +1,76 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
# duplicate_detection_with_hier_and_plots.py
|
|
2
4
|
import copy
|
|
3
|
-
import warnings
|
|
4
5
|
import math
|
|
5
6
|
import os
|
|
7
|
+
import warnings
|
|
6
8
|
from collections import defaultdict
|
|
7
|
-
from
|
|
9
|
+
from importlib.util import find_spec
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
|
|
8
11
|
|
|
9
|
-
import torch
|
|
10
|
-
import anndata as ad
|
|
11
12
|
import numpy as np
|
|
12
13
|
import pandas as pd
|
|
13
|
-
|
|
14
|
-
from
|
|
14
|
+
from scipy.cluster import hierarchy as sch
|
|
15
|
+
from scipy.spatial.distance import pdist, squareform
|
|
16
|
+
from scipy.stats import gaussian_kde
|
|
17
|
+
|
|
18
|
+
from smftools.logging_utils import get_logger
|
|
19
|
+
from smftools.optional_imports import require
|
|
15
20
|
|
|
16
21
|
from ..readwrite import make_dirs
|
|
17
22
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
23
|
+
logger = get_logger(__name__)
|
|
24
|
+
|
|
25
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="duplicate read plots")
|
|
26
|
+
torch = require("torch", extra="torch", purpose="duplicate read detection")
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
import anndata as ad
|
|
30
|
+
|
|
31
|
+
SCIPY_AVAILABLE = True
|
|
32
|
+
SKLEARN_AVAILABLE = find_spec("sklearn") is not None
|
|
33
|
+
|
|
34
|
+
PCA = None
|
|
35
|
+
KMeans = DBSCAN = GaussianMixture = silhouette_score = None
|
|
36
|
+
if SKLEARN_AVAILABLE:
|
|
37
|
+
sklearn_cluster = require(
|
|
38
|
+
"sklearn.cluster",
|
|
39
|
+
extra="ml-base",
|
|
40
|
+
purpose="duplicate read clustering",
|
|
41
|
+
)
|
|
42
|
+
sklearn_decomp = require(
|
|
43
|
+
"sklearn.decomposition",
|
|
44
|
+
extra="ml-base",
|
|
45
|
+
purpose="duplicate read PCA",
|
|
46
|
+
)
|
|
47
|
+
sklearn_metrics = require(
|
|
48
|
+
"sklearn.metrics",
|
|
49
|
+
extra="ml-base",
|
|
50
|
+
purpose="duplicate read clustering diagnostics",
|
|
51
|
+
)
|
|
52
|
+
sklearn_mixture = require(
|
|
53
|
+
"sklearn.mixture",
|
|
54
|
+
extra="ml-base",
|
|
55
|
+
purpose="duplicate read clustering",
|
|
56
|
+
)
|
|
57
|
+
DBSCAN = sklearn_cluster.DBSCAN
|
|
58
|
+
KMeans = sklearn_cluster.KMeans
|
|
59
|
+
PCA = sklearn_decomp.PCA
|
|
60
|
+
silhouette_score = sklearn_metrics.silhouette_score
|
|
61
|
+
GaussianMixture = sklearn_mixture.GaussianMixture
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def merge_uns_preserve(orig_uns: dict, new_uns: dict, prefer: str = "orig") -> dict:
|
|
65
|
+
"""Merge two ``.uns`` dictionaries while preserving preferred values.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
orig_uns: Original ``.uns`` dictionary.
|
|
69
|
+
new_uns: New ``.uns`` dictionary to merge.
|
|
70
|
+
prefer: Which dictionary to prefer on conflict (``"orig"`` or ``"new"``).
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
dict: Merged dictionary.
|
|
50
74
|
"""
|
|
51
75
|
out = copy.deepcopy(new_uns) if new_uns is not None else {}
|
|
52
76
|
for k, v in (orig_uns or {}).items():
|
|
@@ -55,7 +79,7 @@ def merge_uns_preserve(orig_uns: dict, new_uns: dict, prefer="orig") -> dict:
|
|
|
55
79
|
else:
|
|
56
80
|
# present in both: compare quickly (best-effort)
|
|
57
81
|
try:
|
|
58
|
-
equal =
|
|
82
|
+
equal = out[k] == v
|
|
59
83
|
except Exception:
|
|
60
84
|
equal = False
|
|
61
85
|
if equal:
|
|
@@ -69,9 +93,10 @@ def merge_uns_preserve(orig_uns: dict, new_uns: dict, prefer="orig") -> dict:
|
|
|
69
93
|
out[f"orig_uns__{k}"] = copy.deepcopy(v)
|
|
70
94
|
return out
|
|
71
95
|
|
|
96
|
+
|
|
72
97
|
def flag_duplicate_reads(
|
|
73
|
-
adata,
|
|
74
|
-
var_filters_sets,
|
|
98
|
+
adata: ad.AnnData,
|
|
99
|
+
var_filters_sets: Sequence[dict[str, Any]],
|
|
75
100
|
distance_threshold: float = 0.07,
|
|
76
101
|
obs_reference_col: str = "Reference_strand",
|
|
77
102
|
sample_col: str = "Barcode",
|
|
@@ -81,7 +106,7 @@ def flag_duplicate_reads(
|
|
|
81
106
|
uns_filtered_flag: str = "read_duplicates_removed",
|
|
82
107
|
bypass: bool = False,
|
|
83
108
|
force_redo: bool = False,
|
|
84
|
-
keep_best_metric: Optional[str] =
|
|
109
|
+
keep_best_metric: Optional[str] = "read_quality",
|
|
85
110
|
keep_best_higher: bool = True,
|
|
86
111
|
window_size: int = 50,
|
|
87
112
|
min_overlap_positions: int = 20,
|
|
@@ -93,19 +118,119 @@ def flag_duplicate_reads(
|
|
|
93
118
|
hierarchical_metric: str = "euclidean",
|
|
94
119
|
hierarchical_window: int = 50,
|
|
95
120
|
random_state: int = 0,
|
|
96
|
-
|
|
121
|
+
demux_types: Optional[Sequence[str]] = None,
|
|
122
|
+
demux_col: str = "demux_type",
|
|
123
|
+
) -> ad.AnnData:
|
|
124
|
+
"""Flag duplicate reads with demux-aware keeper preference.
|
|
125
|
+
|
|
126
|
+
Behavior:
|
|
127
|
+
- All reads are processed (no masking by demux).
|
|
128
|
+
- At each keeper decision, prefer reads whose ``demux_col`` value is in
|
|
129
|
+
``demux_types`` when present. Among candidates, choose by
|
|
130
|
+
``keep_best_metric``.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
adata: AnnData object to process.
|
|
134
|
+
var_filters_sets: Sequence of variable filter definitions.
|
|
135
|
+
distance_threshold: Distance threshold for duplicate detection.
|
|
136
|
+
obs_reference_col: Obs column containing reference identifiers.
|
|
137
|
+
sample_col: Obs column containing sample identifiers.
|
|
138
|
+
output_directory: Directory for output plots and artifacts.
|
|
139
|
+
metric_keys: Metric key(s) used in processing.
|
|
140
|
+
uns_flag: Flag in ``adata.uns`` indicating prior completion.
|
|
141
|
+
uns_filtered_flag: Flag to mark read duplicates removal.
|
|
142
|
+
bypass: Whether to skip processing.
|
|
143
|
+
force_redo: Whether to rerun even if ``uns_flag`` is set.
|
|
144
|
+
keep_best_metric: Obs column used to select best read within duplicates.
|
|
145
|
+
keep_best_higher: Whether higher values in ``keep_best_metric`` are preferred.
|
|
146
|
+
window_size: Window size for local comparisons.
|
|
147
|
+
min_overlap_positions: Minimum overlapping positions required.
|
|
148
|
+
do_pca: Whether to run PCA before clustering.
|
|
149
|
+
pca_n_components: Number of PCA components.
|
|
150
|
+
pca_center: Whether to center data before PCA.
|
|
151
|
+
do_hierarchical: Whether to run hierarchical clustering.
|
|
152
|
+
hierarchical_linkage: Linkage method for hierarchical clustering.
|
|
153
|
+
hierarchical_metric: Distance metric for hierarchical clustering.
|
|
154
|
+
hierarchical_window: Window size for hierarchical clustering.
|
|
155
|
+
random_state: Random seed.
|
|
156
|
+
demux_types: Preferred demux types for keeper selection.
|
|
157
|
+
demux_col: Obs column containing demux type labels.
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
anndata.AnnData: AnnData object with duplicate flags stored in ``adata.obs``.
|
|
97
161
|
"""
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
162
|
+
import copy
|
|
163
|
+
import warnings
|
|
164
|
+
|
|
165
|
+
import anndata as ad
|
|
166
|
+
import numpy as np
|
|
167
|
+
import pandas as pd
|
|
168
|
+
|
|
169
|
+
# -------- helper: demux-aware keeper selection --------
|
|
170
|
+
def _choose_keeper_with_demux_preference(
|
|
171
|
+
members_idx: List[int],
|
|
172
|
+
adata_subset: ad.AnnData,
|
|
173
|
+
obs_index_list: List[Any],
|
|
174
|
+
*,
|
|
175
|
+
demux_col: str = "demux_type",
|
|
176
|
+
preferred_types: Optional[Sequence[str]] = None,
|
|
177
|
+
keep_best_metric: Optional[str] = None,
|
|
178
|
+
keep_best_higher: bool = True,
|
|
179
|
+
lex_keeper_mask: Optional[np.ndarray] = None, # aligned to members order
|
|
180
|
+
) -> int:
|
|
181
|
+
"""
|
|
182
|
+
Prefer members whose demux_col ∈ preferred_types.
|
|
183
|
+
Among candidates, pick by keep_best_metric (higher/lower).
|
|
184
|
+
If metric missing/NaN, prefer lex keeper (via mask) among candidates.
|
|
185
|
+
Fallback: first candidate.
|
|
186
|
+
Returns the chosen *member index* (int from members_idx).
|
|
187
|
+
"""
|
|
188
|
+
# 1) demux-preferred candidates
|
|
189
|
+
if preferred_types and (demux_col in adata_subset.obs.columns):
|
|
190
|
+
preferred = set(map(str, preferred_types))
|
|
191
|
+
demux_series = adata_subset.obs[demux_col].astype("string")
|
|
192
|
+
names = [obs_index_list[m] for m in members_idx]
|
|
193
|
+
is_pref = demux_series.loc[names].isin(preferred).to_numpy()
|
|
194
|
+
candidates = [members_idx[i] for i, ok in enumerate(is_pref) if ok]
|
|
195
|
+
else:
|
|
196
|
+
candidates = []
|
|
101
197
|
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
198
|
+
if not candidates:
|
|
199
|
+
candidates = list(members_idx)
|
|
200
|
+
|
|
201
|
+
# 2) metric-based within candidates
|
|
202
|
+
if keep_best_metric and (keep_best_metric in adata_subset.obs.columns):
|
|
203
|
+
cand_names = [obs_index_list[m] for m in candidates]
|
|
204
|
+
try:
|
|
205
|
+
vals = pd.to_numeric(
|
|
206
|
+
adata_subset.obs.loc[cand_names, keep_best_metric],
|
|
207
|
+
errors="coerce",
|
|
208
|
+
).to_numpy(dtype=float)
|
|
209
|
+
except Exception:
|
|
210
|
+
vals = np.array([np.nan] * len(candidates), dtype=float)
|
|
211
|
+
|
|
212
|
+
if not np.all(np.isnan(vals)):
|
|
213
|
+
if keep_best_higher:
|
|
214
|
+
vals = np.where(np.isnan(vals), -np.inf, vals)
|
|
215
|
+
return candidates[int(np.nanargmax(vals))]
|
|
216
|
+
else:
|
|
217
|
+
vals = np.where(np.isnan(vals), np.inf, vals)
|
|
218
|
+
return candidates[int(np.nanargmin(vals))]
|
|
219
|
+
|
|
220
|
+
# 3) metric unhelpful — prefer lex keeper if provided
|
|
221
|
+
if lex_keeper_mask is not None:
|
|
222
|
+
for i, midx in enumerate(members_idx):
|
|
223
|
+
if (midx in candidates) and bool(lex_keeper_mask[i]):
|
|
224
|
+
return midx
|
|
225
|
+
|
|
226
|
+
# 4) fallback
|
|
227
|
+
return candidates[0]
|
|
228
|
+
|
|
229
|
+
# -------- early exits --------
|
|
105
230
|
already = bool(adata.uns.get(uns_flag, False))
|
|
106
|
-
if
|
|
231
|
+
if already and not force_redo:
|
|
107
232
|
if "is_duplicate" in adata.obs.columns:
|
|
108
|
-
adata_unique = adata[adata.obs["is_duplicate"]
|
|
233
|
+
adata_unique = adata[~adata.obs["is_duplicate"]].copy()
|
|
109
234
|
return adata_unique, adata
|
|
110
235
|
else:
|
|
111
236
|
return adata.copy(), adata.copy()
|
|
@@ -117,17 +242,23 @@ def flag_duplicate_reads(
|
|
|
117
242
|
|
|
118
243
|
# local UnionFind
|
|
119
244
|
class UnionFind:
|
|
245
|
+
"""Disjoint-set union-find helper for clustering indices."""
|
|
246
|
+
|
|
120
247
|
def __init__(self, size):
|
|
248
|
+
"""Initialize parent pointers for the union-find."""
|
|
121
249
|
self.parent = list(range(size))
|
|
122
250
|
|
|
123
251
|
def find(self, x):
|
|
252
|
+
"""Find the root for a member with path compression."""
|
|
124
253
|
while self.parent[x] != x:
|
|
125
254
|
self.parent[x] = self.parent[self.parent[x]]
|
|
126
255
|
x = self.parent[x]
|
|
127
256
|
return x
|
|
128
257
|
|
|
129
258
|
def union(self, x, y):
|
|
130
|
-
|
|
259
|
+
"""Union the sets that contain x and y."""
|
|
260
|
+
rx = self.find(x)
|
|
261
|
+
ry = self.find(y)
|
|
131
262
|
if rx != ry:
|
|
132
263
|
self.parent[ry] = rx
|
|
133
264
|
|
|
@@ -139,14 +270,14 @@ def flag_duplicate_reads(
|
|
|
139
270
|
|
|
140
271
|
for sample in samples:
|
|
141
272
|
for ref in references:
|
|
142
|
-
|
|
273
|
+
logger.info("Processing sample=%s ref=%s", sample, ref)
|
|
143
274
|
sample_mask = adata.obs[sample_col] == sample
|
|
144
275
|
ref_mask = adata.obs[obs_reference_col] == ref
|
|
145
276
|
subset_mask = sample_mask & ref_mask
|
|
146
277
|
adata_subset = adata[subset_mask].copy()
|
|
147
278
|
|
|
148
279
|
if adata_subset.n_obs < 2:
|
|
149
|
-
|
|
280
|
+
logger.info(" Skipping %s_%s (too few reads)", sample, ref)
|
|
150
281
|
continue
|
|
151
282
|
|
|
152
283
|
N = adata_subset.shape[0]
|
|
@@ -162,7 +293,12 @@ def flag_duplicate_reads(
|
|
|
162
293
|
|
|
163
294
|
selected_cols = adata.var.index[combined_mask.tolist()].to_list()
|
|
164
295
|
col_indices = [adata.var.index.get_loc(c) for c in selected_cols]
|
|
165
|
-
|
|
296
|
+
logger.info(
|
|
297
|
+
" Selected %s columns out of %s for %s",
|
|
298
|
+
len(col_indices),
|
|
299
|
+
adata.var.shape[0],
|
|
300
|
+
ref,
|
|
301
|
+
)
|
|
166
302
|
|
|
167
303
|
# Extract data matrix (dense numpy) for the subset
|
|
168
304
|
X = adata_subset.X
|
|
@@ -187,7 +323,10 @@ def flag_duplicate_reads(
|
|
|
187
323
|
hierarchical_found_dists = []
|
|
188
324
|
|
|
189
325
|
# Lexicographic windowed pass function
|
|
190
|
-
def cluster_pass(
|
|
326
|
+
def cluster_pass(
|
|
327
|
+
X_tensor_local, reverse=False, window=int(window_size), record_distances=False
|
|
328
|
+
):
|
|
329
|
+
"""Perform a lexicographic windowed clustering pass."""
|
|
191
330
|
N_local = X_tensor_local.shape[0]
|
|
192
331
|
X_sortable = X_tensor_local.clone().nan_to_num(-1.0)
|
|
193
332
|
sort_keys = [tuple(row.numpy().tolist()) for row in X_sortable]
|
|
@@ -208,10 +347,16 @@ def flag_duplicate_reads(
|
|
|
208
347
|
if enough_overlap.any():
|
|
209
348
|
diffs = (row_i_exp != block_rows) & valid_mask
|
|
210
349
|
hamming_counts = diffs.sum(dim=1).float()
|
|
211
|
-
hamming_dists = torch.where(
|
|
350
|
+
hamming_dists = torch.where(
|
|
351
|
+
valid_counts > 0,
|
|
352
|
+
hamming_counts / valid_counts,
|
|
353
|
+
torch.tensor(float("nan")),
|
|
354
|
+
)
|
|
212
355
|
# record distances (legacy list of all local comparisons)
|
|
213
356
|
hamming_np = hamming_dists.cpu().numpy().tolist()
|
|
214
|
-
local_hamming_dists.extend(
|
|
357
|
+
local_hamming_dists.extend(
|
|
358
|
+
[float(x) for x in hamming_np if (not np.isnan(x))]
|
|
359
|
+
)
|
|
215
360
|
matches = (hamming_dists < distance_threshold) & (enough_overlap)
|
|
216
361
|
for offset_local, m in enumerate(matches):
|
|
217
362
|
if m:
|
|
@@ -223,20 +368,28 @@ def flag_duplicate_reads(
|
|
|
223
368
|
next_local_idx = i + 1
|
|
224
369
|
if next_local_idx < len(sorted_X):
|
|
225
370
|
next_global = sorted_idx[next_local_idx]
|
|
226
|
-
vm_pair = (~torch.isnan(row_i)) & (
|
|
371
|
+
vm_pair = (~torch.isnan(row_i)) & (
|
|
372
|
+
~torch.isnan(sorted_X[next_local_idx])
|
|
373
|
+
)
|
|
227
374
|
vc = vm_pair.sum().item()
|
|
228
375
|
if vc >= min_overlap_positions:
|
|
229
|
-
d = float(
|
|
376
|
+
d = float(
|
|
377
|
+
(
|
|
378
|
+
(row_i[vm_pair] != sorted_X[next_local_idx][vm_pair])
|
|
379
|
+
.sum()
|
|
380
|
+
.item()
|
|
381
|
+
)
|
|
382
|
+
/ vc
|
|
383
|
+
)
|
|
230
384
|
if reverse:
|
|
231
385
|
rev_hamming_to_prev[next_global] = d
|
|
232
386
|
else:
|
|
233
387
|
fwd_hamming_to_next[sorted_idx[i]] = d
|
|
234
388
|
return cluster_pairs_local
|
|
235
389
|
|
|
236
|
-
# run forward
|
|
390
|
+
# run forward & reverse windows
|
|
237
391
|
pairs_fwd = cluster_pass(X_tensor, reverse=False, record_distances=True)
|
|
238
392
|
involved_in_fwd = set([p for pair in pairs_fwd for p in pair])
|
|
239
|
-
# build mask for reverse pass to avoid re-checking items already paired
|
|
240
393
|
mask_for_rev = np.ones(N, dtype=bool)
|
|
241
394
|
if len(involved_in_fwd) > 0:
|
|
242
395
|
for idx in involved_in_fwd:
|
|
@@ -245,8 +398,9 @@ def flag_duplicate_reads(
|
|
|
245
398
|
if len(rev_idx_map) > 0:
|
|
246
399
|
reduced_tensor = X_tensor[rev_idx_map]
|
|
247
400
|
pairs_rev_local = cluster_pass(reduced_tensor, reverse=True, record_distances=True)
|
|
248
|
-
|
|
249
|
-
|
|
401
|
+
remapped_rev_pairs = [
|
|
402
|
+
(int(rev_idx_map[i]), int(rev_idx_map[j])) for (i, j) in pairs_rev_local
|
|
403
|
+
]
|
|
250
404
|
else:
|
|
251
405
|
remapped_rev_pairs = []
|
|
252
406
|
|
|
@@ -265,53 +419,41 @@ def flag_duplicate_reads(
|
|
|
265
419
|
id_map = {old: new for new, old in enumerate(sorted(unique_initial.tolist()))}
|
|
266
420
|
merged_cluster_mapped = np.array([id_map[int(x)] for x in merged_cluster], dtype=int)
|
|
267
421
|
|
|
268
|
-
# cluster sizes and choose lex-keeper per lex-cluster (
|
|
422
|
+
# cluster sizes and choose lex-keeper per lex-cluster (demux-aware)
|
|
269
423
|
cluster_sizes = np.zeros_like(merged_cluster_mapped)
|
|
270
424
|
cluster_counts = []
|
|
271
425
|
unique_clusters = np.unique(merged_cluster_mapped)
|
|
272
426
|
keeper_for_cluster = {}
|
|
427
|
+
|
|
428
|
+
obs_index = list(adata_subset.obs.index)
|
|
273
429
|
for cid in unique_clusters:
|
|
274
430
|
members = np.where(merged_cluster_mapped == cid)[0].tolist()
|
|
275
431
|
csize = int(len(members))
|
|
276
432
|
cluster_counts.append(csize)
|
|
277
433
|
cluster_sizes[members] = csize
|
|
278
|
-
# pick lex keeper (representative)
|
|
279
|
-
if len(members) == 1:
|
|
280
|
-
keeper_for_cluster[cid] = members[0]
|
|
281
|
-
else:
|
|
282
|
-
if keep_best_metric is None:
|
|
283
|
-
keeper_for_cluster[cid] = members[0]
|
|
284
|
-
else:
|
|
285
|
-
obs_index = list(adata_subset.obs.index)
|
|
286
|
-
member_names = [obs_index[m] for m in members]
|
|
287
|
-
try:
|
|
288
|
-
vals = pd.to_numeric(adata_subset.obs.loc[member_names, keep_best_metric], errors="coerce").to_numpy(dtype=float)
|
|
289
|
-
except Exception:
|
|
290
|
-
vals = np.array([np.nan] * len(members), dtype=float)
|
|
291
|
-
if np.all(np.isnan(vals)):
|
|
292
|
-
keeper_for_cluster[cid] = members[0]
|
|
293
|
-
else:
|
|
294
|
-
if keep_best_higher:
|
|
295
|
-
nan_mask = np.isnan(vals)
|
|
296
|
-
vals[nan_mask] = -np.inf
|
|
297
|
-
rel_idx = int(np.nanargmax(vals))
|
|
298
|
-
else:
|
|
299
|
-
nan_mask = np.isnan(vals)
|
|
300
|
-
vals[nan_mask] = np.inf
|
|
301
|
-
rel_idx = int(np.nanargmin(vals))
|
|
302
|
-
keeper_for_cluster[cid] = members[rel_idx]
|
|
303
434
|
|
|
304
|
-
|
|
435
|
+
keeper_for_cluster[cid] = _choose_keeper_with_demux_preference(
|
|
436
|
+
members,
|
|
437
|
+
adata_subset,
|
|
438
|
+
obs_index,
|
|
439
|
+
demux_col=demux_col,
|
|
440
|
+
preferred_types=demux_types,
|
|
441
|
+
keep_best_metric=keep_best_metric,
|
|
442
|
+
keep_best_higher=keep_best_higher,
|
|
443
|
+
lex_keeper_mask=None, # no lex preference yet
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
# expose lex keeper info (record only)
|
|
305
447
|
lex_is_keeper = np.zeros((N,), dtype=bool)
|
|
306
448
|
lex_is_duplicate = np.zeros((N,), dtype=bool)
|
|
307
|
-
for cid
|
|
449
|
+
for cid in unique_clusters:
|
|
450
|
+
members = np.where(merged_cluster_mapped == cid)[0].tolist()
|
|
308
451
|
keeper_idx = keeper_for_cluster[cid]
|
|
309
452
|
lex_is_keeper[keeper_idx] = True
|
|
310
453
|
for m in members:
|
|
311
454
|
if m != keeper_idx:
|
|
312
455
|
lex_is_duplicate[m] = True
|
|
313
|
-
|
|
314
|
-
# and will be written to adata_subset.obs below
|
|
456
|
+
|
|
315
457
|
# record lex min pair (min of fwd/rev neighbor) for each read
|
|
316
458
|
min_pair = np.full((N,), np.nan, dtype=float)
|
|
317
459
|
for i in range(N):
|
|
@@ -349,7 +491,12 @@ def flag_duplicate_reads(
|
|
|
349
491
|
if n_comp <= 0:
|
|
350
492
|
reps_for_clustering = reps_arr
|
|
351
493
|
else:
|
|
352
|
-
pca = PCA(
|
|
494
|
+
pca = PCA(
|
|
495
|
+
n_components=n_comp,
|
|
496
|
+
random_state=int(random_state),
|
|
497
|
+
svd_solver="auto",
|
|
498
|
+
copy=True,
|
|
499
|
+
)
|
|
353
500
|
reps_for_clustering = pca.fit_transform(reps_arr)
|
|
354
501
|
else:
|
|
355
502
|
reps_for_clustering = reps_arr
|
|
@@ -360,10 +507,12 @@ def flag_duplicate_reads(
|
|
|
360
507
|
Z = sch.linkage(pdist_vec, method=hierarchical_linkage)
|
|
361
508
|
leaves = sch.leaves_list(Z)
|
|
362
509
|
except Exception as e:
|
|
363
|
-
warnings.warn(
|
|
510
|
+
warnings.warn(
|
|
511
|
+
f"hierarchical pass failed: {e}; skipping hierarchical stage."
|
|
512
|
+
)
|
|
364
513
|
leaves = np.arange(len(rep_global_indices), dtype=int)
|
|
365
514
|
|
|
366
|
-
#
|
|
515
|
+
# windowed hamming comparisons across ordered reps and union
|
|
367
516
|
order_global_reps = [rep_global_indices[i] for i in leaves]
|
|
368
517
|
n_reps = len(order_global_reps)
|
|
369
518
|
for pos in range(n_reps):
|
|
@@ -389,55 +538,40 @@ def flag_duplicate_reads(
|
|
|
389
538
|
merged_cluster_after[i] = uf.find(i)
|
|
390
539
|
unique_final = np.unique(merged_cluster_after)
|
|
391
540
|
id_map_final = {old: new for new, old in enumerate(sorted(unique_final.tolist()))}
|
|
392
|
-
merged_cluster_mapped_final = np.array(
|
|
541
|
+
merged_cluster_mapped_final = np.array(
|
|
542
|
+
[id_map_final[int(x)] for x in merged_cluster_after], dtype=int
|
|
543
|
+
)
|
|
393
544
|
|
|
394
|
-
# compute final cluster members and choose final keeper per final cluster
|
|
545
|
+
# compute final cluster members and choose final keeper per final cluster (demux-aware)
|
|
395
546
|
cluster_sizes_final = np.zeros_like(merged_cluster_mapped_final)
|
|
396
|
-
final_cluster_counts = []
|
|
397
|
-
final_unique = np.unique(merged_cluster_mapped_final)
|
|
398
547
|
final_keeper_for_cluster = {}
|
|
399
548
|
cluster_members_map = {}
|
|
400
|
-
|
|
549
|
+
|
|
550
|
+
obs_index = list(adata_subset.obs.index)
|
|
551
|
+
lex_mask_full = lex_is_keeper # use lex keeper as optional tiebreaker
|
|
552
|
+
|
|
553
|
+
for cid in np.unique(merged_cluster_mapped_final):
|
|
401
554
|
members = np.where(merged_cluster_mapped_final == cid)[0].tolist()
|
|
402
555
|
cluster_members_map[cid] = members
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
else:
|
|
421
|
-
if keep_best_higher:
|
|
422
|
-
nan_mask = np.isnan(vals)
|
|
423
|
-
vals[nan_mask] = -np.inf
|
|
424
|
-
rel_idx = int(np.nanargmax(vals))
|
|
425
|
-
else:
|
|
426
|
-
nan_mask = np.isnan(vals)
|
|
427
|
-
vals[nan_mask] = np.inf
|
|
428
|
-
rel_idx = int(np.nanargmin(vals))
|
|
429
|
-
final_keeper_for_cluster[cid] = members[rel_idx]
|
|
430
|
-
else:
|
|
431
|
-
# if lex keepers present among members, prefer them
|
|
432
|
-
lex_members = [m for m in members if lex_is_keeper[m]]
|
|
433
|
-
if len(lex_members) > 0:
|
|
434
|
-
final_keeper_for_cluster[cid] = lex_members[0]
|
|
435
|
-
else:
|
|
436
|
-
final_keeper_for_cluster[cid] = members[0]
|
|
437
|
-
|
|
438
|
-
# update sequence__is_duplicate based on final clusters: non-keepers in multi-member clusters are duplicates
|
|
556
|
+
cluster_sizes_final[members] = len(members)
|
|
557
|
+
|
|
558
|
+
lex_mask_members = np.array([bool(lex_mask_full[m]) for m in members], dtype=bool)
|
|
559
|
+
|
|
560
|
+
keeper = _choose_keeper_with_demux_preference(
|
|
561
|
+
members,
|
|
562
|
+
adata_subset,
|
|
563
|
+
obs_index,
|
|
564
|
+
demux_col=demux_col,
|
|
565
|
+
preferred_types=demux_types,
|
|
566
|
+
keep_best_metric=keep_best_metric,
|
|
567
|
+
keep_best_higher=keep_best_higher,
|
|
568
|
+
lex_keeper_mask=lex_mask_members,
|
|
569
|
+
)
|
|
570
|
+
final_keeper_for_cluster[cid] = keeper
|
|
571
|
+
|
|
572
|
+
# update sequence__is_duplicate based on final clusters
|
|
439
573
|
sequence_is_duplicate = np.zeros((N,), dtype=bool)
|
|
440
|
-
for cid in
|
|
574
|
+
for cid in np.unique(merged_cluster_mapped_final):
|
|
441
575
|
keeper = final_keeper_for_cluster[cid]
|
|
442
576
|
members = cluster_members_map[cid]
|
|
443
577
|
if len(members) > 1:
|
|
@@ -446,8 +580,7 @@ def flag_duplicate_reads(
|
|
|
446
580
|
sequence_is_duplicate[m] = True
|
|
447
581
|
|
|
448
582
|
# propagate hierarchical distances into hierarchical_min_pair for all cluster members
|
|
449
|
-
for
|
|
450
|
-
# identify their final cluster ids (after unions)
|
|
583
|
+
for i_g, j_g, d in hierarchical_pairs:
|
|
451
584
|
c_i = merged_cluster_mapped_final[int(i_g)]
|
|
452
585
|
c_j = merged_cluster_mapped_final[int(j_g)]
|
|
453
586
|
members_i = cluster_members_map.get(c_i, [int(i_g)])
|
|
@@ -459,7 +592,7 @@ def flag_duplicate_reads(
|
|
|
459
592
|
if np.isnan(hierarchical_min_pair[mj]) or (d < hierarchical_min_pair[mj]):
|
|
460
593
|
hierarchical_min_pair[mj] = d
|
|
461
594
|
|
|
462
|
-
# combine
|
|
595
|
+
# combine min pairs
|
|
463
596
|
combined_min = min_pair.copy()
|
|
464
597
|
for i in range(N):
|
|
465
598
|
hval = hierarchical_min_pair[i]
|
|
@@ -475,69 +608,117 @@ def flag_duplicate_reads(
|
|
|
475
608
|
adata_subset.obs["rev_hamming_to_prev"] = rev_hamming_to_prev
|
|
476
609
|
adata_subset.obs["sequence__hier_hamming_to_pair"] = hierarchical_min_pair
|
|
477
610
|
adata_subset.obs["sequence__min_hamming_to_pair"] = combined_min
|
|
478
|
-
# persist lex bookkeeping
|
|
611
|
+
# persist lex bookkeeping
|
|
479
612
|
adata_subset.obs["sequence__lex_is_keeper"] = lex_is_keeper
|
|
480
613
|
adata_subset.obs["sequence__lex_is_duplicate"] = lex_is_duplicate
|
|
481
614
|
|
|
482
615
|
adata_processed_list.append(adata_subset)
|
|
483
616
|
|
|
484
|
-
histograms.append(
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
617
|
+
histograms.append(
|
|
618
|
+
{
|
|
619
|
+
"sample": sample,
|
|
620
|
+
"reference": ref,
|
|
621
|
+
"distances": local_hamming_dists,
|
|
622
|
+
"cluster_counts": [
|
|
623
|
+
int(x) for x in np.unique(cluster_sizes_final[cluster_sizes_final > 0])
|
|
624
|
+
],
|
|
625
|
+
"hierarchical_pairs": hierarchical_found_dists,
|
|
626
|
+
}
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
# Merge annotated subsets back together BEFORE plotting
|
|
493
630
|
_original_uns = copy.deepcopy(adata.uns)
|
|
494
631
|
if len(adata_processed_list) == 0:
|
|
495
632
|
return adata.copy(), adata.copy()
|
|
496
633
|
|
|
497
634
|
adata_full = ad.concat(adata_processed_list, merge="same", join="outer", index_unique=None)
|
|
635
|
+
|
|
636
|
+
# preserve uns (prefer original on conflicts)
|
|
637
|
+
def merge_uns_preserve(orig_uns: dict, new_uns: dict, prefer="orig") -> dict:
|
|
638
|
+
"""Merge .uns dictionaries while preserving original on conflicts."""
|
|
639
|
+
out = copy.deepcopy(new_uns) if new_uns is not None else {}
|
|
640
|
+
for k, v in (orig_uns or {}).items():
|
|
641
|
+
if k not in out:
|
|
642
|
+
out[k] = copy.deepcopy(v)
|
|
643
|
+
else:
|
|
644
|
+
try:
|
|
645
|
+
equal = out[k] == v
|
|
646
|
+
except Exception:
|
|
647
|
+
equal = False
|
|
648
|
+
if equal:
|
|
649
|
+
continue
|
|
650
|
+
if prefer == "orig":
|
|
651
|
+
out[k] = copy.deepcopy(v)
|
|
652
|
+
else:
|
|
653
|
+
out[f"orig_uns__{k}"] = copy.deepcopy(v)
|
|
654
|
+
return out
|
|
655
|
+
|
|
498
656
|
adata_full.uns = merge_uns_preserve(_original_uns, adata_full.uns, prefer="orig")
|
|
499
657
|
|
|
500
658
|
# Ensure expected numeric columns exist (create if missing)
|
|
501
|
-
for col in (
|
|
659
|
+
for col in (
|
|
660
|
+
"fwd_hamming_to_next",
|
|
661
|
+
"rev_hamming_to_prev",
|
|
662
|
+
"sequence__min_hamming_to_pair",
|
|
663
|
+
"sequence__hier_hamming_to_pair",
|
|
664
|
+
):
|
|
502
665
|
if col not in adata_full.obs.columns:
|
|
503
666
|
adata_full.obs[col] = np.nan
|
|
504
667
|
|
|
505
|
-
# histograms
|
|
506
|
-
hist_outs =
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
668
|
+
# histograms
|
|
669
|
+
hist_outs = (
|
|
670
|
+
os.path.join(output_directory, "read_pair_hamming_distance_histograms")
|
|
671
|
+
if output_directory
|
|
672
|
+
else None
|
|
673
|
+
)
|
|
674
|
+
if hist_outs:
|
|
675
|
+
make_dirs([hist_outs])
|
|
676
|
+
plot_histogram_pages(
|
|
677
|
+
histograms,
|
|
678
|
+
distance_threshold=distance_threshold,
|
|
679
|
+
adata=adata_full,
|
|
680
|
+
output_directory=hist_outs,
|
|
681
|
+
distance_types=["min", "fwd", "rev", "hier", "lex_local"],
|
|
682
|
+
sample_key=sample_col,
|
|
683
|
+
)
|
|
515
684
|
|
|
516
685
|
# hamming vs metric scatter
|
|
517
|
-
scatter_outs =
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
686
|
+
scatter_outs = (
|
|
687
|
+
os.path.join(output_directory, "read_pair_hamming_distance_scatter_plots")
|
|
688
|
+
if output_directory
|
|
689
|
+
else None
|
|
690
|
+
)
|
|
691
|
+
if scatter_outs:
|
|
692
|
+
make_dirs([scatter_outs])
|
|
693
|
+
plot_hamming_vs_metric_pages(
|
|
694
|
+
adata_full,
|
|
695
|
+
metric_keys=metric_keys,
|
|
696
|
+
output_dir=scatter_outs,
|
|
697
|
+
hamming_col="sequence__min_hamming_to_pair",
|
|
698
|
+
highlight_threshold=distance_threshold,
|
|
699
|
+
highlight_color="red",
|
|
700
|
+
sample_col=sample_col,
|
|
701
|
+
)
|
|
526
702
|
|
|
527
703
|
# boolean columns from neighbor distances
|
|
528
|
-
fwd_vals = pd.to_numeric(
|
|
529
|
-
|
|
704
|
+
fwd_vals = pd.to_numeric(
|
|
705
|
+
adata_full.obs.get("fwd_hamming_to_next", pd.Series(np.nan, index=adata_full.obs.index)),
|
|
706
|
+
errors="coerce",
|
|
707
|
+
)
|
|
708
|
+
rev_vals = pd.to_numeric(
|
|
709
|
+
adata_full.obs.get("rev_hamming_to_prev", pd.Series(np.nan, index=adata_full.obs.index)),
|
|
710
|
+
errors="coerce",
|
|
711
|
+
)
|
|
530
712
|
is_dup_dist = (fwd_vals < float(distance_threshold)) | (rev_vals < float(distance_threshold))
|
|
531
713
|
is_dup_dist = is_dup_dist.fillna(False).astype(bool)
|
|
532
714
|
adata_full.obs["is_duplicate_distance"] = is_dup_dist.values
|
|
533
715
|
|
|
534
|
-
# combine sequence
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
# cluster-based duplicate indicator (if any clustering columns exist)
|
|
716
|
+
# combine with sequence flag and any clustering flags
|
|
717
|
+
seq_dup = (
|
|
718
|
+
adata_full.obs["sequence__is_duplicate"].astype(bool)
|
|
719
|
+
if "sequence__is_duplicate" in adata_full.obs.columns
|
|
720
|
+
else pd.Series(False, index=adata_full.obs.index)
|
|
721
|
+
)
|
|
541
722
|
cluster_cols = [c for c in adata_full.obs.columns if c.startswith("hamming_cluster__")]
|
|
542
723
|
if cluster_cols:
|
|
543
724
|
cl_mask = pd.Series(False, index=adata_full.obs.index)
|
|
@@ -550,59 +731,61 @@ def flag_duplicate_reads(
|
|
|
550
731
|
else:
|
|
551
732
|
adata_full.obs["is_duplicate_clustering"] = False
|
|
552
733
|
|
|
553
|
-
final_dup =
|
|
734
|
+
final_dup = (
|
|
735
|
+
seq_dup
|
|
736
|
+
| adata_full.obs["is_duplicate_distance"].astype(bool)
|
|
737
|
+
| adata_full.obs["is_duplicate_clustering"].astype(bool)
|
|
738
|
+
)
|
|
554
739
|
adata_full.obs["is_duplicate"] = final_dup.values
|
|
555
740
|
|
|
556
|
-
# Final keeper enforcement
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
keeper_idx_by_cluster = {}
|
|
560
|
-
metric_col = keep_best_metric if 'keep_best_metric' in locals() else None
|
|
741
|
+
# -------- Final keeper enforcement on adata_full (demux-aware) --------
|
|
742
|
+
keeper_idx_by_cluster = {}
|
|
743
|
+
metric_col = keep_best_metric
|
|
561
744
|
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
try:
|
|
566
|
-
members = sub.index.to_list()
|
|
567
|
-
except Exception:
|
|
568
|
-
members = list(sub.index)
|
|
569
|
-
keeper = None
|
|
570
|
-
# prefer keep_best_metric (if present), else prefer lex keeper among members, else first member
|
|
571
|
-
if metric_col and metric_col in adata_full.obs.columns:
|
|
572
|
-
try:
|
|
573
|
-
vals = pd.to_numeric(adata_full.obs.loc[members, metric_col], errors="coerce")
|
|
574
|
-
if vals.notna().any():
|
|
575
|
-
keeper = vals.idxmax() if keep_best_higher else vals.idxmin()
|
|
576
|
-
else:
|
|
577
|
-
keeper = members[0]
|
|
578
|
-
except Exception:
|
|
579
|
-
keeper = members[0]
|
|
580
|
-
else:
|
|
581
|
-
# prefer lex keeper if present in this merged cluster
|
|
582
|
-
lex_candidates = [m for m in members if ("sequence__lex_is_keeper" in adata_full.obs.columns and adata_full.obs.at[m, "sequence__lex_is_keeper"])]
|
|
583
|
-
if len(lex_candidates) > 0:
|
|
584
|
-
keeper = lex_candidates[0]
|
|
585
|
-
else:
|
|
586
|
-
keeper = members[0]
|
|
745
|
+
# Build an index→row-number mapping
|
|
746
|
+
name_to_pos = {name: i for i, name in enumerate(adata_full.obs.index)}
|
|
747
|
+
obs_index_full = list(adata_full.obs.index)
|
|
587
748
|
|
|
588
|
-
|
|
749
|
+
lex_col = (
|
|
750
|
+
"sequence__lex_is_keeper" if "sequence__lex_is_keeper" in adata_full.obs.columns else None
|
|
751
|
+
)
|
|
589
752
|
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
for
|
|
593
|
-
if keeper_idx in adata_full.obs.index:
|
|
594
|
-
is_dup_series.at[keeper_idx] = False
|
|
595
|
-
# clear sequence__is_duplicate for keeper if present
|
|
596
|
-
if "sequence__is_duplicate" in adata_full.obs.columns:
|
|
597
|
-
adata_full.obs.at[keeper_idx, "sequence__is_duplicate"] = False
|
|
598
|
-
# clear lex duplicate flag too if present
|
|
599
|
-
if "sequence__lex_is_duplicate" in adata_full.obs.columns:
|
|
600
|
-
adata_full.obs.at[keeper_idx, "sequence__lex_is_duplicate"] = False
|
|
753
|
+
for cid, sub in adata_full.obs.groupby("sequence__merged_cluster_id", dropna=False):
|
|
754
|
+
members_names = sub.index.to_list()
|
|
755
|
+
members_pos = [name_to_pos[n] for n in members_names]
|
|
601
756
|
|
|
602
|
-
|
|
757
|
+
if lex_col:
|
|
758
|
+
lex_mask_members = adata_full.obs.loc[members_names, lex_col].astype(bool).to_numpy()
|
|
759
|
+
else:
|
|
760
|
+
lex_mask_members = np.zeros(len(members_pos), dtype=bool)
|
|
761
|
+
|
|
762
|
+
keeper_pos = _choose_keeper_with_demux_preference(
|
|
763
|
+
members_pos,
|
|
764
|
+
adata_full,
|
|
765
|
+
obs_index_full,
|
|
766
|
+
demux_col=demux_col,
|
|
767
|
+
preferred_types=demux_types,
|
|
768
|
+
keep_best_metric=metric_col,
|
|
769
|
+
keep_best_higher=keep_best_higher,
|
|
770
|
+
lex_keeper_mask=lex_mask_members,
|
|
771
|
+
)
|
|
772
|
+
keeper_name = obs_index_full[keeper_pos]
|
|
773
|
+
keeper_idx_by_cluster[cid] = keeper_name
|
|
774
|
+
|
|
775
|
+
# enforce: keepers are not duplicates
|
|
776
|
+
is_dup_series = adata_full.obs["is_duplicate"].astype(bool)
|
|
777
|
+
for cid, keeper_name in keeper_idx_by_cluster.items():
|
|
778
|
+
if keeper_name in adata_full.obs.index:
|
|
779
|
+
is_dup_series.at[keeper_name] = False
|
|
780
|
+
if "sequence__is_duplicate" in adata_full.obs.columns:
|
|
781
|
+
adata_full.obs.at[keeper_name, "sequence__is_duplicate"] = False
|
|
782
|
+
if "sequence__lex_is_duplicate" in adata_full.obs.columns:
|
|
783
|
+
adata_full.obs.at[keeper_name, "sequence__lex_is_duplicate"] = False
|
|
784
|
+
adata_full.obs["is_duplicate"] = is_dup_series.values
|
|
603
785
|
|
|
604
786
|
# reason column
|
|
605
787
|
def _dup_reason_row(row):
|
|
788
|
+
"""Build a semi-colon delimited duplicate reason string."""
|
|
606
789
|
reasons = []
|
|
607
790
|
if row.get("is_duplicate_distance", False):
|
|
608
791
|
reasons.append("distance_thresh")
|
|
@@ -632,6 +815,7 @@ def flag_duplicate_reads(
|
|
|
632
815
|
# Plot helpers (use adata_full as input)
|
|
633
816
|
# ---------------------------
|
|
634
817
|
|
|
818
|
+
|
|
635
819
|
def plot_histogram_pages(
|
|
636
820
|
histograms,
|
|
637
821
|
distance_threshold,
|
|
@@ -674,10 +858,11 @@ def plot_histogram_pages(
|
|
|
674
858
|
use_adata = False
|
|
675
859
|
|
|
676
860
|
if len(samples) == 0 or len(references) == 0:
|
|
677
|
-
|
|
861
|
+
logger.info("No histogram data to plot.")
|
|
678
862
|
return {"distance_pages": [], "cluster_size_pages": []}
|
|
679
863
|
|
|
680
864
|
def clean_array(arr):
|
|
865
|
+
"""Filter array values to finite [0, 1] range for plotting."""
|
|
681
866
|
if arr is None or len(arr) == 0:
|
|
682
867
|
return np.array([], dtype=float)
|
|
683
868
|
a = np.asarray(arr, dtype=float)
|
|
@@ -707,7 +892,9 @@ def plot_histogram_pages(
|
|
|
707
892
|
if "rev" in distance_types and "rev_hamming_to_prev" in group.columns:
|
|
708
893
|
grid[(s, r)]["rev"].extend(clean_array(group["rev_hamming_to_prev"].to_numpy()))
|
|
709
894
|
if "hier" in distance_types and "sequence__hier_hamming_to_pair" in group.columns:
|
|
710
|
-
grid[(s, r)]["hier"].extend(
|
|
895
|
+
grid[(s, r)]["hier"].extend(
|
|
896
|
+
clean_array(group["sequence__hier_hamming_to_pair"].to_numpy())
|
|
897
|
+
)
|
|
711
898
|
else:
|
|
712
899
|
for (s, r), group in grouped:
|
|
713
900
|
if "min" in distance_types and distance_key in group.columns:
|
|
@@ -717,7 +904,9 @@ def plot_histogram_pages(
|
|
|
717
904
|
if "rev" in distance_types and "rev_hamming_to_prev" in group.columns:
|
|
718
905
|
grid[(s, r)]["rev"].extend(clean_array(group["rev_hamming_to_prev"].to_numpy()))
|
|
719
906
|
if "hier" in distance_types and "sequence__hier_hamming_to_pair" in group.columns:
|
|
720
|
-
grid[(s, r)]["hier"].extend(
|
|
907
|
+
grid[(s, r)]["hier"].extend(
|
|
908
|
+
clean_array(group["sequence__hier_hamming_to_pair"].to_numpy())
|
|
909
|
+
)
|
|
721
910
|
|
|
722
911
|
# legacy histograms fallback
|
|
723
912
|
if histograms:
|
|
@@ -753,9 +942,17 @@ def plot_histogram_pages(
|
|
|
753
942
|
|
|
754
943
|
# counts (for labels)
|
|
755
944
|
if use_adata:
|
|
756
|
-
counts = {
|
|
945
|
+
counts = {
|
|
946
|
+
(s, r): int(((adata.obs[sample_key] == s) & (adata.obs[ref_key] == r)).sum())
|
|
947
|
+
for s in samples
|
|
948
|
+
for r in references
|
|
949
|
+
}
|
|
757
950
|
else:
|
|
758
|
-
counts = {
|
|
951
|
+
counts = {
|
|
952
|
+
(s, r): sum(len(grid[(s, r)][dt]) for dt in distance_types)
|
|
953
|
+
for s in samples
|
|
954
|
+
for r in references
|
|
955
|
+
}
|
|
759
956
|
|
|
760
957
|
distance_pages = []
|
|
761
958
|
cluster_size_pages = []
|
|
@@ -773,7 +970,9 @@ def plot_histogram_pages(
|
|
|
773
970
|
# Distance histogram page
|
|
774
971
|
fig_w = figsize_per_cell[0] * ncols
|
|
775
972
|
fig_h = figsize_per_cell[1] * nrows
|
|
776
|
-
fig, axes = plt.subplots(
|
|
973
|
+
fig, axes = plt.subplots(
|
|
974
|
+
nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False
|
|
975
|
+
)
|
|
777
976
|
|
|
778
977
|
for r_idx, sample_name in enumerate(chunk):
|
|
779
978
|
for c_idx, ref_name in enumerate(references):
|
|
@@ -789,17 +988,37 @@ def plot_histogram_pages(
|
|
|
789
988
|
vals = vals[(vals >= 0.0) & (vals <= ref_vmax)]
|
|
790
989
|
if vals.size > 0:
|
|
791
990
|
any_data = True
|
|
792
|
-
ax.hist(
|
|
793
|
-
|
|
991
|
+
ax.hist(
|
|
992
|
+
vals,
|
|
993
|
+
bins=bins_edges,
|
|
994
|
+
alpha=0.5,
|
|
995
|
+
label=dtype,
|
|
996
|
+
density=False,
|
|
997
|
+
stacked=False,
|
|
998
|
+
color=dtype_colors.get(dtype, None),
|
|
999
|
+
)
|
|
794
1000
|
if not any_data:
|
|
795
|
-
ax.text(
|
|
1001
|
+
ax.text(
|
|
1002
|
+
0.5,
|
|
1003
|
+
0.5,
|
|
1004
|
+
"No data",
|
|
1005
|
+
ha="center",
|
|
1006
|
+
va="center",
|
|
1007
|
+
transform=ax.transAxes,
|
|
1008
|
+
fontsize=10,
|
|
1009
|
+
color="gray",
|
|
1010
|
+
)
|
|
796
1011
|
# threshold line (make sure it is within axis)
|
|
797
1012
|
ax.axvline(distance_threshold, color="red", linestyle="--", linewidth=1)
|
|
798
1013
|
|
|
799
1014
|
if r_idx == 0:
|
|
800
1015
|
ax.set_title(str(ref_name), fontsize=10)
|
|
801
1016
|
if c_idx == 0:
|
|
802
|
-
total_reads =
|
|
1017
|
+
total_reads = (
|
|
1018
|
+
sum(counts.get((sample_name, ref), 0) for ref in references)
|
|
1019
|
+
if not use_adata
|
|
1020
|
+
else int((adata.obs[sample_key] == sample_name).sum())
|
|
1021
|
+
)
|
|
803
1022
|
ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=9)
|
|
804
1023
|
if r_idx == nrows - 1:
|
|
805
1024
|
ax.set_xlabel("Hamming Distance", fontsize=9)
|
|
@@ -811,12 +1030,16 @@ def plot_histogram_pages(
|
|
|
811
1030
|
if r_idx == 0 and c_idx == 0:
|
|
812
1031
|
ax.legend(fontsize=7, loc="upper right")
|
|
813
1032
|
|
|
814
|
-
fig.suptitle(
|
|
1033
|
+
fig.suptitle(
|
|
1034
|
+
f"Hamming distance histograms (rows=samples, cols=references) — page {page + 1}/{n_pages}",
|
|
1035
|
+
fontsize=12,
|
|
1036
|
+
y=0.995,
|
|
1037
|
+
)
|
|
815
1038
|
fig.tight_layout(rect=[0, 0, 1, 0.96])
|
|
816
1039
|
|
|
817
1040
|
if output_directory:
|
|
818
1041
|
os.makedirs(output_directory, exist_ok=True)
|
|
819
|
-
fname = os.path.join(output_directory, f"hamming_histograms_page_{page+1}.png")
|
|
1042
|
+
fname = os.path.join(output_directory, f"hamming_histograms_page_{page + 1}.png")
|
|
820
1043
|
plt.savefig(fname, bbox_inches="tight")
|
|
821
1044
|
distance_pages.append(fname)
|
|
822
1045
|
else:
|
|
@@ -826,22 +1049,43 @@ def plot_histogram_pages(
|
|
|
826
1049
|
# Cluster-size histogram page (unchanged except it uses adata-derived sizes per cluster if available)
|
|
827
1050
|
fig_w = figsize_per_cell[0] * ncols
|
|
828
1051
|
fig_h = figsize_per_cell[1] * nrows
|
|
829
|
-
fig2, axes2 = plt.subplots(
|
|
1052
|
+
fig2, axes2 = plt.subplots(
|
|
1053
|
+
nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False
|
|
1054
|
+
)
|
|
830
1055
|
|
|
831
1056
|
for r_idx, sample_name in enumerate(chunk):
|
|
832
1057
|
for c_idx, ref_name in enumerate(references):
|
|
833
1058
|
ax = axes2[r_idx][c_idx]
|
|
834
1059
|
sizes = []
|
|
835
|
-
if use_adata and (
|
|
836
|
-
|
|
1060
|
+
if use_adata and (
|
|
1061
|
+
"sequence__merged_cluster_id" in adata.obs.columns
|
|
1062
|
+
and "sequence__cluster_size" in adata.obs.columns
|
|
1063
|
+
):
|
|
1064
|
+
sub = adata.obs[
|
|
1065
|
+
(adata.obs[sample_key] == sample_name) & (adata.obs[ref_key] == ref_name)
|
|
1066
|
+
]
|
|
837
1067
|
if not sub.empty:
|
|
838
1068
|
try:
|
|
839
|
-
grp = sub.groupby("sequence__merged_cluster_id")[
|
|
840
|
-
|
|
1069
|
+
grp = sub.groupby("sequence__merged_cluster_id")[
|
|
1070
|
+
"sequence__cluster_size"
|
|
1071
|
+
].first()
|
|
1072
|
+
sizes = [
|
|
1073
|
+
int(x)
|
|
1074
|
+
for x in grp.to_numpy().tolist()
|
|
1075
|
+
if (pd.notna(x) and np.isfinite(x))
|
|
1076
|
+
]
|
|
841
1077
|
except Exception:
|
|
842
1078
|
try:
|
|
843
|
-
unique_pairs = sub[
|
|
844
|
-
|
|
1079
|
+
unique_pairs = sub[
|
|
1080
|
+
["sequence__merged_cluster_id", "sequence__cluster_size"]
|
|
1081
|
+
].drop_duplicates()
|
|
1082
|
+
sizes = [
|
|
1083
|
+
int(x)
|
|
1084
|
+
for x in unique_pairs["sequence__cluster_size"]
|
|
1085
|
+
.dropna()
|
|
1086
|
+
.astype(int)
|
|
1087
|
+
.tolist()
|
|
1088
|
+
]
|
|
845
1089
|
except Exception:
|
|
846
1090
|
sizes = []
|
|
847
1091
|
if (not sizes) and histograms:
|
|
@@ -855,23 +1099,38 @@ def plot_histogram_pages(
|
|
|
855
1099
|
ax.set_xlabel("Cluster size")
|
|
856
1100
|
ax.set_ylabel("Count")
|
|
857
1101
|
else:
|
|
858
|
-
ax.text(
|
|
1102
|
+
ax.text(
|
|
1103
|
+
0.5,
|
|
1104
|
+
0.5,
|
|
1105
|
+
"No clusters",
|
|
1106
|
+
ha="center",
|
|
1107
|
+
va="center",
|
|
1108
|
+
transform=ax.transAxes,
|
|
1109
|
+
fontsize=10,
|
|
1110
|
+
color="gray",
|
|
1111
|
+
)
|
|
859
1112
|
|
|
860
1113
|
if r_idx == 0:
|
|
861
1114
|
ax.set_title(str(ref_name), fontsize=10)
|
|
862
1115
|
if c_idx == 0:
|
|
863
|
-
total_reads =
|
|
1116
|
+
total_reads = (
|
|
1117
|
+
sum(counts.get((sample_name, ref), 0) for ref in references)
|
|
1118
|
+
if not use_adata
|
|
1119
|
+
else int((adata.obs[sample_key] == sample_name).sum())
|
|
1120
|
+
)
|
|
864
1121
|
ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=9)
|
|
865
1122
|
if r_idx != nrows - 1:
|
|
866
1123
|
ax.set_xticklabels([])
|
|
867
1124
|
|
|
868
1125
|
ax.grid(True, alpha=0.25)
|
|
869
1126
|
|
|
870
|
-
fig2.suptitle(
|
|
1127
|
+
fig2.suptitle(
|
|
1128
|
+
f"Union-find cluster size histograms — page {page + 1}/{n_pages}", fontsize=12, y=0.995
|
|
1129
|
+
)
|
|
871
1130
|
fig2.tight_layout(rect=[0, 0, 1, 0.96])
|
|
872
1131
|
|
|
873
1132
|
if output_directory:
|
|
874
|
-
fname2 = os.path.join(output_directory, f"cluster_size_histograms_page_{page+1}.png")
|
|
1133
|
+
fname2 = os.path.join(output_directory, f"cluster_size_histograms_page_{page + 1}.png")
|
|
875
1134
|
plt.savefig(fname2, bbox_inches="tight")
|
|
876
1135
|
cluster_size_pages.append(fname2)
|
|
877
1136
|
else:
|
|
@@ -923,7 +1182,9 @@ def plot_hamming_vs_metric_pages(
|
|
|
923
1182
|
|
|
924
1183
|
obs = adata.obs
|
|
925
1184
|
if sample_col not in obs.columns or ref_col not in obs.columns:
|
|
926
|
-
raise ValueError(
|
|
1185
|
+
raise ValueError(
|
|
1186
|
+
f"sample_col '{sample_col}' and ref_col '{ref_col}' must exist in adata.obs"
|
|
1187
|
+
)
|
|
927
1188
|
|
|
928
1189
|
# canonicalize samples and refs
|
|
929
1190
|
if samples is None:
|
|
@@ -964,14 +1225,24 @@ def plot_hamming_vs_metric_pages(
|
|
|
964
1225
|
sY = pd.to_numeric(obs[hamming_col], errors="coerce") if hamming_present else None
|
|
965
1226
|
|
|
966
1227
|
if (sX is not None) and (sY is not None):
|
|
967
|
-
valid_both =
|
|
1228
|
+
valid_both = (
|
|
1229
|
+
sX.notna() & sY.notna() & np.isfinite(sX.values) & np.isfinite(sY.values)
|
|
1230
|
+
)
|
|
968
1231
|
if valid_both.any():
|
|
969
1232
|
xvals = sX[valid_both].to_numpy(dtype=float)
|
|
970
1233
|
yvals = sY[valid_both].to_numpy(dtype=float)
|
|
971
1234
|
xmin, xmax = float(np.nanmin(xvals)), float(np.nanmax(xvals))
|
|
972
1235
|
ymin, ymax = float(np.nanmin(yvals)), float(np.nanmax(yvals))
|
|
973
|
-
xpad =
|
|
974
|
-
|
|
1236
|
+
xpad = (
|
|
1237
|
+
max(1e-6, (xmax - xmin) * 0.05)
|
|
1238
|
+
if xmax > xmin
|
|
1239
|
+
else max(1e-3, abs(xmin) * 0.05 + 1e-3)
|
|
1240
|
+
)
|
|
1241
|
+
ypad = (
|
|
1242
|
+
max(1e-6, (ymax - ymin) * 0.05)
|
|
1243
|
+
if ymax > ymin
|
|
1244
|
+
else max(1e-3, abs(ymin) * 0.05 + 1e-3)
|
|
1245
|
+
)
|
|
975
1246
|
global_xlim = (xmin - xpad, xmax + xpad)
|
|
976
1247
|
global_ylim = (ymin - ypad, ymax + ypad)
|
|
977
1248
|
else:
|
|
@@ -979,23 +1250,39 @@ def plot_hamming_vs_metric_pages(
|
|
|
979
1250
|
sY_finite = sY[np.isfinite(sY)]
|
|
980
1251
|
if sX_finite.size > 0:
|
|
981
1252
|
xmin, xmax = float(np.nanmin(sX_finite)), float(np.nanmax(sX_finite))
|
|
982
|
-
xpad =
|
|
1253
|
+
xpad = (
|
|
1254
|
+
max(1e-6, (xmax - xmin) * 0.05)
|
|
1255
|
+
if xmax > xmin
|
|
1256
|
+
else max(1e-3, abs(xmin) * 0.05 + 1e-3)
|
|
1257
|
+
)
|
|
983
1258
|
global_xlim = (xmin - xpad, xmax + xpad)
|
|
984
1259
|
if sY_finite.size > 0:
|
|
985
1260
|
ymin, ymax = float(np.nanmin(sY_finite)), float(np.nanmax(sY_finite))
|
|
986
|
-
ypad =
|
|
1261
|
+
ypad = (
|
|
1262
|
+
max(1e-6, (ymax - ymin) * 0.05)
|
|
1263
|
+
if ymax > ymin
|
|
1264
|
+
else max(1e-3, abs(ymin) * 0.05 + 1e-3)
|
|
1265
|
+
)
|
|
987
1266
|
global_ylim = (ymin - ypad, ymax + ypad)
|
|
988
1267
|
elif sX is not None:
|
|
989
1268
|
sX_finite = sX[np.isfinite(sX)]
|
|
990
1269
|
if sX_finite.size > 0:
|
|
991
1270
|
xmin, xmax = float(np.nanmin(sX_finite)), float(np.nanmax(sX_finite))
|
|
992
|
-
xpad =
|
|
1271
|
+
xpad = (
|
|
1272
|
+
max(1e-6, (xmax - xmin) * 0.05)
|
|
1273
|
+
if xmax > xmin
|
|
1274
|
+
else max(1e-3, abs(xmin) * 0.05 + 1e-3)
|
|
1275
|
+
)
|
|
993
1276
|
global_xlim = (xmin - xpad, xmax + xpad)
|
|
994
1277
|
elif sY is not None:
|
|
995
1278
|
sY_finite = sY[np.isfinite(sY)]
|
|
996
1279
|
if sY_finite.size > 0:
|
|
997
1280
|
ymin, ymax = float(np.nanmin(sY_finite)), float(np.nanmax(sY_finite))
|
|
998
|
-
ypad =
|
|
1281
|
+
ypad = (
|
|
1282
|
+
max(1e-6, (ymax - ymin) * 0.05)
|
|
1283
|
+
if ymax > ymin
|
|
1284
|
+
else max(1e-3, abs(ymin) * 0.05 + 1e-3)
|
|
1285
|
+
)
|
|
999
1286
|
global_ylim = (ymin - ypad, ymax + ypad)
|
|
1000
1287
|
|
|
1001
1288
|
# pagination
|
|
@@ -1005,15 +1292,19 @@ def plot_hamming_vs_metric_pages(
|
|
|
1005
1292
|
ncols = len(cols)
|
|
1006
1293
|
fig_w = ncols * figsize_per_cell[0]
|
|
1007
1294
|
fig_h = nrows * figsize_per_cell[1]
|
|
1008
|
-
fig, axes = plt.subplots(
|
|
1295
|
+
fig, axes = plt.subplots(
|
|
1296
|
+
nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False
|
|
1297
|
+
)
|
|
1009
1298
|
|
|
1010
1299
|
for r_idx, sample_name in enumerate(chunk):
|
|
1011
1300
|
for c_idx, ref_name in enumerate(cols):
|
|
1012
1301
|
ax = axes[r_idx][c_idx]
|
|
1013
1302
|
if ref_name == extra_col:
|
|
1014
|
-
mask =
|
|
1303
|
+
mask = obs[sample_col].values == sample_name
|
|
1015
1304
|
else:
|
|
1016
|
-
mask = (obs[sample_col].values == sample_name) & (
|
|
1305
|
+
mask = (obs[sample_col].values == sample_name) & (
|
|
1306
|
+
obs[ref_col].values == ref_name
|
|
1307
|
+
)
|
|
1017
1308
|
|
|
1018
1309
|
sub = obs[mask]
|
|
1019
1310
|
|
|
@@ -1022,7 +1313,9 @@ def plot_hamming_vs_metric_pages(
|
|
|
1022
1313
|
else:
|
|
1023
1314
|
x_all = np.array([], dtype=float)
|
|
1024
1315
|
if hamming_col in sub.columns:
|
|
1025
|
-
y_all = pd.to_numeric(sub[hamming_col], errors="coerce").to_numpy(
|
|
1316
|
+
y_all = pd.to_numeric(sub[hamming_col], errors="coerce").to_numpy(
|
|
1317
|
+
dtype=float
|
|
1318
|
+
)
|
|
1026
1319
|
else:
|
|
1027
1320
|
y_all = np.array([], dtype=float)
|
|
1028
1321
|
|
|
@@ -1040,32 +1333,67 @@ def plot_hamming_vs_metric_pages(
|
|
|
1040
1333
|
idxs_valid = np.array([], dtype=int)
|
|
1041
1334
|
|
|
1042
1335
|
if x.size == 0:
|
|
1043
|
-
ax.text(
|
|
1336
|
+
ax.text(
|
|
1337
|
+
0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes
|
|
1338
|
+
)
|
|
1044
1339
|
clusters_info[(sample_name, ref_name)] = {"diag": None, "n_points": 0}
|
|
1045
1340
|
else:
|
|
1046
1341
|
# Decide color mapping
|
|
1047
|
-
if
|
|
1342
|
+
if (
|
|
1343
|
+
color_by_duplicate
|
|
1344
|
+
and duplicate_col in adata.obs.columns
|
|
1345
|
+
and idxs_valid.size > 0
|
|
1346
|
+
):
|
|
1048
1347
|
# get boolean series aligned to idxs_valid
|
|
1049
1348
|
try:
|
|
1050
|
-
dup_flags =
|
|
1349
|
+
dup_flags = (
|
|
1350
|
+
adata.obs.loc[idxs_valid, duplicate_col].astype(bool).to_numpy()
|
|
1351
|
+
)
|
|
1051
1352
|
except Exception:
|
|
1052
1353
|
dup_flags = np.zeros(len(idxs_valid), dtype=bool)
|
|
1053
1354
|
mask_dup = dup_flags
|
|
1054
1355
|
mask_nondup = ~mask_dup
|
|
1055
1356
|
# plot non-duplicates first in gray, duplicates in highlight color
|
|
1056
1357
|
if mask_nondup.any():
|
|
1057
|
-
ax.scatter(
|
|
1358
|
+
ax.scatter(
|
|
1359
|
+
x[mask_nondup],
|
|
1360
|
+
y[mask_nondup],
|
|
1361
|
+
s=12,
|
|
1362
|
+
alpha=0.6,
|
|
1363
|
+
rasterized=True,
|
|
1364
|
+
c="lightgray",
|
|
1365
|
+
)
|
|
1058
1366
|
if mask_dup.any():
|
|
1059
|
-
ax.scatter(
|
|
1367
|
+
ax.scatter(
|
|
1368
|
+
x[mask_dup],
|
|
1369
|
+
y[mask_dup],
|
|
1370
|
+
s=20,
|
|
1371
|
+
alpha=0.9,
|
|
1372
|
+
rasterized=True,
|
|
1373
|
+
c=highlight_color,
|
|
1374
|
+
edgecolors="k",
|
|
1375
|
+
linewidths=0.3,
|
|
1376
|
+
)
|
|
1060
1377
|
else:
|
|
1061
1378
|
# old behavior: highlight by threshold if requested
|
|
1062
1379
|
if highlight_threshold is not None and y.size:
|
|
1063
1380
|
mask_low = (y < float(highlight_threshold)) & np.isfinite(y)
|
|
1064
1381
|
mask_high = ~mask_low
|
|
1065
1382
|
if mask_high.any():
|
|
1066
|
-
ax.scatter(
|
|
1383
|
+
ax.scatter(
|
|
1384
|
+
x[mask_high], y[mask_high], s=12, alpha=0.6, rasterized=True
|
|
1385
|
+
)
|
|
1067
1386
|
if mask_low.any():
|
|
1068
|
-
ax.scatter(
|
|
1387
|
+
ax.scatter(
|
|
1388
|
+
x[mask_low],
|
|
1389
|
+
y[mask_low],
|
|
1390
|
+
s=18,
|
|
1391
|
+
alpha=0.9,
|
|
1392
|
+
rasterized=True,
|
|
1393
|
+
c=highlight_color,
|
|
1394
|
+
edgecolors="k",
|
|
1395
|
+
linewidths=0.3,
|
|
1396
|
+
)
|
|
1069
1397
|
else:
|
|
1070
1398
|
ax.scatter(x, y, s=12, alpha=0.6, rasterized=True)
|
|
1071
1399
|
|
|
@@ -1081,7 +1409,9 @@ def plot_hamming_vs_metric_pages(
|
|
|
1081
1409
|
zi = gaussian_kde(np.vstack([x, y]))(coords).reshape(xi_g.shape)
|
|
1082
1410
|
ax.contourf(xi_g, yi_g, zi, levels=8, alpha=0.35, cmap="Blues")
|
|
1083
1411
|
else:
|
|
1084
|
-
ax.scatter(
|
|
1412
|
+
ax.scatter(
|
|
1413
|
+
x, y, c=kde2, s=16, cmap="viridis", alpha=0.7, linewidths=0
|
|
1414
|
+
)
|
|
1085
1415
|
except Exception:
|
|
1086
1416
|
pass
|
|
1087
1417
|
|
|
@@ -1090,16 +1420,29 @@ def plot_hamming_vs_metric_pages(
|
|
|
1090
1420
|
a, b = np.polyfit(x, y, 1)
|
|
1091
1421
|
xs = np.linspace(np.nanmin(x), np.nanmax(x), 100)
|
|
1092
1422
|
ys = a * xs + b
|
|
1093
|
-
ax.plot(
|
|
1423
|
+
ax.plot(
|
|
1424
|
+
xs, ys, linestyle="--", linewidth=1.2, alpha=0.9, color="red"
|
|
1425
|
+
)
|
|
1094
1426
|
r = np.corrcoef(x, y)[0, 1]
|
|
1095
|
-
ax.text(
|
|
1096
|
-
|
|
1427
|
+
ax.text(
|
|
1428
|
+
0.98,
|
|
1429
|
+
0.02,
|
|
1430
|
+
f"r={float(r):.3f}",
|
|
1431
|
+
ha="right",
|
|
1432
|
+
va="bottom",
|
|
1433
|
+
transform=ax.transAxes,
|
|
1434
|
+
fontsize=8,
|
|
1435
|
+
bbox=dict(
|
|
1436
|
+
facecolor="white", alpha=0.6, boxstyle="round,pad=0.2"
|
|
1437
|
+
),
|
|
1438
|
+
)
|
|
1097
1439
|
except Exception:
|
|
1098
1440
|
pass
|
|
1099
1441
|
|
|
1100
1442
|
if clustering:
|
|
1101
1443
|
cl_labels, diag = _run_clustering(
|
|
1102
|
-
x,
|
|
1444
|
+
x,
|
|
1445
|
+
y,
|
|
1103
1446
|
method=clustering.get("method", "dbscan"),
|
|
1104
1447
|
n_clusters=clustering.get("n_clusters", 2),
|
|
1105
1448
|
dbscan_eps=clustering.get("dbscan_eps", 0.05),
|
|
@@ -1113,32 +1456,59 @@ def plot_hamming_vs_metric_pages(
|
|
|
1113
1456
|
if len(unique_nonnoise) > 0:
|
|
1114
1457
|
medians = {}
|
|
1115
1458
|
for lab in unique_nonnoise:
|
|
1116
|
-
mask_lab =
|
|
1117
|
-
medians[lab] =
|
|
1118
|
-
|
|
1459
|
+
mask_lab = cl_labels == lab
|
|
1460
|
+
medians[lab] = (
|
|
1461
|
+
float(np.median(y[mask_lab]))
|
|
1462
|
+
if mask_lab.any()
|
|
1463
|
+
else float("nan")
|
|
1464
|
+
)
|
|
1465
|
+
sorted_by_median = sorted(
|
|
1466
|
+
unique_nonnoise,
|
|
1467
|
+
key=lambda idx: (
|
|
1468
|
+
np.nan if np.isnan(medians[idx]) else medians[idx]
|
|
1469
|
+
),
|
|
1470
|
+
reverse=True,
|
|
1471
|
+
)
|
|
1119
1472
|
mapping = {old: new for new, old in enumerate(sorted_by_median)}
|
|
1120
1473
|
for i_lab in range(len(remapped_labels)):
|
|
1121
1474
|
if remapped_labels[i_lab] != -1:
|
|
1122
|
-
remapped_labels[i_lab] = mapping.get(
|
|
1475
|
+
remapped_labels[i_lab] = mapping.get(
|
|
1476
|
+
remapped_labels[i_lab], -1
|
|
1477
|
+
)
|
|
1123
1478
|
diag = diag or {}
|
|
1124
|
-
diag["cluster_median_hamming"] = {
|
|
1125
|
-
|
|
1479
|
+
diag["cluster_median_hamming"] = {
|
|
1480
|
+
int(old): medians[old] for old in medians
|
|
1481
|
+
}
|
|
1482
|
+
diag["cluster_old_to_new_map"] = {
|
|
1483
|
+
int(old): int(new) for old, new in mapping.items()
|
|
1484
|
+
}
|
|
1126
1485
|
else:
|
|
1127
1486
|
remapped_labels = cl_labels.copy()
|
|
1128
1487
|
diag = diag or {}
|
|
1129
1488
|
diag["cluster_median_hamming"] = {}
|
|
1130
1489
|
diag["cluster_old_to_new_map"] = {}
|
|
1131
1490
|
|
|
1132
|
-
_overlay_clusters_on_ax(
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1491
|
+
_overlay_clusters_on_ax(
|
|
1492
|
+
ax,
|
|
1493
|
+
x,
|
|
1494
|
+
y,
|
|
1495
|
+
remapped_labels,
|
|
1496
|
+
diag,
|
|
1497
|
+
cmap=clustering.get("cmap", "tab10"),
|
|
1498
|
+
hull=clustering.get("hull", True),
|
|
1499
|
+
show_cluster_labels=True,
|
|
1500
|
+
)
|
|
1136
1501
|
|
|
1137
|
-
clusters_info[(sample_name, ref_name)] = {
|
|
1502
|
+
clusters_info[(sample_name, ref_name)] = {
|
|
1503
|
+
"diag": diag,
|
|
1504
|
+
"n_points": len(x),
|
|
1505
|
+
}
|
|
1138
1506
|
|
|
1139
1507
|
if write_clusters_to_adata and idxs_valid.size > 0:
|
|
1140
|
-
colname_safe_ref =
|
|
1141
|
-
colname =
|
|
1508
|
+
colname_safe_ref = ref_name if ref_name != extra_col else "ALLREFS"
|
|
1509
|
+
colname = (
|
|
1510
|
+
f"hamming_cluster__{metric}__{sample_name}__{colname_safe_ref}"
|
|
1511
|
+
)
|
|
1142
1512
|
if colname not in adata.obs.columns:
|
|
1143
1513
|
adata.obs[colname] = np.nan
|
|
1144
1514
|
lab_arr = remapped_labels.astype(float)
|
|
@@ -1178,7 +1548,11 @@ def plot_hamming_vs_metric_pages(
|
|
|
1178
1548
|
plt.show()
|
|
1179
1549
|
plt.close(fig)
|
|
1180
1550
|
|
|
1181
|
-
saved_map[metric] = {
|
|
1551
|
+
saved_map[metric] = {
|
|
1552
|
+
"files": files,
|
|
1553
|
+
"clusters_info": clusters_info,
|
|
1554
|
+
"written_cols": written_cols,
|
|
1555
|
+
}
|
|
1182
1556
|
|
|
1183
1557
|
return saved_map
|
|
1184
1558
|
|
|
@@ -1187,7 +1561,7 @@ def _run_clustering(
|
|
|
1187
1561
|
x: np.ndarray,
|
|
1188
1562
|
y: np.ndarray,
|
|
1189
1563
|
*,
|
|
1190
|
-
method: str = "kmeans",
|
|
1564
|
+
method: str = "kmeans", # "kmeans", "dbscan", "gmm", "hdbscan"
|
|
1191
1565
|
n_clusters: int = 2,
|
|
1192
1566
|
dbscan_eps: float = 0.05,
|
|
1193
1567
|
dbscan_min_samples: int = 5,
|
|
@@ -1198,13 +1572,6 @@ def _run_clustering(
|
|
|
1198
1572
|
Run clustering on 2D points (x,y). Returns labels (len = npoints) and diagnostics dict.
|
|
1199
1573
|
Labels follow sklearn conventions (noise -> -1 for DBSCAN/HDBSCAN).
|
|
1200
1574
|
"""
|
|
1201
|
-
try:
|
|
1202
|
-
from sklearn.cluster import KMeans, DBSCAN
|
|
1203
|
-
from sklearn.mixture import GaussianMixture
|
|
1204
|
-
from sklearn.metrics import silhouette_score
|
|
1205
|
-
except Exception:
|
|
1206
|
-
KMeans = DBSCAN = GaussianMixture = silhouette_score = None
|
|
1207
|
-
|
|
1208
1575
|
pts = np.column_stack([x, y])
|
|
1209
1576
|
diagnostics: Dict[str, Any] = {"method": method, "n_input": len(x)}
|
|
1210
1577
|
if len(x) < min_points:
|
|
@@ -1270,7 +1637,11 @@ def _run_clustering(
|
|
|
1270
1637
|
|
|
1271
1638
|
# compute silhouette if suitable
|
|
1272
1639
|
try:
|
|
1273
|
-
if
|
|
1640
|
+
if (
|
|
1641
|
+
diagnostics.get("n_clusters_found", 0) >= 2
|
|
1642
|
+
and len(x) >= 3
|
|
1643
|
+
and silhouette_score is not None
|
|
1644
|
+
):
|
|
1274
1645
|
diagnostics["silhouette"] = float(silhouette_score(pts, labels))
|
|
1275
1646
|
else:
|
|
1276
1647
|
diagnostics["silhouette"] = None
|
|
@@ -1305,7 +1676,6 @@ def _overlay_clusters_on_ax(
|
|
|
1305
1676
|
Labels == -1 are noise and drawn in grey.
|
|
1306
1677
|
Also annotates cluster numbers near centroids (contiguous numbers starting at 0).
|
|
1307
1678
|
"""
|
|
1308
|
-
import matplotlib.colors as mcolors
|
|
1309
1679
|
from scipy.spatial import ConvexHull
|
|
1310
1680
|
|
|
1311
1681
|
labels = np.asarray(labels)
|
|
@@ -1323,19 +1693,47 @@ def _overlay_clusters_on_ax(
|
|
|
1323
1693
|
if not mask.any():
|
|
1324
1694
|
continue
|
|
1325
1695
|
col = (0.6, 0.6, 0.6, 0.6) if lab == -1 else colors[idx % ncolors]
|
|
1326
|
-
ax.scatter(
|
|
1696
|
+
ax.scatter(
|
|
1697
|
+
x[mask],
|
|
1698
|
+
y[mask],
|
|
1699
|
+
s=20,
|
|
1700
|
+
c=[col],
|
|
1701
|
+
alpha=alpha_pts,
|
|
1702
|
+
marker=marker,
|
|
1703
|
+
linewidths=0.2,
|
|
1704
|
+
edgecolors="none",
|
|
1705
|
+
rasterized=True,
|
|
1706
|
+
)
|
|
1327
1707
|
|
|
1328
1708
|
if lab != -1:
|
|
1329
1709
|
# centroid
|
|
1330
1710
|
if plot_centroids:
|
|
1331
1711
|
cx = float(np.mean(x[mask]))
|
|
1332
1712
|
cy = float(np.mean(y[mask]))
|
|
1333
|
-
ax.scatter(
|
|
1713
|
+
ax.scatter(
|
|
1714
|
+
[cx],
|
|
1715
|
+
[cy],
|
|
1716
|
+
s=centroid_size,
|
|
1717
|
+
marker=centroid_marker,
|
|
1718
|
+
c=[col],
|
|
1719
|
+
edgecolor="k",
|
|
1720
|
+
linewidth=0.6,
|
|
1721
|
+
zorder=10,
|
|
1722
|
+
)
|
|
1334
1723
|
|
|
1335
1724
|
if show_cluster_labels:
|
|
1336
|
-
ax.text(
|
|
1337
|
-
|
|
1338
|
-
|
|
1725
|
+
ax.text(
|
|
1726
|
+
cx,
|
|
1727
|
+
cy,
|
|
1728
|
+
str(int(lab)),
|
|
1729
|
+
color="white",
|
|
1730
|
+
fontsize=cluster_label_fontsize,
|
|
1731
|
+
ha="center",
|
|
1732
|
+
va="center",
|
|
1733
|
+
weight="bold",
|
|
1734
|
+
zorder=12,
|
|
1735
|
+
bbox=dict(facecolor=(0, 0, 0, 0.5), pad=0.3, boxstyle="round"),
|
|
1736
|
+
)
|
|
1339
1737
|
|
|
1340
1738
|
# hull
|
|
1341
1739
|
if hull and np.sum(mask) >= 3:
|
|
@@ -1343,9 +1741,16 @@ def _overlay_clusters_on_ax(
|
|
|
1343
1741
|
ch_pts = pts[mask]
|
|
1344
1742
|
hull_idx = ConvexHull(ch_pts).vertices
|
|
1345
1743
|
hull_poly = ch_pts[hull_idx]
|
|
1346
|
-
ax.fill(
|
|
1744
|
+
ax.fill(
|
|
1745
|
+
hull_poly[:, 0],
|
|
1746
|
+
hull_poly[:, 1],
|
|
1747
|
+
alpha=hull_alpha,
|
|
1748
|
+
facecolor=col,
|
|
1749
|
+
edgecolor=hull_edgecolor,
|
|
1750
|
+
linewidth=0.6,
|
|
1751
|
+
zorder=5,
|
|
1752
|
+
)
|
|
1347
1753
|
except Exception:
|
|
1348
1754
|
pass
|
|
1349
1755
|
|
|
1350
1756
|
return None
|
|
1351
|
-
|