smftools 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- smftools/_version.py +1 -1
- smftools/cli/chimeric_adata.py +1563 -0
- smftools/cli/helpers.py +18 -2
- smftools/cli/hmm_adata.py +18 -1
- smftools/cli/latent_adata.py +522 -67
- smftools/cli/load_adata.py +2 -2
- smftools/cli/preprocess_adata.py +32 -93
- smftools/cli/recipes.py +26 -0
- smftools/cli/spatial_adata.py +23 -109
- smftools/cli/variant_adata.py +423 -0
- smftools/cli_entry.py +41 -5
- smftools/config/conversion.yaml +0 -10
- smftools/config/deaminase.yaml +3 -0
- smftools/config/default.yaml +49 -13
- smftools/config/experiment_config.py +96 -3
- smftools/constants.py +4 -0
- smftools/hmm/call_hmm_peaks.py +1 -1
- smftools/informatics/binarize_converted_base_identities.py +2 -89
- smftools/informatics/converted_BAM_to_adata.py +53 -13
- smftools/informatics/h5ad_functions.py +83 -0
- smftools/informatics/modkit_extract_to_adata.py +4 -0
- smftools/plotting/__init__.py +26 -12
- smftools/plotting/autocorrelation_plotting.py +22 -4
- smftools/plotting/chimeric_plotting.py +1893 -0
- smftools/plotting/classifiers.py +28 -14
- smftools/plotting/general_plotting.py +58 -3362
- smftools/plotting/hmm_plotting.py +1586 -2
- smftools/plotting/latent_plotting.py +804 -0
- smftools/plotting/plotting_utils.py +243 -0
- smftools/plotting/position_stats.py +16 -8
- smftools/plotting/preprocess_plotting.py +281 -0
- smftools/plotting/qc_plotting.py +8 -3
- smftools/plotting/spatial_plotting.py +1134 -0
- smftools/plotting/variant_plotting.py +1231 -0
- smftools/preprocessing/__init__.py +3 -0
- smftools/preprocessing/append_base_context.py +1 -1
- smftools/preprocessing/append_mismatch_frequency_sites.py +35 -6
- smftools/preprocessing/append_sequence_mismatch_annotations.py +171 -0
- smftools/preprocessing/append_variant_call_layer.py +480 -0
- smftools/preprocessing/flag_duplicate_reads.py +4 -4
- smftools/preprocessing/invert_adata.py +1 -0
- smftools/readwrite.py +109 -85
- smftools/tools/__init__.py +6 -0
- smftools/tools/calculate_knn.py +121 -0
- smftools/tools/calculate_nmf.py +18 -7
- smftools/tools/calculate_pca.py +180 -0
- smftools/tools/calculate_umap.py +70 -154
- smftools/tools/position_stats.py +4 -4
- smftools/tools/rolling_nn_distance.py +640 -3
- smftools/tools/sequence_alignment.py +140 -0
- smftools/tools/tensor_factorization.py +52 -4
- {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/METADATA +3 -1
- {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/RECORD +56 -42
- {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/WHEEL +0 -0
- {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/entry_points.txt +0 -0
- {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/licenses/LICENSE +0 -0
smftools/plotting/classifiers.py
CHANGED
|
@@ -4,11 +4,14 @@ import os
|
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
|
|
7
|
+
from smftools.logging_utils import get_logger
|
|
7
8
|
from smftools.optional_imports import require
|
|
8
9
|
|
|
9
10
|
plt = require("matplotlib.pyplot", extra="plotting", purpose="model plots")
|
|
10
11
|
torch = require("torch", extra="ml-base", purpose="model saliency plots")
|
|
11
12
|
|
|
13
|
+
logger = get_logger(__name__)
|
|
14
|
+
|
|
12
15
|
|
|
13
16
|
def plot_model_performance(metrics, save_path=None):
|
|
14
17
|
"""Plot ROC and precision-recall curves for model metrics.
|
|
@@ -19,6 +22,7 @@ def plot_model_performance(metrics, save_path=None):
|
|
|
19
22
|
"""
|
|
20
23
|
import os
|
|
21
24
|
|
|
25
|
+
logger.info("Plotting model performance curves.")
|
|
22
26
|
for ref in metrics.keys():
|
|
23
27
|
plt.figure(figsize=(12, 5))
|
|
24
28
|
|
|
@@ -58,14 +62,17 @@ def plot_model_performance(metrics, save_path=None):
|
|
|
58
62
|
safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
|
|
59
63
|
out_file = os.path.join(save_path, f"{safe_name}.png")
|
|
60
64
|
plt.savefig(out_file, dpi=300)
|
|
61
|
-
|
|
65
|
+
logger.info("Saved model performance plot to %s.", out_file)
|
|
62
66
|
plt.show()
|
|
63
67
|
|
|
64
68
|
# Confusion Matrices
|
|
65
69
|
for model_name, vals in metrics[ref].items():
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
70
|
+
logger.info(
|
|
71
|
+
"Confusion Matrix for %s - %s:\n%s",
|
|
72
|
+
ref,
|
|
73
|
+
model_name.upper(),
|
|
74
|
+
vals["confusion_matrix"],
|
|
75
|
+
)
|
|
69
76
|
|
|
70
77
|
|
|
71
78
|
def plot_feature_importances_or_saliency(
|
|
@@ -94,6 +101,7 @@ def plot_feature_importances_or_saliency(
|
|
|
94
101
|
|
|
95
102
|
import numpy as np
|
|
96
103
|
|
|
104
|
+
logger.info("Plotting feature importances or saliency.")
|
|
97
105
|
# Select device for NN models
|
|
98
106
|
device = (
|
|
99
107
|
torch.device("cuda")
|
|
@@ -110,7 +118,7 @@ def plot_feature_importances_or_saliency(
|
|
|
110
118
|
suffix = "_".join(site_config[ref]) if ref in site_config else "full"
|
|
111
119
|
|
|
112
120
|
if ref not in positions or suffix not in positions[ref]:
|
|
113
|
-
|
|
121
|
+
logger.warning("Positions not found for %s with suffix %s. Skipping.", ref, suffix)
|
|
114
122
|
continue
|
|
115
123
|
|
|
116
124
|
coords_index = positions[ref][suffix]
|
|
@@ -122,8 +130,8 @@ def plot_feature_importances_or_saliency(
|
|
|
122
130
|
other_sites = set()
|
|
123
131
|
|
|
124
132
|
if adata is None:
|
|
125
|
-
|
|
126
|
-
"
|
|
133
|
+
logger.warning(
|
|
134
|
+
"AnnData object is required to classify site types. Skipping site type markers."
|
|
127
135
|
)
|
|
128
136
|
else:
|
|
129
137
|
gpc_col = f"{ref}_GpC_site"
|
|
@@ -140,7 +148,7 @@ def plot_feature_importances_or_saliency(
|
|
|
140
148
|
else:
|
|
141
149
|
other_sites.add(coord_int)
|
|
142
150
|
except KeyError:
|
|
143
|
-
|
|
151
|
+
logger.warning("Index '%s' not found in adata.var. Skipping.", idx_str)
|
|
144
152
|
continue
|
|
145
153
|
|
|
146
154
|
for model_key, model in model_dict.items():
|
|
@@ -151,13 +159,17 @@ def plot_feature_importances_or_saliency(
|
|
|
151
159
|
if hasattr(model, "feature_importances_"):
|
|
152
160
|
importances = model.feature_importances_
|
|
153
161
|
else:
|
|
154
|
-
|
|
162
|
+
logger.warning(
|
|
163
|
+
"Random Forest model %s has no feature_importances_. Skipping.", model_key
|
|
164
|
+
)
|
|
155
165
|
continue
|
|
156
166
|
plot_title = f"RF Feature Importances for {ref} ({model_key})"
|
|
157
167
|
y_label = "Feature Importance"
|
|
158
168
|
else:
|
|
159
169
|
if tensors is None or ref not in tensors or suffix not in tensors[ref]:
|
|
160
|
-
|
|
170
|
+
logger.warning(
|
|
171
|
+
"No input data provided for NN saliency for %s. Skipping.", model_key
|
|
172
|
+
)
|
|
161
173
|
continue
|
|
162
174
|
input_tensor = tensors[ref][suffix]
|
|
163
175
|
model.eval()
|
|
@@ -238,7 +250,7 @@ def plot_feature_importances_or_saliency(
|
|
|
238
250
|
)
|
|
239
251
|
out_file = os.path.join(save_path, f"{safe_name}.png")
|
|
240
252
|
plt.savefig(out_file, dpi=300)
|
|
241
|
-
|
|
253
|
+
logger.info("Saved feature importance plot to %s.", out_file)
|
|
242
254
|
|
|
243
255
|
plt.show()
|
|
244
256
|
|
|
@@ -265,6 +277,7 @@ def plot_model_curves_from_adata(
|
|
|
265
277
|
ylim_roc: Y-axis limits for ROC curve.
|
|
266
278
|
ylim_pr: Y-axis limits for PR curve.
|
|
267
279
|
"""
|
|
280
|
+
logger.info("Plotting model curves from AnnData.")
|
|
268
281
|
sklearn_metrics = require("sklearn.metrics", extra="ml-base", purpose="model curves")
|
|
269
282
|
auc = sklearn_metrics.auc
|
|
270
283
|
precision_recall_curve = sklearn_metrics.precision_recall_curve
|
|
@@ -320,7 +333,7 @@ def plot_model_curves_from_adata(
|
|
|
320
333
|
safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
|
|
321
334
|
out_file = os.path.join(save_path, f"{safe_name}.png")
|
|
322
335
|
plt.savefig(out_file, dpi=300)
|
|
323
|
-
|
|
336
|
+
logger.info("Saved model curves plot to %s.", out_file)
|
|
324
337
|
plt.show()
|
|
325
338
|
|
|
326
339
|
|
|
@@ -358,6 +371,7 @@ def plot_model_curves_from_adata_with_frequency_grid(
|
|
|
358
371
|
|
|
359
372
|
import numpy as np
|
|
360
373
|
|
|
374
|
+
logger.info("Plotting model curves with frequency grid from AnnData.")
|
|
361
375
|
sklearn_metrics = require("sklearn.metrics", extra="ml-base", purpose="model curves")
|
|
362
376
|
auc = sklearn_metrics.auc
|
|
363
377
|
precision_recall_curve = sklearn_metrics.precision_recall_curve
|
|
@@ -387,7 +401,7 @@ def plot_model_curves_from_adata_with_frequency_grid(
|
|
|
387
401
|
neg_sample_count = desired_total - pos_sample_count
|
|
388
402
|
|
|
389
403
|
if pos_sample_count > len(pos_indices) or neg_sample_count > len(neg_indices):
|
|
390
|
-
|
|
404
|
+
logger.warning("Skipping frequency %.3f: not enough samples.", pos_freq)
|
|
391
405
|
continue
|
|
392
406
|
|
|
393
407
|
sampled_pos = np.random.choice(pos_indices, size=pos_sample_count, replace=False)
|
|
@@ -453,5 +467,5 @@ def plot_model_curves_from_adata_with_frequency_grid(
|
|
|
453
467
|
os.makedirs(save_path, exist_ok=True)
|
|
454
468
|
out_file = os.path.join(save_path, "ROC_PR_grid.png")
|
|
455
469
|
plt.savefig(out_file, dpi=300)
|
|
456
|
-
|
|
470
|
+
logger.info("Saved model curves frequency grid to %s.", out_file)
|
|
457
471
|
plt.show()
|