halib 0.1.7__py3-none-any.whl → 0.1.99__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- halib/__init__.py +84 -0
- halib/common.py +151 -0
- halib/cuda.py +39 -0
- halib/dataset.py +209 -0
- halib/filetype/csvfile.py +151 -45
- halib/filetype/ipynb.py +63 -0
- halib/filetype/jsonfile.py +1 -1
- halib/filetype/textfile.py +4 -4
- halib/filetype/videofile.py +44 -33
- halib/filetype/yamlfile.py +95 -0
- halib/gdrive.py +1 -1
- halib/online/gdrive.py +104 -54
- halib/online/gdrive_mkdir.py +29 -17
- halib/online/gdrive_test.py +31 -18
- halib/online/projectmake.py +58 -43
- halib/plot.py +296 -11
- halib/projectmake.py +1 -1
- halib/research/__init__.py +0 -0
- halib/research/base_config.py +100 -0
- halib/research/base_exp.py +100 -0
- halib/research/benchquery.py +131 -0
- halib/research/dataset.py +208 -0
- halib/research/flop_csv.py +34 -0
- halib/research/flops.py +156 -0
- halib/research/metrics.py +133 -0
- halib/research/mics.py +68 -0
- halib/research/params_gen.py +108 -0
- halib/research/perfcalc.py +336 -0
- halib/research/perftb.py +780 -0
- halib/research/plot.py +758 -0
- halib/research/profiler.py +300 -0
- halib/research/torchloader.py +162 -0
- halib/research/wandb_op.py +116 -0
- halib/rich_color.py +285 -0
- halib/sys/filesys.py +17 -10
- halib/system/__init__.py +0 -0
- halib/system/cmd.py +8 -0
- halib/system/filesys.py +124 -0
- halib/tele_noti.py +166 -0
- halib/torchloader.py +162 -0
- halib/utils/__init__.py +0 -0
- halib/utils/dataclass_util.py +40 -0
- halib/utils/dict_op.py +9 -0
- halib/utils/gpu_mon.py +58 -0
- halib/utils/listop.py +13 -0
- halib/utils/tele_noti.py +166 -0
- halib/utils/video.py +82 -0
- halib/videofile.py +1 -1
- halib-0.1.99.dist-info/METADATA +209 -0
- halib-0.1.99.dist-info/RECORD +64 -0
- {halib-0.1.7.dist-info → halib-0.1.99.dist-info}/WHEEL +1 -1
- halib-0.1.7.dist-info/METADATA +0 -59
- halib-0.1.7.dist-info/RECORD +0 -30
- {halib-0.1.7.dist-info → halib-0.1.99.dist-info/licenses}/LICENSE.txt +0 -0
- {halib-0.1.7.dist-info → halib-0.1.99.dist-info}/top_level.txt +0 -0
halib/plot.py
CHANGED
|
@@ -1,16 +1,301 @@
|
|
|
1
|
+
from .common import now_str, norm_str, ConsoleLog
|
|
2
|
+
from .filetype import csvfile
|
|
3
|
+
from .system import filesys as fs
|
|
4
|
+
from functools import partial
|
|
5
|
+
from rich.console import Console
|
|
6
|
+
from rich.pretty import pprint
|
|
7
|
+
import click
|
|
8
|
+
import csv
|
|
9
|
+
import matplotlib
|
|
1
10
|
import matplotlib.pyplot as plt
|
|
11
|
+
import numpy as np
|
|
12
|
+
import os
|
|
13
|
+
import pandas as pd
|
|
2
14
|
import seaborn as sns
|
|
3
|
-
import matplotlib
|
|
4
15
|
|
|
5
16
|
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
17
|
+
console = Console()
|
|
18
|
+
desktop_path = os.path.expanduser("~/Desktop")
|
|
19
|
+
REQUIRED_COLUMNS = ["epoch", "train_loss", "val_loss", "train_acc", "val_acc"]
|
|
20
|
+
|
|
21
|
+
import csv
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_delimiter(file_path, bytes=4096):
|
|
25
|
+
sniffer = csv.Sniffer()
|
|
26
|
+
data = open(file_path, "r").read(bytes)
|
|
27
|
+
delimiter = sniffer.sniff(data).delimiter
|
|
28
|
+
return delimiter
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# Function to verify that the DataFrame has the required columns, and only the required columns
|
|
32
|
+
def verify_csv(csv_file, required_columns=REQUIRED_COLUMNS):
|
|
33
|
+
delimiter = get_delimiter(csv_file)
|
|
34
|
+
df = pd.read_csv(csv_file, sep=delimiter)
|
|
35
|
+
# change the column names to lower case
|
|
36
|
+
df.columns = [col.lower() for col in df.columns]
|
|
37
|
+
for col in required_columns:
|
|
38
|
+
if col not in df.columns:
|
|
39
|
+
raise ValueError(
|
|
40
|
+
f"Required columns are: {REQUIRED_COLUMNS}, but found {df.columns}"
|
|
41
|
+
)
|
|
42
|
+
df = df[required_columns].copy()
|
|
43
|
+
return df
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_valid_tags(csv_files, tags):
|
|
47
|
+
if tags is not None and len(tags) > 0:
|
|
48
|
+
assert all(
|
|
49
|
+
isinstance(tag, str) for tag in tags
|
|
50
|
+
), "tags must be a list of strings"
|
|
51
|
+
assert all(
|
|
52
|
+
len(tag) > 0 for tag in tags
|
|
53
|
+
), "tags must be a list of non-empty strings"
|
|
54
|
+
valid_tags = tags
|
|
55
|
+
else:
|
|
56
|
+
valid_tags = []
|
|
57
|
+
for csv_file in csv_files:
|
|
58
|
+
file_name = fs.get_file_name(csv_file, split_file_ext=True)[0]
|
|
59
|
+
tag = norm_str(file_name)
|
|
60
|
+
valid_tags.append(tag)
|
|
61
|
+
return valid_tags
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def plot_ax(df, ax, metric="loss", tag=""):
|
|
65
|
+
pprint(locals())
|
|
66
|
+
# reset plt
|
|
67
|
+
assert metric in ["loss", "acc"], "metric must be either 'loss' or 'acc'"
|
|
68
|
+
part = ["train", "val"]
|
|
69
|
+
for p in part:
|
|
70
|
+
label = f"{tag}_{p}_{metric}"
|
|
71
|
+
ax.plot(df["epoch"], df[f"{p}_{metric}"], label=label)
|
|
72
|
+
return ax
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def actual_plot_seaborn(frame, csv_files, axes, tags, log):
|
|
76
|
+
# clear the axes
|
|
77
|
+
for ax in axes:
|
|
78
|
+
ax.clear()
|
|
79
|
+
ls_df = []
|
|
80
|
+
valid_tags = get_valid_tags(csv_files, tags)
|
|
81
|
+
for csv_file in csv_files:
|
|
82
|
+
df = verify_csv(csv_file)
|
|
83
|
+
if log:
|
|
84
|
+
with ConsoleLog(f"plotting {csv_file}"):
|
|
85
|
+
csvfile.fn_display_df(df)
|
|
86
|
+
ls_df.append(df)
|
|
87
|
+
|
|
88
|
+
ls_metrics = ["loss", "acc"]
|
|
89
|
+
for df_item, tag in zip(ls_df, valid_tags):
|
|
90
|
+
# add tag to columns,excpet epoch
|
|
91
|
+
df_item.columns = [
|
|
92
|
+
f"{tag}_{col}" if col != "epoch" else col for col in df_item.columns
|
|
93
|
+
]
|
|
94
|
+
# merge the dataframes on the epoch column
|
|
95
|
+
df_combined = ls_df[0]
|
|
96
|
+
for df_item in ls_df[1:]:
|
|
97
|
+
df_combined = pd.merge(df_combined, df_item, on="epoch", how="outer")
|
|
98
|
+
# csvfile.fn_display_df(df_combined)
|
|
99
|
+
|
|
100
|
+
for i, metric in enumerate(ls_metrics):
|
|
101
|
+
tags_str = "+".join(valid_tags) if len(valid_tags) > 1 else valid_tags[0]
|
|
102
|
+
title = f"{tags_str}_{metric}-by-epoch"
|
|
103
|
+
cols = [col for col in df_combined.columns if col != "epoch" and metric in col]
|
|
104
|
+
cols = sorted(cols)
|
|
105
|
+
# pprint(cols)
|
|
106
|
+
plot_data = df_combined[cols]
|
|
107
|
+
|
|
108
|
+
# line from same csv file (same tag) should have the same marker
|
|
109
|
+
all_markers = [
|
|
110
|
+
marker for marker in plt.Line2D.markers if marker and marker != " "
|
|
111
|
+
]
|
|
112
|
+
tag2marker = {tag: marker for tag, marker in zip(valid_tags, all_markers)}
|
|
113
|
+
plot_markers = []
|
|
114
|
+
for col in cols:
|
|
115
|
+
# find the tag:
|
|
116
|
+
tag = None
|
|
117
|
+
for valid_tag in valid_tags:
|
|
118
|
+
if valid_tag in col:
|
|
119
|
+
tag = valid_tag
|
|
120
|
+
break
|
|
121
|
+
plot_markers.append(tag2marker[tag])
|
|
122
|
+
# pprint(list(zip(cols, plot_markers)))
|
|
123
|
+
|
|
124
|
+
# create color
|
|
125
|
+
sequential_palettes = [
|
|
126
|
+
"Reds",
|
|
127
|
+
"Greens",
|
|
128
|
+
"Blues",
|
|
129
|
+
"Oranges",
|
|
130
|
+
"Purples",
|
|
131
|
+
"Greys",
|
|
132
|
+
"BuGn",
|
|
133
|
+
"BuPu",
|
|
134
|
+
"GnBu",
|
|
135
|
+
"OrRd",
|
|
136
|
+
"PuBu",
|
|
137
|
+
"PuRd",
|
|
138
|
+
"RdPu",
|
|
139
|
+
"YlGn",
|
|
140
|
+
"PuBuGn",
|
|
141
|
+
"YlGnBu",
|
|
142
|
+
"YlOrBr",
|
|
143
|
+
"YlOrRd",
|
|
144
|
+
]
|
|
145
|
+
# each csvfile (tag) should have a unique color
|
|
146
|
+
tag2palette = {
|
|
147
|
+
tag: palette for tag, palette in zip(valid_tags, sequential_palettes)
|
|
148
|
+
}
|
|
149
|
+
plot_colors = []
|
|
150
|
+
for tag in valid_tags:
|
|
151
|
+
palette = tag2palette[tag]
|
|
152
|
+
total_colors = 10
|
|
153
|
+
ls_colors = sns.color_palette(palette, total_colors).as_hex()
|
|
154
|
+
num_part = len(ls_metrics)
|
|
155
|
+
subarr = np.array_split(np.arange(total_colors), num_part)
|
|
156
|
+
for idx, col in enumerate(cols):
|
|
157
|
+
if tag in col:
|
|
158
|
+
chosen_color = ls_colors[
|
|
159
|
+
subarr[int(idx % num_part)].mean().astype(int)
|
|
160
|
+
]
|
|
161
|
+
plot_colors.append(chosen_color)
|
|
162
|
+
|
|
163
|
+
# pprint(list(zip(cols, plot_colors)))
|
|
164
|
+
sns.lineplot(
|
|
165
|
+
data=plot_data,
|
|
166
|
+
markers=plot_markers,
|
|
167
|
+
palette=plot_colors,
|
|
168
|
+
ax=axes[i],
|
|
169
|
+
dashes=False,
|
|
170
|
+
)
|
|
171
|
+
axes[i].set(xlabel="epoch", ylabel=metric, title=title)
|
|
172
|
+
axes[i].legend()
|
|
173
|
+
axes[i].grid()
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def actual_plot(frame, csv_files, axes, tags, log):
|
|
177
|
+
ls_df = []
|
|
178
|
+
valid_tags = get_valid_tags(csv_files, tags)
|
|
179
|
+
for csv_file in csv_files:
|
|
180
|
+
df = verify_csv(csv_file)
|
|
181
|
+
if log:
|
|
182
|
+
with ConsoleLog(f"plotting {csv_file}"):
|
|
183
|
+
csvfile.fn_display_df(df)
|
|
184
|
+
ls_df.append(df)
|
|
185
|
+
|
|
186
|
+
metric_values = ["loss", "acc"]
|
|
187
|
+
for i, metric in enumerate(metric_values):
|
|
188
|
+
for df_item, tag in zip(ls_df, valid_tags):
|
|
189
|
+
metric_ax = plot_ax(df_item, axes[i], metric, tag)
|
|
190
|
+
|
|
191
|
+
# set the title, xlabel, ylabel, legend, and grid
|
|
192
|
+
tags_str = "+".join(valid_tags) if len(valid_tags) > 1 else valid_tags[0]
|
|
193
|
+
metric_ax.set(
|
|
194
|
+
xlabel="epoch", ylabel=metric, title=f"{tags_str}_{metric}-by-epoch"
|
|
195
|
+
)
|
|
196
|
+
metric_ax.legend()
|
|
197
|
+
metric_ax.grid()
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def plot_csv_files(
|
|
201
|
+
csv_files,
|
|
202
|
+
outdir="./out/plot",
|
|
203
|
+
tags=None,
|
|
204
|
+
log=False,
|
|
205
|
+
save_fig=False,
|
|
206
|
+
update_in_min=1,
|
|
207
|
+
):
|
|
208
|
+
# if csv_files is a string, convert it to a list
|
|
209
|
+
if isinstance(csv_files, str):
|
|
210
|
+
csv_files = [csv_files]
|
|
211
|
+
# if tags is a string, convert it to a list
|
|
212
|
+
if isinstance(tags, str):
|
|
213
|
+
tags = [tags]
|
|
214
|
+
valid_tags = get_valid_tags(csv_files, tags)
|
|
215
|
+
assert len(valid_tags) == len(
|
|
216
|
+
csv_files
|
|
217
|
+
), "Unable to determine tags for each csv file"
|
|
218
|
+
live_update_in_ms = int(update_in_min * 60 * 1000)
|
|
219
|
+
fig, axes = plt.subplots(2, 1, figsize=(10, 17))
|
|
220
|
+
if live_update_in_ms: # live update in min should be > 0
|
|
221
|
+
from matplotlib.animation import FuncAnimation
|
|
222
|
+
|
|
223
|
+
anim = FuncAnimation(
|
|
224
|
+
fig,
|
|
225
|
+
partial(
|
|
226
|
+
actual_plot_seaborn, csv_files=csv_files, axes=axes, tags=tags, log=log
|
|
227
|
+
),
|
|
228
|
+
interval=live_update_in_ms,
|
|
229
|
+
blit=False,
|
|
230
|
+
cache_frame_data=False,
|
|
231
|
+
)
|
|
232
|
+
plt.show()
|
|
233
|
+
else:
|
|
234
|
+
actual_plot_seaborn(None, csv_files, axes, tags, log)
|
|
235
|
+
plt.show()
|
|
236
|
+
|
|
237
|
+
if save_fig:
|
|
238
|
+
os.makedirs(outdir, exist_ok=True)
|
|
239
|
+
tags_str = "+".join(valid_tags) if len(valid_tags) > 1 else valid_tags[0]
|
|
240
|
+
tag = f"{now_str()}_{tags_str}"
|
|
241
|
+
fig.savefig(f"{outdir}/{tag}_plot.png")
|
|
242
|
+
enable_plot_pgf()
|
|
243
|
+
fig.savefig(f"{outdir}/{tag}_plot.pdf")
|
|
244
|
+
if live_update_in_ms:
|
|
245
|
+
return anim
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def enable_plot_pgf():
|
|
249
|
+
matplotlib.use("pdf")
|
|
250
|
+
matplotlib.rcParams.update(
|
|
251
|
+
{
|
|
252
|
+
"pgf.texsystem": "pdflatex",
|
|
253
|
+
"font.family": "serif",
|
|
254
|
+
"text.usetex": True,
|
|
255
|
+
"pgf.rcfonts": False,
|
|
256
|
+
}
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def save_fig_latex_pgf(filename, directory="."):
|
|
261
|
+
enable_plot_pgf()
|
|
262
|
+
if ".pgf" not in filename:
|
|
263
|
+
filename = f"{directory}/{filename}.pgf"
|
|
16
264
|
plt.savefig(filename)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
# https: // click.palletsprojects.com/en/8.1.x/api/
|
|
268
|
+
@click.command()
|
|
269
|
+
@click.option("--csvfiles", "-f", multiple=True, type=str, help="csv files to plot")
|
|
270
|
+
@click.option(
|
|
271
|
+
"--outdir",
|
|
272
|
+
"-o",
|
|
273
|
+
type=str,
|
|
274
|
+
help="output directory for the plot",
|
|
275
|
+
default=str(desktop_path),
|
|
276
|
+
)
|
|
277
|
+
@click.option(
|
|
278
|
+
"--tags", "-t", multiple=True, type=str, help="tags for the csv files", default=[]
|
|
279
|
+
)
|
|
280
|
+
@click.option("--log", "-l", is_flag=True, help="log the csv files")
|
|
281
|
+
@click.option("--save_fig", "-s", is_flag=True, help="save the plot as a file")
|
|
282
|
+
@click.option(
|
|
283
|
+
"--update_in_min",
|
|
284
|
+
"-u",
|
|
285
|
+
type=float,
|
|
286
|
+
help="update the plot every x minutes",
|
|
287
|
+
default=0.0,
|
|
288
|
+
)
|
|
289
|
+
def main(
|
|
290
|
+
csvfiles,
|
|
291
|
+
outdir,
|
|
292
|
+
tags,
|
|
293
|
+
log,
|
|
294
|
+
save_fig,
|
|
295
|
+
update_in_min,
|
|
296
|
+
):
|
|
297
|
+
plot_csv_files(list(csvfiles), outdir, list(tags), log, save_fig, update_in_min)
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
if __name__ == "__main__":
|
|
301
|
+
main()
|
halib/projectmake.py
CHANGED
|
File without changes
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from rich.pretty import pprint
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from dataclass_wizard import YAMLWizard
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class NamedConfig(ABC):
|
|
8
|
+
"""
|
|
9
|
+
Base class for named configurations.
|
|
10
|
+
All configurations should have a name.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
@abstractmethod
|
|
14
|
+
def get_name(self):
|
|
15
|
+
"""
|
|
16
|
+
Get the name of the configuration.
|
|
17
|
+
This method should be implemented in subclasses.
|
|
18
|
+
"""
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ExpBaseConfig(ABC, YAMLWizard):
|
|
23
|
+
"""
|
|
24
|
+
Base class for configuration objects.
|
|
25
|
+
What a cfg class must have:
|
|
26
|
+
1 - a dataset cfg
|
|
27
|
+
2 - a metric cfg
|
|
28
|
+
3 - a method cfg
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
# Save to yaml fil
|
|
32
|
+
def save_to_outdir(
|
|
33
|
+
self, filename: str = "__config.yaml", outdir=None, override: bool = False
|
|
34
|
+
) -> None:
|
|
35
|
+
"""
|
|
36
|
+
Save the configuration to the output directory.
|
|
37
|
+
"""
|
|
38
|
+
if outdir is not None:
|
|
39
|
+
output_dir = outdir
|
|
40
|
+
else:
|
|
41
|
+
output_dir = self.get_outdir()
|
|
42
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
43
|
+
assert (output_dir is not None) and (
|
|
44
|
+
os.path.isdir(output_dir)
|
|
45
|
+
), f"Output directory '{output_dir}' does not exist or is not a directory."
|
|
46
|
+
file_path = os.path.join(output_dir, filename)
|
|
47
|
+
if os.path.exists(file_path) and not override:
|
|
48
|
+
pprint(
|
|
49
|
+
f"File '{file_path}' already exists. Use 'override=True' to overwrite."
|
|
50
|
+
)
|
|
51
|
+
else:
|
|
52
|
+
# method of YAMLWizard to_yaml_file
|
|
53
|
+
self.to_yaml_file(file_path)
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
@abstractmethod
|
|
57
|
+
# load from a custom YAML file
|
|
58
|
+
def from_custom_yaml_file(cls, yaml_file: str):
|
|
59
|
+
"""Load a configuration from a custom YAML file."""
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
@abstractmethod
|
|
63
|
+
def get_cfg_name(self):
|
|
64
|
+
"""
|
|
65
|
+
Get the name of the configuration.
|
|
66
|
+
This method should be implemented in subclasses.
|
|
67
|
+
"""
|
|
68
|
+
pass
|
|
69
|
+
|
|
70
|
+
@abstractmethod
|
|
71
|
+
def get_outdir(self):
|
|
72
|
+
"""
|
|
73
|
+
Get the output directory for the configuration.
|
|
74
|
+
This method should be implemented in subclasses.
|
|
75
|
+
"""
|
|
76
|
+
return None
|
|
77
|
+
|
|
78
|
+
@abstractmethod
|
|
79
|
+
def get_general_cfg(self):
|
|
80
|
+
"""
|
|
81
|
+
Get the general configuration like output directory, log settings, SEED, etc.
|
|
82
|
+
This method should be implemented in subclasses.
|
|
83
|
+
"""
|
|
84
|
+
pass
|
|
85
|
+
|
|
86
|
+
@abstractmethod
|
|
87
|
+
def get_dataset_cfg(self) -> NamedConfig:
|
|
88
|
+
"""
|
|
89
|
+
Get the dataset configuration.
|
|
90
|
+
This method should be implemented in subclasses.
|
|
91
|
+
"""
|
|
92
|
+
pass
|
|
93
|
+
|
|
94
|
+
@abstractmethod
|
|
95
|
+
def get_metric_cfg(self) -> NamedConfig:
|
|
96
|
+
"""
|
|
97
|
+
Get the metric configuration.
|
|
98
|
+
This method should be implemented in subclasses.
|
|
99
|
+
"""
|
|
100
|
+
pass
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
from ..research.base_config import ExpBaseConfig
|
|
4
|
+
from ..research.perfcalc import PerfCalc
|
|
5
|
+
from ..research.metrics import MetricsBackend
|
|
6
|
+
|
|
7
|
+
# ! SEE https://github.com/hahv/base_exp for sample usage
|
|
8
|
+
class BaseExperiment(PerfCalc, ABC):
|
|
9
|
+
"""
|
|
10
|
+
Base class for experiments.
|
|
11
|
+
Orchestrates the experiment pipeline using a pluggable metrics backend.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, config: ExpBaseConfig):
|
|
15
|
+
self.config = config
|
|
16
|
+
self.metric_backend = None
|
|
17
|
+
|
|
18
|
+
# -----------------------
|
|
19
|
+
# PerfCalc Required Methods
|
|
20
|
+
# -----------------------
|
|
21
|
+
def get_dataset_name(self):
|
|
22
|
+
return self.config.get_dataset_cfg().get_name()
|
|
23
|
+
|
|
24
|
+
def get_experiment_name(self):
|
|
25
|
+
return self.config.get_cfg_name()
|
|
26
|
+
|
|
27
|
+
def get_metric_backend(self):
|
|
28
|
+
if not self.metric_backend:
|
|
29
|
+
self.metric_backend = self.prepare_metrics(self.config.get_metric_cfg())
|
|
30
|
+
return self.metric_backend
|
|
31
|
+
|
|
32
|
+
# -----------------------
|
|
33
|
+
# Abstract Experiment Steps
|
|
34
|
+
# -----------------------
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def init_general(self, general_cfg):
|
|
37
|
+
"""Setup general settings like SEED, logging, env variables."""
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def prepare_dataset(self, dataset_cfg):
|
|
42
|
+
"""Load/prepare dataset."""
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
@abstractmethod
|
|
46
|
+
def prepare_metrics(self, metric_cfg) -> MetricsBackend:
|
|
47
|
+
"""
|
|
48
|
+
Prepare the metrics for the experiment.
|
|
49
|
+
This method should be implemented in subclasses.
|
|
50
|
+
"""
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
@abstractmethod
|
|
54
|
+
def exec_exp(self, *args, **kwargs):
|
|
55
|
+
"""Run experiment process, e.g.: training/evaluation loop.
|
|
56
|
+
Return: raw_metrics_data, and extra_data as input for calc_and_save_exp_perfs
|
|
57
|
+
"""
|
|
58
|
+
pass
|
|
59
|
+
|
|
60
|
+
def eval_exp(self):
|
|
61
|
+
"""Optional: re-run evaluation from saved results."""
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
# -----------------------
|
|
65
|
+
# Main Experiment Runner
|
|
66
|
+
# -----------------------
|
|
67
|
+
def run_exp(self, do_calc_metrics=True, *args, **kwargs):
|
|
68
|
+
"""
|
|
69
|
+
Run the whole experiment pipeline.
|
|
70
|
+
Params:
|
|
71
|
+
+ 'outfile' to save csv file results,
|
|
72
|
+
+ 'outdir' to set output directory for experiment results.
|
|
73
|
+
+ 'return_df' to return a DataFrame of results instead of a dictionary.
|
|
74
|
+
|
|
75
|
+
Full pipeline:
|
|
76
|
+
1. Init
|
|
77
|
+
2. Dataset
|
|
78
|
+
3. Metrics Preparation
|
|
79
|
+
4. Save Config
|
|
80
|
+
5. Execute
|
|
81
|
+
6. Calculate & Save Metrics
|
|
82
|
+
"""
|
|
83
|
+
self.init_general(self.config.get_general_cfg())
|
|
84
|
+
self.prepare_dataset(self.config.get_dataset_cfg())
|
|
85
|
+
self.prepare_metrics(self.config.get_metric_cfg())
|
|
86
|
+
|
|
87
|
+
# Save config before running
|
|
88
|
+
self.config.save_to_outdir()
|
|
89
|
+
|
|
90
|
+
# Execute experiment
|
|
91
|
+
results = self.exec_exp(*args, **kwargs)
|
|
92
|
+
if do_calc_metrics:
|
|
93
|
+
metrics_data, extra_data = results
|
|
94
|
+
# Calculate & Save metrics
|
|
95
|
+
perf_results = self.calc_and_save_exp_perfs(
|
|
96
|
+
raw_metrics_data=metrics_data, extra_data=extra_data, *args, **kwargs
|
|
97
|
+
)
|
|
98
|
+
return perf_results
|
|
99
|
+
else:
|
|
100
|
+
return results
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
from rich.pretty import pprint
|
|
3
|
+
from argparse import ArgumentParser
|
|
4
|
+
|
|
5
|
+
def cols_to_col_groups(df):
|
|
6
|
+
columns = list(df.columns)
|
|
7
|
+
# pprint(columns)
|
|
8
|
+
|
|
9
|
+
col_groups = []
|
|
10
|
+
current_group = []
|
|
11
|
+
|
|
12
|
+
def have_unnamed(col_group):
|
|
13
|
+
return any("unnamed" in col.lower() for col in col_group)
|
|
14
|
+
|
|
15
|
+
for i, col in enumerate(columns):
|
|
16
|
+
# Add the first column to the current group
|
|
17
|
+
if not current_group:
|
|
18
|
+
current_group.append(col)
|
|
19
|
+
continue
|
|
20
|
+
|
|
21
|
+
prev_col = columns[i - 1]
|
|
22
|
+
# Check if current column is "unnamed" or shares base name with previous
|
|
23
|
+
# Assuming "equal" means same base name (before any suffix like '_1')
|
|
24
|
+
base_prev = (
|
|
25
|
+
prev_col.split("_")[0].lower() if "_" in prev_col else prev_col.lower()
|
|
26
|
+
)
|
|
27
|
+
base_col = col.split("_")[0].lower() if "_" in col else col.lower()
|
|
28
|
+
is_unnamed = "unnamed" in col.lower()
|
|
29
|
+
is_equal = base_col == base_prev
|
|
30
|
+
|
|
31
|
+
if is_unnamed or is_equal:
|
|
32
|
+
# Add to current group
|
|
33
|
+
current_group.append(col)
|
|
34
|
+
else:
|
|
35
|
+
# Start a new group
|
|
36
|
+
col_groups.append(current_group)
|
|
37
|
+
current_group = [col]
|
|
38
|
+
# Append the last group
|
|
39
|
+
if current_group:
|
|
40
|
+
col_groups.append(current_group)
|
|
41
|
+
meta_dict = {"common_cols": [], "db_cols": []}
|
|
42
|
+
for group in col_groups:
|
|
43
|
+
if not have_unnamed(group):
|
|
44
|
+
meta_dict["common_cols"].extend(group)
|
|
45
|
+
else:
|
|
46
|
+
# find the first unnamed column
|
|
47
|
+
named_col = next(
|
|
48
|
+
(col for col in group if "unnamed" not in col.lower()), None
|
|
49
|
+
)
|
|
50
|
+
group_cols = [f"{named_col}_{i}" for i in range(len(group))]
|
|
51
|
+
meta_dict["db_cols"].extend(group_cols)
|
|
52
|
+
return meta_dict
|
|
53
|
+
|
|
54
|
+
# def bech_by_db_name(df, db_list="db1, db2", key_metrics="p, r, f1, acc"):
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def str_2_list(input_str, sep=","):
|
|
58
|
+
out_ls = []
|
|
59
|
+
if len(input_str.strip()) == 0:
|
|
60
|
+
return out_ls
|
|
61
|
+
if sep not in input_str:
|
|
62
|
+
out_ls.append(input_str.strip())
|
|
63
|
+
return out_ls
|
|
64
|
+
else:
|
|
65
|
+
out_ls = [item.strip() for item in input_str.split(sep) if item.strip()]
|
|
66
|
+
return out_ls
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def filter_bech_df_by_db_and_metrics(df, db_list="", key_metrics=""):
|
|
70
|
+
meta_cols_dict = cols_to_col_groups(df)
|
|
71
|
+
op_df = df.copy()
|
|
72
|
+
op_df.columns = (
|
|
73
|
+
meta_cols_dict["common_cols"].copy() + meta_cols_dict["db_cols"].copy()
|
|
74
|
+
)
|
|
75
|
+
filterd_cols = []
|
|
76
|
+
filterd_cols.extend(meta_cols_dict["common_cols"])
|
|
77
|
+
|
|
78
|
+
selected_db_list = str_2_list(db_list)
|
|
79
|
+
db_filted_cols = []
|
|
80
|
+
if len(selected_db_list) > 0:
|
|
81
|
+
for db_name in db_list.split(","):
|
|
82
|
+
db_name = db_name.strip()
|
|
83
|
+
for col_name in meta_cols_dict["db_cols"]:
|
|
84
|
+
if db_name.lower() in col_name.lower():
|
|
85
|
+
db_filted_cols.append(col_name)
|
|
86
|
+
else:
|
|
87
|
+
db_filted_cols = meta_cols_dict["db_cols"]
|
|
88
|
+
|
|
89
|
+
filterd_cols.extend(db_filted_cols)
|
|
90
|
+
df_filtered = op_df[filterd_cols].copy()
|
|
91
|
+
df_filtered
|
|
92
|
+
|
|
93
|
+
selected_metrics_ls = str_2_list(key_metrics)
|
|
94
|
+
if len(selected_metrics_ls) > 0:
|
|
95
|
+
# get the second row as metrics row (header)
|
|
96
|
+
metrics_row = df_filtered.iloc[0].copy()
|
|
97
|
+
# only get the values in columns in (db_filterd_cols)
|
|
98
|
+
metrics_values = metrics_row[db_filted_cols].values
|
|
99
|
+
keep_metrics_cols = []
|
|
100
|
+
# create a zip of db_filted_cols and metrics_values (in that metrics_row)
|
|
101
|
+
metrics_list = list(zip(metrics_values, db_filted_cols))
|
|
102
|
+
selected_metrics_ls = [metric.strip().lower() for metric in selected_metrics_ls]
|
|
103
|
+
for metric, col_name in metrics_list:
|
|
104
|
+
if metric.lower() in selected_metrics_ls:
|
|
105
|
+
keep_metrics_cols.append(col_name)
|
|
106
|
+
|
|
107
|
+
else:
|
|
108
|
+
pprint("No metrics selected, keeping all db columns")
|
|
109
|
+
keep_metrics_cols = db_filted_cols
|
|
110
|
+
|
|
111
|
+
final_filterd_cols = meta_cols_dict["common_cols"].copy() + keep_metrics_cols
|
|
112
|
+
df_final = df_filtered[final_filterd_cols].copy()
|
|
113
|
+
return df_final
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def parse_args():
|
|
117
|
+
parser = ArgumentParser(
|
|
118
|
+
description="desc text")
|
|
119
|
+
parser.add_argument('-csv', '--csv', type=str, help='CSV file path', default=r"E:\Dev\__halib\test\bench.csv")
|
|
120
|
+
return parser.parse_args()
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def main():
|
|
124
|
+
args = parse_args()
|
|
125
|
+
csv_file = args.csv
|
|
126
|
+
df = pd.read_csv(csv_file, sep=";", encoding="utf-8")
|
|
127
|
+
filtered_df = filter_bech_df_by_db_and_metrics(df, "bowfire", "acc")
|
|
128
|
+
print(filtered_df)
|
|
129
|
+
|
|
130
|
+
if __name__ == "__main__":
|
|
131
|
+
main()
|