halib 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,778 @@
1
+ import warnings
2
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
3
+
4
+ import os
5
+ import random
6
+ import itertools
7
+ import plotly.express as px
8
+ import plotly.graph_objects as go
9
+ from collections import defaultdict
10
+ from plotly.subplots import make_subplots
11
+ from typing import Dict, List, Union, Optional
12
+
13
+ import pandas as pd
14
+ from rich.pretty import pprint
15
+
16
+ from ...filetype import csvfile
17
+ from ...common.common import ConsoleLog
18
+
19
+ class DatasetMetrics:
20
+ """Class to store metrics definitions for a specific dataset."""
21
+
22
+ def __init__(self, dataset_name: str, metric_names: List[str]):
23
+ self.dataset_name = dataset_name
24
+ self.metric_names = set(metric_names) # Unique metric names
25
+ self.experiment_results: Dict[str, Dict[str, Union[float, int, None]]] = (
26
+ defaultdict(dict)
27
+ )
28
+
29
+ def add_experiment_result(
30
+ self, experiment_name: str, metrics: Dict[str, Union[float, int]]
31
+ ) -> None:
32
+ """Add experiment results for this dataset, only for defined metrics."""
33
+ # normalize metric names to lowercase
34
+ metrics = {k.lower(): v for k, v in metrics.items()}
35
+ # make sure every metric in metrics is defined for this dataset
36
+ for metric in metrics:
37
+ assert metric in self.metric_names, (
38
+ f"Metric <<{metric}>> not defined for dataset <<{self.dataset_name}>>. "
39
+ f"Available metrics: {self.metric_names}"
40
+ )
41
+ for metric in self.metric_names:
42
+ self.experiment_results[experiment_name][metric] = metrics.get(metric)
43
+
44
+ def get_metrics(self, experiment_name: str) -> Dict[str, Union[float, int, None]]:
45
+ """Retrieve metrics for a specific experiment."""
46
+ return self.experiment_results.get(
47
+ experiment_name, {metric: None for metric in self.metric_names}
48
+ )
49
+
50
+ def __str__(self) -> str:
51
+ return f"Dataset: {self.dataset_name}, Metrics: {', '.join(self.metric_names)}"
52
+
53
+
54
+ class PerfTB:
55
+ """Class to manage performance table data with datasets as primary structure."""
56
+
57
+ def __init__(self):
58
+ # Dictionary of dataset_name -> DatasetMetrics
59
+ self.datasets: Dict[str, DatasetMetrics] = {}
60
+ self.experiments: set = set()
61
+
62
+ def add_dataset(self, dataset_name: str, metric_names: List[str]) -> None:
63
+ """
64
+ Add a new dataset with its associated metrics.
65
+
66
+ Args:
67
+ dataset_name: Name of the dataset
68
+ metric_names: List of metric names for this dataset
69
+ """
70
+ # normalize metric names to lowercase
71
+ metric_names = [metric.lower() for metric in metric_names]
72
+ self.datasets[dataset_name] = DatasetMetrics(dataset_name, metric_names)
73
+
74
+ def table_meta(self):
75
+ """
76
+ Return metadata about the performance table.
77
+ """
78
+ return {
79
+ "num_datasets": len(self.datasets),
80
+ "num_experiments": len(self.experiments),
81
+ "datasets_metrics": {
82
+ dataset_name: dataset.metric_names
83
+ for dataset_name, dataset in self.datasets.items()
84
+ }
85
+ }
86
+
87
+ def add_experiment(
88
+ self,
89
+ experiment_name: str,
90
+ dataset_name: str,
91
+ metrics: Dict[str, Union[float, int]],
92
+ ) -> None:
93
+ """
94
+ Add experiment results for a specific dataset.
95
+
96
+ Args:
97
+ experiment_name: Name or identifier of the experiment
98
+ dataset_name: Name of the dataset
99
+ metrics: Dictionary of metric names and their values
100
+ """
101
+ # normalize metric names to lowercase
102
+ metrics = {k.lower(): v for k, v in metrics.items()}
103
+ if dataset_name not in self.datasets:
104
+ raise ValueError(
105
+ f"Dataset <<{dataset_name}>> not defined. Add dataset first."
106
+ )
107
+ self.experiments.add(experiment_name)
108
+ self.datasets[dataset_name].add_experiment_result(experiment_name, metrics)
109
+
110
+ def get_metrics_for_dataset(
111
+ self, experiment_name: str, dataset_name: str
112
+ ) -> Optional[Dict[str, Union[float, int, None]]]:
113
+ """
114
+ Retrieve performance metrics for a specific dataset and experiment.
115
+
116
+ Args:
117
+ experiment_name: Name or identifier of the experiment
118
+ dataset_name: Name of the dataset
119
+
120
+ Returns:
121
+ Dictionary of metrics or None if dataset not found
122
+ """
123
+ dataset = self.datasets.get(dataset_name)
124
+ if dataset:
125
+ return dataset.get_metrics(experiment_name)
126
+ return None
127
+
128
+ def get_all_experiments(self) -> List[str]:
129
+ """Return list of all experiment names."""
130
+ return sorted(self.experiments)
131
+
132
+ def get_all_datasets(self) -> List[str]:
133
+ """Return list of all dataset names."""
134
+ return sorted(self.datasets.keys())
135
+
136
+ def to_dataframe(self) -> pd.DataFrame:
137
+ """
138
+ Convert the performance table to a pandas DataFrame with MultiIndex columns.
139
+ Level 1: Datasets
140
+ Level 2: Metrics
141
+
142
+ Returns:
143
+ pandas DataFrame with experiments as rows and (dataset, metric) as columns
144
+ """
145
+ # Create MultiIndex for columns (dataset, metric)
146
+ columns = []
147
+ for dataset_name in self.get_all_datasets():
148
+ for metric in sorted(self.datasets[dataset_name].metric_names):
149
+ columns.append((dataset_name, metric))
150
+ columns = pd.MultiIndex.from_tuples(columns, names=["Dataset", "Metric"])
151
+
152
+ # Initialize DataFrame with experiments as index
153
+ df = pd.DataFrame(index=sorted(self.experiments), columns=columns)
154
+
155
+ # Populate DataFrame
156
+ for exp in self.experiments:
157
+ for dataset_name in self.datasets:
158
+ metrics = self.datasets[dataset_name].get_metrics(exp)
159
+ for metric, value in metrics.items():
160
+ df.loc[exp, (dataset_name, metric)] = value
161
+
162
+ return df
163
+
164
+ def plot(
165
+ self,
166
+ save_path: str,
167
+ title: Optional[str] = None,
168
+ custom_highlight_method_fn: Optional[callable] = None,
169
+ custom_sort_exp_fn: Optional[
170
+ callable
171
+ ] = None, # Function to sort experiments; should accept a list of experiment names and return a sorted list
172
+ open_plot: bool = False,
173
+ show_raw_df: bool = False,
174
+ experiment_names: Optional[List[str]] = None,
175
+ datasets: Optional[List[str]] = None,
176
+ height: int = 400,
177
+ width: int = 700,
178
+ ) -> None:
179
+ """
180
+ Plot comparison of experiments across datasets and their metrics using Plotly.
181
+ Splits plots if metrics have significantly different value ranges.
182
+
183
+ Args:
184
+ save_path: Base file path to save the figure(s) (extension optional)
185
+ open_plot: If True, attempts to open the saved image file(s) (Windows only)
186
+ experiment_names: List of experiments to compare (default: all)
187
+ datasets: List of datasets to include (default: all)
188
+ height: Base height of the plot (scaled by # of facet rows)
189
+ width: Width of the plot
190
+ range_diff_threshold: Range threshold to split metrics across different axes
191
+ """
192
+ experiment_names = experiment_names or self.get_all_experiments()
193
+ datasets = datasets or self.get_all_datasets()
194
+
195
+ records = []
196
+
197
+ for dataset in datasets:
198
+ if dataset not in self.datasets:
199
+ print(f"Warning: Dataset '{dataset}' not found. Skipping...")
200
+ continue
201
+
202
+ metric_names = sorted(self.datasets[dataset].metric_names)
203
+ for exp in experiment_names:
204
+ metric_values = self.get_metrics_for_dataset(exp, dataset)
205
+ if not metric_values:
206
+ continue
207
+ for metric in metric_names:
208
+ value = metric_values.get(metric)
209
+ if value is not None:
210
+ records.append(
211
+ {
212
+ "Experiment": exp,
213
+ "Dataset": dataset,
214
+ "Metric": metric,
215
+ "Value": value,
216
+ }
217
+ )
218
+
219
+ if not records:
220
+ print("No data found to plot.")
221
+ return
222
+
223
+ df = pd.DataFrame(records)
224
+ if show_raw_df:
225
+ with ConsoleLog("PerfTB DF"):
226
+ csvfile.fn_display_df(df)
227
+
228
+ metric_list = df["Metric"].unique()
229
+ fig = make_subplots(
230
+ rows=len(metric_list),
231
+ cols=1,
232
+ shared_xaxes=False,
233
+ subplot_titles=metric_list,
234
+ vertical_spacing=0.1,
235
+ )
236
+
237
+ unique_experiments = df["Experiment"].unique()
238
+ color_cycle = itertools.cycle(px.colors.qualitative.Plotly)
239
+
240
+ color_map = {
241
+ exp: color
242
+ for exp, color in zip(unique_experiments, color_cycle)
243
+ }
244
+
245
+ pattern_shapes = ["x", "-", "/", "\\", "|", "+", "."]
246
+ pattern_color = "black" # Color for patterns
247
+
248
+ current_our_method = -1 # Start with -1 to avoid index error
249
+ exp_pattern_dict = {}
250
+ shown_legends = set()
251
+ for row_idx, metric in enumerate(metric_list, start=1):
252
+ metric_df = df[df["Metric"] == metric]
253
+ list_exp = list(metric_df["Experiment"].unique())
254
+ if custom_sort_exp_fn:
255
+ list_exp = custom_sort_exp_fn(list_exp)
256
+ for exp in list_exp:
257
+ showlegend = exp not in shown_legends
258
+ shown_legends.add(exp) # since it is a set, it will only keep unique values
259
+ should_highlight = (
260
+ custom_highlight_method_fn is not None and custom_highlight_method_fn(exp)
261
+ )
262
+ pattern_shape = "" # default no pattern
263
+ if should_highlight and exp not in exp_pattern_dict:
264
+ current_our_method += 1
265
+ pattern_shape = pattern_shapes[
266
+ current_our_method % len(pattern_shapes)
267
+ ]
268
+ exp_pattern_dict[exp] = pattern_shape
269
+ elif exp in exp_pattern_dict:
270
+ pattern_shape = exp_pattern_dict[exp]
271
+ exp_df = metric_df[metric_df["Experiment"] == exp]
272
+ fig.add_trace(
273
+ go.Bar(
274
+ x=exp_df["Dataset"],
275
+ y=exp_df["Value"],
276
+ name=f"{exp}",
277
+ legendgroup=exp,
278
+ showlegend=showlegend, # Show legend only for the first row
279
+ marker=dict(
280
+ color=color_map[exp],
281
+ pattern=(
282
+ dict(shape=pattern_shape, fgcolor=pattern_color)
283
+ if pattern_shape
284
+ else None
285
+ ),
286
+ ),
287
+ text=[f"{v:.5f}" for v in exp_df["Value"]],
288
+ textposition="auto", # <- position them automatically
289
+ ),
290
+ row=row_idx,
291
+ col=1,
292
+ )
293
+
294
+ # Manage layout
295
+ if title is None:
296
+ title = "Experiment Comparison by Metric Groups"
297
+ fig.update_layout(
298
+ height=height * len(metric_list),
299
+ width=width,
300
+ title_text=title,
301
+ barmode="group",
302
+ showlegend=True,
303
+ )
304
+
305
+ # Save and open plot
306
+ if save_path:
307
+ export_success = False
308
+ try:
309
+ # fig.write_image(save_path, engine="kaleido")
310
+ fig.write_image(save_path, engine="kaleido", width=width, height=height * len(metric_list))
311
+ export_success = True
312
+ # pprint(f"Saved: {os.path.abspath(save_path)}")
313
+ except Exception as e:
314
+ print(f"Error saving plot: {e}")
315
+ pprint(
316
+ "Failed to save plot. Check this link: https://stackoverflow.com/questions/69016568/unable-to-export-plotly-images-to-png-with-kaleido. Maybe you need to downgrade kaleido version to 0.1.* or install it via pip install kaleido==0.1.*"
317
+ )
318
+ return
319
+ if export_success and open_plot and os.name == "nt": # Windows
320
+ os.system(f'start "" "{os.path.abspath(save_path)}"')
321
+ return fig
322
+
323
+ def to_csv(self, outfile: str, sep=";", condensed_multiindex: bool = True) -> None:
324
+ """
325
+ Save the performance table to a CSV file.
326
+
327
+ Args:
328
+ outfile: Path to the output CSV file
329
+ """
330
+ df = self.to_dataframe()
331
+ if condensed_multiindex:
332
+ # Extract levels
333
+ level0 = df.columns.get_level_values(0)
334
+ level1 = df.columns.get_level_values(1)
335
+
336
+ # Build new level0 with blanks after first appearance
337
+ new_level0 = []
338
+ prev = None
339
+ for val in level0:
340
+ if val == prev:
341
+ new_level0.append("")
342
+ else:
343
+ new_level0.append(val)
344
+ prev = val
345
+
346
+ # Write to CSV
347
+ df.columns = pd.MultiIndex.from_arrays([new_level0, level1])
348
+ df.to_csv(outfile, index=True, sep=sep)
349
+
350
+ def display(self) -> None:
351
+ """
352
+ Display the performance table as a DataFrame.
353
+ """
354
+ df = self.to_dataframe()
355
+ csvfile.fn_display_df(df)
356
+
357
+ @classmethod
358
+ def _read_condensed_multiindex_csv(cls, filepath: str, sep=";", col_exclude_fn: Optional[callable] = None) -> pd.DataFrame:
359
+ # Read first two header rows
360
+ df = pd.read_csv(filepath, header=[0, 1], sep=sep)
361
+ # Extract levels
362
+ level0 = df.columns.get_level_values(0)
363
+ level1 = df.columns.get_level_values(1)
364
+ # pprint(f'{level0=}')
365
+ # pprint(f'{level1=}')
366
+ # if blank values in level0, fill them after first appearance
367
+ # for level0, we need to fill in blanks after first appearance
368
+ new_level0 = []
369
+ last_non_blank = level0[0] # Start with the first value
370
+ assert last_non_blank != "", (
371
+ "First level0 value should not be blank. "
372
+ "Check the CSV file format."
373
+ )
374
+ for val in level0:
375
+ if val == "" or "Unnamed: " in val:
376
+ new_level0.append(last_non_blank)
377
+ else:
378
+ new_level0.append(val)
379
+ last_non_blank = val
380
+ # pprint(new_level0)
381
+ # Rebuild MultiIndex
382
+ excluded_indices = []
383
+ if col_exclude_fn:
384
+ excluded_indices = []
385
+ for idx, val in enumerate(new_level0):
386
+ if col_exclude_fn(val):
387
+ excluded_indices.append(idx)
388
+ for idx, val in enumerate(level1):
389
+ if col_exclude_fn(val):
390
+ excluded_indices.append(idx)
391
+ excluded_indices = list(set(excluded_indices))
392
+
393
+ num_prev_cols = len(new_level0)
394
+ # Remove excluded indices from both levels
395
+ new_level0 = [
396
+ val for idx, val in enumerate(new_level0) if idx not in excluded_indices
397
+ ]
398
+ new_level1 = [
399
+ val for idx, val in enumerate(level1) if idx not in excluded_indices
400
+ ]
401
+ num_after_cols = len(new_level0)
402
+ if num_prev_cols != num_after_cols:
403
+ # get df with only the new level0 index
404
+ df = df.iloc[:, [i for i in range(len(df.columns)) if i not in excluded_indices]]
405
+
406
+ df.columns = pd.MultiIndex.from_arrays([new_level0, new_level1])
407
+ return df
408
+
409
+ @classmethod
410
+ def from_dataframe(
411
+ cls,
412
+ df: pd.DataFrame
413
+ ) -> "PerfTB":
414
+ """
415
+ Load performance table from a DataFrame.
416
+
417
+ Args:
418
+ df: Input DataFrame
419
+ method_col_name: Column name for methods
420
+ """
421
+ # console.log('--- PerfTB.from_dataframe ---')
422
+ # csvfile.fn_display_df(df)
423
+ cls_instance = cls()
424
+ # first loop through MultiIndex columns and extract datasets with their metrics
425
+ dataset_metrics = {}
426
+ for (dataset_name, metric_name) in df.columns[1:]:
427
+ if dataset_name not in dataset_metrics:
428
+ dataset_metrics[dataset_name] = []
429
+ dataset_metrics[dataset_name].append(metric_name)
430
+ for dataset_name, metric_names in dataset_metrics.items():
431
+ cls_instance.add_dataset(dataset_name, metric_names)
432
+
433
+ def safe_cast(val):
434
+ try:
435
+ return float(val)
436
+ except (ValueError, TypeError):
437
+ return None
438
+ for _, row in df.iterrows():
439
+ # Extract experiment name by first column
440
+ experiment_name = row.iloc[0]
441
+ # Iterate over MultiIndex columns (except first column)
442
+ metrics = {}
443
+ for dataset_name in dataset_metrics.keys():
444
+ for metric_name in dataset_metrics[dataset_name]:
445
+ # Get the value for this dataset and metric
446
+ value = row[(dataset_name, metric_name)]
447
+ # Cast to float or None if not applicable
448
+ metrics[metric_name] = safe_cast(value)
449
+
450
+ cls_instance.add_experiment(
451
+ experiment_name=experiment_name,
452
+ dataset_name=dataset_name,
453
+ metrics=metrics,
454
+ )
455
+
456
+ return cls_instance
457
+
458
+ @classmethod
459
+ def from_csv(
460
+ cls,
461
+ csv_file: str,
462
+ sep: str = ";",
463
+ col_exclude_fn: Optional[callable] = None,
464
+ ) -> "PerfTB":
465
+ """
466
+ Load performance table from a CSV file.
467
+
468
+ Args:
469
+ csv_file: Path to the CSV file
470
+ sep: Separator used in the CSV file
471
+ """
472
+ df = cls._read_condensed_multiindex_csv(csv_file, sep=sep, col_exclude_fn=col_exclude_fn)
473
+ return cls.from_dataframe(df)
474
+
475
+ def filter_index_info(
476
+ self):
477
+ """
478
+ Filter the index information of the performance table.
479
+ """
480
+ datasets_metrics = {
481
+ dataset_name: dataset.metric_names
482
+ for dataset_name, dataset in self.datasets.items()
483
+ }
484
+ meta_dict = {}
485
+ for i, (dataset_name, metrics) in enumerate(datasets_metrics.items()):
486
+ sorted_metrics = sorted(metrics) # make sure output should be same
487
+ meta_dict[dataset_name] = {
488
+ "index": i,
489
+ "metrics": sorted(
490
+ list(zip(sorted_metrics, range(len(sorted_metrics))))
491
+ ), # (metric_name, index)
492
+ }
493
+ return meta_dict
494
+
495
+ def filter(
496
+ self,
497
+ dataset_list: List[Union[str, int]] = None, # list of strings or integers
498
+ metrics_list: List[Union[list, str]] = None,
499
+ experiment_list: List[str] = None,
500
+ ) -> "PerfTB":
501
+ """
502
+ Filter the performance table by datasets and experiments.
503
+ Returns a new PerfTB instance with filtered data.
504
+ Args:
505
+ dataset_list: List of dataset names or indices to filter (optional)
506
+ metrics_list: List of metric names to filter (optional). Note that can be pass a list of list (of metric names) to filter by different set of metrics for each dataset. If using a single list, it will be applied to all datasets.
507
+ experiment_list: List of experiment NAMES (string) to filter (optional). Indices are not supported.
508
+ """
509
+ meta_filter_dict = self.filter_index_info()
510
+
511
+ if experiment_list is None:
512
+ experiment_list = self.get_all_experiments()
513
+ else:
514
+ # make sure all experiments are found in the performance table
515
+ for exp in experiment_list:
516
+ if exp not in self.experiments:
517
+ raise ValueError(
518
+ f"Experiment <<{exp}>> not found in the performance table. Available experiments: {self.get_all_experiments()}"
519
+ )
520
+ # pprint(f"Filtering experiments: {experiment_list}")
521
+ # get dataset list
522
+ if dataset_list is not None:
523
+ # if all item in dataset_list are integers, convert them to dataset names
524
+ if all(isinstance(item, int) and 0 <= item < len(meta_filter_dict) for item in dataset_list):
525
+ dataset_list = [
526
+ list(meta_filter_dict.keys())[item] for item in dataset_list
527
+ ]
528
+ elif all(isinstance(item, str) for item in dataset_list):
529
+ # if all items are strings, use them as dataset names
530
+ dataset_list = [
531
+ item for item in dataset_list if item in meta_filter_dict
532
+ ]
533
+ else:
534
+ raise ValueError(
535
+ f"dataset_list should be a list of strings (dataset names) or integers (indices, should be <= {len(meta_filter_dict) - 1}). Got: {dataset_list}"
536
+ )
537
+ else:
538
+ dataset_list = self.get_all_datasets()
539
+
540
+ filter_metrics_ls = [] # [list_metric_db_A, list_metric_db_B, ...]
541
+ all_ds_metrics = []
542
+ for dataset_name in dataset_list:
543
+ ds_meta = meta_filter_dict.get(dataset_name, None)
544
+ if ds_meta:
545
+ ds_metrics = ds_meta["metrics"]
546
+ all_ds_metrics.append([metric[0] for metric in ds_metrics])
547
+
548
+ if metrics_list is None:
549
+ filter_metrics_ls = all_ds_metrics
550
+ elif isinstance(metrics_list, list):
551
+ all_string = all(isinstance(item, str) for item in metrics_list)
552
+ if all_string:
553
+ # normalize metrics_list to lowercase
554
+ metrics_list = [metric.lower() for metric in metrics_list]
555
+ filter_metrics_ls = [metrics_list] * len(dataset_list)
556
+ else:
557
+ all_list = all(isinstance(item, list) for item in metrics_list)
558
+ pprint(f'{all_list=}')
559
+ if all_list:
560
+ print('b')
561
+ if len(metrics_list) != len(dataset_list):
562
+ raise ValueError(
563
+ f"metrics_list should be a list of strings (metric names) or a list of lists of metric names for each dataset. Got: {len(metrics_list)} metrics for {len(dataset_list)} datasets."
564
+ )
565
+ # normalize each list of metrics to lowercase
566
+ filter_metrics_ls = [
567
+ [metric.lower() for metric in item] for item in metrics_list
568
+ ]
569
+
570
+ else:
571
+ raise ValueError(
572
+ f"metrics_list should be a list of strings (metric names) or a list of lists of metric names for each dataset. Got: {metrics_list}"
573
+ )
574
+
575
+ # make sure that all metrics in filtered_metrics_list are valid for the datasets
576
+ final_metrics_list = []
577
+ for idx, dataset_name in enumerate(dataset_list):
578
+ valid_metrics_list = all_ds_metrics[idx]
579
+ current_metrics = filter_metrics_ls[idx]
580
+ new_valid_ds_metrics = []
581
+ for metric in current_metrics:
582
+ if metric in valid_metrics_list:
583
+ new_valid_ds_metrics.append(metric)
584
+ assert len(new_valid_ds_metrics) > 0, (
585
+ f"No valid metrics found for dataset <<{dataset_name}>>. "
586
+ f"Available metrics: {valid_metrics_list}. "
587
+ f"Filtered metrics: {current_metrics}"
588
+ )
589
+ final_metrics_list.append(new_valid_ds_metrics)
590
+
591
+ assert len(experiment_list) > 0, "No experiments to filter."
592
+ assert len(dataset_list) > 0, "No datasets to filter."
593
+ assert len(final_metrics_list) > 0, "No metrics to filter."
594
+ filtered_tb = PerfTB()
595
+ for db, metrics in zip(dataset_list, final_metrics_list):
596
+ # add dataset with its metrics
597
+ filtered_tb.add_dataset(db, metrics)
598
+
599
+ # now add experiments with their metrics
600
+ for exp in experiment_list:
601
+ for db, metrics in zip(dataset_list, final_metrics_list):
602
+ # get metrics for this experiment and dataset
603
+ metrics_dict = self.get_metrics_for_dataset(exp, db)
604
+ if metrics_dict:
605
+ # filter metrics to only those defined for this dataset
606
+ filtered_metrics = {k: v for k, v in metrics_dict.items() if k in metrics}
607
+ if filtered_metrics:
608
+ filtered_tb.add_experiment(exp, db, filtered_metrics)
609
+
610
+ return filtered_tb
611
+
612
+
613
+ def test_perftb_create() -> PerfTB:
614
+ # Create a performance table
615
+ perf_table = PerfTB()
616
+
617
+ # Define datasets and their metrics first
618
+ perf_table.add_dataset("dataset1", ["accuracy", "f1_score"])
619
+ perf_table.add_dataset("dataset2", ["accuracy", "f1_score", "precision"])
620
+
621
+ # Add experiment results
622
+ perf_table.add_experiment(
623
+ experiment_name="our_method1",
624
+ dataset_name="dataset1",
625
+ metrics={"accuracy": 100, "f1_score": 0.93},
626
+ )
627
+ perf_table.add_experiment(
628
+ experiment_name="our_method2",
629
+ dataset_name="dataset2",
630
+ metrics={"accuracy": 100, "precision": 0.87}, # Missing precision will be None
631
+ )
632
+ perf_table.add_experiment(
633
+ experiment_name="our_method2",
634
+ dataset_name="dataset1",
635
+ metrics={"accuracy": 90, "f1_score": 0.85},
636
+ )
637
+ method_list = [f"method{idx}" for idx in range(3, 7)]
638
+ # add random values for methods 3-6
639
+ for method in method_list:
640
+ perf_table.add_experiment(
641
+ experiment_name=method,
642
+ dataset_name="dataset1",
643
+ metrics={
644
+ "accuracy": random.randint(80, 100),
645
+ "f1_score": random.uniform(0.7, 0.95),
646
+ },
647
+ )
648
+ perf_table.add_experiment(
649
+ experiment_name=method,
650
+ dataset_name="dataset2",
651
+ metrics={
652
+ "accuracy": random.randint(80, 100),
653
+ "precision": random.uniform(0.7, 0.95),
654
+ "f1_score": random.uniform(0.7, 0.95),
655
+ },
656
+ )
657
+
658
+ # Get metrics for a specific dataset
659
+ metrics = perf_table.get_metrics_for_dataset("model1", "f1_score")
660
+ if metrics:
661
+ print(f"\nMetrics for model1 on dataset1: {metrics}")
662
+
663
+ return perf_table
664
+
665
+ def test_perftb_dataframe() -> None:
666
+ # Create a performance table
667
+ perf_table = test_perftb_create()
668
+
669
+ # Convert to DataFrame
670
+ df = perf_table.to_dataframe()
671
+ print("\nPerformance Table as DataFrame:")
672
+ csvfile.fn_display_df(df)
673
+
674
+ # Save to CSV
675
+ perf_table.to_csv("zout/perf_tb.csv", sep=";")
676
+
677
+ def test_perftb_plot() -> None:
678
+ # Create a performance table
679
+ perf_table = test_perftb_create()
680
+
681
+ # Plot the performance table
682
+ perf_table.plot(
683
+ save_path="zout/perf_tb.svg",
684
+ title="Performance Comparison",
685
+ custom_highlight_method_fn=lambda exp: exp.startswith("our_method"),
686
+ custom_sort_exp_fn=lambda exps: sorted(exps, reverse=True),
687
+ open_plot=False,
688
+ show_raw_df=False,
689
+ )
690
+
691
+ def test_load_perftb() -> None:
692
+ # Load performance table from CSV
693
+ def col_exclude_fn(col_name: str) -> bool:
694
+ # Exclude columns that are not metrics (e.g., "Unnamed" columns)
695
+ return col_name in ["Year", "data split", "test procedure", "code?"]
696
+
697
+ perf_table = PerfTB.from_csv("test/bench.csv", sep=";", col_exclude_fn=col_exclude_fn)
698
+ # print("\nLoaded Performance Table:")
699
+ # perf_table.display()
700
+ perf_table.to_csv("zout/loaded_perf_tb.csv", sep=";")
701
+
702
+ # Plot loaded performance table
703
+ perf_table.plot(
704
+ save_path="zout/loaded_perf_plot.svg",
705
+ title="Loaded Performance Comparison",
706
+ custom_highlight_method_fn=lambda exp: exp.startswith("Ours"),
707
+ custom_sort_exp_fn=lambda exps: sorted(exps, reverse=True),
708
+ open_plot=False,
709
+ show_raw_df=False,
710
+ )
711
+ return perf_table
712
+
713
+ def test_filtered_perftb() -> None:
714
+ perf_table_item = test_load_perftb()
715
+ # pprint(perf_table_item.meta())
716
+ pprint(perf_table_item.filter_index_info())
717
+ perf_table_item.filter(
718
+ dataset_list=[0, 2], # Use indices of datasets
719
+ # dataset_list=[
720
+ # "BOWFire_dataset_chino2015bowfire (small)",
721
+ # "FD-Dataset_li2020efficient (large)",
722
+ # ],
723
+ metrics_list=[
724
+ "acc",
725
+ "f1",
726
+ ], # [["acc"], ["f1"]], # Use a single list of metrics for all datasets or a list of lists for different metrics per dataset
727
+ # experiment_list=["ADFireNet_yar2023effective"],
728
+ ).plot(
729
+ save_path="zout/filtered_perf_tb.svg",
730
+ chk_highlight_method_fn=lambda exp: exp.startswith("Ours"),
731
+ custom_sort_exp_fn=lambda exps: sorted(exps, reverse=True),
732
+ title="Filtered Performance Comparison",
733
+ )
734
+
735
+ def test_mics() -> None:
736
+ # Test reading a CSV with MultiIndex columns
737
+ perf_table = test_perftb_create()
738
+ perf_table.display()
739
+ perf_table.plot(
740
+ save_path="zout/test1.svg",
741
+ title="Performance Comparison",
742
+ custom_highlight_method_fn=lambda exp: exp.startswith("our_"),
743
+ custom_sort_exp_fn=lambda exps: sorted(exps, reverse=True),
744
+ open_plot=False,
745
+ )
746
+ perf_table.to_csv("zout/perf_tb1.csv", sep=";")
747
+ tb = PerfTB.from_csv("./zout/perf_tb1.csv", sep=";")
748
+ tb.display()
749
+ ftb = tb.filter(
750
+ dataset_list=[1],
751
+ metrics_list=["precision"],
752
+ experiment_list=["method3", "method6"],
753
+ )
754
+ ftb.display()
755
+
756
+ ftb.plot(
757
+ save_path="zout/perf_tb11_plot.svg",
758
+ title="Performance Comparison",
759
+ custom_highlight_method_fn=lambda exp: exp.startswith("our_"),
760
+ custom_sort_exp_fn=lambda exps: sorted(exps, reverse=True),
761
+ open_plot=True,
762
+ )
763
+ def test_bench2():
764
+ perftb = PerfTB.from_csv(
765
+ "test/bench2.csv",
766
+ sep=";")
767
+ perftb.display()
768
+ perftb.plot(
769
+ save_path="zout/bench2_plot.svg",
770
+ title="Bench2 Performance Comparison",
771
+ open_plot=True,
772
+ )
773
+
774
+
775
+ # Example usage
776
+ if __name__ == "__main__":
777
+ # test_mics()
778
+ test_bench2()