cache-dit 0.2.10__py3-none-any.whl → 0.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/compile/utils.py +8 -8
- cache_dit/metrics/metrics.py +146 -57
- {cache_dit-0.2.10.dist-info → cache_dit-0.2.11.dist-info}/METADATA +3 -4
- {cache_dit-0.2.10.dist-info → cache_dit-0.2.11.dist-info}/RECORD +9 -9
- {cache_dit-0.2.10.dist-info → cache_dit-0.2.11.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.10.dist-info → cache_dit-0.2.11.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.10.dist-info → cache_dit-0.2.11.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.10.dist-info → cache_dit-0.2.11.dist-info}/top_level.txt +0 -0
cache_dit/_version.py
CHANGED
cache_dit/compile/utils.py
CHANGED
|
@@ -39,6 +39,14 @@ def set_custom_compile_configs(
|
|
|
39
39
|
# https://github.com/pytorch/pytorch/issues/153791
|
|
40
40
|
torch._inductor.config.autotune_local_cache = False
|
|
41
41
|
|
|
42
|
+
if dist.is_initialized():
|
|
43
|
+
# Enable compute comm overlap
|
|
44
|
+
torch._inductor.config.reorder_for_compute_comm_overlap = True
|
|
45
|
+
# L20 64 GB/s, PCIe; A100/A800 NVLink 300 GB/s.
|
|
46
|
+
torch._inductor.config.intra_node_bw = (
|
|
47
|
+
64 if "L20" in torch.cuda.get_device_name() else 300
|
|
48
|
+
)
|
|
49
|
+
|
|
42
50
|
FORCE_DISABLE_CUSTOM_COMPILE_CONFIG = (
|
|
43
51
|
os.environ.get("CACHE_DIT_FORCE_DISABLE_CUSTOM_COMPILE_CONFIG", "0")
|
|
44
52
|
== "1"
|
|
@@ -51,14 +59,6 @@ def set_custom_compile_configs(
|
|
|
51
59
|
)
|
|
52
60
|
return
|
|
53
61
|
|
|
54
|
-
if dist.is_initialized():
|
|
55
|
-
# Enable compute comm overlap
|
|
56
|
-
torch._inductor.config.reorder_for_compute_comm_overlap = True
|
|
57
|
-
# L20 64 GB/s, PCIe; A100/A800 NVLink 300 GB/s.
|
|
58
|
-
torch._inductor.config.intra_node_bw = (
|
|
59
|
-
64 if "L20" in torch.cuda.get_device_name() else 300
|
|
60
|
-
)
|
|
61
|
-
|
|
62
62
|
# Below are default settings for torch.compile, you can change
|
|
63
63
|
# them to your needs and test the performance
|
|
64
64
|
torch._inductor.config.max_fusion_size = 64
|
cache_dit/metrics/metrics.py
CHANGED
|
@@ -334,19 +334,22 @@ compute_video_mse = partial(
|
|
|
334
334
|
)
|
|
335
335
|
|
|
336
336
|
|
|
337
|
+
METRICS_CHOICES = [
|
|
338
|
+
"psnr",
|
|
339
|
+
"ssim",
|
|
340
|
+
"mse",
|
|
341
|
+
"fid",
|
|
342
|
+
"all",
|
|
343
|
+
]
|
|
344
|
+
|
|
345
|
+
|
|
337
346
|
# Entrypoints
|
|
338
347
|
def get_args():
|
|
348
|
+
global METRICS_CHOICES
|
|
339
349
|
parser = argparse.ArgumentParser(
|
|
340
350
|
description="CacheDiT's Metrics CLI",
|
|
341
351
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
342
352
|
)
|
|
343
|
-
METRICS_CHOICES = [
|
|
344
|
-
"psnr",
|
|
345
|
-
"ssim",
|
|
346
|
-
"mse",
|
|
347
|
-
"fid",
|
|
348
|
-
"all",
|
|
349
|
-
]
|
|
350
353
|
parser.add_argument(
|
|
351
354
|
"metrics",
|
|
352
355
|
type=str,
|
|
@@ -383,13 +386,6 @@ def get_args():
|
|
|
383
386
|
default=None,
|
|
384
387
|
help="Path to predicted video or Dir to predicted videos",
|
|
385
388
|
)
|
|
386
|
-
parser.add_argument(
|
|
387
|
-
"--enable-verbose",
|
|
388
|
-
"-verbose",
|
|
389
|
-
action="store_true",
|
|
390
|
-
default=False,
|
|
391
|
-
help="Show metrics progress verbose",
|
|
392
|
-
)
|
|
393
389
|
|
|
394
390
|
# Image 1 vs N pattern
|
|
395
391
|
parser.add_argument(
|
|
@@ -431,10 +427,29 @@ def get_args():
|
|
|
431
427
|
default=1,
|
|
432
428
|
help="Batch size for FID compute",
|
|
433
429
|
)
|
|
430
|
+
|
|
431
|
+
# Verbose
|
|
432
|
+
parser.add_argument(
|
|
433
|
+
"--enable-verbose",
|
|
434
|
+
"-verbose",
|
|
435
|
+
action="store_true",
|
|
436
|
+
default=False,
|
|
437
|
+
help="Show metrics progress verbose",
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
# Format output
|
|
441
|
+
parser.add_argument(
|
|
442
|
+
"--sort-output",
|
|
443
|
+
"-sort",
|
|
444
|
+
action="store_true",
|
|
445
|
+
default=False,
|
|
446
|
+
help="Sort the outupt metrics results",
|
|
447
|
+
)
|
|
434
448
|
return parser.parse_args()
|
|
435
449
|
|
|
436
450
|
|
|
437
451
|
def entrypoint():
|
|
452
|
+
global METRICS_CHOICES
|
|
438
453
|
args = get_args()
|
|
439
454
|
logger.debug(args)
|
|
440
455
|
|
|
@@ -449,16 +464,19 @@ def entrypoint():
|
|
|
449
464
|
batch_size=args.fid_batch_size,
|
|
450
465
|
)
|
|
451
466
|
|
|
467
|
+
METRICS_META: dict[str, float] = {}
|
|
468
|
+
|
|
452
469
|
# run one metric
|
|
453
470
|
def _run_metric(
|
|
454
|
-
|
|
471
|
+
metric: str,
|
|
455
472
|
img_true: str = None,
|
|
456
473
|
img_test: str = None,
|
|
457
474
|
video_true: str = None,
|
|
458
475
|
video_test: str = None,
|
|
459
476
|
) -> None:
|
|
460
477
|
nonlocal FID
|
|
461
|
-
|
|
478
|
+
nonlocal METRICS_META
|
|
479
|
+
metric = metric.lower()
|
|
462
480
|
if img_true is not None and img_test is not None:
|
|
463
481
|
if any(
|
|
464
482
|
(
|
|
@@ -470,30 +488,30 @@ def entrypoint():
|
|
|
470
488
|
# img_true and img_test can be files or dirs
|
|
471
489
|
img_true_info = os.path.basename(img_true)
|
|
472
490
|
img_test_info = os.path.basename(img_test)
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
491
|
+
|
|
492
|
+
def _logging_msg(value: float, n: int):
|
|
493
|
+
if value is None or n is None:
|
|
494
|
+
return
|
|
495
|
+
msg = (
|
|
476
496
|
f"{img_true_info} vs {img_test_info}, "
|
|
477
|
-
f"Num: {n},
|
|
497
|
+
f"Num: {n}, {metric.upper()}: {value:.5f}"
|
|
478
498
|
)
|
|
479
|
-
|
|
499
|
+
METRICS_META[msg] = value
|
|
500
|
+
logger.info(msg)
|
|
501
|
+
|
|
502
|
+
if metric == "psnr" or metric == "all":
|
|
503
|
+
img_psnr, n = compute_psnr(img_true, img_test)
|
|
504
|
+
_logging_msg(img_psnr, n)
|
|
505
|
+
if metric == "ssim" or metric == "all":
|
|
480
506
|
img_ssim, n = compute_ssim(img_true, img_test)
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
f"Num: {n}, SSIM: {img_ssim}"
|
|
484
|
-
)
|
|
485
|
-
if mertric == "mse" or mertric == "all":
|
|
507
|
+
_logging_msg(img_ssim, n)
|
|
508
|
+
if metric == "mse" or metric == "all":
|
|
486
509
|
img_mse, n = compute_mse(img_true, img_test)
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
f"Num: {n}, MSE: {img_mse}"
|
|
490
|
-
)
|
|
491
|
-
if mertric == "fid" or mertric == "all":
|
|
510
|
+
_logging_msg(img_mse, n)
|
|
511
|
+
if metric == "fid" or metric == "all":
|
|
492
512
|
img_fid, n = FID.compute_fid(img_true, img_test)
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
f"Num: {n}, FID: {img_fid}"
|
|
496
|
-
)
|
|
513
|
+
_logging_msg(img_fid, n)
|
|
514
|
+
|
|
497
515
|
if video_true is not None and video_test is not None:
|
|
498
516
|
if any(
|
|
499
517
|
(
|
|
@@ -502,33 +520,33 @@ def entrypoint():
|
|
|
502
520
|
)
|
|
503
521
|
):
|
|
504
522
|
return
|
|
523
|
+
|
|
505
524
|
# video_true and video_test can be files or dirs
|
|
506
525
|
video_true_info = os.path.basename(video_true)
|
|
507
526
|
video_test_info = os.path.basename(video_test)
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
527
|
+
|
|
528
|
+
def _logging_msg(value: float, n: int):
|
|
529
|
+
if value is None or n is None:
|
|
530
|
+
return
|
|
531
|
+
msg = (
|
|
511
532
|
f"{video_true_info} vs {video_test_info}, "
|
|
512
|
-
f"Frames: {n},
|
|
533
|
+
f"Frames: {n}, {metric.upper()}: {value:.5f}"
|
|
513
534
|
)
|
|
514
|
-
|
|
535
|
+
METRICS_META[msg] = value
|
|
536
|
+
logger.info(msg)
|
|
537
|
+
|
|
538
|
+
if metric == "psnr" or metric == "all":
|
|
539
|
+
video_psnr, n = compute_video_psnr(video_true, video_test)
|
|
540
|
+
_logging_msg(video_psnr, n)
|
|
541
|
+
if metric == "ssim" or metric == "all":
|
|
515
542
|
video_ssim, n = compute_video_ssim(video_true, video_test)
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
f"Frames: {n}, SSIM: {video_ssim}"
|
|
519
|
-
)
|
|
520
|
-
if mertric == "mse" or mertric == "all":
|
|
543
|
+
_logging_msg(video_ssim, n)
|
|
544
|
+
if metric == "mse" or metric == "all":
|
|
521
545
|
video_mse, n = compute_video_mse(video_true, video_test)
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
f"Frames: {n}, MSE: {video_mse}"
|
|
525
|
-
)
|
|
526
|
-
if mertric == "fid" or mertric == "all":
|
|
546
|
+
_logging_msg(video_mse, n)
|
|
547
|
+
if metric == "fid" or metric == "all":
|
|
527
548
|
video_fid, n = FID.compute_video_fid(video_true, video_test)
|
|
528
|
-
|
|
529
|
-
f"{video_true_info} vs {video_test_info}, "
|
|
530
|
-
f"Frames: {n}, FID: {video_fid}"
|
|
531
|
-
)
|
|
549
|
+
_logging_msg(video_fid, n)
|
|
532
550
|
|
|
533
551
|
# run selected metrics
|
|
534
552
|
if not DISABLE_VERBOSE:
|
|
@@ -574,7 +592,7 @@ def entrypoint():
|
|
|
574
592
|
for metric in args.metrics:
|
|
575
593
|
for img_test_dir in directories:
|
|
576
594
|
_run_metric(
|
|
577
|
-
|
|
595
|
+
metric=metric,
|
|
578
596
|
img_true=args.ref_img_dir,
|
|
579
597
|
img_test=img_test_dir,
|
|
580
598
|
)
|
|
@@ -619,7 +637,7 @@ def entrypoint():
|
|
|
619
637
|
for metric in args.metrics:
|
|
620
638
|
for video_test in video_source_selected:
|
|
621
639
|
_run_metric(
|
|
622
|
-
|
|
640
|
+
metric=metric,
|
|
623
641
|
video_true=args.ref_video,
|
|
624
642
|
video_test=video_test,
|
|
625
643
|
)
|
|
@@ -627,13 +645,84 @@ def entrypoint():
|
|
|
627
645
|
else:
|
|
628
646
|
for metric in args.metrics:
|
|
629
647
|
_run_metric(
|
|
630
|
-
|
|
648
|
+
metric=metric,
|
|
631
649
|
img_true=args.img_true,
|
|
632
650
|
img_test=args.img_test,
|
|
633
651
|
video_true=args.video_true,
|
|
634
652
|
video_test=args.video_test,
|
|
635
653
|
)
|
|
636
654
|
|
|
655
|
+
if args.sort_output:
|
|
656
|
+
|
|
657
|
+
def _parse_value(
|
|
658
|
+
text: str,
|
|
659
|
+
tag: str = "Num",
|
|
660
|
+
) -> float:
|
|
661
|
+
import re
|
|
662
|
+
|
|
663
|
+
pattern = re.compile(
|
|
664
|
+
rf"{re.escape(tag)}:\s*(\d+\.?\d*)", re.IGNORECASE
|
|
665
|
+
)
|
|
666
|
+
|
|
667
|
+
match = pattern.search(text)
|
|
668
|
+
|
|
669
|
+
if not match:
|
|
670
|
+
return None
|
|
671
|
+
|
|
672
|
+
if tag.lower() in METRICS_CHOICES:
|
|
673
|
+
return float(match.group(1))
|
|
674
|
+
return int(match.group(1))
|
|
675
|
+
|
|
676
|
+
def _format_item(
|
|
677
|
+
key: str,
|
|
678
|
+
metric: str,
|
|
679
|
+
value: float,
|
|
680
|
+
max_key_len: int,
|
|
681
|
+
):
|
|
682
|
+
# U1-Q0-C0-NONE vs U4-Q1-C1-NONE
|
|
683
|
+
header = key.split(",")[0].strip()
|
|
684
|
+
# Num / Frames
|
|
685
|
+
if n := _parse_value(key, "Num"):
|
|
686
|
+
print(
|
|
687
|
+
f"{header:<{max_key_len}} Num: {n} "
|
|
688
|
+
f"{metric.upper()}: {value:<.4f}"
|
|
689
|
+
)
|
|
690
|
+
elif n := _parse_value(key, "Frames"):
|
|
691
|
+
print(
|
|
692
|
+
f"{header:<{max_key_len}} Frames: {n} "
|
|
693
|
+
f"{metric.upper()}: {value:<.4f}"
|
|
694
|
+
)
|
|
695
|
+
else:
|
|
696
|
+
raise ValueError("Num or Frames can not be NoneType.")
|
|
697
|
+
|
|
698
|
+
for metric in args.metrics:
|
|
699
|
+
selected_items = {}
|
|
700
|
+
for key in METRICS_META.keys():
|
|
701
|
+
if metric.upper() in key or metric.lower() in key:
|
|
702
|
+
selected_items[key] = METRICS_META[key]
|
|
703
|
+
|
|
704
|
+
reverse = True if metric.lower() in ["psnr", "ssim"] else False
|
|
705
|
+
sorted_items = sorted(
|
|
706
|
+
selected_items.items(), key=lambda x: x[1], reverse=reverse
|
|
707
|
+
)
|
|
708
|
+
selected_keys = [
|
|
709
|
+
key.split(",")[0].strip() for key in selected_items.keys()
|
|
710
|
+
]
|
|
711
|
+
max_key_len = max(len(key) for key in selected_keys)
|
|
712
|
+
|
|
713
|
+
format_len = int(max_key_len * 1.5)
|
|
714
|
+
res_len = format_len - len(f"Summary: {metric.upper()}")
|
|
715
|
+
left_len = res_len // 2
|
|
716
|
+
right_len = res_len - left_len
|
|
717
|
+
print("-" * format_len)
|
|
718
|
+
print(
|
|
719
|
+
" " * left_len + f"Summary: {metric.upper()}" + " " * right_len
|
|
720
|
+
)
|
|
721
|
+
print("-" * format_len)
|
|
722
|
+
for key, value in sorted_items:
|
|
723
|
+
_format_item(key, metric, value, max_key_len)
|
|
724
|
+
print("-" * format_len)
|
|
725
|
+
|
|
637
726
|
|
|
638
727
|
if __name__ == "__main__":
|
|
639
728
|
entrypoint()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.11
|
|
4
4
|
Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
@@ -63,9 +63,8 @@ Dynamic: requires-python
|
|
|
63
63
|
</div>
|
|
64
64
|
|
|
65
65
|
## 🔥News🔥
|
|
66
|
-
|
|
67
|
-
- [2025-07-13] **[🤗flux-faster](https://github.com/xlite-dev/flux-faster)** is released! A forked version of [huggingface/flux-fast](https://github.com/huggingface/flux-fast) that **makes flux-fast even faster** with **[cache-dit](https://github.com/vipshop/cache-dit)**, **3.3x** speedup on NVIDIA L20
|
|
68
|
-
|
|
66
|
+
- [2025-07-18] 🎉First caching mechanism in **[🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast)** with **[cache-dit](https://github.com/vipshop/cache-dit)**, please check [PR](https://github.com/huggingface/flux-fast/pull/13).
|
|
67
|
+
- [2025-07-13] **[🤗flux-faster](https://github.com/xlite-dev/flux-faster)** is released! A forked version of **[🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast)** that **makes flux-fast even faster** with **[cache-dit](https://github.com/vipshop/cache-dit)**, **3.3x** speedup on NVIDIA L20.
|
|
69
68
|
|
|
70
69
|
## 🤗 Introduction
|
|
71
70
|
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
cache_dit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
cache_dit/_version.py,sha256=
|
|
2
|
+
cache_dit/_version.py,sha256=Y72g1mojWf0yRnnMW5zEUr6skXUSsqAdPjRJUrxXSYc,513
|
|
3
3
|
cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
|
|
4
4
|
cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
|
|
5
5
|
cache_dit/cache_factory/__init__.py,sha256=iYQwLwB_XLoYl0OB9unZGDbBtrYvZaLkOAmhGRwdW2E,191
|
|
@@ -31,17 +31,17 @@ cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py,sh
|
|
|
31
31
|
cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py,sha256=lQTClo52OwPbNEE4jiBZQhfC7hbtYqnYIABp_vbm_dk,2363
|
|
32
32
|
cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py,sha256=dBNzHBECAuTTA1a7kLdvZL20YzaKTAS3iciVLzKKEWA,2638
|
|
33
33
|
cache_dit/compile/__init__.py,sha256=DfMdPleFFGADXLsr7zXui8BTz_y9futY6rNmNdh9y7k,63
|
|
34
|
-
cache_dit/compile/utils.py,sha256=
|
|
34
|
+
cache_dit/compile/utils.py,sha256=N4A55_8uIbEd-S4xyJPcrdKceI2MGM9BTIhJE63jyL4,3786
|
|
35
35
|
cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
36
36
|
cache_dit/custom_ops/triton_taylorseer.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
37
37
|
cache_dit/metrics/__init__.py,sha256=RaUhl5dieF40RqnizGzR30qoJJ9dyMUEADwgwMaMQrE,575
|
|
38
38
|
cache_dit/metrics/config.py,sha256=ieOgD9ayz722RjVzk24bSIqS2D6o7TZjGk8KeXV-OLQ,551
|
|
39
39
|
cache_dit/metrics/fid.py,sha256=9Ivtazl6mW0Bon2VXa-Ia5Xj2ewxRD3V1Qkd69zYM3Y,17066
|
|
40
40
|
cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
|
|
41
|
-
cache_dit/metrics/metrics.py,sha256=
|
|
42
|
-
cache_dit-0.2.
|
|
43
|
-
cache_dit-0.2.
|
|
44
|
-
cache_dit-0.2.
|
|
45
|
-
cache_dit-0.2.
|
|
46
|
-
cache_dit-0.2.
|
|
47
|
-
cache_dit-0.2.
|
|
41
|
+
cache_dit/metrics/metrics.py,sha256=PAzyhJawos1UeMnHsxcu4edkwCSYMBmjDGRR_--I104,22410
|
|
42
|
+
cache_dit-0.2.11.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
|
|
43
|
+
cache_dit-0.2.11.dist-info/METADATA,sha256=HExb-ldgaYzZRSCy6Tg6syONjd0uDoB8_8AShvbfA-0,28213
|
|
44
|
+
cache_dit-0.2.11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
45
|
+
cache_dit-0.2.11.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
|
|
46
|
+
cache_dit-0.2.11.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
|
|
47
|
+
cache_dit-0.2.11.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|