cache-dit 0.2.33__py3-none-any.whl → 0.2.36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,177 @@
1
+ import os
2
+ import re
3
+ import pathlib
4
+ import warnings
5
+
6
+ import numpy as np
7
+ from PIL import Image
8
+ from tqdm import tqdm
9
+ import torch
10
+ import ImageReward as RM
11
+ import torchvision.transforms.v2.functional as TF
12
+ import torchvision.transforms.v2 as T
13
+
14
+ from typing import Tuple, Union
15
+ from cache_dit.metrics.config import _IMAGE_EXTENSIONS
16
+ from cache_dit.metrics.config import get_metrics_verbose
17
+ from cache_dit.utils import disable_print
18
+ from cache_dit.logger import init_logger
19
+
20
+ warnings.filterwarnings("ignore")
21
+
22
+
23
+ logger = init_logger(__name__)
24
+
25
+
26
+ DISABLE_VERBOSE = not get_metrics_verbose()
27
+
28
+
29
+ class ImageRewardScore:
30
+ def __init__(
31
+ self,
32
+ device="cuda" if torch.cuda.is_available() else "cpu",
33
+ imagereward_model_path: str = None,
34
+ ):
35
+ self.device = device
36
+ if imagereward_model_path is None:
37
+ imagereward_model_path = os.environ.get(
38
+ "IMAGEREWARD_MODEL_DIR", None
39
+ )
40
+
41
+ # Load ImageReward model
42
+ self.med_config = os.path.join(
43
+ imagereward_model_path, "med_config.json"
44
+ )
45
+ self.imagereward_path = os.path.join(
46
+ imagereward_model_path, "ImageReward.pt"
47
+ )
48
+ if imagereward_model_path is not None:
49
+ self.imagereward_model = RM.load(
50
+ self.imagereward_path,
51
+ download_root=imagereward_model_path,
52
+ med_config=self.med_config,
53
+ ).to(self.device)
54
+ else:
55
+ self.imagereward_model = RM.load(
56
+ "ImageReward-v1.0", # download from huggingface
57
+ ).to(self.device)
58
+
59
+ # ImageReward transform
60
+ self.reward_transform = T.Compose(
61
+ [
62
+ T.Resize(224, interpolation=T.InterpolationMode.BICUBIC),
63
+ T.CenterCrop(224),
64
+ T.ToImage(),
65
+ T.ToDtype(torch.float32, scale=True),
66
+ T.Normalize(
67
+ (0.48145466, 0.4578275, 0.40821073),
68
+ (0.26862954, 0.26130258, 0.27577711),
69
+ ),
70
+ ]
71
+ )
72
+
73
+ @torch.no_grad()
74
+ def compute_reward_score(
75
+ self,
76
+ img: Image.Image | np.ndarray,
77
+ prompt: str,
78
+ ) -> float:
79
+ if isinstance(img, Image.Image):
80
+ img_pil = img.convert("RGB")
81
+ elif isinstance(img, np.ndarray):
82
+ img_pil = Image.fromarray(img).convert("RGB")
83
+ else:
84
+ img_pil = Image.open(img).convert("RGB")
85
+ with torch.no_grad():
86
+ img_tensor = TF.pil_to_tensor(img_pil).unsqueeze(0).to(self.device)
87
+ img_reward = self.reward_transform(img_tensor)
88
+ inputs = self.imagereward_model.blip.tokenizer(
89
+ [prompt],
90
+ padding="max_length",
91
+ truncation=True,
92
+ max_length=512,
93
+ return_tensors="pt",
94
+ ).to(self.device)
95
+ score = self.imagereward_model.score_gard(
96
+ inputs.input_ids, inputs.attention_mask, img_reward
97
+ )
98
+ return score.item()
99
+
100
+
101
+ image_reward_score_instance: ImageRewardScore = None
102
+
103
+
104
+ def compute_reward_score_img(
105
+ img: Image.Image | np.ndarray | str,
106
+ prompt: str,
107
+ imagereward_model_path: str = None,
108
+ ) -> float:
109
+ global image_reward_score_instance
110
+ if image_reward_score_instance is None:
111
+ with disable_print():
112
+ image_reward_score_instance = ImageRewardScore(
113
+ imagereward_model_path=imagereward_model_path
114
+ )
115
+ assert image_reward_score_instance is not None
116
+ return image_reward_score_instance.compute_reward_score(img, prompt)
117
+
118
+
119
+ def compute_reward_score(
120
+ img_dir: Image.Image | np.ndarray | str,
121
+ prompts: str | list[str],
122
+ imagereward_model_path: str = None,
123
+ ) -> Union[Tuple[float, int], Tuple[None, None]]:
124
+ if not os.path.isdir(img_dir) or (
125
+ not isinstance(prompts, list) and not os.path.isfile(prompts)
126
+ ):
127
+ return (
128
+ compute_reward_score_img(
129
+ img_dir,
130
+ prompts,
131
+ imagereward_model_path=imagereward_model_path,
132
+ ),
133
+ 1,
134
+ )
135
+
136
+ # compute dir metric
137
+ def natural_sort_key(filename):
138
+ match = re.search(r"(\d+)\D*$", filename)
139
+ return int(match.group(1)) if match else filename
140
+
141
+ img_dir: pathlib.Path = pathlib.Path(img_dir)
142
+ img_files = [
143
+ file
144
+ for ext in _IMAGE_EXTENSIONS
145
+ for file in img_dir.rglob("*.{}".format(ext))
146
+ ]
147
+ img_files = [file.as_posix() for file in img_files]
148
+ img_files = sorted(img_files, key=natural_sort_key)
149
+
150
+ if os.path.isfile(prompts):
151
+ """Load prompts from file"""
152
+ with open(prompts, "r", encoding="utf-8") as f:
153
+ prompts_load = [line.strip() for line in f.readlines()]
154
+ prompts = prompts_load.copy()
155
+
156
+ vaild_len = min(len(img_files), len(prompts))
157
+ img_files = img_files[:vaild_len]
158
+ prompts = prompts[:vaild_len]
159
+
160
+ reward_scores = []
161
+
162
+ for img_file, prompt in tqdm(
163
+ zip(img_files, prompts),
164
+ total=vaild_len,
165
+ disable=not get_metrics_verbose(),
166
+ ):
167
+ reward_scores.append(
168
+ compute_reward_score_img(
169
+ img_file,
170
+ prompt,
171
+ imagereward_model_path=imagereward_model_path,
172
+ )
173
+ )
174
+
175
+ if vaild_len > 0:
176
+ return np.mean(reward_scores), vaild_len
177
+ return None, None
@@ -1,9 +1,9 @@
1
- import builtins as __builtin__
2
- import contextlib
3
1
  import warnings
4
2
 
5
3
  import lpips
6
4
  import torch
5
+ from cache_dit.utils import disable_print
6
+
7
7
 
8
8
  warnings.filterwarnings("ignore")
9
9
 
@@ -11,18 +11,6 @@ lpips_loss_fn_vgg = None
11
11
  lpips_loss_fn_alex = None
12
12
 
13
13
 
14
- def dummy_print(*args, **kwargs):
15
- pass
16
-
17
-
18
- @contextlib.contextmanager
19
- def disable_print():
20
- origin_print = __builtin__.print
21
- __builtin__.print = dummy_print
22
- yield
23
- __builtin__.print = origin_print
24
-
25
-
26
14
  def compute_lpips_img(img0, img1, net: str = "alex"):
27
15
  global lpips_loss_fn_vgg
28
16
  global lpips_loss_fn_alex