halib 0.2.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. halib/__init__.py +94 -0
  2. halib/common/__init__.py +0 -0
  3. halib/common/common.py +326 -0
  4. halib/common/rich_color.py +285 -0
  5. halib/common.py +151 -0
  6. halib/csvfile.py +48 -0
  7. halib/cuda.py +39 -0
  8. halib/dataset.py +209 -0
  9. halib/exp/__init__.py +0 -0
  10. halib/exp/core/__init__.py +0 -0
  11. halib/exp/core/base_config.py +167 -0
  12. halib/exp/core/base_exp.py +147 -0
  13. halib/exp/core/param_gen.py +170 -0
  14. halib/exp/core/wandb_op.py +117 -0
  15. halib/exp/data/__init__.py +0 -0
  16. halib/exp/data/dataclass_util.py +41 -0
  17. halib/exp/data/dataset.py +208 -0
  18. halib/exp/data/torchloader.py +165 -0
  19. halib/exp/perf/__init__.py +0 -0
  20. halib/exp/perf/flop_calc.py +190 -0
  21. halib/exp/perf/gpu_mon.py +58 -0
  22. halib/exp/perf/perfcalc.py +470 -0
  23. halib/exp/perf/perfmetrics.py +137 -0
  24. halib/exp/perf/perftb.py +778 -0
  25. halib/exp/perf/profiler.py +507 -0
  26. halib/exp/viz/__init__.py +0 -0
  27. halib/exp/viz/plot.py +754 -0
  28. halib/filesys.py +117 -0
  29. halib/filetype/__init__.py +0 -0
  30. halib/filetype/csvfile.py +192 -0
  31. halib/filetype/ipynb.py +61 -0
  32. halib/filetype/jsonfile.py +19 -0
  33. halib/filetype/textfile.py +12 -0
  34. halib/filetype/videofile.py +266 -0
  35. halib/filetype/yamlfile.py +87 -0
  36. halib/gdrive.py +179 -0
  37. halib/gdrive_mkdir.py +41 -0
  38. halib/gdrive_test.py +37 -0
  39. halib/jsonfile.py +22 -0
  40. halib/listop.py +13 -0
  41. halib/online/__init__.py +0 -0
  42. halib/online/gdrive.py +229 -0
  43. halib/online/gdrive_mkdir.py +53 -0
  44. halib/online/gdrive_test.py +50 -0
  45. halib/online/projectmake.py +131 -0
  46. halib/online/tele_noti.py +165 -0
  47. halib/plot.py +301 -0
  48. halib/projectmake.py +115 -0
  49. halib/research/__init__.py +0 -0
  50. halib/research/base_config.py +100 -0
  51. halib/research/base_exp.py +157 -0
  52. halib/research/benchquery.py +131 -0
  53. halib/research/core/__init__.py +0 -0
  54. halib/research/core/base_config.py +144 -0
  55. halib/research/core/base_exp.py +157 -0
  56. halib/research/core/param_gen.py +108 -0
  57. halib/research/core/wandb_op.py +117 -0
  58. halib/research/data/__init__.py +0 -0
  59. halib/research/data/dataclass_util.py +41 -0
  60. halib/research/data/dataset.py +208 -0
  61. halib/research/data/torchloader.py +165 -0
  62. halib/research/dataset.py +208 -0
  63. halib/research/flop_csv.py +34 -0
  64. halib/research/flops.py +156 -0
  65. halib/research/metrics.py +137 -0
  66. halib/research/mics.py +74 -0
  67. halib/research/params_gen.py +108 -0
  68. halib/research/perf/__init__.py +0 -0
  69. halib/research/perf/flop_calc.py +190 -0
  70. halib/research/perf/gpu_mon.py +58 -0
  71. halib/research/perf/perfcalc.py +363 -0
  72. halib/research/perf/perfmetrics.py +137 -0
  73. halib/research/perf/perftb.py +778 -0
  74. halib/research/perf/profiler.py +301 -0
  75. halib/research/perfcalc.py +361 -0
  76. halib/research/perftb.py +780 -0
  77. halib/research/plot.py +758 -0
  78. halib/research/profiler.py +300 -0
  79. halib/research/torchloader.py +162 -0
  80. halib/research/viz/__init__.py +0 -0
  81. halib/research/viz/plot.py +754 -0
  82. halib/research/wandb_op.py +116 -0
  83. halib/rich_color.py +285 -0
  84. halib/sys/__init__.py +0 -0
  85. halib/sys/cmd.py +8 -0
  86. halib/sys/filesys.py +124 -0
  87. halib/system/__init__.py +0 -0
  88. halib/system/_list_pc.csv +6 -0
  89. halib/system/cmd.py +8 -0
  90. halib/system/filesys.py +164 -0
  91. halib/system/path.py +106 -0
  92. halib/tele_noti.py +166 -0
  93. halib/textfile.py +13 -0
  94. halib/torchloader.py +162 -0
  95. halib/utils/__init__.py +0 -0
  96. halib/utils/dataclass_util.py +40 -0
  97. halib/utils/dict.py +317 -0
  98. halib/utils/dict_op.py +9 -0
  99. halib/utils/gpu_mon.py +58 -0
  100. halib/utils/list.py +17 -0
  101. halib/utils/listop.py +13 -0
  102. halib/utils/slack.py +86 -0
  103. halib/utils/tele_noti.py +166 -0
  104. halib/utils/video.py +82 -0
  105. halib/videofile.py +139 -0
  106. halib-0.2.30.dist-info/METADATA +237 -0
  107. halib-0.2.30.dist-info/RECORD +110 -0
  108. halib-0.2.30.dist-info/WHEEL +5 -0
  109. halib-0.2.30.dist-info/licenses/LICENSE.txt +17 -0
  110. halib-0.2.30.dist-info/top_level.txt +1 -0
@@ -0,0 +1,301 @@
1
+ import os
2
+ import time
3
+ import json
4
+
5
+ from pathlib import Path
6
+ from pprint import pprint
7
+ from threading import Lock
8
+ from loguru import logger
9
+
10
+ from plotly.subplots import make_subplots
11
+ import plotly.graph_objects as go
12
+ import plotly.express as px # for dynamic color scales
13
+
14
+ from ...common.common import ConsoleLog
15
+
16
+
17
+ class zProfiler:
18
+ """A singleton profiler to measure execution time of contexts and steps.
19
+
20
+ Args:
21
+ interval_report (int): Frequency of periodic reports (0 to disable).
22
+ stop_to_view (bool): Pause execution to view reports if True (only in debug mode).
23
+ output_file (str): Path to save the profiling report.
24
+ report_format (str): Output format for reports ("json" or "csv").
25
+
26
+ Example:
27
+ prof = zProfiler()
28
+ prof.ctx_start("my_context")
29
+ prof.step_start("my_context", "step1")
30
+ time.sleep(0.1)
31
+ prof.step_end("my_context", "step1")
32
+ prof.ctx_end("my_context")
33
+ """
34
+
35
+ _instance = None
36
+ _lock = Lock()
37
+
38
+ def __new__(cls, *args, **kwargs):
39
+ with cls._lock:
40
+ if cls._instance is None:
41
+ cls._instance = super().__new__(cls)
42
+ return cls._instance
43
+
44
+ def __init__(
45
+ self,
46
+ ):
47
+ if not hasattr(self, "_initialized"):
48
+ self.time_dict = {}
49
+ self._initialized = True
50
+
51
+ def ctx_start(self, ctx_name="ctx_default"):
52
+ if not isinstance(ctx_name, str) or not ctx_name:
53
+ raise ValueError("ctx_name must be a non-empty string")
54
+ if ctx_name not in self.time_dict:
55
+ self.time_dict[ctx_name] = {
56
+ "start": time.perf_counter(),
57
+ "step_dict": {},
58
+ "report_count": 0,
59
+ }
60
+ self.time_dict[ctx_name]["report_count"] += 1
61
+
62
+ def ctx_end(self, ctx_name="ctx_default", report_func=None):
63
+ if ctx_name not in self.time_dict:
64
+ return
65
+ self.time_dict[ctx_name]["end"] = time.perf_counter()
66
+ self.time_dict[ctx_name]["duration"] = (
67
+ self.time_dict[ctx_name]["end"] - self.time_dict[ctx_name]["start"]
68
+ )
69
+
70
+ def step_start(self, ctx_name, step_name):
71
+ if not isinstance(step_name, str) or not step_name:
72
+ raise ValueError("step_name must be a non-empty string")
73
+ if ctx_name not in self.time_dict:
74
+ return
75
+ if step_name not in self.time_dict[ctx_name]["step_dict"]:
76
+ self.time_dict[ctx_name]["step_dict"][step_name] = []
77
+ self.time_dict[ctx_name]["step_dict"][step_name].append([time.perf_counter()])
78
+
79
+ def step_end(self, ctx_name, step_name):
80
+ if (
81
+ ctx_name not in self.time_dict
82
+ or step_name not in self.time_dict[ctx_name]["step_dict"]
83
+ ):
84
+ return
85
+ self.time_dict[ctx_name]["step_dict"][step_name][-1].append(time.perf_counter())
86
+
87
+ def _step_dict_to_detail(self, ctx_step_dict):
88
+ """
89
+ 'ctx_step_dict': {
90
+ │ │ 'preprocess': [
91
+ │ │ │ [278090.947465806, 278090.960484853],
92
+ │ │ │ [278091.178424035, 278091.230944486],
93
+ │ │ 'infer': [
94
+ │ │ │ [278090.960490534, 278091.178424035],
95
+ │ │ │ [278091.230944486, 278091.251378469],
96
+ │ }
97
+ """
98
+ assert (
99
+ len(ctx_step_dict.keys()) > 0
100
+ ), "step_dict must have only one key (step_name) for detail."
101
+ normed_ctx_step_dict = {}
102
+ for step_name, time_list in ctx_step_dict.items():
103
+ if not isinstance(ctx_step_dict[step_name], list):
104
+ raise ValueError(f"Step data for {step_name} must be a list")
105
+ # step_name = list(ctx_step_dict.keys())[0] # ! debug
106
+ normed_time_ls = []
107
+ for idx, time_data in enumerate(time_list):
108
+ elapsed_time = -1
109
+ if len(time_data) == 2:
110
+ start, end = time_data[0], time_data[1]
111
+ elapsed_time = end - start
112
+ normed_time_ls.append((idx, elapsed_time)) # including step
113
+ normed_ctx_step_dict[step_name] = normed_time_ls
114
+ return normed_ctx_step_dict
115
+
116
+ def get_report_dict(self, with_detail=False):
117
+ report_dict = {}
118
+ for ctx_name, ctx_dict in self.time_dict.items():
119
+ report_dict[ctx_name] = {
120
+ "duration": ctx_dict.get("duration", 0.0),
121
+ "step_dict": {
122
+ "summary": {"avg_time": {}, "percent_time": {}},
123
+ "detail": {},
124
+ },
125
+ }
126
+
127
+ if with_detail:
128
+ report_dict[ctx_name]["step_dict"]["detail"] = (
129
+ self._step_dict_to_detail(ctx_dict["step_dict"])
130
+ )
131
+ avg_time_list = []
132
+ epsilon = 1e-5
133
+ for step_name, step_list in ctx_dict["step_dict"].items():
134
+ durations = []
135
+ try:
136
+ for time_data in step_list:
137
+ if len(time_data) != 2:
138
+ continue
139
+ start, end = time_data
140
+ durations.append(end - start)
141
+ except Exception as e:
142
+ logger.error(
143
+ f"Error processing step {step_name} in context {ctx_name}: {e}"
144
+ )
145
+ continue
146
+ if not durations:
147
+ continue
148
+ avg_time = sum(durations) / len(durations)
149
+ if avg_time < epsilon:
150
+ continue
151
+ avg_time_list.append((step_name, avg_time))
152
+ total_avg_time = (
153
+ sum(time for _, time in avg_time_list) or 1e-10
154
+ ) # Avoid division by zero
155
+ for step_name, avg_time in avg_time_list:
156
+ report_dict[ctx_name]["step_dict"]["summary"]["percent_time"][
157
+ f"per_{step_name}"
158
+ ] = (avg_time / total_avg_time) * 100.0
159
+ report_dict[ctx_name]["step_dict"]["summary"]["avg_time"][
160
+ f"avg_{step_name}"
161
+ ] = avg_time
162
+ report_dict[ctx_name]["step_dict"]["summary"][
163
+ "total_avg_time"
164
+ ] = total_avg_time
165
+ report_dict[ctx_name]["step_dict"]["summary"] = dict(
166
+ sorted(report_dict[ctx_name]["step_dict"]["summary"].items())
167
+ )
168
+ return report_dict
169
+
170
+ @classmethod
171
+ @classmethod
172
+ def plot_formatted_data(
173
+ cls, profiler_data, outdir=None, file_format="png", do_show=False, tag=""
174
+ ):
175
+ """
176
+ Plot each context in a separate figure with bar + pie charts.
177
+ Save each figure in the specified format (png or svg).
178
+ """
179
+
180
+ if outdir is not None:
181
+ os.makedirs(outdir, exist_ok=True)
182
+
183
+ if file_format.lower() not in ["png", "svg"]:
184
+ raise ValueError("file_format must be 'png' or 'svg'")
185
+
186
+ results = {} # {context: fig}
187
+
188
+ for ctx, ctx_data in profiler_data.items():
189
+ summary = ctx_data["step_dict"]["summary"]
190
+ avg_times = summary["avg_time"]
191
+ percent_times = summary["percent_time"]
192
+
193
+ step_names = [s.replace("avg_", "") for s in avg_times.keys()]
194
+ # pprint(f'{step_names=}')
195
+ n_steps = len(step_names)
196
+
197
+ assert n_steps > 0, "No steps found for context: {}".format(ctx)
198
+ # Generate dynamic colors
199
+ colors = px.colors.sample_colorscale(
200
+ "Viridis", [i / (n_steps - 1) for i in range(n_steps)]
201
+ ) if n_steps > 1 else [px.colors.sample_colorscale("Viridis", [0])[0]]
202
+ # pprint(f'{len(colors)} colors generated for {n_steps} steps')
203
+ color_map = dict(zip(step_names, colors))
204
+
205
+ # Create figure
206
+ fig = make_subplots(
207
+ rows=1,
208
+ cols=2,
209
+ subplot_titles=[f"Avg Time", f"% Time"],
210
+ specs=[[{"type": "bar"}, {"type": "pie"}]],
211
+ )
212
+
213
+ # Bar chart
214
+ fig.add_trace(
215
+ go.Bar(
216
+ x=step_names,
217
+ y=list(avg_times.values()),
218
+ text=[f"{v*1000:.2f} ms" for v in avg_times.values()],
219
+ textposition="outside",
220
+ marker=dict(color=[color_map[s] for s in step_names]),
221
+ name="", # unified legend
222
+ showlegend=False,
223
+ ),
224
+ row=1,
225
+ col=1,
226
+ )
227
+
228
+ # Pie chart (colors match bar)
229
+ fig.add_trace(
230
+ go.Pie(
231
+ labels=step_names,
232
+ values=list(percent_times.values()),
233
+ marker=dict(colors=[color_map[s] for s in step_names]),
234
+ hole=0.4,
235
+ name="",
236
+ showlegend=True,
237
+ ),
238
+ row=1,
239
+ col=2,
240
+ )
241
+ tag_str = tag if tag and len(tag) > 0 else ""
242
+ # Layout
243
+ fig.update_layout(
244
+ title_text=f"[{tag_str}] Context Profiler: {ctx}",
245
+ width=1000,
246
+ height=400,
247
+ showlegend=True,
248
+ legend=dict(title="Steps", x=1.05, y=0.5, traceorder="normal"),
249
+ hovermode="x unified",
250
+ )
251
+
252
+ fig.update_xaxes(title_text="Steps", row=1, col=1)
253
+ fig.update_yaxes(title_text="Avg Time (ms)", row=1, col=1)
254
+
255
+ # Show figure
256
+ if do_show:
257
+ fig.show()
258
+
259
+ # Save figure
260
+ if outdir is not None:
261
+ file_prefix = ctx if len(tag_str) == 0 else f"{tag_str}_{ctx}"
262
+ file_path = os.path.join(outdir, f"{file_prefix}_summary.{file_format.lower()}")
263
+ fig.write_image(file_path)
264
+ print(f"Saved figure: {file_path}")
265
+
266
+ results[ctx] = fig
267
+
268
+ return results
269
+
270
+ def report_and_plot(self, outdir=None, file_format="png", do_show=False, tag=""):
271
+ """
272
+ Generate the profiling report and plot the formatted data.
273
+
274
+ Args:
275
+ outdir (str): Directory to save figures. If None, figures are only shown.
276
+ file_format (str): Target file format, "png" or "svg". Default is "png".
277
+ do_show (bool): Whether to display the plots. Default is False.
278
+ """
279
+ report = self.get_report_dict()
280
+ self.get_report_dict(with_detail=False)
281
+ return self.plot_formatted_data(
282
+ report, outdir=outdir, file_format=file_format, do_show=do_show, tag=tag
283
+ )
284
+ def meta_info(self):
285
+ """
286
+ Print the structure of the profiler's time dictionary.
287
+ Useful for debugging and understanding the profiler's internal state.
288
+ """
289
+ for ctx_name, ctx_dict in self.time_dict.items():
290
+ with ConsoleLog(f"Context: {ctx_name}"):
291
+ step_names = list(ctx_dict['step_dict'].keys())
292
+ for step_name in step_names:
293
+ pprint(f"Step: {step_name}")
294
+
295
+ def save_report_dict(self, output_file, with_detail=False):
296
+ try:
297
+ report = self.get_report_dict(with_detail=with_detail)
298
+ with open(output_file, "w") as f:
299
+ json.dump(report, f, indent=4)
300
+ except Exception as e:
301
+ logger.error(f"Failed to save report to {output_file}: {e}")
@@ -0,0 +1,361 @@
1
+ import os
2
+ import glob
3
+ from typing import Optional, Tuple
4
+ import pandas as pd
5
+
6
+ from abc import ABC, abstractmethod
7
+ from collections import OrderedDict
8
+
9
+ from ..system import filesys as fs
10
+ from ..common import now_str
11
+ from ..research.perftb import PerfTB
12
+ from ..research.metrics import *
13
+
14
+
15
+ REQUIRED_COLS = ["experiment", "dataset"]
16
+ CSV_FILE_POSTFIX = "__perf"
17
+ METRIC_PREFIX = "metric_"
18
+
19
+
20
+ class PerfCalc(ABC): # Abstract base class for performance calculation
21
+ @abstractmethod
22
+ def get_experiment_name(self) -> str:
23
+ """
24
+ Return the name of the experiment.
25
+ This function should be overridden by the subclass if needed.
26
+ """
27
+ pass
28
+
29
+ @abstractmethod
30
+ def get_dataset_name(self) -> str:
31
+ """
32
+ Return the name of the dataset.
33
+ This function should be overridden by the subclass if needed.
34
+ """
35
+ pass
36
+
37
+ @abstractmethod
38
+ def get_metric_backend(self) -> MetricsBackend:
39
+ """
40
+ Return a list of metric names to be used for performance calculation OR a dictionaray with keys as metric names and values as metric instances of torchmetrics.Metric. For example: {"accuracy": Accuracy(), "precision": Precision()}
41
+
42
+ """
43
+ pass
44
+
45
+ def valid_proc_extra_data(self, proc_extra_data):
46
+ # make sure that all items in proc_extra_data are dictionaries, with same keys
47
+ if proc_extra_data is None or len(proc_extra_data) == 0:
48
+ return
49
+ if not all(isinstance(item, dict) for item in proc_extra_data):
50
+ raise TypeError("All items in proc_extra_data must be dictionaries")
51
+
52
+ if not all(
53
+ item.keys() == proc_extra_data[0].keys() for item in proc_extra_data
54
+ ):
55
+ raise ValueError(
56
+ "All dictionaries in proc_extra_data must have the same keys"
57
+ )
58
+
59
+ def valid_proc_metric_raw_data(self, metric_names, proc_metric_raw_data):
60
+ # make sure that all items in proc_metric_raw_data are dictionaries, with same keys as metric_names
61
+ assert (
62
+ isinstance(proc_metric_raw_data, list) and len(proc_metric_raw_data) > 0
63
+ ), "raw_data_for_metrics must be a non-empty list of dictionaries"
64
+
65
+ # make sure that all items in proc_metric_raw_data are dictionaries with keys as metric_names
66
+ if not all(isinstance(item, dict) for item in proc_metric_raw_data):
67
+ raise TypeError("All items in raw_data_for_metrics must be dictionaries")
68
+ if not all(
69
+ set(item.keys()) == set(metric_names) for item in proc_metric_raw_data
70
+ ):
71
+ raise ValueError(
72
+ "All dictionaries in raw_data_for_metrics must have the same keys as metric_names"
73
+ )
74
+
75
+ # ! only need to override this method if torchmetrics are not used
76
+ def calc_exp_perf_metrics(
77
+ self, metric_names, raw_metrics_data, extra_data=None, *args, **kwargs
78
+ ):
79
+ assert isinstance(raw_metrics_data, dict) or isinstance(
80
+ raw_metrics_data, list
81
+ ), "raw_data_for_metrics must be a dictionary or a list"
82
+
83
+ if extra_data is not None:
84
+ assert isinstance(
85
+ extra_data, type(raw_metrics_data)
86
+ ), "extra_data must be of the same type as raw_data_for_metrics (dict or list)"
87
+ # prepare raw_metric data for processing
88
+ proc_metric_raw_data_ls = (
89
+ raw_metrics_data
90
+ if isinstance(raw_metrics_data, list)
91
+ else [raw_metrics_data.copy()]
92
+ )
93
+ self.valid_proc_metric_raw_data(metric_names, proc_metric_raw_data_ls)
94
+ # prepare extra data for processing
95
+ proc_extra_data_ls = []
96
+ if extra_data is not None:
97
+ proc_extra_data_ls = (
98
+ extra_data if isinstance(extra_data, list) else [extra_data.copy()]
99
+ )
100
+ assert len(proc_extra_data_ls) == len(
101
+ proc_metric_raw_data_ls
102
+ ), "extra_data must have the same length as raw_data_for_metrics if it is a list"
103
+ # validate the extra_data
104
+ self.valid_proc_extra_data(proc_extra_data_ls)
105
+
106
+ # calculate the metrics output results
107
+ metrics_backend = self.get_metric_backend()
108
+ proc_outdict_list = []
109
+ for idx, raw_metrics_data in enumerate(proc_metric_raw_data_ls):
110
+ out_dict = {
111
+ "dataset": self.get_dataset_name(),
112
+ "experiment": self.get_experiment_name(),
113
+ }
114
+ custom_fields = []
115
+ if len(proc_extra_data_ls) > 0:
116
+ # add extra data to the output dictionary
117
+ extra_data_item = proc_extra_data_ls[idx]
118
+ out_dict.update(extra_data_item)
119
+ custom_fields = list(extra_data_item.keys())
120
+ metric_results = metrics_backend.calc_metrics(
121
+ metrics_data_dict=raw_metrics_data, *args, **kwargs
122
+ )
123
+ metric_results_prefix = {
124
+ f"metric_{k}": v for k, v in metric_results.items()
125
+ }
126
+ out_dict.update(metric_results_prefix)
127
+ ordered_cols = (
128
+ REQUIRED_COLS + custom_fields + list(metric_results_prefix.keys())
129
+ )
130
+ out_dict = OrderedDict(
131
+ (col, out_dict[col]) for col in ordered_cols if col in out_dict
132
+ )
133
+ proc_outdict_list.append(out_dict)
134
+
135
+ return proc_outdict_list
136
+
137
+ #! custom kwargs:
138
+ #! outfile - if provided, will save the output to a CSV file with the given path
139
+ #! outdir - if provided, will save the output to a CSV file in the given directory with a generated filename
140
+ #! return_df - if True, will return a DataFrame instead of a dictionary
141
+ def calc_perfs(
142
+ self,
143
+ raw_metrics_data: Union[List[dict], dict],
144
+ extra_data: Optional[Union[List[dict], dict]] = None,
145
+ *args,
146
+ **kwargs,
147
+ ) -> Tuple[Union[List[OrderedDict], pd.DataFrame], Optional[str]]:
148
+ """
149
+ Calculate the metrics.
150
+ This function should be overridden by the subclass if needed.
151
+ Must return a dictionary with keys as metric names and values as the calculated metrics.
152
+ """
153
+ metric_names = self.get_metric_backend().metric_names
154
+ out_dict_list = self.calc_exp_perf_metrics(
155
+ metric_names=metric_names,
156
+ raw_metrics_data=raw_metrics_data,
157
+ extra_data=extra_data,
158
+ *args,
159
+ **kwargs,
160
+ )
161
+ csv_outfile = kwargs.get("outfile", None)
162
+ if csv_outfile is not None:
163
+ filePathNoExt, _ = os.path.splitext(csv_outfile)
164
+ # pprint(f"CSV Outfile Path (No Ext): {filePathNoExt}")
165
+ csv_outfile = f"{filePathNoExt}{CSV_FILE_POSTFIX}.csv"
166
+ elif "outdir" in kwargs:
167
+ csvoutdir = kwargs["outdir"]
168
+ csvfilename = f"{now_str()}_{self.get_dataset_name()}_{self.get_experiment_name()}_{CSV_FILE_POSTFIX}.csv"
169
+ csv_outfile = os.path.join(csvoutdir, csvfilename)
170
+
171
+ # convert out_dict to a DataFrame
172
+ df = pd.DataFrame(out_dict_list)
173
+ # get the orders of the columns as the orders or the keys in out_dict
174
+ ordered_cols = list(out_dict_list[0].keys())
175
+ df = df[ordered_cols] # reorder columns
176
+ if csv_outfile:
177
+ df.to_csv(csv_outfile, index=False, sep=";", encoding="utf-8")
178
+ return_df = kwargs.get("return_df", False)
179
+ if return_df: # return DataFrame instead of dict if requested
180
+ return df, csv_outfile
181
+ else:
182
+ return out_dict_list, csv_outfile
183
+
184
+ @staticmethod
185
+ def default_exp_csv_filter_fn(exp_file_name: str) -> bool:
186
+ """
187
+ Default filter function for experiments.
188
+ Returns True if the experiment name does not start with "test_" or "debug_".
189
+ """
190
+ return "__perf.csv" in exp_file_name
191
+
192
+ @classmethod
193
+ def get_perftb_for_multi_exps(
194
+ cls,
195
+ indir: str,
196
+ exp_csv_filter_fn=default_exp_csv_filter_fn,
197
+ include_file_name=False,
198
+ csv_sep=";",
199
+ ) -> PerfTB:
200
+ """
201
+ Generate a performance report by scanning experiment subdirectories.
202
+ Must return a dictionary with keys as metric names and values as performance tables.
203
+ """
204
+
205
+ def get_df_for_all_exp_perf(csv_perf_files, csv_sep=";"):
206
+ """
207
+ Create a single DataFrame from all CSV files.
208
+ Assumes all CSV files MAY have different metrics
209
+ """
210
+ cols = []
211
+ FILE_NAME_COL = "file_name" if include_file_name else None
212
+
213
+ for csv_file in csv_perf_files:
214
+ temp_df = pd.read_csv(csv_file, sep=csv_sep)
215
+ if FILE_NAME_COL:
216
+ temp_df[FILE_NAME_COL] = fs.get_file_name(
217
+ csv_file, split_file_ext=False
218
+ )
219
+ # csvfile.fn_display_df(temp_df)
220
+ temp_df_cols = temp_df.columns.tolist()
221
+ for col in temp_df_cols:
222
+ if col not in cols:
223
+ cols.append(col)
224
+
225
+ df = pd.DataFrame(columns=cols)
226
+ for csv_file in csv_perf_files:
227
+ temp_df = pd.read_csv(csv_file, sep=csv_sep)
228
+ if FILE_NAME_COL:
229
+ temp_df[FILE_NAME_COL] = fs.get_file_name(
230
+ csv_file, split_file_ext=False
231
+ )
232
+ # Drop all-NA columns to avoid dtype inconsistency
233
+ temp_df = temp_df.dropna(axis=1, how="all")
234
+ # ensure all columns are present in the final DataFrame
235
+ for col in cols:
236
+ if col not in temp_df.columns:
237
+ temp_df[col] = None # fill missing columns with None
238
+ df = pd.concat([df, temp_df], ignore_index=True)
239
+ # assert that REQUIRED_COLS are present in the DataFrame
240
+ # pprint(df.columns.tolist())
241
+ sticky_cols = REQUIRED_COLS + (
242
+ [FILE_NAME_COL] if include_file_name else []
243
+ ) # columns that must always be present
244
+ for col in sticky_cols:
245
+ if col not in df.columns:
246
+ raise ValueError(
247
+ f"Required column '{col}' is missing from the DataFrame. REQUIRED_COLS = {sticky_cols}"
248
+ )
249
+ metric_cols = [col for col in df.columns if col.startswith(METRIC_PREFIX)]
250
+ assert (
251
+ len(metric_cols) > 0
252
+ ), "No metric columns found in the DataFrame. Ensure that the CSV files contain metric columns starting with 'metric_'."
253
+ final_cols = sticky_cols + metric_cols
254
+ df = df[final_cols]
255
+ # # !hahv debug
256
+ # pprint("------ Final DataFrame Columns ------")
257
+ # csvfile.fn_display_df(df)
258
+ # ! validate all rows in df before returning
259
+ # make sure all rows will have at least values for REQUIRED_COLS and at least one metric column
260
+ for index, row in df.iterrows():
261
+ if not all(col in row and pd.notna(row[col]) for col in sticky_cols):
262
+ raise ValueError(
263
+ f"Row {index} is missing required columns or has NaN values in required columns: {row}"
264
+ )
265
+ if not any(pd.notna(row[col]) for col in metric_cols):
266
+ raise ValueError(f"Row {index} has no metric values: {row}")
267
+ # make sure these is no (experiment, dataset) pair that is duplicated
268
+ duplicates = df.duplicated(subset=sticky_cols, keep=False)
269
+ if duplicates.any():
270
+ raise ValueError(
271
+ "Duplicate (experiment, dataset) pairs found in the DataFrame. Please ensure that each experiment-dataset combination is unique."
272
+ )
273
+ return df
274
+
275
+ def mk_perftb_report(df):
276
+ """
277
+ Create a performance report table from the DataFrame.
278
+ This function should be customized based on the specific requirements of the report.
279
+ """
280
+ perftb = PerfTB()
281
+ # find all "dataset" values (unique)
282
+ dataset_names = list(df["dataset"].unique())
283
+ # find all columns that start with METRIC_PREFIX
284
+ metric_cols = [col for col in df.columns if col.startswith(METRIC_PREFIX)]
285
+
286
+ # Determine which metrics are associated with each dataset.
287
+ # Since a dataset may appear in multiple rows and may not include all metrics in each, identify the row with the same dataset that contains the most non-NaN metric values. The set of metrics for that dataset is defined by the non-NaN metrics in that row.
288
+
289
+ dataset_metrics = {}
290
+ for dataset_name in dataset_names:
291
+ dataset_rows = df[df["dataset"] == dataset_name]
292
+ # Find the row with the most non-NaN metric values
293
+ max_non_nan_row = dataset_rows[metric_cols].count(axis=1).idxmax()
294
+ metrics_for_dataset = (
295
+ dataset_rows.loc[max_non_nan_row, metric_cols]
296
+ .dropna()
297
+ .index.tolist()
298
+ )
299
+ dataset_metrics[dataset_name] = metrics_for_dataset
300
+
301
+ for dataset_name, metrics in dataset_metrics.items():
302
+ # Create a new row for the performance table
303
+ perftb.add_dataset(dataset_name, metrics)
304
+
305
+ for _, row in df.iterrows():
306
+ dataset_name = row["dataset"]
307
+ ds_metrics = dataset_metrics.get(dataset_name)
308
+ if dataset_name in dataset_metrics:
309
+ # Add the metrics for this row to the performance table
310
+ exp_name = row.get("experiment")
311
+ exp_metric_values = {}
312
+ for metric in ds_metrics:
313
+ if metric in row and pd.notna(row[metric]):
314
+ exp_metric_values[metric] = row[metric]
315
+ perftb.add_experiment(
316
+ experiment_name=exp_name,
317
+ dataset_name=dataset_name,
318
+ metrics=exp_metric_values,
319
+ )
320
+
321
+ return perftb
322
+
323
+ assert os.path.exists(indir), f"Input directory {indir} does not exist."
324
+
325
+ csv_perf_files = []
326
+ # Find experiment subdirectories
327
+ exp_dirs = [
328
+ os.path.join(indir, d)
329
+ for d in os.listdir(indir)
330
+ if os.path.isdir(os.path.join(indir, d))
331
+ ]
332
+ if len(exp_dirs) == 0:
333
+ csv_perf_files = glob.glob(os.path.join(indir, f"*.csv"))
334
+ csv_perf_files = [
335
+ file_item
336
+ for file_item in csv_perf_files
337
+ if exp_csv_filter_fn(file_item)
338
+ ]
339
+ else:
340
+ # multiple experiment directories found
341
+ # Collect all matching CSV files in those subdirs
342
+ for exp_dir in exp_dirs:
343
+ # pprint(f"Searching in experiment directory: {exp_dir}")
344
+ matched = glob.glob(os.path.join(exp_dir, f"*.csv"))
345
+ matched = [
346
+ file_item for file_item in matched if exp_csv_filter_fn(file_item)
347
+ ]
348
+ csv_perf_files.extend(matched)
349
+
350
+ assert (
351
+ len(csv_perf_files) > 0
352
+ ), f"No CSV files matching pattern '{exp_csv_filter_fn}' found in the experiment directories."
353
+
354
+ assert (
355
+ len(csv_perf_files) > 0
356
+ ), f"No CSV files matching pattern '{exp_csv_filter_fn}' found in the experiment directories."
357
+
358
+ all_exp_perf_df = get_df_for_all_exp_perf(csv_perf_files, csv_sep=csv_sep)
359
+ # csvfile.fn_display_df(all_exp_perf_df)
360
+ perf_tb = mk_perftb_report(all_exp_perf_df)
361
+ return perf_tb