pythonflex 0.3.4__py3-none-any.whl → 0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pythonflex/plotting.py CHANGED
@@ -9,6 +9,7 @@ import pandas as pd
9
9
  import matplotlib.pyplot as plt
10
10
  from matplotlib import patches
11
11
  from matplotlib.cm import get_cmap
12
+ from matplotlib.lines import Line2D
12
13
  from matplotlib.ticker import NullFormatter, NullLocator
13
14
 
14
15
  # Completely disable LaTeX and clear all font cache/references
@@ -26,14 +27,17 @@ mpl.rcParams['font.cursive'] = ['Apple Chancery', 'Textile', 'Zapf Chancery', 'S
26
27
  mpl.rcParams['font.fantasy'] = ['Comic Sans MS', 'Chicago', 'Charcoal', 'Impact', 'Western', 'Humor Sans', 'fantasy']
27
28
  mpl.rcParams['font.monospace'] = ['DejaVu Sans Mono', 'Bitstream Vera Sans Mono', 'Computer Modern Typewriter', 'Andale Mono', 'Nimbus Mono L', 'Courier New', 'Courier', 'Fixed', 'Terminal', 'monospace']
28
29
 
29
- # Remove any LaTeX-specific math font settings
30
- mpl.rcParams['mathtext.fontset'] = 'dejavusans'
31
- mpl.rcParams['mathtext.default'] = 'regular'
30
+ # Remove any LaTeX-specific math font settings
31
+ mpl.rcParams['mathtext.fontset'] = 'dejavusans'
32
+ mpl.rcParams['mathtext.default'] = 'regular'
33
+ mpl.rcParams['pdf.fonttype'] = 42
34
+ mpl.rcParams['ps.fonttype'] = 42
35
+ mpl.rcParams['svg.fonttype'] = 'none'
32
36
 
33
37
  # Force font manager to rebuild with system fonts only
34
38
  try:
35
39
  fm.fontManager.__init__()
36
- except:
40
+ except Exception:
37
41
  pass
38
42
 
39
43
  # Local modules
@@ -275,17 +279,233 @@ def plot_all_runs_pra(pra_list, mean_df=None, line_width=2.0, hide_minor_ticks=T
275
279
  output_path = Path(config["output_folder"]) / f"aggregated_all_runs_precision_recall_curve.{output_type}"
276
280
  fig.savefig(output_path, bbox_inches="tight", format=output_type)
277
281
 
278
- if plot_config.get("show_plot", True):
279
- plt.show()
280
- plt.close(fig)
281
-
282
- def plot_percomplex_scatter(n_top=10, sig_color='#B71A2A', nonsig_color='#DBDDDD', label_color='black', border_color='black', border_width=1.0, show_text_background=True):
283
- config = dload("config")
284
- plot_config = config["plotting"]
285
- rdict = dload("pra_percomplex")
282
+ if plot_config.get("show_plot", True):
283
+ plt.show()
284
+ plt.close(fig)
285
+
286
+ def _short_scatter_label(label, max_chars=10):
287
+ label = str(label)
288
+ return label[:max_chars] + "." if len(label) > max_chars else label
289
+
290
+ def _bbox_overlap_area(box_a, box_b):
291
+ x_overlap = max(0.0, min(box_a[2], box_b[2]) - max(box_a[0], box_b[0]))
292
+ y_overlap = max(0.0, min(box_a[3], box_b[3]) - max(box_a[1], box_b[1]))
293
+ return x_overlap * y_overlap
294
+
295
+ def _line_intersects_bbox(start_px, end_px, box):
296
+ samples = np.linspace(0.05, 0.95, 12)
297
+ xs = start_px[0] + (end_px[0] - start_px[0]) * samples
298
+ ys = start_px[1] + (end_px[1] - start_px[1]) * samples
299
+ return np.any(
300
+ (xs >= box[0]) & (xs <= box[2]) &
301
+ (ys >= box[1]) & (ys <= box[3])
302
+ )
303
+
304
+ def _label_alignment(dx, dy):
305
+ if abs(dx) < 1e-9:
306
+ ha = "center"
307
+ else:
308
+ ha = "left" if dx > 0 else "right"
309
+
310
+ if abs(dy) < 1e-9:
311
+ va = "center"
312
+ else:
313
+ va = "bottom" if dy > 0 else "top"
314
+
315
+ return ha, va
316
+
317
+ def _label_display_bbox(ax, anchor_xy, width_px, height_px, ha, va):
318
+ anchor_x, anchor_y = ax.transData.transform(anchor_xy)
319
+
320
+ if ha == "left":
321
+ x0, x1 = anchor_x, anchor_x + width_px
322
+ elif ha == "right":
323
+ x0, x1 = anchor_x - width_px, anchor_x
324
+ else:
325
+ x0, x1 = anchor_x - width_px / 2.0, anchor_x + width_px / 2.0
326
+
327
+ if va == "bottom":
328
+ y0, y1 = anchor_y, anchor_y + height_px
329
+ elif va == "top":
330
+ y0, y1 = anchor_y - height_px, anchor_y
331
+ else:
332
+ y0, y1 = anchor_y - height_px / 2.0, anchor_y + height_px / 2.0
333
+
334
+ return (x0, y0, x1, y1)
335
+
336
+ def _measure_label_size_pixels(ax, label, fontsize, bbox_props, renderer):
337
+ probe = ax.text(
338
+ 0.5, 0.5, label,
339
+ fontsize=fontsize,
340
+ ha="left",
341
+ va="bottom",
342
+ linespacing=1,
343
+ alpha=0.0,
344
+ bbox=bbox_props,
345
+ )
346
+ bbox = probe.get_window_extent(renderer=renderer).expanded(1.12, 1.35)
347
+ probe.remove()
348
+ return bbox.width, bbox.height
349
+
350
+ def _place_scatter_labels_radially(ax, label_items, obstacle_points, fontsize=4, bbox_props=None):
351
+ """Choose label anchors by scanning candidate positions around each point."""
352
+ fig = ax.figure
353
+ fig.canvas.draw()
354
+ renderer = fig.canvas.get_renderer()
355
+ axis_box = ax.get_window_extent(renderer=renderer)
356
+
357
+ obstacle_points = np.asarray(obstacle_points, dtype=float)
358
+ if obstacle_points.size == 0:
359
+ obstacle_points = np.empty((0, 2), dtype=float)
360
+ else:
361
+ obstacle_points = obstacle_points[np.isfinite(obstacle_points).all(axis=1)]
362
+ obstacle_pixels = ax.transData.transform(obstacle_points) if len(obstacle_points) else np.empty((0, 2))
363
+
364
+ angles = np.deg2rad(np.arange(0, 360, 15))
365
+ distances = np.array([0.035, 0.055, 0.08, 0.11, 0.145, 0.18])
366
+ axis_margin_px = 2.0
367
+
368
+ valid_items = [
369
+ (x, y, label)
370
+ for x, y, label in label_items
371
+ if np.isfinite(x) and np.isfinite(y)
372
+ ]
373
+ valid_items = sorted(
374
+ valid_items,
375
+ key=lambda item: (
376
+ min(item[0], 1.0 - item[0], item[1], 1.0 - item[1]),
377
+ -item[1],
378
+ -item[0],
379
+ ),
380
+ )
381
+
382
+ placed = []
383
+ placed_boxes = []
384
+
385
+ for point_x, point_y, label in valid_items:
386
+ label_width, label_height = _measure_label_size_pixels(
387
+ ax, label, fontsize, bbox_props, renderer
388
+ )
389
+ point_px = ax.transData.transform((point_x, point_y))
390
+ best_candidate = None
391
+
392
+ for distance in distances:
393
+ for angle in angles:
394
+ dx = float(np.cos(angle) * distance)
395
+ dy = float(np.sin(angle) * distance)
396
+ text_x = point_x + dx
397
+ text_y = point_y + dy
398
+ ha, va = _label_alignment(dx, dy)
399
+ box = _label_display_bbox(
400
+ ax, (text_x, text_y), label_width, label_height, ha, va
401
+ )
402
+
403
+ outside_axes = (
404
+ box[0] < axis_box.x0 + axis_margin_px or
405
+ box[1] < axis_box.y0 + axis_margin_px or
406
+ box[2] > axis_box.x1 - axis_margin_px or
407
+ box[3] > axis_box.y1 - axis_margin_px
408
+ )
409
+ if outside_axes:
410
+ continue
411
+
412
+ overlaps = [_bbox_overlap_area(box, placed_box) for placed_box in placed_boxes]
413
+ overlap_hits = sum(area > 0 for area in overlaps)
414
+ overlap_area = sum(overlaps)
415
+
416
+ if len(obstacle_pixels):
417
+ dot_hits = np.count_nonzero(
418
+ (obstacle_pixels[:, 0] >= box[0]) &
419
+ (obstacle_pixels[:, 0] <= box[2]) &
420
+ (obstacle_pixels[:, 1] >= box[1]) &
421
+ (obstacle_pixels[:, 1] <= box[3])
422
+ )
423
+ if box[0] <= point_px[0] <= box[2] and box[1] <= point_px[1] <= box[3]:
424
+ dot_hits = max(0, dot_hits - 1)
425
+ else:
426
+ dot_hits = 0
427
+
428
+ text_px = ax.transData.transform((text_x, text_y))
429
+ connector_hits = sum(
430
+ _line_intersects_bbox(point_px, text_px, placed_box)
431
+ for placed_box in placed_boxes
432
+ )
433
+ edge_gap = min(
434
+ box[0] - axis_box.x0,
435
+ box[1] - axis_box.y0,
436
+ axis_box.x1 - box[2],
437
+ axis_box.y1 - box[3],
438
+ )
439
+
440
+ score = (
441
+ overlap_hits * 100000.0 +
442
+ overlap_area * 0.5 +
443
+ dot_hits * 1500.0 +
444
+ connector_hits * 1200.0 +
445
+ np.hypot(text_px[0] - point_px[0], text_px[1] - point_px[1]) * 0.05 +
446
+ 10.0 / max(edge_gap, 1.0)
447
+ )
448
+
449
+ if best_candidate is None or score < best_candidate["score"]:
450
+ best_candidate = {
451
+ "point_x": point_x,
452
+ "point_y": point_y,
453
+ "text_x": text_x,
454
+ "text_y": text_y,
455
+ "label": label,
456
+ "ha": ha,
457
+ "va": va,
458
+ "box": box,
459
+ "score": score,
460
+ }
461
+
462
+ if best_candidate is None:
463
+ # Extremely rare fallback for very long labels or tight axes.
464
+ fallback_dx = -0.055 if point_x > 0.5 else 0.055
465
+ fallback_dy = -0.055 if point_y > 0.5 else 0.055
466
+ ha, va = _label_alignment(fallback_dx, fallback_dy)
467
+ text_x = max(0.02, min(0.98, point_x + fallback_dx))
468
+ text_y = max(0.02, min(0.98, point_y + fallback_dy))
469
+ best_candidate = {
470
+ "point_x": point_x,
471
+ "point_y": point_y,
472
+ "text_x": text_x,
473
+ "text_y": text_y,
474
+ "label": label,
475
+ "ha": ha,
476
+ "va": va,
477
+ "box": _label_display_bbox(
478
+ ax, (text_x, text_y), label_width, label_height, ha, va
479
+ ),
480
+ "score": float("inf"),
481
+ }
482
+
483
+ placed_boxes.append(best_candidate["box"])
484
+ placed.append(best_candidate)
485
+
486
+ return placed
487
+
488
+ def plot_percomplex_scatter(
489
+ n_top=10,
490
+ sig_color='black',
491
+ nonsig_color='none',
492
+ label_color='black',
493
+ border_color='black',
494
+ border_width=1.0,
495
+ nonsig_border_color="#7F7F7F",
496
+ nonsig_border_width=0.5,
497
+ show_text_background=True,
498
+ ):
499
+ config = dload("config")
500
+ plot_config = config["plotting"]
501
+ rdict = dload("pra_percomplex")
502
+ input_colors = dload("input", "colors")
503
+ input_colors = {_sanitize(k): v for k, v in input_colors.items()} if input_colors else {}
286
504
 
287
505
  if len(rdict) < 2:
288
- print("Skipping plot: At least two datasets are required for per-complex scatter plot.")
506
+ log.warning(
507
+ "Skipping plot: at least two datasets are required for per-complex scatter plot."
508
+ )
289
509
  return
290
510
 
291
511
  column_pairs = list(combinations(rdict.keys(), 2))
@@ -299,93 +519,61 @@ def plot_percomplex_scatter(n_top=10, sig_color='#B71A2A', nonsig_color='#DBDDDD
299
519
  df = pd.concat([df, val[key]], axis=1)
300
520
 
301
521
  for pair in column_pairs:
302
- extreme_indices_0 = df[pair[0]].sort_values(ascending=False).head(n_top).index
303
- extreme_indices_1 = df[pair[1]].sort_values(ascending=False).head(n_top).index
304
- significant_indices = extreme_indices_0.union(extreme_indices_1)
305
-
306
- bg_df = df.drop(index=significant_indices)
307
- sig_df = df.loc[significant_indices]
308
-
309
- # Create square figure
310
- fig, ax = plt.subplots(figsize=(6, 6))
311
-
312
- # Background cloud (filled dots with black borders, not rasterized)
313
- bg_sizes = (bg_df['n_used_genes'] if 'n_used_genes' in bg_df else pd.Series(1, index=bg_df.index)) * 5
314
- ax.scatter(
315
- bg_df[pair[0]], bg_df[pair[1]],
316
- facecolors=nonsig_color, edgecolors=border_color,
317
- s=bg_sizes, linewidth=border_width, alpha=1.0,
318
- zorder=0
319
- )
320
-
321
- # Significant points (filled dots with black borders)
322
- sig_sizes = (sig_df['n_used_genes'] if 'n_used_genes' in sig_df else pd.Series(1, index=sig_df.index)) * 8
323
- ax.scatter(
324
- sig_df[pair[0]], sig_df[pair[1]],
325
- facecolors=sig_color, edgecolors=border_color,
326
- s=sig_sizes, linewidth=border_width, zorder=2
327
- )
328
-
329
- # Improved label positioning with adaptive spacing
330
- coords = sorted(
331
- [(sig_df.loc[idx, pair[0]], sig_df.loc[idx, pair[1]], idx) for idx in sig_df.index],
332
- key=lambda c: (-c[1], -c[0])
333
- )
334
-
335
- # Calculate proper parameters for normalized coordinate system (0-1 range)
336
- max_y = 1.0 # Normalized plots use 0-1 range
337
- scale_factor = 1.0 # Standard scaling for normalized plots
338
- min_distance = 0.08 # Increased spacing for 0-1 range to avoid overlap
339
-
340
- adjusted_coords = adjust_text_positions_improved(
341
- coords, sig_sizes,
342
- min_distance=min_distance,
343
- max_y=max_y,
344
- scale_factor=scale_factor,
345
- y_threshold=0.8 # Points above this will have labels below
346
- )
347
-
348
- for x, adj_y, idx, direction in adjusted_coords:
349
- y = df.loc[idx, pair[1]]
350
-
351
- # Calculate connector line extension, but constrain within plot bounds
352
- line_extension_factor = 1.5 # Reduced from 2.5 to keep labels in bounds
353
- extended_adj_y = y + (adj_y - y) * line_extension_factor
354
-
355
- # Clip to ensure connector stays within 0-1 range
356
- extended_adj_y = max(0.02, min(extended_adj_y, 0.98))
357
-
358
- # Draw connector line
359
- ax.plot([x, x], [y, extended_adj_y],
360
- color=label_color, linewidth=0.6, alpha=0.15, zorder=3)
361
-
362
- # Position text at the end of extended line with small offset
363
- text_y_offset = 0.01 if direction == "up" else -0.01
364
- final_text_y = extended_adj_y + text_y_offset
365
-
366
- # Final clip to ensure text stays within 0-1 range
367
- final_text_y = max(0.02, min(final_text_y, 0.98))
368
-
369
- # Prepare text bbox settings (can be turned on/off)
370
- bbox_props = dict(facecolor="white", alpha=0.7, edgecolor="none", pad=1) if show_text_background else None
371
-
372
- ax.text(
373
- x, final_text_y,
374
- df.loc[idx, 'Name'][:10] + '.' if len(df.loc[idx, 'Name']) > 10 else df.loc[idx, 'Name'],
375
- fontsize=4,
376
- ha='left',
377
- va='bottom' if direction == "up" else 'top',
378
- color=label_color,
379
- linespacing=1,
380
- zorder=4,
381
- clip_on=True, # Enable clipping to axes bounds
382
- bbox=bbox_props
383
- )
384
-
385
- # Diagonal & axes cosmetics
386
- ax.plot([0, 1], [0, 1], linestyle='-', color='lightgray', alpha=0.4, linewidth=0.5, zorder=1)
387
-
388
- # Force square aspect ratio and exact 0-1 range
522
+ extreme_indices_0 = df[pair[0]].sort_values(ascending=False).head(n_top).index
523
+ extreme_indices_1 = df[pair[1]].sort_values(ascending=False).head(n_top).index
524
+ significant_indices = extreme_indices_0.union(extreme_indices_1)
525
+ significant_in_both = extreme_indices_0.intersection(extreme_indices_1)
526
+ significant_pair0_only = extreme_indices_0.difference(extreme_indices_1)
527
+ significant_pair1_only = extreme_indices_1.difference(extreme_indices_0)
528
+
529
+ bg_df = df.drop(index=significant_indices)
530
+ sig_df = df.loc[significant_indices]
531
+
532
+ # Create square figure
533
+ fig, ax = plt.subplots(figsize=(6, 6))
534
+
535
+ # Background cloud: non-significant complexes are open circles.
536
+ bg_sizes = (bg_df['n_used_genes'] if 'n_used_genes' in bg_df else pd.Series(1, index=bg_df.index)) * 5
537
+ ax.scatter(
538
+ bg_df[pair[0]], bg_df[pair[1]],
539
+ facecolors="none", edgecolors=nonsig_border_color,
540
+ s=bg_sizes, linewidth=nonsig_border_width, alpha=0.8,
541
+ zorder=0
542
+ )
543
+
544
+ def scatter_significant(indices, color, zorder=2):
545
+ if len(indices) == 0:
546
+ return
547
+ point_df = df.loc[indices]
548
+ point_sizes = (
549
+ point_df['n_used_genes']
550
+ if 'n_used_genes' in point_df
551
+ else pd.Series(1, index=point_df.index)
552
+ ) * 8
553
+ ax.scatter(
554
+ point_df[pair[0]], point_df[pair[1]],
555
+ facecolors=color, edgecolors=color,
556
+ s=point_sizes, linewidth=border_width, zorder=zorder
557
+ )
558
+
559
+ # Dataset-specific significant complexes use the dataset input color.
560
+ scatter_significant(
561
+ significant_pair0_only,
562
+ input_colors.get(_sanitize(pair[0]), sig_color),
563
+ zorder=2,
564
+ )
565
+ scatter_significant(
566
+ significant_pair1_only,
567
+ input_colors.get(_sanitize(pair[1]), sig_color),
568
+ zorder=2,
569
+ )
570
+ # Complexes significant in both datasets stay black to avoid ambiguous color mixing.
571
+ scatter_significant(significant_in_both, "black", zorder=3)
572
+
573
+ # Diagonal & axes cosmetics
574
+ ax.plot([0, 1], [0, 1], linestyle='-', color='lightgray', alpha=0.4, linewidth=0.5, zorder=1)
575
+
576
+ # Force square aspect ratio and exact 0-1 range
389
577
  ax.set_xlim(0, 1)
390
578
  ax.set_ylim(0, 1)
391
579
  ax.set_aspect('equal', adjustable='box')
@@ -400,15 +588,55 @@ def plot_percomplex_scatter(n_top=10, sig_color='#B71A2A', nonsig_color='#DBDDDD
400
588
  #ax.set_title(f"{pair[0]} vs {pair[1]} - Comparison of complex performance")
401
589
 
402
590
  # Nature style: no grid, open top/right spines
403
- ax.grid(False)
404
- ax.spines['top'].set_visible(False)
405
- ax.spines['right'].set_visible(False)
406
-
407
- plt.tight_layout()
408
-
409
- if plot_config["save_plot"]:
410
- output_type = plot_config["output_type"]
411
- output_path = Path(config["output_folder"]) / f"percomplex_scatter_{pair[0]}_vs_{pair[1]}.{output_type}"
591
+ ax.grid(False)
592
+ ax.spines['top'].set_visible(False)
593
+ ax.spines['right'].set_visible(False)
594
+
595
+ plt.tight_layout()
596
+
597
+ # Radial label positioning searches all directions around each significant point.
598
+ label_items = []
599
+ for idx in sig_df.index:
600
+ label_items.append((
601
+ float(sig_df.loc[idx, pair[0]]),
602
+ float(sig_df.loc[idx, pair[1]]),
603
+ _short_scatter_label(df.loc[idx, 'Name']),
604
+ ))
605
+ all_points = df[[pair[0], pair[1]]].dropna().to_numpy(dtype=float)
606
+ label_positions = _place_scatter_labels_radially(
607
+ ax,
608
+ label_items,
609
+ all_points,
610
+ fontsize=4,
611
+ bbox_props=None,
612
+ )
613
+
614
+ for label_pos in label_positions:
615
+ ax.plot(
616
+ [label_pos["point_x"], label_pos["text_x"]],
617
+ [label_pos["point_y"], label_pos["text_y"]],
618
+ color=label_color,
619
+ linewidth=0.6,
620
+ alpha=0.15,
621
+ zorder=3,
622
+ )
623
+ ax.text(
624
+ label_pos["text_x"],
625
+ label_pos["text_y"],
626
+ label_pos["label"],
627
+ fontsize=4,
628
+ ha=label_pos["ha"],
629
+ va=label_pos["va"],
630
+ color=label_color,
631
+ linespacing=1,
632
+ zorder=4,
633
+ clip_on=True,
634
+ bbox=None,
635
+ )
636
+
637
+ if plot_config["save_plot"]:
638
+ output_type = plot_config["output_type"]
639
+ output_path = Path(config["output_folder"]) / f"percomplex_scatter_{pair[0]}_vs_{pair[1]}.{output_type}"
412
640
  fig.savefig(output_path, bbox_inches="tight", format=output_type)
413
641
 
414
642
  if plot_config.get("show_plot", True):
@@ -894,16 +1122,28 @@ def position_cluster_labels(cluster, cluster_id, max_y, effective_max_y, label_c
894
1122
  clip_on=True, bbox=bbox_props
895
1123
  )
896
1124
 
897
- def plot_percomplex_scatter_bysize(n_labels=10, n_top=10, sig_color='#B71A2A', nonsig_color='#DBDDDD',
898
- label_color='black', border_color='black', border_width=1.0,
899
- show_text_background=True):
900
- config = dload("config")
901
- plot_config = config["plotting"]
902
- rdict = dload("pra_percomplex")
903
-
904
- for key, per_complex in rdict.items():
905
- sorted_pc = per_complex.sort_values(by="auc_score", ascending=False, na_position="last")
906
- top_labels, rest = sorted_pc.head(n_labels), sorted_pc.iloc[n_labels:]
1125
+ def plot_percomplex_scatter_bysize(
1126
+ n_labels=10,
1127
+ n_top=10,
1128
+ sig_color='black',
1129
+ nonsig_color='none',
1130
+ label_color='black',
1131
+ border_color='black',
1132
+ border_width=1.0,
1133
+ nonsig_border_color="#7F7F7F",
1134
+ nonsig_border_width=0.5,
1135
+ show_text_background=True,
1136
+ ):
1137
+ config = dload("config")
1138
+ plot_config = config["plotting"]
1139
+ rdict = dload("pra_percomplex")
1140
+ input_colors = dload("input", "colors")
1141
+ input_colors = {_sanitize(k): v for k, v in input_colors.items()} if input_colors else {}
1142
+
1143
+ for key, per_complex in rdict.items():
1144
+ dataset_color = input_colors.get(_sanitize(key), sig_color)
1145
+ sorted_pc = per_complex.sort_values(by="auc_score", ascending=False, na_position="last")
1146
+ top_labels, rest = sorted_pc.head(n_labels), sorted_pc.iloc[n_labels:]
907
1147
 
908
1148
  # Calculate data range for appropriate figure sizing
909
1149
  max_genes = sorted_pc.n_used_genes.max()
@@ -914,22 +1154,22 @@ def plot_percomplex_scatter_bysize(n_labels=10, n_top=10, sig_color='#B71A2A', n
914
1154
  fig_height = min(max(4, aspect_ratio), 8) # Between 4-8 inches
915
1155
  fig, ax = plt.subplots(figsize=(6, fig_height))
916
1156
 
917
- # Background (REST): filled dots with black borders, not rasterized
918
- ax.scatter(
919
- rest.auc_score, rest.n_used_genes,
920
- facecolors=nonsig_color, edgecolors=border_color,
921
- linewidth=border_width, s=rest.n_used_genes * 5,
922
- alpha=1.0, label="Other Complexes",
923
- zorder=0
924
- )
925
-
926
- # Top N: filled dots with black borders
927
- ax.scatter(
928
- top_labels.auc_score, top_labels.n_used_genes,
929
- facecolors=sig_color, edgecolors=border_color,
930
- linewidth=border_width, s=top_labels.n_used_genes * 8,
931
- label=f"Top {n_labels} AUC Scores", alpha=1.0, zorder=2
932
- )
1157
+ # Background: non-significant complexes are open circles.
1158
+ ax.scatter(
1159
+ rest.auc_score, rest.n_used_genes,
1160
+ facecolors="none", edgecolors=nonsig_border_color,
1161
+ linewidth=nonsig_border_width, s=rest.n_used_genes * 5,
1162
+ alpha=0.8, label="Other Complexes",
1163
+ zorder=0
1164
+ )
1165
+
1166
+ # Top N/significant complexes are filled black circles.
1167
+ ax.scatter(
1168
+ top_labels.auc_score, top_labels.n_used_genes,
1169
+ facecolors=dataset_color, edgecolors=dataset_color,
1170
+ linewidth=border_width, s=top_labels.n_used_genes * 8,
1171
+ label=f"Top {n_labels} AUC Scores", alpha=1.0, zorder=2
1172
+ )
933
1173
 
934
1174
  # Enhanced anti-overlap labeling system
935
1175
  coords = [(row.auc_score, row.n_used_genes, idx) for idx, row in top_labels.iterrows()]
@@ -1010,6 +1250,7 @@ def plot_complex_contributions(
1010
1250
  tmp = np.tile(x, (mx, 1))
1011
1251
  x = cont_stepwise_mat.values / tmp
1012
1252
  x_df = pd.DataFrame(x, index=cont_stepwise_anno, columns=cont_stepwise_mat.columns)
1253
+
1013
1254
  ind_for_mean = y >= (last_prec_value - min_precision_cutoff)
1014
1255
  if sum(ind_for_mean) == 0:
1015
1256
  log.info("No values above 'min.precision.cutoff'"); return False
@@ -1094,15 +1335,23 @@ def plot_significant_complexes():
1094
1335
  input_colors = {_sanitize(k): v for k, v in input_colors.items()}
1095
1336
 
1096
1337
  thresholds = [0.1, 0.2, 0.3, 0.4, 0.5]
1338
+ if not isinstance(pra_percomplex, dict) or not pra_percomplex:
1339
+ log.warning("No per-complex PRA data found. Run pra_percomplex() first.")
1340
+ return pd.DataFrame(index=thresholds)
1341
+
1097
1342
  datasets = list(pra_percomplex.keys())
1098
1343
  num_datasets = len(datasets)
1099
1344
 
1345
+ if num_datasets == 0:
1346
+ return pd.DataFrame(index=thresholds)
1347
+
1100
1348
  df = pd.DataFrame(index=thresholds)
1101
1349
  for key, complex_data in pra_percomplex.items():
1102
1350
  if "corrected_auc_score" in complex_data.columns:
1103
1351
  score_col = "corrected_auc_score"
1104
1352
  else:
1105
1353
  score_col = "auc_score"
1354
+
1106
1355
  df[key] = [complex_data.query(f'{score_col} >= {t}').shape[0] for t in thresholds]
1107
1356
 
1108
1357
  fig, ax = plt.subplots()
@@ -1221,15 +1470,20 @@ def plot_auc_scores():
1221
1470
  return pra_dict
1222
1471
 
1223
1472
 
1224
- def plot_mpr_complexes_auc_scores(filter_key: str = "all"):
1473
+ def plot_mpr_complex_auc_scores(variant: str = "unfiltered", save=None, outname=None):
1225
1474
  """Plot AUC scores for the mPR complexes curve (Fig 1F-style).
1226
1475
 
1227
1476
  Requires `mpr_prepare()` to have been run for each dataset.
1228
1477
 
1229
1478
  Parameters
1230
1479
  ----------
1231
- filter_key : str
1232
- One of: "all", "no_mtRibo_ETCI", "no_small_highAUPRC".
1480
+ variant : str
1481
+ One of: "unfiltered", "without_mt_ribo_etci",
1482
+ "without_small_high_auprc".
1483
+ save : bool, optional
1484
+ Whether to save the figure. If None, uses config["plotting"]["save_plot"].
1485
+ outname : str, optional
1486
+ Output filename. If None, auto-generated.
1233
1487
 
1234
1488
  Returns
1235
1489
  -------
@@ -1250,12 +1504,14 @@ def plot_mpr_complexes_auc_scores(filter_key: str = "all"):
1250
1504
  )
1251
1505
  return pd.Series(dtype=float)
1252
1506
 
1507
+ variant_key = _normalize_mpr_variant(variant)
1508
+
1253
1509
  # Build Series: dataset -> auc
1254
1510
  auc_by_dataset = {}
1255
1511
  for dataset, per_filter in mpr_auc_dict.items():
1256
1512
  if not isinstance(per_filter, dict):
1257
1513
  continue
1258
- val = per_filter.get(filter_key)
1514
+ val = per_filter.get(variant_key)
1259
1515
  if val is None:
1260
1516
  continue
1261
1517
  try:
@@ -1265,7 +1521,8 @@ def plot_mpr_complexes_auc_scores(filter_key: str = "all"):
1265
1521
 
1266
1522
  if not auc_by_dataset:
1267
1523
  log.warning(
1268
- f"No mPR complexes AUC scores found for filter '{filter_key}'. Available filters: {list(FILTER_STYLES.keys())}"
1524
+ f"No mPR complex AUC scores found for variant '{variant}'. "
1525
+ f"Available variants: {list(PUBLIC_MPR_VARIANTS.keys())}"
1269
1526
  )
1270
1527
  return pd.Series(dtype=float)
1271
1528
 
@@ -1308,11 +1565,16 @@ def plot_mpr_complexes_auc_scores(filter_key: str = "all"):
1308
1565
  ax.spines["top"].set_visible(False)
1309
1566
  ax.spines["right"].set_visible(False)
1310
1567
 
1311
- if plot_config.get("save_plot", False):
1568
+ should_save = plot_config.get("save_plot", False) if save is None else bool(save)
1569
+ if should_save:
1312
1570
  output_type = plot_config.get("output_type", "pdf")
1313
1571
  output_folder = Path(config["output_folder"])
1314
1572
  output_folder.mkdir(parents=True, exist_ok=True)
1315
- output_path = output_folder / f"mpr_complexes_auc_{filter_key}.{output_type}"
1573
+ if outname is None:
1574
+ outname = f"mpr_complexes_auc_{variant_key}.{output_type}"
1575
+ output_path = Path(outname)
1576
+ if len(output_path.parts) == 1:
1577
+ output_path = output_folder / outname
1316
1578
  plt.savefig(output_path, bbox_inches="tight", format=output_type)
1317
1579
 
1318
1580
  if plot_config.get("show_plot", True):
@@ -1321,6 +1583,13 @@ def plot_mpr_complexes_auc_scores(filter_key: str = "all"):
1321
1583
  plt.close(fig)
1322
1584
  return s
1323
1585
 
1586
+
1587
+ def plot_mpr_complexes_auc_scores(filter_key: str = "all"):
1588
+ """Backward-compatible wrapper for plot_mpr_complex_auc_scores()."""
1589
+ return plot_mpr_complex_auc_scores(
1590
+ variant=_legacy_filter_to_variant(filter_key, default="unfiltered")
1591
+ )
1592
+
1324
1593
  # -----------------------------------------------------------------------------
1325
1594
  # mPR plots (Fig. 1E and Fig. 1F)
1326
1595
  # -----------------------------------------------------------------------------
@@ -1475,26 +1744,6 @@ def plot_mpr_tp(name, ax=None, save=True, outname=None):
1475
1744
 
1476
1745
  return ax
1477
1746
 
1478
- """
1479
- Multi-dataset mPR plotting functions.
1480
-
1481
- Usage:
1482
- from pythonflex.plotting import plot_mpr_tp_multi, plot_mpr_complexes_multi
1483
-
1484
- # Plot multiple datasets
1485
- plot_mpr_tp_multi(["19Q2", "19Q4", "20Q1"])
1486
- plot_mpr_complexes_multi(["19Q2", "19Q4", "20Q1"])
1487
- """
1488
-
1489
- import numpy as np
1490
- import pandas as pd
1491
- import matplotlib.pyplot as plt
1492
- from matplotlib.lines import Line2D
1493
- from pathlib import Path
1494
-
1495
- from .utils import dload
1496
- from .logging_config import log
1497
-
1498
1747
  # Default color palette (colorblind-friendly)
1499
1748
  DEFAULT_COLORS = [
1500
1749
  "#4E79A7", # blue
@@ -1509,40 +1758,101 @@ DEFAULT_COLORS = [
1509
1758
  "#BAB0AC", # gray
1510
1759
  ]
1511
1760
 
1512
- # Filter line styles
1513
- FILTER_STYLES = {
1761
+ # Public mPR variant names map to the internal keys stored by mpr_prepare().
1762
+ PUBLIC_MPR_VARIANTS = {
1763
+ "unfiltered": "all",
1764
+ "without_mt_ribo_etci": "no_mtRibo_ETCI",
1765
+ "without_small_high_auprc": "no_small_highAUPRC",
1766
+ }
1767
+ INTERNAL_MPR_VARIANTS = {v: k for k, v in PUBLIC_MPR_VARIANTS.items()}
1768
+
1769
+ # mPR variant line styles keyed by internal storage names.
1770
+ MPR_VARIANT_STYLES = {
1514
1771
  "all": {"linestyle": "-", "label": "all data"},
1515
1772
  "no_mtRibo_ETCI": {"linestyle": "--", "label": "no mtRibo, ETC I"},
1516
1773
  "no_small_highAUPRC": {"linestyle": "dotted", "label": "no small, high AUPRC"},
1517
1774
  }
1518
1775
 
1776
+ # Compatibility alias for users who imported this internal constant.
1777
+ FILTER_STYLES = MPR_VARIANT_STYLES
1778
+
1779
+
1780
+ def _normalize_mpr_variant(variant):
1781
+ """Return the internal mPR variant key for one public variant name."""
1782
+ if variant in PUBLIC_MPR_VARIANTS:
1783
+ return PUBLIC_MPR_VARIANTS[variant]
1784
+ if variant in MPR_VARIANT_STYLES:
1785
+ if variant == "all":
1786
+ return PUBLIC_MPR_VARIANTS["unfiltered"]
1787
+ return variant
1788
+ raise ValueError(
1789
+ "Unknown mPR variant "
1790
+ f"{variant!r}. Use one of {list(PUBLIC_MPR_VARIANTS.keys())}."
1791
+ )
1519
1792
 
1520
- def _normalize_show_filters(show_filters):
1521
- """Normalize show_filters to an ordered tuple of filter keys.
1522
1793
 
1523
- Common footgun: passing a single string (e.g. "no_mtRibo_ETCI") is iterable,
1524
- which would otherwise be treated as a sequence of characters.
1525
- """
1794
+ def _normalize_mpr_variants(variants):
1795
+ """Normalize public mPR variant names to internal storage keys."""
1796
+ if variants is None:
1797
+ raw_variants = ("all",)
1798
+ elif isinstance(variants, str):
1799
+ raw_variants = (variants,)
1800
+ else:
1801
+ try:
1802
+ raw_variants = tuple(variants)
1803
+ except TypeError:
1804
+ raw_variants = (variants,)
1805
+
1806
+ out = []
1807
+ for variant in raw_variants:
1808
+ if variant == "all":
1809
+ out.extend(PUBLIC_MPR_VARIANTS.values())
1810
+ else:
1811
+ out.append(_normalize_mpr_variant(variant))
1812
+
1813
+ # Preserve user order while removing duplicates.
1814
+ return tuple(dict.fromkeys(out))
1815
+
1816
+
1817
+ def _legacy_filter_to_variant(filter_key, default=None):
1818
+ """Map old filter-key names to public variant names."""
1819
+ if filter_key is None:
1820
+ return default if default is not None else "all"
1821
+ mapping = {
1822
+ "all": "unfiltered",
1823
+ "no_mtRibo_ETCI": "without_mt_ribo_etci",
1824
+ "no_small_highAUPRC": "without_small_high_auprc",
1825
+ }
1826
+ return mapping.get(filter_key, filter_key)
1827
+
1828
+
1829
+ def _legacy_filters_to_variants(show_filters):
1830
+ """Map old show_filters values to public variant names."""
1526
1831
  if show_filters is None:
1527
- return tuple(FILTER_STYLES.keys())
1832
+ return "all"
1528
1833
  if isinstance(show_filters, str):
1529
- return (show_filters,)
1834
+ return _legacy_filter_to_variant(show_filters)
1530
1835
  try:
1531
- return tuple(show_filters)
1836
+ return tuple(_legacy_filter_to_variant(item) for item in show_filters)
1532
1837
  except TypeError:
1533
- return (show_filters,)
1838
+ return (_legacy_filter_to_variant(show_filters),)
1534
1839
 
1535
- def plot_mpr_tp_multi(
1840
+
1841
+ def _normalize_show_filters(show_filters):
1842
+ """Backward-compatible normalizer for old internal filter keys."""
1843
+ return _normalize_mpr_variants(_legacy_filters_to_variants(show_filters))
1844
+
1845
+ def plot_mpr_true_positive_curve(
1536
1846
  dataset_names=None,
1537
1847
  colors=None,
1538
1848
  ax=None,
1539
1849
  save=True,
1540
1850
  outname=None,
1541
1851
  linewidth=1.8,
1542
- show_filters=("all", "no_mtRibo_ETCI", "no_small_highAUPRC"),
1852
+ variants="unfiltered",
1543
1853
  ):
1544
1854
  """
1545
- Plot TP vs precision curves for multiple datasets.
1855
+ Plot mPR true-positive vs precision curves for multiple datasets.
1546
1856
 
1547
1857
  Can auto-detect datasets or use provided dataset names.
1548
1858
  Each dataset gets one color, each filter type gets one line style.
@@ -1562,8 +1872,9 @@ def plot_mpr_tp_multi(
1562
1872
  Output filename. If None, auto-generated.
1563
1873
  linewidth : float
1564
1874
  Line width for all curves
1565
- show_filters : tuple of str
1566
- Which filters to show. Default is all three.
1875
+ variants : str or iterable of str
1876
+ Which mPR variants to show. Use "unfiltered",
1877
+ "without_mt_ribo_etci", "without_small_high_auprc", or "all".
1567
1878
 
1568
1879
  Returns
1569
1880
  -------
@@ -1573,7 +1884,7 @@ def plot_mpr_tp_multi(
1573
1884
  plot_config = config["plotting"]
1574
1885
  input_colors = dload("input", "colors")
1575
1886
 
1576
- show_filters = _normalize_show_filters(show_filters)
1887
+ variant_keys = _normalize_mpr_variants(variants)
1577
1888
 
1578
1889
  # Sanitize color keys
1579
1890
  if input_colors:
@@ -1641,13 +1952,13 @@ def plot_mpr_tp_multi(
1641
1952
  tp_curves = mpr["tp_curves"]
1642
1953
  color = colors[i % len(colors)]
1643
1954
 
1644
- for filter_key in show_filters:
1645
- if filter_key not in tp_curves:
1955
+ for variant_key in variant_keys:
1956
+ if variant_key not in tp_curves:
1646
1957
  continue
1647
1958
 
1648
- data = tp_curves[filter_key]
1959
+ data = tp_curves[variant_key]
1649
1960
  if not isinstance(data, dict) or "tp" not in data or "precision" not in data:
1650
- log.warning(f"Invalid tp_curves data structure for '{name}' filter '{filter_key}', skipping.")
1961
+ log.warning(f"Invalid tp_curves data structure for '{name}' variant '{variant_key}', skipping.")
1651
1962
  continue
1652
1963
 
1653
1964
  tp = np.asarray(data["tp"], dtype=float)
@@ -1661,7 +1972,7 @@ def plot_mpr_tp_multi(
1661
1972
  prec_plot = prec[mask]
1662
1973
  xmax = max(xmax, float(tp_plot.max()))
1663
1974
 
1664
- style = FILTER_STYLES.get(filter_key, {})
1975
+ style = MPR_VARIANT_STYLES.get(variant_key, {})
1665
1976
  ax.plot(
1666
1977
  tp_plot,
1667
1978
  prec_plot,
@@ -1694,7 +2005,7 @@ def plot_mpr_tp_multi(
1694
2005
  ax.spines['right'].set_visible(False)
1695
2006
 
1696
2007
  # Create vertically stacked legends
1697
- _add_vertical_legend(ax, dataset_names, colors, show_filters, linewidth)
2008
+ _add_vertical_legend(ax, dataset_names, colors, variant_keys, linewidth)
1698
2009
 
1699
2010
  # Save
1700
2011
  if save:
@@ -1713,7 +2024,8 @@ def plot_mpr_tp_multi(
1713
2024
 
1714
2025
  return ax
1715
2026
 
1716
- def plot_mpr_complexes_multi(
2027
+
2028
+ def plot_mpr_tp_multi(
1717
2029
  dataset_names=None,
1718
2030
  colors=None,
1719
2031
  ax=None,
@@ -1721,11 +2033,31 @@ def plot_mpr_complexes_multi(
1721
2033
  outname=None,
1722
2034
  linewidth=1.8,
1723
2035
  show_filters=("all", "no_mtRibo_ETCI", "no_small_highAUPRC"),
2036
+ ):
2037
+ """Backward-compatible wrapper for plot_mpr_true_positive_curve()."""
2038
+ return plot_mpr_true_positive_curve(
2039
+ dataset_names=dataset_names,
2040
+ colors=colors,
2041
+ ax=ax,
2042
+ save=save,
2043
+ outname=outname,
2044
+ linewidth=linewidth,
2045
+ variants=_legacy_filters_to_variants(show_filters),
2046
+ )
2047
+
2048
+ def plot_mpr_complex_coverage_curve(
2049
+ dataset_names=None,
2050
+ colors=None,
2051
+ ax=None,
2052
+ save=True,
2053
+ outname=None,
2054
+ linewidth=1.8,
2055
+ variants="unfiltered",
1724
2056
  show_markers="auto",
1725
2057
  marker_size=20,
1726
2058
  ):
1727
2059
  """
1728
- Plot module-level PR (#complexes vs precision) for multiple datasets.
2060
+ Plot mPR complex-coverage vs precision curves for multiple datasets.
1729
2061
 
1730
2062
  Can auto-detect datasets or use provided dataset names.
1731
2063
  Each dataset gets one color, each filter type gets one line style.
@@ -1745,8 +2077,9 @@ def plot_mpr_complexes_multi(
1745
2077
  Output filename. If None, auto-generated.
1746
2078
  linewidth : float
1747
2079
  Line width for all curves
1748
- show_filters : tuple of str
1749
- Which filters to show. Default is all three.
2080
+ variants : str or iterable of str
2081
+ Which mPR variants to show. Use "unfiltered",
2082
+ "without_mt_ribo_etci", "without_small_high_auprc", or "all".
1750
2083
  show_markers : bool or "auto"
1751
2084
  If True, draw markers on curves to make short curves visible.
1752
2085
  If "auto" (default), markers are drawn only for curves with <= 10 points.
@@ -1761,7 +2094,7 @@ def plot_mpr_complexes_multi(
1761
2094
  plot_config = config["plotting"]
1762
2095
  input_colors = dload("input", "colors")
1763
2096
 
1764
- show_filters = _normalize_show_filters(show_filters)
2097
+ variant_keys = _normalize_mpr_variants(variants)
1765
2098
 
1766
2099
  # Sanitize color keys
1767
2100
  if input_colors:
@@ -1812,32 +2145,61 @@ def plot_mpr_complexes_multi(
1812
2145
  else:
1813
2146
  fig = ax.figure
1814
2147
 
1815
- # Plot each dataset
2148
+ # First pass: determine max coverage across all datasets/filters for adaptive x-axis
2149
+ max_cov_global = 0
2150
+ _mpr_cache = {}
1816
2151
  for i, name in enumerate(dataset_names):
1817
2152
  mpr = dload("mpr", name)
2153
+ _mpr_cache[name] = mpr
2154
+ if mpr is not None:
2155
+ for variant_key in variant_keys:
2156
+ arr = mpr["coverage_curves"].get(variant_key)
2157
+ if arr is not None:
2158
+ max_cov_global = max(max_cov_global, float(np.asarray(arr).max()))
2159
+
2160
+ # Build adaptive x-axis limits and ticks
2161
+ import math
2162
+ if max_cov_global <= 200:
2163
+ # Original fixed range — keeps CORUM plots identical to before
2164
+ x_max_plot = 200
2165
+ tick_positions = [1, 2, 20, 200]
2166
+ tick_labels = ["0", "2", "20", "200"]
2167
+ else:
2168
+ # Round up to the next power of 10 so the max bar has breathing room
2169
+ x_max_plot = 10 ** math.ceil(math.log10(max_cov_global + 1))
2170
+ tick_positions = [1, 2]
2171
+ v = 10
2172
+ while v <= x_max_plot:
2173
+ tick_positions.append(v)
2174
+ v *= 10
2175
+ tick_labels = ["0"] + [str(t) for t in tick_positions[1:]]
2176
+
2177
+ # Plot each dataset
2178
+ for i, name in enumerate(dataset_names):
2179
+ mpr = _mpr_cache[name]
1818
2180
  if mpr is None:
1819
2181
  log.warning(f"mPR data for '{name}' not found, skipping.")
1820
2182
  continue
1821
-
2183
+
1822
2184
  precision_cutoffs = np.asarray(mpr["precision_cutoffs"], dtype=float)
1823
2185
  coverage = mpr["coverage_curves"]
1824
2186
  color = colors[i % len(colors)]
1825
-
1826
- for filter_key in show_filters:
1827
- if filter_key not in coverage:
2187
+
2188
+ for variant_key in variant_keys:
2189
+ if variant_key not in coverage:
1828
2190
  continue
1829
-
1830
- cov = np.asarray(coverage[filter_key], dtype=float)
1831
-
1832
- # Keep only positive coverage up to 200 complexes
1833
- mask = (cov > 0) & (cov <= 200)
2191
+
2192
+ cov = np.asarray(coverage[variant_key], dtype=float)
2193
+
2194
+ # Keep only positive coverage within the visible x range
2195
+ mask = (cov > 0) & (cov <= x_max_plot)
1834
2196
  if not mask.any():
1835
2197
  continue
1836
-
2198
+
1837
2199
  cov_plot = cov[mask]
1838
2200
  prec_plot = precision_cutoffs[mask]
1839
-
1840
- style = FILTER_STYLES.get(filter_key, {})
2201
+
2202
+ style = MPR_VARIANT_STYLES.get(variant_key, {})
1841
2203
 
1842
2204
  # Decide marker visibility
1843
2205
  if show_markers == "auto":
@@ -1858,17 +2220,15 @@ def plot_mpr_complexes_multi(
1858
2220
  marker=("o" if use_markers else None),
1859
2221
  markersize=(3 if use_markers else None),
1860
2222
  )
1861
-
2223
+
1862
2224
  # Configure axes
1863
2225
  ax.set_xscale("log")
1864
- ax.set_xlim(1, 200)
2226
+ ax.set_xlim(1, x_max_plot)
1865
2227
  ax.set_xlabel("# complexes")
1866
2228
  ax.set_ylabel("Precision")
1867
2229
  ax.set_ylim(0.0, 1.05)
1868
-
1869
- # Custom x-ticks
1870
- tick_positions = [1, 2, 20, 200]
1871
- tick_labels = ["0", "2", "20", "200"]
2230
+
2231
+ # Adaptive x-ticks
1872
2232
  ax.set_xticks(tick_positions)
1873
2233
  ax.set_xticklabels(tick_labels)
1874
2234
 
@@ -1877,7 +2237,7 @@ def plot_mpr_complexes_multi(
1877
2237
  ax.spines['right'].set_visible(False)
1878
2238
 
1879
2239
  # Create vertically stacked legends
1880
- _add_vertical_legend(ax, dataset_names, colors, show_filters, linewidth)
2240
+ _add_vertical_legend(ax, dataset_names, colors, variant_keys, linewidth)
1881
2241
 
1882
2242
  # Save
1883
2243
  if save:
@@ -1896,11 +2256,71 @@ def plot_mpr_complexes_multi(
1896
2256
 
1897
2257
  return ax
1898
2258
 
1899
- def _add_vertical_legend(ax, dataset_names, colors, show_filters, linewidth):
2259
+
2260
+ def plot_mpr_complexes_multi(
2261
+ dataset_names=None,
2262
+ colors=None,
2263
+ ax=None,
2264
+ save=True,
2265
+ outname=None,
2266
+ linewidth=1.8,
2267
+ show_filters=("all", "no_mtRibo_ETCI", "no_small_highAUPRC"),
2268
+ show_markers="auto",
2269
+ marker_size=20,
2270
+ ):
2271
+ """Backward-compatible wrapper for plot_mpr_complex_coverage_curve()."""
2272
+ return plot_mpr_complex_coverage_curve(
2273
+ dataset_names=dataset_names,
2274
+ colors=colors,
2275
+ ax=ax,
2276
+ save=save,
2277
+ outname=outname,
2278
+ linewidth=linewidth,
2279
+ variants=_legacy_filters_to_variants(show_filters),
2280
+ show_markers=show_markers,
2281
+ marker_size=marker_size,
2282
+ )
2283
+
2284
+
2285
+ def plot_mpr_summary(
2286
+ dataset_names=None,
2287
+ colors=None,
2288
+ variants="unfiltered",
2289
+ save=True,
2290
+ linewidth=1.8,
2291
+ show_markers="auto",
2292
+ marker_size=20,
2293
+ auc_variant=None,
2294
+ ):
2295
+ """Generate the standard mPR summary plots and return complex AUC scores."""
2296
+ plot_mpr_true_positive_curve(
2297
+ dataset_names=dataset_names,
2298
+ colors=colors,
2299
+ save=save,
2300
+ linewidth=linewidth,
2301
+ variants=variants,
2302
+ )
2303
+ plot_mpr_complex_coverage_curve(
2304
+ dataset_names=dataset_names,
2305
+ colors=colors,
2306
+ save=save,
2307
+ linewidth=linewidth,
2308
+ variants=variants,
2309
+ show_markers=show_markers,
2310
+ marker_size=marker_size,
2311
+ )
2312
+
2313
+ if auc_variant is None:
2314
+ variant_keys = _normalize_mpr_variants(variants)
2315
+ auc_variant = INTERNAL_MPR_VARIANTS.get(variant_keys[0], "unfiltered")
2316
+
2317
+ return plot_mpr_complex_auc_scores(variant=auc_variant, save=save)
2318
+
2319
+ def _add_vertical_legend(ax, dataset_names, colors, variant_keys, linewidth):
1900
2320
  """
1901
- Add vertically stacked legends: Dataset on top, Filter below.
2321
+ Add vertically stacked legends: Dataset on top, mPR variant below.
1902
2322
  """
1903
- show_filters = _normalize_show_filters(show_filters)
2323
+ variant_keys = _normalize_show_filters(variant_keys)
1904
2324
  # Legend 1: Datasets (colors) - solid lines
1905
2325
  dataset_handles = []
1906
2326
  for i, name in enumerate(dataset_names):
@@ -1908,19 +2328,19 @@ def _add_vertical_legend(ax, dataset_names, colors, show_filters, linewidth):
1908
2328
  handle = Line2D([0], [0], color=color, linewidth=linewidth, linestyle="-")
1909
2329
  dataset_handles.append(handle)
1910
2330
 
1911
- # Legend 2: Filters (line styles) - black lines
1912
- filter_handles = []
1913
- filter_labels = []
1914
- for filter_key in show_filters:
1915
- style = FILTER_STYLES.get(filter_key, {})
2331
+ # Legend 2: mPR variants (line styles) - black lines
2332
+ variant_handles = []
2333
+ variant_labels = []
2334
+ for variant_key in variant_keys:
2335
+ style = MPR_VARIANT_STYLES.get(variant_key, {})
1916
2336
  handle = Line2D(
1917
2337
  [0], [0],
1918
2338
  color="black",
1919
2339
  linewidth=linewidth,
1920
2340
  linestyle=style.get("linestyle", "-")
1921
2341
  )
1922
- filter_handles.append(handle)
1923
- filter_labels.append(style.get("label", filter_key))
2342
+ variant_handles.append(handle)
2343
+ variant_labels.append(style.get("label", variant_key))
1924
2344
 
1925
2345
  # Position legends vertically with proper alignment
1926
2346
  # Dataset legend on upper right
@@ -1938,19 +2358,19 @@ def _add_vertical_legend(ax, dataset_names, colors, show_filters, linewidth):
1938
2358
 
1939
2359
  # Filter legend below the dataset legend, aligned properly without title
1940
2360
  legend2 = ax.legend(
1941
- filter_handles,
1942
- filter_labels,
2361
+ variant_handles,
2362
+ variant_labels,
1943
2363
  loc="upper left",
1944
2364
  frameon=False,
1945
2365
  fontsize=7,
1946
2366
  bbox_to_anchor=(1.05, 1.0 - len(dataset_names) * 0.06 - 0.1)
1947
2367
  )
1948
2368
 
1949
- def _add_dual_legend(ax, dataset_names, colors, show_filters, linewidth):
2369
+ def _add_dual_legend(ax, dataset_names, colors, variant_keys, linewidth):
1950
2370
  """
1951
- Add two legends: one for datasets (colors), one for filters (line styles).
2371
+ Add two legends: one for datasets (colors), one for mPR variants (line styles).
1952
2372
  """
1953
- show_filters = _normalize_show_filters(show_filters)
2373
+ variant_keys = _normalize_show_filters(variant_keys)
1954
2374
  # Legend 1: Datasets (colors) - solid lines
1955
2375
  dataset_handles = []
1956
2376
  for i, name in enumerate(dataset_names):
@@ -1958,19 +2378,19 @@ def _add_dual_legend(ax, dataset_names, colors, show_filters, linewidth):
1958
2378
  handle = Line2D([0], [0], color=color, linewidth=linewidth, linestyle="-")
1959
2379
  dataset_handles.append(handle)
1960
2380
 
1961
- # Legend 2: Filters (line styles) - black lines
1962
- filter_handles = []
1963
- filter_labels = []
1964
- for filter_key in show_filters:
1965
- style = FILTER_STYLES.get(filter_key, {})
2381
+ # Legend 2: mPR variants (line styles) - black lines
2382
+ variant_handles = []
2383
+ variant_labels = []
2384
+ for variant_key in variant_keys:
2385
+ style = MPR_VARIANT_STYLES.get(variant_key, {})
1966
2386
  handle = Line2D(
1967
2387
  [0], [0],
1968
2388
  color="black",
1969
2389
  linewidth=linewidth,
1970
2390
  linestyle=style.get("linestyle", "-")
1971
2391
  )
1972
- filter_handles.append(handle)
1973
- filter_labels.append(style.get("label", filter_key))
2392
+ variant_handles.append(handle)
2393
+ variant_labels.append(style.get("label", variant_key))
1974
2394
 
1975
2395
  # Position legends
1976
2396
  # Dataset legend on upper right
@@ -1987,19 +2407,12 @@ def _add_dual_legend(ax, dataset_names, colors, show_filters, linewidth):
1987
2407
 
1988
2408
  # Filter legend on lower left or right depending on plot type
1989
2409
  legend2 = ax.legend(
1990
- filter_handles,
1991
- filter_labels,
2410
+ variant_handles,
2411
+ variant_labels,
1992
2412
  loc="lower left",
1993
2413
  frameon=False,
1994
- title="Filter",
2414
+ title="Variant",
1995
2415
  fontsize=7,
1996
2416
  title_fontsize=8,
1997
2417
  )
1998
2418
 
1999
- # ============================================================================
2000
- # Single dataset functions are now obsolete
2001
- # ============================================================================
2002
-
2003
- # Note: The original single dataset functions plot_mpr_tp() and plot_mpr_complexes()
2004
- # have been replaced by the multi functions that now auto-detect available datasets.
2005
- # Use plot_mpr_tp_multi() and plot_mpr_complexes_multi() instead.