halib 0.1.91__py3-none-any.whl → 0.2.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- halib/__init__.py +12 -6
- halib/common/__init__.py +0 -0
- halib/common/common.py +207 -0
- halib/common/rich_color.py +285 -0
- halib/common.py +53 -10
- halib/exp/__init__.py +0 -0
- halib/exp/core/__init__.py +0 -0
- halib/exp/core/base_config.py +167 -0
- halib/exp/core/base_exp.py +147 -0
- halib/exp/core/param_gen.py +189 -0
- halib/exp/core/wandb_op.py +117 -0
- halib/exp/data/__init__.py +0 -0
- halib/exp/data/dataclass_util.py +41 -0
- halib/exp/data/dataset.py +208 -0
- halib/exp/data/torchloader.py +165 -0
- halib/exp/perf/__init__.py +0 -0
- halib/exp/perf/flop_calc.py +190 -0
- halib/exp/perf/gpu_mon.py +58 -0
- halib/exp/perf/perfcalc.py +440 -0
- halib/exp/perf/perfmetrics.py +137 -0
- halib/exp/perf/perftb.py +778 -0
- halib/exp/perf/profiler.py +507 -0
- halib/exp/viz/__init__.py +0 -0
- halib/exp/viz/plot.py +754 -0
- halib/filetype/csvfile.py +3 -9
- halib/filetype/ipynb.py +61 -0
- halib/filetype/jsonfile.py +0 -3
- halib/filetype/textfile.py +0 -1
- halib/filetype/videofile.py +119 -3
- halib/filetype/yamlfile.py +16 -1
- halib/online/projectmake.py +7 -6
- halib/online/tele_noti.py +165 -0
- halib/research/base_exp.py +75 -18
- halib/research/core/__init__.py +0 -0
- halib/research/core/base_config.py +144 -0
- halib/research/core/base_exp.py +157 -0
- halib/research/core/param_gen.py +108 -0
- halib/research/core/wandb_op.py +117 -0
- halib/research/data/__init__.py +0 -0
- halib/research/data/dataclass_util.py +41 -0
- halib/research/data/dataset.py +208 -0
- halib/research/data/torchloader.py +165 -0
- halib/research/dataset.py +6 -7
- halib/research/flop_csv.py +34 -0
- halib/research/flops.py +156 -0
- halib/research/metrics.py +4 -0
- halib/research/mics.py +59 -1
- halib/research/perf/__init__.py +0 -0
- halib/research/perf/flop_calc.py +190 -0
- halib/research/perf/gpu_mon.py +58 -0
- halib/research/perf/perfcalc.py +363 -0
- halib/research/perf/perfmetrics.py +137 -0
- halib/research/perf/perftb.py +778 -0
- halib/research/perf/profiler.py +301 -0
- halib/research/perfcalc.py +60 -35
- halib/research/perftb.py +2 -1
- halib/research/plot.py +480 -218
- halib/research/viz/__init__.py +0 -0
- halib/research/viz/plot.py +754 -0
- halib/system/_list_pc.csv +6 -0
- halib/system/filesys.py +60 -20
- halib/system/path.py +106 -0
- halib/utils/dict.py +9 -0
- halib/utils/list.py +12 -0
- halib/utils/video.py +6 -0
- halib-0.2.21.dist-info/METADATA +192 -0
- halib-0.2.21.dist-info/RECORD +109 -0
- halib-0.1.91.dist-info/METADATA +0 -201
- halib-0.1.91.dist-info/RECORD +0 -61
- {halib-0.1.91.dist-info → halib-0.2.21.dist-info}/WHEEL +0 -0
- {halib-0.1.91.dist-info → halib-0.2.21.dist-info}/licenses/LICENSE.txt +0 -0
- {halib-0.1.91.dist-info → halib-0.2.21.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
# This script create a test version
|
|
2
|
+
# of the watcam (wc) dataset
|
|
3
|
+
# for testing the tflite model
|
|
4
|
+
|
|
5
|
+
from argparse import ArgumentParser
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
import click
|
|
9
|
+
import shutil
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
from rich import inspect
|
|
12
|
+
from rich.pretty import pprint
|
|
13
|
+
from torchvision.datasets import ImageFolder
|
|
14
|
+
from sklearn.model_selection import StratifiedShuffleSplit, ShuffleSplit
|
|
15
|
+
|
|
16
|
+
from ...common.common import console, seed_everything, ConsoleLog
|
|
17
|
+
from ...system import filesys as fs
|
|
18
|
+
|
|
19
|
+
def parse_args():
|
|
20
|
+
parser = ArgumentParser(description="desc text")
|
|
21
|
+
parser.add_argument(
|
|
22
|
+
"-indir",
|
|
23
|
+
"--indir",
|
|
24
|
+
type=str,
|
|
25
|
+
help="orignal dataset path",
|
|
26
|
+
)
|
|
27
|
+
parser.add_argument(
|
|
28
|
+
"-outdir",
|
|
29
|
+
"--outdir",
|
|
30
|
+
type=str,
|
|
31
|
+
help="dataset out path",
|
|
32
|
+
default=".", # default to current dir
|
|
33
|
+
)
|
|
34
|
+
parser.add_argument(
|
|
35
|
+
"-val_size",
|
|
36
|
+
"--val_size",
|
|
37
|
+
type=float,
|
|
38
|
+
help="validation size", # no default value to force user to input
|
|
39
|
+
default=0.2,
|
|
40
|
+
)
|
|
41
|
+
# add using StratifiedShuffleSplit or ShuffleSplit
|
|
42
|
+
parser.add_argument(
|
|
43
|
+
"-seed",
|
|
44
|
+
"--seed",
|
|
45
|
+
type=int,
|
|
46
|
+
help="random seed",
|
|
47
|
+
default=42,
|
|
48
|
+
)
|
|
49
|
+
parser.add_argument(
|
|
50
|
+
"-inplace",
|
|
51
|
+
"--inplace",
|
|
52
|
+
action="store_true",
|
|
53
|
+
help="inplace operation, will overwrite the outdir if exists",
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
parser.add_argument(
|
|
57
|
+
"-stratified",
|
|
58
|
+
"--stratified",
|
|
59
|
+
action="store_true",
|
|
60
|
+
help="use StratifiedShuffleSplit instead of ShuffleSplit",
|
|
61
|
+
)
|
|
62
|
+
parser.add_argument(
|
|
63
|
+
"-no_train",
|
|
64
|
+
"--no_train",
|
|
65
|
+
action="store_true",
|
|
66
|
+
help="only create test set, no train set",
|
|
67
|
+
)
|
|
68
|
+
parser.add_argument(
|
|
69
|
+
"-reverse",
|
|
70
|
+
"--reverse",
|
|
71
|
+
action="store_true",
|
|
72
|
+
help="combine train and val set back to original dataset",
|
|
73
|
+
)
|
|
74
|
+
return parser.parse_args()
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def move_images(image_paths, target_set_dir):
|
|
78
|
+
for img_path in tqdm(image_paths):
|
|
79
|
+
# get folder name of the image
|
|
80
|
+
img_dir = os.path.dirname(img_path)
|
|
81
|
+
out_cls_dir = os.path.join(target_set_dir, os.path.basename(img_dir))
|
|
82
|
+
if not os.path.exists(out_cls_dir):
|
|
83
|
+
os.makedirs(out_cls_dir)
|
|
84
|
+
# move the image to the class folder
|
|
85
|
+
shutil.move(img_path, out_cls_dir)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def split_dataset_cls(
|
|
89
|
+
indir, outdir, val_size, seed, inplace, stratified_split, no_train
|
|
90
|
+
):
|
|
91
|
+
seed_everything(seed)
|
|
92
|
+
console.rule("Config confirm?")
|
|
93
|
+
pprint(locals())
|
|
94
|
+
click.confirm("Continue?", abort=True)
|
|
95
|
+
assert os.path.exists(indir), f"{indir} does not exist"
|
|
96
|
+
|
|
97
|
+
if not inplace:
|
|
98
|
+
assert (not inplace) and (
|
|
99
|
+
not os.path.exists(outdir)
|
|
100
|
+
), f"{outdir} already exists; SKIP ...."
|
|
101
|
+
|
|
102
|
+
if inplace:
|
|
103
|
+
outdir = indir
|
|
104
|
+
if not os.path.exists(outdir):
|
|
105
|
+
os.makedirs(outdir)
|
|
106
|
+
|
|
107
|
+
console.rule(f"Creating train/val dataset")
|
|
108
|
+
|
|
109
|
+
sss = (
|
|
110
|
+
ShuffleSplit(n_splits=1, test_size=val_size)
|
|
111
|
+
if not stratified_split
|
|
112
|
+
else StratifiedShuffleSplit(n_splits=1, test_size=val_size)
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
pprint({"split strategy": sss, "indir": indir, "outdir": outdir})
|
|
116
|
+
dataset = ImageFolder(
|
|
117
|
+
root=indir,
|
|
118
|
+
transform=None,
|
|
119
|
+
)
|
|
120
|
+
train_dataset_indices = None
|
|
121
|
+
val_dataset_indices = None # val here means test
|
|
122
|
+
for train_indices, val_indices in sss.split(dataset.samples, dataset.targets):
|
|
123
|
+
train_dataset_indices = train_indices
|
|
124
|
+
val_dataset_indices = val_indices
|
|
125
|
+
|
|
126
|
+
# get image paths for train/val split dataset
|
|
127
|
+
train_image_paths = [dataset.imgs[i][0] for i in train_dataset_indices]
|
|
128
|
+
val_image_paths = [dataset.imgs[i][0] for i in val_dataset_indices]
|
|
129
|
+
|
|
130
|
+
# start creating train/val folders then move images
|
|
131
|
+
out_train_dir = os.path.join(outdir, "train")
|
|
132
|
+
out_val_dir = os.path.join(outdir, "val")
|
|
133
|
+
if inplace:
|
|
134
|
+
assert os.path.exists(out_train_dir) == False, f"{out_train_dir} already exists"
|
|
135
|
+
assert os.path.exists(out_val_dir) == False, f"{out_val_dir} already exists"
|
|
136
|
+
|
|
137
|
+
os.makedirs(out_train_dir)
|
|
138
|
+
os.makedirs(out_val_dir)
|
|
139
|
+
|
|
140
|
+
if not no_train:
|
|
141
|
+
with ConsoleLog(f"Moving train images to {out_train_dir} "):
|
|
142
|
+
move_images(train_image_paths, out_train_dir)
|
|
143
|
+
else:
|
|
144
|
+
pprint("test only, skip moving train images")
|
|
145
|
+
# remove out_train_dir
|
|
146
|
+
shutil.rmtree(out_train_dir)
|
|
147
|
+
|
|
148
|
+
with ConsoleLog(f"Moving val images to {out_val_dir} "):
|
|
149
|
+
move_images(val_image_paths, out_val_dir)
|
|
150
|
+
|
|
151
|
+
if inplace:
|
|
152
|
+
pprint(f"remove all folders, except train and val")
|
|
153
|
+
for cls_dir in os.listdir(outdir):
|
|
154
|
+
if cls_dir not in ["train", "val"]:
|
|
155
|
+
shutil.rmtree(os.path.join(indir, cls_dir))
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def reverse_split_ds(indir):
|
|
159
|
+
console.rule(f"Reversing split dataset <{indir}>...")
|
|
160
|
+
ls_dirs = os.listdir(indir)
|
|
161
|
+
# make sure there are only two dirs 'train' and 'val'
|
|
162
|
+
assert len(ls_dirs) == 2, f"Found more than 2 dirs: {len(ls_dirs) } dirs"
|
|
163
|
+
assert "train" in ls_dirs, f"train dir not found in {indir}"
|
|
164
|
+
assert "val" in ls_dirs, f"val dir not found in {indir}"
|
|
165
|
+
train_dir = os.path.join(indir, "train")
|
|
166
|
+
val_dir = os.path.join(indir, "val")
|
|
167
|
+
all_train_files = fs.filter_files_by_extension(
|
|
168
|
+
train_dir, ["jpg", "jpeg", "png", "bmp", "gif", "tiff"]
|
|
169
|
+
)
|
|
170
|
+
all_val_files = fs.filter_files_by_extension(
|
|
171
|
+
val_dir, ["jpg", "jpeg", "png", "bmp", "gif", "tiff"]
|
|
172
|
+
)
|
|
173
|
+
# move all files from train to indir
|
|
174
|
+
with ConsoleLog(f"Moving train images to {indir} "):
|
|
175
|
+
move_images(all_train_files, indir)
|
|
176
|
+
with ConsoleLog(f"Moving val images to {indir} "):
|
|
177
|
+
move_images(all_val_files, indir)
|
|
178
|
+
with ConsoleLog(f"Removing train and val dirs"):
|
|
179
|
+
# remove train and val dirs
|
|
180
|
+
shutil.rmtree(train_dir)
|
|
181
|
+
shutil.rmtree(val_dir)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def main():
|
|
185
|
+
args = parse_args()
|
|
186
|
+
indir = args.indir
|
|
187
|
+
outdir = args.outdir
|
|
188
|
+
if outdir == ".":
|
|
189
|
+
# get current folder of the indir
|
|
190
|
+
indir_parent_dir = os.path.dirname(os.path.normpath(indir))
|
|
191
|
+
indir_name = os.path.basename(indir)
|
|
192
|
+
outdir = os.path.join(indir_parent_dir, f"{indir_name}_split")
|
|
193
|
+
val_size = args.val_size
|
|
194
|
+
seed = args.seed
|
|
195
|
+
inplace = args.inplace
|
|
196
|
+
stratified_split = args.stratified
|
|
197
|
+
no_train = args.no_train
|
|
198
|
+
reverse = args.reverse
|
|
199
|
+
if not reverse:
|
|
200
|
+
split_dataset_cls(
|
|
201
|
+
indir, outdir, val_size, seed, inplace, stratified_split, no_train
|
|
202
|
+
)
|
|
203
|
+
else:
|
|
204
|
+
reverse_split_ds(indir)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
if __name__ == "__main__":
|
|
208
|
+
main()
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
"""
|
|
2
|
+
* @author Hoang Van-Ha
|
|
3
|
+
* @email hoangvanhauit@gmail.com
|
|
4
|
+
* @create date 2024-03-27 15:40:22
|
|
5
|
+
* @modify date 2024-03-27 15:40:22
|
|
6
|
+
* @desc this module works as a utility tools for finding the best configuration for dataloader (num_workers, batch_size, pin_menory, etc.) that fits your hardware.
|
|
7
|
+
"""
|
|
8
|
+
from argparse import ArgumentParser
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
import time
|
|
12
|
+
import traceback
|
|
13
|
+
|
|
14
|
+
from tqdm import tqdm
|
|
15
|
+
from rich import inspect
|
|
16
|
+
from typing import Union
|
|
17
|
+
import itertools as it # for cartesian product
|
|
18
|
+
|
|
19
|
+
from torch.utils.data import DataLoader
|
|
20
|
+
from torchvision import datasets, transforms
|
|
21
|
+
|
|
22
|
+
from ...common.common import *
|
|
23
|
+
from ...filetype import csvfile
|
|
24
|
+
from ...filetype.yamlfile import load_yaml
|
|
25
|
+
|
|
26
|
+
def parse_args():
|
|
27
|
+
parser = ArgumentParser(description="desc text")
|
|
28
|
+
parser.add_argument("-cfg", "--cfg", type=str, help="cfg file for searching")
|
|
29
|
+
return parser.parse_args()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_test_range(cfg: dict, search_item="num_workers"):
|
|
33
|
+
item_search_cfg = cfg["search_space"].get(search_item, None)
|
|
34
|
+
if item_search_cfg is None:
|
|
35
|
+
raise ValueError(f"search_item: {search_item} not found in cfg")
|
|
36
|
+
if isinstance(item_search_cfg, list):
|
|
37
|
+
return item_search_cfg
|
|
38
|
+
elif isinstance(item_search_cfg, dict):
|
|
39
|
+
if "mode" in item_search_cfg:
|
|
40
|
+
mode = item_search_cfg["mode"]
|
|
41
|
+
assert mode in ["range", "list"], f"mode: {mode} not supported"
|
|
42
|
+
value_in_mode = item_search_cfg.get(mode, None)
|
|
43
|
+
if value_in_mode is None:
|
|
44
|
+
raise ValueError(f"mode<{mode}>: data not found in <{search_item}>")
|
|
45
|
+
if mode == "range":
|
|
46
|
+
assert len(value_in_mode) == 3, f"range must have 3 values: start, stop, step"
|
|
47
|
+
start = value_in_mode[0]
|
|
48
|
+
stop = value_in_mode[1]
|
|
49
|
+
step = value_in_mode[2]
|
|
50
|
+
return list(range(start, stop, step))
|
|
51
|
+
elif mode == "list":
|
|
52
|
+
return item_search_cfg["list"]
|
|
53
|
+
else:
|
|
54
|
+
return [item_search_cfg] # for int, float, str, bool, etc.
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def load_an_batch(loader_iter):
|
|
58
|
+
start = time.time()
|
|
59
|
+
next(loader_iter)
|
|
60
|
+
end = time.time()
|
|
61
|
+
return end - start
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def test_dataloader_with_cfg(origin_dataloader: DataLoader, cfg: Union[dict, str]):
|
|
65
|
+
try:
|
|
66
|
+
if isinstance(cfg, str):
|
|
67
|
+
cfg = load_yaml(cfg, to_dict=True)
|
|
68
|
+
dfmk = csvfile.DFCreator()
|
|
69
|
+
search_items = ["batch_size", "num_workers", "persistent_workers", "pin_memory"]
|
|
70
|
+
batch_limit = cfg["general"]["batch_limit"]
|
|
71
|
+
csv_cfg = cfg["general"]["to_csv"]
|
|
72
|
+
log_batch_info = cfg["general"]["log_batch_info"]
|
|
73
|
+
|
|
74
|
+
save_to_csv = csv_cfg["enabled"]
|
|
75
|
+
log_dir = csv_cfg["log_dir"]
|
|
76
|
+
filename = csv_cfg["filename"]
|
|
77
|
+
filename = f"{now_str()}_{filename}.csv"
|
|
78
|
+
outfile = os.path.join(log_dir, filename)
|
|
79
|
+
|
|
80
|
+
dfmk.create_table(
|
|
81
|
+
"cfg_search",
|
|
82
|
+
(search_items + ["avg_time_taken"]),
|
|
83
|
+
)
|
|
84
|
+
ls_range_test = []
|
|
85
|
+
for item in search_items:
|
|
86
|
+
range_test = get_test_range(cfg, search_item=item)
|
|
87
|
+
range_test = [(item, i) for i in range_test]
|
|
88
|
+
ls_range_test.append(range_test)
|
|
89
|
+
|
|
90
|
+
all_combinations = list(it.product(*ls_range_test))
|
|
91
|
+
|
|
92
|
+
rows = []
|
|
93
|
+
for cfg_idx, combine in enumerate(all_combinations):
|
|
94
|
+
console.rule(f"Testing cfg {cfg_idx+1}/{len(all_combinations)}")
|
|
95
|
+
inspect(combine)
|
|
96
|
+
batch_size = combine[search_items.index("batch_size")][1]
|
|
97
|
+
num_workers = combine[search_items.index("num_workers")][1]
|
|
98
|
+
persistent_workers = combine[search_items.index("persistent_workers")][1]
|
|
99
|
+
pin_memory = combine[search_items.index("pin_memory")][1]
|
|
100
|
+
|
|
101
|
+
test_dataloader = DataLoader(origin_dataloader.dataset, batch_size=batch_size, num_workers=num_workers, persistent_workers=persistent_workers, pin_memory=pin_memory, shuffle=True)
|
|
102
|
+
row = [
|
|
103
|
+
batch_size,
|
|
104
|
+
num_workers,
|
|
105
|
+
persistent_workers,
|
|
106
|
+
pin_memory,
|
|
107
|
+
0.0,
|
|
108
|
+
]
|
|
109
|
+
|
|
110
|
+
# calculate the avg time taken to load the data for <batch_limit> batches
|
|
111
|
+
trainiter = iter(test_dataloader)
|
|
112
|
+
time_elapsed = 0
|
|
113
|
+
pprint('Start testing...')
|
|
114
|
+
for i in tqdm(range(batch_limit)):
|
|
115
|
+
single_batch_time = load_an_batch(trainiter)
|
|
116
|
+
if log_batch_info:
|
|
117
|
+
pprint(f"Batch {i+1} took {single_batch_time:.4f} seconds to load")
|
|
118
|
+
time_elapsed += single_batch_time
|
|
119
|
+
row[-1] = time_elapsed / batch_limit
|
|
120
|
+
rows.append(row)
|
|
121
|
+
dfmk.insert_rows('cfg_search', rows)
|
|
122
|
+
dfmk.fill_table_from_row_pool('cfg_search')
|
|
123
|
+
with ConsoleLog("results"):
|
|
124
|
+
csvfile.fn_display_df(dfmk['cfg_search'])
|
|
125
|
+
if save_to_csv:
|
|
126
|
+
dfmk["cfg_search"].to_csv(outfile, index=False)
|
|
127
|
+
console.print(f"[red] Data saved to <{outfile}> [/red]")
|
|
128
|
+
|
|
129
|
+
except Exception as e:
|
|
130
|
+
traceback.print_exc()
|
|
131
|
+
print(e)
|
|
132
|
+
# get current directory of this python file
|
|
133
|
+
current_dir = os.path.dirname(os.path.realpath(__file__))
|
|
134
|
+
standar_cfg_path = os.path.join(current_dir, "torchloader_search.yaml")
|
|
135
|
+
pprint(
|
|
136
|
+
f"Make sure you get the right <cfg.yaml> file. An example of <cfg.yaml> file can be found at this path: {standar_cfg_path}"
|
|
137
|
+
)
|
|
138
|
+
return
|
|
139
|
+
|
|
140
|
+
def main():
|
|
141
|
+
args = parse_args()
|
|
142
|
+
cfg_yaml = args.cfg
|
|
143
|
+
cfg_dict = load_yaml(cfg_yaml, to_dict=True)
|
|
144
|
+
|
|
145
|
+
# Define transforms for data augmentation and normalization
|
|
146
|
+
transform = transforms.Compose(
|
|
147
|
+
[
|
|
148
|
+
transforms.RandomHorizontalFlip(), # Randomly flip images horizontally
|
|
149
|
+
transforms.RandomRotation(10), # Randomly rotate images by 10 degrees
|
|
150
|
+
transforms.ToTensor(), # Convert images to PyTorch tensors
|
|
151
|
+
transforms.Normalize(
|
|
152
|
+
(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
|
|
153
|
+
), # Normalize pixel values to [-1, 1]
|
|
154
|
+
]
|
|
155
|
+
)
|
|
156
|
+
test_dataset = datasets.CIFAR10(
|
|
157
|
+
root="./data", train=False, download=True, transform=transform
|
|
158
|
+
)
|
|
159
|
+
batch_size = 64
|
|
160
|
+
train_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
|
|
161
|
+
test_dataloader_with_cfg(train_loader, cfg_dict)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
if __name__ == "__main__":
|
|
165
|
+
main()
|
halib/research/dataset.py
CHANGED
|
@@ -4,18 +4,17 @@
|
|
|
4
4
|
|
|
5
5
|
from argparse import ArgumentParser
|
|
6
6
|
|
|
7
|
-
from rich import inspect
|
|
8
|
-
from common import console, seed_everything, ConsoleLog
|
|
9
|
-
from sklearn.model_selection import StratifiedShuffleSplit, ShuffleSplit
|
|
10
|
-
from tqdm import tqdm
|
|
11
7
|
import os
|
|
12
8
|
import click
|
|
13
|
-
from torchvision.datasets import ImageFolder
|
|
14
9
|
import shutil
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
from rich import inspect
|
|
15
12
|
from rich.pretty import pprint
|
|
16
|
-
from
|
|
17
|
-
import
|
|
13
|
+
from torchvision.datasets import ImageFolder
|
|
14
|
+
from sklearn.model_selection import StratifiedShuffleSplit, ShuffleSplit
|
|
18
15
|
|
|
16
|
+
from ..common import console, seed_everything, ConsoleLog
|
|
17
|
+
from ..system import filesys as fs
|
|
19
18
|
|
|
20
19
|
def parse_args():
|
|
21
20
|
parser = ArgumentParser(description="desc text")
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from halib import *
|
|
2
|
+
from flops import _calculate_flops_for_model
|
|
3
|
+
|
|
4
|
+
from halib import *
|
|
5
|
+
from argparse import ArgumentParser
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def main():
|
|
9
|
+
csv_file = "./results-imagenet.csv"
|
|
10
|
+
df = pd.read_csv(csv_file)
|
|
11
|
+
# make param_count column as float
|
|
12
|
+
# df['param_count'] = df['param_count'].astype(float)
|
|
13
|
+
df['param_count'] = pd.to_numeric(df['param_count'], errors='coerce').fillna(99999).astype(float)
|
|
14
|
+
df = df[df['param_count'] < 5.0] # filter models with param_count < 20M
|
|
15
|
+
|
|
16
|
+
dict_ls = []
|
|
17
|
+
|
|
18
|
+
for index, row in tqdm(df.iterrows()):
|
|
19
|
+
console.rule(f"Row {index+1}/{len(df)}")
|
|
20
|
+
model = row['model']
|
|
21
|
+
num_class = 2
|
|
22
|
+
_, _, mflops = _calculate_flops_for_model(model, num_class)
|
|
23
|
+
dict_ls.append({'model': model, 'param_count': row['param_count'], 'mflops': mflops})
|
|
24
|
+
|
|
25
|
+
# Create a DataFrame from the list of dictionaries
|
|
26
|
+
result_df = pd.DataFrame(dict_ls)
|
|
27
|
+
|
|
28
|
+
final_df = pd.merge(df, result_df, on=['model', 'param_count'])
|
|
29
|
+
final_df.sort_values(by='mflops', inplace=True, ascending=True)
|
|
30
|
+
csvfile.fn_display_df(final_df)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
if __name__ == "__main__":
|
|
34
|
+
main()
|
halib/research/flops.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import torch
|
|
4
|
+
import timm
|
|
5
|
+
from argparse import ArgumentParser
|
|
6
|
+
from fvcore.nn import FlopCountAnalysis
|
|
7
|
+
from halib import *
|
|
8
|
+
from halib.filetype import csvfile
|
|
9
|
+
from curriculum.utils.config import *
|
|
10
|
+
from curriculum.utils.model_helper import *
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# ---------------------------------------------------------------------
|
|
14
|
+
# Argument Parser
|
|
15
|
+
# ---------------------------------------------------------------------
|
|
16
|
+
def parse_args():
|
|
17
|
+
parser = ArgumentParser(description="Calculate FLOPs for TIMM or trained models")
|
|
18
|
+
|
|
19
|
+
# Option 1: Direct TIMM model
|
|
20
|
+
parser.add_argument(
|
|
21
|
+
"--model_name", type=str, help="TIMM model name (e.g., efficientnet_b0)"
|
|
22
|
+
)
|
|
23
|
+
parser.add_argument(
|
|
24
|
+
"--num_classes", type=int, default=1000, help="Number of output classes"
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
# Option 2: Experiment directory
|
|
28
|
+
parser.add_argument(
|
|
29
|
+
"--indir",
|
|
30
|
+
type=str,
|
|
31
|
+
default=None,
|
|
32
|
+
help="Directory containing trained experiment (with .yaml and .pth)",
|
|
33
|
+
)
|
|
34
|
+
parser.add_argument(
|
|
35
|
+
"-o", "--o", action="store_true", help="Open output CSV after saving"
|
|
36
|
+
)
|
|
37
|
+
return parser.parse_args()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
# ---------------------------------------------------------------------
|
|
41
|
+
# Helper Functions
|
|
42
|
+
# ---------------------------------------------------------------------
|
|
43
|
+
def _get_list_of_proc_dirs(indir):
|
|
44
|
+
assert os.path.exists(indir), f"Input directory {indir} does not exist."
|
|
45
|
+
pth_files = [f for f in os.listdir(indir) if f.endswith(".pth")]
|
|
46
|
+
if len(pth_files) > 0:
|
|
47
|
+
return [indir]
|
|
48
|
+
return [
|
|
49
|
+
os.path.join(indir, f)
|
|
50
|
+
for f in os.listdir(indir)
|
|
51
|
+
if os.path.isdir(os.path.join(indir, f))
|
|
52
|
+
]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _calculate_flops_for_model(model_name, num_classes):
|
|
56
|
+
"""Calculate FLOPs for a plain TIMM model."""
|
|
57
|
+
try:
|
|
58
|
+
model = timm.create_model(model_name, pretrained=False, num_classes=num_classes)
|
|
59
|
+
input_size = timm.data.resolve_data_config(model.default_cfg)["input_size"]
|
|
60
|
+
dummy_input = torch.randn(1, *input_size)
|
|
61
|
+
model.eval() # ! set to eval mode to avoid some warnings or errors
|
|
62
|
+
flops = FlopCountAnalysis(model, dummy_input)
|
|
63
|
+
gflops = flops.total() / 1e9
|
|
64
|
+
mflops = flops.total() / 1e6
|
|
65
|
+
print(f"\nModel: **{model_name}**, Classes: {num_classes}")
|
|
66
|
+
print(f"Input size: {input_size}, FLOPs: **{gflops:.3f} GFLOPs**, **{mflops:.3f} MFLOPs**\n")
|
|
67
|
+
return model_name, gflops, mflops
|
|
68
|
+
except Exception as e:
|
|
69
|
+
print(f"[Error] Could not calculate FLOPs for {model_name}: {e}")
|
|
70
|
+
return model_name, -1, -1
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _calculate_flops_for_experiment(exp_dir):
|
|
74
|
+
"""Calculate FLOPs for a trained experiment directory."""
|
|
75
|
+
yaml_files = [f for f in os.listdir(exp_dir) if f.endswith(".yaml")]
|
|
76
|
+
pth_files = [f for f in os.listdir(exp_dir) if f.endswith(".pth")]
|
|
77
|
+
|
|
78
|
+
assert (
|
|
79
|
+
len(yaml_files) == 1
|
|
80
|
+
), f"Expected 1 YAML file in {exp_dir}, found {len(yaml_files)}"
|
|
81
|
+
assert (
|
|
82
|
+
len(pth_files) == 1
|
|
83
|
+
), f"Expected 1 PTH file in {exp_dir}, found {len(pth_files)}"
|
|
84
|
+
|
|
85
|
+
exp_cfg_yaml = os.path.join(exp_dir, yaml_files[0])
|
|
86
|
+
cfg = ExpConfig.from_yaml(exp_cfg_yaml)
|
|
87
|
+
ds_label_list = cfg.dataset.get_label_list()
|
|
88
|
+
|
|
89
|
+
try:
|
|
90
|
+
model = build_model(
|
|
91
|
+
cfg.model.name, num_classes=len(ds_label_list), pretrained=True
|
|
92
|
+
)
|
|
93
|
+
model_weights_path = os.path.join(exp_dir, pth_files[0])
|
|
94
|
+
model.load_state_dict(torch.load(model_weights_path, map_location="cpu"))
|
|
95
|
+
model.eval()
|
|
96
|
+
|
|
97
|
+
input_size = timm.data.resolve_data_config(model.default_cfg)["input_size"]
|
|
98
|
+
dummy_input = torch.randn(1, *input_size)
|
|
99
|
+
flops = FlopCountAnalysis(model, dummy_input)
|
|
100
|
+
gflops = flops.total() / 1e9
|
|
101
|
+
mflops = flops.total() / 1e6
|
|
102
|
+
|
|
103
|
+
return str(cfg), cfg.model.name, gflops, mflops
|
|
104
|
+
except Exception as e:
|
|
105
|
+
console.print(f"[red] Error processing {exp_dir}: {e}[/red]")
|
|
106
|
+
return str(cfg), cfg.model.name, -1, -1
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
# ---------------------------------------------------------------------
|
|
110
|
+
# Main Entry
|
|
111
|
+
# ---------------------------------------------------------------------
|
|
112
|
+
def main():
|
|
113
|
+
args = parse_args()
|
|
114
|
+
|
|
115
|
+
# Case 1: Direct TIMM model input
|
|
116
|
+
if args.model_name:
|
|
117
|
+
_calculate_flops_for_model(args.model_name, args.num_classes)
|
|
118
|
+
return
|
|
119
|
+
|
|
120
|
+
# Case 2: Experiment directory input
|
|
121
|
+
if args.indir is None:
|
|
122
|
+
print("[Error] Either --model_name or --indir must be specified.")
|
|
123
|
+
return
|
|
124
|
+
|
|
125
|
+
proc_dirs = _get_list_of_proc_dirs(args.indir)
|
|
126
|
+
pprint(proc_dirs)
|
|
127
|
+
|
|
128
|
+
dfmk = csvfile.DFCreator()
|
|
129
|
+
TABLE_NAME = "model_flops_results"
|
|
130
|
+
dfmk.create_table(TABLE_NAME, ["exp_name", "model_name", "gflops", "mflops"])
|
|
131
|
+
|
|
132
|
+
console.rule(f"Calculating FLOPs for models in {len(proc_dirs)} dir(s)...")
|
|
133
|
+
rows = []
|
|
134
|
+
for exp_dir in tqdm(proc_dirs):
|
|
135
|
+
dir_name = os.path.basename(exp_dir)
|
|
136
|
+
console.rule(f"{dir_name}")
|
|
137
|
+
exp_name, model_name, gflops, mflops = _calculate_flops_for_experiment(exp_dir)
|
|
138
|
+
rows.append([exp_name, model_name, gflops, mflops])
|
|
139
|
+
|
|
140
|
+
dfmk.insert_rows(TABLE_NAME, rows)
|
|
141
|
+
dfmk.fill_table_from_row_pool(TABLE_NAME)
|
|
142
|
+
|
|
143
|
+
outfile = f"zout/zreport/{now_str()}_model_flops_results.csv"
|
|
144
|
+
dfmk[TABLE_NAME].to_csv(outfile, sep=";", index=False)
|
|
145
|
+
csvfile.fn_display_df(dfmk[TABLE_NAME])
|
|
146
|
+
|
|
147
|
+
if args.o:
|
|
148
|
+
os.system(f"start {outfile}")
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
# ---------------------------------------------------------------------
|
|
152
|
+
# Script Entry
|
|
153
|
+
# ---------------------------------------------------------------------
|
|
154
|
+
if __name__ == "__main__":
|
|
155
|
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
|
156
|
+
main()
|
halib/research/metrics.py
CHANGED
|
@@ -11,6 +11,10 @@ class MetricsBackend(ABC):
|
|
|
11
11
|
def __init__(self, metrics_info: Union[List[str], Dict[str, Any]]):
|
|
12
12
|
"""
|
|
13
13
|
Initialize the backend with optional metrics_info.
|
|
14
|
+
`metrics_info` can be either:
|
|
15
|
+
- A list of metric names (strings). e.g., ["accuracy", "precision"]
|
|
16
|
+
- A dict mapping metric names with object that defines how to compute them. e.g: {"accuracy": torchmetrics.Accuracy(), "precision": torchmetrics.Precision()}
|
|
17
|
+
|
|
14
18
|
"""
|
|
15
19
|
self.metric_info = metrics_info
|
|
16
20
|
self.validate_metrics_info(self.metric_info)
|
halib/research/mics.py
CHANGED
|
@@ -1,16 +1,74 @@
|
|
|
1
|
+
from ..common import *
|
|
2
|
+
from ..filetype import csvfile
|
|
3
|
+
import pandas as pd
|
|
1
4
|
import platform
|
|
2
5
|
|
|
6
|
+
|
|
3
7
|
PC_NAME_TO_ABBR = {
|
|
4
8
|
"DESKTOP-JQD9K01": "MainPC",
|
|
5
9
|
"DESKTOP-5IRHU87": "MSI_Laptop",
|
|
6
10
|
"DESKTOP-96HQCNO": "4090_SV",
|
|
7
11
|
"DESKTOP-Q2IKLC0": "4GPU_SV",
|
|
8
|
-
"DESKTOP-QNS3DNF": "1GPU_SV"
|
|
12
|
+
"DESKTOP-QNS3DNF": "1GPU_SV",
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
DEFAULT_ABBR_WORKING_DISK = {
|
|
16
|
+
"MainPC": "E:",
|
|
17
|
+
"MSI_Laptop": "D:",
|
|
18
|
+
"4090_SV": "E:",
|
|
19
|
+
"4GPU_SV": "D:",
|
|
9
20
|
}
|
|
10
21
|
|
|
22
|
+
|
|
23
|
+
def list_PCs(show=True):
|
|
24
|
+
df = pd.DataFrame(
|
|
25
|
+
list(PC_NAME_TO_ABBR.items()), columns=["PC Name", "Abbreviation"]
|
|
26
|
+
)
|
|
27
|
+
if show:
|
|
28
|
+
csvfile.fn_display_df(df)
|
|
29
|
+
return df
|
|
30
|
+
|
|
31
|
+
|
|
11
32
|
def get_PC_name():
|
|
12
33
|
return platform.node()
|
|
13
34
|
|
|
35
|
+
|
|
14
36
|
def get_PC_abbr_name():
|
|
15
37
|
pc_name = get_PC_name()
|
|
16
38
|
return PC_NAME_TO_ABBR.get(pc_name, "Unknown")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# ! This funcction search for full paths in the obj and normalize them according to the current platform and working disk
|
|
42
|
+
# ! E.g: "E:/zdataset/DFire", but working_disk: "D:", current_platform: "windows" => "D:/zdataset/DFire"
|
|
43
|
+
# ! E.g: "E:/zdataset/DFire", but working_disk: "D:", current_platform: "linux" => "/mnt/d/zdataset/DFire"
|
|
44
|
+
def normalize_paths(obj, working_disk, current_platform):
|
|
45
|
+
if isinstance(obj, dict):
|
|
46
|
+
for key, value in obj.items():
|
|
47
|
+
obj[key] = normalize_paths(value, working_disk, current_platform)
|
|
48
|
+
return obj
|
|
49
|
+
elif isinstance(obj, list):
|
|
50
|
+
for i, item in enumerate(obj):
|
|
51
|
+
obj[i] = normalize_paths(item, working_disk, current_platform)
|
|
52
|
+
return obj
|
|
53
|
+
elif isinstance(obj, str):
|
|
54
|
+
# Normalize backslashes to forward slashes for consistency
|
|
55
|
+
obj = obj.replace("\\", "/")
|
|
56
|
+
# Regex for Windows-style path: e.g., "E:/zdataset/DFire"
|
|
57
|
+
win_match = re.match(r"^([A-Z]):/(.*)$", obj)
|
|
58
|
+
# Regex for Linux-style path: e.g., "/mnt/e/zdataset/DFire"
|
|
59
|
+
lin_match = re.match(r"^/mnt/([a-z])/(.*)$", obj)
|
|
60
|
+
if win_match or lin_match:
|
|
61
|
+
rest = win_match.group(2) if win_match else lin_match.group(2)
|
|
62
|
+
if current_platform == "windows":
|
|
63
|
+
# working_disk is like "D:", so "D:/" + rest
|
|
64
|
+
new_path = working_disk + "/" + rest
|
|
65
|
+
elif current_platform == "linux":
|
|
66
|
+
# Extract drive letter from working_disk (e.g., "D:" -> "d")
|
|
67
|
+
drive_letter = working_disk[0].lower()
|
|
68
|
+
new_path = "/mnt/" + drive_letter + "/" + rest
|
|
69
|
+
else:
|
|
70
|
+
# Unknown platform, return original
|
|
71
|
+
return obj
|
|
72
|
+
return new_path
|
|
73
|
+
# For non-strings or non-path strings, return as is
|
|
74
|
+
return obj
|
|
File without changes
|