halib 0.1.47__py3-none-any.whl → 0.1.49__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- halib/research/__init__.py +0 -0
- halib/research/benchquery.py +131 -0
- halib/research/dataset.py +209 -0
- halib/research/plot.py +301 -0
- halib/research/torchloader.py +162 -0
- halib/research/wandb_op.py +116 -0
- halib/utils/__init__.py +0 -0
- halib/utils/listop.py +13 -0
- halib/utils/tele_noti.py +166 -0
- {halib-0.1.47.dist-info → halib-0.1.49.dist-info}/METADATA +156 -154
- {halib-0.1.47.dist-info → halib-0.1.49.dist-info}/RECORD +14 -5
- {halib-0.1.47.dist-info → halib-0.1.49.dist-info}/WHEEL +1 -1
- {halib-0.1.47.dist-info → halib-0.1.49.dist-info}/LICENSE.txt +0 -0
- {halib-0.1.47.dist-info → halib-0.1.49.dist-info}/top_level.txt +0 -0
File without changes
|
@@ -0,0 +1,131 @@
|
|
1
|
+
import pandas as pd
|
2
|
+
from rich.pretty import pprint
|
3
|
+
from argparse import ArgumentParser
|
4
|
+
|
5
|
+
def cols_to_col_groups(df):
|
6
|
+
columns = list(df.columns)
|
7
|
+
# pprint(columns)
|
8
|
+
|
9
|
+
col_groups = []
|
10
|
+
current_group = []
|
11
|
+
|
12
|
+
def have_unnamed(col_group):
|
13
|
+
return any("unnamed" in col.lower() for col in col_group)
|
14
|
+
|
15
|
+
for i, col in enumerate(columns):
|
16
|
+
# Add the first column to the current group
|
17
|
+
if not current_group:
|
18
|
+
current_group.append(col)
|
19
|
+
continue
|
20
|
+
|
21
|
+
prev_col = columns[i - 1]
|
22
|
+
# Check if current column is "unnamed" or shares base name with previous
|
23
|
+
# Assuming "equal" means same base name (before any suffix like '_1')
|
24
|
+
base_prev = (
|
25
|
+
prev_col.split("_")[0].lower() if "_" in prev_col else prev_col.lower()
|
26
|
+
)
|
27
|
+
base_col = col.split("_")[0].lower() if "_" in col else col.lower()
|
28
|
+
is_unnamed = "unnamed" in col.lower()
|
29
|
+
is_equal = base_col == base_prev
|
30
|
+
|
31
|
+
if is_unnamed or is_equal:
|
32
|
+
# Add to current group
|
33
|
+
current_group.append(col)
|
34
|
+
else:
|
35
|
+
# Start a new group
|
36
|
+
col_groups.append(current_group)
|
37
|
+
current_group = [col]
|
38
|
+
# Append the last group
|
39
|
+
if current_group:
|
40
|
+
col_groups.append(current_group)
|
41
|
+
meta_dict = {"common_cols": [], "db_cols": []}
|
42
|
+
for group in col_groups:
|
43
|
+
if not have_unnamed(group):
|
44
|
+
meta_dict["common_cols"].extend(group)
|
45
|
+
else:
|
46
|
+
# find the first unnamed column
|
47
|
+
named_col = next(
|
48
|
+
(col for col in group if "unnamed" not in col.lower()), None
|
49
|
+
)
|
50
|
+
group_cols = [f"{named_col}_{i}" for i in range(len(group))]
|
51
|
+
meta_dict["db_cols"].extend(group_cols)
|
52
|
+
return meta_dict
|
53
|
+
|
54
|
+
# def bech_by_db_name(df, db_list="db1, db2", key_metrics="p, r, f1, acc"):
|
55
|
+
|
56
|
+
|
57
|
+
def str_2_list(input_str, sep=","):
|
58
|
+
out_ls = []
|
59
|
+
if len(input_str.strip()) == 0:
|
60
|
+
return out_ls
|
61
|
+
if sep not in input_str:
|
62
|
+
out_ls.append(input_str.strip())
|
63
|
+
return out_ls
|
64
|
+
else:
|
65
|
+
out_ls = [item.strip() for item in input_str.split(sep) if item.strip()]
|
66
|
+
return out_ls
|
67
|
+
|
68
|
+
|
69
|
+
def filter_bech_df_by_db_and_metrics(df, db_list="", key_metrics=""):
|
70
|
+
meta_cols_dict = cols_to_col_groups(df)
|
71
|
+
op_df = df.copy()
|
72
|
+
op_df.columns = (
|
73
|
+
meta_cols_dict["common_cols"].copy() + meta_cols_dict["db_cols"].copy()
|
74
|
+
)
|
75
|
+
filterd_cols = []
|
76
|
+
filterd_cols.extend(meta_cols_dict["common_cols"])
|
77
|
+
|
78
|
+
selected_db_list = str_2_list(db_list)
|
79
|
+
db_filted_cols = []
|
80
|
+
if len(selected_db_list) > 0:
|
81
|
+
for db_name in db_list.split(","):
|
82
|
+
db_name = db_name.strip()
|
83
|
+
for col_name in meta_cols_dict["db_cols"]:
|
84
|
+
if db_name.lower() in col_name.lower():
|
85
|
+
db_filted_cols.append(col_name)
|
86
|
+
else:
|
87
|
+
db_filted_cols = meta_cols_dict["db_cols"]
|
88
|
+
|
89
|
+
filterd_cols.extend(db_filted_cols)
|
90
|
+
df_filtered = op_df[filterd_cols].copy()
|
91
|
+
df_filtered
|
92
|
+
|
93
|
+
selected_metrics_ls = str_2_list(key_metrics)
|
94
|
+
if len(selected_metrics_ls) > 0:
|
95
|
+
# get the second row as metrics row (header)
|
96
|
+
metrics_row = df_filtered.iloc[0].copy()
|
97
|
+
# only get the values in columns in (db_filterd_cols)
|
98
|
+
metrics_values = metrics_row[db_filted_cols].values
|
99
|
+
keep_metrics_cols = []
|
100
|
+
# create a zip of db_filted_cols and metrics_values (in that metrics_row)
|
101
|
+
metrics_list = list(zip(metrics_values, db_filted_cols))
|
102
|
+
selected_metrics_ls = [metric.strip().lower() for metric in selected_metrics_ls]
|
103
|
+
for metric, col_name in metrics_list:
|
104
|
+
if metric.lower() in selected_metrics_ls:
|
105
|
+
keep_metrics_cols.append(col_name)
|
106
|
+
|
107
|
+
else:
|
108
|
+
pprint("No metrics selected, keeping all db columns")
|
109
|
+
keep_metrics_cols = db_filted_cols
|
110
|
+
|
111
|
+
final_filterd_cols = meta_cols_dict["common_cols"].copy() + keep_metrics_cols
|
112
|
+
df_final = df_filtered[final_filterd_cols].copy()
|
113
|
+
return df_final
|
114
|
+
|
115
|
+
|
116
|
+
def parse_args():
|
117
|
+
parser = ArgumentParser(
|
118
|
+
description="desc text")
|
119
|
+
parser.add_argument('-csv', '--csv', type=str, help='CSV file path', default=r"E:\Dev\__halib\test\bench.csv")
|
120
|
+
return parser.parse_args()
|
121
|
+
|
122
|
+
|
123
|
+
def main():
|
124
|
+
args = parse_args()
|
125
|
+
csv_file = args.csv
|
126
|
+
df = pd.read_csv(csv_file, sep=";", encoding="utf-8")
|
127
|
+
filtered_df = filter_bech_df_by_db_and_metrics(df, "bowfire", "acc")
|
128
|
+
print(filtered_df)
|
129
|
+
|
130
|
+
if __name__ == "__main__":
|
131
|
+
main()
|
@@ -0,0 +1,209 @@
|
|
1
|
+
# This script create a test version
|
2
|
+
# of the watcam (wc) dataset
|
3
|
+
# for testing the tflite model
|
4
|
+
|
5
|
+
from argparse import ArgumentParser
|
6
|
+
|
7
|
+
from rich import inspect
|
8
|
+
from common import console, seed_everything, ConsoleLog
|
9
|
+
from sklearn.model_selection import StratifiedShuffleSplit, ShuffleSplit
|
10
|
+
from tqdm import tqdm
|
11
|
+
import os
|
12
|
+
import click
|
13
|
+
from torchvision.datasets import ImageFolder
|
14
|
+
import shutil
|
15
|
+
from rich.pretty import pprint
|
16
|
+
from system import filesys as fs
|
17
|
+
import glob
|
18
|
+
|
19
|
+
|
20
|
+
def parse_args():
|
21
|
+
parser = ArgumentParser(description="desc text")
|
22
|
+
parser.add_argument(
|
23
|
+
"-indir",
|
24
|
+
"--indir",
|
25
|
+
type=str,
|
26
|
+
help="orignal dataset path",
|
27
|
+
)
|
28
|
+
parser.add_argument(
|
29
|
+
"-outdir",
|
30
|
+
"--outdir",
|
31
|
+
type=str,
|
32
|
+
help="dataset out path",
|
33
|
+
default=".", # default to current dir
|
34
|
+
)
|
35
|
+
parser.add_argument(
|
36
|
+
"-val_size",
|
37
|
+
"--val_size",
|
38
|
+
type=float,
|
39
|
+
help="validation size", # no default value to force user to input
|
40
|
+
default=0.2,
|
41
|
+
)
|
42
|
+
# add using StratifiedShuffleSplit or ShuffleSplit
|
43
|
+
parser.add_argument(
|
44
|
+
"-seed",
|
45
|
+
"--seed",
|
46
|
+
type=int,
|
47
|
+
help="random seed",
|
48
|
+
default=42,
|
49
|
+
)
|
50
|
+
parser.add_argument(
|
51
|
+
"-inplace",
|
52
|
+
"--inplace",
|
53
|
+
action="store_true",
|
54
|
+
help="inplace operation, will overwrite the outdir if exists",
|
55
|
+
)
|
56
|
+
|
57
|
+
parser.add_argument(
|
58
|
+
"-stratified",
|
59
|
+
"--stratified",
|
60
|
+
action="store_true",
|
61
|
+
help="use StratifiedShuffleSplit instead of ShuffleSplit",
|
62
|
+
)
|
63
|
+
parser.add_argument(
|
64
|
+
"-no_train",
|
65
|
+
"--no_train",
|
66
|
+
action="store_true",
|
67
|
+
help="only create test set, no train set",
|
68
|
+
)
|
69
|
+
parser.add_argument(
|
70
|
+
"-reverse",
|
71
|
+
"--reverse",
|
72
|
+
action="store_true",
|
73
|
+
help="combine train and val set back to original dataset",
|
74
|
+
)
|
75
|
+
return parser.parse_args()
|
76
|
+
|
77
|
+
|
78
|
+
def move_images(image_paths, target_set_dir):
|
79
|
+
for img_path in tqdm(image_paths):
|
80
|
+
# get folder name of the image
|
81
|
+
img_dir = os.path.dirname(img_path)
|
82
|
+
out_cls_dir = os.path.join(target_set_dir, os.path.basename(img_dir))
|
83
|
+
if not os.path.exists(out_cls_dir):
|
84
|
+
os.makedirs(out_cls_dir)
|
85
|
+
# move the image to the class folder
|
86
|
+
shutil.move(img_path, out_cls_dir)
|
87
|
+
|
88
|
+
|
89
|
+
def split_dataset_cls(
|
90
|
+
indir, outdir, val_size, seed, inplace, stratified_split, no_train
|
91
|
+
):
|
92
|
+
seed_everything(seed)
|
93
|
+
console.rule("Config confirm?")
|
94
|
+
pprint(locals())
|
95
|
+
click.confirm("Continue?", abort=True)
|
96
|
+
assert os.path.exists(indir), f"{indir} does not exist"
|
97
|
+
|
98
|
+
if not inplace:
|
99
|
+
assert (not inplace) and (
|
100
|
+
not os.path.exists(outdir)
|
101
|
+
), f"{outdir} already exists; SKIP ...."
|
102
|
+
|
103
|
+
if inplace:
|
104
|
+
outdir = indir
|
105
|
+
if not os.path.exists(outdir):
|
106
|
+
os.makedirs(outdir)
|
107
|
+
|
108
|
+
console.rule(f"Creating train/val dataset")
|
109
|
+
|
110
|
+
sss = (
|
111
|
+
ShuffleSplit(n_splits=1, test_size=val_size)
|
112
|
+
if not stratified_split
|
113
|
+
else StratifiedShuffleSplit(n_splits=1, test_size=val_size)
|
114
|
+
)
|
115
|
+
|
116
|
+
pprint({"split strategy": sss, "indir": indir, "outdir": outdir})
|
117
|
+
dataset = ImageFolder(
|
118
|
+
root=indir,
|
119
|
+
transform=None,
|
120
|
+
)
|
121
|
+
train_dataset_indices = None
|
122
|
+
val_dataset_indices = None # val here means test
|
123
|
+
for train_indices, val_indices in sss.split(dataset.samples, dataset.targets):
|
124
|
+
train_dataset_indices = train_indices
|
125
|
+
val_dataset_indices = val_indices
|
126
|
+
|
127
|
+
# get image paths for train/val split dataset
|
128
|
+
train_image_paths = [dataset.imgs[i][0] for i in train_dataset_indices]
|
129
|
+
val_image_paths = [dataset.imgs[i][0] for i in val_dataset_indices]
|
130
|
+
|
131
|
+
# start creating train/val folders then move images
|
132
|
+
out_train_dir = os.path.join(outdir, "train")
|
133
|
+
out_val_dir = os.path.join(outdir, "val")
|
134
|
+
if inplace:
|
135
|
+
assert os.path.exists(out_train_dir) == False, f"{out_train_dir} already exists"
|
136
|
+
assert os.path.exists(out_val_dir) == False, f"{out_val_dir} already exists"
|
137
|
+
|
138
|
+
os.makedirs(out_train_dir)
|
139
|
+
os.makedirs(out_val_dir)
|
140
|
+
|
141
|
+
if not no_train:
|
142
|
+
with ConsoleLog(f"Moving train images to {out_train_dir} "):
|
143
|
+
move_images(train_image_paths, out_train_dir)
|
144
|
+
else:
|
145
|
+
pprint("test only, skip moving train images")
|
146
|
+
# remove out_train_dir
|
147
|
+
shutil.rmtree(out_train_dir)
|
148
|
+
|
149
|
+
with ConsoleLog(f"Moving val images to {out_val_dir} "):
|
150
|
+
move_images(val_image_paths, out_val_dir)
|
151
|
+
|
152
|
+
if inplace:
|
153
|
+
pprint(f"remove all folders, except train and val")
|
154
|
+
for cls_dir in os.listdir(outdir):
|
155
|
+
if cls_dir not in ["train", "val"]:
|
156
|
+
shutil.rmtree(os.path.join(indir, cls_dir))
|
157
|
+
|
158
|
+
|
159
|
+
def reverse_split_ds(indir):
|
160
|
+
console.rule(f"Reversing split dataset <{indir}>...")
|
161
|
+
ls_dirs = os.listdir(indir)
|
162
|
+
# make sure there are only two dirs 'train' and 'val'
|
163
|
+
assert len(ls_dirs) == 2, f"Found more than 2 dirs: {len(ls_dirs) } dirs"
|
164
|
+
assert "train" in ls_dirs, f"train dir not found in {indir}"
|
165
|
+
assert "val" in ls_dirs, f"val dir not found in {indir}"
|
166
|
+
train_dir = os.path.join(indir, "train")
|
167
|
+
val_dir = os.path.join(indir, "val")
|
168
|
+
all_train_files = fs.filter_files_by_extension(
|
169
|
+
train_dir, ["jpg", "jpeg", "png", "bmp", "gif", "tiff"]
|
170
|
+
)
|
171
|
+
all_val_files = fs.filter_files_by_extension(
|
172
|
+
val_dir, ["jpg", "jpeg", "png", "bmp", "gif", "tiff"]
|
173
|
+
)
|
174
|
+
# move all files from train to indir
|
175
|
+
with ConsoleLog(f"Moving train images to {indir} "):
|
176
|
+
move_images(all_train_files, indir)
|
177
|
+
with ConsoleLog(f"Moving val images to {indir} "):
|
178
|
+
move_images(all_val_files, indir)
|
179
|
+
with ConsoleLog(f"Removing train and val dirs"):
|
180
|
+
# remove train and val dirs
|
181
|
+
shutil.rmtree(train_dir)
|
182
|
+
shutil.rmtree(val_dir)
|
183
|
+
|
184
|
+
|
185
|
+
def main():
|
186
|
+
args = parse_args()
|
187
|
+
indir = args.indir
|
188
|
+
outdir = args.outdir
|
189
|
+
if outdir == ".":
|
190
|
+
# get current folder of the indir
|
191
|
+
indir_parent_dir = os.path.dirname(os.path.normpath(indir))
|
192
|
+
indir_name = os.path.basename(indir)
|
193
|
+
outdir = os.path.join(indir_parent_dir, f"{indir_name}_split")
|
194
|
+
val_size = args.val_size
|
195
|
+
seed = args.seed
|
196
|
+
inplace = args.inplace
|
197
|
+
stratified_split = args.stratified
|
198
|
+
no_train = args.no_train
|
199
|
+
reverse = args.reverse
|
200
|
+
if not reverse:
|
201
|
+
split_dataset_cls(
|
202
|
+
indir, outdir, val_size, seed, inplace, stratified_split, no_train
|
203
|
+
)
|
204
|
+
else:
|
205
|
+
reverse_split_ds(indir)
|
206
|
+
|
207
|
+
|
208
|
+
if __name__ == "__main__":
|
209
|
+
main()
|
halib/research/plot.py
ADDED
@@ -0,0 +1,301 @@
|
|
1
|
+
from ..common import now_str, norm_str, ConsoleLog
|
2
|
+
from ..filetype import csvfile
|
3
|
+
from ..system import filesys as fs
|
4
|
+
from functools import partial
|
5
|
+
from rich.console import Console
|
6
|
+
from rich.pretty import pprint
|
7
|
+
import click
|
8
|
+
import csv
|
9
|
+
import matplotlib
|
10
|
+
import matplotlib.pyplot as plt
|
11
|
+
import numpy as np
|
12
|
+
import os
|
13
|
+
import pandas as pd
|
14
|
+
import seaborn as sns
|
15
|
+
|
16
|
+
|
17
|
+
console = Console()
|
18
|
+
desktop_path = os.path.expanduser("~/Desktop")
|
19
|
+
REQUIRED_COLUMNS = ["epoch", "train_loss", "val_loss", "train_acc", "val_acc"]
|
20
|
+
|
21
|
+
import csv
|
22
|
+
|
23
|
+
|
24
|
+
def get_delimiter(file_path, bytes=4096):
|
25
|
+
sniffer = csv.Sniffer()
|
26
|
+
data = open(file_path, "r").read(bytes)
|
27
|
+
delimiter = sniffer.sniff(data).delimiter
|
28
|
+
return delimiter
|
29
|
+
|
30
|
+
|
31
|
+
# Function to verify that the DataFrame has the required columns, and only the required columns
|
32
|
+
def verify_csv(csv_file, required_columns=REQUIRED_COLUMNS):
|
33
|
+
delimiter = get_delimiter(csv_file)
|
34
|
+
df = pd.read_csv(csv_file, sep=delimiter)
|
35
|
+
# change the column names to lower case
|
36
|
+
df.columns = [col.lower() for col in df.columns]
|
37
|
+
for col in required_columns:
|
38
|
+
if col not in df.columns:
|
39
|
+
raise ValueError(
|
40
|
+
f"Required columns are: {REQUIRED_COLUMNS}, but found {df.columns}"
|
41
|
+
)
|
42
|
+
df = df[required_columns].copy()
|
43
|
+
return df
|
44
|
+
|
45
|
+
|
46
|
+
def get_valid_tags(csv_files, tags):
|
47
|
+
if tags is not None and len(tags) > 0:
|
48
|
+
assert all(
|
49
|
+
isinstance(tag, str) for tag in tags
|
50
|
+
), "tags must be a list of strings"
|
51
|
+
assert all(
|
52
|
+
len(tag) > 0 for tag in tags
|
53
|
+
), "tags must be a list of non-empty strings"
|
54
|
+
valid_tags = tags
|
55
|
+
else:
|
56
|
+
valid_tags = []
|
57
|
+
for csv_file in csv_files:
|
58
|
+
file_name = fs.get_file_name(csv_file, split_file_ext=True)[0]
|
59
|
+
tag = norm_str(file_name)
|
60
|
+
valid_tags.append(tag)
|
61
|
+
return valid_tags
|
62
|
+
|
63
|
+
|
64
|
+
def plot_ax(df, ax, metric="loss", tag=""):
|
65
|
+
pprint(locals())
|
66
|
+
# reset plt
|
67
|
+
assert metric in ["loss", "acc"], "metric must be either 'loss' or 'acc'"
|
68
|
+
part = ["train", "val"]
|
69
|
+
for p in part:
|
70
|
+
label = f"{tag}_{p}_{metric}"
|
71
|
+
ax.plot(df["epoch"], df[f"{p}_{metric}"], label=label)
|
72
|
+
return ax
|
73
|
+
|
74
|
+
|
75
|
+
def actual_plot_seaborn(frame, csv_files, axes, tags, log):
|
76
|
+
# clear the axes
|
77
|
+
for ax in axes:
|
78
|
+
ax.clear()
|
79
|
+
ls_df = []
|
80
|
+
valid_tags = get_valid_tags(csv_files, tags)
|
81
|
+
for csv_file in csv_files:
|
82
|
+
df = verify_csv(csv_file)
|
83
|
+
if log:
|
84
|
+
with ConsoleLog(f"plotting {csv_file}"):
|
85
|
+
csvfile.fn_display_df(df)
|
86
|
+
ls_df.append(df)
|
87
|
+
|
88
|
+
ls_metrics = ["loss", "acc"]
|
89
|
+
for df_item, tag in zip(ls_df, valid_tags):
|
90
|
+
# add tag to columns,excpet epoch
|
91
|
+
df_item.columns = [
|
92
|
+
f"{tag}_{col}" if col != "epoch" else col for col in df_item.columns
|
93
|
+
]
|
94
|
+
# merge the dataframes on the epoch column
|
95
|
+
df_combined = ls_df[0]
|
96
|
+
for df_item in ls_df[1:]:
|
97
|
+
df_combined = pd.merge(df_combined, df_item, on="epoch", how="outer")
|
98
|
+
# csvfile.fn_display_df(df_combined)
|
99
|
+
|
100
|
+
for i, metric in enumerate(ls_metrics):
|
101
|
+
tags_str = "+".join(valid_tags) if len(valid_tags) > 1 else valid_tags[0]
|
102
|
+
title = f"{tags_str}_{metric}-by-epoch"
|
103
|
+
cols = [col for col in df_combined.columns if col != "epoch" and metric in col]
|
104
|
+
cols = sorted(cols)
|
105
|
+
# pprint(cols)
|
106
|
+
plot_data = df_combined[cols]
|
107
|
+
|
108
|
+
# line from same csv file (same tag) should have the same marker
|
109
|
+
all_markers = [
|
110
|
+
marker for marker in plt.Line2D.markers if marker and marker != " "
|
111
|
+
]
|
112
|
+
tag2marker = {tag: marker for tag, marker in zip(valid_tags, all_markers)}
|
113
|
+
plot_markers = []
|
114
|
+
for col in cols:
|
115
|
+
# find the tag:
|
116
|
+
tag = None
|
117
|
+
for valid_tag in valid_tags:
|
118
|
+
if valid_tag in col:
|
119
|
+
tag = valid_tag
|
120
|
+
break
|
121
|
+
plot_markers.append(tag2marker[tag])
|
122
|
+
# pprint(list(zip(cols, plot_markers)))
|
123
|
+
|
124
|
+
# create color
|
125
|
+
sequential_palettes = [
|
126
|
+
"Reds",
|
127
|
+
"Greens",
|
128
|
+
"Blues",
|
129
|
+
"Oranges",
|
130
|
+
"Purples",
|
131
|
+
"Greys",
|
132
|
+
"BuGn",
|
133
|
+
"BuPu",
|
134
|
+
"GnBu",
|
135
|
+
"OrRd",
|
136
|
+
"PuBu",
|
137
|
+
"PuRd",
|
138
|
+
"RdPu",
|
139
|
+
"YlGn",
|
140
|
+
"PuBuGn",
|
141
|
+
"YlGnBu",
|
142
|
+
"YlOrBr",
|
143
|
+
"YlOrRd",
|
144
|
+
]
|
145
|
+
# each csvfile (tag) should have a unique color
|
146
|
+
tag2palette = {
|
147
|
+
tag: palette for tag, palette in zip(valid_tags, sequential_palettes)
|
148
|
+
}
|
149
|
+
plot_colors = []
|
150
|
+
for tag in valid_tags:
|
151
|
+
palette = tag2palette[tag]
|
152
|
+
total_colors = 10
|
153
|
+
ls_colors = sns.color_palette(palette, total_colors).as_hex()
|
154
|
+
num_part = len(ls_metrics)
|
155
|
+
subarr = np.array_split(np.arange(total_colors), num_part)
|
156
|
+
for idx, col in enumerate(cols):
|
157
|
+
if tag in col:
|
158
|
+
chosen_color = ls_colors[
|
159
|
+
subarr[int(idx % num_part)].mean().astype(int)
|
160
|
+
]
|
161
|
+
plot_colors.append(chosen_color)
|
162
|
+
|
163
|
+
# pprint(list(zip(cols, plot_colors)))
|
164
|
+
sns.lineplot(
|
165
|
+
data=plot_data,
|
166
|
+
markers=plot_markers,
|
167
|
+
palette=plot_colors,
|
168
|
+
ax=axes[i],
|
169
|
+
dashes=False,
|
170
|
+
)
|
171
|
+
axes[i].set(xlabel="epoch", ylabel=metric, title=title)
|
172
|
+
axes[i].legend()
|
173
|
+
axes[i].grid()
|
174
|
+
|
175
|
+
|
176
|
+
def actual_plot(frame, csv_files, axes, tags, log):
|
177
|
+
ls_df = []
|
178
|
+
valid_tags = get_valid_tags(csv_files, tags)
|
179
|
+
for csv_file in csv_files:
|
180
|
+
df = verify_csv(csv_file)
|
181
|
+
if log:
|
182
|
+
with ConsoleLog(f"plotting {csv_file}"):
|
183
|
+
csvfile.fn_display_df(df)
|
184
|
+
ls_df.append(df)
|
185
|
+
|
186
|
+
metric_values = ["loss", "acc"]
|
187
|
+
for i, metric in enumerate(metric_values):
|
188
|
+
for df_item, tag in zip(ls_df, valid_tags):
|
189
|
+
metric_ax = plot_ax(df_item, axes[i], metric, tag)
|
190
|
+
|
191
|
+
# set the title, xlabel, ylabel, legend, and grid
|
192
|
+
tags_str = "+".join(valid_tags) if len(valid_tags) > 1 else valid_tags[0]
|
193
|
+
metric_ax.set(
|
194
|
+
xlabel="epoch", ylabel=metric, title=f"{tags_str}_{metric}-by-epoch"
|
195
|
+
)
|
196
|
+
metric_ax.legend()
|
197
|
+
metric_ax.grid()
|
198
|
+
|
199
|
+
|
200
|
+
def plot_csv_files(
|
201
|
+
csv_files,
|
202
|
+
outdir="./out/plot",
|
203
|
+
tags=None,
|
204
|
+
log=False,
|
205
|
+
save_fig=False,
|
206
|
+
update_in_min=1,
|
207
|
+
):
|
208
|
+
# if csv_files is a string, convert it to a list
|
209
|
+
if isinstance(csv_files, str):
|
210
|
+
csv_files = [csv_files]
|
211
|
+
# if tags is a string, convert it to a list
|
212
|
+
if isinstance(tags, str):
|
213
|
+
tags = [tags]
|
214
|
+
valid_tags = get_valid_tags(csv_files, tags)
|
215
|
+
assert len(valid_tags) == len(
|
216
|
+
csv_files
|
217
|
+
), "Unable to determine tags for each csv file"
|
218
|
+
live_update_in_ms = int(update_in_min * 60 * 1000)
|
219
|
+
fig, axes = plt.subplots(2, 1, figsize=(10, 17))
|
220
|
+
if live_update_in_ms: # live update in min should be > 0
|
221
|
+
from matplotlib.animation import FuncAnimation
|
222
|
+
|
223
|
+
anim = FuncAnimation(
|
224
|
+
fig,
|
225
|
+
partial(
|
226
|
+
actual_plot_seaborn, csv_files=csv_files, axes=axes, tags=tags, log=log
|
227
|
+
),
|
228
|
+
interval=live_update_in_ms,
|
229
|
+
blit=False,
|
230
|
+
cache_frame_data=False,
|
231
|
+
)
|
232
|
+
plt.show()
|
233
|
+
else:
|
234
|
+
actual_plot_seaborn(None, csv_files, axes, tags, log)
|
235
|
+
plt.show()
|
236
|
+
|
237
|
+
if save_fig:
|
238
|
+
os.makedirs(outdir, exist_ok=True)
|
239
|
+
tags_str = "+".join(valid_tags) if len(valid_tags) > 1 else valid_tags[0]
|
240
|
+
tag = f"{now_str()}_{tags_str}"
|
241
|
+
fig.savefig(f"{outdir}/{tag}_plot.png")
|
242
|
+
enable_plot_pgf()
|
243
|
+
fig.savefig(f"{outdir}/{tag}_plot.pdf")
|
244
|
+
if live_update_in_ms:
|
245
|
+
return anim
|
246
|
+
|
247
|
+
|
248
|
+
def enable_plot_pgf():
|
249
|
+
matplotlib.use("pdf")
|
250
|
+
matplotlib.rcParams.update(
|
251
|
+
{
|
252
|
+
"pgf.texsystem": "pdflatex",
|
253
|
+
"font.family": "serif",
|
254
|
+
"text.usetex": True,
|
255
|
+
"pgf.rcfonts": False,
|
256
|
+
}
|
257
|
+
)
|
258
|
+
|
259
|
+
|
260
|
+
def save_fig_latex_pgf(filename, directory="."):
|
261
|
+
enable_plot_pgf()
|
262
|
+
if ".pgf" not in filename:
|
263
|
+
filename = f"{directory}/{filename}.pgf"
|
264
|
+
plt.savefig(filename)
|
265
|
+
|
266
|
+
|
267
|
+
# https: // click.palletsprojects.com/en/8.1.x/api/
|
268
|
+
@click.command()
|
269
|
+
@click.option("--csvfiles", "-f", multiple=True, type=str, help="csv files to plot")
|
270
|
+
@click.option(
|
271
|
+
"--outdir",
|
272
|
+
"-o",
|
273
|
+
type=str,
|
274
|
+
help="output directory for the plot",
|
275
|
+
default=str(desktop_path),
|
276
|
+
)
|
277
|
+
@click.option(
|
278
|
+
"--tags", "-t", multiple=True, type=str, help="tags for the csv files", default=[]
|
279
|
+
)
|
280
|
+
@click.option("--log", "-l", is_flag=True, help="log the csv files")
|
281
|
+
@click.option("--save_fig", "-s", is_flag=True, help="save the plot as a file")
|
282
|
+
@click.option(
|
283
|
+
"--update_in_min",
|
284
|
+
"-u",
|
285
|
+
type=float,
|
286
|
+
help="update the plot every x minutes",
|
287
|
+
default=0.0,
|
288
|
+
)
|
289
|
+
def main(
|
290
|
+
csvfiles,
|
291
|
+
outdir,
|
292
|
+
tags,
|
293
|
+
log,
|
294
|
+
save_fig,
|
295
|
+
update_in_min,
|
296
|
+
):
|
297
|
+
plot_csv_files(list(csvfiles), outdir, list(tags), log, save_fig, update_in_min)
|
298
|
+
|
299
|
+
|
300
|
+
if __name__ == "__main__":
|
301
|
+
main()
|