smftools 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. smftools/_version.py +1 -1
  2. smftools/cli/chimeric_adata.py +1563 -0
  3. smftools/cli/helpers.py +49 -7
  4. smftools/cli/hmm_adata.py +250 -32
  5. smftools/cli/latent_adata.py +773 -0
  6. smftools/cli/load_adata.py +78 -74
  7. smftools/cli/preprocess_adata.py +122 -58
  8. smftools/cli/recipes.py +26 -0
  9. smftools/cli/spatial_adata.py +74 -112
  10. smftools/cli/variant_adata.py +423 -0
  11. smftools/cli_entry.py +52 -4
  12. smftools/config/conversion.yaml +1 -1
  13. smftools/config/deaminase.yaml +3 -0
  14. smftools/config/default.yaml +85 -12
  15. smftools/config/experiment_config.py +146 -1
  16. smftools/constants.py +69 -0
  17. smftools/hmm/HMM.py +88 -0
  18. smftools/hmm/call_hmm_peaks.py +1 -1
  19. smftools/informatics/__init__.py +6 -0
  20. smftools/informatics/bam_functions.py +358 -8
  21. smftools/informatics/binarize_converted_base_identities.py +2 -89
  22. smftools/informatics/converted_BAM_to_adata.py +636 -175
  23. smftools/informatics/h5ad_functions.py +198 -2
  24. smftools/informatics/modkit_extract_to_adata.py +1007 -425
  25. smftools/informatics/sequence_encoding.py +72 -0
  26. smftools/logging_utils.py +21 -2
  27. smftools/metadata.py +1 -1
  28. smftools/plotting/__init__.py +26 -3
  29. smftools/plotting/autocorrelation_plotting.py +22 -4
  30. smftools/plotting/chimeric_plotting.py +1893 -0
  31. smftools/plotting/classifiers.py +28 -14
  32. smftools/plotting/general_plotting.py +62 -1583
  33. smftools/plotting/hmm_plotting.py +1670 -8
  34. smftools/plotting/latent_plotting.py +804 -0
  35. smftools/plotting/plotting_utils.py +243 -0
  36. smftools/plotting/position_stats.py +16 -8
  37. smftools/plotting/preprocess_plotting.py +281 -0
  38. smftools/plotting/qc_plotting.py +8 -3
  39. smftools/plotting/spatial_plotting.py +1134 -0
  40. smftools/plotting/variant_plotting.py +1231 -0
  41. smftools/preprocessing/__init__.py +4 -0
  42. smftools/preprocessing/append_base_context.py +18 -18
  43. smftools/preprocessing/append_mismatch_frequency_sites.py +187 -0
  44. smftools/preprocessing/append_sequence_mismatch_annotations.py +171 -0
  45. smftools/preprocessing/append_variant_call_layer.py +480 -0
  46. smftools/preprocessing/calculate_consensus.py +1 -1
  47. smftools/preprocessing/calculate_read_modification_stats.py +6 -1
  48. smftools/preprocessing/flag_duplicate_reads.py +4 -4
  49. smftools/preprocessing/invert_adata.py +1 -0
  50. smftools/readwrite.py +159 -99
  51. smftools/schema/anndata_schema_v1.yaml +15 -1
  52. smftools/tools/__init__.py +10 -0
  53. smftools/tools/calculate_knn.py +121 -0
  54. smftools/tools/calculate_leiden.py +57 -0
  55. smftools/tools/calculate_nmf.py +130 -0
  56. smftools/tools/calculate_pca.py +180 -0
  57. smftools/tools/calculate_umap.py +79 -80
  58. smftools/tools/position_stats.py +4 -4
  59. smftools/tools/rolling_nn_distance.py +872 -0
  60. smftools/tools/sequence_alignment.py +140 -0
  61. smftools/tools/tensor_factorization.py +217 -0
  62. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/METADATA +9 -5
  63. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/RECORD +66 -45
  64. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/WHEEL +0 -0
  65. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/entry_points.txt +0 -0
  66. {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,804 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING, Dict, Mapping, Sequence
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ from smftools.logging_utils import get_logger
11
+ from smftools.optional_imports import require
12
+ from smftools.plotting.plotting_utils import _fixed_tick_positions
13
+
14
+ patches = require("matplotlib.patches", extra="plotting", purpose="plot rendering")
15
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="plot rendering")
16
+
17
+ sns = require("seaborn", extra="plotting", purpose="plot styling")
18
+
19
+ grid_spec = require("matplotlib.gridspec", extra="plotting", purpose="heatmap plotting")
20
+
21
+ logger = get_logger(__name__)
22
+
23
+ if TYPE_CHECKING:
24
+ import anndata as ad
25
+
26
+
27
+ def plot_nmf_components(
28
+ adata: "ad.AnnData",
29
+ *,
30
+ output_dir: Path | str,
31
+ components_key: str = "H_nmf",
32
+ suffix: str | None = None,
33
+ heatmap_name: str = "heatmap.png",
34
+ lineplot_name: str = "lineplot.png",
35
+ max_features: int = 2000,
36
+ ) -> Dict[str, Path]:
37
+ """Plot NMF component weights as a heatmap and per-component scatter plot.
38
+
39
+ Args:
40
+ adata: AnnData object containing NMF results.
41
+ output_dir: Directory to write plots into.
42
+ components_key: Key in ``adata.varm`` storing the H matrix.
43
+ heatmap_name: Filename for the heatmap plot.
44
+ lineplot_name: Filename for the scatter plot.
45
+ max_features: Maximum number of features to plot (top-weighted by component).
46
+
47
+ Returns:
48
+ Dict[str, Path]: Paths to created plots (keys: ``heatmap`` and ``lineplot``).
49
+ """
50
+ logger.info("Plotting NMF components to %s.", output_dir)
51
+ if suffix:
52
+ components_key = f"{components_key}_{suffix}"
53
+
54
+ heatmap_name = f"{components_key}_{heatmap_name}"
55
+ lineplot_name = f"{components_key}_{lineplot_name}"
56
+
57
+ if components_key not in adata.varm:
58
+ logger.warning("NMF components key '%s' not found in adata.varm.", components_key)
59
+ return {}
60
+
61
+ output_path = Path(output_dir)
62
+ output_path.mkdir(parents=True, exist_ok=True)
63
+
64
+ components = np.asarray(adata.varm[components_key])
65
+ if components.ndim != 2:
66
+ raise ValueError(f"NMF components must be 2D; got shape {components.shape}.")
67
+
68
+ all_positions = np.arange(components.shape[0])
69
+ feature_labels = all_positions.astype(str)
70
+
71
+ nonzero_mask = np.any(components != 0, axis=1)
72
+ if not np.any(nonzero_mask):
73
+ logger.warning("NMF components are all zeros; skipping plot generation.")
74
+ return {}
75
+
76
+ components = components[nonzero_mask]
77
+ feature_positions = all_positions[nonzero_mask]
78
+
79
+ if max_features and components.shape[0] > max_features:
80
+ scores = np.nanmax(components, axis=1)
81
+ top_idx = np.argsort(scores)[-max_features:]
82
+ top_idx = np.sort(top_idx)
83
+ components = components[top_idx]
84
+ feature_positions = feature_positions[top_idx]
85
+ logger.info(
86
+ "Downsampled NMF features from %s to %s for plotting.",
87
+ nonzero_mask.sum(),
88
+ components.shape[0],
89
+ )
90
+
91
+ n_features, n_components = components.shape
92
+ feature_labels = feature_positions.astype(str)
93
+ component_labels = [f"C{i + 1}" for i in range(n_components)]
94
+
95
+ heatmap_width = max(8, min(20, n_features / 60))
96
+ heatmap_height = max(2.5, 0.6 * n_components + 1.5)
97
+ fig, ax = plt.subplots(figsize=(heatmap_width, heatmap_height))
98
+ sns.heatmap(
99
+ components.T,
100
+ ax=ax,
101
+ cmap="viridis",
102
+ cbar_kws={"label": "Component weight"},
103
+ xticklabels=feature_labels if n_features <= 60 else False,
104
+ yticklabels=component_labels,
105
+ )
106
+ ax.set_xlabel("Position index")
107
+ ax.set_ylabel("NMF component")
108
+ if n_features > 60:
109
+ tick_positions = _fixed_tick_positions(n_features, min(20, n_features))
110
+ ax.set_xticks(tick_positions + 0.5)
111
+ ax.set_xticklabels(feature_positions[tick_positions].astype(str), rotation=90, fontsize=8)
112
+ fig.tight_layout()
113
+ heatmap_path = output_path / heatmap_name
114
+ fig.savefig(heatmap_path, dpi=200)
115
+ plt.close(fig)
116
+ logger.info("Saved NMF heatmap to %s.", heatmap_path)
117
+
118
+ fig, ax = plt.subplots(figsize=(max(8, min(20, n_features / 50)), 3.5))
119
+ x = feature_positions
120
+ for idx, label in enumerate(component_labels):
121
+ ax.scatter(x, components[:, idx], label=label, s=14, alpha=0.75)
122
+ ax.set_xlabel("Position index")
123
+ ax.set_ylabel("Component weight")
124
+ if n_features <= 60:
125
+ ax.set_xticks(x)
126
+ ax.set_xticklabels(feature_labels, rotation=90, fontsize=8)
127
+ ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1), frameon=False)
128
+ fig.tight_layout(rect=[0, 0, 0.82, 1])
129
+ lineplot_path = output_path / lineplot_name
130
+ fig.savefig(lineplot_path, dpi=200)
131
+ plt.close(fig)
132
+ logger.info("Saved NMF line plot to %s.", lineplot_path)
133
+
134
+ return {"heatmap": heatmap_path, "lineplot": lineplot_path}
135
+
136
+
137
+ def plot_pca_components(
138
+ adata: "ad.AnnData",
139
+ *,
140
+ output_dir: Path | str,
141
+ components_key: str = "PCs",
142
+ suffix: str | None = None,
143
+ heatmap_name: str = "heatmap.png",
144
+ lineplot_name: str = "lineplot.png",
145
+ max_features: int = 2000,
146
+ ) -> Dict[str, Path]:
147
+ """Plot PCA component loadings as a heatmap and per-component scatter plot.
148
+
149
+ Args:
150
+ adata: AnnData object containing PCA results.
151
+ output_dir: Directory to write plots into.
152
+ components_key: Key in ``adata.varm`` storing the components.
153
+ heatmap_name: Filename for the heatmap plot.
154
+ lineplot_name: Filename for the scatter plot.
155
+ max_features: Maximum number of features to plot (top-weighted by component).
156
+
157
+ Returns:
158
+ Dict[str, Path]: Paths to created plots (keys: ``heatmap`` and ``lineplot``).
159
+ """
160
+ logger.info("Plotting PCA components to %s.", output_dir)
161
+ if suffix:
162
+ components_key = f"{components_key}_{suffix}"
163
+
164
+ heatmap_name = f"{components_key}_{heatmap_name}"
165
+ lineplot_name = f"{components_key}_{lineplot_name}"
166
+
167
+ if components_key not in adata.varm:
168
+ logger.warning("PCA components key '%s' not found in adata.varm.", components_key)
169
+ return {}
170
+
171
+ output_path = Path(output_dir)
172
+ output_path.mkdir(parents=True, exist_ok=True)
173
+
174
+ components = np.asarray(adata.varm[components_key])
175
+ if components.ndim != 2:
176
+ raise ValueError(f"PCA components must be 2D; got shape {components.shape}.")
177
+
178
+ all_positions = np.arange(components.shape[0])
179
+ feature_labels = all_positions.astype(str)
180
+
181
+ nonzero_mask = np.any(components != 0, axis=1)
182
+ if not np.any(nonzero_mask):
183
+ logger.warning("PCA components are all zeros; skipping plot generation.")
184
+ return {}
185
+
186
+ components = components[nonzero_mask]
187
+ feature_positions = all_positions[nonzero_mask]
188
+
189
+ if max_features and components.shape[0] > max_features:
190
+ scores = np.nanmax(components, axis=1)
191
+ top_idx = np.argsort(scores)[-max_features:]
192
+ top_idx = np.sort(top_idx)
193
+ components = components[top_idx]
194
+ feature_positions = feature_positions[top_idx]
195
+ logger.info(
196
+ "Downsampled PCA features from %s to %s for plotting.",
197
+ nonzero_mask.sum(),
198
+ components.shape[0],
199
+ )
200
+
201
+ n_features, n_components = components.shape
202
+ feature_labels = feature_positions.astype(str)
203
+ component_labels = [f"PC{i + 1}" for i in range(n_components)]
204
+
205
+ heatmap_width = max(8, min(20, n_features / 60))
206
+ heatmap_height = max(2.5, 0.6 * n_components + 1.5)
207
+ fig, ax = plt.subplots(figsize=(heatmap_width, heatmap_height))
208
+ sns.heatmap(
209
+ components.T,
210
+ ax=ax,
211
+ cmap="coolwarm",
212
+ cbar_kws={"label": "Component loading"},
213
+ xticklabels=feature_labels if n_features <= 60 else False,
214
+ yticklabels=component_labels,
215
+ )
216
+ ax.set_xlabel("Position index")
217
+ ax.set_ylabel("PCA component")
218
+ if n_features > 60:
219
+ tick_positions = _fixed_tick_positions(n_features, min(20, n_features))
220
+ ax.set_xticks(tick_positions + 0.5)
221
+ ax.set_xticklabels(feature_positions[tick_positions].astype(str), rotation=90, fontsize=8)
222
+ fig.tight_layout()
223
+ heatmap_path = output_path / heatmap_name
224
+ fig.savefig(heatmap_path, dpi=200)
225
+ plt.close(fig)
226
+ logger.info("Saved PCA heatmap to %s.", heatmap_path)
227
+
228
+ fig, ax = plt.subplots(figsize=(max(8, min(20, n_features / 50)), 3.5))
229
+ x = feature_positions
230
+ for idx, label in enumerate(component_labels):
231
+ ax.scatter(x, components[:, idx], label=label, s=14, alpha=0.75)
232
+ ax.set_xlabel("Position index")
233
+ ax.set_ylabel("Component loading")
234
+ if n_features <= 60:
235
+ ax.set_xticks(x)
236
+ ax.set_xticklabels(feature_labels, rotation=90, fontsize=8)
237
+ ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1), frameon=False)
238
+ fig.tight_layout(rect=[0, 0, 0.82, 1])
239
+ lineplot_path = output_path / lineplot_name
240
+ fig.savefig(lineplot_path, dpi=200)
241
+ plt.close(fig)
242
+ logger.info("Saved PCA line plot to %s.", lineplot_path)
243
+
244
+ return {"heatmap": heatmap_path, "lineplot": lineplot_path}
245
+
246
+
247
+ def plot_cp_sequence_components(
248
+ adata: "ad.AnnData",
249
+ *,
250
+ output_dir: Path | str,
251
+ components_key: str = "H_cp_sequence",
252
+ uns_key: str = "cp_sequence",
253
+ base_factors_key: str | None = None,
254
+ suffix: str | None = None,
255
+ heatmap_name: str = "cp_sequence_position_heatmap.png",
256
+ lineplot_name: str = "cp_sequence_position_lineplot.png",
257
+ base_factors_name: str = "cp_sequence_base_weights.png",
258
+ max_positions: int = 2000,
259
+ ) -> Dict[str, Path]:
260
+ """Plot CP sequence components as heatmaps and line plots.
261
+
262
+ Args:
263
+ adata: AnnData object with CP decomposition in ``varm`` and ``uns``.
264
+ output_dir: Directory to write plots into.
265
+ components_key: Key in ``adata.varm`` for position factors.
266
+ uns_key: Key in ``adata.uns`` for CP metadata (base factors/labels).
267
+ base_factors_key: Optional key in ``adata.uns`` for base factors.
268
+ suffix: Optional suffix appended to the component keys.
269
+ heatmap_name: Filename for the heatmap plot.
270
+ lineplot_name: Filename for the line plot.
271
+ base_factors_name: Filename for the base factors plot.
272
+ max_positions: Maximum number of positions to plot.
273
+
274
+ Returns:
275
+ Dict[str, Path]: Paths to generated plots.
276
+ """
277
+ logger.info("Plotting CP sequence components to %s.", output_dir)
278
+ if suffix:
279
+ components_key = f"{components_key}_{suffix}"
280
+ if base_factors_key is not None:
281
+ base_factors_key = f"{base_factors_key}_{suffix}"
282
+ uns_key = f"{uns_key}_{suffix}"
283
+
284
+ heatmap_name = f"{components_key}_{heatmap_name}"
285
+ lineplot_name = f"{components_key}_{lineplot_name}"
286
+ base_name = f"{components_key}_{base_factors_name}"
287
+
288
+ output_path = Path(output_dir)
289
+ output_path.mkdir(parents=True, exist_ok=True)
290
+
291
+ if components_key not in adata.varm:
292
+ logger.warning("CP components key '%s' not found in adata.varm.", components_key)
293
+ return {}
294
+
295
+ components = np.asarray(adata.varm[components_key])
296
+ if components.ndim != 2:
297
+ raise ValueError(f"CP position factors must be 2D; got shape {components.shape}.")
298
+
299
+ position_indices = np.arange(components.shape[0])
300
+ valid_mask = np.isfinite(components).any(axis=1)
301
+ if not np.all(valid_mask):
302
+ dropped = int(np.sum(~valid_mask))
303
+ logger.info("Dropping %s CP positions with no finite weights before plotting.", dropped)
304
+ components = components[valid_mask]
305
+ position_indices = position_indices[valid_mask]
306
+
307
+ if max_positions and components.shape[0] > max_positions:
308
+ original_count = components.shape[0]
309
+ scores = np.nanmax(np.abs(components), axis=1)
310
+ top_idx = np.argsort(scores)[-max_positions:]
311
+ top_idx = np.sort(top_idx)
312
+ components = components[top_idx]
313
+ position_indices = position_indices[top_idx]
314
+ logger.info(
315
+ "Downsampled CP positions from %s to %s for plotting.",
316
+ original_count,
317
+ max_positions,
318
+ )
319
+
320
+ outputs: Dict[str, Path] = {}
321
+ if components.size == 0:
322
+ logger.warning("No finite CP position factors available; skipping position plots.")
323
+ else:
324
+ n_positions, n_components = components.shape
325
+ component_labels = [f"C{i + 1}" for i in range(n_components)]
326
+
327
+ heatmap_width = max(8, min(20, n_positions / 60))
328
+ heatmap_height = max(2.5, 0.6 * n_components + 1.5)
329
+ fig, ax = plt.subplots(figsize=(heatmap_width, heatmap_height))
330
+ sns.heatmap(
331
+ components.T,
332
+ ax=ax,
333
+ cmap="viridis",
334
+ cbar_kws={"label": "Component weight"},
335
+ xticklabels=position_indices if n_positions <= 60 else False,
336
+ yticklabels=component_labels,
337
+ )
338
+ ax.set_xlabel("Position index")
339
+ ax.set_ylabel("CP component")
340
+ fig.tight_layout()
341
+ heatmap_path = output_path / heatmap_name
342
+ fig.savefig(heatmap_path, dpi=200)
343
+ plt.close(fig)
344
+ logger.info("Saved CP sequence heatmap to %s.", heatmap_path)
345
+
346
+ fig, ax = plt.subplots(figsize=(max(8, min(20, n_positions / 50)), 3.5))
347
+ x = position_indices
348
+ for idx, label in enumerate(component_labels):
349
+ ax.scatter(x, components[:, idx], label=label, s=20, alpha=0.8)
350
+ ax.set_xlabel("Position index")
351
+ ax.set_ylabel("Component weight")
352
+ if n_positions <= 60:
353
+ ax.set_xticks(x)
354
+ ax.set_xticklabels([str(pos) for pos in x], rotation=90, fontsize=8)
355
+ ax.legend(loc="upper right", frameon=False)
356
+ fig.tight_layout()
357
+ lineplot_path = output_path / lineplot_name
358
+ fig.savefig(lineplot_path, dpi=200)
359
+ plt.close(fig)
360
+ logger.info("Saved CP sequence line plot to %s.", lineplot_path)
361
+
362
+ outputs["heatmap"] = heatmap_path
363
+ outputs["lineplot"] = lineplot_path
364
+
365
+ base_factors = None
366
+ base_labels = None
367
+ if uns_key in adata.uns and isinstance(adata.uns[uns_key], dict):
368
+ base_factors = adata.uns[uns_key].get("base_factors")
369
+ base_labels = adata.uns[uns_key].get("base_labels")
370
+ if base_factors is None and base_factors_key:
371
+ base_factors = adata.uns.get(base_factors_key)
372
+ base_labels = adata.uns.get("cp_base_labels")
373
+
374
+ if base_factors is not None:
375
+ base_factors = np.asarray(base_factors)
376
+ if base_factors.ndim != 2 or base_factors.shape[0] == 0:
377
+ logger.warning(
378
+ "CP base factors must be 2D and non-empty; got shape %s.",
379
+ base_factors.shape,
380
+ )
381
+ else:
382
+ base_labels = base_labels or [f"B{i + 1}" for i in range(base_factors.shape[0])]
383
+ fig, ax = plt.subplots(figsize=(4.5, 3))
384
+ width = 0.8 / base_factors.shape[1]
385
+ x = np.arange(base_factors.shape[0])
386
+ for idx in range(base_factors.shape[1]):
387
+ ax.bar(
388
+ x + idx * width,
389
+ base_factors[:, idx],
390
+ width=width,
391
+ label=f"C{idx + 1}",
392
+ )
393
+ ax.set_xticks(x + width * (base_factors.shape[1] - 1) / 2)
394
+ ax.set_xticklabels(base_labels)
395
+ ax.set_ylabel("Base factor weight")
396
+ ax.legend(loc="upper right", frameon=False)
397
+ fig.tight_layout()
398
+ base_path = output_path / base_name
399
+ fig.savefig(base_path, dpi=200)
400
+ plt.close(fig)
401
+ outputs["base_factors"] = base_path
402
+ logger.info("Saved CP base factors plot to %s.", base_path)
403
+
404
+ return outputs
405
+
406
+
407
+ def _resolve_embedding(adata: "ad.AnnData", basis: str) -> np.ndarray:
408
+ key = basis if basis.startswith("X_") else f"X_{basis}"
409
+ if key not in adata.obsm:
410
+ raise KeyError(f"Embedding '{key}' not found in adata.obsm.")
411
+ embedding = np.asarray(adata.obsm[key])
412
+ if embedding.shape[1] < 2:
413
+ raise ValueError(f"Embedding '{key}' must have at least two dimensions.")
414
+ return embedding[:, :2]
415
+
416
+
417
+ def plot_embedding(
418
+ adata: "ad.AnnData",
419
+ *,
420
+ basis: str,
421
+ color: str | Sequence[str],
422
+ output_dir: Path | str,
423
+ prefix: str | None = None,
424
+ point_size: float = 12,
425
+ alpha: float = 0.8,
426
+ ) -> Dict[str, Path]:
427
+ """Plot a 2D embedding with scanpy-style color options.
428
+
429
+ Args:
430
+ adata: AnnData object with ``obsm['X_<basis>']``.
431
+ basis: Embedding basis name (e.g., ``'umap'``, ``'pca'``).
432
+ color: Obs column name or list of names to color by.
433
+ output_dir: Directory to save plots.
434
+ prefix: Optional filename prefix.
435
+ point_size: Marker size for scatter plots.
436
+ alpha: Marker transparency.
437
+
438
+ Returns:
439
+ Dict[str, Path]: Mapping of color keys to saved plot paths.
440
+ """
441
+ logger.info("Plotting %s embedding to %s.", basis, output_dir)
442
+ output_path = Path(output_dir)
443
+ output_path.mkdir(parents=True, exist_ok=True)
444
+ embedding = _resolve_embedding(adata, basis)
445
+ colors = [color] if isinstance(color, str) else list(color)
446
+ saved: Dict[str, Path] = {}
447
+
448
+ for color_key in colors:
449
+ if color_key not in adata.obs:
450
+ logger.warning("Color key '%s' not found in adata.obs; skipping.", color_key)
451
+ continue
452
+ values = adata.obs[color_key]
453
+ fig, ax = plt.subplots(figsize=(5.5, 4.5))
454
+
455
+ if isinstance(values.dtype, pd.CategoricalDtype) or values.dtype == object:
456
+ categories = pd.Categorical(values)
457
+ label_strings = categories.categories.astype(str)
458
+ palette = sns.color_palette("tab20", n_colors=len(label_strings))
459
+ color_map = dict(zip(label_strings, palette))
460
+ codes = categories.codes
461
+ mapped = np.empty(len(codes), dtype=object)
462
+ valid = codes >= 0
463
+ if np.any(valid):
464
+ valid_codes = codes[valid]
465
+ mapped_values = np.empty(len(valid_codes), dtype=object)
466
+ for i, idx in enumerate(valid_codes):
467
+ mapped_values[i] = palette[idx]
468
+ mapped[valid] = mapped_values
469
+ mapped[~valid] = "#bdbdbd"
470
+ ax.scatter(
471
+ embedding[:, 0],
472
+ embedding[:, 1],
473
+ c=list(mapped),
474
+ s=point_size,
475
+ alpha=alpha,
476
+ linewidths=0,
477
+ )
478
+ handles = [
479
+ patches.Patch(color=color_map[label], label=str(label)) for label in label_strings
480
+ ]
481
+ ax.legend(handles=handles, loc="best", fontsize=8, frameon=False)
482
+ else:
483
+ scatter = ax.scatter(
484
+ embedding[:, 0],
485
+ embedding[:, 1],
486
+ c=values.astype(float),
487
+ cmap="viridis",
488
+ s=point_size,
489
+ alpha=alpha,
490
+ linewidths=0,
491
+ )
492
+ fig.colorbar(scatter, ax=ax, label=color_key)
493
+
494
+ ax.set_xlabel(f"Component 1")
495
+ ax.set_ylabel(f"Component 2")
496
+ ax.set_title(f"{color_key}")
497
+ fig.tight_layout()
498
+
499
+ filename_prefix = prefix or basis
500
+ safe_key = str(color_key).replace(" ", "_")
501
+ output_file = output_path / f"{filename_prefix}_{safe_key}.png"
502
+ fig.savefig(output_file, dpi=200)
503
+ plt.close(fig)
504
+ logger.info("Saved %s embedding plot to %s.", basis, output_file)
505
+ saved[color_key] = output_file
506
+
507
+ return saved
508
+
509
+
510
+ def _grid_dimensions(n_items: int, ncols: int | None) -> tuple[int, int]:
511
+ if n_items < 1:
512
+ return 0, 0
513
+ if ncols is None:
514
+ ncols = 2 if n_items > 1 else 1
515
+ ncols = max(1, min(ncols, n_items))
516
+ nrows = int(math.ceil(n_items / ncols))
517
+ return nrows, ncols
518
+
519
+
520
+ def plot_embedding_grid(
521
+ adata: "ad.AnnData",
522
+ *,
523
+ basis: str,
524
+ color: str | Sequence[str],
525
+ output_dir: Path | str,
526
+ prefix: str | None = None,
527
+ ncols: int | None = None,
528
+ point_size: float = 12,
529
+ alpha: float = 0.8,
530
+ ) -> Path | None:
531
+ """Plot a 2D embedding grid with legends to the right of each subplot.
532
+
533
+ Args:
534
+ adata: AnnData object with ``obsm['X_<basis>']``.
535
+ basis: Embedding basis name (e.g., ``'umap'``, ``'pca'``).
536
+ color: Obs column name or list of names to color by.
537
+ output_dir: Directory to save plots.
538
+ prefix: Optional filename prefix.
539
+ ncols: Number of columns in the grid.
540
+ point_size: Marker size for scatter plots.
541
+ alpha: Marker transparency.
542
+
543
+ Returns:
544
+ Path to the saved grid image, or None if no valid color keys exist.
545
+ """
546
+ logger.info("Plotting %s embedding grid to %s.", basis, output_dir)
547
+ output_path = Path(output_dir)
548
+ output_path.mkdir(parents=True, exist_ok=True)
549
+ embedding = _resolve_embedding(adata, basis)
550
+ colors = [color] if isinstance(color, str) else list(color)
551
+
552
+ valid_colors = []
553
+ for color_key in colors:
554
+ if color_key not in adata.obs:
555
+ logger.warning("Color key '%s' not found in adata.obs; skipping.", color_key)
556
+ continue
557
+ valid_colors.append(color_key)
558
+
559
+ if not valid_colors:
560
+ return None
561
+
562
+ nrows, ncols = _grid_dimensions(len(valid_colors), ncols)
563
+ plot_width = 4.8
564
+ legend_width = 2.4
565
+ plot_height = 4.2
566
+ fig = plt.figure(
567
+ figsize=(ncols * (plot_width + legend_width), nrows * plot_height),
568
+ )
569
+ width_ratios = [plot_width, legend_width] * ncols
570
+ grid = grid_spec.GridSpec(
571
+ nrows,
572
+ ncols * 2,
573
+ figure=fig,
574
+ width_ratios=width_ratios,
575
+ wspace=0.08,
576
+ hspace=0.35,
577
+ )
578
+
579
+ for idx, color_key in enumerate(valid_colors):
580
+ row = idx // ncols
581
+ col = idx % ncols
582
+ ax = fig.add_subplot(grid[row, col * 2])
583
+ legend_ax = fig.add_subplot(grid[row, col * 2 + 1])
584
+ legend_ax.axis("off")
585
+
586
+ values = adata.obs[color_key]
587
+ if isinstance(values.dtype, pd.CategoricalDtype) or values.dtype == object:
588
+ categories = pd.Categorical(values)
589
+ label_strings = categories.categories.astype(str)
590
+ palette = sns.color_palette("tab20", n_colors=len(label_strings))
591
+ color_map = dict(zip(label_strings, palette))
592
+ codes = categories.codes
593
+ mapped = np.empty(len(codes), dtype=object)
594
+ valid = codes >= 0
595
+ if np.any(valid):
596
+ valid_codes = codes[valid]
597
+ mapped_values = np.empty(len(valid_codes), dtype=object)
598
+ for i, idx2 in enumerate(valid_codes):
599
+ mapped_values[i] = palette[idx2]
600
+ mapped[valid] = mapped_values
601
+ mapped[~valid] = "#bdbdbd"
602
+ ax.scatter(
603
+ embedding[:, 0],
604
+ embedding[:, 1],
605
+ c=list(mapped),
606
+ s=point_size,
607
+ alpha=alpha,
608
+ linewidths=0,
609
+ )
610
+ handles = [
611
+ patches.Patch(color=color_map[label], label=str(label)) for label in label_strings
612
+ ]
613
+ legend_ax.legend(handles=handles, loc="center left", fontsize=8, frameon=False)
614
+ else:
615
+ scatter = ax.scatter(
616
+ embedding[:, 0],
617
+ embedding[:, 1],
618
+ c=values.astype(float),
619
+ cmap="viridis",
620
+ s=point_size,
621
+ alpha=alpha,
622
+ linewidths=0,
623
+ )
624
+ fig.colorbar(scatter, ax=ax, fraction=0.046, pad=0.02, shrink=0.9)
625
+
626
+ ax.set_xlabel(f"Component 1")
627
+ ax.set_ylabel(f"Component 2")
628
+ ax.set_title(f"{color_key}")
629
+
630
+ fig.tight_layout()
631
+
632
+ filename_prefix = prefix or basis
633
+ output_file = output_path / f"{filename_prefix}_grid.png"
634
+ fig.savefig(output_file, dpi=200)
635
+ plt.close(fig)
636
+ logger.info("Saved %s embedding grid to %s.", basis, output_file)
637
+ return output_file
638
+
639
+
640
+ def plot_umap(
641
+ adata: "ad.AnnData",
642
+ *,
643
+ subset: str | None = None,
644
+ color: str | Sequence[str],
645
+ output_dir: Path | str,
646
+ prefix: str | None = None,
647
+ point_size: float = 12,
648
+ alpha: float = 0.8,
649
+ ) -> Dict[str, Path]:
650
+ logger.info("Plotting UMAP embedding to %s.", output_dir)
651
+
652
+ if subset:
653
+ umap_key = f"umap_{subset}"
654
+ else:
655
+ umap_key = "umap"
656
+
657
+ return plot_embedding(adata, basis=umap_key, color=color, output_dir=output_dir, prefix=prefix)
658
+
659
+
660
+ def plot_umap_grid(
661
+ adata: "ad.AnnData",
662
+ *,
663
+ subset: str | None = None,
664
+ color: str | Sequence[str],
665
+ output_dir: Path | str,
666
+ prefix: str | None = None,
667
+ ncols: int | None = None,
668
+ point_size: float = 12,
669
+ alpha: float = 0.8,
670
+ ) -> Path | None:
671
+ logger.info("Plotting UMAP embedding grid to %s.", output_dir)
672
+
673
+ if subset:
674
+ umap_key = f"umap_{subset}"
675
+ else:
676
+ umap_key = "umap"
677
+
678
+ return plot_embedding_grid(
679
+ adata,
680
+ basis=umap_key,
681
+ color=color,
682
+ output_dir=output_dir,
683
+ prefix=prefix,
684
+ ncols=ncols,
685
+ point_size=point_size,
686
+ alpha=alpha,
687
+ )
688
+
689
+
690
+ def plot_pca(
691
+ adata: "ad.AnnData",
692
+ *,
693
+ subset: str | None = None,
694
+ color: str | Sequence[str],
695
+ output_dir: Path | str,
696
+ prefix: str | None = None,
697
+ point_size: float = 12,
698
+ alpha: float = 0.8,
699
+ ) -> Dict[str, Path]:
700
+ logger.info("Plotting PCA embedding to %s.", output_dir)
701
+ if subset:
702
+ pca_key = f"pca_{subset}"
703
+ else:
704
+ pca_key = "pca"
705
+ return plot_embedding(adata, basis=pca_key, color=color, output_dir=output_dir, prefix=prefix)
706
+
707
+
708
+ def plot_pca_grid(
709
+ adata: "ad.AnnData",
710
+ *,
711
+ subset: str | None = None,
712
+ color: str | Sequence[str],
713
+ output_dir: Path | str,
714
+ prefix: str | None = None,
715
+ ncols: int | None = None,
716
+ point_size: float = 12,
717
+ alpha: float = 0.8,
718
+ ) -> Path | None:
719
+ logger.info("Plotting PCA embedding grid to %s.", output_dir)
720
+
721
+ if subset:
722
+ pca_key = f"pca_{subset}"
723
+ else:
724
+ pca_key = "pca"
725
+
726
+ return plot_embedding_grid(
727
+ adata,
728
+ basis=pca_key,
729
+ color=color,
730
+ output_dir=output_dir,
731
+ prefix=prefix,
732
+ ncols=ncols,
733
+ point_size=point_size,
734
+ alpha=alpha,
735
+ )
736
+
737
+
738
+ def plot_pca_explained_variance(
739
+ adata: "ad.AnnData",
740
+ *,
741
+ subset: str | None = None,
742
+ output_dir: Path | str,
743
+ pca_key: str = "pca",
744
+ suffix: str | None = None,
745
+ max_pcs: int | None = None,
746
+ ) -> Path | None:
747
+ """Plot cumulative explained variance for PCA results.
748
+
749
+ Args:
750
+ adata: AnnData object containing PCA results in ``uns``.
751
+ subset: Optional subset suffix used in key naming.
752
+ output_dir: Directory to write the plot into.
753
+ pca_key: Base key in ``adata.uns`` storing PCA results.
754
+ suffix: Optional suffix to append to the key.
755
+ max_pcs: Optional cap on number of PCs to plot.
756
+
757
+ Returns:
758
+ Path to the saved plot, or None if explained variance is unavailable.
759
+ """
760
+ logger.info("Plotting PCA explained variance to %s.", output_dir)
761
+
762
+ if subset:
763
+ pca_key = f"{pca_key}_{subset}"
764
+ if suffix:
765
+ pca_key = f"{pca_key}_{suffix}"
766
+
767
+ if pca_key not in adata.uns:
768
+ logger.warning("Explained variance ratio not found in adata.uns[%s].", pca_key)
769
+ return None
770
+
771
+ pca_data = adata.uns[pca_key]
772
+ if not isinstance(pca_data, Mapping) or "variance_ratio" not in pca_data:
773
+ logger.warning("Explained variance ratio not found in adata.uns[%s].", pca_key)
774
+ return None
775
+
776
+ variance_ratio = np.asarray(pca_data.get("variance_ratio", []), dtype=float)
777
+ if variance_ratio.size == 0:
778
+ logger.warning("Explained variance ratio for %s is empty; skipping plot.", pca_key)
779
+ return None
780
+
781
+ if max_pcs is not None:
782
+ variance_ratio = variance_ratio[:max_pcs]
783
+
784
+ cumulative = np.cumsum(variance_ratio)
785
+ pcs = np.arange(1, len(variance_ratio) + 1)
786
+
787
+ output_path = Path(output_dir)
788
+ output_path.mkdir(parents=True, exist_ok=True)
789
+
790
+ fig, ax = plt.subplots(figsize=(8, 4))
791
+ ax.plot(pcs, cumulative, color="#DD8452", marker="o", label="Cumulative variance")
792
+ ax.set_xlabel("Principal component")
793
+ ax.set_ylabel("Cumulative explained variance")
794
+ ax.set_ylim(0, 1.05)
795
+ ax.grid(True, axis="y", alpha=0.3)
796
+ ax.legend(frameon=False)
797
+ fig.tight_layout()
798
+
799
+ out_file = output_path / f"{pca_key}_explained_variance.png"
800
+ fig.savefig(out_file, dpi=200)
801
+ plt.close(fig)
802
+ logger.info("Saved PCA explained variance plot to %s.", out_file)
803
+
804
+ return out_file