halib 0.1.99__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. halib/__init__.py +3 -3
  2. halib/common/__init__.py +0 -0
  3. halib/common/common.py +178 -0
  4. halib/common/rich_color.py +285 -0
  5. halib/filetype/csvfile.py +3 -9
  6. halib/filetype/ipynb.py +3 -5
  7. halib/filetype/jsonfile.py +0 -3
  8. halib/filetype/textfile.py +0 -1
  9. halib/filetype/videofile.py +91 -2
  10. halib/filetype/yamlfile.py +3 -3
  11. halib/online/projectmake.py +7 -6
  12. halib/online/tele_noti.py +165 -0
  13. halib/research/base_exp.py +75 -18
  14. halib/research/core/__init__.py +0 -0
  15. halib/research/core/base_config.py +144 -0
  16. halib/research/core/base_exp.py +157 -0
  17. halib/research/core/param_gen.py +108 -0
  18. halib/research/core/wandb_op.py +117 -0
  19. halib/research/data/__init__.py +0 -0
  20. halib/research/data/dataclass_util.py +41 -0
  21. halib/research/data/dataset.py +208 -0
  22. halib/research/data/torchloader.py +165 -0
  23. halib/research/dataset.py +1 -1
  24. halib/research/metrics.py +4 -0
  25. halib/research/mics.py +8 -2
  26. halib/research/perf/__init__.py +0 -0
  27. halib/research/perf/flop_calc.py +190 -0
  28. halib/research/perf/gpu_mon.py +58 -0
  29. halib/research/perf/perfcalc.py +363 -0
  30. halib/research/perf/perfmetrics.py +137 -0
  31. halib/research/perf/perftb.py +778 -0
  32. halib/research/perf/profiler.py +301 -0
  33. halib/research/perfcalc.py +57 -32
  34. halib/research/viz/__init__.py +0 -0
  35. halib/research/viz/plot.py +754 -0
  36. halib/system/filesys.py +60 -20
  37. halib/system/path.py +73 -0
  38. halib/utils/dict.py +9 -0
  39. halib/utils/list.py +12 -0
  40. {halib-0.1.99.dist-info → halib-0.2.2.dist-info}/METADATA +7 -1
  41. halib-0.2.2.dist-info/RECORD +89 -0
  42. halib-0.1.99.dist-info/RECORD +0 -64
  43. {halib-0.1.99.dist-info → halib-0.2.2.dist-info}/WHEEL +0 -0
  44. {halib-0.1.99.dist-info → halib-0.2.2.dist-info}/licenses/LICENSE.txt +0 -0
  45. {halib-0.1.99.dist-info → halib-0.2.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,165 @@
1
+ # Watch a log file and send a telegram message when train reaches a certain epoch or end
2
+
3
+ import os
4
+ import yaml
5
+ import asyncio
6
+ import telegram
7
+ import pandas as pd
8
+
9
+ from rich.pretty import pprint
10
+ from rich.console import Console
11
+ import plotly.graph_objects as go
12
+
13
+ from ..system import filesys as fs
14
+ from ..filetype import textfile, csvfile
15
+
16
+ from argparse import ArgumentParser
17
+
18
+ tele_console = Console()
19
+
20
+
21
+ def parse_args():
22
+ parser = ArgumentParser(description="desc text")
23
+ parser.add_argument(
24
+ "-cfg",
25
+ "--cfg",
26
+ type=str,
27
+ help="yaml file for tele",
28
+ default=r"E:\Dev\__halib\halib\online\tele_noti_cfg.yaml",
29
+ )
30
+
31
+ return parser.parse_args()
32
+
33
+ def get_watcher_message_df(target_file, num_last_lines):
34
+ file_ext = fs.get_file_name(target_file, split_file_ext=True)[1]
35
+ supported_ext = [".txt", ".log", ".csv"]
36
+ assert (
37
+ file_ext in supported_ext
38
+ ), f"File extension {file_ext} not supported. Supported extensions are {supported_ext}"
39
+ last_lines_df = None
40
+ if file_ext in [".txt", ".log"]:
41
+ lines = textfile.read_line_by_line(target_file)
42
+ if num_last_lines > len(lines):
43
+ num_last_lines = len(lines)
44
+ last_line_arr = lines[-num_last_lines:]
45
+ # add a line start with word "epoch"
46
+ epoch_info_list = "Epoch: n/a"
47
+ for line in reversed(lines):
48
+ if "epoch" in line.lower():
49
+ epoch_info_list = line
50
+ break
51
+ last_line_arr.insert(0, epoch_info_list) # insert at the beginning
52
+ dfCreator = csvfile.DFCreator()
53
+ dfCreator.create_table("last_lines", ["line"])
54
+ last_line_arr = [[line] for line in last_line_arr]
55
+ dfCreator.insert_rows("last_lines", last_line_arr)
56
+ dfCreator.fill_table_from_row_pool("last_lines")
57
+ last_lines_df = dfCreator["last_lines"].copy()
58
+ else:
59
+ df = pd.read_csv(target_file)
60
+ num_rows = len(df)
61
+ if num_last_lines > num_rows:
62
+ num_last_lines = num_rows
63
+ last_lines_df = df.tail(num_last_lines)
64
+ return last_lines_df
65
+
66
+
67
+ def df2img(df: pd.DataFrame, output_img_dir, decimal_places, out_img_scale):
68
+ df = df.round(decimal_places)
69
+ fig = go.Figure(
70
+ data=[
71
+ go.Table(
72
+ header=dict(values=list(df.columns), align="center"),
73
+ cells=dict(
74
+ values=df.values.transpose(),
75
+ fill_color=[["white", "lightgrey"] * df.shape[0]],
76
+ align="center",
77
+ ),
78
+ )
79
+ ]
80
+ )
81
+ if not os.path.exists(output_img_dir):
82
+ os.makedirs(output_img_dir)
83
+ img_path = os.path.normpath(os.path.join(output_img_dir, "last_lines.png"))
84
+ fig.write_image(img_path, scale=out_img_scale)
85
+ return img_path
86
+
87
+
88
+ def compose_message_and_img_path(
89
+ target_file, project, num_last_lines, decimal_places, out_img_scale, output_img_dir
90
+ ):
91
+ context_msg = f">> Project: {project} \n>> File: {target_file} \n>> Last {num_last_lines} lines:"
92
+ msg_df = get_watcher_message_df(target_file, num_last_lines)
93
+ try:
94
+ img_path = df2img(msg_df, output_img_dir, decimal_places, out_img_scale)
95
+ except Exception as e:
96
+ pprint(f"Error: {e}")
97
+ img_path = None
98
+ return context_msg, img_path
99
+
100
+
101
+ async def send_to_telegram(cfg_dict, interval_in_sec):
102
+ # pprint(cfg_dict)
103
+ token = cfg_dict["telegram"]["token"]
104
+ chat_id = cfg_dict["telegram"]["chat_id"]
105
+
106
+ noti_settings = cfg_dict["noti_settings"]
107
+ project = noti_settings["project"]
108
+ target_file = noti_settings["target_file"]
109
+ num_last_lines = noti_settings["num_last_lines"]
110
+ output_img_dir = noti_settings["output_img_dir"]
111
+ decimal_places = noti_settings["decimal_places"]
112
+ out_img_scale = noti_settings["out_img_scale"]
113
+
114
+ bot = telegram.Bot(token=token)
115
+ async with bot:
116
+ try:
117
+ context_msg, img_path = compose_message_and_img_path(
118
+ target_file,
119
+ project,
120
+ num_last_lines,
121
+ decimal_places,
122
+ out_img_scale,
123
+ output_img_dir,
124
+ )
125
+ time_now = next_time = pd.Timestamp.now().strftime("%Y-%m-%d %H:%M:%S")
126
+ sep_line = "-" * 50
127
+ context_msg = f"{sep_line}\n>> Time: {time_now}\n{context_msg}"
128
+ # calculate the next time to send message
129
+ next_time = pd.Timestamp.now() + pd.Timedelta(seconds=interval_in_sec)
130
+ next_time = next_time.strftime("%Y-%m-%d %H:%M:%S")
131
+ next_time_info = f"Next msg: {next_time}"
132
+ tele_console.rule()
133
+ tele_console.print("[green] Send message to telegram [/green]")
134
+ tele_console.print(
135
+ f"[red] Next message will be sent at <{next_time}> [/red]"
136
+ )
137
+ await bot.send_message(text=context_msg, chat_id=chat_id)
138
+ if img_path:
139
+ await bot.send_photo(chat_id=chat_id, photo=open(img_path, "rb"))
140
+ await bot.send_message(text=next_time_info, chat_id=chat_id)
141
+ except Exception as e:
142
+ pprint(f"Error: {e}")
143
+ pprint("Message not sent to telegram")
144
+
145
+
146
+ async def run_forever(cfg_path):
147
+ cfg_dict = yaml.safe_load(open(cfg_path, "r"))
148
+ noti_settings = cfg_dict["noti_settings"]
149
+ interval_in_min = noti_settings["interval_in_min"]
150
+ interval_in_sec = int(interval_in_min * 60)
151
+ pprint(
152
+ f"Message will be sent every {interval_in_min} minutes or {interval_in_sec} seconds"
153
+ )
154
+ while True:
155
+ await send_to_telegram(cfg_dict, interval_in_sec)
156
+ await asyncio.sleep(interval_in_sec)
157
+
158
+
159
+ async def main():
160
+ args = parse_args()
161
+ await run_forever(args.cfg)
162
+
163
+
164
+ if __name__ == "__main__":
165
+ asyncio.run(main())
@@ -1,5 +1,5 @@
1
1
  from abc import ABC, abstractmethod
2
-
2
+ from typing import Tuple, Any, Optional
3
3
  from ..research.base_config import ExpBaseConfig
4
4
  from ..research.perfcalc import PerfCalc
5
5
  from ..research.metrics import MetricsBackend
@@ -14,6 +14,8 @@ class BaseExperiment(PerfCalc, ABC):
14
14
  def __init__(self, config: ExpBaseConfig):
15
15
  self.config = config
16
16
  self.metric_backend = None
17
+ # Flag to track if init_general/prepare_dataset has run
18
+ self._is_env_ready = False
17
19
 
18
20
  # -----------------------
19
21
  # PerfCalc Required Methods
@@ -51,50 +53,105 @@ class BaseExperiment(PerfCalc, ABC):
51
53
  pass
52
54
 
53
55
  @abstractmethod
54
- def exec_exp(self, *args, **kwargs):
56
+ def before_exec_exp_once(self, *args, **kwargs):
57
+ """Optional: any setup before exec_exp. Note this is called once per run_exp."""
58
+ pass
59
+
60
+ @abstractmethod
61
+ def exec_exp(self, *args, **kwargs) -> Optional[Tuple[Any, Any]]:
55
62
  """Run experiment process, e.g.: training/evaluation loop.
56
- Return: raw_metrics_data, and extra_data as input for calc_and_save_exp_perfs
63
+ Return: either `None` or a tuple of (raw_metrics_data, extra_data) for calc_and_save_exp_perfs
57
64
  """
58
65
  pass
59
66
 
60
- def eval_exp(self):
61
- """Optional: re-run evaluation from saved results."""
67
+ @abstractmethod
68
+ def exec_eval(self, *args, **kwargs) -> Optional[Tuple[Any, Any]]:
69
+ """Run evaluation process.
70
+ Return: either `None` or a tuple of (raw_metrics_data, extra_data) for calc_and_save_exp_perfs
71
+ """
62
72
  pass
63
73
 
74
+ # -----------------------
75
+ # Internal Helpers
76
+ # -----------------------
77
+ def _validate_and_unpack(self, results):
78
+ if results is None:
79
+ return None
80
+ if not isinstance(results, (tuple, list)) or len(results) != 2:
81
+ raise ValueError("exec must return (metrics_data, extra_data)")
82
+ return results[0], results[1]
83
+
84
+ def _prepare_environment(self, force_reload: bool = False):
85
+ """
86
+ Common setup. Skips if already initialized, unless force_reload is True.
87
+ """
88
+ if self._is_env_ready and not force_reload:
89
+ # Environment is already prepared, skipping setup.
90
+ return
91
+
92
+ # 1. Run Setup
93
+ self.init_general(self.config.get_general_cfg())
94
+ self.prepare_dataset(self.config.get_dataset_cfg())
95
+
96
+ # 2. Update metric backend (refresh if needed)
97
+ self.metric_backend = self.prepare_metrics(self.config.get_metric_cfg())
98
+
99
+ # 3. Mark as ready
100
+ self._is_env_ready = True
101
+
64
102
  # -----------------------
65
103
  # Main Experiment Runner
66
104
  # -----------------------
67
- def run_exp(self, do_calc_metrics=True, *args, **kwargs):
105
+ def run_exp(self, should_calc_metrics=True, reload_env=False, *args, **kwargs):
68
106
  """
69
107
  Run the whole experiment pipeline.
70
- Params:
108
+ :param reload_env: If True, forces dataset/general init to run again.
109
+ :param should_calc_metrics: Whether to calculate and save metrics after execution.
110
+ :kwargs Params:
71
111
  + 'outfile' to save csv file results,
72
112
  + 'outdir' to set output directory for experiment results.
73
113
  + 'return_df' to return a DataFrame of results instead of a dictionary.
74
114
 
75
115
  Full pipeline:
76
116
  1. Init
77
- 2. Dataset
78
- 3. Metrics Preparation
79
- 4. Save Config
80
- 5. Execute
81
- 6. Calculate & Save Metrics
117
+ 2. Prepare Environment (General + Dataset + Metrics)
118
+ 3. Save Config
119
+ 4. Execute
120
+ 5. Calculate & Save Metrics
82
121
  """
83
- self.init_general(self.config.get_general_cfg())
84
- self.prepare_dataset(self.config.get_dataset_cfg())
85
- self.prepare_metrics(self.config.get_metric_cfg())
122
+ self._prepare_environment(force_reload=reload_env)
86
123
 
124
+ # Any pre-exec setup (loading models, etc)
125
+ self.before_exec_exp_once(*args, **kwargs)
87
126
  # Save config before running
88
127
  self.config.save_to_outdir()
89
128
 
90
129
  # Execute experiment
91
130
  results = self.exec_exp(*args, **kwargs)
92
- if do_calc_metrics:
93
- metrics_data, extra_data = results
131
+
132
+ if should_calc_metrics and results is not None:
133
+ metrics_data, extra_data = self._validate_and_unpack(results)
94
134
  # Calculate & Save metrics
95
- perf_results = self.calc_and_save_exp_perfs(
135
+ perf_results = self.calc_perfs(
96
136
  raw_metrics_data=metrics_data, extra_data=extra_data, *args, **kwargs
97
137
  )
98
138
  return perf_results
99
139
  else:
100
140
  return results
141
+
142
+ # -----------------------
143
+ # Main Experiment Evaluator
144
+ # -----------------------
145
+ def eval_exp(self, reload_env=False, *args, **kwargs):
146
+ """
147
+ Run evaluation only.
148
+ :param reload_env: If True, forces dataset/general init to run again.
149
+ """
150
+ self._prepare_environment(force_reload=reload_env)
151
+ results = self.exec_eval(*args, **kwargs)
152
+ if results is not None:
153
+ metrics_data, extra_data = self._validate_and_unpack(results)
154
+ return self.calc_perfs(
155
+ raw_metrics_data=metrics_data, extra_data=extra_data, *args, **kwargs
156
+ )
157
+ return None
File without changes
@@ -0,0 +1,144 @@
1
+ import os
2
+ from rich.pretty import pprint
3
+ from abc import ABC, abstractmethod
4
+ from typing import List, Optional, TypeVar, Generic
5
+
6
+ from abc import ABC, abstractmethod
7
+ from dataclasses import dataclass
8
+ from dataclass_wizard import YAMLWizard
9
+
10
+
11
+ class NamedConfig(ABC):
12
+ """
13
+ Base class for named configurations.
14
+ All configurations should have a name.
15
+ """
16
+
17
+ @abstractmethod
18
+ def get_name(self):
19
+ """
20
+ Get the name of the configuration.
21
+ This method should be implemented in subclasses.
22
+ """
23
+ pass
24
+
25
+
26
+ @dataclass
27
+ class AutoNamedConfig(YAMLWizard, NamedConfig):
28
+ """
29
+ Mixin that automatically implements get_name() by returning self.name.
30
+ Classes using this MUST have a 'name' field.
31
+ """
32
+
33
+ name: Optional[str] = None
34
+
35
+ def get_name(self):
36
+ return self.name
37
+
38
+ def __post_init__(self):
39
+ # Enforce the "MUST" rule here
40
+ if self.name is None:
41
+ # We allow None during initial load, but it must be set before usage
42
+ # or handled by the loader.
43
+ pass
44
+
45
+ T = TypeVar("T", bound=AutoNamedConfig)
46
+
47
+ class BaseSelectorConfig(Generic[T]):
48
+ """
49
+ Base class to handle the logic of selecting an item from a list by name.
50
+ """
51
+
52
+ def _resolve_selection(self, items: List[T], selected_name: str, context: str) -> T:
53
+ if selected_name is None:
54
+ raise ValueError(f"No {context} selected in the configuration.")
55
+
56
+ # Create a lookup dict for O(1) access, or just iterate if list is short
57
+ for item in items:
58
+ if item.name == selected_name:
59
+ return item
60
+
61
+ raise ValueError(
62
+ f"{context.capitalize()} '{selected_name}' not found in the configuration list."
63
+ )
64
+
65
+
66
+ class ExpBaseConfig(ABC, YAMLWizard):
67
+ """
68
+ Base class for configuration objects.
69
+ What a cfg class must have:
70
+ 1 - a dataset cfg
71
+ 2 - a metric cfg
72
+ 3 - a method cfg
73
+ """
74
+
75
+ # Save to yaml fil
76
+ def save_to_outdir(
77
+ self, filename: str = "__config.yaml", outdir=None, override: bool = False
78
+ ) -> None:
79
+ """
80
+ Save the configuration to the output directory.
81
+ """
82
+ if outdir is not None:
83
+ output_dir = outdir
84
+ else:
85
+ output_dir = self.get_outdir()
86
+ os.makedirs(output_dir, exist_ok=True)
87
+ assert (output_dir is not None) and (
88
+ os.path.isdir(output_dir)
89
+ ), f"Output directory '{output_dir}' does not exist or is not a directory."
90
+ file_path = os.path.join(output_dir, filename)
91
+ if os.path.exists(file_path) and not override:
92
+ pprint(
93
+ f"File '{file_path}' already exists. Use 'override=True' to overwrite."
94
+ )
95
+ else:
96
+ # method of YAMLWizard to_yaml_file
97
+ self.to_yaml_file(file_path)
98
+
99
+ @classmethod
100
+ @abstractmethod
101
+ # load from a custom YAML file
102
+ def from_custom_yaml_file(cls, yaml_file: str):
103
+ """Load a configuration from a custom YAML file."""
104
+ pass
105
+
106
+ @abstractmethod
107
+ def get_cfg_name(self):
108
+ """
109
+ Get the name of the configuration.
110
+ This method should be implemented in subclasses.
111
+ """
112
+ pass
113
+
114
+ @abstractmethod
115
+ def get_outdir(self):
116
+ """
117
+ Get the output directory for the configuration.
118
+ This method should be implemented in subclasses.
119
+ """
120
+ return None
121
+
122
+ @abstractmethod
123
+ def get_general_cfg(self):
124
+ """
125
+ Get the general configuration like output directory, log settings, SEED, etc.
126
+ This method should be implemented in subclasses.
127
+ """
128
+ pass
129
+
130
+ @abstractmethod
131
+ def get_dataset_cfg(self) -> NamedConfig:
132
+ """
133
+ Get the dataset configuration.
134
+ This method should be implemented in subclasses.
135
+ """
136
+ pass
137
+
138
+ @abstractmethod
139
+ def get_metric_cfg(self) -> NamedConfig:
140
+ """
141
+ Get the metric configuration.
142
+ This method should be implemented in subclasses.
143
+ """
144
+ pass
@@ -0,0 +1,157 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Tuple, Any, Optional
3
+ from base_config import ExpBaseConfig
4
+ from ..perf.perfcalc import PerfCalc
5
+ from ..perf.perfmetrics import MetricsBackend
6
+
7
+ # ! SEE https://github.com/hahv/base_exp for sample usage
8
+ class BaseExperiment(PerfCalc, ABC):
9
+ """
10
+ Base class for experiments.
11
+ Orchestrates the experiment pipeline using a pluggable metrics backend.
12
+ """
13
+
14
+ def __init__(self, config: ExpBaseConfig):
15
+ self.config = config
16
+ self.metric_backend = None
17
+ # Flag to track if init_general/prepare_dataset has run
18
+ self._is_env_ready = False
19
+
20
+ # -----------------------
21
+ # PerfCalc Required Methods
22
+ # -----------------------
23
+ def get_dataset_name(self):
24
+ return self.config.get_dataset_cfg().get_name()
25
+
26
+ def get_experiment_name(self):
27
+ return self.config.get_cfg_name()
28
+
29
+ def get_metric_backend(self):
30
+ if not self.metric_backend:
31
+ self.metric_backend = self.prepare_metrics(self.config.get_metric_cfg())
32
+ return self.metric_backend
33
+
34
+ # -----------------------
35
+ # Abstract Experiment Steps
36
+ # -----------------------
37
+ @abstractmethod
38
+ def init_general(self, general_cfg):
39
+ """Setup general settings like SEED, logging, env variables."""
40
+ pass
41
+
42
+ @abstractmethod
43
+ def prepare_dataset(self, dataset_cfg):
44
+ """Load/prepare dataset."""
45
+ pass
46
+
47
+ @abstractmethod
48
+ def prepare_metrics(self, metric_cfg) -> MetricsBackend:
49
+ """
50
+ Prepare the metrics for the experiment.
51
+ This method should be implemented in subclasses.
52
+ """
53
+ pass
54
+
55
+ @abstractmethod
56
+ def before_exec_exp_once(self, *args, **kwargs):
57
+ """Optional: any setup before exec_exp. Note this is called once per run_exp."""
58
+ pass
59
+
60
+ @abstractmethod
61
+ def exec_exp(self, *args, **kwargs) -> Optional[Tuple[Any, Any]]:
62
+ """Run experiment process, e.g.: training/evaluation loop.
63
+ Return: either `None` or a tuple of (raw_metrics_data, extra_data) for calc_and_save_exp_perfs
64
+ """
65
+ pass
66
+
67
+ @abstractmethod
68
+ def exec_eval(self, *args, **kwargs) -> Optional[Tuple[Any, Any]]:
69
+ """Run evaluation process.
70
+ Return: either `None` or a tuple of (raw_metrics_data, extra_data) for calc_and_save_exp_perfs
71
+ """
72
+ pass
73
+
74
+ # -----------------------
75
+ # Internal Helpers
76
+ # -----------------------
77
+ def _validate_and_unpack(self, results):
78
+ if results is None:
79
+ return None
80
+ if not isinstance(results, (tuple, list)) or len(results) != 2:
81
+ raise ValueError("exec must return (metrics_data, extra_data)")
82
+ return results[0], results[1]
83
+
84
+ def _prepare_environment(self, force_reload: bool = False):
85
+ """
86
+ Common setup. Skips if already initialized, unless force_reload is True.
87
+ """
88
+ if self._is_env_ready and not force_reload:
89
+ # Environment is already prepared, skipping setup.
90
+ return
91
+
92
+ # 1. Run Setup
93
+ self.init_general(self.config.get_general_cfg())
94
+ self.prepare_dataset(self.config.get_dataset_cfg())
95
+
96
+ # 2. Update metric backend (refresh if needed)
97
+ self.metric_backend = self.prepare_metrics(self.config.get_metric_cfg())
98
+
99
+ # 3. Mark as ready
100
+ self._is_env_ready = True
101
+
102
+ # -----------------------
103
+ # Main Experiment Runner
104
+ # -----------------------
105
+ def run_exp(self, should_calc_metrics=True, reload_env=False, *args, **kwargs):
106
+ """
107
+ Run the whole experiment pipeline.
108
+ :param reload_env: If True, forces dataset/general init to run again.
109
+ :param should_calc_metrics: Whether to calculate and save metrics after execution.
110
+ :kwargs Params:
111
+ + 'outfile' to save csv file results,
112
+ + 'outdir' to set output directory for experiment results.
113
+ + 'return_df' to return a DataFrame of results instead of a dictionary.
114
+
115
+ Full pipeline:
116
+ 1. Init
117
+ 2. Prepare Environment (General + Dataset + Metrics)
118
+ 3. Save Config
119
+ 4. Execute
120
+ 5. Calculate & Save Metrics
121
+ """
122
+ self._prepare_environment(force_reload=reload_env)
123
+
124
+ # Any pre-exec setup (loading models, etc)
125
+ self.before_exec_exp_once(*args, **kwargs)
126
+ # Save config before running
127
+ self.config.save_to_outdir()
128
+
129
+ # Execute experiment
130
+ results = self.exec_exp(*args, **kwargs)
131
+
132
+ if should_calc_metrics and results is not None:
133
+ metrics_data, extra_data = self._validate_and_unpack(results)
134
+ # Calculate & Save metrics
135
+ perf_results = self.calc_perfs(
136
+ raw_metrics_data=metrics_data, extra_data=extra_data, *args, **kwargs
137
+ )
138
+ return perf_results
139
+ else:
140
+ return results
141
+
142
+ # -----------------------
143
+ # Main Experiment Evaluator
144
+ # -----------------------
145
+ def eval_exp(self, reload_env=False, *args, **kwargs):
146
+ """
147
+ Run evaluation only.
148
+ :param reload_env: If True, forces dataset/general init to run again.
149
+ """
150
+ self._prepare_environment(force_reload=reload_env)
151
+ results = self.exec_eval(*args, **kwargs)
152
+ if results is not None:
153
+ metrics_data, extra_data = self._validate_and_unpack(results)
154
+ return self.calc_perfs(
155
+ raw_metrics_data=metrics_data, extra_data=extra_data, *args, **kwargs
156
+ )
157
+ return None