smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. smftools/__init__.py +6 -8
  2. smftools/_settings.py +4 -6
  3. smftools/_version.py +1 -1
  4. smftools/cli/helpers.py +54 -0
  5. smftools/cli/hmm_adata.py +937 -256
  6. smftools/cli/load_adata.py +448 -268
  7. smftools/cli/preprocess_adata.py +469 -263
  8. smftools/cli/spatial_adata.py +536 -319
  9. smftools/cli_entry.py +97 -182
  10. smftools/config/__init__.py +1 -1
  11. smftools/config/conversion.yaml +17 -6
  12. smftools/config/deaminase.yaml +12 -10
  13. smftools/config/default.yaml +142 -33
  14. smftools/config/direct.yaml +11 -3
  15. smftools/config/discover_input_files.py +19 -5
  16. smftools/config/experiment_config.py +594 -264
  17. smftools/constants.py +37 -0
  18. smftools/datasets/__init__.py +2 -8
  19. smftools/datasets/datasets.py +32 -18
  20. smftools/hmm/HMM.py +2128 -1418
  21. smftools/hmm/__init__.py +2 -9
  22. smftools/hmm/archived/call_hmm_peaks.py +121 -0
  23. smftools/hmm/call_hmm_peaks.py +299 -91
  24. smftools/hmm/display_hmm.py +19 -6
  25. smftools/hmm/hmm_readwrite.py +13 -4
  26. smftools/hmm/nucleosome_hmm_refinement.py +102 -14
  27. smftools/informatics/__init__.py +30 -7
  28. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  30. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  31. smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
  32. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
  33. smftools/informatics/archived/print_bam_query_seq.py +7 -1
  34. smftools/informatics/bam_functions.py +397 -175
  35. smftools/informatics/basecalling.py +51 -9
  36. smftools/informatics/bed_functions.py +90 -57
  37. smftools/informatics/binarize_converted_base_identities.py +18 -7
  38. smftools/informatics/complement_base_list.py +7 -6
  39. smftools/informatics/converted_BAM_to_adata.py +265 -122
  40. smftools/informatics/fasta_functions.py +161 -83
  41. smftools/informatics/h5ad_functions.py +196 -30
  42. smftools/informatics/modkit_extract_to_adata.py +609 -270
  43. smftools/informatics/modkit_functions.py +85 -44
  44. smftools/informatics/ohe.py +44 -21
  45. smftools/informatics/pod5_functions.py +112 -73
  46. smftools/informatics/run_multiqc.py +20 -14
  47. smftools/logging_utils.py +51 -0
  48. smftools/machine_learning/__init__.py +2 -7
  49. smftools/machine_learning/data/anndata_data_module.py +143 -50
  50. smftools/machine_learning/data/preprocessing.py +2 -1
  51. smftools/machine_learning/evaluation/__init__.py +1 -1
  52. smftools/machine_learning/evaluation/eval_utils.py +11 -14
  53. smftools/machine_learning/evaluation/evaluators.py +46 -33
  54. smftools/machine_learning/inference/__init__.py +1 -1
  55. smftools/machine_learning/inference/inference_utils.py +7 -4
  56. smftools/machine_learning/inference/lightning_inference.py +9 -13
  57. smftools/machine_learning/inference/sklearn_inference.py +6 -8
  58. smftools/machine_learning/inference/sliding_window_inference.py +35 -25
  59. smftools/machine_learning/models/__init__.py +10 -5
  60. smftools/machine_learning/models/base.py +28 -42
  61. smftools/machine_learning/models/cnn.py +15 -11
  62. smftools/machine_learning/models/lightning_base.py +71 -40
  63. smftools/machine_learning/models/mlp.py +13 -4
  64. smftools/machine_learning/models/positional.py +3 -2
  65. smftools/machine_learning/models/rnn.py +3 -2
  66. smftools/machine_learning/models/sklearn_models.py +39 -22
  67. smftools/machine_learning/models/transformer.py +68 -53
  68. smftools/machine_learning/models/wrappers.py +2 -1
  69. smftools/machine_learning/training/__init__.py +2 -2
  70. smftools/machine_learning/training/train_lightning_model.py +29 -20
  71. smftools/machine_learning/training/train_sklearn_model.py +9 -15
  72. smftools/machine_learning/utils/__init__.py +1 -1
  73. smftools/machine_learning/utils/device.py +7 -4
  74. smftools/machine_learning/utils/grl.py +3 -1
  75. smftools/metadata.py +443 -0
  76. smftools/plotting/__init__.py +19 -5
  77. smftools/plotting/autocorrelation_plotting.py +145 -44
  78. smftools/plotting/classifiers.py +162 -72
  79. smftools/plotting/general_plotting.py +422 -197
  80. smftools/plotting/hmm_plotting.py +42 -13
  81. smftools/plotting/position_stats.py +147 -87
  82. smftools/plotting/qc_plotting.py +20 -12
  83. smftools/preprocessing/__init__.py +10 -12
  84. smftools/preprocessing/append_base_context.py +115 -80
  85. smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
  86. smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
  87. smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
  88. smftools/preprocessing/binarize.py +21 -4
  89. smftools/preprocessing/binarize_on_Youden.py +129 -31
  90. smftools/preprocessing/binary_layers_to_ohe.py +17 -11
  91. smftools/preprocessing/calculate_complexity_II.py +86 -59
  92. smftools/preprocessing/calculate_consensus.py +28 -19
  93. smftools/preprocessing/calculate_coverage.py +50 -25
  94. smftools/preprocessing/calculate_pairwise_differences.py +2 -1
  95. smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
  96. smftools/preprocessing/calculate_position_Youden.py +118 -54
  97. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  98. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  99. smftools/preprocessing/clean_NaN.py +38 -28
  100. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  101. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
  102. smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
  103. smftools/preprocessing/flag_duplicate_reads.py +689 -272
  104. smftools/preprocessing/invert_adata.py +26 -11
  105. smftools/preprocessing/load_sample_sheet.py +40 -22
  106. smftools/preprocessing/make_dirs.py +8 -3
  107. smftools/preprocessing/min_non_diagonal.py +2 -1
  108. smftools/preprocessing/recipes.py +56 -23
  109. smftools/preprocessing/reindex_references_adata.py +103 -0
  110. smftools/preprocessing/subsample_adata.py +33 -16
  111. smftools/readwrite.py +331 -82
  112. smftools/schema/__init__.py +11 -0
  113. smftools/schema/anndata_schema_v1.yaml +227 -0
  114. smftools/tools/__init__.py +3 -4
  115. smftools/tools/archived/classifiers.py +163 -0
  116. smftools/tools/archived/subset_adata_v1.py +10 -1
  117. smftools/tools/archived/subset_adata_v2.py +12 -1
  118. smftools/tools/calculate_umap.py +54 -15
  119. smftools/tools/cluster_adata_on_methylation.py +115 -46
  120. smftools/tools/general_tools.py +70 -25
  121. smftools/tools/position_stats.py +229 -98
  122. smftools/tools/read_stats.py +50 -29
  123. smftools/tools/spatial_autocorrelation.py +365 -192
  124. smftools/tools/subset_adata.py +23 -21
  125. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
  126. smftools-0.2.5.dist-info/RECORD +181 -0
  127. smftools-0.2.3.dist-info/RECORD +0 -173
  128. /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
  129. /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
  130. /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
  131. /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
  132. /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
  133. /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
  134. /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
  135. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
  136. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
  137. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
@@ -1,24 +1,28 @@
1
1
  # duplicate_detection_with_hier_and_plots.py
2
2
  import copy
3
- import warnings
4
3
  import math
5
4
  import os
5
+ import warnings
6
6
  from collections import defaultdict
7
- from typing import Dict, Any, Tuple, Union, List, Optional
7
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
8
8
 
9
- import torch
10
9
  import anndata as ad
10
+ import matplotlib.pyplot as plt
11
11
  import numpy as np
12
12
  import pandas as pd
13
- import matplotlib.pyplot as plt
14
- from tqdm import tqdm
13
+ import torch
14
+
15
+ from smftools.logging_utils import get_logger
15
16
 
16
17
  from ..readwrite import make_dirs
17
18
 
19
+ logger = get_logger(__name__)
20
+
18
21
  # optional imports for clustering / PCA / KDE
19
22
  try:
20
23
  from scipy.cluster import hierarchy as sch
21
24
  from scipy.spatial.distance import pdist, squareform
25
+
22
26
  SCIPY_AVAILABLE = True
23
27
  except Exception:
24
28
  sch = None
@@ -27,10 +31,11 @@ except Exception:
27
31
  SCIPY_AVAILABLE = False
28
32
 
29
33
  try:
34
+ from sklearn.cluster import DBSCAN, KMeans
30
35
  from sklearn.decomposition import PCA
31
- from sklearn.cluster import KMeans, DBSCAN
32
- from sklearn.mixture import GaussianMixture
33
36
  from sklearn.metrics import silhouette_score
37
+ from sklearn.mixture import GaussianMixture
38
+
34
39
  SKLEARN_AVAILABLE = True
35
40
  except Exception:
36
41
  PCA = None
@@ -43,10 +48,16 @@ except Exception:
43
48
  gaussian_kde = None
44
49
 
45
50
 
46
- def merge_uns_preserve(orig_uns: dict, new_uns: dict, prefer="orig") -> dict:
47
- """
48
- Merge two .uns dicts. prefer='orig' will keep orig_uns values on conflict,
49
- prefer='new' will keep new_uns values on conflict. Conflicts are reported.
51
+ def merge_uns_preserve(orig_uns: dict, new_uns: dict, prefer: str = "orig") -> dict:
52
+ """Merge two ``.uns`` dictionaries while preserving preferred values.
53
+
54
+ Args:
55
+ orig_uns: Original ``.uns`` dictionary.
56
+ new_uns: New ``.uns`` dictionary to merge.
57
+ prefer: Which dictionary to prefer on conflict (``"orig"`` or ``"new"``).
58
+
59
+ Returns:
60
+ dict: Merged dictionary.
50
61
  """
51
62
  out = copy.deepcopy(new_uns) if new_uns is not None else {}
52
63
  for k, v in (orig_uns or {}).items():
@@ -55,7 +66,7 @@ def merge_uns_preserve(orig_uns: dict, new_uns: dict, prefer="orig") -> dict:
55
66
  else:
56
67
  # present in both: compare quickly (best-effort)
57
68
  try:
58
- equal = (out[k] == v)
69
+ equal = out[k] == v
59
70
  except Exception:
60
71
  equal = False
61
72
  if equal:
@@ -69,19 +80,20 @@ def merge_uns_preserve(orig_uns: dict, new_uns: dict, prefer="orig") -> dict:
69
80
  out[f"orig_uns__{k}"] = copy.deepcopy(v)
70
81
  return out
71
82
 
83
+
72
84
  def flag_duplicate_reads(
73
- adata,
74
- var_filters_sets,
85
+ adata: ad.AnnData,
86
+ var_filters_sets: Sequence[dict[str, Any]],
75
87
  distance_threshold: float = 0.07,
76
88
  obs_reference_col: str = "Reference_strand",
77
89
  sample_col: str = "Barcode",
78
90
  output_directory: Optional[str] = None,
79
91
  metric_keys: Union[str, List[str]] = ("Fraction_any_C_site_modified",),
80
- uns_flag: str = "read_duplicate_detection_performed",
92
+ uns_flag: str = "flag_duplicate_reads_performed",
81
93
  uns_filtered_flag: str = "read_duplicates_removed",
82
94
  bypass: bool = False,
83
95
  force_redo: bool = False,
84
- keep_best_metric: Optional[str] = 'read_quality',
96
+ keep_best_metric: Optional[str] = "read_quality",
85
97
  keep_best_higher: bool = True,
86
98
  window_size: int = 50,
87
99
  min_overlap_positions: int = 20,
@@ -93,19 +105,137 @@ def flag_duplicate_reads(
93
105
  hierarchical_metric: str = "euclidean",
94
106
  hierarchical_window: int = 50,
95
107
  random_state: int = 0,
96
- ):
108
+ demux_types: Optional[Sequence[str]] = None,
109
+ demux_col: str = "demux_type",
110
+ ) -> ad.AnnData:
111
+ """Flag duplicate reads with demux-aware keeper preference.
112
+
113
+ Behavior:
114
+ - All reads are processed (no masking by demux).
115
+ - At each keeper decision, prefer reads whose ``demux_col`` value is in
116
+ ``demux_types`` when present. Among candidates, choose by
117
+ ``keep_best_metric``.
118
+
119
+ Args:
120
+ adata: AnnData object to process.
121
+ var_filters_sets: Sequence of variable filter definitions.
122
+ distance_threshold: Distance threshold for duplicate detection.
123
+ obs_reference_col: Obs column containing reference identifiers.
124
+ sample_col: Obs column containing sample identifiers.
125
+ output_directory: Directory for output plots and artifacts.
126
+ metric_keys: Metric key(s) used in processing.
127
+ uns_flag: Flag in ``adata.uns`` indicating prior completion.
128
+ uns_filtered_flag: Flag to mark read duplicates removal.
129
+ bypass: Whether to skip processing.
130
+ force_redo: Whether to rerun even if ``uns_flag`` is set.
131
+ keep_best_metric: Obs column used to select best read within duplicates.
132
+ keep_best_higher: Whether higher values in ``keep_best_metric`` are preferred.
133
+ window_size: Window size for local comparisons.
134
+ min_overlap_positions: Minimum overlapping positions required.
135
+ do_pca: Whether to run PCA before clustering.
136
+ pca_n_components: Number of PCA components.
137
+ pca_center: Whether to center data before PCA.
138
+ do_hierarchical: Whether to run hierarchical clustering.
139
+ hierarchical_linkage: Linkage method for hierarchical clustering.
140
+ hierarchical_metric: Distance metric for hierarchical clustering.
141
+ hierarchical_window: Window size for hierarchical clustering.
142
+ random_state: Random seed.
143
+ demux_types: Preferred demux types for keeper selection.
144
+ demux_col: Obs column containing demux type labels.
145
+
146
+ Returns:
147
+ anndata.AnnData: AnnData object with duplicate flags stored in ``adata.obs``.
97
148
  """
98
- Duplicate-flagging pipeline where hierarchical stage operates only on representatives
99
- (one representative per lex cluster, i.e. the keeper). Final keeper assignment and
100
- enforcement happens only after hierarchical merging.
149
+ import copy
150
+ import warnings
101
151
 
102
- Returns (adata_unique, adata_full) as before; writes sequence__* columns into adata.obs.
103
- """
104
- # early exits
152
+ import anndata as ad
153
+ import numpy as np
154
+ import pandas as pd
155
+
156
+ # optional imports already guarded at module import time, but re-check
157
+ try:
158
+ from scipy.cluster import hierarchy as sch
159
+ from scipy.spatial.distance import pdist
160
+
161
+ SCIPY_AVAILABLE = True
162
+ except Exception:
163
+ sch = None
164
+ pdist = None
165
+ SCIPY_AVAILABLE = False
166
+ try:
167
+ from sklearn.decomposition import PCA
168
+
169
+ SKLEARN_AVAILABLE = True
170
+ except Exception:
171
+ PCA = None
172
+ SKLEARN_AVAILABLE = False
173
+
174
+ # -------- helper: demux-aware keeper selection --------
175
+ def _choose_keeper_with_demux_preference(
176
+ members_idx: List[int],
177
+ adata_subset: ad.AnnData,
178
+ obs_index_list: List[Any],
179
+ *,
180
+ demux_col: str = "demux_type",
181
+ preferred_types: Optional[Sequence[str]] = None,
182
+ keep_best_metric: Optional[str] = None,
183
+ keep_best_higher: bool = True,
184
+ lex_keeper_mask: Optional[np.ndarray] = None, # aligned to members order
185
+ ) -> int:
186
+ """
187
+ Prefer members whose demux_col ∈ preferred_types.
188
+ Among candidates, pick by keep_best_metric (higher/lower).
189
+ If metric missing/NaN, prefer lex keeper (via mask) among candidates.
190
+ Fallback: first candidate.
191
+ Returns the chosen *member index* (int from members_idx).
192
+ """
193
+ # 1) demux-preferred candidates
194
+ if preferred_types and (demux_col in adata_subset.obs.columns):
195
+ preferred = set(map(str, preferred_types))
196
+ demux_series = adata_subset.obs[demux_col].astype("string")
197
+ names = [obs_index_list[m] for m in members_idx]
198
+ is_pref = demux_series.loc[names].isin(preferred).to_numpy()
199
+ candidates = [members_idx[i] for i, ok in enumerate(is_pref) if ok]
200
+ else:
201
+ candidates = []
202
+
203
+ if not candidates:
204
+ candidates = list(members_idx)
205
+
206
+ # 2) metric-based within candidates
207
+ if keep_best_metric and (keep_best_metric in adata_subset.obs.columns):
208
+ cand_names = [obs_index_list[m] for m in candidates]
209
+ try:
210
+ vals = pd.to_numeric(
211
+ adata_subset.obs.loc[cand_names, keep_best_metric],
212
+ errors="coerce",
213
+ ).to_numpy(dtype=float)
214
+ except Exception:
215
+ vals = np.array([np.nan] * len(candidates), dtype=float)
216
+
217
+ if not np.all(np.isnan(vals)):
218
+ if keep_best_higher:
219
+ vals = np.where(np.isnan(vals), -np.inf, vals)
220
+ return candidates[int(np.nanargmax(vals))]
221
+ else:
222
+ vals = np.where(np.isnan(vals), np.inf, vals)
223
+ return candidates[int(np.nanargmin(vals))]
224
+
225
+ # 3) metric unhelpful — prefer lex keeper if provided
226
+ if lex_keeper_mask is not None:
227
+ for i, midx in enumerate(members_idx):
228
+ if (midx in candidates) and bool(lex_keeper_mask[i]):
229
+ return midx
230
+
231
+ # 4) fallback
232
+ return candidates[0]
233
+
234
+ # -------- early exits --------
105
235
  already = bool(adata.uns.get(uns_flag, False))
106
- if (already and not force_redo):
236
+ if already and not force_redo:
107
237
  if "is_duplicate" in adata.obs.columns:
108
- adata_unique = adata[adata.obs["is_duplicate"] == False].copy()
238
+ adata_unique = adata[~adata.obs["is_duplicate"]].copy()
109
239
  return adata_unique, adata
110
240
  else:
111
241
  return adata.copy(), adata.copy()
@@ -117,17 +247,23 @@ def flag_duplicate_reads(
117
247
 
118
248
  # local UnionFind
119
249
  class UnionFind:
250
+ """Disjoint-set union-find helper for clustering indices."""
251
+
120
252
  def __init__(self, size):
253
+ """Initialize parent pointers for the union-find."""
121
254
  self.parent = list(range(size))
122
255
 
123
256
  def find(self, x):
257
+ """Find the root for a member with path compression."""
124
258
  while self.parent[x] != x:
125
259
  self.parent[x] = self.parent[self.parent[x]]
126
260
  x = self.parent[x]
127
261
  return x
128
262
 
129
263
  def union(self, x, y):
130
- rx = self.find(x); ry = self.find(y)
264
+ """Union the sets that contain x and y."""
265
+ rx = self.find(x)
266
+ ry = self.find(y)
131
267
  if rx != ry:
132
268
  self.parent[ry] = rx
133
269
 
@@ -139,14 +275,14 @@ def flag_duplicate_reads(
139
275
 
140
276
  for sample in samples:
141
277
  for ref in references:
142
- print(f"Processing sample={sample} ref={ref}")
278
+ logger.info("Processing sample=%s ref=%s", sample, ref)
143
279
  sample_mask = adata.obs[sample_col] == sample
144
280
  ref_mask = adata.obs[obs_reference_col] == ref
145
281
  subset_mask = sample_mask & ref_mask
146
282
  adata_subset = adata[subset_mask].copy()
147
283
 
148
284
  if adata_subset.n_obs < 2:
149
- print(f" Skipping {sample}_{ref} (too few reads)")
285
+ logger.info(" Skipping %s_%s (too few reads)", sample, ref)
150
286
  continue
151
287
 
152
288
  N = adata_subset.shape[0]
@@ -162,7 +298,12 @@ def flag_duplicate_reads(
162
298
 
163
299
  selected_cols = adata.var.index[combined_mask.tolist()].to_list()
164
300
  col_indices = [adata.var.index.get_loc(c) for c in selected_cols]
165
- print(f" Selected {len(col_indices)} columns out of {adata.var.shape[0]} for {ref}")
301
+ logger.info(
302
+ " Selected %s columns out of %s for %s",
303
+ len(col_indices),
304
+ adata.var.shape[0],
305
+ ref,
306
+ )
166
307
 
167
308
  # Extract data matrix (dense numpy) for the subset
168
309
  X = adata_subset.X
@@ -187,7 +328,10 @@ def flag_duplicate_reads(
187
328
  hierarchical_found_dists = []
188
329
 
189
330
  # Lexicographic windowed pass function
190
- def cluster_pass(X_tensor_local, reverse=False, window=int(window_size), record_distances=False):
331
+ def cluster_pass(
332
+ X_tensor_local, reverse=False, window=int(window_size), record_distances=False
333
+ ):
334
+ """Perform a lexicographic windowed clustering pass."""
191
335
  N_local = X_tensor_local.shape[0]
192
336
  X_sortable = X_tensor_local.clone().nan_to_num(-1.0)
193
337
  sort_keys = [tuple(row.numpy().tolist()) for row in X_sortable]
@@ -208,10 +352,16 @@ def flag_duplicate_reads(
208
352
  if enough_overlap.any():
209
353
  diffs = (row_i_exp != block_rows) & valid_mask
210
354
  hamming_counts = diffs.sum(dim=1).float()
211
- hamming_dists = torch.where(valid_counts > 0, hamming_counts / valid_counts, torch.tensor(float("nan")))
355
+ hamming_dists = torch.where(
356
+ valid_counts > 0,
357
+ hamming_counts / valid_counts,
358
+ torch.tensor(float("nan")),
359
+ )
212
360
  # record distances (legacy list of all local comparisons)
213
361
  hamming_np = hamming_dists.cpu().numpy().tolist()
214
- local_hamming_dists.extend([float(x) for x in hamming_np if (not np.isnan(x))])
362
+ local_hamming_dists.extend(
363
+ [float(x) for x in hamming_np if (not np.isnan(x))]
364
+ )
215
365
  matches = (hamming_dists < distance_threshold) & (enough_overlap)
216
366
  for offset_local, m in enumerate(matches):
217
367
  if m:
@@ -223,20 +373,28 @@ def flag_duplicate_reads(
223
373
  next_local_idx = i + 1
224
374
  if next_local_idx < len(sorted_X):
225
375
  next_global = sorted_idx[next_local_idx]
226
- vm_pair = (~torch.isnan(row_i)) & (~torch.isnan(sorted_X[next_local_idx]))
376
+ vm_pair = (~torch.isnan(row_i)) & (
377
+ ~torch.isnan(sorted_X[next_local_idx])
378
+ )
227
379
  vc = vm_pair.sum().item()
228
380
  if vc >= min_overlap_positions:
229
- d = float(((row_i[vm_pair] != sorted_X[next_local_idx][vm_pair]).sum().item()) / vc)
381
+ d = float(
382
+ (
383
+ (row_i[vm_pair] != sorted_X[next_local_idx][vm_pair])
384
+ .sum()
385
+ .item()
386
+ )
387
+ / vc
388
+ )
230
389
  if reverse:
231
390
  rev_hamming_to_prev[next_global] = d
232
391
  else:
233
392
  fwd_hamming_to_next[sorted_idx[i]] = d
234
393
  return cluster_pairs_local
235
394
 
236
- # run forward pass
395
+ # run forward & reverse windows
237
396
  pairs_fwd = cluster_pass(X_tensor, reverse=False, record_distances=True)
238
397
  involved_in_fwd = set([p for pair in pairs_fwd for p in pair])
239
- # build mask for reverse pass to avoid re-checking items already paired
240
398
  mask_for_rev = np.ones(N, dtype=bool)
241
399
  if len(involved_in_fwd) > 0:
242
400
  for idx in involved_in_fwd:
@@ -245,8 +403,9 @@ def flag_duplicate_reads(
245
403
  if len(rev_idx_map) > 0:
246
404
  reduced_tensor = X_tensor[rev_idx_map]
247
405
  pairs_rev_local = cluster_pass(reduced_tensor, reverse=True, record_distances=True)
248
- # remap local reduced indices to global
249
- remapped_rev_pairs = [(int(rev_idx_map[i]), int(rev_idx_map[j])) for (i, j) in pairs_rev_local]
406
+ remapped_rev_pairs = [
407
+ (int(rev_idx_map[i]), int(rev_idx_map[j])) for (i, j) in pairs_rev_local
408
+ ]
250
409
  else:
251
410
  remapped_rev_pairs = []
252
411
 
@@ -265,53 +424,41 @@ def flag_duplicate_reads(
265
424
  id_map = {old: new for new, old in enumerate(sorted(unique_initial.tolist()))}
266
425
  merged_cluster_mapped = np.array([id_map[int(x)] for x in merged_cluster], dtype=int)
267
426
 
268
- # cluster sizes and choose lex-keeper per lex-cluster (representatives)
427
+ # cluster sizes and choose lex-keeper per lex-cluster (demux-aware)
269
428
  cluster_sizes = np.zeros_like(merged_cluster_mapped)
270
429
  cluster_counts = []
271
430
  unique_clusters = np.unique(merged_cluster_mapped)
272
431
  keeper_for_cluster = {}
432
+
433
+ obs_index = list(adata_subset.obs.index)
273
434
  for cid in unique_clusters:
274
435
  members = np.where(merged_cluster_mapped == cid)[0].tolist()
275
436
  csize = int(len(members))
276
437
  cluster_counts.append(csize)
277
438
  cluster_sizes[members] = csize
278
- # pick lex keeper (representative)
279
- if len(members) == 1:
280
- keeper_for_cluster[cid] = members[0]
281
- else:
282
- if keep_best_metric is None:
283
- keeper_for_cluster[cid] = members[0]
284
- else:
285
- obs_index = list(adata_subset.obs.index)
286
- member_names = [obs_index[m] for m in members]
287
- try:
288
- vals = pd.to_numeric(adata_subset.obs.loc[member_names, keep_best_metric], errors="coerce").to_numpy(dtype=float)
289
- except Exception:
290
- vals = np.array([np.nan] * len(members), dtype=float)
291
- if np.all(np.isnan(vals)):
292
- keeper_for_cluster[cid] = members[0]
293
- else:
294
- if keep_best_higher:
295
- nan_mask = np.isnan(vals)
296
- vals[nan_mask] = -np.inf
297
- rel_idx = int(np.nanargmax(vals))
298
- else:
299
- nan_mask = np.isnan(vals)
300
- vals[nan_mask] = np.inf
301
- rel_idx = int(np.nanargmin(vals))
302
- keeper_for_cluster[cid] = members[rel_idx]
303
439
 
304
- # expose lex keeper info (record only; do not enforce deletion yet)
440
+ keeper_for_cluster[cid] = _choose_keeper_with_demux_preference(
441
+ members,
442
+ adata_subset,
443
+ obs_index,
444
+ demux_col=demux_col,
445
+ preferred_types=demux_types,
446
+ keep_best_metric=keep_best_metric,
447
+ keep_best_higher=keep_best_higher,
448
+ lex_keeper_mask=None, # no lex preference yet
449
+ )
450
+
451
+ # expose lex keeper info (record only)
305
452
  lex_is_keeper = np.zeros((N,), dtype=bool)
306
453
  lex_is_duplicate = np.zeros((N,), dtype=bool)
307
- for cid, members in zip(unique_clusters, [np.where(merged_cluster_mapped == cid)[0].tolist() for cid in unique_clusters]):
454
+ for cid in unique_clusters:
455
+ members = np.where(merged_cluster_mapped == cid)[0].tolist()
308
456
  keeper_idx = keeper_for_cluster[cid]
309
457
  lex_is_keeper[keeper_idx] = True
310
458
  for m in members:
311
459
  if m != keeper_idx:
312
460
  lex_is_duplicate[m] = True
313
- # note: these are just recorded for inspection / later preference
314
- # and will be written to adata_subset.obs below
461
+
315
462
  # record lex min pair (min of fwd/rev neighbor) for each read
316
463
  min_pair = np.full((N,), np.nan, dtype=float)
317
464
  for i in range(N):
@@ -349,7 +496,12 @@ def flag_duplicate_reads(
349
496
  if n_comp <= 0:
350
497
  reps_for_clustering = reps_arr
351
498
  else:
352
- pca = PCA(n_components=n_comp, random_state=int(random_state), svd_solver="auto", copy=True)
499
+ pca = PCA(
500
+ n_components=n_comp,
501
+ random_state=int(random_state),
502
+ svd_solver="auto",
503
+ copy=True,
504
+ )
353
505
  reps_for_clustering = pca.fit_transform(reps_arr)
354
506
  else:
355
507
  reps_for_clustering = reps_arr
@@ -360,10 +512,12 @@ def flag_duplicate_reads(
360
512
  Z = sch.linkage(pdist_vec, method=hierarchical_linkage)
361
513
  leaves = sch.leaves_list(Z)
362
514
  except Exception as e:
363
- warnings.warn(f"hierarchical pass failed: {e}; skipping hierarchical stage.")
515
+ warnings.warn(
516
+ f"hierarchical pass failed: {e}; skipping hierarchical stage."
517
+ )
364
518
  leaves = np.arange(len(rep_global_indices), dtype=int)
365
519
 
366
- # apply windowed hamming comparisons across ordered reps and union via same UF (so clusters of all reads merge)
520
+ # windowed hamming comparisons across ordered reps and union
367
521
  order_global_reps = [rep_global_indices[i] for i in leaves]
368
522
  n_reps = len(order_global_reps)
369
523
  for pos in range(n_reps):
@@ -389,55 +543,40 @@ def flag_duplicate_reads(
389
543
  merged_cluster_after[i] = uf.find(i)
390
544
  unique_final = np.unique(merged_cluster_after)
391
545
  id_map_final = {old: new for new, old in enumerate(sorted(unique_final.tolist()))}
392
- merged_cluster_mapped_final = np.array([id_map_final[int(x)] for x in merged_cluster_after], dtype=int)
546
+ merged_cluster_mapped_final = np.array(
547
+ [id_map_final[int(x)] for x in merged_cluster_after], dtype=int
548
+ )
393
549
 
394
- # compute final cluster members and choose final keeper per final cluster
550
+ # compute final cluster members and choose final keeper per final cluster (demux-aware)
395
551
  cluster_sizes_final = np.zeros_like(merged_cluster_mapped_final)
396
- final_cluster_counts = []
397
- final_unique = np.unique(merged_cluster_mapped_final)
398
552
  final_keeper_for_cluster = {}
399
553
  cluster_members_map = {}
400
- for cid in final_unique:
554
+
555
+ obs_index = list(adata_subset.obs.index)
556
+ lex_mask_full = lex_is_keeper # use lex keeper as optional tiebreaker
557
+
558
+ for cid in np.unique(merged_cluster_mapped_final):
401
559
  members = np.where(merged_cluster_mapped_final == cid)[0].tolist()
402
560
  cluster_members_map[cid] = members
403
- csize = len(members)
404
- final_cluster_counts.append(csize)
405
- cluster_sizes_final[members] = csize
406
- if csize == 1:
407
- final_keeper_for_cluster[cid] = members[0]
408
- else:
409
- # prefer keep_best_metric if available; do not automatically prefer lex-keeper here unless you want to;
410
- # (user previously asked for preferring lex keepers — if desired, you can prefer lex_is_keeper among members)
411
- obs_index = list(adata_subset.obs.index)
412
- member_names = [obs_index[m] for m in members]
413
- if keep_best_metric is not None and keep_best_metric in adata_subset.obs.columns:
414
- try:
415
- vals = pd.to_numeric(adata_subset.obs.loc[member_names, keep_best_metric], errors="coerce").to_numpy(dtype=float)
416
- except Exception:
417
- vals = np.array([np.nan] * len(members), dtype=float)
418
- if np.all(np.isnan(vals)):
419
- final_keeper_for_cluster[cid] = members[0]
420
- else:
421
- if keep_best_higher:
422
- nan_mask = np.isnan(vals)
423
- vals[nan_mask] = -np.inf
424
- rel_idx = int(np.nanargmax(vals))
425
- else:
426
- nan_mask = np.isnan(vals)
427
- vals[nan_mask] = np.inf
428
- rel_idx = int(np.nanargmin(vals))
429
- final_keeper_for_cluster[cid] = members[rel_idx]
430
- else:
431
- # if lex keepers present among members, prefer them
432
- lex_members = [m for m in members if lex_is_keeper[m]]
433
- if len(lex_members) > 0:
434
- final_keeper_for_cluster[cid] = lex_members[0]
435
- else:
436
- final_keeper_for_cluster[cid] = members[0]
437
-
438
- # update sequence__is_duplicate based on final clusters: non-keepers in multi-member clusters are duplicates
561
+ cluster_sizes_final[members] = len(members)
562
+
563
+ lex_mask_members = np.array([bool(lex_mask_full[m]) for m in members], dtype=bool)
564
+
565
+ keeper = _choose_keeper_with_demux_preference(
566
+ members,
567
+ adata_subset,
568
+ obs_index,
569
+ demux_col=demux_col,
570
+ preferred_types=demux_types,
571
+ keep_best_metric=keep_best_metric,
572
+ keep_best_higher=keep_best_higher,
573
+ lex_keeper_mask=lex_mask_members,
574
+ )
575
+ final_keeper_for_cluster[cid] = keeper
576
+
577
+ # update sequence__is_duplicate based on final clusters
439
578
  sequence_is_duplicate = np.zeros((N,), dtype=bool)
440
- for cid in final_unique:
579
+ for cid in np.unique(merged_cluster_mapped_final):
441
580
  keeper = final_keeper_for_cluster[cid]
442
581
  members = cluster_members_map[cid]
443
582
  if len(members) > 1:
@@ -446,8 +585,7 @@ def flag_duplicate_reads(
446
585
  sequence_is_duplicate[m] = True
447
586
 
448
587
  # propagate hierarchical distances into hierarchical_min_pair for all cluster members
449
- for (i_g, j_g, d) in hierarchical_pairs:
450
- # identify their final cluster ids (after unions)
588
+ for i_g, j_g, d in hierarchical_pairs:
451
589
  c_i = merged_cluster_mapped_final[int(i_g)]
452
590
  c_j = merged_cluster_mapped_final[int(j_g)]
453
591
  members_i = cluster_members_map.get(c_i, [int(i_g)])
@@ -459,7 +597,7 @@ def flag_duplicate_reads(
459
597
  if np.isnan(hierarchical_min_pair[mj]) or (d < hierarchical_min_pair[mj]):
460
598
  hierarchical_min_pair[mj] = d
461
599
 
462
- # combine lex-phase min_pair and hierarchical_min_pair into the final sequence__min_hamming_to_pair
600
+ # combine min pairs
463
601
  combined_min = min_pair.copy()
464
602
  for i in range(N):
465
603
  hval = hierarchical_min_pair[i]
@@ -475,69 +613,117 @@ def flag_duplicate_reads(
475
613
  adata_subset.obs["rev_hamming_to_prev"] = rev_hamming_to_prev
476
614
  adata_subset.obs["sequence__hier_hamming_to_pair"] = hierarchical_min_pair
477
615
  adata_subset.obs["sequence__min_hamming_to_pair"] = combined_min
478
- # persist lex bookkeeping columns (informational)
616
+ # persist lex bookkeeping
479
617
  adata_subset.obs["sequence__lex_is_keeper"] = lex_is_keeper
480
618
  adata_subset.obs["sequence__lex_is_duplicate"] = lex_is_duplicate
481
619
 
482
620
  adata_processed_list.append(adata_subset)
483
621
 
484
- histograms.append({
485
- "sample": sample,
486
- "reference": ref,
487
- "distances": local_hamming_dists, # lex local comparisons
488
- "cluster_counts": final_cluster_counts,
489
- "hierarchical_pairs": hierarchical_found_dists,
490
- })
491
-
492
- # Merge annotated subsets back together BEFORE plotting so plotting sees fwd_hamming_to_next, etc.
622
+ histograms.append(
623
+ {
624
+ "sample": sample,
625
+ "reference": ref,
626
+ "distances": local_hamming_dists,
627
+ "cluster_counts": [
628
+ int(x) for x in np.unique(cluster_sizes_final[cluster_sizes_final > 0])
629
+ ],
630
+ "hierarchical_pairs": hierarchical_found_dists,
631
+ }
632
+ )
633
+
634
+ # Merge annotated subsets back together BEFORE plotting
493
635
  _original_uns = copy.deepcopy(adata.uns)
494
636
  if len(adata_processed_list) == 0:
495
637
  return adata.copy(), adata.copy()
496
638
 
497
639
  adata_full = ad.concat(adata_processed_list, merge="same", join="outer", index_unique=None)
640
+
641
+ # preserve uns (prefer original on conflicts)
642
+ def merge_uns_preserve(orig_uns: dict, new_uns: dict, prefer="orig") -> dict:
643
+ """Merge .uns dictionaries while preserving original on conflicts."""
644
+ out = copy.deepcopy(new_uns) if new_uns is not None else {}
645
+ for k, v in (orig_uns or {}).items():
646
+ if k not in out:
647
+ out[k] = copy.deepcopy(v)
648
+ else:
649
+ try:
650
+ equal = out[k] == v
651
+ except Exception:
652
+ equal = False
653
+ if equal:
654
+ continue
655
+ if prefer == "orig":
656
+ out[k] = copy.deepcopy(v)
657
+ else:
658
+ out[f"orig_uns__{k}"] = copy.deepcopy(v)
659
+ return out
660
+
498
661
  adata_full.uns = merge_uns_preserve(_original_uns, adata_full.uns, prefer="orig")
499
662
 
500
663
  # Ensure expected numeric columns exist (create if missing)
501
- for col in ("fwd_hamming_to_next", "rev_hamming_to_prev", "sequence__min_hamming_to_pair", "sequence__hier_hamming_to_pair"):
664
+ for col in (
665
+ "fwd_hamming_to_next",
666
+ "rev_hamming_to_prev",
667
+ "sequence__min_hamming_to_pair",
668
+ "sequence__hier_hamming_to_pair",
669
+ ):
502
670
  if col not in adata_full.obs.columns:
503
671
  adata_full.obs[col] = np.nan
504
672
 
505
- # histograms (now driven by adata_full if requested)
506
- hist_outs = os.path.join(output_directory, "read_pair_hamming_distance_histograms")
507
- make_dirs([hist_outs])
508
- plot_histogram_pages(histograms,
509
- distance_threshold=distance_threshold,
510
- adata=adata_full,
511
- output_directory=hist_outs,
512
- distance_types=["min","fwd","rev","hier","lex_local"],
513
- sample_key=sample_col,
514
- )
673
+ # histograms
674
+ hist_outs = (
675
+ os.path.join(output_directory, "read_pair_hamming_distance_histograms")
676
+ if output_directory
677
+ else None
678
+ )
679
+ if hist_outs:
680
+ make_dirs([hist_outs])
681
+ plot_histogram_pages(
682
+ histograms,
683
+ distance_threshold=distance_threshold,
684
+ adata=adata_full,
685
+ output_directory=hist_outs,
686
+ distance_types=["min", "fwd", "rev", "hier", "lex_local"],
687
+ sample_key=sample_col,
688
+ )
515
689
 
516
690
  # hamming vs metric scatter
517
- scatter_outs = os.path.join(output_directory, "read_pair_hamming_distance_scatter_plots")
518
- make_dirs([scatter_outs])
519
- plot_hamming_vs_metric_pages(adata_full,
520
- metric_keys=metric_keys,
521
- output_dir=scatter_outs,
522
- hamming_col="sequence__min_hamming_to_pair",
523
- highlight_threshold=distance_threshold,
524
- highlight_color="red",
525
- sample_col=sample_col)
691
+ scatter_outs = (
692
+ os.path.join(output_directory, "read_pair_hamming_distance_scatter_plots")
693
+ if output_directory
694
+ else None
695
+ )
696
+ if scatter_outs:
697
+ make_dirs([scatter_outs])
698
+ plot_hamming_vs_metric_pages(
699
+ adata_full,
700
+ metric_keys=metric_keys,
701
+ output_dir=scatter_outs,
702
+ hamming_col="sequence__min_hamming_to_pair",
703
+ highlight_threshold=distance_threshold,
704
+ highlight_color="red",
705
+ sample_col=sample_col,
706
+ )
526
707
 
527
708
  # boolean columns from neighbor distances
528
- fwd_vals = pd.to_numeric(adata_full.obs.get("fwd_hamming_to_next", pd.Series(np.nan, index=adata_full.obs.index)), errors="coerce")
529
- rev_vals = pd.to_numeric(adata_full.obs.get("rev_hamming_to_prev", pd.Series(np.nan, index=adata_full.obs.index)), errors="coerce")
709
+ fwd_vals = pd.to_numeric(
710
+ adata_full.obs.get("fwd_hamming_to_next", pd.Series(np.nan, index=adata_full.obs.index)),
711
+ errors="coerce",
712
+ )
713
+ rev_vals = pd.to_numeric(
714
+ adata_full.obs.get("rev_hamming_to_prev", pd.Series(np.nan, index=adata_full.obs.index)),
715
+ errors="coerce",
716
+ )
530
717
  is_dup_dist = (fwd_vals < float(distance_threshold)) | (rev_vals < float(distance_threshold))
531
718
  is_dup_dist = is_dup_dist.fillna(False).astype(bool)
532
719
  adata_full.obs["is_duplicate_distance"] = is_dup_dist.values
533
720
 
534
- # combine sequence-derived flag with others
535
- if "sequence__is_duplicate" in adata_full.obs.columns:
536
- seq_dup = adata_full.obs["sequence__is_duplicate"].astype(bool)
537
- else:
538
- seq_dup = pd.Series(False, index=adata_full.obs.index)
539
-
540
- # cluster-based duplicate indicator (if any clustering columns exist)
721
+ # combine with sequence flag and any clustering flags
722
+ seq_dup = (
723
+ adata_full.obs["sequence__is_duplicate"].astype(bool)
724
+ if "sequence__is_duplicate" in adata_full.obs.columns
725
+ else pd.Series(False, index=adata_full.obs.index)
726
+ )
541
727
  cluster_cols = [c for c in adata_full.obs.columns if c.startswith("hamming_cluster__")]
542
728
  if cluster_cols:
543
729
  cl_mask = pd.Series(False, index=adata_full.obs.index)
@@ -550,59 +736,61 @@ def flag_duplicate_reads(
550
736
  else:
551
737
  adata_full.obs["is_duplicate_clustering"] = False
552
738
 
553
- final_dup = seq_dup | adata_full.obs["is_duplicate_distance"].astype(bool) | adata_full.obs["is_duplicate_clustering"].astype(bool)
739
+ final_dup = (
740
+ seq_dup
741
+ | adata_full.obs["is_duplicate_distance"].astype(bool)
742
+ | adata_full.obs["is_duplicate_clustering"].astype(bool)
743
+ )
554
744
  adata_full.obs["is_duplicate"] = final_dup.values
555
745
 
556
- # Final keeper enforcement: recompute per-cluster keeper from sequence__merged_cluster_id and
557
- # ensure that keeper is not marked duplicate
558
- if "sequence__merged_cluster_id" in adata_full.obs.columns:
559
- keeper_idx_by_cluster = {}
560
- metric_col = keep_best_metric if 'keep_best_metric' in locals() else None
746
+ # -------- Final keeper enforcement on adata_full (demux-aware) --------
747
+ keeper_idx_by_cluster = {}
748
+ metric_col = keep_best_metric
561
749
 
562
- # group by cluster id
563
- grp = adata_full.obs[["sequence__merged_cluster_id", "sequence__cluster_size"]].copy()
564
- for cid, sub in grp.groupby("sequence__merged_cluster_id"):
565
- try:
566
- members = sub.index.to_list()
567
- except Exception:
568
- members = list(sub.index)
569
- keeper = None
570
- # prefer keep_best_metric (if present), else prefer lex keeper among members, else first member
571
- if metric_col and metric_col in adata_full.obs.columns:
572
- try:
573
- vals = pd.to_numeric(adata_full.obs.loc[members, metric_col], errors="coerce")
574
- if vals.notna().any():
575
- keeper = vals.idxmax() if keep_best_higher else vals.idxmin()
576
- else:
577
- keeper = members[0]
578
- except Exception:
579
- keeper = members[0]
580
- else:
581
- # prefer lex keeper if present in this merged cluster
582
- lex_candidates = [m for m in members if ("sequence__lex_is_keeper" in adata_full.obs.columns and adata_full.obs.at[m, "sequence__lex_is_keeper"])]
583
- if len(lex_candidates) > 0:
584
- keeper = lex_candidates[0]
585
- else:
586
- keeper = members[0]
750
+ # Build an index→row-number mapping
751
+ name_to_pos = {name: i for i, name in enumerate(adata_full.obs.index)}
752
+ obs_index_full = list(adata_full.obs.index)
587
753
 
588
- keeper_idx_by_cluster[cid] = keeper
754
+ lex_col = (
755
+ "sequence__lex_is_keeper" if "sequence__lex_is_keeper" in adata_full.obs.columns else None
756
+ )
589
757
 
590
- # force keepers not to be duplicates
591
- is_dup_series = adata_full.obs["is_duplicate"].astype(bool)
592
- for cid, keeper_idx in keeper_idx_by_cluster.items():
593
- if keeper_idx in adata_full.obs.index:
594
- is_dup_series.at[keeper_idx] = False
595
- # clear sequence__is_duplicate for keeper if present
596
- if "sequence__is_duplicate" in adata_full.obs.columns:
597
- adata_full.obs.at[keeper_idx, "sequence__is_duplicate"] = False
598
- # clear lex duplicate flag too if present
599
- if "sequence__lex_is_duplicate" in adata_full.obs.columns:
600
- adata_full.obs.at[keeper_idx, "sequence__lex_is_duplicate"] = False
758
+ for cid, sub in adata_full.obs.groupby("sequence__merged_cluster_id", dropna=False):
759
+ members_names = sub.index.to_list()
760
+ members_pos = [name_to_pos[n] for n in members_names]
601
761
 
602
- adata_full.obs["is_duplicate"] = is_dup_series.values
762
+ if lex_col:
763
+ lex_mask_members = adata_full.obs.loc[members_names, lex_col].astype(bool).to_numpy()
764
+ else:
765
+ lex_mask_members = np.zeros(len(members_pos), dtype=bool)
766
+
767
+ keeper_pos = _choose_keeper_with_demux_preference(
768
+ members_pos,
769
+ adata_full,
770
+ obs_index_full,
771
+ demux_col=demux_col,
772
+ preferred_types=demux_types,
773
+ keep_best_metric=metric_col,
774
+ keep_best_higher=keep_best_higher,
775
+ lex_keeper_mask=lex_mask_members,
776
+ )
777
+ keeper_name = obs_index_full[keeper_pos]
778
+ keeper_idx_by_cluster[cid] = keeper_name
779
+
780
+ # enforce: keepers are not duplicates
781
+ is_dup_series = adata_full.obs["is_duplicate"].astype(bool)
782
+ for cid, keeper_name in keeper_idx_by_cluster.items():
783
+ if keeper_name in adata_full.obs.index:
784
+ is_dup_series.at[keeper_name] = False
785
+ if "sequence__is_duplicate" in adata_full.obs.columns:
786
+ adata_full.obs.at[keeper_name, "sequence__is_duplicate"] = False
787
+ if "sequence__lex_is_duplicate" in adata_full.obs.columns:
788
+ adata_full.obs.at[keeper_name, "sequence__lex_is_duplicate"] = False
789
+ adata_full.obs["is_duplicate"] = is_dup_series.values
603
790
 
604
791
  # reason column
605
792
  def _dup_reason_row(row):
793
+ """Build a semi-colon delimited duplicate reason string."""
606
794
  reasons = []
607
795
  if row.get("is_duplicate_distance", False):
608
796
  reasons.append("distance_thresh")
@@ -632,6 +820,7 @@ def flag_duplicate_reads(
632
820
  # Plot helpers (use adata_full as input)
633
821
  # ---------------------------
634
822
 
823
+
635
824
  def plot_histogram_pages(
636
825
  histograms,
637
826
  distance_threshold,
@@ -674,10 +863,11 @@ def plot_histogram_pages(
674
863
  use_adata = False
675
864
 
676
865
  if len(samples) == 0 or len(references) == 0:
677
- print("No histogram data to plot.")
866
+ logger.info("No histogram data to plot.")
678
867
  return {"distance_pages": [], "cluster_size_pages": []}
679
868
 
680
869
  def clean_array(arr):
870
+ """Filter array values to finite [0, 1] range for plotting."""
681
871
  if arr is None or len(arr) == 0:
682
872
  return np.array([], dtype=float)
683
873
  a = np.asarray(arr, dtype=float)
@@ -707,7 +897,9 @@ def plot_histogram_pages(
707
897
  if "rev" in distance_types and "rev_hamming_to_prev" in group.columns:
708
898
  grid[(s, r)]["rev"].extend(clean_array(group["rev_hamming_to_prev"].to_numpy()))
709
899
  if "hier" in distance_types and "sequence__hier_hamming_to_pair" in group.columns:
710
- grid[(s, r)]["hier"].extend(clean_array(group["sequence__hier_hamming_to_pair"].to_numpy()))
900
+ grid[(s, r)]["hier"].extend(
901
+ clean_array(group["sequence__hier_hamming_to_pair"].to_numpy())
902
+ )
711
903
  else:
712
904
  for (s, r), group in grouped:
713
905
  if "min" in distance_types and distance_key in group.columns:
@@ -717,7 +909,9 @@ def plot_histogram_pages(
717
909
  if "rev" in distance_types and "rev_hamming_to_prev" in group.columns:
718
910
  grid[(s, r)]["rev"].extend(clean_array(group["rev_hamming_to_prev"].to_numpy()))
719
911
  if "hier" in distance_types and "sequence__hier_hamming_to_pair" in group.columns:
720
- grid[(s, r)]["hier"].extend(clean_array(group["sequence__hier_hamming_to_pair"].to_numpy()))
912
+ grid[(s, r)]["hier"].extend(
913
+ clean_array(group["sequence__hier_hamming_to_pair"].to_numpy())
914
+ )
721
915
 
722
916
  # legacy histograms fallback
723
917
  if histograms:
@@ -753,9 +947,17 @@ def plot_histogram_pages(
753
947
 
754
948
  # counts (for labels)
755
949
  if use_adata:
756
- counts = {(s, r): int(((adata.obs[sample_key] == s) & (adata.obs[ref_key] == r)).sum()) for s in samples for r in references}
950
+ counts = {
951
+ (s, r): int(((adata.obs[sample_key] == s) & (adata.obs[ref_key] == r)).sum())
952
+ for s in samples
953
+ for r in references
954
+ }
757
955
  else:
758
- counts = {(s, r): sum(len(grid[(s, r)][dt]) for dt in distance_types) for s in samples for r in references}
956
+ counts = {
957
+ (s, r): sum(len(grid[(s, r)][dt]) for dt in distance_types)
958
+ for s in samples
959
+ for r in references
960
+ }
759
961
 
760
962
  distance_pages = []
761
963
  cluster_size_pages = []
@@ -773,7 +975,9 @@ def plot_histogram_pages(
773
975
  # Distance histogram page
774
976
  fig_w = figsize_per_cell[0] * ncols
775
977
  fig_h = figsize_per_cell[1] * nrows
776
- fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False)
978
+ fig, axes = plt.subplots(
979
+ nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False
980
+ )
777
981
 
778
982
  for r_idx, sample_name in enumerate(chunk):
779
983
  for c_idx, ref_name in enumerate(references):
@@ -789,17 +993,37 @@ def plot_histogram_pages(
789
993
  vals = vals[(vals >= 0.0) & (vals <= ref_vmax)]
790
994
  if vals.size > 0:
791
995
  any_data = True
792
- ax.hist(vals, bins=bins_edges, alpha=0.5, label=dtype, density=False, stacked=False,
793
- color=dtype_colors.get(dtype, None))
996
+ ax.hist(
997
+ vals,
998
+ bins=bins_edges,
999
+ alpha=0.5,
1000
+ label=dtype,
1001
+ density=False,
1002
+ stacked=False,
1003
+ color=dtype_colors.get(dtype, None),
1004
+ )
794
1005
  if not any_data:
795
- ax.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes, fontsize=10, color="gray")
1006
+ ax.text(
1007
+ 0.5,
1008
+ 0.5,
1009
+ "No data",
1010
+ ha="center",
1011
+ va="center",
1012
+ transform=ax.transAxes,
1013
+ fontsize=10,
1014
+ color="gray",
1015
+ )
796
1016
  # threshold line (make sure it is within axis)
797
1017
  ax.axvline(distance_threshold, color="red", linestyle="--", linewidth=1)
798
1018
 
799
1019
  if r_idx == 0:
800
1020
  ax.set_title(str(ref_name), fontsize=10)
801
1021
  if c_idx == 0:
802
- total_reads = sum(counts.get((sample_name, ref), 0) for ref in references) if not use_adata else int((adata.obs[sample_key] == sample_name).sum())
1022
+ total_reads = (
1023
+ sum(counts.get((sample_name, ref), 0) for ref in references)
1024
+ if not use_adata
1025
+ else int((adata.obs[sample_key] == sample_name).sum())
1026
+ )
803
1027
  ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=9)
804
1028
  if r_idx == nrows - 1:
805
1029
  ax.set_xlabel("Hamming Distance", fontsize=9)
@@ -811,12 +1035,16 @@ def plot_histogram_pages(
811
1035
  if r_idx == 0 and c_idx == 0:
812
1036
  ax.legend(fontsize=7, loc="upper right")
813
1037
 
814
- fig.suptitle(f"Hamming distance histograms (rows=samples, cols=references) — page {page+1}/{n_pages}", fontsize=12, y=0.995)
1038
+ fig.suptitle(
1039
+ f"Hamming distance histograms (rows=samples, cols=references) — page {page + 1}/{n_pages}",
1040
+ fontsize=12,
1041
+ y=0.995,
1042
+ )
815
1043
  fig.tight_layout(rect=[0, 0, 1, 0.96])
816
1044
 
817
1045
  if output_directory:
818
1046
  os.makedirs(output_directory, exist_ok=True)
819
- fname = os.path.join(output_directory, f"hamming_histograms_page_{page+1}.png")
1047
+ fname = os.path.join(output_directory, f"hamming_histograms_page_{page + 1}.png")
820
1048
  plt.savefig(fname, bbox_inches="tight")
821
1049
  distance_pages.append(fname)
822
1050
  else:
@@ -826,22 +1054,43 @@ def plot_histogram_pages(
826
1054
  # Cluster-size histogram page (unchanged except it uses adata-derived sizes per cluster if available)
827
1055
  fig_w = figsize_per_cell[0] * ncols
828
1056
  fig_h = figsize_per_cell[1] * nrows
829
- fig2, axes2 = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False)
1057
+ fig2, axes2 = plt.subplots(
1058
+ nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False
1059
+ )
830
1060
 
831
1061
  for r_idx, sample_name in enumerate(chunk):
832
1062
  for c_idx, ref_name in enumerate(references):
833
1063
  ax = axes2[r_idx][c_idx]
834
1064
  sizes = []
835
- if use_adata and ("sequence__merged_cluster_id" in adata.obs.columns and "sequence__cluster_size" in adata.obs.columns):
836
- sub = adata.obs[(adata.obs[sample_key] == sample_name) & (adata.obs[ref_key] == ref_name)]
1065
+ if use_adata and (
1066
+ "sequence__merged_cluster_id" in adata.obs.columns
1067
+ and "sequence__cluster_size" in adata.obs.columns
1068
+ ):
1069
+ sub = adata.obs[
1070
+ (adata.obs[sample_key] == sample_name) & (adata.obs[ref_key] == ref_name)
1071
+ ]
837
1072
  if not sub.empty:
838
1073
  try:
839
- grp = sub.groupby("sequence__merged_cluster_id")["sequence__cluster_size"].first()
840
- sizes = [int(x) for x in grp.to_numpy().tolist() if (pd.notna(x) and np.isfinite(x))]
1074
+ grp = sub.groupby("sequence__merged_cluster_id")[
1075
+ "sequence__cluster_size"
1076
+ ].first()
1077
+ sizes = [
1078
+ int(x)
1079
+ for x in grp.to_numpy().tolist()
1080
+ if (pd.notna(x) and np.isfinite(x))
1081
+ ]
841
1082
  except Exception:
842
1083
  try:
843
- unique_pairs = sub[["sequence__merged_cluster_id", "sequence__cluster_size"]].drop_duplicates()
844
- sizes = [int(x) for x in unique_pairs["sequence__cluster_size"].dropna().astype(int).tolist()]
1084
+ unique_pairs = sub[
1085
+ ["sequence__merged_cluster_id", "sequence__cluster_size"]
1086
+ ].drop_duplicates()
1087
+ sizes = [
1088
+ int(x)
1089
+ for x in unique_pairs["sequence__cluster_size"]
1090
+ .dropna()
1091
+ .astype(int)
1092
+ .tolist()
1093
+ ]
845
1094
  except Exception:
846
1095
  sizes = []
847
1096
  if (not sizes) and histograms:
@@ -855,23 +1104,38 @@ def plot_histogram_pages(
855
1104
  ax.set_xlabel("Cluster size")
856
1105
  ax.set_ylabel("Count")
857
1106
  else:
858
- ax.text(0.5, 0.5, "No clusters", ha="center", va="center", transform=ax.transAxes, fontsize=10, color="gray")
1107
+ ax.text(
1108
+ 0.5,
1109
+ 0.5,
1110
+ "No clusters",
1111
+ ha="center",
1112
+ va="center",
1113
+ transform=ax.transAxes,
1114
+ fontsize=10,
1115
+ color="gray",
1116
+ )
859
1117
 
860
1118
  if r_idx == 0:
861
1119
  ax.set_title(str(ref_name), fontsize=10)
862
1120
  if c_idx == 0:
863
- total_reads = sum(counts.get((sample_name, ref), 0) for ref in references) if not use_adata else int((adata.obs[sample_key] == sample_name).sum())
1121
+ total_reads = (
1122
+ sum(counts.get((sample_name, ref), 0) for ref in references)
1123
+ if not use_adata
1124
+ else int((adata.obs[sample_key] == sample_name).sum())
1125
+ )
864
1126
  ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=9)
865
1127
  if r_idx != nrows - 1:
866
1128
  ax.set_xticklabels([])
867
1129
 
868
1130
  ax.grid(True, alpha=0.25)
869
1131
 
870
- fig2.suptitle(f"Union-find cluster size histograms — page {page+1}/{n_pages}", fontsize=12, y=0.995)
1132
+ fig2.suptitle(
1133
+ f"Union-find cluster size histograms — page {page + 1}/{n_pages}", fontsize=12, y=0.995
1134
+ )
871
1135
  fig2.tight_layout(rect=[0, 0, 1, 0.96])
872
1136
 
873
1137
  if output_directory:
874
- fname2 = os.path.join(output_directory, f"cluster_size_histograms_page_{page+1}.png")
1138
+ fname2 = os.path.join(output_directory, f"cluster_size_histograms_page_{page + 1}.png")
875
1139
  plt.savefig(fname2, bbox_inches="tight")
876
1140
  cluster_size_pages.append(fname2)
877
1141
  else:
@@ -923,7 +1187,9 @@ def plot_hamming_vs_metric_pages(
923
1187
 
924
1188
  obs = adata.obs
925
1189
  if sample_col not in obs.columns or ref_col not in obs.columns:
926
- raise ValueError(f"sample_col '{sample_col}' and ref_col '{ref_col}' must exist in adata.obs")
1190
+ raise ValueError(
1191
+ f"sample_col '{sample_col}' and ref_col '{ref_col}' must exist in adata.obs"
1192
+ )
927
1193
 
928
1194
  # canonicalize samples and refs
929
1195
  if samples is None:
@@ -964,14 +1230,24 @@ def plot_hamming_vs_metric_pages(
964
1230
  sY = pd.to_numeric(obs[hamming_col], errors="coerce") if hamming_present else None
965
1231
 
966
1232
  if (sX is not None) and (sY is not None):
967
- valid_both = sX.notna() & sY.notna() & np.isfinite(sX.values) & np.isfinite(sY.values)
1233
+ valid_both = (
1234
+ sX.notna() & sY.notna() & np.isfinite(sX.values) & np.isfinite(sY.values)
1235
+ )
968
1236
  if valid_both.any():
969
1237
  xvals = sX[valid_both].to_numpy(dtype=float)
970
1238
  yvals = sY[valid_both].to_numpy(dtype=float)
971
1239
  xmin, xmax = float(np.nanmin(xvals)), float(np.nanmax(xvals))
972
1240
  ymin, ymax = float(np.nanmin(yvals)), float(np.nanmax(yvals))
973
- xpad = max(1e-6, (xmax - xmin) * 0.05) if xmax > xmin else max(1e-3, abs(xmin) * 0.05 + 1e-3)
974
- ypad = max(1e-6, (ymax - ymin) * 0.05) if ymax > ymin else max(1e-3, abs(ymin) * 0.05 + 1e-3)
1241
+ xpad = (
1242
+ max(1e-6, (xmax - xmin) * 0.05)
1243
+ if xmax > xmin
1244
+ else max(1e-3, abs(xmin) * 0.05 + 1e-3)
1245
+ )
1246
+ ypad = (
1247
+ max(1e-6, (ymax - ymin) * 0.05)
1248
+ if ymax > ymin
1249
+ else max(1e-3, abs(ymin) * 0.05 + 1e-3)
1250
+ )
975
1251
  global_xlim = (xmin - xpad, xmax + xpad)
976
1252
  global_ylim = (ymin - ypad, ymax + ypad)
977
1253
  else:
@@ -979,23 +1255,39 @@ def plot_hamming_vs_metric_pages(
979
1255
  sY_finite = sY[np.isfinite(sY)]
980
1256
  if sX_finite.size > 0:
981
1257
  xmin, xmax = float(np.nanmin(sX_finite)), float(np.nanmax(sX_finite))
982
- xpad = max(1e-6, (xmax - xmin) * 0.05) if xmax > xmin else max(1e-3, abs(xmin) * 0.05 + 1e-3)
1258
+ xpad = (
1259
+ max(1e-6, (xmax - xmin) * 0.05)
1260
+ if xmax > xmin
1261
+ else max(1e-3, abs(xmin) * 0.05 + 1e-3)
1262
+ )
983
1263
  global_xlim = (xmin - xpad, xmax + xpad)
984
1264
  if sY_finite.size > 0:
985
1265
  ymin, ymax = float(np.nanmin(sY_finite)), float(np.nanmax(sY_finite))
986
- ypad = max(1e-6, (ymax - ymin) * 0.05) if ymax > ymin else max(1e-3, abs(ymin) * 0.05 + 1e-3)
1266
+ ypad = (
1267
+ max(1e-6, (ymax - ymin) * 0.05)
1268
+ if ymax > ymin
1269
+ else max(1e-3, abs(ymin) * 0.05 + 1e-3)
1270
+ )
987
1271
  global_ylim = (ymin - ypad, ymax + ypad)
988
1272
  elif sX is not None:
989
1273
  sX_finite = sX[np.isfinite(sX)]
990
1274
  if sX_finite.size > 0:
991
1275
  xmin, xmax = float(np.nanmin(sX_finite)), float(np.nanmax(sX_finite))
992
- xpad = max(1e-6, (xmax - xmin) * 0.05) if xmax > xmin else max(1e-3, abs(xmin) * 0.05 + 1e-3)
1276
+ xpad = (
1277
+ max(1e-6, (xmax - xmin) * 0.05)
1278
+ if xmax > xmin
1279
+ else max(1e-3, abs(xmin) * 0.05 + 1e-3)
1280
+ )
993
1281
  global_xlim = (xmin - xpad, xmax + xpad)
994
1282
  elif sY is not None:
995
1283
  sY_finite = sY[np.isfinite(sY)]
996
1284
  if sY_finite.size > 0:
997
1285
  ymin, ymax = float(np.nanmin(sY_finite)), float(np.nanmax(sY_finite))
998
- ypad = max(1e-6, (ymax - ymin) * 0.05) if ymax > ymin else max(1e-3, abs(ymin) * 0.05 + 1e-3)
1286
+ ypad = (
1287
+ max(1e-6, (ymax - ymin) * 0.05)
1288
+ if ymax > ymin
1289
+ else max(1e-3, abs(ymin) * 0.05 + 1e-3)
1290
+ )
999
1291
  global_ylim = (ymin - ypad, ymax + ypad)
1000
1292
 
1001
1293
  # pagination
@@ -1005,15 +1297,19 @@ def plot_hamming_vs_metric_pages(
1005
1297
  ncols = len(cols)
1006
1298
  fig_w = ncols * figsize_per_cell[0]
1007
1299
  fig_h = nrows * figsize_per_cell[1]
1008
- fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False)
1300
+ fig, axes = plt.subplots(
1301
+ nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False
1302
+ )
1009
1303
 
1010
1304
  for r_idx, sample_name in enumerate(chunk):
1011
1305
  for c_idx, ref_name in enumerate(cols):
1012
1306
  ax = axes[r_idx][c_idx]
1013
1307
  if ref_name == extra_col:
1014
- mask = (obs[sample_col].values == sample_name)
1308
+ mask = obs[sample_col].values == sample_name
1015
1309
  else:
1016
- mask = (obs[sample_col].values == sample_name) & (obs[ref_col].values == ref_name)
1310
+ mask = (obs[sample_col].values == sample_name) & (
1311
+ obs[ref_col].values == ref_name
1312
+ )
1017
1313
 
1018
1314
  sub = obs[mask]
1019
1315
 
@@ -1022,7 +1318,9 @@ def plot_hamming_vs_metric_pages(
1022
1318
  else:
1023
1319
  x_all = np.array([], dtype=float)
1024
1320
  if hamming_col in sub.columns:
1025
- y_all = pd.to_numeric(sub[hamming_col], errors="coerce").to_numpy(dtype=float)
1321
+ y_all = pd.to_numeric(sub[hamming_col], errors="coerce").to_numpy(
1322
+ dtype=float
1323
+ )
1026
1324
  else:
1027
1325
  y_all = np.array([], dtype=float)
1028
1326
 
@@ -1040,32 +1338,67 @@ def plot_hamming_vs_metric_pages(
1040
1338
  idxs_valid = np.array([], dtype=int)
1041
1339
 
1042
1340
  if x.size == 0:
1043
- ax.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes)
1341
+ ax.text(
1342
+ 0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes
1343
+ )
1044
1344
  clusters_info[(sample_name, ref_name)] = {"diag": None, "n_points": 0}
1045
1345
  else:
1046
1346
  # Decide color mapping
1047
- if color_by_duplicate and duplicate_col in adata.obs.columns and idxs_valid.size > 0:
1347
+ if (
1348
+ color_by_duplicate
1349
+ and duplicate_col in adata.obs.columns
1350
+ and idxs_valid.size > 0
1351
+ ):
1048
1352
  # get boolean series aligned to idxs_valid
1049
1353
  try:
1050
- dup_flags = adata.obs.loc[idxs_valid, duplicate_col].astype(bool).to_numpy()
1354
+ dup_flags = (
1355
+ adata.obs.loc[idxs_valid, duplicate_col].astype(bool).to_numpy()
1356
+ )
1051
1357
  except Exception:
1052
1358
  dup_flags = np.zeros(len(idxs_valid), dtype=bool)
1053
1359
  mask_dup = dup_flags
1054
1360
  mask_nondup = ~mask_dup
1055
1361
  # plot non-duplicates first in gray, duplicates in highlight color
1056
1362
  if mask_nondup.any():
1057
- ax.scatter(x[mask_nondup], y[mask_nondup], s=12, alpha=0.6, rasterized=True, c="lightgray")
1363
+ ax.scatter(
1364
+ x[mask_nondup],
1365
+ y[mask_nondup],
1366
+ s=12,
1367
+ alpha=0.6,
1368
+ rasterized=True,
1369
+ c="lightgray",
1370
+ )
1058
1371
  if mask_dup.any():
1059
- ax.scatter(x[mask_dup], y[mask_dup], s=20, alpha=0.9, rasterized=True, c=highlight_color, edgecolors="k", linewidths=0.3)
1372
+ ax.scatter(
1373
+ x[mask_dup],
1374
+ y[mask_dup],
1375
+ s=20,
1376
+ alpha=0.9,
1377
+ rasterized=True,
1378
+ c=highlight_color,
1379
+ edgecolors="k",
1380
+ linewidths=0.3,
1381
+ )
1060
1382
  else:
1061
1383
  # old behavior: highlight by threshold if requested
1062
1384
  if highlight_threshold is not None and y.size:
1063
1385
  mask_low = (y < float(highlight_threshold)) & np.isfinite(y)
1064
1386
  mask_high = ~mask_low
1065
1387
  if mask_high.any():
1066
- ax.scatter(x[mask_high], y[mask_high], s=12, alpha=0.6, rasterized=True)
1388
+ ax.scatter(
1389
+ x[mask_high], y[mask_high], s=12, alpha=0.6, rasterized=True
1390
+ )
1067
1391
  if mask_low.any():
1068
- ax.scatter(x[mask_low], y[mask_low], s=18, alpha=0.9, rasterized=True, c=highlight_color, edgecolors="k", linewidths=0.3)
1392
+ ax.scatter(
1393
+ x[mask_low],
1394
+ y[mask_low],
1395
+ s=18,
1396
+ alpha=0.9,
1397
+ rasterized=True,
1398
+ c=highlight_color,
1399
+ edgecolors="k",
1400
+ linewidths=0.3,
1401
+ )
1069
1402
  else:
1070
1403
  ax.scatter(x, y, s=12, alpha=0.6, rasterized=True)
1071
1404
 
@@ -1081,7 +1414,9 @@ def plot_hamming_vs_metric_pages(
1081
1414
  zi = gaussian_kde(np.vstack([x, y]))(coords).reshape(xi_g.shape)
1082
1415
  ax.contourf(xi_g, yi_g, zi, levels=8, alpha=0.35, cmap="Blues")
1083
1416
  else:
1084
- ax.scatter(x, y, c=kde2, s=16, cmap="viridis", alpha=0.7, linewidths=0)
1417
+ ax.scatter(
1418
+ x, y, c=kde2, s=16, cmap="viridis", alpha=0.7, linewidths=0
1419
+ )
1085
1420
  except Exception:
1086
1421
  pass
1087
1422
 
@@ -1090,16 +1425,29 @@ def plot_hamming_vs_metric_pages(
1090
1425
  a, b = np.polyfit(x, y, 1)
1091
1426
  xs = np.linspace(np.nanmin(x), np.nanmax(x), 100)
1092
1427
  ys = a * xs + b
1093
- ax.plot(xs, ys, linestyle="--", linewidth=1.2, alpha=0.9, color="red")
1428
+ ax.plot(
1429
+ xs, ys, linestyle="--", linewidth=1.2, alpha=0.9, color="red"
1430
+ )
1094
1431
  r = np.corrcoef(x, y)[0, 1]
1095
- ax.text(0.98, 0.02, f"r={float(r):.3f}", ha="right", va="bottom", transform=ax.transAxes, fontsize=8,
1096
- bbox=dict(facecolor="white", alpha=0.6, boxstyle="round,pad=0.2"))
1432
+ ax.text(
1433
+ 0.98,
1434
+ 0.02,
1435
+ f"r={float(r):.3f}",
1436
+ ha="right",
1437
+ va="bottom",
1438
+ transform=ax.transAxes,
1439
+ fontsize=8,
1440
+ bbox=dict(
1441
+ facecolor="white", alpha=0.6, boxstyle="round,pad=0.2"
1442
+ ),
1443
+ )
1097
1444
  except Exception:
1098
1445
  pass
1099
1446
 
1100
1447
  if clustering:
1101
1448
  cl_labels, diag = _run_clustering(
1102
- x, y,
1449
+ x,
1450
+ y,
1103
1451
  method=clustering.get("method", "dbscan"),
1104
1452
  n_clusters=clustering.get("n_clusters", 2),
1105
1453
  dbscan_eps=clustering.get("dbscan_eps", 0.05),
@@ -1113,32 +1461,59 @@ def plot_hamming_vs_metric_pages(
1113
1461
  if len(unique_nonnoise) > 0:
1114
1462
  medians = {}
1115
1463
  for lab in unique_nonnoise:
1116
- mask_lab = (cl_labels == lab)
1117
- medians[lab] = float(np.median(y[mask_lab])) if mask_lab.any() else float("nan")
1118
- sorted_by_median = sorted(unique_nonnoise, key=lambda l: (np.nan if np.isnan(medians[l]) else medians[l]), reverse=True)
1464
+ mask_lab = cl_labels == lab
1465
+ medians[lab] = (
1466
+ float(np.median(y[mask_lab]))
1467
+ if mask_lab.any()
1468
+ else float("nan")
1469
+ )
1470
+ sorted_by_median = sorted(
1471
+ unique_nonnoise,
1472
+ key=lambda idx: (
1473
+ np.nan if np.isnan(medians[idx]) else medians[idx]
1474
+ ),
1475
+ reverse=True,
1476
+ )
1119
1477
  mapping = {old: new for new, old in enumerate(sorted_by_median)}
1120
1478
  for i_lab in range(len(remapped_labels)):
1121
1479
  if remapped_labels[i_lab] != -1:
1122
- remapped_labels[i_lab] = mapping.get(remapped_labels[i_lab], -1)
1480
+ remapped_labels[i_lab] = mapping.get(
1481
+ remapped_labels[i_lab], -1
1482
+ )
1123
1483
  diag = diag or {}
1124
- diag["cluster_median_hamming"] = {int(old): medians[old] for old in medians}
1125
- diag["cluster_old_to_new_map"] = {int(old): int(new) for old, new in mapping.items()}
1484
+ diag["cluster_median_hamming"] = {
1485
+ int(old): medians[old] for old in medians
1486
+ }
1487
+ diag["cluster_old_to_new_map"] = {
1488
+ int(old): int(new) for old, new in mapping.items()
1489
+ }
1126
1490
  else:
1127
1491
  remapped_labels = cl_labels.copy()
1128
1492
  diag = diag or {}
1129
1493
  diag["cluster_median_hamming"] = {}
1130
1494
  diag["cluster_old_to_new_map"] = {}
1131
1495
 
1132
- _overlay_clusters_on_ax(ax, x, y, remapped_labels, diag,
1133
- cmap=clustering.get("cmap", "tab10"),
1134
- hull=clustering.get("hull", True),
1135
- show_cluster_labels=True)
1496
+ _overlay_clusters_on_ax(
1497
+ ax,
1498
+ x,
1499
+ y,
1500
+ remapped_labels,
1501
+ diag,
1502
+ cmap=clustering.get("cmap", "tab10"),
1503
+ hull=clustering.get("hull", True),
1504
+ show_cluster_labels=True,
1505
+ )
1136
1506
 
1137
- clusters_info[(sample_name, ref_name)] = {"diag": diag, "n_points": len(x)}
1507
+ clusters_info[(sample_name, ref_name)] = {
1508
+ "diag": diag,
1509
+ "n_points": len(x),
1510
+ }
1138
1511
 
1139
1512
  if write_clusters_to_adata and idxs_valid.size > 0:
1140
- colname_safe_ref = (ref_name if ref_name != extra_col else "ALLREFS")
1141
- colname = f"hamming_cluster__{metric}__{sample_name}__{colname_safe_ref}"
1513
+ colname_safe_ref = ref_name if ref_name != extra_col else "ALLREFS"
1514
+ colname = (
1515
+ f"hamming_cluster__{metric}__{sample_name}__{colname_safe_ref}"
1516
+ )
1142
1517
  if colname not in adata.obs.columns:
1143
1518
  adata.obs[colname] = np.nan
1144
1519
  lab_arr = remapped_labels.astype(float)
@@ -1178,7 +1553,11 @@ def plot_hamming_vs_metric_pages(
1178
1553
  plt.show()
1179
1554
  plt.close(fig)
1180
1555
 
1181
- saved_map[metric] = {"files": files, "clusters_info": clusters_info, "written_cols": written_cols}
1556
+ saved_map[metric] = {
1557
+ "files": files,
1558
+ "clusters_info": clusters_info,
1559
+ "written_cols": written_cols,
1560
+ }
1182
1561
 
1183
1562
  return saved_map
1184
1563
 
@@ -1187,7 +1566,7 @@ def _run_clustering(
1187
1566
  x: np.ndarray,
1188
1567
  y: np.ndarray,
1189
1568
  *,
1190
- method: str = "kmeans", # "kmeans", "dbscan", "gmm", "hdbscan"
1569
+ method: str = "kmeans", # "kmeans", "dbscan", "gmm", "hdbscan"
1191
1570
  n_clusters: int = 2,
1192
1571
  dbscan_eps: float = 0.05,
1193
1572
  dbscan_min_samples: int = 5,
@@ -1199,9 +1578,9 @@ def _run_clustering(
1199
1578
  Labels follow sklearn conventions (noise -> -1 for DBSCAN/HDBSCAN).
1200
1579
  """
1201
1580
  try:
1202
- from sklearn.cluster import KMeans, DBSCAN
1203
- from sklearn.mixture import GaussianMixture
1581
+ from sklearn.cluster import DBSCAN, KMeans
1204
1582
  from sklearn.metrics import silhouette_score
1583
+ from sklearn.mixture import GaussianMixture
1205
1584
  except Exception:
1206
1585
  KMeans = DBSCAN = GaussianMixture = silhouette_score = None
1207
1586
 
@@ -1270,7 +1649,11 @@ def _run_clustering(
1270
1649
 
1271
1650
  # compute silhouette if suitable
1272
1651
  try:
1273
- if diagnostics.get("n_clusters_found", 0) >= 2 and len(x) >= 3 and silhouette_score is not None:
1652
+ if (
1653
+ diagnostics.get("n_clusters_found", 0) >= 2
1654
+ and len(x) >= 3
1655
+ and silhouette_score is not None
1656
+ ):
1274
1657
  diagnostics["silhouette"] = float(silhouette_score(pts, labels))
1275
1658
  else:
1276
1659
  diagnostics["silhouette"] = None
@@ -1305,7 +1688,6 @@ def _overlay_clusters_on_ax(
1305
1688
  Labels == -1 are noise and drawn in grey.
1306
1689
  Also annotates cluster numbers near centroids (contiguous numbers starting at 0).
1307
1690
  """
1308
- import matplotlib.colors as mcolors
1309
1691
  from scipy.spatial import ConvexHull
1310
1692
 
1311
1693
  labels = np.asarray(labels)
@@ -1323,19 +1705,47 @@ def _overlay_clusters_on_ax(
1323
1705
  if not mask.any():
1324
1706
  continue
1325
1707
  col = (0.6, 0.6, 0.6, 0.6) if lab == -1 else colors[idx % ncolors]
1326
- ax.scatter(x[mask], y[mask], s=20, c=[col], alpha=alpha_pts, marker=marker, linewidths=0.2, edgecolors="none", rasterized=True)
1708
+ ax.scatter(
1709
+ x[mask],
1710
+ y[mask],
1711
+ s=20,
1712
+ c=[col],
1713
+ alpha=alpha_pts,
1714
+ marker=marker,
1715
+ linewidths=0.2,
1716
+ edgecolors="none",
1717
+ rasterized=True,
1718
+ )
1327
1719
 
1328
1720
  if lab != -1:
1329
1721
  # centroid
1330
1722
  if plot_centroids:
1331
1723
  cx = float(np.mean(x[mask]))
1332
1724
  cy = float(np.mean(y[mask]))
1333
- ax.scatter([cx], [cy], s=centroid_size, marker=centroid_marker, c=[col], edgecolor="k", linewidth=0.6, zorder=10)
1725
+ ax.scatter(
1726
+ [cx],
1727
+ [cy],
1728
+ s=centroid_size,
1729
+ marker=centroid_marker,
1730
+ c=[col],
1731
+ edgecolor="k",
1732
+ linewidth=0.6,
1733
+ zorder=10,
1734
+ )
1334
1735
 
1335
1736
  if show_cluster_labels:
1336
- ax.text(cx, cy, str(int(lab)), color="white", fontsize=cluster_label_fontsize,
1337
- ha="center", va="center", weight="bold", zorder=12,
1338
- bbox=dict(facecolor=(0,0,0,0.5), pad=0.3, boxstyle="round"))
1737
+ ax.text(
1738
+ cx,
1739
+ cy,
1740
+ str(int(lab)),
1741
+ color="white",
1742
+ fontsize=cluster_label_fontsize,
1743
+ ha="center",
1744
+ va="center",
1745
+ weight="bold",
1746
+ zorder=12,
1747
+ bbox=dict(facecolor=(0, 0, 0, 0.5), pad=0.3, boxstyle="round"),
1748
+ )
1339
1749
 
1340
1750
  # hull
1341
1751
  if hull and np.sum(mask) >= 3:
@@ -1343,9 +1753,16 @@ def _overlay_clusters_on_ax(
1343
1753
  ch_pts = pts[mask]
1344
1754
  hull_idx = ConvexHull(ch_pts).vertices
1345
1755
  hull_poly = ch_pts[hull_idx]
1346
- ax.fill(hull_poly[:, 0], hull_poly[:, 1], alpha=hull_alpha, facecolor=col, edgecolor=hull_edgecolor, linewidth=0.6, zorder=5)
1756
+ ax.fill(
1757
+ hull_poly[:, 0],
1758
+ hull_poly[:, 1],
1759
+ alpha=hull_alpha,
1760
+ facecolor=col,
1761
+ edgecolor=hull_edgecolor,
1762
+ linewidth=0.6,
1763
+ zorder=5,
1764
+ )
1347
1765
  except Exception:
1348
1766
  pass
1349
1767
 
1350
1768
  return None
1351
-