halib 0.1.91__py3-none-any.whl → 0.2.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. halib/__init__.py +12 -6
  2. halib/common/__init__.py +0 -0
  3. halib/common/common.py +207 -0
  4. halib/common/rich_color.py +285 -0
  5. halib/common.py +53 -10
  6. halib/exp/__init__.py +0 -0
  7. halib/exp/core/__init__.py +0 -0
  8. halib/exp/core/base_config.py +167 -0
  9. halib/exp/core/base_exp.py +147 -0
  10. halib/exp/core/param_gen.py +189 -0
  11. halib/exp/core/wandb_op.py +117 -0
  12. halib/exp/data/__init__.py +0 -0
  13. halib/exp/data/dataclass_util.py +41 -0
  14. halib/exp/data/dataset.py +208 -0
  15. halib/exp/data/torchloader.py +165 -0
  16. halib/exp/perf/__init__.py +0 -0
  17. halib/exp/perf/flop_calc.py +190 -0
  18. halib/exp/perf/gpu_mon.py +58 -0
  19. halib/exp/perf/perfcalc.py +440 -0
  20. halib/exp/perf/perfmetrics.py +137 -0
  21. halib/exp/perf/perftb.py +778 -0
  22. halib/exp/perf/profiler.py +507 -0
  23. halib/exp/viz/__init__.py +0 -0
  24. halib/exp/viz/plot.py +754 -0
  25. halib/filetype/csvfile.py +3 -9
  26. halib/filetype/ipynb.py +61 -0
  27. halib/filetype/jsonfile.py +0 -3
  28. halib/filetype/textfile.py +0 -1
  29. halib/filetype/videofile.py +119 -3
  30. halib/filetype/yamlfile.py +16 -1
  31. halib/online/projectmake.py +7 -6
  32. halib/online/tele_noti.py +165 -0
  33. halib/research/base_exp.py +75 -18
  34. halib/research/core/__init__.py +0 -0
  35. halib/research/core/base_config.py +144 -0
  36. halib/research/core/base_exp.py +157 -0
  37. halib/research/core/param_gen.py +108 -0
  38. halib/research/core/wandb_op.py +117 -0
  39. halib/research/data/__init__.py +0 -0
  40. halib/research/data/dataclass_util.py +41 -0
  41. halib/research/data/dataset.py +208 -0
  42. halib/research/data/torchloader.py +165 -0
  43. halib/research/dataset.py +6 -7
  44. halib/research/flop_csv.py +34 -0
  45. halib/research/flops.py +156 -0
  46. halib/research/metrics.py +4 -0
  47. halib/research/mics.py +59 -1
  48. halib/research/perf/__init__.py +0 -0
  49. halib/research/perf/flop_calc.py +190 -0
  50. halib/research/perf/gpu_mon.py +58 -0
  51. halib/research/perf/perfcalc.py +363 -0
  52. halib/research/perf/perfmetrics.py +137 -0
  53. halib/research/perf/perftb.py +778 -0
  54. halib/research/perf/profiler.py +301 -0
  55. halib/research/perfcalc.py +60 -35
  56. halib/research/perftb.py +2 -1
  57. halib/research/plot.py +480 -218
  58. halib/research/viz/__init__.py +0 -0
  59. halib/research/viz/plot.py +754 -0
  60. halib/system/_list_pc.csv +6 -0
  61. halib/system/filesys.py +60 -20
  62. halib/system/path.py +106 -0
  63. halib/utils/dict.py +9 -0
  64. halib/utils/list.py +12 -0
  65. halib/utils/video.py +6 -0
  66. halib-0.2.21.dist-info/METADATA +192 -0
  67. halib-0.2.21.dist-info/RECORD +109 -0
  68. halib-0.1.91.dist-info/METADATA +0 -201
  69. halib-0.1.91.dist-info/RECORD +0 -61
  70. {halib-0.1.91.dist-info → halib-0.2.21.dist-info}/WHEEL +0 -0
  71. {halib-0.1.91.dist-info → halib-0.2.21.dist-info}/licenses/LICENSE.txt +0 -0
  72. {halib-0.1.91.dist-info → halib-0.2.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,440 @@
1
+ import os
2
+ import glob
3
+ from typing import Optional, Tuple
4
+ import pandas as pd
5
+
6
+ from abc import ABC, abstractmethod
7
+ from collections import OrderedDict
8
+
9
+
10
+ from ...common.common import now_str
11
+ from ...system import filesys as fs
12
+
13
+ from .perftb import PerfTB
14
+ from .perfmetrics import *
15
+
16
+
17
+ REQUIRED_COLS = ["experiment", "dataset"]
18
+ CSV_FILE_POSTFIX = "__perf"
19
+ METRIC_PREFIX = "metric_"
20
+
21
+
22
+ class PerfCalc(ABC): # Abstract base class for performance calculation
23
+ @abstractmethod
24
+ def get_experiment_name(self) -> str:
25
+ """
26
+ Return the name of the experiment.
27
+ This function should be overridden by the subclass if needed.
28
+ """
29
+ pass
30
+
31
+ @abstractmethod
32
+ def get_dataset_name(self) -> str:
33
+ """
34
+ Return the name of the dataset.
35
+ This function should be overridden by the subclass if needed.
36
+ """
37
+ pass
38
+
39
+ @abstractmethod
40
+ def get_metric_backend(self) -> MetricsBackend:
41
+ """
42
+ Return a list of metric names to be used for performance calculation OR a dictionaray with keys as metric names and values as metric instances of torchmetrics.Metric. For example: {"accuracy": Accuracy(), "precision": Precision()}
43
+
44
+ """
45
+ pass
46
+
47
+ def valid_proc_extra_data(self, proc_extra_data):
48
+ # make sure that all items in proc_extra_data are dictionaries, with same keys
49
+ if proc_extra_data is None or len(proc_extra_data) == 0:
50
+ return
51
+ if not all(isinstance(item, dict) for item in proc_extra_data):
52
+ raise TypeError("All items in proc_extra_data must be dictionaries")
53
+
54
+ if not all(
55
+ item.keys() == proc_extra_data[0].keys() for item in proc_extra_data
56
+ ):
57
+ raise ValueError(
58
+ "All dictionaries in proc_extra_data must have the same keys"
59
+ )
60
+
61
+ def valid_proc_metric_raw_data(self, metric_names, proc_metric_raw_data):
62
+ # make sure that all items in proc_metric_raw_data are dictionaries, with same keys as metric_names
63
+ assert (
64
+ isinstance(proc_metric_raw_data, list) and len(proc_metric_raw_data) > 0
65
+ ), "raw_data_for_metrics must be a non-empty list of dictionaries"
66
+
67
+ # make sure that all items in proc_metric_raw_data are dictionaries with keys as metric_names
68
+ if not all(isinstance(item, dict) for item in proc_metric_raw_data):
69
+ raise TypeError("All items in raw_data_for_metrics must be dictionaries")
70
+ if not all(
71
+ set(item.keys()) == set(metric_names) for item in proc_metric_raw_data
72
+ ):
73
+ raise ValueError(
74
+ "All dictionaries in raw_data_for_metrics must have the same keys as metric_names"
75
+ )
76
+
77
+ # =========================================================================
78
+ # 1. Formatting Logic (Decoupled)
79
+ # =========================================================================
80
+ def package_metrics(
81
+ self,
82
+ metric_results_list: List[dict],
83
+ extra_data_list: Optional[List[dict]] = None,
84
+ ) -> List[OrderedDict]:
85
+ """
86
+ Pure formatting function.
87
+ Takes ALREADY CALCULATED metrics and formats them
88
+ (adds metadata, prefixes keys, ensures column order).
89
+ """
90
+ # Normalize extra_data to a list if provided
91
+ if extra_data_list is None:
92
+ extra_data_list = [{} for _ in range(len(metric_results_list))]
93
+ elif isinstance(extra_data_list, dict):
94
+ extra_data_list = [extra_data_list]
95
+
96
+ assert len(extra_data_list) == len(
97
+ metric_results_list
98
+ ), "Length mismatch: metrics vs extra_data"
99
+
100
+ proc_outdict_list = []
101
+
102
+ for metric_res, extra_item in zip(metric_results_list, extra_data_list):
103
+ # A. Base Metadata
104
+ out_dict = {
105
+ "dataset": self.get_dataset_name(),
106
+ "experiment": self.get_experiment_name(),
107
+ }
108
+
109
+ # B. Attach Extra Data
110
+ out_dict.update(extra_item)
111
+ custom_fields = list(extra_item.keys())
112
+
113
+ # C. Prefix Metric Keys (e.g., 'acc' -> 'metric_acc')
114
+ metric_results_prefixed = {f"metric_{k}": v for k, v in metric_res.items()}
115
+ out_dict.update(metric_results_prefixed)
116
+
117
+ # D. Order Columns
118
+ all_cols = (
119
+ REQUIRED_COLS + custom_fields + list(metric_results_prefixed.keys())
120
+ )
121
+ ordered_out = OrderedDict(
122
+ (col, out_dict[col]) for col in all_cols if col in out_dict
123
+ )
124
+ proc_outdict_list.append(ordered_out)
125
+
126
+ return proc_outdict_list
127
+
128
+ # =========================================================================
129
+ # 2. Calculation Logic (The Coordinator)
130
+ # =========================================================================
131
+ def calc_exp_perf_metrics(
132
+ self,
133
+ metric_names: List[str],
134
+ raw_metrics_data: Union[List[dict], dict],
135
+ extra_data: Optional[Union[List[dict], dict]] = None,
136
+ *args,
137
+ **kwargs,
138
+ ) -> List[OrderedDict]:
139
+ """
140
+ Full workflow: Validates raw data -> Calculates via Backend -> Packages results.
141
+ """
142
+ # Prepare Raw Data
143
+ raw_data_ls = (
144
+ raw_metrics_data
145
+ if isinstance(raw_metrics_data, list)
146
+ else [raw_metrics_data]
147
+ )
148
+ self.valid_proc_metric_raw_data(metric_names, raw_data_ls)
149
+
150
+ # Prepare Extra Data (Validation only)
151
+ extra_data_ls = None
152
+ if extra_data:
153
+ extra_data_ls = extra_data if isinstance(extra_data, list) else [extra_data]
154
+ self.valid_proc_extra_data(extra_data_ls)
155
+
156
+ # Calculate Metrics via Backend
157
+ metrics_backend = self.get_metric_backend()
158
+ calculated_results = []
159
+
160
+ for data_item in raw_data_ls:
161
+ res = metrics_backend.calc_metrics(
162
+ metrics_data_dict=data_item, *args, **kwargs
163
+ )
164
+ calculated_results.append(res)
165
+
166
+ # Delegate to Formatting
167
+ return self.package_metrics(calculated_results, extra_data_ls)
168
+
169
+ # =========================================================================
170
+ # 3. File Saving Logic (Decoupled)
171
+ # =========================================================================
172
+ def save_results_to_csv(
173
+ self, out_dict_list: List[OrderedDict], **kwargs
174
+ ) -> Tuple[pd.DataFrame, Optional[str]]:
175
+ """
176
+ Helper function to convert results to DataFrame and save to CSV.
177
+ """
178
+ csv_outfile = kwargs.get("outfile", None)
179
+
180
+ # Determine Output Path
181
+ if csv_outfile is not None:
182
+ filePathNoExt, _ = os.path.splitext(csv_outfile)
183
+ csv_outfile = f"{filePathNoExt}{CSV_FILE_POSTFIX}.csv"
184
+ elif "outdir" in kwargs:
185
+ csvoutdir = kwargs["outdir"]
186
+ csvfilename = f"{now_str()}_{self.get_dataset_name()}_{self.get_experiment_name()}_{CSV_FILE_POSTFIX}.csv"
187
+ csv_outfile = os.path.join(csvoutdir, csvfilename)
188
+
189
+ # Convert to DataFrame
190
+ df = pd.DataFrame(out_dict_list)
191
+ if out_dict_list:
192
+ ordered_cols = list(out_dict_list[0].keys())
193
+ df = df[ordered_cols]
194
+
195
+ # Save to File
196
+ if csv_outfile:
197
+ df.to_csv(csv_outfile, index=False, sep=";", encoding="utf-8")
198
+
199
+ return df, csv_outfile
200
+
201
+ # =========================================================================
202
+ # 4. Public API: Standard Calculation
203
+ # raw_metrics_data: example: [{"preds": ..., "target": ...}, ...]
204
+ # =========================================================================
205
+ def calc_perfs(
206
+ self,
207
+ raw_metrics_data: Union[List[dict], dict],
208
+ extra_data: Optional[Union[List[dict], dict]] = None,
209
+ *args,
210
+ **kwargs,
211
+ ) -> Tuple[Union[List[OrderedDict], pd.DataFrame], Optional[str]]:
212
+ """
213
+ Standard use case: Calculate metrics AND save to CSV.
214
+ """
215
+ metric_names = self.get_metric_backend().metric_names
216
+
217
+ # 1. Calculate & Package
218
+ out_dict_list = self.calc_exp_perf_metrics(
219
+ metric_names=metric_names,
220
+ raw_metrics_data=raw_metrics_data,
221
+ extra_data=extra_data,
222
+ *args,
223
+ **kwargs,
224
+ )
225
+
226
+ # 2. Save
227
+ df, csv_outfile = self.save_results_to_csv(out_dict_list, **kwargs)
228
+
229
+ return (
230
+ (df, csv_outfile)
231
+ if kwargs.get("return_df", False)
232
+ else (out_dict_list, csv_outfile)
233
+ )
234
+
235
+ # =========================================================================
236
+ # 5. Public API: Manual / External Metrics (The Shortcut)
237
+ # =========================================================================
238
+ def save_computed_perfs(
239
+ self,
240
+ metrics_data: Union[List[dict], dict],
241
+ extra_data: Optional[Union[List[dict], dict]] = None,
242
+ **kwargs,
243
+ ) -> Tuple[Union[List[OrderedDict], pd.DataFrame], Optional[str]]:
244
+
245
+ # Ensure list format
246
+ if isinstance(metrics_data, dict):
247
+ metrics_data = [metrics_data]
248
+ if isinstance(extra_data, dict):
249
+ extra_data = [extra_data]
250
+
251
+ # 1. Package (Format)
252
+ formatted_list = self.package_metrics(metrics_data, extra_data)
253
+
254
+ # 2. Save
255
+ df, csv_outfile = self.save_results_to_csv(formatted_list, **kwargs)
256
+
257
+ return (
258
+ (df, csv_outfile)
259
+ if kwargs.get("return_df", False)
260
+ else (formatted_list, csv_outfile)
261
+ )
262
+
263
+ @staticmethod
264
+ def default_exp_csv_filter_fn(exp_file_name: str) -> bool:
265
+ """
266
+ Default filter function for experiments.
267
+ Returns True if the experiment name does not start with "test_" or "debug_".
268
+ """
269
+ return "__perf.csv" in exp_file_name
270
+
271
+ @classmethod
272
+ def get_perftb_for_multi_exps(
273
+ cls,
274
+ indir: str,
275
+ exp_csv_filter_fn=default_exp_csv_filter_fn,
276
+ include_file_name=False,
277
+ csv_sep=";",
278
+ ) -> PerfTB:
279
+ """
280
+ Generate a performance report by scanning experiment subdirectories.
281
+ Must return a dictionary with keys as metric names and values as performance tables.
282
+ """
283
+
284
+ def get_df_for_all_exp_perf(csv_perf_files, csv_sep=";"):
285
+ """
286
+ Create a single DataFrame from all CSV files.
287
+ Assumes all CSV files MAY have different metrics
288
+ """
289
+ cols = []
290
+ FILE_NAME_COL = "file_name" if include_file_name else None
291
+
292
+ for csv_file in csv_perf_files:
293
+ temp_df = pd.read_csv(csv_file, sep=csv_sep)
294
+ if FILE_NAME_COL:
295
+ temp_df[FILE_NAME_COL] = fs.get_file_name(
296
+ csv_file, split_file_ext=False
297
+ )
298
+ # csvfile.fn_display_df(temp_df)
299
+ temp_df_cols = temp_df.columns.tolist()
300
+ for col in temp_df_cols:
301
+ if col not in cols:
302
+ cols.append(col)
303
+
304
+ df = pd.DataFrame(columns=cols)
305
+ for csv_file in csv_perf_files:
306
+ temp_df = pd.read_csv(csv_file, sep=csv_sep)
307
+ if FILE_NAME_COL:
308
+ temp_df[FILE_NAME_COL] = fs.get_file_name(
309
+ csv_file, split_file_ext=False
310
+ )
311
+ # Drop all-NA columns to avoid dtype inconsistency
312
+ temp_df = temp_df.dropna(axis=1, how="all")
313
+ # ensure all columns are present in the final DataFrame
314
+ for col in cols:
315
+ if col not in temp_df.columns:
316
+ temp_df[col] = None # fill missing columns with None
317
+ df = pd.concat([df, temp_df], ignore_index=True)
318
+ # assert that REQUIRED_COLS are present in the DataFrame
319
+ # pprint(df.columns.tolist())
320
+ sticky_cols = REQUIRED_COLS + (
321
+ [FILE_NAME_COL] if include_file_name else []
322
+ ) # columns that must always be present
323
+ for col in sticky_cols:
324
+ if col not in df.columns:
325
+ raise ValueError(
326
+ f"Required column '{col}' is missing from the DataFrame. REQUIRED_COLS = {sticky_cols}"
327
+ )
328
+ metric_cols = [col for col in df.columns if col.startswith(METRIC_PREFIX)]
329
+ assert (
330
+ len(metric_cols) > 0
331
+ ), "No metric columns found in the DataFrame. Ensure that the CSV files contain metric columns starting with 'metric_'."
332
+ final_cols = sticky_cols + metric_cols
333
+ df = df[final_cols]
334
+ # # !hahv debug
335
+ # pprint("------ Final DataFrame Columns ------")
336
+ # csvfile.fn_display_df(df)
337
+ # ! validate all rows in df before returning
338
+ # make sure all rows will have at least values for REQUIRED_COLS and at least one metric column
339
+ for index, row in df.iterrows():
340
+ if not all(col in row and pd.notna(row[col]) for col in sticky_cols):
341
+ raise ValueError(
342
+ f"Row {index} is missing required columns or has NaN values in required columns: {row}"
343
+ )
344
+ if not any(pd.notna(row[col]) for col in metric_cols):
345
+ raise ValueError(f"Row {index} has no metric values: {row}")
346
+ # make sure these is no (experiment, dataset) pair that is duplicated
347
+ duplicates = df.duplicated(subset=sticky_cols, keep=False)
348
+ if duplicates.any():
349
+ raise ValueError(
350
+ "Duplicate (experiment, dataset) pairs found in the DataFrame. Please ensure that each experiment-dataset combination is unique."
351
+ )
352
+ return df
353
+
354
+ def mk_perftb_report(df):
355
+ """
356
+ Create a performance report table from the DataFrame.
357
+ This function should be customized based on the specific requirements of the report.
358
+ """
359
+ perftb = PerfTB()
360
+ # find all "dataset" values (unique)
361
+ dataset_names = list(df["dataset"].unique())
362
+ # find all columns that start with METRIC_PREFIX
363
+ metric_cols = [col for col in df.columns if col.startswith(METRIC_PREFIX)]
364
+
365
+ # Determine which metrics are associated with each dataset.
366
+ # Since a dataset may appear in multiple rows and may not include all metrics in each, identify the row with the same dataset that contains the most non-NaN metric values. The set of metrics for that dataset is defined by the non-NaN metrics in that row.
367
+
368
+ dataset_metrics = {}
369
+ for dataset_name in dataset_names:
370
+ dataset_rows = df[df["dataset"] == dataset_name]
371
+ # Find the row with the most non-NaN metric values
372
+ max_non_nan_row = dataset_rows[metric_cols].count(axis=1).idxmax()
373
+ metrics_for_dataset = (
374
+ dataset_rows.loc[max_non_nan_row, metric_cols]
375
+ .dropna()
376
+ .index.tolist()
377
+ )
378
+ dataset_metrics[dataset_name] = metrics_for_dataset
379
+
380
+ for dataset_name, metrics in dataset_metrics.items():
381
+ # Create a new row for the performance table
382
+ perftb.add_dataset(dataset_name, metrics)
383
+
384
+ for _, row in df.iterrows():
385
+ dataset_name = row["dataset"]
386
+ ds_metrics = dataset_metrics.get(dataset_name)
387
+ if dataset_name in dataset_metrics:
388
+ # Add the metrics for this row to the performance table
389
+ exp_name = row.get("experiment")
390
+ exp_metric_values = {}
391
+ for metric in ds_metrics:
392
+ if metric in row and pd.notna(row[metric]):
393
+ exp_metric_values[metric] = row[metric]
394
+ perftb.add_experiment(
395
+ experiment_name=exp_name,
396
+ dataset_name=dataset_name,
397
+ metrics=exp_metric_values,
398
+ )
399
+
400
+ return perftb
401
+
402
+ assert os.path.exists(indir), f"Input directory {indir} does not exist."
403
+
404
+ csv_perf_files = []
405
+ # Find experiment subdirectories
406
+ exp_dirs = [
407
+ os.path.join(indir, d)
408
+ for d in os.listdir(indir)
409
+ if os.path.isdir(os.path.join(indir, d))
410
+ ]
411
+ if len(exp_dirs) == 0:
412
+ csv_perf_files = glob.glob(os.path.join(indir, f"*.csv"))
413
+ csv_perf_files = [
414
+ file_item
415
+ for file_item in csv_perf_files
416
+ if exp_csv_filter_fn(file_item)
417
+ ]
418
+ else:
419
+ # multiple experiment directories found
420
+ # Collect all matching CSV files in those subdirs
421
+ for exp_dir in exp_dirs:
422
+ # pprint(f"Searching in experiment directory: {exp_dir}")
423
+ matched = glob.glob(os.path.join(exp_dir, f"*.csv"))
424
+ matched = [
425
+ file_item for file_item in matched if exp_csv_filter_fn(file_item)
426
+ ]
427
+ csv_perf_files.extend(matched)
428
+
429
+ assert (
430
+ len(csv_perf_files) > 0
431
+ ), f"No CSV files matching pattern '{exp_csv_filter_fn}' found in the experiment directories."
432
+
433
+ assert (
434
+ len(csv_perf_files) > 0
435
+ ), f"No CSV files matching pattern '{exp_csv_filter_fn}' found in the experiment directories."
436
+
437
+ all_exp_perf_df = get_df_for_all_exp_perf(csv_perf_files, csv_sep=csv_sep)
438
+ # csvfile.fn_display_df(all_exp_perf_df)
439
+ perf_tb = mk_perftb_report(all_exp_perf_df)
440
+ return perf_tb
@@ -0,0 +1,137 @@
1
+ # -------------------------------
2
+ # Metrics Backend Interface
3
+ # -------------------------------
4
+ import inspect
5
+ from typing import Dict, Union, List, Any
6
+ from abc import ABC, abstractmethod
7
+
8
+ class MetricsBackend(ABC):
9
+ """Interface for pluggable metrics computation backends."""
10
+
11
+ def __init__(self, metrics_info: Union[List[str], Dict[str, Any]]):
12
+ """
13
+ Initialize the backend with optional metrics_info.
14
+ `metrics_info` can be either:
15
+ - A list of metric names (strings). e.g., ["accuracy", "precision"]
16
+ - A dict mapping metric names with object that defines how to compute them. e.g: {"accuracy": torchmetrics.Accuracy(), "precision": torchmetrics.Precision()}
17
+
18
+ """
19
+ self.metric_info = metrics_info
20
+ self.validate_metrics_info(self.metric_info)
21
+
22
+ @property
23
+ def metric_names(self) -> List[str]:
24
+ """
25
+ Return a list of metric names.
26
+ If metric_info is a dict, return its keys; if it's a list, return it directly.
27
+ """
28
+ if isinstance(self.metric_info, dict):
29
+ return list(self.metric_info.keys())
30
+ elif isinstance(self.metric_info, list):
31
+ return self.metric_info
32
+ else:
33
+ raise TypeError("metric_info must be a list or a dict")
34
+
35
+ def validate_metrics_info(self, metrics_info):
36
+ if isinstance(metrics_info, list):
37
+ return metrics_info
38
+ elif isinstance(metrics_info, dict):
39
+ return {k: v for k, v in metrics_info.items() if isinstance(k, str)}
40
+ else:
41
+ raise TypeError(
42
+ "metrics_info must be a list of strings or a dict with string keys"
43
+ )
44
+
45
+ @abstractmethod
46
+ def compute_metrics(
47
+ self, metrics_info: Union[List[str], Dict[str, Any]], metrics_data_dict: Dict[str, Any], *args, **kwargs
48
+ ) -> Dict[str, Any]:
49
+ pass
50
+
51
+ def prepare_metrics_backend_data(
52
+ self, raw_metric_data, *args, **kwargs
53
+ ):
54
+ """
55
+ Prepare the data for the metrics backend.
56
+ This method can be overridden by subclasses to customize data preparation.
57
+ """
58
+ return raw_metric_data
59
+
60
+ def calc_metrics(
61
+ self, metrics_data_dict: Dict[str, Any], *args, **kwargs
62
+ ) -> Dict[str, Any]:
63
+ """
64
+ Calculate metrics based on the provided metrics_info and data.
65
+ This method should be overridden by subclasses to implement specific metric calculations.
66
+ """
67
+ # prevalidate the metrics_data_dict
68
+ for metric in self.metric_names:
69
+ if metric not in metrics_data_dict:
70
+ raise ValueError(f"Metric '{metric}' not found in provided data.")
71
+ # Prepare the data for the backend
72
+ metrics_data_dict = self.prepare_metrics_backend_data(
73
+ metrics_data_dict, *args, **kwargs
74
+ )
75
+ # Call the abstract method to compute metrics
76
+ return self.compute_metrics(self.metric_info, metrics_data_dict, *args, **kwargs)
77
+
78
+ class TorchMetricsBackend(MetricsBackend):
79
+ """TorchMetrics-based backend implementation."""
80
+
81
+ def __init__(self, metrics_info: Union[List[str], Dict[str, Any]]):
82
+ try:
83
+ import torch
84
+ from torchmetrics import Metric
85
+ except ImportError:
86
+ raise ImportError(
87
+ "TorchMetricsBackend requires torch and torchmetrics to be installed."
88
+ )
89
+ self.metric_info = metrics_info
90
+ self.torch = torch
91
+ self.Metric = Metric
92
+ self.validate_metrics_info(metrics_info)
93
+
94
+ def validate_metrics_info(self, metrics_info):
95
+ if not isinstance(metrics_info, dict):
96
+ raise TypeError(
97
+ "TorchMetricsBackend requires metrics_info as a dict {name: MetricInstance}"
98
+ )
99
+ for k, v in metrics_info.items():
100
+ if not isinstance(k, str):
101
+ raise TypeError(f"Key '{k}' is not a string")
102
+ if not isinstance(v, self.Metric):
103
+ raise TypeError(f"Value for key '{k}' must be a torchmetrics.Metric")
104
+ return metrics_info
105
+
106
+ def compute_metrics(self, metrics_info, metrics_data_dict, *args, **kwargs):
107
+ out_dict = {}
108
+ for metric, metric_instance in metrics_info.items():
109
+ if metric not in metrics_data_dict:
110
+ raise ValueError(f"Metric '{metric}' not found in provided data.")
111
+
112
+ metric_data = metrics_data_dict[metric]
113
+ sig = inspect.signature(metric_instance.update)
114
+ expected_args = list(sig.parameters.values())
115
+
116
+ if isinstance(metric_data, dict):
117
+ args = [metric_data[param.name] for param in expected_args]
118
+ elif isinstance(metric_data, (list, tuple)):
119
+ args = metric_data
120
+ else:
121
+ args = metric_data
122
+ if len(expected_args) == 1:
123
+ metric_instance.update(args)
124
+ else:
125
+ metric_instance.update(*args)
126
+
127
+ computed_value = metric_instance.compute()
128
+ if isinstance(computed_value, self.torch.Tensor):
129
+ computed_value = (
130
+ computed_value.item()
131
+ if computed_value.numel() == 1
132
+ else computed_value.tolist()
133
+ )
134
+
135
+
136
+ out_dict[metric] = computed_value
137
+ return out_dict