smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. smftools/__init__.py +43 -13
  2. smftools/_settings.py +6 -6
  3. smftools/_version.py +3 -1
  4. smftools/cli/__init__.py +1 -0
  5. smftools/cli/archived/cli_flows.py +2 -0
  6. smftools/cli/helpers.py +9 -1
  7. smftools/cli/hmm_adata.py +905 -242
  8. smftools/cli/load_adata.py +432 -280
  9. smftools/cli/preprocess_adata.py +287 -171
  10. smftools/cli/spatial_adata.py +141 -53
  11. smftools/cli_entry.py +119 -178
  12. smftools/config/__init__.py +3 -1
  13. smftools/config/conversion.yaml +5 -1
  14. smftools/config/deaminase.yaml +1 -1
  15. smftools/config/default.yaml +26 -18
  16. smftools/config/direct.yaml +8 -3
  17. smftools/config/discover_input_files.py +19 -5
  18. smftools/config/experiment_config.py +511 -276
  19. smftools/constants.py +37 -0
  20. smftools/datasets/__init__.py +4 -8
  21. smftools/datasets/datasets.py +32 -18
  22. smftools/hmm/HMM.py +2133 -1428
  23. smftools/hmm/__init__.py +24 -14
  24. smftools/hmm/archived/apply_hmm_batched.py +2 -0
  25. smftools/hmm/archived/calculate_distances.py +2 -0
  26. smftools/hmm/archived/call_hmm_peaks.py +18 -1
  27. smftools/hmm/archived/train_hmm.py +2 -0
  28. smftools/hmm/call_hmm_peaks.py +176 -193
  29. smftools/hmm/display_hmm.py +23 -7
  30. smftools/hmm/hmm_readwrite.py +20 -6
  31. smftools/hmm/nucleosome_hmm_refinement.py +104 -14
  32. smftools/informatics/__init__.py +55 -13
  33. smftools/informatics/archived/bam_conversion.py +2 -0
  34. smftools/informatics/archived/bam_direct.py +2 -0
  35. smftools/informatics/archived/basecall_pod5s.py +2 -0
  36. smftools/informatics/archived/basecalls_to_adata.py +2 -0
  37. smftools/informatics/archived/conversion_smf.py +2 -0
  38. smftools/informatics/archived/deaminase_smf.py +1 -0
  39. smftools/informatics/archived/direct_smf.py +2 -0
  40. smftools/informatics/archived/fast5_to_pod5.py +2 -0
  41. smftools/informatics/archived/helpers/archived/__init__.py +2 -0
  42. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
  43. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
  44. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  45. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
  46. smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
  47. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  48. smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
  49. smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
  50. smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
  51. smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
  52. smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
  53. smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
  54. smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
  55. smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
  56. smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
  57. smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
  58. smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
  59. smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
  60. smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
  61. smftools/informatics/archived/helpers/archived/informatics.py +2 -0
  62. smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
  63. smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
  64. smftools/informatics/archived/helpers/archived/modQC.py +2 -0
  65. smftools/informatics/archived/helpers/archived/modcall.py +2 -0
  66. smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
  67. smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
  68. smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
  69. smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
  70. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
  71. smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
  72. smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
  73. smftools/informatics/archived/print_bam_query_seq.py +9 -1
  74. smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
  75. smftools/informatics/archived/subsample_pod5.py +2 -0
  76. smftools/informatics/bam_functions.py +1059 -269
  77. smftools/informatics/basecalling.py +53 -9
  78. smftools/informatics/bed_functions.py +357 -114
  79. smftools/informatics/binarize_converted_base_identities.py +21 -7
  80. smftools/informatics/complement_base_list.py +9 -6
  81. smftools/informatics/converted_BAM_to_adata.py +324 -137
  82. smftools/informatics/fasta_functions.py +251 -89
  83. smftools/informatics/h5ad_functions.py +202 -30
  84. smftools/informatics/modkit_extract_to_adata.py +623 -274
  85. smftools/informatics/modkit_functions.py +87 -44
  86. smftools/informatics/ohe.py +46 -21
  87. smftools/informatics/pod5_functions.py +114 -74
  88. smftools/informatics/run_multiqc.py +20 -14
  89. smftools/logging_utils.py +51 -0
  90. smftools/machine_learning/__init__.py +23 -12
  91. smftools/machine_learning/data/__init__.py +2 -0
  92. smftools/machine_learning/data/anndata_data_module.py +157 -50
  93. smftools/machine_learning/data/preprocessing.py +4 -1
  94. smftools/machine_learning/evaluation/__init__.py +3 -1
  95. smftools/machine_learning/evaluation/eval_utils.py +13 -14
  96. smftools/machine_learning/evaluation/evaluators.py +52 -34
  97. smftools/machine_learning/inference/__init__.py +3 -1
  98. smftools/machine_learning/inference/inference_utils.py +9 -4
  99. smftools/machine_learning/inference/lightning_inference.py +14 -13
  100. smftools/machine_learning/inference/sklearn_inference.py +8 -8
  101. smftools/machine_learning/inference/sliding_window_inference.py +37 -25
  102. smftools/machine_learning/models/__init__.py +12 -5
  103. smftools/machine_learning/models/base.py +34 -43
  104. smftools/machine_learning/models/cnn.py +22 -13
  105. smftools/machine_learning/models/lightning_base.py +78 -42
  106. smftools/machine_learning/models/mlp.py +18 -5
  107. smftools/machine_learning/models/positional.py +10 -4
  108. smftools/machine_learning/models/rnn.py +8 -3
  109. smftools/machine_learning/models/sklearn_models.py +46 -24
  110. smftools/machine_learning/models/transformer.py +75 -55
  111. smftools/machine_learning/models/wrappers.py +8 -3
  112. smftools/machine_learning/training/__init__.py +4 -2
  113. smftools/machine_learning/training/train_lightning_model.py +42 -23
  114. smftools/machine_learning/training/train_sklearn_model.py +11 -15
  115. smftools/machine_learning/utils/__init__.py +3 -1
  116. smftools/machine_learning/utils/device.py +12 -5
  117. smftools/machine_learning/utils/grl.py +8 -2
  118. smftools/metadata.py +443 -0
  119. smftools/optional_imports.py +31 -0
  120. smftools/plotting/__init__.py +32 -17
  121. smftools/plotting/autocorrelation_plotting.py +153 -48
  122. smftools/plotting/classifiers.py +175 -73
  123. smftools/plotting/general_plotting.py +350 -168
  124. smftools/plotting/hmm_plotting.py +53 -14
  125. smftools/plotting/position_stats.py +155 -87
  126. smftools/plotting/qc_plotting.py +25 -12
  127. smftools/preprocessing/__init__.py +35 -37
  128. smftools/preprocessing/append_base_context.py +105 -79
  129. smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
  130. smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
  131. smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
  132. smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
  133. smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
  134. smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
  135. smftools/preprocessing/binarize.py +21 -4
  136. smftools/preprocessing/binarize_on_Youden.py +127 -31
  137. smftools/preprocessing/binary_layers_to_ohe.py +18 -11
  138. smftools/preprocessing/calculate_complexity_II.py +89 -59
  139. smftools/preprocessing/calculate_consensus.py +28 -19
  140. smftools/preprocessing/calculate_coverage.py +44 -22
  141. smftools/preprocessing/calculate_pairwise_differences.py +4 -1
  142. smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
  143. smftools/preprocessing/calculate_position_Youden.py +110 -55
  144. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  145. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  146. smftools/preprocessing/clean_NaN.py +38 -28
  147. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  148. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
  149. smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
  150. smftools/preprocessing/flag_duplicate_reads.py +708 -303
  151. smftools/preprocessing/invert_adata.py +26 -11
  152. smftools/preprocessing/load_sample_sheet.py +40 -22
  153. smftools/preprocessing/make_dirs.py +9 -3
  154. smftools/preprocessing/min_non_diagonal.py +4 -1
  155. smftools/preprocessing/recipes.py +58 -23
  156. smftools/preprocessing/reindex_references_adata.py +93 -27
  157. smftools/preprocessing/subsample_adata.py +33 -16
  158. smftools/readwrite.py +264 -109
  159. smftools/schema/__init__.py +11 -0
  160. smftools/schema/anndata_schema_v1.yaml +227 -0
  161. smftools/tools/__init__.py +25 -18
  162. smftools/tools/archived/apply_hmm.py +2 -0
  163. smftools/tools/archived/classifiers.py +165 -0
  164. smftools/tools/archived/classify_methylated_features.py +2 -0
  165. smftools/tools/archived/classify_non_methylated_features.py +2 -0
  166. smftools/tools/archived/subset_adata_v1.py +12 -1
  167. smftools/tools/archived/subset_adata_v2.py +14 -1
  168. smftools/tools/calculate_umap.py +56 -15
  169. smftools/tools/cluster_adata_on_methylation.py +122 -47
  170. smftools/tools/general_tools.py +70 -25
  171. smftools/tools/position_stats.py +220 -99
  172. smftools/tools/read_stats.py +50 -29
  173. smftools/tools/spatial_autocorrelation.py +365 -192
  174. smftools/tools/subset_adata.py +23 -21
  175. smftools-0.3.0.dist-info/METADATA +147 -0
  176. smftools-0.3.0.dist-info/RECORD +182 -0
  177. smftools-0.2.4.dist-info/METADATA +0 -141
  178. smftools-0.2.4.dist-info/RECORD +0 -176
  179. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
  180. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
  181. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,52 +1,76 @@
1
+ from __future__ import annotations
2
+
1
3
  # duplicate_detection_with_hier_and_plots.py
2
4
  import copy
3
- import warnings
4
5
  import math
5
6
  import os
7
+ import warnings
6
8
  from collections import defaultdict
7
- from typing import Dict, Any, Tuple, Union, List, Optional
9
+ from importlib.util import find_spec
10
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
8
11
 
9
- import torch
10
- import anndata as ad
11
12
  import numpy as np
12
13
  import pandas as pd
13
- import matplotlib.pyplot as plt
14
- from tqdm import tqdm
14
+ from scipy.cluster import hierarchy as sch
15
+ from scipy.spatial.distance import pdist, squareform
16
+ from scipy.stats import gaussian_kde
17
+
18
+ from smftools.logging_utils import get_logger
19
+ from smftools.optional_imports import require
15
20
 
16
21
  from ..readwrite import make_dirs
17
22
 
18
- # optional imports for clustering / PCA / KDE
19
- try:
20
- from scipy.cluster import hierarchy as sch
21
- from scipy.spatial.distance import pdist, squareform
22
- SCIPY_AVAILABLE = True
23
- except Exception:
24
- sch = None
25
- pdist = None
26
- squareform = None
27
- SCIPY_AVAILABLE = False
28
-
29
- try:
30
- from sklearn.decomposition import PCA
31
- from sklearn.cluster import KMeans, DBSCAN
32
- from sklearn.mixture import GaussianMixture
33
- from sklearn.metrics import silhouette_score
34
- SKLEARN_AVAILABLE = True
35
- except Exception:
36
- PCA = None
37
- KMeans = DBSCAN = GaussianMixture = silhouette_score = None
38
- SKLEARN_AVAILABLE = False
39
-
40
- try:
41
- from scipy.stats import gaussian_kde
42
- except Exception:
43
- gaussian_kde = None
44
-
45
-
46
- def merge_uns_preserve(orig_uns: dict, new_uns: dict, prefer="orig") -> dict:
47
- """
48
- Merge two .uns dicts. prefer='orig' will keep orig_uns values on conflict,
49
- prefer='new' will keep new_uns values on conflict. Conflicts are reported.
23
+ logger = get_logger(__name__)
24
+
25
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="duplicate read plots")
26
+ torch = require("torch", extra="torch", purpose="duplicate read detection")
27
+
28
+ if TYPE_CHECKING:
29
+ import anndata as ad
30
+
31
+ SCIPY_AVAILABLE = True
32
+ SKLEARN_AVAILABLE = find_spec("sklearn") is not None
33
+
34
+ PCA = None
35
+ KMeans = DBSCAN = GaussianMixture = silhouette_score = None
36
+ if SKLEARN_AVAILABLE:
37
+ sklearn_cluster = require(
38
+ "sklearn.cluster",
39
+ extra="ml-base",
40
+ purpose="duplicate read clustering",
41
+ )
42
+ sklearn_decomp = require(
43
+ "sklearn.decomposition",
44
+ extra="ml-base",
45
+ purpose="duplicate read PCA",
46
+ )
47
+ sklearn_metrics = require(
48
+ "sklearn.metrics",
49
+ extra="ml-base",
50
+ purpose="duplicate read clustering diagnostics",
51
+ )
52
+ sklearn_mixture = require(
53
+ "sklearn.mixture",
54
+ extra="ml-base",
55
+ purpose="duplicate read clustering",
56
+ )
57
+ DBSCAN = sklearn_cluster.DBSCAN
58
+ KMeans = sklearn_cluster.KMeans
59
+ PCA = sklearn_decomp.PCA
60
+ silhouette_score = sklearn_metrics.silhouette_score
61
+ GaussianMixture = sklearn_mixture.GaussianMixture
62
+
63
+
64
+ def merge_uns_preserve(orig_uns: dict, new_uns: dict, prefer: str = "orig") -> dict:
65
+ """Merge two ``.uns`` dictionaries while preserving preferred values.
66
+
67
+ Args:
68
+ orig_uns: Original ``.uns`` dictionary.
69
+ new_uns: New ``.uns`` dictionary to merge.
70
+ prefer: Which dictionary to prefer on conflict (``"orig"`` or ``"new"``).
71
+
72
+ Returns:
73
+ dict: Merged dictionary.
50
74
  """
51
75
  out = copy.deepcopy(new_uns) if new_uns is not None else {}
52
76
  for k, v in (orig_uns or {}).items():
@@ -55,7 +79,7 @@ def merge_uns_preserve(orig_uns: dict, new_uns: dict, prefer="orig") -> dict:
55
79
  else:
56
80
  # present in both: compare quickly (best-effort)
57
81
  try:
58
- equal = (out[k] == v)
82
+ equal = out[k] == v
59
83
  except Exception:
60
84
  equal = False
61
85
  if equal:
@@ -69,9 +93,10 @@ def merge_uns_preserve(orig_uns: dict, new_uns: dict, prefer="orig") -> dict:
69
93
  out[f"orig_uns__{k}"] = copy.deepcopy(v)
70
94
  return out
71
95
 
96
+
72
97
  def flag_duplicate_reads(
73
- adata,
74
- var_filters_sets,
98
+ adata: ad.AnnData,
99
+ var_filters_sets: Sequence[dict[str, Any]],
75
100
  distance_threshold: float = 0.07,
76
101
  obs_reference_col: str = "Reference_strand",
77
102
  sample_col: str = "Barcode",
@@ -81,7 +106,7 @@ def flag_duplicate_reads(
81
106
  uns_filtered_flag: str = "read_duplicates_removed",
82
107
  bypass: bool = False,
83
108
  force_redo: bool = False,
84
- keep_best_metric: Optional[str] = 'read_quality',
109
+ keep_best_metric: Optional[str] = "read_quality",
85
110
  keep_best_higher: bool = True,
86
111
  window_size: int = 50,
87
112
  min_overlap_positions: int = 20,
@@ -93,19 +118,119 @@ def flag_duplicate_reads(
93
118
  hierarchical_metric: str = "euclidean",
94
119
  hierarchical_window: int = 50,
95
120
  random_state: int = 0,
96
- ):
121
+ demux_types: Optional[Sequence[str]] = None,
122
+ demux_col: str = "demux_type",
123
+ ) -> ad.AnnData:
124
+ """Flag duplicate reads with demux-aware keeper preference.
125
+
126
+ Behavior:
127
+ - All reads are processed (no masking by demux).
128
+ - At each keeper decision, prefer reads whose ``demux_col`` value is in
129
+ ``demux_types`` when present. Among candidates, choose by
130
+ ``keep_best_metric``.
131
+
132
+ Args:
133
+ adata: AnnData object to process.
134
+ var_filters_sets: Sequence of variable filter definitions.
135
+ distance_threshold: Distance threshold for duplicate detection.
136
+ obs_reference_col: Obs column containing reference identifiers.
137
+ sample_col: Obs column containing sample identifiers.
138
+ output_directory: Directory for output plots and artifacts.
139
+ metric_keys: Metric key(s) used in processing.
140
+ uns_flag: Flag in ``adata.uns`` indicating prior completion.
141
+ uns_filtered_flag: Flag to mark read duplicates removal.
142
+ bypass: Whether to skip processing.
143
+ force_redo: Whether to rerun even if ``uns_flag`` is set.
144
+ keep_best_metric: Obs column used to select best read within duplicates.
145
+ keep_best_higher: Whether higher values in ``keep_best_metric`` are preferred.
146
+ window_size: Window size for local comparisons.
147
+ min_overlap_positions: Minimum overlapping positions required.
148
+ do_pca: Whether to run PCA before clustering.
149
+ pca_n_components: Number of PCA components.
150
+ pca_center: Whether to center data before PCA.
151
+ do_hierarchical: Whether to run hierarchical clustering.
152
+ hierarchical_linkage: Linkage method for hierarchical clustering.
153
+ hierarchical_metric: Distance metric for hierarchical clustering.
154
+ hierarchical_window: Window size for hierarchical clustering.
155
+ random_state: Random seed.
156
+ demux_types: Preferred demux types for keeper selection.
157
+ demux_col: Obs column containing demux type labels.
158
+
159
+ Returns:
160
+ anndata.AnnData: AnnData object with duplicate flags stored in ``adata.obs``.
97
161
  """
98
- Duplicate-flagging pipeline where hierarchical stage operates only on representatives
99
- (one representative per lex cluster, i.e. the keeper). Final keeper assignment and
100
- enforcement happens only after hierarchical merging.
162
+ import copy
163
+ import warnings
164
+
165
+ import anndata as ad
166
+ import numpy as np
167
+ import pandas as pd
168
+
169
+ # -------- helper: demux-aware keeper selection --------
170
+ def _choose_keeper_with_demux_preference(
171
+ members_idx: List[int],
172
+ adata_subset: ad.AnnData,
173
+ obs_index_list: List[Any],
174
+ *,
175
+ demux_col: str = "demux_type",
176
+ preferred_types: Optional[Sequence[str]] = None,
177
+ keep_best_metric: Optional[str] = None,
178
+ keep_best_higher: bool = True,
179
+ lex_keeper_mask: Optional[np.ndarray] = None, # aligned to members order
180
+ ) -> int:
181
+ """
182
+ Prefer members whose demux_col ∈ preferred_types.
183
+ Among candidates, pick by keep_best_metric (higher/lower).
184
+ If metric missing/NaN, prefer lex keeper (via mask) among candidates.
185
+ Fallback: first candidate.
186
+ Returns the chosen *member index* (int from members_idx).
187
+ """
188
+ # 1) demux-preferred candidates
189
+ if preferred_types and (demux_col in adata_subset.obs.columns):
190
+ preferred = set(map(str, preferred_types))
191
+ demux_series = adata_subset.obs[demux_col].astype("string")
192
+ names = [obs_index_list[m] for m in members_idx]
193
+ is_pref = demux_series.loc[names].isin(preferred).to_numpy()
194
+ candidates = [members_idx[i] for i, ok in enumerate(is_pref) if ok]
195
+ else:
196
+ candidates = []
101
197
 
102
- Returns (adata_unique, adata_full) as before; writes sequence__* columns into adata.obs.
103
- """
104
- # early exits
198
+ if not candidates:
199
+ candidates = list(members_idx)
200
+
201
+ # 2) metric-based within candidates
202
+ if keep_best_metric and (keep_best_metric in adata_subset.obs.columns):
203
+ cand_names = [obs_index_list[m] for m in candidates]
204
+ try:
205
+ vals = pd.to_numeric(
206
+ adata_subset.obs.loc[cand_names, keep_best_metric],
207
+ errors="coerce",
208
+ ).to_numpy(dtype=float)
209
+ except Exception:
210
+ vals = np.array([np.nan] * len(candidates), dtype=float)
211
+
212
+ if not np.all(np.isnan(vals)):
213
+ if keep_best_higher:
214
+ vals = np.where(np.isnan(vals), -np.inf, vals)
215
+ return candidates[int(np.nanargmax(vals))]
216
+ else:
217
+ vals = np.where(np.isnan(vals), np.inf, vals)
218
+ return candidates[int(np.nanargmin(vals))]
219
+
220
+ # 3) metric unhelpful — prefer lex keeper if provided
221
+ if lex_keeper_mask is not None:
222
+ for i, midx in enumerate(members_idx):
223
+ if (midx in candidates) and bool(lex_keeper_mask[i]):
224
+ return midx
225
+
226
+ # 4) fallback
227
+ return candidates[0]
228
+
229
+ # -------- early exits --------
105
230
  already = bool(adata.uns.get(uns_flag, False))
106
- if (already and not force_redo):
231
+ if already and not force_redo:
107
232
  if "is_duplicate" in adata.obs.columns:
108
- adata_unique = adata[adata.obs["is_duplicate"] == False].copy()
233
+ adata_unique = adata[~adata.obs["is_duplicate"]].copy()
109
234
  return adata_unique, adata
110
235
  else:
111
236
  return adata.copy(), adata.copy()
@@ -117,17 +242,23 @@ def flag_duplicate_reads(
117
242
 
118
243
  # local UnionFind
119
244
  class UnionFind:
245
+ """Disjoint-set union-find helper for clustering indices."""
246
+
120
247
  def __init__(self, size):
248
+ """Initialize parent pointers for the union-find."""
121
249
  self.parent = list(range(size))
122
250
 
123
251
  def find(self, x):
252
+ """Find the root for a member with path compression."""
124
253
  while self.parent[x] != x:
125
254
  self.parent[x] = self.parent[self.parent[x]]
126
255
  x = self.parent[x]
127
256
  return x
128
257
 
129
258
  def union(self, x, y):
130
- rx = self.find(x); ry = self.find(y)
259
+ """Union the sets that contain x and y."""
260
+ rx = self.find(x)
261
+ ry = self.find(y)
131
262
  if rx != ry:
132
263
  self.parent[ry] = rx
133
264
 
@@ -139,14 +270,14 @@ def flag_duplicate_reads(
139
270
 
140
271
  for sample in samples:
141
272
  for ref in references:
142
- print(f"Processing sample={sample} ref={ref}")
273
+ logger.info("Processing sample=%s ref=%s", sample, ref)
143
274
  sample_mask = adata.obs[sample_col] == sample
144
275
  ref_mask = adata.obs[obs_reference_col] == ref
145
276
  subset_mask = sample_mask & ref_mask
146
277
  adata_subset = adata[subset_mask].copy()
147
278
 
148
279
  if adata_subset.n_obs < 2:
149
- print(f" Skipping {sample}_{ref} (too few reads)")
280
+ logger.info(" Skipping %s_%s (too few reads)", sample, ref)
150
281
  continue
151
282
 
152
283
  N = adata_subset.shape[0]
@@ -162,7 +293,12 @@ def flag_duplicate_reads(
162
293
 
163
294
  selected_cols = adata.var.index[combined_mask.tolist()].to_list()
164
295
  col_indices = [adata.var.index.get_loc(c) for c in selected_cols]
165
- print(f" Selected {len(col_indices)} columns out of {adata.var.shape[0]} for {ref}")
296
+ logger.info(
297
+ " Selected %s columns out of %s for %s",
298
+ len(col_indices),
299
+ adata.var.shape[0],
300
+ ref,
301
+ )
166
302
 
167
303
  # Extract data matrix (dense numpy) for the subset
168
304
  X = adata_subset.X
@@ -187,7 +323,10 @@ def flag_duplicate_reads(
187
323
  hierarchical_found_dists = []
188
324
 
189
325
  # Lexicographic windowed pass function
190
- def cluster_pass(X_tensor_local, reverse=False, window=int(window_size), record_distances=False):
326
+ def cluster_pass(
327
+ X_tensor_local, reverse=False, window=int(window_size), record_distances=False
328
+ ):
329
+ """Perform a lexicographic windowed clustering pass."""
191
330
  N_local = X_tensor_local.shape[0]
192
331
  X_sortable = X_tensor_local.clone().nan_to_num(-1.0)
193
332
  sort_keys = [tuple(row.numpy().tolist()) for row in X_sortable]
@@ -208,10 +347,16 @@ def flag_duplicate_reads(
208
347
  if enough_overlap.any():
209
348
  diffs = (row_i_exp != block_rows) & valid_mask
210
349
  hamming_counts = diffs.sum(dim=1).float()
211
- hamming_dists = torch.where(valid_counts > 0, hamming_counts / valid_counts, torch.tensor(float("nan")))
350
+ hamming_dists = torch.where(
351
+ valid_counts > 0,
352
+ hamming_counts / valid_counts,
353
+ torch.tensor(float("nan")),
354
+ )
212
355
  # record distances (legacy list of all local comparisons)
213
356
  hamming_np = hamming_dists.cpu().numpy().tolist()
214
- local_hamming_dists.extend([float(x) for x in hamming_np if (not np.isnan(x))])
357
+ local_hamming_dists.extend(
358
+ [float(x) for x in hamming_np if (not np.isnan(x))]
359
+ )
215
360
  matches = (hamming_dists < distance_threshold) & (enough_overlap)
216
361
  for offset_local, m in enumerate(matches):
217
362
  if m:
@@ -223,20 +368,28 @@ def flag_duplicate_reads(
223
368
  next_local_idx = i + 1
224
369
  if next_local_idx < len(sorted_X):
225
370
  next_global = sorted_idx[next_local_idx]
226
- vm_pair = (~torch.isnan(row_i)) & (~torch.isnan(sorted_X[next_local_idx]))
371
+ vm_pair = (~torch.isnan(row_i)) & (
372
+ ~torch.isnan(sorted_X[next_local_idx])
373
+ )
227
374
  vc = vm_pair.sum().item()
228
375
  if vc >= min_overlap_positions:
229
- d = float(((row_i[vm_pair] != sorted_X[next_local_idx][vm_pair]).sum().item()) / vc)
376
+ d = float(
377
+ (
378
+ (row_i[vm_pair] != sorted_X[next_local_idx][vm_pair])
379
+ .sum()
380
+ .item()
381
+ )
382
+ / vc
383
+ )
230
384
  if reverse:
231
385
  rev_hamming_to_prev[next_global] = d
232
386
  else:
233
387
  fwd_hamming_to_next[sorted_idx[i]] = d
234
388
  return cluster_pairs_local
235
389
 
236
- # run forward pass
390
+ # run forward & reverse windows
237
391
  pairs_fwd = cluster_pass(X_tensor, reverse=False, record_distances=True)
238
392
  involved_in_fwd = set([p for pair in pairs_fwd for p in pair])
239
- # build mask for reverse pass to avoid re-checking items already paired
240
393
  mask_for_rev = np.ones(N, dtype=bool)
241
394
  if len(involved_in_fwd) > 0:
242
395
  for idx in involved_in_fwd:
@@ -245,8 +398,9 @@ def flag_duplicate_reads(
245
398
  if len(rev_idx_map) > 0:
246
399
  reduced_tensor = X_tensor[rev_idx_map]
247
400
  pairs_rev_local = cluster_pass(reduced_tensor, reverse=True, record_distances=True)
248
- # remap local reduced indices to global
249
- remapped_rev_pairs = [(int(rev_idx_map[i]), int(rev_idx_map[j])) for (i, j) in pairs_rev_local]
401
+ remapped_rev_pairs = [
402
+ (int(rev_idx_map[i]), int(rev_idx_map[j])) for (i, j) in pairs_rev_local
403
+ ]
250
404
  else:
251
405
  remapped_rev_pairs = []
252
406
 
@@ -265,53 +419,41 @@ def flag_duplicate_reads(
265
419
  id_map = {old: new for new, old in enumerate(sorted(unique_initial.tolist()))}
266
420
  merged_cluster_mapped = np.array([id_map[int(x)] for x in merged_cluster], dtype=int)
267
421
 
268
- # cluster sizes and choose lex-keeper per lex-cluster (representatives)
422
+ # cluster sizes and choose lex-keeper per lex-cluster (demux-aware)
269
423
  cluster_sizes = np.zeros_like(merged_cluster_mapped)
270
424
  cluster_counts = []
271
425
  unique_clusters = np.unique(merged_cluster_mapped)
272
426
  keeper_for_cluster = {}
427
+
428
+ obs_index = list(adata_subset.obs.index)
273
429
  for cid in unique_clusters:
274
430
  members = np.where(merged_cluster_mapped == cid)[0].tolist()
275
431
  csize = int(len(members))
276
432
  cluster_counts.append(csize)
277
433
  cluster_sizes[members] = csize
278
- # pick lex keeper (representative)
279
- if len(members) == 1:
280
- keeper_for_cluster[cid] = members[0]
281
- else:
282
- if keep_best_metric is None:
283
- keeper_for_cluster[cid] = members[0]
284
- else:
285
- obs_index = list(adata_subset.obs.index)
286
- member_names = [obs_index[m] for m in members]
287
- try:
288
- vals = pd.to_numeric(adata_subset.obs.loc[member_names, keep_best_metric], errors="coerce").to_numpy(dtype=float)
289
- except Exception:
290
- vals = np.array([np.nan] * len(members), dtype=float)
291
- if np.all(np.isnan(vals)):
292
- keeper_for_cluster[cid] = members[0]
293
- else:
294
- if keep_best_higher:
295
- nan_mask = np.isnan(vals)
296
- vals[nan_mask] = -np.inf
297
- rel_idx = int(np.nanargmax(vals))
298
- else:
299
- nan_mask = np.isnan(vals)
300
- vals[nan_mask] = np.inf
301
- rel_idx = int(np.nanargmin(vals))
302
- keeper_for_cluster[cid] = members[rel_idx]
303
434
 
304
- # expose lex keeper info (record only; do not enforce deletion yet)
435
+ keeper_for_cluster[cid] = _choose_keeper_with_demux_preference(
436
+ members,
437
+ adata_subset,
438
+ obs_index,
439
+ demux_col=demux_col,
440
+ preferred_types=demux_types,
441
+ keep_best_metric=keep_best_metric,
442
+ keep_best_higher=keep_best_higher,
443
+ lex_keeper_mask=None, # no lex preference yet
444
+ )
445
+
446
+ # expose lex keeper info (record only)
305
447
  lex_is_keeper = np.zeros((N,), dtype=bool)
306
448
  lex_is_duplicate = np.zeros((N,), dtype=bool)
307
- for cid, members in zip(unique_clusters, [np.where(merged_cluster_mapped == cid)[0].tolist() for cid in unique_clusters]):
449
+ for cid in unique_clusters:
450
+ members = np.where(merged_cluster_mapped == cid)[0].tolist()
308
451
  keeper_idx = keeper_for_cluster[cid]
309
452
  lex_is_keeper[keeper_idx] = True
310
453
  for m in members:
311
454
  if m != keeper_idx:
312
455
  lex_is_duplicate[m] = True
313
- # note: these are just recorded for inspection / later preference
314
- # and will be written to adata_subset.obs below
456
+
315
457
  # record lex min pair (min of fwd/rev neighbor) for each read
316
458
  min_pair = np.full((N,), np.nan, dtype=float)
317
459
  for i in range(N):
@@ -349,7 +491,12 @@ def flag_duplicate_reads(
349
491
  if n_comp <= 0:
350
492
  reps_for_clustering = reps_arr
351
493
  else:
352
- pca = PCA(n_components=n_comp, random_state=int(random_state), svd_solver="auto", copy=True)
494
+ pca = PCA(
495
+ n_components=n_comp,
496
+ random_state=int(random_state),
497
+ svd_solver="auto",
498
+ copy=True,
499
+ )
353
500
  reps_for_clustering = pca.fit_transform(reps_arr)
354
501
  else:
355
502
  reps_for_clustering = reps_arr
@@ -360,10 +507,12 @@ def flag_duplicate_reads(
360
507
  Z = sch.linkage(pdist_vec, method=hierarchical_linkage)
361
508
  leaves = sch.leaves_list(Z)
362
509
  except Exception as e:
363
- warnings.warn(f"hierarchical pass failed: {e}; skipping hierarchical stage.")
510
+ warnings.warn(
511
+ f"hierarchical pass failed: {e}; skipping hierarchical stage."
512
+ )
364
513
  leaves = np.arange(len(rep_global_indices), dtype=int)
365
514
 
366
- # apply windowed hamming comparisons across ordered reps and union via same UF (so clusters of all reads merge)
515
+ # windowed hamming comparisons across ordered reps and union
367
516
  order_global_reps = [rep_global_indices[i] for i in leaves]
368
517
  n_reps = len(order_global_reps)
369
518
  for pos in range(n_reps):
@@ -389,55 +538,40 @@ def flag_duplicate_reads(
389
538
  merged_cluster_after[i] = uf.find(i)
390
539
  unique_final = np.unique(merged_cluster_after)
391
540
  id_map_final = {old: new for new, old in enumerate(sorted(unique_final.tolist()))}
392
- merged_cluster_mapped_final = np.array([id_map_final[int(x)] for x in merged_cluster_after], dtype=int)
541
+ merged_cluster_mapped_final = np.array(
542
+ [id_map_final[int(x)] for x in merged_cluster_after], dtype=int
543
+ )
393
544
 
394
- # compute final cluster members and choose final keeper per final cluster
545
+ # compute final cluster members and choose final keeper per final cluster (demux-aware)
395
546
  cluster_sizes_final = np.zeros_like(merged_cluster_mapped_final)
396
- final_cluster_counts = []
397
- final_unique = np.unique(merged_cluster_mapped_final)
398
547
  final_keeper_for_cluster = {}
399
548
  cluster_members_map = {}
400
- for cid in final_unique:
549
+
550
+ obs_index = list(adata_subset.obs.index)
551
+ lex_mask_full = lex_is_keeper # use lex keeper as optional tiebreaker
552
+
553
+ for cid in np.unique(merged_cluster_mapped_final):
401
554
  members = np.where(merged_cluster_mapped_final == cid)[0].tolist()
402
555
  cluster_members_map[cid] = members
403
- csize = len(members)
404
- final_cluster_counts.append(csize)
405
- cluster_sizes_final[members] = csize
406
- if csize == 1:
407
- final_keeper_for_cluster[cid] = members[0]
408
- else:
409
- # prefer keep_best_metric if available; do not automatically prefer lex-keeper here unless you want to;
410
- # (user previously asked for preferring lex keepers — if desired, you can prefer lex_is_keeper among members)
411
- obs_index = list(adata_subset.obs.index)
412
- member_names = [obs_index[m] for m in members]
413
- if keep_best_metric is not None and keep_best_metric in adata_subset.obs.columns:
414
- try:
415
- vals = pd.to_numeric(adata_subset.obs.loc[member_names, keep_best_metric], errors="coerce").to_numpy(dtype=float)
416
- except Exception:
417
- vals = np.array([np.nan] * len(members), dtype=float)
418
- if np.all(np.isnan(vals)):
419
- final_keeper_for_cluster[cid] = members[0]
420
- else:
421
- if keep_best_higher:
422
- nan_mask = np.isnan(vals)
423
- vals[nan_mask] = -np.inf
424
- rel_idx = int(np.nanargmax(vals))
425
- else:
426
- nan_mask = np.isnan(vals)
427
- vals[nan_mask] = np.inf
428
- rel_idx = int(np.nanargmin(vals))
429
- final_keeper_for_cluster[cid] = members[rel_idx]
430
- else:
431
- # if lex keepers present among members, prefer them
432
- lex_members = [m for m in members if lex_is_keeper[m]]
433
- if len(lex_members) > 0:
434
- final_keeper_for_cluster[cid] = lex_members[0]
435
- else:
436
- final_keeper_for_cluster[cid] = members[0]
437
-
438
- # update sequence__is_duplicate based on final clusters: non-keepers in multi-member clusters are duplicates
556
+ cluster_sizes_final[members] = len(members)
557
+
558
+ lex_mask_members = np.array([bool(lex_mask_full[m]) for m in members], dtype=bool)
559
+
560
+ keeper = _choose_keeper_with_demux_preference(
561
+ members,
562
+ adata_subset,
563
+ obs_index,
564
+ demux_col=demux_col,
565
+ preferred_types=demux_types,
566
+ keep_best_metric=keep_best_metric,
567
+ keep_best_higher=keep_best_higher,
568
+ lex_keeper_mask=lex_mask_members,
569
+ )
570
+ final_keeper_for_cluster[cid] = keeper
571
+
572
+ # update sequence__is_duplicate based on final clusters
439
573
  sequence_is_duplicate = np.zeros((N,), dtype=bool)
440
- for cid in final_unique:
574
+ for cid in np.unique(merged_cluster_mapped_final):
441
575
  keeper = final_keeper_for_cluster[cid]
442
576
  members = cluster_members_map[cid]
443
577
  if len(members) > 1:
@@ -446,8 +580,7 @@ def flag_duplicate_reads(
446
580
  sequence_is_duplicate[m] = True
447
581
 
448
582
  # propagate hierarchical distances into hierarchical_min_pair for all cluster members
449
- for (i_g, j_g, d) in hierarchical_pairs:
450
- # identify their final cluster ids (after unions)
583
+ for i_g, j_g, d in hierarchical_pairs:
451
584
  c_i = merged_cluster_mapped_final[int(i_g)]
452
585
  c_j = merged_cluster_mapped_final[int(j_g)]
453
586
  members_i = cluster_members_map.get(c_i, [int(i_g)])
@@ -459,7 +592,7 @@ def flag_duplicate_reads(
459
592
  if np.isnan(hierarchical_min_pair[mj]) or (d < hierarchical_min_pair[mj]):
460
593
  hierarchical_min_pair[mj] = d
461
594
 
462
- # combine lex-phase min_pair and hierarchical_min_pair into the final sequence__min_hamming_to_pair
595
+ # combine min pairs
463
596
  combined_min = min_pair.copy()
464
597
  for i in range(N):
465
598
  hval = hierarchical_min_pair[i]
@@ -475,69 +608,117 @@ def flag_duplicate_reads(
475
608
  adata_subset.obs["rev_hamming_to_prev"] = rev_hamming_to_prev
476
609
  adata_subset.obs["sequence__hier_hamming_to_pair"] = hierarchical_min_pair
477
610
  adata_subset.obs["sequence__min_hamming_to_pair"] = combined_min
478
- # persist lex bookkeeping columns (informational)
611
+ # persist lex bookkeeping
479
612
  adata_subset.obs["sequence__lex_is_keeper"] = lex_is_keeper
480
613
  adata_subset.obs["sequence__lex_is_duplicate"] = lex_is_duplicate
481
614
 
482
615
  adata_processed_list.append(adata_subset)
483
616
 
484
- histograms.append({
485
- "sample": sample,
486
- "reference": ref,
487
- "distances": local_hamming_dists, # lex local comparisons
488
- "cluster_counts": final_cluster_counts,
489
- "hierarchical_pairs": hierarchical_found_dists,
490
- })
491
-
492
- # Merge annotated subsets back together BEFORE plotting so plotting sees fwd_hamming_to_next, etc.
617
+ histograms.append(
618
+ {
619
+ "sample": sample,
620
+ "reference": ref,
621
+ "distances": local_hamming_dists,
622
+ "cluster_counts": [
623
+ int(x) for x in np.unique(cluster_sizes_final[cluster_sizes_final > 0])
624
+ ],
625
+ "hierarchical_pairs": hierarchical_found_dists,
626
+ }
627
+ )
628
+
629
+ # Merge annotated subsets back together BEFORE plotting
493
630
  _original_uns = copy.deepcopy(adata.uns)
494
631
  if len(adata_processed_list) == 0:
495
632
  return adata.copy(), adata.copy()
496
633
 
497
634
  adata_full = ad.concat(adata_processed_list, merge="same", join="outer", index_unique=None)
635
+
636
+ # preserve uns (prefer original on conflicts)
637
+ def merge_uns_preserve(orig_uns: dict, new_uns: dict, prefer="orig") -> dict:
638
+ """Merge .uns dictionaries while preserving original on conflicts."""
639
+ out = copy.deepcopy(new_uns) if new_uns is not None else {}
640
+ for k, v in (orig_uns or {}).items():
641
+ if k not in out:
642
+ out[k] = copy.deepcopy(v)
643
+ else:
644
+ try:
645
+ equal = out[k] == v
646
+ except Exception:
647
+ equal = False
648
+ if equal:
649
+ continue
650
+ if prefer == "orig":
651
+ out[k] = copy.deepcopy(v)
652
+ else:
653
+ out[f"orig_uns__{k}"] = copy.deepcopy(v)
654
+ return out
655
+
498
656
  adata_full.uns = merge_uns_preserve(_original_uns, adata_full.uns, prefer="orig")
499
657
 
500
658
  # Ensure expected numeric columns exist (create if missing)
501
- for col in ("fwd_hamming_to_next", "rev_hamming_to_prev", "sequence__min_hamming_to_pair", "sequence__hier_hamming_to_pair"):
659
+ for col in (
660
+ "fwd_hamming_to_next",
661
+ "rev_hamming_to_prev",
662
+ "sequence__min_hamming_to_pair",
663
+ "sequence__hier_hamming_to_pair",
664
+ ):
502
665
  if col not in adata_full.obs.columns:
503
666
  adata_full.obs[col] = np.nan
504
667
 
505
- # histograms (now driven by adata_full if requested)
506
- hist_outs = os.path.join(output_directory, "read_pair_hamming_distance_histograms")
507
- make_dirs([hist_outs])
508
- plot_histogram_pages(histograms,
509
- distance_threshold=distance_threshold,
510
- adata=adata_full,
511
- output_directory=hist_outs,
512
- distance_types=["min","fwd","rev","hier","lex_local"],
513
- sample_key=sample_col,
514
- )
668
+ # histograms
669
+ hist_outs = (
670
+ os.path.join(output_directory, "read_pair_hamming_distance_histograms")
671
+ if output_directory
672
+ else None
673
+ )
674
+ if hist_outs:
675
+ make_dirs([hist_outs])
676
+ plot_histogram_pages(
677
+ histograms,
678
+ distance_threshold=distance_threshold,
679
+ adata=adata_full,
680
+ output_directory=hist_outs,
681
+ distance_types=["min", "fwd", "rev", "hier", "lex_local"],
682
+ sample_key=sample_col,
683
+ )
515
684
 
516
685
  # hamming vs metric scatter
517
- scatter_outs = os.path.join(output_directory, "read_pair_hamming_distance_scatter_plots")
518
- make_dirs([scatter_outs])
519
- plot_hamming_vs_metric_pages(adata_full,
520
- metric_keys=metric_keys,
521
- output_dir=scatter_outs,
522
- hamming_col="sequence__min_hamming_to_pair",
523
- highlight_threshold=distance_threshold,
524
- highlight_color="red",
525
- sample_col=sample_col)
686
+ scatter_outs = (
687
+ os.path.join(output_directory, "read_pair_hamming_distance_scatter_plots")
688
+ if output_directory
689
+ else None
690
+ )
691
+ if scatter_outs:
692
+ make_dirs([scatter_outs])
693
+ plot_hamming_vs_metric_pages(
694
+ adata_full,
695
+ metric_keys=metric_keys,
696
+ output_dir=scatter_outs,
697
+ hamming_col="sequence__min_hamming_to_pair",
698
+ highlight_threshold=distance_threshold,
699
+ highlight_color="red",
700
+ sample_col=sample_col,
701
+ )
526
702
 
527
703
  # boolean columns from neighbor distances
528
- fwd_vals = pd.to_numeric(adata_full.obs.get("fwd_hamming_to_next", pd.Series(np.nan, index=adata_full.obs.index)), errors="coerce")
529
- rev_vals = pd.to_numeric(adata_full.obs.get("rev_hamming_to_prev", pd.Series(np.nan, index=adata_full.obs.index)), errors="coerce")
704
+ fwd_vals = pd.to_numeric(
705
+ adata_full.obs.get("fwd_hamming_to_next", pd.Series(np.nan, index=adata_full.obs.index)),
706
+ errors="coerce",
707
+ )
708
+ rev_vals = pd.to_numeric(
709
+ adata_full.obs.get("rev_hamming_to_prev", pd.Series(np.nan, index=adata_full.obs.index)),
710
+ errors="coerce",
711
+ )
530
712
  is_dup_dist = (fwd_vals < float(distance_threshold)) | (rev_vals < float(distance_threshold))
531
713
  is_dup_dist = is_dup_dist.fillna(False).astype(bool)
532
714
  adata_full.obs["is_duplicate_distance"] = is_dup_dist.values
533
715
 
534
- # combine sequence-derived flag with others
535
- if "sequence__is_duplicate" in adata_full.obs.columns:
536
- seq_dup = adata_full.obs["sequence__is_duplicate"].astype(bool)
537
- else:
538
- seq_dup = pd.Series(False, index=adata_full.obs.index)
539
-
540
- # cluster-based duplicate indicator (if any clustering columns exist)
716
+ # combine with sequence flag and any clustering flags
717
+ seq_dup = (
718
+ adata_full.obs["sequence__is_duplicate"].astype(bool)
719
+ if "sequence__is_duplicate" in adata_full.obs.columns
720
+ else pd.Series(False, index=adata_full.obs.index)
721
+ )
541
722
  cluster_cols = [c for c in adata_full.obs.columns if c.startswith("hamming_cluster__")]
542
723
  if cluster_cols:
543
724
  cl_mask = pd.Series(False, index=adata_full.obs.index)
@@ -550,59 +731,61 @@ def flag_duplicate_reads(
550
731
  else:
551
732
  adata_full.obs["is_duplicate_clustering"] = False
552
733
 
553
- final_dup = seq_dup | adata_full.obs["is_duplicate_distance"].astype(bool) | adata_full.obs["is_duplicate_clustering"].astype(bool)
734
+ final_dup = (
735
+ seq_dup
736
+ | adata_full.obs["is_duplicate_distance"].astype(bool)
737
+ | adata_full.obs["is_duplicate_clustering"].astype(bool)
738
+ )
554
739
  adata_full.obs["is_duplicate"] = final_dup.values
555
740
 
556
- # Final keeper enforcement: recompute per-cluster keeper from sequence__merged_cluster_id and
557
- # ensure that keeper is not marked duplicate
558
- if "sequence__merged_cluster_id" in adata_full.obs.columns:
559
- keeper_idx_by_cluster = {}
560
- metric_col = keep_best_metric if 'keep_best_metric' in locals() else None
741
+ # -------- Final keeper enforcement on adata_full (demux-aware) --------
742
+ keeper_idx_by_cluster = {}
743
+ metric_col = keep_best_metric
561
744
 
562
- # group by cluster id
563
- grp = adata_full.obs[["sequence__merged_cluster_id", "sequence__cluster_size"]].copy()
564
- for cid, sub in grp.groupby("sequence__merged_cluster_id"):
565
- try:
566
- members = sub.index.to_list()
567
- except Exception:
568
- members = list(sub.index)
569
- keeper = None
570
- # prefer keep_best_metric (if present), else prefer lex keeper among members, else first member
571
- if metric_col and metric_col in adata_full.obs.columns:
572
- try:
573
- vals = pd.to_numeric(adata_full.obs.loc[members, metric_col], errors="coerce")
574
- if vals.notna().any():
575
- keeper = vals.idxmax() if keep_best_higher else vals.idxmin()
576
- else:
577
- keeper = members[0]
578
- except Exception:
579
- keeper = members[0]
580
- else:
581
- # prefer lex keeper if present in this merged cluster
582
- lex_candidates = [m for m in members if ("sequence__lex_is_keeper" in adata_full.obs.columns and adata_full.obs.at[m, "sequence__lex_is_keeper"])]
583
- if len(lex_candidates) > 0:
584
- keeper = lex_candidates[0]
585
- else:
586
- keeper = members[0]
745
+ # Build an index→row-number mapping
746
+ name_to_pos = {name: i for i, name in enumerate(adata_full.obs.index)}
747
+ obs_index_full = list(adata_full.obs.index)
587
748
 
588
- keeper_idx_by_cluster[cid] = keeper
749
+ lex_col = (
750
+ "sequence__lex_is_keeper" if "sequence__lex_is_keeper" in adata_full.obs.columns else None
751
+ )
589
752
 
590
- # force keepers not to be duplicates
591
- is_dup_series = adata_full.obs["is_duplicate"].astype(bool)
592
- for cid, keeper_idx in keeper_idx_by_cluster.items():
593
- if keeper_idx in adata_full.obs.index:
594
- is_dup_series.at[keeper_idx] = False
595
- # clear sequence__is_duplicate for keeper if present
596
- if "sequence__is_duplicate" in adata_full.obs.columns:
597
- adata_full.obs.at[keeper_idx, "sequence__is_duplicate"] = False
598
- # clear lex duplicate flag too if present
599
- if "sequence__lex_is_duplicate" in adata_full.obs.columns:
600
- adata_full.obs.at[keeper_idx, "sequence__lex_is_duplicate"] = False
753
+ for cid, sub in adata_full.obs.groupby("sequence__merged_cluster_id", dropna=False):
754
+ members_names = sub.index.to_list()
755
+ members_pos = [name_to_pos[n] for n in members_names]
601
756
 
602
- adata_full.obs["is_duplicate"] = is_dup_series.values
757
+ if lex_col:
758
+ lex_mask_members = adata_full.obs.loc[members_names, lex_col].astype(bool).to_numpy()
759
+ else:
760
+ lex_mask_members = np.zeros(len(members_pos), dtype=bool)
761
+
762
+ keeper_pos = _choose_keeper_with_demux_preference(
763
+ members_pos,
764
+ adata_full,
765
+ obs_index_full,
766
+ demux_col=demux_col,
767
+ preferred_types=demux_types,
768
+ keep_best_metric=metric_col,
769
+ keep_best_higher=keep_best_higher,
770
+ lex_keeper_mask=lex_mask_members,
771
+ )
772
+ keeper_name = obs_index_full[keeper_pos]
773
+ keeper_idx_by_cluster[cid] = keeper_name
774
+
775
+ # enforce: keepers are not duplicates
776
+ is_dup_series = adata_full.obs["is_duplicate"].astype(bool)
777
+ for cid, keeper_name in keeper_idx_by_cluster.items():
778
+ if keeper_name in adata_full.obs.index:
779
+ is_dup_series.at[keeper_name] = False
780
+ if "sequence__is_duplicate" in adata_full.obs.columns:
781
+ adata_full.obs.at[keeper_name, "sequence__is_duplicate"] = False
782
+ if "sequence__lex_is_duplicate" in adata_full.obs.columns:
783
+ adata_full.obs.at[keeper_name, "sequence__lex_is_duplicate"] = False
784
+ adata_full.obs["is_duplicate"] = is_dup_series.values
603
785
 
604
786
  # reason column
605
787
  def _dup_reason_row(row):
788
+ """Build a semi-colon delimited duplicate reason string."""
606
789
  reasons = []
607
790
  if row.get("is_duplicate_distance", False):
608
791
  reasons.append("distance_thresh")
@@ -632,6 +815,7 @@ def flag_duplicate_reads(
632
815
  # Plot helpers (use adata_full as input)
633
816
  # ---------------------------
634
817
 
818
+
635
819
  def plot_histogram_pages(
636
820
  histograms,
637
821
  distance_threshold,
@@ -674,10 +858,11 @@ def plot_histogram_pages(
674
858
  use_adata = False
675
859
 
676
860
  if len(samples) == 0 or len(references) == 0:
677
- print("No histogram data to plot.")
861
+ logger.info("No histogram data to plot.")
678
862
  return {"distance_pages": [], "cluster_size_pages": []}
679
863
 
680
864
  def clean_array(arr):
865
+ """Filter array values to finite [0, 1] range for plotting."""
681
866
  if arr is None or len(arr) == 0:
682
867
  return np.array([], dtype=float)
683
868
  a = np.asarray(arr, dtype=float)
@@ -707,7 +892,9 @@ def plot_histogram_pages(
707
892
  if "rev" in distance_types and "rev_hamming_to_prev" in group.columns:
708
893
  grid[(s, r)]["rev"].extend(clean_array(group["rev_hamming_to_prev"].to_numpy()))
709
894
  if "hier" in distance_types and "sequence__hier_hamming_to_pair" in group.columns:
710
- grid[(s, r)]["hier"].extend(clean_array(group["sequence__hier_hamming_to_pair"].to_numpy()))
895
+ grid[(s, r)]["hier"].extend(
896
+ clean_array(group["sequence__hier_hamming_to_pair"].to_numpy())
897
+ )
711
898
  else:
712
899
  for (s, r), group in grouped:
713
900
  if "min" in distance_types and distance_key in group.columns:
@@ -717,7 +904,9 @@ def plot_histogram_pages(
717
904
  if "rev" in distance_types and "rev_hamming_to_prev" in group.columns:
718
905
  grid[(s, r)]["rev"].extend(clean_array(group["rev_hamming_to_prev"].to_numpy()))
719
906
  if "hier" in distance_types and "sequence__hier_hamming_to_pair" in group.columns:
720
- grid[(s, r)]["hier"].extend(clean_array(group["sequence__hier_hamming_to_pair"].to_numpy()))
907
+ grid[(s, r)]["hier"].extend(
908
+ clean_array(group["sequence__hier_hamming_to_pair"].to_numpy())
909
+ )
721
910
 
722
911
  # legacy histograms fallback
723
912
  if histograms:
@@ -753,9 +942,17 @@ def plot_histogram_pages(
753
942
 
754
943
  # counts (for labels)
755
944
  if use_adata:
756
- counts = {(s, r): int(((adata.obs[sample_key] == s) & (adata.obs[ref_key] == r)).sum()) for s in samples for r in references}
945
+ counts = {
946
+ (s, r): int(((adata.obs[sample_key] == s) & (adata.obs[ref_key] == r)).sum())
947
+ for s in samples
948
+ for r in references
949
+ }
757
950
  else:
758
- counts = {(s, r): sum(len(grid[(s, r)][dt]) for dt in distance_types) for s in samples for r in references}
951
+ counts = {
952
+ (s, r): sum(len(grid[(s, r)][dt]) for dt in distance_types)
953
+ for s in samples
954
+ for r in references
955
+ }
759
956
 
760
957
  distance_pages = []
761
958
  cluster_size_pages = []
@@ -773,7 +970,9 @@ def plot_histogram_pages(
773
970
  # Distance histogram page
774
971
  fig_w = figsize_per_cell[0] * ncols
775
972
  fig_h = figsize_per_cell[1] * nrows
776
- fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False)
973
+ fig, axes = plt.subplots(
974
+ nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False
975
+ )
777
976
 
778
977
  for r_idx, sample_name in enumerate(chunk):
779
978
  for c_idx, ref_name in enumerate(references):
@@ -789,17 +988,37 @@ def plot_histogram_pages(
789
988
  vals = vals[(vals >= 0.0) & (vals <= ref_vmax)]
790
989
  if vals.size > 0:
791
990
  any_data = True
792
- ax.hist(vals, bins=bins_edges, alpha=0.5, label=dtype, density=False, stacked=False,
793
- color=dtype_colors.get(dtype, None))
991
+ ax.hist(
992
+ vals,
993
+ bins=bins_edges,
994
+ alpha=0.5,
995
+ label=dtype,
996
+ density=False,
997
+ stacked=False,
998
+ color=dtype_colors.get(dtype, None),
999
+ )
794
1000
  if not any_data:
795
- ax.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes, fontsize=10, color="gray")
1001
+ ax.text(
1002
+ 0.5,
1003
+ 0.5,
1004
+ "No data",
1005
+ ha="center",
1006
+ va="center",
1007
+ transform=ax.transAxes,
1008
+ fontsize=10,
1009
+ color="gray",
1010
+ )
796
1011
  # threshold line (make sure it is within axis)
797
1012
  ax.axvline(distance_threshold, color="red", linestyle="--", linewidth=1)
798
1013
 
799
1014
  if r_idx == 0:
800
1015
  ax.set_title(str(ref_name), fontsize=10)
801
1016
  if c_idx == 0:
802
- total_reads = sum(counts.get((sample_name, ref), 0) for ref in references) if not use_adata else int((adata.obs[sample_key] == sample_name).sum())
1017
+ total_reads = (
1018
+ sum(counts.get((sample_name, ref), 0) for ref in references)
1019
+ if not use_adata
1020
+ else int((adata.obs[sample_key] == sample_name).sum())
1021
+ )
803
1022
  ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=9)
804
1023
  if r_idx == nrows - 1:
805
1024
  ax.set_xlabel("Hamming Distance", fontsize=9)
@@ -811,12 +1030,16 @@ def plot_histogram_pages(
811
1030
  if r_idx == 0 and c_idx == 0:
812
1031
  ax.legend(fontsize=7, loc="upper right")
813
1032
 
814
- fig.suptitle(f"Hamming distance histograms (rows=samples, cols=references) — page {page+1}/{n_pages}", fontsize=12, y=0.995)
1033
+ fig.suptitle(
1034
+ f"Hamming distance histograms (rows=samples, cols=references) — page {page + 1}/{n_pages}",
1035
+ fontsize=12,
1036
+ y=0.995,
1037
+ )
815
1038
  fig.tight_layout(rect=[0, 0, 1, 0.96])
816
1039
 
817
1040
  if output_directory:
818
1041
  os.makedirs(output_directory, exist_ok=True)
819
- fname = os.path.join(output_directory, f"hamming_histograms_page_{page+1}.png")
1042
+ fname = os.path.join(output_directory, f"hamming_histograms_page_{page + 1}.png")
820
1043
  plt.savefig(fname, bbox_inches="tight")
821
1044
  distance_pages.append(fname)
822
1045
  else:
@@ -826,22 +1049,43 @@ def plot_histogram_pages(
826
1049
  # Cluster-size histogram page (unchanged except it uses adata-derived sizes per cluster if available)
827
1050
  fig_w = figsize_per_cell[0] * ncols
828
1051
  fig_h = figsize_per_cell[1] * nrows
829
- fig2, axes2 = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False)
1052
+ fig2, axes2 = plt.subplots(
1053
+ nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False
1054
+ )
830
1055
 
831
1056
  for r_idx, sample_name in enumerate(chunk):
832
1057
  for c_idx, ref_name in enumerate(references):
833
1058
  ax = axes2[r_idx][c_idx]
834
1059
  sizes = []
835
- if use_adata and ("sequence__merged_cluster_id" in adata.obs.columns and "sequence__cluster_size" in adata.obs.columns):
836
- sub = adata.obs[(adata.obs[sample_key] == sample_name) & (adata.obs[ref_key] == ref_name)]
1060
+ if use_adata and (
1061
+ "sequence__merged_cluster_id" in adata.obs.columns
1062
+ and "sequence__cluster_size" in adata.obs.columns
1063
+ ):
1064
+ sub = adata.obs[
1065
+ (adata.obs[sample_key] == sample_name) & (adata.obs[ref_key] == ref_name)
1066
+ ]
837
1067
  if not sub.empty:
838
1068
  try:
839
- grp = sub.groupby("sequence__merged_cluster_id")["sequence__cluster_size"].first()
840
- sizes = [int(x) for x in grp.to_numpy().tolist() if (pd.notna(x) and np.isfinite(x))]
1069
+ grp = sub.groupby("sequence__merged_cluster_id")[
1070
+ "sequence__cluster_size"
1071
+ ].first()
1072
+ sizes = [
1073
+ int(x)
1074
+ for x in grp.to_numpy().tolist()
1075
+ if (pd.notna(x) and np.isfinite(x))
1076
+ ]
841
1077
  except Exception:
842
1078
  try:
843
- unique_pairs = sub[["sequence__merged_cluster_id", "sequence__cluster_size"]].drop_duplicates()
844
- sizes = [int(x) for x in unique_pairs["sequence__cluster_size"].dropna().astype(int).tolist()]
1079
+ unique_pairs = sub[
1080
+ ["sequence__merged_cluster_id", "sequence__cluster_size"]
1081
+ ].drop_duplicates()
1082
+ sizes = [
1083
+ int(x)
1084
+ for x in unique_pairs["sequence__cluster_size"]
1085
+ .dropna()
1086
+ .astype(int)
1087
+ .tolist()
1088
+ ]
845
1089
  except Exception:
846
1090
  sizes = []
847
1091
  if (not sizes) and histograms:
@@ -855,23 +1099,38 @@ def plot_histogram_pages(
855
1099
  ax.set_xlabel("Cluster size")
856
1100
  ax.set_ylabel("Count")
857
1101
  else:
858
- ax.text(0.5, 0.5, "No clusters", ha="center", va="center", transform=ax.transAxes, fontsize=10, color="gray")
1102
+ ax.text(
1103
+ 0.5,
1104
+ 0.5,
1105
+ "No clusters",
1106
+ ha="center",
1107
+ va="center",
1108
+ transform=ax.transAxes,
1109
+ fontsize=10,
1110
+ color="gray",
1111
+ )
859
1112
 
860
1113
  if r_idx == 0:
861
1114
  ax.set_title(str(ref_name), fontsize=10)
862
1115
  if c_idx == 0:
863
- total_reads = sum(counts.get((sample_name, ref), 0) for ref in references) if not use_adata else int((adata.obs[sample_key] == sample_name).sum())
1116
+ total_reads = (
1117
+ sum(counts.get((sample_name, ref), 0) for ref in references)
1118
+ if not use_adata
1119
+ else int((adata.obs[sample_key] == sample_name).sum())
1120
+ )
864
1121
  ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=9)
865
1122
  if r_idx != nrows - 1:
866
1123
  ax.set_xticklabels([])
867
1124
 
868
1125
  ax.grid(True, alpha=0.25)
869
1126
 
870
- fig2.suptitle(f"Union-find cluster size histograms — page {page+1}/{n_pages}", fontsize=12, y=0.995)
1127
+ fig2.suptitle(
1128
+ f"Union-find cluster size histograms — page {page + 1}/{n_pages}", fontsize=12, y=0.995
1129
+ )
871
1130
  fig2.tight_layout(rect=[0, 0, 1, 0.96])
872
1131
 
873
1132
  if output_directory:
874
- fname2 = os.path.join(output_directory, f"cluster_size_histograms_page_{page+1}.png")
1133
+ fname2 = os.path.join(output_directory, f"cluster_size_histograms_page_{page + 1}.png")
875
1134
  plt.savefig(fname2, bbox_inches="tight")
876
1135
  cluster_size_pages.append(fname2)
877
1136
  else:
@@ -923,7 +1182,9 @@ def plot_hamming_vs_metric_pages(
923
1182
 
924
1183
  obs = adata.obs
925
1184
  if sample_col not in obs.columns or ref_col not in obs.columns:
926
- raise ValueError(f"sample_col '{sample_col}' and ref_col '{ref_col}' must exist in adata.obs")
1185
+ raise ValueError(
1186
+ f"sample_col '{sample_col}' and ref_col '{ref_col}' must exist in adata.obs"
1187
+ )
927
1188
 
928
1189
  # canonicalize samples and refs
929
1190
  if samples is None:
@@ -964,14 +1225,24 @@ def plot_hamming_vs_metric_pages(
964
1225
  sY = pd.to_numeric(obs[hamming_col], errors="coerce") if hamming_present else None
965
1226
 
966
1227
  if (sX is not None) and (sY is not None):
967
- valid_both = sX.notna() & sY.notna() & np.isfinite(sX.values) & np.isfinite(sY.values)
1228
+ valid_both = (
1229
+ sX.notna() & sY.notna() & np.isfinite(sX.values) & np.isfinite(sY.values)
1230
+ )
968
1231
  if valid_both.any():
969
1232
  xvals = sX[valid_both].to_numpy(dtype=float)
970
1233
  yvals = sY[valid_both].to_numpy(dtype=float)
971
1234
  xmin, xmax = float(np.nanmin(xvals)), float(np.nanmax(xvals))
972
1235
  ymin, ymax = float(np.nanmin(yvals)), float(np.nanmax(yvals))
973
- xpad = max(1e-6, (xmax - xmin) * 0.05) if xmax > xmin else max(1e-3, abs(xmin) * 0.05 + 1e-3)
974
- ypad = max(1e-6, (ymax - ymin) * 0.05) if ymax > ymin else max(1e-3, abs(ymin) * 0.05 + 1e-3)
1236
+ xpad = (
1237
+ max(1e-6, (xmax - xmin) * 0.05)
1238
+ if xmax > xmin
1239
+ else max(1e-3, abs(xmin) * 0.05 + 1e-3)
1240
+ )
1241
+ ypad = (
1242
+ max(1e-6, (ymax - ymin) * 0.05)
1243
+ if ymax > ymin
1244
+ else max(1e-3, abs(ymin) * 0.05 + 1e-3)
1245
+ )
975
1246
  global_xlim = (xmin - xpad, xmax + xpad)
976
1247
  global_ylim = (ymin - ypad, ymax + ypad)
977
1248
  else:
@@ -979,23 +1250,39 @@ def plot_hamming_vs_metric_pages(
979
1250
  sY_finite = sY[np.isfinite(sY)]
980
1251
  if sX_finite.size > 0:
981
1252
  xmin, xmax = float(np.nanmin(sX_finite)), float(np.nanmax(sX_finite))
982
- xpad = max(1e-6, (xmax - xmin) * 0.05) if xmax > xmin else max(1e-3, abs(xmin) * 0.05 + 1e-3)
1253
+ xpad = (
1254
+ max(1e-6, (xmax - xmin) * 0.05)
1255
+ if xmax > xmin
1256
+ else max(1e-3, abs(xmin) * 0.05 + 1e-3)
1257
+ )
983
1258
  global_xlim = (xmin - xpad, xmax + xpad)
984
1259
  if sY_finite.size > 0:
985
1260
  ymin, ymax = float(np.nanmin(sY_finite)), float(np.nanmax(sY_finite))
986
- ypad = max(1e-6, (ymax - ymin) * 0.05) if ymax > ymin else max(1e-3, abs(ymin) * 0.05 + 1e-3)
1261
+ ypad = (
1262
+ max(1e-6, (ymax - ymin) * 0.05)
1263
+ if ymax > ymin
1264
+ else max(1e-3, abs(ymin) * 0.05 + 1e-3)
1265
+ )
987
1266
  global_ylim = (ymin - ypad, ymax + ypad)
988
1267
  elif sX is not None:
989
1268
  sX_finite = sX[np.isfinite(sX)]
990
1269
  if sX_finite.size > 0:
991
1270
  xmin, xmax = float(np.nanmin(sX_finite)), float(np.nanmax(sX_finite))
992
- xpad = max(1e-6, (xmax - xmin) * 0.05) if xmax > xmin else max(1e-3, abs(xmin) * 0.05 + 1e-3)
1271
+ xpad = (
1272
+ max(1e-6, (xmax - xmin) * 0.05)
1273
+ if xmax > xmin
1274
+ else max(1e-3, abs(xmin) * 0.05 + 1e-3)
1275
+ )
993
1276
  global_xlim = (xmin - xpad, xmax + xpad)
994
1277
  elif sY is not None:
995
1278
  sY_finite = sY[np.isfinite(sY)]
996
1279
  if sY_finite.size > 0:
997
1280
  ymin, ymax = float(np.nanmin(sY_finite)), float(np.nanmax(sY_finite))
998
- ypad = max(1e-6, (ymax - ymin) * 0.05) if ymax > ymin else max(1e-3, abs(ymin) * 0.05 + 1e-3)
1281
+ ypad = (
1282
+ max(1e-6, (ymax - ymin) * 0.05)
1283
+ if ymax > ymin
1284
+ else max(1e-3, abs(ymin) * 0.05 + 1e-3)
1285
+ )
999
1286
  global_ylim = (ymin - ypad, ymax + ypad)
1000
1287
 
1001
1288
  # pagination
@@ -1005,15 +1292,19 @@ def plot_hamming_vs_metric_pages(
1005
1292
  ncols = len(cols)
1006
1293
  fig_w = ncols * figsize_per_cell[0]
1007
1294
  fig_h = nrows * figsize_per_cell[1]
1008
- fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False)
1295
+ fig, axes = plt.subplots(
1296
+ nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False
1297
+ )
1009
1298
 
1010
1299
  for r_idx, sample_name in enumerate(chunk):
1011
1300
  for c_idx, ref_name in enumerate(cols):
1012
1301
  ax = axes[r_idx][c_idx]
1013
1302
  if ref_name == extra_col:
1014
- mask = (obs[sample_col].values == sample_name)
1303
+ mask = obs[sample_col].values == sample_name
1015
1304
  else:
1016
- mask = (obs[sample_col].values == sample_name) & (obs[ref_col].values == ref_name)
1305
+ mask = (obs[sample_col].values == sample_name) & (
1306
+ obs[ref_col].values == ref_name
1307
+ )
1017
1308
 
1018
1309
  sub = obs[mask]
1019
1310
 
@@ -1022,7 +1313,9 @@ def plot_hamming_vs_metric_pages(
1022
1313
  else:
1023
1314
  x_all = np.array([], dtype=float)
1024
1315
  if hamming_col in sub.columns:
1025
- y_all = pd.to_numeric(sub[hamming_col], errors="coerce").to_numpy(dtype=float)
1316
+ y_all = pd.to_numeric(sub[hamming_col], errors="coerce").to_numpy(
1317
+ dtype=float
1318
+ )
1026
1319
  else:
1027
1320
  y_all = np.array([], dtype=float)
1028
1321
 
@@ -1040,32 +1333,67 @@ def plot_hamming_vs_metric_pages(
1040
1333
  idxs_valid = np.array([], dtype=int)
1041
1334
 
1042
1335
  if x.size == 0:
1043
- ax.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes)
1336
+ ax.text(
1337
+ 0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes
1338
+ )
1044
1339
  clusters_info[(sample_name, ref_name)] = {"diag": None, "n_points": 0}
1045
1340
  else:
1046
1341
  # Decide color mapping
1047
- if color_by_duplicate and duplicate_col in adata.obs.columns and idxs_valid.size > 0:
1342
+ if (
1343
+ color_by_duplicate
1344
+ and duplicate_col in adata.obs.columns
1345
+ and idxs_valid.size > 0
1346
+ ):
1048
1347
  # get boolean series aligned to idxs_valid
1049
1348
  try:
1050
- dup_flags = adata.obs.loc[idxs_valid, duplicate_col].astype(bool).to_numpy()
1349
+ dup_flags = (
1350
+ adata.obs.loc[idxs_valid, duplicate_col].astype(bool).to_numpy()
1351
+ )
1051
1352
  except Exception:
1052
1353
  dup_flags = np.zeros(len(idxs_valid), dtype=bool)
1053
1354
  mask_dup = dup_flags
1054
1355
  mask_nondup = ~mask_dup
1055
1356
  # plot non-duplicates first in gray, duplicates in highlight color
1056
1357
  if mask_nondup.any():
1057
- ax.scatter(x[mask_nondup], y[mask_nondup], s=12, alpha=0.6, rasterized=True, c="lightgray")
1358
+ ax.scatter(
1359
+ x[mask_nondup],
1360
+ y[mask_nondup],
1361
+ s=12,
1362
+ alpha=0.6,
1363
+ rasterized=True,
1364
+ c="lightgray",
1365
+ )
1058
1366
  if mask_dup.any():
1059
- ax.scatter(x[mask_dup], y[mask_dup], s=20, alpha=0.9, rasterized=True, c=highlight_color, edgecolors="k", linewidths=0.3)
1367
+ ax.scatter(
1368
+ x[mask_dup],
1369
+ y[mask_dup],
1370
+ s=20,
1371
+ alpha=0.9,
1372
+ rasterized=True,
1373
+ c=highlight_color,
1374
+ edgecolors="k",
1375
+ linewidths=0.3,
1376
+ )
1060
1377
  else:
1061
1378
  # old behavior: highlight by threshold if requested
1062
1379
  if highlight_threshold is not None and y.size:
1063
1380
  mask_low = (y < float(highlight_threshold)) & np.isfinite(y)
1064
1381
  mask_high = ~mask_low
1065
1382
  if mask_high.any():
1066
- ax.scatter(x[mask_high], y[mask_high], s=12, alpha=0.6, rasterized=True)
1383
+ ax.scatter(
1384
+ x[mask_high], y[mask_high], s=12, alpha=0.6, rasterized=True
1385
+ )
1067
1386
  if mask_low.any():
1068
- ax.scatter(x[mask_low], y[mask_low], s=18, alpha=0.9, rasterized=True, c=highlight_color, edgecolors="k", linewidths=0.3)
1387
+ ax.scatter(
1388
+ x[mask_low],
1389
+ y[mask_low],
1390
+ s=18,
1391
+ alpha=0.9,
1392
+ rasterized=True,
1393
+ c=highlight_color,
1394
+ edgecolors="k",
1395
+ linewidths=0.3,
1396
+ )
1069
1397
  else:
1070
1398
  ax.scatter(x, y, s=12, alpha=0.6, rasterized=True)
1071
1399
 
@@ -1081,7 +1409,9 @@ def plot_hamming_vs_metric_pages(
1081
1409
  zi = gaussian_kde(np.vstack([x, y]))(coords).reshape(xi_g.shape)
1082
1410
  ax.contourf(xi_g, yi_g, zi, levels=8, alpha=0.35, cmap="Blues")
1083
1411
  else:
1084
- ax.scatter(x, y, c=kde2, s=16, cmap="viridis", alpha=0.7, linewidths=0)
1412
+ ax.scatter(
1413
+ x, y, c=kde2, s=16, cmap="viridis", alpha=0.7, linewidths=0
1414
+ )
1085
1415
  except Exception:
1086
1416
  pass
1087
1417
 
@@ -1090,16 +1420,29 @@ def plot_hamming_vs_metric_pages(
1090
1420
  a, b = np.polyfit(x, y, 1)
1091
1421
  xs = np.linspace(np.nanmin(x), np.nanmax(x), 100)
1092
1422
  ys = a * xs + b
1093
- ax.plot(xs, ys, linestyle="--", linewidth=1.2, alpha=0.9, color="red")
1423
+ ax.plot(
1424
+ xs, ys, linestyle="--", linewidth=1.2, alpha=0.9, color="red"
1425
+ )
1094
1426
  r = np.corrcoef(x, y)[0, 1]
1095
- ax.text(0.98, 0.02, f"r={float(r):.3f}", ha="right", va="bottom", transform=ax.transAxes, fontsize=8,
1096
- bbox=dict(facecolor="white", alpha=0.6, boxstyle="round,pad=0.2"))
1427
+ ax.text(
1428
+ 0.98,
1429
+ 0.02,
1430
+ f"r={float(r):.3f}",
1431
+ ha="right",
1432
+ va="bottom",
1433
+ transform=ax.transAxes,
1434
+ fontsize=8,
1435
+ bbox=dict(
1436
+ facecolor="white", alpha=0.6, boxstyle="round,pad=0.2"
1437
+ ),
1438
+ )
1097
1439
  except Exception:
1098
1440
  pass
1099
1441
 
1100
1442
  if clustering:
1101
1443
  cl_labels, diag = _run_clustering(
1102
- x, y,
1444
+ x,
1445
+ y,
1103
1446
  method=clustering.get("method", "dbscan"),
1104
1447
  n_clusters=clustering.get("n_clusters", 2),
1105
1448
  dbscan_eps=clustering.get("dbscan_eps", 0.05),
@@ -1113,32 +1456,59 @@ def plot_hamming_vs_metric_pages(
1113
1456
  if len(unique_nonnoise) > 0:
1114
1457
  medians = {}
1115
1458
  for lab in unique_nonnoise:
1116
- mask_lab = (cl_labels == lab)
1117
- medians[lab] = float(np.median(y[mask_lab])) if mask_lab.any() else float("nan")
1118
- sorted_by_median = sorted(unique_nonnoise, key=lambda l: (np.nan if np.isnan(medians[l]) else medians[l]), reverse=True)
1459
+ mask_lab = cl_labels == lab
1460
+ medians[lab] = (
1461
+ float(np.median(y[mask_lab]))
1462
+ if mask_lab.any()
1463
+ else float("nan")
1464
+ )
1465
+ sorted_by_median = sorted(
1466
+ unique_nonnoise,
1467
+ key=lambda idx: (
1468
+ np.nan if np.isnan(medians[idx]) else medians[idx]
1469
+ ),
1470
+ reverse=True,
1471
+ )
1119
1472
  mapping = {old: new for new, old in enumerate(sorted_by_median)}
1120
1473
  for i_lab in range(len(remapped_labels)):
1121
1474
  if remapped_labels[i_lab] != -1:
1122
- remapped_labels[i_lab] = mapping.get(remapped_labels[i_lab], -1)
1475
+ remapped_labels[i_lab] = mapping.get(
1476
+ remapped_labels[i_lab], -1
1477
+ )
1123
1478
  diag = diag or {}
1124
- diag["cluster_median_hamming"] = {int(old): medians[old] for old in medians}
1125
- diag["cluster_old_to_new_map"] = {int(old): int(new) for old, new in mapping.items()}
1479
+ diag["cluster_median_hamming"] = {
1480
+ int(old): medians[old] for old in medians
1481
+ }
1482
+ diag["cluster_old_to_new_map"] = {
1483
+ int(old): int(new) for old, new in mapping.items()
1484
+ }
1126
1485
  else:
1127
1486
  remapped_labels = cl_labels.copy()
1128
1487
  diag = diag or {}
1129
1488
  diag["cluster_median_hamming"] = {}
1130
1489
  diag["cluster_old_to_new_map"] = {}
1131
1490
 
1132
- _overlay_clusters_on_ax(ax, x, y, remapped_labels, diag,
1133
- cmap=clustering.get("cmap", "tab10"),
1134
- hull=clustering.get("hull", True),
1135
- show_cluster_labels=True)
1491
+ _overlay_clusters_on_ax(
1492
+ ax,
1493
+ x,
1494
+ y,
1495
+ remapped_labels,
1496
+ diag,
1497
+ cmap=clustering.get("cmap", "tab10"),
1498
+ hull=clustering.get("hull", True),
1499
+ show_cluster_labels=True,
1500
+ )
1136
1501
 
1137
- clusters_info[(sample_name, ref_name)] = {"diag": diag, "n_points": len(x)}
1502
+ clusters_info[(sample_name, ref_name)] = {
1503
+ "diag": diag,
1504
+ "n_points": len(x),
1505
+ }
1138
1506
 
1139
1507
  if write_clusters_to_adata and idxs_valid.size > 0:
1140
- colname_safe_ref = (ref_name if ref_name != extra_col else "ALLREFS")
1141
- colname = f"hamming_cluster__{metric}__{sample_name}__{colname_safe_ref}"
1508
+ colname_safe_ref = ref_name if ref_name != extra_col else "ALLREFS"
1509
+ colname = (
1510
+ f"hamming_cluster__{metric}__{sample_name}__{colname_safe_ref}"
1511
+ )
1142
1512
  if colname not in adata.obs.columns:
1143
1513
  adata.obs[colname] = np.nan
1144
1514
  lab_arr = remapped_labels.astype(float)
@@ -1178,7 +1548,11 @@ def plot_hamming_vs_metric_pages(
1178
1548
  plt.show()
1179
1549
  plt.close(fig)
1180
1550
 
1181
- saved_map[metric] = {"files": files, "clusters_info": clusters_info, "written_cols": written_cols}
1551
+ saved_map[metric] = {
1552
+ "files": files,
1553
+ "clusters_info": clusters_info,
1554
+ "written_cols": written_cols,
1555
+ }
1182
1556
 
1183
1557
  return saved_map
1184
1558
 
@@ -1187,7 +1561,7 @@ def _run_clustering(
1187
1561
  x: np.ndarray,
1188
1562
  y: np.ndarray,
1189
1563
  *,
1190
- method: str = "kmeans", # "kmeans", "dbscan", "gmm", "hdbscan"
1564
+ method: str = "kmeans", # "kmeans", "dbscan", "gmm", "hdbscan"
1191
1565
  n_clusters: int = 2,
1192
1566
  dbscan_eps: float = 0.05,
1193
1567
  dbscan_min_samples: int = 5,
@@ -1198,13 +1572,6 @@ def _run_clustering(
1198
1572
  Run clustering on 2D points (x,y). Returns labels (len = npoints) and diagnostics dict.
1199
1573
  Labels follow sklearn conventions (noise -> -1 for DBSCAN/HDBSCAN).
1200
1574
  """
1201
- try:
1202
- from sklearn.cluster import KMeans, DBSCAN
1203
- from sklearn.mixture import GaussianMixture
1204
- from sklearn.metrics import silhouette_score
1205
- except Exception:
1206
- KMeans = DBSCAN = GaussianMixture = silhouette_score = None
1207
-
1208
1575
  pts = np.column_stack([x, y])
1209
1576
  diagnostics: Dict[str, Any] = {"method": method, "n_input": len(x)}
1210
1577
  if len(x) < min_points:
@@ -1270,7 +1637,11 @@ def _run_clustering(
1270
1637
 
1271
1638
  # compute silhouette if suitable
1272
1639
  try:
1273
- if diagnostics.get("n_clusters_found", 0) >= 2 and len(x) >= 3 and silhouette_score is not None:
1640
+ if (
1641
+ diagnostics.get("n_clusters_found", 0) >= 2
1642
+ and len(x) >= 3
1643
+ and silhouette_score is not None
1644
+ ):
1274
1645
  diagnostics["silhouette"] = float(silhouette_score(pts, labels))
1275
1646
  else:
1276
1647
  diagnostics["silhouette"] = None
@@ -1305,7 +1676,6 @@ def _overlay_clusters_on_ax(
1305
1676
  Labels == -1 are noise and drawn in grey.
1306
1677
  Also annotates cluster numbers near centroids (contiguous numbers starting at 0).
1307
1678
  """
1308
- import matplotlib.colors as mcolors
1309
1679
  from scipy.spatial import ConvexHull
1310
1680
 
1311
1681
  labels = np.asarray(labels)
@@ -1323,19 +1693,47 @@ def _overlay_clusters_on_ax(
1323
1693
  if not mask.any():
1324
1694
  continue
1325
1695
  col = (0.6, 0.6, 0.6, 0.6) if lab == -1 else colors[idx % ncolors]
1326
- ax.scatter(x[mask], y[mask], s=20, c=[col], alpha=alpha_pts, marker=marker, linewidths=0.2, edgecolors="none", rasterized=True)
1696
+ ax.scatter(
1697
+ x[mask],
1698
+ y[mask],
1699
+ s=20,
1700
+ c=[col],
1701
+ alpha=alpha_pts,
1702
+ marker=marker,
1703
+ linewidths=0.2,
1704
+ edgecolors="none",
1705
+ rasterized=True,
1706
+ )
1327
1707
 
1328
1708
  if lab != -1:
1329
1709
  # centroid
1330
1710
  if plot_centroids:
1331
1711
  cx = float(np.mean(x[mask]))
1332
1712
  cy = float(np.mean(y[mask]))
1333
- ax.scatter([cx], [cy], s=centroid_size, marker=centroid_marker, c=[col], edgecolor="k", linewidth=0.6, zorder=10)
1713
+ ax.scatter(
1714
+ [cx],
1715
+ [cy],
1716
+ s=centroid_size,
1717
+ marker=centroid_marker,
1718
+ c=[col],
1719
+ edgecolor="k",
1720
+ linewidth=0.6,
1721
+ zorder=10,
1722
+ )
1334
1723
 
1335
1724
  if show_cluster_labels:
1336
- ax.text(cx, cy, str(int(lab)), color="white", fontsize=cluster_label_fontsize,
1337
- ha="center", va="center", weight="bold", zorder=12,
1338
- bbox=dict(facecolor=(0,0,0,0.5), pad=0.3, boxstyle="round"))
1725
+ ax.text(
1726
+ cx,
1727
+ cy,
1728
+ str(int(lab)),
1729
+ color="white",
1730
+ fontsize=cluster_label_fontsize,
1731
+ ha="center",
1732
+ va="center",
1733
+ weight="bold",
1734
+ zorder=12,
1735
+ bbox=dict(facecolor=(0, 0, 0, 0.5), pad=0.3, boxstyle="round"),
1736
+ )
1339
1737
 
1340
1738
  # hull
1341
1739
  if hull and np.sum(mask) >= 3:
@@ -1343,9 +1741,16 @@ def _overlay_clusters_on_ax(
1343
1741
  ch_pts = pts[mask]
1344
1742
  hull_idx = ConvexHull(ch_pts).vertices
1345
1743
  hull_poly = ch_pts[hull_idx]
1346
- ax.fill(hull_poly[:, 0], hull_poly[:, 1], alpha=hull_alpha, facecolor=col, edgecolor=hull_edgecolor, linewidth=0.6, zorder=5)
1744
+ ax.fill(
1745
+ hull_poly[:, 0],
1746
+ hull_poly[:, 1],
1747
+ alpha=hull_alpha,
1748
+ facecolor=col,
1749
+ edgecolor=hull_edgecolor,
1750
+ linewidth=0.6,
1751
+ zorder=5,
1752
+ )
1347
1753
  except Exception:
1348
1754
  pass
1349
1755
 
1350
1756
  return None
1351
-