halib 0.1.89__tar.gz → 0.2.13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. {halib-0.1.89 → halib-0.2.13}/.gitignore +0 -1
  2. halib-0.2.13/MANIFEST.in +5 -0
  3. {halib-0.1.89 → halib-0.2.13}/PKG-INFO +21 -3
  4. {halib-0.1.89 → halib-0.2.13}/README.md +19 -2
  5. {halib-0.1.89 → halib-0.2.13}/halib/__init__.py +3 -3
  6. halib-0.2.13/halib/common/common.py +178 -0
  7. halib-0.2.13/halib/exp/core/base_config.py +167 -0
  8. halib-0.2.13/halib/exp/core/base_exp.py +147 -0
  9. halib-0.2.13/halib/exp/core/param_gen.py +181 -0
  10. {halib-0.1.89/halib/research → halib-0.2.13/halib/exp/core}/wandb_op.py +5 -4
  11. {halib-0.1.89/halib/utils → halib-0.2.13/halib/exp/data}/dataclass_util.py +3 -2
  12. {halib-0.1.89/halib/research → halib-0.2.13/halib/exp/data}/dataset.py +6 -7
  13. {halib-0.1.89/halib/research → halib-0.2.13/halib/exp/data}/torchloader.py +12 -9
  14. halib-0.2.13/halib/exp/perf/flop_calc.py +190 -0
  15. {halib-0.1.89/halib/research → halib-0.2.13/halib/exp/perf}/perfcalc.py +195 -91
  16. halib-0.1.89/halib/research/metrics.py → halib-0.2.13/halib/exp/perf/perfmetrics.py +4 -0
  17. {halib-0.1.89/halib/research → halib-0.2.13/halib/exp/perf}/perftb.py +6 -7
  18. {halib-0.1.89/halib/research → halib-0.2.13/halib/exp/perf}/profiler.py +6 -6
  19. halib-0.2.13/halib/exp/viz/__init__.py +0 -0
  20. halib-0.2.13/halib/exp/viz/plot.py +754 -0
  21. halib-0.2.13/halib/filetype/__init__.py +0 -0
  22. {halib-0.1.89 → halib-0.2.13}/halib/filetype/csvfile.py +3 -9
  23. halib-0.2.13/halib/filetype/ipynb.py +61 -0
  24. {halib-0.1.89 → halib-0.2.13}/halib/filetype/jsonfile.py +0 -3
  25. {halib-0.1.89 → halib-0.2.13}/halib/filetype/textfile.py +0 -1
  26. {halib-0.1.89 → halib-0.2.13}/halib/filetype/videofile.py +91 -2
  27. {halib-0.1.89 → halib-0.2.13}/halib/filetype/yamlfile.py +16 -1
  28. halib-0.2.13/halib/online/__init__.py +0 -0
  29. {halib-0.1.89 → halib-0.2.13}/halib/online/projectmake.py +7 -6
  30. {halib-0.1.89/halib/utils → halib-0.2.13/halib/online}/tele_noti.py +1 -2
  31. halib-0.2.13/halib/system/__init__.py +0 -0
  32. halib-0.2.13/halib/system/_list_pc.csv +6 -0
  33. halib-0.2.13/halib/system/filesys.py +164 -0
  34. halib-0.2.13/halib/system/path.py +106 -0
  35. halib-0.2.13/halib/utils/__init__.py +0 -0
  36. halib-0.1.89/halib/utils/listop.py → halib-0.2.13/halib/utils/list.py +0 -1
  37. {halib-0.1.89 → halib-0.2.13}/halib.egg-info/PKG-INFO +21 -3
  38. halib-0.2.13/halib.egg-info/SOURCES.txt +54 -0
  39. {halib-0.1.89 → halib-0.2.13}/halib.egg-info/requires.txt +1 -0
  40. {halib-0.1.89 → halib-0.2.13}/setup.py +2 -1
  41. halib-0.1.89/MANIFEST.in +0 -4
  42. halib-0.1.89/guide_publish_pip.pdf +0 -0
  43. halib-0.1.89/halib/common.py +0 -108
  44. halib-0.1.89/halib/cuda.py +0 -39
  45. halib-0.1.89/halib/online/gdrive_test.py +0 -50
  46. halib-0.1.89/halib/research/base_config.py +0 -100
  47. halib-0.1.89/halib/research/base_exp.py +0 -100
  48. halib-0.1.89/halib/research/mics.py +0 -16
  49. halib-0.1.89/halib/research/plot.py +0 -496
  50. halib-0.1.89/halib/system/filesys.py +0 -124
  51. halib-0.1.89/halib/utils/video.py +0 -76
  52. halib-0.1.89/halib.egg-info/SOURCES.txt +0 -49
  53. {halib-0.1.89 → halib-0.2.13}/GDriveFolder.txt +0 -0
  54. {halib-0.1.89 → halib-0.2.13}/LICENSE.txt +0 -0
  55. {halib-0.1.89/halib/filetype → halib-0.2.13/halib/common}/__init__.py +0 -0
  56. {halib-0.1.89/halib → halib-0.2.13/halib/common}/rich_color.py +0 -0
  57. {halib-0.1.89/halib/online → halib-0.2.13/halib/exp}/__init__.py +0 -0
  58. {halib-0.1.89/halib/research → halib-0.2.13/halib/exp/core}/__init__.py +0 -0
  59. {halib-0.1.89/halib/system → halib-0.2.13/halib/exp/data}/__init__.py +0 -0
  60. {halib-0.1.89/halib/utils → halib-0.2.13/halib/exp/perf}/__init__.py +0 -0
  61. {halib-0.1.89/halib/utils → halib-0.2.13/halib/exp/perf}/gpu_mon.py +0 -0
  62. {halib-0.1.89 → halib-0.2.13}/halib/online/gdrive.py +0 -0
  63. {halib-0.1.89 → halib-0.2.13}/halib/online/gdrive_mkdir.py +0 -0
  64. {halib-0.1.89 → halib-0.2.13}/halib/system/cmd.py +0 -0
  65. /halib-0.1.89/halib/utils/dict_op.py → /halib-0.2.13/halib/utils/dict.py +0 -0
  66. {halib-0.1.89 → halib-0.2.13}/halib.egg-info/dependency_links.txt +0 -0
  67. {halib-0.1.89 → halib-0.2.13}/halib.egg-info/top_level.txt +0 -0
  68. {halib-0.1.89 → halib-0.2.13}/setup.cfg +0 -0
@@ -50,7 +50,6 @@ Thumbs.db
50
50
 
51
51
  build
52
52
  dist
53
- data
54
53
 
55
54
  venv*/
56
55
 
@@ -0,0 +1,5 @@
1
+ prune _archived
2
+ prune test
3
+ prune zout
4
+ exclude *.toml
5
+ include halib/system/_list_pc.csv
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: halib
3
- Version: 0.1.89
3
+ Version: 0.2.13
4
4
  Summary: Small library for common tasks
5
5
  Author: Hoang Van Ha
6
6
  Author-email: hoangvanhauit@gmail.com
@@ -40,6 +40,7 @@ Requires-Dist: timebudget
40
40
  Requires-Dist: tqdm
41
41
  Requires-Dist: tube_dl
42
42
  Requires-Dist: wandb
43
+ Requires-Dist: ipynbname
43
44
  Dynamic: author
44
45
  Dynamic: author-email
45
46
  Dynamic: classifier
@@ -50,9 +51,26 @@ Dynamic: requires-dist
50
51
  Dynamic: requires-python
51
52
  Dynamic: summary
52
53
 
53
- Helper package for coding and automation
54
+ # Helper package for coding and automation
54
55
 
55
- **Version 0.1.89**
56
+ **Version 0.2.13**
57
+ + reorganize packages with most changes in `research` package; also rename `research` to `exp` (package for experiment management and utilities)
58
+ + update `exp/perfcalc.py` to allow save computed performance to csv file (without explicit calling method `calc_perfs`)
59
+
60
+ **Version 0.2.1**
61
+ + `research/base_exp`: add `eval_exp` method to evaluate experiment (e.g., model evaluation on test set) after experiment running is done.
62
+
63
+ **Version 0.1.99**
64
+ + `filetype/ipynb`: add `gen_ipynb_name` generator to create file name based on current notebook name as prefix (with optional timestamp)
65
+
66
+ **Version 0.1.96**
67
+ + `research/plot`: add `PlotHelper` class to plot train history + plot grid of images (e.g., image samples from dataset or model outputs)
68
+
69
+
70
+ **Version 0.1.91**
71
+ + `research/param_gen`: add `ParamGen` class to generate parameter list from yaml file for hyperparameter search (grid search, random search, etc.)
72
+
73
+ **Version 0.1.90**
56
74
 
57
75
  + `research/profiler`: add `zProfiler` class to measure execution time of contexts and steps, with support for dynamic color scales in plots.
58
76
 
@@ -1,6 +1,23 @@
1
- Helper package for coding and automation
1
+ # Helper package for coding and automation
2
2
 
3
- **Version 0.1.89**
3
+ **Version 0.2.13**
4
+ + reorganize packages with most changes in `research` package; also rename `research` to `exp` (package for experiment management and utilities)
5
+ + update `exp/perfcalc.py` to allow save computed performance to csv file (without explicit calling method `calc_perfs`)
6
+
7
+ **Version 0.2.1**
8
+ + `research/base_exp`: add `eval_exp` method to evaluate experiment (e.g., model evaluation on test set) after experiment running is done.
9
+
10
+ **Version 0.1.99**
11
+ + `filetype/ipynb`: add `gen_ipynb_name` generator to create file name based on current notebook name as prefix (with optional timestamp)
12
+
13
+ **Version 0.1.96**
14
+ + `research/plot`: add `PlotHelper` class to plot train history + plot grid of images (e.g., image samples from dataset or model outputs)
15
+
16
+
17
+ **Version 0.1.91**
18
+ + `research/param_gen`: add `ParamGen` class to generate parameter list from yaml file for hyperparameter search (grid search, random search, etc.)
19
+
20
+ **Version 0.1.90**
4
21
 
5
22
  + `research/profiler`: add `zProfiler` class to measure execution time of contexts and steps, with support for dynamic color scales in plots.
6
23
 
@@ -56,8 +56,7 @@ from .filetype.yamlfile import load_yaml
56
56
  from .system import cmd
57
57
  from .system import filesys as fs
58
58
  from .filetype import csvfile
59
- from .cuda import tcuda
60
- from .common import (
59
+ from .common.common import (
61
60
  console,
62
61
  console_log,
63
62
  ConsoleLog,
@@ -65,6 +64,7 @@ from .common import (
65
64
  norm_str,
66
65
  pprint_box,
67
66
  pprint_local_path,
67
+ tcuda
68
68
  )
69
69
 
70
70
  # for log
@@ -76,7 +76,7 @@ from timebudget import timebudget
76
76
  import omegaconf
77
77
  from omegaconf import OmegaConf
78
78
  from omegaconf.dictconfig import DictConfig
79
- from .rich_color import rcolor_str, rcolor_palette, rcolor_palette_all, rcolor_all_str
79
+ from .common.rich_color import rcolor_str, rcolor_palette, rcolor_palette_all, rcolor_all_str
80
80
 
81
81
  # for visualization
82
82
  import seaborn as sns
@@ -0,0 +1,178 @@
1
+ import os
2
+ import re
3
+ import arrow
4
+ import importlib
5
+
6
+ import rich
7
+ from rich import print
8
+ from rich.panel import Panel
9
+ from rich.console import Console
10
+ from rich.pretty import pprint, Pretty
11
+
12
+ from pathlib import Path, PureWindowsPath
13
+
14
+
15
+ console = Console()
16
+
17
+
18
+ def seed_everything(seed=42):
19
+ import random
20
+ import numpy as np
21
+
22
+ random.seed(seed)
23
+ np.random.seed(seed)
24
+ # import torch if it is available
25
+ try:
26
+ import torch
27
+
28
+ torch.manual_seed(seed)
29
+ torch.cuda.manual_seed(seed)
30
+ torch.cuda.manual_seed_all(seed)
31
+ torch.backends.cudnn.deterministic = True
32
+ torch.backends.cudnn.benchmark = False
33
+ except ImportError:
34
+ pprint("torch not imported, skipping torch seed_everything")
35
+ pass
36
+
37
+
38
+ def now_str(sep_date_time="."):
39
+ assert sep_date_time in [
40
+ ".",
41
+ "_",
42
+ "-",
43
+ ], "sep_date_time must be one of '.', '_', or '-'"
44
+ now_string = arrow.now().format(f"YYYYMMDD{sep_date_time}HHmmss")
45
+ return now_string
46
+
47
+
48
+ def norm_str(in_str):
49
+ # Replace one or more whitespace characters with a single underscore
50
+ norm_string = re.sub(r"\s+", "_", in_str)
51
+ # Remove leading and trailing spaces
52
+ norm_string = norm_string.strip()
53
+ return norm_string
54
+
55
+
56
+ def pprint_box(obj, title="", border_style="green"):
57
+ """
58
+ Pretty print an object in a box.
59
+ """
60
+ rich.print(
61
+ Panel(Pretty(obj, expand_all=True), title=title, border_style=border_style)
62
+ )
63
+
64
+
65
+ def console_rule(msg, do_norm_msg=True, is_end_tag=False):
66
+ msg = norm_str(msg) if do_norm_msg else msg
67
+ if is_end_tag:
68
+ console.rule(f"</{msg}>")
69
+ else:
70
+ console.rule(f"<{msg}>")
71
+
72
+
73
+ def console_log(func):
74
+ def wrapper(*args, **kwargs):
75
+ console_rule(func.__name__)
76
+ result = func(*args, **kwargs)
77
+ console_rule(func.__name__, is_end_tag=True)
78
+ return result
79
+
80
+ return wrapper
81
+
82
+
83
+ class ConsoleLog:
84
+ def __init__(self, message):
85
+ self.message = message
86
+
87
+ def __enter__(self):
88
+ console_rule(self.message)
89
+ return self
90
+
91
+ def __exit__(self, exc_type, exc_value, traceback):
92
+ console_rule(self.message, is_end_tag=True)
93
+ if exc_type is not None:
94
+ print(f"An exception of type {exc_type} occurred.")
95
+ print(f"Exception message: {exc_value}")
96
+
97
+
98
+ def linux_to_wins_path(path: str) -> str:
99
+ """
100
+ Convert a Linux-style WSL path (/mnt/c/... or /mnt/d/...) to a Windows-style path (C:\...).
101
+ """
102
+ # Handle only /mnt/<drive>/... style
103
+ if (
104
+ path.startswith("/mnt/")
105
+ and len(path) > 6
106
+ and path[5].isalpha()
107
+ and path[6] == "/"
108
+ ):
109
+ drive = path[5].upper() # Extract drive letter
110
+ win_path = f"{drive}:{path[6:]}" # Replace "/mnt/c/" with "C:/"
111
+ else:
112
+ win_path = path # Return unchanged if not a WSL-style path
113
+ # Normalize to Windows-style backslashes
114
+ return str(PureWindowsPath(win_path))
115
+
116
+
117
+ def pprint_local_path(
118
+ local_path: str, get_wins_path: bool = False, tag: str = ""
119
+ ) -> str:
120
+ """
121
+ Pretty-print a local path with emoji and clickable file:// URI.
122
+
123
+ Args:
124
+ local_path: Path to file or directory (Linux or Windows style).
125
+ get_wins_path: If True on Linux, convert WSL-style path to Windows style before printing.
126
+ tag: Optional console log tag.
127
+
128
+ Returns:
129
+ The file URI string.
130
+ """
131
+ p = Path(local_path).resolve()
132
+ type_str = "📄" if p.is_file() else "📁" if p.is_dir() else "❓"
133
+
134
+ if get_wins_path and os.name == "posix":
135
+ # Try WSL → Windows conversion
136
+ converted = linux_to_wins_path(str(p))
137
+ if converted != str(p): # Conversion happened
138
+ file_uri = str(PureWindowsPath(converted).as_uri())
139
+ else:
140
+ file_uri = p.as_uri()
141
+ else:
142
+ file_uri = p.as_uri()
143
+
144
+ content_str = f"{type_str} [link={file_uri}]{file_uri}[/link]"
145
+
146
+ if tag:
147
+ with ConsoleLog(tag):
148
+ console.print(content_str)
149
+ else:
150
+ console.print(content_str)
151
+
152
+ return file_uri
153
+
154
+
155
+ def tcuda():
156
+ NOT_INSTALLED = "Not Installed"
157
+ GPU_AVAILABLE = "GPU(s) Available"
158
+ ls_lib = ["torch", "tensorflow"]
159
+ lib_stats = {lib: NOT_INSTALLED for lib in ls_lib}
160
+ for lib in ls_lib:
161
+ spec = importlib.util.find_spec(lib)
162
+ if spec:
163
+ if lib == "torch":
164
+ import torch
165
+
166
+ lib_stats[lib] = str(torch.cuda.device_count()) + " " + GPU_AVAILABLE
167
+ elif lib == "tensorflow":
168
+ import tensorflow as tf
169
+
170
+ lib_stats[lib] = (
171
+ str(len(tf.config.list_physical_devices("GPU")))
172
+ + " "
173
+ + GPU_AVAILABLE
174
+ )
175
+ console.rule("<CUDA Library Stats>")
176
+ pprint(lib_stats)
177
+ console.rule("</CUDA Library Stats>")
178
+ return lib_stats
@@ -0,0 +1,167 @@
1
+ import os
2
+ from rich.pretty import pprint
3
+ from abc import ABC, abstractmethod
4
+ from typing import List, Optional, TypeVar, Generic
5
+
6
+ from abc import ABC, abstractmethod
7
+ from dataclasses import dataclass
8
+ from dataclass_wizard import YAMLWizard
9
+
10
+
11
+ class NamedCfg(ABC):
12
+ """
13
+ Base class for named configurations.
14
+ All configurations should have a name.
15
+ """
16
+
17
+ @abstractmethod
18
+ def get_name(self):
19
+ """
20
+ Get the name of the configuration.
21
+ This method should be implemented in subclasses.
22
+ """
23
+ pass
24
+
25
+
26
+ @dataclass
27
+ class AutoNamedCfg(YAMLWizard, NamedCfg):
28
+ """
29
+ Mixin that automatically implements get_name() by returning self.name.
30
+ Classes using this MUST have a 'name' field.
31
+ """
32
+
33
+ name: Optional[str] = None
34
+
35
+ def get_name(self):
36
+ return self.name
37
+
38
+ def __post_init__(self):
39
+ # Enforce the "MUST" rule here
40
+ if self.name is None:
41
+ # We allow None during initial load, but it must be set before usage
42
+ # or handled by the loader.
43
+ pass
44
+
45
+
46
+ T = TypeVar("T", bound=AutoNamedCfg)
47
+
48
+
49
+ class BaseSelectorCfg(Generic[T]):
50
+ """
51
+ Base class to handle the logic of selecting an item from a list by name.
52
+ """
53
+
54
+ def _resolve_selection(self, items: List[T], selected_name: str, context: str) -> T:
55
+ if selected_name is None:
56
+ raise ValueError(f"No {context} selected in the configuration.")
57
+
58
+ # Create a lookup dict for O(1) access, or just iterate if list is short
59
+ for item in items:
60
+ if item.name == selected_name:
61
+ return item
62
+
63
+ raise ValueError(
64
+ f"{context.capitalize()} '{selected_name}' not found in the configuration list."
65
+ )
66
+
67
+
68
+ class ExpBaseCfg(ABC, YAMLWizard):
69
+ """
70
+ Base class for configuration objects.
71
+ What a cfg class must have:
72
+ 1 - a dataset cfg
73
+ 2 - a metric cfg
74
+ 3 - a method cfg
75
+ """
76
+
77
+ cfg_name: Optional[str] = None
78
+
79
+ # Save to yaml fil
80
+ def save_to_outdir(
81
+ self, filename: str = "__config.yaml", outdir=None, override: bool = False
82
+ ) -> None:
83
+ """
84
+ Save the configuration to the output directory.
85
+ """
86
+ if outdir is not None:
87
+ output_dir = outdir
88
+ else:
89
+ output_dir = self.get_outdir()
90
+ os.makedirs(output_dir, exist_ok=True)
91
+ assert (output_dir is not None) and (
92
+ os.path.isdir(output_dir)
93
+ ), f"Output directory '{output_dir}' does not exist or is not a directory."
94
+ file_path = os.path.join(output_dir, filename)
95
+ if os.path.exists(file_path) and not override:
96
+ pprint(
97
+ f"File '{file_path}' already exists. Use 'override=True' to overwrite."
98
+ )
99
+ else:
100
+ # method of YAMLWizard to_yaml_file
101
+ self.to_yaml_file(file_path)
102
+
103
+ @classmethod
104
+ @abstractmethod
105
+ # load from a custom YAML file
106
+ def from_custom_yaml_file(cls, yaml_file: str):
107
+ """Load a configuration from a custom YAML file."""
108
+ pass
109
+
110
+ def get_cfg_name(self, sep: str = "__", *args, **kwargs) -> str:
111
+ # auto get the config name from dataset, method, metric
112
+ # 2. Generate the canonical Config Name
113
+ name_parts = []
114
+ general_info = self.get_general_cfg().get_name()
115
+ dataset_info = self.get_dataset_cfg().get_name()
116
+ method_info = self.get_method_cfg().get_name()
117
+ name_parts = [
118
+ general_info,
119
+ f"ds_{dataset_info}",
120
+ f"mt_{method_info}",
121
+ ]
122
+ if "extra" in kwargs:
123
+ extra_info = kwargs["extra"]
124
+ assert isinstance(extra_info, str), "'extra' kwarg must be a string."
125
+ name_parts.append(extra_info)
126
+ self.cfg_name = sep.join(name_parts)
127
+ return self.cfg_name
128
+
129
+ @abstractmethod
130
+ def get_outdir(self):
131
+ """
132
+ Get the output directory for the configuration.
133
+ This method should be implemented in subclasses.
134
+ """
135
+ return None
136
+
137
+ @abstractmethod
138
+ def get_general_cfg(self) -> NamedCfg:
139
+ """
140
+ Get the general configuration like output directory, log settings, SEED, etc.
141
+ This method should be implemented in subclasses.
142
+ """
143
+ pass
144
+
145
+ @abstractmethod
146
+ def get_dataset_cfg(self) -> NamedCfg:
147
+ """
148
+ Get the dataset configuration.
149
+ This method should be implemented in subclasses.
150
+ """
151
+ pass
152
+
153
+ @abstractmethod
154
+ def get_method_cfg(self) -> NamedCfg:
155
+ """
156
+ Get the method configuration.
157
+ This method should be implemented in subclasses.
158
+ """
159
+ pass
160
+
161
+ @abstractmethod
162
+ def get_metric_cfg(self) -> NamedCfg:
163
+ """
164
+ Get the metric configuration.
165
+ This method should be implemented in subclasses.
166
+ """
167
+ pass
@@ -0,0 +1,147 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Tuple, Any, Optional
3
+ from .base_config import ExpBaseCfg
4
+ from ..perf.perfcalc import PerfCalc
5
+ from ..perf.perfmetrics import MetricsBackend
6
+
7
+
8
+ class ExpHook:
9
+ """Base interface for all experiment hooks."""
10
+ def on_before_run(self, exp): pass
11
+ def on_after_run(self, exp, results): pass
12
+
13
+
14
+ # ! SEE https://github.com/hahv/base_exp for sample usage
15
+ class BaseExp(PerfCalc, ABC):
16
+ """
17
+ Base class for experiments.
18
+ Orchestrates the experiment pipeline using a pluggable metrics backend.
19
+ """
20
+
21
+ def __init__(self, config: ExpBaseCfg):
22
+ self.config = config
23
+ self.metric_backend = None
24
+ # Flag to track if init_general/prepare_dataset has run
25
+ self._is_env_ready = False
26
+ self.hooks = []
27
+
28
+ def register_hook(self, hook: ExpHook):
29
+ self.hooks.append(hook)
30
+
31
+ def _trigger_hooks(self, method_name: str, *args, **kwargs):
32
+ for hook in self.hooks:
33
+ method = getattr(hook, method_name, None)
34
+ if callable(method):
35
+ method(*args, **kwargs)
36
+
37
+ # -----------------------
38
+ # PerfCalc Required Methods
39
+ # -----------------------
40
+ def get_dataset_name(self):
41
+ return self.config.get_dataset_cfg().get_name()
42
+
43
+ def get_experiment_name(self):
44
+ return self.config.get_cfg_name()
45
+
46
+ def get_metric_backend(self):
47
+ if not self.metric_backend:
48
+ self.metric_backend = self.prepare_metrics(self.config.get_metric_cfg())
49
+ return self.metric_backend
50
+
51
+ # -----------------------
52
+ # Abstract Experiment Steps
53
+ # -----------------------
54
+ @abstractmethod
55
+ def init_general(self, general_cfg):
56
+ """Setup general settings like SEED, logging, env variables."""
57
+ pass
58
+
59
+ @abstractmethod
60
+ def prepare_dataset(self, dataset_cfg):
61
+ """Load/prepare dataset."""
62
+ pass
63
+
64
+ @abstractmethod
65
+ def prepare_metrics(self, metric_cfg) -> MetricsBackend:
66
+ """
67
+ Prepare the metrics for the experiment.
68
+ This method should be implemented in subclasses.
69
+ """
70
+ pass
71
+
72
+ @abstractmethod
73
+ def exec_exp(self, *args, **kwargs) -> Optional[Tuple[Any, Any]]:
74
+ """Run experiment process, e.g.: training/evaluation loop.
75
+ Return: either `None` or a tuple of (raw_metrics_data, extra_data) for calc_and_save_exp_perfs
76
+ """
77
+ pass
78
+
79
+ # -----------------------
80
+ # Internal Helpers
81
+ # -----------------------
82
+ def _validate_and_unpack(self, results):
83
+ if results is None:
84
+ return None
85
+ if not isinstance(results, (tuple, list)) or len(results) != 2:
86
+ raise ValueError("exec must return (metrics_data, extra_data)")
87
+ return results[0], results[1]
88
+
89
+ def _prepare_environment(self, force_reload: bool = False):
90
+ """
91
+ Common setup. Skips if already initialized, unless force_reload is True.
92
+ """
93
+ if self._is_env_ready and not force_reload:
94
+ # Environment is already prepared, skipping setup.
95
+ return
96
+
97
+ # 1. Run Setup
98
+ self.init_general(self.config.get_general_cfg())
99
+ self.prepare_dataset(self.config.get_dataset_cfg())
100
+
101
+ # 2. Update metric backend (refresh if needed)
102
+ self.metric_backend = self.prepare_metrics(self.config.get_metric_cfg())
103
+
104
+ # 3. Mark as ready
105
+ self._is_env_ready = True
106
+
107
+ # -----------------------
108
+ # Main Experiment Runner
109
+ # -----------------------
110
+ def run_exp(self, should_calc_metrics=True, reload_env=False, *args, **kwargs):
111
+ """
112
+ Run the whole experiment pipeline.
113
+ :param reload_env: If True, forces dataset/general init to run again.
114
+ :param should_calc_metrics: Whether to calculate and save metrics after execution.
115
+ :kwargs Params:
116
+ + 'outfile' to save csv file results,
117
+ + 'outdir' to set output directory for experiment results.
118
+ + 'return_df' to return a DataFrame of results instead of a dictionary.
119
+
120
+ Full pipeline:
121
+ 1. Init
122
+ 2. Prepare Environment (General + Dataset + Metrics)
123
+ 3. Save Config
124
+ 4. Execute
125
+ 5. Calculate & Save Metrics
126
+ """
127
+ self._prepare_environment(force_reload=reload_env)
128
+
129
+ self._trigger_hooks("before_run", self)
130
+
131
+ # Save config before running
132
+ self.config.save_to_outdir()
133
+
134
+ # Execute experiment
135
+ results = self.exec_exp(*args, **kwargs)
136
+
137
+ if should_calc_metrics and results is not None:
138
+ metrics_data, extra_data = self._validate_and_unpack(results)
139
+ # Calculate & Save metrics
140
+ perf_results = self.calc_perfs(
141
+ raw_metrics_data=metrics_data, extra_data=extra_data, *args, **kwargs
142
+ )
143
+ self._trigger_hooks("after_run", self, perf_results)
144
+ return perf_results
145
+ else:
146
+ self._trigger_hooks("after_run", self, results)
147
+ return results