smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. smftools/__init__.py +43 -13
  2. smftools/_settings.py +6 -6
  3. smftools/_version.py +3 -1
  4. smftools/cli/__init__.py +1 -0
  5. smftools/cli/archived/cli_flows.py +2 -0
  6. smftools/cli/helpers.py +9 -1
  7. smftools/cli/hmm_adata.py +905 -242
  8. smftools/cli/load_adata.py +432 -280
  9. smftools/cli/preprocess_adata.py +287 -171
  10. smftools/cli/spatial_adata.py +141 -53
  11. smftools/cli_entry.py +119 -178
  12. smftools/config/__init__.py +3 -1
  13. smftools/config/conversion.yaml +5 -1
  14. smftools/config/deaminase.yaml +1 -1
  15. smftools/config/default.yaml +26 -18
  16. smftools/config/direct.yaml +8 -3
  17. smftools/config/discover_input_files.py +19 -5
  18. smftools/config/experiment_config.py +511 -276
  19. smftools/constants.py +37 -0
  20. smftools/datasets/__init__.py +4 -8
  21. smftools/datasets/datasets.py +32 -18
  22. smftools/hmm/HMM.py +2133 -1428
  23. smftools/hmm/__init__.py +24 -14
  24. smftools/hmm/archived/apply_hmm_batched.py +2 -0
  25. smftools/hmm/archived/calculate_distances.py +2 -0
  26. smftools/hmm/archived/call_hmm_peaks.py +18 -1
  27. smftools/hmm/archived/train_hmm.py +2 -0
  28. smftools/hmm/call_hmm_peaks.py +176 -193
  29. smftools/hmm/display_hmm.py +23 -7
  30. smftools/hmm/hmm_readwrite.py +20 -6
  31. smftools/hmm/nucleosome_hmm_refinement.py +104 -14
  32. smftools/informatics/__init__.py +55 -13
  33. smftools/informatics/archived/bam_conversion.py +2 -0
  34. smftools/informatics/archived/bam_direct.py +2 -0
  35. smftools/informatics/archived/basecall_pod5s.py +2 -0
  36. smftools/informatics/archived/basecalls_to_adata.py +2 -0
  37. smftools/informatics/archived/conversion_smf.py +2 -0
  38. smftools/informatics/archived/deaminase_smf.py +1 -0
  39. smftools/informatics/archived/direct_smf.py +2 -0
  40. smftools/informatics/archived/fast5_to_pod5.py +2 -0
  41. smftools/informatics/archived/helpers/archived/__init__.py +2 -0
  42. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
  43. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
  44. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  45. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
  46. smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
  47. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  48. smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
  49. smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
  50. smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
  51. smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
  52. smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
  53. smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
  54. smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
  55. smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
  56. smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
  57. smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
  58. smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
  59. smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
  60. smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
  61. smftools/informatics/archived/helpers/archived/informatics.py +2 -0
  62. smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
  63. smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
  64. smftools/informatics/archived/helpers/archived/modQC.py +2 -0
  65. smftools/informatics/archived/helpers/archived/modcall.py +2 -0
  66. smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
  67. smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
  68. smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
  69. smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
  70. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
  71. smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
  72. smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
  73. smftools/informatics/archived/print_bam_query_seq.py +9 -1
  74. smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
  75. smftools/informatics/archived/subsample_pod5.py +2 -0
  76. smftools/informatics/bam_functions.py +1059 -269
  77. smftools/informatics/basecalling.py +53 -9
  78. smftools/informatics/bed_functions.py +357 -114
  79. smftools/informatics/binarize_converted_base_identities.py +21 -7
  80. smftools/informatics/complement_base_list.py +9 -6
  81. smftools/informatics/converted_BAM_to_adata.py +324 -137
  82. smftools/informatics/fasta_functions.py +251 -89
  83. smftools/informatics/h5ad_functions.py +202 -30
  84. smftools/informatics/modkit_extract_to_adata.py +623 -274
  85. smftools/informatics/modkit_functions.py +87 -44
  86. smftools/informatics/ohe.py +46 -21
  87. smftools/informatics/pod5_functions.py +114 -74
  88. smftools/informatics/run_multiqc.py +20 -14
  89. smftools/logging_utils.py +51 -0
  90. smftools/machine_learning/__init__.py +23 -12
  91. smftools/machine_learning/data/__init__.py +2 -0
  92. smftools/machine_learning/data/anndata_data_module.py +157 -50
  93. smftools/machine_learning/data/preprocessing.py +4 -1
  94. smftools/machine_learning/evaluation/__init__.py +3 -1
  95. smftools/machine_learning/evaluation/eval_utils.py +13 -14
  96. smftools/machine_learning/evaluation/evaluators.py +52 -34
  97. smftools/machine_learning/inference/__init__.py +3 -1
  98. smftools/machine_learning/inference/inference_utils.py +9 -4
  99. smftools/machine_learning/inference/lightning_inference.py +14 -13
  100. smftools/machine_learning/inference/sklearn_inference.py +8 -8
  101. smftools/machine_learning/inference/sliding_window_inference.py +37 -25
  102. smftools/machine_learning/models/__init__.py +12 -5
  103. smftools/machine_learning/models/base.py +34 -43
  104. smftools/machine_learning/models/cnn.py +22 -13
  105. smftools/machine_learning/models/lightning_base.py +78 -42
  106. smftools/machine_learning/models/mlp.py +18 -5
  107. smftools/machine_learning/models/positional.py +10 -4
  108. smftools/machine_learning/models/rnn.py +8 -3
  109. smftools/machine_learning/models/sklearn_models.py +46 -24
  110. smftools/machine_learning/models/transformer.py +75 -55
  111. smftools/machine_learning/models/wrappers.py +8 -3
  112. smftools/machine_learning/training/__init__.py +4 -2
  113. smftools/machine_learning/training/train_lightning_model.py +42 -23
  114. smftools/machine_learning/training/train_sklearn_model.py +11 -15
  115. smftools/machine_learning/utils/__init__.py +3 -1
  116. smftools/machine_learning/utils/device.py +12 -5
  117. smftools/machine_learning/utils/grl.py +8 -2
  118. smftools/metadata.py +443 -0
  119. smftools/optional_imports.py +31 -0
  120. smftools/plotting/__init__.py +32 -17
  121. smftools/plotting/autocorrelation_plotting.py +153 -48
  122. smftools/plotting/classifiers.py +175 -73
  123. smftools/plotting/general_plotting.py +350 -168
  124. smftools/plotting/hmm_plotting.py +53 -14
  125. smftools/plotting/position_stats.py +155 -87
  126. smftools/plotting/qc_plotting.py +25 -12
  127. smftools/preprocessing/__init__.py +35 -37
  128. smftools/preprocessing/append_base_context.py +105 -79
  129. smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
  130. smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
  131. smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
  132. smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
  133. smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
  134. smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
  135. smftools/preprocessing/binarize.py +21 -4
  136. smftools/preprocessing/binarize_on_Youden.py +127 -31
  137. smftools/preprocessing/binary_layers_to_ohe.py +18 -11
  138. smftools/preprocessing/calculate_complexity_II.py +89 -59
  139. smftools/preprocessing/calculate_consensus.py +28 -19
  140. smftools/preprocessing/calculate_coverage.py +44 -22
  141. smftools/preprocessing/calculate_pairwise_differences.py +4 -1
  142. smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
  143. smftools/preprocessing/calculate_position_Youden.py +110 -55
  144. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  145. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  146. smftools/preprocessing/clean_NaN.py +38 -28
  147. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  148. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
  149. smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
  150. smftools/preprocessing/flag_duplicate_reads.py +708 -303
  151. smftools/preprocessing/invert_adata.py +26 -11
  152. smftools/preprocessing/load_sample_sheet.py +40 -22
  153. smftools/preprocessing/make_dirs.py +9 -3
  154. smftools/preprocessing/min_non_diagonal.py +4 -1
  155. smftools/preprocessing/recipes.py +58 -23
  156. smftools/preprocessing/reindex_references_adata.py +93 -27
  157. smftools/preprocessing/subsample_adata.py +33 -16
  158. smftools/readwrite.py +264 -109
  159. smftools/schema/__init__.py +11 -0
  160. smftools/schema/anndata_schema_v1.yaml +227 -0
  161. smftools/tools/__init__.py +25 -18
  162. smftools/tools/archived/apply_hmm.py +2 -0
  163. smftools/tools/archived/classifiers.py +165 -0
  164. smftools/tools/archived/classify_methylated_features.py +2 -0
  165. smftools/tools/archived/classify_non_methylated_features.py +2 -0
  166. smftools/tools/archived/subset_adata_v1.py +12 -1
  167. smftools/tools/archived/subset_adata_v2.py +14 -1
  168. smftools/tools/calculate_umap.py +56 -15
  169. smftools/tools/cluster_adata_on_methylation.py +122 -47
  170. smftools/tools/general_tools.py +70 -25
  171. smftools/tools/position_stats.py +220 -99
  172. smftools/tools/read_stats.py +50 -29
  173. smftools/tools/spatial_autocorrelation.py +365 -192
  174. smftools/tools/subset_adata.py +23 -21
  175. smftools-0.3.0.dist-info/METADATA +147 -0
  176. smftools-0.3.0.dist-info/RECORD +182 -0
  177. smftools-0.2.4.dist-info/METADATA +0 -141
  178. smftools-0.2.4.dist-info/RECORD +0 -176
  179. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
  180. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
  181. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,16 +1,20 @@
1
1
  from __future__ import annotations
2
2
 
3
- import numpy as np
4
- import seaborn as sns
5
- import matplotlib.pyplot as plt
6
- import scipy.cluster.hierarchy as sch
7
- import matplotlib.gridspec as gridspec
8
- import os
9
3
  import math
4
+ import os
5
+ from pathlib import Path
6
+ from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
7
+
8
+ import numpy as np
10
9
  import pandas as pd
10
+ import scipy.cluster.hierarchy as sch
11
+
12
+ from smftools.optional_imports import require
13
+
14
+ gridspec = require("matplotlib.gridspec", extra="plotting", purpose="heatmap plotting")
15
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="plot rendering")
16
+ sns = require("seaborn", extra="plotting", purpose="plot styling")
11
17
 
12
- from typing import Optional, Mapping, Sequence, Any, Dict, List, Tuple
13
- from pathlib import Path
14
18
 
15
19
  def _fixed_tick_positions(n_positions: int, n_ticks: int) -> np.ndarray:
16
20
  """
@@ -25,6 +29,7 @@ def _fixed_tick_positions(n_positions: int, n_ticks: int) -> np.ndarray:
25
29
  pos = np.linspace(0, n_positions - 1, n_ticks)
26
30
  return np.unique(np.round(pos).astype(int))
27
31
 
32
+
28
33
  def _select_labels(subset, sites: np.ndarray, reference: str, index_col_suffix: str | None):
29
34
  """
30
35
  Select tick labels for the heatmap axis.
@@ -65,11 +70,21 @@ def _select_labels(subset, sites: np.ndarray, reference: str, index_col_suffix:
65
70
  labels = subset.var[colname].astype(str).values
66
71
  return labels[sites]
67
72
 
73
+
68
74
  def normalized_mean(matrix: np.ndarray) -> np.ndarray:
75
+ """Compute normalized column means for a matrix.
76
+
77
+ Args:
78
+ matrix: Input matrix.
79
+
80
+ Returns:
81
+ 1D array of normalized means.
82
+ """
69
83
  mean = np.nanmean(matrix, axis=0)
70
84
  denom = (mean.max() - mean.min()) + 1e-9
71
85
  return (mean - mean.min()) / denom
72
86
 
87
+
73
88
  def methylation_fraction(matrix: np.ndarray) -> np.ndarray:
74
89
  """
75
90
  Fraction methylated per column.
@@ -84,14 +99,20 @@ def methylation_fraction(matrix: np.ndarray) -> np.ndarray:
84
99
  valid = valid_mask.sum(axis=0)
85
100
 
86
101
  return np.divide(
87
- methylated, valid,
88
- out=np.zeros_like(methylated, dtype=float),
89
- where=valid != 0
102
+ methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0
90
103
  )
91
104
 
105
+
92
106
  def clean_barplot(ax, mean_values, title):
107
+ """Format a barplot with consistent axes and labels.
108
+
109
+ Args:
110
+ ax: Matplotlib axes.
111
+ mean_values: Values to plot.
112
+ title: Plot title.
113
+ """
93
114
  x = np.arange(len(mean_values))
94
- ax.bar(x, mean_values, color="gray", width=1.0, align='edge')
115
+ ax.bar(x, mean_values, color="gray", width=1.0, align="edge")
95
116
  ax.set_xlim(0, len(mean_values))
96
117
  ax.set_ylim(0, 1)
97
118
  ax.set_yticks([0.0, 0.5, 1.0])
@@ -100,9 +121,10 @@ def clean_barplot(ax, mean_values, title):
100
121
 
101
122
  # Hide all spines except left
102
123
  for spine_name, spine in ax.spines.items():
103
- spine.set_visible(spine_name == 'left')
124
+ spine.set_visible(spine_name == "left")
125
+
126
+ ax.tick_params(axis="x", which="both", bottom=False, top=False, labelbottom=False)
104
127
 
105
- ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
106
128
 
107
129
  # def combined_hmm_raw_clustermap(
108
130
  # adata,
@@ -145,7 +167,7 @@ def clean_barplot(ax, mean_values, title):
145
167
  # (adata.obs['read_length'] >= min_length) &
146
168
  # (adata.obs['mapped_length_to_reference_length_ratio'] > min_mapped_length_to_reference_length_ratio)
147
169
  # ]
148
-
170
+
149
171
  # mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
150
172
  # subset = subset[:, mask]
151
173
 
@@ -257,7 +279,7 @@ def clean_barplot(ax, mean_values, title):
257
279
  # clean_barplot(axes_bar[1], mean_gpc, f"GpC Accessibility Signal")
258
280
  # clean_barplot(axes_bar[2], mean_cpg, f"CpG Accessibility Signal")
259
281
  # clean_barplot(axes_bar[3], mean_any_c, f"Any C Accessibility Signal")
260
-
282
+
261
283
  # hmm_labels = subset.var_names.astype(int)
262
284
  # hmm_label_spacing = 150
263
285
  # sns.heatmap(hmm_matrix, cmap=cmap_hmm, ax=axes_heat[0], xticklabels=hmm_labels[::hmm_label_spacing], yticklabels=False, cbar=False)
@@ -311,7 +333,7 @@ def clean_barplot(ax, mean_values, title):
311
333
  # "bin_boundaries": bin_boundaries,
312
334
  # "percentages": percentages
313
335
  # })
314
-
336
+
315
337
  # #adata.uns['clustermap_results'] = results
316
338
 
317
339
  # except Exception as e:
@@ -319,45 +341,39 @@ def clean_barplot(ax, mean_values, title):
319
341
  # traceback.print_exc()
320
342
  # continue
321
343
 
344
+
322
345
  def combined_hmm_raw_clustermap(
323
346
  adata,
324
347
  sample_col: str = "Sample_Names",
325
348
  reference_col: str = "Reference_strand",
326
-
327
349
  hmm_feature_layer: str = "hmm_combined",
328
-
329
350
  layer_gpc: str = "nan0_0minus1",
330
351
  layer_cpg: str = "nan0_0minus1",
331
352
  layer_c: str = "nan0_0minus1",
332
353
  layer_a: str = "nan0_0minus1",
333
-
334
354
  cmap_hmm: str = "tab10",
335
355
  cmap_gpc: str = "coolwarm",
336
356
  cmap_cpg: str = "viridis",
337
357
  cmap_c: str = "coolwarm",
338
358
  cmap_a: str = "coolwarm",
339
-
340
359
  min_quality: int = 20,
341
360
  min_length: int = 200,
342
361
  min_mapped_length_to_reference_length_ratio: float = 0.8,
343
362
  min_position_valid_fraction: float = 0.5,
344
-
363
+ demux_types: Sequence[str] = ("single", "double", "already"),
364
+ sample_mapping: Optional[Mapping[str, str]] = None,
345
365
  save_path: str | Path | None = None,
346
366
  normalize_hmm: bool = False,
347
-
348
367
  sort_by: str = "gpc",
349
368
  bins: Optional[Dict[str, Any]] = None,
350
-
351
369
  deaminase: bool = False,
352
370
  min_signal: float = 0.0,
353
-
354
371
  # ---- fixed tick label controls (counts, not spacing)
355
372
  n_xticks_hmm: int = 10,
356
373
  n_xticks_any_c: int = 8,
357
374
  n_xticks_gpc: int = 8,
358
375
  n_xticks_cpg: int = 8,
359
376
  n_xticks_a: int = 8,
360
-
361
377
  index_col_suffix: str | None = None,
362
378
  ):
363
379
  """
@@ -369,39 +385,92 @@ def combined_hmm_raw_clustermap(
369
385
  sort_by options:
370
386
  'gpc', 'cpg', 'c', 'a', 'gpc_cpg', 'none', 'hmm', or 'obs:<col>'
371
387
  """
388
+
372
389
  def pick_xticks(labels: np.ndarray, n_ticks: int):
390
+ """Pick tick indices/labels from an array."""
373
391
  if labels.size == 0:
374
392
  return [], []
375
393
  idx = np.linspace(0, len(labels) - 1, n_ticks).round().astype(int)
376
394
  idx = np.unique(idx)
377
395
  return idx.tolist(), labels[idx].tolist()
378
-
396
+
397
+ # Helper: build a True mask if filter is inactive or column missing
398
+ def _mask_or_true(series_name: str, predicate):
399
+ """Return a mask from predicate or an all-True mask."""
400
+ if series_name not in adata.obs:
401
+ return pd.Series(True, index=adata.obs.index)
402
+ s = adata.obs[series_name]
403
+ try:
404
+ return predicate(s)
405
+ except Exception:
406
+ # Fallback: all True if bad dtype / predicate failure
407
+ return pd.Series(True, index=adata.obs.index)
408
+
379
409
  results = []
380
410
  signal_type = "deamination" if deaminase else "methylation"
381
411
 
382
412
  for ref in adata.obs[reference_col].cat.categories:
383
413
  for sample in adata.obs[sample_col].cat.categories:
414
+ # Optionally remap sample label for display
415
+ display_sample = sample_mapping.get(sample, sample) if sample_mapping else sample
416
+ # Row-level masks (obs)
417
+ qmask = _mask_or_true(
418
+ "read_quality",
419
+ (lambda s: s >= float(min_quality))
420
+ if (min_quality is not None)
421
+ else (lambda s: pd.Series(True, index=s.index)),
422
+ )
423
+ lm_mask = _mask_or_true(
424
+ "mapped_length",
425
+ (lambda s: s >= float(min_length))
426
+ if (min_length is not None)
427
+ else (lambda s: pd.Series(True, index=s.index)),
428
+ )
429
+ lrr_mask = _mask_or_true(
430
+ "mapped_length_to_reference_length_ratio",
431
+ (lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
432
+ if (min_mapped_length_to_reference_length_ratio is not None)
433
+ else (lambda s: pd.Series(True, index=s.index)),
434
+ )
435
+
436
+ demux_mask = _mask_or_true(
437
+ "demux_type",
438
+ (lambda s: s.astype("string").isin(list(demux_types)))
439
+ if (demux_types is not None)
440
+ else (lambda s: pd.Series(True, index=s.index)),
441
+ )
442
+
443
+ ref_mask = adata.obs[reference_col] == ref
444
+ sample_mask = adata.obs[sample_col] == sample
445
+
446
+ row_mask = ref_mask & sample_mask & qmask & lm_mask & lrr_mask & demux_mask
447
+
448
+ if not bool(row_mask.any()):
449
+ print(
450
+ f"No reads for {display_sample} - {ref} after read quality and length filtering"
451
+ )
452
+ continue
384
453
 
385
454
  try:
386
455
  # ---- subset reads ----
387
- subset = adata[
388
- (adata.obs[reference_col] == ref) &
389
- (adata.obs[sample_col] == sample) &
390
- (adata.obs["read_quality"] >= min_quality) &
391
- (adata.obs["read_length"] >= min_length) &
392
- (
393
- adata.obs["mapped_length_to_reference_length_ratio"]
394
- > min_mapped_length_to_reference_length_ratio
395
- )
396
- ]
397
-
398
- # ---- valid fraction filter ----
399
- vf_key = f"{ref}_valid_fraction"
400
- if vf_key in subset.var:
401
- mask = subset.var[vf_key].astype(float) > float(min_position_valid_fraction)
402
- subset = subset[:, mask]
456
+ subset = adata[row_mask, :].copy()
457
+
458
+ # Column-level mask (var)
459
+ if min_position_valid_fraction is not None:
460
+ valid_key = f"{ref}_valid_fraction"
461
+ if valid_key in subset.var:
462
+ v = pd.to_numeric(subset.var[valid_key], errors="coerce").to_numpy()
463
+ col_mask = np.asarray(v > float(min_position_valid_fraction), dtype=bool)
464
+ if col_mask.any():
465
+ subset = subset[:, col_mask].copy()
466
+ else:
467
+ print(
468
+ f"No positions left after valid_fraction filter for {display_sample} - {ref}"
469
+ )
470
+ continue
403
471
 
404
472
  if subset.shape[0] == 0:
473
+ print(f"No reads left after filtering for {display_sample} - {ref}")
405
474
  continue
406
475
 
407
476
  # ---- bins ----
@@ -412,22 +481,23 @@ def combined_hmm_raw_clustermap(
412
481
 
413
482
  # ---- site masks (robust) ----
414
483
  def _sites(*keys):
484
+ """Return indices for the first matching site key."""
415
485
  for k in keys:
416
486
  if k in subset.var:
417
487
  return np.where(subset.var[k].values)[0]
418
488
  return np.array([], dtype=int)
419
489
 
420
- gpc_sites = _sites(f"{ref}_GpC_site")
421
- cpg_sites = _sites(f"{ref}_CpG_site")
490
+ gpc_sites = _sites(f"{ref}_GpC_site")
491
+ cpg_sites = _sites(f"{ref}_CpG_site")
422
492
  any_c_sites = _sites(f"{ref}_any_C_site", f"{ref}_C_site")
423
493
  any_a_sites = _sites(f"{ref}_A_site", f"{ref}_any_A_site")
424
494
 
425
495
  # ---- labels via _select_labels ----
426
496
  # HMM uses *all* columns
427
- hmm_sites = np.arange(subset.n_vars, dtype=int)
428
- hmm_labels = _select_labels(subset, hmm_sites, ref, index_col_suffix)
429
- gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
430
- cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
497
+ hmm_sites = np.arange(subset.n_vars, dtype=int)
498
+ hmm_labels = _select_labels(subset, hmm_sites, ref, index_col_suffix)
499
+ gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
500
+ cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
431
501
  any_c_labels = _select_labels(subset, any_c_sites, ref, index_col_suffix)
432
502
  any_a_labels = _select_labels(subset, any_a_sites, ref, index_col_suffix)
433
503
 
@@ -477,9 +547,11 @@ def combined_hmm_raw_clustermap(
477
547
  elif sort_by == "gpc_cpg" and gpc_sites.size and cpg_sites.size:
478
548
  linkage = sch.linkage(sb.layers[layer_gpc], method="ward")
479
549
  order = sch.leaves_list(linkage)
480
-
550
+
481
551
  elif sort_by == "hmm" and hmm_sites.size:
482
- linkage = sch.linkage(sb[:, hmm_sites].layers[hmm_feature_layer], method="ward")
552
+ linkage = sch.linkage(
553
+ sb[:, hmm_sites].layers[hmm_feature_layer], method="ward"
554
+ )
483
555
  order = sch.leaves_list(linkage)
484
556
 
485
557
  else:
@@ -505,46 +577,62 @@ def combined_hmm_raw_clustermap(
505
577
 
506
578
  # ---------------- stack ----------------
507
579
  hmm_matrix = np.vstack(stacked_hmm)
508
- mean_hmm = normalized_mean(hmm_matrix) if normalize_hmm else np.nanmean(hmm_matrix, axis=0)
580
+ mean_hmm = (
581
+ normalized_mean(hmm_matrix) if normalize_hmm else np.nanmean(hmm_matrix, axis=0)
582
+ )
509
583
 
510
584
  panels = [
511
- (f"HMM - {hmm_feature_layer}", hmm_matrix, hmm_labels, cmap_hmm, mean_hmm, n_xticks_hmm),
585
+ (
586
+ f"HMM - {hmm_feature_layer}",
587
+ hmm_matrix,
588
+ hmm_labels,
589
+ cmap_hmm,
590
+ mean_hmm,
591
+ n_xticks_hmm,
592
+ ),
512
593
  ]
513
594
 
514
595
  if stacked_any_c:
515
596
  m = np.vstack(stacked_any_c)
516
- panels.append(("C", m, any_c_labels, cmap_c, methylation_fraction(m), n_xticks_any_c))
597
+ panels.append(
598
+ ("C", m, any_c_labels, cmap_c, methylation_fraction(m), n_xticks_any_c)
599
+ )
517
600
 
518
601
  if stacked_gpc:
519
602
  m = np.vstack(stacked_gpc)
520
- panels.append(("GpC", m, gpc_labels, cmap_gpc, methylation_fraction(m), n_xticks_gpc))
603
+ panels.append(
604
+ ("GpC", m, gpc_labels, cmap_gpc, methylation_fraction(m), n_xticks_gpc)
605
+ )
521
606
 
522
607
  if stacked_cpg:
523
608
  m = np.vstack(stacked_cpg)
524
- panels.append(("CpG", m, cpg_labels, cmap_cpg, methylation_fraction(m), n_xticks_cpg))
609
+ panels.append(
610
+ ("CpG", m, cpg_labels, cmap_cpg, methylation_fraction(m), n_xticks_cpg)
611
+ )
525
612
 
526
613
  if stacked_any_a:
527
614
  m = np.vstack(stacked_any_a)
528
- panels.append(("A", m, any_a_labels, cmap_a, methylation_fraction(m), n_xticks_a))
615
+ panels.append(
616
+ ("A", m, any_a_labels, cmap_a, methylation_fraction(m), n_xticks_a)
617
+ )
529
618
 
530
619
  # ---------------- plotting ----------------
531
620
  n_panels = len(panels)
532
621
  fig = plt.figure(figsize=(4.5 * n_panels, 10))
533
622
  gs = gridspec.GridSpec(2, n_panels, height_ratios=[1, 6], hspace=0.01)
534
- fig.suptitle(f"{sample} — {ref} — {total_reads} reads ({signal_type})",
535
- fontsize=14, y=0.98)
623
+ fig.suptitle(
624
+ f"{sample} — {ref} — {total_reads} reads ({signal_type})", fontsize=14, y=0.98
625
+ )
536
626
 
537
627
  axes_heat = [fig.add_subplot(gs[1, i]) for i in range(n_panels)]
538
628
  axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(n_panels)]
539
629
 
540
630
  for i, (name, matrix, labels, cmap, mean_vec, n_ticks) in enumerate(panels):
541
-
542
631
  # ---- your clean barplot ----
543
632
  clean_barplot(axes_bar[i], mean_vec, name)
544
633
 
545
634
  # ---- heatmap ----
546
- sns.heatmap(matrix, cmap=cmap, ax=axes_heat[i],
547
- yticklabels=False, cbar=False)
635
+ sns.heatmap(matrix, cmap=cmap, ax=axes_heat[i], yticklabels=False, cbar=False)
548
636
 
549
637
  # ---- xticks ----
550
638
  xtick_pos, xtick_labels = pick_xticks(np.asarray(labels), n_ticks)
@@ -568,6 +656,7 @@ def combined_hmm_raw_clustermap(
568
656
 
569
657
  except Exception:
570
658
  import traceback
659
+
571
660
  traceback.print_exc()
572
661
  continue
573
662
 
@@ -687,7 +776,7 @@ def combined_hmm_raw_clustermap(
687
776
  # order = np.arange(num_reads)
688
777
  # elif sort_by == "any_a":
689
778
  # linkage = sch.linkage(subset_bin.layers[layer_a], method="ward")
690
- # order = sch.leaves_list(linkage)
779
+ # order = sch.leaves_list(linkage)
691
780
  # else:
692
781
  # raise ValueError(f"Unsupported sort_by option: {sort_by}")
693
782
 
@@ -716,13 +805,13 @@ def combined_hmm_raw_clustermap(
716
805
  # order = np.arange(num_reads)
717
806
  # elif sort_by == "any_a":
718
807
  # linkage = sch.linkage(subset_bin.layers[layer_a], method="ward")
719
- # order = sch.leaves_list(linkage)
808
+ # order = sch.leaves_list(linkage)
720
809
  # else:
721
810
  # raise ValueError(f"Unsupported sort_by option: {sort_by}")
722
-
811
+
723
812
  # stacked_any_a.append(subset_bin[order][:, any_a_sites].layers[layer_a])
724
-
725
-
813
+
814
+
726
815
  # row_labels.extend([bin_label] * num_reads)
727
816
  # bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
728
817
  # last_idx += num_reads
@@ -745,7 +834,7 @@ def combined_hmm_raw_clustermap(
745
834
  # if any_a_matrix.size > 0:
746
835
  # mean_any_a = methylation_fraction(any_a_matrix)
747
836
  # gs_dim += 1
748
-
837
+
749
838
 
750
839
  # fig = plt.figure(figsize=(18, 12))
751
840
  # gs = gridspec.GridSpec(2, gs_dim, height_ratios=[1, 6], hspace=0.01)
@@ -777,8 +866,8 @@ def combined_hmm_raw_clustermap(
777
866
  # sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
778
867
  # axes_heat[current_ax].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
779
868
  # for boundary in bin_boundaries[:-1]:
780
- # axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
781
- # current_ax +=1
869
+ # axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
870
+ # current_ax +=1
782
871
 
783
872
  # results.append({
784
873
  # "sample": sample,
@@ -790,7 +879,7 @@ def combined_hmm_raw_clustermap(
790
879
  # "bin_labels": bin_labels,
791
880
  # "bin_boundaries": bin_boundaries,
792
881
  # "percentages": percentages
793
- # })
882
+ # })
794
883
 
795
884
  # if stacked_any_a:
796
885
  # if any_a_matrix.size > 0:
@@ -810,7 +899,7 @@ def combined_hmm_raw_clustermap(
810
899
  # "bin_labels": bin_labels,
811
900
  # "bin_boundaries": bin_boundaries,
812
901
  # "percentages": percentages
813
- # })
902
+ # })
814
903
 
815
904
  # plt.tight_layout()
816
905
 
@@ -828,7 +917,7 @@ def combined_hmm_raw_clustermap(
828
917
  # print(f"Summary for {sample} - {ref}:")
829
918
  # for bin_label, percent in percentages.items():
830
919
  # print(f" - {bin_label}: {percent:.1f}%")
831
-
920
+
832
921
  # adata.uns['clustermap_results'] = results
833
922
 
834
923
  # except Exception as e:
@@ -836,6 +925,7 @@ def combined_hmm_raw_clustermap(
836
925
  # traceback.print_exc()
837
926
  # continue
838
927
 
928
+
839
929
  def combined_raw_clustermap(
840
930
  adata,
841
931
  sample_col: str = "Sample_Names",
@@ -849,10 +939,11 @@ def combined_raw_clustermap(
849
939
  cmap_gpc: str = "coolwarm",
850
940
  cmap_cpg: str = "viridis",
851
941
  cmap_a: str = "coolwarm",
852
- min_quality: float = 20,
853
- min_length: int = 200,
854
- min_mapped_length_to_reference_length_ratio: float = 0.8,
855
- min_position_valid_fraction: float = 0.5,
942
+ min_quality: float | None = 20,
943
+ min_length: int | None = 200,
944
+ min_mapped_length_to_reference_length_ratio: float | None = 0,
945
+ min_position_valid_fraction: float | None = 0,
946
+ demux_types: Sequence[str] = ("single", "double", "already"),
856
947
  sample_mapping: Optional[Mapping[str, str]] = None,
857
948
  save_path: str | Path | None = None,
858
949
  sort_by: str = "gpc", # 'gpc','cpg','c','gpc_cpg','a','none','obs:<col>'
@@ -884,6 +975,18 @@ def combined_raw_clustermap(
884
975
  One entry per (sample, ref) plot with matrices + bin metadata.
885
976
  """
886
977
 
978
+ # Helper: build a True mask if filter is inactive or column missing
979
+ def _mask_or_true(series_name: str, predicate):
980
+ """Return a mask from predicate or an all-True mask."""
981
+ if series_name not in adata.obs:
982
+ return pd.Series(True, index=adata.obs.index)
983
+ s = adata.obs[series_name]
984
+ try:
985
+ return predicate(s)
986
+ except Exception:
987
+ # Fallback: all True if bad dtype / predicate failure
988
+ return pd.Series(True, index=adata.obs.index)
989
+
887
990
  results: List[Dict[str, Any]] = []
888
991
  save_path = Path(save_path) if save_path is not None else None
889
992
  if save_path is not None:
@@ -902,24 +1005,63 @@ def combined_raw_clustermap(
902
1005
 
903
1006
  for ref in adata.obs[reference_col].cat.categories:
904
1007
  for sample in adata.obs[sample_col].cat.categories:
905
-
906
1008
  # Optionally remap sample label for display
907
1009
  display_sample = sample_mapping.get(sample, sample) if sample_mapping else sample
908
1010
 
909
- try:
910
- subset = adata[
911
- (adata.obs[reference_col] == ref) &
912
- (adata.obs[sample_col] == sample) &
913
- (adata.obs["read_quality"] >= min_quality) &
914
- (adata.obs["mapped_length"] >= min_length) &
915
- (adata.obs["mapped_length_to_reference_length_ratio"] >= min_mapped_length_to_reference_length_ratio)
916
- ]
1011
+ # Row-level masks (obs)
1012
+ qmask = _mask_or_true(
1013
+ "read_quality",
1014
+ (lambda s: s >= float(min_quality))
1015
+ if (min_quality is not None)
1016
+ else (lambda s: pd.Series(True, index=s.index)),
1017
+ )
1018
+ lm_mask = _mask_or_true(
1019
+ "mapped_length",
1020
+ (lambda s: s >= float(min_length))
1021
+ if (min_length is not None)
1022
+ else (lambda s: pd.Series(True, index=s.index)),
1023
+ )
1024
+ lrr_mask = _mask_or_true(
1025
+ "mapped_length_to_reference_length_ratio",
1026
+ (lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
1027
+ if (min_mapped_length_to_reference_length_ratio is not None)
1028
+ else (lambda s: pd.Series(True, index=s.index)),
1029
+ )
1030
+
1031
+ demux_mask = _mask_or_true(
1032
+ "demux_type",
1033
+ (lambda s: s.astype("string").isin(list(demux_types)))
1034
+ if (demux_types is not None)
1035
+ else (lambda s: pd.Series(True, index=s.index)),
1036
+ )
1037
+
1038
+ ref_mask = adata.obs[reference_col] == ref
1039
+ sample_mask = adata.obs[sample_col] == sample
1040
+
1041
+ row_mask = ref_mask & sample_mask & qmask & lm_mask & lrr_mask & demux_mask
1042
+
1043
+ if not bool(row_mask.any()):
1044
+ print(
1045
+ f"No reads for {display_sample} - {ref} after read quality and length filtering"
1046
+ )
1047
+ continue
917
1048
 
918
- # position-level mask
919
- valid_key = f"{ref}_valid_fraction"
920
- if valid_key in subset.var:
921
- mask = subset.var[valid_key].astype(float).values > float(min_position_valid_fraction)
922
- subset = subset[:, mask]
1049
+ try:
1050
+ subset = adata[row_mask, :].copy()
1051
+
1052
+ # Column-level mask (var)
1053
+ if min_position_valid_fraction is not None:
1054
+ valid_key = f"{ref}_valid_fraction"
1055
+ if valid_key in subset.var:
1056
+ v = pd.to_numeric(subset.var[valid_key], errors="coerce").to_numpy()
1057
+ col_mask = np.asarray(v > float(min_position_valid_fraction), dtype=bool)
1058
+ if col_mask.any():
1059
+ subset = subset[:, col_mask].copy()
1060
+ else:
1061
+ print(
1062
+ f"No positions left after valid_fraction filter for {display_sample} - {ref}"
1063
+ )
1064
+ continue
923
1065
 
924
1066
  if subset.shape[0] == 0:
925
1067
  print(f"No reads left after filtering for {display_sample} - {ref}")
@@ -939,14 +1081,14 @@ def combined_raw_clustermap(
939
1081
 
940
1082
  if include_any_c:
941
1083
  any_c_sites = np.where(subset.var.get(f"{ref}_C_site", False).values)[0]
942
- gpc_sites = np.where(subset.var.get(f"{ref}_GpC_site", False).values)[0]
943
- cpg_sites = np.where(subset.var.get(f"{ref}_CpG_site", False).values)[0]
1084
+ gpc_sites = np.where(subset.var.get(f"{ref}_GpC_site", False).values)[0]
1085
+ cpg_sites = np.where(subset.var.get(f"{ref}_CpG_site", False).values)[0]
944
1086
 
945
1087
  num_any_c, num_gpc, num_cpg = len(any_c_sites), len(gpc_sites), len(cpg_sites)
946
1088
 
947
1089
  any_c_labels = _select_labels(subset, any_c_sites, ref, index_col_suffix)
948
- gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
949
- cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
1090
+ gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
1091
+ cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
950
1092
 
951
1093
  if include_any_a:
952
1094
  any_a_sites = np.where(subset.var.get(f"{ref}_A_site", False).values)[0]
@@ -978,15 +1120,21 @@ def combined_raw_clustermap(
978
1120
  order = np.argsort(subset_bin.obs[colname].values)
979
1121
 
980
1122
  elif sort_by == "gpc" and num_gpc > 0:
981
- linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
1123
+ linkage = sch.linkage(
1124
+ subset_bin[:, gpc_sites].layers[layer_gpc], method="ward"
1125
+ )
982
1126
  order = sch.leaves_list(linkage)
983
1127
 
984
1128
  elif sort_by == "cpg" and num_cpg > 0:
985
- linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
1129
+ linkage = sch.linkage(
1130
+ subset_bin[:, cpg_sites].layers[layer_cpg], method="ward"
1131
+ )
986
1132
  order = sch.leaves_list(linkage)
987
1133
 
988
1134
  elif sort_by == "c" and num_any_c > 0:
989
- linkage = sch.linkage(subset_bin[:, any_c_sites].layers[layer_c], method="ward")
1135
+ linkage = sch.linkage(
1136
+ subset_bin[:, any_c_sites].layers[layer_c], method="ward"
1137
+ )
990
1138
  order = sch.leaves_list(linkage)
991
1139
 
992
1140
  elif sort_by == "gpc_cpg":
@@ -994,7 +1142,9 @@ def combined_raw_clustermap(
994
1142
  order = sch.leaves_list(linkage)
995
1143
 
996
1144
  elif sort_by == "a" and num_any_a > 0:
997
- linkage = sch.linkage(subset_bin[:, any_a_sites].layers[layer_a], method="ward")
1145
+ linkage = sch.linkage(
1146
+ subset_bin[:, any_a_sites].layers[layer_a], method="ward"
1147
+ )
998
1148
  order = sch.leaves_list(linkage)
999
1149
 
1000
1150
  elif sort_by == "none":
@@ -1027,57 +1177,65 @@ def combined_raw_clustermap(
1027
1177
 
1028
1178
  if include_any_c and stacked_any_c:
1029
1179
  any_c_matrix = np.vstack(stacked_any_c)
1030
- gpc_matrix = np.vstack(stacked_gpc) if stacked_gpc else np.empty((0, 0))
1031
- cpg_matrix = np.vstack(stacked_cpg) if stacked_cpg else np.empty((0, 0))
1180
+ gpc_matrix = np.vstack(stacked_gpc) if stacked_gpc else np.empty((0, 0))
1181
+ cpg_matrix = np.vstack(stacked_cpg) if stacked_cpg else np.empty((0, 0))
1032
1182
 
1033
1183
  mean_any_c = methylation_fraction(any_c_matrix) if any_c_matrix.size else None
1034
- mean_gpc = methylation_fraction(gpc_matrix) if gpc_matrix.size else None
1035
- mean_cpg = methylation_fraction(cpg_matrix) if cpg_matrix.size else None
1184
+ mean_gpc = methylation_fraction(gpc_matrix) if gpc_matrix.size else None
1185
+ mean_cpg = methylation_fraction(cpg_matrix) if cpg_matrix.size else None
1036
1186
 
1037
1187
  if any_c_matrix.size:
1038
- blocks.append(dict(
1039
- name="c",
1040
- matrix=any_c_matrix,
1041
- mean=mean_any_c,
1042
- labels=any_c_labels,
1043
- cmap=cmap_c,
1044
- n_xticks=n_xticks_any_c,
1045
- title="any C site Modification Signal"
1046
- ))
1188
+ blocks.append(
1189
+ dict(
1190
+ name="c",
1191
+ matrix=any_c_matrix,
1192
+ mean=mean_any_c,
1193
+ labels=any_c_labels,
1194
+ cmap=cmap_c,
1195
+ n_xticks=n_xticks_any_c,
1196
+ title="any C site Modification Signal",
1197
+ )
1198
+ )
1047
1199
  if gpc_matrix.size:
1048
- blocks.append(dict(
1049
- name="gpc",
1050
- matrix=gpc_matrix,
1051
- mean=mean_gpc,
1052
- labels=gpc_labels,
1053
- cmap=cmap_gpc,
1054
- n_xticks=n_xticks_gpc,
1055
- title="GpC Modification Signal"
1056
- ))
1200
+ blocks.append(
1201
+ dict(
1202
+ name="gpc",
1203
+ matrix=gpc_matrix,
1204
+ mean=mean_gpc,
1205
+ labels=gpc_labels,
1206
+ cmap=cmap_gpc,
1207
+ n_xticks=n_xticks_gpc,
1208
+ title="GpC Modification Signal",
1209
+ )
1210
+ )
1057
1211
  if cpg_matrix.size:
1058
- blocks.append(dict(
1059
- name="cpg",
1060
- matrix=cpg_matrix,
1061
- mean=mean_cpg,
1062
- labels=cpg_labels,
1063
- cmap=cmap_cpg,
1064
- n_xticks=n_xticks_cpg,
1065
- title="CpG Modification Signal"
1066
- ))
1212
+ blocks.append(
1213
+ dict(
1214
+ name="cpg",
1215
+ matrix=cpg_matrix,
1216
+ mean=mean_cpg,
1217
+ labels=cpg_labels,
1218
+ cmap=cmap_cpg,
1219
+ n_xticks=n_xticks_cpg,
1220
+ title="CpG Modification Signal",
1221
+ )
1222
+ )
1067
1223
 
1068
1224
  if include_any_a and stacked_any_a:
1069
1225
  any_a_matrix = np.vstack(stacked_any_a)
1070
1226
  mean_any_a = methylation_fraction(any_a_matrix) if any_a_matrix.size else None
1071
1227
  if any_a_matrix.size:
1072
- blocks.append(dict(
1073
- name="a",
1074
- matrix=any_a_matrix,
1075
- mean=mean_any_a,
1076
- labels=any_a_labels,
1077
- cmap=cmap_a,
1078
- n_xticks=n_xticks_any_a,
1079
- title="any A site Modification Signal"
1080
- ))
1228
+ blocks.append(
1229
+ dict(
1230
+ name="a",
1231
+ matrix=any_a_matrix,
1232
+ mean=mean_any_a,
1233
+ labels=any_a_labels,
1234
+ cmap=cmap_a,
1235
+ n_xticks=n_xticks_any_a,
1236
+ title="any A site Modification Signal",
1237
+ )
1238
+ )
1081
1239
 
1082
1240
  if not blocks:
1083
1241
  print(f"No matrices to plot for {display_sample} - {ref}")
@@ -1089,7 +1247,7 @@ def combined_raw_clustermap(
1089
1247
  fig.suptitle(f"{display_sample} - {ref} - {total_reads} reads", fontsize=14, y=0.97)
1090
1248
 
1091
1249
  axes_heat = [fig.add_subplot(gs[1, i]) for i in range(gs_dim)]
1092
- axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(gs_dim)]
1250
+ axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(gs_dim)]
1093
1251
 
1094
1252
  # ----------------------------
1095
1253
  # plot blocks
@@ -1105,20 +1263,14 @@ def combined_raw_clustermap(
1105
1263
 
1106
1264
  # heatmap
1107
1265
  sns.heatmap(
1108
- mat,
1109
- cmap=blk["cmap"],
1110
- ax=axes_heat[i],
1111
- yticklabels=False,
1112
- cbar=False
1266
+ mat, cmap=blk["cmap"], ax=axes_heat[i], yticklabels=False, cbar=False
1113
1267
  )
1114
1268
 
1115
1269
  # fixed tick labels
1116
1270
  tick_pos = _fixed_tick_positions(len(labels), n_xticks)
1117
1271
  axes_heat[i].set_xticks(tick_pos)
1118
1272
  axes_heat[i].set_xticklabels(
1119
- labels[tick_pos],
1120
- rotation=xtick_rotation,
1121
- fontsize=xtick_fontsize
1273
+ labels[tick_pos], rotation=xtick_rotation, fontsize=xtick_fontsize
1122
1274
  )
1123
1275
 
1124
1276
  # bin separators
@@ -1131,7 +1283,12 @@ def combined_raw_clustermap(
1131
1283
 
1132
1284
  # save or show
1133
1285
  if save_path is not None:
1134
- safe_name = f"{ref}__{display_sample}".replace("=", "").replace("__", "_").replace(",", "_").replace(" ", "_")
1286
+ safe_name = (
1287
+ f"{ref}__{display_sample}".replace("=", "")
1288
+ .replace("__", "_")
1289
+ .replace(",", "_")
1290
+ .replace(" ", "_")
1291
+ )
1135
1292
  out_file = save_path / f"{safe_name}.png"
1136
1293
  fig.savefig(out_file, dpi=300)
1137
1294
  plt.close(fig)
@@ -1157,20 +1314,15 @@ def combined_raw_clustermap(
1157
1314
  for bin_label, percent in percentages.items():
1158
1315
  print(f" - {bin_label}: {percent:.1f}%")
1159
1316
 
1160
- except Exception as e:
1317
+ except Exception:
1161
1318
  import traceback
1319
+
1162
1320
  traceback.print_exc()
1163
1321
  continue
1164
1322
 
1165
- # store once at the end (HDF5 safe)
1166
- # matrices won't be HDF5-safe; store only metadata + maybe hit counts
1167
- # adata.uns["clustermap_results"] = [
1168
- # {k: v for k, v in r.items() if not k.endswith("_matrix")}
1169
- # for r in results
1170
- # ]
1171
-
1172
1323
  return results
1173
1324
 
1325
+
1174
1326
  def plot_hmm_layers_rolling_by_sample_ref(
1175
1327
  adata,
1176
1328
  layers: Optional[Sequence[str]] = None,
@@ -1237,7 +1389,9 @@ def plot_hmm_layers_rolling_by_sample_ref(
1237
1389
 
1238
1390
  # --- basic checks / defaults ---
1239
1391
  if sample_col not in adata.obs.columns or ref_col not in adata.obs.columns:
1240
- raise ValueError(f"sample_col '{sample_col}' and ref_col '{ref_col}' must exist in adata.obs")
1392
+ raise ValueError(
1393
+ f"sample_col '{sample_col}' and ref_col '{ref_col}' must exist in adata.obs"
1394
+ )
1241
1395
 
1242
1396
  # canonicalize samples / refs
1243
1397
  if samples is None:
@@ -1260,7 +1414,9 @@ def plot_hmm_layers_rolling_by_sample_ref(
1260
1414
  if layers is None:
1261
1415
  layers = list(adata.layers.keys())
1262
1416
  if len(layers) == 0:
1263
- raise ValueError("No adata.layers found. Please pass `layers=[...]` of the HMM layers to plot.")
1417
+ raise ValueError(
1418
+ "No adata.layers found. Please pass `layers=[...]` of the HMM layers to plot."
1419
+ )
1264
1420
  layers = list(layers)
1265
1421
 
1266
1422
  # x coordinates (positions)
@@ -1299,19 +1455,29 @@ def plot_hmm_layers_rolling_by_sample_ref(
1299
1455
 
1300
1456
  fig_w = figsize_per_cell[0] * ncols
1301
1457
  fig_h = figsize_per_cell[1] * nrows
1302
- fig, axes = plt.subplots(nrows=nrows, ncols=ncols,
1303
- figsize=(fig_w, fig_h), dpi=dpi,
1304
- squeeze=False)
1458
+ fig, axes = plt.subplots(
1459
+ nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False
1460
+ )
1305
1461
 
1306
1462
  for r_idx, sample_name in enumerate(chunk):
1307
1463
  for c_idx, ref_name in enumerate(refs_all):
1308
1464
  ax = axes[r_idx][c_idx]
1309
1465
 
1310
1466
  # subset adata
1311
- mask = (adata.obs[sample_col].values == sample_name) & (adata.obs[ref_col].values == ref_name)
1467
+ mask = (adata.obs[sample_col].values == sample_name) & (
1468
+ adata.obs[ref_col].values == ref_name
1469
+ )
1312
1470
  sub = adata[mask]
1313
1471
  if sub.n_obs == 0:
1314
- ax.text(0.5, 0.5, "No reads", ha="center", va="center", transform=ax.transAxes, color="gray")
1472
+ ax.text(
1473
+ 0.5,
1474
+ 0.5,
1475
+ "No reads",
1476
+ ha="center",
1477
+ va="center",
1478
+ transform=ax.transAxes,
1479
+ color="gray",
1480
+ )
1315
1481
  ax.set_xticks([])
1316
1482
  ax.set_yticks([])
1317
1483
  if r_idx == 0:
@@ -1361,7 +1527,11 @@ def plot_hmm_layers_rolling_by_sample_ref(
1361
1527
  smoothed = col_mean
1362
1528
  else:
1363
1529
  ser = pd.Series(col_mean)
1364
- smoothed = ser.rolling(window=window, min_periods=min_periods, center=center).mean().to_numpy()
1530
+ smoothed = (
1531
+ ser.rolling(window=window, min_periods=min_periods, center=center)
1532
+ .mean()
1533
+ .to_numpy()
1534
+ )
1365
1535
 
1366
1536
  # x axis: x_coords (trim/pad to match length)
1367
1537
  L = len(col_mean)
@@ -1371,7 +1541,15 @@ def plot_hmm_layers_rolling_by_sample_ref(
1371
1541
  if show_raw:
1372
1542
  ax.plot(x, col_mean[:L], linewidth=0.7, alpha=0.25, zorder=1)
1373
1543
 
1374
- ax.plot(x, smoothed[:L], label=layer, color=colors[li], linewidth=1.2, alpha=0.95, zorder=2)
1544
+ ax.plot(
1545
+ x,
1546
+ smoothed[:L],
1547
+ label=layer,
1548
+ color=colors[li],
1549
+ linewidth=1.2,
1550
+ alpha=0.95,
1551
+ zorder=2,
1552
+ )
1375
1553
  plotted_any = True
1376
1554
 
1377
1555
  # labels / titles
@@ -1389,11 +1567,15 @@ def plot_hmm_layers_rolling_by_sample_ref(
1389
1567
 
1390
1568
  ax.grid(True, alpha=0.2)
1391
1569
 
1392
- fig.suptitle(f"Rolling mean of layer positional means (window={window}) — page {page+1}/{total_pages}", fontsize=11, y=0.995)
1570
+ fig.suptitle(
1571
+ f"Rolling mean of layer positional means (window={window}) — page {page + 1}/{total_pages}",
1572
+ fontsize=11,
1573
+ y=0.995,
1574
+ )
1393
1575
  fig.tight_layout(rect=[0, 0, 1, 0.97])
1394
1576
 
1395
1577
  if save:
1396
- fname = os.path.join(outdir, f"hmm_layers_rolling_page{page+1}.png")
1578
+ fname = os.path.join(outdir, f"hmm_layers_rolling_page{page + 1}.png")
1397
1579
  plt.savefig(fname, bbox_inches="tight", dpi=dpi)
1398
1580
  saved_files.append(fname)
1399
1581
  else: