flaxdiff 0.1.38.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. flaxdiff/data/__init__.py +5 -1
  2. flaxdiff/data/benchmark_decord.py +443 -0
  3. flaxdiff/data/dataloaders.py +608 -0
  4. flaxdiff/data/dataset_map.py +61 -6
  5. flaxdiff/data/online_loader.py +779 -150
  6. flaxdiff/data/sources/audio_utils.py +142 -0
  7. flaxdiff/data/sources/av_example.py +125 -0
  8. flaxdiff/data/sources/av_utils.py +590 -0
  9. flaxdiff/data/sources/base.py +129 -0
  10. flaxdiff/data/sources/images.py +309 -0
  11. flaxdiff/data/sources/utils.py +158 -0
  12. flaxdiff/data/sources/videos.py +250 -0
  13. flaxdiff/data/sources/voxceleb2.py +412 -0
  14. flaxdiff/inference/__init__.py +0 -0
  15. flaxdiff/inference/pipeline.py +260 -0
  16. flaxdiff/inference/utils.py +320 -0
  17. flaxdiff/inputs/__init__.py +173 -0
  18. flaxdiff/inputs/encoders.py +98 -0
  19. flaxdiff/models/__init__.py +2 -1
  20. flaxdiff/models/autoencoder/autoencoder.py +141 -9
  21. flaxdiff/models/autoencoder/diffusers.py +88 -25
  22. flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
  23. flaxdiff/models/simple_unet.py +5 -5
  24. flaxdiff/models/simple_vit.py +1 -1
  25. flaxdiff/models/unet_3d.py +446 -0
  26. flaxdiff/models/unet_3d_blocks.py +505 -0
  27. flaxdiff/samplers/common.py +358 -96
  28. flaxdiff/samplers/ddim.py +44 -5
  29. flaxdiff/schedulers/karras.py +20 -12
  30. flaxdiff/trainer/__init__.py +2 -1
  31. flaxdiff/trainer/autoencoder_trainer.py +1 -2
  32. flaxdiff/trainer/diffusion_trainer.py +33 -27
  33. flaxdiff/trainer/general_diffusion_trainer.py +583 -0
  34. flaxdiff/trainer/simple_trainer.py +48 -31
  35. flaxdiff/utils.py +128 -57
  36. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
  37. flaxdiff-0.2.0.dist-info/RECORD +64 -0
  38. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
  39. flaxdiff/data/datasets.py +0 -169
  40. flaxdiff/data/sources/gcs.py +0 -81
  41. flaxdiff/data/sources/tfds.py +0 -79
  42. flaxdiff/trainer/video_diffusion_trainer.py +0 -62
  43. flaxdiff-0.1.38.1.dist-info/RECORD +0 -50
  44. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,142 @@
1
+ """
2
+ Audio utilities for efficiently loading audio data from video files.
3
+ This module provides alternatives to decord's AudioReader/AVReader (which have memory leaks).
4
+ """
5
+
6
+ import os
7
+ import tempfile
8
+ import subprocess
9
+ import numpy as np
10
+ from typing import Tuple, Optional, Union
11
+
12
+
13
+ def read_audio_ffmpeg(
14
+ video_path: str,
15
+ start_time: Optional[float] = None,
16
+ duration: Optional[float] = None,
17
+ target_sr: int = 16000
18
+ ) -> Tuple[np.ndarray, int]:
19
+ """
20
+ Extract audio from video file using ffmpeg subprocess calls.
21
+
22
+ Args:
23
+ video_path: Path to the video file.
24
+ start_time: Start time in seconds (optional).
25
+ duration: Duration to extract in seconds (optional).
26
+ target_sr: Target sample rate for the audio.
27
+
28
+ Returns:
29
+ Tuple of (audio_data, sample_rate) where audio_data is a numpy array.
30
+ """
31
+ # Create a temporary file for the audio
32
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file:
33
+ tmp_path = tmp_file.name
34
+
35
+ try:
36
+ # Build the ffmpeg command
37
+ cmd = ['ffmpeg', '-y', '-i', video_path]
38
+
39
+ # Add time parameters if specified
40
+ if start_time is not None:
41
+ cmd.extend(['-ss', str(start_time)])
42
+
43
+ if duration is not None:
44
+ cmd.extend(['-t', str(duration)])
45
+
46
+ # Set output parameters (mono, target sample rate)
47
+ cmd.extend([
48
+ '-ac', '1', # mono
49
+ '-ar', str(target_sr), # sample rate
50
+ '-vn', # no video
51
+ '-f', 'wav', # wav format
52
+ tmp_path
53
+ ])
54
+
55
+ # Execute the command
56
+ subprocess.run(cmd, check=True, stderr=subprocess.PIPE, stdout=subprocess.PIPE)
57
+
58
+ # Read the audio file using numpy
59
+ audio_data = np.fromfile(tmp_path, np.int16).astype(np.float32) / 32768.0 # Convert to float in [-1, 1]
60
+
61
+ return audio_data, target_sr
62
+
63
+ finally:
64
+ # Always clean up the temporary file
65
+ try:
66
+ os.unlink(tmp_path)
67
+ except:
68
+ pass
69
+
70
+
71
+ def read_audio_moviepy(
72
+ video_path: str,
73
+ start_time: Optional[float] = None,
74
+ duration: Optional[float] = None,
75
+ target_sr: int = 16000
76
+ ) -> Tuple[np.ndarray, int]:
77
+ """
78
+ Extract audio from video file using moviepy.
79
+ Requires the moviepy package: pip install moviepy
80
+
81
+ Args:
82
+ video_path: Path to the video file.
83
+ start_time: Start time in seconds (optional).
84
+ duration: Duration to extract in seconds (optional).
85
+ target_sr: Target sample rate for the audio.
86
+
87
+ Returns:
88
+ Tuple of (audio_data, sample_rate) where audio_data is a numpy array.
89
+ """
90
+ try:
91
+ from moviepy import VideoFileClip
92
+ except ImportError:
93
+ raise ImportError("moviepy is not installed. Install it with 'pip install moviepy'")
94
+
95
+ # Load video file
96
+ if start_time is not None or duration is not None:
97
+ start_t = start_time if start_time is not None else 0
98
+ end_t = start_t + duration if duration is not None else None
99
+ video = VideoFileClip(video_path).subclipped(start_t, end_t)
100
+ else:
101
+ video = VideoFileClip(video_path)
102
+ # Extract audio
103
+ audio = video.audio.with_fps(target_sr)
104
+
105
+ # Get audio data
106
+ audio_data = audio.to_soundarray()
107
+
108
+ # Convert to mono if stereo
109
+ if audio_data.ndim > 1 and audio_data.shape[1] > 1:
110
+ audio_data = np.mean(audio_data, axis=1)
111
+
112
+ # Clean up
113
+ video.close()
114
+
115
+ return audio_data, target_sr
116
+
117
+
118
+ # Helper function to choose the best available method
119
+ def read_audio(
120
+ video_path: str,
121
+ start_time: Optional[float] = None,
122
+ duration: Optional[float] = None,
123
+ target_sr: int = 16000,
124
+ method: str = 'ffmpeg'
125
+ ) -> Tuple[np.ndarray, int]:
126
+ """
127
+ Extract audio from video file using the specified method.
128
+
129
+ Args:
130
+ video_path: Path to the video file.
131
+ start_time: Start time in seconds (optional).
132
+ duration: Duration to extract in seconds (optional).
133
+ target_sr: Target sample rate for the audio.
134
+ method: Method to use ('ffmpeg' or 'moviepy').
135
+
136
+ Returns:
137
+ Tuple of (audio_data, sample_rate) where audio_data is a numpy array.
138
+ """
139
+ if method == 'moviepy':
140
+ return read_audio_moviepy(video_path, start_time, duration, target_sr)
141
+ else: # default to ffmpeg
142
+ return read_audio_ffmpeg(video_path, start_time, duration, target_sr)
@@ -0,0 +1,125 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Example script demonstrating how to use the memory-leak-free audio-video reading functions.
4
+ """
5
+
6
+ import os
7
+ import time
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ from av_utils import read_av_improved, read_av_batch
11
+ from audio_utils import read_audio
12
+ import argparse
13
+
14
+
15
+ def visualize_av_data(audio_data, video_frames, output_path=None):
16
+ """
17
+ Visualize audio and video data.
18
+
19
+ Args:
20
+ audio_data: Audio data as numpy array or list.
21
+ video_frames: Video frames as numpy array.
22
+ output_path: Path to save visualization (optional).
23
+ """
24
+ fig = plt.figure(figsize=(12, 6))
25
+
26
+ # Number of frames to show
27
+ num_frames = min(4, len(video_frames))
28
+
29
+ # Plot audio waveform
30
+ plt.subplot(2, num_frames, 1)
31
+ plt.plot(audio_data[:10000])
32
+ plt.title('Audio Waveform')
33
+ plt.grid(True)
34
+
35
+ # Plot audio spectrogram
36
+ plt.subplot(2, num_frames, 2)
37
+ plt.specgram(audio_data, NFFT=1024, Fs=16000)
38
+ plt.title('Audio Spectrogram')
39
+
40
+ # Plot sample frames
41
+ for i in range(num_frames):
42
+ plt.subplot(2, num_frames, num_frames+i+1)
43
+ plt.imshow(video_frames[i*len(video_frames)//num_frames])
44
+ plt.title(f'Frame {i*len(video_frames)//num_frames}')
45
+ plt.axis('off')
46
+
47
+ plt.tight_layout()
48
+
49
+ if output_path:
50
+ plt.savefig(output_path)
51
+ print(f"Visualization saved to {output_path}")
52
+
53
+ plt.show()
54
+
55
+
56
+ def benchmark_av_reading(video_path, num_iterations=10, use_batch=False):
57
+ """
58
+ Benchmark audio-video reading performance.
59
+
60
+ Args:
61
+ video_path: Path to the video file.
62
+ num_iterations: Number of iterations for benchmarking.
63
+ use_batch: Whether to use batch reading.
64
+ """
65
+ print(f"Benchmarking {'batch' if use_batch else 'single'} reading...")
66
+
67
+ # Perform warmup
68
+ if use_batch:
69
+ _ = read_av_batch([video_path])
70
+ else:
71
+ _ = read_av_improved(video_path)
72
+
73
+ # Measure performance
74
+ start_time = time.time()
75
+
76
+ for i in range(num_iterations):
77
+ if use_batch:
78
+ results = read_av_batch([video_path])
79
+ else:
80
+ audio, video = read_av_improved(video_path)
81
+
82
+ end_time = time.time()
83
+ avg_time = (end_time - start_time) / num_iterations
84
+
85
+ print(f"Average time per read: {avg_time:.4f} seconds")
86
+
87
+ return avg_time
88
+
89
+
90
+ def main():
91
+ parser = argparse.ArgumentParser(description="Demo for memory-leak-free audio-video reading")
92
+ parser.add_argument("--video", "-v", required=True, help="Path to the video file")
93
+ parser.add_argument("--output", "-o", help="Path to save visualization")
94
+ parser.add_argument("--benchmark", "-b", action="store_true", help="Run benchmarks")
95
+ parser.add_argument("--iterations", "-i", type=int, default=10, help="Number of benchmark iterations")
96
+
97
+ args = parser.parse_args()
98
+
99
+ if not os.path.exists(args.video):
100
+ print(f"Error: Video file not found: {args.video}")
101
+ return
102
+
103
+ # Load audio-video data
104
+ print(f"Reading audio-video data from {args.video}...")
105
+ audio, video = read_av_improved(args.video)
106
+
107
+ print(f"Video shape: {video.shape}")
108
+ print(f"Audio length: {len(audio)}")
109
+
110
+ # Visualize data
111
+ visualize_av_data(audio, video, args.output)
112
+
113
+ # Run benchmarks if requested
114
+ if args.benchmark:
115
+ print("\nRunning benchmarks...")
116
+ single_time = benchmark_av_reading(args.video, args.iterations, use_batch=False)
117
+ batch_time = benchmark_av_reading(args.video, args.iterations, use_batch=True)
118
+
119
+ print("\nBenchmark results:")
120
+ print(f"Single reading: {single_time:.4f} seconds per video")
121
+ print(f"Batch reading: {batch_time:.4f} seconds per video")
122
+
123
+
124
+ if __name__ == "__main__":
125
+ main()