cache-dit 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cache_dit/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.6'
21
- __version_tuple__ = version_tuple = (0, 2, 6)
20
+ __version__ = version = '0.2.8'
21
+ __version_tuple__ = version_tuple = (0, 2, 8)
@@ -988,7 +988,7 @@ def get_Bn_encoder_buffer(prefix: str = "Bn"):
988
988
  @torch.compiler.disable
989
989
  def apply_hidden_states_residual(
990
990
  hidden_states: torch.Tensor,
991
- encoder_hidden_states: torch.Tensor,
991
+ encoder_hidden_states: torch.Tensor = None,
992
992
  prefix: str = "Bn",
993
993
  encoder_prefix: str = "Bn_encoder",
994
994
  ):
@@ -1006,25 +1006,27 @@ def apply_hidden_states_residual(
1006
1006
  # If cache is not residual, we use the hidden states directly
1007
1007
  hidden_states = hidden_states_prev
1008
1008
 
1009
- if "Bn" in encoder_prefix:
1010
- encoder_hidden_states_prev = get_Bn_encoder_buffer(encoder_prefix)
1011
- else:
1012
- encoder_hidden_states_prev = get_Fn_encoder_buffer(encoder_prefix)
1009
+ hidden_states = hidden_states.contiguous()
1013
1010
 
1014
- assert (
1015
- encoder_hidden_states_prev is not None
1016
- ), f"{prefix}_encoder_buffer must be set before"
1011
+ if encoder_hidden_states is not None:
1012
+ if "Bn" in encoder_prefix:
1013
+ encoder_hidden_states_prev = get_Bn_encoder_buffer(encoder_prefix)
1014
+ else:
1015
+ encoder_hidden_states_prev = get_Fn_encoder_buffer(encoder_prefix)
1017
1016
 
1018
- if is_encoder_cache_residual():
1019
- encoder_hidden_states = (
1020
- encoder_hidden_states_prev + encoder_hidden_states
1021
- )
1022
- else:
1023
- # If encoder cache is not residual, we use the encoder hidden states directly
1024
- encoder_hidden_states = encoder_hidden_states_prev
1017
+ assert (
1018
+ encoder_hidden_states_prev is not None
1019
+ ), f"{prefix}_encoder_buffer must be set before"
1025
1020
 
1026
- hidden_states = hidden_states.contiguous()
1027
- encoder_hidden_states = encoder_hidden_states.contiguous()
1021
+ if is_encoder_cache_residual():
1022
+ encoder_hidden_states = (
1023
+ encoder_hidden_states_prev + encoder_hidden_states
1024
+ )
1025
+ else:
1026
+ # If encoder cache is not residual, we use the encoder hidden states directly
1027
+ encoder_hidden_states = encoder_hidden_states_prev
1028
+
1029
+ encoder_hidden_states = encoder_hidden_states.contiguous()
1028
1030
 
1029
1031
  return hidden_states, encoder_hidden_states
1030
1032
 
@@ -1160,7 +1162,6 @@ class DBCachedTransformerBlocks(torch.nn.Module):
1160
1162
 
1161
1163
  torch._dynamo.graph_break()
1162
1164
  if can_use_cache:
1163
- torch._dynamo.graph_break()
1164
1165
  add_cached_step()
1165
1166
  del Fn_hidden_states_residual
1166
1167
  hidden_states, encoder_hidden_states = apply_hidden_states_residual(
@@ -1187,7 +1188,6 @@ class DBCachedTransformerBlocks(torch.nn.Module):
1187
1188
  )
1188
1189
  )
1189
1190
  else:
1190
- torch._dynamo.graph_break()
1191
1191
  set_Fn_buffer(Fn_hidden_states_residual, prefix="Fn_residual")
1192
1192
  if is_l1_diff_enabled():
1193
1193
  # for hidden states L1 diff
@@ -1795,7 +1795,6 @@ class DBCachedTransformerBlocks(torch.nn.Module):
1795
1795
  f"the number of single transformer blocks {len(self.single_transformer_blocks)}"
1796
1796
  )
1797
1797
 
1798
- torch._dynamo.graph_break()
1799
1798
  hidden_states = torch.cat(
1800
1799
  [encoder_hidden_states, hidden_states], dim=1
1801
1800
  )
@@ -1827,13 +1826,11 @@ class DBCachedTransformerBlocks(torch.nn.Module):
1827
1826
  ],
1828
1827
  dim=1,
1829
1828
  )
1830
- torch._dynamo.graph_break()
1831
1829
  else:
1832
1830
  assert Bn_compute_blocks() <= len(self.transformer_blocks), (
1833
1831
  f"Bn_compute_blocks {Bn_compute_blocks()} must be less than "
1834
1832
  f"the number of transformer blocks {len(self.transformer_blocks)}"
1835
1833
  )
1836
- torch._dynamo.graph_break()
1837
1834
  if len(Bn_compute_blocks_ids()) > 0:
1838
1835
  for i, block in enumerate(self._Bn_transformer_blocks()):
1839
1836
  hidden_states, encoder_hidden_states = (
@@ -1862,7 +1859,6 @@ class DBCachedTransformerBlocks(torch.nn.Module):
1862
1859
  encoder_hidden_states,
1863
1860
  hidden_states,
1864
1861
  )
1865
- torch._dynamo.graph_break()
1866
1862
 
1867
1863
  hidden_states = (
1868
1864
  hidden_states.reshape(-1)
@@ -82,6 +82,6 @@ def apply_db_cache_on_pipe(
82
82
  pipe.__class__._is_cached = True
83
83
 
84
84
  if not shallow_patch:
85
- apply_db_cache_on_transformer(pipe.transformer, **kwargs)
85
+ apply_db_cache_on_transformer(pipe.transformer)
86
86
 
87
87
  return pipe
@@ -93,6 +93,6 @@ def apply_db_cache_on_pipe(
93
93
  pipe.__class__._is_cached = True
94
94
 
95
95
  if not shallow_patch:
96
- apply_db_cache_on_transformer(pipe.transformer, **kwargs)
96
+ apply_db_cache_on_transformer(pipe.transformer)
97
97
 
98
98
  return pipe
@@ -289,6 +289,6 @@ def apply_db_cache_on_pipe(
289
289
  pipe.__class__._is_cached = True
290
290
 
291
291
  if not shallow_patch:
292
- apply_db_cache_on_transformer(pipe.transformer, **kwargs)
292
+ apply_db_cache_on_transformer(pipe.transformer)
293
293
 
294
294
  return pipe
@@ -82,6 +82,6 @@ def apply_db_cache_on_pipe(
82
82
  pipe.__class__._is_cached = True
83
83
 
84
84
  if not shallow_patch:
85
- apply_db_cache_on_transformer(pipe.transformer, **kwargs)
85
+ apply_db_cache_on_transformer(pipe.transformer)
86
86
 
87
87
  return pipe
@@ -6,6 +6,8 @@ from cache_dit.metrics.metrics import compute_video_ssim
6
6
  from cache_dit.metrics.metrics import compute_video_mse
7
7
  from cache_dit.metrics.metrics import entrypoint
8
8
  from cache_dit.metrics.fid import FrechetInceptionDistance
9
+ from cache_dit.metrics.config import set_metrics_verbose
10
+ from cache_dit.metrics.config import get_metrics_verbose
9
11
 
10
12
 
11
13
  def main():
@@ -0,0 +1,34 @@
1
+ from cache_dit.logger import init_logger
2
+
3
+ logger = init_logger(__name__)
4
+
5
+
6
+ _metrics_progress_verbose = False
7
+
8
+
9
+ def set_metrics_verbose(verbose: bool):
10
+ global _metrics_progress_verbose
11
+ _metrics_progress_verbose = verbose
12
+ logger.debug(f"Metrics verbose: {verbose}")
13
+
14
+
15
+ def get_metrics_verbose() -> bool:
16
+ global _metrics_progress_verbose
17
+ return _metrics_progress_verbose
18
+
19
+
20
+ _IMAGE_EXTENSIONS = [
21
+ "bmp",
22
+ "jpg",
23
+ "jpeg",
24
+ "pgm",
25
+ "png",
26
+ "ppm",
27
+ "tif",
28
+ "tiff",
29
+ "webp",
30
+ ]
31
+
32
+ _VIDEO_EXTENSIONS = [
33
+ "mp4",
34
+ ]
cache_dit/metrics/fid.py CHANGED
@@ -9,11 +9,14 @@ import torch
9
9
  import torchvision.transforms as TF
10
10
  from torch.nn.functional import adaptive_avg_pool2d
11
11
  from cache_dit.metrics.inception import InceptionV3
12
+ from cache_dit.metrics.config import _IMAGE_EXTENSIONS
13
+ from cache_dit.metrics.config import _VIDEO_EXTENSIONS
12
14
  from cache_dit.logger import init_logger
13
15
 
14
16
  logger = init_logger(__name__)
15
17
 
16
18
 
19
+ # Adapted from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py
17
20
  class ImagePathDataset(torch.utils.data.Dataset):
18
21
  def __init__(self, files_or_imgs, transforms=None):
19
22
  self.files_or_imgs = files_or_imgs
@@ -219,22 +222,7 @@ def calculate_activation_statistics(
219
222
  return mu, sigma
220
223
 
221
224
 
222
- _IMAGE_EXTENSIONS = {
223
- "bmp",
224
- "jpg",
225
- "jpeg",
226
- "pgm",
227
- "png",
228
- "ppm",
229
- "tif",
230
- "tiff",
231
- "webp",
232
- }
233
-
234
-
235
225
  class FrechetInceptionDistance:
236
- IMAGE_EXTENSIONS = _IMAGE_EXTENSIONS
237
-
238
226
  def __init__(
239
227
  self,
240
228
  device="cuda" if torch.cuda.is_available() else "cpu",
@@ -258,7 +246,8 @@ class FrechetInceptionDistance:
258
246
  image_true: np.ndarray | str,
259
247
  image_test: np.ndarray | str,
260
248
  ):
261
- """Calculates the FID of two file paths
249
+ """
250
+ Calculates the FID of two file paths
262
251
  FID = FrechetInceptionDistance()
263
252
  img_fid = FID.compute_fid("img_true.png", "img_test.png")
264
253
  img_dir_fid = FID.compute_fid("img_true_dir", "img_test_dir")
@@ -267,8 +256,8 @@ class FrechetInceptionDistance:
267
256
  if os.path.isfile(image_true) or os.path.isfile(image_test):
268
257
  assert os.path.exists(image_true)
269
258
  assert os.path.exists(image_test)
270
- assert image_true.split(".")[-1] in self.IMAGE_EXTENSIONS
271
- assert image_test.split(".")[-1] in self.IMAGE_EXTENSIONS
259
+ assert image_true.split(".")[-1] in _IMAGE_EXTENSIONS
260
+ assert image_test.split(".")[-1] in _IMAGE_EXTENSIONS
272
261
  image_true_files = [image_true]
273
262
  image_test_files = [image_test]
274
263
  else:
@@ -279,7 +268,7 @@ class FrechetInceptionDistance:
279
268
  image_true_files = sorted(
280
269
  [
281
270
  file
282
- for ext in self.IMAGE_EXTENSIONS
271
+ for ext in _IMAGE_EXTENSIONS
283
272
  for file in image_true_dir.rglob("*.{}".format(ext))
284
273
  ]
285
274
  )
@@ -287,7 +276,7 @@ class FrechetInceptionDistance:
287
276
  image_test_files = sorted(
288
277
  [
289
278
  file
290
- for ext in self.IMAGE_EXTENSIONS
279
+ for ext in _IMAGE_EXTENSIONS
291
280
  for file in image_test_dir.rglob("*.{}".format(ext))
292
281
  ]
293
282
  )
@@ -297,15 +286,32 @@ class FrechetInceptionDistance:
297
286
  image_test_files = [
298
287
  file.as_posix() for file in image_test_files
299
288
  ]
289
+
290
+ # select valid files
291
+ image_true_files_selected = []
292
+ image_test_files_selected = []
293
+ for i in range(
294
+ min(len(image_true_files), len(image_test_files))
295
+ ):
296
+ selected_image_true = image_true_files[i]
297
+ selected_image_test = image_test_files[i]
298
+ # Image pair must have the same basename
299
+ if os.path.basename(
300
+ selected_image_test
301
+ ) == os.path.basename(selected_image_true):
302
+ image_true_files_selected.append(selected_image_true)
303
+ image_test_files_selected.append(selected_image_test)
304
+ image_true_files = image_true_files_selected.copy()
305
+ image_test_files = image_test_files_selected.copy()
306
+ if len(image_true_files) == 0:
307
+ logger.error(
308
+ "No valid Image pairs, please note that Image "
309
+ "pairs must have the same basename."
310
+ )
311
+ return None, None
312
+
300
313
  logger.debug(f"image_true_files: {image_true_files}")
301
314
  logger.debug(f"image_test_files: {image_test_files}")
302
- assert len(image_true_files) == len(image_test_files)
303
- for image_true, image_test in zip(
304
- image_true_files, image_test_files
305
- ):
306
- assert os.path.basename(image_true) == os.path.basename(
307
- image_test
308
- ), f"image_true:{image_true} != image_test: {image_test}"
309
315
  else:
310
316
  image_true_files = [image_true]
311
317
  image_test_files = [image_test]
@@ -340,6 +346,115 @@ class FrechetInceptionDistance:
340
346
  return fid_value, len(image_true_files)
341
347
 
342
348
  def compute_video_fid(
349
+ self,
350
+ # file or dir
351
+ video_true: str,
352
+ video_test: str,
353
+ ):
354
+ if os.path.isfile(video_true) and os.path.isfile(video_test):
355
+ video_true_frames, video_test_frames, valid_frames = (
356
+ self._fetch_video_frames(
357
+ video_true=video_true,
358
+ video_test=video_test,
359
+ )
360
+ )
361
+ elif os.path.isdir(video_true) and os.path.isdir(video_test):
362
+ # Glob videos
363
+ video_true_dir: pathlib.Path = pathlib.Path(video_true)
364
+ video_true_files = sorted(
365
+ [
366
+ file
367
+ for ext in _VIDEO_EXTENSIONS
368
+ for file in video_true_dir.rglob("*.{}".format(ext))
369
+ ]
370
+ )
371
+ video_test_dir: pathlib.Path = pathlib.Path(video_test)
372
+ video_test_files = sorted(
373
+ [
374
+ file
375
+ for ext in _VIDEO_EXTENSIONS
376
+ for file in video_test_dir.rglob("*.{}".format(ext))
377
+ ]
378
+ )
379
+ video_true_files = [file.as_posix() for file in video_true_files]
380
+ video_test_files = [file.as_posix() for file in video_test_files]
381
+
382
+ # select valid video files
383
+ video_true_files_selected = []
384
+ video_test_files_selected = []
385
+ for i in range(min(len(video_true_files), len(video_test_files))):
386
+ selected_video_true = video_true_files[i]
387
+ selected_video_test = video_test_files[i]
388
+ # Video pair must have the same basename
389
+ if os.path.basename(selected_video_test) == os.path.basename(
390
+ selected_video_true
391
+ ):
392
+ video_true_files_selected.append(selected_video_true)
393
+ video_test_files_selected.append(selected_video_test)
394
+
395
+ video_true_files = video_true_files_selected.copy()
396
+ video_test_files = video_test_files_selected.copy()
397
+ if len(video_true_files) == 0:
398
+ logger.error(
399
+ "No valid Video pairs, please note that Video "
400
+ "pairs must have the same basename."
401
+ )
402
+ return None, None
403
+ logger.debug(f"video_true_files: {video_true_files}")
404
+ logger.debug(f"video_test_files: {video_test_files}")
405
+
406
+ # Fetch all frames
407
+ video_true_frames = []
408
+ video_test_frames = []
409
+ valid_frames = 0
410
+
411
+ for video_true_, video_test_ in zip(
412
+ video_true_files, video_test_files
413
+ ):
414
+ video_true_frames_, video_test_frames_, valid_frames_ = (
415
+ self._fetch_video_frames(
416
+ video_true=video_true_, video_test=video_test_
417
+ )
418
+ )
419
+ video_true_frames.extend(video_true_frames_)
420
+ video_test_frames.extend(video_test_frames_)
421
+ valid_frames += valid_frames_
422
+ else:
423
+ raise ValueError("video_true and video_test must be files or dirs.")
424
+
425
+ if valid_frames <= 0:
426
+ logger.debug("No valid frames to compare")
427
+ return None, None
428
+
429
+ batch_size = min(16, self.batch_size)
430
+ m1, s1 = calculate_activation_statistics(
431
+ video_true_frames,
432
+ self.model,
433
+ batch_size,
434
+ self.dims,
435
+ self.device,
436
+ self.num_workers,
437
+ self.disable_tqdm,
438
+ )
439
+ m2, s2 = calculate_activation_statistics(
440
+ video_test_frames,
441
+ self.model,
442
+ batch_size,
443
+ self.dims,
444
+ self.device,
445
+ self.num_workers,
446
+ self.disable_tqdm,
447
+ )
448
+ fid_value = calculate_frechet_distance(
449
+ m1,
450
+ s1,
451
+ m2,
452
+ s2,
453
+ )
454
+
455
+ return fid_value, valid_frames
456
+
457
+ def _fetch_video_frames(
343
458
  self,
344
459
  video_true: str,
345
460
  video_test: str,
@@ -349,7 +464,7 @@ class FrechetInceptionDistance:
349
464
 
350
465
  if not cap1.isOpened() or not cap2.isOpened():
351
466
  logger.error("Could not open video files")
352
- return None, None
467
+ return [], [], 0
353
468
 
354
469
  frame_count = min(
355
470
  int(cap1.get(cv2.CAP_PROP_FRAME_COUNT)),
@@ -378,32 +493,6 @@ class FrechetInceptionDistance:
378
493
  cap2.release()
379
494
 
380
495
  if valid_frames <= 0:
381
- return None, None
496
+ return [], [], 0
382
497
 
383
- batch_size = min(16, self.batch_size)
384
- m1, s1 = calculate_activation_statistics(
385
- video_true_frames,
386
- self.model,
387
- batch_size,
388
- self.dims,
389
- self.device,
390
- self.num_workers,
391
- self.disable_tqdm,
392
- )
393
- m2, s2 = calculate_activation_statistics(
394
- video_test_frames,
395
- self.model,
396
- batch_size,
397
- self.dims,
398
- self.device,
399
- self.num_workers,
400
- self.disable_tqdm,
401
- )
402
- fid_value = calculate_frechet_distance(
403
- m1,
404
- s1,
405
- m2,
406
- s2,
407
- )
408
-
409
- return fid_value, valid_frames
498
+ return video_true_frames, video_test_frames, valid_frames
@@ -3,16 +3,24 @@ import cv2
3
3
  import pathlib
4
4
  import argparse
5
5
  import numpy as np
6
+ from tqdm import tqdm
6
7
  from functools import partial
7
8
  from skimage.metrics import mean_squared_error
8
9
  from skimage.metrics import peak_signal_noise_ratio
9
10
  from skimage.metrics import structural_similarity
10
11
  from cache_dit.metrics.fid import FrechetInceptionDistance
12
+ from cache_dit.metrics.config import set_metrics_verbose
13
+ from cache_dit.metrics.config import get_metrics_verbose
14
+ from cache_dit.metrics.config import _IMAGE_EXTENSIONS
15
+ from cache_dit.metrics.config import _VIDEO_EXTENSIONS
11
16
  from cache_dit.logger import init_logger
12
17
 
13
18
  logger = init_logger(__name__)
14
19
 
15
20
 
21
+ DISABLE_VERBOSE = not get_metrics_verbose()
22
+
23
+
16
24
  def compute_psnr_file(
17
25
  image_true: np.ndarray | str,
18
26
  image_test: np.ndarray | str,
@@ -72,19 +80,6 @@ def compute_ssim_file(
72
80
  )
73
81
 
74
82
 
75
- _IMAGE_EXTENSIONS = {
76
- "bmp",
77
- "jpg",
78
- "jpeg",
79
- "pgm",
80
- "png",
81
- "ppm",
82
- "tif",
83
- "tiff",
84
- "webp",
85
- }
86
-
87
-
88
83
  def compute_dir_metric(
89
84
  image_true_dir: np.ndarray | str,
90
85
  image_test_dir: np.ndarray | str,
@@ -117,17 +112,38 @@ def compute_dir_metric(
117
112
  )
118
113
  image_true_files = [file.as_posix() for file in image_true_files]
119
114
  image_test_files = [file.as_posix() for file in image_test_files]
115
+
116
+ # select valid files
117
+ image_true_files_selected = []
118
+ image_test_files_selected = []
119
+ for i in range(min(len(image_true_files), len(image_test_files))):
120
+ selected_image_true = image_true_files[i]
121
+ selected_image_test = image_test_files[i]
122
+ # Image pair must have the same basename
123
+ if os.path.basename(selected_image_test) == os.path.basename(
124
+ selected_image_true
125
+ ):
126
+ image_true_files_selected.append(selected_image_true)
127
+ image_test_files_selected.append(selected_image_test)
128
+ image_true_files = image_true_files_selected.copy()
129
+ image_test_files = image_test_files_selected.copy()
130
+ if len(image_true_files) == 0:
131
+ logger.error(
132
+ "No valid Image pairs, please note that Image "
133
+ "pairs must have the same basename."
134
+ )
135
+ return None, None
136
+
120
137
  logger.debug(f"image_true_files: {image_true_files}")
121
138
  logger.debug(f"image_test_files: {image_test_files}")
122
- assert len(image_true_files) == len(image_test_files)
123
- for image_true, image_test in zip(image_true_files, image_test_files):
124
- assert os.path.basename(image_true) == os.path.basename(
125
- image_test
126
- ), f"image_true:{image_true} != image_test: {image_test}"
127
139
 
128
140
  total_metric = 0.0
129
141
  valid_files = 0
130
- for image_true, image_test in zip(image_true_files, image_test_files):
142
+ for image_true, image_test in tqdm(
143
+ zip(image_true_files, image_test_files),
144
+ total=len(image_true_files),
145
+ disable=DISABLE_VERBOSE,
146
+ ):
131
147
  metric = compute_file_func(image_true, image_test)
132
148
  if metric != float("inf"):
133
149
  total_metric += metric
@@ -142,30 +158,25 @@ def compute_dir_metric(
142
158
  return None, None
143
159
 
144
160
 
145
- def compute_video_metric(
161
+ def _fetch_video_frames(
146
162
  video_true: str,
147
163
  video_test: str,
148
- compute_frame_func: callable = compute_psnr_file,
149
- ) -> float:
150
- """
151
- video_true = "video_true.mp4"
152
- video_test = "video_test.mp4"
153
- PSNR = compute_video_psnr(video_true, video_test)
154
- """
164
+ ):
155
165
  cap1 = cv2.VideoCapture(video_true)
156
166
  cap2 = cv2.VideoCapture(video_test)
157
167
 
158
168
  if not cap1.isOpened() or not cap2.isOpened():
159
169
  logger.error("Could not open video files")
160
- return None
170
+ return [], [], 0
161
171
 
162
172
  frame_count = min(
163
173
  int(cap1.get(cv2.CAP_PROP_FRAME_COUNT)),
164
174
  int(cap2.get(cv2.CAP_PROP_FRAME_COUNT)),
165
175
  )
166
176
 
167
- total_metric = 0.0
168
177
  valid_frames = 0
178
+ video_true_frames = []
179
+ video_test_frames = []
169
180
 
170
181
  logger.debug(f"Total frames: {frame_count}")
171
182
 
@@ -176,18 +187,115 @@ def compute_video_metric(
176
187
  if not ret1 or not ret2:
177
188
  break
178
189
 
179
- metric = compute_frame_func(frame1, frame2)
190
+ video_true_frames.append(frame1)
191
+ video_test_frames.append(frame2)
180
192
 
181
- if metric != float("inf"):
182
- total_metric += metric
183
- valid_frames += 1
184
-
185
- if valid_frames % 10 == 0:
186
- logger.debug(f"Processed {valid_frames}/{frame_count} frames")
193
+ valid_frames += 1
187
194
 
188
195
  cap1.release()
189
196
  cap2.release()
190
197
 
198
+ if valid_frames <= 0:
199
+ return [], [], 0
200
+
201
+ return video_true_frames, video_test_frames, valid_frames
202
+
203
+
204
+ def compute_video_metric(
205
+ video_true: str,
206
+ video_test: str,
207
+ compute_frame_func: callable = compute_psnr_file,
208
+ ) -> float:
209
+ """
210
+ video_true = "video_true.mp4"
211
+ video_test = "video_test.mp4"
212
+ PSNR = compute_video_psnr(video_true, video_test)
213
+ """
214
+ if os.path.isfile(video_true) and os.path.isfile(video_test):
215
+ video_true_frames, video_test_frames, valid_frames = (
216
+ _fetch_video_frames(
217
+ video_true=video_true,
218
+ video_test=video_test,
219
+ )
220
+ )
221
+ elif os.path.isdir(video_true) and os.path.isdir(video_test):
222
+ # Glob videos
223
+ video_true_dir: pathlib.Path = pathlib.Path(video_true)
224
+ video_true_files = sorted(
225
+ [
226
+ file
227
+ for ext in _VIDEO_EXTENSIONS
228
+ for file in video_true_dir.rglob("*.{}".format(ext))
229
+ ]
230
+ )
231
+ video_test_dir: pathlib.Path = pathlib.Path(video_test)
232
+ video_test_files = sorted(
233
+ [
234
+ file
235
+ for ext in _VIDEO_EXTENSIONS
236
+ for file in video_test_dir.rglob("*.{}".format(ext))
237
+ ]
238
+ )
239
+ video_true_files = [file.as_posix() for file in video_true_files]
240
+ video_test_files = [file.as_posix() for file in video_test_files]
241
+
242
+ # select valid video files
243
+ video_true_files_selected = []
244
+ video_test_files_selected = []
245
+ for i in range(min(len(video_true_files), len(video_test_files))):
246
+ selected_video_true = video_true_files[i]
247
+ selected_video_test = video_test_files[i]
248
+ # Video pair must have the same basename
249
+ if os.path.basename(selected_video_test) == os.path.basename(
250
+ selected_video_true
251
+ ):
252
+ video_true_files_selected.append(selected_video_true)
253
+ video_test_files_selected.append(selected_video_test)
254
+
255
+ video_true_files = video_true_files_selected.copy()
256
+ video_test_files = video_test_files_selected.copy()
257
+ if len(video_true_files) == 0:
258
+ logger.error(
259
+ "No valid Video pairs, please note that Video "
260
+ "pairs must have the same basename."
261
+ )
262
+ return None, None
263
+ logger.debug(f"video_true_files: {video_true_files}")
264
+ logger.debug(f"video_test_files: {video_test_files}")
265
+
266
+ # Fetch all frames
267
+ video_true_frames = []
268
+ video_test_frames = []
269
+ valid_frames = 0
270
+
271
+ for video_true_, video_test_ in zip(video_true_files, video_test_files):
272
+ video_true_frames_, video_test_frames_, valid_frames_ = (
273
+ _fetch_video_frames(
274
+ video_true=video_true_, video_test=video_test_
275
+ )
276
+ )
277
+ video_true_frames.extend(video_true_frames_)
278
+ video_test_frames.extend(video_test_frames_)
279
+ valid_frames += valid_frames_
280
+ else:
281
+ raise ValueError("video_true and video_test must be files or dirs.")
282
+
283
+ if valid_frames <= 0:
284
+ logger.debug("No valid frames to compare")
285
+ return None, None
286
+
287
+ total_metric = 0.0
288
+ valid_frames = 0 # reset
289
+ for frame1, frame2 in tqdm(
290
+ zip(video_true_frames, video_test_frames),
291
+ total=len(video_true_frames),
292
+ disable=DISABLE_VERBOSE,
293
+ ):
294
+ metric = compute_frame_func(frame1, frame2)
295
+ if metric != float("inf"):
296
+ total_metric += metric
297
+ valid_frames += 1
298
+
191
299
  if valid_frames > 0:
192
300
  average_metric = total_metric / valid_frames
193
301
  logger.debug(f"Average: {average_metric:.2f}")
@@ -265,14 +373,21 @@ def get_args():
265
373
  "-v1",
266
374
  type=str,
267
375
  default=None,
268
- help="Path to ground truth video",
376
+ help="Path to ground truth video or Dir to ground truth videos",
269
377
  )
270
378
  parser.add_argument(
271
379
  "--video-test",
272
380
  "-v2",
273
381
  type=str,
274
382
  default=None,
275
- help="Path to predicted video",
383
+ help="Path to predicted video or Dir to predicted videos",
384
+ )
385
+ parser.add_argument(
386
+ "--enable-verbose",
387
+ "-verbose",
388
+ action="store_true",
389
+ default=False,
390
+ help="Show metrics progress verbose",
276
391
  )
277
392
  return parser.parse_args()
278
393
 
@@ -281,6 +396,11 @@ def entrypoint():
281
396
  args = get_args()
282
397
  logger.debug(args)
283
398
 
399
+ if args.enable_verbose:
400
+ global DISABLE_VERBOSE
401
+ set_metrics_verbose(True)
402
+ DISABLE_VERBOSE = not get_metrics_verbose()
403
+
284
404
  if args.img_true is not None and args.img_test is not None:
285
405
  if any(
286
406
  (
@@ -306,7 +426,7 @@ def entrypoint():
306
426
  f"{args.img_true} vs {args.img_test}, Num: {n}, MSE: {img_mse}"
307
427
  )
308
428
  if args.metric == "fid" or args.metric == "all":
309
- FID = FrechetInceptionDistance()
429
+ FID = FrechetInceptionDistance(disable_tqdm=DISABLE_VERBOSE)
310
430
  img_fid, n = FID.compute_fid(args.img_true, args.img_test)
311
431
  logger.info(
312
432
  f"{args.img_true} vs {args.img_test}, Num: {n}, FID: {img_fid}"
@@ -319,36 +439,29 @@ def entrypoint():
319
439
  )
320
440
  ):
321
441
  return
442
+ # video_true and video_test can be files or dirs
322
443
  if args.metric == "psnr" or args.metric == "all":
323
- assert not os.path.isdir(args.video_true)
324
- assert not os.path.isdir(args.video_test)
325
444
  video_psnr, n = compute_video_psnr(args.video_true, args.video_test)
326
445
  logger.info(
327
- f"{args.video_true} vs {args.video_test}, Num: {n}, PSNR: {video_psnr}"
446
+ f"{args.video_true} vs {args.video_test}, Frames: {n}, PSNR: {video_psnr}"
328
447
  )
329
448
  if args.metric == "ssim" or args.metric == "all":
330
- assert not os.path.isdir(args.video_true)
331
- assert not os.path.isdir(args.video_test)
332
449
  video_ssim, n = compute_video_ssim(args.video_true, args.video_test)
333
450
  logger.info(
334
- f"{args.video_true} vs {args.video_test}, Num: {n}, SSIM: {video_ssim}"
451
+ f"{args.video_true} vs {args.video_test}, Frames: {n}, SSIM: {video_ssim}"
335
452
  )
336
453
  if args.metric == "mse" or args.metric == "all":
337
- assert not os.path.isdir(args.video_true)
338
- assert not os.path.isdir(args.video_test)
339
454
  video_mse, n = compute_video_mse(args.video_true, args.video_test)
340
455
  logger.info(
341
- f"{args.video_true} vs {args.video_test}, Num: {n}, MSE: {video_mse}"
456
+ f"{args.video_true} vs {args.video_test}, Frames: {n}, MSE: {video_mse}"
342
457
  )
343
458
  if args.metric == "fid" or args.metric == "all":
344
- assert not os.path.isdir(args.video_true)
345
- assert not os.path.isdir(args.video_test)
346
- FID = FrechetInceptionDistance()
459
+ FID = FrechetInceptionDistance(disable_tqdm=DISABLE_VERBOSE)
347
460
  video_fid, n = FID.compute_video_fid(
348
461
  args.video_true, args.video_test
349
462
  )
350
463
  logger.info(
351
- f"{args.video_true} vs {args.video_test}, Num: {n}, FID: {video_fid}"
464
+ f"{args.video_true} vs {args.video_test}, Frames: {n}, FID: {video_fid}"
352
465
  )
353
466
 
354
467
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.2.6
3
+ Version: 0.2.8
4
4
  Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -519,23 +519,21 @@ from cache_dit.metrics import FrechetInceptionDistance # FID
519
519
 
520
520
  FID = FrechetInceptionDistance()
521
521
  image_psnr, n = compute_psnr("true.png", "test.png") # Num: n
522
- image_fid, n = FID.compute_fid("true.png", "test.png")
523
- video_psnr, n = compute_video_psnr("true.mp4", "test.mp4")
522
+ image_fid, n = FID.compute_fid("true_dir", "test_dir")
523
+ video_psnr, n = compute_video_psnr("true.mp4", "test.mp4") # Frames: n
524
524
  ```
525
525
 
526
526
  Please check [test_metrics.py](./tests/test_metrics.py) for more details. Or, you can use `cache-dit-metrics-cli` tool. For examples:
527
527
 
528
528
  ```bash
529
529
  cache-dit-metrics-cli -h # show usage
530
- cache-dit-metrics-cli all -v1 true.mp4 -v2 test.mp4 # compare video
531
- cache-dit-metrics-cli all -i1 true.png -i2 test.png # compare image
532
- cache-dit-metrics-cli all -i1 true_dir -i2 test_dir # compare image dir
533
- cache-dit-metrics-cli all -i1 BASELINE -i2 OPTIMIZED # compare image dir
534
-
535
- INFO 07-09 20:59:40 [metrics.py:295] BASELINE vs OPTIMIZED, Num: 1000, PSNR: 38.742413478199005
536
- INFO 07-09 21:00:32 [metrics.py:300] BASELINE vs OPTIMIZED, Num: 1000, SSIM: 0.9863484896791567
537
- INFO 07-09 21:00:45 [metrics.py:305] BASELINE vs OPTIMIZED, Num: 1000, MSE: 12.287594770695606
538
- INFO 07-09 21:01:04 [metrics.py:311] BASELINE vs OPTIMIZED, Num: 1000, FID: 5.983550108647762
530
+ # all: PSNR, FID, SSIM, MSE, ..., etc.
531
+ cache-dit-metrics-cli all -i1 true.png -i2 test.png # image
532
+ cache-dit-metrics-cli all -i1 true_dir -i2 test_dir # image dir
533
+ cache-dit-metrics-cli all -v1 true.mp4 -v2 test.mp4 # video
534
+ cache-dit-metrics-cli all -v1 true_dir -v2 test_dir # video dir
535
+ cache-dit-metrics-cli fid -i1 true_dir -i2 test_dir # FID
536
+ cache-dit-metrics-cli psnr -i1 true_dir -i2 test_dir # PSNR
539
537
  ```
540
538
 
541
539
  ## 👋Contribute
@@ -1,17 +1,17 @@
1
1
  cache_dit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- cache_dit/_version.py,sha256=nObnONsicQ3YX6SG5MVBxmIp5dmRacXDauSqZijWQbY,511
2
+ cache_dit/_version.py,sha256=zkhRarrvPoGA1yWjS9_zVM80dWqpDesNn9DiHcF4JWM,511
3
3
  cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
4
4
  cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
5
5
  cache_dit/cache_factory/__init__.py,sha256=5RNuhWakvvqrOV4vkqrEBA7d-V1LwcNSsjtW14mkqK8,5255
6
6
  cache_dit/cache_factory/taylorseer.py,sha256=LKSNo2ode69EVo9xrxjxAMEjz0yDGiGADeDYnEqddA8,3987
7
7
  cache_dit/cache_factory/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  cache_dit/cache_factory/dual_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
- cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=wE_xYp7DRbgB-fD8dpr75o4Cvvl2s-jnT2fRyqWm_RM,71286
9
+ cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=itVEb6gT2eZuncAHUmP51ZS0r6v6cGtRvnPjyeXqKH8,71156
10
10
  cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py,sha256=krNAICf-aS3JLmSG8vOB9tpLa04uYRcABsC8PMbVUKY,1870
11
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py,sha256=fibkeU-FHa30BNT-uPV2Eqcd5IRli07EKb25tMDp23c,2270
12
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py,sha256=fddSpTHXU24COMGAY-Z21EmHHAEArZBv_-XLRFD6ADU,2625
13
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py,sha256=wcZdBhjUB8WSfz40A268BtSe3nr_hRsIi2BNlg1FHRU,9965
14
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py,sha256=Cmy0KHRDgwXqtmqfkrr7kw0CP6CmkSnuz29gDHcD6sQ,2262
11
+ cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py,sha256=3xUjvDzor9AkBkDUc0N7kZqM86MIdajuigesnicNzXE,2260
12
+ cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py,sha256=cIsov6Pf0dRyddqkzTA2CU-jSDotof8LQr-HIoY9T9M,2615
13
+ cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py,sha256=SO4q39PQuQ5QVHy5Z-ubiKdstzvQPedONN2J5oiGUh0,9955
14
+ cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py,sha256=8W9m-WeEVE2ytYi9udKEA8Wtb0EnvP3eT2A1Tu-d29k,2252
15
15
  cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py,sha256=EREHM5E1wxnL-uRXRAEege4HXraRp1oD_r1Zx4CsiKk,2596
16
16
  cache_dit/cache_factory/dynamic_block_prune/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
17
  cache_dit/cache_factory/dynamic_block_prune/prune_context.py,sha256=so1wGdb8W0ATwrjv7E5IEZLPcobybaY1HJa6hBYlOOQ,34698
@@ -33,13 +33,14 @@ cache_dit/compile/__init__.py,sha256=DfMdPleFFGADXLsr7zXui8BTz_y9futY6rNmNdh9y7k
33
33
  cache_dit/compile/utils.py,sha256=KU60xc474Anbj7Y_FLRFmNxEjVYLLXkhbtCLXO7o_Tc,3699
34
34
  cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
35
35
  cache_dit/custom_ops/triton_taylorseer.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
36
- cache_dit/metrics/__init__.py,sha256=5yavk_6b50EhaHij0C-7nQegOA0uD79-qX96dI1i_8s,461
37
- cache_dit/metrics/fid.py,sha256=RIC-wW0RFv6W9IW6nc6Ih4dAxunUTnzcugMbYZjdqRM,12891
36
+ cache_dit/metrics/__init__.py,sha256=RaUhl5dieF40RqnizGzR30qoJJ9dyMUEADwgwMaMQrE,575
37
+ cache_dit/metrics/config.py,sha256=ieOgD9ayz722RjVzk24bSIqS2D6o7TZjGk8KeXV-OLQ,551
38
+ cache_dit/metrics/fid.py,sha256=9Ivtazl6mW0Bon2VXa-Ia5Xj2ewxRD3V1Qkd69zYM3Y,17066
38
39
  cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
39
- cache_dit/metrics/metrics.py,sha256=FUrpc58ofg0LKyM3Y_7kfUVbevMfCFFcaeaxeOrj9iY,10498
40
- cache_dit-0.2.6.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
41
- cache_dit-0.2.6.dist-info/METADATA,sha256=aTTyx_2gM7Z-mcIYd3_LlGljiZOwIq9iadDXOFPDzwA,27848
42
- cache_dit-0.2.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
43
- cache_dit-0.2.6.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
44
- cache_dit-0.2.6.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
45
- cache_dit-0.2.6.dist-info/RECORD,,
40
+ cache_dit/metrics/metrics.py,sha256=tzAtG_-fM1xPIBfRVFIBupvOWYzIO3xDq29Vy5rOBWc,14730
41
+ cache_dit-0.2.8.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
42
+ cache_dit-0.2.8.dist-info/METADATA,sha256=8E51DpSKDGqk3_cG9buahoXN-7fub6M8VCiPb_Idg64,27608
43
+ cache_dit-0.2.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
44
+ cache_dit-0.2.8.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
45
+ cache_dit-0.2.8.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
46
+ cache_dit-0.2.8.dist-info/RECORD,,