flaxdiff 0.1.38.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. flaxdiff/data/__init__.py +5 -1
  2. flaxdiff/data/benchmark_decord.py +443 -0
  3. flaxdiff/data/dataloaders.py +608 -0
  4. flaxdiff/data/dataset_map.py +61 -6
  5. flaxdiff/data/online_loader.py +779 -150
  6. flaxdiff/data/sources/audio_utils.py +142 -0
  7. flaxdiff/data/sources/av_example.py +125 -0
  8. flaxdiff/data/sources/av_utils.py +590 -0
  9. flaxdiff/data/sources/base.py +129 -0
  10. flaxdiff/data/sources/images.py +309 -0
  11. flaxdiff/data/sources/utils.py +158 -0
  12. flaxdiff/data/sources/videos.py +250 -0
  13. flaxdiff/data/sources/voxceleb2.py +412 -0
  14. flaxdiff/inference/__init__.py +0 -0
  15. flaxdiff/inference/pipeline.py +260 -0
  16. flaxdiff/inference/utils.py +320 -0
  17. flaxdiff/inputs/__init__.py +173 -0
  18. flaxdiff/inputs/encoders.py +98 -0
  19. flaxdiff/models/__init__.py +2 -1
  20. flaxdiff/models/autoencoder/autoencoder.py +141 -9
  21. flaxdiff/models/autoencoder/diffusers.py +88 -25
  22. flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
  23. flaxdiff/models/simple_unet.py +5 -5
  24. flaxdiff/models/simple_vit.py +1 -1
  25. flaxdiff/models/unet_3d.py +446 -0
  26. flaxdiff/models/unet_3d_blocks.py +505 -0
  27. flaxdiff/samplers/common.py +358 -96
  28. flaxdiff/samplers/ddim.py +44 -5
  29. flaxdiff/schedulers/karras.py +20 -12
  30. flaxdiff/trainer/__init__.py +2 -1
  31. flaxdiff/trainer/autoencoder_trainer.py +1 -2
  32. flaxdiff/trainer/diffusion_trainer.py +33 -27
  33. flaxdiff/trainer/general_diffusion_trainer.py +583 -0
  34. flaxdiff/trainer/simple_trainer.py +48 -31
  35. flaxdiff/utils.py +128 -57
  36. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
  37. flaxdiff-0.2.0.dist-info/RECORD +64 -0
  38. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
  39. flaxdiff/data/datasets.py +0 -169
  40. flaxdiff/data/sources/gcs.py +0 -81
  41. flaxdiff/data/sources/tfds.py +0 -79
  42. flaxdiff/trainer/video_diffusion_trainer.py +0 -62
  43. flaxdiff-0.1.38.1.dist-info/RECORD +0 -50
  44. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,590 @@
1
+ """
2
+ Functions for reading audio-video data without memory leaks.
3
+ """
4
+ import cv2
5
+ import os
6
+ import shutil
7
+ import subprocess
8
+ import numpy as np
9
+ from typing import Tuple, Optional, Union, List
10
+ from video_reader import PyVideoReader
11
+ from .audio_utils import read_audio
12
+
13
+ def get_video_fps(video_path: str):
14
+ cam = cv2.VideoCapture(video_path)
15
+ fps = cam.get(cv2.CAP_PROP_FPS)
16
+ cam.release()
17
+ return fps
18
+
19
+ def read_video(video_path: str, change_fps=False, reader="rsreader"):
20
+ temp_dir = None
21
+ try:
22
+ if change_fps:
23
+ print(f"Changing fps of {video_path} to 25")
24
+ temp_dir = "temp"
25
+ if os.path.exists(temp_dir):
26
+ shutil.rmtree(temp_dir)
27
+ os.makedirs(temp_dir, exist_ok=True)
28
+ command = (
29
+ f"ffmpeg -loglevel error -y -nostdin -i {video_path} -r 25 -crf 18 {os.path.join(temp_dir, 'video.mp4')}"
30
+ )
31
+ subprocess.run(command, shell=True)
32
+ target_video_path = os.path.join(temp_dir, "video.mp4")
33
+ else:
34
+ target_video_path = video_path
35
+
36
+ if reader == "rsreader":
37
+ return read_video_rsreader(target_video_path)
38
+ elif reader == "rsreader_fast":
39
+ return read_video_rsreader(target_video_path, fast=True)
40
+ elif reader == "decord":
41
+ return read_video_decord(target_video_path)
42
+ elif reader == "opencv":
43
+ return read_video_opencv(target_video_path)
44
+ else:
45
+ raise ValueError(f"Unknown reader: {reader}")
46
+ finally:
47
+ # Clean up temp directory when done
48
+ if change_fps and temp_dir and os.path.exists(temp_dir):
49
+ shutil.rmtree(temp_dir)
50
+
51
+ def read_video_decord(video_path: str):
52
+ from decord import VideoReader
53
+ vr = VideoReader(video_path)
54
+ video_frames = vr[:].asnumpy()
55
+ vr.seek(0)
56
+ return video_frames
57
+
58
+ # Fixed OpenCV video reader - properly release resources
59
+ def read_video_opencv(video_path):
60
+ cap = cv2.VideoCapture(video_path)
61
+ try:
62
+ frames = []
63
+ while True:
64
+ ret, frame = cap.read()
65
+ if not ret:
66
+ break
67
+ frames.append(frame)
68
+ return np.array(frames)[:, :, :, ::-1]
69
+ finally:
70
+ cap.release()
71
+
72
+ def read_video_rsreader(video_path, fast=False):
73
+ from video_reader import PyVideoReader
74
+ vr = PyVideoReader(video_path)
75
+ return vr.decode_fast() if fast else vr.decode()
76
+
77
+ def read_audio_decord(audio_path:str):
78
+ from decord import AudioReader
79
+ ar = AudioReader(audio_path)
80
+ audio_frames = ar[:].asnumpy()
81
+ ar.seek(0)
82
+ return audio_frames
83
+
84
+ def read_av_decord(path: str, start: int=0, end: int = None, ctx=None):
85
+ from decord import AVReader, cpu
86
+ if ctx is None:
87
+ ctx = cpu(0)
88
+ vr = AVReader(path, ctx=ctx, sample_rate=16000)
89
+ audio, video = vr[start:end]
90
+ return audio, video.asnumpy()
91
+
92
+ def read_av_improved(
93
+ path: str,
94
+ start: int = 0,
95
+ end: Optional[int] = None,
96
+ fps: float = 25.0,
97
+ target_sr: int = 16000,
98
+ audio_method: str = 'ffmpeg'
99
+ ) -> Tuple[Union[List, np.ndarray], np.ndarray]:
100
+ """
101
+ Read audio-video data with explicit cleanup and without memory leaks.
102
+ Uses PyVideoReader for video (which doesn't have memory leaks) and
103
+ FFmpeg/moviepy for audio extraction.
104
+
105
+ Args:
106
+ path: Path to the video file.
107
+ start: Start frame index.
108
+ end: End frame index (or None to read until the end).
109
+ fps: Video frames per second (used for audio timing).
110
+ target_sr: Target audio sample rate.
111
+ audio_method: Method to extract audio ('ffmpeg' or 'moviepy').
112
+
113
+ Returns:
114
+ Tuple of (audio_data, video_frames) where video_frames is a numpy array.
115
+ """
116
+ # Calculate time information for audio extraction
117
+ start_time = start / fps if start > 0 else 0
118
+ duration = None
119
+ if end is not None:
120
+ duration = (end - start) / fps
121
+
122
+ # Get video frames using PyVideoReader
123
+ vr = PyVideoReader(path)
124
+ video = vr.decode(start_frame=start, end_frame=None)
125
+
126
+ # Get audio data using our custom audio utilities
127
+ audio, _ = read_audio(
128
+ path,
129
+ start_time=start_time,
130
+ duration=duration,
131
+ target_sr=target_sr,
132
+ method=audio_method
133
+ )
134
+
135
+ # Convert audio to list for API compatibility with original read_av
136
+ audio_list = list(audio)
137
+
138
+ return audio_list, video
139
+
140
+ def read_av_moviepy(
141
+ video_path: str,
142
+ start_idx: Optional[int] = None,
143
+ end_idx: Optional[int] = None,
144
+ target_fps: float = 25.0,
145
+ target_sr: int = 16000,
146
+ ):
147
+ """
148
+ Read audio-video data using moviepy.
149
+
150
+ Args:
151
+ video_path: Path to the video file.
152
+ start_idx: Start frame index (optional).
153
+ end_idx: End frame index (optional).
154
+ target_sr: Target sample rate for the audio.
155
+
156
+ Returns:
157
+ Tuple of (audio_data, video_frames) where video_frames is a numpy array.
158
+ """
159
+ # Use moviepy to read audio and video
160
+ from moviepy import VideoFileClip
161
+
162
+ video = VideoFileClip(video_path).with_fps(target_fps)
163
+
164
+ # Convert frame indexes to time
165
+ start_time = start_idx / target_fps if start_idx is not None else 0
166
+ end_time = end_idx / target_fps if end_idx is not None else None
167
+ video = video.subclipped(start_time, end_time)
168
+
169
+ # Extract audio
170
+ audio = video.audio.with_fps(target_sr)
171
+ audio_data = audio.to_soundarray()
172
+ if audio_data.ndim > 1 and audio_data.shape[1] > 1:
173
+ audio_data = np.mean(audio_data, axis=1)
174
+
175
+ # Extract video frames
176
+ video_frames = []
177
+ for frame in video.iter_frames(fps=target_fps, dtype='uint8'):
178
+ video_frames.append(frame)
179
+ video_frames = np.array(video_frames)
180
+ video.close()
181
+ return audio_data, video_frames
182
+ def read_av_random_clip_moviepy(
183
+ video_path: str,
184
+ num_frames: int = 16,
185
+ audio_frames_per_video_frame: int = 1,
186
+ audio_frame_padding: int = 0,
187
+ target_sr: int = 16000,
188
+ target_fps: float = 25.0,
189
+ random_seed: Optional[int] = None,
190
+ ):
191
+ """
192
+ Read a random clip of audio and video frames.
193
+ Works by first selecting a random appropriate start frame, then reading the specified number of frames (1, N, H, W, C).
194
+ It then selects the audio clip corresponding to the video frames + some extra padding frames on either side. This is
195
+ of shape (1, P + N + P, K) where P is the padding, N is the number of video frames, and K is the audio data shape per frame.
196
+ if audio_frames_per_video_frame > 1, It then also creates a tensor of shape (1, N, F, K) where F = audio_frames_per_video_frame.
197
+ Otherwise (1, N, 1, K) is returned in the case of audio_frames_per_video_frame = 1.
198
+
199
+ The final audio and video tensors are returned.
200
+ Args:
201
+ video_path: Path to the video file.
202
+ num_frames: Number of video frames to read.
203
+ audio_frames_per_video_frame: Number of audio frames per video frame.
204
+ audio_frame_padding: Padding for audio frames.
205
+ target_sr: Target sample rate for the audio.
206
+ target_fps: Target frames per second for the video.
207
+ random_seed: Random seed for reproducibility (optional).
208
+
209
+ Returns:
210
+ Tuple of (frame_wise_audio, full_padded_audio, video_frames) where video_frames is a numpy array.
211
+ """
212
+ from moviepy import VideoFileClip
213
+ # Set random seed if provided
214
+ if random_seed is not None:
215
+ np.random.seed(random_seed)
216
+ # Load the video
217
+ video = VideoFileClip(video_path).with_fps(target_fps)
218
+ original_duration = video.duration
219
+ total_frames = video.n_frames#int(original_duration * target_fps)
220
+
221
+ # Calculate effective padding needed based on audio segmentation
222
+ effective_padding = max(audio_frame_padding, (audio_frames_per_video_frame) // 2)
223
+
224
+ # Make sure we have enough frames
225
+ if total_frames < num_frames + 2 * effective_padding:
226
+ raise ValueError(f"Video has only {total_frames} frames, but {num_frames + 2 * effective_padding} were requested (including effective padding)")
227
+
228
+ # Adjust the range for start_idx to account for effective padding
229
+ min_start_idx = effective_padding
230
+ max_start_idx = total_frames - num_frames - effective_padding
231
+
232
+ # Select a random start frame that allows for padding on both sides
233
+ start_idx = np.random.randint(min_start_idx, max_start_idx) if max_start_idx > min_start_idx else min_start_idx
234
+ end_idx = start_idx + num_frames
235
+
236
+ # Convert to time
237
+ video_start_time = start_idx / target_fps
238
+ video_end_time = end_idx / target_fps
239
+
240
+ # Extract video frames
241
+ main_clip : VideoFileClip = video.subclipped(video_start_time, video_end_time)
242
+ # Replace the video frame extraction with:
243
+ frame_count = 0
244
+ video_frames = []
245
+ for frame in video.iter_frames(fps=target_fps, dtype='uint8'):
246
+ if frame_count >= start_idx and frame_count < start_idx + num_frames:
247
+ video_frames.append(frame)
248
+ frame_count += 1
249
+ if len(video_frames) == num_frames:
250
+ break
251
+
252
+ # Convert to numpy array
253
+ video_frames = np.array(video_frames)
254
+
255
+ audio_start_time = (start_idx - effective_padding) / target_fps
256
+ audio_end_time = (end_idx + effective_padding) / target_fps
257
+ num_audio_frames = num_frames + 2 * effective_padding
258
+ audio_duration = audio_end_time - audio_start_time
259
+ # Ensure we don't go out of bounds
260
+ if audio_start_time < 0 or audio_end_time > original_duration:
261
+ raise ValueError(f"Audio start time {audio_start_time} or end time {audio_end_time} is out of bounds for video duration {original_duration}")
262
+
263
+ # Extract the subclip
264
+ clip : VideoFileClip = video.subclipped(audio_start_time, audio_end_time)
265
+ # Extract audio
266
+ audio = clip.audio.with_fps(target_sr)
267
+ audio_data = audio.to_soundarray()
268
+ # Make sure len(audio_data) == (num_frames + 2 * effective_padding) * target_sr
269
+ num_audio_samples_required = int(round(audio_duration * target_sr))
270
+ if len(audio_data) < num_audio_samples_required:
271
+ raise ValueError(f"Audio data length {len(audio_data)} is less than required {num_audio_samples_required}")
272
+ audio_data = audio_data[:num_audio_samples_required]
273
+ # Convert to mono if stereo
274
+ if audio_data.ndim > 1 and audio_data.shape[1] > 1:
275
+ audio_data = np.mean(audio_data, axis=1)
276
+
277
+ # Close the clips
278
+ clip.close()
279
+ main_clip.close()
280
+ video.close()
281
+
282
+ # Reshape audio data
283
+ audio_data = np.array(audio_data) # This is just 1D
284
+
285
+ # Calculate dimensions for audio
286
+ audio_data_per_frame = int(round(target_sr / target_fps))
287
+ # print(f"Audio {audio_duration * target_sr}->{num_audio_samples_required} data len {audio_data.shape}, shape: {num_audio_frames}, {audio_data_per_frame}")
288
+ audio_data = audio_data.reshape(num_audio_frames, audio_data_per_frame)
289
+
290
+ # Create frame-wise audio
291
+ if audio_frames_per_video_frame > 1:
292
+ raise NotImplementedError("Frame-wise audio extraction is not implemented yet.")
293
+ else:
294
+ # Extract the central part (for effective frames) and reshape to (1, N, 1, K)
295
+ start_idx = effective_padding
296
+ end_idx = start_idx + num_frames
297
+ central_audio = audio_data[start_idx:end_idx]
298
+ frame_wise_audio = central_audio.reshape(1, num_frames, 1, audio_data_per_frame)
299
+
300
+ return frame_wise_audio, audio_data, video_frames
301
+
302
+
303
+ def read_av_random_clip_alt(
304
+ video_path: str,
305
+ num_frames: int = 16,
306
+ audio_frames_per_video_frame: int = 1,
307
+ audio_frame_padding: int = 0,
308
+ target_sr: int = 16000,
309
+ target_fps: float = 25.0,
310
+ random_seed: Optional[int] = None,
311
+ ):
312
+ """
313
+ Read a random clip of audio and video frames.
314
+ Works by first selecting a random appropriate start frame, then reading the specified number of frames (1, N, H, W, C).
315
+ It then selects the audio clip corresponding to the video frames + some extra padding frames on either side. This is
316
+ of shape (1, P + N + P, K) where P is the padding, N is the number of video frames, and K is the audio data shape per frame.
317
+ if audio_frames_per_video_frame > 1, It then also creates a tensor of shape (1, N, F, K) where F = audio_frames_per_video_frame.
318
+ Otherwise (1, N, 1, K) is returned in the case of audio_frames_per_video_frame = 1.
319
+
320
+ The final audio and video tensors are returned.
321
+ Args:
322
+ video_path: Path to the video file.
323
+ num_frames: Number of video frames to read.
324
+ audio_frames_per_video_frame: Number of audio frames per video frame.
325
+ audio_frame_padding: Padding for audio frames.
326
+ target_sr: Target sample rate for the audio.
327
+ target_fps: Target frames per second for the video.
328
+ random_seed: Random seed for reproducibility (optional).
329
+
330
+ Returns:
331
+ Tuple of (frame_wise_audio, full_padded_audio, video_frames) where video_frames is a numpy array.
332
+ """
333
+ from moviepy import VideoFileClip, AudioFileClip
334
+ from video_reader import PyVideoReader
335
+ # Set random seed if provided
336
+ if random_seed is not None:
337
+ np.random.seed(random_seed)
338
+ # Load the video
339
+ vr = PyVideoReader(video_path)
340
+ info = vr.get_info()
341
+ total_frames = int(info['frame_count'])
342
+
343
+ # Calculate effective padding needed based on audio segmentation
344
+ effective_padding = max(audio_frame_padding, (audio_frames_per_video_frame) // 2)
345
+
346
+ # Make sure we have enough frames
347
+ if total_frames < num_frames + 2 * effective_padding:
348
+ raise ValueError(f"Video has only {total_frames} frames, but {num_frames + 2 * effective_padding} were requested (including effective padding)")
349
+
350
+ # Adjust the range for start_idx to account for effective padding
351
+ min_start_idx = effective_padding
352
+ max_start_idx = total_frames - num_frames - effective_padding
353
+
354
+ # Select a random start frame that allows for padding on both sides
355
+ start_idx = np.random.randint(min_start_idx, max_start_idx) if max_start_idx > min_start_idx else min_start_idx
356
+ end_idx = start_idx + num_frames
357
+
358
+ video_frames = vr.decode(start_idx, end_idx)
359
+
360
+ audio_start_time = (start_idx - effective_padding) / target_fps
361
+ audio_end_time = (end_idx + effective_padding) / target_fps
362
+ num_audio_frames = num_frames + 2 * effective_padding
363
+ audio_duration = audio_end_time - audio_start_time
364
+
365
+ assert audio_duration > 0, f"Audio duration {audio_duration} is not positive"
366
+ assert audio_start_time >= 0, f"Audio start time {audio_start_time} is negative"
367
+
368
+ # Extract the subclip
369
+ audio_clip : AudioFileClip = VideoFileClip(video_path).audio.with_fps(target_sr).subclipped(audio_start_time, audio_end_time)
370
+ audio_data = audio_clip.to_soundarray()
371
+ # Make sure len(audio_data) == (num_frames + 2 * effective_padding) * target_sr
372
+ num_audio_samples_required = int(round(audio_duration * target_sr))
373
+
374
+ if len(audio_data) < num_audio_samples_required:
375
+ raise ValueError(f"Audio data length {len(audio_data)} is less than required {num_audio_samples_required}")
376
+
377
+ audio_data = audio_data[:num_audio_samples_required]
378
+ # Convert to mono if stereo
379
+ if audio_data.ndim > 1 and audio_data.shape[1] > 1:
380
+ audio_data = np.mean(audio_data, axis=1)
381
+
382
+ # Close the clips
383
+ audio_clip.close()
384
+
385
+ # Reshape audio data
386
+ audio_data = np.array(audio_data) # This is just 1D
387
+
388
+ # Calculate dimensions for audio
389
+ audio_data_per_frame = int(round(target_sr / target_fps))
390
+ # print(f"Audio {audio_duration * target_sr}->{num_audio_samples_required} data len {audio_data.shape}, shape: {num_audio_frames}, {audio_data_per_frame}")
391
+ audio_data = audio_data.reshape(num_audio_frames, audio_data_per_frame)
392
+
393
+ # Create frame-wise audio
394
+ if audio_frames_per_video_frame > 1:
395
+ raise NotImplementedError("Frame-wise audio extraction is not implemented yet.")
396
+ else:
397
+ # Extract the central part (for effective frames) and reshape to (1, N, 1, K)
398
+ start_idx = effective_padding
399
+ end_idx = start_idx + num_frames
400
+ central_audio = audio_data[start_idx:end_idx]
401
+ frame_wise_audio = central_audio.reshape(1, num_frames, 1, audio_data_per_frame)
402
+
403
+ return frame_wise_audio, audio_data, video_frames
404
+
405
+ def read_av_random_clip_pyav(
406
+ video_path: str,
407
+ num_frames: int = 16,
408
+ audio_frames_per_video_frame: int = 1,
409
+ audio_frame_padding: int = 0,
410
+ target_sr: int = 16000,
411
+ target_fps: float = 25.0,
412
+ random_seed: Optional[int] = None,
413
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
414
+ """
415
+ Decodes a random video clip and its corresponding audio from `video_path`,
416
+ padding audio by `audio_frame_padding` on each side in terms of video frames.
417
+ Uses PyAV's built-in resampler to produce mono 16-bit audio at `target_sr`.
418
+
419
+ Returns:
420
+ (frame_wise_audio, full_padded_audio, video_frames)
421
+ * frame_wise_audio: (1, num_frames, 1, audio_data_per_frame)
422
+ * full_padded_audio: (num_frames + 2*padding, audio_data_per_frame)
423
+ * video_frames: (num_frames, H, W, 3)
424
+ """
425
+ from video_reader import PyVideoReader
426
+ import av
427
+
428
+ if random_seed is not None:
429
+ np.random.seed(random_seed)
430
+
431
+ # --- 1) Determine which video frames to read ---
432
+ vr = PyVideoReader(video_path)
433
+ total_frames = int(vr.get_info()["frame_count"])
434
+ eff_pad = max(audio_frame_padding, audio_frames_per_video_frame // 2)
435
+ needed_frames = num_frames + 2 * eff_pad
436
+ if total_frames < needed_frames:
437
+ raise ValueError(
438
+ f"Video has only {total_frames} frames but needs {needed_frames} (with padding)."
439
+ )
440
+
441
+ min_start = eff_pad
442
+ max_start = total_frames - num_frames - eff_pad
443
+ start_idx = (
444
+ np.random.randint(min_start, max_start)
445
+ if max_start > min_start
446
+ else min_start
447
+ )
448
+ end_idx = start_idx + num_frames
449
+
450
+ # --- 2) Decode the chosen video frames ---
451
+ video_frames = vr.decode(start_idx, end_idx) # shape => (num_frames, H, W, 3)
452
+ del vr
453
+
454
+ # --- 3) Define audio time window ---
455
+ audio_start_time = max(0.0, (start_idx - eff_pad) / target_fps)
456
+ audio_end_time = (end_idx + eff_pad) / target_fps
457
+ with av.open(video_path) as container:
458
+ audio_stream = next((s for s in container.streams if s.type == "audio"), None)
459
+ if audio_stream is None:
460
+ raise ValueError("No audio stream found in the file.")
461
+
462
+ # --- 4) Decode all audio, resample to s16 mono @ target_sr ---
463
+ resampler = av.AudioResampler(format="s16", layout="mono", rate=target_sr)
464
+ audio_segments = []
465
+ segment_times = []
466
+ for packet in container.demux(audio_stream):
467
+ for frame in packet.decode():
468
+ if frame.pts is None:
469
+ continue
470
+ out = resampler.resample(frame)
471
+ out = [out] if not isinstance(out, list) else out
472
+ for oframe in out:
473
+ # Extract samples from the PyAV audio frame
474
+ arr = oframe.to_ndarray() # shape: (1, samples) for mono
475
+ samples = arr.flatten().astype(np.int16)
476
+ start_t = float(oframe.pts * audio_stream.time_base)
477
+ end_t = start_t + oframe.samples / oframe.sample_rate
478
+ audio_segments.append(samples)
479
+ segment_times.append((start_t, end_t))
480
+
481
+ del resampler
482
+
483
+ if not audio_segments:
484
+ raise ValueError("No audio frames were decoded.")
485
+
486
+ full_audio = np.concatenate(audio_segments, axis=0)
487
+ seg_lens = [len(seg) for seg in audio_segments]
488
+ offsets = np.cumsum([0] + seg_lens)
489
+
490
+ # Helper: convert time -> sample index in full_audio
491
+ def time_to_sample(t):
492
+ if t <= segment_times[0][0]:
493
+ return 0
494
+ if t >= segment_times[-1][1]:
495
+ return len(full_audio)
496
+ for i, (st, ed) in enumerate(segment_times):
497
+ if st <= t < ed:
498
+ seg_offset = int(round((t - st) * audio_stream.rate))
499
+ return offsets[i] + min(seg_offset, seg_lens[i] - 1)
500
+ return len(full_audio)
501
+
502
+ start_sample = time_to_sample(audio_start_time)
503
+ end_sample = time_to_sample(audio_end_time)
504
+ if end_sample <= start_sample:
505
+ raise ValueError("No audio in the requested range.")
506
+
507
+ # Slice out the desired portion
508
+ sliced_audio = full_audio[start_sample:end_sample]
509
+
510
+ # --- 5) Convert to float32 in [-1,1], pad or trim to the exact length ---
511
+ # Overall expected sample count for the window
512
+ needed_samples_window = int(round((audio_end_time - audio_start_time) * target_sr))
513
+ if len(sliced_audio) < needed_samples_window:
514
+ pad = needed_samples_window - len(sliced_audio)
515
+ sliced_audio = np.pad(sliced_audio, (0, pad), "constant")
516
+ else:
517
+ sliced_audio = sliced_audio[:needed_samples_window]
518
+ # Convert to float in [-1, 1]
519
+ sliced_audio = sliced_audio.astype(np.float32) / 32768.0
520
+
521
+ # We ultimately need (num_frames + 2*pad) * audio_data_per_frame
522
+ num_audio_frames = num_frames + 2 * eff_pad
523
+ audio_data_per_frame = int(round(target_sr / target_fps))
524
+ needed_total_samples = num_audio_frames * audio_data_per_frame
525
+
526
+ # Final pad/trim to expected shape
527
+ if len(sliced_audio) < needed_total_samples:
528
+ pad = needed_total_samples - len(sliced_audio)
529
+ sliced_audio = np.pad(sliced_audio, (0, pad), "constant")
530
+ else:
531
+ sliced_audio = sliced_audio[:needed_total_samples]
532
+
533
+ full_padded_audio = sliced_audio.reshape(num_audio_frames, audio_data_per_frame)
534
+
535
+ # --- 6) Extract the clip's central audio & reshape for per-frame usage ---
536
+ if audio_frames_per_video_frame > 1:
537
+ raise NotImplementedError("Multiple audio frames per video frame not supported.")
538
+ center = full_padded_audio[eff_pad:eff_pad + num_frames]
539
+ frame_wise_audio = center.reshape(1, num_frames, 1, audio_data_per_frame)
540
+
541
+ return frame_wise_audio, full_padded_audio, video_frames
542
+
543
+ # Create a registry of all random clip readers for easier function selection
544
+ CLIP_READERS = {
545
+ 'moviepy': read_av_random_clip_moviepy,
546
+ 'alt': read_av_random_clip_alt,
547
+ 'pyav': read_av_random_clip_pyav
548
+ }
549
+
550
+ def read_av_random_clip(
551
+ path: str,
552
+ num_frames: int = 16,
553
+ audio_frames_per_video_frame: int = 1,
554
+ audio_frame_padding: int = 0,
555
+ target_sr: int = 16000,
556
+ target_fps: float = 25.0,
557
+ random_seed: Optional[int] = None,
558
+ method: str = 'alt'
559
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
560
+ """
561
+ Read a random clip of audio and video frames using specified method.
562
+ Args:
563
+ path (str): Path to the media file.
564
+ num_frames (int): Number of video frames to read.
565
+ audio_frames_per_video_frame (int): Number of audio frames per video frame.
566
+ audio_frame_padding (int): Padding for audio frames.
567
+ target_sr (int): Target sample rate for audio.
568
+ target_fps (float): Target frames per second for video.
569
+ random_seed (Optional[int]): Seed for random number generator.
570
+ method (str): Method to use for reading the clip.
571
+ Options: 'moviepy', 'alt', 'pyav'.
572
+ Returns:
573
+ Tuple[np.ndarray, np.ndarray, np.ndarray]: Tuple of (frame_wise_audio, full_padded_audio, video_frames).
574
+ - frame_wise_audio: Shape (1, num_frames, 1, audio_data_per_frame)
575
+ - full_padded_audio: Shape (num_frames + 2*padding, audio_data_per_frame)
576
+ - video_frames: Shape (num_frames, H, W, 3)
577
+ """
578
+
579
+ if method not in CLIP_READERS:
580
+ raise ValueError(f"Unknown method: {method}. Available methods: {list(CLIP_READERS.keys())}")
581
+
582
+ return CLIP_READERS[method](
583
+ path,
584
+ num_frames=num_frames,
585
+ audio_frames_per_video_frame=audio_frames_per_video_frame,
586
+ audio_frame_padding=audio_frame_padding,
587
+ target_sr=target_sr,
588
+ target_fps=target_fps,
589
+ random_seed=random_seed
590
+ )