cache-dit 0.2.11__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.11'
21
- __version_tuple__ = version_tuple = (0, 2, 11)
20
+ __version__ = version = '0.2.12'
21
+ __version_tuple__ = version_tuple = (0, 2, 12)
@@ -439,11 +439,34 @@ def get_args():
439
439
 
440
440
  # Format output
441
441
  parser.add_argument(
442
- "--sort-output",
443
- "-sort",
442
+ "--summary",
443
+ "-s",
444
444
  action="store_true",
445
445
  default=False,
446
- help="Sort the outupt metrics results",
446
+ help="Summary the outupt metrics results",
447
+ )
448
+
449
+ # Addtional perf log
450
+ parser.add_argument(
451
+ "--perf-log",
452
+ "-plog",
453
+ type=str,
454
+ default=None,
455
+ help="Path to addtional perf log",
456
+ )
457
+ parser.add_argument(
458
+ "--perf-tag",
459
+ "-ptag",
460
+ type=str,
461
+ default=None,
462
+ help="Tag to parse perf time from perf log",
463
+ )
464
+ parser.add_argument(
465
+ "--extra-perf-tags",
466
+ "-extra-ptags",
467
+ nargs="+",
468
+ default=[],
469
+ help="Extra tags to parse perf time from perf log",
447
470
  )
448
471
  return parser.parse_args()
449
472
 
@@ -489,28 +512,28 @@ def entrypoint():
489
512
  img_true_info = os.path.basename(img_true)
490
513
  img_test_info = os.path.basename(img_test)
491
514
 
492
- def _logging_msg(value: float, n: int):
515
+ def _logging_msg(value: float, name, n: int):
493
516
  if value is None or n is None:
494
517
  return
495
518
  msg = (
496
519
  f"{img_true_info} vs {img_test_info}, "
497
- f"Num: {n}, {metric.upper()}: {value:.5f}"
520
+ f"Num: {n}, {name.upper()}: {value:.5f}"
498
521
  )
499
522
  METRICS_META[msg] = value
500
523
  logger.info(msg)
501
524
 
502
525
  if metric == "psnr" or metric == "all":
503
526
  img_psnr, n = compute_psnr(img_true, img_test)
504
- _logging_msg(img_psnr, n)
527
+ _logging_msg(img_psnr, "psnr", n)
505
528
  if metric == "ssim" or metric == "all":
506
529
  img_ssim, n = compute_ssim(img_true, img_test)
507
- _logging_msg(img_ssim, n)
530
+ _logging_msg(img_ssim, "ssim", n)
508
531
  if metric == "mse" or metric == "all":
509
532
  img_mse, n = compute_mse(img_true, img_test)
510
- _logging_msg(img_mse, n)
533
+ _logging_msg(img_mse, "mse", n)
511
534
  if metric == "fid" or metric == "all":
512
535
  img_fid, n = FID.compute_fid(img_true, img_test)
513
- _logging_msg(img_fid, n)
536
+ _logging_msg(img_fid, "fid", n)
514
537
 
515
538
  if video_true is not None and video_test is not None:
516
539
  if any(
@@ -525,28 +548,28 @@ def entrypoint():
525
548
  video_true_info = os.path.basename(video_true)
526
549
  video_test_info = os.path.basename(video_test)
527
550
 
528
- def _logging_msg(value: float, n: int):
551
+ def _logging_msg(value: float, name, n: int):
529
552
  if value is None or n is None:
530
553
  return
531
554
  msg = (
532
555
  f"{video_true_info} vs {video_test_info}, "
533
- f"Frames: {n}, {metric.upper()}: {value:.5f}"
556
+ f"Frames: {n}, {name.upper()}: {value:.5f}"
534
557
  )
535
558
  METRICS_META[msg] = value
536
559
  logger.info(msg)
537
560
 
538
561
  if metric == "psnr" or metric == "all":
539
562
  video_psnr, n = compute_video_psnr(video_true, video_test)
540
- _logging_msg(video_psnr, n)
563
+ _logging_msg(video_psnr, "psnr", n)
541
564
  if metric == "ssim" or metric == "all":
542
565
  video_ssim, n = compute_video_ssim(video_true, video_test)
543
- _logging_msg(video_ssim, n)
566
+ _logging_msg(video_ssim, "ssim", n)
544
567
  if metric == "mse" or metric == "all":
545
568
  video_mse, n = compute_video_mse(video_true, video_test)
546
- _logging_msg(video_mse, n)
569
+ _logging_msg(video_mse, "mse", n)
547
570
  if metric == "fid" or metric == "all":
548
571
  video_fid, n = FID.compute_video_fid(video_true, video_test)
549
- _logging_msg(video_fid, n)
572
+ _logging_msg(video_fid, "fid", n)
550
573
 
551
574
  # run selected metrics
552
575
  if not DISABLE_VERBOSE:
@@ -652,16 +675,44 @@ def entrypoint():
652
675
  video_test=args.video_test,
653
676
  )
654
677
 
655
- if args.sort_output:
678
+ if args.summary:
679
+
680
+ def _fetch_perf():
681
+ if args.perf_log is None or args.perf_tag is None:
682
+ return []
683
+ if not os.path.exists(args.perf_log):
684
+ return []
685
+ perf_texts = []
686
+ with open(args.perf_log, "r") as file:
687
+ perf_lines = file.readlines()
688
+ for line in perf_lines:
689
+ line = line.strip()
690
+ if args.perf_tag.lower() in line.lower():
691
+ if len(args.extra_perf_tags) == 0:
692
+ perf_texts.append(line)
693
+ else:
694
+ has_all_extra_tag = True
695
+ for ext_tag in args.extra_perf_tags:
696
+ if ext_tag.lower() not in line.lower():
697
+ has_all_extra_tag = False
698
+ break
699
+ if has_all_extra_tag:
700
+ perf_texts.append(line)
701
+ return perf_texts
702
+
703
+ PERF_TEXTS: list[str] = _fetch_perf()
656
704
 
657
705
  def _parse_value(
658
706
  text: str,
659
707
  tag: str = "Num",
660
- ) -> float:
708
+ ) -> float | None:
661
709
  import re
662
710
 
711
+ escaped_tag = re.escape(tag)
712
+ processed_tag = escaped_tag.replace(r"\ ", r"\s+")
713
+
663
714
  pattern = re.compile(
664
- rf"{re.escape(tag)}:\s*(\d+\.?\d*)", re.IGNORECASE
715
+ rf"{processed_tag}:\s*(\d+\.?\d*)\D*", re.IGNORECASE
665
716
  )
666
717
 
667
718
  match = pattern.search(text)
@@ -669,9 +720,30 @@ def entrypoint():
669
720
  if not match:
670
721
  return None
671
722
 
672
- if tag.lower() in METRICS_CHOICES:
673
- return float(match.group(1))
674
- return int(match.group(1))
723
+ value_str = match.group(1)
724
+ try:
725
+ if tag.lower() in METRICS_CHOICES:
726
+ return float(value_str)
727
+ if args.perf_tag is not None:
728
+ if tag.lower() == args.perf_tag.lower():
729
+ return float(value_str)
730
+ return int(value_str)
731
+ except ValueError:
732
+ return None
733
+
734
+ def _parse_perf(
735
+ compare_tag: str,
736
+ ) -> float | None:
737
+ nonlocal PERF_TEXTS
738
+ perf_times = []
739
+ for line in PERF_TEXTS:
740
+ if compare_tag in line:
741
+ perf_time = _parse_value(line, args.perf_tag)
742
+ if perf_time is not None:
743
+ perf_times.append(perf_time)
744
+ if len(perf_times) == 0:
745
+ return None
746
+ return sum(perf_times) / len(perf_times)
675
747
 
676
748
  def _format_item(
677
749
  key: str,
@@ -679,23 +751,52 @@ def entrypoint():
679
751
  value: float,
680
752
  max_key_len: int,
681
753
  ):
754
+ nonlocal PERF_TEXTS
682
755
  # U1-Q0-C0-NONE vs U4-Q1-C1-NONE
683
756
  header = key.split(",")[0].strip()
757
+ compare_tag = header.split("vs")[1].strip() # U4-Q1-C1-NONE
758
+ has_perf_texts = len(PERF_TEXTS) > 0
759
+ format_str = ""
684
760
  # Num / Frames
685
761
  if n := _parse_value(key, "Num"):
686
- print(
687
- f"{header:<{max_key_len}} Num: {n} "
688
- f"{metric.upper()}: {value:<.4f}"
689
- )
762
+ if not has_perf_texts:
763
+ format_str = (
764
+ f"{header:<{max_key_len}} Num: {n} "
765
+ f"{metric.upper()}: {value:<7.4f}"
766
+ )
767
+ else:
768
+ perf_time = _parse_perf(compare_tag)
769
+ perf_time = f"{perf_time:<.2f}" if perf_time else None
770
+ format_str = (
771
+ f"{header:<{max_key_len}} Num: {n} "
772
+ f"{metric.upper()}: {value:<7.4f} "
773
+ f"Perf: {perf_time}"
774
+ )
690
775
  elif n := _parse_value(key, "Frames"):
691
- print(
692
- f"{header:<{max_key_len}} Frames: {n} "
693
- f"{metric.upper()}: {value:<.4f}"
694
- )
776
+ if not has_perf_texts:
777
+ format_str = (
778
+ f"{header:<{max_key_len}} Frames: {n} "
779
+ f"{metric.upper()}: {value:<7.4f}"
780
+ )
781
+ else:
782
+ perf_time = _parse_perf(compare_tag)
783
+ perf_time = f"{perf_time:<.2f}" if perf_time else None
784
+ format_str = (
785
+ f"{header:<{max_key_len}} Frames: {n} "
786
+ f"{metric.upper()}: {value:<7.4f} "
787
+ f"Perf: {perf_time}"
788
+ )
695
789
  else:
696
790
  raise ValueError("Num or Frames can not be NoneType.")
697
791
 
698
- for metric in args.metrics:
792
+ return format_str
793
+
794
+ selected_metrics = args.metrics
795
+ if "all" in selected_metrics:
796
+ selected_metrics = METRICS_CHOICES.copy()
797
+ selected_metrics.remove("all")
798
+
799
+ for metric in selected_metrics:
699
800
  selected_items = {}
700
801
  for key in METRICS_META.keys():
701
802
  if metric.upper() in key or metric.lower() in key:
@@ -710,7 +811,14 @@ def entrypoint():
710
811
  ]
711
812
  max_key_len = max(len(key) for key in selected_keys)
712
813
 
713
- format_len = int(max_key_len * 1.5)
814
+ format_strs = []
815
+ for key, value in sorted_items:
816
+ format_strs.append(
817
+ _format_item(key, metric, value, max_key_len)
818
+ )
819
+
820
+ format_len = max(len(format_str) for format_str in format_strs)
821
+
714
822
  res_len = format_len - len(f"Summary: {metric.upper()}")
715
823
  left_len = res_len // 2
716
824
  right_len = res_len - left_len
@@ -719,8 +827,8 @@ def entrypoint():
719
827
  " " * left_len + f"Summary: {metric.upper()}" + " " * right_len
720
828
  )
721
829
  print("-" * format_len)
722
- for key, value in sorted_items:
723
- _format_item(key, metric, value, max_key_len)
830
+ for format_str in format_strs:
831
+ print(format_str)
724
832
  print("-" * format_len)
725
833
 
726
834
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.2.11
3
+ Version: 0.2.12
4
4
  Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -63,7 +63,7 @@ Dynamic: requires-python
63
63
  </div>
64
64
 
65
65
  ## 🔥News🔥
66
- - [2025-07-18] 🎉First caching mechanism in **[🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast)** with **[cache-dit](https://github.com/vipshop/cache-dit)**, please check [PR](https://github.com/huggingface/flux-fast/pull/13).
66
+ - [2025-07-18] 🎉First caching mechanism in **[🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast)** with **[cache-dit](https://github.com/vipshop/cache-dit)**, also check the **[PR](https://github.com/huggingface/flux-fast/pull/13)**.
67
67
  - [2025-07-13] **[🤗flux-faster](https://github.com/xlite-dev/flux-faster)** is released! A forked version of **[🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast)** that **makes flux-fast even faster** with **[cache-dit](https://github.com/vipshop/cache-dit)**, **3.3x** speedup on NVIDIA L20.
68
68
 
69
69
  ## 🤗 Introduction
@@ -1,5 +1,5 @@
1
1
  cache_dit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- cache_dit/_version.py,sha256=Y72g1mojWf0yRnnMW5zEUr6skXUSsqAdPjRJUrxXSYc,513
2
+ cache_dit/_version.py,sha256=7CFcHKqzy7OglwTX58ipGg1TD8MpDORqWvEBO3W1dHI,513
3
3
  cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
4
4
  cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
5
5
  cache_dit/cache_factory/__init__.py,sha256=iYQwLwB_XLoYl0OB9unZGDbBtrYvZaLkOAmhGRwdW2E,191
@@ -38,10 +38,10 @@ cache_dit/metrics/__init__.py,sha256=RaUhl5dieF40RqnizGzR30qoJJ9dyMUEADwgwMaMQrE
38
38
  cache_dit/metrics/config.py,sha256=ieOgD9ayz722RjVzk24bSIqS2D6o7TZjGk8KeXV-OLQ,551
39
39
  cache_dit/metrics/fid.py,sha256=9Ivtazl6mW0Bon2VXa-Ia5Xj2ewxRD3V1Qkd69zYM3Y,17066
40
40
  cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
41
- cache_dit/metrics/metrics.py,sha256=PAzyhJawos1UeMnHsxcu4edkwCSYMBmjDGRR_--I104,22410
42
- cache_dit-0.2.11.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
43
- cache_dit-0.2.11.dist-info/METADATA,sha256=HExb-ldgaYzZRSCy6Tg6syONjd0uDoB8_8AShvbfA-0,28213
44
- cache_dit-0.2.11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
45
- cache_dit-0.2.11.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
46
- cache_dit-0.2.11.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
47
- cache_dit-0.2.11.dist-info/RECORD,,
41
+ cache_dit/metrics/metrics.py,sha256=1TTbfaj_-vdUfxopLnc5kVrXs5rMpAoSi8D0ItYdPu8,26439
42
+ cache_dit-0.2.12.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
43
+ cache_dit-0.2.12.dist-info/METADATA,sha256=-AIWGVOFsY-nhMkDeFErUFcELTWmza96-0IUN3od88A,28219
44
+ cache_dit-0.2.12.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
45
+ cache_dit-0.2.12.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
46
+ cache_dit-0.2.12.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
47
+ cache_dit-0.2.12.dist-info/RECORD,,