halib 0.2.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- halib/__init__.py +94 -0
- halib/common/__init__.py +0 -0
- halib/common/common.py +326 -0
- halib/common/rich_color.py +285 -0
- halib/common.py +151 -0
- halib/csvfile.py +48 -0
- halib/cuda.py +39 -0
- halib/dataset.py +209 -0
- halib/exp/__init__.py +0 -0
- halib/exp/core/__init__.py +0 -0
- halib/exp/core/base_config.py +167 -0
- halib/exp/core/base_exp.py +147 -0
- halib/exp/core/param_gen.py +170 -0
- halib/exp/core/wandb_op.py +117 -0
- halib/exp/data/__init__.py +0 -0
- halib/exp/data/dataclass_util.py +41 -0
- halib/exp/data/dataset.py +208 -0
- halib/exp/data/torchloader.py +165 -0
- halib/exp/perf/__init__.py +0 -0
- halib/exp/perf/flop_calc.py +190 -0
- halib/exp/perf/gpu_mon.py +58 -0
- halib/exp/perf/perfcalc.py +470 -0
- halib/exp/perf/perfmetrics.py +137 -0
- halib/exp/perf/perftb.py +778 -0
- halib/exp/perf/profiler.py +507 -0
- halib/exp/viz/__init__.py +0 -0
- halib/exp/viz/plot.py +754 -0
- halib/filesys.py +117 -0
- halib/filetype/__init__.py +0 -0
- halib/filetype/csvfile.py +192 -0
- halib/filetype/ipynb.py +61 -0
- halib/filetype/jsonfile.py +19 -0
- halib/filetype/textfile.py +12 -0
- halib/filetype/videofile.py +266 -0
- halib/filetype/yamlfile.py +87 -0
- halib/gdrive.py +179 -0
- halib/gdrive_mkdir.py +41 -0
- halib/gdrive_test.py +37 -0
- halib/jsonfile.py +22 -0
- halib/listop.py +13 -0
- halib/online/__init__.py +0 -0
- halib/online/gdrive.py +229 -0
- halib/online/gdrive_mkdir.py +53 -0
- halib/online/gdrive_test.py +50 -0
- halib/online/projectmake.py +131 -0
- halib/online/tele_noti.py +165 -0
- halib/plot.py +301 -0
- halib/projectmake.py +115 -0
- halib/research/__init__.py +0 -0
- halib/research/base_config.py +100 -0
- halib/research/base_exp.py +157 -0
- halib/research/benchquery.py +131 -0
- halib/research/core/__init__.py +0 -0
- halib/research/core/base_config.py +144 -0
- halib/research/core/base_exp.py +157 -0
- halib/research/core/param_gen.py +108 -0
- halib/research/core/wandb_op.py +117 -0
- halib/research/data/__init__.py +0 -0
- halib/research/data/dataclass_util.py +41 -0
- halib/research/data/dataset.py +208 -0
- halib/research/data/torchloader.py +165 -0
- halib/research/dataset.py +208 -0
- halib/research/flop_csv.py +34 -0
- halib/research/flops.py +156 -0
- halib/research/metrics.py +137 -0
- halib/research/mics.py +74 -0
- halib/research/params_gen.py +108 -0
- halib/research/perf/__init__.py +0 -0
- halib/research/perf/flop_calc.py +190 -0
- halib/research/perf/gpu_mon.py +58 -0
- halib/research/perf/perfcalc.py +363 -0
- halib/research/perf/perfmetrics.py +137 -0
- halib/research/perf/perftb.py +778 -0
- halib/research/perf/profiler.py +301 -0
- halib/research/perfcalc.py +361 -0
- halib/research/perftb.py +780 -0
- halib/research/plot.py +758 -0
- halib/research/profiler.py +300 -0
- halib/research/torchloader.py +162 -0
- halib/research/viz/__init__.py +0 -0
- halib/research/viz/plot.py +754 -0
- halib/research/wandb_op.py +116 -0
- halib/rich_color.py +285 -0
- halib/sys/__init__.py +0 -0
- halib/sys/cmd.py +8 -0
- halib/sys/filesys.py +124 -0
- halib/system/__init__.py +0 -0
- halib/system/_list_pc.csv +6 -0
- halib/system/cmd.py +8 -0
- halib/system/filesys.py +164 -0
- halib/system/path.py +106 -0
- halib/tele_noti.py +166 -0
- halib/textfile.py +13 -0
- halib/torchloader.py +162 -0
- halib/utils/__init__.py +0 -0
- halib/utils/dataclass_util.py +40 -0
- halib/utils/dict.py +317 -0
- halib/utils/dict_op.py +9 -0
- halib/utils/gpu_mon.py +58 -0
- halib/utils/list.py +17 -0
- halib/utils/listop.py +13 -0
- halib/utils/slack.py +86 -0
- halib/utils/tele_noti.py +166 -0
- halib/utils/video.py +82 -0
- halib/videofile.py +139 -0
- halib-0.2.30.dist-info/METADATA +237 -0
- halib-0.2.30.dist-info/RECORD +110 -0
- halib-0.2.30.dist-info/WHEEL +5 -0
- halib-0.2.30.dist-info/licenses/LICENSE.txt +17 -0
- halib-0.2.30.dist-info/top_level.txt +1 -0
halib/system/filesys.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
import glob
|
|
2
|
+
import os
|
|
3
|
+
import shutil
|
|
4
|
+
from distutils.dir_util import copy_tree
|
|
5
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
6
|
+
|
|
7
|
+
COMMON_IMG_EXT = ["jpg", "jpeg", "png", "bmp", "tiff", "gif"]
|
|
8
|
+
COMMON_VIDEO_EXT = ["mp4", "avi", "mov", "mkv", "flv", "wmv"]
|
|
9
|
+
|
|
10
|
+
def is_exist(path):
|
|
11
|
+
return os.path.exists(path)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def is_dir(path):
|
|
15
|
+
return os.path.isdir(path)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_current_dir():
|
|
19
|
+
return os.getcwd()
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def change_current_dir(new_dir):
|
|
23
|
+
if is_dir(new_dir):
|
|
24
|
+
os.chdir(new_dir)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_dir_name(directory):
|
|
28
|
+
return os.path.basename(os.path.normpath(directory))
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_parent_dir(directory, return_full_path=False):
|
|
32
|
+
if not return_full_path:
|
|
33
|
+
return os.path.basename(os.path.dirname(directory))
|
|
34
|
+
else:
|
|
35
|
+
return os.path.dirname(directory)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def make_dir(directory):
|
|
39
|
+
if not os.path.exists(directory):
|
|
40
|
+
os.makedirs(directory)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def copy_dir(src_dir, dst_dir, dirs_exist_ok=True, ignore_patterns=None):
|
|
44
|
+
shutil.copytree(
|
|
45
|
+
src_dir, dst_dir, dirs_exist_ok=dirs_exist_ok, ignore=ignore_patterns
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def delete_dir(directory):
|
|
50
|
+
shutil.rmtree(directory)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def list_dirs(directory):
|
|
54
|
+
folders = list(
|
|
55
|
+
filter(
|
|
56
|
+
lambda x: os.path.isdir(os.path.join(directory, x)), os.listdir(directory)
|
|
57
|
+
)
|
|
58
|
+
)
|
|
59
|
+
return folders
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def list_files(directory):
|
|
63
|
+
files = list(
|
|
64
|
+
filter(
|
|
65
|
+
lambda x: os.path.isfile(os.path.join(directory, x)), os.listdir(directory)
|
|
66
|
+
)
|
|
67
|
+
)
|
|
68
|
+
return files
|
|
69
|
+
|
|
70
|
+
def filter_files_by_extension(directory, ext=None, recursive=True, num_workers=0):
|
|
71
|
+
"""
|
|
72
|
+
Filters files using glob and multithreading.
|
|
73
|
+
If ext is None, returns ALL files.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
directory (str): Path to search.
|
|
77
|
+
ext (str, list, or None): Extension(s) to find. If None, return all files.
|
|
78
|
+
recursive (bool): Whether to search subdirectories.
|
|
79
|
+
num_workers (int): Number of threads for checking file existence.
|
|
80
|
+
"""
|
|
81
|
+
assert os.path.exists(directory) and os.path.isdir(
|
|
82
|
+
directory
|
|
83
|
+
), "Directory does not exist"
|
|
84
|
+
|
|
85
|
+
# 1. Normalize extensions to a tuple (only if ext is provided)
|
|
86
|
+
extensions = None
|
|
87
|
+
if ext is not None:
|
|
88
|
+
if isinstance(ext, list):
|
|
89
|
+
extensions = tuple(ext)
|
|
90
|
+
else:
|
|
91
|
+
extensions = (ext,)
|
|
92
|
+
|
|
93
|
+
# 2. Define pattern
|
|
94
|
+
pattern = (
|
|
95
|
+
os.path.join(directory, "**", "*")
|
|
96
|
+
if recursive
|
|
97
|
+
else os.path.join(directory, "*")
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# 3. Helper function for the thread workers
|
|
101
|
+
def validate_file(path):
|
|
102
|
+
if os.path.isfile(path):
|
|
103
|
+
return path
|
|
104
|
+
return None
|
|
105
|
+
|
|
106
|
+
result_files = []
|
|
107
|
+
if num_workers <= 0:
|
|
108
|
+
num_workers = os.cpu_count() or 4
|
|
109
|
+
# 4. Initialize ThreadPool with user-defined workers
|
|
110
|
+
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
|
111
|
+
# Step A: Get the iterator (lazy evaluation)
|
|
112
|
+
all_paths = glob.iglob(pattern, recursive=recursive)
|
|
113
|
+
|
|
114
|
+
# Step B: Apply Filter
|
|
115
|
+
if extensions is None:
|
|
116
|
+
# If ext is None, we skip the endswith check and pass everything
|
|
117
|
+
candidate_paths = all_paths
|
|
118
|
+
else:
|
|
119
|
+
# Filter by extension string FIRST (Fast CPU op)
|
|
120
|
+
candidate_paths = (p for p in all_paths if p.endswith(extensions))
|
|
121
|
+
|
|
122
|
+
# Step C: Parallelize the disk check (Slow I/O op)
|
|
123
|
+
for result in executor.map(validate_file, candidate_paths):
|
|
124
|
+
if result:
|
|
125
|
+
result_files.append(result)
|
|
126
|
+
|
|
127
|
+
return result_files
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def is_file(path):
|
|
131
|
+
return os.path.isfile(path)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def get_file_name(file_path, split_file_ext=False):
|
|
135
|
+
if is_file(file_path):
|
|
136
|
+
if split_file_ext:
|
|
137
|
+
filename, file_extension = os.path.splitext(os.path.basename(file_path))
|
|
138
|
+
return filename, file_extension
|
|
139
|
+
else:
|
|
140
|
+
return os.path.basename(file_path)
|
|
141
|
+
else:
|
|
142
|
+
raise OSError("Not a file")
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def get_absolute_path(file_path):
|
|
146
|
+
return os.path.abspath(file_path)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
# dest can be a directory
|
|
150
|
+
def copy_file(source, dest):
|
|
151
|
+
shutil.copy2(source, dest)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def delete_file(path):
|
|
155
|
+
if is_file(path):
|
|
156
|
+
os.remove(path)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def rename_dir_or_file(old, new):
|
|
160
|
+
os.renames(old, new)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def move_dir_or_file(source, destination):
|
|
164
|
+
shutil.move(source, destination)
|
halib/system/path.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
from ..common.common import *
|
|
2
|
+
from ..filetype import csvfile
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import platform
|
|
5
|
+
import re # <--- [FIX 1] Added missing import
|
|
6
|
+
import csv
|
|
7
|
+
from importlib import resources
|
|
8
|
+
|
|
9
|
+
PC_TO_ABBR = {}
|
|
10
|
+
ABBR_DISK_MAP = {}
|
|
11
|
+
pc_df = None
|
|
12
|
+
cPlatform = platform.system().lower()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def load_pc_meta_info():
|
|
16
|
+
# 1. Define the package where the file lives (dotted notation)
|
|
17
|
+
# Since the file is in 'halib/system/', the package is 'halib.system'
|
|
18
|
+
package_name = "halib.system"
|
|
19
|
+
file_name = "_list_pc.csv"
|
|
20
|
+
|
|
21
|
+
# 2. Locate the file
|
|
22
|
+
csv_path = resources.files(package_name).joinpath(file_name)
|
|
23
|
+
global PC_TO_ABBR, ABBR_DISK_MAP, pc_df
|
|
24
|
+
pc_df = pd.read_csv(csv_path, sep=';', encoding='utf-8') # ty:ignore[no-matching-overload]
|
|
25
|
+
PC_TO_ABBR = dict(zip(pc_df['pc_name'], pc_df['abbr']))
|
|
26
|
+
ABBR_DISK_MAP = dict(zip(pc_df['abbr'], pc_df['working_disk']))
|
|
27
|
+
# pprint("Loaded PC meta info:")
|
|
28
|
+
# pprint(ABBR_DISK_MAP)
|
|
29
|
+
# pprint(PC_TO_ABBR)
|
|
30
|
+
# ! must be called at the module load time
|
|
31
|
+
load_pc_meta_info()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def list_PCs(show=True):
|
|
35
|
+
global pc_df
|
|
36
|
+
if show:
|
|
37
|
+
csvfile.fn_display_df(pc_df)
|
|
38
|
+
return pc_df
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def get_PC_name():
|
|
42
|
+
return platform.node()
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def get_PC_abbr_name():
|
|
46
|
+
pc_name = get_PC_name()
|
|
47
|
+
return PC_TO_ABBR.get(pc_name, "Unknown")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def get_os_platform():
|
|
51
|
+
return platform.system().lower()
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def get_working_disk(abbr_disk_map=ABBR_DISK_MAP):
|
|
55
|
+
pc_abbr = get_PC_abbr_name()
|
|
56
|
+
return abbr_disk_map.get(pc_abbr, None)
|
|
57
|
+
|
|
58
|
+
cDisk = get_working_disk()
|
|
59
|
+
|
|
60
|
+
# ! This function search for full paths in the obj and normalize them according to the current platform and working disk
|
|
61
|
+
# ! E.g: "E:/zdataset/DFire", but working_disk: "D:", current_platform: "windows" => "D:/zdataset/DFire"
|
|
62
|
+
# ! E.g: "E:/zdataset/DFire", but working_disk: "D:", current_platform: "linux" => "/mnt/d/zdataset/DFire"
|
|
63
|
+
def normalize_paths(obj, working_disk=cDisk, current_platform=cPlatform):
|
|
64
|
+
# [FIX 3] Resolve defaults inside function to be safer/cleaner
|
|
65
|
+
if working_disk is None:
|
|
66
|
+
working_disk = get_working_disk()
|
|
67
|
+
if current_platform is None:
|
|
68
|
+
current_platform = get_os_platform()
|
|
69
|
+
|
|
70
|
+
# [FIX 2] If PC is unknown, working_disk is None. Return early to avoid crash.
|
|
71
|
+
if working_disk is None:
|
|
72
|
+
return obj
|
|
73
|
+
|
|
74
|
+
if isinstance(obj, dict):
|
|
75
|
+
for key, value in obj.items():
|
|
76
|
+
obj[key] = normalize_paths(value, working_disk, current_platform)
|
|
77
|
+
return obj
|
|
78
|
+
elif isinstance(obj, list):
|
|
79
|
+
for i, item in enumerate(obj):
|
|
80
|
+
obj[i] = normalize_paths(item, working_disk, current_platform)
|
|
81
|
+
return obj
|
|
82
|
+
elif isinstance(obj, str):
|
|
83
|
+
# Normalize backslashes to forward slashes for consistency
|
|
84
|
+
obj = obj.replace("\\", "/")
|
|
85
|
+
|
|
86
|
+
# Regex for Windows-style path: e.g., "E:/zdataset/DFire"
|
|
87
|
+
win_match = re.match(r"^([A-Z]):/(.*)$", obj)
|
|
88
|
+
# Regex for Linux-style path: e.g., "/mnt/e/zdataset/DFire"
|
|
89
|
+
lin_match = re.match(r"^/mnt/([a-z])/(.*)$", obj)
|
|
90
|
+
|
|
91
|
+
if win_match or lin_match:
|
|
92
|
+
rest = win_match.group(2) if win_match else lin_match.group(2)
|
|
93
|
+
|
|
94
|
+
if current_platform == "windows":
|
|
95
|
+
# working_disk is like "D:", so "D:/" + rest
|
|
96
|
+
new_path = f"{working_disk}/{rest}"
|
|
97
|
+
elif current_platform == "linux":
|
|
98
|
+
# Extract drive letter from working_disk (e.g., "D:" -> "d")
|
|
99
|
+
drive_letter = working_disk[0].lower()
|
|
100
|
+
new_path = f"/mnt/{drive_letter}/{rest}"
|
|
101
|
+
else:
|
|
102
|
+
return obj
|
|
103
|
+
return new_path
|
|
104
|
+
|
|
105
|
+
# For non-strings or non-path strings, return as is
|
|
106
|
+
return obj
|
halib/tele_noti.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
# Watch a log file and send a telegram message when train reaches a certain epoch or end
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import yaml
|
|
5
|
+
import asyncio
|
|
6
|
+
import telegram
|
|
7
|
+
import pandas as pd
|
|
8
|
+
|
|
9
|
+
from rich.pretty import pprint
|
|
10
|
+
from rich.console import Console
|
|
11
|
+
import plotly.graph_objects as go
|
|
12
|
+
|
|
13
|
+
from .system import filesys as fs
|
|
14
|
+
from .filetype import textfile, csvfile
|
|
15
|
+
|
|
16
|
+
from argparse import ArgumentParser
|
|
17
|
+
|
|
18
|
+
tele_console = Console()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def parse_args():
|
|
22
|
+
parser = ArgumentParser(description="desc text")
|
|
23
|
+
parser.add_argument(
|
|
24
|
+
"-cfg",
|
|
25
|
+
"--cfg",
|
|
26
|
+
type=str,
|
|
27
|
+
help="yaml file for tele",
|
|
28
|
+
default=r"E:\Dev\halib\cfg_tele_noti.yaml",
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
return parser.parse_args()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_watcher_message_df(target_file, num_last_lines):
|
|
35
|
+
file_ext = fs.get_file_name(target_file, split_file_ext=True)[1]
|
|
36
|
+
supported_ext = [".txt", ".log", ".csv"]
|
|
37
|
+
assert (
|
|
38
|
+
file_ext in supported_ext
|
|
39
|
+
), f"File extension {file_ext} not supported. Supported extensions are {supported_ext}"
|
|
40
|
+
last_lines_df = None
|
|
41
|
+
if file_ext in [".txt", ".log"]:
|
|
42
|
+
lines = textfile.read_line_by_line(target_file)
|
|
43
|
+
if num_last_lines > len(lines):
|
|
44
|
+
num_last_lines = len(lines)
|
|
45
|
+
last_line_arr = lines[-num_last_lines:]
|
|
46
|
+
# add a line start with word "epoch"
|
|
47
|
+
epoch_info_list = "Epoch: n/a"
|
|
48
|
+
for line in reversed(lines):
|
|
49
|
+
if "epoch" in line.lower():
|
|
50
|
+
epoch_info_list = line
|
|
51
|
+
break
|
|
52
|
+
last_line_arr.insert(0, epoch_info_list) # insert at the beginning
|
|
53
|
+
dfCreator = csvfile.DFCreator()
|
|
54
|
+
dfCreator.create_table("last_lines", ["line"])
|
|
55
|
+
last_line_arr = [[line] for line in last_line_arr]
|
|
56
|
+
dfCreator.insert_rows("last_lines", last_line_arr)
|
|
57
|
+
dfCreator.fill_table_from_row_pool("last_lines")
|
|
58
|
+
last_lines_df = dfCreator["last_lines"].copy()
|
|
59
|
+
else:
|
|
60
|
+
df = pd.read_csv(target_file)
|
|
61
|
+
num_rows = len(df)
|
|
62
|
+
if num_last_lines > num_rows:
|
|
63
|
+
num_last_lines = num_rows
|
|
64
|
+
last_lines_df = df.tail(num_last_lines)
|
|
65
|
+
return last_lines_df
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def df2img(df: pd.DataFrame, output_img_dir, decimal_places, out_img_scale):
|
|
69
|
+
df = df.round(decimal_places)
|
|
70
|
+
fig = go.Figure(
|
|
71
|
+
data=[
|
|
72
|
+
go.Table(
|
|
73
|
+
header=dict(values=list(df.columns), align="center"),
|
|
74
|
+
cells=dict(
|
|
75
|
+
values=df.values.transpose(),
|
|
76
|
+
fill_color=[["white", "lightgrey"] * df.shape[0]],
|
|
77
|
+
align="center",
|
|
78
|
+
),
|
|
79
|
+
)
|
|
80
|
+
]
|
|
81
|
+
)
|
|
82
|
+
if not os.path.exists(output_img_dir):
|
|
83
|
+
os.makedirs(output_img_dir)
|
|
84
|
+
img_path = os.path.normpath(os.path.join(output_img_dir, "last_lines.png"))
|
|
85
|
+
fig.write_image(img_path, scale=out_img_scale)
|
|
86
|
+
return img_path
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def compose_message_and_img_path(
|
|
90
|
+
target_file, project, num_last_lines, decimal_places, out_img_scale, output_img_dir
|
|
91
|
+
):
|
|
92
|
+
context_msg = f">> Project: {project} \n>> File: {target_file} \n>> Last {num_last_lines} lines:"
|
|
93
|
+
msg_df = get_watcher_message_df(target_file, num_last_lines)
|
|
94
|
+
try:
|
|
95
|
+
img_path = df2img(msg_df, output_img_dir, decimal_places, out_img_scale)
|
|
96
|
+
except Exception as e:
|
|
97
|
+
pprint(f"Error: {e}")
|
|
98
|
+
img_path = None
|
|
99
|
+
return context_msg, img_path
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
async def send_to_telegram(cfg_dict, interval_in_sec):
|
|
103
|
+
# pprint(cfg_dict)
|
|
104
|
+
token = cfg_dict["telegram"]["token"]
|
|
105
|
+
chat_id = cfg_dict["telegram"]["chat_id"]
|
|
106
|
+
|
|
107
|
+
noti_settings = cfg_dict["noti_settings"]
|
|
108
|
+
project = noti_settings["project"]
|
|
109
|
+
target_file = noti_settings["target_file"]
|
|
110
|
+
num_last_lines = noti_settings["num_last_lines"]
|
|
111
|
+
output_img_dir = noti_settings["output_img_dir"]
|
|
112
|
+
decimal_places = noti_settings["decimal_places"]
|
|
113
|
+
out_img_scale = noti_settings["out_img_scale"]
|
|
114
|
+
|
|
115
|
+
bot = telegram.Bot(token=token)
|
|
116
|
+
async with bot:
|
|
117
|
+
try:
|
|
118
|
+
context_msg, img_path = compose_message_and_img_path(
|
|
119
|
+
target_file,
|
|
120
|
+
project,
|
|
121
|
+
num_last_lines,
|
|
122
|
+
decimal_places,
|
|
123
|
+
out_img_scale,
|
|
124
|
+
output_img_dir,
|
|
125
|
+
)
|
|
126
|
+
time_now = next_time = pd.Timestamp.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
127
|
+
sep_line = "-" * 50
|
|
128
|
+
context_msg = f"{sep_line}\n>> Time: {time_now}\n{context_msg}"
|
|
129
|
+
# calculate the next time to send message
|
|
130
|
+
next_time = pd.Timestamp.now() + pd.Timedelta(seconds=interval_in_sec)
|
|
131
|
+
next_time = next_time.strftime("%Y-%m-%d %H:%M:%S")
|
|
132
|
+
next_time_info = f"Next msg: {next_time}"
|
|
133
|
+
tele_console.rule()
|
|
134
|
+
tele_console.print("[green] Send message to telegram [/green]")
|
|
135
|
+
tele_console.print(
|
|
136
|
+
f"[red] Next message will be sent at <{next_time}> [/red]"
|
|
137
|
+
)
|
|
138
|
+
await bot.send_message(text=context_msg, chat_id=chat_id)
|
|
139
|
+
if img_path:
|
|
140
|
+
await bot.send_photo(chat_id=chat_id, photo=open(img_path, "rb"))
|
|
141
|
+
await bot.send_message(text=next_time_info, chat_id=chat_id)
|
|
142
|
+
except Exception as e:
|
|
143
|
+
pprint(f"Error: {e}")
|
|
144
|
+
pprint("Message not sent to telegram")
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
async def run_forever(cfg_path):
|
|
148
|
+
cfg_dict = yaml.safe_load(open(cfg_path, "r"))
|
|
149
|
+
noti_settings = cfg_dict["noti_settings"]
|
|
150
|
+
interval_in_min = noti_settings["interval_in_min"]
|
|
151
|
+
interval_in_sec = int(interval_in_min * 60)
|
|
152
|
+
pprint(
|
|
153
|
+
f"Message will be sent every {interval_in_min} minutes or {interval_in_sec} seconds"
|
|
154
|
+
)
|
|
155
|
+
while True:
|
|
156
|
+
await send_to_telegram(cfg_dict, interval_in_sec)
|
|
157
|
+
await asyncio.sleep(interval_in_sec)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
async def main():
|
|
161
|
+
args = parse_args()
|
|
162
|
+
await run_forever(args.cfg)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
if __name__ == "__main__":
|
|
166
|
+
asyncio.run(main())
|
halib/textfile.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
def read_line_by_line(file_path):
|
|
2
|
+
with open(file_path, 'r') as file:
|
|
3
|
+
lines = file.readlines()
|
|
4
|
+
lines = [line.rstrip() for line in lines]
|
|
5
|
+
return lines
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def write(lines, outfile, append=False):
|
|
9
|
+
mode = 'a' if append else 'w'
|
|
10
|
+
with open(outfile, mode, encoding='utf-8') as f:
|
|
11
|
+
for line in lines:
|
|
12
|
+
f.write(line)
|
|
13
|
+
f.write('\n')
|
halib/torchloader.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
"""
|
|
2
|
+
* @author Hoang Van-Ha
|
|
3
|
+
* @email hoangvanhauit@gmail.com
|
|
4
|
+
* @create date 2024-03-27 15:40:22
|
|
5
|
+
* @modify date 2024-03-27 15:40:22
|
|
6
|
+
* @desc this module works as a utility tools for finding the best configuration for dataloader (num_workers, batch_size, pin_menory, etc.) that fits your hardware.
|
|
7
|
+
"""
|
|
8
|
+
from argparse import ArgumentParser
|
|
9
|
+
from .common import *
|
|
10
|
+
from .filetype import csvfile
|
|
11
|
+
from .filetype.yamlfile import load_yaml
|
|
12
|
+
from rich import inspect
|
|
13
|
+
from torch.utils.data import DataLoader
|
|
14
|
+
from torchvision import datasets, transforms
|
|
15
|
+
from tqdm import tqdm
|
|
16
|
+
from typing import Union
|
|
17
|
+
import itertools as it # for cartesian product
|
|
18
|
+
import os
|
|
19
|
+
import time
|
|
20
|
+
import traceback
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def parse_args():
|
|
24
|
+
parser = ArgumentParser(description="desc text")
|
|
25
|
+
parser.add_argument("-cfg", "--cfg", type=str, help="cfg file for searching")
|
|
26
|
+
return parser.parse_args()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def get_test_range(cfg: dict, search_item="num_workers"):
|
|
30
|
+
item_search_cfg = cfg["search_space"].get(search_item, None)
|
|
31
|
+
if item_search_cfg is None:
|
|
32
|
+
raise ValueError(f"search_item: {search_item} not found in cfg")
|
|
33
|
+
if isinstance(item_search_cfg, list):
|
|
34
|
+
return item_search_cfg
|
|
35
|
+
elif isinstance(item_search_cfg, dict):
|
|
36
|
+
if "mode" in item_search_cfg:
|
|
37
|
+
mode = item_search_cfg["mode"]
|
|
38
|
+
assert mode in ["range", "list"], f"mode: {mode} not supported"
|
|
39
|
+
value_in_mode = item_search_cfg.get(mode, None)
|
|
40
|
+
if value_in_mode is None:
|
|
41
|
+
raise ValueError(f"mode<{mode}>: data not found in <{search_item}>")
|
|
42
|
+
if mode == "range":
|
|
43
|
+
assert len(value_in_mode) == 3, f"range must have 3 values: start, stop, step"
|
|
44
|
+
start = value_in_mode[0]
|
|
45
|
+
stop = value_in_mode[1]
|
|
46
|
+
step = value_in_mode[2]
|
|
47
|
+
return list(range(start, stop, step))
|
|
48
|
+
elif mode == "list":
|
|
49
|
+
return item_search_cfg["list"]
|
|
50
|
+
else:
|
|
51
|
+
return [item_search_cfg] # for int, float, str, bool, etc.
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def load_an_batch(loader_iter):
|
|
55
|
+
start = time.time()
|
|
56
|
+
next(loader_iter)
|
|
57
|
+
end = time.time()
|
|
58
|
+
return end - start
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def test_dataloader_with_cfg(origin_dataloader: DataLoader, cfg: Union[dict, str]):
|
|
62
|
+
try:
|
|
63
|
+
if isinstance(cfg, str):
|
|
64
|
+
cfg = load_yaml(cfg, to_dict=True)
|
|
65
|
+
dfmk = csvfile.DFCreator()
|
|
66
|
+
search_items = ["batch_size", "num_workers", "persistent_workers", "pin_memory"]
|
|
67
|
+
batch_limit = cfg["general"]["batch_limit"]
|
|
68
|
+
csv_cfg = cfg["general"]["to_csv"]
|
|
69
|
+
log_batch_info = cfg["general"]["log_batch_info"]
|
|
70
|
+
|
|
71
|
+
save_to_csv = csv_cfg["enabled"]
|
|
72
|
+
log_dir = csv_cfg["log_dir"]
|
|
73
|
+
filename = csv_cfg["filename"]
|
|
74
|
+
filename = f"{now_str()}_{filename}.csv"
|
|
75
|
+
outfile = os.path.join(log_dir, filename)
|
|
76
|
+
|
|
77
|
+
dfmk.create_table(
|
|
78
|
+
"cfg_search",
|
|
79
|
+
(search_items + ["avg_time_taken"]),
|
|
80
|
+
)
|
|
81
|
+
ls_range_test = []
|
|
82
|
+
for item in search_items:
|
|
83
|
+
range_test = get_test_range(cfg, search_item=item)
|
|
84
|
+
range_test = [(item, i) for i in range_test]
|
|
85
|
+
ls_range_test.append(range_test)
|
|
86
|
+
|
|
87
|
+
all_combinations = list(it.product(*ls_range_test))
|
|
88
|
+
|
|
89
|
+
rows = []
|
|
90
|
+
for cfg_idx, combine in enumerate(all_combinations):
|
|
91
|
+
console.rule(f"Testing cfg {cfg_idx+1}/{len(all_combinations)}")
|
|
92
|
+
inspect(combine)
|
|
93
|
+
batch_size = combine[search_items.index("batch_size")][1]
|
|
94
|
+
num_workers = combine[search_items.index("num_workers")][1]
|
|
95
|
+
persistent_workers = combine[search_items.index("persistent_workers")][1]
|
|
96
|
+
pin_memory = combine[search_items.index("pin_memory")][1]
|
|
97
|
+
|
|
98
|
+
test_dataloader = DataLoader(origin_dataloader.dataset, batch_size=batch_size, num_workers=num_workers, persistent_workers=persistent_workers, pin_memory=pin_memory, shuffle=True)
|
|
99
|
+
row = [
|
|
100
|
+
batch_size,
|
|
101
|
+
num_workers,
|
|
102
|
+
persistent_workers,
|
|
103
|
+
pin_memory,
|
|
104
|
+
0.0,
|
|
105
|
+
]
|
|
106
|
+
|
|
107
|
+
# calculate the avg time taken to load the data for <batch_limit> batches
|
|
108
|
+
trainiter = iter(test_dataloader)
|
|
109
|
+
time_elapsed = 0
|
|
110
|
+
pprint('Start testing...')
|
|
111
|
+
for i in tqdm(range(batch_limit)):
|
|
112
|
+
single_batch_time = load_an_batch(trainiter)
|
|
113
|
+
if log_batch_info:
|
|
114
|
+
pprint(f"Batch {i+1} took {single_batch_time:.4f} seconds to load")
|
|
115
|
+
time_elapsed += single_batch_time
|
|
116
|
+
row[-1] = time_elapsed / batch_limit
|
|
117
|
+
rows.append(row)
|
|
118
|
+
dfmk.insert_rows('cfg_search', rows)
|
|
119
|
+
dfmk.fill_table_from_row_pool('cfg_search')
|
|
120
|
+
with ConsoleLog("results"):
|
|
121
|
+
csvfile.fn_display_df(dfmk['cfg_search'])
|
|
122
|
+
if save_to_csv:
|
|
123
|
+
dfmk["cfg_search"].to_csv(outfile, index=False)
|
|
124
|
+
console.print(f"[red] Data saved to <{outfile}> [/red]")
|
|
125
|
+
|
|
126
|
+
except Exception as e:
|
|
127
|
+
traceback.print_exc()
|
|
128
|
+
print(e)
|
|
129
|
+
# get current directory of this python file
|
|
130
|
+
current_dir = os.path.dirname(os.path.realpath(__file__))
|
|
131
|
+
standar_cfg_path = os.path.join(current_dir, "torchloader_search.yaml")
|
|
132
|
+
pprint(
|
|
133
|
+
f"Make sure you get the right <cfg.yaml> file. An example of <cfg.yaml> file can be found at this path: {standar_cfg_path}"
|
|
134
|
+
)
|
|
135
|
+
return
|
|
136
|
+
|
|
137
|
+
def main():
|
|
138
|
+
args = parse_args()
|
|
139
|
+
cfg_yaml = args.cfg
|
|
140
|
+
cfg_dict = load_yaml(cfg_yaml, to_dict=True)
|
|
141
|
+
|
|
142
|
+
# Define transforms for data augmentation and normalization
|
|
143
|
+
transform = transforms.Compose(
|
|
144
|
+
[
|
|
145
|
+
transforms.RandomHorizontalFlip(), # Randomly flip images horizontally
|
|
146
|
+
transforms.RandomRotation(10), # Randomly rotate images by 10 degrees
|
|
147
|
+
transforms.ToTensor(), # Convert images to PyTorch tensors
|
|
148
|
+
transforms.Normalize(
|
|
149
|
+
(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
|
|
150
|
+
), # Normalize pixel values to [-1, 1]
|
|
151
|
+
]
|
|
152
|
+
)
|
|
153
|
+
test_dataset = datasets.CIFAR10(
|
|
154
|
+
root="./data", train=False, download=True, transform=transform
|
|
155
|
+
)
|
|
156
|
+
batch_size = 64
|
|
157
|
+
train_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
|
|
158
|
+
test_dataloader_with_cfg(train_loader, cfg_dict)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
if __name__ == "__main__":
|
|
162
|
+
main()
|
halib/utils/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
import yaml
|
|
2
|
+
from typing import Any
|
|
3
|
+
from rich.pretty import pprint
|
|
4
|
+
from ..filetype import yamlfile
|
|
5
|
+
# from halib.filetype import yamlfile
|
|
6
|
+
from dataclasses import make_dataclass
|
|
7
|
+
|
|
8
|
+
def dict_to_dataclass(name: str, data: dict):
|
|
9
|
+
fields = []
|
|
10
|
+
values = {}
|
|
11
|
+
|
|
12
|
+
for key, value in data.items():
|
|
13
|
+
if isinstance(value, dict):
|
|
14
|
+
sub_dc = dict_to_dataclass(key.capitalize(), value)
|
|
15
|
+
fields.append((key, type(sub_dc)))
|
|
16
|
+
values[key] = sub_dc
|
|
17
|
+
else:
|
|
18
|
+
field_type = type(value) if value is not None else Any
|
|
19
|
+
fields.append((key, field_type))
|
|
20
|
+
values[key] = value
|
|
21
|
+
|
|
22
|
+
DC = make_dataclass(name.capitalize(), fields)
|
|
23
|
+
return DC(**values)
|
|
24
|
+
|
|
25
|
+
def yaml_to_dataclass(name: str, yaml_str: str):
|
|
26
|
+
data = yaml.safe_load(yaml_str)
|
|
27
|
+
return dict_to_dataclass(name, data)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def yamlfile_to_dataclass(name: str, file_path: str):
|
|
31
|
+
data_dict = yamlfile.load_yaml(file_path, to_dict=True)
|
|
32
|
+
if "__base__" in data_dict:
|
|
33
|
+
del data_dict["__base__"]
|
|
34
|
+
return dict_to_dataclass(name, data_dict)
|
|
35
|
+
|
|
36
|
+
if __name__ == "__main__":
|
|
37
|
+
cfg = yamlfile_to_dataclass("Config", "test/dataclass_util_test_cfg.yaml")
|
|
38
|
+
|
|
39
|
+
# ! NOTICE: after print out this dataclass, we can copy the output and paste it into CHATGPT to generate a list of needed dataclass classes using `from dataclass_wizard import YAMLWizard`
|
|
40
|
+
pprint(cfg)
|