smftools 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. smftools/_version.py +1 -1
  2. smftools/cli/chimeric_adata.py +1563 -0
  3. smftools/cli/helpers.py +18 -2
  4. smftools/cli/hmm_adata.py +18 -1
  5. smftools/cli/latent_adata.py +522 -67
  6. smftools/cli/load_adata.py +2 -2
  7. smftools/cli/preprocess_adata.py +32 -93
  8. smftools/cli/recipes.py +26 -0
  9. smftools/cli/spatial_adata.py +23 -109
  10. smftools/cli/variant_adata.py +423 -0
  11. smftools/cli_entry.py +41 -5
  12. smftools/config/conversion.yaml +0 -10
  13. smftools/config/deaminase.yaml +3 -0
  14. smftools/config/default.yaml +49 -13
  15. smftools/config/experiment_config.py +96 -3
  16. smftools/constants.py +4 -0
  17. smftools/hmm/call_hmm_peaks.py +1 -1
  18. smftools/informatics/binarize_converted_base_identities.py +2 -89
  19. smftools/informatics/converted_BAM_to_adata.py +53 -13
  20. smftools/informatics/h5ad_functions.py +83 -0
  21. smftools/informatics/modkit_extract_to_adata.py +4 -0
  22. smftools/plotting/__init__.py +26 -12
  23. smftools/plotting/autocorrelation_plotting.py +22 -4
  24. smftools/plotting/chimeric_plotting.py +1893 -0
  25. smftools/plotting/classifiers.py +28 -14
  26. smftools/plotting/general_plotting.py +58 -3362
  27. smftools/plotting/hmm_plotting.py +1586 -2
  28. smftools/plotting/latent_plotting.py +804 -0
  29. smftools/plotting/plotting_utils.py +243 -0
  30. smftools/plotting/position_stats.py +16 -8
  31. smftools/plotting/preprocess_plotting.py +281 -0
  32. smftools/plotting/qc_plotting.py +8 -3
  33. smftools/plotting/spatial_plotting.py +1134 -0
  34. smftools/plotting/variant_plotting.py +1231 -0
  35. smftools/preprocessing/__init__.py +3 -0
  36. smftools/preprocessing/append_base_context.py +1 -1
  37. smftools/preprocessing/append_mismatch_frequency_sites.py +35 -6
  38. smftools/preprocessing/append_sequence_mismatch_annotations.py +171 -0
  39. smftools/preprocessing/append_variant_call_layer.py +480 -0
  40. smftools/preprocessing/flag_duplicate_reads.py +4 -4
  41. smftools/preprocessing/invert_adata.py +1 -0
  42. smftools/readwrite.py +109 -85
  43. smftools/tools/__init__.py +6 -0
  44. smftools/tools/calculate_knn.py +121 -0
  45. smftools/tools/calculate_nmf.py +18 -7
  46. smftools/tools/calculate_pca.py +180 -0
  47. smftools/tools/calculate_umap.py +70 -154
  48. smftools/tools/position_stats.py +4 -4
  49. smftools/tools/rolling_nn_distance.py +640 -3
  50. smftools/tools/sequence_alignment.py +140 -0
  51. smftools/tools/tensor_factorization.py +52 -4
  52. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/METADATA +3 -1
  53. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/RECORD +56 -42
  54. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/WHEEL +0 -0
  55. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/entry_points.txt +0 -0
  56. {smftools-0.3.1.dist-info → smftools-0.3.2.dist-info}/licenses/LICENSE +0 -0
@@ -4,11 +4,14 @@ import os
4
4
 
5
5
  import numpy as np
6
6
 
7
+ from smftools.logging_utils import get_logger
7
8
  from smftools.optional_imports import require
8
9
 
9
10
  plt = require("matplotlib.pyplot", extra="plotting", purpose="model plots")
10
11
  torch = require("torch", extra="ml-base", purpose="model saliency plots")
11
12
 
13
+ logger = get_logger(__name__)
14
+
12
15
 
13
16
  def plot_model_performance(metrics, save_path=None):
14
17
  """Plot ROC and precision-recall curves for model metrics.
@@ -19,6 +22,7 @@ def plot_model_performance(metrics, save_path=None):
19
22
  """
20
23
  import os
21
24
 
25
+ logger.info("Plotting model performance curves.")
22
26
  for ref in metrics.keys():
23
27
  plt.figure(figsize=(12, 5))
24
28
 
@@ -58,14 +62,17 @@ def plot_model_performance(metrics, save_path=None):
58
62
  safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
59
63
  out_file = os.path.join(save_path, f"{safe_name}.png")
60
64
  plt.savefig(out_file, dpi=300)
61
- print(f"📁 Saved: {out_file}")
65
+ logger.info("Saved model performance plot to %s.", out_file)
62
66
  plt.show()
63
67
 
64
68
  # Confusion Matrices
65
69
  for model_name, vals in metrics[ref].items():
66
- print(f"Confusion Matrix for {ref} - {model_name.upper()}:")
67
- print(vals["confusion_matrix"])
68
- print()
70
+ logger.info(
71
+ "Confusion Matrix for %s - %s:\n%s",
72
+ ref,
73
+ model_name.upper(),
74
+ vals["confusion_matrix"],
75
+ )
69
76
 
70
77
 
71
78
  def plot_feature_importances_or_saliency(
@@ -94,6 +101,7 @@ def plot_feature_importances_or_saliency(
94
101
 
95
102
  import numpy as np
96
103
 
104
+ logger.info("Plotting feature importances or saliency.")
97
105
  # Select device for NN models
98
106
  device = (
99
107
  torch.device("cuda")
@@ -110,7 +118,7 @@ def plot_feature_importances_or_saliency(
110
118
  suffix = "_".join(site_config[ref]) if ref in site_config else "full"
111
119
 
112
120
  if ref not in positions or suffix not in positions[ref]:
113
- print(f"Positions not found for {ref} with suffix {suffix}. Skipping {ref}.")
121
+ logger.warning("Positions not found for %s with suffix %s. Skipping.", ref, suffix)
114
122
  continue
115
123
 
116
124
  coords_index = positions[ref][suffix]
@@ -122,8 +130,8 @@ def plot_feature_importances_or_saliency(
122
130
  other_sites = set()
123
131
 
124
132
  if adata is None:
125
- print(
126
- "⚠️ AnnData object is required to classify site types. Skipping site type markers."
133
+ logger.warning(
134
+ "AnnData object is required to classify site types. Skipping site type markers."
127
135
  )
128
136
  else:
129
137
  gpc_col = f"{ref}_GpC_site"
@@ -140,7 +148,7 @@ def plot_feature_importances_or_saliency(
140
148
  else:
141
149
  other_sites.add(coord_int)
142
150
  except KeyError:
143
- print(f"⚠️ Index '{idx_str}' not found in adata.var. Skipping.")
151
+ logger.warning("Index '%s' not found in adata.var. Skipping.", idx_str)
144
152
  continue
145
153
 
146
154
  for model_key, model in model_dict.items():
@@ -151,13 +159,17 @@ def plot_feature_importances_or_saliency(
151
159
  if hasattr(model, "feature_importances_"):
152
160
  importances = model.feature_importances_
153
161
  else:
154
- print(f"Random Forest model {model_key} has no feature_importances_. Skipping.")
162
+ logger.warning(
163
+ "Random Forest model %s has no feature_importances_. Skipping.", model_key
164
+ )
155
165
  continue
156
166
  plot_title = f"RF Feature Importances for {ref} ({model_key})"
157
167
  y_label = "Feature Importance"
158
168
  else:
159
169
  if tensors is None or ref not in tensors or suffix not in tensors[ref]:
160
- print(f"No input data provided for NN saliency for {model_key}. Skipping.")
170
+ logger.warning(
171
+ "No input data provided for NN saliency for %s. Skipping.", model_key
172
+ )
161
173
  continue
162
174
  input_tensor = tensors[ref][suffix]
163
175
  model.eval()
@@ -238,7 +250,7 @@ def plot_feature_importances_or_saliency(
238
250
  )
239
251
  out_file = os.path.join(save_path, f"{safe_name}.png")
240
252
  plt.savefig(out_file, dpi=300)
241
- print(f"📁 Saved: {out_file}")
253
+ logger.info("Saved feature importance plot to %s.", out_file)
242
254
 
243
255
  plt.show()
244
256
 
@@ -265,6 +277,7 @@ def plot_model_curves_from_adata(
265
277
  ylim_roc: Y-axis limits for ROC curve.
266
278
  ylim_pr: Y-axis limits for PR curve.
267
279
  """
280
+ logger.info("Plotting model curves from AnnData.")
268
281
  sklearn_metrics = require("sklearn.metrics", extra="ml-base", purpose="model curves")
269
282
  auc = sklearn_metrics.auc
270
283
  precision_recall_curve = sklearn_metrics.precision_recall_curve
@@ -320,7 +333,7 @@ def plot_model_curves_from_adata(
320
333
  safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
321
334
  out_file = os.path.join(save_path, f"{safe_name}.png")
322
335
  plt.savefig(out_file, dpi=300)
323
- print(f"📁 Saved: {out_file}")
336
+ logger.info("Saved model curves plot to %s.", out_file)
324
337
  plt.show()
325
338
 
326
339
 
@@ -358,6 +371,7 @@ def plot_model_curves_from_adata_with_frequency_grid(
358
371
 
359
372
  import numpy as np
360
373
 
374
+ logger.info("Plotting model curves with frequency grid from AnnData.")
361
375
  sklearn_metrics = require("sklearn.metrics", extra="ml-base", purpose="model curves")
362
376
  auc = sklearn_metrics.auc
363
377
  precision_recall_curve = sklearn_metrics.precision_recall_curve
@@ -387,7 +401,7 @@ def plot_model_curves_from_adata_with_frequency_grid(
387
401
  neg_sample_count = desired_total - pos_sample_count
388
402
 
389
403
  if pos_sample_count > len(pos_indices) or neg_sample_count > len(neg_indices):
390
- print(f"⚠️ Skipping frequency {pos_freq:.3f}: not enough samples.")
404
+ logger.warning("Skipping frequency %.3f: not enough samples.", pos_freq)
391
405
  continue
392
406
 
393
407
  sampled_pos = np.random.choice(pos_indices, size=pos_sample_count, replace=False)
@@ -453,5 +467,5 @@ def plot_model_curves_from_adata_with_frequency_grid(
453
467
  os.makedirs(save_path, exist_ok=True)
454
468
  out_file = os.path.join(save_path, "ROC_PR_grid.png")
455
469
  plt.savefig(out_file, dpi=300)
456
- print(f"📁 Saved: {out_file}")
470
+ logger.info("Saved model curves frequency grid to %s.", out_file)
457
471
  plt.show()