flaxdiff 0.1.38.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/__init__.py +5 -1
- flaxdiff/data/benchmark_decord.py +443 -0
- flaxdiff/data/dataloaders.py +608 -0
- flaxdiff/data/dataset_map.py +61 -6
- flaxdiff/data/online_loader.py +779 -150
- flaxdiff/data/sources/audio_utils.py +142 -0
- flaxdiff/data/sources/av_example.py +125 -0
- flaxdiff/data/sources/av_utils.py +590 -0
- flaxdiff/data/sources/base.py +129 -0
- flaxdiff/data/sources/images.py +309 -0
- flaxdiff/data/sources/utils.py +158 -0
- flaxdiff/data/sources/videos.py +250 -0
- flaxdiff/data/sources/voxceleb2.py +412 -0
- flaxdiff/inference/__init__.py +0 -0
- flaxdiff/inference/pipeline.py +260 -0
- flaxdiff/inference/utils.py +320 -0
- flaxdiff/inputs/__init__.py +173 -0
- flaxdiff/inputs/encoders.py +98 -0
- flaxdiff/models/__init__.py +2 -1
- flaxdiff/models/autoencoder/autoencoder.py +141 -9
- flaxdiff/models/autoencoder/diffusers.py +88 -25
- flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
- flaxdiff/models/simple_unet.py +5 -5
- flaxdiff/models/simple_vit.py +1 -1
- flaxdiff/models/unet_3d.py +446 -0
- flaxdiff/models/unet_3d_blocks.py +505 -0
- flaxdiff/samplers/common.py +358 -96
- flaxdiff/samplers/ddim.py +44 -5
- flaxdiff/schedulers/karras.py +20 -12
- flaxdiff/trainer/__init__.py +2 -1
- flaxdiff/trainer/autoencoder_trainer.py +1 -2
- flaxdiff/trainer/diffusion_trainer.py +33 -27
- flaxdiff/trainer/general_diffusion_trainer.py +583 -0
- flaxdiff/trainer/simple_trainer.py +48 -31
- flaxdiff/utils.py +128 -57
- {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
- flaxdiff-0.2.0.dist-info/RECORD +64 -0
- {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
- flaxdiff/data/datasets.py +0 -169
- flaxdiff/data/sources/gcs.py +0 -81
- flaxdiff/data/sources/tfds.py +0 -79
- flaxdiff/trainer/video_diffusion_trainer.py +0 -62
- flaxdiff-0.1.38.1.dist-info/RECORD +0 -50
- {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,142 @@
|
|
1
|
+
"""
|
2
|
+
Audio utilities for efficiently loading audio data from video files.
|
3
|
+
This module provides alternatives to decord's AudioReader/AVReader (which have memory leaks).
|
4
|
+
"""
|
5
|
+
|
6
|
+
import os
|
7
|
+
import tempfile
|
8
|
+
import subprocess
|
9
|
+
import numpy as np
|
10
|
+
from typing import Tuple, Optional, Union
|
11
|
+
|
12
|
+
|
13
|
+
def read_audio_ffmpeg(
|
14
|
+
video_path: str,
|
15
|
+
start_time: Optional[float] = None,
|
16
|
+
duration: Optional[float] = None,
|
17
|
+
target_sr: int = 16000
|
18
|
+
) -> Tuple[np.ndarray, int]:
|
19
|
+
"""
|
20
|
+
Extract audio from video file using ffmpeg subprocess calls.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
video_path: Path to the video file.
|
24
|
+
start_time: Start time in seconds (optional).
|
25
|
+
duration: Duration to extract in seconds (optional).
|
26
|
+
target_sr: Target sample rate for the audio.
|
27
|
+
|
28
|
+
Returns:
|
29
|
+
Tuple of (audio_data, sample_rate) where audio_data is a numpy array.
|
30
|
+
"""
|
31
|
+
# Create a temporary file for the audio
|
32
|
+
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file:
|
33
|
+
tmp_path = tmp_file.name
|
34
|
+
|
35
|
+
try:
|
36
|
+
# Build the ffmpeg command
|
37
|
+
cmd = ['ffmpeg', '-y', '-i', video_path]
|
38
|
+
|
39
|
+
# Add time parameters if specified
|
40
|
+
if start_time is not None:
|
41
|
+
cmd.extend(['-ss', str(start_time)])
|
42
|
+
|
43
|
+
if duration is not None:
|
44
|
+
cmd.extend(['-t', str(duration)])
|
45
|
+
|
46
|
+
# Set output parameters (mono, target sample rate)
|
47
|
+
cmd.extend([
|
48
|
+
'-ac', '1', # mono
|
49
|
+
'-ar', str(target_sr), # sample rate
|
50
|
+
'-vn', # no video
|
51
|
+
'-f', 'wav', # wav format
|
52
|
+
tmp_path
|
53
|
+
])
|
54
|
+
|
55
|
+
# Execute the command
|
56
|
+
subprocess.run(cmd, check=True, stderr=subprocess.PIPE, stdout=subprocess.PIPE)
|
57
|
+
|
58
|
+
# Read the audio file using numpy
|
59
|
+
audio_data = np.fromfile(tmp_path, np.int16).astype(np.float32) / 32768.0 # Convert to float in [-1, 1]
|
60
|
+
|
61
|
+
return audio_data, target_sr
|
62
|
+
|
63
|
+
finally:
|
64
|
+
# Always clean up the temporary file
|
65
|
+
try:
|
66
|
+
os.unlink(tmp_path)
|
67
|
+
except:
|
68
|
+
pass
|
69
|
+
|
70
|
+
|
71
|
+
def read_audio_moviepy(
|
72
|
+
video_path: str,
|
73
|
+
start_time: Optional[float] = None,
|
74
|
+
duration: Optional[float] = None,
|
75
|
+
target_sr: int = 16000
|
76
|
+
) -> Tuple[np.ndarray, int]:
|
77
|
+
"""
|
78
|
+
Extract audio from video file using moviepy.
|
79
|
+
Requires the moviepy package: pip install moviepy
|
80
|
+
|
81
|
+
Args:
|
82
|
+
video_path: Path to the video file.
|
83
|
+
start_time: Start time in seconds (optional).
|
84
|
+
duration: Duration to extract in seconds (optional).
|
85
|
+
target_sr: Target sample rate for the audio.
|
86
|
+
|
87
|
+
Returns:
|
88
|
+
Tuple of (audio_data, sample_rate) where audio_data is a numpy array.
|
89
|
+
"""
|
90
|
+
try:
|
91
|
+
from moviepy import VideoFileClip
|
92
|
+
except ImportError:
|
93
|
+
raise ImportError("moviepy is not installed. Install it with 'pip install moviepy'")
|
94
|
+
|
95
|
+
# Load video file
|
96
|
+
if start_time is not None or duration is not None:
|
97
|
+
start_t = start_time if start_time is not None else 0
|
98
|
+
end_t = start_t + duration if duration is not None else None
|
99
|
+
video = VideoFileClip(video_path).subclipped(start_t, end_t)
|
100
|
+
else:
|
101
|
+
video = VideoFileClip(video_path)
|
102
|
+
# Extract audio
|
103
|
+
audio = video.audio.with_fps(target_sr)
|
104
|
+
|
105
|
+
# Get audio data
|
106
|
+
audio_data = audio.to_soundarray()
|
107
|
+
|
108
|
+
# Convert to mono if stereo
|
109
|
+
if audio_data.ndim > 1 and audio_data.shape[1] > 1:
|
110
|
+
audio_data = np.mean(audio_data, axis=1)
|
111
|
+
|
112
|
+
# Clean up
|
113
|
+
video.close()
|
114
|
+
|
115
|
+
return audio_data, target_sr
|
116
|
+
|
117
|
+
|
118
|
+
# Helper function to choose the best available method
|
119
|
+
def read_audio(
|
120
|
+
video_path: str,
|
121
|
+
start_time: Optional[float] = None,
|
122
|
+
duration: Optional[float] = None,
|
123
|
+
target_sr: int = 16000,
|
124
|
+
method: str = 'ffmpeg'
|
125
|
+
) -> Tuple[np.ndarray, int]:
|
126
|
+
"""
|
127
|
+
Extract audio from video file using the specified method.
|
128
|
+
|
129
|
+
Args:
|
130
|
+
video_path: Path to the video file.
|
131
|
+
start_time: Start time in seconds (optional).
|
132
|
+
duration: Duration to extract in seconds (optional).
|
133
|
+
target_sr: Target sample rate for the audio.
|
134
|
+
method: Method to use ('ffmpeg' or 'moviepy').
|
135
|
+
|
136
|
+
Returns:
|
137
|
+
Tuple of (audio_data, sample_rate) where audio_data is a numpy array.
|
138
|
+
"""
|
139
|
+
if method == 'moviepy':
|
140
|
+
return read_audio_moviepy(video_path, start_time, duration, target_sr)
|
141
|
+
else: # default to ffmpeg
|
142
|
+
return read_audio_ffmpeg(video_path, start_time, duration, target_sr)
|
@@ -0,0 +1,125 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
"""
|
3
|
+
Example script demonstrating how to use the memory-leak-free audio-video reading functions.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import os
|
7
|
+
import time
|
8
|
+
import numpy as np
|
9
|
+
import matplotlib.pyplot as plt
|
10
|
+
from av_utils import read_av_improved, read_av_batch
|
11
|
+
from audio_utils import read_audio
|
12
|
+
import argparse
|
13
|
+
|
14
|
+
|
15
|
+
def visualize_av_data(audio_data, video_frames, output_path=None):
|
16
|
+
"""
|
17
|
+
Visualize audio and video data.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
audio_data: Audio data as numpy array or list.
|
21
|
+
video_frames: Video frames as numpy array.
|
22
|
+
output_path: Path to save visualization (optional).
|
23
|
+
"""
|
24
|
+
fig = plt.figure(figsize=(12, 6))
|
25
|
+
|
26
|
+
# Number of frames to show
|
27
|
+
num_frames = min(4, len(video_frames))
|
28
|
+
|
29
|
+
# Plot audio waveform
|
30
|
+
plt.subplot(2, num_frames, 1)
|
31
|
+
plt.plot(audio_data[:10000])
|
32
|
+
plt.title('Audio Waveform')
|
33
|
+
plt.grid(True)
|
34
|
+
|
35
|
+
# Plot audio spectrogram
|
36
|
+
plt.subplot(2, num_frames, 2)
|
37
|
+
plt.specgram(audio_data, NFFT=1024, Fs=16000)
|
38
|
+
plt.title('Audio Spectrogram')
|
39
|
+
|
40
|
+
# Plot sample frames
|
41
|
+
for i in range(num_frames):
|
42
|
+
plt.subplot(2, num_frames, num_frames+i+1)
|
43
|
+
plt.imshow(video_frames[i*len(video_frames)//num_frames])
|
44
|
+
plt.title(f'Frame {i*len(video_frames)//num_frames}')
|
45
|
+
plt.axis('off')
|
46
|
+
|
47
|
+
plt.tight_layout()
|
48
|
+
|
49
|
+
if output_path:
|
50
|
+
plt.savefig(output_path)
|
51
|
+
print(f"Visualization saved to {output_path}")
|
52
|
+
|
53
|
+
plt.show()
|
54
|
+
|
55
|
+
|
56
|
+
def benchmark_av_reading(video_path, num_iterations=10, use_batch=False):
|
57
|
+
"""
|
58
|
+
Benchmark audio-video reading performance.
|
59
|
+
|
60
|
+
Args:
|
61
|
+
video_path: Path to the video file.
|
62
|
+
num_iterations: Number of iterations for benchmarking.
|
63
|
+
use_batch: Whether to use batch reading.
|
64
|
+
"""
|
65
|
+
print(f"Benchmarking {'batch' if use_batch else 'single'} reading...")
|
66
|
+
|
67
|
+
# Perform warmup
|
68
|
+
if use_batch:
|
69
|
+
_ = read_av_batch([video_path])
|
70
|
+
else:
|
71
|
+
_ = read_av_improved(video_path)
|
72
|
+
|
73
|
+
# Measure performance
|
74
|
+
start_time = time.time()
|
75
|
+
|
76
|
+
for i in range(num_iterations):
|
77
|
+
if use_batch:
|
78
|
+
results = read_av_batch([video_path])
|
79
|
+
else:
|
80
|
+
audio, video = read_av_improved(video_path)
|
81
|
+
|
82
|
+
end_time = time.time()
|
83
|
+
avg_time = (end_time - start_time) / num_iterations
|
84
|
+
|
85
|
+
print(f"Average time per read: {avg_time:.4f} seconds")
|
86
|
+
|
87
|
+
return avg_time
|
88
|
+
|
89
|
+
|
90
|
+
def main():
|
91
|
+
parser = argparse.ArgumentParser(description="Demo for memory-leak-free audio-video reading")
|
92
|
+
parser.add_argument("--video", "-v", required=True, help="Path to the video file")
|
93
|
+
parser.add_argument("--output", "-o", help="Path to save visualization")
|
94
|
+
parser.add_argument("--benchmark", "-b", action="store_true", help="Run benchmarks")
|
95
|
+
parser.add_argument("--iterations", "-i", type=int, default=10, help="Number of benchmark iterations")
|
96
|
+
|
97
|
+
args = parser.parse_args()
|
98
|
+
|
99
|
+
if not os.path.exists(args.video):
|
100
|
+
print(f"Error: Video file not found: {args.video}")
|
101
|
+
return
|
102
|
+
|
103
|
+
# Load audio-video data
|
104
|
+
print(f"Reading audio-video data from {args.video}...")
|
105
|
+
audio, video = read_av_improved(args.video)
|
106
|
+
|
107
|
+
print(f"Video shape: {video.shape}")
|
108
|
+
print(f"Audio length: {len(audio)}")
|
109
|
+
|
110
|
+
# Visualize data
|
111
|
+
visualize_av_data(audio, video, args.output)
|
112
|
+
|
113
|
+
# Run benchmarks if requested
|
114
|
+
if args.benchmark:
|
115
|
+
print("\nRunning benchmarks...")
|
116
|
+
single_time = benchmark_av_reading(args.video, args.iterations, use_batch=False)
|
117
|
+
batch_time = benchmark_av_reading(args.video, args.iterations, use_batch=True)
|
118
|
+
|
119
|
+
print("\nBenchmark results:")
|
120
|
+
print(f"Single reading: {single_time:.4f} seconds per video")
|
121
|
+
print(f"Batch reading: {batch_time:.4f} seconds per video")
|
122
|
+
|
123
|
+
|
124
|
+
if __name__ == "__main__":
|
125
|
+
main()
|