halib 0.2.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. halib/__init__.py +94 -0
  2. halib/common/__init__.py +0 -0
  3. halib/common/common.py +326 -0
  4. halib/common/rich_color.py +285 -0
  5. halib/common.py +151 -0
  6. halib/csvfile.py +48 -0
  7. halib/cuda.py +39 -0
  8. halib/dataset.py +209 -0
  9. halib/exp/__init__.py +0 -0
  10. halib/exp/core/__init__.py +0 -0
  11. halib/exp/core/base_config.py +167 -0
  12. halib/exp/core/base_exp.py +147 -0
  13. halib/exp/core/param_gen.py +170 -0
  14. halib/exp/core/wandb_op.py +117 -0
  15. halib/exp/data/__init__.py +0 -0
  16. halib/exp/data/dataclass_util.py +41 -0
  17. halib/exp/data/dataset.py +208 -0
  18. halib/exp/data/torchloader.py +165 -0
  19. halib/exp/perf/__init__.py +0 -0
  20. halib/exp/perf/flop_calc.py +190 -0
  21. halib/exp/perf/gpu_mon.py +58 -0
  22. halib/exp/perf/perfcalc.py +470 -0
  23. halib/exp/perf/perfmetrics.py +137 -0
  24. halib/exp/perf/perftb.py +778 -0
  25. halib/exp/perf/profiler.py +507 -0
  26. halib/exp/viz/__init__.py +0 -0
  27. halib/exp/viz/plot.py +754 -0
  28. halib/filesys.py +117 -0
  29. halib/filetype/__init__.py +0 -0
  30. halib/filetype/csvfile.py +192 -0
  31. halib/filetype/ipynb.py +61 -0
  32. halib/filetype/jsonfile.py +19 -0
  33. halib/filetype/textfile.py +12 -0
  34. halib/filetype/videofile.py +266 -0
  35. halib/filetype/yamlfile.py +87 -0
  36. halib/gdrive.py +179 -0
  37. halib/gdrive_mkdir.py +41 -0
  38. halib/gdrive_test.py +37 -0
  39. halib/jsonfile.py +22 -0
  40. halib/listop.py +13 -0
  41. halib/online/__init__.py +0 -0
  42. halib/online/gdrive.py +229 -0
  43. halib/online/gdrive_mkdir.py +53 -0
  44. halib/online/gdrive_test.py +50 -0
  45. halib/online/projectmake.py +131 -0
  46. halib/online/tele_noti.py +165 -0
  47. halib/plot.py +301 -0
  48. halib/projectmake.py +115 -0
  49. halib/research/__init__.py +0 -0
  50. halib/research/base_config.py +100 -0
  51. halib/research/base_exp.py +157 -0
  52. halib/research/benchquery.py +131 -0
  53. halib/research/core/__init__.py +0 -0
  54. halib/research/core/base_config.py +144 -0
  55. halib/research/core/base_exp.py +157 -0
  56. halib/research/core/param_gen.py +108 -0
  57. halib/research/core/wandb_op.py +117 -0
  58. halib/research/data/__init__.py +0 -0
  59. halib/research/data/dataclass_util.py +41 -0
  60. halib/research/data/dataset.py +208 -0
  61. halib/research/data/torchloader.py +165 -0
  62. halib/research/dataset.py +208 -0
  63. halib/research/flop_csv.py +34 -0
  64. halib/research/flops.py +156 -0
  65. halib/research/metrics.py +137 -0
  66. halib/research/mics.py +74 -0
  67. halib/research/params_gen.py +108 -0
  68. halib/research/perf/__init__.py +0 -0
  69. halib/research/perf/flop_calc.py +190 -0
  70. halib/research/perf/gpu_mon.py +58 -0
  71. halib/research/perf/perfcalc.py +363 -0
  72. halib/research/perf/perfmetrics.py +137 -0
  73. halib/research/perf/perftb.py +778 -0
  74. halib/research/perf/profiler.py +301 -0
  75. halib/research/perfcalc.py +361 -0
  76. halib/research/perftb.py +780 -0
  77. halib/research/plot.py +758 -0
  78. halib/research/profiler.py +300 -0
  79. halib/research/torchloader.py +162 -0
  80. halib/research/viz/__init__.py +0 -0
  81. halib/research/viz/plot.py +754 -0
  82. halib/research/wandb_op.py +116 -0
  83. halib/rich_color.py +285 -0
  84. halib/sys/__init__.py +0 -0
  85. halib/sys/cmd.py +8 -0
  86. halib/sys/filesys.py +124 -0
  87. halib/system/__init__.py +0 -0
  88. halib/system/_list_pc.csv +6 -0
  89. halib/system/cmd.py +8 -0
  90. halib/system/filesys.py +164 -0
  91. halib/system/path.py +106 -0
  92. halib/tele_noti.py +166 -0
  93. halib/textfile.py +13 -0
  94. halib/torchloader.py +162 -0
  95. halib/utils/__init__.py +0 -0
  96. halib/utils/dataclass_util.py +40 -0
  97. halib/utils/dict.py +317 -0
  98. halib/utils/dict_op.py +9 -0
  99. halib/utils/gpu_mon.py +58 -0
  100. halib/utils/list.py +17 -0
  101. halib/utils/listop.py +13 -0
  102. halib/utils/slack.py +86 -0
  103. halib/utils/tele_noti.py +166 -0
  104. halib/utils/video.py +82 -0
  105. halib/videofile.py +139 -0
  106. halib-0.2.30.dist-info/METADATA +237 -0
  107. halib-0.2.30.dist-info/RECORD +110 -0
  108. halib-0.2.30.dist-info/WHEEL +5 -0
  109. halib-0.2.30.dist-info/licenses/LICENSE.txt +17 -0
  110. halib-0.2.30.dist-info/top_level.txt +1 -0
@@ -0,0 +1,50 @@
1
+ from argparse import ArgumentParser
2
+ import gdrive
3
+
4
+
5
+ def parse_args():
6
+ parser = ArgumentParser(description="Upload local folder to Google Drive")
7
+ parser.add_argument(
8
+ "-a",
9
+ "--authFile",
10
+ type=str,
11
+ help="authenticate file to Google Drive",
12
+ default="settings.yaml",
13
+ )
14
+ parser.add_argument("-s", "--source", type=str, help="Folder to upload")
15
+ parser.add_argument(
16
+ "-d", "--destination", type=str, help="Destination folder ID in Google Drive"
17
+ )
18
+ parser.add_argument(
19
+ "-c",
20
+ "--contentOnly",
21
+ type=str,
22
+ help="Parent Folder in Google Drive",
23
+ default="True",
24
+ )
25
+ parser.add_argument(
26
+ "-i",
27
+ "--ignoreFile",
28
+ type=str,
29
+ help="file containing files/folders to ignore",
30
+ default=None,
31
+ )
32
+
33
+ return parser.parse_args()
34
+
35
+
36
+ def main():
37
+ args = parse_args()
38
+ auth_file = args.authFile
39
+ local_folder = args.source
40
+ gg_folder_id = args.destination
41
+ content_only = args.contentOnly.lower() == "true"
42
+ ignore_file = args.ignoreFile
43
+ gdrive.get_gg_drive(auth_file)
44
+ gdrive.upload_folder_to_drive(
45
+ local_folder, gg_folder_id, content_only=content_only, ignore_file=ignore_file
46
+ )
47
+
48
+
49
+ if __name__ == "__main__":
50
+ main()
@@ -0,0 +1,131 @@
1
+ # coding=utf-8
2
+
3
+ import os
4
+ import json
5
+ import pycurl
6
+ import shutil
7
+ import certifi
8
+ import subprocess
9
+ from io import BytesIO
10
+
11
+ from argparse import ArgumentParser
12
+
13
+ from ..filetype import jsonfile
14
+ from ..system import filesys
15
+
16
+ def get_curl(url, user_and_pass, verbose=True):
17
+ c = pycurl.Curl()
18
+ c.setopt(pycurl.VERBOSE, verbose)
19
+ c.setopt(pycurl.CAINFO, certifi.where())
20
+ c.setopt(pycurl.URL, url)
21
+ c.setopt(pycurl.USERPWD, user_and_pass)
22
+ return c
23
+
24
+
25
+ def get_user_and_pass(username, appPass):
26
+ return f"{username}:{appPass}"
27
+
28
+
29
+ def create_repo(
30
+ username, appPass, repo_name, workspace, proj_name, template_repo="py-proj-template"
31
+ ):
32
+ buffer = BytesIO()
33
+ url = f"https://api.bitbucket.org/2.0/repositories/{workspace}/{repo_name}"
34
+ data = json.dumps({"scm": "git", "project": {"key": f"{proj_name}"}})
35
+
36
+ user_and_pass = get_user_and_pass(username, appPass)
37
+ c = get_curl(url, user_and_pass)
38
+ c.setopt(pycurl.WRITEDATA, buffer)
39
+ c.setopt(pycurl.POST, 1)
40
+ c.setopt(pycurl.POSTFIELDS, data)
41
+ c.setopt(pycurl.HTTPHEADER, ["Accept: application/json"])
42
+ c.perform()
43
+ RESPOND_CODE = c.getinfo(pycurl.HTTP_CODE)
44
+ c.close()
45
+ # log info
46
+ body = buffer.getvalue()
47
+ msg = body.decode("iso-8859-1")
48
+ successful = True if str(RESPOND_CODE) == "200" else False
49
+
50
+ if successful and template_repo:
51
+ template_repo_url = f"https://{username}:{appPass}@bitbucket.org/{workspace}/{template_repo}.git"
52
+ git_clone(template_repo_url)
53
+ template_folder = f"./{template_repo}"
54
+
55
+ created_repo_url = (
56
+ f"https://{username}:{appPass}@bitbucket.org/{workspace}/{repo_name}.git"
57
+ )
58
+ git_clone(created_repo_url)
59
+ created_folder = f"./{repo_name}"
60
+ shutil.copytree(
61
+ template_folder,
62
+ created_folder,
63
+ dirs_exist_ok=True,
64
+ ignore=shutil.ignore_patterns(".git"),
65
+ )
66
+ os.system('rmdir /S /Q "{}"'.format(template_folder))
67
+ project_folder = "project_name"
68
+
69
+ filesys.change_current_dir(created_folder)
70
+ filesys.rename_dir_or_file(project_folder, repo_name)
71
+ # push to remote
72
+ subprocess.check_call(["C:/batch/gitp.bat", "init proj from template"])
73
+
74
+ return successful, msg
75
+
76
+
77
+ def parse_args():
78
+ parser = ArgumentParser(description="Upload local folder to Google Drive")
79
+ parser.add_argument(
80
+ "-a",
81
+ "--authFile",
82
+ type=str,
83
+ help="authenticate file (json) to Bitbucket",
84
+ default="bitbucket.json",
85
+ )
86
+ parser.add_argument(
87
+ "-r", "--repoName", type=str, help="Repository name", default="hahv-proj"
88
+ )
89
+ parser.add_argument(
90
+ "-t", "--templateRepo", type=str, help="template repo to fork", default="True"
91
+ )
92
+ return parser.parse_args()
93
+
94
+
95
+ def git_clone(url):
96
+ subprocess.check_call(["git", "clone", url])
97
+
98
+
99
+ def main():
100
+ args = parse_args()
101
+ authFile = args.authFile
102
+ repo_name = args.repoName
103
+
104
+ authInfo = jsonfile.read(authFile)
105
+ username = authInfo["username"]
106
+ appPass = authInfo["appPass"]
107
+ workspace_id = authInfo["workspace_id"]
108
+ project_id = authInfo["project_id"]
109
+ use_template = args.templateRepo.lower() == "true"
110
+ template_repo = authInfo["template_repo"] if use_template else ""
111
+
112
+ extra_info = f"[Use template project {template_repo}]" if use_template else ""
113
+ print(f"[BitBucket] creating {repo_name} Project in Bitbucket {extra_info}")
114
+
115
+ successful, msg = create_repo(
116
+ username,
117
+ appPass,
118
+ repo_name,
119
+ workspace_id,
120
+ project_id,
121
+ template_repo=template_repo,
122
+ )
123
+ if successful:
124
+ print(f"[Bitbucket] {repo_name} created successfully.{extra_info}")
125
+ else:
126
+ formatted_msg = jsonfile.beautify(msg)
127
+ print(f"[Bitbucket] {repo_name} created failed. Details:\n{formatted_msg}")
128
+
129
+
130
+ if __name__ == "__main__":
131
+ main()
@@ -0,0 +1,165 @@
1
+ # Watch a log file and send a telegram message when train reaches a certain epoch or end
2
+
3
+ import os
4
+ import yaml
5
+ import asyncio
6
+ import telegram
7
+ import pandas as pd
8
+
9
+ from rich.pretty import pprint
10
+ from rich.console import Console
11
+ import plotly.graph_objects as go
12
+
13
+ from ..system import filesys as fs
14
+ from ..filetype import textfile, csvfile
15
+
16
+ from argparse import ArgumentParser
17
+
18
+ tele_console = Console()
19
+
20
+
21
+ def parse_args():
22
+ parser = ArgumentParser(description="desc text")
23
+ parser.add_argument(
24
+ "-cfg",
25
+ "--cfg",
26
+ type=str,
27
+ help="yaml file for tele",
28
+ default=r"E:\Dev\__halib\halib\online\tele_noti_cfg.yaml",
29
+ )
30
+
31
+ return parser.parse_args()
32
+
33
+ def get_watcher_message_df(target_file, num_last_lines):
34
+ file_ext = fs.get_file_name(target_file, split_file_ext=True)[1]
35
+ supported_ext = [".txt", ".log", ".csv"]
36
+ assert (
37
+ file_ext in supported_ext
38
+ ), f"File extension {file_ext} not supported. Supported extensions are {supported_ext}"
39
+ last_lines_df = None
40
+ if file_ext in [".txt", ".log"]:
41
+ lines = textfile.read_line_by_line(target_file)
42
+ if num_last_lines > len(lines):
43
+ num_last_lines = len(lines)
44
+ last_line_arr = lines[-num_last_lines:]
45
+ # add a line start with word "epoch"
46
+ epoch_info_list = "Epoch: n/a"
47
+ for line in reversed(lines):
48
+ if "epoch" in line.lower():
49
+ epoch_info_list = line
50
+ break
51
+ last_line_arr.insert(0, epoch_info_list) # insert at the beginning
52
+ dfCreator = csvfile.DFCreator()
53
+ dfCreator.create_table("last_lines", ["line"])
54
+ last_line_arr = [[line] for line in last_line_arr]
55
+ dfCreator.insert_rows("last_lines", last_line_arr)
56
+ dfCreator.fill_table_from_row_pool("last_lines")
57
+ last_lines_df = dfCreator["last_lines"].copy()
58
+ else:
59
+ df = pd.read_csv(target_file)
60
+ num_rows = len(df)
61
+ if num_last_lines > num_rows:
62
+ num_last_lines = num_rows
63
+ last_lines_df = df.tail(num_last_lines)
64
+ return last_lines_df
65
+
66
+
67
+ def df2img(df: pd.DataFrame, output_img_dir, decimal_places, out_img_scale):
68
+ df = df.round(decimal_places)
69
+ fig = go.Figure(
70
+ data=[
71
+ go.Table(
72
+ header=dict(values=list(df.columns), align="center"),
73
+ cells=dict(
74
+ values=df.values.transpose(),
75
+ fill_color=[["white", "lightgrey"] * df.shape[0]],
76
+ align="center",
77
+ ),
78
+ )
79
+ ]
80
+ )
81
+ if not os.path.exists(output_img_dir):
82
+ os.makedirs(output_img_dir)
83
+ img_path = os.path.normpath(os.path.join(output_img_dir, "last_lines.png"))
84
+ fig.write_image(img_path, scale=out_img_scale)
85
+ return img_path
86
+
87
+
88
+ def compose_message_and_img_path(
89
+ target_file, project, num_last_lines, decimal_places, out_img_scale, output_img_dir
90
+ ):
91
+ context_msg = f">> Project: {project} \n>> File: {target_file} \n>> Last {num_last_lines} lines:"
92
+ msg_df = get_watcher_message_df(target_file, num_last_lines)
93
+ try:
94
+ img_path = df2img(msg_df, output_img_dir, decimal_places, out_img_scale)
95
+ except Exception as e:
96
+ pprint(f"Error: {e}")
97
+ img_path = None
98
+ return context_msg, img_path
99
+
100
+
101
+ async def send_to_telegram(cfg_dict, interval_in_sec):
102
+ # pprint(cfg_dict)
103
+ token = cfg_dict["telegram"]["token"]
104
+ chat_id = cfg_dict["telegram"]["chat_id"]
105
+
106
+ noti_settings = cfg_dict["noti_settings"]
107
+ project = noti_settings["project"]
108
+ target_file = noti_settings["target_file"]
109
+ num_last_lines = noti_settings["num_last_lines"]
110
+ output_img_dir = noti_settings["output_img_dir"]
111
+ decimal_places = noti_settings["decimal_places"]
112
+ out_img_scale = noti_settings["out_img_scale"]
113
+
114
+ bot = telegram.Bot(token=token)
115
+ async with bot:
116
+ try:
117
+ context_msg, img_path = compose_message_and_img_path(
118
+ target_file,
119
+ project,
120
+ num_last_lines,
121
+ decimal_places,
122
+ out_img_scale,
123
+ output_img_dir,
124
+ )
125
+ time_now = next_time = pd.Timestamp.now().strftime("%Y-%m-%d %H:%M:%S")
126
+ sep_line = "-" * 50
127
+ context_msg = f"{sep_line}\n>> Time: {time_now}\n{context_msg}"
128
+ # calculate the next time to send message
129
+ next_time = pd.Timestamp.now() + pd.Timedelta(seconds=interval_in_sec)
130
+ next_time = next_time.strftime("%Y-%m-%d %H:%M:%S")
131
+ next_time_info = f"Next msg: {next_time}"
132
+ tele_console.rule()
133
+ tele_console.print("[green] Send message to telegram [/green]")
134
+ tele_console.print(
135
+ f"[red] Next message will be sent at <{next_time}> [/red]"
136
+ )
137
+ await bot.send_message(text=context_msg, chat_id=chat_id)
138
+ if img_path:
139
+ await bot.send_photo(chat_id=chat_id, photo=open(img_path, "rb"))
140
+ await bot.send_message(text=next_time_info, chat_id=chat_id)
141
+ except Exception as e:
142
+ pprint(f"Error: {e}")
143
+ pprint("Message not sent to telegram")
144
+
145
+
146
+ async def run_forever(cfg_path):
147
+ cfg_dict = yaml.safe_load(open(cfg_path, "r"))
148
+ noti_settings = cfg_dict["noti_settings"]
149
+ interval_in_min = noti_settings["interval_in_min"]
150
+ interval_in_sec = int(interval_in_min * 60)
151
+ pprint(
152
+ f"Message will be sent every {interval_in_min} minutes or {interval_in_sec} seconds"
153
+ )
154
+ while True:
155
+ await send_to_telegram(cfg_dict, interval_in_sec)
156
+ await asyncio.sleep(interval_in_sec)
157
+
158
+
159
+ async def main():
160
+ args = parse_args()
161
+ await run_forever(args.cfg)
162
+
163
+
164
+ if __name__ == "__main__":
165
+ asyncio.run(main())
halib/plot.py ADDED
@@ -0,0 +1,301 @@
1
+ from .common import now_str, norm_str, ConsoleLog
2
+ from .filetype import csvfile
3
+ from .system import filesys as fs
4
+ from functools import partial
5
+ from rich.console import Console
6
+ from rich.pretty import pprint
7
+ import click
8
+ import csv
9
+ import matplotlib
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import os
13
+ import pandas as pd
14
+ import seaborn as sns
15
+
16
+
17
+ console = Console()
18
+ desktop_path = os.path.expanduser("~/Desktop")
19
+ REQUIRED_COLUMNS = ["epoch", "train_loss", "val_loss", "train_acc", "val_acc"]
20
+
21
+ import csv
22
+
23
+
24
+ def get_delimiter(file_path, bytes=4096):
25
+ sniffer = csv.Sniffer()
26
+ data = open(file_path, "r").read(bytes)
27
+ delimiter = sniffer.sniff(data).delimiter
28
+ return delimiter
29
+
30
+
31
+ # Function to verify that the DataFrame has the required columns, and only the required columns
32
+ def verify_csv(csv_file, required_columns=REQUIRED_COLUMNS):
33
+ delimiter = get_delimiter(csv_file)
34
+ df = pd.read_csv(csv_file, sep=delimiter)
35
+ # change the column names to lower case
36
+ df.columns = [col.lower() for col in df.columns]
37
+ for col in required_columns:
38
+ if col not in df.columns:
39
+ raise ValueError(
40
+ f"Required columns are: {REQUIRED_COLUMNS}, but found {df.columns}"
41
+ )
42
+ df = df[required_columns].copy()
43
+ return df
44
+
45
+
46
+ def get_valid_tags(csv_files, tags):
47
+ if tags is not None and len(tags) > 0:
48
+ assert all(
49
+ isinstance(tag, str) for tag in tags
50
+ ), "tags must be a list of strings"
51
+ assert all(
52
+ len(tag) > 0 for tag in tags
53
+ ), "tags must be a list of non-empty strings"
54
+ valid_tags = tags
55
+ else:
56
+ valid_tags = []
57
+ for csv_file in csv_files:
58
+ file_name = fs.get_file_name(csv_file, split_file_ext=True)[0]
59
+ tag = norm_str(file_name)
60
+ valid_tags.append(tag)
61
+ return valid_tags
62
+
63
+
64
+ def plot_ax(df, ax, metric="loss", tag=""):
65
+ pprint(locals())
66
+ # reset plt
67
+ assert metric in ["loss", "acc"], "metric must be either 'loss' or 'acc'"
68
+ part = ["train", "val"]
69
+ for p in part:
70
+ label = f"{tag}_{p}_{metric}"
71
+ ax.plot(df["epoch"], df[f"{p}_{metric}"], label=label)
72
+ return ax
73
+
74
+
75
+ def actual_plot_seaborn(frame, csv_files, axes, tags, log):
76
+ # clear the axes
77
+ for ax in axes:
78
+ ax.clear()
79
+ ls_df = []
80
+ valid_tags = get_valid_tags(csv_files, tags)
81
+ for csv_file in csv_files:
82
+ df = verify_csv(csv_file)
83
+ if log:
84
+ with ConsoleLog(f"plotting {csv_file}"):
85
+ csvfile.fn_display_df(df)
86
+ ls_df.append(df)
87
+
88
+ ls_metrics = ["loss", "acc"]
89
+ for df_item, tag in zip(ls_df, valid_tags):
90
+ # add tag to columns,excpet epoch
91
+ df_item.columns = [
92
+ f"{tag}_{col}" if col != "epoch" else col for col in df_item.columns
93
+ ]
94
+ # merge the dataframes on the epoch column
95
+ df_combined = ls_df[0]
96
+ for df_item in ls_df[1:]:
97
+ df_combined = pd.merge(df_combined, df_item, on="epoch", how="outer")
98
+ # csvfile.fn_display_df(df_combined)
99
+
100
+ for i, metric in enumerate(ls_metrics):
101
+ tags_str = "+".join(valid_tags) if len(valid_tags) > 1 else valid_tags[0]
102
+ title = f"{tags_str}_{metric}-by-epoch"
103
+ cols = [col for col in df_combined.columns if col != "epoch" and metric in col]
104
+ cols = sorted(cols)
105
+ # pprint(cols)
106
+ plot_data = df_combined[cols]
107
+
108
+ # line from same csv file (same tag) should have the same marker
109
+ all_markers = [
110
+ marker for marker in plt.Line2D.markers if marker and marker != " "
111
+ ]
112
+ tag2marker = {tag: marker for tag, marker in zip(valid_tags, all_markers)}
113
+ plot_markers = []
114
+ for col in cols:
115
+ # find the tag:
116
+ tag = None
117
+ for valid_tag in valid_tags:
118
+ if valid_tag in col:
119
+ tag = valid_tag
120
+ break
121
+ plot_markers.append(tag2marker[tag])
122
+ # pprint(list(zip(cols, plot_markers)))
123
+
124
+ # create color
125
+ sequential_palettes = [
126
+ "Reds",
127
+ "Greens",
128
+ "Blues",
129
+ "Oranges",
130
+ "Purples",
131
+ "Greys",
132
+ "BuGn",
133
+ "BuPu",
134
+ "GnBu",
135
+ "OrRd",
136
+ "PuBu",
137
+ "PuRd",
138
+ "RdPu",
139
+ "YlGn",
140
+ "PuBuGn",
141
+ "YlGnBu",
142
+ "YlOrBr",
143
+ "YlOrRd",
144
+ ]
145
+ # each csvfile (tag) should have a unique color
146
+ tag2palette = {
147
+ tag: palette for tag, palette in zip(valid_tags, sequential_palettes)
148
+ }
149
+ plot_colors = []
150
+ for tag in valid_tags:
151
+ palette = tag2palette[tag]
152
+ total_colors = 10
153
+ ls_colors = sns.color_palette(palette, total_colors).as_hex()
154
+ num_part = len(ls_metrics)
155
+ subarr = np.array_split(np.arange(total_colors), num_part)
156
+ for idx, col in enumerate(cols):
157
+ if tag in col:
158
+ chosen_color = ls_colors[
159
+ subarr[int(idx % num_part)].mean().astype(int)
160
+ ]
161
+ plot_colors.append(chosen_color)
162
+
163
+ # pprint(list(zip(cols, plot_colors)))
164
+ sns.lineplot(
165
+ data=plot_data,
166
+ markers=plot_markers,
167
+ palette=plot_colors,
168
+ ax=axes[i],
169
+ dashes=False,
170
+ )
171
+ axes[i].set(xlabel="epoch", ylabel=metric, title=title)
172
+ axes[i].legend()
173
+ axes[i].grid()
174
+
175
+
176
+ def actual_plot(frame, csv_files, axes, tags, log):
177
+ ls_df = []
178
+ valid_tags = get_valid_tags(csv_files, tags)
179
+ for csv_file in csv_files:
180
+ df = verify_csv(csv_file)
181
+ if log:
182
+ with ConsoleLog(f"plotting {csv_file}"):
183
+ csvfile.fn_display_df(df)
184
+ ls_df.append(df)
185
+
186
+ metric_values = ["loss", "acc"]
187
+ for i, metric in enumerate(metric_values):
188
+ for df_item, tag in zip(ls_df, valid_tags):
189
+ metric_ax = plot_ax(df_item, axes[i], metric, tag)
190
+
191
+ # set the title, xlabel, ylabel, legend, and grid
192
+ tags_str = "+".join(valid_tags) if len(valid_tags) > 1 else valid_tags[0]
193
+ metric_ax.set(
194
+ xlabel="epoch", ylabel=metric, title=f"{tags_str}_{metric}-by-epoch"
195
+ )
196
+ metric_ax.legend()
197
+ metric_ax.grid()
198
+
199
+
200
+ def plot_csv_files(
201
+ csv_files,
202
+ outdir="./out/plot",
203
+ tags=None,
204
+ log=False,
205
+ save_fig=False,
206
+ update_in_min=1,
207
+ ):
208
+ # if csv_files is a string, convert it to a list
209
+ if isinstance(csv_files, str):
210
+ csv_files = [csv_files]
211
+ # if tags is a string, convert it to a list
212
+ if isinstance(tags, str):
213
+ tags = [tags]
214
+ valid_tags = get_valid_tags(csv_files, tags)
215
+ assert len(valid_tags) == len(
216
+ csv_files
217
+ ), "Unable to determine tags for each csv file"
218
+ live_update_in_ms = int(update_in_min * 60 * 1000)
219
+ fig, axes = plt.subplots(2, 1, figsize=(10, 17))
220
+ if live_update_in_ms: # live update in min should be > 0
221
+ from matplotlib.animation import FuncAnimation
222
+
223
+ anim = FuncAnimation(
224
+ fig,
225
+ partial(
226
+ actual_plot_seaborn, csv_files=csv_files, axes=axes, tags=tags, log=log
227
+ ),
228
+ interval=live_update_in_ms,
229
+ blit=False,
230
+ cache_frame_data=False,
231
+ )
232
+ plt.show()
233
+ else:
234
+ actual_plot_seaborn(None, csv_files, axes, tags, log)
235
+ plt.show()
236
+
237
+ if save_fig:
238
+ os.makedirs(outdir, exist_ok=True)
239
+ tags_str = "+".join(valid_tags) if len(valid_tags) > 1 else valid_tags[0]
240
+ tag = f"{now_str()}_{tags_str}"
241
+ fig.savefig(f"{outdir}/{tag}_plot.png")
242
+ enable_plot_pgf()
243
+ fig.savefig(f"{outdir}/{tag}_plot.pdf")
244
+ if live_update_in_ms:
245
+ return anim
246
+
247
+
248
+ def enable_plot_pgf():
249
+ matplotlib.use("pdf")
250
+ matplotlib.rcParams.update(
251
+ {
252
+ "pgf.texsystem": "pdflatex",
253
+ "font.family": "serif",
254
+ "text.usetex": True,
255
+ "pgf.rcfonts": False,
256
+ }
257
+ )
258
+
259
+
260
+ def save_fig_latex_pgf(filename, directory="."):
261
+ enable_plot_pgf()
262
+ if ".pgf" not in filename:
263
+ filename = f"{directory}/{filename}.pgf"
264
+ plt.savefig(filename)
265
+
266
+
267
+ # https: // click.palletsprojects.com/en/8.1.x/api/
268
+ @click.command()
269
+ @click.option("--csvfiles", "-f", multiple=True, type=str, help="csv files to plot")
270
+ @click.option(
271
+ "--outdir",
272
+ "-o",
273
+ type=str,
274
+ help="output directory for the plot",
275
+ default=str(desktop_path),
276
+ )
277
+ @click.option(
278
+ "--tags", "-t", multiple=True, type=str, help="tags for the csv files", default=[]
279
+ )
280
+ @click.option("--log", "-l", is_flag=True, help="log the csv files")
281
+ @click.option("--save_fig", "-s", is_flag=True, help="save the plot as a file")
282
+ @click.option(
283
+ "--update_in_min",
284
+ "-u",
285
+ type=float,
286
+ help="update the plot every x minutes",
287
+ default=0.0,
288
+ )
289
+ def main(
290
+ csvfiles,
291
+ outdir,
292
+ tags,
293
+ log,
294
+ save_fig,
295
+ update_in_min,
296
+ ):
297
+ plot_csv_files(list(csvfiles), outdir, list(tags), log, save_fig, update_in_min)
298
+
299
+
300
+ if __name__ == "__main__":
301
+ main()