smftools 0.2.4__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. smftools/__init__.py +43 -13
  2. smftools/_settings.py +6 -6
  3. smftools/_version.py +3 -1
  4. smftools/cli/__init__.py +1 -0
  5. smftools/cli/archived/cli_flows.py +2 -0
  6. smftools/cli/helpers.py +9 -1
  7. smftools/cli/hmm_adata.py +905 -242
  8. smftools/cli/load_adata.py +432 -280
  9. smftools/cli/preprocess_adata.py +287 -171
  10. smftools/cli/spatial_adata.py +141 -53
  11. smftools/cli_entry.py +119 -178
  12. smftools/config/__init__.py +3 -1
  13. smftools/config/conversion.yaml +5 -1
  14. smftools/config/deaminase.yaml +1 -1
  15. smftools/config/default.yaml +26 -18
  16. smftools/config/direct.yaml +8 -3
  17. smftools/config/discover_input_files.py +19 -5
  18. smftools/config/experiment_config.py +511 -276
  19. smftools/constants.py +37 -0
  20. smftools/datasets/__init__.py +4 -8
  21. smftools/datasets/datasets.py +32 -18
  22. smftools/hmm/HMM.py +2133 -1428
  23. smftools/hmm/__init__.py +24 -14
  24. smftools/hmm/archived/apply_hmm_batched.py +2 -0
  25. smftools/hmm/archived/calculate_distances.py +2 -0
  26. smftools/hmm/archived/call_hmm_peaks.py +18 -1
  27. smftools/hmm/archived/train_hmm.py +2 -0
  28. smftools/hmm/call_hmm_peaks.py +176 -193
  29. smftools/hmm/display_hmm.py +23 -7
  30. smftools/hmm/hmm_readwrite.py +20 -6
  31. smftools/hmm/nucleosome_hmm_refinement.py +104 -14
  32. smftools/informatics/__init__.py +55 -13
  33. smftools/informatics/archived/bam_conversion.py +2 -0
  34. smftools/informatics/archived/bam_direct.py +2 -0
  35. smftools/informatics/archived/basecall_pod5s.py +2 -0
  36. smftools/informatics/archived/basecalls_to_adata.py +2 -0
  37. smftools/informatics/archived/conversion_smf.py +2 -0
  38. smftools/informatics/archived/deaminase_smf.py +1 -0
  39. smftools/informatics/archived/direct_smf.py +2 -0
  40. smftools/informatics/archived/fast5_to_pod5.py +2 -0
  41. smftools/informatics/archived/helpers/archived/__init__.py +2 -0
  42. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +16 -1
  43. smftools/informatics/archived/helpers/archived/aligned_BAM_to_bed.py +2 -0
  44. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  45. smftools/informatics/archived/helpers/archived/bed_to_bigwig.py +2 -0
  46. smftools/informatics/archived/helpers/archived/canoncall.py +2 -0
  47. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  48. smftools/informatics/archived/helpers/archived/converted_BAM_to_adata.py +2 -0
  49. smftools/informatics/archived/helpers/archived/count_aligned_reads.py +2 -0
  50. smftools/informatics/archived/helpers/archived/demux_and_index_BAM.py +2 -0
  51. smftools/informatics/archived/helpers/archived/extract_base_identities.py +2 -0
  52. smftools/informatics/archived/helpers/archived/extract_mods.py +2 -0
  53. smftools/informatics/archived/helpers/archived/extract_read_features_from_bam.py +2 -0
  54. smftools/informatics/archived/helpers/archived/extract_read_lengths_from_bed.py +2 -0
  55. smftools/informatics/archived/helpers/archived/extract_readnames_from_BAM.py +2 -0
  56. smftools/informatics/archived/helpers/archived/find_conversion_sites.py +2 -0
  57. smftools/informatics/archived/helpers/archived/generate_converted_FASTA.py +2 -0
  58. smftools/informatics/archived/helpers/archived/get_chromosome_lengths.py +2 -0
  59. smftools/informatics/archived/helpers/archived/get_native_references.py +2 -0
  60. smftools/informatics/archived/helpers/archived/index_fasta.py +2 -0
  61. smftools/informatics/archived/helpers/archived/informatics.py +2 -0
  62. smftools/informatics/archived/helpers/archived/load_adata.py +5 -3
  63. smftools/informatics/archived/helpers/archived/make_modbed.py +2 -0
  64. smftools/informatics/archived/helpers/archived/modQC.py +2 -0
  65. smftools/informatics/archived/helpers/archived/modcall.py +2 -0
  66. smftools/informatics/archived/helpers/archived/ohe_batching.py +2 -0
  67. smftools/informatics/archived/helpers/archived/ohe_layers_decode.py +2 -0
  68. smftools/informatics/archived/helpers/archived/one_hot_decode.py +2 -0
  69. smftools/informatics/archived/helpers/archived/one_hot_encode.py +2 -0
  70. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +5 -1
  71. smftools/informatics/archived/helpers/archived/separate_bam_by_bc.py +2 -0
  72. smftools/informatics/archived/helpers/archived/split_and_index_BAM.py +2 -0
  73. smftools/informatics/archived/print_bam_query_seq.py +9 -1
  74. smftools/informatics/archived/subsample_fasta_from_bed.py +2 -0
  75. smftools/informatics/archived/subsample_pod5.py +2 -0
  76. smftools/informatics/bam_functions.py +1059 -269
  77. smftools/informatics/basecalling.py +53 -9
  78. smftools/informatics/bed_functions.py +357 -114
  79. smftools/informatics/binarize_converted_base_identities.py +21 -7
  80. smftools/informatics/complement_base_list.py +9 -6
  81. smftools/informatics/converted_BAM_to_adata.py +324 -137
  82. smftools/informatics/fasta_functions.py +251 -89
  83. smftools/informatics/h5ad_functions.py +202 -30
  84. smftools/informatics/modkit_extract_to_adata.py +623 -274
  85. smftools/informatics/modkit_functions.py +87 -44
  86. smftools/informatics/ohe.py +46 -21
  87. smftools/informatics/pod5_functions.py +114 -74
  88. smftools/informatics/run_multiqc.py +20 -14
  89. smftools/logging_utils.py +51 -0
  90. smftools/machine_learning/__init__.py +23 -12
  91. smftools/machine_learning/data/__init__.py +2 -0
  92. smftools/machine_learning/data/anndata_data_module.py +157 -50
  93. smftools/machine_learning/data/preprocessing.py +4 -1
  94. smftools/machine_learning/evaluation/__init__.py +3 -1
  95. smftools/machine_learning/evaluation/eval_utils.py +13 -14
  96. smftools/machine_learning/evaluation/evaluators.py +52 -34
  97. smftools/machine_learning/inference/__init__.py +3 -1
  98. smftools/machine_learning/inference/inference_utils.py +9 -4
  99. smftools/machine_learning/inference/lightning_inference.py +14 -13
  100. smftools/machine_learning/inference/sklearn_inference.py +8 -8
  101. smftools/machine_learning/inference/sliding_window_inference.py +37 -25
  102. smftools/machine_learning/models/__init__.py +12 -5
  103. smftools/machine_learning/models/base.py +34 -43
  104. smftools/machine_learning/models/cnn.py +22 -13
  105. smftools/machine_learning/models/lightning_base.py +78 -42
  106. smftools/machine_learning/models/mlp.py +18 -5
  107. smftools/machine_learning/models/positional.py +10 -4
  108. smftools/machine_learning/models/rnn.py +8 -3
  109. smftools/machine_learning/models/sklearn_models.py +46 -24
  110. smftools/machine_learning/models/transformer.py +75 -55
  111. smftools/machine_learning/models/wrappers.py +8 -3
  112. smftools/machine_learning/training/__init__.py +4 -2
  113. smftools/machine_learning/training/train_lightning_model.py +42 -23
  114. smftools/machine_learning/training/train_sklearn_model.py +11 -15
  115. smftools/machine_learning/utils/__init__.py +3 -1
  116. smftools/machine_learning/utils/device.py +12 -5
  117. smftools/machine_learning/utils/grl.py +8 -2
  118. smftools/metadata.py +443 -0
  119. smftools/optional_imports.py +31 -0
  120. smftools/plotting/__init__.py +32 -17
  121. smftools/plotting/autocorrelation_plotting.py +153 -48
  122. smftools/plotting/classifiers.py +175 -73
  123. smftools/plotting/general_plotting.py +350 -168
  124. smftools/plotting/hmm_plotting.py +53 -14
  125. smftools/plotting/position_stats.py +155 -87
  126. smftools/plotting/qc_plotting.py +25 -12
  127. smftools/preprocessing/__init__.py +35 -37
  128. smftools/preprocessing/append_base_context.py +105 -79
  129. smftools/preprocessing/append_binary_layer_by_base_context.py +75 -37
  130. smftools/preprocessing/{archives → archived}/add_read_length_and_mapping_qc.py +2 -0
  131. smftools/preprocessing/{archives → archived}/calculate_complexity.py +5 -1
  132. smftools/preprocessing/{archives → archived}/mark_duplicates.py +2 -0
  133. smftools/preprocessing/{archives → archived}/preprocessing.py +10 -6
  134. smftools/preprocessing/{archives → archived}/remove_duplicates.py +2 -0
  135. smftools/preprocessing/binarize.py +21 -4
  136. smftools/preprocessing/binarize_on_Youden.py +127 -31
  137. smftools/preprocessing/binary_layers_to_ohe.py +18 -11
  138. smftools/preprocessing/calculate_complexity_II.py +89 -59
  139. smftools/preprocessing/calculate_consensus.py +28 -19
  140. smftools/preprocessing/calculate_coverage.py +44 -22
  141. smftools/preprocessing/calculate_pairwise_differences.py +4 -1
  142. smftools/preprocessing/calculate_pairwise_hamming_distances.py +7 -3
  143. smftools/preprocessing/calculate_position_Youden.py +110 -55
  144. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  145. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  146. smftools/preprocessing/clean_NaN.py +38 -28
  147. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  148. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +72 -37
  149. smftools/preprocessing/filter_reads_on_modification_thresholds.py +183 -73
  150. smftools/preprocessing/flag_duplicate_reads.py +708 -303
  151. smftools/preprocessing/invert_adata.py +26 -11
  152. smftools/preprocessing/load_sample_sheet.py +40 -22
  153. smftools/preprocessing/make_dirs.py +9 -3
  154. smftools/preprocessing/min_non_diagonal.py +4 -1
  155. smftools/preprocessing/recipes.py +58 -23
  156. smftools/preprocessing/reindex_references_adata.py +93 -27
  157. smftools/preprocessing/subsample_adata.py +33 -16
  158. smftools/readwrite.py +264 -109
  159. smftools/schema/__init__.py +11 -0
  160. smftools/schema/anndata_schema_v1.yaml +227 -0
  161. smftools/tools/__init__.py +25 -18
  162. smftools/tools/archived/apply_hmm.py +2 -0
  163. smftools/tools/archived/classifiers.py +165 -0
  164. smftools/tools/archived/classify_methylated_features.py +2 -0
  165. smftools/tools/archived/classify_non_methylated_features.py +2 -0
  166. smftools/tools/archived/subset_adata_v1.py +12 -1
  167. smftools/tools/archived/subset_adata_v2.py +14 -1
  168. smftools/tools/calculate_umap.py +56 -15
  169. smftools/tools/cluster_adata_on_methylation.py +122 -47
  170. smftools/tools/general_tools.py +70 -25
  171. smftools/tools/position_stats.py +220 -99
  172. smftools/tools/read_stats.py +50 -29
  173. smftools/tools/spatial_autocorrelation.py +365 -192
  174. smftools/tools/subset_adata.py +23 -21
  175. smftools-0.3.0.dist-info/METADATA +147 -0
  176. smftools-0.3.0.dist-info/RECORD +182 -0
  177. smftools-0.2.4.dist-info/METADATA +0 -141
  178. smftools-0.2.4.dist-info/RECORD +0 -176
  179. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/WHEEL +0 -0
  180. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/entry_points.txt +0 -0
  181. {smftools-0.2.4.dist-info → smftools-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,35 +1,53 @@
1
+ from __future__ import annotations
1
2
 
2
- import numpy as np
3
- import matplotlib.pyplot as plt
4
- import torch
5
3
  import os
6
4
 
5
+ import numpy as np
6
+
7
+ from smftools.optional_imports import require
8
+
9
+ plt = require("matplotlib.pyplot", extra="plotting", purpose="model plots")
10
+ torch = require("torch", extra="ml-base", purpose="model saliency plots")
11
+
12
+
7
13
  def plot_model_performance(metrics, save_path=None):
8
- import matplotlib.pyplot as plt
14
+ """Plot ROC and precision-recall curves for model metrics.
15
+
16
+ Args:
17
+ metrics: Dictionary of model metrics by reference.
18
+ save_path: Optional path to save plots.
19
+ """
9
20
  import os
21
+
10
22
  for ref in metrics.keys():
11
23
  plt.figure(figsize=(12, 5))
12
24
 
13
25
  # ROC Curve
14
26
  plt.subplot(1, 2, 1)
15
27
  for model_name, vals in metrics[ref].items():
16
- model_type = model_name.split('_')[0]
28
+ model_type = model_name.split("_")[0]
17
29
  data_type = model_name.split(f"{model_type}_")[1]
18
- plt.plot(vals['fpr'], vals['tpr'], label=f"{model_type.upper()} - AUC: {vals['auc']:.4f}")
19
- plt.xlabel('False Positive Rate')
20
- plt.ylabel('True Positive Rate')
21
- plt.title(f'{data_type} ROC Curve ({ref})')
30
+ plt.plot(
31
+ vals["fpr"], vals["tpr"], label=f"{model_type.upper()} - AUC: {vals['auc']:.4f}"
32
+ )
33
+ plt.xlabel("False Positive Rate")
34
+ plt.ylabel("True Positive Rate")
35
+ plt.title(f"{data_type} ROC Curve ({ref})")
22
36
  plt.legend()
23
37
 
24
38
  # PR Curve
25
39
  plt.subplot(1, 2, 2)
26
40
  for model_name, vals in metrics[ref].items():
27
- model_type = model_name.split('_')[0]
41
+ model_type = model_name.split("_")[0]
28
42
  data_type = model_name.split(f"{model_type}_")[1]
29
- plt.plot(vals['recall'], vals['precision'], label=f"{model_type.upper()} - F1: {vals['f1']:.4f}")
30
- plt.xlabel('Recall')
31
- plt.ylabel('Precision')
32
- plt.title(f'{data_type} Precision-Recall Curve ({ref})')
43
+ plt.plot(
44
+ vals["recall"],
45
+ vals["precision"],
46
+ label=f"{model_type.upper()} - F1: {vals['f1']:.4f}",
47
+ )
48
+ plt.xlabel("Recall")
49
+ plt.ylabel("Precision")
50
+ plt.title(f"{data_type} Precision-Recall Curve ({ref})")
33
51
  plt.legend()
34
52
 
35
53
  plt.tight_layout()
@@ -42,13 +60,14 @@ def plot_model_performance(metrics, save_path=None):
42
60
  plt.savefig(out_file, dpi=300)
43
61
  print(f"📁 Saved: {out_file}")
44
62
  plt.show()
45
-
63
+
46
64
  # Confusion Matrices
47
65
  for model_name, vals in metrics[ref].items():
48
66
  print(f"Confusion Matrix for {ref} - {model_name.upper()}:")
49
- print(vals['confusion_matrix'])
67
+ print(vals["confusion_matrix"])
50
68
  print()
51
69
 
70
+
52
71
  def plot_feature_importances_or_saliency(
53
72
  models,
54
73
  positions,
@@ -57,18 +76,31 @@ def plot_feature_importances_or_saliency(
57
76
  adata=None,
58
77
  layer_name=None,
59
78
  save_path=None,
60
- shaded_regions=None
79
+ shaded_regions=None,
61
80
  ):
62
- import torch
63
- import numpy as np
64
- import matplotlib.pyplot as plt
81
+ """Plot feature importances or saliency for trained models.
82
+
83
+ Args:
84
+ models: Mapping of trained models.
85
+ positions: Mapping of positions per reference.
86
+ tensors: Mapping of input tensors per reference.
87
+ site_config: Site configuration mapping.
88
+ adata: Optional AnnData object.
89
+ layer_name: Optional layer name for plotting.
90
+ save_path: Optional path to save plots.
91
+ shaded_regions: Optional list of regions to highlight.
92
+ """
65
93
  import os
66
94
 
95
+ import numpy as np
96
+
67
97
  # Select device for NN models
68
98
  device = (
69
- torch.device('cuda') if torch.cuda.is_available() else
70
- torch.device('mps') if torch.backends.mps.is_available() else
71
- torch.device('cpu')
99
+ torch.device("cuda")
100
+ if torch.cuda.is_available()
101
+ else torch.device("mps")
102
+ if torch.backends.mps.is_available()
103
+ else torch.device("cpu")
72
104
  )
73
105
 
74
106
  for ref, model_dict in models.items():
@@ -90,7 +122,9 @@ def plot_feature_importances_or_saliency(
90
122
  other_sites = set()
91
123
 
92
124
  if adata is None:
93
- print("⚠️ AnnData object is required to classify site types. Skipping site type markers.")
125
+ print(
126
+ "⚠️ AnnData object is required to classify site types. Skipping site type markers."
127
+ )
94
128
  else:
95
129
  gpc_col = f"{ref}_GpC_site"
96
130
  cpg_col = f"{ref}_CpG_site"
@@ -146,20 +180,46 @@ def plot_feature_importances_or_saliency(
146
180
  plt.figure(figsize=(12, 4))
147
181
  for pos, imp in zip(positions_sorted, importances_sorted):
148
182
  if pos in cpg_sites:
149
- plt.plot(pos, imp, marker='*', color='black', markersize=10, linestyle='None',
150
- label='CpG site' if 'CpG site' not in plt.gca().get_legend_handles_labels()[1] else "")
183
+ plt.plot(
184
+ pos,
185
+ imp,
186
+ marker="*",
187
+ color="black",
188
+ markersize=10,
189
+ linestyle="None",
190
+ label="CpG site"
191
+ if "CpG site" not in plt.gca().get_legend_handles_labels()[1]
192
+ else "",
193
+ )
151
194
  elif pos in gpc_sites:
152
- plt.plot(pos, imp, marker='o', color='blue', markersize=6, linestyle='None',
153
- label='GpC site' if 'GpC site' not in plt.gca().get_legend_handles_labels()[1] else "")
195
+ plt.plot(
196
+ pos,
197
+ imp,
198
+ marker="o",
199
+ color="blue",
200
+ markersize=6,
201
+ linestyle="None",
202
+ label="GpC site"
203
+ if "GpC site" not in plt.gca().get_legend_handles_labels()[1]
204
+ else "",
205
+ )
154
206
  else:
155
- plt.plot(pos, imp, marker='.', color='gray', linestyle='None',
156
- label='Other' if 'Other' not in plt.gca().get_legend_handles_labels()[1] else "")
157
-
158
- plt.plot(positions_sorted, importances_sorted, linestyle='-', alpha=0.5, color='black')
207
+ plt.plot(
208
+ pos,
209
+ imp,
210
+ marker=".",
211
+ color="gray",
212
+ linestyle="None",
213
+ label="Other"
214
+ if "Other" not in plt.gca().get_legend_handles_labels()[1]
215
+ else "",
216
+ )
217
+
218
+ plt.plot(positions_sorted, importances_sorted, linestyle="-", alpha=0.5, color="black")
159
219
 
160
220
  if shaded_regions:
161
- for (start, end) in shaded_regions:
162
- plt.axvspan(start, end, color='gray', alpha=0.3)
221
+ for start, end in shaded_regions:
222
+ plt.axvspan(start, end, color="gray", alpha=0.3)
163
223
 
164
224
  plt.xlabel("Genomic Position")
165
225
  plt.ylabel(y_label)
@@ -170,31 +230,50 @@ def plot_feature_importances_or_saliency(
170
230
 
171
231
  if save_path:
172
232
  os.makedirs(save_path, exist_ok=True)
173
- safe_name = plot_title.replace("=", "").replace("__", "_").replace(",", "_").replace(" ", "_")
233
+ safe_name = (
234
+ plot_title.replace("=", "")
235
+ .replace("__", "_")
236
+ .replace(",", "_")
237
+ .replace(" ", "_")
238
+ )
174
239
  out_file = os.path.join(save_path, f"{safe_name}.png")
175
240
  plt.savefig(out_file, dpi=300)
176
241
  print(f"📁 Saved: {out_file}")
177
242
 
178
243
  plt.show()
179
244
 
245
+
180
246
  def plot_model_curves_from_adata(
181
- adata,
182
- label_col='activity_status',
183
- model_names = ["cnn", "mlp", "rf"],
184
- suffix='GpC_site_CpG_site',
185
- omit_training=True,
186
- save_path=None,
187
- ylim_roc=(0.0, 1.05),
188
- ylim_pr=(0.0, 1.05)):
189
-
190
- from sklearn.metrics import precision_recall_curve, roc_curve, auc
191
- import matplotlib.pyplot as plt
192
- import seaborn as sns
247
+ adata,
248
+ label_col="activity_status",
249
+ model_names=["cnn", "mlp", "rf"],
250
+ suffix="GpC_site_CpG_site",
251
+ omit_training=True,
252
+ save_path=None,
253
+ ylim_roc=(0.0, 1.05),
254
+ ylim_pr=(0.0, 1.05),
255
+ ):
256
+ """Plot ROC and PR curves using AnnData model outputs.
257
+
258
+ Args:
259
+ adata: AnnData containing model outputs.
260
+ label_col: Ground-truth label column.
261
+ model_names: Model name prefixes.
262
+ suffix: Prediction column suffix.
263
+ omit_training: Whether to omit training rows.
264
+ save_path: Optional path to save the plot.
265
+ ylim_roc: Y-axis limits for ROC curve.
266
+ ylim_pr: Y-axis limits for PR curve.
267
+ """
268
+ sklearn_metrics = require("sklearn.metrics", extra="ml-base", purpose="model curves")
269
+ auc = sklearn_metrics.auc
270
+ precision_recall_curve = sklearn_metrics.precision_recall_curve
271
+ roc_curve = sklearn_metrics.roc_curve
193
272
 
194
273
  if omit_training:
195
- subset = adata[adata.obs['used_for_training'].astype(bool) == False]
274
+ subset = adata[~adata.obs["used_for_training"].astype(bool)]
196
275
 
197
- label = subset.obs[label_col].map({'Active': 1, 'Silent': 0}).values
276
+ label = subset.obs[label_col].map({"Active": 1, "Silent": 0}).values
198
277
 
199
278
  positive_ratio = np.sum(label.astype(int)) / len(label)
200
279
 
@@ -210,7 +289,7 @@ def plot_model_curves_from_adata(
210
289
  roc_auc = auc(fpr, tpr)
211
290
  plt.plot(fpr, tpr, label=f"{model.upper()} (AUC={roc_auc:.4f})")
212
291
 
213
- plt.plot([0, 1], [0, 1], 'k--', alpha=0.5)
292
+ plt.plot([0, 1], [0, 1], "k--", alpha=0.5)
214
293
  plt.xlabel("False Positive Rate")
215
294
  plt.ylabel("True Positive Rate")
216
295
  plt.title("ROC Curve")
@@ -230,13 +309,13 @@ def plot_model_curves_from_adata(
230
309
  plt.xlabel("Recall")
231
310
  plt.ylabel("Precision")
232
311
  plt.ylim(*ylim_pr)
233
- plt.axhline(y=positive_ratio, linestyle='--', color='gray', label='Random Baseline')
312
+ plt.axhline(y=positive_ratio, linestyle="--", color="gray", label="Random Baseline")
234
313
  plt.title("Precision-Recall Curve")
235
314
  plt.legend()
236
315
 
237
316
  plt.tight_layout()
238
317
  if save_path:
239
- save_name = f"ROC_PR_curves"
318
+ save_name = "ROC_PR_curves"
240
319
  os.makedirs(save_path, exist_ok=True)
241
320
  safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
242
321
  out_file = os.path.join(save_path, f"{safe_name}.png")
@@ -244,11 +323,12 @@ def plot_model_curves_from_adata(
244
323
  print(f"📁 Saved: {out_file}")
245
324
  plt.show()
246
325
 
326
+
247
327
  def plot_model_curves_from_adata_with_frequency_grid(
248
328
  adata,
249
- label_col='activity_status',
329
+ label_col="activity_status",
250
330
  model_names=["cnn", "mlp", "rf"],
251
- suffix='GpC_site_CpG_site',
331
+ suffix="GpC_site_CpG_site",
252
332
  omit_training=True,
253
333
  save_path=None,
254
334
  ylim_roc=(0.0, 1.05),
@@ -256,22 +336,42 @@ def plot_model_curves_from_adata_with_frequency_grid(
256
336
  pos_sample_count=500,
257
337
  pos_freq_list=[0.01, 0.05, 0.1],
258
338
  show_f1_iso_curves=False,
259
- f1_levels=None):
260
- import numpy as np
261
- import matplotlib.pyplot as plt
262
- import seaborn as sns
339
+ f1_levels=None,
340
+ ):
341
+ """Plot ROC/PR curves with frequency grid overlays.
342
+
343
+ Args:
344
+ adata: AnnData containing model outputs.
345
+ label_col: Ground-truth label column.
346
+ model_names: Model name prefixes.
347
+ suffix: Prediction column suffix.
348
+ omit_training: Whether to omit training rows.
349
+ save_path: Optional path to save the plot.
350
+ ylim_roc: Y-axis limits for ROC curve.
351
+ ylim_pr: Y-axis limits for PR curve.
352
+ pos_sample_count: Sample count for positive baseline.
353
+ pos_freq_list: List of positive class frequencies to plot.
354
+ show_f1_iso_curves: Whether to show F1 iso-curves.
355
+ f1_levels: F1 levels to plot if enabled.
356
+ """
263
357
  import os
264
- from sklearn.metrics import precision_recall_curve, roc_curve, auc
265
-
358
+
359
+ import numpy as np
360
+
361
+ sklearn_metrics = require("sklearn.metrics", extra="ml-base", purpose="model curves")
362
+ auc = sklearn_metrics.auc
363
+ precision_recall_curve = sklearn_metrics.precision_recall_curve
364
+ roc_curve = sklearn_metrics.roc_curve
365
+
266
366
  if f1_levels is None:
267
367
  f1_levels = np.linspace(0.2, 0.9, 8)
268
-
368
+
269
369
  if omit_training:
270
- subset = adata[adata.obs['used_for_training'].astype(bool) == False]
370
+ subset = adata[~adata.obs["used_for_training"].astype(bool)]
271
371
  else:
272
372
  subset = adata
273
373
 
274
- label = subset.obs[label_col].map({'Active': 1, 'Silent': 0}).values
374
+ label = subset.obs[label_col].map({"Active": 1, "Silent": 0}).values
275
375
  subset = subset.copy()
276
376
  subset.obs["__label__"] = label
277
377
 
@@ -280,7 +380,7 @@ def plot_model_curves_from_adata_with_frequency_grid(
280
380
 
281
381
  n_rows = len(pos_freq_list)
282
382
  fig, axes = plt.subplots(n_rows, 2, figsize=(12, 5 * n_rows))
283
- fig.suptitle(f'{suffix} Performance metrics')
383
+ fig.suptitle(f"{suffix} Performance metrics")
284
384
 
285
385
  for row_idx, pos_freq in enumerate(pos_freq_list):
286
386
  desired_total = int(pos_sample_count / pos_freq)
@@ -308,14 +408,14 @@ def plot_model_curves_from_adata_with_frequency_grid(
308
408
  fpr, tpr, _ = roc_curve(y_true, probs)
309
409
  roc_auc = auc(fpr, tpr)
310
410
  ax_roc.plot(fpr, tpr, label=f"{model.upper()} (AUC={roc_auc:.4f})")
311
- ax_roc.plot([0, 1], [0, 1], 'k--', alpha=0.5)
411
+ ax_roc.plot([0, 1], [0, 1], "k--", alpha=0.5)
312
412
  ax_roc.set_xlabel("False Positive Rate")
313
413
  ax_roc.set_ylabel("True Positive Rate")
314
414
  ax_roc.set_ylim(*ylim_roc)
315
415
  ax_roc.set_title(f"ROC Curve (Pos Freq: {pos_freq:.2%})")
316
416
  ax_roc.legend()
317
- ax_roc.spines['top'].set_visible(False)
318
- ax_roc.spines['right'].set_visible(False)
417
+ ax_roc.spines["top"].set_visible(False)
418
+ ax_roc.spines["right"].set_visible(False)
319
419
 
320
420
  # PR Curve
321
421
  for model in model_names:
@@ -325,26 +425,28 @@ def plot_model_curves_from_adata_with_frequency_grid(
325
425
  precision, recall, _ = precision_recall_curve(y_true, probs)
326
426
  pr_auc = auc(recall, precision)
327
427
  ax_pr.plot(recall, precision, label=f"{model.upper()} (AUC={pr_auc:.4f})")
328
- ax_pr.axhline(y=pos_freq, linestyle='--', color='gray', label='Random Baseline')
428
+ ax_pr.axhline(y=pos_freq, linestyle="--", color="gray", label="Random Baseline")
329
429
 
330
430
  if show_f1_iso_curves:
331
431
  recall_vals = np.linspace(0.01, 1, 500)
332
432
  for f1 in f1_levels:
333
433
  precision_vals = (f1 * recall_vals) / (2 * recall_vals - f1)
334
434
  precision_vals[precision_vals < 0] = np.nan # Avoid plotting invalid values
335
- ax_pr.plot(recall_vals, precision_vals, color='gray', linestyle=':', linewidth=1, alpha=0.6)
435
+ ax_pr.plot(
436
+ recall_vals, precision_vals, color="gray", linestyle=":", linewidth=1, alpha=0.6
437
+ )
336
438
  x_val = 0.9
337
439
  y_val = (f1 * x_val) / (2 * x_val - f1)
338
440
  if 0 < y_val < 1:
339
- ax_pr.text(x_val, y_val, f"F1={f1:.1f}", fontsize=8, color='gray')
441
+ ax_pr.text(x_val, y_val, f"F1={f1:.1f}", fontsize=8, color="gray")
340
442
 
341
443
  ax_pr.set_xlabel("Recall")
342
444
  ax_pr.set_ylabel("Precision")
343
445
  ax_pr.set_ylim(*ylim_pr)
344
446
  ax_pr.set_title(f"PR Curve (Pos Freq: {pos_freq:.2%})")
345
447
  ax_pr.legend()
346
- ax_pr.spines['top'].set_visible(False)
347
- ax_pr.spines['right'].set_visible(False)
448
+ ax_pr.spines["top"].set_visible(False)
449
+ ax_pr.spines["right"].set_visible(False)
348
450
 
349
451
  plt.tight_layout(rect=[0, 0, 1, 0.97])
350
452
  if save_path:
@@ -352,4 +454,4 @@ def plot_model_curves_from_adata_with_frequency_grid(
352
454
  out_file = os.path.join(save_path, "ROC_PR_grid.png")
353
455
  plt.savefig(out_file, dpi=300)
354
456
  print(f"📁 Saved: {out_file}")
355
- plt.show()
457
+ plt.show()