cache-dit 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.5'
21
- __version_tuple__ = version_tuple = (0, 2, 5)
20
+ __version__ = version = '0.2.6'
21
+ __version_tuple__ = version_tuple = (0, 2, 6)
@@ -0,0 +1,12 @@
1
+ from cache_dit.metrics.metrics import compute_psnr
2
+ from cache_dit.metrics.metrics import compute_ssim
3
+ from cache_dit.metrics.metrics import compute_mse
4
+ from cache_dit.metrics.metrics import compute_video_psnr
5
+ from cache_dit.metrics.metrics import compute_video_ssim
6
+ from cache_dit.metrics.metrics import compute_video_mse
7
+ from cache_dit.metrics.metrics import entrypoint
8
+ from cache_dit.metrics.fid import FrechetInceptionDistance
9
+
10
+
11
+ def main():
12
+ entrypoint()
@@ -0,0 +1,409 @@
1
+ import os
2
+ import cv2
3
+ import pathlib
4
+ import numpy as np
5
+ from PIL import Image
6
+ from tqdm import tqdm
7
+ from scipy import linalg
8
+ import torch
9
+ import torchvision.transforms as TF
10
+ from torch.nn.functional import adaptive_avg_pool2d
11
+ from cache_dit.metrics.inception import InceptionV3
12
+ from cache_dit.logger import init_logger
13
+
14
+ logger = init_logger(__name__)
15
+
16
+
17
+ class ImagePathDataset(torch.utils.data.Dataset):
18
+ def __init__(self, files_or_imgs, transforms=None):
19
+ self.files_or_imgs = files_or_imgs
20
+ self.transforms = transforms
21
+
22
+ def __len__(self):
23
+ return len(self.files_or_imgs)
24
+
25
+ def __getitem__(self, i):
26
+ file_or_img = self.files_or_imgs[i]
27
+ if isinstance(file_or_img, (str, pathlib.Path)):
28
+ img = Image.open(file_or_img).convert("RGB")
29
+ elif isinstance(file_or_img, np.ndarray):
30
+ # Assume the img is a standard OpenCV image.
31
+ img = cv2.cvtColor(file_or_img, cv2.COLOR_BGR2RGB)
32
+ img = Image.fromarray(img)
33
+ else:
34
+ raise ValueError(
35
+ "file_or_img must be a file path or an OpenCV image."
36
+ )
37
+ if self.transforms is not None:
38
+ img = self.transforms(img)
39
+ return img
40
+
41
+
42
+ def get_activations(
43
+ files_or_imgs,
44
+ model,
45
+ batch_size=50,
46
+ dims=2048,
47
+ device="cpu",
48
+ num_workers=1,
49
+ disable_tqdm=True,
50
+ ):
51
+ """Calculates the activations of the pool_3 layer for all images.
52
+
53
+ Params:
54
+ -- files_or_imgs : List of image files paths or OpenCV image
55
+ -- model : Instance of inception model
56
+ -- batch_size : Batch size of images for the model to process at once.
57
+ Make sure that the number of samples is a multiple of
58
+ the batch size, otherwise some samples are ignored. This
59
+ behavior is retained to match the original FID score
60
+ implementation.
61
+ -- dims : Dimensionality of features returned by Inception
62
+ -- device : Device to run calculations
63
+ -- num_workers : Number of parallel dataloader workers
64
+
65
+ Returns:
66
+ -- A numpy array of dimension (num images, dims) that contains the
67
+ activations of the given tensor when feeding inception with the
68
+ query tensor.
69
+ """
70
+ model.eval()
71
+
72
+ if batch_size > len(files_or_imgs):
73
+ logger.info(
74
+ (
75
+ "Warning: batch size is bigger than the data size. "
76
+ "Setting batch size to data size"
77
+ )
78
+ )
79
+ batch_size = len(files_or_imgs)
80
+
81
+ dataset = ImagePathDataset(files_or_imgs, transforms=TF.ToTensor())
82
+ dataloader = torch.utils.data.DataLoader(
83
+ dataset,
84
+ batch_size=batch_size,
85
+ shuffle=False,
86
+ drop_last=False,
87
+ num_workers=num_workers,
88
+ )
89
+
90
+ pred_arr = np.empty((len(files_or_imgs), dims))
91
+
92
+ start_idx = 0
93
+
94
+ for batch in tqdm(dataloader, disable=disable_tqdm):
95
+ batch = batch.to(device)
96
+
97
+ with torch.no_grad():
98
+ pred = model(batch)[0]
99
+
100
+ # If model output is not scalar, apply global spatial average pooling.
101
+ # This happens if you choose a dimensionality not equal 2048.
102
+ if pred.size(2) != 1 or pred.size(3) != 1:
103
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
104
+
105
+ pred = pred.squeeze(3).squeeze(2).cpu().numpy()
106
+
107
+ pred_arr[start_idx : start_idx + pred.shape[0]] = pred
108
+
109
+ start_idx = start_idx + pred.shape[0]
110
+
111
+ return pred_arr
112
+
113
+
114
+ def calculate_frechet_distance(
115
+ mu1,
116
+ sigma1,
117
+ mu2,
118
+ sigma2,
119
+ eps=1e-6,
120
+ ):
121
+ """Numpy implementation of the Frechet Distance.
122
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
123
+ and X_2 ~ N(mu_2, C_2) is
124
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
125
+
126
+ Stable version by Dougal J. Sutherland.
127
+
128
+ Params:
129
+ -- mu1 : Numpy array containing the activations of a layer of the
130
+ inception net (like returned by the function 'get_predictions')
131
+ for generated samples.
132
+ -- mu2 : The sample mean over activations, precalculated on an
133
+ representative data set.
134
+ -- sigma1: The covariance matrix over activations for generated samples.
135
+ -- sigma2: The covariance matrix over activations, precalculated on an
136
+ representative data set.
137
+
138
+ Returns:
139
+ -- : The Frechet Distance.
140
+ """
141
+
142
+ mu1 = np.atleast_1d(mu1)
143
+ mu2 = np.atleast_1d(mu2)
144
+
145
+ sigma1 = np.atleast_2d(sigma1)
146
+ sigma2 = np.atleast_2d(sigma2)
147
+
148
+ assert (
149
+ mu1.shape == mu2.shape
150
+ ), "Training and test mean vectors have different lengths"
151
+ assert (
152
+ sigma1.shape == sigma2.shape
153
+ ), "Training and test covariances have different dimensions"
154
+
155
+ diff = mu1 - mu2
156
+
157
+ # Product might be almost singular
158
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
159
+ if not np.isfinite(covmean).all():
160
+ msg = (
161
+ "fid calculation produces singular product; "
162
+ "adding %s to diagonal of cov estimates"
163
+ ) % eps
164
+ print(msg)
165
+ offset = np.eye(sigma1.shape[0]) * eps
166
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
167
+
168
+ # Numerical error might give slight imaginary component
169
+ if np.iscomplexobj(covmean):
170
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
171
+ m = np.max(np.abs(covmean.imag))
172
+ raise ValueError("Imaginary component {}".format(m))
173
+ covmean = covmean.real
174
+
175
+ tr_covmean = np.trace(covmean)
176
+
177
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
178
+
179
+
180
+ def calculate_activation_statistics(
181
+ files_or_imgs,
182
+ model,
183
+ batch_size=50,
184
+ dims=2048,
185
+ device="cpu",
186
+ num_workers=1,
187
+ disable_tqdm=True,
188
+ ):
189
+ """Calculation of the statistics used by the FID.
190
+ Params:
191
+ -- files_or_imgs : List of image files paths or OpenCV image
192
+ -- model : Instance of inception model
193
+ -- batch_size : Batch size of images for the model to process at once.
194
+ Make sure that the number of samples is a multiple of
195
+ the batch size, otherwise some samples are ignored. This
196
+ behavior is retained to match the original FID score
197
+ implementation.
198
+ -- dims : Dimensionality of features returned by Inception
199
+ -- device : Device to run calculations
200
+ -- num_workers : Number of parallel dataloader workers
201
+
202
+ Returns:
203
+ -- mu : The mean over samples of the activations of the pool_3 layer of
204
+ the inception model.
205
+ -- sigma : The covariance matrix of the activations of the pool_3 layer of
206
+ the inception model.
207
+ """
208
+ act = get_activations(
209
+ files_or_imgs,
210
+ model,
211
+ batch_size,
212
+ dims,
213
+ device,
214
+ num_workers,
215
+ disable_tqdm,
216
+ )
217
+ mu = np.mean(act, axis=0)
218
+ sigma = np.cov(act, rowvar=False)
219
+ return mu, sigma
220
+
221
+
222
+ _IMAGE_EXTENSIONS = {
223
+ "bmp",
224
+ "jpg",
225
+ "jpeg",
226
+ "pgm",
227
+ "png",
228
+ "ppm",
229
+ "tif",
230
+ "tiff",
231
+ "webp",
232
+ }
233
+
234
+
235
+ class FrechetInceptionDistance:
236
+ IMAGE_EXTENSIONS = _IMAGE_EXTENSIONS
237
+
238
+ def __init__(
239
+ self,
240
+ device="cuda" if torch.cuda.is_available() else "cpu",
241
+ dims: int = 2048,
242
+ num_workers: int = 1,
243
+ batch_size: int = 1,
244
+ disable_tqdm: bool = True,
245
+ ):
246
+ # https://github.com/mseitzer/pytorch-fid/src/pytorch_fid/fid_score.py
247
+ self.dims = dims
248
+ self.device = device
249
+ self.num_workers = num_workers
250
+ self.batch_size = batch_size
251
+ self.disable_tqdm = disable_tqdm
252
+ self.block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[self.dims]
253
+ self.model = InceptionV3([self.block_idx]).to(self.device)
254
+ self.model = self.model.eval()
255
+
256
+ def compute_fid(
257
+ self,
258
+ image_true: np.ndarray | str,
259
+ image_test: np.ndarray | str,
260
+ ):
261
+ """Calculates the FID of two file paths
262
+ FID = FrechetInceptionDistance()
263
+ img_fid = FID.compute_fid("img_true.png", "img_test.png")
264
+ img_dir_fid = FID.compute_fid("img_true_dir", "img_test_dir")
265
+ """
266
+ if isinstance(image_true, str) or isinstance(image_test, str):
267
+ if os.path.isfile(image_true) or os.path.isfile(image_test):
268
+ assert os.path.exists(image_true)
269
+ assert os.path.exists(image_test)
270
+ assert image_true.split(".")[-1] in self.IMAGE_EXTENSIONS
271
+ assert image_test.split(".")[-1] in self.IMAGE_EXTENSIONS
272
+ image_true_files = [image_true]
273
+ image_test_files = [image_test]
274
+ else:
275
+ # glob image files from dir
276
+ assert os.path.isdir(image_true)
277
+ assert os.path.isdir(image_test)
278
+ image_true_dir = pathlib.Path(image_true)
279
+ image_true_files = sorted(
280
+ [
281
+ file
282
+ for ext in self.IMAGE_EXTENSIONS
283
+ for file in image_true_dir.rglob("*.{}".format(ext))
284
+ ]
285
+ )
286
+ image_test_dir = pathlib.Path(image_test)
287
+ image_test_files = sorted(
288
+ [
289
+ file
290
+ for ext in self.IMAGE_EXTENSIONS
291
+ for file in image_test_dir.rglob("*.{}".format(ext))
292
+ ]
293
+ )
294
+ image_true_files = [
295
+ file.as_posix() for file in image_true_files
296
+ ]
297
+ image_test_files = [
298
+ file.as_posix() for file in image_test_files
299
+ ]
300
+ logger.debug(f"image_true_files: {image_true_files}")
301
+ logger.debug(f"image_test_files: {image_test_files}")
302
+ assert len(image_true_files) == len(image_test_files)
303
+ for image_true, image_test in zip(
304
+ image_true_files, image_test_files
305
+ ):
306
+ assert os.path.basename(image_true) == os.path.basename(
307
+ image_test
308
+ ), f"image_true:{image_true} != image_test: {image_test}"
309
+ else:
310
+ image_true_files = [image_true]
311
+ image_test_files = [image_test]
312
+
313
+ batch_size = min(16, self.batch_size)
314
+ batch_size = min(batch_size, len(image_test_files))
315
+ m1, s1 = calculate_activation_statistics(
316
+ image_true_files,
317
+ self.model,
318
+ batch_size,
319
+ self.dims,
320
+ self.device,
321
+ self.num_workers,
322
+ self.disable_tqdm,
323
+ )
324
+ m2, s2 = calculate_activation_statistics(
325
+ image_test_files,
326
+ self.model,
327
+ batch_size,
328
+ self.dims,
329
+ self.device,
330
+ self.num_workers,
331
+ self.disable_tqdm,
332
+ )
333
+ fid_value = calculate_frechet_distance(
334
+ m1,
335
+ s1,
336
+ m2,
337
+ s2,
338
+ )
339
+
340
+ return fid_value, len(image_true_files)
341
+
342
+ def compute_video_fid(
343
+ self,
344
+ video_true: str,
345
+ video_test: str,
346
+ ):
347
+ cap1 = cv2.VideoCapture(video_true)
348
+ cap2 = cv2.VideoCapture(video_test)
349
+
350
+ if not cap1.isOpened() or not cap2.isOpened():
351
+ logger.error("Could not open video files")
352
+ return None, None
353
+
354
+ frame_count = min(
355
+ int(cap1.get(cv2.CAP_PROP_FRAME_COUNT)),
356
+ int(cap2.get(cv2.CAP_PROP_FRAME_COUNT)),
357
+ )
358
+
359
+ valid_frames = 0
360
+ video_true_frames = []
361
+ video_test_frames = []
362
+
363
+ logger.debug(f"Total frames: {frame_count}")
364
+
365
+ while True:
366
+ ret1, frame1 = cap1.read()
367
+ ret2, frame2 = cap2.read()
368
+
369
+ if not ret1 or not ret2:
370
+ break
371
+
372
+ video_true_frames.append(frame1)
373
+ video_test_frames.append(frame2)
374
+
375
+ valid_frames += 1
376
+
377
+ cap1.release()
378
+ cap2.release()
379
+
380
+ if valid_frames <= 0:
381
+ return None, None
382
+
383
+ batch_size = min(16, self.batch_size)
384
+ m1, s1 = calculate_activation_statistics(
385
+ video_true_frames,
386
+ self.model,
387
+ batch_size,
388
+ self.dims,
389
+ self.device,
390
+ self.num_workers,
391
+ self.disable_tqdm,
392
+ )
393
+ m2, s2 = calculate_activation_statistics(
394
+ video_test_frames,
395
+ self.model,
396
+ batch_size,
397
+ self.dims,
398
+ self.device,
399
+ self.num_workers,
400
+ self.disable_tqdm,
401
+ )
402
+ fid_value = calculate_frechet_distance(
403
+ m1,
404
+ s1,
405
+ m2,
406
+ s2,
407
+ )
408
+
409
+ return fid_value, valid_frames
@@ -0,0 +1,353 @@
1
+ import torch
2
+ import torchvision
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from cache_dit.logger import init_logger
6
+
7
+ logger = init_logger(__name__)
8
+
9
+
10
+ # Adapted from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/inception.py
11
+ try:
12
+ from torchvision.models.utils import load_state_dict_from_url
13
+ except ImportError:
14
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
15
+
16
+ # Inception weights ported to Pytorch from
17
+ # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
18
+ FID_WEIGHTS_URL = "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" # noqa: E501
19
+
20
+
21
+ class InceptionV3(nn.Module):
22
+ """Pretrained InceptionV3 network returning feature maps"""
23
+
24
+ # Index of default block of inception to return,
25
+ # corresponds to output of final average pooling
26
+ DEFAULT_BLOCK_INDEX = 3
27
+
28
+ # Maps feature dimensionality to their output blocks indices
29
+ BLOCK_INDEX_BY_DIM = {
30
+ 64: 0, # First max pooling features
31
+ 192: 1, # Second max pooling featurs
32
+ 768: 2, # Pre-aux classifier features
33
+ 2048: 3, # Final average pooling features
34
+ }
35
+
36
+ def __init__(
37
+ self,
38
+ output_blocks=(DEFAULT_BLOCK_INDEX,),
39
+ resize_input=True,
40
+ normalize_input=True,
41
+ requires_grad=False,
42
+ use_fid_inception=True,
43
+ ):
44
+ """Build pretrained InceptionV3
45
+
46
+ Parameters
47
+ ----------
48
+ output_blocks : list of int
49
+ Indices of blocks to return features of. Possible values are:
50
+ - 0: corresponds to output of first max pooling
51
+ - 1: corresponds to output of second max pooling
52
+ - 2: corresponds to output which is fed to aux classifier
53
+ - 3: corresponds to output of final average pooling
54
+ resize_input : bool
55
+ If true, bilinearly resizes input to width and height 299 before
56
+ feeding input to model. As the network without fully connected
57
+ layers is fully convolutional, it should be able to handle inputs
58
+ of arbitrary size, so resizing might not be strictly needed
59
+ normalize_input : bool
60
+ If true, scales the input from range (0, 1) to the range the
61
+ pretrained Inception network expects, namely (-1, 1)
62
+ requires_grad : bool
63
+ If true, parameters of the model require gradients. Possibly useful
64
+ for finetuning the network
65
+ use_fid_inception : bool
66
+ If true, uses the pretrained Inception model used in Tensorflow's
67
+ FID implementation. If false, uses the pretrained Inception model
68
+ available in torchvision. The FID Inception model has different
69
+ weights and a slightly different structure from torchvision's
70
+ Inception model. If you want to compute FID scores, you are
71
+ strongly advised to set this parameter to true to get comparable
72
+ results.
73
+ """
74
+ super(InceptionV3, self).__init__()
75
+
76
+ self.resize_input = resize_input
77
+ self.normalize_input = normalize_input
78
+ self.output_blocks = sorted(output_blocks)
79
+ self.last_needed_block = max(output_blocks)
80
+
81
+ assert (
82
+ self.last_needed_block <= 3
83
+ ), "Last possible output block index is 3"
84
+
85
+ self.blocks = nn.ModuleList()
86
+
87
+ if use_fid_inception:
88
+ inception = fid_inception_v3()
89
+ else:
90
+ inception = _inception_v3(weights="DEFAULT")
91
+
92
+ # Block 0: input to maxpool1
93
+ block0 = [
94
+ inception.Conv2d_1a_3x3,
95
+ inception.Conv2d_2a_3x3,
96
+ inception.Conv2d_2b_3x3,
97
+ nn.MaxPool2d(kernel_size=3, stride=2),
98
+ ]
99
+ self.blocks.append(nn.Sequential(*block0))
100
+
101
+ # Block 1: maxpool1 to maxpool2
102
+ if self.last_needed_block >= 1:
103
+ block1 = [
104
+ inception.Conv2d_3b_1x1,
105
+ inception.Conv2d_4a_3x3,
106
+ nn.MaxPool2d(kernel_size=3, stride=2),
107
+ ]
108
+ self.blocks.append(nn.Sequential(*block1))
109
+
110
+ # Block 2: maxpool2 to aux classifier
111
+ if self.last_needed_block >= 2:
112
+ block2 = [
113
+ inception.Mixed_5b,
114
+ inception.Mixed_5c,
115
+ inception.Mixed_5d,
116
+ inception.Mixed_6a,
117
+ inception.Mixed_6b,
118
+ inception.Mixed_6c,
119
+ inception.Mixed_6d,
120
+ inception.Mixed_6e,
121
+ ]
122
+ self.blocks.append(nn.Sequential(*block2))
123
+
124
+ # Block 3: aux classifier to final avgpool
125
+ if self.last_needed_block >= 3:
126
+ block3 = [
127
+ inception.Mixed_7a,
128
+ inception.Mixed_7b,
129
+ inception.Mixed_7c,
130
+ nn.AdaptiveAvgPool2d(output_size=(1, 1)),
131
+ ]
132
+ self.blocks.append(nn.Sequential(*block3))
133
+
134
+ for param in self.parameters():
135
+ param.requires_grad = requires_grad
136
+
137
+ def forward(self, inp):
138
+ """Get Inception feature maps
139
+
140
+ Parameters
141
+ ----------
142
+ inp : torch.autograd.Variable
143
+ Input tensor of shape Bx3xHxW. Values are expected to be in
144
+ range (0, 1)
145
+
146
+ Returns
147
+ -------
148
+ List of torch.autograd.Variable, corresponding to the selected output
149
+ block, sorted ascending by index
150
+ """
151
+ outp = []
152
+ x = inp
153
+
154
+ if self.resize_input:
155
+ x = F.interpolate(
156
+ x, size=(299, 299), mode="bilinear", align_corners=False
157
+ )
158
+
159
+ if self.normalize_input:
160
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
161
+
162
+ for idx, block in enumerate(self.blocks):
163
+ x = block(x)
164
+ if idx in self.output_blocks:
165
+ outp.append(x)
166
+
167
+ if idx == self.last_needed_block:
168
+ break
169
+
170
+ return outp
171
+
172
+
173
+ def _inception_v3(*args, **kwargs):
174
+ """Wraps `torchvision.models.inception_v3`"""
175
+ try:
176
+ version = tuple(map(int, torchvision.__version__.split(".")[:2]))
177
+ except ValueError:
178
+ # Just a caution against weird version strings
179
+ version = (0,)
180
+
181
+ # Skips default weight inititialization if supported by torchvision
182
+ # version. See https://github.com/mseitzer/pytorch-fid/issues/28.
183
+ if version >= (0, 6):
184
+ kwargs["init_weights"] = False
185
+
186
+ # Backwards compatibility: `weights` argument was handled by `pretrained`
187
+ # argument prior to version 0.13.
188
+ if version < (0, 13) and "weights" in kwargs:
189
+ if kwargs["weights"] == "DEFAULT":
190
+ kwargs["pretrained"] = True
191
+ elif kwargs["weights"] is None:
192
+ kwargs["pretrained"] = False
193
+ else:
194
+ raise ValueError(
195
+ "weights=={} not supported in torchvision {}".format(
196
+ kwargs["weights"], torchvision.__version__
197
+ )
198
+ )
199
+ del kwargs["weights"]
200
+
201
+ return torchvision.models.inception_v3(*args, **kwargs)
202
+
203
+
204
+ def fid_inception_v3():
205
+ """Build pretrained Inception model for FID computation
206
+
207
+ The Inception model for FID computation uses a different set of weights
208
+ and has a slightly different structure than torchvision's Inception.
209
+
210
+ This method first constructs torchvision's Inception and then patches the
211
+ necessary parts that are different in the FID Inception model.
212
+ """
213
+ inception = _inception_v3(num_classes=1008, aux_logits=False, weights=None)
214
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
215
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
216
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
217
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
218
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
219
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
220
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
221
+ inception.Mixed_7b = FIDInceptionE_1(1280)
222
+ inception.Mixed_7c = FIDInceptionE_2(2048)
223
+
224
+ state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=False)
225
+ inception.load_state_dict(state_dict)
226
+ return inception
227
+
228
+
229
+ class FIDInceptionA(torchvision.models.inception.InceptionA):
230
+ """InceptionA block patched for FID computation"""
231
+
232
+ def __init__(self, in_channels, pool_features):
233
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
234
+
235
+ def forward(self, x):
236
+ branch1x1 = self.branch1x1(x)
237
+
238
+ branch5x5 = self.branch5x5_1(x)
239
+ branch5x5 = self.branch5x5_2(branch5x5)
240
+
241
+ branch3x3dbl = self.branch3x3dbl_1(x)
242
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
243
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
244
+
245
+ # Patch: Tensorflow's average pool does not use the padded zero's in
246
+ # its average calculation
247
+ branch_pool = F.avg_pool2d(
248
+ x, kernel_size=3, stride=1, padding=1, count_include_pad=False
249
+ )
250
+ branch_pool = self.branch_pool(branch_pool)
251
+
252
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
253
+ return torch.cat(outputs, 1)
254
+
255
+
256
+ class FIDInceptionC(torchvision.models.inception.InceptionC):
257
+ """InceptionC block patched for FID computation"""
258
+
259
+ def __init__(self, in_channels, channels_7x7):
260
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
261
+
262
+ def forward(self, x):
263
+ branch1x1 = self.branch1x1(x)
264
+
265
+ branch7x7 = self.branch7x7_1(x)
266
+ branch7x7 = self.branch7x7_2(branch7x7)
267
+ branch7x7 = self.branch7x7_3(branch7x7)
268
+
269
+ branch7x7dbl = self.branch7x7dbl_1(x)
270
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
271
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
272
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
273
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
274
+
275
+ # Patch: Tensorflow's average pool does not use the padded zero's in
276
+ # its average calculation
277
+ branch_pool = F.avg_pool2d(
278
+ x, kernel_size=3, stride=1, padding=1, count_include_pad=False
279
+ )
280
+ branch_pool = self.branch_pool(branch_pool)
281
+
282
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
283
+ return torch.cat(outputs, 1)
284
+
285
+
286
+ class FIDInceptionE_1(torchvision.models.inception.InceptionE):
287
+ """First InceptionE block patched for FID computation"""
288
+
289
+ def __init__(self, in_channels):
290
+ super(FIDInceptionE_1, self).__init__(in_channels)
291
+
292
+ def forward(self, x):
293
+ branch1x1 = self.branch1x1(x)
294
+
295
+ branch3x3 = self.branch3x3_1(x)
296
+ branch3x3 = [
297
+ self.branch3x3_2a(branch3x3),
298
+ self.branch3x3_2b(branch3x3),
299
+ ]
300
+ branch3x3 = torch.cat(branch3x3, 1)
301
+
302
+ branch3x3dbl = self.branch3x3dbl_1(x)
303
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
304
+ branch3x3dbl = [
305
+ self.branch3x3dbl_3a(branch3x3dbl),
306
+ self.branch3x3dbl_3b(branch3x3dbl),
307
+ ]
308
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
309
+
310
+ # Patch: Tensorflow's average pool does not use the padded zero's in
311
+ # its average calculation
312
+ branch_pool = F.avg_pool2d(
313
+ x, kernel_size=3, stride=1, padding=1, count_include_pad=False
314
+ )
315
+ branch_pool = self.branch_pool(branch_pool)
316
+
317
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
318
+ return torch.cat(outputs, 1)
319
+
320
+
321
+ class FIDInceptionE_2(torchvision.models.inception.InceptionE):
322
+ """Second InceptionE block patched for FID computation"""
323
+
324
+ def __init__(self, in_channels):
325
+ super(FIDInceptionE_2, self).__init__(in_channels)
326
+
327
+ def forward(self, x):
328
+ branch1x1 = self.branch1x1(x)
329
+
330
+ branch3x3 = self.branch3x3_1(x)
331
+ branch3x3 = [
332
+ self.branch3x3_2a(branch3x3),
333
+ self.branch3x3_2b(branch3x3),
334
+ ]
335
+ branch3x3 = torch.cat(branch3x3, 1)
336
+
337
+ branch3x3dbl = self.branch3x3dbl_1(x)
338
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
339
+ branch3x3dbl = [
340
+ self.branch3x3dbl_3a(branch3x3dbl),
341
+ self.branch3x3dbl_3b(branch3x3dbl),
342
+ ]
343
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
344
+
345
+ # Patch: The FID Inception model uses max pooling instead of average
346
+ # pooling. This is likely an error in this specific Inception
347
+ # implementation, as other Inception models use average pooling here
348
+ # (which matches the description in the paper).
349
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
350
+ branch_pool = self.branch_pool(branch_pool)
351
+
352
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
353
+ return torch.cat(outputs, 1)
@@ -0,0 +1,356 @@
1
+ import os
2
+ import cv2
3
+ import pathlib
4
+ import argparse
5
+ import numpy as np
6
+ from functools import partial
7
+ from skimage.metrics import mean_squared_error
8
+ from skimage.metrics import peak_signal_noise_ratio
9
+ from skimage.metrics import structural_similarity
10
+ from cache_dit.metrics.fid import FrechetInceptionDistance
11
+ from cache_dit.logger import init_logger
12
+
13
+ logger = init_logger(__name__)
14
+
15
+
16
+ def compute_psnr_file(
17
+ image_true: np.ndarray | str,
18
+ image_test: np.ndarray | str,
19
+ ) -> float:
20
+ """
21
+ img_true = cv2.imread(img_true_file)
22
+ img_test = cv2.imread(img_test_file)
23
+ PSNR = compute_psnr(img_true, img_test)
24
+ """
25
+ if isinstance(image_true, str):
26
+ image_true = cv2.imread(image_true)
27
+ if isinstance(image_test, str):
28
+ image_test = cv2.imread(image_test)
29
+ return peak_signal_noise_ratio(
30
+ image_true,
31
+ image_test,
32
+ )
33
+
34
+
35
+ def compute_mse_file(
36
+ image_true: np.ndarray | str,
37
+ image_test: np.ndarray | str,
38
+ ) -> float:
39
+ """
40
+ img_true = cv2.imread(img_true_file)
41
+ img_test = cv2.imread(img_test_file)
42
+ MSE = compute_mse(img_true, img_test)
43
+ """
44
+ if isinstance(image_true, str):
45
+ image_true = cv2.imread(image_true)
46
+ if isinstance(image_test, str):
47
+ image_test = cv2.imread(image_test)
48
+ return mean_squared_error(
49
+ image_true,
50
+ image_test,
51
+ )
52
+
53
+
54
+ def compute_ssim_file(
55
+ image_true: np.ndarray | str,
56
+ image_test: np.ndarray | str,
57
+ ) -> float:
58
+ """
59
+ img_true = cv2.imread(img_true_file)
60
+ img_test = cv2.imread(img_test_file)
61
+ SSIM = compute_ssim(img_true, img_test)
62
+ """
63
+ if isinstance(image_true, str):
64
+ image_true = cv2.imread(image_true)
65
+ if isinstance(image_test, str):
66
+ image_test = cv2.imread(image_test)
67
+ return structural_similarity(
68
+ image_true,
69
+ image_test,
70
+ multichannel=True,
71
+ channel_axis=2,
72
+ )
73
+
74
+
75
+ _IMAGE_EXTENSIONS = {
76
+ "bmp",
77
+ "jpg",
78
+ "jpeg",
79
+ "pgm",
80
+ "png",
81
+ "ppm",
82
+ "tif",
83
+ "tiff",
84
+ "webp",
85
+ }
86
+
87
+
88
+ def compute_dir_metric(
89
+ image_true_dir: np.ndarray | str,
90
+ image_test_dir: np.ndarray | str,
91
+ compute_file_func: callable = compute_psnr_file,
92
+ ) -> float:
93
+ # Image
94
+ if isinstance(image_true_dir, np.ndarray) or isinstance(
95
+ image_test_dir, np.ndarray
96
+ ):
97
+ return compute_file_func(image_true_dir, image_test_dir), 1
98
+ # File
99
+ if not os.path.isdir(image_true_dir) or not os.path.isdir(image_test_dir):
100
+ return compute_file_func(image_true_dir, image_test_dir), 1
101
+ # Dir
102
+ image_true_dir: pathlib.Path = pathlib.Path(image_true_dir)
103
+ image_true_files = sorted(
104
+ [
105
+ file
106
+ for ext in _IMAGE_EXTENSIONS
107
+ for file in image_true_dir.rglob("*.{}".format(ext))
108
+ ]
109
+ )
110
+ image_test_dir: pathlib.Path = pathlib.Path(image_test_dir)
111
+ image_test_files = sorted(
112
+ [
113
+ file
114
+ for ext in _IMAGE_EXTENSIONS
115
+ for file in image_test_dir.rglob("*.{}".format(ext))
116
+ ]
117
+ )
118
+ image_true_files = [file.as_posix() for file in image_true_files]
119
+ image_test_files = [file.as_posix() for file in image_test_files]
120
+ logger.debug(f"image_true_files: {image_true_files}")
121
+ logger.debug(f"image_test_files: {image_test_files}")
122
+ assert len(image_true_files) == len(image_test_files)
123
+ for image_true, image_test in zip(image_true_files, image_test_files):
124
+ assert os.path.basename(image_true) == os.path.basename(
125
+ image_test
126
+ ), f"image_true:{image_true} != image_test: {image_test}"
127
+
128
+ total_metric = 0.0
129
+ valid_files = 0
130
+ for image_true, image_test in zip(image_true_files, image_test_files):
131
+ metric = compute_file_func(image_true, image_test)
132
+ if metric != float("inf"):
133
+ total_metric += metric
134
+ valid_files += 1
135
+
136
+ if valid_files > 0:
137
+ average_metric = total_metric / valid_files
138
+ logger.debug(f"Average: {average_metric:.2f}")
139
+ return average_metric, valid_files
140
+ else:
141
+ logger.debug("No valid files to compare")
142
+ return None, None
143
+
144
+
145
+ def compute_video_metric(
146
+ video_true: str,
147
+ video_test: str,
148
+ compute_frame_func: callable = compute_psnr_file,
149
+ ) -> float:
150
+ """
151
+ video_true = "video_true.mp4"
152
+ video_test = "video_test.mp4"
153
+ PSNR = compute_video_psnr(video_true, video_test)
154
+ """
155
+ cap1 = cv2.VideoCapture(video_true)
156
+ cap2 = cv2.VideoCapture(video_test)
157
+
158
+ if not cap1.isOpened() or not cap2.isOpened():
159
+ logger.error("Could not open video files")
160
+ return None
161
+
162
+ frame_count = min(
163
+ int(cap1.get(cv2.CAP_PROP_FRAME_COUNT)),
164
+ int(cap2.get(cv2.CAP_PROP_FRAME_COUNT)),
165
+ )
166
+
167
+ total_metric = 0.0
168
+ valid_frames = 0
169
+
170
+ logger.debug(f"Total frames: {frame_count}")
171
+
172
+ while True:
173
+ ret1, frame1 = cap1.read()
174
+ ret2, frame2 = cap2.read()
175
+
176
+ if not ret1 or not ret2:
177
+ break
178
+
179
+ metric = compute_frame_func(frame1, frame2)
180
+
181
+ if metric != float("inf"):
182
+ total_metric += metric
183
+ valid_frames += 1
184
+
185
+ if valid_frames % 10 == 0:
186
+ logger.debug(f"Processed {valid_frames}/{frame_count} frames")
187
+
188
+ cap1.release()
189
+ cap2.release()
190
+
191
+ if valid_frames > 0:
192
+ average_metric = total_metric / valid_frames
193
+ logger.debug(f"Average: {average_metric:.2f}")
194
+ return average_metric, valid_frames
195
+ else:
196
+ logger.debug("No valid frames to compare")
197
+ return None, None
198
+
199
+
200
+ compute_psnr = partial(
201
+ compute_dir_metric,
202
+ compute_file_func=compute_psnr_file,
203
+ )
204
+
205
+ compute_ssim = partial(
206
+ compute_dir_metric,
207
+ compute_file_func=compute_ssim_file,
208
+ )
209
+
210
+ compute_mse = partial(
211
+ compute_dir_metric,
212
+ compute_file_func=compute_mse_file,
213
+ )
214
+
215
+ compute_video_psnr = partial(
216
+ compute_video_metric,
217
+ compute_frame_func=compute_psnr_file,
218
+ )
219
+ compute_video_ssim = partial(
220
+ compute_video_metric,
221
+ compute_frame_func=compute_ssim_file,
222
+ )
223
+ compute_video_mse = partial(
224
+ compute_video_metric,
225
+ compute_frame_func=compute_mse_file,
226
+ )
227
+
228
+
229
+ # Entrypoints
230
+ def get_args():
231
+ parser = argparse.ArgumentParser(
232
+ description="CacheDiT's Metrics CLI",
233
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
234
+ )
235
+ METRICS_CHOICES = [
236
+ "psnr",
237
+ "ssim",
238
+ "mse",
239
+ "fid",
240
+ "all",
241
+ ]
242
+ parser.add_argument(
243
+ "metric",
244
+ type=str,
245
+ default="psnr",
246
+ choices=METRICS_CHOICES,
247
+ help=f"Metric choices: {METRICS_CHOICES}",
248
+ )
249
+ parser.add_argument(
250
+ "--img-true",
251
+ "-i1",
252
+ type=str,
253
+ default=None,
254
+ help="Path to ground truth image or Dir to ground truth images",
255
+ )
256
+ parser.add_argument(
257
+ "--img-test",
258
+ "-i2",
259
+ type=str,
260
+ default=None,
261
+ help="Path to predicted image or Dir to predicted images",
262
+ )
263
+ parser.add_argument(
264
+ "--video-true",
265
+ "-v1",
266
+ type=str,
267
+ default=None,
268
+ help="Path to ground truth video",
269
+ )
270
+ parser.add_argument(
271
+ "--video-test",
272
+ "-v2",
273
+ type=str,
274
+ default=None,
275
+ help="Path to predicted video",
276
+ )
277
+ return parser.parse_args()
278
+
279
+
280
+ def entrypoint():
281
+ args = get_args()
282
+ logger.debug(args)
283
+
284
+ if args.img_true is not None and args.img_test is not None:
285
+ if any(
286
+ (
287
+ not os.path.exists(args.img_true),
288
+ not os.path.exists(args.img_test),
289
+ )
290
+ ):
291
+ return
292
+ # img_true and img_test can be files or dirs
293
+ if args.metric == "psnr" or args.metric == "all":
294
+ img_psnr, n = compute_psnr(args.img_true, args.img_test)
295
+ logger.info(
296
+ f"{args.img_true} vs {args.img_test}, Num: {n}, PSNR: {img_psnr}"
297
+ )
298
+ if args.metric == "ssim" or args.metric == "all":
299
+ img_ssim, n = compute_ssim(args.img_true, args.img_test)
300
+ logger.info(
301
+ f"{args.img_true} vs {args.img_test}, Num: {n}, SSIM: {img_ssim}"
302
+ )
303
+ if args.metric == "mse" or args.metric == "all":
304
+ img_mse, n = compute_mse(args.img_true, args.img_test)
305
+ logger.info(
306
+ f"{args.img_true} vs {args.img_test}, Num: {n}, MSE: {img_mse}"
307
+ )
308
+ if args.metric == "fid" or args.metric == "all":
309
+ FID = FrechetInceptionDistance()
310
+ img_fid, n = FID.compute_fid(args.img_true, args.img_test)
311
+ logger.info(
312
+ f"{args.img_true} vs {args.img_test}, Num: {n}, FID: {img_fid}"
313
+ )
314
+ if args.video_true is not None and args.video_test is not None:
315
+ if any(
316
+ (
317
+ not os.path.exists(args.video_true),
318
+ not os.path.exists(args.video_test),
319
+ )
320
+ ):
321
+ return
322
+ if args.metric == "psnr" or args.metric == "all":
323
+ assert not os.path.isdir(args.video_true)
324
+ assert not os.path.isdir(args.video_test)
325
+ video_psnr, n = compute_video_psnr(args.video_true, args.video_test)
326
+ logger.info(
327
+ f"{args.video_true} vs {args.video_test}, Num: {n}, PSNR: {video_psnr}"
328
+ )
329
+ if args.metric == "ssim" or args.metric == "all":
330
+ assert not os.path.isdir(args.video_true)
331
+ assert not os.path.isdir(args.video_test)
332
+ video_ssim, n = compute_video_ssim(args.video_true, args.video_test)
333
+ logger.info(
334
+ f"{args.video_true} vs {args.video_test}, Num: {n}, SSIM: {video_ssim}"
335
+ )
336
+ if args.metric == "mse" or args.metric == "all":
337
+ assert not os.path.isdir(args.video_true)
338
+ assert not os.path.isdir(args.video_test)
339
+ video_mse, n = compute_video_mse(args.video_true, args.video_test)
340
+ logger.info(
341
+ f"{args.video_true} vs {args.video_test}, Num: {n}, MSE: {video_mse}"
342
+ )
343
+ if args.metric == "fid" or args.metric == "all":
344
+ assert not os.path.isdir(args.video_true)
345
+ assert not os.path.isdir(args.video_test)
346
+ FID = FrechetInceptionDistance()
347
+ video_fid, n = FID.compute_video_fid(
348
+ args.video_true, args.video_test
349
+ )
350
+ logger.info(
351
+ f"{args.video_true} vs {args.video_test}, Num: {n}, FID: {video_fid}"
352
+ )
353
+
354
+
355
+ if __name__ == "__main__":
356
+ entrypoint()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.2.5
3
+ Version: 0.2.6
4
4
  Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -13,6 +13,8 @@ Requires-Dist: packaging
13
13
  Requires-Dist: torch>=2.5.1
14
14
  Requires-Dist: transformers>=4.51.3
15
15
  Requires-Dist: diffusers>=0.33.1
16
+ Requires-Dist: scikit-image
17
+ Requires-Dist: scipy
16
18
  Provides-Extra: all
17
19
  Provides-Extra: dev
18
20
  Requires-Dist: pre-commit; extra == "dev"
@@ -28,6 +30,8 @@ Requires-Dist: protobuf; extra == "dev"
28
30
  Requires-Dist: sentencepiece; extra == "dev"
29
31
  Requires-Dist: opencv-python-headless; extra == "dev"
30
32
  Requires-Dist: ftfy; extra == "dev"
33
+ Requires-Dist: scikit-image; extra == "dev"
34
+ Requires-Dist: pytorch-fid; extra == "dev"
31
35
  Dynamic: license-file
32
36
  Dynamic: provides-extra
33
37
  Dynamic: requires-dist
@@ -159,6 +163,7 @@ The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi
159
163
  - [⚡️Dynamic Block Prune](#dbprune)
160
164
  - [🎉Context Parallelism](#context-parallelism)
161
165
  - [🔥Torch Compile](#compile)
166
+ - [⚙️Metrics CLI](#metrics)
162
167
  - [👋Contribute](#contribute)
163
168
  - [©️License](#license)
164
169
 
@@ -500,6 +505,39 @@ torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
500
505
 
501
506
  Please check [bench.py](./bench/bench.py) for more details.
502
507
 
508
+
509
+ ## ⚙️Metrics CLI
510
+
511
+ <div id="metrics"></div>
512
+
513
+ You can utilize the APIs provided by CacheDiT to quickly evaluate the accuracy losses caused by different cache configurations. For example:
514
+
515
+ ```python
516
+ from cache_dit.metrics import compute_psnr
517
+ from cache_dit.metrics import compute_video_psnr
518
+ from cache_dit.metrics import FrechetInceptionDistance # FID
519
+
520
+ FID = FrechetInceptionDistance()
521
+ image_psnr, n = compute_psnr("true.png", "test.png") # Num: n
522
+ image_fid, n = FID.compute_fid("true.png", "test.png")
523
+ video_psnr, n = compute_video_psnr("true.mp4", "test.mp4")
524
+ ```
525
+
526
+ Please check [test_metrics.py](./tests/test_metrics.py) for more details. Or, you can use `cache-dit-metrics-cli` tool. For examples:
527
+
528
+ ```bash
529
+ cache-dit-metrics-cli -h # show usage
530
+ cache-dit-metrics-cli all -v1 true.mp4 -v2 test.mp4 # compare video
531
+ cache-dit-metrics-cli all -i1 true.png -i2 test.png # compare image
532
+ cache-dit-metrics-cli all -i1 true_dir -i2 test_dir # compare image dir
533
+ cache-dit-metrics-cli all -i1 BASELINE -i2 OPTIMIZED # compare image dir
534
+
535
+ INFO 07-09 20:59:40 [metrics.py:295] BASELINE vs OPTIMIZED, Num: 1000, PSNR: 38.742413478199005
536
+ INFO 07-09 21:00:32 [metrics.py:300] BASELINE vs OPTIMIZED, Num: 1000, SSIM: 0.9863484896791567
537
+ INFO 07-09 21:00:45 [metrics.py:305] BASELINE vs OPTIMIZED, Num: 1000, MSE: 12.287594770695606
538
+ INFO 07-09 21:01:04 [metrics.py:311] BASELINE vs OPTIMIZED, Num: 1000, FID: 5.983550108647762
539
+ ```
540
+
503
541
  ## 👋Contribute
504
542
  <div id="contribute"></div>
505
543
 
@@ -1,5 +1,5 @@
1
1
  cache_dit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- cache_dit/_version.py,sha256=N3oBwJUFmS-AwCjqOcSlRW4GvSq-uJJMaBvoGfv1-hM,511
2
+ cache_dit/_version.py,sha256=nObnONsicQ3YX6SG5MVBxmIp5dmRacXDauSqZijWQbY,511
3
3
  cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
4
4
  cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
5
5
  cache_dit/cache_factory/__init__.py,sha256=5RNuhWakvvqrOV4vkqrEBA7d-V1LwcNSsjtW14mkqK8,5255
@@ -33,8 +33,13 @@ cache_dit/compile/__init__.py,sha256=DfMdPleFFGADXLsr7zXui8BTz_y9futY6rNmNdh9y7k
33
33
  cache_dit/compile/utils.py,sha256=KU60xc474Anbj7Y_FLRFmNxEjVYLLXkhbtCLXO7o_Tc,3699
34
34
  cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
35
35
  cache_dit/custom_ops/triton_taylorseer.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
36
- cache_dit-0.2.5.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
37
- cache_dit-0.2.5.dist-info/METADATA,sha256=J37Waq-cMbuFfTrngXuxqouXpjHK9qhR_MZHlE2odmY,26249
38
- cache_dit-0.2.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
39
- cache_dit-0.2.5.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
40
- cache_dit-0.2.5.dist-info/RECORD,,
36
+ cache_dit/metrics/__init__.py,sha256=5yavk_6b50EhaHij0C-7nQegOA0uD79-qX96dI1i_8s,461
37
+ cache_dit/metrics/fid.py,sha256=RIC-wW0RFv6W9IW6nc6Ih4dAxunUTnzcugMbYZjdqRM,12891
38
+ cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
39
+ cache_dit/metrics/metrics.py,sha256=FUrpc58ofg0LKyM3Y_7kfUVbevMfCFFcaeaxeOrj9iY,10498
40
+ cache_dit-0.2.6.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
41
+ cache_dit-0.2.6.dist-info/METADATA,sha256=aTTyx_2gM7Z-mcIYd3_LlGljiZOwIq9iadDXOFPDzwA,27848
42
+ cache_dit-0.2.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
43
+ cache_dit-0.2.6.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
44
+ cache_dit-0.2.6.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
45
+ cache_dit-0.2.6.dist-info/RECORD,,
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ cache-dit-metrics-cli = cache_dit.metrics:main