pythonflex 0.3.4__py3-none-any.whl → 0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pythonflex/__init__.py +28 -4
- pythonflex/analysis.py +287 -579
- pythonflex/examples/basic_usage.py +38 -30
- pythonflex/examples/manuscript.py +37 -43
- pythonflex/examples/runtime/runtime_benchmark.py +218 -0
- pythonflex/examples/runtime/runtime_benchmark_10_runs_memmap.py +534 -0
- pythonflex/examples/runtime/runtime_benchmark_corum_njobs.py +245 -0
- pythonflex/examples/runtime/runtime_benchmark_gobp_njobs_chunks.py +319 -0
- pythonflex/examples/runtime/runtime_benchmark_gobp_optimization.py +417 -0
- pythonflex/examples/runtime/runtime_benchmark_repeated.py +347 -0
- pythonflex/old_functions.py +422 -0
- pythonflex/plotting.py +655 -242
- pythonflex/preprocessing.py +54 -216
- pythonflex/utils.py +36 -9
- {pythonflex-0.3.4.dist-info → pythonflex-0.4.dist-info}/METADATA +8 -6
- pythonflex-0.4.dist-info/RECORD +32 -0
- {pythonflex-0.3.4.dist-info → pythonflex-0.4.dist-info}/WHEEL +1 -1
- pythonflex-0.4.dist-info/licenses/LICENSE +7 -0
- pythonflex-0.3.4.dist-info/RECORD +0 -24
- {pythonflex-0.3.4.dist-info → pythonflex-0.4.dist-info}/entry_points.txt +0 -0
pythonflex/plotting.py
CHANGED
|
@@ -9,6 +9,7 @@ import pandas as pd
|
|
|
9
9
|
import matplotlib.pyplot as plt
|
|
10
10
|
from matplotlib import patches
|
|
11
11
|
from matplotlib.cm import get_cmap
|
|
12
|
+
from matplotlib.lines import Line2D
|
|
12
13
|
from matplotlib.ticker import NullFormatter, NullLocator
|
|
13
14
|
|
|
14
15
|
# Completely disable LaTeX and clear all font cache/references
|
|
@@ -26,14 +27,17 @@ mpl.rcParams['font.cursive'] = ['Apple Chancery', 'Textile', 'Zapf Chancery', 'S
|
|
|
26
27
|
mpl.rcParams['font.fantasy'] = ['Comic Sans MS', 'Chicago', 'Charcoal', 'Impact', 'Western', 'Humor Sans', 'fantasy']
|
|
27
28
|
mpl.rcParams['font.monospace'] = ['DejaVu Sans Mono', 'Bitstream Vera Sans Mono', 'Computer Modern Typewriter', 'Andale Mono', 'Nimbus Mono L', 'Courier New', 'Courier', 'Fixed', 'Terminal', 'monospace']
|
|
28
29
|
|
|
29
|
-
# Remove any LaTeX-specific math font settings
|
|
30
|
-
mpl.rcParams['mathtext.fontset'] = 'dejavusans'
|
|
31
|
-
mpl.rcParams['mathtext.default'] = 'regular'
|
|
30
|
+
# Remove any LaTeX-specific math font settings
|
|
31
|
+
mpl.rcParams['mathtext.fontset'] = 'dejavusans'
|
|
32
|
+
mpl.rcParams['mathtext.default'] = 'regular'
|
|
33
|
+
mpl.rcParams['pdf.fonttype'] = 42
|
|
34
|
+
mpl.rcParams['ps.fonttype'] = 42
|
|
35
|
+
mpl.rcParams['svg.fonttype'] = 'none'
|
|
32
36
|
|
|
33
37
|
# Force font manager to rebuild with system fonts only
|
|
34
38
|
try:
|
|
35
39
|
fm.fontManager.__init__()
|
|
36
|
-
except:
|
|
40
|
+
except Exception:
|
|
37
41
|
pass
|
|
38
42
|
|
|
39
43
|
# Local modules
|
|
@@ -275,17 +279,233 @@ def plot_all_runs_pra(pra_list, mean_df=None, line_width=2.0, hide_minor_ticks=T
|
|
|
275
279
|
output_path = Path(config["output_folder"]) / f"aggregated_all_runs_precision_recall_curve.{output_type}"
|
|
276
280
|
fig.savefig(output_path, bbox_inches="tight", format=output_type)
|
|
277
281
|
|
|
278
|
-
if plot_config.get("show_plot", True):
|
|
279
|
-
plt.show()
|
|
280
|
-
plt.close(fig)
|
|
281
|
-
|
|
282
|
-
def
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
282
|
+
if plot_config.get("show_plot", True):
|
|
283
|
+
plt.show()
|
|
284
|
+
plt.close(fig)
|
|
285
|
+
|
|
286
|
+
def _short_scatter_label(label, max_chars=10):
|
|
287
|
+
label = str(label)
|
|
288
|
+
return label[:max_chars] + "." if len(label) > max_chars else label
|
|
289
|
+
|
|
290
|
+
def _bbox_overlap_area(box_a, box_b):
|
|
291
|
+
x_overlap = max(0.0, min(box_a[2], box_b[2]) - max(box_a[0], box_b[0]))
|
|
292
|
+
y_overlap = max(0.0, min(box_a[3], box_b[3]) - max(box_a[1], box_b[1]))
|
|
293
|
+
return x_overlap * y_overlap
|
|
294
|
+
|
|
295
|
+
def _line_intersects_bbox(start_px, end_px, box):
|
|
296
|
+
samples = np.linspace(0.05, 0.95, 12)
|
|
297
|
+
xs = start_px[0] + (end_px[0] - start_px[0]) * samples
|
|
298
|
+
ys = start_px[1] + (end_px[1] - start_px[1]) * samples
|
|
299
|
+
return np.any(
|
|
300
|
+
(xs >= box[0]) & (xs <= box[2]) &
|
|
301
|
+
(ys >= box[1]) & (ys <= box[3])
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
def _label_alignment(dx, dy):
|
|
305
|
+
if abs(dx) < 1e-9:
|
|
306
|
+
ha = "center"
|
|
307
|
+
else:
|
|
308
|
+
ha = "left" if dx > 0 else "right"
|
|
309
|
+
|
|
310
|
+
if abs(dy) < 1e-9:
|
|
311
|
+
va = "center"
|
|
312
|
+
else:
|
|
313
|
+
va = "bottom" if dy > 0 else "top"
|
|
314
|
+
|
|
315
|
+
return ha, va
|
|
316
|
+
|
|
317
|
+
def _label_display_bbox(ax, anchor_xy, width_px, height_px, ha, va):
|
|
318
|
+
anchor_x, anchor_y = ax.transData.transform(anchor_xy)
|
|
319
|
+
|
|
320
|
+
if ha == "left":
|
|
321
|
+
x0, x1 = anchor_x, anchor_x + width_px
|
|
322
|
+
elif ha == "right":
|
|
323
|
+
x0, x1 = anchor_x - width_px, anchor_x
|
|
324
|
+
else:
|
|
325
|
+
x0, x1 = anchor_x - width_px / 2.0, anchor_x + width_px / 2.0
|
|
326
|
+
|
|
327
|
+
if va == "bottom":
|
|
328
|
+
y0, y1 = anchor_y, anchor_y + height_px
|
|
329
|
+
elif va == "top":
|
|
330
|
+
y0, y1 = anchor_y - height_px, anchor_y
|
|
331
|
+
else:
|
|
332
|
+
y0, y1 = anchor_y - height_px / 2.0, anchor_y + height_px / 2.0
|
|
333
|
+
|
|
334
|
+
return (x0, y0, x1, y1)
|
|
335
|
+
|
|
336
|
+
def _measure_label_size_pixels(ax, label, fontsize, bbox_props, renderer):
|
|
337
|
+
probe = ax.text(
|
|
338
|
+
0.5, 0.5, label,
|
|
339
|
+
fontsize=fontsize,
|
|
340
|
+
ha="left",
|
|
341
|
+
va="bottom",
|
|
342
|
+
linespacing=1,
|
|
343
|
+
alpha=0.0,
|
|
344
|
+
bbox=bbox_props,
|
|
345
|
+
)
|
|
346
|
+
bbox = probe.get_window_extent(renderer=renderer).expanded(1.12, 1.35)
|
|
347
|
+
probe.remove()
|
|
348
|
+
return bbox.width, bbox.height
|
|
349
|
+
|
|
350
|
+
def _place_scatter_labels_radially(ax, label_items, obstacle_points, fontsize=4, bbox_props=None):
|
|
351
|
+
"""Choose label anchors by scanning candidate positions around each point."""
|
|
352
|
+
fig = ax.figure
|
|
353
|
+
fig.canvas.draw()
|
|
354
|
+
renderer = fig.canvas.get_renderer()
|
|
355
|
+
axis_box = ax.get_window_extent(renderer=renderer)
|
|
356
|
+
|
|
357
|
+
obstacle_points = np.asarray(obstacle_points, dtype=float)
|
|
358
|
+
if obstacle_points.size == 0:
|
|
359
|
+
obstacle_points = np.empty((0, 2), dtype=float)
|
|
360
|
+
else:
|
|
361
|
+
obstacle_points = obstacle_points[np.isfinite(obstacle_points).all(axis=1)]
|
|
362
|
+
obstacle_pixels = ax.transData.transform(obstacle_points) if len(obstacle_points) else np.empty((0, 2))
|
|
363
|
+
|
|
364
|
+
angles = np.deg2rad(np.arange(0, 360, 15))
|
|
365
|
+
distances = np.array([0.035, 0.055, 0.08, 0.11, 0.145, 0.18])
|
|
366
|
+
axis_margin_px = 2.0
|
|
367
|
+
|
|
368
|
+
valid_items = [
|
|
369
|
+
(x, y, label)
|
|
370
|
+
for x, y, label in label_items
|
|
371
|
+
if np.isfinite(x) and np.isfinite(y)
|
|
372
|
+
]
|
|
373
|
+
valid_items = sorted(
|
|
374
|
+
valid_items,
|
|
375
|
+
key=lambda item: (
|
|
376
|
+
min(item[0], 1.0 - item[0], item[1], 1.0 - item[1]),
|
|
377
|
+
-item[1],
|
|
378
|
+
-item[0],
|
|
379
|
+
),
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
placed = []
|
|
383
|
+
placed_boxes = []
|
|
384
|
+
|
|
385
|
+
for point_x, point_y, label in valid_items:
|
|
386
|
+
label_width, label_height = _measure_label_size_pixels(
|
|
387
|
+
ax, label, fontsize, bbox_props, renderer
|
|
388
|
+
)
|
|
389
|
+
point_px = ax.transData.transform((point_x, point_y))
|
|
390
|
+
best_candidate = None
|
|
391
|
+
|
|
392
|
+
for distance in distances:
|
|
393
|
+
for angle in angles:
|
|
394
|
+
dx = float(np.cos(angle) * distance)
|
|
395
|
+
dy = float(np.sin(angle) * distance)
|
|
396
|
+
text_x = point_x + dx
|
|
397
|
+
text_y = point_y + dy
|
|
398
|
+
ha, va = _label_alignment(dx, dy)
|
|
399
|
+
box = _label_display_bbox(
|
|
400
|
+
ax, (text_x, text_y), label_width, label_height, ha, va
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
outside_axes = (
|
|
404
|
+
box[0] < axis_box.x0 + axis_margin_px or
|
|
405
|
+
box[1] < axis_box.y0 + axis_margin_px or
|
|
406
|
+
box[2] > axis_box.x1 - axis_margin_px or
|
|
407
|
+
box[3] > axis_box.y1 - axis_margin_px
|
|
408
|
+
)
|
|
409
|
+
if outside_axes:
|
|
410
|
+
continue
|
|
411
|
+
|
|
412
|
+
overlaps = [_bbox_overlap_area(box, placed_box) for placed_box in placed_boxes]
|
|
413
|
+
overlap_hits = sum(area > 0 for area in overlaps)
|
|
414
|
+
overlap_area = sum(overlaps)
|
|
415
|
+
|
|
416
|
+
if len(obstacle_pixels):
|
|
417
|
+
dot_hits = np.count_nonzero(
|
|
418
|
+
(obstacle_pixels[:, 0] >= box[0]) &
|
|
419
|
+
(obstacle_pixels[:, 0] <= box[2]) &
|
|
420
|
+
(obstacle_pixels[:, 1] >= box[1]) &
|
|
421
|
+
(obstacle_pixels[:, 1] <= box[3])
|
|
422
|
+
)
|
|
423
|
+
if box[0] <= point_px[0] <= box[2] and box[1] <= point_px[1] <= box[3]:
|
|
424
|
+
dot_hits = max(0, dot_hits - 1)
|
|
425
|
+
else:
|
|
426
|
+
dot_hits = 0
|
|
427
|
+
|
|
428
|
+
text_px = ax.transData.transform((text_x, text_y))
|
|
429
|
+
connector_hits = sum(
|
|
430
|
+
_line_intersects_bbox(point_px, text_px, placed_box)
|
|
431
|
+
for placed_box in placed_boxes
|
|
432
|
+
)
|
|
433
|
+
edge_gap = min(
|
|
434
|
+
box[0] - axis_box.x0,
|
|
435
|
+
box[1] - axis_box.y0,
|
|
436
|
+
axis_box.x1 - box[2],
|
|
437
|
+
axis_box.y1 - box[3],
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
score = (
|
|
441
|
+
overlap_hits * 100000.0 +
|
|
442
|
+
overlap_area * 0.5 +
|
|
443
|
+
dot_hits * 1500.0 +
|
|
444
|
+
connector_hits * 1200.0 +
|
|
445
|
+
np.hypot(text_px[0] - point_px[0], text_px[1] - point_px[1]) * 0.05 +
|
|
446
|
+
10.0 / max(edge_gap, 1.0)
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
if best_candidate is None or score < best_candidate["score"]:
|
|
450
|
+
best_candidate = {
|
|
451
|
+
"point_x": point_x,
|
|
452
|
+
"point_y": point_y,
|
|
453
|
+
"text_x": text_x,
|
|
454
|
+
"text_y": text_y,
|
|
455
|
+
"label": label,
|
|
456
|
+
"ha": ha,
|
|
457
|
+
"va": va,
|
|
458
|
+
"box": box,
|
|
459
|
+
"score": score,
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
if best_candidate is None:
|
|
463
|
+
# Extremely rare fallback for very long labels or tight axes.
|
|
464
|
+
fallback_dx = -0.055 if point_x > 0.5 else 0.055
|
|
465
|
+
fallback_dy = -0.055 if point_y > 0.5 else 0.055
|
|
466
|
+
ha, va = _label_alignment(fallback_dx, fallback_dy)
|
|
467
|
+
text_x = max(0.02, min(0.98, point_x + fallback_dx))
|
|
468
|
+
text_y = max(0.02, min(0.98, point_y + fallback_dy))
|
|
469
|
+
best_candidate = {
|
|
470
|
+
"point_x": point_x,
|
|
471
|
+
"point_y": point_y,
|
|
472
|
+
"text_x": text_x,
|
|
473
|
+
"text_y": text_y,
|
|
474
|
+
"label": label,
|
|
475
|
+
"ha": ha,
|
|
476
|
+
"va": va,
|
|
477
|
+
"box": _label_display_bbox(
|
|
478
|
+
ax, (text_x, text_y), label_width, label_height, ha, va
|
|
479
|
+
),
|
|
480
|
+
"score": float("inf"),
|
|
481
|
+
}
|
|
482
|
+
|
|
483
|
+
placed_boxes.append(best_candidate["box"])
|
|
484
|
+
placed.append(best_candidate)
|
|
485
|
+
|
|
486
|
+
return placed
|
|
487
|
+
|
|
488
|
+
def plot_percomplex_scatter(
|
|
489
|
+
n_top=10,
|
|
490
|
+
sig_color='black',
|
|
491
|
+
nonsig_color='none',
|
|
492
|
+
label_color='black',
|
|
493
|
+
border_color='black',
|
|
494
|
+
border_width=1.0,
|
|
495
|
+
nonsig_border_color="#7F7F7F",
|
|
496
|
+
nonsig_border_width=0.5,
|
|
497
|
+
show_text_background=True,
|
|
498
|
+
):
|
|
499
|
+
config = dload("config")
|
|
500
|
+
plot_config = config["plotting"]
|
|
501
|
+
rdict = dload("pra_percomplex")
|
|
502
|
+
input_colors = dload("input", "colors")
|
|
503
|
+
input_colors = {_sanitize(k): v for k, v in input_colors.items()} if input_colors else {}
|
|
286
504
|
|
|
287
505
|
if len(rdict) < 2:
|
|
288
|
-
|
|
506
|
+
log.warning(
|
|
507
|
+
"Skipping plot: at least two datasets are required for per-complex scatter plot."
|
|
508
|
+
)
|
|
289
509
|
return
|
|
290
510
|
|
|
291
511
|
column_pairs = list(combinations(rdict.keys(), 2))
|
|
@@ -299,93 +519,61 @@ def plot_percomplex_scatter(n_top=10, sig_color='#B71A2A', nonsig_color='#DBDDDD
|
|
|
299
519
|
df = pd.concat([df, val[key]], axis=1)
|
|
300
520
|
|
|
301
521
|
for pair in column_pairs:
|
|
302
|
-
extreme_indices_0 = df[pair[0]].sort_values(ascending=False).head(n_top).index
|
|
303
|
-
extreme_indices_1 = df[pair[1]].sort_values(ascending=False).head(n_top).index
|
|
304
|
-
significant_indices = extreme_indices_0.union(extreme_indices_1)
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
#
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
# Draw connector line
|
|
359
|
-
ax.plot([x, x], [y, extended_adj_y],
|
|
360
|
-
color=label_color, linewidth=0.6, alpha=0.15, zorder=3)
|
|
361
|
-
|
|
362
|
-
# Position text at the end of extended line with small offset
|
|
363
|
-
text_y_offset = 0.01 if direction == "up" else -0.01
|
|
364
|
-
final_text_y = extended_adj_y + text_y_offset
|
|
365
|
-
|
|
366
|
-
# Final clip to ensure text stays within 0-1 range
|
|
367
|
-
final_text_y = max(0.02, min(final_text_y, 0.98))
|
|
368
|
-
|
|
369
|
-
# Prepare text bbox settings (can be turned on/off)
|
|
370
|
-
bbox_props = dict(facecolor="white", alpha=0.7, edgecolor="none", pad=1) if show_text_background else None
|
|
371
|
-
|
|
372
|
-
ax.text(
|
|
373
|
-
x, final_text_y,
|
|
374
|
-
df.loc[idx, 'Name'][:10] + '.' if len(df.loc[idx, 'Name']) > 10 else df.loc[idx, 'Name'],
|
|
375
|
-
fontsize=4,
|
|
376
|
-
ha='left',
|
|
377
|
-
va='bottom' if direction == "up" else 'top',
|
|
378
|
-
color=label_color,
|
|
379
|
-
linespacing=1,
|
|
380
|
-
zorder=4,
|
|
381
|
-
clip_on=True, # Enable clipping to axes bounds
|
|
382
|
-
bbox=bbox_props
|
|
383
|
-
)
|
|
384
|
-
|
|
385
|
-
# Diagonal & axes cosmetics
|
|
386
|
-
ax.plot([0, 1], [0, 1], linestyle='-', color='lightgray', alpha=0.4, linewidth=0.5, zorder=1)
|
|
387
|
-
|
|
388
|
-
# Force square aspect ratio and exact 0-1 range
|
|
522
|
+
extreme_indices_0 = df[pair[0]].sort_values(ascending=False).head(n_top).index
|
|
523
|
+
extreme_indices_1 = df[pair[1]].sort_values(ascending=False).head(n_top).index
|
|
524
|
+
significant_indices = extreme_indices_0.union(extreme_indices_1)
|
|
525
|
+
significant_in_both = extreme_indices_0.intersection(extreme_indices_1)
|
|
526
|
+
significant_pair0_only = extreme_indices_0.difference(extreme_indices_1)
|
|
527
|
+
significant_pair1_only = extreme_indices_1.difference(extreme_indices_0)
|
|
528
|
+
|
|
529
|
+
bg_df = df.drop(index=significant_indices)
|
|
530
|
+
sig_df = df.loc[significant_indices]
|
|
531
|
+
|
|
532
|
+
# Create square figure
|
|
533
|
+
fig, ax = plt.subplots(figsize=(6, 6))
|
|
534
|
+
|
|
535
|
+
# Background cloud: non-significant complexes are open circles.
|
|
536
|
+
bg_sizes = (bg_df['n_used_genes'] if 'n_used_genes' in bg_df else pd.Series(1, index=bg_df.index)) * 5
|
|
537
|
+
ax.scatter(
|
|
538
|
+
bg_df[pair[0]], bg_df[pair[1]],
|
|
539
|
+
facecolors="none", edgecolors=nonsig_border_color,
|
|
540
|
+
s=bg_sizes, linewidth=nonsig_border_width, alpha=0.8,
|
|
541
|
+
zorder=0
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
def scatter_significant(indices, color, zorder=2):
|
|
545
|
+
if len(indices) == 0:
|
|
546
|
+
return
|
|
547
|
+
point_df = df.loc[indices]
|
|
548
|
+
point_sizes = (
|
|
549
|
+
point_df['n_used_genes']
|
|
550
|
+
if 'n_used_genes' in point_df
|
|
551
|
+
else pd.Series(1, index=point_df.index)
|
|
552
|
+
) * 8
|
|
553
|
+
ax.scatter(
|
|
554
|
+
point_df[pair[0]], point_df[pair[1]],
|
|
555
|
+
facecolors=color, edgecolors=color,
|
|
556
|
+
s=point_sizes, linewidth=border_width, zorder=zorder
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
# Dataset-specific significant complexes use the dataset input color.
|
|
560
|
+
scatter_significant(
|
|
561
|
+
significant_pair0_only,
|
|
562
|
+
input_colors.get(_sanitize(pair[0]), sig_color),
|
|
563
|
+
zorder=2,
|
|
564
|
+
)
|
|
565
|
+
scatter_significant(
|
|
566
|
+
significant_pair1_only,
|
|
567
|
+
input_colors.get(_sanitize(pair[1]), sig_color),
|
|
568
|
+
zorder=2,
|
|
569
|
+
)
|
|
570
|
+
# Complexes significant in both datasets stay black to avoid ambiguous color mixing.
|
|
571
|
+
scatter_significant(significant_in_both, "black", zorder=3)
|
|
572
|
+
|
|
573
|
+
# Diagonal & axes cosmetics
|
|
574
|
+
ax.plot([0, 1], [0, 1], linestyle='-', color='lightgray', alpha=0.4, linewidth=0.5, zorder=1)
|
|
575
|
+
|
|
576
|
+
# Force square aspect ratio and exact 0-1 range
|
|
389
577
|
ax.set_xlim(0, 1)
|
|
390
578
|
ax.set_ylim(0, 1)
|
|
391
579
|
ax.set_aspect('equal', adjustable='box')
|
|
@@ -400,15 +588,55 @@ def plot_percomplex_scatter(n_top=10, sig_color='#B71A2A', nonsig_color='#DBDDDD
|
|
|
400
588
|
#ax.set_title(f"{pair[0]} vs {pair[1]} - Comparison of complex performance")
|
|
401
589
|
|
|
402
590
|
# Nature style: no grid, open top/right spines
|
|
403
|
-
ax.grid(False)
|
|
404
|
-
ax.spines['top'].set_visible(False)
|
|
405
|
-
ax.spines['right'].set_visible(False)
|
|
406
|
-
|
|
407
|
-
plt.tight_layout()
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
591
|
+
ax.grid(False)
|
|
592
|
+
ax.spines['top'].set_visible(False)
|
|
593
|
+
ax.spines['right'].set_visible(False)
|
|
594
|
+
|
|
595
|
+
plt.tight_layout()
|
|
596
|
+
|
|
597
|
+
# Radial label positioning searches all directions around each significant point.
|
|
598
|
+
label_items = []
|
|
599
|
+
for idx in sig_df.index:
|
|
600
|
+
label_items.append((
|
|
601
|
+
float(sig_df.loc[idx, pair[0]]),
|
|
602
|
+
float(sig_df.loc[idx, pair[1]]),
|
|
603
|
+
_short_scatter_label(df.loc[idx, 'Name']),
|
|
604
|
+
))
|
|
605
|
+
all_points = df[[pair[0], pair[1]]].dropna().to_numpy(dtype=float)
|
|
606
|
+
label_positions = _place_scatter_labels_radially(
|
|
607
|
+
ax,
|
|
608
|
+
label_items,
|
|
609
|
+
all_points,
|
|
610
|
+
fontsize=4,
|
|
611
|
+
bbox_props=None,
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
for label_pos in label_positions:
|
|
615
|
+
ax.plot(
|
|
616
|
+
[label_pos["point_x"], label_pos["text_x"]],
|
|
617
|
+
[label_pos["point_y"], label_pos["text_y"]],
|
|
618
|
+
color=label_color,
|
|
619
|
+
linewidth=0.6,
|
|
620
|
+
alpha=0.15,
|
|
621
|
+
zorder=3,
|
|
622
|
+
)
|
|
623
|
+
ax.text(
|
|
624
|
+
label_pos["text_x"],
|
|
625
|
+
label_pos["text_y"],
|
|
626
|
+
label_pos["label"],
|
|
627
|
+
fontsize=4,
|
|
628
|
+
ha=label_pos["ha"],
|
|
629
|
+
va=label_pos["va"],
|
|
630
|
+
color=label_color,
|
|
631
|
+
linespacing=1,
|
|
632
|
+
zorder=4,
|
|
633
|
+
clip_on=True,
|
|
634
|
+
bbox=None,
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
if plot_config["save_plot"]:
|
|
638
|
+
output_type = plot_config["output_type"]
|
|
639
|
+
output_path = Path(config["output_folder"]) / f"percomplex_scatter_{pair[0]}_vs_{pair[1]}.{output_type}"
|
|
412
640
|
fig.savefig(output_path, bbox_inches="tight", format=output_type)
|
|
413
641
|
|
|
414
642
|
if plot_config.get("show_plot", True):
|
|
@@ -894,16 +1122,28 @@ def position_cluster_labels(cluster, cluster_id, max_y, effective_max_y, label_c
|
|
|
894
1122
|
clip_on=True, bbox=bbox_props
|
|
895
1123
|
)
|
|
896
1124
|
|
|
897
|
-
def plot_percomplex_scatter_bysize(
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
1125
|
+
def plot_percomplex_scatter_bysize(
|
|
1126
|
+
n_labels=10,
|
|
1127
|
+
n_top=10,
|
|
1128
|
+
sig_color='black',
|
|
1129
|
+
nonsig_color='none',
|
|
1130
|
+
label_color='black',
|
|
1131
|
+
border_color='black',
|
|
1132
|
+
border_width=1.0,
|
|
1133
|
+
nonsig_border_color="#7F7F7F",
|
|
1134
|
+
nonsig_border_width=0.5,
|
|
1135
|
+
show_text_background=True,
|
|
1136
|
+
):
|
|
1137
|
+
config = dload("config")
|
|
1138
|
+
plot_config = config["plotting"]
|
|
1139
|
+
rdict = dload("pra_percomplex")
|
|
1140
|
+
input_colors = dload("input", "colors")
|
|
1141
|
+
input_colors = {_sanitize(k): v for k, v in input_colors.items()} if input_colors else {}
|
|
1142
|
+
|
|
1143
|
+
for key, per_complex in rdict.items():
|
|
1144
|
+
dataset_color = input_colors.get(_sanitize(key), sig_color)
|
|
1145
|
+
sorted_pc = per_complex.sort_values(by="auc_score", ascending=False, na_position="last")
|
|
1146
|
+
top_labels, rest = sorted_pc.head(n_labels), sorted_pc.iloc[n_labels:]
|
|
907
1147
|
|
|
908
1148
|
# Calculate data range for appropriate figure sizing
|
|
909
1149
|
max_genes = sorted_pc.n_used_genes.max()
|
|
@@ -914,22 +1154,22 @@ def plot_percomplex_scatter_bysize(n_labels=10, n_top=10, sig_color='#B71A2A', n
|
|
|
914
1154
|
fig_height = min(max(4, aspect_ratio), 8) # Between 4-8 inches
|
|
915
1155
|
fig, ax = plt.subplots(figsize=(6, fig_height))
|
|
916
1156
|
|
|
917
|
-
# Background
|
|
918
|
-
ax.scatter(
|
|
919
|
-
rest.auc_score, rest.n_used_genes,
|
|
920
|
-
facecolors=
|
|
921
|
-
linewidth=
|
|
922
|
-
alpha=
|
|
923
|
-
zorder=0
|
|
924
|
-
)
|
|
925
|
-
|
|
926
|
-
# Top N
|
|
927
|
-
ax.scatter(
|
|
928
|
-
top_labels.auc_score, top_labels.n_used_genes,
|
|
929
|
-
facecolors=
|
|
930
|
-
linewidth=border_width, s=top_labels.n_used_genes * 8,
|
|
931
|
-
label=f"Top {n_labels} AUC Scores", alpha=1.0, zorder=2
|
|
932
|
-
)
|
|
1157
|
+
# Background: non-significant complexes are open circles.
|
|
1158
|
+
ax.scatter(
|
|
1159
|
+
rest.auc_score, rest.n_used_genes,
|
|
1160
|
+
facecolors="none", edgecolors=nonsig_border_color,
|
|
1161
|
+
linewidth=nonsig_border_width, s=rest.n_used_genes * 5,
|
|
1162
|
+
alpha=0.8, label="Other Complexes",
|
|
1163
|
+
zorder=0
|
|
1164
|
+
)
|
|
1165
|
+
|
|
1166
|
+
# Top N/significant complexes are filled black circles.
|
|
1167
|
+
ax.scatter(
|
|
1168
|
+
top_labels.auc_score, top_labels.n_used_genes,
|
|
1169
|
+
facecolors=dataset_color, edgecolors=dataset_color,
|
|
1170
|
+
linewidth=border_width, s=top_labels.n_used_genes * 8,
|
|
1171
|
+
label=f"Top {n_labels} AUC Scores", alpha=1.0, zorder=2
|
|
1172
|
+
)
|
|
933
1173
|
|
|
934
1174
|
# Enhanced anti-overlap labeling system
|
|
935
1175
|
coords = [(row.auc_score, row.n_used_genes, idx) for idx, row in top_labels.iterrows()]
|
|
@@ -1010,6 +1250,7 @@ def plot_complex_contributions(
|
|
|
1010
1250
|
tmp = np.tile(x, (mx, 1))
|
|
1011
1251
|
x = cont_stepwise_mat.values / tmp
|
|
1012
1252
|
x_df = pd.DataFrame(x, index=cont_stepwise_anno, columns=cont_stepwise_mat.columns)
|
|
1253
|
+
|
|
1013
1254
|
ind_for_mean = y >= (last_prec_value - min_precision_cutoff)
|
|
1014
1255
|
if sum(ind_for_mean) == 0:
|
|
1015
1256
|
log.info("No values above 'min.precision.cutoff'"); return False
|
|
@@ -1094,15 +1335,23 @@ def plot_significant_complexes():
|
|
|
1094
1335
|
input_colors = {_sanitize(k): v for k, v in input_colors.items()}
|
|
1095
1336
|
|
|
1096
1337
|
thresholds = [0.1, 0.2, 0.3, 0.4, 0.5]
|
|
1338
|
+
if not isinstance(pra_percomplex, dict) or not pra_percomplex:
|
|
1339
|
+
log.warning("No per-complex PRA data found. Run pra_percomplex() first.")
|
|
1340
|
+
return pd.DataFrame(index=thresholds)
|
|
1341
|
+
|
|
1097
1342
|
datasets = list(pra_percomplex.keys())
|
|
1098
1343
|
num_datasets = len(datasets)
|
|
1099
1344
|
|
|
1345
|
+
if num_datasets == 0:
|
|
1346
|
+
return pd.DataFrame(index=thresholds)
|
|
1347
|
+
|
|
1100
1348
|
df = pd.DataFrame(index=thresholds)
|
|
1101
1349
|
for key, complex_data in pra_percomplex.items():
|
|
1102
1350
|
if "corrected_auc_score" in complex_data.columns:
|
|
1103
1351
|
score_col = "corrected_auc_score"
|
|
1104
1352
|
else:
|
|
1105
1353
|
score_col = "auc_score"
|
|
1354
|
+
|
|
1106
1355
|
df[key] = [complex_data.query(f'{score_col} >= {t}').shape[0] for t in thresholds]
|
|
1107
1356
|
|
|
1108
1357
|
fig, ax = plt.subplots()
|
|
@@ -1221,15 +1470,20 @@ def plot_auc_scores():
|
|
|
1221
1470
|
return pra_dict
|
|
1222
1471
|
|
|
1223
1472
|
|
|
1224
|
-
def
|
|
1473
|
+
def plot_mpr_complex_auc_scores(variant: str = "unfiltered", save=None, outname=None):
|
|
1225
1474
|
"""Plot AUC scores for the mPR complexes curve (Fig 1F-style).
|
|
1226
1475
|
|
|
1227
1476
|
Requires `mpr_prepare()` to have been run for each dataset.
|
|
1228
1477
|
|
|
1229
1478
|
Parameters
|
|
1230
1479
|
----------
|
|
1231
|
-
|
|
1232
|
-
One of: "
|
|
1480
|
+
variant : str
|
|
1481
|
+
One of: "unfiltered", "without_mt_ribo_etci",
|
|
1482
|
+
"without_small_high_auprc".
|
|
1483
|
+
save : bool, optional
|
|
1484
|
+
Whether to save the figure. If None, uses config["plotting"]["save_plot"].
|
|
1485
|
+
outname : str, optional
|
|
1486
|
+
Output filename. If None, auto-generated.
|
|
1233
1487
|
|
|
1234
1488
|
Returns
|
|
1235
1489
|
-------
|
|
@@ -1250,12 +1504,14 @@ def plot_mpr_complexes_auc_scores(filter_key: str = "all"):
|
|
|
1250
1504
|
)
|
|
1251
1505
|
return pd.Series(dtype=float)
|
|
1252
1506
|
|
|
1507
|
+
variant_key = _normalize_mpr_variant(variant)
|
|
1508
|
+
|
|
1253
1509
|
# Build Series: dataset -> auc
|
|
1254
1510
|
auc_by_dataset = {}
|
|
1255
1511
|
for dataset, per_filter in mpr_auc_dict.items():
|
|
1256
1512
|
if not isinstance(per_filter, dict):
|
|
1257
1513
|
continue
|
|
1258
|
-
val = per_filter.get(
|
|
1514
|
+
val = per_filter.get(variant_key)
|
|
1259
1515
|
if val is None:
|
|
1260
1516
|
continue
|
|
1261
1517
|
try:
|
|
@@ -1265,7 +1521,8 @@ def plot_mpr_complexes_auc_scores(filter_key: str = "all"):
|
|
|
1265
1521
|
|
|
1266
1522
|
if not auc_by_dataset:
|
|
1267
1523
|
log.warning(
|
|
1268
|
-
f"No mPR
|
|
1524
|
+
f"No mPR complex AUC scores found for variant '{variant}'. "
|
|
1525
|
+
f"Available variants: {list(PUBLIC_MPR_VARIANTS.keys())}"
|
|
1269
1526
|
)
|
|
1270
1527
|
return pd.Series(dtype=float)
|
|
1271
1528
|
|
|
@@ -1308,11 +1565,16 @@ def plot_mpr_complexes_auc_scores(filter_key: str = "all"):
|
|
|
1308
1565
|
ax.spines["top"].set_visible(False)
|
|
1309
1566
|
ax.spines["right"].set_visible(False)
|
|
1310
1567
|
|
|
1311
|
-
|
|
1568
|
+
should_save = plot_config.get("save_plot", False) if save is None else bool(save)
|
|
1569
|
+
if should_save:
|
|
1312
1570
|
output_type = plot_config.get("output_type", "pdf")
|
|
1313
1571
|
output_folder = Path(config["output_folder"])
|
|
1314
1572
|
output_folder.mkdir(parents=True, exist_ok=True)
|
|
1315
|
-
|
|
1573
|
+
if outname is None:
|
|
1574
|
+
outname = f"mpr_complexes_auc_{variant_key}.{output_type}"
|
|
1575
|
+
output_path = Path(outname)
|
|
1576
|
+
if len(output_path.parts) == 1:
|
|
1577
|
+
output_path = output_folder / outname
|
|
1316
1578
|
plt.savefig(output_path, bbox_inches="tight", format=output_type)
|
|
1317
1579
|
|
|
1318
1580
|
if plot_config.get("show_plot", True):
|
|
@@ -1321,6 +1583,13 @@ def plot_mpr_complexes_auc_scores(filter_key: str = "all"):
|
|
|
1321
1583
|
plt.close(fig)
|
|
1322
1584
|
return s
|
|
1323
1585
|
|
|
1586
|
+
|
|
1587
|
+
def plot_mpr_complexes_auc_scores(filter_key: str = "all"):
|
|
1588
|
+
"""Backward-compatible wrapper for plot_mpr_complex_auc_scores()."""
|
|
1589
|
+
return plot_mpr_complex_auc_scores(
|
|
1590
|
+
variant=_legacy_filter_to_variant(filter_key, default="unfiltered")
|
|
1591
|
+
)
|
|
1592
|
+
|
|
1324
1593
|
# -----------------------------------------------------------------------------
|
|
1325
1594
|
# mPR plots (Fig. 1E and Fig. 1F)
|
|
1326
1595
|
# -----------------------------------------------------------------------------
|
|
@@ -1475,26 +1744,6 @@ def plot_mpr_tp(name, ax=None, save=True, outname=None):
|
|
|
1475
1744
|
|
|
1476
1745
|
return ax
|
|
1477
1746
|
|
|
1478
|
-
"""
|
|
1479
|
-
Multi-dataset mPR plotting functions.
|
|
1480
|
-
|
|
1481
|
-
Usage:
|
|
1482
|
-
from pythonflex.plotting import plot_mpr_tp_multi, plot_mpr_complexes_multi
|
|
1483
|
-
|
|
1484
|
-
# Plot multiple datasets
|
|
1485
|
-
plot_mpr_tp_multi(["19Q2", "19Q4", "20Q1"])
|
|
1486
|
-
plot_mpr_complexes_multi(["19Q2", "19Q4", "20Q1"])
|
|
1487
|
-
"""
|
|
1488
|
-
|
|
1489
|
-
import numpy as np
|
|
1490
|
-
import pandas as pd
|
|
1491
|
-
import matplotlib.pyplot as plt
|
|
1492
|
-
from matplotlib.lines import Line2D
|
|
1493
|
-
from pathlib import Path
|
|
1494
|
-
|
|
1495
|
-
from .utils import dload
|
|
1496
|
-
from .logging_config import log
|
|
1497
|
-
|
|
1498
1747
|
# Default color palette (colorblind-friendly)
|
|
1499
1748
|
DEFAULT_COLORS = [
|
|
1500
1749
|
"#4E79A7", # blue
|
|
@@ -1509,40 +1758,101 @@ DEFAULT_COLORS = [
|
|
|
1509
1758
|
"#BAB0AC", # gray
|
|
1510
1759
|
]
|
|
1511
1760
|
|
|
1512
|
-
#
|
|
1513
|
-
|
|
1761
|
+
# Public mPR variant names map to the internal keys stored by mpr_prepare().
|
|
1762
|
+
PUBLIC_MPR_VARIANTS = {
|
|
1763
|
+
"unfiltered": "all",
|
|
1764
|
+
"without_mt_ribo_etci": "no_mtRibo_ETCI",
|
|
1765
|
+
"without_small_high_auprc": "no_small_highAUPRC",
|
|
1766
|
+
}
|
|
1767
|
+
INTERNAL_MPR_VARIANTS = {v: k for k, v in PUBLIC_MPR_VARIANTS.items()}
|
|
1768
|
+
|
|
1769
|
+
# mPR variant line styles keyed by internal storage names.
|
|
1770
|
+
MPR_VARIANT_STYLES = {
|
|
1514
1771
|
"all": {"linestyle": "-", "label": "all data"},
|
|
1515
1772
|
"no_mtRibo_ETCI": {"linestyle": "--", "label": "no mtRibo, ETC I"},
|
|
1516
1773
|
"no_small_highAUPRC": {"linestyle": "dotted", "label": "no small, high AUPRC"},
|
|
1517
1774
|
}
|
|
1518
1775
|
|
|
1776
|
+
# Compatibility alias for users who imported this internal constant.
|
|
1777
|
+
FILTER_STYLES = MPR_VARIANT_STYLES
|
|
1778
|
+
|
|
1779
|
+
|
|
1780
|
+
def _normalize_mpr_variant(variant):
|
|
1781
|
+
"""Return the internal mPR variant key for one public variant name."""
|
|
1782
|
+
if variant in PUBLIC_MPR_VARIANTS:
|
|
1783
|
+
return PUBLIC_MPR_VARIANTS[variant]
|
|
1784
|
+
if variant in MPR_VARIANT_STYLES:
|
|
1785
|
+
if variant == "all":
|
|
1786
|
+
return PUBLIC_MPR_VARIANTS["unfiltered"]
|
|
1787
|
+
return variant
|
|
1788
|
+
raise ValueError(
|
|
1789
|
+
"Unknown mPR variant "
|
|
1790
|
+
f"{variant!r}. Use one of {list(PUBLIC_MPR_VARIANTS.keys())}."
|
|
1791
|
+
)
|
|
1519
1792
|
|
|
1520
|
-
def _normalize_show_filters(show_filters):
|
|
1521
|
-
"""Normalize show_filters to an ordered tuple of filter keys.
|
|
1522
1793
|
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
|
|
1794
|
+
def _normalize_mpr_variants(variants):
|
|
1795
|
+
"""Normalize public mPR variant names to internal storage keys."""
|
|
1796
|
+
if variants is None:
|
|
1797
|
+
raw_variants = ("all",)
|
|
1798
|
+
elif isinstance(variants, str):
|
|
1799
|
+
raw_variants = (variants,)
|
|
1800
|
+
else:
|
|
1801
|
+
try:
|
|
1802
|
+
raw_variants = tuple(variants)
|
|
1803
|
+
except TypeError:
|
|
1804
|
+
raw_variants = (variants,)
|
|
1805
|
+
|
|
1806
|
+
out = []
|
|
1807
|
+
for variant in raw_variants:
|
|
1808
|
+
if variant == "all":
|
|
1809
|
+
out.extend(PUBLIC_MPR_VARIANTS.values())
|
|
1810
|
+
else:
|
|
1811
|
+
out.append(_normalize_mpr_variant(variant))
|
|
1812
|
+
|
|
1813
|
+
# Preserve user order while removing duplicates.
|
|
1814
|
+
return tuple(dict.fromkeys(out))
|
|
1815
|
+
|
|
1816
|
+
|
|
1817
|
+
def _legacy_filter_to_variant(filter_key, default=None):
|
|
1818
|
+
"""Map old filter-key names to public variant names."""
|
|
1819
|
+
if filter_key is None:
|
|
1820
|
+
return default if default is not None else "all"
|
|
1821
|
+
mapping = {
|
|
1822
|
+
"all": "unfiltered",
|
|
1823
|
+
"no_mtRibo_ETCI": "without_mt_ribo_etci",
|
|
1824
|
+
"no_small_highAUPRC": "without_small_high_auprc",
|
|
1825
|
+
}
|
|
1826
|
+
return mapping.get(filter_key, filter_key)
|
|
1827
|
+
|
|
1828
|
+
|
|
1829
|
+
def _legacy_filters_to_variants(show_filters):
|
|
1830
|
+
"""Map old show_filters values to public variant names."""
|
|
1526
1831
|
if show_filters is None:
|
|
1527
|
-
return
|
|
1832
|
+
return "all"
|
|
1528
1833
|
if isinstance(show_filters, str):
|
|
1529
|
-
return (show_filters
|
|
1834
|
+
return _legacy_filter_to_variant(show_filters)
|
|
1530
1835
|
try:
|
|
1531
|
-
return tuple(show_filters)
|
|
1836
|
+
return tuple(_legacy_filter_to_variant(item) for item in show_filters)
|
|
1532
1837
|
except TypeError:
|
|
1533
|
-
return (show_filters,)
|
|
1838
|
+
return (_legacy_filter_to_variant(show_filters),)
|
|
1534
1839
|
|
|
1535
|
-
|
|
1840
|
+
|
|
1841
|
+
def _normalize_show_filters(show_filters):
|
|
1842
|
+
"""Backward-compatible normalizer for old internal filter keys."""
|
|
1843
|
+
return _normalize_mpr_variants(_legacy_filters_to_variants(show_filters))
|
|
1844
|
+
|
|
1845
|
+
def plot_mpr_true_positive_curve(
|
|
1536
1846
|
dataset_names=None,
|
|
1537
1847
|
colors=None,
|
|
1538
1848
|
ax=None,
|
|
1539
1849
|
save=True,
|
|
1540
1850
|
outname=None,
|
|
1541
1851
|
linewidth=1.8,
|
|
1542
|
-
|
|
1852
|
+
variants="unfiltered",
|
|
1543
1853
|
):
|
|
1544
1854
|
"""
|
|
1545
|
-
Plot
|
|
1855
|
+
Plot mPR true-positive vs precision curves for multiple datasets.
|
|
1546
1856
|
|
|
1547
1857
|
Can auto-detect datasets or use provided dataset names.
|
|
1548
1858
|
Each dataset gets one color, each filter type gets one line style.
|
|
@@ -1562,8 +1872,9 @@ def plot_mpr_tp_multi(
|
|
|
1562
1872
|
Output filename. If None, auto-generated.
|
|
1563
1873
|
linewidth : float
|
|
1564
1874
|
Line width for all curves
|
|
1565
|
-
|
|
1566
|
-
Which
|
|
1875
|
+
variants : str or iterable of str
|
|
1876
|
+
Which mPR variants to show. Use "unfiltered",
|
|
1877
|
+
"without_mt_ribo_etci", "without_small_high_auprc", or "all".
|
|
1567
1878
|
|
|
1568
1879
|
Returns
|
|
1569
1880
|
-------
|
|
@@ -1573,7 +1884,7 @@ def plot_mpr_tp_multi(
|
|
|
1573
1884
|
plot_config = config["plotting"]
|
|
1574
1885
|
input_colors = dload("input", "colors")
|
|
1575
1886
|
|
|
1576
|
-
|
|
1887
|
+
variant_keys = _normalize_mpr_variants(variants)
|
|
1577
1888
|
|
|
1578
1889
|
# Sanitize color keys
|
|
1579
1890
|
if input_colors:
|
|
@@ -1641,13 +1952,13 @@ def plot_mpr_tp_multi(
|
|
|
1641
1952
|
tp_curves = mpr["tp_curves"]
|
|
1642
1953
|
color = colors[i % len(colors)]
|
|
1643
1954
|
|
|
1644
|
-
for
|
|
1645
|
-
if
|
|
1955
|
+
for variant_key in variant_keys:
|
|
1956
|
+
if variant_key not in tp_curves:
|
|
1646
1957
|
continue
|
|
1647
1958
|
|
|
1648
|
-
data = tp_curves[
|
|
1959
|
+
data = tp_curves[variant_key]
|
|
1649
1960
|
if not isinstance(data, dict) or "tp" not in data or "precision" not in data:
|
|
1650
|
-
log.warning(f"Invalid tp_curves data structure for '{name}'
|
|
1961
|
+
log.warning(f"Invalid tp_curves data structure for '{name}' variant '{variant_key}', skipping.")
|
|
1651
1962
|
continue
|
|
1652
1963
|
|
|
1653
1964
|
tp = np.asarray(data["tp"], dtype=float)
|
|
@@ -1661,7 +1972,7 @@ def plot_mpr_tp_multi(
|
|
|
1661
1972
|
prec_plot = prec[mask]
|
|
1662
1973
|
xmax = max(xmax, float(tp_plot.max()))
|
|
1663
1974
|
|
|
1664
|
-
style =
|
|
1975
|
+
style = MPR_VARIANT_STYLES.get(variant_key, {})
|
|
1665
1976
|
ax.plot(
|
|
1666
1977
|
tp_plot,
|
|
1667
1978
|
prec_plot,
|
|
@@ -1694,7 +2005,7 @@ def plot_mpr_tp_multi(
|
|
|
1694
2005
|
ax.spines['right'].set_visible(False)
|
|
1695
2006
|
|
|
1696
2007
|
# Create vertically stacked legends
|
|
1697
|
-
_add_vertical_legend(ax, dataset_names, colors,
|
|
2008
|
+
_add_vertical_legend(ax, dataset_names, colors, variant_keys, linewidth)
|
|
1698
2009
|
|
|
1699
2010
|
# Save
|
|
1700
2011
|
if save:
|
|
@@ -1713,7 +2024,8 @@ def plot_mpr_tp_multi(
|
|
|
1713
2024
|
|
|
1714
2025
|
return ax
|
|
1715
2026
|
|
|
1716
|
-
|
|
2027
|
+
|
|
2028
|
+
def plot_mpr_tp_multi(
|
|
1717
2029
|
dataset_names=None,
|
|
1718
2030
|
colors=None,
|
|
1719
2031
|
ax=None,
|
|
@@ -1721,11 +2033,31 @@ def plot_mpr_complexes_multi(
|
|
|
1721
2033
|
outname=None,
|
|
1722
2034
|
linewidth=1.8,
|
|
1723
2035
|
show_filters=("all", "no_mtRibo_ETCI", "no_small_highAUPRC"),
|
|
2036
|
+
):
|
|
2037
|
+
"""Backward-compatible wrapper for plot_mpr_true_positive_curve()."""
|
|
2038
|
+
return plot_mpr_true_positive_curve(
|
|
2039
|
+
dataset_names=dataset_names,
|
|
2040
|
+
colors=colors,
|
|
2041
|
+
ax=ax,
|
|
2042
|
+
save=save,
|
|
2043
|
+
outname=outname,
|
|
2044
|
+
linewidth=linewidth,
|
|
2045
|
+
variants=_legacy_filters_to_variants(show_filters),
|
|
2046
|
+
)
|
|
2047
|
+
|
|
2048
|
+
def plot_mpr_complex_coverage_curve(
|
|
2049
|
+
dataset_names=None,
|
|
2050
|
+
colors=None,
|
|
2051
|
+
ax=None,
|
|
2052
|
+
save=True,
|
|
2053
|
+
outname=None,
|
|
2054
|
+
linewidth=1.8,
|
|
2055
|
+
variants="unfiltered",
|
|
1724
2056
|
show_markers="auto",
|
|
1725
2057
|
marker_size=20,
|
|
1726
2058
|
):
|
|
1727
2059
|
"""
|
|
1728
|
-
Plot
|
|
2060
|
+
Plot mPR complex-coverage vs precision curves for multiple datasets.
|
|
1729
2061
|
|
|
1730
2062
|
Can auto-detect datasets or use provided dataset names.
|
|
1731
2063
|
Each dataset gets one color, each filter type gets one line style.
|
|
@@ -1745,8 +2077,9 @@ def plot_mpr_complexes_multi(
|
|
|
1745
2077
|
Output filename. If None, auto-generated.
|
|
1746
2078
|
linewidth : float
|
|
1747
2079
|
Line width for all curves
|
|
1748
|
-
|
|
1749
|
-
Which
|
|
2080
|
+
variants : str or iterable of str
|
|
2081
|
+
Which mPR variants to show. Use "unfiltered",
|
|
2082
|
+
"without_mt_ribo_etci", "without_small_high_auprc", or "all".
|
|
1750
2083
|
show_markers : bool or "auto"
|
|
1751
2084
|
If True, draw markers on curves to make short curves visible.
|
|
1752
2085
|
If "auto" (default), markers are drawn only for curves with <= 10 points.
|
|
@@ -1761,7 +2094,7 @@ def plot_mpr_complexes_multi(
|
|
|
1761
2094
|
plot_config = config["plotting"]
|
|
1762
2095
|
input_colors = dload("input", "colors")
|
|
1763
2096
|
|
|
1764
|
-
|
|
2097
|
+
variant_keys = _normalize_mpr_variants(variants)
|
|
1765
2098
|
|
|
1766
2099
|
# Sanitize color keys
|
|
1767
2100
|
if input_colors:
|
|
@@ -1812,32 +2145,61 @@ def plot_mpr_complexes_multi(
|
|
|
1812
2145
|
else:
|
|
1813
2146
|
fig = ax.figure
|
|
1814
2147
|
|
|
1815
|
-
#
|
|
2148
|
+
# First pass: determine max coverage across all datasets/filters for adaptive x-axis
|
|
2149
|
+
max_cov_global = 0
|
|
2150
|
+
_mpr_cache = {}
|
|
1816
2151
|
for i, name in enumerate(dataset_names):
|
|
1817
2152
|
mpr = dload("mpr", name)
|
|
2153
|
+
_mpr_cache[name] = mpr
|
|
2154
|
+
if mpr is not None:
|
|
2155
|
+
for variant_key in variant_keys:
|
|
2156
|
+
arr = mpr["coverage_curves"].get(variant_key)
|
|
2157
|
+
if arr is not None:
|
|
2158
|
+
max_cov_global = max(max_cov_global, float(np.asarray(arr).max()))
|
|
2159
|
+
|
|
2160
|
+
# Build adaptive x-axis limits and ticks
|
|
2161
|
+
import math
|
|
2162
|
+
if max_cov_global <= 200:
|
|
2163
|
+
# Original fixed range — keeps CORUM plots identical to before
|
|
2164
|
+
x_max_plot = 200
|
|
2165
|
+
tick_positions = [1, 2, 20, 200]
|
|
2166
|
+
tick_labels = ["0", "2", "20", "200"]
|
|
2167
|
+
else:
|
|
2168
|
+
# Round up to the next power of 10 so the max bar has breathing room
|
|
2169
|
+
x_max_plot = 10 ** math.ceil(math.log10(max_cov_global + 1))
|
|
2170
|
+
tick_positions = [1, 2]
|
|
2171
|
+
v = 10
|
|
2172
|
+
while v <= x_max_plot:
|
|
2173
|
+
tick_positions.append(v)
|
|
2174
|
+
v *= 10
|
|
2175
|
+
tick_labels = ["0"] + [str(t) for t in tick_positions[1:]]
|
|
2176
|
+
|
|
2177
|
+
# Plot each dataset
|
|
2178
|
+
for i, name in enumerate(dataset_names):
|
|
2179
|
+
mpr = _mpr_cache[name]
|
|
1818
2180
|
if mpr is None:
|
|
1819
2181
|
log.warning(f"mPR data for '{name}' not found, skipping.")
|
|
1820
2182
|
continue
|
|
1821
|
-
|
|
2183
|
+
|
|
1822
2184
|
precision_cutoffs = np.asarray(mpr["precision_cutoffs"], dtype=float)
|
|
1823
2185
|
coverage = mpr["coverage_curves"]
|
|
1824
2186
|
color = colors[i % len(colors)]
|
|
1825
|
-
|
|
1826
|
-
for
|
|
1827
|
-
if
|
|
2187
|
+
|
|
2188
|
+
for variant_key in variant_keys:
|
|
2189
|
+
if variant_key not in coverage:
|
|
1828
2190
|
continue
|
|
1829
|
-
|
|
1830
|
-
cov = np.asarray(coverage[
|
|
1831
|
-
|
|
1832
|
-
# Keep only positive coverage
|
|
1833
|
-
mask = (cov > 0) & (cov <=
|
|
2191
|
+
|
|
2192
|
+
cov = np.asarray(coverage[variant_key], dtype=float)
|
|
2193
|
+
|
|
2194
|
+
# Keep only positive coverage within the visible x range
|
|
2195
|
+
mask = (cov > 0) & (cov <= x_max_plot)
|
|
1834
2196
|
if not mask.any():
|
|
1835
2197
|
continue
|
|
1836
|
-
|
|
2198
|
+
|
|
1837
2199
|
cov_plot = cov[mask]
|
|
1838
2200
|
prec_plot = precision_cutoffs[mask]
|
|
1839
|
-
|
|
1840
|
-
style =
|
|
2201
|
+
|
|
2202
|
+
style = MPR_VARIANT_STYLES.get(variant_key, {})
|
|
1841
2203
|
|
|
1842
2204
|
# Decide marker visibility
|
|
1843
2205
|
if show_markers == "auto":
|
|
@@ -1858,17 +2220,15 @@ def plot_mpr_complexes_multi(
|
|
|
1858
2220
|
marker=("o" if use_markers else None),
|
|
1859
2221
|
markersize=(3 if use_markers else None),
|
|
1860
2222
|
)
|
|
1861
|
-
|
|
2223
|
+
|
|
1862
2224
|
# Configure axes
|
|
1863
2225
|
ax.set_xscale("log")
|
|
1864
|
-
ax.set_xlim(1,
|
|
2226
|
+
ax.set_xlim(1, x_max_plot)
|
|
1865
2227
|
ax.set_xlabel("# complexes")
|
|
1866
2228
|
ax.set_ylabel("Precision")
|
|
1867
2229
|
ax.set_ylim(0.0, 1.05)
|
|
1868
|
-
|
|
1869
|
-
#
|
|
1870
|
-
tick_positions = [1, 2, 20, 200]
|
|
1871
|
-
tick_labels = ["0", "2", "20", "200"]
|
|
2230
|
+
|
|
2231
|
+
# Adaptive x-ticks
|
|
1872
2232
|
ax.set_xticks(tick_positions)
|
|
1873
2233
|
ax.set_xticklabels(tick_labels)
|
|
1874
2234
|
|
|
@@ -1877,7 +2237,7 @@ def plot_mpr_complexes_multi(
|
|
|
1877
2237
|
ax.spines['right'].set_visible(False)
|
|
1878
2238
|
|
|
1879
2239
|
# Create vertically stacked legends
|
|
1880
|
-
_add_vertical_legend(ax, dataset_names, colors,
|
|
2240
|
+
_add_vertical_legend(ax, dataset_names, colors, variant_keys, linewidth)
|
|
1881
2241
|
|
|
1882
2242
|
# Save
|
|
1883
2243
|
if save:
|
|
@@ -1896,11 +2256,71 @@ def plot_mpr_complexes_multi(
|
|
|
1896
2256
|
|
|
1897
2257
|
return ax
|
|
1898
2258
|
|
|
1899
|
-
|
|
2259
|
+
|
|
2260
|
+
def plot_mpr_complexes_multi(
|
|
2261
|
+
dataset_names=None,
|
|
2262
|
+
colors=None,
|
|
2263
|
+
ax=None,
|
|
2264
|
+
save=True,
|
|
2265
|
+
outname=None,
|
|
2266
|
+
linewidth=1.8,
|
|
2267
|
+
show_filters=("all", "no_mtRibo_ETCI", "no_small_highAUPRC"),
|
|
2268
|
+
show_markers="auto",
|
|
2269
|
+
marker_size=20,
|
|
2270
|
+
):
|
|
2271
|
+
"""Backward-compatible wrapper for plot_mpr_complex_coverage_curve()."""
|
|
2272
|
+
return plot_mpr_complex_coverage_curve(
|
|
2273
|
+
dataset_names=dataset_names,
|
|
2274
|
+
colors=colors,
|
|
2275
|
+
ax=ax,
|
|
2276
|
+
save=save,
|
|
2277
|
+
outname=outname,
|
|
2278
|
+
linewidth=linewidth,
|
|
2279
|
+
variants=_legacy_filters_to_variants(show_filters),
|
|
2280
|
+
show_markers=show_markers,
|
|
2281
|
+
marker_size=marker_size,
|
|
2282
|
+
)
|
|
2283
|
+
|
|
2284
|
+
|
|
2285
|
+
def plot_mpr_summary(
|
|
2286
|
+
dataset_names=None,
|
|
2287
|
+
colors=None,
|
|
2288
|
+
variants="unfiltered",
|
|
2289
|
+
save=True,
|
|
2290
|
+
linewidth=1.8,
|
|
2291
|
+
show_markers="auto",
|
|
2292
|
+
marker_size=20,
|
|
2293
|
+
auc_variant=None,
|
|
2294
|
+
):
|
|
2295
|
+
"""Generate the standard mPR summary plots and return complex AUC scores."""
|
|
2296
|
+
plot_mpr_true_positive_curve(
|
|
2297
|
+
dataset_names=dataset_names,
|
|
2298
|
+
colors=colors,
|
|
2299
|
+
save=save,
|
|
2300
|
+
linewidth=linewidth,
|
|
2301
|
+
variants=variants,
|
|
2302
|
+
)
|
|
2303
|
+
plot_mpr_complex_coverage_curve(
|
|
2304
|
+
dataset_names=dataset_names,
|
|
2305
|
+
colors=colors,
|
|
2306
|
+
save=save,
|
|
2307
|
+
linewidth=linewidth,
|
|
2308
|
+
variants=variants,
|
|
2309
|
+
show_markers=show_markers,
|
|
2310
|
+
marker_size=marker_size,
|
|
2311
|
+
)
|
|
2312
|
+
|
|
2313
|
+
if auc_variant is None:
|
|
2314
|
+
variant_keys = _normalize_mpr_variants(variants)
|
|
2315
|
+
auc_variant = INTERNAL_MPR_VARIANTS.get(variant_keys[0], "unfiltered")
|
|
2316
|
+
|
|
2317
|
+
return plot_mpr_complex_auc_scores(variant=auc_variant, save=save)
|
|
2318
|
+
|
|
2319
|
+
def _add_vertical_legend(ax, dataset_names, colors, variant_keys, linewidth):
|
|
1900
2320
|
"""
|
|
1901
|
-
Add vertically stacked legends: Dataset on top,
|
|
2321
|
+
Add vertically stacked legends: Dataset on top, mPR variant below.
|
|
1902
2322
|
"""
|
|
1903
|
-
|
|
2323
|
+
variant_keys = _normalize_show_filters(variant_keys)
|
|
1904
2324
|
# Legend 1: Datasets (colors) - solid lines
|
|
1905
2325
|
dataset_handles = []
|
|
1906
2326
|
for i, name in enumerate(dataset_names):
|
|
@@ -1908,19 +2328,19 @@ def _add_vertical_legend(ax, dataset_names, colors, show_filters, linewidth):
|
|
|
1908
2328
|
handle = Line2D([0], [0], color=color, linewidth=linewidth, linestyle="-")
|
|
1909
2329
|
dataset_handles.append(handle)
|
|
1910
2330
|
|
|
1911
|
-
# Legend 2:
|
|
1912
|
-
|
|
1913
|
-
|
|
1914
|
-
for
|
|
1915
|
-
style =
|
|
2331
|
+
# Legend 2: mPR variants (line styles) - black lines
|
|
2332
|
+
variant_handles = []
|
|
2333
|
+
variant_labels = []
|
|
2334
|
+
for variant_key in variant_keys:
|
|
2335
|
+
style = MPR_VARIANT_STYLES.get(variant_key, {})
|
|
1916
2336
|
handle = Line2D(
|
|
1917
2337
|
[0], [0],
|
|
1918
2338
|
color="black",
|
|
1919
2339
|
linewidth=linewidth,
|
|
1920
2340
|
linestyle=style.get("linestyle", "-")
|
|
1921
2341
|
)
|
|
1922
|
-
|
|
1923
|
-
|
|
2342
|
+
variant_handles.append(handle)
|
|
2343
|
+
variant_labels.append(style.get("label", variant_key))
|
|
1924
2344
|
|
|
1925
2345
|
# Position legends vertically with proper alignment
|
|
1926
2346
|
# Dataset legend on upper right
|
|
@@ -1938,19 +2358,19 @@ def _add_vertical_legend(ax, dataset_names, colors, show_filters, linewidth):
|
|
|
1938
2358
|
|
|
1939
2359
|
# Filter legend below the dataset legend, aligned properly without title
|
|
1940
2360
|
legend2 = ax.legend(
|
|
1941
|
-
|
|
1942
|
-
|
|
2361
|
+
variant_handles,
|
|
2362
|
+
variant_labels,
|
|
1943
2363
|
loc="upper left",
|
|
1944
2364
|
frameon=False,
|
|
1945
2365
|
fontsize=7,
|
|
1946
2366
|
bbox_to_anchor=(1.05, 1.0 - len(dataset_names) * 0.06 - 0.1)
|
|
1947
2367
|
)
|
|
1948
2368
|
|
|
1949
|
-
def _add_dual_legend(ax, dataset_names, colors,
|
|
2369
|
+
def _add_dual_legend(ax, dataset_names, colors, variant_keys, linewidth):
|
|
1950
2370
|
"""
|
|
1951
|
-
Add two legends: one for datasets (colors), one for
|
|
2371
|
+
Add two legends: one for datasets (colors), one for mPR variants (line styles).
|
|
1952
2372
|
"""
|
|
1953
|
-
|
|
2373
|
+
variant_keys = _normalize_show_filters(variant_keys)
|
|
1954
2374
|
# Legend 1: Datasets (colors) - solid lines
|
|
1955
2375
|
dataset_handles = []
|
|
1956
2376
|
for i, name in enumerate(dataset_names):
|
|
@@ -1958,19 +2378,19 @@ def _add_dual_legend(ax, dataset_names, colors, show_filters, linewidth):
|
|
|
1958
2378
|
handle = Line2D([0], [0], color=color, linewidth=linewidth, linestyle="-")
|
|
1959
2379
|
dataset_handles.append(handle)
|
|
1960
2380
|
|
|
1961
|
-
# Legend 2:
|
|
1962
|
-
|
|
1963
|
-
|
|
1964
|
-
for
|
|
1965
|
-
style =
|
|
2381
|
+
# Legend 2: mPR variants (line styles) - black lines
|
|
2382
|
+
variant_handles = []
|
|
2383
|
+
variant_labels = []
|
|
2384
|
+
for variant_key in variant_keys:
|
|
2385
|
+
style = MPR_VARIANT_STYLES.get(variant_key, {})
|
|
1966
2386
|
handle = Line2D(
|
|
1967
2387
|
[0], [0],
|
|
1968
2388
|
color="black",
|
|
1969
2389
|
linewidth=linewidth,
|
|
1970
2390
|
linestyle=style.get("linestyle", "-")
|
|
1971
2391
|
)
|
|
1972
|
-
|
|
1973
|
-
|
|
2392
|
+
variant_handles.append(handle)
|
|
2393
|
+
variant_labels.append(style.get("label", variant_key))
|
|
1974
2394
|
|
|
1975
2395
|
# Position legends
|
|
1976
2396
|
# Dataset legend on upper right
|
|
@@ -1987,19 +2407,12 @@ def _add_dual_legend(ax, dataset_names, colors, show_filters, linewidth):
|
|
|
1987
2407
|
|
|
1988
2408
|
# Filter legend on lower left or right depending on plot type
|
|
1989
2409
|
legend2 = ax.legend(
|
|
1990
|
-
|
|
1991
|
-
|
|
2410
|
+
variant_handles,
|
|
2411
|
+
variant_labels,
|
|
1992
2412
|
loc="lower left",
|
|
1993
2413
|
frameon=False,
|
|
1994
|
-
title="
|
|
2414
|
+
title="Variant",
|
|
1995
2415
|
fontsize=7,
|
|
1996
2416
|
title_fontsize=8,
|
|
1997
2417
|
)
|
|
1998
2418
|
|
|
1999
|
-
# ============================================================================
|
|
2000
|
-
# Single dataset functions are now obsolete
|
|
2001
|
-
# ============================================================================
|
|
2002
|
-
|
|
2003
|
-
# Note: The original single dataset functions plot_mpr_tp() and plot_mpr_complexes()
|
|
2004
|
-
# have been replaced by the multi functions that now auto-detect available datasets.
|
|
2005
|
-
# Use plot_mpr_tp_multi() and plot_mpr_complexes_multi() instead.
|