cache-dit 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

@@ -0,0 +1,353 @@
1
+ import torch
2
+ import torchvision
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from cache_dit.logger import init_logger
6
+
7
+ logger = init_logger(__name__)
8
+
9
+
10
+ # Adapted from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/inception.py
11
+ try:
12
+ from torchvision.models.utils import load_state_dict_from_url
13
+ except ImportError:
14
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
15
+
16
+ # Inception weights ported to Pytorch from
17
+ # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
18
+ FID_WEIGHTS_URL = "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" # noqa: E501
19
+
20
+
21
+ class InceptionV3(nn.Module):
22
+ """Pretrained InceptionV3 network returning feature maps"""
23
+
24
+ # Index of default block of inception to return,
25
+ # corresponds to output of final average pooling
26
+ DEFAULT_BLOCK_INDEX = 3
27
+
28
+ # Maps feature dimensionality to their output blocks indices
29
+ BLOCK_INDEX_BY_DIM = {
30
+ 64: 0, # First max pooling features
31
+ 192: 1, # Second max pooling featurs
32
+ 768: 2, # Pre-aux classifier features
33
+ 2048: 3, # Final average pooling features
34
+ }
35
+
36
+ def __init__(
37
+ self,
38
+ output_blocks=(DEFAULT_BLOCK_INDEX,),
39
+ resize_input=True,
40
+ normalize_input=True,
41
+ requires_grad=False,
42
+ use_fid_inception=True,
43
+ ):
44
+ """Build pretrained InceptionV3
45
+
46
+ Parameters
47
+ ----------
48
+ output_blocks : list of int
49
+ Indices of blocks to return features of. Possible values are:
50
+ - 0: corresponds to output of first max pooling
51
+ - 1: corresponds to output of second max pooling
52
+ - 2: corresponds to output which is fed to aux classifier
53
+ - 3: corresponds to output of final average pooling
54
+ resize_input : bool
55
+ If true, bilinearly resizes input to width and height 299 before
56
+ feeding input to model. As the network without fully connected
57
+ layers is fully convolutional, it should be able to handle inputs
58
+ of arbitrary size, so resizing might not be strictly needed
59
+ normalize_input : bool
60
+ If true, scales the input from range (0, 1) to the range the
61
+ pretrained Inception network expects, namely (-1, 1)
62
+ requires_grad : bool
63
+ If true, parameters of the model require gradients. Possibly useful
64
+ for finetuning the network
65
+ use_fid_inception : bool
66
+ If true, uses the pretrained Inception model used in Tensorflow's
67
+ FID implementation. If false, uses the pretrained Inception model
68
+ available in torchvision. The FID Inception model has different
69
+ weights and a slightly different structure from torchvision's
70
+ Inception model. If you want to compute FID scores, you are
71
+ strongly advised to set this parameter to true to get comparable
72
+ results.
73
+ """
74
+ super(InceptionV3, self).__init__()
75
+
76
+ self.resize_input = resize_input
77
+ self.normalize_input = normalize_input
78
+ self.output_blocks = sorted(output_blocks)
79
+ self.last_needed_block = max(output_blocks)
80
+
81
+ assert (
82
+ self.last_needed_block <= 3
83
+ ), "Last possible output block index is 3"
84
+
85
+ self.blocks = nn.ModuleList()
86
+
87
+ if use_fid_inception:
88
+ inception = fid_inception_v3()
89
+ else:
90
+ inception = _inception_v3(weights="DEFAULT")
91
+
92
+ # Block 0: input to maxpool1
93
+ block0 = [
94
+ inception.Conv2d_1a_3x3,
95
+ inception.Conv2d_2a_3x3,
96
+ inception.Conv2d_2b_3x3,
97
+ nn.MaxPool2d(kernel_size=3, stride=2),
98
+ ]
99
+ self.blocks.append(nn.Sequential(*block0))
100
+
101
+ # Block 1: maxpool1 to maxpool2
102
+ if self.last_needed_block >= 1:
103
+ block1 = [
104
+ inception.Conv2d_3b_1x1,
105
+ inception.Conv2d_4a_3x3,
106
+ nn.MaxPool2d(kernel_size=3, stride=2),
107
+ ]
108
+ self.blocks.append(nn.Sequential(*block1))
109
+
110
+ # Block 2: maxpool2 to aux classifier
111
+ if self.last_needed_block >= 2:
112
+ block2 = [
113
+ inception.Mixed_5b,
114
+ inception.Mixed_5c,
115
+ inception.Mixed_5d,
116
+ inception.Mixed_6a,
117
+ inception.Mixed_6b,
118
+ inception.Mixed_6c,
119
+ inception.Mixed_6d,
120
+ inception.Mixed_6e,
121
+ ]
122
+ self.blocks.append(nn.Sequential(*block2))
123
+
124
+ # Block 3: aux classifier to final avgpool
125
+ if self.last_needed_block >= 3:
126
+ block3 = [
127
+ inception.Mixed_7a,
128
+ inception.Mixed_7b,
129
+ inception.Mixed_7c,
130
+ nn.AdaptiveAvgPool2d(output_size=(1, 1)),
131
+ ]
132
+ self.blocks.append(nn.Sequential(*block3))
133
+
134
+ for param in self.parameters():
135
+ param.requires_grad = requires_grad
136
+
137
+ def forward(self, inp):
138
+ """Get Inception feature maps
139
+
140
+ Parameters
141
+ ----------
142
+ inp : torch.autograd.Variable
143
+ Input tensor of shape Bx3xHxW. Values are expected to be in
144
+ range (0, 1)
145
+
146
+ Returns
147
+ -------
148
+ List of torch.autograd.Variable, corresponding to the selected output
149
+ block, sorted ascending by index
150
+ """
151
+ outp = []
152
+ x = inp
153
+
154
+ if self.resize_input:
155
+ x = F.interpolate(
156
+ x, size=(299, 299), mode="bilinear", align_corners=False
157
+ )
158
+
159
+ if self.normalize_input:
160
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
161
+
162
+ for idx, block in enumerate(self.blocks):
163
+ x = block(x)
164
+ if idx in self.output_blocks:
165
+ outp.append(x)
166
+
167
+ if idx == self.last_needed_block:
168
+ break
169
+
170
+ return outp
171
+
172
+
173
+ def _inception_v3(*args, **kwargs):
174
+ """Wraps `torchvision.models.inception_v3`"""
175
+ try:
176
+ version = tuple(map(int, torchvision.__version__.split(".")[:2]))
177
+ except ValueError:
178
+ # Just a caution against weird version strings
179
+ version = (0,)
180
+
181
+ # Skips default weight inititialization if supported by torchvision
182
+ # version. See https://github.com/mseitzer/pytorch-fid/issues/28.
183
+ if version >= (0, 6):
184
+ kwargs["init_weights"] = False
185
+
186
+ # Backwards compatibility: `weights` argument was handled by `pretrained`
187
+ # argument prior to version 0.13.
188
+ if version < (0, 13) and "weights" in kwargs:
189
+ if kwargs["weights"] == "DEFAULT":
190
+ kwargs["pretrained"] = True
191
+ elif kwargs["weights"] is None:
192
+ kwargs["pretrained"] = False
193
+ else:
194
+ raise ValueError(
195
+ "weights=={} not supported in torchvision {}".format(
196
+ kwargs["weights"], torchvision.__version__
197
+ )
198
+ )
199
+ del kwargs["weights"]
200
+
201
+ return torchvision.models.inception_v3(*args, **kwargs)
202
+
203
+
204
+ def fid_inception_v3():
205
+ """Build pretrained Inception model for FID computation
206
+
207
+ The Inception model for FID computation uses a different set of weights
208
+ and has a slightly different structure than torchvision's Inception.
209
+
210
+ This method first constructs torchvision's Inception and then patches the
211
+ necessary parts that are different in the FID Inception model.
212
+ """
213
+ inception = _inception_v3(num_classes=1008, aux_logits=False, weights=None)
214
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
215
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
216
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
217
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
218
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
219
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
220
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
221
+ inception.Mixed_7b = FIDInceptionE_1(1280)
222
+ inception.Mixed_7c = FIDInceptionE_2(2048)
223
+
224
+ state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=False)
225
+ inception.load_state_dict(state_dict)
226
+ return inception
227
+
228
+
229
+ class FIDInceptionA(torchvision.models.inception.InceptionA):
230
+ """InceptionA block patched for FID computation"""
231
+
232
+ def __init__(self, in_channels, pool_features):
233
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
234
+
235
+ def forward(self, x):
236
+ branch1x1 = self.branch1x1(x)
237
+
238
+ branch5x5 = self.branch5x5_1(x)
239
+ branch5x5 = self.branch5x5_2(branch5x5)
240
+
241
+ branch3x3dbl = self.branch3x3dbl_1(x)
242
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
243
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
244
+
245
+ # Patch: Tensorflow's average pool does not use the padded zero's in
246
+ # its average calculation
247
+ branch_pool = F.avg_pool2d(
248
+ x, kernel_size=3, stride=1, padding=1, count_include_pad=False
249
+ )
250
+ branch_pool = self.branch_pool(branch_pool)
251
+
252
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
253
+ return torch.cat(outputs, 1)
254
+
255
+
256
+ class FIDInceptionC(torchvision.models.inception.InceptionC):
257
+ """InceptionC block patched for FID computation"""
258
+
259
+ def __init__(self, in_channels, channels_7x7):
260
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
261
+
262
+ def forward(self, x):
263
+ branch1x1 = self.branch1x1(x)
264
+
265
+ branch7x7 = self.branch7x7_1(x)
266
+ branch7x7 = self.branch7x7_2(branch7x7)
267
+ branch7x7 = self.branch7x7_3(branch7x7)
268
+
269
+ branch7x7dbl = self.branch7x7dbl_1(x)
270
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
271
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
272
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
273
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
274
+
275
+ # Patch: Tensorflow's average pool does not use the padded zero's in
276
+ # its average calculation
277
+ branch_pool = F.avg_pool2d(
278
+ x, kernel_size=3, stride=1, padding=1, count_include_pad=False
279
+ )
280
+ branch_pool = self.branch_pool(branch_pool)
281
+
282
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
283
+ return torch.cat(outputs, 1)
284
+
285
+
286
+ class FIDInceptionE_1(torchvision.models.inception.InceptionE):
287
+ """First InceptionE block patched for FID computation"""
288
+
289
+ def __init__(self, in_channels):
290
+ super(FIDInceptionE_1, self).__init__(in_channels)
291
+
292
+ def forward(self, x):
293
+ branch1x1 = self.branch1x1(x)
294
+
295
+ branch3x3 = self.branch3x3_1(x)
296
+ branch3x3 = [
297
+ self.branch3x3_2a(branch3x3),
298
+ self.branch3x3_2b(branch3x3),
299
+ ]
300
+ branch3x3 = torch.cat(branch3x3, 1)
301
+
302
+ branch3x3dbl = self.branch3x3dbl_1(x)
303
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
304
+ branch3x3dbl = [
305
+ self.branch3x3dbl_3a(branch3x3dbl),
306
+ self.branch3x3dbl_3b(branch3x3dbl),
307
+ ]
308
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
309
+
310
+ # Patch: Tensorflow's average pool does not use the padded zero's in
311
+ # its average calculation
312
+ branch_pool = F.avg_pool2d(
313
+ x, kernel_size=3, stride=1, padding=1, count_include_pad=False
314
+ )
315
+ branch_pool = self.branch_pool(branch_pool)
316
+
317
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
318
+ return torch.cat(outputs, 1)
319
+
320
+
321
+ class FIDInceptionE_2(torchvision.models.inception.InceptionE):
322
+ """Second InceptionE block patched for FID computation"""
323
+
324
+ def __init__(self, in_channels):
325
+ super(FIDInceptionE_2, self).__init__(in_channels)
326
+
327
+ def forward(self, x):
328
+ branch1x1 = self.branch1x1(x)
329
+
330
+ branch3x3 = self.branch3x3_1(x)
331
+ branch3x3 = [
332
+ self.branch3x3_2a(branch3x3),
333
+ self.branch3x3_2b(branch3x3),
334
+ ]
335
+ branch3x3 = torch.cat(branch3x3, 1)
336
+
337
+ branch3x3dbl = self.branch3x3dbl_1(x)
338
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
339
+ branch3x3dbl = [
340
+ self.branch3x3dbl_3a(branch3x3dbl),
341
+ self.branch3x3dbl_3b(branch3x3dbl),
342
+ ]
343
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
344
+
345
+ # Patch: The FID Inception model uses max pooling instead of average
346
+ # pooling. This is likely an error in this specific Inception
347
+ # implementation, as other Inception models use average pooling here
348
+ # (which matches the description in the paper).
349
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
350
+ branch_pool = self.branch_pool(branch_pool)
351
+
352
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
353
+ return torch.cat(outputs, 1)
@@ -0,0 +1,356 @@
1
+ import os
2
+ import cv2
3
+ import pathlib
4
+ import argparse
5
+ import numpy as np
6
+ from functools import partial
7
+ from skimage.metrics import mean_squared_error
8
+ from skimage.metrics import peak_signal_noise_ratio
9
+ from skimage.metrics import structural_similarity
10
+ from cache_dit.metrics.fid import FrechetInceptionDistance
11
+ from cache_dit.logger import init_logger
12
+
13
+ logger = init_logger(__name__)
14
+
15
+
16
+ def compute_psnr_file(
17
+ image_true: np.ndarray | str,
18
+ image_test: np.ndarray | str,
19
+ ) -> float:
20
+ """
21
+ img_true = cv2.imread(img_true_file)
22
+ img_test = cv2.imread(img_test_file)
23
+ PSNR = compute_psnr(img_true, img_test)
24
+ """
25
+ if isinstance(image_true, str):
26
+ image_true = cv2.imread(image_true)
27
+ if isinstance(image_test, str):
28
+ image_test = cv2.imread(image_test)
29
+ return peak_signal_noise_ratio(
30
+ image_true,
31
+ image_test,
32
+ )
33
+
34
+
35
+ def compute_mse_file(
36
+ image_true: np.ndarray | str,
37
+ image_test: np.ndarray | str,
38
+ ) -> float:
39
+ """
40
+ img_true = cv2.imread(img_true_file)
41
+ img_test = cv2.imread(img_test_file)
42
+ MSE = compute_mse(img_true, img_test)
43
+ """
44
+ if isinstance(image_true, str):
45
+ image_true = cv2.imread(image_true)
46
+ if isinstance(image_test, str):
47
+ image_test = cv2.imread(image_test)
48
+ return mean_squared_error(
49
+ image_true,
50
+ image_test,
51
+ )
52
+
53
+
54
+ def compute_ssim_file(
55
+ image_true: np.ndarray | str,
56
+ image_test: np.ndarray | str,
57
+ ) -> float:
58
+ """
59
+ img_true = cv2.imread(img_true_file)
60
+ img_test = cv2.imread(img_test_file)
61
+ SSIM = compute_ssim(img_true, img_test)
62
+ """
63
+ if isinstance(image_true, str):
64
+ image_true = cv2.imread(image_true)
65
+ if isinstance(image_test, str):
66
+ image_test = cv2.imread(image_test)
67
+ return structural_similarity(
68
+ image_true,
69
+ image_test,
70
+ multichannel=True,
71
+ channel_axis=2,
72
+ )
73
+
74
+
75
+ _IMAGE_EXTENSIONS = {
76
+ "bmp",
77
+ "jpg",
78
+ "jpeg",
79
+ "pgm",
80
+ "png",
81
+ "ppm",
82
+ "tif",
83
+ "tiff",
84
+ "webp",
85
+ }
86
+
87
+
88
+ def compute_dir_metric(
89
+ image_true_dir: np.ndarray | str,
90
+ image_test_dir: np.ndarray | str,
91
+ compute_file_func: callable = compute_psnr_file,
92
+ ) -> float:
93
+ # Image
94
+ if isinstance(image_true_dir, np.ndarray) or isinstance(
95
+ image_test_dir, np.ndarray
96
+ ):
97
+ return compute_file_func(image_true_dir, image_test_dir), 1
98
+ # File
99
+ if not os.path.isdir(image_true_dir) or not os.path.isdir(image_test_dir):
100
+ return compute_file_func(image_true_dir, image_test_dir), 1
101
+ # Dir
102
+ image_true_dir: pathlib.Path = pathlib.Path(image_true_dir)
103
+ image_true_files = sorted(
104
+ [
105
+ file
106
+ for ext in _IMAGE_EXTENSIONS
107
+ for file in image_true_dir.rglob("*.{}".format(ext))
108
+ ]
109
+ )
110
+ image_test_dir: pathlib.Path = pathlib.Path(image_test_dir)
111
+ image_test_files = sorted(
112
+ [
113
+ file
114
+ for ext in _IMAGE_EXTENSIONS
115
+ for file in image_test_dir.rglob("*.{}".format(ext))
116
+ ]
117
+ )
118
+ image_true_files = [file.as_posix() for file in image_true_files]
119
+ image_test_files = [file.as_posix() for file in image_test_files]
120
+ logger.debug(f"image_true_files: {image_true_files}")
121
+ logger.debug(f"image_test_files: {image_test_files}")
122
+ assert len(image_true_files) == len(image_test_files)
123
+ for image_true, image_test in zip(image_true_files, image_test_files):
124
+ assert os.path.basename(image_true) == os.path.basename(
125
+ image_test
126
+ ), f"image_true:{image_true} != image_test: {image_test}"
127
+
128
+ total_metric = 0.0
129
+ valid_files = 0
130
+ for image_true, image_test in zip(image_true_files, image_test_files):
131
+ metric = compute_file_func(image_true, image_test)
132
+ if metric != float("inf"):
133
+ total_metric += metric
134
+ valid_files += 1
135
+
136
+ if valid_files > 0:
137
+ average_metric = total_metric / valid_files
138
+ logger.debug(f"Average: {average_metric:.2f}")
139
+ return average_metric, valid_files
140
+ else:
141
+ logger.debug("No valid files to compare")
142
+ return None, None
143
+
144
+
145
+ def compute_video_metric(
146
+ video_true: str,
147
+ video_test: str,
148
+ compute_frame_func: callable = compute_psnr_file,
149
+ ) -> float:
150
+ """
151
+ video_true = "video_true.mp4"
152
+ video_test = "video_test.mp4"
153
+ PSNR = compute_video_psnr(video_true, video_test)
154
+ """
155
+ cap1 = cv2.VideoCapture(video_true)
156
+ cap2 = cv2.VideoCapture(video_test)
157
+
158
+ if not cap1.isOpened() or not cap2.isOpened():
159
+ logger.error("Could not open video files")
160
+ return None
161
+
162
+ frame_count = min(
163
+ int(cap1.get(cv2.CAP_PROP_FRAME_COUNT)),
164
+ int(cap2.get(cv2.CAP_PROP_FRAME_COUNT)),
165
+ )
166
+
167
+ total_metric = 0.0
168
+ valid_frames = 0
169
+
170
+ logger.debug(f"Total frames: {frame_count}")
171
+
172
+ while True:
173
+ ret1, frame1 = cap1.read()
174
+ ret2, frame2 = cap2.read()
175
+
176
+ if not ret1 or not ret2:
177
+ break
178
+
179
+ metric = compute_frame_func(frame1, frame2)
180
+
181
+ if metric != float("inf"):
182
+ total_metric += metric
183
+ valid_frames += 1
184
+
185
+ if valid_frames % 10 == 0:
186
+ logger.debug(f"Processed {valid_frames}/{frame_count} frames")
187
+
188
+ cap1.release()
189
+ cap2.release()
190
+
191
+ if valid_frames > 0:
192
+ average_metric = total_metric / valid_frames
193
+ logger.debug(f"Average: {average_metric:.2f}")
194
+ return average_metric, valid_frames
195
+ else:
196
+ logger.debug("No valid frames to compare")
197
+ return None, None
198
+
199
+
200
+ compute_psnr = partial(
201
+ compute_dir_metric,
202
+ compute_file_func=compute_psnr_file,
203
+ )
204
+
205
+ compute_ssim = partial(
206
+ compute_dir_metric,
207
+ compute_file_func=compute_ssim_file,
208
+ )
209
+
210
+ compute_mse = partial(
211
+ compute_dir_metric,
212
+ compute_file_func=compute_mse_file,
213
+ )
214
+
215
+ compute_video_psnr = partial(
216
+ compute_video_metric,
217
+ compute_frame_func=compute_psnr_file,
218
+ )
219
+ compute_video_ssim = partial(
220
+ compute_video_metric,
221
+ compute_frame_func=compute_ssim_file,
222
+ )
223
+ compute_video_mse = partial(
224
+ compute_video_metric,
225
+ compute_frame_func=compute_mse_file,
226
+ )
227
+
228
+
229
+ # Entrypoints
230
+ def get_args():
231
+ parser = argparse.ArgumentParser(
232
+ description="CacheDiT's Metrics CLI",
233
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
234
+ )
235
+ METRICS_CHOICES = [
236
+ "psnr",
237
+ "ssim",
238
+ "mse",
239
+ "fid",
240
+ "all",
241
+ ]
242
+ parser.add_argument(
243
+ "metric",
244
+ type=str,
245
+ default="psnr",
246
+ choices=METRICS_CHOICES,
247
+ help=f"Metric choices: {METRICS_CHOICES}",
248
+ )
249
+ parser.add_argument(
250
+ "--img-true",
251
+ "-i1",
252
+ type=str,
253
+ default=None,
254
+ help="Path to ground truth image or Dir to ground truth images",
255
+ )
256
+ parser.add_argument(
257
+ "--img-test",
258
+ "-i2",
259
+ type=str,
260
+ default=None,
261
+ help="Path to predicted image or Dir to predicted images",
262
+ )
263
+ parser.add_argument(
264
+ "--video-true",
265
+ "-v1",
266
+ type=str,
267
+ default=None,
268
+ help="Path to ground truth video",
269
+ )
270
+ parser.add_argument(
271
+ "--video-test",
272
+ "-v2",
273
+ type=str,
274
+ default=None,
275
+ help="Path to predicted video",
276
+ )
277
+ return parser.parse_args()
278
+
279
+
280
+ def entrypoint():
281
+ args = get_args()
282
+ logger.debug(args)
283
+
284
+ if args.img_true is not None and args.img_test is not None:
285
+ if any(
286
+ (
287
+ not os.path.exists(args.img_true),
288
+ not os.path.exists(args.img_test),
289
+ )
290
+ ):
291
+ return
292
+ # img_true and img_test can be files or dirs
293
+ if args.metric == "psnr" or args.metric == "all":
294
+ img_psnr, n = compute_psnr(args.img_true, args.img_test)
295
+ logger.info(
296
+ f"{args.img_true} vs {args.img_test}, Num: {n}, PSNR: {img_psnr}"
297
+ )
298
+ if args.metric == "ssim" or args.metric == "all":
299
+ img_ssim, n = compute_ssim(args.img_true, args.img_test)
300
+ logger.info(
301
+ f"{args.img_true} vs {args.img_test}, Num: {n}, SSIM: {img_ssim}"
302
+ )
303
+ if args.metric == "mse" or args.metric == "all":
304
+ img_mse, n = compute_mse(args.img_true, args.img_test)
305
+ logger.info(
306
+ f"{args.img_true} vs {args.img_test}, Num: {n}, MSE: {img_mse}"
307
+ )
308
+ if args.metric == "fid" or args.metric == "all":
309
+ FID = FrechetInceptionDistance()
310
+ img_fid, n = FID.compute_fid(args.img_true, args.img_test)
311
+ logger.info(
312
+ f"{args.img_true} vs {args.img_test}, Num: {n}, FID: {img_fid}"
313
+ )
314
+ if args.video_true is not None and args.video_test is not None:
315
+ if any(
316
+ (
317
+ not os.path.exists(args.video_true),
318
+ not os.path.exists(args.video_test),
319
+ )
320
+ ):
321
+ return
322
+ if args.metric == "psnr" or args.metric == "all":
323
+ assert not os.path.isdir(args.video_true)
324
+ assert not os.path.isdir(args.video_test)
325
+ video_psnr, n = compute_video_psnr(args.video_true, args.video_test)
326
+ logger.info(
327
+ f"{args.video_true} vs {args.video_test}, Num: {n}, PSNR: {video_psnr}"
328
+ )
329
+ if args.metric == "ssim" or args.metric == "all":
330
+ assert not os.path.isdir(args.video_true)
331
+ assert not os.path.isdir(args.video_test)
332
+ video_ssim, n = compute_video_ssim(args.video_true, args.video_test)
333
+ logger.info(
334
+ f"{args.video_true} vs {args.video_test}, Num: {n}, SSIM: {video_ssim}"
335
+ )
336
+ if args.metric == "mse" or args.metric == "all":
337
+ assert not os.path.isdir(args.video_true)
338
+ assert not os.path.isdir(args.video_test)
339
+ video_mse, n = compute_video_mse(args.video_true, args.video_test)
340
+ logger.info(
341
+ f"{args.video_true} vs {args.video_test}, Num: {n}, MSE: {video_mse}"
342
+ )
343
+ if args.metric == "fid" or args.metric == "all":
344
+ assert not os.path.isdir(args.video_true)
345
+ assert not os.path.isdir(args.video_test)
346
+ FID = FrechetInceptionDistance()
347
+ video_fid, n = FID.compute_video_fid(
348
+ args.video_true, args.video_test
349
+ )
350
+ logger.info(
351
+ f"{args.video_true} vs {args.video_test}, Num: {n}, FID: {video_fid}"
352
+ )
353
+
354
+
355
+ if __name__ == "__main__":
356
+ entrypoint()