cache-dit 0.2.34__py3-none-any.whl → 0.2.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

@@ -1,25 +1,31 @@
1
1
  import os
2
+ import re
2
3
  import cv2
3
4
  import pathlib
4
5
  import argparse
5
6
  import numpy as np
6
7
  from tqdm import tqdm
7
8
  from functools import partial
9
+ from typing import Callable, Union, Tuple, List
8
10
  from skimage.metrics import mean_squared_error
9
11
  from skimage.metrics import peak_signal_noise_ratio
10
12
  from skimage.metrics import structural_similarity
11
- from cache_dit.metrics.fid import FrechetInceptionDistance
12
13
  from cache_dit.metrics.config import set_metrics_verbose
13
14
  from cache_dit.metrics.config import get_metrics_verbose
14
15
  from cache_dit.metrics.config import _IMAGE_EXTENSIONS
15
16
  from cache_dit.metrics.config import _VIDEO_EXTENSIONS
16
17
  from cache_dit.logger import init_logger
18
+ from cache_dit.metrics.fid import compute_fid
19
+ from cache_dit.metrics.fid import compute_video_fid
17
20
  from cache_dit.metrics.lpips import compute_lpips_img
21
+ from cache_dit.metrics.clip_score import compute_clip_score
22
+ from cache_dit.metrics.image_reward import compute_reward_score
18
23
 
19
24
  logger = init_logger(__name__)
20
25
 
21
26
 
22
27
  DISABLE_VERBOSE = not get_metrics_verbose()
28
+ PSNR_TYPE = "custom"
23
29
 
24
30
 
25
31
  def compute_lpips_file(
@@ -51,6 +57,35 @@ def compute_lpips_file(
51
57
  )
52
58
 
53
59
 
60
+ def set_psnr_type(psnr_type: str):
61
+ global PSNR_TYPE
62
+ PSNR_TYPE = psnr_type
63
+ assert PSNR_TYPE in ["skimage", "custom"]
64
+
65
+
66
+ def get_psnr_type():
67
+ global PSNR_TYPE
68
+ return PSNR_TYPE
69
+
70
+
71
+ def calculate_psnr(
72
+ image_true: np.ndarray,
73
+ image_test: np.ndarray,
74
+ ):
75
+ """Calculate PSNR (Peak Signal-to-Noise Ratio).
76
+
77
+ Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
78
+
79
+ Args:
80
+ image_true (ndarray): Images with range [0, 255].
81
+ image_test (ndarray): Images with range [0, 255].
82
+ """
83
+ mse = np.mean((image_true - image_test) ** 2)
84
+ if mse == 0:
85
+ return float("inf")
86
+ return 20 * np.log10(255.0 / np.sqrt(mse))
87
+
88
+
54
89
  def compute_psnr_file(
55
90
  image_true: np.ndarray | str,
56
91
  image_test: np.ndarray | str,
@@ -64,10 +99,13 @@ def compute_psnr_file(
64
99
  image_true = cv2.imread(image_true)
65
100
  if isinstance(image_test, str):
66
101
  image_test = cv2.imread(image_test)
67
- return peak_signal_noise_ratio(
68
- image_true,
69
- image_test,
70
- )
102
+ if get_psnr_type() == "skimage":
103
+ return peak_signal_noise_ratio(
104
+ image_true,
105
+ image_test,
106
+ )
107
+ else:
108
+ return calculate_psnr(image_true, image_test)
71
109
 
72
110
 
73
111
  def compute_mse_file(
@@ -114,7 +152,7 @@ def compute_dir_metric(
114
152
  image_true_dir: np.ndarray | str,
115
153
  image_test_dir: np.ndarray | str,
116
154
  compute_file_func: callable = compute_psnr_file,
117
- ) -> float:
155
+ ) -> Union[Tuple[float, int], Tuple[None, None]]:
118
156
  # Image
119
157
  if isinstance(image_true_dir, np.ndarray) or isinstance(
120
158
  image_test_dir, np.ndarray
@@ -123,25 +161,30 @@ def compute_dir_metric(
123
161
  # File
124
162
  if not os.path.isdir(image_true_dir) or not os.path.isdir(image_test_dir):
125
163
  return compute_file_func(image_true_dir, image_test_dir), 1
164
+
126
165
  # Dir
166
+ # compute dir metric
167
+ def natural_sort_key(filename):
168
+ match = re.search(r"(\d+)\D*$", filename)
169
+ return int(match.group(1)) if match else filename
170
+
127
171
  image_true_dir: pathlib.Path = pathlib.Path(image_true_dir)
128
- image_true_files = sorted(
129
- [
130
- file
131
- for ext in _IMAGE_EXTENSIONS
132
- for file in image_true_dir.rglob("*.{}".format(ext))
133
- ]
134
- )
135
- image_test_dir: pathlib.Path = pathlib.Path(image_test_dir)
136
- image_test_files = sorted(
137
- [
138
- file
139
- for ext in _IMAGE_EXTENSIONS
140
- for file in image_test_dir.rglob("*.{}".format(ext))
141
- ]
142
- )
172
+ image_true_files = [
173
+ file
174
+ for ext in _IMAGE_EXTENSIONS
175
+ for file in image_true_dir.rglob("*.{}".format(ext))
176
+ ]
143
177
  image_true_files = [file.as_posix() for file in image_true_files]
178
+ image_true_files = sorted(image_true_files, key=natural_sort_key)
179
+
180
+ image_test_dir: pathlib.Path = pathlib.Path(image_test_dir)
181
+ image_test_files = [
182
+ file
183
+ for ext in _IMAGE_EXTENSIONS
184
+ for file in image_test_dir.rglob("*.{}".format(ext))
185
+ ]
144
186
  image_test_files = [file.as_posix() for file in image_test_files]
187
+ image_test_files = sorted(image_test_files, key=natural_sort_key)
145
188
 
146
189
  # select valid files
147
190
  image_true_files_selected = []
@@ -155,6 +198,7 @@ def compute_dir_metric(
155
198
  ):
156
199
  image_true_files_selected.append(selected_image_true)
157
200
  image_test_files_selected.append(selected_image_test)
201
+
158
202
  image_true_files = image_true_files_selected.copy()
159
203
  image_test_files = image_test_files_selected.copy()
160
204
  if len(image_true_files) == 0:
@@ -169,20 +213,22 @@ def compute_dir_metric(
169
213
 
170
214
  total_metric = 0.0
171
215
  valid_files = 0
216
+ total_files = 0
172
217
  for image_true, image_test in tqdm(
173
218
  zip(image_true_files, image_test_files),
174
219
  total=len(image_true_files),
175
220
  disable=DISABLE_VERBOSE,
176
221
  ):
177
222
  metric = compute_file_func(image_true, image_test)
178
- if metric != float("inf"):
223
+ if metric != float("inf"): # means no cache apply to image_test
179
224
  total_metric += metric
180
225
  valid_files += 1
226
+ total_files += 1
181
227
 
182
228
  if valid_files > 0:
183
229
  average_metric = total_metric / valid_files
184
230
  logger.debug(f"Average: {average_metric:.2f}")
185
- return average_metric, valid_files
231
+ return average_metric, total_files
186
232
  else:
187
233
  logger.debug("No valid files to compare")
188
234
  return None, None
@@ -235,7 +281,7 @@ def compute_video_metric(
235
281
  video_true: str,
236
282
  video_test: str,
237
283
  compute_frame_func: callable = compute_psnr_file,
238
- ) -> float:
284
+ ) -> Union[Tuple[float, int], Tuple[None, None]]:
239
285
  """
240
286
  video_true = "video_true.mp4"
241
287
  video_test = "video_test.mp4"
@@ -335,51 +381,69 @@ def compute_video_metric(
335
381
  return None, None
336
382
 
337
383
 
338
- compute_lpips = partial(
339
- compute_dir_metric,
340
- compute_file_func=compute_lpips_file,
384
+ compute_lpips: Callable[..., Union[Tuple[float, int], Tuple[None, None]]] = (
385
+ partial(
386
+ compute_dir_metric,
387
+ compute_file_func=compute_lpips_file,
388
+ )
341
389
  )
342
390
 
343
- compute_psnr = partial(
344
- compute_dir_metric,
345
- compute_file_func=compute_psnr_file,
391
+ compute_psnr: Callable[..., Union[Tuple[float, int], Tuple[None, None]]] = (
392
+ partial(
393
+ compute_dir_metric,
394
+ compute_file_func=compute_psnr_file,
395
+ )
346
396
  )
347
397
 
348
- compute_ssim = partial(
349
- compute_dir_metric,
350
- compute_file_func=compute_ssim_file,
398
+ compute_ssim: Callable[..., Union[Tuple[float, int], Tuple[None, None]]] = (
399
+ partial(
400
+ compute_dir_metric,
401
+ compute_file_func=compute_ssim_file,
402
+ )
351
403
  )
352
404
 
353
- compute_mse = partial(
354
- compute_dir_metric,
355
- compute_file_func=compute_mse_file,
405
+ compute_mse: Callable[..., Union[Tuple[float, int], Tuple[None, None]]] = (
406
+ partial(
407
+ compute_dir_metric,
408
+ compute_file_func=compute_mse_file,
409
+ )
356
410
  )
357
411
 
358
- compute_video_lpips = partial(
412
+ compute_video_lpips: Callable[
413
+ ..., Union[Tuple[float, int], Tuple[None, None]]
414
+ ] = partial(
359
415
  compute_video_metric,
360
416
  compute_frame_func=compute_lpips_file,
361
417
  )
362
- compute_video_psnr = partial(
418
+ compute_video_psnr: Callable[
419
+ ..., Union[Tuple[float, int], Tuple[None, None]]
420
+ ] = partial(
363
421
  compute_video_metric,
364
422
  compute_frame_func=compute_psnr_file,
365
423
  )
366
- compute_video_ssim = partial(
424
+ compute_video_ssim: Callable[
425
+ ..., Union[Tuple[float, int], Tuple[None, None]]
426
+ ] = partial(
367
427
  compute_video_metric,
368
428
  compute_frame_func=compute_ssim_file,
369
429
  )
370
- compute_video_mse = partial(
430
+ compute_video_mse: Callable[
431
+ ..., Union[Tuple[float, int], Tuple[None, None]]
432
+ ] = partial(
371
433
  compute_video_metric,
372
434
  compute_frame_func=compute_mse_file,
373
435
  )
374
436
 
375
437
 
376
438
  METRICS_CHOICES = [
377
- "lpips",
378
- "psnr",
379
- "ssim",
380
- "mse",
381
- "fid",
382
- "all",
439
+ "lpips", # img vs img
440
+ "psnr", # img vs img
441
+ "ssim", # img vs img
442
+ "mse", # img vs img
443
+ "fid", # img vs img
444
+ "all", # img vs img
445
+ "clip_score", # img vs prompt
446
+ "image_reward", # img vs prompt
383
447
  ]
384
448
 
385
449
 
@@ -405,6 +469,13 @@ def get_args():
405
469
  default=None,
406
470
  help="Path to ground truth image or Dir to ground truth images",
407
471
  )
472
+ parser.add_argument(
473
+ "--prompt-true",
474
+ "-p",
475
+ type=str,
476
+ default=None,
477
+ help="Path to ground truth prompt file for CLIP Score and Image Reward Score.",
478
+ )
408
479
  parser.add_argument(
409
480
  "--img-test",
410
481
  "-i2",
@@ -442,6 +513,13 @@ def get_args():
442
513
  default=None,
443
514
  help="Path to ref dir that contains ground truth images",
444
515
  )
516
+ parser.add_argument(
517
+ "--ref-prompt-true",
518
+ "-rp",
519
+ type=str,
520
+ default=None,
521
+ help="Path to ground truth prompt file for CLIP Score and Image Reward Score.",
522
+ )
445
523
 
446
524
  # Video 1 vs N pattern
447
525
  parser.add_argument(
@@ -495,10 +573,11 @@ def get_args():
495
573
  help="Path to addtional perf log",
496
574
  )
497
575
  parser.add_argument(
498
- "--perf-tag",
499
- "-ptag",
576
+ "--perf-tags",
577
+ "-ptags",
578
+ nargs="+",
500
579
  type=str,
501
- default=None,
580
+ default=[],
502
581
  help="Tag to parse perf time from perf log",
503
582
  )
504
583
  parser.add_argument(
@@ -508,6 +587,26 @@ def get_args():
508
587
  default=[],
509
588
  help="Extra tags to parse perf time from perf log",
510
589
  )
590
+ parser.add_argument(
591
+ "--psnr-type",
592
+ type=str,
593
+ default="custom",
594
+ choices=["custom", "skimage"],
595
+ help="The compute type of PSNR, [custom, skimage]",
596
+ )
597
+ parser.add_argument(
598
+ "--cal-speedup",
599
+ action="store_true",
600
+ default=False,
601
+ help="Calculate performance speedup.",
602
+ )
603
+ parser.add_argument(
604
+ "--gen-markdown-table",
605
+ "-table",
606
+ action="store_true",
607
+ default=False,
608
+ help="Generate performance markdown table",
609
+ )
511
610
  return parser.parse_args()
512
611
 
513
612
 
@@ -516,16 +615,16 @@ def entrypoint():
516
615
  args = get_args()
517
616
  logger.debug(args)
518
617
 
618
+ if args.metrics in ["clip_score", "image_reward"]:
619
+ assert args.prompt_true is not None or args.ref_prompt_true is not None
620
+ assert args.img_test is not None or args.img_source_dir is not None
621
+
519
622
  if args.enable_verbose:
520
623
  global DISABLE_VERBOSE
521
624
  set_metrics_verbose(True)
522
625
  DISABLE_VERBOSE = not get_metrics_verbose()
523
626
 
524
- if "all" in args.metrics or "fid" in args.metrics:
525
- FID = FrechetInceptionDistance(
526
- disable_tqdm=DISABLE_VERBOSE,
527
- batch_size=args.fid_batch_size,
528
- )
627
+ set_psnr_type(args.psnr_type)
529
628
 
530
629
  METRICS_META: dict[str, float] = {}
531
630
 
@@ -533,11 +632,11 @@ def entrypoint():
533
632
  def _run_metric(
534
633
  metric: str,
535
634
  img_true: str = None,
635
+ prompt_true: str = None,
536
636
  img_test: str = None,
537
637
  video_true: str = None,
538
638
  video_test: str = None,
539
639
  ) -> None:
540
- nonlocal FID
541
640
  nonlocal METRICS_META
542
641
  metric = metric.lower()
543
642
  if img_true is not None and img_test is not None:
@@ -575,9 +674,39 @@ def entrypoint():
575
674
  img_mse, n = compute_mse(img_true, img_test)
576
675
  _logging_msg(img_mse, "mse", n)
577
676
  if metric == "fid" or metric == "all":
578
- img_fid, n = FID.compute_fid(img_true, img_test)
677
+ img_fid, n = compute_fid(img_true, img_test)
579
678
  _logging_msg(img_fid, "fid", n)
580
679
 
680
+ if prompt_true is not None and img_test is not None:
681
+ if any(
682
+ (
683
+ not os.path.exists(prompt_true), # file
684
+ not os.path.exists(img_test), # dir
685
+ )
686
+ ):
687
+ return
688
+
689
+ # img_true and img_test can be files or dirs
690
+ prompt_true_info = os.path.basename(prompt_true)
691
+ img_test_info = os.path.basename(img_test)
692
+
693
+ def _logging_msg(value: float, name, n: int):
694
+ if value is None or n is None:
695
+ return
696
+ msg = (
697
+ f"{prompt_true_info} vs {img_test_info}, "
698
+ f"Num: {n}, {name.upper()}: {value:.5f}"
699
+ )
700
+ METRICS_META[msg] = value
701
+ logger.info(msg)
702
+
703
+ if metric == "clip_score":
704
+ clip_score, n = compute_clip_score(img_test, prompt_true)
705
+ _logging_msg(clip_score, "clip_score", n)
706
+ if metric == "image_reward":
707
+ image_reward, n = compute_reward_score(img_test, prompt_true)
708
+ _logging_msg(image_reward, "image_reward", n)
709
+
581
710
  if video_true is not None and video_test is not None:
582
711
  if any(
583
712
  (
@@ -614,7 +743,7 @@ def entrypoint():
614
743
  video_mse, n = compute_video_mse(video_true, video_test)
615
744
  _logging_msg(video_mse, "mse", n)
616
745
  if metric == "fid" or metric == "all":
617
- video_fid, n = FID.compute_video_fid(video_true, video_test)
746
+ video_fid, n = compute_video_fid(video_true, video_test)
618
747
  _logging_msg(video_fid, "fid", n)
619
748
 
620
749
  # run selected metrics
@@ -627,7 +756,18 @@ def entrypoint():
627
756
  def _is_video_1vsN_pattern() -> bool:
628
757
  return args.video_source_dir is not None and args.ref_video is not None
629
758
 
630
- assert not all((_is_image_1vsN_pattern(), _is_video_1vsN_pattern()))
759
+ def _is_prompt_1vsN_pattern() -> bool:
760
+ return (
761
+ args.img_source_dir is not None and args.ref_prompt_true is not None
762
+ )
763
+
764
+ assert not all(
765
+ (
766
+ _is_image_1vsN_pattern(),
767
+ _is_video_1vsN_pattern(),
768
+ _is_prompt_1vsN_pattern(),
769
+ )
770
+ )
631
771
 
632
772
  if _is_image_1vsN_pattern():
633
773
  # Glob Image dirs
@@ -711,11 +851,42 @@ def entrypoint():
711
851
  video_test=video_test,
712
852
  )
713
853
 
854
+ elif _is_prompt_1vsN_pattern():
855
+ # Glob Image dirs
856
+ if not os.path.exists(args.img_source_dir):
857
+ logger.error(f"{args.img_source_dir} not exist!")
858
+ return
859
+
860
+ directories = []
861
+ for item in os.listdir(args.img_source_dir):
862
+ item_path = os.path.join(args.img_source_dir, item)
863
+ if os.path.isdir(item_path):
864
+ directories.append(item_path)
865
+
866
+ if len(directories) == 0:
867
+ return
868
+
869
+ directories = sorted(directories)
870
+ if not DISABLE_VERBOSE:
871
+ logger.info(
872
+ f"Compare {args.ref_prompt_true} vs {directories}, "
873
+ f"Num compares: {len(directories)}"
874
+ )
875
+
876
+ for metric in args.metrics:
877
+ for img_test_dir in directories:
878
+ _run_metric(
879
+ metric=metric,
880
+ prompt_true=args.ref_prompt_true,
881
+ img_test=img_test_dir,
882
+ )
883
+
714
884
  else:
715
885
  for metric in args.metrics:
716
886
  _run_metric(
717
887
  metric=metric,
718
888
  img_true=args.img_true,
889
+ prompt_true=args.prompt_true,
719
890
  img_test=args.img_test,
720
891
  video_true=args.video_true,
721
892
  video_test=args.video_test,
@@ -724,7 +895,7 @@ def entrypoint():
724
895
  if args.summary:
725
896
 
726
897
  def _fetch_perf():
727
- if args.perf_log is None or args.perf_tag is None:
898
+ if args.perf_log is None or len(args.perf_tags) == 0:
728
899
  return []
729
900
  if not os.path.exists(args.perf_log):
730
901
  return []
@@ -733,17 +904,20 @@ def entrypoint():
733
904
  perf_lines = file.readlines()
734
905
  for line in perf_lines:
735
906
  line = line.strip()
736
- if args.perf_tag.lower() in line.lower():
737
- if len(args.extra_perf_tags) == 0:
738
- perf_texts.append(line)
739
- else:
740
- has_all_extra_tag = True
741
- for ext_tag in args.extra_perf_tags:
742
- if ext_tag.lower() not in line.lower():
743
- has_all_extra_tag = False
744
- break
745
- if has_all_extra_tag:
907
+ for perf_tag in args.perf_tags:
908
+ if perf_tag.lower() in line.lower():
909
+ if len(args.extra_perf_tags) == 0:
746
910
  perf_texts.append(line)
911
+ break
912
+ else:
913
+ has_all_extra_tag = True
914
+ for ext_tag in args.extra_perf_tags:
915
+ if ext_tag.lower() not in line.lower():
916
+ has_all_extra_tag = False
917
+ break
918
+ if has_all_extra_tag:
919
+ perf_texts.append(line)
920
+ break
747
921
  return perf_texts
748
922
 
749
923
  PERF_TEXTS: list[str] = _fetch_perf()
@@ -770,8 +944,9 @@ def entrypoint():
770
944
  try:
771
945
  if tag.lower() in METRICS_CHOICES:
772
946
  return float(value_str)
773
- if args.perf_tag is not None:
774
- if tag.lower() == args.perf_tag.lower():
947
+ if len(args.perf_tags) > 0:
948
+ perf_tags = [tag.lower() for tag in args.perf_tags]
949
+ if tag.lower() in perf_tags:
775
950
  return float(value_str)
776
951
  return int(value_str)
777
952
  except ValueError:
@@ -779,17 +954,37 @@ def entrypoint():
779
954
 
780
955
  def _parse_perf(
781
956
  compare_tag: str,
957
+ perf_tag: str,
782
958
  ) -> float | None:
783
959
  nonlocal PERF_TEXTS
784
- perf_times = []
960
+ perf_values = []
785
961
  for line in PERF_TEXTS:
786
962
  if compare_tag in line:
787
- perf_time = _parse_value(line, args.perf_tag)
788
- if perf_time is not None:
789
- perf_times.append(perf_time)
790
- if len(perf_times) == 0:
963
+ perf_value = _parse_value(line, perf_tag)
964
+ if perf_value is not None:
965
+ perf_values.append(perf_value)
966
+ if len(perf_values) == 0:
791
967
  return None
792
- return sum(perf_times) / len(perf_times)
968
+ return sum(perf_values) / len(perf_values)
969
+
970
+ def _ref_perf(
971
+ key: str,
972
+ ):
973
+ # U1-Q0-C0-NONE vs U4-Q1-C1-NONE
974
+ header = key.split(",")[0].strip()
975
+ reference_tag = None
976
+ if args.prompt_true is None:
977
+ reference_tag = header.split("vs")[0].strip() # U1-Q0-C0-NONE
978
+
979
+ if reference_tag is None:
980
+ return []
981
+
982
+ ref_perf_values = []
983
+ for perf_tag in args.perf_tags:
984
+ perf_value = _parse_perf(reference_tag, perf_tag)
985
+ ref_perf_values.append(perf_value)
986
+
987
+ return ref_perf_values
793
988
 
794
989
  def _format_item(
795
990
  key: str,
@@ -802,40 +997,129 @@ def entrypoint():
802
997
  header = key.split(",")[0].strip()
803
998
  compare_tag = header.split("vs")[1].strip() # U4-Q1-C1-NONE
804
999
  has_perf_texts = len(PERF_TEXTS) > 0
1000
+
1001
+ def _perf_msg(perf_tag: str):
1002
+ if "time" in perf_tag.lower():
1003
+ perf_msg = "Latency(s)"
1004
+ elif "tflops" in perf_tag.lower():
1005
+ perf_msg = "TFLOPs"
1006
+ elif "flops" in perf_tag.lower():
1007
+ perf_msg = "FLOPs"
1008
+ else:
1009
+ perf_msg = perf_tag.upper()
1010
+ return perf_msg
1011
+
805
1012
  format_str = ""
806
1013
  # Num / Frames
1014
+ perf_values = []
1015
+ perf_msgs = []
807
1016
  if n := _parse_value(key, "Num"):
808
1017
  if not has_perf_texts:
809
1018
  format_str = (
810
- f"{header:<{max_key_len}} Num: {n} "
1019
+ f"{header:<{max_key_len}}, Num: {n}, "
811
1020
  f"{metric.upper()}: {value:<7.4f}"
812
1021
  )
813
1022
  else:
814
- perf_time = _parse_perf(compare_tag)
815
- perf_time = f"{perf_time:<.2f}" if perf_time else None
816
1023
  format_str = (
817
- f"{header:<{max_key_len}} Num: {n} "
818
- f"{metric.upper()}: {value:<7.4f} "
819
- f"Perf: {perf_time}"
1024
+ f"{header:<{max_key_len}}, Num: {n}, "
1025
+ f"{metric.upper()}: {value:<7.4f}, "
820
1026
  )
1027
+ for perf_tag in args.perf_tags:
1028
+ perf_value = _parse_perf(compare_tag, perf_tag)
1029
+ perf_values.append(perf_value)
1030
+
1031
+ perf_value = (
1032
+ f"{perf_value:<.2f}" if perf_value else None
1033
+ )
1034
+ perf_msg = _perf_msg(perf_tag)
1035
+ format_str += f"{perf_msg}: {perf_value}, "
1036
+
1037
+ perf_msgs.append(perf_msg)
1038
+
1039
+ if not args.cal_speedup:
1040
+ format_str = format_str.removesuffix(", ")
1041
+
821
1042
  elif n := _parse_value(key, "Frames"):
822
1043
  if not has_perf_texts:
823
1044
  format_str = (
824
- f"{header:<{max_key_len}} Frames: {n} "
1045
+ f"{header:<{max_key_len}}, Frames: {n}, "
825
1046
  f"{metric.upper()}: {value:<7.4f}"
826
1047
  )
827
1048
  else:
828
- perf_time = _parse_perf(compare_tag)
829
- perf_time = f"{perf_time:<.2f}" if perf_time else None
830
1049
  format_str = (
831
- f"{header:<{max_key_len}} Frames: {n} "
832
- f"{metric.upper()}: {value:<7.4f} "
833
- f"Perf: {perf_time}"
1050
+ f"{header:<{max_key_len}}, Frames: {n}, "
1051
+ f"{metric.upper()}: {value:<7.4f}, "
834
1052
  )
1053
+ for perf_tag in args.perf_tags:
1054
+ perf_value = _parse_perf(compare_tag, perf_tag)
1055
+ perf_values.append(perf_value)
1056
+
1057
+ perf_value = (
1058
+ f"{perf_value:<.2f}" if perf_value else None
1059
+ )
1060
+ perf_msg = _perf_msg(perf_tag)
1061
+ format_str += f"{perf_msg}: {perf_value}, "
1062
+ perf_msgs.append(perf_msg)
1063
+
1064
+ if not args.cal_speedup:
1065
+ format_str = format_str.removesuffix(", ")
835
1066
  else:
836
1067
  raise ValueError("Num or Frames can not be NoneType.")
837
1068
 
838
- return format_str
1069
+ return format_str, perf_values, perf_msgs
1070
+
1071
+ def _format_table(format_strs: List[str], metric: str):
1072
+ if not format_strs:
1073
+ return ""
1074
+
1075
+ metric_upper = metric.upper()
1076
+ all_headers = {"Config", metric_upper}
1077
+ row_data = []
1078
+
1079
+ for line in format_strs:
1080
+ parts = [p.strip() for p in line.split(",")]
1081
+
1082
+ config_part = parts[0].strip()
1083
+ if "vs" in config_part:
1084
+ config = config_part.split("vs", 1)[1].strip()
1085
+ if "_DBCACHE_" in config:
1086
+ config = config.split("_DBCACHE_", 1)[1].strip()
1087
+ else:
1088
+ config = config_part
1089
+
1090
+ metric_value = next(
1091
+ p.split(":")[1].strip()
1092
+ for p in parts
1093
+ if p.startswith(metric_upper)
1094
+ )
1095
+
1096
+ perf_data = {}
1097
+ for part in parts:
1098
+ if part.startswith(("Num:", "Frames:", metric_upper)):
1099
+ continue
1100
+ if ":" in part:
1101
+ key, value = part.split(":", 1)
1102
+ key = key.strip()
1103
+ value = value.strip()
1104
+ perf_data[key] = value
1105
+ all_headers.add(key)
1106
+
1107
+ row_data.append(
1108
+ {"Config": config, metric_upper: metric_value, **perf_data}
1109
+ )
1110
+
1111
+ sorted_headers = ["Config", metric_upper] + sorted(
1112
+ [h for h in all_headers if h not in ["Config", metric_upper]]
1113
+ )
1114
+
1115
+ table = "| " + " | ".join(sorted_headers) + " |\n"
1116
+ table += "| " + " | ".join(["---"] * len(sorted_headers)) + " |\n"
1117
+
1118
+ for row in row_data:
1119
+ row_values = [row.get(header, "") for header in sorted_headers]
1120
+ table += "| " + " | ".join(row_values) + " |\n"
1121
+
1122
+ return table.strip()
839
1123
 
840
1124
  selected_metrics = args.metrics
841
1125
  if "all" in selected_metrics:
@@ -848,7 +1132,17 @@ def entrypoint():
848
1132
  if metric.upper() in key or metric.lower() in key:
849
1133
  selected_items[key] = METRICS_META[key]
850
1134
 
851
- reverse = True if metric.lower() in ["psnr", "ssim"] else False
1135
+ reverse = (
1136
+ True
1137
+ if metric.lower()
1138
+ in [
1139
+ "psnr",
1140
+ "ssim",
1141
+ "clip_score",
1142
+ "image_reward",
1143
+ ]
1144
+ else False
1145
+ )
852
1146
  sorted_items = sorted(
853
1147
  selected_items.items(), key=lambda x: x[1], reverse=reverse
854
1148
  )
@@ -857,12 +1151,65 @@ def entrypoint():
857
1151
  ]
858
1152
  max_key_len = max(len(key) for key in selected_keys)
859
1153
 
1154
+ ref_perf_values = _ref_perf(key=selected_keys[0])
1155
+ max_perf_values: List[float] = []
1156
+
1157
+ if ref_perf_values and None not in ref_perf_values:
1158
+ max_perf_values = ref_perf_values.copy()
1159
+
1160
+ for key, value in sorted_items:
1161
+ format_str, perf_values, perf_msgs = _format_item(
1162
+ key, metric, value, max_key_len
1163
+ )
1164
+ # skip 'None' msg but not 'NONE', 'NONE' means w/o cache
1165
+ if "None" in format_str:
1166
+ continue
1167
+
1168
+ if (
1169
+ not perf_values
1170
+ or None in perf_values
1171
+ or not perf_msgs
1172
+ or not args.cal_speedup
1173
+ ):
1174
+ continue
1175
+
1176
+ if not max_perf_values:
1177
+ max_perf_values = perf_values
1178
+ else:
1179
+ for i in range(len(max_perf_values)):
1180
+ max_perf_values[i] = max(
1181
+ max_perf_values[i], perf_values[i]
1182
+ )
1183
+
860
1184
  format_strs = []
861
1185
  for key, value in sorted_items:
862
- format_strs.append(
863
- _format_item(key, metric, value, max_key_len)
1186
+ format_str, perf_values, perf_msgs = _format_item(
1187
+ key, metric, value, max_key_len
864
1188
  )
865
1189
 
1190
+ # skip 'None' msg but not 'NONE', 'NONE' means w/o cache
1191
+ if "None" in format_str:
1192
+ continue
1193
+
1194
+ if (
1195
+ not perf_values
1196
+ or None in perf_values
1197
+ or not perf_msgs
1198
+ or not max_perf_values
1199
+ or not args.cal_speedup
1200
+ ):
1201
+ format_strs.append(format_str)
1202
+ continue
1203
+
1204
+ for perf_value, perf_msg, max_perf_value in zip(
1205
+ perf_values, perf_msgs, max_perf_values
1206
+ ):
1207
+ perf_speedup = max_perf_value / perf_value
1208
+ format_str += f"{perf_msg}(↑): {perf_speedup:<.2f}, "
1209
+
1210
+ format_str = format_str.removesuffix(", ")
1211
+ format_strs.append(format_str)
1212
+
866
1213
  format_len = max(len(format_str) for format_str in format_strs)
867
1214
 
868
1215
  res_len = format_len - len(f"Summary: {metric.upper()}")
@@ -877,6 +1224,15 @@ def entrypoint():
877
1224
  print(format_str)
878
1225
  print("-" * format_len)
879
1226
 
1227
+ if args.gen_markdown_table:
1228
+ table = _format_table(format_strs, metric)
1229
+ table = table.replace("Latency(s)(↑)", "SpeedUp(↑)")
1230
+ table = table.replace("TFLOPs(↑)", "SpeedUp(↑)")
1231
+ table = table.replace("FLOPs(↑)", "SpeedUp(↑)")
1232
+ print("-" * format_len)
1233
+ print(f"{table}")
1234
+ print("-" * format_len)
1235
+
880
1236
 
881
1237
  if __name__ == "__main__":
882
1238
  entrypoint()