halib 0.2.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. halib/__init__.py +94 -0
  2. halib/common/__init__.py +0 -0
  3. halib/common/common.py +326 -0
  4. halib/common/rich_color.py +285 -0
  5. halib/common.py +151 -0
  6. halib/csvfile.py +48 -0
  7. halib/cuda.py +39 -0
  8. halib/dataset.py +209 -0
  9. halib/exp/__init__.py +0 -0
  10. halib/exp/core/__init__.py +0 -0
  11. halib/exp/core/base_config.py +167 -0
  12. halib/exp/core/base_exp.py +147 -0
  13. halib/exp/core/param_gen.py +170 -0
  14. halib/exp/core/wandb_op.py +117 -0
  15. halib/exp/data/__init__.py +0 -0
  16. halib/exp/data/dataclass_util.py +41 -0
  17. halib/exp/data/dataset.py +208 -0
  18. halib/exp/data/torchloader.py +165 -0
  19. halib/exp/perf/__init__.py +0 -0
  20. halib/exp/perf/flop_calc.py +190 -0
  21. halib/exp/perf/gpu_mon.py +58 -0
  22. halib/exp/perf/perfcalc.py +470 -0
  23. halib/exp/perf/perfmetrics.py +137 -0
  24. halib/exp/perf/perftb.py +778 -0
  25. halib/exp/perf/profiler.py +507 -0
  26. halib/exp/viz/__init__.py +0 -0
  27. halib/exp/viz/plot.py +754 -0
  28. halib/filesys.py +117 -0
  29. halib/filetype/__init__.py +0 -0
  30. halib/filetype/csvfile.py +192 -0
  31. halib/filetype/ipynb.py +61 -0
  32. halib/filetype/jsonfile.py +19 -0
  33. halib/filetype/textfile.py +12 -0
  34. halib/filetype/videofile.py +266 -0
  35. halib/filetype/yamlfile.py +87 -0
  36. halib/gdrive.py +179 -0
  37. halib/gdrive_mkdir.py +41 -0
  38. halib/gdrive_test.py +37 -0
  39. halib/jsonfile.py +22 -0
  40. halib/listop.py +13 -0
  41. halib/online/__init__.py +0 -0
  42. halib/online/gdrive.py +229 -0
  43. halib/online/gdrive_mkdir.py +53 -0
  44. halib/online/gdrive_test.py +50 -0
  45. halib/online/projectmake.py +131 -0
  46. halib/online/tele_noti.py +165 -0
  47. halib/plot.py +301 -0
  48. halib/projectmake.py +115 -0
  49. halib/research/__init__.py +0 -0
  50. halib/research/base_config.py +100 -0
  51. halib/research/base_exp.py +157 -0
  52. halib/research/benchquery.py +131 -0
  53. halib/research/core/__init__.py +0 -0
  54. halib/research/core/base_config.py +144 -0
  55. halib/research/core/base_exp.py +157 -0
  56. halib/research/core/param_gen.py +108 -0
  57. halib/research/core/wandb_op.py +117 -0
  58. halib/research/data/__init__.py +0 -0
  59. halib/research/data/dataclass_util.py +41 -0
  60. halib/research/data/dataset.py +208 -0
  61. halib/research/data/torchloader.py +165 -0
  62. halib/research/dataset.py +208 -0
  63. halib/research/flop_csv.py +34 -0
  64. halib/research/flops.py +156 -0
  65. halib/research/metrics.py +137 -0
  66. halib/research/mics.py +74 -0
  67. halib/research/params_gen.py +108 -0
  68. halib/research/perf/__init__.py +0 -0
  69. halib/research/perf/flop_calc.py +190 -0
  70. halib/research/perf/gpu_mon.py +58 -0
  71. halib/research/perf/perfcalc.py +363 -0
  72. halib/research/perf/perfmetrics.py +137 -0
  73. halib/research/perf/perftb.py +778 -0
  74. halib/research/perf/profiler.py +301 -0
  75. halib/research/perfcalc.py +361 -0
  76. halib/research/perftb.py +780 -0
  77. halib/research/plot.py +758 -0
  78. halib/research/profiler.py +300 -0
  79. halib/research/torchloader.py +162 -0
  80. halib/research/viz/__init__.py +0 -0
  81. halib/research/viz/plot.py +754 -0
  82. halib/research/wandb_op.py +116 -0
  83. halib/rich_color.py +285 -0
  84. halib/sys/__init__.py +0 -0
  85. halib/sys/cmd.py +8 -0
  86. halib/sys/filesys.py +124 -0
  87. halib/system/__init__.py +0 -0
  88. halib/system/_list_pc.csv +6 -0
  89. halib/system/cmd.py +8 -0
  90. halib/system/filesys.py +164 -0
  91. halib/system/path.py +106 -0
  92. halib/tele_noti.py +166 -0
  93. halib/textfile.py +13 -0
  94. halib/torchloader.py +162 -0
  95. halib/utils/__init__.py +0 -0
  96. halib/utils/dataclass_util.py +40 -0
  97. halib/utils/dict.py +317 -0
  98. halib/utils/dict_op.py +9 -0
  99. halib/utils/gpu_mon.py +58 -0
  100. halib/utils/list.py +17 -0
  101. halib/utils/listop.py +13 -0
  102. halib/utils/slack.py +86 -0
  103. halib/utils/tele_noti.py +166 -0
  104. halib/utils/video.py +82 -0
  105. halib/videofile.py +139 -0
  106. halib-0.2.30.dist-info/METADATA +237 -0
  107. halib-0.2.30.dist-info/RECORD +110 -0
  108. halib-0.2.30.dist-info/WHEEL +5 -0
  109. halib-0.2.30.dist-info/licenses/LICENSE.txt +17 -0
  110. halib-0.2.30.dist-info/top_level.txt +1 -0
@@ -0,0 +1,147 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Tuple, Any, Optional
3
+ from .base_config import ExpBaseCfg
4
+ from ..perf.perfcalc import PerfCalc
5
+ from ..perf.perfmetrics import MetricsBackend
6
+
7
+
8
+ class ExpHook:
9
+ """Base interface for all experiment hooks."""
10
+ def on_before_run(self, exp): pass
11
+ def on_after_run(self, exp, results): pass
12
+
13
+
14
+ # ! SEE https://github.com/hahv/base_exp for sample usage
15
+ class BaseExp(PerfCalc, ABC):
16
+ """
17
+ Base class for experiments.
18
+ Orchestrates the experiment pipeline using a pluggable metrics backend.
19
+ """
20
+
21
+ def __init__(self, config: ExpBaseCfg):
22
+ self.config = config
23
+ self.metric_backend = None
24
+ # Flag to track if init_general/prepare_dataset has run
25
+ self._is_env_ready = False
26
+ self.hooks = []
27
+
28
+ def register_hook(self, hook: ExpHook):
29
+ self.hooks.append(hook)
30
+
31
+ def _trigger_hooks(self, method_name: str, *args, **kwargs):
32
+ for hook in self.hooks:
33
+ method = getattr(hook, method_name, None)
34
+ if callable(method):
35
+ method(*args, **kwargs)
36
+
37
+ # -----------------------
38
+ # PerfCalc Required Methods
39
+ # -----------------------
40
+ def get_dataset_name(self):
41
+ return self.config.get_dataset_cfg().get_name()
42
+
43
+ def get_experiment_name(self):
44
+ return self.config.get_cfg_name()
45
+
46
+ def get_metric_backend(self):
47
+ if not self.metric_backend:
48
+ self.metric_backend = self.prepare_metrics(self.config.get_metric_cfg())
49
+ return self.metric_backend
50
+
51
+ # -----------------------
52
+ # Abstract Experiment Steps
53
+ # -----------------------
54
+ @abstractmethod
55
+ def init_general(self, general_cfg):
56
+ """Setup general settings like SEED, logging, env variables."""
57
+ pass
58
+
59
+ @abstractmethod
60
+ def prepare_dataset(self, dataset_cfg):
61
+ """Load/prepare dataset."""
62
+ pass
63
+
64
+ @abstractmethod
65
+ def prepare_metrics(self, metric_cfg) -> MetricsBackend:
66
+ """
67
+ Prepare the metrics for the experiment.
68
+ This method should be implemented in subclasses.
69
+ """
70
+ pass
71
+
72
+ @abstractmethod
73
+ def exec_exp(self, *args, **kwargs) -> Optional[Tuple[Any, Any]]:
74
+ """Run experiment process, e.g.: training/evaluation loop.
75
+ Return: either `None` or a tuple of (raw_metrics_data, extra_data) for calc_and_save_exp_perfs
76
+ """
77
+ pass
78
+
79
+ # -----------------------
80
+ # Internal Helpers
81
+ # -----------------------
82
+ def _validate_and_unpack(self, results):
83
+ if results is None:
84
+ return None
85
+ if not isinstance(results, (tuple, list)) or len(results) != 2:
86
+ raise ValueError("exec must return (metrics_data, extra_data)")
87
+ return results[0], results[1]
88
+
89
+ def _prepare_environment(self, force_reload: bool = False):
90
+ """
91
+ Common setup. Skips if already initialized, unless force_reload is True.
92
+ """
93
+ if self._is_env_ready and not force_reload:
94
+ # Environment is already prepared, skipping setup.
95
+ return
96
+
97
+ # 1. Run Setup
98
+ self.init_general(self.config.get_general_cfg())
99
+ self.prepare_dataset(self.config.get_dataset_cfg())
100
+
101
+ # 2. Update metric backend (refresh if needed)
102
+ self.metric_backend = self.prepare_metrics(self.config.get_metric_cfg())
103
+
104
+ # 3. Mark as ready
105
+ self._is_env_ready = True
106
+
107
+ # -----------------------
108
+ # Main Experiment Runner
109
+ # -----------------------
110
+ def run_exp(self, should_calc_metrics=True, reload_env=False, *args, **kwargs):
111
+ """
112
+ Run the whole experiment pipeline.
113
+ :param reload_env: If True, forces dataset/general init to run again.
114
+ :param should_calc_metrics: Whether to calculate and save metrics after execution.
115
+ :kwargs Params:
116
+ + 'outfile' to save csv file results,
117
+ + 'outdir' to set output directory for experiment results.
118
+ + 'return_df' to return a DataFrame of results instead of a dictionary.
119
+
120
+ Full pipeline:
121
+ 1. Init
122
+ 2. Prepare Environment (General + Dataset + Metrics)
123
+ 3. Save Config
124
+ 4. Execute
125
+ 5. Calculate & Save Metrics
126
+ """
127
+ self._prepare_environment(force_reload=reload_env)
128
+
129
+ self._trigger_hooks("before_run", self)
130
+
131
+ # Save config before running
132
+ self.config.save_to_outdir()
133
+
134
+ # Execute experiment
135
+ results = self.exec_exp(*args, **kwargs)
136
+
137
+ if should_calc_metrics and results is not None:
138
+ metrics_data, extra_data = self._validate_and_unpack(results)
139
+ # Calculate & Save metrics
140
+ perf_results = self.calc_perfs(
141
+ raw_metrics_data=metrics_data, extra_data=extra_data, *args, **kwargs
142
+ )
143
+ self._trigger_hooks("after_run", self, perf_results)
144
+ return perf_results
145
+ else:
146
+ self._trigger_hooks("after_run", self, results)
147
+ return results
@@ -0,0 +1,170 @@
1
+ import os
2
+ import copy
3
+ import numpy as np
4
+ from itertools import product
5
+ from typing import Dict, Any, List, Iterator, Optional
6
+ from ...filetype import yamlfile
7
+ from ...utils.dict import DictUtils
8
+
9
+ # Assuming DictUtils is available in the scope or imported
10
+ # from .dict_utils import DictUtils
11
+
12
+
13
+ class ParamGen:
14
+ """
15
+ A flexible parameter grid generator for hyperparameter tuning and experiment management.
16
+
17
+ This class generates a Cartesian product of parameters from a "sweep configuration"
18
+ and optionally merges them into a "base configuration". It abstracts away the complexity
19
+ of handling nested dictionaries and range generation.
20
+
21
+ Key Features:
22
+ -----------
23
+ 1. **Flexible Syntax**: Define parameters using standard nested dictionaries or
24
+ dot-notation keys (e.g., `'model.backbone.layers'`).
25
+ 2. **Range Shortcuts**:
26
+ - **Choices**: Standard lists `[1, 2, 3]`.
27
+ - **String Ranges**: `"start:stop:step"` (e.g., `"0:10:2"` -> `[0, 2, 4, 6, 8]`).
28
+ - **Dict Ranges**: `{'start': 0, 'stop': 1, 'step': 0.1}`.
29
+ 3. **Deep Merging**: Automatically updates deep keys in `base_cfg` without overwriting siblings.
30
+
31
+ Example:
32
+ --------
33
+ >>> base = {'model': {'name': 'resnet', 'dropout': 0.1}, 'seed': 42}
34
+ >>> sweep = {
35
+ ... 'model.name': ['resnet', 'vit'], # Dot notation
36
+ ... 'model.dropout': "0.1:0.3:0.1", # Range string
37
+ ... 'seed': [42, 100] # Simple choice
38
+ ... }
39
+ >>> grid = ParamGen(sweep, base)
40
+ >>> configs = grid.expand()
41
+ >>> print(len(configs)) # Outputs: 8 (2 models * 2 dropouts * 2 seeds)
42
+ Attributes:
43
+ keys (List[str]): List of flattened dot-notation keys being swept.
44
+ values (List[List[Any]]): List of value options for each key.
45
+ """
46
+
47
+ def __init__(
48
+ self, sweep_cfg: Dict[str, Any], base_cfg: Optional[Dict[str, Any]] = None
49
+ ):
50
+ """
51
+ Args:
52
+ sweep_cfg: The dictionary defining parameters to sweep.
53
+ base_cfg: (Optional) The base config to merge sweep parameters into.
54
+ If None, expand() behaves like expand_sweep().
55
+ """
56
+ self.base_cfg = base_cfg if base_cfg is not None else {}
57
+
58
+ # Recursively flatten the nested sweep config into dot-notation keys
59
+ # Refactored to use DictUtils, passing our custom leaf logic
60
+ flat_sweep = DictUtils.flatten(sweep_cfg, is_leaf_predicate=self._is_sweep_leaf)
61
+
62
+ # Expand values (ranges, strings) which DictUtils leaves as-is
63
+ self.param_space = {k: self._expand_val(v) for k, v in flat_sweep.items()}
64
+
65
+ self.keys = list(self.param_space.keys())
66
+ self.values = list(self.param_space.values())
67
+
68
+ def get_param_space(self) -> Dict[str, List[Any]]:
69
+ """Returns the parameter space as a dictionary of dot-notation keys to value lists."""
70
+ return self.param_space
71
+
72
+ def __iter__(self) -> Iterator[Dict[str, Any]]:
73
+ """Yields fully merged configurations one by one."""
74
+ for combination in product(*self.values):
75
+ # 1. Create the flat sweep dict (dot notation)
76
+ flat_params = dict(zip(self.keys, combination))
77
+
78
+ # 2. Deep copy base and update with current params
79
+ new_cfg = copy.deepcopy(self.base_cfg)
80
+
81
+ # Refactored: Unflatten the specific params, then deep merge
82
+ update_structure = DictUtils.unflatten(flat_params)
83
+ DictUtils.deep_update(new_cfg, update_structure)
84
+
85
+ # 3. Store metadata (Optional)
86
+ # if "_meta" not in new_cfg:
87
+ # new_cfg["_meta"] = {}
88
+ # We unflatten the sweep params here so the log is readable
89
+ # new_cfg["_meta"]["sweep_params"] = DictUtils.unflatten(flat_params)
90
+
91
+ yield new_cfg
92
+
93
+ # ! --- Factory Methods ---
94
+ @classmethod
95
+ def from_dicts(
96
+ cls, sweep_cfg: Dict[str, Any], base_cfg: Optional[Dict[str, Any]] = None
97
+ ):
98
+ """
99
+ Load from dictionaries.
100
+ Args:
101
+ sweep_cfg: The dictionary defining parameters to sweep.
102
+ base_cfg: (Optional) The base config to merge sweep parameters into.
103
+ """
104
+ return cls(sweep_cfg, base_cfg)
105
+
106
+ @classmethod
107
+ def from_files(cls, sweep_yaml: str, base_yaml: Optional[str] = None):
108
+ """
109
+ Load from files.
110
+ Args:
111
+ sweep_yaml: Path to sweep config.
112
+ base_yaml: (Optional) Path to base config.
113
+ """
114
+ assert os.path.isfile(sweep_yaml), f"Sweep file not found: {sweep_yaml}"
115
+ sweep_dict = yamlfile.load_yaml(sweep_yaml, to_dict=True)
116
+ base_dict = None
117
+ if base_yaml:
118
+ base_dict = yamlfile.load_yaml(base_yaml, to_dict=True)
119
+ if "__base__" in base_dict:
120
+ del base_dict["__base__"]
121
+
122
+ return cls(sweep_dict, base_dict)
123
+
124
+ def expand(self) -> List[Dict[str, Any]]:
125
+ """Generates and returns the full list of MERGED configurations."""
126
+ return list(self)
127
+
128
+ def expand_sweep_flat(self) -> List[Dict[str, Any]]:
129
+ """
130
+ Returns a list of ONLY the sweep parameters, formatted as FLAT dot-notation dictionaries.
131
+
132
+ Returns:
133
+ [{'exp_params.model': 'resnet', 'exp_params.lr': 0.01}, ...]
134
+ """
135
+ combinations = []
136
+ for combination in product(*self.values):
137
+ flat_dict = dict(zip(self.keys, combination))
138
+ combinations.append(flat_dict)
139
+ return combinations
140
+
141
+ # Note: _unflatten, _flatten_params, and _apply_updates have been removed
142
+ # as they are replaced by DictUtils methods.
143
+
144
+ def _is_sweep_leaf(self, val: Any) -> bool:
145
+ if isinstance(val, list):
146
+ return True
147
+ if isinstance(val, str) and ":" in val:
148
+ return True
149
+ if isinstance(val, dict) and "start" in val and "stop" in val:
150
+ return True
151
+ return False
152
+
153
+ def _expand_val(self, val: Any) -> List[Any]:
154
+ if isinstance(val, list):
155
+ return val
156
+
157
+ if isinstance(val, str) and ":" in val:
158
+ try:
159
+ parts = [float(x) for x in val.split(":")]
160
+ if len(parts) == 3:
161
+ arr = np.arange(parts[0], parts[1], parts[2])
162
+ return [float(f"{x:.6g}") for x in arr]
163
+ except ValueError:
164
+ pass
165
+
166
+ if isinstance(val, dict) and "start" in val:
167
+ step = val.get("step", 1)
168
+ return np.arange(val["start"], val["stop"], step).tolist()
169
+
170
+ return [val]
@@ -0,0 +1,117 @@
1
+ import os
2
+ import glob
3
+ import wandb
4
+ import argparse
5
+ import subprocess
6
+
7
+ from tqdm import tqdm
8
+ from rich.console import Console
9
+
10
+ console = Console()
11
+
12
+ def sync_runs(outdir):
13
+ outdir = os.path.abspath(outdir)
14
+ assert os.path.exists(outdir), f"Output directory {outdir} does not exist."
15
+ sub_dirs = [name for name in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, name))]
16
+ assert len(sub_dirs) > 0, f"No subdirectories found in {outdir}."
17
+ console.rule("Parent Directory")
18
+ console.print(f"[yellow]{outdir}[/yellow]")
19
+
20
+ exp_dirs = [os.path.join(outdir, sub_dir) for sub_dir in sub_dirs]
21
+ wandb_dirs = []
22
+ for exp_dir in exp_dirs:
23
+ wandb_dirs.extend(glob.glob(f"{exp_dir}/wandb/*run-*"))
24
+ if len(wandb_dirs) == 0:
25
+ console.print(f"No wandb runs found in {outdir}.")
26
+ return
27
+ else:
28
+ console.print(f"Found [bold]{len(wandb_dirs)}[/bold] wandb runs in {outdir}.")
29
+ for i, wandb_dir in enumerate(wandb_dirs):
30
+ console.rule(f"Syncing wandb run {i + 1}/{len(wandb_dirs)}")
31
+ console.print(f"Syncing: {wandb_dir}")
32
+ process = subprocess.Popen(
33
+ ["wandb", "sync", wandb_dir],
34
+ stdout=subprocess.PIPE,
35
+ stderr=subprocess.STDOUT,
36
+ text=True,
37
+ )
38
+
39
+ for line in process.stdout:
40
+ console.print(line.strip())
41
+ if " ERROR Error while calling W&B API" in line:
42
+ break
43
+ process.stdout.close()
44
+ process.wait()
45
+ if process.returncode != 0:
46
+ console.print(f"[red]Error syncing {wandb_dir}. Return code: {process.returncode}[/red]")
47
+ else:
48
+ console.print(f"Successfully synced {wandb_dir}.")
49
+
50
+ def delete_runs(project, pattern=None):
51
+ console.rule("Delete W&B Runs")
52
+ confirm_msg = f"Are you sure you want to delete all runs in"
53
+ confirm_msg += f" \n\tproject: [red]{project}[/red]"
54
+ if pattern:
55
+ confirm_msg += f"\n\tpattern: [blue]{pattern}[/blue]"
56
+
57
+ console.print(confirm_msg)
58
+ confirmation = input(f"This action cannot be undone. [y/N]: ").strip().lower()
59
+ if confirmation != "y":
60
+ print("Cancelled.")
61
+ return
62
+
63
+ print("Confirmed. Proceeding...")
64
+ api = wandb.Api()
65
+ runs = api.runs(project)
66
+
67
+ deleted = 0
68
+ console.rule("Deleting W&B Runs")
69
+ if len(runs) == 0:
70
+ print("No runs found in the project.")
71
+ return
72
+ for run in tqdm(runs):
73
+ if pattern is None or pattern in run.name:
74
+ run.delete()
75
+ console.print(f"Deleted run: [red]{run.name}[/red]")
76
+ deleted += 1
77
+
78
+ console.print(f"Total runs deleted: {deleted}")
79
+
80
+
81
+ def valid_argument(args):
82
+ if args.op == "sync":
83
+ assert os.path.exists(args.outdir), f"Output directory {args.outdir} does not exist."
84
+ elif args.op == "delete":
85
+ assert isinstance(args.project, str) and len(args.project.strip()) > 0, "Project name must be a non-empty string."
86
+ else:
87
+ raise ValueError(f"Unknown operation: {args.op}")
88
+
89
+ def parse_args():
90
+ parser = argparse.ArgumentParser(description="Operations on W&B runs")
91
+ parser.add_argument("-op", "--op", type=str, help="Operation to perform", default="sync", choices=["delete", "sync"])
92
+ parser.add_argument("-prj", "--project", type=str, default="fire-paper2-2025", help="W&B project name")
93
+ parser.add_argument("-outdir", "--outdir", type=str, help="arg1 description", default="./zout/train")
94
+ parser.add_argument("-pt", "--pattern",
95
+ type=str,
96
+ default=None,
97
+ help="Run name pattern to match for deletion",
98
+ )
99
+
100
+ return parser.parse_args()
101
+
102
+
103
+ def main():
104
+ args = parse_args()
105
+ # Validate arguments, stop if invalid
106
+ valid_argument(args)
107
+
108
+ op = args.op
109
+ if op == "sync":
110
+ sync_runs(args.outdir)
111
+ elif op == "delete":
112
+ delete_runs(args.project, args.pattern)
113
+ else:
114
+ raise ValueError(f"Unknown operation: {op}")
115
+
116
+ if __name__ == "__main__":
117
+ main()
File without changes
@@ -0,0 +1,41 @@
1
+ import yaml
2
+ from typing import Any
3
+
4
+ from rich.pretty import pprint
5
+ from dataclasses import make_dataclass
6
+
7
+ from ...filetype import yamlfile
8
+
9
+ def dict_to_dataclass(name: str, data: dict):
10
+ fields = []
11
+ values = {}
12
+
13
+ for key, value in data.items():
14
+ if isinstance(value, dict):
15
+ sub_dc = dict_to_dataclass(key.capitalize(), value)
16
+ fields.append((key, type(sub_dc)))
17
+ values[key] = sub_dc
18
+ else:
19
+ field_type = type(value) if value is not None else Any
20
+ fields.append((key, field_type))
21
+ values[key] = value
22
+
23
+ DC = make_dataclass(name.capitalize(), fields)
24
+ return DC(**values)
25
+
26
+ def yaml_to_dataclass(name: str, yaml_str: str):
27
+ data = yaml.safe_load(yaml_str)
28
+ return dict_to_dataclass(name, data)
29
+
30
+
31
+ def yamlfile_to_dataclass(name: str, file_path: str):
32
+ data_dict = yamlfile.load_yaml(file_path, to_dict=True)
33
+ if "__base__" in data_dict:
34
+ del data_dict["__base__"]
35
+ return dict_to_dataclass(name, data_dict)
36
+
37
+ if __name__ == "__main__":
38
+ cfg = yamlfile_to_dataclass("Config", "test/dataclass_util_test_cfg.yaml")
39
+
40
+ # ! NOTICE: after print out this dataclass, we can copy the output and paste it into CHATGPT to generate a list of needed dataclass classes using `from dataclass_wizard import YAMLWizard`
41
+ pprint(cfg)
@@ -0,0 +1,208 @@
1
+ # This script create a test version
2
+ # of the watcam (wc) dataset
3
+ # for testing the tflite model
4
+
5
+ from argparse import ArgumentParser
6
+
7
+ import os
8
+ import click
9
+ import shutil
10
+ from tqdm import tqdm
11
+ from rich import inspect
12
+ from rich.pretty import pprint
13
+ from torchvision.datasets import ImageFolder
14
+ from sklearn.model_selection import StratifiedShuffleSplit, ShuffleSplit
15
+
16
+ from ...common.common import console, seed_everything, ConsoleLog
17
+ from ...system import filesys as fs
18
+
19
+ def parse_args():
20
+ parser = ArgumentParser(description="desc text")
21
+ parser.add_argument(
22
+ "-indir",
23
+ "--indir",
24
+ type=str,
25
+ help="orignal dataset path",
26
+ )
27
+ parser.add_argument(
28
+ "-outdir",
29
+ "--outdir",
30
+ type=str,
31
+ help="dataset out path",
32
+ default=".", # default to current dir
33
+ )
34
+ parser.add_argument(
35
+ "-val_size",
36
+ "--val_size",
37
+ type=float,
38
+ help="validation size", # no default value to force user to input
39
+ default=0.2,
40
+ )
41
+ # add using StratifiedShuffleSplit or ShuffleSplit
42
+ parser.add_argument(
43
+ "-seed",
44
+ "--seed",
45
+ type=int,
46
+ help="random seed",
47
+ default=42,
48
+ )
49
+ parser.add_argument(
50
+ "-inplace",
51
+ "--inplace",
52
+ action="store_true",
53
+ help="inplace operation, will overwrite the outdir if exists",
54
+ )
55
+
56
+ parser.add_argument(
57
+ "-stratified",
58
+ "--stratified",
59
+ action="store_true",
60
+ help="use StratifiedShuffleSplit instead of ShuffleSplit",
61
+ )
62
+ parser.add_argument(
63
+ "-no_train",
64
+ "--no_train",
65
+ action="store_true",
66
+ help="only create test set, no train set",
67
+ )
68
+ parser.add_argument(
69
+ "-reverse",
70
+ "--reverse",
71
+ action="store_true",
72
+ help="combine train and val set back to original dataset",
73
+ )
74
+ return parser.parse_args()
75
+
76
+
77
+ def move_images(image_paths, target_set_dir):
78
+ for img_path in tqdm(image_paths):
79
+ # get folder name of the image
80
+ img_dir = os.path.dirname(img_path)
81
+ out_cls_dir = os.path.join(target_set_dir, os.path.basename(img_dir))
82
+ if not os.path.exists(out_cls_dir):
83
+ os.makedirs(out_cls_dir)
84
+ # move the image to the class folder
85
+ shutil.move(img_path, out_cls_dir)
86
+
87
+
88
+ def split_dataset_cls(
89
+ indir, outdir, val_size, seed, inplace, stratified_split, no_train
90
+ ):
91
+ seed_everything(seed)
92
+ console.rule("Config confirm?")
93
+ pprint(locals())
94
+ click.confirm("Continue?", abort=True)
95
+ assert os.path.exists(indir), f"{indir} does not exist"
96
+
97
+ if not inplace:
98
+ assert (not inplace) and (
99
+ not os.path.exists(outdir)
100
+ ), f"{outdir} already exists; SKIP ...."
101
+
102
+ if inplace:
103
+ outdir = indir
104
+ if not os.path.exists(outdir):
105
+ os.makedirs(outdir)
106
+
107
+ console.rule(f"Creating train/val dataset")
108
+
109
+ sss = (
110
+ ShuffleSplit(n_splits=1, test_size=val_size)
111
+ if not stratified_split
112
+ else StratifiedShuffleSplit(n_splits=1, test_size=val_size)
113
+ )
114
+
115
+ pprint({"split strategy": sss, "indir": indir, "outdir": outdir})
116
+ dataset = ImageFolder(
117
+ root=indir,
118
+ transform=None,
119
+ )
120
+ train_dataset_indices = None
121
+ val_dataset_indices = None # val here means test
122
+ for train_indices, val_indices in sss.split(dataset.samples, dataset.targets):
123
+ train_dataset_indices = train_indices
124
+ val_dataset_indices = val_indices
125
+
126
+ # get image paths for train/val split dataset
127
+ train_image_paths = [dataset.imgs[i][0] for i in train_dataset_indices]
128
+ val_image_paths = [dataset.imgs[i][0] for i in val_dataset_indices]
129
+
130
+ # start creating train/val folders then move images
131
+ out_train_dir = os.path.join(outdir, "train")
132
+ out_val_dir = os.path.join(outdir, "val")
133
+ if inplace:
134
+ assert os.path.exists(out_train_dir) == False, f"{out_train_dir} already exists"
135
+ assert os.path.exists(out_val_dir) == False, f"{out_val_dir} already exists"
136
+
137
+ os.makedirs(out_train_dir)
138
+ os.makedirs(out_val_dir)
139
+
140
+ if not no_train:
141
+ with ConsoleLog(f"Moving train images to {out_train_dir} "):
142
+ move_images(train_image_paths, out_train_dir)
143
+ else:
144
+ pprint("test only, skip moving train images")
145
+ # remove out_train_dir
146
+ shutil.rmtree(out_train_dir)
147
+
148
+ with ConsoleLog(f"Moving val images to {out_val_dir} "):
149
+ move_images(val_image_paths, out_val_dir)
150
+
151
+ if inplace:
152
+ pprint(f"remove all folders, except train and val")
153
+ for cls_dir in os.listdir(outdir):
154
+ if cls_dir not in ["train", "val"]:
155
+ shutil.rmtree(os.path.join(indir, cls_dir))
156
+
157
+
158
+ def reverse_split_ds(indir):
159
+ console.rule(f"Reversing split dataset <{indir}>...")
160
+ ls_dirs = os.listdir(indir)
161
+ # make sure there are only two dirs 'train' and 'val'
162
+ assert len(ls_dirs) == 2, f"Found more than 2 dirs: {len(ls_dirs) } dirs"
163
+ assert "train" in ls_dirs, f"train dir not found in {indir}"
164
+ assert "val" in ls_dirs, f"val dir not found in {indir}"
165
+ train_dir = os.path.join(indir, "train")
166
+ val_dir = os.path.join(indir, "val")
167
+ all_train_files = fs.filter_files_by_extension(
168
+ train_dir, ["jpg", "jpeg", "png", "bmp", "gif", "tiff"]
169
+ )
170
+ all_val_files = fs.filter_files_by_extension(
171
+ val_dir, ["jpg", "jpeg", "png", "bmp", "gif", "tiff"]
172
+ )
173
+ # move all files from train to indir
174
+ with ConsoleLog(f"Moving train images to {indir} "):
175
+ move_images(all_train_files, indir)
176
+ with ConsoleLog(f"Moving val images to {indir} "):
177
+ move_images(all_val_files, indir)
178
+ with ConsoleLog(f"Removing train and val dirs"):
179
+ # remove train and val dirs
180
+ shutil.rmtree(train_dir)
181
+ shutil.rmtree(val_dir)
182
+
183
+
184
+ def main():
185
+ args = parse_args()
186
+ indir = args.indir
187
+ outdir = args.outdir
188
+ if outdir == ".":
189
+ # get current folder of the indir
190
+ indir_parent_dir = os.path.dirname(os.path.normpath(indir))
191
+ indir_name = os.path.basename(indir)
192
+ outdir = os.path.join(indir_parent_dir, f"{indir_name}_split")
193
+ val_size = args.val_size
194
+ seed = args.seed
195
+ inplace = args.inplace
196
+ stratified_split = args.stratified
197
+ no_train = args.no_train
198
+ reverse = args.reverse
199
+ if not reverse:
200
+ split_dataset_cls(
201
+ indir, outdir, val_size, seed, inplace, stratified_split, no_train
202
+ )
203
+ else:
204
+ reverse_split_ds(indir)
205
+
206
+
207
+ if __name__ == "__main__":
208
+ main()