yuclid 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
yuclid/plot.py ADDED
@@ -0,0 +1,1009 @@
1
+ from yuclid.log import report, LogLevel
2
+ import yuclid.cli
3
+ import matplotlib.gridspec as gridspec
4
+ import matplotlib.lines as mlines
5
+ import matplotlib.pyplot as plt
6
+ import yuclid.spread as spread
7
+ import seaborn as sns
8
+ import pandas as pd
9
+ import numpy as np
10
+ import scipy.stats
11
+ import subprocess
12
+ import threading
13
+ import itertools
14
+ import pathlib
15
+ import hashlib
16
+ import time
17
+ import math
18
+ import sys
19
+
20
+
21
+ def get_current_config(ctx):
22
+ df = ctx["df"]
23
+ domains = ctx["domains"]
24
+ position = ctx["position"]
25
+ free_dims = ctx["free_dims"]
26
+ config = dict()
27
+ for d in free_dims:
28
+ k = domains[d][position[d]]
29
+ config[d] = k
30
+ return config
31
+
32
+
33
+ def get_config(point, keys):
34
+ config = dict()
35
+ for i, k in enumerate(keys):
36
+ if i < len(point):
37
+ config[k] = point[i]
38
+ else:
39
+ config[k] = None
40
+ return config
41
+
42
+
43
+ def get_projection(df, config):
44
+ keys = list(config.keys())
45
+ if len(keys) == 0:
46
+ return df
47
+ mask = (df[keys] == pd.Series(config)).all(axis=1)
48
+ return df[mask].copy()
49
+
50
+
51
+ def group_normalization(norm_axis, df, config, args, y_axis):
52
+ sub_df = get_projection(df, config)
53
+ ref_config = {k: v for k, v in config.items()} # copy
54
+ if norm_axis == "x":
55
+ selector = dict(pair.split("=") for pair in args.x_norm)
56
+ elif norm_axis == "z":
57
+ selector = dict(pair.split("=") for pair in args.z_norm)
58
+ ref_config.update(selector)
59
+
60
+ # fixing types
61
+ for k, v in ref_config.items():
62
+ ref_config[k] = df[k].dtype.type(v)
63
+
64
+ ref_df = get_projection(df, ref_config)
65
+ estimator = scipy.stats.gmean if args.geomean else np.median
66
+ gb_cols = df.columns.difference(args.y).tolist()
67
+ ref = ref_df.groupby(gb_cols)[y_axis].apply(estimator).reset_index()
68
+ if norm_axis == "x":
69
+ y_ref_at = lambda x: ref[ref[args.x] == x][y_axis].values[0]
70
+ y_ref = sub_df[args.x].map(y_ref_at)
71
+ elif norm_axis == "z":
72
+ y_ref_at = lambda z: ref[ref[args.z] == z][y_axis].values[0]
73
+ y_ref = sub_df[args.z].map(y_ref_at)
74
+ if args.norm_reverse:
75
+ sub_df[y_axis] = y_ref / sub_df[y_axis]
76
+ else:
77
+ sub_df[y_axis] = sub_df[y_axis] / y_ref
78
+ return sub_df
79
+
80
+
81
+ def ref_normalization(df, config, args, y_axis):
82
+ sub_df = get_projection(df, config)
83
+ ref_config = {k: v for k, v in config.items()} # copy
84
+ selector = dict(pair.split("=") for pair in args.ref_norm)
85
+ ref_config.update(selector)
86
+
87
+ # fixing types
88
+ for k, v in ref_config.items():
89
+ ref_config[k] = df[k].dtype.type(v)
90
+
91
+ ref_df = get_projection(df, ref_config)
92
+ estimator = scipy.stats.gmean if args.geomean else np.median
93
+ gb_cols = df.columns.difference(args.y).tolist()
94
+ ref = ref_df.groupby(gb_cols)[y_axis].apply(estimator).reset_index()
95
+ y_ref = ref[y_axis].values[0]
96
+ if args.norm_reverse:
97
+ sub_df[y_axis] = y_ref / sub_df[y_axis]
98
+ else:
99
+ sub_df[y_axis] = sub_df[y_axis] / y_ref
100
+ return sub_df
101
+
102
+
103
+ def validate_files(ctx):
104
+ args = ctx["args"]
105
+ valid_files = []
106
+ valid_formats = [".json", ".csv"]
107
+ for file in args.files:
108
+ if pathlib.Path(file).suffix in valid_formats:
109
+ valid_files.append(file)
110
+ else:
111
+ report(LogLevel.ERROR, f"unsupported file format {file}")
112
+ ctx["valid_files"] = valid_files
113
+
114
+
115
+ def get_local_mirror(rfile):
116
+ return pathlib.Path(rfile.split(":")[1]).name
117
+
118
+
119
+ def locate_files(ctx):
120
+ local_files = []
121
+ valid_files = ctx["valid_files"]
122
+ for file in valid_files:
123
+ if is_remote(file):
124
+ local_files.append(get_local_mirror(file))
125
+ else:
126
+ local_files.append(file)
127
+ ctx["local_files"] = local_files
128
+
129
+
130
+ def set_axes_style(ctx):
131
+ fig = ctx["fig"]
132
+ fig.set_size_inches(12, 10)
133
+ sns.set_theme(style="whitegrid")
134
+
135
+
136
+ def initialize_figure(ctx):
137
+ fig, axs = plt.subplots(2, 1, gridspec_kw={"height_ratios": [20, 1]})
138
+ ctx["fig"] = fig
139
+ ax_plot = axs[0]
140
+ ax_table = axs[1]
141
+ ax_plot.grid(axis="y")
142
+ set_axes_style(ctx)
143
+ y = ax_table.get_position().y1 + 0.03
144
+ line = mlines.Line2D(
145
+ [0.05, 0.95], [y, y], linewidth=4, transform=fig.transFigure, color="lightgrey"
146
+ )
147
+ fig.add_artist(line)
148
+ fig.subplots_adjust(top=0.92, bottom=0.1, hspace=0.3)
149
+ fig.canvas.mpl_connect("key_press_event", lambda event: on_key(event, ctx))
150
+ fig.canvas.mpl_connect("close_event", lambda event: on_close(event, ctx))
151
+ ctx["ax_plot"] = ax_plot
152
+ ctx["ax_table"] = ax_table
153
+
154
+
155
+ def generate_dataframe(ctx):
156
+ args = ctx["args"]
157
+ local_files = ctx["local_files"]
158
+ dfs = dict()
159
+ for file in local_files:
160
+ file = pathlib.Path(file)
161
+ try:
162
+ if file.suffix == ".json":
163
+ dfs[file.stem] = pd.read_json(file, lines=True)
164
+ elif file.suffix == ".csv":
165
+ dfs[file.stem] = pd.read_csv(file)
166
+ except:
167
+ report(LogLevel.ERROR, f"could not open {file}")
168
+
169
+ if len(dfs) == 0:
170
+ report(LogLevel.ERROR, "no valid source of data")
171
+ ctx["alive"] = False
172
+ sys.exit(1)
173
+
174
+ df = pd.concat(dfs)
175
+
176
+ if args.no_merge_inputs:
177
+ df = df.reset_index(level=0, names=["file"])
178
+ else:
179
+ df = df.reset_index(drop=True)
180
+
181
+ if args.filter is None:
182
+ user_filter = dict()
183
+ else:
184
+ user_filter = dict(pair.split("=") for pair in args.filter)
185
+ for k, v_list in user_filter.items():
186
+ v_list = v_list.split(",")
187
+ user_filter[k] = [df[k].dtype.type(v) for v in v_list]
188
+
189
+ if user_filter:
190
+ user_filter_mask = np.ones(len(df), dtype=bool)
191
+ for k, v_list in user_filter.items():
192
+ user_filter_mask &= df[k].isin(v_list)
193
+ df = df[user_filter_mask]
194
+
195
+ if len(df) == 0:
196
+ if args.filter:
197
+ report(LogLevel.FATAL, "no valid data after filtering")
198
+ else:
199
+ report(LogLevel.FATAL, "no valid data found in the files")
200
+ ctx["alive"] = False
201
+ return
202
+
203
+ ctx["df"] = df
204
+
205
+
206
+ def rescale(ctx):
207
+ df = ctx["df"]
208
+ args = ctx["args"]
209
+ for y in args.y:
210
+ df[y] = df[y] * args.rescale
211
+ ctx["df"] = df
212
+
213
+
214
+ def draw(fig, ax, cli_args):
215
+ ctx = dict()
216
+ parser = yuclid.cli.get_parser()
217
+ args = parser.parse_args(["plot"] + cli_args)
218
+ ctx["args"] = args
219
+ ctx["fig"] = fig
220
+ ctx["ax_plot"] = ax
221
+ yuclid.log.init(ignore_errors=args.ignore_errors)
222
+ validate_files(ctx)
223
+ locate_files(ctx)
224
+ set_axes_style(ctx)
225
+ generate_dataframe(ctx)
226
+ validate_args(ctx)
227
+ generate_space(ctx)
228
+ update_plot(ctx)
229
+
230
+
231
+ def generate_space(ctx):
232
+ args = ctx["args"]
233
+ df = ctx["df"]
234
+ z_size = df[args.z].nunique()
235
+ free_dims = list(df.columns.difference([args.x, args.z] + args.y))
236
+ selected_index = 0 if len(free_dims) > 0 else None
237
+ domains = dict()
238
+ position = dict()
239
+
240
+ for d in df.columns:
241
+ domains[d] = df[d].unique()
242
+ position[d] = 0
243
+
244
+ z_dom = df[args.z].unique()
245
+ ctx.update(
246
+ {
247
+ "z_size": z_size,
248
+ "free_dims": free_dims,
249
+ "selected_index": selected_index,
250
+ "domains": domains,
251
+ "position": position,
252
+ "z_dom": z_dom,
253
+ }
254
+ )
255
+
256
+
257
+ def file_monitor(ctx):
258
+ current_hash = None
259
+ last_hash = None
260
+ while ctx["alive"]:
261
+ try:
262
+ current_hash = ""
263
+ for file in ctx["local_files"]:
264
+ with open(file, "rb") as f:
265
+ current_hash += hashlib.md5(f.read()).hexdigest()
266
+ except FileNotFoundError:
267
+ current_hash = None
268
+ if current_hash != last_hash:
269
+ generate_dataframe(ctx)
270
+ rescale(ctx)
271
+ generate_space(ctx)
272
+ compute_ylimits(ctx)
273
+ space_columns = ctx["df"].columns.difference([ctx["y_axis"]])
274
+ sizes = ["{}={}".format(d, ctx["df"][d].nunique()) for d in space_columns]
275
+ missing = compute_missing(ctx)
276
+ report(LogLevel.INFO, "space sizes", " | ".join(sizes))
277
+ if len(missing) > 0:
278
+ report(LogLevel.WARNING, f"at least {len(missing)} missing experiments")
279
+ update_table(ctx)
280
+ update_plot(ctx)
281
+ last_hash = current_hash
282
+ time.sleep(1)
283
+
284
+
285
+ def update_table(ctx):
286
+ ax_table = ctx["ax_table"]
287
+ free_dims = ctx["free_dims"]
288
+ domains = ctx["domains"]
289
+ position = ctx["position"]
290
+ selected_index = ctx["selected_index"]
291
+ ax_table.clear()
292
+ ax_table.axis("off")
293
+ if len(free_dims) == 0:
294
+ return
295
+ arrow_up = "\u2191"
296
+ arrow_down = "\u2193"
297
+ fields = []
298
+ values = []
299
+ arrows = []
300
+ for i, d in enumerate(free_dims, start=1):
301
+ value = domains[d][position[d]]
302
+ if d == free_dims[selected_index]:
303
+ fields.append(rf"$\mathbf{{{d}}}$")
304
+ values.append(f"{value}")
305
+ arrows.append(f"{arrow_up}{arrow_down}")
306
+ else:
307
+ fields.append(rf"$\mathbf{{{d}}}$")
308
+ values.append(value)
309
+ arrows.append("")
310
+ ax_table.table(
311
+ cellText=[fields, values, arrows], cellLoc="center", edges="open", loc="center"
312
+ )
313
+ ctx["fig"].canvas.draw_idle()
314
+
315
+
316
+ def is_remote(file):
317
+ return "@" in file
318
+
319
+
320
+ def sync_files(ctx):
321
+ args = ctx["args"]
322
+ valid_files = ctx["valid_files"]
323
+ jobs = []
324
+ for file in valid_files:
325
+ if is_remote(file):
326
+ mirror = get_local_mirror(file)
327
+ proc = subprocess.run(["scp", file, mirror])
328
+ if proc.returncode != 0:
329
+ report(LogLevel.ERROR, f"scp transfer failed for {file}")
330
+ sys.exit(1)
331
+ jobs.append((file, mirror))
332
+
333
+ def rsync(src, dst):
334
+ while ctx["alive"]:
335
+ subprocess.run(
336
+ ["rsync", "-z", "--checksum", src, dst],
337
+ stdout=subprocess.DEVNULL,
338
+ stderr=subprocess.DEVNULL,
339
+ )
340
+ time.sleep(args.rsync_interval)
341
+
342
+ for job in jobs:
343
+ threading.Thread(target=rsync, daemon=True, args=job).start()
344
+
345
+
346
+ def fontsize_to_y_units(ctx, fontsize):
347
+ fig = ctx["fig"]
348
+ ax = ctx["ax_plot"]
349
+ dpi = fig.dpi
350
+ font_px = fontsize * dpi / 72
351
+ inv = ax.transData.inverted()
352
+ _, y0 = inv.transform((0, 0))
353
+ _, y1 = inv.transform((0, font_px))
354
+ dy = y1 - y0
355
+ return dy
356
+
357
+
358
+ def autospace_annotations(ctx, x_domain, ys, fontsize, padding_factor=1.10):
359
+ text_height = fontsize_to_y_units(ctx, fontsize)
360
+ h = text_height * padding_factor
361
+
362
+ y_adjust = {k: dict() for k in ys}
363
+ for x in x_domain:
364
+ y_vals = [(z, ys[z][x]) for z in ys]
365
+ lower_bound = -float("inf")
366
+ for z, y in sorted(y_vals, key=lambda item: item[1]):
367
+ box_bottom, box_top = y - h / 2, y + h / 2
368
+ if box_bottom < lower_bound: # overlap?
369
+ shift = lower_bound - box_bottom
370
+ new_y = y + shift
371
+ lower_bound = box_top + shift
372
+ else:
373
+ lower_bound = box_top
374
+ new_y = y
375
+ y_adjust[z][x] = new_y
376
+
377
+ return y_adjust
378
+
379
+
380
+ def annotate(ctx, plot_type, sub_df, y_axis, palette):
381
+ args = ctx["args"]
382
+ ax_plot = ctx["ax_plot"]
383
+
384
+ if not (args.annotate_max or args.annotate_min or args.annotate):
385
+ return
386
+
387
+ annotation_kwargs = {
388
+ "ha": "center",
389
+ "va": "bottom",
390
+ "color": "black",
391
+ "fontsize": 12,
392
+ "fontweight": "normal",
393
+ "xytext": (0, 5),
394
+ "textcoords": "offset points",
395
+ }
396
+
397
+ ys = dict()
398
+ z_domain = sub_df[args.z].unique()
399
+ x_domain = sub_df[args.x].unique()
400
+
401
+ for z in z_domain:
402
+ group = sub_df[sub_df[args.z] == z]
403
+ ys_z = group.groupby(args.x)[y_axis].apply(
404
+ scipy.stats.gmean if args.geomean else np.median
405
+ )
406
+ ys[z] = ys_z
407
+
408
+ x_adjust = {z: dict() for z in z_domain}
409
+ y_adjust = autospace_annotations(ctx, x_domain, ys, annotation_kwargs["fontsize"])
410
+
411
+ # adjust x positions for annotations based on the plot type
412
+ if plot_type == "lines":
413
+ for z in z_domain:
414
+ for x in x_domain:
415
+ x_adjust[z][x] = x # no adjustment needed for lines
416
+ elif plot_type == "bars":
417
+
418
+ def x_flat_generator():
419
+ for p in ax_plot.patches:
420
+ height = p.get_height()
421
+ if not np.isnan(height) and height > 0:
422
+ yield p.get_x() + p.get_width() / 2
423
+
424
+ x_flat_gen = iter(x_flat_generator())
425
+ for z in z_domain:
426
+ for x in x_domain:
427
+ x_adjust[z][x] = next(x_flat_gen)
428
+
429
+ for z in z_domain:
430
+ annotation_kwargs_z = annotation_kwargs.copy()
431
+ annotation_kwargs_z["color"] = palette[z]
432
+ if args.annotate_max:
433
+ y = ys[z].max()
434
+ x = ys[z].idxmax()
435
+ xa = x_adjust[z][x]
436
+ ya = y_adjust[z][x]
437
+ ax_plot.annotate(
438
+ f"{y:.2f}",
439
+ (xa, ya),
440
+ **annotation_kwargs_z,
441
+ )
442
+ if args.annotate_min:
443
+ y = ys[z].min()
444
+ x = ys[z].idxmin()
445
+ xa = x_adjust[z][x]
446
+ ya = y_adjust[z][x]
447
+ ax_plot.annotate(
448
+ f"{y:.2f}",
449
+ (xa, ya),
450
+ **annotation_kwargs_z,
451
+ )
452
+ if args.annotate:
453
+ for x, y in ys[z].items():
454
+ xa = x_adjust[z][x]
455
+ ya = y_adjust[z][x]
456
+ ax_plot.annotate(
457
+ f"{y:.2f}",
458
+ (xa, ya),
459
+ **annotation_kwargs_z,
460
+ )
461
+
462
+
463
+ def to_engineering_si(x, precision=0, unit=None):
464
+ if x == 0:
465
+ return f"{0:.{precision}f}"
466
+ si_prefixes = {
467
+ -24: "y",
468
+ -21: "z",
469
+ -18: "a",
470
+ -15: "f",
471
+ -12: "p",
472
+ -9: "n",
473
+ -6: "µ",
474
+ -3: "m",
475
+ 0: "",
476
+ 3: "k",
477
+ 6: "M",
478
+ 9: "G",
479
+ 12: "T",
480
+ 15: "P",
481
+ 18: "E",
482
+ 21: "Z",
483
+ 24: "Y",
484
+ }
485
+ exp = int(math.floor(math.log10(abs(x)) // 3 * 3))
486
+ exp = max(min(exp, 24), -24) # clamp to available prefixes
487
+ coeff = x / (10**exp)
488
+ prefix = si_prefixes.get(exp, f"e{exp:+03d}")
489
+ unit = unit or ""
490
+ return f"{coeff:.{precision}f}{prefix}{unit}"
491
+
492
+
493
+ def get_palette(values, colorblind=False):
494
+ if colorblind:
495
+ palette = sns.color_palette("colorblind", n_colors=len(values))
496
+ return {v: palette[i] for i, v in enumerate(values)}
497
+ else:
498
+ preferred_colors = [
499
+ "#5588dd",
500
+ "#882255",
501
+ "#33bb88",
502
+ "#9624e1",
503
+ "#BBBB41",
504
+ "#ed5a15",
505
+ "#aa44ff",
506
+ "#448811",
507
+ "#3fa7d6",
508
+ "#e94f37",
509
+ "#6cc551",
510
+ "#dabef9",
511
+ ]
512
+ color_gen = iter(preferred_colors)
513
+ return {v: next(color_gen) for v in values}
514
+
515
+
516
+ def update_plot(ctx, padding_factor=1.05):
517
+ args = ctx["args"]
518
+ df = ctx["df"]
519
+ y_axis = ctx["y_axis"]
520
+ ax_plot = ctx["ax_plot"]
521
+ top = ctx.get("top", None)
522
+
523
+ config = get_current_config(ctx)
524
+ sub_df = get_projection(df, config)
525
+
526
+ ax_plot.clear()
527
+
528
+ # set figure title
529
+ y_left, y_right = sub_df[y_axis].min(), sub_df[y_axis].max()
530
+ y_range = "[{} - {}]".format(
531
+ to_engineering_si(y_left, unit=args.unit),
532
+ to_engineering_si(y_right, unit=args.unit),
533
+ )
534
+ title_parts = []
535
+ for i, y in enumerate(args.y, start=1):
536
+ if y == y_axis:
537
+ title_parts.append(rf"{i}: $\mathbf{{{y}}}$")
538
+ else:
539
+ title_parts.append(f"{i}: {y}")
540
+ title = " | ".join(title_parts) + "\n" + y_range
541
+ ctx["fig"].suptitle(title)
542
+
543
+ if args.x_norm:
544
+ sub_df = group_normalization("x", df, config, args, y_axis)
545
+ elif args.z_norm:
546
+ sub_df = group_normalization("z", df, config, args, y_axis)
547
+ elif args.ref_norm:
548
+ sub_df = ref_normalization(df, config, args, y_axis)
549
+
550
+ if args.geomean:
551
+ gm_df = sub_df.copy()
552
+ gm_df[args.x] = "geomean"
553
+ sub_df = pd.concat([sub_df, gm_df])
554
+
555
+ # draw horizontal line at y=1.0
556
+ if args.x_norm or args.z_norm or args.ref_norm:
557
+ ax_plot.axhline(y=1.0, linestyle="-", linewidth=4, color="lightgrey")
558
+
559
+ def custom_error(data):
560
+ d = pd.DataFrame(data)
561
+ return (
562
+ spread.lower(args.spread_measure)(d),
563
+ spread.upper(args.spread_measure)(d),
564
+ )
565
+
566
+ palette = get_palette(ctx["z_dom"], colorblind=args.colorblind)
567
+
568
+ # main plot generation
569
+ if args.lines:
570
+ sns.lineplot(
571
+ data=sub_df,
572
+ x=args.x,
573
+ y=y_axis,
574
+ hue=args.z,
575
+ palette=palette,
576
+ lw=2,
577
+ linestyle="-",
578
+ marker="o",
579
+ errorbar=None,
580
+ ax=ax_plot,
581
+ estimator=np.median,
582
+ )
583
+ if args.spread_measure != "none":
584
+ spread.draw(
585
+ ax_plot,
586
+ [args.spread_measure],
587
+ sub_df,
588
+ x=args.x,
589
+ y=y_axis,
590
+ z=args.z,
591
+ palette=palette,
592
+ )
593
+ else:
594
+ sns.barplot(
595
+ data=sub_df,
596
+ ax=ax_plot,
597
+ estimator=scipy.stats.gmean if args.geomean else np.median,
598
+ palette=palette,
599
+ legend=True,
600
+ x=args.x,
601
+ y=y_axis,
602
+ hue=args.z,
603
+ errorbar=custom_error if args.spread_measure != "none" else None,
604
+ alpha=0.6,
605
+ err_kws={
606
+ "color": "black",
607
+ "alpha": 1.0,
608
+ "linewidth": 2.0,
609
+ "solid_capstyle": "round",
610
+ "solid_joinstyle": "round",
611
+ },
612
+ )
613
+
614
+ # draw vertical line to separate geomean
615
+ if args.geomean:
616
+ pp = sorted(ax_plot.patches, key=lambda x: x.get_x())
617
+ z_size = ctx["z_size"]
618
+ x = pp[-z_size].get_x() + pp[-z_size - 1].get_x() + pp[-z_size - 1].get_width()
619
+ plt.axvline(x=x / 2, color="grey", linewidth=1, linestyle="-")
620
+
621
+ # set y-axis label
622
+ def format_ylabel(label):
623
+ if args.unit is None:
624
+ return label
625
+ elif args.x_norm or args.z_norm or args.ref_norm:
626
+ return label
627
+ else:
628
+ return f"{label} [{args.unit}]"
629
+
630
+ if top is not None:
631
+ ax_plot.set_ylim(top=top * padding_factor, bottom=0.0)
632
+
633
+ if args.x_norm or args.z_norm or args.ref_norm:
634
+ if args.norm_reverse:
635
+ normalized_label = f"{y_axis} (gain)"
636
+ else:
637
+ normalized_label = f"{y_axis} (normalized)"
638
+ ax_plot.set_ylabel(format_ylabel(normalized_label))
639
+ else:
640
+ ax_plot.set_ylabel(format_ylabel(y_axis))
641
+
642
+ # format y-tick labels with 'x' suffix for normalized plots
643
+ if args.x_norm or args.z_norm or args.ref_norm:
644
+ # use FuncFormatter to append 'x' to tick labels
645
+ from matplotlib.ticker import FuncFormatter
646
+
647
+ def format_with_x(x, pos):
648
+ return f"{x:.2f}x"
649
+
650
+ ax_plot.yaxis.set_major_formatter(FuncFormatter(format_with_x))
651
+ ax_plot.set_yticks(sorted(set(list(ax_plot.get_yticks()) + [1.0])))
652
+
653
+ if args.lines:
654
+ annotate(ctx, "lines", sub_df, y_axis, palette)
655
+ else:
656
+ annotate(ctx, "bars", sub_df, y_axis, palette)
657
+
658
+ ctx["fig"].canvas.draw_idle()
659
+
660
+
661
+ def get_config_name(ctx):
662
+ y_axis = ctx["y_axis"]
663
+ args = ctx["args"]
664
+ config = get_current_config(ctx)
665
+ if args.ref_norm or args.x_norm or args.z_norm:
666
+ if args.norm_reverse:
667
+ status = [f"{y_axis}", "gain"]
668
+ else:
669
+ status = [f"{y_axis}", "normalized"]
670
+ else:
671
+ status = [f"{y_axis}"]
672
+ status += [str(v) for v in config.values()]
673
+ name = "_".join(status)
674
+ return name
675
+
676
+
677
+ def get_status_description(ctx):
678
+ args = ctx["args"]
679
+ description_parts = []
680
+ domains = ctx["domains"]
681
+
682
+ for d in ctx["free_dims"]:
683
+ position = ctx["position"]
684
+ value = domains[d][position[d]]
685
+ description_parts.append(f"{d}={value}")
686
+
687
+ description = " | ".join(description_parts)
688
+ if ctx["z_size"] == 1:
689
+ z_values = ctx["df"][args.z].unique()
690
+ description += f" | {args.z}={z_values[0]}"
691
+
692
+ return description
693
+
694
+
695
+ def save_to_file(ctx, outfile=None):
696
+ ax_plot = ctx["ax_plot"]
697
+ args = ctx["args"]
698
+ outfile = outfile or get_config_name(ctx) + ".pdf"
699
+ if ctx["z_size"] == 1:
700
+ legend = ax_plot.get_legend()
701
+ if legend:
702
+ legend.set_visible(False)
703
+
704
+ name = str(ctx["y_axis"])
705
+ s = "gain" if args.norm_reverse else "normalized"
706
+ if args.ref_norm:
707
+ wrt = " | ".join(args.ref_norm)
708
+ title = rf"$\mathbf{{{name}}}$ ({s} w.r.t {wrt})"
709
+ elif args.x_norm:
710
+ wrt = " | ".join(args.x_norm)
711
+ title = rf"$\mathbf{{{name}}}$ ({s} w.r.t {wrt})"
712
+ elif args.z_norm:
713
+ wrt = " | ".join(args.z_norm)
714
+ title = rf"$\mathbf{{{name}}}$ ({s} w.r.t {wrt})"
715
+ else:
716
+ title = rf"$\mathbf{{{name}}}$"
717
+
718
+ title += "\n" + get_status_description(ctx)
719
+ ctx["fig"].suptitle(title)
720
+ extent = ax_plot.get_window_extent().transformed(
721
+ ctx["fig"].dpi_scale_trans.inverted()
722
+ )
723
+ ctx["fig"].savefig(outfile, bbox_inches=extent.expanded(1.2, 1.2))
724
+ report(LogLevel.INFO, f"saved to '{outfile}'")
725
+
726
+
727
+ def on_key(event, ctx):
728
+ selected_index = ctx["selected_index"]
729
+ free_dims = ctx["free_dims"]
730
+ domains = ctx["domains"]
731
+ position = ctx["position"]
732
+ y_dims = ctx["y_dims"]
733
+
734
+ if event.key in ["enter", " ", "up", "down"]:
735
+ x = 1 if event.key in [" ", "enter", "up"] else -1
736
+ if selected_index is None:
737
+ return
738
+ selected_dim = free_dims[selected_index]
739
+ cur_pos = position[selected_dim]
740
+ new_pos = (cur_pos + x) % domains[selected_dim].size
741
+ position[selected_dim] = new_pos
742
+ update_plot(ctx)
743
+ update_table(ctx)
744
+ elif event.key in ["left", "right"]:
745
+ if selected_index is None:
746
+ return
747
+ if event.key == "left":
748
+ ctx["selected_index"] = (selected_index - 1) % len(free_dims)
749
+ else:
750
+ ctx["selected_index"] = (selected_index + 1) % len(free_dims)
751
+ update_table(ctx)
752
+ elif event.key in "123456789":
753
+ new_idx = int(event.key) - 1
754
+ if new_idx < len(y_dims):
755
+ ctx["y_axis"] = y_dims[new_idx]
756
+ compute_ylimits(ctx)
757
+ update_plot(ctx)
758
+ elif event.key in ".":
759
+ save_to_file(ctx)
760
+
761
+
762
+ def on_close(event, ctx):
763
+ ctx["alive"] = False
764
+
765
+
766
+ def compute_missing(ctx):
767
+ df = ctx["df"]
768
+ y_dims = ctx["y_dims"]
769
+ space_columns = df.columns.difference(y_dims)
770
+ expected = set(itertools.product(*[df[col].unique() for col in space_columns]))
771
+ observed = set(map(tuple, df[space_columns].drop_duplicates().values))
772
+ missing = expected - observed
773
+ return pd.DataFrame(list(missing), columns=space_columns)
774
+
775
+
776
+ def validate_dimensions(ctx, dims):
777
+ args = ctx["args"]
778
+ df = ctx["df"]
779
+ for col in dims:
780
+ if col not in df.columns:
781
+ available = list(df.columns)
782
+ hint = "available columns: {}".format(", ".join(available))
783
+ report(LogLevel.FATAL, "invalid column", col, hint=hint)
784
+
785
+
786
+ def validate_args(ctx):
787
+ args = ctx["args"]
788
+ df = ctx["df"]
789
+
790
+ validate_dimensions(ctx, [args.x])
791
+
792
+ # Y-axis
793
+ numeric_cols = (
794
+ df.drop(columns=[args.x]).select_dtypes(include=[np.number]).columns.tolist()
795
+ )
796
+ if args.y is None:
797
+ # find the floating point numeric columns
798
+ if len(numeric_cols) == 0:
799
+ report(
800
+ LogLevel.FATAL,
801
+ "No numeric columns found in the data",
802
+ hint="use -y to specify a Y-axis",
803
+ )
804
+ report(LogLevel.INFO, "Using '{}' as Y-axis".format(", ".join(numeric_cols)))
805
+ args.y = numeric_cols
806
+ validate_dimensions(ctx, args.y)
807
+ for y in args.y:
808
+ if not pd.api.types.is_numeric_dtype(df[y]):
809
+ t = df[y].dtype
810
+ if len(numeric_cols) > 0:
811
+ hint = "try {}".format(
812
+ numeric_cols[0]
813
+ if len(numeric_cols) == 1
814
+ else ", ".join(numeric_cols)
815
+ )
816
+ else:
817
+ hint = "use -y to specify a Y-axis"
818
+ report(
819
+ LogLevel.FATAL,
820
+ f"Y-axis must have a numeric type. '{y}' has type '{t}'",
821
+ hint=hint,
822
+ )
823
+
824
+ if args.x in args.y:
825
+ report(
826
+ LogLevel.FATAL,
827
+ f"X-axis and Y-axis must be different dimensions",
828
+ )
829
+
830
+ # Z-axis
831
+ # check that there are at least two dimensions other than args.y
832
+ if len(df.columns.difference(args.y)) < 2:
833
+ report(
834
+ LogLevel.FATAL,
835
+ "there must be at least two dimensions other than the Y-axis",
836
+ )
837
+ if args.z is None:
838
+ # pick the first column that is not args.x or in args.y
839
+ available = df.columns.difference([args.x] + args.y)
840
+ args.z = available[np.argmin([df[col].nunique() for col in available])]
841
+ report(LogLevel.INFO, "Using '{}' as Z-axis".format(args.z))
842
+ else:
843
+ validate_dimensions(ctx, [args.z])
844
+ zdom = df[args.z].unique()
845
+ if len(zdom) == 1 and args.geomean:
846
+ report(
847
+ LogLevel.WARNING,
848
+ "--geomean is superfluous because '{}' is the only value in the '{}' group".format(
849
+ zdom[0], args.z
850
+ ),
851
+ )
852
+
853
+ # all axis
854
+ if args.x == args.z or args.z in args.y:
855
+ report(
856
+ LogLevel.FATAL,
857
+ "the -z dimension must be different from the dimension used on the X or Y axis",
858
+ )
859
+
860
+ # geomean and lines
861
+ if args.geomean and args.lines:
862
+ report(LogLevel.FATAL, "--geomean and --lines cannot be used together")
863
+ for d in df.columns.difference(args.y):
864
+ n = df[d].nunique()
865
+ if n > 100 and pd.api.types.is_numeric_dtype(df[d]):
866
+ report(
867
+ LogLevel.WARNING,
868
+ f"'{d}' seems to have many ({n}) numeric values. Are you sure this is not supposed to be the Y-axis?",
869
+ )
870
+
871
+ # normalization
872
+ def validate_pairs(norm_args):
873
+ for arg in norm_args:
874
+ if "=" not in arg:
875
+ report(
876
+ LogLevel.FATAL,
877
+ f"invalid normalization argument '{arg}', expected format 'key=value'",
878
+ )
879
+ return {pair.split("=")[0]: pair.split("=")[1] for pair in norm_args}
880
+
881
+ if (
882
+ (args.x_norm and args.z_norm)
883
+ or (args.x_norm and args.ref_norm)
884
+ or (args.z_norm and args.ref_norm)
885
+ ):
886
+ report(
887
+ LogLevel.FATAL,
888
+ "only one normalization method can be used at a time: --x-norm, --z-norm, or --ref-norm",
889
+ )
890
+ if args.ref_norm:
891
+ keys = validate_pairs(args.ref_norm).keys()
892
+ if args.x not in keys or args.z not in keys:
893
+ hint = "try adding '{}=<value>' or '{}=<value>' to --ref-norm".format(
894
+ args.x, args.z
895
+ )
896
+ report(
897
+ LogLevel.FATAL,
898
+ "--ref-norm pairs must include both the X-axis and Z-axis dimensions",
899
+ hint=hint,
900
+ )
901
+ elif args.x_norm:
902
+ keys = validate_pairs(args.x_norm).keys()
903
+ if args.z not in keys:
904
+ hint = "try adding '{}=<value>' to --x-norm".format(args.z)
905
+ report(
906
+ LogLevel.FATAL,
907
+ "--x-norm pairs must include the Z-axis dimension",
908
+ hint=hint,
909
+ )
910
+ if args.x in keys:
911
+ hint = "try removing '{}=<value>' from --x-norm".format(args.x)
912
+ report(
913
+ LogLevel.FATAL,
914
+ "--x-norm pairs must not include the X-axis dimension",
915
+ hint=hint,
916
+ )
917
+ elif args.z_norm:
918
+ keys = validate_pairs(args.z_norm).keys()
919
+ if args.x not in keys:
920
+ hint = "try adding '{}=<value>' to --z-norm".format(args.x)
921
+ report(
922
+ LogLevel.FATAL,
923
+ "--z-norm pairs must include the X-axis dimension",
924
+ hint=hint,
925
+ )
926
+ if args.z in keys:
927
+ hint = "try removing '{}=<value>' from --z-norm".format(args.z)
928
+ report(
929
+ LogLevel.FATAL,
930
+ "--z-norm pairs must not include the Z-axis dimension",
931
+ hint=hint,
932
+ )
933
+ if not (args.x_norm or args.z_norm or args.ref_norm) and args.norm_reverse:
934
+ report(
935
+ LogLevel.WARNING,
936
+ "--norm-reverse is ignored because no normalization is applied",
937
+ )
938
+
939
+ if args.spread_measure != "none":
940
+ if not spread.assert_validity(args.spread_measure):
941
+ args.spread_measure = "none"
942
+
943
+ ctx["y_dims"] = args.y
944
+ ctx["y_axis"] = args.y[0]
945
+
946
+ if args.show_missing:
947
+ missing = compute_missing(ctx)
948
+ if len(missing) > 0:
949
+ report(LogLevel.WARNING, "missing experiments:")
950
+ report(LogLevel.WARNING, "\n" + missing.to_string(index=False))
951
+ report(LogLevel.WARNING, "")
952
+
953
+
954
+ def start_gui(ctx):
955
+ ctx["alive"] = True
956
+
957
+ update_plot(ctx)
958
+ update_table(ctx)
959
+ threading.Thread(target=file_monitor, daemon=True, args=(ctx,)).start()
960
+ report(LogLevel.INFO, "application running")
961
+ time.sleep(1.0) # wait for the GUI to initialize
962
+ plt.show()
963
+
964
+
965
+ def compute_ylimits(ctx):
966
+ args = ctx["args"]
967
+ free_dims = ctx["free_dims"]
968
+ df = ctx["df"]
969
+ y_axis = ctx["y_axis"]
970
+ domains = ctx["domains"]
971
+ free_domains = {k: v for k, v in domains.items() if k in free_dims}
972
+ top = None
973
+ if len(free_dims) == 0:
974
+ ctx["top"] = None
975
+ return
976
+ if args.x_norm or args.z_norm or args.ref_norm:
977
+ top = 0
978
+ for point in itertools.product(*free_domains.values()):
979
+ filt = (df[free_domains.keys()] == point).all(axis=1)
980
+ config = get_config(point, free_domains.keys())
981
+ if args.ref_norm:
982
+ df_config = ref_normalization(df, config, args, y_axis)
983
+ elif args.x_norm:
984
+ df_config = group_normalization("x", df, config, args, y_axis)
985
+ elif args.z_norm:
986
+ df_config = group_normalization("z", df, config, args, y_axis)
987
+ zx = df_config.groupby([args.z, args.x])[y_axis]
988
+ if args.spread_measure != "none":
989
+ t = zx.apply(spread.upper(args.spread_measure))
990
+ else:
991
+ t = zx.max()
992
+ top = max(top, t.max())
993
+ else:
994
+ top = df[y_axis].max()
995
+ ctx["top"] = top
996
+
997
+
998
+ def launch(args):
999
+ ctx = {"args": args, "alive": True}
1000
+ validate_files(ctx)
1001
+ locate_files(ctx)
1002
+ sync_files(ctx)
1003
+ generate_dataframe(ctx)
1004
+ validate_args(ctx)
1005
+ rescale(ctx)
1006
+ generate_space(ctx)
1007
+ compute_ylimits(ctx)
1008
+ initialize_figure(ctx)
1009
+ start_gui(ctx)