pythonflex 0.3.3__tar.gz → 0.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pythonflex-0.3.3/src/pythonflex/plotting.py → pythonflex-0.4/.codex_backups/plotting.py.20260601-121332.bak +394 -165
- {pythonflex-0.3.3 → pythonflex-0.4}/.gitignore +3 -1
- pythonflex-0.4/LICENSE +7 -0
- {pythonflex-0.3.3 → pythonflex-0.4}/PKG-INFO +9 -4
- {pythonflex-0.3.3 → pythonflex-0.4}/README.md +7 -3
- {pythonflex-0.3.3 → pythonflex-0.4}/pyproject.toml +69 -77
- pythonflex-0.4/src/pythonflex/__init__.py +44 -0
- {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/analysis.py +287 -578
- pythonflex-0.4/src/pythonflex/examples/basic_usage.py +95 -0
- {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/examples/manuscript.py +37 -42
- pythonflex-0.4/src/pythonflex/examples/runtime/runtime_benchmark.py +218 -0
- pythonflex-0.4/src/pythonflex/examples/runtime/runtime_benchmark_10_runs_memmap.py +534 -0
- pythonflex-0.4/src/pythonflex/examples/runtime/runtime_benchmark_corum_njobs.py +245 -0
- pythonflex-0.4/src/pythonflex/examples/runtime/runtime_benchmark_gobp_njobs_chunks.py +319 -0
- pythonflex-0.4/src/pythonflex/examples/runtime/runtime_benchmark_gobp_optimization.py +417 -0
- pythonflex-0.4/src/pythonflex/examples/runtime/runtime_benchmark_repeated.py +347 -0
- pythonflex-0.4/src/pythonflex/old_functions.py +422 -0
- pythonflex-0.4/src/pythonflex/plotting.py +2418 -0
- {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/preprocessing.py +62 -60
- {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/utils.py +36 -9
- pythonflex-0.4/todo.txt +2 -0
- pythonflex-0.3.3/src/pythonflex/__init__.py +0 -20
- pythonflex-0.3.3/src/pythonflex/examples/basic_usage.py +0 -87
- pythonflex-0.3.3/todo.txt +0 -1
- {pythonflex-0.3.3 → pythonflex-0.4}/.python-version +0 -0
- {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/data/__init__.py +0 -0
- {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/data/dataset/__init__.py +0 -0
- {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/data/dataset/liver_cell_lines_500_genes.csv +0 -0
- {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/data/dataset/melanoma_cell_lines_500_genes.csv +0 -0
- {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/data/dataset/neuroblastoma_cell_lines_500_genes.csv +0 -0
- {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/data/gold_standard/CORUM.parquet +0 -0
- {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/data/gold_standard/GOBP.parquet +0 -0
- {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/data/gold_standard/PATHWAY.parquet +0 -0
- {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/data/gold_standard/__init__.py +0 -0
- {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/data/gold_standard/corum.csv +0 -0
- {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/data/gold_standard/gobp.csv +0 -0
- {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/data/gold_standard/pathway.csv +0 -0
- {pythonflex-0.3.3 → pythonflex-0.4}/src/pythonflex/logging_config.py +0 -0
- {pythonflex-0.3.3 → pythonflex-0.4}/uv.lock +0 -0
|
@@ -9,6 +9,7 @@ import pandas as pd
|
|
|
9
9
|
import matplotlib.pyplot as plt
|
|
10
10
|
from matplotlib import patches
|
|
11
11
|
from matplotlib.cm import get_cmap
|
|
12
|
+
from matplotlib.lines import Line2D
|
|
12
13
|
from matplotlib.ticker import NullFormatter, NullLocator
|
|
13
14
|
|
|
14
15
|
# Completely disable LaTeX and clear all font cache/references
|
|
@@ -33,7 +34,7 @@ mpl.rcParams['mathtext.default'] = 'regular'
|
|
|
33
34
|
# Force font manager to rebuild with system fonts only
|
|
34
35
|
try:
|
|
35
36
|
fm.fontManager.__init__()
|
|
36
|
-
except:
|
|
37
|
+
except Exception:
|
|
37
38
|
pass
|
|
38
39
|
|
|
39
40
|
# Local modules
|
|
@@ -279,13 +280,27 @@ def plot_all_runs_pra(pra_list, mean_df=None, line_width=2.0, hide_minor_ticks=T
|
|
|
279
280
|
plt.show()
|
|
280
281
|
plt.close(fig)
|
|
281
282
|
|
|
282
|
-
def plot_percomplex_scatter(
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
283
|
+
def plot_percomplex_scatter(
|
|
284
|
+
n_top=10,
|
|
285
|
+
sig_color='black',
|
|
286
|
+
nonsig_color='none',
|
|
287
|
+
label_color='black',
|
|
288
|
+
border_color='black',
|
|
289
|
+
border_width=1.0,
|
|
290
|
+
nonsig_border_color="#7F7F7F",
|
|
291
|
+
nonsig_border_width=0.5,
|
|
292
|
+
show_text_background=True,
|
|
293
|
+
):
|
|
294
|
+
config = dload("config")
|
|
295
|
+
plot_config = config["plotting"]
|
|
296
|
+
rdict = dload("pra_percomplex")
|
|
297
|
+
input_colors = dload("input", "colors")
|
|
298
|
+
input_colors = {_sanitize(k): v for k, v in input_colors.items()} if input_colors else {}
|
|
286
299
|
|
|
287
300
|
if len(rdict) < 2:
|
|
288
|
-
|
|
301
|
+
log.warning(
|
|
302
|
+
"Skipping plot: at least two datasets are required for per-complex scatter plot."
|
|
303
|
+
)
|
|
289
304
|
return
|
|
290
305
|
|
|
291
306
|
column_pairs = list(combinations(rdict.keys(), 2))
|
|
@@ -299,32 +314,61 @@ def plot_percomplex_scatter(n_top=10, sig_color='#B71A2A', nonsig_color='#DBDDDD
|
|
|
299
314
|
df = pd.concat([df, val[key]], axis=1)
|
|
300
315
|
|
|
301
316
|
for pair in column_pairs:
|
|
302
|
-
extreme_indices_0 = df[pair[0]].sort_values(ascending=False).head(n_top).index
|
|
303
|
-
extreme_indices_1 = df[pair[1]].sort_values(ascending=False).head(n_top).index
|
|
304
|
-
significant_indices = extreme_indices_0.union(extreme_indices_1)
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
)
|
|
317
|
+
extreme_indices_0 = df[pair[0]].sort_values(ascending=False).head(n_top).index
|
|
318
|
+
extreme_indices_1 = df[pair[1]].sort_values(ascending=False).head(n_top).index
|
|
319
|
+
significant_indices = extreme_indices_0.union(extreme_indices_1)
|
|
320
|
+
significant_in_both = extreme_indices_0.intersection(extreme_indices_1)
|
|
321
|
+
significant_pair0_only = extreme_indices_0.difference(extreme_indices_1)
|
|
322
|
+
significant_pair1_only = extreme_indices_1.difference(extreme_indices_0)
|
|
323
|
+
|
|
324
|
+
bg_df = df.drop(index=significant_indices)
|
|
325
|
+
sig_df = df.loc[significant_indices]
|
|
326
|
+
sig_sizes = (
|
|
327
|
+
sig_df['n_used_genes']
|
|
328
|
+
if 'n_used_genes' in sig_df
|
|
329
|
+
else pd.Series(1, index=sig_df.index)
|
|
330
|
+
) * 8
|
|
331
|
+
|
|
332
|
+
# Create square figure
|
|
333
|
+
fig, ax = plt.subplots(figsize=(6, 6))
|
|
334
|
+
|
|
335
|
+
# Background cloud: non-significant complexes are open circles.
|
|
336
|
+
bg_sizes = (bg_df['n_used_genes'] if 'n_used_genes' in bg_df else pd.Series(1, index=bg_df.index)) * 5
|
|
337
|
+
ax.scatter(
|
|
338
|
+
bg_df[pair[0]], bg_df[pair[1]],
|
|
339
|
+
facecolors="none", edgecolors=nonsig_border_color,
|
|
340
|
+
s=bg_sizes, linewidth=nonsig_border_width, alpha=0.8,
|
|
341
|
+
zorder=0
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
def scatter_significant(indices, color, zorder=2):
|
|
345
|
+
if len(indices) == 0:
|
|
346
|
+
return
|
|
347
|
+
point_df = df.loc[indices]
|
|
348
|
+
point_sizes = (
|
|
349
|
+
point_df['n_used_genes']
|
|
350
|
+
if 'n_used_genes' in point_df
|
|
351
|
+
else pd.Series(1, index=point_df.index)
|
|
352
|
+
) * 8
|
|
353
|
+
ax.scatter(
|
|
354
|
+
point_df[pair[0]], point_df[pair[1]],
|
|
355
|
+
facecolors=color, edgecolors=color,
|
|
356
|
+
s=point_sizes, linewidth=border_width, zorder=zorder
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
# Dataset-specific significant complexes use the dataset input color.
|
|
360
|
+
scatter_significant(
|
|
361
|
+
significant_pair0_only,
|
|
362
|
+
input_colors.get(_sanitize(pair[0]), sig_color),
|
|
363
|
+
zorder=2,
|
|
364
|
+
)
|
|
365
|
+
scatter_significant(
|
|
366
|
+
significant_pair1_only,
|
|
367
|
+
input_colors.get(_sanitize(pair[1]), sig_color),
|
|
368
|
+
zorder=2,
|
|
369
|
+
)
|
|
370
|
+
# Complexes significant in both datasets stay black to avoid ambiguous color mixing.
|
|
371
|
+
scatter_significant(significant_in_both, "black", zorder=3)
|
|
328
372
|
|
|
329
373
|
# Improved label positioning with adaptive spacing
|
|
330
374
|
coords = sorted(
|
|
@@ -894,16 +938,28 @@ def position_cluster_labels(cluster, cluster_id, max_y, effective_max_y, label_c
|
|
|
894
938
|
clip_on=True, bbox=bbox_props
|
|
895
939
|
)
|
|
896
940
|
|
|
897
|
-
def plot_percomplex_scatter_bysize(
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
941
|
+
def plot_percomplex_scatter_bysize(
|
|
942
|
+
n_labels=10,
|
|
943
|
+
n_top=10,
|
|
944
|
+
sig_color='black',
|
|
945
|
+
nonsig_color='none',
|
|
946
|
+
label_color='black',
|
|
947
|
+
border_color='black',
|
|
948
|
+
border_width=1.0,
|
|
949
|
+
nonsig_border_color="#7F7F7F",
|
|
950
|
+
nonsig_border_width=0.5,
|
|
951
|
+
show_text_background=True,
|
|
952
|
+
):
|
|
953
|
+
config = dload("config")
|
|
954
|
+
plot_config = config["plotting"]
|
|
955
|
+
rdict = dload("pra_percomplex")
|
|
956
|
+
input_colors = dload("input", "colors")
|
|
957
|
+
input_colors = {_sanitize(k): v for k, v in input_colors.items()} if input_colors else {}
|
|
958
|
+
|
|
959
|
+
for key, per_complex in rdict.items():
|
|
960
|
+
dataset_color = input_colors.get(_sanitize(key), sig_color)
|
|
961
|
+
sorted_pc = per_complex.sort_values(by="auc_score", ascending=False, na_position="last")
|
|
962
|
+
top_labels, rest = sorted_pc.head(n_labels), sorted_pc.iloc[n_labels:]
|
|
907
963
|
|
|
908
964
|
# Calculate data range for appropriate figure sizing
|
|
909
965
|
max_genes = sorted_pc.n_used_genes.max()
|
|
@@ -914,22 +970,22 @@ def plot_percomplex_scatter_bysize(n_labels=10, n_top=10, sig_color='#B71A2A', n
|
|
|
914
970
|
fig_height = min(max(4, aspect_ratio), 8) # Between 4-8 inches
|
|
915
971
|
fig, ax = plt.subplots(figsize=(6, fig_height))
|
|
916
972
|
|
|
917
|
-
# Background
|
|
918
|
-
ax.scatter(
|
|
919
|
-
rest.auc_score, rest.n_used_genes,
|
|
920
|
-
facecolors=
|
|
921
|
-
linewidth=
|
|
922
|
-
alpha=
|
|
923
|
-
zorder=0
|
|
924
|
-
)
|
|
925
|
-
|
|
926
|
-
# Top N
|
|
927
|
-
ax.scatter(
|
|
928
|
-
top_labels.auc_score, top_labels.n_used_genes,
|
|
929
|
-
facecolors=
|
|
930
|
-
linewidth=border_width, s=top_labels.n_used_genes * 8,
|
|
931
|
-
label=f"Top {n_labels} AUC Scores", alpha=1.0, zorder=2
|
|
932
|
-
)
|
|
973
|
+
# Background: non-significant complexes are open circles.
|
|
974
|
+
ax.scatter(
|
|
975
|
+
rest.auc_score, rest.n_used_genes,
|
|
976
|
+
facecolors="none", edgecolors=nonsig_border_color,
|
|
977
|
+
linewidth=nonsig_border_width, s=rest.n_used_genes * 5,
|
|
978
|
+
alpha=0.8, label="Other Complexes",
|
|
979
|
+
zorder=0
|
|
980
|
+
)
|
|
981
|
+
|
|
982
|
+
# Top N/significant complexes are filled black circles.
|
|
983
|
+
ax.scatter(
|
|
984
|
+
top_labels.auc_score, top_labels.n_used_genes,
|
|
985
|
+
facecolors=dataset_color, edgecolors=dataset_color,
|
|
986
|
+
linewidth=border_width, s=top_labels.n_used_genes * 8,
|
|
987
|
+
label=f"Top {n_labels} AUC Scores", alpha=1.0, zorder=2
|
|
988
|
+
)
|
|
933
989
|
|
|
934
990
|
# Enhanced anti-overlap labeling system
|
|
935
991
|
coords = [(row.auc_score, row.n_used_genes, idx) for idx, row in top_labels.iterrows()]
|
|
@@ -1010,6 +1066,7 @@ def plot_complex_contributions(
|
|
|
1010
1066
|
tmp = np.tile(x, (mx, 1))
|
|
1011
1067
|
x = cont_stepwise_mat.values / tmp
|
|
1012
1068
|
x_df = pd.DataFrame(x, index=cont_stepwise_anno, columns=cont_stepwise_mat.columns)
|
|
1069
|
+
|
|
1013
1070
|
ind_for_mean = y >= (last_prec_value - min_precision_cutoff)
|
|
1014
1071
|
if sum(ind_for_mean) == 0:
|
|
1015
1072
|
log.info("No values above 'min.precision.cutoff'"); return False
|
|
@@ -1094,15 +1151,23 @@ def plot_significant_complexes():
|
|
|
1094
1151
|
input_colors = {_sanitize(k): v for k, v in input_colors.items()}
|
|
1095
1152
|
|
|
1096
1153
|
thresholds = [0.1, 0.2, 0.3, 0.4, 0.5]
|
|
1154
|
+
if not isinstance(pra_percomplex, dict) or not pra_percomplex:
|
|
1155
|
+
log.warning("No per-complex PRA data found. Run pra_percomplex() first.")
|
|
1156
|
+
return pd.DataFrame(index=thresholds)
|
|
1157
|
+
|
|
1097
1158
|
datasets = list(pra_percomplex.keys())
|
|
1098
1159
|
num_datasets = len(datasets)
|
|
1099
1160
|
|
|
1161
|
+
if num_datasets == 0:
|
|
1162
|
+
return pd.DataFrame(index=thresholds)
|
|
1163
|
+
|
|
1100
1164
|
df = pd.DataFrame(index=thresholds)
|
|
1101
1165
|
for key, complex_data in pra_percomplex.items():
|
|
1102
1166
|
if "corrected_auc_score" in complex_data.columns:
|
|
1103
1167
|
score_col = "corrected_auc_score"
|
|
1104
1168
|
else:
|
|
1105
1169
|
score_col = "auc_score"
|
|
1170
|
+
|
|
1106
1171
|
df[key] = [complex_data.query(f'{score_col} >= {t}').shape[0] for t in thresholds]
|
|
1107
1172
|
|
|
1108
1173
|
fig, ax = plt.subplots()
|
|
@@ -1221,15 +1286,20 @@ def plot_auc_scores():
|
|
|
1221
1286
|
return pra_dict
|
|
1222
1287
|
|
|
1223
1288
|
|
|
1224
|
-
def
|
|
1289
|
+
def plot_mpr_complex_auc_scores(variant: str = "unfiltered", save=None, outname=None):
|
|
1225
1290
|
"""Plot AUC scores for the mPR complexes curve (Fig 1F-style).
|
|
1226
1291
|
|
|
1227
1292
|
Requires `mpr_prepare()` to have been run for each dataset.
|
|
1228
1293
|
|
|
1229
1294
|
Parameters
|
|
1230
1295
|
----------
|
|
1231
|
-
|
|
1232
|
-
One of: "
|
|
1296
|
+
variant : str
|
|
1297
|
+
One of: "unfiltered", "without_mt_ribo_etci",
|
|
1298
|
+
"without_small_high_auprc".
|
|
1299
|
+
save : bool, optional
|
|
1300
|
+
Whether to save the figure. If None, uses config["plotting"]["save_plot"].
|
|
1301
|
+
outname : str, optional
|
|
1302
|
+
Output filename. If None, auto-generated.
|
|
1233
1303
|
|
|
1234
1304
|
Returns
|
|
1235
1305
|
-------
|
|
@@ -1250,12 +1320,14 @@ def plot_mpr_complexes_auc_scores(filter_key: str = "all"):
|
|
|
1250
1320
|
)
|
|
1251
1321
|
return pd.Series(dtype=float)
|
|
1252
1322
|
|
|
1323
|
+
variant_key = _normalize_mpr_variant(variant)
|
|
1324
|
+
|
|
1253
1325
|
# Build Series: dataset -> auc
|
|
1254
1326
|
auc_by_dataset = {}
|
|
1255
1327
|
for dataset, per_filter in mpr_auc_dict.items():
|
|
1256
1328
|
if not isinstance(per_filter, dict):
|
|
1257
1329
|
continue
|
|
1258
|
-
val = per_filter.get(
|
|
1330
|
+
val = per_filter.get(variant_key)
|
|
1259
1331
|
if val is None:
|
|
1260
1332
|
continue
|
|
1261
1333
|
try:
|
|
@@ -1265,7 +1337,8 @@ def plot_mpr_complexes_auc_scores(filter_key: str = "all"):
|
|
|
1265
1337
|
|
|
1266
1338
|
if not auc_by_dataset:
|
|
1267
1339
|
log.warning(
|
|
1268
|
-
f"No mPR
|
|
1340
|
+
f"No mPR complex AUC scores found for variant '{variant}'. "
|
|
1341
|
+
f"Available variants: {list(PUBLIC_MPR_VARIANTS.keys())}"
|
|
1269
1342
|
)
|
|
1270
1343
|
return pd.Series(dtype=float)
|
|
1271
1344
|
|
|
@@ -1308,11 +1381,16 @@ def plot_mpr_complexes_auc_scores(filter_key: str = "all"):
|
|
|
1308
1381
|
ax.spines["top"].set_visible(False)
|
|
1309
1382
|
ax.spines["right"].set_visible(False)
|
|
1310
1383
|
|
|
1311
|
-
|
|
1384
|
+
should_save = plot_config.get("save_plot", False) if save is None else bool(save)
|
|
1385
|
+
if should_save:
|
|
1312
1386
|
output_type = plot_config.get("output_type", "pdf")
|
|
1313
1387
|
output_folder = Path(config["output_folder"])
|
|
1314
1388
|
output_folder.mkdir(parents=True, exist_ok=True)
|
|
1315
|
-
|
|
1389
|
+
if outname is None:
|
|
1390
|
+
outname = f"mpr_complexes_auc_{variant_key}.{output_type}"
|
|
1391
|
+
output_path = Path(outname)
|
|
1392
|
+
if len(output_path.parts) == 1:
|
|
1393
|
+
output_path = output_folder / outname
|
|
1316
1394
|
plt.savefig(output_path, bbox_inches="tight", format=output_type)
|
|
1317
1395
|
|
|
1318
1396
|
if plot_config.get("show_plot", True):
|
|
@@ -1321,6 +1399,13 @@ def plot_mpr_complexes_auc_scores(filter_key: str = "all"):
|
|
|
1321
1399
|
plt.close(fig)
|
|
1322
1400
|
return s
|
|
1323
1401
|
|
|
1402
|
+
|
|
1403
|
+
def plot_mpr_complexes_auc_scores(filter_key: str = "all"):
|
|
1404
|
+
"""Backward-compatible wrapper for plot_mpr_complex_auc_scores()."""
|
|
1405
|
+
return plot_mpr_complex_auc_scores(
|
|
1406
|
+
variant=_legacy_filter_to_variant(filter_key, default="unfiltered")
|
|
1407
|
+
)
|
|
1408
|
+
|
|
1324
1409
|
# -----------------------------------------------------------------------------
|
|
1325
1410
|
# mPR plots (Fig. 1E and Fig. 1F)
|
|
1326
1411
|
# -----------------------------------------------------------------------------
|
|
@@ -1475,26 +1560,6 @@ def plot_mpr_tp(name, ax=None, save=True, outname=None):
|
|
|
1475
1560
|
|
|
1476
1561
|
return ax
|
|
1477
1562
|
|
|
1478
|
-
"""
|
|
1479
|
-
Multi-dataset mPR plotting functions.
|
|
1480
|
-
|
|
1481
|
-
Usage:
|
|
1482
|
-
from pythonflex.plotting import plot_mpr_tp_multi, plot_mpr_complexes_multi
|
|
1483
|
-
|
|
1484
|
-
# Plot multiple datasets
|
|
1485
|
-
plot_mpr_tp_multi(["19Q2", "19Q4", "20Q1"])
|
|
1486
|
-
plot_mpr_complexes_multi(["19Q2", "19Q4", "20Q1"])
|
|
1487
|
-
"""
|
|
1488
|
-
|
|
1489
|
-
import numpy as np
|
|
1490
|
-
import pandas as pd
|
|
1491
|
-
import matplotlib.pyplot as plt
|
|
1492
|
-
from matplotlib.lines import Line2D
|
|
1493
|
-
from pathlib import Path
|
|
1494
|
-
|
|
1495
|
-
from .utils import dload
|
|
1496
|
-
from .logging_config import log
|
|
1497
|
-
|
|
1498
1563
|
# Default color palette (colorblind-friendly)
|
|
1499
1564
|
DEFAULT_COLORS = [
|
|
1500
1565
|
"#4E79A7", # blue
|
|
@@ -1509,40 +1574,101 @@ DEFAULT_COLORS = [
|
|
|
1509
1574
|
"#BAB0AC", # gray
|
|
1510
1575
|
]
|
|
1511
1576
|
|
|
1512
|
-
#
|
|
1513
|
-
|
|
1577
|
+
# Public mPR variant names map to the internal keys stored by mpr_prepare().
|
|
1578
|
+
PUBLIC_MPR_VARIANTS = {
|
|
1579
|
+
"unfiltered": "all",
|
|
1580
|
+
"without_mt_ribo_etci": "no_mtRibo_ETCI",
|
|
1581
|
+
"without_small_high_auprc": "no_small_highAUPRC",
|
|
1582
|
+
}
|
|
1583
|
+
INTERNAL_MPR_VARIANTS = {v: k for k, v in PUBLIC_MPR_VARIANTS.items()}
|
|
1584
|
+
|
|
1585
|
+
# mPR variant line styles keyed by internal storage names.
|
|
1586
|
+
MPR_VARIANT_STYLES = {
|
|
1514
1587
|
"all": {"linestyle": "-", "label": "all data"},
|
|
1515
1588
|
"no_mtRibo_ETCI": {"linestyle": "--", "label": "no mtRibo, ETC I"},
|
|
1516
1589
|
"no_small_highAUPRC": {"linestyle": "dotted", "label": "no small, high AUPRC"},
|
|
1517
1590
|
}
|
|
1518
1591
|
|
|
1592
|
+
# Compatibility alias for users who imported this internal constant.
|
|
1593
|
+
FILTER_STYLES = MPR_VARIANT_STYLES
|
|
1594
|
+
|
|
1595
|
+
|
|
1596
|
+
def _normalize_mpr_variant(variant):
|
|
1597
|
+
"""Return the internal mPR variant key for one public variant name."""
|
|
1598
|
+
if variant in PUBLIC_MPR_VARIANTS:
|
|
1599
|
+
return PUBLIC_MPR_VARIANTS[variant]
|
|
1600
|
+
if variant in MPR_VARIANT_STYLES:
|
|
1601
|
+
if variant == "all":
|
|
1602
|
+
return PUBLIC_MPR_VARIANTS["unfiltered"]
|
|
1603
|
+
return variant
|
|
1604
|
+
raise ValueError(
|
|
1605
|
+
"Unknown mPR variant "
|
|
1606
|
+
f"{variant!r}. Use one of {list(PUBLIC_MPR_VARIANTS.keys())}."
|
|
1607
|
+
)
|
|
1519
1608
|
|
|
1520
|
-
def _normalize_show_filters(show_filters):
|
|
1521
|
-
"""Normalize show_filters to an ordered tuple of filter keys.
|
|
1522
1609
|
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
|
|
1610
|
+
def _normalize_mpr_variants(variants):
|
|
1611
|
+
"""Normalize public mPR variant names to internal storage keys."""
|
|
1612
|
+
if variants is None:
|
|
1613
|
+
raw_variants = ("all",)
|
|
1614
|
+
elif isinstance(variants, str):
|
|
1615
|
+
raw_variants = (variants,)
|
|
1616
|
+
else:
|
|
1617
|
+
try:
|
|
1618
|
+
raw_variants = tuple(variants)
|
|
1619
|
+
except TypeError:
|
|
1620
|
+
raw_variants = (variants,)
|
|
1621
|
+
|
|
1622
|
+
out = []
|
|
1623
|
+
for variant in raw_variants:
|
|
1624
|
+
if variant == "all":
|
|
1625
|
+
out.extend(PUBLIC_MPR_VARIANTS.values())
|
|
1626
|
+
else:
|
|
1627
|
+
out.append(_normalize_mpr_variant(variant))
|
|
1628
|
+
|
|
1629
|
+
# Preserve user order while removing duplicates.
|
|
1630
|
+
return tuple(dict.fromkeys(out))
|
|
1631
|
+
|
|
1632
|
+
|
|
1633
|
+
def _legacy_filter_to_variant(filter_key, default=None):
|
|
1634
|
+
"""Map old filter-key names to public variant names."""
|
|
1635
|
+
if filter_key is None:
|
|
1636
|
+
return default if default is not None else "all"
|
|
1637
|
+
mapping = {
|
|
1638
|
+
"all": "unfiltered",
|
|
1639
|
+
"no_mtRibo_ETCI": "without_mt_ribo_etci",
|
|
1640
|
+
"no_small_highAUPRC": "without_small_high_auprc",
|
|
1641
|
+
}
|
|
1642
|
+
return mapping.get(filter_key, filter_key)
|
|
1643
|
+
|
|
1644
|
+
|
|
1645
|
+
def _legacy_filters_to_variants(show_filters):
|
|
1646
|
+
"""Map old show_filters values to public variant names."""
|
|
1526
1647
|
if show_filters is None:
|
|
1527
|
-
return
|
|
1648
|
+
return "all"
|
|
1528
1649
|
if isinstance(show_filters, str):
|
|
1529
|
-
return (show_filters
|
|
1650
|
+
return _legacy_filter_to_variant(show_filters)
|
|
1530
1651
|
try:
|
|
1531
|
-
return tuple(show_filters)
|
|
1652
|
+
return tuple(_legacy_filter_to_variant(item) for item in show_filters)
|
|
1532
1653
|
except TypeError:
|
|
1533
|
-
return (show_filters,)
|
|
1654
|
+
return (_legacy_filter_to_variant(show_filters),)
|
|
1534
1655
|
|
|
1535
|
-
|
|
1656
|
+
|
|
1657
|
+
def _normalize_show_filters(show_filters):
|
|
1658
|
+
"""Backward-compatible normalizer for old internal filter keys."""
|
|
1659
|
+
return _normalize_mpr_variants(_legacy_filters_to_variants(show_filters))
|
|
1660
|
+
|
|
1661
|
+
def plot_mpr_true_positive_curve(
|
|
1536
1662
|
dataset_names=None,
|
|
1537
1663
|
colors=None,
|
|
1538
1664
|
ax=None,
|
|
1539
1665
|
save=True,
|
|
1540
1666
|
outname=None,
|
|
1541
1667
|
linewidth=1.8,
|
|
1542
|
-
|
|
1668
|
+
variants="unfiltered",
|
|
1543
1669
|
):
|
|
1544
1670
|
"""
|
|
1545
|
-
Plot
|
|
1671
|
+
Plot mPR true-positive vs precision curves for multiple datasets.
|
|
1546
1672
|
|
|
1547
1673
|
Can auto-detect datasets or use provided dataset names.
|
|
1548
1674
|
Each dataset gets one color, each filter type gets one line style.
|
|
@@ -1562,8 +1688,9 @@ def plot_mpr_tp_multi(
|
|
|
1562
1688
|
Output filename. If None, auto-generated.
|
|
1563
1689
|
linewidth : float
|
|
1564
1690
|
Line width for all curves
|
|
1565
|
-
|
|
1566
|
-
Which
|
|
1691
|
+
variants : str or iterable of str
|
|
1692
|
+
Which mPR variants to show. Use "unfiltered",
|
|
1693
|
+
"without_mt_ribo_etci", "without_small_high_auprc", or "all".
|
|
1567
1694
|
|
|
1568
1695
|
Returns
|
|
1569
1696
|
-------
|
|
@@ -1573,7 +1700,7 @@ def plot_mpr_tp_multi(
|
|
|
1573
1700
|
plot_config = config["plotting"]
|
|
1574
1701
|
input_colors = dload("input", "colors")
|
|
1575
1702
|
|
|
1576
|
-
|
|
1703
|
+
variant_keys = _normalize_mpr_variants(variants)
|
|
1577
1704
|
|
|
1578
1705
|
# Sanitize color keys
|
|
1579
1706
|
if input_colors:
|
|
@@ -1641,13 +1768,13 @@ def plot_mpr_tp_multi(
|
|
|
1641
1768
|
tp_curves = mpr["tp_curves"]
|
|
1642
1769
|
color = colors[i % len(colors)]
|
|
1643
1770
|
|
|
1644
|
-
for
|
|
1645
|
-
if
|
|
1771
|
+
for variant_key in variant_keys:
|
|
1772
|
+
if variant_key not in tp_curves:
|
|
1646
1773
|
continue
|
|
1647
1774
|
|
|
1648
|
-
data = tp_curves[
|
|
1775
|
+
data = tp_curves[variant_key]
|
|
1649
1776
|
if not isinstance(data, dict) or "tp" not in data or "precision" not in data:
|
|
1650
|
-
log.warning(f"Invalid tp_curves data structure for '{name}'
|
|
1777
|
+
log.warning(f"Invalid tp_curves data structure for '{name}' variant '{variant_key}', skipping.")
|
|
1651
1778
|
continue
|
|
1652
1779
|
|
|
1653
1780
|
tp = np.asarray(data["tp"], dtype=float)
|
|
@@ -1661,7 +1788,7 @@ def plot_mpr_tp_multi(
|
|
|
1661
1788
|
prec_plot = prec[mask]
|
|
1662
1789
|
xmax = max(xmax, float(tp_plot.max()))
|
|
1663
1790
|
|
|
1664
|
-
style =
|
|
1791
|
+
style = MPR_VARIANT_STYLES.get(variant_key, {})
|
|
1665
1792
|
ax.plot(
|
|
1666
1793
|
tp_plot,
|
|
1667
1794
|
prec_plot,
|
|
@@ -1694,7 +1821,7 @@ def plot_mpr_tp_multi(
|
|
|
1694
1821
|
ax.spines['right'].set_visible(False)
|
|
1695
1822
|
|
|
1696
1823
|
# Create vertically stacked legends
|
|
1697
|
-
_add_vertical_legend(ax, dataset_names, colors,
|
|
1824
|
+
_add_vertical_legend(ax, dataset_names, colors, variant_keys, linewidth)
|
|
1698
1825
|
|
|
1699
1826
|
# Save
|
|
1700
1827
|
if save:
|
|
@@ -1713,7 +1840,8 @@ def plot_mpr_tp_multi(
|
|
|
1713
1840
|
|
|
1714
1841
|
return ax
|
|
1715
1842
|
|
|
1716
|
-
|
|
1843
|
+
|
|
1844
|
+
def plot_mpr_tp_multi(
|
|
1717
1845
|
dataset_names=None,
|
|
1718
1846
|
colors=None,
|
|
1719
1847
|
ax=None,
|
|
@@ -1721,11 +1849,31 @@ def plot_mpr_complexes_multi(
|
|
|
1721
1849
|
outname=None,
|
|
1722
1850
|
linewidth=1.8,
|
|
1723
1851
|
show_filters=("all", "no_mtRibo_ETCI", "no_small_highAUPRC"),
|
|
1852
|
+
):
|
|
1853
|
+
"""Backward-compatible wrapper for plot_mpr_true_positive_curve()."""
|
|
1854
|
+
return plot_mpr_true_positive_curve(
|
|
1855
|
+
dataset_names=dataset_names,
|
|
1856
|
+
colors=colors,
|
|
1857
|
+
ax=ax,
|
|
1858
|
+
save=save,
|
|
1859
|
+
outname=outname,
|
|
1860
|
+
linewidth=linewidth,
|
|
1861
|
+
variants=_legacy_filters_to_variants(show_filters),
|
|
1862
|
+
)
|
|
1863
|
+
|
|
1864
|
+
def plot_mpr_complex_coverage_curve(
|
|
1865
|
+
dataset_names=None,
|
|
1866
|
+
colors=None,
|
|
1867
|
+
ax=None,
|
|
1868
|
+
save=True,
|
|
1869
|
+
outname=None,
|
|
1870
|
+
linewidth=1.8,
|
|
1871
|
+
variants="unfiltered",
|
|
1724
1872
|
show_markers="auto",
|
|
1725
1873
|
marker_size=20,
|
|
1726
1874
|
):
|
|
1727
1875
|
"""
|
|
1728
|
-
Plot
|
|
1876
|
+
Plot mPR complex-coverage vs precision curves for multiple datasets.
|
|
1729
1877
|
|
|
1730
1878
|
Can auto-detect datasets or use provided dataset names.
|
|
1731
1879
|
Each dataset gets one color, each filter type gets one line style.
|
|
@@ -1745,8 +1893,9 @@ def plot_mpr_complexes_multi(
|
|
|
1745
1893
|
Output filename. If None, auto-generated.
|
|
1746
1894
|
linewidth : float
|
|
1747
1895
|
Line width for all curves
|
|
1748
|
-
|
|
1749
|
-
Which
|
|
1896
|
+
variants : str or iterable of str
|
|
1897
|
+
Which mPR variants to show. Use "unfiltered",
|
|
1898
|
+
"without_mt_ribo_etci", "without_small_high_auprc", or "all".
|
|
1750
1899
|
show_markers : bool or "auto"
|
|
1751
1900
|
If True, draw markers on curves to make short curves visible.
|
|
1752
1901
|
If "auto" (default), markers are drawn only for curves with <= 10 points.
|
|
@@ -1761,7 +1910,7 @@ def plot_mpr_complexes_multi(
|
|
|
1761
1910
|
plot_config = config["plotting"]
|
|
1762
1911
|
input_colors = dload("input", "colors")
|
|
1763
1912
|
|
|
1764
|
-
|
|
1913
|
+
variant_keys = _normalize_mpr_variants(variants)
|
|
1765
1914
|
|
|
1766
1915
|
# Sanitize color keys
|
|
1767
1916
|
if input_colors:
|
|
@@ -1812,32 +1961,61 @@ def plot_mpr_complexes_multi(
|
|
|
1812
1961
|
else:
|
|
1813
1962
|
fig = ax.figure
|
|
1814
1963
|
|
|
1815
|
-
#
|
|
1964
|
+
# First pass: determine max coverage across all datasets/filters for adaptive x-axis
|
|
1965
|
+
max_cov_global = 0
|
|
1966
|
+
_mpr_cache = {}
|
|
1816
1967
|
for i, name in enumerate(dataset_names):
|
|
1817
1968
|
mpr = dload("mpr", name)
|
|
1969
|
+
_mpr_cache[name] = mpr
|
|
1970
|
+
if mpr is not None:
|
|
1971
|
+
for variant_key in variant_keys:
|
|
1972
|
+
arr = mpr["coverage_curves"].get(variant_key)
|
|
1973
|
+
if arr is not None:
|
|
1974
|
+
max_cov_global = max(max_cov_global, float(np.asarray(arr).max()))
|
|
1975
|
+
|
|
1976
|
+
# Build adaptive x-axis limits and ticks
|
|
1977
|
+
import math
|
|
1978
|
+
if max_cov_global <= 200:
|
|
1979
|
+
# Original fixed range — keeps CORUM plots identical to before
|
|
1980
|
+
x_max_plot = 200
|
|
1981
|
+
tick_positions = [1, 2, 20, 200]
|
|
1982
|
+
tick_labels = ["0", "2", "20", "200"]
|
|
1983
|
+
else:
|
|
1984
|
+
# Round up to the next power of 10 so the max bar has breathing room
|
|
1985
|
+
x_max_plot = 10 ** math.ceil(math.log10(max_cov_global + 1))
|
|
1986
|
+
tick_positions = [1, 2]
|
|
1987
|
+
v = 10
|
|
1988
|
+
while v <= x_max_plot:
|
|
1989
|
+
tick_positions.append(v)
|
|
1990
|
+
v *= 10
|
|
1991
|
+
tick_labels = ["0"] + [str(t) for t in tick_positions[1:]]
|
|
1992
|
+
|
|
1993
|
+
# Plot each dataset
|
|
1994
|
+
for i, name in enumerate(dataset_names):
|
|
1995
|
+
mpr = _mpr_cache[name]
|
|
1818
1996
|
if mpr is None:
|
|
1819
1997
|
log.warning(f"mPR data for '{name}' not found, skipping.")
|
|
1820
1998
|
continue
|
|
1821
|
-
|
|
1999
|
+
|
|
1822
2000
|
precision_cutoffs = np.asarray(mpr["precision_cutoffs"], dtype=float)
|
|
1823
2001
|
coverage = mpr["coverage_curves"]
|
|
1824
2002
|
color = colors[i % len(colors)]
|
|
1825
|
-
|
|
1826
|
-
for
|
|
1827
|
-
if
|
|
2003
|
+
|
|
2004
|
+
for variant_key in variant_keys:
|
|
2005
|
+
if variant_key not in coverage:
|
|
1828
2006
|
continue
|
|
1829
|
-
|
|
1830
|
-
cov = np.asarray(coverage[
|
|
1831
|
-
|
|
1832
|
-
# Keep only positive coverage
|
|
1833
|
-
mask = (cov > 0) & (cov <=
|
|
2007
|
+
|
|
2008
|
+
cov = np.asarray(coverage[variant_key], dtype=float)
|
|
2009
|
+
|
|
2010
|
+
# Keep only positive coverage within the visible x range
|
|
2011
|
+
mask = (cov > 0) & (cov <= x_max_plot)
|
|
1834
2012
|
if not mask.any():
|
|
1835
2013
|
continue
|
|
1836
|
-
|
|
2014
|
+
|
|
1837
2015
|
cov_plot = cov[mask]
|
|
1838
2016
|
prec_plot = precision_cutoffs[mask]
|
|
1839
|
-
|
|
1840
|
-
style =
|
|
2017
|
+
|
|
2018
|
+
style = MPR_VARIANT_STYLES.get(variant_key, {})
|
|
1841
2019
|
|
|
1842
2020
|
# Decide marker visibility
|
|
1843
2021
|
if show_markers == "auto":
|
|
@@ -1858,17 +2036,15 @@ def plot_mpr_complexes_multi(
|
|
|
1858
2036
|
marker=("o" if use_markers else None),
|
|
1859
2037
|
markersize=(3 if use_markers else None),
|
|
1860
2038
|
)
|
|
1861
|
-
|
|
2039
|
+
|
|
1862
2040
|
# Configure axes
|
|
1863
2041
|
ax.set_xscale("log")
|
|
1864
|
-
ax.set_xlim(1,
|
|
2042
|
+
ax.set_xlim(1, x_max_plot)
|
|
1865
2043
|
ax.set_xlabel("# complexes")
|
|
1866
2044
|
ax.set_ylabel("Precision")
|
|
1867
2045
|
ax.set_ylim(0.0, 1.05)
|
|
1868
|
-
|
|
1869
|
-
#
|
|
1870
|
-
tick_positions = [1, 2, 20, 200]
|
|
1871
|
-
tick_labels = ["0", "2", "20", "200"]
|
|
2046
|
+
|
|
2047
|
+
# Adaptive x-ticks
|
|
1872
2048
|
ax.set_xticks(tick_positions)
|
|
1873
2049
|
ax.set_xticklabels(tick_labels)
|
|
1874
2050
|
|
|
@@ -1877,7 +2053,7 @@ def plot_mpr_complexes_multi(
|
|
|
1877
2053
|
ax.spines['right'].set_visible(False)
|
|
1878
2054
|
|
|
1879
2055
|
# Create vertically stacked legends
|
|
1880
|
-
_add_vertical_legend(ax, dataset_names, colors,
|
|
2056
|
+
_add_vertical_legend(ax, dataset_names, colors, variant_keys, linewidth)
|
|
1881
2057
|
|
|
1882
2058
|
# Save
|
|
1883
2059
|
if save:
|
|
@@ -1896,11 +2072,71 @@ def plot_mpr_complexes_multi(
|
|
|
1896
2072
|
|
|
1897
2073
|
return ax
|
|
1898
2074
|
|
|
1899
|
-
|
|
2075
|
+
|
|
2076
|
+
def plot_mpr_complexes_multi(
|
|
2077
|
+
dataset_names=None,
|
|
2078
|
+
colors=None,
|
|
2079
|
+
ax=None,
|
|
2080
|
+
save=True,
|
|
2081
|
+
outname=None,
|
|
2082
|
+
linewidth=1.8,
|
|
2083
|
+
show_filters=("all", "no_mtRibo_ETCI", "no_small_highAUPRC"),
|
|
2084
|
+
show_markers="auto",
|
|
2085
|
+
marker_size=20,
|
|
2086
|
+
):
|
|
2087
|
+
"""Backward-compatible wrapper for plot_mpr_complex_coverage_curve()."""
|
|
2088
|
+
return plot_mpr_complex_coverage_curve(
|
|
2089
|
+
dataset_names=dataset_names,
|
|
2090
|
+
colors=colors,
|
|
2091
|
+
ax=ax,
|
|
2092
|
+
save=save,
|
|
2093
|
+
outname=outname,
|
|
2094
|
+
linewidth=linewidth,
|
|
2095
|
+
variants=_legacy_filters_to_variants(show_filters),
|
|
2096
|
+
show_markers=show_markers,
|
|
2097
|
+
marker_size=marker_size,
|
|
2098
|
+
)
|
|
2099
|
+
|
|
2100
|
+
|
|
2101
|
+
def plot_mpr_summary(
|
|
2102
|
+
dataset_names=None,
|
|
2103
|
+
colors=None,
|
|
2104
|
+
variants="unfiltered",
|
|
2105
|
+
save=True,
|
|
2106
|
+
linewidth=1.8,
|
|
2107
|
+
show_markers="auto",
|
|
2108
|
+
marker_size=20,
|
|
2109
|
+
auc_variant=None,
|
|
2110
|
+
):
|
|
2111
|
+
"""Generate the standard mPR summary plots and return complex AUC scores."""
|
|
2112
|
+
plot_mpr_true_positive_curve(
|
|
2113
|
+
dataset_names=dataset_names,
|
|
2114
|
+
colors=colors,
|
|
2115
|
+
save=save,
|
|
2116
|
+
linewidth=linewidth,
|
|
2117
|
+
variants=variants,
|
|
2118
|
+
)
|
|
2119
|
+
plot_mpr_complex_coverage_curve(
|
|
2120
|
+
dataset_names=dataset_names,
|
|
2121
|
+
colors=colors,
|
|
2122
|
+
save=save,
|
|
2123
|
+
linewidth=linewidth,
|
|
2124
|
+
variants=variants,
|
|
2125
|
+
show_markers=show_markers,
|
|
2126
|
+
marker_size=marker_size,
|
|
2127
|
+
)
|
|
2128
|
+
|
|
2129
|
+
if auc_variant is None:
|
|
2130
|
+
variant_keys = _normalize_mpr_variants(variants)
|
|
2131
|
+
auc_variant = INTERNAL_MPR_VARIANTS.get(variant_keys[0], "unfiltered")
|
|
2132
|
+
|
|
2133
|
+
return plot_mpr_complex_auc_scores(variant=auc_variant, save=save)
|
|
2134
|
+
|
|
2135
|
+
def _add_vertical_legend(ax, dataset_names, colors, variant_keys, linewidth):
|
|
1900
2136
|
"""
|
|
1901
|
-
Add vertically stacked legends: Dataset on top,
|
|
2137
|
+
Add vertically stacked legends: Dataset on top, mPR variant below.
|
|
1902
2138
|
"""
|
|
1903
|
-
|
|
2139
|
+
variant_keys = _normalize_show_filters(variant_keys)
|
|
1904
2140
|
# Legend 1: Datasets (colors) - solid lines
|
|
1905
2141
|
dataset_handles = []
|
|
1906
2142
|
for i, name in enumerate(dataset_names):
|
|
@@ -1908,19 +2144,19 @@ def _add_vertical_legend(ax, dataset_names, colors, show_filters, linewidth):
|
|
|
1908
2144
|
handle = Line2D([0], [0], color=color, linewidth=linewidth, linestyle="-")
|
|
1909
2145
|
dataset_handles.append(handle)
|
|
1910
2146
|
|
|
1911
|
-
# Legend 2:
|
|
1912
|
-
|
|
1913
|
-
|
|
1914
|
-
for
|
|
1915
|
-
style =
|
|
2147
|
+
# Legend 2: mPR variants (line styles) - black lines
|
|
2148
|
+
variant_handles = []
|
|
2149
|
+
variant_labels = []
|
|
2150
|
+
for variant_key in variant_keys:
|
|
2151
|
+
style = MPR_VARIANT_STYLES.get(variant_key, {})
|
|
1916
2152
|
handle = Line2D(
|
|
1917
2153
|
[0], [0],
|
|
1918
2154
|
color="black",
|
|
1919
2155
|
linewidth=linewidth,
|
|
1920
2156
|
linestyle=style.get("linestyle", "-")
|
|
1921
2157
|
)
|
|
1922
|
-
|
|
1923
|
-
|
|
2158
|
+
variant_handles.append(handle)
|
|
2159
|
+
variant_labels.append(style.get("label", variant_key))
|
|
1924
2160
|
|
|
1925
2161
|
# Position legends vertically with proper alignment
|
|
1926
2162
|
# Dataset legend on upper right
|
|
@@ -1938,19 +2174,19 @@ def _add_vertical_legend(ax, dataset_names, colors, show_filters, linewidth):
|
|
|
1938
2174
|
|
|
1939
2175
|
# Filter legend below the dataset legend, aligned properly without title
|
|
1940
2176
|
legend2 = ax.legend(
|
|
1941
|
-
|
|
1942
|
-
|
|
2177
|
+
variant_handles,
|
|
2178
|
+
variant_labels,
|
|
1943
2179
|
loc="upper left",
|
|
1944
2180
|
frameon=False,
|
|
1945
2181
|
fontsize=7,
|
|
1946
2182
|
bbox_to_anchor=(1.05, 1.0 - len(dataset_names) * 0.06 - 0.1)
|
|
1947
2183
|
)
|
|
1948
2184
|
|
|
1949
|
-
def _add_dual_legend(ax, dataset_names, colors,
|
|
2185
|
+
def _add_dual_legend(ax, dataset_names, colors, variant_keys, linewidth):
|
|
1950
2186
|
"""
|
|
1951
|
-
Add two legends: one for datasets (colors), one for
|
|
2187
|
+
Add two legends: one for datasets (colors), one for mPR variants (line styles).
|
|
1952
2188
|
"""
|
|
1953
|
-
|
|
2189
|
+
variant_keys = _normalize_show_filters(variant_keys)
|
|
1954
2190
|
# Legend 1: Datasets (colors) - solid lines
|
|
1955
2191
|
dataset_handles = []
|
|
1956
2192
|
for i, name in enumerate(dataset_names):
|
|
@@ -1958,19 +2194,19 @@ def _add_dual_legend(ax, dataset_names, colors, show_filters, linewidth):
|
|
|
1958
2194
|
handle = Line2D([0], [0], color=color, linewidth=linewidth, linestyle="-")
|
|
1959
2195
|
dataset_handles.append(handle)
|
|
1960
2196
|
|
|
1961
|
-
# Legend 2:
|
|
1962
|
-
|
|
1963
|
-
|
|
1964
|
-
for
|
|
1965
|
-
style =
|
|
2197
|
+
# Legend 2: mPR variants (line styles) - black lines
|
|
2198
|
+
variant_handles = []
|
|
2199
|
+
variant_labels = []
|
|
2200
|
+
for variant_key in variant_keys:
|
|
2201
|
+
style = MPR_VARIANT_STYLES.get(variant_key, {})
|
|
1966
2202
|
handle = Line2D(
|
|
1967
2203
|
[0], [0],
|
|
1968
2204
|
color="black",
|
|
1969
2205
|
linewidth=linewidth,
|
|
1970
2206
|
linestyle=style.get("linestyle", "-")
|
|
1971
2207
|
)
|
|
1972
|
-
|
|
1973
|
-
|
|
2208
|
+
variant_handles.append(handle)
|
|
2209
|
+
variant_labels.append(style.get("label", variant_key))
|
|
1974
2210
|
|
|
1975
2211
|
# Position legends
|
|
1976
2212
|
# Dataset legend on upper right
|
|
@@ -1987,19 +2223,12 @@ def _add_dual_legend(ax, dataset_names, colors, show_filters, linewidth):
|
|
|
1987
2223
|
|
|
1988
2224
|
# Filter legend on lower left or right depending on plot type
|
|
1989
2225
|
legend2 = ax.legend(
|
|
1990
|
-
|
|
1991
|
-
|
|
2226
|
+
variant_handles,
|
|
2227
|
+
variant_labels,
|
|
1992
2228
|
loc="lower left",
|
|
1993
2229
|
frameon=False,
|
|
1994
|
-
title="
|
|
2230
|
+
title="Variant",
|
|
1995
2231
|
fontsize=7,
|
|
1996
2232
|
title_fontsize=8,
|
|
1997
2233
|
)
|
|
1998
2234
|
|
|
1999
|
-
# ============================================================================
|
|
2000
|
-
# Single dataset functions are now obsolete
|
|
2001
|
-
# ============================================================================
|
|
2002
|
-
|
|
2003
|
-
# Note: The original single dataset functions plot_mpr_tp() and plot_mpr_complexes()
|
|
2004
|
-
# have been replaced by the multi functions that now auto-detect available datasets.
|
|
2005
|
-
# Use plot_mpr_tp_multi() and plot_mpr_complexes_multi() instead.
|