smftools 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/_version.py +1 -1
- smftools/cli/chimeric_adata.py +1563 -0
- smftools/cli/helpers.py +49 -7
- smftools/cli/hmm_adata.py +250 -32
- smftools/cli/latent_adata.py +773 -0
- smftools/cli/load_adata.py +78 -74
- smftools/cli/preprocess_adata.py +122 -58
- smftools/cli/recipes.py +26 -0
- smftools/cli/spatial_adata.py +74 -112
- smftools/cli/variant_adata.py +423 -0
- smftools/cli_entry.py +52 -4
- smftools/config/conversion.yaml +1 -1
- smftools/config/deaminase.yaml +3 -0
- smftools/config/default.yaml +85 -12
- smftools/config/experiment_config.py +146 -1
- smftools/constants.py +69 -0
- smftools/hmm/HMM.py +88 -0
- smftools/hmm/call_hmm_peaks.py +1 -1
- smftools/informatics/__init__.py +6 -0
- smftools/informatics/bam_functions.py +358 -8
- smftools/informatics/binarize_converted_base_identities.py +2 -89
- smftools/informatics/converted_BAM_to_adata.py +636 -175
- smftools/informatics/h5ad_functions.py +198 -2
- smftools/informatics/modkit_extract_to_adata.py +1007 -425
- smftools/informatics/sequence_encoding.py +72 -0
- smftools/logging_utils.py +21 -2
- smftools/metadata.py +1 -1
- smftools/plotting/__init__.py +26 -3
- smftools/plotting/autocorrelation_plotting.py +22 -4
- smftools/plotting/chimeric_plotting.py +1893 -0
- smftools/plotting/classifiers.py +28 -14
- smftools/plotting/general_plotting.py +62 -1583
- smftools/plotting/hmm_plotting.py +1670 -8
- smftools/plotting/latent_plotting.py +804 -0
- smftools/plotting/plotting_utils.py +243 -0
- smftools/plotting/position_stats.py +16 -8
- smftools/plotting/preprocess_plotting.py +281 -0
- smftools/plotting/qc_plotting.py +8 -3
- smftools/plotting/spatial_plotting.py +1134 -0
- smftools/plotting/variant_plotting.py +1231 -0
- smftools/preprocessing/__init__.py +4 -0
- smftools/preprocessing/append_base_context.py +18 -18
- smftools/preprocessing/append_mismatch_frequency_sites.py +187 -0
- smftools/preprocessing/append_sequence_mismatch_annotations.py +171 -0
- smftools/preprocessing/append_variant_call_layer.py +480 -0
- smftools/preprocessing/calculate_consensus.py +1 -1
- smftools/preprocessing/calculate_read_modification_stats.py +6 -1
- smftools/preprocessing/flag_duplicate_reads.py +4 -4
- smftools/preprocessing/invert_adata.py +1 -0
- smftools/readwrite.py +159 -99
- smftools/schema/anndata_schema_v1.yaml +15 -1
- smftools/tools/__init__.py +10 -0
- smftools/tools/calculate_knn.py +121 -0
- smftools/tools/calculate_leiden.py +57 -0
- smftools/tools/calculate_nmf.py +130 -0
- smftools/tools/calculate_pca.py +180 -0
- smftools/tools/calculate_umap.py +79 -80
- smftools/tools/position_stats.py +4 -4
- smftools/tools/rolling_nn_distance.py +872 -0
- smftools/tools/sequence_alignment.py +140 -0
- smftools/tools/tensor_factorization.py +217 -0
- {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/METADATA +9 -5
- {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/RECORD +66 -45
- {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/WHEEL +0 -0
- {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/entry_points.txt +0 -0
- {smftools-0.3.0.dist-info → smftools-0.3.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Any
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
from smftools.logging_utils import get_logger
|
|
9
|
+
from smftools.optional_imports import require
|
|
10
|
+
|
|
11
|
+
sns = require("seaborn", extra="plotting", purpose="plot styling")
|
|
12
|
+
|
|
13
|
+
logger = get_logger(__name__)
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
import anndata as ad
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _fixed_tick_positions(n_positions: int, n_ticks: int) -> np.ndarray:
|
|
20
|
+
"""
|
|
21
|
+
Return indices for ~n_ticks evenly spaced labels across [0, n_positions-1].
|
|
22
|
+
Always includes 0 and n_positions-1 when possible.
|
|
23
|
+
"""
|
|
24
|
+
n_ticks = int(max(2, n_ticks))
|
|
25
|
+
if n_positions <= n_ticks:
|
|
26
|
+
return np.arange(n_positions)
|
|
27
|
+
|
|
28
|
+
pos = np.linspace(0, n_positions - 1, n_ticks)
|
|
29
|
+
return np.unique(np.round(pos).astype(int))
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _select_labels(
|
|
33
|
+
subset: "ad.AnnData", sites: np.ndarray, reference: str, index_col_suffix: str | None
|
|
34
|
+
) -> np.ndarray:
|
|
35
|
+
"""
|
|
36
|
+
Select tick labels for the heatmap axis.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
subset : AnnData view
|
|
41
|
+
The per-bin subset of the AnnData.
|
|
42
|
+
sites : np.ndarray[int]
|
|
43
|
+
Indices of the subset.var positions to annotate.
|
|
44
|
+
reference : str
|
|
45
|
+
Reference name (e.g., '6B6_top').
|
|
46
|
+
index_col_suffix : None or str
|
|
47
|
+
If None → use subset.var_names
|
|
48
|
+
Else → use subset.var[f"{reference}_{index_col_suffix}"]
|
|
49
|
+
|
|
50
|
+
Returns
|
|
51
|
+
-------
|
|
52
|
+
np.ndarray[str]
|
|
53
|
+
The labels to use for tick positions.
|
|
54
|
+
"""
|
|
55
|
+
if sites.size == 0:
|
|
56
|
+
return np.array([])
|
|
57
|
+
|
|
58
|
+
if index_col_suffix is None:
|
|
59
|
+
return subset.var_names[sites].astype(str)
|
|
60
|
+
|
|
61
|
+
colname = f"{reference}_{index_col_suffix}"
|
|
62
|
+
|
|
63
|
+
if colname not in subset.var:
|
|
64
|
+
raise KeyError(
|
|
65
|
+
f"index_col_suffix='{index_col_suffix}' requires var column '{colname}', "
|
|
66
|
+
f"but it is not present in adata.var."
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
labels = subset.var[colname].astype(str).values
|
|
70
|
+
return labels[sites]
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def normalized_mean(matrix: np.ndarray, *, ignore_nan: bool = True) -> np.ndarray:
|
|
74
|
+
"""Compute normalized column means for a matrix.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
matrix: Input matrix.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
1D array of normalized means.
|
|
81
|
+
"""
|
|
82
|
+
mean = np.nanmean(matrix, axis=0) if ignore_nan else np.mean(matrix, axis=0)
|
|
83
|
+
denom = (mean.max() - mean.min()) + 1e-9
|
|
84
|
+
return (mean - mean.min()) / denom
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _layer_to_numpy(
|
|
88
|
+
subset: "ad.AnnData",
|
|
89
|
+
layer_name: str,
|
|
90
|
+
sites: np.ndarray | None = None,
|
|
91
|
+
*,
|
|
92
|
+
fill_nan_strategy: str = "value",
|
|
93
|
+
fill_nan_value: float = -1,
|
|
94
|
+
) -> np.ndarray:
|
|
95
|
+
"""Return a (copied) numpy array for a layer with optional NaN filling."""
|
|
96
|
+
if sites is not None:
|
|
97
|
+
layer_data = subset[:, sites].layers[layer_name]
|
|
98
|
+
else:
|
|
99
|
+
layer_data = subset.layers[layer_name]
|
|
100
|
+
|
|
101
|
+
if hasattr(layer_data, "toarray"):
|
|
102
|
+
arr = layer_data.toarray()
|
|
103
|
+
else:
|
|
104
|
+
arr = np.asarray(layer_data)
|
|
105
|
+
|
|
106
|
+
arr = np.array(arr, copy=True)
|
|
107
|
+
|
|
108
|
+
if fill_nan_strategy == "none":
|
|
109
|
+
return arr
|
|
110
|
+
|
|
111
|
+
if fill_nan_strategy not in {"value", "col_mean"}:
|
|
112
|
+
raise ValueError("fill_nan_strategy must be 'none', 'value', or 'col_mean'.")
|
|
113
|
+
|
|
114
|
+
arr = arr.astype(float, copy=False)
|
|
115
|
+
|
|
116
|
+
if fill_nan_strategy == "value":
|
|
117
|
+
return np.where(np.isnan(arr), fill_nan_value, arr)
|
|
118
|
+
|
|
119
|
+
col_mean = np.nanmean(arr, axis=0)
|
|
120
|
+
if np.any(np.isnan(col_mean)):
|
|
121
|
+
col_mean = np.where(np.isnan(col_mean), fill_nan_value, col_mean)
|
|
122
|
+
return np.where(np.isnan(arr), col_mean, arr)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _infer_zero_is_valid(layer_name: str | None, matrix: np.ndarray) -> bool:
|
|
126
|
+
"""Infer whether zeros should count as valid (unmethylated) values."""
|
|
127
|
+
if layer_name and "nan0_0minus1" in layer_name:
|
|
128
|
+
return False
|
|
129
|
+
if np.isnan(matrix).any():
|
|
130
|
+
return True
|
|
131
|
+
if np.any(matrix < 0):
|
|
132
|
+
return False
|
|
133
|
+
return True
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def methylation_fraction(
|
|
137
|
+
matrix: np.ndarray, *, ignore_nan: bool = True, zero_is_valid: bool = False
|
|
138
|
+
) -> np.ndarray:
|
|
139
|
+
"""
|
|
140
|
+
Fraction methylated per column.
|
|
141
|
+
Methylated = 1
|
|
142
|
+
Valid = finite AND not 0 (unless zero_is_valid=True)
|
|
143
|
+
"""
|
|
144
|
+
matrix = np.asarray(matrix)
|
|
145
|
+
if not ignore_nan:
|
|
146
|
+
matrix = np.where(np.isnan(matrix), 0, matrix)
|
|
147
|
+
finite_mask = np.isfinite(matrix)
|
|
148
|
+
valid_mask = finite_mask if zero_is_valid else (finite_mask & (matrix != 0))
|
|
149
|
+
methyl_mask = (matrix == 1) & np.isfinite(matrix)
|
|
150
|
+
|
|
151
|
+
methylated = methyl_mask.sum(axis=0)
|
|
152
|
+
valid = valid_mask.sum(axis=0)
|
|
153
|
+
|
|
154
|
+
return np.divide(
|
|
155
|
+
methylated, valid, out=np.zeros_like(methylated, dtype=float), where=valid != 0
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _methylation_fraction_for_layer(
|
|
160
|
+
matrix: np.ndarray,
|
|
161
|
+
layer_name: str | None,
|
|
162
|
+
*,
|
|
163
|
+
ignore_nan: bool = True,
|
|
164
|
+
zero_is_valid: bool | None = None,
|
|
165
|
+
) -> np.ndarray:
|
|
166
|
+
"""Compute methylation fractions with layer-aware zero handling."""
|
|
167
|
+
matrix = np.asarray(matrix)
|
|
168
|
+
if zero_is_valid is None:
|
|
169
|
+
zero_is_valid = _infer_zero_is_valid(layer_name, matrix)
|
|
170
|
+
return methylation_fraction(matrix, ignore_nan=ignore_nan, zero_is_valid=zero_is_valid)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def clean_barplot(
|
|
174
|
+
ax,
|
|
175
|
+
mean_values,
|
|
176
|
+
title,
|
|
177
|
+
*,
|
|
178
|
+
y_max: float | None = 1.0,
|
|
179
|
+
y_label: str = "Mean",
|
|
180
|
+
y_ticks: list[float] | None = None,
|
|
181
|
+
):
|
|
182
|
+
"""Format a barplot with consistent axes and labels.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
ax: Matplotlib axes.
|
|
186
|
+
mean_values: Values to plot.
|
|
187
|
+
title: Plot title.
|
|
188
|
+
y_max: Optional y-axis max; inferred from data if not provided.
|
|
189
|
+
y_label: Y-axis label.
|
|
190
|
+
y_ticks: Optional y-axis ticks.
|
|
191
|
+
"""
|
|
192
|
+
logger.debug("Formatting barplot '%s' with %s values.", title, len(mean_values))
|
|
193
|
+
x = np.arange(len(mean_values))
|
|
194
|
+
ax.bar(x, mean_values, color="gray", width=1.0, align="edge")
|
|
195
|
+
ax.set_xlim(0, len(mean_values))
|
|
196
|
+
if y_ticks is None and y_max == 1.0:
|
|
197
|
+
y_ticks = [0.0, 0.5, 1.0]
|
|
198
|
+
if y_max is None:
|
|
199
|
+
y_max = np.nanmax(mean_values) if len(mean_values) else 1.0
|
|
200
|
+
if not np.isfinite(y_max) or y_max <= 0:
|
|
201
|
+
y_max = 1.0
|
|
202
|
+
y_max *= 1.05
|
|
203
|
+
ax.set_ylim(0, y_max)
|
|
204
|
+
if y_ticks is not None:
|
|
205
|
+
ax.set_yticks(y_ticks)
|
|
206
|
+
ax.set_ylabel(y_label)
|
|
207
|
+
ax.set_title(title, fontsize=12, pad=2)
|
|
208
|
+
|
|
209
|
+
for spine_name, spine in ax.spines.items():
|
|
210
|
+
spine.set_visible(spine_name == "left")
|
|
211
|
+
|
|
212
|
+
ax.tick_params(axis="x", which="both", bottom=False, top=False, labelbottom=False)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def make_row_colors(meta: pd.DataFrame) -> pd.DataFrame:
|
|
216
|
+
"""
|
|
217
|
+
Convert metadata columns to RGB colors without invoking pandas Categorical.map
|
|
218
|
+
(MultiIndex-safe, category-safe).
|
|
219
|
+
"""
|
|
220
|
+
row_colors = pd.DataFrame(index=meta.index)
|
|
221
|
+
|
|
222
|
+
for col in meta.columns:
|
|
223
|
+
s = meta[col].astype("object")
|
|
224
|
+
|
|
225
|
+
def _to_label(x: Any) -> str:
|
|
226
|
+
if x is None:
|
|
227
|
+
return "NA"
|
|
228
|
+
if isinstance(x, float) and np.isnan(x):
|
|
229
|
+
return "NA"
|
|
230
|
+
if isinstance(x, pd.MultiIndex):
|
|
231
|
+
return "MultiIndex"
|
|
232
|
+
if isinstance(x, tuple):
|
|
233
|
+
return "|".join(map(str, x))
|
|
234
|
+
return str(x)
|
|
235
|
+
|
|
236
|
+
labels = np.array([_to_label(x) for x in s.to_numpy()], dtype=object)
|
|
237
|
+
uniq = pd.unique(labels)
|
|
238
|
+
palette = dict(zip(uniq, sns.color_palette(n_colors=len(uniq))))
|
|
239
|
+
|
|
240
|
+
colors = [palette.get(lbl, (0.7, 0.7, 0.7)) for lbl in labels]
|
|
241
|
+
row_colors[col] = colors
|
|
242
|
+
|
|
243
|
+
return row_colors
|
|
@@ -1,7 +1,10 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from smftools.logging_utils import get_logger
|
|
3
4
|
from smftools.optional_imports import require
|
|
4
5
|
|
|
6
|
+
logger = get_logger(__name__)
|
|
7
|
+
|
|
5
8
|
|
|
6
9
|
def plot_volcano_relative_risk(
|
|
7
10
|
results_dict,
|
|
@@ -29,10 +32,11 @@ def plot_volcano_relative_risk(
|
|
|
29
32
|
|
|
30
33
|
plt = require("matplotlib.pyplot", extra="plotting", purpose="relative risk plots")
|
|
31
34
|
|
|
35
|
+
logger.info("Plotting volcano relative risk plots.")
|
|
32
36
|
for ref, group_results in results_dict.items():
|
|
33
37
|
for group_label, (results_df, _) in group_results.items():
|
|
34
38
|
if results_df.empty:
|
|
35
|
-
|
|
39
|
+
logger.warning("Skipping empty results for %s / %s.", ref, group_label)
|
|
36
40
|
continue
|
|
37
41
|
|
|
38
42
|
# Split by site type
|
|
@@ -100,7 +104,7 @@ def plot_volcano_relative_risk(
|
|
|
100
104
|
)
|
|
101
105
|
out_file = os.path.join(save_path, f"{safe_name}.png")
|
|
102
106
|
plt.savefig(out_file, dpi=300)
|
|
103
|
-
|
|
107
|
+
logger.info("Saved volcano relative risk plot to %s.", out_file)
|
|
104
108
|
|
|
105
109
|
plt.show()
|
|
106
110
|
|
|
@@ -131,10 +135,11 @@ def plot_bar_relative_risk(
|
|
|
131
135
|
|
|
132
136
|
plt = require("matplotlib.pyplot", extra="plotting", purpose="relative risk plots")
|
|
133
137
|
|
|
138
|
+
logger.info("Plotting bar relative risk plots.")
|
|
134
139
|
for ref, group_data in results_dict.items():
|
|
135
140
|
for group_label, (df, _) in group_data.items():
|
|
136
141
|
if df.empty:
|
|
137
|
-
|
|
142
|
+
logger.warning("Skipping empty result for %s / %s.", ref, group_label)
|
|
138
143
|
continue
|
|
139
144
|
|
|
140
145
|
df = df.copy()
|
|
@@ -206,7 +211,7 @@ def plot_bar_relative_risk(
|
|
|
206
211
|
)
|
|
207
212
|
out_file = os.path.join(save_path, f"{safe_name}.png")
|
|
208
213
|
plt.savefig(out_file, dpi=300)
|
|
209
|
-
|
|
214
|
+
logger.info("Saved bar relative risk plot to %s.", out_file)
|
|
210
215
|
|
|
211
216
|
plt.show()
|
|
212
217
|
|
|
@@ -240,6 +245,8 @@ def plot_positionwise_matrix(
|
|
|
240
245
|
plt = require("matplotlib.pyplot", extra="plotting", purpose="position stats plots")
|
|
241
246
|
sns = require("seaborn", extra="plotting", purpose="position stats plots")
|
|
242
247
|
|
|
248
|
+
logger.info("Plotting positionwise matrices for key '%s'.", key)
|
|
249
|
+
|
|
243
250
|
def find_closest_index(index, target):
|
|
244
251
|
"""Find the index value closest to a target value."""
|
|
245
252
|
index_vals = pd.to_numeric(index, errors="coerce")
|
|
@@ -357,7 +364,7 @@ def plot_positionwise_matrix(
|
|
|
357
364
|
va="center",
|
|
358
365
|
fontsize=10,
|
|
359
366
|
)
|
|
360
|
-
|
|
367
|
+
logger.warning("Error plotting line for %s=%s: %s", highlight_axis, pos, e)
|
|
361
368
|
|
|
362
369
|
line_ax.set_title(f"{highlight_axis.capitalize()} Profile(s)")
|
|
363
370
|
line_ax.set_xlabel(f"{'Column' if highlight_axis == 'row' else 'Row'} position")
|
|
@@ -373,7 +380,7 @@ def plot_positionwise_matrix(
|
|
|
373
380
|
safe_name = group.replace("=", "").replace("__", "_").replace(",", "_")
|
|
374
381
|
out_file = os.path.join(save_path, f"{key}_{safe_name}.png")
|
|
375
382
|
plt.savefig(out_file, dpi=300)
|
|
376
|
-
|
|
383
|
+
logger.info("Saved positionwise matrix plot to %s.", out_file)
|
|
377
384
|
|
|
378
385
|
plt.show()
|
|
379
386
|
|
|
@@ -423,6 +430,7 @@ def plot_positionwise_matrix_grid(
|
|
|
423
430
|
grid_spec = require("matplotlib.gridspec", extra="plotting", purpose="position stats plots")
|
|
424
431
|
GridSpec = grid_spec.GridSpec
|
|
425
432
|
|
|
433
|
+
logger.info("Plotting positionwise matrix grid for key '%s'.", key)
|
|
426
434
|
matrices = adata.uns[key]
|
|
427
435
|
group_labels = list(matrices.keys())
|
|
428
436
|
|
|
@@ -515,7 +523,7 @@ def plot_positionwise_matrix_grid(
|
|
|
515
523
|
os.makedirs(save_path, exist_ok=True)
|
|
516
524
|
fname = outer_label.replace("_", "").replace("=", "") + ".png"
|
|
517
525
|
plt.savefig(os.path.join(save_path, fname), dpi=300, bbox_inches="tight")
|
|
518
|
-
|
|
526
|
+
logger.info("Saved positionwise matrix grid plot to %s.", fname)
|
|
519
527
|
|
|
520
528
|
plt.close(fig)
|
|
521
529
|
|
|
@@ -527,4 +535,4 @@ def plot_positionwise_matrix_grid(
|
|
|
527
535
|
for outer_label in parsed["outer"].unique():
|
|
528
536
|
plot_one_grid(outer_label)
|
|
529
537
|
|
|
530
|
-
|
|
538
|
+
logger.info("Finished plotting all grids.")
|
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Dict, List, Sequence
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
import scipy.cluster.hierarchy as sch
|
|
9
|
+
|
|
10
|
+
from smftools.logging_utils import get_logger
|
|
11
|
+
from smftools.optional_imports import require
|
|
12
|
+
|
|
13
|
+
colors = require("matplotlib.colors", extra="plotting", purpose="plot rendering")
|
|
14
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="plot rendering")
|
|
15
|
+
sns = require("seaborn", extra="plotting", purpose="plot styling")
|
|
16
|
+
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def plot_read_span_quality_clustermaps(
|
|
21
|
+
adata,
|
|
22
|
+
sample_col: str = "Sample_Names",
|
|
23
|
+
reference_col: str = "Reference_strand",
|
|
24
|
+
quality_layer: str = "base_quality_scores",
|
|
25
|
+
read_span_layer: str = "read_span_mask",
|
|
26
|
+
quality_cmap: str = "viridis",
|
|
27
|
+
read_span_color: str = "#2ca25f",
|
|
28
|
+
max_nan_fraction: float | None = None,
|
|
29
|
+
min_quality: float | None = None,
|
|
30
|
+
min_length: int | None = None,
|
|
31
|
+
min_mapped_length_to_reference_length_ratio: float | None = None,
|
|
32
|
+
demux_types: Sequence[str] = ("single", "double", "already"),
|
|
33
|
+
max_reads: int | None = None,
|
|
34
|
+
xtick_step: int | None = None,
|
|
35
|
+
xtick_rotation: int = 90,
|
|
36
|
+
xtick_fontsize: int = 9,
|
|
37
|
+
show_position_axis: bool = False,
|
|
38
|
+
position_axis_tick_target: int = 25,
|
|
39
|
+
save_path: str | Path | None = None,
|
|
40
|
+
) -> List[Dict[str, Any]]:
|
|
41
|
+
"""Plot read-span mask and base quality clustermaps side by side.
|
|
42
|
+
|
|
43
|
+
Clustering is performed using the base-quality layer ordering, which is then
|
|
44
|
+
applied to the read-span mask to keep the two panels aligned.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
adata: AnnData with read-span and base-quality layers.
|
|
48
|
+
sample_col: Column in ``adata.obs`` that identifies samples.
|
|
49
|
+
reference_col: Column in ``adata.obs`` that identifies references.
|
|
50
|
+
quality_layer: Layer name containing base-quality scores.
|
|
51
|
+
read_span_layer: Layer name containing read-span masks.
|
|
52
|
+
quality_cmap: Colormap for base-quality scores.
|
|
53
|
+
read_span_color: Color for read-span mask (1-values); 0-values are white.
|
|
54
|
+
max_nan_fraction: Optional maximum fraction of NaNs allowed per position; positions
|
|
55
|
+
above this threshold are excluded.
|
|
56
|
+
min_quality: Optional minimum read quality filter.
|
|
57
|
+
min_length: Optional minimum mapped length filter.
|
|
58
|
+
min_mapped_length_to_reference_length_ratio: Optional min length ratio filter.
|
|
59
|
+
demux_types: Allowed ``demux_type`` values, if present in ``adata.obs``.
|
|
60
|
+
max_reads: Optional maximum number of reads to plot per sample/reference.
|
|
61
|
+
xtick_step: Spacing between x-axis tick labels (None = no labels).
|
|
62
|
+
xtick_rotation: Rotation for x-axis tick labels.
|
|
63
|
+
xtick_fontsize: Font size for x-axis tick labels.
|
|
64
|
+
show_position_axis: Whether to draw a position axis with tick labels.
|
|
65
|
+
position_axis_tick_target: Approximate number of ticks to show when auto-sizing.
|
|
66
|
+
save_path: Optional output directory for saving plots.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
List of dictionaries with per-plot metadata and output paths.
|
|
70
|
+
"""
|
|
71
|
+
logger.info("Plotting read span and quality clustermaps.")
|
|
72
|
+
|
|
73
|
+
def _mask_or_true(series_name: str, predicate):
|
|
74
|
+
if series_name not in adata.obs:
|
|
75
|
+
return pd.Series(True, index=adata.obs.index)
|
|
76
|
+
s = adata.obs[series_name]
|
|
77
|
+
try:
|
|
78
|
+
return predicate(s)
|
|
79
|
+
except Exception:
|
|
80
|
+
return pd.Series(True, index=s.index)
|
|
81
|
+
|
|
82
|
+
def _resolve_xtick_step(n_positions: int) -> int | None:
|
|
83
|
+
if xtick_step is not None:
|
|
84
|
+
return xtick_step
|
|
85
|
+
if not show_position_axis:
|
|
86
|
+
return None
|
|
87
|
+
return max(1, int(np.ceil(n_positions / position_axis_tick_target)))
|
|
88
|
+
|
|
89
|
+
def _fill_nan_with_col_means(matrix: np.ndarray) -> np.ndarray:
|
|
90
|
+
filled = matrix.copy()
|
|
91
|
+
col_means = np.nanmean(filled, axis=0)
|
|
92
|
+
col_means = np.where(np.isnan(col_means), 0.0, col_means)
|
|
93
|
+
nan_rows, nan_cols = np.where(np.isnan(filled))
|
|
94
|
+
filled[nan_rows, nan_cols] = col_means[nan_cols]
|
|
95
|
+
return filled
|
|
96
|
+
|
|
97
|
+
if quality_layer not in adata.layers:
|
|
98
|
+
raise KeyError(f"Layer '{quality_layer}' not found in adata.layers")
|
|
99
|
+
if read_span_layer not in adata.layers:
|
|
100
|
+
raise KeyError(f"Layer '{read_span_layer}' not found in adata.layers")
|
|
101
|
+
if max_nan_fraction is not None and not (0 <= max_nan_fraction <= 1):
|
|
102
|
+
raise ValueError("max_nan_fraction must be between 0 and 1.")
|
|
103
|
+
if position_axis_tick_target < 1:
|
|
104
|
+
raise ValueError("position_axis_tick_target must be at least 1.")
|
|
105
|
+
|
|
106
|
+
results: List[Dict[str, Any]] = []
|
|
107
|
+
save_path = Path(save_path) if save_path is not None else None
|
|
108
|
+
if save_path is not None:
|
|
109
|
+
save_path.mkdir(parents=True, exist_ok=True)
|
|
110
|
+
|
|
111
|
+
for col in (sample_col, reference_col):
|
|
112
|
+
if col not in adata.obs:
|
|
113
|
+
raise KeyError(f"{col} not in adata.obs")
|
|
114
|
+
if not isinstance(adata.obs[col].dtype, pd.CategoricalDtype):
|
|
115
|
+
adata.obs[col] = adata.obs[col].astype("category")
|
|
116
|
+
|
|
117
|
+
for ref in adata.obs[reference_col].cat.categories:
|
|
118
|
+
for sample in adata.obs[sample_col].cat.categories:
|
|
119
|
+
qmask = _mask_or_true(
|
|
120
|
+
"read_quality",
|
|
121
|
+
(lambda s: s >= float(min_quality))
|
|
122
|
+
if (min_quality is not None)
|
|
123
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
124
|
+
)
|
|
125
|
+
lm_mask = _mask_or_true(
|
|
126
|
+
"mapped_length",
|
|
127
|
+
(lambda s: s >= float(min_length))
|
|
128
|
+
if (min_length is not None)
|
|
129
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
130
|
+
)
|
|
131
|
+
lrr_mask = _mask_or_true(
|
|
132
|
+
"mapped_length_to_reference_length_ratio",
|
|
133
|
+
(lambda s: s >= float(min_mapped_length_to_reference_length_ratio))
|
|
134
|
+
if (min_mapped_length_to_reference_length_ratio is not None)
|
|
135
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
136
|
+
)
|
|
137
|
+
demux_mask = _mask_or_true(
|
|
138
|
+
"demux_type",
|
|
139
|
+
(lambda s: s.astype("string").isin(list(demux_types)))
|
|
140
|
+
if (demux_types is not None)
|
|
141
|
+
else (lambda s: pd.Series(True, index=s.index)),
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
row_mask = (
|
|
145
|
+
(adata.obs[reference_col] == ref)
|
|
146
|
+
& (adata.obs[sample_col] == sample)
|
|
147
|
+
& qmask
|
|
148
|
+
& lm_mask
|
|
149
|
+
& lrr_mask
|
|
150
|
+
& demux_mask
|
|
151
|
+
)
|
|
152
|
+
if not bool(row_mask.any()):
|
|
153
|
+
continue
|
|
154
|
+
|
|
155
|
+
subset = adata[row_mask, :].copy()
|
|
156
|
+
quality_matrix = np.asarray(subset.layers[quality_layer]).astype(float)
|
|
157
|
+
quality_matrix[quality_matrix < 0] = np.nan
|
|
158
|
+
read_span_matrix = np.asarray(subset.layers[read_span_layer]).astype(float)
|
|
159
|
+
|
|
160
|
+
if max_nan_fraction is not None:
|
|
161
|
+
nan_mask = np.isnan(quality_matrix) | np.isnan(read_span_matrix)
|
|
162
|
+
nan_fraction = nan_mask.mean(axis=0)
|
|
163
|
+
keep_columns = nan_fraction <= max_nan_fraction
|
|
164
|
+
if not np.any(keep_columns):
|
|
165
|
+
continue
|
|
166
|
+
quality_matrix = quality_matrix[:, keep_columns]
|
|
167
|
+
read_span_matrix = read_span_matrix[:, keep_columns]
|
|
168
|
+
subset = subset[:, keep_columns].copy()
|
|
169
|
+
|
|
170
|
+
if max_reads is not None and quality_matrix.shape[0] > max_reads:
|
|
171
|
+
quality_matrix = quality_matrix[:max_reads]
|
|
172
|
+
read_span_matrix = read_span_matrix[:max_reads]
|
|
173
|
+
subset = subset[:max_reads, :].copy()
|
|
174
|
+
|
|
175
|
+
if quality_matrix.size == 0:
|
|
176
|
+
continue
|
|
177
|
+
|
|
178
|
+
quality_filled = _fill_nan_with_col_means(quality_matrix)
|
|
179
|
+
linkage = sch.linkage(quality_filled, method="ward")
|
|
180
|
+
order = sch.leaves_list(linkage)
|
|
181
|
+
|
|
182
|
+
quality_matrix = quality_matrix[order]
|
|
183
|
+
read_span_matrix = read_span_matrix[order]
|
|
184
|
+
|
|
185
|
+
fig, axes = plt.subplots(
|
|
186
|
+
nrows=2,
|
|
187
|
+
ncols=3,
|
|
188
|
+
figsize=(18, 6),
|
|
189
|
+
sharex="col",
|
|
190
|
+
gridspec_kw={"height_ratios": [1, 4], "width_ratios": [1, 1, 0.05]},
|
|
191
|
+
)
|
|
192
|
+
span_bar_ax, quality_bar_ax, bar_spacer_ax = axes[0]
|
|
193
|
+
span_ax, quality_ax, cbar_ax = axes[1]
|
|
194
|
+
bar_spacer_ax.set_axis_off()
|
|
195
|
+
|
|
196
|
+
span_mean = np.nanmean(read_span_matrix, axis=0)
|
|
197
|
+
quality_mean = np.nanmean(quality_matrix, axis=0)
|
|
198
|
+
bar_positions = np.arange(read_span_matrix.shape[1]) + 0.5
|
|
199
|
+
span_bar_ax.bar(
|
|
200
|
+
bar_positions,
|
|
201
|
+
span_mean,
|
|
202
|
+
color=read_span_color,
|
|
203
|
+
width=1.0,
|
|
204
|
+
)
|
|
205
|
+
span_bar_ax.set_title(f"{read_span_layer} mean")
|
|
206
|
+
span_bar_ax.set_xlim(0, read_span_matrix.shape[1])
|
|
207
|
+
span_bar_ax.tick_params(axis="x", labelbottom=False)
|
|
208
|
+
|
|
209
|
+
quality_bar_ax.bar(
|
|
210
|
+
bar_positions,
|
|
211
|
+
quality_mean,
|
|
212
|
+
color="#4c72b0",
|
|
213
|
+
width=1.0,
|
|
214
|
+
)
|
|
215
|
+
quality_bar_ax.set_title(f"{quality_layer} mean")
|
|
216
|
+
quality_bar_ax.set_xlim(0, quality_matrix.shape[1])
|
|
217
|
+
quality_bar_ax.tick_params(axis="x", labelbottom=False)
|
|
218
|
+
|
|
219
|
+
span_cmap = colors.ListedColormap(["white", read_span_color])
|
|
220
|
+
span_norm = colors.BoundaryNorm([-0.5, 0.5, 1.5], span_cmap.N)
|
|
221
|
+
sns.heatmap(
|
|
222
|
+
read_span_matrix,
|
|
223
|
+
cmap=span_cmap,
|
|
224
|
+
norm=span_norm,
|
|
225
|
+
ax=span_ax,
|
|
226
|
+
yticklabels=False,
|
|
227
|
+
cbar=False,
|
|
228
|
+
)
|
|
229
|
+
span_ax.set_title(read_span_layer)
|
|
230
|
+
|
|
231
|
+
sns.heatmap(
|
|
232
|
+
quality_matrix,
|
|
233
|
+
cmap=quality_cmap,
|
|
234
|
+
ax=quality_ax,
|
|
235
|
+
yticklabels=False,
|
|
236
|
+
cbar=True,
|
|
237
|
+
cbar_ax=cbar_ax,
|
|
238
|
+
)
|
|
239
|
+
quality_ax.set_title(quality_layer)
|
|
240
|
+
|
|
241
|
+
resolved_step = _resolve_xtick_step(quality_matrix.shape[1])
|
|
242
|
+
for axis in (span_ax, quality_ax):
|
|
243
|
+
if resolved_step is not None and resolved_step > 0:
|
|
244
|
+
sites = np.arange(0, quality_matrix.shape[1], resolved_step)
|
|
245
|
+
axis.set_xticks(sites)
|
|
246
|
+
axis.set_xticklabels(
|
|
247
|
+
subset.var_names[sites].astype(str),
|
|
248
|
+
rotation=xtick_rotation,
|
|
249
|
+
fontsize=xtick_fontsize,
|
|
250
|
+
)
|
|
251
|
+
else:
|
|
252
|
+
axis.set_xticks([])
|
|
253
|
+
if show_position_axis or xtick_step is not None:
|
|
254
|
+
axis.set_xlabel("Position")
|
|
255
|
+
|
|
256
|
+
n_reads = quality_matrix.shape[0]
|
|
257
|
+
fig.suptitle(f"{sample} - {ref} - {n_reads} reads")
|
|
258
|
+
fig.tight_layout(rect=(0, 0, 1, 0.95))
|
|
259
|
+
|
|
260
|
+
out_file = None
|
|
261
|
+
if save_path is not None:
|
|
262
|
+
safe_name = f"{ref}__{sample}__read_span_quality".replace("=", "").replace(",", "_")
|
|
263
|
+
out_file = save_path / f"{safe_name}.png"
|
|
264
|
+
fig.savefig(out_file, dpi=300, bbox_inches="tight")
|
|
265
|
+
plt.close(fig)
|
|
266
|
+
logger.info("Saved read span/quality clustermap to %s.", out_file)
|
|
267
|
+
else:
|
|
268
|
+
plt.show()
|
|
269
|
+
|
|
270
|
+
results.append(
|
|
271
|
+
{
|
|
272
|
+
"reference": str(ref),
|
|
273
|
+
"sample": str(sample),
|
|
274
|
+
"quality_layer": quality_layer,
|
|
275
|
+
"read_span_layer": read_span_layer,
|
|
276
|
+
"n_positions": int(quality_matrix.shape[1]),
|
|
277
|
+
"output_path": str(out_file) if out_file is not None else None,
|
|
278
|
+
}
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
return results
|
smftools/plotting/qc_plotting.py
CHANGED
|
@@ -5,10 +5,13 @@ import os
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import pandas as pd
|
|
7
7
|
|
|
8
|
+
from smftools.logging_utils import get_logger
|
|
8
9
|
from smftools.optional_imports import require
|
|
9
10
|
|
|
10
11
|
plt = require("matplotlib.pyplot", extra="plotting", purpose="QC plots")
|
|
11
12
|
|
|
13
|
+
logger = get_logger(__name__)
|
|
14
|
+
|
|
12
15
|
|
|
13
16
|
def plot_read_qc_histograms(
|
|
14
17
|
adata,
|
|
@@ -53,6 +56,7 @@ def plot_read_qc_histograms(
|
|
|
53
56
|
dpi : int
|
|
54
57
|
Figure resolution.
|
|
55
58
|
"""
|
|
59
|
+
logger.info("Plotting read QC histograms to %s.", outdir)
|
|
56
60
|
os.makedirs(outdir, exist_ok=True)
|
|
57
61
|
|
|
58
62
|
if sample_key not in adata.obs.columns:
|
|
@@ -60,7 +64,7 @@ def plot_read_qc_histograms(
|
|
|
60
64
|
|
|
61
65
|
# Ensure sample_key is categorical for stable ordering
|
|
62
66
|
samples = adata.obs[sample_key]
|
|
63
|
-
if not pd.
|
|
67
|
+
if not isinstance(samples.dtype, pd.CategoricalDtype):
|
|
64
68
|
samples = samples.astype("category")
|
|
65
69
|
sample_levels = list(samples.cat.categories)
|
|
66
70
|
|
|
@@ -69,14 +73,14 @@ def plot_read_qc_histograms(
|
|
|
69
73
|
is_numeric = {}
|
|
70
74
|
for key in obs_keys:
|
|
71
75
|
if key not in adata.obs.columns:
|
|
72
|
-
|
|
76
|
+
logger.warning("'%s' not found in obs; skipping.", key)
|
|
73
77
|
continue
|
|
74
78
|
s = adata.obs[key]
|
|
75
79
|
num = pd.api.types.is_numeric_dtype(s)
|
|
76
80
|
valid_keys.append(key)
|
|
77
81
|
is_numeric[key] = num
|
|
78
82
|
if not valid_keys:
|
|
79
|
-
|
|
83
|
+
logger.warning("No valid obs_keys to plot.")
|
|
80
84
|
return
|
|
81
85
|
|
|
82
86
|
# Precompute global numeric ranges (after clipping) so rows share x-axis per column
|
|
@@ -174,6 +178,7 @@ def plot_read_qc_histograms(
|
|
|
174
178
|
page = start // rows_per_fig + 1
|
|
175
179
|
out_png = os.path.join(outdir, f"qc_grid_{_sanitize(sample_key)}_page{page}.png")
|
|
176
180
|
plt.savefig(out_png, bbox_inches="tight")
|
|
181
|
+
logger.info("Saved QC histogram page to %s.", out_png)
|
|
177
182
|
plt.close(fig)
|
|
178
183
|
|
|
179
184
|
|