flaxdiff 0.1.38__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/PKG-INFO +1 -1
  2. flaxdiff-0.2.0/flaxdiff/data/__init__.py +5 -0
  3. flaxdiff-0.2.0/flaxdiff/data/benchmark_decord.py +443 -0
  4. flaxdiff-0.2.0/flaxdiff/data/dataloaders.py +608 -0
  5. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/data/dataset_map.py +61 -6
  6. flaxdiff-0.2.0/flaxdiff/data/online_loader.py +992 -0
  7. flaxdiff-0.2.0/flaxdiff/data/sources/audio_utils.py +142 -0
  8. flaxdiff-0.2.0/flaxdiff/data/sources/av_example.py +125 -0
  9. flaxdiff-0.2.0/flaxdiff/data/sources/av_utils.py +590 -0
  10. flaxdiff-0.2.0/flaxdiff/data/sources/base.py +129 -0
  11. flaxdiff-0.2.0/flaxdiff/data/sources/images.py +309 -0
  12. flaxdiff-0.2.0/flaxdiff/data/sources/utils.py +158 -0
  13. flaxdiff-0.2.0/flaxdiff/data/sources/videos.py +250 -0
  14. flaxdiff-0.2.0/flaxdiff/data/sources/voxceleb2.py +412 -0
  15. flaxdiff-0.2.0/flaxdiff/inference/pipeline.py +260 -0
  16. flaxdiff-0.2.0/flaxdiff/inference/utils.py +320 -0
  17. flaxdiff-0.2.0/flaxdiff/inputs/__init__.py +173 -0
  18. flaxdiff-0.2.0/flaxdiff/inputs/encoders.py +98 -0
  19. flaxdiff-0.2.0/flaxdiff/metrics/ssim.py +0 -0
  20. flaxdiff-0.2.0/flaxdiff/models/__init__.py +2 -0
  21. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/models/attention.py +22 -16
  22. flaxdiff-0.2.0/flaxdiff/models/autoencoder/autoencoder.py +151 -0
  23. flaxdiff-0.2.0/flaxdiff/models/autoencoder/diffusers.py +154 -0
  24. flaxdiff-0.2.0/flaxdiff/models/autoencoder/simple_autoenc.py +58 -0
  25. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/models/common.py +8 -18
  26. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/models/simple_unet.py +6 -17
  27. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/models/simple_vit.py +9 -13
  28. flaxdiff-0.2.0/flaxdiff/models/unet_3d.py +446 -0
  29. flaxdiff-0.2.0/flaxdiff/models/unet_3d_blocks.py +505 -0
  30. flaxdiff-0.2.0/flaxdiff/samplers/common.py +433 -0
  31. flaxdiff-0.2.0/flaxdiff/samplers/ddim.py +49 -0
  32. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/schedulers/karras.py +20 -12
  33. flaxdiff-0.2.0/flaxdiff/trainer/__init__.py +3 -0
  34. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/trainer/autoencoder_trainer.py +1 -2
  35. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/trainer/diffusion_trainer.py +35 -29
  36. flaxdiff-0.2.0/flaxdiff/trainer/general_diffusion_trainer.py +583 -0
  37. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/trainer/simple_trainer.py +51 -16
  38. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/utils.py +128 -57
  39. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff.egg-info/PKG-INFO +1 -1
  40. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff.egg-info/SOURCES.txt +19 -5
  41. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/pyproject.toml +1 -1
  42. flaxdiff-0.1.38/flaxdiff/data/__init__.py +0 -1
  43. flaxdiff-0.1.38/flaxdiff/data/datasets.py +0 -169
  44. flaxdiff-0.1.38/flaxdiff/data/online_loader.py +0 -363
  45. flaxdiff-0.1.38/flaxdiff/data/sources/gcs.py +0 -81
  46. flaxdiff-0.1.38/flaxdiff/data/sources/tfds.py +0 -79
  47. flaxdiff-0.1.38/flaxdiff/models/__init__.py +0 -1
  48. flaxdiff-0.1.38/flaxdiff/models/autoencoder/autoencoder.py +0 -19
  49. flaxdiff-0.1.38/flaxdiff/models/autoencoder/diffusers.py +0 -91
  50. flaxdiff-0.1.38/flaxdiff/models/autoencoder/simple_autoenc.py +0 -26
  51. flaxdiff-0.1.38/flaxdiff/samplers/common.py +0 -171
  52. flaxdiff-0.1.38/flaxdiff/samplers/ddim.py +0 -10
  53. flaxdiff-0.1.38/flaxdiff/trainer/__init__.py +0 -2
  54. flaxdiff-0.1.38/flaxdiff/trainer/video_diffusion_trainer.py +0 -62
  55. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/README.md +0 -0
  56. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/__init__.py +0 -0
  57. /flaxdiff-0.1.38/flaxdiff/metrics/psnr.py → /flaxdiff-0.2.0/flaxdiff/inference/__init__.py +0 -0
  58. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/metrics/inception.py +0 -0
  59. /flaxdiff-0.1.38/flaxdiff/metrics/ssim.py → /flaxdiff-0.2.0/flaxdiff/metrics/psnr.py +0 -0
  60. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/metrics/utils.py +0 -0
  61. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/models/autoencoder/__init__.py +0 -0
  62. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/models/favor_fastattn.py +0 -0
  63. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/models/general.py +0 -0
  64. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/predictors/__init__.py +0 -0
  65. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/samplers/__init__.py +0 -0
  66. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/samplers/ddpm.py +0 -0
  67. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/samplers/euler.py +0 -0
  68. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/samplers/heun_sampler.py +0 -0
  69. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/samplers/multistep_dpm.py +0 -0
  70. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/samplers/rk4_sampler.py +0 -0
  71. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/schedulers/__init__.py +0 -0
  72. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/schedulers/common.py +0 -0
  73. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/schedulers/continuous.py +0 -0
  74. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/schedulers/cosine.py +0 -0
  75. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/schedulers/discrete.py +0 -0
  76. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/schedulers/exp.py +0 -0
  77. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/schedulers/linear.py +0 -0
  78. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff/schedulers/sqrt.py +0 -0
  79. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff.egg-info/dependency_links.txt +0 -0
  80. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff.egg-info/requires.txt +0 -0
  81. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/flaxdiff.egg-info/top_level.txt +0 -0
  82. {flaxdiff-0.1.38 → flaxdiff-0.2.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.38
3
+ Version: 0.2.0
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -0,0 +1,5 @@
1
+ from .online_loader import *
2
+ from .dataloaders import *
3
+ from .sources.base import *
4
+ from .sources.images import *
5
+ from .sources.videos import *
@@ -0,0 +1,443 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Benchmark script to test for memory leaks and performance in decord library.
4
+
5
+ This script specifically targets the read_av function and provides comprehensive
6
+ memory usage tracking and performance metrics.
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import time
12
+ import random
13
+ import gc
14
+ import argparse
15
+ import numpy as np
16
+ import matplotlib.pyplot as plt
17
+ import psutil
18
+ from tqdm import tqdm
19
+
20
+ try:
21
+ from decord import AVReader, VideoReader, cpu, gpu
22
+ HAS_DECORD = True
23
+ except ImportError:
24
+ print("Warning: decord library not found. Only OpenCV mode will be available.")
25
+ HAS_DECORD = False
26
+
27
+ import cv2
28
+
29
+
30
+ def gather_video_paths(directory):
31
+ """Gather all video file paths in a directory (recursively).
32
+
33
+ Args:
34
+ directory: Directory to search for video files.
35
+
36
+ Returns:
37
+ List of video file paths.
38
+ """
39
+ video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.webm']
40
+ video_paths = []
41
+
42
+ for root, _, files in os.walk(directory):
43
+ for file in files:
44
+ if any(file.lower().endswith(ext) for ext in video_extensions):
45
+ video_paths.append(os.path.join(root, file))
46
+
47
+ return video_paths
48
+
49
+
50
+ def read_av_standard(path, start=0, end=None, ctx=None):
51
+ """Read audio-video with standard decord approach.
52
+
53
+ Args:
54
+ path: Path to the video file.
55
+ start: Start frame index.
56
+ end: End frame index.
57
+ ctx: Decord context (CPU or GPU).
58
+
59
+ Returns:
60
+ Tuple of (audio, video) arrays.
61
+ """
62
+ if not HAS_DECORD:
63
+ raise ImportError("decord library not installed")
64
+
65
+ ctx = ctx or cpu(0)
66
+ vr = AVReader(path, ctx=ctx)
67
+ audio, video = vr[start:end]
68
+ return audio, video.asnumpy()
69
+
70
+
71
+ def read_av_cleanup(path, start=0, end=None, ctx=None):
72
+ """Read audio-video with explicit cleanup of decord objects.
73
+
74
+ Args:
75
+ path: Path to the video file.
76
+ start: Start frame index.
77
+ end: End frame index.
78
+ ctx: Decord context (CPU or GPU).
79
+
80
+ Returns:
81
+ Tuple of (audio, video) arrays.
82
+ """
83
+ if not HAS_DECORD:
84
+ raise ImportError("decord library not installed")
85
+
86
+ ctx = ctx or cpu(0)
87
+ vr = AVReader(path, ctx=ctx)
88
+ audio, video = vr[start:end]
89
+ audio_list = list(audio) # Copy audio data
90
+ video_np = video.asnumpy() # Convert to numpy array
91
+ del vr # Explicitly delete AVReader object
92
+ return audio_list, video_np
93
+
94
+
95
+ def read_video_opencv(path, max_frames=None):
96
+ """Read video using OpenCV instead of decord.
97
+
98
+ Args:
99
+ path: Path to the video file.
100
+ max_frames: Maximum number of frames to read.
101
+
102
+ Returns:
103
+ Video frames as numpy array.
104
+ """
105
+ cap = cv2.VideoCapture(path)
106
+ frames = []
107
+
108
+ while True:
109
+ ret, frame = cap.read()
110
+ if not ret:
111
+ break
112
+
113
+ # Convert BGR to RGB
114
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
115
+ frames.append(frame)
116
+
117
+ if max_frames and len(frames) >= max_frames:
118
+ break
119
+
120
+ cap.release()
121
+
122
+ # Stack frames into a video tensor [num_frames, height, width, channels]
123
+ if frames:
124
+ return np.stack(frames, axis=0)
125
+ else:
126
+ return np.array([]) # Empty array if no frames were read
127
+
128
+
129
+ def get_memory_usage():
130
+ """Get current memory usage in MB.
131
+
132
+ Returns:
133
+ Current memory usage in MB.
134
+ """
135
+ process = psutil.Process(os.getpid())
136
+ mem_info = process.memory_info()
137
+ return mem_info.rss / (1024 * 1024) # Convert bytes to MB
138
+
139
+
140
+ def test_for_memory_leak(video_paths, method='standard', num_iterations=100, sample_size=20):
141
+ """Test for memory leaks by repeatedly loading videos.
142
+
143
+ Args:
144
+ video_paths: List of video file paths.
145
+ method: Method to use for loading videos ('standard', 'cleanup', or 'opencv').
146
+ num_iterations: Number of iterations to run.
147
+ sample_size: Number of video paths to sample from.
148
+
149
+ Returns:
150
+ List of memory usage measurements.
151
+ """
152
+ memory_usage = []
153
+ sample_paths = random.sample(video_paths, min(sample_size, len(video_paths)))
154
+
155
+ # Record baseline memory usage
156
+ gc.collect()
157
+ baseline_memory = get_memory_usage()
158
+ memory_usage.append(baseline_memory)
159
+
160
+ print(f"Initial memory usage: {baseline_memory:.2f} MB")
161
+
162
+ # Load videos repeatedly and track memory usage
163
+ for i in tqdm(range(num_iterations), desc=f"Testing {method} method"):
164
+ path = random.choice(sample_paths)
165
+
166
+ try:
167
+ # Load the video using the specified method
168
+ if method == 'standard' and HAS_DECORD:
169
+ audio, video = read_av_standard(path)
170
+ del audio, video
171
+ elif method == 'cleanup' and HAS_DECORD:
172
+ audio, video = read_av_cleanup(path)
173
+ del audio, video
174
+ elif method == 'opencv':
175
+ video = read_video_opencv(path)
176
+ del video
177
+ else:
178
+ raise ValueError(f"Unknown method: {method}")
179
+
180
+ # Periodic garbage collection
181
+ if i % 5 == 0:
182
+ gc.collect()
183
+
184
+ # Record memory
185
+ memory_usage.append(get_memory_usage())
186
+
187
+ except Exception as e:
188
+ print(f"Error processing video {path}: {e}")
189
+ continue
190
+
191
+ # Final cleanup
192
+ gc.collect()
193
+ final_memory = get_memory_usage()
194
+ memory_usage.append(final_memory)
195
+
196
+ print(f"Final memory usage: {final_memory:.2f} MB")
197
+ print(f"Memory change: {final_memory - baseline_memory:.2f} MB")
198
+
199
+ return memory_usage
200
+
201
+
202
+ def benchmark_loading_speed(video_paths, method='standard', num_videos=30):
203
+ """Benchmark video loading speed.
204
+
205
+ Args:
206
+ video_paths: List of video file paths.
207
+ method: Method to use for loading videos ('standard', 'cleanup', or 'opencv').
208
+ num_videos: Number of videos to benchmark.
209
+
210
+ Returns:
211
+ Tuple of (load times, video sizes).
212
+ """
213
+ # Select random videos to load
214
+ selected_paths = random.sample(video_paths, min(num_videos, len(video_paths)))
215
+
216
+ load_times = []
217
+ video_sizes = []
218
+
219
+ print(f"Benchmarking {method} method...")
220
+
221
+ for path in tqdm(selected_paths, desc=f"Benchmarking {method}"):
222
+ try:
223
+ start_time = time.time()
224
+
225
+ # Load the video using specified method
226
+ if method == 'standard' and HAS_DECORD:
227
+ audio, video = read_av_standard(path)
228
+ elif method == 'cleanup' and HAS_DECORD:
229
+ audio, video = read_av_cleanup(path)
230
+ elif method == 'opencv':
231
+ video = read_video_opencv(path)
232
+ audio = None
233
+ else:
234
+ raise ValueError(f"Unknown method: {method}")
235
+
236
+ end_time = time.time()
237
+
238
+ # Calculate and store metrics
239
+ load_time = end_time - start_time
240
+ load_times.append(load_time)
241
+
242
+ # Get video size in MB
243
+ video_size = video.nbytes / (1024 * 1024) # Convert bytes to MB
244
+ video_sizes.append(video_size)
245
+
246
+ # Cleanup
247
+ del video
248
+ if audio is not None:
249
+ del audio
250
+
251
+ if len(load_times) % 10 == 0:
252
+ gc.collect()
253
+
254
+ except Exception as e:
255
+ print(f"Error benchmarking {path}: {e}")
256
+ continue
257
+
258
+ if not load_times:
259
+ print("No videos were successfully processed.")
260
+ return [], []
261
+
262
+ # Calculate statistics
263
+ avg_time = sum(load_times) / len(load_times)
264
+ avg_size = sum(video_sizes) / len(video_sizes) if video_sizes else 0
265
+ avg_speed = sum(video_sizes) / sum(load_times) if sum(load_times) > 0 else 0 # MB/s
266
+
267
+ print(f"Average load time: {avg_time:.4f} seconds")
268
+ print(f"Average video size: {avg_size:.2f} MB")
269
+ print(f"Average loading speed: {avg_speed:.2f} MB/s")
270
+
271
+ return load_times, video_sizes
272
+
273
+
274
+ def plot_memory_usage(results, output_dir=None):
275
+ """Plot memory usage over time.
276
+
277
+ Args:
278
+ results: Dictionary of memory usage results.
279
+ output_dir: Directory to save plots to.
280
+ """
281
+ plt.figure(figsize=(12, 6))
282
+
283
+ for method, memory_usage in results.items():
284
+ plt.plot(memory_usage, label=method)
285
+
286
+ plt.title('Memory Usage During Repeated Video Loading')
287
+ plt.xlabel('Iteration')
288
+ plt.ylabel('Memory Usage (MB)')
289
+ plt.legend()
290
+ plt.grid(True)
291
+
292
+ if output_dir:
293
+ plt.savefig(os.path.join(output_dir, 'memory_usage.png'))
294
+
295
+ plt.show()
296
+
297
+
298
+ def plot_loading_speed(results, output_dir=None):
299
+ """Plot loading speed comparison.
300
+
301
+ Args:
302
+ results: Dictionary of loading speed results.
303
+ output_dir: Directory to save plots to.
304
+ """
305
+ methods = list(results.keys())
306
+ times = [results[m][0] for m in methods]
307
+ sizes = [results[m][1] for m in methods]
308
+
309
+ plt.figure(figsize=(15, 5))
310
+
311
+ # Plot 1: Load time comparison (box plot)
312
+ plt.subplot(1, 3, 1)
313
+ plt.boxplot(times, labels=methods)
314
+ plt.title('Load Time Comparison')
315
+ plt.ylabel('Time (seconds)')
316
+
317
+ # Plot 2: Load time vs video size (scatter)
318
+ plt.subplot(1, 3, 2)
319
+ for i, method in enumerate(methods):
320
+ plt.scatter(sizes[i], times[i], alpha=0.7, label=method)
321
+ plt.title('Load Time vs. Video Size')
322
+ plt.xlabel('Video Size (MB)')
323
+ plt.ylabel('Time (seconds)')
324
+ plt.legend()
325
+
326
+ # Plot 3: Loading speed comparison (box plot)
327
+ plt.subplot(1, 3, 3)
328
+ speeds = []
329
+ for i in range(len(methods)):
330
+ # Calculate MB/s for each video
331
+ speed = [s/t for s, t in zip(sizes[i], times[i]) if t > 0]
332
+ speeds.append(speed)
333
+
334
+ plt.boxplot(speeds, labels=methods)
335
+ plt.title('Loading Speed Comparison')
336
+ plt.ylabel('Speed (MB/s)')
337
+
338
+ plt.tight_layout()
339
+
340
+ if output_dir:
341
+ plt.savefig(os.path.join(output_dir, 'loading_speed.png'))
342
+
343
+ plt.show()
344
+
345
+
346
+ def run_full_benchmark(videos_dir, output_dir=None, iterations=100, num_videos=30, sample_size=20):
347
+ """Run a full benchmark suite.
348
+
349
+ Args:
350
+ videos_dir: Directory containing video files.
351
+ output_dir: Directory to save results to.
352
+ iterations: Number of iterations for memory leak test.
353
+ num_videos: Number of videos for performance benchmark.
354
+ sample_size: Sample size for memory leak test.
355
+ """
356
+ # Create output directory if it doesn't exist
357
+ if output_dir and not os.path.exists(output_dir):
358
+ os.makedirs(output_dir)
359
+
360
+ # Gather video paths
361
+ print(f"Searching for videos in {videos_dir}...")
362
+ video_paths = gather_video_paths(videos_dir)
363
+ print(f"Found {len(video_paths)} videos.")
364
+
365
+ if not video_paths:
366
+ print("No videos found. Exiting.")
367
+ return
368
+
369
+ # Memory leak tests
370
+ print("\n=== Running memory leak tests ===\n")
371
+ memory_results = {}
372
+
373
+ methods = ['opencv']
374
+ if HAS_DECORD:
375
+ methods = ['standard', 'cleanup', 'opencv'] # Test all methods if decord is available
376
+
377
+ for method in methods:
378
+ print(f"\nTesting {method} method for memory leaks...")
379
+ memory_usage = test_for_memory_leak(
380
+ video_paths,
381
+ method=method,
382
+ num_iterations=iterations,
383
+ sample_size=sample_size
384
+ )
385
+ memory_results[method] = memory_usage
386
+
387
+ # Plot memory usage results
388
+ plot_memory_usage(memory_results, output_dir)
389
+
390
+ # Performance benchmarks
391
+ print("\n=== Running performance benchmarks ===\n")
392
+ performance_results = {}
393
+
394
+ for method in methods:
395
+ print(f"\nBenchmarking {method} method...")
396
+ times, sizes = benchmark_loading_speed(
397
+ video_paths,
398
+ method=method,
399
+ num_videos=num_videos
400
+ )
401
+ performance_results[method] = (times, sizes)
402
+
403
+ # Plot performance results
404
+ plot_loading_speed(performance_results, output_dir)
405
+
406
+ # Save results to files if output_dir is specified
407
+ if output_dir:
408
+ # Save memory results
409
+ for method, usage in memory_results.items():
410
+ with open(os.path.join(output_dir, f'memory_{method}.txt'), 'w') as f:
411
+ f.write('\n'.join(str(x) for x in usage))
412
+
413
+ # Save performance results
414
+ for method, (times, sizes) in performance_results.items():
415
+ with open(os.path.join(output_dir, f'performance_{method}.txt'), 'w') as f:
416
+ f.write('time,size\n')
417
+ for t, s in zip(times, sizes):
418
+ f.write(f'{t},{s}\n')
419
+
420
+ print("\nBenchmark complete.")
421
+
422
+
423
+ def main():
424
+ """Main function."""
425
+ parser = argparse.ArgumentParser(description='Benchmark decord and OpenCV video loading.')
426
+ parser.add_argument('--videos_dir', '-d', required=True, help='Directory containing video files')
427
+ parser.add_argument('--output_dir', '-o', help='Directory to save results to')
428
+ parser.add_argument('--iterations', '-i', type=int, default=100, help='Number of iterations for memory leak test')
429
+ parser.add_argument('--num_videos', '-n', type=int, default=30, help='Number of videos for performance benchmark')
430
+ parser.add_argument('--sample_size', '-s', type=int, default=20, help='Sample size for memory leak test')
431
+ args = parser.parse_args()
432
+
433
+ run_full_benchmark(
434
+ args.videos_dir,
435
+ args.output_dir,
436
+ args.iterations,
437
+ args.num_videos,
438
+ args.sample_size
439
+ )
440
+
441
+
442
+ if __name__ == '__main__':
443
+ main()