halib 0.1.65__py3-none-any.whl → 0.1.67__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,134 @@
1
+ # -------------------------------
2
+ # Metrics Backend Interface
3
+ # -------------------------------
4
+ import inspect
5
+ from typing import Dict, Union, List, Any
6
+ from abc import ABC, abstractmethod
7
+
8
+ class MetricsBackend(ABC):
9
+ """Interface for pluggable metrics computation backends."""
10
+
11
+ def __init__(self, metrics_info: Union[List[str], Dict[str, Any]]):
12
+ """
13
+ Initialize the backend with optional metrics_info.
14
+ """
15
+ self.metric_info = metrics_info
16
+ self.validate_metrics_info(self.metric_info)
17
+
18
+ @property
19
+ def metric_names(self) -> List[str]:
20
+ """
21
+ Return a list of metric names.
22
+ If metric_info is a dict, return its keys; if it's a list, return it directly.
23
+ """
24
+ if isinstance(self.metric_info, dict):
25
+ return list(self.metric_info.keys())
26
+ elif isinstance(self.metric_info, list):
27
+ return self.metric_info
28
+ else:
29
+ raise TypeError("metric_info must be a list or a dict")
30
+
31
+ def validate_metrics_info(self, metrics_info):
32
+ if isinstance(metrics_info, list):
33
+ return metrics_info
34
+ elif isinstance(metrics_info, dict):
35
+ return {k: v for k, v in metrics_info.items() if isinstance(k, str)}
36
+ else:
37
+ raise TypeError(
38
+ "metrics_info must be a list of strings or a dict with string keys"
39
+ )
40
+
41
+ @abstractmethod
42
+ def compute_metrics(
43
+ self, metrics_info: Union[List[str], Dict[str, Any]], metrics_data_dict: Dict[str, Any], *args, **kwargs
44
+ ) -> Dict[str, Any]:
45
+ pass
46
+
47
+ def prepare_metrics_backend_data(
48
+ self, raw_metric_data, *args, **kwargs
49
+ ):
50
+ """
51
+ Prepare the data for the metrics backend.
52
+ This method can be overridden by subclasses to customize data preparation.
53
+ """
54
+ return raw_metric_data
55
+
56
+ def calc_metrics(
57
+ self, metrics_data_dict: Dict[str, Any], *args, **kwargs
58
+ ) -> Dict[str, Any]:
59
+ """
60
+ Calculate metrics based on the provided metrics_info and data.
61
+ This method should be overridden by subclasses to implement specific metric calculations.
62
+ """
63
+ # prevalidate the metrics_data_dict
64
+ for metric in self.metric_names:
65
+ if metric not in metrics_data_dict:
66
+ raise ValueError(f"Metric '{metric}' not found in provided data.")
67
+ # Prepare the data for the backend
68
+ metrics_data_dict = self.prepare_metrics_backend_data(
69
+ metrics_data_dict, *args, **kwargs
70
+ )
71
+ # Call the abstract method to compute metrics
72
+ return self.compute_metrics(self.metric_info, metrics_data_dict, *args, **kwargs)
73
+
74
+ class TorchMetricsBackend(MetricsBackend):
75
+ """TorchMetrics-based backend implementation."""
76
+
77
+ def __init__(self, metrics_info: Union[List[str], Dict[str, Any]]):
78
+ try:
79
+ import torch
80
+ from torchmetrics import Metric
81
+ except ImportError:
82
+ raise ImportError(
83
+ "TorchMetricsBackend requires torch and torchmetrics to be installed."
84
+ )
85
+ self.metric_info = metrics_info
86
+ self.torch = torch
87
+ self.Metric = Metric
88
+ self.validate_metrics_info(metrics_info)
89
+
90
+ def validate_metrics_info(self, metrics_info):
91
+ if not isinstance(metrics_info, dict):
92
+ raise TypeError(
93
+ "TorchMetricsBackend requires metrics_info as a dict {name: MetricInstance}"
94
+ )
95
+ for k, v in metrics_info.items():
96
+ if not isinstance(k, str):
97
+ raise TypeError(f"Key '{k}' is not a string")
98
+ if not isinstance(v, self.Metric):
99
+ raise TypeError(f"Value for key '{k}' must be a torchmetrics.Metric")
100
+ return metrics_info
101
+
102
+ def compute_metrics(self, metrics_info, metrics_data_dict, *args, **kwargs):
103
+ out_dict = {}
104
+ for metric, metric_instance in metrics_info.items():
105
+ if metric not in metrics_data_dict:
106
+ raise ValueError(f"Metric '{metric}' not found in provided data.")
107
+
108
+ metric_data = metrics_data_dict[metric]
109
+ sig = inspect.signature(metric_instance.update)
110
+ expected_args = list(sig.parameters.values())
111
+
112
+ if isinstance(metric_data, dict):
113
+ args = [metric_data[param.name] for param in expected_args]
114
+ elif isinstance(metric_data, (list, tuple)):
115
+ args = metric_data
116
+ else:
117
+ raise TypeError(f"Unsupported data format for metric '{metric}'")
118
+
119
+ if len(expected_args) == 1:
120
+ metric_instance.update(args)
121
+ else:
122
+ metric_instance.update(*args)
123
+
124
+ computed_value = metric_instance.compute()
125
+ if isinstance(computed_value, self.torch.Tensor):
126
+ computed_value = (
127
+ computed_value.item()
128
+ if computed_value.numel() == 1
129
+ else computed_value.tolist()
130
+ )
131
+
132
+
133
+ out_dict[metric] = computed_value
134
+ return out_dict
@@ -1,71 +1,27 @@
1
1
  import os
2
2
  import glob
3
- import inspect
3
+ from typing import Optional, Tuple
4
4
  import pandas as pd
5
5
 
6
- from typing import Dict
7
- from functools import wraps
8
6
  from rich.pretty import pprint
9
7
 
10
8
  from abc import ABC, abstractmethod
9
+ from collections import OrderedDict
11
10
 
12
11
  from ..filetype import csvfile
12
+ from ..system import filesys as fs
13
13
  from ..common import now_str
14
14
  from ..research.perftb import PerfTB
15
- from collections import OrderedDict
15
+ from ..research.metrics import *
16
16
 
17
- # try to import torch, and torchmetrics
18
- try:
19
- import torch
20
- import torchmetrics
21
- from torchmetrics import Metric
22
- except ImportError:
23
- raise ImportError("Please install torch and torchmetrics to use this module.")
24
-
25
- def validate_torch_metrics(fn):
26
- @wraps(fn)
27
- def wrapper(self, *args, **kwargs):
28
- result = fn(self, *args, **kwargs)
29
-
30
- if not isinstance(result, dict):
31
- raise TypeError("torch_metrics() must return a dictionary")
32
-
33
- for k, v in result.items():
34
- if not isinstance(k, str):
35
- raise TypeError(f"Key '{k}' is not a string")
36
- if not isinstance(v, Metric):
37
- raise TypeError(
38
- f"Value for key '{k}' is not a torchmetrics.Metric (got {type(v).__name__})"
39
- )
40
-
41
- return result
42
-
43
- return wrapper
44
- def valid_custom_fields(fn):
45
- @wraps(fn)
46
- def wrapper(self, *args, **kwargs):
47
- rs = fn(self, *args, **kwargs)
48
- if not isinstance(rs, tuple) or len(rs) != 2:
49
- raise ValueError("Function must return a tuple (outdict, custom_fields)")
50
- outdict, custom_fields = rs
51
- if not isinstance(outdict, dict):
52
- raise TypeError("Output must be a dictionary")
53
- if not isinstance(custom_fields, list):
54
- raise TypeError("Custom fields must be a list")
55
- for field in custom_fields:
56
- if not isinstance(field, str):
57
- raise TypeError(f"Custom field '{field}' is not a string")
58
- return outdict, custom_fields
59
-
60
- return wrapper
61
17
 
62
18
  REQUIRED_COLS = ["experiment", "dataset"]
63
19
  CSV_FILE_POSTFIX = "__perf"
20
+ METRIC_PREFIX = "metric_"
64
21
 
65
- class PerfCalc(ABC): # Abstract base class for performance calculation
66
-
22
+ class PerfCalc(ABC): # Abstract base class for performance calculation
67
23
  @abstractmethod
68
- def get_experiment_name(self):
24
+ def get_experiment_name(self) -> str:
69
25
  """
70
26
  Return the name of the experiment.
71
27
  This function should be overridden by the subclass if needed.
@@ -73,7 +29,7 @@ class PerfCalc(ABC): # Abstract base class for performance calculation
73
29
  pass
74
30
 
75
31
  @abstractmethod
76
- def get_dataset_name(self):
32
+ def get_dataset_name(self) -> str:
77
33
  """
78
34
  Return the name of the dataset.
79
35
  This function should be overridden by the subclass if needed.
@@ -81,161 +37,135 @@ class PerfCalc(ABC): # Abstract base class for performance calculation
81
37
  pass
82
38
 
83
39
  @abstractmethod
84
- def get_metrics_info(self):
40
+ def get_metric_backend(self) -> MetricsBackend:
85
41
  """
86
42
  Return a list of metric names to be used for performance calculation OR a dictionaray with keys as metric names and values as metric instances of torchmetrics.Metric. For example: {"accuracy": Accuracy(), "precision": Precision()}
87
43
 
88
44
  """
89
45
  pass
90
46
 
91
- def calc_exp_outdict_custom_fields(self, outdict, *args, **kwargs):
92
- """Can be overridden by the subclass to add custom fields to the output dictionary.
93
- ! must return the modified outdict, and a ordered list of custom fields to be added to the output dictionary.
94
- """
95
- return outdict, []
96
-
97
- # ! can be override, but ONLY if torchmetrics are used
98
- # Prepare the exp data for torch metrics.
99
- def prepare_torch_metrics_exp_data(self, metric_names, *args, **kwargs):
100
- """
101
- Prepare the data for metrics.
102
- This function should be overridden by the subclass if needed.
103
- Must return a dictionary with keys as metric names and values as the data to be used for those metrics.
104
- NOTE: that the data (for each metric) must be in the format expected by the torchmetrics instance (for that metric). E.g: {"accuracy": {"preds": [...], "target": [...]}, ...} since torchmetrics expects the data in a specific format.
105
- """
106
- pass
107
-
108
- def __validate_metrics_info(self, metrics_info):
109
- """
110
- Validate the metrics_info to ensure it is a list or a dictionary with valid metric names and instances.
111
- """
112
- if not isinstance(metrics_info, (list, dict)):
113
- raise TypeError(f"Metrics info must be a list or a dictionary, got {type(metrics_info).__name__}")
114
-
115
- if isinstance(metrics_info, dict):
116
- for k, v in metrics_info.items():
117
- if not isinstance(k, str):
118
- raise TypeError(f"Key '{k}' is not a string")
119
- if not isinstance(v, Metric):
120
- raise TypeError(f"Value for key '{k}' is not a torchmetrics.Metric (got {type(v).__name__})")
121
- elif isinstance(metrics_info, list):
122
- for metric in metrics_info:
123
- if not isinstance(metric, str):
124
- raise TypeError(f"Metric '{metric}' is not a string")
125
- return metrics_info
126
- def __calc_exp_perf_metrics(self, *args, **kwargs):
127
- """
128
- Calculate the performance metrics for the experiment.
129
- """
130
- metrics_info = self.__validate_metrics_info(self.get_metrics_info())
131
- USED_TORCHMETRICS = isinstance(metrics_info, dict)
132
- metric_names = metrics_info if isinstance(metrics_info, list) else list(metrics_info.keys())
133
- out_dict = {metric: None for metric in metric_names}
134
- out_dict["dataset"] = self.get_dataset_name()
135
- out_dict["experiment"] = self.get_experiment_name()
136
- out_dict, custom_fields = self.calc_exp_outdict_custom_fields(
137
- outdict=out_dict, *args, **kwargs
138
- )
139
- if USED_TORCHMETRICS:
140
- torch_metrics_dict = self.get_metrics_info()
141
- all_metric_data = self.prepare_torch_metrics_exp_data(
142
- metric_names, *args, **kwargs
47
+ def valid_proc_extra_data(
48
+ self, proc_extra_data
49
+ ):
50
+ # make sure that all items in proc_extra_data are dictionaries, with same keys
51
+ if proc_extra_data is None or len(proc_extra_data) == 0:
52
+ return
53
+ if not all(isinstance(item, dict) for item in proc_extra_data):
54
+ raise TypeError("All items in proc_extra_data must be dictionaries")
55
+
56
+ if not all(item.keys() == proc_extra_data[0].keys() for item in proc_extra_data):
57
+ raise ValueError("All dictionaries in proc_extra_data must have the same keys")
58
+
59
+ def valid_proc_metric_raw_data(
60
+ self, metric_names, proc_metric_raw_data
61
+ ):
62
+ # make sure that all items in proc_metric_raw_data are dictionaries, with same keys as metric_names
63
+ assert isinstance(proc_metric_raw_data, list) and len(proc_metric_raw_data) > 0, \
64
+ "raw_data_for_metrics must be a non-empty list of dictionaries"
65
+
66
+ # make sure that all items in proc_metric_raw_data are dictionaries with keys as metric_names
67
+ if not all(isinstance(item, dict) for item in proc_metric_raw_data):
68
+ raise TypeError("All items in raw_data_for_metrics must be dictionaries")
69
+ if not all( set(item.keys()) == set(metric_names) for item in proc_metric_raw_data):
70
+ raise ValueError(
71
+ "All dictionaries in raw_data_for_metrics must have the same keys as metric_names"
143
72
  )
144
- metric_col_names = []
145
- for metric in metric_names:
146
- if metric not in all_metric_data:
147
- raise ValueError(f"Metric '{metric}' not found in provided data.")
148
- tmetric = torch_metrics_dict[metric] # torchmetrics instance
149
- metric_data = all_metric_data[metric] # should be a dict of args/kwargs
150
- # Inspect expected parameters for the metric's update() method
151
- sig = inspect.signature(tmetric.update)
152
- expected_args = list(sig.parameters.values())
153
- # Prepare args in correct order
154
- if isinstance(metric_data, dict):
155
- # Match dict keys to parameter names
156
- args = [metric_data[param.name] for param in expected_args]
157
- elif isinstance(metric_data, (list, tuple)):
158
- args = metric_data
159
- else:
160
- raise TypeError(f"Unsupported data format for metric '{metric}'")
161
-
162
- # Call update and compute
163
- if len(expected_args) == 1:
164
- tmetric.update(args) # pass as single argument
165
- else:
166
- tmetric.update(*args) # unpack multiple arguments
167
- computed_value = tmetric.compute()
168
- # ensure the computed value converted to a scala value or list array
169
- if isinstance(computed_value, torch.Tensor):
170
- if computed_value.numel() == 1:
171
- computed_value = computed_value.item()
172
- else:
173
- computed_value = computed_value.tolist()
174
- col_name = f"metric_{metric}" if "metric_" not in metric else metric
175
- metric_col_names.append(col_name)
176
- out_dict[col_name] = computed_value
177
- else:
178
- # If torchmetrics are not used, calculate metrics using the custom method
179
- metric_rs_dict = self.calc_exp_perf_metrics(
180
- metric_names, *args, **kwargs)
181
- for metric in metric_names:
182
- if metric not in metric_rs_dict:
183
- raise ValueError(f"Metric '{metric}' not found in provided data.")
184
- col_name = f"metric_{metric}" if "metric_" not in metric else metric
185
- out_dict[col_name] = metric_rs_dict[metric]
186
- metric_col_names = [f"metric_{metric}" for metric in metric_names]
187
- ordered_cols = REQUIRED_COLS + custom_fields + metric_col_names
188
- # create a new ordered dictionary with the correct order
189
- out_dict = OrderedDict((col, out_dict[col]) for col in ordered_cols if col in out_dict)
190
- return out_dict
191
73
 
192
74
  # ! only need to override this method if torchmetrics are not used
193
- def calc_exp_perf_metrics(self, metric_names, *args, **kwargs):
194
- """
195
- Calculate the performance metrics for the experiment, but not using torchmetrics.
196
- This function should be overridden by the subclass if needed.
197
- Must return a dictionary with keys as metric names and values as the calculated metrics.
198
- """
199
- raise NotImplementedError("calc_exp_perf_metrics() must be overridden by the subclass if torchmetrics are not used.")
75
+ def calc_exp_perf_metrics(
76
+ self, metric_names, raw_metrics_data, extra_data=None, *args, **kwargs
77
+ ):
78
+ assert isinstance(raw_metrics_data, dict) or isinstance(raw_metrics_data, list), \
79
+ "raw_data_for_metrics must be a dictionary or a list"
80
+
81
+ if extra_data is not None:
82
+ assert isinstance(extra_data, type(raw_metrics_data)), \
83
+ "extra_data must be of the same type as raw_data_for_metrics (dict or list)"
84
+ # prepare raw_metric data for processing
85
+ proc_metric_raw_data_ls = raw_metrics_data if isinstance(raw_metrics_data, list) else [raw_metrics_data.copy()]
86
+ self.valid_proc_metric_raw_data(metric_names, proc_metric_raw_data_ls)
87
+ # prepare extra data for processing
88
+ proc_extra_data_ls = []
89
+ if extra_data is not None:
90
+ proc_extra_data_ls = extra_data if isinstance(extra_data, list) else [extra_data.copy()]
91
+ assert len(proc_extra_data_ls) == len(proc_metric_raw_data_ls), \
92
+ "extra_data must have the same length as raw_data_for_metrics if it is a list"
93
+ # validate the extra_data
94
+ self.valid_proc_extra_data(proc_extra_data_ls)
95
+
96
+ # calculate the metrics output results
97
+ metrics_backend = self.get_metric_backend()
98
+ proc_outdict_list = []
99
+ for idx, raw_metrics_data in enumerate(proc_metric_raw_data_ls):
100
+ out_dict = {
101
+ "dataset": self.get_dataset_name(),
102
+ "experiment": self.get_experiment_name(),
103
+ }
104
+ custom_fields = []
105
+ if len(proc_extra_data_ls)> 0:
106
+ # add extra data to the output dictionary
107
+ extra_data_item = proc_extra_data_ls[idx]
108
+ out_dict.update(extra_data_item)
109
+ custom_fields = list(extra_data_item.keys())
110
+ metric_results = metrics_backend.calc_metrics(
111
+ metrics_data_dict=raw_metrics_data, *args, **kwargs
112
+ )
113
+ metric_results_prefix = {f"metric_{k}": v for k, v in metric_results.items()}
114
+ out_dict.update(metric_results_prefix)
115
+ ordered_cols = (
116
+ REQUIRED_COLS + custom_fields + list(metric_results_prefix.keys())
117
+ )
118
+ out_dict = OrderedDict(
119
+ (col, out_dict[col]) for col in ordered_cols if col in out_dict
120
+ )
121
+ proc_outdict_list.append(out_dict)
200
122
 
123
+ return proc_outdict_list
201
124
 
202
125
  #! custom kwargs:
203
126
  #! outfile - if provided, will save the output to a CSV file with the given path
204
127
  #! outdir - if provided, will save the output to a CSV file in the given directory with a generated filename
205
128
  #! return_df - if True, will return a DataFrame instead of a dictionary
206
-
207
- def calc_save_exp_perfs(self, *args, **kwargs):
129
+ def calc_and_save_exp_perfs(
130
+ self,
131
+ raw_metrics_data: Union[List[dict], dict],
132
+ extra_data: Optional[Union[List[dict], dict]] = None,
133
+ *args,
134
+ **kwargs,
135
+ ) -> Tuple[Union[List[OrderedDict], pd.DataFrame], Optional[str]]:
208
136
  """
209
137
  Calculate the metrics.
210
138
  This function should be overridden by the subclass if needed.
211
139
  Must return a dictionary with keys as metric names and values as the calculated metrics.
212
140
  """
213
- out_dict = self.__calc_exp_perf_metrics(*args, **kwargs)
214
- # pprint(f"Output Dictionary: {out_dict}")
215
- # check if any kwargs named "outfile"
141
+ metric_names = self.get_metric_backend().metric_names
142
+ out_dict_list = self.calc_exp_perf_metrics(
143
+ metric_names=metric_names, raw_metrics_data=raw_metrics_data,
144
+ extra_data=extra_data,
145
+ *args, **kwargs
146
+ )
216
147
  csv_outfile = kwargs.get("outfile", None)
217
148
  if csv_outfile is not None:
218
149
  filePathNoExt, _ = os.path.splitext(csv_outfile)
219
150
  # pprint(f"CSV Outfile Path (No Ext): {filePathNoExt}")
220
- csv_outfile = f'{filePathNoExt}{CSV_FILE_POSTFIX}.csv'
151
+ csv_outfile = f"{filePathNoExt}{CSV_FILE_POSTFIX}.csv"
221
152
  elif "outdir" in kwargs:
222
153
  csvoutdir = kwargs["outdir"]
223
154
  csvfilename = f"{now_str()}_{self.get_dataset_name()}_{self.get_experiment_name()}_{CSV_FILE_POSTFIX}.csv"
224
155
  csv_outfile = os.path.join(csvoutdir, csvfilename)
225
156
 
226
157
  # convert out_dict to a DataFrame
227
- df = pd.DataFrame([out_dict])
158
+ df = pd.DataFrame(out_dict_list)
228
159
  # get the orders of the columns as the orders or the keys in out_dict
229
- ordered_cols = list(out_dict.keys())
160
+ ordered_cols = list(out_dict_list[0].keys())
230
161
  df = df[ordered_cols] # reorder columns
231
-
232
162
  if csv_outfile:
233
163
  df.to_csv(csv_outfile, index=False, sep=";", encoding="utf-8")
234
164
  return_df = kwargs.get("return_df", False)
235
- if return_df: # return DataFrame instead of dict if requested
165
+ if return_df: # return DataFrame instead of dict if requested
236
166
  return df, csv_outfile
237
167
  else:
238
- return out_dict, csv_outfile
168
+ return out_dict_list, csv_outfile
239
169
 
240
170
  @staticmethod
241
171
  def default_exp_csv_filter_fn(exp_file_name: str) -> bool:
@@ -247,29 +177,37 @@ class PerfCalc(ABC): # Abstract base class for performance calculation
247
177
 
248
178
  @classmethod
249
179
  def gen_perf_report_for_multip_exps(
250
- cls, indir: str, exp_csv_filter_fn=default_exp_csv_filter_fn, csv_sep=";"
180
+ cls, indir: str, exp_csv_filter_fn=default_exp_csv_filter_fn, include_file_name=False, csv_sep=";"
251
181
  ) -> PerfTB:
252
182
  """
253
183
  Generate a performance report by scanning experiment subdirectories.
254
184
  Must return a dictionary with keys as metric names and values as performance tables.
255
185
  """
256
- def get_df_for_all_exp_perf(csv_perf_files, csv_sep=';'):
186
+ def get_df_for_all_exp_perf(csv_perf_files, csv_sep=";"):
257
187
  """
258
188
  Create a single DataFrame from all CSV files.
259
189
  Assumes all CSV files MAY have different metrics
260
190
  """
261
191
  cols = []
192
+ FILE_NAME_COL = "file_name" if include_file_name else None
193
+
262
194
  for csv_file in csv_perf_files:
263
195
  temp_df = pd.read_csv(csv_file, sep=csv_sep)
196
+ if FILE_NAME_COL:
197
+ temp_df[FILE_NAME_COL] = fs.get_file_name(csv_file, split_file_ext=False)
198
+ # csvfile.fn_display_df(temp_df)
264
199
  temp_df_cols = temp_df.columns.tolist()
265
200
  for col in temp_df_cols:
266
201
  if col not in cols:
267
202
  cols.append(col)
203
+
268
204
  df = pd.DataFrame(columns=cols)
269
205
  for csv_file in csv_perf_files:
270
206
  temp_df = pd.read_csv(csv_file, sep=csv_sep)
207
+ if FILE_NAME_COL:
208
+ temp_df[FILE_NAME_COL] = fs.get_file_name(csv_file, split_file_ext=False)
271
209
  # Drop all-NA columns to avoid dtype inconsistency
272
- temp_df = temp_df.dropna(axis=1, how='all')
210
+ temp_df = temp_df.dropna(axis=1, how="all")
273
211
  # ensure all columns are present in the final DataFrame
274
212
  for col in cols:
275
213
  if col not in temp_df.columns:
@@ -277,24 +215,36 @@ class PerfCalc(ABC): # Abstract base class for performance calculation
277
215
  df = pd.concat([df, temp_df], ignore_index=True)
278
216
  # assert that REQUIRED_COLS are present in the DataFrame
279
217
  # pprint(df.columns.tolist())
280
- for col in REQUIRED_COLS:
218
+ sticky_cols = REQUIRED_COLS + ([FILE_NAME_COL] if include_file_name else []) # columns that must always be present
219
+ for col in sticky_cols:
281
220
  if col not in df.columns:
282
- raise ValueError(f"Required column '{col}' is missing from the DataFrame. REQUIRED_COLS = {REQUIRED_COLS}")
283
- metric_cols = [col for col in df.columns if col.startswith('metric_')]
284
- assert len(metric_cols) > 0, "No metric columns found in the DataFrame. Ensure that the CSV files contain metric columns starting with 'metric_'."
285
- final_cols = REQUIRED_COLS + metric_cols
221
+ raise ValueError(
222
+ f"Required column '{col}' is missing from the DataFrame. REQUIRED_COLS = {sticky_cols}"
223
+ )
224
+ metric_cols = [col for col in df.columns if col.startswith(METRIC_PREFIX)]
225
+ assert (
226
+ len(metric_cols) > 0
227
+ ), "No metric columns found in the DataFrame. Ensure that the CSV files contain metric columns starting with 'metric_'."
228
+ final_cols = sticky_cols + metric_cols
286
229
  df = df[final_cols]
230
+ # !hahv debug
231
+ pprint("------ Final DataFrame Columns ------")
232
+ csvfile.fn_display_df(df)
287
233
  # ! validate all rows in df before returning
288
234
  # make sure all rows will have at least values for REQUIRED_COLS and at least one metric column
289
235
  for index, row in df.iterrows():
290
- if not all(col in row and pd.notna(row[col]) for col in REQUIRED_COLS):
291
- raise ValueError(f"Row {index} is missing required columns or has NaN values in required columns: {row}")
236
+ if not all(col in row and pd.notna(row[col]) for col in sticky_cols):
237
+ raise ValueError(
238
+ f"Row {index} is missing required columns or has NaN values in required columns: {row}"
239
+ )
292
240
  if not any(pd.notna(row[col]) for col in metric_cols):
293
241
  raise ValueError(f"Row {index} has no metric values: {row}")
294
242
  # make sure these is no (experiment, dataset) pair that is duplicated
295
- duplicates = df.duplicated(subset=['experiment', 'dataset'], keep=False)
243
+ duplicates = df.duplicated(subset=sticky_cols, keep=False)
296
244
  if duplicates.any():
297
- raise ValueError("Duplicate (experiment, dataset) pairs found in the DataFrame. Please ensure that each experiment-dataset combination is unique.")
245
+ raise ValueError(
246
+ "Duplicate (experiment, dataset) pairs found in the DataFrame. Please ensure that each experiment-dataset combination is unique."
247
+ )
298
248
  return df
299
249
 
300
250
  def mk_perftb_report(df):
@@ -304,9 +254,9 @@ class PerfCalc(ABC): # Abstract base class for performance calculation
304
254
  """
305
255
  perftb = PerfTB()
306
256
  # find all "dataset" values (unique)
307
- dataset_names = list(df['dataset'].unique())
308
- # find all columns that start with "metric_"
309
- metric_cols = [col for col in df.columns if col.startswith('metric_')]
257
+ dataset_names = list(df["dataset"].unique())
258
+ # find all columns that start with METRIC_PREFIX
259
+ metric_cols = [col for col in df.columns if col.startswith(METRIC_PREFIX)]
310
260
 
311
261
  # Determine which metrics are associated with each dataset.
312
262
  # Since a dataset may appear in multiple rows and may not include all metrics in each, identify the row with the same dataset that contains the most non-NaN metric values. The set of metrics for that dataset is defined by the non-NaN metrics in that row.
@@ -316,7 +266,11 @@ class PerfCalc(ABC): # Abstract base class for performance calculation
316
266
  dataset_rows = df[df["dataset"] == dataset_name]
317
267
  # Find the row with the most non-NaN metric values
318
268
  max_non_nan_row = dataset_rows[metric_cols].count(axis=1).idxmax()
319
- metrics_for_dataset = dataset_rows.loc[max_non_nan_row, metric_cols].dropna().index.tolist()
269
+ metrics_for_dataset = (
270
+ dataset_rows.loc[max_non_nan_row, metric_cols]
271
+ .dropna()
272
+ .index.tolist()
273
+ )
320
274
  dataset_metrics[dataset_name] = metrics_for_dataset
321
275
 
322
276
  for dataset_name, metrics in dataset_metrics.items():
@@ -324,11 +278,11 @@ class PerfCalc(ABC): # Abstract base class for performance calculation
324
278
  perftb.add_dataset(dataset_name, metrics)
325
279
 
326
280
  for _, row in df.iterrows():
327
- dataset_name = row['dataset']
281
+ dataset_name = row["dataset"]
328
282
  ds_metrics = dataset_metrics.get(dataset_name)
329
283
  if dataset_name in dataset_metrics:
330
284
  # Add the metrics for this row to the performance table
331
- exp_name = row.get('experiment')
285
+ exp_name = row.get("experiment")
332
286
  exp_metric_values = {}
333
287
  for metric in ds_metrics:
334
288
  if metric in row and pd.notna(row[metric]):
@@ -336,7 +290,7 @@ class PerfCalc(ABC): # Abstract base class for performance calculation
336
290
  perftb.add_experiment(
337
291
  experiment_name=exp_name,
338
292
  dataset_name=dataset_name,
339
- metrics=exp_metric_values
293
+ metrics=exp_metric_values,
340
294
  )
341
295
 
342
296
  return perftb
@@ -351,9 +305,7 @@ class PerfCalc(ABC): # Abstract base class for performance calculation
351
305
  if os.path.isdir(os.path.join(indir, d))
352
306
  ]
353
307
  if len(exp_dirs) == 0:
354
- csv_perf_files = glob.glob(
355
- os.path.join(indir, f"*.csv")
356
- )
308
+ csv_perf_files = glob.glob(os.path.join(indir, f"*.csv"))
357
309
  csv_perf_files = [
358
310
  file_item
359
311
  for file_item in csv_perf_files
@@ -364,13 +316,9 @@ class PerfCalc(ABC): # Abstract base class for performance calculation
364
316
  # Collect all matching CSV files in those subdirs
365
317
  for exp_dir in exp_dirs:
366
318
  # pprint(f"Searching in experiment directory: {exp_dir}")
367
- matched = glob.glob(
368
- os.path.join(exp_dir, f"*.csv")
369
- )
319
+ matched = glob.glob(os.path.join(exp_dir, f"*.csv"))
370
320
  matched = [
371
- file_item
372
- for file_item in matched
373
- if exp_csv_filter_fn(file_item)
321
+ file_item for file_item in matched if exp_csv_filter_fn(file_item)
374
322
  ]
375
323
  csv_perf_files.extend(matched)
376
324
 
@@ -378,9 +326,11 @@ class PerfCalc(ABC): # Abstract base class for performance calculation
378
326
  len(csv_perf_files) > 0
379
327
  ), f"No CSV files matching pattern '{exp_csv_filter_fn}' found in the experiment directories."
380
328
 
381
- assert len(csv_perf_files) > 0, f"No CSV files matching pattern '{exp_csv_filter_fn}' found in the experiment directories."
329
+ assert (
330
+ len(csv_perf_files) > 0
331
+ ), f"No CSV files matching pattern '{exp_csv_filter_fn}' found in the experiment directories."
382
332
 
383
333
  all_exp_perf_df = get_df_for_all_exp_perf(csv_perf_files, csv_sep=csv_sep)
384
- csvfile.fn_display_df(all_exp_perf_df)
334
+ # csvfile.fn_display_df(all_exp_perf_df)
385
335
  perf_tb = mk_perftb_report(all_exp_perf_df)
386
- return perf_tb
336
+ return perf_tb
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: halib
3
- Version: 0.1.65
3
+ Version: 0.1.67
4
4
  Summary: Small library for common tasks
5
5
  Author: Hoang Van Ha
6
6
  Author-email: hoangvanhauit@gmail.com
@@ -52,7 +52,7 @@ Dynamic: summary
52
52
 
53
53
  Helper package for coding and automation
54
54
 
55
- **Version 0.1.65**
55
+ **Version 0.1.67**
56
56
 
57
57
  + now use `uv` for venv management
58
58
  + `research/perfcalc`: support both torchmetrics and custom metrics for performance calculation
@@ -30,7 +30,8 @@ halib/online/projectmake.py,sha256=Zrs96WgXvO4nIrwxnCOletL4aTBge-EoF0r7hpKO1w8,4
30
30
  halib/research/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
31
  halib/research/benchquery.py,sha256=FuKnbWQtCEoRRtJAfN-zaN-jPiO_EzsakmTOMiqi7GQ,4626
32
32
  halib/research/dataset.py,sha256=QU0Hr5QFb8_XlvnOMgC9QJGIpwXAZ9lDd0RdQi_QRec,6743
33
- halib/research/perfcalc.py,sha256=F1BYbxQohbS7u3iRtqnKgPmMrWneV6_bEdBumto8h58,18403
33
+ halib/research/metrics.py,sha256=Xgv0GUGo-o-RJaBOmkRCRpQJaYijF_1xeKkyYU_Bv4U,5249
34
+ halib/research/perfcalc.py,sha256=qDa0sqfpWrwGZVJtjuUVFK7JX6j8xyXP9OnnfYmdamg,15898
34
35
  halib/research/perftb.py,sha256=vazU-dYBJhfc4sK4zFgxOvzeXGi-5TyPHCt20ItiWhY,30463
35
36
  halib/research/plot.py,sha256=-pDUk4z3C_GnyJ5zWmf-mGMdT4gaipVJWzIgcpIPiRk,9448
36
37
  halib/research/torchloader.py,sha256=yqUjcSiME6H5W210363HyRUrOi3ISpUFAFkTr1w4DCw,6503
@@ -48,8 +49,8 @@ halib/utils/gpu_mon.py,sha256=vD41_ZnmPLKguuq9X44SB_vwd9JrblO4BDzHLXZhhFY,2233
48
49
  halib/utils/listop.py,sha256=Vpa8_2fI0wySpB2-8sfTBkyi_A4FhoFVVvFiuvW8N64,339
49
50
  halib/utils/tele_noti.py,sha256=-4WXZelCA4W9BroapkRyIdUu9cUVrcJJhegnMs_WpGU,5928
50
51
  halib/utils/video.py,sha256=ZqzNVPgc1RZr_T0OlHvZ6SzyBpL7O27LtB86JMbBuR0,3059
51
- halib-0.1.65.dist-info/licenses/LICENSE.txt,sha256=qZssdna4aETiR8znYsShUjidu-U4jUT9Q-EWNlZ9yBQ,1100
52
- halib-0.1.65.dist-info/METADATA,sha256=clqI54I9dybegyKBsPwaJM7cYBOCLYHdzXHEpFAKv_4,5541
53
- halib-0.1.65.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
54
- halib-0.1.65.dist-info/top_level.txt,sha256=7AD6PLaQTreE0Fn44mdZsoHBe_Zdd7GUmjsWPyQ7I-k,6
55
- halib-0.1.65.dist-info/RECORD,,
52
+ halib-0.1.67.dist-info/licenses/LICENSE.txt,sha256=qZssdna4aETiR8znYsShUjidu-U4jUT9Q-EWNlZ9yBQ,1100
53
+ halib-0.1.67.dist-info/METADATA,sha256=Zk22ct5W95qBzGkz0tNepuAdfUwPJTbVO7Nb4L_hFTQ,5541
54
+ halib-0.1.67.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
55
+ halib-0.1.67.dist-info/top_level.txt,sha256=7AD6PLaQTreE0Fn44mdZsoHBe_Zdd7GUmjsWPyQ7I-k,6
56
+ halib-0.1.67.dist-info/RECORD,,
File without changes