cache-dit 0.2.6__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/dual_block_cache/cache_context.py +19 -17
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +1 -1
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +1 -1
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +1 -1
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +1 -1
- cache_dit/metrics/__init__.py +2 -0
- cache_dit/metrics/config.py +34 -0
- cache_dit/metrics/fid.py +145 -56
- cache_dit/metrics/metrics.py +165 -52
- {cache_dit-0.2.6.dist-info → cache_dit-0.2.7.dist-info}/METADATA +10 -12
- {cache_dit-0.2.6.dist-info → cache_dit-0.2.7.dist-info}/RECORD +16 -15
- {cache_dit-0.2.6.dist-info → cache_dit-0.2.7.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.6.dist-info → cache_dit-0.2.7.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.6.dist-info → cache_dit-0.2.7.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.6.dist-info → cache_dit-0.2.7.dist-info}/top_level.txt +0 -0
cache_dit/_version.py
CHANGED
|
@@ -988,7 +988,7 @@ def get_Bn_encoder_buffer(prefix: str = "Bn"):
|
|
|
988
988
|
@torch.compiler.disable
|
|
989
989
|
def apply_hidden_states_residual(
|
|
990
990
|
hidden_states: torch.Tensor,
|
|
991
|
-
encoder_hidden_states: torch.Tensor,
|
|
991
|
+
encoder_hidden_states: torch.Tensor = None,
|
|
992
992
|
prefix: str = "Bn",
|
|
993
993
|
encoder_prefix: str = "Bn_encoder",
|
|
994
994
|
):
|
|
@@ -1006,25 +1006,27 @@ def apply_hidden_states_residual(
|
|
|
1006
1006
|
# If cache is not residual, we use the hidden states directly
|
|
1007
1007
|
hidden_states = hidden_states_prev
|
|
1008
1008
|
|
|
1009
|
-
|
|
1010
|
-
encoder_hidden_states_prev = get_Bn_encoder_buffer(encoder_prefix)
|
|
1011
|
-
else:
|
|
1012
|
-
encoder_hidden_states_prev = get_Fn_encoder_buffer(encoder_prefix)
|
|
1009
|
+
hidden_states = hidden_states.contiguous()
|
|
1013
1010
|
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1011
|
+
if encoder_hidden_states is not None:
|
|
1012
|
+
if "Bn" in encoder_prefix:
|
|
1013
|
+
encoder_hidden_states_prev = get_Bn_encoder_buffer(encoder_prefix)
|
|
1014
|
+
else:
|
|
1015
|
+
encoder_hidden_states_prev = get_Fn_encoder_buffer(encoder_prefix)
|
|
1017
1016
|
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
)
|
|
1022
|
-
else:
|
|
1023
|
-
# If encoder cache is not residual, we use the encoder hidden states directly
|
|
1024
|
-
encoder_hidden_states = encoder_hidden_states_prev
|
|
1017
|
+
assert (
|
|
1018
|
+
encoder_hidden_states_prev is not None
|
|
1019
|
+
), f"{prefix}_encoder_buffer must be set before"
|
|
1025
1020
|
|
|
1026
|
-
|
|
1027
|
-
|
|
1021
|
+
if is_encoder_cache_residual():
|
|
1022
|
+
encoder_hidden_states = (
|
|
1023
|
+
encoder_hidden_states_prev + encoder_hidden_states
|
|
1024
|
+
)
|
|
1025
|
+
else:
|
|
1026
|
+
# If encoder cache is not residual, we use the encoder hidden states directly
|
|
1027
|
+
encoder_hidden_states = encoder_hidden_states_prev
|
|
1028
|
+
|
|
1029
|
+
encoder_hidden_states = encoder_hidden_states.contiguous()
|
|
1028
1030
|
|
|
1029
1031
|
return hidden_states, encoder_hidden_states
|
|
1030
1032
|
|
cache_dit/metrics/__init__.py
CHANGED
|
@@ -6,6 +6,8 @@ from cache_dit.metrics.metrics import compute_video_ssim
|
|
|
6
6
|
from cache_dit.metrics.metrics import compute_video_mse
|
|
7
7
|
from cache_dit.metrics.metrics import entrypoint
|
|
8
8
|
from cache_dit.metrics.fid import FrechetInceptionDistance
|
|
9
|
+
from cache_dit.metrics.config import set_metrics_verbose
|
|
10
|
+
from cache_dit.metrics.config import get_metrics_verbose
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
def main():
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from cache_dit.logger import init_logger
|
|
2
|
+
|
|
3
|
+
logger = init_logger(__name__)
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
_metrics_progress_verbose = False
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def set_metrics_verbose(verbose: bool):
|
|
10
|
+
global _metrics_progress_verbose
|
|
11
|
+
_metrics_progress_verbose = verbose
|
|
12
|
+
logger.debug(f"Metrics verbose: {verbose}")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_metrics_verbose() -> bool:
|
|
16
|
+
global _metrics_progress_verbose
|
|
17
|
+
return _metrics_progress_verbose
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
_IMAGE_EXTENSIONS = [
|
|
21
|
+
"bmp",
|
|
22
|
+
"jpg",
|
|
23
|
+
"jpeg",
|
|
24
|
+
"pgm",
|
|
25
|
+
"png",
|
|
26
|
+
"ppm",
|
|
27
|
+
"tif",
|
|
28
|
+
"tiff",
|
|
29
|
+
"webp",
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
_VIDEO_EXTENSIONS = [
|
|
33
|
+
"mp4",
|
|
34
|
+
]
|
cache_dit/metrics/fid.py
CHANGED
|
@@ -9,11 +9,14 @@ import torch
|
|
|
9
9
|
import torchvision.transforms as TF
|
|
10
10
|
from torch.nn.functional import adaptive_avg_pool2d
|
|
11
11
|
from cache_dit.metrics.inception import InceptionV3
|
|
12
|
+
from cache_dit.metrics.config import _IMAGE_EXTENSIONS
|
|
13
|
+
from cache_dit.metrics.config import _VIDEO_EXTENSIONS
|
|
12
14
|
from cache_dit.logger import init_logger
|
|
13
15
|
|
|
14
16
|
logger = init_logger(__name__)
|
|
15
17
|
|
|
16
18
|
|
|
19
|
+
# Adapted from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py
|
|
17
20
|
class ImagePathDataset(torch.utils.data.Dataset):
|
|
18
21
|
def __init__(self, files_or_imgs, transforms=None):
|
|
19
22
|
self.files_or_imgs = files_or_imgs
|
|
@@ -219,22 +222,7 @@ def calculate_activation_statistics(
|
|
|
219
222
|
return mu, sigma
|
|
220
223
|
|
|
221
224
|
|
|
222
|
-
_IMAGE_EXTENSIONS = {
|
|
223
|
-
"bmp",
|
|
224
|
-
"jpg",
|
|
225
|
-
"jpeg",
|
|
226
|
-
"pgm",
|
|
227
|
-
"png",
|
|
228
|
-
"ppm",
|
|
229
|
-
"tif",
|
|
230
|
-
"tiff",
|
|
231
|
-
"webp",
|
|
232
|
-
}
|
|
233
|
-
|
|
234
|
-
|
|
235
225
|
class FrechetInceptionDistance:
|
|
236
|
-
IMAGE_EXTENSIONS = _IMAGE_EXTENSIONS
|
|
237
|
-
|
|
238
226
|
def __init__(
|
|
239
227
|
self,
|
|
240
228
|
device="cuda" if torch.cuda.is_available() else "cpu",
|
|
@@ -258,7 +246,8 @@ class FrechetInceptionDistance:
|
|
|
258
246
|
image_true: np.ndarray | str,
|
|
259
247
|
image_test: np.ndarray | str,
|
|
260
248
|
):
|
|
261
|
-
"""
|
|
249
|
+
"""
|
|
250
|
+
Calculates the FID of two file paths
|
|
262
251
|
FID = FrechetInceptionDistance()
|
|
263
252
|
img_fid = FID.compute_fid("img_true.png", "img_test.png")
|
|
264
253
|
img_dir_fid = FID.compute_fid("img_true_dir", "img_test_dir")
|
|
@@ -267,8 +256,8 @@ class FrechetInceptionDistance:
|
|
|
267
256
|
if os.path.isfile(image_true) or os.path.isfile(image_test):
|
|
268
257
|
assert os.path.exists(image_true)
|
|
269
258
|
assert os.path.exists(image_test)
|
|
270
|
-
assert image_true.split(".")[-1] in
|
|
271
|
-
assert image_test.split(".")[-1] in
|
|
259
|
+
assert image_true.split(".")[-1] in _IMAGE_EXTENSIONS
|
|
260
|
+
assert image_test.split(".")[-1] in _IMAGE_EXTENSIONS
|
|
272
261
|
image_true_files = [image_true]
|
|
273
262
|
image_test_files = [image_test]
|
|
274
263
|
else:
|
|
@@ -279,7 +268,7 @@ class FrechetInceptionDistance:
|
|
|
279
268
|
image_true_files = sorted(
|
|
280
269
|
[
|
|
281
270
|
file
|
|
282
|
-
for ext in
|
|
271
|
+
for ext in _IMAGE_EXTENSIONS
|
|
283
272
|
for file in image_true_dir.rglob("*.{}".format(ext))
|
|
284
273
|
]
|
|
285
274
|
)
|
|
@@ -287,7 +276,7 @@ class FrechetInceptionDistance:
|
|
|
287
276
|
image_test_files = sorted(
|
|
288
277
|
[
|
|
289
278
|
file
|
|
290
|
-
for ext in
|
|
279
|
+
for ext in _IMAGE_EXTENSIONS
|
|
291
280
|
for file in image_test_dir.rglob("*.{}".format(ext))
|
|
292
281
|
]
|
|
293
282
|
)
|
|
@@ -297,15 +286,32 @@ class FrechetInceptionDistance:
|
|
|
297
286
|
image_test_files = [
|
|
298
287
|
file.as_posix() for file in image_test_files
|
|
299
288
|
]
|
|
289
|
+
|
|
290
|
+
# select valid files
|
|
291
|
+
image_true_files_selected = []
|
|
292
|
+
image_test_files_selected = []
|
|
293
|
+
for i in range(
|
|
294
|
+
min(len(image_true_files), len(image_test_files))
|
|
295
|
+
):
|
|
296
|
+
selected_image_true = image_true_files[i]
|
|
297
|
+
selected_image_test = image_test_files[i]
|
|
298
|
+
# Image pair must have the same basename
|
|
299
|
+
if os.path.basename(
|
|
300
|
+
selected_image_test
|
|
301
|
+
) == os.path.basename(selected_image_true):
|
|
302
|
+
image_true_files_selected.append(selected_image_true)
|
|
303
|
+
image_test_files_selected.append(selected_image_test)
|
|
304
|
+
image_true_files = image_true_files_selected.copy()
|
|
305
|
+
image_test_files = image_test_files_selected.copy()
|
|
306
|
+
if len(image_true_files) == 0:
|
|
307
|
+
logger.error(
|
|
308
|
+
"No valid Image pairs, please note that Image "
|
|
309
|
+
"pairs must have the same basename."
|
|
310
|
+
)
|
|
311
|
+
return None, None
|
|
312
|
+
|
|
300
313
|
logger.debug(f"image_true_files: {image_true_files}")
|
|
301
314
|
logger.debug(f"image_test_files: {image_test_files}")
|
|
302
|
-
assert len(image_true_files) == len(image_test_files)
|
|
303
|
-
for image_true, image_test in zip(
|
|
304
|
-
image_true_files, image_test_files
|
|
305
|
-
):
|
|
306
|
-
assert os.path.basename(image_true) == os.path.basename(
|
|
307
|
-
image_test
|
|
308
|
-
), f"image_true:{image_true} != image_test: {image_test}"
|
|
309
315
|
else:
|
|
310
316
|
image_true_files = [image_true]
|
|
311
317
|
image_test_files = [image_test]
|
|
@@ -340,6 +346,115 @@ class FrechetInceptionDistance:
|
|
|
340
346
|
return fid_value, len(image_true_files)
|
|
341
347
|
|
|
342
348
|
def compute_video_fid(
|
|
349
|
+
self,
|
|
350
|
+
# file or dir
|
|
351
|
+
video_true: str,
|
|
352
|
+
video_test: str,
|
|
353
|
+
):
|
|
354
|
+
if os.path.isfile(video_true) and os.path.isfile(video_test):
|
|
355
|
+
video_true_frames, video_test_frames, valid_frames = (
|
|
356
|
+
self._fetch_video_frames(
|
|
357
|
+
video_true=video_true,
|
|
358
|
+
video_test=video_test,
|
|
359
|
+
)
|
|
360
|
+
)
|
|
361
|
+
elif os.path.isdir(video_true) and os.path.isdir(video_test):
|
|
362
|
+
# Glob videos
|
|
363
|
+
video_true_dir: pathlib.Path = pathlib.Path(video_true)
|
|
364
|
+
video_true_files = sorted(
|
|
365
|
+
[
|
|
366
|
+
file
|
|
367
|
+
for ext in _VIDEO_EXTENSIONS
|
|
368
|
+
for file in video_true_dir.rglob("*.{}".format(ext))
|
|
369
|
+
]
|
|
370
|
+
)
|
|
371
|
+
video_test_dir: pathlib.Path = pathlib.Path(video_test)
|
|
372
|
+
video_test_files = sorted(
|
|
373
|
+
[
|
|
374
|
+
file
|
|
375
|
+
for ext in _VIDEO_EXTENSIONS
|
|
376
|
+
for file in video_test_dir.rglob("*.{}".format(ext))
|
|
377
|
+
]
|
|
378
|
+
)
|
|
379
|
+
video_true_files = [file.as_posix() for file in video_true_files]
|
|
380
|
+
video_test_files = [file.as_posix() for file in video_test_files]
|
|
381
|
+
|
|
382
|
+
# select valid video files
|
|
383
|
+
video_true_files_selected = []
|
|
384
|
+
video_test_files_selected = []
|
|
385
|
+
for i in range(min(len(video_true_files), len(video_test_files))):
|
|
386
|
+
selected_video_true = video_true_files[i]
|
|
387
|
+
selected_video_test = video_test_files[i]
|
|
388
|
+
# Video pair must have the same basename
|
|
389
|
+
if os.path.basename(selected_video_test) == os.path.basename(
|
|
390
|
+
selected_video_true
|
|
391
|
+
):
|
|
392
|
+
video_true_files_selected.append(selected_video_true)
|
|
393
|
+
video_test_files_selected.append(selected_video_test)
|
|
394
|
+
|
|
395
|
+
video_true_files = video_true_files_selected.copy()
|
|
396
|
+
video_test_files = video_test_files_selected.copy()
|
|
397
|
+
if len(video_true_files) == 0:
|
|
398
|
+
logger.error(
|
|
399
|
+
"No valid Video pairs, please note that Video "
|
|
400
|
+
"pairs must have the same basename."
|
|
401
|
+
)
|
|
402
|
+
return None, None
|
|
403
|
+
logger.debug(f"video_true_files: {video_true_files}")
|
|
404
|
+
logger.debug(f"video_test_files: {video_test_files}")
|
|
405
|
+
|
|
406
|
+
# Fetch all frames
|
|
407
|
+
video_true_frames = []
|
|
408
|
+
video_test_frames = []
|
|
409
|
+
valid_frames = 0
|
|
410
|
+
|
|
411
|
+
for video_true_, video_test_ in zip(
|
|
412
|
+
video_true_files, video_test_files
|
|
413
|
+
):
|
|
414
|
+
video_true_frames_, video_test_frames_, valid_frames_ = (
|
|
415
|
+
self._fetch_video_frames(
|
|
416
|
+
video_true=video_true_, video_test=video_test_
|
|
417
|
+
)
|
|
418
|
+
)
|
|
419
|
+
video_true_frames.extend(video_true_frames_)
|
|
420
|
+
video_test_frames.extend(video_test_frames_)
|
|
421
|
+
valid_frames += valid_frames_
|
|
422
|
+
else:
|
|
423
|
+
raise ValueError("video_true and video_test must be files or dirs.")
|
|
424
|
+
|
|
425
|
+
if valid_frames <= 0:
|
|
426
|
+
logger.debug("No valid frames to compare")
|
|
427
|
+
return None, None
|
|
428
|
+
|
|
429
|
+
batch_size = min(16, self.batch_size)
|
|
430
|
+
m1, s1 = calculate_activation_statistics(
|
|
431
|
+
video_true_frames,
|
|
432
|
+
self.model,
|
|
433
|
+
batch_size,
|
|
434
|
+
self.dims,
|
|
435
|
+
self.device,
|
|
436
|
+
self.num_workers,
|
|
437
|
+
self.disable_tqdm,
|
|
438
|
+
)
|
|
439
|
+
m2, s2 = calculate_activation_statistics(
|
|
440
|
+
video_test_frames,
|
|
441
|
+
self.model,
|
|
442
|
+
batch_size,
|
|
443
|
+
self.dims,
|
|
444
|
+
self.device,
|
|
445
|
+
self.num_workers,
|
|
446
|
+
self.disable_tqdm,
|
|
447
|
+
)
|
|
448
|
+
fid_value = calculate_frechet_distance(
|
|
449
|
+
m1,
|
|
450
|
+
s1,
|
|
451
|
+
m2,
|
|
452
|
+
s2,
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
return fid_value, valid_frames
|
|
456
|
+
|
|
457
|
+
def _fetch_video_frames(
|
|
343
458
|
self,
|
|
344
459
|
video_true: str,
|
|
345
460
|
video_test: str,
|
|
@@ -349,7 +464,7 @@ class FrechetInceptionDistance:
|
|
|
349
464
|
|
|
350
465
|
if not cap1.isOpened() or not cap2.isOpened():
|
|
351
466
|
logger.error("Could not open video files")
|
|
352
|
-
return
|
|
467
|
+
return [], [], 0
|
|
353
468
|
|
|
354
469
|
frame_count = min(
|
|
355
470
|
int(cap1.get(cv2.CAP_PROP_FRAME_COUNT)),
|
|
@@ -378,32 +493,6 @@ class FrechetInceptionDistance:
|
|
|
378
493
|
cap2.release()
|
|
379
494
|
|
|
380
495
|
if valid_frames <= 0:
|
|
381
|
-
return
|
|
496
|
+
return [], [], 0
|
|
382
497
|
|
|
383
|
-
|
|
384
|
-
m1, s1 = calculate_activation_statistics(
|
|
385
|
-
video_true_frames,
|
|
386
|
-
self.model,
|
|
387
|
-
batch_size,
|
|
388
|
-
self.dims,
|
|
389
|
-
self.device,
|
|
390
|
-
self.num_workers,
|
|
391
|
-
self.disable_tqdm,
|
|
392
|
-
)
|
|
393
|
-
m2, s2 = calculate_activation_statistics(
|
|
394
|
-
video_test_frames,
|
|
395
|
-
self.model,
|
|
396
|
-
batch_size,
|
|
397
|
-
self.dims,
|
|
398
|
-
self.device,
|
|
399
|
-
self.num_workers,
|
|
400
|
-
self.disable_tqdm,
|
|
401
|
-
)
|
|
402
|
-
fid_value = calculate_frechet_distance(
|
|
403
|
-
m1,
|
|
404
|
-
s1,
|
|
405
|
-
m2,
|
|
406
|
-
s2,
|
|
407
|
-
)
|
|
408
|
-
|
|
409
|
-
return fid_value, valid_frames
|
|
498
|
+
return video_true_frames, video_test_frames, valid_frames
|
cache_dit/metrics/metrics.py
CHANGED
|
@@ -3,16 +3,24 @@ import cv2
|
|
|
3
3
|
import pathlib
|
|
4
4
|
import argparse
|
|
5
5
|
import numpy as np
|
|
6
|
+
from tqdm import tqdm
|
|
6
7
|
from functools import partial
|
|
7
8
|
from skimage.metrics import mean_squared_error
|
|
8
9
|
from skimage.metrics import peak_signal_noise_ratio
|
|
9
10
|
from skimage.metrics import structural_similarity
|
|
10
11
|
from cache_dit.metrics.fid import FrechetInceptionDistance
|
|
12
|
+
from cache_dit.metrics.config import set_metrics_verbose
|
|
13
|
+
from cache_dit.metrics.config import get_metrics_verbose
|
|
14
|
+
from cache_dit.metrics.config import _IMAGE_EXTENSIONS
|
|
15
|
+
from cache_dit.metrics.config import _VIDEO_EXTENSIONS
|
|
11
16
|
from cache_dit.logger import init_logger
|
|
12
17
|
|
|
13
18
|
logger = init_logger(__name__)
|
|
14
19
|
|
|
15
20
|
|
|
21
|
+
DISABLE_VERBOSE = not get_metrics_verbose()
|
|
22
|
+
|
|
23
|
+
|
|
16
24
|
def compute_psnr_file(
|
|
17
25
|
image_true: np.ndarray | str,
|
|
18
26
|
image_test: np.ndarray | str,
|
|
@@ -72,19 +80,6 @@ def compute_ssim_file(
|
|
|
72
80
|
)
|
|
73
81
|
|
|
74
82
|
|
|
75
|
-
_IMAGE_EXTENSIONS = {
|
|
76
|
-
"bmp",
|
|
77
|
-
"jpg",
|
|
78
|
-
"jpeg",
|
|
79
|
-
"pgm",
|
|
80
|
-
"png",
|
|
81
|
-
"ppm",
|
|
82
|
-
"tif",
|
|
83
|
-
"tiff",
|
|
84
|
-
"webp",
|
|
85
|
-
}
|
|
86
|
-
|
|
87
|
-
|
|
88
83
|
def compute_dir_metric(
|
|
89
84
|
image_true_dir: np.ndarray | str,
|
|
90
85
|
image_test_dir: np.ndarray | str,
|
|
@@ -117,17 +112,38 @@ def compute_dir_metric(
|
|
|
117
112
|
)
|
|
118
113
|
image_true_files = [file.as_posix() for file in image_true_files]
|
|
119
114
|
image_test_files = [file.as_posix() for file in image_test_files]
|
|
115
|
+
|
|
116
|
+
# select valid files
|
|
117
|
+
image_true_files_selected = []
|
|
118
|
+
image_test_files_selected = []
|
|
119
|
+
for i in range(min(len(image_true_files), len(image_test_files))):
|
|
120
|
+
selected_image_true = image_true_files[i]
|
|
121
|
+
selected_image_test = image_test_files[i]
|
|
122
|
+
# Image pair must have the same basename
|
|
123
|
+
if os.path.basename(selected_image_test) == os.path.basename(
|
|
124
|
+
selected_image_true
|
|
125
|
+
):
|
|
126
|
+
image_true_files_selected.append(selected_image_true)
|
|
127
|
+
image_test_files_selected.append(selected_image_test)
|
|
128
|
+
image_true_files = image_true_files_selected.copy()
|
|
129
|
+
image_test_files = image_test_files_selected.copy()
|
|
130
|
+
if len(image_true_files) == 0:
|
|
131
|
+
logger.error(
|
|
132
|
+
"No valid Image pairs, please note that Image "
|
|
133
|
+
"pairs must have the same basename."
|
|
134
|
+
)
|
|
135
|
+
return None, None
|
|
136
|
+
|
|
120
137
|
logger.debug(f"image_true_files: {image_true_files}")
|
|
121
138
|
logger.debug(f"image_test_files: {image_test_files}")
|
|
122
|
-
assert len(image_true_files) == len(image_test_files)
|
|
123
|
-
for image_true, image_test in zip(image_true_files, image_test_files):
|
|
124
|
-
assert os.path.basename(image_true) == os.path.basename(
|
|
125
|
-
image_test
|
|
126
|
-
), f"image_true:{image_true} != image_test: {image_test}"
|
|
127
139
|
|
|
128
140
|
total_metric = 0.0
|
|
129
141
|
valid_files = 0
|
|
130
|
-
for image_true, image_test in
|
|
142
|
+
for image_true, image_test in tqdm(
|
|
143
|
+
zip(image_true_files, image_test_files),
|
|
144
|
+
total=len(image_true_files),
|
|
145
|
+
disable=DISABLE_VERBOSE,
|
|
146
|
+
):
|
|
131
147
|
metric = compute_file_func(image_true, image_test)
|
|
132
148
|
if metric != float("inf"):
|
|
133
149
|
total_metric += metric
|
|
@@ -142,30 +158,25 @@ def compute_dir_metric(
|
|
|
142
158
|
return None, None
|
|
143
159
|
|
|
144
160
|
|
|
145
|
-
def
|
|
161
|
+
def _fetch_video_frames(
|
|
146
162
|
video_true: str,
|
|
147
163
|
video_test: str,
|
|
148
|
-
|
|
149
|
-
) -> float:
|
|
150
|
-
"""
|
|
151
|
-
video_true = "video_true.mp4"
|
|
152
|
-
video_test = "video_test.mp4"
|
|
153
|
-
PSNR = compute_video_psnr(video_true, video_test)
|
|
154
|
-
"""
|
|
164
|
+
):
|
|
155
165
|
cap1 = cv2.VideoCapture(video_true)
|
|
156
166
|
cap2 = cv2.VideoCapture(video_test)
|
|
157
167
|
|
|
158
168
|
if not cap1.isOpened() or not cap2.isOpened():
|
|
159
169
|
logger.error("Could not open video files")
|
|
160
|
-
return
|
|
170
|
+
return [], [], 0
|
|
161
171
|
|
|
162
172
|
frame_count = min(
|
|
163
173
|
int(cap1.get(cv2.CAP_PROP_FRAME_COUNT)),
|
|
164
174
|
int(cap2.get(cv2.CAP_PROP_FRAME_COUNT)),
|
|
165
175
|
)
|
|
166
176
|
|
|
167
|
-
total_metric = 0.0
|
|
168
177
|
valid_frames = 0
|
|
178
|
+
video_true_frames = []
|
|
179
|
+
video_test_frames = []
|
|
169
180
|
|
|
170
181
|
logger.debug(f"Total frames: {frame_count}")
|
|
171
182
|
|
|
@@ -176,18 +187,115 @@ def compute_video_metric(
|
|
|
176
187
|
if not ret1 or not ret2:
|
|
177
188
|
break
|
|
178
189
|
|
|
179
|
-
|
|
190
|
+
video_true_frames.append(frame1)
|
|
191
|
+
video_test_frames.append(frame2)
|
|
180
192
|
|
|
181
|
-
|
|
182
|
-
total_metric += metric
|
|
183
|
-
valid_frames += 1
|
|
184
|
-
|
|
185
|
-
if valid_frames % 10 == 0:
|
|
186
|
-
logger.debug(f"Processed {valid_frames}/{frame_count} frames")
|
|
193
|
+
valid_frames += 1
|
|
187
194
|
|
|
188
195
|
cap1.release()
|
|
189
196
|
cap2.release()
|
|
190
197
|
|
|
198
|
+
if valid_frames <= 0:
|
|
199
|
+
return [], [], 0
|
|
200
|
+
|
|
201
|
+
return video_true_frames, video_test_frames, valid_frames
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def compute_video_metric(
|
|
205
|
+
video_true: str,
|
|
206
|
+
video_test: str,
|
|
207
|
+
compute_frame_func: callable = compute_psnr_file,
|
|
208
|
+
) -> float:
|
|
209
|
+
"""
|
|
210
|
+
video_true = "video_true.mp4"
|
|
211
|
+
video_test = "video_test.mp4"
|
|
212
|
+
PSNR = compute_video_psnr(video_true, video_test)
|
|
213
|
+
"""
|
|
214
|
+
if os.path.isfile(video_true) and os.path.isfile(video_test):
|
|
215
|
+
video_true_frames, video_test_frames, valid_frames = (
|
|
216
|
+
_fetch_video_frames(
|
|
217
|
+
video_true=video_true,
|
|
218
|
+
video_test=video_test,
|
|
219
|
+
)
|
|
220
|
+
)
|
|
221
|
+
elif os.path.isdir(video_true) and os.path.isdir(video_test):
|
|
222
|
+
# Glob videos
|
|
223
|
+
video_true_dir: pathlib.Path = pathlib.Path(video_true)
|
|
224
|
+
video_true_files = sorted(
|
|
225
|
+
[
|
|
226
|
+
file
|
|
227
|
+
for ext in _VIDEO_EXTENSIONS
|
|
228
|
+
for file in video_true_dir.rglob("*.{}".format(ext))
|
|
229
|
+
]
|
|
230
|
+
)
|
|
231
|
+
video_test_dir: pathlib.Path = pathlib.Path(video_test)
|
|
232
|
+
video_test_files = sorted(
|
|
233
|
+
[
|
|
234
|
+
file
|
|
235
|
+
for ext in _VIDEO_EXTENSIONS
|
|
236
|
+
for file in video_test_dir.rglob("*.{}".format(ext))
|
|
237
|
+
]
|
|
238
|
+
)
|
|
239
|
+
video_true_files = [file.as_posix() for file in video_true_files]
|
|
240
|
+
video_test_files = [file.as_posix() for file in video_test_files]
|
|
241
|
+
|
|
242
|
+
# select valid video files
|
|
243
|
+
video_true_files_selected = []
|
|
244
|
+
video_test_files_selected = []
|
|
245
|
+
for i in range(min(len(video_true_files), len(video_test_files))):
|
|
246
|
+
selected_video_true = video_true_files[i]
|
|
247
|
+
selected_video_test = video_test_files[i]
|
|
248
|
+
# Video pair must have the same basename
|
|
249
|
+
if os.path.basename(selected_video_test) == os.path.basename(
|
|
250
|
+
selected_video_true
|
|
251
|
+
):
|
|
252
|
+
video_true_files_selected.append(selected_video_true)
|
|
253
|
+
video_test_files_selected.append(selected_video_test)
|
|
254
|
+
|
|
255
|
+
video_true_files = video_true_files_selected.copy()
|
|
256
|
+
video_test_files = video_test_files_selected.copy()
|
|
257
|
+
if len(video_true_files) == 0:
|
|
258
|
+
logger.error(
|
|
259
|
+
"No valid Video pairs, please note that Video "
|
|
260
|
+
"pairs must have the same basename."
|
|
261
|
+
)
|
|
262
|
+
return None, None
|
|
263
|
+
logger.debug(f"video_true_files: {video_true_files}")
|
|
264
|
+
logger.debug(f"video_test_files: {video_test_files}")
|
|
265
|
+
|
|
266
|
+
# Fetch all frames
|
|
267
|
+
video_true_frames = []
|
|
268
|
+
video_test_frames = []
|
|
269
|
+
valid_frames = 0
|
|
270
|
+
|
|
271
|
+
for video_true_, video_test_ in zip(video_true_files, video_test_files):
|
|
272
|
+
video_true_frames_, video_test_frames_, valid_frames_ = (
|
|
273
|
+
_fetch_video_frames(
|
|
274
|
+
video_true=video_true_, video_test=video_test_
|
|
275
|
+
)
|
|
276
|
+
)
|
|
277
|
+
video_true_frames.extend(video_true_frames_)
|
|
278
|
+
video_test_frames.extend(video_test_frames_)
|
|
279
|
+
valid_frames += valid_frames_
|
|
280
|
+
else:
|
|
281
|
+
raise ValueError("video_true and video_test must be files or dirs.")
|
|
282
|
+
|
|
283
|
+
if valid_frames <= 0:
|
|
284
|
+
logger.debug("No valid frames to compare")
|
|
285
|
+
return None, None
|
|
286
|
+
|
|
287
|
+
total_metric = 0.0
|
|
288
|
+
valid_frames = 0 # reset
|
|
289
|
+
for frame1, frame2 in tqdm(
|
|
290
|
+
zip(video_true_frames, video_test_frames),
|
|
291
|
+
total=len(video_true_frames),
|
|
292
|
+
disable=DISABLE_VERBOSE,
|
|
293
|
+
):
|
|
294
|
+
metric = compute_frame_func(frame1, frame2)
|
|
295
|
+
if metric != float("inf"):
|
|
296
|
+
total_metric += metric
|
|
297
|
+
valid_frames += 1
|
|
298
|
+
|
|
191
299
|
if valid_frames > 0:
|
|
192
300
|
average_metric = total_metric / valid_frames
|
|
193
301
|
logger.debug(f"Average: {average_metric:.2f}")
|
|
@@ -265,14 +373,21 @@ def get_args():
|
|
|
265
373
|
"-v1",
|
|
266
374
|
type=str,
|
|
267
375
|
default=None,
|
|
268
|
-
help="Path to ground truth video",
|
|
376
|
+
help="Path to ground truth video or Dir to ground truth videos",
|
|
269
377
|
)
|
|
270
378
|
parser.add_argument(
|
|
271
379
|
"--video-test",
|
|
272
380
|
"-v2",
|
|
273
381
|
type=str,
|
|
274
382
|
default=None,
|
|
275
|
-
help="Path to predicted video",
|
|
383
|
+
help="Path to predicted video or Dir to predicted videos",
|
|
384
|
+
)
|
|
385
|
+
parser.add_argument(
|
|
386
|
+
"--enable-verbose",
|
|
387
|
+
"-verbose",
|
|
388
|
+
action="store_true",
|
|
389
|
+
default=False,
|
|
390
|
+
help="Show metrics progress verbose",
|
|
276
391
|
)
|
|
277
392
|
return parser.parse_args()
|
|
278
393
|
|
|
@@ -281,6 +396,11 @@ def entrypoint():
|
|
|
281
396
|
args = get_args()
|
|
282
397
|
logger.debug(args)
|
|
283
398
|
|
|
399
|
+
if args.enable_verbose:
|
|
400
|
+
global DISABLE_VERBOSE
|
|
401
|
+
set_metrics_verbose(True)
|
|
402
|
+
DISABLE_VERBOSE = not get_metrics_verbose()
|
|
403
|
+
|
|
284
404
|
if args.img_true is not None and args.img_test is not None:
|
|
285
405
|
if any(
|
|
286
406
|
(
|
|
@@ -306,7 +426,7 @@ def entrypoint():
|
|
|
306
426
|
f"{args.img_true} vs {args.img_test}, Num: {n}, MSE: {img_mse}"
|
|
307
427
|
)
|
|
308
428
|
if args.metric == "fid" or args.metric == "all":
|
|
309
|
-
FID = FrechetInceptionDistance()
|
|
429
|
+
FID = FrechetInceptionDistance(disable_tqdm=DISABLE_VERBOSE)
|
|
310
430
|
img_fid, n = FID.compute_fid(args.img_true, args.img_test)
|
|
311
431
|
logger.info(
|
|
312
432
|
f"{args.img_true} vs {args.img_test}, Num: {n}, FID: {img_fid}"
|
|
@@ -319,36 +439,29 @@ def entrypoint():
|
|
|
319
439
|
)
|
|
320
440
|
):
|
|
321
441
|
return
|
|
442
|
+
# video_true and video_test can be files or dirs
|
|
322
443
|
if args.metric == "psnr" or args.metric == "all":
|
|
323
|
-
assert not os.path.isdir(args.video_true)
|
|
324
|
-
assert not os.path.isdir(args.video_test)
|
|
325
444
|
video_psnr, n = compute_video_psnr(args.video_true, args.video_test)
|
|
326
445
|
logger.info(
|
|
327
|
-
f"{args.video_true} vs {args.video_test},
|
|
446
|
+
f"{args.video_true} vs {args.video_test}, Frames: {n}, PSNR: {video_psnr}"
|
|
328
447
|
)
|
|
329
448
|
if args.metric == "ssim" or args.metric == "all":
|
|
330
|
-
assert not os.path.isdir(args.video_true)
|
|
331
|
-
assert not os.path.isdir(args.video_test)
|
|
332
449
|
video_ssim, n = compute_video_ssim(args.video_true, args.video_test)
|
|
333
450
|
logger.info(
|
|
334
|
-
f"{args.video_true} vs {args.video_test},
|
|
451
|
+
f"{args.video_true} vs {args.video_test}, Frames: {n}, SSIM: {video_ssim}"
|
|
335
452
|
)
|
|
336
453
|
if args.metric == "mse" or args.metric == "all":
|
|
337
|
-
assert not os.path.isdir(args.video_true)
|
|
338
|
-
assert not os.path.isdir(args.video_test)
|
|
339
454
|
video_mse, n = compute_video_mse(args.video_true, args.video_test)
|
|
340
455
|
logger.info(
|
|
341
|
-
f"{args.video_true} vs {args.video_test},
|
|
456
|
+
f"{args.video_true} vs {args.video_test}, Frames: {n}, MSE: {video_mse}"
|
|
342
457
|
)
|
|
343
458
|
if args.metric == "fid" or args.metric == "all":
|
|
344
|
-
|
|
345
|
-
assert not os.path.isdir(args.video_test)
|
|
346
|
-
FID = FrechetInceptionDistance()
|
|
459
|
+
FID = FrechetInceptionDistance(disable_tqdm=DISABLE_VERBOSE)
|
|
347
460
|
video_fid, n = FID.compute_video_fid(
|
|
348
461
|
args.video_true, args.video_test
|
|
349
462
|
)
|
|
350
463
|
logger.info(
|
|
351
|
-
f"{args.video_true} vs {args.video_test},
|
|
464
|
+
f"{args.video_true} vs {args.video_test}, Frames: {n}, FID: {video_fid}"
|
|
352
465
|
)
|
|
353
466
|
|
|
354
467
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.7
|
|
4
4
|
Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
@@ -519,23 +519,21 @@ from cache_dit.metrics import FrechetInceptionDistance # FID
|
|
|
519
519
|
|
|
520
520
|
FID = FrechetInceptionDistance()
|
|
521
521
|
image_psnr, n = compute_psnr("true.png", "test.png") # Num: n
|
|
522
|
-
image_fid, n = FID.compute_fid("
|
|
523
|
-
video_psnr, n = compute_video_psnr("true.mp4", "test.mp4")
|
|
522
|
+
image_fid, n = FID.compute_fid("true_dir", "test_dir")
|
|
523
|
+
video_psnr, n = compute_video_psnr("true.mp4", "test.mp4") # Frames: n
|
|
524
524
|
```
|
|
525
525
|
|
|
526
526
|
Please check [test_metrics.py](./tests/test_metrics.py) for more details. Or, you can use `cache-dit-metrics-cli` tool. For examples:
|
|
527
527
|
|
|
528
528
|
```bash
|
|
529
529
|
cache-dit-metrics-cli -h # show usage
|
|
530
|
-
|
|
531
|
-
cache-dit-metrics-cli all
|
|
532
|
-
cache-dit-metrics-cli all
|
|
533
|
-
cache-dit-metrics-cli all
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
INFO 07-09 21:00:45 [metrics.py:305] BASELINE vs OPTIMIZED, Num: 1000, MSE: 12.287594770695606
|
|
538
|
-
INFO 07-09 21:01:04 [metrics.py:311] BASELINE vs OPTIMIZED, Num: 1000, FID: 5.983550108647762
|
|
530
|
+
# all: PSNR, FID, SSIM, MSE, ..., etc.
|
|
531
|
+
cache-dit-metrics-cli all -i1 true.png -i2 test.png # image
|
|
532
|
+
cache-dit-metrics-cli all -i1 true_dir -i2 test_dir # image dir
|
|
533
|
+
cache-dit-metrics-cli all -v1 true.mp4 -v2 test.mp4 # video
|
|
534
|
+
cache-dit-metrics-cli all -v1 true_dir -v2 test_dir # video dir
|
|
535
|
+
cache-dit-metrics-cli fid -i1 true_dir -i2 test_dir # FID
|
|
536
|
+
cache-dit-metrics-cli psnr -i1 true_dir -i2 test_dir # PSNR
|
|
539
537
|
```
|
|
540
538
|
|
|
541
539
|
## 👋Contribute
|
|
@@ -1,17 +1,17 @@
|
|
|
1
1
|
cache_dit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
cache_dit/_version.py,sha256=
|
|
2
|
+
cache_dit/_version.py,sha256=Xk20v7uvkFqkpy9aLJzVngs1eKQn0FYUP2oyA1MEQUU,511
|
|
3
3
|
cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
|
|
4
4
|
cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
|
|
5
5
|
cache_dit/cache_factory/__init__.py,sha256=5RNuhWakvvqrOV4vkqrEBA7d-V1LwcNSsjtW14mkqK8,5255
|
|
6
6
|
cache_dit/cache_factory/taylorseer.py,sha256=LKSNo2ode69EVo9xrxjxAMEjz0yDGiGADeDYnEqddA8,3987
|
|
7
7
|
cache_dit/cache_factory/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
8
|
cache_dit/cache_factory/dual_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
9
|
-
cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=
|
|
9
|
+
cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=7kMk6hvMDi-m2HP1qlj4p6qJhzZPjJol6IBneVGDs3E,71396
|
|
10
10
|
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py,sha256=krNAICf-aS3JLmSG8vOB9tpLa04uYRcABsC8PMbVUKY,1870
|
|
11
|
-
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py,sha256=
|
|
12
|
-
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py,sha256=
|
|
13
|
-
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py,sha256=
|
|
14
|
-
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py,sha256=
|
|
11
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py,sha256=3xUjvDzor9AkBkDUc0N7kZqM86MIdajuigesnicNzXE,2260
|
|
12
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py,sha256=cIsov6Pf0dRyddqkzTA2CU-jSDotof8LQr-HIoY9T9M,2615
|
|
13
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py,sha256=SO4q39PQuQ5QVHy5Z-ubiKdstzvQPedONN2J5oiGUh0,9955
|
|
14
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py,sha256=8W9m-WeEVE2ytYi9udKEA8Wtb0EnvP3eT2A1Tu-d29k,2252
|
|
15
15
|
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py,sha256=EREHM5E1wxnL-uRXRAEege4HXraRp1oD_r1Zx4CsiKk,2596
|
|
16
16
|
cache_dit/cache_factory/dynamic_block_prune/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
17
17
|
cache_dit/cache_factory/dynamic_block_prune/prune_context.py,sha256=so1wGdb8W0ATwrjv7E5IEZLPcobybaY1HJa6hBYlOOQ,34698
|
|
@@ -33,13 +33,14 @@ cache_dit/compile/__init__.py,sha256=DfMdPleFFGADXLsr7zXui8BTz_y9futY6rNmNdh9y7k
|
|
|
33
33
|
cache_dit/compile/utils.py,sha256=KU60xc474Anbj7Y_FLRFmNxEjVYLLXkhbtCLXO7o_Tc,3699
|
|
34
34
|
cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
35
35
|
cache_dit/custom_ops/triton_taylorseer.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
36
|
-
cache_dit/metrics/__init__.py,sha256=
|
|
37
|
-
cache_dit/metrics/
|
|
36
|
+
cache_dit/metrics/__init__.py,sha256=RaUhl5dieF40RqnizGzR30qoJJ9dyMUEADwgwMaMQrE,575
|
|
37
|
+
cache_dit/metrics/config.py,sha256=ieOgD9ayz722RjVzk24bSIqS2D6o7TZjGk8KeXV-OLQ,551
|
|
38
|
+
cache_dit/metrics/fid.py,sha256=9Ivtazl6mW0Bon2VXa-Ia5Xj2ewxRD3V1Qkd69zYM3Y,17066
|
|
38
39
|
cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
|
|
39
|
-
cache_dit/metrics/metrics.py,sha256=
|
|
40
|
-
cache_dit-0.2.
|
|
41
|
-
cache_dit-0.2.
|
|
42
|
-
cache_dit-0.2.
|
|
43
|
-
cache_dit-0.2.
|
|
44
|
-
cache_dit-0.2.
|
|
45
|
-
cache_dit-0.2.
|
|
40
|
+
cache_dit/metrics/metrics.py,sha256=tzAtG_-fM1xPIBfRVFIBupvOWYzIO3xDq29Vy5rOBWc,14730
|
|
41
|
+
cache_dit-0.2.7.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
|
|
42
|
+
cache_dit-0.2.7.dist-info/METADATA,sha256=S0C1VGcXoWVjxmfX_755xttdoE9J0toSuHCZW9xaUBM,27608
|
|
43
|
+
cache_dit-0.2.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
44
|
+
cache_dit-0.2.7.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
|
|
45
|
+
cache_dit-0.2.7.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
|
|
46
|
+
cache_dit-0.2.7.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|