halib 0.2.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. halib/__init__.py +94 -0
  2. halib/common/__init__.py +0 -0
  3. halib/common/common.py +326 -0
  4. halib/common/rich_color.py +285 -0
  5. halib/common.py +151 -0
  6. halib/csvfile.py +48 -0
  7. halib/cuda.py +39 -0
  8. halib/dataset.py +209 -0
  9. halib/exp/__init__.py +0 -0
  10. halib/exp/core/__init__.py +0 -0
  11. halib/exp/core/base_config.py +167 -0
  12. halib/exp/core/base_exp.py +147 -0
  13. halib/exp/core/param_gen.py +170 -0
  14. halib/exp/core/wandb_op.py +117 -0
  15. halib/exp/data/__init__.py +0 -0
  16. halib/exp/data/dataclass_util.py +41 -0
  17. halib/exp/data/dataset.py +208 -0
  18. halib/exp/data/torchloader.py +165 -0
  19. halib/exp/perf/__init__.py +0 -0
  20. halib/exp/perf/flop_calc.py +190 -0
  21. halib/exp/perf/gpu_mon.py +58 -0
  22. halib/exp/perf/perfcalc.py +470 -0
  23. halib/exp/perf/perfmetrics.py +137 -0
  24. halib/exp/perf/perftb.py +778 -0
  25. halib/exp/perf/profiler.py +507 -0
  26. halib/exp/viz/__init__.py +0 -0
  27. halib/exp/viz/plot.py +754 -0
  28. halib/filesys.py +117 -0
  29. halib/filetype/__init__.py +0 -0
  30. halib/filetype/csvfile.py +192 -0
  31. halib/filetype/ipynb.py +61 -0
  32. halib/filetype/jsonfile.py +19 -0
  33. halib/filetype/textfile.py +12 -0
  34. halib/filetype/videofile.py +266 -0
  35. halib/filetype/yamlfile.py +87 -0
  36. halib/gdrive.py +179 -0
  37. halib/gdrive_mkdir.py +41 -0
  38. halib/gdrive_test.py +37 -0
  39. halib/jsonfile.py +22 -0
  40. halib/listop.py +13 -0
  41. halib/online/__init__.py +0 -0
  42. halib/online/gdrive.py +229 -0
  43. halib/online/gdrive_mkdir.py +53 -0
  44. halib/online/gdrive_test.py +50 -0
  45. halib/online/projectmake.py +131 -0
  46. halib/online/tele_noti.py +165 -0
  47. halib/plot.py +301 -0
  48. halib/projectmake.py +115 -0
  49. halib/research/__init__.py +0 -0
  50. halib/research/base_config.py +100 -0
  51. halib/research/base_exp.py +157 -0
  52. halib/research/benchquery.py +131 -0
  53. halib/research/core/__init__.py +0 -0
  54. halib/research/core/base_config.py +144 -0
  55. halib/research/core/base_exp.py +157 -0
  56. halib/research/core/param_gen.py +108 -0
  57. halib/research/core/wandb_op.py +117 -0
  58. halib/research/data/__init__.py +0 -0
  59. halib/research/data/dataclass_util.py +41 -0
  60. halib/research/data/dataset.py +208 -0
  61. halib/research/data/torchloader.py +165 -0
  62. halib/research/dataset.py +208 -0
  63. halib/research/flop_csv.py +34 -0
  64. halib/research/flops.py +156 -0
  65. halib/research/metrics.py +137 -0
  66. halib/research/mics.py +74 -0
  67. halib/research/params_gen.py +108 -0
  68. halib/research/perf/__init__.py +0 -0
  69. halib/research/perf/flop_calc.py +190 -0
  70. halib/research/perf/gpu_mon.py +58 -0
  71. halib/research/perf/perfcalc.py +363 -0
  72. halib/research/perf/perfmetrics.py +137 -0
  73. halib/research/perf/perftb.py +778 -0
  74. halib/research/perf/profiler.py +301 -0
  75. halib/research/perfcalc.py +361 -0
  76. halib/research/perftb.py +780 -0
  77. halib/research/plot.py +758 -0
  78. halib/research/profiler.py +300 -0
  79. halib/research/torchloader.py +162 -0
  80. halib/research/viz/__init__.py +0 -0
  81. halib/research/viz/plot.py +754 -0
  82. halib/research/wandb_op.py +116 -0
  83. halib/rich_color.py +285 -0
  84. halib/sys/__init__.py +0 -0
  85. halib/sys/cmd.py +8 -0
  86. halib/sys/filesys.py +124 -0
  87. halib/system/__init__.py +0 -0
  88. halib/system/_list_pc.csv +6 -0
  89. halib/system/cmd.py +8 -0
  90. halib/system/filesys.py +164 -0
  91. halib/system/path.py +106 -0
  92. halib/tele_noti.py +166 -0
  93. halib/textfile.py +13 -0
  94. halib/torchloader.py +162 -0
  95. halib/utils/__init__.py +0 -0
  96. halib/utils/dataclass_util.py +40 -0
  97. halib/utils/dict.py +317 -0
  98. halib/utils/dict_op.py +9 -0
  99. halib/utils/gpu_mon.py +58 -0
  100. halib/utils/list.py +17 -0
  101. halib/utils/listop.py +13 -0
  102. halib/utils/slack.py +86 -0
  103. halib/utils/tele_noti.py +166 -0
  104. halib/utils/video.py +82 -0
  105. halib/videofile.py +139 -0
  106. halib-0.2.30.dist-info/METADATA +237 -0
  107. halib-0.2.30.dist-info/RECORD +110 -0
  108. halib-0.2.30.dist-info/WHEEL +5 -0
  109. halib-0.2.30.dist-info/licenses/LICENSE.txt +17 -0
  110. halib-0.2.30.dist-info/top_level.txt +1 -0
@@ -0,0 +1,164 @@
1
+ import glob
2
+ import os
3
+ import shutil
4
+ from distutils.dir_util import copy_tree
5
+ from concurrent.futures import ThreadPoolExecutor
6
+
7
+ COMMON_IMG_EXT = ["jpg", "jpeg", "png", "bmp", "tiff", "gif"]
8
+ COMMON_VIDEO_EXT = ["mp4", "avi", "mov", "mkv", "flv", "wmv"]
9
+
10
+ def is_exist(path):
11
+ return os.path.exists(path)
12
+
13
+
14
+ def is_dir(path):
15
+ return os.path.isdir(path)
16
+
17
+
18
+ def get_current_dir():
19
+ return os.getcwd()
20
+
21
+
22
+ def change_current_dir(new_dir):
23
+ if is_dir(new_dir):
24
+ os.chdir(new_dir)
25
+
26
+
27
+ def get_dir_name(directory):
28
+ return os.path.basename(os.path.normpath(directory))
29
+
30
+
31
+ def get_parent_dir(directory, return_full_path=False):
32
+ if not return_full_path:
33
+ return os.path.basename(os.path.dirname(directory))
34
+ else:
35
+ return os.path.dirname(directory)
36
+
37
+
38
+ def make_dir(directory):
39
+ if not os.path.exists(directory):
40
+ os.makedirs(directory)
41
+
42
+
43
+ def copy_dir(src_dir, dst_dir, dirs_exist_ok=True, ignore_patterns=None):
44
+ shutil.copytree(
45
+ src_dir, dst_dir, dirs_exist_ok=dirs_exist_ok, ignore=ignore_patterns
46
+ )
47
+
48
+
49
+ def delete_dir(directory):
50
+ shutil.rmtree(directory)
51
+
52
+
53
+ def list_dirs(directory):
54
+ folders = list(
55
+ filter(
56
+ lambda x: os.path.isdir(os.path.join(directory, x)), os.listdir(directory)
57
+ )
58
+ )
59
+ return folders
60
+
61
+
62
+ def list_files(directory):
63
+ files = list(
64
+ filter(
65
+ lambda x: os.path.isfile(os.path.join(directory, x)), os.listdir(directory)
66
+ )
67
+ )
68
+ return files
69
+
70
+ def filter_files_by_extension(directory, ext=None, recursive=True, num_workers=0):
71
+ """
72
+ Filters files using glob and multithreading.
73
+ If ext is None, returns ALL files.
74
+
75
+ Args:
76
+ directory (str): Path to search.
77
+ ext (str, list, or None): Extension(s) to find. If None, return all files.
78
+ recursive (bool): Whether to search subdirectories.
79
+ num_workers (int): Number of threads for checking file existence.
80
+ """
81
+ assert os.path.exists(directory) and os.path.isdir(
82
+ directory
83
+ ), "Directory does not exist"
84
+
85
+ # 1. Normalize extensions to a tuple (only if ext is provided)
86
+ extensions = None
87
+ if ext is not None:
88
+ if isinstance(ext, list):
89
+ extensions = tuple(ext)
90
+ else:
91
+ extensions = (ext,)
92
+
93
+ # 2. Define pattern
94
+ pattern = (
95
+ os.path.join(directory, "**", "*")
96
+ if recursive
97
+ else os.path.join(directory, "*")
98
+ )
99
+
100
+ # 3. Helper function for the thread workers
101
+ def validate_file(path):
102
+ if os.path.isfile(path):
103
+ return path
104
+ return None
105
+
106
+ result_files = []
107
+ if num_workers <= 0:
108
+ num_workers = os.cpu_count() or 4
109
+ # 4. Initialize ThreadPool with user-defined workers
110
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
111
+ # Step A: Get the iterator (lazy evaluation)
112
+ all_paths = glob.iglob(pattern, recursive=recursive)
113
+
114
+ # Step B: Apply Filter
115
+ if extensions is None:
116
+ # If ext is None, we skip the endswith check and pass everything
117
+ candidate_paths = all_paths
118
+ else:
119
+ # Filter by extension string FIRST (Fast CPU op)
120
+ candidate_paths = (p for p in all_paths if p.endswith(extensions))
121
+
122
+ # Step C: Parallelize the disk check (Slow I/O op)
123
+ for result in executor.map(validate_file, candidate_paths):
124
+ if result:
125
+ result_files.append(result)
126
+
127
+ return result_files
128
+
129
+
130
+ def is_file(path):
131
+ return os.path.isfile(path)
132
+
133
+
134
+ def get_file_name(file_path, split_file_ext=False):
135
+ if is_file(file_path):
136
+ if split_file_ext:
137
+ filename, file_extension = os.path.splitext(os.path.basename(file_path))
138
+ return filename, file_extension
139
+ else:
140
+ return os.path.basename(file_path)
141
+ else:
142
+ raise OSError("Not a file")
143
+
144
+
145
+ def get_absolute_path(file_path):
146
+ return os.path.abspath(file_path)
147
+
148
+
149
+ # dest can be a directory
150
+ def copy_file(source, dest):
151
+ shutil.copy2(source, dest)
152
+
153
+
154
+ def delete_file(path):
155
+ if is_file(path):
156
+ os.remove(path)
157
+
158
+
159
+ def rename_dir_or_file(old, new):
160
+ os.renames(old, new)
161
+
162
+
163
+ def move_dir_or_file(source, destination):
164
+ shutil.move(source, destination)
halib/system/path.py ADDED
@@ -0,0 +1,106 @@
1
+ from ..common.common import *
2
+ from ..filetype import csvfile
3
+ import pandas as pd
4
+ import platform
5
+ import re # <--- [FIX 1] Added missing import
6
+ import csv
7
+ from importlib import resources
8
+
9
+ PC_TO_ABBR = {}
10
+ ABBR_DISK_MAP = {}
11
+ pc_df = None
12
+ cPlatform = platform.system().lower()
13
+
14
+
15
+ def load_pc_meta_info():
16
+ # 1. Define the package where the file lives (dotted notation)
17
+ # Since the file is in 'halib/system/', the package is 'halib.system'
18
+ package_name = "halib.system"
19
+ file_name = "_list_pc.csv"
20
+
21
+ # 2. Locate the file
22
+ csv_path = resources.files(package_name).joinpath(file_name)
23
+ global PC_TO_ABBR, ABBR_DISK_MAP, pc_df
24
+ pc_df = pd.read_csv(csv_path, sep=';', encoding='utf-8') # ty:ignore[no-matching-overload]
25
+ PC_TO_ABBR = dict(zip(pc_df['pc_name'], pc_df['abbr']))
26
+ ABBR_DISK_MAP = dict(zip(pc_df['abbr'], pc_df['working_disk']))
27
+ # pprint("Loaded PC meta info:")
28
+ # pprint(ABBR_DISK_MAP)
29
+ # pprint(PC_TO_ABBR)
30
+ # ! must be called at the module load time
31
+ load_pc_meta_info()
32
+
33
+
34
+ def list_PCs(show=True):
35
+ global pc_df
36
+ if show:
37
+ csvfile.fn_display_df(pc_df)
38
+ return pc_df
39
+
40
+
41
+ def get_PC_name():
42
+ return platform.node()
43
+
44
+
45
+ def get_PC_abbr_name():
46
+ pc_name = get_PC_name()
47
+ return PC_TO_ABBR.get(pc_name, "Unknown")
48
+
49
+
50
+ def get_os_platform():
51
+ return platform.system().lower()
52
+
53
+
54
+ def get_working_disk(abbr_disk_map=ABBR_DISK_MAP):
55
+ pc_abbr = get_PC_abbr_name()
56
+ return abbr_disk_map.get(pc_abbr, None)
57
+
58
+ cDisk = get_working_disk()
59
+
60
+ # ! This function search for full paths in the obj and normalize them according to the current platform and working disk
61
+ # ! E.g: "E:/zdataset/DFire", but working_disk: "D:", current_platform: "windows" => "D:/zdataset/DFire"
62
+ # ! E.g: "E:/zdataset/DFire", but working_disk: "D:", current_platform: "linux" => "/mnt/d/zdataset/DFire"
63
+ def normalize_paths(obj, working_disk=cDisk, current_platform=cPlatform):
64
+ # [FIX 3] Resolve defaults inside function to be safer/cleaner
65
+ if working_disk is None:
66
+ working_disk = get_working_disk()
67
+ if current_platform is None:
68
+ current_platform = get_os_platform()
69
+
70
+ # [FIX 2] If PC is unknown, working_disk is None. Return early to avoid crash.
71
+ if working_disk is None:
72
+ return obj
73
+
74
+ if isinstance(obj, dict):
75
+ for key, value in obj.items():
76
+ obj[key] = normalize_paths(value, working_disk, current_platform)
77
+ return obj
78
+ elif isinstance(obj, list):
79
+ for i, item in enumerate(obj):
80
+ obj[i] = normalize_paths(item, working_disk, current_platform)
81
+ return obj
82
+ elif isinstance(obj, str):
83
+ # Normalize backslashes to forward slashes for consistency
84
+ obj = obj.replace("\\", "/")
85
+
86
+ # Regex for Windows-style path: e.g., "E:/zdataset/DFire"
87
+ win_match = re.match(r"^([A-Z]):/(.*)$", obj)
88
+ # Regex for Linux-style path: e.g., "/mnt/e/zdataset/DFire"
89
+ lin_match = re.match(r"^/mnt/([a-z])/(.*)$", obj)
90
+
91
+ if win_match or lin_match:
92
+ rest = win_match.group(2) if win_match else lin_match.group(2)
93
+
94
+ if current_platform == "windows":
95
+ # working_disk is like "D:", so "D:/" + rest
96
+ new_path = f"{working_disk}/{rest}"
97
+ elif current_platform == "linux":
98
+ # Extract drive letter from working_disk (e.g., "D:" -> "d")
99
+ drive_letter = working_disk[0].lower()
100
+ new_path = f"/mnt/{drive_letter}/{rest}"
101
+ else:
102
+ return obj
103
+ return new_path
104
+
105
+ # For non-strings or non-path strings, return as is
106
+ return obj
halib/tele_noti.py ADDED
@@ -0,0 +1,166 @@
1
+ # Watch a log file and send a telegram message when train reaches a certain epoch or end
2
+
3
+ import os
4
+ import yaml
5
+ import asyncio
6
+ import telegram
7
+ import pandas as pd
8
+
9
+ from rich.pretty import pprint
10
+ from rich.console import Console
11
+ import plotly.graph_objects as go
12
+
13
+ from .system import filesys as fs
14
+ from .filetype import textfile, csvfile
15
+
16
+ from argparse import ArgumentParser
17
+
18
+ tele_console = Console()
19
+
20
+
21
+ def parse_args():
22
+ parser = ArgumentParser(description="desc text")
23
+ parser.add_argument(
24
+ "-cfg",
25
+ "--cfg",
26
+ type=str,
27
+ help="yaml file for tele",
28
+ default=r"E:\Dev\halib\cfg_tele_noti.yaml",
29
+ )
30
+
31
+ return parser.parse_args()
32
+
33
+
34
+ def get_watcher_message_df(target_file, num_last_lines):
35
+ file_ext = fs.get_file_name(target_file, split_file_ext=True)[1]
36
+ supported_ext = [".txt", ".log", ".csv"]
37
+ assert (
38
+ file_ext in supported_ext
39
+ ), f"File extension {file_ext} not supported. Supported extensions are {supported_ext}"
40
+ last_lines_df = None
41
+ if file_ext in [".txt", ".log"]:
42
+ lines = textfile.read_line_by_line(target_file)
43
+ if num_last_lines > len(lines):
44
+ num_last_lines = len(lines)
45
+ last_line_arr = lines[-num_last_lines:]
46
+ # add a line start with word "epoch"
47
+ epoch_info_list = "Epoch: n/a"
48
+ for line in reversed(lines):
49
+ if "epoch" in line.lower():
50
+ epoch_info_list = line
51
+ break
52
+ last_line_arr.insert(0, epoch_info_list) # insert at the beginning
53
+ dfCreator = csvfile.DFCreator()
54
+ dfCreator.create_table("last_lines", ["line"])
55
+ last_line_arr = [[line] for line in last_line_arr]
56
+ dfCreator.insert_rows("last_lines", last_line_arr)
57
+ dfCreator.fill_table_from_row_pool("last_lines")
58
+ last_lines_df = dfCreator["last_lines"].copy()
59
+ else:
60
+ df = pd.read_csv(target_file)
61
+ num_rows = len(df)
62
+ if num_last_lines > num_rows:
63
+ num_last_lines = num_rows
64
+ last_lines_df = df.tail(num_last_lines)
65
+ return last_lines_df
66
+
67
+
68
+ def df2img(df: pd.DataFrame, output_img_dir, decimal_places, out_img_scale):
69
+ df = df.round(decimal_places)
70
+ fig = go.Figure(
71
+ data=[
72
+ go.Table(
73
+ header=dict(values=list(df.columns), align="center"),
74
+ cells=dict(
75
+ values=df.values.transpose(),
76
+ fill_color=[["white", "lightgrey"] * df.shape[0]],
77
+ align="center",
78
+ ),
79
+ )
80
+ ]
81
+ )
82
+ if not os.path.exists(output_img_dir):
83
+ os.makedirs(output_img_dir)
84
+ img_path = os.path.normpath(os.path.join(output_img_dir, "last_lines.png"))
85
+ fig.write_image(img_path, scale=out_img_scale)
86
+ return img_path
87
+
88
+
89
+ def compose_message_and_img_path(
90
+ target_file, project, num_last_lines, decimal_places, out_img_scale, output_img_dir
91
+ ):
92
+ context_msg = f">> Project: {project} \n>> File: {target_file} \n>> Last {num_last_lines} lines:"
93
+ msg_df = get_watcher_message_df(target_file, num_last_lines)
94
+ try:
95
+ img_path = df2img(msg_df, output_img_dir, decimal_places, out_img_scale)
96
+ except Exception as e:
97
+ pprint(f"Error: {e}")
98
+ img_path = None
99
+ return context_msg, img_path
100
+
101
+
102
+ async def send_to_telegram(cfg_dict, interval_in_sec):
103
+ # pprint(cfg_dict)
104
+ token = cfg_dict["telegram"]["token"]
105
+ chat_id = cfg_dict["telegram"]["chat_id"]
106
+
107
+ noti_settings = cfg_dict["noti_settings"]
108
+ project = noti_settings["project"]
109
+ target_file = noti_settings["target_file"]
110
+ num_last_lines = noti_settings["num_last_lines"]
111
+ output_img_dir = noti_settings["output_img_dir"]
112
+ decimal_places = noti_settings["decimal_places"]
113
+ out_img_scale = noti_settings["out_img_scale"]
114
+
115
+ bot = telegram.Bot(token=token)
116
+ async with bot:
117
+ try:
118
+ context_msg, img_path = compose_message_and_img_path(
119
+ target_file,
120
+ project,
121
+ num_last_lines,
122
+ decimal_places,
123
+ out_img_scale,
124
+ output_img_dir,
125
+ )
126
+ time_now = next_time = pd.Timestamp.now().strftime("%Y-%m-%d %H:%M:%S")
127
+ sep_line = "-" * 50
128
+ context_msg = f"{sep_line}\n>> Time: {time_now}\n{context_msg}"
129
+ # calculate the next time to send message
130
+ next_time = pd.Timestamp.now() + pd.Timedelta(seconds=interval_in_sec)
131
+ next_time = next_time.strftime("%Y-%m-%d %H:%M:%S")
132
+ next_time_info = f"Next msg: {next_time}"
133
+ tele_console.rule()
134
+ tele_console.print("[green] Send message to telegram [/green]")
135
+ tele_console.print(
136
+ f"[red] Next message will be sent at <{next_time}> [/red]"
137
+ )
138
+ await bot.send_message(text=context_msg, chat_id=chat_id)
139
+ if img_path:
140
+ await bot.send_photo(chat_id=chat_id, photo=open(img_path, "rb"))
141
+ await bot.send_message(text=next_time_info, chat_id=chat_id)
142
+ except Exception as e:
143
+ pprint(f"Error: {e}")
144
+ pprint("Message not sent to telegram")
145
+
146
+
147
+ async def run_forever(cfg_path):
148
+ cfg_dict = yaml.safe_load(open(cfg_path, "r"))
149
+ noti_settings = cfg_dict["noti_settings"]
150
+ interval_in_min = noti_settings["interval_in_min"]
151
+ interval_in_sec = int(interval_in_min * 60)
152
+ pprint(
153
+ f"Message will be sent every {interval_in_min} minutes or {interval_in_sec} seconds"
154
+ )
155
+ while True:
156
+ await send_to_telegram(cfg_dict, interval_in_sec)
157
+ await asyncio.sleep(interval_in_sec)
158
+
159
+
160
+ async def main():
161
+ args = parse_args()
162
+ await run_forever(args.cfg)
163
+
164
+
165
+ if __name__ == "__main__":
166
+ asyncio.run(main())
halib/textfile.py ADDED
@@ -0,0 +1,13 @@
1
+ def read_line_by_line(file_path):
2
+ with open(file_path, 'r') as file:
3
+ lines = file.readlines()
4
+ lines = [line.rstrip() for line in lines]
5
+ return lines
6
+
7
+
8
+ def write(lines, outfile, append=False):
9
+ mode = 'a' if append else 'w'
10
+ with open(outfile, mode, encoding='utf-8') as f:
11
+ for line in lines:
12
+ f.write(line)
13
+ f.write('\n')
halib/torchloader.py ADDED
@@ -0,0 +1,162 @@
1
+ """
2
+ * @author Hoang Van-Ha
3
+ * @email hoangvanhauit@gmail.com
4
+ * @create date 2024-03-27 15:40:22
5
+ * @modify date 2024-03-27 15:40:22
6
+ * @desc this module works as a utility tools for finding the best configuration for dataloader (num_workers, batch_size, pin_menory, etc.) that fits your hardware.
7
+ """
8
+ from argparse import ArgumentParser
9
+ from .common import *
10
+ from .filetype import csvfile
11
+ from .filetype.yamlfile import load_yaml
12
+ from rich import inspect
13
+ from torch.utils.data import DataLoader
14
+ from torchvision import datasets, transforms
15
+ from tqdm import tqdm
16
+ from typing import Union
17
+ import itertools as it # for cartesian product
18
+ import os
19
+ import time
20
+ import traceback
21
+
22
+
23
+ def parse_args():
24
+ parser = ArgumentParser(description="desc text")
25
+ parser.add_argument("-cfg", "--cfg", type=str, help="cfg file for searching")
26
+ return parser.parse_args()
27
+
28
+
29
+ def get_test_range(cfg: dict, search_item="num_workers"):
30
+ item_search_cfg = cfg["search_space"].get(search_item, None)
31
+ if item_search_cfg is None:
32
+ raise ValueError(f"search_item: {search_item} not found in cfg")
33
+ if isinstance(item_search_cfg, list):
34
+ return item_search_cfg
35
+ elif isinstance(item_search_cfg, dict):
36
+ if "mode" in item_search_cfg:
37
+ mode = item_search_cfg["mode"]
38
+ assert mode in ["range", "list"], f"mode: {mode} not supported"
39
+ value_in_mode = item_search_cfg.get(mode, None)
40
+ if value_in_mode is None:
41
+ raise ValueError(f"mode<{mode}>: data not found in <{search_item}>")
42
+ if mode == "range":
43
+ assert len(value_in_mode) == 3, f"range must have 3 values: start, stop, step"
44
+ start = value_in_mode[0]
45
+ stop = value_in_mode[1]
46
+ step = value_in_mode[2]
47
+ return list(range(start, stop, step))
48
+ elif mode == "list":
49
+ return item_search_cfg["list"]
50
+ else:
51
+ return [item_search_cfg] # for int, float, str, bool, etc.
52
+
53
+
54
+ def load_an_batch(loader_iter):
55
+ start = time.time()
56
+ next(loader_iter)
57
+ end = time.time()
58
+ return end - start
59
+
60
+
61
+ def test_dataloader_with_cfg(origin_dataloader: DataLoader, cfg: Union[dict, str]):
62
+ try:
63
+ if isinstance(cfg, str):
64
+ cfg = load_yaml(cfg, to_dict=True)
65
+ dfmk = csvfile.DFCreator()
66
+ search_items = ["batch_size", "num_workers", "persistent_workers", "pin_memory"]
67
+ batch_limit = cfg["general"]["batch_limit"]
68
+ csv_cfg = cfg["general"]["to_csv"]
69
+ log_batch_info = cfg["general"]["log_batch_info"]
70
+
71
+ save_to_csv = csv_cfg["enabled"]
72
+ log_dir = csv_cfg["log_dir"]
73
+ filename = csv_cfg["filename"]
74
+ filename = f"{now_str()}_{filename}.csv"
75
+ outfile = os.path.join(log_dir, filename)
76
+
77
+ dfmk.create_table(
78
+ "cfg_search",
79
+ (search_items + ["avg_time_taken"]),
80
+ )
81
+ ls_range_test = []
82
+ for item in search_items:
83
+ range_test = get_test_range(cfg, search_item=item)
84
+ range_test = [(item, i) for i in range_test]
85
+ ls_range_test.append(range_test)
86
+
87
+ all_combinations = list(it.product(*ls_range_test))
88
+
89
+ rows = []
90
+ for cfg_idx, combine in enumerate(all_combinations):
91
+ console.rule(f"Testing cfg {cfg_idx+1}/{len(all_combinations)}")
92
+ inspect(combine)
93
+ batch_size = combine[search_items.index("batch_size")][1]
94
+ num_workers = combine[search_items.index("num_workers")][1]
95
+ persistent_workers = combine[search_items.index("persistent_workers")][1]
96
+ pin_memory = combine[search_items.index("pin_memory")][1]
97
+
98
+ test_dataloader = DataLoader(origin_dataloader.dataset, batch_size=batch_size, num_workers=num_workers, persistent_workers=persistent_workers, pin_memory=pin_memory, shuffle=True)
99
+ row = [
100
+ batch_size,
101
+ num_workers,
102
+ persistent_workers,
103
+ pin_memory,
104
+ 0.0,
105
+ ]
106
+
107
+ # calculate the avg time taken to load the data for <batch_limit> batches
108
+ trainiter = iter(test_dataloader)
109
+ time_elapsed = 0
110
+ pprint('Start testing...')
111
+ for i in tqdm(range(batch_limit)):
112
+ single_batch_time = load_an_batch(trainiter)
113
+ if log_batch_info:
114
+ pprint(f"Batch {i+1} took {single_batch_time:.4f} seconds to load")
115
+ time_elapsed += single_batch_time
116
+ row[-1] = time_elapsed / batch_limit
117
+ rows.append(row)
118
+ dfmk.insert_rows('cfg_search', rows)
119
+ dfmk.fill_table_from_row_pool('cfg_search')
120
+ with ConsoleLog("results"):
121
+ csvfile.fn_display_df(dfmk['cfg_search'])
122
+ if save_to_csv:
123
+ dfmk["cfg_search"].to_csv(outfile, index=False)
124
+ console.print(f"[red] Data saved to <{outfile}> [/red]")
125
+
126
+ except Exception as e:
127
+ traceback.print_exc()
128
+ print(e)
129
+ # get current directory of this python file
130
+ current_dir = os.path.dirname(os.path.realpath(__file__))
131
+ standar_cfg_path = os.path.join(current_dir, "torchloader_search.yaml")
132
+ pprint(
133
+ f"Make sure you get the right <cfg.yaml> file. An example of <cfg.yaml> file can be found at this path: {standar_cfg_path}"
134
+ )
135
+ return
136
+
137
+ def main():
138
+ args = parse_args()
139
+ cfg_yaml = args.cfg
140
+ cfg_dict = load_yaml(cfg_yaml, to_dict=True)
141
+
142
+ # Define transforms for data augmentation and normalization
143
+ transform = transforms.Compose(
144
+ [
145
+ transforms.RandomHorizontalFlip(), # Randomly flip images horizontally
146
+ transforms.RandomRotation(10), # Randomly rotate images by 10 degrees
147
+ transforms.ToTensor(), # Convert images to PyTorch tensors
148
+ transforms.Normalize(
149
+ (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
150
+ ), # Normalize pixel values to [-1, 1]
151
+ ]
152
+ )
153
+ test_dataset = datasets.CIFAR10(
154
+ root="./data", train=False, download=True, transform=transform
155
+ )
156
+ batch_size = 64
157
+ train_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
158
+ test_dataloader_with_cfg(train_loader, cfg_dict)
159
+
160
+
161
+ if __name__ == "__main__":
162
+ main()
File without changes
@@ -0,0 +1,40 @@
1
+ import yaml
2
+ from typing import Any
3
+ from rich.pretty import pprint
4
+ from ..filetype import yamlfile
5
+ # from halib.filetype import yamlfile
6
+ from dataclasses import make_dataclass
7
+
8
+ def dict_to_dataclass(name: str, data: dict):
9
+ fields = []
10
+ values = {}
11
+
12
+ for key, value in data.items():
13
+ if isinstance(value, dict):
14
+ sub_dc = dict_to_dataclass(key.capitalize(), value)
15
+ fields.append((key, type(sub_dc)))
16
+ values[key] = sub_dc
17
+ else:
18
+ field_type = type(value) if value is not None else Any
19
+ fields.append((key, field_type))
20
+ values[key] = value
21
+
22
+ DC = make_dataclass(name.capitalize(), fields)
23
+ return DC(**values)
24
+
25
+ def yaml_to_dataclass(name: str, yaml_str: str):
26
+ data = yaml.safe_load(yaml_str)
27
+ return dict_to_dataclass(name, data)
28
+
29
+
30
+ def yamlfile_to_dataclass(name: str, file_path: str):
31
+ data_dict = yamlfile.load_yaml(file_path, to_dict=True)
32
+ if "__base__" in data_dict:
33
+ del data_dict["__base__"]
34
+ return dict_to_dataclass(name, data_dict)
35
+
36
+ if __name__ == "__main__":
37
+ cfg = yamlfile_to_dataclass("Config", "test/dataclass_util_test_cfg.yaml")
38
+
39
+ # ! NOTICE: after print out this dataclass, we can copy the output and paste it into CHATGPT to generate a list of needed dataclass classes using `from dataclass_wizard import YAMLWizard`
40
+ pprint(cfg)