cache-dit 0.2.12__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/_version.py +2 -2
- cache_dit/metrics/lpips.py +43 -0
- cache_dit/metrics/metrics.py +46 -0
- {cache_dit-0.2.12.dist-info → cache_dit-0.2.13.dist-info}/METADATA +2 -1
- {cache_dit-0.2.12.dist-info → cache_dit-0.2.13.dist-info}/RECORD +9 -8
- {cache_dit-0.2.12.dist-info → cache_dit-0.2.13.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.12.dist-info → cache_dit-0.2.13.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.12.dist-info → cache_dit-0.2.13.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.12.dist-info → cache_dit-0.2.13.dist-info}/top_level.txt +0 -0
cache_dit/_version.py
CHANGED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import builtins as __builtin__
|
|
2
|
+
import contextlib
|
|
3
|
+
import warnings
|
|
4
|
+
|
|
5
|
+
import lpips
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
warnings.filterwarnings("ignore")
|
|
9
|
+
|
|
10
|
+
lpips_loss_fn_vgg = None
|
|
11
|
+
lpips_loss_fn_alex = None
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def dummy_print(*args, **kwargs):
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@contextlib.contextmanager
|
|
19
|
+
def disable_print():
|
|
20
|
+
origin_print = __builtin__.print
|
|
21
|
+
__builtin__.print = dummy_print
|
|
22
|
+
yield
|
|
23
|
+
__builtin__.print = origin_print
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def compute_lpips_img(img0, img1, net: str = "alex"):
|
|
27
|
+
global lpips_loss_fn_vgg
|
|
28
|
+
global lpips_loss_fn_alex
|
|
29
|
+
if net.lower() == "alex":
|
|
30
|
+
if lpips_loss_fn_alex is None:
|
|
31
|
+
with disable_print():
|
|
32
|
+
lpips_loss_fn_alex = lpips.LPIPS(net="alex")
|
|
33
|
+
loss_fn = lpips_loss_fn_alex
|
|
34
|
+
elif net.lower() == "vgg":
|
|
35
|
+
if lpips_loss_fn_vgg is None:
|
|
36
|
+
with disable_print():
|
|
37
|
+
lpips_loss_fn_vgg = lpips.LPIPS(net="vgg")
|
|
38
|
+
loss_fn = lpips_loss_fn_vgg
|
|
39
|
+
else:
|
|
40
|
+
assert False, f"unsupport net {net}"
|
|
41
|
+
|
|
42
|
+
with torch.no_grad():
|
|
43
|
+
return loss_fn(img0, img1).item()
|
cache_dit/metrics/metrics.py
CHANGED
|
@@ -14,6 +14,7 @@ from cache_dit.metrics.config import get_metrics_verbose
|
|
|
14
14
|
from cache_dit.metrics.config import _IMAGE_EXTENSIONS
|
|
15
15
|
from cache_dit.metrics.config import _VIDEO_EXTENSIONS
|
|
16
16
|
from cache_dit.logger import init_logger
|
|
17
|
+
from cache_dit.metrics.lpips import compute_lpips_img
|
|
17
18
|
|
|
18
19
|
logger = init_logger(__name__)
|
|
19
20
|
|
|
@@ -21,6 +22,35 @@ logger = init_logger(__name__)
|
|
|
21
22
|
DISABLE_VERBOSE = not get_metrics_verbose()
|
|
22
23
|
|
|
23
24
|
|
|
25
|
+
def compute_lpips_file(
|
|
26
|
+
image_true: np.ndarray | str,
|
|
27
|
+
image_test: np.ndarray | str,
|
|
28
|
+
) -> float:
|
|
29
|
+
import torch
|
|
30
|
+
from PIL import Image
|
|
31
|
+
from torchvision.transforms.v2.functional import (
|
|
32
|
+
convert_image_dtype,
|
|
33
|
+
normalize,
|
|
34
|
+
pil_to_tensor,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
def load_img_as_tensor(path):
|
|
38
|
+
pil = Image.open(path)
|
|
39
|
+
img = pil_to_tensor(pil)
|
|
40
|
+
img = convert_image_dtype(img, dtype=torch.float32)
|
|
41
|
+
img = normalize(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
|
42
|
+
return img
|
|
43
|
+
|
|
44
|
+
if isinstance(image_true, str):
|
|
45
|
+
image_true = load_img_as_tensor(image_true)
|
|
46
|
+
if isinstance(image_test, str):
|
|
47
|
+
image_test = load_img_as_tensor(image_test)
|
|
48
|
+
return compute_lpips_img(
|
|
49
|
+
image_true,
|
|
50
|
+
image_test,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
24
54
|
def compute_psnr_file(
|
|
25
55
|
image_true: np.ndarray | str,
|
|
26
56
|
image_test: np.ndarray | str,
|
|
@@ -305,6 +335,11 @@ def compute_video_metric(
|
|
|
305
335
|
return None, None
|
|
306
336
|
|
|
307
337
|
|
|
338
|
+
compute_lpips = partial(
|
|
339
|
+
compute_dir_metric,
|
|
340
|
+
compute_file_func=compute_lpips_file,
|
|
341
|
+
)
|
|
342
|
+
|
|
308
343
|
compute_psnr = partial(
|
|
309
344
|
compute_dir_metric,
|
|
310
345
|
compute_file_func=compute_psnr_file,
|
|
@@ -320,6 +355,10 @@ compute_mse = partial(
|
|
|
320
355
|
compute_file_func=compute_mse_file,
|
|
321
356
|
)
|
|
322
357
|
|
|
358
|
+
compute_video_lpips = partial(
|
|
359
|
+
compute_video_metric,
|
|
360
|
+
compute_frame_func=compute_lpips_file,
|
|
361
|
+
)
|
|
323
362
|
compute_video_psnr = partial(
|
|
324
363
|
compute_video_metric,
|
|
325
364
|
compute_frame_func=compute_psnr_file,
|
|
@@ -335,6 +374,7 @@ compute_video_mse = partial(
|
|
|
335
374
|
|
|
336
375
|
|
|
337
376
|
METRICS_CHOICES = [
|
|
377
|
+
"lpips",
|
|
338
378
|
"psnr",
|
|
339
379
|
"ssim",
|
|
340
380
|
"mse",
|
|
@@ -522,6 +562,9 @@ def entrypoint():
|
|
|
522
562
|
METRICS_META[msg] = value
|
|
523
563
|
logger.info(msg)
|
|
524
564
|
|
|
565
|
+
if metric == "lpips" or metric == "all":
|
|
566
|
+
img_lpips, n = compute_lpips(img_true, img_test)
|
|
567
|
+
_logging_msg(img_lpips, "lpips", n)
|
|
525
568
|
if metric == "psnr" or metric == "all":
|
|
526
569
|
img_psnr, n = compute_psnr(img_true, img_test)
|
|
527
570
|
_logging_msg(img_psnr, "psnr", n)
|
|
@@ -558,6 +601,9 @@ def entrypoint():
|
|
|
558
601
|
METRICS_META[msg] = value
|
|
559
602
|
logger.info(msg)
|
|
560
603
|
|
|
604
|
+
if metric == "lpips" or metric == "all":
|
|
605
|
+
video_lpips, n = compute_video_lpips(video_true, video_test)
|
|
606
|
+
_logging_msg(video_lpips, "lpips", n)
|
|
561
607
|
if metric == "psnr" or metric == "all":
|
|
562
608
|
video_psnr, n = compute_video_psnr(video_true, video_test)
|
|
563
609
|
_logging_msg(video_psnr, "psnr", n)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.13
|
|
4
4
|
Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
@@ -16,6 +16,7 @@ Requires-Dist: transformers>=4.51.3
|
|
|
16
16
|
Requires-Dist: diffusers>=0.33.1
|
|
17
17
|
Requires-Dist: scikit-image
|
|
18
18
|
Requires-Dist: scipy
|
|
19
|
+
Requires-Dist: lpips==0.1.4
|
|
19
20
|
Provides-Extra: all
|
|
20
21
|
Provides-Extra: dev
|
|
21
22
|
Requires-Dist: pre-commit; extra == "dev"
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
cache_dit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
cache_dit/_version.py,sha256=
|
|
2
|
+
cache_dit/_version.py,sha256=2ECxD0Bipdh9vxnyteM0k9jxi9NOpPR7YxTi7Ad1ors,513
|
|
3
3
|
cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
|
|
4
4
|
cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
|
|
5
5
|
cache_dit/cache_factory/__init__.py,sha256=iYQwLwB_XLoYl0OB9unZGDbBtrYvZaLkOAmhGRwdW2E,191
|
|
@@ -38,10 +38,11 @@ cache_dit/metrics/__init__.py,sha256=RaUhl5dieF40RqnizGzR30qoJJ9dyMUEADwgwMaMQrE
|
|
|
38
38
|
cache_dit/metrics/config.py,sha256=ieOgD9ayz722RjVzk24bSIqS2D6o7TZjGk8KeXV-OLQ,551
|
|
39
39
|
cache_dit/metrics/fid.py,sha256=9Ivtazl6mW0Bon2VXa-Ia5Xj2ewxRD3V1Qkd69zYM3Y,17066
|
|
40
40
|
cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
|
|
41
|
-
cache_dit/metrics/
|
|
42
|
-
cache_dit
|
|
43
|
-
cache_dit-0.2.
|
|
44
|
-
cache_dit-0.2.
|
|
45
|
-
cache_dit-0.2.
|
|
46
|
-
cache_dit-0.2.
|
|
47
|
-
cache_dit-0.2.
|
|
41
|
+
cache_dit/metrics/lpips.py,sha256=I2qCNi6qJh5TRsaIsdxO0WoRX1DN7U_H3zS0oCSahYM,1032
|
|
42
|
+
cache_dit/metrics/metrics.py,sha256=8jvM1sF-nDxUuwCRy44QEoo4dYVLCQVh1QyAMs4eaQY,27840
|
|
43
|
+
cache_dit-0.2.13.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
|
|
44
|
+
cache_dit-0.2.13.dist-info/METADATA,sha256=at8DNFeGI5aVnBTi7_6zJgAi_QdgsItpBMzSGl8HEME,28247
|
|
45
|
+
cache_dit-0.2.13.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
46
|
+
cache_dit-0.2.13.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
|
|
47
|
+
cache_dit-0.2.13.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
|
|
48
|
+
cache_dit-0.2.13.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|