smftools 0.2.5__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. smftools/__init__.py +39 -7
  2. smftools/_settings.py +2 -0
  3. smftools/_version.py +3 -1
  4. smftools/cli/__init__.py +1 -0
  5. smftools/cli/archived/cli_flows.py +2 -0
  6. smftools/cli/helpers.py +34 -6
  7. smftools/cli/hmm_adata.py +239 -33
  8. smftools/cli/latent_adata.py +318 -0
  9. smftools/cli/load_adata.py +167 -131
  10. smftools/cli/preprocess_adata.py +180 -53
  11. smftools/cli/spatial_adata.py +152 -100
  12. smftools/cli_entry.py +38 -1
  13. smftools/config/__init__.py +2 -0
  14. smftools/config/conversion.yaml +11 -1
  15. smftools/config/default.yaml +42 -2
  16. smftools/config/experiment_config.py +59 -1
  17. smftools/constants.py +65 -0
  18. smftools/datasets/__init__.py +2 -0
  19. smftools/hmm/HMM.py +97 -3
  20. smftools/hmm/__init__.py +24 -13
  21. smftools/hmm/archived/apply_hmm_batched.py +2 -0
  22. smftools/hmm/archived/calculate_distances.py +2 -0
  23. smftools/hmm/archived/call_hmm_peaks.py +2 -0
  24. smftools/hmm/archived/train_hmm.py +2 -0
  25. smftools/hmm/call_hmm_peaks.py +5 -2
  26. smftools/hmm/display_hmm.py +4 -1
  27. smftools/hmm/hmm_readwrite.py +7 -2
  28. smftools/hmm/nucleosome_hmm_refinement.py +2 -0
  29. smftools/informatics/__init__.py +59 -34
  30. smftools/informatics/archived/bam_conversion.py +2 -0
  31. smftools/informatics/archived/bam_direct.py +2 -0
  32. smftools/informatics/archived/basecall_pod5s.py +2 -0
  33. smftools/informatics/archived/basecalls_to_adata.py +2 -0
  34. smftools/informatics/archived/conversion_smf.py +2 -0
  35. smftools/informatics/archived/deaminase_smf.py +1 -0
  36. smftools/informatics/archived/direct_smf.py +2 -0
  37. smftools/informatics/archived/fast5_to_pod5.py +2 -0
  38. smftools/informatics/archived/helpers/archived/__init__.py +2 -0
  39. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +2 -0
  40. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
  41. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
  42. smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
  43. smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
  44. smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
  45. smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
  46. smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
  47. smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
  48. smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
  49. smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
  50. smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
  51. smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
  52. smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
  53. smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
  54. smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
  55. smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
  56. smftools/informatics/archived/helpers/archived/informatics.py +2 -0
  57. smftools/informatics/archived/helpers/archived/load_adata.py +2 -0
  58. smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
  59. smftools/informatics/archived/helpers/archived/modQC.py +2 -0
  60. smftools/informatics/archived/helpers/archived/modcall.py +2 -0
  61. smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
  62. smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
  63. smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
  64. smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
  65. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +2 -0
  66. smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
  67. smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
  68. smftools/informatics/archived/print_bam_query_seq.py +2 -0
  69. smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
  70. smftools/informatics/archived/subsample_pod5.py +2 -0
  71. smftools/informatics/bam_functions.py +1093 -176
  72. smftools/informatics/basecalling.py +2 -0
  73. smftools/informatics/bed_functions.py +271 -61
  74. smftools/informatics/binarize_converted_base_identities.py +3 -0
  75. smftools/informatics/complement_base_list.py +2 -0
  76. smftools/informatics/converted_BAM_to_adata.py +641 -176
  77. smftools/informatics/fasta_functions.py +94 -10
  78. smftools/informatics/h5ad_functions.py +123 -4
  79. smftools/informatics/modkit_extract_to_adata.py +1019 -431
  80. smftools/informatics/modkit_functions.py +2 -0
  81. smftools/informatics/ohe.py +2 -0
  82. smftools/informatics/pod5_functions.py +3 -2
  83. smftools/informatics/sequence_encoding.py +72 -0
  84. smftools/logging_utils.py +21 -2
  85. smftools/machine_learning/__init__.py +22 -6
  86. smftools/machine_learning/data/__init__.py +2 -0
  87. smftools/machine_learning/data/anndata_data_module.py +18 -4
  88. smftools/machine_learning/data/preprocessing.py +2 -0
  89. smftools/machine_learning/evaluation/__init__.py +2 -0
  90. smftools/machine_learning/evaluation/eval_utils.py +2 -0
  91. smftools/machine_learning/evaluation/evaluators.py +14 -9
  92. smftools/machine_learning/inference/__init__.py +2 -0
  93. smftools/machine_learning/inference/inference_utils.py +2 -0
  94. smftools/machine_learning/inference/lightning_inference.py +6 -1
  95. smftools/machine_learning/inference/sklearn_inference.py +2 -0
  96. smftools/machine_learning/inference/sliding_window_inference.py +2 -0
  97. smftools/machine_learning/models/__init__.py +2 -0
  98. smftools/machine_learning/models/base.py +7 -2
  99. smftools/machine_learning/models/cnn.py +7 -2
  100. smftools/machine_learning/models/lightning_base.py +16 -11
  101. smftools/machine_learning/models/mlp.py +5 -1
  102. smftools/machine_learning/models/positional.py +7 -2
  103. smftools/machine_learning/models/rnn.py +5 -1
  104. smftools/machine_learning/models/sklearn_models.py +14 -9
  105. smftools/machine_learning/models/transformer.py +7 -2
  106. smftools/machine_learning/models/wrappers.py +6 -2
  107. smftools/machine_learning/training/__init__.py +2 -0
  108. smftools/machine_learning/training/train_lightning_model.py +13 -3
  109. smftools/machine_learning/training/train_sklearn_model.py +2 -0
  110. smftools/machine_learning/utils/__init__.py +2 -0
  111. smftools/machine_learning/utils/device.py +5 -1
  112. smftools/machine_learning/utils/grl.py +5 -1
  113. smftools/metadata.py +1 -1
  114. smftools/optional_imports.py +31 -0
  115. smftools/plotting/__init__.py +41 -31
  116. smftools/plotting/autocorrelation_plotting.py +9 -5
  117. smftools/plotting/classifiers.py +16 -4
  118. smftools/plotting/general_plotting.py +2415 -629
  119. smftools/plotting/hmm_plotting.py +97 -9
  120. smftools/plotting/position_stats.py +15 -7
  121. smftools/plotting/qc_plotting.py +6 -1
  122. smftools/preprocessing/__init__.py +36 -37
  123. smftools/preprocessing/append_base_context.py +17 -17
  124. smftools/preprocessing/append_mismatch_frequency_sites.py +158 -0
  125. smftools/preprocessing/archived/add_read_length_and_mapping_qc.py +2 -0
  126. smftools/preprocessing/archived/calculate_complexity.py +2 -0
  127. smftools/preprocessing/archived/mark_duplicates.py +2 -0
  128. smftools/preprocessing/archived/preprocessing.py +2 -0
  129. smftools/preprocessing/archived/remove_duplicates.py +2 -0
  130. smftools/preprocessing/binary_layers_to_ohe.py +2 -1
  131. smftools/preprocessing/calculate_complexity_II.py +4 -1
  132. smftools/preprocessing/calculate_consensus.py +1 -1
  133. smftools/preprocessing/calculate_pairwise_differences.py +2 -0
  134. smftools/preprocessing/calculate_pairwise_hamming_distances.py +3 -0
  135. smftools/preprocessing/calculate_position_Youden.py +9 -2
  136. smftools/preprocessing/calculate_read_modification_stats.py +6 -1
  137. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +2 -0
  138. smftools/preprocessing/filter_reads_on_modification_thresholds.py +2 -0
  139. smftools/preprocessing/flag_duplicate_reads.py +42 -54
  140. smftools/preprocessing/make_dirs.py +2 -1
  141. smftools/preprocessing/min_non_diagonal.py +2 -0
  142. smftools/preprocessing/recipes.py +2 -0
  143. smftools/readwrite.py +53 -17
  144. smftools/schema/anndata_schema_v1.yaml +15 -1
  145. smftools/tools/__init__.py +30 -18
  146. smftools/tools/archived/apply_hmm.py +2 -0
  147. smftools/tools/archived/classifiers.py +2 -0
  148. smftools/tools/archived/classify_methylated_features.py +2 -0
  149. smftools/tools/archived/classify_non_methylated_features.py +2 -0
  150. smftools/tools/archived/subset_adata_v1.py +2 -0
  151. smftools/tools/archived/subset_adata_v2.py +2 -0
  152. smftools/tools/calculate_leiden.py +57 -0
  153. smftools/tools/calculate_nmf.py +119 -0
  154. smftools/tools/calculate_umap.py +93 -8
  155. smftools/tools/cluster_adata_on_methylation.py +7 -1
  156. smftools/tools/position_stats.py +17 -27
  157. smftools/tools/rolling_nn_distance.py +235 -0
  158. smftools/tools/tensor_factorization.py +169 -0
  159. {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/METADATA +69 -33
  160. smftools-0.3.1.dist-info/RECORD +189 -0
  161. smftools-0.2.5.dist-info/RECORD +0 -181
  162. {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/WHEEL +0 -0
  163. {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/entry_points.txt +0 -0
  164. {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,16 +1,37 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import ast
4
+ import json
3
5
  import math
4
6
  import os
5
7
  from pathlib import Path
6
- from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
8
+ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence, Tuple
7
9
 
8
- import matplotlib.gridspec as gridspec
9
- import matplotlib.pyplot as plt
10
10
  import numpy as np
11
11
  import pandas as pd
12
12
  import scipy.cluster.hierarchy as sch
13
- import seaborn as sns
13
+
14
+ from smftools.logging_utils import get_logger
15
+ from smftools.optional_imports import require
16
+
17
+ colors = require("matplotlib.colors", extra="plotting", purpose="plot rendering")
18
+ gridspec = require("matplotlib.gridspec", extra="plotting", purpose="heatmap plotting")
19
+ patches = require("matplotlib.patches", extra="plotting", purpose="plot rendering")
20
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="plot rendering")
21
+ sns = require("seaborn", extra="plotting", purpose="plot styling")
22
+
23
+ logger = get_logger(__name__)
24
+
25
+ DNA_5COLOR_PALETTE = {
26
+ "A": "#00A000", # green
27
+ "C": "#0000FF", # blue
28
+ "G": "#FF7F00", # orange
29
+ "T": "#FF0000", # red
30
+ "OTHER": "#808080", # gray (N, PAD, unknown)
31
+ }
32
+
33
+ if TYPE_CHECKING:
34
+ import anndata as ad
14
35
 
15
36
 
16
37
  def _fixed_tick_positions(n_positions: int, n_ticks: int) -> np.ndarray:
@@ -68,7 +89,7 @@ def _select_labels(subset, sites: np.ndarray, reference: str, index_col_suffix:
68
89
  return labels[sites]
69
90
 
70
91
 
71
- def normalized_mean(matrix: np.ndarray) -> np.ndarray:
92
+ def normalized_mean(matrix: np.ndarray, *, ignore_nan: bool = True) -> np.ndarray:
72
93
  """Compute normalized column means for a matrix.
73
94
 
74
95
  Args:
@@ -77,19 +98,362 @@ def normalized_mean(matrix: np.ndarray) -> np.ndarray:
77
98
  Returns:
78
99
  1D array of normalized means.
79
100
  """
80
- mean = np.nanmean(matrix, axis=0)
101
+ mean = np.nanmean(matrix, axis=0) if ignore_nan else np.mean(matrix, axis=0)
81
102
  denom = (mean.max() - mean.min()) + 1e-9
82
103
  return (mean - mean.min()) / denom
83
104
 
84
105
 
85
- def methylation_fraction(matrix: np.ndarray) -> np.ndarray:
106
+ def plot_nmf_components(
107
+ adata: "ad.AnnData",
108
+ *,
109
+ output_dir: Path | str,
110
+ components_key: str = "H_nmf",
111
+ heatmap_name: str = "nmf_H_heatmap.png",
112
+ lineplot_name: str = "nmf_H_lineplot.png",
113
+ max_features: int = 2000,
114
+ ) -> Dict[str, Path]:
115
+ """Plot NMF component weights as a heatmap and per-component line plot.
116
+
117
+ Args:
118
+ adata: AnnData object containing NMF results.
119
+ output_dir: Directory to write plots into.
120
+ components_key: Key in ``adata.varm`` storing the H matrix.
121
+ heatmap_name: Filename for the heatmap plot.
122
+ lineplot_name: Filename for the line plot.
123
+ max_features: Maximum number of features to plot (top-weighted by component).
124
+
125
+ Returns:
126
+ Dict[str, Path]: Paths to created plots (keys: ``heatmap`` and ``lineplot``).
127
+ """
128
+ if components_key not in adata.varm:
129
+ logger.warning("NMF components key '%s' not found in adata.varm.", components_key)
130
+ return {}
131
+
132
+ output_path = Path(output_dir)
133
+ output_path.mkdir(parents=True, exist_ok=True)
134
+
135
+ components = np.asarray(adata.varm[components_key])
136
+ if components.ndim != 2:
137
+ raise ValueError(f"NMF components must be 2D; got shape {components.shape}.")
138
+
139
+ feature_labels = (
140
+ np.asarray(adata.var_names).astype(str)
141
+ if adata.shape[1] == components.shape[0]
142
+ else np.array([str(i) for i in range(components.shape[0])])
143
+ )
144
+
145
+ nonzero_mask = np.any(components != 0, axis=1)
146
+ if not np.any(nonzero_mask):
147
+ logger.warning("NMF components are all zeros; skipping plot generation.")
148
+ return {}
149
+
150
+ components = components[nonzero_mask]
151
+ feature_labels = feature_labels[nonzero_mask]
152
+
153
+ if max_features and components.shape[0] > max_features:
154
+ scores = np.nanmax(components, axis=1)
155
+ top_idx = np.argsort(scores)[-max_features:]
156
+ top_idx = np.sort(top_idx)
157
+ components = components[top_idx]
158
+ feature_labels = feature_labels[top_idx]
159
+ logger.info(
160
+ "Downsampled NMF features from %s to %s for plotting.",
161
+ nonzero_mask.sum(),
162
+ components.shape[0],
163
+ )
164
+
165
+ n_features, n_components = components.shape
166
+ component_labels = [f"C{i + 1}" for i in range(n_components)]
167
+
168
+ heatmap_width = max(8, min(20, n_features / 60))
169
+ heatmap_height = max(2.5, 0.6 * n_components + 1.5)
170
+ fig, ax = plt.subplots(figsize=(heatmap_width, heatmap_height))
171
+ sns.heatmap(
172
+ components.T,
173
+ ax=ax,
174
+ cmap="viridis",
175
+ cbar_kws={"label": "Component weight"},
176
+ xticklabels=feature_labels if n_features <= 60 else False,
177
+ yticklabels=component_labels,
178
+ )
179
+ ax.set_xlabel("Feature")
180
+ ax.set_ylabel("NMF component")
181
+ fig.tight_layout()
182
+ heatmap_path = output_path / heatmap_name
183
+ fig.savefig(heatmap_path, dpi=200)
184
+ plt.close(fig)
185
+
186
+ fig, ax = plt.subplots(figsize=(max(8, min(20, n_features / 50)), 3.5))
187
+ x = np.arange(n_features)
188
+ for idx, label in enumerate(component_labels):
189
+ ax.plot(x, components[:, idx], label=label, linewidth=1.5)
190
+ ax.set_xlabel("Feature index")
191
+ ax.set_ylabel("Component weight")
192
+ if n_features <= 60:
193
+ ax.set_xticks(x)
194
+ ax.set_xticklabels(feature_labels, rotation=90, fontsize=8)
195
+ ax.legend(loc="upper right", frameon=False)
196
+ fig.tight_layout()
197
+ lineplot_path = output_path / lineplot_name
198
+ fig.savefig(lineplot_path, dpi=200)
199
+ plt.close(fig)
200
+
201
+ return {"heatmap": heatmap_path, "lineplot": lineplot_path}
202
+
203
+
204
+ def plot_cp_sequence_components(
205
+ adata: "ad.AnnData",
206
+ *,
207
+ output_dir: Path | str,
208
+ components_key: str = "H_cp_sequence",
209
+ uns_key: str = "cp_sequence",
210
+ heatmap_name: str = "cp_sequence_position_heatmap.png",
211
+ lineplot_name: str = "cp_sequence_position_lineplot.png",
212
+ base_name: str = "cp_sequence_base_weights.png",
213
+ max_positions: int = 2000,
214
+ ) -> Dict[str, Path]:
215
+ """Plot CP decomposition position and base factors.
216
+
217
+ Args:
218
+ adata: AnnData object containing CP decomposition results.
219
+ output_dir: Directory to write plots into.
220
+ components_key: Key in ``adata.varm`` storing position factors.
221
+ uns_key: Key in ``adata.uns`` storing base factors.
222
+ heatmap_name: Filename for position heatmap.
223
+ lineplot_name: Filename for position line plot.
224
+ base_name: Filename for base factor bar plot.
225
+ max_positions: Maximum number of positions to plot.
226
+
227
+ Returns:
228
+ Dict[str, Path]: Paths to created plots.
229
+ """
230
+ if components_key not in adata.varm:
231
+ logger.warning("CP components key '%s' not found in adata.varm.", components_key)
232
+ return {}
233
+
234
+ output_path = Path(output_dir)
235
+ output_path.mkdir(parents=True, exist_ok=True)
236
+
237
+ components = np.asarray(adata.varm[components_key])
238
+ if components.ndim != 2:
239
+ raise ValueError(f"CP position factors must be 2D; got shape {components.shape}.")
240
+
241
+ feature_labels = (
242
+ np.asarray(adata.var_names).astype(str)
243
+ if adata.shape[1] == components.shape[0]
244
+ else np.array([str(i) for i in range(components.shape[0])])
245
+ )
246
+
247
+ if max_positions and components.shape[0] > max_positions:
248
+ original_count = components.shape[0]
249
+ scores = np.nanmax(np.abs(components), axis=1)
250
+ top_idx = np.argsort(scores)[-max_positions:]
251
+ top_idx = np.sort(top_idx)
252
+ components = components[top_idx]
253
+ feature_labels = feature_labels[top_idx]
254
+ logger.info(
255
+ "Downsampled CP positions from %s to %s for plotting.",
256
+ original_count,
257
+ max_positions,
258
+ )
259
+
260
+ n_positions, n_components = components.shape
261
+ component_labels = [f"C{i + 1}" for i in range(n_components)]
262
+
263
+ heatmap_width = max(8, min(20, n_positions / 60))
264
+ heatmap_height = max(2.5, 0.6 * n_components + 1.5)
265
+ fig, ax = plt.subplots(figsize=(heatmap_width, heatmap_height))
266
+ sns.heatmap(
267
+ components.T,
268
+ ax=ax,
269
+ cmap="viridis",
270
+ cbar_kws={"label": "Component weight"},
271
+ xticklabels=feature_labels if n_positions <= 60 else False,
272
+ yticklabels=component_labels,
273
+ )
274
+ ax.set_xlabel("Position")
275
+ ax.set_ylabel("CP component")
276
+ fig.tight_layout()
277
+ heatmap_path = output_path / heatmap_name
278
+ fig.savefig(heatmap_path, dpi=200)
279
+ plt.close(fig)
280
+
281
+ fig, ax = plt.subplots(figsize=(max(8, min(20, n_positions / 50)), 3.5))
282
+ x = np.arange(n_positions)
283
+ for idx, label in enumerate(component_labels):
284
+ ax.plot(x, components[:, idx], label=label, linewidth=1.5)
285
+ ax.set_xlabel("Position index")
286
+ ax.set_ylabel("Component weight")
287
+ if n_positions <= 60:
288
+ ax.set_xticks(x)
289
+ ax.set_xticklabels(feature_labels, rotation=90, fontsize=8)
290
+ ax.legend(loc="upper right", frameon=False)
291
+ fig.tight_layout()
292
+ lineplot_path = output_path / lineplot_name
293
+ fig.savefig(lineplot_path, dpi=200)
294
+ plt.close(fig)
295
+
296
+ outputs = {"heatmap": heatmap_path, "lineplot": lineplot_path}
297
+ if uns_key in adata.uns:
298
+ base_factors = adata.uns[uns_key].get("base_factors")
299
+ base_labels = adata.uns[uns_key].get("base_labels")
300
+ if base_factors is not None:
301
+ base_factors = np.asarray(base_factors)
302
+ if base_factors.ndim != 2 or base_factors.size == 0:
303
+ logger.warning(
304
+ "CP base factors must be 2D and non-empty; got shape %s.",
305
+ base_factors.shape,
306
+ )
307
+ else:
308
+ base_labels = base_labels or [f"B{i + 1}" for i in range(base_factors.shape[0])]
309
+ fig, ax = plt.subplots(figsize=(4.5, 3))
310
+ width = 0.8 / base_factors.shape[1]
311
+ x = np.arange(base_factors.shape[0])
312
+ for idx in range(base_factors.shape[1]):
313
+ ax.bar(
314
+ x + idx * width,
315
+ base_factors[:, idx],
316
+ width=width,
317
+ label=f"C{idx + 1}",
318
+ )
319
+ ax.set_xticks(x + width * (base_factors.shape[1] - 1) / 2)
320
+ ax.set_xticklabels(base_labels)
321
+ ax.set_ylabel("Base factor weight")
322
+ ax.legend(loc="upper right", frameon=False)
323
+ fig.tight_layout()
324
+ base_path = output_path / base_name
325
+ fig.savefig(base_path, dpi=200)
326
+ plt.close(fig)
327
+ outputs["base_factors"] = base_path
328
+
329
+ return outputs
330
+
331
+
332
+ def _resolve_feature_color(cmap: Any) -> Tuple[float, float, float, float]:
333
+ """Resolve a representative feature color from a colormap or color spec."""
334
+ if isinstance(cmap, str):
335
+ try:
336
+ cmap_obj = plt.get_cmap(cmap)
337
+ return colors.to_rgba(cmap_obj(1.0))
338
+ except Exception:
339
+ return colors.to_rgba(cmap)
340
+
341
+ if isinstance(cmap, colors.Colormap):
342
+ if hasattr(cmap, "colors") and cmap.colors:
343
+ return colors.to_rgba(cmap.colors[-1])
344
+ return colors.to_rgba(cmap(1.0))
345
+
346
+ return colors.to_rgba("black")
347
+
348
+
349
+ def _build_hmm_feature_cmap(
350
+ cmap: Any,
351
+ *,
352
+ zero_color: str = "#f5f1e8",
353
+ nan_color: str = "#E6E6E6",
354
+ ) -> colors.Colormap:
355
+ """Build a two-color HMM colormap with explicit NaN/under handling."""
356
+ feature_color = _resolve_feature_color(cmap)
357
+ hmm_cmap = colors.LinearSegmentedColormap.from_list(
358
+ "hmm_feature_cmap",
359
+ [zero_color, feature_color],
360
+ )
361
+ hmm_cmap.set_bad(nan_color)
362
+ hmm_cmap.set_under(nan_color)
363
+ return hmm_cmap
364
+
365
+
366
+ def _map_length_matrix_to_subclasses(
367
+ length_matrix: np.ndarray,
368
+ feature_ranges: Sequence[Tuple[int, int, Any]],
369
+ ) -> np.ndarray:
370
+ """Map length values into subclass integer codes based on feature ranges."""
371
+ mapped = np.zeros_like(length_matrix, dtype=float)
372
+ finite_mask = np.isfinite(length_matrix)
373
+ for idx, (min_len, max_len, _color) in enumerate(feature_ranges, start=1):
374
+ mask = finite_mask & (length_matrix >= min_len) & (length_matrix <= max_len)
375
+ mapped[mask] = float(idx)
376
+ mapped[~finite_mask] = np.nan
377
+ return mapped
378
+
379
+
380
+ def _build_length_feature_cmap(
381
+ feature_ranges: Sequence[Tuple[int, int, Any]],
382
+ *,
383
+ zero_color: str = "#f5f1e8",
384
+ nan_color: str = "#E6E6E6",
385
+ ) -> Tuple[colors.Colormap, colors.BoundaryNorm]:
386
+ """Build a discrete colormap and norm for length-based subclasses."""
387
+ color_list = [zero_color] + [color for _, _, color in feature_ranges]
388
+ cmap = colors.ListedColormap(color_list, name="hmm_length_feature_cmap")
389
+ cmap.set_bad(nan_color)
390
+ bounds = np.arange(-0.5, len(color_list) + 0.5, 1)
391
+ norm = colors.BoundaryNorm(bounds, cmap.N)
392
+ return cmap, norm
393
+
394
+
395
+ def _layer_to_numpy(
396
+ subset,
397
+ layer_name: str,
398
+ sites: np.ndarray | None = None,
399
+ *,
400
+ fill_nan_strategy: str = "value",
401
+ fill_nan_value: float = -1,
402
+ ) -> np.ndarray:
403
+ """Return a (copied) numpy array for a layer with optional NaN filling."""
404
+ if sites is not None:
405
+ layer_data = subset[:, sites].layers[layer_name]
406
+ else:
407
+ layer_data = subset.layers[layer_name]
408
+
409
+ if hasattr(layer_data, "toarray"):
410
+ arr = layer_data.toarray()
411
+ else:
412
+ arr = np.asarray(layer_data)
413
+
414
+ arr = np.array(arr, copy=True)
415
+
416
+ if fill_nan_strategy == "none":
417
+ return arr
418
+
419
+ if fill_nan_strategy not in {"value", "col_mean"}:
420
+ raise ValueError("fill_nan_strategy must be 'none', 'value', or 'col_mean'.")
421
+
422
+ arr = arr.astype(float, copy=False)
423
+
424
+ if fill_nan_strategy == "value":
425
+ return np.where(np.isnan(arr), fill_nan_value, arr)
426
+
427
+ col_mean = np.nanmean(arr, axis=0)
428
+ if np.any(np.isnan(col_mean)):
429
+ col_mean = np.where(np.isnan(col_mean), fill_nan_value, col_mean)
430
+ return np.where(np.isnan(arr), col_mean, arr)
431
+
432
+
433
+ def _infer_zero_is_valid(layer_name: str | None, matrix: np.ndarray) -> bool:
434
+ """Infer whether zeros should count as valid (unmethylated) values."""
435
+ if layer_name and "nan0_0minus1" in layer_name:
436
+ return False
437
+ if np.isnan(matrix).any():
438
+ return True
439
+ if np.any(matrix < 0):
440
+ return False
441
+ return True
442
+
443
+
444
+ def methylation_fraction(
445
+ matrix: np.ndarray, *, ignore_nan: bool = True, zero_is_valid: bool = False
446
+ ) -> np.ndarray:
86
447
  """
87
448
  Fraction methylated per column.
88
449
  Methylated = 1
89
- Valid = finite AND not 0
450
+ Valid = finite AND not 0 (unless zero_is_valid=True)
90
451
  """
91
452
  matrix = np.asarray(matrix)
92
- valid_mask = np.isfinite(matrix) & (matrix != 0)
453
+ if not ignore_nan:
454
+ matrix = np.where(np.isnan(matrix), 0, matrix)
455
+ finite_mask = np.isfinite(matrix)
456
+ valid_mask = finite_mask if zero_is_valid else (finite_mask & (matrix != 0))
93
457
  methyl_mask = (matrix == 1) & np.isfinite(matrix)
94
458
 
95
459
  methylated = methyl_mask.sum(axis=0)
@@ -100,20 +464,53 @@ def methylation_fraction(matrix: np.ndarray) -> np.ndarray:
100
464
  )
101
465
 
102
466
 
103
- def clean_barplot(ax, mean_values, title):
467
+ def _methylation_fraction_for_layer(
468
+ matrix: np.ndarray,
469
+ layer_name: str | None,
470
+ *,
471
+ ignore_nan: bool = True,
472
+ zero_is_valid: bool | None = None,
473
+ ) -> np.ndarray:
474
+ """Compute methylation fractions with layer-aware zero handling."""
475
+ matrix = np.asarray(matrix)
476
+ if zero_is_valid is None:
477
+ zero_is_valid = _infer_zero_is_valid(layer_name, matrix)
478
+ return methylation_fraction(matrix, ignore_nan=ignore_nan, zero_is_valid=zero_is_valid)
479
+
480
+
481
+ def clean_barplot(
482
+ ax,
483
+ mean_values,
484
+ title,
485
+ *,
486
+ y_max: float | None = 1.0,
487
+ y_label: str = "Mean",
488
+ y_ticks: list[float] | None = None,
489
+ ):
104
490
  """Format a barplot with consistent axes and labels.
105
491
 
106
492
  Args:
107
493
  ax: Matplotlib axes.
108
494
  mean_values: Values to plot.
109
495
  title: Plot title.
496
+ y_max: Optional y-axis max; inferred from data if not provided.
497
+ y_label: Y-axis label.
498
+ y_ticks: Optional y-axis ticks.
110
499
  """
111
500
  x = np.arange(len(mean_values))
112
501
  ax.bar(x, mean_values, color="gray", width=1.0, align="edge")
113
502
  ax.set_xlim(0, len(mean_values))
114
- ax.set_ylim(0, 1)
115
- ax.set_yticks([0.0, 0.5, 1.0])
116
- ax.set_ylabel("Mean")
503
+ if y_ticks is None and y_max == 1.0:
504
+ y_ticks = [0.0, 0.5, 1.0]
505
+ if y_max is None:
506
+ y_max = np.nanmax(mean_values) if len(mean_values) else 1.0
507
+ if not np.isfinite(y_max) or y_max <= 0:
508
+ y_max = 1.0
509
+ y_max *= 1.05
510
+ ax.set_ylim(0, y_max)
511
+ if y_ticks is not None:
512
+ ax.set_yticks(y_ticks)
513
+ ax.set_ylabel(y_label)
117
514
  ax.set_title(title, fontsize=12, pad=2)
118
515
 
119
516
  # Hide all spines except left
@@ -123,222 +520,6 @@ def clean_barplot(ax, mean_values, title):
123
520
  ax.tick_params(axis="x", which="both", bottom=False, top=False, labelbottom=False)
124
521
 
125
522
 
126
- # def combined_hmm_raw_clustermap(
127
- # adata,
128
- # sample_col='Sample_Names',
129
- # reference_col='Reference_strand',
130
- # hmm_feature_layer="hmm_combined",
131
- # layer_gpc="nan0_0minus1",
132
- # layer_cpg="nan0_0minus1",
133
- # layer_any_c="nan0_0minus1",
134
- # cmap_hmm="tab10",
135
- # cmap_gpc="coolwarm",
136
- # cmap_cpg="viridis",
137
- # cmap_any_c='coolwarm',
138
- # min_quality=20,
139
- # min_length=200,
140
- # min_mapped_length_to_reference_length_ratio=0.8,
141
- # min_position_valid_fraction=0.5,
142
- # sample_mapping=None,
143
- # save_path=None,
144
- # normalize_hmm=False,
145
- # sort_by="gpc", # options: 'gpc', 'cpg', 'gpc_cpg', 'none', or 'obs:<column>'
146
- # bins=None,
147
- # deaminase=False,
148
- # min_signal=0
149
- # ):
150
-
151
- # results = []
152
- # if deaminase:
153
- # signal_type = 'deamination'
154
- # else:
155
- # signal_type = 'methylation'
156
-
157
- # for ref in adata.obs[reference_col].cat.categories:
158
- # for sample in adata.obs[sample_col].cat.categories:
159
- # try:
160
- # subset = adata[
161
- # (adata.obs[reference_col] == ref) &
162
- # (adata.obs[sample_col] == sample) &
163
- # (adata.obs['read_quality'] >= min_quality) &
164
- # (adata.obs['read_length'] >= min_length) &
165
- # (adata.obs['mapped_length_to_reference_length_ratio'] > min_mapped_length_to_reference_length_ratio)
166
- # ]
167
-
168
- # mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
169
- # subset = subset[:, mask]
170
-
171
- # if subset.shape[0] == 0:
172
- # print(f" No reads left after filtering for {sample} - {ref}")
173
- # continue
174
-
175
- # if bins:
176
- # print(f"Using defined bins to subset clustermap for {sample} - {ref}")
177
- # bins_temp = bins
178
- # else:
179
- # print(f"Using all reads for clustermap for {sample} - {ref}")
180
- # bins_temp = {"All": (subset.obs['Reference_strand'] == ref)}
181
-
182
- # # Get column positions (not var_names!) of site masks
183
- # gpc_sites = np.where(subset.var[f"{ref}_GpC_site"].values)[0]
184
- # cpg_sites = np.where(subset.var[f"{ref}_CpG_site"].values)[0]
185
- # any_c_sites = np.where(subset.var[f"{ref}_any_C_site"].values)[0]
186
- # num_gpc = len(gpc_sites)
187
- # num_cpg = len(cpg_sites)
188
- # num_c = len(any_c_sites)
189
- # print(f"Found {num_gpc} GpC sites at {gpc_sites} \nand {num_cpg} CpG sites at {cpg_sites} for {sample} - {ref}")
190
-
191
- # # Use var_names for x-axis tick labels
192
- # gpc_labels = subset.var_names[gpc_sites].astype(int)
193
- # cpg_labels = subset.var_names[cpg_sites].astype(int)
194
- # any_c_labels = subset.var_names[any_c_sites].astype(int)
195
-
196
- # stacked_hmm_feature, stacked_gpc, stacked_cpg, stacked_any_c = [], [], [], []
197
- # row_labels, bin_labels = [], []
198
- # bin_boundaries = []
199
-
200
- # total_reads = subset.shape[0]
201
- # percentages = {}
202
- # last_idx = 0
203
-
204
- # for bin_label, bin_filter in bins_temp.items():
205
- # subset_bin = subset[bin_filter].copy()
206
- # num_reads = subset_bin.shape[0]
207
- # print(f"analyzing {num_reads} reads for {bin_label} bin in {sample} - {ref}")
208
- # percent_reads = (num_reads / total_reads) * 100 if total_reads > 0 else 0
209
- # percentages[bin_label] = percent_reads
210
-
211
- # if num_reads > 0 and num_cpg > 0 and num_gpc > 0:
212
- # # Determine sorting order
213
- # if sort_by.startswith("obs:"):
214
- # colname = sort_by.split("obs:")[1]
215
- # order = np.argsort(subset_bin.obs[colname].values)
216
- # elif sort_by == "gpc":
217
- # linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
218
- # order = sch.leaves_list(linkage)
219
- # elif sort_by == "cpg":
220
- # linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
221
- # order = sch.leaves_list(linkage)
222
- # elif sort_by == "gpc_cpg":
223
- # linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
224
- # order = sch.leaves_list(linkage)
225
- # elif sort_by == "none":
226
- # order = np.arange(num_reads)
227
- # elif sort_by == "any_c":
228
- # linkage = sch.linkage(subset_bin.layers[layer_any_c], method="ward")
229
- # order = sch.leaves_list(linkage)
230
- # else:
231
- # raise ValueError(f"Unsupported sort_by option: {sort_by}")
232
-
233
- # stacked_hmm_feature.append(subset_bin[order].layers[hmm_feature_layer])
234
- # stacked_gpc.append(subset_bin[order][:, gpc_sites].layers[layer_gpc])
235
- # stacked_cpg.append(subset_bin[order][:, cpg_sites].layers[layer_cpg])
236
- # stacked_any_c.append(subset_bin[order][:, any_c_sites].layers[layer_any_c])
237
-
238
- # row_labels.extend([bin_label] * num_reads)
239
- # bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
240
- # last_idx += num_reads
241
- # bin_boundaries.append(last_idx)
242
-
243
- # if stacked_hmm_feature:
244
- # hmm_matrix = np.vstack(stacked_hmm_feature)
245
- # gpc_matrix = np.vstack(stacked_gpc)
246
- # cpg_matrix = np.vstack(stacked_cpg)
247
- # any_c_matrix = np.vstack(stacked_any_c)
248
-
249
- # if hmm_matrix.size > 0:
250
- # def normalized_mean(matrix):
251
- # mean = np.nanmean(matrix, axis=0)
252
- # normalized = (mean - mean.min()) / (mean.max() - mean.min() + 1e-9)
253
- # return normalized
254
-
255
- # def methylation_fraction(matrix):
256
- # methylated = (matrix == 1).sum(axis=0)
257
- # valid = (matrix != 0).sum(axis=0)
258
- # return np.divide(methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0)
259
-
260
- # if normalize_hmm:
261
- # mean_hmm = normalized_mean(hmm_matrix)
262
- # else:
263
- # mean_hmm = np.nanmean(hmm_matrix, axis=0)
264
- # mean_gpc = methylation_fraction(gpc_matrix)
265
- # mean_cpg = methylation_fraction(cpg_matrix)
266
- # mean_any_c = methylation_fraction(any_c_matrix)
267
-
268
- # fig = plt.figure(figsize=(18, 12))
269
- # gs = gridspec.GridSpec(2, 4, height_ratios=[1, 6], hspace=0.01)
270
- # fig.suptitle(f"{sample} - {ref} - {total_reads} reads", fontsize=14, y=0.95)
271
-
272
- # axes_heat = [fig.add_subplot(gs[1, i]) for i in range(4)]
273
- # axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(4)]
274
-
275
- # clean_barplot(axes_bar[0], mean_hmm, f"{hmm_feature_layer} HMM Features")
276
- # clean_barplot(axes_bar[1], mean_gpc, f"GpC Accessibility Signal")
277
- # clean_barplot(axes_bar[2], mean_cpg, f"CpG Accessibility Signal")
278
- # clean_barplot(axes_bar[3], mean_any_c, f"Any C Accessibility Signal")
279
-
280
- # hmm_labels = subset.var_names.astype(int)
281
- # hmm_label_spacing = 150
282
- # sns.heatmap(hmm_matrix, cmap=cmap_hmm, ax=axes_heat[0], xticklabels=hmm_labels[::hmm_label_spacing], yticklabels=False, cbar=False)
283
- # axes_heat[0].set_xticks(range(0, len(hmm_labels), hmm_label_spacing))
284
- # axes_heat[0].set_xticklabels(hmm_labels[::hmm_label_spacing], rotation=90, fontsize=10)
285
- # for boundary in bin_boundaries[:-1]:
286
- # axes_heat[0].axhline(y=boundary, color="black", linewidth=2)
287
-
288
- # sns.heatmap(gpc_matrix, cmap=cmap_gpc, ax=axes_heat[1], xticklabels=gpc_labels[::5], yticklabels=False, cbar=False)
289
- # axes_heat[1].set_xticks(range(0, len(gpc_labels), 5))
290
- # axes_heat[1].set_xticklabels(gpc_labels[::5], rotation=90, fontsize=10)
291
- # for boundary in bin_boundaries[:-1]:
292
- # axes_heat[1].axhline(y=boundary, color="black", linewidth=2)
293
-
294
- # sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
295
- # axes_heat[2].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
296
- # for boundary in bin_boundaries[:-1]:
297
- # axes_heat[2].axhline(y=boundary, color="black", linewidth=2)
298
-
299
- # sns.heatmap(any_c_matrix, cmap=cmap_any_c, ax=axes_heat[3], xticklabels=any_c_labels[::20], yticklabels=False, cbar=False)
300
- # axes_heat[3].set_xticks(range(0, len(any_c_labels), 20))
301
- # axes_heat[3].set_xticklabels(any_c_labels[::20], rotation=90, fontsize=10)
302
- # for boundary in bin_boundaries[:-1]:
303
- # axes_heat[3].axhline(y=boundary, color="black", linewidth=2)
304
-
305
- # plt.tight_layout()
306
-
307
- # if save_path:
308
- # save_name = f"{ref} — {sample}"
309
- # os.makedirs(save_path, exist_ok=True)
310
- # safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
311
- # out_file = os.path.join(save_path, f"{safe_name}.png")
312
- # plt.savefig(out_file, dpi=300)
313
- # print(f"Saved: {out_file}")
314
- # plt.close()
315
- # else:
316
- # plt.show()
317
-
318
- # print(f"Summary for {sample} - {ref}:")
319
- # for bin_label, percent in percentages.items():
320
- # print(f" - {bin_label}: {percent:.1f}%")
321
-
322
- # results.append({
323
- # "sample": sample,
324
- # "ref": ref,
325
- # "hmm_matrix": hmm_matrix,
326
- # "gpc_matrix": gpc_matrix,
327
- # "cpg_matrix": cpg_matrix,
328
- # "row_labels": row_labels,
329
- # "bin_labels": bin_labels,
330
- # "bin_boundaries": bin_boundaries,
331
- # "percentages": percentages
332
- # })
333
-
334
- # #adata.uns['clustermap_results'] = results
335
-
336
- # except Exception as e:
337
- # import traceback
338
- # traceback.print_exc()
339
- # continue
340
-
341
-
342
523
  def combined_hmm_raw_clustermap(
343
524
  adata,
344
525
  sample_col: str = "Sample_Names",
@@ -372,6 +553,8 @@ def combined_hmm_raw_clustermap(
372
553
  n_xticks_cpg: int = 8,
373
554
  n_xticks_a: int = 8,
374
555
  index_col_suffix: str | None = None,
556
+ fill_nan_strategy: str = "value",
557
+ fill_nan_value: float = -1,
375
558
  ):
376
559
  """
377
560
  Makes a multi-panel clustermap per (sample, reference):
@@ -381,7 +564,11 @@ def combined_hmm_raw_clustermap(
381
564
 
382
565
  sort_by options:
383
566
  'gpc', 'cpg', 'c', 'a', 'gpc_cpg', 'none', 'hmm', or 'obs:<col>'
567
+
568
+ NaN fill strategy is applied in-memory for clustering/plotting only.
384
569
  """
570
+ if fill_nan_strategy not in {"none", "value", "col_mean"}:
571
+ raise ValueError("fill_nan_strategy must be 'none', 'value', or 'col_mean'.")
385
572
 
386
573
  def pick_xticks(labels: np.ndarray, n_ticks: int):
387
574
  """Pick tick indices/labels from an array."""
@@ -500,10 +687,15 @@ def combined_hmm_raw_clustermap(
500
687
 
501
688
  # storage
502
689
  stacked_hmm = []
690
+ stacked_hmm_raw = []
503
691
  stacked_any_c = []
692
+ stacked_any_c_raw = []
504
693
  stacked_gpc = []
694
+ stacked_gpc_raw = []
505
695
  stacked_cpg = []
696
+ stacked_cpg_raw = []
506
697
  stacked_any_a = []
698
+ stacked_any_a_raw = []
507
699
 
508
700
  row_labels, bin_labels, bin_boundaries = [], [], []
509
701
  total_reads = subset.n_obs
@@ -526,29 +718,69 @@ def combined_hmm_raw_clustermap(
526
718
  order = np.argsort(sb.obs[colname].values)
527
719
 
528
720
  elif sort_by == "gpc" and gpc_sites.size:
529
- linkage = sch.linkage(sb[:, gpc_sites].layers[layer_gpc], method="ward")
721
+ gpc_matrix = _layer_to_numpy(
722
+ sb,
723
+ layer_gpc,
724
+ gpc_sites,
725
+ fill_nan_strategy=fill_nan_strategy,
726
+ fill_nan_value=fill_nan_value,
727
+ )
728
+ linkage = sch.linkage(gpc_matrix, method="ward")
530
729
  order = sch.leaves_list(linkage)
531
730
 
532
731
  elif sort_by == "cpg" and cpg_sites.size:
533
- linkage = sch.linkage(sb[:, cpg_sites].layers[layer_cpg], method="ward")
732
+ cpg_matrix = _layer_to_numpy(
733
+ sb,
734
+ layer_cpg,
735
+ cpg_sites,
736
+ fill_nan_strategy=fill_nan_strategy,
737
+ fill_nan_value=fill_nan_value,
738
+ )
739
+ linkage = sch.linkage(cpg_matrix, method="ward")
534
740
  order = sch.leaves_list(linkage)
535
741
 
536
742
  elif sort_by == "c" and any_c_sites.size:
537
- linkage = sch.linkage(sb[:, any_c_sites].layers[layer_c], method="ward")
743
+ any_c_matrix = _layer_to_numpy(
744
+ sb,
745
+ layer_c,
746
+ any_c_sites,
747
+ fill_nan_strategy=fill_nan_strategy,
748
+ fill_nan_value=fill_nan_value,
749
+ )
750
+ linkage = sch.linkage(any_c_matrix, method="ward")
538
751
  order = sch.leaves_list(linkage)
539
752
 
540
753
  elif sort_by == "a" and any_a_sites.size:
541
- linkage = sch.linkage(sb[:, any_a_sites].layers[layer_a], method="ward")
754
+ any_a_matrix = _layer_to_numpy(
755
+ sb,
756
+ layer_a,
757
+ any_a_sites,
758
+ fill_nan_strategy=fill_nan_strategy,
759
+ fill_nan_value=fill_nan_value,
760
+ )
761
+ linkage = sch.linkage(any_a_matrix, method="ward")
542
762
  order = sch.leaves_list(linkage)
543
763
 
544
764
  elif sort_by == "gpc_cpg" and gpc_sites.size and cpg_sites.size:
545
- linkage = sch.linkage(sb.layers[layer_gpc], method="ward")
765
+ gpc_matrix = _layer_to_numpy(
766
+ sb,
767
+ layer_gpc,
768
+ None,
769
+ fill_nan_strategy=fill_nan_strategy,
770
+ fill_nan_value=fill_nan_value,
771
+ )
772
+ linkage = sch.linkage(gpc_matrix, method="ward")
546
773
  order = sch.leaves_list(linkage)
547
774
 
548
775
  elif sort_by == "hmm" and hmm_sites.size:
549
- linkage = sch.linkage(
550
- sb[:, hmm_sites].layers[hmm_feature_layer], method="ward"
776
+ hmm_matrix = _layer_to_numpy(
777
+ sb,
778
+ hmm_feature_layer,
779
+ hmm_sites,
780
+ fill_nan_strategy=fill_nan_strategy,
781
+ fill_nan_value=fill_nan_value,
551
782
  )
783
+ linkage = sch.linkage(hmm_matrix, method="ward")
552
784
  order = sch.leaves_list(linkage)
553
785
 
554
786
  else:
@@ -557,15 +789,100 @@ def combined_hmm_raw_clustermap(
557
789
  sb = sb[order]
558
790
 
559
791
  # ---- collect matrices ----
560
- stacked_hmm.append(sb.layers[hmm_feature_layer])
792
+ stacked_hmm.append(
793
+ _layer_to_numpy(
794
+ sb,
795
+ hmm_feature_layer,
796
+ None,
797
+ fill_nan_strategy=fill_nan_strategy,
798
+ fill_nan_value=fill_nan_value,
799
+ )
800
+ )
801
+ stacked_hmm_raw.append(
802
+ _layer_to_numpy(
803
+ sb,
804
+ hmm_feature_layer,
805
+ None,
806
+ fill_nan_strategy="none",
807
+ fill_nan_value=fill_nan_value,
808
+ )
809
+ )
561
810
  if any_c_sites.size:
562
- stacked_any_c.append(sb[:, any_c_sites].layers[layer_c])
811
+ stacked_any_c.append(
812
+ _layer_to_numpy(
813
+ sb,
814
+ layer_c,
815
+ any_c_sites,
816
+ fill_nan_strategy=fill_nan_strategy,
817
+ fill_nan_value=fill_nan_value,
818
+ )
819
+ )
820
+ stacked_any_c_raw.append(
821
+ _layer_to_numpy(
822
+ sb,
823
+ layer_c,
824
+ any_c_sites,
825
+ fill_nan_strategy="none",
826
+ fill_nan_value=fill_nan_value,
827
+ )
828
+ )
563
829
  if gpc_sites.size:
564
- stacked_gpc.append(sb[:, gpc_sites].layers[layer_gpc])
830
+ stacked_gpc.append(
831
+ _layer_to_numpy(
832
+ sb,
833
+ layer_gpc,
834
+ gpc_sites,
835
+ fill_nan_strategy=fill_nan_strategy,
836
+ fill_nan_value=fill_nan_value,
837
+ )
838
+ )
839
+ stacked_gpc_raw.append(
840
+ _layer_to_numpy(
841
+ sb,
842
+ layer_gpc,
843
+ gpc_sites,
844
+ fill_nan_strategy="none",
845
+ fill_nan_value=fill_nan_value,
846
+ )
847
+ )
565
848
  if cpg_sites.size:
566
- stacked_cpg.append(sb[:, cpg_sites].layers[layer_cpg])
849
+ stacked_cpg.append(
850
+ _layer_to_numpy(
851
+ sb,
852
+ layer_cpg,
853
+ cpg_sites,
854
+ fill_nan_strategy=fill_nan_strategy,
855
+ fill_nan_value=fill_nan_value,
856
+ )
857
+ )
858
+ stacked_cpg_raw.append(
859
+ _layer_to_numpy(
860
+ sb,
861
+ layer_cpg,
862
+ cpg_sites,
863
+ fill_nan_strategy="none",
864
+ fill_nan_value=fill_nan_value,
865
+ )
866
+ )
567
867
  if any_a_sites.size:
568
- stacked_any_a.append(sb[:, any_a_sites].layers[layer_a])
868
+ stacked_any_a.append(
869
+ _layer_to_numpy(
870
+ sb,
871
+ layer_a,
872
+ any_a_sites,
873
+ fill_nan_strategy=fill_nan_strategy,
874
+ fill_nan_value=fill_nan_value,
875
+ )
876
+ )
877
+ stacked_any_a_raw.append(
878
+ _layer_to_numpy(
879
+ sb,
880
+ layer_a,
881
+ any_a_sites,
882
+ fill_nan_strategy="none",
883
+ fill_nan_value=fill_nan_value,
884
+ )
885
+ )
569
886
 
570
887
  row_labels.extend([bin_label] * n)
571
888
  bin_labels.append(f"{bin_label}: {n} reads ({pct:.1f}%)")
@@ -574,16 +891,21 @@ def combined_hmm_raw_clustermap(
574
891
 
575
892
  # ---------------- stack ----------------
576
893
  hmm_matrix = np.vstack(stacked_hmm)
894
+ hmm_matrix_raw = np.vstack(stacked_hmm_raw)
577
895
  mean_hmm = (
578
- normalized_mean(hmm_matrix) if normalize_hmm else np.nanmean(hmm_matrix, axis=0)
896
+ normalized_mean(hmm_matrix_raw)
897
+ if normalize_hmm
898
+ else np.nanmean(hmm_matrix_raw, axis=0)
579
899
  )
900
+ hmm_plot_matrix = hmm_matrix_raw
901
+ hmm_plot_cmap = _build_hmm_feature_cmap(cmap_hmm)
580
902
 
581
903
  panels = [
582
904
  (
583
905
  f"HMM - {hmm_feature_layer}",
584
- hmm_matrix,
906
+ hmm_plot_matrix,
585
907
  hmm_labels,
586
- cmap_hmm,
908
+ hmm_plot_cmap,
587
909
  mean_hmm,
588
910
  n_xticks_hmm,
589
911
  ),
@@ -591,26 +913,58 @@ def combined_hmm_raw_clustermap(
591
913
 
592
914
  if stacked_any_c:
593
915
  m = np.vstack(stacked_any_c)
916
+ m_raw = np.vstack(stacked_any_c_raw)
594
917
  panels.append(
595
- ("C", m, any_c_labels, cmap_c, methylation_fraction(m), n_xticks_any_c)
918
+ (
919
+ "C",
920
+ m,
921
+ any_c_labels,
922
+ cmap_c,
923
+ _methylation_fraction_for_layer(m_raw, layer_c),
924
+ n_xticks_any_c,
925
+ )
596
926
  )
597
927
 
598
928
  if stacked_gpc:
599
929
  m = np.vstack(stacked_gpc)
930
+ m_raw = np.vstack(stacked_gpc_raw)
600
931
  panels.append(
601
- ("GpC", m, gpc_labels, cmap_gpc, methylation_fraction(m), n_xticks_gpc)
932
+ (
933
+ "GpC",
934
+ m,
935
+ gpc_labels,
936
+ cmap_gpc,
937
+ _methylation_fraction_for_layer(m_raw, layer_gpc),
938
+ n_xticks_gpc,
939
+ )
602
940
  )
603
941
 
604
942
  if stacked_cpg:
605
943
  m = np.vstack(stacked_cpg)
944
+ m_raw = np.vstack(stacked_cpg_raw)
606
945
  panels.append(
607
- ("CpG", m, cpg_labels, cmap_cpg, methylation_fraction(m), n_xticks_cpg)
946
+ (
947
+ "CpG",
948
+ m,
949
+ cpg_labels,
950
+ cmap_cpg,
951
+ _methylation_fraction_for_layer(m_raw, layer_cpg),
952
+ n_xticks_cpg,
953
+ )
608
954
  )
609
955
 
610
956
  if stacked_any_a:
611
957
  m = np.vstack(stacked_any_a)
958
+ m_raw = np.vstack(stacked_any_a_raw)
612
959
  panels.append(
613
- ("A", m, any_a_labels, cmap_a, methylation_fraction(m), n_xticks_a)
960
+ (
961
+ "A",
962
+ m,
963
+ any_a_labels,
964
+ cmap_a,
965
+ _methylation_fraction_for_layer(m_raw, layer_a),
966
+ n_xticks_a,
967
+ )
614
968
  )
615
969
 
616
970
  # ---------------- plotting ----------------
@@ -629,7 +983,15 @@ def combined_hmm_raw_clustermap(
629
983
  clean_barplot(axes_bar[i], mean_vec, name)
630
984
 
631
985
  # ---- heatmap ----
632
- sns.heatmap(matrix, cmap=cmap, ax=axes_heat[i], yticklabels=False, cbar=False)
986
+ heatmap_kwargs = dict(
987
+ cmap=cmap,
988
+ ax=axes_heat[i],
989
+ yticklabels=False,
990
+ cbar=False,
991
+ )
992
+ if name.startswith("HMM -"):
993
+ heatmap_kwargs.update(vmin=0.0, vmax=1.0)
994
+ sns.heatmap(matrix, **heatmap_kwargs)
633
995
 
634
996
  # ---- xticks ----
635
997
  xtick_pos, xtick_labels = pick_xticks(np.asarray(labels), n_ticks)
@@ -658,271 +1020,6 @@ def combined_hmm_raw_clustermap(
658
1020
  continue
659
1021
 
660
1022
 
661
- # def combined_raw_clustermap(
662
- # adata,
663
- # sample_col='Sample_Names',
664
- # reference_col='Reference_strand',
665
- # mod_target_bases=['GpC', 'CpG'],
666
- # layer_any_c="nan0_0minus1",
667
- # layer_gpc="nan0_0minus1",
668
- # layer_cpg="nan0_0minus1",
669
- # layer_a="nan0_0minus1",
670
- # cmap_any_c="coolwarm",
671
- # cmap_gpc="coolwarm",
672
- # cmap_cpg="viridis",
673
- # cmap_a="coolwarm",
674
- # min_quality=20,
675
- # min_length=200,
676
- # min_mapped_length_to_reference_length_ratio=0.8,
677
- # min_position_valid_fraction=0.5,
678
- # sample_mapping=None,
679
- # save_path=None,
680
- # sort_by="gpc", # options: 'gpc', 'cpg', 'gpc_cpg', 'none', 'any_a', or 'obs:<column>'
681
- # bins=None,
682
- # deaminase=False,
683
- # min_signal=0
684
- # ):
685
-
686
- # results = []
687
-
688
- # for ref in adata.obs[reference_col].cat.categories:
689
- # for sample in adata.obs[sample_col].cat.categories:
690
- # try:
691
- # subset = adata[
692
- # (adata.obs[reference_col] == ref) &
693
- # (adata.obs[sample_col] == sample) &
694
- # (adata.obs['read_quality'] >= min_quality) &
695
- # (adata.obs['mapped_length'] >= min_length) &
696
- # (adata.obs['mapped_length_to_reference_length_ratio'] >= min_mapped_length_to_reference_length_ratio)
697
- # ]
698
-
699
- # mask = subset.var[f"{ref}_valid_fraction"].astype(float) > float(min_position_valid_fraction)
700
- # subset = subset[:, mask]
701
-
702
- # if subset.shape[0] == 0:
703
- # print(f" No reads left after filtering for {sample} - {ref}")
704
- # continue
705
-
706
- # if bins:
707
- # print(f"Using defined bins to subset clustermap for {sample} - {ref}")
708
- # bins_temp = bins
709
- # else:
710
- # print(f"Using all reads for clustermap for {sample} - {ref}")
711
- # bins_temp = {"All": (subset.obs['Reference_strand'] == ref)}
712
-
713
- # num_any_c = 0
714
- # num_gpc = 0
715
- # num_cpg = 0
716
- # num_any_a = 0
717
-
718
- # # Get column positions (not var_names!) of site masks
719
- # if any(base in ["C", "CpG", "GpC"] for base in mod_target_bases):
720
- # any_c_sites = np.where(subset.var[f"{ref}_C_site"].values)[0]
721
- # gpc_sites = np.where(subset.var[f"{ref}_GpC_site"].values)[0]
722
- # cpg_sites = np.where(subset.var[f"{ref}_CpG_site"].values)[0]
723
- # num_any_c = len(any_c_sites)
724
- # num_gpc = len(gpc_sites)
725
- # num_cpg = len(cpg_sites)
726
- # print(f"Found {num_gpc} GpC sites at {gpc_sites} \nand {num_cpg} CpG sites at {cpg_sites}\n and {num_any_c} any_C sites at {any_c_sites} for {sample} - {ref}")
727
-
728
- # # Use var_names for x-axis tick labels
729
- # gpc_labels = subset.var_names[gpc_sites].astype(int)
730
- # cpg_labels = subset.var_names[cpg_sites].astype(int)
731
- # any_c_labels = subset.var_names[any_c_sites].astype(int)
732
- # stacked_any_c, stacked_gpc, stacked_cpg = [], [], []
733
-
734
- # if "A" in mod_target_bases:
735
- # any_a_sites = np.where(subset.var[f"{ref}_A_site"].values)[0]
736
- # num_any_a = len(any_a_sites)
737
- # print(f"Found {num_any_a} any_A sites at {any_a_sites} for {sample} - {ref}")
738
- # any_a_labels = subset.var_names[any_a_sites].astype(int)
739
- # stacked_any_a = []
740
-
741
- # row_labels, bin_labels = [], []
742
- # bin_boundaries = []
743
-
744
- # total_reads = subset.shape[0]
745
- # percentages = {}
746
- # last_idx = 0
747
-
748
- # for bin_label, bin_filter in bins_temp.items():
749
- # subset_bin = subset[bin_filter].copy()
750
- # num_reads = subset_bin.shape[0]
751
- # print(f"analyzing {num_reads} reads for {bin_label} bin in {sample} - {ref}")
752
- # percent_reads = (num_reads / total_reads) * 100 if total_reads > 0 else 0
753
- # percentages[bin_label] = percent_reads
754
-
755
- # if num_reads > 0 and num_cpg > 0 and num_gpc > 0:
756
- # # Determine sorting order
757
- # if sort_by.startswith("obs:"):
758
- # colname = sort_by.split("obs:")[1]
759
- # order = np.argsort(subset_bin.obs[colname].values)
760
- # elif sort_by == "gpc":
761
- # linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
762
- # order = sch.leaves_list(linkage)
763
- # elif sort_by == "cpg":
764
- # linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
765
- # order = sch.leaves_list(linkage)
766
- # elif sort_by == "any_c":
767
- # linkage = sch.linkage(subset_bin[:, any_c_sites].layers[layer_any_c], method="ward")
768
- # order = sch.leaves_list(linkage)
769
- # elif sort_by == "gpc_cpg":
770
- # linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
771
- # order = sch.leaves_list(linkage)
772
- # elif sort_by == "none":
773
- # order = np.arange(num_reads)
774
- # elif sort_by == "any_a":
775
- # linkage = sch.linkage(subset_bin.layers[layer_a], method="ward")
776
- # order = sch.leaves_list(linkage)
777
- # else:
778
- # raise ValueError(f"Unsupported sort_by option: {sort_by}")
779
-
780
- # stacked_any_c.append(subset_bin[order][:, any_c_sites].layers[layer_any_c])
781
- # stacked_gpc.append(subset_bin[order][:, gpc_sites].layers[layer_gpc])
782
- # stacked_cpg.append(subset_bin[order][:, cpg_sites].layers[layer_cpg])
783
-
784
- # if num_reads > 0 and num_any_a > 0:
785
- # # Determine sorting order
786
- # if sort_by.startswith("obs:"):
787
- # colname = sort_by.split("obs:")[1]
788
- # order = np.argsort(subset_bin.obs[colname].values)
789
- # elif sort_by == "gpc":
790
- # linkage = sch.linkage(subset_bin[:, gpc_sites].layers[layer_gpc], method="ward")
791
- # order = sch.leaves_list(linkage)
792
- # elif sort_by == "cpg":
793
- # linkage = sch.linkage(subset_bin[:, cpg_sites].layers[layer_cpg], method="ward")
794
- # order = sch.leaves_list(linkage)
795
- # elif sort_by == "any_c":
796
- # linkage = sch.linkage(subset_bin[:, any_c_sites].layers[layer_any_c], method="ward")
797
- # order = sch.leaves_list(linkage)
798
- # elif sort_by == "gpc_cpg":
799
- # linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
800
- # order = sch.leaves_list(linkage)
801
- # elif sort_by == "none":
802
- # order = np.arange(num_reads)
803
- # elif sort_by == "any_a":
804
- # linkage = sch.linkage(subset_bin.layers[layer_a], method="ward")
805
- # order = sch.leaves_list(linkage)
806
- # else:
807
- # raise ValueError(f"Unsupported sort_by option: {sort_by}")
808
-
809
- # stacked_any_a.append(subset_bin[order][:, any_a_sites].layers[layer_a])
810
-
811
-
812
- # row_labels.extend([bin_label] * num_reads)
813
- # bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
814
- # last_idx += num_reads
815
- # bin_boundaries.append(last_idx)
816
-
817
- # gs_dim = 0
818
-
819
- # if stacked_any_c:
820
- # any_c_matrix = np.vstack(stacked_any_c)
821
- # gpc_matrix = np.vstack(stacked_gpc)
822
- # cpg_matrix = np.vstack(stacked_cpg)
823
- # if any_c_matrix.size > 0:
824
- # mean_gpc = methylation_fraction(gpc_matrix)
825
- # mean_cpg = methylation_fraction(cpg_matrix)
826
- # mean_any_c = methylation_fraction(any_c_matrix)
827
- # gs_dim += 3
828
-
829
- # if stacked_any_a:
830
- # any_a_matrix = np.vstack(stacked_any_a)
831
- # if any_a_matrix.size > 0:
832
- # mean_any_a = methylation_fraction(any_a_matrix)
833
- # gs_dim += 1
834
-
835
-
836
- # fig = plt.figure(figsize=(18, 12))
837
- # gs = gridspec.GridSpec(2, gs_dim, height_ratios=[1, 6], hspace=0.01)
838
- # fig.suptitle(f"{sample} - {ref} - {total_reads} reads", fontsize=14, y=0.95)
839
- # axes_heat = [fig.add_subplot(gs[1, i]) for i in range(gs_dim)]
840
- # axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(gs_dim)]
841
-
842
- # current_ax = 0
843
-
844
- # if stacked_any_c:
845
- # if any_c_matrix.size > 0:
846
- # clean_barplot(axes_bar[current_ax], mean_any_c, f"any C site Modification Signal")
847
- # sns.heatmap(any_c_matrix, cmap=cmap_any_c, ax=axes_heat[current_ax], xticklabels=any_c_labels[::20], yticklabels=False, cbar=False)
848
- # axes_heat[current_ax].set_xticks(range(0, len(any_c_labels), 20))
849
- # axes_heat[current_ax].set_xticklabels(any_c_labels[::20], rotation=90, fontsize=10)
850
- # for boundary in bin_boundaries[:-1]:
851
- # axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
852
- # current_ax +=1
853
-
854
- # clean_barplot(axes_bar[current_ax], mean_gpc, f"GpC Modification Signal")
855
- # sns.heatmap(gpc_matrix, cmap=cmap_gpc, ax=axes_heat[current_ax], xticklabels=gpc_labels[::5], yticklabels=False, cbar=False)
856
- # axes_heat[current_ax].set_xticks(range(0, len(gpc_labels), 5))
857
- # axes_heat[current_ax].set_xticklabels(gpc_labels[::5], rotation=90, fontsize=10)
858
- # for boundary in bin_boundaries[:-1]:
859
- # axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
860
- # current_ax +=1
861
-
862
- # clean_barplot(axes_bar[current_ax], mean_cpg, f"CpG Modification Signal")
863
- # sns.heatmap(cpg_matrix, cmap=cmap_cpg, ax=axes_heat[2], xticklabels=cpg_labels, yticklabels=False, cbar=False)
864
- # axes_heat[current_ax].set_xticklabels(cpg_labels, rotation=90, fontsize=10)
865
- # for boundary in bin_boundaries[:-1]:
866
- # axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
867
- # current_ax +=1
868
-
869
- # results.append({
870
- # "sample": sample,
871
- # "ref": ref,
872
- # "any_c_matrix": any_c_matrix,
873
- # "gpc_matrix": gpc_matrix,
874
- # "cpg_matrix": cpg_matrix,
875
- # "row_labels": row_labels,
876
- # "bin_labels": bin_labels,
877
- # "bin_boundaries": bin_boundaries,
878
- # "percentages": percentages
879
- # })
880
-
881
- # if stacked_any_a:
882
- # if any_a_matrix.size > 0:
883
- # clean_barplot(axes_bar[current_ax], mean_any_a, f"any A site Modification Signal")
884
- # sns.heatmap(any_a_matrix, cmap=cmap_a, ax=axes_heat[current_ax], xticklabels=any_a_labels[::20], yticklabels=False, cbar=False)
885
- # axes_heat[current_ax].set_xticks(range(0, len(any_a_labels), 20))
886
- # axes_heat[current_ax].set_xticklabels(any_a_labels[::20], rotation=90, fontsize=10)
887
- # for boundary in bin_boundaries[:-1]:
888
- # axes_heat[current_ax].axhline(y=boundary, color="black", linewidth=2)
889
- # current_ax +=1
890
-
891
- # results.append({
892
- # "sample": sample,
893
- # "ref": ref,
894
- # "any_a_matrix": any_a_matrix,
895
- # "row_labels": row_labels,
896
- # "bin_labels": bin_labels,
897
- # "bin_boundaries": bin_boundaries,
898
- # "percentages": percentages
899
- # })
900
-
901
- # plt.tight_layout()
902
-
903
- # if save_path:
904
- # save_name = f"{ref} — {sample}"
905
- # os.makedirs(save_path, exist_ok=True)
906
- # safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
907
- # out_file = os.path.join(save_path, f"{safe_name}.png")
908
- # plt.savefig(out_file, dpi=300)
909
- # print(f"Saved: {out_file}")
910
- # plt.close()
911
- # else:
912
- # plt.show()
913
-
914
- # print(f"Summary for {sample} - {ref}:")
915
- # for bin_label, percent in percentages.items():
916
- # print(f" - {bin_label}: {percent:.1f}%")
917
-
918
- # adata.uns['clustermap_results'] = results
919
-
920
- # except Exception as e:
921
- # import traceback
922
- # traceback.print_exc()
923
- # continue
924
-
925
-
926
1023
  def combined_raw_clustermap(
927
1024
  adata,
928
1025
  sample_col: str = "Sample_Names",
@@ -954,6 +1051,8 @@ def combined_raw_clustermap(
954
1051
  xtick_rotation: int = 90,
955
1052
  xtick_fontsize: int = 9,
956
1053
  index_col_suffix: str | None = None,
1054
+ fill_nan_strategy: str = "value",
1055
+ fill_nan_value: float = -1,
957
1056
  ):
958
1057
  """
959
1058
  Plot stacked heatmaps + per-position mean barplots for C, GpC, CpG, and optional A.
@@ -964,6 +1063,7 @@ def combined_raw_clustermap(
964
1063
  - NaNs excluded from methylation denominators
965
1064
  - var_names not forced to int
966
1065
  - fixed count of x tick labels per block (controllable)
1066
+ - optional NaN fill strategy for clustering/plotting (in-memory only)
967
1067
  - adata.uns updated once at end
968
1068
 
969
1069
  Returns
@@ -971,6 +1071,8 @@ def combined_raw_clustermap(
971
1071
  results : list[dict]
972
1072
  One entry per (sample, ref) plot with matrices + bin metadata.
973
1073
  """
1074
+ if fill_nan_strategy not in {"none", "value", "col_mean"}:
1075
+ raise ValueError("fill_nan_strategy must be 'none', 'value', or 'col_mean'.")
974
1076
 
975
1077
  # Helper: build a True mask if filter is inactive or column missing
976
1078
  def _mask_or_true(series_name: str, predicate):
@@ -1093,6 +1195,12 @@ def combined_raw_clustermap(
1093
1195
  any_a_labels = _select_labels(subset, any_a_sites, ref, index_col_suffix)
1094
1196
 
1095
1197
  stacked_any_c, stacked_gpc, stacked_cpg, stacked_any_a = [], [], [], []
1198
+ stacked_any_c_raw, stacked_gpc_raw, stacked_cpg_raw, stacked_any_a_raw = (
1199
+ [],
1200
+ [],
1201
+ [],
1202
+ [],
1203
+ )
1096
1204
  row_labels, bin_labels, bin_boundaries = [], [], []
1097
1205
  percentages = {}
1098
1206
  last_idx = 0
@@ -1117,31 +1225,58 @@ def combined_raw_clustermap(
1117
1225
  order = np.argsort(subset_bin.obs[colname].values)
1118
1226
 
1119
1227
  elif sort_by == "gpc" and num_gpc > 0:
1120
- linkage = sch.linkage(
1121
- subset_bin[:, gpc_sites].layers[layer_gpc], method="ward"
1228
+ gpc_matrix = _layer_to_numpy(
1229
+ subset_bin,
1230
+ layer_gpc,
1231
+ gpc_sites,
1232
+ fill_nan_strategy=fill_nan_strategy,
1233
+ fill_nan_value=fill_nan_value,
1122
1234
  )
1235
+ linkage = sch.linkage(gpc_matrix, method="ward")
1123
1236
  order = sch.leaves_list(linkage)
1124
1237
 
1125
1238
  elif sort_by == "cpg" and num_cpg > 0:
1126
- linkage = sch.linkage(
1127
- subset_bin[:, cpg_sites].layers[layer_cpg], method="ward"
1239
+ cpg_matrix = _layer_to_numpy(
1240
+ subset_bin,
1241
+ layer_cpg,
1242
+ cpg_sites,
1243
+ fill_nan_strategy=fill_nan_strategy,
1244
+ fill_nan_value=fill_nan_value,
1128
1245
  )
1246
+ linkage = sch.linkage(cpg_matrix, method="ward")
1129
1247
  order = sch.leaves_list(linkage)
1130
1248
 
1131
1249
  elif sort_by == "c" and num_any_c > 0:
1132
- linkage = sch.linkage(
1133
- subset_bin[:, any_c_sites].layers[layer_c], method="ward"
1250
+ any_c_matrix = _layer_to_numpy(
1251
+ subset_bin,
1252
+ layer_c,
1253
+ any_c_sites,
1254
+ fill_nan_strategy=fill_nan_strategy,
1255
+ fill_nan_value=fill_nan_value,
1134
1256
  )
1257
+ linkage = sch.linkage(any_c_matrix, method="ward")
1135
1258
  order = sch.leaves_list(linkage)
1136
1259
 
1137
1260
  elif sort_by == "gpc_cpg":
1138
- linkage = sch.linkage(subset_bin.layers[layer_gpc], method="ward")
1261
+ gpc_matrix = _layer_to_numpy(
1262
+ subset_bin,
1263
+ layer_gpc,
1264
+ None,
1265
+ fill_nan_strategy=fill_nan_strategy,
1266
+ fill_nan_value=fill_nan_value,
1267
+ )
1268
+ linkage = sch.linkage(gpc_matrix, method="ward")
1139
1269
  order = sch.leaves_list(linkage)
1140
1270
 
1141
1271
  elif sort_by == "a" and num_any_a > 0:
1142
- linkage = sch.linkage(
1143
- subset_bin[:, any_a_sites].layers[layer_a], method="ward"
1272
+ any_a_matrix = _layer_to_numpy(
1273
+ subset_bin,
1274
+ layer_a,
1275
+ any_a_sites,
1276
+ fill_nan_strategy=fill_nan_strategy,
1277
+ fill_nan_value=fill_nan_value,
1144
1278
  )
1279
+ linkage = sch.linkage(any_a_matrix, method="ward")
1145
1280
  order = sch.leaves_list(linkage)
1146
1281
 
1147
1282
  elif sort_by == "none":
@@ -1154,13 +1289,81 @@ def combined_raw_clustermap(
1154
1289
 
1155
1290
  # stack consistently
1156
1291
  if include_any_c and num_any_c > 0:
1157
- stacked_any_c.append(subset_bin[:, any_c_sites].layers[layer_c])
1292
+ stacked_any_c.append(
1293
+ _layer_to_numpy(
1294
+ subset_bin,
1295
+ layer_c,
1296
+ any_c_sites,
1297
+ fill_nan_strategy=fill_nan_strategy,
1298
+ fill_nan_value=fill_nan_value,
1299
+ )
1300
+ )
1301
+ stacked_any_c_raw.append(
1302
+ _layer_to_numpy(
1303
+ subset_bin,
1304
+ layer_c,
1305
+ any_c_sites,
1306
+ fill_nan_strategy="none",
1307
+ fill_nan_value=fill_nan_value,
1308
+ )
1309
+ )
1158
1310
  if include_any_c and num_gpc > 0:
1159
- stacked_gpc.append(subset_bin[:, gpc_sites].layers[layer_gpc])
1311
+ stacked_gpc.append(
1312
+ _layer_to_numpy(
1313
+ subset_bin,
1314
+ layer_gpc,
1315
+ gpc_sites,
1316
+ fill_nan_strategy=fill_nan_strategy,
1317
+ fill_nan_value=fill_nan_value,
1318
+ )
1319
+ )
1320
+ stacked_gpc_raw.append(
1321
+ _layer_to_numpy(
1322
+ subset_bin,
1323
+ layer_gpc,
1324
+ gpc_sites,
1325
+ fill_nan_strategy="none",
1326
+ fill_nan_value=fill_nan_value,
1327
+ )
1328
+ )
1160
1329
  if include_any_c and num_cpg > 0:
1161
- stacked_cpg.append(subset_bin[:, cpg_sites].layers[layer_cpg])
1330
+ stacked_cpg.append(
1331
+ _layer_to_numpy(
1332
+ subset_bin,
1333
+ layer_cpg,
1334
+ cpg_sites,
1335
+ fill_nan_strategy=fill_nan_strategy,
1336
+ fill_nan_value=fill_nan_value,
1337
+ )
1338
+ )
1339
+ stacked_cpg_raw.append(
1340
+ _layer_to_numpy(
1341
+ subset_bin,
1342
+ layer_cpg,
1343
+ cpg_sites,
1344
+ fill_nan_strategy="none",
1345
+ fill_nan_value=fill_nan_value,
1346
+ )
1347
+ )
1162
1348
  if include_any_a and num_any_a > 0:
1163
- stacked_any_a.append(subset_bin[:, any_a_sites].layers[layer_a])
1349
+ stacked_any_a.append(
1350
+ _layer_to_numpy(
1351
+ subset_bin,
1352
+ layer_a,
1353
+ any_a_sites,
1354
+ fill_nan_strategy=fill_nan_strategy,
1355
+ fill_nan_value=fill_nan_value,
1356
+ )
1357
+ )
1358
+ stacked_any_a_raw.append(
1359
+ _layer_to_numpy(
1360
+ subset_bin,
1361
+ layer_a,
1362
+ any_a_sites,
1363
+ fill_nan_strategy="none",
1364
+ fill_nan_value=fill_nan_value,
1365
+ )
1366
+ )
1164
1367
 
1165
1368
  row_labels.extend([bin_label] * num_reads)
1166
1369
  bin_labels.append(f"{bin_label}: {num_reads} reads ({percent_reads:.1f}%)")
@@ -1174,12 +1377,31 @@ def combined_raw_clustermap(
1174
1377
 
1175
1378
  if include_any_c and stacked_any_c:
1176
1379
  any_c_matrix = np.vstack(stacked_any_c)
1380
+ any_c_matrix_raw = np.vstack(stacked_any_c_raw)
1177
1381
  gpc_matrix = np.vstack(stacked_gpc) if stacked_gpc else np.empty((0, 0))
1382
+ gpc_matrix_raw = (
1383
+ np.vstack(stacked_gpc_raw) if stacked_gpc_raw else np.empty((0, 0))
1384
+ )
1178
1385
  cpg_matrix = np.vstack(stacked_cpg) if stacked_cpg else np.empty((0, 0))
1386
+ cpg_matrix_raw = (
1387
+ np.vstack(stacked_cpg_raw) if stacked_cpg_raw else np.empty((0, 0))
1388
+ )
1179
1389
 
1180
- mean_any_c = methylation_fraction(any_c_matrix) if any_c_matrix.size else None
1181
- mean_gpc = methylation_fraction(gpc_matrix) if gpc_matrix.size else None
1182
- mean_cpg = methylation_fraction(cpg_matrix) if cpg_matrix.size else None
1390
+ mean_any_c = (
1391
+ _methylation_fraction_for_layer(any_c_matrix_raw, layer_c)
1392
+ if any_c_matrix_raw.size
1393
+ else None
1394
+ )
1395
+ mean_gpc = (
1396
+ _methylation_fraction_for_layer(gpc_matrix_raw, layer_gpc)
1397
+ if gpc_matrix_raw.size
1398
+ else None
1399
+ )
1400
+ mean_cpg = (
1401
+ _methylation_fraction_for_layer(cpg_matrix_raw, layer_cpg)
1402
+ if cpg_matrix_raw.size
1403
+ else None
1404
+ )
1183
1405
 
1184
1406
  if any_c_matrix.size:
1185
1407
  blocks.append(
@@ -1220,7 +1442,12 @@ def combined_raw_clustermap(
1220
1442
 
1221
1443
  if include_any_a and stacked_any_a:
1222
1444
  any_a_matrix = np.vstack(stacked_any_a)
1223
- mean_any_a = methylation_fraction(any_a_matrix) if any_a_matrix.size else None
1445
+ any_a_matrix_raw = np.vstack(stacked_any_a_raw)
1446
+ mean_any_a = (
1447
+ _methylation_fraction_for_layer(any_a_matrix_raw, layer_a)
1448
+ if any_a_matrix_raw.size
1449
+ else None
1450
+ )
1224
1451
  if any_a_matrix.size:
1225
1452
  blocks.append(
1226
1453
  dict(
@@ -1320,112 +1547,1530 @@ def combined_raw_clustermap(
1320
1547
  return results
1321
1548
 
1322
1549
 
1323
- def plot_hmm_layers_rolling_by_sample_ref(
1550
+ def combined_hmm_length_clustermap(
1324
1551
  adata,
1325
- layers: Optional[Sequence[str]] = None,
1326
- sample_col: str = "Barcode",
1327
- ref_col: str = "Reference_strand",
1328
- samples: Optional[Sequence[str]] = None,
1329
- references: Optional[Sequence[str]] = None,
1330
- window: int = 51,
1331
- min_periods: int = 1,
1332
- center: bool = True,
1333
- rows_per_page: int = 6,
1334
- figsize_per_cell: Tuple[float, float] = (4.0, 2.5),
1335
- dpi: int = 160,
1336
- output_dir: Optional[str] = None,
1337
- save: bool = True,
1338
- show_raw: bool = False,
1339
- cmap: str = "tab20",
1340
- use_var_coords: bool = True,
1552
+ sample_col: str = "Sample_Names",
1553
+ reference_col: str = "Reference_strand",
1554
+ length_layer: str = "hmm_combined_lengths",
1555
+ layer_gpc: str = "nan0_0minus1",
1556
+ layer_cpg: str = "nan0_0minus1",
1557
+ layer_c: str = "nan0_0minus1",
1558
+ layer_a: str = "nan0_0minus1",
1559
+ cmap_lengths: Any = "Greens",
1560
+ cmap_gpc: str = "coolwarm",
1561
+ cmap_cpg: str = "viridis",
1562
+ cmap_c: str = "coolwarm",
1563
+ cmap_a: str = "coolwarm",
1564
+ min_quality: int = 20,
1565
+ min_length: int = 200,
1566
+ min_mapped_length_to_reference_length_ratio: float = 0.8,
1567
+ min_position_valid_fraction: float = 0.5,
1568
+ demux_types: Sequence[str] = ("single", "double", "already"),
1569
+ sample_mapping: Optional[Mapping[str, str]] = None,
1570
+ save_path: str | Path | None = None,
1571
+ sort_by: str = "gpc",
1572
+ bins: Optional[Dict[str, Any]] = None,
1573
+ deaminase: bool = False,
1574
+ min_signal: float = 0.0,
1575
+ n_xticks_lengths: int = 10,
1576
+ n_xticks_any_c: int = 8,
1577
+ n_xticks_gpc: int = 8,
1578
+ n_xticks_cpg: int = 8,
1579
+ n_xticks_a: int = 8,
1580
+ index_col_suffix: str | None = None,
1581
+ fill_nan_strategy: str = "value",
1582
+ fill_nan_value: float = -1,
1583
+ length_feature_ranges: Optional[Sequence[Tuple[int, int, Any]]] = None,
1341
1584
  ):
1342
1585
  """
1343
- For each sample (row) and reference (col) plot the rolling average of the
1344
- positional mean (mean across reads) for each layer listed.
1345
-
1346
- Parameters
1347
- ----------
1348
- adata : AnnData
1349
- Input annotated data (expects obs columns sample_col and ref_col).
1350
- layers : list[str] | None
1351
- Which adata.layers to plot. If None, attempts to autodetect layers whose
1352
- matrices look like "HMM" outputs (else will error). If None and layers
1353
- cannot be found, user must pass a list.
1354
- sample_col, ref_col : str
1355
- obs columns used to group rows.
1356
- samples, references : optional lists
1357
- explicit ordering of samples / references. If None, categories in adata.obs are used.
1358
- window : int
1359
- rolling window size (odd recommended). If window <= 1, no smoothing applied.
1360
- min_periods : int
1361
- min periods param for pd.Series.rolling.
1362
- center : bool
1363
- center the rolling window.
1364
- rows_per_page : int
1365
- paginate rows per page into multiple figures if needed.
1366
- figsize_per_cell : (w,h)
1367
- per-subplot size in inches.
1368
- dpi : int
1369
- figure dpi when saving.
1370
- output_dir : str | None
1371
- directory to save pages; created if necessary. If None and save=True, uses cwd.
1372
- save : bool
1373
- whether to save PNG files.
1374
- show_raw : bool
1375
- draw unsmoothed mean as faint line under smoothed curve.
1376
- cmap : str
1377
- matplotlib colormap for layer lines.
1378
- use_var_coords : bool
1379
- if True, tries to use adata.var_names (coerced to int) as x-axis coordinates; otherwise uses 0..n-1.
1586
+ Plot clustermaps for length-encoded HMM feature layers with optional subclass colors.
1380
1587
 
1381
- Returns
1382
- -------
1383
- saved_files : list[str]
1384
- list of saved filenames (may be empty if save=False).
1588
+ Length-based feature ranges map integer lengths into subclass colors for accessible
1589
+ and footprint layers. Raw methylation panels are included when available.
1385
1590
  """
1591
+ if fill_nan_strategy not in {"none", "value", "col_mean"}:
1592
+ raise ValueError("fill_nan_strategy must be 'none', 'value', or 'col_mean'.")
1386
1593
 
1387
- # --- basic checks / defaults ---
1388
- if sample_col not in adata.obs.columns or ref_col not in adata.obs.columns:
1389
- raise ValueError(
1390
- f"sample_col '{sample_col}' and ref_col '{ref_col}' must exist in adata.obs"
1391
- )
1594
+ def pick_xticks(labels: np.ndarray, n_ticks: int):
1595
+ """Pick tick indices/labels from an array."""
1596
+ if labels.size == 0:
1597
+ return [], []
1598
+ idx = np.linspace(0, len(labels) - 1, n_ticks).round().astype(int)
1599
+ idx = np.unique(idx)
1600
+ return idx.tolist(), labels[idx].tolist()
1392
1601
 
1393
- # canonicalize samples / refs
1394
- if samples is None:
1395
- sseries = adata.obs[sample_col]
1396
- if not pd.api.types.is_categorical_dtype(sseries):
1397
- sseries = sseries.astype("category")
1398
- samples_all = list(sseries.cat.categories)
1399
- else:
1400
- samples_all = list(samples)
1602
+ def _mask_or_true(series_name: str, predicate):
1603
+ """Return a mask from predicate or an all-True mask."""
1604
+ if series_name not in adata.obs:
1605
+ return pd.Series(True, index=adata.obs.index)
1606
+ s = adata.obs[series_name]
1607
+ try:
1608
+ return predicate(s)
1609
+ except Exception:
1610
+ return pd.Series(True, index=adata.obs.index)
1401
1611
 
1402
- if references is None:
1403
- rseries = adata.obs[ref_col]
1404
- if not pd.api.types.is_categorical_dtype(rseries):
1405
- rseries = rseries.astype("category")
1406
- refs_all = list(rseries.cat.categories)
1407
- else:
1408
- refs_all = list(references)
1612
+ results = []
1613
+ signal_type = "deamination" if deaminase else "methylation"
1614
+ feature_ranges = tuple(length_feature_ranges or ())
1409
1615
 
1410
- # choose layers: if not provided, try a sensible default: all layers
1411
- if layers is None:
1412
- layers = list(adata.layers.keys())
1413
- if len(layers) == 0:
1414
- raise ValueError(
1415
- "No adata.layers found. Please pass `layers=[...]` of the HMM layers to plot."
1616
+ for ref in adata.obs[reference_col].cat.categories:
1617
+ for sample in adata.obs[sample_col].cat.categories:
1618
+ display_sample = sample_mapping.get(sample, sample) if sample_mapping else sample
1619
+ qmask = _mask_or_true(
1620
+ "read_quality",
1621
+ (lambda s: s >= float(min_quality))
1622
+ if (min_quality is not None)
1623
+ else (lambda s: pd.Series(True, index=s.index)),
1416
1624
  )
1417
- layers = list(layers)
1418
-
1419
- # x coordinates (positions)
1420
- try:
1421
- if use_var_coords:
1422
- x_coords = np.array([int(v) for v in adata.var_names])
1423
- else:
1424
- raise Exception("user disabled var coords")
1425
- except Exception:
1426
- # fallback to 0..n_vars-1
1427
- x_coords = np.arange(adata.shape[1], dtype=int)
1428
-
1625
+ lm_mask = _mask_or_true(
1626
+ "mapped_length",
1627
+ (lambda s: s >= float(min_length))
1628
+ if (min_length is not None)
1629
+ else (lambda s: pd.Series(True, index=s.index)),
1630
+ )
1631
+ lrr_mask = _mask_or_true(
1632
+ "mapped_length_to_reference_length_ratio",
1633
+ (lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
1634
+ if (min_mapped_length_to_reference_length_ratio is not None)
1635
+ else (lambda s: pd.Series(True, index=s.index)),
1636
+ )
1637
+
1638
+ demux_mask = _mask_or_true(
1639
+ "demux_type",
1640
+ (lambda s: s.astype("string").isin(list(demux_types)))
1641
+ if (demux_types is not None)
1642
+ else (lambda s: pd.Series(True, index=s.index)),
1643
+ )
1644
+
1645
+ ref_mask = adata.obs[reference_col] == ref
1646
+ sample_mask = adata.obs[sample_col] == sample
1647
+
1648
+ row_mask = ref_mask & sample_mask & qmask & lm_mask & lrr_mask & demux_mask
1649
+
1650
+ if not bool(row_mask.any()):
1651
+ print(
1652
+ f"No reads for {display_sample} - {ref} after read quality and length filtering"
1653
+ )
1654
+ continue
1655
+
1656
+ try:
1657
+ subset = adata[row_mask, :].copy()
1658
+
1659
+ if min_position_valid_fraction is not None:
1660
+ valid_key = f"{ref}_valid_fraction"
1661
+ if valid_key in subset.var:
1662
+ v = pd.to_numeric(subset.var[valid_key], errors="coerce").to_numpy()
1663
+ col_mask = np.asarray(v > float(min_position_valid_fraction), dtype=bool)
1664
+ if col_mask.any():
1665
+ subset = subset[:, col_mask].copy()
1666
+ else:
1667
+ print(
1668
+ f"No positions left after valid_fraction filter for {display_sample} - {ref}"
1669
+ )
1670
+ continue
1671
+
1672
+ if subset.shape[0] == 0:
1673
+ print(f"No reads left after filtering for {display_sample} - {ref}")
1674
+ continue
1675
+
1676
+ if bins is None:
1677
+ bins_temp = {"All": np.ones(subset.n_obs, dtype=bool)}
1678
+ else:
1679
+ bins_temp = bins
1680
+
1681
+ def _sites(*keys):
1682
+ """Return indices for the first matching site key."""
1683
+ for k in keys:
1684
+ if k in subset.var:
1685
+ return np.where(subset.var[k].values)[0]
1686
+ return np.array([], dtype=int)
1687
+
1688
+ gpc_sites = _sites(f"{ref}_GpC_site")
1689
+ cpg_sites = _sites(f"{ref}_CpG_site")
1690
+ any_c_sites = _sites(f"{ref}_any_C_site", f"{ref}_C_site")
1691
+ any_a_sites = _sites(f"{ref}_A_site", f"{ref}_any_A_site")
1692
+
1693
+ length_sites = np.arange(subset.n_vars, dtype=int)
1694
+ length_labels = _select_labels(subset, length_sites, ref, index_col_suffix)
1695
+ gpc_labels = _select_labels(subset, gpc_sites, ref, index_col_suffix)
1696
+ cpg_labels = _select_labels(subset, cpg_sites, ref, index_col_suffix)
1697
+ any_c_labels = _select_labels(subset, any_c_sites, ref, index_col_suffix)
1698
+ any_a_labels = _select_labels(subset, any_a_sites, ref, index_col_suffix)
1699
+
1700
+ stacked_lengths = []
1701
+ stacked_lengths_raw = []
1702
+ stacked_any_c = []
1703
+ stacked_any_c_raw = []
1704
+ stacked_gpc = []
1705
+ stacked_gpc_raw = []
1706
+ stacked_cpg = []
1707
+ stacked_cpg_raw = []
1708
+ stacked_any_a = []
1709
+ stacked_any_a_raw = []
1710
+
1711
+ row_labels, bin_labels, bin_boundaries = [], [], []
1712
+ total_reads = subset.n_obs
1713
+ percentages = {}
1714
+ last_idx = 0
1715
+
1716
+ for bin_label, bin_filter in bins_temp.items():
1717
+ sb = subset[bin_filter].copy()
1718
+ n = sb.n_obs
1719
+ if n == 0:
1720
+ continue
1721
+
1722
+ pct = (n / total_reads) * 100 if total_reads else 0
1723
+ percentages[bin_label] = pct
1724
+
1725
+ if sort_by.startswith("obs:"):
1726
+ colname = sort_by.split("obs:")[1]
1727
+ order = np.argsort(sb.obs[colname].values)
1728
+ elif sort_by == "gpc" and gpc_sites.size:
1729
+ gpc_matrix = _layer_to_numpy(
1730
+ sb,
1731
+ layer_gpc,
1732
+ gpc_sites,
1733
+ fill_nan_strategy=fill_nan_strategy,
1734
+ fill_nan_value=fill_nan_value,
1735
+ )
1736
+ linkage = sch.linkage(gpc_matrix, method="ward")
1737
+ order = sch.leaves_list(linkage)
1738
+ elif sort_by == "cpg" and cpg_sites.size:
1739
+ cpg_matrix = _layer_to_numpy(
1740
+ sb,
1741
+ layer_cpg,
1742
+ cpg_sites,
1743
+ fill_nan_strategy=fill_nan_strategy,
1744
+ fill_nan_value=fill_nan_value,
1745
+ )
1746
+ linkage = sch.linkage(cpg_matrix, method="ward")
1747
+ order = sch.leaves_list(linkage)
1748
+ elif sort_by == "c" and any_c_sites.size:
1749
+ any_c_matrix = _layer_to_numpy(
1750
+ sb,
1751
+ layer_c,
1752
+ any_c_sites,
1753
+ fill_nan_strategy=fill_nan_strategy,
1754
+ fill_nan_value=fill_nan_value,
1755
+ )
1756
+ linkage = sch.linkage(any_c_matrix, method="ward")
1757
+ order = sch.leaves_list(linkage)
1758
+ elif sort_by == "a" and any_a_sites.size:
1759
+ any_a_matrix = _layer_to_numpy(
1760
+ sb,
1761
+ layer_a,
1762
+ any_a_sites,
1763
+ fill_nan_strategy=fill_nan_strategy,
1764
+ fill_nan_value=fill_nan_value,
1765
+ )
1766
+ linkage = sch.linkage(any_a_matrix, method="ward")
1767
+ order = sch.leaves_list(linkage)
1768
+ elif sort_by == "gpc_cpg" and gpc_sites.size and cpg_sites.size:
1769
+ gpc_matrix = _layer_to_numpy(
1770
+ sb,
1771
+ layer_gpc,
1772
+ None,
1773
+ fill_nan_strategy=fill_nan_strategy,
1774
+ fill_nan_value=fill_nan_value,
1775
+ )
1776
+ linkage = sch.linkage(gpc_matrix, method="ward")
1777
+ order = sch.leaves_list(linkage)
1778
+ elif sort_by == "hmm" and length_sites.size:
1779
+ length_matrix = _layer_to_numpy(
1780
+ sb,
1781
+ length_layer,
1782
+ length_sites,
1783
+ fill_nan_strategy=fill_nan_strategy,
1784
+ fill_nan_value=fill_nan_value,
1785
+ )
1786
+ linkage = sch.linkage(length_matrix, method="ward")
1787
+ order = sch.leaves_list(linkage)
1788
+ else:
1789
+ order = np.arange(n)
1790
+
1791
+ sb = sb[order]
1792
+
1793
+ stacked_lengths.append(
1794
+ _layer_to_numpy(
1795
+ sb,
1796
+ length_layer,
1797
+ None,
1798
+ fill_nan_strategy=fill_nan_strategy,
1799
+ fill_nan_value=fill_nan_value,
1800
+ )
1801
+ )
1802
+ stacked_lengths_raw.append(
1803
+ _layer_to_numpy(
1804
+ sb,
1805
+ length_layer,
1806
+ None,
1807
+ fill_nan_strategy="none",
1808
+ fill_nan_value=fill_nan_value,
1809
+ )
1810
+ )
1811
+ if any_c_sites.size:
1812
+ stacked_any_c.append(
1813
+ _layer_to_numpy(
1814
+ sb,
1815
+ layer_c,
1816
+ any_c_sites,
1817
+ fill_nan_strategy=fill_nan_strategy,
1818
+ fill_nan_value=fill_nan_value,
1819
+ )
1820
+ )
1821
+ stacked_any_c_raw.append(
1822
+ _layer_to_numpy(
1823
+ sb,
1824
+ layer_c,
1825
+ any_c_sites,
1826
+ fill_nan_strategy="none",
1827
+ fill_nan_value=fill_nan_value,
1828
+ )
1829
+ )
1830
+ if gpc_sites.size:
1831
+ stacked_gpc.append(
1832
+ _layer_to_numpy(
1833
+ sb,
1834
+ layer_gpc,
1835
+ gpc_sites,
1836
+ fill_nan_strategy=fill_nan_strategy,
1837
+ fill_nan_value=fill_nan_value,
1838
+ )
1839
+ )
1840
+ stacked_gpc_raw.append(
1841
+ _layer_to_numpy(
1842
+ sb,
1843
+ layer_gpc,
1844
+ gpc_sites,
1845
+ fill_nan_strategy="none",
1846
+ fill_nan_value=fill_nan_value,
1847
+ )
1848
+ )
1849
+ if cpg_sites.size:
1850
+ stacked_cpg.append(
1851
+ _layer_to_numpy(
1852
+ sb,
1853
+ layer_cpg,
1854
+ cpg_sites,
1855
+ fill_nan_strategy=fill_nan_strategy,
1856
+ fill_nan_value=fill_nan_value,
1857
+ )
1858
+ )
1859
+ stacked_cpg_raw.append(
1860
+ _layer_to_numpy(
1861
+ sb,
1862
+ layer_cpg,
1863
+ cpg_sites,
1864
+ fill_nan_strategy="none",
1865
+ fill_nan_value=fill_nan_value,
1866
+ )
1867
+ )
1868
+ if any_a_sites.size:
1869
+ stacked_any_a.append(
1870
+ _layer_to_numpy(
1871
+ sb,
1872
+ layer_a,
1873
+ any_a_sites,
1874
+ fill_nan_strategy=fill_nan_strategy,
1875
+ fill_nan_value=fill_nan_value,
1876
+ )
1877
+ )
1878
+ stacked_any_a_raw.append(
1879
+ _layer_to_numpy(
1880
+ sb,
1881
+ layer_a,
1882
+ any_a_sites,
1883
+ fill_nan_strategy="none",
1884
+ fill_nan_value=fill_nan_value,
1885
+ )
1886
+ )
1887
+
1888
+ row_labels.extend([bin_label] * n)
1889
+ bin_labels.append(f"{bin_label}: {n} reads ({pct:.1f}%)")
1890
+ last_idx += n
1891
+ bin_boundaries.append(last_idx)
1892
+
1893
+ length_matrix = np.vstack(stacked_lengths)
1894
+ length_matrix_raw = np.vstack(stacked_lengths_raw)
1895
+ capped_lengths = np.where(length_matrix_raw > 1, 1.0, length_matrix_raw)
1896
+ mean_lengths = np.nanmean(capped_lengths, axis=0)
1897
+ length_plot_matrix = length_matrix_raw
1898
+ length_plot_cmap = cmap_lengths
1899
+ length_plot_norm = None
1900
+
1901
+ if feature_ranges:
1902
+ length_plot_matrix = _map_length_matrix_to_subclasses(
1903
+ length_matrix_raw, feature_ranges
1904
+ )
1905
+ length_plot_cmap, length_plot_norm = _build_length_feature_cmap(feature_ranges)
1906
+
1907
+ panels = [
1908
+ (
1909
+ f"HMM lengths - {length_layer}",
1910
+ length_plot_matrix,
1911
+ length_labels,
1912
+ length_plot_cmap,
1913
+ mean_lengths,
1914
+ n_xticks_lengths,
1915
+ length_plot_norm,
1916
+ ),
1917
+ ]
1918
+
1919
+ if stacked_any_c:
1920
+ m = np.vstack(stacked_any_c)
1921
+ m_raw = np.vstack(stacked_any_c_raw)
1922
+ panels.append(
1923
+ (
1924
+ "C",
1925
+ m,
1926
+ any_c_labels,
1927
+ cmap_c,
1928
+ _methylation_fraction_for_layer(m_raw, layer_c),
1929
+ n_xticks_any_c,
1930
+ None,
1931
+ )
1932
+ )
1933
+
1934
+ if stacked_gpc:
1935
+ m = np.vstack(stacked_gpc)
1936
+ m_raw = np.vstack(stacked_gpc_raw)
1937
+ panels.append(
1938
+ (
1939
+ "GpC",
1940
+ m,
1941
+ gpc_labels,
1942
+ cmap_gpc,
1943
+ _methylation_fraction_for_layer(m_raw, layer_gpc),
1944
+ n_xticks_gpc,
1945
+ None,
1946
+ )
1947
+ )
1948
+
1949
+ if stacked_cpg:
1950
+ m = np.vstack(stacked_cpg)
1951
+ m_raw = np.vstack(stacked_cpg_raw)
1952
+ panels.append(
1953
+ (
1954
+ "CpG",
1955
+ m,
1956
+ cpg_labels,
1957
+ cmap_cpg,
1958
+ _methylation_fraction_for_layer(m_raw, layer_cpg),
1959
+ n_xticks_cpg,
1960
+ None,
1961
+ )
1962
+ )
1963
+
1964
+ if stacked_any_a:
1965
+ m = np.vstack(stacked_any_a)
1966
+ m_raw = np.vstack(stacked_any_a_raw)
1967
+ panels.append(
1968
+ (
1969
+ "A",
1970
+ m,
1971
+ any_a_labels,
1972
+ cmap_a,
1973
+ _methylation_fraction_for_layer(m_raw, layer_a),
1974
+ n_xticks_a,
1975
+ None,
1976
+ )
1977
+ )
1978
+
1979
+ n_panels = len(panels)
1980
+ fig = plt.figure(figsize=(4.5 * n_panels, 10))
1981
+ gs = gridspec.GridSpec(2, n_panels, height_ratios=[1, 6], hspace=0.01)
1982
+ fig.suptitle(
1983
+ f"{sample} — {ref} — {total_reads} reads ({signal_type})", fontsize=14, y=0.98
1984
+ )
1985
+
1986
+ axes_heat = [fig.add_subplot(gs[1, i]) for i in range(n_panels)]
1987
+ axes_bar = [fig.add_subplot(gs[0, i], sharex=axes_heat[i]) for i in range(n_panels)]
1988
+
1989
+ for i, (name, matrix, labels, cmap, mean_vec, n_ticks, norm) in enumerate(panels):
1990
+ clean_barplot(axes_bar[i], mean_vec, name)
1991
+
1992
+ heatmap_kwargs = dict(
1993
+ cmap=cmap,
1994
+ ax=axes_heat[i],
1995
+ yticklabels=False,
1996
+ cbar=False,
1997
+ )
1998
+ if norm is not None:
1999
+ heatmap_kwargs["norm"] = norm
2000
+ sns.heatmap(matrix, **heatmap_kwargs)
2001
+
2002
+ xtick_pos, xtick_labels = pick_xticks(np.asarray(labels), n_ticks)
2003
+ axes_heat[i].set_xticks(xtick_pos)
2004
+ axes_heat[i].set_xticklabels(xtick_labels, rotation=90, fontsize=8)
2005
+
2006
+ for boundary in bin_boundaries[:-1]:
2007
+ axes_heat[i].axhline(y=boundary, color="black", linewidth=1.2)
2008
+
2009
+ plt.tight_layout()
2010
+
2011
+ if save_path:
2012
+ save_path = Path(save_path)
2013
+ save_path.mkdir(parents=True, exist_ok=True)
2014
+ safe_name = f"{ref}__{sample}".replace("/", "_")
2015
+ out_file = save_path / f"{safe_name}.png"
2016
+ plt.savefig(out_file, dpi=300)
2017
+ plt.close(fig)
2018
+ else:
2019
+ plt.show()
2020
+
2021
+ results.append((sample, ref))
2022
+
2023
+ except Exception:
2024
+ import traceback
2025
+
2026
+ traceback.print_exc()
2027
+ print(f"Failed {sample} - {ref} - {length_layer}")
2028
+
2029
+ return results
2030
+
2031
+
2032
+ def make_row_colors(meta: pd.DataFrame) -> pd.DataFrame:
2033
+ """
2034
+ Convert metadata columns to RGB colors without invoking pandas Categorical.map
2035
+ (MultiIndex-safe, category-safe).
2036
+ """
2037
+ row_colors = pd.DataFrame(index=meta.index)
2038
+
2039
+ for col in meta.columns:
2040
+ # Force plain python objects to avoid ExtensionArray/Categorical behavior
2041
+ s = meta[col].astype("object")
2042
+
2043
+ def _to_label(x):
2044
+ if x is None:
2045
+ return "NA"
2046
+ if isinstance(x, float) and np.isnan(x):
2047
+ return "NA"
2048
+ # If a MultiIndex object is stored in a cell (rare), bucket it
2049
+ if isinstance(x, pd.MultiIndex):
2050
+ return "MultiIndex"
2051
+ # Tuples are common when MultiIndex-ish things get stored as values
2052
+ if isinstance(x, tuple):
2053
+ return "|".join(map(str, x))
2054
+ return str(x)
2055
+
2056
+ labels = np.array([_to_label(x) for x in s.to_numpy()], dtype=object)
2057
+ uniq = pd.unique(labels)
2058
+ palette = dict(zip(uniq, sns.color_palette(n_colors=len(uniq))))
2059
+
2060
+ # Map via python loop -> no pandas map machinery
2061
+ colors = [palette.get(lbl, (0.7, 0.7, 0.7)) for lbl in labels]
2062
+ row_colors[col] = colors
2063
+
2064
+ return row_colors
2065
+
2066
+
2067
+ def plot_rolling_nn_and_layer(
2068
+ subset,
2069
+ obsm_key: str = "rolling_nn_dist",
2070
+ layer_key: str = "nan0_0minus1",
2071
+ meta_cols=("Reference_strand", "Sample"),
2072
+ col_cluster: bool = False,
2073
+ fill_nn_with_colmax: bool = True,
2074
+ fill_layer_value: float = 0.0,
2075
+ drop_all_nan_windows: bool = True,
2076
+ max_nan_fraction: float | None = None,
2077
+ var_valid_fraction_col: str | None = None,
2078
+ var_nan_fraction_col: str | None = None,
2079
+ figsize=(14, 10),
2080
+ right_panel_var_mask=None, # optional boolean mask over subset.var to reduce width
2081
+ robust=True,
2082
+ title: str | None = None,
2083
+ xtick_step: int | None = None,
2084
+ xtick_rotation: int = 90,
2085
+ xtick_fontsize: int = 8,
2086
+ save_name=None,
2087
+ ):
2088
+ """
2089
+ 1) Cluster rows by subset.obsm[obsm_key] (rolling NN distances)
2090
+ 2) Plot two heatmaps side-by-side in the SAME row order, with mean barplots above:
2091
+ - left: rolling NN distance matrix
2092
+ - right: subset.layers[layer_key] matrix
2093
+
2094
+ Handles categorical/MultiIndex issues in metadata coloring.
2095
+
2096
+ Args:
2097
+ subset: AnnData subset with rolling NN distances stored in ``obsm``.
2098
+ obsm_key: Key in ``subset.obsm`` containing rolling NN distances.
2099
+ layer_key: Layer name to plot alongside rolling NN distances.
2100
+ meta_cols: Obs columns used for row color annotations.
2101
+ col_cluster: Whether to cluster columns in the rolling NN clustermap.
2102
+ fill_nn_with_colmax: Fill NaNs in rolling NN distances with per-column max values.
2103
+ fill_layer_value: Fill NaNs in the layer heatmap with this value.
2104
+ drop_all_nan_windows: Drop rolling windows that are all NaN.
2105
+ max_nan_fraction: Maximum allowed NaN fraction per position (filtering columns).
2106
+ var_valid_fraction_col: ``subset.var`` column with valid fractions (1 - NaN fraction).
2107
+ var_nan_fraction_col: ``subset.var`` column with NaN fractions.
2108
+ figsize: Figure size for the combined plot.
2109
+ right_panel_var_mask: Optional boolean mask over ``subset.var`` for the right panel.
2110
+ robust: Use robust color scaling in seaborn.
2111
+ title: Optional figure title (suptitle).
2112
+ xtick_step: Spacing between x-axis tick labels.
2113
+ xtick_rotation: Rotation for x-axis tick labels.
2114
+ xtick_fontsize: Font size for x-axis tick labels.
2115
+ save_name: Optional output path for saving the plot.
2116
+ """
2117
+ if max_nan_fraction is not None and not (0 <= max_nan_fraction <= 1):
2118
+ raise ValueError("max_nan_fraction must be between 0 and 1.")
2119
+
2120
+ def _apply_xticks(ax, labels, step):
2121
+ if labels is None or len(labels) == 0:
2122
+ ax.set_xticks([])
2123
+ return
2124
+ if step is None or step <= 0:
2125
+ step = max(1, len(labels) // 10)
2126
+ ticks = np.arange(0, len(labels), step)
2127
+ ax.set_xticks(ticks + 0.5)
2128
+ ax.set_xticklabels(
2129
+ [labels[i] for i in ticks],
2130
+ rotation=xtick_rotation,
2131
+ fontsize=xtick_fontsize,
2132
+ )
2133
+
2134
+ # --- rolling NN distances
2135
+ X = subset.obsm[obsm_key]
2136
+ valid = ~np.all(np.isnan(X), axis=1)
2137
+
2138
+ X_df = pd.DataFrame(X[valid], index=subset.obs_names[valid])
2139
+
2140
+ if drop_all_nan_windows:
2141
+ X_df = X_df.loc[:, ~X_df.isna().all(axis=0)]
2142
+
2143
+ X_df_filled = X_df.copy()
2144
+ if fill_nn_with_colmax:
2145
+ col_max = X_df_filled.max(axis=0, skipna=True)
2146
+ X_df_filled = X_df_filled.fillna(col_max)
2147
+
2148
+ # Ensure non-MultiIndex index for seaborn
2149
+ X_df_filled.index = X_df_filled.index.astype(str)
2150
+
2151
+ # --- row colors from metadata (MultiIndex-safe)
2152
+ meta = subset.obs.loc[X_df.index, list(meta_cols)].copy()
2153
+ meta.index = meta.index.astype(str)
2154
+ row_colors = make_row_colors(meta)
2155
+
2156
+ # --- get row order via clustermap
2157
+ g = sns.clustermap(
2158
+ X_df_filled,
2159
+ cmap="viridis",
2160
+ col_cluster=col_cluster,
2161
+ row_cluster=True,
2162
+ row_colors=row_colors,
2163
+ xticklabels=False,
2164
+ yticklabels=False,
2165
+ robust=robust,
2166
+ )
2167
+ row_order = g.dendrogram_row.reordered_ind
2168
+ ordered_index = X_df_filled.index[row_order]
2169
+ plt.close(g.fig)
2170
+
2171
+ # reorder rolling NN matrix
2172
+ X_ord = X_df_filled.loc[ordered_index]
2173
+
2174
+ # --- layer matrix
2175
+ L = subset.layers[layer_key]
2176
+ L = L.toarray() if hasattr(L, "toarray") else np.asarray(L)
2177
+
2178
+ L_df = pd.DataFrame(L[valid], index=subset.obs_names[valid], columns=subset.var_names)
2179
+ L_df.index = L_df.index.astype(str)
2180
+
2181
+ if right_panel_var_mask is not None:
2182
+ # right_panel_var_mask must be boolean array/Series aligned to subset.var_names
2183
+ if hasattr(right_panel_var_mask, "values"):
2184
+ right_panel_var_mask = right_panel_var_mask.values
2185
+ right_panel_var_mask = np.asarray(right_panel_var_mask, dtype=bool)
2186
+
2187
+ if max_nan_fraction is not None:
2188
+ nan_fraction = None
2189
+ if var_nan_fraction_col and var_nan_fraction_col in subset.var:
2190
+ nan_fraction = pd.to_numeric(
2191
+ subset.var[var_nan_fraction_col], errors="coerce"
2192
+ ).to_numpy()
2193
+ elif var_valid_fraction_col and var_valid_fraction_col in subset.var:
2194
+ valid_fraction = pd.to_numeric(
2195
+ subset.var[var_valid_fraction_col], errors="coerce"
2196
+ ).to_numpy()
2197
+ nan_fraction = 1 - valid_fraction
2198
+ if nan_fraction is not None:
2199
+ nan_mask = nan_fraction <= max_nan_fraction
2200
+ if right_panel_var_mask is None:
2201
+ right_panel_var_mask = nan_mask
2202
+ else:
2203
+ right_panel_var_mask = right_panel_var_mask & nan_mask
2204
+
2205
+ if right_panel_var_mask is not None:
2206
+ if right_panel_var_mask.size != L_df.shape[1]:
2207
+ raise ValueError("right_panel_var_mask must align with subset.var_names.")
2208
+ L_df = L_df.loc[:, right_panel_var_mask]
2209
+
2210
+ L_ord = L_df.loc[ordered_index]
2211
+ L_plot = L_ord.fillna(fill_layer_value)
2212
+
2213
+ # --- plot side-by-side with barplots above
2214
+ fig = plt.figure(figsize=figsize)
2215
+ gs = fig.add_gridspec(
2216
+ 2,
2217
+ 4,
2218
+ width_ratios=[1, 0.05, 1, 0.05],
2219
+ height_ratios=[1, 6],
2220
+ wspace=0.2,
2221
+ hspace=0.05,
2222
+ )
2223
+
2224
+ ax1 = fig.add_subplot(gs[1, 0])
2225
+ ax1_cbar = fig.add_subplot(gs[1, 1])
2226
+ ax2 = fig.add_subplot(gs[1, 2])
2227
+ ax2_cbar = fig.add_subplot(gs[1, 3])
2228
+ ax1_bar = fig.add_subplot(gs[0, 0], sharex=ax1)
2229
+ ax2_bar = fig.add_subplot(gs[0, 2], sharex=ax2)
2230
+ fig.add_subplot(gs[0, 1]).axis("off")
2231
+ fig.add_subplot(gs[0, 3]).axis("off")
2232
+
2233
+ mean_nn = np.nanmean(X_ord.to_numpy(), axis=0)
2234
+ clean_barplot(
2235
+ ax1_bar,
2236
+ mean_nn,
2237
+ obsm_key,
2238
+ y_max=None,
2239
+ y_label="Mean distance",
2240
+ y_ticks=None,
2241
+ )
2242
+
2243
+ sns.heatmap(
2244
+ X_ord,
2245
+ ax=ax1,
2246
+ cmap="viridis",
2247
+ xticklabels=False,
2248
+ yticklabels=False,
2249
+ robust=robust,
2250
+ cbar_ax=ax1_cbar,
2251
+ )
2252
+ starts = subset.uns.get(f"{obsm_key}_starts")
2253
+ if starts is not None:
2254
+ starts = np.asarray(starts)
2255
+ window_labels = [str(s) for s in starts]
2256
+ try:
2257
+ col_idx = X_ord.columns.to_numpy()
2258
+ if np.issubdtype(col_idx.dtype, np.number):
2259
+ col_idx = col_idx.astype(int)
2260
+ if col_idx.size and col_idx.max() < len(starts):
2261
+ window_labels = [str(s) for s in starts[col_idx]]
2262
+ except Exception:
2263
+ window_labels = [str(s) for s in starts]
2264
+ _apply_xticks(ax1, window_labels, xtick_step)
2265
+
2266
+ methylation_fraction = _methylation_fraction_for_layer(L_ord.to_numpy(), layer_key)
2267
+ clean_barplot(
2268
+ ax2_bar,
2269
+ methylation_fraction,
2270
+ layer_key,
2271
+ y_max=1.0,
2272
+ y_label="Methylation fraction",
2273
+ y_ticks=[0.0, 0.5, 1.0],
2274
+ )
2275
+
2276
+ sns.heatmap(
2277
+ L_plot,
2278
+ ax=ax2,
2279
+ cmap="coolwarm",
2280
+ xticklabels=False,
2281
+ yticklabels=False,
2282
+ robust=robust,
2283
+ cbar_ax=ax2_cbar,
2284
+ )
2285
+ _apply_xticks(ax2, [str(x) for x in L_plot.columns], xtick_step)
2286
+
2287
+ if title:
2288
+ fig.suptitle(title)
2289
+
2290
+ if save_name is not None:
2291
+ fname = os.path.join(save_name)
2292
+ plt.savefig(fname, dpi=200, bbox_inches="tight")
2293
+
2294
+ else:
2295
+ plt.show()
2296
+
2297
+ return ordered_index
2298
+
2299
+
2300
+ def plot_sequence_integer_encoding_clustermaps(
2301
+ adata,
2302
+ sample_col: str = "Sample_Names",
2303
+ reference_col: str = "Reference_strand",
2304
+ layer: str = "sequence_integer_encoding",
2305
+ mismatch_layer: str = "mismatch_integer_encoding",
2306
+ min_quality: float | None = 20,
2307
+ min_length: int | None = 200,
2308
+ min_mapped_length_to_reference_length_ratio: float | None = 0,
2309
+ demux_types: Sequence[str] = ("single", "double", "already"),
2310
+ sort_by: str = "none", # "none", "hierarchical", "obs:<col>"
2311
+ cmap: str = "viridis",
2312
+ max_unknown_fraction: float | None = None,
2313
+ unknown_values: Sequence[int] = (4, 5),
2314
+ xtick_step: int | None = None,
2315
+ xtick_rotation: int = 90,
2316
+ xtick_fontsize: int = 9,
2317
+ max_reads: int | None = None,
2318
+ save_path: str | Path | None = None,
2319
+ use_dna_5color_palette: bool = True,
2320
+ show_numeric_colorbar: bool = False,
2321
+ show_position_axis: bool = False,
2322
+ position_axis_tick_target: int = 25,
2323
+ ):
2324
+ """Plot integer-encoded sequence clustermaps per sample/reference.
2325
+
2326
+ Args:
2327
+ adata: AnnData with a ``sequence_integer_encoding`` layer.
2328
+ sample_col: Column in ``adata.obs`` that identifies samples.
2329
+ reference_col: Column in ``adata.obs`` that identifies references.
2330
+ layer: Layer name containing integer-encoded sequences.
2331
+ mismatch_layer: Optional layer name containing mismatch integer encodings.
2332
+ min_quality: Optional minimum read quality filter.
2333
+ min_length: Optional minimum mapped length filter.
2334
+ min_mapped_length_to_reference_length_ratio: Optional min length ratio filter.
2335
+ demux_types: Allowed ``demux_type`` values, if present in ``adata.obs``.
2336
+ sort_by: Row sorting strategy: ``none``, ``hierarchical``, or ``obs:<col>``.
2337
+ cmap: Matplotlib colormap for the heatmap when ``use_dna_5color_palette`` is False.
2338
+ max_unknown_fraction: Optional maximum fraction of ``unknown_values`` allowed per
2339
+ position; positions above this threshold are excluded.
2340
+ unknown_values: Integer values to treat as unknown/padding.
2341
+ xtick_step: Spacing between x-axis tick labels (None = no labels).
2342
+ xtick_rotation: Rotation for x-axis tick labels.
2343
+ xtick_fontsize: Font size for x-axis tick labels.
2344
+ max_reads: Optional maximum number of reads to plot per sample/reference.
2345
+ save_path: Optional output directory for saving plots.
2346
+ use_dna_5color_palette: Whether to use a fixed A/C/G/T/Other palette.
2347
+ show_numeric_colorbar: If False, use a legend instead of a numeric colorbar.
2348
+ show_position_axis: Whether to draw a position axis with tick labels.
2349
+ position_axis_tick_target: Approximate number of ticks to show when auto-sizing.
2350
+
2351
+ Returns:
2352
+ List of dictionaries with per-plot metadata and output paths.
2353
+ """
2354
+
2355
+ def _mask_or_true(series_name: str, predicate):
2356
+ if series_name not in adata.obs:
2357
+ return pd.Series(True, index=adata.obs.index)
2358
+ s = adata.obs[series_name]
2359
+ try:
2360
+ return predicate(s)
2361
+ except Exception:
2362
+ return pd.Series(True, index=adata.obs.index)
2363
+
2364
+ if layer not in adata.layers:
2365
+ raise KeyError(f"Layer '{layer}' not found in adata.layers")
2366
+
2367
+ if max_unknown_fraction is not None and not (0 <= max_unknown_fraction <= 1):
2368
+ raise ValueError("max_unknown_fraction must be between 0 and 1.")
2369
+
2370
+ if position_axis_tick_target < 1:
2371
+ raise ValueError("position_axis_tick_target must be at least 1.")
2372
+
2373
+ results: List[Dict[str, Any]] = []
2374
+ save_path = Path(save_path) if save_path is not None else None
2375
+ if save_path is not None:
2376
+ save_path.mkdir(parents=True, exist_ok=True)
2377
+
2378
+ for col in (sample_col, reference_col):
2379
+ if col not in adata.obs:
2380
+ raise KeyError(f"{col} not in adata.obs")
2381
+ if not isinstance(adata.obs[col].dtype, pd.CategoricalDtype):
2382
+ adata.obs[col] = adata.obs[col].astype("category")
2383
+
2384
+ int_to_base = adata.uns.get("sequence_integer_decoding_map", {}) or {}
2385
+ if not int_to_base:
2386
+ encoding_map = adata.uns.get("sequence_integer_encoding_map", {}) or {}
2387
+ int_to_base = {int(v): str(k) for k, v in encoding_map.items()} if encoding_map else {}
2388
+
2389
+ coerced_int_to_base = {}
2390
+ for key, value in int_to_base.items():
2391
+ try:
2392
+ coerced_key = int(key)
2393
+ except Exception:
2394
+ continue
2395
+ coerced_int_to_base[coerced_key] = str(value)
2396
+ int_to_base = coerced_int_to_base
2397
+
2398
+ def normalize_base(base: str) -> str:
2399
+ return base if base in {"A", "C", "G", "T"} else "OTHER"
2400
+
2401
+ mismatch_int_to_base = {}
2402
+ if mismatch_layer in adata.layers:
2403
+ mismatch_encoding_map = adata.uns.get("mismatch_integer_encoding_map", {}) or {}
2404
+ mismatch_int_to_base = {
2405
+ int(v): str(k)
2406
+ for k, v in mismatch_encoding_map.items()
2407
+ if isinstance(v, (int, np.integer))
2408
+ }
2409
+
2410
+ def _resolve_xtick_step(n_positions: int) -> int | None:
2411
+ if xtick_step is not None:
2412
+ return xtick_step
2413
+ if not show_position_axis:
2414
+ return None
2415
+ return max(1, int(np.ceil(n_positions / position_axis_tick_target)))
2416
+
2417
+ for ref in adata.obs[reference_col].cat.categories:
2418
+ for sample in adata.obs[sample_col].cat.categories:
2419
+ qmask = _mask_or_true(
2420
+ "read_quality",
2421
+ (lambda s: s >= float(min_quality))
2422
+ if (min_quality is not None)
2423
+ else (lambda s: pd.Series(True, index=s.index)),
2424
+ )
2425
+ lm_mask = _mask_or_true(
2426
+ "mapped_length",
2427
+ (lambda s: s >= float(min_length))
2428
+ if (min_length is not None)
2429
+ else (lambda s: pd.Series(True, index=s.index)),
2430
+ )
2431
+ lrr_mask = _mask_or_true(
2432
+ "mapped_length_to_reference_length_ratio",
2433
+ (lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
2434
+ if (min_mapped_length_to_reference_length_ratio is not None)
2435
+ else (lambda s: pd.Series(True, index=s.index)),
2436
+ )
2437
+ demux_mask = _mask_or_true(
2438
+ "demux_type",
2439
+ (lambda s: s.astype("string").isin(list(demux_types)))
2440
+ if (demux_types is not None)
2441
+ else (lambda s: pd.Series(True, index=s.index)),
2442
+ )
2443
+
2444
+ row_mask = (
2445
+ (adata.obs[reference_col] == ref)
2446
+ & (adata.obs[sample_col] == sample)
2447
+ & qmask
2448
+ & lm_mask
2449
+ & lrr_mask
2450
+ & demux_mask
2451
+ )
2452
+ if not bool(row_mask.any()):
2453
+ continue
2454
+
2455
+ subset = adata[row_mask, :].copy()
2456
+ matrix = np.asarray(subset.layers[layer])
2457
+ mismatch_matrix = None
2458
+ if mismatch_layer in subset.layers:
2459
+ mismatch_matrix = np.asarray(subset.layers[mismatch_layer])
2460
+
2461
+ if max_unknown_fraction is not None:
2462
+ unknown_mask = np.isin(matrix, np.asarray(unknown_values))
2463
+ unknown_fraction = unknown_mask.mean(axis=0)
2464
+ keep_columns = unknown_fraction <= max_unknown_fraction
2465
+ if not np.any(keep_columns):
2466
+ continue
2467
+ matrix = matrix[:, keep_columns]
2468
+ subset = subset[:, keep_columns].copy()
2469
+ if mismatch_matrix is not None:
2470
+ mismatch_matrix = mismatch_matrix[:, keep_columns]
2471
+
2472
+ if max_reads is not None and matrix.shape[0] > max_reads:
2473
+ matrix = matrix[:max_reads]
2474
+ subset = subset[:max_reads, :].copy()
2475
+ if mismatch_matrix is not None:
2476
+ mismatch_matrix = mismatch_matrix[:max_reads]
2477
+
2478
+ if matrix.size == 0:
2479
+ continue
2480
+
2481
+ if use_dna_5color_palette and not int_to_base:
2482
+ uniq_vals = np.unique(matrix[~pd.isna(matrix)])
2483
+ guess = {}
2484
+ for val in uniq_vals:
2485
+ try:
2486
+ int_val = int(val)
2487
+ except Exception:
2488
+ continue
2489
+ guess[int_val] = {0: "A", 1: "C", 2: "G", 3: "T"}.get(int_val, "OTHER")
2490
+ int_to_base_local = guess
2491
+ else:
2492
+ int_to_base_local = int_to_base
2493
+
2494
+ order = None
2495
+ if sort_by.startswith("obs:"):
2496
+ colname = sort_by.split("obs:")[1]
2497
+ order = np.argsort(subset.obs[colname].values)
2498
+ elif sort_by == "hierarchical":
2499
+ linkage = sch.linkage(np.nan_to_num(matrix), method="ward")
2500
+ order = sch.leaves_list(linkage)
2501
+ elif sort_by != "none":
2502
+ raise ValueError("sort_by must be 'none', 'hierarchical', or 'obs:<col>'")
2503
+
2504
+ if order is not None:
2505
+ matrix = matrix[order]
2506
+ if mismatch_matrix is not None:
2507
+ mismatch_matrix = mismatch_matrix[order]
2508
+
2509
+ has_mismatch = mismatch_matrix is not None
2510
+ fig, axes = plt.subplots(
2511
+ ncols=2 if has_mismatch else 1,
2512
+ figsize=(18, 6) if has_mismatch else (12, 6),
2513
+ sharey=has_mismatch,
2514
+ )
2515
+ if not isinstance(axes, np.ndarray):
2516
+ axes = np.asarray([axes])
2517
+ ax = axes[0]
2518
+
2519
+ if use_dna_5color_palette and int_to_base_local:
2520
+ int_to_color = {
2521
+ int(int_val): DNA_5COLOR_PALETTE[normalize_base(str(base))]
2522
+ for int_val, base in int_to_base_local.items()
2523
+ }
2524
+ uniq_matrix = np.unique(matrix[~pd.isna(matrix)])
2525
+ for val in uniq_matrix:
2526
+ try:
2527
+ int_val = int(val)
2528
+ except Exception:
2529
+ continue
2530
+ if int_val not in int_to_color:
2531
+ int_to_color[int_val] = DNA_5COLOR_PALETTE["OTHER"]
2532
+
2533
+ ordered = sorted(int_to_color.items(), key=lambda x: x[0])
2534
+ colors_list = [color for _, color in ordered]
2535
+ bounds = [int_val - 0.5 for int_val, _ in ordered]
2536
+ bounds.append(ordered[-1][0] + 0.5)
2537
+
2538
+ cmap_obj = colors.ListedColormap(colors_list)
2539
+ norm = colors.BoundaryNorm(bounds, cmap_obj.N)
2540
+
2541
+ sns.heatmap(
2542
+ matrix,
2543
+ cmap=cmap_obj,
2544
+ norm=norm,
2545
+ ax=ax,
2546
+ yticklabels=False,
2547
+ cbar=show_numeric_colorbar,
2548
+ )
2549
+
2550
+ legend_handles = [
2551
+ patches.Patch(facecolor=DNA_5COLOR_PALETTE["A"], label="A"),
2552
+ patches.Patch(facecolor=DNA_5COLOR_PALETTE["C"], label="C"),
2553
+ patches.Patch(facecolor=DNA_5COLOR_PALETTE["G"], label="G"),
2554
+ patches.Patch(facecolor=DNA_5COLOR_PALETTE["T"], label="T"),
2555
+ patches.Patch(
2556
+ facecolor=DNA_5COLOR_PALETTE["OTHER"],
2557
+ label="Other (N / PAD / unknown)",
2558
+ ),
2559
+ ]
2560
+ ax.legend(
2561
+ handles=legend_handles,
2562
+ title="Base",
2563
+ loc="upper left",
2564
+ bbox_to_anchor=(1.02, 1.0),
2565
+ frameon=False,
2566
+ )
2567
+ else:
2568
+ sns.heatmap(matrix, cmap=cmap, ax=ax, yticklabels=False, cbar=True)
2569
+
2570
+ ax.set_title(layer)
2571
+
2572
+ resolved_step = _resolve_xtick_step(matrix.shape[1])
2573
+ if resolved_step is not None and resolved_step > 0:
2574
+ sites = np.arange(0, matrix.shape[1], resolved_step)
2575
+ ax.set_xticks(sites)
2576
+ ax.set_xticklabels(
2577
+ subset.var_names[sites].astype(str),
2578
+ rotation=xtick_rotation,
2579
+ fontsize=xtick_fontsize,
2580
+ )
2581
+ else:
2582
+ ax.set_xticks([])
2583
+ if show_position_axis or xtick_step is not None:
2584
+ ax.set_xlabel("Position")
2585
+
2586
+ if has_mismatch:
2587
+ mismatch_ax = axes[1]
2588
+ mismatch_int_to_base_local = mismatch_int_to_base or int_to_base_local
2589
+ if use_dna_5color_palette and mismatch_int_to_base_local:
2590
+ mismatch_int_to_color = {}
2591
+ for int_val, base in mismatch_int_to_base_local.items():
2592
+ base_upper = str(base).upper()
2593
+ if base_upper == "PAD":
2594
+ mismatch_int_to_color[int(int_val)] = "#D3D3D3"
2595
+ elif base_upper == "N":
2596
+ mismatch_int_to_color[int(int_val)] = "#808080"
2597
+ else:
2598
+ mismatch_int_to_color[int(int_val)] = DNA_5COLOR_PALETTE[
2599
+ normalize_base(base_upper)
2600
+ ]
2601
+
2602
+ uniq_mismatch = np.unique(mismatch_matrix[~pd.isna(mismatch_matrix)])
2603
+ for val in uniq_mismatch:
2604
+ try:
2605
+ int_val = int(val)
2606
+ except Exception:
2607
+ continue
2608
+ if int_val not in mismatch_int_to_color:
2609
+ mismatch_int_to_color[int_val] = DNA_5COLOR_PALETTE["OTHER"]
2610
+
2611
+ ordered_mismatch = sorted(mismatch_int_to_color.items(), key=lambda x: x[0])
2612
+ mismatch_colors = [color for _, color in ordered_mismatch]
2613
+ mismatch_bounds = [int_val - 0.5 for int_val, _ in ordered_mismatch]
2614
+ mismatch_bounds.append(ordered_mismatch[-1][0] + 0.5)
2615
+
2616
+ mismatch_cmap = colors.ListedColormap(mismatch_colors)
2617
+ mismatch_norm = colors.BoundaryNorm(mismatch_bounds, mismatch_cmap.N)
2618
+
2619
+ sns.heatmap(
2620
+ mismatch_matrix,
2621
+ cmap=mismatch_cmap,
2622
+ norm=mismatch_norm,
2623
+ ax=mismatch_ax,
2624
+ yticklabels=False,
2625
+ cbar=show_numeric_colorbar,
2626
+ )
2627
+
2628
+ mismatch_legend_handles = [
2629
+ patches.Patch(facecolor=DNA_5COLOR_PALETTE["A"], label="A"),
2630
+ patches.Patch(facecolor=DNA_5COLOR_PALETTE["C"], label="C"),
2631
+ patches.Patch(facecolor=DNA_5COLOR_PALETTE["G"], label="G"),
2632
+ patches.Patch(facecolor=DNA_5COLOR_PALETTE["T"], label="T"),
2633
+ patches.Patch(facecolor="#808080", label="Match/N"),
2634
+ patches.Patch(facecolor="#D3D3D3", label="PAD"),
2635
+ ]
2636
+ mismatch_ax.legend(
2637
+ handles=mismatch_legend_handles,
2638
+ title="Mismatch base",
2639
+ loc="upper left",
2640
+ bbox_to_anchor=(1.02, 1.0),
2641
+ frameon=False,
2642
+ )
2643
+ else:
2644
+ sns.heatmap(
2645
+ mismatch_matrix,
2646
+ cmap=cmap,
2647
+ ax=mismatch_ax,
2648
+ yticklabels=False,
2649
+ cbar=True,
2650
+ )
2651
+
2652
+ mismatch_ax.set_title(mismatch_layer)
2653
+ if resolved_step is not None and resolved_step > 0:
2654
+ sites = np.arange(0, mismatch_matrix.shape[1], resolved_step)
2655
+ mismatch_ax.set_xticks(sites)
2656
+ mismatch_ax.set_xticklabels(
2657
+ subset.var_names[sites].astype(str),
2658
+ rotation=xtick_rotation,
2659
+ fontsize=xtick_fontsize,
2660
+ )
2661
+ else:
2662
+ mismatch_ax.set_xticks([])
2663
+ if show_position_axis or xtick_step is not None:
2664
+ mismatch_ax.set_xlabel("Position")
2665
+
2666
+ fig.suptitle(f"{sample} - {ref}")
2667
+ fig.tight_layout(rect=(0, 0, 1, 0.95))
2668
+
2669
+ out_file = None
2670
+ if save_path is not None:
2671
+ safe_name = f"{ref}__{sample}__{layer}".replace("=", "").replace(",", "_")
2672
+ out_file = save_path / f"{safe_name}.png"
2673
+ fig.savefig(out_file, dpi=300, bbox_inches="tight")
2674
+ plt.close(fig)
2675
+ else:
2676
+ plt.show()
2677
+
2678
+ results.append(
2679
+ {
2680
+ "reference": str(ref),
2681
+ "sample": str(sample),
2682
+ "layer": layer,
2683
+ "n_positions": int(matrix.shape[1]),
2684
+ "mismatch_layer": mismatch_layer if has_mismatch else None,
2685
+ "mismatch_layer_present": bool(has_mismatch),
2686
+ "output_path": str(out_file) if out_file is not None else None,
2687
+ }
2688
+ )
2689
+
2690
+ return results
2691
+
2692
+
2693
+ def plot_read_span_quality_clustermaps(
2694
+ adata,
2695
+ sample_col: str = "Sample_Names",
2696
+ reference_col: str = "Reference_strand",
2697
+ quality_layer: str = "base_quality_scores",
2698
+ read_span_layer: str = "read_span_mask",
2699
+ quality_cmap: str = "viridis",
2700
+ read_span_color: str = "#2ca25f",
2701
+ max_nan_fraction: float | None = None,
2702
+ min_quality: float | None = None,
2703
+ min_length: int | None = None,
2704
+ min_mapped_length_to_reference_length_ratio: float | None = None,
2705
+ demux_types: Sequence[str] = ("single", "double", "already"),
2706
+ max_reads: int | None = None,
2707
+ xtick_step: int | None = None,
2708
+ xtick_rotation: int = 90,
2709
+ xtick_fontsize: int = 9,
2710
+ show_position_axis: bool = False,
2711
+ position_axis_tick_target: int = 25,
2712
+ save_path: str | Path | None = None,
2713
+ ) -> List[Dict[str, Any]]:
2714
+ """Plot read-span mask and base quality clustermaps side by side.
2715
+
2716
+ Clustering is performed using the base-quality layer ordering, which is then
2717
+ applied to the read-span mask to keep the two panels aligned.
2718
+
2719
+ Args:
2720
+ adata: AnnData with read-span and base-quality layers.
2721
+ sample_col: Column in ``adata.obs`` that identifies samples.
2722
+ reference_col: Column in ``adata.obs`` that identifies references.
2723
+ quality_layer: Layer name containing base-quality scores.
2724
+ read_span_layer: Layer name containing read-span masks.
2725
+ quality_cmap: Colormap for base-quality scores.
2726
+ read_span_color: Color for read-span mask (1-values); 0-values are white.
2727
+ max_nan_fraction: Optional maximum fraction of NaNs allowed per position; positions
2728
+ above this threshold are excluded.
2729
+ min_quality: Optional minimum read quality filter.
2730
+ min_length: Optional minimum mapped length filter.
2731
+ min_mapped_length_to_reference_length_ratio: Optional min length ratio filter.
2732
+ demux_types: Allowed ``demux_type`` values, if present in ``adata.obs``.
2733
+ max_reads: Optional maximum number of reads to plot per sample/reference.
2734
+ xtick_step: Spacing between x-axis tick labels (None = no labels).
2735
+ xtick_rotation: Rotation for x-axis tick labels.
2736
+ xtick_fontsize: Font size for x-axis tick labels.
2737
+ show_position_axis: Whether to draw a position axis with tick labels.
2738
+ position_axis_tick_target: Approximate number of ticks to show when auto-sizing.
2739
+ save_path: Optional output directory for saving plots.
2740
+
2741
+ Returns:
2742
+ List of dictionaries with per-plot metadata and output paths.
2743
+ """
2744
+
2745
+ def _mask_or_true(series_name: str, predicate):
2746
+ if series_name not in adata.obs:
2747
+ return pd.Series(True, index=adata.obs.index)
2748
+ s = adata.obs[series_name]
2749
+ try:
2750
+ return predicate(s)
2751
+ except Exception:
2752
+ return pd.Series(True, index=adata.obs.index)
2753
+
2754
+ def _resolve_xtick_step(n_positions: int) -> int | None:
2755
+ if xtick_step is not None:
2756
+ return xtick_step
2757
+ if not show_position_axis:
2758
+ return None
2759
+ return max(1, int(np.ceil(n_positions / position_axis_tick_target)))
2760
+
2761
+ def _fill_nan_with_col_means(matrix: np.ndarray) -> np.ndarray:
2762
+ filled = matrix.copy()
2763
+ col_means = np.nanmean(filled, axis=0)
2764
+ col_means = np.where(np.isnan(col_means), 0.0, col_means)
2765
+ nan_rows, nan_cols = np.where(np.isnan(filled))
2766
+ filled[nan_rows, nan_cols] = col_means[nan_cols]
2767
+ return filled
2768
+
2769
+ if quality_layer not in adata.layers:
2770
+ raise KeyError(f"Layer '{quality_layer}' not found in adata.layers")
2771
+ if read_span_layer not in adata.layers:
2772
+ raise KeyError(f"Layer '{read_span_layer}' not found in adata.layers")
2773
+ if max_nan_fraction is not None and not (0 <= max_nan_fraction <= 1):
2774
+ raise ValueError("max_nan_fraction must be between 0 and 1.")
2775
+ if position_axis_tick_target < 1:
2776
+ raise ValueError("position_axis_tick_target must be at least 1.")
2777
+
2778
+ results: List[Dict[str, Any]] = []
2779
+ save_path = Path(save_path) if save_path is not None else None
2780
+ if save_path is not None:
2781
+ save_path.mkdir(parents=True, exist_ok=True)
2782
+
2783
+ for col in (sample_col, reference_col):
2784
+ if col not in adata.obs:
2785
+ raise KeyError(f"{col} not in adata.obs")
2786
+ if not isinstance(adata.obs[col].dtype, pd.CategoricalDtype):
2787
+ adata.obs[col] = adata.obs[col].astype("category")
2788
+
2789
+ for ref in adata.obs[reference_col].cat.categories:
2790
+ for sample in adata.obs[sample_col].cat.categories:
2791
+ qmask = _mask_or_true(
2792
+ "read_quality",
2793
+ (lambda s: s >= float(min_quality))
2794
+ if (min_quality is not None)
2795
+ else (lambda s: pd.Series(True, index=s.index)),
2796
+ )
2797
+ lm_mask = _mask_or_true(
2798
+ "mapped_length",
2799
+ (lambda s: s >= float(min_length))
2800
+ if (min_length is not None)
2801
+ else (lambda s: pd.Series(True, index=s.index)),
2802
+ )
2803
+ lrr_mask = _mask_or_true(
2804
+ "mapped_length_to_reference_length_ratio",
2805
+ (lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
2806
+ if (min_mapped_length_to_reference_length_ratio is not None)
2807
+ else (lambda s: pd.Series(True, index=s.index)),
2808
+ )
2809
+ demux_mask = _mask_or_true(
2810
+ "demux_type",
2811
+ (lambda s: s.astype("string").isin(list(demux_types)))
2812
+ if (demux_types is not None)
2813
+ else (lambda s: pd.Series(True, index=s.index)),
2814
+ )
2815
+
2816
+ row_mask = (
2817
+ (adata.obs[reference_col] == ref)
2818
+ & (adata.obs[sample_col] == sample)
2819
+ & qmask
2820
+ & lm_mask
2821
+ & lrr_mask
2822
+ & demux_mask
2823
+ )
2824
+ if not bool(row_mask.any()):
2825
+ continue
2826
+
2827
+ subset = adata[row_mask, :].copy()
2828
+ quality_matrix = np.asarray(subset.layers[quality_layer]).astype(float)
2829
+ quality_matrix[quality_matrix < 0] = np.nan
2830
+ read_span_matrix = np.asarray(subset.layers[read_span_layer]).astype(float)
2831
+
2832
+ if max_nan_fraction is not None:
2833
+ nan_mask = np.isnan(quality_matrix) | np.isnan(read_span_matrix)
2834
+ nan_fraction = nan_mask.mean(axis=0)
2835
+ keep_columns = nan_fraction <= max_nan_fraction
2836
+ if not np.any(keep_columns):
2837
+ continue
2838
+ quality_matrix = quality_matrix[:, keep_columns]
2839
+ read_span_matrix = read_span_matrix[:, keep_columns]
2840
+ subset = subset[:, keep_columns].copy()
2841
+
2842
+ if max_reads is not None and quality_matrix.shape[0] > max_reads:
2843
+ quality_matrix = quality_matrix[:max_reads]
2844
+ read_span_matrix = read_span_matrix[:max_reads]
2845
+ subset = subset[:max_reads, :].copy()
2846
+
2847
+ if quality_matrix.size == 0:
2848
+ continue
2849
+
2850
+ quality_filled = _fill_nan_with_col_means(quality_matrix)
2851
+ linkage = sch.linkage(quality_filled, method="ward")
2852
+ order = sch.leaves_list(linkage)
2853
+
2854
+ quality_matrix = quality_matrix[order]
2855
+ read_span_matrix = read_span_matrix[order]
2856
+
2857
+ fig, axes = plt.subplots(
2858
+ nrows=2,
2859
+ ncols=3,
2860
+ figsize=(18, 6),
2861
+ sharex="col",
2862
+ gridspec_kw={"height_ratios": [1, 4], "width_ratios": [1, 1, 0.05]},
2863
+ )
2864
+ span_bar_ax, quality_bar_ax, bar_spacer_ax = axes[0]
2865
+ span_ax, quality_ax, cbar_ax = axes[1]
2866
+ bar_spacer_ax.set_axis_off()
2867
+
2868
+ span_mean = np.nanmean(read_span_matrix, axis=0)
2869
+ quality_mean = np.nanmean(quality_matrix, axis=0)
2870
+ bar_positions = np.arange(read_span_matrix.shape[1]) + 0.5
2871
+ span_bar_ax.bar(
2872
+ bar_positions,
2873
+ span_mean,
2874
+ color=read_span_color,
2875
+ width=1.0,
2876
+ )
2877
+ span_bar_ax.set_title(f"{read_span_layer} mean")
2878
+ span_bar_ax.set_xlim(0, read_span_matrix.shape[1])
2879
+ span_bar_ax.tick_params(axis="x", labelbottom=False)
2880
+
2881
+ quality_bar_ax.bar(
2882
+ bar_positions,
2883
+ quality_mean,
2884
+ color="#4c72b0",
2885
+ width=1.0,
2886
+ )
2887
+ quality_bar_ax.set_title(f"{quality_layer} mean")
2888
+ quality_bar_ax.set_xlim(0, quality_matrix.shape[1])
2889
+ quality_bar_ax.tick_params(axis="x", labelbottom=False)
2890
+
2891
+ span_cmap = colors.ListedColormap(["white", read_span_color])
2892
+ span_norm = colors.BoundaryNorm([-0.5, 0.5, 1.5], span_cmap.N)
2893
+ sns.heatmap(
2894
+ read_span_matrix,
2895
+ cmap=span_cmap,
2896
+ norm=span_norm,
2897
+ ax=span_ax,
2898
+ yticklabels=False,
2899
+ cbar=False,
2900
+ )
2901
+ span_ax.set_title(read_span_layer)
2902
+
2903
+ sns.heatmap(
2904
+ quality_matrix,
2905
+ cmap=quality_cmap,
2906
+ ax=quality_ax,
2907
+ yticklabels=False,
2908
+ cbar=True,
2909
+ cbar_ax=cbar_ax,
2910
+ )
2911
+ quality_ax.set_title(quality_layer)
2912
+
2913
+ resolved_step = _resolve_xtick_step(quality_matrix.shape[1])
2914
+ for axis in (span_ax, quality_ax):
2915
+ if resolved_step is not None and resolved_step > 0:
2916
+ sites = np.arange(0, quality_matrix.shape[1], resolved_step)
2917
+ axis.set_xticks(sites)
2918
+ axis.set_xticklabels(
2919
+ subset.var_names[sites].astype(str),
2920
+ rotation=xtick_rotation,
2921
+ fontsize=xtick_fontsize,
2922
+ )
2923
+ else:
2924
+ axis.set_xticks([])
2925
+ if show_position_axis or xtick_step is not None:
2926
+ axis.set_xlabel("Position")
2927
+
2928
+ fig.suptitle(f"{sample} - {ref}")
2929
+ fig.tight_layout(rect=(0, 0, 1, 0.95))
2930
+
2931
+ out_file = None
2932
+ if save_path is not None:
2933
+ safe_name = f"{ref}__{sample}__read_span_quality".replace("=", "").replace(",", "_")
2934
+ out_file = save_path / f"{safe_name}.png"
2935
+ fig.savefig(out_file, dpi=300, bbox_inches="tight")
2936
+ plt.close(fig)
2937
+ else:
2938
+ plt.show()
2939
+
2940
+ results.append(
2941
+ {
2942
+ "reference": str(ref),
2943
+ "sample": str(sample),
2944
+ "quality_layer": quality_layer,
2945
+ "read_span_layer": read_span_layer,
2946
+ "n_positions": int(quality_matrix.shape[1]),
2947
+ "output_path": str(out_file) if out_file is not None else None,
2948
+ }
2949
+ )
2950
+
2951
+ return results
2952
+
2953
+
2954
+ def plot_hmm_layers_rolling_by_sample_ref(
2955
+ adata,
2956
+ layers: Optional[Sequence[str]] = None,
2957
+ sample_col: str = "Barcode",
2958
+ ref_col: str = "Reference_strand",
2959
+ samples: Optional[Sequence[str]] = None,
2960
+ references: Optional[Sequence[str]] = None,
2961
+ window: int = 51,
2962
+ min_periods: int = 1,
2963
+ center: bool = True,
2964
+ rows_per_page: int = 6,
2965
+ figsize_per_cell: Tuple[float, float] = (4.0, 2.5),
2966
+ dpi: int = 160,
2967
+ output_dir: Optional[str] = None,
2968
+ save: bool = True,
2969
+ show_raw: bool = False,
2970
+ cmap: str = "tab20",
2971
+ layer_colors: Optional[Mapping[str, Any]] = None,
2972
+ use_var_coords: bool = True,
2973
+ reindexed_var_suffix: str = "reindexed",
2974
+ ):
2975
+ """
2976
+ For each sample (row) and reference (col) plot the rolling average of the
2977
+ positional mean (mean across reads) for each layer listed.
2978
+
2979
+ Parameters
2980
+ ----------
2981
+ adata : AnnData
2982
+ Input annotated data (expects obs columns sample_col and ref_col).
2983
+ layers : list[str] | None
2984
+ Which adata.layers to plot. If None, attempts to autodetect layers whose
2985
+ matrices look like "HMM" outputs (else will error). If None and layers
2986
+ cannot be found, user must pass a list.
2987
+ sample_col, ref_col : str
2988
+ obs columns used to group rows.
2989
+ samples, references : optional lists
2990
+ explicit ordering of samples / references. If None, categories in adata.obs are used.
2991
+ window : int
2992
+ rolling window size (odd recommended). If window <= 1, no smoothing applied.
2993
+ min_periods : int
2994
+ min periods param for pd.Series.rolling.
2995
+ center : bool
2996
+ center the rolling window.
2997
+ rows_per_page : int
2998
+ paginate rows per page into multiple figures if needed.
2999
+ figsize_per_cell : (w,h)
3000
+ per-subplot size in inches.
3001
+ dpi : int
3002
+ figure dpi when saving.
3003
+ output_dir : str | None
3004
+ directory to save pages; created if necessary. If None and save=True, uses cwd.
3005
+ save : bool
3006
+ whether to save PNG files.
3007
+ show_raw : bool
3008
+ draw unsmoothed mean as faint line under smoothed curve.
3009
+ cmap : str
3010
+ matplotlib colormap for layer lines.
3011
+ layer_colors : dict[str, Any] | None
3012
+ Optional mapping of layer name to explicit line colors.
3013
+ use_var_coords : bool
3014
+ if True, tries to use adata.var_names (coerced to int) as x-axis coordinates; otherwise uses 0..n-1.
3015
+ reindexed_var_suffix : str
3016
+ Suffix for per-reference reindexed var columns (e.g., ``Reference_reindexed``) used when available.
3017
+
3018
+ Returns
3019
+ -------
3020
+ saved_files : list[str]
3021
+ list of saved filenames (may be empty if save=False).
3022
+ """
3023
+
3024
+ # --- basic checks / defaults ---
3025
+ if sample_col not in adata.obs.columns or ref_col not in adata.obs.columns:
3026
+ raise ValueError(
3027
+ f"sample_col '{sample_col}' and ref_col '{ref_col}' must exist in adata.obs"
3028
+ )
3029
+
3030
+ # canonicalize samples / refs
3031
+ if samples is None:
3032
+ sseries = adata.obs[sample_col]
3033
+ if not pd.api.types.is_categorical_dtype(sseries):
3034
+ sseries = sseries.astype("category")
3035
+ samples_all = list(sseries.cat.categories)
3036
+ else:
3037
+ samples_all = list(samples)
3038
+
3039
+ if references is None:
3040
+ rseries = adata.obs[ref_col]
3041
+ if not pd.api.types.is_categorical_dtype(rseries):
3042
+ rseries = rseries.astype("category")
3043
+ refs_all = list(rseries.cat.categories)
3044
+ else:
3045
+ refs_all = list(references)
3046
+
3047
+ # choose layers: if not provided, try a sensible default: all layers
3048
+ if layers is None:
3049
+ layers = list(adata.layers.keys())
3050
+ if len(layers) == 0:
3051
+ raise ValueError(
3052
+ "No adata.layers found. Please pass `layers=[...]` of the HMM layers to plot."
3053
+ )
3054
+ layers = list(layers)
3055
+
3056
+ # x coordinates (positions) + optional labels
3057
+ x_labels = None
3058
+ try:
3059
+ if use_var_coords:
3060
+ x_coords = np.array([int(v) for v in adata.var_names])
3061
+ else:
3062
+ raise Exception("user disabled var coords")
3063
+ except Exception:
3064
+ # fallback to 0..n_vars-1, but keep var_names as labels
3065
+ x_coords = np.arange(adata.shape[1], dtype=int)
3066
+ x_labels = adata.var_names.astype(str).tolist()
3067
+
3068
+ ref_reindexed_cols = {
3069
+ ref: f"{ref}_{reindexed_var_suffix}"
3070
+ for ref in refs_all
3071
+ if f"{ref}_{reindexed_var_suffix}" in adata.var
3072
+ }
3073
+
1429
3074
  # make output dir
1430
3075
  if save:
1431
3076
  outdir = output_dir or os.getcwd()
@@ -1441,7 +3086,9 @@ def plot_hmm_layers_rolling_by_sample_ref(
1441
3086
  # color cycle for layers
1442
3087
  cmap_obj = plt.get_cmap(cmap)
1443
3088
  n_layers = max(1, len(layers))
1444
- colors = [cmap_obj(i / max(1, n_layers - 1)) for i in range(n_layers)]
3089
+ fallback_colors = [cmap_obj(i / max(1, n_layers - 1)) for i in range(n_layers)]
3090
+ layer_colors = layer_colors or {}
3091
+ colors = [layer_colors.get(layer, fallback_colors[idx]) for idx, layer in enumerate(layers)]
1445
3092
 
1446
3093
  for page in range(total_pages):
1447
3094
  start = page * rows_per_page
@@ -1486,6 +3133,14 @@ def plot_hmm_layers_rolling_by_sample_ref(
1486
3133
 
1487
3134
  # for each layer, compute positional mean across reads (ignore NaNs)
1488
3135
  plotted_any = False
3136
+ reindexed_col = ref_reindexed_cols.get(ref_name)
3137
+ if reindexed_col is not None:
3138
+ try:
3139
+ ref_coords = np.asarray(adata.var[reindexed_col], dtype=int)
3140
+ except Exception:
3141
+ ref_coords = x_coords
3142
+ else:
3143
+ ref_coords = x_coords
1489
3144
  for li, layer in enumerate(layers):
1490
3145
  if layer in sub.layers:
1491
3146
  mat = sub.layers[layer]
@@ -1519,6 +3174,8 @@ def plot_hmm_layers_rolling_by_sample_ref(
1519
3174
  if np.all(np.isnan(col_mean)):
1520
3175
  continue
1521
3176
 
3177
+ valid_mask = np.isfinite(col_mean)
3178
+
1522
3179
  # smooth via pandas rolling (centered)
1523
3180
  if (window is None) or (window <= 1):
1524
3181
  smoothed = col_mean
@@ -1529,10 +3186,11 @@ def plot_hmm_layers_rolling_by_sample_ref(
1529
3186
  .mean()
1530
3187
  .to_numpy()
1531
3188
  )
3189
+ smoothed = np.where(valid_mask, smoothed, np.nan)
1532
3190
 
1533
3191
  # x axis: x_coords (trim/pad to match length)
1534
3192
  L = len(col_mean)
1535
- x = x_coords[:L]
3193
+ x = ref_coords[:L]
1536
3194
 
1537
3195
  # optionally plot raw faint line first
1538
3196
  if show_raw:
@@ -1557,6 +3215,13 @@ def plot_hmm_layers_rolling_by_sample_ref(
1557
3215
  ax.set_ylabel(f"{sample_name}\n(n={total_reads})", fontsize=8)
1558
3216
  if r_idx == nrows - 1:
1559
3217
  ax.set_xlabel("position", fontsize=8)
3218
+ if x_labels is not None and reindexed_col is None:
3219
+ max_ticks = 8
3220
+ tick_step = max(1, int(math.ceil(len(x_labels) / max_ticks)))
3221
+ tick_positions = x_coords[::tick_step]
3222
+ tick_labels = x_labels[::tick_step]
3223
+ ax.set_xticks(tick_positions)
3224
+ ax.set_xticklabels(tick_labels, fontsize=7, rotation=45, ha="right")
1560
3225
 
1561
3226
  # legend (only show in top-left plot to reduce clutter)
1562
3227
  if (r_idx == 0 and c_idx == 0) and plotted_any:
@@ -1580,3 +3245,124 @@ def plot_hmm_layers_rolling_by_sample_ref(
1580
3245
  plt.close(fig)
1581
3246
 
1582
3247
  return saved_files
3248
+
3249
+
3250
+ def _resolve_embedding(adata: "ad.AnnData", basis: str) -> np.ndarray:
3251
+ key = basis if basis.startswith("X_") else f"X_{basis}"
3252
+ if key not in adata.obsm:
3253
+ raise KeyError(f"Embedding '{key}' not found in adata.obsm.")
3254
+ embedding = np.asarray(adata.obsm[key])
3255
+ if embedding.shape[1] < 2:
3256
+ raise ValueError(f"Embedding '{key}' must have at least two dimensions.")
3257
+ return embedding[:, :2]
3258
+
3259
+
3260
+ def plot_embedding(
3261
+ adata: "ad.AnnData",
3262
+ *,
3263
+ basis: str,
3264
+ color: str | Sequence[str],
3265
+ output_dir: Path | str,
3266
+ prefix: str | None = None,
3267
+ point_size: float = 12,
3268
+ alpha: float = 0.8,
3269
+ ) -> Dict[str, Path]:
3270
+ """Plot a 2D embedding with scanpy-style color options.
3271
+
3272
+ Args:
3273
+ adata: AnnData object with ``obsm['X_<basis>']``.
3274
+ basis: Embedding basis name (e.g., ``'umap'``, ``'pca'``).
3275
+ color: Obs column name or list of names to color by.
3276
+ output_dir: Directory to save plots.
3277
+ prefix: Optional filename prefix.
3278
+ point_size: Marker size for scatter plots.
3279
+ alpha: Marker transparency.
3280
+
3281
+ Returns:
3282
+ Dict[str, Path]: Mapping of color keys to saved plot paths.
3283
+ """
3284
+ output_path = Path(output_dir)
3285
+ output_path.mkdir(parents=True, exist_ok=True)
3286
+ embedding = _resolve_embedding(adata, basis)
3287
+ colors = [color] if isinstance(color, str) else list(color)
3288
+ saved: Dict[str, Path] = {}
3289
+
3290
+ for color_key in colors:
3291
+ if color_key not in adata.obs:
3292
+ logger.warning("Color key '%s' not found in adata.obs; skipping.", color_key)
3293
+ continue
3294
+ values = adata.obs[color_key]
3295
+ fig, ax = plt.subplots(figsize=(5.5, 4.5))
3296
+
3297
+ if pd.api.types.is_categorical_dtype(values) or values.dtype == object:
3298
+ categories = pd.Categorical(values)
3299
+ label_strings = categories.categories.astype(str)
3300
+ palette = sns.color_palette("tab20", n_colors=len(label_strings))
3301
+ color_map = dict(zip(label_strings, palette))
3302
+ codes = categories.codes
3303
+ mapped = np.empty(len(codes), dtype=object)
3304
+ valid = codes >= 0
3305
+ if np.any(valid):
3306
+ valid_codes = codes[valid]
3307
+ mapped_values = np.empty(len(valid_codes), dtype=object)
3308
+ for i, idx in enumerate(valid_codes):
3309
+ mapped_values[i] = palette[idx]
3310
+ mapped[valid] = mapped_values
3311
+ mapped[~valid] = "#bdbdbd"
3312
+ ax.scatter(
3313
+ embedding[:, 0],
3314
+ embedding[:, 1],
3315
+ c=list(mapped),
3316
+ s=point_size,
3317
+ alpha=alpha,
3318
+ linewidths=0,
3319
+ )
3320
+ handles = [
3321
+ patches.Patch(color=color_map[label], label=str(label)) for label in label_strings
3322
+ ]
3323
+ ax.legend(handles=handles, loc="best", fontsize=8, frameon=False)
3324
+ else:
3325
+ scatter = ax.scatter(
3326
+ embedding[:, 0],
3327
+ embedding[:, 1],
3328
+ c=values.astype(float),
3329
+ cmap="viridis",
3330
+ s=point_size,
3331
+ alpha=alpha,
3332
+ linewidths=0,
3333
+ )
3334
+ fig.colorbar(scatter, ax=ax, label=color_key)
3335
+
3336
+ ax.set_xlabel(f"{basis.upper()} 1")
3337
+ ax.set_ylabel(f"{basis.upper()} 2")
3338
+ ax.set_title(f"{basis.upper()} colored by {color_key}")
3339
+ fig.tight_layout()
3340
+
3341
+ filename_prefix = prefix or basis
3342
+ safe_key = str(color_key).replace(" ", "_")
3343
+ output_file = output_path / f"{filename_prefix}_{safe_key}.png"
3344
+ fig.savefig(output_file, dpi=200)
3345
+ plt.close(fig)
3346
+ saved[color_key] = output_file
3347
+
3348
+ return saved
3349
+
3350
+
3351
+ def plot_umap(
3352
+ adata: "ad.AnnData",
3353
+ *,
3354
+ color: str | Sequence[str],
3355
+ output_dir: Path | str,
3356
+ ) -> Dict[str, Path]:
3357
+ """Plot UMAP embedding with scanpy-style color options."""
3358
+ return plot_embedding(adata, basis="umap", color=color, output_dir=output_dir, prefix="umap")
3359
+
3360
+
3361
+ def plot_pca(
3362
+ adata: "ad.AnnData",
3363
+ *,
3364
+ color: str | Sequence[str],
3365
+ output_dir: Path | str,
3366
+ ) -> Dict[str, Path]:
3367
+ """Plot PCA embedding with scanpy-style color options."""
3368
+ return plot_embedding(adata, basis="pca", color=color, output_dir=output_dir, prefix="pca")