cache-dit 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/dual_block_cache/cache_context.py +138 -33
- cache_dit/cache_factory/first_block_cache/cache_context.py +2 -2
- cache_dit/metrics/__init__.py +12 -0
- cache_dit/metrics/fid.py +409 -0
- cache_dit/metrics/inception.py +353 -0
- cache_dit/metrics/metrics.py +356 -0
- {cache_dit-0.2.4.dist-info → cache_dit-0.2.6.dist-info}/METADATA +59 -8
- {cache_dit-0.2.4.dist-info → cache_dit-0.2.6.dist-info}/RECORD +13 -8
- cache_dit-0.2.6.dist-info/entry_points.txt +2 -0
- {cache_dit-0.2.4.dist-info → cache_dit-0.2.6.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.4.dist-info → cache_dit-0.2.6.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.4.dist-info → cache_dit-0.2.6.dist-info}/top_level.txt +0 -0
cache_dit/metrics/fid.py
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import cv2
|
|
3
|
+
import pathlib
|
|
4
|
+
import numpy as np
|
|
5
|
+
from PIL import Image
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
from scipy import linalg
|
|
8
|
+
import torch
|
|
9
|
+
import torchvision.transforms as TF
|
|
10
|
+
from torch.nn.functional import adaptive_avg_pool2d
|
|
11
|
+
from cache_dit.metrics.inception import InceptionV3
|
|
12
|
+
from cache_dit.logger import init_logger
|
|
13
|
+
|
|
14
|
+
logger = init_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ImagePathDataset(torch.utils.data.Dataset):
|
|
18
|
+
def __init__(self, files_or_imgs, transforms=None):
|
|
19
|
+
self.files_or_imgs = files_or_imgs
|
|
20
|
+
self.transforms = transforms
|
|
21
|
+
|
|
22
|
+
def __len__(self):
|
|
23
|
+
return len(self.files_or_imgs)
|
|
24
|
+
|
|
25
|
+
def __getitem__(self, i):
|
|
26
|
+
file_or_img = self.files_or_imgs[i]
|
|
27
|
+
if isinstance(file_or_img, (str, pathlib.Path)):
|
|
28
|
+
img = Image.open(file_or_img).convert("RGB")
|
|
29
|
+
elif isinstance(file_or_img, np.ndarray):
|
|
30
|
+
# Assume the img is a standard OpenCV image.
|
|
31
|
+
img = cv2.cvtColor(file_or_img, cv2.COLOR_BGR2RGB)
|
|
32
|
+
img = Image.fromarray(img)
|
|
33
|
+
else:
|
|
34
|
+
raise ValueError(
|
|
35
|
+
"file_or_img must be a file path or an OpenCV image."
|
|
36
|
+
)
|
|
37
|
+
if self.transforms is not None:
|
|
38
|
+
img = self.transforms(img)
|
|
39
|
+
return img
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_activations(
|
|
43
|
+
files_or_imgs,
|
|
44
|
+
model,
|
|
45
|
+
batch_size=50,
|
|
46
|
+
dims=2048,
|
|
47
|
+
device="cpu",
|
|
48
|
+
num_workers=1,
|
|
49
|
+
disable_tqdm=True,
|
|
50
|
+
):
|
|
51
|
+
"""Calculates the activations of the pool_3 layer for all images.
|
|
52
|
+
|
|
53
|
+
Params:
|
|
54
|
+
-- files_or_imgs : List of image files paths or OpenCV image
|
|
55
|
+
-- model : Instance of inception model
|
|
56
|
+
-- batch_size : Batch size of images for the model to process at once.
|
|
57
|
+
Make sure that the number of samples is a multiple of
|
|
58
|
+
the batch size, otherwise some samples are ignored. This
|
|
59
|
+
behavior is retained to match the original FID score
|
|
60
|
+
implementation.
|
|
61
|
+
-- dims : Dimensionality of features returned by Inception
|
|
62
|
+
-- device : Device to run calculations
|
|
63
|
+
-- num_workers : Number of parallel dataloader workers
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
-- A numpy array of dimension (num images, dims) that contains the
|
|
67
|
+
activations of the given tensor when feeding inception with the
|
|
68
|
+
query tensor.
|
|
69
|
+
"""
|
|
70
|
+
model.eval()
|
|
71
|
+
|
|
72
|
+
if batch_size > len(files_or_imgs):
|
|
73
|
+
logger.info(
|
|
74
|
+
(
|
|
75
|
+
"Warning: batch size is bigger than the data size. "
|
|
76
|
+
"Setting batch size to data size"
|
|
77
|
+
)
|
|
78
|
+
)
|
|
79
|
+
batch_size = len(files_or_imgs)
|
|
80
|
+
|
|
81
|
+
dataset = ImagePathDataset(files_or_imgs, transforms=TF.ToTensor())
|
|
82
|
+
dataloader = torch.utils.data.DataLoader(
|
|
83
|
+
dataset,
|
|
84
|
+
batch_size=batch_size,
|
|
85
|
+
shuffle=False,
|
|
86
|
+
drop_last=False,
|
|
87
|
+
num_workers=num_workers,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
pred_arr = np.empty((len(files_or_imgs), dims))
|
|
91
|
+
|
|
92
|
+
start_idx = 0
|
|
93
|
+
|
|
94
|
+
for batch in tqdm(dataloader, disable=disable_tqdm):
|
|
95
|
+
batch = batch.to(device)
|
|
96
|
+
|
|
97
|
+
with torch.no_grad():
|
|
98
|
+
pred = model(batch)[0]
|
|
99
|
+
|
|
100
|
+
# If model output is not scalar, apply global spatial average pooling.
|
|
101
|
+
# This happens if you choose a dimensionality not equal 2048.
|
|
102
|
+
if pred.size(2) != 1 or pred.size(3) != 1:
|
|
103
|
+
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
|
|
104
|
+
|
|
105
|
+
pred = pred.squeeze(3).squeeze(2).cpu().numpy()
|
|
106
|
+
|
|
107
|
+
pred_arr[start_idx : start_idx + pred.shape[0]] = pred
|
|
108
|
+
|
|
109
|
+
start_idx = start_idx + pred.shape[0]
|
|
110
|
+
|
|
111
|
+
return pred_arr
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def calculate_frechet_distance(
|
|
115
|
+
mu1,
|
|
116
|
+
sigma1,
|
|
117
|
+
mu2,
|
|
118
|
+
sigma2,
|
|
119
|
+
eps=1e-6,
|
|
120
|
+
):
|
|
121
|
+
"""Numpy implementation of the Frechet Distance.
|
|
122
|
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
|
123
|
+
and X_2 ~ N(mu_2, C_2) is
|
|
124
|
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
|
125
|
+
|
|
126
|
+
Stable version by Dougal J. Sutherland.
|
|
127
|
+
|
|
128
|
+
Params:
|
|
129
|
+
-- mu1 : Numpy array containing the activations of a layer of the
|
|
130
|
+
inception net (like returned by the function 'get_predictions')
|
|
131
|
+
for generated samples.
|
|
132
|
+
-- mu2 : The sample mean over activations, precalculated on an
|
|
133
|
+
representative data set.
|
|
134
|
+
-- sigma1: The covariance matrix over activations for generated samples.
|
|
135
|
+
-- sigma2: The covariance matrix over activations, precalculated on an
|
|
136
|
+
representative data set.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
-- : The Frechet Distance.
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
mu1 = np.atleast_1d(mu1)
|
|
143
|
+
mu2 = np.atleast_1d(mu2)
|
|
144
|
+
|
|
145
|
+
sigma1 = np.atleast_2d(sigma1)
|
|
146
|
+
sigma2 = np.atleast_2d(sigma2)
|
|
147
|
+
|
|
148
|
+
assert (
|
|
149
|
+
mu1.shape == mu2.shape
|
|
150
|
+
), "Training and test mean vectors have different lengths"
|
|
151
|
+
assert (
|
|
152
|
+
sigma1.shape == sigma2.shape
|
|
153
|
+
), "Training and test covariances have different dimensions"
|
|
154
|
+
|
|
155
|
+
diff = mu1 - mu2
|
|
156
|
+
|
|
157
|
+
# Product might be almost singular
|
|
158
|
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
|
159
|
+
if not np.isfinite(covmean).all():
|
|
160
|
+
msg = (
|
|
161
|
+
"fid calculation produces singular product; "
|
|
162
|
+
"adding %s to diagonal of cov estimates"
|
|
163
|
+
) % eps
|
|
164
|
+
print(msg)
|
|
165
|
+
offset = np.eye(sigma1.shape[0]) * eps
|
|
166
|
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
|
167
|
+
|
|
168
|
+
# Numerical error might give slight imaginary component
|
|
169
|
+
if np.iscomplexobj(covmean):
|
|
170
|
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
|
171
|
+
m = np.max(np.abs(covmean.imag))
|
|
172
|
+
raise ValueError("Imaginary component {}".format(m))
|
|
173
|
+
covmean = covmean.real
|
|
174
|
+
|
|
175
|
+
tr_covmean = np.trace(covmean)
|
|
176
|
+
|
|
177
|
+
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def calculate_activation_statistics(
|
|
181
|
+
files_or_imgs,
|
|
182
|
+
model,
|
|
183
|
+
batch_size=50,
|
|
184
|
+
dims=2048,
|
|
185
|
+
device="cpu",
|
|
186
|
+
num_workers=1,
|
|
187
|
+
disable_tqdm=True,
|
|
188
|
+
):
|
|
189
|
+
"""Calculation of the statistics used by the FID.
|
|
190
|
+
Params:
|
|
191
|
+
-- files_or_imgs : List of image files paths or OpenCV image
|
|
192
|
+
-- model : Instance of inception model
|
|
193
|
+
-- batch_size : Batch size of images for the model to process at once.
|
|
194
|
+
Make sure that the number of samples is a multiple of
|
|
195
|
+
the batch size, otherwise some samples are ignored. This
|
|
196
|
+
behavior is retained to match the original FID score
|
|
197
|
+
implementation.
|
|
198
|
+
-- dims : Dimensionality of features returned by Inception
|
|
199
|
+
-- device : Device to run calculations
|
|
200
|
+
-- num_workers : Number of parallel dataloader workers
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
-- mu : The mean over samples of the activations of the pool_3 layer of
|
|
204
|
+
the inception model.
|
|
205
|
+
-- sigma : The covariance matrix of the activations of the pool_3 layer of
|
|
206
|
+
the inception model.
|
|
207
|
+
"""
|
|
208
|
+
act = get_activations(
|
|
209
|
+
files_or_imgs,
|
|
210
|
+
model,
|
|
211
|
+
batch_size,
|
|
212
|
+
dims,
|
|
213
|
+
device,
|
|
214
|
+
num_workers,
|
|
215
|
+
disable_tqdm,
|
|
216
|
+
)
|
|
217
|
+
mu = np.mean(act, axis=0)
|
|
218
|
+
sigma = np.cov(act, rowvar=False)
|
|
219
|
+
return mu, sigma
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
_IMAGE_EXTENSIONS = {
|
|
223
|
+
"bmp",
|
|
224
|
+
"jpg",
|
|
225
|
+
"jpeg",
|
|
226
|
+
"pgm",
|
|
227
|
+
"png",
|
|
228
|
+
"ppm",
|
|
229
|
+
"tif",
|
|
230
|
+
"tiff",
|
|
231
|
+
"webp",
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
class FrechetInceptionDistance:
|
|
236
|
+
IMAGE_EXTENSIONS = _IMAGE_EXTENSIONS
|
|
237
|
+
|
|
238
|
+
def __init__(
|
|
239
|
+
self,
|
|
240
|
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
|
241
|
+
dims: int = 2048,
|
|
242
|
+
num_workers: int = 1,
|
|
243
|
+
batch_size: int = 1,
|
|
244
|
+
disable_tqdm: bool = True,
|
|
245
|
+
):
|
|
246
|
+
# https://github.com/mseitzer/pytorch-fid/src/pytorch_fid/fid_score.py
|
|
247
|
+
self.dims = dims
|
|
248
|
+
self.device = device
|
|
249
|
+
self.num_workers = num_workers
|
|
250
|
+
self.batch_size = batch_size
|
|
251
|
+
self.disable_tqdm = disable_tqdm
|
|
252
|
+
self.block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[self.dims]
|
|
253
|
+
self.model = InceptionV3([self.block_idx]).to(self.device)
|
|
254
|
+
self.model = self.model.eval()
|
|
255
|
+
|
|
256
|
+
def compute_fid(
|
|
257
|
+
self,
|
|
258
|
+
image_true: np.ndarray | str,
|
|
259
|
+
image_test: np.ndarray | str,
|
|
260
|
+
):
|
|
261
|
+
"""Calculates the FID of two file paths
|
|
262
|
+
FID = FrechetInceptionDistance()
|
|
263
|
+
img_fid = FID.compute_fid("img_true.png", "img_test.png")
|
|
264
|
+
img_dir_fid = FID.compute_fid("img_true_dir", "img_test_dir")
|
|
265
|
+
"""
|
|
266
|
+
if isinstance(image_true, str) or isinstance(image_test, str):
|
|
267
|
+
if os.path.isfile(image_true) or os.path.isfile(image_test):
|
|
268
|
+
assert os.path.exists(image_true)
|
|
269
|
+
assert os.path.exists(image_test)
|
|
270
|
+
assert image_true.split(".")[-1] in self.IMAGE_EXTENSIONS
|
|
271
|
+
assert image_test.split(".")[-1] in self.IMAGE_EXTENSIONS
|
|
272
|
+
image_true_files = [image_true]
|
|
273
|
+
image_test_files = [image_test]
|
|
274
|
+
else:
|
|
275
|
+
# glob image files from dir
|
|
276
|
+
assert os.path.isdir(image_true)
|
|
277
|
+
assert os.path.isdir(image_test)
|
|
278
|
+
image_true_dir = pathlib.Path(image_true)
|
|
279
|
+
image_true_files = sorted(
|
|
280
|
+
[
|
|
281
|
+
file
|
|
282
|
+
for ext in self.IMAGE_EXTENSIONS
|
|
283
|
+
for file in image_true_dir.rglob("*.{}".format(ext))
|
|
284
|
+
]
|
|
285
|
+
)
|
|
286
|
+
image_test_dir = pathlib.Path(image_test)
|
|
287
|
+
image_test_files = sorted(
|
|
288
|
+
[
|
|
289
|
+
file
|
|
290
|
+
for ext in self.IMAGE_EXTENSIONS
|
|
291
|
+
for file in image_test_dir.rglob("*.{}".format(ext))
|
|
292
|
+
]
|
|
293
|
+
)
|
|
294
|
+
image_true_files = [
|
|
295
|
+
file.as_posix() for file in image_true_files
|
|
296
|
+
]
|
|
297
|
+
image_test_files = [
|
|
298
|
+
file.as_posix() for file in image_test_files
|
|
299
|
+
]
|
|
300
|
+
logger.debug(f"image_true_files: {image_true_files}")
|
|
301
|
+
logger.debug(f"image_test_files: {image_test_files}")
|
|
302
|
+
assert len(image_true_files) == len(image_test_files)
|
|
303
|
+
for image_true, image_test in zip(
|
|
304
|
+
image_true_files, image_test_files
|
|
305
|
+
):
|
|
306
|
+
assert os.path.basename(image_true) == os.path.basename(
|
|
307
|
+
image_test
|
|
308
|
+
), f"image_true:{image_true} != image_test: {image_test}"
|
|
309
|
+
else:
|
|
310
|
+
image_true_files = [image_true]
|
|
311
|
+
image_test_files = [image_test]
|
|
312
|
+
|
|
313
|
+
batch_size = min(16, self.batch_size)
|
|
314
|
+
batch_size = min(batch_size, len(image_test_files))
|
|
315
|
+
m1, s1 = calculate_activation_statistics(
|
|
316
|
+
image_true_files,
|
|
317
|
+
self.model,
|
|
318
|
+
batch_size,
|
|
319
|
+
self.dims,
|
|
320
|
+
self.device,
|
|
321
|
+
self.num_workers,
|
|
322
|
+
self.disable_tqdm,
|
|
323
|
+
)
|
|
324
|
+
m2, s2 = calculate_activation_statistics(
|
|
325
|
+
image_test_files,
|
|
326
|
+
self.model,
|
|
327
|
+
batch_size,
|
|
328
|
+
self.dims,
|
|
329
|
+
self.device,
|
|
330
|
+
self.num_workers,
|
|
331
|
+
self.disable_tqdm,
|
|
332
|
+
)
|
|
333
|
+
fid_value = calculate_frechet_distance(
|
|
334
|
+
m1,
|
|
335
|
+
s1,
|
|
336
|
+
m2,
|
|
337
|
+
s2,
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
return fid_value, len(image_true_files)
|
|
341
|
+
|
|
342
|
+
def compute_video_fid(
|
|
343
|
+
self,
|
|
344
|
+
video_true: str,
|
|
345
|
+
video_test: str,
|
|
346
|
+
):
|
|
347
|
+
cap1 = cv2.VideoCapture(video_true)
|
|
348
|
+
cap2 = cv2.VideoCapture(video_test)
|
|
349
|
+
|
|
350
|
+
if not cap1.isOpened() or not cap2.isOpened():
|
|
351
|
+
logger.error("Could not open video files")
|
|
352
|
+
return None, None
|
|
353
|
+
|
|
354
|
+
frame_count = min(
|
|
355
|
+
int(cap1.get(cv2.CAP_PROP_FRAME_COUNT)),
|
|
356
|
+
int(cap2.get(cv2.CAP_PROP_FRAME_COUNT)),
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
valid_frames = 0
|
|
360
|
+
video_true_frames = []
|
|
361
|
+
video_test_frames = []
|
|
362
|
+
|
|
363
|
+
logger.debug(f"Total frames: {frame_count}")
|
|
364
|
+
|
|
365
|
+
while True:
|
|
366
|
+
ret1, frame1 = cap1.read()
|
|
367
|
+
ret2, frame2 = cap2.read()
|
|
368
|
+
|
|
369
|
+
if not ret1 or not ret2:
|
|
370
|
+
break
|
|
371
|
+
|
|
372
|
+
video_true_frames.append(frame1)
|
|
373
|
+
video_test_frames.append(frame2)
|
|
374
|
+
|
|
375
|
+
valid_frames += 1
|
|
376
|
+
|
|
377
|
+
cap1.release()
|
|
378
|
+
cap2.release()
|
|
379
|
+
|
|
380
|
+
if valid_frames <= 0:
|
|
381
|
+
return None, None
|
|
382
|
+
|
|
383
|
+
batch_size = min(16, self.batch_size)
|
|
384
|
+
m1, s1 = calculate_activation_statistics(
|
|
385
|
+
video_true_frames,
|
|
386
|
+
self.model,
|
|
387
|
+
batch_size,
|
|
388
|
+
self.dims,
|
|
389
|
+
self.device,
|
|
390
|
+
self.num_workers,
|
|
391
|
+
self.disable_tqdm,
|
|
392
|
+
)
|
|
393
|
+
m2, s2 = calculate_activation_statistics(
|
|
394
|
+
video_test_frames,
|
|
395
|
+
self.model,
|
|
396
|
+
batch_size,
|
|
397
|
+
self.dims,
|
|
398
|
+
self.device,
|
|
399
|
+
self.num_workers,
|
|
400
|
+
self.disable_tqdm,
|
|
401
|
+
)
|
|
402
|
+
fid_value = calculate_frechet_distance(
|
|
403
|
+
m1,
|
|
404
|
+
s1,
|
|
405
|
+
m2,
|
|
406
|
+
s2,
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
return fid_value, valid_frames
|