smftools 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/_version.py +1 -1
- smftools/cli/chimeric_adata.py +1563 -0
- smftools/cli/helpers.py +18 -2
- smftools/cli/hmm_adata.py +18 -1
- smftools/cli/latent_adata.py +522 -67
- smftools/cli/load_adata.py +2 -2
- smftools/cli/preprocess_adata.py +32 -93
- smftools/cli/recipes.py +26 -0
- smftools/cli/spatial_adata.py +23 -109
- smftools/cli/variant_adata.py +423 -0
- smftools/cli_entry.py +41 -5
- smftools/config/conversion.yaml +0 -10
- smftools/config/deaminase.yaml +3 -0
- smftools/config/default.yaml +49 -13
- smftools/config/experiment_config.py +96 -3
- smftools/constants.py +4 -0
- smftools/hmm/call_hmm_peaks.py +1 -1
- smftools/informatics/binarize_converted_base_identities.py +2 -89
- smftools/informatics/converted_BAM_to_adata.py +53 -13
- smftools/informatics/h5ad_functions.py +83 -0
- smftools/informatics/modkit_extract_to_adata.py +4 -0
- smftools/plotting/__init__.py +26 -12
- smftools/plotting/autocorrelation_plotting.py +22 -4
- smftools/plotting/chimeric_plotting.py +1893 -0
- smftools/plotting/classifiers.py +28 -14
- smftools/plotting/general_plotting.py +58 -3362
- smftools/plotting/hmm_plotting.py +1586 -2
- smftools/plotting/latent_plotting.py +804 -0
- smftools/plotting/plotting_utils.py +243 -0
- smftools/plotting/position_stats.py +16 -8
- smftools/plotting/preprocess_plotting.py +281 -0
- smftools/plotting/qc_plotting.py +8 -3
- smftools/plotting/spatial_plotting.py +1134 -0
- smftools/plotting/variant_plotting.py +1231 -0
- smftools/preprocessing/__init__.py +3 -0
- smftools/preprocessing/append_base_context.py +1 -1
- smftools/preprocessing/append_mismatch_frequency_sites.py +35 -6
- smftools/preprocessing/append_sequence_mismatch_annotations.py +171 -0
- smftools/preprocessing/append_variant_call_layer.py +480 -0
- smftools/preprocessing/flag_duplicate_reads.py +4 -4
- smftools/preprocessing/invert_adata.py +1 -0
- smftools/readwrite.py +109 -85
- smftools/tools/__init__.py +6 -0
- smftools/tools/calculate_knn.py +121 -0
- smftools/tools/calculate_nmf.py +18 -7
- smftools/tools/calculate_pca.py +180 -0
- smftools/tools/calculate_umap.py +70 -154
- smftools/tools/position_stats.py +4 -4
- smftools/tools/rolling_nn_distance.py +640 -3
- smftools/tools/sequence_alignment.py +140 -0
- smftools/tools/tensor_factorization.py +52 -4
- {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/METADATA +3 -1
- {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/RECORD +56 -42
- {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/WHEEL +0 -0
- {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/entry_points.txt +0 -0
- {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,804 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import TYPE_CHECKING, Dict, Mapping, Sequence
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
9
|
+
|
|
10
|
+
from smftools.logging_utils import get_logger
|
|
11
|
+
from smftools.optional_imports import require
|
|
12
|
+
from smftools.plotting.plotting_utils import _fixed_tick_positions
|
|
13
|
+
|
|
14
|
+
patches = require("matplotlib.patches", extra="plotting", purpose="plot rendering")
|
|
15
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="plot rendering")
|
|
16
|
+
|
|
17
|
+
sns = require("seaborn", extra="plotting", purpose="plot styling")
|
|
18
|
+
|
|
19
|
+
grid_spec = require("matplotlib.gridspec", extra="plotting", purpose="heatmap plotting")
|
|
20
|
+
|
|
21
|
+
logger = get_logger(__name__)
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
import anndata as ad
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def plot_nmf_components(
|
|
28
|
+
adata: "ad.AnnData",
|
|
29
|
+
*,
|
|
30
|
+
output_dir: Path | str,
|
|
31
|
+
components_key: str = "H_nmf",
|
|
32
|
+
suffix: str | None = None,
|
|
33
|
+
heatmap_name: str = "heatmap.png",
|
|
34
|
+
lineplot_name: str = "lineplot.png",
|
|
35
|
+
max_features: int = 2000,
|
|
36
|
+
) -> Dict[str, Path]:
|
|
37
|
+
"""Plot NMF component weights as a heatmap and per-component scatter plot.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
adata: AnnData object containing NMF results.
|
|
41
|
+
output_dir: Directory to write plots into.
|
|
42
|
+
components_key: Key in ``adata.varm`` storing the H matrix.
|
|
43
|
+
heatmap_name: Filename for the heatmap plot.
|
|
44
|
+
lineplot_name: Filename for the scatter plot.
|
|
45
|
+
max_features: Maximum number of features to plot (top-weighted by component).
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
Dict[str, Path]: Paths to created plots (keys: ``heatmap`` and ``lineplot``).
|
|
49
|
+
"""
|
|
50
|
+
logger.info("Plotting NMF components to %s.", output_dir)
|
|
51
|
+
if suffix:
|
|
52
|
+
components_key = f"{components_key}_{suffix}"
|
|
53
|
+
|
|
54
|
+
heatmap_name = f"{components_key}_{heatmap_name}"
|
|
55
|
+
lineplot_name = f"{components_key}_{lineplot_name}"
|
|
56
|
+
|
|
57
|
+
if components_key not in adata.varm:
|
|
58
|
+
logger.warning("NMF components key '%s' not found in adata.varm.", components_key)
|
|
59
|
+
return {}
|
|
60
|
+
|
|
61
|
+
output_path = Path(output_dir)
|
|
62
|
+
output_path.mkdir(parents=True, exist_ok=True)
|
|
63
|
+
|
|
64
|
+
components = np.asarray(adata.varm[components_key])
|
|
65
|
+
if components.ndim != 2:
|
|
66
|
+
raise ValueError(f"NMF components must be 2D; got shape {components.shape}.")
|
|
67
|
+
|
|
68
|
+
all_positions = np.arange(components.shape[0])
|
|
69
|
+
feature_labels = all_positions.astype(str)
|
|
70
|
+
|
|
71
|
+
nonzero_mask = np.any(components != 0, axis=1)
|
|
72
|
+
if not np.any(nonzero_mask):
|
|
73
|
+
logger.warning("NMF components are all zeros; skipping plot generation.")
|
|
74
|
+
return {}
|
|
75
|
+
|
|
76
|
+
components = components[nonzero_mask]
|
|
77
|
+
feature_positions = all_positions[nonzero_mask]
|
|
78
|
+
|
|
79
|
+
if max_features and components.shape[0] > max_features:
|
|
80
|
+
scores = np.nanmax(components, axis=1)
|
|
81
|
+
top_idx = np.argsort(scores)[-max_features:]
|
|
82
|
+
top_idx = np.sort(top_idx)
|
|
83
|
+
components = components[top_idx]
|
|
84
|
+
feature_positions = feature_positions[top_idx]
|
|
85
|
+
logger.info(
|
|
86
|
+
"Downsampled NMF features from %s to %s for plotting.",
|
|
87
|
+
nonzero_mask.sum(),
|
|
88
|
+
components.shape[0],
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
n_features, n_components = components.shape
|
|
92
|
+
feature_labels = feature_positions.astype(str)
|
|
93
|
+
component_labels = [f"C{i + 1}" for i in range(n_components)]
|
|
94
|
+
|
|
95
|
+
heatmap_width = max(8, min(20, n_features / 60))
|
|
96
|
+
heatmap_height = max(2.5, 0.6 * n_components + 1.5)
|
|
97
|
+
fig, ax = plt.subplots(figsize=(heatmap_width, heatmap_height))
|
|
98
|
+
sns.heatmap(
|
|
99
|
+
components.T,
|
|
100
|
+
ax=ax,
|
|
101
|
+
cmap="viridis",
|
|
102
|
+
cbar_kws={"label": "Component weight"},
|
|
103
|
+
xticklabels=feature_labels if n_features <= 60 else False,
|
|
104
|
+
yticklabels=component_labels,
|
|
105
|
+
)
|
|
106
|
+
ax.set_xlabel("Position index")
|
|
107
|
+
ax.set_ylabel("NMF component")
|
|
108
|
+
if n_features > 60:
|
|
109
|
+
tick_positions = _fixed_tick_positions(n_features, min(20, n_features))
|
|
110
|
+
ax.set_xticks(tick_positions + 0.5)
|
|
111
|
+
ax.set_xticklabels(feature_positions[tick_positions].astype(str), rotation=90, fontsize=8)
|
|
112
|
+
fig.tight_layout()
|
|
113
|
+
heatmap_path = output_path / heatmap_name
|
|
114
|
+
fig.savefig(heatmap_path, dpi=200)
|
|
115
|
+
plt.close(fig)
|
|
116
|
+
logger.info("Saved NMF heatmap to %s.", heatmap_path)
|
|
117
|
+
|
|
118
|
+
fig, ax = plt.subplots(figsize=(max(8, min(20, n_features / 50)), 3.5))
|
|
119
|
+
x = feature_positions
|
|
120
|
+
for idx, label in enumerate(component_labels):
|
|
121
|
+
ax.scatter(x, components[:, idx], label=label, s=14, alpha=0.75)
|
|
122
|
+
ax.set_xlabel("Position index")
|
|
123
|
+
ax.set_ylabel("Component weight")
|
|
124
|
+
if n_features <= 60:
|
|
125
|
+
ax.set_xticks(x)
|
|
126
|
+
ax.set_xticklabels(feature_labels, rotation=90, fontsize=8)
|
|
127
|
+
ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1), frameon=False)
|
|
128
|
+
fig.tight_layout(rect=[0, 0, 0.82, 1])
|
|
129
|
+
lineplot_path = output_path / lineplot_name
|
|
130
|
+
fig.savefig(lineplot_path, dpi=200)
|
|
131
|
+
plt.close(fig)
|
|
132
|
+
logger.info("Saved NMF line plot to %s.", lineplot_path)
|
|
133
|
+
|
|
134
|
+
return {"heatmap": heatmap_path, "lineplot": lineplot_path}
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def plot_pca_components(
|
|
138
|
+
adata: "ad.AnnData",
|
|
139
|
+
*,
|
|
140
|
+
output_dir: Path | str,
|
|
141
|
+
components_key: str = "PCs",
|
|
142
|
+
suffix: str | None = None,
|
|
143
|
+
heatmap_name: str = "heatmap.png",
|
|
144
|
+
lineplot_name: str = "lineplot.png",
|
|
145
|
+
max_features: int = 2000,
|
|
146
|
+
) -> Dict[str, Path]:
|
|
147
|
+
"""Plot PCA component loadings as a heatmap and per-component scatter plot.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
adata: AnnData object containing PCA results.
|
|
151
|
+
output_dir: Directory to write plots into.
|
|
152
|
+
components_key: Key in ``adata.varm`` storing the components.
|
|
153
|
+
heatmap_name: Filename for the heatmap plot.
|
|
154
|
+
lineplot_name: Filename for the scatter plot.
|
|
155
|
+
max_features: Maximum number of features to plot (top-weighted by component).
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
Dict[str, Path]: Paths to created plots (keys: ``heatmap`` and ``lineplot``).
|
|
159
|
+
"""
|
|
160
|
+
logger.info("Plotting PCA components to %s.", output_dir)
|
|
161
|
+
if suffix:
|
|
162
|
+
components_key = f"{components_key}_{suffix}"
|
|
163
|
+
|
|
164
|
+
heatmap_name = f"{components_key}_{heatmap_name}"
|
|
165
|
+
lineplot_name = f"{components_key}_{lineplot_name}"
|
|
166
|
+
|
|
167
|
+
if components_key not in adata.varm:
|
|
168
|
+
logger.warning("PCA components key '%s' not found in adata.varm.", components_key)
|
|
169
|
+
return {}
|
|
170
|
+
|
|
171
|
+
output_path = Path(output_dir)
|
|
172
|
+
output_path.mkdir(parents=True, exist_ok=True)
|
|
173
|
+
|
|
174
|
+
components = np.asarray(adata.varm[components_key])
|
|
175
|
+
if components.ndim != 2:
|
|
176
|
+
raise ValueError(f"PCA components must be 2D; got shape {components.shape}.")
|
|
177
|
+
|
|
178
|
+
all_positions = np.arange(components.shape[0])
|
|
179
|
+
feature_labels = all_positions.astype(str)
|
|
180
|
+
|
|
181
|
+
nonzero_mask = np.any(components != 0, axis=1)
|
|
182
|
+
if not np.any(nonzero_mask):
|
|
183
|
+
logger.warning("PCA components are all zeros; skipping plot generation.")
|
|
184
|
+
return {}
|
|
185
|
+
|
|
186
|
+
components = components[nonzero_mask]
|
|
187
|
+
feature_positions = all_positions[nonzero_mask]
|
|
188
|
+
|
|
189
|
+
if max_features and components.shape[0] > max_features:
|
|
190
|
+
scores = np.nanmax(components, axis=1)
|
|
191
|
+
top_idx = np.argsort(scores)[-max_features:]
|
|
192
|
+
top_idx = np.sort(top_idx)
|
|
193
|
+
components = components[top_idx]
|
|
194
|
+
feature_positions = feature_positions[top_idx]
|
|
195
|
+
logger.info(
|
|
196
|
+
"Downsampled PCA features from %s to %s for plotting.",
|
|
197
|
+
nonzero_mask.sum(),
|
|
198
|
+
components.shape[0],
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
n_features, n_components = components.shape
|
|
202
|
+
feature_labels = feature_positions.astype(str)
|
|
203
|
+
component_labels = [f"PC{i + 1}" for i in range(n_components)]
|
|
204
|
+
|
|
205
|
+
heatmap_width = max(8, min(20, n_features / 60))
|
|
206
|
+
heatmap_height = max(2.5, 0.6 * n_components + 1.5)
|
|
207
|
+
fig, ax = plt.subplots(figsize=(heatmap_width, heatmap_height))
|
|
208
|
+
sns.heatmap(
|
|
209
|
+
components.T,
|
|
210
|
+
ax=ax,
|
|
211
|
+
cmap="coolwarm",
|
|
212
|
+
cbar_kws={"label": "Component loading"},
|
|
213
|
+
xticklabels=feature_labels if n_features <= 60 else False,
|
|
214
|
+
yticklabels=component_labels,
|
|
215
|
+
)
|
|
216
|
+
ax.set_xlabel("Position index")
|
|
217
|
+
ax.set_ylabel("PCA component")
|
|
218
|
+
if n_features > 60:
|
|
219
|
+
tick_positions = _fixed_tick_positions(n_features, min(20, n_features))
|
|
220
|
+
ax.set_xticks(tick_positions + 0.5)
|
|
221
|
+
ax.set_xticklabels(feature_positions[tick_positions].astype(str), rotation=90, fontsize=8)
|
|
222
|
+
fig.tight_layout()
|
|
223
|
+
heatmap_path = output_path / heatmap_name
|
|
224
|
+
fig.savefig(heatmap_path, dpi=200)
|
|
225
|
+
plt.close(fig)
|
|
226
|
+
logger.info("Saved PCA heatmap to %s.", heatmap_path)
|
|
227
|
+
|
|
228
|
+
fig, ax = plt.subplots(figsize=(max(8, min(20, n_features / 50)), 3.5))
|
|
229
|
+
x = feature_positions
|
|
230
|
+
for idx, label in enumerate(component_labels):
|
|
231
|
+
ax.scatter(x, components[:, idx], label=label, s=14, alpha=0.75)
|
|
232
|
+
ax.set_xlabel("Position index")
|
|
233
|
+
ax.set_ylabel("Component loading")
|
|
234
|
+
if n_features <= 60:
|
|
235
|
+
ax.set_xticks(x)
|
|
236
|
+
ax.set_xticklabels(feature_labels, rotation=90, fontsize=8)
|
|
237
|
+
ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1), frameon=False)
|
|
238
|
+
fig.tight_layout(rect=[0, 0, 0.82, 1])
|
|
239
|
+
lineplot_path = output_path / lineplot_name
|
|
240
|
+
fig.savefig(lineplot_path, dpi=200)
|
|
241
|
+
plt.close(fig)
|
|
242
|
+
logger.info("Saved PCA line plot to %s.", lineplot_path)
|
|
243
|
+
|
|
244
|
+
return {"heatmap": heatmap_path, "lineplot": lineplot_path}
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def plot_cp_sequence_components(
|
|
248
|
+
adata: "ad.AnnData",
|
|
249
|
+
*,
|
|
250
|
+
output_dir: Path | str,
|
|
251
|
+
components_key: str = "H_cp_sequence",
|
|
252
|
+
uns_key: str = "cp_sequence",
|
|
253
|
+
base_factors_key: str | None = None,
|
|
254
|
+
suffix: str | None = None,
|
|
255
|
+
heatmap_name: str = "cp_sequence_position_heatmap.png",
|
|
256
|
+
lineplot_name: str = "cp_sequence_position_lineplot.png",
|
|
257
|
+
base_factors_name: str = "cp_sequence_base_weights.png",
|
|
258
|
+
max_positions: int = 2000,
|
|
259
|
+
) -> Dict[str, Path]:
|
|
260
|
+
"""Plot CP sequence components as heatmaps and line plots.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
adata: AnnData object with CP decomposition in ``varm`` and ``uns``.
|
|
264
|
+
output_dir: Directory to write plots into.
|
|
265
|
+
components_key: Key in ``adata.varm`` for position factors.
|
|
266
|
+
uns_key: Key in ``adata.uns`` for CP metadata (base factors/labels).
|
|
267
|
+
base_factors_key: Optional key in ``adata.uns`` for base factors.
|
|
268
|
+
suffix: Optional suffix appended to the component keys.
|
|
269
|
+
heatmap_name: Filename for the heatmap plot.
|
|
270
|
+
lineplot_name: Filename for the line plot.
|
|
271
|
+
base_factors_name: Filename for the base factors plot.
|
|
272
|
+
max_positions: Maximum number of positions to plot.
|
|
273
|
+
|
|
274
|
+
Returns:
|
|
275
|
+
Dict[str, Path]: Paths to generated plots.
|
|
276
|
+
"""
|
|
277
|
+
logger.info("Plotting CP sequence components to %s.", output_dir)
|
|
278
|
+
if suffix:
|
|
279
|
+
components_key = f"{components_key}_{suffix}"
|
|
280
|
+
if base_factors_key is not None:
|
|
281
|
+
base_factors_key = f"{base_factors_key}_{suffix}"
|
|
282
|
+
uns_key = f"{uns_key}_{suffix}"
|
|
283
|
+
|
|
284
|
+
heatmap_name = f"{components_key}_{heatmap_name}"
|
|
285
|
+
lineplot_name = f"{components_key}_{lineplot_name}"
|
|
286
|
+
base_name = f"{components_key}_{base_factors_name}"
|
|
287
|
+
|
|
288
|
+
output_path = Path(output_dir)
|
|
289
|
+
output_path.mkdir(parents=True, exist_ok=True)
|
|
290
|
+
|
|
291
|
+
if components_key not in adata.varm:
|
|
292
|
+
logger.warning("CP components key '%s' not found in adata.varm.", components_key)
|
|
293
|
+
return {}
|
|
294
|
+
|
|
295
|
+
components = np.asarray(adata.varm[components_key])
|
|
296
|
+
if components.ndim != 2:
|
|
297
|
+
raise ValueError(f"CP position factors must be 2D; got shape {components.shape}.")
|
|
298
|
+
|
|
299
|
+
position_indices = np.arange(components.shape[0])
|
|
300
|
+
valid_mask = np.isfinite(components).any(axis=1)
|
|
301
|
+
if not np.all(valid_mask):
|
|
302
|
+
dropped = int(np.sum(~valid_mask))
|
|
303
|
+
logger.info("Dropping %s CP positions with no finite weights before plotting.", dropped)
|
|
304
|
+
components = components[valid_mask]
|
|
305
|
+
position_indices = position_indices[valid_mask]
|
|
306
|
+
|
|
307
|
+
if max_positions and components.shape[0] > max_positions:
|
|
308
|
+
original_count = components.shape[0]
|
|
309
|
+
scores = np.nanmax(np.abs(components), axis=1)
|
|
310
|
+
top_idx = np.argsort(scores)[-max_positions:]
|
|
311
|
+
top_idx = np.sort(top_idx)
|
|
312
|
+
components = components[top_idx]
|
|
313
|
+
position_indices = position_indices[top_idx]
|
|
314
|
+
logger.info(
|
|
315
|
+
"Downsampled CP positions from %s to %s for plotting.",
|
|
316
|
+
original_count,
|
|
317
|
+
max_positions,
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
outputs: Dict[str, Path] = {}
|
|
321
|
+
if components.size == 0:
|
|
322
|
+
logger.warning("No finite CP position factors available; skipping position plots.")
|
|
323
|
+
else:
|
|
324
|
+
n_positions, n_components = components.shape
|
|
325
|
+
component_labels = [f"C{i + 1}" for i in range(n_components)]
|
|
326
|
+
|
|
327
|
+
heatmap_width = max(8, min(20, n_positions / 60))
|
|
328
|
+
heatmap_height = max(2.5, 0.6 * n_components + 1.5)
|
|
329
|
+
fig, ax = plt.subplots(figsize=(heatmap_width, heatmap_height))
|
|
330
|
+
sns.heatmap(
|
|
331
|
+
components.T,
|
|
332
|
+
ax=ax,
|
|
333
|
+
cmap="viridis",
|
|
334
|
+
cbar_kws={"label": "Component weight"},
|
|
335
|
+
xticklabels=position_indices if n_positions <= 60 else False,
|
|
336
|
+
yticklabels=component_labels,
|
|
337
|
+
)
|
|
338
|
+
ax.set_xlabel("Position index")
|
|
339
|
+
ax.set_ylabel("CP component")
|
|
340
|
+
fig.tight_layout()
|
|
341
|
+
heatmap_path = output_path / heatmap_name
|
|
342
|
+
fig.savefig(heatmap_path, dpi=200)
|
|
343
|
+
plt.close(fig)
|
|
344
|
+
logger.info("Saved CP sequence heatmap to %s.", heatmap_path)
|
|
345
|
+
|
|
346
|
+
fig, ax = plt.subplots(figsize=(max(8, min(20, n_positions / 50)), 3.5))
|
|
347
|
+
x = position_indices
|
|
348
|
+
for idx, label in enumerate(component_labels):
|
|
349
|
+
ax.scatter(x, components[:, idx], label=label, s=20, alpha=0.8)
|
|
350
|
+
ax.set_xlabel("Position index")
|
|
351
|
+
ax.set_ylabel("Component weight")
|
|
352
|
+
if n_positions <= 60:
|
|
353
|
+
ax.set_xticks(x)
|
|
354
|
+
ax.set_xticklabels([str(pos) for pos in x], rotation=90, fontsize=8)
|
|
355
|
+
ax.legend(loc="upper right", frameon=False)
|
|
356
|
+
fig.tight_layout()
|
|
357
|
+
lineplot_path = output_path / lineplot_name
|
|
358
|
+
fig.savefig(lineplot_path, dpi=200)
|
|
359
|
+
plt.close(fig)
|
|
360
|
+
logger.info("Saved CP sequence line plot to %s.", lineplot_path)
|
|
361
|
+
|
|
362
|
+
outputs["heatmap"] = heatmap_path
|
|
363
|
+
outputs["lineplot"] = lineplot_path
|
|
364
|
+
|
|
365
|
+
base_factors = None
|
|
366
|
+
base_labels = None
|
|
367
|
+
if uns_key in adata.uns and isinstance(adata.uns[uns_key], dict):
|
|
368
|
+
base_factors = adata.uns[uns_key].get("base_factors")
|
|
369
|
+
base_labels = adata.uns[uns_key].get("base_labels")
|
|
370
|
+
if base_factors is None and base_factors_key:
|
|
371
|
+
base_factors = adata.uns.get(base_factors_key)
|
|
372
|
+
base_labels = adata.uns.get("cp_base_labels")
|
|
373
|
+
|
|
374
|
+
if base_factors is not None:
|
|
375
|
+
base_factors = np.asarray(base_factors)
|
|
376
|
+
if base_factors.ndim != 2 or base_factors.shape[0] == 0:
|
|
377
|
+
logger.warning(
|
|
378
|
+
"CP base factors must be 2D and non-empty; got shape %s.",
|
|
379
|
+
base_factors.shape,
|
|
380
|
+
)
|
|
381
|
+
else:
|
|
382
|
+
base_labels = base_labels or [f"B{i + 1}" for i in range(base_factors.shape[0])]
|
|
383
|
+
fig, ax = plt.subplots(figsize=(4.5, 3))
|
|
384
|
+
width = 0.8 / base_factors.shape[1]
|
|
385
|
+
x = np.arange(base_factors.shape[0])
|
|
386
|
+
for idx in range(base_factors.shape[1]):
|
|
387
|
+
ax.bar(
|
|
388
|
+
x + idx * width,
|
|
389
|
+
base_factors[:, idx],
|
|
390
|
+
width=width,
|
|
391
|
+
label=f"C{idx + 1}",
|
|
392
|
+
)
|
|
393
|
+
ax.set_xticks(x + width * (base_factors.shape[1] - 1) / 2)
|
|
394
|
+
ax.set_xticklabels(base_labels)
|
|
395
|
+
ax.set_ylabel("Base factor weight")
|
|
396
|
+
ax.legend(loc="upper right", frameon=False)
|
|
397
|
+
fig.tight_layout()
|
|
398
|
+
base_path = output_path / base_name
|
|
399
|
+
fig.savefig(base_path, dpi=200)
|
|
400
|
+
plt.close(fig)
|
|
401
|
+
outputs["base_factors"] = base_path
|
|
402
|
+
logger.info("Saved CP base factors plot to %s.", base_path)
|
|
403
|
+
|
|
404
|
+
return outputs
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def _resolve_embedding(adata: "ad.AnnData", basis: str) -> np.ndarray:
|
|
408
|
+
key = basis if basis.startswith("X_") else f"X_{basis}"
|
|
409
|
+
if key not in adata.obsm:
|
|
410
|
+
raise KeyError(f"Embedding '{key}' not found in adata.obsm.")
|
|
411
|
+
embedding = np.asarray(adata.obsm[key])
|
|
412
|
+
if embedding.shape[1] < 2:
|
|
413
|
+
raise ValueError(f"Embedding '{key}' must have at least two dimensions.")
|
|
414
|
+
return embedding[:, :2]
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def plot_embedding(
|
|
418
|
+
adata: "ad.AnnData",
|
|
419
|
+
*,
|
|
420
|
+
basis: str,
|
|
421
|
+
color: str | Sequence[str],
|
|
422
|
+
output_dir: Path | str,
|
|
423
|
+
prefix: str | None = None,
|
|
424
|
+
point_size: float = 12,
|
|
425
|
+
alpha: float = 0.8,
|
|
426
|
+
) -> Dict[str, Path]:
|
|
427
|
+
"""Plot a 2D embedding with scanpy-style color options.
|
|
428
|
+
|
|
429
|
+
Args:
|
|
430
|
+
adata: AnnData object with ``obsm['X_<basis>']``.
|
|
431
|
+
basis: Embedding basis name (e.g., ``'umap'``, ``'pca'``).
|
|
432
|
+
color: Obs column name or list of names to color by.
|
|
433
|
+
output_dir: Directory to save plots.
|
|
434
|
+
prefix: Optional filename prefix.
|
|
435
|
+
point_size: Marker size for scatter plots.
|
|
436
|
+
alpha: Marker transparency.
|
|
437
|
+
|
|
438
|
+
Returns:
|
|
439
|
+
Dict[str, Path]: Mapping of color keys to saved plot paths.
|
|
440
|
+
"""
|
|
441
|
+
logger.info("Plotting %s embedding to %s.", basis, output_dir)
|
|
442
|
+
output_path = Path(output_dir)
|
|
443
|
+
output_path.mkdir(parents=True, exist_ok=True)
|
|
444
|
+
embedding = _resolve_embedding(adata, basis)
|
|
445
|
+
colors = [color] if isinstance(color, str) else list(color)
|
|
446
|
+
saved: Dict[str, Path] = {}
|
|
447
|
+
|
|
448
|
+
for color_key in colors:
|
|
449
|
+
if color_key not in adata.obs:
|
|
450
|
+
logger.warning("Color key '%s' not found in adata.obs; skipping.", color_key)
|
|
451
|
+
continue
|
|
452
|
+
values = adata.obs[color_key]
|
|
453
|
+
fig, ax = plt.subplots(figsize=(5.5, 4.5))
|
|
454
|
+
|
|
455
|
+
if isinstance(values.dtype, pd.CategoricalDtype) or values.dtype == object:
|
|
456
|
+
categories = pd.Categorical(values)
|
|
457
|
+
label_strings = categories.categories.astype(str)
|
|
458
|
+
palette = sns.color_palette("tab20", n_colors=len(label_strings))
|
|
459
|
+
color_map = dict(zip(label_strings, palette))
|
|
460
|
+
codes = categories.codes
|
|
461
|
+
mapped = np.empty(len(codes), dtype=object)
|
|
462
|
+
valid = codes >= 0
|
|
463
|
+
if np.any(valid):
|
|
464
|
+
valid_codes = codes[valid]
|
|
465
|
+
mapped_values = np.empty(len(valid_codes), dtype=object)
|
|
466
|
+
for i, idx in enumerate(valid_codes):
|
|
467
|
+
mapped_values[i] = palette[idx]
|
|
468
|
+
mapped[valid] = mapped_values
|
|
469
|
+
mapped[~valid] = "#bdbdbd"
|
|
470
|
+
ax.scatter(
|
|
471
|
+
embedding[:, 0],
|
|
472
|
+
embedding[:, 1],
|
|
473
|
+
c=list(mapped),
|
|
474
|
+
s=point_size,
|
|
475
|
+
alpha=alpha,
|
|
476
|
+
linewidths=0,
|
|
477
|
+
)
|
|
478
|
+
handles = [
|
|
479
|
+
patches.Patch(color=color_map[label], label=str(label)) for label in label_strings
|
|
480
|
+
]
|
|
481
|
+
ax.legend(handles=handles, loc="best", fontsize=8, frameon=False)
|
|
482
|
+
else:
|
|
483
|
+
scatter = ax.scatter(
|
|
484
|
+
embedding[:, 0],
|
|
485
|
+
embedding[:, 1],
|
|
486
|
+
c=values.astype(float),
|
|
487
|
+
cmap="viridis",
|
|
488
|
+
s=point_size,
|
|
489
|
+
alpha=alpha,
|
|
490
|
+
linewidths=0,
|
|
491
|
+
)
|
|
492
|
+
fig.colorbar(scatter, ax=ax, label=color_key)
|
|
493
|
+
|
|
494
|
+
ax.set_xlabel(f"Component 1")
|
|
495
|
+
ax.set_ylabel(f"Component 2")
|
|
496
|
+
ax.set_title(f"{color_key}")
|
|
497
|
+
fig.tight_layout()
|
|
498
|
+
|
|
499
|
+
filename_prefix = prefix or basis
|
|
500
|
+
safe_key = str(color_key).replace(" ", "_")
|
|
501
|
+
output_file = output_path / f"{filename_prefix}_{safe_key}.png"
|
|
502
|
+
fig.savefig(output_file, dpi=200)
|
|
503
|
+
plt.close(fig)
|
|
504
|
+
logger.info("Saved %s embedding plot to %s.", basis, output_file)
|
|
505
|
+
saved[color_key] = output_file
|
|
506
|
+
|
|
507
|
+
return saved
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
def _grid_dimensions(n_items: int, ncols: int | None) -> tuple[int, int]:
|
|
511
|
+
if n_items < 1:
|
|
512
|
+
return 0, 0
|
|
513
|
+
if ncols is None:
|
|
514
|
+
ncols = 2 if n_items > 1 else 1
|
|
515
|
+
ncols = max(1, min(ncols, n_items))
|
|
516
|
+
nrows = int(math.ceil(n_items / ncols))
|
|
517
|
+
return nrows, ncols
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
def plot_embedding_grid(
|
|
521
|
+
adata: "ad.AnnData",
|
|
522
|
+
*,
|
|
523
|
+
basis: str,
|
|
524
|
+
color: str | Sequence[str],
|
|
525
|
+
output_dir: Path | str,
|
|
526
|
+
prefix: str | None = None,
|
|
527
|
+
ncols: int | None = None,
|
|
528
|
+
point_size: float = 12,
|
|
529
|
+
alpha: float = 0.8,
|
|
530
|
+
) -> Path | None:
|
|
531
|
+
"""Plot a 2D embedding grid with legends to the right of each subplot.
|
|
532
|
+
|
|
533
|
+
Args:
|
|
534
|
+
adata: AnnData object with ``obsm['X_<basis>']``.
|
|
535
|
+
basis: Embedding basis name (e.g., ``'umap'``, ``'pca'``).
|
|
536
|
+
color: Obs column name or list of names to color by.
|
|
537
|
+
output_dir: Directory to save plots.
|
|
538
|
+
prefix: Optional filename prefix.
|
|
539
|
+
ncols: Number of columns in the grid.
|
|
540
|
+
point_size: Marker size for scatter plots.
|
|
541
|
+
alpha: Marker transparency.
|
|
542
|
+
|
|
543
|
+
Returns:
|
|
544
|
+
Path to the saved grid image, or None if no valid color keys exist.
|
|
545
|
+
"""
|
|
546
|
+
logger.info("Plotting %s embedding grid to %s.", basis, output_dir)
|
|
547
|
+
output_path = Path(output_dir)
|
|
548
|
+
output_path.mkdir(parents=True, exist_ok=True)
|
|
549
|
+
embedding = _resolve_embedding(adata, basis)
|
|
550
|
+
colors = [color] if isinstance(color, str) else list(color)
|
|
551
|
+
|
|
552
|
+
valid_colors = []
|
|
553
|
+
for color_key in colors:
|
|
554
|
+
if color_key not in adata.obs:
|
|
555
|
+
logger.warning("Color key '%s' not found in adata.obs; skipping.", color_key)
|
|
556
|
+
continue
|
|
557
|
+
valid_colors.append(color_key)
|
|
558
|
+
|
|
559
|
+
if not valid_colors:
|
|
560
|
+
return None
|
|
561
|
+
|
|
562
|
+
nrows, ncols = _grid_dimensions(len(valid_colors), ncols)
|
|
563
|
+
plot_width = 4.8
|
|
564
|
+
legend_width = 2.4
|
|
565
|
+
plot_height = 4.2
|
|
566
|
+
fig = plt.figure(
|
|
567
|
+
figsize=(ncols * (plot_width + legend_width), nrows * plot_height),
|
|
568
|
+
)
|
|
569
|
+
width_ratios = [plot_width, legend_width] * ncols
|
|
570
|
+
grid = grid_spec.GridSpec(
|
|
571
|
+
nrows,
|
|
572
|
+
ncols * 2,
|
|
573
|
+
figure=fig,
|
|
574
|
+
width_ratios=width_ratios,
|
|
575
|
+
wspace=0.08,
|
|
576
|
+
hspace=0.35,
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
for idx, color_key in enumerate(valid_colors):
|
|
580
|
+
row = idx // ncols
|
|
581
|
+
col = idx % ncols
|
|
582
|
+
ax = fig.add_subplot(grid[row, col * 2])
|
|
583
|
+
legend_ax = fig.add_subplot(grid[row, col * 2 + 1])
|
|
584
|
+
legend_ax.axis("off")
|
|
585
|
+
|
|
586
|
+
values = adata.obs[color_key]
|
|
587
|
+
if isinstance(values.dtype, pd.CategoricalDtype) or values.dtype == object:
|
|
588
|
+
categories = pd.Categorical(values)
|
|
589
|
+
label_strings = categories.categories.astype(str)
|
|
590
|
+
palette = sns.color_palette("tab20", n_colors=len(label_strings))
|
|
591
|
+
color_map = dict(zip(label_strings, palette))
|
|
592
|
+
codes = categories.codes
|
|
593
|
+
mapped = np.empty(len(codes), dtype=object)
|
|
594
|
+
valid = codes >= 0
|
|
595
|
+
if np.any(valid):
|
|
596
|
+
valid_codes = codes[valid]
|
|
597
|
+
mapped_values = np.empty(len(valid_codes), dtype=object)
|
|
598
|
+
for i, idx2 in enumerate(valid_codes):
|
|
599
|
+
mapped_values[i] = palette[idx2]
|
|
600
|
+
mapped[valid] = mapped_values
|
|
601
|
+
mapped[~valid] = "#bdbdbd"
|
|
602
|
+
ax.scatter(
|
|
603
|
+
embedding[:, 0],
|
|
604
|
+
embedding[:, 1],
|
|
605
|
+
c=list(mapped),
|
|
606
|
+
s=point_size,
|
|
607
|
+
alpha=alpha,
|
|
608
|
+
linewidths=0,
|
|
609
|
+
)
|
|
610
|
+
handles = [
|
|
611
|
+
patches.Patch(color=color_map[label], label=str(label)) for label in label_strings
|
|
612
|
+
]
|
|
613
|
+
legend_ax.legend(handles=handles, loc="center left", fontsize=8, frameon=False)
|
|
614
|
+
else:
|
|
615
|
+
scatter = ax.scatter(
|
|
616
|
+
embedding[:, 0],
|
|
617
|
+
embedding[:, 1],
|
|
618
|
+
c=values.astype(float),
|
|
619
|
+
cmap="viridis",
|
|
620
|
+
s=point_size,
|
|
621
|
+
alpha=alpha,
|
|
622
|
+
linewidths=0,
|
|
623
|
+
)
|
|
624
|
+
fig.colorbar(scatter, ax=ax, fraction=0.046, pad=0.02, shrink=0.9)
|
|
625
|
+
|
|
626
|
+
ax.set_xlabel(f"Component 1")
|
|
627
|
+
ax.set_ylabel(f"Component 2")
|
|
628
|
+
ax.set_title(f"{color_key}")
|
|
629
|
+
|
|
630
|
+
fig.tight_layout()
|
|
631
|
+
|
|
632
|
+
filename_prefix = prefix or basis
|
|
633
|
+
output_file = output_path / f"{filename_prefix}_grid.png"
|
|
634
|
+
fig.savefig(output_file, dpi=200)
|
|
635
|
+
plt.close(fig)
|
|
636
|
+
logger.info("Saved %s embedding grid to %s.", basis, output_file)
|
|
637
|
+
return output_file
|
|
638
|
+
|
|
639
|
+
|
|
640
|
+
def plot_umap(
|
|
641
|
+
adata: "ad.AnnData",
|
|
642
|
+
*,
|
|
643
|
+
subset: str | None = None,
|
|
644
|
+
color: str | Sequence[str],
|
|
645
|
+
output_dir: Path | str,
|
|
646
|
+
prefix: str | None = None,
|
|
647
|
+
point_size: float = 12,
|
|
648
|
+
alpha: float = 0.8,
|
|
649
|
+
) -> Dict[str, Path]:
|
|
650
|
+
logger.info("Plotting UMAP embedding to %s.", output_dir)
|
|
651
|
+
|
|
652
|
+
if subset:
|
|
653
|
+
umap_key = f"umap_{subset}"
|
|
654
|
+
else:
|
|
655
|
+
umap_key = "umap"
|
|
656
|
+
|
|
657
|
+
return plot_embedding(adata, basis=umap_key, color=color, output_dir=output_dir, prefix=prefix)
|
|
658
|
+
|
|
659
|
+
|
|
660
|
+
def plot_umap_grid(
|
|
661
|
+
adata: "ad.AnnData",
|
|
662
|
+
*,
|
|
663
|
+
subset: str | None = None,
|
|
664
|
+
color: str | Sequence[str],
|
|
665
|
+
output_dir: Path | str,
|
|
666
|
+
prefix: str | None = None,
|
|
667
|
+
ncols: int | None = None,
|
|
668
|
+
point_size: float = 12,
|
|
669
|
+
alpha: float = 0.8,
|
|
670
|
+
) -> Path | None:
|
|
671
|
+
logger.info("Plotting UMAP embedding grid to %s.", output_dir)
|
|
672
|
+
|
|
673
|
+
if subset:
|
|
674
|
+
umap_key = f"umap_{subset}"
|
|
675
|
+
else:
|
|
676
|
+
umap_key = "umap"
|
|
677
|
+
|
|
678
|
+
return plot_embedding_grid(
|
|
679
|
+
adata,
|
|
680
|
+
basis=umap_key,
|
|
681
|
+
color=color,
|
|
682
|
+
output_dir=output_dir,
|
|
683
|
+
prefix=prefix,
|
|
684
|
+
ncols=ncols,
|
|
685
|
+
point_size=point_size,
|
|
686
|
+
alpha=alpha,
|
|
687
|
+
)
|
|
688
|
+
|
|
689
|
+
|
|
690
|
+
def plot_pca(
|
|
691
|
+
adata: "ad.AnnData",
|
|
692
|
+
*,
|
|
693
|
+
subset: str | None = None,
|
|
694
|
+
color: str | Sequence[str],
|
|
695
|
+
output_dir: Path | str,
|
|
696
|
+
prefix: str | None = None,
|
|
697
|
+
point_size: float = 12,
|
|
698
|
+
alpha: float = 0.8,
|
|
699
|
+
) -> Dict[str, Path]:
|
|
700
|
+
logger.info("Plotting PCA embedding to %s.", output_dir)
|
|
701
|
+
if subset:
|
|
702
|
+
pca_key = f"pca_{subset}"
|
|
703
|
+
else:
|
|
704
|
+
pca_key = "pca"
|
|
705
|
+
return plot_embedding(adata, basis=pca_key, color=color, output_dir=output_dir, prefix=prefix)
|
|
706
|
+
|
|
707
|
+
|
|
708
|
+
def plot_pca_grid(
|
|
709
|
+
adata: "ad.AnnData",
|
|
710
|
+
*,
|
|
711
|
+
subset: str | None = None,
|
|
712
|
+
color: str | Sequence[str],
|
|
713
|
+
output_dir: Path | str,
|
|
714
|
+
prefix: str | None = None,
|
|
715
|
+
ncols: int | None = None,
|
|
716
|
+
point_size: float = 12,
|
|
717
|
+
alpha: float = 0.8,
|
|
718
|
+
) -> Path | None:
|
|
719
|
+
logger.info("Plotting PCA embedding grid to %s.", output_dir)
|
|
720
|
+
|
|
721
|
+
if subset:
|
|
722
|
+
pca_key = f"pca_{subset}"
|
|
723
|
+
else:
|
|
724
|
+
pca_key = "pca"
|
|
725
|
+
|
|
726
|
+
return plot_embedding_grid(
|
|
727
|
+
adata,
|
|
728
|
+
basis=pca_key,
|
|
729
|
+
color=color,
|
|
730
|
+
output_dir=output_dir,
|
|
731
|
+
prefix=prefix,
|
|
732
|
+
ncols=ncols,
|
|
733
|
+
point_size=point_size,
|
|
734
|
+
alpha=alpha,
|
|
735
|
+
)
|
|
736
|
+
|
|
737
|
+
|
|
738
|
+
def plot_pca_explained_variance(
|
|
739
|
+
adata: "ad.AnnData",
|
|
740
|
+
*,
|
|
741
|
+
subset: str | None = None,
|
|
742
|
+
output_dir: Path | str,
|
|
743
|
+
pca_key: str = "pca",
|
|
744
|
+
suffix: str | None = None,
|
|
745
|
+
max_pcs: int | None = None,
|
|
746
|
+
) -> Path | None:
|
|
747
|
+
"""Plot cumulative explained variance for PCA results.
|
|
748
|
+
|
|
749
|
+
Args:
|
|
750
|
+
adata: AnnData object containing PCA results in ``uns``.
|
|
751
|
+
subset: Optional subset suffix used in key naming.
|
|
752
|
+
output_dir: Directory to write the plot into.
|
|
753
|
+
pca_key: Base key in ``adata.uns`` storing PCA results.
|
|
754
|
+
suffix: Optional suffix to append to the key.
|
|
755
|
+
max_pcs: Optional cap on number of PCs to plot.
|
|
756
|
+
|
|
757
|
+
Returns:
|
|
758
|
+
Path to the saved plot, or None if explained variance is unavailable.
|
|
759
|
+
"""
|
|
760
|
+
logger.info("Plotting PCA explained variance to %s.", output_dir)
|
|
761
|
+
|
|
762
|
+
if subset:
|
|
763
|
+
pca_key = f"{pca_key}_{subset}"
|
|
764
|
+
if suffix:
|
|
765
|
+
pca_key = f"{pca_key}_{suffix}"
|
|
766
|
+
|
|
767
|
+
if pca_key not in adata.uns:
|
|
768
|
+
logger.warning("Explained variance ratio not found in adata.uns[%s].", pca_key)
|
|
769
|
+
return None
|
|
770
|
+
|
|
771
|
+
pca_data = adata.uns[pca_key]
|
|
772
|
+
if not isinstance(pca_data, Mapping) or "variance_ratio" not in pca_data:
|
|
773
|
+
logger.warning("Explained variance ratio not found in adata.uns[%s].", pca_key)
|
|
774
|
+
return None
|
|
775
|
+
|
|
776
|
+
variance_ratio = np.asarray(pca_data.get("variance_ratio", []), dtype=float)
|
|
777
|
+
if variance_ratio.size == 0:
|
|
778
|
+
logger.warning("Explained variance ratio for %s is empty; skipping plot.", pca_key)
|
|
779
|
+
return None
|
|
780
|
+
|
|
781
|
+
if max_pcs is not None:
|
|
782
|
+
variance_ratio = variance_ratio[:max_pcs]
|
|
783
|
+
|
|
784
|
+
cumulative = np.cumsum(variance_ratio)
|
|
785
|
+
pcs = np.arange(1, len(variance_ratio) + 1)
|
|
786
|
+
|
|
787
|
+
output_path = Path(output_dir)
|
|
788
|
+
output_path.mkdir(parents=True, exist_ok=True)
|
|
789
|
+
|
|
790
|
+
fig, ax = plt.subplots(figsize=(8, 4))
|
|
791
|
+
ax.plot(pcs, cumulative, color="#DD8452", marker="o", label="Cumulative variance")
|
|
792
|
+
ax.set_xlabel("Principal component")
|
|
793
|
+
ax.set_ylabel("Cumulative explained variance")
|
|
794
|
+
ax.set_ylim(0, 1.05)
|
|
795
|
+
ax.grid(True, axis="y", alpha=0.3)
|
|
796
|
+
ax.legend(frameon=False)
|
|
797
|
+
fig.tight_layout()
|
|
798
|
+
|
|
799
|
+
out_file = output_path / f"{pca_key}_explained_variance.png"
|
|
800
|
+
fig.savefig(out_file, dpi=200)
|
|
801
|
+
plt.close(fig)
|
|
802
|
+
logger.info("Saved PCA explained variance plot to %s.", out_file)
|
|
803
|
+
|
|
804
|
+
return out_file
|