cache-dit 0.2.33__py3-none-any.whl → 0.2.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +5 -3
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/block_adapters/__init__.py +14 -20
- cache_dit/cache_factory/block_adapters/block_adapters.py +46 -2
- cache_dit/cache_factory/block_adapters/block_registers.py +3 -2
- cache_dit/cache_factory/cache_adapters.py +8 -8
- cache_dit/cache_factory/cache_contexts/cache_context.py +11 -11
- cache_dit/cache_factory/cache_contexts/cache_manager.py +5 -5
- cache_dit/cache_factory/cache_contexts/taylorseer.py +12 -6
- cache_dit/cache_factory/cache_interface.py +9 -9
- cache_dit/cache_factory/patch_functors/__init__.py +1 -0
- cache_dit/cache_factory/patch_functors/functor_chroma.py +142 -52
- cache_dit/cache_factory/patch_functors/functor_dit.py +130 -0
- cache_dit/metrics/clip_score.py +135 -0
- cache_dit/metrics/fid.py +42 -0
- cache_dit/metrics/image_reward.py +177 -0
- cache_dit/metrics/lpips.py +2 -14
- cache_dit/metrics/metrics.py +420 -76
- cache_dit/utils.py +15 -0
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.36.dist-info}/METADATA +261 -52
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.36.dist-info}/RECORD +25 -22
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.36.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.36.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.36.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.36.dist-info}/top_level.txt +0 -0
cache_dit/metrics/metrics.py
CHANGED
|
@@ -5,21 +5,26 @@ import argparse
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
from tqdm import tqdm
|
|
7
7
|
from functools import partial
|
|
8
|
+
from typing import Callable, Union, Tuple, List
|
|
8
9
|
from skimage.metrics import mean_squared_error
|
|
9
10
|
from skimage.metrics import peak_signal_noise_ratio
|
|
10
11
|
from skimage.metrics import structural_similarity
|
|
11
|
-
from cache_dit.metrics.fid import FrechetInceptionDistance
|
|
12
12
|
from cache_dit.metrics.config import set_metrics_verbose
|
|
13
13
|
from cache_dit.metrics.config import get_metrics_verbose
|
|
14
14
|
from cache_dit.metrics.config import _IMAGE_EXTENSIONS
|
|
15
15
|
from cache_dit.metrics.config import _VIDEO_EXTENSIONS
|
|
16
16
|
from cache_dit.logger import init_logger
|
|
17
|
+
from cache_dit.metrics.fid import compute_fid
|
|
18
|
+
from cache_dit.metrics.fid import compute_video_fid
|
|
17
19
|
from cache_dit.metrics.lpips import compute_lpips_img
|
|
20
|
+
from cache_dit.metrics.clip_score import compute_clip_score
|
|
21
|
+
from cache_dit.metrics.image_reward import compute_reward_score
|
|
18
22
|
|
|
19
23
|
logger = init_logger(__name__)
|
|
20
24
|
|
|
21
25
|
|
|
22
26
|
DISABLE_VERBOSE = not get_metrics_verbose()
|
|
27
|
+
PSNR_TYPE = "custom"
|
|
23
28
|
|
|
24
29
|
|
|
25
30
|
def compute_lpips_file(
|
|
@@ -51,6 +56,35 @@ def compute_lpips_file(
|
|
|
51
56
|
)
|
|
52
57
|
|
|
53
58
|
|
|
59
|
+
def set_psnr_type(psnr_type: str):
|
|
60
|
+
global PSNR_TYPE
|
|
61
|
+
PSNR_TYPE = psnr_type
|
|
62
|
+
assert PSNR_TYPE in ["skimage", "custom"]
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def get_psnr_type():
|
|
66
|
+
global PSNR_TYPE
|
|
67
|
+
return PSNR_TYPE
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def calculate_psnr(
|
|
71
|
+
image_true: np.ndarray,
|
|
72
|
+
image_test: np.ndarray,
|
|
73
|
+
):
|
|
74
|
+
"""Calculate PSNR (Peak Signal-to-Noise Ratio).
|
|
75
|
+
|
|
76
|
+
Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
image_true (ndarray): Images with range [0, 255].
|
|
80
|
+
image_test (ndarray): Images with range [0, 255].
|
|
81
|
+
"""
|
|
82
|
+
mse = np.mean((image_true - image_test) ** 2)
|
|
83
|
+
if mse == 0:
|
|
84
|
+
return float("inf")
|
|
85
|
+
return 20 * np.log10(255.0 / np.sqrt(mse))
|
|
86
|
+
|
|
87
|
+
|
|
54
88
|
def compute_psnr_file(
|
|
55
89
|
image_true: np.ndarray | str,
|
|
56
90
|
image_test: np.ndarray | str,
|
|
@@ -64,10 +98,13 @@ def compute_psnr_file(
|
|
|
64
98
|
image_true = cv2.imread(image_true)
|
|
65
99
|
if isinstance(image_test, str):
|
|
66
100
|
image_test = cv2.imread(image_test)
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
101
|
+
if get_psnr_type() == "skimage":
|
|
102
|
+
return peak_signal_noise_ratio(
|
|
103
|
+
image_true,
|
|
104
|
+
image_test,
|
|
105
|
+
)
|
|
106
|
+
else:
|
|
107
|
+
return calculate_psnr(image_true, image_test)
|
|
71
108
|
|
|
72
109
|
|
|
73
110
|
def compute_mse_file(
|
|
@@ -114,7 +151,7 @@ def compute_dir_metric(
|
|
|
114
151
|
image_true_dir: np.ndarray | str,
|
|
115
152
|
image_test_dir: np.ndarray | str,
|
|
116
153
|
compute_file_func: callable = compute_psnr_file,
|
|
117
|
-
) -> float:
|
|
154
|
+
) -> Union[Tuple[float, int], Tuple[None, None]]:
|
|
118
155
|
# Image
|
|
119
156
|
if isinstance(image_true_dir, np.ndarray) or isinstance(
|
|
120
157
|
image_test_dir, np.ndarray
|
|
@@ -235,7 +272,7 @@ def compute_video_metric(
|
|
|
235
272
|
video_true: str,
|
|
236
273
|
video_test: str,
|
|
237
274
|
compute_frame_func: callable = compute_psnr_file,
|
|
238
|
-
) -> float:
|
|
275
|
+
) -> Union[Tuple[float, int], Tuple[None, None]]:
|
|
239
276
|
"""
|
|
240
277
|
video_true = "video_true.mp4"
|
|
241
278
|
video_test = "video_test.mp4"
|
|
@@ -335,51 +372,69 @@ def compute_video_metric(
|
|
|
335
372
|
return None, None
|
|
336
373
|
|
|
337
374
|
|
|
338
|
-
compute_lpips =
|
|
339
|
-
|
|
340
|
-
|
|
375
|
+
compute_lpips: Callable[..., Union[Tuple[float, int], Tuple[None, None]]] = (
|
|
376
|
+
partial(
|
|
377
|
+
compute_dir_metric,
|
|
378
|
+
compute_file_func=compute_lpips_file,
|
|
379
|
+
)
|
|
341
380
|
)
|
|
342
381
|
|
|
343
|
-
compute_psnr =
|
|
344
|
-
|
|
345
|
-
|
|
382
|
+
compute_psnr: Callable[..., Union[Tuple[float, int], Tuple[None, None]]] = (
|
|
383
|
+
partial(
|
|
384
|
+
compute_dir_metric,
|
|
385
|
+
compute_file_func=compute_psnr_file,
|
|
386
|
+
)
|
|
346
387
|
)
|
|
347
388
|
|
|
348
|
-
compute_ssim =
|
|
349
|
-
|
|
350
|
-
|
|
389
|
+
compute_ssim: Callable[..., Union[Tuple[float, int], Tuple[None, None]]] = (
|
|
390
|
+
partial(
|
|
391
|
+
compute_dir_metric,
|
|
392
|
+
compute_file_func=compute_ssim_file,
|
|
393
|
+
)
|
|
351
394
|
)
|
|
352
395
|
|
|
353
|
-
compute_mse =
|
|
354
|
-
|
|
355
|
-
|
|
396
|
+
compute_mse: Callable[..., Union[Tuple[float, int], Tuple[None, None]]] = (
|
|
397
|
+
partial(
|
|
398
|
+
compute_dir_metric,
|
|
399
|
+
compute_file_func=compute_mse_file,
|
|
400
|
+
)
|
|
356
401
|
)
|
|
357
402
|
|
|
358
|
-
compute_video_lpips
|
|
403
|
+
compute_video_lpips: Callable[
|
|
404
|
+
..., Union[Tuple[float, int], Tuple[None, None]]
|
|
405
|
+
] = partial(
|
|
359
406
|
compute_video_metric,
|
|
360
407
|
compute_frame_func=compute_lpips_file,
|
|
361
408
|
)
|
|
362
|
-
compute_video_psnr
|
|
409
|
+
compute_video_psnr: Callable[
|
|
410
|
+
..., Union[Tuple[float, int], Tuple[None, None]]
|
|
411
|
+
] = partial(
|
|
363
412
|
compute_video_metric,
|
|
364
413
|
compute_frame_func=compute_psnr_file,
|
|
365
414
|
)
|
|
366
|
-
compute_video_ssim
|
|
415
|
+
compute_video_ssim: Callable[
|
|
416
|
+
..., Union[Tuple[float, int], Tuple[None, None]]
|
|
417
|
+
] = partial(
|
|
367
418
|
compute_video_metric,
|
|
368
419
|
compute_frame_func=compute_ssim_file,
|
|
369
420
|
)
|
|
370
|
-
compute_video_mse
|
|
421
|
+
compute_video_mse: Callable[
|
|
422
|
+
..., Union[Tuple[float, int], Tuple[None, None]]
|
|
423
|
+
] = partial(
|
|
371
424
|
compute_video_metric,
|
|
372
425
|
compute_frame_func=compute_mse_file,
|
|
373
426
|
)
|
|
374
427
|
|
|
375
428
|
|
|
376
429
|
METRICS_CHOICES = [
|
|
377
|
-
"lpips",
|
|
378
|
-
"psnr",
|
|
379
|
-
"ssim",
|
|
380
|
-
"mse",
|
|
381
|
-
"fid",
|
|
382
|
-
"all",
|
|
430
|
+
"lpips", # img vs img
|
|
431
|
+
"psnr", # img vs img
|
|
432
|
+
"ssim", # img vs img
|
|
433
|
+
"mse", # img vs img
|
|
434
|
+
"fid", # img vs img
|
|
435
|
+
"all", # img vs img
|
|
436
|
+
"clip_score", # img vs prompt
|
|
437
|
+
"image_reward", # img vs prompt
|
|
383
438
|
]
|
|
384
439
|
|
|
385
440
|
|
|
@@ -405,6 +460,13 @@ def get_args():
|
|
|
405
460
|
default=None,
|
|
406
461
|
help="Path to ground truth image or Dir to ground truth images",
|
|
407
462
|
)
|
|
463
|
+
parser.add_argument(
|
|
464
|
+
"--prompt-true",
|
|
465
|
+
"-p",
|
|
466
|
+
type=str,
|
|
467
|
+
default=None,
|
|
468
|
+
help="Path to ground truth prompt file for CLIP Score and Image Reward Score.",
|
|
469
|
+
)
|
|
408
470
|
parser.add_argument(
|
|
409
471
|
"--img-test",
|
|
410
472
|
"-i2",
|
|
@@ -442,6 +504,13 @@ def get_args():
|
|
|
442
504
|
default=None,
|
|
443
505
|
help="Path to ref dir that contains ground truth images",
|
|
444
506
|
)
|
|
507
|
+
parser.add_argument(
|
|
508
|
+
"--ref-prompt-true",
|
|
509
|
+
"-rp",
|
|
510
|
+
type=str,
|
|
511
|
+
default=None,
|
|
512
|
+
help="Path to ground truth prompt file for CLIP Score and Image Reward Score.",
|
|
513
|
+
)
|
|
445
514
|
|
|
446
515
|
# Video 1 vs N pattern
|
|
447
516
|
parser.add_argument(
|
|
@@ -495,10 +564,11 @@ def get_args():
|
|
|
495
564
|
help="Path to addtional perf log",
|
|
496
565
|
)
|
|
497
566
|
parser.add_argument(
|
|
498
|
-
"--perf-
|
|
499
|
-
"-
|
|
567
|
+
"--perf-tags",
|
|
568
|
+
"-ptags",
|
|
569
|
+
nargs="+",
|
|
500
570
|
type=str,
|
|
501
|
-
default=
|
|
571
|
+
default=[],
|
|
502
572
|
help="Tag to parse perf time from perf log",
|
|
503
573
|
)
|
|
504
574
|
parser.add_argument(
|
|
@@ -508,6 +578,26 @@ def get_args():
|
|
|
508
578
|
default=[],
|
|
509
579
|
help="Extra tags to parse perf time from perf log",
|
|
510
580
|
)
|
|
581
|
+
parser.add_argument(
|
|
582
|
+
"--psnr-type",
|
|
583
|
+
type=str,
|
|
584
|
+
default="custom",
|
|
585
|
+
choices=["custom", "skimage"],
|
|
586
|
+
help="The compute type of PSNR, [custom, skimage]",
|
|
587
|
+
)
|
|
588
|
+
parser.add_argument(
|
|
589
|
+
"--cal-speedup",
|
|
590
|
+
action="store_true",
|
|
591
|
+
default=False,
|
|
592
|
+
help="Calculate performance speedup.",
|
|
593
|
+
)
|
|
594
|
+
parser.add_argument(
|
|
595
|
+
"--gen-markdown-table",
|
|
596
|
+
"-table",
|
|
597
|
+
action="store_true",
|
|
598
|
+
default=False,
|
|
599
|
+
help="Generate performance markdown table",
|
|
600
|
+
)
|
|
511
601
|
return parser.parse_args()
|
|
512
602
|
|
|
513
603
|
|
|
@@ -516,16 +606,16 @@ def entrypoint():
|
|
|
516
606
|
args = get_args()
|
|
517
607
|
logger.debug(args)
|
|
518
608
|
|
|
609
|
+
if args.metrics in ["clip_score", "image_reward"]:
|
|
610
|
+
assert args.prompt_true is not None or args.ref_prompt_true is not None
|
|
611
|
+
assert args.img_test is not None or args.img_source_dir is not None
|
|
612
|
+
|
|
519
613
|
if args.enable_verbose:
|
|
520
614
|
global DISABLE_VERBOSE
|
|
521
615
|
set_metrics_verbose(True)
|
|
522
616
|
DISABLE_VERBOSE = not get_metrics_verbose()
|
|
523
617
|
|
|
524
|
-
|
|
525
|
-
FID = FrechetInceptionDistance(
|
|
526
|
-
disable_tqdm=DISABLE_VERBOSE,
|
|
527
|
-
batch_size=args.fid_batch_size,
|
|
528
|
-
)
|
|
618
|
+
set_psnr_type(args.psnr_type)
|
|
529
619
|
|
|
530
620
|
METRICS_META: dict[str, float] = {}
|
|
531
621
|
|
|
@@ -533,11 +623,11 @@ def entrypoint():
|
|
|
533
623
|
def _run_metric(
|
|
534
624
|
metric: str,
|
|
535
625
|
img_true: str = None,
|
|
626
|
+
prompt_true: str = None,
|
|
536
627
|
img_test: str = None,
|
|
537
628
|
video_true: str = None,
|
|
538
629
|
video_test: str = None,
|
|
539
630
|
) -> None:
|
|
540
|
-
nonlocal FID
|
|
541
631
|
nonlocal METRICS_META
|
|
542
632
|
metric = metric.lower()
|
|
543
633
|
if img_true is not None and img_test is not None:
|
|
@@ -575,9 +665,39 @@ def entrypoint():
|
|
|
575
665
|
img_mse, n = compute_mse(img_true, img_test)
|
|
576
666
|
_logging_msg(img_mse, "mse", n)
|
|
577
667
|
if metric == "fid" or metric == "all":
|
|
578
|
-
img_fid, n =
|
|
668
|
+
img_fid, n = compute_fid(img_true, img_test)
|
|
579
669
|
_logging_msg(img_fid, "fid", n)
|
|
580
670
|
|
|
671
|
+
if prompt_true is not None and img_test is not None:
|
|
672
|
+
if any(
|
|
673
|
+
(
|
|
674
|
+
not os.path.exists(prompt_true), # file
|
|
675
|
+
not os.path.exists(img_test), # dir
|
|
676
|
+
)
|
|
677
|
+
):
|
|
678
|
+
return
|
|
679
|
+
|
|
680
|
+
# img_true and img_test can be files or dirs
|
|
681
|
+
prompt_true_info = os.path.basename(prompt_true)
|
|
682
|
+
img_test_info = os.path.basename(img_test)
|
|
683
|
+
|
|
684
|
+
def _logging_msg(value: float, name, n: int):
|
|
685
|
+
if value is None or n is None:
|
|
686
|
+
return
|
|
687
|
+
msg = (
|
|
688
|
+
f"{prompt_true_info} vs {img_test_info}, "
|
|
689
|
+
f"Num: {n}, {name.upper()}: {value:.5f}"
|
|
690
|
+
)
|
|
691
|
+
METRICS_META[msg] = value
|
|
692
|
+
logger.info(msg)
|
|
693
|
+
|
|
694
|
+
if metric == "clip_score":
|
|
695
|
+
clip_score, n = compute_clip_score(img_test, prompt_true)
|
|
696
|
+
_logging_msg(clip_score, "clip_score", n)
|
|
697
|
+
if metric == "image_reward":
|
|
698
|
+
image_reward, n = compute_reward_score(img_test, prompt_true)
|
|
699
|
+
_logging_msg(image_reward, "image_reward", n)
|
|
700
|
+
|
|
581
701
|
if video_true is not None and video_test is not None:
|
|
582
702
|
if any(
|
|
583
703
|
(
|
|
@@ -614,7 +734,7 @@ def entrypoint():
|
|
|
614
734
|
video_mse, n = compute_video_mse(video_true, video_test)
|
|
615
735
|
_logging_msg(video_mse, "mse", n)
|
|
616
736
|
if metric == "fid" or metric == "all":
|
|
617
|
-
video_fid, n =
|
|
737
|
+
video_fid, n = compute_video_fid(video_true, video_test)
|
|
618
738
|
_logging_msg(video_fid, "fid", n)
|
|
619
739
|
|
|
620
740
|
# run selected metrics
|
|
@@ -627,7 +747,18 @@ def entrypoint():
|
|
|
627
747
|
def _is_video_1vsN_pattern() -> bool:
|
|
628
748
|
return args.video_source_dir is not None and args.ref_video is not None
|
|
629
749
|
|
|
630
|
-
|
|
750
|
+
def _is_prompt_1vsN_pattern() -> bool:
|
|
751
|
+
return (
|
|
752
|
+
args.img_source_dir is not None and args.ref_prompt_true is not None
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
assert not all(
|
|
756
|
+
(
|
|
757
|
+
_is_image_1vsN_pattern(),
|
|
758
|
+
_is_video_1vsN_pattern(),
|
|
759
|
+
_is_prompt_1vsN_pattern(),
|
|
760
|
+
)
|
|
761
|
+
)
|
|
631
762
|
|
|
632
763
|
if _is_image_1vsN_pattern():
|
|
633
764
|
# Glob Image dirs
|
|
@@ -711,11 +842,42 @@ def entrypoint():
|
|
|
711
842
|
video_test=video_test,
|
|
712
843
|
)
|
|
713
844
|
|
|
845
|
+
elif _is_prompt_1vsN_pattern():
|
|
846
|
+
# Glob Image dirs
|
|
847
|
+
if not os.path.exists(args.img_source_dir):
|
|
848
|
+
logger.error(f"{args.img_source_dir} not exist!")
|
|
849
|
+
return
|
|
850
|
+
|
|
851
|
+
directories = []
|
|
852
|
+
for item in os.listdir(args.img_source_dir):
|
|
853
|
+
item_path = os.path.join(args.img_source_dir, item)
|
|
854
|
+
if os.path.isdir(item_path):
|
|
855
|
+
directories.append(item_path)
|
|
856
|
+
|
|
857
|
+
if len(directories) == 0:
|
|
858
|
+
return
|
|
859
|
+
|
|
860
|
+
directories = sorted(directories)
|
|
861
|
+
if not DISABLE_VERBOSE:
|
|
862
|
+
logger.info(
|
|
863
|
+
f"Compare {args.ref_prompt_true} vs {directories}, "
|
|
864
|
+
f"Num compares: {len(directories)}"
|
|
865
|
+
)
|
|
866
|
+
|
|
867
|
+
for metric in args.metrics:
|
|
868
|
+
for img_test_dir in directories:
|
|
869
|
+
_run_metric(
|
|
870
|
+
metric=metric,
|
|
871
|
+
prompt_true=args.ref_prompt_true,
|
|
872
|
+
img_test=img_test_dir,
|
|
873
|
+
)
|
|
874
|
+
|
|
714
875
|
else:
|
|
715
876
|
for metric in args.metrics:
|
|
716
877
|
_run_metric(
|
|
717
878
|
metric=metric,
|
|
718
879
|
img_true=args.img_true,
|
|
880
|
+
prompt_true=args.prompt_true,
|
|
719
881
|
img_test=args.img_test,
|
|
720
882
|
video_true=args.video_true,
|
|
721
883
|
video_test=args.video_test,
|
|
@@ -724,7 +886,7 @@ def entrypoint():
|
|
|
724
886
|
if args.summary:
|
|
725
887
|
|
|
726
888
|
def _fetch_perf():
|
|
727
|
-
if args.perf_log is None or args.
|
|
889
|
+
if args.perf_log is None or len(args.perf_tags) == 0:
|
|
728
890
|
return []
|
|
729
891
|
if not os.path.exists(args.perf_log):
|
|
730
892
|
return []
|
|
@@ -733,17 +895,20 @@ def entrypoint():
|
|
|
733
895
|
perf_lines = file.readlines()
|
|
734
896
|
for line in perf_lines:
|
|
735
897
|
line = line.strip()
|
|
736
|
-
|
|
737
|
-
if
|
|
738
|
-
|
|
739
|
-
else:
|
|
740
|
-
has_all_extra_tag = True
|
|
741
|
-
for ext_tag in args.extra_perf_tags:
|
|
742
|
-
if ext_tag.lower() not in line.lower():
|
|
743
|
-
has_all_extra_tag = False
|
|
744
|
-
break
|
|
745
|
-
if has_all_extra_tag:
|
|
898
|
+
for perf_tag in args.perf_tags:
|
|
899
|
+
if perf_tag.lower() in line.lower():
|
|
900
|
+
if len(args.extra_perf_tags) == 0:
|
|
746
901
|
perf_texts.append(line)
|
|
902
|
+
break
|
|
903
|
+
else:
|
|
904
|
+
has_all_extra_tag = True
|
|
905
|
+
for ext_tag in args.extra_perf_tags:
|
|
906
|
+
if ext_tag.lower() not in line.lower():
|
|
907
|
+
has_all_extra_tag = False
|
|
908
|
+
break
|
|
909
|
+
if has_all_extra_tag:
|
|
910
|
+
perf_texts.append(line)
|
|
911
|
+
break
|
|
747
912
|
return perf_texts
|
|
748
913
|
|
|
749
914
|
PERF_TEXTS: list[str] = _fetch_perf()
|
|
@@ -770,8 +935,9 @@ def entrypoint():
|
|
|
770
935
|
try:
|
|
771
936
|
if tag.lower() in METRICS_CHOICES:
|
|
772
937
|
return float(value_str)
|
|
773
|
-
if args.
|
|
774
|
-
|
|
938
|
+
if len(args.perf_tags) > 0:
|
|
939
|
+
perf_tags = [tag.lower() for tag in args.perf_tags]
|
|
940
|
+
if tag.lower() in perf_tags:
|
|
775
941
|
return float(value_str)
|
|
776
942
|
return int(value_str)
|
|
777
943
|
except ValueError:
|
|
@@ -779,17 +945,37 @@ def entrypoint():
|
|
|
779
945
|
|
|
780
946
|
def _parse_perf(
|
|
781
947
|
compare_tag: str,
|
|
948
|
+
perf_tag: str,
|
|
782
949
|
) -> float | None:
|
|
783
950
|
nonlocal PERF_TEXTS
|
|
784
|
-
|
|
951
|
+
perf_values = []
|
|
785
952
|
for line in PERF_TEXTS:
|
|
786
953
|
if compare_tag in line:
|
|
787
|
-
|
|
788
|
-
if
|
|
789
|
-
|
|
790
|
-
if len(
|
|
954
|
+
perf_value = _parse_value(line, perf_tag)
|
|
955
|
+
if perf_value is not None:
|
|
956
|
+
perf_values.append(perf_value)
|
|
957
|
+
if len(perf_values) == 0:
|
|
791
958
|
return None
|
|
792
|
-
return sum(
|
|
959
|
+
return sum(perf_values) / len(perf_values)
|
|
960
|
+
|
|
961
|
+
def _ref_perf(
|
|
962
|
+
key: str,
|
|
963
|
+
):
|
|
964
|
+
# U1-Q0-C0-NONE vs U4-Q1-C1-NONE
|
|
965
|
+
header = key.split(",")[0].strip()
|
|
966
|
+
reference_tag = None
|
|
967
|
+
if args.prompt_true is None:
|
|
968
|
+
reference_tag = header.split("vs")[0].strip() # U1-Q0-C0-NONE
|
|
969
|
+
|
|
970
|
+
if reference_tag is None:
|
|
971
|
+
return []
|
|
972
|
+
|
|
973
|
+
ref_perf_values = []
|
|
974
|
+
for perf_tag in args.perf_tags:
|
|
975
|
+
perf_value = _parse_perf(reference_tag, perf_tag)
|
|
976
|
+
ref_perf_values.append(perf_value)
|
|
977
|
+
|
|
978
|
+
return ref_perf_values
|
|
793
979
|
|
|
794
980
|
def _format_item(
|
|
795
981
|
key: str,
|
|
@@ -802,40 +988,129 @@ def entrypoint():
|
|
|
802
988
|
header = key.split(",")[0].strip()
|
|
803
989
|
compare_tag = header.split("vs")[1].strip() # U4-Q1-C1-NONE
|
|
804
990
|
has_perf_texts = len(PERF_TEXTS) > 0
|
|
991
|
+
|
|
992
|
+
def _perf_msg(perf_tag: str):
|
|
993
|
+
if "time" in perf_tag.lower():
|
|
994
|
+
perf_msg = "Latency(s)"
|
|
995
|
+
elif "tflops" in perf_tag.lower():
|
|
996
|
+
perf_msg = "TFLOPs"
|
|
997
|
+
elif "flops" in perf_tag.lower():
|
|
998
|
+
perf_msg = "FLOPs"
|
|
999
|
+
else:
|
|
1000
|
+
perf_msg = perf_tag.upper()
|
|
1001
|
+
return perf_msg
|
|
1002
|
+
|
|
805
1003
|
format_str = ""
|
|
806
1004
|
# Num / Frames
|
|
1005
|
+
perf_values = []
|
|
1006
|
+
perf_msgs = []
|
|
807
1007
|
if n := _parse_value(key, "Num"):
|
|
808
1008
|
if not has_perf_texts:
|
|
809
1009
|
format_str = (
|
|
810
|
-
f"{header:<{max_key_len}}
|
|
1010
|
+
f"{header:<{max_key_len}}, Num: {n}, "
|
|
811
1011
|
f"{metric.upper()}: {value:<7.4f}"
|
|
812
1012
|
)
|
|
813
1013
|
else:
|
|
814
|
-
perf_time = _parse_perf(compare_tag)
|
|
815
|
-
perf_time = f"{perf_time:<.2f}" if perf_time else None
|
|
816
1014
|
format_str = (
|
|
817
|
-
f"{header:<{max_key_len}}
|
|
818
|
-
f"{metric.upper()}: {value:<7.4f}
|
|
819
|
-
f"Perf: {perf_time}"
|
|
1015
|
+
f"{header:<{max_key_len}}, Num: {n}, "
|
|
1016
|
+
f"{metric.upper()}: {value:<7.4f}, "
|
|
820
1017
|
)
|
|
1018
|
+
for perf_tag in args.perf_tags:
|
|
1019
|
+
perf_value = _parse_perf(compare_tag, perf_tag)
|
|
1020
|
+
perf_values.append(perf_value)
|
|
1021
|
+
|
|
1022
|
+
perf_value = (
|
|
1023
|
+
f"{perf_value:<.2f}" if perf_value else None
|
|
1024
|
+
)
|
|
1025
|
+
perf_msg = _perf_msg(perf_tag)
|
|
1026
|
+
format_str += f"{perf_msg}: {perf_value}, "
|
|
1027
|
+
|
|
1028
|
+
perf_msgs.append(perf_msg)
|
|
1029
|
+
|
|
1030
|
+
if not args.cal_speedup:
|
|
1031
|
+
format_str = format_str.removesuffix(", ")
|
|
1032
|
+
|
|
821
1033
|
elif n := _parse_value(key, "Frames"):
|
|
822
1034
|
if not has_perf_texts:
|
|
823
1035
|
format_str = (
|
|
824
|
-
f"{header:<{max_key_len}}
|
|
1036
|
+
f"{header:<{max_key_len}}, Frames: {n}, "
|
|
825
1037
|
f"{metric.upper()}: {value:<7.4f}"
|
|
826
1038
|
)
|
|
827
1039
|
else:
|
|
828
|
-
perf_time = _parse_perf(compare_tag)
|
|
829
|
-
perf_time = f"{perf_time:<.2f}" if perf_time else None
|
|
830
1040
|
format_str = (
|
|
831
|
-
f"{header:<{max_key_len}}
|
|
832
|
-
f"{metric.upper()}: {value:<7.4f}
|
|
833
|
-
f"Perf: {perf_time}"
|
|
1041
|
+
f"{header:<{max_key_len}}, Frames: {n}, "
|
|
1042
|
+
f"{metric.upper()}: {value:<7.4f}, "
|
|
834
1043
|
)
|
|
1044
|
+
for perf_tag in args.perf_tags:
|
|
1045
|
+
perf_value = _parse_perf(compare_tag, perf_tag)
|
|
1046
|
+
perf_values.append(perf_value)
|
|
1047
|
+
|
|
1048
|
+
perf_value = (
|
|
1049
|
+
f"{perf_value:<.2f}" if perf_value else None
|
|
1050
|
+
)
|
|
1051
|
+
perf_msg = _perf_msg(perf_tag)
|
|
1052
|
+
format_str += f"{perf_msg}: {perf_value}, "
|
|
1053
|
+
perf_msgs.append(perf_msg)
|
|
1054
|
+
|
|
1055
|
+
if not args.cal_speedup:
|
|
1056
|
+
format_str = format_str.removesuffix(", ")
|
|
835
1057
|
else:
|
|
836
1058
|
raise ValueError("Num or Frames can not be NoneType.")
|
|
837
1059
|
|
|
838
|
-
return format_str
|
|
1060
|
+
return format_str, perf_values, perf_msgs
|
|
1061
|
+
|
|
1062
|
+
def _format_table(format_strs: List[str], metric: str):
|
|
1063
|
+
if not format_strs:
|
|
1064
|
+
return ""
|
|
1065
|
+
|
|
1066
|
+
metric_upper = metric.upper()
|
|
1067
|
+
all_headers = {"Config", metric_upper}
|
|
1068
|
+
row_data = []
|
|
1069
|
+
|
|
1070
|
+
for line in format_strs:
|
|
1071
|
+
parts = [p.strip() for p in line.split(",")]
|
|
1072
|
+
|
|
1073
|
+
config_part = parts[0].strip()
|
|
1074
|
+
if "vs" in config_part:
|
|
1075
|
+
config = config_part.split("vs", 1)[1].strip()
|
|
1076
|
+
if "_DBCACHE_" in config:
|
|
1077
|
+
config = config.split("_DBCACHE_", 1)[1].strip()
|
|
1078
|
+
else:
|
|
1079
|
+
config = config_part
|
|
1080
|
+
|
|
1081
|
+
metric_value = next(
|
|
1082
|
+
p.split(":")[1].strip()
|
|
1083
|
+
for p in parts
|
|
1084
|
+
if p.startswith(metric_upper)
|
|
1085
|
+
)
|
|
1086
|
+
|
|
1087
|
+
perf_data = {}
|
|
1088
|
+
for part in parts:
|
|
1089
|
+
if part.startswith(("Num:", "Frames:", metric_upper)):
|
|
1090
|
+
continue
|
|
1091
|
+
if ":" in part:
|
|
1092
|
+
key, value = part.split(":", 1)
|
|
1093
|
+
key = key.strip()
|
|
1094
|
+
value = value.strip()
|
|
1095
|
+
perf_data[key] = value
|
|
1096
|
+
all_headers.add(key)
|
|
1097
|
+
|
|
1098
|
+
row_data.append(
|
|
1099
|
+
{"Config": config, metric_upper: metric_value, **perf_data}
|
|
1100
|
+
)
|
|
1101
|
+
|
|
1102
|
+
sorted_headers = ["Config", metric_upper] + sorted(
|
|
1103
|
+
[h for h in all_headers if h not in ["Config", metric_upper]]
|
|
1104
|
+
)
|
|
1105
|
+
|
|
1106
|
+
table = "| " + " | ".join(sorted_headers) + " |\n"
|
|
1107
|
+
table += "| " + " | ".join(["---"] * len(sorted_headers)) + " |\n"
|
|
1108
|
+
|
|
1109
|
+
for row in row_data:
|
|
1110
|
+
row_values = [row.get(header, "") for header in sorted_headers]
|
|
1111
|
+
table += "| " + " | ".join(row_values) + " |\n"
|
|
1112
|
+
|
|
1113
|
+
return table.strip()
|
|
839
1114
|
|
|
840
1115
|
selected_metrics = args.metrics
|
|
841
1116
|
if "all" in selected_metrics:
|
|
@@ -848,7 +1123,17 @@ def entrypoint():
|
|
|
848
1123
|
if metric.upper() in key or metric.lower() in key:
|
|
849
1124
|
selected_items[key] = METRICS_META[key]
|
|
850
1125
|
|
|
851
|
-
reverse =
|
|
1126
|
+
reverse = (
|
|
1127
|
+
True
|
|
1128
|
+
if metric.lower()
|
|
1129
|
+
in [
|
|
1130
|
+
"psnr",
|
|
1131
|
+
"ssim",
|
|
1132
|
+
"clip_score",
|
|
1133
|
+
"image_reward",
|
|
1134
|
+
]
|
|
1135
|
+
else False
|
|
1136
|
+
)
|
|
852
1137
|
sorted_items = sorted(
|
|
853
1138
|
selected_items.items(), key=lambda x: x[1], reverse=reverse
|
|
854
1139
|
)
|
|
@@ -857,12 +1142,65 @@ def entrypoint():
|
|
|
857
1142
|
]
|
|
858
1143
|
max_key_len = max(len(key) for key in selected_keys)
|
|
859
1144
|
|
|
1145
|
+
ref_perf_values = _ref_perf(key=selected_keys[0])
|
|
1146
|
+
max_perf_values: List[float] = []
|
|
1147
|
+
|
|
1148
|
+
if ref_perf_values and None not in ref_perf_values:
|
|
1149
|
+
max_perf_values = ref_perf_values.copy()
|
|
1150
|
+
|
|
1151
|
+
for key, value in sorted_items:
|
|
1152
|
+
format_str, perf_values, perf_msgs = _format_item(
|
|
1153
|
+
key, metric, value, max_key_len
|
|
1154
|
+
)
|
|
1155
|
+
# skip 'None' msg but not 'NONE', 'NONE' means w/o cache
|
|
1156
|
+
if "None" in format_str:
|
|
1157
|
+
continue
|
|
1158
|
+
|
|
1159
|
+
if (
|
|
1160
|
+
not perf_values
|
|
1161
|
+
or None in perf_values
|
|
1162
|
+
or not perf_msgs
|
|
1163
|
+
or not args.cal_speedup
|
|
1164
|
+
):
|
|
1165
|
+
continue
|
|
1166
|
+
|
|
1167
|
+
if not max_perf_values:
|
|
1168
|
+
max_perf_values = perf_values
|
|
1169
|
+
else:
|
|
1170
|
+
for i in range(len(max_perf_values)):
|
|
1171
|
+
max_perf_values[i] = max(
|
|
1172
|
+
max_perf_values[i], perf_values[i]
|
|
1173
|
+
)
|
|
1174
|
+
|
|
860
1175
|
format_strs = []
|
|
861
1176
|
for key, value in sorted_items:
|
|
862
|
-
|
|
863
|
-
|
|
1177
|
+
format_str, perf_values, perf_msgs = _format_item(
|
|
1178
|
+
key, metric, value, max_key_len
|
|
864
1179
|
)
|
|
865
1180
|
|
|
1181
|
+
# skip 'None' msg but not 'NONE', 'NONE' means w/o cache
|
|
1182
|
+
if "None" in format_str:
|
|
1183
|
+
continue
|
|
1184
|
+
|
|
1185
|
+
if (
|
|
1186
|
+
not perf_values
|
|
1187
|
+
or None in perf_values
|
|
1188
|
+
or not perf_msgs
|
|
1189
|
+
or not max_perf_values
|
|
1190
|
+
or not args.cal_speedup
|
|
1191
|
+
):
|
|
1192
|
+
format_strs.append(format_str)
|
|
1193
|
+
continue
|
|
1194
|
+
|
|
1195
|
+
for perf_value, perf_msg, max_perf_value in zip(
|
|
1196
|
+
perf_values, perf_msgs, max_perf_values
|
|
1197
|
+
):
|
|
1198
|
+
perf_speedup = max_perf_value / perf_value
|
|
1199
|
+
format_str += f"{perf_msg}(↑): {perf_speedup:<.2f}, "
|
|
1200
|
+
|
|
1201
|
+
format_str = format_str.removesuffix(", ")
|
|
1202
|
+
format_strs.append(format_str)
|
|
1203
|
+
|
|
866
1204
|
format_len = max(len(format_str) for format_str in format_strs)
|
|
867
1205
|
|
|
868
1206
|
res_len = format_len - len(f"Summary: {metric.upper()}")
|
|
@@ -877,6 +1215,12 @@ def entrypoint():
|
|
|
877
1215
|
print(format_str)
|
|
878
1216
|
print("-" * format_len)
|
|
879
1217
|
|
|
1218
|
+
if args.gen_markdown_table:
|
|
1219
|
+
table = _format_table(format_strs, metric)
|
|
1220
|
+
print("-" * format_len)
|
|
1221
|
+
print(f"{table}")
|
|
1222
|
+
print("-" * format_len)
|
|
1223
|
+
|
|
880
1224
|
|
|
881
1225
|
if __name__ == "__main__":
|
|
882
1226
|
entrypoint()
|