cache-dit 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

@@ -0,0 +1,409 @@
1
+ import os
2
+ import cv2
3
+ import pathlib
4
+ import numpy as np
5
+ from PIL import Image
6
+ from tqdm import tqdm
7
+ from scipy import linalg
8
+ import torch
9
+ import torchvision.transforms as TF
10
+ from torch.nn.functional import adaptive_avg_pool2d
11
+ from cache_dit.metrics.inception import InceptionV3
12
+ from cache_dit.logger import init_logger
13
+
14
+ logger = init_logger(__name__)
15
+
16
+
17
+ class ImagePathDataset(torch.utils.data.Dataset):
18
+ def __init__(self, files_or_imgs, transforms=None):
19
+ self.files_or_imgs = files_or_imgs
20
+ self.transforms = transforms
21
+
22
+ def __len__(self):
23
+ return len(self.files_or_imgs)
24
+
25
+ def __getitem__(self, i):
26
+ file_or_img = self.files_or_imgs[i]
27
+ if isinstance(file_or_img, (str, pathlib.Path)):
28
+ img = Image.open(file_or_img).convert("RGB")
29
+ elif isinstance(file_or_img, np.ndarray):
30
+ # Assume the img is a standard OpenCV image.
31
+ img = cv2.cvtColor(file_or_img, cv2.COLOR_BGR2RGB)
32
+ img = Image.fromarray(img)
33
+ else:
34
+ raise ValueError(
35
+ "file_or_img must be a file path or an OpenCV image."
36
+ )
37
+ if self.transforms is not None:
38
+ img = self.transforms(img)
39
+ return img
40
+
41
+
42
+ def get_activations(
43
+ files_or_imgs,
44
+ model,
45
+ batch_size=50,
46
+ dims=2048,
47
+ device="cpu",
48
+ num_workers=1,
49
+ disable_tqdm=True,
50
+ ):
51
+ """Calculates the activations of the pool_3 layer for all images.
52
+
53
+ Params:
54
+ -- files_or_imgs : List of image files paths or OpenCV image
55
+ -- model : Instance of inception model
56
+ -- batch_size : Batch size of images for the model to process at once.
57
+ Make sure that the number of samples is a multiple of
58
+ the batch size, otherwise some samples are ignored. This
59
+ behavior is retained to match the original FID score
60
+ implementation.
61
+ -- dims : Dimensionality of features returned by Inception
62
+ -- device : Device to run calculations
63
+ -- num_workers : Number of parallel dataloader workers
64
+
65
+ Returns:
66
+ -- A numpy array of dimension (num images, dims) that contains the
67
+ activations of the given tensor when feeding inception with the
68
+ query tensor.
69
+ """
70
+ model.eval()
71
+
72
+ if batch_size > len(files_or_imgs):
73
+ logger.info(
74
+ (
75
+ "Warning: batch size is bigger than the data size. "
76
+ "Setting batch size to data size"
77
+ )
78
+ )
79
+ batch_size = len(files_or_imgs)
80
+
81
+ dataset = ImagePathDataset(files_or_imgs, transforms=TF.ToTensor())
82
+ dataloader = torch.utils.data.DataLoader(
83
+ dataset,
84
+ batch_size=batch_size,
85
+ shuffle=False,
86
+ drop_last=False,
87
+ num_workers=num_workers,
88
+ )
89
+
90
+ pred_arr = np.empty((len(files_or_imgs), dims))
91
+
92
+ start_idx = 0
93
+
94
+ for batch in tqdm(dataloader, disable=disable_tqdm):
95
+ batch = batch.to(device)
96
+
97
+ with torch.no_grad():
98
+ pred = model(batch)[0]
99
+
100
+ # If model output is not scalar, apply global spatial average pooling.
101
+ # This happens if you choose a dimensionality not equal 2048.
102
+ if pred.size(2) != 1 or pred.size(3) != 1:
103
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
104
+
105
+ pred = pred.squeeze(3).squeeze(2).cpu().numpy()
106
+
107
+ pred_arr[start_idx : start_idx + pred.shape[0]] = pred
108
+
109
+ start_idx = start_idx + pred.shape[0]
110
+
111
+ return pred_arr
112
+
113
+
114
+ def calculate_frechet_distance(
115
+ mu1,
116
+ sigma1,
117
+ mu2,
118
+ sigma2,
119
+ eps=1e-6,
120
+ ):
121
+ """Numpy implementation of the Frechet Distance.
122
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
123
+ and X_2 ~ N(mu_2, C_2) is
124
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
125
+
126
+ Stable version by Dougal J. Sutherland.
127
+
128
+ Params:
129
+ -- mu1 : Numpy array containing the activations of a layer of the
130
+ inception net (like returned by the function 'get_predictions')
131
+ for generated samples.
132
+ -- mu2 : The sample mean over activations, precalculated on an
133
+ representative data set.
134
+ -- sigma1: The covariance matrix over activations for generated samples.
135
+ -- sigma2: The covariance matrix over activations, precalculated on an
136
+ representative data set.
137
+
138
+ Returns:
139
+ -- : The Frechet Distance.
140
+ """
141
+
142
+ mu1 = np.atleast_1d(mu1)
143
+ mu2 = np.atleast_1d(mu2)
144
+
145
+ sigma1 = np.atleast_2d(sigma1)
146
+ sigma2 = np.atleast_2d(sigma2)
147
+
148
+ assert (
149
+ mu1.shape == mu2.shape
150
+ ), "Training and test mean vectors have different lengths"
151
+ assert (
152
+ sigma1.shape == sigma2.shape
153
+ ), "Training and test covariances have different dimensions"
154
+
155
+ diff = mu1 - mu2
156
+
157
+ # Product might be almost singular
158
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
159
+ if not np.isfinite(covmean).all():
160
+ msg = (
161
+ "fid calculation produces singular product; "
162
+ "adding %s to diagonal of cov estimates"
163
+ ) % eps
164
+ print(msg)
165
+ offset = np.eye(sigma1.shape[0]) * eps
166
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
167
+
168
+ # Numerical error might give slight imaginary component
169
+ if np.iscomplexobj(covmean):
170
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
171
+ m = np.max(np.abs(covmean.imag))
172
+ raise ValueError("Imaginary component {}".format(m))
173
+ covmean = covmean.real
174
+
175
+ tr_covmean = np.trace(covmean)
176
+
177
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
178
+
179
+
180
+ def calculate_activation_statistics(
181
+ files_or_imgs,
182
+ model,
183
+ batch_size=50,
184
+ dims=2048,
185
+ device="cpu",
186
+ num_workers=1,
187
+ disable_tqdm=True,
188
+ ):
189
+ """Calculation of the statistics used by the FID.
190
+ Params:
191
+ -- files_or_imgs : List of image files paths or OpenCV image
192
+ -- model : Instance of inception model
193
+ -- batch_size : Batch size of images for the model to process at once.
194
+ Make sure that the number of samples is a multiple of
195
+ the batch size, otherwise some samples are ignored. This
196
+ behavior is retained to match the original FID score
197
+ implementation.
198
+ -- dims : Dimensionality of features returned by Inception
199
+ -- device : Device to run calculations
200
+ -- num_workers : Number of parallel dataloader workers
201
+
202
+ Returns:
203
+ -- mu : The mean over samples of the activations of the pool_3 layer of
204
+ the inception model.
205
+ -- sigma : The covariance matrix of the activations of the pool_3 layer of
206
+ the inception model.
207
+ """
208
+ act = get_activations(
209
+ files_or_imgs,
210
+ model,
211
+ batch_size,
212
+ dims,
213
+ device,
214
+ num_workers,
215
+ disable_tqdm,
216
+ )
217
+ mu = np.mean(act, axis=0)
218
+ sigma = np.cov(act, rowvar=False)
219
+ return mu, sigma
220
+
221
+
222
+ _IMAGE_EXTENSIONS = {
223
+ "bmp",
224
+ "jpg",
225
+ "jpeg",
226
+ "pgm",
227
+ "png",
228
+ "ppm",
229
+ "tif",
230
+ "tiff",
231
+ "webp",
232
+ }
233
+
234
+
235
+ class FrechetInceptionDistance:
236
+ IMAGE_EXTENSIONS = _IMAGE_EXTENSIONS
237
+
238
+ def __init__(
239
+ self,
240
+ device="cuda" if torch.cuda.is_available() else "cpu",
241
+ dims: int = 2048,
242
+ num_workers: int = 1,
243
+ batch_size: int = 1,
244
+ disable_tqdm: bool = True,
245
+ ):
246
+ # https://github.com/mseitzer/pytorch-fid/src/pytorch_fid/fid_score.py
247
+ self.dims = dims
248
+ self.device = device
249
+ self.num_workers = num_workers
250
+ self.batch_size = batch_size
251
+ self.disable_tqdm = disable_tqdm
252
+ self.block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[self.dims]
253
+ self.model = InceptionV3([self.block_idx]).to(self.device)
254
+ self.model = self.model.eval()
255
+
256
+ def compute_fid(
257
+ self,
258
+ image_true: np.ndarray | str,
259
+ image_test: np.ndarray | str,
260
+ ):
261
+ """Calculates the FID of two file paths
262
+ FID = FrechetInceptionDistance()
263
+ img_fid = FID.compute_fid("img_true.png", "img_test.png")
264
+ img_dir_fid = FID.compute_fid("img_true_dir", "img_test_dir")
265
+ """
266
+ if isinstance(image_true, str) or isinstance(image_test, str):
267
+ if os.path.isfile(image_true) or os.path.isfile(image_test):
268
+ assert os.path.exists(image_true)
269
+ assert os.path.exists(image_test)
270
+ assert image_true.split(".")[-1] in self.IMAGE_EXTENSIONS
271
+ assert image_test.split(".")[-1] in self.IMAGE_EXTENSIONS
272
+ image_true_files = [image_true]
273
+ image_test_files = [image_test]
274
+ else:
275
+ # glob image files from dir
276
+ assert os.path.isdir(image_true)
277
+ assert os.path.isdir(image_test)
278
+ image_true_dir = pathlib.Path(image_true)
279
+ image_true_files = sorted(
280
+ [
281
+ file
282
+ for ext in self.IMAGE_EXTENSIONS
283
+ for file in image_true_dir.rglob("*.{}".format(ext))
284
+ ]
285
+ )
286
+ image_test_dir = pathlib.Path(image_test)
287
+ image_test_files = sorted(
288
+ [
289
+ file
290
+ for ext in self.IMAGE_EXTENSIONS
291
+ for file in image_test_dir.rglob("*.{}".format(ext))
292
+ ]
293
+ )
294
+ image_true_files = [
295
+ file.as_posix() for file in image_true_files
296
+ ]
297
+ image_test_files = [
298
+ file.as_posix() for file in image_test_files
299
+ ]
300
+ logger.debug(f"image_true_files: {image_true_files}")
301
+ logger.debug(f"image_test_files: {image_test_files}")
302
+ assert len(image_true_files) == len(image_test_files)
303
+ for image_true, image_test in zip(
304
+ image_true_files, image_test_files
305
+ ):
306
+ assert os.path.basename(image_true) == os.path.basename(
307
+ image_test
308
+ ), f"image_true:{image_true} != image_test: {image_test}"
309
+ else:
310
+ image_true_files = [image_true]
311
+ image_test_files = [image_test]
312
+
313
+ batch_size = min(16, self.batch_size)
314
+ batch_size = min(batch_size, len(image_test_files))
315
+ m1, s1 = calculate_activation_statistics(
316
+ image_true_files,
317
+ self.model,
318
+ batch_size,
319
+ self.dims,
320
+ self.device,
321
+ self.num_workers,
322
+ self.disable_tqdm,
323
+ )
324
+ m2, s2 = calculate_activation_statistics(
325
+ image_test_files,
326
+ self.model,
327
+ batch_size,
328
+ self.dims,
329
+ self.device,
330
+ self.num_workers,
331
+ self.disable_tqdm,
332
+ )
333
+ fid_value = calculate_frechet_distance(
334
+ m1,
335
+ s1,
336
+ m2,
337
+ s2,
338
+ )
339
+
340
+ return fid_value, len(image_true_files)
341
+
342
+ def compute_video_fid(
343
+ self,
344
+ video_true: str,
345
+ video_test: str,
346
+ ):
347
+ cap1 = cv2.VideoCapture(video_true)
348
+ cap2 = cv2.VideoCapture(video_test)
349
+
350
+ if not cap1.isOpened() or not cap2.isOpened():
351
+ logger.error("Could not open video files")
352
+ return None, None
353
+
354
+ frame_count = min(
355
+ int(cap1.get(cv2.CAP_PROP_FRAME_COUNT)),
356
+ int(cap2.get(cv2.CAP_PROP_FRAME_COUNT)),
357
+ )
358
+
359
+ valid_frames = 0
360
+ video_true_frames = []
361
+ video_test_frames = []
362
+
363
+ logger.debug(f"Total frames: {frame_count}")
364
+
365
+ while True:
366
+ ret1, frame1 = cap1.read()
367
+ ret2, frame2 = cap2.read()
368
+
369
+ if not ret1 or not ret2:
370
+ break
371
+
372
+ video_true_frames.append(frame1)
373
+ video_test_frames.append(frame2)
374
+
375
+ valid_frames += 1
376
+
377
+ cap1.release()
378
+ cap2.release()
379
+
380
+ if valid_frames <= 0:
381
+ return None, None
382
+
383
+ batch_size = min(16, self.batch_size)
384
+ m1, s1 = calculate_activation_statistics(
385
+ video_true_frames,
386
+ self.model,
387
+ batch_size,
388
+ self.dims,
389
+ self.device,
390
+ self.num_workers,
391
+ self.disable_tqdm,
392
+ )
393
+ m2, s2 = calculate_activation_statistics(
394
+ video_test_frames,
395
+ self.model,
396
+ batch_size,
397
+ self.dims,
398
+ self.device,
399
+ self.num_workers,
400
+ self.disable_tqdm,
401
+ )
402
+ fid_value = calculate_frechet_distance(
403
+ m1,
404
+ s1,
405
+ m2,
406
+ s2,
407
+ )
408
+
409
+ return fid_value, valid_frames