halib 0.1.99__py3-none-any.whl → 0.2.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- halib/__init__.py +12 -6
- halib/common/__init__.py +0 -0
- halib/common/common.py +207 -0
- halib/common/rich_color.py +285 -0
- halib/exp/__init__.py +0 -0
- halib/exp/core/__init__.py +0 -0
- halib/exp/core/base_config.py +167 -0
- halib/exp/core/base_exp.py +147 -0
- halib/exp/core/param_gen.py +189 -0
- halib/exp/core/wandb_op.py +117 -0
- halib/exp/data/__init__.py +0 -0
- halib/exp/data/dataclass_util.py +41 -0
- halib/exp/data/dataset.py +208 -0
- halib/exp/data/torchloader.py +165 -0
- halib/exp/perf/__init__.py +0 -0
- halib/exp/perf/flop_calc.py +190 -0
- halib/exp/perf/gpu_mon.py +58 -0
- halib/exp/perf/perfcalc.py +440 -0
- halib/exp/perf/perfmetrics.py +137 -0
- halib/exp/perf/perftb.py +778 -0
- halib/exp/perf/profiler.py +507 -0
- halib/exp/viz/__init__.py +0 -0
- halib/exp/viz/plot.py +754 -0
- halib/filetype/csvfile.py +3 -9
- halib/filetype/ipynb.py +3 -5
- halib/filetype/jsonfile.py +0 -3
- halib/filetype/textfile.py +0 -1
- halib/filetype/videofile.py +119 -3
- halib/filetype/yamlfile.py +8 -16
- halib/online/projectmake.py +7 -6
- halib/online/tele_noti.py +165 -0
- halib/research/base_exp.py +75 -18
- halib/research/core/__init__.py +0 -0
- halib/research/core/base_config.py +144 -0
- halib/research/core/base_exp.py +157 -0
- halib/research/core/param_gen.py +108 -0
- halib/research/core/wandb_op.py +117 -0
- halib/research/data/__init__.py +0 -0
- halib/research/data/dataclass_util.py +41 -0
- halib/research/data/dataset.py +208 -0
- halib/research/data/torchloader.py +165 -0
- halib/research/dataset.py +1 -1
- halib/research/metrics.py +4 -0
- halib/research/mics.py +8 -2
- halib/research/perf/__init__.py +0 -0
- halib/research/perf/flop_calc.py +190 -0
- halib/research/perf/gpu_mon.py +58 -0
- halib/research/perf/perfcalc.py +363 -0
- halib/research/perf/perfmetrics.py +137 -0
- halib/research/perf/perftb.py +778 -0
- halib/research/perf/profiler.py +301 -0
- halib/research/perfcalc.py +57 -32
- halib/research/viz/__init__.py +0 -0
- halib/research/viz/plot.py +754 -0
- halib/system/_list_pc.csv +6 -0
- halib/system/filesys.py +60 -20
- halib/system/path.py +106 -0
- halib/utils/dict.py +9 -0
- halib/utils/list.py +12 -0
- halib-0.2.21.dist-info/METADATA +192 -0
- halib-0.2.21.dist-info/RECORD +109 -0
- halib-0.1.99.dist-info/METADATA +0 -209
- halib-0.1.99.dist-info/RECORD +0 -64
- {halib-0.1.99.dist-info → halib-0.2.21.dist-info}/WHEEL +0 -0
- {halib-0.1.99.dist-info → halib-0.2.21.dist-info}/licenses/LICENSE.txt +0 -0
- {halib-0.1.99.dist-info → halib-0.2.21.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from rich.pretty import pprint
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import List, Optional, TypeVar, Generic
|
|
5
|
+
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from dataclass_wizard import YAMLWizard
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class NamedCfg(ABC):
|
|
12
|
+
"""
|
|
13
|
+
Base class for named configurations.
|
|
14
|
+
All configurations should have a name.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
def get_name(self):
|
|
19
|
+
"""
|
|
20
|
+
Get the name of the configuration.
|
|
21
|
+
This method should be implemented in subclasses.
|
|
22
|
+
"""
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class AutoNamedCfg(YAMLWizard, NamedCfg):
|
|
28
|
+
"""
|
|
29
|
+
Mixin that automatically implements get_name() by returning self.name.
|
|
30
|
+
Classes using this MUST have a 'name' field.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
name: Optional[str] = None
|
|
34
|
+
|
|
35
|
+
def get_name(self):
|
|
36
|
+
return self.name
|
|
37
|
+
|
|
38
|
+
def __post_init__(self):
|
|
39
|
+
# Enforce the "MUST" rule here
|
|
40
|
+
if self.name is None:
|
|
41
|
+
# We allow None during initial load, but it must be set before usage
|
|
42
|
+
# or handled by the loader.
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
T = TypeVar("T", bound=AutoNamedCfg)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class BaseSelectorCfg(Generic[T]):
|
|
50
|
+
"""
|
|
51
|
+
Base class to handle the logic of selecting an item from a list by name.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def _resolve_selection(self, items: List[T], selected_name: str, context: str) -> T:
|
|
55
|
+
if selected_name is None:
|
|
56
|
+
raise ValueError(f"No {context} selected in the configuration.")
|
|
57
|
+
|
|
58
|
+
# Create a lookup dict for O(1) access, or just iterate if list is short
|
|
59
|
+
for item in items:
|
|
60
|
+
if item.name == selected_name:
|
|
61
|
+
return item
|
|
62
|
+
|
|
63
|
+
raise ValueError(
|
|
64
|
+
f"{context.capitalize()} '{selected_name}' not found in the configuration list."
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class ExpBaseCfg(ABC, YAMLWizard):
|
|
69
|
+
"""
|
|
70
|
+
Base class for configuration objects.
|
|
71
|
+
What a cfg class must have:
|
|
72
|
+
1 - a dataset cfg
|
|
73
|
+
2 - a metric cfg
|
|
74
|
+
3 - a method cfg
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
cfg_name: Optional[str] = None
|
|
78
|
+
|
|
79
|
+
# Save to yaml fil
|
|
80
|
+
def save_to_outdir(
|
|
81
|
+
self, filename: str = "__config.yaml", outdir=None, override: bool = False
|
|
82
|
+
) -> None:
|
|
83
|
+
"""
|
|
84
|
+
Save the configuration to the output directory.
|
|
85
|
+
"""
|
|
86
|
+
if outdir is not None:
|
|
87
|
+
output_dir = outdir
|
|
88
|
+
else:
|
|
89
|
+
output_dir = self.get_outdir()
|
|
90
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
91
|
+
assert (output_dir is not None) and (
|
|
92
|
+
os.path.isdir(output_dir)
|
|
93
|
+
), f"Output directory '{output_dir}' does not exist or is not a directory."
|
|
94
|
+
file_path = os.path.join(output_dir, filename)
|
|
95
|
+
if os.path.exists(file_path) and not override:
|
|
96
|
+
pprint(
|
|
97
|
+
f"File '{file_path}' already exists. Use 'override=True' to overwrite."
|
|
98
|
+
)
|
|
99
|
+
else:
|
|
100
|
+
# method of YAMLWizard to_yaml_file
|
|
101
|
+
self.to_yaml_file(file_path)
|
|
102
|
+
|
|
103
|
+
@classmethod
|
|
104
|
+
@abstractmethod
|
|
105
|
+
# load from a custom YAML file
|
|
106
|
+
def from_custom_yaml_file(cls, yaml_file: str):
|
|
107
|
+
"""Load a configuration from a custom YAML file."""
|
|
108
|
+
pass
|
|
109
|
+
|
|
110
|
+
def get_cfg_name(self, sep: str = "__", *args, **kwargs) -> str:
|
|
111
|
+
# auto get the config name from dataset, method, metric
|
|
112
|
+
# 2. Generate the canonical Config Name
|
|
113
|
+
name_parts = []
|
|
114
|
+
general_info = self.get_general_cfg().get_name()
|
|
115
|
+
dataset_info = self.get_dataset_cfg().get_name()
|
|
116
|
+
method_info = self.get_method_cfg().get_name()
|
|
117
|
+
name_parts = [
|
|
118
|
+
general_info,
|
|
119
|
+
f"ds_{dataset_info}",
|
|
120
|
+
f"mt_{method_info}",
|
|
121
|
+
]
|
|
122
|
+
if "extra" in kwargs:
|
|
123
|
+
extra_info = kwargs["extra"]
|
|
124
|
+
assert isinstance(extra_info, str), "'extra' kwarg must be a string."
|
|
125
|
+
name_parts.append(extra_info)
|
|
126
|
+
self.cfg_name = sep.join(name_parts)
|
|
127
|
+
return self.cfg_name
|
|
128
|
+
|
|
129
|
+
@abstractmethod
|
|
130
|
+
def get_outdir(self):
|
|
131
|
+
"""
|
|
132
|
+
Get the output directory for the configuration.
|
|
133
|
+
This method should be implemented in subclasses.
|
|
134
|
+
"""
|
|
135
|
+
return None
|
|
136
|
+
|
|
137
|
+
@abstractmethod
|
|
138
|
+
def get_general_cfg(self) -> NamedCfg:
|
|
139
|
+
"""
|
|
140
|
+
Get the general configuration like output directory, log settings, SEED, etc.
|
|
141
|
+
This method should be implemented in subclasses.
|
|
142
|
+
"""
|
|
143
|
+
pass
|
|
144
|
+
|
|
145
|
+
@abstractmethod
|
|
146
|
+
def get_dataset_cfg(self) -> NamedCfg:
|
|
147
|
+
"""
|
|
148
|
+
Get the dataset configuration.
|
|
149
|
+
This method should be implemented in subclasses.
|
|
150
|
+
"""
|
|
151
|
+
pass
|
|
152
|
+
|
|
153
|
+
@abstractmethod
|
|
154
|
+
def get_method_cfg(self) -> NamedCfg:
|
|
155
|
+
"""
|
|
156
|
+
Get the method configuration.
|
|
157
|
+
This method should be implemented in subclasses.
|
|
158
|
+
"""
|
|
159
|
+
pass
|
|
160
|
+
|
|
161
|
+
@abstractmethod
|
|
162
|
+
def get_metric_cfg(self) -> NamedCfg:
|
|
163
|
+
"""
|
|
164
|
+
Get the metric configuration.
|
|
165
|
+
This method should be implemented in subclasses.
|
|
166
|
+
"""
|
|
167
|
+
pass
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Tuple, Any, Optional
|
|
3
|
+
from .base_config import ExpBaseCfg
|
|
4
|
+
from ..perf.perfcalc import PerfCalc
|
|
5
|
+
from ..perf.perfmetrics import MetricsBackend
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ExpHook:
|
|
9
|
+
"""Base interface for all experiment hooks."""
|
|
10
|
+
def on_before_run(self, exp): pass
|
|
11
|
+
def on_after_run(self, exp, results): pass
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
# ! SEE https://github.com/hahv/base_exp for sample usage
|
|
15
|
+
class BaseExp(PerfCalc, ABC):
|
|
16
|
+
"""
|
|
17
|
+
Base class for experiments.
|
|
18
|
+
Orchestrates the experiment pipeline using a pluggable metrics backend.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, config: ExpBaseCfg):
|
|
22
|
+
self.config = config
|
|
23
|
+
self.metric_backend = None
|
|
24
|
+
# Flag to track if init_general/prepare_dataset has run
|
|
25
|
+
self._is_env_ready = False
|
|
26
|
+
self.hooks = []
|
|
27
|
+
|
|
28
|
+
def register_hook(self, hook: ExpHook):
|
|
29
|
+
self.hooks.append(hook)
|
|
30
|
+
|
|
31
|
+
def _trigger_hooks(self, method_name: str, *args, **kwargs):
|
|
32
|
+
for hook in self.hooks:
|
|
33
|
+
method = getattr(hook, method_name, None)
|
|
34
|
+
if callable(method):
|
|
35
|
+
method(*args, **kwargs)
|
|
36
|
+
|
|
37
|
+
# -----------------------
|
|
38
|
+
# PerfCalc Required Methods
|
|
39
|
+
# -----------------------
|
|
40
|
+
def get_dataset_name(self):
|
|
41
|
+
return self.config.get_dataset_cfg().get_name()
|
|
42
|
+
|
|
43
|
+
def get_experiment_name(self):
|
|
44
|
+
return self.config.get_cfg_name()
|
|
45
|
+
|
|
46
|
+
def get_metric_backend(self):
|
|
47
|
+
if not self.metric_backend:
|
|
48
|
+
self.metric_backend = self.prepare_metrics(self.config.get_metric_cfg())
|
|
49
|
+
return self.metric_backend
|
|
50
|
+
|
|
51
|
+
# -----------------------
|
|
52
|
+
# Abstract Experiment Steps
|
|
53
|
+
# -----------------------
|
|
54
|
+
@abstractmethod
|
|
55
|
+
def init_general(self, general_cfg):
|
|
56
|
+
"""Setup general settings like SEED, logging, env variables."""
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
@abstractmethod
|
|
60
|
+
def prepare_dataset(self, dataset_cfg):
|
|
61
|
+
"""Load/prepare dataset."""
|
|
62
|
+
pass
|
|
63
|
+
|
|
64
|
+
@abstractmethod
|
|
65
|
+
def prepare_metrics(self, metric_cfg) -> MetricsBackend:
|
|
66
|
+
"""
|
|
67
|
+
Prepare the metrics for the experiment.
|
|
68
|
+
This method should be implemented in subclasses.
|
|
69
|
+
"""
|
|
70
|
+
pass
|
|
71
|
+
|
|
72
|
+
@abstractmethod
|
|
73
|
+
def exec_exp(self, *args, **kwargs) -> Optional[Tuple[Any, Any]]:
|
|
74
|
+
"""Run experiment process, e.g.: training/evaluation loop.
|
|
75
|
+
Return: either `None` or a tuple of (raw_metrics_data, extra_data) for calc_and_save_exp_perfs
|
|
76
|
+
"""
|
|
77
|
+
pass
|
|
78
|
+
|
|
79
|
+
# -----------------------
|
|
80
|
+
# Internal Helpers
|
|
81
|
+
# -----------------------
|
|
82
|
+
def _validate_and_unpack(self, results):
|
|
83
|
+
if results is None:
|
|
84
|
+
return None
|
|
85
|
+
if not isinstance(results, (tuple, list)) or len(results) != 2:
|
|
86
|
+
raise ValueError("exec must return (metrics_data, extra_data)")
|
|
87
|
+
return results[0], results[1]
|
|
88
|
+
|
|
89
|
+
def _prepare_environment(self, force_reload: bool = False):
|
|
90
|
+
"""
|
|
91
|
+
Common setup. Skips if already initialized, unless force_reload is True.
|
|
92
|
+
"""
|
|
93
|
+
if self._is_env_ready and not force_reload:
|
|
94
|
+
# Environment is already prepared, skipping setup.
|
|
95
|
+
return
|
|
96
|
+
|
|
97
|
+
# 1. Run Setup
|
|
98
|
+
self.init_general(self.config.get_general_cfg())
|
|
99
|
+
self.prepare_dataset(self.config.get_dataset_cfg())
|
|
100
|
+
|
|
101
|
+
# 2. Update metric backend (refresh if needed)
|
|
102
|
+
self.metric_backend = self.prepare_metrics(self.config.get_metric_cfg())
|
|
103
|
+
|
|
104
|
+
# 3. Mark as ready
|
|
105
|
+
self._is_env_ready = True
|
|
106
|
+
|
|
107
|
+
# -----------------------
|
|
108
|
+
# Main Experiment Runner
|
|
109
|
+
# -----------------------
|
|
110
|
+
def run_exp(self, should_calc_metrics=True, reload_env=False, *args, **kwargs):
|
|
111
|
+
"""
|
|
112
|
+
Run the whole experiment pipeline.
|
|
113
|
+
:param reload_env: If True, forces dataset/general init to run again.
|
|
114
|
+
:param should_calc_metrics: Whether to calculate and save metrics after execution.
|
|
115
|
+
:kwargs Params:
|
|
116
|
+
+ 'outfile' to save csv file results,
|
|
117
|
+
+ 'outdir' to set output directory for experiment results.
|
|
118
|
+
+ 'return_df' to return a DataFrame of results instead of a dictionary.
|
|
119
|
+
|
|
120
|
+
Full pipeline:
|
|
121
|
+
1. Init
|
|
122
|
+
2. Prepare Environment (General + Dataset + Metrics)
|
|
123
|
+
3. Save Config
|
|
124
|
+
4. Execute
|
|
125
|
+
5. Calculate & Save Metrics
|
|
126
|
+
"""
|
|
127
|
+
self._prepare_environment(force_reload=reload_env)
|
|
128
|
+
|
|
129
|
+
self._trigger_hooks("before_run", self)
|
|
130
|
+
|
|
131
|
+
# Save config before running
|
|
132
|
+
self.config.save_to_outdir()
|
|
133
|
+
|
|
134
|
+
# Execute experiment
|
|
135
|
+
results = self.exec_exp(*args, **kwargs)
|
|
136
|
+
|
|
137
|
+
if should_calc_metrics and results is not None:
|
|
138
|
+
metrics_data, extra_data = self._validate_and_unpack(results)
|
|
139
|
+
# Calculate & Save metrics
|
|
140
|
+
perf_results = self.calc_perfs(
|
|
141
|
+
raw_metrics_data=metrics_data, extra_data=extra_data, *args, **kwargs
|
|
142
|
+
)
|
|
143
|
+
self._trigger_hooks("after_run", self, perf_results)
|
|
144
|
+
return perf_results
|
|
145
|
+
else:
|
|
146
|
+
self._trigger_hooks("after_run", self, results)
|
|
147
|
+
return results
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import copy
|
|
3
|
+
import numpy as np
|
|
4
|
+
from itertools import product
|
|
5
|
+
from typing import Dict, Any, List, Iterator, Optional
|
|
6
|
+
from ...filetype import yamlfile
|
|
7
|
+
|
|
8
|
+
class ParamGen:
|
|
9
|
+
"""
|
|
10
|
+
A flexible parameter grid generator for hyperparameter tuning and experiment management.
|
|
11
|
+
|
|
12
|
+
This class generates a Cartesian product of parameters from a "sweep configuration"
|
|
13
|
+
and optionally merges them into a "base configuration". It abstracts away the complexity
|
|
14
|
+
of handling nested dictionaries and range generation.
|
|
15
|
+
|
|
16
|
+
Key Features:
|
|
17
|
+
-----------
|
|
18
|
+
1. **Flexible Syntax**: Define parameters using standard nested dictionaries or
|
|
19
|
+
dot-notation keys (e.g., `'model.backbone.layers'`).
|
|
20
|
+
2. **Range Shortcuts**:
|
|
21
|
+
- **Choices**: Standard lists `[1, 2, 3]`.
|
|
22
|
+
- **String Ranges**: `"start:stop:step"` (e.g., `"0:10:2"` -> `[0, 2, 4, 6, 8]`).
|
|
23
|
+
- **Dict Ranges**: `{'start': 0, 'stop': 1, 'step': 0.1}`.
|
|
24
|
+
3. **Deep Merging**: Automatically updates deep keys in `base_cfg` without overwriting siblings.
|
|
25
|
+
|
|
26
|
+
Example:
|
|
27
|
+
--------
|
|
28
|
+
>>> base = {'model': {'name': 'resnet', 'dropout': 0.1}, 'seed': 42}
|
|
29
|
+
>>> sweep = {
|
|
30
|
+
... 'model.name': ['resnet', 'vit'], # Dot notation
|
|
31
|
+
... 'model.dropout': "0.1:0.3:0.1", # Range string
|
|
32
|
+
... 'seed': [42, 100] # Simple choice
|
|
33
|
+
... }
|
|
34
|
+
>>> grid = ParamGen(sweep, base)
|
|
35
|
+
>>> configs = grid.expand()
|
|
36
|
+
>>> print(len(configs)) # Outputs: 8 (2 models * 2 dropouts * 2 seeds)
|
|
37
|
+
Attributes:
|
|
38
|
+
keys (List[str]): List of flattened dot-notation keys being swept.
|
|
39
|
+
values (List[List[Any]]): List of value options for each key.
|
|
40
|
+
"""
|
|
41
|
+
def __init__(
|
|
42
|
+
self, sweep_cfg: Dict[str, Any], base_cfg: Optional[Dict[str, Any]] = None
|
|
43
|
+
):
|
|
44
|
+
"""
|
|
45
|
+
Args:
|
|
46
|
+
sweep_cfg: The dictionary defining parameters to sweep.
|
|
47
|
+
base_cfg: (Optional) The base config to merge sweep parameters into.
|
|
48
|
+
If None, expand() behaves like expand_sweep().
|
|
49
|
+
"""
|
|
50
|
+
self.base_cfg = base_cfg if base_cfg is not None else {}
|
|
51
|
+
|
|
52
|
+
# Recursively flatten the nested sweep config into dot-notation keys
|
|
53
|
+
self.param_space = self._flatten_params(sweep_cfg)
|
|
54
|
+
self.keys = list(self.param_space.keys())
|
|
55
|
+
self.values = list(self.param_space.values())
|
|
56
|
+
|
|
57
|
+
def get_param_space(self) -> Dict[str, List[Any]]:
|
|
58
|
+
"""Returns the parameter space as a dictionary of dot-notation keys to value lists."""
|
|
59
|
+
return self.param_space
|
|
60
|
+
|
|
61
|
+
def __iter__(self) -> Iterator[Dict[str, Any]]:
|
|
62
|
+
"""Yields fully merged configurations one by one."""
|
|
63
|
+
for combination in product(*self.values):
|
|
64
|
+
# 1. Create the flat sweep dict (dot notation)
|
|
65
|
+
flat_params = dict(zip(self.keys, combination))
|
|
66
|
+
|
|
67
|
+
# 2. Deep copy base and update with current params
|
|
68
|
+
new_cfg = copy.deepcopy(self.base_cfg)
|
|
69
|
+
new_cfg = self._apply_updates(new_cfg, flat_params)
|
|
70
|
+
|
|
71
|
+
# 3. Store metadata (Optional)
|
|
72
|
+
# if "_meta" not in new_cfg:
|
|
73
|
+
# new_cfg["_meta"] = {}
|
|
74
|
+
# We unflatten the sweep params here so the log is readable
|
|
75
|
+
# new_cfg["_meta"]["sweep_params"] = self._unflatten(flat_params)
|
|
76
|
+
|
|
77
|
+
yield new_cfg
|
|
78
|
+
|
|
79
|
+
# ! --- Factory Methods ---
|
|
80
|
+
@classmethod
|
|
81
|
+
def from_dicts(
|
|
82
|
+
cls, sweep_cfg: Dict[str, Any], base_cfg: Optional[Dict[str, Any]] = None
|
|
83
|
+
):
|
|
84
|
+
"""
|
|
85
|
+
Load from dictionaries.
|
|
86
|
+
Args:
|
|
87
|
+
sweep_cfg: The dictionary defining parameters to sweep.
|
|
88
|
+
base_cfg: (Optional) The base config to merge sweep parameters into.
|
|
89
|
+
"""
|
|
90
|
+
return cls(sweep_cfg, base_cfg)
|
|
91
|
+
|
|
92
|
+
@classmethod
|
|
93
|
+
def from_files(cls, sweep_yaml: str, base_yaml: Optional[str] = None):
|
|
94
|
+
"""
|
|
95
|
+
Load from files.
|
|
96
|
+
Args:
|
|
97
|
+
sweep_yaml: Path to sweep config.
|
|
98
|
+
base_yaml: (Optional) Path to base config.
|
|
99
|
+
"""
|
|
100
|
+
assert os.path.isfile(sweep_yaml), f"Sweep file not found: {sweep_yaml}"
|
|
101
|
+
sweep_dict = yamlfile.load_yaml(sweep_yaml, to_dict=True)
|
|
102
|
+
base_dict = None
|
|
103
|
+
if base_yaml:
|
|
104
|
+
base_dict = yamlfile.load_yaml(base_yaml, to_dict=True)
|
|
105
|
+
if "__base__" in base_dict:
|
|
106
|
+
del base_dict["__base__"]
|
|
107
|
+
|
|
108
|
+
return cls(sweep_dict, base_dict)
|
|
109
|
+
|
|
110
|
+
def expand(self) -> List[Dict[str, Any]]:
|
|
111
|
+
"""Generates and returns the full list of MERGED configurations."""
|
|
112
|
+
return list(self)
|
|
113
|
+
|
|
114
|
+
def expand_sweep_flat(self) -> List[Dict[str, Any]]:
|
|
115
|
+
"""
|
|
116
|
+
Returns a list of ONLY the sweep parameters, formatted as FLAT dot-notation dictionaries.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
[{'exp_params.model': 'resnet', 'exp_params.lr': 0.01}, ...]
|
|
120
|
+
"""
|
|
121
|
+
combinations = []
|
|
122
|
+
for combination in product(*self.values):
|
|
123
|
+
flat_dict = dict(zip(self.keys, combination))
|
|
124
|
+
combinations.append(flat_dict)
|
|
125
|
+
return combinations
|
|
126
|
+
|
|
127
|
+
def _unflatten(self, flat_dict: Dict[str, Any]) -> Dict[str, Any]:
|
|
128
|
+
"""Converts {'a.b': 1} back to {'a': {'b': 1}}."""
|
|
129
|
+
nested = {}
|
|
130
|
+
self._apply_updates(nested, flat_dict)
|
|
131
|
+
return nested
|
|
132
|
+
|
|
133
|
+
def _flatten_params(
|
|
134
|
+
self, cfg: Dict[str, Any], parent_key: str = ""
|
|
135
|
+
) -> Dict[str, List[Any]]:
|
|
136
|
+
"""Recursively converts nested dicts into flat dot-notation keys."""
|
|
137
|
+
flat = {}
|
|
138
|
+
for key, val in cfg.items():
|
|
139
|
+
current_key = f"{parent_key}.{key}" if parent_key else key
|
|
140
|
+
|
|
141
|
+
if self._is_sweep_leaf(val):
|
|
142
|
+
flat[current_key] = self._expand_val(val)
|
|
143
|
+
elif isinstance(val, dict):
|
|
144
|
+
flat.update(self._flatten_params(val, current_key))
|
|
145
|
+
else:
|
|
146
|
+
flat[current_key] = [val]
|
|
147
|
+
return flat
|
|
148
|
+
|
|
149
|
+
def _is_sweep_leaf(self, val: Any) -> bool:
|
|
150
|
+
if isinstance(val, list):
|
|
151
|
+
return True
|
|
152
|
+
if isinstance(val, str) and ":" in val:
|
|
153
|
+
return True
|
|
154
|
+
if isinstance(val, dict) and "start" in val and "stop" in val:
|
|
155
|
+
return True
|
|
156
|
+
return False
|
|
157
|
+
|
|
158
|
+
def _expand_val(self, val: Any) -> List[Any]:
|
|
159
|
+
if isinstance(val, list):
|
|
160
|
+
return val
|
|
161
|
+
|
|
162
|
+
if isinstance(val, str) and ":" in val:
|
|
163
|
+
try:
|
|
164
|
+
parts = [float(x) for x in val.split(":")]
|
|
165
|
+
if len(parts) == 3:
|
|
166
|
+
arr = np.arange(parts[0], parts[1], parts[2])
|
|
167
|
+
return [float(f"{x:.6g}") for x in arr]
|
|
168
|
+
except ValueError:
|
|
169
|
+
pass
|
|
170
|
+
|
|
171
|
+
if isinstance(val, dict) and "start" in val:
|
|
172
|
+
step = val.get("step", 1)
|
|
173
|
+
return np.arange(val["start"], val["stop"], step).tolist()
|
|
174
|
+
|
|
175
|
+
return [val]
|
|
176
|
+
|
|
177
|
+
def _apply_updates(
|
|
178
|
+
self, cfg: Dict[str, Any], updates: Dict[str, Any]
|
|
179
|
+
) -> Dict[str, Any]:
|
|
180
|
+
"""Deep merges dot-notation updates into cfg."""
|
|
181
|
+
for key, val in updates.items():
|
|
182
|
+
parts = key.split(".")
|
|
183
|
+
target = cfg
|
|
184
|
+
for part in parts[:-1]:
|
|
185
|
+
if part not in target:
|
|
186
|
+
target[part] = {}
|
|
187
|
+
target = target[part]
|
|
188
|
+
target[parts[-1]] = val
|
|
189
|
+
return cfg
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import glob
|
|
3
|
+
import wandb
|
|
4
|
+
import argparse
|
|
5
|
+
import subprocess
|
|
6
|
+
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
from rich.console import Console
|
|
9
|
+
|
|
10
|
+
console = Console()
|
|
11
|
+
|
|
12
|
+
def sync_runs(outdir):
|
|
13
|
+
outdir = os.path.abspath(outdir)
|
|
14
|
+
assert os.path.exists(outdir), f"Output directory {outdir} does not exist."
|
|
15
|
+
sub_dirs = [name for name in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, name))]
|
|
16
|
+
assert len(sub_dirs) > 0, f"No subdirectories found in {outdir}."
|
|
17
|
+
console.rule("Parent Directory")
|
|
18
|
+
console.print(f"[yellow]{outdir}[/yellow]")
|
|
19
|
+
|
|
20
|
+
exp_dirs = [os.path.join(outdir, sub_dir) for sub_dir in sub_dirs]
|
|
21
|
+
wandb_dirs = []
|
|
22
|
+
for exp_dir in exp_dirs:
|
|
23
|
+
wandb_dirs.extend(glob.glob(f"{exp_dir}/wandb/*run-*"))
|
|
24
|
+
if len(wandb_dirs) == 0:
|
|
25
|
+
console.print(f"No wandb runs found in {outdir}.")
|
|
26
|
+
return
|
|
27
|
+
else:
|
|
28
|
+
console.print(f"Found [bold]{len(wandb_dirs)}[/bold] wandb runs in {outdir}.")
|
|
29
|
+
for i, wandb_dir in enumerate(wandb_dirs):
|
|
30
|
+
console.rule(f"Syncing wandb run {i + 1}/{len(wandb_dirs)}")
|
|
31
|
+
console.print(f"Syncing: {wandb_dir}")
|
|
32
|
+
process = subprocess.Popen(
|
|
33
|
+
["wandb", "sync", wandb_dir],
|
|
34
|
+
stdout=subprocess.PIPE,
|
|
35
|
+
stderr=subprocess.STDOUT,
|
|
36
|
+
text=True,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
for line in process.stdout:
|
|
40
|
+
console.print(line.strip())
|
|
41
|
+
if " ERROR Error while calling W&B API" in line:
|
|
42
|
+
break
|
|
43
|
+
process.stdout.close()
|
|
44
|
+
process.wait()
|
|
45
|
+
if process.returncode != 0:
|
|
46
|
+
console.print(f"[red]Error syncing {wandb_dir}. Return code: {process.returncode}[/red]")
|
|
47
|
+
else:
|
|
48
|
+
console.print(f"Successfully synced {wandb_dir}.")
|
|
49
|
+
|
|
50
|
+
def delete_runs(project, pattern=None):
|
|
51
|
+
console.rule("Delete W&B Runs")
|
|
52
|
+
confirm_msg = f"Are you sure you want to delete all runs in"
|
|
53
|
+
confirm_msg += f" \n\tproject: [red]{project}[/red]"
|
|
54
|
+
if pattern:
|
|
55
|
+
confirm_msg += f"\n\tpattern: [blue]{pattern}[/blue]"
|
|
56
|
+
|
|
57
|
+
console.print(confirm_msg)
|
|
58
|
+
confirmation = input(f"This action cannot be undone. [y/N]: ").strip().lower()
|
|
59
|
+
if confirmation != "y":
|
|
60
|
+
print("Cancelled.")
|
|
61
|
+
return
|
|
62
|
+
|
|
63
|
+
print("Confirmed. Proceeding...")
|
|
64
|
+
api = wandb.Api()
|
|
65
|
+
runs = api.runs(project)
|
|
66
|
+
|
|
67
|
+
deleted = 0
|
|
68
|
+
console.rule("Deleting W&B Runs")
|
|
69
|
+
if len(runs) == 0:
|
|
70
|
+
print("No runs found in the project.")
|
|
71
|
+
return
|
|
72
|
+
for run in tqdm(runs):
|
|
73
|
+
if pattern is None or pattern in run.name:
|
|
74
|
+
run.delete()
|
|
75
|
+
console.print(f"Deleted run: [red]{run.name}[/red]")
|
|
76
|
+
deleted += 1
|
|
77
|
+
|
|
78
|
+
console.print(f"Total runs deleted: {deleted}")
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def valid_argument(args):
|
|
82
|
+
if args.op == "sync":
|
|
83
|
+
assert os.path.exists(args.outdir), f"Output directory {args.outdir} does not exist."
|
|
84
|
+
elif args.op == "delete":
|
|
85
|
+
assert isinstance(args.project, str) and len(args.project.strip()) > 0, "Project name must be a non-empty string."
|
|
86
|
+
else:
|
|
87
|
+
raise ValueError(f"Unknown operation: {args.op}")
|
|
88
|
+
|
|
89
|
+
def parse_args():
|
|
90
|
+
parser = argparse.ArgumentParser(description="Operations on W&B runs")
|
|
91
|
+
parser.add_argument("-op", "--op", type=str, help="Operation to perform", default="sync", choices=["delete", "sync"])
|
|
92
|
+
parser.add_argument("-prj", "--project", type=str, default="fire-paper2-2025", help="W&B project name")
|
|
93
|
+
parser.add_argument("-outdir", "--outdir", type=str, help="arg1 description", default="./zout/train")
|
|
94
|
+
parser.add_argument("-pt", "--pattern",
|
|
95
|
+
type=str,
|
|
96
|
+
default=None,
|
|
97
|
+
help="Run name pattern to match for deletion",
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
return parser.parse_args()
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def main():
|
|
104
|
+
args = parse_args()
|
|
105
|
+
# Validate arguments, stop if invalid
|
|
106
|
+
valid_argument(args)
|
|
107
|
+
|
|
108
|
+
op = args.op
|
|
109
|
+
if op == "sync":
|
|
110
|
+
sync_runs(args.outdir)
|
|
111
|
+
elif op == "delete":
|
|
112
|
+
delete_runs(args.project, args.pattern)
|
|
113
|
+
else:
|
|
114
|
+
raise ValueError(f"Unknown operation: {op}")
|
|
115
|
+
|
|
116
|
+
if __name__ == "__main__":
|
|
117
|
+
main()
|
|
File without changes
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
import yaml
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from rich.pretty import pprint
|
|
5
|
+
from dataclasses import make_dataclass
|
|
6
|
+
|
|
7
|
+
from ...filetype import yamlfile
|
|
8
|
+
|
|
9
|
+
def dict_to_dataclass(name: str, data: dict):
|
|
10
|
+
fields = []
|
|
11
|
+
values = {}
|
|
12
|
+
|
|
13
|
+
for key, value in data.items():
|
|
14
|
+
if isinstance(value, dict):
|
|
15
|
+
sub_dc = dict_to_dataclass(key.capitalize(), value)
|
|
16
|
+
fields.append((key, type(sub_dc)))
|
|
17
|
+
values[key] = sub_dc
|
|
18
|
+
else:
|
|
19
|
+
field_type = type(value) if value is not None else Any
|
|
20
|
+
fields.append((key, field_type))
|
|
21
|
+
values[key] = value
|
|
22
|
+
|
|
23
|
+
DC = make_dataclass(name.capitalize(), fields)
|
|
24
|
+
return DC(**values)
|
|
25
|
+
|
|
26
|
+
def yaml_to_dataclass(name: str, yaml_str: str):
|
|
27
|
+
data = yaml.safe_load(yaml_str)
|
|
28
|
+
return dict_to_dataclass(name, data)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def yamlfile_to_dataclass(name: str, file_path: str):
|
|
32
|
+
data_dict = yamlfile.load_yaml(file_path, to_dict=True)
|
|
33
|
+
if "__base__" in data_dict:
|
|
34
|
+
del data_dict["__base__"]
|
|
35
|
+
return dict_to_dataclass(name, data_dict)
|
|
36
|
+
|
|
37
|
+
if __name__ == "__main__":
|
|
38
|
+
cfg = yamlfile_to_dataclass("Config", "test/dataclass_util_test_cfg.yaml")
|
|
39
|
+
|
|
40
|
+
# ! NOTICE: after print out this dataclass, we can copy the output and paste it into CHATGPT to generate a list of needed dataclass classes using `from dataclass_wizard import YAMLWizard`
|
|
41
|
+
pprint(cfg)
|