reproplot 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Oscar Villemaud, Indy Lab, EPFL
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,33 @@
1
+ Metadata-Version: 2.1
2
+ Name: reproplot
3
+ Version: 0.0.1
4
+ Summary: General package for experimenting and automatically plotting.
5
+ Author-email: Oscar Villemaud <oscar.villemaud@epfl.ch>
6
+ License: MIT License
7
+
8
+ Copyright (c) 2024 Oscar Villemaud, Indy Lab, EPFL
9
+
10
+ Permission is hereby granted, free of charge, to any person obtaining a copy
11
+ of this software and associated documentation files (the "Software"), to deal
12
+ in the Software without restriction, including without limitation the rights
13
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14
+ copies of the Software, and to permit persons to whom the Software is
15
+ furnished to do so, subject to the following conditions:
16
+
17
+ The above copyright notice and this permission notice shall be included in all
18
+ copies or substantial portions of the Software.
19
+
20
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26
+ SOFTWARE.
27
+ License-File: LICENSE
28
+ Classifier: Operating System :: OS Independent
29
+ Classifier: Programming Language :: Python :: 3
30
+ Requires-Python: >=3.8
31
+ Description-Content-Type: text/markdown
32
+
33
+ General package for experimenting and automatically plotting.
@@ -0,0 +1 @@
1
+ General package for experimenting and automatically plotting.
@@ -0,0 +1,25 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+
6
+ [project]
7
+ license = {file = "LICENSE"}
8
+ name = "reproplot"
9
+ version = "0.0.1"
10
+ authors = [
11
+ { name="Oscar Villemaud", email="oscar.villemaud@epfl.ch" },
12
+ ]
13
+ description = "General package for experimenting and automatically plotting. "
14
+ readme = "README.md"
15
+ requires-python = ">=3.8"
16
+ classifiers = [
17
+ "Programming Language :: Python :: 3",
18
+ "Operating System :: OS Independent",
19
+ ]
20
+
21
+ [tool.hatch.build.targets.wheel]
22
+ include = [
23
+ "reproplot/*.py",
24
+ "/tests",
25
+ ]
File without changes
@@ -0,0 +1,191 @@
1
+ ###
2
+ # @file plotting.py
3
+ # @author Oscar Villemaud <oscar.villemaud@epfl.ch>
4
+ #
5
+ # @section LICENSE
6
+ #
7
+ # Copyright © 2024-2026 École Polytechnique Fédérale de Lausanne (EPFL).
8
+ # All rights reserved.
9
+ #
10
+ # @section DESCRIPTION
11
+ #
12
+ # Plotting functions based on pyplot.
13
+ ###
14
+
15
+ import numpy as np
16
+ from matplotlib import pyplot as plt
17
+ from matplotlib import colors
18
+
19
+
20
+ def _finalize_plot(title, xlab, ylab, fontsize, savepath, extension, show):
21
+ plt.title(title)
22
+ plt.xlabel(xlab, size=fontsize)
23
+ plt.ylabel(ylab, size=fontsize)
24
+ if savepath is not None:
25
+ plt.savefig(f"{savepath}.{extension}")
26
+ if show:
27
+ plt.show()
28
+ plt.close()
29
+
30
+
31
+ def plot(data, legend=None, title="", log=False):
32
+ """ simple plotting function """
33
+ def _plot_scaled(curve):
34
+ if log:
35
+ plt.semilogy(curve)
36
+ else:
37
+ plt.plot(curve)
38
+ if not hasattr(data[0], '__iter__'):
39
+ _plot_scaled(data)
40
+ else:
41
+ for curve in data:
42
+ _plot_scaled(curve)
43
+ if legend is not None:
44
+ plt.legend(legend)
45
+ plt.title(title)
46
+ plt.show()
47
+ plt.close()
48
+
49
+
50
+ def seeds_plot(
51
+ list_list, legend=None, x_vals=None, color=None, style=None,
52
+ xlog=False, ylog=False, confidence=True, std=False, plot_all=False, x_vlines=False):
53
+ """ plots one line and confidence interval from multiple seeds """
54
+ if len(list_list) == 0:
55
+ return
56
+ def _log_switch_plot(values, x_vals):
57
+ if ylog and xlog:
58
+ plt.loglog(x_vals, values, label=legend, linestyle=style, color=color)
59
+ elif ylog and not xlog:
60
+ plt.semilogy(x_vals, values, label=legend, linestyle=style, color=color)
61
+ elif xlog and not ylog:
62
+ plt.semilogx(x_vals, values, label=legend, linestyle=style, color=color)
63
+ else:
64
+ plt.plot(x_vals, values, label=legend, linestyle=style, color=color)
65
+ nb_samples = len(list_list)
66
+ if plot_all:
67
+ for values in list_list:
68
+ if x_vals is None:
69
+ x_vals = list(range(len(values)))
70
+ _log_switch_plot(values, x_vals)
71
+ else:
72
+ if x_vals is None:
73
+ x_vals = list(range(len(list_list[0])))
74
+ arr = np.array(list_list)
75
+ vals = np.nanmean(arr, axis=0)
76
+ _log_switch_plot(vals, x_vals)
77
+ if confidence:
78
+ confs = 1.96 * np.nanstd(arr, axis=0) / nb_samples**0.5
79
+ plt.plot(x_vals, vals - confs, linestyle=style, color=color, linewidth=0.3)
80
+ plt.plot(x_vals, vals + confs, linestyle=style, color=color, linewidth=0.3)
81
+ if std:
82
+ stds = np.nanstd(arr, axis=0)
83
+ plt.fill_between(list(x_vals), vals - stds, vals + stds, alpha=0.1, color=color)
84
+ if x_vlines:
85
+ for x in x_vals:
86
+ plt.axvline(x, linewidth=0.3)
87
+
88
+
89
+ def seeds_plot_together(
90
+ all_curves, legends=None, title="", xlog=False, ylog=False, confidence=True, std=False,
91
+ vlines=[], x_vals=None, xlab=None, ylab=None, fontsize=11, xlims=None, ylims=None,
92
+ savepath=None, figsize=(8, 5), show=False, plot_all=False, x_vlines=False, extension="png",
93
+ ):
94
+ """
95
+ Plots several lines of several seeds (for each line average and confidence interval)
96
+ Args:
97
+ - all_curves (float list list list) : order 3 array/list of lists of lists
98
+ one sublist is one line, one subsublist is one seed
99
+ - legends (str list) : labels to use for each line (in order)
100
+ - title (str) : title of the plot
101
+ - xlog (bool) : True for x axis log scale
102
+ - ylog (bool) : True for y axis log scale
103
+ - confidence (bool) : True to display 95% mean estimate confidence intervals
104
+ - std (bool) : True to display standard deviation accross seeds
105
+ - vlines (float list) : list of x coordinates where to add vertical lines
106
+ - x_vals (list) : x axis values, default is 0 to n
107
+ - xlab (str) : label of x axis
108
+ - ylab (str) : label of y axis
109
+ - fontsize (int) : font size
110
+ - xlims (float pair) : plot limits for x axis, None for auto
111
+ - ylims (float pair) : plot limits for y axis, None for auto
112
+ - savepath (str) : path where to save the plot, not saved if None
113
+ - figsize (int pair) : dimensions of the plot
114
+ - show (bool) : True to show the plot
115
+ - plot_all (bool) : True to display one line for each seed instead of average
116
+ - x_vlines (bool) : True to draw a vertical line at each data point
117
+ - extension (str) : format for the saved image file
118
+ """
119
+ plt.figure(figsize=figsize)
120
+ colors = ["orange", "green", "blue", "red", "purple", "black"] * 10
121
+ styles = ["-", "--", "-.", ":", "-"] * 10
122
+ if legends is None:
123
+ legends = [[]] * len(all_curves)
124
+ for curve, color, style, legend in zip(all_curves, colors, styles, legends):
125
+ seeds_plot(
126
+ curve, color=color, x_vals=x_vals, style=style, legend=legend,
127
+ xlog=xlog, ylog=ylog, confidence=confidence, std=std, plot_all=plot_all, x_vlines=x_vlines)
128
+ for x in vlines:
129
+ plt.axvline(x)
130
+ plt.xlim(xlims)
131
+ plt.ylim(ylims)
132
+ if legend is not None:
133
+ plt.legend(prop={'size': fontsize})
134
+ _finalize_plot(title, xlab, ylab, fontsize, savepath, extension, show)
135
+
136
+
137
+ def seeds_plot_color3d(
138
+ all_seeds, x_vals, y_vals, title, xlab, ylab,
139
+ savepath, show=False, label="", fontsize=11, std=False,
140
+ xlog=False, ylog=False, zlog=False, extension="png",
141
+ **kwargs):
142
+ """ 3d color plot """
143
+ all_seeds = np.array(all_seeds)
144
+ means = np.nanmean(all_seeds, axis=0)
145
+ if zlog:
146
+ plt.pcolor(x_vals, y_vals, means, norm=colors.LogNorm())
147
+ else:
148
+ plt.pcolor(x_vals, y_vals, means)
149
+
150
+ if xlog:
151
+ plt.xscale('log')
152
+ if ylog:
153
+ plt.yscale('log')
154
+ if std:
155
+ stds = np.nanstd(all_seeds, axis=0)
156
+ for y, row in enumerate(stds):
157
+ for x, val in enumerate(row):
158
+ plt.text(x_vals[x], y_vals[y] ,
159
+ f"+{round(val, 2)}", ha='center',
160
+ va='center', color='black')
161
+ plt.colorbar(label=label)
162
+ _finalize_plot(title, xlab, ylab, fontsize, savepath, extension, show)
163
+
164
+
165
+ def seeds_plot_surface3d(
166
+ all_seeds, x_vals, y_vals, title, xlab, ylab,
167
+ savepath, show=False, label="", fontsize=11,
168
+ xlog=False, ylog=False, zlog=False, angle=None, extension="png",
169
+ **kwargs):
170
+ """ 3d color plot """
171
+ all_seeds = np.array(all_seeds)
172
+ means = np.nanmean(all_seeds, axis=0)
173
+ ax = plt.axes(projection='3d')
174
+ if angle is not None:
175
+ ax.view_init(*angle)
176
+ x, y = x_vals, y_vals
177
+ y_len, x_len = len(y), len(x)
178
+ x = np.expand_dims(x, axis=1)
179
+ x = np.repeat(x, [y_len], axis=1).transpose()
180
+ y = np.expand_dims(y, axis=0)
181
+ y = np.repeat(y, [x_len], axis=0).transpose()
182
+ # if xlog:
183
+ # ax.set_xscale('log')
184
+ # if ylog:
185
+ # ax.set_yscale('log')
186
+ if zlog:
187
+ ax.set_zscale('log')
188
+ ax.plot_surface(x, y, means, cmap='viridis',\
189
+ edgecolor='green')
190
+ ax.set_zlabel(label, size=fontsize)
191
+ _finalize_plot(title, xlab, ylab, fontsize, savepath, extension, show)
@@ -0,0 +1,141 @@
1
+ ###
2
+ # @file reprod_exps.py
3
+ # @author Oscar Villemaud <oscar.villemaud@epfl.ch>
4
+ #
5
+ # @section LICENSE
6
+ #
7
+ # Copyright © 2024-2026 École Polytechnique Fédérale de Lausanne (EPFL).
8
+ # All rights reserved.
9
+ #
10
+ # @section DESCRIPTION
11
+ #
12
+ # Experiment running.
13
+ ###
14
+
15
+ import os
16
+ import time
17
+ from tqdm import tqdm
18
+ import traceback
19
+ import multiprocessing as mp
20
+
21
+ from .utils import seedall, dump_json, load_json, make_exp_name, make_grid, update_params, scan_runs
22
+
23
+
24
+ def _run_one_run(experiment_func, params, seed, exp_name, run_path, verb):
25
+ """ return 0 if run completed, 1 if failed"""
26
+ if "runs" in verb:
27
+ print(f"running {exp_name} seed {seed}.")
28
+ def _handle_error(exc):
29
+ count = 1
30
+ while os.path.isdir(f"{run_path[:-1]}_failed_{count}"):
31
+ count += 1
32
+ new_path = f"{run_path[:-1]}_failed_{count}"
33
+ os.rename(run_path, new_path)
34
+ print(f"{exp_name} seed {seed} failed with error: {exc}, renaming to: failed_{count}")
35
+ with open(new_path + "/" + "traceback.txt", 'w') as f:
36
+ f.write(traceback.format_exc())
37
+ try:
38
+ run_time = time.time()
39
+ seedall(seed)
40
+ metrics = experiment_func(**params)
41
+ if type(metrics) is not dict:
42
+ metrics = {"metric": metrics}
43
+ run_time = time.time() - run_time
44
+ metrics["run_time"] = run_time
45
+ dump_json(metrics, run_path + f"metrics.json")
46
+ if "runs" in verb:
47
+ print(f"{exp_name} seed {seed} saved. ({round(run_time)} secs)")
48
+ except Exception as exc:
49
+ _handle_error(exc)
50
+ return 1
51
+ except KeyboardInterrupt:
52
+ _handle_error("Keyboard Interrupt")
53
+ raise KeyboardInterrupt
54
+ return 0
55
+
56
+
57
+ def run_experiments(
58
+ experiment_func=None, res_dir=None,
59
+ seeds=None, nametag="", params_common=None, diff_plots=None, same_plot=None, same_line=None,
60
+ exp_tags=None, set_depending_params=None, verb=None):
61
+ """ run an experiment series given a grid of parameters, saves results in json files
62
+ Args:
63
+ - experiment_func (func) : function running a (random) experiment
64
+ and outputing a dictionnary of metrics
65
+ - res_dir (str) : name of the directory where to store results as json
66
+ - seeds (int list) : list of random seeds to use for reproducibility
67
+ - nametag (str) : prefix that identifies the series of experiments
68
+ - params_common (dict) : dictionnary of {"param_name": value}
69
+ that are de default parameters of experiment_func
70
+ - diff_plots (list dict) : dictionnary of {"param_name": list of values} each value combination on a different plot
71
+ - same_plot (list dict) : dictionnary of {"param_name": list of values} each value combination on a different line of the same plot
72
+ - same_line (list dict) : dictionnary of {"param_name": list of values}, values on the x axis of the plot (only one parameter)
73
+ - exp_tags (str list) : list of parameters to put in experiment names
74
+ - set_depending_params (func) : function editing in-place a dictionnary of parameters
75
+ - verb (str list) : string codes to indicate verbose
76
+ """
77
+ multiprocess = 0
78
+ dir_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
79
+ nb_runs_found, nb_runs_total, confs_paths = scan_runs(
80
+ res_dir=res_dir, dir_path=dir_path,
81
+ seeds=seeds, nametag=nametag, params_common=params_common,
82
+ diff_plots=diff_plots, same_plot=same_plot, same_line=same_line,
83
+ exp_tags=exp_tags, set_depending_params=set_depending_params)
84
+ nb_runs_needed = nb_runs_total - nb_runs_found
85
+ print(f"Recovered {nb_runs_found}/{nb_runs_total} runs. Running the remaining {nb_runs_needed}.")
86
+ # nb_exps = count_combinations(diff_plots) * count_combinations(same_plot) * count_combinations(same_line) * len(seeds)
87
+ nb_runs_ran = 0
88
+ init_time = time.time()
89
+ # TODO use confs_paths ?
90
+ if multiprocess:
91
+ mp_pool = mp.Pool(processes=multiprocess)
92
+ else:
93
+ nb_failed = 0
94
+ with tqdm(total=nb_runs_needed, leave=True, disable=("pbar" not in verb)) as pbar:
95
+ for params1 in make_grid(diff_plots):
96
+ for params2 in make_grid(same_plot):
97
+ for params3 in make_grid(same_line):
98
+ params = update_params(params_common, params1, params2, params3)
99
+ set_depending_params(params)
100
+ exp_name = make_exp_name(params, nametag, exp_tags)
101
+ exp_dir = f"{dir_path}/{res_dir}/{exp_name}/"
102
+ can_run = True
103
+ try: # check if parameters are matching
104
+ loaded_params = load_json(exp_dir + "params.json")
105
+ if loaded_params != params:
106
+ print("WARNING : old and new parameters don't match despite same experiment name !")
107
+ print("found :", loaded_params)
108
+ print("but has :", params)
109
+ can_run = False
110
+ except OSError:
111
+ os.makedirs(exp_dir)
112
+ dump_json(params, exp_dir + "params.json", indent=2)
113
+ if can_run:
114
+ for seed in seeds:
115
+ run_path = exp_dir + f"seed_{seed}/"
116
+ if os.path.isdir(run_path):
117
+ if "runs" in verb:
118
+ print(exp_name, "seed", seed, "already exists.")
119
+ else:
120
+ os.makedirs(run_path)
121
+ if multiprocess:
122
+ mp_pool.apply_async(
123
+ func=_run_one_run,
124
+ args=(experiment_func, params, seed, exp_name, run_path, verb))
125
+ else:
126
+ failed = _run_one_run(experiment_func, params, seed, exp_name, run_path, verb)
127
+ nb_failed += failed
128
+ pbar.update()
129
+ nb_runs_ran += 1
130
+ else:
131
+ print("Skipping experiment")
132
+ if multiprocess:
133
+ mp_pool.close()
134
+ mp_pool.join()
135
+ time_taken = time.time() - init_time
136
+ if multiprocess:
137
+ print("Unable to count fails in multiprocess mode")
138
+ print(f"{nb_runs_ran} runs ran and {nb_runs_found} recovered in {round(time_taken)} seconds")
139
+ else:
140
+ print(f"{nb_runs_ran - nb_failed} runs ran, {nb_runs_found} recovered and {nb_failed} failed in {round(time_taken)} seconds")
141
+
@@ -0,0 +1,248 @@
1
+ ###
2
+ # @file reprod_plots.py
3
+ # @author Oscar Villemaud <oscar.villemaud@epfl.ch>
4
+ #
5
+ # @section LICENSE
6
+ #
7
+ # Copyright © 2024-2026 École Polytechnique Fédérale de Lausanne (EPFL).
8
+ # All rights reserved.
9
+ #
10
+ # @section DESCRIPTION
11
+ #
12
+ # Data retrieval for plotting.
13
+ ###
14
+
15
+ import os
16
+ import copy
17
+ from tqdm import tqdm
18
+
19
+ import numpy as np
20
+
21
+ from .utils import load_json, make_exp_name, make_grid, make_legend, make_plot_name, update_params, make_title
22
+ from .plotting import seeds_plot_together, seeds_plot_color3d, seeds_plot_surface3d
23
+
24
+
25
+ def _extract_line_data(lineconf, res_dir, nametag, exp_tags, ignore_missing):
26
+ """ extracts data from json files and puts it in list of lists
27
+ Args:
28
+ - lineconf (dict) : dictionnary containing directions to the data to plot on one line
29
+ lineconf1 = {"confs": conf_list, "seeds":[1, 2], "metric": "metric"}
30
+ or
31
+ lineconf1 = {"conf": conf, "seeds":[1, 2], "metric": "metric"}
32
+ - res_dir (str) : directory containing the data to plot
33
+ - nametag (str) : prefix identifying the series of experiments
34
+ - exp_tags (str list) : list of parameters used in experiment names
35
+ - ignore_missing (bool) : True to plot despite missing data
36
+
37
+ Returns:
38
+ - (float list list) list of lists of values of the metric, one sublist is one seed
39
+ """
40
+ dir_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
41
+ def _handle_missing(ignore_missing, path):
42
+ if ignore_missing:
43
+ print(f"{path} not found, plotting without it.")
44
+ else:
45
+ raise Exception("{} not found, use ignore_missing=True to plot anyway".format(path))
46
+ if "confs" in lineconf: # if the metric chosen gives one value (eg final loss)
47
+ all_seeds = []
48
+ for seed in lineconf["seeds"]:
49
+ one_seed = []
50
+ for params in lineconf["confs"]:
51
+ exp_name = make_exp_name(params, nametag, exp_tags)
52
+ exp_dir = f"{dir_path}/{res_dir}/{exp_name}/"
53
+ path = exp_dir + f"seed_{seed}/"
54
+ try:
55
+ value = load_json(path + "metrics.json")[lineconf["metric"]]
56
+ if type(value) is list: # implicitely taking value at end of training
57
+ value = value[-1]
58
+ one_seed.append(value)
59
+ except OSError:
60
+ _handle_missing(ignore_missing, path)
61
+ one_seed.append(np.nan)
62
+ all_seeds.append(one_seed)
63
+ elif "conf" in lineconf: # if the metric chosen gives a list of values (eg training loss)
64
+ all_seeds = []
65
+ for seed in lineconf["seeds"]:
66
+ params = lineconf["conf"]
67
+ exp_name = make_exp_name(params, nametag, exp_tags)
68
+ exp_dir = f"{dir_path}/{res_dir}/{exp_name}/"
69
+ path = exp_dir + f"seed_{seed}/"
70
+ try:
71
+ one_seed = load_json(path + f"metrics.json")[lineconf["metric"]]
72
+ all_seeds.append(one_seed)
73
+ except OSError:
74
+ _handle_missing(ignore_missing, path)
75
+ return all_seeds
76
+
77
+
78
+ def _plot_from_conf(
79
+ plot_conf, res_dir, plot_dir, nametag, exp_tags,
80
+ same_line, metric, ignore_missing, plot_kwargs=None,
81
+ custom_xlab=None, custom_ylab=None):
82
+ """ creates and saves a plot from a config
83
+ Args:
84
+ - plot_conf (dict) : configuration of the plot, including instructions on the data to use
85
+ exp : plot_conf = { "filename" : "plot", "title": "title_foo",
86
+ "lines": { "legend1" : lineconf1, "legend2" : lineconf2}}
87
+ - res_dir (str) : directory containing the data to plot
88
+ - plot_dir (str) : directory where to save the plot
89
+ - nametag (str) : prefix identifying the series of experiments
90
+ - exp_tags (str list) : list of parameters used in experiment names
91
+ - same_line (list dict or empty dict) : {"param_name": list of values} values on x axis,
92
+ empty dict will use training steps as x axis
93
+ - metric (str) : name of the metric to plot
94
+ - ignore_missing (bool) : True to plot despite missing data
95
+ - plot_kwargs (dict) : dictionnary of parameters to forward to seeds_plot_together()
96
+ - custom_xlab (func) : function that gives x label from x parameter name
97
+ - custom_ylab (func) : function that gives y label from y parameter name
98
+ """
99
+ dir_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
100
+ os.makedirs(f"{dir_path}/{plot_dir}/", exist_ok=True)
101
+ path = f"{dir_path}/{plot_dir}/{plot_conf['filename']}"
102
+ if plot_conf["type"] == "lines2d": # 2D line plot
103
+ all_lines = []
104
+ all_legends = []
105
+ for legend, lineconf in tqdm(plot_conf["lines"].items(), desc=f"Plotting {plot_conf['filename']}"):
106
+ all_lines.append(_extract_line_data(lineconf, res_dir, nametag, exp_tags, ignore_missing))
107
+ all_legends.append(legend)
108
+ if len(same_line) == 0: # implicit values for x axis
109
+ xlab, xvals = "steps", None
110
+ else: # explicit values for x axis
111
+ xlab, xvals = list(same_line.items())[0]
112
+ seeds_plot_together(
113
+ all_lines, all_legends, title=plot_conf["title"], savepath=path,
114
+ ylab=custom_ylab(metric), xlab=custom_xlab(xlab), x_vals=xvals, **plot_kwargs
115
+ )
116
+ elif plot_conf["type"] in ["colors3d", "surface3d"]: # 3D plots
117
+ all_rows = []
118
+ xlab, xvals = list(same_line.items())[0]
119
+ ylab, yvals = list(same_line.items())[1]
120
+ for lineconf in tqdm(plot_conf["rows"], desc=f"Plotting {plot_conf['filename']}"):
121
+ all_rows.append(_extract_line_data(lineconf, res_dir, nametag, exp_tags, ignore_missing))
122
+ all_seeds = np.transpose(np.array(all_rows), (1, 0, 2))
123
+ plot_funcs = {"colors3d": seeds_plot_color3d, "surface3d": seeds_plot_surface3d}
124
+ plot_funcs[plot_conf["type"]](
125
+ all_seeds, x_vals=xvals, y_vals=yvals, label=metric,
126
+ title=plot_conf["title"], xlab=custom_xlab(xlab), ylab=custom_ylab(ylab),
127
+ savepath=path, std=False, **plot_kwargs)
128
+
129
+ # plot_conf = { "filename" : "plot", "title": "title_foo", "type":"3d",
130
+ # "rows": [lineconf1, lineconf2]}
131
+ def plot_experiments(
132
+ res_dir=None, plot_dir=None, metrics=None,
133
+ seeds=None, nametag="", params_common=None,
134
+ diff_plots=None, same_plot=None, same_line=None, set_depending_params=None,
135
+ exp_tags=None, ignore_missing=None, style3Dplot=None,
136
+ custom_title=None, custom_legend=None,
137
+ custom_xlab=None, custom_ylab=None,
138
+ ):
139
+ """ Plot a series of experiments
140
+ Args:
141
+ - res_dir (str) : directory containing the data to plot
142
+ - plot_dir (str) : directory where to save the plot
143
+ - metrics ((str | str tuple) list) : list of metrics to plot (batch to put on same plot)
144
+ - seeds (int list) : seeds of runs to plot
145
+ - nametag (str) : prefix identifying the series of experiments
146
+ - exp_tags (str list) : list of parameters used in experiment names
147
+ - same_line (list dict or empty dict) : {"param_name": list of values} values on x axis,
148
+ empty dict will use training steps as x axis
149
+ - params_common (dict) : dictionnary of {"param_name": value}
150
+ that are de default parameters of experiment_func
151
+ - diff_plots (list dict) : dictionnary of {"param_name": list of values} each value combination on a different plot
152
+ - same_plot (list dict) : dictionnary of {"param_name": list of values} each value combination on a different line of the same plot
153
+ - same_line (list dict) : dictionnary of {"param_name": list of values}, values on the x axis of the plot (only one parameter)
154
+ - set_depending_params (func) : function editing in-place a dictionnary of parameters
155
+ - exp_tags (str list) : list of parameters to put in experiment names
156
+ - ignore_missing (bool) : True to plot despite missing data
157
+ - style3Dplot (str) : 'surface3d' or 'colors3d'
158
+ - custom_title (func or str) : (optionnal) function that gives a title from parameters
159
+ - custom_legend (func) : (optionnal) function that gives a legend from parameters
160
+ - custom_xlab (func or str) : function that gives x label from x parameter name
161
+ - custom_ylab (func or str) : function that gives y label from y parameter name
162
+ """
163
+ # handling constant functions
164
+ if type(custom_title) is str:
165
+ _title = custom_title
166
+ custom_title = lambda x, y : _title
167
+ if type(custom_xlab) is str:
168
+ _xlab = custom_xlab
169
+ custom_xlab = lambda x : _xlab
170
+ if type(custom_ylab) is str:
171
+ _ylab = custom_ylab
172
+ custom_ylab = lambda x : _ylab
173
+ # infering plot type
174
+ if len(same_line) < 2:
175
+ plot_type = "lines2d"
176
+ else:
177
+ plot_type = style3Dplot
178
+ # searching for config options
179
+ if os.path.exists("plot_config.json"):
180
+ plot_params = load_json("plot_config.json")[plot_type]
181
+ else:
182
+ print("plot_config.json not found, plotting without config")
183
+ plot_params = {"_overwrite":{}}
184
+ plot_tags = list(diff_plots.keys())
185
+ legend_tags = list(same_plot.keys())
186
+ # creating plot configs and plotting
187
+ for metrics_plot in metrics:
188
+ if type(metrics_plot) is str: # if only one metric on the plot
189
+ metrics_plot = [metrics_plot]
190
+ for params1 in make_grid(diff_plots):
191
+ if custom_title is None:
192
+ title = make_title(params1, plot_tags)
193
+ else:
194
+ title = custom_title(params1, metrics_plot)
195
+ plot_conf = {
196
+ "title": title,
197
+ "filename": make_plot_name(params1, nametag, metrics_plot[0], plot_tags)}
198
+ if plot_type == "lines2d": # if 2D line plot
199
+ plot_conf["lines"] = {}
200
+ for params2 in make_grid(same_plot):
201
+ if len(metrics_plot) > 1:
202
+ prefix = metric + ", "
203
+ else:
204
+ prefix = ""
205
+ for metric in metrics_plot:
206
+ def _pick_legend(params):
207
+ if custom_legend is None:
208
+ return prefix + make_legend(params, legend_tags)
209
+ return custom_legend(params, metric)
210
+ if same_line == {}:
211
+ params = update_params(params_common, params1, params2)
212
+ set_depending_params(params)
213
+ plot_conf["lines"][_pick_legend(params)] = {
214
+ "seeds": seeds, "metric": copy.deepcopy(metric), "conf": params
215
+ }
216
+ else:
217
+ params_list = []
218
+ for params3 in make_grid(same_line):
219
+ params = update_params(params_common, params1, params2, params3)
220
+ set_depending_params(params)
221
+ params_list.append(params)
222
+ plot_conf["lines"][_pick_legend(params)] = {
223
+ "seeds": seeds, "metric": copy.deepcopy(metric), "confs": params_list,
224
+ }
225
+ elif plot_type in ["colors3d", "surface3d"]: # if 3d plot
226
+ for metric in metrics_plot:
227
+ plot_conf["rows"] = []
228
+ x_param, x_values = list(same_line.items())[0]
229
+ y_param, y_values = list(same_line.items())[1]
230
+ for y_value in y_values:
231
+ params_list = []
232
+ for x_value in x_values:
233
+ params = update_params(params_common, params1, {x_param: x_value, y_param: y_value})
234
+ set_depending_params(params)
235
+ params_list.append(params)
236
+ lineconf = {"seeds": seeds, "metric": copy.deepcopy(metric), "confs": params_list,
237
+ }
238
+ plot_conf["rows"].append(lineconf)
239
+ plot_conf["type"] = plot_type
240
+ if plot_params is not None:
241
+ plot_kwargs = update_params(plot_params.get(metric, {}), plot_params["_overwrite"])
242
+ _plot_from_conf(
243
+ plot_conf, res_dir, plot_dir,
244
+ nametag, exp_tags,
245
+ same_line, metric, ignore_missing, plot_kwargs,
246
+ custom_xlab, custom_ylab
247
+ )
248
+ # lineconf = {"confs": conf_list, "seeds":[1, 2], "metric": "metric"}
@@ -0,0 +1,157 @@
1
+ ###
2
+ # @file reproduce.py
3
+ # @author Oscar Villemaud <oscar.villemaud@epfl.ch>
4
+ #
5
+ # @section LICENSE
6
+ #
7
+ # Copyright © 2024-2026 École Polytechnique Fédérale de Lausanne (EPFL).
8
+ # All rights reserved.
9
+ #
10
+ # @section DESCRIPTION
11
+ #
12
+ # reproplot main functions to run, plot and manage experiments.
13
+ ###
14
+
15
+ import os
16
+ import random
17
+ from tqdm import tqdm
18
+
19
+ from .utils import load_json, make_exp_name, check_compatibility
20
+ from .reprod_exps import run_experiments
21
+ from .reprod_plots import plot_experiments
22
+
23
+
24
+ def rename_exps(
25
+ directory, new_tag, new_exp_tags, old_tag=None):
26
+ """ rename experiment names using parameters saved
27
+ Args:
28
+ - directory : directory of experiments
29
+ - new_tag (str) : new prefix for experiment names
30
+ - new_exp_tags (str list) : new names of metrics to put in experiment names
31
+ - old_tag (str) : specify to rename only experiments with that tag
32
+ """
33
+ print("Renaming experiments of directory :", directory)
34
+ exp_names_paths = [(f.name, f.path) for f in os.scandir(directory) if f.is_dir()]
35
+ counter = 0
36
+ for exp_name, exp_path in exp_names_paths:
37
+ prefix_cond = True
38
+ if old_tag is not None:
39
+ prefix = exp_name.split("-")[0]
40
+ prefix_cond = prefix == old_tag
41
+ if prefix_cond:
42
+ params = load_json(exp_path + "/params.json")
43
+ exp_name_new = make_exp_name(params, new_tag, tags_list=new_exp_tags)
44
+ os.rename(exp_path, directory + "/" + exp_name_new)
45
+ counter += 1
46
+ print("renaming:", exp_name, "\n into: ", exp_name_new)
47
+ print(f"Renamed {counter} experiments")
48
+
49
+
50
+ def index_exps(directory, nametag=None, param_requirements=None):
51
+ """ gives a summary of available experiments
52
+ Args:
53
+ - directory : directory of experiments
54
+ - nametag (str) : specify to see only experiments with that tag
55
+ - param_requirements (bool func) : function taking params as input
56
+ and outputing True if this experiment should be included
57
+ """
58
+ exp_names_paths = [(f.name, f.path) for f in os.scandir(directory) if f.is_dir()]
59
+ all_params = {}
60
+ seeds = set()
61
+ nb_runs, nb_exps = 0, 0
62
+ for exp_name, exp_path in tqdm(exp_names_paths):
63
+ loaded_params = load_json(exp_path + "/params.json")
64
+ select = True
65
+ if param_requirements is not None:
66
+ select = param_requirements(loaded_params)
67
+ if nametag is not None:
68
+ prefix = exp_name.split("-")[0]
69
+ select = select and (prefix == nametag)
70
+ if select:
71
+ nb_exps += 1
72
+ for name, value in loaded_params.items():
73
+ if type(value) is list:
74
+ value = tuple(value)
75
+ if name in all_params:
76
+ all_params[name].add(value)
77
+ else:
78
+ all_params[name] = {value}
79
+ for seed_name in [f.name for f in os.scandir(exp_path) if f.is_dir()]:
80
+ if "failed" not in seed_name:
81
+ seed = int(seed_name[5:])
82
+ seeds.add(seed)
83
+ nb_runs += 1
84
+ print(f"found {nb_runs} runs grouped in {nb_exps}/{len(exp_names_paths)} experiments with parameters :")
85
+ def _custom_key(obj):
86
+ if obj is None:
87
+ return 0
88
+ elif type(obj) is int or float:
89
+ return obj
90
+ else:
91
+ return len(obj)
92
+ for name, values in all_params.items():
93
+ print(name, sorted(values, key=_custom_key))
94
+ print("seeds", sorted(seeds))
95
+
96
+
97
+ def run_and_plot(
98
+ seeds=None, res_dir="results_RPP", plot_dir="plots_RPP",
99
+ metrics=["metric"], experiment_func=None,
100
+ nametag="", params_common={},
101
+ diff_plots={}, same_plot={}, same_line={},
102
+ exp_tags=None, set_depending_params=lambda x : None,
103
+ no_run=False, no_plot=False,
104
+ ignore_missing=False, style3Dplot="colors3d", verb=["pbar"],
105
+ custom_title=None, custom_legend=None, custom_xlab=lambda x:x, custom_ylab=lambda x:x,
106
+ ):
107
+ """ run and plot experiments by using all hyperparameters in a grid-search fashion
108
+ Args:
109
+ - res_dir (str) : directory containing the data to plot
110
+ - plot_dir (str) : directory where to save the plot
111
+ - seeds (int list) : seeds of runs to plot
112
+ - metrics ((str | str tuple) list) : list of metrics to plot (batch to put on same plot)
113
+ - nametag (str) : prefix identifying the series of experiments
114
+ - exp_tags (str list) : list of parameters used in experiment names
115
+ - same_line (list dict or empty dict) : {"param_name": list of values} values on x axis,
116
+ empty dict will use training steps as x axis
117
+ - params_common (dict) : dictionnary of {"param_name": value}
118
+ that are de default parameters of experiment_func
119
+ - diff_plots (list dict) : dictionnary of {"param_name": list of values} each value combination on a different plot
120
+ - same_plot (list dict) : dictionnary of {"param_name": list of values} each value combination on a different line of the same plot
121
+ - same_line (list dict) : dictionnary of {"param_name": list of values}, values on the x axis of the plot (only one parameter)
122
+ - set_depending_params (func) : function editing in-place a dictionnary of parameters
123
+ - exp_tags (str list) : list of parameters to put in experiment names
124
+ - no_run (bool) : True to disable running new experiments
125
+ - no_plot (bool) : True to disable plotting
126
+ - ignore_missing (bool) : True to plot despite missing data
127
+ - style3Dplot (str) : 'surface3d' or 'colors3d'
128
+ - verb (str list) : string codes to indicate verbose
129
+ - custom_title (func or str) : (optionnal) function that gives a title from parameters
130
+ - custom_legend (func) : (optionnal) function that gives a legend from parameters
131
+ - custom_xlab (func or str) : (optionnal) function that gives x label from x parameter name
132
+ - custom_ylab (func or str) : (optionnal) function that gives y label from y parameter name
133
+ """
134
+ if exp_tags is None:
135
+ exp_tags = sorted(list(set(diff_plots.keys()) | set(same_plot.keys()) | set(same_line.keys())))
136
+ if seeds is None:
137
+ seeds = [random.randint(1, 99999)]
138
+ print(f"No seeds specified, using random seed {seeds[0]}")
139
+ check_compatibility(diff_plots, same_plot, same_line, exp_tags)
140
+
141
+ if not no_run:
142
+ run_experiments(
143
+ experiment_func=experiment_func, res_dir=res_dir, seeds=seeds, nametag=nametag,
144
+ params_common=params_common, diff_plots=diff_plots, same_plot=same_plot, same_line=same_line,
145
+ exp_tags=exp_tags, set_depending_params=set_depending_params, verb=verb,
146
+ )
147
+
148
+ if not no_plot:
149
+ plot_experiments(
150
+ res_dir=res_dir, plot_dir=plot_dir, seeds=seeds,
151
+ nametag=nametag, metrics=metrics, same_line=same_line,
152
+ params_common=params_common, diff_plots=diff_plots, same_plot=same_plot,
153
+ exp_tags=exp_tags, set_depending_params=set_depending_params,
154
+ ignore_missing=ignore_missing, style3Dplot=style3Dplot,
155
+ custom_title=custom_title, custom_legend=custom_legend,
156
+ custom_xlab=custom_xlab, custom_ylab=custom_ylab,
157
+ )
@@ -0,0 +1,235 @@
1
+ ###
2
+ # @file utils.py
3
+ # @author Oscar Villemaud <oscar.villemaud@epfl.ch>
4
+ #
5
+ # @section LICENSE
6
+ #
7
+ # Copyright © 2024-2026 École Polytechnique Fédérale de Lausanne (EPFL).
8
+ # All rights reserved.
9
+ #
10
+ # @section DESCRIPTION
11
+ #
12
+ # Utilitary functions.
13
+ ###
14
+
15
+ import os
16
+ import json
17
+ import pickle
18
+ import random
19
+ from itertools import product
20
+
21
+ import numpy as np
22
+ import torch
23
+
24
+
25
+ def dump_json(object, path, indent=0):
26
+ """ save object in json file
27
+ Args:
28
+ - object : python object to save
29
+ - path (str) : path where to save the object
30
+ - indent (int) : indent level to use (for readability)
31
+ """
32
+ with open(path, "w") as outfile:
33
+ json.dump(object, outfile, indent=indent)
34
+
35
+
36
+ def load_json(path):
37
+ """ load object from json file
38
+ Args:
39
+ - path (str) : path where to save the object
40
+ """
41
+ with open(path, 'r') as openfile:
42
+ return json.load(openfile)
43
+
44
+
45
+ def dump_pickle(object, path):
46
+ """ save object in pickle file
47
+ Args:
48
+ - object : python object to save
49
+ - path (str) : path where to save the object
50
+ """
51
+ with open(path, "wb") as outfile:
52
+ pickle.dump(object, outfile)
53
+
54
+
55
+ def load_pickle(path):
56
+ """ load object from pickle file
57
+ Args:
58
+ - path (str) : path where to save the object
59
+ """
60
+ with open(path, 'rb') as openfile:
61
+ return pickle.load(openfile)
62
+
63
+
64
+ def seedall(seed):
65
+ """ seed random, numpy and pytorch with the same seed
66
+ Args:
67
+ - seed (int) : seed to use
68
+ """
69
+ np.random.seed(seed)
70
+ torch.manual_seed(seed)
71
+ random.seed(seed)
72
+
73
+
74
+ def make_grid(list_dic):
75
+ """ build iterator over different combination of params (gridsearch)
76
+ Args:
77
+ - list_dic (list dict): dictionnary of {"param_name": list of values}
78
+
79
+ Yields:
80
+ - (dict)
81
+ """
82
+ for params in product(*list_dic.values()):
83
+ yield {key : param for key, param in zip(list_dic.keys(), params)}
84
+
85
+
86
+ def count_combinations(list_dic):
87
+ """ count number of parameter combinations
88
+
89
+ Args:
90
+ - list_dic : dictionnary of {"metric": list of values}
91
+
92
+ Returns:
93
+ - (int) number of possible parameter combinations
94
+ """
95
+ nb = 1
96
+ for _, params in list_dic.items():
97
+ nb *= len(params)
98
+ return nb
99
+
100
+
101
+ def update_params(params, *updates):
102
+ """ add new parameters or replace existing ones
103
+ Args:
104
+ - params : dictionnary of parameters
105
+ - *updates : dictionnaries of additionnal parameters
106
+ Return:
107
+ - (dict) updated dictionnary of parameters
108
+ """
109
+ paramsall = params.copy()
110
+ for new_params in updates:
111
+ paramsall.update(new_params)
112
+ return paramsall
113
+
114
+
115
+ def make_exp_name(params, nametag, tags_list):
116
+ """ create experiment name
117
+ Args:
118
+ - params : dictionnary of {"metric": value}
119
+ - nametag (str) : prefix to identify experiment series
120
+ - tag_list (str list) : list of metrics to put in experiment name
121
+
122
+ Returns:
123
+ - (str) experiment name
124
+ """
125
+ exp_name = nametag
126
+ for name in tags_list:
127
+ exp_name += f"-{name}_{params[name]}"
128
+ return exp_name
129
+
130
+
131
+ def make_plot_name(params, nametag, metric, tags_list):
132
+ """" result should depend on diff_plot
133
+ Args:
134
+ - params : dictionnary of {"metric": value}
135
+ - nametag (str) : prefix to identify experiment series
136
+ - metric (str) : metric plotted
137
+ - tag_list (str list) : list of param names to put in experiment name
138
+
139
+ Returns:
140
+ - (str) plot name
141
+ """
142
+ plot_name = f"{nametag}-{metric}"
143
+ for name in tags_list:
144
+ plot_name += f"-{name}_{params[name]}"
145
+ return plot_name
146
+
147
+
148
+ def make_title(params, tags_list):
149
+ """
150
+ Args:
151
+ - params : dictionnary of {"metric": value}
152
+ - tag_list (str list) : list of param names to put in experiment name
153
+
154
+ Returns:
155
+ - (str) plot title
156
+ """
157
+ if len(tags_list) == 0:
158
+ return "X"
159
+ # title = f"{tags_list[0]}={params[tags_list[0]]}"
160
+ title = ""
161
+ for name in tags_list:
162
+ if len(title):
163
+ title += ", "
164
+ title += f"{name}={params[name]}"
165
+ return title
166
+
167
+
168
+ def make_legend(params, tags_list):
169
+ """ create a legend for a line on a plot
170
+ from parameters and a list of parameter names
171
+ Args:
172
+ - params : dictionnary of {"metric": value}
173
+ - tag_list (str list) : list of param names to put in experiment name
174
+
175
+ Returns:
176
+ - (str) line legend
177
+ """
178
+ if len(tags_list) == 0:
179
+ return "X"
180
+ legend = f"{tags_list[0]}={params[tags_list[0]]}"
181
+ for name in tags_list[1:]:
182
+ legend += f", {name}={params[name]}"
183
+ return legend
184
+
185
+
186
+ def check_compatibility(
187
+ diff_plots, same_plot, same_line, exp_tags
188
+ ):
189
+ """ check if the same name if generated twice
190
+ Args:
191
+ - diff_plots (list dict) : dictionnary of {"param_name": list of values} each value combination on a different plot
192
+ - same_plot (list dict) : dictionnary of {"param_name": list of values} each value combination on a different line of the same plot
193
+ - same_line (list dict) : dictionnary of {"param_name": list of values}, values on the x axis of the plot (only one parameter)
194
+ - exp_tags (str list) : list of parameters to put in experiment names
195
+ """
196
+ for param in list(diff_plots.keys()) + list(same_plot.keys()) + list(same_line.keys()):
197
+ if param not in exp_tags:
198
+ print(f"WARNING : Experiment names should depend on -{param} to avoid having the same name")
199
+
200
+
201
+ def scan_runs(
202
+ res_dir, dir_path=None,
203
+ seeds=None, nametag="", params_common=None,
204
+ diff_plots=None, same_plot=None, same_line=None,
205
+ exp_tags=None, set_depending_params=lambda x: None):
206
+ """ finds existing and missing runs
207
+ - res_dir (str) : name of the directory where to store results as json
208
+ - dir_path (str) : path to directory containing -res_dir
209
+ - seeds (int list) : list of random seeds to use for reproducibility
210
+ - nametag (str) : prefix that identifies the series of experiments
211
+ - params_common (dict) : dictionnary of {"param_name": value}
212
+ that are de default parameters of experiment_func
213
+ - diff_plots (list dict) : dictionnary of {"param_name": list of values} each value combination on a different plot
214
+ - same_plot (list dict) : dictionnary of {"param_name": list of values} each value combination on a different line of the same plot
215
+ - same_line (list dict) : dictionnary of {"param_name": list of values}, values on the x axis of the plot (only one parameter)
216
+ - exp_tags (str list) : list of parameters to put in experiment names
217
+ - set_depending_params (func) : function editing in-place a dictionnary of parameters
218
+ """
219
+ confs_paths = []
220
+ count_done, count_total = 0, 0
221
+ for params1 in make_grid(diff_plots):
222
+ for params2 in make_grid(same_plot):
223
+ for params3 in make_grid(same_line):
224
+ params = update_params(params_common, params1, params2, params3)
225
+ set_depending_params(params)
226
+ exp_name = make_exp_name(params, nametag, exp_tags)
227
+ exp_dir = f"{dir_path}/{res_dir}/{exp_name}/"
228
+ for seed in seeds:
229
+ path = exp_dir + f"seed_{seed}/"
230
+ if os.path.isdir(path):
231
+ count_done += 1
232
+ else:
233
+ confs_paths.append((params.copy(), path))
234
+ count_total += 1
235
+ return count_done, count_total, confs_paths