flaxdiff 0.1.38__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. flaxdiff/data/__init__.py +5 -1
  2. flaxdiff/data/benchmark_decord.py +443 -0
  3. flaxdiff/data/dataloaders.py +608 -0
  4. flaxdiff/data/dataset_map.py +61 -6
  5. flaxdiff/data/online_loader.py +779 -150
  6. flaxdiff/data/sources/audio_utils.py +142 -0
  7. flaxdiff/data/sources/av_example.py +125 -0
  8. flaxdiff/data/sources/av_utils.py +590 -0
  9. flaxdiff/data/sources/base.py +129 -0
  10. flaxdiff/data/sources/images.py +309 -0
  11. flaxdiff/data/sources/utils.py +158 -0
  12. flaxdiff/data/sources/videos.py +250 -0
  13. flaxdiff/data/sources/voxceleb2.py +412 -0
  14. flaxdiff/inference/__init__.py +0 -0
  15. flaxdiff/inference/pipeline.py +260 -0
  16. flaxdiff/inference/utils.py +320 -0
  17. flaxdiff/inputs/__init__.py +173 -0
  18. flaxdiff/inputs/encoders.py +98 -0
  19. flaxdiff/models/__init__.py +2 -1
  20. flaxdiff/models/attention.py +22 -16
  21. flaxdiff/models/autoencoder/autoencoder.py +141 -9
  22. flaxdiff/models/autoencoder/diffusers.py +88 -25
  23. flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
  24. flaxdiff/models/common.py +8 -18
  25. flaxdiff/models/simple_unet.py +6 -17
  26. flaxdiff/models/simple_vit.py +9 -13
  27. flaxdiff/models/unet_3d.py +446 -0
  28. flaxdiff/models/unet_3d_blocks.py +505 -0
  29. flaxdiff/samplers/common.py +358 -96
  30. flaxdiff/samplers/ddim.py +44 -5
  31. flaxdiff/schedulers/karras.py +20 -12
  32. flaxdiff/trainer/__init__.py +2 -1
  33. flaxdiff/trainer/autoencoder_trainer.py +1 -2
  34. flaxdiff/trainer/diffusion_trainer.py +35 -29
  35. flaxdiff/trainer/general_diffusion_trainer.py +583 -0
  36. flaxdiff/trainer/simple_trainer.py +51 -16
  37. flaxdiff/utils.py +128 -57
  38. {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
  39. flaxdiff-0.2.0.dist-info/RECORD +64 -0
  40. {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
  41. flaxdiff/data/datasets.py +0 -169
  42. flaxdiff/data/sources/gcs.py +0 -81
  43. flaxdiff/data/sources/tfds.py +0 -79
  44. flaxdiff/trainer/video_diffusion_trainer.py +0 -62
  45. flaxdiff-0.1.38.dist-info/RECORD +0 -50
  46. {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,250 @@
1
+ import cv2
2
+ import jax.numpy as jnp
3
+ import grain.python as pygrain
4
+ from flaxdiff.utils import AutoTextTokenizer
5
+ from typing import Dict, Any, Callable, List, Optional, Tuple
6
+ import random
7
+ import os
8
+ import numpy as np
9
+ from functools import partial
10
+ from .base import DataSource, DataAugmenter
11
+ import numpy as np
12
+ import subprocess
13
+ import shutil
14
+ from .av_utils import read_av_random_clip
15
+
16
+ # ----------------------------------------------------------------------------------
17
+ # Video augmentation utilities
18
+ # ----------------------------------------------------------------------------------
19
+ def gather_video_paths_iter(input_dir, extensions=['.mp4', '.avi', '.mov', '.webm']):
20
+ # Ensure extensions have dots at the beginning and are lowercase
21
+ extensions = {ext.lower() if ext.startswith('.') else f'.{ext}'.lower() for ext in extensions}
22
+
23
+ for root, _, files in os.walk(input_dir):
24
+ for file in sorted(files):
25
+ _, ext = os.path.splitext(file)
26
+ if ext.lower() in extensions:
27
+ video_input = os.path.join(root, file)
28
+ yield video_input
29
+
30
+ def gather_video_paths(input_dir, extensions=['.mp4', '.avi', '.mov', '.webm']):
31
+ """Gather video paths from a directory."""
32
+ video_paths = []
33
+ for video_input in gather_video_paths_iter(input_dir, extensions):
34
+ video_paths.append(video_input)
35
+
36
+ # Sort the video paths
37
+ video_paths.sort()
38
+ return video_paths
39
+
40
+ # ----------------------------------------------------------------------------------
41
+ # TFDS Video Source
42
+ # ----------------------------------------------------------------------------------
43
+
44
+ class VideoTFDSSource(DataSource):
45
+ """Data source for TensorFlow Datasets (TFDS) video datasets."""
46
+
47
+ def __init__(self, name: str, use_tf: bool = True, split: str = "train"):
48
+ """Initialize a TFDS video data source.
49
+
50
+ Args:
51
+ name: Name of the TFDS dataset.
52
+ use_tf: Whether to use TensorFlow for loading.
53
+ split: Dataset split to use.
54
+ """
55
+ self.name = name
56
+ self.use_tf = use_tf
57
+ self.split = split
58
+
59
+ def get_source(self, path_override: str) -> Any:
60
+ """Get the TFDS video data source.
61
+
62
+ Args:
63
+ path_override: Override path for the dataset.
64
+
65
+ Returns:
66
+ A TFDS dataset.
67
+ """
68
+ import tensorflow_datasets as tfds
69
+ if self.use_tf:
70
+ return tfds.load(self.name, split=self.split, shuffle_files=True)
71
+ else:
72
+ return tfds.data_source(self.name, split=self.split, try_gcs=False)
73
+
74
+
75
+ # ----------------------------------------------------------------------------------
76
+ # Local Video Source
77
+ # ----------------------------------------------------------------------------------
78
+
79
+ class VideoLocalSource(DataSource):
80
+ """Data source for local video files."""
81
+
82
+ def __init__(
83
+ self,
84
+ directory: str = "",
85
+ extensions: List[str] = ['.mp4', '.avi', '.mov', '.webm'],
86
+ clear_cache: bool = False,
87
+ cache_dir: Optional[str] = './cache',
88
+ ):
89
+ """Initialize a local video data source.
90
+
91
+ Args:
92
+ directory: Directory containing video files.
93
+ extensions: List of valid video file extensions.
94
+ clear_cache: Whether to clear the cache on initialization.
95
+ cache_dir: Directory to cache video paths.
96
+ """
97
+ self.extensions = extensions
98
+ self.cache_dir = cache_dir
99
+ if directory:
100
+ self.load_paths(directory, clear_cache)
101
+
102
+ def load_paths(self, directory: str, clear_cache: bool = False):
103
+ """Load video paths from a directory."""
104
+ if self.directory == directory and not clear_cache:
105
+ # If the directory hasn't changed and cache is not cleared, return cached paths
106
+ return
107
+ self.directory = directory
108
+
109
+ # Use gather_video_paths to get all video paths and cache them
110
+ # in a local dictionary for future use
111
+
112
+ # Generate a hash for the directory to use as a key
113
+ self.directory_hash = hash(directory)
114
+
115
+ # Check if the cache directory exists
116
+ if os.path.exists(self.cache_dir):
117
+ # Load cached video paths if available
118
+ cache_file = os.path.join(self.cache_dir, f"video_paths_{self.directory_hash}.txt")
119
+ import pickle
120
+ if os.path.exists(cache_file) and not clear_cache:
121
+ with open(cache_file, 'rb') as f:
122
+ video_paths = pickle.load(f)
123
+ print(f"Loaded cached video paths from {cache_file}")
124
+ else:
125
+ # If no cache file, gather video paths and save them
126
+ print(f"Cache file not found or clear_cache is True. Gathering video paths from {directory}")
127
+ video_paths = gather_video_paths(directory, self.extensions)
128
+ with open(cache_file, 'wb') as f:
129
+ pickle.dump(video_paths, f)
130
+ print(f"Cached video paths to {cache_file}")
131
+
132
+ self.video_paths = video_paths
133
+
134
+ def get_source(self, path_override: str = None) -> List[Dict[str, Any]]:
135
+ """Get the local video data source.
136
+
137
+ Args:
138
+ path_override: Override directory path.
139
+
140
+ Returns:
141
+ A list of dictionaries with video paths.
142
+ """
143
+ if path_override:
144
+ self.load_paths(path_override)
145
+
146
+ video_paths = self.video_paths
147
+ dataset = []
148
+ for video_path in video_paths:
149
+ dataset.append({"video_path": video_path})
150
+ return dataset
151
+
152
+ # ----------------------------------------------------------------------------------
153
+ # Video Augmenter
154
+ # ----------------------------------------------------------------------------------
155
+
156
+ class AudioVideoAugmenter(DataAugmenter):
157
+ """Augmenter for audio-video datasets."""
158
+
159
+ def __init__(self,
160
+ preprocess_fn: Callable = None):
161
+ """Initialize a AV augmenter.
162
+
163
+ Args:
164
+ num_frames: Number of frames to sample from each video.
165
+ preprocess_fn: Optional function to preprocess video frames.
166
+ """
167
+ self.preprocess_fn = preprocess_fn
168
+
169
+ def create_transform(
170
+ self,
171
+ frame_size: int = 256,
172
+ sequence_length: int = 16,
173
+ audio_frame_padding: int = 3,
174
+ method: Any = cv2.INTER_AREA,
175
+ ) -> Callable[[], pygrain.MapTransform]:
176
+ """Create a transform for video datasets.
177
+
178
+ Args:
179
+ frame_size: Size to scale video frames to.
180
+ sequence_length: Number of frames to sample from each video.
181
+ method: Interpolation method for resizing.
182
+
183
+ Returns:
184
+ A callable that returns a pygrain.MapTransform.
185
+ """
186
+ num_frames = sequence_length
187
+
188
+ class AudioVideoTransform(pygrain.RandomMapTransform):
189
+ def __init__(self, *args, **kwargs):
190
+ super().__init__(*args, **kwargs)
191
+ self.tokenize = AutoAudioTokenizer(tensor_type="np")
192
+
193
+ def random_map(self, element, rng: np.random.Generator) -> Dict[str, jnp.array]:
194
+ video_path = element["video_path"]
195
+ random_seed = rng.integers(0, 2**32 - 1)
196
+ # Read video frames
197
+ framewise_audio, full_audio, video_frames = read_av_random_clip(
198
+ video_path,
199
+ num_frames=num_frames,
200
+ audio_frame_padding=audio_frame_padding,
201
+ random_seed=random_seed,
202
+ )
203
+
204
+ # Process caption
205
+ results = self.tokenize(full_audio)
206
+
207
+ return {
208
+ "video": video_frames,
209
+ "audio": {
210
+ "input_ids": results['input_ids'][0],
211
+ "attention_mask": results['attention_mask'][0],
212
+ "full_audio": full_audio,
213
+ "framewise_audio": framewise_audio,
214
+ }
215
+ }
216
+
217
+ return AudioVideoTransform
218
+
219
+
220
+ # ----------------------------------------------------------------------------------
221
+ # Helper functions for video datasets
222
+ # ----------------------------------------------------------------------------------
223
+
224
+ # def create_video_dataset_from_directory(
225
+ # directory: str,
226
+ # extensions: List[str] = ['.mp4', '.avi', '.mov', '.webm'],
227
+ # frame_size: int = 256,
228
+ # ) -> Tuple[List[Dict[str, Any]], AudioVideoAugmenter]:
229
+ # """Create a video dataset from a directory of video files.
230
+
231
+ # Args:
232
+ # directory: Directory containing video files.
233
+ # extensions: List of valid video file extensions.
234
+ # frame_size: Size to scale video frames to.
235
+ # num_frames: Number of frames to sample from each video.
236
+
237
+ # Returns:
238
+ # Tuple of (dataset, augmenter) for the video dataset.
239
+ # """
240
+ # source = VideoLocalSource(
241
+ # directory=directory,
242
+ # extensions=extensions,
243
+ # )
244
+
245
+ # augmenter = AudioVideoAugmenter(
246
+ # num_frames=num_frames
247
+ # )
248
+
249
+ # dataset = source.get_source()
250
+ # return dataset, augmenter
@@ -0,0 +1,412 @@
1
+ from logging import warn, warning
2
+ import os
3
+ import random
4
+ from arrow import get
5
+ import einops
6
+ import numpy as np
7
+ from os.path import join
8
+ from PIL import Image
9
+ import torch
10
+ from torch.utils.data import Dataset
11
+ from torchvision import transforms
12
+ from torchvision.transforms import functional as F
13
+ import decord
14
+ from decord import VideoReader, AudioReader, cpu
15
+ import traceback
16
+
17
+ from d2lv2_lightning.config import DataConfig
18
+ from d2lv2_lightning.utils import dist_util
19
+ from .face_mask import FaceMaskGenerator
20
+ from .prompt_templates import TEMPLATE_MAP
21
+ from .utils import ImageProcessor
22
+ from .audio import crop_wav_window, melspectrogram, crop_mel_window, get_segmented_wavs, get_segmented_mels
23
+
24
+ class Voxceleb2Decord(Dataset):
25
+ """
26
+ A dataset module for video-to-video (audio guided) diffusion training.
27
+ This implementation uses decord to load videos and audio on the fly
28
+ """
29
+ default_video_fps = 25
30
+ default_mel_steps_per_sec = 80.
31
+
32
+ def __init__(
33
+ self,
34
+ split,
35
+ data_config: DataConfig, # expects attributes like: data_root, filelists_path, nframes, syncnet_mel_step_size, image_size, face_hide_percentage, video_fps, etc.
36
+ tokenizer = None,
37
+ token_map: dict = None,
38
+ use_template: str = None,
39
+ audio_format: str = "mel",
40
+ h_flip: bool = True,
41
+ color_jitter: bool = False,
42
+ blur_amount: int = 70,
43
+ sample_rate: int = 16000,
44
+ shared_audio_dict=None,
45
+ val_ratio: float = 0.001,
46
+ num_val_ids: int = -1,
47
+ val_split_seed: int = 787,
48
+ dataset_name: str = "voxceleb2",
49
+ face_mask_type: str = "fixed",
50
+ ):
51
+ random.seed(dist_util.get_rank() + 1)
52
+ print(f"Dataset split: {split}, rank: {dist_util.get_rank() + 1}")
53
+ self.split = split
54
+ self.data_config = data_config
55
+ self.tokenizer = tokenizer
56
+ self.token_map = token_map
57
+ self.use_template = use_template
58
+ self.audio_format = audio_format
59
+ self.h_flip = h_flip
60
+ self.color_jitter = color_jitter
61
+ self.blur_amount = blur_amount
62
+ self.sample_rate = sample_rate
63
+ self.shared_audio_dict = shared_audio_dict if shared_audio_dict is not None else {}
64
+ self.val_ratio = val_ratio
65
+ self.num_val_ids = num_val_ids
66
+ self.val_split_seed = val_split_seed
67
+ self.dataset_name = dataset_name
68
+ self.face_mask_type = face_mask_type
69
+
70
+ decord.bridge.set_bridge('torch')
71
+
72
+ # Video properties (either from args or defaults)
73
+ self.video_fps = getattr(data_config, "video_fps", self.__class__.default_video_fps)
74
+ self.mel_steps_per_sec = self.__class__.default_mel_steps_per_sec
75
+
76
+ # Set the data root based on the split.
77
+ if split in ["train", "trainfull"]:
78
+ self.data_root = os.path.join(data_config.data_root, "train")
79
+ else:
80
+ self.data_root = os.path.join(data_config.data_root, "test")
81
+ # self.data_root = data_config.data_root
82
+
83
+ # Determine file list path
84
+ if hasattr(data_config, "filelists_path") and data_config.filelists_path is not None:
85
+ self.filelists_path = data_config.filelists_path
86
+ else:
87
+ self.filelists_path = os.path.join('./data/voxceleb2/', "filelists")
88
+ # Warn the user that the default filelists path is being used.
89
+ warning(f"Using default filelists path: {self.filelists_path}. Please set data_config.filelists_path to a custom path if needed.")
90
+ os.makedirs(self.filelists_path, exist_ok=True)
91
+
92
+ filelist_file = join(self.filelists_path, f"{dataset_name}_{split}.txt")
93
+ if not os.path.exists(filelist_file):
94
+ warning(f"File list {filelist_file} not found. Creating a new file list. Please make sure to the data_root: {data_config.data_root} is correct for the split {split}.")
95
+ self.all_videos = self.create_filelist()
96
+ else:
97
+ self.all_videos = self.get_video_list(filelist_file)
98
+ print(f"Using file list: {filelist_file} with {len(self.all_videos)} videos.")
99
+
100
+ # Image transforms (assumes 3-channel images)
101
+ size = data_config.resolution
102
+ self.size = size
103
+ self.image_transforms = ImageProcessor(size)
104
+ self.mask_transforms = ImageProcessor(size)
105
+
106
+ if use_template is not None:
107
+ assert token_map is not None, "token_map must be provided if using a template."
108
+ self.templates = TEMPLATE_MAP[use_template]
109
+
110
+ def worker_init_fn(self, worker_id):
111
+ self.worker_id = worker_id
112
+ if self.face_mask_type != "fixed":
113
+ # Initialize dynamic face mask generator.
114
+ self.mask_generator = FaceMaskGenerator(
115
+ video_mode=False,
116
+ mask_type=self.face_mask_type,
117
+ )
118
+
119
+
120
+ def get_video_list(self, filelist_file):
121
+ videos = []
122
+ with open(filelist_file, "r") as f:
123
+ for line in f:
124
+ line = line.strip()
125
+ if line:
126
+ # Each line is relative to data_root.
127
+ videos.append(os.path.join(self.data_root, line))
128
+ return videos
129
+
130
+ def create_filelist(self):
131
+ # Create a filelist by scanning the directory structure.
132
+ # (This example assumes VoxCeleb2 videos are stored under data_root/id/vid/utterance.mp4)
133
+ all_videos = []
134
+ print("Creating filelist for dataset", self.dataset_name)
135
+ if self.dataset_name == 'voxceleb2':
136
+ for identity in os.listdir(self.data_root):
137
+ id_path = os.path.join(self.data_root, identity)
138
+ if not os.path.isdir(id_path):
139
+ continue
140
+ for vid in os.listdir(id_path):
141
+ vid_path = os.path.join(id_path, vid)
142
+ if not os.path.isdir(vid_path):
143
+ continue
144
+ for utt in os.listdir(vid_path):
145
+ if utt.endswith(".mp4") or utt.endswith(".avi"):
146
+ # Save relative path (so that data_root can be prepended)
147
+ all_videos.append(os.path.join(identity, vid, utt))
148
+ else:
149
+ raise NotImplementedError("Filelist creation for this dataset is not implemented.")
150
+ print("Total videos found:", len(all_videos))
151
+ # Write filelist to disk.
152
+ filelist_file = join(self.filelists_path, f"{self.dataset_name}_{self.split}.txt")
153
+ with open(filelist_file, "w") as f:
154
+ for v in all_videos:
155
+ f.write(v + "\n")
156
+ # Return full paths.
157
+ return [os.path.join(self.data_root, v) for v in all_videos]
158
+
159
+ def get_masks(self, imgs, pad=0):
160
+ if hasattr(self, 'mask_generator'):
161
+ try:
162
+ if imgs.shape[-1] == 3:
163
+ B, H, W, C = imgs.shape
164
+ else:
165
+ B, C, H, W = imgs.shape
166
+ imgs = einops.rearrange(imgs, "b c h w -> b h w c")
167
+ masks = self.mask_generator.generate_mask_video(imgs.numpy(), mask_expansion=10, expansion_factor=1.1)
168
+ return torch.from_numpy(np.stack(masks, axis=0, dtype=np.float16).reshape(B, 1, H, W) // 255)
169
+ except Exception as e:
170
+ print(f"Error generating masks with mask_generator: {e}")
171
+ # Fallback to simple mask generation if the generator fails.
172
+ print("Falling back to simple mask generation.")
173
+ return self.get_simple_mask(pad)
174
+ else:
175
+ return self.get_simple_mask(pad)
176
+
177
+ def get_simple_mask(self, pad=0):
178
+ if getattr(self, 'mask_cache', None) is not None:
179
+ return self.mask_cache
180
+ H = W = self.size
181
+ # Define a crop region similar to the original crop function.
182
+ y1, y2 = 0, H - int(H * 2.36 / 8)
183
+ x1, x2 = int(W * 1.8 / 8), W - int(W * 1.8 / 8)
184
+ # Apply face_hide_percentage to determine the mask region.
185
+ y1 = y2 - int(np.ceil(self.data_config.face_hide_percentage * (y2 - y1)))
186
+ if pad:
187
+ y1 = max(y1 - pad, 0)
188
+ y2 = min(y2 + pad, H)
189
+ x1 = max(x1 - pad, 0)
190
+ x2 = min(x2 + pad, W)
191
+ msk = Image.new("L", (W, H), 0)
192
+ msk_arr = np.array(msk).astype(np.float16)
193
+ msk_arr[y1:y2, x1:x2] = 255
194
+
195
+ msk_arr = msk_arr // 255
196
+
197
+ # msk = Image.fromarray(msk_arr)
198
+ # msk = self.mask_transforms.preprocess_frames(msk) * 0.5 + 0.5 # normalize to [0,1]
199
+ # Duplicate the mask for each frame.
200
+ mask = torch.from_numpy(msk_arr).to(torch.float16).unsqueeze(0).repeat(self.data_config.nframes, 1, 1, 1)
201
+ # Cache the mask for all frames.
202
+ self.mask_cache = mask
203
+ return mask
204
+
205
+ def read_frames(self, videoreader: VideoReader, start_frame, num_frames):
206
+ """
207
+ Read a batch of frames from the video using decord.
208
+ Returns a tuple: (list of transformed frames, list of reference frames, list of raw PIL frames).
209
+ """
210
+ try:
211
+ total_frames = len(videoreader)
212
+ if total_frames < num_frames:
213
+ return None, None, None
214
+ # Get the target window of frames.
215
+ frame_indices = list(range(start_frame, start_frame + num_frames))
216
+ frames_array = videoreader.get_batch(frame_indices) # shape: (num_frames, H, W, C)
217
+
218
+ # Determine valid start indices for a "wrong" window that does not overlap the instance window.
219
+ valid_starts = []
220
+ # Left interval: ensure wrong_start + num_frames - 1 < start_frame.
221
+ left_max = start_frame - num_frames
222
+ if left_max >= 0:
223
+ valid_starts.extend(range(0, left_max + 1))
224
+ # Right interval: ensure wrong_start > start_frame + num_frames - 1.
225
+ right_min = start_frame + num_frames
226
+ if right_min <= total_frames - num_frames:
227
+ valid_starts.extend(range(right_min, total_frames - num_frames + 1))
228
+
229
+ if not valid_starts:
230
+ # Fallback: if no valid index is available, choose the farthest possible window.
231
+ wrong_start = 0 if start_frame > total_frames // 2 else total_frames - num_frames
232
+ else:
233
+ wrong_start = random.choice(valid_starts)
234
+
235
+ wrong_indices = list(range(wrong_start, wrong_start + num_frames))
236
+
237
+ wrong_indices = list(range(wrong_start, wrong_start + num_frames))
238
+ wrong_array = videoreader.get_batch(wrong_indices)
239
+ return frames_array, wrong_array
240
+ except Exception as e:
241
+ print(f"Error reading frames from {videoreader}: {e}")
242
+ return None, None, None
243
+
244
+ def read_audio(self, video_path):
245
+ try:
246
+ ar = AudioReader(video_path, ctx=cpu(self.worker_id), sample_rate=self.sample_rate)
247
+ audio = ar[:].squeeze() # assume mono
248
+ del ar
249
+ return audio
250
+ except Exception as e:
251
+ print(f"Error reading audio from {video_path}: {e}")
252
+ return None
253
+
254
+ def compute_mel(self, audio):
255
+ try:
256
+ mel = melspectrogram(audio)
257
+ return mel.T
258
+ except Exception as e:
259
+ print("Error computing mel spectrogram:", e)
260
+ return None
261
+
262
+ def get_mel(self, audio, path):
263
+ # First try to find the mel in the cache directory
264
+ cache_dir = self.data_config.data_cache_path if self.data_config.data_cache_path else os.path.join(self.data_root, "cache")
265
+ cache_dir = os.path.join(cache_dir, self.split)
266
+ cache_path = os.path.join(cache_dir, os.path.basename(path) + ".mel")
267
+ if os.path.exists(cache_path):
268
+ mel = np.load(cache_path)
269
+ return mel
270
+ # If not found, compute the mel and save it to the cache
271
+ mel = self.compute_mel(audio)
272
+ if mel is None:
273
+ return None
274
+ os.makedirs(cache_dir, exist_ok=True)
275
+ np.save(cache_path, mel)
276
+ return mel
277
+
278
+ def __len__(self):
279
+ return len(self.all_videos)
280
+
281
+ def __getitem__(self, index):
282
+ """
283
+ Returns a dictionary with:
284
+ - instance_images: [F, C, H, W]
285
+ - reference_images: [F, C, H, W]
286
+ - mask: [F, 1, H, W]
287
+ - instance_masks: same as mask
288
+ - (optionally) instance_masks_dilated
289
+ - instance_masked_images: instance_images * (mask < 0.5)
290
+ - instance_prompt_ids: tokenized caption
291
+ - raw_audio / indiv_raw_audios, mels / indiv_mels if audio_format is specified.
292
+ """
293
+ example = {}
294
+ attempt = 0
295
+ while True:
296
+ attempt += 1
297
+ if attempt > 10:
298
+ raise RuntimeError("Failed to get a valid sample after multiple attempts.")
299
+ try:
300
+ # Select a random video.
301
+ video_idx = random.randint(0, len(self.all_videos) - 1)
302
+ video_path = self.all_videos[video_idx]
303
+ vr = VideoReader(video_path, ctx=cpu(self.worker_id))
304
+ total_frames = len(vr)
305
+ if total_frames < 3 * self.data_config.nframes:
306
+ continue
307
+
308
+ # Randomly choose a start frame ensuring enough frames for the window.
309
+ start_frame = random.randint(self.data_config.nframes // 2, total_frames - self.data_config.nframes - self.data_config.nframes // 2)
310
+ inst_frames, ref_frames = self.read_frames(vr, start_frame, self.data_config.nframes)
311
+ if inst_frames is None or ref_frames is None:
312
+ continue
313
+
314
+ vr.seek(0) # avoid memory leak
315
+ del vr
316
+
317
+ # Generate masks
318
+ masks = self.get_masks(inst_frames)
319
+ masks = self.image_transforms.resize(masks)
320
+
321
+ dilated_masks = None
322
+ if getattr(self.data_config, "dilate_masked_loss", False):
323
+ dilated_masks = self.get_masks(inst_frames, pad=self.data_config.resolution // 10)
324
+ dilated_masks = self.image_transforms.resize(dilated_masks)
325
+
326
+ # Preprocess frames.
327
+ inst_frames = self.image_transforms.preprocess_frames(inst_frames)
328
+ ref_frames = self.image_transforms.preprocess_frames(ref_frames)
329
+
330
+ # Optionally apply horizontal flip.
331
+ if self.h_flip and random.random() > 0.5:
332
+ inst_frames = F.hflip(inst_frames)
333
+ ref_frames = F.hflip(ref_frames)
334
+ masks = F.hflip(masks)
335
+ if dilated_masks is not None:
336
+ dilated_masks = F.hflip(dilated_masks)
337
+
338
+ # Audio processing.
339
+ if "wav" in self.audio_format or "mel" in self.audio_format:
340
+ audio = self.read_audio(video_path)
341
+
342
+ audio_chunk = crop_wav_window(
343
+ audio,
344
+ start_frame=start_frame,
345
+ nframes=self.data_config.nframes,
346
+ video_fps=self.video_fps,
347
+ sample_rate=self.sample_rate,
348
+ )
349
+ if audio_chunk is None:
350
+ continue
351
+ example["raw_audio"] = audio_chunk
352
+ if getattr(self.data_config, "use_indiv_audio", False):
353
+ indiv_audios = get_segmented_wavs(
354
+ audio,
355
+ start_frame,
356
+ self.data_config.nframes,
357
+ self.video_fps,
358
+ self.sample_rate,
359
+ indiv_audio_mode=self.data_config.indiv_audio_mode,
360
+ )
361
+ example["indiv_raw_audios"] = torch.FloatTensor(indiv_audios)
362
+ if "mel" in self.audio_format:
363
+ mel = self.get_mel(audio, video_path)
364
+ if mel is None:
365
+ continue
366
+ mel_window = crop_mel_window(
367
+ mel,
368
+ start_frame,
369
+ self.mel_steps_per_sec,
370
+ self.data_config.syncnet_mel_step_size,
371
+ self.video_fps,
372
+ )
373
+ if mel_window.shape[0] != self.data_config.syncnet_mel_step_size:
374
+ continue
375
+ example["mels"] = torch.FloatTensor(mel_window.T).unsqueeze(0)
376
+ indiv_mels = get_segmented_mels(
377
+ mel,
378
+ start_frame,
379
+ self.data_config.nframes,
380
+ self.mel_steps_per_sec,
381
+ self.data_config.syncnet_mel_step_size,
382
+ self.video_fps,
383
+ )
384
+ if indiv_mels is None:
385
+ continue
386
+ example["indiv_mels"] = torch.FloatTensor(indiv_mels)
387
+
388
+ example["instance_images"] = inst_frames # [F, C, H, W]
389
+ example["reference_images"] = ref_frames # [F, C, H, W]
390
+ example["mask"] = masks # [F, 1, H, W]
391
+ example["instance_masks"] = example["mask"]
392
+ if dilated_masks is not None:
393
+ example["instance_masks_dilated"] = dilated_masks
394
+ example["instance_masked_images"] = example["instance_images"] * (example["mask"] < 0.5)
395
+
396
+ # Process the caption prompt.
397
+ if self.use_template and self.tokenizer is not None:
398
+ input_tok = list(self.token_map.values())[0]
399
+ text = random.choice(self.templates).format(input_tok)
400
+ example["instance_prompt_ids"] = self.tokenizer(
401
+ text,
402
+ padding="do_not_pad",
403
+ truncation=True,
404
+ max_length=self.tokenizer.model_max_length,
405
+ ).input_ids
406
+ # else:
407
+ # raise NotImplementedError("Only template-based captions are supported.")
408
+ return example
409
+ except Exception as e:
410
+ print("Exception in __getitem__:", e)
411
+ traceback.print_exc()
412
+ continue
File without changes