smftools 0.1.7__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (174) hide show
  1. smftools/__init__.py +7 -6
  2. smftools/_version.py +1 -1
  3. smftools/cli/cli_flows.py +94 -0
  4. smftools/cli/hmm_adata.py +338 -0
  5. smftools/cli/load_adata.py +577 -0
  6. smftools/cli/preprocess_adata.py +363 -0
  7. smftools/cli/spatial_adata.py +564 -0
  8. smftools/cli_entry.py +435 -0
  9. smftools/config/__init__.py +1 -0
  10. smftools/config/conversion.yaml +38 -0
  11. smftools/config/deaminase.yaml +61 -0
  12. smftools/config/default.yaml +264 -0
  13. smftools/config/direct.yaml +41 -0
  14. smftools/config/discover_input_files.py +115 -0
  15. smftools/config/experiment_config.py +1288 -0
  16. smftools/hmm/HMM.py +1576 -0
  17. smftools/hmm/__init__.py +20 -0
  18. smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
  19. smftools/hmm/call_hmm_peaks.py +106 -0
  20. smftools/{tools → hmm}/display_hmm.py +3 -3
  21. smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
  22. smftools/{tools → hmm}/train_hmm.py +1 -1
  23. smftools/informatics/__init__.py +13 -9
  24. smftools/informatics/archived/deaminase_smf.py +132 -0
  25. smftools/informatics/archived/fast5_to_pod5.py +43 -0
  26. smftools/informatics/archived/helpers/archived/__init__.py +71 -0
  27. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +126 -0
  28. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +87 -0
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +213 -0
  30. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +90 -0
  31. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +259 -0
  32. smftools/informatics/{helpers → archived/helpers/archived}/count_aligned_reads.py +2 -2
  33. smftools/informatics/{helpers → archived/helpers/archived}/demux_and_index_BAM.py +8 -10
  34. smftools/informatics/{helpers → archived/helpers/archived}/extract_base_identities.py +30 -4
  35. smftools/informatics/{helpers → archived/helpers/archived}/extract_mods.py +15 -13
  36. smftools/informatics/{helpers → archived/helpers/archived}/extract_read_features_from_bam.py +4 -2
  37. smftools/informatics/{helpers → archived/helpers/archived}/find_conversion_sites.py +5 -4
  38. smftools/informatics/{helpers → archived/helpers/archived}/generate_converted_FASTA.py +2 -0
  39. smftools/informatics/{helpers → archived/helpers/archived}/get_chromosome_lengths.py +9 -8
  40. smftools/informatics/archived/helpers/archived/index_fasta.py +24 -0
  41. smftools/informatics/{helpers → archived/helpers/archived}/make_modbed.py +1 -2
  42. smftools/informatics/{helpers → archived/helpers/archived}/modQC.py +2 -2
  43. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +250 -0
  44. smftools/informatics/{helpers → archived/helpers/archived}/separate_bam_by_bc.py +8 -7
  45. smftools/informatics/{helpers → archived/helpers/archived}/split_and_index_BAM.py +8 -12
  46. smftools/informatics/archived/subsample_fasta_from_bed.py +49 -0
  47. smftools/informatics/bam_functions.py +812 -0
  48. smftools/informatics/basecalling.py +67 -0
  49. smftools/informatics/bed_functions.py +366 -0
  50. smftools/informatics/binarize_converted_base_identities.py +172 -0
  51. smftools/informatics/{helpers/converted_BAM_to_adata_II.py → converted_BAM_to_adata.py} +198 -50
  52. smftools/informatics/fasta_functions.py +255 -0
  53. smftools/informatics/h5ad_functions.py +197 -0
  54. smftools/informatics/{helpers/modkit_extract_to_adata.py → modkit_extract_to_adata.py} +147 -61
  55. smftools/informatics/modkit_functions.py +129 -0
  56. smftools/informatics/ohe.py +160 -0
  57. smftools/informatics/pod5_functions.py +224 -0
  58. smftools/informatics/{helpers/run_multiqc.py → run_multiqc.py} +5 -2
  59. smftools/machine_learning/__init__.py +12 -0
  60. smftools/machine_learning/data/__init__.py +2 -0
  61. smftools/machine_learning/data/anndata_data_module.py +234 -0
  62. smftools/machine_learning/evaluation/__init__.py +2 -0
  63. smftools/machine_learning/evaluation/eval_utils.py +31 -0
  64. smftools/machine_learning/evaluation/evaluators.py +223 -0
  65. smftools/machine_learning/inference/__init__.py +3 -0
  66. smftools/machine_learning/inference/inference_utils.py +27 -0
  67. smftools/machine_learning/inference/lightning_inference.py +68 -0
  68. smftools/machine_learning/inference/sklearn_inference.py +55 -0
  69. smftools/machine_learning/inference/sliding_window_inference.py +114 -0
  70. smftools/machine_learning/models/base.py +295 -0
  71. smftools/machine_learning/models/cnn.py +138 -0
  72. smftools/machine_learning/models/lightning_base.py +345 -0
  73. smftools/machine_learning/models/mlp.py +26 -0
  74. smftools/{tools → machine_learning}/models/positional.py +3 -2
  75. smftools/{tools → machine_learning}/models/rnn.py +2 -1
  76. smftools/machine_learning/models/sklearn_models.py +273 -0
  77. smftools/machine_learning/models/transformer.py +303 -0
  78. smftools/machine_learning/training/__init__.py +2 -0
  79. smftools/machine_learning/training/train_lightning_model.py +135 -0
  80. smftools/machine_learning/training/train_sklearn_model.py +114 -0
  81. smftools/plotting/__init__.py +4 -1
  82. smftools/plotting/autocorrelation_plotting.py +609 -0
  83. smftools/plotting/general_plotting.py +1292 -140
  84. smftools/plotting/hmm_plotting.py +260 -0
  85. smftools/plotting/qc_plotting.py +270 -0
  86. smftools/preprocessing/__init__.py +15 -8
  87. smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
  88. smftools/preprocessing/append_base_context.py +122 -0
  89. smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
  90. smftools/preprocessing/binarize.py +17 -0
  91. smftools/preprocessing/binarize_on_Youden.py +2 -2
  92. smftools/preprocessing/calculate_complexity_II.py +248 -0
  93. smftools/preprocessing/calculate_coverage.py +10 -1
  94. smftools/preprocessing/calculate_position_Youden.py +1 -1
  95. smftools/preprocessing/calculate_read_modification_stats.py +101 -0
  96. smftools/preprocessing/clean_NaN.py +17 -1
  97. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
  98. smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
  99. smftools/preprocessing/flag_duplicate_reads.py +1326 -124
  100. smftools/preprocessing/invert_adata.py +12 -5
  101. smftools/preprocessing/load_sample_sheet.py +19 -4
  102. smftools/readwrite.py +1021 -89
  103. smftools/tools/__init__.py +3 -32
  104. smftools/tools/calculate_umap.py +5 -5
  105. smftools/tools/general_tools.py +3 -3
  106. smftools/tools/position_stats.py +468 -106
  107. smftools/tools/read_stats.py +115 -1
  108. smftools/tools/spatial_autocorrelation.py +562 -0
  109. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/METADATA +14 -9
  110. smftools-0.2.3.dist-info/RECORD +173 -0
  111. smftools-0.2.3.dist-info/entry_points.txt +2 -0
  112. smftools/informatics/fast5_to_pod5.py +0 -21
  113. smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
  114. smftools/informatics/helpers/__init__.py +0 -74
  115. smftools/informatics/helpers/align_and_sort_BAM.py +0 -59
  116. smftools/informatics/helpers/aligned_BAM_to_bed.py +0 -74
  117. smftools/informatics/helpers/bam_qc.py +0 -66
  118. smftools/informatics/helpers/bed_to_bigwig.py +0 -39
  119. smftools/informatics/helpers/binarize_converted_base_identities.py +0 -79
  120. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +0 -55
  121. smftools/informatics/helpers/index_fasta.py +0 -12
  122. smftools/informatics/helpers/make_dirs.py +0 -21
  123. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
  124. smftools/informatics/load_adata.py +0 -182
  125. smftools/informatics/readwrite.py +0 -106
  126. smftools/informatics/subsample_fasta_from_bed.py +0 -47
  127. smftools/preprocessing/append_C_context.py +0 -82
  128. smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
  129. smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
  130. smftools/preprocessing/filter_reads_on_length.py +0 -51
  131. smftools/tools/call_hmm_peaks.py +0 -105
  132. smftools/tools/data/__init__.py +0 -2
  133. smftools/tools/data/anndata_data_module.py +0 -90
  134. smftools/tools/inference/__init__.py +0 -1
  135. smftools/tools/inference/lightning_inference.py +0 -41
  136. smftools/tools/models/base.py +0 -14
  137. smftools/tools/models/cnn.py +0 -34
  138. smftools/tools/models/lightning_base.py +0 -41
  139. smftools/tools/models/mlp.py +0 -17
  140. smftools/tools/models/sklearn_models.py +0 -40
  141. smftools/tools/models/transformer.py +0 -133
  142. smftools/tools/training/__init__.py +0 -1
  143. smftools/tools/training/train_lightning_model.py +0 -47
  144. smftools-0.1.7.dist-info/RECORD +0 -136
  145. /smftools/{tools/evaluation → cli}/__init__.py +0 -0
  146. /smftools/{tools → hmm}/calculate_distances.py +0 -0
  147. /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
  148. /smftools/informatics/{basecall_pod5s.py → archived/basecall_pod5s.py} +0 -0
  149. /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
  150. /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
  151. /smftools/informatics/{helpers → archived/helpers/archived}/canoncall.py +0 -0
  152. /smftools/informatics/{helpers → archived/helpers/archived}/converted_BAM_to_adata.py +0 -0
  153. /smftools/informatics/{helpers → archived/helpers/archived}/extract_read_lengths_from_bed.py +0 -0
  154. /smftools/informatics/{helpers → archived/helpers/archived}/extract_readnames_from_BAM.py +0 -0
  155. /smftools/informatics/{helpers → archived/helpers/archived}/get_native_references.py +0 -0
  156. /smftools/informatics/{helpers → archived/helpers}/archived/informatics.py +0 -0
  157. /smftools/informatics/{helpers → archived/helpers}/archived/load_adata.py +0 -0
  158. /smftools/informatics/{helpers → archived/helpers/archived}/modcall.py +0 -0
  159. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_batching.py +0 -0
  160. /smftools/informatics/{helpers → archived/helpers/archived}/ohe_layers_decode.py +0 -0
  161. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_decode.py +0 -0
  162. /smftools/informatics/{helpers → archived/helpers/archived}/one_hot_encode.py +0 -0
  163. /smftools/informatics/{subsample_pod5.py → archived/subsample_pod5.py} +0 -0
  164. /smftools/informatics/{helpers/complement_base_list.py → complement_base_list.py} +0 -0
  165. /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
  166. /smftools/{tools → machine_learning}/models/__init__.py +0 -0
  167. /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
  168. /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
  169. /smftools/{tools → machine_learning}/utils/device.py +0 -0
  170. /smftools/{tools → machine_learning}/utils/grl.py +0 -0
  171. /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
  172. /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
  173. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/WHEEL +0 -0
  174. {smftools-0.1.7.dist-info → smftools-0.2.3.dist-info}/licenses/LICENSE +0 -0
@@ -1,149 +1,1351 @@
1
+ # duplicate_detection_with_hier_and_plots.py
2
+ import copy
3
+ import warnings
4
+ import math
5
+ import os
6
+ from collections import defaultdict
7
+ from typing import Dict, Any, Tuple, Union, List, Optional
8
+
1
9
  import torch
10
+ import anndata as ad
11
+ import numpy as np
12
+ import pandas as pd
13
+ import matplotlib.pyplot as plt
2
14
  from tqdm import tqdm
3
15
 
4
- class UnionFind:
5
- def __init__(self, size):
6
- self.parent = torch.arange(size)
16
+ from ..readwrite import make_dirs
17
+
18
+ # optional imports for clustering / PCA / KDE
19
+ try:
20
+ from scipy.cluster import hierarchy as sch
21
+ from scipy.spatial.distance import pdist, squareform
22
+ SCIPY_AVAILABLE = True
23
+ except Exception:
24
+ sch = None
25
+ pdist = None
26
+ squareform = None
27
+ SCIPY_AVAILABLE = False
28
+
29
+ try:
30
+ from sklearn.decomposition import PCA
31
+ from sklearn.cluster import KMeans, DBSCAN
32
+ from sklearn.mixture import GaussianMixture
33
+ from sklearn.metrics import silhouette_score
34
+ SKLEARN_AVAILABLE = True
35
+ except Exception:
36
+ PCA = None
37
+ KMeans = DBSCAN = GaussianMixture = silhouette_score = None
38
+ SKLEARN_AVAILABLE = False
39
+
40
+ try:
41
+ from scipy.stats import gaussian_kde
42
+ except Exception:
43
+ gaussian_kde = None
44
+
45
+
46
+ def merge_uns_preserve(orig_uns: dict, new_uns: dict, prefer="orig") -> dict:
47
+ """
48
+ Merge two .uns dicts. prefer='orig' will keep orig_uns values on conflict,
49
+ prefer='new' will keep new_uns values on conflict. Conflicts are reported.
50
+ """
51
+ out = copy.deepcopy(new_uns) if new_uns is not None else {}
52
+ for k, v in (orig_uns or {}).items():
53
+ if k not in out:
54
+ out[k] = copy.deepcopy(v)
55
+ else:
56
+ # present in both: compare quickly (best-effort)
57
+ try:
58
+ equal = (out[k] == v)
59
+ except Exception:
60
+ equal = False
61
+ if equal:
62
+ continue
63
+ # conflict
64
+ warnings.warn(f".uns conflict for key '{k}'; keeping '{prefer}' value.")
65
+ if prefer == "orig":
66
+ out[k] = copy.deepcopy(v)
67
+ else:
68
+ # keep out[k] (the new one) and also stash orig under a suffix
69
+ out[f"orig_uns__{k}"] = copy.deepcopy(v)
70
+ return out
71
+
72
+ def flag_duplicate_reads(
73
+ adata,
74
+ var_filters_sets,
75
+ distance_threshold: float = 0.07,
76
+ obs_reference_col: str = "Reference_strand",
77
+ sample_col: str = "Barcode",
78
+ output_directory: Optional[str] = None,
79
+ metric_keys: Union[str, List[str]] = ("Fraction_any_C_site_modified",),
80
+ uns_flag: str = "read_duplicate_detection_performed",
81
+ uns_filtered_flag: str = "read_duplicates_removed",
82
+ bypass: bool = False,
83
+ force_redo: bool = False,
84
+ keep_best_metric: Optional[str] = 'read_quality',
85
+ keep_best_higher: bool = True,
86
+ window_size: int = 50,
87
+ min_overlap_positions: int = 20,
88
+ do_pca: bool = False,
89
+ pca_n_components: int = 50,
90
+ pca_center: bool = True,
91
+ do_hierarchical: bool = True,
92
+ hierarchical_linkage: str = "average",
93
+ hierarchical_metric: str = "euclidean",
94
+ hierarchical_window: int = 50,
95
+ random_state: int = 0,
96
+ ):
97
+ """
98
+ Duplicate-flagging pipeline where hierarchical stage operates only on representatives
99
+ (one representative per lex cluster, i.e. the keeper). Final keeper assignment and
100
+ enforcement happens only after hierarchical merging.
101
+
102
+ Returns (adata_unique, adata_full) as before; writes sequence__* columns into adata.obs.
103
+ """
104
+ # early exits
105
+ already = bool(adata.uns.get(uns_flag, False))
106
+ if (already and not force_redo):
107
+ if "is_duplicate" in adata.obs.columns:
108
+ adata_unique = adata[adata.obs["is_duplicate"] == False].copy()
109
+ return adata_unique, adata
110
+ else:
111
+ return adata.copy(), adata.copy()
112
+ if bypass:
113
+ return None, adata
114
+
115
+ if isinstance(metric_keys, str):
116
+ metric_keys = [metric_keys]
117
+
118
+ # local UnionFind
119
+ class UnionFind:
120
+ def __init__(self, size):
121
+ self.parent = list(range(size))
122
+
123
+ def find(self, x):
124
+ while self.parent[x] != x:
125
+ self.parent[x] = self.parent[self.parent[x]]
126
+ x = self.parent[x]
127
+ return x
128
+
129
+ def union(self, x, y):
130
+ rx = self.find(x); ry = self.find(y)
131
+ if rx != ry:
132
+ self.parent[ry] = rx
133
+
134
+ adata_processed_list = []
135
+ histograms = []
136
+
137
+ samples = adata.obs[sample_col].astype("category").cat.categories
138
+ references = adata.obs[obs_reference_col].astype("category").cat.categories
139
+
140
+ for sample in samples:
141
+ for ref in references:
142
+ print(f"Processing sample={sample} ref={ref}")
143
+ sample_mask = adata.obs[sample_col] == sample
144
+ ref_mask = adata.obs[obs_reference_col] == ref
145
+ subset_mask = sample_mask & ref_mask
146
+ adata_subset = adata[subset_mask].copy()
147
+
148
+ if adata_subset.n_obs < 2:
149
+ print(f" Skipping {sample}_{ref} (too few reads)")
150
+ continue
151
+
152
+ N = adata_subset.shape[0]
153
+
154
+ # Build mask of columns (vars) to use
155
+ combined_mask = np.zeros(len(adata.var), dtype=bool)
156
+ for var_set in var_filters_sets:
157
+ if any(str(ref) in str(v) for v in var_set):
158
+ per_col_mask = np.ones(len(adata.var), dtype=bool)
159
+ for key in var_set:
160
+ per_col_mask &= np.asarray(adata.var[key].values, dtype=bool)
161
+ combined_mask |= per_col_mask
162
+
163
+ selected_cols = adata.var.index[combined_mask.tolist()].to_list()
164
+ col_indices = [adata.var.index.get_loc(c) for c in selected_cols]
165
+ print(f" Selected {len(col_indices)} columns out of {adata.var.shape[0]} for {ref}")
166
+
167
+ # Extract data matrix (dense numpy) for the subset
168
+ X = adata_subset.X
169
+ if not isinstance(X, np.ndarray):
170
+ try:
171
+ X = X.toarray()
172
+ except Exception:
173
+ X = np.asarray(X)
174
+ X_sub = X[:, col_indices].astype(float) # keep NaNs
175
+
176
+ # convert to torch for some vector ops
177
+ X_tensor = torch.from_numpy(X_sub.copy())
178
+
179
+ # per-read nearest distances recorded
180
+ fwd_hamming_to_next = np.full((N,), np.nan, dtype=float)
181
+ rev_hamming_to_prev = np.full((N,), np.nan, dtype=float)
182
+ hierarchical_min_pair = np.full((N,), np.nan, dtype=float)
183
+
184
+ # legacy local lexographic pairwise hamming distances (for histogram)
185
+ local_hamming_dists = []
186
+ # hierarchical discovered dists (for histogram)
187
+ hierarchical_found_dists = []
188
+
189
+ # Lexicographic windowed pass function
190
+ def cluster_pass(X_tensor_local, reverse=False, window=int(window_size), record_distances=False):
191
+ N_local = X_tensor_local.shape[0]
192
+ X_sortable = X_tensor_local.clone().nan_to_num(-1.0)
193
+ sort_keys = [tuple(row.numpy().tolist()) for row in X_sortable]
194
+ sorted_idx = sorted(range(N_local), key=lambda i: sort_keys[i], reverse=reverse)
195
+ sorted_X = X_tensor_local[sorted_idx]
196
+ cluster_pairs_local = []
197
+
198
+ for i in range(len(sorted_X)):
199
+ row_i = sorted_X[i]
200
+ j_range_local = range(i + 1, min(i + 1 + window, len(sorted_X)))
201
+ if len(j_range_local) == 0:
202
+ continue
203
+ block_rows = sorted_X[list(j_range_local)]
204
+ row_i_exp = row_i.unsqueeze(0) # (1, D)
205
+ valid_mask = (~torch.isnan(row_i_exp)) & (~torch.isnan(block_rows)) # (M, D)
206
+ valid_counts = valid_mask.sum(dim=1).float()
207
+ enough_overlap = valid_counts >= float(min_overlap_positions)
208
+ if enough_overlap.any():
209
+ diffs = (row_i_exp != block_rows) & valid_mask
210
+ hamming_counts = diffs.sum(dim=1).float()
211
+ hamming_dists = torch.where(valid_counts > 0, hamming_counts / valid_counts, torch.tensor(float("nan")))
212
+ # record distances (legacy list of all local comparisons)
213
+ hamming_np = hamming_dists.cpu().numpy().tolist()
214
+ local_hamming_dists.extend([float(x) for x in hamming_np if (not np.isnan(x))])
215
+ matches = (hamming_dists < distance_threshold) & (enough_overlap)
216
+ for offset_local, m in enumerate(matches):
217
+ if m:
218
+ i_global = sorted_idx[i]
219
+ j_global = sorted_idx[i + 1 + offset_local]
220
+ cluster_pairs_local.append((i_global, j_global))
221
+ if record_distances:
222
+ # record next neighbor distance for the item (global index)
223
+ next_local_idx = i + 1
224
+ if next_local_idx < len(sorted_X):
225
+ next_global = sorted_idx[next_local_idx]
226
+ vm_pair = (~torch.isnan(row_i)) & (~torch.isnan(sorted_X[next_local_idx]))
227
+ vc = vm_pair.sum().item()
228
+ if vc >= min_overlap_positions:
229
+ d = float(((row_i[vm_pair] != sorted_X[next_local_idx][vm_pair]).sum().item()) / vc)
230
+ if reverse:
231
+ rev_hamming_to_prev[next_global] = d
232
+ else:
233
+ fwd_hamming_to_next[sorted_idx[i]] = d
234
+ return cluster_pairs_local
235
+
236
+ # run forward pass
237
+ pairs_fwd = cluster_pass(X_tensor, reverse=False, record_distances=True)
238
+ involved_in_fwd = set([p for pair in pairs_fwd for p in pair])
239
+ # build mask for reverse pass to avoid re-checking items already paired
240
+ mask_for_rev = np.ones(N, dtype=bool)
241
+ if len(involved_in_fwd) > 0:
242
+ for idx in involved_in_fwd:
243
+ mask_for_rev[idx] = False
244
+ rev_idx_map = np.nonzero(mask_for_rev)[0].tolist()
245
+ if len(rev_idx_map) > 0:
246
+ reduced_tensor = X_tensor[rev_idx_map]
247
+ pairs_rev_local = cluster_pass(reduced_tensor, reverse=True, record_distances=True)
248
+ # remap local reduced indices to global
249
+ remapped_rev_pairs = [(int(rev_idx_map[i]), int(rev_idx_map[j])) for (i, j) in pairs_rev_local]
250
+ else:
251
+ remapped_rev_pairs = []
252
+
253
+ all_pairs = pairs_fwd + remapped_rev_pairs
254
+
255
+ # initial union-find based on lex pairs
256
+ uf = UnionFind(N)
257
+ for i, j in all_pairs:
258
+ uf.union(i, j)
259
+
260
+ # initial merged clusters (lex-level)
261
+ merged_cluster = np.zeros((N,), dtype=int)
262
+ for i in range(N):
263
+ merged_cluster[i] = uf.find(i)
264
+ unique_initial = np.unique(merged_cluster)
265
+ id_map = {old: new for new, old in enumerate(sorted(unique_initial.tolist()))}
266
+ merged_cluster_mapped = np.array([id_map[int(x)] for x in merged_cluster], dtype=int)
267
+
268
+ # cluster sizes and choose lex-keeper per lex-cluster (representatives)
269
+ cluster_sizes = np.zeros_like(merged_cluster_mapped)
270
+ cluster_counts = []
271
+ unique_clusters = np.unique(merged_cluster_mapped)
272
+ keeper_for_cluster = {}
273
+ for cid in unique_clusters:
274
+ members = np.where(merged_cluster_mapped == cid)[0].tolist()
275
+ csize = int(len(members))
276
+ cluster_counts.append(csize)
277
+ cluster_sizes[members] = csize
278
+ # pick lex keeper (representative)
279
+ if len(members) == 1:
280
+ keeper_for_cluster[cid] = members[0]
281
+ else:
282
+ if keep_best_metric is None:
283
+ keeper_for_cluster[cid] = members[0]
284
+ else:
285
+ obs_index = list(adata_subset.obs.index)
286
+ member_names = [obs_index[m] for m in members]
287
+ try:
288
+ vals = pd.to_numeric(adata_subset.obs.loc[member_names, keep_best_metric], errors="coerce").to_numpy(dtype=float)
289
+ except Exception:
290
+ vals = np.array([np.nan] * len(members), dtype=float)
291
+ if np.all(np.isnan(vals)):
292
+ keeper_for_cluster[cid] = members[0]
293
+ else:
294
+ if keep_best_higher:
295
+ nan_mask = np.isnan(vals)
296
+ vals[nan_mask] = -np.inf
297
+ rel_idx = int(np.nanargmax(vals))
298
+ else:
299
+ nan_mask = np.isnan(vals)
300
+ vals[nan_mask] = np.inf
301
+ rel_idx = int(np.nanargmin(vals))
302
+ keeper_for_cluster[cid] = members[rel_idx]
303
+
304
+ # expose lex keeper info (record only; do not enforce deletion yet)
305
+ lex_is_keeper = np.zeros((N,), dtype=bool)
306
+ lex_is_duplicate = np.zeros((N,), dtype=bool)
307
+ for cid, members in zip(unique_clusters, [np.where(merged_cluster_mapped == cid)[0].tolist() for cid in unique_clusters]):
308
+ keeper_idx = keeper_for_cluster[cid]
309
+ lex_is_keeper[keeper_idx] = True
310
+ for m in members:
311
+ if m != keeper_idx:
312
+ lex_is_duplicate[m] = True
313
+ # note: these are just recorded for inspection / later preference
314
+ # and will be written to adata_subset.obs below
315
+ # record lex min pair (min of fwd/rev neighbor) for each read
316
+ min_pair = np.full((N,), np.nan, dtype=float)
317
+ for i in range(N):
318
+ a = fwd_hamming_to_next[i]
319
+ b = rev_hamming_to_prev[i]
320
+ vals = []
321
+ if not np.isnan(a):
322
+ vals.append(a)
323
+ if not np.isnan(b):
324
+ vals.append(b)
325
+ if vals:
326
+ min_pair[i] = float(np.nanmin(vals))
327
+
328
+ # --- hierarchical on representatives only ---
329
+ hierarchical_pairs = [] # (rep_global_i, rep_global_j, d)
330
+ rep_global_indices = sorted(set(keeper_for_cluster.values()))
331
+ if do_hierarchical and len(rep_global_indices) > 1:
332
+ if not SKLEARN_AVAILABLE:
333
+ warnings.warn("sklearn not available; skipping PCA/hierarchical pass.")
334
+ elif not SCIPY_AVAILABLE:
335
+ warnings.warn("scipy not available; skipping hierarchical pass.")
336
+ else:
337
+ # build reps array and impute for PCA
338
+ reps_X = X_sub[rep_global_indices, :]
339
+ reps_arr = np.array(reps_X, dtype=float, copy=True)
340
+ col_means = np.nanmean(reps_arr, axis=0)
341
+ col_means = np.where(np.isnan(col_means), 0.0, col_means)
342
+ inds = np.where(np.isnan(reps_arr))
343
+ if inds[0].size > 0:
344
+ reps_arr[inds] = np.take(col_means, inds[1])
345
+
346
+ # PCA if requested
347
+ if do_pca and PCA is not None:
348
+ n_comp = min(int(pca_n_components), reps_arr.shape[1], reps_arr.shape[0])
349
+ if n_comp <= 0:
350
+ reps_for_clustering = reps_arr
351
+ else:
352
+ pca = PCA(n_components=n_comp, random_state=int(random_state), svd_solver="auto", copy=True)
353
+ reps_for_clustering = pca.fit_transform(reps_arr)
354
+ else:
355
+ reps_for_clustering = reps_arr
356
+
357
+ # linkage & leaves (ordering)
358
+ try:
359
+ pdist_vec = pdist(reps_for_clustering, metric=hierarchical_metric)
360
+ Z = sch.linkage(pdist_vec, method=hierarchical_linkage)
361
+ leaves = sch.leaves_list(Z)
362
+ except Exception as e:
363
+ warnings.warn(f"hierarchical pass failed: {e}; skipping hierarchical stage.")
364
+ leaves = np.arange(len(rep_global_indices), dtype=int)
365
+
366
+ # apply windowed hamming comparisons across ordered reps and union via same UF (so clusters of all reads merge)
367
+ order_global_reps = [rep_global_indices[i] for i in leaves]
368
+ n_reps = len(order_global_reps)
369
+ for pos in range(n_reps):
370
+ i_global = order_global_reps[pos]
371
+ for jpos in range(pos + 1, min(pos + 1 + hierarchical_window, n_reps)):
372
+ j_global = order_global_reps[jpos]
373
+ vi = X_sub[int(i_global), :]
374
+ vj = X_sub[int(j_global), :]
375
+ valid_mask = (~np.isnan(vi)) & (~np.isnan(vj))
376
+ overlap = int(valid_mask.sum())
377
+ if overlap < min_overlap_positions:
378
+ continue
379
+ diffs = (vi[valid_mask] != vj[valid_mask]).sum()
380
+ d = float(diffs) / float(overlap)
381
+ if d < distance_threshold:
382
+ uf.union(int(i_global), int(j_global))
383
+ hierarchical_pairs.append((int(i_global), int(j_global), float(d)))
384
+ hierarchical_found_dists.append(float(d))
385
+
386
+ # after hierarchical unions, reconstruct merged clusters for all reads
387
+ merged_cluster_after = np.zeros((N,), dtype=int)
388
+ for i in range(N):
389
+ merged_cluster_after[i] = uf.find(i)
390
+ unique_final = np.unique(merged_cluster_after)
391
+ id_map_final = {old: new for new, old in enumerate(sorted(unique_final.tolist()))}
392
+ merged_cluster_mapped_final = np.array([id_map_final[int(x)] for x in merged_cluster_after], dtype=int)
393
+
394
+ # compute final cluster members and choose final keeper per final cluster
395
+ cluster_sizes_final = np.zeros_like(merged_cluster_mapped_final)
396
+ final_cluster_counts = []
397
+ final_unique = np.unique(merged_cluster_mapped_final)
398
+ final_keeper_for_cluster = {}
399
+ cluster_members_map = {}
400
+ for cid in final_unique:
401
+ members = np.where(merged_cluster_mapped_final == cid)[0].tolist()
402
+ cluster_members_map[cid] = members
403
+ csize = len(members)
404
+ final_cluster_counts.append(csize)
405
+ cluster_sizes_final[members] = csize
406
+ if csize == 1:
407
+ final_keeper_for_cluster[cid] = members[0]
408
+ else:
409
+ # prefer keep_best_metric if available; do not automatically prefer lex-keeper here unless you want to;
410
+ # (user previously asked for preferring lex keepers — if desired, you can prefer lex_is_keeper among members)
411
+ obs_index = list(adata_subset.obs.index)
412
+ member_names = [obs_index[m] for m in members]
413
+ if keep_best_metric is not None and keep_best_metric in adata_subset.obs.columns:
414
+ try:
415
+ vals = pd.to_numeric(adata_subset.obs.loc[member_names, keep_best_metric], errors="coerce").to_numpy(dtype=float)
416
+ except Exception:
417
+ vals = np.array([np.nan] * len(members), dtype=float)
418
+ if np.all(np.isnan(vals)):
419
+ final_keeper_for_cluster[cid] = members[0]
420
+ else:
421
+ if keep_best_higher:
422
+ nan_mask = np.isnan(vals)
423
+ vals[nan_mask] = -np.inf
424
+ rel_idx = int(np.nanargmax(vals))
425
+ else:
426
+ nan_mask = np.isnan(vals)
427
+ vals[nan_mask] = np.inf
428
+ rel_idx = int(np.nanargmin(vals))
429
+ final_keeper_for_cluster[cid] = members[rel_idx]
430
+ else:
431
+ # if lex keepers present among members, prefer them
432
+ lex_members = [m for m in members if lex_is_keeper[m]]
433
+ if len(lex_members) > 0:
434
+ final_keeper_for_cluster[cid] = lex_members[0]
435
+ else:
436
+ final_keeper_for_cluster[cid] = members[0]
437
+
438
+ # update sequence__is_duplicate based on final clusters: non-keepers in multi-member clusters are duplicates
439
+ sequence_is_duplicate = np.zeros((N,), dtype=bool)
440
+ for cid in final_unique:
441
+ keeper = final_keeper_for_cluster[cid]
442
+ members = cluster_members_map[cid]
443
+ if len(members) > 1:
444
+ for m in members:
445
+ if m != keeper:
446
+ sequence_is_duplicate[m] = True
447
+
448
+ # propagate hierarchical distances into hierarchical_min_pair for all cluster members
449
+ for (i_g, j_g, d) in hierarchical_pairs:
450
+ # identify their final cluster ids (after unions)
451
+ c_i = merged_cluster_mapped_final[int(i_g)]
452
+ c_j = merged_cluster_mapped_final[int(j_g)]
453
+ members_i = cluster_members_map.get(c_i, [int(i_g)])
454
+ members_j = cluster_members_map.get(c_j, [int(j_g)])
455
+ for mi in members_i:
456
+ if np.isnan(hierarchical_min_pair[mi]) or (d < hierarchical_min_pair[mi]):
457
+ hierarchical_min_pair[mi] = d
458
+ for mj in members_j:
459
+ if np.isnan(hierarchical_min_pair[mj]) or (d < hierarchical_min_pair[mj]):
460
+ hierarchical_min_pair[mj] = d
461
+
462
+ # combine lex-phase min_pair and hierarchical_min_pair into the final sequence__min_hamming_to_pair
463
+ combined_min = min_pair.copy()
464
+ for i in range(N):
465
+ hval = hierarchical_min_pair[i]
466
+ if not np.isnan(hval):
467
+ if np.isnan(combined_min[i]) or (hval < combined_min[i]):
468
+ combined_min[i] = hval
469
+
470
+ # write columns back into adata_subset.obs
471
+ adata_subset.obs["sequence__is_duplicate"] = sequence_is_duplicate
472
+ adata_subset.obs["sequence__merged_cluster_id"] = merged_cluster_mapped_final
473
+ adata_subset.obs["sequence__cluster_size"] = cluster_sizes_final
474
+ adata_subset.obs["fwd_hamming_to_next"] = fwd_hamming_to_next
475
+ adata_subset.obs["rev_hamming_to_prev"] = rev_hamming_to_prev
476
+ adata_subset.obs["sequence__hier_hamming_to_pair"] = hierarchical_min_pair
477
+ adata_subset.obs["sequence__min_hamming_to_pair"] = combined_min
478
+ # persist lex bookkeeping columns (informational)
479
+ adata_subset.obs["sequence__lex_is_keeper"] = lex_is_keeper
480
+ adata_subset.obs["sequence__lex_is_duplicate"] = lex_is_duplicate
481
+
482
+ adata_processed_list.append(adata_subset)
483
+
484
+ histograms.append({
485
+ "sample": sample,
486
+ "reference": ref,
487
+ "distances": local_hamming_dists, # lex local comparisons
488
+ "cluster_counts": final_cluster_counts,
489
+ "hierarchical_pairs": hierarchical_found_dists,
490
+ })
491
+
492
+ # Merge annotated subsets back together BEFORE plotting so plotting sees fwd_hamming_to_next, etc.
493
+ _original_uns = copy.deepcopy(adata.uns)
494
+ if len(adata_processed_list) == 0:
495
+ return adata.copy(), adata.copy()
496
+
497
+ adata_full = ad.concat(adata_processed_list, merge="same", join="outer", index_unique=None)
498
+ adata_full.uns = merge_uns_preserve(_original_uns, adata_full.uns, prefer="orig")
499
+
500
+ # Ensure expected numeric columns exist (create if missing)
501
+ for col in ("fwd_hamming_to_next", "rev_hamming_to_prev", "sequence__min_hamming_to_pair", "sequence__hier_hamming_to_pair"):
502
+ if col not in adata_full.obs.columns:
503
+ adata_full.obs[col] = np.nan
504
+
505
+ # histograms (now driven by adata_full if requested)
506
+ hist_outs = os.path.join(output_directory, "read_pair_hamming_distance_histograms")
507
+ make_dirs([hist_outs])
508
+ plot_histogram_pages(histograms,
509
+ distance_threshold=distance_threshold,
510
+ adata=adata_full,
511
+ output_directory=hist_outs,
512
+ distance_types=["min","fwd","rev","hier","lex_local"],
513
+ sample_key=sample_col,
514
+ )
515
+
516
+ # hamming vs metric scatter
517
+ scatter_outs = os.path.join(output_directory, "read_pair_hamming_distance_scatter_plots")
518
+ make_dirs([scatter_outs])
519
+ plot_hamming_vs_metric_pages(adata_full,
520
+ metric_keys=metric_keys,
521
+ output_dir=scatter_outs,
522
+ hamming_col="sequence__min_hamming_to_pair",
523
+ highlight_threshold=distance_threshold,
524
+ highlight_color="red",
525
+ sample_col=sample_col)
526
+
527
+ # boolean columns from neighbor distances
528
+ fwd_vals = pd.to_numeric(adata_full.obs.get("fwd_hamming_to_next", pd.Series(np.nan, index=adata_full.obs.index)), errors="coerce")
529
+ rev_vals = pd.to_numeric(adata_full.obs.get("rev_hamming_to_prev", pd.Series(np.nan, index=adata_full.obs.index)), errors="coerce")
530
+ is_dup_dist = (fwd_vals < float(distance_threshold)) | (rev_vals < float(distance_threshold))
531
+ is_dup_dist = is_dup_dist.fillna(False).astype(bool)
532
+ adata_full.obs["is_duplicate_distance"] = is_dup_dist.values
533
+
534
+ # combine sequence-derived flag with others
535
+ if "sequence__is_duplicate" in adata_full.obs.columns:
536
+ seq_dup = adata_full.obs["sequence__is_duplicate"].astype(bool)
537
+ else:
538
+ seq_dup = pd.Series(False, index=adata_full.obs.index)
539
+
540
+ # cluster-based duplicate indicator (if any clustering columns exist)
541
+ cluster_cols = [c for c in adata_full.obs.columns if c.startswith("hamming_cluster__")]
542
+ if cluster_cols:
543
+ cl_mask = pd.Series(False, index=adata_full.obs.index)
544
+ for c in cluster_cols:
545
+ vals = pd.to_numeric(adata_full.obs[c], errors="coerce")
546
+ mask_pos = (vals > 0) & (vals != -1)
547
+ mask_pos = mask_pos.fillna(False)
548
+ cl_mask |= mask_pos
549
+ adata_full.obs["is_duplicate_clustering"] = cl_mask.values
550
+ else:
551
+ adata_full.obs["is_duplicate_clustering"] = False
552
+
553
+ final_dup = seq_dup | adata_full.obs["is_duplicate_distance"].astype(bool) | adata_full.obs["is_duplicate_clustering"].astype(bool)
554
+ adata_full.obs["is_duplicate"] = final_dup.values
555
+
556
+ # Final keeper enforcement: recompute per-cluster keeper from sequence__merged_cluster_id and
557
+ # ensure that keeper is not marked duplicate
558
+ if "sequence__merged_cluster_id" in adata_full.obs.columns:
559
+ keeper_idx_by_cluster = {}
560
+ metric_col = keep_best_metric if 'keep_best_metric' in locals() else None
561
+
562
+ # group by cluster id
563
+ grp = adata_full.obs[["sequence__merged_cluster_id", "sequence__cluster_size"]].copy()
564
+ for cid, sub in grp.groupby("sequence__merged_cluster_id"):
565
+ try:
566
+ members = sub.index.to_list()
567
+ except Exception:
568
+ members = list(sub.index)
569
+ keeper = None
570
+ # prefer keep_best_metric (if present), else prefer lex keeper among members, else first member
571
+ if metric_col and metric_col in adata_full.obs.columns:
572
+ try:
573
+ vals = pd.to_numeric(adata_full.obs.loc[members, metric_col], errors="coerce")
574
+ if vals.notna().any():
575
+ keeper = vals.idxmax() if keep_best_higher else vals.idxmin()
576
+ else:
577
+ keeper = members[0]
578
+ except Exception:
579
+ keeper = members[0]
580
+ else:
581
+ # prefer lex keeper if present in this merged cluster
582
+ lex_candidates = [m for m in members if ("sequence__lex_is_keeper" in adata_full.obs.columns and adata_full.obs.at[m, "sequence__lex_is_keeper"])]
583
+ if len(lex_candidates) > 0:
584
+ keeper = lex_candidates[0]
585
+ else:
586
+ keeper = members[0]
587
+
588
+ keeper_idx_by_cluster[cid] = keeper
589
+
590
+ # force keepers not to be duplicates
591
+ is_dup_series = adata_full.obs["is_duplicate"].astype(bool)
592
+ for cid, keeper_idx in keeper_idx_by_cluster.items():
593
+ if keeper_idx in adata_full.obs.index:
594
+ is_dup_series.at[keeper_idx] = False
595
+ # clear sequence__is_duplicate for keeper if present
596
+ if "sequence__is_duplicate" in adata_full.obs.columns:
597
+ adata_full.obs.at[keeper_idx, "sequence__is_duplicate"] = False
598
+ # clear lex duplicate flag too if present
599
+ if "sequence__lex_is_duplicate" in adata_full.obs.columns:
600
+ adata_full.obs.at[keeper_idx, "sequence__lex_is_duplicate"] = False
601
+
602
+ adata_full.obs["is_duplicate"] = is_dup_series.values
603
+
604
+ # reason column
605
+ def _dup_reason_row(row):
606
+ reasons = []
607
+ if row.get("is_duplicate_distance", False):
608
+ reasons.append("distance_thresh")
609
+ if row.get("is_duplicate_clustering", False):
610
+ reasons.append("hamming_metric_cluster")
611
+ if bool(row.get("sequence__is_duplicate", False)):
612
+ reasons.append("sequence_cluster")
613
+ return ";".join(reasons) if reasons else ""
614
+
615
+ try:
616
+ reasons = adata_full.obs.apply(_dup_reason_row, axis=1)
617
+ adata_full.obs["is_duplicate_reason"] = reasons.values
618
+ except Exception:
619
+ adata_full.obs["is_duplicate_reason"] = ""
620
+
621
+ adata_unique = adata_full[~adata_full.obs["is_duplicate"].astype(bool)].copy()
622
+
623
+ # mark flags in .uns
624
+ adata_unique.uns[uns_flag] = True
625
+ adata_unique.uns[uns_filtered_flag] = True
626
+ adata_full.uns[uns_flag] = True
627
+
628
+ return adata_unique, adata_full
7
629
 
8
- def find(self, x):
9
- while self.parent[x] != x:
10
- self.parent[x] = self.parent[self.parent[x]]
11
- x = self.parent[x]
12
- return x
13
630
 
14
- def union(self, x, y):
15
- root_x = self.find(x)
16
- root_y = self.find(y)
17
- if root_x != root_y:
18
- self.parent[root_y] = root_x
631
+ # ---------------------------
632
+ # Plot helpers (use adata_full as input)
633
+ # ---------------------------
19
634
 
635
+ def plot_histogram_pages(
636
+ histograms,
637
+ distance_threshold,
638
+ output_directory=None,
639
+ rows_per_page=6,
640
+ bins=50,
641
+ dpi=160,
642
+ figsize_per_cell=(5, 3),
643
+ adata: Optional[ad.AnnData] = None,
644
+ sample_key: str = "Barcode",
645
+ ref_key: str = "Reference_strand",
646
+ distance_key: str = "sequence__min_hamming_to_pair",
647
+ distance_types: Optional[List[str]] = None,
648
+ ):
649
+ """
650
+ Plot Hamming-distance histograms as a grid (rows=samples, cols=references).
20
651
 
21
- def flag_duplicate_reads(adata, var_filters_sets, distance_threshold=0.05, obs_reference_col='Reference_strand'):
22
- import numpy as np
23
- import pandas as pd
24
- import matplotlib.pyplot as plt
652
+ Changes:
653
+ - Ensures that every subplot in a column (same ref) uses the same X-axis range and the same bins,
654
+ computed from the union of values for that reference across samples/dtypes (clamped to [0,1]).
655
+ """
656
+ if distance_types is None:
657
+ distance_types = ["min", "fwd", "rev", "hier", "lex_local"]
25
658
 
26
- all_hamming_dists = []
27
- merged_results = []
659
+ # canonicalize samples / refs
660
+ if adata is not None and sample_key in adata.obs.columns and ref_key in adata.obs.columns:
661
+ obs = adata.obs
662
+ sseries = obs[sample_key]
663
+ if not pd.api.types.is_categorical_dtype(sseries):
664
+ sseries = sseries.astype("category")
665
+ samples = list(sseries.cat.categories)
666
+ rseries = obs[ref_key]
667
+ if not pd.api.types.is_categorical_dtype(rseries):
668
+ rseries = rseries.astype("category")
669
+ references = list(rseries.cat.categories)
670
+ use_adata = True
671
+ else:
672
+ samples = sorted({h["sample"] for h in histograms})
673
+ references = sorted({h["reference"] for h in histograms})
674
+ use_adata = False
28
675
 
29
- references = adata.obs[obs_reference_col].cat.categories
676
+ if len(samples) == 0 or len(references) == 0:
677
+ print("No histogram data to plot.")
678
+ return {"distance_pages": [], "cluster_size_pages": []}
30
679
 
680
+ def clean_array(arr):
681
+ if arr is None or len(arr) == 0:
682
+ return np.array([], dtype=float)
683
+ a = np.asarray(arr, dtype=float)
684
+ a = a[np.isfinite(a)]
685
+ a = a[(a >= 0.0) & (a <= 1.0)]
686
+ return a
687
+
688
+ grid = defaultdict(lambda: defaultdict(list))
689
+ # populate from adata if available
690
+ if use_adata:
691
+ obs = adata.obs
692
+ try:
693
+ grouped = obs.groupby([sample_key, ref_key])
694
+ except Exception:
695
+ grouped = []
696
+ for s in samples:
697
+ for r in references:
698
+ sub = obs[(obs[sample_key] == s) & (obs[ref_key] == r)]
699
+ if not sub.empty:
700
+ grouped.append(((s, r), sub))
701
+ if isinstance(grouped, dict) or hasattr(grouped, "groups"):
702
+ for (s, r), group in grouped:
703
+ if "min" in distance_types and distance_key in group.columns:
704
+ grid[(s, r)]["min"].extend(clean_array(group[distance_key].to_numpy()))
705
+ if "fwd" in distance_types and "fwd_hamming_to_next" in group.columns:
706
+ grid[(s, r)]["fwd"].extend(clean_array(group["fwd_hamming_to_next"].to_numpy()))
707
+ if "rev" in distance_types and "rev_hamming_to_prev" in group.columns:
708
+ grid[(s, r)]["rev"].extend(clean_array(group["rev_hamming_to_prev"].to_numpy()))
709
+ if "hier" in distance_types and "sequence__hier_hamming_to_pair" in group.columns:
710
+ grid[(s, r)]["hier"].extend(clean_array(group["sequence__hier_hamming_to_pair"].to_numpy()))
711
+ else:
712
+ for (s, r), group in grouped:
713
+ if "min" in distance_types and distance_key in group.columns:
714
+ grid[(s, r)]["min"].extend(clean_array(group[distance_key].to_numpy()))
715
+ if "fwd" in distance_types and "fwd_hamming_to_next" in group.columns:
716
+ grid[(s, r)]["fwd"].extend(clean_array(group["fwd_hamming_to_next"].to_numpy()))
717
+ if "rev" in distance_types and "rev_hamming_to_prev" in group.columns:
718
+ grid[(s, r)]["rev"].extend(clean_array(group["rev_hamming_to_prev"].to_numpy()))
719
+ if "hier" in distance_types and "sequence__hier_hamming_to_pair" in group.columns:
720
+ grid[(s, r)]["hier"].extend(clean_array(group["sequence__hier_hamming_to_pair"].to_numpy()))
721
+
722
+ # legacy histograms fallback
723
+ if histograms:
724
+ for h in histograms:
725
+ key = (h["sample"], h["reference"])
726
+ if "lex_local" in distance_types:
727
+ grid[key]["lex_local"].extend(clean_array(h.get("distances", [])))
728
+ if "hier" in distance_types and "hierarchical_pairs" in h:
729
+ grid[key]["hier"].extend(clean_array(h.get("hierarchical_pairs", [])))
730
+ if "cluster_counts" in h:
731
+ grid[key]["_legacy_cluster_counts"].extend(h.get("cluster_counts", []))
732
+
733
+ # Compute per-reference global x-range and bin edges (so every subplot in a column uses same bins)
734
+ ref_xmax = {}
31
735
  for ref in references:
32
- print(f'🔹 Processing reference: {ref}')
33
-
34
- ref_mask = adata.obs[obs_reference_col] == ref
35
- adata_subset = adata[ref_mask].copy()
36
- N = adata_subset.shape[0]
37
-
38
- combined_mask = torch.zeros(len(adata.var), dtype=torch.bool)
39
- for var_set in var_filters_sets:
40
- if any(ref in v for v in var_set):
41
- set_mask = torch.ones(len(adata.var), dtype=torch.bool)
42
- for key in var_set:
43
- set_mask &= torch.from_numpy(adata.var[key].values)
44
- combined_mask |= set_mask
45
-
46
- selected_cols = adata.var.index[combined_mask.numpy()].to_list()
47
- col_indices = [adata.var.index.get_loc(col) for col in selected_cols]
48
-
49
- print(f"Selected {len(col_indices)} columns out of {adata.var.shape[0]} for {ref}")
50
-
51
- X = adata_subset.X
52
- if not isinstance(X, np.ndarray):
53
- X = X.toarray()
54
- X_subset = X[:, col_indices]
55
- X_tensor = torch.from_numpy(X_subset.astype(np.float32))
56
-
57
- fwd_hamming_to_next = torch.full((N,), float('nan'))
58
- rev_hamming_to_prev = torch.full((N,), float('nan'))
59
-
60
- def cluster_pass(X_tensor, reverse=False, window_size=50, record_distances=False):
61
- N_local = X_tensor.shape[0]
62
- X_sortable = X_tensor.nan_to_num(-1)
63
- sort_keys = X_sortable.tolist()
64
- sorted_idx = sorted(range(N_local), key=lambda i: sort_keys[i], reverse=reverse)
65
- sorted_X = X_tensor[sorted_idx]
66
-
67
- cluster_pairs = []
68
-
69
- for i in tqdm(range(len(sorted_X)), desc=f"Pass {'rev' if reverse else 'fwd'} ({ref})"):
70
- row_i = sorted_X[i]
71
- j_range = range(i + 1, min(i + 1 + window_size, len(sorted_X)))
72
-
73
- if len(j_range) > 0:
74
- row_i_exp = row_i.unsqueeze(0)
75
- block_rows = sorted_X[j_range]
76
- valid_mask = (~torch.isnan(row_i_exp)) & (~torch.isnan(block_rows))
77
- valid_counts = valid_mask.sum(dim=1)
78
- diffs = (row_i_exp != block_rows) & valid_mask
79
- hamming_dists = diffs.sum(dim=1) / valid_counts.clamp(min=1)
80
- all_hamming_dists.extend(hamming_dists.cpu().numpy().tolist())
81
-
82
- matches = (hamming_dists < distance_threshold) & (valid_counts > 0)
83
- for offset_idx, m in zip(j_range, matches):
84
- if m:
85
- cluster_pairs.append((sorted_idx[i], sorted_idx[offset_idx]))
86
-
87
- if record_distances and i + 1 < len(sorted_X):
88
- next_idx = sorted_idx[i + 1]
89
- valid_mask_pair = (~torch.isnan(row_i)) & (~torch.isnan(sorted_X[i + 1]))
90
- if valid_mask_pair.sum() > 0:
91
- d = (row_i[valid_mask_pair] != sorted_X[i + 1][valid_mask_pair]).sum()
92
- norm_d = d.item() / valid_mask_pair.sum().item()
93
- if reverse:
94
- rev_hamming_to_prev[next_idx] = norm_d
736
+ vals_for_ref = []
737
+ for s in samples:
738
+ for dt in distance_types:
739
+ a = np.asarray(grid[(s, ref)].get(dt, []), dtype=float)
740
+ if a.size:
741
+ a = a[np.isfinite(a)]
742
+ if a.size:
743
+ vals_for_ref.append(a)
744
+ if vals_for_ref:
745
+ allvals = np.concatenate(vals_for_ref)
746
+ vmax = float(np.nanmax(allvals)) if np.isfinite(allvals).any() else 1.0
747
+ # pad slightly to include uppermost bin and always keep at least distance_threshold
748
+ vmax = max(vmax, float(distance_threshold))
749
+ vmax = min(1.0, max(0.01, vmax)) # clamp to [0.01, 1.0] to avoid degenerate bins
750
+ else:
751
+ vmax = 1.0
752
+ ref_xmax[ref] = vmax
753
+
754
+ # counts (for labels)
755
+ if use_adata:
756
+ counts = {(s, r): int(((adata.obs[sample_key] == s) & (adata.obs[ref_key] == r)).sum()) for s in samples for r in references}
757
+ else:
758
+ counts = {(s, r): sum(len(grid[(s, r)][dt]) for dt in distance_types) for s in samples for r in references}
759
+
760
+ distance_pages = []
761
+ cluster_size_pages = []
762
+ n_pages = math.ceil(len(samples) / rows_per_page)
763
+ palette = plt.get_cmap("tab10")
764
+ dtype_colors = {dt: palette(i % 10) for i, dt in enumerate(distance_types)}
765
+
766
+ for page in range(n_pages):
767
+ start = page * rows_per_page
768
+ end = min(start + rows_per_page, len(samples))
769
+ chunk = samples[start:end]
770
+ nrows = len(chunk)
771
+ ncols = len(references)
772
+
773
+ # Distance histogram page
774
+ fig_w = figsize_per_cell[0] * ncols
775
+ fig_h = figsize_per_cell[1] * nrows
776
+ fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False)
777
+
778
+ for r_idx, sample_name in enumerate(chunk):
779
+ for c_idx, ref_name in enumerate(references):
780
+ ax = axes[r_idx][c_idx]
781
+ any_data = False
782
+ # pick per-column bins based on ref_xmax
783
+ ref_vmax = ref_xmax.get(ref_name, 1.0)
784
+ bins_edges = np.linspace(0.0, ref_vmax, bins + 1)
785
+ for dtype in distance_types:
786
+ vals = np.asarray(grid[(sample_name, ref_name)].get(dtype, []), dtype=float)
787
+ if vals.size > 0:
788
+ vals = vals[np.isfinite(vals)]
789
+ vals = vals[(vals >= 0.0) & (vals <= ref_vmax)]
790
+ if vals.size > 0:
791
+ any_data = True
792
+ ax.hist(vals, bins=bins_edges, alpha=0.5, label=dtype, density=False, stacked=False,
793
+ color=dtype_colors.get(dtype, None))
794
+ if not any_data:
795
+ ax.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes, fontsize=10, color="gray")
796
+ # threshold line (make sure it is within axis)
797
+ ax.axvline(distance_threshold, color="red", linestyle="--", linewidth=1)
798
+
799
+ if r_idx == 0:
800
+ ax.set_title(str(ref_name), fontsize=10)
801
+ if c_idx == 0:
802
+ total_reads = sum(counts.get((sample_name, ref), 0) for ref in references) if not use_adata else int((adata.obs[sample_key] == sample_name).sum())
803
+ ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=9)
804
+ if r_idx == nrows - 1:
805
+ ax.set_xlabel("Hamming Distance", fontsize=9)
806
+ else:
807
+ ax.set_xticklabels([])
808
+
809
+ ax.set_xlim(left=0.0, right=ref_vmax)
810
+ ax.grid(True, alpha=0.25)
811
+ if r_idx == 0 and c_idx == 0:
812
+ ax.legend(fontsize=7, loc="upper right")
813
+
814
+ fig.suptitle(f"Hamming distance histograms (rows=samples, cols=references) — page {page+1}/{n_pages}", fontsize=12, y=0.995)
815
+ fig.tight_layout(rect=[0, 0, 1, 0.96])
816
+
817
+ if output_directory:
818
+ os.makedirs(output_directory, exist_ok=True)
819
+ fname = os.path.join(output_directory, f"hamming_histograms_page_{page+1}.png")
820
+ plt.savefig(fname, bbox_inches="tight")
821
+ distance_pages.append(fname)
822
+ else:
823
+ plt.show()
824
+ plt.close(fig)
825
+
826
+ # Cluster-size histogram page (unchanged except it uses adata-derived sizes per cluster if available)
827
+ fig_w = figsize_per_cell[0] * ncols
828
+ fig_h = figsize_per_cell[1] * nrows
829
+ fig2, axes2 = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False)
830
+
831
+ for r_idx, sample_name in enumerate(chunk):
832
+ for c_idx, ref_name in enumerate(references):
833
+ ax = axes2[r_idx][c_idx]
834
+ sizes = []
835
+ if use_adata and ("sequence__merged_cluster_id" in adata.obs.columns and "sequence__cluster_size" in adata.obs.columns):
836
+ sub = adata.obs[(adata.obs[sample_key] == sample_name) & (adata.obs[ref_key] == ref_name)]
837
+ if not sub.empty:
838
+ try:
839
+ grp = sub.groupby("sequence__merged_cluster_id")["sequence__cluster_size"].first()
840
+ sizes = [int(x) for x in grp.to_numpy().tolist() if (pd.notna(x) and np.isfinite(x))]
841
+ except Exception:
842
+ try:
843
+ unique_pairs = sub[["sequence__merged_cluster_id", "sequence__cluster_size"]].drop_duplicates()
844
+ sizes = [int(x) for x in unique_pairs["sequence__cluster_size"].dropna().astype(int).tolist()]
845
+ except Exception:
846
+ sizes = []
847
+ if (not sizes) and histograms:
848
+ for h in histograms:
849
+ if h.get("sample") == sample_name and h.get("reference") == ref_name:
850
+ sizes = h.get("cluster_counts", []) or []
851
+ break
852
+
853
+ if sizes:
854
+ ax.hist(sizes, bins=range(1, max(2, max(sizes) + 1)), alpha=0.8, align="left")
855
+ ax.set_xlabel("Cluster size")
856
+ ax.set_ylabel("Count")
857
+ else:
858
+ ax.text(0.5, 0.5, "No clusters", ha="center", va="center", transform=ax.transAxes, fontsize=10, color="gray")
859
+
860
+ if r_idx == 0:
861
+ ax.set_title(str(ref_name), fontsize=10)
862
+ if c_idx == 0:
863
+ total_reads = sum(counts.get((sample_name, ref), 0) for ref in references) if not use_adata else int((adata.obs[sample_key] == sample_name).sum())
864
+ ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=9)
865
+ if r_idx != nrows - 1:
866
+ ax.set_xticklabels([])
867
+
868
+ ax.grid(True, alpha=0.25)
869
+
870
+ fig2.suptitle(f"Union-find cluster size histograms — page {page+1}/{n_pages}", fontsize=12, y=0.995)
871
+ fig2.tight_layout(rect=[0, 0, 1, 0.96])
872
+
873
+ if output_directory:
874
+ fname2 = os.path.join(output_directory, f"cluster_size_histograms_page_{page+1}.png")
875
+ plt.savefig(fname2, bbox_inches="tight")
876
+ cluster_size_pages.append(fname2)
877
+ else:
878
+ plt.show()
879
+ plt.close(fig2)
880
+
881
+ return {"distance_pages": distance_pages, "cluster_size_pages": cluster_size_pages}
882
+
883
+
884
+ def plot_hamming_vs_metric_pages(
885
+ adata,
886
+ metric_keys: Union[str, List[str]],
887
+ hamming_col: str = "fwd_hamming_to_next",
888
+ sample_col: str = "Barcode",
889
+ ref_col: str = "Reference_strand",
890
+ references: Optional[List[str]] = None,
891
+ samples: Optional[List[str]] = None,
892
+ rows_per_fig: int = 6,
893
+ dpi: int = 160,
894
+ filename_prefix: str = "hamming_vs_metric",
895
+ output_dir: Optional[str] = None,
896
+ kde: bool = False,
897
+ contour: bool = False,
898
+ regression: bool = True,
899
+ show_ticks: bool = True,
900
+ clustering: Optional[Dict[str, Any]] = None,
901
+ write_clusters_to_adata: bool = False,
902
+ figsize_per_cell: Tuple[float, float] = (4.0, 3.0),
903
+ random_state: int = 0,
904
+ highlight_threshold: Optional[float] = None,
905
+ highlight_color: str = "red",
906
+ color_by_duplicate: bool = False,
907
+ duplicate_col: str = "is_duplicate",
908
+ ) -> Dict[str, Any]:
909
+ """
910
+ Plot hamming (y) vs metric (x).
911
+
912
+ New behavior:
913
+ - If color_by_duplicate is True and adata.obs[duplicate_col] exists, points are colored by that boolean:
914
+ duplicates -> highlight_color (with edge), non-duplicates -> gray
915
+ - If color_by_duplicate is False, previous highlight_threshold behavior is preserved.
916
+ """
917
+ if isinstance(metric_keys, str):
918
+ metric_keys = [metric_keys]
919
+ metric_keys = list(metric_keys)
920
+
921
+ if output_dir is not None:
922
+ os.makedirs(output_dir, exist_ok=True)
923
+
924
+ obs = adata.obs
925
+ if sample_col not in obs.columns or ref_col not in obs.columns:
926
+ raise ValueError(f"sample_col '{sample_col}' and ref_col '{ref_col}' must exist in adata.obs")
927
+
928
+ # canonicalize samples and refs
929
+ if samples is None:
930
+ sseries = obs[sample_col]
931
+ if not pd.api.types.is_categorical_dtype(sseries):
932
+ sseries = sseries.astype("category")
933
+ samples_all = list(sseries.cat.categories)
934
+ else:
935
+ samples_all = list(samples)
936
+
937
+ if references is None:
938
+ rseries = obs[ref_col]
939
+ if not pd.api.types.is_categorical_dtype(rseries):
940
+ rseries = rseries.astype("category")
941
+ refs_all = list(rseries.cat.categories)
942
+ else:
943
+ refs_all = list(references)
944
+
945
+ extra_col = "ALLREFS"
946
+ cols = list(refs_all) + [extra_col]
947
+
948
+ saved_map: Dict[str, Any] = {}
949
+
950
+ for metric in metric_keys:
951
+ clusters_info: Dict[Tuple[str, str], Dict[str, Any]] = {}
952
+ files: List[str] = []
953
+ written_cols: List[str] = []
954
+
955
+ # compute global x/y limits robustly by aligning Series
956
+ global_xlim = None
957
+ global_ylim = None
958
+
959
+ metric_present = metric in obs.columns
960
+ hamming_present = hamming_col in obs.columns
961
+
962
+ if metric_present or hamming_present:
963
+ sX = obs[metric].astype(float) if metric_present else None
964
+ sY = pd.to_numeric(obs[hamming_col], errors="coerce") if hamming_present else None
965
+
966
+ if (sX is not None) and (sY is not None):
967
+ valid_both = sX.notna() & sY.notna() & np.isfinite(sX.values) & np.isfinite(sY.values)
968
+ if valid_both.any():
969
+ xvals = sX[valid_both].to_numpy(dtype=float)
970
+ yvals = sY[valid_both].to_numpy(dtype=float)
971
+ xmin, xmax = float(np.nanmin(xvals)), float(np.nanmax(xvals))
972
+ ymin, ymax = float(np.nanmin(yvals)), float(np.nanmax(yvals))
973
+ xpad = max(1e-6, (xmax - xmin) * 0.05) if xmax > xmin else max(1e-3, abs(xmin) * 0.05 + 1e-3)
974
+ ypad = max(1e-6, (ymax - ymin) * 0.05) if ymax > ymin else max(1e-3, abs(ymin) * 0.05 + 1e-3)
975
+ global_xlim = (xmin - xpad, xmax + xpad)
976
+ global_ylim = (ymin - ypad, ymax + ypad)
977
+ else:
978
+ sX_finite = sX[np.isfinite(sX)]
979
+ sY_finite = sY[np.isfinite(sY)]
980
+ if sX_finite.size > 0:
981
+ xmin, xmax = float(np.nanmin(sX_finite)), float(np.nanmax(sX_finite))
982
+ xpad = max(1e-6, (xmax - xmin) * 0.05) if xmax > xmin else max(1e-3, abs(xmin) * 0.05 + 1e-3)
983
+ global_xlim = (xmin - xpad, xmax + xpad)
984
+ if sY_finite.size > 0:
985
+ ymin, ymax = float(np.nanmin(sY_finite)), float(np.nanmax(sY_finite))
986
+ ypad = max(1e-6, (ymax - ymin) * 0.05) if ymax > ymin else max(1e-3, abs(ymin) * 0.05 + 1e-3)
987
+ global_ylim = (ymin - ypad, ymax + ypad)
988
+ elif sX is not None:
989
+ sX_finite = sX[np.isfinite(sX)]
990
+ if sX_finite.size > 0:
991
+ xmin, xmax = float(np.nanmin(sX_finite)), float(np.nanmax(sX_finite))
992
+ xpad = max(1e-6, (xmax - xmin) * 0.05) if xmax > xmin else max(1e-3, abs(xmin) * 0.05 + 1e-3)
993
+ global_xlim = (xmin - xpad, xmax + xpad)
994
+ elif sY is not None:
995
+ sY_finite = sY[np.isfinite(sY)]
996
+ if sY_finite.size > 0:
997
+ ymin, ymax = float(np.nanmin(sY_finite)), float(np.nanmax(sY_finite))
998
+ ypad = max(1e-6, (ymax - ymin) * 0.05) if ymax > ymin else max(1e-3, abs(ymin) * 0.05 + 1e-3)
999
+ global_ylim = (ymin - ypad, ymax + ypad)
1000
+
1001
+ # pagination
1002
+ for start in range(0, len(samples_all), rows_per_fig):
1003
+ chunk = samples_all[start : start + rows_per_fig]
1004
+ nrows = len(chunk)
1005
+ ncols = len(cols)
1006
+ fig_w = ncols * figsize_per_cell[0]
1007
+ fig_h = nrows * figsize_per_cell[1]
1008
+ fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False)
1009
+
1010
+ for r_idx, sample_name in enumerate(chunk):
1011
+ for c_idx, ref_name in enumerate(cols):
1012
+ ax = axes[r_idx][c_idx]
1013
+ if ref_name == extra_col:
1014
+ mask = (obs[sample_col].values == sample_name)
1015
+ else:
1016
+ mask = (obs[sample_col].values == sample_name) & (obs[ref_col].values == ref_name)
1017
+
1018
+ sub = obs[mask]
1019
+
1020
+ if metric in sub.columns:
1021
+ x_all = pd.to_numeric(sub[metric], errors="coerce").to_numpy(dtype=float)
1022
+ else:
1023
+ x_all = np.array([], dtype=float)
1024
+ if hamming_col in sub.columns:
1025
+ y_all = pd.to_numeric(sub[hamming_col], errors="coerce").to_numpy(dtype=float)
1026
+ else:
1027
+ y_all = np.array([], dtype=float)
1028
+
1029
+ idxs = sub.index.to_numpy()
1030
+
1031
+ # drop nan pairs
1032
+ if x_all.size and y_all.size and len(x_all) == len(y_all):
1033
+ valid_pair_mask = np.isfinite(x_all) & np.isfinite(y_all)
1034
+ x = x_all[valid_pair_mask]
1035
+ y = y_all[valid_pair_mask]
1036
+ idxs_valid = idxs[valid_pair_mask]
1037
+ else:
1038
+ x = np.array([], dtype=float)
1039
+ y = np.array([], dtype=float)
1040
+ idxs_valid = np.array([], dtype=int)
1041
+
1042
+ if x.size == 0:
1043
+ ax.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes)
1044
+ clusters_info[(sample_name, ref_name)] = {"diag": None, "n_points": 0}
1045
+ else:
1046
+ # Decide color mapping
1047
+ if color_by_duplicate and duplicate_col in adata.obs.columns and idxs_valid.size > 0:
1048
+ # get boolean series aligned to idxs_valid
1049
+ try:
1050
+ dup_flags = adata.obs.loc[idxs_valid, duplicate_col].astype(bool).to_numpy()
1051
+ except Exception:
1052
+ dup_flags = np.zeros(len(idxs_valid), dtype=bool)
1053
+ mask_dup = dup_flags
1054
+ mask_nondup = ~mask_dup
1055
+ # plot non-duplicates first in gray, duplicates in highlight color
1056
+ if mask_nondup.any():
1057
+ ax.scatter(x[mask_nondup], y[mask_nondup], s=12, alpha=0.6, rasterized=True, c="lightgray")
1058
+ if mask_dup.any():
1059
+ ax.scatter(x[mask_dup], y[mask_dup], s=20, alpha=0.9, rasterized=True, c=highlight_color, edgecolors="k", linewidths=0.3)
1060
+ else:
1061
+ # old behavior: highlight by threshold if requested
1062
+ if highlight_threshold is not None and y.size:
1063
+ mask_low = (y < float(highlight_threshold)) & np.isfinite(y)
1064
+ mask_high = ~mask_low
1065
+ if mask_high.any():
1066
+ ax.scatter(x[mask_high], y[mask_high], s=12, alpha=0.6, rasterized=True)
1067
+ if mask_low.any():
1068
+ ax.scatter(x[mask_low], y[mask_low], s=18, alpha=0.9, rasterized=True, c=highlight_color, edgecolors="k", linewidths=0.3)
1069
+ else:
1070
+ ax.scatter(x, y, s=12, alpha=0.6, rasterized=True)
1071
+
1072
+ if kde and gaussian_kde is not None and x.size >= 4:
1073
+ try:
1074
+ xy = np.vstack([x, y])
1075
+ kde2 = gaussian_kde(xy)(xy)
1076
+ if contour:
1077
+ xi = np.linspace(np.nanmin(x), np.nanmax(x), 80)
1078
+ yi = np.linspace(np.nanmin(y), np.nanmax(y), 80)
1079
+ xi_g, yi_g = np.meshgrid(xi, yi)
1080
+ coords = np.vstack([xi_g.ravel(), yi_g.ravel()])
1081
+ zi = gaussian_kde(np.vstack([x, y]))(coords).reshape(xi_g.shape)
1082
+ ax.contourf(xi_g, yi_g, zi, levels=8, alpha=0.35, cmap="Blues")
1083
+ else:
1084
+ ax.scatter(x, y, c=kde2, s=16, cmap="viridis", alpha=0.7, linewidths=0)
1085
+ except Exception:
1086
+ pass
1087
+
1088
+ if regression and x.size >= 2:
1089
+ try:
1090
+ a, b = np.polyfit(x, y, 1)
1091
+ xs = np.linspace(np.nanmin(x), np.nanmax(x), 100)
1092
+ ys = a * xs + b
1093
+ ax.plot(xs, ys, linestyle="--", linewidth=1.2, alpha=0.9, color="red")
1094
+ r = np.corrcoef(x, y)[0, 1]
1095
+ ax.text(0.98, 0.02, f"r={float(r):.3f}", ha="right", va="bottom", transform=ax.transAxes, fontsize=8,
1096
+ bbox=dict(facecolor="white", alpha=0.6, boxstyle="round,pad=0.2"))
1097
+ except Exception:
1098
+ pass
1099
+
1100
+ if clustering:
1101
+ cl_labels, diag = _run_clustering(
1102
+ x, y,
1103
+ method=clustering.get("method", "dbscan"),
1104
+ n_clusters=clustering.get("n_clusters", 2),
1105
+ dbscan_eps=clustering.get("dbscan_eps", 0.05),
1106
+ dbscan_min_samples=clustering.get("dbscan_min_samples", 5),
1107
+ random_state=random_state,
1108
+ min_points=clustering.get("min_points", 8),
1109
+ )
1110
+
1111
+ remapped_labels = cl_labels.copy()
1112
+ unique_nonnoise = sorted([u for u in np.unique(cl_labels) if u != -1])
1113
+ if len(unique_nonnoise) > 0:
1114
+ medians = {}
1115
+ for lab in unique_nonnoise:
1116
+ mask_lab = (cl_labels == lab)
1117
+ medians[lab] = float(np.median(y[mask_lab])) if mask_lab.any() else float("nan")
1118
+ sorted_by_median = sorted(unique_nonnoise, key=lambda l: (np.nan if np.isnan(medians[l]) else medians[l]), reverse=True)
1119
+ mapping = {old: new for new, old in enumerate(sorted_by_median)}
1120
+ for i_lab in range(len(remapped_labels)):
1121
+ if remapped_labels[i_lab] != -1:
1122
+ remapped_labels[i_lab] = mapping.get(remapped_labels[i_lab], -1)
1123
+ diag = diag or {}
1124
+ diag["cluster_median_hamming"] = {int(old): medians[old] for old in medians}
1125
+ diag["cluster_old_to_new_map"] = {int(old): int(new) for old, new in mapping.items()}
95
1126
  else:
96
- fwd_hamming_to_next[sorted_idx[i]] = norm_d
1127
+ remapped_labels = cl_labels.copy()
1128
+ diag = diag or {}
1129
+ diag["cluster_median_hamming"] = {}
1130
+ diag["cluster_old_to_new_map"] = {}
1131
+
1132
+ _overlay_clusters_on_ax(ax, x, y, remapped_labels, diag,
1133
+ cmap=clustering.get("cmap", "tab10"),
1134
+ hull=clustering.get("hull", True),
1135
+ show_cluster_labels=True)
1136
+
1137
+ clusters_info[(sample_name, ref_name)] = {"diag": diag, "n_points": len(x)}
1138
+
1139
+ if write_clusters_to_adata and idxs_valid.size > 0:
1140
+ colname_safe_ref = (ref_name if ref_name != extra_col else "ALLREFS")
1141
+ colname = f"hamming_cluster__{metric}__{sample_name}__{colname_safe_ref}"
1142
+ if colname not in adata.obs.columns:
1143
+ adata.obs[colname] = np.nan
1144
+ lab_arr = remapped_labels.astype(float)
1145
+ adata.obs.loc[idxs_valid, colname] = lab_arr
1146
+ if colname not in written_cols:
1147
+ written_cols.append(colname)
1148
+
1149
+ if r_idx == 0:
1150
+ ax.set_title(str(ref_name), fontsize=9)
1151
+ if c_idx == 0:
1152
+ total_reads = int((obs[sample_col] == sample_name).sum())
1153
+ ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=8)
1154
+ if r_idx == nrows - 1:
1155
+ ax.set_xlabel(metric, fontsize=8)
1156
+
1157
+ if global_xlim is not None:
1158
+ ax.set_xlim(global_xlim)
1159
+ if global_ylim is not None:
1160
+ ax.set_ylim(global_ylim)
1161
+
1162
+ if not show_ticks:
1163
+ ax.set_xticklabels([])
1164
+ ax.set_yticklabels([])
1165
+
1166
+ ax.grid(True, alpha=0.25)
1167
+
1168
+ fig.suptitle(f"Hamming ({hamming_col}) vs {metric}", y=0.995, fontsize=11)
1169
+ fig.tight_layout(rect=[0, 0, 1, 0.97])
1170
+
1171
+ page_idx = start // rows_per_fig + 1
1172
+ fname = f"{filename_prefix}_{metric}_page{page_idx}.png"
1173
+ if output_dir:
1174
+ outpath = os.path.join(output_dir, fname)
1175
+ plt.savefig(outpath, bbox_inches="tight", dpi=dpi)
1176
+ files.append(outpath)
1177
+ else:
1178
+ plt.show()
1179
+ plt.close(fig)
1180
+
1181
+ saved_map[metric] = {"files": files, "clusters_info": clusters_info, "written_cols": written_cols}
1182
+
1183
+ return saved_map
1184
+
1185
+
1186
+ def _run_clustering(
1187
+ x: np.ndarray,
1188
+ y: np.ndarray,
1189
+ *,
1190
+ method: str = "kmeans", # "kmeans", "dbscan", "gmm", "hdbscan"
1191
+ n_clusters: int = 2,
1192
+ dbscan_eps: float = 0.05,
1193
+ dbscan_min_samples: int = 5,
1194
+ random_state: int = 0,
1195
+ min_points: int = 10,
1196
+ ) -> Tuple[np.ndarray, Dict[str, Any]]:
1197
+ """
1198
+ Run clustering on 2D points (x,y). Returns labels (len = npoints) and diagnostics dict.
1199
+ Labels follow sklearn conventions (noise -> -1 for DBSCAN/HDBSCAN).
1200
+ """
1201
+ try:
1202
+ from sklearn.cluster import KMeans, DBSCAN
1203
+ from sklearn.mixture import GaussianMixture
1204
+ from sklearn.metrics import silhouette_score
1205
+ except Exception:
1206
+ KMeans = DBSCAN = GaussianMixture = silhouette_score = None
1207
+
1208
+ pts = np.column_stack([x, y])
1209
+ diagnostics: Dict[str, Any] = {"method": method, "n_input": len(x)}
1210
+ if len(x) < min_points:
1211
+ diagnostics["skipped"] = True
1212
+ return np.full(len(x), -1, dtype=int), diagnostics
1213
+
1214
+ method = (method or "kmeans").lower()
1215
+ labels = np.full(len(x), -1, dtype=int)
1216
+
1217
+ try:
1218
+ if method == "kmeans" and KMeans is not None:
1219
+ km = KMeans(n_clusters=max(1, int(n_clusters)), random_state=random_state)
1220
+ labels = km.fit_predict(pts)
1221
+ diagnostics["centers"] = km.cluster_centers_
1222
+ diagnostics["n_clusters_found"] = int(len(np.unique(labels)))
1223
+ elif method == "dbscan" and DBSCAN is not None:
1224
+ db = DBSCAN(eps=float(dbscan_eps), min_samples=int(dbscan_min_samples))
1225
+ labels = db.fit_predict(pts)
1226
+ uniq = [u for u in np.unique(labels) if u != -1]
1227
+ diagnostics["n_clusters_found"] = int(len(uniq))
1228
+ elif method == "gmm" and GaussianMixture is not None:
1229
+ gm = GaussianMixture(n_components=max(1, int(n_clusters)), random_state=random_state)
1230
+ labels = gm.fit_predict(pts)
1231
+ diagnostics["means"] = gm.means_
1232
+ diagnostics["covariances"] = getattr(gm, "covariances_", None)
1233
+ diagnostics["n_clusters_found"] = int(len(np.unique(labels)))
1234
+ else:
1235
+ # fallback: try DBSCAN then KMeans
1236
+ if DBSCAN is not None:
1237
+ db = DBSCAN(eps=float(dbscan_eps), min_samples=int(dbscan_min_samples))
1238
+ labels = db.fit_predict(pts)
1239
+ if (labels == -1).all() and KMeans is not None:
1240
+ km = KMeans(n_clusters=max(1, int(n_clusters)), random_state=random_state)
1241
+ labels = km.fit_predict(pts)
1242
+ diagnostics["fallback_to"] = "kmeans"
1243
+ diagnostics["centers"] = km.cluster_centers_
1244
+ diagnostics["n_clusters_found"] = int(len(np.unique(labels)))
1245
+ elif KMeans is not None:
1246
+ km = KMeans(n_clusters=max(1, int(n_clusters)), random_state=random_state)
1247
+ labels = km.fit_predict(pts)
1248
+ diagnostics["n_clusters_found"] = int(len(np.unique(labels)))
1249
+ else:
1250
+ diagnostics["skipped"] = True
1251
+ return np.full(len(x), -1, dtype=int), diagnostics
1252
+
1253
+ except Exception as e:
1254
+ diagnostics["error"] = str(e)
1255
+ diagnostics["skipped"] = True
1256
+ return np.full(len(x), -1, dtype=int), diagnostics
97
1257
 
98
- return cluster_pairs
1258
+ # remap non-noise labels to contiguous ints starting at 0 (keep -1 for noise)
1259
+ unique_nonnoise = sorted([u for u in np.unique(labels) if u != -1])
1260
+ if unique_nonnoise:
1261
+ mapping = {old: new for new, old in enumerate(unique_nonnoise)}
1262
+ remapped = np.full_like(labels, -1)
1263
+ for i, lab in enumerate(labels):
1264
+ if lab != -1:
1265
+ remapped[i] = mapping.get(lab, -1)
1266
+ labels = remapped
1267
+ diagnostics["n_clusters_found"] = int(len(unique_nonnoise))
1268
+ else:
1269
+ diagnostics["n_clusters_found"] = 0
99
1270
 
100
- pairs_fwd = cluster_pass(X_tensor, reverse=False, record_distances=True)
101
- involved_in_fwd = set([p[0] for p in pairs_fwd] + [p[1] for p in pairs_fwd])
102
- mask_for_rev = torch.ones(N, dtype=torch.bool)
103
- mask_for_rev[list(involved_in_fwd)] = False
104
- pairs_rev = cluster_pass(X_tensor[mask_for_rev], reverse=True, record_distances=True)
1271
+ # compute silhouette if suitable
1272
+ try:
1273
+ if diagnostics.get("n_clusters_found", 0) >= 2 and len(x) >= 3 and silhouette_score is not None:
1274
+ diagnostics["silhouette"] = float(silhouette_score(pts, labels))
1275
+ else:
1276
+ diagnostics["silhouette"] = None
1277
+ except Exception:
1278
+ diagnostics["silhouette"] = None
105
1279
 
106
- all_pairs = pairs_fwd + [(list(mask_for_rev.nonzero(as_tuple=True)[0])[i], list(mask_for_rev.nonzero(as_tuple=True)[0])[j]) for i, j in pairs_rev]
1280
+ diagnostics["skipped"] = False
1281
+ return labels.astype(int), diagnostics
107
1282
 
108
- uf = UnionFind(N)
109
- for i, j in all_pairs:
110
- uf.union(i, j)
111
1283
 
112
- merged_cluster = torch.zeros(N, dtype=torch.long)
113
- for i in range(N):
114
- merged_cluster[i] = uf.find(i)
1284
+ def _overlay_clusters_on_ax(
1285
+ ax,
1286
+ x,
1287
+ y,
1288
+ labels,
1289
+ diagnostics,
1290
+ *,
1291
+ cmap="tab20",
1292
+ alpha_pts=0.6,
1293
+ marker="o",
1294
+ plot_centroids=True,
1295
+ centroid_marker="X",
1296
+ centroid_size=60,
1297
+ hull=True,
1298
+ hull_alpha=0.12,
1299
+ hull_edgecolor="k",
1300
+ show_cluster_labels=True,
1301
+ cluster_label_fontsize=8,
1302
+ ):
1303
+ """
1304
+ Color points by label, plot centroids and optional convex hulls.
1305
+ Labels == -1 are noise and drawn in grey.
1306
+ Also annotates cluster numbers near centroids (contiguous numbers starting at 0).
1307
+ """
1308
+ import matplotlib.colors as mcolors
1309
+ from scipy.spatial import ConvexHull
115
1310
 
116
- cluster_sizes = torch.zeros_like(merged_cluster)
117
- for cid in merged_cluster.unique():
118
- members = (merged_cluster == cid).nonzero(as_tuple=True)[0]
119
- cluster_sizes[members] = len(members)
1311
+ labels = np.asarray(labels)
1312
+ pts = np.column_stack([x, y])
120
1313
 
121
- is_duplicate = torch.zeros(N, dtype=torch.bool)
122
- for cid in merged_cluster.unique():
123
- members = (merged_cluster == cid).nonzero(as_tuple=True)[0]
124
- if len(members) > 1:
125
- is_duplicate[members[1:]] = True
1314
+ unique = np.unique(labels)
1315
+ # sort so noise (-1) comes last for drawing
1316
+ unique = sorted(unique.tolist(), key=lambda v: (v == -1, v))
1317
+ cmap_obj = plt.get_cmap(cmap)
1318
+ ncolors = max(8, len(unique))
1319
+ colors = [cmap_obj(i / float(ncolors)) for i in range(ncolors)]
126
1320
 
127
- adata_subset.obs['is_duplicate'] = is_duplicate.numpy()
128
- adata_subset.obs['merged_cluster_id'] = merged_cluster.numpy()
129
- adata_subset.obs['cluster_size'] = cluster_sizes.numpy()
130
- adata_subset.obs['fwd_hamming_to_next'] = fwd_hamming_to_next.numpy()
131
- adata_subset.obs['rev_hamming_to_prev'] = rev_hamming_to_prev.numpy()
1321
+ for idx, lab in enumerate(unique):
1322
+ mask = labels == lab
1323
+ if not mask.any():
1324
+ continue
1325
+ col = (0.6, 0.6, 0.6, 0.6) if lab == -1 else colors[idx % ncolors]
1326
+ ax.scatter(x[mask], y[mask], s=20, c=[col], alpha=alpha_pts, marker=marker, linewidths=0.2, edgecolors="none", rasterized=True)
132
1327
 
133
- merged_results.append(adata_subset.obs)
1328
+ if lab != -1:
1329
+ # centroid
1330
+ if plot_centroids:
1331
+ cx = float(np.mean(x[mask]))
1332
+ cy = float(np.mean(y[mask]))
1333
+ ax.scatter([cx], [cy], s=centroid_size, marker=centroid_marker, c=[col], edgecolor="k", linewidth=0.6, zorder=10)
134
1334
 
135
- merged_obs = pd.concat(merged_results)
136
- adata.obs = adata.obs.join(merged_obs[['is_duplicate', 'merged_cluster_id', 'cluster_size', 'fwd_hamming_to_next', 'rev_hamming_to_prev']])
1335
+ if show_cluster_labels:
1336
+ ax.text(cx, cy, str(int(lab)), color="white", fontsize=cluster_label_fontsize,
1337
+ ha="center", va="center", weight="bold", zorder=12,
1338
+ bbox=dict(facecolor=(0,0,0,0.5), pad=0.3, boxstyle="round"))
137
1339
 
138
- adata_unique = adata[~adata.obs['is_duplicate']].copy()
1340
+ # hull
1341
+ if hull and np.sum(mask) >= 3:
1342
+ try:
1343
+ ch_pts = pts[mask]
1344
+ hull_idx = ConvexHull(ch_pts).vertices
1345
+ hull_poly = ch_pts[hull_idx]
1346
+ ax.fill(hull_poly[:, 0], hull_poly[:, 1], alpha=hull_alpha, facecolor=col, edgecolor=hull_edgecolor, linewidth=0.6, zorder=5)
1347
+ except Exception:
1348
+ pass
139
1349
 
140
- plt.figure(figsize=(5, 4))
141
- plt.hist(all_hamming_dists, bins=50, alpha=0.75)
142
- plt.axvline(distance_threshold, color="red", linestyle="--", label=f"threshold = {distance_threshold}")
143
- plt.xlabel("Hamming Distance")
144
- plt.ylabel("Frequency")
145
- plt.title("Histogram of Pairwise Hamming Distances")
146
- plt.legend()
147
- plt.show()
1350
+ return None
148
1351
 
149
- return adata_unique, adata