cache-dit 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/dual_block_cache/cache_context.py +19 -17
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +1 -1
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +1 -1
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +1 -1
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +1 -1
- cache_dit/metrics/__init__.py +14 -0
- cache_dit/metrics/config.py +34 -0
- cache_dit/metrics/fid.py +498 -0
- cache_dit/metrics/inception.py +353 -0
- cache_dit/metrics/metrics.py +469 -0
- {cache_dit-0.2.5.dist-info → cache_dit-0.2.7.dist-info}/METADATA +37 -1
- {cache_dit-0.2.5.dist-info → cache_dit-0.2.7.dist-info}/RECORD +17 -11
- cache_dit-0.2.7.dist-info/entry_points.txt +2 -0
- {cache_dit-0.2.5.dist-info → cache_dit-0.2.7.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.5.dist-info → cache_dit-0.2.7.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.5.dist-info → cache_dit-0.2.7.dist-info}/top_level.txt +0 -0
cache_dit/_version.py
CHANGED
|
@@ -988,7 +988,7 @@ def get_Bn_encoder_buffer(prefix: str = "Bn"):
|
|
|
988
988
|
@torch.compiler.disable
|
|
989
989
|
def apply_hidden_states_residual(
|
|
990
990
|
hidden_states: torch.Tensor,
|
|
991
|
-
encoder_hidden_states: torch.Tensor,
|
|
991
|
+
encoder_hidden_states: torch.Tensor = None,
|
|
992
992
|
prefix: str = "Bn",
|
|
993
993
|
encoder_prefix: str = "Bn_encoder",
|
|
994
994
|
):
|
|
@@ -1006,25 +1006,27 @@ def apply_hidden_states_residual(
|
|
|
1006
1006
|
# If cache is not residual, we use the hidden states directly
|
|
1007
1007
|
hidden_states = hidden_states_prev
|
|
1008
1008
|
|
|
1009
|
-
|
|
1010
|
-
encoder_hidden_states_prev = get_Bn_encoder_buffer(encoder_prefix)
|
|
1011
|
-
else:
|
|
1012
|
-
encoder_hidden_states_prev = get_Fn_encoder_buffer(encoder_prefix)
|
|
1009
|
+
hidden_states = hidden_states.contiguous()
|
|
1013
1010
|
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1011
|
+
if encoder_hidden_states is not None:
|
|
1012
|
+
if "Bn" in encoder_prefix:
|
|
1013
|
+
encoder_hidden_states_prev = get_Bn_encoder_buffer(encoder_prefix)
|
|
1014
|
+
else:
|
|
1015
|
+
encoder_hidden_states_prev = get_Fn_encoder_buffer(encoder_prefix)
|
|
1017
1016
|
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
)
|
|
1022
|
-
else:
|
|
1023
|
-
# If encoder cache is not residual, we use the encoder hidden states directly
|
|
1024
|
-
encoder_hidden_states = encoder_hidden_states_prev
|
|
1017
|
+
assert (
|
|
1018
|
+
encoder_hidden_states_prev is not None
|
|
1019
|
+
), f"{prefix}_encoder_buffer must be set before"
|
|
1025
1020
|
|
|
1026
|
-
|
|
1027
|
-
|
|
1021
|
+
if is_encoder_cache_residual():
|
|
1022
|
+
encoder_hidden_states = (
|
|
1023
|
+
encoder_hidden_states_prev + encoder_hidden_states
|
|
1024
|
+
)
|
|
1025
|
+
else:
|
|
1026
|
+
# If encoder cache is not residual, we use the encoder hidden states directly
|
|
1027
|
+
encoder_hidden_states = encoder_hidden_states_prev
|
|
1028
|
+
|
|
1029
|
+
encoder_hidden_states = encoder_hidden_states.contiguous()
|
|
1028
1030
|
|
|
1029
1031
|
return hidden_states, encoder_hidden_states
|
|
1030
1032
|
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from cache_dit.metrics.metrics import compute_psnr
|
|
2
|
+
from cache_dit.metrics.metrics import compute_ssim
|
|
3
|
+
from cache_dit.metrics.metrics import compute_mse
|
|
4
|
+
from cache_dit.metrics.metrics import compute_video_psnr
|
|
5
|
+
from cache_dit.metrics.metrics import compute_video_ssim
|
|
6
|
+
from cache_dit.metrics.metrics import compute_video_mse
|
|
7
|
+
from cache_dit.metrics.metrics import entrypoint
|
|
8
|
+
from cache_dit.metrics.fid import FrechetInceptionDistance
|
|
9
|
+
from cache_dit.metrics.config import set_metrics_verbose
|
|
10
|
+
from cache_dit.metrics.config import get_metrics_verbose
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def main():
|
|
14
|
+
entrypoint()
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from cache_dit.logger import init_logger
|
|
2
|
+
|
|
3
|
+
logger = init_logger(__name__)
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
_metrics_progress_verbose = False
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def set_metrics_verbose(verbose: bool):
|
|
10
|
+
global _metrics_progress_verbose
|
|
11
|
+
_metrics_progress_verbose = verbose
|
|
12
|
+
logger.debug(f"Metrics verbose: {verbose}")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_metrics_verbose() -> bool:
|
|
16
|
+
global _metrics_progress_verbose
|
|
17
|
+
return _metrics_progress_verbose
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
_IMAGE_EXTENSIONS = [
|
|
21
|
+
"bmp",
|
|
22
|
+
"jpg",
|
|
23
|
+
"jpeg",
|
|
24
|
+
"pgm",
|
|
25
|
+
"png",
|
|
26
|
+
"ppm",
|
|
27
|
+
"tif",
|
|
28
|
+
"tiff",
|
|
29
|
+
"webp",
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
_VIDEO_EXTENSIONS = [
|
|
33
|
+
"mp4",
|
|
34
|
+
]
|
cache_dit/metrics/fid.py
ADDED
|
@@ -0,0 +1,498 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import cv2
|
|
3
|
+
import pathlib
|
|
4
|
+
import numpy as np
|
|
5
|
+
from PIL import Image
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
from scipy import linalg
|
|
8
|
+
import torch
|
|
9
|
+
import torchvision.transforms as TF
|
|
10
|
+
from torch.nn.functional import adaptive_avg_pool2d
|
|
11
|
+
from cache_dit.metrics.inception import InceptionV3
|
|
12
|
+
from cache_dit.metrics.config import _IMAGE_EXTENSIONS
|
|
13
|
+
from cache_dit.metrics.config import _VIDEO_EXTENSIONS
|
|
14
|
+
from cache_dit.logger import init_logger
|
|
15
|
+
|
|
16
|
+
logger = init_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# Adapted from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py
|
|
20
|
+
class ImagePathDataset(torch.utils.data.Dataset):
|
|
21
|
+
def __init__(self, files_or_imgs, transforms=None):
|
|
22
|
+
self.files_or_imgs = files_or_imgs
|
|
23
|
+
self.transforms = transforms
|
|
24
|
+
|
|
25
|
+
def __len__(self):
|
|
26
|
+
return len(self.files_or_imgs)
|
|
27
|
+
|
|
28
|
+
def __getitem__(self, i):
|
|
29
|
+
file_or_img = self.files_or_imgs[i]
|
|
30
|
+
if isinstance(file_or_img, (str, pathlib.Path)):
|
|
31
|
+
img = Image.open(file_or_img).convert("RGB")
|
|
32
|
+
elif isinstance(file_or_img, np.ndarray):
|
|
33
|
+
# Assume the img is a standard OpenCV image.
|
|
34
|
+
img = cv2.cvtColor(file_or_img, cv2.COLOR_BGR2RGB)
|
|
35
|
+
img = Image.fromarray(img)
|
|
36
|
+
else:
|
|
37
|
+
raise ValueError(
|
|
38
|
+
"file_or_img must be a file path or an OpenCV image."
|
|
39
|
+
)
|
|
40
|
+
if self.transforms is not None:
|
|
41
|
+
img = self.transforms(img)
|
|
42
|
+
return img
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def get_activations(
|
|
46
|
+
files_or_imgs,
|
|
47
|
+
model,
|
|
48
|
+
batch_size=50,
|
|
49
|
+
dims=2048,
|
|
50
|
+
device="cpu",
|
|
51
|
+
num_workers=1,
|
|
52
|
+
disable_tqdm=True,
|
|
53
|
+
):
|
|
54
|
+
"""Calculates the activations of the pool_3 layer for all images.
|
|
55
|
+
|
|
56
|
+
Params:
|
|
57
|
+
-- files_or_imgs : List of image files paths or OpenCV image
|
|
58
|
+
-- model : Instance of inception model
|
|
59
|
+
-- batch_size : Batch size of images for the model to process at once.
|
|
60
|
+
Make sure that the number of samples is a multiple of
|
|
61
|
+
the batch size, otherwise some samples are ignored. This
|
|
62
|
+
behavior is retained to match the original FID score
|
|
63
|
+
implementation.
|
|
64
|
+
-- dims : Dimensionality of features returned by Inception
|
|
65
|
+
-- device : Device to run calculations
|
|
66
|
+
-- num_workers : Number of parallel dataloader workers
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
-- A numpy array of dimension (num images, dims) that contains the
|
|
70
|
+
activations of the given tensor when feeding inception with the
|
|
71
|
+
query tensor.
|
|
72
|
+
"""
|
|
73
|
+
model.eval()
|
|
74
|
+
|
|
75
|
+
if batch_size > len(files_or_imgs):
|
|
76
|
+
logger.info(
|
|
77
|
+
(
|
|
78
|
+
"Warning: batch size is bigger than the data size. "
|
|
79
|
+
"Setting batch size to data size"
|
|
80
|
+
)
|
|
81
|
+
)
|
|
82
|
+
batch_size = len(files_or_imgs)
|
|
83
|
+
|
|
84
|
+
dataset = ImagePathDataset(files_or_imgs, transforms=TF.ToTensor())
|
|
85
|
+
dataloader = torch.utils.data.DataLoader(
|
|
86
|
+
dataset,
|
|
87
|
+
batch_size=batch_size,
|
|
88
|
+
shuffle=False,
|
|
89
|
+
drop_last=False,
|
|
90
|
+
num_workers=num_workers,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
pred_arr = np.empty((len(files_or_imgs), dims))
|
|
94
|
+
|
|
95
|
+
start_idx = 0
|
|
96
|
+
|
|
97
|
+
for batch in tqdm(dataloader, disable=disable_tqdm):
|
|
98
|
+
batch = batch.to(device)
|
|
99
|
+
|
|
100
|
+
with torch.no_grad():
|
|
101
|
+
pred = model(batch)[0]
|
|
102
|
+
|
|
103
|
+
# If model output is not scalar, apply global spatial average pooling.
|
|
104
|
+
# This happens if you choose a dimensionality not equal 2048.
|
|
105
|
+
if pred.size(2) != 1 or pred.size(3) != 1:
|
|
106
|
+
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
|
|
107
|
+
|
|
108
|
+
pred = pred.squeeze(3).squeeze(2).cpu().numpy()
|
|
109
|
+
|
|
110
|
+
pred_arr[start_idx : start_idx + pred.shape[0]] = pred
|
|
111
|
+
|
|
112
|
+
start_idx = start_idx + pred.shape[0]
|
|
113
|
+
|
|
114
|
+
return pred_arr
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def calculate_frechet_distance(
|
|
118
|
+
mu1,
|
|
119
|
+
sigma1,
|
|
120
|
+
mu2,
|
|
121
|
+
sigma2,
|
|
122
|
+
eps=1e-6,
|
|
123
|
+
):
|
|
124
|
+
"""Numpy implementation of the Frechet Distance.
|
|
125
|
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
|
126
|
+
and X_2 ~ N(mu_2, C_2) is
|
|
127
|
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
|
128
|
+
|
|
129
|
+
Stable version by Dougal J. Sutherland.
|
|
130
|
+
|
|
131
|
+
Params:
|
|
132
|
+
-- mu1 : Numpy array containing the activations of a layer of the
|
|
133
|
+
inception net (like returned by the function 'get_predictions')
|
|
134
|
+
for generated samples.
|
|
135
|
+
-- mu2 : The sample mean over activations, precalculated on an
|
|
136
|
+
representative data set.
|
|
137
|
+
-- sigma1: The covariance matrix over activations for generated samples.
|
|
138
|
+
-- sigma2: The covariance matrix over activations, precalculated on an
|
|
139
|
+
representative data set.
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
-- : The Frechet Distance.
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
mu1 = np.atleast_1d(mu1)
|
|
146
|
+
mu2 = np.atleast_1d(mu2)
|
|
147
|
+
|
|
148
|
+
sigma1 = np.atleast_2d(sigma1)
|
|
149
|
+
sigma2 = np.atleast_2d(sigma2)
|
|
150
|
+
|
|
151
|
+
assert (
|
|
152
|
+
mu1.shape == mu2.shape
|
|
153
|
+
), "Training and test mean vectors have different lengths"
|
|
154
|
+
assert (
|
|
155
|
+
sigma1.shape == sigma2.shape
|
|
156
|
+
), "Training and test covariances have different dimensions"
|
|
157
|
+
|
|
158
|
+
diff = mu1 - mu2
|
|
159
|
+
|
|
160
|
+
# Product might be almost singular
|
|
161
|
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
|
162
|
+
if not np.isfinite(covmean).all():
|
|
163
|
+
msg = (
|
|
164
|
+
"fid calculation produces singular product; "
|
|
165
|
+
"adding %s to diagonal of cov estimates"
|
|
166
|
+
) % eps
|
|
167
|
+
print(msg)
|
|
168
|
+
offset = np.eye(sigma1.shape[0]) * eps
|
|
169
|
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
|
170
|
+
|
|
171
|
+
# Numerical error might give slight imaginary component
|
|
172
|
+
if np.iscomplexobj(covmean):
|
|
173
|
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
|
174
|
+
m = np.max(np.abs(covmean.imag))
|
|
175
|
+
raise ValueError("Imaginary component {}".format(m))
|
|
176
|
+
covmean = covmean.real
|
|
177
|
+
|
|
178
|
+
tr_covmean = np.trace(covmean)
|
|
179
|
+
|
|
180
|
+
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def calculate_activation_statistics(
|
|
184
|
+
files_or_imgs,
|
|
185
|
+
model,
|
|
186
|
+
batch_size=50,
|
|
187
|
+
dims=2048,
|
|
188
|
+
device="cpu",
|
|
189
|
+
num_workers=1,
|
|
190
|
+
disable_tqdm=True,
|
|
191
|
+
):
|
|
192
|
+
"""Calculation of the statistics used by the FID.
|
|
193
|
+
Params:
|
|
194
|
+
-- files_or_imgs : List of image files paths or OpenCV image
|
|
195
|
+
-- model : Instance of inception model
|
|
196
|
+
-- batch_size : Batch size of images for the model to process at once.
|
|
197
|
+
Make sure that the number of samples is a multiple of
|
|
198
|
+
the batch size, otherwise some samples are ignored. This
|
|
199
|
+
behavior is retained to match the original FID score
|
|
200
|
+
implementation.
|
|
201
|
+
-- dims : Dimensionality of features returned by Inception
|
|
202
|
+
-- device : Device to run calculations
|
|
203
|
+
-- num_workers : Number of parallel dataloader workers
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
-- mu : The mean over samples of the activations of the pool_3 layer of
|
|
207
|
+
the inception model.
|
|
208
|
+
-- sigma : The covariance matrix of the activations of the pool_3 layer of
|
|
209
|
+
the inception model.
|
|
210
|
+
"""
|
|
211
|
+
act = get_activations(
|
|
212
|
+
files_or_imgs,
|
|
213
|
+
model,
|
|
214
|
+
batch_size,
|
|
215
|
+
dims,
|
|
216
|
+
device,
|
|
217
|
+
num_workers,
|
|
218
|
+
disable_tqdm,
|
|
219
|
+
)
|
|
220
|
+
mu = np.mean(act, axis=0)
|
|
221
|
+
sigma = np.cov(act, rowvar=False)
|
|
222
|
+
return mu, sigma
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
class FrechetInceptionDistance:
|
|
226
|
+
def __init__(
|
|
227
|
+
self,
|
|
228
|
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
|
229
|
+
dims: int = 2048,
|
|
230
|
+
num_workers: int = 1,
|
|
231
|
+
batch_size: int = 1,
|
|
232
|
+
disable_tqdm: bool = True,
|
|
233
|
+
):
|
|
234
|
+
# https://github.com/mseitzer/pytorch-fid/src/pytorch_fid/fid_score.py
|
|
235
|
+
self.dims = dims
|
|
236
|
+
self.device = device
|
|
237
|
+
self.num_workers = num_workers
|
|
238
|
+
self.batch_size = batch_size
|
|
239
|
+
self.disable_tqdm = disable_tqdm
|
|
240
|
+
self.block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[self.dims]
|
|
241
|
+
self.model = InceptionV3([self.block_idx]).to(self.device)
|
|
242
|
+
self.model = self.model.eval()
|
|
243
|
+
|
|
244
|
+
def compute_fid(
|
|
245
|
+
self,
|
|
246
|
+
image_true: np.ndarray | str,
|
|
247
|
+
image_test: np.ndarray | str,
|
|
248
|
+
):
|
|
249
|
+
"""
|
|
250
|
+
Calculates the FID of two file paths
|
|
251
|
+
FID = FrechetInceptionDistance()
|
|
252
|
+
img_fid = FID.compute_fid("img_true.png", "img_test.png")
|
|
253
|
+
img_dir_fid = FID.compute_fid("img_true_dir", "img_test_dir")
|
|
254
|
+
"""
|
|
255
|
+
if isinstance(image_true, str) or isinstance(image_test, str):
|
|
256
|
+
if os.path.isfile(image_true) or os.path.isfile(image_test):
|
|
257
|
+
assert os.path.exists(image_true)
|
|
258
|
+
assert os.path.exists(image_test)
|
|
259
|
+
assert image_true.split(".")[-1] in _IMAGE_EXTENSIONS
|
|
260
|
+
assert image_test.split(".")[-1] in _IMAGE_EXTENSIONS
|
|
261
|
+
image_true_files = [image_true]
|
|
262
|
+
image_test_files = [image_test]
|
|
263
|
+
else:
|
|
264
|
+
# glob image files from dir
|
|
265
|
+
assert os.path.isdir(image_true)
|
|
266
|
+
assert os.path.isdir(image_test)
|
|
267
|
+
image_true_dir = pathlib.Path(image_true)
|
|
268
|
+
image_true_files = sorted(
|
|
269
|
+
[
|
|
270
|
+
file
|
|
271
|
+
for ext in _IMAGE_EXTENSIONS
|
|
272
|
+
for file in image_true_dir.rglob("*.{}".format(ext))
|
|
273
|
+
]
|
|
274
|
+
)
|
|
275
|
+
image_test_dir = pathlib.Path(image_test)
|
|
276
|
+
image_test_files = sorted(
|
|
277
|
+
[
|
|
278
|
+
file
|
|
279
|
+
for ext in _IMAGE_EXTENSIONS
|
|
280
|
+
for file in image_test_dir.rglob("*.{}".format(ext))
|
|
281
|
+
]
|
|
282
|
+
)
|
|
283
|
+
image_true_files = [
|
|
284
|
+
file.as_posix() for file in image_true_files
|
|
285
|
+
]
|
|
286
|
+
image_test_files = [
|
|
287
|
+
file.as_posix() for file in image_test_files
|
|
288
|
+
]
|
|
289
|
+
|
|
290
|
+
# select valid files
|
|
291
|
+
image_true_files_selected = []
|
|
292
|
+
image_test_files_selected = []
|
|
293
|
+
for i in range(
|
|
294
|
+
min(len(image_true_files), len(image_test_files))
|
|
295
|
+
):
|
|
296
|
+
selected_image_true = image_true_files[i]
|
|
297
|
+
selected_image_test = image_test_files[i]
|
|
298
|
+
# Image pair must have the same basename
|
|
299
|
+
if os.path.basename(
|
|
300
|
+
selected_image_test
|
|
301
|
+
) == os.path.basename(selected_image_true):
|
|
302
|
+
image_true_files_selected.append(selected_image_true)
|
|
303
|
+
image_test_files_selected.append(selected_image_test)
|
|
304
|
+
image_true_files = image_true_files_selected.copy()
|
|
305
|
+
image_test_files = image_test_files_selected.copy()
|
|
306
|
+
if len(image_true_files) == 0:
|
|
307
|
+
logger.error(
|
|
308
|
+
"No valid Image pairs, please note that Image "
|
|
309
|
+
"pairs must have the same basename."
|
|
310
|
+
)
|
|
311
|
+
return None, None
|
|
312
|
+
|
|
313
|
+
logger.debug(f"image_true_files: {image_true_files}")
|
|
314
|
+
logger.debug(f"image_test_files: {image_test_files}")
|
|
315
|
+
else:
|
|
316
|
+
image_true_files = [image_true]
|
|
317
|
+
image_test_files = [image_test]
|
|
318
|
+
|
|
319
|
+
batch_size = min(16, self.batch_size)
|
|
320
|
+
batch_size = min(batch_size, len(image_test_files))
|
|
321
|
+
m1, s1 = calculate_activation_statistics(
|
|
322
|
+
image_true_files,
|
|
323
|
+
self.model,
|
|
324
|
+
batch_size,
|
|
325
|
+
self.dims,
|
|
326
|
+
self.device,
|
|
327
|
+
self.num_workers,
|
|
328
|
+
self.disable_tqdm,
|
|
329
|
+
)
|
|
330
|
+
m2, s2 = calculate_activation_statistics(
|
|
331
|
+
image_test_files,
|
|
332
|
+
self.model,
|
|
333
|
+
batch_size,
|
|
334
|
+
self.dims,
|
|
335
|
+
self.device,
|
|
336
|
+
self.num_workers,
|
|
337
|
+
self.disable_tqdm,
|
|
338
|
+
)
|
|
339
|
+
fid_value = calculate_frechet_distance(
|
|
340
|
+
m1,
|
|
341
|
+
s1,
|
|
342
|
+
m2,
|
|
343
|
+
s2,
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
return fid_value, len(image_true_files)
|
|
347
|
+
|
|
348
|
+
def compute_video_fid(
|
|
349
|
+
self,
|
|
350
|
+
# file or dir
|
|
351
|
+
video_true: str,
|
|
352
|
+
video_test: str,
|
|
353
|
+
):
|
|
354
|
+
if os.path.isfile(video_true) and os.path.isfile(video_test):
|
|
355
|
+
video_true_frames, video_test_frames, valid_frames = (
|
|
356
|
+
self._fetch_video_frames(
|
|
357
|
+
video_true=video_true,
|
|
358
|
+
video_test=video_test,
|
|
359
|
+
)
|
|
360
|
+
)
|
|
361
|
+
elif os.path.isdir(video_true) and os.path.isdir(video_test):
|
|
362
|
+
# Glob videos
|
|
363
|
+
video_true_dir: pathlib.Path = pathlib.Path(video_true)
|
|
364
|
+
video_true_files = sorted(
|
|
365
|
+
[
|
|
366
|
+
file
|
|
367
|
+
for ext in _VIDEO_EXTENSIONS
|
|
368
|
+
for file in video_true_dir.rglob("*.{}".format(ext))
|
|
369
|
+
]
|
|
370
|
+
)
|
|
371
|
+
video_test_dir: pathlib.Path = pathlib.Path(video_test)
|
|
372
|
+
video_test_files = sorted(
|
|
373
|
+
[
|
|
374
|
+
file
|
|
375
|
+
for ext in _VIDEO_EXTENSIONS
|
|
376
|
+
for file in video_test_dir.rglob("*.{}".format(ext))
|
|
377
|
+
]
|
|
378
|
+
)
|
|
379
|
+
video_true_files = [file.as_posix() for file in video_true_files]
|
|
380
|
+
video_test_files = [file.as_posix() for file in video_test_files]
|
|
381
|
+
|
|
382
|
+
# select valid video files
|
|
383
|
+
video_true_files_selected = []
|
|
384
|
+
video_test_files_selected = []
|
|
385
|
+
for i in range(min(len(video_true_files), len(video_test_files))):
|
|
386
|
+
selected_video_true = video_true_files[i]
|
|
387
|
+
selected_video_test = video_test_files[i]
|
|
388
|
+
# Video pair must have the same basename
|
|
389
|
+
if os.path.basename(selected_video_test) == os.path.basename(
|
|
390
|
+
selected_video_true
|
|
391
|
+
):
|
|
392
|
+
video_true_files_selected.append(selected_video_true)
|
|
393
|
+
video_test_files_selected.append(selected_video_test)
|
|
394
|
+
|
|
395
|
+
video_true_files = video_true_files_selected.copy()
|
|
396
|
+
video_test_files = video_test_files_selected.copy()
|
|
397
|
+
if len(video_true_files) == 0:
|
|
398
|
+
logger.error(
|
|
399
|
+
"No valid Video pairs, please note that Video "
|
|
400
|
+
"pairs must have the same basename."
|
|
401
|
+
)
|
|
402
|
+
return None, None
|
|
403
|
+
logger.debug(f"video_true_files: {video_true_files}")
|
|
404
|
+
logger.debug(f"video_test_files: {video_test_files}")
|
|
405
|
+
|
|
406
|
+
# Fetch all frames
|
|
407
|
+
video_true_frames = []
|
|
408
|
+
video_test_frames = []
|
|
409
|
+
valid_frames = 0
|
|
410
|
+
|
|
411
|
+
for video_true_, video_test_ in zip(
|
|
412
|
+
video_true_files, video_test_files
|
|
413
|
+
):
|
|
414
|
+
video_true_frames_, video_test_frames_, valid_frames_ = (
|
|
415
|
+
self._fetch_video_frames(
|
|
416
|
+
video_true=video_true_, video_test=video_test_
|
|
417
|
+
)
|
|
418
|
+
)
|
|
419
|
+
video_true_frames.extend(video_true_frames_)
|
|
420
|
+
video_test_frames.extend(video_test_frames_)
|
|
421
|
+
valid_frames += valid_frames_
|
|
422
|
+
else:
|
|
423
|
+
raise ValueError("video_true and video_test must be files or dirs.")
|
|
424
|
+
|
|
425
|
+
if valid_frames <= 0:
|
|
426
|
+
logger.debug("No valid frames to compare")
|
|
427
|
+
return None, None
|
|
428
|
+
|
|
429
|
+
batch_size = min(16, self.batch_size)
|
|
430
|
+
m1, s1 = calculate_activation_statistics(
|
|
431
|
+
video_true_frames,
|
|
432
|
+
self.model,
|
|
433
|
+
batch_size,
|
|
434
|
+
self.dims,
|
|
435
|
+
self.device,
|
|
436
|
+
self.num_workers,
|
|
437
|
+
self.disable_tqdm,
|
|
438
|
+
)
|
|
439
|
+
m2, s2 = calculate_activation_statistics(
|
|
440
|
+
video_test_frames,
|
|
441
|
+
self.model,
|
|
442
|
+
batch_size,
|
|
443
|
+
self.dims,
|
|
444
|
+
self.device,
|
|
445
|
+
self.num_workers,
|
|
446
|
+
self.disable_tqdm,
|
|
447
|
+
)
|
|
448
|
+
fid_value = calculate_frechet_distance(
|
|
449
|
+
m1,
|
|
450
|
+
s1,
|
|
451
|
+
m2,
|
|
452
|
+
s2,
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
return fid_value, valid_frames
|
|
456
|
+
|
|
457
|
+
def _fetch_video_frames(
|
|
458
|
+
self,
|
|
459
|
+
video_true: str,
|
|
460
|
+
video_test: str,
|
|
461
|
+
):
|
|
462
|
+
cap1 = cv2.VideoCapture(video_true)
|
|
463
|
+
cap2 = cv2.VideoCapture(video_test)
|
|
464
|
+
|
|
465
|
+
if not cap1.isOpened() or not cap2.isOpened():
|
|
466
|
+
logger.error("Could not open video files")
|
|
467
|
+
return [], [], 0
|
|
468
|
+
|
|
469
|
+
frame_count = min(
|
|
470
|
+
int(cap1.get(cv2.CAP_PROP_FRAME_COUNT)),
|
|
471
|
+
int(cap2.get(cv2.CAP_PROP_FRAME_COUNT)),
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
valid_frames = 0
|
|
475
|
+
video_true_frames = []
|
|
476
|
+
video_test_frames = []
|
|
477
|
+
|
|
478
|
+
logger.debug(f"Total frames: {frame_count}")
|
|
479
|
+
|
|
480
|
+
while True:
|
|
481
|
+
ret1, frame1 = cap1.read()
|
|
482
|
+
ret2, frame2 = cap2.read()
|
|
483
|
+
|
|
484
|
+
if not ret1 or not ret2:
|
|
485
|
+
break
|
|
486
|
+
|
|
487
|
+
video_true_frames.append(frame1)
|
|
488
|
+
video_test_frames.append(frame2)
|
|
489
|
+
|
|
490
|
+
valid_frames += 1
|
|
491
|
+
|
|
492
|
+
cap1.release()
|
|
493
|
+
cap2.release()
|
|
494
|
+
|
|
495
|
+
if valid_frames <= 0:
|
|
496
|
+
return [], [], 0
|
|
497
|
+
|
|
498
|
+
return video_true_frames, video_test_frames, valid_frames
|