cache-dit 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.10'
21
- __version_tuple__ = version_tuple = (0, 2, 10)
20
+ __version__ = version = '0.2.12'
21
+ __version_tuple__ = version_tuple = (0, 2, 12)
@@ -39,6 +39,14 @@ def set_custom_compile_configs(
39
39
  # https://github.com/pytorch/pytorch/issues/153791
40
40
  torch._inductor.config.autotune_local_cache = False
41
41
 
42
+ if dist.is_initialized():
43
+ # Enable compute comm overlap
44
+ torch._inductor.config.reorder_for_compute_comm_overlap = True
45
+ # L20 64 GB/s, PCIe; A100/A800 NVLink 300 GB/s.
46
+ torch._inductor.config.intra_node_bw = (
47
+ 64 if "L20" in torch.cuda.get_device_name() else 300
48
+ )
49
+
42
50
  FORCE_DISABLE_CUSTOM_COMPILE_CONFIG = (
43
51
  os.environ.get("CACHE_DIT_FORCE_DISABLE_CUSTOM_COMPILE_CONFIG", "0")
44
52
  == "1"
@@ -51,14 +59,6 @@ def set_custom_compile_configs(
51
59
  )
52
60
  return
53
61
 
54
- if dist.is_initialized():
55
- # Enable compute comm overlap
56
- torch._inductor.config.reorder_for_compute_comm_overlap = True
57
- # L20 64 GB/s, PCIe; A100/A800 NVLink 300 GB/s.
58
- torch._inductor.config.intra_node_bw = (
59
- 64 if "L20" in torch.cuda.get_device_name() else 300
60
- )
61
-
62
62
  # Below are default settings for torch.compile, you can change
63
63
  # them to your needs and test the performance
64
64
  torch._inductor.config.max_fusion_size = 64
@@ -334,19 +334,22 @@ compute_video_mse = partial(
334
334
  )
335
335
 
336
336
 
337
+ METRICS_CHOICES = [
338
+ "psnr",
339
+ "ssim",
340
+ "mse",
341
+ "fid",
342
+ "all",
343
+ ]
344
+
345
+
337
346
  # Entrypoints
338
347
  def get_args():
348
+ global METRICS_CHOICES
339
349
  parser = argparse.ArgumentParser(
340
350
  description="CacheDiT's Metrics CLI",
341
351
  formatter_class=argparse.ArgumentDefaultsHelpFormatter,
342
352
  )
343
- METRICS_CHOICES = [
344
- "psnr",
345
- "ssim",
346
- "mse",
347
- "fid",
348
- "all",
349
- ]
350
353
  parser.add_argument(
351
354
  "metrics",
352
355
  type=str,
@@ -383,13 +386,6 @@ def get_args():
383
386
  default=None,
384
387
  help="Path to predicted video or Dir to predicted videos",
385
388
  )
386
- parser.add_argument(
387
- "--enable-verbose",
388
- "-verbose",
389
- action="store_true",
390
- default=False,
391
- help="Show metrics progress verbose",
392
- )
393
389
 
394
390
  # Image 1 vs N pattern
395
391
  parser.add_argument(
@@ -431,10 +427,52 @@ def get_args():
431
427
  default=1,
432
428
  help="Batch size for FID compute",
433
429
  )
430
+
431
+ # Verbose
432
+ parser.add_argument(
433
+ "--enable-verbose",
434
+ "-verbose",
435
+ action="store_true",
436
+ default=False,
437
+ help="Show metrics progress verbose",
438
+ )
439
+
440
+ # Format output
441
+ parser.add_argument(
442
+ "--summary",
443
+ "-s",
444
+ action="store_true",
445
+ default=False,
446
+ help="Summary the outupt metrics results",
447
+ )
448
+
449
+ # Addtional perf log
450
+ parser.add_argument(
451
+ "--perf-log",
452
+ "-plog",
453
+ type=str,
454
+ default=None,
455
+ help="Path to addtional perf log",
456
+ )
457
+ parser.add_argument(
458
+ "--perf-tag",
459
+ "-ptag",
460
+ type=str,
461
+ default=None,
462
+ help="Tag to parse perf time from perf log",
463
+ )
464
+ parser.add_argument(
465
+ "--extra-perf-tags",
466
+ "-extra-ptags",
467
+ nargs="+",
468
+ default=[],
469
+ help="Extra tags to parse perf time from perf log",
470
+ )
434
471
  return parser.parse_args()
435
472
 
436
473
 
437
474
  def entrypoint():
475
+ global METRICS_CHOICES
438
476
  args = get_args()
439
477
  logger.debug(args)
440
478
 
@@ -449,16 +487,19 @@ def entrypoint():
449
487
  batch_size=args.fid_batch_size,
450
488
  )
451
489
 
490
+ METRICS_META: dict[str, float] = {}
491
+
452
492
  # run one metric
453
493
  def _run_metric(
454
- mertric: str,
494
+ metric: str,
455
495
  img_true: str = None,
456
496
  img_test: str = None,
457
497
  video_true: str = None,
458
498
  video_test: str = None,
459
499
  ) -> None:
460
500
  nonlocal FID
461
- mertric = mertric.lower()
501
+ nonlocal METRICS_META
502
+ metric = metric.lower()
462
503
  if img_true is not None and img_test is not None:
463
504
  if any(
464
505
  (
@@ -470,30 +511,30 @@ def entrypoint():
470
511
  # img_true and img_test can be files or dirs
471
512
  img_true_info = os.path.basename(img_true)
472
513
  img_test_info = os.path.basename(img_test)
473
- if mertric == "psnr" or mertric == "all":
474
- img_psnr, n = compute_psnr(img_true, img_test)
475
- logger.info(
514
+
515
+ def _logging_msg(value: float, name, n: int):
516
+ if value is None or n is None:
517
+ return
518
+ msg = (
476
519
  f"{img_true_info} vs {img_test_info}, "
477
- f"Num: {n}, PSNR: {img_psnr}"
520
+ f"Num: {n}, {name.upper()}: {value:.5f}"
478
521
  )
479
- if mertric == "ssim" or mertric == "all":
522
+ METRICS_META[msg] = value
523
+ logger.info(msg)
524
+
525
+ if metric == "psnr" or metric == "all":
526
+ img_psnr, n = compute_psnr(img_true, img_test)
527
+ _logging_msg(img_psnr, "psnr", n)
528
+ if metric == "ssim" or metric == "all":
480
529
  img_ssim, n = compute_ssim(img_true, img_test)
481
- logger.info(
482
- f"{img_true_info} vs {img_test_info}, "
483
- f"Num: {n}, SSIM: {img_ssim}"
484
- )
485
- if mertric == "mse" or mertric == "all":
530
+ _logging_msg(img_ssim, "ssim", n)
531
+ if metric == "mse" or metric == "all":
486
532
  img_mse, n = compute_mse(img_true, img_test)
487
- logger.info(
488
- f"{img_true_info} vs {img_test_info}, "
489
- f"Num: {n}, MSE: {img_mse}"
490
- )
491
- if mertric == "fid" or mertric == "all":
533
+ _logging_msg(img_mse, "mse", n)
534
+ if metric == "fid" or metric == "all":
492
535
  img_fid, n = FID.compute_fid(img_true, img_test)
493
- logger.info(
494
- f"{img_true_info} vs {img_test_info}, "
495
- f"Num: {n}, FID: {img_fid}"
496
- )
536
+ _logging_msg(img_fid, "fid", n)
537
+
497
538
  if video_true is not None and video_test is not None:
498
539
  if any(
499
540
  (
@@ -502,33 +543,33 @@ def entrypoint():
502
543
  )
503
544
  ):
504
545
  return
546
+
505
547
  # video_true and video_test can be files or dirs
506
548
  video_true_info = os.path.basename(video_true)
507
549
  video_test_info = os.path.basename(video_test)
508
- if mertric == "psnr" or mertric == "all":
509
- video_psnr, n = compute_video_psnr(video_true, video_test)
510
- logger.info(
550
+
551
+ def _logging_msg(value: float, name, n: int):
552
+ if value is None or n is None:
553
+ return
554
+ msg = (
511
555
  f"{video_true_info} vs {video_test_info}, "
512
- f"Frames: {n}, PSNR: {video_psnr}"
556
+ f"Frames: {n}, {name.upper()}: {value:.5f}"
513
557
  )
514
- if mertric == "ssim" or mertric == "all":
558
+ METRICS_META[msg] = value
559
+ logger.info(msg)
560
+
561
+ if metric == "psnr" or metric == "all":
562
+ video_psnr, n = compute_video_psnr(video_true, video_test)
563
+ _logging_msg(video_psnr, "psnr", n)
564
+ if metric == "ssim" or metric == "all":
515
565
  video_ssim, n = compute_video_ssim(video_true, video_test)
516
- logger.info(
517
- f"{video_true_info} vs {video_test_info}, "
518
- f"Frames: {n}, SSIM: {video_ssim}"
519
- )
520
- if mertric == "mse" or mertric == "all":
566
+ _logging_msg(video_ssim, "ssim", n)
567
+ if metric == "mse" or metric == "all":
521
568
  video_mse, n = compute_video_mse(video_true, video_test)
522
- logger.info(
523
- f"{video_true_info} vs {video_test_info}, "
524
- f"Frames: {n}, MSE: {video_mse}"
525
- )
526
- if mertric == "fid" or mertric == "all":
569
+ _logging_msg(video_mse, "mse", n)
570
+ if metric == "fid" or metric == "all":
527
571
  video_fid, n = FID.compute_video_fid(video_true, video_test)
528
- logger.info(
529
- f"{video_true_info} vs {video_test_info}, "
530
- f"Frames: {n}, FID: {video_fid}"
531
- )
572
+ _logging_msg(video_fid, "fid", n)
532
573
 
533
574
  # run selected metrics
534
575
  if not DISABLE_VERBOSE:
@@ -574,7 +615,7 @@ def entrypoint():
574
615
  for metric in args.metrics:
575
616
  for img_test_dir in directories:
576
617
  _run_metric(
577
- mertric=metric,
618
+ metric=metric,
578
619
  img_true=args.ref_img_dir,
579
620
  img_test=img_test_dir,
580
621
  )
@@ -619,7 +660,7 @@ def entrypoint():
619
660
  for metric in args.metrics:
620
661
  for video_test in video_source_selected:
621
662
  _run_metric(
622
- mertric=metric,
663
+ metric=metric,
623
664
  video_true=args.ref_video,
624
665
  video_test=video_test,
625
666
  )
@@ -627,13 +668,169 @@ def entrypoint():
627
668
  else:
628
669
  for metric in args.metrics:
629
670
  _run_metric(
630
- mertric=metric,
671
+ metric=metric,
631
672
  img_true=args.img_true,
632
673
  img_test=args.img_test,
633
674
  video_true=args.video_true,
634
675
  video_test=args.video_test,
635
676
  )
636
677
 
678
+ if args.summary:
679
+
680
+ def _fetch_perf():
681
+ if args.perf_log is None or args.perf_tag is None:
682
+ return []
683
+ if not os.path.exists(args.perf_log):
684
+ return []
685
+ perf_texts = []
686
+ with open(args.perf_log, "r") as file:
687
+ perf_lines = file.readlines()
688
+ for line in perf_lines:
689
+ line = line.strip()
690
+ if args.perf_tag.lower() in line.lower():
691
+ if len(args.extra_perf_tags) == 0:
692
+ perf_texts.append(line)
693
+ else:
694
+ has_all_extra_tag = True
695
+ for ext_tag in args.extra_perf_tags:
696
+ if ext_tag.lower() not in line.lower():
697
+ has_all_extra_tag = False
698
+ break
699
+ if has_all_extra_tag:
700
+ perf_texts.append(line)
701
+ return perf_texts
702
+
703
+ PERF_TEXTS: list[str] = _fetch_perf()
704
+
705
+ def _parse_value(
706
+ text: str,
707
+ tag: str = "Num",
708
+ ) -> float | None:
709
+ import re
710
+
711
+ escaped_tag = re.escape(tag)
712
+ processed_tag = escaped_tag.replace(r"\ ", r"\s+")
713
+
714
+ pattern = re.compile(
715
+ rf"{processed_tag}:\s*(\d+\.?\d*)\D*", re.IGNORECASE
716
+ )
717
+
718
+ match = pattern.search(text)
719
+
720
+ if not match:
721
+ return None
722
+
723
+ value_str = match.group(1)
724
+ try:
725
+ if tag.lower() in METRICS_CHOICES:
726
+ return float(value_str)
727
+ if args.perf_tag is not None:
728
+ if tag.lower() == args.perf_tag.lower():
729
+ return float(value_str)
730
+ return int(value_str)
731
+ except ValueError:
732
+ return None
733
+
734
+ def _parse_perf(
735
+ compare_tag: str,
736
+ ) -> float | None:
737
+ nonlocal PERF_TEXTS
738
+ perf_times = []
739
+ for line in PERF_TEXTS:
740
+ if compare_tag in line:
741
+ perf_time = _parse_value(line, args.perf_tag)
742
+ if perf_time is not None:
743
+ perf_times.append(perf_time)
744
+ if len(perf_times) == 0:
745
+ return None
746
+ return sum(perf_times) / len(perf_times)
747
+
748
+ def _format_item(
749
+ key: str,
750
+ metric: str,
751
+ value: float,
752
+ max_key_len: int,
753
+ ):
754
+ nonlocal PERF_TEXTS
755
+ # U1-Q0-C0-NONE vs U4-Q1-C1-NONE
756
+ header = key.split(",")[0].strip()
757
+ compare_tag = header.split("vs")[1].strip() # U4-Q1-C1-NONE
758
+ has_perf_texts = len(PERF_TEXTS) > 0
759
+ format_str = ""
760
+ # Num / Frames
761
+ if n := _parse_value(key, "Num"):
762
+ if not has_perf_texts:
763
+ format_str = (
764
+ f"{header:<{max_key_len}} Num: {n} "
765
+ f"{metric.upper()}: {value:<7.4f}"
766
+ )
767
+ else:
768
+ perf_time = _parse_perf(compare_tag)
769
+ perf_time = f"{perf_time:<.2f}" if perf_time else None
770
+ format_str = (
771
+ f"{header:<{max_key_len}} Num: {n} "
772
+ f"{metric.upper()}: {value:<7.4f} "
773
+ f"Perf: {perf_time}"
774
+ )
775
+ elif n := _parse_value(key, "Frames"):
776
+ if not has_perf_texts:
777
+ format_str = (
778
+ f"{header:<{max_key_len}} Frames: {n} "
779
+ f"{metric.upper()}: {value:<7.4f}"
780
+ )
781
+ else:
782
+ perf_time = _parse_perf(compare_tag)
783
+ perf_time = f"{perf_time:<.2f}" if perf_time else None
784
+ format_str = (
785
+ f"{header:<{max_key_len}} Frames: {n} "
786
+ f"{metric.upper()}: {value:<7.4f} "
787
+ f"Perf: {perf_time}"
788
+ )
789
+ else:
790
+ raise ValueError("Num or Frames can not be NoneType.")
791
+
792
+ return format_str
793
+
794
+ selected_metrics = args.metrics
795
+ if "all" in selected_metrics:
796
+ selected_metrics = METRICS_CHOICES.copy()
797
+ selected_metrics.remove("all")
798
+
799
+ for metric in selected_metrics:
800
+ selected_items = {}
801
+ for key in METRICS_META.keys():
802
+ if metric.upper() in key or metric.lower() in key:
803
+ selected_items[key] = METRICS_META[key]
804
+
805
+ reverse = True if metric.lower() in ["psnr", "ssim"] else False
806
+ sorted_items = sorted(
807
+ selected_items.items(), key=lambda x: x[1], reverse=reverse
808
+ )
809
+ selected_keys = [
810
+ key.split(",")[0].strip() for key in selected_items.keys()
811
+ ]
812
+ max_key_len = max(len(key) for key in selected_keys)
813
+
814
+ format_strs = []
815
+ for key, value in sorted_items:
816
+ format_strs.append(
817
+ _format_item(key, metric, value, max_key_len)
818
+ )
819
+
820
+ format_len = max(len(format_str) for format_str in format_strs)
821
+
822
+ res_len = format_len - len(f"Summary: {metric.upper()}")
823
+ left_len = res_len // 2
824
+ right_len = res_len - left_len
825
+ print("-" * format_len)
826
+ print(
827
+ " " * left_len + f"Summary: {metric.upper()}" + " " * right_len
828
+ )
829
+ print("-" * format_len)
830
+ for format_str in format_strs:
831
+ print(format_str)
832
+ print("-" * format_len)
833
+
637
834
 
638
835
  if __name__ == "__main__":
639
836
  entrypoint()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.2.10
3
+ Version: 0.2.12
4
4
  Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -63,9 +63,8 @@ Dynamic: requires-python
63
63
  </div>
64
64
 
65
65
  ## 🔥News🔥
66
-
67
- - [2025-07-13] **[🤗flux-faster](https://github.com/xlite-dev/flux-faster)** is released! A forked version of [huggingface/flux-fast](https://github.com/huggingface/flux-fast) that **makes flux-fast even faster** with **[cache-dit](https://github.com/vipshop/cache-dit)**, **3.3x** speedup on NVIDIA L20 while still maintaining **high precision**.
68
-
66
+ - [2025-07-18] 🎉First caching mechanism in **[🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast)** with **[cache-dit](https://github.com/vipshop/cache-dit)**, also check the **[PR](https://github.com/huggingface/flux-fast/pull/13)**.
67
+ - [2025-07-13] **[🤗flux-faster](https://github.com/xlite-dev/flux-faster)** is released! A forked version of **[🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast)** that **makes flux-fast even faster** with **[cache-dit](https://github.com/vipshop/cache-dit)**, **3.3x** speedup on NVIDIA L20.
69
68
 
70
69
  ## 🤗 Introduction
71
70
 
@@ -1,5 +1,5 @@
1
1
  cache_dit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- cache_dit/_version.py,sha256=fMOyoyXAggjNTgl2YJ-8HW1bnjjDPiNACUsDoNufScI,513
2
+ cache_dit/_version.py,sha256=7CFcHKqzy7OglwTX58ipGg1TD8MpDORqWvEBO3W1dHI,513
3
3
  cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
4
4
  cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
5
5
  cache_dit/cache_factory/__init__.py,sha256=iYQwLwB_XLoYl0OB9unZGDbBtrYvZaLkOAmhGRwdW2E,191
@@ -31,17 +31,17 @@ cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py,sh
31
31
  cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py,sha256=lQTClo52OwPbNEE4jiBZQhfC7hbtYqnYIABp_vbm_dk,2363
32
32
  cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py,sha256=dBNzHBECAuTTA1a7kLdvZL20YzaKTAS3iciVLzKKEWA,2638
33
33
  cache_dit/compile/__init__.py,sha256=DfMdPleFFGADXLsr7zXui8BTz_y9futY6rNmNdh9y7k,63
34
- cache_dit/compile/utils.py,sha256=OTvkwcezSrApZ2M1IMkYtkEmFbkfpTknhHMgoBApd6U,3786
34
+ cache_dit/compile/utils.py,sha256=N4A55_8uIbEd-S4xyJPcrdKceI2MGM9BTIhJE63jyL4,3786
35
35
  cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
36
36
  cache_dit/custom_ops/triton_taylorseer.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
37
37
  cache_dit/metrics/__init__.py,sha256=RaUhl5dieF40RqnizGzR30qoJJ9dyMUEADwgwMaMQrE,575
38
38
  cache_dit/metrics/config.py,sha256=ieOgD9ayz722RjVzk24bSIqS2D6o7TZjGk8KeXV-OLQ,551
39
39
  cache_dit/metrics/fid.py,sha256=9Ivtazl6mW0Bon2VXa-Ia5Xj2ewxRD3V1Qkd69zYM3Y,17066
40
40
  cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
41
- cache_dit/metrics/metrics.py,sha256=O9a8qV6deQDWEoez7UZ_aqDLQ9rJXAJUMHGnJM7RUMs,19927
42
- cache_dit-0.2.10.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
43
- cache_dit-0.2.10.dist-info/METADATA,sha256=f5PF-lhexcLdB2HmBWETBEyg011eZOc99tlVI1lozYA,28002
44
- cache_dit-0.2.10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
45
- cache_dit-0.2.10.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
46
- cache_dit-0.2.10.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
47
- cache_dit-0.2.10.dist-info/RECORD,,
41
+ cache_dit/metrics/metrics.py,sha256=1TTbfaj_-vdUfxopLnc5kVrXs5rMpAoSi8D0ItYdPu8,26439
42
+ cache_dit-0.2.12.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
43
+ cache_dit-0.2.12.dist-info/METADATA,sha256=-AIWGVOFsY-nhMkDeFErUFcELTWmza96-0IUN3od88A,28219
44
+ cache_dit-0.2.12.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
45
+ cache_dit-0.2.12.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
46
+ cache_dit-0.2.12.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
47
+ cache_dit-0.2.12.dist-info/RECORD,,