yuclid 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yuclid/__init__.py +1 -0
- yuclid/cli.py +229 -0
- yuclid/log.py +56 -0
- yuclid/plot.py +1009 -0
- yuclid/run.py +1239 -0
- yuclid/spread.py +152 -0
- yuclid-0.1.0.dist-info/METADATA +15 -0
- yuclid-0.1.0.dist-info/RECORD +11 -0
- yuclid-0.1.0.dist-info/WHEEL +5 -0
- yuclid-0.1.0.dist-info/entry_points.txt +2 -0
- yuclid-0.1.0.dist-info/top_level.txt +1 -0
yuclid/plot.py
ADDED
@@ -0,0 +1,1009 @@
|
|
1
|
+
from yuclid.log import report, LogLevel
|
2
|
+
import yuclid.cli
|
3
|
+
import matplotlib.gridspec as gridspec
|
4
|
+
import matplotlib.lines as mlines
|
5
|
+
import matplotlib.pyplot as plt
|
6
|
+
import yuclid.spread as spread
|
7
|
+
import seaborn as sns
|
8
|
+
import pandas as pd
|
9
|
+
import numpy as np
|
10
|
+
import scipy.stats
|
11
|
+
import subprocess
|
12
|
+
import threading
|
13
|
+
import itertools
|
14
|
+
import pathlib
|
15
|
+
import hashlib
|
16
|
+
import time
|
17
|
+
import math
|
18
|
+
import sys
|
19
|
+
|
20
|
+
|
21
|
+
def get_current_config(ctx):
|
22
|
+
df = ctx["df"]
|
23
|
+
domains = ctx["domains"]
|
24
|
+
position = ctx["position"]
|
25
|
+
free_dims = ctx["free_dims"]
|
26
|
+
config = dict()
|
27
|
+
for d in free_dims:
|
28
|
+
k = domains[d][position[d]]
|
29
|
+
config[d] = k
|
30
|
+
return config
|
31
|
+
|
32
|
+
|
33
|
+
def get_config(point, keys):
|
34
|
+
config = dict()
|
35
|
+
for i, k in enumerate(keys):
|
36
|
+
if i < len(point):
|
37
|
+
config[k] = point[i]
|
38
|
+
else:
|
39
|
+
config[k] = None
|
40
|
+
return config
|
41
|
+
|
42
|
+
|
43
|
+
def get_projection(df, config):
|
44
|
+
keys = list(config.keys())
|
45
|
+
if len(keys) == 0:
|
46
|
+
return df
|
47
|
+
mask = (df[keys] == pd.Series(config)).all(axis=1)
|
48
|
+
return df[mask].copy()
|
49
|
+
|
50
|
+
|
51
|
+
def group_normalization(norm_axis, df, config, args, y_axis):
|
52
|
+
sub_df = get_projection(df, config)
|
53
|
+
ref_config = {k: v for k, v in config.items()} # copy
|
54
|
+
if norm_axis == "x":
|
55
|
+
selector = dict(pair.split("=") for pair in args.x_norm)
|
56
|
+
elif norm_axis == "z":
|
57
|
+
selector = dict(pair.split("=") for pair in args.z_norm)
|
58
|
+
ref_config.update(selector)
|
59
|
+
|
60
|
+
# fixing types
|
61
|
+
for k, v in ref_config.items():
|
62
|
+
ref_config[k] = df[k].dtype.type(v)
|
63
|
+
|
64
|
+
ref_df = get_projection(df, ref_config)
|
65
|
+
estimator = scipy.stats.gmean if args.geomean else np.median
|
66
|
+
gb_cols = df.columns.difference(args.y).tolist()
|
67
|
+
ref = ref_df.groupby(gb_cols)[y_axis].apply(estimator).reset_index()
|
68
|
+
if norm_axis == "x":
|
69
|
+
y_ref_at = lambda x: ref[ref[args.x] == x][y_axis].values[0]
|
70
|
+
y_ref = sub_df[args.x].map(y_ref_at)
|
71
|
+
elif norm_axis == "z":
|
72
|
+
y_ref_at = lambda z: ref[ref[args.z] == z][y_axis].values[0]
|
73
|
+
y_ref = sub_df[args.z].map(y_ref_at)
|
74
|
+
if args.norm_reverse:
|
75
|
+
sub_df[y_axis] = y_ref / sub_df[y_axis]
|
76
|
+
else:
|
77
|
+
sub_df[y_axis] = sub_df[y_axis] / y_ref
|
78
|
+
return sub_df
|
79
|
+
|
80
|
+
|
81
|
+
def ref_normalization(df, config, args, y_axis):
|
82
|
+
sub_df = get_projection(df, config)
|
83
|
+
ref_config = {k: v for k, v in config.items()} # copy
|
84
|
+
selector = dict(pair.split("=") for pair in args.ref_norm)
|
85
|
+
ref_config.update(selector)
|
86
|
+
|
87
|
+
# fixing types
|
88
|
+
for k, v in ref_config.items():
|
89
|
+
ref_config[k] = df[k].dtype.type(v)
|
90
|
+
|
91
|
+
ref_df = get_projection(df, ref_config)
|
92
|
+
estimator = scipy.stats.gmean if args.geomean else np.median
|
93
|
+
gb_cols = df.columns.difference(args.y).tolist()
|
94
|
+
ref = ref_df.groupby(gb_cols)[y_axis].apply(estimator).reset_index()
|
95
|
+
y_ref = ref[y_axis].values[0]
|
96
|
+
if args.norm_reverse:
|
97
|
+
sub_df[y_axis] = y_ref / sub_df[y_axis]
|
98
|
+
else:
|
99
|
+
sub_df[y_axis] = sub_df[y_axis] / y_ref
|
100
|
+
return sub_df
|
101
|
+
|
102
|
+
|
103
|
+
def validate_files(ctx):
|
104
|
+
args = ctx["args"]
|
105
|
+
valid_files = []
|
106
|
+
valid_formats = [".json", ".csv"]
|
107
|
+
for file in args.files:
|
108
|
+
if pathlib.Path(file).suffix in valid_formats:
|
109
|
+
valid_files.append(file)
|
110
|
+
else:
|
111
|
+
report(LogLevel.ERROR, f"unsupported file format {file}")
|
112
|
+
ctx["valid_files"] = valid_files
|
113
|
+
|
114
|
+
|
115
|
+
def get_local_mirror(rfile):
|
116
|
+
return pathlib.Path(rfile.split(":")[1]).name
|
117
|
+
|
118
|
+
|
119
|
+
def locate_files(ctx):
|
120
|
+
local_files = []
|
121
|
+
valid_files = ctx["valid_files"]
|
122
|
+
for file in valid_files:
|
123
|
+
if is_remote(file):
|
124
|
+
local_files.append(get_local_mirror(file))
|
125
|
+
else:
|
126
|
+
local_files.append(file)
|
127
|
+
ctx["local_files"] = local_files
|
128
|
+
|
129
|
+
|
130
|
+
def set_axes_style(ctx):
|
131
|
+
fig = ctx["fig"]
|
132
|
+
fig.set_size_inches(12, 10)
|
133
|
+
sns.set_theme(style="whitegrid")
|
134
|
+
|
135
|
+
|
136
|
+
def initialize_figure(ctx):
|
137
|
+
fig, axs = plt.subplots(2, 1, gridspec_kw={"height_ratios": [20, 1]})
|
138
|
+
ctx["fig"] = fig
|
139
|
+
ax_plot = axs[0]
|
140
|
+
ax_table = axs[1]
|
141
|
+
ax_plot.grid(axis="y")
|
142
|
+
set_axes_style(ctx)
|
143
|
+
y = ax_table.get_position().y1 + 0.03
|
144
|
+
line = mlines.Line2D(
|
145
|
+
[0.05, 0.95], [y, y], linewidth=4, transform=fig.transFigure, color="lightgrey"
|
146
|
+
)
|
147
|
+
fig.add_artist(line)
|
148
|
+
fig.subplots_adjust(top=0.92, bottom=0.1, hspace=0.3)
|
149
|
+
fig.canvas.mpl_connect("key_press_event", lambda event: on_key(event, ctx))
|
150
|
+
fig.canvas.mpl_connect("close_event", lambda event: on_close(event, ctx))
|
151
|
+
ctx["ax_plot"] = ax_plot
|
152
|
+
ctx["ax_table"] = ax_table
|
153
|
+
|
154
|
+
|
155
|
+
def generate_dataframe(ctx):
|
156
|
+
args = ctx["args"]
|
157
|
+
local_files = ctx["local_files"]
|
158
|
+
dfs = dict()
|
159
|
+
for file in local_files:
|
160
|
+
file = pathlib.Path(file)
|
161
|
+
try:
|
162
|
+
if file.suffix == ".json":
|
163
|
+
dfs[file.stem] = pd.read_json(file, lines=True)
|
164
|
+
elif file.suffix == ".csv":
|
165
|
+
dfs[file.stem] = pd.read_csv(file)
|
166
|
+
except:
|
167
|
+
report(LogLevel.ERROR, f"could not open {file}")
|
168
|
+
|
169
|
+
if len(dfs) == 0:
|
170
|
+
report(LogLevel.ERROR, "no valid source of data")
|
171
|
+
ctx["alive"] = False
|
172
|
+
sys.exit(1)
|
173
|
+
|
174
|
+
df = pd.concat(dfs)
|
175
|
+
|
176
|
+
if args.no_merge_inputs:
|
177
|
+
df = df.reset_index(level=0, names=["file"])
|
178
|
+
else:
|
179
|
+
df = df.reset_index(drop=True)
|
180
|
+
|
181
|
+
if args.filter is None:
|
182
|
+
user_filter = dict()
|
183
|
+
else:
|
184
|
+
user_filter = dict(pair.split("=") for pair in args.filter)
|
185
|
+
for k, v_list in user_filter.items():
|
186
|
+
v_list = v_list.split(",")
|
187
|
+
user_filter[k] = [df[k].dtype.type(v) for v in v_list]
|
188
|
+
|
189
|
+
if user_filter:
|
190
|
+
user_filter_mask = np.ones(len(df), dtype=bool)
|
191
|
+
for k, v_list in user_filter.items():
|
192
|
+
user_filter_mask &= df[k].isin(v_list)
|
193
|
+
df = df[user_filter_mask]
|
194
|
+
|
195
|
+
if len(df) == 0:
|
196
|
+
if args.filter:
|
197
|
+
report(LogLevel.FATAL, "no valid data after filtering")
|
198
|
+
else:
|
199
|
+
report(LogLevel.FATAL, "no valid data found in the files")
|
200
|
+
ctx["alive"] = False
|
201
|
+
return
|
202
|
+
|
203
|
+
ctx["df"] = df
|
204
|
+
|
205
|
+
|
206
|
+
def rescale(ctx):
|
207
|
+
df = ctx["df"]
|
208
|
+
args = ctx["args"]
|
209
|
+
for y in args.y:
|
210
|
+
df[y] = df[y] * args.rescale
|
211
|
+
ctx["df"] = df
|
212
|
+
|
213
|
+
|
214
|
+
def draw(fig, ax, cli_args):
|
215
|
+
ctx = dict()
|
216
|
+
parser = yuclid.cli.get_parser()
|
217
|
+
args = parser.parse_args(["plot"] + cli_args)
|
218
|
+
ctx["args"] = args
|
219
|
+
ctx["fig"] = fig
|
220
|
+
ctx["ax_plot"] = ax
|
221
|
+
yuclid.log.init(ignore_errors=args.ignore_errors)
|
222
|
+
validate_files(ctx)
|
223
|
+
locate_files(ctx)
|
224
|
+
set_axes_style(ctx)
|
225
|
+
generate_dataframe(ctx)
|
226
|
+
validate_args(ctx)
|
227
|
+
generate_space(ctx)
|
228
|
+
update_plot(ctx)
|
229
|
+
|
230
|
+
|
231
|
+
def generate_space(ctx):
|
232
|
+
args = ctx["args"]
|
233
|
+
df = ctx["df"]
|
234
|
+
z_size = df[args.z].nunique()
|
235
|
+
free_dims = list(df.columns.difference([args.x, args.z] + args.y))
|
236
|
+
selected_index = 0 if len(free_dims) > 0 else None
|
237
|
+
domains = dict()
|
238
|
+
position = dict()
|
239
|
+
|
240
|
+
for d in df.columns:
|
241
|
+
domains[d] = df[d].unique()
|
242
|
+
position[d] = 0
|
243
|
+
|
244
|
+
z_dom = df[args.z].unique()
|
245
|
+
ctx.update(
|
246
|
+
{
|
247
|
+
"z_size": z_size,
|
248
|
+
"free_dims": free_dims,
|
249
|
+
"selected_index": selected_index,
|
250
|
+
"domains": domains,
|
251
|
+
"position": position,
|
252
|
+
"z_dom": z_dom,
|
253
|
+
}
|
254
|
+
)
|
255
|
+
|
256
|
+
|
257
|
+
def file_monitor(ctx):
|
258
|
+
current_hash = None
|
259
|
+
last_hash = None
|
260
|
+
while ctx["alive"]:
|
261
|
+
try:
|
262
|
+
current_hash = ""
|
263
|
+
for file in ctx["local_files"]:
|
264
|
+
with open(file, "rb") as f:
|
265
|
+
current_hash += hashlib.md5(f.read()).hexdigest()
|
266
|
+
except FileNotFoundError:
|
267
|
+
current_hash = None
|
268
|
+
if current_hash != last_hash:
|
269
|
+
generate_dataframe(ctx)
|
270
|
+
rescale(ctx)
|
271
|
+
generate_space(ctx)
|
272
|
+
compute_ylimits(ctx)
|
273
|
+
space_columns = ctx["df"].columns.difference([ctx["y_axis"]])
|
274
|
+
sizes = ["{}={}".format(d, ctx["df"][d].nunique()) for d in space_columns]
|
275
|
+
missing = compute_missing(ctx)
|
276
|
+
report(LogLevel.INFO, "space sizes", " | ".join(sizes))
|
277
|
+
if len(missing) > 0:
|
278
|
+
report(LogLevel.WARNING, f"at least {len(missing)} missing experiments")
|
279
|
+
update_table(ctx)
|
280
|
+
update_plot(ctx)
|
281
|
+
last_hash = current_hash
|
282
|
+
time.sleep(1)
|
283
|
+
|
284
|
+
|
285
|
+
def update_table(ctx):
|
286
|
+
ax_table = ctx["ax_table"]
|
287
|
+
free_dims = ctx["free_dims"]
|
288
|
+
domains = ctx["domains"]
|
289
|
+
position = ctx["position"]
|
290
|
+
selected_index = ctx["selected_index"]
|
291
|
+
ax_table.clear()
|
292
|
+
ax_table.axis("off")
|
293
|
+
if len(free_dims) == 0:
|
294
|
+
return
|
295
|
+
arrow_up = "\u2191"
|
296
|
+
arrow_down = "\u2193"
|
297
|
+
fields = []
|
298
|
+
values = []
|
299
|
+
arrows = []
|
300
|
+
for i, d in enumerate(free_dims, start=1):
|
301
|
+
value = domains[d][position[d]]
|
302
|
+
if d == free_dims[selected_index]:
|
303
|
+
fields.append(rf"$\mathbf{{{d}}}$")
|
304
|
+
values.append(f"{value}")
|
305
|
+
arrows.append(f"{arrow_up}{arrow_down}")
|
306
|
+
else:
|
307
|
+
fields.append(rf"$\mathbf{{{d}}}$")
|
308
|
+
values.append(value)
|
309
|
+
arrows.append("")
|
310
|
+
ax_table.table(
|
311
|
+
cellText=[fields, values, arrows], cellLoc="center", edges="open", loc="center"
|
312
|
+
)
|
313
|
+
ctx["fig"].canvas.draw_idle()
|
314
|
+
|
315
|
+
|
316
|
+
def is_remote(file):
|
317
|
+
return "@" in file
|
318
|
+
|
319
|
+
|
320
|
+
def sync_files(ctx):
|
321
|
+
args = ctx["args"]
|
322
|
+
valid_files = ctx["valid_files"]
|
323
|
+
jobs = []
|
324
|
+
for file in valid_files:
|
325
|
+
if is_remote(file):
|
326
|
+
mirror = get_local_mirror(file)
|
327
|
+
proc = subprocess.run(["scp", file, mirror])
|
328
|
+
if proc.returncode != 0:
|
329
|
+
report(LogLevel.ERROR, f"scp transfer failed for {file}")
|
330
|
+
sys.exit(1)
|
331
|
+
jobs.append((file, mirror))
|
332
|
+
|
333
|
+
def rsync(src, dst):
|
334
|
+
while ctx["alive"]:
|
335
|
+
subprocess.run(
|
336
|
+
["rsync", "-z", "--checksum", src, dst],
|
337
|
+
stdout=subprocess.DEVNULL,
|
338
|
+
stderr=subprocess.DEVNULL,
|
339
|
+
)
|
340
|
+
time.sleep(args.rsync_interval)
|
341
|
+
|
342
|
+
for job in jobs:
|
343
|
+
threading.Thread(target=rsync, daemon=True, args=job).start()
|
344
|
+
|
345
|
+
|
346
|
+
def fontsize_to_y_units(ctx, fontsize):
|
347
|
+
fig = ctx["fig"]
|
348
|
+
ax = ctx["ax_plot"]
|
349
|
+
dpi = fig.dpi
|
350
|
+
font_px = fontsize * dpi / 72
|
351
|
+
inv = ax.transData.inverted()
|
352
|
+
_, y0 = inv.transform((0, 0))
|
353
|
+
_, y1 = inv.transform((0, font_px))
|
354
|
+
dy = y1 - y0
|
355
|
+
return dy
|
356
|
+
|
357
|
+
|
358
|
+
def autospace_annotations(ctx, x_domain, ys, fontsize, padding_factor=1.10):
|
359
|
+
text_height = fontsize_to_y_units(ctx, fontsize)
|
360
|
+
h = text_height * padding_factor
|
361
|
+
|
362
|
+
y_adjust = {k: dict() for k in ys}
|
363
|
+
for x in x_domain:
|
364
|
+
y_vals = [(z, ys[z][x]) for z in ys]
|
365
|
+
lower_bound = -float("inf")
|
366
|
+
for z, y in sorted(y_vals, key=lambda item: item[1]):
|
367
|
+
box_bottom, box_top = y - h / 2, y + h / 2
|
368
|
+
if box_bottom < lower_bound: # overlap?
|
369
|
+
shift = lower_bound - box_bottom
|
370
|
+
new_y = y + shift
|
371
|
+
lower_bound = box_top + shift
|
372
|
+
else:
|
373
|
+
lower_bound = box_top
|
374
|
+
new_y = y
|
375
|
+
y_adjust[z][x] = new_y
|
376
|
+
|
377
|
+
return y_adjust
|
378
|
+
|
379
|
+
|
380
|
+
def annotate(ctx, plot_type, sub_df, y_axis, palette):
|
381
|
+
args = ctx["args"]
|
382
|
+
ax_plot = ctx["ax_plot"]
|
383
|
+
|
384
|
+
if not (args.annotate_max or args.annotate_min or args.annotate):
|
385
|
+
return
|
386
|
+
|
387
|
+
annotation_kwargs = {
|
388
|
+
"ha": "center",
|
389
|
+
"va": "bottom",
|
390
|
+
"color": "black",
|
391
|
+
"fontsize": 12,
|
392
|
+
"fontweight": "normal",
|
393
|
+
"xytext": (0, 5),
|
394
|
+
"textcoords": "offset points",
|
395
|
+
}
|
396
|
+
|
397
|
+
ys = dict()
|
398
|
+
z_domain = sub_df[args.z].unique()
|
399
|
+
x_domain = sub_df[args.x].unique()
|
400
|
+
|
401
|
+
for z in z_domain:
|
402
|
+
group = sub_df[sub_df[args.z] == z]
|
403
|
+
ys_z = group.groupby(args.x)[y_axis].apply(
|
404
|
+
scipy.stats.gmean if args.geomean else np.median
|
405
|
+
)
|
406
|
+
ys[z] = ys_z
|
407
|
+
|
408
|
+
x_adjust = {z: dict() for z in z_domain}
|
409
|
+
y_adjust = autospace_annotations(ctx, x_domain, ys, annotation_kwargs["fontsize"])
|
410
|
+
|
411
|
+
# adjust x positions for annotations based on the plot type
|
412
|
+
if plot_type == "lines":
|
413
|
+
for z in z_domain:
|
414
|
+
for x in x_domain:
|
415
|
+
x_adjust[z][x] = x # no adjustment needed for lines
|
416
|
+
elif plot_type == "bars":
|
417
|
+
|
418
|
+
def x_flat_generator():
|
419
|
+
for p in ax_plot.patches:
|
420
|
+
height = p.get_height()
|
421
|
+
if not np.isnan(height) and height > 0:
|
422
|
+
yield p.get_x() + p.get_width() / 2
|
423
|
+
|
424
|
+
x_flat_gen = iter(x_flat_generator())
|
425
|
+
for z in z_domain:
|
426
|
+
for x in x_domain:
|
427
|
+
x_adjust[z][x] = next(x_flat_gen)
|
428
|
+
|
429
|
+
for z in z_domain:
|
430
|
+
annotation_kwargs_z = annotation_kwargs.copy()
|
431
|
+
annotation_kwargs_z["color"] = palette[z]
|
432
|
+
if args.annotate_max:
|
433
|
+
y = ys[z].max()
|
434
|
+
x = ys[z].idxmax()
|
435
|
+
xa = x_adjust[z][x]
|
436
|
+
ya = y_adjust[z][x]
|
437
|
+
ax_plot.annotate(
|
438
|
+
f"{y:.2f}",
|
439
|
+
(xa, ya),
|
440
|
+
**annotation_kwargs_z,
|
441
|
+
)
|
442
|
+
if args.annotate_min:
|
443
|
+
y = ys[z].min()
|
444
|
+
x = ys[z].idxmin()
|
445
|
+
xa = x_adjust[z][x]
|
446
|
+
ya = y_adjust[z][x]
|
447
|
+
ax_plot.annotate(
|
448
|
+
f"{y:.2f}",
|
449
|
+
(xa, ya),
|
450
|
+
**annotation_kwargs_z,
|
451
|
+
)
|
452
|
+
if args.annotate:
|
453
|
+
for x, y in ys[z].items():
|
454
|
+
xa = x_adjust[z][x]
|
455
|
+
ya = y_adjust[z][x]
|
456
|
+
ax_plot.annotate(
|
457
|
+
f"{y:.2f}",
|
458
|
+
(xa, ya),
|
459
|
+
**annotation_kwargs_z,
|
460
|
+
)
|
461
|
+
|
462
|
+
|
463
|
+
def to_engineering_si(x, precision=0, unit=None):
|
464
|
+
if x == 0:
|
465
|
+
return f"{0:.{precision}f}"
|
466
|
+
si_prefixes = {
|
467
|
+
-24: "y",
|
468
|
+
-21: "z",
|
469
|
+
-18: "a",
|
470
|
+
-15: "f",
|
471
|
+
-12: "p",
|
472
|
+
-9: "n",
|
473
|
+
-6: "µ",
|
474
|
+
-3: "m",
|
475
|
+
0: "",
|
476
|
+
3: "k",
|
477
|
+
6: "M",
|
478
|
+
9: "G",
|
479
|
+
12: "T",
|
480
|
+
15: "P",
|
481
|
+
18: "E",
|
482
|
+
21: "Z",
|
483
|
+
24: "Y",
|
484
|
+
}
|
485
|
+
exp = int(math.floor(math.log10(abs(x)) // 3 * 3))
|
486
|
+
exp = max(min(exp, 24), -24) # clamp to available prefixes
|
487
|
+
coeff = x / (10**exp)
|
488
|
+
prefix = si_prefixes.get(exp, f"e{exp:+03d}")
|
489
|
+
unit = unit or ""
|
490
|
+
return f"{coeff:.{precision}f}{prefix}{unit}"
|
491
|
+
|
492
|
+
|
493
|
+
def get_palette(values, colorblind=False):
|
494
|
+
if colorblind:
|
495
|
+
palette = sns.color_palette("colorblind", n_colors=len(values))
|
496
|
+
return {v: palette[i] for i, v in enumerate(values)}
|
497
|
+
else:
|
498
|
+
preferred_colors = [
|
499
|
+
"#5588dd",
|
500
|
+
"#882255",
|
501
|
+
"#33bb88",
|
502
|
+
"#9624e1",
|
503
|
+
"#BBBB41",
|
504
|
+
"#ed5a15",
|
505
|
+
"#aa44ff",
|
506
|
+
"#448811",
|
507
|
+
"#3fa7d6",
|
508
|
+
"#e94f37",
|
509
|
+
"#6cc551",
|
510
|
+
"#dabef9",
|
511
|
+
]
|
512
|
+
color_gen = iter(preferred_colors)
|
513
|
+
return {v: next(color_gen) for v in values}
|
514
|
+
|
515
|
+
|
516
|
+
def update_plot(ctx, padding_factor=1.05):
|
517
|
+
args = ctx["args"]
|
518
|
+
df = ctx["df"]
|
519
|
+
y_axis = ctx["y_axis"]
|
520
|
+
ax_plot = ctx["ax_plot"]
|
521
|
+
top = ctx.get("top", None)
|
522
|
+
|
523
|
+
config = get_current_config(ctx)
|
524
|
+
sub_df = get_projection(df, config)
|
525
|
+
|
526
|
+
ax_plot.clear()
|
527
|
+
|
528
|
+
# set figure title
|
529
|
+
y_left, y_right = sub_df[y_axis].min(), sub_df[y_axis].max()
|
530
|
+
y_range = "[{} - {}]".format(
|
531
|
+
to_engineering_si(y_left, unit=args.unit),
|
532
|
+
to_engineering_si(y_right, unit=args.unit),
|
533
|
+
)
|
534
|
+
title_parts = []
|
535
|
+
for i, y in enumerate(args.y, start=1):
|
536
|
+
if y == y_axis:
|
537
|
+
title_parts.append(rf"{i}: $\mathbf{{{y}}}$")
|
538
|
+
else:
|
539
|
+
title_parts.append(f"{i}: {y}")
|
540
|
+
title = " | ".join(title_parts) + "\n" + y_range
|
541
|
+
ctx["fig"].suptitle(title)
|
542
|
+
|
543
|
+
if args.x_norm:
|
544
|
+
sub_df = group_normalization("x", df, config, args, y_axis)
|
545
|
+
elif args.z_norm:
|
546
|
+
sub_df = group_normalization("z", df, config, args, y_axis)
|
547
|
+
elif args.ref_norm:
|
548
|
+
sub_df = ref_normalization(df, config, args, y_axis)
|
549
|
+
|
550
|
+
if args.geomean:
|
551
|
+
gm_df = sub_df.copy()
|
552
|
+
gm_df[args.x] = "geomean"
|
553
|
+
sub_df = pd.concat([sub_df, gm_df])
|
554
|
+
|
555
|
+
# draw horizontal line at y=1.0
|
556
|
+
if args.x_norm or args.z_norm or args.ref_norm:
|
557
|
+
ax_plot.axhline(y=1.0, linestyle="-", linewidth=4, color="lightgrey")
|
558
|
+
|
559
|
+
def custom_error(data):
|
560
|
+
d = pd.DataFrame(data)
|
561
|
+
return (
|
562
|
+
spread.lower(args.spread_measure)(d),
|
563
|
+
spread.upper(args.spread_measure)(d),
|
564
|
+
)
|
565
|
+
|
566
|
+
palette = get_palette(ctx["z_dom"], colorblind=args.colorblind)
|
567
|
+
|
568
|
+
# main plot generation
|
569
|
+
if args.lines:
|
570
|
+
sns.lineplot(
|
571
|
+
data=sub_df,
|
572
|
+
x=args.x,
|
573
|
+
y=y_axis,
|
574
|
+
hue=args.z,
|
575
|
+
palette=palette,
|
576
|
+
lw=2,
|
577
|
+
linestyle="-",
|
578
|
+
marker="o",
|
579
|
+
errorbar=None,
|
580
|
+
ax=ax_plot,
|
581
|
+
estimator=np.median,
|
582
|
+
)
|
583
|
+
if args.spread_measure != "none":
|
584
|
+
spread.draw(
|
585
|
+
ax_plot,
|
586
|
+
[args.spread_measure],
|
587
|
+
sub_df,
|
588
|
+
x=args.x,
|
589
|
+
y=y_axis,
|
590
|
+
z=args.z,
|
591
|
+
palette=palette,
|
592
|
+
)
|
593
|
+
else:
|
594
|
+
sns.barplot(
|
595
|
+
data=sub_df,
|
596
|
+
ax=ax_plot,
|
597
|
+
estimator=scipy.stats.gmean if args.geomean else np.median,
|
598
|
+
palette=palette,
|
599
|
+
legend=True,
|
600
|
+
x=args.x,
|
601
|
+
y=y_axis,
|
602
|
+
hue=args.z,
|
603
|
+
errorbar=custom_error if args.spread_measure != "none" else None,
|
604
|
+
alpha=0.6,
|
605
|
+
err_kws={
|
606
|
+
"color": "black",
|
607
|
+
"alpha": 1.0,
|
608
|
+
"linewidth": 2.0,
|
609
|
+
"solid_capstyle": "round",
|
610
|
+
"solid_joinstyle": "round",
|
611
|
+
},
|
612
|
+
)
|
613
|
+
|
614
|
+
# draw vertical line to separate geomean
|
615
|
+
if args.geomean:
|
616
|
+
pp = sorted(ax_plot.patches, key=lambda x: x.get_x())
|
617
|
+
z_size = ctx["z_size"]
|
618
|
+
x = pp[-z_size].get_x() + pp[-z_size - 1].get_x() + pp[-z_size - 1].get_width()
|
619
|
+
plt.axvline(x=x / 2, color="grey", linewidth=1, linestyle="-")
|
620
|
+
|
621
|
+
# set y-axis label
|
622
|
+
def format_ylabel(label):
|
623
|
+
if args.unit is None:
|
624
|
+
return label
|
625
|
+
elif args.x_norm or args.z_norm or args.ref_norm:
|
626
|
+
return label
|
627
|
+
else:
|
628
|
+
return f"{label} [{args.unit}]"
|
629
|
+
|
630
|
+
if top is not None:
|
631
|
+
ax_plot.set_ylim(top=top * padding_factor, bottom=0.0)
|
632
|
+
|
633
|
+
if args.x_norm or args.z_norm or args.ref_norm:
|
634
|
+
if args.norm_reverse:
|
635
|
+
normalized_label = f"{y_axis} (gain)"
|
636
|
+
else:
|
637
|
+
normalized_label = f"{y_axis} (normalized)"
|
638
|
+
ax_plot.set_ylabel(format_ylabel(normalized_label))
|
639
|
+
else:
|
640
|
+
ax_plot.set_ylabel(format_ylabel(y_axis))
|
641
|
+
|
642
|
+
# format y-tick labels with 'x' suffix for normalized plots
|
643
|
+
if args.x_norm or args.z_norm or args.ref_norm:
|
644
|
+
# use FuncFormatter to append 'x' to tick labels
|
645
|
+
from matplotlib.ticker import FuncFormatter
|
646
|
+
|
647
|
+
def format_with_x(x, pos):
|
648
|
+
return f"{x:.2f}x"
|
649
|
+
|
650
|
+
ax_plot.yaxis.set_major_formatter(FuncFormatter(format_with_x))
|
651
|
+
ax_plot.set_yticks(sorted(set(list(ax_plot.get_yticks()) + [1.0])))
|
652
|
+
|
653
|
+
if args.lines:
|
654
|
+
annotate(ctx, "lines", sub_df, y_axis, palette)
|
655
|
+
else:
|
656
|
+
annotate(ctx, "bars", sub_df, y_axis, palette)
|
657
|
+
|
658
|
+
ctx["fig"].canvas.draw_idle()
|
659
|
+
|
660
|
+
|
661
|
+
def get_config_name(ctx):
|
662
|
+
y_axis = ctx["y_axis"]
|
663
|
+
args = ctx["args"]
|
664
|
+
config = get_current_config(ctx)
|
665
|
+
if args.ref_norm or args.x_norm or args.z_norm:
|
666
|
+
if args.norm_reverse:
|
667
|
+
status = [f"{y_axis}", "gain"]
|
668
|
+
else:
|
669
|
+
status = [f"{y_axis}", "normalized"]
|
670
|
+
else:
|
671
|
+
status = [f"{y_axis}"]
|
672
|
+
status += [str(v) for v in config.values()]
|
673
|
+
name = "_".join(status)
|
674
|
+
return name
|
675
|
+
|
676
|
+
|
677
|
+
def get_status_description(ctx):
|
678
|
+
args = ctx["args"]
|
679
|
+
description_parts = []
|
680
|
+
domains = ctx["domains"]
|
681
|
+
|
682
|
+
for d in ctx["free_dims"]:
|
683
|
+
position = ctx["position"]
|
684
|
+
value = domains[d][position[d]]
|
685
|
+
description_parts.append(f"{d}={value}")
|
686
|
+
|
687
|
+
description = " | ".join(description_parts)
|
688
|
+
if ctx["z_size"] == 1:
|
689
|
+
z_values = ctx["df"][args.z].unique()
|
690
|
+
description += f" | {args.z}={z_values[0]}"
|
691
|
+
|
692
|
+
return description
|
693
|
+
|
694
|
+
|
695
|
+
def save_to_file(ctx, outfile=None):
|
696
|
+
ax_plot = ctx["ax_plot"]
|
697
|
+
args = ctx["args"]
|
698
|
+
outfile = outfile or get_config_name(ctx) + ".pdf"
|
699
|
+
if ctx["z_size"] == 1:
|
700
|
+
legend = ax_plot.get_legend()
|
701
|
+
if legend:
|
702
|
+
legend.set_visible(False)
|
703
|
+
|
704
|
+
name = str(ctx["y_axis"])
|
705
|
+
s = "gain" if args.norm_reverse else "normalized"
|
706
|
+
if args.ref_norm:
|
707
|
+
wrt = " | ".join(args.ref_norm)
|
708
|
+
title = rf"$\mathbf{{{name}}}$ ({s} w.r.t {wrt})"
|
709
|
+
elif args.x_norm:
|
710
|
+
wrt = " | ".join(args.x_norm)
|
711
|
+
title = rf"$\mathbf{{{name}}}$ ({s} w.r.t {wrt})"
|
712
|
+
elif args.z_norm:
|
713
|
+
wrt = " | ".join(args.z_norm)
|
714
|
+
title = rf"$\mathbf{{{name}}}$ ({s} w.r.t {wrt})"
|
715
|
+
else:
|
716
|
+
title = rf"$\mathbf{{{name}}}$"
|
717
|
+
|
718
|
+
title += "\n" + get_status_description(ctx)
|
719
|
+
ctx["fig"].suptitle(title)
|
720
|
+
extent = ax_plot.get_window_extent().transformed(
|
721
|
+
ctx["fig"].dpi_scale_trans.inverted()
|
722
|
+
)
|
723
|
+
ctx["fig"].savefig(outfile, bbox_inches=extent.expanded(1.2, 1.2))
|
724
|
+
report(LogLevel.INFO, f"saved to '{outfile}'")
|
725
|
+
|
726
|
+
|
727
|
+
def on_key(event, ctx):
|
728
|
+
selected_index = ctx["selected_index"]
|
729
|
+
free_dims = ctx["free_dims"]
|
730
|
+
domains = ctx["domains"]
|
731
|
+
position = ctx["position"]
|
732
|
+
y_dims = ctx["y_dims"]
|
733
|
+
|
734
|
+
if event.key in ["enter", " ", "up", "down"]:
|
735
|
+
x = 1 if event.key in [" ", "enter", "up"] else -1
|
736
|
+
if selected_index is None:
|
737
|
+
return
|
738
|
+
selected_dim = free_dims[selected_index]
|
739
|
+
cur_pos = position[selected_dim]
|
740
|
+
new_pos = (cur_pos + x) % domains[selected_dim].size
|
741
|
+
position[selected_dim] = new_pos
|
742
|
+
update_plot(ctx)
|
743
|
+
update_table(ctx)
|
744
|
+
elif event.key in ["left", "right"]:
|
745
|
+
if selected_index is None:
|
746
|
+
return
|
747
|
+
if event.key == "left":
|
748
|
+
ctx["selected_index"] = (selected_index - 1) % len(free_dims)
|
749
|
+
else:
|
750
|
+
ctx["selected_index"] = (selected_index + 1) % len(free_dims)
|
751
|
+
update_table(ctx)
|
752
|
+
elif event.key in "123456789":
|
753
|
+
new_idx = int(event.key) - 1
|
754
|
+
if new_idx < len(y_dims):
|
755
|
+
ctx["y_axis"] = y_dims[new_idx]
|
756
|
+
compute_ylimits(ctx)
|
757
|
+
update_plot(ctx)
|
758
|
+
elif event.key in ".":
|
759
|
+
save_to_file(ctx)
|
760
|
+
|
761
|
+
|
762
|
+
def on_close(event, ctx):
|
763
|
+
ctx["alive"] = False
|
764
|
+
|
765
|
+
|
766
|
+
def compute_missing(ctx):
|
767
|
+
df = ctx["df"]
|
768
|
+
y_dims = ctx["y_dims"]
|
769
|
+
space_columns = df.columns.difference(y_dims)
|
770
|
+
expected = set(itertools.product(*[df[col].unique() for col in space_columns]))
|
771
|
+
observed = set(map(tuple, df[space_columns].drop_duplicates().values))
|
772
|
+
missing = expected - observed
|
773
|
+
return pd.DataFrame(list(missing), columns=space_columns)
|
774
|
+
|
775
|
+
|
776
|
+
def validate_dimensions(ctx, dims):
|
777
|
+
args = ctx["args"]
|
778
|
+
df = ctx["df"]
|
779
|
+
for col in dims:
|
780
|
+
if col not in df.columns:
|
781
|
+
available = list(df.columns)
|
782
|
+
hint = "available columns: {}".format(", ".join(available))
|
783
|
+
report(LogLevel.FATAL, "invalid column", col, hint=hint)
|
784
|
+
|
785
|
+
|
786
|
+
def validate_args(ctx):
|
787
|
+
args = ctx["args"]
|
788
|
+
df = ctx["df"]
|
789
|
+
|
790
|
+
validate_dimensions(ctx, [args.x])
|
791
|
+
|
792
|
+
# Y-axis
|
793
|
+
numeric_cols = (
|
794
|
+
df.drop(columns=[args.x]).select_dtypes(include=[np.number]).columns.tolist()
|
795
|
+
)
|
796
|
+
if args.y is None:
|
797
|
+
# find the floating point numeric columns
|
798
|
+
if len(numeric_cols) == 0:
|
799
|
+
report(
|
800
|
+
LogLevel.FATAL,
|
801
|
+
"No numeric columns found in the data",
|
802
|
+
hint="use -y to specify a Y-axis",
|
803
|
+
)
|
804
|
+
report(LogLevel.INFO, "Using '{}' as Y-axis".format(", ".join(numeric_cols)))
|
805
|
+
args.y = numeric_cols
|
806
|
+
validate_dimensions(ctx, args.y)
|
807
|
+
for y in args.y:
|
808
|
+
if not pd.api.types.is_numeric_dtype(df[y]):
|
809
|
+
t = df[y].dtype
|
810
|
+
if len(numeric_cols) > 0:
|
811
|
+
hint = "try {}".format(
|
812
|
+
numeric_cols[0]
|
813
|
+
if len(numeric_cols) == 1
|
814
|
+
else ", ".join(numeric_cols)
|
815
|
+
)
|
816
|
+
else:
|
817
|
+
hint = "use -y to specify a Y-axis"
|
818
|
+
report(
|
819
|
+
LogLevel.FATAL,
|
820
|
+
f"Y-axis must have a numeric type. '{y}' has type '{t}'",
|
821
|
+
hint=hint,
|
822
|
+
)
|
823
|
+
|
824
|
+
if args.x in args.y:
|
825
|
+
report(
|
826
|
+
LogLevel.FATAL,
|
827
|
+
f"X-axis and Y-axis must be different dimensions",
|
828
|
+
)
|
829
|
+
|
830
|
+
# Z-axis
|
831
|
+
# check that there are at least two dimensions other than args.y
|
832
|
+
if len(df.columns.difference(args.y)) < 2:
|
833
|
+
report(
|
834
|
+
LogLevel.FATAL,
|
835
|
+
"there must be at least two dimensions other than the Y-axis",
|
836
|
+
)
|
837
|
+
if args.z is None:
|
838
|
+
# pick the first column that is not args.x or in args.y
|
839
|
+
available = df.columns.difference([args.x] + args.y)
|
840
|
+
args.z = available[np.argmin([df[col].nunique() for col in available])]
|
841
|
+
report(LogLevel.INFO, "Using '{}' as Z-axis".format(args.z))
|
842
|
+
else:
|
843
|
+
validate_dimensions(ctx, [args.z])
|
844
|
+
zdom = df[args.z].unique()
|
845
|
+
if len(zdom) == 1 and args.geomean:
|
846
|
+
report(
|
847
|
+
LogLevel.WARNING,
|
848
|
+
"--geomean is superfluous because '{}' is the only value in the '{}' group".format(
|
849
|
+
zdom[0], args.z
|
850
|
+
),
|
851
|
+
)
|
852
|
+
|
853
|
+
# all axis
|
854
|
+
if args.x == args.z or args.z in args.y:
|
855
|
+
report(
|
856
|
+
LogLevel.FATAL,
|
857
|
+
"the -z dimension must be different from the dimension used on the X or Y axis",
|
858
|
+
)
|
859
|
+
|
860
|
+
# geomean and lines
|
861
|
+
if args.geomean and args.lines:
|
862
|
+
report(LogLevel.FATAL, "--geomean and --lines cannot be used together")
|
863
|
+
for d in df.columns.difference(args.y):
|
864
|
+
n = df[d].nunique()
|
865
|
+
if n > 100 and pd.api.types.is_numeric_dtype(df[d]):
|
866
|
+
report(
|
867
|
+
LogLevel.WARNING,
|
868
|
+
f"'{d}' seems to have many ({n}) numeric values. Are you sure this is not supposed to be the Y-axis?",
|
869
|
+
)
|
870
|
+
|
871
|
+
# normalization
|
872
|
+
def validate_pairs(norm_args):
|
873
|
+
for arg in norm_args:
|
874
|
+
if "=" not in arg:
|
875
|
+
report(
|
876
|
+
LogLevel.FATAL,
|
877
|
+
f"invalid normalization argument '{arg}', expected format 'key=value'",
|
878
|
+
)
|
879
|
+
return {pair.split("=")[0]: pair.split("=")[1] for pair in norm_args}
|
880
|
+
|
881
|
+
if (
|
882
|
+
(args.x_norm and args.z_norm)
|
883
|
+
or (args.x_norm and args.ref_norm)
|
884
|
+
or (args.z_norm and args.ref_norm)
|
885
|
+
):
|
886
|
+
report(
|
887
|
+
LogLevel.FATAL,
|
888
|
+
"only one normalization method can be used at a time: --x-norm, --z-norm, or --ref-norm",
|
889
|
+
)
|
890
|
+
if args.ref_norm:
|
891
|
+
keys = validate_pairs(args.ref_norm).keys()
|
892
|
+
if args.x not in keys or args.z not in keys:
|
893
|
+
hint = "try adding '{}=<value>' or '{}=<value>' to --ref-norm".format(
|
894
|
+
args.x, args.z
|
895
|
+
)
|
896
|
+
report(
|
897
|
+
LogLevel.FATAL,
|
898
|
+
"--ref-norm pairs must include both the X-axis and Z-axis dimensions",
|
899
|
+
hint=hint,
|
900
|
+
)
|
901
|
+
elif args.x_norm:
|
902
|
+
keys = validate_pairs(args.x_norm).keys()
|
903
|
+
if args.z not in keys:
|
904
|
+
hint = "try adding '{}=<value>' to --x-norm".format(args.z)
|
905
|
+
report(
|
906
|
+
LogLevel.FATAL,
|
907
|
+
"--x-norm pairs must include the Z-axis dimension",
|
908
|
+
hint=hint,
|
909
|
+
)
|
910
|
+
if args.x in keys:
|
911
|
+
hint = "try removing '{}=<value>' from --x-norm".format(args.x)
|
912
|
+
report(
|
913
|
+
LogLevel.FATAL,
|
914
|
+
"--x-norm pairs must not include the X-axis dimension",
|
915
|
+
hint=hint,
|
916
|
+
)
|
917
|
+
elif args.z_norm:
|
918
|
+
keys = validate_pairs(args.z_norm).keys()
|
919
|
+
if args.x not in keys:
|
920
|
+
hint = "try adding '{}=<value>' to --z-norm".format(args.x)
|
921
|
+
report(
|
922
|
+
LogLevel.FATAL,
|
923
|
+
"--z-norm pairs must include the X-axis dimension",
|
924
|
+
hint=hint,
|
925
|
+
)
|
926
|
+
if args.z in keys:
|
927
|
+
hint = "try removing '{}=<value>' from --z-norm".format(args.z)
|
928
|
+
report(
|
929
|
+
LogLevel.FATAL,
|
930
|
+
"--z-norm pairs must not include the Z-axis dimension",
|
931
|
+
hint=hint,
|
932
|
+
)
|
933
|
+
if not (args.x_norm or args.z_norm or args.ref_norm) and args.norm_reverse:
|
934
|
+
report(
|
935
|
+
LogLevel.WARNING,
|
936
|
+
"--norm-reverse is ignored because no normalization is applied",
|
937
|
+
)
|
938
|
+
|
939
|
+
if args.spread_measure != "none":
|
940
|
+
if not spread.assert_validity(args.spread_measure):
|
941
|
+
args.spread_measure = "none"
|
942
|
+
|
943
|
+
ctx["y_dims"] = args.y
|
944
|
+
ctx["y_axis"] = args.y[0]
|
945
|
+
|
946
|
+
if args.show_missing:
|
947
|
+
missing = compute_missing(ctx)
|
948
|
+
if len(missing) > 0:
|
949
|
+
report(LogLevel.WARNING, "missing experiments:")
|
950
|
+
report(LogLevel.WARNING, "\n" + missing.to_string(index=False))
|
951
|
+
report(LogLevel.WARNING, "")
|
952
|
+
|
953
|
+
|
954
|
+
def start_gui(ctx):
|
955
|
+
ctx["alive"] = True
|
956
|
+
|
957
|
+
update_plot(ctx)
|
958
|
+
update_table(ctx)
|
959
|
+
threading.Thread(target=file_monitor, daemon=True, args=(ctx,)).start()
|
960
|
+
report(LogLevel.INFO, "application running")
|
961
|
+
time.sleep(1.0) # wait for the GUI to initialize
|
962
|
+
plt.show()
|
963
|
+
|
964
|
+
|
965
|
+
def compute_ylimits(ctx):
|
966
|
+
args = ctx["args"]
|
967
|
+
free_dims = ctx["free_dims"]
|
968
|
+
df = ctx["df"]
|
969
|
+
y_axis = ctx["y_axis"]
|
970
|
+
domains = ctx["domains"]
|
971
|
+
free_domains = {k: v for k, v in domains.items() if k in free_dims}
|
972
|
+
top = None
|
973
|
+
if len(free_dims) == 0:
|
974
|
+
ctx["top"] = None
|
975
|
+
return
|
976
|
+
if args.x_norm or args.z_norm or args.ref_norm:
|
977
|
+
top = 0
|
978
|
+
for point in itertools.product(*free_domains.values()):
|
979
|
+
filt = (df[free_domains.keys()] == point).all(axis=1)
|
980
|
+
config = get_config(point, free_domains.keys())
|
981
|
+
if args.ref_norm:
|
982
|
+
df_config = ref_normalization(df, config, args, y_axis)
|
983
|
+
elif args.x_norm:
|
984
|
+
df_config = group_normalization("x", df, config, args, y_axis)
|
985
|
+
elif args.z_norm:
|
986
|
+
df_config = group_normalization("z", df, config, args, y_axis)
|
987
|
+
zx = df_config.groupby([args.z, args.x])[y_axis]
|
988
|
+
if args.spread_measure != "none":
|
989
|
+
t = zx.apply(spread.upper(args.spread_measure))
|
990
|
+
else:
|
991
|
+
t = zx.max()
|
992
|
+
top = max(top, t.max())
|
993
|
+
else:
|
994
|
+
top = df[y_axis].max()
|
995
|
+
ctx["top"] = top
|
996
|
+
|
997
|
+
|
998
|
+
def launch(args):
|
999
|
+
ctx = {"args": args, "alive": True}
|
1000
|
+
validate_files(ctx)
|
1001
|
+
locate_files(ctx)
|
1002
|
+
sync_files(ctx)
|
1003
|
+
generate_dataframe(ctx)
|
1004
|
+
validate_args(ctx)
|
1005
|
+
rescale(ctx)
|
1006
|
+
generate_space(ctx)
|
1007
|
+
compute_ylimits(ctx)
|
1008
|
+
initialize_figure(ctx)
|
1009
|
+
start_gui(ctx)
|