halib 0.1.47__py3-none-any.whl → 0.1.49__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,162 @@
1
+ """
2
+ * @author Hoang Van-Ha
3
+ * @email hoangvanhauit@gmail.com
4
+ * @create date 2024-03-27 15:40:22
5
+ * @modify date 2024-03-27 15:40:22
6
+ * @desc this module works as a utility tools for finding the best configuration for dataloader (num_workers, batch_size, pin_menory, etc.) that fits your hardware.
7
+ """
8
+ from argparse import ArgumentParser
9
+ from ..common import *
10
+ from ..filetype import csvfile
11
+ from ..filetype.yamlfile import load_yaml
12
+ from rich import inspect
13
+ from torch.utils.data import DataLoader
14
+ from torchvision import datasets, transforms
15
+ from tqdm import tqdm
16
+ from typing import Union
17
+ import itertools as it # for cartesian product
18
+ import os
19
+ import time
20
+ import traceback
21
+
22
+
23
+ def parse_args():
24
+ parser = ArgumentParser(description="desc text")
25
+ parser.add_argument("-cfg", "--cfg", type=str, help="cfg file for searching")
26
+ return parser.parse_args()
27
+
28
+
29
+ def get_test_range(cfg: dict, search_item="num_workers"):
30
+ item_search_cfg = cfg["search_space"].get(search_item, None)
31
+ if item_search_cfg is None:
32
+ raise ValueError(f"search_item: {search_item} not found in cfg")
33
+ if isinstance(item_search_cfg, list):
34
+ return item_search_cfg
35
+ elif isinstance(item_search_cfg, dict):
36
+ if "mode" in item_search_cfg:
37
+ mode = item_search_cfg["mode"]
38
+ assert mode in ["range", "list"], f"mode: {mode} not supported"
39
+ value_in_mode = item_search_cfg.get(mode, None)
40
+ if value_in_mode is None:
41
+ raise ValueError(f"mode<{mode}>: data not found in <{search_item}>")
42
+ if mode == "range":
43
+ assert len(value_in_mode) == 3, f"range must have 3 values: start, stop, step"
44
+ start = value_in_mode[0]
45
+ stop = value_in_mode[1]
46
+ step = value_in_mode[2]
47
+ return list(range(start, stop, step))
48
+ elif mode == "list":
49
+ return item_search_cfg["list"]
50
+ else:
51
+ return [item_search_cfg] # for int, float, str, bool, etc.
52
+
53
+
54
+ def load_an_batch(loader_iter):
55
+ start = time.time()
56
+ next(loader_iter)
57
+ end = time.time()
58
+ return end - start
59
+
60
+
61
+ def test_dataloader_with_cfg(origin_dataloader: DataLoader, cfg: Union[dict, str]):
62
+ try:
63
+ if isinstance(cfg, str):
64
+ cfg = load_yaml(cfg, to_dict=True)
65
+ dfmk = csvfile.DFCreator()
66
+ search_items = ["batch_size", "num_workers", "persistent_workers", "pin_memory"]
67
+ batch_limit = cfg["general"]["batch_limit"]
68
+ csv_cfg = cfg["general"]["to_csv"]
69
+ log_batch_info = cfg["general"]["log_batch_info"]
70
+
71
+ save_to_csv = csv_cfg["enabled"]
72
+ log_dir = csv_cfg["log_dir"]
73
+ filename = csv_cfg["filename"]
74
+ filename = f"{now_str()}_{filename}.csv"
75
+ outfile = os.path.join(log_dir, filename)
76
+
77
+ dfmk.create_table(
78
+ "cfg_search",
79
+ (search_items + ["avg_time_taken"]),
80
+ )
81
+ ls_range_test = []
82
+ for item in search_items:
83
+ range_test = get_test_range(cfg, search_item=item)
84
+ range_test = [(item, i) for i in range_test]
85
+ ls_range_test.append(range_test)
86
+
87
+ all_combinations = list(it.product(*ls_range_test))
88
+
89
+ rows = []
90
+ for cfg_idx, combine in enumerate(all_combinations):
91
+ console.rule(f"Testing cfg {cfg_idx+1}/{len(all_combinations)}")
92
+ inspect(combine)
93
+ batch_size = combine[search_items.index("batch_size")][1]
94
+ num_workers = combine[search_items.index("num_workers")][1]
95
+ persistent_workers = combine[search_items.index("persistent_workers")][1]
96
+ pin_memory = combine[search_items.index("pin_memory")][1]
97
+
98
+ test_dataloader = DataLoader(origin_dataloader.dataset, batch_size=batch_size, num_workers=num_workers, persistent_workers=persistent_workers, pin_memory=pin_memory, shuffle=True)
99
+ row = [
100
+ batch_size,
101
+ num_workers,
102
+ persistent_workers,
103
+ pin_memory,
104
+ 0.0,
105
+ ]
106
+
107
+ # calculate the avg time taken to load the data for <batch_limit> batches
108
+ trainiter = iter(test_dataloader)
109
+ time_elapsed = 0
110
+ pprint('Start testing...')
111
+ for i in tqdm(range(batch_limit)):
112
+ single_batch_time = load_an_batch(trainiter)
113
+ if log_batch_info:
114
+ pprint(f"Batch {i+1} took {single_batch_time:.4f} seconds to load")
115
+ time_elapsed += single_batch_time
116
+ row[-1] = time_elapsed / batch_limit
117
+ rows.append(row)
118
+ dfmk.insert_rows('cfg_search', rows)
119
+ dfmk.fill_table_from_row_pool('cfg_search')
120
+ with ConsoleLog("results"):
121
+ csvfile.fn_display_df(dfmk['cfg_search'])
122
+ if save_to_csv:
123
+ dfmk["cfg_search"].to_csv(outfile, index=False)
124
+ console.print(f"[red] Data saved to <{outfile}> [/red]")
125
+
126
+ except Exception as e:
127
+ traceback.print_exc()
128
+ print(e)
129
+ # get current directory of this python file
130
+ current_dir = os.path.dirname(os.path.realpath(__file__))
131
+ standar_cfg_path = os.path.join(current_dir, "torchloader_search.yaml")
132
+ pprint(
133
+ f"Make sure you get the right <cfg.yaml> file. An example of <cfg.yaml> file can be found at this path: {standar_cfg_path}"
134
+ )
135
+ return
136
+
137
+ def main():
138
+ args = parse_args()
139
+ cfg_yaml = args.cfg
140
+ cfg_dict = load_yaml(cfg_yaml, to_dict=True)
141
+
142
+ # Define transforms for data augmentation and normalization
143
+ transform = transforms.Compose(
144
+ [
145
+ transforms.RandomHorizontalFlip(), # Randomly flip images horizontally
146
+ transforms.RandomRotation(10), # Randomly rotate images by 10 degrees
147
+ transforms.ToTensor(), # Convert images to PyTorch tensors
148
+ transforms.Normalize(
149
+ (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
150
+ ), # Normalize pixel values to [-1, 1]
151
+ ]
152
+ )
153
+ test_dataset = datasets.CIFAR10(
154
+ root="./data", train=False, download=True, transform=transform
155
+ )
156
+ batch_size = 64
157
+ train_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
158
+ test_dataloader_with_cfg(train_loader, cfg_dict)
159
+
160
+
161
+ if __name__ == "__main__":
162
+ main()
@@ -0,0 +1,116 @@
1
+ import glob
2
+ from rich.pretty import pprint
3
+ import os
4
+ import subprocess
5
+ import argparse
6
+ import wandb
7
+ from tqdm import tqdm
8
+ from rich.console import Console
9
+ console = Console()
10
+
11
+ def sync_runs(outdir):
12
+ outdir = os.path.abspath(outdir)
13
+ assert os.path.exists(outdir), f"Output directory {outdir} does not exist."
14
+ sub_dirs = [name for name in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, name))]
15
+ assert len(sub_dirs) > 0, f"No subdirectories found in {outdir}."
16
+ console.rule("Parent Directory")
17
+ console.print(f"[yellow]{outdir}[/yellow]")
18
+
19
+ exp_dirs = [os.path.join(outdir, sub_dir) for sub_dir in sub_dirs]
20
+ wandb_dirs = []
21
+ for exp_dir in exp_dirs:
22
+ wandb_dirs.extend(glob.glob(f"{exp_dir}/wandb/*run-*"))
23
+ if len(wandb_dirs) == 0:
24
+ console.print(f"No wandb runs found in {outdir}.")
25
+ return
26
+ else:
27
+ console.print(f"Found [bold]{len(wandb_dirs)}[/bold] wandb runs in {outdir}.")
28
+ for i, wandb_dir in enumerate(wandb_dirs):
29
+ console.rule(f"Syncing wandb run {i + 1}/{len(wandb_dirs)}")
30
+ console.print(f"Syncing: {wandb_dir}")
31
+ process = subprocess.Popen(
32
+ ["wandb", "sync", wandb_dir],
33
+ stdout=subprocess.PIPE,
34
+ stderr=subprocess.STDOUT,
35
+ text=True,
36
+ )
37
+
38
+ for line in process.stdout:
39
+ console.print(line.strip())
40
+ if " ERROR Error while calling W&B API" in line:
41
+ break
42
+ process.stdout.close()
43
+ process.wait()
44
+ if process.returncode != 0:
45
+ console.print(f"[red]Error syncing {wandb_dir}. Return code: {process.returncode}[/red]")
46
+ else:
47
+ console.print(f"Successfully synced {wandb_dir}.")
48
+
49
+ def delete_runs(project, pattern=None):
50
+ console.rule("Delete W&B Runs")
51
+ confirm_msg = f"Are you sure you want to delete all runs in"
52
+ confirm_msg += f" \n\tproject: [red]{project}[/red]"
53
+ if pattern:
54
+ confirm_msg += f"\n\tpattern: [blue]{pattern}[/blue]"
55
+
56
+ console.print(confirm_msg)
57
+ confirmation = input(f"This action cannot be undone. [y/N]: ").strip().lower()
58
+ if confirmation != "y":
59
+ print("Cancelled.")
60
+ return
61
+
62
+ print("Confirmed. Proceeding...")
63
+ api = wandb.Api()
64
+ runs = api.runs(project)
65
+
66
+ deleted = 0
67
+ console.rule("Deleting W&B Runs")
68
+ if len(runs) == 0:
69
+ print("No runs found in the project.")
70
+ return
71
+ for run in tqdm(runs):
72
+ if pattern is None or pattern in run.name:
73
+ run.delete()
74
+ console.print(f"Deleted run: [red]{run.name}[/red]")
75
+ deleted += 1
76
+
77
+ console.print(f"Total runs deleted: {deleted}")
78
+
79
+
80
+ def valid_argument(args):
81
+ if args.op == "sync":
82
+ assert os.path.exists(args.outdir), f"Output directory {args.outdir} does not exist."
83
+ elif args.op == "delete":
84
+ assert isinstance(args.project, str) and len(args.project.strip()) > 0, "Project name must be a non-empty string."
85
+ else:
86
+ raise ValueError(f"Unknown operation: {args.op}")
87
+
88
+ def parse_args():
89
+ parser = argparse.ArgumentParser(description="Operations on W&B runs")
90
+ parser.add_argument("-op", "--op", type=str, help="Operation to perform", default="sync", choices=["delete", "sync"])
91
+ parser.add_argument("-prj", "--project", type=str, default="fire-paper2-2025", help="W&B project name")
92
+ parser.add_argument("-outdir", "--outdir", type=str, help="arg1 description", default="./zout/train")
93
+ parser.add_argument("-pt", "--pattern",
94
+ type=str,
95
+ default=None,
96
+ help="Run name pattern to match for deletion",
97
+ )
98
+
99
+ return parser.parse_args()
100
+
101
+
102
+ def main():
103
+ args = parse_args()
104
+ # Validate arguments, stop if invalid
105
+ valid_argument(args)
106
+
107
+ op = args.op
108
+ if op == "sync":
109
+ sync_runs(args.outdir)
110
+ elif op == "delete":
111
+ delete_runs(args.project, args.pattern)
112
+ else:
113
+ raise ValueError(f"Unknown operation: {op}")
114
+
115
+ if __name__ == "__main__":
116
+ main()
File without changes
halib/utils/listop.py ADDED
@@ -0,0 +1,13 @@
1
+ def subtract(list_a, list_b):
2
+ return [item for item in list_a if item not in list_b]
3
+
4
+
5
+ def union(list_a, list_b, no_duplicate=False):
6
+ if no_duplicate:
7
+ return list(set(list_a) | set(list_b))
8
+ else:
9
+ return list_a + list_b
10
+
11
+
12
+ def intersection(list_a, list_b):
13
+ return list(set(list_a) & set(list_b))
@@ -0,0 +1,166 @@
1
+ # Watch a log file and send a telegram message when train reaches a certain epoch or end
2
+
3
+ import os
4
+ import yaml
5
+ import asyncio
6
+ import telegram
7
+ import pandas as pd
8
+
9
+ from rich.pretty import pprint
10
+ from rich.console import Console
11
+ import plotly.graph_objects as go
12
+
13
+ from ..system import filesys as fs
14
+ from ..filetype import textfile, csvfile
15
+
16
+ from argparse import ArgumentParser
17
+
18
+ tele_console = Console()
19
+
20
+
21
+ def parse_args():
22
+ parser = ArgumentParser(description="desc text")
23
+ parser.add_argument(
24
+ "-cfg",
25
+ "--cfg",
26
+ type=str,
27
+ help="yaml file for tele",
28
+ default=r"E:\Dev\halib\cfg_tele_noti.yaml",
29
+ )
30
+
31
+ return parser.parse_args()
32
+
33
+
34
+ def get_watcher_message_df(target_file, num_last_lines):
35
+ file_ext = fs.get_file_name(target_file, split_file_ext=True)[1]
36
+ supported_ext = [".txt", ".log", ".csv"]
37
+ assert (
38
+ file_ext in supported_ext
39
+ ), f"File extension {file_ext} not supported. Supported extensions are {supported_ext}"
40
+ last_lines_df = None
41
+ if file_ext in [".txt", ".log"]:
42
+ lines = textfile.read_line_by_line(target_file)
43
+ if num_last_lines > len(lines):
44
+ num_last_lines = len(lines)
45
+ last_line_arr = lines[-num_last_lines:]
46
+ # add a line start with word "epoch"
47
+ epoch_info_list = "Epoch: n/a"
48
+ for line in reversed(lines):
49
+ if "epoch" in line.lower():
50
+ epoch_info_list = line
51
+ break
52
+ last_line_arr.insert(0, epoch_info_list) # insert at the beginning
53
+ dfCreator = csvfile.DFCreator()
54
+ dfCreator.create_table("last_lines", ["line"])
55
+ last_line_arr = [[line] for line in last_line_arr]
56
+ dfCreator.insert_rows("last_lines", last_line_arr)
57
+ dfCreator.fill_table_from_row_pool("last_lines")
58
+ last_lines_df = dfCreator["last_lines"].copy()
59
+ else:
60
+ df = pd.read_csv(target_file)
61
+ num_rows = len(df)
62
+ if num_last_lines > num_rows:
63
+ num_last_lines = num_rows
64
+ last_lines_df = df.tail(num_last_lines)
65
+ return last_lines_df
66
+
67
+
68
+ def df2img(df: pd.DataFrame, output_img_dir, decimal_places, out_img_scale):
69
+ df = df.round(decimal_places)
70
+ fig = go.Figure(
71
+ data=[
72
+ go.Table(
73
+ header=dict(values=list(df.columns), align="center"),
74
+ cells=dict(
75
+ values=df.values.transpose(),
76
+ fill_color=[["white", "lightgrey"] * df.shape[0]],
77
+ align="center",
78
+ ),
79
+ )
80
+ ]
81
+ )
82
+ if not os.path.exists(output_img_dir):
83
+ os.makedirs(output_img_dir)
84
+ img_path = os.path.normpath(os.path.join(output_img_dir, "last_lines.png"))
85
+ fig.write_image(img_path, scale=out_img_scale)
86
+ return img_path
87
+
88
+
89
+ def compose_message_and_img_path(
90
+ target_file, project, num_last_lines, decimal_places, out_img_scale, output_img_dir
91
+ ):
92
+ context_msg = f">> Project: {project} \n>> File: {target_file} \n>> Last {num_last_lines} lines:"
93
+ msg_df = get_watcher_message_df(target_file, num_last_lines)
94
+ try:
95
+ img_path = df2img(msg_df, output_img_dir, decimal_places, out_img_scale)
96
+ except Exception as e:
97
+ pprint(f"Error: {e}")
98
+ img_path = None
99
+ return context_msg, img_path
100
+
101
+
102
+ async def send_to_telegram(cfg_dict, interval_in_sec):
103
+ # pprint(cfg_dict)
104
+ token = cfg_dict["telegram"]["token"]
105
+ chat_id = cfg_dict["telegram"]["chat_id"]
106
+
107
+ noti_settings = cfg_dict["noti_settings"]
108
+ project = noti_settings["project"]
109
+ target_file = noti_settings["target_file"]
110
+ num_last_lines = noti_settings["num_last_lines"]
111
+ output_img_dir = noti_settings["output_img_dir"]
112
+ decimal_places = noti_settings["decimal_places"]
113
+ out_img_scale = noti_settings["out_img_scale"]
114
+
115
+ bot = telegram.Bot(token=token)
116
+ async with bot:
117
+ try:
118
+ context_msg, img_path = compose_message_and_img_path(
119
+ target_file,
120
+ project,
121
+ num_last_lines,
122
+ decimal_places,
123
+ out_img_scale,
124
+ output_img_dir,
125
+ )
126
+ time_now = next_time = pd.Timestamp.now().strftime("%Y-%m-%d %H:%M:%S")
127
+ sep_line = "-" * 50
128
+ context_msg = f"{sep_line}\n>> Time: {time_now}\n{context_msg}"
129
+ # calculate the next time to send message
130
+ next_time = pd.Timestamp.now() + pd.Timedelta(seconds=interval_in_sec)
131
+ next_time = next_time.strftime("%Y-%m-%d %H:%M:%S")
132
+ next_time_info = f"Next msg: {next_time}"
133
+ tele_console.rule()
134
+ tele_console.print("[green] Send message to telegram [/green]")
135
+ tele_console.print(
136
+ f"[red] Next message will be sent at <{next_time}> [/red]"
137
+ )
138
+ await bot.send_message(text=context_msg, chat_id=chat_id)
139
+ if img_path:
140
+ await bot.send_photo(chat_id=chat_id, photo=open(img_path, "rb"))
141
+ await bot.send_message(text=next_time_info, chat_id=chat_id)
142
+ except Exception as e:
143
+ pprint(f"Error: {e}")
144
+ pprint("Message not sent to telegram")
145
+
146
+
147
+ async def run_forever(cfg_path):
148
+ cfg_dict = yaml.safe_load(open(cfg_path, "r"))
149
+ noti_settings = cfg_dict["noti_settings"]
150
+ interval_in_min = noti_settings["interval_in_min"]
151
+ interval_in_sec = int(interval_in_min * 60)
152
+ pprint(
153
+ f"Message will be sent every {interval_in_min} minutes or {interval_in_sec} seconds"
154
+ )
155
+ while True:
156
+ await send_to_telegram(cfg_dict, interval_in_sec)
157
+ await asyncio.sleep(interval_in_sec)
158
+
159
+
160
+ async def main():
161
+ args = parse_args()
162
+ await run_forever(args.cfg)
163
+
164
+
165
+ if __name__ == "__main__":
166
+ asyncio.run(main())