smftools 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. smftools/__init__.py +6 -8
  2. smftools/_settings.py +4 -6
  3. smftools/_version.py +1 -1
  4. smftools/cli/helpers.py +54 -0
  5. smftools/cli/hmm_adata.py +937 -256
  6. smftools/cli/load_adata.py +448 -268
  7. smftools/cli/preprocess_adata.py +469 -263
  8. smftools/cli/spatial_adata.py +536 -319
  9. smftools/cli_entry.py +97 -182
  10. smftools/config/__init__.py +1 -1
  11. smftools/config/conversion.yaml +17 -6
  12. smftools/config/deaminase.yaml +12 -10
  13. smftools/config/default.yaml +142 -33
  14. smftools/config/direct.yaml +11 -3
  15. smftools/config/discover_input_files.py +19 -5
  16. smftools/config/experiment_config.py +594 -264
  17. smftools/constants.py +37 -0
  18. smftools/datasets/__init__.py +2 -8
  19. smftools/datasets/datasets.py +32 -18
  20. smftools/hmm/HMM.py +2128 -1418
  21. smftools/hmm/__init__.py +2 -9
  22. smftools/hmm/archived/call_hmm_peaks.py +121 -0
  23. smftools/hmm/call_hmm_peaks.py +299 -91
  24. smftools/hmm/display_hmm.py +19 -6
  25. smftools/hmm/hmm_readwrite.py +13 -4
  26. smftools/hmm/nucleosome_hmm_refinement.py +102 -14
  27. smftools/informatics/__init__.py +30 -7
  28. smftools/informatics/archived/helpers/archived/align_and_sort_BAM.py +14 -1
  29. smftools/informatics/archived/helpers/archived/bam_qc.py +14 -1
  30. smftools/informatics/archived/helpers/archived/concatenate_fastqs_to_bam.py +8 -1
  31. smftools/informatics/archived/helpers/archived/load_adata.py +3 -3
  32. smftools/informatics/archived/helpers/archived/plot_bed_histograms.py +3 -1
  33. smftools/informatics/archived/print_bam_query_seq.py +7 -1
  34. smftools/informatics/bam_functions.py +397 -175
  35. smftools/informatics/basecalling.py +51 -9
  36. smftools/informatics/bed_functions.py +90 -57
  37. smftools/informatics/binarize_converted_base_identities.py +18 -7
  38. smftools/informatics/complement_base_list.py +7 -6
  39. smftools/informatics/converted_BAM_to_adata.py +265 -122
  40. smftools/informatics/fasta_functions.py +161 -83
  41. smftools/informatics/h5ad_functions.py +196 -30
  42. smftools/informatics/modkit_extract_to_adata.py +609 -270
  43. smftools/informatics/modkit_functions.py +85 -44
  44. smftools/informatics/ohe.py +44 -21
  45. smftools/informatics/pod5_functions.py +112 -73
  46. smftools/informatics/run_multiqc.py +20 -14
  47. smftools/logging_utils.py +51 -0
  48. smftools/machine_learning/__init__.py +2 -7
  49. smftools/machine_learning/data/anndata_data_module.py +143 -50
  50. smftools/machine_learning/data/preprocessing.py +2 -1
  51. smftools/machine_learning/evaluation/__init__.py +1 -1
  52. smftools/machine_learning/evaluation/eval_utils.py +11 -14
  53. smftools/machine_learning/evaluation/evaluators.py +46 -33
  54. smftools/machine_learning/inference/__init__.py +1 -1
  55. smftools/machine_learning/inference/inference_utils.py +7 -4
  56. smftools/machine_learning/inference/lightning_inference.py +9 -13
  57. smftools/machine_learning/inference/sklearn_inference.py +6 -8
  58. smftools/machine_learning/inference/sliding_window_inference.py +35 -25
  59. smftools/machine_learning/models/__init__.py +10 -5
  60. smftools/machine_learning/models/base.py +28 -42
  61. smftools/machine_learning/models/cnn.py +15 -11
  62. smftools/machine_learning/models/lightning_base.py +71 -40
  63. smftools/machine_learning/models/mlp.py +13 -4
  64. smftools/machine_learning/models/positional.py +3 -2
  65. smftools/machine_learning/models/rnn.py +3 -2
  66. smftools/machine_learning/models/sklearn_models.py +39 -22
  67. smftools/machine_learning/models/transformer.py +68 -53
  68. smftools/machine_learning/models/wrappers.py +2 -1
  69. smftools/machine_learning/training/__init__.py +2 -2
  70. smftools/machine_learning/training/train_lightning_model.py +29 -20
  71. smftools/machine_learning/training/train_sklearn_model.py +9 -15
  72. smftools/machine_learning/utils/__init__.py +1 -1
  73. smftools/machine_learning/utils/device.py +7 -4
  74. smftools/machine_learning/utils/grl.py +3 -1
  75. smftools/metadata.py +443 -0
  76. smftools/plotting/__init__.py +19 -5
  77. smftools/plotting/autocorrelation_plotting.py +145 -44
  78. smftools/plotting/classifiers.py +162 -72
  79. smftools/plotting/general_plotting.py +422 -197
  80. smftools/plotting/hmm_plotting.py +42 -13
  81. smftools/plotting/position_stats.py +147 -87
  82. smftools/plotting/qc_plotting.py +20 -12
  83. smftools/preprocessing/__init__.py +10 -12
  84. smftools/preprocessing/append_base_context.py +115 -80
  85. smftools/preprocessing/append_binary_layer_by_base_context.py +77 -39
  86. smftools/preprocessing/{calculate_complexity.py → archived/calculate_complexity.py} +3 -1
  87. smftools/preprocessing/{archives → archived}/preprocessing.py +8 -6
  88. smftools/preprocessing/binarize.py +21 -4
  89. smftools/preprocessing/binarize_on_Youden.py +129 -31
  90. smftools/preprocessing/binary_layers_to_ohe.py +17 -11
  91. smftools/preprocessing/calculate_complexity_II.py +86 -59
  92. smftools/preprocessing/calculate_consensus.py +28 -19
  93. smftools/preprocessing/calculate_coverage.py +50 -25
  94. smftools/preprocessing/calculate_pairwise_differences.py +2 -1
  95. smftools/preprocessing/calculate_pairwise_hamming_distances.py +4 -3
  96. smftools/preprocessing/calculate_position_Youden.py +118 -54
  97. smftools/preprocessing/calculate_read_length_stats.py +52 -23
  98. smftools/preprocessing/calculate_read_modification_stats.py +91 -57
  99. smftools/preprocessing/clean_NaN.py +38 -28
  100. smftools/preprocessing/filter_adata_by_nan_proportion.py +24 -12
  101. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +71 -38
  102. smftools/preprocessing/filter_reads_on_modification_thresholds.py +181 -73
  103. smftools/preprocessing/flag_duplicate_reads.py +689 -272
  104. smftools/preprocessing/invert_adata.py +26 -11
  105. smftools/preprocessing/load_sample_sheet.py +40 -22
  106. smftools/preprocessing/make_dirs.py +8 -3
  107. smftools/preprocessing/min_non_diagonal.py +2 -1
  108. smftools/preprocessing/recipes.py +56 -23
  109. smftools/preprocessing/reindex_references_adata.py +103 -0
  110. smftools/preprocessing/subsample_adata.py +33 -16
  111. smftools/readwrite.py +331 -82
  112. smftools/schema/__init__.py +11 -0
  113. smftools/schema/anndata_schema_v1.yaml +227 -0
  114. smftools/tools/__init__.py +3 -4
  115. smftools/tools/archived/classifiers.py +163 -0
  116. smftools/tools/archived/subset_adata_v1.py +10 -1
  117. smftools/tools/archived/subset_adata_v2.py +12 -1
  118. smftools/tools/calculate_umap.py +54 -15
  119. smftools/tools/cluster_adata_on_methylation.py +115 -46
  120. smftools/tools/general_tools.py +70 -25
  121. smftools/tools/position_stats.py +229 -98
  122. smftools/tools/read_stats.py +50 -29
  123. smftools/tools/spatial_autocorrelation.py +365 -192
  124. smftools/tools/subset_adata.py +23 -21
  125. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/METADATA +17 -39
  126. smftools-0.2.5.dist-info/RECORD +181 -0
  127. smftools-0.2.3.dist-info/RECORD +0 -173
  128. /smftools/cli/{cli_flows.py → archived/cli_flows.py} +0 -0
  129. /smftools/hmm/{apply_hmm_batched.py → archived/apply_hmm_batched.py} +0 -0
  130. /smftools/hmm/{calculate_distances.py → archived/calculate_distances.py} +0 -0
  131. /smftools/hmm/{train_hmm.py → archived/train_hmm.py} +0 -0
  132. /smftools/preprocessing/{add_read_length_and_mapping_qc.py → archived/add_read_length_and_mapping_qc.py} +0 -0
  133. /smftools/preprocessing/{archives → archived}/mark_duplicates.py +0 -0
  134. /smftools/preprocessing/{archives → archived}/remove_duplicates.py +0 -0
  135. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/WHEEL +0 -0
  136. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/entry_points.txt +0 -0
  137. {smftools-0.2.3.dist-info → smftools-0.2.5.dist-info}/licenses/LICENSE +0 -0
@@ -1,35 +1,48 @@
1
+ import os
1
2
 
2
- import numpy as np
3
3
  import matplotlib.pyplot as plt
4
+ import numpy as np
4
5
  import torch
5
- import os
6
+
6
7
 
7
8
  def plot_model_performance(metrics, save_path=None):
8
- import matplotlib.pyplot as plt
9
+ """Plot ROC and precision-recall curves for model metrics.
10
+
11
+ Args:
12
+ metrics: Dictionary of model metrics by reference.
13
+ save_path: Optional path to save plots.
14
+ """
9
15
  import os
16
+
10
17
  for ref in metrics.keys():
11
18
  plt.figure(figsize=(12, 5))
12
19
 
13
20
  # ROC Curve
14
21
  plt.subplot(1, 2, 1)
15
22
  for model_name, vals in metrics[ref].items():
16
- model_type = model_name.split('_')[0]
23
+ model_type = model_name.split("_")[0]
17
24
  data_type = model_name.split(f"{model_type}_")[1]
18
- plt.plot(vals['fpr'], vals['tpr'], label=f"{model_type.upper()} - AUC: {vals['auc']:.4f}")
19
- plt.xlabel('False Positive Rate')
20
- plt.ylabel('True Positive Rate')
21
- plt.title(f'{data_type} ROC Curve ({ref})')
25
+ plt.plot(
26
+ vals["fpr"], vals["tpr"], label=f"{model_type.upper()} - AUC: {vals['auc']:.4f}"
27
+ )
28
+ plt.xlabel("False Positive Rate")
29
+ plt.ylabel("True Positive Rate")
30
+ plt.title(f"{data_type} ROC Curve ({ref})")
22
31
  plt.legend()
23
32
 
24
33
  # PR Curve
25
34
  plt.subplot(1, 2, 2)
26
35
  for model_name, vals in metrics[ref].items():
27
- model_type = model_name.split('_')[0]
36
+ model_type = model_name.split("_")[0]
28
37
  data_type = model_name.split(f"{model_type}_")[1]
29
- plt.plot(vals['recall'], vals['precision'], label=f"{model_type.upper()} - F1: {vals['f1']:.4f}")
30
- plt.xlabel('Recall')
31
- plt.ylabel('Precision')
32
- plt.title(f'{data_type} Precision-Recall Curve ({ref})')
38
+ plt.plot(
39
+ vals["recall"],
40
+ vals["precision"],
41
+ label=f"{model_type.upper()} - F1: {vals['f1']:.4f}",
42
+ )
43
+ plt.xlabel("Recall")
44
+ plt.ylabel("Precision")
45
+ plt.title(f"{data_type} Precision-Recall Curve ({ref})")
33
46
  plt.legend()
34
47
 
35
48
  plt.tight_layout()
@@ -42,13 +55,14 @@ def plot_model_performance(metrics, save_path=None):
42
55
  plt.savefig(out_file, dpi=300)
43
56
  print(f"📁 Saved: {out_file}")
44
57
  plt.show()
45
-
58
+
46
59
  # Confusion Matrices
47
60
  for model_name, vals in metrics[ref].items():
48
61
  print(f"Confusion Matrix for {ref} - {model_name.upper()}:")
49
- print(vals['confusion_matrix'])
62
+ print(vals["confusion_matrix"])
50
63
  print()
51
64
 
65
+
52
66
  def plot_feature_importances_or_saliency(
53
67
  models,
54
68
  positions,
@@ -57,18 +71,31 @@ def plot_feature_importances_or_saliency(
57
71
  adata=None,
58
72
  layer_name=None,
59
73
  save_path=None,
60
- shaded_regions=None
74
+ shaded_regions=None,
61
75
  ):
62
- import torch
63
- import numpy as np
64
- import matplotlib.pyplot as plt
76
+ """Plot feature importances or saliency for trained models.
77
+
78
+ Args:
79
+ models: Mapping of trained models.
80
+ positions: Mapping of positions per reference.
81
+ tensors: Mapping of input tensors per reference.
82
+ site_config: Site configuration mapping.
83
+ adata: Optional AnnData object.
84
+ layer_name: Optional layer name for plotting.
85
+ save_path: Optional path to save plots.
86
+ shaded_regions: Optional list of regions to highlight.
87
+ """
65
88
  import os
66
89
 
90
+ import numpy as np
91
+
67
92
  # Select device for NN models
68
93
  device = (
69
- torch.device('cuda') if torch.cuda.is_available() else
70
- torch.device('mps') if torch.backends.mps.is_available() else
71
- torch.device('cpu')
94
+ torch.device("cuda")
95
+ if torch.cuda.is_available()
96
+ else torch.device("mps")
97
+ if torch.backends.mps.is_available()
98
+ else torch.device("cpu")
72
99
  )
73
100
 
74
101
  for ref, model_dict in models.items():
@@ -90,7 +117,9 @@ def plot_feature_importances_or_saliency(
90
117
  other_sites = set()
91
118
 
92
119
  if adata is None:
93
- print("⚠️ AnnData object is required to classify site types. Skipping site type markers.")
120
+ print(
121
+ "⚠️ AnnData object is required to classify site types. Skipping site type markers."
122
+ )
94
123
  else:
95
124
  gpc_col = f"{ref}_GpC_site"
96
125
  cpg_col = f"{ref}_CpG_site"
@@ -146,20 +175,46 @@ def plot_feature_importances_or_saliency(
146
175
  plt.figure(figsize=(12, 4))
147
176
  for pos, imp in zip(positions_sorted, importances_sorted):
148
177
  if pos in cpg_sites:
149
- plt.plot(pos, imp, marker='*', color='black', markersize=10, linestyle='None',
150
- label='CpG site' if 'CpG site' not in plt.gca().get_legend_handles_labels()[1] else "")
178
+ plt.plot(
179
+ pos,
180
+ imp,
181
+ marker="*",
182
+ color="black",
183
+ markersize=10,
184
+ linestyle="None",
185
+ label="CpG site"
186
+ if "CpG site" not in plt.gca().get_legend_handles_labels()[1]
187
+ else "",
188
+ )
151
189
  elif pos in gpc_sites:
152
- plt.plot(pos, imp, marker='o', color='blue', markersize=6, linestyle='None',
153
- label='GpC site' if 'GpC site' not in plt.gca().get_legend_handles_labels()[1] else "")
190
+ plt.plot(
191
+ pos,
192
+ imp,
193
+ marker="o",
194
+ color="blue",
195
+ markersize=6,
196
+ linestyle="None",
197
+ label="GpC site"
198
+ if "GpC site" not in plt.gca().get_legend_handles_labels()[1]
199
+ else "",
200
+ )
154
201
  else:
155
- plt.plot(pos, imp, marker='.', color='gray', linestyle='None',
156
- label='Other' if 'Other' not in plt.gca().get_legend_handles_labels()[1] else "")
157
-
158
- plt.plot(positions_sorted, importances_sorted, linestyle='-', alpha=0.5, color='black')
202
+ plt.plot(
203
+ pos,
204
+ imp,
205
+ marker=".",
206
+ color="gray",
207
+ linestyle="None",
208
+ label="Other"
209
+ if "Other" not in plt.gca().get_legend_handles_labels()[1]
210
+ else "",
211
+ )
212
+
213
+ plt.plot(positions_sorted, importances_sorted, linestyle="-", alpha=0.5, color="black")
159
214
 
160
215
  if shaded_regions:
161
- for (start, end) in shaded_regions:
162
- plt.axvspan(start, end, color='gray', alpha=0.3)
216
+ for start, end in shaded_regions:
217
+ plt.axvspan(start, end, color="gray", alpha=0.3)
163
218
 
164
219
  plt.xlabel("Genomic Position")
165
220
  plt.ylabel(y_label)
@@ -170,31 +225,47 @@ def plot_feature_importances_or_saliency(
170
225
 
171
226
  if save_path:
172
227
  os.makedirs(save_path, exist_ok=True)
173
- safe_name = plot_title.replace("=", "").replace("__", "_").replace(",", "_").replace(" ", "_")
228
+ safe_name = (
229
+ plot_title.replace("=", "")
230
+ .replace("__", "_")
231
+ .replace(",", "_")
232
+ .replace(" ", "_")
233
+ )
174
234
  out_file = os.path.join(save_path, f"{safe_name}.png")
175
235
  plt.savefig(out_file, dpi=300)
176
236
  print(f"📁 Saved: {out_file}")
177
237
 
178
238
  plt.show()
179
239
 
240
+
180
241
  def plot_model_curves_from_adata(
181
- adata,
182
- label_col='activity_status',
183
- model_names = ["cnn", "mlp", "rf"],
184
- suffix='GpC_site_CpG_site',
185
- omit_training=True,
186
- save_path=None,
187
- ylim_roc=(0.0, 1.05),
188
- ylim_pr=(0.0, 1.05)):
189
-
190
- from sklearn.metrics import precision_recall_curve, roc_curve, auc
191
- import matplotlib.pyplot as plt
192
- import seaborn as sns
242
+ adata,
243
+ label_col="activity_status",
244
+ model_names=["cnn", "mlp", "rf"],
245
+ suffix="GpC_site_CpG_site",
246
+ omit_training=True,
247
+ save_path=None,
248
+ ylim_roc=(0.0, 1.05),
249
+ ylim_pr=(0.0, 1.05),
250
+ ):
251
+ """Plot ROC and PR curves using AnnData model outputs.
252
+
253
+ Args:
254
+ adata: AnnData containing model outputs.
255
+ label_col: Ground-truth label column.
256
+ model_names: Model name prefixes.
257
+ suffix: Prediction column suffix.
258
+ omit_training: Whether to omit training rows.
259
+ save_path: Optional path to save the plot.
260
+ ylim_roc: Y-axis limits for ROC curve.
261
+ ylim_pr: Y-axis limits for PR curve.
262
+ """
263
+ from sklearn.metrics import auc, precision_recall_curve, roc_curve
193
264
 
194
265
  if omit_training:
195
- subset = adata[adata.obs['used_for_training'].astype(bool) == False]
266
+ subset = adata[~adata.obs["used_for_training"].astype(bool)]
196
267
 
197
- label = subset.obs[label_col].map({'Active': 1, 'Silent': 0}).values
268
+ label = subset.obs[label_col].map({"Active": 1, "Silent": 0}).values
198
269
 
199
270
  positive_ratio = np.sum(label.astype(int)) / len(label)
200
271
 
@@ -210,7 +281,7 @@ def plot_model_curves_from_adata(
210
281
  roc_auc = auc(fpr, tpr)
211
282
  plt.plot(fpr, tpr, label=f"{model.upper()} (AUC={roc_auc:.4f})")
212
283
 
213
- plt.plot([0, 1], [0, 1], 'k--', alpha=0.5)
284
+ plt.plot([0, 1], [0, 1], "k--", alpha=0.5)
214
285
  plt.xlabel("False Positive Rate")
215
286
  plt.ylabel("True Positive Rate")
216
287
  plt.title("ROC Curve")
@@ -230,13 +301,13 @@ def plot_model_curves_from_adata(
230
301
  plt.xlabel("Recall")
231
302
  plt.ylabel("Precision")
232
303
  plt.ylim(*ylim_pr)
233
- plt.axhline(y=positive_ratio, linestyle='--', color='gray', label='Random Baseline')
304
+ plt.axhline(y=positive_ratio, linestyle="--", color="gray", label="Random Baseline")
234
305
  plt.title("Precision-Recall Curve")
235
306
  plt.legend()
236
307
 
237
308
  plt.tight_layout()
238
309
  if save_path:
239
- save_name = f"ROC_PR_curves"
310
+ save_name = "ROC_PR_curves"
240
311
  os.makedirs(save_path, exist_ok=True)
241
312
  safe_name = save_name.replace("=", "").replace("__", "_").replace(",", "_")
242
313
  out_file = os.path.join(save_path, f"{safe_name}.png")
@@ -244,11 +315,12 @@ def plot_model_curves_from_adata(
244
315
  print(f"📁 Saved: {out_file}")
245
316
  plt.show()
246
317
 
318
+
247
319
  def plot_model_curves_from_adata_with_frequency_grid(
248
320
  adata,
249
- label_col='activity_status',
321
+ label_col="activity_status",
250
322
  model_names=["cnn", "mlp", "rf"],
251
- suffix='GpC_site_CpG_site',
323
+ suffix="GpC_site_CpG_site",
252
324
  omit_training=True,
253
325
  save_path=None,
254
326
  ylim_roc=(0.0, 1.05),
@@ -256,22 +328,38 @@ def plot_model_curves_from_adata_with_frequency_grid(
256
328
  pos_sample_count=500,
257
329
  pos_freq_list=[0.01, 0.05, 0.1],
258
330
  show_f1_iso_curves=False,
259
- f1_levels=None):
260
- import numpy as np
261
- import matplotlib.pyplot as plt
262
- import seaborn as sns
331
+ f1_levels=None,
332
+ ):
333
+ """Plot ROC/PR curves with frequency grid overlays.
334
+
335
+ Args:
336
+ adata: AnnData containing model outputs.
337
+ label_col: Ground-truth label column.
338
+ model_names: Model name prefixes.
339
+ suffix: Prediction column suffix.
340
+ omit_training: Whether to omit training rows.
341
+ save_path: Optional path to save the plot.
342
+ ylim_roc: Y-axis limits for ROC curve.
343
+ ylim_pr: Y-axis limits for PR curve.
344
+ pos_sample_count: Sample count for positive baseline.
345
+ pos_freq_list: List of positive class frequencies to plot.
346
+ show_f1_iso_curves: Whether to show F1 iso-curves.
347
+ f1_levels: F1 levels to plot if enabled.
348
+ """
263
349
  import os
264
- from sklearn.metrics import precision_recall_curve, roc_curve, auc
265
-
350
+
351
+ import numpy as np
352
+ from sklearn.metrics import auc, precision_recall_curve, roc_curve
353
+
266
354
  if f1_levels is None:
267
355
  f1_levels = np.linspace(0.2, 0.9, 8)
268
-
356
+
269
357
  if omit_training:
270
- subset = adata[adata.obs['used_for_training'].astype(bool) == False]
358
+ subset = adata[~adata.obs["used_for_training"].astype(bool)]
271
359
  else:
272
360
  subset = adata
273
361
 
274
- label = subset.obs[label_col].map({'Active': 1, 'Silent': 0}).values
362
+ label = subset.obs[label_col].map({"Active": 1, "Silent": 0}).values
275
363
  subset = subset.copy()
276
364
  subset.obs["__label__"] = label
277
365
 
@@ -280,7 +368,7 @@ def plot_model_curves_from_adata_with_frequency_grid(
280
368
 
281
369
  n_rows = len(pos_freq_list)
282
370
  fig, axes = plt.subplots(n_rows, 2, figsize=(12, 5 * n_rows))
283
- fig.suptitle(f'{suffix} Performance metrics')
371
+ fig.suptitle(f"{suffix} Performance metrics")
284
372
 
285
373
  for row_idx, pos_freq in enumerate(pos_freq_list):
286
374
  desired_total = int(pos_sample_count / pos_freq)
@@ -308,14 +396,14 @@ def plot_model_curves_from_adata_with_frequency_grid(
308
396
  fpr, tpr, _ = roc_curve(y_true, probs)
309
397
  roc_auc = auc(fpr, tpr)
310
398
  ax_roc.plot(fpr, tpr, label=f"{model.upper()} (AUC={roc_auc:.4f})")
311
- ax_roc.plot([0, 1], [0, 1], 'k--', alpha=0.5)
399
+ ax_roc.plot([0, 1], [0, 1], "k--", alpha=0.5)
312
400
  ax_roc.set_xlabel("False Positive Rate")
313
401
  ax_roc.set_ylabel("True Positive Rate")
314
402
  ax_roc.set_ylim(*ylim_roc)
315
403
  ax_roc.set_title(f"ROC Curve (Pos Freq: {pos_freq:.2%})")
316
404
  ax_roc.legend()
317
- ax_roc.spines['top'].set_visible(False)
318
- ax_roc.spines['right'].set_visible(False)
405
+ ax_roc.spines["top"].set_visible(False)
406
+ ax_roc.spines["right"].set_visible(False)
319
407
 
320
408
  # PR Curve
321
409
  for model in model_names:
@@ -325,26 +413,28 @@ def plot_model_curves_from_adata_with_frequency_grid(
325
413
  precision, recall, _ = precision_recall_curve(y_true, probs)
326
414
  pr_auc = auc(recall, precision)
327
415
  ax_pr.plot(recall, precision, label=f"{model.upper()} (AUC={pr_auc:.4f})")
328
- ax_pr.axhline(y=pos_freq, linestyle='--', color='gray', label='Random Baseline')
416
+ ax_pr.axhline(y=pos_freq, linestyle="--", color="gray", label="Random Baseline")
329
417
 
330
418
  if show_f1_iso_curves:
331
419
  recall_vals = np.linspace(0.01, 1, 500)
332
420
  for f1 in f1_levels:
333
421
  precision_vals = (f1 * recall_vals) / (2 * recall_vals - f1)
334
422
  precision_vals[precision_vals < 0] = np.nan # Avoid plotting invalid values
335
- ax_pr.plot(recall_vals, precision_vals, color='gray', linestyle=':', linewidth=1, alpha=0.6)
423
+ ax_pr.plot(
424
+ recall_vals, precision_vals, color="gray", linestyle=":", linewidth=1, alpha=0.6
425
+ )
336
426
  x_val = 0.9
337
427
  y_val = (f1 * x_val) / (2 * x_val - f1)
338
428
  if 0 < y_val < 1:
339
- ax_pr.text(x_val, y_val, f"F1={f1:.1f}", fontsize=8, color='gray')
429
+ ax_pr.text(x_val, y_val, f"F1={f1:.1f}", fontsize=8, color="gray")
340
430
 
341
431
  ax_pr.set_xlabel("Recall")
342
432
  ax_pr.set_ylabel("Precision")
343
433
  ax_pr.set_ylim(*ylim_pr)
344
434
  ax_pr.set_title(f"PR Curve (Pos Freq: {pos_freq:.2%})")
345
435
  ax_pr.legend()
346
- ax_pr.spines['top'].set_visible(False)
347
- ax_pr.spines['right'].set_visible(False)
436
+ ax_pr.spines["top"].set_visible(False)
437
+ ax_pr.spines["right"].set_visible(False)
348
438
 
349
439
  plt.tight_layout(rect=[0, 0, 1, 0.97])
350
440
  if save_path:
@@ -352,4 +442,4 @@ def plot_model_curves_from_adata_with_frequency_grid(
352
442
  out_file = os.path.join(save_path, "ROC_PR_grid.png")
353
443
  plt.savefig(out_file, dpi=300)
354
444
  print(f"📁 Saved: {out_file}")
355
- plt.show()
445
+ plt.show()