smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. smftools/__init__.py +6 -8
  2. smftools/_settings.py +4 -6
  3. smftools/_version.py +1 -1
  4. smftools/cli/helpers.py +54 -0
  5. smftools/cli/hmm_adata.py +937 -256
  6. smftools/cli/load_adata.py +448 -268
  7. smftools/cli/preprocess_adata.py +469 -263
  8. smftools/cli/spatial_adata.py +536 -319
  9. smftools/cli_entry.py +97 -182
  10. smftools/config/__init__.py +1 -1
  11. smftools/config/conversion.yaml +17 -6
  12. smftools/config/deaminase.yaml +12 -10
  13. smftools/config/default.yaml +142 -33
  14. smftools/config/direct.yaml +11 -3
  15. smftools/config/discover_input_files.py +19 -5
  16. smftools/config/experiment_config.py +594 -264
  17. smftools/constants.py +37 -0
  18. smftools/datasets/__init__.py +2 -8
  19. smftools/datasets/datasets.py +32 -18
  20. smftools/hmm/HMM.py +2128 -1418
  21. smftools/hmm/__init__.py +2 -9
  22. smftools/hmm/archived/call_hmm_peaks.py +121 -0
  23. smftools/hmm/call_hmm_peaks.py +299 -91
  24. smftools/hmm/display_hmm.py +19 -6
  25. smftools/hmm/hmm_readwrite.py +13 -4
  26. smftools/hmm/nucleosome_hmm_refinement.py +102 -14
  27. smftools/informatics/__init__.py +30 -7
  28. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  30. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  31. smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
  32. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
  33. smftools/informatics/archived/print_bam_query_seq.py +7 -1
  34. smftools/informatics/bam_functions.py +397 -175
  35. smftools/informatics/basecalling.py +51 -9
  36. smftools/informatics/bed_functions.py +90 -57
  37. smftools/informatics/binarize_converted_base_identities.py +18 -7
  38. smftools/informatics/complement_base_list.py +7 -6
  39. smftools/informatics/converted_BAM_to_adata.py +265 -122
  40. smftools/informatics/fasta_functions.py +161 -83
  41. smftools/informatics/h5ad_functions.py +196 -30
  42. smftools/informatics/modkit_extract_to_adata.py +609 -270
  43. smftools/informatics/modkit_functions.py +85 -44
  44. smftools/informatics/ohe.py +44 -21
  45. smftools/informatics/pod5_functions.py +112 -73
  46. smftools/informatics/run_multiqc.py +20 -14
  47. smftools/logging_utils.py +51 -0
  48. smftools/machine_learning/__init__.py +2 -7
  49. smftools/machine_learning/data/anndata_data_module.py +143 -50
  50. smftools/machine_learning/data/preprocessing.py +2 -1
  51. smftools/machine_learning/evaluation/__init__.py +1 -1
  52. smftools/machine_learning/evaluation/eval_utils.py +11 -14
  53. smftools/machine_learning/evaluation/evaluators.py +46 -33
  54. smftools/machine_learning/inference/__init__.py +1 -1
  55. smftools/machine_learning/inference/inference_utils.py +7 -4
  56. smftools/machine_learning/inference/lightning_inference.py +9 -13
  57. smftools/machine_learning/inference/sklearn_inference.py +6 -8
  58. smftools/machine_learning/inference/sliding_window_inference.py +35 -25
  59. smftools/machine_learning/models/__init__.py +10 -5
  60. smftools/machine_learning/models/base.py +28 -42
  61. smftools/machine_learning/models/cnn.py +15 -11
  62. smftools/machine_learning/models/lightning_base.py +71 -40
  63. smftools/machine_learning/models/mlp.py +13 -4
  64. smftools/machine_learning/models/positional.py +3 -2
  65. smftools/machine_learning/models/rnn.py +3 -2
  66. smftools/machine_learning/models/sklearn_models.py +39 -22
  67. smftools/machine_learning/models/transformer.py +68 -53
  68. smftools/machine_learning/models/wrappers.py +2 -1
  69. smftools/machine_learning/training/__init__.py +2 -2
  70. smftools/machine_learning/training/train_lightning_model.py +29 -20
  71. smftools/machine_learning/training/train_sklearn_model.py +9 -15
  72. smftools/machine_learning/utils/__init__.py +1 -1
  73. smftools/machine_learning/utils/device.py +7 -4
  74. smftools/machine_learning/utils/grl.py +3 -1
  75. smftools/metadata.py +443 -0
  76. smftools/plotting/__init__.py +19 -5
  77. smftools/plotting/autocorrelation_plotting.py +145 -44
  78. smftools/plotting/classifiers.py +162 -72
  79. smftools/plotting/general_plotting.py +422 -197
  80. smftools/plotting/hmm_plotting.py +42 -13
  81. smftools/plotting/position_stats.py +147 -87
  82. smftools/plotting/qc_plotting.py +20 -12
  83. smftools/preprocessing/__init__.py +10 -12
  84. smftools/preprocessing/append_base_context.py +115 -80
  85. smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
  86. smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
  87. smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
  88. smftools/preprocessing/binarize.py +21 -4
  89. smftools/preprocessing/binarize_on_Youden.py +129 -31
  90. smftools/preprocessing/binary_layers_to_ohe.py +17 -11
  91. smftools/preprocessing/calculate_complexity_II.py +86 -59
  92. smftools/preprocessing/calculate_consensus.py +28 -19
  93. smftools/preprocessing/calculate_coverage.py +50 -25
  94. smftools/preprocessing/calculate_pairwise_differences.py +2 -1
  95. smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
  96. smftools/preprocessing/calculate_position_Youden.py +118 -54
  97. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  98. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  99. smftools/preprocessing/clean_NaN.py +38 -28
  100. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  101. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
  102. smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
  103. smftools/preprocessing/flag_duplicate_reads.py +689 -272
  104. smftools/preprocessing/invert_adata.py +26 -11
  105. smftools/preprocessing/load_sample_sheet.py +40 -22
  106. smftools/preprocessing/make_dirs.py +8 -3
  107. smftools/preprocessing/min_non_diagonal.py +2 -1
  108. smftools/preprocessing/recipes.py +56 -23
  109. smftools/preprocessing/reindex_references_adata.py +103 -0
  110. smftools/preprocessing/subsample_adata.py +33 -16
  111. smftools/readwrite.py +331 -82
  112. smftools/schema/__init__.py +11 -0
  113. smftools/schema/anndata_schema_v1.yaml +227 -0
  114. smftools/tools/__init__.py +3 -4
  115. smftools/tools/archived/classifiers.py +163 -0
  116. smftools/tools/archived/subset_adata_v1.py +10 -1
  117. smftools/tools/archived/subset_adata_v2.py +12 -1
  118. smftools/tools/calculate_umap.py +54 -15
  119. smftools/tools/cluster_adata_on_methylation.py +115 -46
  120. smftools/tools/general_tools.py +70 -25
  121. smftools/tools/position_stats.py +229 -98
  122. smftools/tools/read_stats.py +50 -29
  123. smftools/tools/spatial_autocorrelation.py +365 -192
  124. smftools/tools/subset_adata.py +23 -21
  125. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
  126. smftools-0.2.5.dist-info/RECORD +181 -0
  127. smftools-0.2.3.dist-info/RECORD +0 -173
  128. /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
  129. /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
  130. /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
  131. /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
  132. /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
  133. /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
  134. /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
  135. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
  136. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
  137. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
@@ -1,22 +1,87 @@
1
1
  from __future__ import annotations
2
2
 
3
- import numpy as np
4
- import seaborn as sns
5
- import matplotlib.pyplot as plt
6
- import scipy.cluster.hierarchy as sch
7
- import matplotlib.gridspec as gridspec
8
- import os
9
3
  import math
4
+ import os
5
+ from pathlib import Path
6
+ from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
7
+
8
+ import matplotlib.gridspec as gridspec
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
10
11
  import pandas as pd
12
+ import scipy.cluster.hierarchy as sch
13
+ import seaborn as sns
14
+
15
+
16
+ def _fixed_tick_positions(n_positions: int, n_ticks: int) -> np.ndarray:
17
+ """
18
+ Return indices for ~n_ticks evenly spaced labels across [0, n_positions-1].
19
+ Always includes 0 and n_positions-1 when possible.
20
+ """
21
+ n_ticks = int(max(2, n_ticks))
22
+ if n_positions <= n_ticks:
23
+ return np.arange(n_positions)
24
+
25
+ # linspace gives fixed count
26
+ pos = np.linspace(0, n_positions - 1, n_ticks)
27
+ return np.unique(np.round(pos).astype(int))
28
+
29
+
30
+ def _select_labels(subset, sites: np.ndarray, reference: str, index_col_suffix: str | None):
31
+ """
32
+ Select tick labels for the heatmap axis.
33
+
34
+ Parameters
35
+ ----------
36
+ subset : AnnData view
37
+ The per-bin subset of the AnnData.
38
+ sites : np.ndarray[int]
39
+ Indices of the subset.var positions to annotate.
40
+ reference : str
41
+ Reference name (e.g., '6B6_top').
42
+ index_col_suffix : None or str
43
+ If None → use subset.var_names
44
+ Else → use subset.var[f"{reference}_{index_col_suffix}"]
45
+
46
+ Returns
47
+ -------
48
+ np.ndarray[str]
49
+ The labels to use for tick positions.
50
+ """
51
+ if sites.size == 0:
52
+ return np.array([])
53
+
54
+ # Default behavior: use var_names
55
+ if index_col_suffix is None:
56
+ return subset.var_names[sites].astype(str)
57
+
58
+ # Otherwise: use a computed column adata.var[f"{reference}_{suffix}"]
59
+ colname = f"{reference}_{index_col_suffix}"
60
+
61
+ if colname not in subset.var:
62
+ raise KeyError(
63
+ f"index_col_suffix='{index_col_suffix}' requires var column '{colname}', "
64
+ f"but it is not present in adata.var."
65
+ )
66
+
67
+ labels = subset.var[colname].astype(str).values
68
+ return labels[sites]
11
69
 
12
- from typing import Optional, Mapping, Sequence, Any, Dict, List
13
- from pathlib import Path
14
70
 
15
71
  def normalized_mean(matrix: np.ndarray) -> np.ndarray:
72
+ """Compute normalized column means for a matrix.
73
+
74
+ Args:
75
+ matrix: Input matrix.
76
+
77
+ Returns:
78
+ 1D array of normalized means.
79
+ """
16
80
  mean = np.nanmean(matrix, axis=0)
17
81
  denom = (mean.max() - mean.min()) + 1e-9
18
82
  return (mean - mean.min()) / denom
19
83
 
84
+
20
85
  def methylation_fraction(matrix: np.ndarray) -> np.ndarray:
21
86
  """
22
87
  Fraction methylated per column.
@@ -31,14 +96,20 @@ def methylation_fraction(matrix: np.ndarray) -> np.ndarray:
31
96
  valid = valid_mask.sum(axis=0)
32
97
 
33
98
  return np.divide(
34
- methylated, valid,
35
- out=np.zeros_like(methylated, dtype=float),
36
- where=valid != 0
99
+ methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0
37
100
  )
38
101
 
102
+
39
103
  def clean_barplot(ax, mean_values, title):
104
+ """Format a barplot with consistent axes and labels.
105
+
106
+ Args:
107
+ ax: Matplotlib axes.
108
+ mean_values: Values to plot.
109
+ title: Plot title.
110
+ """
40
111
  x = np.arange(len(mean_values))
41
- ax.bar(x, mean_values, color="gray", width=1.0, align='edge')
112
+ ax.bar(x, mean_values, color="gray", width=1.0, align="edge")
42
113
  ax.set_xlim(0, len(mean_values))
43
114
  ax.set_ylim(0, 1)
44
115
  ax.set_yticks([0.0, 0.5, 1.0])
@@ -47,9 +118,10 @@ def clean_barplot(ax, mean_values, title):
47
118
 
48
119
  # Hide all spines except left
49
120
  for spine_name, spine in ax.spines.items():
50
- spine.set_visible(spine_name == 'left')
121
+ spine.set_visible(spine_name == "left")
122
+
123
+ ax.tick_params(axis="x", which="both", bottom=False, top=False, labelbottom=False)
51
124
 
52
- ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
53
125
 
54
126
  # def combined_hmm_raw_clustermap(
55
127
  # adata,
@@ -92,7 +164,7 @@ def clean_barplot(ax, mean_values, title):
92
164
  # (adata.obs['read_length'] >= min_length) &
93
165
  # (adata.obs['mapped_length_to_reference_length_ratio'] > min_mapped_length_to_reference_length_ratio)
94
166
  # ]
95
-
167
+
96
168
  # mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
97
169
  # subset = subset[:, mask]
98
170
 
@@ -204,7 +276,7 @@ def clean_barplot(ax, mean_values, title):
204
276
  # clean_barplot(axes_bar[1], mean_gpc, f"GpC Accessibility Signal")
205
277
  # clean_barplot(axes_bar[2], mean_cpg, f"CpG Accessibility Signal")
206
278
  # clean_barplot(axes_bar[3], mean_any_c, f"Any C Accessibility Signal")
207
-
279
+
208
280
  # hmm_labels = subset.var_names.astype(int)
209
281
  # hmm_label_spacing = 150
210
282
  # sns.heatmap(hmm_matrix, cmap=cmap_hmm, ax=axes_heat[0], xticklabels=hmm_labels[::hmm_label_spacing], yticklabels=False, cbar=False)
@@ -258,7 +330,7 @@ def clean_barplot(ax, mean_values, title):
258
330
  # "bin_boundaries": bin_boundaries,
259
331
  # "percentages": percentages
260
332
  # })
261
-
333
+
262
334
  # #adata.uns['clustermap_results'] = results
263
335
 
264
336
  # except Exception as e:
@@ -271,83 +343,131 @@ def combined_hmm_raw_clustermap(
271
343
  adata,
272
344
  sample_col: str = "Sample_Names",
273
345
  reference_col: str = "Reference_strand",
274
-
275
346
  hmm_feature_layer: str = "hmm_combined",
276
-
277
347
  layer_gpc: str = "nan0_0minus1",
278
348
  layer_cpg: str = "nan0_0minus1",
279
- layer_any_c: str = "nan0_0minus1",
349
+ layer_c: str = "nan0_0minus1",
280
350
  layer_a: str = "nan0_0minus1",
281
-
282
351
  cmap_hmm: str = "tab10",
283
352
  cmap_gpc: str = "coolwarm",
284
353
  cmap_cpg: str = "viridis",
285
- cmap_any_c: str = "coolwarm",
354
+ cmap_c: str = "coolwarm",
286
355
  cmap_a: str = "coolwarm",
287
-
288
356
  min_quality: int = 20,
289
357
  min_length: int = 200,
290
358
  min_mapped_length_to_reference_length_ratio: float = 0.8,
291
359
  min_position_valid_fraction: float = 0.5,
292
-
360
+ demux_types: Sequence[str] = ("single", "double", "already"),
361
+ sample_mapping: Optional[Mapping[str, str]] = None,
293
362
  save_path: str | Path | None = None,
294
363
  normalize_hmm: bool = False,
295
-
296
364
  sort_by: str = "gpc",
297
365
  bins: Optional[Dict[str, Any]] = None,
298
-
299
366
  deaminase: bool = False,
300
367
  min_signal: float = 0.0,
301
-
302
368
  # ---- fixed tick label controls (counts, not spacing)
303
369
  n_xticks_hmm: int = 10,
304
370
  n_xticks_any_c: int = 8,
305
371
  n_xticks_gpc: int = 8,
306
372
  n_xticks_cpg: int = 8,
307
373
  n_xticks_a: int = 8,
374
+ index_col_suffix: str | None = None,
308
375
  ):
309
376
  """
310
377
  Makes a multi-panel clustermap per (sample, reference):
311
- HMM panel (always) + optional raw panels for any_C, GpC, CpG, and A sites.
378
+ HMM panel (always) + optional raw panels for C, GpC, CpG, and A sites.
312
379
 
313
380
  Panels are added only if the corresponding site mask exists AND has >0 sites.
314
381
 
315
382
  sort_by options:
316
- 'gpc', 'cpg', 'any_c', 'any_a', 'gpc_cpg', 'none', or 'obs:<col>'
383
+ 'gpc', 'cpg', 'c', 'a', 'gpc_cpg', 'none', 'hmm', or 'obs:<col>'
317
384
  """
385
+
318
386
  def pick_xticks(labels: np.ndarray, n_ticks: int):
387
+ """Pick tick indices/labels from an array."""
319
388
  if labels.size == 0:
320
389
  return [], []
321
390
  idx = np.linspace(0, len(labels) - 1, n_ticks).round().astype(int)
322
391
  idx = np.unique(idx)
323
392
  return idx.tolist(), labels[idx].tolist()
324
-
393
+
394
+ # Helper: build a True mask if filter is inactive or column missing
395
+ def _mask_or_true(series_name: str, predicate):
396
+ """Return a mask from predicate or an all-True mask."""
397
+ if series_name not in adata.obs:
398
+ return pd.Series(True, index=adata.obs.index)
399
+ s = adata.obs[series_name]
400
+ try:
401
+ return predicate(s)
402
+ except Exception:
403
+ # Fallback: all True if bad dtype / predicate failure
404
+ return pd.Series(True, index=adata.obs.index)
405
+
325
406
  results = []
326
407
  signal_type = "deamination" if deaminase else "methylation"
327
408
 
328
409
  for ref in adata.obs[reference_col].cat.categories:
329
410
  for sample in adata.obs[sample_col].cat.categories:
411
+ # Optionally remap sample label for display
412
+ display_sample = sample_mapping.get(sample, sample) if sample_mapping else sample
413
+ # Row-level masks (obs)
414
+ qmask = _mask_or_true(
415
+ "read_quality",
416
+ (lambda s: s >= float(min_quality))
417
+ if (min_quality is not None)
418
+ else (lambda s: pd.Series(True, index=s.index)),
419
+ )
420
+ lm_mask = _mask_or_true(
421
+ "mapped_length",
422
+ (lambda s: s >= float(min_length))
423
+ if (min_length is not None)
424
+ else (lambda s: pd.Series(True, index=s.index)),
425
+ )
426
+ lrr_mask = _mask_or_true(
427
+ "mapped_length_to_reference_length_ratio",
428
+ (lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
429
+ if (min_mapped_length_to_reference_length_ratio is not None)
430
+ else (lambda s: pd.Series(True, index=s.index)),
431
+ )
432
+
433
+ demux_mask = _mask_or_true(
434
+ "demux_type",
435
+ (lambda s: s.astype("string").isin(list(demux_types)))
436
+ if (demux_types is not None)
437
+ else (lambda s: pd.Series(True, index=s.index)),
438
+ )
439
+
440
+ ref_mask = adata.obs[reference_col] == ref
441
+ sample_mask = adata.obs[sample_col] == sample
442
+
443
+ row_mask = ref_mask & sample_mask & qmask & lm_mask & lrr_mask & demux_mask
444
+
445
+ if not bool(row_mask.any()):
446
+ print(
447
+ f"No reads for {display_sample} - {ref} after read quality and length filtering"
448
+ )
449
+ continue
330
450
 
331
451
  try:
332
452
  # ---- subset reads ----
333
- subset = adata[
334
- (adata.obs[reference_col] == ref) &
335
- (adata.obs[sample_col] == sample) &
336
- (adata.obs["read_quality"] >= min_quality) &
337
- (adata.obs["read_length"] >= min_length) &
338
- (
339
- adata.obs["mapped_length_to_reference_length_ratio"]
340
- > min_mapped_length_to_reference_length_ratio
341
- )
342
- ]
343
-
344
- # ---- valid fraction filter ----
345
- vf_key = f"{ref}_valid_fraction"
346
- if vf_key in subset.var:
347
- mask = subset.var[vf_key].astype(float) > float(min_position_valid_fraction)
348
- subset = subset[:, mask]
453
+ subset = adata[row_mask, :].copy()
454
+
455
+ # Column-level mask (var)
456
+ if min_position_valid_fraction is not None:
457
+ valid_key = f"{ref}_valid_fraction"
458
+ if valid_key in subset.var:
459
+ v = pd.to_numeric(subset.var[valid_key], errors="coerce").to_numpy()
460
+ col_mask = np.asarray(v > float(min_position_valid_fraction), dtype=bool)
461
+ if col_mask.any():
462
+ subset = subset[:, col_mask].copy()
463
+ else:
464
+ print(
465
+ f"No positions left after valid_fraction filter for {display_sample} - {ref}"
466
+ )
467
+ continue
349
468
 
350
469
  if subset.shape[0] == 0:
470
+ print(f"No reads left after filtering for {display_sample} - {ref}")
351
471
  continue
352
472
 
353
473
  # ---- bins ----
@@ -358,6 +478,7 @@ def combined_hmm_raw_clustermap(
358
478
 
359
479
  # ---- site masks (robust) ----
360
480
  def _sites(*keys):
481
+ """Return indices for the first matching site key."""
361
482
  for k in keys:
362
483
  if k in subset.var:
363
484
  return np.where(subset.var[k].values)[0]
@@ -368,13 +489,14 @@ def combined_hmm_raw_clustermap(
368
489
  any_c_sites = _sites(f"{ref}_any_C_site", f"{ref}_C_site")
369
490
  any_a_sites = _sites(f"{ref}_A_site", f"{ref}_any_A_site")
370
491
 
371
- def _labels(sites):
372
- return subset.var_names[sites].astype(int) if sites.size else np.array([])
373
-
374
- gpc_labels = _labels(gpc_sites)
375
- cpg_labels = _labels(cpg_sites)
376
- any_c_labels = _labels(any_c_sites)
377
- any_a_labels = _labels(any_a_sites)
492
+ # ---- labels via _select_labels ----
493
+ # HMM uses *all* columns
494
+ hmm_sites = np.arange(subset.n_vars, dtype=int)
495
+ hmm_labels = _select_labels(subset, hmm_sites, ref, index_col_suffix)
496
+ gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
497
+ cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
498
+ any_c_labels = _select_labels(subset, any_c_sites, ref, index_col_suffix)
499
+ any_a_labels = _select_labels(subset, any_a_sites, ref, index_col_suffix)
378
500
 
379
501
  # storage
380
502
  stacked_hmm = []
@@ -411,11 +533,11 @@ def combined_hmm_raw_clustermap(
411
533
  linkage = sch.linkage(sb[:, cpg_sites].layers[layer_cpg], method="ward")
412
534
  order = sch.leaves_list(linkage)
413
535
 
414
- elif sort_by == "any_c" and any_c_sites.size:
415
- linkage = sch.linkage(sb[:, any_c_sites].layers[layer_any_c], method="ward")
536
+ elif sort_by == "c" and any_c_sites.size:
537
+ linkage = sch.linkage(sb[:, any_c_sites].layers[layer_c], method="ward")
416
538
  order = sch.leaves_list(linkage)
417
539
 
418
- elif sort_by == "any_a" and any_a_sites.size:
540
+ elif sort_by == "a" and any_a_sites.size:
419
541
  linkage = sch.linkage(sb[:, any_a_sites].layers[layer_a], method="ward")
420
542
  order = sch.leaves_list(linkage)
421
543
 
@@ -423,6 +545,12 @@ def combined_hmm_raw_clustermap(
423
545
  linkage = sch.linkage(sb.layers[layer_gpc], method="ward")
424
546
  order = sch.leaves_list(linkage)
425
547
 
548
+ elif sort_by == "hmm" and hmm_sites.size:
549
+ linkage = sch.linkage(
550
+ sb[:, hmm_sites].layers[hmm_feature_layer], method="ward"
551
+ )
552
+ order = sch.leaves_list(linkage)
553
+
426
554
  else:
427
555
  order = np.arange(n)
428
556
 
@@ -431,7 +559,7 @@ def combined_hmm_raw_clustermap(
431
559
  # ---- collect matrices ----
432
560
  stacked_hmm.append(sb.layers[hmm_feature_layer])
433
561
  if any_c_sites.size:
434
- stacked_any_c.append(sb[:, any_c_sites].layers[layer_any_c])
562
+ stacked_any_c.append(sb[:, any_c_sites].layers[layer_c])
435
563
  if gpc_sites.size:
436
564
  stacked_gpc.append(sb[:, gpc_sites].layers[layer_gpc])
437
565
  if cpg_sites.size:
@@ -446,46 +574,62 @@ def combined_hmm_raw_clustermap(
446
574
 
447
575
  # ---------------- stack ----------------
448
576
  hmm_matrix = np.vstack(stacked_hmm)
449
- mean_hmm = normalized_mean(hmm_matrix) if normalize_hmm else np.nanmean(hmm_matrix, axis=0)
577
+ mean_hmm = (
578
+ normalized_mean(hmm_matrix) if normalize_hmm else np.nanmean(hmm_matrix, axis=0)
579
+ )
450
580
 
451
581
  panels = [
452
- ("HMM", hmm_matrix, subset.var_names.astype(int), cmap_hmm, mean_hmm, n_xticks_hmm),
582
+ (
583
+ f"HMM - {hmm_feature_layer}",
584
+ hmm_matrix,
585
+ hmm_labels,
586
+ cmap_hmm,
587
+ mean_hmm,
588
+ n_xticks_hmm,
589
+ ),
453
590
  ]
454
591
 
455
592
  if stacked_any_c:
456
593
  m = np.vstack(stacked_any_c)
457
- panels.append(("any_C", m, any_c_labels, cmap_any_c, methylation_fraction(m), n_xticks_any_c))
594
+ panels.append(
595
+ ("C", m, any_c_labels, cmap_c, methylation_fraction(m), n_xticks_any_c)
596
+ )
458
597
 
459
598
  if stacked_gpc:
460
599
  m = np.vstack(stacked_gpc)
461
- panels.append(("GpC", m, gpc_labels, cmap_gpc, methylation_fraction(m), n_xticks_gpc))
600
+ panels.append(
601
+ ("GpC", m, gpc_labels, cmap_gpc, methylation_fraction(m), n_xticks_gpc)
602
+ )
462
603
 
463
604
  if stacked_cpg:
464
605
  m = np.vstack(stacked_cpg)
465
- panels.append(("CpG", m, cpg_labels, cmap_cpg, methylation_fraction(m), n_xticks_cpg))
606
+ panels.append(
607
+ ("CpG", m, cpg_labels, cmap_cpg, methylation_fraction(m), n_xticks_cpg)
608
+ )
466
609
 
467
610
  if stacked_any_a:
468
611
  m = np.vstack(stacked_any_a)
469
- panels.append(("A", m, any_a_labels, cmap_a, methylation_fraction(m), n_xticks_a))
612
+ panels.append(
613
+ ("A", m, any_a_labels, cmap_a, methylation_fraction(m), n_xticks_a)
614
+ )
470
615
 
471
616
  # ---------------- plotting ----------------
472
617
  n_panels = len(panels)
473
618
  fig = plt.figure(figsize=(4.5 * n_panels, 10))
474
619
  gs = gridspec.GridSpec(2, n_panels, height_ratios=[1, 6], hspace=0.01)
475
- fig.suptitle(f"{sample} — {ref} — {total_reads} reads ({signal_type})",
476
- fontsize=14, y=0.98)
620
+ fig.suptitle(
621
+ f"{sample} — {ref} — {total_reads} reads ({signal_type})", fontsize=14, y=0.98
622
+ )
477
623
 
478
624
  axes_heat = [fig.add_subplot(gs[1, i]) for i in range(n_panels)]
479
625
  axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(n_panels)]
480
626
 
481
627
  for i, (name, matrix, labels, cmap, mean_vec, n_ticks) in enumerate(panels):
482
-
483
628
  # ---- your clean barplot ----
484
629
  clean_barplot(axes_bar[i], mean_vec, name)
485
630
 
486
631
  # ---- heatmap ----
487
- sns.heatmap(matrix, cmap=cmap, ax=axes_heat[i],
488
- yticklabels=False, cbar=False)
632
+ sns.heatmap(matrix, cmap=cmap, ax=axes_heat[i], yticklabels=False, cbar=False)
489
633
 
490
634
  # ---- xticks ----
491
635
  xtick_pos, xtick_labels = pick_xticks(np.asarray(labels), n_ticks)
@@ -509,6 +653,7 @@ def combined_hmm_raw_clustermap(
509
653
 
510
654
  except Exception:
511
655
  import traceback
656
+
512
657
  traceback.print_exc()
513
658
  continue
514
659
 
@@ -628,7 +773,7 @@ def combined_hmm_raw_clustermap(
628
773
  # order = np.arange(num_reads)
629
774
  # elif sort_by == "any_a":
630
775
  # linkage = sch.linkage(subset_bin.layers[layer_a], method="ward")
631
- # order = sch.leaves_list(linkage)
776
+ # order = sch.leaves_list(linkage)
632
777
  # else:
633
778
  # raise ValueError(f"Unsupported sort_by option: {sort_by}")
634
779
 
@@ -657,13 +802,13 @@ def combined_hmm_raw_clustermap(
657
802
  # order = np.arange(num_reads)
658
803
  # elif sort_by == "any_a":
659
804
  # linkage = sch.linkage(subset_bin.layers[layer_a], method="ward")
660
- # order = sch.leaves_list(linkage)
805
+ # order = sch.leaves_list(linkage)
661
806
  # else:
662
807
  # raise ValueError(f"Unsupported sort_by option: {sort_by}")
663
-
808
+
664
809
  # stacked_any_a.append(subset_bin[order][:, any_a_sites].layers[layer_a])
665
-
666
-
810
+
811
+
667
812
  # row_labels.extend([bin_label] * num_reads)
668
813
  # bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
669
814
  # last_idx += num_reads
@@ -686,7 +831,7 @@ def combined_hmm_raw_clustermap(
686
831
  # if any_a_matrix.size > 0:
687
832
  # mean_any_a = methylation_fraction(any_a_matrix)
688
833
  # gs_dim += 1
689
-
834
+
690
835
 
691
836
  # fig = plt.figure(figsize=(18, 12))
692
837
  # gs = gridspec.GridSpec(2, gs_dim, height_ratios=[1, 6], hspace=0.01)
@@ -718,8 +863,8 @@ def combined_hmm_raw_clustermap(
718
863
  # sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
719
864
  # axes_heat[current_ax].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
720
865
  # for boundary in bin_boundaries[:-1]:
721
- # axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
722
- # current_ax +=1
866
+ # axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
867
+ # current_ax +=1
723
868
 
724
869
  # results.append({
725
870
  # "sample": sample,
@@ -731,7 +876,7 @@ def combined_hmm_raw_clustermap(
731
876
  # "bin_labels": bin_labels,
732
877
  # "bin_boundaries": bin_boundaries,
733
878
  # "percentages": percentages
734
- # })
879
+ # })
735
880
 
736
881
  # if stacked_any_a:
737
882
  # if any_a_matrix.size > 0:
@@ -751,7 +896,7 @@ def combined_hmm_raw_clustermap(
751
896
  # "bin_labels": bin_labels,
752
897
  # "bin_boundaries": bin_boundaries,
753
898
  # "percentages": percentages
754
- # })
899
+ # })
755
900
 
756
901
  # plt.tight_layout()
757
902
 
@@ -769,7 +914,7 @@ def combined_hmm_raw_clustermap(
769
914
  # print(f"Summary for {sample} - {ref}:")
770
915
  # for bin_label, percent in percentages.items():
771
916
  # print(f" - {bin_label}: {percent:.1f}%")
772
-
917
+
773
918
  # adata.uns['clustermap_results'] = results
774
919
 
775
920
  # except Exception as e:
@@ -777,52 +922,41 @@ def combined_hmm_raw_clustermap(
777
922
  # traceback.print_exc()
778
923
  # continue
779
924
 
780
- def _fixed_tick_positions(n_positions: int, n_ticks: int) -> np.ndarray:
781
- """
782
- Return indices for ~n_ticks evenly spaced labels across [0, n_positions-1].
783
- Always includes 0 and n_positions-1 when possible.
784
- """
785
- n_ticks = int(max(2, n_ticks))
786
- if n_positions <= n_ticks:
787
- return np.arange(n_positions)
788
-
789
- # linspace gives fixed count
790
- pos = np.linspace(0, n_positions - 1, n_ticks)
791
- return np.unique(np.round(pos).astype(int))
792
925
 
793
926
  def combined_raw_clustermap(
794
927
  adata,
795
928
  sample_col: str = "Sample_Names",
796
929
  reference_col: str = "Reference_strand",
797
930
  mod_target_bases: Sequence[str] = ("GpC", "CpG"),
798
- layer_any_c: str = "nan0_0minus1",
931
+ layer_c: str = "nan0_0minus1",
799
932
  layer_gpc: str = "nan0_0minus1",
800
933
  layer_cpg: str = "nan0_0minus1",
801
934
  layer_a: str = "nan0_0minus1",
802
- cmap_any_c: str = "coolwarm",
935
+ cmap_c: str = "coolwarm",
803
936
  cmap_gpc: str = "coolwarm",
804
937
  cmap_cpg: str = "viridis",
805
938
  cmap_a: str = "coolwarm",
806
- min_quality: float = 20,
807
- min_length: int = 200,
808
- min_mapped_length_to_reference_length_ratio: float = 0.8,
809
- min_position_valid_fraction: float = 0.5,
939
+ min_quality: float | None = 20,
940
+ min_length: int | None = 200,
941
+ min_mapped_length_to_reference_length_ratio: float | None = 0,
942
+ min_position_valid_fraction: float | None = 0,
943
+ demux_types: Sequence[str] = ("single", "double", "already"),
810
944
  sample_mapping: Optional[Mapping[str, str]] = None,
811
945
  save_path: str | Path | None = None,
812
- sort_by: str = "gpc", # 'gpc','cpg','any_c','gpc_cpg','any_a','none','obs:<col>'
946
+ sort_by: str = "gpc", # 'gpc','cpg','c','gpc_cpg','a','none','obs:<col>'
813
947
  bins: Optional[Dict[str, Any]] = None,
814
948
  deaminase: bool = False,
815
949
  min_signal: float = 0,
816
- # NEW tick controls
817
950
  n_xticks_any_c: int = 10,
818
951
  n_xticks_gpc: int = 10,
819
952
  n_xticks_cpg: int = 10,
820
953
  n_xticks_any_a: int = 10,
821
954
  xtick_rotation: int = 90,
822
955
  xtick_fontsize: int = 9,
956
+ index_col_suffix: str | None = None,
823
957
  ):
824
958
  """
825
- Plot stacked heatmaps + per-position mean barplots for any_C, GpC, CpG, and optional A.
959
+ Plot stacked heatmaps + per-position mean barplots for C, GpC, CpG, and optional A.
826
960
 
827
961
  Key fixes vs old version:
828
962
  - order computed ONCE per bin, applied to all matrices
@@ -838,6 +972,18 @@ def combined_raw_clustermap(
838
972
  One entry per (sample, ref) plot with matrices + bin metadata.
839
973
  """
840
974
 
975
+ # Helper: build a True mask if filter is inactive or column missing
976
+ def _mask_or_true(series_name: str, predicate):
977
+ """Return a mask from predicate or an all-True mask."""
978
+ if series_name not in adata.obs:
979
+ return pd.Series(True, index=adata.obs.index)
980
+ s = adata.obs[series_name]
981
+ try:
982
+ return predicate(s)
983
+ except Exception:
984
+ # Fallback: all True if bad dtype / predicate failure
985
+ return pd.Series(True, index=adata.obs.index)
986
+
841
987
  results: List[Dict[str, Any]] = []
842
988
  save_path = Path(save_path) if save_path is not None else None
843
989
  if save_path is not None:
@@ -856,24 +1002,63 @@ def combined_raw_clustermap(
856
1002
 
857
1003
  for ref in adata.obs[reference_col].cat.categories:
858
1004
  for sample in adata.obs[sample_col].cat.categories:
859
-
860
1005
  # Optionally remap sample label for display
861
1006
  display_sample = sample_mapping.get(sample, sample) if sample_mapping else sample
862
1007
 
863
- try:
864
- subset = adata[
865
- (adata.obs[reference_col] == ref) &
866
- (adata.obs[sample_col] == sample) &
867
- (adata.obs["read_quality"] >= min_quality) &
868
- (adata.obs["mapped_length"] >= min_length) &
869
- (adata.obs["mapped_length_to_reference_length_ratio"] >= min_mapped_length_to_reference_length_ratio)
870
- ]
1008
+ # Row-level masks (obs)
1009
+ qmask = _mask_or_true(
1010
+ "read_quality",
1011
+ (lambda s: s >= float(min_quality))
1012
+ if (min_quality is not None)
1013
+ else (lambda s: pd.Series(True, index=s.index)),
1014
+ )
1015
+ lm_mask = _mask_or_true(
1016
+ "mapped_length",
1017
+ (lambda s: s >= float(min_length))
1018
+ if (min_length is not None)
1019
+ else (lambda s: pd.Series(True, index=s.index)),
1020
+ )
1021
+ lrr_mask = _mask_or_true(
1022
+ "mapped_length_to_reference_length_ratio",
1023
+ (lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
1024
+ if (min_mapped_length_to_reference_length_ratio is not None)
1025
+ else (lambda s: pd.Series(True, index=s.index)),
1026
+ )
1027
+
1028
+ demux_mask = _mask_or_true(
1029
+ "demux_type",
1030
+ (lambda s: s.astype("string").isin(list(demux_types)))
1031
+ if (demux_types is not None)
1032
+ else (lambda s: pd.Series(True, index=s.index)),
1033
+ )
1034
+
1035
+ ref_mask = adata.obs[reference_col] == ref
1036
+ sample_mask = adata.obs[sample_col] == sample
1037
+
1038
+ row_mask = ref_mask & sample_mask & qmask & lm_mask & lrr_mask & demux_mask
1039
+
1040
+ if not bool(row_mask.any()):
1041
+ print(
1042
+ f"No reads for {display_sample} - {ref} after read quality and length filtering"
1043
+ )
1044
+ continue
871
1045
 
872
- # position-level mask
873
- valid_key = f"{ref}_valid_fraction"
874
- if valid_key in subset.var:
875
- mask = subset.var[valid_key].astype(float).values > float(min_position_valid_fraction)
876
- subset = subset[:, mask]
1046
+ try:
1047
+ subset = adata[row_mask, :].copy()
1048
+
1049
+ # Column-level mask (var)
1050
+ if min_position_valid_fraction is not None:
1051
+ valid_key = f"{ref}_valid_fraction"
1052
+ if valid_key in subset.var:
1053
+ v = pd.to_numeric(subset.var[valid_key], errors="coerce").to_numpy()
1054
+ col_mask = np.asarray(v > float(min_position_valid_fraction), dtype=bool)
1055
+ if col_mask.any():
1056
+ subset = subset[:, col_mask].copy()
1057
+ else:
1058
+ print(
1059
+ f"No positions left after valid_fraction filter for {display_sample} - {ref}"
1060
+ )
1061
+ continue
877
1062
 
878
1063
  if subset.shape[0] == 0:
879
1064
  print(f"No reads left after filtering for {display_sample} - {ref}")
@@ -893,19 +1078,19 @@ def combined_raw_clustermap(
893
1078
 
894
1079
  if include_any_c:
895
1080
  any_c_sites = np.where(subset.var.get(f"{ref}_C_site", False).values)[0]
896
- gpc_sites = np.where(subset.var.get(f"{ref}_GpC_site", False).values)[0]
897
- cpg_sites = np.where(subset.var.get(f"{ref}_CpG_site", False).values)[0]
1081
+ gpc_sites = np.where(subset.var.get(f"{ref}_GpC_site", False).values)[0]
1082
+ cpg_sites = np.where(subset.var.get(f"{ref}_CpG_site", False).values)[0]
898
1083
 
899
1084
  num_any_c, num_gpc, num_cpg = len(any_c_sites), len(gpc_sites), len(cpg_sites)
900
1085
 
901
- any_c_labels = subset.var_names[any_c_sites].astype(str)
902
- gpc_labels = subset.var_names[gpc_sites].astype(str)
903
- cpg_labels = subset.var_names[cpg_sites].astype(str)
1086
+ any_c_labels = _select_labels(subset, any_c_sites, ref, index_col_suffix)
1087
+ gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
1088
+ cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
904
1089
 
905
1090
  if include_any_a:
906
1091
  any_a_sites = np.where(subset.var.get(f"{ref}_A_site", False).values)[0]
907
1092
  num_any_a = len(any_a_sites)
908
- any_a_labels = subset.var_names[any_a_sites].astype(str)
1093
+ any_a_labels = _select_labels(subset, any_a_sites, ref, index_col_suffix)
909
1094
 
910
1095
  stacked_any_c, stacked_gpc, stacked_cpg, stacked_any_a = [], [], [], []
911
1096
  row_labels, bin_labels, bin_boundaries = [], [], []
@@ -932,23 +1117,31 @@ def combined_raw_clustermap(
932
1117
  order = np.argsort(subset_bin.obs[colname].values)
933
1118
 
934
1119
  elif sort_by == "gpc" and num_gpc > 0:
935
- linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
1120
+ linkage = sch.linkage(
1121
+ subset_bin[:, gpc_sites].layers[layer_gpc], method="ward"
1122
+ )
936
1123
  order = sch.leaves_list(linkage)
937
1124
 
938
1125
  elif sort_by == "cpg" and num_cpg > 0:
939
- linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
1126
+ linkage = sch.linkage(
1127
+ subset_bin[:, cpg_sites].layers[layer_cpg], method="ward"
1128
+ )
940
1129
  order = sch.leaves_list(linkage)
941
1130
 
942
- elif sort_by == "any_c" and num_any_c > 0:
943
- linkage = sch.linkage(subset_bin[:, any_c_sites].layers[layer_any_c], method="ward")
1131
+ elif sort_by == "c" and num_any_c > 0:
1132
+ linkage = sch.linkage(
1133
+ subset_bin[:, any_c_sites].layers[layer_c], method="ward"
1134
+ )
944
1135
  order = sch.leaves_list(linkage)
945
1136
 
946
1137
  elif sort_by == "gpc_cpg":
947
1138
  linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
948
1139
  order = sch.leaves_list(linkage)
949
1140
 
950
- elif sort_by == "any_a" and num_any_a > 0:
951
- linkage = sch.linkage(subset_bin[:, any_a_sites].layers[layer_a], method="ward")
1141
+ elif sort_by == "a" and num_any_a > 0:
1142
+ linkage = sch.linkage(
1143
+ subset_bin[:, any_a_sites].layers[layer_a], method="ward"
1144
+ )
952
1145
  order = sch.leaves_list(linkage)
953
1146
 
954
1147
  elif sort_by == "none":
@@ -961,7 +1154,7 @@ def combined_raw_clustermap(
961
1154
 
962
1155
  # stack consistently
963
1156
  if include_any_c and num_any_c > 0:
964
- stacked_any_c.append(subset_bin[:, any_c_sites].layers[layer_any_c])
1157
+ stacked_any_c.append(subset_bin[:, any_c_sites].layers[layer_c])
965
1158
  if include_any_c and num_gpc > 0:
966
1159
  stacked_gpc.append(subset_bin[:, gpc_sites].layers[layer_gpc])
967
1160
  if include_any_c and num_cpg > 0:
@@ -981,57 +1174,65 @@ def combined_raw_clustermap(
981
1174
 
982
1175
  if include_any_c and stacked_any_c:
983
1176
  any_c_matrix = np.vstack(stacked_any_c)
984
- gpc_matrix = np.vstack(stacked_gpc) if stacked_gpc else np.empty((0, 0))
985
- cpg_matrix = np.vstack(stacked_cpg) if stacked_cpg else np.empty((0, 0))
1177
+ gpc_matrix = np.vstack(stacked_gpc) if stacked_gpc else np.empty((0, 0))
1178
+ cpg_matrix = np.vstack(stacked_cpg) if stacked_cpg else np.empty((0, 0))
986
1179
 
987
1180
  mean_any_c = methylation_fraction(any_c_matrix) if any_c_matrix.size else None
988
- mean_gpc = methylation_fraction(gpc_matrix) if gpc_matrix.size else None
989
- mean_cpg = methylation_fraction(cpg_matrix) if cpg_matrix.size else None
1181
+ mean_gpc = methylation_fraction(gpc_matrix) if gpc_matrix.size else None
1182
+ mean_cpg = methylation_fraction(cpg_matrix) if cpg_matrix.size else None
990
1183
 
991
1184
  if any_c_matrix.size:
992
- blocks.append(dict(
993
- name="any_c",
994
- matrix=any_c_matrix,
995
- mean=mean_any_c,
996
- labels=any_c_labels,
997
- cmap=cmap_any_c,
998
- n_xticks=n_xticks_any_c,
999
- title="any C site Modification Signal"
1000
- ))
1185
+ blocks.append(
1186
+ dict(
1187
+ name="c",
1188
+ matrix=any_c_matrix,
1189
+ mean=mean_any_c,
1190
+ labels=any_c_labels,
1191
+ cmap=cmap_c,
1192
+ n_xticks=n_xticks_any_c,
1193
+ title="any C site Modification Signal",
1194
+ )
1195
+ )
1001
1196
  if gpc_matrix.size:
1002
- blocks.append(dict(
1003
- name="gpc",
1004
- matrix=gpc_matrix,
1005
- mean=mean_gpc,
1006
- labels=gpc_labels,
1007
- cmap=cmap_gpc,
1008
- n_xticks=n_xticks_gpc,
1009
- title="GpC Modification Signal"
1010
- ))
1197
+ blocks.append(
1198
+ dict(
1199
+ name="gpc",
1200
+ matrix=gpc_matrix,
1201
+ mean=mean_gpc,
1202
+ labels=gpc_labels,
1203
+ cmap=cmap_gpc,
1204
+ n_xticks=n_xticks_gpc,
1205
+ title="GpC Modification Signal",
1206
+ )
1207
+ )
1011
1208
  if cpg_matrix.size:
1012
- blocks.append(dict(
1013
- name="cpg",
1014
- matrix=cpg_matrix,
1015
- mean=mean_cpg,
1016
- labels=cpg_labels,
1017
- cmap=cmap_cpg,
1018
- n_xticks=n_xticks_cpg,
1019
- title="CpG Modification Signal"
1020
- ))
1209
+ blocks.append(
1210
+ dict(
1211
+ name="cpg",
1212
+ matrix=cpg_matrix,
1213
+ mean=mean_cpg,
1214
+ labels=cpg_labels,
1215
+ cmap=cmap_cpg,
1216
+ n_xticks=n_xticks_cpg,
1217
+ title="CpG Modification Signal",
1218
+ )
1219
+ )
1021
1220
 
1022
1221
  if include_any_a and stacked_any_a:
1023
1222
  any_a_matrix = np.vstack(stacked_any_a)
1024
1223
  mean_any_a = methylation_fraction(any_a_matrix) if any_a_matrix.size else None
1025
1224
  if any_a_matrix.size:
1026
- blocks.append(dict(
1027
- name="any_a",
1028
- matrix=any_a_matrix,
1029
- mean=mean_any_a,
1030
- labels=any_a_labels,
1031
- cmap=cmap_a,
1032
- n_xticks=n_xticks_any_a,
1033
- title="any A site Modification Signal"
1034
- ))
1225
+ blocks.append(
1226
+ dict(
1227
+ name="a",
1228
+ matrix=any_a_matrix,
1229
+ mean=mean_any_a,
1230
+ labels=any_a_labels,
1231
+ cmap=cmap_a,
1232
+ n_xticks=n_xticks_any_a,
1233
+ title="any A site Modification Signal",
1234
+ )
1235
+ )
1035
1236
 
1036
1237
  if not blocks:
1037
1238
  print(f"No matrices to plot for {display_sample} - {ref}")
@@ -1043,7 +1244,7 @@ def combined_raw_clustermap(
1043
1244
  fig.suptitle(f"{display_sample} - {ref} - {total_reads} reads", fontsize=14, y=0.97)
1044
1245
 
1045
1246
  axes_heat = [fig.add_subplot(gs[1, i]) for i in range(gs_dim)]
1046
- axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(gs_dim)]
1247
+ axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(gs_dim)]
1047
1248
 
1048
1249
  # ----------------------------
1049
1250
  # plot blocks
@@ -1059,20 +1260,14 @@ def combined_raw_clustermap(
1059
1260
 
1060
1261
  # heatmap
1061
1262
  sns.heatmap(
1062
- mat,
1063
- cmap=blk["cmap"],
1064
- ax=axes_heat[i],
1065
- yticklabels=False,
1066
- cbar=False
1263
+ mat, cmap=blk["cmap"], ax=axes_heat[i], yticklabels=False, cbar=False
1067
1264
  )
1068
1265
 
1069
1266
  # fixed tick labels
1070
1267
  tick_pos = _fixed_tick_positions(len(labels), n_xticks)
1071
1268
  axes_heat[i].set_xticks(tick_pos)
1072
1269
  axes_heat[i].set_xticklabels(
1073
- labels[tick_pos],
1074
- rotation=xtick_rotation,
1075
- fontsize=xtick_fontsize
1270
+ labels[tick_pos], rotation=xtick_rotation, fontsize=xtick_fontsize
1076
1271
  )
1077
1272
 
1078
1273
  # bin separators
@@ -1085,7 +1280,12 @@ def combined_raw_clustermap(
1085
1280
 
1086
1281
  # save or show
1087
1282
  if save_path is not None:
1088
- safe_name = f"{ref}__{display_sample}".replace("=", "").replace("__", "_").replace(",", "_").replace(" ", "_")
1283
+ safe_name = (
1284
+ f"{ref}__{display_sample}".replace("=", "")
1285
+ .replace("__", "_")
1286
+ .replace(",", "_")
1287
+ .replace(" ", "_")
1288
+ )
1089
1289
  out_file = save_path / f"{safe_name}.png"
1090
1290
  fig.savefig(out_file, dpi=300)
1091
1291
  plt.close(fig)
@@ -1111,20 +1311,15 @@ def combined_raw_clustermap(
1111
1311
  for bin_label, percent in percentages.items():
1112
1312
  print(f" - {bin_label}: {percent:.1f}%")
1113
1313
 
1114
- except Exception as e:
1314
+ except Exception:
1115
1315
  import traceback
1316
+
1116
1317
  traceback.print_exc()
1117
1318
  continue
1118
1319
 
1119
- # store once at the end (HDF5 safe)
1120
- # matrices won't be HDF5-safe; store only metadata + maybe hit counts
1121
- # adata.uns["clustermap_results"] = [
1122
- # {k: v for k, v in r.items() if not k.endswith("_matrix")}
1123
- # for r in results
1124
- # ]
1125
-
1126
1320
  return results
1127
1321
 
1322
+
1128
1323
  def plot_hmm_layers_rolling_by_sample_ref(
1129
1324
  adata,
1130
1325
  layers: Optional[Sequence[str]] = None,
@@ -1141,7 +1336,7 @@ def plot_hmm_layers_rolling_by_sample_ref(
1141
1336
  output_dir: Optional[str] = None,
1142
1337
  save: bool = True,
1143
1338
  show_raw: bool = False,
1144
- cmap: str = "tab10",
1339
+ cmap: str = "tab20",
1145
1340
  use_var_coords: bool = True,
1146
1341
  ):
1147
1342
  """
@@ -1191,7 +1386,9 @@ def plot_hmm_layers_rolling_by_sample_ref(
1191
1386
 
1192
1387
  # --- basic checks / defaults ---
1193
1388
  if sample_col not in adata.obs.columns or ref_col not in adata.obs.columns:
1194
- raise ValueError(f"sample_col '{sample_col}' and ref_col '{ref_col}' must exist in adata.obs")
1389
+ raise ValueError(
1390
+ f"sample_col '{sample_col}' and ref_col '{ref_col}' must exist in adata.obs"
1391
+ )
1195
1392
 
1196
1393
  # canonicalize samples / refs
1197
1394
  if samples is None:
@@ -1214,7 +1411,9 @@ def plot_hmm_layers_rolling_by_sample_ref(
1214
1411
  if layers is None:
1215
1412
  layers = list(adata.layers.keys())
1216
1413
  if len(layers) == 0:
1217
- raise ValueError("No adata.layers found. Please pass `layers=[...]` of the HMM layers to plot.")
1414
+ raise ValueError(
1415
+ "No adata.layers found. Please pass `layers=[...]` of the HMM layers to plot."
1416
+ )
1218
1417
  layers = list(layers)
1219
1418
 
1220
1419
  # x coordinates (positions)
@@ -1253,19 +1452,29 @@ def plot_hmm_layers_rolling_by_sample_ref(
1253
1452
 
1254
1453
  fig_w = figsize_per_cell[0] * ncols
1255
1454
  fig_h = figsize_per_cell[1] * nrows
1256
- fig, axes = plt.subplots(nrows=nrows, ncols=ncols,
1257
- figsize=(fig_w, fig_h), dpi=dpi,
1258
- squeeze=False)
1455
+ fig, axes = plt.subplots(
1456
+ nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False
1457
+ )
1259
1458
 
1260
1459
  for r_idx, sample_name in enumerate(chunk):
1261
1460
  for c_idx, ref_name in enumerate(refs_all):
1262
1461
  ax = axes[r_idx][c_idx]
1263
1462
 
1264
1463
  # subset adata
1265
- mask = (adata.obs[sample_col].values == sample_name) & (adata.obs[ref_col].values == ref_name)
1464
+ mask = (adata.obs[sample_col].values == sample_name) & (
1465
+ adata.obs[ref_col].values == ref_name
1466
+ )
1266
1467
  sub = adata[mask]
1267
1468
  if sub.n_obs == 0:
1268
- ax.text(0.5, 0.5, "No reads", ha="center", va="center", transform=ax.transAxes, color="gray")
1469
+ ax.text(
1470
+ 0.5,
1471
+ 0.5,
1472
+ "No reads",
1473
+ ha="center",
1474
+ va="center",
1475
+ transform=ax.transAxes,
1476
+ color="gray",
1477
+ )
1269
1478
  ax.set_xticks([])
1270
1479
  ax.set_yticks([])
1271
1480
  if r_idx == 0:
@@ -1315,7 +1524,11 @@ def plot_hmm_layers_rolling_by_sample_ref(
1315
1524
  smoothed = col_mean
1316
1525
  else:
1317
1526
  ser = pd.Series(col_mean)
1318
- smoothed = ser.rolling(window=window, min_periods=min_periods, center=center).mean().to_numpy()
1527
+ smoothed = (
1528
+ ser.rolling(window=window, min_periods=min_periods, center=center)
1529
+ .mean()
1530
+ .to_numpy()
1531
+ )
1319
1532
 
1320
1533
  # x axis: x_coords (trim/pad to match length)
1321
1534
  L = len(col_mean)
@@ -1325,7 +1538,15 @@ def plot_hmm_layers_rolling_by_sample_ref(
1325
1538
  if show_raw:
1326
1539
  ax.plot(x, col_mean[:L], linewidth=0.7, alpha=0.25, zorder=1)
1327
1540
 
1328
- ax.plot(x, smoothed[:L], label=layer, color=colors[li], linewidth=1.2, alpha=0.95, zorder=2)
1541
+ ax.plot(
1542
+ x,
1543
+ smoothed[:L],
1544
+ label=layer,
1545
+ color=colors[li],
1546
+ linewidth=1.2,
1547
+ alpha=0.95,
1548
+ zorder=2,
1549
+ )
1329
1550
  plotted_any = True
1330
1551
 
1331
1552
  # labels / titles
@@ -1343,11 +1564,15 @@ def plot_hmm_layers_rolling_by_sample_ref(
1343
1564
 
1344
1565
  ax.grid(True, alpha=0.2)
1345
1566
 
1346
- fig.suptitle(f"Rolling mean of layer positional means (window={window}) — page {page+1}/{total_pages}", fontsize=11, y=0.995)
1567
+ fig.suptitle(
1568
+ f"Rolling mean of layer positional means (window={window}) — page {page + 1}/{total_pages}",
1569
+ fontsize=11,
1570
+ y=0.995,
1571
+ )
1347
1572
  fig.tight_layout(rect=[0, 0, 1, 0.97])
1348
1573
 
1349
1574
  if save:
1350
- fname = os.path.join(outdir, f"hmm_layers_rolling_page{page+1}.png")
1575
+ fname = os.path.join(outdir, f"hmm_layers_rolling_page{page + 1}.png")
1351
1576
  plt.savefig(fname, bbox_inches="tight", dpi=dpi)
1352
1577
  saved_files.append(fname)
1353
1578
  else: