shot-detection 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,54 @@
1
+ name: Publish PyPI
2
+
3
+ on:
4
+ workflow_dispatch:
5
+ release:
6
+ types: [published]
7
+
8
+ permissions:
9
+ contents: read
10
+
11
+ jobs:
12
+ build:
13
+ name: Build distributions
14
+ runs-on: ubuntu-latest
15
+
16
+ steps:
17
+ - uses: actions/checkout@v5
18
+
19
+ - uses: actions/setup-python@v5
20
+ with:
21
+ python-version: "3.12"
22
+
23
+ - name: Build distributions
24
+ run: |
25
+ python -m pip install --upgrade build twine
26
+ python -m build
27
+ python -m twine check dist/*
28
+
29
+ - name: Upload distributions
30
+ uses: actions/upload-artifact@v4
31
+ with:
32
+ name: release-dists
33
+ path: dist/
34
+
35
+ publish:
36
+ name: Publish to PyPI
37
+ runs-on: ubuntu-latest
38
+ needs: build
39
+ permissions:
40
+ id-token: write
41
+
42
+ environment:
43
+ name: pypi
44
+ url: https://pypi.org/project/shot-detection/
45
+
46
+ steps:
47
+ - name: Download distributions
48
+ uses: actions/download-artifact@v5
49
+ with:
50
+ name: release-dists
51
+ path: dist/
52
+
53
+ - name: Publish distributions to PyPI
54
+ uses: pypa/gh-action-pypi-publish@release/v1
@@ -0,0 +1,106 @@
1
+ Metadata-Version: 2.4
2
+ Name: shot-detection
3
+ Version: 0.1.0
4
+ Summary: Standalone Python shot boundary detection package for splitting videos into shots
5
+ Project-URL: Homepage, https://github.com/Seeknetic/shot-detection-python
6
+ Project-URL: Repository, https://github.com/Seeknetic/shot-detection-python
7
+ Author: Seeknetic
8
+ License-Expression: Apache-2.0
9
+ Classifier: Intended Audience :: Developers
10
+ Classifier: License :: OSI Approved :: Apache Software License
11
+ Classifier: Operating System :: OS Independent
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Programming Language :: Python :: 3.9
14
+ Classifier: Programming Language :: Python :: 3.10
15
+ Classifier: Programming Language :: Python :: 3.11
16
+ Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Topic :: Multimedia :: Video
18
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
19
+ Classifier: Typing :: Typed
20
+ Requires-Python: >=3.9
21
+ Requires-Dist: numpy<3,>=1.24
22
+ Requires-Dist: onnxruntime<2,>=1.17
23
+ Requires-Dist: typing-extensions<5,>=4.9
24
+ Provides-Extra: dev
25
+ Requires-Dist: pytest<9,>=8; extra == 'dev'
26
+ Description-Content-Type: text/markdown
27
+
28
+ # Shot Detection Python
29
+
30
+ Standalone Python package for shot boundary detection.
31
+
32
+ It takes a video file, runs TransNetV2-style ONNX inference on low-resolution RGB frames, and returns shot ranges with millisecond timestamps.
33
+
34
+ Built by [Seeknetic](https://www.seeknetic.com/). If you want to make video shots searchable and enable more professional video understanding workflows, visit [seeknetic.com](https://www.seeknetic.com/).
35
+
36
+ ## What it does
37
+
38
+ - Probes the input video with `ffprobe`
39
+ - Extracts low-resolution frames with `ffmpeg`
40
+ - Runs sliding-window ONNX inference
41
+ - Finds shot boundaries
42
+ - Converts them into `{start_ms, end_ms}` shot segments
43
+
44
+ ## Installation
45
+
46
+ ```bash
47
+ pip install shot-detection
48
+ ```
49
+
50
+ Import from `shot_detection`:
51
+
52
+ ```python
53
+ from shot_detection import ShotDetector
54
+ ```
55
+
56
+ You also need:
57
+
58
+ - `ffmpeg`
59
+ - `ffprobe`
60
+
61
+ On first run, the package downloads the default model automatically:
62
+
63
+ - URL: `https://download.shotai.io/model/shot-detection/transnetv2_open_fp16.onnx`
64
+ - cache dir:
65
+ - Linux/macOS: `~/.cache/shot-detection/models/`
66
+ - Windows: `%LOCALAPPDATA%\\shot-detection\\models\\`
67
+
68
+ You can override the cache root with `SHOT_DETECTION_CACHE_DIR`.
69
+
70
+ ## Usage
71
+
72
+ ```python
73
+ from shot_detection import ShotDetector
74
+
75
+ detector = ShotDetector()
76
+ shots = detector.detect("/path/to/video.mp4")
77
+
78
+ for shot in shots:
79
+ print(shot.start_ms, shot.end_ms)
80
+ ```
81
+
82
+ ## Advanced usage
83
+
84
+ ```python
85
+ from shot_detection import detect_shots
86
+
87
+ shots = detect_shots(
88
+ video_path="/path/to/video.mp4",
89
+ threshold=0.5,
90
+ min_shot_duration_ms=500,
91
+ )
92
+ ```
93
+
94
+ ## Custom model path
95
+
96
+ ```python
97
+ from shot_detection import ShotDetector
98
+
99
+ detector = ShotDetector(model_path="/path/to/custom-transnetv2.onnx")
100
+ ```
101
+
102
+ ## Notes
103
+
104
+ - The package expects the ONNX model input to accept 100-frame windows at `48x27` RGB.
105
+ - `ffmpeg` decode is adaptive: it prefers system-native hardware decoding when available and falls back to software decoding automatically.
106
+ - CUDA is intentionally not part of the default decode plan.
@@ -0,0 +1,79 @@
1
+ # Shot Detection Python
2
+
3
+ Standalone Python package for shot boundary detection.
4
+
5
+ It takes a video file, runs TransNetV2-style ONNX inference on low-resolution RGB frames, and returns shot ranges with millisecond timestamps.
6
+
7
+ Built by [Seeknetic](https://www.seeknetic.com/). If you want to make video shots searchable and enable more professional video understanding workflows, visit [seeknetic.com](https://www.seeknetic.com/).
8
+
9
+ ## What it does
10
+
11
+ - Probes the input video with `ffprobe`
12
+ - Extracts low-resolution frames with `ffmpeg`
13
+ - Runs sliding-window ONNX inference
14
+ - Finds shot boundaries
15
+ - Converts them into `{start_ms, end_ms}` shot segments
16
+
17
+ ## Installation
18
+
19
+ ```bash
20
+ pip install shot-detection
21
+ ```
22
+
23
+ Import from `shot_detection`:
24
+
25
+ ```python
26
+ from shot_detection import ShotDetector
27
+ ```
28
+
29
+ You also need:
30
+
31
+ - `ffmpeg`
32
+ - `ffprobe`
33
+
34
+ On first run, the package downloads the default model automatically:
35
+
36
+ - URL: `https://download.shotai.io/model/shot-detection/transnetv2_open_fp16.onnx`
37
+ - cache dir:
38
+ - Linux/macOS: `~/.cache/shot-detection/models/`
39
+ - Windows: `%LOCALAPPDATA%\\shot-detection\\models\\`
40
+
41
+ You can override the cache root with `SHOT_DETECTION_CACHE_DIR`.
42
+
43
+ ## Usage
44
+
45
+ ```python
46
+ from shot_detection import ShotDetector
47
+
48
+ detector = ShotDetector()
49
+ shots = detector.detect("/path/to/video.mp4")
50
+
51
+ for shot in shots:
52
+ print(shot.start_ms, shot.end_ms)
53
+ ```
54
+
55
+ ## Advanced usage
56
+
57
+ ```python
58
+ from shot_detection import detect_shots
59
+
60
+ shots = detect_shots(
61
+ video_path="/path/to/video.mp4",
62
+ threshold=0.5,
63
+ min_shot_duration_ms=500,
64
+ )
65
+ ```
66
+
67
+ ## Custom model path
68
+
69
+ ```python
70
+ from shot_detection import ShotDetector
71
+
72
+ detector = ShotDetector(model_path="/path/to/custom-transnetv2.onnx")
73
+ ```
74
+
75
+ ## Notes
76
+
77
+ - The package expects the ONNX model input to accept 100-frame windows at `48x27` RGB.
78
+ - `ffmpeg` decode is adaptive: it prefers system-native hardware decoding when available and falls back to software decoding automatically.
79
+ - CUDA is intentionally not part of the default decode plan.
@@ -0,0 +1,11 @@
1
+ #!/usr/bin/env bash
2
+
3
+ set -eux
4
+ rm -rf dist
5
+ mkdir -p dist
6
+ uv build
7
+ if [ -n "${PYPI_TOKEN:-}" ]; then
8
+ uv publish --token="$PYPI_TOKEN"
9
+ else
10
+ uv publish
11
+ fi
@@ -0,0 +1,47 @@
1
+ [project]
2
+ name = "shot-detection"
3
+ version = "0.1.0"
4
+ description = "Standalone Python shot boundary detection package for splitting videos into shots"
5
+ readme = "README.md"
6
+ license = "Apache-2.0"
7
+ authors = [
8
+ { name = "Seeknetic" },
9
+ ]
10
+ requires-python = ">=3.9"
11
+ dependencies = [
12
+ "numpy>=1.24,<3",
13
+ "onnxruntime>=1.17,<2",
14
+ "typing-extensions>=4.9,<5",
15
+ ]
16
+ classifiers = [
17
+ "Typing :: Typed",
18
+ "Intended Audience :: Developers",
19
+ "Programming Language :: Python :: 3",
20
+ "Programming Language :: Python :: 3.9",
21
+ "Programming Language :: Python :: 3.10",
22
+ "Programming Language :: Python :: 3.11",
23
+ "Programming Language :: Python :: 3.12",
24
+ "Operating System :: OS Independent",
25
+ "License :: OSI Approved :: Apache Software License",
26
+ "Topic :: Multimedia :: Video",
27
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
28
+ ]
29
+
30
+ [project.urls]
31
+ Homepage = "https://github.com/Seeknetic/shot-detection-python"
32
+ Repository = "https://github.com/Seeknetic/shot-detection-python"
33
+
34
+ [project.optional-dependencies]
35
+ dev = [
36
+ "pytest>=8,<9",
37
+ ]
38
+
39
+ [build-system]
40
+ requires = ["hatchling>=1.26,<2"]
41
+ build-backend = "hatchling.build"
42
+
43
+ [tool.hatch.build.targets.wheel]
44
+ packages = ["src/shot_detection"]
45
+
46
+ [tool.pytest.ini_options]
47
+ testpaths = ["tests"]
@@ -0,0 +1,24 @@
1
+ from .detector import ShotDetector, detect_shots, detect_shots_detailed
2
+ from .errors import ShotDetectionError
3
+ from .model_manager import (
4
+ DEFAULT_MODEL_URL,
5
+ ensure_default_model,
6
+ get_default_cache_dir,
7
+ get_default_model_path,
8
+ )
9
+ from .types import DetectionResult, ShotBoundary, ShotSegment, VideoMetadata
10
+
11
+ __all__ = [
12
+ "ShotDetector",
13
+ "ShotDetectionError",
14
+ "DEFAULT_MODEL_URL",
15
+ "ShotBoundary",
16
+ "ShotSegment",
17
+ "VideoMetadata",
18
+ "DetectionResult",
19
+ "ensure_default_model",
20
+ "get_default_cache_dir",
21
+ "get_default_model_path",
22
+ "detect_shots",
23
+ "detect_shots_detailed",
24
+ ]
@@ -0,0 +1,207 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+
8
+ from .errors import ShotDetectionError
9
+ from .ffmpeg import MODEL_INPUT_HEIGHT, MODEL_INPUT_WIDTH, extract_low_res_rgb_frames, probe_video
10
+ from .inference import MODEL_INPUT_CHANNELS, WINDOW_SIZE, TransNetV2ONNX
11
+ from .model_manager import ensure_default_model
12
+ from .types import DetectionResult, ShotBoundary, ShotSegment
13
+
14
+ WINDOW_STEP = WINDOW_SIZE - 10
15
+ DEFAULT_THRESHOLD = 0.5
16
+ DEFAULT_MIN_SHOT_DURATION_MS = 500
17
+
18
+
19
+ class ShotDetector:
20
+ def __init__(
21
+ self,
22
+ *,
23
+ model_path: str | Path | None = None,
24
+ threshold: float = DEFAULT_THRESHOLD,
25
+ min_shot_duration_ms: int = DEFAULT_MIN_SHOT_DURATION_MS,
26
+ model_cache_dir: str | Path | None = None,
27
+ ) -> None:
28
+ resolved_model_path = ensure_default_model(model_path=model_path, cache_dir=model_cache_dir)
29
+ self.model = TransNetV2ONNX(resolved_model_path)
30
+ self.threshold = threshold
31
+ self.min_shot_duration_ms = min_shot_duration_ms
32
+
33
+ def detect(self, video_path: str | Path) -> list[ShotSegment]:
34
+ return self.detect_detailed(video_path).shots
35
+
36
+ def detect_detailed(self, video_path: str | Path) -> DetectionResult:
37
+ metadata = probe_video(video_path)
38
+ raw_frames = extract_low_res_rgb_frames(video_path)
39
+ frame_size = MODEL_INPUT_WIDTH * MODEL_INPUT_HEIGHT * MODEL_INPUT_CHANNELS
40
+ total_frames = len(raw_frames) // frame_size
41
+
42
+ if total_frames <= 0:
43
+ raise ShotDetectionError("No frames were extracted from the input video")
44
+
45
+ frames = np.frombuffer(raw_frames[: total_frames * frame_size], dtype=np.uint8).reshape(
46
+ total_frames,
47
+ MODEL_INPUT_HEIGHT,
48
+ MODEL_INPUT_WIDTH,
49
+ MODEL_INPUT_CHANNELS,
50
+ )
51
+
52
+ all_probs: list[float] = []
53
+
54
+ for start in range(0, total_frames, WINDOW_STEP):
55
+ end = min(start + WINDOW_SIZE, total_frames)
56
+ window_frames = end - start
57
+ if window_frames < 10:
58
+ break
59
+
60
+ buffer = np.zeros(
61
+ (1, WINDOW_SIZE, MODEL_INPUT_HEIGHT, MODEL_INPUT_WIDTH, MODEL_INPUT_CHANNELS),
62
+ dtype=np.uint8,
63
+ )
64
+ buffer[0, :window_frames] = frames[start:end]
65
+
66
+ probs = self.model.infer(buffer)
67
+ for offset in range(window_frames):
68
+ frame_index = start + offset
69
+ if frame_index >= len(all_probs):
70
+ all_probs.append(_to_boundary_probability(float(probs[offset])))
71
+
72
+ peak_indices = find_peaks(np.asarray(all_probs, dtype=np.float32), self.threshold)
73
+ boundaries = [
74
+ ShotBoundary(
75
+ frame_index=frame_index,
76
+ timestamp_ms=int(round((frame_index / metadata.fps) * 1000)),
77
+ confidence=float(all_probs[frame_index]),
78
+ )
79
+ for frame_index in peak_indices
80
+ ]
81
+ shots = boundaries_to_shots(boundaries, metadata.duration_ms, self.min_shot_duration_ms)
82
+
83
+ return DetectionResult(
84
+ boundaries=boundaries,
85
+ shots=shots,
86
+ total_frames=total_frames,
87
+ fps=metadata.fps,
88
+ duration_ms=metadata.duration_ms,
89
+ )
90
+
91
+
92
+ def detect_shots(
93
+ *,
94
+ video_path: str | Path,
95
+ model_path: str | Path | None = None,
96
+ threshold: float = DEFAULT_THRESHOLD,
97
+ min_shot_duration_ms: int = DEFAULT_MIN_SHOT_DURATION_MS,
98
+ model_cache_dir: str | Path | None = None,
99
+ ) -> list[ShotSegment]:
100
+ detector = ShotDetector(
101
+ model_path=model_path,
102
+ threshold=threshold,
103
+ min_shot_duration_ms=min_shot_duration_ms,
104
+ model_cache_dir=model_cache_dir,
105
+ )
106
+ return detector.detect(video_path)
107
+
108
+
109
+ def detect_shots_detailed(
110
+ *,
111
+ video_path: str | Path,
112
+ model_path: str | Path | None = None,
113
+ threshold: float = DEFAULT_THRESHOLD,
114
+ min_shot_duration_ms: int = DEFAULT_MIN_SHOT_DURATION_MS,
115
+ model_cache_dir: str | Path | None = None,
116
+ ) -> DetectionResult:
117
+ detector = ShotDetector(
118
+ model_path=model_path,
119
+ threshold=threshold,
120
+ min_shot_duration_ms=min_shot_duration_ms,
121
+ model_cache_dir=model_cache_dir,
122
+ )
123
+ return detector.detect_detailed(video_path)
124
+
125
+
126
+ def _to_boundary_probability(value: float) -> float:
127
+ if 0.0 <= value <= 1.0:
128
+ return value
129
+ return 1.0 / (1.0 + math.exp(-value))
130
+
131
+
132
+ def find_peaks(probs: np.ndarray, threshold: float, min_distance: int = 10) -> list[int]:
133
+ peaks: list[int] = []
134
+
135
+ for index in range(1, len(probs) - 1):
136
+ current = float(probs[index])
137
+ previous = float(probs[index - 1])
138
+ following = float(probs[index + 1])
139
+
140
+ if current > threshold and current > previous and current > following:
141
+ if not peaks or index - peaks[-1] >= min_distance:
142
+ peaks.append(index)
143
+ elif current > float(probs[peaks[-1]]):
144
+ peaks[-1] = index
145
+
146
+ return peaks
147
+
148
+
149
+ def boundaries_to_shots(
150
+ boundaries: list[ShotBoundary],
151
+ duration_ms: int,
152
+ min_shot_duration_ms: int = DEFAULT_MIN_SHOT_DURATION_MS,
153
+ ) -> list[ShotSegment]:
154
+ start_points = [0, *[boundary.timestamp_ms for boundary in boundaries]]
155
+ end_points = [*[boundary.timestamp_ms for boundary in boundaries], duration_ms]
156
+
157
+ shots = [
158
+ ShotSegment(index=index, start_ms=start_points[index], end_ms=end_points[index])
159
+ for index in range(len(start_points))
160
+ ]
161
+ return merge_short_shots(shots, min_shot_duration_ms)
162
+
163
+
164
+ def merge_short_shots(shots: list[ShotSegment], min_shot_duration_ms: int) -> list[ShotSegment]:
165
+ if len(shots) <= 1:
166
+ return shots
167
+
168
+ merged = [
169
+ {"index": shot.index, "start_ms": shot.start_ms, "end_ms": shot.end_ms}
170
+ for shot in shots
171
+ ]
172
+
173
+ has_short = True
174
+ while has_short:
175
+ has_short = False
176
+ next_shots: list[dict[str, int]] = []
177
+
178
+ for index, shot in enumerate(merged):
179
+ duration = shot["end_ms"] - shot["start_ms"]
180
+ if duration >= min_shot_duration_ms:
181
+ next_shots.append(shot)
182
+ continue
183
+
184
+ has_short = True
185
+ previous = next_shots[-1] if next_shots else None
186
+ following = merged[index + 1] if index + 1 < len(merged) else None
187
+
188
+ if previous is None and following is not None:
189
+ following["start_ms"] = shot["start_ms"]
190
+ elif previous is not None and following is None:
191
+ previous["end_ms"] = shot["end_ms"]
192
+ elif previous is not None and following is not None:
193
+ previous_duration = previous["end_ms"] - previous["start_ms"]
194
+ following_duration = following["end_ms"] - following["start_ms"]
195
+ if previous_duration <= following_duration:
196
+ previous["end_ms"] = shot["end_ms"]
197
+ else:
198
+ following["start_ms"] = shot["start_ms"]
199
+
200
+ merged = next_shots
201
+ if len(merged) <= 1:
202
+ break
203
+
204
+ return [
205
+ ShotSegment(index=index, start_ms=shot["start_ms"], end_ms=shot["end_ms"])
206
+ for index, shot in enumerate(merged)
207
+ ]
@@ -0,0 +1,2 @@
1
+ class ShotDetectionError(RuntimeError):
2
+ """Raised when shot detection preprocessing or inference fails."""
@@ -0,0 +1,281 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import shutil
5
+ import subprocess
6
+ import sys
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+
10
+ from .errors import ShotDetectionError
11
+ from .types import VideoMetadata
12
+
13
+ MODEL_INPUT_WIDTH = 48
14
+ MODEL_INPUT_HEIGHT = 27
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class FfmpegRuntimeCapabilities:
19
+ hwaccels: set[str]
20
+ decoders: set[str]
21
+
22
+
23
+ @dataclass(frozen=True)
24
+ class FfmpegInputPlan:
25
+ label: str
26
+ input_options: list[str]
27
+ uses_hardware_decode: bool
28
+
29
+
30
+ def resolve_binary(name: str) -> str:
31
+ binary = shutil.which(name)
32
+ if binary is None:
33
+ raise ShotDetectionError(
34
+ f"Required binary '{name}' was not found on PATH. Please install FFmpeg before running shot detection."
35
+ )
36
+ return binary
37
+
38
+
39
+ def run_command(command: list[str]) -> subprocess.CompletedProcess[bytes]:
40
+ try:
41
+ return subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
42
+ except subprocess.CalledProcessError as exc:
43
+ message = exc.stderr.decode("utf-8", errors="replace").strip() or str(exc)
44
+ raise ShotDetectionError(message) from exc
45
+
46
+
47
+ def _parse_ffmpeg_list(output: str) -> set[str]:
48
+ values: set[str] = set()
49
+ for line in output.splitlines():
50
+ trimmed = line.strip()
51
+ if not trimmed:
52
+ continue
53
+
54
+ parts = trimmed.split()
55
+ if len(parts) >= 2 and parts[0].isupper():
56
+ values.add(parts[1].lower())
57
+ elif trimmed.isidentifier():
58
+ values.add(trimmed.lower())
59
+
60
+ return values
61
+
62
+
63
+ def get_runtime_capabilities(ffmpeg_path: str) -> FfmpegRuntimeCapabilities:
64
+ hwaccels = _parse_ffmpeg_list(run_command([ffmpeg_path, "-hide_banner", "-hwaccels"]).stdout.decode("utf-8"))
65
+ decoders = _parse_ffmpeg_list(run_command([ffmpeg_path, "-hide_banner", "-decoders"]).stdout.decode("utf-8"))
66
+ return FfmpegRuntimeCapabilities(hwaccels=hwaccels, decoders=decoders)
67
+
68
+
69
+ def _linux_vaapi_options() -> list[str] | None:
70
+ dri_dir = Path("/dev/dri")
71
+ if not dri_dir.exists():
72
+ return None
73
+
74
+ render_nodes = sorted(dri_dir.glob("renderD*"))
75
+ if not render_nodes:
76
+ return None
77
+
78
+ return [
79
+ "-threads",
80
+ "1",
81
+ "-hwaccel",
82
+ "vaapi",
83
+ "-hwaccel_output_format",
84
+ "vaapi",
85
+ "-vaapi_device",
86
+ str(render_nodes[0]),
87
+ ]
88
+
89
+
90
+ def get_decode_plans(platform_name: str | None = None) -> list[FfmpegInputPlan]:
91
+ ffmpeg_path = resolve_binary("ffmpeg")
92
+ capabilities = get_runtime_capabilities(ffmpeg_path)
93
+ platform_name = platform_name or sys.platform
94
+
95
+ plans: list[FfmpegInputPlan] = []
96
+
97
+ if platform_name == "darwin" and "videotoolbox" in capabilities.hwaccels:
98
+ plans.append(
99
+ FfmpegInputPlan(
100
+ label="videotoolbox-hardware",
101
+ input_options=["-threads", "1", "-hwaccel", "videotoolbox"],
102
+ uses_hardware_decode=True,
103
+ )
104
+ )
105
+
106
+ if platform_name == "win32":
107
+ if "d3d12va" in capabilities.hwaccels:
108
+ plans.append(
109
+ FfmpegInputPlan(
110
+ label="d3d12va-hardware",
111
+ input_options=["-threads", "1", "-hwaccel", "d3d12va", "-hwaccel_output_format", "d3d12"],
112
+ uses_hardware_decode=True,
113
+ )
114
+ )
115
+ if "d3d11va" in capabilities.hwaccels:
116
+ plans.append(
117
+ FfmpegInputPlan(
118
+ label="d3d11va-hardware",
119
+ input_options=["-threads", "1", "-hwaccel", "d3d11va", "-hwaccel_output_format", "d3d11"],
120
+ uses_hardware_decode=True,
121
+ )
122
+ )
123
+ if "dxva2" in capabilities.hwaccels:
124
+ plans.append(
125
+ FfmpegInputPlan(
126
+ label="dxva2-hardware",
127
+ input_options=["-threads", "1", "-hwaccel", "dxva2", "-hwaccel_output_format", "dxva2_vld"],
128
+ uses_hardware_decode=True,
129
+ )
130
+ )
131
+
132
+ if platform_name.startswith("linux") and "vaapi" in capabilities.hwaccels:
133
+ vaapi_options = _linux_vaapi_options()
134
+ if vaapi_options is not None:
135
+ plans.append(
136
+ FfmpegInputPlan(
137
+ label="vaapi-hardware",
138
+ input_options=vaapi_options,
139
+ uses_hardware_decode=True,
140
+ )
141
+ )
142
+
143
+ plans.append(
144
+ FfmpegInputPlan(
145
+ label="software-fallback",
146
+ input_options=["-threads", "1"],
147
+ uses_hardware_decode=False,
148
+ )
149
+ )
150
+ return plans
151
+
152
+
153
+ def _requires_hwdownload(plan: FfmpegInputPlan, platform_name: str | None = None) -> bool:
154
+ platform_name = platform_name or sys.platform
155
+ if platform_name == "darwin" and plan.label == "videotoolbox-hardware":
156
+ return False
157
+ return plan.uses_hardware_decode
158
+
159
+
160
+ def probe_video(video_path: str | Path) -> VideoMetadata:
161
+ ffprobe_path = resolve_binary("ffprobe")
162
+ result = run_command(
163
+ [
164
+ ffprobe_path,
165
+ "-v",
166
+ "error",
167
+ "-show_streams",
168
+ "-show_format",
169
+ "-print_format",
170
+ "json",
171
+ str(video_path),
172
+ ]
173
+ )
174
+
175
+ try:
176
+ payload = json.loads(result.stdout.decode("utf-8"))
177
+ except json.JSONDecodeError as exc:
178
+ raise ShotDetectionError("Failed to parse ffprobe output") from exc
179
+
180
+ streams = payload.get("streams")
181
+ if not isinstance(streams, list):
182
+ raise ShotDetectionError("ffprobe did not return stream metadata")
183
+
184
+ video_stream = next((stream for stream in streams if stream.get("codec_type") == "video"), None)
185
+ if not isinstance(video_stream, dict):
186
+ raise ShotDetectionError("No video stream found in the input file")
187
+
188
+ width = int(video_stream.get("width") or 0)
189
+ height = int(video_stream.get("height") or 0)
190
+ fps = _parse_frame_rate(video_stream.get("avg_frame_rate") or video_stream.get("r_frame_rate"))
191
+ duration_seconds = _parse_duration_seconds(payload.get("format"), video_stream)
192
+
193
+ if width <= 0 or height <= 0 or fps <= 0 or duration_seconds <= 0:
194
+ raise ShotDetectionError("Unable to determine valid video metadata")
195
+
196
+ return VideoMetadata(
197
+ duration_ms=int(round(duration_seconds * 1000)),
198
+ width=width,
199
+ height=height,
200
+ fps=fps,
201
+ codec=video_stream.get("codec_name"),
202
+ )
203
+
204
+
205
+ def _parse_frame_rate(value: object) -> float:
206
+ if not isinstance(value, str) or not value:
207
+ return 0.0
208
+
209
+ if "/" in value:
210
+ numerator_text, denominator_text = value.split("/", 1)
211
+ try:
212
+ numerator = float(numerator_text)
213
+ denominator = float(denominator_text)
214
+ except ValueError:
215
+ return 0.0
216
+ return 0.0 if denominator == 0 else numerator / denominator
217
+
218
+ try:
219
+ return float(value)
220
+ except ValueError:
221
+ return 0.0
222
+
223
+
224
+ def _parse_duration_seconds(format_info: object, video_stream: dict[str, object]) -> float:
225
+ if isinstance(format_info, dict):
226
+ try:
227
+ duration = float(format_info.get("duration") or 0.0)
228
+ except (TypeError, ValueError):
229
+ duration = 0.0
230
+ if duration > 0:
231
+ return duration
232
+
233
+ try:
234
+ return float(video_stream.get("duration") or 0.0)
235
+ except (TypeError, ValueError):
236
+ return 0.0
237
+
238
+
239
+ def extract_low_res_rgb_frames(
240
+ video_path: str | Path,
241
+ *,
242
+ width: int = MODEL_INPUT_WIDTH,
243
+ height: int = MODEL_INPUT_HEIGHT,
244
+ ) -> bytes:
245
+ ffmpeg_path = resolve_binary("ffmpeg")
246
+ last_error: ShotDetectionError | None = None
247
+
248
+ for plan in get_decode_plans():
249
+ scale_filter = f"scale={width}:{height}:flags=fast_bilinear"
250
+ if _requires_hwdownload(plan):
251
+ filter_chain = f"hwdownload,format=nv12,{scale_filter},format=rgb24"
252
+ else:
253
+ filter_chain = f"{scale_filter},format=rgb24"
254
+
255
+ command = [
256
+ ffmpeg_path,
257
+ "-v",
258
+ "error",
259
+ *plan.input_options,
260
+ "-i",
261
+ str(video_path),
262
+ "-vf",
263
+ filter_chain,
264
+ "-pix_fmt",
265
+ "rgb24",
266
+ "-vsync",
267
+ "0",
268
+ "-f",
269
+ "rawvideo",
270
+ "pipe:1",
271
+ ]
272
+
273
+ try:
274
+ return run_command(command).stdout
275
+ except ShotDetectionError as exc:
276
+ last_error = exc
277
+
278
+ if last_error is not None:
279
+ raise last_error
280
+
281
+ raise ShotDetectionError("No FFmpeg decode plan was available")
@@ -0,0 +1,65 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import onnxruntime as ort
7
+
8
+ from .errors import ShotDetectionError
9
+
10
+ WINDOW_SIZE = 100
11
+ MODEL_INPUT_WIDTH = 48
12
+ MODEL_INPUT_HEIGHT = 27
13
+ MODEL_INPUT_CHANNELS = 3
14
+ MODEL_INPUT_SIZE = WINDOW_SIZE * MODEL_INPUT_WIDTH * MODEL_INPUT_HEIGHT * MODEL_INPUT_CHANNELS
15
+
16
+
17
+ class TransNetV2ONNX:
18
+ def __init__(self, model_path: str | Path) -> None:
19
+ self.model_path = str(model_path)
20
+ self._session = self._create_session(self.model_path)
21
+ self._input_name = self._session.get_inputs()[0].name
22
+ self._output_name = self._session.get_outputs()[0].name
23
+ self._input_type = self._session.get_inputs()[0].type
24
+
25
+ def _create_session(self, model_path: str) -> ort.InferenceSession:
26
+ path = Path(model_path)
27
+ if not path.is_file():
28
+ raise ShotDetectionError(f"Model file not found: {model_path}")
29
+
30
+ available_providers = ort.get_available_providers()
31
+ preferred_order = [
32
+ "CoreMLExecutionProvider",
33
+ "DmlExecutionProvider",
34
+ "CPUExecutionProvider",
35
+ ]
36
+ providers = [provider for provider in preferred_order if provider in available_providers]
37
+ if not providers:
38
+ providers = available_providers
39
+ if not providers:
40
+ raise ShotDetectionError("No ONNX Runtime execution provider is available")
41
+
42
+ try:
43
+ return ort.InferenceSession(model_path, providers=providers)
44
+ except Exception as exc: # pragma: no cover
45
+ raise ShotDetectionError(f"Failed to initialize ONNX Runtime session: {exc}") from exc
46
+
47
+ def infer(self, frames: np.ndarray) -> np.ndarray:
48
+ if frames.shape != (1, WINDOW_SIZE, MODEL_INPUT_HEIGHT, MODEL_INPUT_WIDTH, MODEL_INPUT_CHANNELS):
49
+ raise ShotDetectionError(f"Unexpected frame tensor shape: {frames.shape}")
50
+
51
+ if self._input_type == "tensor(uint8)":
52
+ input_data = frames.astype(np.uint8, copy=False)
53
+ else:
54
+ input_data = frames.astype(np.float32, copy=False)
55
+
56
+ try:
57
+ outputs = self._session.run([self._output_name], {self._input_name: input_data})
58
+ except Exception as exc: # pragma: no cover
59
+ raise ShotDetectionError(f"ONNX inference failed: {exc}") from exc
60
+
61
+ output = np.asarray(outputs[0]).reshape(-1)
62
+ if output.size < WINDOW_SIZE:
63
+ raise ShotDetectionError(f"Model output is too short: {output.size}")
64
+
65
+ return output[:WINDOW_SIZE]
@@ -0,0 +1,75 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import tempfile
5
+ import urllib.request
6
+ from pathlib import Path
7
+
8
+ from .errors import ShotDetectionError
9
+
10
+ DEFAULT_MODEL_URL = "https://download.shotai.io/model/shot-detection/transnetv2_open_fp16.onnx"
11
+ DEFAULT_MODEL_FILENAME = "transnetv2_open_fp16.onnx"
12
+ DEFAULT_CACHE_DIRNAME = "shot-detection"
13
+
14
+
15
+ def get_default_cache_dir() -> Path:
16
+ override = os.environ.get("SHOT_DETECTION_CACHE_DIR")
17
+ if override:
18
+ return Path(override).expanduser()
19
+
20
+ if os.name == "nt":
21
+ local_app_data = os.environ.get("LOCALAPPDATA")
22
+ if local_app_data:
23
+ return Path(local_app_data) / DEFAULT_CACHE_DIRNAME
24
+
25
+ xdg_cache_home = os.environ.get("XDG_CACHE_HOME")
26
+ if xdg_cache_home:
27
+ return Path(xdg_cache_home) / DEFAULT_CACHE_DIRNAME
28
+
29
+ return Path.home() / ".cache" / DEFAULT_CACHE_DIRNAME
30
+
31
+
32
+ def get_default_model_path(cache_dir: str | Path | None = None) -> Path:
33
+ base_dir = Path(cache_dir).expanduser() if cache_dir is not None else get_default_cache_dir()
34
+ return base_dir / "models" / DEFAULT_MODEL_FILENAME
35
+
36
+
37
+ def ensure_default_model(
38
+ *,
39
+ model_path: str | Path | None = None,
40
+ cache_dir: str | Path | None = None,
41
+ model_url: str = DEFAULT_MODEL_URL,
42
+ timeout_seconds: int = 60,
43
+ ) -> Path:
44
+ resolved_path = Path(model_path).expanduser() if model_path is not None else get_default_model_path(cache_dir)
45
+ if resolved_path.is_file():
46
+ return resolved_path
47
+
48
+ resolved_path.parent.mkdir(parents=True, exist_ok=True)
49
+ temp_file_handle = tempfile.NamedTemporaryFile(
50
+ prefix=resolved_path.stem + ".",
51
+ suffix=".tmp",
52
+ dir=resolved_path.parent,
53
+ delete=False,
54
+ )
55
+ temp_path = Path(temp_file_handle.name)
56
+ temp_file_handle.close()
57
+
58
+ try:
59
+ with urllib.request.urlopen(model_url, timeout=timeout_seconds) as response, temp_path.open("wb") as file_obj:
60
+ while True:
61
+ chunk = response.read(1024 * 1024)
62
+ if not chunk:
63
+ break
64
+ file_obj.write(chunk)
65
+
66
+ if temp_path.stat().st_size == 0:
67
+ raise ShotDetectionError(f"Downloaded model is empty: {model_url}")
68
+
69
+ os.replace(temp_path, resolved_path)
70
+ return resolved_path
71
+ except Exception as exc:
72
+ temp_path.unlink(missing_ok=True)
73
+ raise ShotDetectionError(
74
+ f"Failed to download the default shot detection model from {model_url}: {exc}"
75
+ ) from exc
@@ -0,0 +1,35 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+
6
+ @dataclass(frozen=True)
7
+ class VideoMetadata:
8
+ duration_ms: int
9
+ width: int
10
+ height: int
11
+ fps: float
12
+ codec: str | None = None
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class ShotBoundary:
17
+ frame_index: int
18
+ timestamp_ms: int
19
+ confidence: float
20
+
21
+
22
+ @dataclass(frozen=True)
23
+ class ShotSegment:
24
+ index: int
25
+ start_ms: int
26
+ end_ms: int
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class DetectionResult:
31
+ boundaries: list[ShotBoundary]
32
+ shots: list[ShotSegment]
33
+ total_frames: int
34
+ fps: float
35
+ duration_ms: int
@@ -0,0 +1,33 @@
1
+ from shot_detection.detector import boundaries_to_shots, find_peaks, merge_short_shots
2
+ from shot_detection.types import ShotBoundary, ShotSegment
3
+
4
+
5
+ def test_find_peaks_prefers_stronger_peak_when_too_close() -> None:
6
+ peaks = find_peaks(__import__("numpy").array([0.0, 0.7, 0.0, 0.8, 0.0], dtype="float32"), 0.5, min_distance=10)
7
+ assert peaks == [3]
8
+
9
+
10
+ def test_boundaries_to_shots_returns_ranges_in_ms() -> None:
11
+ boundaries = [
12
+ ShotBoundary(frame_index=10, timestamp_ms=1000, confidence=0.9),
13
+ ShotBoundary(frame_index=20, timestamp_ms=2500, confidence=0.8),
14
+ ]
15
+ shots = boundaries_to_shots(boundaries, duration_ms=4000, min_shot_duration_ms=100)
16
+ assert shots == [
17
+ ShotSegment(index=0, start_ms=0, end_ms=1000),
18
+ ShotSegment(index=1, start_ms=1000, end_ms=2500),
19
+ ShotSegment(index=2, start_ms=2500, end_ms=4000),
20
+ ]
21
+
22
+
23
+ def test_merge_short_shots_merges_short_middle_segment() -> None:
24
+ shots = [
25
+ ShotSegment(index=0, start_ms=0, end_ms=1000),
26
+ ShotSegment(index=1, start_ms=1000, end_ms=1200),
27
+ ShotSegment(index=2, start_ms=1200, end_ms=3000),
28
+ ]
29
+ merged = merge_short_shots(shots, min_shot_duration_ms=500)
30
+ assert merged == [
31
+ ShotSegment(index=0, start_ms=0, end_ms=1200),
32
+ ShotSegment(index=1, start_ms=1200, end_ms=3000),
33
+ ]
@@ -0,0 +1,45 @@
1
+ from __future__ import annotations
2
+
3
+ import io
4
+ from pathlib import Path
5
+
6
+ from shot_detection.model_manager import ensure_default_model, get_default_model_path
7
+
8
+
9
+ class _FakeResponse:
10
+ def __init__(self, payload: bytes) -> None:
11
+ self._buffer = io.BytesIO(payload)
12
+
13
+ def read(self, size: int = -1) -> bytes:
14
+ return self._buffer.read(size)
15
+
16
+ def __enter__(self) -> "_FakeResponse":
17
+ return self
18
+
19
+ def __exit__(self, exc_type, exc, tb) -> None:
20
+ return None
21
+
22
+
23
+ def test_get_default_model_path_uses_cache_dir(tmp_path: Path) -> None:
24
+ model_path = get_default_model_path(tmp_path)
25
+ assert model_path == tmp_path / "models" / "transnetv2_open_fp16.onnx"
26
+
27
+
28
+ def test_ensure_default_model_downloads_once(monkeypatch, tmp_path: Path) -> None:
29
+ payload = b"fake-onnx-model"
30
+ calls: list[str] = []
31
+
32
+ def fake_urlopen(url: str, timeout: int = 60) -> _FakeResponse:
33
+ calls.append(url)
34
+ return _FakeResponse(payload)
35
+
36
+ monkeypatch.setattr("shot_detection.model_manager.urllib.request.urlopen", fake_urlopen)
37
+
38
+ model_path = ensure_default_model(cache_dir=tmp_path)
39
+ assert model_path.is_file()
40
+ assert model_path.read_bytes() == payload
41
+ assert len(calls) == 1
42
+
43
+ same_path = ensure_default_model(cache_dir=tmp_path)
44
+ assert same_path == model_path
45
+ assert len(calls) == 1