flaxdiff 0.1.38__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/__init__.py +5 -1
- flaxdiff/data/benchmark_decord.py +443 -0
- flaxdiff/data/dataloaders.py +608 -0
- flaxdiff/data/dataset_map.py +61 -6
- flaxdiff/data/online_loader.py +779 -150
- flaxdiff/data/sources/audio_utils.py +142 -0
- flaxdiff/data/sources/av_example.py +125 -0
- flaxdiff/data/sources/av_utils.py +590 -0
- flaxdiff/data/sources/base.py +129 -0
- flaxdiff/data/sources/images.py +309 -0
- flaxdiff/data/sources/utils.py +158 -0
- flaxdiff/data/sources/videos.py +250 -0
- flaxdiff/data/sources/voxceleb2.py +412 -0
- flaxdiff/inference/__init__.py +0 -0
- flaxdiff/inference/pipeline.py +260 -0
- flaxdiff/inference/utils.py +320 -0
- flaxdiff/inputs/__init__.py +173 -0
- flaxdiff/inputs/encoders.py +98 -0
- flaxdiff/models/__init__.py +2 -1
- flaxdiff/models/attention.py +22 -16
- flaxdiff/models/autoencoder/autoencoder.py +141 -9
- flaxdiff/models/autoencoder/diffusers.py +88 -25
- flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
- flaxdiff/models/common.py +8 -18
- flaxdiff/models/simple_unet.py +6 -17
- flaxdiff/models/simple_vit.py +9 -13
- flaxdiff/models/unet_3d.py +446 -0
- flaxdiff/models/unet_3d_blocks.py +505 -0
- flaxdiff/samplers/common.py +358 -96
- flaxdiff/samplers/ddim.py +44 -5
- flaxdiff/schedulers/karras.py +20 -12
- flaxdiff/trainer/__init__.py +2 -1
- flaxdiff/trainer/autoencoder_trainer.py +1 -2
- flaxdiff/trainer/diffusion_trainer.py +35 -29
- flaxdiff/trainer/general_diffusion_trainer.py +583 -0
- flaxdiff/trainer/simple_trainer.py +51 -16
- flaxdiff/utils.py +128 -57
- {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
- flaxdiff-0.2.0.dist-info/RECORD +64 -0
- {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
- flaxdiff/data/datasets.py +0 -169
- flaxdiff/data/sources/gcs.py +0 -81
- flaxdiff/data/sources/tfds.py +0 -79
- flaxdiff/trainer/video_diffusion_trainer.py +0 -62
- flaxdiff-0.1.38.dist-info/RECORD +0 -50
- {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,250 @@
|
|
1
|
+
import cv2
|
2
|
+
import jax.numpy as jnp
|
3
|
+
import grain.python as pygrain
|
4
|
+
from flaxdiff.utils import AutoTextTokenizer
|
5
|
+
from typing import Dict, Any, Callable, List, Optional, Tuple
|
6
|
+
import random
|
7
|
+
import os
|
8
|
+
import numpy as np
|
9
|
+
from functools import partial
|
10
|
+
from .base import DataSource, DataAugmenter
|
11
|
+
import numpy as np
|
12
|
+
import subprocess
|
13
|
+
import shutil
|
14
|
+
from .av_utils import read_av_random_clip
|
15
|
+
|
16
|
+
# ----------------------------------------------------------------------------------
|
17
|
+
# Video augmentation utilities
|
18
|
+
# ----------------------------------------------------------------------------------
|
19
|
+
def gather_video_paths_iter(input_dir, extensions=['.mp4', '.avi', '.mov', '.webm']):
|
20
|
+
# Ensure extensions have dots at the beginning and are lowercase
|
21
|
+
extensions = {ext.lower() if ext.startswith('.') else f'.{ext}'.lower() for ext in extensions}
|
22
|
+
|
23
|
+
for root, _, files in os.walk(input_dir):
|
24
|
+
for file in sorted(files):
|
25
|
+
_, ext = os.path.splitext(file)
|
26
|
+
if ext.lower() in extensions:
|
27
|
+
video_input = os.path.join(root, file)
|
28
|
+
yield video_input
|
29
|
+
|
30
|
+
def gather_video_paths(input_dir, extensions=['.mp4', '.avi', '.mov', '.webm']):
|
31
|
+
"""Gather video paths from a directory."""
|
32
|
+
video_paths = []
|
33
|
+
for video_input in gather_video_paths_iter(input_dir, extensions):
|
34
|
+
video_paths.append(video_input)
|
35
|
+
|
36
|
+
# Sort the video paths
|
37
|
+
video_paths.sort()
|
38
|
+
return video_paths
|
39
|
+
|
40
|
+
# ----------------------------------------------------------------------------------
|
41
|
+
# TFDS Video Source
|
42
|
+
# ----------------------------------------------------------------------------------
|
43
|
+
|
44
|
+
class VideoTFDSSource(DataSource):
|
45
|
+
"""Data source for TensorFlow Datasets (TFDS) video datasets."""
|
46
|
+
|
47
|
+
def __init__(self, name: str, use_tf: bool = True, split: str = "train"):
|
48
|
+
"""Initialize a TFDS video data source.
|
49
|
+
|
50
|
+
Args:
|
51
|
+
name: Name of the TFDS dataset.
|
52
|
+
use_tf: Whether to use TensorFlow for loading.
|
53
|
+
split: Dataset split to use.
|
54
|
+
"""
|
55
|
+
self.name = name
|
56
|
+
self.use_tf = use_tf
|
57
|
+
self.split = split
|
58
|
+
|
59
|
+
def get_source(self, path_override: str) -> Any:
|
60
|
+
"""Get the TFDS video data source.
|
61
|
+
|
62
|
+
Args:
|
63
|
+
path_override: Override path for the dataset.
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
A TFDS dataset.
|
67
|
+
"""
|
68
|
+
import tensorflow_datasets as tfds
|
69
|
+
if self.use_tf:
|
70
|
+
return tfds.load(self.name, split=self.split, shuffle_files=True)
|
71
|
+
else:
|
72
|
+
return tfds.data_source(self.name, split=self.split, try_gcs=False)
|
73
|
+
|
74
|
+
|
75
|
+
# ----------------------------------------------------------------------------------
|
76
|
+
# Local Video Source
|
77
|
+
# ----------------------------------------------------------------------------------
|
78
|
+
|
79
|
+
class VideoLocalSource(DataSource):
|
80
|
+
"""Data source for local video files."""
|
81
|
+
|
82
|
+
def __init__(
|
83
|
+
self,
|
84
|
+
directory: str = "",
|
85
|
+
extensions: List[str] = ['.mp4', '.avi', '.mov', '.webm'],
|
86
|
+
clear_cache: bool = False,
|
87
|
+
cache_dir: Optional[str] = './cache',
|
88
|
+
):
|
89
|
+
"""Initialize a local video data source.
|
90
|
+
|
91
|
+
Args:
|
92
|
+
directory: Directory containing video files.
|
93
|
+
extensions: List of valid video file extensions.
|
94
|
+
clear_cache: Whether to clear the cache on initialization.
|
95
|
+
cache_dir: Directory to cache video paths.
|
96
|
+
"""
|
97
|
+
self.extensions = extensions
|
98
|
+
self.cache_dir = cache_dir
|
99
|
+
if directory:
|
100
|
+
self.load_paths(directory, clear_cache)
|
101
|
+
|
102
|
+
def load_paths(self, directory: str, clear_cache: bool = False):
|
103
|
+
"""Load video paths from a directory."""
|
104
|
+
if self.directory == directory and not clear_cache:
|
105
|
+
# If the directory hasn't changed and cache is not cleared, return cached paths
|
106
|
+
return
|
107
|
+
self.directory = directory
|
108
|
+
|
109
|
+
# Use gather_video_paths to get all video paths and cache them
|
110
|
+
# in a local dictionary for future use
|
111
|
+
|
112
|
+
# Generate a hash for the directory to use as a key
|
113
|
+
self.directory_hash = hash(directory)
|
114
|
+
|
115
|
+
# Check if the cache directory exists
|
116
|
+
if os.path.exists(self.cache_dir):
|
117
|
+
# Load cached video paths if available
|
118
|
+
cache_file = os.path.join(self.cache_dir, f"video_paths_{self.directory_hash}.txt")
|
119
|
+
import pickle
|
120
|
+
if os.path.exists(cache_file) and not clear_cache:
|
121
|
+
with open(cache_file, 'rb') as f:
|
122
|
+
video_paths = pickle.load(f)
|
123
|
+
print(f"Loaded cached video paths from {cache_file}")
|
124
|
+
else:
|
125
|
+
# If no cache file, gather video paths and save them
|
126
|
+
print(f"Cache file not found or clear_cache is True. Gathering video paths from {directory}")
|
127
|
+
video_paths = gather_video_paths(directory, self.extensions)
|
128
|
+
with open(cache_file, 'wb') as f:
|
129
|
+
pickle.dump(video_paths, f)
|
130
|
+
print(f"Cached video paths to {cache_file}")
|
131
|
+
|
132
|
+
self.video_paths = video_paths
|
133
|
+
|
134
|
+
def get_source(self, path_override: str = None) -> List[Dict[str, Any]]:
|
135
|
+
"""Get the local video data source.
|
136
|
+
|
137
|
+
Args:
|
138
|
+
path_override: Override directory path.
|
139
|
+
|
140
|
+
Returns:
|
141
|
+
A list of dictionaries with video paths.
|
142
|
+
"""
|
143
|
+
if path_override:
|
144
|
+
self.load_paths(path_override)
|
145
|
+
|
146
|
+
video_paths = self.video_paths
|
147
|
+
dataset = []
|
148
|
+
for video_path in video_paths:
|
149
|
+
dataset.append({"video_path": video_path})
|
150
|
+
return dataset
|
151
|
+
|
152
|
+
# ----------------------------------------------------------------------------------
|
153
|
+
# Video Augmenter
|
154
|
+
# ----------------------------------------------------------------------------------
|
155
|
+
|
156
|
+
class AudioVideoAugmenter(DataAugmenter):
|
157
|
+
"""Augmenter for audio-video datasets."""
|
158
|
+
|
159
|
+
def __init__(self,
|
160
|
+
preprocess_fn: Callable = None):
|
161
|
+
"""Initialize a AV augmenter.
|
162
|
+
|
163
|
+
Args:
|
164
|
+
num_frames: Number of frames to sample from each video.
|
165
|
+
preprocess_fn: Optional function to preprocess video frames.
|
166
|
+
"""
|
167
|
+
self.preprocess_fn = preprocess_fn
|
168
|
+
|
169
|
+
def create_transform(
|
170
|
+
self,
|
171
|
+
frame_size: int = 256,
|
172
|
+
sequence_length: int = 16,
|
173
|
+
audio_frame_padding: int = 3,
|
174
|
+
method: Any = cv2.INTER_AREA,
|
175
|
+
) -> Callable[[], pygrain.MapTransform]:
|
176
|
+
"""Create a transform for video datasets.
|
177
|
+
|
178
|
+
Args:
|
179
|
+
frame_size: Size to scale video frames to.
|
180
|
+
sequence_length: Number of frames to sample from each video.
|
181
|
+
method: Interpolation method for resizing.
|
182
|
+
|
183
|
+
Returns:
|
184
|
+
A callable that returns a pygrain.MapTransform.
|
185
|
+
"""
|
186
|
+
num_frames = sequence_length
|
187
|
+
|
188
|
+
class AudioVideoTransform(pygrain.RandomMapTransform):
|
189
|
+
def __init__(self, *args, **kwargs):
|
190
|
+
super().__init__(*args, **kwargs)
|
191
|
+
self.tokenize = AutoAudioTokenizer(tensor_type="np")
|
192
|
+
|
193
|
+
def random_map(self, element, rng: np.random.Generator) -> Dict[str, jnp.array]:
|
194
|
+
video_path = element["video_path"]
|
195
|
+
random_seed = rng.integers(0, 2**32 - 1)
|
196
|
+
# Read video frames
|
197
|
+
framewise_audio, full_audio, video_frames = read_av_random_clip(
|
198
|
+
video_path,
|
199
|
+
num_frames=num_frames,
|
200
|
+
audio_frame_padding=audio_frame_padding,
|
201
|
+
random_seed=random_seed,
|
202
|
+
)
|
203
|
+
|
204
|
+
# Process caption
|
205
|
+
results = self.tokenize(full_audio)
|
206
|
+
|
207
|
+
return {
|
208
|
+
"video": video_frames,
|
209
|
+
"audio": {
|
210
|
+
"input_ids": results['input_ids'][0],
|
211
|
+
"attention_mask": results['attention_mask'][0],
|
212
|
+
"full_audio": full_audio,
|
213
|
+
"framewise_audio": framewise_audio,
|
214
|
+
}
|
215
|
+
}
|
216
|
+
|
217
|
+
return AudioVideoTransform
|
218
|
+
|
219
|
+
|
220
|
+
# ----------------------------------------------------------------------------------
|
221
|
+
# Helper functions for video datasets
|
222
|
+
# ----------------------------------------------------------------------------------
|
223
|
+
|
224
|
+
# def create_video_dataset_from_directory(
|
225
|
+
# directory: str,
|
226
|
+
# extensions: List[str] = ['.mp4', '.avi', '.mov', '.webm'],
|
227
|
+
# frame_size: int = 256,
|
228
|
+
# ) -> Tuple[List[Dict[str, Any]], AudioVideoAugmenter]:
|
229
|
+
# """Create a video dataset from a directory of video files.
|
230
|
+
|
231
|
+
# Args:
|
232
|
+
# directory: Directory containing video files.
|
233
|
+
# extensions: List of valid video file extensions.
|
234
|
+
# frame_size: Size to scale video frames to.
|
235
|
+
# num_frames: Number of frames to sample from each video.
|
236
|
+
|
237
|
+
# Returns:
|
238
|
+
# Tuple of (dataset, augmenter) for the video dataset.
|
239
|
+
# """
|
240
|
+
# source = VideoLocalSource(
|
241
|
+
# directory=directory,
|
242
|
+
# extensions=extensions,
|
243
|
+
# )
|
244
|
+
|
245
|
+
# augmenter = AudioVideoAugmenter(
|
246
|
+
# num_frames=num_frames
|
247
|
+
# )
|
248
|
+
|
249
|
+
# dataset = source.get_source()
|
250
|
+
# return dataset, augmenter
|
@@ -0,0 +1,412 @@
|
|
1
|
+
from logging import warn, warning
|
2
|
+
import os
|
3
|
+
import random
|
4
|
+
from arrow import get
|
5
|
+
import einops
|
6
|
+
import numpy as np
|
7
|
+
from os.path import join
|
8
|
+
from PIL import Image
|
9
|
+
import torch
|
10
|
+
from torch.utils.data import Dataset
|
11
|
+
from torchvision import transforms
|
12
|
+
from torchvision.transforms import functional as F
|
13
|
+
import decord
|
14
|
+
from decord import VideoReader, AudioReader, cpu
|
15
|
+
import traceback
|
16
|
+
|
17
|
+
from d2lv2_lightning.config import DataConfig
|
18
|
+
from d2lv2_lightning.utils import dist_util
|
19
|
+
from .face_mask import FaceMaskGenerator
|
20
|
+
from .prompt_templates import TEMPLATE_MAP
|
21
|
+
from .utils import ImageProcessor
|
22
|
+
from .audio import crop_wav_window, melspectrogram, crop_mel_window, get_segmented_wavs, get_segmented_mels
|
23
|
+
|
24
|
+
class Voxceleb2Decord(Dataset):
|
25
|
+
"""
|
26
|
+
A dataset module for video-to-video (audio guided) diffusion training.
|
27
|
+
This implementation uses decord to load videos and audio on the fly
|
28
|
+
"""
|
29
|
+
default_video_fps = 25
|
30
|
+
default_mel_steps_per_sec = 80.
|
31
|
+
|
32
|
+
def __init__(
|
33
|
+
self,
|
34
|
+
split,
|
35
|
+
data_config: DataConfig, # expects attributes like: data_root, filelists_path, nframes, syncnet_mel_step_size, image_size, face_hide_percentage, video_fps, etc.
|
36
|
+
tokenizer = None,
|
37
|
+
token_map: dict = None,
|
38
|
+
use_template: str = None,
|
39
|
+
audio_format: str = "mel",
|
40
|
+
h_flip: bool = True,
|
41
|
+
color_jitter: bool = False,
|
42
|
+
blur_amount: int = 70,
|
43
|
+
sample_rate: int = 16000,
|
44
|
+
shared_audio_dict=None,
|
45
|
+
val_ratio: float = 0.001,
|
46
|
+
num_val_ids: int = -1,
|
47
|
+
val_split_seed: int = 787,
|
48
|
+
dataset_name: str = "voxceleb2",
|
49
|
+
face_mask_type: str = "fixed",
|
50
|
+
):
|
51
|
+
random.seed(dist_util.get_rank() + 1)
|
52
|
+
print(f"Dataset split: {split}, rank: {dist_util.get_rank() + 1}")
|
53
|
+
self.split = split
|
54
|
+
self.data_config = data_config
|
55
|
+
self.tokenizer = tokenizer
|
56
|
+
self.token_map = token_map
|
57
|
+
self.use_template = use_template
|
58
|
+
self.audio_format = audio_format
|
59
|
+
self.h_flip = h_flip
|
60
|
+
self.color_jitter = color_jitter
|
61
|
+
self.blur_amount = blur_amount
|
62
|
+
self.sample_rate = sample_rate
|
63
|
+
self.shared_audio_dict = shared_audio_dict if shared_audio_dict is not None else {}
|
64
|
+
self.val_ratio = val_ratio
|
65
|
+
self.num_val_ids = num_val_ids
|
66
|
+
self.val_split_seed = val_split_seed
|
67
|
+
self.dataset_name = dataset_name
|
68
|
+
self.face_mask_type = face_mask_type
|
69
|
+
|
70
|
+
decord.bridge.set_bridge('torch')
|
71
|
+
|
72
|
+
# Video properties (either from args or defaults)
|
73
|
+
self.video_fps = getattr(data_config, "video_fps", self.__class__.default_video_fps)
|
74
|
+
self.mel_steps_per_sec = self.__class__.default_mel_steps_per_sec
|
75
|
+
|
76
|
+
# Set the data root based on the split.
|
77
|
+
if split in ["train", "trainfull"]:
|
78
|
+
self.data_root = os.path.join(data_config.data_root, "train")
|
79
|
+
else:
|
80
|
+
self.data_root = os.path.join(data_config.data_root, "test")
|
81
|
+
# self.data_root = data_config.data_root
|
82
|
+
|
83
|
+
# Determine file list path
|
84
|
+
if hasattr(data_config, "filelists_path") and data_config.filelists_path is not None:
|
85
|
+
self.filelists_path = data_config.filelists_path
|
86
|
+
else:
|
87
|
+
self.filelists_path = os.path.join('./data/voxceleb2/', "filelists")
|
88
|
+
# Warn the user that the default filelists path is being used.
|
89
|
+
warning(f"Using default filelists path: {self.filelists_path}. Please set data_config.filelists_path to a custom path if needed.")
|
90
|
+
os.makedirs(self.filelists_path, exist_ok=True)
|
91
|
+
|
92
|
+
filelist_file = join(self.filelists_path, f"{dataset_name}_{split}.txt")
|
93
|
+
if not os.path.exists(filelist_file):
|
94
|
+
warning(f"File list {filelist_file} not found. Creating a new file list. Please make sure to the data_root: {data_config.data_root} is correct for the split {split}.")
|
95
|
+
self.all_videos = self.create_filelist()
|
96
|
+
else:
|
97
|
+
self.all_videos = self.get_video_list(filelist_file)
|
98
|
+
print(f"Using file list: {filelist_file} with {len(self.all_videos)} videos.")
|
99
|
+
|
100
|
+
# Image transforms (assumes 3-channel images)
|
101
|
+
size = data_config.resolution
|
102
|
+
self.size = size
|
103
|
+
self.image_transforms = ImageProcessor(size)
|
104
|
+
self.mask_transforms = ImageProcessor(size)
|
105
|
+
|
106
|
+
if use_template is not None:
|
107
|
+
assert token_map is not None, "token_map must be provided if using a template."
|
108
|
+
self.templates = TEMPLATE_MAP[use_template]
|
109
|
+
|
110
|
+
def worker_init_fn(self, worker_id):
|
111
|
+
self.worker_id = worker_id
|
112
|
+
if self.face_mask_type != "fixed":
|
113
|
+
# Initialize dynamic face mask generator.
|
114
|
+
self.mask_generator = FaceMaskGenerator(
|
115
|
+
video_mode=False,
|
116
|
+
mask_type=self.face_mask_type,
|
117
|
+
)
|
118
|
+
|
119
|
+
|
120
|
+
def get_video_list(self, filelist_file):
|
121
|
+
videos = []
|
122
|
+
with open(filelist_file, "r") as f:
|
123
|
+
for line in f:
|
124
|
+
line = line.strip()
|
125
|
+
if line:
|
126
|
+
# Each line is relative to data_root.
|
127
|
+
videos.append(os.path.join(self.data_root, line))
|
128
|
+
return videos
|
129
|
+
|
130
|
+
def create_filelist(self):
|
131
|
+
# Create a filelist by scanning the directory structure.
|
132
|
+
# (This example assumes VoxCeleb2 videos are stored under data_root/id/vid/utterance.mp4)
|
133
|
+
all_videos = []
|
134
|
+
print("Creating filelist for dataset", self.dataset_name)
|
135
|
+
if self.dataset_name == 'voxceleb2':
|
136
|
+
for identity in os.listdir(self.data_root):
|
137
|
+
id_path = os.path.join(self.data_root, identity)
|
138
|
+
if not os.path.isdir(id_path):
|
139
|
+
continue
|
140
|
+
for vid in os.listdir(id_path):
|
141
|
+
vid_path = os.path.join(id_path, vid)
|
142
|
+
if not os.path.isdir(vid_path):
|
143
|
+
continue
|
144
|
+
for utt in os.listdir(vid_path):
|
145
|
+
if utt.endswith(".mp4") or utt.endswith(".avi"):
|
146
|
+
# Save relative path (so that data_root can be prepended)
|
147
|
+
all_videos.append(os.path.join(identity, vid, utt))
|
148
|
+
else:
|
149
|
+
raise NotImplementedError("Filelist creation for this dataset is not implemented.")
|
150
|
+
print("Total videos found:", len(all_videos))
|
151
|
+
# Write filelist to disk.
|
152
|
+
filelist_file = join(self.filelists_path, f"{self.dataset_name}_{self.split}.txt")
|
153
|
+
with open(filelist_file, "w") as f:
|
154
|
+
for v in all_videos:
|
155
|
+
f.write(v + "\n")
|
156
|
+
# Return full paths.
|
157
|
+
return [os.path.join(self.data_root, v) for v in all_videos]
|
158
|
+
|
159
|
+
def get_masks(self, imgs, pad=0):
|
160
|
+
if hasattr(self, 'mask_generator'):
|
161
|
+
try:
|
162
|
+
if imgs.shape[-1] == 3:
|
163
|
+
B, H, W, C = imgs.shape
|
164
|
+
else:
|
165
|
+
B, C, H, W = imgs.shape
|
166
|
+
imgs = einops.rearrange(imgs, "b c h w -> b h w c")
|
167
|
+
masks = self.mask_generator.generate_mask_video(imgs.numpy(), mask_expansion=10, expansion_factor=1.1)
|
168
|
+
return torch.from_numpy(np.stack(masks, axis=0, dtype=np.float16).reshape(B, 1, H, W) // 255)
|
169
|
+
except Exception as e:
|
170
|
+
print(f"Error generating masks with mask_generator: {e}")
|
171
|
+
# Fallback to simple mask generation if the generator fails.
|
172
|
+
print("Falling back to simple mask generation.")
|
173
|
+
return self.get_simple_mask(pad)
|
174
|
+
else:
|
175
|
+
return self.get_simple_mask(pad)
|
176
|
+
|
177
|
+
def get_simple_mask(self, pad=0):
|
178
|
+
if getattr(self, 'mask_cache', None) is not None:
|
179
|
+
return self.mask_cache
|
180
|
+
H = W = self.size
|
181
|
+
# Define a crop region similar to the original crop function.
|
182
|
+
y1, y2 = 0, H - int(H * 2.36 / 8)
|
183
|
+
x1, x2 = int(W * 1.8 / 8), W - int(W * 1.8 / 8)
|
184
|
+
# Apply face_hide_percentage to determine the mask region.
|
185
|
+
y1 = y2 - int(np.ceil(self.data_config.face_hide_percentage * (y2 - y1)))
|
186
|
+
if pad:
|
187
|
+
y1 = max(y1 - pad, 0)
|
188
|
+
y2 = min(y2 + pad, H)
|
189
|
+
x1 = max(x1 - pad, 0)
|
190
|
+
x2 = min(x2 + pad, W)
|
191
|
+
msk = Image.new("L", (W, H), 0)
|
192
|
+
msk_arr = np.array(msk).astype(np.float16)
|
193
|
+
msk_arr[y1:y2, x1:x2] = 255
|
194
|
+
|
195
|
+
msk_arr = msk_arr // 255
|
196
|
+
|
197
|
+
# msk = Image.fromarray(msk_arr)
|
198
|
+
# msk = self.mask_transforms.preprocess_frames(msk) * 0.5 + 0.5 # normalize to [0,1]
|
199
|
+
# Duplicate the mask for each frame.
|
200
|
+
mask = torch.from_numpy(msk_arr).to(torch.float16).unsqueeze(0).repeat(self.data_config.nframes, 1, 1, 1)
|
201
|
+
# Cache the mask for all frames.
|
202
|
+
self.mask_cache = mask
|
203
|
+
return mask
|
204
|
+
|
205
|
+
def read_frames(self, videoreader: VideoReader, start_frame, num_frames):
|
206
|
+
"""
|
207
|
+
Read a batch of frames from the video using decord.
|
208
|
+
Returns a tuple: (list of transformed frames, list of reference frames, list of raw PIL frames).
|
209
|
+
"""
|
210
|
+
try:
|
211
|
+
total_frames = len(videoreader)
|
212
|
+
if total_frames < num_frames:
|
213
|
+
return None, None, None
|
214
|
+
# Get the target window of frames.
|
215
|
+
frame_indices = list(range(start_frame, start_frame + num_frames))
|
216
|
+
frames_array = videoreader.get_batch(frame_indices) # shape: (num_frames, H, W, C)
|
217
|
+
|
218
|
+
# Determine valid start indices for a "wrong" window that does not overlap the instance window.
|
219
|
+
valid_starts = []
|
220
|
+
# Left interval: ensure wrong_start + num_frames - 1 < start_frame.
|
221
|
+
left_max = start_frame - num_frames
|
222
|
+
if left_max >= 0:
|
223
|
+
valid_starts.extend(range(0, left_max + 1))
|
224
|
+
# Right interval: ensure wrong_start > start_frame + num_frames - 1.
|
225
|
+
right_min = start_frame + num_frames
|
226
|
+
if right_min <= total_frames - num_frames:
|
227
|
+
valid_starts.extend(range(right_min, total_frames - num_frames + 1))
|
228
|
+
|
229
|
+
if not valid_starts:
|
230
|
+
# Fallback: if no valid index is available, choose the farthest possible window.
|
231
|
+
wrong_start = 0 if start_frame > total_frames // 2 else total_frames - num_frames
|
232
|
+
else:
|
233
|
+
wrong_start = random.choice(valid_starts)
|
234
|
+
|
235
|
+
wrong_indices = list(range(wrong_start, wrong_start + num_frames))
|
236
|
+
|
237
|
+
wrong_indices = list(range(wrong_start, wrong_start + num_frames))
|
238
|
+
wrong_array = videoreader.get_batch(wrong_indices)
|
239
|
+
return frames_array, wrong_array
|
240
|
+
except Exception as e:
|
241
|
+
print(f"Error reading frames from {videoreader}: {e}")
|
242
|
+
return None, None, None
|
243
|
+
|
244
|
+
def read_audio(self, video_path):
|
245
|
+
try:
|
246
|
+
ar = AudioReader(video_path, ctx=cpu(self.worker_id), sample_rate=self.sample_rate)
|
247
|
+
audio = ar[:].squeeze() # assume mono
|
248
|
+
del ar
|
249
|
+
return audio
|
250
|
+
except Exception as e:
|
251
|
+
print(f"Error reading audio from {video_path}: {e}")
|
252
|
+
return None
|
253
|
+
|
254
|
+
def compute_mel(self, audio):
|
255
|
+
try:
|
256
|
+
mel = melspectrogram(audio)
|
257
|
+
return mel.T
|
258
|
+
except Exception as e:
|
259
|
+
print("Error computing mel spectrogram:", e)
|
260
|
+
return None
|
261
|
+
|
262
|
+
def get_mel(self, audio, path):
|
263
|
+
# First try to find the mel in the cache directory
|
264
|
+
cache_dir = self.data_config.data_cache_path if self.data_config.data_cache_path else os.path.join(self.data_root, "cache")
|
265
|
+
cache_dir = os.path.join(cache_dir, self.split)
|
266
|
+
cache_path = os.path.join(cache_dir, os.path.basename(path) + ".mel")
|
267
|
+
if os.path.exists(cache_path):
|
268
|
+
mel = np.load(cache_path)
|
269
|
+
return mel
|
270
|
+
# If not found, compute the mel and save it to the cache
|
271
|
+
mel = self.compute_mel(audio)
|
272
|
+
if mel is None:
|
273
|
+
return None
|
274
|
+
os.makedirs(cache_dir, exist_ok=True)
|
275
|
+
np.save(cache_path, mel)
|
276
|
+
return mel
|
277
|
+
|
278
|
+
def __len__(self):
|
279
|
+
return len(self.all_videos)
|
280
|
+
|
281
|
+
def __getitem__(self, index):
|
282
|
+
"""
|
283
|
+
Returns a dictionary with:
|
284
|
+
- instance_images: [F, C, H, W]
|
285
|
+
- reference_images: [F, C, H, W]
|
286
|
+
- mask: [F, 1, H, W]
|
287
|
+
- instance_masks: same as mask
|
288
|
+
- (optionally) instance_masks_dilated
|
289
|
+
- instance_masked_images: instance_images * (mask < 0.5)
|
290
|
+
- instance_prompt_ids: tokenized caption
|
291
|
+
- raw_audio / indiv_raw_audios, mels / indiv_mels if audio_format is specified.
|
292
|
+
"""
|
293
|
+
example = {}
|
294
|
+
attempt = 0
|
295
|
+
while True:
|
296
|
+
attempt += 1
|
297
|
+
if attempt > 10:
|
298
|
+
raise RuntimeError("Failed to get a valid sample after multiple attempts.")
|
299
|
+
try:
|
300
|
+
# Select a random video.
|
301
|
+
video_idx = random.randint(0, len(self.all_videos) - 1)
|
302
|
+
video_path = self.all_videos[video_idx]
|
303
|
+
vr = VideoReader(video_path, ctx=cpu(self.worker_id))
|
304
|
+
total_frames = len(vr)
|
305
|
+
if total_frames < 3 * self.data_config.nframes:
|
306
|
+
continue
|
307
|
+
|
308
|
+
# Randomly choose a start frame ensuring enough frames for the window.
|
309
|
+
start_frame = random.randint(self.data_config.nframes // 2, total_frames - self.data_config.nframes - self.data_config.nframes // 2)
|
310
|
+
inst_frames, ref_frames = self.read_frames(vr, start_frame, self.data_config.nframes)
|
311
|
+
if inst_frames is None or ref_frames is None:
|
312
|
+
continue
|
313
|
+
|
314
|
+
vr.seek(0) # avoid memory leak
|
315
|
+
del vr
|
316
|
+
|
317
|
+
# Generate masks
|
318
|
+
masks = self.get_masks(inst_frames)
|
319
|
+
masks = self.image_transforms.resize(masks)
|
320
|
+
|
321
|
+
dilated_masks = None
|
322
|
+
if getattr(self.data_config, "dilate_masked_loss", False):
|
323
|
+
dilated_masks = self.get_masks(inst_frames, pad=self.data_config.resolution // 10)
|
324
|
+
dilated_masks = self.image_transforms.resize(dilated_masks)
|
325
|
+
|
326
|
+
# Preprocess frames.
|
327
|
+
inst_frames = self.image_transforms.preprocess_frames(inst_frames)
|
328
|
+
ref_frames = self.image_transforms.preprocess_frames(ref_frames)
|
329
|
+
|
330
|
+
# Optionally apply horizontal flip.
|
331
|
+
if self.h_flip and random.random() > 0.5:
|
332
|
+
inst_frames = F.hflip(inst_frames)
|
333
|
+
ref_frames = F.hflip(ref_frames)
|
334
|
+
masks = F.hflip(masks)
|
335
|
+
if dilated_masks is not None:
|
336
|
+
dilated_masks = F.hflip(dilated_masks)
|
337
|
+
|
338
|
+
# Audio processing.
|
339
|
+
if "wav" in self.audio_format or "mel" in self.audio_format:
|
340
|
+
audio = self.read_audio(video_path)
|
341
|
+
|
342
|
+
audio_chunk = crop_wav_window(
|
343
|
+
audio,
|
344
|
+
start_frame=start_frame,
|
345
|
+
nframes=self.data_config.nframes,
|
346
|
+
video_fps=self.video_fps,
|
347
|
+
sample_rate=self.sample_rate,
|
348
|
+
)
|
349
|
+
if audio_chunk is None:
|
350
|
+
continue
|
351
|
+
example["raw_audio"] = audio_chunk
|
352
|
+
if getattr(self.data_config, "use_indiv_audio", False):
|
353
|
+
indiv_audios = get_segmented_wavs(
|
354
|
+
audio,
|
355
|
+
start_frame,
|
356
|
+
self.data_config.nframes,
|
357
|
+
self.video_fps,
|
358
|
+
self.sample_rate,
|
359
|
+
indiv_audio_mode=self.data_config.indiv_audio_mode,
|
360
|
+
)
|
361
|
+
example["indiv_raw_audios"] = torch.FloatTensor(indiv_audios)
|
362
|
+
if "mel" in self.audio_format:
|
363
|
+
mel = self.get_mel(audio, video_path)
|
364
|
+
if mel is None:
|
365
|
+
continue
|
366
|
+
mel_window = crop_mel_window(
|
367
|
+
mel,
|
368
|
+
start_frame,
|
369
|
+
self.mel_steps_per_sec,
|
370
|
+
self.data_config.syncnet_mel_step_size,
|
371
|
+
self.video_fps,
|
372
|
+
)
|
373
|
+
if mel_window.shape[0] != self.data_config.syncnet_mel_step_size:
|
374
|
+
continue
|
375
|
+
example["mels"] = torch.FloatTensor(mel_window.T).unsqueeze(0)
|
376
|
+
indiv_mels = get_segmented_mels(
|
377
|
+
mel,
|
378
|
+
start_frame,
|
379
|
+
self.data_config.nframes,
|
380
|
+
self.mel_steps_per_sec,
|
381
|
+
self.data_config.syncnet_mel_step_size,
|
382
|
+
self.video_fps,
|
383
|
+
)
|
384
|
+
if indiv_mels is None:
|
385
|
+
continue
|
386
|
+
example["indiv_mels"] = torch.FloatTensor(indiv_mels)
|
387
|
+
|
388
|
+
example["instance_images"] = inst_frames # [F, C, H, W]
|
389
|
+
example["reference_images"] = ref_frames # [F, C, H, W]
|
390
|
+
example["mask"] = masks # [F, 1, H, W]
|
391
|
+
example["instance_masks"] = example["mask"]
|
392
|
+
if dilated_masks is not None:
|
393
|
+
example["instance_masks_dilated"] = dilated_masks
|
394
|
+
example["instance_masked_images"] = example["instance_images"] * (example["mask"] < 0.5)
|
395
|
+
|
396
|
+
# Process the caption prompt.
|
397
|
+
if self.use_template and self.tokenizer is not None:
|
398
|
+
input_tok = list(self.token_map.values())[0]
|
399
|
+
text = random.choice(self.templates).format(input_tok)
|
400
|
+
example["instance_prompt_ids"] = self.tokenizer(
|
401
|
+
text,
|
402
|
+
padding="do_not_pad",
|
403
|
+
truncation=True,
|
404
|
+
max_length=self.tokenizer.model_max_length,
|
405
|
+
).input_ids
|
406
|
+
# else:
|
407
|
+
# raise NotImplementedError("Only template-based captions are supported.")
|
408
|
+
return example
|
409
|
+
except Exception as e:
|
410
|
+
print("Exception in __getitem__:", e)
|
411
|
+
traceback.print_exc()
|
412
|
+
continue
|
File without changes
|