cache-dit 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cache_dit/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.11'
21
- __version_tuple__ = version_tuple = (0, 2, 11)
20
+ __version__ = version = '0.2.13'
21
+ __version_tuple__ = version_tuple = (0, 2, 13)
@@ -0,0 +1,43 @@
1
+ import builtins as __builtin__
2
+ import contextlib
3
+ import warnings
4
+
5
+ import lpips
6
+ import torch
7
+
8
+ warnings.filterwarnings("ignore")
9
+
10
+ lpips_loss_fn_vgg = None
11
+ lpips_loss_fn_alex = None
12
+
13
+
14
+ def dummy_print(*args, **kwargs):
15
+ pass
16
+
17
+
18
+ @contextlib.contextmanager
19
+ def disable_print():
20
+ origin_print = __builtin__.print
21
+ __builtin__.print = dummy_print
22
+ yield
23
+ __builtin__.print = origin_print
24
+
25
+
26
+ def compute_lpips_img(img0, img1, net: str = "alex"):
27
+ global lpips_loss_fn_vgg
28
+ global lpips_loss_fn_alex
29
+ if net.lower() == "alex":
30
+ if lpips_loss_fn_alex is None:
31
+ with disable_print():
32
+ lpips_loss_fn_alex = lpips.LPIPS(net="alex")
33
+ loss_fn = lpips_loss_fn_alex
34
+ elif net.lower() == "vgg":
35
+ if lpips_loss_fn_vgg is None:
36
+ with disable_print():
37
+ lpips_loss_fn_vgg = lpips.LPIPS(net="vgg")
38
+ loss_fn = lpips_loss_fn_vgg
39
+ else:
40
+ assert False, f"unsupport net {net}"
41
+
42
+ with torch.no_grad():
43
+ return loss_fn(img0, img1).item()
@@ -14,6 +14,7 @@ from cache_dit.metrics.config import get_metrics_verbose
14
14
  from cache_dit.metrics.config import _IMAGE_EXTENSIONS
15
15
  from cache_dit.metrics.config import _VIDEO_EXTENSIONS
16
16
  from cache_dit.logger import init_logger
17
+ from cache_dit.metrics.lpips import compute_lpips_img
17
18
 
18
19
  logger = init_logger(__name__)
19
20
 
@@ -21,6 +22,35 @@ logger = init_logger(__name__)
21
22
  DISABLE_VERBOSE = not get_metrics_verbose()
22
23
 
23
24
 
25
+ def compute_lpips_file(
26
+ image_true: np.ndarray | str,
27
+ image_test: np.ndarray | str,
28
+ ) -> float:
29
+ import torch
30
+ from PIL import Image
31
+ from torchvision.transforms.v2.functional import (
32
+ convert_image_dtype,
33
+ normalize,
34
+ pil_to_tensor,
35
+ )
36
+
37
+ def load_img_as_tensor(path):
38
+ pil = Image.open(path)
39
+ img = pil_to_tensor(pil)
40
+ img = convert_image_dtype(img, dtype=torch.float32)
41
+ img = normalize(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
42
+ return img
43
+
44
+ if isinstance(image_true, str):
45
+ image_true = load_img_as_tensor(image_true)
46
+ if isinstance(image_test, str):
47
+ image_test = load_img_as_tensor(image_test)
48
+ return compute_lpips_img(
49
+ image_true,
50
+ image_test,
51
+ )
52
+
53
+
24
54
  def compute_psnr_file(
25
55
  image_true: np.ndarray | str,
26
56
  image_test: np.ndarray | str,
@@ -305,6 +335,11 @@ def compute_video_metric(
305
335
  return None, None
306
336
 
307
337
 
338
+ compute_lpips = partial(
339
+ compute_dir_metric,
340
+ compute_file_func=compute_lpips_file,
341
+ )
342
+
308
343
  compute_psnr = partial(
309
344
  compute_dir_metric,
310
345
  compute_file_func=compute_psnr_file,
@@ -320,6 +355,10 @@ compute_mse = partial(
320
355
  compute_file_func=compute_mse_file,
321
356
  )
322
357
 
358
+ compute_video_lpips = partial(
359
+ compute_video_metric,
360
+ compute_frame_func=compute_lpips_file,
361
+ )
323
362
  compute_video_psnr = partial(
324
363
  compute_video_metric,
325
364
  compute_frame_func=compute_psnr_file,
@@ -335,6 +374,7 @@ compute_video_mse = partial(
335
374
 
336
375
 
337
376
  METRICS_CHOICES = [
377
+ "lpips",
338
378
  "psnr",
339
379
  "ssim",
340
380
  "mse",
@@ -439,11 +479,34 @@ def get_args():
439
479
 
440
480
  # Format output
441
481
  parser.add_argument(
442
- "--sort-output",
443
- "-sort",
482
+ "--summary",
483
+ "-s",
444
484
  action="store_true",
445
485
  default=False,
446
- help="Sort the outupt metrics results",
486
+ help="Summary the outupt metrics results",
487
+ )
488
+
489
+ # Addtional perf log
490
+ parser.add_argument(
491
+ "--perf-log",
492
+ "-plog",
493
+ type=str,
494
+ default=None,
495
+ help="Path to addtional perf log",
496
+ )
497
+ parser.add_argument(
498
+ "--perf-tag",
499
+ "-ptag",
500
+ type=str,
501
+ default=None,
502
+ help="Tag to parse perf time from perf log",
503
+ )
504
+ parser.add_argument(
505
+ "--extra-perf-tags",
506
+ "-extra-ptags",
507
+ nargs="+",
508
+ default=[],
509
+ help="Extra tags to parse perf time from perf log",
447
510
  )
448
511
  return parser.parse_args()
449
512
 
@@ -489,28 +552,31 @@ def entrypoint():
489
552
  img_true_info = os.path.basename(img_true)
490
553
  img_test_info = os.path.basename(img_test)
491
554
 
492
- def _logging_msg(value: float, n: int):
555
+ def _logging_msg(value: float, name, n: int):
493
556
  if value is None or n is None:
494
557
  return
495
558
  msg = (
496
559
  f"{img_true_info} vs {img_test_info}, "
497
- f"Num: {n}, {metric.upper()}: {value:.5f}"
560
+ f"Num: {n}, {name.upper()}: {value:.5f}"
498
561
  )
499
562
  METRICS_META[msg] = value
500
563
  logger.info(msg)
501
564
 
565
+ if metric == "lpips" or metric == "all":
566
+ img_lpips, n = compute_lpips(img_true, img_test)
567
+ _logging_msg(img_lpips, "lpips", n)
502
568
  if metric == "psnr" or metric == "all":
503
569
  img_psnr, n = compute_psnr(img_true, img_test)
504
- _logging_msg(img_psnr, n)
570
+ _logging_msg(img_psnr, "psnr", n)
505
571
  if metric == "ssim" or metric == "all":
506
572
  img_ssim, n = compute_ssim(img_true, img_test)
507
- _logging_msg(img_ssim, n)
573
+ _logging_msg(img_ssim, "ssim", n)
508
574
  if metric == "mse" or metric == "all":
509
575
  img_mse, n = compute_mse(img_true, img_test)
510
- _logging_msg(img_mse, n)
576
+ _logging_msg(img_mse, "mse", n)
511
577
  if metric == "fid" or metric == "all":
512
578
  img_fid, n = FID.compute_fid(img_true, img_test)
513
- _logging_msg(img_fid, n)
579
+ _logging_msg(img_fid, "fid", n)
514
580
 
515
581
  if video_true is not None and video_test is not None:
516
582
  if any(
@@ -525,28 +591,31 @@ def entrypoint():
525
591
  video_true_info = os.path.basename(video_true)
526
592
  video_test_info = os.path.basename(video_test)
527
593
 
528
- def _logging_msg(value: float, n: int):
594
+ def _logging_msg(value: float, name, n: int):
529
595
  if value is None or n is None:
530
596
  return
531
597
  msg = (
532
598
  f"{video_true_info} vs {video_test_info}, "
533
- f"Frames: {n}, {metric.upper()}: {value:.5f}"
599
+ f"Frames: {n}, {name.upper()}: {value:.5f}"
534
600
  )
535
601
  METRICS_META[msg] = value
536
602
  logger.info(msg)
537
603
 
604
+ if metric == "lpips" or metric == "all":
605
+ video_lpips, n = compute_video_lpips(video_true, video_test)
606
+ _logging_msg(video_lpips, "lpips", n)
538
607
  if metric == "psnr" or metric == "all":
539
608
  video_psnr, n = compute_video_psnr(video_true, video_test)
540
- _logging_msg(video_psnr, n)
609
+ _logging_msg(video_psnr, "psnr", n)
541
610
  if metric == "ssim" or metric == "all":
542
611
  video_ssim, n = compute_video_ssim(video_true, video_test)
543
- _logging_msg(video_ssim, n)
612
+ _logging_msg(video_ssim, "ssim", n)
544
613
  if metric == "mse" or metric == "all":
545
614
  video_mse, n = compute_video_mse(video_true, video_test)
546
- _logging_msg(video_mse, n)
615
+ _logging_msg(video_mse, "mse", n)
547
616
  if metric == "fid" or metric == "all":
548
617
  video_fid, n = FID.compute_video_fid(video_true, video_test)
549
- _logging_msg(video_fid, n)
618
+ _logging_msg(video_fid, "fid", n)
550
619
 
551
620
  # run selected metrics
552
621
  if not DISABLE_VERBOSE:
@@ -652,16 +721,44 @@ def entrypoint():
652
721
  video_test=args.video_test,
653
722
  )
654
723
 
655
- if args.sort_output:
724
+ if args.summary:
725
+
726
+ def _fetch_perf():
727
+ if args.perf_log is None or args.perf_tag is None:
728
+ return []
729
+ if not os.path.exists(args.perf_log):
730
+ return []
731
+ perf_texts = []
732
+ with open(args.perf_log, "r") as file:
733
+ perf_lines = file.readlines()
734
+ for line in perf_lines:
735
+ line = line.strip()
736
+ if args.perf_tag.lower() in line.lower():
737
+ if len(args.extra_perf_tags) == 0:
738
+ perf_texts.append(line)
739
+ else:
740
+ has_all_extra_tag = True
741
+ for ext_tag in args.extra_perf_tags:
742
+ if ext_tag.lower() not in line.lower():
743
+ has_all_extra_tag = False
744
+ break
745
+ if has_all_extra_tag:
746
+ perf_texts.append(line)
747
+ return perf_texts
748
+
749
+ PERF_TEXTS: list[str] = _fetch_perf()
656
750
 
657
751
  def _parse_value(
658
752
  text: str,
659
753
  tag: str = "Num",
660
- ) -> float:
754
+ ) -> float | None:
661
755
  import re
662
756
 
757
+ escaped_tag = re.escape(tag)
758
+ processed_tag = escaped_tag.replace(r"\ ", r"\s+")
759
+
663
760
  pattern = re.compile(
664
- rf"{re.escape(tag)}:\s*(\d+\.?\d*)", re.IGNORECASE
761
+ rf"{processed_tag}:\s*(\d+\.?\d*)\D*", re.IGNORECASE
665
762
  )
666
763
 
667
764
  match = pattern.search(text)
@@ -669,9 +766,30 @@ def entrypoint():
669
766
  if not match:
670
767
  return None
671
768
 
672
- if tag.lower() in METRICS_CHOICES:
673
- return float(match.group(1))
674
- return int(match.group(1))
769
+ value_str = match.group(1)
770
+ try:
771
+ if tag.lower() in METRICS_CHOICES:
772
+ return float(value_str)
773
+ if args.perf_tag is not None:
774
+ if tag.lower() == args.perf_tag.lower():
775
+ return float(value_str)
776
+ return int(value_str)
777
+ except ValueError:
778
+ return None
779
+
780
+ def _parse_perf(
781
+ compare_tag: str,
782
+ ) -> float | None:
783
+ nonlocal PERF_TEXTS
784
+ perf_times = []
785
+ for line in PERF_TEXTS:
786
+ if compare_tag in line:
787
+ perf_time = _parse_value(line, args.perf_tag)
788
+ if perf_time is not None:
789
+ perf_times.append(perf_time)
790
+ if len(perf_times) == 0:
791
+ return None
792
+ return sum(perf_times) / len(perf_times)
675
793
 
676
794
  def _format_item(
677
795
  key: str,
@@ -679,23 +797,52 @@ def entrypoint():
679
797
  value: float,
680
798
  max_key_len: int,
681
799
  ):
800
+ nonlocal PERF_TEXTS
682
801
  # U1-Q0-C0-NONE vs U4-Q1-C1-NONE
683
802
  header = key.split(",")[0].strip()
803
+ compare_tag = header.split("vs")[1].strip() # U4-Q1-C1-NONE
804
+ has_perf_texts = len(PERF_TEXTS) > 0
805
+ format_str = ""
684
806
  # Num / Frames
685
807
  if n := _parse_value(key, "Num"):
686
- print(
687
- f"{header:<{max_key_len}} Num: {n} "
688
- f"{metric.upper()}: {value:<.4f}"
689
- )
808
+ if not has_perf_texts:
809
+ format_str = (
810
+ f"{header:<{max_key_len}} Num: {n} "
811
+ f"{metric.upper()}: {value:<7.4f}"
812
+ )
813
+ else:
814
+ perf_time = _parse_perf(compare_tag)
815
+ perf_time = f"{perf_time:<.2f}" if perf_time else None
816
+ format_str = (
817
+ f"{header:<{max_key_len}} Num: {n} "
818
+ f"{metric.upper()}: {value:<7.4f} "
819
+ f"Perf: {perf_time}"
820
+ )
690
821
  elif n := _parse_value(key, "Frames"):
691
- print(
692
- f"{header:<{max_key_len}} Frames: {n} "
693
- f"{metric.upper()}: {value:<.4f}"
694
- )
822
+ if not has_perf_texts:
823
+ format_str = (
824
+ f"{header:<{max_key_len}} Frames: {n} "
825
+ f"{metric.upper()}: {value:<7.4f}"
826
+ )
827
+ else:
828
+ perf_time = _parse_perf(compare_tag)
829
+ perf_time = f"{perf_time:<.2f}" if perf_time else None
830
+ format_str = (
831
+ f"{header:<{max_key_len}} Frames: {n} "
832
+ f"{metric.upper()}: {value:<7.4f} "
833
+ f"Perf: {perf_time}"
834
+ )
695
835
  else:
696
836
  raise ValueError("Num or Frames can not be NoneType.")
697
837
 
698
- for metric in args.metrics:
838
+ return format_str
839
+
840
+ selected_metrics = args.metrics
841
+ if "all" in selected_metrics:
842
+ selected_metrics = METRICS_CHOICES.copy()
843
+ selected_metrics.remove("all")
844
+
845
+ for metric in selected_metrics:
699
846
  selected_items = {}
700
847
  for key in METRICS_META.keys():
701
848
  if metric.upper() in key or metric.lower() in key:
@@ -710,7 +857,14 @@ def entrypoint():
710
857
  ]
711
858
  max_key_len = max(len(key) for key in selected_keys)
712
859
 
713
- format_len = int(max_key_len * 1.5)
860
+ format_strs = []
861
+ for key, value in sorted_items:
862
+ format_strs.append(
863
+ _format_item(key, metric, value, max_key_len)
864
+ )
865
+
866
+ format_len = max(len(format_str) for format_str in format_strs)
867
+
714
868
  res_len = format_len - len(f"Summary: {metric.upper()}")
715
869
  left_len = res_len // 2
716
870
  right_len = res_len - left_len
@@ -719,8 +873,8 @@ def entrypoint():
719
873
  " " * left_len + f"Summary: {metric.upper()}" + " " * right_len
720
874
  )
721
875
  print("-" * format_len)
722
- for key, value in sorted_items:
723
- _format_item(key, metric, value, max_key_len)
876
+ for format_str in format_strs:
877
+ print(format_str)
724
878
  print("-" * format_len)
725
879
 
726
880
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.2.11
3
+ Version: 0.2.13
4
4
  Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -16,6 +16,7 @@ Requires-Dist: transformers>=4.51.3
16
16
  Requires-Dist: diffusers>=0.33.1
17
17
  Requires-Dist: scikit-image
18
18
  Requires-Dist: scipy
19
+ Requires-Dist: lpips==0.1.4
19
20
  Provides-Extra: all
20
21
  Provides-Extra: dev
21
22
  Requires-Dist: pre-commit; extra == "dev"
@@ -63,7 +64,7 @@ Dynamic: requires-python
63
64
  </div>
64
65
 
65
66
  ## 🔥News🔥
66
- - [2025-07-18] 🎉First caching mechanism in **[🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast)** with **[cache-dit](https://github.com/vipshop/cache-dit)**, please check [PR](https://github.com/huggingface/flux-fast/pull/13).
67
+ - [2025-07-18] 🎉First caching mechanism in **[🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast)** with **[cache-dit](https://github.com/vipshop/cache-dit)**, also check the **[PR](https://github.com/huggingface/flux-fast/pull/13)**.
67
68
  - [2025-07-13] **[🤗flux-faster](https://github.com/xlite-dev/flux-faster)** is released! A forked version of **[🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast)** that **makes flux-fast even faster** with **[cache-dit](https://github.com/vipshop/cache-dit)**, **3.3x** speedup on NVIDIA L20.
68
69
 
69
70
  ## 🤗 Introduction
@@ -1,5 +1,5 @@
1
1
  cache_dit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- cache_dit/_version.py,sha256=Y72g1mojWf0yRnnMW5zEUr6skXUSsqAdPjRJUrxXSYc,513
2
+ cache_dit/_version.py,sha256=2ECxD0Bipdh9vxnyteM0k9jxi9NOpPR7YxTi7Ad1ors,513
3
3
  cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
4
4
  cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
5
5
  cache_dit/cache_factory/__init__.py,sha256=iYQwLwB_XLoYl0OB9unZGDbBtrYvZaLkOAmhGRwdW2E,191
@@ -38,10 +38,11 @@ cache_dit/metrics/__init__.py,sha256=RaUhl5dieF40RqnizGzR30qoJJ9dyMUEADwgwMaMQrE
38
38
  cache_dit/metrics/config.py,sha256=ieOgD9ayz722RjVzk24bSIqS2D6o7TZjGk8KeXV-OLQ,551
39
39
  cache_dit/metrics/fid.py,sha256=9Ivtazl6mW0Bon2VXa-Ia5Xj2ewxRD3V1Qkd69zYM3Y,17066
40
40
  cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
41
- cache_dit/metrics/metrics.py,sha256=PAzyhJawos1UeMnHsxcu4edkwCSYMBmjDGRR_--I104,22410
42
- cache_dit-0.2.11.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
43
- cache_dit-0.2.11.dist-info/METADATA,sha256=HExb-ldgaYzZRSCy6Tg6syONjd0uDoB8_8AShvbfA-0,28213
44
- cache_dit-0.2.11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
45
- cache_dit-0.2.11.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
46
- cache_dit-0.2.11.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
47
- cache_dit-0.2.11.dist-info/RECORD,,
41
+ cache_dit/metrics/lpips.py,sha256=I2qCNi6qJh5TRsaIsdxO0WoRX1DN7U_H3zS0oCSahYM,1032
42
+ cache_dit/metrics/metrics.py,sha256=8jvM1sF-nDxUuwCRy44QEoo4dYVLCQVh1QyAMs4eaQY,27840
43
+ cache_dit-0.2.13.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
44
+ cache_dit-0.2.13.dist-info/METADATA,sha256=at8DNFeGI5aVnBTi7_6zJgAi_QdgsItpBMzSGl8HEME,28247
45
+ cache_dit-0.2.13.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
46
+ cache_dit-0.2.13.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
47
+ cache_dit-0.2.13.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
48
+ cache_dit-0.2.13.dist-info/RECORD,,