cache-dit 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.5'
21
- __version_tuple__ = version_tuple = (0, 2, 5)
20
+ __version__ = version = '0.2.7'
21
+ __version_tuple__ = version_tuple = (0, 2, 7)
@@ -988,7 +988,7 @@ def get_Bn_encoder_buffer(prefix: str = "Bn"):
988
988
  @torch.compiler.disable
989
989
  def apply_hidden_states_residual(
990
990
  hidden_states: torch.Tensor,
991
- encoder_hidden_states: torch.Tensor,
991
+ encoder_hidden_states: torch.Tensor = None,
992
992
  prefix: str = "Bn",
993
993
  encoder_prefix: str = "Bn_encoder",
994
994
  ):
@@ -1006,25 +1006,27 @@ def apply_hidden_states_residual(
1006
1006
  # If cache is not residual, we use the hidden states directly
1007
1007
  hidden_states = hidden_states_prev
1008
1008
 
1009
- if "Bn" in encoder_prefix:
1010
- encoder_hidden_states_prev = get_Bn_encoder_buffer(encoder_prefix)
1011
- else:
1012
- encoder_hidden_states_prev = get_Fn_encoder_buffer(encoder_prefix)
1009
+ hidden_states = hidden_states.contiguous()
1013
1010
 
1014
- assert (
1015
- encoder_hidden_states_prev is not None
1016
- ), f"{prefix}_encoder_buffer must be set before"
1011
+ if encoder_hidden_states is not None:
1012
+ if "Bn" in encoder_prefix:
1013
+ encoder_hidden_states_prev = get_Bn_encoder_buffer(encoder_prefix)
1014
+ else:
1015
+ encoder_hidden_states_prev = get_Fn_encoder_buffer(encoder_prefix)
1017
1016
 
1018
- if is_encoder_cache_residual():
1019
- encoder_hidden_states = (
1020
- encoder_hidden_states_prev + encoder_hidden_states
1021
- )
1022
- else:
1023
- # If encoder cache is not residual, we use the encoder hidden states directly
1024
- encoder_hidden_states = encoder_hidden_states_prev
1017
+ assert (
1018
+ encoder_hidden_states_prev is not None
1019
+ ), f"{prefix}_encoder_buffer must be set before"
1025
1020
 
1026
- hidden_states = hidden_states.contiguous()
1027
- encoder_hidden_states = encoder_hidden_states.contiguous()
1021
+ if is_encoder_cache_residual():
1022
+ encoder_hidden_states = (
1023
+ encoder_hidden_states_prev + encoder_hidden_states
1024
+ )
1025
+ else:
1026
+ # If encoder cache is not residual, we use the encoder hidden states directly
1027
+ encoder_hidden_states = encoder_hidden_states_prev
1028
+
1029
+ encoder_hidden_states = encoder_hidden_states.contiguous()
1028
1030
 
1029
1031
  return hidden_states, encoder_hidden_states
1030
1032
 
@@ -82,6 +82,6 @@ def apply_db_cache_on_pipe(
82
82
  pipe.__class__._is_cached = True
83
83
 
84
84
  if not shallow_patch:
85
- apply_db_cache_on_transformer(pipe.transformer, **kwargs)
85
+ apply_db_cache_on_transformer(pipe.transformer)
86
86
 
87
87
  return pipe
@@ -93,6 +93,6 @@ def apply_db_cache_on_pipe(
93
93
  pipe.__class__._is_cached = True
94
94
 
95
95
  if not shallow_patch:
96
- apply_db_cache_on_transformer(pipe.transformer, **kwargs)
96
+ apply_db_cache_on_transformer(pipe.transformer)
97
97
 
98
98
  return pipe
@@ -289,6 +289,6 @@ def apply_db_cache_on_pipe(
289
289
  pipe.__class__._is_cached = True
290
290
 
291
291
  if not shallow_patch:
292
- apply_db_cache_on_transformer(pipe.transformer, **kwargs)
292
+ apply_db_cache_on_transformer(pipe.transformer)
293
293
 
294
294
  return pipe
@@ -82,6 +82,6 @@ def apply_db_cache_on_pipe(
82
82
  pipe.__class__._is_cached = True
83
83
 
84
84
  if not shallow_patch:
85
- apply_db_cache_on_transformer(pipe.transformer, **kwargs)
85
+ apply_db_cache_on_transformer(pipe.transformer)
86
86
 
87
87
  return pipe
@@ -0,0 +1,14 @@
1
+ from cache_dit.metrics.metrics import compute_psnr
2
+ from cache_dit.metrics.metrics import compute_ssim
3
+ from cache_dit.metrics.metrics import compute_mse
4
+ from cache_dit.metrics.metrics import compute_video_psnr
5
+ from cache_dit.metrics.metrics import compute_video_ssim
6
+ from cache_dit.metrics.metrics import compute_video_mse
7
+ from cache_dit.metrics.metrics import entrypoint
8
+ from cache_dit.metrics.fid import FrechetInceptionDistance
9
+ from cache_dit.metrics.config import set_metrics_verbose
10
+ from cache_dit.metrics.config import get_metrics_verbose
11
+
12
+
13
+ def main():
14
+ entrypoint()
@@ -0,0 +1,34 @@
1
+ from cache_dit.logger import init_logger
2
+
3
+ logger = init_logger(__name__)
4
+
5
+
6
+ _metrics_progress_verbose = False
7
+
8
+
9
+ def set_metrics_verbose(verbose: bool):
10
+ global _metrics_progress_verbose
11
+ _metrics_progress_verbose = verbose
12
+ logger.debug(f"Metrics verbose: {verbose}")
13
+
14
+
15
+ def get_metrics_verbose() -> bool:
16
+ global _metrics_progress_verbose
17
+ return _metrics_progress_verbose
18
+
19
+
20
+ _IMAGE_EXTENSIONS = [
21
+ "bmp",
22
+ "jpg",
23
+ "jpeg",
24
+ "pgm",
25
+ "png",
26
+ "ppm",
27
+ "tif",
28
+ "tiff",
29
+ "webp",
30
+ ]
31
+
32
+ _VIDEO_EXTENSIONS = [
33
+ "mp4",
34
+ ]
@@ -0,0 +1,498 @@
1
+ import os
2
+ import cv2
3
+ import pathlib
4
+ import numpy as np
5
+ from PIL import Image
6
+ from tqdm import tqdm
7
+ from scipy import linalg
8
+ import torch
9
+ import torchvision.transforms as TF
10
+ from torch.nn.functional import adaptive_avg_pool2d
11
+ from cache_dit.metrics.inception import InceptionV3
12
+ from cache_dit.metrics.config import _IMAGE_EXTENSIONS
13
+ from cache_dit.metrics.config import _VIDEO_EXTENSIONS
14
+ from cache_dit.logger import init_logger
15
+
16
+ logger = init_logger(__name__)
17
+
18
+
19
+ # Adapted from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py
20
+ class ImagePathDataset(torch.utils.data.Dataset):
21
+ def __init__(self, files_or_imgs, transforms=None):
22
+ self.files_or_imgs = files_or_imgs
23
+ self.transforms = transforms
24
+
25
+ def __len__(self):
26
+ return len(self.files_or_imgs)
27
+
28
+ def __getitem__(self, i):
29
+ file_or_img = self.files_or_imgs[i]
30
+ if isinstance(file_or_img, (str, pathlib.Path)):
31
+ img = Image.open(file_or_img).convert("RGB")
32
+ elif isinstance(file_or_img, np.ndarray):
33
+ # Assume the img is a standard OpenCV image.
34
+ img = cv2.cvtColor(file_or_img, cv2.COLOR_BGR2RGB)
35
+ img = Image.fromarray(img)
36
+ else:
37
+ raise ValueError(
38
+ "file_or_img must be a file path or an OpenCV image."
39
+ )
40
+ if self.transforms is not None:
41
+ img = self.transforms(img)
42
+ return img
43
+
44
+
45
+ def get_activations(
46
+ files_or_imgs,
47
+ model,
48
+ batch_size=50,
49
+ dims=2048,
50
+ device="cpu",
51
+ num_workers=1,
52
+ disable_tqdm=True,
53
+ ):
54
+ """Calculates the activations of the pool_3 layer for all images.
55
+
56
+ Params:
57
+ -- files_or_imgs : List of image files paths or OpenCV image
58
+ -- model : Instance of inception model
59
+ -- batch_size : Batch size of images for the model to process at once.
60
+ Make sure that the number of samples is a multiple of
61
+ the batch size, otherwise some samples are ignored. This
62
+ behavior is retained to match the original FID score
63
+ implementation.
64
+ -- dims : Dimensionality of features returned by Inception
65
+ -- device : Device to run calculations
66
+ -- num_workers : Number of parallel dataloader workers
67
+
68
+ Returns:
69
+ -- A numpy array of dimension (num images, dims) that contains the
70
+ activations of the given tensor when feeding inception with the
71
+ query tensor.
72
+ """
73
+ model.eval()
74
+
75
+ if batch_size > len(files_or_imgs):
76
+ logger.info(
77
+ (
78
+ "Warning: batch size is bigger than the data size. "
79
+ "Setting batch size to data size"
80
+ )
81
+ )
82
+ batch_size = len(files_or_imgs)
83
+
84
+ dataset = ImagePathDataset(files_or_imgs, transforms=TF.ToTensor())
85
+ dataloader = torch.utils.data.DataLoader(
86
+ dataset,
87
+ batch_size=batch_size,
88
+ shuffle=False,
89
+ drop_last=False,
90
+ num_workers=num_workers,
91
+ )
92
+
93
+ pred_arr = np.empty((len(files_or_imgs), dims))
94
+
95
+ start_idx = 0
96
+
97
+ for batch in tqdm(dataloader, disable=disable_tqdm):
98
+ batch = batch.to(device)
99
+
100
+ with torch.no_grad():
101
+ pred = model(batch)[0]
102
+
103
+ # If model output is not scalar, apply global spatial average pooling.
104
+ # This happens if you choose a dimensionality not equal 2048.
105
+ if pred.size(2) != 1 or pred.size(3) != 1:
106
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
107
+
108
+ pred = pred.squeeze(3).squeeze(2).cpu().numpy()
109
+
110
+ pred_arr[start_idx : start_idx + pred.shape[0]] = pred
111
+
112
+ start_idx = start_idx + pred.shape[0]
113
+
114
+ return pred_arr
115
+
116
+
117
+ def calculate_frechet_distance(
118
+ mu1,
119
+ sigma1,
120
+ mu2,
121
+ sigma2,
122
+ eps=1e-6,
123
+ ):
124
+ """Numpy implementation of the Frechet Distance.
125
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
126
+ and X_2 ~ N(mu_2, C_2) is
127
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
128
+
129
+ Stable version by Dougal J. Sutherland.
130
+
131
+ Params:
132
+ -- mu1 : Numpy array containing the activations of a layer of the
133
+ inception net (like returned by the function 'get_predictions')
134
+ for generated samples.
135
+ -- mu2 : The sample mean over activations, precalculated on an
136
+ representative data set.
137
+ -- sigma1: The covariance matrix over activations for generated samples.
138
+ -- sigma2: The covariance matrix over activations, precalculated on an
139
+ representative data set.
140
+
141
+ Returns:
142
+ -- : The Frechet Distance.
143
+ """
144
+
145
+ mu1 = np.atleast_1d(mu1)
146
+ mu2 = np.atleast_1d(mu2)
147
+
148
+ sigma1 = np.atleast_2d(sigma1)
149
+ sigma2 = np.atleast_2d(sigma2)
150
+
151
+ assert (
152
+ mu1.shape == mu2.shape
153
+ ), "Training and test mean vectors have different lengths"
154
+ assert (
155
+ sigma1.shape == sigma2.shape
156
+ ), "Training and test covariances have different dimensions"
157
+
158
+ diff = mu1 - mu2
159
+
160
+ # Product might be almost singular
161
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
162
+ if not np.isfinite(covmean).all():
163
+ msg = (
164
+ "fid calculation produces singular product; "
165
+ "adding %s to diagonal of cov estimates"
166
+ ) % eps
167
+ print(msg)
168
+ offset = np.eye(sigma1.shape[0]) * eps
169
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
170
+
171
+ # Numerical error might give slight imaginary component
172
+ if np.iscomplexobj(covmean):
173
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
174
+ m = np.max(np.abs(covmean.imag))
175
+ raise ValueError("Imaginary component {}".format(m))
176
+ covmean = covmean.real
177
+
178
+ tr_covmean = np.trace(covmean)
179
+
180
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
181
+
182
+
183
+ def calculate_activation_statistics(
184
+ files_or_imgs,
185
+ model,
186
+ batch_size=50,
187
+ dims=2048,
188
+ device="cpu",
189
+ num_workers=1,
190
+ disable_tqdm=True,
191
+ ):
192
+ """Calculation of the statistics used by the FID.
193
+ Params:
194
+ -- files_or_imgs : List of image files paths or OpenCV image
195
+ -- model : Instance of inception model
196
+ -- batch_size : Batch size of images for the model to process at once.
197
+ Make sure that the number of samples is a multiple of
198
+ the batch size, otherwise some samples are ignored. This
199
+ behavior is retained to match the original FID score
200
+ implementation.
201
+ -- dims : Dimensionality of features returned by Inception
202
+ -- device : Device to run calculations
203
+ -- num_workers : Number of parallel dataloader workers
204
+
205
+ Returns:
206
+ -- mu : The mean over samples of the activations of the pool_3 layer of
207
+ the inception model.
208
+ -- sigma : The covariance matrix of the activations of the pool_3 layer of
209
+ the inception model.
210
+ """
211
+ act = get_activations(
212
+ files_or_imgs,
213
+ model,
214
+ batch_size,
215
+ dims,
216
+ device,
217
+ num_workers,
218
+ disable_tqdm,
219
+ )
220
+ mu = np.mean(act, axis=0)
221
+ sigma = np.cov(act, rowvar=False)
222
+ return mu, sigma
223
+
224
+
225
+ class FrechetInceptionDistance:
226
+ def __init__(
227
+ self,
228
+ device="cuda" if torch.cuda.is_available() else "cpu",
229
+ dims: int = 2048,
230
+ num_workers: int = 1,
231
+ batch_size: int = 1,
232
+ disable_tqdm: bool = True,
233
+ ):
234
+ # https://github.com/mseitzer/pytorch-fid/src/pytorch_fid/fid_score.py
235
+ self.dims = dims
236
+ self.device = device
237
+ self.num_workers = num_workers
238
+ self.batch_size = batch_size
239
+ self.disable_tqdm = disable_tqdm
240
+ self.block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[self.dims]
241
+ self.model = InceptionV3([self.block_idx]).to(self.device)
242
+ self.model = self.model.eval()
243
+
244
+ def compute_fid(
245
+ self,
246
+ image_true: np.ndarray | str,
247
+ image_test: np.ndarray | str,
248
+ ):
249
+ """
250
+ Calculates the FID of two file paths
251
+ FID = FrechetInceptionDistance()
252
+ img_fid = FID.compute_fid("img_true.png", "img_test.png")
253
+ img_dir_fid = FID.compute_fid("img_true_dir", "img_test_dir")
254
+ """
255
+ if isinstance(image_true, str) or isinstance(image_test, str):
256
+ if os.path.isfile(image_true) or os.path.isfile(image_test):
257
+ assert os.path.exists(image_true)
258
+ assert os.path.exists(image_test)
259
+ assert image_true.split(".")[-1] in _IMAGE_EXTENSIONS
260
+ assert image_test.split(".")[-1] in _IMAGE_EXTENSIONS
261
+ image_true_files = [image_true]
262
+ image_test_files = [image_test]
263
+ else:
264
+ # glob image files from dir
265
+ assert os.path.isdir(image_true)
266
+ assert os.path.isdir(image_test)
267
+ image_true_dir = pathlib.Path(image_true)
268
+ image_true_files = sorted(
269
+ [
270
+ file
271
+ for ext in _IMAGE_EXTENSIONS
272
+ for file in image_true_dir.rglob("*.{}".format(ext))
273
+ ]
274
+ )
275
+ image_test_dir = pathlib.Path(image_test)
276
+ image_test_files = sorted(
277
+ [
278
+ file
279
+ for ext in _IMAGE_EXTENSIONS
280
+ for file in image_test_dir.rglob("*.{}".format(ext))
281
+ ]
282
+ )
283
+ image_true_files = [
284
+ file.as_posix() for file in image_true_files
285
+ ]
286
+ image_test_files = [
287
+ file.as_posix() for file in image_test_files
288
+ ]
289
+
290
+ # select valid files
291
+ image_true_files_selected = []
292
+ image_test_files_selected = []
293
+ for i in range(
294
+ min(len(image_true_files), len(image_test_files))
295
+ ):
296
+ selected_image_true = image_true_files[i]
297
+ selected_image_test = image_test_files[i]
298
+ # Image pair must have the same basename
299
+ if os.path.basename(
300
+ selected_image_test
301
+ ) == os.path.basename(selected_image_true):
302
+ image_true_files_selected.append(selected_image_true)
303
+ image_test_files_selected.append(selected_image_test)
304
+ image_true_files = image_true_files_selected.copy()
305
+ image_test_files = image_test_files_selected.copy()
306
+ if len(image_true_files) == 0:
307
+ logger.error(
308
+ "No valid Image pairs, please note that Image "
309
+ "pairs must have the same basename."
310
+ )
311
+ return None, None
312
+
313
+ logger.debug(f"image_true_files: {image_true_files}")
314
+ logger.debug(f"image_test_files: {image_test_files}")
315
+ else:
316
+ image_true_files = [image_true]
317
+ image_test_files = [image_test]
318
+
319
+ batch_size = min(16, self.batch_size)
320
+ batch_size = min(batch_size, len(image_test_files))
321
+ m1, s1 = calculate_activation_statistics(
322
+ image_true_files,
323
+ self.model,
324
+ batch_size,
325
+ self.dims,
326
+ self.device,
327
+ self.num_workers,
328
+ self.disable_tqdm,
329
+ )
330
+ m2, s2 = calculate_activation_statistics(
331
+ image_test_files,
332
+ self.model,
333
+ batch_size,
334
+ self.dims,
335
+ self.device,
336
+ self.num_workers,
337
+ self.disable_tqdm,
338
+ )
339
+ fid_value = calculate_frechet_distance(
340
+ m1,
341
+ s1,
342
+ m2,
343
+ s2,
344
+ )
345
+
346
+ return fid_value, len(image_true_files)
347
+
348
+ def compute_video_fid(
349
+ self,
350
+ # file or dir
351
+ video_true: str,
352
+ video_test: str,
353
+ ):
354
+ if os.path.isfile(video_true) and os.path.isfile(video_test):
355
+ video_true_frames, video_test_frames, valid_frames = (
356
+ self._fetch_video_frames(
357
+ video_true=video_true,
358
+ video_test=video_test,
359
+ )
360
+ )
361
+ elif os.path.isdir(video_true) and os.path.isdir(video_test):
362
+ # Glob videos
363
+ video_true_dir: pathlib.Path = pathlib.Path(video_true)
364
+ video_true_files = sorted(
365
+ [
366
+ file
367
+ for ext in _VIDEO_EXTENSIONS
368
+ for file in video_true_dir.rglob("*.{}".format(ext))
369
+ ]
370
+ )
371
+ video_test_dir: pathlib.Path = pathlib.Path(video_test)
372
+ video_test_files = sorted(
373
+ [
374
+ file
375
+ for ext in _VIDEO_EXTENSIONS
376
+ for file in video_test_dir.rglob("*.{}".format(ext))
377
+ ]
378
+ )
379
+ video_true_files = [file.as_posix() for file in video_true_files]
380
+ video_test_files = [file.as_posix() for file in video_test_files]
381
+
382
+ # select valid video files
383
+ video_true_files_selected = []
384
+ video_test_files_selected = []
385
+ for i in range(min(len(video_true_files), len(video_test_files))):
386
+ selected_video_true = video_true_files[i]
387
+ selected_video_test = video_test_files[i]
388
+ # Video pair must have the same basename
389
+ if os.path.basename(selected_video_test) == os.path.basename(
390
+ selected_video_true
391
+ ):
392
+ video_true_files_selected.append(selected_video_true)
393
+ video_test_files_selected.append(selected_video_test)
394
+
395
+ video_true_files = video_true_files_selected.copy()
396
+ video_test_files = video_test_files_selected.copy()
397
+ if len(video_true_files) == 0:
398
+ logger.error(
399
+ "No valid Video pairs, please note that Video "
400
+ "pairs must have the same basename."
401
+ )
402
+ return None, None
403
+ logger.debug(f"video_true_files: {video_true_files}")
404
+ logger.debug(f"video_test_files: {video_test_files}")
405
+
406
+ # Fetch all frames
407
+ video_true_frames = []
408
+ video_test_frames = []
409
+ valid_frames = 0
410
+
411
+ for video_true_, video_test_ in zip(
412
+ video_true_files, video_test_files
413
+ ):
414
+ video_true_frames_, video_test_frames_, valid_frames_ = (
415
+ self._fetch_video_frames(
416
+ video_true=video_true_, video_test=video_test_
417
+ )
418
+ )
419
+ video_true_frames.extend(video_true_frames_)
420
+ video_test_frames.extend(video_test_frames_)
421
+ valid_frames += valid_frames_
422
+ else:
423
+ raise ValueError("video_true and video_test must be files or dirs.")
424
+
425
+ if valid_frames <= 0:
426
+ logger.debug("No valid frames to compare")
427
+ return None, None
428
+
429
+ batch_size = min(16, self.batch_size)
430
+ m1, s1 = calculate_activation_statistics(
431
+ video_true_frames,
432
+ self.model,
433
+ batch_size,
434
+ self.dims,
435
+ self.device,
436
+ self.num_workers,
437
+ self.disable_tqdm,
438
+ )
439
+ m2, s2 = calculate_activation_statistics(
440
+ video_test_frames,
441
+ self.model,
442
+ batch_size,
443
+ self.dims,
444
+ self.device,
445
+ self.num_workers,
446
+ self.disable_tqdm,
447
+ )
448
+ fid_value = calculate_frechet_distance(
449
+ m1,
450
+ s1,
451
+ m2,
452
+ s2,
453
+ )
454
+
455
+ return fid_value, valid_frames
456
+
457
+ def _fetch_video_frames(
458
+ self,
459
+ video_true: str,
460
+ video_test: str,
461
+ ):
462
+ cap1 = cv2.VideoCapture(video_true)
463
+ cap2 = cv2.VideoCapture(video_test)
464
+
465
+ if not cap1.isOpened() or not cap2.isOpened():
466
+ logger.error("Could not open video files")
467
+ return [], [], 0
468
+
469
+ frame_count = min(
470
+ int(cap1.get(cv2.CAP_PROP_FRAME_COUNT)),
471
+ int(cap2.get(cv2.CAP_PROP_FRAME_COUNT)),
472
+ )
473
+
474
+ valid_frames = 0
475
+ video_true_frames = []
476
+ video_test_frames = []
477
+
478
+ logger.debug(f"Total frames: {frame_count}")
479
+
480
+ while True:
481
+ ret1, frame1 = cap1.read()
482
+ ret2, frame2 = cap2.read()
483
+
484
+ if not ret1 or not ret2:
485
+ break
486
+
487
+ video_true_frames.append(frame1)
488
+ video_test_frames.append(frame2)
489
+
490
+ valid_frames += 1
491
+
492
+ cap1.release()
493
+ cap2.release()
494
+
495
+ if valid_frames <= 0:
496
+ return [], [], 0
497
+
498
+ return video_true_frames, video_test_frames, valid_frames