halib 0.1.91__py3-none-any.whl → 0.2.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. halib/__init__.py +12 -6
  2. halib/common/__init__.py +0 -0
  3. halib/common/common.py +207 -0
  4. halib/common/rich_color.py +285 -0
  5. halib/common.py +53 -10
  6. halib/exp/__init__.py +0 -0
  7. halib/exp/core/__init__.py +0 -0
  8. halib/exp/core/base_config.py +167 -0
  9. halib/exp/core/base_exp.py +147 -0
  10. halib/exp/core/param_gen.py +189 -0
  11. halib/exp/core/wandb_op.py +117 -0
  12. halib/exp/data/__init__.py +0 -0
  13. halib/exp/data/dataclass_util.py +41 -0
  14. halib/exp/data/dataset.py +208 -0
  15. halib/exp/data/torchloader.py +165 -0
  16. halib/exp/perf/__init__.py +0 -0
  17. halib/exp/perf/flop_calc.py +190 -0
  18. halib/exp/perf/gpu_mon.py +58 -0
  19. halib/exp/perf/perfcalc.py +440 -0
  20. halib/exp/perf/perfmetrics.py +137 -0
  21. halib/exp/perf/perftb.py +778 -0
  22. halib/exp/perf/profiler.py +507 -0
  23. halib/exp/viz/__init__.py +0 -0
  24. halib/exp/viz/plot.py +754 -0
  25. halib/filetype/csvfile.py +3 -9
  26. halib/filetype/ipynb.py +61 -0
  27. halib/filetype/jsonfile.py +0 -3
  28. halib/filetype/textfile.py +0 -1
  29. halib/filetype/videofile.py +119 -3
  30. halib/filetype/yamlfile.py +16 -1
  31. halib/online/projectmake.py +7 -6
  32. halib/online/tele_noti.py +165 -0
  33. halib/research/base_exp.py +75 -18
  34. halib/research/core/__init__.py +0 -0
  35. halib/research/core/base_config.py +144 -0
  36. halib/research/core/base_exp.py +157 -0
  37. halib/research/core/param_gen.py +108 -0
  38. halib/research/core/wandb_op.py +117 -0
  39. halib/research/data/__init__.py +0 -0
  40. halib/research/data/dataclass_util.py +41 -0
  41. halib/research/data/dataset.py +208 -0
  42. halib/research/data/torchloader.py +165 -0
  43. halib/research/dataset.py +6 -7
  44. halib/research/flop_csv.py +34 -0
  45. halib/research/flops.py +156 -0
  46. halib/research/metrics.py +4 -0
  47. halib/research/mics.py +59 -1
  48. halib/research/perf/__init__.py +0 -0
  49. halib/research/perf/flop_calc.py +190 -0
  50. halib/research/perf/gpu_mon.py +58 -0
  51. halib/research/perf/perfcalc.py +363 -0
  52. halib/research/perf/perfmetrics.py +137 -0
  53. halib/research/perf/perftb.py +778 -0
  54. halib/research/perf/profiler.py +301 -0
  55. halib/research/perfcalc.py +60 -35
  56. halib/research/perftb.py +2 -1
  57. halib/research/plot.py +480 -218
  58. halib/research/viz/__init__.py +0 -0
  59. halib/research/viz/plot.py +754 -0
  60. halib/system/_list_pc.csv +6 -0
  61. halib/system/filesys.py +60 -20
  62. halib/system/path.py +106 -0
  63. halib/utils/dict.py +9 -0
  64. halib/utils/list.py +12 -0
  65. halib/utils/video.py +6 -0
  66. halib-0.2.21.dist-info/METADATA +192 -0
  67. halib-0.2.21.dist-info/RECORD +109 -0
  68. halib-0.1.91.dist-info/METADATA +0 -201
  69. halib-0.1.91.dist-info/RECORD +0 -61
  70. {halib-0.1.91.dist-info → halib-0.2.21.dist-info}/WHEEL +0 -0
  71. {halib-0.1.91.dist-info → halib-0.2.21.dist-info}/licenses/LICENSE.txt +0 -0
  72. {halib-0.1.91.dist-info → halib-0.2.21.dist-info}/top_level.txt +0 -0
halib/filetype/csvfile.py CHANGED
@@ -1,19 +1,13 @@
1
+ import csv
2
+ import textwrap
1
3
  import pandas as pd
4
+ import pygwalker as pyg
2
5
  from tabulate import tabulate
3
6
  from rich.console import Console
4
- from rich import print as rprint
5
- from rich import inspect
6
- from rich.pretty import pprint
7
- from tqdm import tqdm
8
- from loguru import logger
9
7
  from itables import init_notebook_mode, show
10
- import pygwalker as pyg
11
- import textwrap
12
- import csv
13
8
 
14
9
  console = Console()
15
10
 
16
-
17
11
  def read(file, separator=","):
18
12
  df = pd.read_csv(file, separator)
19
13
  return df
@@ -0,0 +1,61 @@
1
+ import ipynbname
2
+ from pathlib import Path
3
+ from contextlib import contextmanager
4
+
5
+ from ..common.common import now_str
6
+
7
+ @contextmanager
8
+ def gen_ipynb_name(
9
+ filename,
10
+ add_time_stamp=False,
11
+ nb_prefix="nb__",
12
+ separator="__",
13
+ ):
14
+ """
15
+ Context manager that prefixes the filename with the notebook name.
16
+ Output: NotebookName_OriginalName.ext
17
+ """
18
+ try:
19
+ nb_name = ipynbname.name()
20
+ except FileNotFoundError:
21
+ nb_name = "script" # Fallback
22
+
23
+ p = Path(filename)
24
+
25
+ # --- FIX START ---
26
+
27
+ # 1. Get the parts separately
28
+ original_stem = p.stem # "test" (no extension)
29
+ extension = p.suffix # ".csv"
30
+
31
+ now_string = now_str() if add_time_stamp else ""
32
+
33
+ # 2. Construct the base name (Notebook + Separator + OriginalName)
34
+ base_name = f"{nb_prefix}{nb_name}{separator}{original_stem}"
35
+
36
+ # 3. Append timestamp if needed
37
+ if now_string:
38
+ base_name = f"{base_name}{separator}{now_string}"
39
+
40
+ # 4. Add the extension at the VERY END
41
+ new_filename = f"{base_name}{extension}"
42
+
43
+ # --- FIX END ---
44
+
45
+ final_path = p.parent / new_filename
46
+
47
+ # Assuming you use 'rich' console based on your snippet
48
+ # console.rule()
49
+ # print(f"📝 Saving as: {final_path}")
50
+
51
+ yield str(final_path)
52
+
53
+
54
+ if __name__ == "__main__":
55
+ # --- Usage Example ---
56
+ # Assume Notebook Name is: "MyThesisWork"
57
+ filename = "results.csv"
58
+ with gen_ipynb_name(filename) as filename_ipynb:
59
+ # filename_ipynb is now: "MyThesisWork_results.csv"
60
+ print(f"File to save: {filename_ipynb}")
61
+ # df.to_csv(filename_ipynb)
@@ -1,17 +1,14 @@
1
1
  import json
2
2
 
3
-
4
3
  def read(file):
5
4
  with open(file) as f:
6
5
  data = json.load(f)
7
6
  return data
8
7
 
9
-
10
8
  def write(data_dict, outfile):
11
9
  with open(outfile, "w") as json_file:
12
10
  json.dump(data_dict, json_file)
13
11
 
14
-
15
12
  def beautify(json_str):
16
13
  formatted_json = json_str
17
14
  try:
@@ -4,7 +4,6 @@ def read_line_by_line(file_path):
4
4
  lines = [line.rstrip() for line in lines]
5
5
  return lines
6
6
 
7
-
8
7
  def write(lines, outfile, append=False):
9
8
  mode = "a" if append else "w"
10
9
  with open(outfile, mode, encoding="utf-8") as f:
@@ -1,11 +1,128 @@
1
+ import os
1
2
  import cv2
2
- import textfile
3
3
  import enlighten
4
+
4
5
  from enum import Enum
5
- from ..system import filesys
6
6
  from tube_dl import Youtube, Playlist
7
7
  from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
8
8
 
9
+ from . import textfile
10
+ from . import csvfile
11
+ from ..system import filesys
12
+
13
+ class VideoUtils:
14
+ @staticmethod
15
+ def _default_meta_extractor(video_path):
16
+ """Default video metadata extractor function."""
17
+ # Open the video file
18
+ cap = cv2.VideoCapture(video_path)
19
+
20
+ # Check if the video was opened successfully
21
+ if not cap.isOpened():
22
+ print(f"Error: Could not open video file {video_path}")
23
+ return None
24
+
25
+ # Get the frame count
26
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
27
+
28
+ # Get the FPS
29
+ fps = cap.get(cv2.CAP_PROP_FPS)
30
+
31
+ # get frame size
32
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
33
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
34
+
35
+ # Release the video capture object
36
+ cap.release()
37
+
38
+ meta_dict = {
39
+ "video_path": video_path,
40
+ "width": width,
41
+ "height": height,
42
+ "frame_count": frame_count,
43
+ "fps": fps,
44
+ }
45
+ return meta_dict
46
+
47
+ @staticmethod
48
+ def get_video_meta_dict(video_path, meta_dict_extractor_func=None):
49
+ assert os.path.exists(video_path), f"Video file {video_path} does not exist"
50
+ if meta_dict_extractor_func and callable(meta_dict_extractor_func):
51
+ assert (
52
+ meta_dict_extractor_func.__code__.co_argcount == 1
53
+ ), "meta_dict_extractor_func must take exactly one argument (video_path)"
54
+ meta_dict = meta_dict_extractor_func(video_path)
55
+ assert isinstance(
56
+ meta_dict, dict
57
+ ), "meta_dict_extractor_func must return a dictionary"
58
+ assert "video_path" in meta_dict, "meta_dict must contain 'video_path'"
59
+ else:
60
+ meta_dict = VideoUtils._default_meta_extractor(video_path=video_path)
61
+ return meta_dict
62
+
63
+ @staticmethod
64
+ def get_video_dir_meta_df(
65
+ video_dir,
66
+ video_exts=[".mp4", ".avi", ".mov", ".mkv"],
67
+ search_recursive=False,
68
+ csv_outfile=None,
69
+ ):
70
+ assert os.path.exists(video_dir), f"Video directory {video_dir} does not exist"
71
+ video_files = filesys.filter_files_by_extension(
72
+ video_dir, video_exts, recursive=search_recursive
73
+ )
74
+ assert (
75
+ len(video_files) > 0
76
+ ), f"No video files found in {video_dir} with extensions {video_exts}"
77
+ video_meta_list = []
78
+ for vfile in video_files:
79
+ meta_dict = VideoUtils.get_video_meta_dict(vfile)
80
+ if meta_dict:
81
+ video_meta_list.append(meta_dict)
82
+ dfmk = csvfile.DFCreator()
83
+ columns = list(video_meta_list[0].keys())
84
+ assert len(columns) > 0, "No video metadata found"
85
+ assert "video_path" in columns, "video_path column not found in video metadata"
86
+ # move video_path to the first column
87
+ columns.remove("video_path")
88
+ columns.insert(0, "video_path")
89
+ dfmk.create_table("video_meta", columns)
90
+ rows = [[meta[col] for col in columns] for meta in video_meta_list]
91
+ dfmk.insert_rows("video_meta", rows)
92
+ dfmk.fill_table_from_row_pool("video_meta")
93
+
94
+ if csv_outfile:
95
+ dfmk["video_meta"].to_csv(csv_outfile, index=False, sep=";")
96
+ return dfmk["video_meta"].copy()
97
+
98
+
99
+ # -----------------------------
100
+ # FFmpeg Horizontal Stack
101
+ # -----------------------------
102
+ @staticmethod
103
+ def hstack(video_files, output_file):
104
+ """Horizontally stack multiple videos using FFmpeg."""
105
+ tmp_file = "video_list.txt"
106
+ try:
107
+ with open(tmp_file, "w") as f:
108
+ for video in video_files:
109
+ f.write(f"file '{video}'\n")
110
+
111
+ ffmpeg_cmd = (
112
+ f"ffmpeg -f concat -safe 0 -i {tmp_file} "
113
+ f'-filter_complex "[0:v][1:v][2:v]hstack=inputs={len(video_files)}[v]" '
114
+ f'-map "[v]" -c:v libx264 -preset fast -crf 22 {output_file}'
115
+ )
116
+
117
+ os.system(ffmpeg_cmd)
118
+ print(f"[INFO] Video stacked successfully: {output_file}")
119
+
120
+ except Exception as e:
121
+ print(f"[ERROR] Video stacking failed: {e}")
122
+ finally:
123
+ if os.path.exists(tmp_file):
124
+ os.remove(tmp_file)
125
+
9
126
 
10
127
  class VideoResolution(Enum):
11
128
  VR480p = "720x480"
@@ -57,7 +174,6 @@ def trim_video(source, destination, start_time, end_time):
57
174
 
58
175
  progress_bar = None
59
176
 
60
-
61
177
  def on_progress(bytes_done, total_bytes):
62
178
  global progress_bar
63
179
  if progress_bar is None:
@@ -2,10 +2,13 @@ import time
2
2
  import networkx as nx
3
3
  from rich import inspect
4
4
  from rich.pretty import pprint
5
- from omegaconf import OmegaConf
6
5
  from rich.console import Console
6
+
7
+ from omegaconf import OmegaConf
7
8
  from argparse import ArgumentParser
8
9
 
10
+ from ..system.path import *
11
+
9
12
  console = Console()
10
13
 
11
14
 
@@ -52,6 +55,18 @@ def load_yaml(yaml_file, to_dict=False, log_info=False):
52
55
  return omgconf
53
56
 
54
57
 
58
+ def load_yaml_with_PC_abbr(
59
+ yaml_file, abbr_disk_map=ABBR_DISK_MAP
60
+ ):
61
+ # load yaml file
62
+ data_dict = load_yaml(yaml_file=yaml_file, to_dict=True)
63
+ # Normalize paths in the loaded data
64
+ data_dict = normalize_paths(
65
+ data_dict, get_working_disk(abbr_disk_map), get_os_platform()
66
+ )
67
+ return data_dict
68
+
69
+
55
70
  def parse_args():
56
71
  parser = ArgumentParser(description="desc text")
57
72
  parser.add_argument(
@@ -1,17 +1,18 @@
1
1
  # coding=utf-8
2
- import json
2
+
3
3
  import os
4
+ import json
5
+ import pycurl
4
6
  import shutil
5
- from argparse import ArgumentParser
6
- from io import BytesIO
7
+ import certifi
7
8
  import subprocess
9
+ from io import BytesIO
10
+
11
+ from argparse import ArgumentParser
8
12
 
9
- import certifi
10
- import pycurl
11
13
  from ..filetype import jsonfile
12
14
  from ..system import filesys
13
15
 
14
-
15
16
  def get_curl(url, user_and_pass, verbose=True):
16
17
  c = pycurl.Curl()
17
18
  c.setopt(pycurl.VERBOSE, verbose)
@@ -0,0 +1,165 @@
1
+ # Watch a log file and send a telegram message when train reaches a certain epoch or end
2
+
3
+ import os
4
+ import yaml
5
+ import asyncio
6
+ import telegram
7
+ import pandas as pd
8
+
9
+ from rich.pretty import pprint
10
+ from rich.console import Console
11
+ import plotly.graph_objects as go
12
+
13
+ from ..system import filesys as fs
14
+ from ..filetype import textfile, csvfile
15
+
16
+ from argparse import ArgumentParser
17
+
18
+ tele_console = Console()
19
+
20
+
21
+ def parse_args():
22
+ parser = ArgumentParser(description="desc text")
23
+ parser.add_argument(
24
+ "-cfg",
25
+ "--cfg",
26
+ type=str,
27
+ help="yaml file for tele",
28
+ default=r"E:\Dev\__halib\halib\online\tele_noti_cfg.yaml",
29
+ )
30
+
31
+ return parser.parse_args()
32
+
33
+ def get_watcher_message_df(target_file, num_last_lines):
34
+ file_ext = fs.get_file_name(target_file, split_file_ext=True)[1]
35
+ supported_ext = [".txt", ".log", ".csv"]
36
+ assert (
37
+ file_ext in supported_ext
38
+ ), f"File extension {file_ext} not supported. Supported extensions are {supported_ext}"
39
+ last_lines_df = None
40
+ if file_ext in [".txt", ".log"]:
41
+ lines = textfile.read_line_by_line(target_file)
42
+ if num_last_lines > len(lines):
43
+ num_last_lines = len(lines)
44
+ last_line_arr = lines[-num_last_lines:]
45
+ # add a line start with word "epoch"
46
+ epoch_info_list = "Epoch: n/a"
47
+ for line in reversed(lines):
48
+ if "epoch" in line.lower():
49
+ epoch_info_list = line
50
+ break
51
+ last_line_arr.insert(0, epoch_info_list) # insert at the beginning
52
+ dfCreator = csvfile.DFCreator()
53
+ dfCreator.create_table("last_lines", ["line"])
54
+ last_line_arr = [[line] for line in last_line_arr]
55
+ dfCreator.insert_rows("last_lines", last_line_arr)
56
+ dfCreator.fill_table_from_row_pool("last_lines")
57
+ last_lines_df = dfCreator["last_lines"].copy()
58
+ else:
59
+ df = pd.read_csv(target_file)
60
+ num_rows = len(df)
61
+ if num_last_lines > num_rows:
62
+ num_last_lines = num_rows
63
+ last_lines_df = df.tail(num_last_lines)
64
+ return last_lines_df
65
+
66
+
67
+ def df2img(df: pd.DataFrame, output_img_dir, decimal_places, out_img_scale):
68
+ df = df.round(decimal_places)
69
+ fig = go.Figure(
70
+ data=[
71
+ go.Table(
72
+ header=dict(values=list(df.columns), align="center"),
73
+ cells=dict(
74
+ values=df.values.transpose(),
75
+ fill_color=[["white", "lightgrey"] * df.shape[0]],
76
+ align="center",
77
+ ),
78
+ )
79
+ ]
80
+ )
81
+ if not os.path.exists(output_img_dir):
82
+ os.makedirs(output_img_dir)
83
+ img_path = os.path.normpath(os.path.join(output_img_dir, "last_lines.png"))
84
+ fig.write_image(img_path, scale=out_img_scale)
85
+ return img_path
86
+
87
+
88
+ def compose_message_and_img_path(
89
+ target_file, project, num_last_lines, decimal_places, out_img_scale, output_img_dir
90
+ ):
91
+ context_msg = f">> Project: {project} \n>> File: {target_file} \n>> Last {num_last_lines} lines:"
92
+ msg_df = get_watcher_message_df(target_file, num_last_lines)
93
+ try:
94
+ img_path = df2img(msg_df, output_img_dir, decimal_places, out_img_scale)
95
+ except Exception as e:
96
+ pprint(f"Error: {e}")
97
+ img_path = None
98
+ return context_msg, img_path
99
+
100
+
101
+ async def send_to_telegram(cfg_dict, interval_in_sec):
102
+ # pprint(cfg_dict)
103
+ token = cfg_dict["telegram"]["token"]
104
+ chat_id = cfg_dict["telegram"]["chat_id"]
105
+
106
+ noti_settings = cfg_dict["noti_settings"]
107
+ project = noti_settings["project"]
108
+ target_file = noti_settings["target_file"]
109
+ num_last_lines = noti_settings["num_last_lines"]
110
+ output_img_dir = noti_settings["output_img_dir"]
111
+ decimal_places = noti_settings["decimal_places"]
112
+ out_img_scale = noti_settings["out_img_scale"]
113
+
114
+ bot = telegram.Bot(token=token)
115
+ async with bot:
116
+ try:
117
+ context_msg, img_path = compose_message_and_img_path(
118
+ target_file,
119
+ project,
120
+ num_last_lines,
121
+ decimal_places,
122
+ out_img_scale,
123
+ output_img_dir,
124
+ )
125
+ time_now = next_time = pd.Timestamp.now().strftime("%Y-%m-%d %H:%M:%S")
126
+ sep_line = "-" * 50
127
+ context_msg = f"{sep_line}\n>> Time: {time_now}\n{context_msg}"
128
+ # calculate the next time to send message
129
+ next_time = pd.Timestamp.now() + pd.Timedelta(seconds=interval_in_sec)
130
+ next_time = next_time.strftime("%Y-%m-%d %H:%M:%S")
131
+ next_time_info = f"Next msg: {next_time}"
132
+ tele_console.rule()
133
+ tele_console.print("[green] Send message to telegram [/green]")
134
+ tele_console.print(
135
+ f"[red] Next message will be sent at <{next_time}> [/red]"
136
+ )
137
+ await bot.send_message(text=context_msg, chat_id=chat_id)
138
+ if img_path:
139
+ await bot.send_photo(chat_id=chat_id, photo=open(img_path, "rb"))
140
+ await bot.send_message(text=next_time_info, chat_id=chat_id)
141
+ except Exception as e:
142
+ pprint(f"Error: {e}")
143
+ pprint("Message not sent to telegram")
144
+
145
+
146
+ async def run_forever(cfg_path):
147
+ cfg_dict = yaml.safe_load(open(cfg_path, "r"))
148
+ noti_settings = cfg_dict["noti_settings"]
149
+ interval_in_min = noti_settings["interval_in_min"]
150
+ interval_in_sec = int(interval_in_min * 60)
151
+ pprint(
152
+ f"Message will be sent every {interval_in_min} minutes or {interval_in_sec} seconds"
153
+ )
154
+ while True:
155
+ await send_to_telegram(cfg_dict, interval_in_sec)
156
+ await asyncio.sleep(interval_in_sec)
157
+
158
+
159
+ async def main():
160
+ args = parse_args()
161
+ await run_forever(args.cfg)
162
+
163
+
164
+ if __name__ == "__main__":
165
+ asyncio.run(main())
@@ -1,5 +1,5 @@
1
1
  from abc import ABC, abstractmethod
2
-
2
+ from typing import Tuple, Any, Optional
3
3
  from ..research.base_config import ExpBaseConfig
4
4
  from ..research.perfcalc import PerfCalc
5
5
  from ..research.metrics import MetricsBackend
@@ -14,6 +14,8 @@ class BaseExperiment(PerfCalc, ABC):
14
14
  def __init__(self, config: ExpBaseConfig):
15
15
  self.config = config
16
16
  self.metric_backend = None
17
+ # Flag to track if init_general/prepare_dataset has run
18
+ self._is_env_ready = False
17
19
 
18
20
  # -----------------------
19
21
  # PerfCalc Required Methods
@@ -51,50 +53,105 @@ class BaseExperiment(PerfCalc, ABC):
51
53
  pass
52
54
 
53
55
  @abstractmethod
54
- def exec_exp(self, *args, **kwargs):
56
+ def before_exec_exp_once(self, *args, **kwargs):
57
+ """Optional: any setup before exec_exp. Note this is called once per run_exp."""
58
+ pass
59
+
60
+ @abstractmethod
61
+ def exec_exp(self, *args, **kwargs) -> Optional[Tuple[Any, Any]]:
55
62
  """Run experiment process, e.g.: training/evaluation loop.
56
- Return: raw_metrics_data, and extra_data as input for calc_and_save_exp_perfs
63
+ Return: either `None` or a tuple of (raw_metrics_data, extra_data) for calc_and_save_exp_perfs
57
64
  """
58
65
  pass
59
66
 
60
- def eval_exp(self):
61
- """Optional: re-run evaluation from saved results."""
67
+ @abstractmethod
68
+ def exec_eval(self, *args, **kwargs) -> Optional[Tuple[Any, Any]]:
69
+ """Run evaluation process.
70
+ Return: either `None` or a tuple of (raw_metrics_data, extra_data) for calc_and_save_exp_perfs
71
+ """
62
72
  pass
63
73
 
74
+ # -----------------------
75
+ # Internal Helpers
76
+ # -----------------------
77
+ def _validate_and_unpack(self, results):
78
+ if results is None:
79
+ return None
80
+ if not isinstance(results, (tuple, list)) or len(results) != 2:
81
+ raise ValueError("exec must return (metrics_data, extra_data)")
82
+ return results[0], results[1]
83
+
84
+ def _prepare_environment(self, force_reload: bool = False):
85
+ """
86
+ Common setup. Skips if already initialized, unless force_reload is True.
87
+ """
88
+ if self._is_env_ready and not force_reload:
89
+ # Environment is already prepared, skipping setup.
90
+ return
91
+
92
+ # 1. Run Setup
93
+ self.init_general(self.config.get_general_cfg())
94
+ self.prepare_dataset(self.config.get_dataset_cfg())
95
+
96
+ # 2. Update metric backend (refresh if needed)
97
+ self.metric_backend = self.prepare_metrics(self.config.get_metric_cfg())
98
+
99
+ # 3. Mark as ready
100
+ self._is_env_ready = True
101
+
64
102
  # -----------------------
65
103
  # Main Experiment Runner
66
104
  # -----------------------
67
- def run_exp(self, do_calc_metrics=True, *args, **kwargs):
105
+ def run_exp(self, should_calc_metrics=True, reload_env=False, *args, **kwargs):
68
106
  """
69
107
  Run the whole experiment pipeline.
70
- Params:
108
+ :param reload_env: If True, forces dataset/general init to run again.
109
+ :param should_calc_metrics: Whether to calculate and save metrics after execution.
110
+ :kwargs Params:
71
111
  + 'outfile' to save csv file results,
72
112
  + 'outdir' to set output directory for experiment results.
73
113
  + 'return_df' to return a DataFrame of results instead of a dictionary.
74
114
 
75
115
  Full pipeline:
76
116
  1. Init
77
- 2. Dataset
78
- 3. Metrics Preparation
79
- 4. Save Config
80
- 5. Execute
81
- 6. Calculate & Save Metrics
117
+ 2. Prepare Environment (General + Dataset + Metrics)
118
+ 3. Save Config
119
+ 4. Execute
120
+ 5. Calculate & Save Metrics
82
121
  """
83
- self.init_general(self.config.get_general_cfg())
84
- self.prepare_dataset(self.config.get_dataset_cfg())
85
- self.prepare_metrics(self.config.get_metric_cfg())
122
+ self._prepare_environment(force_reload=reload_env)
86
123
 
124
+ # Any pre-exec setup (loading models, etc)
125
+ self.before_exec_exp_once(*args, **kwargs)
87
126
  # Save config before running
88
127
  self.config.save_to_outdir()
89
128
 
90
129
  # Execute experiment
91
130
  results = self.exec_exp(*args, **kwargs)
92
- if do_calc_metrics:
93
- metrics_data, extra_data = results
131
+
132
+ if should_calc_metrics and results is not None:
133
+ metrics_data, extra_data = self._validate_and_unpack(results)
94
134
  # Calculate & Save metrics
95
- perf_results = self.calc_and_save_exp_perfs(
135
+ perf_results = self.calc_perfs(
96
136
  raw_metrics_data=metrics_data, extra_data=extra_data, *args, **kwargs
97
137
  )
98
138
  return perf_results
99
139
  else:
100
140
  return results
141
+
142
+ # -----------------------
143
+ # Main Experiment Evaluator
144
+ # -----------------------
145
+ def eval_exp(self, reload_env=False, *args, **kwargs):
146
+ """
147
+ Run evaluation only.
148
+ :param reload_env: If True, forces dataset/general init to run again.
149
+ """
150
+ self._prepare_environment(force_reload=reload_env)
151
+ results = self.exec_eval(*args, **kwargs)
152
+ if results is not None:
153
+ metrics_data, extra_data = self._validate_and_unpack(results)
154
+ return self.calc_perfs(
155
+ raw_metrics_data=metrics_data, extra_data=extra_data, *args, **kwargs
156
+ )
157
+ return None
File without changes