smftools 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. smftools/_version.py +1 -1
  2. smftools/cli/chimeric_adata.py +1563 -0
  3. smftools/cli/helpers.py +49 -7
  4. smftools/cli/hmm_adata.py +250 -32
  5. smftools/cli/latent_adata.py +773 -0
  6. smftools/cli/load_adata.py +78 -74
  7. smftools/cli/preprocess_adata.py +122 -58
  8. smftools/cli/recipes.py +26 -0
  9. smftools/cli/spatial_adata.py +74 -112
  10. smftools/cli/variant_adata.py +423 -0
  11. smftools/cli_entry.py +52 -4
  12. smftools/config/conversion.yaml +1 -1
  13. smftools/config/deaminase.yaml +3 -0
  14. smftools/config/default.yaml +85 -12
  15. smftools/config/experiment_config.py +146 -1
  16. smftools/constants.py +69 -0
  17. smftools/hmm/HMM.py +88 -0
  18. smftools/hmm/call_hmm_peaks.py +1 -1
  19. smftools/informatics/__init__.py +6 -0
  20. smftools/informatics/bam_functions.py +358 -8
  21. smftools/informatics/binarize_converted_base_identities.py +2 -89
  22. smftools/informatics/converted_BAM_to_adata.py +636 -175
  23. smftools/informatics/h5ad_functions.py +198 -2
  24. smftools/informatics/modkit_extract_to_adata.py +1007 -425
  25. smftools/informatics/sequence_encoding.py +72 -0
  26. smftools/logging_utils.py +21 -2
  27. smftools/metadata.py +1 -1
  28. smftools/plotting/__init__.py +26 -3
  29. smftools/plotting/autocorrelation_plotting.py +22 -4
  30. smftools/plotting/chimeric_plotting.py +1893 -0
  31. smftools/plotting/classifiers.py +28 -14
  32. smftools/plotting/general_plotting.py +62 -1583
  33. smftools/plotting/hmm_plotting.py +1670 -8
  34. smftools/plotting/latent_plotting.py +804 -0
  35. smftools/plotting/plotting_utils.py +243 -0
  36. smftools/plotting/position_stats.py +16 -8
  37. smftools/plotting/preprocess_plotting.py +281 -0
  38. smftools/plotting/qc_plotting.py +8 -3
  39. smftools/plotting/spatial_plotting.py +1134 -0
  40. smftools/plotting/variant_plotting.py +1231 -0
  41. smftools/preprocessing/__init__.py +4 -0
  42. smftools/preprocessing/append_base_context.py +18 -18
  43. smftools/preprocessing/append_mismatch_frequency_sites.py +187 -0
  44. smftools/preprocessing/append_sequence_mismatch_annotations.py +171 -0
  45. smftools/preprocessing/append_variant_call_layer.py +480 -0
  46. smftools/preprocessing/calculate_consensus.py +1 -1
  47. smftools/preprocessing/calculate_read_modification_stats.py +6 -1
  48. smftools/preprocessing/flag_duplicate_reads.py +4 -4
  49. smftools/preprocessing/invert_adata.py +1 -0
  50. smftools/readwrite.py +159 -99
  51. smftools/schema/anndata_schema_v1.yaml +15 -1
  52. smftools/tools/__init__.py +10 -0
  53. smftools/tools/calculate_knn.py +121 -0
  54. smftools/tools/calculate_leiden.py +57 -0
  55. smftools/tools/calculate_nmf.py +130 -0
  56. smftools/tools/calculate_pca.py +180 -0
  57. smftools/tools/calculate_umap.py +79 -80
  58. smftools/tools/position_stats.py +4 -4
  59. smftools/tools/rolling_nn_distance.py +872 -0
  60. smftools/tools/sequence_alignment.py +140 -0
  61. smftools/tools/tensor_factorization.py +217 -0
  62. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/METADATA +9 -5
  63. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/RECORD +66 -45
  64. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/WHEEL +0 -0
  65. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/entry_points.txt +0 -0
  66. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,243 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ from smftools.logging_utils import get_logger
9
+ from smftools.optional_imports import require
10
+
11
+ sns = require("seaborn", extra="plotting", purpose="plot styling")
12
+
13
+ logger = get_logger(__name__)
14
+
15
+ if TYPE_CHECKING:
16
+ import anndata as ad
17
+
18
+
19
+ def _fixed_tick_positions(n_positions: int, n_ticks: int) -> np.ndarray:
20
+ """
21
+ Return indices for ~n_ticks evenly spaced labels across [0, n_positions-1].
22
+ Always includes 0 and n_positions-1 when possible.
23
+ """
24
+ n_ticks = int(max(2, n_ticks))
25
+ if n_positions <= n_ticks:
26
+ return np.arange(n_positions)
27
+
28
+ pos = np.linspace(0, n_positions - 1, n_ticks)
29
+ return np.unique(np.round(pos).astype(int))
30
+
31
+
32
+ def _select_labels(
33
+ subset: "ad.AnnData", sites: np.ndarray, reference: str, index_col_suffix: str | None
34
+ ) -> np.ndarray:
35
+ """
36
+ Select tick labels for the heatmap axis.
37
+
38
+ Parameters
39
+ ----------
40
+ subset : AnnData view
41
+ The per-bin subset of the AnnData.
42
+ sites : np.ndarray[int]
43
+ Indices of the subset.var positions to annotate.
44
+ reference : str
45
+ Reference name (e.g., '6B6_top').
46
+ index_col_suffix : None or str
47
+ If None → use subset.var_names
48
+ Else → use subset.var[f"{reference}_{index_col_suffix}"]
49
+
50
+ Returns
51
+ -------
52
+ np.ndarray[str]
53
+ The labels to use for tick positions.
54
+ """
55
+ if sites.size == 0:
56
+ return np.array([])
57
+
58
+ if index_col_suffix is None:
59
+ return subset.var_names[sites].astype(str)
60
+
61
+ colname = f"{reference}_{index_col_suffix}"
62
+
63
+ if colname not in subset.var:
64
+ raise KeyError(
65
+ f"index_col_suffix='{index_col_suffix}' requires var column '{colname}', "
66
+ f"but it is not present in adata.var."
67
+ )
68
+
69
+ labels = subset.var[colname].astype(str).values
70
+ return labels[sites]
71
+
72
+
73
+ def normalized_mean(matrix: np.ndarray, *, ignore_nan: bool = True) -> np.ndarray:
74
+ """Compute normalized column means for a matrix.
75
+
76
+ Args:
77
+ matrix: Input matrix.
78
+
79
+ Returns:
80
+ 1D array of normalized means.
81
+ """
82
+ mean = np.nanmean(matrix, axis=0) if ignore_nan else np.mean(matrix, axis=0)
83
+ denom = (mean.max() - mean.min()) + 1e-9
84
+ return (mean - mean.min()) / denom
85
+
86
+
87
+ def _layer_to_numpy(
88
+ subset: "ad.AnnData",
89
+ layer_name: str,
90
+ sites: np.ndarray | None = None,
91
+ *,
92
+ fill_nan_strategy: str = "value",
93
+ fill_nan_value: float = -1,
94
+ ) -> np.ndarray:
95
+ """Return a (copied) numpy array for a layer with optional NaN filling."""
96
+ if sites is not None:
97
+ layer_data = subset[:, sites].layers[layer_name]
98
+ else:
99
+ layer_data = subset.layers[layer_name]
100
+
101
+ if hasattr(layer_data, "toarray"):
102
+ arr = layer_data.toarray()
103
+ else:
104
+ arr = np.asarray(layer_data)
105
+
106
+ arr = np.array(arr, copy=True)
107
+
108
+ if fill_nan_strategy == "none":
109
+ return arr
110
+
111
+ if fill_nan_strategy not in {"value", "col_mean"}:
112
+ raise ValueError("fill_nan_strategy must be 'none', 'value', or 'col_mean'.")
113
+
114
+ arr = arr.astype(float, copy=False)
115
+
116
+ if fill_nan_strategy == "value":
117
+ return np.where(np.isnan(arr), fill_nan_value, arr)
118
+
119
+ col_mean = np.nanmean(arr, axis=0)
120
+ if np.any(np.isnan(col_mean)):
121
+ col_mean = np.where(np.isnan(col_mean), fill_nan_value, col_mean)
122
+ return np.where(np.isnan(arr), col_mean, arr)
123
+
124
+
125
+ def _infer_zero_is_valid(layer_name: str | None, matrix: np.ndarray) -> bool:
126
+ """Infer whether zeros should count as valid (unmethylated) values."""
127
+ if layer_name and "nan0_0minus1" in layer_name:
128
+ return False
129
+ if np.isnan(matrix).any():
130
+ return True
131
+ if np.any(matrix < 0):
132
+ return False
133
+ return True
134
+
135
+
136
+ def methylation_fraction(
137
+ matrix: np.ndarray, *, ignore_nan: bool = True, zero_is_valid: bool = False
138
+ ) -> np.ndarray:
139
+ """
140
+ Fraction methylated per column.
141
+ Methylated = 1
142
+ Valid = finite AND not 0 (unless zero_is_valid=True)
143
+ """
144
+ matrix = np.asarray(matrix)
145
+ if not ignore_nan:
146
+ matrix = np.where(np.isnan(matrix), 0, matrix)
147
+ finite_mask = np.isfinite(matrix)
148
+ valid_mask = finite_mask if zero_is_valid else (finite_mask & (matrix != 0))
149
+ methyl_mask = (matrix == 1) & np.isfinite(matrix)
150
+
151
+ methylated = methyl_mask.sum(axis=0)
152
+ valid = valid_mask.sum(axis=0)
153
+
154
+ return np.divide(
155
+ methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0
156
+ )
157
+
158
+
159
+ def _methylation_fraction_for_layer(
160
+ matrix: np.ndarray,
161
+ layer_name: str | None,
162
+ *,
163
+ ignore_nan: bool = True,
164
+ zero_is_valid: bool | None = None,
165
+ ) -> np.ndarray:
166
+ """Compute methylation fractions with layer-aware zero handling."""
167
+ matrix = np.asarray(matrix)
168
+ if zero_is_valid is None:
169
+ zero_is_valid = _infer_zero_is_valid(layer_name, matrix)
170
+ return methylation_fraction(matrix, ignore_nan=ignore_nan, zero_is_valid=zero_is_valid)
171
+
172
+
173
+ def clean_barplot(
174
+ ax,
175
+ mean_values,
176
+ title,
177
+ *,
178
+ y_max: float | None = 1.0,
179
+ y_label: str = "Mean",
180
+ y_ticks: list[float] | None = None,
181
+ ):
182
+ """Format a barplot with consistent axes and labels.
183
+
184
+ Args:
185
+ ax: Matplotlib axes.
186
+ mean_values: Values to plot.
187
+ title: Plot title.
188
+ y_max: Optional y-axis max; inferred from data if not provided.
189
+ y_label: Y-axis label.
190
+ y_ticks: Optional y-axis ticks.
191
+ """
192
+ logger.debug("Formatting barplot '%s' with %s values.", title, len(mean_values))
193
+ x = np.arange(len(mean_values))
194
+ ax.bar(x, mean_values, color="gray", width=1.0, align="edge")
195
+ ax.set_xlim(0, len(mean_values))
196
+ if y_ticks is None and y_max == 1.0:
197
+ y_ticks = [0.0, 0.5, 1.0]
198
+ if y_max is None:
199
+ y_max = np.nanmax(mean_values) if len(mean_values) else 1.0
200
+ if not np.isfinite(y_max) or y_max <= 0:
201
+ y_max = 1.0
202
+ y_max *= 1.05
203
+ ax.set_ylim(0, y_max)
204
+ if y_ticks is not None:
205
+ ax.set_yticks(y_ticks)
206
+ ax.set_ylabel(y_label)
207
+ ax.set_title(title, fontsize=12, pad=2)
208
+
209
+ for spine_name, spine in ax.spines.items():
210
+ spine.set_visible(spine_name == "left")
211
+
212
+ ax.tick_params(axis="x", which="both", bottom=False, top=False, labelbottom=False)
213
+
214
+
215
+ def make_row_colors(meta: pd.DataFrame) -> pd.DataFrame:
216
+ """
217
+ Convert metadata columns to RGB colors without invoking pandas Categorical.map
218
+ (MultiIndex-safe, category-safe).
219
+ """
220
+ row_colors = pd.DataFrame(index=meta.index)
221
+
222
+ for col in meta.columns:
223
+ s = meta[col].astype("object")
224
+
225
+ def _to_label(x: Any) -> str:
226
+ if x is None:
227
+ return "NA"
228
+ if isinstance(x, float) and np.isnan(x):
229
+ return "NA"
230
+ if isinstance(x, pd.MultiIndex):
231
+ return "MultiIndex"
232
+ if isinstance(x, tuple):
233
+ return "|".join(map(str, x))
234
+ return str(x)
235
+
236
+ labels = np.array([_to_label(x) for x in s.to_numpy()], dtype=object)
237
+ uniq = pd.unique(labels)
238
+ palette = dict(zip(uniq, sns.color_palette(n_colors=len(uniq))))
239
+
240
+ colors = [palette.get(lbl, (0.7, 0.7, 0.7)) for lbl in labels]
241
+ row_colors[col] = colors
242
+
243
+ return row_colors
@@ -1,7 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from smftools.logging_utils import get_logger
3
4
  from smftools.optional_imports import require
4
5
 
6
+ logger = get_logger(__name__)
7
+
5
8
 
6
9
  def plot_volcano_relative_risk(
7
10
  results_dict,
@@ -29,10 +32,11 @@ def plot_volcano_relative_risk(
29
32
 
30
33
  plt = require("matplotlib.pyplot", extra="plotting", purpose="relative risk plots")
31
34
 
35
+ logger.info("Plotting volcano relative risk plots.")
32
36
  for ref, group_results in results_dict.items():
33
37
  for group_label, (results_df, _) in group_results.items():
34
38
  if results_df.empty:
35
- print(f"Skipping empty results for {ref} / {group_label}")
39
+ logger.warning("Skipping empty results for %s / %s.", ref, group_label)
36
40
  continue
37
41
 
38
42
  # Split by site type
@@ -100,7 +104,7 @@ def plot_volcano_relative_risk(
100
104
  )
101
105
  out_file = os.path.join(save_path, f"{safe_name}.png")
102
106
  plt.savefig(out_file, dpi=300)
103
- print(f"Saved: {out_file}")
107
+ logger.info("Saved volcano relative risk plot to %s.", out_file)
104
108
 
105
109
  plt.show()
106
110
 
@@ -131,10 +135,11 @@ def plot_bar_relative_risk(
131
135
 
132
136
  plt = require("matplotlib.pyplot", extra="plotting", purpose="relative risk plots")
133
137
 
138
+ logger.info("Plotting bar relative risk plots.")
134
139
  for ref, group_data in results_dict.items():
135
140
  for group_label, (df, _) in group_data.items():
136
141
  if df.empty:
137
- print(f"Skipping empty result for {ref} / {group_label}")
142
+ logger.warning("Skipping empty result for %s / %s.", ref, group_label)
138
143
  continue
139
144
 
140
145
  df = df.copy()
@@ -206,7 +211,7 @@ def plot_bar_relative_risk(
206
211
  )
207
212
  out_file = os.path.join(save_path, f"{safe_name}.png")
208
213
  plt.savefig(out_file, dpi=300)
209
- print(f"📁 Saved: {out_file}")
214
+ logger.info("Saved bar relative risk plot to %s.", out_file)
210
215
 
211
216
  plt.show()
212
217
 
@@ -240,6 +245,8 @@ def plot_positionwise_matrix(
240
245
  plt = require("matplotlib.pyplot", extra="plotting", purpose="position stats plots")
241
246
  sns = require("seaborn", extra="plotting", purpose="position stats plots")
242
247
 
248
+ logger.info("Plotting positionwise matrices for key '%s'.", key)
249
+
243
250
  def find_closest_index(index, target):
244
251
  """Find the index value closest to a target value."""
245
252
  index_vals = pd.to_numeric(index, errors="coerce")
@@ -357,7 +364,7 @@ def plot_positionwise_matrix(
357
364
  va="center",
358
365
  fontsize=10,
359
366
  )
360
- print(f"Error plotting line for {highlight_axis}={pos}: {e}")
367
+ logger.warning("Error plotting line for %s=%s: %s", highlight_axis, pos, e)
361
368
 
362
369
  line_ax.set_title(f"{highlight_axis.capitalize()} Profile(s)")
363
370
  line_ax.set_xlabel(f"{'Column' if highlight_axis == 'row' else 'Row'} position")
@@ -373,7 +380,7 @@ def plot_positionwise_matrix(
373
380
  safe_name = group.replace("=", "").replace("__", "_").replace(",", "_")
374
381
  out_file = os.path.join(save_path, f"{key}_{safe_name}.png")
375
382
  plt.savefig(out_file, dpi=300)
376
- print(f"📁 Saved: {out_file}")
383
+ logger.info("Saved positionwise matrix plot to %s.", out_file)
377
384
 
378
385
  plt.show()
379
386
 
@@ -423,6 +430,7 @@ def plot_positionwise_matrix_grid(
423
430
  grid_spec = require("matplotlib.gridspec", extra="plotting", purpose="position stats plots")
424
431
  GridSpec = grid_spec.GridSpec
425
432
 
433
+ logger.info("Plotting positionwise matrix grid for key '%s'.", key)
426
434
  matrices = adata.uns[key]
427
435
  group_labels = list(matrices.keys())
428
436
 
@@ -515,7 +523,7 @@ def plot_positionwise_matrix_grid(
515
523
  os.makedirs(save_path, exist_ok=True)
516
524
  fname = outer_label.replace("_", "").replace("=", "") + ".png"
517
525
  plt.savefig(os.path.join(save_path, fname), dpi=300, bbox_inches="tight")
518
- print(f"Saved {fname}")
526
+ logger.info("Saved positionwise matrix grid plot to %s.", fname)
519
527
 
520
528
  plt.close(fig)
521
529
 
@@ -527,4 +535,4 @@ def plot_positionwise_matrix_grid(
527
535
  for outer_label in parsed["outer"].unique():
528
536
  plot_one_grid(outer_label)
529
537
 
530
- print("Finished plotting all grids.")
538
+ logger.info("Finished plotting all grids.")
@@ -0,0 +1,281 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Any, Dict, List, Sequence
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import scipy.cluster.hierarchy as sch
9
+
10
+ from smftools.logging_utils import get_logger
11
+ from smftools.optional_imports import require
12
+
13
+ colors = require("matplotlib.colors", extra="plotting", purpose="plot rendering")
14
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="plot rendering")
15
+ sns = require("seaborn", extra="plotting", purpose="plot styling")
16
+
17
+ logger = get_logger(__name__)
18
+
19
+
20
+ def plot_read_span_quality_clustermaps(
21
+ adata,
22
+ sample_col: str = "Sample_Names",
23
+ reference_col: str = "Reference_strand",
24
+ quality_layer: str = "base_quality_scores",
25
+ read_span_layer: str = "read_span_mask",
26
+ quality_cmap: str = "viridis",
27
+ read_span_color: str = "#2ca25f",
28
+ max_nan_fraction: float | None = None,
29
+ min_quality: float | None = None,
30
+ min_length: int | None = None,
31
+ min_mapped_length_to_reference_length_ratio: float | None = None,
32
+ demux_types: Sequence[str] = ("single", "double", "already"),
33
+ max_reads: int | None = None,
34
+ xtick_step: int | None = None,
35
+ xtick_rotation: int = 90,
36
+ xtick_fontsize: int = 9,
37
+ show_position_axis: bool = False,
38
+ position_axis_tick_target: int = 25,
39
+ save_path: str | Path | None = None,
40
+ ) -> List[Dict[str, Any]]:
41
+ """Plot read-span mask and base quality clustermaps side by side.
42
+
43
+ Clustering is performed using the base-quality layer ordering, which is then
44
+ applied to the read-span mask to keep the two panels aligned.
45
+
46
+ Args:
47
+ adata: AnnData with read-span and base-quality layers.
48
+ sample_col: Column in ``adata.obs`` that identifies samples.
49
+ reference_col: Column in ``adata.obs`` that identifies references.
50
+ quality_layer: Layer name containing base-quality scores.
51
+ read_span_layer: Layer name containing read-span masks.
52
+ quality_cmap: Colormap for base-quality scores.
53
+ read_span_color: Color for read-span mask (1-values); 0-values are white.
54
+ max_nan_fraction: Optional maximum fraction of NaNs allowed per position; positions
55
+ above this threshold are excluded.
56
+ min_quality: Optional minimum read quality filter.
57
+ min_length: Optional minimum mapped length filter.
58
+ min_mapped_length_to_reference_length_ratio: Optional min length ratio filter.
59
+ demux_types: Allowed ``demux_type`` values, if present in ``adata.obs``.
60
+ max_reads: Optional maximum number of reads to plot per sample/reference.
61
+ xtick_step: Spacing between x-axis tick labels (None = no labels).
62
+ xtick_rotation: Rotation for x-axis tick labels.
63
+ xtick_fontsize: Font size for x-axis tick labels.
64
+ show_position_axis: Whether to draw a position axis with tick labels.
65
+ position_axis_tick_target: Approximate number of ticks to show when auto-sizing.
66
+ save_path: Optional output directory for saving plots.
67
+
68
+ Returns:
69
+ List of dictionaries with per-plot metadata and output paths.
70
+ """
71
+ logger.info("Plotting read span and quality clustermaps.")
72
+
73
+ def _mask_or_true(series_name: str, predicate):
74
+ if series_name not in adata.obs:
75
+ return pd.Series(True, index=adata.obs.index)
76
+ s = adata.obs[series_name]
77
+ try:
78
+ return predicate(s)
79
+ except Exception:
80
+ return pd.Series(True, index=s.index)
81
+
82
+ def _resolve_xtick_step(n_positions: int) -> int | None:
83
+ if xtick_step is not None:
84
+ return xtick_step
85
+ if not show_position_axis:
86
+ return None
87
+ return max(1, int(np.ceil(n_positions / position_axis_tick_target)))
88
+
89
+ def _fill_nan_with_col_means(matrix: np.ndarray) -> np.ndarray:
90
+ filled = matrix.copy()
91
+ col_means = np.nanmean(filled, axis=0)
92
+ col_means = np.where(np.isnan(col_means), 0.0, col_means)
93
+ nan_rows, nan_cols = np.where(np.isnan(filled))
94
+ filled[nan_rows, nan_cols] = col_means[nan_cols]
95
+ return filled
96
+
97
+ if quality_layer not in adata.layers:
98
+ raise KeyError(f"Layer '{quality_layer}' not found in adata.layers")
99
+ if read_span_layer not in adata.layers:
100
+ raise KeyError(f"Layer '{read_span_layer}' not found in adata.layers")
101
+ if max_nan_fraction is not None and not (0 <= max_nan_fraction <= 1):
102
+ raise ValueError("max_nan_fraction must be between 0 and 1.")
103
+ if position_axis_tick_target < 1:
104
+ raise ValueError("position_axis_tick_target must be at least 1.")
105
+
106
+ results: List[Dict[str, Any]] = []
107
+ save_path = Path(save_path) if save_path is not None else None
108
+ if save_path is not None:
109
+ save_path.mkdir(parents=True, exist_ok=True)
110
+
111
+ for col in (sample_col, reference_col):
112
+ if col not in adata.obs:
113
+ raise KeyError(f"{col} not in adata.obs")
114
+ if not isinstance(adata.obs[col].dtype, pd.CategoricalDtype):
115
+ adata.obs[col] = adata.obs[col].astype("category")
116
+
117
+ for ref in adata.obs[reference_col].cat.categories:
118
+ for sample in adata.obs[sample_col].cat.categories:
119
+ qmask = _mask_or_true(
120
+ "read_quality",
121
+ (lambda s: s >= float(min_quality))
122
+ if (min_quality is not None)
123
+ else (lambda s: pd.Series(True, index=s.index)),
124
+ )
125
+ lm_mask = _mask_or_true(
126
+ "mapped_length",
127
+ (lambda s: s >= float(min_length))
128
+ if (min_length is not None)
129
+ else (lambda s: pd.Series(True, index=s.index)),
130
+ )
131
+ lrr_mask = _mask_or_true(
132
+ "mapped_length_to_reference_length_ratio",
133
+ (lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
134
+ if (min_mapped_length_to_reference_length_ratio is not None)
135
+ else (lambda s: pd.Series(True, index=s.index)),
136
+ )
137
+ demux_mask = _mask_or_true(
138
+ "demux_type",
139
+ (lambda s: s.astype("string").isin(list(demux_types)))
140
+ if (demux_types is not None)
141
+ else (lambda s: pd.Series(True, index=s.index)),
142
+ )
143
+
144
+ row_mask = (
145
+ (adata.obs[reference_col] == ref)
146
+ & (adata.obs[sample_col] == sample)
147
+ & qmask
148
+ & lm_mask
149
+ & lrr_mask
150
+ & demux_mask
151
+ )
152
+ if not bool(row_mask.any()):
153
+ continue
154
+
155
+ subset = adata[row_mask, :].copy()
156
+ quality_matrix = np.asarray(subset.layers[quality_layer]).astype(float)
157
+ quality_matrix[quality_matrix < 0] = np.nan
158
+ read_span_matrix = np.asarray(subset.layers[read_span_layer]).astype(float)
159
+
160
+ if max_nan_fraction is not None:
161
+ nan_mask = np.isnan(quality_matrix) | np.isnan(read_span_matrix)
162
+ nan_fraction = nan_mask.mean(axis=0)
163
+ keep_columns = nan_fraction <= max_nan_fraction
164
+ if not np.any(keep_columns):
165
+ continue
166
+ quality_matrix = quality_matrix[:, keep_columns]
167
+ read_span_matrix = read_span_matrix[:, keep_columns]
168
+ subset = subset[:, keep_columns].copy()
169
+
170
+ if max_reads is not None and quality_matrix.shape[0] > max_reads:
171
+ quality_matrix = quality_matrix[:max_reads]
172
+ read_span_matrix = read_span_matrix[:max_reads]
173
+ subset = subset[:max_reads, :].copy()
174
+
175
+ if quality_matrix.size == 0:
176
+ continue
177
+
178
+ quality_filled = _fill_nan_with_col_means(quality_matrix)
179
+ linkage = sch.linkage(quality_filled, method="ward")
180
+ order = sch.leaves_list(linkage)
181
+
182
+ quality_matrix = quality_matrix[order]
183
+ read_span_matrix = read_span_matrix[order]
184
+
185
+ fig, axes = plt.subplots(
186
+ nrows=2,
187
+ ncols=3,
188
+ figsize=(18, 6),
189
+ sharex="col",
190
+ gridspec_kw={"height_ratios": [1, 4], "width_ratios": [1, 1, 0.05]},
191
+ )
192
+ span_bar_ax, quality_bar_ax, bar_spacer_ax = axes[0]
193
+ span_ax, quality_ax, cbar_ax = axes[1]
194
+ bar_spacer_ax.set_axis_off()
195
+
196
+ span_mean = np.nanmean(read_span_matrix, axis=0)
197
+ quality_mean = np.nanmean(quality_matrix, axis=0)
198
+ bar_positions = np.arange(read_span_matrix.shape[1]) + 0.5
199
+ span_bar_ax.bar(
200
+ bar_positions,
201
+ span_mean,
202
+ color=read_span_color,
203
+ width=1.0,
204
+ )
205
+ span_bar_ax.set_title(f"{read_span_layer} mean")
206
+ span_bar_ax.set_xlim(0, read_span_matrix.shape[1])
207
+ span_bar_ax.tick_params(axis="x", labelbottom=False)
208
+
209
+ quality_bar_ax.bar(
210
+ bar_positions,
211
+ quality_mean,
212
+ color="#4c72b0",
213
+ width=1.0,
214
+ )
215
+ quality_bar_ax.set_title(f"{quality_layer} mean")
216
+ quality_bar_ax.set_xlim(0, quality_matrix.shape[1])
217
+ quality_bar_ax.tick_params(axis="x", labelbottom=False)
218
+
219
+ span_cmap = colors.ListedColormap(["white", read_span_color])
220
+ span_norm = colors.BoundaryNorm([-0.5, 0.5, 1.5], span_cmap.N)
221
+ sns.heatmap(
222
+ read_span_matrix,
223
+ cmap=span_cmap,
224
+ norm=span_norm,
225
+ ax=span_ax,
226
+ yticklabels=False,
227
+ cbar=False,
228
+ )
229
+ span_ax.set_title(read_span_layer)
230
+
231
+ sns.heatmap(
232
+ quality_matrix,
233
+ cmap=quality_cmap,
234
+ ax=quality_ax,
235
+ yticklabels=False,
236
+ cbar=True,
237
+ cbar_ax=cbar_ax,
238
+ )
239
+ quality_ax.set_title(quality_layer)
240
+
241
+ resolved_step = _resolve_xtick_step(quality_matrix.shape[1])
242
+ for axis in (span_ax, quality_ax):
243
+ if resolved_step is not None and resolved_step > 0:
244
+ sites = np.arange(0, quality_matrix.shape[1], resolved_step)
245
+ axis.set_xticks(sites)
246
+ axis.set_xticklabels(
247
+ subset.var_names[sites].astype(str),
248
+ rotation=xtick_rotation,
249
+ fontsize=xtick_fontsize,
250
+ )
251
+ else:
252
+ axis.set_xticks([])
253
+ if show_position_axis or xtick_step is not None:
254
+ axis.set_xlabel("Position")
255
+
256
+ n_reads = quality_matrix.shape[0]
257
+ fig.suptitle(f"{sample} - {ref} - {n_reads} reads")
258
+ fig.tight_layout(rect=(0, 0, 1, 0.95))
259
+
260
+ out_file = None
261
+ if save_path is not None:
262
+ safe_name = f"{ref}__{sample}__read_span_quality".replace("=", "").replace(",", "_")
263
+ out_file = save_path / f"{safe_name}.png"
264
+ fig.savefig(out_file, dpi=300, bbox_inches="tight")
265
+ plt.close(fig)
266
+ logger.info("Saved read span/quality clustermap to %s.", out_file)
267
+ else:
268
+ plt.show()
269
+
270
+ results.append(
271
+ {
272
+ "reference": str(ref),
273
+ "sample": str(sample),
274
+ "quality_layer": quality_layer,
275
+ "read_span_layer": read_span_layer,
276
+ "n_positions": int(quality_matrix.shape[1]),
277
+ "output_path": str(out_file) if out_file is not None else None,
278
+ }
279
+ )
280
+
281
+ return results
@@ -5,10 +5,13 @@ import os
5
5
  import numpy as np
6
6
  import pandas as pd
7
7
 
8
+ from smftools.logging_utils import get_logger
8
9
  from smftools.optional_imports import require
9
10
 
10
11
  plt = require("matplotlib.pyplot", extra="plotting", purpose="QC plots")
11
12
 
13
+ logger = get_logger(__name__)
14
+
12
15
 
13
16
  def plot_read_qc_histograms(
14
17
  adata,
@@ -53,6 +56,7 @@ def plot_read_qc_histograms(
53
56
  dpi : int
54
57
  Figure resolution.
55
58
  """
59
+ logger.info("Plotting read QC histograms to %s.", outdir)
56
60
  os.makedirs(outdir, exist_ok=True)
57
61
 
58
62
  if sample_key not in adata.obs.columns:
@@ -60,7 +64,7 @@ def plot_read_qc_histograms(
60
64
 
61
65
  # Ensure sample_key is categorical for stable ordering
62
66
  samples = adata.obs[sample_key]
63
- if not pd.api.types.is_categorical_dtype(samples):
67
+ if not isinstance(samples.dtype, pd.CategoricalDtype):
64
68
  samples = samples.astype("category")
65
69
  sample_levels = list(samples.cat.categories)
66
70
 
@@ -69,14 +73,14 @@ def plot_read_qc_histograms(
69
73
  is_numeric = {}
70
74
  for key in obs_keys:
71
75
  if key not in adata.obs.columns:
72
- print(f"[WARN] '{key}' not found in obs; skipping.")
76
+ logger.warning("'%s' not found in obs; skipping.", key)
73
77
  continue
74
78
  s = adata.obs[key]
75
79
  num = pd.api.types.is_numeric_dtype(s)
76
80
  valid_keys.append(key)
77
81
  is_numeric[key] = num
78
82
  if not valid_keys:
79
- print("[plot_read_qc_grid] No valid obs_keys to plot.")
83
+ logger.warning("No valid obs_keys to plot.")
80
84
  return
81
85
 
82
86
  # Precompute global numeric ranges (after clipping) so rows share x-axis per column
@@ -174,6 +178,7 @@ def plot_read_qc_histograms(
174
178
  page = start // rows_per_fig + 1
175
179
  out_png = os.path.join(outdir, f"qc_grid_{_sanitize(sample_key)}_page{page}.png")
176
180
  plt.savefig(out_png, bbox_inches="tight")
181
+ logger.info("Saved QC histogram page to %s.", out_png)
177
182
  plt.close(fig)
178
183
 
179
184