halib 0.2.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- halib/__init__.py +94 -0
- halib/common/__init__.py +0 -0
- halib/common/common.py +326 -0
- halib/common/rich_color.py +285 -0
- halib/common.py +151 -0
- halib/csvfile.py +48 -0
- halib/cuda.py +39 -0
- halib/dataset.py +209 -0
- halib/exp/__init__.py +0 -0
- halib/exp/core/__init__.py +0 -0
- halib/exp/core/base_config.py +167 -0
- halib/exp/core/base_exp.py +147 -0
- halib/exp/core/param_gen.py +170 -0
- halib/exp/core/wandb_op.py +117 -0
- halib/exp/data/__init__.py +0 -0
- halib/exp/data/dataclass_util.py +41 -0
- halib/exp/data/dataset.py +208 -0
- halib/exp/data/torchloader.py +165 -0
- halib/exp/perf/__init__.py +0 -0
- halib/exp/perf/flop_calc.py +190 -0
- halib/exp/perf/gpu_mon.py +58 -0
- halib/exp/perf/perfcalc.py +470 -0
- halib/exp/perf/perfmetrics.py +137 -0
- halib/exp/perf/perftb.py +778 -0
- halib/exp/perf/profiler.py +507 -0
- halib/exp/viz/__init__.py +0 -0
- halib/exp/viz/plot.py +754 -0
- halib/filesys.py +117 -0
- halib/filetype/__init__.py +0 -0
- halib/filetype/csvfile.py +192 -0
- halib/filetype/ipynb.py +61 -0
- halib/filetype/jsonfile.py +19 -0
- halib/filetype/textfile.py +12 -0
- halib/filetype/videofile.py +266 -0
- halib/filetype/yamlfile.py +87 -0
- halib/gdrive.py +179 -0
- halib/gdrive_mkdir.py +41 -0
- halib/gdrive_test.py +37 -0
- halib/jsonfile.py +22 -0
- halib/listop.py +13 -0
- halib/online/__init__.py +0 -0
- halib/online/gdrive.py +229 -0
- halib/online/gdrive_mkdir.py +53 -0
- halib/online/gdrive_test.py +50 -0
- halib/online/projectmake.py +131 -0
- halib/online/tele_noti.py +165 -0
- halib/plot.py +301 -0
- halib/projectmake.py +115 -0
- halib/research/__init__.py +0 -0
- halib/research/base_config.py +100 -0
- halib/research/base_exp.py +157 -0
- halib/research/benchquery.py +131 -0
- halib/research/core/__init__.py +0 -0
- halib/research/core/base_config.py +144 -0
- halib/research/core/base_exp.py +157 -0
- halib/research/core/param_gen.py +108 -0
- halib/research/core/wandb_op.py +117 -0
- halib/research/data/__init__.py +0 -0
- halib/research/data/dataclass_util.py +41 -0
- halib/research/data/dataset.py +208 -0
- halib/research/data/torchloader.py +165 -0
- halib/research/dataset.py +208 -0
- halib/research/flop_csv.py +34 -0
- halib/research/flops.py +156 -0
- halib/research/metrics.py +137 -0
- halib/research/mics.py +74 -0
- halib/research/params_gen.py +108 -0
- halib/research/perf/__init__.py +0 -0
- halib/research/perf/flop_calc.py +190 -0
- halib/research/perf/gpu_mon.py +58 -0
- halib/research/perf/perfcalc.py +363 -0
- halib/research/perf/perfmetrics.py +137 -0
- halib/research/perf/perftb.py +778 -0
- halib/research/perf/profiler.py +301 -0
- halib/research/perfcalc.py +361 -0
- halib/research/perftb.py +780 -0
- halib/research/plot.py +758 -0
- halib/research/profiler.py +300 -0
- halib/research/torchloader.py +162 -0
- halib/research/viz/__init__.py +0 -0
- halib/research/viz/plot.py +754 -0
- halib/research/wandb_op.py +116 -0
- halib/rich_color.py +285 -0
- halib/sys/__init__.py +0 -0
- halib/sys/cmd.py +8 -0
- halib/sys/filesys.py +124 -0
- halib/system/__init__.py +0 -0
- halib/system/_list_pc.csv +6 -0
- halib/system/cmd.py +8 -0
- halib/system/filesys.py +164 -0
- halib/system/path.py +106 -0
- halib/tele_noti.py +166 -0
- halib/textfile.py +13 -0
- halib/torchloader.py +162 -0
- halib/utils/__init__.py +0 -0
- halib/utils/dataclass_util.py +40 -0
- halib/utils/dict.py +317 -0
- halib/utils/dict_op.py +9 -0
- halib/utils/gpu_mon.py +58 -0
- halib/utils/list.py +17 -0
- halib/utils/listop.py +13 -0
- halib/utils/slack.py +86 -0
- halib/utils/tele_noti.py +166 -0
- halib/utils/video.py +82 -0
- halib/videofile.py +139 -0
- halib-0.2.30.dist-info/METADATA +237 -0
- halib-0.2.30.dist-info/RECORD +110 -0
- halib-0.2.30.dist-info/WHEEL +5 -0
- halib-0.2.30.dist-info/licenses/LICENSE.txt +17 -0
- halib-0.2.30.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from argparse import ArgumentParser
|
|
2
|
+
import gdrive
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def parse_args():
|
|
6
|
+
parser = ArgumentParser(description="Upload local folder to Google Drive")
|
|
7
|
+
parser.add_argument(
|
|
8
|
+
"-a",
|
|
9
|
+
"--authFile",
|
|
10
|
+
type=str,
|
|
11
|
+
help="authenticate file to Google Drive",
|
|
12
|
+
default="settings.yaml",
|
|
13
|
+
)
|
|
14
|
+
parser.add_argument("-s", "--source", type=str, help="Folder to upload")
|
|
15
|
+
parser.add_argument(
|
|
16
|
+
"-d", "--destination", type=str, help="Destination folder ID in Google Drive"
|
|
17
|
+
)
|
|
18
|
+
parser.add_argument(
|
|
19
|
+
"-c",
|
|
20
|
+
"--contentOnly",
|
|
21
|
+
type=str,
|
|
22
|
+
help="Parent Folder in Google Drive",
|
|
23
|
+
default="True",
|
|
24
|
+
)
|
|
25
|
+
parser.add_argument(
|
|
26
|
+
"-i",
|
|
27
|
+
"--ignoreFile",
|
|
28
|
+
type=str,
|
|
29
|
+
help="file containing files/folders to ignore",
|
|
30
|
+
default=None,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
return parser.parse_args()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def main():
|
|
37
|
+
args = parse_args()
|
|
38
|
+
auth_file = args.authFile
|
|
39
|
+
local_folder = args.source
|
|
40
|
+
gg_folder_id = args.destination
|
|
41
|
+
content_only = args.contentOnly.lower() == "true"
|
|
42
|
+
ignore_file = args.ignoreFile
|
|
43
|
+
gdrive.get_gg_drive(auth_file)
|
|
44
|
+
gdrive.upload_folder_to_drive(
|
|
45
|
+
local_folder, gg_folder_id, content_only=content_only, ignore_file=ignore_file
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
if __name__ == "__main__":
|
|
50
|
+
main()
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
# coding=utf-8
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import json
|
|
5
|
+
import pycurl
|
|
6
|
+
import shutil
|
|
7
|
+
import certifi
|
|
8
|
+
import subprocess
|
|
9
|
+
from io import BytesIO
|
|
10
|
+
|
|
11
|
+
from argparse import ArgumentParser
|
|
12
|
+
|
|
13
|
+
from ..filetype import jsonfile
|
|
14
|
+
from ..system import filesys
|
|
15
|
+
|
|
16
|
+
def get_curl(url, user_and_pass, verbose=True):
|
|
17
|
+
c = pycurl.Curl()
|
|
18
|
+
c.setopt(pycurl.VERBOSE, verbose)
|
|
19
|
+
c.setopt(pycurl.CAINFO, certifi.where())
|
|
20
|
+
c.setopt(pycurl.URL, url)
|
|
21
|
+
c.setopt(pycurl.USERPWD, user_and_pass)
|
|
22
|
+
return c
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_user_and_pass(username, appPass):
|
|
26
|
+
return f"{username}:{appPass}"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def create_repo(
|
|
30
|
+
username, appPass, repo_name, workspace, proj_name, template_repo="py-proj-template"
|
|
31
|
+
):
|
|
32
|
+
buffer = BytesIO()
|
|
33
|
+
url = f"https://api.bitbucket.org/2.0/repositories/{workspace}/{repo_name}"
|
|
34
|
+
data = json.dumps({"scm": "git", "project": {"key": f"{proj_name}"}})
|
|
35
|
+
|
|
36
|
+
user_and_pass = get_user_and_pass(username, appPass)
|
|
37
|
+
c = get_curl(url, user_and_pass)
|
|
38
|
+
c.setopt(pycurl.WRITEDATA, buffer)
|
|
39
|
+
c.setopt(pycurl.POST, 1)
|
|
40
|
+
c.setopt(pycurl.POSTFIELDS, data)
|
|
41
|
+
c.setopt(pycurl.HTTPHEADER, ["Accept: application/json"])
|
|
42
|
+
c.perform()
|
|
43
|
+
RESPOND_CODE = c.getinfo(pycurl.HTTP_CODE)
|
|
44
|
+
c.close()
|
|
45
|
+
# log info
|
|
46
|
+
body = buffer.getvalue()
|
|
47
|
+
msg = body.decode("iso-8859-1")
|
|
48
|
+
successful = True if str(RESPOND_CODE) == "200" else False
|
|
49
|
+
|
|
50
|
+
if successful and template_repo:
|
|
51
|
+
template_repo_url = f"https://{username}:{appPass}@bitbucket.org/{workspace}/{template_repo}.git"
|
|
52
|
+
git_clone(template_repo_url)
|
|
53
|
+
template_folder = f"./{template_repo}"
|
|
54
|
+
|
|
55
|
+
created_repo_url = (
|
|
56
|
+
f"https://{username}:{appPass}@bitbucket.org/{workspace}/{repo_name}.git"
|
|
57
|
+
)
|
|
58
|
+
git_clone(created_repo_url)
|
|
59
|
+
created_folder = f"./{repo_name}"
|
|
60
|
+
shutil.copytree(
|
|
61
|
+
template_folder,
|
|
62
|
+
created_folder,
|
|
63
|
+
dirs_exist_ok=True,
|
|
64
|
+
ignore=shutil.ignore_patterns(".git"),
|
|
65
|
+
)
|
|
66
|
+
os.system('rmdir /S /Q "{}"'.format(template_folder))
|
|
67
|
+
project_folder = "project_name"
|
|
68
|
+
|
|
69
|
+
filesys.change_current_dir(created_folder)
|
|
70
|
+
filesys.rename_dir_or_file(project_folder, repo_name)
|
|
71
|
+
# push to remote
|
|
72
|
+
subprocess.check_call(["C:/batch/gitp.bat", "init proj from template"])
|
|
73
|
+
|
|
74
|
+
return successful, msg
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def parse_args():
|
|
78
|
+
parser = ArgumentParser(description="Upload local folder to Google Drive")
|
|
79
|
+
parser.add_argument(
|
|
80
|
+
"-a",
|
|
81
|
+
"--authFile",
|
|
82
|
+
type=str,
|
|
83
|
+
help="authenticate file (json) to Bitbucket",
|
|
84
|
+
default="bitbucket.json",
|
|
85
|
+
)
|
|
86
|
+
parser.add_argument(
|
|
87
|
+
"-r", "--repoName", type=str, help="Repository name", default="hahv-proj"
|
|
88
|
+
)
|
|
89
|
+
parser.add_argument(
|
|
90
|
+
"-t", "--templateRepo", type=str, help="template repo to fork", default="True"
|
|
91
|
+
)
|
|
92
|
+
return parser.parse_args()
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def git_clone(url):
|
|
96
|
+
subprocess.check_call(["git", "clone", url])
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def main():
|
|
100
|
+
args = parse_args()
|
|
101
|
+
authFile = args.authFile
|
|
102
|
+
repo_name = args.repoName
|
|
103
|
+
|
|
104
|
+
authInfo = jsonfile.read(authFile)
|
|
105
|
+
username = authInfo["username"]
|
|
106
|
+
appPass = authInfo["appPass"]
|
|
107
|
+
workspace_id = authInfo["workspace_id"]
|
|
108
|
+
project_id = authInfo["project_id"]
|
|
109
|
+
use_template = args.templateRepo.lower() == "true"
|
|
110
|
+
template_repo = authInfo["template_repo"] if use_template else ""
|
|
111
|
+
|
|
112
|
+
extra_info = f"[Use template project {template_repo}]" if use_template else ""
|
|
113
|
+
print(f"[BitBucket] creating {repo_name} Project in Bitbucket {extra_info}")
|
|
114
|
+
|
|
115
|
+
successful, msg = create_repo(
|
|
116
|
+
username,
|
|
117
|
+
appPass,
|
|
118
|
+
repo_name,
|
|
119
|
+
workspace_id,
|
|
120
|
+
project_id,
|
|
121
|
+
template_repo=template_repo,
|
|
122
|
+
)
|
|
123
|
+
if successful:
|
|
124
|
+
print(f"[Bitbucket] {repo_name} created successfully.{extra_info}")
|
|
125
|
+
else:
|
|
126
|
+
formatted_msg = jsonfile.beautify(msg)
|
|
127
|
+
print(f"[Bitbucket] {repo_name} created failed. Details:\n{formatted_msg}")
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
if __name__ == "__main__":
|
|
131
|
+
main()
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
# Watch a log file and send a telegram message when train reaches a certain epoch or end
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import yaml
|
|
5
|
+
import asyncio
|
|
6
|
+
import telegram
|
|
7
|
+
import pandas as pd
|
|
8
|
+
|
|
9
|
+
from rich.pretty import pprint
|
|
10
|
+
from rich.console import Console
|
|
11
|
+
import plotly.graph_objects as go
|
|
12
|
+
|
|
13
|
+
from ..system import filesys as fs
|
|
14
|
+
from ..filetype import textfile, csvfile
|
|
15
|
+
|
|
16
|
+
from argparse import ArgumentParser
|
|
17
|
+
|
|
18
|
+
tele_console = Console()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def parse_args():
|
|
22
|
+
parser = ArgumentParser(description="desc text")
|
|
23
|
+
parser.add_argument(
|
|
24
|
+
"-cfg",
|
|
25
|
+
"--cfg",
|
|
26
|
+
type=str,
|
|
27
|
+
help="yaml file for tele",
|
|
28
|
+
default=r"E:\Dev\__halib\halib\online\tele_noti_cfg.yaml",
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
return parser.parse_args()
|
|
32
|
+
|
|
33
|
+
def get_watcher_message_df(target_file, num_last_lines):
|
|
34
|
+
file_ext = fs.get_file_name(target_file, split_file_ext=True)[1]
|
|
35
|
+
supported_ext = [".txt", ".log", ".csv"]
|
|
36
|
+
assert (
|
|
37
|
+
file_ext in supported_ext
|
|
38
|
+
), f"File extension {file_ext} not supported. Supported extensions are {supported_ext}"
|
|
39
|
+
last_lines_df = None
|
|
40
|
+
if file_ext in [".txt", ".log"]:
|
|
41
|
+
lines = textfile.read_line_by_line(target_file)
|
|
42
|
+
if num_last_lines > len(lines):
|
|
43
|
+
num_last_lines = len(lines)
|
|
44
|
+
last_line_arr = lines[-num_last_lines:]
|
|
45
|
+
# add a line start with word "epoch"
|
|
46
|
+
epoch_info_list = "Epoch: n/a"
|
|
47
|
+
for line in reversed(lines):
|
|
48
|
+
if "epoch" in line.lower():
|
|
49
|
+
epoch_info_list = line
|
|
50
|
+
break
|
|
51
|
+
last_line_arr.insert(0, epoch_info_list) # insert at the beginning
|
|
52
|
+
dfCreator = csvfile.DFCreator()
|
|
53
|
+
dfCreator.create_table("last_lines", ["line"])
|
|
54
|
+
last_line_arr = [[line] for line in last_line_arr]
|
|
55
|
+
dfCreator.insert_rows("last_lines", last_line_arr)
|
|
56
|
+
dfCreator.fill_table_from_row_pool("last_lines")
|
|
57
|
+
last_lines_df = dfCreator["last_lines"].copy()
|
|
58
|
+
else:
|
|
59
|
+
df = pd.read_csv(target_file)
|
|
60
|
+
num_rows = len(df)
|
|
61
|
+
if num_last_lines > num_rows:
|
|
62
|
+
num_last_lines = num_rows
|
|
63
|
+
last_lines_df = df.tail(num_last_lines)
|
|
64
|
+
return last_lines_df
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def df2img(df: pd.DataFrame, output_img_dir, decimal_places, out_img_scale):
|
|
68
|
+
df = df.round(decimal_places)
|
|
69
|
+
fig = go.Figure(
|
|
70
|
+
data=[
|
|
71
|
+
go.Table(
|
|
72
|
+
header=dict(values=list(df.columns), align="center"),
|
|
73
|
+
cells=dict(
|
|
74
|
+
values=df.values.transpose(),
|
|
75
|
+
fill_color=[["white", "lightgrey"] * df.shape[0]],
|
|
76
|
+
align="center",
|
|
77
|
+
),
|
|
78
|
+
)
|
|
79
|
+
]
|
|
80
|
+
)
|
|
81
|
+
if not os.path.exists(output_img_dir):
|
|
82
|
+
os.makedirs(output_img_dir)
|
|
83
|
+
img_path = os.path.normpath(os.path.join(output_img_dir, "last_lines.png"))
|
|
84
|
+
fig.write_image(img_path, scale=out_img_scale)
|
|
85
|
+
return img_path
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def compose_message_and_img_path(
|
|
89
|
+
target_file, project, num_last_lines, decimal_places, out_img_scale, output_img_dir
|
|
90
|
+
):
|
|
91
|
+
context_msg = f">> Project: {project} \n>> File: {target_file} \n>> Last {num_last_lines} lines:"
|
|
92
|
+
msg_df = get_watcher_message_df(target_file, num_last_lines)
|
|
93
|
+
try:
|
|
94
|
+
img_path = df2img(msg_df, output_img_dir, decimal_places, out_img_scale)
|
|
95
|
+
except Exception as e:
|
|
96
|
+
pprint(f"Error: {e}")
|
|
97
|
+
img_path = None
|
|
98
|
+
return context_msg, img_path
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
async def send_to_telegram(cfg_dict, interval_in_sec):
|
|
102
|
+
# pprint(cfg_dict)
|
|
103
|
+
token = cfg_dict["telegram"]["token"]
|
|
104
|
+
chat_id = cfg_dict["telegram"]["chat_id"]
|
|
105
|
+
|
|
106
|
+
noti_settings = cfg_dict["noti_settings"]
|
|
107
|
+
project = noti_settings["project"]
|
|
108
|
+
target_file = noti_settings["target_file"]
|
|
109
|
+
num_last_lines = noti_settings["num_last_lines"]
|
|
110
|
+
output_img_dir = noti_settings["output_img_dir"]
|
|
111
|
+
decimal_places = noti_settings["decimal_places"]
|
|
112
|
+
out_img_scale = noti_settings["out_img_scale"]
|
|
113
|
+
|
|
114
|
+
bot = telegram.Bot(token=token)
|
|
115
|
+
async with bot:
|
|
116
|
+
try:
|
|
117
|
+
context_msg, img_path = compose_message_and_img_path(
|
|
118
|
+
target_file,
|
|
119
|
+
project,
|
|
120
|
+
num_last_lines,
|
|
121
|
+
decimal_places,
|
|
122
|
+
out_img_scale,
|
|
123
|
+
output_img_dir,
|
|
124
|
+
)
|
|
125
|
+
time_now = next_time = pd.Timestamp.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
126
|
+
sep_line = "-" * 50
|
|
127
|
+
context_msg = f"{sep_line}\n>> Time: {time_now}\n{context_msg}"
|
|
128
|
+
# calculate the next time to send message
|
|
129
|
+
next_time = pd.Timestamp.now() + pd.Timedelta(seconds=interval_in_sec)
|
|
130
|
+
next_time = next_time.strftime("%Y-%m-%d %H:%M:%S")
|
|
131
|
+
next_time_info = f"Next msg: {next_time}"
|
|
132
|
+
tele_console.rule()
|
|
133
|
+
tele_console.print("[green] Send message to telegram [/green]")
|
|
134
|
+
tele_console.print(
|
|
135
|
+
f"[red] Next message will be sent at <{next_time}> [/red]"
|
|
136
|
+
)
|
|
137
|
+
await bot.send_message(text=context_msg, chat_id=chat_id)
|
|
138
|
+
if img_path:
|
|
139
|
+
await bot.send_photo(chat_id=chat_id, photo=open(img_path, "rb"))
|
|
140
|
+
await bot.send_message(text=next_time_info, chat_id=chat_id)
|
|
141
|
+
except Exception as e:
|
|
142
|
+
pprint(f"Error: {e}")
|
|
143
|
+
pprint("Message not sent to telegram")
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
async def run_forever(cfg_path):
|
|
147
|
+
cfg_dict = yaml.safe_load(open(cfg_path, "r"))
|
|
148
|
+
noti_settings = cfg_dict["noti_settings"]
|
|
149
|
+
interval_in_min = noti_settings["interval_in_min"]
|
|
150
|
+
interval_in_sec = int(interval_in_min * 60)
|
|
151
|
+
pprint(
|
|
152
|
+
f"Message will be sent every {interval_in_min} minutes or {interval_in_sec} seconds"
|
|
153
|
+
)
|
|
154
|
+
while True:
|
|
155
|
+
await send_to_telegram(cfg_dict, interval_in_sec)
|
|
156
|
+
await asyncio.sleep(interval_in_sec)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
async def main():
|
|
160
|
+
args = parse_args()
|
|
161
|
+
await run_forever(args.cfg)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
if __name__ == "__main__":
|
|
165
|
+
asyncio.run(main())
|
halib/plot.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
1
|
+
from .common import now_str, norm_str, ConsoleLog
|
|
2
|
+
from .filetype import csvfile
|
|
3
|
+
from .system import filesys as fs
|
|
4
|
+
from functools import partial
|
|
5
|
+
from rich.console import Console
|
|
6
|
+
from rich.pretty import pprint
|
|
7
|
+
import click
|
|
8
|
+
import csv
|
|
9
|
+
import matplotlib
|
|
10
|
+
import matplotlib.pyplot as plt
|
|
11
|
+
import numpy as np
|
|
12
|
+
import os
|
|
13
|
+
import pandas as pd
|
|
14
|
+
import seaborn as sns
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
console = Console()
|
|
18
|
+
desktop_path = os.path.expanduser("~/Desktop")
|
|
19
|
+
REQUIRED_COLUMNS = ["epoch", "train_loss", "val_loss", "train_acc", "val_acc"]
|
|
20
|
+
|
|
21
|
+
import csv
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_delimiter(file_path, bytes=4096):
|
|
25
|
+
sniffer = csv.Sniffer()
|
|
26
|
+
data = open(file_path, "r").read(bytes)
|
|
27
|
+
delimiter = sniffer.sniff(data).delimiter
|
|
28
|
+
return delimiter
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# Function to verify that the DataFrame has the required columns, and only the required columns
|
|
32
|
+
def verify_csv(csv_file, required_columns=REQUIRED_COLUMNS):
|
|
33
|
+
delimiter = get_delimiter(csv_file)
|
|
34
|
+
df = pd.read_csv(csv_file, sep=delimiter)
|
|
35
|
+
# change the column names to lower case
|
|
36
|
+
df.columns = [col.lower() for col in df.columns]
|
|
37
|
+
for col in required_columns:
|
|
38
|
+
if col not in df.columns:
|
|
39
|
+
raise ValueError(
|
|
40
|
+
f"Required columns are: {REQUIRED_COLUMNS}, but found {df.columns}"
|
|
41
|
+
)
|
|
42
|
+
df = df[required_columns].copy()
|
|
43
|
+
return df
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_valid_tags(csv_files, tags):
|
|
47
|
+
if tags is not None and len(tags) > 0:
|
|
48
|
+
assert all(
|
|
49
|
+
isinstance(tag, str) for tag in tags
|
|
50
|
+
), "tags must be a list of strings"
|
|
51
|
+
assert all(
|
|
52
|
+
len(tag) > 0 for tag in tags
|
|
53
|
+
), "tags must be a list of non-empty strings"
|
|
54
|
+
valid_tags = tags
|
|
55
|
+
else:
|
|
56
|
+
valid_tags = []
|
|
57
|
+
for csv_file in csv_files:
|
|
58
|
+
file_name = fs.get_file_name(csv_file, split_file_ext=True)[0]
|
|
59
|
+
tag = norm_str(file_name)
|
|
60
|
+
valid_tags.append(tag)
|
|
61
|
+
return valid_tags
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def plot_ax(df, ax, metric="loss", tag=""):
|
|
65
|
+
pprint(locals())
|
|
66
|
+
# reset plt
|
|
67
|
+
assert metric in ["loss", "acc"], "metric must be either 'loss' or 'acc'"
|
|
68
|
+
part = ["train", "val"]
|
|
69
|
+
for p in part:
|
|
70
|
+
label = f"{tag}_{p}_{metric}"
|
|
71
|
+
ax.plot(df["epoch"], df[f"{p}_{metric}"], label=label)
|
|
72
|
+
return ax
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def actual_plot_seaborn(frame, csv_files, axes, tags, log):
|
|
76
|
+
# clear the axes
|
|
77
|
+
for ax in axes:
|
|
78
|
+
ax.clear()
|
|
79
|
+
ls_df = []
|
|
80
|
+
valid_tags = get_valid_tags(csv_files, tags)
|
|
81
|
+
for csv_file in csv_files:
|
|
82
|
+
df = verify_csv(csv_file)
|
|
83
|
+
if log:
|
|
84
|
+
with ConsoleLog(f"plotting {csv_file}"):
|
|
85
|
+
csvfile.fn_display_df(df)
|
|
86
|
+
ls_df.append(df)
|
|
87
|
+
|
|
88
|
+
ls_metrics = ["loss", "acc"]
|
|
89
|
+
for df_item, tag in zip(ls_df, valid_tags):
|
|
90
|
+
# add tag to columns,excpet epoch
|
|
91
|
+
df_item.columns = [
|
|
92
|
+
f"{tag}_{col}" if col != "epoch" else col for col in df_item.columns
|
|
93
|
+
]
|
|
94
|
+
# merge the dataframes on the epoch column
|
|
95
|
+
df_combined = ls_df[0]
|
|
96
|
+
for df_item in ls_df[1:]:
|
|
97
|
+
df_combined = pd.merge(df_combined, df_item, on="epoch", how="outer")
|
|
98
|
+
# csvfile.fn_display_df(df_combined)
|
|
99
|
+
|
|
100
|
+
for i, metric in enumerate(ls_metrics):
|
|
101
|
+
tags_str = "+".join(valid_tags) if len(valid_tags) > 1 else valid_tags[0]
|
|
102
|
+
title = f"{tags_str}_{metric}-by-epoch"
|
|
103
|
+
cols = [col for col in df_combined.columns if col != "epoch" and metric in col]
|
|
104
|
+
cols = sorted(cols)
|
|
105
|
+
# pprint(cols)
|
|
106
|
+
plot_data = df_combined[cols]
|
|
107
|
+
|
|
108
|
+
# line from same csv file (same tag) should have the same marker
|
|
109
|
+
all_markers = [
|
|
110
|
+
marker for marker in plt.Line2D.markers if marker and marker != " "
|
|
111
|
+
]
|
|
112
|
+
tag2marker = {tag: marker for tag, marker in zip(valid_tags, all_markers)}
|
|
113
|
+
plot_markers = []
|
|
114
|
+
for col in cols:
|
|
115
|
+
# find the tag:
|
|
116
|
+
tag = None
|
|
117
|
+
for valid_tag in valid_tags:
|
|
118
|
+
if valid_tag in col:
|
|
119
|
+
tag = valid_tag
|
|
120
|
+
break
|
|
121
|
+
plot_markers.append(tag2marker[tag])
|
|
122
|
+
# pprint(list(zip(cols, plot_markers)))
|
|
123
|
+
|
|
124
|
+
# create color
|
|
125
|
+
sequential_palettes = [
|
|
126
|
+
"Reds",
|
|
127
|
+
"Greens",
|
|
128
|
+
"Blues",
|
|
129
|
+
"Oranges",
|
|
130
|
+
"Purples",
|
|
131
|
+
"Greys",
|
|
132
|
+
"BuGn",
|
|
133
|
+
"BuPu",
|
|
134
|
+
"GnBu",
|
|
135
|
+
"OrRd",
|
|
136
|
+
"PuBu",
|
|
137
|
+
"PuRd",
|
|
138
|
+
"RdPu",
|
|
139
|
+
"YlGn",
|
|
140
|
+
"PuBuGn",
|
|
141
|
+
"YlGnBu",
|
|
142
|
+
"YlOrBr",
|
|
143
|
+
"YlOrRd",
|
|
144
|
+
]
|
|
145
|
+
# each csvfile (tag) should have a unique color
|
|
146
|
+
tag2palette = {
|
|
147
|
+
tag: palette for tag, palette in zip(valid_tags, sequential_palettes)
|
|
148
|
+
}
|
|
149
|
+
plot_colors = []
|
|
150
|
+
for tag in valid_tags:
|
|
151
|
+
palette = tag2palette[tag]
|
|
152
|
+
total_colors = 10
|
|
153
|
+
ls_colors = sns.color_palette(palette, total_colors).as_hex()
|
|
154
|
+
num_part = len(ls_metrics)
|
|
155
|
+
subarr = np.array_split(np.arange(total_colors), num_part)
|
|
156
|
+
for idx, col in enumerate(cols):
|
|
157
|
+
if tag in col:
|
|
158
|
+
chosen_color = ls_colors[
|
|
159
|
+
subarr[int(idx % num_part)].mean().astype(int)
|
|
160
|
+
]
|
|
161
|
+
plot_colors.append(chosen_color)
|
|
162
|
+
|
|
163
|
+
# pprint(list(zip(cols, plot_colors)))
|
|
164
|
+
sns.lineplot(
|
|
165
|
+
data=plot_data,
|
|
166
|
+
markers=plot_markers,
|
|
167
|
+
palette=plot_colors,
|
|
168
|
+
ax=axes[i],
|
|
169
|
+
dashes=False,
|
|
170
|
+
)
|
|
171
|
+
axes[i].set(xlabel="epoch", ylabel=metric, title=title)
|
|
172
|
+
axes[i].legend()
|
|
173
|
+
axes[i].grid()
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def actual_plot(frame, csv_files, axes, tags, log):
|
|
177
|
+
ls_df = []
|
|
178
|
+
valid_tags = get_valid_tags(csv_files, tags)
|
|
179
|
+
for csv_file in csv_files:
|
|
180
|
+
df = verify_csv(csv_file)
|
|
181
|
+
if log:
|
|
182
|
+
with ConsoleLog(f"plotting {csv_file}"):
|
|
183
|
+
csvfile.fn_display_df(df)
|
|
184
|
+
ls_df.append(df)
|
|
185
|
+
|
|
186
|
+
metric_values = ["loss", "acc"]
|
|
187
|
+
for i, metric in enumerate(metric_values):
|
|
188
|
+
for df_item, tag in zip(ls_df, valid_tags):
|
|
189
|
+
metric_ax = plot_ax(df_item, axes[i], metric, tag)
|
|
190
|
+
|
|
191
|
+
# set the title, xlabel, ylabel, legend, and grid
|
|
192
|
+
tags_str = "+".join(valid_tags) if len(valid_tags) > 1 else valid_tags[0]
|
|
193
|
+
metric_ax.set(
|
|
194
|
+
xlabel="epoch", ylabel=metric, title=f"{tags_str}_{metric}-by-epoch"
|
|
195
|
+
)
|
|
196
|
+
metric_ax.legend()
|
|
197
|
+
metric_ax.grid()
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def plot_csv_files(
|
|
201
|
+
csv_files,
|
|
202
|
+
outdir="./out/plot",
|
|
203
|
+
tags=None,
|
|
204
|
+
log=False,
|
|
205
|
+
save_fig=False,
|
|
206
|
+
update_in_min=1,
|
|
207
|
+
):
|
|
208
|
+
# if csv_files is a string, convert it to a list
|
|
209
|
+
if isinstance(csv_files, str):
|
|
210
|
+
csv_files = [csv_files]
|
|
211
|
+
# if tags is a string, convert it to a list
|
|
212
|
+
if isinstance(tags, str):
|
|
213
|
+
tags = [tags]
|
|
214
|
+
valid_tags = get_valid_tags(csv_files, tags)
|
|
215
|
+
assert len(valid_tags) == len(
|
|
216
|
+
csv_files
|
|
217
|
+
), "Unable to determine tags for each csv file"
|
|
218
|
+
live_update_in_ms = int(update_in_min * 60 * 1000)
|
|
219
|
+
fig, axes = plt.subplots(2, 1, figsize=(10, 17))
|
|
220
|
+
if live_update_in_ms: # live update in min should be > 0
|
|
221
|
+
from matplotlib.animation import FuncAnimation
|
|
222
|
+
|
|
223
|
+
anim = FuncAnimation(
|
|
224
|
+
fig,
|
|
225
|
+
partial(
|
|
226
|
+
actual_plot_seaborn, csv_files=csv_files, axes=axes, tags=tags, log=log
|
|
227
|
+
),
|
|
228
|
+
interval=live_update_in_ms,
|
|
229
|
+
blit=False,
|
|
230
|
+
cache_frame_data=False,
|
|
231
|
+
)
|
|
232
|
+
plt.show()
|
|
233
|
+
else:
|
|
234
|
+
actual_plot_seaborn(None, csv_files, axes, tags, log)
|
|
235
|
+
plt.show()
|
|
236
|
+
|
|
237
|
+
if save_fig:
|
|
238
|
+
os.makedirs(outdir, exist_ok=True)
|
|
239
|
+
tags_str = "+".join(valid_tags) if len(valid_tags) > 1 else valid_tags[0]
|
|
240
|
+
tag = f"{now_str()}_{tags_str}"
|
|
241
|
+
fig.savefig(f"{outdir}/{tag}_plot.png")
|
|
242
|
+
enable_plot_pgf()
|
|
243
|
+
fig.savefig(f"{outdir}/{tag}_plot.pdf")
|
|
244
|
+
if live_update_in_ms:
|
|
245
|
+
return anim
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def enable_plot_pgf():
|
|
249
|
+
matplotlib.use("pdf")
|
|
250
|
+
matplotlib.rcParams.update(
|
|
251
|
+
{
|
|
252
|
+
"pgf.texsystem": "pdflatex",
|
|
253
|
+
"font.family": "serif",
|
|
254
|
+
"text.usetex": True,
|
|
255
|
+
"pgf.rcfonts": False,
|
|
256
|
+
}
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def save_fig_latex_pgf(filename, directory="."):
|
|
261
|
+
enable_plot_pgf()
|
|
262
|
+
if ".pgf" not in filename:
|
|
263
|
+
filename = f"{directory}/{filename}.pgf"
|
|
264
|
+
plt.savefig(filename)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
# https: // click.palletsprojects.com/en/8.1.x/api/
|
|
268
|
+
@click.command()
|
|
269
|
+
@click.option("--csvfiles", "-f", multiple=True, type=str, help="csv files to plot")
|
|
270
|
+
@click.option(
|
|
271
|
+
"--outdir",
|
|
272
|
+
"-o",
|
|
273
|
+
type=str,
|
|
274
|
+
help="output directory for the plot",
|
|
275
|
+
default=str(desktop_path),
|
|
276
|
+
)
|
|
277
|
+
@click.option(
|
|
278
|
+
"--tags", "-t", multiple=True, type=str, help="tags for the csv files", default=[]
|
|
279
|
+
)
|
|
280
|
+
@click.option("--log", "-l", is_flag=True, help="log the csv files")
|
|
281
|
+
@click.option("--save_fig", "-s", is_flag=True, help="save the plot as a file")
|
|
282
|
+
@click.option(
|
|
283
|
+
"--update_in_min",
|
|
284
|
+
"-u",
|
|
285
|
+
type=float,
|
|
286
|
+
help="update the plot every x minutes",
|
|
287
|
+
default=0.0,
|
|
288
|
+
)
|
|
289
|
+
def main(
|
|
290
|
+
csvfiles,
|
|
291
|
+
outdir,
|
|
292
|
+
tags,
|
|
293
|
+
log,
|
|
294
|
+
save_fig,
|
|
295
|
+
update_in_min,
|
|
296
|
+
):
|
|
297
|
+
plot_csv_files(list(csvfiles), outdir, list(tags), log, save_fig, update_in_min)
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
if __name__ == "__main__":
|
|
301
|
+
main()
|