cache-dit 0.2.33__py3-none-any.whl → 0.2.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +5 -3
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/block_adapters/__init__.py +14 -20
- cache_dit/cache_factory/block_adapters/block_adapters.py +46 -2
- cache_dit/cache_factory/block_adapters/block_registers.py +3 -2
- cache_dit/cache_factory/cache_adapters.py +8 -8
- cache_dit/cache_factory/cache_contexts/cache_context.py +11 -11
- cache_dit/cache_factory/cache_contexts/cache_manager.py +5 -5
- cache_dit/cache_factory/cache_contexts/taylorseer.py +12 -6
- cache_dit/cache_factory/cache_interface.py +9 -9
- cache_dit/cache_factory/patch_functors/__init__.py +1 -0
- cache_dit/cache_factory/patch_functors/functor_chroma.py +142 -52
- cache_dit/cache_factory/patch_functors/functor_dit.py +130 -0
- cache_dit/metrics/clip_score.py +135 -0
- cache_dit/metrics/fid.py +42 -0
- cache_dit/metrics/image_reward.py +177 -0
- cache_dit/metrics/lpips.py +2 -14
- cache_dit/metrics/metrics.py +420 -76
- cache_dit/utils.py +15 -0
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.36.dist-info}/METADATA +261 -52
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.36.dist-info}/RECORD +25 -22
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.36.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.36.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.36.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.36.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import pathlib
|
|
4
|
+
import warnings
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from PIL import Image
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
import torch
|
|
10
|
+
import ImageReward as RM
|
|
11
|
+
import torchvision.transforms.v2.functional as TF
|
|
12
|
+
import torchvision.transforms.v2 as T
|
|
13
|
+
|
|
14
|
+
from typing import Tuple, Union
|
|
15
|
+
from cache_dit.metrics.config import _IMAGE_EXTENSIONS
|
|
16
|
+
from cache_dit.metrics.config import get_metrics_verbose
|
|
17
|
+
from cache_dit.utils import disable_print
|
|
18
|
+
from cache_dit.logger import init_logger
|
|
19
|
+
|
|
20
|
+
warnings.filterwarnings("ignore")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
logger = init_logger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
DISABLE_VERBOSE = not get_metrics_verbose()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ImageRewardScore:
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
|
33
|
+
imagereward_model_path: str = None,
|
|
34
|
+
):
|
|
35
|
+
self.device = device
|
|
36
|
+
if imagereward_model_path is None:
|
|
37
|
+
imagereward_model_path = os.environ.get(
|
|
38
|
+
"IMAGEREWARD_MODEL_DIR", None
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
# Load ImageReward model
|
|
42
|
+
self.med_config = os.path.join(
|
|
43
|
+
imagereward_model_path, "med_config.json"
|
|
44
|
+
)
|
|
45
|
+
self.imagereward_path = os.path.join(
|
|
46
|
+
imagereward_model_path, "ImageReward.pt"
|
|
47
|
+
)
|
|
48
|
+
if imagereward_model_path is not None:
|
|
49
|
+
self.imagereward_model = RM.load(
|
|
50
|
+
self.imagereward_path,
|
|
51
|
+
download_root=imagereward_model_path,
|
|
52
|
+
med_config=self.med_config,
|
|
53
|
+
).to(self.device)
|
|
54
|
+
else:
|
|
55
|
+
self.imagereward_model = RM.load(
|
|
56
|
+
"ImageReward-v1.0", # download from huggingface
|
|
57
|
+
).to(self.device)
|
|
58
|
+
|
|
59
|
+
# ImageReward transform
|
|
60
|
+
self.reward_transform = T.Compose(
|
|
61
|
+
[
|
|
62
|
+
T.Resize(224, interpolation=T.InterpolationMode.BICUBIC),
|
|
63
|
+
T.CenterCrop(224),
|
|
64
|
+
T.ToImage(),
|
|
65
|
+
T.ToDtype(torch.float32, scale=True),
|
|
66
|
+
T.Normalize(
|
|
67
|
+
(0.48145466, 0.4578275, 0.40821073),
|
|
68
|
+
(0.26862954, 0.26130258, 0.27577711),
|
|
69
|
+
),
|
|
70
|
+
]
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
@torch.no_grad()
|
|
74
|
+
def compute_reward_score(
|
|
75
|
+
self,
|
|
76
|
+
img: Image.Image | np.ndarray,
|
|
77
|
+
prompt: str,
|
|
78
|
+
) -> float:
|
|
79
|
+
if isinstance(img, Image.Image):
|
|
80
|
+
img_pil = img.convert("RGB")
|
|
81
|
+
elif isinstance(img, np.ndarray):
|
|
82
|
+
img_pil = Image.fromarray(img).convert("RGB")
|
|
83
|
+
else:
|
|
84
|
+
img_pil = Image.open(img).convert("RGB")
|
|
85
|
+
with torch.no_grad():
|
|
86
|
+
img_tensor = TF.pil_to_tensor(img_pil).unsqueeze(0).to(self.device)
|
|
87
|
+
img_reward = self.reward_transform(img_tensor)
|
|
88
|
+
inputs = self.imagereward_model.blip.tokenizer(
|
|
89
|
+
[prompt],
|
|
90
|
+
padding="max_length",
|
|
91
|
+
truncation=True,
|
|
92
|
+
max_length=512,
|
|
93
|
+
return_tensors="pt",
|
|
94
|
+
).to(self.device)
|
|
95
|
+
score = self.imagereward_model.score_gard(
|
|
96
|
+
inputs.input_ids, inputs.attention_mask, img_reward
|
|
97
|
+
)
|
|
98
|
+
return score.item()
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
image_reward_score_instance: ImageRewardScore = None
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def compute_reward_score_img(
|
|
105
|
+
img: Image.Image | np.ndarray | str,
|
|
106
|
+
prompt: str,
|
|
107
|
+
imagereward_model_path: str = None,
|
|
108
|
+
) -> float:
|
|
109
|
+
global image_reward_score_instance
|
|
110
|
+
if image_reward_score_instance is None:
|
|
111
|
+
with disable_print():
|
|
112
|
+
image_reward_score_instance = ImageRewardScore(
|
|
113
|
+
imagereward_model_path=imagereward_model_path
|
|
114
|
+
)
|
|
115
|
+
assert image_reward_score_instance is not None
|
|
116
|
+
return image_reward_score_instance.compute_reward_score(img, prompt)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def compute_reward_score(
|
|
120
|
+
img_dir: Image.Image | np.ndarray | str,
|
|
121
|
+
prompts: str | list[str],
|
|
122
|
+
imagereward_model_path: str = None,
|
|
123
|
+
) -> Union[Tuple[float, int], Tuple[None, None]]:
|
|
124
|
+
if not os.path.isdir(img_dir) or (
|
|
125
|
+
not isinstance(prompts, list) and not os.path.isfile(prompts)
|
|
126
|
+
):
|
|
127
|
+
return (
|
|
128
|
+
compute_reward_score_img(
|
|
129
|
+
img_dir,
|
|
130
|
+
prompts,
|
|
131
|
+
imagereward_model_path=imagereward_model_path,
|
|
132
|
+
),
|
|
133
|
+
1,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# compute dir metric
|
|
137
|
+
def natural_sort_key(filename):
|
|
138
|
+
match = re.search(r"(\d+)\D*$", filename)
|
|
139
|
+
return int(match.group(1)) if match else filename
|
|
140
|
+
|
|
141
|
+
img_dir: pathlib.Path = pathlib.Path(img_dir)
|
|
142
|
+
img_files = [
|
|
143
|
+
file
|
|
144
|
+
for ext in _IMAGE_EXTENSIONS
|
|
145
|
+
for file in img_dir.rglob("*.{}".format(ext))
|
|
146
|
+
]
|
|
147
|
+
img_files = [file.as_posix() for file in img_files]
|
|
148
|
+
img_files = sorted(img_files, key=natural_sort_key)
|
|
149
|
+
|
|
150
|
+
if os.path.isfile(prompts):
|
|
151
|
+
"""Load prompts from file"""
|
|
152
|
+
with open(prompts, "r", encoding="utf-8") as f:
|
|
153
|
+
prompts_load = [line.strip() for line in f.readlines()]
|
|
154
|
+
prompts = prompts_load.copy()
|
|
155
|
+
|
|
156
|
+
vaild_len = min(len(img_files), len(prompts))
|
|
157
|
+
img_files = img_files[:vaild_len]
|
|
158
|
+
prompts = prompts[:vaild_len]
|
|
159
|
+
|
|
160
|
+
reward_scores = []
|
|
161
|
+
|
|
162
|
+
for img_file, prompt in tqdm(
|
|
163
|
+
zip(img_files, prompts),
|
|
164
|
+
total=vaild_len,
|
|
165
|
+
disable=not get_metrics_verbose(),
|
|
166
|
+
):
|
|
167
|
+
reward_scores.append(
|
|
168
|
+
compute_reward_score_img(
|
|
169
|
+
img_file,
|
|
170
|
+
prompt,
|
|
171
|
+
imagereward_model_path=imagereward_model_path,
|
|
172
|
+
)
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
if vaild_len > 0:
|
|
176
|
+
return np.mean(reward_scores), vaild_len
|
|
177
|
+
return None, None
|
cache_dit/metrics/lpips.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
import builtins as __builtin__
|
|
2
|
-
import contextlib
|
|
3
1
|
import warnings
|
|
4
2
|
|
|
5
3
|
import lpips
|
|
6
4
|
import torch
|
|
5
|
+
from cache_dit.utils import disable_print
|
|
6
|
+
|
|
7
7
|
|
|
8
8
|
warnings.filterwarnings("ignore")
|
|
9
9
|
|
|
@@ -11,18 +11,6 @@ lpips_loss_fn_vgg = None
|
|
|
11
11
|
lpips_loss_fn_alex = None
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
def dummy_print(*args, **kwargs):
|
|
15
|
-
pass
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
@contextlib.contextmanager
|
|
19
|
-
def disable_print():
|
|
20
|
-
origin_print = __builtin__.print
|
|
21
|
-
__builtin__.print = dummy_print
|
|
22
|
-
yield
|
|
23
|
-
__builtin__.print = origin_print
|
|
24
|
-
|
|
25
|
-
|
|
26
14
|
def compute_lpips_img(img0, img1, net: str = "alex"):
|
|
27
15
|
global lpips_loss_fn_vgg
|
|
28
16
|
global lpips_loss_fn_alex
|