cache-dit 0.2.34__py3-none-any.whl → 0.2.37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +5 -3
- cache_dit/_version.py +2 -2
- cache_dit/metrics/clip_score.py +135 -0
- cache_dit/metrics/fid.py +42 -0
- cache_dit/metrics/image_reward.py +177 -0
- cache_dit/metrics/lpips.py +2 -14
- cache_dit/metrics/metrics.py +449 -93
- cache_dit/utils.py +15 -0
- {cache_dit-0.2.34.dist-info → cache_dit-0.2.37.dist-info}/METADATA +142 -35
- {cache_dit-0.2.34.dist-info → cache_dit-0.2.37.dist-info}/RECORD +14 -12
- {cache_dit-0.2.34.dist-info → cache_dit-0.2.37.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.34.dist-info → cache_dit-0.2.37.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.34.dist-info → cache_dit-0.2.37.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.34.dist-info → cache_dit-0.2.37.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py
CHANGED
|
@@ -4,6 +4,10 @@ except ImportError:
|
|
|
4
4
|
__version__ = "unknown version"
|
|
5
5
|
version_tuple = (0, 0, "unknown version")
|
|
6
6
|
|
|
7
|
+
from cache_dit.utils import summary
|
|
8
|
+
from cache_dit.utils import strify
|
|
9
|
+
from cache_dit.utils import disable_print
|
|
10
|
+
from cache_dit.logger import init_logger
|
|
7
11
|
from cache_dit.cache_factory import load_options
|
|
8
12
|
from cache_dit.cache_factory import enable_cache
|
|
9
13
|
from cache_dit.cache_factory import disable_cache
|
|
@@ -18,9 +22,7 @@ from cache_dit.cache_factory import supported_pipelines
|
|
|
18
22
|
from cache_dit.cache_factory import get_adapter
|
|
19
23
|
from cache_dit.compile import set_compile_configs
|
|
20
24
|
from cache_dit.quantize import quantize
|
|
21
|
-
|
|
22
|
-
from cache_dit.utils import strify
|
|
23
|
-
from cache_dit.logger import init_logger
|
|
25
|
+
|
|
24
26
|
|
|
25
27
|
NONE = CacheType.NONE
|
|
26
28
|
DBCache = CacheType.DBCache
|
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.2.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 2,
|
|
31
|
+
__version__ = version = '0.2.37'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 2, 37)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import pathlib
|
|
4
|
+
import numpy as np
|
|
5
|
+
from PIL import Image
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
import torch
|
|
8
|
+
from transformers import CLIPProcessor, CLIPModel
|
|
9
|
+
|
|
10
|
+
from typing import Tuple, Union
|
|
11
|
+
from cache_dit.metrics.config import _IMAGE_EXTENSIONS
|
|
12
|
+
from cache_dit.metrics.config import get_metrics_verbose
|
|
13
|
+
from cache_dit.logger import init_logger
|
|
14
|
+
|
|
15
|
+
logger = init_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
DISABLE_VERBOSE = not get_metrics_verbose()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class CLIPScore:
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
|
25
|
+
clip_model_path: str = None,
|
|
26
|
+
):
|
|
27
|
+
self.device = device
|
|
28
|
+
if clip_model_path is None:
|
|
29
|
+
clip_model_path = os.environ.get(
|
|
30
|
+
"CLIP_MODEL_DIR", "laion/CLIP-ViT-g-14-laion2B-s12B-b42K"
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# Load models
|
|
34
|
+
self.clip_model = CLIPModel.from_pretrained(clip_model_path)
|
|
35
|
+
self.clip_model = self.clip_model.to(device) # type: ignore
|
|
36
|
+
self.clip_processor = CLIPProcessor.from_pretrained(clip_model_path)
|
|
37
|
+
|
|
38
|
+
@torch.no_grad()
|
|
39
|
+
def compute_clip_score(
|
|
40
|
+
self,
|
|
41
|
+
img: Image.Image | np.ndarray,
|
|
42
|
+
prompt: str,
|
|
43
|
+
) -> float:
|
|
44
|
+
if isinstance(img, Image.Image):
|
|
45
|
+
img_pil = img.convert("RGB")
|
|
46
|
+
elif isinstance(img, np.ndarray):
|
|
47
|
+
img_pil = Image.fromarray(img).convert("RGB")
|
|
48
|
+
else:
|
|
49
|
+
img_pil = Image.open(img).convert("RGB")
|
|
50
|
+
with torch.no_grad():
|
|
51
|
+
inputs = self.clip_processor(
|
|
52
|
+
text=prompt,
|
|
53
|
+
images=img_pil,
|
|
54
|
+
return_tensors="pt",
|
|
55
|
+
padding=True,
|
|
56
|
+
truncation=True,
|
|
57
|
+
).to(self.device)
|
|
58
|
+
outputs = self.clip_model(**inputs)
|
|
59
|
+
return outputs.logits_per_image.item()
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
clip_score_instance: CLIPScore = None
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def compute_clip_score_img(
|
|
66
|
+
img: Image.Image | np.ndarray | str,
|
|
67
|
+
prompt: str,
|
|
68
|
+
clip_model_path: str = None,
|
|
69
|
+
) -> float:
|
|
70
|
+
global clip_score_instance
|
|
71
|
+
if clip_score_instance is None:
|
|
72
|
+
clip_score_instance = CLIPScore(clip_model_path=clip_model_path)
|
|
73
|
+
assert clip_score_instance is not None
|
|
74
|
+
return clip_score_instance.compute_clip_score(img, prompt)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def compute_clip_score(
|
|
78
|
+
img_dir: Image.Image | np.ndarray | str,
|
|
79
|
+
prompts: str | list[str],
|
|
80
|
+
clip_model_path: str = None,
|
|
81
|
+
) -> Union[Tuple[float, int], Tuple[None, None]]:
|
|
82
|
+
if not os.path.isdir(img_dir) or (
|
|
83
|
+
not isinstance(prompts, list) and not os.path.isfile(prompts)
|
|
84
|
+
):
|
|
85
|
+
return (
|
|
86
|
+
compute_clip_score_img(
|
|
87
|
+
img_dir,
|
|
88
|
+
prompts,
|
|
89
|
+
clip_model_path=clip_model_path,
|
|
90
|
+
),
|
|
91
|
+
1,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# compute dir metric
|
|
95
|
+
def natural_sort_key(filename):
|
|
96
|
+
match = re.search(r"(\d+)\D*$", filename)
|
|
97
|
+
return int(match.group(1)) if match else filename
|
|
98
|
+
|
|
99
|
+
img_dir: pathlib.Path = pathlib.Path(img_dir)
|
|
100
|
+
img_files = [
|
|
101
|
+
file
|
|
102
|
+
for ext in _IMAGE_EXTENSIONS
|
|
103
|
+
for file in img_dir.rglob("*.{}".format(ext))
|
|
104
|
+
]
|
|
105
|
+
img_files = [file.as_posix() for file in img_files]
|
|
106
|
+
img_files = sorted(img_files, key=natural_sort_key)
|
|
107
|
+
|
|
108
|
+
if os.path.isfile(prompts):
|
|
109
|
+
"""Load prompts from file"""
|
|
110
|
+
with open(prompts, "r", encoding="utf-8") as f:
|
|
111
|
+
prompts_load = [line.strip() for line in f.readlines()]
|
|
112
|
+
prompts = prompts_load.copy()
|
|
113
|
+
|
|
114
|
+
vaild_len = min(len(img_files), len(prompts))
|
|
115
|
+
img_files = img_files[:vaild_len]
|
|
116
|
+
prompts = prompts[:vaild_len]
|
|
117
|
+
|
|
118
|
+
clip_scores = []
|
|
119
|
+
|
|
120
|
+
for img_file, prompt in tqdm(
|
|
121
|
+
zip(img_files, prompts),
|
|
122
|
+
total=vaild_len,
|
|
123
|
+
disable=not get_metrics_verbose(),
|
|
124
|
+
):
|
|
125
|
+
clip_scores.append(
|
|
126
|
+
compute_clip_score_img(
|
|
127
|
+
img_file,
|
|
128
|
+
prompt,
|
|
129
|
+
clip_model_path=clip_model_path,
|
|
130
|
+
)
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
if vaild_len > 0:
|
|
134
|
+
return np.mean(clip_scores), vaild_len
|
|
135
|
+
return None, None
|
cache_dit/metrics/fid.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import cv2
|
|
3
3
|
import pathlib
|
|
4
|
+
import warnings
|
|
5
|
+
|
|
4
6
|
import numpy as np
|
|
5
7
|
from PIL import Image
|
|
6
8
|
from tqdm import tqdm
|
|
@@ -8,13 +10,21 @@ from scipy import linalg
|
|
|
8
10
|
import torch
|
|
9
11
|
import torchvision.transforms as TF
|
|
10
12
|
from torch.nn.functional import adaptive_avg_pool2d
|
|
13
|
+
|
|
14
|
+
from typing import Tuple, Union
|
|
11
15
|
from cache_dit.metrics.inception import InceptionV3
|
|
12
16
|
from cache_dit.metrics.config import _IMAGE_EXTENSIONS
|
|
13
17
|
from cache_dit.metrics.config import _VIDEO_EXTENSIONS
|
|
18
|
+
from cache_dit.metrics.config import get_metrics_verbose
|
|
19
|
+
from cache_dit.utils import disable_print
|
|
14
20
|
from cache_dit.logger import init_logger
|
|
15
21
|
|
|
22
|
+
warnings.filterwarnings("ignore")
|
|
23
|
+
|
|
16
24
|
logger = init_logger(__name__)
|
|
17
25
|
|
|
26
|
+
DISABLE_VERBOSE = not get_metrics_verbose()
|
|
27
|
+
|
|
18
28
|
|
|
19
29
|
# Adapted from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py
|
|
20
30
|
class ImagePathDataset(torch.utils.data.Dataset):
|
|
@@ -496,3 +506,35 @@ class FrechetInceptionDistance:
|
|
|
496
506
|
return [], [], 0
|
|
497
507
|
|
|
498
508
|
return video_true_frames, video_test_frames, valid_frames
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
fid_instance: FrechetInceptionDistance = None
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
def compute_fid(
|
|
515
|
+
image_true: np.ndarray | str,
|
|
516
|
+
image_test: np.ndarray | str,
|
|
517
|
+
) -> Union[Tuple[float, int], Tuple[None, None]]:
|
|
518
|
+
global fid_instance
|
|
519
|
+
if fid_instance is None:
|
|
520
|
+
with disable_print():
|
|
521
|
+
fid_instance = FrechetInceptionDistance(
|
|
522
|
+
disable_tqdm=not get_metrics_verbose(),
|
|
523
|
+
)
|
|
524
|
+
assert fid_instance is not None
|
|
525
|
+
return fid_instance.compute_fid(image_true, image_test)
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
def compute_video_fid(
|
|
529
|
+
# file or dir
|
|
530
|
+
video_true: str,
|
|
531
|
+
video_test: str,
|
|
532
|
+
) -> Union[Tuple[float, int], Tuple[None, None]]:
|
|
533
|
+
global fid_instance
|
|
534
|
+
if fid_instance is None:
|
|
535
|
+
with disable_print():
|
|
536
|
+
fid_instance = FrechetInceptionDistance(
|
|
537
|
+
disable_tqdm=not get_metrics_verbose(),
|
|
538
|
+
)
|
|
539
|
+
assert fid_instance is not None
|
|
540
|
+
return fid_instance.compute_fid(video_true, video_test)
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import pathlib
|
|
4
|
+
import warnings
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from PIL import Image
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
import torch
|
|
10
|
+
import ImageReward as RM
|
|
11
|
+
import torchvision.transforms.v2.functional as TF
|
|
12
|
+
import torchvision.transforms.v2 as T
|
|
13
|
+
|
|
14
|
+
from typing import Tuple, Union
|
|
15
|
+
from cache_dit.metrics.config import _IMAGE_EXTENSIONS
|
|
16
|
+
from cache_dit.metrics.config import get_metrics_verbose
|
|
17
|
+
from cache_dit.utils import disable_print
|
|
18
|
+
from cache_dit.logger import init_logger
|
|
19
|
+
|
|
20
|
+
warnings.filterwarnings("ignore")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
logger = init_logger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
DISABLE_VERBOSE = not get_metrics_verbose()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ImageRewardScore:
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
|
33
|
+
imagereward_model_path: str = None,
|
|
34
|
+
):
|
|
35
|
+
self.device = device
|
|
36
|
+
if imagereward_model_path is None:
|
|
37
|
+
imagereward_model_path = os.environ.get(
|
|
38
|
+
"IMAGEREWARD_MODEL_DIR", None
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
# Load ImageReward model
|
|
42
|
+
self.med_config = os.path.join(
|
|
43
|
+
imagereward_model_path, "med_config.json"
|
|
44
|
+
)
|
|
45
|
+
self.imagereward_path = os.path.join(
|
|
46
|
+
imagereward_model_path, "ImageReward.pt"
|
|
47
|
+
)
|
|
48
|
+
if imagereward_model_path is not None:
|
|
49
|
+
self.imagereward_model = RM.load(
|
|
50
|
+
self.imagereward_path,
|
|
51
|
+
download_root=imagereward_model_path,
|
|
52
|
+
med_config=self.med_config,
|
|
53
|
+
).to(self.device)
|
|
54
|
+
else:
|
|
55
|
+
self.imagereward_model = RM.load(
|
|
56
|
+
"ImageReward-v1.0", # download from huggingface
|
|
57
|
+
).to(self.device)
|
|
58
|
+
|
|
59
|
+
# ImageReward transform
|
|
60
|
+
self.reward_transform = T.Compose(
|
|
61
|
+
[
|
|
62
|
+
T.Resize(224, interpolation=T.InterpolationMode.BICUBIC),
|
|
63
|
+
T.CenterCrop(224),
|
|
64
|
+
T.ToImage(),
|
|
65
|
+
T.ToDtype(torch.float32, scale=True),
|
|
66
|
+
T.Normalize(
|
|
67
|
+
(0.48145466, 0.4578275, 0.40821073),
|
|
68
|
+
(0.26862954, 0.26130258, 0.27577711),
|
|
69
|
+
),
|
|
70
|
+
]
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
@torch.no_grad()
|
|
74
|
+
def compute_reward_score(
|
|
75
|
+
self,
|
|
76
|
+
img: Image.Image | np.ndarray,
|
|
77
|
+
prompt: str,
|
|
78
|
+
) -> float:
|
|
79
|
+
if isinstance(img, Image.Image):
|
|
80
|
+
img_pil = img.convert("RGB")
|
|
81
|
+
elif isinstance(img, np.ndarray):
|
|
82
|
+
img_pil = Image.fromarray(img).convert("RGB")
|
|
83
|
+
else:
|
|
84
|
+
img_pil = Image.open(img).convert("RGB")
|
|
85
|
+
with torch.no_grad():
|
|
86
|
+
img_tensor = TF.pil_to_tensor(img_pil).unsqueeze(0).to(self.device)
|
|
87
|
+
img_reward = self.reward_transform(img_tensor)
|
|
88
|
+
inputs = self.imagereward_model.blip.tokenizer(
|
|
89
|
+
[prompt],
|
|
90
|
+
padding="max_length",
|
|
91
|
+
truncation=True,
|
|
92
|
+
max_length=512,
|
|
93
|
+
return_tensors="pt",
|
|
94
|
+
).to(self.device)
|
|
95
|
+
score = self.imagereward_model.score_gard(
|
|
96
|
+
inputs.input_ids, inputs.attention_mask, img_reward
|
|
97
|
+
)
|
|
98
|
+
return score.item()
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
image_reward_score_instance: ImageRewardScore = None
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def compute_reward_score_img(
|
|
105
|
+
img: Image.Image | np.ndarray | str,
|
|
106
|
+
prompt: str,
|
|
107
|
+
imagereward_model_path: str = None,
|
|
108
|
+
) -> float:
|
|
109
|
+
global image_reward_score_instance
|
|
110
|
+
if image_reward_score_instance is None:
|
|
111
|
+
with disable_print():
|
|
112
|
+
image_reward_score_instance = ImageRewardScore(
|
|
113
|
+
imagereward_model_path=imagereward_model_path
|
|
114
|
+
)
|
|
115
|
+
assert image_reward_score_instance is not None
|
|
116
|
+
return image_reward_score_instance.compute_reward_score(img, prompt)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def compute_reward_score(
|
|
120
|
+
img_dir: Image.Image | np.ndarray | str,
|
|
121
|
+
prompts: str | list[str],
|
|
122
|
+
imagereward_model_path: str = None,
|
|
123
|
+
) -> Union[Tuple[float, int], Tuple[None, None]]:
|
|
124
|
+
if not os.path.isdir(img_dir) or (
|
|
125
|
+
not isinstance(prompts, list) and not os.path.isfile(prompts)
|
|
126
|
+
):
|
|
127
|
+
return (
|
|
128
|
+
compute_reward_score_img(
|
|
129
|
+
img_dir,
|
|
130
|
+
prompts,
|
|
131
|
+
imagereward_model_path=imagereward_model_path,
|
|
132
|
+
),
|
|
133
|
+
1,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# compute dir metric
|
|
137
|
+
def natural_sort_key(filename):
|
|
138
|
+
match = re.search(r"(\d+)\D*$", filename)
|
|
139
|
+
return int(match.group(1)) if match else filename
|
|
140
|
+
|
|
141
|
+
img_dir: pathlib.Path = pathlib.Path(img_dir)
|
|
142
|
+
img_files = [
|
|
143
|
+
file
|
|
144
|
+
for ext in _IMAGE_EXTENSIONS
|
|
145
|
+
for file in img_dir.rglob("*.{}".format(ext))
|
|
146
|
+
]
|
|
147
|
+
img_files = [file.as_posix() for file in img_files]
|
|
148
|
+
img_files = sorted(img_files, key=natural_sort_key)
|
|
149
|
+
|
|
150
|
+
if os.path.isfile(prompts):
|
|
151
|
+
"""Load prompts from file"""
|
|
152
|
+
with open(prompts, "r", encoding="utf-8") as f:
|
|
153
|
+
prompts_load = [line.strip() for line in f.readlines()]
|
|
154
|
+
prompts = prompts_load.copy()
|
|
155
|
+
|
|
156
|
+
vaild_len = min(len(img_files), len(prompts))
|
|
157
|
+
img_files = img_files[:vaild_len]
|
|
158
|
+
prompts = prompts[:vaild_len]
|
|
159
|
+
|
|
160
|
+
reward_scores = []
|
|
161
|
+
|
|
162
|
+
for img_file, prompt in tqdm(
|
|
163
|
+
zip(img_files, prompts),
|
|
164
|
+
total=vaild_len,
|
|
165
|
+
disable=not get_metrics_verbose(),
|
|
166
|
+
):
|
|
167
|
+
reward_scores.append(
|
|
168
|
+
compute_reward_score_img(
|
|
169
|
+
img_file,
|
|
170
|
+
prompt,
|
|
171
|
+
imagereward_model_path=imagereward_model_path,
|
|
172
|
+
)
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
if vaild_len > 0:
|
|
176
|
+
return np.mean(reward_scores), vaild_len
|
|
177
|
+
return None, None
|
cache_dit/metrics/lpips.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
import builtins as __builtin__
|
|
2
|
-
import contextlib
|
|
3
1
|
import warnings
|
|
4
2
|
|
|
5
3
|
import lpips
|
|
6
4
|
import torch
|
|
5
|
+
from cache_dit.utils import disable_print
|
|
6
|
+
|
|
7
7
|
|
|
8
8
|
warnings.filterwarnings("ignore")
|
|
9
9
|
|
|
@@ -11,18 +11,6 @@ lpips_loss_fn_vgg = None
|
|
|
11
11
|
lpips_loss_fn_alex = None
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
def dummy_print(*args, **kwargs):
|
|
15
|
-
pass
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
@contextlib.contextmanager
|
|
19
|
-
def disable_print():
|
|
20
|
-
origin_print = __builtin__.print
|
|
21
|
-
__builtin__.print = dummy_print
|
|
22
|
-
yield
|
|
23
|
-
__builtin__.print = origin_print
|
|
24
|
-
|
|
25
|
-
|
|
26
14
|
def compute_lpips_img(img0, img1, net: str = "alex"):
|
|
27
15
|
global lpips_loss_fn_vgg
|
|
28
16
|
global lpips_loss_fn_alex
|