cache-dit 0.2.34__py3-none-any.whl → 0.2.37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +5 -3
- cache_dit/_version.py +2 -2
- cache_dit/metrics/clip_score.py +135 -0
- cache_dit/metrics/fid.py +42 -0
- cache_dit/metrics/image_reward.py +177 -0
- cache_dit/metrics/lpips.py +2 -14
- cache_dit/metrics/metrics.py +449 -93
- cache_dit/utils.py +15 -0
- {cache_dit-0.2.34.dist-info → cache_dit-0.2.37.dist-info}/METADATA +142 -35
- {cache_dit-0.2.34.dist-info → cache_dit-0.2.37.dist-info}/RECORD +14 -12
- {cache_dit-0.2.34.dist-info → cache_dit-0.2.37.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.34.dist-info → cache_dit-0.2.37.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.34.dist-info → cache_dit-0.2.37.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.34.dist-info → cache_dit-0.2.37.dist-info}/top_level.txt +0 -0
cache_dit/metrics/metrics.py
CHANGED
|
@@ -1,25 +1,31 @@
|
|
|
1
1
|
import os
|
|
2
|
+
import re
|
|
2
3
|
import cv2
|
|
3
4
|
import pathlib
|
|
4
5
|
import argparse
|
|
5
6
|
import numpy as np
|
|
6
7
|
from tqdm import tqdm
|
|
7
8
|
from functools import partial
|
|
9
|
+
from typing import Callable, Union, Tuple, List
|
|
8
10
|
from skimage.metrics import mean_squared_error
|
|
9
11
|
from skimage.metrics import peak_signal_noise_ratio
|
|
10
12
|
from skimage.metrics import structural_similarity
|
|
11
|
-
from cache_dit.metrics.fid import FrechetInceptionDistance
|
|
12
13
|
from cache_dit.metrics.config import set_metrics_verbose
|
|
13
14
|
from cache_dit.metrics.config import get_metrics_verbose
|
|
14
15
|
from cache_dit.metrics.config import _IMAGE_EXTENSIONS
|
|
15
16
|
from cache_dit.metrics.config import _VIDEO_EXTENSIONS
|
|
16
17
|
from cache_dit.logger import init_logger
|
|
18
|
+
from cache_dit.metrics.fid import compute_fid
|
|
19
|
+
from cache_dit.metrics.fid import compute_video_fid
|
|
17
20
|
from cache_dit.metrics.lpips import compute_lpips_img
|
|
21
|
+
from cache_dit.metrics.clip_score import compute_clip_score
|
|
22
|
+
from cache_dit.metrics.image_reward import compute_reward_score
|
|
18
23
|
|
|
19
24
|
logger = init_logger(__name__)
|
|
20
25
|
|
|
21
26
|
|
|
22
27
|
DISABLE_VERBOSE = not get_metrics_verbose()
|
|
28
|
+
PSNR_TYPE = "custom"
|
|
23
29
|
|
|
24
30
|
|
|
25
31
|
def compute_lpips_file(
|
|
@@ -51,6 +57,35 @@ def compute_lpips_file(
|
|
|
51
57
|
)
|
|
52
58
|
|
|
53
59
|
|
|
60
|
+
def set_psnr_type(psnr_type: str):
|
|
61
|
+
global PSNR_TYPE
|
|
62
|
+
PSNR_TYPE = psnr_type
|
|
63
|
+
assert PSNR_TYPE in ["skimage", "custom"]
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def get_psnr_type():
|
|
67
|
+
global PSNR_TYPE
|
|
68
|
+
return PSNR_TYPE
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def calculate_psnr(
|
|
72
|
+
image_true: np.ndarray,
|
|
73
|
+
image_test: np.ndarray,
|
|
74
|
+
):
|
|
75
|
+
"""Calculate PSNR (Peak Signal-to-Noise Ratio).
|
|
76
|
+
|
|
77
|
+
Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
image_true (ndarray): Images with range [0, 255].
|
|
81
|
+
image_test (ndarray): Images with range [0, 255].
|
|
82
|
+
"""
|
|
83
|
+
mse = np.mean((image_true - image_test) ** 2)
|
|
84
|
+
if mse == 0:
|
|
85
|
+
return float("inf")
|
|
86
|
+
return 20 * np.log10(255.0 / np.sqrt(mse))
|
|
87
|
+
|
|
88
|
+
|
|
54
89
|
def compute_psnr_file(
|
|
55
90
|
image_true: np.ndarray | str,
|
|
56
91
|
image_test: np.ndarray | str,
|
|
@@ -64,10 +99,13 @@ def compute_psnr_file(
|
|
|
64
99
|
image_true = cv2.imread(image_true)
|
|
65
100
|
if isinstance(image_test, str):
|
|
66
101
|
image_test = cv2.imread(image_test)
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
102
|
+
if get_psnr_type() == "skimage":
|
|
103
|
+
return peak_signal_noise_ratio(
|
|
104
|
+
image_true,
|
|
105
|
+
image_test,
|
|
106
|
+
)
|
|
107
|
+
else:
|
|
108
|
+
return calculate_psnr(image_true, image_test)
|
|
71
109
|
|
|
72
110
|
|
|
73
111
|
def compute_mse_file(
|
|
@@ -114,7 +152,7 @@ def compute_dir_metric(
|
|
|
114
152
|
image_true_dir: np.ndarray | str,
|
|
115
153
|
image_test_dir: np.ndarray | str,
|
|
116
154
|
compute_file_func: callable = compute_psnr_file,
|
|
117
|
-
) -> float:
|
|
155
|
+
) -> Union[Tuple[float, int], Tuple[None, None]]:
|
|
118
156
|
# Image
|
|
119
157
|
if isinstance(image_true_dir, np.ndarray) or isinstance(
|
|
120
158
|
image_test_dir, np.ndarray
|
|
@@ -123,25 +161,30 @@ def compute_dir_metric(
|
|
|
123
161
|
# File
|
|
124
162
|
if not os.path.isdir(image_true_dir) or not os.path.isdir(image_test_dir):
|
|
125
163
|
return compute_file_func(image_true_dir, image_test_dir), 1
|
|
164
|
+
|
|
126
165
|
# Dir
|
|
166
|
+
# compute dir metric
|
|
167
|
+
def natural_sort_key(filename):
|
|
168
|
+
match = re.search(r"(\d+)\D*$", filename)
|
|
169
|
+
return int(match.group(1)) if match else filename
|
|
170
|
+
|
|
127
171
|
image_true_dir: pathlib.Path = pathlib.Path(image_true_dir)
|
|
128
|
-
image_true_files =
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
]
|
|
134
|
-
)
|
|
135
|
-
image_test_dir: pathlib.Path = pathlib.Path(image_test_dir)
|
|
136
|
-
image_test_files = sorted(
|
|
137
|
-
[
|
|
138
|
-
file
|
|
139
|
-
for ext in _IMAGE_EXTENSIONS
|
|
140
|
-
for file in image_test_dir.rglob("*.{}".format(ext))
|
|
141
|
-
]
|
|
142
|
-
)
|
|
172
|
+
image_true_files = [
|
|
173
|
+
file
|
|
174
|
+
for ext in _IMAGE_EXTENSIONS
|
|
175
|
+
for file in image_true_dir.rglob("*.{}".format(ext))
|
|
176
|
+
]
|
|
143
177
|
image_true_files = [file.as_posix() for file in image_true_files]
|
|
178
|
+
image_true_files = sorted(image_true_files, key=natural_sort_key)
|
|
179
|
+
|
|
180
|
+
image_test_dir: pathlib.Path = pathlib.Path(image_test_dir)
|
|
181
|
+
image_test_files = [
|
|
182
|
+
file
|
|
183
|
+
for ext in _IMAGE_EXTENSIONS
|
|
184
|
+
for file in image_test_dir.rglob("*.{}".format(ext))
|
|
185
|
+
]
|
|
144
186
|
image_test_files = [file.as_posix() for file in image_test_files]
|
|
187
|
+
image_test_files = sorted(image_test_files, key=natural_sort_key)
|
|
145
188
|
|
|
146
189
|
# select valid files
|
|
147
190
|
image_true_files_selected = []
|
|
@@ -155,6 +198,7 @@ def compute_dir_metric(
|
|
|
155
198
|
):
|
|
156
199
|
image_true_files_selected.append(selected_image_true)
|
|
157
200
|
image_test_files_selected.append(selected_image_test)
|
|
201
|
+
|
|
158
202
|
image_true_files = image_true_files_selected.copy()
|
|
159
203
|
image_test_files = image_test_files_selected.copy()
|
|
160
204
|
if len(image_true_files) == 0:
|
|
@@ -169,20 +213,22 @@ def compute_dir_metric(
|
|
|
169
213
|
|
|
170
214
|
total_metric = 0.0
|
|
171
215
|
valid_files = 0
|
|
216
|
+
total_files = 0
|
|
172
217
|
for image_true, image_test in tqdm(
|
|
173
218
|
zip(image_true_files, image_test_files),
|
|
174
219
|
total=len(image_true_files),
|
|
175
220
|
disable=DISABLE_VERBOSE,
|
|
176
221
|
):
|
|
177
222
|
metric = compute_file_func(image_true, image_test)
|
|
178
|
-
if metric != float("inf"):
|
|
223
|
+
if metric != float("inf"): # means no cache apply to image_test
|
|
179
224
|
total_metric += metric
|
|
180
225
|
valid_files += 1
|
|
226
|
+
total_files += 1
|
|
181
227
|
|
|
182
228
|
if valid_files > 0:
|
|
183
229
|
average_metric = total_metric / valid_files
|
|
184
230
|
logger.debug(f"Average: {average_metric:.2f}")
|
|
185
|
-
return average_metric,
|
|
231
|
+
return average_metric, total_files
|
|
186
232
|
else:
|
|
187
233
|
logger.debug("No valid files to compare")
|
|
188
234
|
return None, None
|
|
@@ -235,7 +281,7 @@ def compute_video_metric(
|
|
|
235
281
|
video_true: str,
|
|
236
282
|
video_test: str,
|
|
237
283
|
compute_frame_func: callable = compute_psnr_file,
|
|
238
|
-
) -> float:
|
|
284
|
+
) -> Union[Tuple[float, int], Tuple[None, None]]:
|
|
239
285
|
"""
|
|
240
286
|
video_true = "video_true.mp4"
|
|
241
287
|
video_test = "video_test.mp4"
|
|
@@ -335,51 +381,69 @@ def compute_video_metric(
|
|
|
335
381
|
return None, None
|
|
336
382
|
|
|
337
383
|
|
|
338
|
-
compute_lpips =
|
|
339
|
-
|
|
340
|
-
|
|
384
|
+
compute_lpips: Callable[..., Union[Tuple[float, int], Tuple[None, None]]] = (
|
|
385
|
+
partial(
|
|
386
|
+
compute_dir_metric,
|
|
387
|
+
compute_file_func=compute_lpips_file,
|
|
388
|
+
)
|
|
341
389
|
)
|
|
342
390
|
|
|
343
|
-
compute_psnr =
|
|
344
|
-
|
|
345
|
-
|
|
391
|
+
compute_psnr: Callable[..., Union[Tuple[float, int], Tuple[None, None]]] = (
|
|
392
|
+
partial(
|
|
393
|
+
compute_dir_metric,
|
|
394
|
+
compute_file_func=compute_psnr_file,
|
|
395
|
+
)
|
|
346
396
|
)
|
|
347
397
|
|
|
348
|
-
compute_ssim =
|
|
349
|
-
|
|
350
|
-
|
|
398
|
+
compute_ssim: Callable[..., Union[Tuple[float, int], Tuple[None, None]]] = (
|
|
399
|
+
partial(
|
|
400
|
+
compute_dir_metric,
|
|
401
|
+
compute_file_func=compute_ssim_file,
|
|
402
|
+
)
|
|
351
403
|
)
|
|
352
404
|
|
|
353
|
-
compute_mse =
|
|
354
|
-
|
|
355
|
-
|
|
405
|
+
compute_mse: Callable[..., Union[Tuple[float, int], Tuple[None, None]]] = (
|
|
406
|
+
partial(
|
|
407
|
+
compute_dir_metric,
|
|
408
|
+
compute_file_func=compute_mse_file,
|
|
409
|
+
)
|
|
356
410
|
)
|
|
357
411
|
|
|
358
|
-
compute_video_lpips
|
|
412
|
+
compute_video_lpips: Callable[
|
|
413
|
+
..., Union[Tuple[float, int], Tuple[None, None]]
|
|
414
|
+
] = partial(
|
|
359
415
|
compute_video_metric,
|
|
360
416
|
compute_frame_func=compute_lpips_file,
|
|
361
417
|
)
|
|
362
|
-
compute_video_psnr
|
|
418
|
+
compute_video_psnr: Callable[
|
|
419
|
+
..., Union[Tuple[float, int], Tuple[None, None]]
|
|
420
|
+
] = partial(
|
|
363
421
|
compute_video_metric,
|
|
364
422
|
compute_frame_func=compute_psnr_file,
|
|
365
423
|
)
|
|
366
|
-
compute_video_ssim
|
|
424
|
+
compute_video_ssim: Callable[
|
|
425
|
+
..., Union[Tuple[float, int], Tuple[None, None]]
|
|
426
|
+
] = partial(
|
|
367
427
|
compute_video_metric,
|
|
368
428
|
compute_frame_func=compute_ssim_file,
|
|
369
429
|
)
|
|
370
|
-
compute_video_mse
|
|
430
|
+
compute_video_mse: Callable[
|
|
431
|
+
..., Union[Tuple[float, int], Tuple[None, None]]
|
|
432
|
+
] = partial(
|
|
371
433
|
compute_video_metric,
|
|
372
434
|
compute_frame_func=compute_mse_file,
|
|
373
435
|
)
|
|
374
436
|
|
|
375
437
|
|
|
376
438
|
METRICS_CHOICES = [
|
|
377
|
-
"lpips",
|
|
378
|
-
"psnr",
|
|
379
|
-
"ssim",
|
|
380
|
-
"mse",
|
|
381
|
-
"fid",
|
|
382
|
-
"all",
|
|
439
|
+
"lpips", # img vs img
|
|
440
|
+
"psnr", # img vs img
|
|
441
|
+
"ssim", # img vs img
|
|
442
|
+
"mse", # img vs img
|
|
443
|
+
"fid", # img vs img
|
|
444
|
+
"all", # img vs img
|
|
445
|
+
"clip_score", # img vs prompt
|
|
446
|
+
"image_reward", # img vs prompt
|
|
383
447
|
]
|
|
384
448
|
|
|
385
449
|
|
|
@@ -405,6 +469,13 @@ def get_args():
|
|
|
405
469
|
default=None,
|
|
406
470
|
help="Path to ground truth image or Dir to ground truth images",
|
|
407
471
|
)
|
|
472
|
+
parser.add_argument(
|
|
473
|
+
"--prompt-true",
|
|
474
|
+
"-p",
|
|
475
|
+
type=str,
|
|
476
|
+
default=None,
|
|
477
|
+
help="Path to ground truth prompt file for CLIP Score and Image Reward Score.",
|
|
478
|
+
)
|
|
408
479
|
parser.add_argument(
|
|
409
480
|
"--img-test",
|
|
410
481
|
"-i2",
|
|
@@ -442,6 +513,13 @@ def get_args():
|
|
|
442
513
|
default=None,
|
|
443
514
|
help="Path to ref dir that contains ground truth images",
|
|
444
515
|
)
|
|
516
|
+
parser.add_argument(
|
|
517
|
+
"--ref-prompt-true",
|
|
518
|
+
"-rp",
|
|
519
|
+
type=str,
|
|
520
|
+
default=None,
|
|
521
|
+
help="Path to ground truth prompt file for CLIP Score and Image Reward Score.",
|
|
522
|
+
)
|
|
445
523
|
|
|
446
524
|
# Video 1 vs N pattern
|
|
447
525
|
parser.add_argument(
|
|
@@ -495,10 +573,11 @@ def get_args():
|
|
|
495
573
|
help="Path to addtional perf log",
|
|
496
574
|
)
|
|
497
575
|
parser.add_argument(
|
|
498
|
-
"--perf-
|
|
499
|
-
"-
|
|
576
|
+
"--perf-tags",
|
|
577
|
+
"-ptags",
|
|
578
|
+
nargs="+",
|
|
500
579
|
type=str,
|
|
501
|
-
default=
|
|
580
|
+
default=[],
|
|
502
581
|
help="Tag to parse perf time from perf log",
|
|
503
582
|
)
|
|
504
583
|
parser.add_argument(
|
|
@@ -508,6 +587,26 @@ def get_args():
|
|
|
508
587
|
default=[],
|
|
509
588
|
help="Extra tags to parse perf time from perf log",
|
|
510
589
|
)
|
|
590
|
+
parser.add_argument(
|
|
591
|
+
"--psnr-type",
|
|
592
|
+
type=str,
|
|
593
|
+
default="custom",
|
|
594
|
+
choices=["custom", "skimage"],
|
|
595
|
+
help="The compute type of PSNR, [custom, skimage]",
|
|
596
|
+
)
|
|
597
|
+
parser.add_argument(
|
|
598
|
+
"--cal-speedup",
|
|
599
|
+
action="store_true",
|
|
600
|
+
default=False,
|
|
601
|
+
help="Calculate performance speedup.",
|
|
602
|
+
)
|
|
603
|
+
parser.add_argument(
|
|
604
|
+
"--gen-markdown-table",
|
|
605
|
+
"-table",
|
|
606
|
+
action="store_true",
|
|
607
|
+
default=False,
|
|
608
|
+
help="Generate performance markdown table",
|
|
609
|
+
)
|
|
511
610
|
return parser.parse_args()
|
|
512
611
|
|
|
513
612
|
|
|
@@ -516,16 +615,16 @@ def entrypoint():
|
|
|
516
615
|
args = get_args()
|
|
517
616
|
logger.debug(args)
|
|
518
617
|
|
|
618
|
+
if args.metrics in ["clip_score", "image_reward"]:
|
|
619
|
+
assert args.prompt_true is not None or args.ref_prompt_true is not None
|
|
620
|
+
assert args.img_test is not None or args.img_source_dir is not None
|
|
621
|
+
|
|
519
622
|
if args.enable_verbose:
|
|
520
623
|
global DISABLE_VERBOSE
|
|
521
624
|
set_metrics_verbose(True)
|
|
522
625
|
DISABLE_VERBOSE = not get_metrics_verbose()
|
|
523
626
|
|
|
524
|
-
|
|
525
|
-
FID = FrechetInceptionDistance(
|
|
526
|
-
disable_tqdm=DISABLE_VERBOSE,
|
|
527
|
-
batch_size=args.fid_batch_size,
|
|
528
|
-
)
|
|
627
|
+
set_psnr_type(args.psnr_type)
|
|
529
628
|
|
|
530
629
|
METRICS_META: dict[str, float] = {}
|
|
531
630
|
|
|
@@ -533,11 +632,11 @@ def entrypoint():
|
|
|
533
632
|
def _run_metric(
|
|
534
633
|
metric: str,
|
|
535
634
|
img_true: str = None,
|
|
635
|
+
prompt_true: str = None,
|
|
536
636
|
img_test: str = None,
|
|
537
637
|
video_true: str = None,
|
|
538
638
|
video_test: str = None,
|
|
539
639
|
) -> None:
|
|
540
|
-
nonlocal FID
|
|
541
640
|
nonlocal METRICS_META
|
|
542
641
|
metric = metric.lower()
|
|
543
642
|
if img_true is not None and img_test is not None:
|
|
@@ -575,9 +674,39 @@ def entrypoint():
|
|
|
575
674
|
img_mse, n = compute_mse(img_true, img_test)
|
|
576
675
|
_logging_msg(img_mse, "mse", n)
|
|
577
676
|
if metric == "fid" or metric == "all":
|
|
578
|
-
img_fid, n =
|
|
677
|
+
img_fid, n = compute_fid(img_true, img_test)
|
|
579
678
|
_logging_msg(img_fid, "fid", n)
|
|
580
679
|
|
|
680
|
+
if prompt_true is not None and img_test is not None:
|
|
681
|
+
if any(
|
|
682
|
+
(
|
|
683
|
+
not os.path.exists(prompt_true), # file
|
|
684
|
+
not os.path.exists(img_test), # dir
|
|
685
|
+
)
|
|
686
|
+
):
|
|
687
|
+
return
|
|
688
|
+
|
|
689
|
+
# img_true and img_test can be files or dirs
|
|
690
|
+
prompt_true_info = os.path.basename(prompt_true)
|
|
691
|
+
img_test_info = os.path.basename(img_test)
|
|
692
|
+
|
|
693
|
+
def _logging_msg(value: float, name, n: int):
|
|
694
|
+
if value is None or n is None:
|
|
695
|
+
return
|
|
696
|
+
msg = (
|
|
697
|
+
f"{prompt_true_info} vs {img_test_info}, "
|
|
698
|
+
f"Num: {n}, {name.upper()}: {value:.5f}"
|
|
699
|
+
)
|
|
700
|
+
METRICS_META[msg] = value
|
|
701
|
+
logger.info(msg)
|
|
702
|
+
|
|
703
|
+
if metric == "clip_score":
|
|
704
|
+
clip_score, n = compute_clip_score(img_test, prompt_true)
|
|
705
|
+
_logging_msg(clip_score, "clip_score", n)
|
|
706
|
+
if metric == "image_reward":
|
|
707
|
+
image_reward, n = compute_reward_score(img_test, prompt_true)
|
|
708
|
+
_logging_msg(image_reward, "image_reward", n)
|
|
709
|
+
|
|
581
710
|
if video_true is not None and video_test is not None:
|
|
582
711
|
if any(
|
|
583
712
|
(
|
|
@@ -614,7 +743,7 @@ def entrypoint():
|
|
|
614
743
|
video_mse, n = compute_video_mse(video_true, video_test)
|
|
615
744
|
_logging_msg(video_mse, "mse", n)
|
|
616
745
|
if metric == "fid" or metric == "all":
|
|
617
|
-
video_fid, n =
|
|
746
|
+
video_fid, n = compute_video_fid(video_true, video_test)
|
|
618
747
|
_logging_msg(video_fid, "fid", n)
|
|
619
748
|
|
|
620
749
|
# run selected metrics
|
|
@@ -627,7 +756,18 @@ def entrypoint():
|
|
|
627
756
|
def _is_video_1vsN_pattern() -> bool:
|
|
628
757
|
return args.video_source_dir is not None and args.ref_video is not None
|
|
629
758
|
|
|
630
|
-
|
|
759
|
+
def _is_prompt_1vsN_pattern() -> bool:
|
|
760
|
+
return (
|
|
761
|
+
args.img_source_dir is not None and args.ref_prompt_true is not None
|
|
762
|
+
)
|
|
763
|
+
|
|
764
|
+
assert not all(
|
|
765
|
+
(
|
|
766
|
+
_is_image_1vsN_pattern(),
|
|
767
|
+
_is_video_1vsN_pattern(),
|
|
768
|
+
_is_prompt_1vsN_pattern(),
|
|
769
|
+
)
|
|
770
|
+
)
|
|
631
771
|
|
|
632
772
|
if _is_image_1vsN_pattern():
|
|
633
773
|
# Glob Image dirs
|
|
@@ -711,11 +851,42 @@ def entrypoint():
|
|
|
711
851
|
video_test=video_test,
|
|
712
852
|
)
|
|
713
853
|
|
|
854
|
+
elif _is_prompt_1vsN_pattern():
|
|
855
|
+
# Glob Image dirs
|
|
856
|
+
if not os.path.exists(args.img_source_dir):
|
|
857
|
+
logger.error(f"{args.img_source_dir} not exist!")
|
|
858
|
+
return
|
|
859
|
+
|
|
860
|
+
directories = []
|
|
861
|
+
for item in os.listdir(args.img_source_dir):
|
|
862
|
+
item_path = os.path.join(args.img_source_dir, item)
|
|
863
|
+
if os.path.isdir(item_path):
|
|
864
|
+
directories.append(item_path)
|
|
865
|
+
|
|
866
|
+
if len(directories) == 0:
|
|
867
|
+
return
|
|
868
|
+
|
|
869
|
+
directories = sorted(directories)
|
|
870
|
+
if not DISABLE_VERBOSE:
|
|
871
|
+
logger.info(
|
|
872
|
+
f"Compare {args.ref_prompt_true} vs {directories}, "
|
|
873
|
+
f"Num compares: {len(directories)}"
|
|
874
|
+
)
|
|
875
|
+
|
|
876
|
+
for metric in args.metrics:
|
|
877
|
+
for img_test_dir in directories:
|
|
878
|
+
_run_metric(
|
|
879
|
+
metric=metric,
|
|
880
|
+
prompt_true=args.ref_prompt_true,
|
|
881
|
+
img_test=img_test_dir,
|
|
882
|
+
)
|
|
883
|
+
|
|
714
884
|
else:
|
|
715
885
|
for metric in args.metrics:
|
|
716
886
|
_run_metric(
|
|
717
887
|
metric=metric,
|
|
718
888
|
img_true=args.img_true,
|
|
889
|
+
prompt_true=args.prompt_true,
|
|
719
890
|
img_test=args.img_test,
|
|
720
891
|
video_true=args.video_true,
|
|
721
892
|
video_test=args.video_test,
|
|
@@ -724,7 +895,7 @@ def entrypoint():
|
|
|
724
895
|
if args.summary:
|
|
725
896
|
|
|
726
897
|
def _fetch_perf():
|
|
727
|
-
if args.perf_log is None or args.
|
|
898
|
+
if args.perf_log is None or len(args.perf_tags) == 0:
|
|
728
899
|
return []
|
|
729
900
|
if not os.path.exists(args.perf_log):
|
|
730
901
|
return []
|
|
@@ -733,17 +904,20 @@ def entrypoint():
|
|
|
733
904
|
perf_lines = file.readlines()
|
|
734
905
|
for line in perf_lines:
|
|
735
906
|
line = line.strip()
|
|
736
|
-
|
|
737
|
-
if
|
|
738
|
-
|
|
739
|
-
else:
|
|
740
|
-
has_all_extra_tag = True
|
|
741
|
-
for ext_tag in args.extra_perf_tags:
|
|
742
|
-
if ext_tag.lower() not in line.lower():
|
|
743
|
-
has_all_extra_tag = False
|
|
744
|
-
break
|
|
745
|
-
if has_all_extra_tag:
|
|
907
|
+
for perf_tag in args.perf_tags:
|
|
908
|
+
if perf_tag.lower() in line.lower():
|
|
909
|
+
if len(args.extra_perf_tags) == 0:
|
|
746
910
|
perf_texts.append(line)
|
|
911
|
+
break
|
|
912
|
+
else:
|
|
913
|
+
has_all_extra_tag = True
|
|
914
|
+
for ext_tag in args.extra_perf_tags:
|
|
915
|
+
if ext_tag.lower() not in line.lower():
|
|
916
|
+
has_all_extra_tag = False
|
|
917
|
+
break
|
|
918
|
+
if has_all_extra_tag:
|
|
919
|
+
perf_texts.append(line)
|
|
920
|
+
break
|
|
747
921
|
return perf_texts
|
|
748
922
|
|
|
749
923
|
PERF_TEXTS: list[str] = _fetch_perf()
|
|
@@ -770,8 +944,9 @@ def entrypoint():
|
|
|
770
944
|
try:
|
|
771
945
|
if tag.lower() in METRICS_CHOICES:
|
|
772
946
|
return float(value_str)
|
|
773
|
-
if args.
|
|
774
|
-
|
|
947
|
+
if len(args.perf_tags) > 0:
|
|
948
|
+
perf_tags = [tag.lower() for tag in args.perf_tags]
|
|
949
|
+
if tag.lower() in perf_tags:
|
|
775
950
|
return float(value_str)
|
|
776
951
|
return int(value_str)
|
|
777
952
|
except ValueError:
|
|
@@ -779,17 +954,37 @@ def entrypoint():
|
|
|
779
954
|
|
|
780
955
|
def _parse_perf(
|
|
781
956
|
compare_tag: str,
|
|
957
|
+
perf_tag: str,
|
|
782
958
|
) -> float | None:
|
|
783
959
|
nonlocal PERF_TEXTS
|
|
784
|
-
|
|
960
|
+
perf_values = []
|
|
785
961
|
for line in PERF_TEXTS:
|
|
786
962
|
if compare_tag in line:
|
|
787
|
-
|
|
788
|
-
if
|
|
789
|
-
|
|
790
|
-
if len(
|
|
963
|
+
perf_value = _parse_value(line, perf_tag)
|
|
964
|
+
if perf_value is not None:
|
|
965
|
+
perf_values.append(perf_value)
|
|
966
|
+
if len(perf_values) == 0:
|
|
791
967
|
return None
|
|
792
|
-
return sum(
|
|
968
|
+
return sum(perf_values) / len(perf_values)
|
|
969
|
+
|
|
970
|
+
def _ref_perf(
|
|
971
|
+
key: str,
|
|
972
|
+
):
|
|
973
|
+
# U1-Q0-C0-NONE vs U4-Q1-C1-NONE
|
|
974
|
+
header = key.split(",")[0].strip()
|
|
975
|
+
reference_tag = None
|
|
976
|
+
if args.prompt_true is None:
|
|
977
|
+
reference_tag = header.split("vs")[0].strip() # U1-Q0-C0-NONE
|
|
978
|
+
|
|
979
|
+
if reference_tag is None:
|
|
980
|
+
return []
|
|
981
|
+
|
|
982
|
+
ref_perf_values = []
|
|
983
|
+
for perf_tag in args.perf_tags:
|
|
984
|
+
perf_value = _parse_perf(reference_tag, perf_tag)
|
|
985
|
+
ref_perf_values.append(perf_value)
|
|
986
|
+
|
|
987
|
+
return ref_perf_values
|
|
793
988
|
|
|
794
989
|
def _format_item(
|
|
795
990
|
key: str,
|
|
@@ -802,40 +997,129 @@ def entrypoint():
|
|
|
802
997
|
header = key.split(",")[0].strip()
|
|
803
998
|
compare_tag = header.split("vs")[1].strip() # U4-Q1-C1-NONE
|
|
804
999
|
has_perf_texts = len(PERF_TEXTS) > 0
|
|
1000
|
+
|
|
1001
|
+
def _perf_msg(perf_tag: str):
|
|
1002
|
+
if "time" in perf_tag.lower():
|
|
1003
|
+
perf_msg = "Latency(s)"
|
|
1004
|
+
elif "tflops" in perf_tag.lower():
|
|
1005
|
+
perf_msg = "TFLOPs"
|
|
1006
|
+
elif "flops" in perf_tag.lower():
|
|
1007
|
+
perf_msg = "FLOPs"
|
|
1008
|
+
else:
|
|
1009
|
+
perf_msg = perf_tag.upper()
|
|
1010
|
+
return perf_msg
|
|
1011
|
+
|
|
805
1012
|
format_str = ""
|
|
806
1013
|
# Num / Frames
|
|
1014
|
+
perf_values = []
|
|
1015
|
+
perf_msgs = []
|
|
807
1016
|
if n := _parse_value(key, "Num"):
|
|
808
1017
|
if not has_perf_texts:
|
|
809
1018
|
format_str = (
|
|
810
|
-
f"{header:<{max_key_len}}
|
|
1019
|
+
f"{header:<{max_key_len}}, Num: {n}, "
|
|
811
1020
|
f"{metric.upper()}: {value:<7.4f}"
|
|
812
1021
|
)
|
|
813
1022
|
else:
|
|
814
|
-
perf_time = _parse_perf(compare_tag)
|
|
815
|
-
perf_time = f"{perf_time:<.2f}" if perf_time else None
|
|
816
1023
|
format_str = (
|
|
817
|
-
f"{header:<{max_key_len}}
|
|
818
|
-
f"{metric.upper()}: {value:<7.4f}
|
|
819
|
-
f"Perf: {perf_time}"
|
|
1024
|
+
f"{header:<{max_key_len}}, Num: {n}, "
|
|
1025
|
+
f"{metric.upper()}: {value:<7.4f}, "
|
|
820
1026
|
)
|
|
1027
|
+
for perf_tag in args.perf_tags:
|
|
1028
|
+
perf_value = _parse_perf(compare_tag, perf_tag)
|
|
1029
|
+
perf_values.append(perf_value)
|
|
1030
|
+
|
|
1031
|
+
perf_value = (
|
|
1032
|
+
f"{perf_value:<.2f}" if perf_value else None
|
|
1033
|
+
)
|
|
1034
|
+
perf_msg = _perf_msg(perf_tag)
|
|
1035
|
+
format_str += f"{perf_msg}: {perf_value}, "
|
|
1036
|
+
|
|
1037
|
+
perf_msgs.append(perf_msg)
|
|
1038
|
+
|
|
1039
|
+
if not args.cal_speedup:
|
|
1040
|
+
format_str = format_str.removesuffix(", ")
|
|
1041
|
+
|
|
821
1042
|
elif n := _parse_value(key, "Frames"):
|
|
822
1043
|
if not has_perf_texts:
|
|
823
1044
|
format_str = (
|
|
824
|
-
f"{header:<{max_key_len}}
|
|
1045
|
+
f"{header:<{max_key_len}}, Frames: {n}, "
|
|
825
1046
|
f"{metric.upper()}: {value:<7.4f}"
|
|
826
1047
|
)
|
|
827
1048
|
else:
|
|
828
|
-
perf_time = _parse_perf(compare_tag)
|
|
829
|
-
perf_time = f"{perf_time:<.2f}" if perf_time else None
|
|
830
1049
|
format_str = (
|
|
831
|
-
f"{header:<{max_key_len}}
|
|
832
|
-
f"{metric.upper()}: {value:<7.4f}
|
|
833
|
-
f"Perf: {perf_time}"
|
|
1050
|
+
f"{header:<{max_key_len}}, Frames: {n}, "
|
|
1051
|
+
f"{metric.upper()}: {value:<7.4f}, "
|
|
834
1052
|
)
|
|
1053
|
+
for perf_tag in args.perf_tags:
|
|
1054
|
+
perf_value = _parse_perf(compare_tag, perf_tag)
|
|
1055
|
+
perf_values.append(perf_value)
|
|
1056
|
+
|
|
1057
|
+
perf_value = (
|
|
1058
|
+
f"{perf_value:<.2f}" if perf_value else None
|
|
1059
|
+
)
|
|
1060
|
+
perf_msg = _perf_msg(perf_tag)
|
|
1061
|
+
format_str += f"{perf_msg}: {perf_value}, "
|
|
1062
|
+
perf_msgs.append(perf_msg)
|
|
1063
|
+
|
|
1064
|
+
if not args.cal_speedup:
|
|
1065
|
+
format_str = format_str.removesuffix(", ")
|
|
835
1066
|
else:
|
|
836
1067
|
raise ValueError("Num or Frames can not be NoneType.")
|
|
837
1068
|
|
|
838
|
-
return format_str
|
|
1069
|
+
return format_str, perf_values, perf_msgs
|
|
1070
|
+
|
|
1071
|
+
def _format_table(format_strs: List[str], metric: str):
|
|
1072
|
+
if not format_strs:
|
|
1073
|
+
return ""
|
|
1074
|
+
|
|
1075
|
+
metric_upper = metric.upper()
|
|
1076
|
+
all_headers = {"Config", metric_upper}
|
|
1077
|
+
row_data = []
|
|
1078
|
+
|
|
1079
|
+
for line in format_strs:
|
|
1080
|
+
parts = [p.strip() for p in line.split(",")]
|
|
1081
|
+
|
|
1082
|
+
config_part = parts[0].strip()
|
|
1083
|
+
if "vs" in config_part:
|
|
1084
|
+
config = config_part.split("vs", 1)[1].strip()
|
|
1085
|
+
if "_DBCACHE_" in config:
|
|
1086
|
+
config = config.split("_DBCACHE_", 1)[1].strip()
|
|
1087
|
+
else:
|
|
1088
|
+
config = config_part
|
|
1089
|
+
|
|
1090
|
+
metric_value = next(
|
|
1091
|
+
p.split(":")[1].strip()
|
|
1092
|
+
for p in parts
|
|
1093
|
+
if p.startswith(metric_upper)
|
|
1094
|
+
)
|
|
1095
|
+
|
|
1096
|
+
perf_data = {}
|
|
1097
|
+
for part in parts:
|
|
1098
|
+
if part.startswith(("Num:", "Frames:", metric_upper)):
|
|
1099
|
+
continue
|
|
1100
|
+
if ":" in part:
|
|
1101
|
+
key, value = part.split(":", 1)
|
|
1102
|
+
key = key.strip()
|
|
1103
|
+
value = value.strip()
|
|
1104
|
+
perf_data[key] = value
|
|
1105
|
+
all_headers.add(key)
|
|
1106
|
+
|
|
1107
|
+
row_data.append(
|
|
1108
|
+
{"Config": config, metric_upper: metric_value, **perf_data}
|
|
1109
|
+
)
|
|
1110
|
+
|
|
1111
|
+
sorted_headers = ["Config", metric_upper] + sorted(
|
|
1112
|
+
[h for h in all_headers if h not in ["Config", metric_upper]]
|
|
1113
|
+
)
|
|
1114
|
+
|
|
1115
|
+
table = "| " + " | ".join(sorted_headers) + " |\n"
|
|
1116
|
+
table += "| " + " | ".join(["---"] * len(sorted_headers)) + " |\n"
|
|
1117
|
+
|
|
1118
|
+
for row in row_data:
|
|
1119
|
+
row_values = [row.get(header, "") for header in sorted_headers]
|
|
1120
|
+
table += "| " + " | ".join(row_values) + " |\n"
|
|
1121
|
+
|
|
1122
|
+
return table.strip()
|
|
839
1123
|
|
|
840
1124
|
selected_metrics = args.metrics
|
|
841
1125
|
if "all" in selected_metrics:
|
|
@@ -848,7 +1132,17 @@ def entrypoint():
|
|
|
848
1132
|
if metric.upper() in key or metric.lower() in key:
|
|
849
1133
|
selected_items[key] = METRICS_META[key]
|
|
850
1134
|
|
|
851
|
-
reverse =
|
|
1135
|
+
reverse = (
|
|
1136
|
+
True
|
|
1137
|
+
if metric.lower()
|
|
1138
|
+
in [
|
|
1139
|
+
"psnr",
|
|
1140
|
+
"ssim",
|
|
1141
|
+
"clip_score",
|
|
1142
|
+
"image_reward",
|
|
1143
|
+
]
|
|
1144
|
+
else False
|
|
1145
|
+
)
|
|
852
1146
|
sorted_items = sorted(
|
|
853
1147
|
selected_items.items(), key=lambda x: x[1], reverse=reverse
|
|
854
1148
|
)
|
|
@@ -857,12 +1151,65 @@ def entrypoint():
|
|
|
857
1151
|
]
|
|
858
1152
|
max_key_len = max(len(key) for key in selected_keys)
|
|
859
1153
|
|
|
1154
|
+
ref_perf_values = _ref_perf(key=selected_keys[0])
|
|
1155
|
+
max_perf_values: List[float] = []
|
|
1156
|
+
|
|
1157
|
+
if ref_perf_values and None not in ref_perf_values:
|
|
1158
|
+
max_perf_values = ref_perf_values.copy()
|
|
1159
|
+
|
|
1160
|
+
for key, value in sorted_items:
|
|
1161
|
+
format_str, perf_values, perf_msgs = _format_item(
|
|
1162
|
+
key, metric, value, max_key_len
|
|
1163
|
+
)
|
|
1164
|
+
# skip 'None' msg but not 'NONE', 'NONE' means w/o cache
|
|
1165
|
+
if "None" in format_str:
|
|
1166
|
+
continue
|
|
1167
|
+
|
|
1168
|
+
if (
|
|
1169
|
+
not perf_values
|
|
1170
|
+
or None in perf_values
|
|
1171
|
+
or not perf_msgs
|
|
1172
|
+
or not args.cal_speedup
|
|
1173
|
+
):
|
|
1174
|
+
continue
|
|
1175
|
+
|
|
1176
|
+
if not max_perf_values:
|
|
1177
|
+
max_perf_values = perf_values
|
|
1178
|
+
else:
|
|
1179
|
+
for i in range(len(max_perf_values)):
|
|
1180
|
+
max_perf_values[i] = max(
|
|
1181
|
+
max_perf_values[i], perf_values[i]
|
|
1182
|
+
)
|
|
1183
|
+
|
|
860
1184
|
format_strs = []
|
|
861
1185
|
for key, value in sorted_items:
|
|
862
|
-
|
|
863
|
-
|
|
1186
|
+
format_str, perf_values, perf_msgs = _format_item(
|
|
1187
|
+
key, metric, value, max_key_len
|
|
864
1188
|
)
|
|
865
1189
|
|
|
1190
|
+
# skip 'None' msg but not 'NONE', 'NONE' means w/o cache
|
|
1191
|
+
if "None" in format_str:
|
|
1192
|
+
continue
|
|
1193
|
+
|
|
1194
|
+
if (
|
|
1195
|
+
not perf_values
|
|
1196
|
+
or None in perf_values
|
|
1197
|
+
or not perf_msgs
|
|
1198
|
+
or not max_perf_values
|
|
1199
|
+
or not args.cal_speedup
|
|
1200
|
+
):
|
|
1201
|
+
format_strs.append(format_str)
|
|
1202
|
+
continue
|
|
1203
|
+
|
|
1204
|
+
for perf_value, perf_msg, max_perf_value in zip(
|
|
1205
|
+
perf_values, perf_msgs, max_perf_values
|
|
1206
|
+
):
|
|
1207
|
+
perf_speedup = max_perf_value / perf_value
|
|
1208
|
+
format_str += f"{perf_msg}(↑): {perf_speedup:<.2f}, "
|
|
1209
|
+
|
|
1210
|
+
format_str = format_str.removesuffix(", ")
|
|
1211
|
+
format_strs.append(format_str)
|
|
1212
|
+
|
|
866
1213
|
format_len = max(len(format_str) for format_str in format_strs)
|
|
867
1214
|
|
|
868
1215
|
res_len = format_len - len(f"Summary: {metric.upper()}")
|
|
@@ -877,6 +1224,15 @@ def entrypoint():
|
|
|
877
1224
|
print(format_str)
|
|
878
1225
|
print("-" * format_len)
|
|
879
1226
|
|
|
1227
|
+
if args.gen_markdown_table:
|
|
1228
|
+
table = _format_table(format_strs, metric)
|
|
1229
|
+
table = table.replace("Latency(s)(↑)", "SpeedUp(↑)")
|
|
1230
|
+
table = table.replace("TFLOPs(↑)", "SpeedUp(↑)")
|
|
1231
|
+
table = table.replace("FLOPs(↑)", "SpeedUp(↑)")
|
|
1232
|
+
print("-" * format_len)
|
|
1233
|
+
print(f"{table}")
|
|
1234
|
+
print("-" * format_len)
|
|
1235
|
+
|
|
880
1236
|
|
|
881
1237
|
if __name__ == "__main__":
|
|
882
1238
|
entrypoint()
|