halib 0.1.99__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- halib/__init__.py +3 -3
- halib/common/__init__.py +0 -0
- halib/common/common.py +178 -0
- halib/common/rich_color.py +285 -0
- halib/filetype/csvfile.py +3 -9
- halib/filetype/ipynb.py +3 -5
- halib/filetype/jsonfile.py +0 -3
- halib/filetype/textfile.py +0 -1
- halib/filetype/videofile.py +91 -2
- halib/filetype/yamlfile.py +3 -3
- halib/online/projectmake.py +7 -6
- halib/online/tele_noti.py +165 -0
- halib/research/base_exp.py +75 -18
- halib/research/core/__init__.py +0 -0
- halib/research/core/base_config.py +144 -0
- halib/research/core/base_exp.py +157 -0
- halib/research/core/param_gen.py +108 -0
- halib/research/core/wandb_op.py +117 -0
- halib/research/data/__init__.py +0 -0
- halib/research/data/dataclass_util.py +41 -0
- halib/research/data/dataset.py +208 -0
- halib/research/data/torchloader.py +165 -0
- halib/research/dataset.py +1 -1
- halib/research/metrics.py +4 -0
- halib/research/mics.py +8 -2
- halib/research/perf/__init__.py +0 -0
- halib/research/perf/flop_calc.py +190 -0
- halib/research/perf/gpu_mon.py +58 -0
- halib/research/perf/perfcalc.py +363 -0
- halib/research/perf/perfmetrics.py +137 -0
- halib/research/perf/perftb.py +778 -0
- halib/research/perf/profiler.py +301 -0
- halib/research/perfcalc.py +57 -32
- halib/research/viz/__init__.py +0 -0
- halib/research/viz/plot.py +754 -0
- halib/system/filesys.py +60 -20
- halib/system/path.py +73 -0
- halib/utils/dict.py +9 -0
- halib/utils/list.py +12 -0
- {halib-0.1.99.dist-info → halib-0.2.2.dist-info}/METADATA +7 -1
- halib-0.2.2.dist-info/RECORD +89 -0
- halib-0.1.99.dist-info/RECORD +0 -64
- {halib-0.1.99.dist-info → halib-0.2.2.dist-info}/WHEEL +0 -0
- {halib-0.1.99.dist-info → halib-0.2.2.dist-info}/licenses/LICENSE.txt +0 -0
- {halib-0.1.99.dist-info → halib-0.2.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
# -------------------------------
|
|
2
|
+
# Metrics Backend Interface
|
|
3
|
+
# -------------------------------
|
|
4
|
+
import inspect
|
|
5
|
+
from typing import Dict, Union, List, Any
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
|
|
8
|
+
class MetricsBackend(ABC):
|
|
9
|
+
"""Interface for pluggable metrics computation backends."""
|
|
10
|
+
|
|
11
|
+
def __init__(self, metrics_info: Union[List[str], Dict[str, Any]]):
|
|
12
|
+
"""
|
|
13
|
+
Initialize the backend with optional metrics_info.
|
|
14
|
+
`metrics_info` can be either:
|
|
15
|
+
- A list of metric names (strings). e.g., ["accuracy", "precision"]
|
|
16
|
+
- A dict mapping metric names with object that defines how to compute them. e.g: {"accuracy": torchmetrics.Accuracy(), "precision": torchmetrics.Precision()}
|
|
17
|
+
|
|
18
|
+
"""
|
|
19
|
+
self.metric_info = metrics_info
|
|
20
|
+
self.validate_metrics_info(self.metric_info)
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def metric_names(self) -> List[str]:
|
|
24
|
+
"""
|
|
25
|
+
Return a list of metric names.
|
|
26
|
+
If metric_info is a dict, return its keys; if it's a list, return it directly.
|
|
27
|
+
"""
|
|
28
|
+
if isinstance(self.metric_info, dict):
|
|
29
|
+
return list(self.metric_info.keys())
|
|
30
|
+
elif isinstance(self.metric_info, list):
|
|
31
|
+
return self.metric_info
|
|
32
|
+
else:
|
|
33
|
+
raise TypeError("metric_info must be a list or a dict")
|
|
34
|
+
|
|
35
|
+
def validate_metrics_info(self, metrics_info):
|
|
36
|
+
if isinstance(metrics_info, list):
|
|
37
|
+
return metrics_info
|
|
38
|
+
elif isinstance(metrics_info, dict):
|
|
39
|
+
return {k: v for k, v in metrics_info.items() if isinstance(k, str)}
|
|
40
|
+
else:
|
|
41
|
+
raise TypeError(
|
|
42
|
+
"metrics_info must be a list of strings or a dict with string keys"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
@abstractmethod
|
|
46
|
+
def compute_metrics(
|
|
47
|
+
self, metrics_info: Union[List[str], Dict[str, Any]], metrics_data_dict: Dict[str, Any], *args, **kwargs
|
|
48
|
+
) -> Dict[str, Any]:
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
def prepare_metrics_backend_data(
|
|
52
|
+
self, raw_metric_data, *args, **kwargs
|
|
53
|
+
):
|
|
54
|
+
"""
|
|
55
|
+
Prepare the data for the metrics backend.
|
|
56
|
+
This method can be overridden by subclasses to customize data preparation.
|
|
57
|
+
"""
|
|
58
|
+
return raw_metric_data
|
|
59
|
+
|
|
60
|
+
def calc_metrics(
|
|
61
|
+
self, metrics_data_dict: Dict[str, Any], *args, **kwargs
|
|
62
|
+
) -> Dict[str, Any]:
|
|
63
|
+
"""
|
|
64
|
+
Calculate metrics based on the provided metrics_info and data.
|
|
65
|
+
This method should be overridden by subclasses to implement specific metric calculations.
|
|
66
|
+
"""
|
|
67
|
+
# prevalidate the metrics_data_dict
|
|
68
|
+
for metric in self.metric_names:
|
|
69
|
+
if metric not in metrics_data_dict:
|
|
70
|
+
raise ValueError(f"Metric '{metric}' not found in provided data.")
|
|
71
|
+
# Prepare the data for the backend
|
|
72
|
+
metrics_data_dict = self.prepare_metrics_backend_data(
|
|
73
|
+
metrics_data_dict, *args, **kwargs
|
|
74
|
+
)
|
|
75
|
+
# Call the abstract method to compute metrics
|
|
76
|
+
return self.compute_metrics(self.metric_info, metrics_data_dict, *args, **kwargs)
|
|
77
|
+
|
|
78
|
+
class TorchMetricsBackend(MetricsBackend):
|
|
79
|
+
"""TorchMetrics-based backend implementation."""
|
|
80
|
+
|
|
81
|
+
def __init__(self, metrics_info: Union[List[str], Dict[str, Any]]):
|
|
82
|
+
try:
|
|
83
|
+
import torch
|
|
84
|
+
from torchmetrics import Metric
|
|
85
|
+
except ImportError:
|
|
86
|
+
raise ImportError(
|
|
87
|
+
"TorchMetricsBackend requires torch and torchmetrics to be installed."
|
|
88
|
+
)
|
|
89
|
+
self.metric_info = metrics_info
|
|
90
|
+
self.torch = torch
|
|
91
|
+
self.Metric = Metric
|
|
92
|
+
self.validate_metrics_info(metrics_info)
|
|
93
|
+
|
|
94
|
+
def validate_metrics_info(self, metrics_info):
|
|
95
|
+
if not isinstance(metrics_info, dict):
|
|
96
|
+
raise TypeError(
|
|
97
|
+
"TorchMetricsBackend requires metrics_info as a dict {name: MetricInstance}"
|
|
98
|
+
)
|
|
99
|
+
for k, v in metrics_info.items():
|
|
100
|
+
if not isinstance(k, str):
|
|
101
|
+
raise TypeError(f"Key '{k}' is not a string")
|
|
102
|
+
if not isinstance(v, self.Metric):
|
|
103
|
+
raise TypeError(f"Value for key '{k}' must be a torchmetrics.Metric")
|
|
104
|
+
return metrics_info
|
|
105
|
+
|
|
106
|
+
def compute_metrics(self, metrics_info, metrics_data_dict, *args, **kwargs):
|
|
107
|
+
out_dict = {}
|
|
108
|
+
for metric, metric_instance in metrics_info.items():
|
|
109
|
+
if metric not in metrics_data_dict:
|
|
110
|
+
raise ValueError(f"Metric '{metric}' not found in provided data.")
|
|
111
|
+
|
|
112
|
+
metric_data = metrics_data_dict[metric]
|
|
113
|
+
sig = inspect.signature(metric_instance.update)
|
|
114
|
+
expected_args = list(sig.parameters.values())
|
|
115
|
+
|
|
116
|
+
if isinstance(metric_data, dict):
|
|
117
|
+
args = [metric_data[param.name] for param in expected_args]
|
|
118
|
+
elif isinstance(metric_data, (list, tuple)):
|
|
119
|
+
args = metric_data
|
|
120
|
+
else:
|
|
121
|
+
args = metric_data
|
|
122
|
+
if len(expected_args) == 1:
|
|
123
|
+
metric_instance.update(args)
|
|
124
|
+
else:
|
|
125
|
+
metric_instance.update(*args)
|
|
126
|
+
|
|
127
|
+
computed_value = metric_instance.compute()
|
|
128
|
+
if isinstance(computed_value, self.torch.Tensor):
|
|
129
|
+
computed_value = (
|
|
130
|
+
computed_value.item()
|
|
131
|
+
if computed_value.numel() == 1
|
|
132
|
+
else computed_value.tolist()
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
out_dict[metric] = computed_value
|
|
137
|
+
return out_dict
|