halib 0.1.47__py3-none-any.whl → 0.1.48__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,131 @@
1
+ import pandas as pd
2
+ from rich.pretty import pprint
3
+ from argparse import ArgumentParser
4
+
5
+ def cols_to_col_groups(df):
6
+ columns = list(df.columns)
7
+ # pprint(columns)
8
+
9
+ col_groups = []
10
+ current_group = []
11
+
12
+ def have_unnamed(col_group):
13
+ return any("unnamed" in col.lower() for col in col_group)
14
+
15
+ for i, col in enumerate(columns):
16
+ # Add the first column to the current group
17
+ if not current_group:
18
+ current_group.append(col)
19
+ continue
20
+
21
+ prev_col = columns[i - 1]
22
+ # Check if current column is "unnamed" or shares base name with previous
23
+ # Assuming "equal" means same base name (before any suffix like '_1')
24
+ base_prev = (
25
+ prev_col.split("_")[0].lower() if "_" in prev_col else prev_col.lower()
26
+ )
27
+ base_col = col.split("_")[0].lower() if "_" in col else col.lower()
28
+ is_unnamed = "unnamed" in col.lower()
29
+ is_equal = base_col == base_prev
30
+
31
+ if is_unnamed or is_equal:
32
+ # Add to current group
33
+ current_group.append(col)
34
+ else:
35
+ # Start a new group
36
+ col_groups.append(current_group)
37
+ current_group = [col]
38
+ # Append the last group
39
+ if current_group:
40
+ col_groups.append(current_group)
41
+ meta_dict = {"common_cols": [], "db_cols": []}
42
+ for group in col_groups:
43
+ if not have_unnamed(group):
44
+ meta_dict["common_cols"].extend(group)
45
+ else:
46
+ # find the first unnamed column
47
+ named_col = next(
48
+ (col for col in group if "unnamed" not in col.lower()), None
49
+ )
50
+ group_cols = [f"{named_col}_{i}" for i in range(len(group))]
51
+ meta_dict["db_cols"].extend(group_cols)
52
+ return meta_dict
53
+
54
+ # def bech_by_db_name(df, db_list="db1, db2", key_metrics="p, r, f1, acc"):
55
+
56
+
57
+ def str_2_list(input_str, sep=","):
58
+ out_ls = []
59
+ if len(input_str.strip()) == 0:
60
+ return out_ls
61
+ if sep not in input_str:
62
+ out_ls.append(input_str.strip())
63
+ return out_ls
64
+ else:
65
+ out_ls = [item.strip() for item in input_str.split(sep) if item.strip()]
66
+ return out_ls
67
+
68
+
69
+ def filter_bech_df_by_db_and_metrics(df, db_list="", key_metrics=""):
70
+ meta_cols_dict = cols_to_col_groups(df)
71
+ op_df = df.copy()
72
+ op_df.columns = (
73
+ meta_cols_dict["common_cols"].copy() + meta_cols_dict["db_cols"].copy()
74
+ )
75
+ filterd_cols = []
76
+ filterd_cols.extend(meta_cols_dict["common_cols"])
77
+
78
+ selected_db_list = str_2_list(db_list)
79
+ db_filted_cols = []
80
+ if len(selected_db_list) > 0:
81
+ for db_name in db_list.split(","):
82
+ db_name = db_name.strip()
83
+ for col_name in meta_cols_dict["db_cols"]:
84
+ if db_name.lower() in col_name.lower():
85
+ db_filted_cols.append(col_name)
86
+ else:
87
+ db_filted_cols = meta_cols_dict["db_cols"]
88
+
89
+ filterd_cols.extend(db_filted_cols)
90
+ df_filtered = op_df[filterd_cols].copy()
91
+ df_filtered
92
+
93
+ selected_metrics_ls = str_2_list(key_metrics)
94
+ if len(selected_metrics_ls) > 0:
95
+ # get the second row as metrics row (header)
96
+ metrics_row = df_filtered.iloc[0].copy()
97
+ # only get the values in columns in (db_filterd_cols)
98
+ metrics_values = metrics_row[db_filted_cols].values
99
+ keep_metrics_cols = []
100
+ # create a zip of db_filted_cols and metrics_values (in that metrics_row)
101
+ metrics_list = list(zip(metrics_values, db_filted_cols))
102
+ selected_metrics_ls = [metric.strip().lower() for metric in selected_metrics_ls]
103
+ for metric, col_name in metrics_list:
104
+ if metric.lower() in selected_metrics_ls:
105
+ keep_metrics_cols.append(col_name)
106
+
107
+ else:
108
+ pprint("No metrics selected, keeping all db columns")
109
+ keep_metrics_cols = db_filted_cols
110
+
111
+ final_filterd_cols = meta_cols_dict["common_cols"].copy() + keep_metrics_cols
112
+ df_final = df_filtered[final_filterd_cols].copy()
113
+ return df_final
114
+
115
+
116
+ def parse_args():
117
+ parser = ArgumentParser(
118
+ description="desc text")
119
+ parser.add_argument('-csv', '--csv', type=str, help='CSV file path', default=r"E:\Dev\__halib\test\bench.csv")
120
+ return parser.parse_args()
121
+
122
+
123
+ def main():
124
+ args = parse_args()
125
+ csv_file = args.csv
126
+ df = pd.read_csv(csv_file, sep=";", encoding="utf-8")
127
+ filtered_df = filter_bech_df_by_db_and_metrics(df, "bowfire", "acc")
128
+ print(filtered_df)
129
+
130
+ if __name__ == "__main__":
131
+ main()
@@ -0,0 +1,209 @@
1
+ # This script create a test version
2
+ # of the watcam (wc) dataset
3
+ # for testing the tflite model
4
+
5
+ from argparse import ArgumentParser
6
+
7
+ from rich import inspect
8
+ from common import console, seed_everything, ConsoleLog
9
+ from sklearn.model_selection import StratifiedShuffleSplit, ShuffleSplit
10
+ from tqdm import tqdm
11
+ import os
12
+ import click
13
+ from torchvision.datasets import ImageFolder
14
+ import shutil
15
+ from rich.pretty import pprint
16
+ from system import filesys as fs
17
+ import glob
18
+
19
+
20
+ def parse_args():
21
+ parser = ArgumentParser(description="desc text")
22
+ parser.add_argument(
23
+ "-indir",
24
+ "--indir",
25
+ type=str,
26
+ help="orignal dataset path",
27
+ )
28
+ parser.add_argument(
29
+ "-outdir",
30
+ "--outdir",
31
+ type=str,
32
+ help="dataset out path",
33
+ default=".", # default to current dir
34
+ )
35
+ parser.add_argument(
36
+ "-val_size",
37
+ "--val_size",
38
+ type=float,
39
+ help="validation size", # no default value to force user to input
40
+ default=0.2,
41
+ )
42
+ # add using StratifiedShuffleSplit or ShuffleSplit
43
+ parser.add_argument(
44
+ "-seed",
45
+ "--seed",
46
+ type=int,
47
+ help="random seed",
48
+ default=42,
49
+ )
50
+ parser.add_argument(
51
+ "-inplace",
52
+ "--inplace",
53
+ action="store_true",
54
+ help="inplace operation, will overwrite the outdir if exists",
55
+ )
56
+
57
+ parser.add_argument(
58
+ "-stratified",
59
+ "--stratified",
60
+ action="store_true",
61
+ help="use StratifiedShuffleSplit instead of ShuffleSplit",
62
+ )
63
+ parser.add_argument(
64
+ "-no_train",
65
+ "--no_train",
66
+ action="store_true",
67
+ help="only create test set, no train set",
68
+ )
69
+ parser.add_argument(
70
+ "-reverse",
71
+ "--reverse",
72
+ action="store_true",
73
+ help="combine train and val set back to original dataset",
74
+ )
75
+ return parser.parse_args()
76
+
77
+
78
+ def move_images(image_paths, target_set_dir):
79
+ for img_path in tqdm(image_paths):
80
+ # get folder name of the image
81
+ img_dir = os.path.dirname(img_path)
82
+ out_cls_dir = os.path.join(target_set_dir, os.path.basename(img_dir))
83
+ if not os.path.exists(out_cls_dir):
84
+ os.makedirs(out_cls_dir)
85
+ # move the image to the class folder
86
+ shutil.move(img_path, out_cls_dir)
87
+
88
+
89
+ def split_dataset_cls(
90
+ indir, outdir, val_size, seed, inplace, stratified_split, no_train
91
+ ):
92
+ seed_everything(seed)
93
+ console.rule("Config confirm?")
94
+ pprint(locals())
95
+ click.confirm("Continue?", abort=True)
96
+ assert os.path.exists(indir), f"{indir} does not exist"
97
+
98
+ if not inplace:
99
+ assert (not inplace) and (
100
+ not os.path.exists(outdir)
101
+ ), f"{outdir} already exists; SKIP ...."
102
+
103
+ if inplace:
104
+ outdir = indir
105
+ if not os.path.exists(outdir):
106
+ os.makedirs(outdir)
107
+
108
+ console.rule(f"Creating train/val dataset")
109
+
110
+ sss = (
111
+ ShuffleSplit(n_splits=1, test_size=val_size)
112
+ if not stratified_split
113
+ else StratifiedShuffleSplit(n_splits=1, test_size=val_size)
114
+ )
115
+
116
+ pprint({"split strategy": sss, "indir": indir, "outdir": outdir})
117
+ dataset = ImageFolder(
118
+ root=indir,
119
+ transform=None,
120
+ )
121
+ train_dataset_indices = None
122
+ val_dataset_indices = None # val here means test
123
+ for train_indices, val_indices in sss.split(dataset.samples, dataset.targets):
124
+ train_dataset_indices = train_indices
125
+ val_dataset_indices = val_indices
126
+
127
+ # get image paths for train/val split dataset
128
+ train_image_paths = [dataset.imgs[i][0] for i in train_dataset_indices]
129
+ val_image_paths = [dataset.imgs[i][0] for i in val_dataset_indices]
130
+
131
+ # start creating train/val folders then move images
132
+ out_train_dir = os.path.join(outdir, "train")
133
+ out_val_dir = os.path.join(outdir, "val")
134
+ if inplace:
135
+ assert os.path.exists(out_train_dir) == False, f"{out_train_dir} already exists"
136
+ assert os.path.exists(out_val_dir) == False, f"{out_val_dir} already exists"
137
+
138
+ os.makedirs(out_train_dir)
139
+ os.makedirs(out_val_dir)
140
+
141
+ if not no_train:
142
+ with ConsoleLog(f"Moving train images to {out_train_dir} "):
143
+ move_images(train_image_paths, out_train_dir)
144
+ else:
145
+ pprint("test only, skip moving train images")
146
+ # remove out_train_dir
147
+ shutil.rmtree(out_train_dir)
148
+
149
+ with ConsoleLog(f"Moving val images to {out_val_dir} "):
150
+ move_images(val_image_paths, out_val_dir)
151
+
152
+ if inplace:
153
+ pprint(f"remove all folders, except train and val")
154
+ for cls_dir in os.listdir(outdir):
155
+ if cls_dir not in ["train", "val"]:
156
+ shutil.rmtree(os.path.join(indir, cls_dir))
157
+
158
+
159
+ def reverse_split_ds(indir):
160
+ console.rule(f"Reversing split dataset <{indir}>...")
161
+ ls_dirs = os.listdir(indir)
162
+ # make sure there are only two dirs 'train' and 'val'
163
+ assert len(ls_dirs) == 2, f"Found more than 2 dirs: {len(ls_dirs) } dirs"
164
+ assert "train" in ls_dirs, f"train dir not found in {indir}"
165
+ assert "val" in ls_dirs, f"val dir not found in {indir}"
166
+ train_dir = os.path.join(indir, "train")
167
+ val_dir = os.path.join(indir, "val")
168
+ all_train_files = fs.filter_files_by_extension(
169
+ train_dir, ["jpg", "jpeg", "png", "bmp", "gif", "tiff"]
170
+ )
171
+ all_val_files = fs.filter_files_by_extension(
172
+ val_dir, ["jpg", "jpeg", "png", "bmp", "gif", "tiff"]
173
+ )
174
+ # move all files from train to indir
175
+ with ConsoleLog(f"Moving train images to {indir} "):
176
+ move_images(all_train_files, indir)
177
+ with ConsoleLog(f"Moving val images to {indir} "):
178
+ move_images(all_val_files, indir)
179
+ with ConsoleLog(f"Removing train and val dirs"):
180
+ # remove train and val dirs
181
+ shutil.rmtree(train_dir)
182
+ shutil.rmtree(val_dir)
183
+
184
+
185
+ def main():
186
+ args = parse_args()
187
+ indir = args.indir
188
+ outdir = args.outdir
189
+ if outdir == ".":
190
+ # get current folder of the indir
191
+ indir_parent_dir = os.path.dirname(os.path.normpath(indir))
192
+ indir_name = os.path.basename(indir)
193
+ outdir = os.path.join(indir_parent_dir, f"{indir_name}_split")
194
+ val_size = args.val_size
195
+ seed = args.seed
196
+ inplace = args.inplace
197
+ stratified_split = args.stratified
198
+ no_train = args.no_train
199
+ reverse = args.reverse
200
+ if not reverse:
201
+ split_dataset_cls(
202
+ indir, outdir, val_size, seed, inplace, stratified_split, no_train
203
+ )
204
+ else:
205
+ reverse_split_ds(indir)
206
+
207
+
208
+ if __name__ == "__main__":
209
+ main()
halib/research/plot.py ADDED
@@ -0,0 +1,301 @@
1
+ from ..common import now_str, norm_str, ConsoleLog
2
+ from ..filetype import csvfile
3
+ from ..system import filesys as fs
4
+ from functools import partial
5
+ from rich.console import Console
6
+ from rich.pretty import pprint
7
+ import click
8
+ import csv
9
+ import matplotlib
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import os
13
+ import pandas as pd
14
+ import seaborn as sns
15
+
16
+
17
+ console = Console()
18
+ desktop_path = os.path.expanduser("~/Desktop")
19
+ REQUIRED_COLUMNS = ["epoch", "train_loss", "val_loss", "train_acc", "val_acc"]
20
+
21
+ import csv
22
+
23
+
24
+ def get_delimiter(file_path, bytes=4096):
25
+ sniffer = csv.Sniffer()
26
+ data = open(file_path, "r").read(bytes)
27
+ delimiter = sniffer.sniff(data).delimiter
28
+ return delimiter
29
+
30
+
31
+ # Function to verify that the DataFrame has the required columns, and only the required columns
32
+ def verify_csv(csv_file, required_columns=REQUIRED_COLUMNS):
33
+ delimiter = get_delimiter(csv_file)
34
+ df = pd.read_csv(csv_file, sep=delimiter)
35
+ # change the column names to lower case
36
+ df.columns = [col.lower() for col in df.columns]
37
+ for col in required_columns:
38
+ if col not in df.columns:
39
+ raise ValueError(
40
+ f"Required columns are: {REQUIRED_COLUMNS}, but found {df.columns}"
41
+ )
42
+ df = df[required_columns].copy()
43
+ return df
44
+
45
+
46
+ def get_valid_tags(csv_files, tags):
47
+ if tags is not None and len(tags) > 0:
48
+ assert all(
49
+ isinstance(tag, str) for tag in tags
50
+ ), "tags must be a list of strings"
51
+ assert all(
52
+ len(tag) > 0 for tag in tags
53
+ ), "tags must be a list of non-empty strings"
54
+ valid_tags = tags
55
+ else:
56
+ valid_tags = []
57
+ for csv_file in csv_files:
58
+ file_name = fs.get_file_name(csv_file, split_file_ext=True)[0]
59
+ tag = norm_str(file_name)
60
+ valid_tags.append(tag)
61
+ return valid_tags
62
+
63
+
64
+ def plot_ax(df, ax, metric="loss", tag=""):
65
+ pprint(locals())
66
+ # reset plt
67
+ assert metric in ["loss", "acc"], "metric must be either 'loss' or 'acc'"
68
+ part = ["train", "val"]
69
+ for p in part:
70
+ label = f"{tag}_{p}_{metric}"
71
+ ax.plot(df["epoch"], df[f"{p}_{metric}"], label=label)
72
+ return ax
73
+
74
+
75
+ def actual_plot_seaborn(frame, csv_files, axes, tags, log):
76
+ # clear the axes
77
+ for ax in axes:
78
+ ax.clear()
79
+ ls_df = []
80
+ valid_tags = get_valid_tags(csv_files, tags)
81
+ for csv_file in csv_files:
82
+ df = verify_csv(csv_file)
83
+ if log:
84
+ with ConsoleLog(f"plotting {csv_file}"):
85
+ csvfile.fn_display_df(df)
86
+ ls_df.append(df)
87
+
88
+ ls_metrics = ["loss", "acc"]
89
+ for df_item, tag in zip(ls_df, valid_tags):
90
+ # add tag to columns,excpet epoch
91
+ df_item.columns = [
92
+ f"{tag}_{col}" if col != "epoch" else col for col in df_item.columns
93
+ ]
94
+ # merge the dataframes on the epoch column
95
+ df_combined = ls_df[0]
96
+ for df_item in ls_df[1:]:
97
+ df_combined = pd.merge(df_combined, df_item, on="epoch", how="outer")
98
+ # csvfile.fn_display_df(df_combined)
99
+
100
+ for i, metric in enumerate(ls_metrics):
101
+ tags_str = "+".join(valid_tags) if len(valid_tags) > 1 else valid_tags[0]
102
+ title = f"{tags_str}_{metric}-by-epoch"
103
+ cols = [col for col in df_combined.columns if col != "epoch" and metric in col]
104
+ cols = sorted(cols)
105
+ # pprint(cols)
106
+ plot_data = df_combined[cols]
107
+
108
+ # line from same csv file (same tag) should have the same marker
109
+ all_markers = [
110
+ marker for marker in plt.Line2D.markers if marker and marker != " "
111
+ ]
112
+ tag2marker = {tag: marker for tag, marker in zip(valid_tags, all_markers)}
113
+ plot_markers = []
114
+ for col in cols:
115
+ # find the tag:
116
+ tag = None
117
+ for valid_tag in valid_tags:
118
+ if valid_tag in col:
119
+ tag = valid_tag
120
+ break
121
+ plot_markers.append(tag2marker[tag])
122
+ # pprint(list(zip(cols, plot_markers)))
123
+
124
+ # create color
125
+ sequential_palettes = [
126
+ "Reds",
127
+ "Greens",
128
+ "Blues",
129
+ "Oranges",
130
+ "Purples",
131
+ "Greys",
132
+ "BuGn",
133
+ "BuPu",
134
+ "GnBu",
135
+ "OrRd",
136
+ "PuBu",
137
+ "PuRd",
138
+ "RdPu",
139
+ "YlGn",
140
+ "PuBuGn",
141
+ "YlGnBu",
142
+ "YlOrBr",
143
+ "YlOrRd",
144
+ ]
145
+ # each csvfile (tag) should have a unique color
146
+ tag2palette = {
147
+ tag: palette for tag, palette in zip(valid_tags, sequential_palettes)
148
+ }
149
+ plot_colors = []
150
+ for tag in valid_tags:
151
+ palette = tag2palette[tag]
152
+ total_colors = 10
153
+ ls_colors = sns.color_palette(palette, total_colors).as_hex()
154
+ num_part = len(ls_metrics)
155
+ subarr = np.array_split(np.arange(total_colors), num_part)
156
+ for idx, col in enumerate(cols):
157
+ if tag in col:
158
+ chosen_color = ls_colors[
159
+ subarr[int(idx % num_part)].mean().astype(int)
160
+ ]
161
+ plot_colors.append(chosen_color)
162
+
163
+ # pprint(list(zip(cols, plot_colors)))
164
+ sns.lineplot(
165
+ data=plot_data,
166
+ markers=plot_markers,
167
+ palette=plot_colors,
168
+ ax=axes[i],
169
+ dashes=False,
170
+ )
171
+ axes[i].set(xlabel="epoch", ylabel=metric, title=title)
172
+ axes[i].legend()
173
+ axes[i].grid()
174
+
175
+
176
+ def actual_plot(frame, csv_files, axes, tags, log):
177
+ ls_df = []
178
+ valid_tags = get_valid_tags(csv_files, tags)
179
+ for csv_file in csv_files:
180
+ df = verify_csv(csv_file)
181
+ if log:
182
+ with ConsoleLog(f"plotting {csv_file}"):
183
+ csvfile.fn_display_df(df)
184
+ ls_df.append(df)
185
+
186
+ metric_values = ["loss", "acc"]
187
+ for i, metric in enumerate(metric_values):
188
+ for df_item, tag in zip(ls_df, valid_tags):
189
+ metric_ax = plot_ax(df_item, axes[i], metric, tag)
190
+
191
+ # set the title, xlabel, ylabel, legend, and grid
192
+ tags_str = "+".join(valid_tags) if len(valid_tags) > 1 else valid_tags[0]
193
+ metric_ax.set(
194
+ xlabel="epoch", ylabel=metric, title=f"{tags_str}_{metric}-by-epoch"
195
+ )
196
+ metric_ax.legend()
197
+ metric_ax.grid()
198
+
199
+
200
+ def plot_csv_files(
201
+ csv_files,
202
+ outdir="./out/plot",
203
+ tags=None,
204
+ log=False,
205
+ save_fig=False,
206
+ update_in_min=1,
207
+ ):
208
+ # if csv_files is a string, convert it to a list
209
+ if isinstance(csv_files, str):
210
+ csv_files = [csv_files]
211
+ # if tags is a string, convert it to a list
212
+ if isinstance(tags, str):
213
+ tags = [tags]
214
+ valid_tags = get_valid_tags(csv_files, tags)
215
+ assert len(valid_tags) == len(
216
+ csv_files
217
+ ), "Unable to determine tags for each csv file"
218
+ live_update_in_ms = int(update_in_min * 60 * 1000)
219
+ fig, axes = plt.subplots(2, 1, figsize=(10, 17))
220
+ if live_update_in_ms: # live update in min should be > 0
221
+ from matplotlib.animation import FuncAnimation
222
+
223
+ anim = FuncAnimation(
224
+ fig,
225
+ partial(
226
+ actual_plot_seaborn, csv_files=csv_files, axes=axes, tags=tags, log=log
227
+ ),
228
+ interval=live_update_in_ms,
229
+ blit=False,
230
+ cache_frame_data=False,
231
+ )
232
+ plt.show()
233
+ else:
234
+ actual_plot_seaborn(None, csv_files, axes, tags, log)
235
+ plt.show()
236
+
237
+ if save_fig:
238
+ os.makedirs(outdir, exist_ok=True)
239
+ tags_str = "+".join(valid_tags) if len(valid_tags) > 1 else valid_tags[0]
240
+ tag = f"{now_str()}_{tags_str}"
241
+ fig.savefig(f"{outdir}/{tag}_plot.png")
242
+ enable_plot_pgf()
243
+ fig.savefig(f"{outdir}/{tag}_plot.pdf")
244
+ if live_update_in_ms:
245
+ return anim
246
+
247
+
248
+ def enable_plot_pgf():
249
+ matplotlib.use("pdf")
250
+ matplotlib.rcParams.update(
251
+ {
252
+ "pgf.texsystem": "pdflatex",
253
+ "font.family": "serif",
254
+ "text.usetex": True,
255
+ "pgf.rcfonts": False,
256
+ }
257
+ )
258
+
259
+
260
+ def save_fig_latex_pgf(filename, directory="."):
261
+ enable_plot_pgf()
262
+ if ".pgf" not in filename:
263
+ filename = f"{directory}/{filename}.pgf"
264
+ plt.savefig(filename)
265
+
266
+
267
+ # https: // click.palletsprojects.com/en/8.1.x/api/
268
+ @click.command()
269
+ @click.option("--csvfiles", "-f", multiple=True, type=str, help="csv files to plot")
270
+ @click.option(
271
+ "--outdir",
272
+ "-o",
273
+ type=str,
274
+ help="output directory for the plot",
275
+ default=str(desktop_path),
276
+ )
277
+ @click.option(
278
+ "--tags", "-t", multiple=True, type=str, help="tags for the csv files", default=[]
279
+ )
280
+ @click.option("--log", "-l", is_flag=True, help="log the csv files")
281
+ @click.option("--save_fig", "-s", is_flag=True, help="save the plot as a file")
282
+ @click.option(
283
+ "--update_in_min",
284
+ "-u",
285
+ type=float,
286
+ help="update the plot every x minutes",
287
+ default=0.0,
288
+ )
289
+ def main(
290
+ csvfiles,
291
+ outdir,
292
+ tags,
293
+ log,
294
+ save_fig,
295
+ update_in_min,
296
+ ):
297
+ plot_csv_files(list(csvfiles), outdir, list(tags), log, save_fig, update_in_min)
298
+
299
+
300
+ if __name__ == "__main__":
301
+ main()
@@ -0,0 +1,162 @@
1
+ """
2
+ * @author Hoang Van-Ha
3
+ * @email hoangvanhauit@gmail.com
4
+ * @create date 2024-03-27 15:40:22
5
+ * @modify date 2024-03-27 15:40:22
6
+ * @desc this module works as a utility tools for finding the best configuration for dataloader (num_workers, batch_size, pin_menory, etc.) that fits your hardware.
7
+ """
8
+ from argparse import ArgumentParser
9
+ from ..common import *
10
+ from ..filetype import csvfile
11
+ from ..filetype.yamlfile import load_yaml
12
+ from rich import inspect
13
+ from torch.utils.data import DataLoader
14
+ from torchvision import datasets, transforms
15
+ from tqdm import tqdm
16
+ from typing import Union
17
+ import itertools as it # for cartesian product
18
+ import os
19
+ import time
20
+ import traceback
21
+
22
+
23
+ def parse_args():
24
+ parser = ArgumentParser(description="desc text")
25
+ parser.add_argument("-cfg", "--cfg", type=str, help="cfg file for searching")
26
+ return parser.parse_args()
27
+
28
+
29
+ def get_test_range(cfg: dict, search_item="num_workers"):
30
+ item_search_cfg = cfg["search_space"].get(search_item, None)
31
+ if item_search_cfg is None:
32
+ raise ValueError(f"search_item: {search_item} not found in cfg")
33
+ if isinstance(item_search_cfg, list):
34
+ return item_search_cfg
35
+ elif isinstance(item_search_cfg, dict):
36
+ if "mode" in item_search_cfg:
37
+ mode = item_search_cfg["mode"]
38
+ assert mode in ["range", "list"], f"mode: {mode} not supported"
39
+ value_in_mode = item_search_cfg.get(mode, None)
40
+ if value_in_mode is None:
41
+ raise ValueError(f"mode<{mode}>: data not found in <{search_item}>")
42
+ if mode == "range":
43
+ assert len(value_in_mode) == 3, f"range must have 3 values: start, stop, step"
44
+ start = value_in_mode[0]
45
+ stop = value_in_mode[1]
46
+ step = value_in_mode[2]
47
+ return list(range(start, stop, step))
48
+ elif mode == "list":
49
+ return item_search_cfg["list"]
50
+ else:
51
+ return [item_search_cfg] # for int, float, str, bool, etc.
52
+
53
+
54
+ def load_an_batch(loader_iter):
55
+ start = time.time()
56
+ next(loader_iter)
57
+ end = time.time()
58
+ return end - start
59
+
60
+
61
+ def test_dataloader_with_cfg(origin_dataloader: DataLoader, cfg: Union[dict, str]):
62
+ try:
63
+ if isinstance(cfg, str):
64
+ cfg = load_yaml(cfg, to_dict=True)
65
+ dfmk = csvfile.DFCreator()
66
+ search_items = ["batch_size", "num_workers", "persistent_workers", "pin_memory"]
67
+ batch_limit = cfg["general"]["batch_limit"]
68
+ csv_cfg = cfg["general"]["to_csv"]
69
+ log_batch_info = cfg["general"]["log_batch_info"]
70
+
71
+ save_to_csv = csv_cfg["enabled"]
72
+ log_dir = csv_cfg["log_dir"]
73
+ filename = csv_cfg["filename"]
74
+ filename = f"{now_str()}_{filename}.csv"
75
+ outfile = os.path.join(log_dir, filename)
76
+
77
+ dfmk.create_table(
78
+ "cfg_search",
79
+ (search_items + ["avg_time_taken"]),
80
+ )
81
+ ls_range_test = []
82
+ for item in search_items:
83
+ range_test = get_test_range(cfg, search_item=item)
84
+ range_test = [(item, i) for i in range_test]
85
+ ls_range_test.append(range_test)
86
+
87
+ all_combinations = list(it.product(*ls_range_test))
88
+
89
+ rows = []
90
+ for cfg_idx, combine in enumerate(all_combinations):
91
+ console.rule(f"Testing cfg {cfg_idx+1}/{len(all_combinations)}")
92
+ inspect(combine)
93
+ batch_size = combine[search_items.index("batch_size")][1]
94
+ num_workers = combine[search_items.index("num_workers")][1]
95
+ persistent_workers = combine[search_items.index("persistent_workers")][1]
96
+ pin_memory = combine[search_items.index("pin_memory")][1]
97
+
98
+ test_dataloader = DataLoader(origin_dataloader.dataset, batch_size=batch_size, num_workers=num_workers, persistent_workers=persistent_workers, pin_memory=pin_memory, shuffle=True)
99
+ row = [
100
+ batch_size,
101
+ num_workers,
102
+ persistent_workers,
103
+ pin_memory,
104
+ 0.0,
105
+ ]
106
+
107
+ # calculate the avg time taken to load the data for <batch_limit> batches
108
+ trainiter = iter(test_dataloader)
109
+ time_elapsed = 0
110
+ pprint('Start testing...')
111
+ for i in tqdm(range(batch_limit)):
112
+ single_batch_time = load_an_batch(trainiter)
113
+ if log_batch_info:
114
+ pprint(f"Batch {i+1} took {single_batch_time:.4f} seconds to load")
115
+ time_elapsed += single_batch_time
116
+ row[-1] = time_elapsed / batch_limit
117
+ rows.append(row)
118
+ dfmk.insert_rows('cfg_search', rows)
119
+ dfmk.fill_table_from_row_pool('cfg_search')
120
+ with ConsoleLog("results"):
121
+ csvfile.fn_display_df(dfmk['cfg_search'])
122
+ if save_to_csv:
123
+ dfmk["cfg_search"].to_csv(outfile, index=False)
124
+ console.print(f"[red] Data saved to <{outfile}> [/red]")
125
+
126
+ except Exception as e:
127
+ traceback.print_exc()
128
+ print(e)
129
+ # get current directory of this python file
130
+ current_dir = os.path.dirname(os.path.realpath(__file__))
131
+ standar_cfg_path = os.path.join(current_dir, "torchloader_search.yaml")
132
+ pprint(
133
+ f"Make sure you get the right <cfg.yaml> file. An example of <cfg.yaml> file can be found at this path: {standar_cfg_path}"
134
+ )
135
+ return
136
+
137
+ def main():
138
+ args = parse_args()
139
+ cfg_yaml = args.cfg
140
+ cfg_dict = load_yaml(cfg_yaml, to_dict=True)
141
+
142
+ # Define transforms for data augmentation and normalization
143
+ transform = transforms.Compose(
144
+ [
145
+ transforms.RandomHorizontalFlip(), # Randomly flip images horizontally
146
+ transforms.RandomRotation(10), # Randomly rotate images by 10 degrees
147
+ transforms.ToTensor(), # Convert images to PyTorch tensors
148
+ transforms.Normalize(
149
+ (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
150
+ ), # Normalize pixel values to [-1, 1]
151
+ ]
152
+ )
153
+ test_dataset = datasets.CIFAR10(
154
+ root="./data", train=False, download=True, transform=transform
155
+ )
156
+ batch_size = 64
157
+ train_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
158
+ test_dataloader_with_cfg(train_loader, cfg_dict)
159
+
160
+
161
+ if __name__ == "__main__":
162
+ main()
File without changes
halib/utils/listop.py ADDED
@@ -0,0 +1,13 @@
1
+ def subtract(list_a, list_b):
2
+ return [item for item in list_a if item not in list_b]
3
+
4
+
5
+ def union(list_a, list_b, no_duplicate=False):
6
+ if no_duplicate:
7
+ return list(set(list_a) | set(list_b))
8
+ else:
9
+ return list_a + list_b
10
+
11
+
12
+ def intersection(list_a, list_b):
13
+ return list(set(list_a) & set(list_b))
@@ -0,0 +1,166 @@
1
+ # Watch a log file and send a telegram message when train reaches a certain epoch or end
2
+
3
+ import os
4
+ import yaml
5
+ import asyncio
6
+ import telegram
7
+ import pandas as pd
8
+
9
+ from rich.pretty import pprint
10
+ from rich.console import Console
11
+ import plotly.graph_objects as go
12
+
13
+ from ..system import filesys as fs
14
+ from ..filetype import textfile, csvfile
15
+
16
+ from argparse import ArgumentParser
17
+
18
+ tele_console = Console()
19
+
20
+
21
+ def parse_args():
22
+ parser = ArgumentParser(description="desc text")
23
+ parser.add_argument(
24
+ "-cfg",
25
+ "--cfg",
26
+ type=str,
27
+ help="yaml file for tele",
28
+ default=r"E:\Dev\halib\cfg_tele_noti.yaml",
29
+ )
30
+
31
+ return parser.parse_args()
32
+
33
+
34
+ def get_watcher_message_df(target_file, num_last_lines):
35
+ file_ext = fs.get_file_name(target_file, split_file_ext=True)[1]
36
+ supported_ext = [".txt", ".log", ".csv"]
37
+ assert (
38
+ file_ext in supported_ext
39
+ ), f"File extension {file_ext} not supported. Supported extensions are {supported_ext}"
40
+ last_lines_df = None
41
+ if file_ext in [".txt", ".log"]:
42
+ lines = textfile.read_line_by_line(target_file)
43
+ if num_last_lines > len(lines):
44
+ num_last_lines = len(lines)
45
+ last_line_arr = lines[-num_last_lines:]
46
+ # add a line start with word "epoch"
47
+ epoch_info_list = "Epoch: n/a"
48
+ for line in reversed(lines):
49
+ if "epoch" in line.lower():
50
+ epoch_info_list = line
51
+ break
52
+ last_line_arr.insert(0, epoch_info_list) # insert at the beginning
53
+ dfCreator = csvfile.DFCreator()
54
+ dfCreator.create_table("last_lines", ["line"])
55
+ last_line_arr = [[line] for line in last_line_arr]
56
+ dfCreator.insert_rows("last_lines", last_line_arr)
57
+ dfCreator.fill_table_from_row_pool("last_lines")
58
+ last_lines_df = dfCreator["last_lines"].copy()
59
+ else:
60
+ df = pd.read_csv(target_file)
61
+ num_rows = len(df)
62
+ if num_last_lines > num_rows:
63
+ num_last_lines = num_rows
64
+ last_lines_df = df.tail(num_last_lines)
65
+ return last_lines_df
66
+
67
+
68
+ def df2img(df: pd.DataFrame, output_img_dir, decimal_places, out_img_scale):
69
+ df = df.round(decimal_places)
70
+ fig = go.Figure(
71
+ data=[
72
+ go.Table(
73
+ header=dict(values=list(df.columns), align="center"),
74
+ cells=dict(
75
+ values=df.values.transpose(),
76
+ fill_color=[["white", "lightgrey"] * df.shape[0]],
77
+ align="center",
78
+ ),
79
+ )
80
+ ]
81
+ )
82
+ if not os.path.exists(output_img_dir):
83
+ os.makedirs(output_img_dir)
84
+ img_path = os.path.normpath(os.path.join(output_img_dir, "last_lines.png"))
85
+ fig.write_image(img_path, scale=out_img_scale)
86
+ return img_path
87
+
88
+
89
+ def compose_message_and_img_path(
90
+ target_file, project, num_last_lines, decimal_places, out_img_scale, output_img_dir
91
+ ):
92
+ context_msg = f">> Project: {project} \n>> File: {target_file} \n>> Last {num_last_lines} lines:"
93
+ msg_df = get_watcher_message_df(target_file, num_last_lines)
94
+ try:
95
+ img_path = df2img(msg_df, output_img_dir, decimal_places, out_img_scale)
96
+ except Exception as e:
97
+ pprint(f"Error: {e}")
98
+ img_path = None
99
+ return context_msg, img_path
100
+
101
+
102
+ async def send_to_telegram(cfg_dict, interval_in_sec):
103
+ # pprint(cfg_dict)
104
+ token = cfg_dict["telegram"]["token"]
105
+ chat_id = cfg_dict["telegram"]["chat_id"]
106
+
107
+ noti_settings = cfg_dict["noti_settings"]
108
+ project = noti_settings["project"]
109
+ target_file = noti_settings["target_file"]
110
+ num_last_lines = noti_settings["num_last_lines"]
111
+ output_img_dir = noti_settings["output_img_dir"]
112
+ decimal_places = noti_settings["decimal_places"]
113
+ out_img_scale = noti_settings["out_img_scale"]
114
+
115
+ bot = telegram.Bot(token=token)
116
+ async with bot:
117
+ try:
118
+ context_msg, img_path = compose_message_and_img_path(
119
+ target_file,
120
+ project,
121
+ num_last_lines,
122
+ decimal_places,
123
+ out_img_scale,
124
+ output_img_dir,
125
+ )
126
+ time_now = next_time = pd.Timestamp.now().strftime("%Y-%m-%d %H:%M:%S")
127
+ sep_line = "-" * 50
128
+ context_msg = f"{sep_line}\n>> Time: {time_now}\n{context_msg}"
129
+ # calculate the next time to send message
130
+ next_time = pd.Timestamp.now() + pd.Timedelta(seconds=interval_in_sec)
131
+ next_time = next_time.strftime("%Y-%m-%d %H:%M:%S")
132
+ next_time_info = f"Next msg: {next_time}"
133
+ tele_console.rule()
134
+ tele_console.print("[green] Send message to telegram [/green]")
135
+ tele_console.print(
136
+ f"[red] Next message will be sent at <{next_time}> [/red]"
137
+ )
138
+ await bot.send_message(text=context_msg, chat_id=chat_id)
139
+ if img_path:
140
+ await bot.send_photo(chat_id=chat_id, photo=open(img_path, "rb"))
141
+ await bot.send_message(text=next_time_info, chat_id=chat_id)
142
+ except Exception as e:
143
+ pprint(f"Error: {e}")
144
+ pprint("Message not sent to telegram")
145
+
146
+
147
+ async def run_forever(cfg_path):
148
+ cfg_dict = yaml.safe_load(open(cfg_path, "r"))
149
+ noti_settings = cfg_dict["noti_settings"]
150
+ interval_in_min = noti_settings["interval_in_min"]
151
+ interval_in_sec = int(interval_in_min * 60)
152
+ pprint(
153
+ f"Message will be sent every {interval_in_min} minutes or {interval_in_sec} seconds"
154
+ )
155
+ while True:
156
+ await send_to_telegram(cfg_dict, interval_in_sec)
157
+ await asyncio.sleep(interval_in_sec)
158
+
159
+
160
+ async def main():
161
+ args = parse_args()
162
+ await run_forever(args.cfg)
163
+
164
+
165
+ if __name__ == "__main__":
166
+ asyncio.run(main())
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: halib
3
- Version: 0.1.47
3
+ Version: 0.1.48
4
4
  Summary: Small library for common tasks
5
5
  Author: Hoang Van Ha
6
6
  Author-email: hoangvanhauit@gmail.com
@@ -43,6 +43,10 @@ Requires-Dist: tube-dl
43
43
 
44
44
  Helper package for coding and automation
45
45
 
46
+ **Version 0.1.48**
47
+
48
+ + add `research` module to help with research tasks, including `benchquery` for benchmarking queries from dataframe
49
+
46
50
  **Version 0.1.47**
47
51
  + add `pprint_box` to print object/string in a box frame (like in `inspect`)
48
52
 
@@ -27,14 +27,22 @@ halib/online/gdrive.py,sha256=RmF4y6UPxektkKIctmfT-pKWZsBM9FVUeld6zZmJkp0,7787
27
27
  halib/online/gdrive_mkdir.py,sha256=wSJkQMJCDuS1gxQ2lHQHq_IrJ4xR_SEoPSo9n_2WNFU,1474
28
28
  halib/online/gdrive_test.py,sha256=hMWzz4RqZwETHp4GG4WwVNFfYvFQhp2Boz5t-DqwMo0,1342
29
29
  halib/online/projectmake.py,sha256=Zrs96WgXvO4nIrwxnCOletL4aTBge-EoF0r7hpKO1w8,4034
30
+ halib/research/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
+ halib/research/benchquery.py,sha256=FuKnbWQtCEoRRtJAfN-zaN-jPiO_EzsakmTOMiqi7GQ,4626
32
+ halib/research/dataset.py,sha256=QU0Hr5QFb8_XlvnOMgC9QJGIpwXAZ9lDd0RdQi_QRec,6743
33
+ halib/research/plot.py,sha256=-pDUk4z3C_GnyJ5zWmf-mGMdT4gaipVJWzIgcpIPiRk,9448
34
+ halib/research/torchloader.py,sha256=yqUjcSiME6H5W210363HyRUrOi3ISpUFAFkTr1w4DCw,6503
30
35
  halib/sys/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
36
  halib/sys/cmd.py,sha256=b2x7JPcNnFjLGheIESVYvqAb-w2UwBM1PAwYxMZ5YjA,228
32
37
  halib/sys/filesys.py,sha256=ERpnELLDKJoTIIKf-AajgkY62nID4qmqmX5TkE95APU,2931
33
38
  halib/system/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
34
39
  halib/system/cmd.py,sha256=b2x7JPcNnFjLGheIESVYvqAb-w2UwBM1PAwYxMZ5YjA,228
35
40
  halib/system/filesys.py,sha256=ERpnELLDKJoTIIKf-AajgkY62nID4qmqmX5TkE95APU,2931
36
- halib-0.1.47.dist-info/LICENSE.txt,sha256=qZssdna4aETiR8znYsShUjidu-U4jUT9Q-EWNlZ9yBQ,1100
37
- halib-0.1.47.dist-info/METADATA,sha256=GsAawspTV3gRGBKKkxsSuTG9IlkE2cAIj3GzdlCNE68,3823
38
- halib-0.1.47.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
39
- halib-0.1.47.dist-info/top_level.txt,sha256=7AD6PLaQTreE0Fn44mdZsoHBe_Zdd7GUmjsWPyQ7I-k,6
40
- halib-0.1.47.dist-info/RECORD,,
41
+ halib/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
42
+ halib/utils/listop.py,sha256=Vpa8_2fI0wySpB2-8sfTBkyi_A4FhoFVVvFiuvW8N64,339
43
+ halib/utils/tele_noti.py,sha256=-4WXZelCA4W9BroapkRyIdUu9cUVrcJJhegnMs_WpGU,5928
44
+ halib-0.1.48.dist-info/LICENSE.txt,sha256=qZssdna4aETiR8znYsShUjidu-U4jUT9Q-EWNlZ9yBQ,1100
45
+ halib-0.1.48.dist-info/METADATA,sha256=iaGDSmQyhQWr6hLkyRpK6ZpkW6tuoAoGOrPNuK3CQp8,3960
46
+ halib-0.1.48.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
47
+ halib-0.1.48.dist-info/top_level.txt,sha256=7AD6PLaQTreE0Fn44mdZsoHBe_Zdd7GUmjsWPyQ7I-k,6
48
+ halib-0.1.48.dist-info/RECORD,,
File without changes