smftools 0.2.4__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. smftools/__init__.py +6 -8
  2. smftools/_settings.py +4 -6
  3. smftools/_version.py +1 -1
  4. smftools/cli/helpers.py +7 -1
  5. smftools/cli/hmm_adata.py +902 -244
  6. smftools/cli/load_adata.py +318 -198
  7. smftools/cli/preprocess_adata.py +285 -171
  8. smftools/cli/spatial_adata.py +137 -53
  9. smftools/cli_entry.py +94 -178
  10. smftools/config/__init__.py +1 -1
  11. smftools/config/conversion.yaml +5 -1
  12. smftools/config/deaminase.yaml +1 -1
  13. smftools/config/default.yaml +22 -17
  14. smftools/config/direct.yaml +8 -3
  15. smftools/config/discover_input_files.py +19 -5
  16. smftools/config/experiment_config.py +505 -276
  17. smftools/constants.py +37 -0
  18. smftools/datasets/__init__.py +2 -8
  19. smftools/datasets/datasets.py +32 -18
  20. smftools/hmm/HMM.py +2125 -1426
  21. smftools/hmm/__init__.py +2 -3
  22. smftools/hmm/archived/call_hmm_peaks.py +16 -1
  23. smftools/hmm/call_hmm_peaks.py +173 -193
  24. smftools/hmm/display_hmm.py +19 -6
  25. smftools/hmm/hmm_readwrite.py +13 -4
  26. smftools/hmm/nucleosome_hmm_refinement.py +102 -14
  27. smftools/informatics/__init__.py +30 -7
  28. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  30. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  31. smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
  32. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
  33. smftools/informatics/archived/print_bam_query_seq.py +7 -1
  34. smftools/informatics/bam_functions.py +379 -156
  35. smftools/informatics/basecalling.py +51 -9
  36. smftools/informatics/bed_functions.py +90 -57
  37. smftools/informatics/binarize_converted_base_identities.py +18 -7
  38. smftools/informatics/complement_base_list.py +7 -6
  39. smftools/informatics/converted_BAM_to_adata.py +265 -122
  40. smftools/informatics/fasta_functions.py +161 -83
  41. smftools/informatics/h5ad_functions.py +195 -29
  42. smftools/informatics/modkit_extract_to_adata.py +609 -270
  43. smftools/informatics/modkit_functions.py +85 -44
  44. smftools/informatics/ohe.py +44 -21
  45. smftools/informatics/pod5_functions.py +112 -73
  46. smftools/informatics/run_multiqc.py +20 -14
  47. smftools/logging_utils.py +51 -0
  48. smftools/machine_learning/__init__.py +2 -7
  49. smftools/machine_learning/data/anndata_data_module.py +143 -50
  50. smftools/machine_learning/data/preprocessing.py +2 -1
  51. smftools/machine_learning/evaluation/__init__.py +1 -1
  52. smftools/machine_learning/evaluation/eval_utils.py +11 -14
  53. smftools/machine_learning/evaluation/evaluators.py +46 -33
  54. smftools/machine_learning/inference/__init__.py +1 -1
  55. smftools/machine_learning/inference/inference_utils.py +7 -4
  56. smftools/machine_learning/inference/lightning_inference.py +9 -13
  57. smftools/machine_learning/inference/sklearn_inference.py +6 -8
  58. smftools/machine_learning/inference/sliding_window_inference.py +35 -25
  59. smftools/machine_learning/models/__init__.py +10 -5
  60. smftools/machine_learning/models/base.py +28 -42
  61. smftools/machine_learning/models/cnn.py +15 -11
  62. smftools/machine_learning/models/lightning_base.py +71 -40
  63. smftools/machine_learning/models/mlp.py +13 -4
  64. smftools/machine_learning/models/positional.py +3 -2
  65. smftools/machine_learning/models/rnn.py +3 -2
  66. smftools/machine_learning/models/sklearn_models.py +39 -22
  67. smftools/machine_learning/models/transformer.py +68 -53
  68. smftools/machine_learning/models/wrappers.py +2 -1
  69. smftools/machine_learning/training/__init__.py +2 -2
  70. smftools/machine_learning/training/train_lightning_model.py +29 -20
  71. smftools/machine_learning/training/train_sklearn_model.py +9 -15
  72. smftools/machine_learning/utils/__init__.py +1 -1
  73. smftools/machine_learning/utils/device.py +7 -4
  74. smftools/machine_learning/utils/grl.py +3 -1
  75. smftools/metadata.py +443 -0
  76. smftools/plotting/__init__.py +19 -5
  77. smftools/plotting/autocorrelation_plotting.py +145 -44
  78. smftools/plotting/classifiers.py +162 -72
  79. smftools/plotting/general_plotting.py +347 -168
  80. smftools/plotting/hmm_plotting.py +42 -13
  81. smftools/plotting/position_stats.py +145 -85
  82. smftools/plotting/qc_plotting.py +20 -12
  83. smftools/preprocessing/__init__.py +8 -8
  84. smftools/preprocessing/append_base_context.py +105 -79
  85. smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
  86. smftools/preprocessing/{archives → archived}/calculate_complexity.py +3 -1
  87. smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
  88. smftools/preprocessing/binarize.py +21 -4
  89. smftools/preprocessing/binarize_on_Youden.py +127 -31
  90. smftools/preprocessing/binary_layers_to_ohe.py +17 -11
  91. smftools/preprocessing/calculate_complexity_II.py +86 -59
  92. smftools/preprocessing/calculate_consensus.py +28 -19
  93. smftools/preprocessing/calculate_coverage.py +44 -22
  94. smftools/preprocessing/calculate_pairwise_differences.py +2 -1
  95. smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
  96. smftools/preprocessing/calculate_position_Youden.py +103 -55
  97. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  98. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  99. smftools/preprocessing/clean_NaN.py +38 -28
  100. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  101. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +70 -37
  102. smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
  103. smftools/preprocessing/flag_duplicate_reads.py +688 -271
  104. smftools/preprocessing/invert_adata.py +26 -11
  105. smftools/preprocessing/load_sample_sheet.py +40 -22
  106. smftools/preprocessing/make_dirs.py +8 -3
  107. smftools/preprocessing/min_non_diagonal.py +2 -1
  108. smftools/preprocessing/recipes.py +56 -23
  109. smftools/preprocessing/reindex_references_adata.py +93 -27
  110. smftools/preprocessing/subsample_adata.py +33 -16
  111. smftools/readwrite.py +264 -109
  112. smftools/schema/__init__.py +11 -0
  113. smftools/schema/anndata_schema_v1.yaml +227 -0
  114. smftools/tools/__init__.py +3 -4
  115. smftools/tools/archived/classifiers.py +163 -0
  116. smftools/tools/archived/subset_adata_v1.py +10 -1
  117. smftools/tools/archived/subset_adata_v2.py +12 -1
  118. smftools/tools/calculate_umap.py +54 -15
  119. smftools/tools/cluster_adata_on_methylation.py +115 -46
  120. smftools/tools/general_tools.py +70 -25
  121. smftools/tools/position_stats.py +229 -98
  122. smftools/tools/read_stats.py +50 -29
  123. smftools/tools/spatial_autocorrelation.py +365 -192
  124. smftools/tools/subset_adata.py +23 -21
  125. {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/METADATA +15 -43
  126. smftools-0.2.5.dist-info/RECORD +181 -0
  127. smftools-0.2.4.dist-info/RECORD +0 -176
  128. /smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +0 -0
  129. /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
  130. /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
  131. {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
  132. {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
  133. {smftools-0.2.4.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
@@ -1,16 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
- import numpy as np
4
- import seaborn as sns
5
- import matplotlib.pyplot as plt
6
- import scipy.cluster.hierarchy as sch
7
- import matplotlib.gridspec as gridspec
8
- import os
9
3
  import math
4
+ import os
5
+ from pathlib import Path
6
+ from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
7
+
8
+ import matplotlib.gridspec as gridspec
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
10
11
  import pandas as pd
12
+ import scipy.cluster.hierarchy as sch
13
+ import seaborn as sns
11
14
 
12
- from typing import Optional, Mapping, Sequence, Any, Dict, List, Tuple
13
- from pathlib import Path
14
15
 
15
16
  def _fixed_tick_positions(n_positions: int, n_ticks: int) -> np.ndarray:
16
17
  """
@@ -25,6 +26,7 @@ def _fixed_tick_positions(n_positions: int, n_ticks: int) -> np.ndarray:
25
26
  pos = np.linspace(0, n_positions - 1, n_ticks)
26
27
  return np.unique(np.round(pos).astype(int))
27
28
 
29
+
28
30
  def _select_labels(subset, sites: np.ndarray, reference: str, index_col_suffix: str | None):
29
31
  """
30
32
  Select tick labels for the heatmap axis.
@@ -65,11 +67,21 @@ def _select_labels(subset, sites: np.ndarray, reference: str, index_col_suffix:
65
67
  labels = subset.var[colname].astype(str).values
66
68
  return labels[sites]
67
69
 
70
+
68
71
  def normalized_mean(matrix: np.ndarray) -> np.ndarray:
72
+ """Compute normalized column means for a matrix.
73
+
74
+ Args:
75
+ matrix: Input matrix.
76
+
77
+ Returns:
78
+ 1D array of normalized means.
79
+ """
69
80
  mean = np.nanmean(matrix, axis=0)
70
81
  denom = (mean.max() - mean.min()) + 1e-9
71
82
  return (mean - mean.min()) / denom
72
83
 
84
+
73
85
  def methylation_fraction(matrix: np.ndarray) -> np.ndarray:
74
86
  """
75
87
  Fraction methylated per column.
@@ -84,14 +96,20 @@ def methylation_fraction(matrix: np.ndarray) -> np.ndarray:
84
96
  valid = valid_mask.sum(axis=0)
85
97
 
86
98
  return np.divide(
87
- methylated, valid,
88
- out=np.zeros_like(methylated, dtype=float),
89
- where=valid != 0
99
+ methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0
90
100
  )
91
101
 
102
+
92
103
  def clean_barplot(ax, mean_values, title):
104
+ """Format a barplot with consistent axes and labels.
105
+
106
+ Args:
107
+ ax: Matplotlib axes.
108
+ mean_values: Values to plot.
109
+ title: Plot title.
110
+ """
93
111
  x = np.arange(len(mean_values))
94
- ax.bar(x, mean_values, color="gray", width=1.0, align='edge')
112
+ ax.bar(x, mean_values, color="gray", width=1.0, align="edge")
95
113
  ax.set_xlim(0, len(mean_values))
96
114
  ax.set_ylim(0, 1)
97
115
  ax.set_yticks([0.0, 0.5, 1.0])
@@ -100,9 +118,10 @@ def clean_barplot(ax, mean_values, title):
100
118
 
101
119
  # Hide all spines except left
102
120
  for spine_name, spine in ax.spines.items():
103
- spine.set_visible(spine_name == 'left')
121
+ spine.set_visible(spine_name == "left")
122
+
123
+ ax.tick_params(axis="x", which="both", bottom=False, top=False, labelbottom=False)
104
124
 
105
- ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
106
125
 
107
126
  # def combined_hmm_raw_clustermap(
108
127
  # adata,
@@ -145,7 +164,7 @@ def clean_barplot(ax, mean_values, title):
145
164
  # (adata.obs['read_length'] >= min_length) &
146
165
  # (adata.obs['mapped_length_to_reference_length_ratio'] > min_mapped_length_to_reference_length_ratio)
147
166
  # ]
148
-
167
+
149
168
  # mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
150
169
  # subset = subset[:, mask]
151
170
 
@@ -257,7 +276,7 @@ def clean_barplot(ax, mean_values, title):
257
276
  # clean_barplot(axes_bar[1], mean_gpc, f"GpC Accessibility Signal")
258
277
  # clean_barplot(axes_bar[2], mean_cpg, f"CpG Accessibility Signal")
259
278
  # clean_barplot(axes_bar[3], mean_any_c, f"Any C Accessibility Signal")
260
-
279
+
261
280
  # hmm_labels = subset.var_names.astype(int)
262
281
  # hmm_label_spacing = 150
263
282
  # sns.heatmap(hmm_matrix, cmap=cmap_hmm, ax=axes_heat[0], xticklabels=hmm_labels[::hmm_label_spacing], yticklabels=False, cbar=False)
@@ -311,7 +330,7 @@ def clean_barplot(ax, mean_values, title):
311
330
  # "bin_boundaries": bin_boundaries,
312
331
  # "percentages": percentages
313
332
  # })
314
-
333
+
315
334
  # #adata.uns['clustermap_results'] = results
316
335
 
317
336
  # except Exception as e:
@@ -319,45 +338,39 @@ def clean_barplot(ax, mean_values, title):
319
338
  # traceback.print_exc()
320
339
  # continue
321
340
 
341
+
322
342
  def combined_hmm_raw_clustermap(
323
343
  adata,
324
344
  sample_col: str = "Sample_Names",
325
345
  reference_col: str = "Reference_strand",
326
-
327
346
  hmm_feature_layer: str = "hmm_combined",
328
-
329
347
  layer_gpc: str = "nan0_0minus1",
330
348
  layer_cpg: str = "nan0_0minus1",
331
349
  layer_c: str = "nan0_0minus1",
332
350
  layer_a: str = "nan0_0minus1",
333
-
334
351
  cmap_hmm: str = "tab10",
335
352
  cmap_gpc: str = "coolwarm",
336
353
  cmap_cpg: str = "viridis",
337
354
  cmap_c: str = "coolwarm",
338
355
  cmap_a: str = "coolwarm",
339
-
340
356
  min_quality: int = 20,
341
357
  min_length: int = 200,
342
358
  min_mapped_length_to_reference_length_ratio: float = 0.8,
343
359
  min_position_valid_fraction: float = 0.5,
344
-
360
+ demux_types: Sequence[str] = ("single", "double", "already"),
361
+ sample_mapping: Optional[Mapping[str, str]] = None,
345
362
  save_path: str | Path | None = None,
346
363
  normalize_hmm: bool = False,
347
-
348
364
  sort_by: str = "gpc",
349
365
  bins: Optional[Dict[str, Any]] = None,
350
-
351
366
  deaminase: bool = False,
352
367
  min_signal: float = 0.0,
353
-
354
368
  # ---- fixed tick label controls (counts, not spacing)
355
369
  n_xticks_hmm: int = 10,
356
370
  n_xticks_any_c: int = 8,
357
371
  n_xticks_gpc: int = 8,
358
372
  n_xticks_cpg: int = 8,
359
373
  n_xticks_a: int = 8,
360
-
361
374
  index_col_suffix: str | None = None,
362
375
  ):
363
376
  """
@@ -369,39 +382,92 @@ def combined_hmm_raw_clustermap(
369
382
  sort_by options:
370
383
  'gpc', 'cpg', 'c', 'a', 'gpc_cpg', 'none', 'hmm', or 'obs:<col>'
371
384
  """
385
+
372
386
  def pick_xticks(labels: np.ndarray, n_ticks: int):
387
+ """Pick tick indices/labels from an array."""
373
388
  if labels.size == 0:
374
389
  return [], []
375
390
  idx = np.linspace(0, len(labels) - 1, n_ticks).round().astype(int)
376
391
  idx = np.unique(idx)
377
392
  return idx.tolist(), labels[idx].tolist()
378
-
393
+
394
+ # Helper: build a True mask if filter is inactive or column missing
395
+ def _mask_or_true(series_name: str, predicate):
396
+ """Return a mask from predicate or an all-True mask."""
397
+ if series_name not in adata.obs:
398
+ return pd.Series(True, index=adata.obs.index)
399
+ s = adata.obs[series_name]
400
+ try:
401
+ return predicate(s)
402
+ except Exception:
403
+ # Fallback: all True if bad dtype / predicate failure
404
+ return pd.Series(True, index=adata.obs.index)
405
+
379
406
  results = []
380
407
  signal_type = "deamination" if deaminase else "methylation"
381
408
 
382
409
  for ref in adata.obs[reference_col].cat.categories:
383
410
  for sample in adata.obs[sample_col].cat.categories:
411
+ # Optionally remap sample label for display
412
+ display_sample = sample_mapping.get(sample, sample) if sample_mapping else sample
413
+ # Row-level masks (obs)
414
+ qmask = _mask_or_true(
415
+ "read_quality",
416
+ (lambda s: s >= float(min_quality))
417
+ if (min_quality is not None)
418
+ else (lambda s: pd.Series(True, index=s.index)),
419
+ )
420
+ lm_mask = _mask_or_true(
421
+ "mapped_length",
422
+ (lambda s: s >= float(min_length))
423
+ if (min_length is not None)
424
+ else (lambda s: pd.Series(True, index=s.index)),
425
+ )
426
+ lrr_mask = _mask_or_true(
427
+ "mapped_length_to_reference_length_ratio",
428
+ (lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
429
+ if (min_mapped_length_to_reference_length_ratio is not None)
430
+ else (lambda s: pd.Series(True, index=s.index)),
431
+ )
432
+
433
+ demux_mask = _mask_or_true(
434
+ "demux_type",
435
+ (lambda s: s.astype("string").isin(list(demux_types)))
436
+ if (demux_types is not None)
437
+ else (lambda s: pd.Series(True, index=s.index)),
438
+ )
439
+
440
+ ref_mask = adata.obs[reference_col] == ref
441
+ sample_mask = adata.obs[sample_col] == sample
442
+
443
+ row_mask = ref_mask & sample_mask & qmask & lm_mask & lrr_mask & demux_mask
444
+
445
+ if not bool(row_mask.any()):
446
+ print(
447
+ f"No reads for {display_sample} - {ref} after read quality and length filtering"
448
+ )
449
+ continue
384
450
 
385
451
  try:
386
452
  # ---- subset reads ----
387
- subset = adata[
388
- (adata.obs[reference_col] == ref) &
389
- (adata.obs[sample_col] == sample) &
390
- (adata.obs["read_quality"] >= min_quality) &
391
- (adata.obs["read_length"] >= min_length) &
392
- (
393
- adata.obs["mapped_length_to_reference_length_ratio"]
394
- > min_mapped_length_to_reference_length_ratio
395
- )
396
- ]
397
-
398
- # ---- valid fraction filter ----
399
- vf_key = f"{ref}_valid_fraction"
400
- if vf_key in subset.var:
401
- mask = subset.var[vf_key].astype(float) > float(min_position_valid_fraction)
402
- subset = subset[:, mask]
453
+ subset = adata[row_mask, :].copy()
454
+
455
+ # Column-level mask (var)
456
+ if min_position_valid_fraction is not None:
457
+ valid_key = f"{ref}_valid_fraction"
458
+ if valid_key in subset.var:
459
+ v = pd.to_numeric(subset.var[valid_key], errors="coerce").to_numpy()
460
+ col_mask = np.asarray(v > float(min_position_valid_fraction), dtype=bool)
461
+ if col_mask.any():
462
+ subset = subset[:, col_mask].copy()
463
+ else:
464
+ print(
465
+ f"No positions left after valid_fraction filter for {display_sample} - {ref}"
466
+ )
467
+ continue
403
468
 
404
469
  if subset.shape[0] == 0:
470
+ print(f"No reads left after filtering for {display_sample} - {ref}")
405
471
  continue
406
472
 
407
473
  # ---- bins ----
@@ -412,22 +478,23 @@ def combined_hmm_raw_clustermap(
412
478
 
413
479
  # ---- site masks (robust) ----
414
480
  def _sites(*keys):
481
+ """Return indices for the first matching site key."""
415
482
  for k in keys:
416
483
  if k in subset.var:
417
484
  return np.where(subset.var[k].values)[0]
418
485
  return np.array([], dtype=int)
419
486
 
420
- gpc_sites = _sites(f"{ref}_GpC_site")
421
- cpg_sites = _sites(f"{ref}_CpG_site")
487
+ gpc_sites = _sites(f"{ref}_GpC_site")
488
+ cpg_sites = _sites(f"{ref}_CpG_site")
422
489
  any_c_sites = _sites(f"{ref}_any_C_site", f"{ref}_C_site")
423
490
  any_a_sites = _sites(f"{ref}_A_site", f"{ref}_any_A_site")
424
491
 
425
492
  # ---- labels via _select_labels ----
426
493
  # HMM uses *all* columns
427
- hmm_sites = np.arange(subset.n_vars, dtype=int)
428
- hmm_labels = _select_labels(subset, hmm_sites, ref, index_col_suffix)
429
- gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
430
- cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
494
+ hmm_sites = np.arange(subset.n_vars, dtype=int)
495
+ hmm_labels = _select_labels(subset, hmm_sites, ref, index_col_suffix)
496
+ gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
497
+ cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
431
498
  any_c_labels = _select_labels(subset, any_c_sites, ref, index_col_suffix)
432
499
  any_a_labels = _select_labels(subset, any_a_sites, ref, index_col_suffix)
433
500
 
@@ -477,9 +544,11 @@ def combined_hmm_raw_clustermap(
477
544
  elif sort_by == "gpc_cpg" and gpc_sites.size and cpg_sites.size:
478
545
  linkage = sch.linkage(sb.layers[layer_gpc], method="ward")
479
546
  order = sch.leaves_list(linkage)
480
-
547
+
481
548
  elif sort_by == "hmm" and hmm_sites.size:
482
- linkage = sch.linkage(sb[:, hmm_sites].layers[hmm_feature_layer], method="ward")
549
+ linkage = sch.linkage(
550
+ sb[:, hmm_sites].layers[hmm_feature_layer], method="ward"
551
+ )
483
552
  order = sch.leaves_list(linkage)
484
553
 
485
554
  else:
@@ -505,46 +574,62 @@ def combined_hmm_raw_clustermap(
505
574
 
506
575
  # ---------------- stack ----------------
507
576
  hmm_matrix = np.vstack(stacked_hmm)
508
- mean_hmm = normalized_mean(hmm_matrix) if normalize_hmm else np.nanmean(hmm_matrix, axis=0)
577
+ mean_hmm = (
578
+ normalized_mean(hmm_matrix) if normalize_hmm else np.nanmean(hmm_matrix, axis=0)
579
+ )
509
580
 
510
581
  panels = [
511
- (f"HMM - {hmm_feature_layer}", hmm_matrix, hmm_labels, cmap_hmm, mean_hmm, n_xticks_hmm),
582
+ (
583
+ f"HMM - {hmm_feature_layer}",
584
+ hmm_matrix,
585
+ hmm_labels,
586
+ cmap_hmm,
587
+ mean_hmm,
588
+ n_xticks_hmm,
589
+ ),
512
590
  ]
513
591
 
514
592
  if stacked_any_c:
515
593
  m = np.vstack(stacked_any_c)
516
- panels.append(("C", m, any_c_labels, cmap_c, methylation_fraction(m), n_xticks_any_c))
594
+ panels.append(
595
+ ("C", m, any_c_labels, cmap_c, methylation_fraction(m), n_xticks_any_c)
596
+ )
517
597
 
518
598
  if stacked_gpc:
519
599
  m = np.vstack(stacked_gpc)
520
- panels.append(("GpC", m, gpc_labels, cmap_gpc, methylation_fraction(m), n_xticks_gpc))
600
+ panels.append(
601
+ ("GpC", m, gpc_labels, cmap_gpc, methylation_fraction(m), n_xticks_gpc)
602
+ )
521
603
 
522
604
  if stacked_cpg:
523
605
  m = np.vstack(stacked_cpg)
524
- panels.append(("CpG", m, cpg_labels, cmap_cpg, methylation_fraction(m), n_xticks_cpg))
606
+ panels.append(
607
+ ("CpG", m, cpg_labels, cmap_cpg, methylation_fraction(m), n_xticks_cpg)
608
+ )
525
609
 
526
610
  if stacked_any_a:
527
611
  m = np.vstack(stacked_any_a)
528
- panels.append(("A", m, any_a_labels, cmap_a, methylation_fraction(m), n_xticks_a))
612
+ panels.append(
613
+ ("A", m, any_a_labels, cmap_a, methylation_fraction(m), n_xticks_a)
614
+ )
529
615
 
530
616
  # ---------------- plotting ----------------
531
617
  n_panels = len(panels)
532
618
  fig = plt.figure(figsize=(4.5 * n_panels, 10))
533
619
  gs = gridspec.GridSpec(2, n_panels, height_ratios=[1, 6], hspace=0.01)
534
- fig.suptitle(f"{sample} — {ref} — {total_reads} reads ({signal_type})",
535
- fontsize=14, y=0.98)
620
+ fig.suptitle(
621
+ f"{sample} — {ref} — {total_reads} reads ({signal_type})", fontsize=14, y=0.98
622
+ )
536
623
 
537
624
  axes_heat = [fig.add_subplot(gs[1, i]) for i in range(n_panels)]
538
625
  axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(n_panels)]
539
626
 
540
627
  for i, (name, matrix, labels, cmap, mean_vec, n_ticks) in enumerate(panels):
541
-
542
628
  # ---- your clean barplot ----
543
629
  clean_barplot(axes_bar[i], mean_vec, name)
544
630
 
545
631
  # ---- heatmap ----
546
- sns.heatmap(matrix, cmap=cmap, ax=axes_heat[i],
547
- yticklabels=False, cbar=False)
632
+ sns.heatmap(matrix, cmap=cmap, ax=axes_heat[i], yticklabels=False, cbar=False)
548
633
 
549
634
  # ---- xticks ----
550
635
  xtick_pos, xtick_labels = pick_xticks(np.asarray(labels), n_ticks)
@@ -568,6 +653,7 @@ def combined_hmm_raw_clustermap(
568
653
 
569
654
  except Exception:
570
655
  import traceback
656
+
571
657
  traceback.print_exc()
572
658
  continue
573
659
 
@@ -687,7 +773,7 @@ def combined_hmm_raw_clustermap(
687
773
  # order = np.arange(num_reads)
688
774
  # elif sort_by == "any_a":
689
775
  # linkage = sch.linkage(subset_bin.layers[layer_a], method="ward")
690
- # order = sch.leaves_list(linkage)
776
+ # order = sch.leaves_list(linkage)
691
777
  # else:
692
778
  # raise ValueError(f"Unsupported sort_by option: {sort_by}")
693
779
 
@@ -716,13 +802,13 @@ def combined_hmm_raw_clustermap(
716
802
  # order = np.arange(num_reads)
717
803
  # elif sort_by == "any_a":
718
804
  # linkage = sch.linkage(subset_bin.layers[layer_a], method="ward")
719
- # order = sch.leaves_list(linkage)
805
+ # order = sch.leaves_list(linkage)
720
806
  # else:
721
807
  # raise ValueError(f"Unsupported sort_by option: {sort_by}")
722
-
808
+
723
809
  # stacked_any_a.append(subset_bin[order][:, any_a_sites].layers[layer_a])
724
-
725
-
810
+
811
+
726
812
  # row_labels.extend([bin_label] * num_reads)
727
813
  # bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
728
814
  # last_idx += num_reads
@@ -745,7 +831,7 @@ def combined_hmm_raw_clustermap(
745
831
  # if any_a_matrix.size > 0:
746
832
  # mean_any_a = methylation_fraction(any_a_matrix)
747
833
  # gs_dim += 1
748
-
834
+
749
835
 
750
836
  # fig = plt.figure(figsize=(18, 12))
751
837
  # gs = gridspec.GridSpec(2, gs_dim, height_ratios=[1, 6], hspace=0.01)
@@ -777,8 +863,8 @@ def combined_hmm_raw_clustermap(
777
863
  # sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
778
864
  # axes_heat[current_ax].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
779
865
  # for boundary in bin_boundaries[:-1]:
780
- # axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
781
- # current_ax +=1
866
+ # axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
867
+ # current_ax +=1
782
868
 
783
869
  # results.append({
784
870
  # "sample": sample,
@@ -790,7 +876,7 @@ def combined_hmm_raw_clustermap(
790
876
  # "bin_labels": bin_labels,
791
877
  # "bin_boundaries": bin_boundaries,
792
878
  # "percentages": percentages
793
- # })
879
+ # })
794
880
 
795
881
  # if stacked_any_a:
796
882
  # if any_a_matrix.size > 0:
@@ -810,7 +896,7 @@ def combined_hmm_raw_clustermap(
810
896
  # "bin_labels": bin_labels,
811
897
  # "bin_boundaries": bin_boundaries,
812
898
  # "percentages": percentages
813
- # })
899
+ # })
814
900
 
815
901
  # plt.tight_layout()
816
902
 
@@ -828,7 +914,7 @@ def combined_hmm_raw_clustermap(
828
914
  # print(f"Summary for {sample} - {ref}:")
829
915
  # for bin_label, percent in percentages.items():
830
916
  # print(f" - {bin_label}: {percent:.1f}%")
831
-
917
+
832
918
  # adata.uns['clustermap_results'] = results
833
919
 
834
920
  # except Exception as e:
@@ -836,6 +922,7 @@ def combined_hmm_raw_clustermap(
836
922
  # traceback.print_exc()
837
923
  # continue
838
924
 
925
+
839
926
  def combined_raw_clustermap(
840
927
  adata,
841
928
  sample_col: str = "Sample_Names",
@@ -849,10 +936,11 @@ def combined_raw_clustermap(
849
936
  cmap_gpc: str = "coolwarm",
850
937
  cmap_cpg: str = "viridis",
851
938
  cmap_a: str = "coolwarm",
852
- min_quality: float = 20,
853
- min_length: int = 200,
854
- min_mapped_length_to_reference_length_ratio: float = 0.8,
855
- min_position_valid_fraction: float = 0.5,
939
+ min_quality: float | None = 20,
940
+ min_length: int | None = 200,
941
+ min_mapped_length_to_reference_length_ratio: float | None = 0,
942
+ min_position_valid_fraction: float | None = 0,
943
+ demux_types: Sequence[str] = ("single", "double", "already"),
856
944
  sample_mapping: Optional[Mapping[str, str]] = None,
857
945
  save_path: str | Path | None = None,
858
946
  sort_by: str = "gpc", # 'gpc','cpg','c','gpc_cpg','a','none','obs:<col>'
@@ -884,6 +972,18 @@ def combined_raw_clustermap(
884
972
  One entry per (sample, ref) plot with matrices + bin metadata.
885
973
  """
886
974
 
975
+ # Helper: build a True mask if filter is inactive or column missing
976
+ def _mask_or_true(series_name: str, predicate):
977
+ """Return a mask from predicate or an all-True mask."""
978
+ if series_name not in adata.obs:
979
+ return pd.Series(True, index=adata.obs.index)
980
+ s = adata.obs[series_name]
981
+ try:
982
+ return predicate(s)
983
+ except Exception:
984
+ # Fallback: all True if bad dtype / predicate failure
985
+ return pd.Series(True, index=adata.obs.index)
986
+
887
987
  results: List[Dict[str, Any]] = []
888
988
  save_path = Path(save_path) if save_path is not None else None
889
989
  if save_path is not None:
@@ -902,24 +1002,63 @@ def combined_raw_clustermap(
902
1002
 
903
1003
  for ref in adata.obs[reference_col].cat.categories:
904
1004
  for sample in adata.obs[sample_col].cat.categories:
905
-
906
1005
  # Optionally remap sample label for display
907
1006
  display_sample = sample_mapping.get(sample, sample) if sample_mapping else sample
908
1007
 
909
- try:
910
- subset = adata[
911
- (adata.obs[reference_col] == ref) &
912
- (adata.obs[sample_col] == sample) &
913
- (adata.obs["read_quality"] >= min_quality) &
914
- (adata.obs["mapped_length"] >= min_length) &
915
- (adata.obs["mapped_length_to_reference_length_ratio"] >= min_mapped_length_to_reference_length_ratio)
916
- ]
1008
+ # Row-level masks (obs)
1009
+ qmask = _mask_or_true(
1010
+ "read_quality",
1011
+ (lambda s: s >= float(min_quality))
1012
+ if (min_quality is not None)
1013
+ else (lambda s: pd.Series(True, index=s.index)),
1014
+ )
1015
+ lm_mask = _mask_or_true(
1016
+ "mapped_length",
1017
+ (lambda s: s >= float(min_length))
1018
+ if (min_length is not None)
1019
+ else (lambda s: pd.Series(True, index=s.index)),
1020
+ )
1021
+ lrr_mask = _mask_or_true(
1022
+ "mapped_length_to_reference_length_ratio",
1023
+ (lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
1024
+ if (min_mapped_length_to_reference_length_ratio is not None)
1025
+ else (lambda s: pd.Series(True, index=s.index)),
1026
+ )
1027
+
1028
+ demux_mask = _mask_or_true(
1029
+ "demux_type",
1030
+ (lambda s: s.astype("string").isin(list(demux_types)))
1031
+ if (demux_types is not None)
1032
+ else (lambda s: pd.Series(True, index=s.index)),
1033
+ )
1034
+
1035
+ ref_mask = adata.obs[reference_col] == ref
1036
+ sample_mask = adata.obs[sample_col] == sample
1037
+
1038
+ row_mask = ref_mask & sample_mask & qmask & lm_mask & lrr_mask & demux_mask
1039
+
1040
+ if not bool(row_mask.any()):
1041
+ print(
1042
+ f"No reads for {display_sample} - {ref} after read quality and length filtering"
1043
+ )
1044
+ continue
917
1045
 
918
- # position-level mask
919
- valid_key = f"{ref}_valid_fraction"
920
- if valid_key in subset.var:
921
- mask = subset.var[valid_key].astype(float).values > float(min_position_valid_fraction)
922
- subset = subset[:, mask]
1046
+ try:
1047
+ subset = adata[row_mask, :].copy()
1048
+
1049
+ # Column-level mask (var)
1050
+ if min_position_valid_fraction is not None:
1051
+ valid_key = f"{ref}_valid_fraction"
1052
+ if valid_key in subset.var:
1053
+ v = pd.to_numeric(subset.var[valid_key], errors="coerce").to_numpy()
1054
+ col_mask = np.asarray(v > float(min_position_valid_fraction), dtype=bool)
1055
+ if col_mask.any():
1056
+ subset = subset[:, col_mask].copy()
1057
+ else:
1058
+ print(
1059
+ f"No positions left after valid_fraction filter for {display_sample} - {ref}"
1060
+ )
1061
+ continue
923
1062
 
924
1063
  if subset.shape[0] == 0:
925
1064
  print(f"No reads left after filtering for {display_sample} - {ref}")
@@ -939,14 +1078,14 @@ def combined_raw_clustermap(
939
1078
 
940
1079
  if include_any_c:
941
1080
  any_c_sites = np.where(subset.var.get(f"{ref}_C_site", False).values)[0]
942
- gpc_sites = np.where(subset.var.get(f"{ref}_GpC_site", False).values)[0]
943
- cpg_sites = np.where(subset.var.get(f"{ref}_CpG_site", False).values)[0]
1081
+ gpc_sites = np.where(subset.var.get(f"{ref}_GpC_site", False).values)[0]
1082
+ cpg_sites = np.where(subset.var.get(f"{ref}_CpG_site", False).values)[0]
944
1083
 
945
1084
  num_any_c, num_gpc, num_cpg = len(any_c_sites), len(gpc_sites), len(cpg_sites)
946
1085
 
947
1086
  any_c_labels = _select_labels(subset, any_c_sites, ref, index_col_suffix)
948
- gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
949
- cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
1087
+ gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
1088
+ cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
950
1089
 
951
1090
  if include_any_a:
952
1091
  any_a_sites = np.where(subset.var.get(f"{ref}_A_site", False).values)[0]
@@ -978,15 +1117,21 @@ def combined_raw_clustermap(
978
1117
  order = np.argsort(subset_bin.obs[colname].values)
979
1118
 
980
1119
  elif sort_by == "gpc" and num_gpc > 0:
981
- linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
1120
+ linkage = sch.linkage(
1121
+ subset_bin[:, gpc_sites].layers[layer_gpc], method="ward"
1122
+ )
982
1123
  order = sch.leaves_list(linkage)
983
1124
 
984
1125
  elif sort_by == "cpg" and num_cpg > 0:
985
- linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
1126
+ linkage = sch.linkage(
1127
+ subset_bin[:, cpg_sites].layers[layer_cpg], method="ward"
1128
+ )
986
1129
  order = sch.leaves_list(linkage)
987
1130
 
988
1131
  elif sort_by == "c" and num_any_c > 0:
989
- linkage = sch.linkage(subset_bin[:, any_c_sites].layers[layer_c], method="ward")
1132
+ linkage = sch.linkage(
1133
+ subset_bin[:, any_c_sites].layers[layer_c], method="ward"
1134
+ )
990
1135
  order = sch.leaves_list(linkage)
991
1136
 
992
1137
  elif sort_by == "gpc_cpg":
@@ -994,7 +1139,9 @@ def combined_raw_clustermap(
994
1139
  order = sch.leaves_list(linkage)
995
1140
 
996
1141
  elif sort_by == "a" and num_any_a > 0:
997
- linkage = sch.linkage(subset_bin[:, any_a_sites].layers[layer_a], method="ward")
1142
+ linkage = sch.linkage(
1143
+ subset_bin[:, any_a_sites].layers[layer_a], method="ward"
1144
+ )
998
1145
  order = sch.leaves_list(linkage)
999
1146
 
1000
1147
  elif sort_by == "none":
@@ -1027,57 +1174,65 @@ def combined_raw_clustermap(
1027
1174
 
1028
1175
  if include_any_c and stacked_any_c:
1029
1176
  any_c_matrix = np.vstack(stacked_any_c)
1030
- gpc_matrix = np.vstack(stacked_gpc) if stacked_gpc else np.empty((0, 0))
1031
- cpg_matrix = np.vstack(stacked_cpg) if stacked_cpg else np.empty((0, 0))
1177
+ gpc_matrix = np.vstack(stacked_gpc) if stacked_gpc else np.empty((0, 0))
1178
+ cpg_matrix = np.vstack(stacked_cpg) if stacked_cpg else np.empty((0, 0))
1032
1179
 
1033
1180
  mean_any_c = methylation_fraction(any_c_matrix) if any_c_matrix.size else None
1034
- mean_gpc = methylation_fraction(gpc_matrix) if gpc_matrix.size else None
1035
- mean_cpg = methylation_fraction(cpg_matrix) if cpg_matrix.size else None
1181
+ mean_gpc = methylation_fraction(gpc_matrix) if gpc_matrix.size else None
1182
+ mean_cpg = methylation_fraction(cpg_matrix) if cpg_matrix.size else None
1036
1183
 
1037
1184
  if any_c_matrix.size:
1038
- blocks.append(dict(
1039
- name="c",
1040
- matrix=any_c_matrix,
1041
- mean=mean_any_c,
1042
- labels=any_c_labels,
1043
- cmap=cmap_c,
1044
- n_xticks=n_xticks_any_c,
1045
- title="any C site Modification Signal"
1046
- ))
1185
+ blocks.append(
1186
+ dict(
1187
+ name="c",
1188
+ matrix=any_c_matrix,
1189
+ mean=mean_any_c,
1190
+ labels=any_c_labels,
1191
+ cmap=cmap_c,
1192
+ n_xticks=n_xticks_any_c,
1193
+ title="any C site Modification Signal",
1194
+ )
1195
+ )
1047
1196
  if gpc_matrix.size:
1048
- blocks.append(dict(
1049
- name="gpc",
1050
- matrix=gpc_matrix,
1051
- mean=mean_gpc,
1052
- labels=gpc_labels,
1053
- cmap=cmap_gpc,
1054
- n_xticks=n_xticks_gpc,
1055
- title="GpC Modification Signal"
1056
- ))
1197
+ blocks.append(
1198
+ dict(
1199
+ name="gpc",
1200
+ matrix=gpc_matrix,
1201
+ mean=mean_gpc,
1202
+ labels=gpc_labels,
1203
+ cmap=cmap_gpc,
1204
+ n_xticks=n_xticks_gpc,
1205
+ title="GpC Modification Signal",
1206
+ )
1207
+ )
1057
1208
  if cpg_matrix.size:
1058
- blocks.append(dict(
1059
- name="cpg",
1060
- matrix=cpg_matrix,
1061
- mean=mean_cpg,
1062
- labels=cpg_labels,
1063
- cmap=cmap_cpg,
1064
- n_xticks=n_xticks_cpg,
1065
- title="CpG Modification Signal"
1066
- ))
1209
+ blocks.append(
1210
+ dict(
1211
+ name="cpg",
1212
+ matrix=cpg_matrix,
1213
+ mean=mean_cpg,
1214
+ labels=cpg_labels,
1215
+ cmap=cmap_cpg,
1216
+ n_xticks=n_xticks_cpg,
1217
+ title="CpG Modification Signal",
1218
+ )
1219
+ )
1067
1220
 
1068
1221
  if include_any_a and stacked_any_a:
1069
1222
  any_a_matrix = np.vstack(stacked_any_a)
1070
1223
  mean_any_a = methylation_fraction(any_a_matrix) if any_a_matrix.size else None
1071
1224
  if any_a_matrix.size:
1072
- blocks.append(dict(
1073
- name="a",
1074
- matrix=any_a_matrix,
1075
- mean=mean_any_a,
1076
- labels=any_a_labels,
1077
- cmap=cmap_a,
1078
- n_xticks=n_xticks_any_a,
1079
- title="any A site Modification Signal"
1080
- ))
1225
+ blocks.append(
1226
+ dict(
1227
+ name="a",
1228
+ matrix=any_a_matrix,
1229
+ mean=mean_any_a,
1230
+ labels=any_a_labels,
1231
+ cmap=cmap_a,
1232
+ n_xticks=n_xticks_any_a,
1233
+ title="any A site Modification Signal",
1234
+ )
1235
+ )
1081
1236
 
1082
1237
  if not blocks:
1083
1238
  print(f"No matrices to plot for {display_sample} - {ref}")
@@ -1089,7 +1244,7 @@ def combined_raw_clustermap(
1089
1244
  fig.suptitle(f"{display_sample} - {ref} - {total_reads} reads", fontsize=14, y=0.97)
1090
1245
 
1091
1246
  axes_heat = [fig.add_subplot(gs[1, i]) for i in range(gs_dim)]
1092
- axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(gs_dim)]
1247
+ axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(gs_dim)]
1093
1248
 
1094
1249
  # ----------------------------
1095
1250
  # plot blocks
@@ -1105,20 +1260,14 @@ def combined_raw_clustermap(
1105
1260
 
1106
1261
  # heatmap
1107
1262
  sns.heatmap(
1108
- mat,
1109
- cmap=blk["cmap"],
1110
- ax=axes_heat[i],
1111
- yticklabels=False,
1112
- cbar=False
1263
+ mat, cmap=blk["cmap"], ax=axes_heat[i], yticklabels=False, cbar=False
1113
1264
  )
1114
1265
 
1115
1266
  # fixed tick labels
1116
1267
  tick_pos = _fixed_tick_positions(len(labels), n_xticks)
1117
1268
  axes_heat[i].set_xticks(tick_pos)
1118
1269
  axes_heat[i].set_xticklabels(
1119
- labels[tick_pos],
1120
- rotation=xtick_rotation,
1121
- fontsize=xtick_fontsize
1270
+ labels[tick_pos], rotation=xtick_rotation, fontsize=xtick_fontsize
1122
1271
  )
1123
1272
 
1124
1273
  # bin separators
@@ -1131,7 +1280,12 @@ def combined_raw_clustermap(
1131
1280
 
1132
1281
  # save or show
1133
1282
  if save_path is not None:
1134
- safe_name = f"{ref}__{display_sample}".replace("=", "").replace("__", "_").replace(",", "_").replace(" ", "_")
1283
+ safe_name = (
1284
+ f"{ref}__{display_sample}".replace("=", "")
1285
+ .replace("__", "_")
1286
+ .replace(",", "_")
1287
+ .replace(" ", "_")
1288
+ )
1135
1289
  out_file = save_path / f"{safe_name}.png"
1136
1290
  fig.savefig(out_file, dpi=300)
1137
1291
  plt.close(fig)
@@ -1157,20 +1311,15 @@ def combined_raw_clustermap(
1157
1311
  for bin_label, percent in percentages.items():
1158
1312
  print(f" - {bin_label}: {percent:.1f}%")
1159
1313
 
1160
- except Exception as e:
1314
+ except Exception:
1161
1315
  import traceback
1316
+
1162
1317
  traceback.print_exc()
1163
1318
  continue
1164
1319
 
1165
- # store once at the end (HDF5 safe)
1166
- # matrices won't be HDF5-safe; store only metadata + maybe hit counts
1167
- # adata.uns["clustermap_results"] = [
1168
- # {k: v for k, v in r.items() if not k.endswith("_matrix")}
1169
- # for r in results
1170
- # ]
1171
-
1172
1320
  return results
1173
1321
 
1322
+
1174
1323
  def plot_hmm_layers_rolling_by_sample_ref(
1175
1324
  adata,
1176
1325
  layers: Optional[Sequence[str]] = None,
@@ -1237,7 +1386,9 @@ def plot_hmm_layers_rolling_by_sample_ref(
1237
1386
 
1238
1387
  # --- basic checks / defaults ---
1239
1388
  if sample_col not in adata.obs.columns or ref_col not in adata.obs.columns:
1240
- raise ValueError(f"sample_col '{sample_col}' and ref_col '{ref_col}' must exist in adata.obs")
1389
+ raise ValueError(
1390
+ f"sample_col '{sample_col}' and ref_col '{ref_col}' must exist in adata.obs"
1391
+ )
1241
1392
 
1242
1393
  # canonicalize samples / refs
1243
1394
  if samples is None:
@@ -1260,7 +1411,9 @@ def plot_hmm_layers_rolling_by_sample_ref(
1260
1411
  if layers is None:
1261
1412
  layers = list(adata.layers.keys())
1262
1413
  if len(layers) == 0:
1263
- raise ValueError("No adata.layers found. Please pass `layers=[...]` of the HMM layers to plot.")
1414
+ raise ValueError(
1415
+ "No adata.layers found. Please pass `layers=[...]` of the HMM layers to plot."
1416
+ )
1264
1417
  layers = list(layers)
1265
1418
 
1266
1419
  # x coordinates (positions)
@@ -1299,19 +1452,29 @@ def plot_hmm_layers_rolling_by_sample_ref(
1299
1452
 
1300
1453
  fig_w = figsize_per_cell[0] * ncols
1301
1454
  fig_h = figsize_per_cell[1] * nrows
1302
- fig, axes = plt.subplots(nrows=nrows, ncols=ncols,
1303
- figsize=(fig_w, fig_h), dpi=dpi,
1304
- squeeze=False)
1455
+ fig, axes = plt.subplots(
1456
+ nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h), dpi=dpi, squeeze=False
1457
+ )
1305
1458
 
1306
1459
  for r_idx, sample_name in enumerate(chunk):
1307
1460
  for c_idx, ref_name in enumerate(refs_all):
1308
1461
  ax = axes[r_idx][c_idx]
1309
1462
 
1310
1463
  # subset adata
1311
- mask = (adata.obs[sample_col].values == sample_name) & (adata.obs[ref_col].values == ref_name)
1464
+ mask = (adata.obs[sample_col].values == sample_name) & (
1465
+ adata.obs[ref_col].values == ref_name
1466
+ )
1312
1467
  sub = adata[mask]
1313
1468
  if sub.n_obs == 0:
1314
- ax.text(0.5, 0.5, "No reads", ha="center", va="center", transform=ax.transAxes, color="gray")
1469
+ ax.text(
1470
+ 0.5,
1471
+ 0.5,
1472
+ "No reads",
1473
+ ha="center",
1474
+ va="center",
1475
+ transform=ax.transAxes,
1476
+ color="gray",
1477
+ )
1315
1478
  ax.set_xticks([])
1316
1479
  ax.set_yticks([])
1317
1480
  if r_idx == 0:
@@ -1361,7 +1524,11 @@ def plot_hmm_layers_rolling_by_sample_ref(
1361
1524
  smoothed = col_mean
1362
1525
  else:
1363
1526
  ser = pd.Series(col_mean)
1364
- smoothed = ser.rolling(window=window, min_periods=min_periods, center=center).mean().to_numpy()
1527
+ smoothed = (
1528
+ ser.rolling(window=window, min_periods=min_periods, center=center)
1529
+ .mean()
1530
+ .to_numpy()
1531
+ )
1365
1532
 
1366
1533
  # x axis: x_coords (trim/pad to match length)
1367
1534
  L = len(col_mean)
@@ -1371,7 +1538,15 @@ def plot_hmm_layers_rolling_by_sample_ref(
1371
1538
  if show_raw:
1372
1539
  ax.plot(x, col_mean[:L], linewidth=0.7, alpha=0.25, zorder=1)
1373
1540
 
1374
- ax.plot(x, smoothed[:L], label=layer, color=colors[li], linewidth=1.2, alpha=0.95, zorder=2)
1541
+ ax.plot(
1542
+ x,
1543
+ smoothed[:L],
1544
+ label=layer,
1545
+ color=colors[li],
1546
+ linewidth=1.2,
1547
+ alpha=0.95,
1548
+ zorder=2,
1549
+ )
1375
1550
  plotted_any = True
1376
1551
 
1377
1552
  # labels / titles
@@ -1389,11 +1564,15 @@ def plot_hmm_layers_rolling_by_sample_ref(
1389
1564
 
1390
1565
  ax.grid(True, alpha=0.2)
1391
1566
 
1392
- fig.suptitle(f"Rolling mean of layer positional means (window={window}) — page {page+1}/{total_pages}", fontsize=11, y=0.995)
1567
+ fig.suptitle(
1568
+ f"Rolling mean of layer positional means (window={window}) — page {page + 1}/{total_pages}",
1569
+ fontsize=11,
1570
+ y=0.995,
1571
+ )
1393
1572
  fig.tight_layout(rect=[0, 0, 1, 0.97])
1394
1573
 
1395
1574
  if save:
1396
- fname = os.path.join(outdir, f"hmm_layers_rolling_page{page+1}.png")
1575
+ fname = os.path.join(outdir, f"hmm_layers_rolling_page{page + 1}.png")
1397
1576
  plt.savefig(fname, bbox_inches="tight", dpi=dpi)
1398
1577
  saved_files.append(fname)
1399
1578
  else: