smftools 0.1.7__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +7 -6
- smftools/_version.py +1 -1
- smftools/cli/cli_flows.py +94 -0
- smftools/cli/hmm_adata.py +338 -0
- smftools/cli/load_adata.py +577 -0
- smftools/cli/preprocess_adata.py +363 -0
- smftools/cli/spatial_adata.py +564 -0
- smftools/cli_entry.py +435 -0
- smftools/config/__init__.py +1 -0
- smftools/config/conversion.yaml +38 -0
- smftools/config/deaminase.yaml +61 -0
- smftools/config/default.yaml +264 -0
- smftools/config/direct.yaml +41 -0
- smftools/config/discover_input_files.py +115 -0
- smftools/config/experiment_config.py +1288 -0
- smftools/hmm/HMM.py +1576 -0
- smftools/hmm/__init__.py +20 -0
- smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
- smftools/hmm/call_hmm_peaks.py +106 -0
- smftools/{tools → hmm}/display_hmm.py +3 -3
- smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
- smftools/{tools → hmm}/train_hmm.py +1 -1
- smftools/informatics/__init__.py +13 -9
- smftools/informatics/archived/deaminase_smf.py +132 -0
- smftools/informatics/archived/fast5_to_pod5.py +43 -0
- smftools/informatics/archived/helpers/archived/__init__.py +71 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +126 -0
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +87 -0
- smftools/informatics/archived/helpers/archived/bam_qc.py +213 -0
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +90 -0
- smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +259 -0
- smftools/informatics/{helpers → archived/helpers/archived}/count_aligned_reads.py +2 -2
- smftools/informatics/{helpers → archived/helpers/archived}/demux_and_index_BAM.py +8 -10
- smftools/informatics/{helpers → archived/helpers/archived}/extract_base_identities.py +30 -4
- smftools/informatics/{helpers → archived/helpers/archived}/extract_mods.py +15 -13
- smftools/informatics/{helpers → archived/helpers/archived}/extract_read_features_from_bam.py +4 -2
- smftools/informatics/{helpers → archived/helpers/archived}/find_conversion_sites.py +5 -4
- smftools/informatics/{helpers → archived/helpers/archived}/generate_converted_FASTA.py +2 -0
- smftools/informatics/{helpers → archived/helpers/archived}/get_chromosome_lengths.py +9 -8
- smftools/informatics/archived/helpers/archived/index_fasta.py +24 -0
- smftools/informatics/{helpers → archived/helpers/archived}/make_modbed.py +1 -2
- smftools/informatics/{helpers → archived/helpers/archived}/modQC.py +2 -2
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +250 -0
- smftools/informatics/{helpers → archived/helpers/archived}/separate_bam_by_bc.py +8 -7
- smftools/informatics/{helpers → archived/helpers/archived}/split_and_index_BAM.py +8 -12
- smftools/informatics/archived/subsample_fasta_from_bed.py +49 -0
- smftools/informatics/bam_functions.py +812 -0
- smftools/informatics/basecalling.py +67 -0
- smftools/informatics/bed_functions.py +366 -0
- smftools/informatics/binarize_converted_base_identities.py +172 -0
- smftools/informatics/{helpers/converted_BAM_to_adata_II.py → converted_BAM_to_adata.py} +198 -50
- smftools/informatics/fasta_functions.py +255 -0
- smftools/informatics/h5ad_functions.py +197 -0
- smftools/informatics/{helpers/modkit_extract_to_adata.py → modkit_extract_to_adata.py} +147 -61
- smftools/informatics/modkit_functions.py +129 -0
- smftools/informatics/ohe.py +160 -0
- smftools/informatics/pod5_functions.py +224 -0
- smftools/informatics/{helpers/run_multiqc.py → run_multiqc.py} +5 -2
- smftools/machine_learning/__init__.py +12 -0
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +234 -0
- smftools/machine_learning/evaluation/__init__.py +2 -0
- smftools/machine_learning/evaluation/eval_utils.py +31 -0
- smftools/machine_learning/evaluation/evaluators.py +223 -0
- smftools/machine_learning/inference/__init__.py +3 -0
- smftools/machine_learning/inference/inference_utils.py +27 -0
- smftools/machine_learning/inference/lightning_inference.py +68 -0
- smftools/machine_learning/inference/sklearn_inference.py +55 -0
- smftools/machine_learning/inference/sliding_window_inference.py +114 -0
- smftools/machine_learning/models/base.py +295 -0
- smftools/machine_learning/models/cnn.py +138 -0
- smftools/machine_learning/models/lightning_base.py +345 -0
- smftools/machine_learning/models/mlp.py +26 -0
- smftools/{tools → machine_learning}/models/positional.py +3 -2
- smftools/{tools → machine_learning}/models/rnn.py +2 -1
- smftools/machine_learning/models/sklearn_models.py +273 -0
- smftools/machine_learning/models/transformer.py +303 -0
- smftools/machine_learning/training/__init__.py +2 -0
- smftools/machine_learning/training/train_lightning_model.py +135 -0
- smftools/machine_learning/training/train_sklearn_model.py +114 -0
- smftools/plotting/__init__.py +4 -1
- smftools/plotting/autocorrelation_plotting.py +609 -0
- smftools/plotting/general_plotting.py +1292 -140
- smftools/plotting/hmm_plotting.py +260 -0
- smftools/plotting/qc_plotting.py +270 -0
- smftools/preprocessing/__init__.py +15 -8
- smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
- smftools/preprocessing/append_base_context.py +122 -0
- smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
- smftools/preprocessing/binarize.py +17 -0
- smftools/preprocessing/binarize_on_Youden.py +2 -2
- smftools/preprocessing/calculate_complexity_II.py +248 -0
- smftools/preprocessing/calculate_coverage.py +10 -1
- smftools/preprocessing/calculate_position_Youden.py +1 -1
- smftools/preprocessing/calculate_read_modification_stats.py +101 -0
- smftools/preprocessing/clean_NaN.py +17 -1
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
- smftools/preprocessing/flag_duplicate_reads.py +1326 -124
- smftools/preprocessing/invert_adata.py +12 -5
- smftools/preprocessing/load_sample_sheet.py +19 -4
- smftools/readwrite.py +1021 -89
- smftools/tools/__init__.py +3 -32
- smftools/tools/calculate_umap.py +5 -5
- smftools/tools/general_tools.py +3 -3
- smftools/tools/position_stats.py +468 -106
- smftools/tools/read_stats.py +115 -1
- smftools/tools/spatial_autocorrelation.py +562 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/METADATA +14 -9
- smftools-0.2.3.dist-info/RECORD +173 -0
- smftools-0.2.3.dist-info/entry_points.txt +2 -0
- smftools/informatics/fast5_to_pod5.py +0 -21
- smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
- smftools/informatics/helpers/__init__.py +0 -74
- smftools/informatics/helpers/align_and_sort_BAM.py +0 -59
- smftools/informatics/helpers/aligned_BAM_to_bed.py +0 -74
- smftools/informatics/helpers/bam_qc.py +0 -66
- smftools/informatics/helpers/bed_to_bigwig.py +0 -39
- smftools/informatics/helpers/binarize_converted_base_identities.py +0 -79
- smftools/informatics/helpers/concatenate_fastqs_to_bam.py +0 -55
- smftools/informatics/helpers/index_fasta.py +0 -12
- smftools/informatics/helpers/make_dirs.py +0 -21
- smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
- smftools/informatics/load_adata.py +0 -182
- smftools/informatics/readwrite.py +0 -106
- smftools/informatics/subsample_fasta_from_bed.py +0 -47
- smftools/preprocessing/append_C_context.py +0 -82
- smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
- smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
- smftools/preprocessing/filter_reads_on_length.py +0 -51
- smftools/tools/call_hmm_peaks.py +0 -105
- smftools/tools/data/__init__.py +0 -2
- smftools/tools/data/anndata_data_module.py +0 -90
- smftools/tools/inference/__init__.py +0 -1
- smftools/tools/inference/lightning_inference.py +0 -41
- smftools/tools/models/base.py +0 -14
- smftools/tools/models/cnn.py +0 -34
- smftools/tools/models/lightning_base.py +0 -41
- smftools/tools/models/mlp.py +0 -17
- smftools/tools/models/sklearn_models.py +0 -40
- smftools/tools/models/transformer.py +0 -133
- smftools/tools/training/__init__.py +0 -1
- smftools/tools/training/train_lightning_model.py +0 -47
- smftools-0.1.7.dist-info/RECORD +0 -136
- /smftools/{tools/evaluation → cli}/__init__.py +0 -0
- /smftools/{tools → hmm}/calculate_distances.py +0 -0
- /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
- /smftools/informatics/{basecall_pod5s.py → archived/basecall_pod5s.py} +0 -0
- /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
- /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/canoncall.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/converted_BAM_to_adata.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_lengths_from_bed.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/extract_readnames_from_BAM.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/get_native_references.py +0 -0
- /smftools/informatics/{helpers → archived/helpers}/archived/informatics.py +0 -0
- /smftools/informatics/{helpers → archived/helpers}/archived/load_adata.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/modcall.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/ohe_batching.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/ohe_layers_decode.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_decode.py +0 -0
- /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_encode.py +0 -0
- /smftools/informatics/{subsample_pod5.py → archived/subsample_pod5.py} +0 -0
- /smftools/informatics/{helpers/complement_base_list.py → complement_base_list.py} +0 -0
- /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
- /smftools/{tools → machine_learning}/models/__init__.py +0 -0
- /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
- /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
- /smftools/{tools → machine_learning}/utils/device.py +0 -0
- /smftools/{tools → machine_learning}/utils/grl.py +0 -0
- /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
- /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/WHEEL +0 -0
- {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,149 +1,1351 @@
|
|
|
1
|
+
# duplicate_detection_with_hier_and_plots.py
|
|
2
|
+
import copy
|
|
3
|
+
import warnings
|
|
4
|
+
import math
|
|
5
|
+
import os
|
|
6
|
+
from collections import defaultdict
|
|
7
|
+
from typing import Dict, Any, Tuple, Union, List, Optional
|
|
8
|
+
|
|
1
9
|
import torch
|
|
10
|
+
import anndata as ad
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
13
|
+
import matplotlib.pyplot as plt
|
|
2
14
|
from tqdm import tqdm
|
|
3
15
|
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
16
|
+
from ..readwrite import make_dirs
|
|
17
|
+
|
|
18
|
+
# optional imports for clustering / PCA / KDE
|
|
19
|
+
try:
|
|
20
|
+
from scipy.cluster import hierarchy as sch
|
|
21
|
+
from scipy.spatial.distance import pdist, squareform
|
|
22
|
+
SCIPY_AVAILABLE = True
|
|
23
|
+
except Exception:
|
|
24
|
+
sch = None
|
|
25
|
+
pdist = None
|
|
26
|
+
squareform = None
|
|
27
|
+
SCIPY_AVAILABLE = False
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
from sklearn.decomposition import PCA
|
|
31
|
+
from sklearn.cluster import KMeans, DBSCAN
|
|
32
|
+
from sklearn.mixture import GaussianMixture
|
|
33
|
+
from sklearn.metrics import silhouette_score
|
|
34
|
+
SKLEARN_AVAILABLE = True
|
|
35
|
+
except Exception:
|
|
36
|
+
PCA = None
|
|
37
|
+
KMeans = DBSCAN = GaussianMixture = silhouette_score = None
|
|
38
|
+
SKLEARN_AVAILABLE = False
|
|
39
|
+
|
|
40
|
+
try:
|
|
41
|
+
from scipy.stats import gaussian_kde
|
|
42
|
+
except Exception:
|
|
43
|
+
gaussian_kde = None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def merge_uns_preserve(orig_uns: dict, new_uns: dict, prefer="orig") -> dict:
|
|
47
|
+
"""
|
|
48
|
+
Merge two .uns dicts. prefer='orig' will keep orig_uns values on conflict,
|
|
49
|
+
prefer='new' will keep new_uns values on conflict. Conflicts are reported.
|
|
50
|
+
"""
|
|
51
|
+
out = copy.deepcopy(new_uns) if new_uns is not None else {}
|
|
52
|
+
for k, v in (orig_uns or {}).items():
|
|
53
|
+
if k not in out:
|
|
54
|
+
out[k] = copy.deepcopy(v)
|
|
55
|
+
else:
|
|
56
|
+
# present in both: compare quickly (best-effort)
|
|
57
|
+
try:
|
|
58
|
+
equal = (out[k] == v)
|
|
59
|
+
except Exception:
|
|
60
|
+
equal = False
|
|
61
|
+
if equal:
|
|
62
|
+
continue
|
|
63
|
+
# conflict
|
|
64
|
+
warnings.warn(f".uns conflict for key '{k}'; keeping '{prefer}' value.")
|
|
65
|
+
if prefer == "orig":
|
|
66
|
+
out[k] = copy.deepcopy(v)
|
|
67
|
+
else:
|
|
68
|
+
# keep out[k] (the new one) and also stash orig under a suffix
|
|
69
|
+
out[f"orig_uns__{k}"] = copy.deepcopy(v)
|
|
70
|
+
return out
|
|
71
|
+
|
|
72
|
+
def flag_duplicate_reads(
|
|
73
|
+
adata,
|
|
74
|
+
var_filters_sets,
|
|
75
|
+
distance_threshold: float = 0.07,
|
|
76
|
+
obs_reference_col: str = "Reference_strand",
|
|
77
|
+
sample_col: str = "Barcode",
|
|
78
|
+
output_directory: Optional[str] = None,
|
|
79
|
+
metric_keys: Union[str, List[str]] = ("Fraction_any_C_site_modified",),
|
|
80
|
+
uns_flag: str = "read_duplicate_detection_performed",
|
|
81
|
+
uns_filtered_flag: str = "read_duplicates_removed",
|
|
82
|
+
bypass: bool = False,
|
|
83
|
+
force_redo: bool = False,
|
|
84
|
+
keep_best_metric: Optional[str] = 'read_quality',
|
|
85
|
+
keep_best_higher: bool = True,
|
|
86
|
+
window_size: int = 50,
|
|
87
|
+
min_overlap_positions: int = 20,
|
|
88
|
+
do_pca: bool = False,
|
|
89
|
+
pca_n_components: int = 50,
|
|
90
|
+
pca_center: bool = True,
|
|
91
|
+
do_hierarchical: bool = True,
|
|
92
|
+
hierarchical_linkage: str = "average",
|
|
93
|
+
hierarchical_metric: str = "euclidean",
|
|
94
|
+
hierarchical_window: int = 50,
|
|
95
|
+
random_state: int = 0,
|
|
96
|
+
):
|
|
97
|
+
"""
|
|
98
|
+
Duplicate-flagging pipeline where hierarchical stage operates only on representatives
|
|
99
|
+
(one representative per lex cluster, i.e. the keeper). Final keeper assignment and
|
|
100
|
+
enforcement happens only after hierarchical merging.
|
|
101
|
+
|
|
102
|
+
Returns (adata_unique, adata_full) as before; writes sequence__* columns into adata.obs.
|
|
103
|
+
"""
|
|
104
|
+
# early exits
|
|
105
|
+
already = bool(adata.uns.get(uns_flag, False))
|
|
106
|
+
if (already and not force_redo):
|
|
107
|
+
if "is_duplicate" in adata.obs.columns:
|
|
108
|
+
adata_unique = adata[adata.obs["is_duplicate"] == False].copy()
|
|
109
|
+
return adata_unique, adata
|
|
110
|
+
else:
|
|
111
|
+
return adata.copy(), adata.copy()
|
|
112
|
+
if bypass:
|
|
113
|
+
return None, adata
|
|
114
|
+
|
|
115
|
+
if isinstance(metric_keys, str):
|
|
116
|
+
metric_keys = [metric_keys]
|
|
117
|
+
|
|
118
|
+
# local UnionFind
|
|
119
|
+
class UnionFind:
|
|
120
|
+
def __init__(self, size):
|
|
121
|
+
self.parent = list(range(size))
|
|
122
|
+
|
|
123
|
+
def find(self, x):
|
|
124
|
+
while self.parent[x] != x:
|
|
125
|
+
self.parent[x] = self.parent[self.parent[x]]
|
|
126
|
+
x = self.parent[x]
|
|
127
|
+
return x
|
|
128
|
+
|
|
129
|
+
def union(self, x, y):
|
|
130
|
+
rx = self.find(x); ry = self.find(y)
|
|
131
|
+
if rx != ry:
|
|
132
|
+
self.parent[ry] = rx
|
|
133
|
+
|
|
134
|
+
adata_processed_list = []
|
|
135
|
+
histograms = []
|
|
136
|
+
|
|
137
|
+
samples = adata.obs[sample_col].astype("category").cat.categories
|
|
138
|
+
references = adata.obs[obs_reference_col].astype("category").cat.categories
|
|
139
|
+
|
|
140
|
+
for sample in samples:
|
|
141
|
+
for ref in references:
|
|
142
|
+
print(f"Processing sample={sample} ref={ref}")
|
|
143
|
+
sample_mask = adata.obs[sample_col] == sample
|
|
144
|
+
ref_mask = adata.obs[obs_reference_col] == ref
|
|
145
|
+
subset_mask = sample_mask & ref_mask
|
|
146
|
+
adata_subset = adata[subset_mask].copy()
|
|
147
|
+
|
|
148
|
+
if adata_subset.n_obs < 2:
|
|
149
|
+
print(f" Skipping {sample}_{ref} (too few reads)")
|
|
150
|
+
continue
|
|
151
|
+
|
|
152
|
+
N = adata_subset.shape[0]
|
|
153
|
+
|
|
154
|
+
# Build mask of columns (vars) to use
|
|
155
|
+
combined_mask = np.zeros(len(adata.var), dtype=bool)
|
|
156
|
+
for var_set in var_filters_sets:
|
|
157
|
+
if any(str(ref) in str(v) for v in var_set):
|
|
158
|
+
per_col_mask = np.ones(len(adata.var), dtype=bool)
|
|
159
|
+
for key in var_set:
|
|
160
|
+
per_col_mask &= np.asarray(adata.var[key].values, dtype=bool)
|
|
161
|
+
combined_mask |= per_col_mask
|
|
162
|
+
|
|
163
|
+
selected_cols = adata.var.index[combined_mask.tolist()].to_list()
|
|
164
|
+
col_indices = [adata.var.index.get_loc(c) for c in selected_cols]
|
|
165
|
+
print(f" Selected {len(col_indices)} columns out of {adata.var.shape[0]} for {ref}")
|
|
166
|
+
|
|
167
|
+
# Extract data matrix (dense numpy) for the subset
|
|
168
|
+
X = adata_subset.X
|
|
169
|
+
if not isinstance(X, np.ndarray):
|
|
170
|
+
try:
|
|
171
|
+
X = X.toarray()
|
|
172
|
+
except Exception:
|
|
173
|
+
X = np.asarray(X)
|
|
174
|
+
X_sub = X[:, col_indices].astype(float) # keep NaNs
|
|
175
|
+
|
|
176
|
+
# convert to torch for some vector ops
|
|
177
|
+
X_tensor = torch.from_numpy(X_sub.copy())
|
|
178
|
+
|
|
179
|
+
# per-read nearest distances recorded
|
|
180
|
+
fwd_hamming_to_next = np.full((N,), np.nan, dtype=float)
|
|
181
|
+
rev_hamming_to_prev = np.full((N,), np.nan, dtype=float)
|
|
182
|
+
hierarchical_min_pair = np.full((N,), np.nan, dtype=float)
|
|
183
|
+
|
|
184
|
+
# legacy local lexographic pairwise hamming distances (for histogram)
|
|
185
|
+
local_hamming_dists = []
|
|
186
|
+
# hierarchical discovered dists (for histogram)
|
|
187
|
+
hierarchical_found_dists = []
|
|
188
|
+
|
|
189
|
+
# Lexicographic windowed pass function
|
|
190
|
+
def cluster_pass(X_tensor_local, reverse=False, window=int(window_size), record_distances=False):
|
|
191
|
+
N_local = X_tensor_local.shape[0]
|
|
192
|
+
X_sortable = X_tensor_local.clone().nan_to_num(-1.0)
|
|
193
|
+
sort_keys = [tuple(row.numpy().tolist()) for row in X_sortable]
|
|
194
|
+
sorted_idx = sorted(range(N_local), key=lambda i: sort_keys[i], reverse=reverse)
|
|
195
|
+
sorted_X = X_tensor_local[sorted_idx]
|
|
196
|
+
cluster_pairs_local = []
|
|
197
|
+
|
|
198
|
+
for i in range(len(sorted_X)):
|
|
199
|
+
row_i = sorted_X[i]
|
|
200
|
+
j_range_local = range(i + 1, min(i + 1 + window, len(sorted_X)))
|
|
201
|
+
if len(j_range_local) == 0:
|
|
202
|
+
continue
|
|
203
|
+
block_rows = sorted_X[list(j_range_local)]
|
|
204
|
+
row_i_exp = row_i.unsqueeze(0) # (1, D)
|
|
205
|
+
valid_mask = (~torch.isnan(row_i_exp)) & (~torch.isnan(block_rows)) # (M, D)
|
|
206
|
+
valid_counts = valid_mask.sum(dim=1).float()
|
|
207
|
+
enough_overlap = valid_counts >= float(min_overlap_positions)
|
|
208
|
+
if enough_overlap.any():
|
|
209
|
+
diffs = (row_i_exp != block_rows) & valid_mask
|
|
210
|
+
hamming_counts = diffs.sum(dim=1).float()
|
|
211
|
+
hamming_dists = torch.where(valid_counts > 0, hamming_counts / valid_counts, torch.tensor(float("nan")))
|
|
212
|
+
# record distances (legacy list of all local comparisons)
|
|
213
|
+
hamming_np = hamming_dists.cpu().numpy().tolist()
|
|
214
|
+
local_hamming_dists.extend([float(x) for x in hamming_np if (not np.isnan(x))])
|
|
215
|
+
matches = (hamming_dists < distance_threshold) & (enough_overlap)
|
|
216
|
+
for offset_local, m in enumerate(matches):
|
|
217
|
+
if m:
|
|
218
|
+
i_global = sorted_idx[i]
|
|
219
|
+
j_global = sorted_idx[i + 1 + offset_local]
|
|
220
|
+
cluster_pairs_local.append((i_global, j_global))
|
|
221
|
+
if record_distances:
|
|
222
|
+
# record next neighbor distance for the item (global index)
|
|
223
|
+
next_local_idx = i + 1
|
|
224
|
+
if next_local_idx < len(sorted_X):
|
|
225
|
+
next_global = sorted_idx[next_local_idx]
|
|
226
|
+
vm_pair = (~torch.isnan(row_i)) & (~torch.isnan(sorted_X[next_local_idx]))
|
|
227
|
+
vc = vm_pair.sum().item()
|
|
228
|
+
if vc >= min_overlap_positions:
|
|
229
|
+
d = float(((row_i[vm_pair] != sorted_X[next_local_idx][vm_pair]).sum().item()) / vc)
|
|
230
|
+
if reverse:
|
|
231
|
+
rev_hamming_to_prev[next_global] = d
|
|
232
|
+
else:
|
|
233
|
+
fwd_hamming_to_next[sorted_idx[i]] = d
|
|
234
|
+
return cluster_pairs_local
|
|
235
|
+
|
|
236
|
+
# run forward pass
|
|
237
|
+
pairs_fwd = cluster_pass(X_tensor, reverse=False, record_distances=True)
|
|
238
|
+
involved_in_fwd = set([p for pair in pairs_fwd for p in pair])
|
|
239
|
+
# build mask for reverse pass to avoid re-checking items already paired
|
|
240
|
+
mask_for_rev = np.ones(N, dtype=bool)
|
|
241
|
+
if len(involved_in_fwd) > 0:
|
|
242
|
+
for idx in involved_in_fwd:
|
|
243
|
+
mask_for_rev[idx] = False
|
|
244
|
+
rev_idx_map = np.nonzero(mask_for_rev)[0].tolist()
|
|
245
|
+
if len(rev_idx_map) > 0:
|
|
246
|
+
reduced_tensor = X_tensor[rev_idx_map]
|
|
247
|
+
pairs_rev_local = cluster_pass(reduced_tensor, reverse=True, record_distances=True)
|
|
248
|
+
# remap local reduced indices to global
|
|
249
|
+
remapped_rev_pairs = [(int(rev_idx_map[i]), int(rev_idx_map[j])) for (i, j) in pairs_rev_local]
|
|
250
|
+
else:
|
|
251
|
+
remapped_rev_pairs = []
|
|
252
|
+
|
|
253
|
+
all_pairs = pairs_fwd + remapped_rev_pairs
|
|
254
|
+
|
|
255
|
+
# initial union-find based on lex pairs
|
|
256
|
+
uf = UnionFind(N)
|
|
257
|
+
for i, j in all_pairs:
|
|
258
|
+
uf.union(i, j)
|
|
259
|
+
|
|
260
|
+
# initial merged clusters (lex-level)
|
|
261
|
+
merged_cluster = np.zeros((N,), dtype=int)
|
|
262
|
+
for i in range(N):
|
|
263
|
+
merged_cluster[i] = uf.find(i)
|
|
264
|
+
unique_initial = np.unique(merged_cluster)
|
|
265
|
+
id_map = {old: new for new, old in enumerate(sorted(unique_initial.tolist()))}
|
|
266
|
+
merged_cluster_mapped = np.array([id_map[int(x)] for x in merged_cluster], dtype=int)
|
|
267
|
+
|
|
268
|
+
# cluster sizes and choose lex-keeper per lex-cluster (representatives)
|
|
269
|
+
cluster_sizes = np.zeros_like(merged_cluster_mapped)
|
|
270
|
+
cluster_counts = []
|
|
271
|
+
unique_clusters = np.unique(merged_cluster_mapped)
|
|
272
|
+
keeper_for_cluster = {}
|
|
273
|
+
for cid in unique_clusters:
|
|
274
|
+
members = np.where(merged_cluster_mapped == cid)[0].tolist()
|
|
275
|
+
csize = int(len(members))
|
|
276
|
+
cluster_counts.append(csize)
|
|
277
|
+
cluster_sizes[members] = csize
|
|
278
|
+
# pick lex keeper (representative)
|
|
279
|
+
if len(members) == 1:
|
|
280
|
+
keeper_for_cluster[cid] = members[0]
|
|
281
|
+
else:
|
|
282
|
+
if keep_best_metric is None:
|
|
283
|
+
keeper_for_cluster[cid] = members[0]
|
|
284
|
+
else:
|
|
285
|
+
obs_index = list(adata_subset.obs.index)
|
|
286
|
+
member_names = [obs_index[m] for m in members]
|
|
287
|
+
try:
|
|
288
|
+
vals = pd.to_numeric(adata_subset.obs.loc[member_names, keep_best_metric], errors="coerce").to_numpy(dtype=float)
|
|
289
|
+
except Exception:
|
|
290
|
+
vals = np.array([np.nan] * len(members), dtype=float)
|
|
291
|
+
if np.all(np.isnan(vals)):
|
|
292
|
+
keeper_for_cluster[cid] = members[0]
|
|
293
|
+
else:
|
|
294
|
+
if keep_best_higher:
|
|
295
|
+
nan_mask = np.isnan(vals)
|
|
296
|
+
vals[nan_mask] = -np.inf
|
|
297
|
+
rel_idx = int(np.nanargmax(vals))
|
|
298
|
+
else:
|
|
299
|
+
nan_mask = np.isnan(vals)
|
|
300
|
+
vals[nan_mask] = np.inf
|
|
301
|
+
rel_idx = int(np.nanargmin(vals))
|
|
302
|
+
keeper_for_cluster[cid] = members[rel_idx]
|
|
303
|
+
|
|
304
|
+
# expose lex keeper info (record only; do not enforce deletion yet)
|
|
305
|
+
lex_is_keeper = np.zeros((N,), dtype=bool)
|
|
306
|
+
lex_is_duplicate = np.zeros((N,), dtype=bool)
|
|
307
|
+
for cid, members in zip(unique_clusters, [np.where(merged_cluster_mapped == cid)[0].tolist() for cid in unique_clusters]):
|
|
308
|
+
keeper_idx = keeper_for_cluster[cid]
|
|
309
|
+
lex_is_keeper[keeper_idx] = True
|
|
310
|
+
for m in members:
|
|
311
|
+
if m != keeper_idx:
|
|
312
|
+
lex_is_duplicate[m] = True
|
|
313
|
+
# note: these are just recorded for inspection / later preference
|
|
314
|
+
# and will be written to adata_subset.obs below
|
|
315
|
+
# record lex min pair (min of fwd/rev neighbor) for each read
|
|
316
|
+
min_pair = np.full((N,), np.nan, dtype=float)
|
|
317
|
+
for i in range(N):
|
|
318
|
+
a = fwd_hamming_to_next[i]
|
|
319
|
+
b = rev_hamming_to_prev[i]
|
|
320
|
+
vals = []
|
|
321
|
+
if not np.isnan(a):
|
|
322
|
+
vals.append(a)
|
|
323
|
+
if not np.isnan(b):
|
|
324
|
+
vals.append(b)
|
|
325
|
+
if vals:
|
|
326
|
+
min_pair[i] = float(np.nanmin(vals))
|
|
327
|
+
|
|
328
|
+
# --- hierarchical on representatives only ---
|
|
329
|
+
hierarchical_pairs = [] # (rep_global_i, rep_global_j, d)
|
|
330
|
+
rep_global_indices = sorted(set(keeper_for_cluster.values()))
|
|
331
|
+
if do_hierarchical and len(rep_global_indices) > 1:
|
|
332
|
+
if not SKLEARN_AVAILABLE:
|
|
333
|
+
warnings.warn("sklearn not available; skipping PCA/hierarchical pass.")
|
|
334
|
+
elif not SCIPY_AVAILABLE:
|
|
335
|
+
warnings.warn("scipy not available; skipping hierarchical pass.")
|
|
336
|
+
else:
|
|
337
|
+
# build reps array and impute for PCA
|
|
338
|
+
reps_X = X_sub[rep_global_indices, :]
|
|
339
|
+
reps_arr = np.array(reps_X, dtype=float, copy=True)
|
|
340
|
+
col_means = np.nanmean(reps_arr, axis=0)
|
|
341
|
+
col_means = np.where(np.isnan(col_means), 0.0, col_means)
|
|
342
|
+
inds = np.where(np.isnan(reps_arr))
|
|
343
|
+
if inds[0].size > 0:
|
|
344
|
+
reps_arr[inds] = np.take(col_means, inds[1])
|
|
345
|
+
|
|
346
|
+
# PCA if requested
|
|
347
|
+
if do_pca and PCA is not None:
|
|
348
|
+
n_comp = min(int(pca_n_components), reps_arr.shape[1], reps_arr.shape[0])
|
|
349
|
+
if n_comp <= 0:
|
|
350
|
+
reps_for_clustering = reps_arr
|
|
351
|
+
else:
|
|
352
|
+
pca = PCA(n_components=n_comp, random_state=int(random_state), svd_solver="auto", copy=True)
|
|
353
|
+
reps_for_clustering = pca.fit_transform(reps_arr)
|
|
354
|
+
else:
|
|
355
|
+
reps_for_clustering = reps_arr
|
|
356
|
+
|
|
357
|
+
# linkage & leaves (ordering)
|
|
358
|
+
try:
|
|
359
|
+
pdist_vec = pdist(reps_for_clustering, metric=hierarchical_metric)
|
|
360
|
+
Z = sch.linkage(pdist_vec, method=hierarchical_linkage)
|
|
361
|
+
leaves = sch.leaves_list(Z)
|
|
362
|
+
except Exception as e:
|
|
363
|
+
warnings.warn(f"hierarchical pass failed: {e}; skipping hierarchical stage.")
|
|
364
|
+
leaves = np.arange(len(rep_global_indices), dtype=int)
|
|
365
|
+
|
|
366
|
+
# apply windowed hamming comparisons across ordered reps and union via same UF (so clusters of all reads merge)
|
|
367
|
+
order_global_reps = [rep_global_indices[i] for i in leaves]
|
|
368
|
+
n_reps = len(order_global_reps)
|
|
369
|
+
for pos in range(n_reps):
|
|
370
|
+
i_global = order_global_reps[pos]
|
|
371
|
+
for jpos in range(pos + 1, min(pos + 1 + hierarchical_window, n_reps)):
|
|
372
|
+
j_global = order_global_reps[jpos]
|
|
373
|
+
vi = X_sub[int(i_global), :]
|
|
374
|
+
vj = X_sub[int(j_global), :]
|
|
375
|
+
valid_mask = (~np.isnan(vi)) & (~np.isnan(vj))
|
|
376
|
+
overlap = int(valid_mask.sum())
|
|
377
|
+
if overlap < min_overlap_positions:
|
|
378
|
+
continue
|
|
379
|
+
diffs = (vi[valid_mask] != vj[valid_mask]).sum()
|
|
380
|
+
d = float(diffs) / float(overlap)
|
|
381
|
+
if d < distance_threshold:
|
|
382
|
+
uf.union(int(i_global), int(j_global))
|
|
383
|
+
hierarchical_pairs.append((int(i_global), int(j_global), float(d)))
|
|
384
|
+
hierarchical_found_dists.append(float(d))
|
|
385
|
+
|
|
386
|
+
# after hierarchical unions, reconstruct merged clusters for all reads
|
|
387
|
+
merged_cluster_after = np.zeros((N,), dtype=int)
|
|
388
|
+
for i in range(N):
|
|
389
|
+
merged_cluster_after[i] = uf.find(i)
|
|
390
|
+
unique_final = np.unique(merged_cluster_after)
|
|
391
|
+
id_map_final = {old: new for new, old in enumerate(sorted(unique_final.tolist()))}
|
|
392
|
+
merged_cluster_mapped_final = np.array([id_map_final[int(x)] for x in merged_cluster_after], dtype=int)
|
|
393
|
+
|
|
394
|
+
# compute final cluster members and choose final keeper per final cluster
|
|
395
|
+
cluster_sizes_final = np.zeros_like(merged_cluster_mapped_final)
|
|
396
|
+
final_cluster_counts = []
|
|
397
|
+
final_unique = np.unique(merged_cluster_mapped_final)
|
|
398
|
+
final_keeper_for_cluster = {}
|
|
399
|
+
cluster_members_map = {}
|
|
400
|
+
for cid in final_unique:
|
|
401
|
+
members = np.where(merged_cluster_mapped_final == cid)[0].tolist()
|
|
402
|
+
cluster_members_map[cid] = members
|
|
403
|
+
csize = len(members)
|
|
404
|
+
final_cluster_counts.append(csize)
|
|
405
|
+
cluster_sizes_final[members] = csize
|
|
406
|
+
if csize == 1:
|
|
407
|
+
final_keeper_for_cluster[cid] = members[0]
|
|
408
|
+
else:
|
|
409
|
+
# prefer keep_best_metric if available; do not automatically prefer lex-keeper here unless you want to;
|
|
410
|
+
# (user previously asked for preferring lex keepers — if desired, you can prefer lex_is_keeper among members)
|
|
411
|
+
obs_index = list(adata_subset.obs.index)
|
|
412
|
+
member_names = [obs_index[m] for m in members]
|
|
413
|
+
if keep_best_metric is not None and keep_best_metric in adata_subset.obs.columns:
|
|
414
|
+
try:
|
|
415
|
+
vals = pd.to_numeric(adata_subset.obs.loc[member_names, keep_best_metric], errors="coerce").to_numpy(dtype=float)
|
|
416
|
+
except Exception:
|
|
417
|
+
vals = np.array([np.nan] * len(members), dtype=float)
|
|
418
|
+
if np.all(np.isnan(vals)):
|
|
419
|
+
final_keeper_for_cluster[cid] = members[0]
|
|
420
|
+
else:
|
|
421
|
+
if keep_best_higher:
|
|
422
|
+
nan_mask = np.isnan(vals)
|
|
423
|
+
vals[nan_mask] = -np.inf
|
|
424
|
+
rel_idx = int(np.nanargmax(vals))
|
|
425
|
+
else:
|
|
426
|
+
nan_mask = np.isnan(vals)
|
|
427
|
+
vals[nan_mask] = np.inf
|
|
428
|
+
rel_idx = int(np.nanargmin(vals))
|
|
429
|
+
final_keeper_for_cluster[cid] = members[rel_idx]
|
|
430
|
+
else:
|
|
431
|
+
# if lex keepers present among members, prefer them
|
|
432
|
+
lex_members = [m for m in members if lex_is_keeper[m]]
|
|
433
|
+
if len(lex_members) > 0:
|
|
434
|
+
final_keeper_for_cluster[cid] = lex_members[0]
|
|
435
|
+
else:
|
|
436
|
+
final_keeper_for_cluster[cid] = members[0]
|
|
437
|
+
|
|
438
|
+
# update sequence__is_duplicate based on final clusters: non-keepers in multi-member clusters are duplicates
|
|
439
|
+
sequence_is_duplicate = np.zeros((N,), dtype=bool)
|
|
440
|
+
for cid in final_unique:
|
|
441
|
+
keeper = final_keeper_for_cluster[cid]
|
|
442
|
+
members = cluster_members_map[cid]
|
|
443
|
+
if len(members) > 1:
|
|
444
|
+
for m in members:
|
|
445
|
+
if m != keeper:
|
|
446
|
+
sequence_is_duplicate[m] = True
|
|
447
|
+
|
|
448
|
+
# propagate hierarchical distances into hierarchical_min_pair for all cluster members
|
|
449
|
+
for (i_g, j_g, d) in hierarchical_pairs:
|
|
450
|
+
# identify their final cluster ids (after unions)
|
|
451
|
+
c_i = merged_cluster_mapped_final[int(i_g)]
|
|
452
|
+
c_j = merged_cluster_mapped_final[int(j_g)]
|
|
453
|
+
members_i = cluster_members_map.get(c_i, [int(i_g)])
|
|
454
|
+
members_j = cluster_members_map.get(c_j, [int(j_g)])
|
|
455
|
+
for mi in members_i:
|
|
456
|
+
if np.isnan(hierarchical_min_pair[mi]) or (d < hierarchical_min_pair[mi]):
|
|
457
|
+
hierarchical_min_pair[mi] = d
|
|
458
|
+
for mj in members_j:
|
|
459
|
+
if np.isnan(hierarchical_min_pair[mj]) or (d < hierarchical_min_pair[mj]):
|
|
460
|
+
hierarchical_min_pair[mj] = d
|
|
461
|
+
|
|
462
|
+
# combine lex-phase min_pair and hierarchical_min_pair into the final sequence__min_hamming_to_pair
|
|
463
|
+
combined_min = min_pair.copy()
|
|
464
|
+
for i in range(N):
|
|
465
|
+
hval = hierarchical_min_pair[i]
|
|
466
|
+
if not np.isnan(hval):
|
|
467
|
+
if np.isnan(combined_min[i]) or (hval < combined_min[i]):
|
|
468
|
+
combined_min[i] = hval
|
|
469
|
+
|
|
470
|
+
# write columns back into adata_subset.obs
|
|
471
|
+
adata_subset.obs["sequence__is_duplicate"] = sequence_is_duplicate
|
|
472
|
+
adata_subset.obs["sequence__merged_cluster_id"] = merged_cluster_mapped_final
|
|
473
|
+
adata_subset.obs["sequence__cluster_size"] = cluster_sizes_final
|
|
474
|
+
adata_subset.obs["fwd_hamming_to_next"] = fwd_hamming_to_next
|
|
475
|
+
adata_subset.obs["rev_hamming_to_prev"] = rev_hamming_to_prev
|
|
476
|
+
adata_subset.obs["sequence__hier_hamming_to_pair"] = hierarchical_min_pair
|
|
477
|
+
adata_subset.obs["sequence__min_hamming_to_pair"] = combined_min
|
|
478
|
+
# persist lex bookkeeping columns (informational)
|
|
479
|
+
adata_subset.obs["sequence__lex_is_keeper"] = lex_is_keeper
|
|
480
|
+
adata_subset.obs["sequence__lex_is_duplicate"] = lex_is_duplicate
|
|
481
|
+
|
|
482
|
+
adata_processed_list.append(adata_subset)
|
|
483
|
+
|
|
484
|
+
histograms.append({
|
|
485
|
+
"sample": sample,
|
|
486
|
+
"reference": ref,
|
|
487
|
+
"distances": local_hamming_dists, # lex local comparisons
|
|
488
|
+
"cluster_counts": final_cluster_counts,
|
|
489
|
+
"hierarchical_pairs": hierarchical_found_dists,
|
|
490
|
+
})
|
|
491
|
+
|
|
492
|
+
# Merge annotated subsets back together BEFORE plotting so plotting sees fwd_hamming_to_next, etc.
|
|
493
|
+
_original_uns = copy.deepcopy(adata.uns)
|
|
494
|
+
if len(adata_processed_list) == 0:
|
|
495
|
+
return adata.copy(), adata.copy()
|
|
496
|
+
|
|
497
|
+
adata_full = ad.concat(adata_processed_list, merge="same", join="outer", index_unique=None)
|
|
498
|
+
adata_full.uns = merge_uns_preserve(_original_uns, adata_full.uns, prefer="orig")
|
|
499
|
+
|
|
500
|
+
# Ensure expected numeric columns exist (create if missing)
|
|
501
|
+
for col in ("fwd_hamming_to_next", "rev_hamming_to_prev", "sequence__min_hamming_to_pair", "sequence__hier_hamming_to_pair"):
|
|
502
|
+
if col not in adata_full.obs.columns:
|
|
503
|
+
adata_full.obs[col] = np.nan
|
|
504
|
+
|
|
505
|
+
# histograms (now driven by adata_full if requested)
|
|
506
|
+
hist_outs = os.path.join(output_directory, "read_pair_hamming_distance_histograms")
|
|
507
|
+
make_dirs([hist_outs])
|
|
508
|
+
plot_histogram_pages(histograms,
|
|
509
|
+
distance_threshold=distance_threshold,
|
|
510
|
+
adata=adata_full,
|
|
511
|
+
output_directory=hist_outs,
|
|
512
|
+
distance_types=["min","fwd","rev","hier","lex_local"],
|
|
513
|
+
sample_key=sample_col,
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
# hamming vs metric scatter
|
|
517
|
+
scatter_outs = os.path.join(output_directory, "read_pair_hamming_distance_scatter_plots")
|
|
518
|
+
make_dirs([scatter_outs])
|
|
519
|
+
plot_hamming_vs_metric_pages(adata_full,
|
|
520
|
+
metric_keys=metric_keys,
|
|
521
|
+
output_dir=scatter_outs,
|
|
522
|
+
hamming_col="sequence__min_hamming_to_pair",
|
|
523
|
+
highlight_threshold=distance_threshold,
|
|
524
|
+
highlight_color="red",
|
|
525
|
+
sample_col=sample_col)
|
|
526
|
+
|
|
527
|
+
# boolean columns from neighbor distances
|
|
528
|
+
fwd_vals = pd.to_numeric(adata_full.obs.get("fwd_hamming_to_next", pd.Series(np.nan, index=adata_full.obs.index)), errors="coerce")
|
|
529
|
+
rev_vals = pd.to_numeric(adata_full.obs.get("rev_hamming_to_prev", pd.Series(np.nan, index=adata_full.obs.index)), errors="coerce")
|
|
530
|
+
is_dup_dist = (fwd_vals < float(distance_threshold)) | (rev_vals < float(distance_threshold))
|
|
531
|
+
is_dup_dist = is_dup_dist.fillna(False).astype(bool)
|
|
532
|
+
adata_full.obs["is_duplicate_distance"] = is_dup_dist.values
|
|
533
|
+
|
|
534
|
+
# combine sequence-derived flag with others
|
|
535
|
+
if "sequence__is_duplicate" in adata_full.obs.columns:
|
|
536
|
+
seq_dup = adata_full.obs["sequence__is_duplicate"].astype(bool)
|
|
537
|
+
else:
|
|
538
|
+
seq_dup = pd.Series(False, index=adata_full.obs.index)
|
|
539
|
+
|
|
540
|
+
# cluster-based duplicate indicator (if any clustering columns exist)
|
|
541
|
+
cluster_cols = [c for c in adata_full.obs.columns if c.startswith("hamming_cluster__")]
|
|
542
|
+
if cluster_cols:
|
|
543
|
+
cl_mask = pd.Series(False, index=adata_full.obs.index)
|
|
544
|
+
for c in cluster_cols:
|
|
545
|
+
vals = pd.to_numeric(adata_full.obs[c], errors="coerce")
|
|
546
|
+
mask_pos = (vals > 0) & (vals != -1)
|
|
547
|
+
mask_pos = mask_pos.fillna(False)
|
|
548
|
+
cl_mask |= mask_pos
|
|
549
|
+
adata_full.obs["is_duplicate_clustering"] = cl_mask.values
|
|
550
|
+
else:
|
|
551
|
+
adata_full.obs["is_duplicate_clustering"] = False
|
|
552
|
+
|
|
553
|
+
final_dup = seq_dup | adata_full.obs["is_duplicate_distance"].astype(bool) | adata_full.obs["is_duplicate_clustering"].astype(bool)
|
|
554
|
+
adata_full.obs["is_duplicate"] = final_dup.values
|
|
555
|
+
|
|
556
|
+
# Final keeper enforcement: recompute per-cluster keeper from sequence__merged_cluster_id and
|
|
557
|
+
# ensure that keeper is not marked duplicate
|
|
558
|
+
if "sequence__merged_cluster_id" in adata_full.obs.columns:
|
|
559
|
+
keeper_idx_by_cluster = {}
|
|
560
|
+
metric_col = keep_best_metric if 'keep_best_metric' in locals() else None
|
|
561
|
+
|
|
562
|
+
# group by cluster id
|
|
563
|
+
grp = adata_full.obs[["sequence__merged_cluster_id", "sequence__cluster_size"]].copy()
|
|
564
|
+
for cid, sub in grp.groupby("sequence__merged_cluster_id"):
|
|
565
|
+
try:
|
|
566
|
+
members = sub.index.to_list()
|
|
567
|
+
except Exception:
|
|
568
|
+
members = list(sub.index)
|
|
569
|
+
keeper = None
|
|
570
|
+
# prefer keep_best_metric (if present), else prefer lex keeper among members, else first member
|
|
571
|
+
if metric_col and metric_col in adata_full.obs.columns:
|
|
572
|
+
try:
|
|
573
|
+
vals = pd.to_numeric(adata_full.obs.loc[members, metric_col], errors="coerce")
|
|
574
|
+
if vals.notna().any():
|
|
575
|
+
keeper = vals.idxmax() if keep_best_higher else vals.idxmin()
|
|
576
|
+
else:
|
|
577
|
+
keeper = members[0]
|
|
578
|
+
except Exception:
|
|
579
|
+
keeper = members[0]
|
|
580
|
+
else:
|
|
581
|
+
# prefer lex keeper if present in this merged cluster
|
|
582
|
+
lex_candidates = [m for m in members if ("sequence__lex_is_keeper" in adata_full.obs.columns and adata_full.obs.at[m, "sequence__lex_is_keeper"])]
|
|
583
|
+
if len(lex_candidates) > 0:
|
|
584
|
+
keeper = lex_candidates[0]
|
|
585
|
+
else:
|
|
586
|
+
keeper = members[0]
|
|
587
|
+
|
|
588
|
+
keeper_idx_by_cluster[cid] = keeper
|
|
589
|
+
|
|
590
|
+
# force keepers not to be duplicates
|
|
591
|
+
is_dup_series = adata_full.obs["is_duplicate"].astype(bool)
|
|
592
|
+
for cid, keeper_idx in keeper_idx_by_cluster.items():
|
|
593
|
+
if keeper_idx in adata_full.obs.index:
|
|
594
|
+
is_dup_series.at[keeper_idx] = False
|
|
595
|
+
# clear sequence__is_duplicate for keeper if present
|
|
596
|
+
if "sequence__is_duplicate" in adata_full.obs.columns:
|
|
597
|
+
adata_full.obs.at[keeper_idx, "sequence__is_duplicate"] = False
|
|
598
|
+
# clear lex duplicate flag too if present
|
|
599
|
+
if "sequence__lex_is_duplicate" in adata_full.obs.columns:
|
|
600
|
+
adata_full.obs.at[keeper_idx, "sequence__lex_is_duplicate"] = False
|
|
601
|
+
|
|
602
|
+
adata_full.obs["is_duplicate"] = is_dup_series.values
|
|
603
|
+
|
|
604
|
+
# reason column
|
|
605
|
+
def _dup_reason_row(row):
|
|
606
|
+
reasons = []
|
|
607
|
+
if row.get("is_duplicate_distance", False):
|
|
608
|
+
reasons.append("distance_thresh")
|
|
609
|
+
if row.get("is_duplicate_clustering", False):
|
|
610
|
+
reasons.append("hamming_metric_cluster")
|
|
611
|
+
if bool(row.get("sequence__is_duplicate", False)):
|
|
612
|
+
reasons.append("sequence_cluster")
|
|
613
|
+
return ";".join(reasons) if reasons else ""
|
|
614
|
+
|
|
615
|
+
try:
|
|
616
|
+
reasons = adata_full.obs.apply(_dup_reason_row, axis=1)
|
|
617
|
+
adata_full.obs["is_duplicate_reason"] = reasons.values
|
|
618
|
+
except Exception:
|
|
619
|
+
adata_full.obs["is_duplicate_reason"] = ""
|
|
620
|
+
|
|
621
|
+
adata_unique = adata_full[~adata_full.obs["is_duplicate"].astype(bool)].copy()
|
|
622
|
+
|
|
623
|
+
# mark flags in .uns
|
|
624
|
+
adata_unique.uns[uns_flag] = True
|
|
625
|
+
adata_unique.uns[uns_filtered_flag] = True
|
|
626
|
+
adata_full.uns[uns_flag] = True
|
|
627
|
+
|
|
628
|
+
return adata_unique, adata_full
|
|
7
629
|
|
|
8
|
-
def find(self, x):
|
|
9
|
-
while self.parent[x] != x:
|
|
10
|
-
self.parent[x] = self.parent[self.parent[x]]
|
|
11
|
-
x = self.parent[x]
|
|
12
|
-
return x
|
|
13
630
|
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
if root_x != root_y:
|
|
18
|
-
self.parent[root_y] = root_x
|
|
631
|
+
# ---------------------------
|
|
632
|
+
# Plot helpers (use adata_full as input)
|
|
633
|
+
# ---------------------------
|
|
19
634
|
|
|
635
|
+
def plot_histogram_pages(
|
|
636
|
+
histograms,
|
|
637
|
+
distance_threshold,
|
|
638
|
+
output_directory=None,
|
|
639
|
+
rows_per_page=6,
|
|
640
|
+
bins=50,
|
|
641
|
+
dpi=160,
|
|
642
|
+
figsize_per_cell=(5, 3),
|
|
643
|
+
adata: Optional[ad.AnnData] = None,
|
|
644
|
+
sample_key: str = "Barcode",
|
|
645
|
+
ref_key: str = "Reference_strand",
|
|
646
|
+
distance_key: str = "sequence__min_hamming_to_pair",
|
|
647
|
+
distance_types: Optional[List[str]] = None,
|
|
648
|
+
):
|
|
649
|
+
"""
|
|
650
|
+
Plot Hamming-distance histograms as a grid (rows=samples, cols=references).
|
|
20
651
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
652
|
+
Changes:
|
|
653
|
+
- Ensures that every subplot in a column (same ref) uses the same X-axis range and the same bins,
|
|
654
|
+
computed from the union of values for that reference across samples/dtypes (clamped to [0,1]).
|
|
655
|
+
"""
|
|
656
|
+
if distance_types is None:
|
|
657
|
+
distance_types = ["min", "fwd", "rev", "hier", "lex_local"]
|
|
25
658
|
|
|
26
|
-
|
|
27
|
-
|
|
659
|
+
# canonicalize samples / refs
|
|
660
|
+
if adata is not None and sample_key in adata.obs.columns and ref_key in adata.obs.columns:
|
|
661
|
+
obs = adata.obs
|
|
662
|
+
sseries = obs[sample_key]
|
|
663
|
+
if not pd.api.types.is_categorical_dtype(sseries):
|
|
664
|
+
sseries = sseries.astype("category")
|
|
665
|
+
samples = list(sseries.cat.categories)
|
|
666
|
+
rseries = obs[ref_key]
|
|
667
|
+
if not pd.api.types.is_categorical_dtype(rseries):
|
|
668
|
+
rseries = rseries.astype("category")
|
|
669
|
+
references = list(rseries.cat.categories)
|
|
670
|
+
use_adata = True
|
|
671
|
+
else:
|
|
672
|
+
samples = sorted({h["sample"] for h in histograms})
|
|
673
|
+
references = sorted({h["reference"] for h in histograms})
|
|
674
|
+
use_adata = False
|
|
28
675
|
|
|
29
|
-
references
|
|
676
|
+
if len(samples) == 0 or len(references) == 0:
|
|
677
|
+
print("No histogram data to plot.")
|
|
678
|
+
return {"distance_pages": [], "cluster_size_pages": []}
|
|
30
679
|
|
|
680
|
+
def clean_array(arr):
|
|
681
|
+
if arr is None or len(arr) == 0:
|
|
682
|
+
return np.array([], dtype=float)
|
|
683
|
+
a = np.asarray(arr, dtype=float)
|
|
684
|
+
a = a[np.isfinite(a)]
|
|
685
|
+
a = a[(a >= 0.0) & (a <= 1.0)]
|
|
686
|
+
return a
|
|
687
|
+
|
|
688
|
+
grid = defaultdict(lambda: defaultdict(list))
|
|
689
|
+
# populate from adata if available
|
|
690
|
+
if use_adata:
|
|
691
|
+
obs = adata.obs
|
|
692
|
+
try:
|
|
693
|
+
grouped = obs.groupby([sample_key, ref_key])
|
|
694
|
+
except Exception:
|
|
695
|
+
grouped = []
|
|
696
|
+
for s in samples:
|
|
697
|
+
for r in references:
|
|
698
|
+
sub = obs[(obs[sample_key] == s) & (obs[ref_key] == r)]
|
|
699
|
+
if not sub.empty:
|
|
700
|
+
grouped.append(((s, r), sub))
|
|
701
|
+
if isinstance(grouped, dict) or hasattr(grouped, "groups"):
|
|
702
|
+
for (s, r), group in grouped:
|
|
703
|
+
if "min" in distance_types and distance_key in group.columns:
|
|
704
|
+
grid[(s, r)]["min"].extend(clean_array(group[distance_key].to_numpy()))
|
|
705
|
+
if "fwd" in distance_types and "fwd_hamming_to_next" in group.columns:
|
|
706
|
+
grid[(s, r)]["fwd"].extend(clean_array(group["fwd_hamming_to_next"].to_numpy()))
|
|
707
|
+
if "rev" in distance_types and "rev_hamming_to_prev" in group.columns:
|
|
708
|
+
grid[(s, r)]["rev"].extend(clean_array(group["rev_hamming_to_prev"].to_numpy()))
|
|
709
|
+
if "hier" in distance_types and "sequence__hier_hamming_to_pair" in group.columns:
|
|
710
|
+
grid[(s, r)]["hier"].extend(clean_array(group["sequence__hier_hamming_to_pair"].to_numpy()))
|
|
711
|
+
else:
|
|
712
|
+
for (s, r), group in grouped:
|
|
713
|
+
if "min" in distance_types and distance_key in group.columns:
|
|
714
|
+
grid[(s, r)]["min"].extend(clean_array(group[distance_key].to_numpy()))
|
|
715
|
+
if "fwd" in distance_types and "fwd_hamming_to_next" in group.columns:
|
|
716
|
+
grid[(s, r)]["fwd"].extend(clean_array(group["fwd_hamming_to_next"].to_numpy()))
|
|
717
|
+
if "rev" in distance_types and "rev_hamming_to_prev" in group.columns:
|
|
718
|
+
grid[(s, r)]["rev"].extend(clean_array(group["rev_hamming_to_prev"].to_numpy()))
|
|
719
|
+
if "hier" in distance_types and "sequence__hier_hamming_to_pair" in group.columns:
|
|
720
|
+
grid[(s, r)]["hier"].extend(clean_array(group["sequence__hier_hamming_to_pair"].to_numpy()))
|
|
721
|
+
|
|
722
|
+
# legacy histograms fallback
|
|
723
|
+
if histograms:
|
|
724
|
+
for h in histograms:
|
|
725
|
+
key = (h["sample"], h["reference"])
|
|
726
|
+
if "lex_local" in distance_types:
|
|
727
|
+
grid[key]["lex_local"].extend(clean_array(h.get("distances", [])))
|
|
728
|
+
if "hier" in distance_types and "hierarchical_pairs" in h:
|
|
729
|
+
grid[key]["hier"].extend(clean_array(h.get("hierarchical_pairs", [])))
|
|
730
|
+
if "cluster_counts" in h:
|
|
731
|
+
grid[key]["_legacy_cluster_counts"].extend(h.get("cluster_counts", []))
|
|
732
|
+
|
|
733
|
+
# Compute per-reference global x-range and bin edges (so every subplot in a column uses same bins)
|
|
734
|
+
ref_xmax = {}
|
|
31
735
|
for ref in references:
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
736
|
+
vals_for_ref = []
|
|
737
|
+
for s in samples:
|
|
738
|
+
for dt in distance_types:
|
|
739
|
+
a = np.asarray(grid[(s, ref)].get(dt, []), dtype=float)
|
|
740
|
+
if a.size:
|
|
741
|
+
a = a[np.isfinite(a)]
|
|
742
|
+
if a.size:
|
|
743
|
+
vals_for_ref.append(a)
|
|
744
|
+
if vals_for_ref:
|
|
745
|
+
allvals = np.concatenate(vals_for_ref)
|
|
746
|
+
vmax = float(np.nanmax(allvals)) if np.isfinite(allvals).any() else 1.0
|
|
747
|
+
# pad slightly to include uppermost bin and always keep at least distance_threshold
|
|
748
|
+
vmax = max(vmax, float(distance_threshold))
|
|
749
|
+
vmax = min(1.0, max(0.01, vmax)) # clamp to [0.01, 1.0] to avoid degenerate bins
|
|
750
|
+
else:
|
|
751
|
+
vmax = 1.0
|
|
752
|
+
ref_xmax[ref] = vmax
|
|
753
|
+
|
|
754
|
+
# counts (for labels)
|
|
755
|
+
if use_adata:
|
|
756
|
+
counts = {(s, r): int(((adata.obs[sample_key] == s) & (adata.obs[ref_key] == r)).sum()) for s in samples for r in references}
|
|
757
|
+
else:
|
|
758
|
+
counts = {(s, r): sum(len(grid[(s, r)][dt]) for dt in distance_types) for s in samples for r in references}
|
|
759
|
+
|
|
760
|
+
distance_pages = []
|
|
761
|
+
cluster_size_pages = []
|
|
762
|
+
n_pages = math.ceil(len(samples) / rows_per_page)
|
|
763
|
+
palette = plt.get_cmap("tab10")
|
|
764
|
+
dtype_colors = {dt: palette(i % 10) for i, dt in enumerate(distance_types)}
|
|
765
|
+
|
|
766
|
+
for page in range(n_pages):
|
|
767
|
+
start = page * rows_per_page
|
|
768
|
+
end = min(start + rows_per_page, len(samples))
|
|
769
|
+
chunk = samples[start:end]
|
|
770
|
+
nrows = len(chunk)
|
|
771
|
+
ncols = len(references)
|
|
772
|
+
|
|
773
|
+
# Distance histogram page
|
|
774
|
+
fig_w = figsize_per_cell[0] * ncols
|
|
775
|
+
fig_h = figsize_per_cell[1] * nrows
|
|
776
|
+
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False)
|
|
777
|
+
|
|
778
|
+
for r_idx, sample_name in enumerate(chunk):
|
|
779
|
+
for c_idx, ref_name in enumerate(references):
|
|
780
|
+
ax = axes[r_idx][c_idx]
|
|
781
|
+
any_data = False
|
|
782
|
+
# pick per-column bins based on ref_xmax
|
|
783
|
+
ref_vmax = ref_xmax.get(ref_name, 1.0)
|
|
784
|
+
bins_edges = np.linspace(0.0, ref_vmax, bins + 1)
|
|
785
|
+
for dtype in distance_types:
|
|
786
|
+
vals = np.asarray(grid[(sample_name, ref_name)].get(dtype, []), dtype=float)
|
|
787
|
+
if vals.size > 0:
|
|
788
|
+
vals = vals[np.isfinite(vals)]
|
|
789
|
+
vals = vals[(vals >= 0.0) & (vals <= ref_vmax)]
|
|
790
|
+
if vals.size > 0:
|
|
791
|
+
any_data = True
|
|
792
|
+
ax.hist(vals, bins=bins_edges, alpha=0.5, label=dtype, density=False, stacked=False,
|
|
793
|
+
color=dtype_colors.get(dtype, None))
|
|
794
|
+
if not any_data:
|
|
795
|
+
ax.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes, fontsize=10, color="gray")
|
|
796
|
+
# threshold line (make sure it is within axis)
|
|
797
|
+
ax.axvline(distance_threshold, color="red", linestyle="--", linewidth=1)
|
|
798
|
+
|
|
799
|
+
if r_idx == 0:
|
|
800
|
+
ax.set_title(str(ref_name), fontsize=10)
|
|
801
|
+
if c_idx == 0:
|
|
802
|
+
total_reads = sum(counts.get((sample_name, ref), 0) for ref in references) if not use_adata else int((adata.obs[sample_key] == sample_name).sum())
|
|
803
|
+
ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=9)
|
|
804
|
+
if r_idx == nrows - 1:
|
|
805
|
+
ax.set_xlabel("Hamming Distance", fontsize=9)
|
|
806
|
+
else:
|
|
807
|
+
ax.set_xticklabels([])
|
|
808
|
+
|
|
809
|
+
ax.set_xlim(left=0.0, right=ref_vmax)
|
|
810
|
+
ax.grid(True, alpha=0.25)
|
|
811
|
+
if r_idx == 0 and c_idx == 0:
|
|
812
|
+
ax.legend(fontsize=7, loc="upper right")
|
|
813
|
+
|
|
814
|
+
fig.suptitle(f"Hamming distance histograms (rows=samples, cols=references) — page {page+1}/{n_pages}", fontsize=12, y=0.995)
|
|
815
|
+
fig.tight_layout(rect=[0, 0, 1, 0.96])
|
|
816
|
+
|
|
817
|
+
if output_directory:
|
|
818
|
+
os.makedirs(output_directory, exist_ok=True)
|
|
819
|
+
fname = os.path.join(output_directory, f"hamming_histograms_page_{page+1}.png")
|
|
820
|
+
plt.savefig(fname, bbox_inches="tight")
|
|
821
|
+
distance_pages.append(fname)
|
|
822
|
+
else:
|
|
823
|
+
plt.show()
|
|
824
|
+
plt.close(fig)
|
|
825
|
+
|
|
826
|
+
# Cluster-size histogram page (unchanged except it uses adata-derived sizes per cluster if available)
|
|
827
|
+
fig_w = figsize_per_cell[0] * ncols
|
|
828
|
+
fig_h = figsize_per_cell[1] * nrows
|
|
829
|
+
fig2, axes2 = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False)
|
|
830
|
+
|
|
831
|
+
for r_idx, sample_name in enumerate(chunk):
|
|
832
|
+
for c_idx, ref_name in enumerate(references):
|
|
833
|
+
ax = axes2[r_idx][c_idx]
|
|
834
|
+
sizes = []
|
|
835
|
+
if use_adata and ("sequence__merged_cluster_id" in adata.obs.columns and "sequence__cluster_size" in adata.obs.columns):
|
|
836
|
+
sub = adata.obs[(adata.obs[sample_key] == sample_name) & (adata.obs[ref_key] == ref_name)]
|
|
837
|
+
if not sub.empty:
|
|
838
|
+
try:
|
|
839
|
+
grp = sub.groupby("sequence__merged_cluster_id")["sequence__cluster_size"].first()
|
|
840
|
+
sizes = [int(x) for x in grp.to_numpy().tolist() if (pd.notna(x) and np.isfinite(x))]
|
|
841
|
+
except Exception:
|
|
842
|
+
try:
|
|
843
|
+
unique_pairs = sub[["sequence__merged_cluster_id", "sequence__cluster_size"]].drop_duplicates()
|
|
844
|
+
sizes = [int(x) for x in unique_pairs["sequence__cluster_size"].dropna().astype(int).tolist()]
|
|
845
|
+
except Exception:
|
|
846
|
+
sizes = []
|
|
847
|
+
if (not sizes) and histograms:
|
|
848
|
+
for h in histograms:
|
|
849
|
+
if h.get("sample") == sample_name and h.get("reference") == ref_name:
|
|
850
|
+
sizes = h.get("cluster_counts", []) or []
|
|
851
|
+
break
|
|
852
|
+
|
|
853
|
+
if sizes:
|
|
854
|
+
ax.hist(sizes, bins=range(1, max(2, max(sizes) + 1)), alpha=0.8, align="left")
|
|
855
|
+
ax.set_xlabel("Cluster size")
|
|
856
|
+
ax.set_ylabel("Count")
|
|
857
|
+
else:
|
|
858
|
+
ax.text(0.5, 0.5, "No clusters", ha="center", va="center", transform=ax.transAxes, fontsize=10, color="gray")
|
|
859
|
+
|
|
860
|
+
if r_idx == 0:
|
|
861
|
+
ax.set_title(str(ref_name), fontsize=10)
|
|
862
|
+
if c_idx == 0:
|
|
863
|
+
total_reads = sum(counts.get((sample_name, ref), 0) for ref in references) if not use_adata else int((adata.obs[sample_key] == sample_name).sum())
|
|
864
|
+
ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=9)
|
|
865
|
+
if r_idx != nrows - 1:
|
|
866
|
+
ax.set_xticklabels([])
|
|
867
|
+
|
|
868
|
+
ax.grid(True, alpha=0.25)
|
|
869
|
+
|
|
870
|
+
fig2.suptitle(f"Union-find cluster size histograms — page {page+1}/{n_pages}", fontsize=12, y=0.995)
|
|
871
|
+
fig2.tight_layout(rect=[0, 0, 1, 0.96])
|
|
872
|
+
|
|
873
|
+
if output_directory:
|
|
874
|
+
fname2 = os.path.join(output_directory, f"cluster_size_histograms_page_{page+1}.png")
|
|
875
|
+
plt.savefig(fname2, bbox_inches="tight")
|
|
876
|
+
cluster_size_pages.append(fname2)
|
|
877
|
+
else:
|
|
878
|
+
plt.show()
|
|
879
|
+
plt.close(fig2)
|
|
880
|
+
|
|
881
|
+
return {"distance_pages": distance_pages, "cluster_size_pages": cluster_size_pages}
|
|
882
|
+
|
|
883
|
+
|
|
884
|
+
def plot_hamming_vs_metric_pages(
|
|
885
|
+
adata,
|
|
886
|
+
metric_keys: Union[str, List[str]],
|
|
887
|
+
hamming_col: str = "fwd_hamming_to_next",
|
|
888
|
+
sample_col: str = "Barcode",
|
|
889
|
+
ref_col: str = "Reference_strand",
|
|
890
|
+
references: Optional[List[str]] = None,
|
|
891
|
+
samples: Optional[List[str]] = None,
|
|
892
|
+
rows_per_fig: int = 6,
|
|
893
|
+
dpi: int = 160,
|
|
894
|
+
filename_prefix: str = "hamming_vs_metric",
|
|
895
|
+
output_dir: Optional[str] = None,
|
|
896
|
+
kde: bool = False,
|
|
897
|
+
contour: bool = False,
|
|
898
|
+
regression: bool = True,
|
|
899
|
+
show_ticks: bool = True,
|
|
900
|
+
clustering: Optional[Dict[str, Any]] = None,
|
|
901
|
+
write_clusters_to_adata: bool = False,
|
|
902
|
+
figsize_per_cell: Tuple[float, float] = (4.0, 3.0),
|
|
903
|
+
random_state: int = 0,
|
|
904
|
+
highlight_threshold: Optional[float] = None,
|
|
905
|
+
highlight_color: str = "red",
|
|
906
|
+
color_by_duplicate: bool = False,
|
|
907
|
+
duplicate_col: str = "is_duplicate",
|
|
908
|
+
) -> Dict[str, Any]:
|
|
909
|
+
"""
|
|
910
|
+
Plot hamming (y) vs metric (x).
|
|
911
|
+
|
|
912
|
+
New behavior:
|
|
913
|
+
- If color_by_duplicate is True and adata.obs[duplicate_col] exists, points are colored by that boolean:
|
|
914
|
+
duplicates -> highlight_color (with edge), non-duplicates -> gray
|
|
915
|
+
- If color_by_duplicate is False, previous highlight_threshold behavior is preserved.
|
|
916
|
+
"""
|
|
917
|
+
if isinstance(metric_keys, str):
|
|
918
|
+
metric_keys = [metric_keys]
|
|
919
|
+
metric_keys = list(metric_keys)
|
|
920
|
+
|
|
921
|
+
if output_dir is not None:
|
|
922
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
923
|
+
|
|
924
|
+
obs = adata.obs
|
|
925
|
+
if sample_col not in obs.columns or ref_col not in obs.columns:
|
|
926
|
+
raise ValueError(f"sample_col '{sample_col}' and ref_col '{ref_col}' must exist in adata.obs")
|
|
927
|
+
|
|
928
|
+
# canonicalize samples and refs
|
|
929
|
+
if samples is None:
|
|
930
|
+
sseries = obs[sample_col]
|
|
931
|
+
if not pd.api.types.is_categorical_dtype(sseries):
|
|
932
|
+
sseries = sseries.astype("category")
|
|
933
|
+
samples_all = list(sseries.cat.categories)
|
|
934
|
+
else:
|
|
935
|
+
samples_all = list(samples)
|
|
936
|
+
|
|
937
|
+
if references is None:
|
|
938
|
+
rseries = obs[ref_col]
|
|
939
|
+
if not pd.api.types.is_categorical_dtype(rseries):
|
|
940
|
+
rseries = rseries.astype("category")
|
|
941
|
+
refs_all = list(rseries.cat.categories)
|
|
942
|
+
else:
|
|
943
|
+
refs_all = list(references)
|
|
944
|
+
|
|
945
|
+
extra_col = "ALLREFS"
|
|
946
|
+
cols = list(refs_all) + [extra_col]
|
|
947
|
+
|
|
948
|
+
saved_map: Dict[str, Any] = {}
|
|
949
|
+
|
|
950
|
+
for metric in metric_keys:
|
|
951
|
+
clusters_info: Dict[Tuple[str, str], Dict[str, Any]] = {}
|
|
952
|
+
files: List[str] = []
|
|
953
|
+
written_cols: List[str] = []
|
|
954
|
+
|
|
955
|
+
# compute global x/y limits robustly by aligning Series
|
|
956
|
+
global_xlim = None
|
|
957
|
+
global_ylim = None
|
|
958
|
+
|
|
959
|
+
metric_present = metric in obs.columns
|
|
960
|
+
hamming_present = hamming_col in obs.columns
|
|
961
|
+
|
|
962
|
+
if metric_present or hamming_present:
|
|
963
|
+
sX = obs[metric].astype(float) if metric_present else None
|
|
964
|
+
sY = pd.to_numeric(obs[hamming_col], errors="coerce") if hamming_present else None
|
|
965
|
+
|
|
966
|
+
if (sX is not None) and (sY is not None):
|
|
967
|
+
valid_both = sX.notna() & sY.notna() & np.isfinite(sX.values) & np.isfinite(sY.values)
|
|
968
|
+
if valid_both.any():
|
|
969
|
+
xvals = sX[valid_both].to_numpy(dtype=float)
|
|
970
|
+
yvals = sY[valid_both].to_numpy(dtype=float)
|
|
971
|
+
xmin, xmax = float(np.nanmin(xvals)), float(np.nanmax(xvals))
|
|
972
|
+
ymin, ymax = float(np.nanmin(yvals)), float(np.nanmax(yvals))
|
|
973
|
+
xpad = max(1e-6, (xmax - xmin) * 0.05) if xmax > xmin else max(1e-3, abs(xmin) * 0.05 + 1e-3)
|
|
974
|
+
ypad = max(1e-6, (ymax - ymin) * 0.05) if ymax > ymin else max(1e-3, abs(ymin) * 0.05 + 1e-3)
|
|
975
|
+
global_xlim = (xmin - xpad, xmax + xpad)
|
|
976
|
+
global_ylim = (ymin - ypad, ymax + ypad)
|
|
977
|
+
else:
|
|
978
|
+
sX_finite = sX[np.isfinite(sX)]
|
|
979
|
+
sY_finite = sY[np.isfinite(sY)]
|
|
980
|
+
if sX_finite.size > 0:
|
|
981
|
+
xmin, xmax = float(np.nanmin(sX_finite)), float(np.nanmax(sX_finite))
|
|
982
|
+
xpad = max(1e-6, (xmax - xmin) * 0.05) if xmax > xmin else max(1e-3, abs(xmin) * 0.05 + 1e-3)
|
|
983
|
+
global_xlim = (xmin - xpad, xmax + xpad)
|
|
984
|
+
if sY_finite.size > 0:
|
|
985
|
+
ymin, ymax = float(np.nanmin(sY_finite)), float(np.nanmax(sY_finite))
|
|
986
|
+
ypad = max(1e-6, (ymax - ymin) * 0.05) if ymax > ymin else max(1e-3, abs(ymin) * 0.05 + 1e-3)
|
|
987
|
+
global_ylim = (ymin - ypad, ymax + ypad)
|
|
988
|
+
elif sX is not None:
|
|
989
|
+
sX_finite = sX[np.isfinite(sX)]
|
|
990
|
+
if sX_finite.size > 0:
|
|
991
|
+
xmin, xmax = float(np.nanmin(sX_finite)), float(np.nanmax(sX_finite))
|
|
992
|
+
xpad = max(1e-6, (xmax - xmin) * 0.05) if xmax > xmin else max(1e-3, abs(xmin) * 0.05 + 1e-3)
|
|
993
|
+
global_xlim = (xmin - xpad, xmax + xpad)
|
|
994
|
+
elif sY is not None:
|
|
995
|
+
sY_finite = sY[np.isfinite(sY)]
|
|
996
|
+
if sY_finite.size > 0:
|
|
997
|
+
ymin, ymax = float(np.nanmin(sY_finite)), float(np.nanmax(sY_finite))
|
|
998
|
+
ypad = max(1e-6, (ymax - ymin) * 0.05) if ymax > ymin else max(1e-3, abs(ymin) * 0.05 + 1e-3)
|
|
999
|
+
global_ylim = (ymin - ypad, ymax + ypad)
|
|
1000
|
+
|
|
1001
|
+
# pagination
|
|
1002
|
+
for start in range(0, len(samples_all), rows_per_fig):
|
|
1003
|
+
chunk = samples_all[start : start + rows_per_fig]
|
|
1004
|
+
nrows = len(chunk)
|
|
1005
|
+
ncols = len(cols)
|
|
1006
|
+
fig_w = ncols * figsize_per_cell[0]
|
|
1007
|
+
fig_h = nrows * figsize_per_cell[1]
|
|
1008
|
+
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False)
|
|
1009
|
+
|
|
1010
|
+
for r_idx, sample_name in enumerate(chunk):
|
|
1011
|
+
for c_idx, ref_name in enumerate(cols):
|
|
1012
|
+
ax = axes[r_idx][c_idx]
|
|
1013
|
+
if ref_name == extra_col:
|
|
1014
|
+
mask = (obs[sample_col].values == sample_name)
|
|
1015
|
+
else:
|
|
1016
|
+
mask = (obs[sample_col].values == sample_name) & (obs[ref_col].values == ref_name)
|
|
1017
|
+
|
|
1018
|
+
sub = obs[mask]
|
|
1019
|
+
|
|
1020
|
+
if metric in sub.columns:
|
|
1021
|
+
x_all = pd.to_numeric(sub[metric], errors="coerce").to_numpy(dtype=float)
|
|
1022
|
+
else:
|
|
1023
|
+
x_all = np.array([], dtype=float)
|
|
1024
|
+
if hamming_col in sub.columns:
|
|
1025
|
+
y_all = pd.to_numeric(sub[hamming_col], errors="coerce").to_numpy(dtype=float)
|
|
1026
|
+
else:
|
|
1027
|
+
y_all = np.array([], dtype=float)
|
|
1028
|
+
|
|
1029
|
+
idxs = sub.index.to_numpy()
|
|
1030
|
+
|
|
1031
|
+
# drop nan pairs
|
|
1032
|
+
if x_all.size and y_all.size and len(x_all) == len(y_all):
|
|
1033
|
+
valid_pair_mask = np.isfinite(x_all) & np.isfinite(y_all)
|
|
1034
|
+
x = x_all[valid_pair_mask]
|
|
1035
|
+
y = y_all[valid_pair_mask]
|
|
1036
|
+
idxs_valid = idxs[valid_pair_mask]
|
|
1037
|
+
else:
|
|
1038
|
+
x = np.array([], dtype=float)
|
|
1039
|
+
y = np.array([], dtype=float)
|
|
1040
|
+
idxs_valid = np.array([], dtype=int)
|
|
1041
|
+
|
|
1042
|
+
if x.size == 0:
|
|
1043
|
+
ax.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes)
|
|
1044
|
+
clusters_info[(sample_name, ref_name)] = {"diag": None, "n_points": 0}
|
|
1045
|
+
else:
|
|
1046
|
+
# Decide color mapping
|
|
1047
|
+
if color_by_duplicate and duplicate_col in adata.obs.columns and idxs_valid.size > 0:
|
|
1048
|
+
# get boolean series aligned to idxs_valid
|
|
1049
|
+
try:
|
|
1050
|
+
dup_flags = adata.obs.loc[idxs_valid, duplicate_col].astype(bool).to_numpy()
|
|
1051
|
+
except Exception:
|
|
1052
|
+
dup_flags = np.zeros(len(idxs_valid), dtype=bool)
|
|
1053
|
+
mask_dup = dup_flags
|
|
1054
|
+
mask_nondup = ~mask_dup
|
|
1055
|
+
# plot non-duplicates first in gray, duplicates in highlight color
|
|
1056
|
+
if mask_nondup.any():
|
|
1057
|
+
ax.scatter(x[mask_nondup], y[mask_nondup], s=12, alpha=0.6, rasterized=True, c="lightgray")
|
|
1058
|
+
if mask_dup.any():
|
|
1059
|
+
ax.scatter(x[mask_dup], y[mask_dup], s=20, alpha=0.9, rasterized=True, c=highlight_color, edgecolors="k", linewidths=0.3)
|
|
1060
|
+
else:
|
|
1061
|
+
# old behavior: highlight by threshold if requested
|
|
1062
|
+
if highlight_threshold is not None and y.size:
|
|
1063
|
+
mask_low = (y < float(highlight_threshold)) & np.isfinite(y)
|
|
1064
|
+
mask_high = ~mask_low
|
|
1065
|
+
if mask_high.any():
|
|
1066
|
+
ax.scatter(x[mask_high], y[mask_high], s=12, alpha=0.6, rasterized=True)
|
|
1067
|
+
if mask_low.any():
|
|
1068
|
+
ax.scatter(x[mask_low], y[mask_low], s=18, alpha=0.9, rasterized=True, c=highlight_color, edgecolors="k", linewidths=0.3)
|
|
1069
|
+
else:
|
|
1070
|
+
ax.scatter(x, y, s=12, alpha=0.6, rasterized=True)
|
|
1071
|
+
|
|
1072
|
+
if kde and gaussian_kde is not None and x.size >= 4:
|
|
1073
|
+
try:
|
|
1074
|
+
xy = np.vstack([x, y])
|
|
1075
|
+
kde2 = gaussian_kde(xy)(xy)
|
|
1076
|
+
if contour:
|
|
1077
|
+
xi = np.linspace(np.nanmin(x), np.nanmax(x), 80)
|
|
1078
|
+
yi = np.linspace(np.nanmin(y), np.nanmax(y), 80)
|
|
1079
|
+
xi_g, yi_g = np.meshgrid(xi, yi)
|
|
1080
|
+
coords = np.vstack([xi_g.ravel(), yi_g.ravel()])
|
|
1081
|
+
zi = gaussian_kde(np.vstack([x, y]))(coords).reshape(xi_g.shape)
|
|
1082
|
+
ax.contourf(xi_g, yi_g, zi, levels=8, alpha=0.35, cmap="Blues")
|
|
1083
|
+
else:
|
|
1084
|
+
ax.scatter(x, y, c=kde2, s=16, cmap="viridis", alpha=0.7, linewidths=0)
|
|
1085
|
+
except Exception:
|
|
1086
|
+
pass
|
|
1087
|
+
|
|
1088
|
+
if regression and x.size >= 2:
|
|
1089
|
+
try:
|
|
1090
|
+
a, b = np.polyfit(x, y, 1)
|
|
1091
|
+
xs = np.linspace(np.nanmin(x), np.nanmax(x), 100)
|
|
1092
|
+
ys = a * xs + b
|
|
1093
|
+
ax.plot(xs, ys, linestyle="--", linewidth=1.2, alpha=0.9, color="red")
|
|
1094
|
+
r = np.corrcoef(x, y)[0, 1]
|
|
1095
|
+
ax.text(0.98, 0.02, f"r={float(r):.3f}", ha="right", va="bottom", transform=ax.transAxes, fontsize=8,
|
|
1096
|
+
bbox=dict(facecolor="white", alpha=0.6, boxstyle="round,pad=0.2"))
|
|
1097
|
+
except Exception:
|
|
1098
|
+
pass
|
|
1099
|
+
|
|
1100
|
+
if clustering:
|
|
1101
|
+
cl_labels, diag = _run_clustering(
|
|
1102
|
+
x, y,
|
|
1103
|
+
method=clustering.get("method", "dbscan"),
|
|
1104
|
+
n_clusters=clustering.get("n_clusters", 2),
|
|
1105
|
+
dbscan_eps=clustering.get("dbscan_eps", 0.05),
|
|
1106
|
+
dbscan_min_samples=clustering.get("dbscan_min_samples", 5),
|
|
1107
|
+
random_state=random_state,
|
|
1108
|
+
min_points=clustering.get("min_points", 8),
|
|
1109
|
+
)
|
|
1110
|
+
|
|
1111
|
+
remapped_labels = cl_labels.copy()
|
|
1112
|
+
unique_nonnoise = sorted([u for u in np.unique(cl_labels) if u != -1])
|
|
1113
|
+
if len(unique_nonnoise) > 0:
|
|
1114
|
+
medians = {}
|
|
1115
|
+
for lab in unique_nonnoise:
|
|
1116
|
+
mask_lab = (cl_labels == lab)
|
|
1117
|
+
medians[lab] = float(np.median(y[mask_lab])) if mask_lab.any() else float("nan")
|
|
1118
|
+
sorted_by_median = sorted(unique_nonnoise, key=lambda l: (np.nan if np.isnan(medians[l]) else medians[l]), reverse=True)
|
|
1119
|
+
mapping = {old: new for new, old in enumerate(sorted_by_median)}
|
|
1120
|
+
for i_lab in range(len(remapped_labels)):
|
|
1121
|
+
if remapped_labels[i_lab] != -1:
|
|
1122
|
+
remapped_labels[i_lab] = mapping.get(remapped_labels[i_lab], -1)
|
|
1123
|
+
diag = diag or {}
|
|
1124
|
+
diag["cluster_median_hamming"] = {int(old): medians[old] for old in medians}
|
|
1125
|
+
diag["cluster_old_to_new_map"] = {int(old): int(new) for old, new in mapping.items()}
|
|
95
1126
|
else:
|
|
96
|
-
|
|
1127
|
+
remapped_labels = cl_labels.copy()
|
|
1128
|
+
diag = diag or {}
|
|
1129
|
+
diag["cluster_median_hamming"] = {}
|
|
1130
|
+
diag["cluster_old_to_new_map"] = {}
|
|
1131
|
+
|
|
1132
|
+
_overlay_clusters_on_ax(ax, x, y, remapped_labels, diag,
|
|
1133
|
+
cmap=clustering.get("cmap", "tab10"),
|
|
1134
|
+
hull=clustering.get("hull", True),
|
|
1135
|
+
show_cluster_labels=True)
|
|
1136
|
+
|
|
1137
|
+
clusters_info[(sample_name, ref_name)] = {"diag": diag, "n_points": len(x)}
|
|
1138
|
+
|
|
1139
|
+
if write_clusters_to_adata and idxs_valid.size > 0:
|
|
1140
|
+
colname_safe_ref = (ref_name if ref_name != extra_col else "ALLREFS")
|
|
1141
|
+
colname = f"hamming_cluster__{metric}__{sample_name}__{colname_safe_ref}"
|
|
1142
|
+
if colname not in adata.obs.columns:
|
|
1143
|
+
adata.obs[colname] = np.nan
|
|
1144
|
+
lab_arr = remapped_labels.astype(float)
|
|
1145
|
+
adata.obs.loc[idxs_valid, colname] = lab_arr
|
|
1146
|
+
if colname not in written_cols:
|
|
1147
|
+
written_cols.append(colname)
|
|
1148
|
+
|
|
1149
|
+
if r_idx == 0:
|
|
1150
|
+
ax.set_title(str(ref_name), fontsize=9)
|
|
1151
|
+
if c_idx == 0:
|
|
1152
|
+
total_reads = int((obs[sample_col] == sample_name).sum())
|
|
1153
|
+
ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=8)
|
|
1154
|
+
if r_idx == nrows - 1:
|
|
1155
|
+
ax.set_xlabel(metric, fontsize=8)
|
|
1156
|
+
|
|
1157
|
+
if global_xlim is not None:
|
|
1158
|
+
ax.set_xlim(global_xlim)
|
|
1159
|
+
if global_ylim is not None:
|
|
1160
|
+
ax.set_ylim(global_ylim)
|
|
1161
|
+
|
|
1162
|
+
if not show_ticks:
|
|
1163
|
+
ax.set_xticklabels([])
|
|
1164
|
+
ax.set_yticklabels([])
|
|
1165
|
+
|
|
1166
|
+
ax.grid(True, alpha=0.25)
|
|
1167
|
+
|
|
1168
|
+
fig.suptitle(f"Hamming ({hamming_col}) vs {metric}", y=0.995, fontsize=11)
|
|
1169
|
+
fig.tight_layout(rect=[0, 0, 1, 0.97])
|
|
1170
|
+
|
|
1171
|
+
page_idx = start // rows_per_fig + 1
|
|
1172
|
+
fname = f"{filename_prefix}_{metric}_page{page_idx}.png"
|
|
1173
|
+
if output_dir:
|
|
1174
|
+
outpath = os.path.join(output_dir, fname)
|
|
1175
|
+
plt.savefig(outpath, bbox_inches="tight", dpi=dpi)
|
|
1176
|
+
files.append(outpath)
|
|
1177
|
+
else:
|
|
1178
|
+
plt.show()
|
|
1179
|
+
plt.close(fig)
|
|
1180
|
+
|
|
1181
|
+
saved_map[metric] = {"files": files, "clusters_info": clusters_info, "written_cols": written_cols}
|
|
1182
|
+
|
|
1183
|
+
return saved_map
|
|
1184
|
+
|
|
1185
|
+
|
|
1186
|
+
def _run_clustering(
|
|
1187
|
+
x: np.ndarray,
|
|
1188
|
+
y: np.ndarray,
|
|
1189
|
+
*,
|
|
1190
|
+
method: str = "kmeans", # "kmeans", "dbscan", "gmm", "hdbscan"
|
|
1191
|
+
n_clusters: int = 2,
|
|
1192
|
+
dbscan_eps: float = 0.05,
|
|
1193
|
+
dbscan_min_samples: int = 5,
|
|
1194
|
+
random_state: int = 0,
|
|
1195
|
+
min_points: int = 10,
|
|
1196
|
+
) -> Tuple[np.ndarray, Dict[str, Any]]:
|
|
1197
|
+
"""
|
|
1198
|
+
Run clustering on 2D points (x,y). Returns labels (len = npoints) and diagnostics dict.
|
|
1199
|
+
Labels follow sklearn conventions (noise -> -1 for DBSCAN/HDBSCAN).
|
|
1200
|
+
"""
|
|
1201
|
+
try:
|
|
1202
|
+
from sklearn.cluster import KMeans, DBSCAN
|
|
1203
|
+
from sklearn.mixture import GaussianMixture
|
|
1204
|
+
from sklearn.metrics import silhouette_score
|
|
1205
|
+
except Exception:
|
|
1206
|
+
KMeans = DBSCAN = GaussianMixture = silhouette_score = None
|
|
1207
|
+
|
|
1208
|
+
pts = np.column_stack([x, y])
|
|
1209
|
+
diagnostics: Dict[str, Any] = {"method": method, "n_input": len(x)}
|
|
1210
|
+
if len(x) < min_points:
|
|
1211
|
+
diagnostics["skipped"] = True
|
|
1212
|
+
return np.full(len(x), -1, dtype=int), diagnostics
|
|
1213
|
+
|
|
1214
|
+
method = (method or "kmeans").lower()
|
|
1215
|
+
labels = np.full(len(x), -1, dtype=int)
|
|
1216
|
+
|
|
1217
|
+
try:
|
|
1218
|
+
if method == "kmeans" and KMeans is not None:
|
|
1219
|
+
km = KMeans(n_clusters=max(1, int(n_clusters)), random_state=random_state)
|
|
1220
|
+
labels = km.fit_predict(pts)
|
|
1221
|
+
diagnostics["centers"] = km.cluster_centers_
|
|
1222
|
+
diagnostics["n_clusters_found"] = int(len(np.unique(labels)))
|
|
1223
|
+
elif method == "dbscan" and DBSCAN is not None:
|
|
1224
|
+
db = DBSCAN(eps=float(dbscan_eps), min_samples=int(dbscan_min_samples))
|
|
1225
|
+
labels = db.fit_predict(pts)
|
|
1226
|
+
uniq = [u for u in np.unique(labels) if u != -1]
|
|
1227
|
+
diagnostics["n_clusters_found"] = int(len(uniq))
|
|
1228
|
+
elif method == "gmm" and GaussianMixture is not None:
|
|
1229
|
+
gm = GaussianMixture(n_components=max(1, int(n_clusters)), random_state=random_state)
|
|
1230
|
+
labels = gm.fit_predict(pts)
|
|
1231
|
+
diagnostics["means"] = gm.means_
|
|
1232
|
+
diagnostics["covariances"] = getattr(gm, "covariances_", None)
|
|
1233
|
+
diagnostics["n_clusters_found"] = int(len(np.unique(labels)))
|
|
1234
|
+
else:
|
|
1235
|
+
# fallback: try DBSCAN then KMeans
|
|
1236
|
+
if DBSCAN is not None:
|
|
1237
|
+
db = DBSCAN(eps=float(dbscan_eps), min_samples=int(dbscan_min_samples))
|
|
1238
|
+
labels = db.fit_predict(pts)
|
|
1239
|
+
if (labels == -1).all() and KMeans is not None:
|
|
1240
|
+
km = KMeans(n_clusters=max(1, int(n_clusters)), random_state=random_state)
|
|
1241
|
+
labels = km.fit_predict(pts)
|
|
1242
|
+
diagnostics["fallback_to"] = "kmeans"
|
|
1243
|
+
diagnostics["centers"] = km.cluster_centers_
|
|
1244
|
+
diagnostics["n_clusters_found"] = int(len(np.unique(labels)))
|
|
1245
|
+
elif KMeans is not None:
|
|
1246
|
+
km = KMeans(n_clusters=max(1, int(n_clusters)), random_state=random_state)
|
|
1247
|
+
labels = km.fit_predict(pts)
|
|
1248
|
+
diagnostics["n_clusters_found"] = int(len(np.unique(labels)))
|
|
1249
|
+
else:
|
|
1250
|
+
diagnostics["skipped"] = True
|
|
1251
|
+
return np.full(len(x), -1, dtype=int), diagnostics
|
|
1252
|
+
|
|
1253
|
+
except Exception as e:
|
|
1254
|
+
diagnostics["error"] = str(e)
|
|
1255
|
+
diagnostics["skipped"] = True
|
|
1256
|
+
return np.full(len(x), -1, dtype=int), diagnostics
|
|
97
1257
|
|
|
98
|
-
|
|
1258
|
+
# remap non-noise labels to contiguous ints starting at 0 (keep -1 for noise)
|
|
1259
|
+
unique_nonnoise = sorted([u for u in np.unique(labels) if u != -1])
|
|
1260
|
+
if unique_nonnoise:
|
|
1261
|
+
mapping = {old: new for new, old in enumerate(unique_nonnoise)}
|
|
1262
|
+
remapped = np.full_like(labels, -1)
|
|
1263
|
+
for i, lab in enumerate(labels):
|
|
1264
|
+
if lab != -1:
|
|
1265
|
+
remapped[i] = mapping.get(lab, -1)
|
|
1266
|
+
labels = remapped
|
|
1267
|
+
diagnostics["n_clusters_found"] = int(len(unique_nonnoise))
|
|
1268
|
+
else:
|
|
1269
|
+
diagnostics["n_clusters_found"] = 0
|
|
99
1270
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
1271
|
+
# compute silhouette if suitable
|
|
1272
|
+
try:
|
|
1273
|
+
if diagnostics.get("n_clusters_found", 0) >= 2 and len(x) >= 3 and silhouette_score is not None:
|
|
1274
|
+
diagnostics["silhouette"] = float(silhouette_score(pts, labels))
|
|
1275
|
+
else:
|
|
1276
|
+
diagnostics["silhouette"] = None
|
|
1277
|
+
except Exception:
|
|
1278
|
+
diagnostics["silhouette"] = None
|
|
105
1279
|
|
|
106
|
-
|
|
1280
|
+
diagnostics["skipped"] = False
|
|
1281
|
+
return labels.astype(int), diagnostics
|
|
107
1282
|
|
|
108
|
-
uf = UnionFind(N)
|
|
109
|
-
for i, j in all_pairs:
|
|
110
|
-
uf.union(i, j)
|
|
111
1283
|
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
1284
|
+
def _overlay_clusters_on_ax(
|
|
1285
|
+
ax,
|
|
1286
|
+
x,
|
|
1287
|
+
y,
|
|
1288
|
+
labels,
|
|
1289
|
+
diagnostics,
|
|
1290
|
+
*,
|
|
1291
|
+
cmap="tab20",
|
|
1292
|
+
alpha_pts=0.6,
|
|
1293
|
+
marker="o",
|
|
1294
|
+
plot_centroids=True,
|
|
1295
|
+
centroid_marker="X",
|
|
1296
|
+
centroid_size=60,
|
|
1297
|
+
hull=True,
|
|
1298
|
+
hull_alpha=0.12,
|
|
1299
|
+
hull_edgecolor="k",
|
|
1300
|
+
show_cluster_labels=True,
|
|
1301
|
+
cluster_label_fontsize=8,
|
|
1302
|
+
):
|
|
1303
|
+
"""
|
|
1304
|
+
Color points by label, plot centroids and optional convex hulls.
|
|
1305
|
+
Labels == -1 are noise and drawn in grey.
|
|
1306
|
+
Also annotates cluster numbers near centroids (contiguous numbers starting at 0).
|
|
1307
|
+
"""
|
|
1308
|
+
import matplotlib.colors as mcolors
|
|
1309
|
+
from scipy.spatial import ConvexHull
|
|
115
1310
|
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
members = (merged_cluster == cid).nonzero(as_tuple=True)[0]
|
|
119
|
-
cluster_sizes[members] = len(members)
|
|
1311
|
+
labels = np.asarray(labels)
|
|
1312
|
+
pts = np.column_stack([x, y])
|
|
120
1313
|
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
1314
|
+
unique = np.unique(labels)
|
|
1315
|
+
# sort so noise (-1) comes last for drawing
|
|
1316
|
+
unique = sorted(unique.tolist(), key=lambda v: (v == -1, v))
|
|
1317
|
+
cmap_obj = plt.get_cmap(cmap)
|
|
1318
|
+
ncolors = max(8, len(unique))
|
|
1319
|
+
colors = [cmap_obj(i / float(ncolors)) for i in range(ncolors)]
|
|
126
1320
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
1321
|
+
for idx, lab in enumerate(unique):
|
|
1322
|
+
mask = labels == lab
|
|
1323
|
+
if not mask.any():
|
|
1324
|
+
continue
|
|
1325
|
+
col = (0.6, 0.6, 0.6, 0.6) if lab == -1 else colors[idx % ncolors]
|
|
1326
|
+
ax.scatter(x[mask], y[mask], s=20, c=[col], alpha=alpha_pts, marker=marker, linewidths=0.2, edgecolors="none", rasterized=True)
|
|
132
1327
|
|
|
133
|
-
|
|
1328
|
+
if lab != -1:
|
|
1329
|
+
# centroid
|
|
1330
|
+
if plot_centroids:
|
|
1331
|
+
cx = float(np.mean(x[mask]))
|
|
1332
|
+
cy = float(np.mean(y[mask]))
|
|
1333
|
+
ax.scatter([cx], [cy], s=centroid_size, marker=centroid_marker, c=[col], edgecolor="k", linewidth=0.6, zorder=10)
|
|
134
1334
|
|
|
135
|
-
|
|
136
|
-
|
|
1335
|
+
if show_cluster_labels:
|
|
1336
|
+
ax.text(cx, cy, str(int(lab)), color="white", fontsize=cluster_label_fontsize,
|
|
1337
|
+
ha="center", va="center", weight="bold", zorder=12,
|
|
1338
|
+
bbox=dict(facecolor=(0,0,0,0.5), pad=0.3, boxstyle="round"))
|
|
137
1339
|
|
|
138
|
-
|
|
1340
|
+
# hull
|
|
1341
|
+
if hull and np.sum(mask) >= 3:
|
|
1342
|
+
try:
|
|
1343
|
+
ch_pts = pts[mask]
|
|
1344
|
+
hull_idx = ConvexHull(ch_pts).vertices
|
|
1345
|
+
hull_poly = ch_pts[hull_idx]
|
|
1346
|
+
ax.fill(hull_poly[:, 0], hull_poly[:, 1], alpha=hull_alpha, facecolor=col, edgecolor=hull_edgecolor, linewidth=0.6, zorder=5)
|
|
1347
|
+
except Exception:
|
|
1348
|
+
pass
|
|
139
1349
|
|
|
140
|
-
|
|
141
|
-
plt.hist(all_hamming_dists, bins=50, alpha=0.75)
|
|
142
|
-
plt.axvline(distance_threshold, color="red", linestyle="--", label=f"threshold = {distance_threshold}")
|
|
143
|
-
plt.xlabel("Hamming Distance")
|
|
144
|
-
plt.ylabel("Frequency")
|
|
145
|
-
plt.title("Histogram of Pairwise Hamming Distances")
|
|
146
|
-
plt.legend()
|
|
147
|
-
plt.show()
|
|
1350
|
+
return None
|
|
148
1351
|
|
|
149
|
-
return adata_unique, adata
|