cache-dit 0.2.12__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cache_dit/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.12'
21
- __version_tuple__ = version_tuple = (0, 2, 12)
20
+ __version__ = version = '0.2.13'
21
+ __version_tuple__ = version_tuple = (0, 2, 13)
@@ -0,0 +1,43 @@
1
+ import builtins as __builtin__
2
+ import contextlib
3
+ import warnings
4
+
5
+ import lpips
6
+ import torch
7
+
8
+ warnings.filterwarnings("ignore")
9
+
10
+ lpips_loss_fn_vgg = None
11
+ lpips_loss_fn_alex = None
12
+
13
+
14
+ def dummy_print(*args, **kwargs):
15
+ pass
16
+
17
+
18
+ @contextlib.contextmanager
19
+ def disable_print():
20
+ origin_print = __builtin__.print
21
+ __builtin__.print = dummy_print
22
+ yield
23
+ __builtin__.print = origin_print
24
+
25
+
26
+ def compute_lpips_img(img0, img1, net: str = "alex"):
27
+ global lpips_loss_fn_vgg
28
+ global lpips_loss_fn_alex
29
+ if net.lower() == "alex":
30
+ if lpips_loss_fn_alex is None:
31
+ with disable_print():
32
+ lpips_loss_fn_alex = lpips.LPIPS(net="alex")
33
+ loss_fn = lpips_loss_fn_alex
34
+ elif net.lower() == "vgg":
35
+ if lpips_loss_fn_vgg is None:
36
+ with disable_print():
37
+ lpips_loss_fn_vgg = lpips.LPIPS(net="vgg")
38
+ loss_fn = lpips_loss_fn_vgg
39
+ else:
40
+ assert False, f"unsupport net {net}"
41
+
42
+ with torch.no_grad():
43
+ return loss_fn(img0, img1).item()
@@ -14,6 +14,7 @@ from cache_dit.metrics.config import get_metrics_verbose
14
14
  from cache_dit.metrics.config import _IMAGE_EXTENSIONS
15
15
  from cache_dit.metrics.config import _VIDEO_EXTENSIONS
16
16
  from cache_dit.logger import init_logger
17
+ from cache_dit.metrics.lpips import compute_lpips_img
17
18
 
18
19
  logger = init_logger(__name__)
19
20
 
@@ -21,6 +22,35 @@ logger = init_logger(__name__)
21
22
  DISABLE_VERBOSE = not get_metrics_verbose()
22
23
 
23
24
 
25
+ def compute_lpips_file(
26
+ image_true: np.ndarray | str,
27
+ image_test: np.ndarray | str,
28
+ ) -> float:
29
+ import torch
30
+ from PIL import Image
31
+ from torchvision.transforms.v2.functional import (
32
+ convert_image_dtype,
33
+ normalize,
34
+ pil_to_tensor,
35
+ )
36
+
37
+ def load_img_as_tensor(path):
38
+ pil = Image.open(path)
39
+ img = pil_to_tensor(pil)
40
+ img = convert_image_dtype(img, dtype=torch.float32)
41
+ img = normalize(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
42
+ return img
43
+
44
+ if isinstance(image_true, str):
45
+ image_true = load_img_as_tensor(image_true)
46
+ if isinstance(image_test, str):
47
+ image_test = load_img_as_tensor(image_test)
48
+ return compute_lpips_img(
49
+ image_true,
50
+ image_test,
51
+ )
52
+
53
+
24
54
  def compute_psnr_file(
25
55
  image_true: np.ndarray | str,
26
56
  image_test: np.ndarray | str,
@@ -305,6 +335,11 @@ def compute_video_metric(
305
335
  return None, None
306
336
 
307
337
 
338
+ compute_lpips = partial(
339
+ compute_dir_metric,
340
+ compute_file_func=compute_lpips_file,
341
+ )
342
+
308
343
  compute_psnr = partial(
309
344
  compute_dir_metric,
310
345
  compute_file_func=compute_psnr_file,
@@ -320,6 +355,10 @@ compute_mse = partial(
320
355
  compute_file_func=compute_mse_file,
321
356
  )
322
357
 
358
+ compute_video_lpips = partial(
359
+ compute_video_metric,
360
+ compute_frame_func=compute_lpips_file,
361
+ )
323
362
  compute_video_psnr = partial(
324
363
  compute_video_metric,
325
364
  compute_frame_func=compute_psnr_file,
@@ -335,6 +374,7 @@ compute_video_mse = partial(
335
374
 
336
375
 
337
376
  METRICS_CHOICES = [
377
+ "lpips",
338
378
  "psnr",
339
379
  "ssim",
340
380
  "mse",
@@ -522,6 +562,9 @@ def entrypoint():
522
562
  METRICS_META[msg] = value
523
563
  logger.info(msg)
524
564
 
565
+ if metric == "lpips" or metric == "all":
566
+ img_lpips, n = compute_lpips(img_true, img_test)
567
+ _logging_msg(img_lpips, "lpips", n)
525
568
  if metric == "psnr" or metric == "all":
526
569
  img_psnr, n = compute_psnr(img_true, img_test)
527
570
  _logging_msg(img_psnr, "psnr", n)
@@ -558,6 +601,9 @@ def entrypoint():
558
601
  METRICS_META[msg] = value
559
602
  logger.info(msg)
560
603
 
604
+ if metric == "lpips" or metric == "all":
605
+ video_lpips, n = compute_video_lpips(video_true, video_test)
606
+ _logging_msg(video_lpips, "lpips", n)
561
607
  if metric == "psnr" or metric == "all":
562
608
  video_psnr, n = compute_video_psnr(video_true, video_test)
563
609
  _logging_msg(video_psnr, "psnr", n)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.2.12
3
+ Version: 0.2.13
4
4
  Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -16,6 +16,7 @@ Requires-Dist: transformers>=4.51.3
16
16
  Requires-Dist: diffusers>=0.33.1
17
17
  Requires-Dist: scikit-image
18
18
  Requires-Dist: scipy
19
+ Requires-Dist: lpips==0.1.4
19
20
  Provides-Extra: all
20
21
  Provides-Extra: dev
21
22
  Requires-Dist: pre-commit; extra == "dev"
@@ -1,5 +1,5 @@
1
1
  cache_dit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- cache_dit/_version.py,sha256=7CFcHKqzy7OglwTX58ipGg1TD8MpDORqWvEBO3W1dHI,513
2
+ cache_dit/_version.py,sha256=2ECxD0Bipdh9vxnyteM0k9jxi9NOpPR7YxTi7Ad1ors,513
3
3
  cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
4
4
  cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
5
5
  cache_dit/cache_factory/__init__.py,sha256=iYQwLwB_XLoYl0OB9unZGDbBtrYvZaLkOAmhGRwdW2E,191
@@ -38,10 +38,11 @@ cache_dit/metrics/__init__.py,sha256=RaUhl5dieF40RqnizGzR30qoJJ9dyMUEADwgwMaMQrE
38
38
  cache_dit/metrics/config.py,sha256=ieOgD9ayz722RjVzk24bSIqS2D6o7TZjGk8KeXV-OLQ,551
39
39
  cache_dit/metrics/fid.py,sha256=9Ivtazl6mW0Bon2VXa-Ia5Xj2ewxRD3V1Qkd69zYM3Y,17066
40
40
  cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
41
- cache_dit/metrics/metrics.py,sha256=1TTbfaj_-vdUfxopLnc5kVrXs5rMpAoSi8D0ItYdPu8,26439
42
- cache_dit-0.2.12.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
43
- cache_dit-0.2.12.dist-info/METADATA,sha256=-AIWGVOFsY-nhMkDeFErUFcELTWmza96-0IUN3od88A,28219
44
- cache_dit-0.2.12.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
45
- cache_dit-0.2.12.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
46
- cache_dit-0.2.12.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
47
- cache_dit-0.2.12.dist-info/RECORD,,
41
+ cache_dit/metrics/lpips.py,sha256=I2qCNi6qJh5TRsaIsdxO0WoRX1DN7U_H3zS0oCSahYM,1032
42
+ cache_dit/metrics/metrics.py,sha256=8jvM1sF-nDxUuwCRy44QEoo4dYVLCQVh1QyAMs4eaQY,27840
43
+ cache_dit-0.2.13.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
44
+ cache_dit-0.2.13.dist-info/METADATA,sha256=at8DNFeGI5aVnBTi7_6zJgAi_QdgsItpBMzSGl8HEME,28247
45
+ cache_dit-0.2.13.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
46
+ cache_dit-0.2.13.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
47
+ cache_dit-0.2.13.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
48
+ cache_dit-0.2.13.dist-info/RECORD,,