halib 0.1.99__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- halib/__init__.py +3 -3
- halib/common/__init__.py +0 -0
- halib/common/common.py +178 -0
- halib/common/rich_color.py +285 -0
- halib/filetype/csvfile.py +3 -9
- halib/filetype/ipynb.py +3 -5
- halib/filetype/jsonfile.py +0 -3
- halib/filetype/textfile.py +0 -1
- halib/filetype/videofile.py +91 -2
- halib/filetype/yamlfile.py +3 -3
- halib/online/projectmake.py +7 -6
- halib/online/tele_noti.py +165 -0
- halib/research/base_exp.py +75 -18
- halib/research/core/__init__.py +0 -0
- halib/research/core/base_config.py +144 -0
- halib/research/core/base_exp.py +157 -0
- halib/research/core/param_gen.py +108 -0
- halib/research/core/wandb_op.py +117 -0
- halib/research/data/__init__.py +0 -0
- halib/research/data/dataclass_util.py +41 -0
- halib/research/data/dataset.py +208 -0
- halib/research/data/torchloader.py +165 -0
- halib/research/dataset.py +1 -1
- halib/research/metrics.py +4 -0
- halib/research/mics.py +8 -2
- halib/research/perf/__init__.py +0 -0
- halib/research/perf/flop_calc.py +190 -0
- halib/research/perf/gpu_mon.py +58 -0
- halib/research/perf/perfcalc.py +363 -0
- halib/research/perf/perfmetrics.py +137 -0
- halib/research/perf/perftb.py +778 -0
- halib/research/perf/profiler.py +301 -0
- halib/research/perfcalc.py +57 -32
- halib/research/viz/__init__.py +0 -0
- halib/research/viz/plot.py +754 -0
- halib/system/filesys.py +60 -20
- halib/system/path.py +73 -0
- halib/utils/dict.py +9 -0
- halib/utils/list.py +12 -0
- {halib-0.1.99.dist-info → halib-0.2.2.dist-info}/METADATA +7 -1
- halib-0.2.2.dist-info/RECORD +89 -0
- halib-0.1.99.dist-info/RECORD +0 -64
- {halib-0.1.99.dist-info → halib-0.2.2.dist-info}/WHEEL +0 -0
- {halib-0.1.99.dist-info → halib-0.2.2.dist-info}/licenses/LICENSE.txt +0 -0
- {halib-0.1.99.dist-info → halib-0.2.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import yaml
|
|
3
|
+
import numpy as np
|
|
4
|
+
from typing import Dict, Any, List
|
|
5
|
+
|
|
6
|
+
from ...common.common import *
|
|
7
|
+
from ...filetype import yamlfile
|
|
8
|
+
|
|
9
|
+
class ParamGen:
|
|
10
|
+
@staticmethod
|
|
11
|
+
def build_from_file(params_file):
|
|
12
|
+
builder = ParamGen(params_file)
|
|
13
|
+
return builder.params
|
|
14
|
+
|
|
15
|
+
def __init__(self, params_file=None):
|
|
16
|
+
self.params = {}
|
|
17
|
+
assert os.path.isfile(params_file), f"params_file not found: {params_file}"
|
|
18
|
+
self.params = self._build(params_file)
|
|
19
|
+
|
|
20
|
+
def _expand_param(self, param_name: str, config: Dict[str, Any]) -> List[Any]:
|
|
21
|
+
"""
|
|
22
|
+
Validates and expands the values for a single parameter configuration.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
param_name: The name of the parameter being processed.
|
|
26
|
+
config: The configuration dictionary for this parameter.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
A list of the expanded values for the parameter.
|
|
30
|
+
|
|
31
|
+
Raises:
|
|
32
|
+
TypeError: If the configuration or its values have an incorrect type.
|
|
33
|
+
ValueError: If the configuration is missing keys or has an invalid structure.
|
|
34
|
+
"""
|
|
35
|
+
# 1. Validate the configuration structure
|
|
36
|
+
if not isinstance(config, dict):
|
|
37
|
+
raise TypeError(f"Config for '{param_name}' must be a dictionary.")
|
|
38
|
+
|
|
39
|
+
if "type" not in config or "values" not in config:
|
|
40
|
+
raise ValueError(
|
|
41
|
+
f"Config for '{param_name}' must contain 'type' and 'values' keys."
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
gen_type = config["type"]
|
|
45
|
+
values = config["values"]
|
|
46
|
+
|
|
47
|
+
# 2. Handle the generation based on type
|
|
48
|
+
if gen_type == "list":
|
|
49
|
+
# Ensure values are returned as a list, even if a single item was provided
|
|
50
|
+
return values if isinstance(values, list) else [values]
|
|
51
|
+
|
|
52
|
+
elif gen_type == "range":
|
|
53
|
+
if not isinstance(values, list) or len(values) != 3:
|
|
54
|
+
raise ValueError(
|
|
55
|
+
f"For 'range' type on '{param_name}', 'values' must be a list of 3 numbers "
|
|
56
|
+
f"[start, end, step], but got: {values}"
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
start, end, step = values
|
|
60
|
+
if all(isinstance(v, int) for v in values):
|
|
61
|
+
return list(range(start, end, step))
|
|
62
|
+
elif all(isinstance(v, (int, float)) for v in values):
|
|
63
|
+
# Use numpy for floating point ranges
|
|
64
|
+
temp_list = list(np.arange(start, end, step))
|
|
65
|
+
# convert to float (not np.float)
|
|
66
|
+
return [float(v) for v in temp_list]
|
|
67
|
+
else:
|
|
68
|
+
raise TypeError(
|
|
69
|
+
f"All 'values' for 'range' on '{param_name}' must be numbers."
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
else:
|
|
73
|
+
raise ValueError(
|
|
74
|
+
f"Invalid 'type' for '{param_name}': '{gen_type}'. Must be 'list' or 'range'."
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def _build(self, params_file):
|
|
78
|
+
"""
|
|
79
|
+
Builds a full optimization configuration by expanding parameter values based on their type.
|
|
80
|
+
|
|
81
|
+
This function processes a dictionary where each key is a parameter name and each value
|
|
82
|
+
is a config dict specifying the 'type' ('list' or 'range') and 'values' for generation.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
opt_cfg: The input configuration dictionary.
|
|
86
|
+
Example:
|
|
87
|
+
{
|
|
88
|
+
"learning_rate": {"type": "range", "values": [0.01, 0.1, 0.01]},
|
|
89
|
+
"optimizer": {"type": "list", "values": ["adam", "sgd"]},
|
|
90
|
+
"epochs": {"type": "list", "values": 100}
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
A dictionary with parameter names mapped to their fully expanded list of values.
|
|
95
|
+
"""
|
|
96
|
+
cfg_raw_dict = yamlfile.load_yaml(params_file, to_dict=True)
|
|
97
|
+
if not isinstance(cfg_raw_dict, dict):
|
|
98
|
+
raise TypeError("The entire opt_cfg must be a dictionary.")
|
|
99
|
+
|
|
100
|
+
# Use a dictionary comprehension for a clean and efficient build
|
|
101
|
+
return {
|
|
102
|
+
param_name: self._expand_param(param_name, config)
|
|
103
|
+
for param_name, config in cfg_raw_dict.items()
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
def save(self, outfile):
|
|
107
|
+
with open(outfile, "w") as f:
|
|
108
|
+
yaml.dump(self.params, f)
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import glob
|
|
3
|
+
import wandb
|
|
4
|
+
import argparse
|
|
5
|
+
import subprocess
|
|
6
|
+
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
from rich.console import Console
|
|
9
|
+
|
|
10
|
+
console = Console()
|
|
11
|
+
|
|
12
|
+
def sync_runs(outdir):
|
|
13
|
+
outdir = os.path.abspath(outdir)
|
|
14
|
+
assert os.path.exists(outdir), f"Output directory {outdir} does not exist."
|
|
15
|
+
sub_dirs = [name for name in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, name))]
|
|
16
|
+
assert len(sub_dirs) > 0, f"No subdirectories found in {outdir}."
|
|
17
|
+
console.rule("Parent Directory")
|
|
18
|
+
console.print(f"[yellow]{outdir}[/yellow]")
|
|
19
|
+
|
|
20
|
+
exp_dirs = [os.path.join(outdir, sub_dir) for sub_dir in sub_dirs]
|
|
21
|
+
wandb_dirs = []
|
|
22
|
+
for exp_dir in exp_dirs:
|
|
23
|
+
wandb_dirs.extend(glob.glob(f"{exp_dir}/wandb/*run-*"))
|
|
24
|
+
if len(wandb_dirs) == 0:
|
|
25
|
+
console.print(f"No wandb runs found in {outdir}.")
|
|
26
|
+
return
|
|
27
|
+
else:
|
|
28
|
+
console.print(f"Found [bold]{len(wandb_dirs)}[/bold] wandb runs in {outdir}.")
|
|
29
|
+
for i, wandb_dir in enumerate(wandb_dirs):
|
|
30
|
+
console.rule(f"Syncing wandb run {i + 1}/{len(wandb_dirs)}")
|
|
31
|
+
console.print(f"Syncing: {wandb_dir}")
|
|
32
|
+
process = subprocess.Popen(
|
|
33
|
+
["wandb", "sync", wandb_dir],
|
|
34
|
+
stdout=subprocess.PIPE,
|
|
35
|
+
stderr=subprocess.STDOUT,
|
|
36
|
+
text=True,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
for line in process.stdout:
|
|
40
|
+
console.print(line.strip())
|
|
41
|
+
if " ERROR Error while calling W&B API" in line:
|
|
42
|
+
break
|
|
43
|
+
process.stdout.close()
|
|
44
|
+
process.wait()
|
|
45
|
+
if process.returncode != 0:
|
|
46
|
+
console.print(f"[red]Error syncing {wandb_dir}. Return code: {process.returncode}[/red]")
|
|
47
|
+
else:
|
|
48
|
+
console.print(f"Successfully synced {wandb_dir}.")
|
|
49
|
+
|
|
50
|
+
def delete_runs(project, pattern=None):
|
|
51
|
+
console.rule("Delete W&B Runs")
|
|
52
|
+
confirm_msg = f"Are you sure you want to delete all runs in"
|
|
53
|
+
confirm_msg += f" \n\tproject: [red]{project}[/red]"
|
|
54
|
+
if pattern:
|
|
55
|
+
confirm_msg += f"\n\tpattern: [blue]{pattern}[/blue]"
|
|
56
|
+
|
|
57
|
+
console.print(confirm_msg)
|
|
58
|
+
confirmation = input(f"This action cannot be undone. [y/N]: ").strip().lower()
|
|
59
|
+
if confirmation != "y":
|
|
60
|
+
print("Cancelled.")
|
|
61
|
+
return
|
|
62
|
+
|
|
63
|
+
print("Confirmed. Proceeding...")
|
|
64
|
+
api = wandb.Api()
|
|
65
|
+
runs = api.runs(project)
|
|
66
|
+
|
|
67
|
+
deleted = 0
|
|
68
|
+
console.rule("Deleting W&B Runs")
|
|
69
|
+
if len(runs) == 0:
|
|
70
|
+
print("No runs found in the project.")
|
|
71
|
+
return
|
|
72
|
+
for run in tqdm(runs):
|
|
73
|
+
if pattern is None or pattern in run.name:
|
|
74
|
+
run.delete()
|
|
75
|
+
console.print(f"Deleted run: [red]{run.name}[/red]")
|
|
76
|
+
deleted += 1
|
|
77
|
+
|
|
78
|
+
console.print(f"Total runs deleted: {deleted}")
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def valid_argument(args):
|
|
82
|
+
if args.op == "sync":
|
|
83
|
+
assert os.path.exists(args.outdir), f"Output directory {args.outdir} does not exist."
|
|
84
|
+
elif args.op == "delete":
|
|
85
|
+
assert isinstance(args.project, str) and len(args.project.strip()) > 0, "Project name must be a non-empty string."
|
|
86
|
+
else:
|
|
87
|
+
raise ValueError(f"Unknown operation: {args.op}")
|
|
88
|
+
|
|
89
|
+
def parse_args():
|
|
90
|
+
parser = argparse.ArgumentParser(description="Operations on W&B runs")
|
|
91
|
+
parser.add_argument("-op", "--op", type=str, help="Operation to perform", default="sync", choices=["delete", "sync"])
|
|
92
|
+
parser.add_argument("-prj", "--project", type=str, default="fire-paper2-2025", help="W&B project name")
|
|
93
|
+
parser.add_argument("-outdir", "--outdir", type=str, help="arg1 description", default="./zout/train")
|
|
94
|
+
parser.add_argument("-pt", "--pattern",
|
|
95
|
+
type=str,
|
|
96
|
+
default=None,
|
|
97
|
+
help="Run name pattern to match for deletion",
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
return parser.parse_args()
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def main():
|
|
104
|
+
args = parse_args()
|
|
105
|
+
# Validate arguments, stop if invalid
|
|
106
|
+
valid_argument(args)
|
|
107
|
+
|
|
108
|
+
op = args.op
|
|
109
|
+
if op == "sync":
|
|
110
|
+
sync_runs(args.outdir)
|
|
111
|
+
elif op == "delete":
|
|
112
|
+
delete_runs(args.project, args.pattern)
|
|
113
|
+
else:
|
|
114
|
+
raise ValueError(f"Unknown operation: {op}")
|
|
115
|
+
|
|
116
|
+
if __name__ == "__main__":
|
|
117
|
+
main()
|
|
File without changes
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
import yaml
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from rich.pretty import pprint
|
|
5
|
+
from dataclasses import make_dataclass
|
|
6
|
+
|
|
7
|
+
from ...filetype import yamlfile
|
|
8
|
+
|
|
9
|
+
def dict_to_dataclass(name: str, data: dict):
|
|
10
|
+
fields = []
|
|
11
|
+
values = {}
|
|
12
|
+
|
|
13
|
+
for key, value in data.items():
|
|
14
|
+
if isinstance(value, dict):
|
|
15
|
+
sub_dc = dict_to_dataclass(key.capitalize(), value)
|
|
16
|
+
fields.append((key, type(sub_dc)))
|
|
17
|
+
values[key] = sub_dc
|
|
18
|
+
else:
|
|
19
|
+
field_type = type(value) if value is not None else Any
|
|
20
|
+
fields.append((key, field_type))
|
|
21
|
+
values[key] = value
|
|
22
|
+
|
|
23
|
+
DC = make_dataclass(name.capitalize(), fields)
|
|
24
|
+
return DC(**values)
|
|
25
|
+
|
|
26
|
+
def yaml_to_dataclass(name: str, yaml_str: str):
|
|
27
|
+
data = yaml.safe_load(yaml_str)
|
|
28
|
+
return dict_to_dataclass(name, data)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def yamlfile_to_dataclass(name: str, file_path: str):
|
|
32
|
+
data_dict = yamlfile.load_yaml(file_path, to_dict=True)
|
|
33
|
+
if "__base__" in data_dict:
|
|
34
|
+
del data_dict["__base__"]
|
|
35
|
+
return dict_to_dataclass(name, data_dict)
|
|
36
|
+
|
|
37
|
+
if __name__ == "__main__":
|
|
38
|
+
cfg = yamlfile_to_dataclass("Config", "test/dataclass_util_test_cfg.yaml")
|
|
39
|
+
|
|
40
|
+
# ! NOTICE: after print out this dataclass, we can copy the output and paste it into CHATGPT to generate a list of needed dataclass classes using `from dataclass_wizard import YAMLWizard`
|
|
41
|
+
pprint(cfg)
|
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
# This script create a test version
|
|
2
|
+
# of the watcam (wc) dataset
|
|
3
|
+
# for testing the tflite model
|
|
4
|
+
|
|
5
|
+
from argparse import ArgumentParser
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
import click
|
|
9
|
+
import shutil
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
from rich import inspect
|
|
12
|
+
from rich.pretty import pprint
|
|
13
|
+
from torchvision.datasets import ImageFolder
|
|
14
|
+
from sklearn.model_selection import StratifiedShuffleSplit, ShuffleSplit
|
|
15
|
+
|
|
16
|
+
from ...common.common import console, seed_everything, ConsoleLog
|
|
17
|
+
from ...system import filesys as fs
|
|
18
|
+
|
|
19
|
+
def parse_args():
|
|
20
|
+
parser = ArgumentParser(description="desc text")
|
|
21
|
+
parser.add_argument(
|
|
22
|
+
"-indir",
|
|
23
|
+
"--indir",
|
|
24
|
+
type=str,
|
|
25
|
+
help="orignal dataset path",
|
|
26
|
+
)
|
|
27
|
+
parser.add_argument(
|
|
28
|
+
"-outdir",
|
|
29
|
+
"--outdir",
|
|
30
|
+
type=str,
|
|
31
|
+
help="dataset out path",
|
|
32
|
+
default=".", # default to current dir
|
|
33
|
+
)
|
|
34
|
+
parser.add_argument(
|
|
35
|
+
"-val_size",
|
|
36
|
+
"--val_size",
|
|
37
|
+
type=float,
|
|
38
|
+
help="validation size", # no default value to force user to input
|
|
39
|
+
default=0.2,
|
|
40
|
+
)
|
|
41
|
+
# add using StratifiedShuffleSplit or ShuffleSplit
|
|
42
|
+
parser.add_argument(
|
|
43
|
+
"-seed",
|
|
44
|
+
"--seed",
|
|
45
|
+
type=int,
|
|
46
|
+
help="random seed",
|
|
47
|
+
default=42,
|
|
48
|
+
)
|
|
49
|
+
parser.add_argument(
|
|
50
|
+
"-inplace",
|
|
51
|
+
"--inplace",
|
|
52
|
+
action="store_true",
|
|
53
|
+
help="inplace operation, will overwrite the outdir if exists",
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
parser.add_argument(
|
|
57
|
+
"-stratified",
|
|
58
|
+
"--stratified",
|
|
59
|
+
action="store_true",
|
|
60
|
+
help="use StratifiedShuffleSplit instead of ShuffleSplit",
|
|
61
|
+
)
|
|
62
|
+
parser.add_argument(
|
|
63
|
+
"-no_train",
|
|
64
|
+
"--no_train",
|
|
65
|
+
action="store_true",
|
|
66
|
+
help="only create test set, no train set",
|
|
67
|
+
)
|
|
68
|
+
parser.add_argument(
|
|
69
|
+
"-reverse",
|
|
70
|
+
"--reverse",
|
|
71
|
+
action="store_true",
|
|
72
|
+
help="combine train and val set back to original dataset",
|
|
73
|
+
)
|
|
74
|
+
return parser.parse_args()
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def move_images(image_paths, target_set_dir):
|
|
78
|
+
for img_path in tqdm(image_paths):
|
|
79
|
+
# get folder name of the image
|
|
80
|
+
img_dir = os.path.dirname(img_path)
|
|
81
|
+
out_cls_dir = os.path.join(target_set_dir, os.path.basename(img_dir))
|
|
82
|
+
if not os.path.exists(out_cls_dir):
|
|
83
|
+
os.makedirs(out_cls_dir)
|
|
84
|
+
# move the image to the class folder
|
|
85
|
+
shutil.move(img_path, out_cls_dir)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def split_dataset_cls(
|
|
89
|
+
indir, outdir, val_size, seed, inplace, stratified_split, no_train
|
|
90
|
+
):
|
|
91
|
+
seed_everything(seed)
|
|
92
|
+
console.rule("Config confirm?")
|
|
93
|
+
pprint(locals())
|
|
94
|
+
click.confirm("Continue?", abort=True)
|
|
95
|
+
assert os.path.exists(indir), f"{indir} does not exist"
|
|
96
|
+
|
|
97
|
+
if not inplace:
|
|
98
|
+
assert (not inplace) and (
|
|
99
|
+
not os.path.exists(outdir)
|
|
100
|
+
), f"{outdir} already exists; SKIP ...."
|
|
101
|
+
|
|
102
|
+
if inplace:
|
|
103
|
+
outdir = indir
|
|
104
|
+
if not os.path.exists(outdir):
|
|
105
|
+
os.makedirs(outdir)
|
|
106
|
+
|
|
107
|
+
console.rule(f"Creating train/val dataset")
|
|
108
|
+
|
|
109
|
+
sss = (
|
|
110
|
+
ShuffleSplit(n_splits=1, test_size=val_size)
|
|
111
|
+
if not stratified_split
|
|
112
|
+
else StratifiedShuffleSplit(n_splits=1, test_size=val_size)
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
pprint({"split strategy": sss, "indir": indir, "outdir": outdir})
|
|
116
|
+
dataset = ImageFolder(
|
|
117
|
+
root=indir,
|
|
118
|
+
transform=None,
|
|
119
|
+
)
|
|
120
|
+
train_dataset_indices = None
|
|
121
|
+
val_dataset_indices = None # val here means test
|
|
122
|
+
for train_indices, val_indices in sss.split(dataset.samples, dataset.targets):
|
|
123
|
+
train_dataset_indices = train_indices
|
|
124
|
+
val_dataset_indices = val_indices
|
|
125
|
+
|
|
126
|
+
# get image paths for train/val split dataset
|
|
127
|
+
train_image_paths = [dataset.imgs[i][0] for i in train_dataset_indices]
|
|
128
|
+
val_image_paths = [dataset.imgs[i][0] for i in val_dataset_indices]
|
|
129
|
+
|
|
130
|
+
# start creating train/val folders then move images
|
|
131
|
+
out_train_dir = os.path.join(outdir, "train")
|
|
132
|
+
out_val_dir = os.path.join(outdir, "val")
|
|
133
|
+
if inplace:
|
|
134
|
+
assert os.path.exists(out_train_dir) == False, f"{out_train_dir} already exists"
|
|
135
|
+
assert os.path.exists(out_val_dir) == False, f"{out_val_dir} already exists"
|
|
136
|
+
|
|
137
|
+
os.makedirs(out_train_dir)
|
|
138
|
+
os.makedirs(out_val_dir)
|
|
139
|
+
|
|
140
|
+
if not no_train:
|
|
141
|
+
with ConsoleLog(f"Moving train images to {out_train_dir} "):
|
|
142
|
+
move_images(train_image_paths, out_train_dir)
|
|
143
|
+
else:
|
|
144
|
+
pprint("test only, skip moving train images")
|
|
145
|
+
# remove out_train_dir
|
|
146
|
+
shutil.rmtree(out_train_dir)
|
|
147
|
+
|
|
148
|
+
with ConsoleLog(f"Moving val images to {out_val_dir} "):
|
|
149
|
+
move_images(val_image_paths, out_val_dir)
|
|
150
|
+
|
|
151
|
+
if inplace:
|
|
152
|
+
pprint(f"remove all folders, except train and val")
|
|
153
|
+
for cls_dir in os.listdir(outdir):
|
|
154
|
+
if cls_dir not in ["train", "val"]:
|
|
155
|
+
shutil.rmtree(os.path.join(indir, cls_dir))
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def reverse_split_ds(indir):
|
|
159
|
+
console.rule(f"Reversing split dataset <{indir}>...")
|
|
160
|
+
ls_dirs = os.listdir(indir)
|
|
161
|
+
# make sure there are only two dirs 'train' and 'val'
|
|
162
|
+
assert len(ls_dirs) == 2, f"Found more than 2 dirs: {len(ls_dirs) } dirs"
|
|
163
|
+
assert "train" in ls_dirs, f"train dir not found in {indir}"
|
|
164
|
+
assert "val" in ls_dirs, f"val dir not found in {indir}"
|
|
165
|
+
train_dir = os.path.join(indir, "train")
|
|
166
|
+
val_dir = os.path.join(indir, "val")
|
|
167
|
+
all_train_files = fs.filter_files_by_extension(
|
|
168
|
+
train_dir, ["jpg", "jpeg", "png", "bmp", "gif", "tiff"]
|
|
169
|
+
)
|
|
170
|
+
all_val_files = fs.filter_files_by_extension(
|
|
171
|
+
val_dir, ["jpg", "jpeg", "png", "bmp", "gif", "tiff"]
|
|
172
|
+
)
|
|
173
|
+
# move all files from train to indir
|
|
174
|
+
with ConsoleLog(f"Moving train images to {indir} "):
|
|
175
|
+
move_images(all_train_files, indir)
|
|
176
|
+
with ConsoleLog(f"Moving val images to {indir} "):
|
|
177
|
+
move_images(all_val_files, indir)
|
|
178
|
+
with ConsoleLog(f"Removing train and val dirs"):
|
|
179
|
+
# remove train and val dirs
|
|
180
|
+
shutil.rmtree(train_dir)
|
|
181
|
+
shutil.rmtree(val_dir)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def main():
|
|
185
|
+
args = parse_args()
|
|
186
|
+
indir = args.indir
|
|
187
|
+
outdir = args.outdir
|
|
188
|
+
if outdir == ".":
|
|
189
|
+
# get current folder of the indir
|
|
190
|
+
indir_parent_dir = os.path.dirname(os.path.normpath(indir))
|
|
191
|
+
indir_name = os.path.basename(indir)
|
|
192
|
+
outdir = os.path.join(indir_parent_dir, f"{indir_name}_split")
|
|
193
|
+
val_size = args.val_size
|
|
194
|
+
seed = args.seed
|
|
195
|
+
inplace = args.inplace
|
|
196
|
+
stratified_split = args.stratified
|
|
197
|
+
no_train = args.no_train
|
|
198
|
+
reverse = args.reverse
|
|
199
|
+
if not reverse:
|
|
200
|
+
split_dataset_cls(
|
|
201
|
+
indir, outdir, val_size, seed, inplace, stratified_split, no_train
|
|
202
|
+
)
|
|
203
|
+
else:
|
|
204
|
+
reverse_split_ds(indir)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
if __name__ == "__main__":
|
|
208
|
+
main()
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
"""
|
|
2
|
+
* @author Hoang Van-Ha
|
|
3
|
+
* @email hoangvanhauit@gmail.com
|
|
4
|
+
* @create date 2024-03-27 15:40:22
|
|
5
|
+
* @modify date 2024-03-27 15:40:22
|
|
6
|
+
* @desc this module works as a utility tools for finding the best configuration for dataloader (num_workers, batch_size, pin_menory, etc.) that fits your hardware.
|
|
7
|
+
"""
|
|
8
|
+
from argparse import ArgumentParser
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
import time
|
|
12
|
+
import traceback
|
|
13
|
+
|
|
14
|
+
from tqdm import tqdm
|
|
15
|
+
from rich import inspect
|
|
16
|
+
from typing import Union
|
|
17
|
+
import itertools as it # for cartesian product
|
|
18
|
+
|
|
19
|
+
from torch.utils.data import DataLoader
|
|
20
|
+
from torchvision import datasets, transforms
|
|
21
|
+
|
|
22
|
+
from ...common.common import *
|
|
23
|
+
from ...filetype import csvfile
|
|
24
|
+
from ...filetype.yamlfile import load_yaml
|
|
25
|
+
|
|
26
|
+
def parse_args():
|
|
27
|
+
parser = ArgumentParser(description="desc text")
|
|
28
|
+
parser.add_argument("-cfg", "--cfg", type=str, help="cfg file for searching")
|
|
29
|
+
return parser.parse_args()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_test_range(cfg: dict, search_item="num_workers"):
|
|
33
|
+
item_search_cfg = cfg["search_space"].get(search_item, None)
|
|
34
|
+
if item_search_cfg is None:
|
|
35
|
+
raise ValueError(f"search_item: {search_item} not found in cfg")
|
|
36
|
+
if isinstance(item_search_cfg, list):
|
|
37
|
+
return item_search_cfg
|
|
38
|
+
elif isinstance(item_search_cfg, dict):
|
|
39
|
+
if "mode" in item_search_cfg:
|
|
40
|
+
mode = item_search_cfg["mode"]
|
|
41
|
+
assert mode in ["range", "list"], f"mode: {mode} not supported"
|
|
42
|
+
value_in_mode = item_search_cfg.get(mode, None)
|
|
43
|
+
if value_in_mode is None:
|
|
44
|
+
raise ValueError(f"mode<{mode}>: data not found in <{search_item}>")
|
|
45
|
+
if mode == "range":
|
|
46
|
+
assert len(value_in_mode) == 3, f"range must have 3 values: start, stop, step"
|
|
47
|
+
start = value_in_mode[0]
|
|
48
|
+
stop = value_in_mode[1]
|
|
49
|
+
step = value_in_mode[2]
|
|
50
|
+
return list(range(start, stop, step))
|
|
51
|
+
elif mode == "list":
|
|
52
|
+
return item_search_cfg["list"]
|
|
53
|
+
else:
|
|
54
|
+
return [item_search_cfg] # for int, float, str, bool, etc.
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def load_an_batch(loader_iter):
|
|
58
|
+
start = time.time()
|
|
59
|
+
next(loader_iter)
|
|
60
|
+
end = time.time()
|
|
61
|
+
return end - start
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def test_dataloader_with_cfg(origin_dataloader: DataLoader, cfg: Union[dict, str]):
|
|
65
|
+
try:
|
|
66
|
+
if isinstance(cfg, str):
|
|
67
|
+
cfg = load_yaml(cfg, to_dict=True)
|
|
68
|
+
dfmk = csvfile.DFCreator()
|
|
69
|
+
search_items = ["batch_size", "num_workers", "persistent_workers", "pin_memory"]
|
|
70
|
+
batch_limit = cfg["general"]["batch_limit"]
|
|
71
|
+
csv_cfg = cfg["general"]["to_csv"]
|
|
72
|
+
log_batch_info = cfg["general"]["log_batch_info"]
|
|
73
|
+
|
|
74
|
+
save_to_csv = csv_cfg["enabled"]
|
|
75
|
+
log_dir = csv_cfg["log_dir"]
|
|
76
|
+
filename = csv_cfg["filename"]
|
|
77
|
+
filename = f"{now_str()}_{filename}.csv"
|
|
78
|
+
outfile = os.path.join(log_dir, filename)
|
|
79
|
+
|
|
80
|
+
dfmk.create_table(
|
|
81
|
+
"cfg_search",
|
|
82
|
+
(search_items + ["avg_time_taken"]),
|
|
83
|
+
)
|
|
84
|
+
ls_range_test = []
|
|
85
|
+
for item in search_items:
|
|
86
|
+
range_test = get_test_range(cfg, search_item=item)
|
|
87
|
+
range_test = [(item, i) for i in range_test]
|
|
88
|
+
ls_range_test.append(range_test)
|
|
89
|
+
|
|
90
|
+
all_combinations = list(it.product(*ls_range_test))
|
|
91
|
+
|
|
92
|
+
rows = []
|
|
93
|
+
for cfg_idx, combine in enumerate(all_combinations):
|
|
94
|
+
console.rule(f"Testing cfg {cfg_idx+1}/{len(all_combinations)}")
|
|
95
|
+
inspect(combine)
|
|
96
|
+
batch_size = combine[search_items.index("batch_size")][1]
|
|
97
|
+
num_workers = combine[search_items.index("num_workers")][1]
|
|
98
|
+
persistent_workers = combine[search_items.index("persistent_workers")][1]
|
|
99
|
+
pin_memory = combine[search_items.index("pin_memory")][1]
|
|
100
|
+
|
|
101
|
+
test_dataloader = DataLoader(origin_dataloader.dataset, batch_size=batch_size, num_workers=num_workers, persistent_workers=persistent_workers, pin_memory=pin_memory, shuffle=True)
|
|
102
|
+
row = [
|
|
103
|
+
batch_size,
|
|
104
|
+
num_workers,
|
|
105
|
+
persistent_workers,
|
|
106
|
+
pin_memory,
|
|
107
|
+
0.0,
|
|
108
|
+
]
|
|
109
|
+
|
|
110
|
+
# calculate the avg time taken to load the data for <batch_limit> batches
|
|
111
|
+
trainiter = iter(test_dataloader)
|
|
112
|
+
time_elapsed = 0
|
|
113
|
+
pprint('Start testing...')
|
|
114
|
+
for i in tqdm(range(batch_limit)):
|
|
115
|
+
single_batch_time = load_an_batch(trainiter)
|
|
116
|
+
if log_batch_info:
|
|
117
|
+
pprint(f"Batch {i+1} took {single_batch_time:.4f} seconds to load")
|
|
118
|
+
time_elapsed += single_batch_time
|
|
119
|
+
row[-1] = time_elapsed / batch_limit
|
|
120
|
+
rows.append(row)
|
|
121
|
+
dfmk.insert_rows('cfg_search', rows)
|
|
122
|
+
dfmk.fill_table_from_row_pool('cfg_search')
|
|
123
|
+
with ConsoleLog("results"):
|
|
124
|
+
csvfile.fn_display_df(dfmk['cfg_search'])
|
|
125
|
+
if save_to_csv:
|
|
126
|
+
dfmk["cfg_search"].to_csv(outfile, index=False)
|
|
127
|
+
console.print(f"[red] Data saved to <{outfile}> [/red]")
|
|
128
|
+
|
|
129
|
+
except Exception as e:
|
|
130
|
+
traceback.print_exc()
|
|
131
|
+
print(e)
|
|
132
|
+
# get current directory of this python file
|
|
133
|
+
current_dir = os.path.dirname(os.path.realpath(__file__))
|
|
134
|
+
standar_cfg_path = os.path.join(current_dir, "torchloader_search.yaml")
|
|
135
|
+
pprint(
|
|
136
|
+
f"Make sure you get the right <cfg.yaml> file. An example of <cfg.yaml> file can be found at this path: {standar_cfg_path}"
|
|
137
|
+
)
|
|
138
|
+
return
|
|
139
|
+
|
|
140
|
+
def main():
|
|
141
|
+
args = parse_args()
|
|
142
|
+
cfg_yaml = args.cfg
|
|
143
|
+
cfg_dict = load_yaml(cfg_yaml, to_dict=True)
|
|
144
|
+
|
|
145
|
+
# Define transforms for data augmentation and normalization
|
|
146
|
+
transform = transforms.Compose(
|
|
147
|
+
[
|
|
148
|
+
transforms.RandomHorizontalFlip(), # Randomly flip images horizontally
|
|
149
|
+
transforms.RandomRotation(10), # Randomly rotate images by 10 degrees
|
|
150
|
+
transforms.ToTensor(), # Convert images to PyTorch tensors
|
|
151
|
+
transforms.Normalize(
|
|
152
|
+
(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
|
|
153
|
+
), # Normalize pixel values to [-1, 1]
|
|
154
|
+
]
|
|
155
|
+
)
|
|
156
|
+
test_dataset = datasets.CIFAR10(
|
|
157
|
+
root="./data", train=False, download=True, transform=transform
|
|
158
|
+
)
|
|
159
|
+
batch_size = 64
|
|
160
|
+
train_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
|
|
161
|
+
test_dataloader_with_cfg(train_loader, cfg_dict)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
if __name__ == "__main__":
|
|
165
|
+
main()
|
halib/research/dataset.py
CHANGED
|
@@ -13,8 +13,8 @@ from rich.pretty import pprint
|
|
|
13
13
|
from torchvision.datasets import ImageFolder
|
|
14
14
|
from sklearn.model_selection import StratifiedShuffleSplit, ShuffleSplit
|
|
15
15
|
|
|
16
|
-
from ..system import filesys as fs
|
|
17
16
|
from ..common import console, seed_everything, ConsoleLog
|
|
17
|
+
from ..system import filesys as fs
|
|
18
18
|
|
|
19
19
|
def parse_args():
|
|
20
20
|
parser = ArgumentParser(description="desc text")
|
halib/research/metrics.py
CHANGED
|
@@ -11,6 +11,10 @@ class MetricsBackend(ABC):
|
|
|
11
11
|
def __init__(self, metrics_info: Union[List[str], Dict[str, Any]]):
|
|
12
12
|
"""
|
|
13
13
|
Initialize the backend with optional metrics_info.
|
|
14
|
+
`metrics_info` can be either:
|
|
15
|
+
- A list of metric names (strings). e.g., ["accuracy", "precision"]
|
|
16
|
+
- A dict mapping metric names with object that defines how to compute them. e.g: {"accuracy": torchmetrics.Accuracy(), "precision": torchmetrics.Precision()}
|
|
17
|
+
|
|
14
18
|
"""
|
|
15
19
|
self.metric_info = metrics_info
|
|
16
20
|
self.validate_metrics_info(self.metric_info)
|