smftools 0.2.5__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/__init__.py +39 -7
- smftools/_settings.py +2 -0
- smftools/_version.py +3 -1
- smftools/cli/__init__.py +1 -0
- smftools/cli/archived/cli_flows.py +2 -0
- smftools/cli/helpers.py +34 -6
- smftools/cli/hmm_adata.py +239 -33
- smftools/cli/latent_adata.py +318 -0
- smftools/cli/load_adata.py +167 -131
- smftools/cli/preprocess_adata.py +180 -53
- smftools/cli/spatial_adata.py +152 -100
- smftools/cli_entry.py +38 -1
- smftools/config/__init__.py +2 -0
- smftools/config/conversion.yaml +11 -1
- smftools/config/default.yaml +42 -2
- smftools/config/experiment_config.py +59 -1
- smftools/constants.py +65 -0
- smftools/datasets/__init__.py +2 -0
- smftools/hmm/HMM.py +97 -3
- smftools/hmm/__init__.py +24 -13
- smftools/hmm/archived/apply_hmm_batched.py +2 -0
- smftools/hmm/archived/calculate_distances.py +2 -0
- smftools/hmm/archived/call_hmm_peaks.py +2 -0
- smftools/hmm/archived/train_hmm.py +2 -0
- smftools/hmm/call_hmm_peaks.py +5 -2
- smftools/hmm/display_hmm.py +4 -1
- smftools/hmm/hmm_readwrite.py +7 -2
- smftools/hmm/nucleosome_hmm_refinement.py +2 -0
- smftools/informatics/__init__.py +59 -34
- smftools/informatics/archived/bam_conversion.py +2 -0
- smftools/informatics/archived/bam_direct.py +2 -0
- smftools/informatics/archived/basecall_pod5s.py +2 -0
- smftools/informatics/archived/basecalls_to_adata.py +2 -0
- smftools/informatics/archived/conversion_smf.py +2 -0
- smftools/informatics/archived/deaminase_smf.py +1 -0
- smftools/informatics/archived/direct_smf.py +2 -0
- smftools/informatics/archived/fast5_to_pod5.py +2 -0
- smftools/informatics/archived/helpers/archived/__init__.py +2 -0
- smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
- smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
- smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
- smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
- smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
- smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
- smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
- smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
- smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
- smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
- smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
- smftools/informatics/archived/helpers/archived/informatics.py +2 -0
- smftools/informatics/archived/helpers/archived/load_adata.py +2 -0
- smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
- smftools/informatics/archived/helpers/archived/modQC.py +2 -0
- smftools/informatics/archived/helpers/archived/modcall.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
- smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
- smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
- smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +2 -0
- smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
- smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
- smftools/informatics/archived/print_bam_query_seq.py +2 -0
- smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
- smftools/informatics/archived/subsample_pod5.py +2 -0
- smftools/informatics/bam_functions.py +1093 -176
- smftools/informatics/basecalling.py +2 -0
- smftools/informatics/bed_functions.py +271 -61
- smftools/informatics/binarize_converted_base_identities.py +3 -0
- smftools/informatics/complement_base_list.py +2 -0
- smftools/informatics/converted_BAM_to_adata.py +641 -176
- smftools/informatics/fasta_functions.py +94 -10
- smftools/informatics/h5ad_functions.py +123 -4
- smftools/informatics/modkit_extract_to_adata.py +1019 -431
- smftools/informatics/modkit_functions.py +2 -0
- smftools/informatics/ohe.py +2 -0
- smftools/informatics/pod5_functions.py +3 -2
- smftools/informatics/sequence_encoding.py +72 -0
- smftools/logging_utils.py +21 -2
- smftools/machine_learning/__init__.py +22 -6
- smftools/machine_learning/data/__init__.py +2 -0
- smftools/machine_learning/data/anndata_data_module.py +18 -4
- smftools/machine_learning/data/preprocessing.py +2 -0
- smftools/machine_learning/evaluation/__init__.py +2 -0
- smftools/machine_learning/evaluation/eval_utils.py +2 -0
- smftools/machine_learning/evaluation/evaluators.py +14 -9
- smftools/machine_learning/inference/__init__.py +2 -0
- smftools/machine_learning/inference/inference_utils.py +2 -0
- smftools/machine_learning/inference/lightning_inference.py +6 -1
- smftools/machine_learning/inference/sklearn_inference.py +2 -0
- smftools/machine_learning/inference/sliding_window_inference.py +2 -0
- smftools/machine_learning/models/__init__.py +2 -0
- smftools/machine_learning/models/base.py +7 -2
- smftools/machine_learning/models/cnn.py +7 -2
- smftools/machine_learning/models/lightning_base.py +16 -11
- smftools/machine_learning/models/mlp.py +5 -1
- smftools/machine_learning/models/positional.py +7 -2
- smftools/machine_learning/models/rnn.py +5 -1
- smftools/machine_learning/models/sklearn_models.py +14 -9
- smftools/machine_learning/models/transformer.py +7 -2
- smftools/machine_learning/models/wrappers.py +6 -2
- smftools/machine_learning/training/__init__.py +2 -0
- smftools/machine_learning/training/train_lightning_model.py +13 -3
- smftools/machine_learning/training/train_sklearn_model.py +2 -0
- smftools/machine_learning/utils/__init__.py +2 -0
- smftools/machine_learning/utils/device.py +5 -1
- smftools/machine_learning/utils/grl.py +5 -1
- smftools/metadata.py +1 -1
- smftools/optional_imports.py +31 -0
- smftools/plotting/__init__.py +41 -31
- smftools/plotting/autocorrelation_plotting.py +9 -5
- smftools/plotting/classifiers.py +16 -4
- smftools/plotting/general_plotting.py +2415 -629
- smftools/plotting/hmm_plotting.py +97 -9
- smftools/plotting/position_stats.py +15 -7
- smftools/plotting/qc_plotting.py +6 -1
- smftools/preprocessing/__init__.py +36 -37
- smftools/preprocessing/append_base_context.py +17 -17
- smftools/preprocessing/append_mismatch_frequency_sites.py +158 -0
- smftools/preprocessing/archived/add_read_length_and_mapping_qc.py +2 -0
- smftools/preprocessing/archived/calculate_complexity.py +2 -0
- smftools/preprocessing/archived/mark_duplicates.py +2 -0
- smftools/preprocessing/archived/preprocessing.py +2 -0
- smftools/preprocessing/archived/remove_duplicates.py +2 -0
- smftools/preprocessing/binary_layers_to_ohe.py +2 -1
- smftools/preprocessing/calculate_complexity_II.py +4 -1
- smftools/preprocessing/calculate_consensus.py +1 -1
- smftools/preprocessing/calculate_pairwise_differences.py +2 -0
- smftools/preprocessing/calculate_pairwise_hamming_distances.py +3 -0
- smftools/preprocessing/calculate_position_Youden.py +9 -2
- smftools/preprocessing/calculate_read_modification_stats.py +6 -1
- smftools/preprocessing/filter_reads_on_length_quality_mapping.py +2 -0
- smftools/preprocessing/filter_reads_on_modification_thresholds.py +2 -0
- smftools/preprocessing/flag_duplicate_reads.py +42 -54
- smftools/preprocessing/make_dirs.py +2 -1
- smftools/preprocessing/min_non_diagonal.py +2 -0
- smftools/preprocessing/recipes.py +2 -0
- smftools/readwrite.py +53 -17
- smftools/schema/anndata_schema_v1.yaml +15 -1
- smftools/tools/__init__.py +30 -18
- smftools/tools/archived/apply_hmm.py +2 -0
- smftools/tools/archived/classifiers.py +2 -0
- smftools/tools/archived/classify_methylated_features.py +2 -0
- smftools/tools/archived/classify_non_methylated_features.py +2 -0
- smftools/tools/archived/subset_adata_v1.py +2 -0
- smftools/tools/archived/subset_adata_v2.py +2 -0
- smftools/tools/calculate_leiden.py +57 -0
- smftools/tools/calculate_nmf.py +119 -0
- smftools/tools/calculate_umap.py +93 -8
- smftools/tools/cluster_adata_on_methylation.py +7 -1
- smftools/tools/position_stats.py +17 -27
- smftools/tools/rolling_nn_distance.py +235 -0
- smftools/tools/tensor_factorization.py +169 -0
- {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/METADATA +69 -33
- smftools-0.3.1.dist-info/RECORD +189 -0
- smftools-0.2.5.dist-info/RECORD +0 -181
- {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/WHEEL +0 -0
- {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/entry_points.txt +0 -0
- {smftools-0.2.5.dist-info → smftools-0.3.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,9 +1,21 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import math
|
|
2
|
-
from typing import Optional, Tuple, Union
|
|
4
|
+
from typing import Optional, Sequence, Tuple, Union
|
|
3
5
|
|
|
4
|
-
import matplotlib.pyplot as plt
|
|
5
6
|
import numpy as np
|
|
6
|
-
|
|
7
|
+
import pandas as pd
|
|
8
|
+
|
|
9
|
+
from smftools.optional_imports import require
|
|
10
|
+
|
|
11
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="HMM plots")
|
|
12
|
+
mpl_colors = require("matplotlib.colors", extra="plotting", purpose="HMM plots")
|
|
13
|
+
pdf_backend = require(
|
|
14
|
+
"matplotlib.backends.backend_pdf",
|
|
15
|
+
extra="plotting",
|
|
16
|
+
purpose="PDF output",
|
|
17
|
+
)
|
|
18
|
+
PdfPages = pdf_backend.PdfPages
|
|
7
19
|
|
|
8
20
|
|
|
9
21
|
def plot_hmm_size_contours(
|
|
@@ -22,6 +34,9 @@ def plot_hmm_size_contours(
|
|
|
22
34
|
dpi: int = 150,
|
|
23
35
|
vmin: Optional[float] = None,
|
|
24
36
|
vmax: Optional[float] = None,
|
|
37
|
+
feature_ranges: Optional[Sequence[Tuple[int, int, str]]] = None,
|
|
38
|
+
zero_color: str = "#f5f1e8",
|
|
39
|
+
nan_color: str = "#E6E6E6",
|
|
25
40
|
# ---------------- smoothing params ----------------
|
|
26
41
|
smoothing_sigma: Optional[Union[float, Tuple[float, float]]] = None,
|
|
27
42
|
normalize_after_smoothing: bool = True,
|
|
@@ -30,6 +45,9 @@ def plot_hmm_size_contours(
|
|
|
30
45
|
"""
|
|
31
46
|
Create contour/pcolormesh plots of P(length | position) using a length-encoded HMM layer.
|
|
32
47
|
Optional Gaussian smoothing applied to the 2D probability grid before plotting.
|
|
48
|
+
When feature_ranges is provided, each length row is assigned a base color based
|
|
49
|
+
on the matching (min_len, max_len) range and the probability value modulates
|
|
50
|
+
the color intensity.
|
|
33
51
|
|
|
34
52
|
smoothing_sigma: None or 0 -> no smoothing.
|
|
35
53
|
float -> same sigma applied to (length_axis, position_axis)
|
|
@@ -38,6 +56,51 @@ def plot_hmm_size_contours(
|
|
|
38
56
|
|
|
39
57
|
Other args are the same as prior function.
|
|
40
58
|
"""
|
|
59
|
+
feature_ranges = tuple(feature_ranges or ())
|
|
60
|
+
|
|
61
|
+
def _resolve_length_color(length: int, fallback: str) -> Tuple[float, float, float, float]:
|
|
62
|
+
for min_len, max_len, color in feature_ranges:
|
|
63
|
+
if min_len <= length <= max_len:
|
|
64
|
+
return mpl_colors.to_rgba(color)
|
|
65
|
+
return mpl_colors.to_rgba(fallback)
|
|
66
|
+
|
|
67
|
+
def _build_length_facecolors(
|
|
68
|
+
Z_values: np.ndarray,
|
|
69
|
+
lengths: np.ndarray,
|
|
70
|
+
fallback_color: str,
|
|
71
|
+
*,
|
|
72
|
+
vmin_local: Optional[float],
|
|
73
|
+
vmax_local: Optional[float],
|
|
74
|
+
) -> np.ndarray:
|
|
75
|
+
zero_rgba = np.array(mpl_colors.to_rgba(zero_color))
|
|
76
|
+
nan_rgba = np.array(mpl_colors.to_rgba(nan_color))
|
|
77
|
+
base_colors = np.array(
|
|
78
|
+
[_resolve_length_color(int(length), fallback_color) for length in lengths],
|
|
79
|
+
dtype=float,
|
|
80
|
+
)
|
|
81
|
+
base_colors[:, 3] = 1.0
|
|
82
|
+
|
|
83
|
+
scale = np.array(Z_values, copy=True, dtype=float)
|
|
84
|
+
finite_mask = np.isfinite(scale)
|
|
85
|
+
if not finite_mask.any():
|
|
86
|
+
facecolors = np.zeros(scale.shape + (4,), dtype=float)
|
|
87
|
+
facecolors[:] = nan_rgba
|
|
88
|
+
return facecolors.reshape(-1, 4)
|
|
89
|
+
|
|
90
|
+
vmin_use = np.nanmin(scale) if vmin_local is None else vmin_local
|
|
91
|
+
vmax_use = np.nanmax(scale) if vmax_local is None else vmax_local
|
|
92
|
+
denom = vmax_use - vmin_use
|
|
93
|
+
if denom <= 0:
|
|
94
|
+
norm = np.zeros_like(scale)
|
|
95
|
+
else:
|
|
96
|
+
norm = (scale - vmin_use) / denom
|
|
97
|
+
norm = np.clip(norm, 0, 1)
|
|
98
|
+
|
|
99
|
+
row_colors = base_colors[:, None, :]
|
|
100
|
+
facecolors = zero_rgba + norm[..., None] * (row_colors - zero_rgba)
|
|
101
|
+
facecolors[..., 3] = 1.0
|
|
102
|
+
facecolors[~finite_mask] = nan_rgba
|
|
103
|
+
return facecolors.reshape(-1, 4)
|
|
41
104
|
|
|
42
105
|
# --- helper: gaussian smoothing (scipy fallback -> numpy separable conv) ---
|
|
43
106
|
def _gaussian_1d_kernel(sigma: float, eps: float = 1e-12):
|
|
@@ -140,7 +203,8 @@ def plot_hmm_size_contours(
|
|
|
140
203
|
figs = []
|
|
141
204
|
|
|
142
205
|
# decide global max length to allocate y axis (cap to avoid huge memory)
|
|
143
|
-
|
|
206
|
+
finite_lengths = full_layer[np.isfinite(full_layer) & (full_layer > 0)]
|
|
207
|
+
observed_max_len = int(np.nanmax(finite_lengths)) if finite_lengths.size > 0 else 0
|
|
144
208
|
if max_length_cap is None:
|
|
145
209
|
max_len = observed_max_len
|
|
146
210
|
else:
|
|
@@ -195,10 +259,15 @@ def plot_hmm_size_contours(
|
|
|
195
259
|
ax.text(0.5, 0.5, "no data", ha="center", va="center")
|
|
196
260
|
ax.set_title(f"{sample} / {ref}")
|
|
197
261
|
continue
|
|
262
|
+
valid_lengths = sub[np.isfinite(sub) & (sub > 0)]
|
|
263
|
+
if valid_lengths.size == 0:
|
|
264
|
+
ax.text(0.5, 0.5, "no data", ha="center", va="center")
|
|
265
|
+
ax.set_title(f"{sample} / {ref}")
|
|
266
|
+
continue
|
|
198
267
|
|
|
199
268
|
# compute counts per length per position
|
|
200
269
|
n_positions = sub.shape[1]
|
|
201
|
-
max_len_local = int(
|
|
270
|
+
max_len_local = int(valid_lengths.max()) if valid_lengths.size > 0 else 0
|
|
202
271
|
max_len_here = min(max_len, max_len_local)
|
|
203
272
|
|
|
204
273
|
lengths_range = np.arange(1, max_len_here + 1, dtype=int)
|
|
@@ -209,7 +278,7 @@ def plot_hmm_size_contours(
|
|
|
209
278
|
# fill Z by efficient bincount across columns
|
|
210
279
|
for j in range(n_positions):
|
|
211
280
|
col_vals = sub[:, j]
|
|
212
|
-
pos_vals = col_vals[col_vals > 0].astype(int)
|
|
281
|
+
pos_vals = col_vals[np.isfinite(col_vals) & (col_vals > 0)].astype(int)
|
|
213
282
|
if pos_vals.size == 0:
|
|
214
283
|
continue
|
|
215
284
|
clipped = np.clip(pos_vals, 1, max_len_here)
|
|
@@ -248,9 +317,28 @@ def plot_hmm_size_contours(
|
|
|
248
317
|
dy = 1.0
|
|
249
318
|
y_edges = np.concatenate([y - 0.5, [y[-1] + 0.5]])
|
|
250
319
|
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
320
|
+
if feature_ranges:
|
|
321
|
+
fallback_color = mpl_colors.to_rgba(plt.get_cmap(cmap)(1.0))
|
|
322
|
+
facecolors = _build_length_facecolors(
|
|
323
|
+
Z_plot,
|
|
324
|
+
lengths_range,
|
|
325
|
+
fallback_color,
|
|
326
|
+
vmin_local=vmin,
|
|
327
|
+
vmax_local=vmax,
|
|
328
|
+
)
|
|
329
|
+
pcm = ax.pcolormesh(
|
|
330
|
+
x_edges,
|
|
331
|
+
y_edges,
|
|
332
|
+
Z_plot,
|
|
333
|
+
shading="auto",
|
|
334
|
+
vmin=vmin,
|
|
335
|
+
vmax=vmax,
|
|
336
|
+
facecolors=facecolors,
|
|
337
|
+
)
|
|
338
|
+
else:
|
|
339
|
+
pcm = ax.pcolormesh(
|
|
340
|
+
x_edges, y_edges, Z_plot, cmap=cmap, shading="auto", vmin=vmin, vmax=vmax
|
|
341
|
+
)
|
|
254
342
|
ax.set_title(f"{sample} / {ref}")
|
|
255
343
|
ax.set_ylabel("length")
|
|
256
344
|
if i_row == rows_on_page - 1:
|
|
@@ -1,3 +1,8 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from smftools.optional_imports import require
|
|
4
|
+
|
|
5
|
+
|
|
1
6
|
def plot_volcano_relative_risk(
|
|
2
7
|
results_dict,
|
|
3
8
|
save_path=None,
|
|
@@ -22,7 +27,7 @@ def plot_volcano_relative_risk(
|
|
|
22
27
|
"""
|
|
23
28
|
import os
|
|
24
29
|
|
|
25
|
-
|
|
30
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="relative risk plots")
|
|
26
31
|
|
|
27
32
|
for ref, group_results in results_dict.items():
|
|
28
33
|
for group_label, (results_df, _) in group_results.items():
|
|
@@ -124,7 +129,7 @@ def plot_bar_relative_risk(
|
|
|
124
129
|
"""
|
|
125
130
|
import os
|
|
126
131
|
|
|
127
|
-
|
|
132
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="relative risk plots")
|
|
128
133
|
|
|
129
134
|
for ref, group_data in results_dict.items():
|
|
130
135
|
for group_label, (df, _) in group_data.items():
|
|
@@ -229,10 +234,11 @@ def plot_positionwise_matrix(
|
|
|
229
234
|
"""
|
|
230
235
|
import os
|
|
231
236
|
|
|
232
|
-
import matplotlib.pyplot as plt
|
|
233
237
|
import numpy as np
|
|
234
238
|
import pandas as pd
|
|
235
|
-
|
|
239
|
+
|
|
240
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="position stats plots")
|
|
241
|
+
sns = require("seaborn", extra="plotting", purpose="position stats plots")
|
|
236
242
|
|
|
237
243
|
def find_closest_index(index, target):
|
|
238
244
|
"""Find the index value closest to a target value."""
|
|
@@ -408,12 +414,14 @@ def plot_positionwise_matrix_grid(
|
|
|
408
414
|
"""
|
|
409
415
|
import os
|
|
410
416
|
|
|
411
|
-
import matplotlib.pyplot as plt
|
|
412
417
|
import numpy as np
|
|
413
418
|
import pandas as pd
|
|
414
|
-
import seaborn as sns
|
|
415
419
|
from joblib import Parallel, delayed
|
|
416
|
-
|
|
420
|
+
|
|
421
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="position stats plots")
|
|
422
|
+
sns = require("seaborn", extra="plotting", purpose="position stats plots")
|
|
423
|
+
grid_spec = require("matplotlib.gridspec", extra="plotting", purpose="position stats plots")
|
|
424
|
+
GridSpec = grid_spec.GridSpec
|
|
417
425
|
|
|
418
426
|
matrices = adata.uns[key]
|
|
419
427
|
group_labels = list(matrices.keys())
|
smftools/plotting/qc_plotting.py
CHANGED
|
@@ -1,9 +1,14 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import os
|
|
2
4
|
|
|
3
|
-
import matplotlib.pyplot as plt
|
|
4
5
|
import numpy as np
|
|
5
6
|
import pandas as pd
|
|
6
7
|
|
|
8
|
+
from smftools.optional_imports import require
|
|
9
|
+
|
|
10
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="QC plots")
|
|
11
|
+
|
|
7
12
|
|
|
8
13
|
def plot_read_qc_histograms(
|
|
9
14
|
adata,
|
|
@@ -1,38 +1,37 @@
|
|
|
1
|
-
from
|
|
2
|
-
from .append_binary_layer_by_base_context import append_binary_layer_by_base_context
|
|
3
|
-
from .binarize import binarize_adata
|
|
4
|
-
from .binarize_on_Youden import binarize_on_Youden
|
|
5
|
-
from .calculate_complexity_II import calculate_complexity_II
|
|
6
|
-
from .calculate_coverage import calculate_coverage
|
|
7
|
-
from .calculate_position_Youden import calculate_position_Youden
|
|
8
|
-
from .calculate_read_length_stats import calculate_read_length_stats
|
|
9
|
-
from .calculate_read_modification_stats import calculate_read_modification_stats
|
|
10
|
-
from .clean_NaN import clean_NaN
|
|
11
|
-
from .filter_adata_by_nan_proportion import filter_adata_by_nan_proportion
|
|
12
|
-
from .filter_reads_on_length_quality_mapping import filter_reads_on_length_quality_mapping
|
|
13
|
-
from .filter_reads_on_modification_thresholds import filter_reads_on_modification_thresholds
|
|
14
|
-
from .flag_duplicate_reads import flag_duplicate_reads
|
|
15
|
-
from .invert_adata import invert_adata
|
|
16
|
-
from .load_sample_sheet import load_sample_sheet
|
|
17
|
-
from .reindex_references_adata import reindex_references_adata
|
|
18
|
-
from .subsample_adata import subsample_adata
|
|
1
|
+
from __future__ import annotations
|
|
19
2
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
"
|
|
24
|
-
"
|
|
25
|
-
"
|
|
26
|
-
"
|
|
27
|
-
"
|
|
28
|
-
"
|
|
29
|
-
"
|
|
30
|
-
"
|
|
31
|
-
"
|
|
32
|
-
"
|
|
33
|
-
"
|
|
34
|
-
"
|
|
35
|
-
"
|
|
36
|
-
"
|
|
37
|
-
"
|
|
38
|
-
|
|
3
|
+
from importlib import import_module
|
|
4
|
+
|
|
5
|
+
_LAZY_ATTRS = {
|
|
6
|
+
"append_base_context": "smftools.preprocessing.append_base_context",
|
|
7
|
+
"append_binary_layer_by_base_context": "smftools.preprocessing.append_binary_layer_by_base_context",
|
|
8
|
+
"append_mismatch_frequency_sites": "smftools.preprocessing.append_mismatch_frequency_sites",
|
|
9
|
+
"binarize_adata": "smftools.preprocessing.binarize",
|
|
10
|
+
"binarize_on_Youden": "smftools.preprocessing.binarize_on_Youden",
|
|
11
|
+
"calculate_complexity_II": "smftools.preprocessing.calculate_complexity_II",
|
|
12
|
+
"calculate_coverage": "smftools.preprocessing.calculate_coverage",
|
|
13
|
+
"calculate_position_Youden": "smftools.preprocessing.calculate_position_Youden",
|
|
14
|
+
"calculate_read_length_stats": "smftools.preprocessing.calculate_read_length_stats",
|
|
15
|
+
"calculate_read_modification_stats": "smftools.preprocessing.calculate_read_modification_stats",
|
|
16
|
+
"clean_NaN": "smftools.preprocessing.clean_NaN",
|
|
17
|
+
"filter_adata_by_nan_proportion": "smftools.preprocessing.filter_adata_by_nan_proportion",
|
|
18
|
+
"filter_reads_on_length_quality_mapping": "smftools.preprocessing.filter_reads_on_length_quality_mapping",
|
|
19
|
+
"filter_reads_on_modification_thresholds": "smftools.preprocessing.filter_reads_on_modification_thresholds",
|
|
20
|
+
"flag_duplicate_reads": "smftools.preprocessing.flag_duplicate_reads",
|
|
21
|
+
"invert_adata": "smftools.preprocessing.invert_adata",
|
|
22
|
+
"load_sample_sheet": "smftools.preprocessing.load_sample_sheet",
|
|
23
|
+
"reindex_references_adata": "smftools.preprocessing.reindex_references_adata",
|
|
24
|
+
"subsample_adata": "smftools.preprocessing.subsample_adata",
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def __getattr__(name: str):
|
|
29
|
+
if name in _LAZY_ATTRS:
|
|
30
|
+
module = import_module(_LAZY_ATTRS[name])
|
|
31
|
+
attr = getattr(module, name)
|
|
32
|
+
globals()[name] = attr
|
|
33
|
+
return attr
|
|
34
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
__all__ = list(_LAZY_ATTRS.keys())
|
|
@@ -133,23 +133,23 @@ def append_base_context(
|
|
|
133
133
|
adata.var[f"{ref}_{site_type}_valid_coverage"] = (
|
|
134
134
|
(adata.var[f"{ref}_{site_type}"]) & (adata.var[f"position_in_{ref}"])
|
|
135
135
|
)
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
else:
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
if native:
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
else:
|
|
152
|
-
|
|
136
|
+
# if native:
|
|
137
|
+
# adata.obsm[f"{ref}_{site_type}_valid_coverage"] = adata[
|
|
138
|
+
# :, adata.var[f"{ref}_{site_type}_valid_coverage"]
|
|
139
|
+
# ].layers["binarized_methylation"]
|
|
140
|
+
# else:
|
|
141
|
+
# adata.obsm[f"{ref}_{site_type}_valid_coverage"] = adata[
|
|
142
|
+
# :, adata.var[f"{ref}_{site_type}_valid_coverage"]
|
|
143
|
+
# ].X
|
|
144
|
+
# else:
|
|
145
|
+
# pass
|
|
146
|
+
|
|
147
|
+
# if native:
|
|
148
|
+
# adata.obsm[f"{ref}_{site_type}"] = adata[:, adata.var[f"{ref}_{site_type}"]].layers[
|
|
149
|
+
# "binarized_methylation"
|
|
150
|
+
# ]
|
|
151
|
+
# else:
|
|
152
|
+
# adata.obsm[f"{ref}_{site_type}"] = adata[:, adata.var[f"{ref}_{site_type}"]].X
|
|
153
153
|
|
|
154
154
|
# mark as done
|
|
155
155
|
adata.uns[uns_flag] = True
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, Iterable, Sequence
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
from smftools.constants import MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT
|
|
9
|
+
from smftools.logging_utils import get_logger
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
import anndata as ad
|
|
13
|
+
|
|
14
|
+
logger = get_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def append_mismatch_frequency_sites(
|
|
18
|
+
adata: "ad.AnnData",
|
|
19
|
+
ref_column: str = "Reference_strand",
|
|
20
|
+
mismatch_layer: str = "mismatch_integer_encoding",
|
|
21
|
+
read_span_layer: str = "read_span_mask",
|
|
22
|
+
mismatch_frequency_range: Sequence[float] | None = (0.05, 0.95),
|
|
23
|
+
uns_flag: str = "append_mismatch_frequency_sites_performed",
|
|
24
|
+
force_redo: bool = False,
|
|
25
|
+
bypass: bool = False,
|
|
26
|
+
) -> None:
|
|
27
|
+
"""Append mismatch frequency metadata and variable-site flags per reference.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
adata: AnnData object.
|
|
31
|
+
ref_column: Obs column defining reference categories.
|
|
32
|
+
mismatch_layer: Layer containing mismatch integer encodings.
|
|
33
|
+
read_span_layer: Layer containing read span masks (1=covered, 0=not covered).
|
|
34
|
+
mismatch_frequency_range: Lower/upper bounds (inclusive) for variable site flagging.
|
|
35
|
+
uns_flag: Flag in ``adata.uns`` indicating prior completion.
|
|
36
|
+
force_redo: Whether to rerun even if ``uns_flag`` is set.
|
|
37
|
+
bypass: Whether to skip running this step.
|
|
38
|
+
"""
|
|
39
|
+
if bypass:
|
|
40
|
+
return
|
|
41
|
+
|
|
42
|
+
already = bool(adata.uns.get(uns_flag, False))
|
|
43
|
+
if already and not force_redo:
|
|
44
|
+
return
|
|
45
|
+
|
|
46
|
+
if mismatch_layer not in adata.layers:
|
|
47
|
+
logger.debug(
|
|
48
|
+
"Mismatch layer '%s' not found; skipping mismatch frequency step.", mismatch_layer
|
|
49
|
+
)
|
|
50
|
+
return
|
|
51
|
+
|
|
52
|
+
mismatch_map = adata.uns.get("mismatch_integer_encoding_map", {})
|
|
53
|
+
if not mismatch_map:
|
|
54
|
+
logger.debug("Mismatch encoding map not found; skipping mismatch frequency step.")
|
|
55
|
+
return
|
|
56
|
+
|
|
57
|
+
n_value = mismatch_map.get("N", MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT["N"])
|
|
58
|
+
pad_value = mismatch_map.get("PAD", MODKIT_EXTRACT_SEQUENCE_BASE_TO_INT["PAD"])
|
|
59
|
+
|
|
60
|
+
base_int_to_label = {
|
|
61
|
+
int(value): str(base)
|
|
62
|
+
for base, value in mismatch_map.items()
|
|
63
|
+
if base not in {"N", "PAD"} and isinstance(value, (int, np.integer))
|
|
64
|
+
}
|
|
65
|
+
if not base_int_to_label:
|
|
66
|
+
logger.debug("Mismatch encoding map missing base labels; skipping mismatch frequency step.")
|
|
67
|
+
return
|
|
68
|
+
|
|
69
|
+
has_span_mask = read_span_layer in adata.layers
|
|
70
|
+
if not has_span_mask:
|
|
71
|
+
logger.debug(
|
|
72
|
+
"Read span mask '%s' not found; mismatch frequencies will be computed over all reads.",
|
|
73
|
+
read_span_layer,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
references = adata.obs[ref_column].cat.categories
|
|
77
|
+
n_vars = adata.shape[1]
|
|
78
|
+
|
|
79
|
+
if mismatch_frequency_range is None:
|
|
80
|
+
mismatch_frequency_range = (0.0, 1.0)
|
|
81
|
+
|
|
82
|
+
lower_bound, upper_bound = mismatch_frequency_range
|
|
83
|
+
|
|
84
|
+
for ref in references:
|
|
85
|
+
ref_mask = adata.obs[ref_column] == ref
|
|
86
|
+
ref_position_mask = adata.var.get(f"position_in_{ref}")
|
|
87
|
+
if ref_position_mask is None:
|
|
88
|
+
ref_position_mask = pd.Series(np.ones(n_vars, dtype=bool), index=adata.var.index)
|
|
89
|
+
else:
|
|
90
|
+
ref_position_mask = ref_position_mask.astype(bool)
|
|
91
|
+
|
|
92
|
+
frequency_values = np.full(n_vars, np.nan, dtype=float)
|
|
93
|
+
variable_flags = np.zeros(n_vars, dtype=bool)
|
|
94
|
+
mismatch_base_frequencies: list[list[tuple[str, float]]] = [[] for _ in range(n_vars)]
|
|
95
|
+
|
|
96
|
+
if ref_mask.sum() == 0:
|
|
97
|
+
adata.var[f"{ref}_mismatch_frequency"] = pd.Series(
|
|
98
|
+
frequency_values, index=adata.var.index
|
|
99
|
+
)
|
|
100
|
+
adata.var[f"{ref}_variable_sequence_site"] = pd.Series(
|
|
101
|
+
variable_flags, index=adata.var.index
|
|
102
|
+
)
|
|
103
|
+
adata.var[f"{ref}_mismatch_base_frequencies"] = pd.Series(
|
|
104
|
+
mismatch_base_frequencies, index=adata.var.index
|
|
105
|
+
)
|
|
106
|
+
continue
|
|
107
|
+
|
|
108
|
+
mismatch_matrix = np.asarray(adata.layers[mismatch_layer][ref_mask])
|
|
109
|
+
if has_span_mask:
|
|
110
|
+
span_matrix = np.asarray(adata.layers[read_span_layer][ref_mask])
|
|
111
|
+
coverage_mask = span_matrix > 0
|
|
112
|
+
coverage_counts = coverage_mask.sum(axis=0).astype(float)
|
|
113
|
+
else:
|
|
114
|
+
coverage_mask = np.ones_like(mismatch_matrix, dtype=bool)
|
|
115
|
+
coverage_counts = np.full(n_vars, ref_mask.sum(), dtype=float)
|
|
116
|
+
|
|
117
|
+
mismatch_mask = (~np.isin(mismatch_matrix, [n_value, pad_value])) & coverage_mask
|
|
118
|
+
mismatch_counts = mismatch_mask.sum(axis=0)
|
|
119
|
+
|
|
120
|
+
frequency_values = np.divide(
|
|
121
|
+
mismatch_counts,
|
|
122
|
+
coverage_counts,
|
|
123
|
+
out=np.full(n_vars, np.nan, dtype=float),
|
|
124
|
+
where=coverage_counts > 0,
|
|
125
|
+
)
|
|
126
|
+
frequency_values = np.where(ref_position_mask.values, frequency_values, np.nan)
|
|
127
|
+
|
|
128
|
+
variable_flags = (
|
|
129
|
+
(frequency_values >= lower_bound)
|
|
130
|
+
& (frequency_values <= upper_bound)
|
|
131
|
+
& ref_position_mask.values
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
base_counts_by_int: dict[int, np.ndarray] = {}
|
|
135
|
+
for base_int in base_int_to_label:
|
|
136
|
+
base_counts_by_int[base_int] = ((mismatch_matrix == base_int) & coverage_mask).sum(
|
|
137
|
+
axis=0
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
for idx in range(n_vars):
|
|
141
|
+
if not ref_position_mask.iloc[idx] or coverage_counts[idx] == 0:
|
|
142
|
+
continue
|
|
143
|
+
base_freqs: list[tuple[str, float]] = []
|
|
144
|
+
for base_int, base_label in base_int_to_label.items():
|
|
145
|
+
count = base_counts_by_int[base_int][idx]
|
|
146
|
+
if count > 0:
|
|
147
|
+
base_freqs.append((base_label, float(count / coverage_counts[idx])))
|
|
148
|
+
mismatch_base_frequencies[idx] = base_freqs
|
|
149
|
+
|
|
150
|
+
adata.var[f"{ref}_mismatch_frequency"] = pd.Series(frequency_values, index=adata.var.index)
|
|
151
|
+
adata.var[f"{ref}_variable_sequence_site"] = pd.Series(
|
|
152
|
+
variable_flags, index=adata.var.index
|
|
153
|
+
)
|
|
154
|
+
adata.var[f"{ref}_mismatch_base_frequencies"] = pd.Series(
|
|
155
|
+
mismatch_base_frequencies, index=adata.var.index
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
adata.uns[uns_flag] = True
|
|
@@ -3,6 +3,8 @@ from __future__ import annotations
|
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
from typing import TYPE_CHECKING, Optional
|
|
5
5
|
|
|
6
|
+
from smftools.optional_imports import require
|
|
7
|
+
|
|
6
8
|
if TYPE_CHECKING:
|
|
7
9
|
import anndata as ad
|
|
8
10
|
|
|
@@ -46,11 +48,12 @@ def calculate_complexity_II(
|
|
|
46
48
|
"""
|
|
47
49
|
import os
|
|
48
50
|
|
|
49
|
-
import matplotlib.pyplot as plt
|
|
50
51
|
import numpy as np
|
|
51
52
|
import pandas as pd
|
|
52
53
|
from scipy.optimize import curve_fit
|
|
53
54
|
|
|
55
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="complexity plots")
|
|
56
|
+
|
|
54
57
|
# early exits
|
|
55
58
|
already = bool(adata.uns.get(uns_flag, False))
|
|
56
59
|
if already and not force_redo:
|
|
@@ -53,4 +53,4 @@ def calculate_consensus(
|
|
|
53
53
|
else:
|
|
54
54
|
adata.var[f"{reference}_consensus_across_samples"] = consensus_sequence_list
|
|
55
55
|
|
|
56
|
-
adata.uns[f"{reference}_consensus_sequence"] = consensus_sequence_list
|
|
56
|
+
adata.uns[f"{reference}_consensus_sequence"] = str(consensus_sequence_list)
|
|
@@ -6,6 +6,7 @@ from pathlib import Path
|
|
|
6
6
|
from typing import TYPE_CHECKING
|
|
7
7
|
|
|
8
8
|
from smftools.logging_utils import get_logger
|
|
9
|
+
from smftools.optional_imports import require
|
|
9
10
|
|
|
10
11
|
if TYPE_CHECKING:
|
|
11
12
|
import anndata as ad
|
|
@@ -40,9 +41,15 @@ def calculate_position_Youden(
|
|
|
40
41
|
save: Whether to save ROC plots to disk.
|
|
41
42
|
output_directory: Output directory for ROC plots.
|
|
42
43
|
"""
|
|
43
|
-
import matplotlib.pyplot as plt
|
|
44
44
|
import numpy as np
|
|
45
|
-
|
|
45
|
+
|
|
46
|
+
plt = require("matplotlib.pyplot", extra="plotting", purpose="Youden ROC plots")
|
|
47
|
+
sklearn_metrics = require(
|
|
48
|
+
"sklearn.metrics",
|
|
49
|
+
extra="ml-base",
|
|
50
|
+
purpose="Youden ROC curve calculation",
|
|
51
|
+
)
|
|
52
|
+
roc_curve = sklearn_metrics.roc_curve
|
|
46
53
|
|
|
47
54
|
control_samples = [positive_control_sample, negative_control_sample]
|
|
48
55
|
references = adata.obs[ref_column].cat.categories
|
|
@@ -20,6 +20,7 @@ def calculate_read_modification_stats(
|
|
|
20
20
|
force_redo: bool = False,
|
|
21
21
|
valid_sites_only: bool = False,
|
|
22
22
|
valid_site_suffix: str = "_valid_coverage",
|
|
23
|
+
smf_modality: str = "conversion",
|
|
23
24
|
) -> None:
|
|
24
25
|
"""Add methylation/deamination statistics for each read.
|
|
25
26
|
|
|
@@ -80,8 +81,12 @@ def calculate_read_modification_stats(
|
|
|
80
81
|
for ref in references:
|
|
81
82
|
ref_subset = adata[adata.obs[reference_column] == ref]
|
|
82
83
|
for site_type in site_types:
|
|
84
|
+
site_subset = ref_subset[:, ref_subset.var[f"{ref}_{site_type}{valid_site_suffix}"]]
|
|
83
85
|
logger.info("Iterating over %s_%s", ref, site_type)
|
|
84
|
-
|
|
86
|
+
if smf_modality == "native":
|
|
87
|
+
observation_matrix = site_subset.layers["binarized_methylation"]
|
|
88
|
+
else:
|
|
89
|
+
observation_matrix = site_subset.X
|
|
85
90
|
total_positions_in_read = np.nansum(~np.isnan(observation_matrix), axis=1)
|
|
86
91
|
total_positions_in_reference = observation_matrix.shape[1]
|
|
87
92
|
fraction_valid_positions_in_read_vs_ref = (
|