cache-dit 0.2.34__py3-none-any.whl → 0.2.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/__init__.py CHANGED
@@ -4,6 +4,10 @@ except ImportError:
4
4
  __version__ = "unknown version"
5
5
  version_tuple = (0, 0, "unknown version")
6
6
 
7
+ from cache_dit.utils import summary
8
+ from cache_dit.utils import strify
9
+ from cache_dit.utils import disable_print
10
+ from cache_dit.logger import init_logger
7
11
  from cache_dit.cache_factory import load_options
8
12
  from cache_dit.cache_factory import enable_cache
9
13
  from cache_dit.cache_factory import disable_cache
@@ -18,9 +22,7 @@ from cache_dit.cache_factory import supported_pipelines
18
22
  from cache_dit.cache_factory import get_adapter
19
23
  from cache_dit.compile import set_compile_configs
20
24
  from cache_dit.quantize import quantize
21
- from cache_dit.utils import summary
22
- from cache_dit.utils import strify
23
- from cache_dit.logger import init_logger
25
+
24
26
 
25
27
  NONE = CacheType.NONE
26
28
  DBCache = CacheType.DBCache
cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.34'
32
- __version_tuple__ = version_tuple = (0, 2, 34)
31
+ __version__ = version = '0.2.37'
32
+ __version_tuple__ = version_tuple = (0, 2, 37)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -0,0 +1,135 @@
1
+ import os
2
+ import re
3
+ import pathlib
4
+ import numpy as np
5
+ from PIL import Image
6
+ from tqdm import tqdm
7
+ import torch
8
+ from transformers import CLIPProcessor, CLIPModel
9
+
10
+ from typing import Tuple, Union
11
+ from cache_dit.metrics.config import _IMAGE_EXTENSIONS
12
+ from cache_dit.metrics.config import get_metrics_verbose
13
+ from cache_dit.logger import init_logger
14
+
15
+ logger = init_logger(__name__)
16
+
17
+
18
+ DISABLE_VERBOSE = not get_metrics_verbose()
19
+
20
+
21
+ class CLIPScore:
22
+ def __init__(
23
+ self,
24
+ device="cuda" if torch.cuda.is_available() else "cpu",
25
+ clip_model_path: str = None,
26
+ ):
27
+ self.device = device
28
+ if clip_model_path is None:
29
+ clip_model_path = os.environ.get(
30
+ "CLIP_MODEL_DIR", "laion/CLIP-ViT-g-14-laion2B-s12B-b42K"
31
+ )
32
+
33
+ # Load models
34
+ self.clip_model = CLIPModel.from_pretrained(clip_model_path)
35
+ self.clip_model = self.clip_model.to(device) # type: ignore
36
+ self.clip_processor = CLIPProcessor.from_pretrained(clip_model_path)
37
+
38
+ @torch.no_grad()
39
+ def compute_clip_score(
40
+ self,
41
+ img: Image.Image | np.ndarray,
42
+ prompt: str,
43
+ ) -> float:
44
+ if isinstance(img, Image.Image):
45
+ img_pil = img.convert("RGB")
46
+ elif isinstance(img, np.ndarray):
47
+ img_pil = Image.fromarray(img).convert("RGB")
48
+ else:
49
+ img_pil = Image.open(img).convert("RGB")
50
+ with torch.no_grad():
51
+ inputs = self.clip_processor(
52
+ text=prompt,
53
+ images=img_pil,
54
+ return_tensors="pt",
55
+ padding=True,
56
+ truncation=True,
57
+ ).to(self.device)
58
+ outputs = self.clip_model(**inputs)
59
+ return outputs.logits_per_image.item()
60
+
61
+
62
+ clip_score_instance: CLIPScore = None
63
+
64
+
65
+ def compute_clip_score_img(
66
+ img: Image.Image | np.ndarray | str,
67
+ prompt: str,
68
+ clip_model_path: str = None,
69
+ ) -> float:
70
+ global clip_score_instance
71
+ if clip_score_instance is None:
72
+ clip_score_instance = CLIPScore(clip_model_path=clip_model_path)
73
+ assert clip_score_instance is not None
74
+ return clip_score_instance.compute_clip_score(img, prompt)
75
+
76
+
77
+ def compute_clip_score(
78
+ img_dir: Image.Image | np.ndarray | str,
79
+ prompts: str | list[str],
80
+ clip_model_path: str = None,
81
+ ) -> Union[Tuple[float, int], Tuple[None, None]]:
82
+ if not os.path.isdir(img_dir) or (
83
+ not isinstance(prompts, list) and not os.path.isfile(prompts)
84
+ ):
85
+ return (
86
+ compute_clip_score_img(
87
+ img_dir,
88
+ prompts,
89
+ clip_model_path=clip_model_path,
90
+ ),
91
+ 1,
92
+ )
93
+
94
+ # compute dir metric
95
+ def natural_sort_key(filename):
96
+ match = re.search(r"(\d+)\D*$", filename)
97
+ return int(match.group(1)) if match else filename
98
+
99
+ img_dir: pathlib.Path = pathlib.Path(img_dir)
100
+ img_files = [
101
+ file
102
+ for ext in _IMAGE_EXTENSIONS
103
+ for file in img_dir.rglob("*.{}".format(ext))
104
+ ]
105
+ img_files = [file.as_posix() for file in img_files]
106
+ img_files = sorted(img_files, key=natural_sort_key)
107
+
108
+ if os.path.isfile(prompts):
109
+ """Load prompts from file"""
110
+ with open(prompts, "r", encoding="utf-8") as f:
111
+ prompts_load = [line.strip() for line in f.readlines()]
112
+ prompts = prompts_load.copy()
113
+
114
+ vaild_len = min(len(img_files), len(prompts))
115
+ img_files = img_files[:vaild_len]
116
+ prompts = prompts[:vaild_len]
117
+
118
+ clip_scores = []
119
+
120
+ for img_file, prompt in tqdm(
121
+ zip(img_files, prompts),
122
+ total=vaild_len,
123
+ disable=not get_metrics_verbose(),
124
+ ):
125
+ clip_scores.append(
126
+ compute_clip_score_img(
127
+ img_file,
128
+ prompt,
129
+ clip_model_path=clip_model_path,
130
+ )
131
+ )
132
+
133
+ if vaild_len > 0:
134
+ return np.mean(clip_scores), vaild_len
135
+ return None, None
cache_dit/metrics/fid.py CHANGED
@@ -1,6 +1,8 @@
1
1
  import os
2
2
  import cv2
3
3
  import pathlib
4
+ import warnings
5
+
4
6
  import numpy as np
5
7
  from PIL import Image
6
8
  from tqdm import tqdm
@@ -8,13 +10,21 @@ from scipy import linalg
8
10
  import torch
9
11
  import torchvision.transforms as TF
10
12
  from torch.nn.functional import adaptive_avg_pool2d
13
+
14
+ from typing import Tuple, Union
11
15
  from cache_dit.metrics.inception import InceptionV3
12
16
  from cache_dit.metrics.config import _IMAGE_EXTENSIONS
13
17
  from cache_dit.metrics.config import _VIDEO_EXTENSIONS
18
+ from cache_dit.metrics.config import get_metrics_verbose
19
+ from cache_dit.utils import disable_print
14
20
  from cache_dit.logger import init_logger
15
21
 
22
+ warnings.filterwarnings("ignore")
23
+
16
24
  logger = init_logger(__name__)
17
25
 
26
+ DISABLE_VERBOSE = not get_metrics_verbose()
27
+
18
28
 
19
29
  # Adapted from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py
20
30
  class ImagePathDataset(torch.utils.data.Dataset):
@@ -496,3 +506,35 @@ class FrechetInceptionDistance:
496
506
  return [], [], 0
497
507
 
498
508
  return video_true_frames, video_test_frames, valid_frames
509
+
510
+
511
+ fid_instance: FrechetInceptionDistance = None
512
+
513
+
514
+ def compute_fid(
515
+ image_true: np.ndarray | str,
516
+ image_test: np.ndarray | str,
517
+ ) -> Union[Tuple[float, int], Tuple[None, None]]:
518
+ global fid_instance
519
+ if fid_instance is None:
520
+ with disable_print():
521
+ fid_instance = FrechetInceptionDistance(
522
+ disable_tqdm=not get_metrics_verbose(),
523
+ )
524
+ assert fid_instance is not None
525
+ return fid_instance.compute_fid(image_true, image_test)
526
+
527
+
528
+ def compute_video_fid(
529
+ # file or dir
530
+ video_true: str,
531
+ video_test: str,
532
+ ) -> Union[Tuple[float, int], Tuple[None, None]]:
533
+ global fid_instance
534
+ if fid_instance is None:
535
+ with disable_print():
536
+ fid_instance = FrechetInceptionDistance(
537
+ disable_tqdm=not get_metrics_verbose(),
538
+ )
539
+ assert fid_instance is not None
540
+ return fid_instance.compute_fid(video_true, video_test)
@@ -0,0 +1,177 @@
1
+ import os
2
+ import re
3
+ import pathlib
4
+ import warnings
5
+
6
+ import numpy as np
7
+ from PIL import Image
8
+ from tqdm import tqdm
9
+ import torch
10
+ import ImageReward as RM
11
+ import torchvision.transforms.v2.functional as TF
12
+ import torchvision.transforms.v2 as T
13
+
14
+ from typing import Tuple, Union
15
+ from cache_dit.metrics.config import _IMAGE_EXTENSIONS
16
+ from cache_dit.metrics.config import get_metrics_verbose
17
+ from cache_dit.utils import disable_print
18
+ from cache_dit.logger import init_logger
19
+
20
+ warnings.filterwarnings("ignore")
21
+
22
+
23
+ logger = init_logger(__name__)
24
+
25
+
26
+ DISABLE_VERBOSE = not get_metrics_verbose()
27
+
28
+
29
+ class ImageRewardScore:
30
+ def __init__(
31
+ self,
32
+ device="cuda" if torch.cuda.is_available() else "cpu",
33
+ imagereward_model_path: str = None,
34
+ ):
35
+ self.device = device
36
+ if imagereward_model_path is None:
37
+ imagereward_model_path = os.environ.get(
38
+ "IMAGEREWARD_MODEL_DIR", None
39
+ )
40
+
41
+ # Load ImageReward model
42
+ self.med_config = os.path.join(
43
+ imagereward_model_path, "med_config.json"
44
+ )
45
+ self.imagereward_path = os.path.join(
46
+ imagereward_model_path, "ImageReward.pt"
47
+ )
48
+ if imagereward_model_path is not None:
49
+ self.imagereward_model = RM.load(
50
+ self.imagereward_path,
51
+ download_root=imagereward_model_path,
52
+ med_config=self.med_config,
53
+ ).to(self.device)
54
+ else:
55
+ self.imagereward_model = RM.load(
56
+ "ImageReward-v1.0", # download from huggingface
57
+ ).to(self.device)
58
+
59
+ # ImageReward transform
60
+ self.reward_transform = T.Compose(
61
+ [
62
+ T.Resize(224, interpolation=T.InterpolationMode.BICUBIC),
63
+ T.CenterCrop(224),
64
+ T.ToImage(),
65
+ T.ToDtype(torch.float32, scale=True),
66
+ T.Normalize(
67
+ (0.48145466, 0.4578275, 0.40821073),
68
+ (0.26862954, 0.26130258, 0.27577711),
69
+ ),
70
+ ]
71
+ )
72
+
73
+ @torch.no_grad()
74
+ def compute_reward_score(
75
+ self,
76
+ img: Image.Image | np.ndarray,
77
+ prompt: str,
78
+ ) -> float:
79
+ if isinstance(img, Image.Image):
80
+ img_pil = img.convert("RGB")
81
+ elif isinstance(img, np.ndarray):
82
+ img_pil = Image.fromarray(img).convert("RGB")
83
+ else:
84
+ img_pil = Image.open(img).convert("RGB")
85
+ with torch.no_grad():
86
+ img_tensor = TF.pil_to_tensor(img_pil).unsqueeze(0).to(self.device)
87
+ img_reward = self.reward_transform(img_tensor)
88
+ inputs = self.imagereward_model.blip.tokenizer(
89
+ [prompt],
90
+ padding="max_length",
91
+ truncation=True,
92
+ max_length=512,
93
+ return_tensors="pt",
94
+ ).to(self.device)
95
+ score = self.imagereward_model.score_gard(
96
+ inputs.input_ids, inputs.attention_mask, img_reward
97
+ )
98
+ return score.item()
99
+
100
+
101
+ image_reward_score_instance: ImageRewardScore = None
102
+
103
+
104
+ def compute_reward_score_img(
105
+ img: Image.Image | np.ndarray | str,
106
+ prompt: str,
107
+ imagereward_model_path: str = None,
108
+ ) -> float:
109
+ global image_reward_score_instance
110
+ if image_reward_score_instance is None:
111
+ with disable_print():
112
+ image_reward_score_instance = ImageRewardScore(
113
+ imagereward_model_path=imagereward_model_path
114
+ )
115
+ assert image_reward_score_instance is not None
116
+ return image_reward_score_instance.compute_reward_score(img, prompt)
117
+
118
+
119
+ def compute_reward_score(
120
+ img_dir: Image.Image | np.ndarray | str,
121
+ prompts: str | list[str],
122
+ imagereward_model_path: str = None,
123
+ ) -> Union[Tuple[float, int], Tuple[None, None]]:
124
+ if not os.path.isdir(img_dir) or (
125
+ not isinstance(prompts, list) and not os.path.isfile(prompts)
126
+ ):
127
+ return (
128
+ compute_reward_score_img(
129
+ img_dir,
130
+ prompts,
131
+ imagereward_model_path=imagereward_model_path,
132
+ ),
133
+ 1,
134
+ )
135
+
136
+ # compute dir metric
137
+ def natural_sort_key(filename):
138
+ match = re.search(r"(\d+)\D*$", filename)
139
+ return int(match.group(1)) if match else filename
140
+
141
+ img_dir: pathlib.Path = pathlib.Path(img_dir)
142
+ img_files = [
143
+ file
144
+ for ext in _IMAGE_EXTENSIONS
145
+ for file in img_dir.rglob("*.{}".format(ext))
146
+ ]
147
+ img_files = [file.as_posix() for file in img_files]
148
+ img_files = sorted(img_files, key=natural_sort_key)
149
+
150
+ if os.path.isfile(prompts):
151
+ """Load prompts from file"""
152
+ with open(prompts, "r", encoding="utf-8") as f:
153
+ prompts_load = [line.strip() for line in f.readlines()]
154
+ prompts = prompts_load.copy()
155
+
156
+ vaild_len = min(len(img_files), len(prompts))
157
+ img_files = img_files[:vaild_len]
158
+ prompts = prompts[:vaild_len]
159
+
160
+ reward_scores = []
161
+
162
+ for img_file, prompt in tqdm(
163
+ zip(img_files, prompts),
164
+ total=vaild_len,
165
+ disable=not get_metrics_verbose(),
166
+ ):
167
+ reward_scores.append(
168
+ compute_reward_score_img(
169
+ img_file,
170
+ prompt,
171
+ imagereward_model_path=imagereward_model_path,
172
+ )
173
+ )
174
+
175
+ if vaild_len > 0:
176
+ return np.mean(reward_scores), vaild_len
177
+ return None, None
@@ -1,9 +1,9 @@
1
- import builtins as __builtin__
2
- import contextlib
3
1
  import warnings
4
2
 
5
3
  import lpips
6
4
  import torch
5
+ from cache_dit.utils import disable_print
6
+
7
7
 
8
8
  warnings.filterwarnings("ignore")
9
9
 
@@ -11,18 +11,6 @@ lpips_loss_fn_vgg = None
11
11
  lpips_loss_fn_alex = None
12
12
 
13
13
 
14
- def dummy_print(*args, **kwargs):
15
- pass
16
-
17
-
18
- @contextlib.contextmanager
19
- def disable_print():
20
- origin_print = __builtin__.print
21
- __builtin__.print = dummy_print
22
- yield
23
- __builtin__.print = origin_print
24
-
25
-
26
14
  def compute_lpips_img(img0, img1, net: str = "alex"):
27
15
  global lpips_loss_fn_vgg
28
16
  global lpips_loss_fn_alex