cache-dit 0.2.34__py3-none-any.whl → 0.2.36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,21 +5,26 @@ import argparse
5
5
  import numpy as np
6
6
  from tqdm import tqdm
7
7
  from functools import partial
8
+ from typing import Callable, Union, Tuple, List
8
9
  from skimage.metrics import mean_squared_error
9
10
  from skimage.metrics import peak_signal_noise_ratio
10
11
  from skimage.metrics import structural_similarity
11
- from cache_dit.metrics.fid import FrechetInceptionDistance
12
12
  from cache_dit.metrics.config import set_metrics_verbose
13
13
  from cache_dit.metrics.config import get_metrics_verbose
14
14
  from cache_dit.metrics.config import _IMAGE_EXTENSIONS
15
15
  from cache_dit.metrics.config import _VIDEO_EXTENSIONS
16
16
  from cache_dit.logger import init_logger
17
+ from cache_dit.metrics.fid import compute_fid
18
+ from cache_dit.metrics.fid import compute_video_fid
17
19
  from cache_dit.metrics.lpips import compute_lpips_img
20
+ from cache_dit.metrics.clip_score import compute_clip_score
21
+ from cache_dit.metrics.image_reward import compute_reward_score
18
22
 
19
23
  logger = init_logger(__name__)
20
24
 
21
25
 
22
26
  DISABLE_VERBOSE = not get_metrics_verbose()
27
+ PSNR_TYPE = "custom"
23
28
 
24
29
 
25
30
  def compute_lpips_file(
@@ -51,6 +56,35 @@ def compute_lpips_file(
51
56
  )
52
57
 
53
58
 
59
+ def set_psnr_type(psnr_type: str):
60
+ global PSNR_TYPE
61
+ PSNR_TYPE = psnr_type
62
+ assert PSNR_TYPE in ["skimage", "custom"]
63
+
64
+
65
+ def get_psnr_type():
66
+ global PSNR_TYPE
67
+ return PSNR_TYPE
68
+
69
+
70
+ def calculate_psnr(
71
+ image_true: np.ndarray,
72
+ image_test: np.ndarray,
73
+ ):
74
+ """Calculate PSNR (Peak Signal-to-Noise Ratio).
75
+
76
+ Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
77
+
78
+ Args:
79
+ image_true (ndarray): Images with range [0, 255].
80
+ image_test (ndarray): Images with range [0, 255].
81
+ """
82
+ mse = np.mean((image_true - image_test) ** 2)
83
+ if mse == 0:
84
+ return float("inf")
85
+ return 20 * np.log10(255.0 / np.sqrt(mse))
86
+
87
+
54
88
  def compute_psnr_file(
55
89
  image_true: np.ndarray | str,
56
90
  image_test: np.ndarray | str,
@@ -64,10 +98,13 @@ def compute_psnr_file(
64
98
  image_true = cv2.imread(image_true)
65
99
  if isinstance(image_test, str):
66
100
  image_test = cv2.imread(image_test)
67
- return peak_signal_noise_ratio(
68
- image_true,
69
- image_test,
70
- )
101
+ if get_psnr_type() == "skimage":
102
+ return peak_signal_noise_ratio(
103
+ image_true,
104
+ image_test,
105
+ )
106
+ else:
107
+ return calculate_psnr(image_true, image_test)
71
108
 
72
109
 
73
110
  def compute_mse_file(
@@ -114,7 +151,7 @@ def compute_dir_metric(
114
151
  image_true_dir: np.ndarray | str,
115
152
  image_test_dir: np.ndarray | str,
116
153
  compute_file_func: callable = compute_psnr_file,
117
- ) -> float:
154
+ ) -> Union[Tuple[float, int], Tuple[None, None]]:
118
155
  # Image
119
156
  if isinstance(image_true_dir, np.ndarray) or isinstance(
120
157
  image_test_dir, np.ndarray
@@ -235,7 +272,7 @@ def compute_video_metric(
235
272
  video_true: str,
236
273
  video_test: str,
237
274
  compute_frame_func: callable = compute_psnr_file,
238
- ) -> float:
275
+ ) -> Union[Tuple[float, int], Tuple[None, None]]:
239
276
  """
240
277
  video_true = "video_true.mp4"
241
278
  video_test = "video_test.mp4"
@@ -335,51 +372,69 @@ def compute_video_metric(
335
372
  return None, None
336
373
 
337
374
 
338
- compute_lpips = partial(
339
- compute_dir_metric,
340
- compute_file_func=compute_lpips_file,
375
+ compute_lpips: Callable[..., Union[Tuple[float, int], Tuple[None, None]]] = (
376
+ partial(
377
+ compute_dir_metric,
378
+ compute_file_func=compute_lpips_file,
379
+ )
341
380
  )
342
381
 
343
- compute_psnr = partial(
344
- compute_dir_metric,
345
- compute_file_func=compute_psnr_file,
382
+ compute_psnr: Callable[..., Union[Tuple[float, int], Tuple[None, None]]] = (
383
+ partial(
384
+ compute_dir_metric,
385
+ compute_file_func=compute_psnr_file,
386
+ )
346
387
  )
347
388
 
348
- compute_ssim = partial(
349
- compute_dir_metric,
350
- compute_file_func=compute_ssim_file,
389
+ compute_ssim: Callable[..., Union[Tuple[float, int], Tuple[None, None]]] = (
390
+ partial(
391
+ compute_dir_metric,
392
+ compute_file_func=compute_ssim_file,
393
+ )
351
394
  )
352
395
 
353
- compute_mse = partial(
354
- compute_dir_metric,
355
- compute_file_func=compute_mse_file,
396
+ compute_mse: Callable[..., Union[Tuple[float, int], Tuple[None, None]]] = (
397
+ partial(
398
+ compute_dir_metric,
399
+ compute_file_func=compute_mse_file,
400
+ )
356
401
  )
357
402
 
358
- compute_video_lpips = partial(
403
+ compute_video_lpips: Callable[
404
+ ..., Union[Tuple[float, int], Tuple[None, None]]
405
+ ] = partial(
359
406
  compute_video_metric,
360
407
  compute_frame_func=compute_lpips_file,
361
408
  )
362
- compute_video_psnr = partial(
409
+ compute_video_psnr: Callable[
410
+ ..., Union[Tuple[float, int], Tuple[None, None]]
411
+ ] = partial(
363
412
  compute_video_metric,
364
413
  compute_frame_func=compute_psnr_file,
365
414
  )
366
- compute_video_ssim = partial(
415
+ compute_video_ssim: Callable[
416
+ ..., Union[Tuple[float, int], Tuple[None, None]]
417
+ ] = partial(
367
418
  compute_video_metric,
368
419
  compute_frame_func=compute_ssim_file,
369
420
  )
370
- compute_video_mse = partial(
421
+ compute_video_mse: Callable[
422
+ ..., Union[Tuple[float, int], Tuple[None, None]]
423
+ ] = partial(
371
424
  compute_video_metric,
372
425
  compute_frame_func=compute_mse_file,
373
426
  )
374
427
 
375
428
 
376
429
  METRICS_CHOICES = [
377
- "lpips",
378
- "psnr",
379
- "ssim",
380
- "mse",
381
- "fid",
382
- "all",
430
+ "lpips", # img vs img
431
+ "psnr", # img vs img
432
+ "ssim", # img vs img
433
+ "mse", # img vs img
434
+ "fid", # img vs img
435
+ "all", # img vs img
436
+ "clip_score", # img vs prompt
437
+ "image_reward", # img vs prompt
383
438
  ]
384
439
 
385
440
 
@@ -405,6 +460,13 @@ def get_args():
405
460
  default=None,
406
461
  help="Path to ground truth image or Dir to ground truth images",
407
462
  )
463
+ parser.add_argument(
464
+ "--prompt-true",
465
+ "-p",
466
+ type=str,
467
+ default=None,
468
+ help="Path to ground truth prompt file for CLIP Score and Image Reward Score.",
469
+ )
408
470
  parser.add_argument(
409
471
  "--img-test",
410
472
  "-i2",
@@ -442,6 +504,13 @@ def get_args():
442
504
  default=None,
443
505
  help="Path to ref dir that contains ground truth images",
444
506
  )
507
+ parser.add_argument(
508
+ "--ref-prompt-true",
509
+ "-rp",
510
+ type=str,
511
+ default=None,
512
+ help="Path to ground truth prompt file for CLIP Score and Image Reward Score.",
513
+ )
445
514
 
446
515
  # Video 1 vs N pattern
447
516
  parser.add_argument(
@@ -495,10 +564,11 @@ def get_args():
495
564
  help="Path to addtional perf log",
496
565
  )
497
566
  parser.add_argument(
498
- "--perf-tag",
499
- "-ptag",
567
+ "--perf-tags",
568
+ "-ptags",
569
+ nargs="+",
500
570
  type=str,
501
- default=None,
571
+ default=[],
502
572
  help="Tag to parse perf time from perf log",
503
573
  )
504
574
  parser.add_argument(
@@ -508,6 +578,26 @@ def get_args():
508
578
  default=[],
509
579
  help="Extra tags to parse perf time from perf log",
510
580
  )
581
+ parser.add_argument(
582
+ "--psnr-type",
583
+ type=str,
584
+ default="custom",
585
+ choices=["custom", "skimage"],
586
+ help="The compute type of PSNR, [custom, skimage]",
587
+ )
588
+ parser.add_argument(
589
+ "--cal-speedup",
590
+ action="store_true",
591
+ default=False,
592
+ help="Calculate performance speedup.",
593
+ )
594
+ parser.add_argument(
595
+ "--gen-markdown-table",
596
+ "-table",
597
+ action="store_true",
598
+ default=False,
599
+ help="Generate performance markdown table",
600
+ )
511
601
  return parser.parse_args()
512
602
 
513
603
 
@@ -516,16 +606,16 @@ def entrypoint():
516
606
  args = get_args()
517
607
  logger.debug(args)
518
608
 
609
+ if args.metrics in ["clip_score", "image_reward"]:
610
+ assert args.prompt_true is not None or args.ref_prompt_true is not None
611
+ assert args.img_test is not None or args.img_source_dir is not None
612
+
519
613
  if args.enable_verbose:
520
614
  global DISABLE_VERBOSE
521
615
  set_metrics_verbose(True)
522
616
  DISABLE_VERBOSE = not get_metrics_verbose()
523
617
 
524
- if "all" in args.metrics or "fid" in args.metrics:
525
- FID = FrechetInceptionDistance(
526
- disable_tqdm=DISABLE_VERBOSE,
527
- batch_size=args.fid_batch_size,
528
- )
618
+ set_psnr_type(args.psnr_type)
529
619
 
530
620
  METRICS_META: dict[str, float] = {}
531
621
 
@@ -533,11 +623,11 @@ def entrypoint():
533
623
  def _run_metric(
534
624
  metric: str,
535
625
  img_true: str = None,
626
+ prompt_true: str = None,
536
627
  img_test: str = None,
537
628
  video_true: str = None,
538
629
  video_test: str = None,
539
630
  ) -> None:
540
- nonlocal FID
541
631
  nonlocal METRICS_META
542
632
  metric = metric.lower()
543
633
  if img_true is not None and img_test is not None:
@@ -575,9 +665,39 @@ def entrypoint():
575
665
  img_mse, n = compute_mse(img_true, img_test)
576
666
  _logging_msg(img_mse, "mse", n)
577
667
  if metric == "fid" or metric == "all":
578
- img_fid, n = FID.compute_fid(img_true, img_test)
668
+ img_fid, n = compute_fid(img_true, img_test)
579
669
  _logging_msg(img_fid, "fid", n)
580
670
 
671
+ if prompt_true is not None and img_test is not None:
672
+ if any(
673
+ (
674
+ not os.path.exists(prompt_true), # file
675
+ not os.path.exists(img_test), # dir
676
+ )
677
+ ):
678
+ return
679
+
680
+ # img_true and img_test can be files or dirs
681
+ prompt_true_info = os.path.basename(prompt_true)
682
+ img_test_info = os.path.basename(img_test)
683
+
684
+ def _logging_msg(value: float, name, n: int):
685
+ if value is None or n is None:
686
+ return
687
+ msg = (
688
+ f"{prompt_true_info} vs {img_test_info}, "
689
+ f"Num: {n}, {name.upper()}: {value:.5f}"
690
+ )
691
+ METRICS_META[msg] = value
692
+ logger.info(msg)
693
+
694
+ if metric == "clip_score":
695
+ clip_score, n = compute_clip_score(img_test, prompt_true)
696
+ _logging_msg(clip_score, "clip_score", n)
697
+ if metric == "image_reward":
698
+ image_reward, n = compute_reward_score(img_test, prompt_true)
699
+ _logging_msg(image_reward, "image_reward", n)
700
+
581
701
  if video_true is not None and video_test is not None:
582
702
  if any(
583
703
  (
@@ -614,7 +734,7 @@ def entrypoint():
614
734
  video_mse, n = compute_video_mse(video_true, video_test)
615
735
  _logging_msg(video_mse, "mse", n)
616
736
  if metric == "fid" or metric == "all":
617
- video_fid, n = FID.compute_video_fid(video_true, video_test)
737
+ video_fid, n = compute_video_fid(video_true, video_test)
618
738
  _logging_msg(video_fid, "fid", n)
619
739
 
620
740
  # run selected metrics
@@ -627,7 +747,18 @@ def entrypoint():
627
747
  def _is_video_1vsN_pattern() -> bool:
628
748
  return args.video_source_dir is not None and args.ref_video is not None
629
749
 
630
- assert not all((_is_image_1vsN_pattern(), _is_video_1vsN_pattern()))
750
+ def _is_prompt_1vsN_pattern() -> bool:
751
+ return (
752
+ args.img_source_dir is not None and args.ref_prompt_true is not None
753
+ )
754
+
755
+ assert not all(
756
+ (
757
+ _is_image_1vsN_pattern(),
758
+ _is_video_1vsN_pattern(),
759
+ _is_prompt_1vsN_pattern(),
760
+ )
761
+ )
631
762
 
632
763
  if _is_image_1vsN_pattern():
633
764
  # Glob Image dirs
@@ -711,11 +842,42 @@ def entrypoint():
711
842
  video_test=video_test,
712
843
  )
713
844
 
845
+ elif _is_prompt_1vsN_pattern():
846
+ # Glob Image dirs
847
+ if not os.path.exists(args.img_source_dir):
848
+ logger.error(f"{args.img_source_dir} not exist!")
849
+ return
850
+
851
+ directories = []
852
+ for item in os.listdir(args.img_source_dir):
853
+ item_path = os.path.join(args.img_source_dir, item)
854
+ if os.path.isdir(item_path):
855
+ directories.append(item_path)
856
+
857
+ if len(directories) == 0:
858
+ return
859
+
860
+ directories = sorted(directories)
861
+ if not DISABLE_VERBOSE:
862
+ logger.info(
863
+ f"Compare {args.ref_prompt_true} vs {directories}, "
864
+ f"Num compares: {len(directories)}"
865
+ )
866
+
867
+ for metric in args.metrics:
868
+ for img_test_dir in directories:
869
+ _run_metric(
870
+ metric=metric,
871
+ prompt_true=args.ref_prompt_true,
872
+ img_test=img_test_dir,
873
+ )
874
+
714
875
  else:
715
876
  for metric in args.metrics:
716
877
  _run_metric(
717
878
  metric=metric,
718
879
  img_true=args.img_true,
880
+ prompt_true=args.prompt_true,
719
881
  img_test=args.img_test,
720
882
  video_true=args.video_true,
721
883
  video_test=args.video_test,
@@ -724,7 +886,7 @@ def entrypoint():
724
886
  if args.summary:
725
887
 
726
888
  def _fetch_perf():
727
- if args.perf_log is None or args.perf_tag is None:
889
+ if args.perf_log is None or len(args.perf_tags) == 0:
728
890
  return []
729
891
  if not os.path.exists(args.perf_log):
730
892
  return []
@@ -733,17 +895,20 @@ def entrypoint():
733
895
  perf_lines = file.readlines()
734
896
  for line in perf_lines:
735
897
  line = line.strip()
736
- if args.perf_tag.lower() in line.lower():
737
- if len(args.extra_perf_tags) == 0:
738
- perf_texts.append(line)
739
- else:
740
- has_all_extra_tag = True
741
- for ext_tag in args.extra_perf_tags:
742
- if ext_tag.lower() not in line.lower():
743
- has_all_extra_tag = False
744
- break
745
- if has_all_extra_tag:
898
+ for perf_tag in args.perf_tags:
899
+ if perf_tag.lower() in line.lower():
900
+ if len(args.extra_perf_tags) == 0:
746
901
  perf_texts.append(line)
902
+ break
903
+ else:
904
+ has_all_extra_tag = True
905
+ for ext_tag in args.extra_perf_tags:
906
+ if ext_tag.lower() not in line.lower():
907
+ has_all_extra_tag = False
908
+ break
909
+ if has_all_extra_tag:
910
+ perf_texts.append(line)
911
+ break
747
912
  return perf_texts
748
913
 
749
914
  PERF_TEXTS: list[str] = _fetch_perf()
@@ -770,8 +935,9 @@ def entrypoint():
770
935
  try:
771
936
  if tag.lower() in METRICS_CHOICES:
772
937
  return float(value_str)
773
- if args.perf_tag is not None:
774
- if tag.lower() == args.perf_tag.lower():
938
+ if len(args.perf_tags) > 0:
939
+ perf_tags = [tag.lower() for tag in args.perf_tags]
940
+ if tag.lower() in perf_tags:
775
941
  return float(value_str)
776
942
  return int(value_str)
777
943
  except ValueError:
@@ -779,17 +945,37 @@ def entrypoint():
779
945
 
780
946
  def _parse_perf(
781
947
  compare_tag: str,
948
+ perf_tag: str,
782
949
  ) -> float | None:
783
950
  nonlocal PERF_TEXTS
784
- perf_times = []
951
+ perf_values = []
785
952
  for line in PERF_TEXTS:
786
953
  if compare_tag in line:
787
- perf_time = _parse_value(line, args.perf_tag)
788
- if perf_time is not None:
789
- perf_times.append(perf_time)
790
- if len(perf_times) == 0:
954
+ perf_value = _parse_value(line, perf_tag)
955
+ if perf_value is not None:
956
+ perf_values.append(perf_value)
957
+ if len(perf_values) == 0:
791
958
  return None
792
- return sum(perf_times) / len(perf_times)
959
+ return sum(perf_values) / len(perf_values)
960
+
961
+ def _ref_perf(
962
+ key: str,
963
+ ):
964
+ # U1-Q0-C0-NONE vs U4-Q1-C1-NONE
965
+ header = key.split(",")[0].strip()
966
+ reference_tag = None
967
+ if args.prompt_true is None:
968
+ reference_tag = header.split("vs")[0].strip() # U1-Q0-C0-NONE
969
+
970
+ if reference_tag is None:
971
+ return []
972
+
973
+ ref_perf_values = []
974
+ for perf_tag in args.perf_tags:
975
+ perf_value = _parse_perf(reference_tag, perf_tag)
976
+ ref_perf_values.append(perf_value)
977
+
978
+ return ref_perf_values
793
979
 
794
980
  def _format_item(
795
981
  key: str,
@@ -802,40 +988,129 @@ def entrypoint():
802
988
  header = key.split(",")[0].strip()
803
989
  compare_tag = header.split("vs")[1].strip() # U4-Q1-C1-NONE
804
990
  has_perf_texts = len(PERF_TEXTS) > 0
991
+
992
+ def _perf_msg(perf_tag: str):
993
+ if "time" in perf_tag.lower():
994
+ perf_msg = "Latency(s)"
995
+ elif "tflops" in perf_tag.lower():
996
+ perf_msg = "TFLOPs"
997
+ elif "flops" in perf_tag.lower():
998
+ perf_msg = "FLOPs"
999
+ else:
1000
+ perf_msg = perf_tag.upper()
1001
+ return perf_msg
1002
+
805
1003
  format_str = ""
806
1004
  # Num / Frames
1005
+ perf_values = []
1006
+ perf_msgs = []
807
1007
  if n := _parse_value(key, "Num"):
808
1008
  if not has_perf_texts:
809
1009
  format_str = (
810
- f"{header:<{max_key_len}} Num: {n} "
1010
+ f"{header:<{max_key_len}}, Num: {n}, "
811
1011
  f"{metric.upper()}: {value:<7.4f}"
812
1012
  )
813
1013
  else:
814
- perf_time = _parse_perf(compare_tag)
815
- perf_time = f"{perf_time:<.2f}" if perf_time else None
816
1014
  format_str = (
817
- f"{header:<{max_key_len}} Num: {n} "
818
- f"{metric.upper()}: {value:<7.4f} "
819
- f"Perf: {perf_time}"
1015
+ f"{header:<{max_key_len}}, Num: {n}, "
1016
+ f"{metric.upper()}: {value:<7.4f}, "
820
1017
  )
1018
+ for perf_tag in args.perf_tags:
1019
+ perf_value = _parse_perf(compare_tag, perf_tag)
1020
+ perf_values.append(perf_value)
1021
+
1022
+ perf_value = (
1023
+ f"{perf_value:<.2f}" if perf_value else None
1024
+ )
1025
+ perf_msg = _perf_msg(perf_tag)
1026
+ format_str += f"{perf_msg}: {perf_value}, "
1027
+
1028
+ perf_msgs.append(perf_msg)
1029
+
1030
+ if not args.cal_speedup:
1031
+ format_str = format_str.removesuffix(", ")
1032
+
821
1033
  elif n := _parse_value(key, "Frames"):
822
1034
  if not has_perf_texts:
823
1035
  format_str = (
824
- f"{header:<{max_key_len}} Frames: {n} "
1036
+ f"{header:<{max_key_len}}, Frames: {n}, "
825
1037
  f"{metric.upper()}: {value:<7.4f}"
826
1038
  )
827
1039
  else:
828
- perf_time = _parse_perf(compare_tag)
829
- perf_time = f"{perf_time:<.2f}" if perf_time else None
830
1040
  format_str = (
831
- f"{header:<{max_key_len}} Frames: {n} "
832
- f"{metric.upper()}: {value:<7.4f} "
833
- f"Perf: {perf_time}"
1041
+ f"{header:<{max_key_len}}, Frames: {n}, "
1042
+ f"{metric.upper()}: {value:<7.4f}, "
834
1043
  )
1044
+ for perf_tag in args.perf_tags:
1045
+ perf_value = _parse_perf(compare_tag, perf_tag)
1046
+ perf_values.append(perf_value)
1047
+
1048
+ perf_value = (
1049
+ f"{perf_value:<.2f}" if perf_value else None
1050
+ )
1051
+ perf_msg = _perf_msg(perf_tag)
1052
+ format_str += f"{perf_msg}: {perf_value}, "
1053
+ perf_msgs.append(perf_msg)
1054
+
1055
+ if not args.cal_speedup:
1056
+ format_str = format_str.removesuffix(", ")
835
1057
  else:
836
1058
  raise ValueError("Num or Frames can not be NoneType.")
837
1059
 
838
- return format_str
1060
+ return format_str, perf_values, perf_msgs
1061
+
1062
+ def _format_table(format_strs: List[str], metric: str):
1063
+ if not format_strs:
1064
+ return ""
1065
+
1066
+ metric_upper = metric.upper()
1067
+ all_headers = {"Config", metric_upper}
1068
+ row_data = []
1069
+
1070
+ for line in format_strs:
1071
+ parts = [p.strip() for p in line.split(",")]
1072
+
1073
+ config_part = parts[0].strip()
1074
+ if "vs" in config_part:
1075
+ config = config_part.split("vs", 1)[1].strip()
1076
+ if "_DBCACHE_" in config:
1077
+ config = config.split("_DBCACHE_", 1)[1].strip()
1078
+ else:
1079
+ config = config_part
1080
+
1081
+ metric_value = next(
1082
+ p.split(":")[1].strip()
1083
+ for p in parts
1084
+ if p.startswith(metric_upper)
1085
+ )
1086
+
1087
+ perf_data = {}
1088
+ for part in parts:
1089
+ if part.startswith(("Num:", "Frames:", metric_upper)):
1090
+ continue
1091
+ if ":" in part:
1092
+ key, value = part.split(":", 1)
1093
+ key = key.strip()
1094
+ value = value.strip()
1095
+ perf_data[key] = value
1096
+ all_headers.add(key)
1097
+
1098
+ row_data.append(
1099
+ {"Config": config, metric_upper: metric_value, **perf_data}
1100
+ )
1101
+
1102
+ sorted_headers = ["Config", metric_upper] + sorted(
1103
+ [h for h in all_headers if h not in ["Config", metric_upper]]
1104
+ )
1105
+
1106
+ table = "| " + " | ".join(sorted_headers) + " |\n"
1107
+ table += "| " + " | ".join(["---"] * len(sorted_headers)) + " |\n"
1108
+
1109
+ for row in row_data:
1110
+ row_values = [row.get(header, "") for header in sorted_headers]
1111
+ table += "| " + " | ".join(row_values) + " |\n"
1112
+
1113
+ return table.strip()
839
1114
 
840
1115
  selected_metrics = args.metrics
841
1116
  if "all" in selected_metrics:
@@ -848,7 +1123,17 @@ def entrypoint():
848
1123
  if metric.upper() in key or metric.lower() in key:
849
1124
  selected_items[key] = METRICS_META[key]
850
1125
 
851
- reverse = True if metric.lower() in ["psnr", "ssim"] else False
1126
+ reverse = (
1127
+ True
1128
+ if metric.lower()
1129
+ in [
1130
+ "psnr",
1131
+ "ssim",
1132
+ "clip_score",
1133
+ "image_reward",
1134
+ ]
1135
+ else False
1136
+ )
852
1137
  sorted_items = sorted(
853
1138
  selected_items.items(), key=lambda x: x[1], reverse=reverse
854
1139
  )
@@ -857,12 +1142,65 @@ def entrypoint():
857
1142
  ]
858
1143
  max_key_len = max(len(key) for key in selected_keys)
859
1144
 
1145
+ ref_perf_values = _ref_perf(key=selected_keys[0])
1146
+ max_perf_values: List[float] = []
1147
+
1148
+ if ref_perf_values and None not in ref_perf_values:
1149
+ max_perf_values = ref_perf_values.copy()
1150
+
1151
+ for key, value in sorted_items:
1152
+ format_str, perf_values, perf_msgs = _format_item(
1153
+ key, metric, value, max_key_len
1154
+ )
1155
+ # skip 'None' msg but not 'NONE', 'NONE' means w/o cache
1156
+ if "None" in format_str:
1157
+ continue
1158
+
1159
+ if (
1160
+ not perf_values
1161
+ or None in perf_values
1162
+ or not perf_msgs
1163
+ or not args.cal_speedup
1164
+ ):
1165
+ continue
1166
+
1167
+ if not max_perf_values:
1168
+ max_perf_values = perf_values
1169
+ else:
1170
+ for i in range(len(max_perf_values)):
1171
+ max_perf_values[i] = max(
1172
+ max_perf_values[i], perf_values[i]
1173
+ )
1174
+
860
1175
  format_strs = []
861
1176
  for key, value in sorted_items:
862
- format_strs.append(
863
- _format_item(key, metric, value, max_key_len)
1177
+ format_str, perf_values, perf_msgs = _format_item(
1178
+ key, metric, value, max_key_len
864
1179
  )
865
1180
 
1181
+ # skip 'None' msg but not 'NONE', 'NONE' means w/o cache
1182
+ if "None" in format_str:
1183
+ continue
1184
+
1185
+ if (
1186
+ not perf_values
1187
+ or None in perf_values
1188
+ or not perf_msgs
1189
+ or not max_perf_values
1190
+ or not args.cal_speedup
1191
+ ):
1192
+ format_strs.append(format_str)
1193
+ continue
1194
+
1195
+ for perf_value, perf_msg, max_perf_value in zip(
1196
+ perf_values, perf_msgs, max_perf_values
1197
+ ):
1198
+ perf_speedup = max_perf_value / perf_value
1199
+ format_str += f"{perf_msg}(↑): {perf_speedup:<.2f}, "
1200
+
1201
+ format_str = format_str.removesuffix(", ")
1202
+ format_strs.append(format_str)
1203
+
866
1204
  format_len = max(len(format_str) for format_str in format_strs)
867
1205
 
868
1206
  res_len = format_len - len(f"Summary: {metric.upper()}")
@@ -877,6 +1215,12 @@ def entrypoint():
877
1215
  print(format_str)
878
1216
  print("-" * format_len)
879
1217
 
1218
+ if args.gen_markdown_table:
1219
+ table = _format_table(format_strs, metric)
1220
+ print("-" * format_len)
1221
+ print(f"{table}")
1222
+ print("-" * format_len)
1223
+
880
1224
 
881
1225
  if __name__ == "__main__":
882
1226
  entrypoint()