halib 0.1.65__py3-none-any.whl → 0.1.67__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- halib/research/metrics.py +134 -0
- halib/research/perfcalc.py +151 -201
- {halib-0.1.65.dist-info → halib-0.1.67.dist-info}/METADATA +2 -2
- {halib-0.1.65.dist-info → halib-0.1.67.dist-info}/RECORD +7 -6
- {halib-0.1.65.dist-info → halib-0.1.67.dist-info}/WHEEL +0 -0
- {halib-0.1.65.dist-info → halib-0.1.67.dist-info}/licenses/LICENSE.txt +0 -0
- {halib-0.1.65.dist-info → halib-0.1.67.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,134 @@
|
|
1
|
+
# -------------------------------
|
2
|
+
# Metrics Backend Interface
|
3
|
+
# -------------------------------
|
4
|
+
import inspect
|
5
|
+
from typing import Dict, Union, List, Any
|
6
|
+
from abc import ABC, abstractmethod
|
7
|
+
|
8
|
+
class MetricsBackend(ABC):
|
9
|
+
"""Interface for pluggable metrics computation backends."""
|
10
|
+
|
11
|
+
def __init__(self, metrics_info: Union[List[str], Dict[str, Any]]):
|
12
|
+
"""
|
13
|
+
Initialize the backend with optional metrics_info.
|
14
|
+
"""
|
15
|
+
self.metric_info = metrics_info
|
16
|
+
self.validate_metrics_info(self.metric_info)
|
17
|
+
|
18
|
+
@property
|
19
|
+
def metric_names(self) -> List[str]:
|
20
|
+
"""
|
21
|
+
Return a list of metric names.
|
22
|
+
If metric_info is a dict, return its keys; if it's a list, return it directly.
|
23
|
+
"""
|
24
|
+
if isinstance(self.metric_info, dict):
|
25
|
+
return list(self.metric_info.keys())
|
26
|
+
elif isinstance(self.metric_info, list):
|
27
|
+
return self.metric_info
|
28
|
+
else:
|
29
|
+
raise TypeError("metric_info must be a list or a dict")
|
30
|
+
|
31
|
+
def validate_metrics_info(self, metrics_info):
|
32
|
+
if isinstance(metrics_info, list):
|
33
|
+
return metrics_info
|
34
|
+
elif isinstance(metrics_info, dict):
|
35
|
+
return {k: v for k, v in metrics_info.items() if isinstance(k, str)}
|
36
|
+
else:
|
37
|
+
raise TypeError(
|
38
|
+
"metrics_info must be a list of strings or a dict with string keys"
|
39
|
+
)
|
40
|
+
|
41
|
+
@abstractmethod
|
42
|
+
def compute_metrics(
|
43
|
+
self, metrics_info: Union[List[str], Dict[str, Any]], metrics_data_dict: Dict[str, Any], *args, **kwargs
|
44
|
+
) -> Dict[str, Any]:
|
45
|
+
pass
|
46
|
+
|
47
|
+
def prepare_metrics_backend_data(
|
48
|
+
self, raw_metric_data, *args, **kwargs
|
49
|
+
):
|
50
|
+
"""
|
51
|
+
Prepare the data for the metrics backend.
|
52
|
+
This method can be overridden by subclasses to customize data preparation.
|
53
|
+
"""
|
54
|
+
return raw_metric_data
|
55
|
+
|
56
|
+
def calc_metrics(
|
57
|
+
self, metrics_data_dict: Dict[str, Any], *args, **kwargs
|
58
|
+
) -> Dict[str, Any]:
|
59
|
+
"""
|
60
|
+
Calculate metrics based on the provided metrics_info and data.
|
61
|
+
This method should be overridden by subclasses to implement specific metric calculations.
|
62
|
+
"""
|
63
|
+
# prevalidate the metrics_data_dict
|
64
|
+
for metric in self.metric_names:
|
65
|
+
if metric not in metrics_data_dict:
|
66
|
+
raise ValueError(f"Metric '{metric}' not found in provided data.")
|
67
|
+
# Prepare the data for the backend
|
68
|
+
metrics_data_dict = self.prepare_metrics_backend_data(
|
69
|
+
metrics_data_dict, *args, **kwargs
|
70
|
+
)
|
71
|
+
# Call the abstract method to compute metrics
|
72
|
+
return self.compute_metrics(self.metric_info, metrics_data_dict, *args, **kwargs)
|
73
|
+
|
74
|
+
class TorchMetricsBackend(MetricsBackend):
|
75
|
+
"""TorchMetrics-based backend implementation."""
|
76
|
+
|
77
|
+
def __init__(self, metrics_info: Union[List[str], Dict[str, Any]]):
|
78
|
+
try:
|
79
|
+
import torch
|
80
|
+
from torchmetrics import Metric
|
81
|
+
except ImportError:
|
82
|
+
raise ImportError(
|
83
|
+
"TorchMetricsBackend requires torch and torchmetrics to be installed."
|
84
|
+
)
|
85
|
+
self.metric_info = metrics_info
|
86
|
+
self.torch = torch
|
87
|
+
self.Metric = Metric
|
88
|
+
self.validate_metrics_info(metrics_info)
|
89
|
+
|
90
|
+
def validate_metrics_info(self, metrics_info):
|
91
|
+
if not isinstance(metrics_info, dict):
|
92
|
+
raise TypeError(
|
93
|
+
"TorchMetricsBackend requires metrics_info as a dict {name: MetricInstance}"
|
94
|
+
)
|
95
|
+
for k, v in metrics_info.items():
|
96
|
+
if not isinstance(k, str):
|
97
|
+
raise TypeError(f"Key '{k}' is not a string")
|
98
|
+
if not isinstance(v, self.Metric):
|
99
|
+
raise TypeError(f"Value for key '{k}' must be a torchmetrics.Metric")
|
100
|
+
return metrics_info
|
101
|
+
|
102
|
+
def compute_metrics(self, metrics_info, metrics_data_dict, *args, **kwargs):
|
103
|
+
out_dict = {}
|
104
|
+
for metric, metric_instance in metrics_info.items():
|
105
|
+
if metric not in metrics_data_dict:
|
106
|
+
raise ValueError(f"Metric '{metric}' not found in provided data.")
|
107
|
+
|
108
|
+
metric_data = metrics_data_dict[metric]
|
109
|
+
sig = inspect.signature(metric_instance.update)
|
110
|
+
expected_args = list(sig.parameters.values())
|
111
|
+
|
112
|
+
if isinstance(metric_data, dict):
|
113
|
+
args = [metric_data[param.name] for param in expected_args]
|
114
|
+
elif isinstance(metric_data, (list, tuple)):
|
115
|
+
args = metric_data
|
116
|
+
else:
|
117
|
+
raise TypeError(f"Unsupported data format for metric '{metric}'")
|
118
|
+
|
119
|
+
if len(expected_args) == 1:
|
120
|
+
metric_instance.update(args)
|
121
|
+
else:
|
122
|
+
metric_instance.update(*args)
|
123
|
+
|
124
|
+
computed_value = metric_instance.compute()
|
125
|
+
if isinstance(computed_value, self.torch.Tensor):
|
126
|
+
computed_value = (
|
127
|
+
computed_value.item()
|
128
|
+
if computed_value.numel() == 1
|
129
|
+
else computed_value.tolist()
|
130
|
+
)
|
131
|
+
|
132
|
+
|
133
|
+
out_dict[metric] = computed_value
|
134
|
+
return out_dict
|
halib/research/perfcalc.py
CHANGED
@@ -1,71 +1,27 @@
|
|
1
1
|
import os
|
2
2
|
import glob
|
3
|
-
import
|
3
|
+
from typing import Optional, Tuple
|
4
4
|
import pandas as pd
|
5
5
|
|
6
|
-
from typing import Dict
|
7
|
-
from functools import wraps
|
8
6
|
from rich.pretty import pprint
|
9
7
|
|
10
8
|
from abc import ABC, abstractmethod
|
9
|
+
from collections import OrderedDict
|
11
10
|
|
12
11
|
from ..filetype import csvfile
|
12
|
+
from ..system import filesys as fs
|
13
13
|
from ..common import now_str
|
14
14
|
from ..research.perftb import PerfTB
|
15
|
-
from
|
15
|
+
from ..research.metrics import *
|
16
16
|
|
17
|
-
# try to import torch, and torchmetrics
|
18
|
-
try:
|
19
|
-
import torch
|
20
|
-
import torchmetrics
|
21
|
-
from torchmetrics import Metric
|
22
|
-
except ImportError:
|
23
|
-
raise ImportError("Please install torch and torchmetrics to use this module.")
|
24
|
-
|
25
|
-
def validate_torch_metrics(fn):
|
26
|
-
@wraps(fn)
|
27
|
-
def wrapper(self, *args, **kwargs):
|
28
|
-
result = fn(self, *args, **kwargs)
|
29
|
-
|
30
|
-
if not isinstance(result, dict):
|
31
|
-
raise TypeError("torch_metrics() must return a dictionary")
|
32
|
-
|
33
|
-
for k, v in result.items():
|
34
|
-
if not isinstance(k, str):
|
35
|
-
raise TypeError(f"Key '{k}' is not a string")
|
36
|
-
if not isinstance(v, Metric):
|
37
|
-
raise TypeError(
|
38
|
-
f"Value for key '{k}' is not a torchmetrics.Metric (got {type(v).__name__})"
|
39
|
-
)
|
40
|
-
|
41
|
-
return result
|
42
|
-
|
43
|
-
return wrapper
|
44
|
-
def valid_custom_fields(fn):
|
45
|
-
@wraps(fn)
|
46
|
-
def wrapper(self, *args, **kwargs):
|
47
|
-
rs = fn(self, *args, **kwargs)
|
48
|
-
if not isinstance(rs, tuple) or len(rs) != 2:
|
49
|
-
raise ValueError("Function must return a tuple (outdict, custom_fields)")
|
50
|
-
outdict, custom_fields = rs
|
51
|
-
if not isinstance(outdict, dict):
|
52
|
-
raise TypeError("Output must be a dictionary")
|
53
|
-
if not isinstance(custom_fields, list):
|
54
|
-
raise TypeError("Custom fields must be a list")
|
55
|
-
for field in custom_fields:
|
56
|
-
if not isinstance(field, str):
|
57
|
-
raise TypeError(f"Custom field '{field}' is not a string")
|
58
|
-
return outdict, custom_fields
|
59
|
-
|
60
|
-
return wrapper
|
61
17
|
|
62
18
|
REQUIRED_COLS = ["experiment", "dataset"]
|
63
19
|
CSV_FILE_POSTFIX = "__perf"
|
20
|
+
METRIC_PREFIX = "metric_"
|
64
21
|
|
65
|
-
class PerfCalc(ABC):
|
66
|
-
|
22
|
+
class PerfCalc(ABC): # Abstract base class for performance calculation
|
67
23
|
@abstractmethod
|
68
|
-
def get_experiment_name(self):
|
24
|
+
def get_experiment_name(self) -> str:
|
69
25
|
"""
|
70
26
|
Return the name of the experiment.
|
71
27
|
This function should be overridden by the subclass if needed.
|
@@ -73,7 +29,7 @@ class PerfCalc(ABC): # Abstract base class for performance calculation
|
|
73
29
|
pass
|
74
30
|
|
75
31
|
@abstractmethod
|
76
|
-
def get_dataset_name(self):
|
32
|
+
def get_dataset_name(self) -> str:
|
77
33
|
"""
|
78
34
|
Return the name of the dataset.
|
79
35
|
This function should be overridden by the subclass if needed.
|
@@ -81,161 +37,135 @@ class PerfCalc(ABC): # Abstract base class for performance calculation
|
|
81
37
|
pass
|
82
38
|
|
83
39
|
@abstractmethod
|
84
|
-
def
|
40
|
+
def get_metric_backend(self) -> MetricsBackend:
|
85
41
|
"""
|
86
42
|
Return a list of metric names to be used for performance calculation OR a dictionaray with keys as metric names and values as metric instances of torchmetrics.Metric. For example: {"accuracy": Accuracy(), "precision": Precision()}
|
87
43
|
|
88
44
|
"""
|
89
45
|
pass
|
90
46
|
|
91
|
-
def
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
for k, v in metrics_info.items():
|
117
|
-
if not isinstance(k, str):
|
118
|
-
raise TypeError(f"Key '{k}' is not a string")
|
119
|
-
if not isinstance(v, Metric):
|
120
|
-
raise TypeError(f"Value for key '{k}' is not a torchmetrics.Metric (got {type(v).__name__})")
|
121
|
-
elif isinstance(metrics_info, list):
|
122
|
-
for metric in metrics_info:
|
123
|
-
if not isinstance(metric, str):
|
124
|
-
raise TypeError(f"Metric '{metric}' is not a string")
|
125
|
-
return metrics_info
|
126
|
-
def __calc_exp_perf_metrics(self, *args, **kwargs):
|
127
|
-
"""
|
128
|
-
Calculate the performance metrics for the experiment.
|
129
|
-
"""
|
130
|
-
metrics_info = self.__validate_metrics_info(self.get_metrics_info())
|
131
|
-
USED_TORCHMETRICS = isinstance(metrics_info, dict)
|
132
|
-
metric_names = metrics_info if isinstance(metrics_info, list) else list(metrics_info.keys())
|
133
|
-
out_dict = {metric: None for metric in metric_names}
|
134
|
-
out_dict["dataset"] = self.get_dataset_name()
|
135
|
-
out_dict["experiment"] = self.get_experiment_name()
|
136
|
-
out_dict, custom_fields = self.calc_exp_outdict_custom_fields(
|
137
|
-
outdict=out_dict, *args, **kwargs
|
138
|
-
)
|
139
|
-
if USED_TORCHMETRICS:
|
140
|
-
torch_metrics_dict = self.get_metrics_info()
|
141
|
-
all_metric_data = self.prepare_torch_metrics_exp_data(
|
142
|
-
metric_names, *args, **kwargs
|
47
|
+
def valid_proc_extra_data(
|
48
|
+
self, proc_extra_data
|
49
|
+
):
|
50
|
+
# make sure that all items in proc_extra_data are dictionaries, with same keys
|
51
|
+
if proc_extra_data is None or len(proc_extra_data) == 0:
|
52
|
+
return
|
53
|
+
if not all(isinstance(item, dict) for item in proc_extra_data):
|
54
|
+
raise TypeError("All items in proc_extra_data must be dictionaries")
|
55
|
+
|
56
|
+
if not all(item.keys() == proc_extra_data[0].keys() for item in proc_extra_data):
|
57
|
+
raise ValueError("All dictionaries in proc_extra_data must have the same keys")
|
58
|
+
|
59
|
+
def valid_proc_metric_raw_data(
|
60
|
+
self, metric_names, proc_metric_raw_data
|
61
|
+
):
|
62
|
+
# make sure that all items in proc_metric_raw_data are dictionaries, with same keys as metric_names
|
63
|
+
assert isinstance(proc_metric_raw_data, list) and len(proc_metric_raw_data) > 0, \
|
64
|
+
"raw_data_for_metrics must be a non-empty list of dictionaries"
|
65
|
+
|
66
|
+
# make sure that all items in proc_metric_raw_data are dictionaries with keys as metric_names
|
67
|
+
if not all(isinstance(item, dict) for item in proc_metric_raw_data):
|
68
|
+
raise TypeError("All items in raw_data_for_metrics must be dictionaries")
|
69
|
+
if not all( set(item.keys()) == set(metric_names) for item in proc_metric_raw_data):
|
70
|
+
raise ValueError(
|
71
|
+
"All dictionaries in raw_data_for_metrics must have the same keys as metric_names"
|
143
72
|
)
|
144
|
-
metric_col_names = []
|
145
|
-
for metric in metric_names:
|
146
|
-
if metric not in all_metric_data:
|
147
|
-
raise ValueError(f"Metric '{metric}' not found in provided data.")
|
148
|
-
tmetric = torch_metrics_dict[metric] # torchmetrics instance
|
149
|
-
metric_data = all_metric_data[metric] # should be a dict of args/kwargs
|
150
|
-
# Inspect expected parameters for the metric's update() method
|
151
|
-
sig = inspect.signature(tmetric.update)
|
152
|
-
expected_args = list(sig.parameters.values())
|
153
|
-
# Prepare args in correct order
|
154
|
-
if isinstance(metric_data, dict):
|
155
|
-
# Match dict keys to parameter names
|
156
|
-
args = [metric_data[param.name] for param in expected_args]
|
157
|
-
elif isinstance(metric_data, (list, tuple)):
|
158
|
-
args = metric_data
|
159
|
-
else:
|
160
|
-
raise TypeError(f"Unsupported data format for metric '{metric}'")
|
161
|
-
|
162
|
-
# Call update and compute
|
163
|
-
if len(expected_args) == 1:
|
164
|
-
tmetric.update(args) # pass as single argument
|
165
|
-
else:
|
166
|
-
tmetric.update(*args) # unpack multiple arguments
|
167
|
-
computed_value = tmetric.compute()
|
168
|
-
# ensure the computed value converted to a scala value or list array
|
169
|
-
if isinstance(computed_value, torch.Tensor):
|
170
|
-
if computed_value.numel() == 1:
|
171
|
-
computed_value = computed_value.item()
|
172
|
-
else:
|
173
|
-
computed_value = computed_value.tolist()
|
174
|
-
col_name = f"metric_{metric}" if "metric_" not in metric else metric
|
175
|
-
metric_col_names.append(col_name)
|
176
|
-
out_dict[col_name] = computed_value
|
177
|
-
else:
|
178
|
-
# If torchmetrics are not used, calculate metrics using the custom method
|
179
|
-
metric_rs_dict = self.calc_exp_perf_metrics(
|
180
|
-
metric_names, *args, **kwargs)
|
181
|
-
for metric in metric_names:
|
182
|
-
if metric not in metric_rs_dict:
|
183
|
-
raise ValueError(f"Metric '{metric}' not found in provided data.")
|
184
|
-
col_name = f"metric_{metric}" if "metric_" not in metric else metric
|
185
|
-
out_dict[col_name] = metric_rs_dict[metric]
|
186
|
-
metric_col_names = [f"metric_{metric}" for metric in metric_names]
|
187
|
-
ordered_cols = REQUIRED_COLS + custom_fields + metric_col_names
|
188
|
-
# create a new ordered dictionary with the correct order
|
189
|
-
out_dict = OrderedDict((col, out_dict[col]) for col in ordered_cols if col in out_dict)
|
190
|
-
return out_dict
|
191
73
|
|
192
74
|
# ! only need to override this method if torchmetrics are not used
|
193
|
-
def calc_exp_perf_metrics(
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
75
|
+
def calc_exp_perf_metrics(
|
76
|
+
self, metric_names, raw_metrics_data, extra_data=None, *args, **kwargs
|
77
|
+
):
|
78
|
+
assert isinstance(raw_metrics_data, dict) or isinstance(raw_metrics_data, list), \
|
79
|
+
"raw_data_for_metrics must be a dictionary or a list"
|
80
|
+
|
81
|
+
if extra_data is not None:
|
82
|
+
assert isinstance(extra_data, type(raw_metrics_data)), \
|
83
|
+
"extra_data must be of the same type as raw_data_for_metrics (dict or list)"
|
84
|
+
# prepare raw_metric data for processing
|
85
|
+
proc_metric_raw_data_ls = raw_metrics_data if isinstance(raw_metrics_data, list) else [raw_metrics_data.copy()]
|
86
|
+
self.valid_proc_metric_raw_data(metric_names, proc_metric_raw_data_ls)
|
87
|
+
# prepare extra data for processing
|
88
|
+
proc_extra_data_ls = []
|
89
|
+
if extra_data is not None:
|
90
|
+
proc_extra_data_ls = extra_data if isinstance(extra_data, list) else [extra_data.copy()]
|
91
|
+
assert len(proc_extra_data_ls) == len(proc_metric_raw_data_ls), \
|
92
|
+
"extra_data must have the same length as raw_data_for_metrics if it is a list"
|
93
|
+
# validate the extra_data
|
94
|
+
self.valid_proc_extra_data(proc_extra_data_ls)
|
95
|
+
|
96
|
+
# calculate the metrics output results
|
97
|
+
metrics_backend = self.get_metric_backend()
|
98
|
+
proc_outdict_list = []
|
99
|
+
for idx, raw_metrics_data in enumerate(proc_metric_raw_data_ls):
|
100
|
+
out_dict = {
|
101
|
+
"dataset": self.get_dataset_name(),
|
102
|
+
"experiment": self.get_experiment_name(),
|
103
|
+
}
|
104
|
+
custom_fields = []
|
105
|
+
if len(proc_extra_data_ls)> 0:
|
106
|
+
# add extra data to the output dictionary
|
107
|
+
extra_data_item = proc_extra_data_ls[idx]
|
108
|
+
out_dict.update(extra_data_item)
|
109
|
+
custom_fields = list(extra_data_item.keys())
|
110
|
+
metric_results = metrics_backend.calc_metrics(
|
111
|
+
metrics_data_dict=raw_metrics_data, *args, **kwargs
|
112
|
+
)
|
113
|
+
metric_results_prefix = {f"metric_{k}": v for k, v in metric_results.items()}
|
114
|
+
out_dict.update(metric_results_prefix)
|
115
|
+
ordered_cols = (
|
116
|
+
REQUIRED_COLS + custom_fields + list(metric_results_prefix.keys())
|
117
|
+
)
|
118
|
+
out_dict = OrderedDict(
|
119
|
+
(col, out_dict[col]) for col in ordered_cols if col in out_dict
|
120
|
+
)
|
121
|
+
proc_outdict_list.append(out_dict)
|
200
122
|
|
123
|
+
return proc_outdict_list
|
201
124
|
|
202
125
|
#! custom kwargs:
|
203
126
|
#! outfile - if provided, will save the output to a CSV file with the given path
|
204
127
|
#! outdir - if provided, will save the output to a CSV file in the given directory with a generated filename
|
205
128
|
#! return_df - if True, will return a DataFrame instead of a dictionary
|
206
|
-
|
207
|
-
|
129
|
+
def calc_and_save_exp_perfs(
|
130
|
+
self,
|
131
|
+
raw_metrics_data: Union[List[dict], dict],
|
132
|
+
extra_data: Optional[Union[List[dict], dict]] = None,
|
133
|
+
*args,
|
134
|
+
**kwargs,
|
135
|
+
) -> Tuple[Union[List[OrderedDict], pd.DataFrame], Optional[str]]:
|
208
136
|
"""
|
209
137
|
Calculate the metrics.
|
210
138
|
This function should be overridden by the subclass if needed.
|
211
139
|
Must return a dictionary with keys as metric names and values as the calculated metrics.
|
212
140
|
"""
|
213
|
-
|
214
|
-
|
215
|
-
|
141
|
+
metric_names = self.get_metric_backend().metric_names
|
142
|
+
out_dict_list = self.calc_exp_perf_metrics(
|
143
|
+
metric_names=metric_names, raw_metrics_data=raw_metrics_data,
|
144
|
+
extra_data=extra_data,
|
145
|
+
*args, **kwargs
|
146
|
+
)
|
216
147
|
csv_outfile = kwargs.get("outfile", None)
|
217
148
|
if csv_outfile is not None:
|
218
149
|
filePathNoExt, _ = os.path.splitext(csv_outfile)
|
219
150
|
# pprint(f"CSV Outfile Path (No Ext): {filePathNoExt}")
|
220
|
-
csv_outfile = f
|
151
|
+
csv_outfile = f"{filePathNoExt}{CSV_FILE_POSTFIX}.csv"
|
221
152
|
elif "outdir" in kwargs:
|
222
153
|
csvoutdir = kwargs["outdir"]
|
223
154
|
csvfilename = f"{now_str()}_{self.get_dataset_name()}_{self.get_experiment_name()}_{CSV_FILE_POSTFIX}.csv"
|
224
155
|
csv_outfile = os.path.join(csvoutdir, csvfilename)
|
225
156
|
|
226
157
|
# convert out_dict to a DataFrame
|
227
|
-
df = pd.DataFrame(
|
158
|
+
df = pd.DataFrame(out_dict_list)
|
228
159
|
# get the orders of the columns as the orders or the keys in out_dict
|
229
|
-
ordered_cols = list(
|
160
|
+
ordered_cols = list(out_dict_list[0].keys())
|
230
161
|
df = df[ordered_cols] # reorder columns
|
231
|
-
|
232
162
|
if csv_outfile:
|
233
163
|
df.to_csv(csv_outfile, index=False, sep=";", encoding="utf-8")
|
234
164
|
return_df = kwargs.get("return_df", False)
|
235
|
-
if return_df:
|
165
|
+
if return_df: # return DataFrame instead of dict if requested
|
236
166
|
return df, csv_outfile
|
237
167
|
else:
|
238
|
-
return
|
168
|
+
return out_dict_list, csv_outfile
|
239
169
|
|
240
170
|
@staticmethod
|
241
171
|
def default_exp_csv_filter_fn(exp_file_name: str) -> bool:
|
@@ -247,29 +177,37 @@ class PerfCalc(ABC): # Abstract base class for performance calculation
|
|
247
177
|
|
248
178
|
@classmethod
|
249
179
|
def gen_perf_report_for_multip_exps(
|
250
|
-
cls, indir: str, exp_csv_filter_fn=default_exp_csv_filter_fn, csv_sep=";"
|
180
|
+
cls, indir: str, exp_csv_filter_fn=default_exp_csv_filter_fn, include_file_name=False, csv_sep=";"
|
251
181
|
) -> PerfTB:
|
252
182
|
"""
|
253
183
|
Generate a performance report by scanning experiment subdirectories.
|
254
184
|
Must return a dictionary with keys as metric names and values as performance tables.
|
255
185
|
"""
|
256
|
-
def get_df_for_all_exp_perf(csv_perf_files, csv_sep=
|
186
|
+
def get_df_for_all_exp_perf(csv_perf_files, csv_sep=";"):
|
257
187
|
"""
|
258
188
|
Create a single DataFrame from all CSV files.
|
259
189
|
Assumes all CSV files MAY have different metrics
|
260
190
|
"""
|
261
191
|
cols = []
|
192
|
+
FILE_NAME_COL = "file_name" if include_file_name else None
|
193
|
+
|
262
194
|
for csv_file in csv_perf_files:
|
263
195
|
temp_df = pd.read_csv(csv_file, sep=csv_sep)
|
196
|
+
if FILE_NAME_COL:
|
197
|
+
temp_df[FILE_NAME_COL] = fs.get_file_name(csv_file, split_file_ext=False)
|
198
|
+
# csvfile.fn_display_df(temp_df)
|
264
199
|
temp_df_cols = temp_df.columns.tolist()
|
265
200
|
for col in temp_df_cols:
|
266
201
|
if col not in cols:
|
267
202
|
cols.append(col)
|
203
|
+
|
268
204
|
df = pd.DataFrame(columns=cols)
|
269
205
|
for csv_file in csv_perf_files:
|
270
206
|
temp_df = pd.read_csv(csv_file, sep=csv_sep)
|
207
|
+
if FILE_NAME_COL:
|
208
|
+
temp_df[FILE_NAME_COL] = fs.get_file_name(csv_file, split_file_ext=False)
|
271
209
|
# Drop all-NA columns to avoid dtype inconsistency
|
272
|
-
temp_df = temp_df.dropna(axis=1, how=
|
210
|
+
temp_df = temp_df.dropna(axis=1, how="all")
|
273
211
|
# ensure all columns are present in the final DataFrame
|
274
212
|
for col in cols:
|
275
213
|
if col not in temp_df.columns:
|
@@ -277,24 +215,36 @@ class PerfCalc(ABC): # Abstract base class for performance calculation
|
|
277
215
|
df = pd.concat([df, temp_df], ignore_index=True)
|
278
216
|
# assert that REQUIRED_COLS are present in the DataFrame
|
279
217
|
# pprint(df.columns.tolist())
|
280
|
-
|
218
|
+
sticky_cols = REQUIRED_COLS + ([FILE_NAME_COL] if include_file_name else []) # columns that must always be present
|
219
|
+
for col in sticky_cols:
|
281
220
|
if col not in df.columns:
|
282
|
-
raise ValueError(
|
283
|
-
|
284
|
-
|
285
|
-
|
221
|
+
raise ValueError(
|
222
|
+
f"Required column '{col}' is missing from the DataFrame. REQUIRED_COLS = {sticky_cols}"
|
223
|
+
)
|
224
|
+
metric_cols = [col for col in df.columns if col.startswith(METRIC_PREFIX)]
|
225
|
+
assert (
|
226
|
+
len(metric_cols) > 0
|
227
|
+
), "No metric columns found in the DataFrame. Ensure that the CSV files contain metric columns starting with 'metric_'."
|
228
|
+
final_cols = sticky_cols + metric_cols
|
286
229
|
df = df[final_cols]
|
230
|
+
# !hahv debug
|
231
|
+
pprint("------ Final DataFrame Columns ------")
|
232
|
+
csvfile.fn_display_df(df)
|
287
233
|
# ! validate all rows in df before returning
|
288
234
|
# make sure all rows will have at least values for REQUIRED_COLS and at least one metric column
|
289
235
|
for index, row in df.iterrows():
|
290
|
-
if not all(col in row and pd.notna(row[col]) for col in
|
291
|
-
raise ValueError(
|
236
|
+
if not all(col in row and pd.notna(row[col]) for col in sticky_cols):
|
237
|
+
raise ValueError(
|
238
|
+
f"Row {index} is missing required columns or has NaN values in required columns: {row}"
|
239
|
+
)
|
292
240
|
if not any(pd.notna(row[col]) for col in metric_cols):
|
293
241
|
raise ValueError(f"Row {index} has no metric values: {row}")
|
294
242
|
# make sure these is no (experiment, dataset) pair that is duplicated
|
295
|
-
duplicates = df.duplicated(subset=
|
243
|
+
duplicates = df.duplicated(subset=sticky_cols, keep=False)
|
296
244
|
if duplicates.any():
|
297
|
-
raise ValueError(
|
245
|
+
raise ValueError(
|
246
|
+
"Duplicate (experiment, dataset) pairs found in the DataFrame. Please ensure that each experiment-dataset combination is unique."
|
247
|
+
)
|
298
248
|
return df
|
299
249
|
|
300
250
|
def mk_perftb_report(df):
|
@@ -304,9 +254,9 @@ class PerfCalc(ABC): # Abstract base class for performance calculation
|
|
304
254
|
"""
|
305
255
|
perftb = PerfTB()
|
306
256
|
# find all "dataset" values (unique)
|
307
|
-
dataset_names = list(df[
|
308
|
-
# find all columns that start with
|
309
|
-
metric_cols = [col for col in df.columns if col.startswith(
|
257
|
+
dataset_names = list(df["dataset"].unique())
|
258
|
+
# find all columns that start with METRIC_PREFIX
|
259
|
+
metric_cols = [col for col in df.columns if col.startswith(METRIC_PREFIX)]
|
310
260
|
|
311
261
|
# Determine which metrics are associated with each dataset.
|
312
262
|
# Since a dataset may appear in multiple rows and may not include all metrics in each, identify the row with the same dataset that contains the most non-NaN metric values. The set of metrics for that dataset is defined by the non-NaN metrics in that row.
|
@@ -316,7 +266,11 @@ class PerfCalc(ABC): # Abstract base class for performance calculation
|
|
316
266
|
dataset_rows = df[df["dataset"] == dataset_name]
|
317
267
|
# Find the row with the most non-NaN metric values
|
318
268
|
max_non_nan_row = dataset_rows[metric_cols].count(axis=1).idxmax()
|
319
|
-
metrics_for_dataset =
|
269
|
+
metrics_for_dataset = (
|
270
|
+
dataset_rows.loc[max_non_nan_row, metric_cols]
|
271
|
+
.dropna()
|
272
|
+
.index.tolist()
|
273
|
+
)
|
320
274
|
dataset_metrics[dataset_name] = metrics_for_dataset
|
321
275
|
|
322
276
|
for dataset_name, metrics in dataset_metrics.items():
|
@@ -324,11 +278,11 @@ class PerfCalc(ABC): # Abstract base class for performance calculation
|
|
324
278
|
perftb.add_dataset(dataset_name, metrics)
|
325
279
|
|
326
280
|
for _, row in df.iterrows():
|
327
|
-
dataset_name = row[
|
281
|
+
dataset_name = row["dataset"]
|
328
282
|
ds_metrics = dataset_metrics.get(dataset_name)
|
329
283
|
if dataset_name in dataset_metrics:
|
330
284
|
# Add the metrics for this row to the performance table
|
331
|
-
exp_name = row.get(
|
285
|
+
exp_name = row.get("experiment")
|
332
286
|
exp_metric_values = {}
|
333
287
|
for metric in ds_metrics:
|
334
288
|
if metric in row and pd.notna(row[metric]):
|
@@ -336,7 +290,7 @@ class PerfCalc(ABC): # Abstract base class for performance calculation
|
|
336
290
|
perftb.add_experiment(
|
337
291
|
experiment_name=exp_name,
|
338
292
|
dataset_name=dataset_name,
|
339
|
-
metrics=exp_metric_values
|
293
|
+
metrics=exp_metric_values,
|
340
294
|
)
|
341
295
|
|
342
296
|
return perftb
|
@@ -351,9 +305,7 @@ class PerfCalc(ABC): # Abstract base class for performance calculation
|
|
351
305
|
if os.path.isdir(os.path.join(indir, d))
|
352
306
|
]
|
353
307
|
if len(exp_dirs) == 0:
|
354
|
-
csv_perf_files = glob.glob(
|
355
|
-
os.path.join(indir, f"*.csv")
|
356
|
-
)
|
308
|
+
csv_perf_files = glob.glob(os.path.join(indir, f"*.csv"))
|
357
309
|
csv_perf_files = [
|
358
310
|
file_item
|
359
311
|
for file_item in csv_perf_files
|
@@ -364,13 +316,9 @@ class PerfCalc(ABC): # Abstract base class for performance calculation
|
|
364
316
|
# Collect all matching CSV files in those subdirs
|
365
317
|
for exp_dir in exp_dirs:
|
366
318
|
# pprint(f"Searching in experiment directory: {exp_dir}")
|
367
|
-
matched = glob.glob(
|
368
|
-
os.path.join(exp_dir, f"*.csv")
|
369
|
-
)
|
319
|
+
matched = glob.glob(os.path.join(exp_dir, f"*.csv"))
|
370
320
|
matched = [
|
371
|
-
file_item
|
372
|
-
for file_item in matched
|
373
|
-
if exp_csv_filter_fn(file_item)
|
321
|
+
file_item for file_item in matched if exp_csv_filter_fn(file_item)
|
374
322
|
]
|
375
323
|
csv_perf_files.extend(matched)
|
376
324
|
|
@@ -378,9 +326,11 @@ class PerfCalc(ABC): # Abstract base class for performance calculation
|
|
378
326
|
len(csv_perf_files) > 0
|
379
327
|
), f"No CSV files matching pattern '{exp_csv_filter_fn}' found in the experiment directories."
|
380
328
|
|
381
|
-
assert
|
329
|
+
assert (
|
330
|
+
len(csv_perf_files) > 0
|
331
|
+
), f"No CSV files matching pattern '{exp_csv_filter_fn}' found in the experiment directories."
|
382
332
|
|
383
333
|
all_exp_perf_df = get_df_for_all_exp_perf(csv_perf_files, csv_sep=csv_sep)
|
384
|
-
csvfile.fn_display_df(all_exp_perf_df)
|
334
|
+
# csvfile.fn_display_df(all_exp_perf_df)
|
385
335
|
perf_tb = mk_perftb_report(all_exp_perf_df)
|
386
|
-
return perf_tb
|
336
|
+
return perf_tb
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: halib
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.67
|
4
4
|
Summary: Small library for common tasks
|
5
5
|
Author: Hoang Van Ha
|
6
6
|
Author-email: hoangvanhauit@gmail.com
|
@@ -52,7 +52,7 @@ Dynamic: summary
|
|
52
52
|
|
53
53
|
Helper package for coding and automation
|
54
54
|
|
55
|
-
**Version 0.1.
|
55
|
+
**Version 0.1.67**
|
56
56
|
|
57
57
|
+ now use `uv` for venv management
|
58
58
|
+ `research/perfcalc`: support both torchmetrics and custom metrics for performance calculation
|
@@ -30,7 +30,8 @@ halib/online/projectmake.py,sha256=Zrs96WgXvO4nIrwxnCOletL4aTBge-EoF0r7hpKO1w8,4
|
|
30
30
|
halib/research/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
31
31
|
halib/research/benchquery.py,sha256=FuKnbWQtCEoRRtJAfN-zaN-jPiO_EzsakmTOMiqi7GQ,4626
|
32
32
|
halib/research/dataset.py,sha256=QU0Hr5QFb8_XlvnOMgC9QJGIpwXAZ9lDd0RdQi_QRec,6743
|
33
|
-
halib/research/
|
33
|
+
halib/research/metrics.py,sha256=Xgv0GUGo-o-RJaBOmkRCRpQJaYijF_1xeKkyYU_Bv4U,5249
|
34
|
+
halib/research/perfcalc.py,sha256=qDa0sqfpWrwGZVJtjuUVFK7JX6j8xyXP9OnnfYmdamg,15898
|
34
35
|
halib/research/perftb.py,sha256=vazU-dYBJhfc4sK4zFgxOvzeXGi-5TyPHCt20ItiWhY,30463
|
35
36
|
halib/research/plot.py,sha256=-pDUk4z3C_GnyJ5zWmf-mGMdT4gaipVJWzIgcpIPiRk,9448
|
36
37
|
halib/research/torchloader.py,sha256=yqUjcSiME6H5W210363HyRUrOi3ISpUFAFkTr1w4DCw,6503
|
@@ -48,8 +49,8 @@ halib/utils/gpu_mon.py,sha256=vD41_ZnmPLKguuq9X44SB_vwd9JrblO4BDzHLXZhhFY,2233
|
|
48
49
|
halib/utils/listop.py,sha256=Vpa8_2fI0wySpB2-8sfTBkyi_A4FhoFVVvFiuvW8N64,339
|
49
50
|
halib/utils/tele_noti.py,sha256=-4WXZelCA4W9BroapkRyIdUu9cUVrcJJhegnMs_WpGU,5928
|
50
51
|
halib/utils/video.py,sha256=ZqzNVPgc1RZr_T0OlHvZ6SzyBpL7O27LtB86JMbBuR0,3059
|
51
|
-
halib-0.1.
|
52
|
-
halib-0.1.
|
53
|
-
halib-0.1.
|
54
|
-
halib-0.1.
|
55
|
-
halib-0.1.
|
52
|
+
halib-0.1.67.dist-info/licenses/LICENSE.txt,sha256=qZssdna4aETiR8znYsShUjidu-U4jUT9Q-EWNlZ9yBQ,1100
|
53
|
+
halib-0.1.67.dist-info/METADATA,sha256=Zk22ct5W95qBzGkz0tNepuAdfUwPJTbVO7Nb4L_hFTQ,5541
|
54
|
+
halib-0.1.67.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
55
|
+
halib-0.1.67.dist-info/top_level.txt,sha256=7AD6PLaQTreE0Fn44mdZsoHBe_Zdd7GUmjsWPyQ7I-k,6
|
56
|
+
halib-0.1.67.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|