pythonflex 0.3.3__tar.gz → 0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. pythonflex-0.3.3/src/pythonflex/plotting.py → pythonflex-0.4/.codex_backups/plotting.py.20260601-121332.bak +394 -165
  2. {pythonflex-0.3.3 → pythonflex-0.4}/.gitignore +3 -1
  3. pythonflex-0.4/LICENSE +7 -0
  4. {pythonflex-0.3.3 → pythonflex-0.4}/PKG-INFO +9 -4
  5. {pythonflex-0.3.3 → pythonflex-0.4}/README.md +7 -3
  6. {pythonflex-0.3.3 → pythonflex-0.4}/pyproject.toml +69 -77
  7. pythonflex-0.4/src/pythonflex/__init__.py +44 -0
  8. {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/analysis.py +287 -578
  9. pythonflex-0.4/src/pythonflex/examples/basic_usage.py +95 -0
  10. {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/examples/manuscript.py +37 -42
  11. pythonflex-0.4/src/pythonflex/examples/runtime/runtime_benchmark.py +218 -0
  12. pythonflex-0.4/src/pythonflex/examples/runtime/runtime_benchmark_10_runs_memmap.py +534 -0
  13. pythonflex-0.4/src/pythonflex/examples/runtime/runtime_benchmark_corum_njobs.py +245 -0
  14. pythonflex-0.4/src/pythonflex/examples/runtime/runtime_benchmark_gobp_njobs_chunks.py +319 -0
  15. pythonflex-0.4/src/pythonflex/examples/runtime/runtime_benchmark_gobp_optimization.py +417 -0
  16. pythonflex-0.4/src/pythonflex/examples/runtime/runtime_benchmark_repeated.py +347 -0
  17. pythonflex-0.4/src/pythonflex/old_functions.py +422 -0
  18. pythonflex-0.4/src/pythonflex/plotting.py +2418 -0
  19. {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/preprocessing.py +62 -60
  20. {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/utils.py +36 -9
  21. pythonflex-0.4/todo.txt +2 -0
  22. pythonflex-0.3.3/src/pythonflex/__init__.py +0 -20
  23. pythonflex-0.3.3/src/pythonflex/examples/basic_usage.py +0 -87
  24. pythonflex-0.3.3/todo.txt +0 -1
  25. {pythonflex-0.3.3 → pythonflex-0.4}/.python-version +0 -0
  26. {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/data/__init__.py +0 -0
  27. {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/data/dataset/__init__.py +0 -0
  28. {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/data/dataset/liver_cell_lines_500_genes.csv +0 -0
  29. {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/data/dataset/melanoma_cell_lines_500_genes.csv +0 -0
  30. {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/data/dataset/neuroblastoma_cell_lines_500_genes.csv +0 -0
  31. {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/data/gold_standard/CORUM.parquet +0 -0
  32. {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/data/gold_standard/GOBP.parquet +0 -0
  33. {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/data/gold_standard/PATHWAY.parquet +0 -0
  34. {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/data/gold_standard/__init__.py +0 -0
  35. {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/data/gold_standard/corum.csv +0 -0
  36. {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/data/gold_standard/gobp.csv +0 -0
  37. {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/data/gold_standard/pathway.csv +0 -0
  38. {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/logging_config.py +0 -0
  39. {pythonflex-0.3.3 → pythonflex-0.4}/uv.lock +0 -0
@@ -9,6 +9,7 @@ import pandas as pd
9
9
  import matplotlib.pyplot as plt
10
10
  from matplotlib import patches
11
11
  from matplotlib.cm import get_cmap
12
+ from matplotlib.lines import Line2D
12
13
  from matplotlib.ticker import NullFormatter, NullLocator
13
14
 
14
15
  # Completely disable LaTeX and clear all font cache/references
@@ -33,7 +34,7 @@ mpl.rcParams['mathtext.default'] = 'regular'
33
34
  # Force font manager to rebuild with system fonts only
34
35
  try:
35
36
  fm.fontManager.__init__()
36
- except:
37
+ except Exception:
37
38
  pass
38
39
 
39
40
  # Local modules
@@ -279,13 +280,27 @@ def plot_all_runs_pra(pra_list, mean_df=None, line_width=2.0, hide_minor_ticks=T
279
280
  plt.show()
280
281
  plt.close(fig)
281
282
 
282
- def plot_percomplex_scatter(n_top=10, sig_color='#B71A2A', nonsig_color='#DBDDDD', label_color='black', border_color='black', border_width=1.0, show_text_background=True):
283
- config = dload("config")
284
- plot_config = config["plotting"]
285
- rdict = dload("pra_percomplex")
283
+ def plot_percomplex_scatter(
284
+ n_top=10,
285
+ sig_color='black',
286
+ nonsig_color='none',
287
+ label_color='black',
288
+ border_color='black',
289
+ border_width=1.0,
290
+ nonsig_border_color="#7F7F7F",
291
+ nonsig_border_width=0.5,
292
+ show_text_background=True,
293
+ ):
294
+ config = dload("config")
295
+ plot_config = config["plotting"]
296
+ rdict = dload("pra_percomplex")
297
+ input_colors = dload("input", "colors")
298
+ input_colors = {_sanitize(k): v for k, v in input_colors.items()} if input_colors else {}
286
299
 
287
300
  if len(rdict) < 2:
288
- print("Skipping plot: At least two datasets are required for per-complex scatter plot.")
301
+ log.warning(
302
+ "Skipping plot: at least two datasets are required for per-complex scatter plot."
303
+ )
289
304
  return
290
305
 
291
306
  column_pairs = list(combinations(rdict.keys(), 2))
@@ -299,32 +314,61 @@ def plot_percomplex_scatter(n_top=10, sig_color='#B71A2A', nonsig_color='#DBDDDD
299
314
  df = pd.concat([df, val[key]], axis=1)
300
315
 
301
316
  for pair in column_pairs:
302
- extreme_indices_0 = df[pair[0]].sort_values(ascending=False).head(n_top).index
303
- extreme_indices_1 = df[pair[1]].sort_values(ascending=False).head(n_top).index
304
- significant_indices = extreme_indices_0.union(extreme_indices_1)
305
-
306
- bg_df = df.drop(index=significant_indices)
307
- sig_df = df.loc[significant_indices]
308
-
309
- # Create square figure
310
- fig, ax = plt.subplots(figsize=(6, 6))
311
-
312
- # Background cloud (filled dots with black borders, not rasterized)
313
- bg_sizes = (bg_df['n_used_genes'] if 'n_used_genes' in bg_df else pd.Series(1, index=bg_df.index)) * 5
314
- ax.scatter(
315
- bg_df[pair[0]], bg_df[pair[1]],
316
- facecolors=nonsig_color, edgecolors=border_color,
317
- s=bg_sizes, linewidth=border_width, alpha=1.0,
318
- zorder=0
319
- )
320
-
321
- # Significant points (filled dots with black borders)
322
- sig_sizes = (sig_df['n_used_genes'] if 'n_used_genes' in sig_df else pd.Series(1, index=sig_df.index)) * 8
323
- ax.scatter(
324
- sig_df[pair[0]], sig_df[pair[1]],
325
- facecolors=sig_color, edgecolors=border_color,
326
- s=sig_sizes, linewidth=border_width, zorder=2
327
- )
317
+ extreme_indices_0 = df[pair[0]].sort_values(ascending=False).head(n_top).index
318
+ extreme_indices_1 = df[pair[1]].sort_values(ascending=False).head(n_top).index
319
+ significant_indices = extreme_indices_0.union(extreme_indices_1)
320
+ significant_in_both = extreme_indices_0.intersection(extreme_indices_1)
321
+ significant_pair0_only = extreme_indices_0.difference(extreme_indices_1)
322
+ significant_pair1_only = extreme_indices_1.difference(extreme_indices_0)
323
+
324
+ bg_df = df.drop(index=significant_indices)
325
+ sig_df = df.loc[significant_indices]
326
+ sig_sizes = (
327
+ sig_df['n_used_genes']
328
+ if 'n_used_genes' in sig_df
329
+ else pd.Series(1, index=sig_df.index)
330
+ ) * 8
331
+
332
+ # Create square figure
333
+ fig, ax = plt.subplots(figsize=(6, 6))
334
+
335
+ # Background cloud: non-significant complexes are open circles.
336
+ bg_sizes = (bg_df['n_used_genes'] if 'n_used_genes' in bg_df else pd.Series(1, index=bg_df.index)) * 5
337
+ ax.scatter(
338
+ bg_df[pair[0]], bg_df[pair[1]],
339
+ facecolors="none", edgecolors=nonsig_border_color,
340
+ s=bg_sizes, linewidth=nonsig_border_width, alpha=0.8,
341
+ zorder=0
342
+ )
343
+
344
+ def scatter_significant(indices, color, zorder=2):
345
+ if len(indices) == 0:
346
+ return
347
+ point_df = df.loc[indices]
348
+ point_sizes = (
349
+ point_df['n_used_genes']
350
+ if 'n_used_genes' in point_df
351
+ else pd.Series(1, index=point_df.index)
352
+ ) * 8
353
+ ax.scatter(
354
+ point_df[pair[0]], point_df[pair[1]],
355
+ facecolors=color, edgecolors=color,
356
+ s=point_sizes, linewidth=border_width, zorder=zorder
357
+ )
358
+
359
+ # Dataset-specific significant complexes use the dataset input color.
360
+ scatter_significant(
361
+ significant_pair0_only,
362
+ input_colors.get(_sanitize(pair[0]), sig_color),
363
+ zorder=2,
364
+ )
365
+ scatter_significant(
366
+ significant_pair1_only,
367
+ input_colors.get(_sanitize(pair[1]), sig_color),
368
+ zorder=2,
369
+ )
370
+ # Complexes significant in both datasets stay black to avoid ambiguous color mixing.
371
+ scatter_significant(significant_in_both, "black", zorder=3)
328
372
 
329
373
  # Improved label positioning with adaptive spacing
330
374
  coords = sorted(
@@ -894,16 +938,28 @@ def position_cluster_labels(cluster, cluster_id, max_y, effective_max_y, label_c
894
938
  clip_on=True, bbox=bbox_props
895
939
  )
896
940
 
897
- def plot_percomplex_scatter_bysize(n_labels=10, n_top=10, sig_color='#B71A2A', nonsig_color='#DBDDDD',
898
- label_color='black', border_color='black', border_width=1.0,
899
- show_text_background=True):
900
- config = dload("config")
901
- plot_config = config["plotting"]
902
- rdict = dload("pra_percomplex")
903
-
904
- for key, per_complex in rdict.items():
905
- sorted_pc = per_complex.sort_values(by="auc_score", ascending=False, na_position="last")
906
- top_labels, rest = sorted_pc.head(n_labels), sorted_pc.iloc[n_labels:]
941
+ def plot_percomplex_scatter_bysize(
942
+ n_labels=10,
943
+ n_top=10,
944
+ sig_color='black',
945
+ nonsig_color='none',
946
+ label_color='black',
947
+ border_color='black',
948
+ border_width=1.0,
949
+ nonsig_border_color="#7F7F7F",
950
+ nonsig_border_width=0.5,
951
+ show_text_background=True,
952
+ ):
953
+ config = dload("config")
954
+ plot_config = config["plotting"]
955
+ rdict = dload("pra_percomplex")
956
+ input_colors = dload("input", "colors")
957
+ input_colors = {_sanitize(k): v for k, v in input_colors.items()} if input_colors else {}
958
+
959
+ for key, per_complex in rdict.items():
960
+ dataset_color = input_colors.get(_sanitize(key), sig_color)
961
+ sorted_pc = per_complex.sort_values(by="auc_score", ascending=False, na_position="last")
962
+ top_labels, rest = sorted_pc.head(n_labels), sorted_pc.iloc[n_labels:]
907
963
 
908
964
  # Calculate data range for appropriate figure sizing
909
965
  max_genes = sorted_pc.n_used_genes.max()
@@ -914,22 +970,22 @@ def plot_percomplex_scatter_bysize(n_labels=10, n_top=10, sig_color='#B71A2A', n
914
970
  fig_height = min(max(4, aspect_ratio), 8) # Between 4-8 inches
915
971
  fig, ax = plt.subplots(figsize=(6, fig_height))
916
972
 
917
- # Background (REST): filled dots with black borders, not rasterized
918
- ax.scatter(
919
- rest.auc_score, rest.n_used_genes,
920
- facecolors=nonsig_color, edgecolors=border_color,
921
- linewidth=border_width, s=rest.n_used_genes * 5,
922
- alpha=1.0, label="Other Complexes",
923
- zorder=0
924
- )
925
-
926
- # Top N: filled dots with black borders
927
- ax.scatter(
928
- top_labels.auc_score, top_labels.n_used_genes,
929
- facecolors=sig_color, edgecolors=border_color,
930
- linewidth=border_width, s=top_labels.n_used_genes * 8,
931
- label=f"Top {n_labels} AUC Scores", alpha=1.0, zorder=2
932
- )
973
+ # Background: non-significant complexes are open circles.
974
+ ax.scatter(
975
+ rest.auc_score, rest.n_used_genes,
976
+ facecolors="none", edgecolors=nonsig_border_color,
977
+ linewidth=nonsig_border_width, s=rest.n_used_genes * 5,
978
+ alpha=0.8, label="Other Complexes",
979
+ zorder=0
980
+ )
981
+
982
+ # Top N/significant complexes are filled black circles.
983
+ ax.scatter(
984
+ top_labels.auc_score, top_labels.n_used_genes,
985
+ facecolors=dataset_color, edgecolors=dataset_color,
986
+ linewidth=border_width, s=top_labels.n_used_genes * 8,
987
+ label=f"Top {n_labels} AUC Scores", alpha=1.0, zorder=2
988
+ )
933
989
 
934
990
  # Enhanced anti-overlap labeling system
935
991
  coords = [(row.auc_score, row.n_used_genes, idx) for idx, row in top_labels.iterrows()]
@@ -1010,6 +1066,7 @@ def plot_complex_contributions(
1010
1066
  tmp = np.tile(x, (mx, 1))
1011
1067
  x = cont_stepwise_mat.values / tmp
1012
1068
  x_df = pd.DataFrame(x, index=cont_stepwise_anno, columns=cont_stepwise_mat.columns)
1069
+
1013
1070
  ind_for_mean = y >= (last_prec_value - min_precision_cutoff)
1014
1071
  if sum(ind_for_mean) == 0:
1015
1072
  log.info("No values above 'min.precision.cutoff'"); return False
@@ -1094,15 +1151,23 @@ def plot_significant_complexes():
1094
1151
  input_colors = {_sanitize(k): v for k, v in input_colors.items()}
1095
1152
 
1096
1153
  thresholds = [0.1, 0.2, 0.3, 0.4, 0.5]
1154
+ if not isinstance(pra_percomplex, dict) or not pra_percomplex:
1155
+ log.warning("No per-complex PRA data found. Run pra_percomplex() first.")
1156
+ return pd.DataFrame(index=thresholds)
1157
+
1097
1158
  datasets = list(pra_percomplex.keys())
1098
1159
  num_datasets = len(datasets)
1099
1160
 
1161
+ if num_datasets == 0:
1162
+ return pd.DataFrame(index=thresholds)
1163
+
1100
1164
  df = pd.DataFrame(index=thresholds)
1101
1165
  for key, complex_data in pra_percomplex.items():
1102
1166
  if "corrected_auc_score" in complex_data.columns:
1103
1167
  score_col = "corrected_auc_score"
1104
1168
  else:
1105
1169
  score_col = "auc_score"
1170
+
1106
1171
  df[key] = [complex_data.query(f'{score_col} >= {t}').shape[0] for t in thresholds]
1107
1172
 
1108
1173
  fig, ax = plt.subplots()
@@ -1221,15 +1286,20 @@ def plot_auc_scores():
1221
1286
  return pra_dict
1222
1287
 
1223
1288
 
1224
- def plot_mpr_complexes_auc_scores(filter_key: str = "all"):
1289
+ def plot_mpr_complex_auc_scores(variant: str = "unfiltered", save=None, outname=None):
1225
1290
  """Plot AUC scores for the mPR complexes curve (Fig 1F-style).
1226
1291
 
1227
1292
  Requires `mpr_prepare()` to have been run for each dataset.
1228
1293
 
1229
1294
  Parameters
1230
1295
  ----------
1231
- filter_key : str
1232
- One of: "all", "no_mtRibo_ETCI", "no_small_highAUPRC".
1296
+ variant : str
1297
+ One of: "unfiltered", "without_mt_ribo_etci",
1298
+ "without_small_high_auprc".
1299
+ save : bool, optional
1300
+ Whether to save the figure. If None, uses config["plotting"]["save_plot"].
1301
+ outname : str, optional
1302
+ Output filename. If None, auto-generated.
1233
1303
 
1234
1304
  Returns
1235
1305
  -------
@@ -1250,12 +1320,14 @@ def plot_mpr_complexes_auc_scores(filter_key: str = "all"):
1250
1320
  )
1251
1321
  return pd.Series(dtype=float)
1252
1322
 
1323
+ variant_key = _normalize_mpr_variant(variant)
1324
+
1253
1325
  # Build Series: dataset -> auc
1254
1326
  auc_by_dataset = {}
1255
1327
  for dataset, per_filter in mpr_auc_dict.items():
1256
1328
  if not isinstance(per_filter, dict):
1257
1329
  continue
1258
- val = per_filter.get(filter_key)
1330
+ val = per_filter.get(variant_key)
1259
1331
  if val is None:
1260
1332
  continue
1261
1333
  try:
@@ -1265,7 +1337,8 @@ def plot_mpr_complexes_auc_scores(filter_key: str = "all"):
1265
1337
 
1266
1338
  if not auc_by_dataset:
1267
1339
  log.warning(
1268
- f"No mPR complexes AUC scores found for filter '{filter_key}'. Available filters: {list(FILTER_STYLES.keys())}"
1340
+ f"No mPR complex AUC scores found for variant '{variant}'. "
1341
+ f"Available variants: {list(PUBLIC_MPR_VARIANTS.keys())}"
1269
1342
  )
1270
1343
  return pd.Series(dtype=float)
1271
1344
 
@@ -1308,11 +1381,16 @@ def plot_mpr_complexes_auc_scores(filter_key: str = "all"):
1308
1381
  ax.spines["top"].set_visible(False)
1309
1382
  ax.spines["right"].set_visible(False)
1310
1383
 
1311
- if plot_config.get("save_plot", False):
1384
+ should_save = plot_config.get("save_plot", False) if save is None else bool(save)
1385
+ if should_save:
1312
1386
  output_type = plot_config.get("output_type", "pdf")
1313
1387
  output_folder = Path(config["output_folder"])
1314
1388
  output_folder.mkdir(parents=True, exist_ok=True)
1315
- output_path = output_folder / f"mpr_complexes_auc_{filter_key}.{output_type}"
1389
+ if outname is None:
1390
+ outname = f"mpr_complexes_auc_{variant_key}.{output_type}"
1391
+ output_path = Path(outname)
1392
+ if len(output_path.parts) == 1:
1393
+ output_path = output_folder / outname
1316
1394
  plt.savefig(output_path, bbox_inches="tight", format=output_type)
1317
1395
 
1318
1396
  if plot_config.get("show_plot", True):
@@ -1321,6 +1399,13 @@ def plot_mpr_complexes_auc_scores(filter_key: str = "all"):
1321
1399
  plt.close(fig)
1322
1400
  return s
1323
1401
 
1402
+
1403
+ def plot_mpr_complexes_auc_scores(filter_key: str = "all"):
1404
+ """Backward-compatible wrapper for plot_mpr_complex_auc_scores()."""
1405
+ return plot_mpr_complex_auc_scores(
1406
+ variant=_legacy_filter_to_variant(filter_key, default="unfiltered")
1407
+ )
1408
+
1324
1409
  # -----------------------------------------------------------------------------
1325
1410
  # mPR plots (Fig. 1E and Fig. 1F)
1326
1411
  # -----------------------------------------------------------------------------
@@ -1475,26 +1560,6 @@ def plot_mpr_tp(name, ax=None, save=True, outname=None):
1475
1560
 
1476
1561
  return ax
1477
1562
 
1478
- """
1479
- Multi-dataset mPR plotting functions.
1480
-
1481
- Usage:
1482
- from pythonflex.plotting import plot_mpr_tp_multi, plot_mpr_complexes_multi
1483
-
1484
- # Plot multiple datasets
1485
- plot_mpr_tp_multi(["19Q2", "19Q4", "20Q1"])
1486
- plot_mpr_complexes_multi(["19Q2", "19Q4", "20Q1"])
1487
- """
1488
-
1489
- import numpy as np
1490
- import pandas as pd
1491
- import matplotlib.pyplot as plt
1492
- from matplotlib.lines import Line2D
1493
- from pathlib import Path
1494
-
1495
- from .utils import dload
1496
- from .logging_config import log
1497
-
1498
1563
  # Default color palette (colorblind-friendly)
1499
1564
  DEFAULT_COLORS = [
1500
1565
  "#4E79A7", # blue
@@ -1509,40 +1574,101 @@ DEFAULT_COLORS = [
1509
1574
  "#BAB0AC", # gray
1510
1575
  ]
1511
1576
 
1512
- # Filter line styles
1513
- FILTER_STYLES = {
1577
+ # Public mPR variant names map to the internal keys stored by mpr_prepare().
1578
+ PUBLIC_MPR_VARIANTS = {
1579
+ "unfiltered": "all",
1580
+ "without_mt_ribo_etci": "no_mtRibo_ETCI",
1581
+ "without_small_high_auprc": "no_small_highAUPRC",
1582
+ }
1583
+ INTERNAL_MPR_VARIANTS = {v: k for k, v in PUBLIC_MPR_VARIANTS.items()}
1584
+
1585
+ # mPR variant line styles keyed by internal storage names.
1586
+ MPR_VARIANT_STYLES = {
1514
1587
  "all": {"linestyle": "-", "label": "all data"},
1515
1588
  "no_mtRibo_ETCI": {"linestyle": "--", "label": "no mtRibo, ETC I"},
1516
1589
  "no_small_highAUPRC": {"linestyle": "dotted", "label": "no small, high AUPRC"},
1517
1590
  }
1518
1591
 
1592
+ # Compatibility alias for users who imported this internal constant.
1593
+ FILTER_STYLES = MPR_VARIANT_STYLES
1594
+
1595
+
1596
+ def _normalize_mpr_variant(variant):
1597
+ """Return the internal mPR variant key for one public variant name."""
1598
+ if variant in PUBLIC_MPR_VARIANTS:
1599
+ return PUBLIC_MPR_VARIANTS[variant]
1600
+ if variant in MPR_VARIANT_STYLES:
1601
+ if variant == "all":
1602
+ return PUBLIC_MPR_VARIANTS["unfiltered"]
1603
+ return variant
1604
+ raise ValueError(
1605
+ "Unknown mPR variant "
1606
+ f"{variant!r}. Use one of {list(PUBLIC_MPR_VARIANTS.keys())}."
1607
+ )
1519
1608
 
1520
- def _normalize_show_filters(show_filters):
1521
- """Normalize show_filters to an ordered tuple of filter keys.
1522
1609
 
1523
- Common footgun: passing a single string (e.g. "no_mtRibo_ETCI") is iterable,
1524
- which would otherwise be treated as a sequence of characters.
1525
- """
1610
+ def _normalize_mpr_variants(variants):
1611
+ """Normalize public mPR variant names to internal storage keys."""
1612
+ if variants is None:
1613
+ raw_variants = ("all",)
1614
+ elif isinstance(variants, str):
1615
+ raw_variants = (variants,)
1616
+ else:
1617
+ try:
1618
+ raw_variants = tuple(variants)
1619
+ except TypeError:
1620
+ raw_variants = (variants,)
1621
+
1622
+ out = []
1623
+ for variant in raw_variants:
1624
+ if variant == "all":
1625
+ out.extend(PUBLIC_MPR_VARIANTS.values())
1626
+ else:
1627
+ out.append(_normalize_mpr_variant(variant))
1628
+
1629
+ # Preserve user order while removing duplicates.
1630
+ return tuple(dict.fromkeys(out))
1631
+
1632
+
1633
+ def _legacy_filter_to_variant(filter_key, default=None):
1634
+ """Map old filter-key names to public variant names."""
1635
+ if filter_key is None:
1636
+ return default if default is not None else "all"
1637
+ mapping = {
1638
+ "all": "unfiltered",
1639
+ "no_mtRibo_ETCI": "without_mt_ribo_etci",
1640
+ "no_small_highAUPRC": "without_small_high_auprc",
1641
+ }
1642
+ return mapping.get(filter_key, filter_key)
1643
+
1644
+
1645
+ def _legacy_filters_to_variants(show_filters):
1646
+ """Map old show_filters values to public variant names."""
1526
1647
  if show_filters is None:
1527
- return tuple(FILTER_STYLES.keys())
1648
+ return "all"
1528
1649
  if isinstance(show_filters, str):
1529
- return (show_filters,)
1650
+ return _legacy_filter_to_variant(show_filters)
1530
1651
  try:
1531
- return tuple(show_filters)
1652
+ return tuple(_legacy_filter_to_variant(item) for item in show_filters)
1532
1653
  except TypeError:
1533
- return (show_filters,)
1654
+ return (_legacy_filter_to_variant(show_filters),)
1534
1655
 
1535
- def plot_mpr_tp_multi(
1656
+
1657
+ def _normalize_show_filters(show_filters):
1658
+ """Backward-compatible normalizer for old internal filter keys."""
1659
+ return _normalize_mpr_variants(_legacy_filters_to_variants(show_filters))
1660
+
1661
+ def plot_mpr_true_positive_curve(
1536
1662
  dataset_names=None,
1537
1663
  colors=None,
1538
1664
  ax=None,
1539
1665
  save=True,
1540
1666
  outname=None,
1541
1667
  linewidth=1.8,
1542
- show_filters=("all", "no_mtRibo_ETCI", "no_small_highAUPRC"),
1668
+ variants="unfiltered",
1543
1669
  ):
1544
1670
  """
1545
- Plot TP vs precision curves for multiple datasets.
1671
+ Plot mPR true-positive vs precision curves for multiple datasets.
1546
1672
 
1547
1673
  Can auto-detect datasets or use provided dataset names.
1548
1674
  Each dataset gets one color, each filter type gets one line style.
@@ -1562,8 +1688,9 @@ def plot_mpr_tp_multi(
1562
1688
  Output filename. If None, auto-generated.
1563
1689
  linewidth : float
1564
1690
  Line width for all curves
1565
- show_filters : tuple of str
1566
- Which filters to show. Default is all three.
1691
+ variants : str or iterable of str
1692
+ Which mPR variants to show. Use "unfiltered",
1693
+ "without_mt_ribo_etci", "without_small_high_auprc", or "all".
1567
1694
 
1568
1695
  Returns
1569
1696
  -------
@@ -1573,7 +1700,7 @@ def plot_mpr_tp_multi(
1573
1700
  plot_config = config["plotting"]
1574
1701
  input_colors = dload("input", "colors")
1575
1702
 
1576
- show_filters = _normalize_show_filters(show_filters)
1703
+ variant_keys = _normalize_mpr_variants(variants)
1577
1704
 
1578
1705
  # Sanitize color keys
1579
1706
  if input_colors:
@@ -1641,13 +1768,13 @@ def plot_mpr_tp_multi(
1641
1768
  tp_curves = mpr["tp_curves"]
1642
1769
  color = colors[i % len(colors)]
1643
1770
 
1644
- for filter_key in show_filters:
1645
- if filter_key not in tp_curves:
1771
+ for variant_key in variant_keys:
1772
+ if variant_key not in tp_curves:
1646
1773
  continue
1647
1774
 
1648
- data = tp_curves[filter_key]
1775
+ data = tp_curves[variant_key]
1649
1776
  if not isinstance(data, dict) or "tp" not in data or "precision" not in data:
1650
- log.warning(f"Invalid tp_curves data structure for '{name}' filter '{filter_key}', skipping.")
1777
+ log.warning(f"Invalid tp_curves data structure for '{name}' variant '{variant_key}', skipping.")
1651
1778
  continue
1652
1779
 
1653
1780
  tp = np.asarray(data["tp"], dtype=float)
@@ -1661,7 +1788,7 @@ def plot_mpr_tp_multi(
1661
1788
  prec_plot = prec[mask]
1662
1789
  xmax = max(xmax, float(tp_plot.max()))
1663
1790
 
1664
- style = FILTER_STYLES.get(filter_key, {})
1791
+ style = MPR_VARIANT_STYLES.get(variant_key, {})
1665
1792
  ax.plot(
1666
1793
  tp_plot,
1667
1794
  prec_plot,
@@ -1694,7 +1821,7 @@ def plot_mpr_tp_multi(
1694
1821
  ax.spines['right'].set_visible(False)
1695
1822
 
1696
1823
  # Create vertically stacked legends
1697
- _add_vertical_legend(ax, dataset_names, colors, show_filters, linewidth)
1824
+ _add_vertical_legend(ax, dataset_names, colors, variant_keys, linewidth)
1698
1825
 
1699
1826
  # Save
1700
1827
  if save:
@@ -1713,7 +1840,8 @@ def plot_mpr_tp_multi(
1713
1840
 
1714
1841
  return ax
1715
1842
 
1716
- def plot_mpr_complexes_multi(
1843
+
1844
+ def plot_mpr_tp_multi(
1717
1845
  dataset_names=None,
1718
1846
  colors=None,
1719
1847
  ax=None,
@@ -1721,11 +1849,31 @@ def plot_mpr_complexes_multi(
1721
1849
  outname=None,
1722
1850
  linewidth=1.8,
1723
1851
  show_filters=("all", "no_mtRibo_ETCI", "no_small_highAUPRC"),
1852
+ ):
1853
+ """Backward-compatible wrapper for plot_mpr_true_positive_curve()."""
1854
+ return plot_mpr_true_positive_curve(
1855
+ dataset_names=dataset_names,
1856
+ colors=colors,
1857
+ ax=ax,
1858
+ save=save,
1859
+ outname=outname,
1860
+ linewidth=linewidth,
1861
+ variants=_legacy_filters_to_variants(show_filters),
1862
+ )
1863
+
1864
+ def plot_mpr_complex_coverage_curve(
1865
+ dataset_names=None,
1866
+ colors=None,
1867
+ ax=None,
1868
+ save=True,
1869
+ outname=None,
1870
+ linewidth=1.8,
1871
+ variants="unfiltered",
1724
1872
  show_markers="auto",
1725
1873
  marker_size=20,
1726
1874
  ):
1727
1875
  """
1728
- Plot module-level PR (#complexes vs precision) for multiple datasets.
1876
+ Plot mPR complex-coverage vs precision curves for multiple datasets.
1729
1877
 
1730
1878
  Can auto-detect datasets or use provided dataset names.
1731
1879
  Each dataset gets one color, each filter type gets one line style.
@@ -1745,8 +1893,9 @@ def plot_mpr_complexes_multi(
1745
1893
  Output filename. If None, auto-generated.
1746
1894
  linewidth : float
1747
1895
  Line width for all curves
1748
- show_filters : tuple of str
1749
- Which filters to show. Default is all three.
1896
+ variants : str or iterable of str
1897
+ Which mPR variants to show. Use "unfiltered",
1898
+ "without_mt_ribo_etci", "without_small_high_auprc", or "all".
1750
1899
  show_markers : bool or "auto"
1751
1900
  If True, draw markers on curves to make short curves visible.
1752
1901
  If "auto" (default), markers are drawn only for curves with <= 10 points.
@@ -1761,7 +1910,7 @@ def plot_mpr_complexes_multi(
1761
1910
  plot_config = config["plotting"]
1762
1911
  input_colors = dload("input", "colors")
1763
1912
 
1764
- show_filters = _normalize_show_filters(show_filters)
1913
+ variant_keys = _normalize_mpr_variants(variants)
1765
1914
 
1766
1915
  # Sanitize color keys
1767
1916
  if input_colors:
@@ -1812,32 +1961,61 @@ def plot_mpr_complexes_multi(
1812
1961
  else:
1813
1962
  fig = ax.figure
1814
1963
 
1815
- # Plot each dataset
1964
+ # First pass: determine max coverage across all datasets/filters for adaptive x-axis
1965
+ max_cov_global = 0
1966
+ _mpr_cache = {}
1816
1967
  for i, name in enumerate(dataset_names):
1817
1968
  mpr = dload("mpr", name)
1969
+ _mpr_cache[name] = mpr
1970
+ if mpr is not None:
1971
+ for variant_key in variant_keys:
1972
+ arr = mpr["coverage_curves"].get(variant_key)
1973
+ if arr is not None:
1974
+ max_cov_global = max(max_cov_global, float(np.asarray(arr).max()))
1975
+
1976
+ # Build adaptive x-axis limits and ticks
1977
+ import math
1978
+ if max_cov_global <= 200:
1979
+ # Original fixed range — keeps CORUM plots identical to before
1980
+ x_max_plot = 200
1981
+ tick_positions = [1, 2, 20, 200]
1982
+ tick_labels = ["0", "2", "20", "200"]
1983
+ else:
1984
+ # Round up to the next power of 10 so the max bar has breathing room
1985
+ x_max_plot = 10 ** math.ceil(math.log10(max_cov_global + 1))
1986
+ tick_positions = [1, 2]
1987
+ v = 10
1988
+ while v <= x_max_plot:
1989
+ tick_positions.append(v)
1990
+ v *= 10
1991
+ tick_labels = ["0"] + [str(t) for t in tick_positions[1:]]
1992
+
1993
+ # Plot each dataset
1994
+ for i, name in enumerate(dataset_names):
1995
+ mpr = _mpr_cache[name]
1818
1996
  if mpr is None:
1819
1997
  log.warning(f"mPR data for '{name}' not found, skipping.")
1820
1998
  continue
1821
-
1999
+
1822
2000
  precision_cutoffs = np.asarray(mpr["precision_cutoffs"], dtype=float)
1823
2001
  coverage = mpr["coverage_curves"]
1824
2002
  color = colors[i % len(colors)]
1825
-
1826
- for filter_key in show_filters:
1827
- if filter_key not in coverage:
2003
+
2004
+ for variant_key in variant_keys:
2005
+ if variant_key not in coverage:
1828
2006
  continue
1829
-
1830
- cov = np.asarray(coverage[filter_key], dtype=float)
1831
-
1832
- # Keep only positive coverage up to 200 complexes
1833
- mask = (cov > 0) & (cov <= 200)
2007
+
2008
+ cov = np.asarray(coverage[variant_key], dtype=float)
2009
+
2010
+ # Keep only positive coverage within the visible x range
2011
+ mask = (cov > 0) & (cov <= x_max_plot)
1834
2012
  if not mask.any():
1835
2013
  continue
1836
-
2014
+
1837
2015
  cov_plot = cov[mask]
1838
2016
  prec_plot = precision_cutoffs[mask]
1839
-
1840
- style = FILTER_STYLES.get(filter_key, {})
2017
+
2018
+ style = MPR_VARIANT_STYLES.get(variant_key, {})
1841
2019
 
1842
2020
  # Decide marker visibility
1843
2021
  if show_markers == "auto":
@@ -1858,17 +2036,15 @@ def plot_mpr_complexes_multi(
1858
2036
  marker=("o" if use_markers else None),
1859
2037
  markersize=(3 if use_markers else None),
1860
2038
  )
1861
-
2039
+
1862
2040
  # Configure axes
1863
2041
  ax.set_xscale("log")
1864
- ax.set_xlim(1, 200)
2042
+ ax.set_xlim(1, x_max_plot)
1865
2043
  ax.set_xlabel("# complexes")
1866
2044
  ax.set_ylabel("Precision")
1867
2045
  ax.set_ylim(0.0, 1.05)
1868
-
1869
- # Custom x-ticks
1870
- tick_positions = [1, 2, 20, 200]
1871
- tick_labels = ["0", "2", "20", "200"]
2046
+
2047
+ # Adaptive x-ticks
1872
2048
  ax.set_xticks(tick_positions)
1873
2049
  ax.set_xticklabels(tick_labels)
1874
2050
 
@@ -1877,7 +2053,7 @@ def plot_mpr_complexes_multi(
1877
2053
  ax.spines['right'].set_visible(False)
1878
2054
 
1879
2055
  # Create vertically stacked legends
1880
- _add_vertical_legend(ax, dataset_names, colors, show_filters, linewidth)
2056
+ _add_vertical_legend(ax, dataset_names, colors, variant_keys, linewidth)
1881
2057
 
1882
2058
  # Save
1883
2059
  if save:
@@ -1896,11 +2072,71 @@ def plot_mpr_complexes_multi(
1896
2072
 
1897
2073
  return ax
1898
2074
 
1899
- def _add_vertical_legend(ax, dataset_names, colors, show_filters, linewidth):
2075
+
2076
+ def plot_mpr_complexes_multi(
2077
+ dataset_names=None,
2078
+ colors=None,
2079
+ ax=None,
2080
+ save=True,
2081
+ outname=None,
2082
+ linewidth=1.8,
2083
+ show_filters=("all", "no_mtRibo_ETCI", "no_small_highAUPRC"),
2084
+ show_markers="auto",
2085
+ marker_size=20,
2086
+ ):
2087
+ """Backward-compatible wrapper for plot_mpr_complex_coverage_curve()."""
2088
+ return plot_mpr_complex_coverage_curve(
2089
+ dataset_names=dataset_names,
2090
+ colors=colors,
2091
+ ax=ax,
2092
+ save=save,
2093
+ outname=outname,
2094
+ linewidth=linewidth,
2095
+ variants=_legacy_filters_to_variants(show_filters),
2096
+ show_markers=show_markers,
2097
+ marker_size=marker_size,
2098
+ )
2099
+
2100
+
2101
+ def plot_mpr_summary(
2102
+ dataset_names=None,
2103
+ colors=None,
2104
+ variants="unfiltered",
2105
+ save=True,
2106
+ linewidth=1.8,
2107
+ show_markers="auto",
2108
+ marker_size=20,
2109
+ auc_variant=None,
2110
+ ):
2111
+ """Generate the standard mPR summary plots and return complex AUC scores."""
2112
+ plot_mpr_true_positive_curve(
2113
+ dataset_names=dataset_names,
2114
+ colors=colors,
2115
+ save=save,
2116
+ linewidth=linewidth,
2117
+ variants=variants,
2118
+ )
2119
+ plot_mpr_complex_coverage_curve(
2120
+ dataset_names=dataset_names,
2121
+ colors=colors,
2122
+ save=save,
2123
+ linewidth=linewidth,
2124
+ variants=variants,
2125
+ show_markers=show_markers,
2126
+ marker_size=marker_size,
2127
+ )
2128
+
2129
+ if auc_variant is None:
2130
+ variant_keys = _normalize_mpr_variants(variants)
2131
+ auc_variant = INTERNAL_MPR_VARIANTS.get(variant_keys[0], "unfiltered")
2132
+
2133
+ return plot_mpr_complex_auc_scores(variant=auc_variant, save=save)
2134
+
2135
+ def _add_vertical_legend(ax, dataset_names, colors, variant_keys, linewidth):
1900
2136
  """
1901
- Add vertically stacked legends: Dataset on top, Filter below.
2137
+ Add vertically stacked legends: Dataset on top, mPR variant below.
1902
2138
  """
1903
- show_filters = _normalize_show_filters(show_filters)
2139
+ variant_keys = _normalize_show_filters(variant_keys)
1904
2140
  # Legend 1: Datasets (colors) - solid lines
1905
2141
  dataset_handles = []
1906
2142
  for i, name in enumerate(dataset_names):
@@ -1908,19 +2144,19 @@ def _add_vertical_legend(ax, dataset_names, colors, show_filters, linewidth):
1908
2144
  handle = Line2D([0], [0], color=color, linewidth=linewidth, linestyle="-")
1909
2145
  dataset_handles.append(handle)
1910
2146
 
1911
- # Legend 2: Filters (line styles) - black lines
1912
- filter_handles = []
1913
- filter_labels = []
1914
- for filter_key in show_filters:
1915
- style = FILTER_STYLES.get(filter_key, {})
2147
+ # Legend 2: mPR variants (line styles) - black lines
2148
+ variant_handles = []
2149
+ variant_labels = []
2150
+ for variant_key in variant_keys:
2151
+ style = MPR_VARIANT_STYLES.get(variant_key, {})
1916
2152
  handle = Line2D(
1917
2153
  [0], [0],
1918
2154
  color="black",
1919
2155
  linewidth=linewidth,
1920
2156
  linestyle=style.get("linestyle", "-")
1921
2157
  )
1922
- filter_handles.append(handle)
1923
- filter_labels.append(style.get("label", filter_key))
2158
+ variant_handles.append(handle)
2159
+ variant_labels.append(style.get("label", variant_key))
1924
2160
 
1925
2161
  # Position legends vertically with proper alignment
1926
2162
  # Dataset legend on upper right
@@ -1938,19 +2174,19 @@ def _add_vertical_legend(ax, dataset_names, colors, show_filters, linewidth):
1938
2174
 
1939
2175
  # Filter legend below the dataset legend, aligned properly without title
1940
2176
  legend2 = ax.legend(
1941
- filter_handles,
1942
- filter_labels,
2177
+ variant_handles,
2178
+ variant_labels,
1943
2179
  loc="upper left",
1944
2180
  frameon=False,
1945
2181
  fontsize=7,
1946
2182
  bbox_to_anchor=(1.05, 1.0 - len(dataset_names) * 0.06 - 0.1)
1947
2183
  )
1948
2184
 
1949
- def _add_dual_legend(ax, dataset_names, colors, show_filters, linewidth):
2185
+ def _add_dual_legend(ax, dataset_names, colors, variant_keys, linewidth):
1950
2186
  """
1951
- Add two legends: one for datasets (colors), one for filters (line styles).
2187
+ Add two legends: one for datasets (colors), one for mPR variants (line styles).
1952
2188
  """
1953
- show_filters = _normalize_show_filters(show_filters)
2189
+ variant_keys = _normalize_show_filters(variant_keys)
1954
2190
  # Legend 1: Datasets (colors) - solid lines
1955
2191
  dataset_handles = []
1956
2192
  for i, name in enumerate(dataset_names):
@@ -1958,19 +2194,19 @@ def _add_dual_legend(ax, dataset_names, colors, show_filters, linewidth):
1958
2194
  handle = Line2D([0], [0], color=color, linewidth=linewidth, linestyle="-")
1959
2195
  dataset_handles.append(handle)
1960
2196
 
1961
- # Legend 2: Filters (line styles) - black lines
1962
- filter_handles = []
1963
- filter_labels = []
1964
- for filter_key in show_filters:
1965
- style = FILTER_STYLES.get(filter_key, {})
2197
+ # Legend 2: mPR variants (line styles) - black lines
2198
+ variant_handles = []
2199
+ variant_labels = []
2200
+ for variant_key in variant_keys:
2201
+ style = MPR_VARIANT_STYLES.get(variant_key, {})
1966
2202
  handle = Line2D(
1967
2203
  [0], [0],
1968
2204
  color="black",
1969
2205
  linewidth=linewidth,
1970
2206
  linestyle=style.get("linestyle", "-")
1971
2207
  )
1972
- filter_handles.append(handle)
1973
- filter_labels.append(style.get("label", filter_key))
2208
+ variant_handles.append(handle)
2209
+ variant_labels.append(style.get("label", variant_key))
1974
2210
 
1975
2211
  # Position legends
1976
2212
  # Dataset legend on upper right
@@ -1987,19 +2223,12 @@ def _add_dual_legend(ax, dataset_names, colors, show_filters, linewidth):
1987
2223
 
1988
2224
  # Filter legend on lower left or right depending on plot type
1989
2225
  legend2 = ax.legend(
1990
- filter_handles,
1991
- filter_labels,
2226
+ variant_handles,
2227
+ variant_labels,
1992
2228
  loc="lower left",
1993
2229
  frameon=False,
1994
- title="Filter",
2230
+ title="Variant",
1995
2231
  fontsize=7,
1996
2232
  title_fontsize=8,
1997
2233
  )
1998
2234
 
1999
- # ============================================================================
2000
- # Single dataset functions are now obsolete
2001
- # ============================================================================
2002
-
2003
- # Note: The original single dataset functions plot_mpr_tp() and plot_mpr_complexes()
2004
- # have been replaced by the multi functions that now auto-detect available datasets.
2005
- # Use plot_mpr_tp_multi() and plot_mpr_complexes_multi() instead.