halib 0.1.89__tar.gz → 0.2.13__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {halib-0.1.89 → halib-0.2.13}/.gitignore +0 -1
- halib-0.2.13/MANIFEST.in +5 -0
- {halib-0.1.89 → halib-0.2.13}/PKG-INFO +21 -3
- {halib-0.1.89 → halib-0.2.13}/README.md +19 -2
- {halib-0.1.89 → halib-0.2.13}/halib/__init__.py +3 -3
- halib-0.2.13/halib/common/common.py +178 -0
- halib-0.2.13/halib/exp/core/base_config.py +167 -0
- halib-0.2.13/halib/exp/core/base_exp.py +147 -0
- halib-0.2.13/halib/exp/core/param_gen.py +181 -0
- {halib-0.1.89/halib/research → halib-0.2.13/halib/exp/core}/wandb_op.py +5 -4
- {halib-0.1.89/halib/utils → halib-0.2.13/halib/exp/data}/dataclass_util.py +3 -2
- {halib-0.1.89/halib/research → halib-0.2.13/halib/exp/data}/dataset.py +6 -7
- {halib-0.1.89/halib/research → halib-0.2.13/halib/exp/data}/torchloader.py +12 -9
- halib-0.2.13/halib/exp/perf/flop_calc.py +190 -0
- {halib-0.1.89/halib/research → halib-0.2.13/halib/exp/perf}/perfcalc.py +195 -91
- halib-0.1.89/halib/research/metrics.py → halib-0.2.13/halib/exp/perf/perfmetrics.py +4 -0
- {halib-0.1.89/halib/research → halib-0.2.13/halib/exp/perf}/perftb.py +6 -7
- {halib-0.1.89/halib/research → halib-0.2.13/halib/exp/perf}/profiler.py +6 -6
- halib-0.2.13/halib/exp/viz/__init__.py +0 -0
- halib-0.2.13/halib/exp/viz/plot.py +754 -0
- halib-0.2.13/halib/filetype/__init__.py +0 -0
- {halib-0.1.89 → halib-0.2.13}/halib/filetype/csvfile.py +3 -9
- halib-0.2.13/halib/filetype/ipynb.py +61 -0
- {halib-0.1.89 → halib-0.2.13}/halib/filetype/jsonfile.py +0 -3
- {halib-0.1.89 → halib-0.2.13}/halib/filetype/textfile.py +0 -1
- {halib-0.1.89 → halib-0.2.13}/halib/filetype/videofile.py +91 -2
- {halib-0.1.89 → halib-0.2.13}/halib/filetype/yamlfile.py +16 -1
- halib-0.2.13/halib/online/__init__.py +0 -0
- {halib-0.1.89 → halib-0.2.13}/halib/online/projectmake.py +7 -6
- {halib-0.1.89/halib/utils → halib-0.2.13/halib/online}/tele_noti.py +1 -2
- halib-0.2.13/halib/system/__init__.py +0 -0
- halib-0.2.13/halib/system/_list_pc.csv +6 -0
- halib-0.2.13/halib/system/filesys.py +164 -0
- halib-0.2.13/halib/system/path.py +106 -0
- halib-0.2.13/halib/utils/__init__.py +0 -0
- halib-0.1.89/halib/utils/listop.py → halib-0.2.13/halib/utils/list.py +0 -1
- {halib-0.1.89 → halib-0.2.13}/halib.egg-info/PKG-INFO +21 -3
- halib-0.2.13/halib.egg-info/SOURCES.txt +54 -0
- {halib-0.1.89 → halib-0.2.13}/halib.egg-info/requires.txt +1 -0
- {halib-0.1.89 → halib-0.2.13}/setup.py +2 -1
- halib-0.1.89/MANIFEST.in +0 -4
- halib-0.1.89/guide_publish_pip.pdf +0 -0
- halib-0.1.89/halib/common.py +0 -108
- halib-0.1.89/halib/cuda.py +0 -39
- halib-0.1.89/halib/online/gdrive_test.py +0 -50
- halib-0.1.89/halib/research/base_config.py +0 -100
- halib-0.1.89/halib/research/base_exp.py +0 -100
- halib-0.1.89/halib/research/mics.py +0 -16
- halib-0.1.89/halib/research/plot.py +0 -496
- halib-0.1.89/halib/system/filesys.py +0 -124
- halib-0.1.89/halib/utils/video.py +0 -76
- halib-0.1.89/halib.egg-info/SOURCES.txt +0 -49
- {halib-0.1.89 → halib-0.2.13}/GDriveFolder.txt +0 -0
- {halib-0.1.89 → halib-0.2.13}/LICENSE.txt +0 -0
- {halib-0.1.89/halib/filetype → halib-0.2.13/halib/common}/__init__.py +0 -0
- {halib-0.1.89/halib → halib-0.2.13/halib/common}/rich_color.py +0 -0
- {halib-0.1.89/halib/online → halib-0.2.13/halib/exp}/__init__.py +0 -0
- {halib-0.1.89/halib/research → halib-0.2.13/halib/exp/core}/__init__.py +0 -0
- {halib-0.1.89/halib/system → halib-0.2.13/halib/exp/data}/__init__.py +0 -0
- {halib-0.1.89/halib/utils → halib-0.2.13/halib/exp/perf}/__init__.py +0 -0
- {halib-0.1.89/halib/utils → halib-0.2.13/halib/exp/perf}/gpu_mon.py +0 -0
- {halib-0.1.89 → halib-0.2.13}/halib/online/gdrive.py +0 -0
- {halib-0.1.89 → halib-0.2.13}/halib/online/gdrive_mkdir.py +0 -0
- {halib-0.1.89 → halib-0.2.13}/halib/system/cmd.py +0 -0
- /halib-0.1.89/halib/utils/dict_op.py → /halib-0.2.13/halib/utils/dict.py +0 -0
- {halib-0.1.89 → halib-0.2.13}/halib.egg-info/dependency_links.txt +0 -0
- {halib-0.1.89 → halib-0.2.13}/halib.egg-info/top_level.txt +0 -0
- {halib-0.1.89 → halib-0.2.13}/setup.cfg +0 -0
halib-0.2.13/MANIFEST.in
ADDED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: halib
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.13
|
|
4
4
|
Summary: Small library for common tasks
|
|
5
5
|
Author: Hoang Van Ha
|
|
6
6
|
Author-email: hoangvanhauit@gmail.com
|
|
@@ -40,6 +40,7 @@ Requires-Dist: timebudget
|
|
|
40
40
|
Requires-Dist: tqdm
|
|
41
41
|
Requires-Dist: tube_dl
|
|
42
42
|
Requires-Dist: wandb
|
|
43
|
+
Requires-Dist: ipynbname
|
|
43
44
|
Dynamic: author
|
|
44
45
|
Dynamic: author-email
|
|
45
46
|
Dynamic: classifier
|
|
@@ -50,9 +51,26 @@ Dynamic: requires-dist
|
|
|
50
51
|
Dynamic: requires-python
|
|
51
52
|
Dynamic: summary
|
|
52
53
|
|
|
53
|
-
Helper package for coding and automation
|
|
54
|
+
# Helper package for coding and automation
|
|
54
55
|
|
|
55
|
-
**Version 0.
|
|
56
|
+
**Version 0.2.13**
|
|
57
|
+
+ reorganize packages with most changes in `research` package; also rename `research` to `exp` (package for experiment management and utilities)
|
|
58
|
+
+ update `exp/perfcalc.py` to allow save computed performance to csv file (without explicit calling method `calc_perfs`)
|
|
59
|
+
|
|
60
|
+
**Version 0.2.1**
|
|
61
|
+
+ `research/base_exp`: add `eval_exp` method to evaluate experiment (e.g., model evaluation on test set) after experiment running is done.
|
|
62
|
+
|
|
63
|
+
**Version 0.1.99**
|
|
64
|
+
+ `filetype/ipynb`: add `gen_ipynb_name` generator to create file name based on current notebook name as prefix (with optional timestamp)
|
|
65
|
+
|
|
66
|
+
**Version 0.1.96**
|
|
67
|
+
+ `research/plot`: add `PlotHelper` class to plot train history + plot grid of images (e.g., image samples from dataset or model outputs)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
**Version 0.1.91**
|
|
71
|
+
+ `research/param_gen`: add `ParamGen` class to generate parameter list from yaml file for hyperparameter search (grid search, random search, etc.)
|
|
72
|
+
|
|
73
|
+
**Version 0.1.90**
|
|
56
74
|
|
|
57
75
|
+ `research/profiler`: add `zProfiler` class to measure execution time of contexts and steps, with support for dynamic color scales in plots.
|
|
58
76
|
|
|
@@ -1,6 +1,23 @@
|
|
|
1
|
-
Helper package for coding and automation
|
|
1
|
+
# Helper package for coding and automation
|
|
2
2
|
|
|
3
|
-
**Version 0.
|
|
3
|
+
**Version 0.2.13**
|
|
4
|
+
+ reorganize packages with most changes in `research` package; also rename `research` to `exp` (package for experiment management and utilities)
|
|
5
|
+
+ update `exp/perfcalc.py` to allow save computed performance to csv file (without explicit calling method `calc_perfs`)
|
|
6
|
+
|
|
7
|
+
**Version 0.2.1**
|
|
8
|
+
+ `research/base_exp`: add `eval_exp` method to evaluate experiment (e.g., model evaluation on test set) after experiment running is done.
|
|
9
|
+
|
|
10
|
+
**Version 0.1.99**
|
|
11
|
+
+ `filetype/ipynb`: add `gen_ipynb_name` generator to create file name based on current notebook name as prefix (with optional timestamp)
|
|
12
|
+
|
|
13
|
+
**Version 0.1.96**
|
|
14
|
+
+ `research/plot`: add `PlotHelper` class to plot train history + plot grid of images (e.g., image samples from dataset or model outputs)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
**Version 0.1.91**
|
|
18
|
+
+ `research/param_gen`: add `ParamGen` class to generate parameter list from yaml file for hyperparameter search (grid search, random search, etc.)
|
|
19
|
+
|
|
20
|
+
**Version 0.1.90**
|
|
4
21
|
|
|
5
22
|
+ `research/profiler`: add `zProfiler` class to measure execution time of contexts and steps, with support for dynamic color scales in plots.
|
|
6
23
|
|
|
@@ -56,8 +56,7 @@ from .filetype.yamlfile import load_yaml
|
|
|
56
56
|
from .system import cmd
|
|
57
57
|
from .system import filesys as fs
|
|
58
58
|
from .filetype import csvfile
|
|
59
|
-
from .
|
|
60
|
-
from .common import (
|
|
59
|
+
from .common.common import (
|
|
61
60
|
console,
|
|
62
61
|
console_log,
|
|
63
62
|
ConsoleLog,
|
|
@@ -65,6 +64,7 @@ from .common import (
|
|
|
65
64
|
norm_str,
|
|
66
65
|
pprint_box,
|
|
67
66
|
pprint_local_path,
|
|
67
|
+
tcuda
|
|
68
68
|
)
|
|
69
69
|
|
|
70
70
|
# for log
|
|
@@ -76,7 +76,7 @@ from timebudget import timebudget
|
|
|
76
76
|
import omegaconf
|
|
77
77
|
from omegaconf import OmegaConf
|
|
78
78
|
from omegaconf.dictconfig import DictConfig
|
|
79
|
-
from .rich_color import rcolor_str, rcolor_palette, rcolor_palette_all, rcolor_all_str
|
|
79
|
+
from .common.rich_color import rcolor_str, rcolor_palette, rcolor_palette_all, rcolor_all_str
|
|
80
80
|
|
|
81
81
|
# for visualization
|
|
82
82
|
import seaborn as sns
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import arrow
|
|
4
|
+
import importlib
|
|
5
|
+
|
|
6
|
+
import rich
|
|
7
|
+
from rich import print
|
|
8
|
+
from rich.panel import Panel
|
|
9
|
+
from rich.console import Console
|
|
10
|
+
from rich.pretty import pprint, Pretty
|
|
11
|
+
|
|
12
|
+
from pathlib import Path, PureWindowsPath
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
console = Console()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def seed_everything(seed=42):
|
|
19
|
+
import random
|
|
20
|
+
import numpy as np
|
|
21
|
+
|
|
22
|
+
random.seed(seed)
|
|
23
|
+
np.random.seed(seed)
|
|
24
|
+
# import torch if it is available
|
|
25
|
+
try:
|
|
26
|
+
import torch
|
|
27
|
+
|
|
28
|
+
torch.manual_seed(seed)
|
|
29
|
+
torch.cuda.manual_seed(seed)
|
|
30
|
+
torch.cuda.manual_seed_all(seed)
|
|
31
|
+
torch.backends.cudnn.deterministic = True
|
|
32
|
+
torch.backends.cudnn.benchmark = False
|
|
33
|
+
except ImportError:
|
|
34
|
+
pprint("torch not imported, skipping torch seed_everything")
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def now_str(sep_date_time="."):
|
|
39
|
+
assert sep_date_time in [
|
|
40
|
+
".",
|
|
41
|
+
"_",
|
|
42
|
+
"-",
|
|
43
|
+
], "sep_date_time must be one of '.', '_', or '-'"
|
|
44
|
+
now_string = arrow.now().format(f"YYYYMMDD{sep_date_time}HHmmss")
|
|
45
|
+
return now_string
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def norm_str(in_str):
|
|
49
|
+
# Replace one or more whitespace characters with a single underscore
|
|
50
|
+
norm_string = re.sub(r"\s+", "_", in_str)
|
|
51
|
+
# Remove leading and trailing spaces
|
|
52
|
+
norm_string = norm_string.strip()
|
|
53
|
+
return norm_string
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def pprint_box(obj, title="", border_style="green"):
|
|
57
|
+
"""
|
|
58
|
+
Pretty print an object in a box.
|
|
59
|
+
"""
|
|
60
|
+
rich.print(
|
|
61
|
+
Panel(Pretty(obj, expand_all=True), title=title, border_style=border_style)
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def console_rule(msg, do_norm_msg=True, is_end_tag=False):
|
|
66
|
+
msg = norm_str(msg) if do_norm_msg else msg
|
|
67
|
+
if is_end_tag:
|
|
68
|
+
console.rule(f"</{msg}>")
|
|
69
|
+
else:
|
|
70
|
+
console.rule(f"<{msg}>")
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def console_log(func):
|
|
74
|
+
def wrapper(*args, **kwargs):
|
|
75
|
+
console_rule(func.__name__)
|
|
76
|
+
result = func(*args, **kwargs)
|
|
77
|
+
console_rule(func.__name__, is_end_tag=True)
|
|
78
|
+
return result
|
|
79
|
+
|
|
80
|
+
return wrapper
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class ConsoleLog:
|
|
84
|
+
def __init__(self, message):
|
|
85
|
+
self.message = message
|
|
86
|
+
|
|
87
|
+
def __enter__(self):
|
|
88
|
+
console_rule(self.message)
|
|
89
|
+
return self
|
|
90
|
+
|
|
91
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
92
|
+
console_rule(self.message, is_end_tag=True)
|
|
93
|
+
if exc_type is not None:
|
|
94
|
+
print(f"An exception of type {exc_type} occurred.")
|
|
95
|
+
print(f"Exception message: {exc_value}")
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def linux_to_wins_path(path: str) -> str:
|
|
99
|
+
"""
|
|
100
|
+
Convert a Linux-style WSL path (/mnt/c/... or /mnt/d/...) to a Windows-style path (C:\...).
|
|
101
|
+
"""
|
|
102
|
+
# Handle only /mnt/<drive>/... style
|
|
103
|
+
if (
|
|
104
|
+
path.startswith("/mnt/")
|
|
105
|
+
and len(path) > 6
|
|
106
|
+
and path[5].isalpha()
|
|
107
|
+
and path[6] == "/"
|
|
108
|
+
):
|
|
109
|
+
drive = path[5].upper() # Extract drive letter
|
|
110
|
+
win_path = f"{drive}:{path[6:]}" # Replace "/mnt/c/" with "C:/"
|
|
111
|
+
else:
|
|
112
|
+
win_path = path # Return unchanged if not a WSL-style path
|
|
113
|
+
# Normalize to Windows-style backslashes
|
|
114
|
+
return str(PureWindowsPath(win_path))
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def pprint_local_path(
|
|
118
|
+
local_path: str, get_wins_path: bool = False, tag: str = ""
|
|
119
|
+
) -> str:
|
|
120
|
+
"""
|
|
121
|
+
Pretty-print a local path with emoji and clickable file:// URI.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
local_path: Path to file or directory (Linux or Windows style).
|
|
125
|
+
get_wins_path: If True on Linux, convert WSL-style path to Windows style before printing.
|
|
126
|
+
tag: Optional console log tag.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
The file URI string.
|
|
130
|
+
"""
|
|
131
|
+
p = Path(local_path).resolve()
|
|
132
|
+
type_str = "📄" if p.is_file() else "📁" if p.is_dir() else "❓"
|
|
133
|
+
|
|
134
|
+
if get_wins_path and os.name == "posix":
|
|
135
|
+
# Try WSL → Windows conversion
|
|
136
|
+
converted = linux_to_wins_path(str(p))
|
|
137
|
+
if converted != str(p): # Conversion happened
|
|
138
|
+
file_uri = str(PureWindowsPath(converted).as_uri())
|
|
139
|
+
else:
|
|
140
|
+
file_uri = p.as_uri()
|
|
141
|
+
else:
|
|
142
|
+
file_uri = p.as_uri()
|
|
143
|
+
|
|
144
|
+
content_str = f"{type_str} [link={file_uri}]{file_uri}[/link]"
|
|
145
|
+
|
|
146
|
+
if tag:
|
|
147
|
+
with ConsoleLog(tag):
|
|
148
|
+
console.print(content_str)
|
|
149
|
+
else:
|
|
150
|
+
console.print(content_str)
|
|
151
|
+
|
|
152
|
+
return file_uri
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def tcuda():
|
|
156
|
+
NOT_INSTALLED = "Not Installed"
|
|
157
|
+
GPU_AVAILABLE = "GPU(s) Available"
|
|
158
|
+
ls_lib = ["torch", "tensorflow"]
|
|
159
|
+
lib_stats = {lib: NOT_INSTALLED for lib in ls_lib}
|
|
160
|
+
for lib in ls_lib:
|
|
161
|
+
spec = importlib.util.find_spec(lib)
|
|
162
|
+
if spec:
|
|
163
|
+
if lib == "torch":
|
|
164
|
+
import torch
|
|
165
|
+
|
|
166
|
+
lib_stats[lib] = str(torch.cuda.device_count()) + " " + GPU_AVAILABLE
|
|
167
|
+
elif lib == "tensorflow":
|
|
168
|
+
import tensorflow as tf
|
|
169
|
+
|
|
170
|
+
lib_stats[lib] = (
|
|
171
|
+
str(len(tf.config.list_physical_devices("GPU")))
|
|
172
|
+
+ " "
|
|
173
|
+
+ GPU_AVAILABLE
|
|
174
|
+
)
|
|
175
|
+
console.rule("<CUDA Library Stats>")
|
|
176
|
+
pprint(lib_stats)
|
|
177
|
+
console.rule("</CUDA Library Stats>")
|
|
178
|
+
return lib_stats
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from rich.pretty import pprint
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import List, Optional, TypeVar, Generic
|
|
5
|
+
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from dataclass_wizard import YAMLWizard
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class NamedCfg(ABC):
|
|
12
|
+
"""
|
|
13
|
+
Base class for named configurations.
|
|
14
|
+
All configurations should have a name.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
def get_name(self):
|
|
19
|
+
"""
|
|
20
|
+
Get the name of the configuration.
|
|
21
|
+
This method should be implemented in subclasses.
|
|
22
|
+
"""
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class AutoNamedCfg(YAMLWizard, NamedCfg):
|
|
28
|
+
"""
|
|
29
|
+
Mixin that automatically implements get_name() by returning self.name.
|
|
30
|
+
Classes using this MUST have a 'name' field.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
name: Optional[str] = None
|
|
34
|
+
|
|
35
|
+
def get_name(self):
|
|
36
|
+
return self.name
|
|
37
|
+
|
|
38
|
+
def __post_init__(self):
|
|
39
|
+
# Enforce the "MUST" rule here
|
|
40
|
+
if self.name is None:
|
|
41
|
+
# We allow None during initial load, but it must be set before usage
|
|
42
|
+
# or handled by the loader.
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
T = TypeVar("T", bound=AutoNamedCfg)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class BaseSelectorCfg(Generic[T]):
|
|
50
|
+
"""
|
|
51
|
+
Base class to handle the logic of selecting an item from a list by name.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def _resolve_selection(self, items: List[T], selected_name: str, context: str) -> T:
|
|
55
|
+
if selected_name is None:
|
|
56
|
+
raise ValueError(f"No {context} selected in the configuration.")
|
|
57
|
+
|
|
58
|
+
# Create a lookup dict for O(1) access, or just iterate if list is short
|
|
59
|
+
for item in items:
|
|
60
|
+
if item.name == selected_name:
|
|
61
|
+
return item
|
|
62
|
+
|
|
63
|
+
raise ValueError(
|
|
64
|
+
f"{context.capitalize()} '{selected_name}' not found in the configuration list."
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class ExpBaseCfg(ABC, YAMLWizard):
|
|
69
|
+
"""
|
|
70
|
+
Base class for configuration objects.
|
|
71
|
+
What a cfg class must have:
|
|
72
|
+
1 - a dataset cfg
|
|
73
|
+
2 - a metric cfg
|
|
74
|
+
3 - a method cfg
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
cfg_name: Optional[str] = None
|
|
78
|
+
|
|
79
|
+
# Save to yaml fil
|
|
80
|
+
def save_to_outdir(
|
|
81
|
+
self, filename: str = "__config.yaml", outdir=None, override: bool = False
|
|
82
|
+
) -> None:
|
|
83
|
+
"""
|
|
84
|
+
Save the configuration to the output directory.
|
|
85
|
+
"""
|
|
86
|
+
if outdir is not None:
|
|
87
|
+
output_dir = outdir
|
|
88
|
+
else:
|
|
89
|
+
output_dir = self.get_outdir()
|
|
90
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
91
|
+
assert (output_dir is not None) and (
|
|
92
|
+
os.path.isdir(output_dir)
|
|
93
|
+
), f"Output directory '{output_dir}' does not exist or is not a directory."
|
|
94
|
+
file_path = os.path.join(output_dir, filename)
|
|
95
|
+
if os.path.exists(file_path) and not override:
|
|
96
|
+
pprint(
|
|
97
|
+
f"File '{file_path}' already exists. Use 'override=True' to overwrite."
|
|
98
|
+
)
|
|
99
|
+
else:
|
|
100
|
+
# method of YAMLWizard to_yaml_file
|
|
101
|
+
self.to_yaml_file(file_path)
|
|
102
|
+
|
|
103
|
+
@classmethod
|
|
104
|
+
@abstractmethod
|
|
105
|
+
# load from a custom YAML file
|
|
106
|
+
def from_custom_yaml_file(cls, yaml_file: str):
|
|
107
|
+
"""Load a configuration from a custom YAML file."""
|
|
108
|
+
pass
|
|
109
|
+
|
|
110
|
+
def get_cfg_name(self, sep: str = "__", *args, **kwargs) -> str:
|
|
111
|
+
# auto get the config name from dataset, method, metric
|
|
112
|
+
# 2. Generate the canonical Config Name
|
|
113
|
+
name_parts = []
|
|
114
|
+
general_info = self.get_general_cfg().get_name()
|
|
115
|
+
dataset_info = self.get_dataset_cfg().get_name()
|
|
116
|
+
method_info = self.get_method_cfg().get_name()
|
|
117
|
+
name_parts = [
|
|
118
|
+
general_info,
|
|
119
|
+
f"ds_{dataset_info}",
|
|
120
|
+
f"mt_{method_info}",
|
|
121
|
+
]
|
|
122
|
+
if "extra" in kwargs:
|
|
123
|
+
extra_info = kwargs["extra"]
|
|
124
|
+
assert isinstance(extra_info, str), "'extra' kwarg must be a string."
|
|
125
|
+
name_parts.append(extra_info)
|
|
126
|
+
self.cfg_name = sep.join(name_parts)
|
|
127
|
+
return self.cfg_name
|
|
128
|
+
|
|
129
|
+
@abstractmethod
|
|
130
|
+
def get_outdir(self):
|
|
131
|
+
"""
|
|
132
|
+
Get the output directory for the configuration.
|
|
133
|
+
This method should be implemented in subclasses.
|
|
134
|
+
"""
|
|
135
|
+
return None
|
|
136
|
+
|
|
137
|
+
@abstractmethod
|
|
138
|
+
def get_general_cfg(self) -> NamedCfg:
|
|
139
|
+
"""
|
|
140
|
+
Get the general configuration like output directory, log settings, SEED, etc.
|
|
141
|
+
This method should be implemented in subclasses.
|
|
142
|
+
"""
|
|
143
|
+
pass
|
|
144
|
+
|
|
145
|
+
@abstractmethod
|
|
146
|
+
def get_dataset_cfg(self) -> NamedCfg:
|
|
147
|
+
"""
|
|
148
|
+
Get the dataset configuration.
|
|
149
|
+
This method should be implemented in subclasses.
|
|
150
|
+
"""
|
|
151
|
+
pass
|
|
152
|
+
|
|
153
|
+
@abstractmethod
|
|
154
|
+
def get_method_cfg(self) -> NamedCfg:
|
|
155
|
+
"""
|
|
156
|
+
Get the method configuration.
|
|
157
|
+
This method should be implemented in subclasses.
|
|
158
|
+
"""
|
|
159
|
+
pass
|
|
160
|
+
|
|
161
|
+
@abstractmethod
|
|
162
|
+
def get_metric_cfg(self) -> NamedCfg:
|
|
163
|
+
"""
|
|
164
|
+
Get the metric configuration.
|
|
165
|
+
This method should be implemented in subclasses.
|
|
166
|
+
"""
|
|
167
|
+
pass
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Tuple, Any, Optional
|
|
3
|
+
from .base_config import ExpBaseCfg
|
|
4
|
+
from ..perf.perfcalc import PerfCalc
|
|
5
|
+
from ..perf.perfmetrics import MetricsBackend
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ExpHook:
|
|
9
|
+
"""Base interface for all experiment hooks."""
|
|
10
|
+
def on_before_run(self, exp): pass
|
|
11
|
+
def on_after_run(self, exp, results): pass
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
# ! SEE https://github.com/hahv/base_exp for sample usage
|
|
15
|
+
class BaseExp(PerfCalc, ABC):
|
|
16
|
+
"""
|
|
17
|
+
Base class for experiments.
|
|
18
|
+
Orchestrates the experiment pipeline using a pluggable metrics backend.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, config: ExpBaseCfg):
|
|
22
|
+
self.config = config
|
|
23
|
+
self.metric_backend = None
|
|
24
|
+
# Flag to track if init_general/prepare_dataset has run
|
|
25
|
+
self._is_env_ready = False
|
|
26
|
+
self.hooks = []
|
|
27
|
+
|
|
28
|
+
def register_hook(self, hook: ExpHook):
|
|
29
|
+
self.hooks.append(hook)
|
|
30
|
+
|
|
31
|
+
def _trigger_hooks(self, method_name: str, *args, **kwargs):
|
|
32
|
+
for hook in self.hooks:
|
|
33
|
+
method = getattr(hook, method_name, None)
|
|
34
|
+
if callable(method):
|
|
35
|
+
method(*args, **kwargs)
|
|
36
|
+
|
|
37
|
+
# -----------------------
|
|
38
|
+
# PerfCalc Required Methods
|
|
39
|
+
# -----------------------
|
|
40
|
+
def get_dataset_name(self):
|
|
41
|
+
return self.config.get_dataset_cfg().get_name()
|
|
42
|
+
|
|
43
|
+
def get_experiment_name(self):
|
|
44
|
+
return self.config.get_cfg_name()
|
|
45
|
+
|
|
46
|
+
def get_metric_backend(self):
|
|
47
|
+
if not self.metric_backend:
|
|
48
|
+
self.metric_backend = self.prepare_metrics(self.config.get_metric_cfg())
|
|
49
|
+
return self.metric_backend
|
|
50
|
+
|
|
51
|
+
# -----------------------
|
|
52
|
+
# Abstract Experiment Steps
|
|
53
|
+
# -----------------------
|
|
54
|
+
@abstractmethod
|
|
55
|
+
def init_general(self, general_cfg):
|
|
56
|
+
"""Setup general settings like SEED, logging, env variables."""
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
@abstractmethod
|
|
60
|
+
def prepare_dataset(self, dataset_cfg):
|
|
61
|
+
"""Load/prepare dataset."""
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
@abstractmethod
|
|
65
|
+
def prepare_metrics(self, metric_cfg) -> MetricsBackend:
|
|
66
|
+
"""
|
|
67
|
+
Prepare the metrics for the experiment.
|
|
68
|
+
This method should be implemented in subclasses.
|
|
69
|
+
"""
|
|
70
|
+
pass
|
|
71
|
+
|
|
72
|
+
@abstractmethod
|
|
73
|
+
def exec_exp(self, *args, **kwargs) -> Optional[Tuple[Any, Any]]:
|
|
74
|
+
"""Run experiment process, e.g.: training/evaluation loop.
|
|
75
|
+
Return: either `None` or a tuple of (raw_metrics_data, extra_data) for calc_and_save_exp_perfs
|
|
76
|
+
"""
|
|
77
|
+
pass
|
|
78
|
+
|
|
79
|
+
# -----------------------
|
|
80
|
+
# Internal Helpers
|
|
81
|
+
# -----------------------
|
|
82
|
+
def _validate_and_unpack(self, results):
|
|
83
|
+
if results is None:
|
|
84
|
+
return None
|
|
85
|
+
if not isinstance(results, (tuple, list)) or len(results) != 2:
|
|
86
|
+
raise ValueError("exec must return (metrics_data, extra_data)")
|
|
87
|
+
return results[0], results[1]
|
|
88
|
+
|
|
89
|
+
def _prepare_environment(self, force_reload: bool = False):
|
|
90
|
+
"""
|
|
91
|
+
Common setup. Skips if already initialized, unless force_reload is True.
|
|
92
|
+
"""
|
|
93
|
+
if self._is_env_ready and not force_reload:
|
|
94
|
+
# Environment is already prepared, skipping setup.
|
|
95
|
+
return
|
|
96
|
+
|
|
97
|
+
# 1. Run Setup
|
|
98
|
+
self.init_general(self.config.get_general_cfg())
|
|
99
|
+
self.prepare_dataset(self.config.get_dataset_cfg())
|
|
100
|
+
|
|
101
|
+
# 2. Update metric backend (refresh if needed)
|
|
102
|
+
self.metric_backend = self.prepare_metrics(self.config.get_metric_cfg())
|
|
103
|
+
|
|
104
|
+
# 3. Mark as ready
|
|
105
|
+
self._is_env_ready = True
|
|
106
|
+
|
|
107
|
+
# -----------------------
|
|
108
|
+
# Main Experiment Runner
|
|
109
|
+
# -----------------------
|
|
110
|
+
def run_exp(self, should_calc_metrics=True, reload_env=False, *args, **kwargs):
|
|
111
|
+
"""
|
|
112
|
+
Run the whole experiment pipeline.
|
|
113
|
+
:param reload_env: If True, forces dataset/general init to run again.
|
|
114
|
+
:param should_calc_metrics: Whether to calculate and save metrics after execution.
|
|
115
|
+
:kwargs Params:
|
|
116
|
+
+ 'outfile' to save csv file results,
|
|
117
|
+
+ 'outdir' to set output directory for experiment results.
|
|
118
|
+
+ 'return_df' to return a DataFrame of results instead of a dictionary.
|
|
119
|
+
|
|
120
|
+
Full pipeline:
|
|
121
|
+
1. Init
|
|
122
|
+
2. Prepare Environment (General + Dataset + Metrics)
|
|
123
|
+
3. Save Config
|
|
124
|
+
4. Execute
|
|
125
|
+
5. Calculate & Save Metrics
|
|
126
|
+
"""
|
|
127
|
+
self._prepare_environment(force_reload=reload_env)
|
|
128
|
+
|
|
129
|
+
self._trigger_hooks("before_run", self)
|
|
130
|
+
|
|
131
|
+
# Save config before running
|
|
132
|
+
self.config.save_to_outdir()
|
|
133
|
+
|
|
134
|
+
# Execute experiment
|
|
135
|
+
results = self.exec_exp(*args, **kwargs)
|
|
136
|
+
|
|
137
|
+
if should_calc_metrics and results is not None:
|
|
138
|
+
metrics_data, extra_data = self._validate_and_unpack(results)
|
|
139
|
+
# Calculate & Save metrics
|
|
140
|
+
perf_results = self.calc_perfs(
|
|
141
|
+
raw_metrics_data=metrics_data, extra_data=extra_data, *args, **kwargs
|
|
142
|
+
)
|
|
143
|
+
self._trigger_hooks("after_run", self, perf_results)
|
|
144
|
+
return perf_results
|
|
145
|
+
else:
|
|
146
|
+
self._trigger_hooks("after_run", self, results)
|
|
147
|
+
return results
|