reproplot 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- reproplot-0.0.1/LICENSE +21 -0
- reproplot-0.0.1/PKG-INFO +33 -0
- reproplot-0.0.1/README.md +1 -0
- reproplot-0.0.1/pyproject.toml +25 -0
- reproplot-0.0.1/src/reproplot_oscarvlld/__init__.py +0 -0
- reproplot-0.0.1/src/reproplot_oscarvlld/plotting.py +191 -0
- reproplot-0.0.1/src/reproplot_oscarvlld/reprod_exps.py +141 -0
- reproplot-0.0.1/src/reproplot_oscarvlld/reprod_plots.py +248 -0
- reproplot-0.0.1/src/reproplot_oscarvlld/reproduce.py +157 -0
- reproplot-0.0.1/src/reproplot_oscarvlld/utils.py +235 -0
reproplot-0.0.1/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024 Oscar Villemaud, Indy Lab, EPFL
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
reproplot-0.0.1/PKG-INFO
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: reproplot
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: General package for experimenting and automatically plotting.
|
|
5
|
+
Author-email: Oscar Villemaud <oscar.villemaud@epfl.ch>
|
|
6
|
+
License: MIT License
|
|
7
|
+
|
|
8
|
+
Copyright (c) 2024 Oscar Villemaud, Indy Lab, EPFL
|
|
9
|
+
|
|
10
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
11
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
12
|
+
in the Software without restriction, including without limitation the rights
|
|
13
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
14
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
15
|
+
furnished to do so, subject to the following conditions:
|
|
16
|
+
|
|
17
|
+
The above copyright notice and this permission notice shall be included in all
|
|
18
|
+
copies or substantial portions of the Software.
|
|
19
|
+
|
|
20
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
21
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
22
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
23
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
24
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
25
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
26
|
+
SOFTWARE.
|
|
27
|
+
License-File: LICENSE
|
|
28
|
+
Classifier: Operating System :: OS Independent
|
|
29
|
+
Classifier: Programming Language :: Python :: 3
|
|
30
|
+
Requires-Python: >=3.8
|
|
31
|
+
Description-Content-Type: text/markdown
|
|
32
|
+
|
|
33
|
+
General package for experimenting and automatically plotting.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
General package for experimenting and automatically plotting.
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
[project]
|
|
7
|
+
license = {file = "LICENSE"}
|
|
8
|
+
name = "reproplot"
|
|
9
|
+
version = "0.0.1"
|
|
10
|
+
authors = [
|
|
11
|
+
{ name="Oscar Villemaud", email="oscar.villemaud@epfl.ch" },
|
|
12
|
+
]
|
|
13
|
+
description = "General package for experimenting and automatically plotting. "
|
|
14
|
+
readme = "README.md"
|
|
15
|
+
requires-python = ">=3.8"
|
|
16
|
+
classifiers = [
|
|
17
|
+
"Programming Language :: Python :: 3",
|
|
18
|
+
"Operating System :: OS Independent",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
[tool.hatch.build.targets.wheel]
|
|
22
|
+
include = [
|
|
23
|
+
"reproplot/*.py",
|
|
24
|
+
"/tests",
|
|
25
|
+
]
|
|
File without changes
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
###
|
|
2
|
+
# @file plotting.py
|
|
3
|
+
# @author Oscar Villemaud <oscar.villemaud@epfl.ch>
|
|
4
|
+
#
|
|
5
|
+
# @section LICENSE
|
|
6
|
+
#
|
|
7
|
+
# Copyright © 2024-2026 École Polytechnique Fédérale de Lausanne (EPFL).
|
|
8
|
+
# All rights reserved.
|
|
9
|
+
#
|
|
10
|
+
# @section DESCRIPTION
|
|
11
|
+
#
|
|
12
|
+
# Plotting functions based on pyplot.
|
|
13
|
+
###
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
from matplotlib import pyplot as plt
|
|
17
|
+
from matplotlib import colors
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _finalize_plot(title, xlab, ylab, fontsize, savepath, extension, show):
|
|
21
|
+
plt.title(title)
|
|
22
|
+
plt.xlabel(xlab, size=fontsize)
|
|
23
|
+
plt.ylabel(ylab, size=fontsize)
|
|
24
|
+
if savepath is not None:
|
|
25
|
+
plt.savefig(f"{savepath}.{extension}")
|
|
26
|
+
if show:
|
|
27
|
+
plt.show()
|
|
28
|
+
plt.close()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def plot(data, legend=None, title="", log=False):
|
|
32
|
+
""" simple plotting function """
|
|
33
|
+
def _plot_scaled(curve):
|
|
34
|
+
if log:
|
|
35
|
+
plt.semilogy(curve)
|
|
36
|
+
else:
|
|
37
|
+
plt.plot(curve)
|
|
38
|
+
if not hasattr(data[0], '__iter__'):
|
|
39
|
+
_plot_scaled(data)
|
|
40
|
+
else:
|
|
41
|
+
for curve in data:
|
|
42
|
+
_plot_scaled(curve)
|
|
43
|
+
if legend is not None:
|
|
44
|
+
plt.legend(legend)
|
|
45
|
+
plt.title(title)
|
|
46
|
+
plt.show()
|
|
47
|
+
plt.close()
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def seeds_plot(
|
|
51
|
+
list_list, legend=None, x_vals=None, color=None, style=None,
|
|
52
|
+
xlog=False, ylog=False, confidence=True, std=False, plot_all=False, x_vlines=False):
|
|
53
|
+
""" plots one line and confidence interval from multiple seeds """
|
|
54
|
+
if len(list_list) == 0:
|
|
55
|
+
return
|
|
56
|
+
def _log_switch_plot(values, x_vals):
|
|
57
|
+
if ylog and xlog:
|
|
58
|
+
plt.loglog(x_vals, values, label=legend, linestyle=style, color=color)
|
|
59
|
+
elif ylog and not xlog:
|
|
60
|
+
plt.semilogy(x_vals, values, label=legend, linestyle=style, color=color)
|
|
61
|
+
elif xlog and not ylog:
|
|
62
|
+
plt.semilogx(x_vals, values, label=legend, linestyle=style, color=color)
|
|
63
|
+
else:
|
|
64
|
+
plt.plot(x_vals, values, label=legend, linestyle=style, color=color)
|
|
65
|
+
nb_samples = len(list_list)
|
|
66
|
+
if plot_all:
|
|
67
|
+
for values in list_list:
|
|
68
|
+
if x_vals is None:
|
|
69
|
+
x_vals = list(range(len(values)))
|
|
70
|
+
_log_switch_plot(values, x_vals)
|
|
71
|
+
else:
|
|
72
|
+
if x_vals is None:
|
|
73
|
+
x_vals = list(range(len(list_list[0])))
|
|
74
|
+
arr = np.array(list_list)
|
|
75
|
+
vals = np.nanmean(arr, axis=0)
|
|
76
|
+
_log_switch_plot(vals, x_vals)
|
|
77
|
+
if confidence:
|
|
78
|
+
confs = 1.96 * np.nanstd(arr, axis=0) / nb_samples**0.5
|
|
79
|
+
plt.plot(x_vals, vals - confs, linestyle=style, color=color, linewidth=0.3)
|
|
80
|
+
plt.plot(x_vals, vals + confs, linestyle=style, color=color, linewidth=0.3)
|
|
81
|
+
if std:
|
|
82
|
+
stds = np.nanstd(arr, axis=0)
|
|
83
|
+
plt.fill_between(list(x_vals), vals - stds, vals + stds, alpha=0.1, color=color)
|
|
84
|
+
if x_vlines:
|
|
85
|
+
for x in x_vals:
|
|
86
|
+
plt.axvline(x, linewidth=0.3)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def seeds_plot_together(
|
|
90
|
+
all_curves, legends=None, title="", xlog=False, ylog=False, confidence=True, std=False,
|
|
91
|
+
vlines=[], x_vals=None, xlab=None, ylab=None, fontsize=11, xlims=None, ylims=None,
|
|
92
|
+
savepath=None, figsize=(8, 5), show=False, plot_all=False, x_vlines=False, extension="png",
|
|
93
|
+
):
|
|
94
|
+
"""
|
|
95
|
+
Plots several lines of several seeds (for each line average and confidence interval)
|
|
96
|
+
Args:
|
|
97
|
+
- all_curves (float list list list) : order 3 array/list of lists of lists
|
|
98
|
+
one sublist is one line, one subsublist is one seed
|
|
99
|
+
- legends (str list) : labels to use for each line (in order)
|
|
100
|
+
- title (str) : title of the plot
|
|
101
|
+
- xlog (bool) : True for x axis log scale
|
|
102
|
+
- ylog (bool) : True for y axis log scale
|
|
103
|
+
- confidence (bool) : True to display 95% mean estimate confidence intervals
|
|
104
|
+
- std (bool) : True to display standard deviation accross seeds
|
|
105
|
+
- vlines (float list) : list of x coordinates where to add vertical lines
|
|
106
|
+
- x_vals (list) : x axis values, default is 0 to n
|
|
107
|
+
- xlab (str) : label of x axis
|
|
108
|
+
- ylab (str) : label of y axis
|
|
109
|
+
- fontsize (int) : font size
|
|
110
|
+
- xlims (float pair) : plot limits for x axis, None for auto
|
|
111
|
+
- ylims (float pair) : plot limits for y axis, None for auto
|
|
112
|
+
- savepath (str) : path where to save the plot, not saved if None
|
|
113
|
+
- figsize (int pair) : dimensions of the plot
|
|
114
|
+
- show (bool) : True to show the plot
|
|
115
|
+
- plot_all (bool) : True to display one line for each seed instead of average
|
|
116
|
+
- x_vlines (bool) : True to draw a vertical line at each data point
|
|
117
|
+
- extension (str) : format for the saved image file
|
|
118
|
+
"""
|
|
119
|
+
plt.figure(figsize=figsize)
|
|
120
|
+
colors = ["orange", "green", "blue", "red", "purple", "black"] * 10
|
|
121
|
+
styles = ["-", "--", "-.", ":", "-"] * 10
|
|
122
|
+
if legends is None:
|
|
123
|
+
legends = [[]] * len(all_curves)
|
|
124
|
+
for curve, color, style, legend in zip(all_curves, colors, styles, legends):
|
|
125
|
+
seeds_plot(
|
|
126
|
+
curve, color=color, x_vals=x_vals, style=style, legend=legend,
|
|
127
|
+
xlog=xlog, ylog=ylog, confidence=confidence, std=std, plot_all=plot_all, x_vlines=x_vlines)
|
|
128
|
+
for x in vlines:
|
|
129
|
+
plt.axvline(x)
|
|
130
|
+
plt.xlim(xlims)
|
|
131
|
+
plt.ylim(ylims)
|
|
132
|
+
if legend is not None:
|
|
133
|
+
plt.legend(prop={'size': fontsize})
|
|
134
|
+
_finalize_plot(title, xlab, ylab, fontsize, savepath, extension, show)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def seeds_plot_color3d(
|
|
138
|
+
all_seeds, x_vals, y_vals, title, xlab, ylab,
|
|
139
|
+
savepath, show=False, label="", fontsize=11, std=False,
|
|
140
|
+
xlog=False, ylog=False, zlog=False, extension="png",
|
|
141
|
+
**kwargs):
|
|
142
|
+
""" 3d color plot """
|
|
143
|
+
all_seeds = np.array(all_seeds)
|
|
144
|
+
means = np.nanmean(all_seeds, axis=0)
|
|
145
|
+
if zlog:
|
|
146
|
+
plt.pcolor(x_vals, y_vals, means, norm=colors.LogNorm())
|
|
147
|
+
else:
|
|
148
|
+
plt.pcolor(x_vals, y_vals, means)
|
|
149
|
+
|
|
150
|
+
if xlog:
|
|
151
|
+
plt.xscale('log')
|
|
152
|
+
if ylog:
|
|
153
|
+
plt.yscale('log')
|
|
154
|
+
if std:
|
|
155
|
+
stds = np.nanstd(all_seeds, axis=0)
|
|
156
|
+
for y, row in enumerate(stds):
|
|
157
|
+
for x, val in enumerate(row):
|
|
158
|
+
plt.text(x_vals[x], y_vals[y] ,
|
|
159
|
+
f"+{round(val, 2)}", ha='center',
|
|
160
|
+
va='center', color='black')
|
|
161
|
+
plt.colorbar(label=label)
|
|
162
|
+
_finalize_plot(title, xlab, ylab, fontsize, savepath, extension, show)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def seeds_plot_surface3d(
|
|
166
|
+
all_seeds, x_vals, y_vals, title, xlab, ylab,
|
|
167
|
+
savepath, show=False, label="", fontsize=11,
|
|
168
|
+
xlog=False, ylog=False, zlog=False, angle=None, extension="png",
|
|
169
|
+
**kwargs):
|
|
170
|
+
""" 3d color plot """
|
|
171
|
+
all_seeds = np.array(all_seeds)
|
|
172
|
+
means = np.nanmean(all_seeds, axis=0)
|
|
173
|
+
ax = plt.axes(projection='3d')
|
|
174
|
+
if angle is not None:
|
|
175
|
+
ax.view_init(*angle)
|
|
176
|
+
x, y = x_vals, y_vals
|
|
177
|
+
y_len, x_len = len(y), len(x)
|
|
178
|
+
x = np.expand_dims(x, axis=1)
|
|
179
|
+
x = np.repeat(x, [y_len], axis=1).transpose()
|
|
180
|
+
y = np.expand_dims(y, axis=0)
|
|
181
|
+
y = np.repeat(y, [x_len], axis=0).transpose()
|
|
182
|
+
# if xlog:
|
|
183
|
+
# ax.set_xscale('log')
|
|
184
|
+
# if ylog:
|
|
185
|
+
# ax.set_yscale('log')
|
|
186
|
+
if zlog:
|
|
187
|
+
ax.set_zscale('log')
|
|
188
|
+
ax.plot_surface(x, y, means, cmap='viridis',\
|
|
189
|
+
edgecolor='green')
|
|
190
|
+
ax.set_zlabel(label, size=fontsize)
|
|
191
|
+
_finalize_plot(title, xlab, ylab, fontsize, savepath, extension, show)
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
###
|
|
2
|
+
# @file reprod_exps.py
|
|
3
|
+
# @author Oscar Villemaud <oscar.villemaud@epfl.ch>
|
|
4
|
+
#
|
|
5
|
+
# @section LICENSE
|
|
6
|
+
#
|
|
7
|
+
# Copyright © 2024-2026 École Polytechnique Fédérale de Lausanne (EPFL).
|
|
8
|
+
# All rights reserved.
|
|
9
|
+
#
|
|
10
|
+
# @section DESCRIPTION
|
|
11
|
+
#
|
|
12
|
+
# Experiment running.
|
|
13
|
+
###
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
import time
|
|
17
|
+
from tqdm import tqdm
|
|
18
|
+
import traceback
|
|
19
|
+
import multiprocessing as mp
|
|
20
|
+
|
|
21
|
+
from .utils import seedall, dump_json, load_json, make_exp_name, make_grid, update_params, scan_runs
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _run_one_run(experiment_func, params, seed, exp_name, run_path, verb):
|
|
25
|
+
""" return 0 if run completed, 1 if failed"""
|
|
26
|
+
if "runs" in verb:
|
|
27
|
+
print(f"running {exp_name} seed {seed}.")
|
|
28
|
+
def _handle_error(exc):
|
|
29
|
+
count = 1
|
|
30
|
+
while os.path.isdir(f"{run_path[:-1]}_failed_{count}"):
|
|
31
|
+
count += 1
|
|
32
|
+
new_path = f"{run_path[:-1]}_failed_{count}"
|
|
33
|
+
os.rename(run_path, new_path)
|
|
34
|
+
print(f"{exp_name} seed {seed} failed with error: {exc}, renaming to: failed_{count}")
|
|
35
|
+
with open(new_path + "/" + "traceback.txt", 'w') as f:
|
|
36
|
+
f.write(traceback.format_exc())
|
|
37
|
+
try:
|
|
38
|
+
run_time = time.time()
|
|
39
|
+
seedall(seed)
|
|
40
|
+
metrics = experiment_func(**params)
|
|
41
|
+
if type(metrics) is not dict:
|
|
42
|
+
metrics = {"metric": metrics}
|
|
43
|
+
run_time = time.time() - run_time
|
|
44
|
+
metrics["run_time"] = run_time
|
|
45
|
+
dump_json(metrics, run_path + f"metrics.json")
|
|
46
|
+
if "runs" in verb:
|
|
47
|
+
print(f"{exp_name} seed {seed} saved. ({round(run_time)} secs)")
|
|
48
|
+
except Exception as exc:
|
|
49
|
+
_handle_error(exc)
|
|
50
|
+
return 1
|
|
51
|
+
except KeyboardInterrupt:
|
|
52
|
+
_handle_error("Keyboard Interrupt")
|
|
53
|
+
raise KeyboardInterrupt
|
|
54
|
+
return 0
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def run_experiments(
|
|
58
|
+
experiment_func=None, res_dir=None,
|
|
59
|
+
seeds=None, nametag="", params_common=None, diff_plots=None, same_plot=None, same_line=None,
|
|
60
|
+
exp_tags=None, set_depending_params=None, verb=None):
|
|
61
|
+
""" run an experiment series given a grid of parameters, saves results in json files
|
|
62
|
+
Args:
|
|
63
|
+
- experiment_func (func) : function running a (random) experiment
|
|
64
|
+
and outputing a dictionnary of metrics
|
|
65
|
+
- res_dir (str) : name of the directory where to store results as json
|
|
66
|
+
- seeds (int list) : list of random seeds to use for reproducibility
|
|
67
|
+
- nametag (str) : prefix that identifies the series of experiments
|
|
68
|
+
- params_common (dict) : dictionnary of {"param_name": value}
|
|
69
|
+
that are de default parameters of experiment_func
|
|
70
|
+
- diff_plots (list dict) : dictionnary of {"param_name": list of values} each value combination on a different plot
|
|
71
|
+
- same_plot (list dict) : dictionnary of {"param_name": list of values} each value combination on a different line of the same plot
|
|
72
|
+
- same_line (list dict) : dictionnary of {"param_name": list of values}, values on the x axis of the plot (only one parameter)
|
|
73
|
+
- exp_tags (str list) : list of parameters to put in experiment names
|
|
74
|
+
- set_depending_params (func) : function editing in-place a dictionnary of parameters
|
|
75
|
+
- verb (str list) : string codes to indicate verbose
|
|
76
|
+
"""
|
|
77
|
+
multiprocess = 0
|
|
78
|
+
dir_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
79
|
+
nb_runs_found, nb_runs_total, confs_paths = scan_runs(
|
|
80
|
+
res_dir=res_dir, dir_path=dir_path,
|
|
81
|
+
seeds=seeds, nametag=nametag, params_common=params_common,
|
|
82
|
+
diff_plots=diff_plots, same_plot=same_plot, same_line=same_line,
|
|
83
|
+
exp_tags=exp_tags, set_depending_params=set_depending_params)
|
|
84
|
+
nb_runs_needed = nb_runs_total - nb_runs_found
|
|
85
|
+
print(f"Recovered {nb_runs_found}/{nb_runs_total} runs. Running the remaining {nb_runs_needed}.")
|
|
86
|
+
# nb_exps = count_combinations(diff_plots) * count_combinations(same_plot) * count_combinations(same_line) * len(seeds)
|
|
87
|
+
nb_runs_ran = 0
|
|
88
|
+
init_time = time.time()
|
|
89
|
+
# TODO use confs_paths ?
|
|
90
|
+
if multiprocess:
|
|
91
|
+
mp_pool = mp.Pool(processes=multiprocess)
|
|
92
|
+
else:
|
|
93
|
+
nb_failed = 0
|
|
94
|
+
with tqdm(total=nb_runs_needed, leave=True, disable=("pbar" not in verb)) as pbar:
|
|
95
|
+
for params1 in make_grid(diff_plots):
|
|
96
|
+
for params2 in make_grid(same_plot):
|
|
97
|
+
for params3 in make_grid(same_line):
|
|
98
|
+
params = update_params(params_common, params1, params2, params3)
|
|
99
|
+
set_depending_params(params)
|
|
100
|
+
exp_name = make_exp_name(params, nametag, exp_tags)
|
|
101
|
+
exp_dir = f"{dir_path}/{res_dir}/{exp_name}/"
|
|
102
|
+
can_run = True
|
|
103
|
+
try: # check if parameters are matching
|
|
104
|
+
loaded_params = load_json(exp_dir + "params.json")
|
|
105
|
+
if loaded_params != params:
|
|
106
|
+
print("WARNING : old and new parameters don't match despite same experiment name !")
|
|
107
|
+
print("found :", loaded_params)
|
|
108
|
+
print("but has :", params)
|
|
109
|
+
can_run = False
|
|
110
|
+
except OSError:
|
|
111
|
+
os.makedirs(exp_dir)
|
|
112
|
+
dump_json(params, exp_dir + "params.json", indent=2)
|
|
113
|
+
if can_run:
|
|
114
|
+
for seed in seeds:
|
|
115
|
+
run_path = exp_dir + f"seed_{seed}/"
|
|
116
|
+
if os.path.isdir(run_path):
|
|
117
|
+
if "runs" in verb:
|
|
118
|
+
print(exp_name, "seed", seed, "already exists.")
|
|
119
|
+
else:
|
|
120
|
+
os.makedirs(run_path)
|
|
121
|
+
if multiprocess:
|
|
122
|
+
mp_pool.apply_async(
|
|
123
|
+
func=_run_one_run,
|
|
124
|
+
args=(experiment_func, params, seed, exp_name, run_path, verb))
|
|
125
|
+
else:
|
|
126
|
+
failed = _run_one_run(experiment_func, params, seed, exp_name, run_path, verb)
|
|
127
|
+
nb_failed += failed
|
|
128
|
+
pbar.update()
|
|
129
|
+
nb_runs_ran += 1
|
|
130
|
+
else:
|
|
131
|
+
print("Skipping experiment")
|
|
132
|
+
if multiprocess:
|
|
133
|
+
mp_pool.close()
|
|
134
|
+
mp_pool.join()
|
|
135
|
+
time_taken = time.time() - init_time
|
|
136
|
+
if multiprocess:
|
|
137
|
+
print("Unable to count fails in multiprocess mode")
|
|
138
|
+
print(f"{nb_runs_ran} runs ran and {nb_runs_found} recovered in {round(time_taken)} seconds")
|
|
139
|
+
else:
|
|
140
|
+
print(f"{nb_runs_ran - nb_failed} runs ran, {nb_runs_found} recovered and {nb_failed} failed in {round(time_taken)} seconds")
|
|
141
|
+
|
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
###
|
|
2
|
+
# @file reprod_plots.py
|
|
3
|
+
# @author Oscar Villemaud <oscar.villemaud@epfl.ch>
|
|
4
|
+
#
|
|
5
|
+
# @section LICENSE
|
|
6
|
+
#
|
|
7
|
+
# Copyright © 2024-2026 École Polytechnique Fédérale de Lausanne (EPFL).
|
|
8
|
+
# All rights reserved.
|
|
9
|
+
#
|
|
10
|
+
# @section DESCRIPTION
|
|
11
|
+
#
|
|
12
|
+
# Data retrieval for plotting.
|
|
13
|
+
###
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
import copy
|
|
17
|
+
from tqdm import tqdm
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
|
|
21
|
+
from .utils import load_json, make_exp_name, make_grid, make_legend, make_plot_name, update_params, make_title
|
|
22
|
+
from .plotting import seeds_plot_together, seeds_plot_color3d, seeds_plot_surface3d
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _extract_line_data(lineconf, res_dir, nametag, exp_tags, ignore_missing):
|
|
26
|
+
""" extracts data from json files and puts it in list of lists
|
|
27
|
+
Args:
|
|
28
|
+
- lineconf (dict) : dictionnary containing directions to the data to plot on one line
|
|
29
|
+
lineconf1 = {"confs": conf_list, "seeds":[1, 2], "metric": "metric"}
|
|
30
|
+
or
|
|
31
|
+
lineconf1 = {"conf": conf, "seeds":[1, 2], "metric": "metric"}
|
|
32
|
+
- res_dir (str) : directory containing the data to plot
|
|
33
|
+
- nametag (str) : prefix identifying the series of experiments
|
|
34
|
+
- exp_tags (str list) : list of parameters used in experiment names
|
|
35
|
+
- ignore_missing (bool) : True to plot despite missing data
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
- (float list list) list of lists of values of the metric, one sublist is one seed
|
|
39
|
+
"""
|
|
40
|
+
dir_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
41
|
+
def _handle_missing(ignore_missing, path):
|
|
42
|
+
if ignore_missing:
|
|
43
|
+
print(f"{path} not found, plotting without it.")
|
|
44
|
+
else:
|
|
45
|
+
raise Exception("{} not found, use ignore_missing=True to plot anyway".format(path))
|
|
46
|
+
if "confs" in lineconf: # if the metric chosen gives one value (eg final loss)
|
|
47
|
+
all_seeds = []
|
|
48
|
+
for seed in lineconf["seeds"]:
|
|
49
|
+
one_seed = []
|
|
50
|
+
for params in lineconf["confs"]:
|
|
51
|
+
exp_name = make_exp_name(params, nametag, exp_tags)
|
|
52
|
+
exp_dir = f"{dir_path}/{res_dir}/{exp_name}/"
|
|
53
|
+
path = exp_dir + f"seed_{seed}/"
|
|
54
|
+
try:
|
|
55
|
+
value = load_json(path + "metrics.json")[lineconf["metric"]]
|
|
56
|
+
if type(value) is list: # implicitely taking value at end of training
|
|
57
|
+
value = value[-1]
|
|
58
|
+
one_seed.append(value)
|
|
59
|
+
except OSError:
|
|
60
|
+
_handle_missing(ignore_missing, path)
|
|
61
|
+
one_seed.append(np.nan)
|
|
62
|
+
all_seeds.append(one_seed)
|
|
63
|
+
elif "conf" in lineconf: # if the metric chosen gives a list of values (eg training loss)
|
|
64
|
+
all_seeds = []
|
|
65
|
+
for seed in lineconf["seeds"]:
|
|
66
|
+
params = lineconf["conf"]
|
|
67
|
+
exp_name = make_exp_name(params, nametag, exp_tags)
|
|
68
|
+
exp_dir = f"{dir_path}/{res_dir}/{exp_name}/"
|
|
69
|
+
path = exp_dir + f"seed_{seed}/"
|
|
70
|
+
try:
|
|
71
|
+
one_seed = load_json(path + f"metrics.json")[lineconf["metric"]]
|
|
72
|
+
all_seeds.append(one_seed)
|
|
73
|
+
except OSError:
|
|
74
|
+
_handle_missing(ignore_missing, path)
|
|
75
|
+
return all_seeds
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _plot_from_conf(
|
|
79
|
+
plot_conf, res_dir, plot_dir, nametag, exp_tags,
|
|
80
|
+
same_line, metric, ignore_missing, plot_kwargs=None,
|
|
81
|
+
custom_xlab=None, custom_ylab=None):
|
|
82
|
+
""" creates and saves a plot from a config
|
|
83
|
+
Args:
|
|
84
|
+
- plot_conf (dict) : configuration of the plot, including instructions on the data to use
|
|
85
|
+
exp : plot_conf = { "filename" : "plot", "title": "title_foo",
|
|
86
|
+
"lines": { "legend1" : lineconf1, "legend2" : lineconf2}}
|
|
87
|
+
- res_dir (str) : directory containing the data to plot
|
|
88
|
+
- plot_dir (str) : directory where to save the plot
|
|
89
|
+
- nametag (str) : prefix identifying the series of experiments
|
|
90
|
+
- exp_tags (str list) : list of parameters used in experiment names
|
|
91
|
+
- same_line (list dict or empty dict) : {"param_name": list of values} values on x axis,
|
|
92
|
+
empty dict will use training steps as x axis
|
|
93
|
+
- metric (str) : name of the metric to plot
|
|
94
|
+
- ignore_missing (bool) : True to plot despite missing data
|
|
95
|
+
- plot_kwargs (dict) : dictionnary of parameters to forward to seeds_plot_together()
|
|
96
|
+
- custom_xlab (func) : function that gives x label from x parameter name
|
|
97
|
+
- custom_ylab (func) : function that gives y label from y parameter name
|
|
98
|
+
"""
|
|
99
|
+
dir_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
100
|
+
os.makedirs(f"{dir_path}/{plot_dir}/", exist_ok=True)
|
|
101
|
+
path = f"{dir_path}/{plot_dir}/{plot_conf['filename']}"
|
|
102
|
+
if plot_conf["type"] == "lines2d": # 2D line plot
|
|
103
|
+
all_lines = []
|
|
104
|
+
all_legends = []
|
|
105
|
+
for legend, lineconf in tqdm(plot_conf["lines"].items(), desc=f"Plotting {plot_conf['filename']}"):
|
|
106
|
+
all_lines.append(_extract_line_data(lineconf, res_dir, nametag, exp_tags, ignore_missing))
|
|
107
|
+
all_legends.append(legend)
|
|
108
|
+
if len(same_line) == 0: # implicit values for x axis
|
|
109
|
+
xlab, xvals = "steps", None
|
|
110
|
+
else: # explicit values for x axis
|
|
111
|
+
xlab, xvals = list(same_line.items())[0]
|
|
112
|
+
seeds_plot_together(
|
|
113
|
+
all_lines, all_legends, title=plot_conf["title"], savepath=path,
|
|
114
|
+
ylab=custom_ylab(metric), xlab=custom_xlab(xlab), x_vals=xvals, **plot_kwargs
|
|
115
|
+
)
|
|
116
|
+
elif plot_conf["type"] in ["colors3d", "surface3d"]: # 3D plots
|
|
117
|
+
all_rows = []
|
|
118
|
+
xlab, xvals = list(same_line.items())[0]
|
|
119
|
+
ylab, yvals = list(same_line.items())[1]
|
|
120
|
+
for lineconf in tqdm(plot_conf["rows"], desc=f"Plotting {plot_conf['filename']}"):
|
|
121
|
+
all_rows.append(_extract_line_data(lineconf, res_dir, nametag, exp_tags, ignore_missing))
|
|
122
|
+
all_seeds = np.transpose(np.array(all_rows), (1, 0, 2))
|
|
123
|
+
plot_funcs = {"colors3d": seeds_plot_color3d, "surface3d": seeds_plot_surface3d}
|
|
124
|
+
plot_funcs[plot_conf["type"]](
|
|
125
|
+
all_seeds, x_vals=xvals, y_vals=yvals, label=metric,
|
|
126
|
+
title=plot_conf["title"], xlab=custom_xlab(xlab), ylab=custom_ylab(ylab),
|
|
127
|
+
savepath=path, std=False, **plot_kwargs)
|
|
128
|
+
|
|
129
|
+
# plot_conf = { "filename" : "plot", "title": "title_foo", "type":"3d",
|
|
130
|
+
# "rows": [lineconf1, lineconf2]}
|
|
131
|
+
def plot_experiments(
|
|
132
|
+
res_dir=None, plot_dir=None, metrics=None,
|
|
133
|
+
seeds=None, nametag="", params_common=None,
|
|
134
|
+
diff_plots=None, same_plot=None, same_line=None, set_depending_params=None,
|
|
135
|
+
exp_tags=None, ignore_missing=None, style3Dplot=None,
|
|
136
|
+
custom_title=None, custom_legend=None,
|
|
137
|
+
custom_xlab=None, custom_ylab=None,
|
|
138
|
+
):
|
|
139
|
+
""" Plot a series of experiments
|
|
140
|
+
Args:
|
|
141
|
+
- res_dir (str) : directory containing the data to plot
|
|
142
|
+
- plot_dir (str) : directory where to save the plot
|
|
143
|
+
- metrics ((str | str tuple) list) : list of metrics to plot (batch to put on same plot)
|
|
144
|
+
- seeds (int list) : seeds of runs to plot
|
|
145
|
+
- nametag (str) : prefix identifying the series of experiments
|
|
146
|
+
- exp_tags (str list) : list of parameters used in experiment names
|
|
147
|
+
- same_line (list dict or empty dict) : {"param_name": list of values} values on x axis,
|
|
148
|
+
empty dict will use training steps as x axis
|
|
149
|
+
- params_common (dict) : dictionnary of {"param_name": value}
|
|
150
|
+
that are de default parameters of experiment_func
|
|
151
|
+
- diff_plots (list dict) : dictionnary of {"param_name": list of values} each value combination on a different plot
|
|
152
|
+
- same_plot (list dict) : dictionnary of {"param_name": list of values} each value combination on a different line of the same plot
|
|
153
|
+
- same_line (list dict) : dictionnary of {"param_name": list of values}, values on the x axis of the plot (only one parameter)
|
|
154
|
+
- set_depending_params (func) : function editing in-place a dictionnary of parameters
|
|
155
|
+
- exp_tags (str list) : list of parameters to put in experiment names
|
|
156
|
+
- ignore_missing (bool) : True to plot despite missing data
|
|
157
|
+
- style3Dplot (str) : 'surface3d' or 'colors3d'
|
|
158
|
+
- custom_title (func or str) : (optionnal) function that gives a title from parameters
|
|
159
|
+
- custom_legend (func) : (optionnal) function that gives a legend from parameters
|
|
160
|
+
- custom_xlab (func or str) : function that gives x label from x parameter name
|
|
161
|
+
- custom_ylab (func or str) : function that gives y label from y parameter name
|
|
162
|
+
"""
|
|
163
|
+
# handling constant functions
|
|
164
|
+
if type(custom_title) is str:
|
|
165
|
+
_title = custom_title
|
|
166
|
+
custom_title = lambda x, y : _title
|
|
167
|
+
if type(custom_xlab) is str:
|
|
168
|
+
_xlab = custom_xlab
|
|
169
|
+
custom_xlab = lambda x : _xlab
|
|
170
|
+
if type(custom_ylab) is str:
|
|
171
|
+
_ylab = custom_ylab
|
|
172
|
+
custom_ylab = lambda x : _ylab
|
|
173
|
+
# infering plot type
|
|
174
|
+
if len(same_line) < 2:
|
|
175
|
+
plot_type = "lines2d"
|
|
176
|
+
else:
|
|
177
|
+
plot_type = style3Dplot
|
|
178
|
+
# searching for config options
|
|
179
|
+
if os.path.exists("plot_config.json"):
|
|
180
|
+
plot_params = load_json("plot_config.json")[plot_type]
|
|
181
|
+
else:
|
|
182
|
+
print("plot_config.json not found, plotting without config")
|
|
183
|
+
plot_params = {"_overwrite":{}}
|
|
184
|
+
plot_tags = list(diff_plots.keys())
|
|
185
|
+
legend_tags = list(same_plot.keys())
|
|
186
|
+
# creating plot configs and plotting
|
|
187
|
+
for metrics_plot in metrics:
|
|
188
|
+
if type(metrics_plot) is str: # if only one metric on the plot
|
|
189
|
+
metrics_plot = [metrics_plot]
|
|
190
|
+
for params1 in make_grid(diff_plots):
|
|
191
|
+
if custom_title is None:
|
|
192
|
+
title = make_title(params1, plot_tags)
|
|
193
|
+
else:
|
|
194
|
+
title = custom_title(params1, metrics_plot)
|
|
195
|
+
plot_conf = {
|
|
196
|
+
"title": title,
|
|
197
|
+
"filename": make_plot_name(params1, nametag, metrics_plot[0], plot_tags)}
|
|
198
|
+
if plot_type == "lines2d": # if 2D line plot
|
|
199
|
+
plot_conf["lines"] = {}
|
|
200
|
+
for params2 in make_grid(same_plot):
|
|
201
|
+
if len(metrics_plot) > 1:
|
|
202
|
+
prefix = metric + ", "
|
|
203
|
+
else:
|
|
204
|
+
prefix = ""
|
|
205
|
+
for metric in metrics_plot:
|
|
206
|
+
def _pick_legend(params):
|
|
207
|
+
if custom_legend is None:
|
|
208
|
+
return prefix + make_legend(params, legend_tags)
|
|
209
|
+
return custom_legend(params, metric)
|
|
210
|
+
if same_line == {}:
|
|
211
|
+
params = update_params(params_common, params1, params2)
|
|
212
|
+
set_depending_params(params)
|
|
213
|
+
plot_conf["lines"][_pick_legend(params)] = {
|
|
214
|
+
"seeds": seeds, "metric": copy.deepcopy(metric), "conf": params
|
|
215
|
+
}
|
|
216
|
+
else:
|
|
217
|
+
params_list = []
|
|
218
|
+
for params3 in make_grid(same_line):
|
|
219
|
+
params = update_params(params_common, params1, params2, params3)
|
|
220
|
+
set_depending_params(params)
|
|
221
|
+
params_list.append(params)
|
|
222
|
+
plot_conf["lines"][_pick_legend(params)] = {
|
|
223
|
+
"seeds": seeds, "metric": copy.deepcopy(metric), "confs": params_list,
|
|
224
|
+
}
|
|
225
|
+
elif plot_type in ["colors3d", "surface3d"]: # if 3d plot
|
|
226
|
+
for metric in metrics_plot:
|
|
227
|
+
plot_conf["rows"] = []
|
|
228
|
+
x_param, x_values = list(same_line.items())[0]
|
|
229
|
+
y_param, y_values = list(same_line.items())[1]
|
|
230
|
+
for y_value in y_values:
|
|
231
|
+
params_list = []
|
|
232
|
+
for x_value in x_values:
|
|
233
|
+
params = update_params(params_common, params1, {x_param: x_value, y_param: y_value})
|
|
234
|
+
set_depending_params(params)
|
|
235
|
+
params_list.append(params)
|
|
236
|
+
lineconf = {"seeds": seeds, "metric": copy.deepcopy(metric), "confs": params_list,
|
|
237
|
+
}
|
|
238
|
+
plot_conf["rows"].append(lineconf)
|
|
239
|
+
plot_conf["type"] = plot_type
|
|
240
|
+
if plot_params is not None:
|
|
241
|
+
plot_kwargs = update_params(plot_params.get(metric, {}), plot_params["_overwrite"])
|
|
242
|
+
_plot_from_conf(
|
|
243
|
+
plot_conf, res_dir, plot_dir,
|
|
244
|
+
nametag, exp_tags,
|
|
245
|
+
same_line, metric, ignore_missing, plot_kwargs,
|
|
246
|
+
custom_xlab, custom_ylab
|
|
247
|
+
)
|
|
248
|
+
# lineconf = {"confs": conf_list, "seeds":[1, 2], "metric": "metric"}
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
###
|
|
2
|
+
# @file reproduce.py
|
|
3
|
+
# @author Oscar Villemaud <oscar.villemaud@epfl.ch>
|
|
4
|
+
#
|
|
5
|
+
# @section LICENSE
|
|
6
|
+
#
|
|
7
|
+
# Copyright © 2024-2026 École Polytechnique Fédérale de Lausanne (EPFL).
|
|
8
|
+
# All rights reserved.
|
|
9
|
+
#
|
|
10
|
+
# @section DESCRIPTION
|
|
11
|
+
#
|
|
12
|
+
# reproplot main functions to run, plot and manage experiments.
|
|
13
|
+
###
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
import random
|
|
17
|
+
from tqdm import tqdm
|
|
18
|
+
|
|
19
|
+
from .utils import load_json, make_exp_name, check_compatibility
|
|
20
|
+
from .reprod_exps import run_experiments
|
|
21
|
+
from .reprod_plots import plot_experiments
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def rename_exps(
|
|
25
|
+
directory, new_tag, new_exp_tags, old_tag=None):
|
|
26
|
+
""" rename experiment names using parameters saved
|
|
27
|
+
Args:
|
|
28
|
+
- directory : directory of experiments
|
|
29
|
+
- new_tag (str) : new prefix for experiment names
|
|
30
|
+
- new_exp_tags (str list) : new names of metrics to put in experiment names
|
|
31
|
+
- old_tag (str) : specify to rename only experiments with that tag
|
|
32
|
+
"""
|
|
33
|
+
print("Renaming experiments of directory :", directory)
|
|
34
|
+
exp_names_paths = [(f.name, f.path) for f in os.scandir(directory) if f.is_dir()]
|
|
35
|
+
counter = 0
|
|
36
|
+
for exp_name, exp_path in exp_names_paths:
|
|
37
|
+
prefix_cond = True
|
|
38
|
+
if old_tag is not None:
|
|
39
|
+
prefix = exp_name.split("-")[0]
|
|
40
|
+
prefix_cond = prefix == old_tag
|
|
41
|
+
if prefix_cond:
|
|
42
|
+
params = load_json(exp_path + "/params.json")
|
|
43
|
+
exp_name_new = make_exp_name(params, new_tag, tags_list=new_exp_tags)
|
|
44
|
+
os.rename(exp_path, directory + "/" + exp_name_new)
|
|
45
|
+
counter += 1
|
|
46
|
+
print("renaming:", exp_name, "\n into: ", exp_name_new)
|
|
47
|
+
print(f"Renamed {counter} experiments")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def index_exps(directory, nametag=None, param_requirements=None):
|
|
51
|
+
""" gives a summary of available experiments
|
|
52
|
+
Args:
|
|
53
|
+
- directory : directory of experiments
|
|
54
|
+
- nametag (str) : specify to see only experiments with that tag
|
|
55
|
+
- param_requirements (bool func) : function taking params as input
|
|
56
|
+
and outputing True if this experiment should be included
|
|
57
|
+
"""
|
|
58
|
+
exp_names_paths = [(f.name, f.path) for f in os.scandir(directory) if f.is_dir()]
|
|
59
|
+
all_params = {}
|
|
60
|
+
seeds = set()
|
|
61
|
+
nb_runs, nb_exps = 0, 0
|
|
62
|
+
for exp_name, exp_path in tqdm(exp_names_paths):
|
|
63
|
+
loaded_params = load_json(exp_path + "/params.json")
|
|
64
|
+
select = True
|
|
65
|
+
if param_requirements is not None:
|
|
66
|
+
select = param_requirements(loaded_params)
|
|
67
|
+
if nametag is not None:
|
|
68
|
+
prefix = exp_name.split("-")[0]
|
|
69
|
+
select = select and (prefix == nametag)
|
|
70
|
+
if select:
|
|
71
|
+
nb_exps += 1
|
|
72
|
+
for name, value in loaded_params.items():
|
|
73
|
+
if type(value) is list:
|
|
74
|
+
value = tuple(value)
|
|
75
|
+
if name in all_params:
|
|
76
|
+
all_params[name].add(value)
|
|
77
|
+
else:
|
|
78
|
+
all_params[name] = {value}
|
|
79
|
+
for seed_name in [f.name for f in os.scandir(exp_path) if f.is_dir()]:
|
|
80
|
+
if "failed" not in seed_name:
|
|
81
|
+
seed = int(seed_name[5:])
|
|
82
|
+
seeds.add(seed)
|
|
83
|
+
nb_runs += 1
|
|
84
|
+
print(f"found {nb_runs} runs grouped in {nb_exps}/{len(exp_names_paths)} experiments with parameters :")
|
|
85
|
+
def _custom_key(obj):
|
|
86
|
+
if obj is None:
|
|
87
|
+
return 0
|
|
88
|
+
elif type(obj) is int or float:
|
|
89
|
+
return obj
|
|
90
|
+
else:
|
|
91
|
+
return len(obj)
|
|
92
|
+
for name, values in all_params.items():
|
|
93
|
+
print(name, sorted(values, key=_custom_key))
|
|
94
|
+
print("seeds", sorted(seeds))
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def run_and_plot(
|
|
98
|
+
seeds=None, res_dir="results_RPP", plot_dir="plots_RPP",
|
|
99
|
+
metrics=["metric"], experiment_func=None,
|
|
100
|
+
nametag="", params_common={},
|
|
101
|
+
diff_plots={}, same_plot={}, same_line={},
|
|
102
|
+
exp_tags=None, set_depending_params=lambda x : None,
|
|
103
|
+
no_run=False, no_plot=False,
|
|
104
|
+
ignore_missing=False, style3Dplot="colors3d", verb=["pbar"],
|
|
105
|
+
custom_title=None, custom_legend=None, custom_xlab=lambda x:x, custom_ylab=lambda x:x,
|
|
106
|
+
):
|
|
107
|
+
""" run and plot experiments by using all hyperparameters in a grid-search fashion
|
|
108
|
+
Args:
|
|
109
|
+
- res_dir (str) : directory containing the data to plot
|
|
110
|
+
- plot_dir (str) : directory where to save the plot
|
|
111
|
+
- seeds (int list) : seeds of runs to plot
|
|
112
|
+
- metrics ((str | str tuple) list) : list of metrics to plot (batch to put on same plot)
|
|
113
|
+
- nametag (str) : prefix identifying the series of experiments
|
|
114
|
+
- exp_tags (str list) : list of parameters used in experiment names
|
|
115
|
+
- same_line (list dict or empty dict) : {"param_name": list of values} values on x axis,
|
|
116
|
+
empty dict will use training steps as x axis
|
|
117
|
+
- params_common (dict) : dictionnary of {"param_name": value}
|
|
118
|
+
that are de default parameters of experiment_func
|
|
119
|
+
- diff_plots (list dict) : dictionnary of {"param_name": list of values} each value combination on a different plot
|
|
120
|
+
- same_plot (list dict) : dictionnary of {"param_name": list of values} each value combination on a different line of the same plot
|
|
121
|
+
- same_line (list dict) : dictionnary of {"param_name": list of values}, values on the x axis of the plot (only one parameter)
|
|
122
|
+
- set_depending_params (func) : function editing in-place a dictionnary of parameters
|
|
123
|
+
- exp_tags (str list) : list of parameters to put in experiment names
|
|
124
|
+
- no_run (bool) : True to disable running new experiments
|
|
125
|
+
- no_plot (bool) : True to disable plotting
|
|
126
|
+
- ignore_missing (bool) : True to plot despite missing data
|
|
127
|
+
- style3Dplot (str) : 'surface3d' or 'colors3d'
|
|
128
|
+
- verb (str list) : string codes to indicate verbose
|
|
129
|
+
- custom_title (func or str) : (optionnal) function that gives a title from parameters
|
|
130
|
+
- custom_legend (func) : (optionnal) function that gives a legend from parameters
|
|
131
|
+
- custom_xlab (func or str) : (optionnal) function that gives x label from x parameter name
|
|
132
|
+
- custom_ylab (func or str) : (optionnal) function that gives y label from y parameter name
|
|
133
|
+
"""
|
|
134
|
+
if exp_tags is None:
|
|
135
|
+
exp_tags = sorted(list(set(diff_plots.keys()) | set(same_plot.keys()) | set(same_line.keys())))
|
|
136
|
+
if seeds is None:
|
|
137
|
+
seeds = [random.randint(1, 99999)]
|
|
138
|
+
print(f"No seeds specified, using random seed {seeds[0]}")
|
|
139
|
+
check_compatibility(diff_plots, same_plot, same_line, exp_tags)
|
|
140
|
+
|
|
141
|
+
if not no_run:
|
|
142
|
+
run_experiments(
|
|
143
|
+
experiment_func=experiment_func, res_dir=res_dir, seeds=seeds, nametag=nametag,
|
|
144
|
+
params_common=params_common, diff_plots=diff_plots, same_plot=same_plot, same_line=same_line,
|
|
145
|
+
exp_tags=exp_tags, set_depending_params=set_depending_params, verb=verb,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
if not no_plot:
|
|
149
|
+
plot_experiments(
|
|
150
|
+
res_dir=res_dir, plot_dir=plot_dir, seeds=seeds,
|
|
151
|
+
nametag=nametag, metrics=metrics, same_line=same_line,
|
|
152
|
+
params_common=params_common, diff_plots=diff_plots, same_plot=same_plot,
|
|
153
|
+
exp_tags=exp_tags, set_depending_params=set_depending_params,
|
|
154
|
+
ignore_missing=ignore_missing, style3Dplot=style3Dplot,
|
|
155
|
+
custom_title=custom_title, custom_legend=custom_legend,
|
|
156
|
+
custom_xlab=custom_xlab, custom_ylab=custom_ylab,
|
|
157
|
+
)
|
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
###
|
|
2
|
+
# @file utils.py
|
|
3
|
+
# @author Oscar Villemaud <oscar.villemaud@epfl.ch>
|
|
4
|
+
#
|
|
5
|
+
# @section LICENSE
|
|
6
|
+
#
|
|
7
|
+
# Copyright © 2024-2026 École Polytechnique Fédérale de Lausanne (EPFL).
|
|
8
|
+
# All rights reserved.
|
|
9
|
+
#
|
|
10
|
+
# @section DESCRIPTION
|
|
11
|
+
#
|
|
12
|
+
# Utilitary functions.
|
|
13
|
+
###
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
import json
|
|
17
|
+
import pickle
|
|
18
|
+
import random
|
|
19
|
+
from itertools import product
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
import torch
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def dump_json(object, path, indent=0):
|
|
26
|
+
""" save object in json file
|
|
27
|
+
Args:
|
|
28
|
+
- object : python object to save
|
|
29
|
+
- path (str) : path where to save the object
|
|
30
|
+
- indent (int) : indent level to use (for readability)
|
|
31
|
+
"""
|
|
32
|
+
with open(path, "w") as outfile:
|
|
33
|
+
json.dump(object, outfile, indent=indent)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def load_json(path):
|
|
37
|
+
""" load object from json file
|
|
38
|
+
Args:
|
|
39
|
+
- path (str) : path where to save the object
|
|
40
|
+
"""
|
|
41
|
+
with open(path, 'r') as openfile:
|
|
42
|
+
return json.load(openfile)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def dump_pickle(object, path):
|
|
46
|
+
""" save object in pickle file
|
|
47
|
+
Args:
|
|
48
|
+
- object : python object to save
|
|
49
|
+
- path (str) : path where to save the object
|
|
50
|
+
"""
|
|
51
|
+
with open(path, "wb") as outfile:
|
|
52
|
+
pickle.dump(object, outfile)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def load_pickle(path):
|
|
56
|
+
""" load object from pickle file
|
|
57
|
+
Args:
|
|
58
|
+
- path (str) : path where to save the object
|
|
59
|
+
"""
|
|
60
|
+
with open(path, 'rb') as openfile:
|
|
61
|
+
return pickle.load(openfile)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def seedall(seed):
|
|
65
|
+
""" seed random, numpy and pytorch with the same seed
|
|
66
|
+
Args:
|
|
67
|
+
- seed (int) : seed to use
|
|
68
|
+
"""
|
|
69
|
+
np.random.seed(seed)
|
|
70
|
+
torch.manual_seed(seed)
|
|
71
|
+
random.seed(seed)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def make_grid(list_dic):
|
|
75
|
+
""" build iterator over different combination of params (gridsearch)
|
|
76
|
+
Args:
|
|
77
|
+
- list_dic (list dict): dictionnary of {"param_name": list of values}
|
|
78
|
+
|
|
79
|
+
Yields:
|
|
80
|
+
- (dict)
|
|
81
|
+
"""
|
|
82
|
+
for params in product(*list_dic.values()):
|
|
83
|
+
yield {key : param for key, param in zip(list_dic.keys(), params)}
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def count_combinations(list_dic):
|
|
87
|
+
""" count number of parameter combinations
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
- list_dic : dictionnary of {"metric": list of values}
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
- (int) number of possible parameter combinations
|
|
94
|
+
"""
|
|
95
|
+
nb = 1
|
|
96
|
+
for _, params in list_dic.items():
|
|
97
|
+
nb *= len(params)
|
|
98
|
+
return nb
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def update_params(params, *updates):
|
|
102
|
+
""" add new parameters or replace existing ones
|
|
103
|
+
Args:
|
|
104
|
+
- params : dictionnary of parameters
|
|
105
|
+
- *updates : dictionnaries of additionnal parameters
|
|
106
|
+
Return:
|
|
107
|
+
- (dict) updated dictionnary of parameters
|
|
108
|
+
"""
|
|
109
|
+
paramsall = params.copy()
|
|
110
|
+
for new_params in updates:
|
|
111
|
+
paramsall.update(new_params)
|
|
112
|
+
return paramsall
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def make_exp_name(params, nametag, tags_list):
|
|
116
|
+
""" create experiment name
|
|
117
|
+
Args:
|
|
118
|
+
- params : dictionnary of {"metric": value}
|
|
119
|
+
- nametag (str) : prefix to identify experiment series
|
|
120
|
+
- tag_list (str list) : list of metrics to put in experiment name
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
- (str) experiment name
|
|
124
|
+
"""
|
|
125
|
+
exp_name = nametag
|
|
126
|
+
for name in tags_list:
|
|
127
|
+
exp_name += f"-{name}_{params[name]}"
|
|
128
|
+
return exp_name
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def make_plot_name(params, nametag, metric, tags_list):
|
|
132
|
+
"""" result should depend on diff_plot
|
|
133
|
+
Args:
|
|
134
|
+
- params : dictionnary of {"metric": value}
|
|
135
|
+
- nametag (str) : prefix to identify experiment series
|
|
136
|
+
- metric (str) : metric plotted
|
|
137
|
+
- tag_list (str list) : list of param names to put in experiment name
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
- (str) plot name
|
|
141
|
+
"""
|
|
142
|
+
plot_name = f"{nametag}-{metric}"
|
|
143
|
+
for name in tags_list:
|
|
144
|
+
plot_name += f"-{name}_{params[name]}"
|
|
145
|
+
return plot_name
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def make_title(params, tags_list):
|
|
149
|
+
"""
|
|
150
|
+
Args:
|
|
151
|
+
- params : dictionnary of {"metric": value}
|
|
152
|
+
- tag_list (str list) : list of param names to put in experiment name
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
- (str) plot title
|
|
156
|
+
"""
|
|
157
|
+
if len(tags_list) == 0:
|
|
158
|
+
return "X"
|
|
159
|
+
# title = f"{tags_list[0]}={params[tags_list[0]]}"
|
|
160
|
+
title = ""
|
|
161
|
+
for name in tags_list:
|
|
162
|
+
if len(title):
|
|
163
|
+
title += ", "
|
|
164
|
+
title += f"{name}={params[name]}"
|
|
165
|
+
return title
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def make_legend(params, tags_list):
|
|
169
|
+
""" create a legend for a line on a plot
|
|
170
|
+
from parameters and a list of parameter names
|
|
171
|
+
Args:
|
|
172
|
+
- params : dictionnary of {"metric": value}
|
|
173
|
+
- tag_list (str list) : list of param names to put in experiment name
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
- (str) line legend
|
|
177
|
+
"""
|
|
178
|
+
if len(tags_list) == 0:
|
|
179
|
+
return "X"
|
|
180
|
+
legend = f"{tags_list[0]}={params[tags_list[0]]}"
|
|
181
|
+
for name in tags_list[1:]:
|
|
182
|
+
legend += f", {name}={params[name]}"
|
|
183
|
+
return legend
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def check_compatibility(
|
|
187
|
+
diff_plots, same_plot, same_line, exp_tags
|
|
188
|
+
):
|
|
189
|
+
""" check if the same name if generated twice
|
|
190
|
+
Args:
|
|
191
|
+
- diff_plots (list dict) : dictionnary of {"param_name": list of values} each value combination on a different plot
|
|
192
|
+
- same_plot (list dict) : dictionnary of {"param_name": list of values} each value combination on a different line of the same plot
|
|
193
|
+
- same_line (list dict) : dictionnary of {"param_name": list of values}, values on the x axis of the plot (only one parameter)
|
|
194
|
+
- exp_tags (str list) : list of parameters to put in experiment names
|
|
195
|
+
"""
|
|
196
|
+
for param in list(diff_plots.keys()) + list(same_plot.keys()) + list(same_line.keys()):
|
|
197
|
+
if param not in exp_tags:
|
|
198
|
+
print(f"WARNING : Experiment names should depend on -{param} to avoid having the same name")
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def scan_runs(
|
|
202
|
+
res_dir, dir_path=None,
|
|
203
|
+
seeds=None, nametag="", params_common=None,
|
|
204
|
+
diff_plots=None, same_plot=None, same_line=None,
|
|
205
|
+
exp_tags=None, set_depending_params=lambda x: None):
|
|
206
|
+
""" finds existing and missing runs
|
|
207
|
+
- res_dir (str) : name of the directory where to store results as json
|
|
208
|
+
- dir_path (str) : path to directory containing -res_dir
|
|
209
|
+
- seeds (int list) : list of random seeds to use for reproducibility
|
|
210
|
+
- nametag (str) : prefix that identifies the series of experiments
|
|
211
|
+
- params_common (dict) : dictionnary of {"param_name": value}
|
|
212
|
+
that are de default parameters of experiment_func
|
|
213
|
+
- diff_plots (list dict) : dictionnary of {"param_name": list of values} each value combination on a different plot
|
|
214
|
+
- same_plot (list dict) : dictionnary of {"param_name": list of values} each value combination on a different line of the same plot
|
|
215
|
+
- same_line (list dict) : dictionnary of {"param_name": list of values}, values on the x axis of the plot (only one parameter)
|
|
216
|
+
- exp_tags (str list) : list of parameters to put in experiment names
|
|
217
|
+
- set_depending_params (func) : function editing in-place a dictionnary of parameters
|
|
218
|
+
"""
|
|
219
|
+
confs_paths = []
|
|
220
|
+
count_done, count_total = 0, 0
|
|
221
|
+
for params1 in make_grid(diff_plots):
|
|
222
|
+
for params2 in make_grid(same_plot):
|
|
223
|
+
for params3 in make_grid(same_line):
|
|
224
|
+
params = update_params(params_common, params1, params2, params3)
|
|
225
|
+
set_depending_params(params)
|
|
226
|
+
exp_name = make_exp_name(params, nametag, exp_tags)
|
|
227
|
+
exp_dir = f"{dir_path}/{res_dir}/{exp_name}/"
|
|
228
|
+
for seed in seeds:
|
|
229
|
+
path = exp_dir + f"seed_{seed}/"
|
|
230
|
+
if os.path.isdir(path):
|
|
231
|
+
count_done += 1
|
|
232
|
+
else:
|
|
233
|
+
confs_paths.append((params.copy(), path))
|
|
234
|
+
count_total += 1
|
|
235
|
+
return count_done, count_total, confs_paths
|