flaxdiff 0.1.38.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/__init__.py +5 -1
- flaxdiff/data/benchmark_decord.py +443 -0
- flaxdiff/data/dataloaders.py +608 -0
- flaxdiff/data/dataset_map.py +61 -6
- flaxdiff/data/online_loader.py +779 -150
- flaxdiff/data/sources/audio_utils.py +142 -0
- flaxdiff/data/sources/av_example.py +125 -0
- flaxdiff/data/sources/av_utils.py +590 -0
- flaxdiff/data/sources/base.py +129 -0
- flaxdiff/data/sources/images.py +309 -0
- flaxdiff/data/sources/utils.py +158 -0
- flaxdiff/data/sources/videos.py +250 -0
- flaxdiff/data/sources/voxceleb2.py +412 -0
- flaxdiff/inference/__init__.py +0 -0
- flaxdiff/inference/pipeline.py +260 -0
- flaxdiff/inference/utils.py +320 -0
- flaxdiff/inputs/__init__.py +173 -0
- flaxdiff/inputs/encoders.py +98 -0
- flaxdiff/models/__init__.py +2 -1
- flaxdiff/models/autoencoder/autoencoder.py +141 -9
- flaxdiff/models/autoencoder/diffusers.py +88 -25
- flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
- flaxdiff/models/simple_unet.py +5 -5
- flaxdiff/models/simple_vit.py +1 -1
- flaxdiff/models/unet_3d.py +446 -0
- flaxdiff/models/unet_3d_blocks.py +505 -0
- flaxdiff/samplers/common.py +358 -96
- flaxdiff/samplers/ddim.py +44 -5
- flaxdiff/schedulers/karras.py +20 -12
- flaxdiff/trainer/__init__.py +2 -1
- flaxdiff/trainer/autoencoder_trainer.py +1 -2
- flaxdiff/trainer/diffusion_trainer.py +33 -27
- flaxdiff/trainer/general_diffusion_trainer.py +583 -0
- flaxdiff/trainer/simple_trainer.py +48 -31
- flaxdiff/utils.py +128 -57
- {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
- flaxdiff-0.2.0.dist-info/RECORD +64 -0
- {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
- flaxdiff/data/datasets.py +0 -169
- flaxdiff/data/sources/gcs.py +0 -81
- flaxdiff/data/sources/tfds.py +0 -79
- flaxdiff/trainer/video_diffusion_trainer.py +0 -62
- flaxdiff-0.1.38.1.dist-info/RECORD +0 -50
- {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
flaxdiff/data/__init__.py
CHANGED
@@ -0,0 +1,443 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
"""
|
3
|
+
Benchmark script to test for memory leaks and performance in decord library.
|
4
|
+
|
5
|
+
This script specifically targets the read_av function and provides comprehensive
|
6
|
+
memory usage tracking and performance metrics.
|
7
|
+
"""
|
8
|
+
|
9
|
+
import os
|
10
|
+
import sys
|
11
|
+
import time
|
12
|
+
import random
|
13
|
+
import gc
|
14
|
+
import argparse
|
15
|
+
import numpy as np
|
16
|
+
import matplotlib.pyplot as plt
|
17
|
+
import psutil
|
18
|
+
from tqdm import tqdm
|
19
|
+
|
20
|
+
try:
|
21
|
+
from decord import AVReader, VideoReader, cpu, gpu
|
22
|
+
HAS_DECORD = True
|
23
|
+
except ImportError:
|
24
|
+
print("Warning: decord library not found. Only OpenCV mode will be available.")
|
25
|
+
HAS_DECORD = False
|
26
|
+
|
27
|
+
import cv2
|
28
|
+
|
29
|
+
|
30
|
+
def gather_video_paths(directory):
|
31
|
+
"""Gather all video file paths in a directory (recursively).
|
32
|
+
|
33
|
+
Args:
|
34
|
+
directory: Directory to search for video files.
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
List of video file paths.
|
38
|
+
"""
|
39
|
+
video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.webm']
|
40
|
+
video_paths = []
|
41
|
+
|
42
|
+
for root, _, files in os.walk(directory):
|
43
|
+
for file in files:
|
44
|
+
if any(file.lower().endswith(ext) for ext in video_extensions):
|
45
|
+
video_paths.append(os.path.join(root, file))
|
46
|
+
|
47
|
+
return video_paths
|
48
|
+
|
49
|
+
|
50
|
+
def read_av_standard(path, start=0, end=None, ctx=None):
|
51
|
+
"""Read audio-video with standard decord approach.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
path: Path to the video file.
|
55
|
+
start: Start frame index.
|
56
|
+
end: End frame index.
|
57
|
+
ctx: Decord context (CPU or GPU).
|
58
|
+
|
59
|
+
Returns:
|
60
|
+
Tuple of (audio, video) arrays.
|
61
|
+
"""
|
62
|
+
if not HAS_DECORD:
|
63
|
+
raise ImportError("decord library not installed")
|
64
|
+
|
65
|
+
ctx = ctx or cpu(0)
|
66
|
+
vr = AVReader(path, ctx=ctx)
|
67
|
+
audio, video = vr[start:end]
|
68
|
+
return audio, video.asnumpy()
|
69
|
+
|
70
|
+
|
71
|
+
def read_av_cleanup(path, start=0, end=None, ctx=None):
|
72
|
+
"""Read audio-video with explicit cleanup of decord objects.
|
73
|
+
|
74
|
+
Args:
|
75
|
+
path: Path to the video file.
|
76
|
+
start: Start frame index.
|
77
|
+
end: End frame index.
|
78
|
+
ctx: Decord context (CPU or GPU).
|
79
|
+
|
80
|
+
Returns:
|
81
|
+
Tuple of (audio, video) arrays.
|
82
|
+
"""
|
83
|
+
if not HAS_DECORD:
|
84
|
+
raise ImportError("decord library not installed")
|
85
|
+
|
86
|
+
ctx = ctx or cpu(0)
|
87
|
+
vr = AVReader(path, ctx=ctx)
|
88
|
+
audio, video = vr[start:end]
|
89
|
+
audio_list = list(audio) # Copy audio data
|
90
|
+
video_np = video.asnumpy() # Convert to numpy array
|
91
|
+
del vr # Explicitly delete AVReader object
|
92
|
+
return audio_list, video_np
|
93
|
+
|
94
|
+
|
95
|
+
def read_video_opencv(path, max_frames=None):
|
96
|
+
"""Read video using OpenCV instead of decord.
|
97
|
+
|
98
|
+
Args:
|
99
|
+
path: Path to the video file.
|
100
|
+
max_frames: Maximum number of frames to read.
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
Video frames as numpy array.
|
104
|
+
"""
|
105
|
+
cap = cv2.VideoCapture(path)
|
106
|
+
frames = []
|
107
|
+
|
108
|
+
while True:
|
109
|
+
ret, frame = cap.read()
|
110
|
+
if not ret:
|
111
|
+
break
|
112
|
+
|
113
|
+
# Convert BGR to RGB
|
114
|
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
115
|
+
frames.append(frame)
|
116
|
+
|
117
|
+
if max_frames and len(frames) >= max_frames:
|
118
|
+
break
|
119
|
+
|
120
|
+
cap.release()
|
121
|
+
|
122
|
+
# Stack frames into a video tensor [num_frames, height, width, channels]
|
123
|
+
if frames:
|
124
|
+
return np.stack(frames, axis=0)
|
125
|
+
else:
|
126
|
+
return np.array([]) # Empty array if no frames were read
|
127
|
+
|
128
|
+
|
129
|
+
def get_memory_usage():
|
130
|
+
"""Get current memory usage in MB.
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
Current memory usage in MB.
|
134
|
+
"""
|
135
|
+
process = psutil.Process(os.getpid())
|
136
|
+
mem_info = process.memory_info()
|
137
|
+
return mem_info.rss / (1024 * 1024) # Convert bytes to MB
|
138
|
+
|
139
|
+
|
140
|
+
def test_for_memory_leak(video_paths, method='standard', num_iterations=100, sample_size=20):
|
141
|
+
"""Test for memory leaks by repeatedly loading videos.
|
142
|
+
|
143
|
+
Args:
|
144
|
+
video_paths: List of video file paths.
|
145
|
+
method: Method to use for loading videos ('standard', 'cleanup', or 'opencv').
|
146
|
+
num_iterations: Number of iterations to run.
|
147
|
+
sample_size: Number of video paths to sample from.
|
148
|
+
|
149
|
+
Returns:
|
150
|
+
List of memory usage measurements.
|
151
|
+
"""
|
152
|
+
memory_usage = []
|
153
|
+
sample_paths = random.sample(video_paths, min(sample_size, len(video_paths)))
|
154
|
+
|
155
|
+
# Record baseline memory usage
|
156
|
+
gc.collect()
|
157
|
+
baseline_memory = get_memory_usage()
|
158
|
+
memory_usage.append(baseline_memory)
|
159
|
+
|
160
|
+
print(f"Initial memory usage: {baseline_memory:.2f} MB")
|
161
|
+
|
162
|
+
# Load videos repeatedly and track memory usage
|
163
|
+
for i in tqdm(range(num_iterations), desc=f"Testing {method} method"):
|
164
|
+
path = random.choice(sample_paths)
|
165
|
+
|
166
|
+
try:
|
167
|
+
# Load the video using the specified method
|
168
|
+
if method == 'standard' and HAS_DECORD:
|
169
|
+
audio, video = read_av_standard(path)
|
170
|
+
del audio, video
|
171
|
+
elif method == 'cleanup' and HAS_DECORD:
|
172
|
+
audio, video = read_av_cleanup(path)
|
173
|
+
del audio, video
|
174
|
+
elif method == 'opencv':
|
175
|
+
video = read_video_opencv(path)
|
176
|
+
del video
|
177
|
+
else:
|
178
|
+
raise ValueError(f"Unknown method: {method}")
|
179
|
+
|
180
|
+
# Periodic garbage collection
|
181
|
+
if i % 5 == 0:
|
182
|
+
gc.collect()
|
183
|
+
|
184
|
+
# Record memory
|
185
|
+
memory_usage.append(get_memory_usage())
|
186
|
+
|
187
|
+
except Exception as e:
|
188
|
+
print(f"Error processing video {path}: {e}")
|
189
|
+
continue
|
190
|
+
|
191
|
+
# Final cleanup
|
192
|
+
gc.collect()
|
193
|
+
final_memory = get_memory_usage()
|
194
|
+
memory_usage.append(final_memory)
|
195
|
+
|
196
|
+
print(f"Final memory usage: {final_memory:.2f} MB")
|
197
|
+
print(f"Memory change: {final_memory - baseline_memory:.2f} MB")
|
198
|
+
|
199
|
+
return memory_usage
|
200
|
+
|
201
|
+
|
202
|
+
def benchmark_loading_speed(video_paths, method='standard', num_videos=30):
|
203
|
+
"""Benchmark video loading speed.
|
204
|
+
|
205
|
+
Args:
|
206
|
+
video_paths: List of video file paths.
|
207
|
+
method: Method to use for loading videos ('standard', 'cleanup', or 'opencv').
|
208
|
+
num_videos: Number of videos to benchmark.
|
209
|
+
|
210
|
+
Returns:
|
211
|
+
Tuple of (load times, video sizes).
|
212
|
+
"""
|
213
|
+
# Select random videos to load
|
214
|
+
selected_paths = random.sample(video_paths, min(num_videos, len(video_paths)))
|
215
|
+
|
216
|
+
load_times = []
|
217
|
+
video_sizes = []
|
218
|
+
|
219
|
+
print(f"Benchmarking {method} method...")
|
220
|
+
|
221
|
+
for path in tqdm(selected_paths, desc=f"Benchmarking {method}"):
|
222
|
+
try:
|
223
|
+
start_time = time.time()
|
224
|
+
|
225
|
+
# Load the video using specified method
|
226
|
+
if method == 'standard' and HAS_DECORD:
|
227
|
+
audio, video = read_av_standard(path)
|
228
|
+
elif method == 'cleanup' and HAS_DECORD:
|
229
|
+
audio, video = read_av_cleanup(path)
|
230
|
+
elif method == 'opencv':
|
231
|
+
video = read_video_opencv(path)
|
232
|
+
audio = None
|
233
|
+
else:
|
234
|
+
raise ValueError(f"Unknown method: {method}")
|
235
|
+
|
236
|
+
end_time = time.time()
|
237
|
+
|
238
|
+
# Calculate and store metrics
|
239
|
+
load_time = end_time - start_time
|
240
|
+
load_times.append(load_time)
|
241
|
+
|
242
|
+
# Get video size in MB
|
243
|
+
video_size = video.nbytes / (1024 * 1024) # Convert bytes to MB
|
244
|
+
video_sizes.append(video_size)
|
245
|
+
|
246
|
+
# Cleanup
|
247
|
+
del video
|
248
|
+
if audio is not None:
|
249
|
+
del audio
|
250
|
+
|
251
|
+
if len(load_times) % 10 == 0:
|
252
|
+
gc.collect()
|
253
|
+
|
254
|
+
except Exception as e:
|
255
|
+
print(f"Error benchmarking {path}: {e}")
|
256
|
+
continue
|
257
|
+
|
258
|
+
if not load_times:
|
259
|
+
print("No videos were successfully processed.")
|
260
|
+
return [], []
|
261
|
+
|
262
|
+
# Calculate statistics
|
263
|
+
avg_time = sum(load_times) / len(load_times)
|
264
|
+
avg_size = sum(video_sizes) / len(video_sizes) if video_sizes else 0
|
265
|
+
avg_speed = sum(video_sizes) / sum(load_times) if sum(load_times) > 0 else 0 # MB/s
|
266
|
+
|
267
|
+
print(f"Average load time: {avg_time:.4f} seconds")
|
268
|
+
print(f"Average video size: {avg_size:.2f} MB")
|
269
|
+
print(f"Average loading speed: {avg_speed:.2f} MB/s")
|
270
|
+
|
271
|
+
return load_times, video_sizes
|
272
|
+
|
273
|
+
|
274
|
+
def plot_memory_usage(results, output_dir=None):
|
275
|
+
"""Plot memory usage over time.
|
276
|
+
|
277
|
+
Args:
|
278
|
+
results: Dictionary of memory usage results.
|
279
|
+
output_dir: Directory to save plots to.
|
280
|
+
"""
|
281
|
+
plt.figure(figsize=(12, 6))
|
282
|
+
|
283
|
+
for method, memory_usage in results.items():
|
284
|
+
plt.plot(memory_usage, label=method)
|
285
|
+
|
286
|
+
plt.title('Memory Usage During Repeated Video Loading')
|
287
|
+
plt.xlabel('Iteration')
|
288
|
+
plt.ylabel('Memory Usage (MB)')
|
289
|
+
plt.legend()
|
290
|
+
plt.grid(True)
|
291
|
+
|
292
|
+
if output_dir:
|
293
|
+
plt.savefig(os.path.join(output_dir, 'memory_usage.png'))
|
294
|
+
|
295
|
+
plt.show()
|
296
|
+
|
297
|
+
|
298
|
+
def plot_loading_speed(results, output_dir=None):
|
299
|
+
"""Plot loading speed comparison.
|
300
|
+
|
301
|
+
Args:
|
302
|
+
results: Dictionary of loading speed results.
|
303
|
+
output_dir: Directory to save plots to.
|
304
|
+
"""
|
305
|
+
methods = list(results.keys())
|
306
|
+
times = [results[m][0] for m in methods]
|
307
|
+
sizes = [results[m][1] for m in methods]
|
308
|
+
|
309
|
+
plt.figure(figsize=(15, 5))
|
310
|
+
|
311
|
+
# Plot 1: Load time comparison (box plot)
|
312
|
+
plt.subplot(1, 3, 1)
|
313
|
+
plt.boxplot(times, labels=methods)
|
314
|
+
plt.title('Load Time Comparison')
|
315
|
+
plt.ylabel('Time (seconds)')
|
316
|
+
|
317
|
+
# Plot 2: Load time vs video size (scatter)
|
318
|
+
plt.subplot(1, 3, 2)
|
319
|
+
for i, method in enumerate(methods):
|
320
|
+
plt.scatter(sizes[i], times[i], alpha=0.7, label=method)
|
321
|
+
plt.title('Load Time vs. Video Size')
|
322
|
+
plt.xlabel('Video Size (MB)')
|
323
|
+
plt.ylabel('Time (seconds)')
|
324
|
+
plt.legend()
|
325
|
+
|
326
|
+
# Plot 3: Loading speed comparison (box plot)
|
327
|
+
plt.subplot(1, 3, 3)
|
328
|
+
speeds = []
|
329
|
+
for i in range(len(methods)):
|
330
|
+
# Calculate MB/s for each video
|
331
|
+
speed = [s/t for s, t in zip(sizes[i], times[i]) if t > 0]
|
332
|
+
speeds.append(speed)
|
333
|
+
|
334
|
+
plt.boxplot(speeds, labels=methods)
|
335
|
+
plt.title('Loading Speed Comparison')
|
336
|
+
plt.ylabel('Speed (MB/s)')
|
337
|
+
|
338
|
+
plt.tight_layout()
|
339
|
+
|
340
|
+
if output_dir:
|
341
|
+
plt.savefig(os.path.join(output_dir, 'loading_speed.png'))
|
342
|
+
|
343
|
+
plt.show()
|
344
|
+
|
345
|
+
|
346
|
+
def run_full_benchmark(videos_dir, output_dir=None, iterations=100, num_videos=30, sample_size=20):
|
347
|
+
"""Run a full benchmark suite.
|
348
|
+
|
349
|
+
Args:
|
350
|
+
videos_dir: Directory containing video files.
|
351
|
+
output_dir: Directory to save results to.
|
352
|
+
iterations: Number of iterations for memory leak test.
|
353
|
+
num_videos: Number of videos for performance benchmark.
|
354
|
+
sample_size: Sample size for memory leak test.
|
355
|
+
"""
|
356
|
+
# Create output directory if it doesn't exist
|
357
|
+
if output_dir and not os.path.exists(output_dir):
|
358
|
+
os.makedirs(output_dir)
|
359
|
+
|
360
|
+
# Gather video paths
|
361
|
+
print(f"Searching for videos in {videos_dir}...")
|
362
|
+
video_paths = gather_video_paths(videos_dir)
|
363
|
+
print(f"Found {len(video_paths)} videos.")
|
364
|
+
|
365
|
+
if not video_paths:
|
366
|
+
print("No videos found. Exiting.")
|
367
|
+
return
|
368
|
+
|
369
|
+
# Memory leak tests
|
370
|
+
print("\n=== Running memory leak tests ===\n")
|
371
|
+
memory_results = {}
|
372
|
+
|
373
|
+
methods = ['opencv']
|
374
|
+
if HAS_DECORD:
|
375
|
+
methods = ['standard', 'cleanup', 'opencv'] # Test all methods if decord is available
|
376
|
+
|
377
|
+
for method in methods:
|
378
|
+
print(f"\nTesting {method} method for memory leaks...")
|
379
|
+
memory_usage = test_for_memory_leak(
|
380
|
+
video_paths,
|
381
|
+
method=method,
|
382
|
+
num_iterations=iterations,
|
383
|
+
sample_size=sample_size
|
384
|
+
)
|
385
|
+
memory_results[method] = memory_usage
|
386
|
+
|
387
|
+
# Plot memory usage results
|
388
|
+
plot_memory_usage(memory_results, output_dir)
|
389
|
+
|
390
|
+
# Performance benchmarks
|
391
|
+
print("\n=== Running performance benchmarks ===\n")
|
392
|
+
performance_results = {}
|
393
|
+
|
394
|
+
for method in methods:
|
395
|
+
print(f"\nBenchmarking {method} method...")
|
396
|
+
times, sizes = benchmark_loading_speed(
|
397
|
+
video_paths,
|
398
|
+
method=method,
|
399
|
+
num_videos=num_videos
|
400
|
+
)
|
401
|
+
performance_results[method] = (times, sizes)
|
402
|
+
|
403
|
+
# Plot performance results
|
404
|
+
plot_loading_speed(performance_results, output_dir)
|
405
|
+
|
406
|
+
# Save results to files if output_dir is specified
|
407
|
+
if output_dir:
|
408
|
+
# Save memory results
|
409
|
+
for method, usage in memory_results.items():
|
410
|
+
with open(os.path.join(output_dir, f'memory_{method}.txt'), 'w') as f:
|
411
|
+
f.write('\n'.join(str(x) for x in usage))
|
412
|
+
|
413
|
+
# Save performance results
|
414
|
+
for method, (times, sizes) in performance_results.items():
|
415
|
+
with open(os.path.join(output_dir, f'performance_{method}.txt'), 'w') as f:
|
416
|
+
f.write('time,size\n')
|
417
|
+
for t, s in zip(times, sizes):
|
418
|
+
f.write(f'{t},{s}\n')
|
419
|
+
|
420
|
+
print("\nBenchmark complete.")
|
421
|
+
|
422
|
+
|
423
|
+
def main():
|
424
|
+
"""Main function."""
|
425
|
+
parser = argparse.ArgumentParser(description='Benchmark decord and OpenCV video loading.')
|
426
|
+
parser.add_argument('--videos_dir', '-d', required=True, help='Directory containing video files')
|
427
|
+
parser.add_argument('--output_dir', '-o', help='Directory to save results to')
|
428
|
+
parser.add_argument('--iterations', '-i', type=int, default=100, help='Number of iterations for memory leak test')
|
429
|
+
parser.add_argument('--num_videos', '-n', type=int, default=30, help='Number of videos for performance benchmark')
|
430
|
+
parser.add_argument('--sample_size', '-s', type=int, default=20, help='Sample size for memory leak test')
|
431
|
+
args = parser.parse_args()
|
432
|
+
|
433
|
+
run_full_benchmark(
|
434
|
+
args.videos_dir,
|
435
|
+
args.output_dir,
|
436
|
+
args.iterations,
|
437
|
+
args.num_videos,
|
438
|
+
args.sample_size
|
439
|
+
)
|
440
|
+
|
441
|
+
|
442
|
+
if __name__ == '__main__':
|
443
|
+
main()
|