flaxdiff 0.1.38.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/__init__.py +5 -1
- flaxdiff/data/benchmark_decord.py +443 -0
- flaxdiff/data/dataloaders.py +608 -0
- flaxdiff/data/dataset_map.py +61 -6
- flaxdiff/data/online_loader.py +779 -150
- flaxdiff/data/sources/audio_utils.py +142 -0
- flaxdiff/data/sources/av_example.py +125 -0
- flaxdiff/data/sources/av_utils.py +590 -0
- flaxdiff/data/sources/base.py +129 -0
- flaxdiff/data/sources/images.py +309 -0
- flaxdiff/data/sources/utils.py +158 -0
- flaxdiff/data/sources/videos.py +250 -0
- flaxdiff/data/sources/voxceleb2.py +412 -0
- flaxdiff/inference/__init__.py +0 -0
- flaxdiff/inference/pipeline.py +260 -0
- flaxdiff/inference/utils.py +320 -0
- flaxdiff/inputs/__init__.py +173 -0
- flaxdiff/inputs/encoders.py +98 -0
- flaxdiff/models/__init__.py +2 -1
- flaxdiff/models/autoencoder/autoencoder.py +141 -9
- flaxdiff/models/autoencoder/diffusers.py +88 -25
- flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
- flaxdiff/models/simple_unet.py +5 -5
- flaxdiff/models/simple_vit.py +1 -1
- flaxdiff/models/unet_3d.py +446 -0
- flaxdiff/models/unet_3d_blocks.py +505 -0
- flaxdiff/samplers/common.py +358 -96
- flaxdiff/samplers/ddim.py +44 -5
- flaxdiff/schedulers/karras.py +20 -12
- flaxdiff/trainer/__init__.py +2 -1
- flaxdiff/trainer/autoencoder_trainer.py +1 -2
- flaxdiff/trainer/diffusion_trainer.py +33 -27
- flaxdiff/trainer/general_diffusion_trainer.py +583 -0
- flaxdiff/trainer/simple_trainer.py +48 -31
- flaxdiff/utils.py +128 -57
- {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
- flaxdiff-0.2.0.dist-info/RECORD +64 -0
- {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
- flaxdiff/data/datasets.py +0 -169
- flaxdiff/data/sources/gcs.py +0 -81
- flaxdiff/data/sources/tfds.py +0 -79
- flaxdiff/trainer/video_diffusion_trainer.py +0 -62
- flaxdiff-0.1.38.1.dist-info/RECORD +0 -50
- {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,590 @@
|
|
1
|
+
"""
|
2
|
+
Functions for reading audio-video data without memory leaks.
|
3
|
+
"""
|
4
|
+
import cv2
|
5
|
+
import os
|
6
|
+
import shutil
|
7
|
+
import subprocess
|
8
|
+
import numpy as np
|
9
|
+
from typing import Tuple, Optional, Union, List
|
10
|
+
from video_reader import PyVideoReader
|
11
|
+
from .audio_utils import read_audio
|
12
|
+
|
13
|
+
def get_video_fps(video_path: str):
|
14
|
+
cam = cv2.VideoCapture(video_path)
|
15
|
+
fps = cam.get(cv2.CAP_PROP_FPS)
|
16
|
+
cam.release()
|
17
|
+
return fps
|
18
|
+
|
19
|
+
def read_video(video_path: str, change_fps=False, reader="rsreader"):
|
20
|
+
temp_dir = None
|
21
|
+
try:
|
22
|
+
if change_fps:
|
23
|
+
print(f"Changing fps of {video_path} to 25")
|
24
|
+
temp_dir = "temp"
|
25
|
+
if os.path.exists(temp_dir):
|
26
|
+
shutil.rmtree(temp_dir)
|
27
|
+
os.makedirs(temp_dir, exist_ok=True)
|
28
|
+
command = (
|
29
|
+
f"ffmpeg -loglevel error -y -nostdin -i {video_path} -r 25 -crf 18 {os.path.join(temp_dir, 'video.mp4')}"
|
30
|
+
)
|
31
|
+
subprocess.run(command, shell=True)
|
32
|
+
target_video_path = os.path.join(temp_dir, "video.mp4")
|
33
|
+
else:
|
34
|
+
target_video_path = video_path
|
35
|
+
|
36
|
+
if reader == "rsreader":
|
37
|
+
return read_video_rsreader(target_video_path)
|
38
|
+
elif reader == "rsreader_fast":
|
39
|
+
return read_video_rsreader(target_video_path, fast=True)
|
40
|
+
elif reader == "decord":
|
41
|
+
return read_video_decord(target_video_path)
|
42
|
+
elif reader == "opencv":
|
43
|
+
return read_video_opencv(target_video_path)
|
44
|
+
else:
|
45
|
+
raise ValueError(f"Unknown reader: {reader}")
|
46
|
+
finally:
|
47
|
+
# Clean up temp directory when done
|
48
|
+
if change_fps and temp_dir and os.path.exists(temp_dir):
|
49
|
+
shutil.rmtree(temp_dir)
|
50
|
+
|
51
|
+
def read_video_decord(video_path: str):
|
52
|
+
from decord import VideoReader
|
53
|
+
vr = VideoReader(video_path)
|
54
|
+
video_frames = vr[:].asnumpy()
|
55
|
+
vr.seek(0)
|
56
|
+
return video_frames
|
57
|
+
|
58
|
+
# Fixed OpenCV video reader - properly release resources
|
59
|
+
def read_video_opencv(video_path):
|
60
|
+
cap = cv2.VideoCapture(video_path)
|
61
|
+
try:
|
62
|
+
frames = []
|
63
|
+
while True:
|
64
|
+
ret, frame = cap.read()
|
65
|
+
if not ret:
|
66
|
+
break
|
67
|
+
frames.append(frame)
|
68
|
+
return np.array(frames)[:, :, :, ::-1]
|
69
|
+
finally:
|
70
|
+
cap.release()
|
71
|
+
|
72
|
+
def read_video_rsreader(video_path, fast=False):
|
73
|
+
from video_reader import PyVideoReader
|
74
|
+
vr = PyVideoReader(video_path)
|
75
|
+
return vr.decode_fast() if fast else vr.decode()
|
76
|
+
|
77
|
+
def read_audio_decord(audio_path:str):
|
78
|
+
from decord import AudioReader
|
79
|
+
ar = AudioReader(audio_path)
|
80
|
+
audio_frames = ar[:].asnumpy()
|
81
|
+
ar.seek(0)
|
82
|
+
return audio_frames
|
83
|
+
|
84
|
+
def read_av_decord(path: str, start: int=0, end: int = None, ctx=None):
|
85
|
+
from decord import AVReader, cpu
|
86
|
+
if ctx is None:
|
87
|
+
ctx = cpu(0)
|
88
|
+
vr = AVReader(path, ctx=ctx, sample_rate=16000)
|
89
|
+
audio, video = vr[start:end]
|
90
|
+
return audio, video.asnumpy()
|
91
|
+
|
92
|
+
def read_av_improved(
|
93
|
+
path: str,
|
94
|
+
start: int = 0,
|
95
|
+
end: Optional[int] = None,
|
96
|
+
fps: float = 25.0,
|
97
|
+
target_sr: int = 16000,
|
98
|
+
audio_method: str = 'ffmpeg'
|
99
|
+
) -> Tuple[Union[List, np.ndarray], np.ndarray]:
|
100
|
+
"""
|
101
|
+
Read audio-video data with explicit cleanup and without memory leaks.
|
102
|
+
Uses PyVideoReader for video (which doesn't have memory leaks) and
|
103
|
+
FFmpeg/moviepy for audio extraction.
|
104
|
+
|
105
|
+
Args:
|
106
|
+
path: Path to the video file.
|
107
|
+
start: Start frame index.
|
108
|
+
end: End frame index (or None to read until the end).
|
109
|
+
fps: Video frames per second (used for audio timing).
|
110
|
+
target_sr: Target audio sample rate.
|
111
|
+
audio_method: Method to extract audio ('ffmpeg' or 'moviepy').
|
112
|
+
|
113
|
+
Returns:
|
114
|
+
Tuple of (audio_data, video_frames) where video_frames is a numpy array.
|
115
|
+
"""
|
116
|
+
# Calculate time information for audio extraction
|
117
|
+
start_time = start / fps if start > 0 else 0
|
118
|
+
duration = None
|
119
|
+
if end is not None:
|
120
|
+
duration = (end - start) / fps
|
121
|
+
|
122
|
+
# Get video frames using PyVideoReader
|
123
|
+
vr = PyVideoReader(path)
|
124
|
+
video = vr.decode(start_frame=start, end_frame=None)
|
125
|
+
|
126
|
+
# Get audio data using our custom audio utilities
|
127
|
+
audio, _ = read_audio(
|
128
|
+
path,
|
129
|
+
start_time=start_time,
|
130
|
+
duration=duration,
|
131
|
+
target_sr=target_sr,
|
132
|
+
method=audio_method
|
133
|
+
)
|
134
|
+
|
135
|
+
# Convert audio to list for API compatibility with original read_av
|
136
|
+
audio_list = list(audio)
|
137
|
+
|
138
|
+
return audio_list, video
|
139
|
+
|
140
|
+
def read_av_moviepy(
|
141
|
+
video_path: str,
|
142
|
+
start_idx: Optional[int] = None,
|
143
|
+
end_idx: Optional[int] = None,
|
144
|
+
target_fps: float = 25.0,
|
145
|
+
target_sr: int = 16000,
|
146
|
+
):
|
147
|
+
"""
|
148
|
+
Read audio-video data using moviepy.
|
149
|
+
|
150
|
+
Args:
|
151
|
+
video_path: Path to the video file.
|
152
|
+
start_idx: Start frame index (optional).
|
153
|
+
end_idx: End frame index (optional).
|
154
|
+
target_sr: Target sample rate for the audio.
|
155
|
+
|
156
|
+
Returns:
|
157
|
+
Tuple of (audio_data, video_frames) where video_frames is a numpy array.
|
158
|
+
"""
|
159
|
+
# Use moviepy to read audio and video
|
160
|
+
from moviepy import VideoFileClip
|
161
|
+
|
162
|
+
video = VideoFileClip(video_path).with_fps(target_fps)
|
163
|
+
|
164
|
+
# Convert frame indexes to time
|
165
|
+
start_time = start_idx / target_fps if start_idx is not None else 0
|
166
|
+
end_time = end_idx / target_fps if end_idx is not None else None
|
167
|
+
video = video.subclipped(start_time, end_time)
|
168
|
+
|
169
|
+
# Extract audio
|
170
|
+
audio = video.audio.with_fps(target_sr)
|
171
|
+
audio_data = audio.to_soundarray()
|
172
|
+
if audio_data.ndim > 1 and audio_data.shape[1] > 1:
|
173
|
+
audio_data = np.mean(audio_data, axis=1)
|
174
|
+
|
175
|
+
# Extract video frames
|
176
|
+
video_frames = []
|
177
|
+
for frame in video.iter_frames(fps=target_fps, dtype='uint8'):
|
178
|
+
video_frames.append(frame)
|
179
|
+
video_frames = np.array(video_frames)
|
180
|
+
video.close()
|
181
|
+
return audio_data, video_frames
|
182
|
+
def read_av_random_clip_moviepy(
|
183
|
+
video_path: str,
|
184
|
+
num_frames: int = 16,
|
185
|
+
audio_frames_per_video_frame: int = 1,
|
186
|
+
audio_frame_padding: int = 0,
|
187
|
+
target_sr: int = 16000,
|
188
|
+
target_fps: float = 25.0,
|
189
|
+
random_seed: Optional[int] = None,
|
190
|
+
):
|
191
|
+
"""
|
192
|
+
Read a random clip of audio and video frames.
|
193
|
+
Works by first selecting a random appropriate start frame, then reading the specified number of frames (1, N, H, W, C).
|
194
|
+
It then selects the audio clip corresponding to the video frames + some extra padding frames on either side. This is
|
195
|
+
of shape (1, P + N + P, K) where P is the padding, N is the number of video frames, and K is the audio data shape per frame.
|
196
|
+
if audio_frames_per_video_frame > 1, It then also creates a tensor of shape (1, N, F, K) where F = audio_frames_per_video_frame.
|
197
|
+
Otherwise (1, N, 1, K) is returned in the case of audio_frames_per_video_frame = 1.
|
198
|
+
|
199
|
+
The final audio and video tensors are returned.
|
200
|
+
Args:
|
201
|
+
video_path: Path to the video file.
|
202
|
+
num_frames: Number of video frames to read.
|
203
|
+
audio_frames_per_video_frame: Number of audio frames per video frame.
|
204
|
+
audio_frame_padding: Padding for audio frames.
|
205
|
+
target_sr: Target sample rate for the audio.
|
206
|
+
target_fps: Target frames per second for the video.
|
207
|
+
random_seed: Random seed for reproducibility (optional).
|
208
|
+
|
209
|
+
Returns:
|
210
|
+
Tuple of (frame_wise_audio, full_padded_audio, video_frames) where video_frames is a numpy array.
|
211
|
+
"""
|
212
|
+
from moviepy import VideoFileClip
|
213
|
+
# Set random seed if provided
|
214
|
+
if random_seed is not None:
|
215
|
+
np.random.seed(random_seed)
|
216
|
+
# Load the video
|
217
|
+
video = VideoFileClip(video_path).with_fps(target_fps)
|
218
|
+
original_duration = video.duration
|
219
|
+
total_frames = video.n_frames#int(original_duration * target_fps)
|
220
|
+
|
221
|
+
# Calculate effective padding needed based on audio segmentation
|
222
|
+
effective_padding = max(audio_frame_padding, (audio_frames_per_video_frame) // 2)
|
223
|
+
|
224
|
+
# Make sure we have enough frames
|
225
|
+
if total_frames < num_frames + 2 * effective_padding:
|
226
|
+
raise ValueError(f"Video has only {total_frames} frames, but {num_frames + 2 * effective_padding} were requested (including effective padding)")
|
227
|
+
|
228
|
+
# Adjust the range for start_idx to account for effective padding
|
229
|
+
min_start_idx = effective_padding
|
230
|
+
max_start_idx = total_frames - num_frames - effective_padding
|
231
|
+
|
232
|
+
# Select a random start frame that allows for padding on both sides
|
233
|
+
start_idx = np.random.randint(min_start_idx, max_start_idx) if max_start_idx > min_start_idx else min_start_idx
|
234
|
+
end_idx = start_idx + num_frames
|
235
|
+
|
236
|
+
# Convert to time
|
237
|
+
video_start_time = start_idx / target_fps
|
238
|
+
video_end_time = end_idx / target_fps
|
239
|
+
|
240
|
+
# Extract video frames
|
241
|
+
main_clip : VideoFileClip = video.subclipped(video_start_time, video_end_time)
|
242
|
+
# Replace the video frame extraction with:
|
243
|
+
frame_count = 0
|
244
|
+
video_frames = []
|
245
|
+
for frame in video.iter_frames(fps=target_fps, dtype='uint8'):
|
246
|
+
if frame_count >= start_idx and frame_count < start_idx + num_frames:
|
247
|
+
video_frames.append(frame)
|
248
|
+
frame_count += 1
|
249
|
+
if len(video_frames) == num_frames:
|
250
|
+
break
|
251
|
+
|
252
|
+
# Convert to numpy array
|
253
|
+
video_frames = np.array(video_frames)
|
254
|
+
|
255
|
+
audio_start_time = (start_idx - effective_padding) / target_fps
|
256
|
+
audio_end_time = (end_idx + effective_padding) / target_fps
|
257
|
+
num_audio_frames = num_frames + 2 * effective_padding
|
258
|
+
audio_duration = audio_end_time - audio_start_time
|
259
|
+
# Ensure we don't go out of bounds
|
260
|
+
if audio_start_time < 0 or audio_end_time > original_duration:
|
261
|
+
raise ValueError(f"Audio start time {audio_start_time} or end time {audio_end_time} is out of bounds for video duration {original_duration}")
|
262
|
+
|
263
|
+
# Extract the subclip
|
264
|
+
clip : VideoFileClip = video.subclipped(audio_start_time, audio_end_time)
|
265
|
+
# Extract audio
|
266
|
+
audio = clip.audio.with_fps(target_sr)
|
267
|
+
audio_data = audio.to_soundarray()
|
268
|
+
# Make sure len(audio_data) == (num_frames + 2 * effective_padding) * target_sr
|
269
|
+
num_audio_samples_required = int(round(audio_duration * target_sr))
|
270
|
+
if len(audio_data) < num_audio_samples_required:
|
271
|
+
raise ValueError(f"Audio data length {len(audio_data)} is less than required {num_audio_samples_required}")
|
272
|
+
audio_data = audio_data[:num_audio_samples_required]
|
273
|
+
# Convert to mono if stereo
|
274
|
+
if audio_data.ndim > 1 and audio_data.shape[1] > 1:
|
275
|
+
audio_data = np.mean(audio_data, axis=1)
|
276
|
+
|
277
|
+
# Close the clips
|
278
|
+
clip.close()
|
279
|
+
main_clip.close()
|
280
|
+
video.close()
|
281
|
+
|
282
|
+
# Reshape audio data
|
283
|
+
audio_data = np.array(audio_data) # This is just 1D
|
284
|
+
|
285
|
+
# Calculate dimensions for audio
|
286
|
+
audio_data_per_frame = int(round(target_sr / target_fps))
|
287
|
+
# print(f"Audio {audio_duration * target_sr}->{num_audio_samples_required} data len {audio_data.shape}, shape: {num_audio_frames}, {audio_data_per_frame}")
|
288
|
+
audio_data = audio_data.reshape(num_audio_frames, audio_data_per_frame)
|
289
|
+
|
290
|
+
# Create frame-wise audio
|
291
|
+
if audio_frames_per_video_frame > 1:
|
292
|
+
raise NotImplementedError("Frame-wise audio extraction is not implemented yet.")
|
293
|
+
else:
|
294
|
+
# Extract the central part (for effective frames) and reshape to (1, N, 1, K)
|
295
|
+
start_idx = effective_padding
|
296
|
+
end_idx = start_idx + num_frames
|
297
|
+
central_audio = audio_data[start_idx:end_idx]
|
298
|
+
frame_wise_audio = central_audio.reshape(1, num_frames, 1, audio_data_per_frame)
|
299
|
+
|
300
|
+
return frame_wise_audio, audio_data, video_frames
|
301
|
+
|
302
|
+
|
303
|
+
def read_av_random_clip_alt(
|
304
|
+
video_path: str,
|
305
|
+
num_frames: int = 16,
|
306
|
+
audio_frames_per_video_frame: int = 1,
|
307
|
+
audio_frame_padding: int = 0,
|
308
|
+
target_sr: int = 16000,
|
309
|
+
target_fps: float = 25.0,
|
310
|
+
random_seed: Optional[int] = None,
|
311
|
+
):
|
312
|
+
"""
|
313
|
+
Read a random clip of audio and video frames.
|
314
|
+
Works by first selecting a random appropriate start frame, then reading the specified number of frames (1, N, H, W, C).
|
315
|
+
It then selects the audio clip corresponding to the video frames + some extra padding frames on either side. This is
|
316
|
+
of shape (1, P + N + P, K) where P is the padding, N is the number of video frames, and K is the audio data shape per frame.
|
317
|
+
if audio_frames_per_video_frame > 1, It then also creates a tensor of shape (1, N, F, K) where F = audio_frames_per_video_frame.
|
318
|
+
Otherwise (1, N, 1, K) is returned in the case of audio_frames_per_video_frame = 1.
|
319
|
+
|
320
|
+
The final audio and video tensors are returned.
|
321
|
+
Args:
|
322
|
+
video_path: Path to the video file.
|
323
|
+
num_frames: Number of video frames to read.
|
324
|
+
audio_frames_per_video_frame: Number of audio frames per video frame.
|
325
|
+
audio_frame_padding: Padding for audio frames.
|
326
|
+
target_sr: Target sample rate for the audio.
|
327
|
+
target_fps: Target frames per second for the video.
|
328
|
+
random_seed: Random seed for reproducibility (optional).
|
329
|
+
|
330
|
+
Returns:
|
331
|
+
Tuple of (frame_wise_audio, full_padded_audio, video_frames) where video_frames is a numpy array.
|
332
|
+
"""
|
333
|
+
from moviepy import VideoFileClip, AudioFileClip
|
334
|
+
from video_reader import PyVideoReader
|
335
|
+
# Set random seed if provided
|
336
|
+
if random_seed is not None:
|
337
|
+
np.random.seed(random_seed)
|
338
|
+
# Load the video
|
339
|
+
vr = PyVideoReader(video_path)
|
340
|
+
info = vr.get_info()
|
341
|
+
total_frames = int(info['frame_count'])
|
342
|
+
|
343
|
+
# Calculate effective padding needed based on audio segmentation
|
344
|
+
effective_padding = max(audio_frame_padding, (audio_frames_per_video_frame) // 2)
|
345
|
+
|
346
|
+
# Make sure we have enough frames
|
347
|
+
if total_frames < num_frames + 2 * effective_padding:
|
348
|
+
raise ValueError(f"Video has only {total_frames} frames, but {num_frames + 2 * effective_padding} were requested (including effective padding)")
|
349
|
+
|
350
|
+
# Adjust the range for start_idx to account for effective padding
|
351
|
+
min_start_idx = effective_padding
|
352
|
+
max_start_idx = total_frames - num_frames - effective_padding
|
353
|
+
|
354
|
+
# Select a random start frame that allows for padding on both sides
|
355
|
+
start_idx = np.random.randint(min_start_idx, max_start_idx) if max_start_idx > min_start_idx else min_start_idx
|
356
|
+
end_idx = start_idx + num_frames
|
357
|
+
|
358
|
+
video_frames = vr.decode(start_idx, end_idx)
|
359
|
+
|
360
|
+
audio_start_time = (start_idx - effective_padding) / target_fps
|
361
|
+
audio_end_time = (end_idx + effective_padding) / target_fps
|
362
|
+
num_audio_frames = num_frames + 2 * effective_padding
|
363
|
+
audio_duration = audio_end_time - audio_start_time
|
364
|
+
|
365
|
+
assert audio_duration > 0, f"Audio duration {audio_duration} is not positive"
|
366
|
+
assert audio_start_time >= 0, f"Audio start time {audio_start_time} is negative"
|
367
|
+
|
368
|
+
# Extract the subclip
|
369
|
+
audio_clip : AudioFileClip = VideoFileClip(video_path).audio.with_fps(target_sr).subclipped(audio_start_time, audio_end_time)
|
370
|
+
audio_data = audio_clip.to_soundarray()
|
371
|
+
# Make sure len(audio_data) == (num_frames + 2 * effective_padding) * target_sr
|
372
|
+
num_audio_samples_required = int(round(audio_duration * target_sr))
|
373
|
+
|
374
|
+
if len(audio_data) < num_audio_samples_required:
|
375
|
+
raise ValueError(f"Audio data length {len(audio_data)} is less than required {num_audio_samples_required}")
|
376
|
+
|
377
|
+
audio_data = audio_data[:num_audio_samples_required]
|
378
|
+
# Convert to mono if stereo
|
379
|
+
if audio_data.ndim > 1 and audio_data.shape[1] > 1:
|
380
|
+
audio_data = np.mean(audio_data, axis=1)
|
381
|
+
|
382
|
+
# Close the clips
|
383
|
+
audio_clip.close()
|
384
|
+
|
385
|
+
# Reshape audio data
|
386
|
+
audio_data = np.array(audio_data) # This is just 1D
|
387
|
+
|
388
|
+
# Calculate dimensions for audio
|
389
|
+
audio_data_per_frame = int(round(target_sr / target_fps))
|
390
|
+
# print(f"Audio {audio_duration * target_sr}->{num_audio_samples_required} data len {audio_data.shape}, shape: {num_audio_frames}, {audio_data_per_frame}")
|
391
|
+
audio_data = audio_data.reshape(num_audio_frames, audio_data_per_frame)
|
392
|
+
|
393
|
+
# Create frame-wise audio
|
394
|
+
if audio_frames_per_video_frame > 1:
|
395
|
+
raise NotImplementedError("Frame-wise audio extraction is not implemented yet.")
|
396
|
+
else:
|
397
|
+
# Extract the central part (for effective frames) and reshape to (1, N, 1, K)
|
398
|
+
start_idx = effective_padding
|
399
|
+
end_idx = start_idx + num_frames
|
400
|
+
central_audio = audio_data[start_idx:end_idx]
|
401
|
+
frame_wise_audio = central_audio.reshape(1, num_frames, 1, audio_data_per_frame)
|
402
|
+
|
403
|
+
return frame_wise_audio, audio_data, video_frames
|
404
|
+
|
405
|
+
def read_av_random_clip_pyav(
|
406
|
+
video_path: str,
|
407
|
+
num_frames: int = 16,
|
408
|
+
audio_frames_per_video_frame: int = 1,
|
409
|
+
audio_frame_padding: int = 0,
|
410
|
+
target_sr: int = 16000,
|
411
|
+
target_fps: float = 25.0,
|
412
|
+
random_seed: Optional[int] = None,
|
413
|
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
414
|
+
"""
|
415
|
+
Decodes a random video clip and its corresponding audio from `video_path`,
|
416
|
+
padding audio by `audio_frame_padding` on each side in terms of video frames.
|
417
|
+
Uses PyAV's built-in resampler to produce mono 16-bit audio at `target_sr`.
|
418
|
+
|
419
|
+
Returns:
|
420
|
+
(frame_wise_audio, full_padded_audio, video_frames)
|
421
|
+
* frame_wise_audio: (1, num_frames, 1, audio_data_per_frame)
|
422
|
+
* full_padded_audio: (num_frames + 2*padding, audio_data_per_frame)
|
423
|
+
* video_frames: (num_frames, H, W, 3)
|
424
|
+
"""
|
425
|
+
from video_reader import PyVideoReader
|
426
|
+
import av
|
427
|
+
|
428
|
+
if random_seed is not None:
|
429
|
+
np.random.seed(random_seed)
|
430
|
+
|
431
|
+
# --- 1) Determine which video frames to read ---
|
432
|
+
vr = PyVideoReader(video_path)
|
433
|
+
total_frames = int(vr.get_info()["frame_count"])
|
434
|
+
eff_pad = max(audio_frame_padding, audio_frames_per_video_frame // 2)
|
435
|
+
needed_frames = num_frames + 2 * eff_pad
|
436
|
+
if total_frames < needed_frames:
|
437
|
+
raise ValueError(
|
438
|
+
f"Video has only {total_frames} frames but needs {needed_frames} (with padding)."
|
439
|
+
)
|
440
|
+
|
441
|
+
min_start = eff_pad
|
442
|
+
max_start = total_frames - num_frames - eff_pad
|
443
|
+
start_idx = (
|
444
|
+
np.random.randint(min_start, max_start)
|
445
|
+
if max_start > min_start
|
446
|
+
else min_start
|
447
|
+
)
|
448
|
+
end_idx = start_idx + num_frames
|
449
|
+
|
450
|
+
# --- 2) Decode the chosen video frames ---
|
451
|
+
video_frames = vr.decode(start_idx, end_idx) # shape => (num_frames, H, W, 3)
|
452
|
+
del vr
|
453
|
+
|
454
|
+
# --- 3) Define audio time window ---
|
455
|
+
audio_start_time = max(0.0, (start_idx - eff_pad) / target_fps)
|
456
|
+
audio_end_time = (end_idx + eff_pad) / target_fps
|
457
|
+
with av.open(video_path) as container:
|
458
|
+
audio_stream = next((s for s in container.streams if s.type == "audio"), None)
|
459
|
+
if audio_stream is None:
|
460
|
+
raise ValueError("No audio stream found in the file.")
|
461
|
+
|
462
|
+
# --- 4) Decode all audio, resample to s16 mono @ target_sr ---
|
463
|
+
resampler = av.AudioResampler(format="s16", layout="mono", rate=target_sr)
|
464
|
+
audio_segments = []
|
465
|
+
segment_times = []
|
466
|
+
for packet in container.demux(audio_stream):
|
467
|
+
for frame in packet.decode():
|
468
|
+
if frame.pts is None:
|
469
|
+
continue
|
470
|
+
out = resampler.resample(frame)
|
471
|
+
out = [out] if not isinstance(out, list) else out
|
472
|
+
for oframe in out:
|
473
|
+
# Extract samples from the PyAV audio frame
|
474
|
+
arr = oframe.to_ndarray() # shape: (1, samples) for mono
|
475
|
+
samples = arr.flatten().astype(np.int16)
|
476
|
+
start_t = float(oframe.pts * audio_stream.time_base)
|
477
|
+
end_t = start_t + oframe.samples / oframe.sample_rate
|
478
|
+
audio_segments.append(samples)
|
479
|
+
segment_times.append((start_t, end_t))
|
480
|
+
|
481
|
+
del resampler
|
482
|
+
|
483
|
+
if not audio_segments:
|
484
|
+
raise ValueError("No audio frames were decoded.")
|
485
|
+
|
486
|
+
full_audio = np.concatenate(audio_segments, axis=0)
|
487
|
+
seg_lens = [len(seg) for seg in audio_segments]
|
488
|
+
offsets = np.cumsum([0] + seg_lens)
|
489
|
+
|
490
|
+
# Helper: convert time -> sample index in full_audio
|
491
|
+
def time_to_sample(t):
|
492
|
+
if t <= segment_times[0][0]:
|
493
|
+
return 0
|
494
|
+
if t >= segment_times[-1][1]:
|
495
|
+
return len(full_audio)
|
496
|
+
for i, (st, ed) in enumerate(segment_times):
|
497
|
+
if st <= t < ed:
|
498
|
+
seg_offset = int(round((t - st) * audio_stream.rate))
|
499
|
+
return offsets[i] + min(seg_offset, seg_lens[i] - 1)
|
500
|
+
return len(full_audio)
|
501
|
+
|
502
|
+
start_sample = time_to_sample(audio_start_time)
|
503
|
+
end_sample = time_to_sample(audio_end_time)
|
504
|
+
if end_sample <= start_sample:
|
505
|
+
raise ValueError("No audio in the requested range.")
|
506
|
+
|
507
|
+
# Slice out the desired portion
|
508
|
+
sliced_audio = full_audio[start_sample:end_sample]
|
509
|
+
|
510
|
+
# --- 5) Convert to float32 in [-1,1], pad or trim to the exact length ---
|
511
|
+
# Overall expected sample count for the window
|
512
|
+
needed_samples_window = int(round((audio_end_time - audio_start_time) * target_sr))
|
513
|
+
if len(sliced_audio) < needed_samples_window:
|
514
|
+
pad = needed_samples_window - len(sliced_audio)
|
515
|
+
sliced_audio = np.pad(sliced_audio, (0, pad), "constant")
|
516
|
+
else:
|
517
|
+
sliced_audio = sliced_audio[:needed_samples_window]
|
518
|
+
# Convert to float in [-1, 1]
|
519
|
+
sliced_audio = sliced_audio.astype(np.float32) / 32768.0
|
520
|
+
|
521
|
+
# We ultimately need (num_frames + 2*pad) * audio_data_per_frame
|
522
|
+
num_audio_frames = num_frames + 2 * eff_pad
|
523
|
+
audio_data_per_frame = int(round(target_sr / target_fps))
|
524
|
+
needed_total_samples = num_audio_frames * audio_data_per_frame
|
525
|
+
|
526
|
+
# Final pad/trim to expected shape
|
527
|
+
if len(sliced_audio) < needed_total_samples:
|
528
|
+
pad = needed_total_samples - len(sliced_audio)
|
529
|
+
sliced_audio = np.pad(sliced_audio, (0, pad), "constant")
|
530
|
+
else:
|
531
|
+
sliced_audio = sliced_audio[:needed_total_samples]
|
532
|
+
|
533
|
+
full_padded_audio = sliced_audio.reshape(num_audio_frames, audio_data_per_frame)
|
534
|
+
|
535
|
+
# --- 6) Extract the clip's central audio & reshape for per-frame usage ---
|
536
|
+
if audio_frames_per_video_frame > 1:
|
537
|
+
raise NotImplementedError("Multiple audio frames per video frame not supported.")
|
538
|
+
center = full_padded_audio[eff_pad:eff_pad + num_frames]
|
539
|
+
frame_wise_audio = center.reshape(1, num_frames, 1, audio_data_per_frame)
|
540
|
+
|
541
|
+
return frame_wise_audio, full_padded_audio, video_frames
|
542
|
+
|
543
|
+
# Create a registry of all random clip readers for easier function selection
|
544
|
+
CLIP_READERS = {
|
545
|
+
'moviepy': read_av_random_clip_moviepy,
|
546
|
+
'alt': read_av_random_clip_alt,
|
547
|
+
'pyav': read_av_random_clip_pyav
|
548
|
+
}
|
549
|
+
|
550
|
+
def read_av_random_clip(
|
551
|
+
path: str,
|
552
|
+
num_frames: int = 16,
|
553
|
+
audio_frames_per_video_frame: int = 1,
|
554
|
+
audio_frame_padding: int = 0,
|
555
|
+
target_sr: int = 16000,
|
556
|
+
target_fps: float = 25.0,
|
557
|
+
random_seed: Optional[int] = None,
|
558
|
+
method: str = 'alt'
|
559
|
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
560
|
+
"""
|
561
|
+
Read a random clip of audio and video frames using specified method.
|
562
|
+
Args:
|
563
|
+
path (str): Path to the media file.
|
564
|
+
num_frames (int): Number of video frames to read.
|
565
|
+
audio_frames_per_video_frame (int): Number of audio frames per video frame.
|
566
|
+
audio_frame_padding (int): Padding for audio frames.
|
567
|
+
target_sr (int): Target sample rate for audio.
|
568
|
+
target_fps (float): Target frames per second for video.
|
569
|
+
random_seed (Optional[int]): Seed for random number generator.
|
570
|
+
method (str): Method to use for reading the clip.
|
571
|
+
Options: 'moviepy', 'alt', 'pyav'.
|
572
|
+
Returns:
|
573
|
+
Tuple[np.ndarray, np.ndarray, np.ndarray]: Tuple of (frame_wise_audio, full_padded_audio, video_frames).
|
574
|
+
- frame_wise_audio: Shape (1, num_frames, 1, audio_data_per_frame)
|
575
|
+
- full_padded_audio: Shape (num_frames + 2*padding, audio_data_per_frame)
|
576
|
+
- video_frames: Shape (num_frames, H, W, 3)
|
577
|
+
"""
|
578
|
+
|
579
|
+
if method not in CLIP_READERS:
|
580
|
+
raise ValueError(f"Unknown method: {method}. Available methods: {list(CLIP_READERS.keys())}")
|
581
|
+
|
582
|
+
return CLIP_READERS[method](
|
583
|
+
path,
|
584
|
+
num_frames=num_frames,
|
585
|
+
audio_frames_per_video_frame=audio_frames_per_video_frame,
|
586
|
+
audio_frame_padding=audio_frame_padding,
|
587
|
+
target_sr=target_sr,
|
588
|
+
target_fps=target_fps,
|
589
|
+
random_seed=random_seed
|
590
|
+
)
|