media-engine 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/clip.py +79 -0
- cli/faces.py +91 -0
- cli/metadata.py +68 -0
- cli/motion.py +77 -0
- cli/objects.py +94 -0
- cli/ocr.py +93 -0
- cli/scenes.py +57 -0
- cli/telemetry.py +65 -0
- cli/transcript.py +76 -0
- media_engine/__init__.py +7 -0
- media_engine/_version.py +34 -0
- media_engine/app.py +80 -0
- media_engine/batch/__init__.py +56 -0
- media_engine/batch/models.py +99 -0
- media_engine/batch/processor.py +1131 -0
- media_engine/batch/queue.py +232 -0
- media_engine/batch/state.py +30 -0
- media_engine/batch/timing.py +321 -0
- media_engine/cli.py +17 -0
- media_engine/config.py +674 -0
- media_engine/extractors/__init__.py +75 -0
- media_engine/extractors/clip.py +401 -0
- media_engine/extractors/faces.py +459 -0
- media_engine/extractors/frame_buffer.py +351 -0
- media_engine/extractors/frames.py +402 -0
- media_engine/extractors/metadata/__init__.py +127 -0
- media_engine/extractors/metadata/apple.py +169 -0
- media_engine/extractors/metadata/arri.py +118 -0
- media_engine/extractors/metadata/avchd.py +208 -0
- media_engine/extractors/metadata/avchd_gps.py +270 -0
- media_engine/extractors/metadata/base.py +688 -0
- media_engine/extractors/metadata/blackmagic.py +139 -0
- media_engine/extractors/metadata/camera_360.py +276 -0
- media_engine/extractors/metadata/canon.py +290 -0
- media_engine/extractors/metadata/dji.py +371 -0
- media_engine/extractors/metadata/dv.py +121 -0
- media_engine/extractors/metadata/ffmpeg.py +76 -0
- media_engine/extractors/metadata/generic.py +119 -0
- media_engine/extractors/metadata/gopro.py +256 -0
- media_engine/extractors/metadata/red.py +305 -0
- media_engine/extractors/metadata/registry.py +114 -0
- media_engine/extractors/metadata/sony.py +442 -0
- media_engine/extractors/metadata/tesla.py +157 -0
- media_engine/extractors/motion.py +765 -0
- media_engine/extractors/objects.py +245 -0
- media_engine/extractors/objects_qwen.py +754 -0
- media_engine/extractors/ocr.py +268 -0
- media_engine/extractors/scenes.py +82 -0
- media_engine/extractors/shot_type.py +217 -0
- media_engine/extractors/telemetry.py +262 -0
- media_engine/extractors/transcribe.py +579 -0
- media_engine/extractors/translate.py +121 -0
- media_engine/extractors/vad.py +263 -0
- media_engine/main.py +68 -0
- media_engine/py.typed +0 -0
- media_engine/routers/__init__.py +15 -0
- media_engine/routers/batch.py +78 -0
- media_engine/routers/health.py +93 -0
- media_engine/routers/models.py +211 -0
- media_engine/routers/settings.py +87 -0
- media_engine/routers/utils.py +135 -0
- media_engine/schemas.py +581 -0
- media_engine/utils/__init__.py +5 -0
- media_engine/utils/logging.py +54 -0
- media_engine/utils/memory.py +49 -0
- media_engine-0.1.0.dist-info/METADATA +276 -0
- media_engine-0.1.0.dist-info/RECORD +70 -0
- media_engine-0.1.0.dist-info/WHEEL +4 -0
- media_engine-0.1.0.dist-info/entry_points.txt +11 -0
- media_engine-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
"""OCR extraction using EasyOCR with fast MSER pre-filtering."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import gc
|
|
6
|
+
import logging
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import cv2
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
from media_engine.config import DeviceType, get_device, get_settings
|
|
13
|
+
from media_engine.extractors.frame_buffer import SharedFrameBuffer
|
|
14
|
+
from media_engine.schemas import BoundingBox, OcrDetection, OcrResult
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
# Singleton reader instance (lazy loaded)
|
|
19
|
+
_ocr_reader: Any = None
|
|
20
|
+
_ocr_languages: list[str] | None = None
|
|
21
|
+
|
|
22
|
+
# MSER detector (reusable, no state)
|
|
23
|
+
_mser_detector: Any = None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _get_mser_detector() -> Any:
|
|
27
|
+
"""Get or create MSER detector (singleton)."""
|
|
28
|
+
global _mser_detector
|
|
29
|
+
if _mser_detector is None:
|
|
30
|
+
_mser_detector = cv2.MSER_create( # type: ignore[attr-defined]
|
|
31
|
+
delta=5, # Stability threshold
|
|
32
|
+
min_area=50, # Min region size
|
|
33
|
+
max_area=14400, # Max region size (120x120)
|
|
34
|
+
max_variation=0.25,
|
|
35
|
+
)
|
|
36
|
+
return _mser_detector
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def has_text_regions(
|
|
40
|
+
frame: np.ndarray,
|
|
41
|
+
min_regions: int = 3,
|
|
42
|
+
min_aspect_ratio: float = 0.2,
|
|
43
|
+
max_aspect_ratio: float = 15.0,
|
|
44
|
+
) -> bool:
|
|
45
|
+
"""Fast detection of potential text regions using MSER.
|
|
46
|
+
|
|
47
|
+
MSER (Maximally Stable Extremal Regions) is a classic computer vision
|
|
48
|
+
algorithm that finds stable regions in images. Text characters are
|
|
49
|
+
typically stable regions with specific aspect ratios.
|
|
50
|
+
|
|
51
|
+
This is ~100x faster than deep learning OCR and can be used to skip
|
|
52
|
+
frames that definitely don't contain text.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
frame: BGR image as numpy array
|
|
56
|
+
min_regions: Minimum text-like regions to consider "has text"
|
|
57
|
+
min_aspect_ratio: Minimum width/height ratio for text-like regions
|
|
58
|
+
max_aspect_ratio: Maximum width/height ratio for text-like regions
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
True if frame likely contains text, False otherwise
|
|
62
|
+
"""
|
|
63
|
+
# Convert to grayscale
|
|
64
|
+
if len(frame.shape) == 3:
|
|
65
|
+
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
|
66
|
+
else:
|
|
67
|
+
gray = frame
|
|
68
|
+
|
|
69
|
+
# Detect MSER regions
|
|
70
|
+
mser = _get_mser_detector()
|
|
71
|
+
regions, _ = mser.detectRegions(gray)
|
|
72
|
+
|
|
73
|
+
if len(regions) < min_regions:
|
|
74
|
+
return False
|
|
75
|
+
|
|
76
|
+
# Filter regions by text-like characteristics
|
|
77
|
+
text_like_count = 0
|
|
78
|
+
for region in regions:
|
|
79
|
+
# Get bounding box
|
|
80
|
+
_, _, w, h = cv2.boundingRect(region) # type: ignore[call-overload]
|
|
81
|
+
|
|
82
|
+
if h == 0:
|
|
83
|
+
continue
|
|
84
|
+
|
|
85
|
+
aspect_ratio = w / h
|
|
86
|
+
|
|
87
|
+
# Text characters typically have aspect ratios between 0.2 and 15
|
|
88
|
+
# (narrow letters like 'i' to wide text blocks)
|
|
89
|
+
if min_aspect_ratio <= aspect_ratio <= max_aspect_ratio:
|
|
90
|
+
# Additional filter: text regions tend to be small-medium sized
|
|
91
|
+
area = w * h
|
|
92
|
+
if 100 <= area <= 50000:
|
|
93
|
+
text_like_count += 1
|
|
94
|
+
|
|
95
|
+
# Early exit if we've found enough
|
|
96
|
+
if text_like_count >= min_regions:
|
|
97
|
+
return True
|
|
98
|
+
|
|
99
|
+
return text_like_count >= min_regions
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def unload_ocr_model() -> None:
|
|
103
|
+
"""Unload the EasyOCR model to free memory."""
|
|
104
|
+
global _ocr_reader, _ocr_languages
|
|
105
|
+
|
|
106
|
+
if _ocr_reader is None:
|
|
107
|
+
return
|
|
108
|
+
|
|
109
|
+
logger.info("Unloading EasyOCR model to free memory")
|
|
110
|
+
|
|
111
|
+
try:
|
|
112
|
+
import torch
|
|
113
|
+
|
|
114
|
+
del _ocr_reader
|
|
115
|
+
_ocr_reader = None
|
|
116
|
+
_ocr_languages = None
|
|
117
|
+
|
|
118
|
+
gc.collect()
|
|
119
|
+
|
|
120
|
+
if torch.cuda.is_available():
|
|
121
|
+
torch.cuda.synchronize()
|
|
122
|
+
torch.cuda.empty_cache()
|
|
123
|
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
124
|
+
if hasattr(torch.mps, "synchronize"):
|
|
125
|
+
torch.mps.synchronize()
|
|
126
|
+
if hasattr(torch.mps, "empty_cache"):
|
|
127
|
+
torch.mps.empty_cache()
|
|
128
|
+
|
|
129
|
+
gc.collect()
|
|
130
|
+
logger.info("EasyOCR model unloaded")
|
|
131
|
+
except Exception as e:
|
|
132
|
+
logger.warning(f"Error unloading EasyOCR model: {e}")
|
|
133
|
+
_ocr_reader = None
|
|
134
|
+
_ocr_languages = None
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _get_ocr_reader(languages: list[str] | None = None) -> Any:
|
|
138
|
+
"""Get or create the EasyOCR reader (singleton).
|
|
139
|
+
|
|
140
|
+
Note: Once initialized, the reader keeps its languages.
|
|
141
|
+
To change languages, restart the server.
|
|
142
|
+
"""
|
|
143
|
+
global _ocr_reader, _ocr_languages
|
|
144
|
+
|
|
145
|
+
if _ocr_reader is not None:
|
|
146
|
+
return _ocr_reader
|
|
147
|
+
|
|
148
|
+
import easyocr # type: ignore[import-not-found]
|
|
149
|
+
|
|
150
|
+
if languages is None:
|
|
151
|
+
# Get from settings
|
|
152
|
+
settings = get_settings()
|
|
153
|
+
languages = settings.ocr_languages
|
|
154
|
+
|
|
155
|
+
_ocr_languages = languages
|
|
156
|
+
|
|
157
|
+
# Enable GPU if available (CUDA only - EasyOCR doesn't support MPS)
|
|
158
|
+
device = get_device()
|
|
159
|
+
use_gpu = device == DeviceType.CUDA
|
|
160
|
+
device_name = "CUDA GPU" if use_gpu else "CPU"
|
|
161
|
+
|
|
162
|
+
logger.info(f"Initializing EasyOCR with languages: {languages} on {device_name}")
|
|
163
|
+
_ocr_reader = easyocr.Reader(languages, gpu=use_gpu)
|
|
164
|
+
|
|
165
|
+
return _ocr_reader
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def extract_ocr(
|
|
169
|
+
file_path: str,
|
|
170
|
+
frame_buffer: SharedFrameBuffer,
|
|
171
|
+
min_confidence: float = 0.5,
|
|
172
|
+
languages: list[str] | None = None,
|
|
173
|
+
skip_prefilter: bool = False,
|
|
174
|
+
) -> OcrResult:
|
|
175
|
+
"""Extract text from video frames using two-phase OCR.
|
|
176
|
+
|
|
177
|
+
Phase 1: Fast MSER-based text detection (~5ms/frame)
|
|
178
|
+
Phase 2: Deep learning OCR on frames with text (~500ms/frame)
|
|
179
|
+
|
|
180
|
+
This typically skips 80-90% of frames, providing major speedup.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
file_path: Path to video file (used for logging)
|
|
184
|
+
frame_buffer: Pre-decoded frames from SharedFrameBuffer
|
|
185
|
+
min_confidence: Minimum detection confidence
|
|
186
|
+
languages: OCR languages (default from settings)
|
|
187
|
+
skip_prefilter: If True, skip MSER pre-filter and run OCR on all frames
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
OcrResult with detected text
|
|
191
|
+
"""
|
|
192
|
+
detections: list[OcrDetection] = []
|
|
193
|
+
seen_texts: set[str] = set() # For deduplication
|
|
194
|
+
|
|
195
|
+
# Stats for logging
|
|
196
|
+
frames_checked = 0
|
|
197
|
+
frames_with_text = 0
|
|
198
|
+
frames_skipped = 0
|
|
199
|
+
|
|
200
|
+
# Lazy-load OCR reader only if we find frames with text
|
|
201
|
+
reader: Any = None
|
|
202
|
+
|
|
203
|
+
def process_frame(frame: np.ndarray, timestamp: float) -> None:
|
|
204
|
+
"""Process a single frame for OCR."""
|
|
205
|
+
nonlocal frames_checked, frames_with_text, frames_skipped, reader
|
|
206
|
+
|
|
207
|
+
frames_checked += 1
|
|
208
|
+
|
|
209
|
+
# Phase 1: Fast MSER pre-filter
|
|
210
|
+
if not skip_prefilter:
|
|
211
|
+
if not has_text_regions(frame):
|
|
212
|
+
frames_skipped += 1
|
|
213
|
+
return
|
|
214
|
+
|
|
215
|
+
frames_with_text += 1
|
|
216
|
+
|
|
217
|
+
# Phase 2: Deep learning OCR (only on frames that passed pre-filter)
|
|
218
|
+
try:
|
|
219
|
+
# Lazy load OCR reader on first use
|
|
220
|
+
if reader is None:
|
|
221
|
+
reader = _get_ocr_reader(languages)
|
|
222
|
+
|
|
223
|
+
results = reader.readtext(frame)
|
|
224
|
+
|
|
225
|
+
for bbox_points, text, confidence in results:
|
|
226
|
+
if confidence < min_confidence:
|
|
227
|
+
continue
|
|
228
|
+
|
|
229
|
+
# Skip if we've seen this exact text recently
|
|
230
|
+
text_key = text.strip().lower()
|
|
231
|
+
if text_key in seen_texts:
|
|
232
|
+
continue
|
|
233
|
+
seen_texts.add(text_key)
|
|
234
|
+
|
|
235
|
+
# Convert polygon to bounding box
|
|
236
|
+
# bbox_points is [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
|
|
237
|
+
x_coords = [p[0] for p in bbox_points]
|
|
238
|
+
y_coords = [p[1] for p in bbox_points]
|
|
239
|
+
x = int(min(x_coords))
|
|
240
|
+
y = int(min(y_coords))
|
|
241
|
+
width = int(max(x_coords) - x)
|
|
242
|
+
height = int(max(y_coords) - y)
|
|
243
|
+
|
|
244
|
+
detection = OcrDetection(
|
|
245
|
+
timestamp=timestamp,
|
|
246
|
+
text=text.strip(),
|
|
247
|
+
confidence=round(float(confidence), 3),
|
|
248
|
+
bbox=BoundingBox(x=x, y=y, width=width, height=height),
|
|
249
|
+
)
|
|
250
|
+
detections.append(detection)
|
|
251
|
+
|
|
252
|
+
except Exception as e:
|
|
253
|
+
logger.warning(f"Failed to process frame at {timestamp}s: {e}")
|
|
254
|
+
|
|
255
|
+
# Process frames from shared buffer
|
|
256
|
+
logger.info(f"Processing {len(frame_buffer.frames)} frames for OCR")
|
|
257
|
+
for ts in sorted(frame_buffer.frames.keys()):
|
|
258
|
+
shared_frame = frame_buffer.frames[ts]
|
|
259
|
+
process_frame(shared_frame.bgr, ts)
|
|
260
|
+
|
|
261
|
+
# Log stats
|
|
262
|
+
if frames_checked > 0:
|
|
263
|
+
skip_pct = (frames_skipped / frames_checked) * 100
|
|
264
|
+
logger.info(f"OCR: {frames_checked} frames checked, {frames_skipped} skipped ({skip_pct:.0f}%), " f"{frames_with_text} processed, {len(detections)} text regions found")
|
|
265
|
+
else:
|
|
266
|
+
logger.info("OCR: no frames to process")
|
|
267
|
+
|
|
268
|
+
return OcrResult(detections=detections)
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""Scene detection using PySceneDetect."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
from media_engine.schemas import SceneDetection, ScenesResult
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
# Resolution thresholds for frame skipping (pixel count)
|
|
11
|
+
RES_1080P = 1920 * 1080 # ~2M pixels
|
|
12
|
+
RES_4K = 3840 * 2160 # ~8.3M pixels
|
|
13
|
+
RES_5K = 5120 * 2880 # ~14.7M pixels
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def extract_scenes(file_path: str, threshold: float = 27.0) -> ScenesResult:
|
|
17
|
+
"""Detect scene boundaries in video file.
|
|
18
|
+
|
|
19
|
+
For high-resolution videos (4K+), uses frame skipping to improve performance.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
file_path: Path to video file
|
|
23
|
+
threshold: Content detector threshold (lower = more sensitive)
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
ScenesResult with detected scene boundaries
|
|
27
|
+
"""
|
|
28
|
+
from scenedetect import ( # type: ignore[import-not-found]
|
|
29
|
+
ContentDetector,
|
|
30
|
+
SceneManager,
|
|
31
|
+
open_video,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
path = Path(file_path)
|
|
35
|
+
if not path.exists():
|
|
36
|
+
raise FileNotFoundError(f"Video file not found: {file_path}")
|
|
37
|
+
|
|
38
|
+
logger.info(f"Detecting scenes in {file_path}")
|
|
39
|
+
|
|
40
|
+
# Open video to get resolution
|
|
41
|
+
video = open_video(file_path)
|
|
42
|
+
width, height = video.frame_size
|
|
43
|
+
pixels = width * height
|
|
44
|
+
|
|
45
|
+
# Determine frame skip based on resolution
|
|
46
|
+
# Higher resolution = more frame skip for speed
|
|
47
|
+
if pixels > RES_5K:
|
|
48
|
+
frame_skip = 4 # Process every 5th frame for 5K+
|
|
49
|
+
logger.info(f"High-res video ({width}x{height}), using frame_skip=4")
|
|
50
|
+
elif pixels > RES_4K:
|
|
51
|
+
frame_skip = 2 # Process every 3rd frame for 4K+
|
|
52
|
+
logger.info(f"4K video ({width}x{height}), using frame_skip=2")
|
|
53
|
+
elif pixels > RES_1080P:
|
|
54
|
+
frame_skip = 1 # Process every 2nd frame for >1080p
|
|
55
|
+
logger.info(f"High-res video ({width}x{height}), using frame_skip=1")
|
|
56
|
+
else:
|
|
57
|
+
frame_skip = 0 # Process every frame for 1080p and below
|
|
58
|
+
|
|
59
|
+
# Use SceneManager API for frame_skip support
|
|
60
|
+
scene_manager = SceneManager()
|
|
61
|
+
scene_manager.add_detector(ContentDetector(threshold=threshold))
|
|
62
|
+
|
|
63
|
+
# Detect scenes with frame skipping
|
|
64
|
+
scene_manager.detect_scenes(video, frame_skip=frame_skip)
|
|
65
|
+
scenes = scene_manager.get_scene_list()
|
|
66
|
+
|
|
67
|
+
detections = []
|
|
68
|
+
for i, (start_time, end_time) in enumerate(scenes):
|
|
69
|
+
start_sec = start_time.get_seconds()
|
|
70
|
+
end_sec = end_time.get_seconds()
|
|
71
|
+
detections.append(
|
|
72
|
+
SceneDetection(
|
|
73
|
+
index=i,
|
|
74
|
+
start=start_sec,
|
|
75
|
+
end=end_sec,
|
|
76
|
+
duration=round(end_sec - start_sec, 3),
|
|
77
|
+
)
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
logger.info(f"Detected {len(detections)} scenes")
|
|
81
|
+
|
|
82
|
+
return ScenesResult(count=len(detections), detections=detections)
|
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
"""Shot type detection using CLIP classification."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import tempfile
|
|
6
|
+
from enum import StrEnum
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Final
|
|
9
|
+
|
|
10
|
+
import cv2
|
|
11
|
+
|
|
12
|
+
from media_engine.config import has_cuda, is_apple_silicon
|
|
13
|
+
from media_engine.extractors.frames import FrameExtractor
|
|
14
|
+
from media_engine.schemas import ShotType
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ShotTypeLabel(StrEnum):
|
|
20
|
+
"""Shot type classification labels."""
|
|
21
|
+
|
|
22
|
+
AERIAL = "aerial"
|
|
23
|
+
INTERVIEW = "interview"
|
|
24
|
+
B_ROLL = "b-roll"
|
|
25
|
+
STUDIO = "studio"
|
|
26
|
+
HANDHELD = "handheld"
|
|
27
|
+
STATIC = "static"
|
|
28
|
+
PHONE = "phone"
|
|
29
|
+
DASHCAM = "dashcam"
|
|
30
|
+
SECURITY = "security"
|
|
31
|
+
BROADCAST = "broadcast"
|
|
32
|
+
UNKNOWN = "unknown"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# Shot type labels for CLIP classification
|
|
36
|
+
SHOT_TYPE_LABELS: Final[list[str]] = [
|
|
37
|
+
"aerial drone footage from above",
|
|
38
|
+
"interview with person talking to camera",
|
|
39
|
+
"b-roll footage of scenery or environment",
|
|
40
|
+
"studio footage with controlled lighting",
|
|
41
|
+
"handheld camera footage",
|
|
42
|
+
"tripod static shot",
|
|
43
|
+
"phone footage vertical video",
|
|
44
|
+
"dashcam footage from car",
|
|
45
|
+
"security camera footage",
|
|
46
|
+
"news broadcast footage",
|
|
47
|
+
]
|
|
48
|
+
|
|
49
|
+
# Map CLIP labels to simplified shot types
|
|
50
|
+
LABEL_TO_TYPE: Final[dict[str, ShotTypeLabel]] = {
|
|
51
|
+
"aerial drone footage from above": ShotTypeLabel.AERIAL,
|
|
52
|
+
"interview with person talking to camera": ShotTypeLabel.INTERVIEW,
|
|
53
|
+
"b-roll footage of scenery or environment": ShotTypeLabel.B_ROLL,
|
|
54
|
+
"studio footage with controlled lighting": ShotTypeLabel.STUDIO,
|
|
55
|
+
"handheld camera footage": ShotTypeLabel.HANDHELD,
|
|
56
|
+
"tripod static shot": ShotTypeLabel.STATIC,
|
|
57
|
+
"phone footage vertical video": ShotTypeLabel.PHONE,
|
|
58
|
+
"dashcam footage from car": ShotTypeLabel.DASHCAM,
|
|
59
|
+
"security camera footage": ShotTypeLabel.SECURITY,
|
|
60
|
+
"news broadcast footage": ShotTypeLabel.BROADCAST,
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def detect_shot_type(file_path: str, sample_count: int = 5) -> ShotType | None:
|
|
65
|
+
"""Detect shot type using CLIP classification.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
file_path: Path to video file
|
|
69
|
+
sample_count: Number of frames to sample for classification
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
ShotType with primary type and confidence, or None if detection fails
|
|
73
|
+
"""
|
|
74
|
+
path = Path(file_path)
|
|
75
|
+
if not path.exists():
|
|
76
|
+
raise FileNotFoundError(f"File not found: {file_path}")
|
|
77
|
+
|
|
78
|
+
try:
|
|
79
|
+
# Get video duration (0 for images)
|
|
80
|
+
duration = _get_video_duration(file_path)
|
|
81
|
+
|
|
82
|
+
# Sample frames at regular intervals (or single frame for images)
|
|
83
|
+
temp_dir = tempfile.mkdtemp(prefix="polybos_shot_")
|
|
84
|
+
|
|
85
|
+
try:
|
|
86
|
+
frames: list[str] = []
|
|
87
|
+
if duration == 0:
|
|
88
|
+
# Image or zero-duration: use single frame at timestamp 0
|
|
89
|
+
frame_path = _extract_frame_at(file_path, temp_dir, 0.0)
|
|
90
|
+
if frame_path:
|
|
91
|
+
frames.append(frame_path)
|
|
92
|
+
else:
|
|
93
|
+
for i in range(sample_count):
|
|
94
|
+
timestamp = (i + 0.5) * duration / sample_count
|
|
95
|
+
frame_path = _extract_frame_at(file_path, temp_dir, timestamp)
|
|
96
|
+
if frame_path:
|
|
97
|
+
frames.append(frame_path)
|
|
98
|
+
|
|
99
|
+
if not frames:
|
|
100
|
+
return None
|
|
101
|
+
|
|
102
|
+
# Classify frames using CLIP
|
|
103
|
+
votes = _classify_frames(frames)
|
|
104
|
+
|
|
105
|
+
if not votes:
|
|
106
|
+
return None
|
|
107
|
+
|
|
108
|
+
# Get most common classification
|
|
109
|
+
best_label = max(votes, key=lambda k: votes.get(k, 0))
|
|
110
|
+
confidence = votes[best_label] / len(frames)
|
|
111
|
+
|
|
112
|
+
return ShotType(
|
|
113
|
+
primary=LABEL_TO_TYPE.get(best_label, ShotTypeLabel.UNKNOWN),
|
|
114
|
+
confidence=round(confidence, 3),
|
|
115
|
+
detection_method="clip",
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
finally:
|
|
119
|
+
# Clean up
|
|
120
|
+
import shutil
|
|
121
|
+
|
|
122
|
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
|
123
|
+
|
|
124
|
+
except Exception as e:
|
|
125
|
+
logger.warning(f"Shot type detection failed: {e}")
|
|
126
|
+
return None
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _classify_frames(frame_paths: list[str]) -> dict[str, int]:
|
|
130
|
+
"""Classify frames using CLIP and return vote counts."""
|
|
131
|
+
votes: dict[str, int] = {}
|
|
132
|
+
|
|
133
|
+
if is_apple_silicon():
|
|
134
|
+
votes = _classify_with_mlx(frame_paths)
|
|
135
|
+
else:
|
|
136
|
+
votes = _classify_with_openclip(frame_paths)
|
|
137
|
+
|
|
138
|
+
return votes
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _classify_with_openclip(frame_paths: list[str]) -> dict[str, int]:
|
|
142
|
+
"""Classify frames using OpenCLIP."""
|
|
143
|
+
import open_clip # type: ignore[import-not-found]
|
|
144
|
+
import torch
|
|
145
|
+
from PIL import Image
|
|
146
|
+
|
|
147
|
+
device: str = "cuda" if has_cuda() else "cpu"
|
|
148
|
+
|
|
149
|
+
model, _, preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained="openai")
|
|
150
|
+
model = model.to(device)
|
|
151
|
+
model.eval()
|
|
152
|
+
|
|
153
|
+
tokenizer = open_clip.get_tokenizer("ViT-B-32")
|
|
154
|
+
text_tokens = tokenizer(SHOT_TYPE_LABELS).to(device)
|
|
155
|
+
|
|
156
|
+
with torch.no_grad():
|
|
157
|
+
text_features = model.encode_text(text_tokens)
|
|
158
|
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
|
159
|
+
|
|
160
|
+
votes: dict[str, int] = {}
|
|
161
|
+
|
|
162
|
+
for frame_path in frame_paths:
|
|
163
|
+
try:
|
|
164
|
+
image = Image.open(frame_path).convert("RGB")
|
|
165
|
+
image_tensor = preprocess(image).unsqueeze(0).to(device)
|
|
166
|
+
|
|
167
|
+
with torch.no_grad():
|
|
168
|
+
image_features = model.encode_image(image_tensor)
|
|
169
|
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
|
170
|
+
|
|
171
|
+
similarity = (image_features @ text_features.T).softmax(dim=-1)
|
|
172
|
+
best_idx = similarity.argmax().item()
|
|
173
|
+
best_label = SHOT_TYPE_LABELS[best_idx]
|
|
174
|
+
|
|
175
|
+
votes[best_label] = votes.get(best_label, 0) + 1
|
|
176
|
+
|
|
177
|
+
except Exception as e:
|
|
178
|
+
logger.warning(f"Failed to classify frame {frame_path}: {e}")
|
|
179
|
+
|
|
180
|
+
return votes
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _classify_with_mlx(frame_paths: list[str]) -> dict[str, int]:
|
|
184
|
+
"""Classify frames using MLX-CLIP (Apple Silicon)."""
|
|
185
|
+
# Fall back to OpenCLIP for now as MLX-CLIP API may vary
|
|
186
|
+
# TODO: Implement native MLX-CLIP classification
|
|
187
|
+
return _classify_with_openclip(frame_paths)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _extract_frame_at(file_path: str, output_dir: str, timestamp: float) -> str | None:
|
|
191
|
+
"""Extract a single frame at specified timestamp.
|
|
192
|
+
|
|
193
|
+
Uses FrameExtractor which handles both videos (via OpenCV/ffmpeg)
|
|
194
|
+
and images (via direct loading).
|
|
195
|
+
"""
|
|
196
|
+
output_path = os.path.join(output_dir, f"frame_{timestamp:.3f}.jpg")
|
|
197
|
+
|
|
198
|
+
with FrameExtractor(file_path) as extractor:
|
|
199
|
+
frame = extractor.get_frame_at(timestamp)
|
|
200
|
+
|
|
201
|
+
if frame is not None:
|
|
202
|
+
cv2.imwrite(output_path, frame, [cv2.IMWRITE_JPEG_QUALITY, 95])
|
|
203
|
+
if os.path.exists(output_path) and os.path.getsize(output_path) > 0:
|
|
204
|
+
return output_path
|
|
205
|
+
else:
|
|
206
|
+
logger.warning(f"Frame at {timestamp}s: could not save to {output_path}")
|
|
207
|
+
else:
|
|
208
|
+
logger.warning(f"Frame at {timestamp}s: extraction failed")
|
|
209
|
+
|
|
210
|
+
return None
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _get_video_duration(file_path: str) -> float:
|
|
214
|
+
"""Get video/image duration in seconds (0 for images)."""
|
|
215
|
+
from media_engine.extractors.frames import get_video_duration
|
|
216
|
+
|
|
217
|
+
return get_video_duration(file_path)
|