halib 0.1.50__py3-none-any.whl → 0.1.51__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- halib/research/perftb.py +757 -0
- {halib-0.1.50.dist-info → halib-0.1.51.dist-info}/METADATA +5 -1
- {halib-0.1.50.dist-info → halib-0.1.51.dist-info}/RECORD +6 -5
- {halib-0.1.50.dist-info → halib-0.1.51.dist-info}/LICENSE.txt +0 -0
- {halib-0.1.50.dist-info → halib-0.1.51.dist-info}/WHEEL +0 -0
- {halib-0.1.50.dist-info → halib-0.1.51.dist-info}/top_level.txt +0 -0
halib/research/perftb.py
ADDED
@@ -0,0 +1,757 @@
|
|
1
|
+
import warnings
|
2
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
3
|
+
|
4
|
+
import random
|
5
|
+
import itertools
|
6
|
+
import plotly.express as px
|
7
|
+
import plotly.graph_objects as go
|
8
|
+
from collections import defaultdict
|
9
|
+
from plotly.subplots import make_subplots
|
10
|
+
from typing import Dict, List, Union, Optional
|
11
|
+
|
12
|
+
from rich.pretty import pprint
|
13
|
+
import pandas as pd
|
14
|
+
|
15
|
+
# from halib import *
|
16
|
+
# internal imports
|
17
|
+
from ..filetype import csvfile
|
18
|
+
from ..common import ConsoleLog
|
19
|
+
|
20
|
+
|
21
|
+
class DatasetMetrics:
|
22
|
+
"""Class to store metrics definitions for a specific dataset."""
|
23
|
+
|
24
|
+
def __init__(self, dataset_name: str, metric_names: List[str]):
|
25
|
+
self.dataset_name = dataset_name
|
26
|
+
self.metric_names = set(metric_names) # Unique metric names
|
27
|
+
self.experiment_results: Dict[str, Dict[str, Union[float, int, None]]] = (
|
28
|
+
defaultdict(dict)
|
29
|
+
)
|
30
|
+
|
31
|
+
def add_experiment_result(
|
32
|
+
self, experiment_name: str, metrics: Dict[str, Union[float, int]]
|
33
|
+
) -> None:
|
34
|
+
"""Add experiment results for this dataset, only for defined metrics."""
|
35
|
+
# normalize metric names to lowercase
|
36
|
+
metrics = {k.lower(): v for k, v in metrics.items()}
|
37
|
+
# make sure every metric in metrics is defined for this dataset
|
38
|
+
for metric in metrics:
|
39
|
+
assert metric in self.metric_names, (
|
40
|
+
f"Metric <<{metric}>> not defined for dataset <<{self.dataset_name}>>. "
|
41
|
+
f"Available metrics: {self.metric_names}"
|
42
|
+
)
|
43
|
+
for metric in self.metric_names:
|
44
|
+
self.experiment_results[experiment_name][metric] = metrics.get(metric)
|
45
|
+
|
46
|
+
def get_metrics(self, experiment_name: str) -> Dict[str, Union[float, int, None]]:
|
47
|
+
"""Retrieve metrics for a specific experiment."""
|
48
|
+
return self.experiment_results.get(
|
49
|
+
experiment_name, {metric: None for metric in self.metric_names}
|
50
|
+
)
|
51
|
+
|
52
|
+
def __str__(self) -> str:
|
53
|
+
return f"Dataset: {self.dataset_name}, Metrics: {', '.join(self.metric_names)}"
|
54
|
+
|
55
|
+
|
56
|
+
class PerfTB:
|
57
|
+
"""Class to manage performance table data with datasets as primary structure."""
|
58
|
+
|
59
|
+
def __init__(self):
|
60
|
+
# Dictionary of dataset_name -> DatasetMetrics
|
61
|
+
self.datasets: Dict[str, DatasetMetrics] = {}
|
62
|
+
self.experiments: set = set()
|
63
|
+
|
64
|
+
def add_dataset(self, dataset_name: str, metric_names: List[str]) -> None:
|
65
|
+
"""
|
66
|
+
Add a new dataset with its associated metrics.
|
67
|
+
|
68
|
+
Args:
|
69
|
+
dataset_name: Name of the dataset
|
70
|
+
metric_names: List of metric names for this dataset
|
71
|
+
"""
|
72
|
+
# normalize metric names to lowercase
|
73
|
+
metric_names = [metric.lower() for metric in metric_names]
|
74
|
+
self.datasets[dataset_name] = DatasetMetrics(dataset_name, metric_names)
|
75
|
+
|
76
|
+
def table_meta(self):
|
77
|
+
"""
|
78
|
+
Return metadata about the performance table.
|
79
|
+
"""
|
80
|
+
return {
|
81
|
+
"num_datasets": len(self.datasets),
|
82
|
+
"num_experiments": len(self.experiments),
|
83
|
+
"datasets_metrics": {
|
84
|
+
dataset_name: dataset.metric_names
|
85
|
+
for dataset_name, dataset in self.datasets.items()
|
86
|
+
}
|
87
|
+
}
|
88
|
+
|
89
|
+
def add_experiment(
|
90
|
+
self,
|
91
|
+
experiment_name: str,
|
92
|
+
dataset_name: str,
|
93
|
+
metrics: Dict[str, Union[float, int]],
|
94
|
+
) -> None:
|
95
|
+
"""
|
96
|
+
Add experiment results for a specific dataset.
|
97
|
+
|
98
|
+
Args:
|
99
|
+
experiment_name: Name or identifier of the experiment
|
100
|
+
dataset_name: Name of the dataset
|
101
|
+
metrics: Dictionary of metric names and their values
|
102
|
+
"""
|
103
|
+
# normalize metric names to lowercase
|
104
|
+
metrics = {k.lower(): v for k, v in metrics.items()}
|
105
|
+
if dataset_name not in self.datasets:
|
106
|
+
raise ValueError(
|
107
|
+
f"Dataset <<{dataset_name}>> not defined. Add dataset first."
|
108
|
+
)
|
109
|
+
self.experiments.add(experiment_name)
|
110
|
+
self.datasets[dataset_name].add_experiment_result(experiment_name, metrics)
|
111
|
+
|
112
|
+
def get_metrics_for_dataset(
|
113
|
+
self, experiment_name: str, dataset_name: str
|
114
|
+
) -> Optional[Dict[str, Union[float, int, None]]]:
|
115
|
+
"""
|
116
|
+
Retrieve performance metrics for a specific dataset and experiment.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
experiment_name: Name or identifier of the experiment
|
120
|
+
dataset_name: Name of the dataset
|
121
|
+
|
122
|
+
Returns:
|
123
|
+
Dictionary of metrics or None if dataset not found
|
124
|
+
"""
|
125
|
+
dataset = self.datasets.get(dataset_name)
|
126
|
+
if dataset:
|
127
|
+
return dataset.get_metrics(experiment_name)
|
128
|
+
return None
|
129
|
+
|
130
|
+
def get_all_experiments(self) -> List[str]:
|
131
|
+
"""Return list of all experiment names."""
|
132
|
+
return sorted(self.experiments)
|
133
|
+
|
134
|
+
def get_all_datasets(self) -> List[str]:
|
135
|
+
"""Return list of all dataset names."""
|
136
|
+
return sorted(self.datasets.keys())
|
137
|
+
|
138
|
+
def to_dataframe(self) -> pd.DataFrame:
|
139
|
+
"""
|
140
|
+
Convert the performance table to a pandas DataFrame with MultiIndex columns.
|
141
|
+
Level 1: Datasets
|
142
|
+
Level 2: Metrics
|
143
|
+
|
144
|
+
Returns:
|
145
|
+
pandas DataFrame with experiments as rows and (dataset, metric) as columns
|
146
|
+
"""
|
147
|
+
# Create MultiIndex for columns (dataset, metric)
|
148
|
+
columns = []
|
149
|
+
for dataset_name in self.get_all_datasets():
|
150
|
+
for metric in sorted(self.datasets[dataset_name].metric_names):
|
151
|
+
columns.append((dataset_name, metric))
|
152
|
+
columns = pd.MultiIndex.from_tuples(columns, names=["Dataset", "Metric"])
|
153
|
+
|
154
|
+
# Initialize DataFrame with experiments as index
|
155
|
+
df = pd.DataFrame(index=sorted(self.experiments), columns=columns)
|
156
|
+
|
157
|
+
# Populate DataFrame
|
158
|
+
for exp in self.experiments:
|
159
|
+
for dataset_name in self.datasets:
|
160
|
+
metrics = self.datasets[dataset_name].get_metrics(exp)
|
161
|
+
for metric, value in metrics.items():
|
162
|
+
df.loc[exp, (dataset_name, metric)] = value
|
163
|
+
|
164
|
+
return df
|
165
|
+
|
166
|
+
def plot(
|
167
|
+
self,
|
168
|
+
save_path: str,
|
169
|
+
title: Optional[str] = None,
|
170
|
+
custom_highlight_method_fn: Optional[callable] = None,
|
171
|
+
custom_sort_exp_fn: Optional[
|
172
|
+
callable
|
173
|
+
] = None, # Function to sort experiments; should accept a list of experiment names and return a sorted list
|
174
|
+
open_plot: bool = False,
|
175
|
+
show_raw_df: bool = False,
|
176
|
+
experiment_names: Optional[List[str]] = None,
|
177
|
+
datasets: Optional[List[str]] = None,
|
178
|
+
height: int = 400,
|
179
|
+
width: int = 700,
|
180
|
+
) -> None:
|
181
|
+
"""
|
182
|
+
Plot comparison of experiments across datasets and their metrics using Plotly.
|
183
|
+
Splits plots if metrics have significantly different value ranges.
|
184
|
+
|
185
|
+
Args:
|
186
|
+
save_path: Base file path to save the figure(s) (extension optional)
|
187
|
+
open_plot: If True, attempts to open the saved image file(s) (Windows only)
|
188
|
+
experiment_names: List of experiments to compare (default: all)
|
189
|
+
datasets: List of datasets to include (default: all)
|
190
|
+
height: Base height of the plot (scaled by # of facet rows)
|
191
|
+
width: Width of the plot
|
192
|
+
range_diff_threshold: Range threshold to split metrics across different axes
|
193
|
+
"""
|
194
|
+
experiment_names = experiment_names or self.get_all_experiments()
|
195
|
+
datasets = datasets or self.get_all_datasets()
|
196
|
+
|
197
|
+
records = []
|
198
|
+
|
199
|
+
for dataset in datasets:
|
200
|
+
if dataset not in self.datasets:
|
201
|
+
print(f"Warning: Dataset '{dataset}' not found. Skipping...")
|
202
|
+
continue
|
203
|
+
|
204
|
+
metric_names = sorted(self.datasets[dataset].metric_names)
|
205
|
+
for exp in experiment_names:
|
206
|
+
metric_values = self.get_metrics_for_dataset(exp, dataset)
|
207
|
+
if not metric_values:
|
208
|
+
continue
|
209
|
+
for metric in metric_names:
|
210
|
+
value = metric_values.get(metric)
|
211
|
+
if value is not None:
|
212
|
+
records.append(
|
213
|
+
{
|
214
|
+
"Experiment": exp,
|
215
|
+
"Dataset": dataset,
|
216
|
+
"Metric": metric,
|
217
|
+
"Value": value,
|
218
|
+
}
|
219
|
+
)
|
220
|
+
|
221
|
+
if not records:
|
222
|
+
print("No data found to plot.")
|
223
|
+
return
|
224
|
+
|
225
|
+
df = pd.DataFrame(records)
|
226
|
+
if show_raw_df:
|
227
|
+
with ConsoleLog("PerfTB DF"):
|
228
|
+
csvfile.fn_display_df(df)
|
229
|
+
|
230
|
+
metric_list = df["Metric"].unique()
|
231
|
+
fig = make_subplots(
|
232
|
+
rows=len(metric_list),
|
233
|
+
cols=1,
|
234
|
+
shared_xaxes=False,
|
235
|
+
subplot_titles=metric_list,
|
236
|
+
vertical_spacing=0.1,
|
237
|
+
)
|
238
|
+
|
239
|
+
unique_experiments = df["Experiment"].unique()
|
240
|
+
color_cycle = itertools.cycle(px.colors.qualitative.Plotly)
|
241
|
+
|
242
|
+
color_map = {
|
243
|
+
exp: color
|
244
|
+
for exp, color in zip(unique_experiments, color_cycle)
|
245
|
+
}
|
246
|
+
|
247
|
+
pattern_shapes = ["x", "-", "/", "\\", "|", "+", "."]
|
248
|
+
pattern_color = "black" # Color for patterns
|
249
|
+
|
250
|
+
current_our_method = -1 # Start with -1 to avoid index error
|
251
|
+
exp_pattern_dict = {}
|
252
|
+
for row_idx, metric in enumerate(metric_list, start=1):
|
253
|
+
metric_df = df[df["Metric"] == metric]
|
254
|
+
list_exp = list(metric_df["Experiment"].unique())
|
255
|
+
if custom_sort_exp_fn:
|
256
|
+
list_exp = custom_sort_exp_fn(list_exp)
|
257
|
+
for exp in list_exp:
|
258
|
+
should_highlight = (
|
259
|
+
custom_highlight_method_fn is not None and custom_highlight_method_fn(exp)
|
260
|
+
)
|
261
|
+
pattern_shape = "" # default no pattern
|
262
|
+
if should_highlight and exp not in exp_pattern_dict:
|
263
|
+
current_our_method += 1
|
264
|
+
pattern_shape = pattern_shapes[
|
265
|
+
current_our_method % len(pattern_shapes)
|
266
|
+
]
|
267
|
+
exp_pattern_dict[exp] = pattern_shape
|
268
|
+
elif exp in exp_pattern_dict:
|
269
|
+
pattern_shape = exp_pattern_dict[exp]
|
270
|
+
exp_df = metric_df[metric_df["Experiment"] == exp]
|
271
|
+
fig.add_trace(
|
272
|
+
go.Bar(
|
273
|
+
x=exp_df["Dataset"],
|
274
|
+
y=exp_df["Value"],
|
275
|
+
name=f"{exp}",
|
276
|
+
legendgroup=exp,
|
277
|
+
showlegend=(row_idx == 1), # Show legend only for the first row
|
278
|
+
marker=dict(
|
279
|
+
color=color_map[exp],
|
280
|
+
pattern=(
|
281
|
+
dict(shape=pattern_shape, fgcolor=pattern_color)
|
282
|
+
if pattern_shape
|
283
|
+
else None
|
284
|
+
),
|
285
|
+
),
|
286
|
+
text=[f"{v:.5f}" for v in exp_df["Value"]],
|
287
|
+
textposition="auto", # <- position them automatically
|
288
|
+
),
|
289
|
+
row=row_idx,
|
290
|
+
col=1,
|
291
|
+
)
|
292
|
+
|
293
|
+
# Manage layout
|
294
|
+
if title is None:
|
295
|
+
title = "Experiment Comparison by Metric Groups"
|
296
|
+
fig.update_layout(
|
297
|
+
height=height * len(metric_list),
|
298
|
+
width=width,
|
299
|
+
title_text=title,
|
300
|
+
barmode="group",
|
301
|
+
showlegend=True,
|
302
|
+
)
|
303
|
+
|
304
|
+
# Save and open plot
|
305
|
+
if save_path:
|
306
|
+
fig.write_image(save_path, engine="kaleido")
|
307
|
+
print(f"Saved: {os.path.abspath(save_path)}")
|
308
|
+
|
309
|
+
if open_plot and os.name == "nt": # Windows
|
310
|
+
os.system(f'start "" "{os.path.abspath(save_path)}"')
|
311
|
+
return fig
|
312
|
+
|
313
|
+
def to_csv(self, outfile: str, sep=";", condensed_multiindex: bool = True) -> None:
|
314
|
+
"""
|
315
|
+
Save the performance table to a CSV file.
|
316
|
+
|
317
|
+
Args:
|
318
|
+
outfile: Path to the output CSV file
|
319
|
+
"""
|
320
|
+
df = self.to_dataframe()
|
321
|
+
if condensed_multiindex:
|
322
|
+
# Extract levels
|
323
|
+
level0 = df.columns.get_level_values(0)
|
324
|
+
level1 = df.columns.get_level_values(1)
|
325
|
+
|
326
|
+
# Build new level0 with blanks after first appearance
|
327
|
+
new_level0 = []
|
328
|
+
prev = None
|
329
|
+
for val in level0:
|
330
|
+
if val == prev:
|
331
|
+
new_level0.append("")
|
332
|
+
else:
|
333
|
+
new_level0.append(val)
|
334
|
+
prev = val
|
335
|
+
|
336
|
+
# Write to CSV
|
337
|
+
df.columns = pd.MultiIndex.from_arrays([new_level0, level1])
|
338
|
+
df.to_csv(outfile, index=True, sep=sep)
|
339
|
+
|
340
|
+
def display(self) -> None:
|
341
|
+
"""
|
342
|
+
Display the performance table as a DataFrame.
|
343
|
+
"""
|
344
|
+
df = self.to_dataframe()
|
345
|
+
csvfile.fn_display_df(df)
|
346
|
+
|
347
|
+
@classmethod
|
348
|
+
def _read_condensed_multiindex_csv(cls, filepath: str, sep=";", col_exclude_fn: Optional[callable] = None) -> pd.DataFrame:
|
349
|
+
# Read first two header rows
|
350
|
+
df = pd.read_csv(filepath, header=[0, 1], sep=sep)
|
351
|
+
# Extract levels
|
352
|
+
level0 = df.columns.get_level_values(0)
|
353
|
+
level1 = df.columns.get_level_values(1)
|
354
|
+
# pprint(f'{level0=}')
|
355
|
+
# pprint(f'{level1=}')
|
356
|
+
# if blank values in level0, fill them after first appearance
|
357
|
+
# for level0, we need to fill in blanks after first appearance
|
358
|
+
new_level0 = []
|
359
|
+
last_non_blank = level0[0] # Start with the first value
|
360
|
+
assert last_non_blank != "", (
|
361
|
+
"First level0 value should not be blank. "
|
362
|
+
"Check the CSV file format."
|
363
|
+
)
|
364
|
+
for val in level0:
|
365
|
+
if val == "" or "Unnamed: " in val:
|
366
|
+
new_level0.append(last_non_blank)
|
367
|
+
else:
|
368
|
+
new_level0.append(val)
|
369
|
+
last_non_blank = val
|
370
|
+
# pprint(new_level0)
|
371
|
+
# Rebuild MultiIndex
|
372
|
+
excluded_indices = []
|
373
|
+
if col_exclude_fn:
|
374
|
+
excluded_indices = []
|
375
|
+
for idx, val in enumerate(new_level0):
|
376
|
+
if col_exclude_fn(val):
|
377
|
+
excluded_indices.append(idx)
|
378
|
+
for idx, val in enumerate(level1):
|
379
|
+
if col_exclude_fn(val):
|
380
|
+
excluded_indices.append(idx)
|
381
|
+
excluded_indices = list(set(excluded_indices))
|
382
|
+
|
383
|
+
num_prev_cols = len(new_level0)
|
384
|
+
# Remove excluded indices from both levels
|
385
|
+
new_level0 = [
|
386
|
+
val for idx, val in enumerate(new_level0) if idx not in excluded_indices
|
387
|
+
]
|
388
|
+
new_level1 = [
|
389
|
+
val for idx, val in enumerate(level1) if idx not in excluded_indices
|
390
|
+
]
|
391
|
+
num_after_cols = len(new_level0)
|
392
|
+
if num_prev_cols != num_after_cols:
|
393
|
+
# get df with only the new level0 index
|
394
|
+
df = df.iloc[:, [i for i in range(len(df.columns)) if i not in excluded_indices]]
|
395
|
+
|
396
|
+
df.columns = pd.MultiIndex.from_arrays([new_level0, new_level1])
|
397
|
+
return df
|
398
|
+
|
399
|
+
@classmethod
|
400
|
+
def from_dataframe(
|
401
|
+
cls,
|
402
|
+
df: pd.DataFrame
|
403
|
+
) -> "PerfTB":
|
404
|
+
"""
|
405
|
+
Load performance table from a DataFrame.
|
406
|
+
|
407
|
+
Args:
|
408
|
+
df: Input DataFrame
|
409
|
+
method_col_name: Column name for methods
|
410
|
+
"""
|
411
|
+
# console.log('--- PerfTB.from_dataframe ---')
|
412
|
+
# csvfile.fn_display_df(df)
|
413
|
+
cls_instance = cls()
|
414
|
+
# first loop through MultiIndex columns and extract datasets with their metrics
|
415
|
+
dataset_metrics = {}
|
416
|
+
for (dataset_name, metric_name) in df.columns[1:]:
|
417
|
+
if dataset_name not in dataset_metrics:
|
418
|
+
dataset_metrics[dataset_name] = []
|
419
|
+
dataset_metrics[dataset_name].append(metric_name)
|
420
|
+
for dataset_name, metric_names in dataset_metrics.items():
|
421
|
+
cls_instance.add_dataset(dataset_name, metric_names)
|
422
|
+
|
423
|
+
def safe_cast(val):
|
424
|
+
try:
|
425
|
+
return float(val)
|
426
|
+
except (ValueError, TypeError):
|
427
|
+
return None
|
428
|
+
for _, row in df.iterrows():
|
429
|
+
# Extract experiment name by first column
|
430
|
+
experiment_name = row.iloc[0]
|
431
|
+
# Iterate over MultiIndex columns (except first column)
|
432
|
+
metrics = {}
|
433
|
+
for dataset_name in dataset_metrics.keys():
|
434
|
+
for metric_name in dataset_metrics[dataset_name]:
|
435
|
+
# Get the value for this dataset and metric
|
436
|
+
value = row[(dataset_name, metric_name)]
|
437
|
+
# Cast to float or None if not applicable
|
438
|
+
metrics[metric_name] = safe_cast(value)
|
439
|
+
|
440
|
+
cls_instance.add_experiment(
|
441
|
+
experiment_name=experiment_name,
|
442
|
+
dataset_name=dataset_name,
|
443
|
+
metrics=metrics,
|
444
|
+
)
|
445
|
+
|
446
|
+
return cls_instance
|
447
|
+
|
448
|
+
@classmethod
|
449
|
+
def from_csv(
|
450
|
+
cls,
|
451
|
+
csv_file: str,
|
452
|
+
sep: str = ";",
|
453
|
+
col_exclude_fn: Optional[callable] = None,
|
454
|
+
) -> "PerfTB":
|
455
|
+
"""
|
456
|
+
Load performance table from a CSV file.
|
457
|
+
|
458
|
+
Args:
|
459
|
+
csv_file: Path to the CSV file
|
460
|
+
sep: Separator used in the CSV file
|
461
|
+
"""
|
462
|
+
df = cls._read_condensed_multiindex_csv(csv_file, sep=sep, col_exclude_fn=col_exclude_fn)
|
463
|
+
return cls.from_dataframe(df)
|
464
|
+
|
465
|
+
def filter_index_info(
|
466
|
+
self):
|
467
|
+
"""
|
468
|
+
Filter the index information of the performance table.
|
469
|
+
"""
|
470
|
+
datasets_metrics = {
|
471
|
+
dataset_name: dataset.metric_names
|
472
|
+
for dataset_name, dataset in self.datasets.items()
|
473
|
+
}
|
474
|
+
meta_dict = {}
|
475
|
+
for i, (dataset_name, metrics) in enumerate(datasets_metrics.items()):
|
476
|
+
sorted_metrics = sorted(metrics) # make sure output should be same
|
477
|
+
meta_dict[dataset_name] = {
|
478
|
+
"index": i,
|
479
|
+
"metrics": sorted(
|
480
|
+
list(zip(sorted_metrics, range(len(sorted_metrics))))
|
481
|
+
), # (metric_name, index)
|
482
|
+
}
|
483
|
+
return meta_dict
|
484
|
+
|
485
|
+
def filter(
|
486
|
+
self,
|
487
|
+
dataset_list: List[Union[str, int]] = None, # list of strings or integers
|
488
|
+
metrics_list: List[Union[list, str]] = None,
|
489
|
+
experiment_list: List[str] = None,
|
490
|
+
) -> "PerfTB":
|
491
|
+
"""
|
492
|
+
Filter the performance table by datasets and experiments.
|
493
|
+
Returns a new PerfTB instance with filtered data.
|
494
|
+
Args:
|
495
|
+
dataset_list: List of dataset names or indices to filter (optional)
|
496
|
+
metrics_list: List of metric names to filter (optional). Note that can be pass a list of list (of metric names) to filter by different set of metrics for each dataset. If using a single list, it will be applied to all datasets.
|
497
|
+
experiment_list: List of experiment NAMES (string) to filter (optional). Indices are not supported.
|
498
|
+
"""
|
499
|
+
meta_filter_dict = self.filter_index_info()
|
500
|
+
|
501
|
+
if experiment_list is None:
|
502
|
+
experiment_list = self.get_all_experiments()
|
503
|
+
else:
|
504
|
+
# make sure all experiments are found in the performance table
|
505
|
+
for exp in experiment_list:
|
506
|
+
if exp not in self.experiments:
|
507
|
+
raise ValueError(
|
508
|
+
f"Experiment <<{exp}>> not found in the performance table. Available experiments: {self.get_all_experiments()}"
|
509
|
+
)
|
510
|
+
# pprint(f"Filtering experiments: {experiment_list}")
|
511
|
+
# get dataset list
|
512
|
+
if dataset_list is not None:
|
513
|
+
# if all item in dataset_list are integers, convert them to dataset names
|
514
|
+
if all(isinstance(item, int) and 0 <= item < len(meta_filter_dict) for item in dataset_list):
|
515
|
+
dataset_list = [
|
516
|
+
list(meta_filter_dict.keys())[item] for item in dataset_list
|
517
|
+
]
|
518
|
+
elif all(isinstance(item, str) for item in dataset_list):
|
519
|
+
# if all items are strings, use them as dataset names
|
520
|
+
dataset_list = [
|
521
|
+
item for item in dataset_list if item in meta_filter_dict
|
522
|
+
]
|
523
|
+
else:
|
524
|
+
raise ValueError(
|
525
|
+
f"dataset_list should be a list of strings (dataset names) or integers (indices, should be <= {len(meta_filter_dict) - 1}). Got: {dataset_list}"
|
526
|
+
)
|
527
|
+
else:
|
528
|
+
dataset_list = self.get_all_datasets()
|
529
|
+
|
530
|
+
filter_metrics_ls = [] # [list_metric_db_A, list_metric_db_B, ...]
|
531
|
+
all_ds_metrics = []
|
532
|
+
for dataset_name in dataset_list:
|
533
|
+
ds_meta = meta_filter_dict.get(dataset_name, None)
|
534
|
+
if ds_meta:
|
535
|
+
ds_metrics = ds_meta["metrics"]
|
536
|
+
all_ds_metrics.append([metric[0] for metric in ds_metrics])
|
537
|
+
|
538
|
+
if metrics_list is None:
|
539
|
+
filter_metrics_ls = all_ds_metrics
|
540
|
+
elif isinstance(metrics_list, list):
|
541
|
+
all_string = all(isinstance(item, str) for item in metrics_list)
|
542
|
+
if all_string:
|
543
|
+
# normalize metrics_list to lowercase
|
544
|
+
metrics_list = [metric.lower() for metric in metrics_list]
|
545
|
+
filter_metrics_ls = [metrics_list] * len(dataset_list)
|
546
|
+
else:
|
547
|
+
all_list = all(isinstance(item, list) for item in metrics_list)
|
548
|
+
pprint(f'{all_list=}')
|
549
|
+
if all_list:
|
550
|
+
print('b')
|
551
|
+
if len(metrics_list) != len(dataset_list):
|
552
|
+
raise ValueError(
|
553
|
+
f"metrics_list should be a list of strings (metric names) or a list of lists of metric names for each dataset. Got: {len(metrics_list)} metrics for {len(dataset_list)} datasets."
|
554
|
+
)
|
555
|
+
# normalize each list of metrics to lowercase
|
556
|
+
filter_metrics_ls = [
|
557
|
+
[metric.lower() for metric in item] for item in metrics_list
|
558
|
+
]
|
559
|
+
|
560
|
+
else:
|
561
|
+
raise ValueError(
|
562
|
+
f"metrics_list should be a list of strings (metric names) or a list of lists of metric names for each dataset. Got: {metrics_list}"
|
563
|
+
)
|
564
|
+
|
565
|
+
# make sure that all metrics in filtered_metrics_list are valid for the datasets
|
566
|
+
final_metrics_list = []
|
567
|
+
for idx, dataset_name in enumerate(dataset_list):
|
568
|
+
valid_metrics_list = all_ds_metrics[idx]
|
569
|
+
current_metrics = filter_metrics_ls[idx]
|
570
|
+
new_valid_ds_metrics = []
|
571
|
+
for metric in current_metrics:
|
572
|
+
if metric in valid_metrics_list:
|
573
|
+
new_valid_ds_metrics.append(metric)
|
574
|
+
assert len(new_valid_ds_metrics) > 0, (
|
575
|
+
f"No valid metrics found for dataset <<{dataset_name}>>. "
|
576
|
+
f"Available metrics: {valid_metrics_list}. "
|
577
|
+
f"Filtered metrics: {current_metrics}"
|
578
|
+
)
|
579
|
+
final_metrics_list.append(new_valid_ds_metrics)
|
580
|
+
|
581
|
+
assert len(experiment_list) > 0, "No experiments to filter."
|
582
|
+
assert len(dataset_list) > 0, "No datasets to filter."
|
583
|
+
assert len(final_metrics_list) > 0, "No metrics to filter."
|
584
|
+
filtered_tb = PerfTB()
|
585
|
+
for db, metrics in zip(dataset_list, final_metrics_list):
|
586
|
+
# add dataset with its metrics
|
587
|
+
filtered_tb.add_dataset(db, metrics)
|
588
|
+
|
589
|
+
# now add experiments with their metrics
|
590
|
+
for exp in experiment_list:
|
591
|
+
for db, metrics in zip(dataset_list, final_metrics_list):
|
592
|
+
# get metrics for this experiment and dataset
|
593
|
+
metrics_dict = self.get_metrics_for_dataset(exp, db)
|
594
|
+
if metrics_dict:
|
595
|
+
# filter metrics to only those defined for this dataset
|
596
|
+
filtered_metrics = {k: v for k, v in metrics_dict.items() if k in metrics}
|
597
|
+
if filtered_metrics:
|
598
|
+
filtered_tb.add_experiment(exp, db, filtered_metrics)
|
599
|
+
|
600
|
+
return filtered_tb
|
601
|
+
|
602
|
+
|
603
|
+
def test_perftb_create() -> PerfTB:
|
604
|
+
# Create a performance table
|
605
|
+
perf_table = PerfTB()
|
606
|
+
|
607
|
+
# Define datasets and their metrics first
|
608
|
+
perf_table.add_dataset("dataset1", ["accuracy", "f1_score"])
|
609
|
+
perf_table.add_dataset("dataset2", ["accuracy", "f1_score", "precision"])
|
610
|
+
|
611
|
+
# Add experiment results
|
612
|
+
perf_table.add_experiment(
|
613
|
+
experiment_name="our_method1",
|
614
|
+
dataset_name="dataset1",
|
615
|
+
metrics={"accuracy": 100, "f1_score": 0.93},
|
616
|
+
)
|
617
|
+
perf_table.add_experiment(
|
618
|
+
experiment_name="our_method2",
|
619
|
+
dataset_name="dataset2",
|
620
|
+
metrics={"accuracy": 100, "precision": 0.87}, # Missing precision will be None
|
621
|
+
)
|
622
|
+
perf_table.add_experiment(
|
623
|
+
experiment_name="our_method2",
|
624
|
+
dataset_name="dataset1",
|
625
|
+
metrics={"accuracy": 90, "f1_score": 0.85},
|
626
|
+
)
|
627
|
+
method_list = [f"method{idx}" for idx in range(3, 7)]
|
628
|
+
# add random values for methods 3-6
|
629
|
+
for method in method_list:
|
630
|
+
perf_table.add_experiment(
|
631
|
+
experiment_name=method,
|
632
|
+
dataset_name="dataset1",
|
633
|
+
metrics={
|
634
|
+
"accuracy": random.randint(80, 100),
|
635
|
+
"f1_score": random.uniform(0.7, 0.95),
|
636
|
+
},
|
637
|
+
)
|
638
|
+
perf_table.add_experiment(
|
639
|
+
experiment_name=method,
|
640
|
+
dataset_name="dataset2",
|
641
|
+
metrics={
|
642
|
+
"accuracy": random.randint(80, 100),
|
643
|
+
"precision": random.uniform(0.7, 0.95),
|
644
|
+
"f1_score": random.uniform(0.7, 0.95),
|
645
|
+
},
|
646
|
+
)
|
647
|
+
|
648
|
+
# Get metrics for a specific dataset
|
649
|
+
metrics = perf_table.get_metrics_for_dataset("model1", "f1_score")
|
650
|
+
if metrics:
|
651
|
+
print(f"\nMetrics for model1 on dataset1: {metrics}")
|
652
|
+
|
653
|
+
return perf_table
|
654
|
+
|
655
|
+
def test_perftb_dataframe() -> None:
|
656
|
+
# Create a performance table
|
657
|
+
perf_table = test_perftb_create()
|
658
|
+
|
659
|
+
# Convert to DataFrame
|
660
|
+
df = perf_table.to_dataframe()
|
661
|
+
print("\nPerformance Table as DataFrame:")
|
662
|
+
csvfile.fn_display_df(df)
|
663
|
+
|
664
|
+
# Save to CSV
|
665
|
+
perf_table.to_csv("zout/perf_tb.csv", sep=";")
|
666
|
+
|
667
|
+
def test_perftb_plot() -> None:
|
668
|
+
# Create a performance table
|
669
|
+
perf_table = test_perftb_create()
|
670
|
+
|
671
|
+
# Plot the performance table
|
672
|
+
perf_table.plot(
|
673
|
+
save_path="zout/perf_tb.svg",
|
674
|
+
title="Performance Comparison",
|
675
|
+
custom_highlight_method_fn=lambda exp: exp.startswith("our_method"),
|
676
|
+
custom_sort_exp_fn=lambda exps: sorted(exps, reverse=True),
|
677
|
+
open_plot=False,
|
678
|
+
show_raw_df=False,
|
679
|
+
)
|
680
|
+
|
681
|
+
def test_load_perftb() -> None:
|
682
|
+
# Load performance table from CSV
|
683
|
+
def col_exclude_fn(col_name: str) -> bool:
|
684
|
+
# Exclude columns that are not metrics (e.g., "Unnamed" columns)
|
685
|
+
return col_name in ["Year", "data split", "test procedure", "code?"]
|
686
|
+
|
687
|
+
perf_table = PerfTB.from_csv("test/bench.csv", sep=";", col_exclude_fn=col_exclude_fn)
|
688
|
+
# print("\nLoaded Performance Table:")
|
689
|
+
# perf_table.display()
|
690
|
+
perf_table.to_csv("zout/loaded_perf_tb.csv", sep=";")
|
691
|
+
|
692
|
+
# Plot loaded performance table
|
693
|
+
perf_table.plot(
|
694
|
+
save_path="zout/loaded_perf_plot.svg",
|
695
|
+
title="Loaded Performance Comparison",
|
696
|
+
custom_highlight_method_fn=lambda exp: exp.startswith("Ours"),
|
697
|
+
custom_sort_exp_fn=lambda exps: sorted(exps, reverse=True),
|
698
|
+
open_plot=False,
|
699
|
+
show_raw_df=False,
|
700
|
+
)
|
701
|
+
return perf_table
|
702
|
+
|
703
|
+
def test_filtered_perftb() -> None:
|
704
|
+
perf_table_item = test_load_perftb()
|
705
|
+
# pprint(perf_table_item.meta())
|
706
|
+
pprint(perf_table_item.filter_index_info())
|
707
|
+
perf_table_item.filter(
|
708
|
+
dataset_list=[0, 2], # Use indices of datasets
|
709
|
+
# dataset_list=[
|
710
|
+
# "BOWFire_dataset_chino2015bowfire (small)",
|
711
|
+
# "FD-Dataset_li2020efficient (large)",
|
712
|
+
# ],
|
713
|
+
metrics_list=[
|
714
|
+
"acc",
|
715
|
+
"f1",
|
716
|
+
], # [["acc"], ["f1"]], # Use a single list of metrics for all datasets or a list of lists for different metrics per dataset
|
717
|
+
# experiment_list=["ADFireNet_yar2023effective"],
|
718
|
+
).plot(
|
719
|
+
save_path="zout/filtered_perf_tb.svg",
|
720
|
+
chk_highlight_method_fn=lambda exp: exp.startswith("Ours"),
|
721
|
+
custom_sort_exp_fn=lambda exps: sorted(exps, reverse=True),
|
722
|
+
title="Filtered Performance Comparison",
|
723
|
+
)
|
724
|
+
|
725
|
+
def test_mics() -> None:
|
726
|
+
# Test reading a CSV with MultiIndex columns
|
727
|
+
perf_table = test_perftb_create()
|
728
|
+
perf_table.display()
|
729
|
+
perf_table.plot(
|
730
|
+
save_path="zout/test1.svg",
|
731
|
+
title="Performance Comparison",
|
732
|
+
custom_highlight_method_fn=lambda exp: exp.startswith("our_"),
|
733
|
+
custom_sort_exp_fn=lambda exps: sorted(exps, reverse=True),
|
734
|
+
open_plot=False,
|
735
|
+
)
|
736
|
+
perf_table.to_csv("zout/perf_tb1.csv", sep=";")
|
737
|
+
tb = PerfTB.from_csv("./zout/perf_tb1.csv", sep=";")
|
738
|
+
tb.display()
|
739
|
+
ftb = tb.filter(
|
740
|
+
dataset_list=[1],
|
741
|
+
metrics_list=["precision"],
|
742
|
+
experiment_list=["method3", "method6"],
|
743
|
+
)
|
744
|
+
ftb.display()
|
745
|
+
|
746
|
+
ftb.plot(
|
747
|
+
save_path="zout/perf_tb11_plot.svg",
|
748
|
+
title="Performance Comparison",
|
749
|
+
custom_highlight_method_fn=lambda exp: exp.startswith("our_"),
|
750
|
+
custom_sort_exp_fn=lambda exps: sorted(exps, reverse=True),
|
751
|
+
open_plot=True,
|
752
|
+
)
|
753
|
+
|
754
|
+
|
755
|
+
# Example usage
|
756
|
+
if __name__ == "__main__":
|
757
|
+
test_mics()
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: halib
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.51
|
4
4
|
Summary: Small library for common tasks
|
5
5
|
Author: Hoang Van Ha
|
6
6
|
Author-email: hoangvanhauit@gmail.com
|
@@ -44,6 +44,10 @@ Requires-Dist: wandb
|
|
44
44
|
|
45
45
|
Helper package for coding and automation
|
46
46
|
|
47
|
+
**Version 0.1.51**
|
48
|
+
|
49
|
+
+ add `research/perftb` module to allow creating and managing performance tables for experiments, including filtering by datasets, metrics, and experiments.
|
50
|
+
|
47
51
|
**Version 0.1.50**
|
48
52
|
|
49
53
|
+ add `pprint_local_path` to print local path (file/directory) in clickable link (as file URI)
|
@@ -30,6 +30,7 @@ halib/online/projectmake.py,sha256=Zrs96WgXvO4nIrwxnCOletL4aTBge-EoF0r7hpKO1w8,4
|
|
30
30
|
halib/research/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
31
31
|
halib/research/benchquery.py,sha256=FuKnbWQtCEoRRtJAfN-zaN-jPiO_EzsakmTOMiqi7GQ,4626
|
32
32
|
halib/research/dataset.py,sha256=QU0Hr5QFb8_XlvnOMgC9QJGIpwXAZ9lDd0RdQi_QRec,6743
|
33
|
+
halib/research/perftb.py,sha256=hptHgCMLWsGKSeGz8g7oB3XdMy9eK5DzASewn0RALOg,29997
|
33
34
|
halib/research/plot.py,sha256=-pDUk4z3C_GnyJ5zWmf-mGMdT4gaipVJWzIgcpIPiRk,9448
|
34
35
|
halib/research/torchloader.py,sha256=yqUjcSiME6H5W210363HyRUrOi3ISpUFAFkTr1w4DCw,6503
|
35
36
|
halib/research/wandb_op.py,sha256=YzLEqME5kIRxi3VvjFkW83wnFrsn92oYeqYuNwtYRkY,4188
|
@@ -42,8 +43,8 @@ halib/system/filesys.py,sha256=ERpnELLDKJoTIIKf-AajgkY62nID4qmqmX5TkE95APU,2931
|
|
42
43
|
halib/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
43
44
|
halib/utils/listop.py,sha256=Vpa8_2fI0wySpB2-8sfTBkyi_A4FhoFVVvFiuvW8N64,339
|
44
45
|
halib/utils/tele_noti.py,sha256=-4WXZelCA4W9BroapkRyIdUu9cUVrcJJhegnMs_WpGU,5928
|
45
|
-
halib-0.1.
|
46
|
-
halib-0.1.
|
47
|
-
halib-0.1.
|
48
|
-
halib-0.1.
|
49
|
-
halib-0.1.
|
46
|
+
halib-0.1.51.dist-info/LICENSE.txt,sha256=qZssdna4aETiR8znYsShUjidu-U4jUT9Q-EWNlZ9yBQ,1100
|
47
|
+
halib-0.1.51.dist-info/METADATA,sha256=V6qfi8fSUhVYiRhXS8wlG1F3f91zzmZBVnKZiwD3sAU,4365
|
48
|
+
halib-0.1.51.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
|
49
|
+
halib-0.1.51.dist-info/top_level.txt,sha256=7AD6PLaQTreE0Fn44mdZsoHBe_Zdd7GUmjsWPyQ7I-k,6
|
50
|
+
halib-0.1.51.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|