cache-dit 0.2.10__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.10'
21
- __version_tuple__ = version_tuple = (0, 2, 10)
20
+ __version__ = version = '0.2.11'
21
+ __version_tuple__ = version_tuple = (0, 2, 11)
@@ -39,6 +39,14 @@ def set_custom_compile_configs(
39
39
  # https://github.com/pytorch/pytorch/issues/153791
40
40
  torch._inductor.config.autotune_local_cache = False
41
41
 
42
+ if dist.is_initialized():
43
+ # Enable compute comm overlap
44
+ torch._inductor.config.reorder_for_compute_comm_overlap = True
45
+ # L20 64 GB/s, PCIe; A100/A800 NVLink 300 GB/s.
46
+ torch._inductor.config.intra_node_bw = (
47
+ 64 if "L20" in torch.cuda.get_device_name() else 300
48
+ )
49
+
42
50
  FORCE_DISABLE_CUSTOM_COMPILE_CONFIG = (
43
51
  os.environ.get("CACHE_DIT_FORCE_DISABLE_CUSTOM_COMPILE_CONFIG", "0")
44
52
  == "1"
@@ -51,14 +59,6 @@ def set_custom_compile_configs(
51
59
  )
52
60
  return
53
61
 
54
- if dist.is_initialized():
55
- # Enable compute comm overlap
56
- torch._inductor.config.reorder_for_compute_comm_overlap = True
57
- # L20 64 GB/s, PCIe; A100/A800 NVLink 300 GB/s.
58
- torch._inductor.config.intra_node_bw = (
59
- 64 if "L20" in torch.cuda.get_device_name() else 300
60
- )
61
-
62
62
  # Below are default settings for torch.compile, you can change
63
63
  # them to your needs and test the performance
64
64
  torch._inductor.config.max_fusion_size = 64
@@ -334,19 +334,22 @@ compute_video_mse = partial(
334
334
  )
335
335
 
336
336
 
337
+ METRICS_CHOICES = [
338
+ "psnr",
339
+ "ssim",
340
+ "mse",
341
+ "fid",
342
+ "all",
343
+ ]
344
+
345
+
337
346
  # Entrypoints
338
347
  def get_args():
348
+ global METRICS_CHOICES
339
349
  parser = argparse.ArgumentParser(
340
350
  description="CacheDiT's Metrics CLI",
341
351
  formatter_class=argparse.ArgumentDefaultsHelpFormatter,
342
352
  )
343
- METRICS_CHOICES = [
344
- "psnr",
345
- "ssim",
346
- "mse",
347
- "fid",
348
- "all",
349
- ]
350
353
  parser.add_argument(
351
354
  "metrics",
352
355
  type=str,
@@ -383,13 +386,6 @@ def get_args():
383
386
  default=None,
384
387
  help="Path to predicted video or Dir to predicted videos",
385
388
  )
386
- parser.add_argument(
387
- "--enable-verbose",
388
- "-verbose",
389
- action="store_true",
390
- default=False,
391
- help="Show metrics progress verbose",
392
- )
393
389
 
394
390
  # Image 1 vs N pattern
395
391
  parser.add_argument(
@@ -431,10 +427,29 @@ def get_args():
431
427
  default=1,
432
428
  help="Batch size for FID compute",
433
429
  )
430
+
431
+ # Verbose
432
+ parser.add_argument(
433
+ "--enable-verbose",
434
+ "-verbose",
435
+ action="store_true",
436
+ default=False,
437
+ help="Show metrics progress verbose",
438
+ )
439
+
440
+ # Format output
441
+ parser.add_argument(
442
+ "--sort-output",
443
+ "-sort",
444
+ action="store_true",
445
+ default=False,
446
+ help="Sort the outupt metrics results",
447
+ )
434
448
  return parser.parse_args()
435
449
 
436
450
 
437
451
  def entrypoint():
452
+ global METRICS_CHOICES
438
453
  args = get_args()
439
454
  logger.debug(args)
440
455
 
@@ -449,16 +464,19 @@ def entrypoint():
449
464
  batch_size=args.fid_batch_size,
450
465
  )
451
466
 
467
+ METRICS_META: dict[str, float] = {}
468
+
452
469
  # run one metric
453
470
  def _run_metric(
454
- mertric: str,
471
+ metric: str,
455
472
  img_true: str = None,
456
473
  img_test: str = None,
457
474
  video_true: str = None,
458
475
  video_test: str = None,
459
476
  ) -> None:
460
477
  nonlocal FID
461
- mertric = mertric.lower()
478
+ nonlocal METRICS_META
479
+ metric = metric.lower()
462
480
  if img_true is not None and img_test is not None:
463
481
  if any(
464
482
  (
@@ -470,30 +488,30 @@ def entrypoint():
470
488
  # img_true and img_test can be files or dirs
471
489
  img_true_info = os.path.basename(img_true)
472
490
  img_test_info = os.path.basename(img_test)
473
- if mertric == "psnr" or mertric == "all":
474
- img_psnr, n = compute_psnr(img_true, img_test)
475
- logger.info(
491
+
492
+ def _logging_msg(value: float, n: int):
493
+ if value is None or n is None:
494
+ return
495
+ msg = (
476
496
  f"{img_true_info} vs {img_test_info}, "
477
- f"Num: {n}, PSNR: {img_psnr}"
497
+ f"Num: {n}, {metric.upper()}: {value:.5f}"
478
498
  )
479
- if mertric == "ssim" or mertric == "all":
499
+ METRICS_META[msg] = value
500
+ logger.info(msg)
501
+
502
+ if metric == "psnr" or metric == "all":
503
+ img_psnr, n = compute_psnr(img_true, img_test)
504
+ _logging_msg(img_psnr, n)
505
+ if metric == "ssim" or metric == "all":
480
506
  img_ssim, n = compute_ssim(img_true, img_test)
481
- logger.info(
482
- f"{img_true_info} vs {img_test_info}, "
483
- f"Num: {n}, SSIM: {img_ssim}"
484
- )
485
- if mertric == "mse" or mertric == "all":
507
+ _logging_msg(img_ssim, n)
508
+ if metric == "mse" or metric == "all":
486
509
  img_mse, n = compute_mse(img_true, img_test)
487
- logger.info(
488
- f"{img_true_info} vs {img_test_info}, "
489
- f"Num: {n}, MSE: {img_mse}"
490
- )
491
- if mertric == "fid" or mertric == "all":
510
+ _logging_msg(img_mse, n)
511
+ if metric == "fid" or metric == "all":
492
512
  img_fid, n = FID.compute_fid(img_true, img_test)
493
- logger.info(
494
- f"{img_true_info} vs {img_test_info}, "
495
- f"Num: {n}, FID: {img_fid}"
496
- )
513
+ _logging_msg(img_fid, n)
514
+
497
515
  if video_true is not None and video_test is not None:
498
516
  if any(
499
517
  (
@@ -502,33 +520,33 @@ def entrypoint():
502
520
  )
503
521
  ):
504
522
  return
523
+
505
524
  # video_true and video_test can be files or dirs
506
525
  video_true_info = os.path.basename(video_true)
507
526
  video_test_info = os.path.basename(video_test)
508
- if mertric == "psnr" or mertric == "all":
509
- video_psnr, n = compute_video_psnr(video_true, video_test)
510
- logger.info(
527
+
528
+ def _logging_msg(value: float, n: int):
529
+ if value is None or n is None:
530
+ return
531
+ msg = (
511
532
  f"{video_true_info} vs {video_test_info}, "
512
- f"Frames: {n}, PSNR: {video_psnr}"
533
+ f"Frames: {n}, {metric.upper()}: {value:.5f}"
513
534
  )
514
- if mertric == "ssim" or mertric == "all":
535
+ METRICS_META[msg] = value
536
+ logger.info(msg)
537
+
538
+ if metric == "psnr" or metric == "all":
539
+ video_psnr, n = compute_video_psnr(video_true, video_test)
540
+ _logging_msg(video_psnr, n)
541
+ if metric == "ssim" or metric == "all":
515
542
  video_ssim, n = compute_video_ssim(video_true, video_test)
516
- logger.info(
517
- f"{video_true_info} vs {video_test_info}, "
518
- f"Frames: {n}, SSIM: {video_ssim}"
519
- )
520
- if mertric == "mse" or mertric == "all":
543
+ _logging_msg(video_ssim, n)
544
+ if metric == "mse" or metric == "all":
521
545
  video_mse, n = compute_video_mse(video_true, video_test)
522
- logger.info(
523
- f"{video_true_info} vs {video_test_info}, "
524
- f"Frames: {n}, MSE: {video_mse}"
525
- )
526
- if mertric == "fid" or mertric == "all":
546
+ _logging_msg(video_mse, n)
547
+ if metric == "fid" or metric == "all":
527
548
  video_fid, n = FID.compute_video_fid(video_true, video_test)
528
- logger.info(
529
- f"{video_true_info} vs {video_test_info}, "
530
- f"Frames: {n}, FID: {video_fid}"
531
- )
549
+ _logging_msg(video_fid, n)
532
550
 
533
551
  # run selected metrics
534
552
  if not DISABLE_VERBOSE:
@@ -574,7 +592,7 @@ def entrypoint():
574
592
  for metric in args.metrics:
575
593
  for img_test_dir in directories:
576
594
  _run_metric(
577
- mertric=metric,
595
+ metric=metric,
578
596
  img_true=args.ref_img_dir,
579
597
  img_test=img_test_dir,
580
598
  )
@@ -619,7 +637,7 @@ def entrypoint():
619
637
  for metric in args.metrics:
620
638
  for video_test in video_source_selected:
621
639
  _run_metric(
622
- mertric=metric,
640
+ metric=metric,
623
641
  video_true=args.ref_video,
624
642
  video_test=video_test,
625
643
  )
@@ -627,13 +645,84 @@ def entrypoint():
627
645
  else:
628
646
  for metric in args.metrics:
629
647
  _run_metric(
630
- mertric=metric,
648
+ metric=metric,
631
649
  img_true=args.img_true,
632
650
  img_test=args.img_test,
633
651
  video_true=args.video_true,
634
652
  video_test=args.video_test,
635
653
  )
636
654
 
655
+ if args.sort_output:
656
+
657
+ def _parse_value(
658
+ text: str,
659
+ tag: str = "Num",
660
+ ) -> float:
661
+ import re
662
+
663
+ pattern = re.compile(
664
+ rf"{re.escape(tag)}:\s*(\d+\.?\d*)", re.IGNORECASE
665
+ )
666
+
667
+ match = pattern.search(text)
668
+
669
+ if not match:
670
+ return None
671
+
672
+ if tag.lower() in METRICS_CHOICES:
673
+ return float(match.group(1))
674
+ return int(match.group(1))
675
+
676
+ def _format_item(
677
+ key: str,
678
+ metric: str,
679
+ value: float,
680
+ max_key_len: int,
681
+ ):
682
+ # U1-Q0-C0-NONE vs U4-Q1-C1-NONE
683
+ header = key.split(",")[0].strip()
684
+ # Num / Frames
685
+ if n := _parse_value(key, "Num"):
686
+ print(
687
+ f"{header:<{max_key_len}} Num: {n} "
688
+ f"{metric.upper()}: {value:<.4f}"
689
+ )
690
+ elif n := _parse_value(key, "Frames"):
691
+ print(
692
+ f"{header:<{max_key_len}} Frames: {n} "
693
+ f"{metric.upper()}: {value:<.4f}"
694
+ )
695
+ else:
696
+ raise ValueError("Num or Frames can not be NoneType.")
697
+
698
+ for metric in args.metrics:
699
+ selected_items = {}
700
+ for key in METRICS_META.keys():
701
+ if metric.upper() in key or metric.lower() in key:
702
+ selected_items[key] = METRICS_META[key]
703
+
704
+ reverse = True if metric.lower() in ["psnr", "ssim"] else False
705
+ sorted_items = sorted(
706
+ selected_items.items(), key=lambda x: x[1], reverse=reverse
707
+ )
708
+ selected_keys = [
709
+ key.split(",")[0].strip() for key in selected_items.keys()
710
+ ]
711
+ max_key_len = max(len(key) for key in selected_keys)
712
+
713
+ format_len = int(max_key_len * 1.5)
714
+ res_len = format_len - len(f"Summary: {metric.upper()}")
715
+ left_len = res_len // 2
716
+ right_len = res_len - left_len
717
+ print("-" * format_len)
718
+ print(
719
+ " " * left_len + f"Summary: {metric.upper()}" + " " * right_len
720
+ )
721
+ print("-" * format_len)
722
+ for key, value in sorted_items:
723
+ _format_item(key, metric, value, max_key_len)
724
+ print("-" * format_len)
725
+
637
726
 
638
727
  if __name__ == "__main__":
639
728
  entrypoint()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.2.10
3
+ Version: 0.2.11
4
4
  Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -63,9 +63,8 @@ Dynamic: requires-python
63
63
  </div>
64
64
 
65
65
  ## 🔥News🔥
66
-
67
- - [2025-07-13] **[🤗flux-faster](https://github.com/xlite-dev/flux-faster)** is released! A forked version of [huggingface/flux-fast](https://github.com/huggingface/flux-fast) that **makes flux-fast even faster** with **[cache-dit](https://github.com/vipshop/cache-dit)**, **3.3x** speedup on NVIDIA L20 while still maintaining **high precision**.
68
-
66
+ - [2025-07-18] 🎉First caching mechanism in **[🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast)** with **[cache-dit](https://github.com/vipshop/cache-dit)**, please check [PR](https://github.com/huggingface/flux-fast/pull/13).
67
+ - [2025-07-13] **[🤗flux-faster](https://github.com/xlite-dev/flux-faster)** is released! A forked version of **[🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast)** that **makes flux-fast even faster** with **[cache-dit](https://github.com/vipshop/cache-dit)**, **3.3x** speedup on NVIDIA L20.
69
68
 
70
69
  ## 🤗 Introduction
71
70
 
@@ -1,5 +1,5 @@
1
1
  cache_dit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- cache_dit/_version.py,sha256=fMOyoyXAggjNTgl2YJ-8HW1bnjjDPiNACUsDoNufScI,513
2
+ cache_dit/_version.py,sha256=Y72g1mojWf0yRnnMW5zEUr6skXUSsqAdPjRJUrxXSYc,513
3
3
  cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
4
4
  cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
5
5
  cache_dit/cache_factory/__init__.py,sha256=iYQwLwB_XLoYl0OB9unZGDbBtrYvZaLkOAmhGRwdW2E,191
@@ -31,17 +31,17 @@ cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py,sh
31
31
  cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py,sha256=lQTClo52OwPbNEE4jiBZQhfC7hbtYqnYIABp_vbm_dk,2363
32
32
  cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py,sha256=dBNzHBECAuTTA1a7kLdvZL20YzaKTAS3iciVLzKKEWA,2638
33
33
  cache_dit/compile/__init__.py,sha256=DfMdPleFFGADXLsr7zXui8BTz_y9futY6rNmNdh9y7k,63
34
- cache_dit/compile/utils.py,sha256=OTvkwcezSrApZ2M1IMkYtkEmFbkfpTknhHMgoBApd6U,3786
34
+ cache_dit/compile/utils.py,sha256=N4A55_8uIbEd-S4xyJPcrdKceI2MGM9BTIhJE63jyL4,3786
35
35
  cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
36
36
  cache_dit/custom_ops/triton_taylorseer.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
37
37
  cache_dit/metrics/__init__.py,sha256=RaUhl5dieF40RqnizGzR30qoJJ9dyMUEADwgwMaMQrE,575
38
38
  cache_dit/metrics/config.py,sha256=ieOgD9ayz722RjVzk24bSIqS2D6o7TZjGk8KeXV-OLQ,551
39
39
  cache_dit/metrics/fid.py,sha256=9Ivtazl6mW0Bon2VXa-Ia5Xj2ewxRD3V1Qkd69zYM3Y,17066
40
40
  cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
41
- cache_dit/metrics/metrics.py,sha256=O9a8qV6deQDWEoez7UZ_aqDLQ9rJXAJUMHGnJM7RUMs,19927
42
- cache_dit-0.2.10.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
43
- cache_dit-0.2.10.dist-info/METADATA,sha256=f5PF-lhexcLdB2HmBWETBEyg011eZOc99tlVI1lozYA,28002
44
- cache_dit-0.2.10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
45
- cache_dit-0.2.10.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
46
- cache_dit-0.2.10.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
47
- cache_dit-0.2.10.dist-info/RECORD,,
41
+ cache_dit/metrics/metrics.py,sha256=PAzyhJawos1UeMnHsxcu4edkwCSYMBmjDGRR_--I104,22410
42
+ cache_dit-0.2.11.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
43
+ cache_dit-0.2.11.dist-info/METADATA,sha256=HExb-ldgaYzZRSCy6Tg6syONjd0uDoB8_8AShvbfA-0,28213
44
+ cache_dit-0.2.11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
45
+ cache_dit-0.2.11.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
46
+ cache_dit-0.2.11.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
47
+ cache_dit-0.2.11.dist-info/RECORD,,