shot-detection 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- shot_detection-0.1.0/.github/workflows/publish-pypi.yml +54 -0
- shot_detection-0.1.0/PKG-INFO +106 -0
- shot_detection-0.1.0/README.md +79 -0
- shot_detection-0.1.0/bin/publish-pypi +11 -0
- shot_detection-0.1.0/pyproject.toml +47 -0
- shot_detection-0.1.0/src/shot_detection/__init__.py +24 -0
- shot_detection-0.1.0/src/shot_detection/detector.py +207 -0
- shot_detection-0.1.0/src/shot_detection/errors.py +2 -0
- shot_detection-0.1.0/src/shot_detection/ffmpeg.py +281 -0
- shot_detection-0.1.0/src/shot_detection/inference.py +65 -0
- shot_detection-0.1.0/src/shot_detection/model_manager.py +75 -0
- shot_detection-0.1.0/src/shot_detection/py.typed +1 -0
- shot_detection-0.1.0/src/shot_detection/types.py +35 -0
- shot_detection-0.1.0/tests/test_detector.py +33 -0
- shot_detection-0.1.0/tests/test_model_manager.py +45 -0
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
name: Publish PyPI
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
workflow_dispatch:
|
|
5
|
+
release:
|
|
6
|
+
types: [published]
|
|
7
|
+
|
|
8
|
+
permissions:
|
|
9
|
+
contents: read
|
|
10
|
+
|
|
11
|
+
jobs:
|
|
12
|
+
build:
|
|
13
|
+
name: Build distributions
|
|
14
|
+
runs-on: ubuntu-latest
|
|
15
|
+
|
|
16
|
+
steps:
|
|
17
|
+
- uses: actions/checkout@v5
|
|
18
|
+
|
|
19
|
+
- uses: actions/setup-python@v5
|
|
20
|
+
with:
|
|
21
|
+
python-version: "3.12"
|
|
22
|
+
|
|
23
|
+
- name: Build distributions
|
|
24
|
+
run: |
|
|
25
|
+
python -m pip install --upgrade build twine
|
|
26
|
+
python -m build
|
|
27
|
+
python -m twine check dist/*
|
|
28
|
+
|
|
29
|
+
- name: Upload distributions
|
|
30
|
+
uses: actions/upload-artifact@v4
|
|
31
|
+
with:
|
|
32
|
+
name: release-dists
|
|
33
|
+
path: dist/
|
|
34
|
+
|
|
35
|
+
publish:
|
|
36
|
+
name: Publish to PyPI
|
|
37
|
+
runs-on: ubuntu-latest
|
|
38
|
+
needs: build
|
|
39
|
+
permissions:
|
|
40
|
+
id-token: write
|
|
41
|
+
|
|
42
|
+
environment:
|
|
43
|
+
name: pypi
|
|
44
|
+
url: https://pypi.org/project/shot-detection/
|
|
45
|
+
|
|
46
|
+
steps:
|
|
47
|
+
- name: Download distributions
|
|
48
|
+
uses: actions/download-artifact@v5
|
|
49
|
+
with:
|
|
50
|
+
name: release-dists
|
|
51
|
+
path: dist/
|
|
52
|
+
|
|
53
|
+
- name: Publish distributions to PyPI
|
|
54
|
+
uses: pypa/gh-action-pypi-publish@release/v1
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: shot-detection
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Standalone Python shot boundary detection package for splitting videos into shots
|
|
5
|
+
Project-URL: Homepage, https://github.com/Seeknetic/shot-detection-python
|
|
6
|
+
Project-URL: Repository, https://github.com/Seeknetic/shot-detection-python
|
|
7
|
+
Author: Seeknetic
|
|
8
|
+
License-Expression: Apache-2.0
|
|
9
|
+
Classifier: Intended Audience :: Developers
|
|
10
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
11
|
+
Classifier: Operating System :: OS Independent
|
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
17
|
+
Classifier: Topic :: Multimedia :: Video
|
|
18
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
19
|
+
Classifier: Typing :: Typed
|
|
20
|
+
Requires-Python: >=3.9
|
|
21
|
+
Requires-Dist: numpy<3,>=1.24
|
|
22
|
+
Requires-Dist: onnxruntime<2,>=1.17
|
|
23
|
+
Requires-Dist: typing-extensions<5,>=4.9
|
|
24
|
+
Provides-Extra: dev
|
|
25
|
+
Requires-Dist: pytest<9,>=8; extra == 'dev'
|
|
26
|
+
Description-Content-Type: text/markdown
|
|
27
|
+
|
|
28
|
+
# Shot Detection Python
|
|
29
|
+
|
|
30
|
+
Standalone Python package for shot boundary detection.
|
|
31
|
+
|
|
32
|
+
It takes a video file, runs TransNetV2-style ONNX inference on low-resolution RGB frames, and returns shot ranges with millisecond timestamps.
|
|
33
|
+
|
|
34
|
+
Built by [Seeknetic](https://www.seeknetic.com/). If you want to make video shots searchable and enable more professional video understanding workflows, visit [seeknetic.com](https://www.seeknetic.com/).
|
|
35
|
+
|
|
36
|
+
## What it does
|
|
37
|
+
|
|
38
|
+
- Probes the input video with `ffprobe`
|
|
39
|
+
- Extracts low-resolution frames with `ffmpeg`
|
|
40
|
+
- Runs sliding-window ONNX inference
|
|
41
|
+
- Finds shot boundaries
|
|
42
|
+
- Converts them into `{start_ms, end_ms}` shot segments
|
|
43
|
+
|
|
44
|
+
## Installation
|
|
45
|
+
|
|
46
|
+
```bash
|
|
47
|
+
pip install shot-detection
|
|
48
|
+
```
|
|
49
|
+
|
|
50
|
+
Import from `shot_detection`:
|
|
51
|
+
|
|
52
|
+
```python
|
|
53
|
+
from shot_detection import ShotDetector
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
You also need:
|
|
57
|
+
|
|
58
|
+
- `ffmpeg`
|
|
59
|
+
- `ffprobe`
|
|
60
|
+
|
|
61
|
+
On first run, the package downloads the default model automatically:
|
|
62
|
+
|
|
63
|
+
- URL: `https://download.shotai.io/model/shot-detection/transnetv2_open_fp16.onnx`
|
|
64
|
+
- cache dir:
|
|
65
|
+
- Linux/macOS: `~/.cache/shot-detection/models/`
|
|
66
|
+
- Windows: `%LOCALAPPDATA%\\shot-detection\\models\\`
|
|
67
|
+
|
|
68
|
+
You can override the cache root with `SHOT_DETECTION_CACHE_DIR`.
|
|
69
|
+
|
|
70
|
+
## Usage
|
|
71
|
+
|
|
72
|
+
```python
|
|
73
|
+
from shot_detection import ShotDetector
|
|
74
|
+
|
|
75
|
+
detector = ShotDetector()
|
|
76
|
+
shots = detector.detect("/path/to/video.mp4")
|
|
77
|
+
|
|
78
|
+
for shot in shots:
|
|
79
|
+
print(shot.start_ms, shot.end_ms)
|
|
80
|
+
```
|
|
81
|
+
|
|
82
|
+
## Advanced usage
|
|
83
|
+
|
|
84
|
+
```python
|
|
85
|
+
from shot_detection import detect_shots
|
|
86
|
+
|
|
87
|
+
shots = detect_shots(
|
|
88
|
+
video_path="/path/to/video.mp4",
|
|
89
|
+
threshold=0.5,
|
|
90
|
+
min_shot_duration_ms=500,
|
|
91
|
+
)
|
|
92
|
+
```
|
|
93
|
+
|
|
94
|
+
## Custom model path
|
|
95
|
+
|
|
96
|
+
```python
|
|
97
|
+
from shot_detection import ShotDetector
|
|
98
|
+
|
|
99
|
+
detector = ShotDetector(model_path="/path/to/custom-transnetv2.onnx")
|
|
100
|
+
```
|
|
101
|
+
|
|
102
|
+
## Notes
|
|
103
|
+
|
|
104
|
+
- The package expects the ONNX model input to accept 100-frame windows at `48x27` RGB.
|
|
105
|
+
- `ffmpeg` decode is adaptive: it prefers system-native hardware decoding when available and falls back to software decoding automatically.
|
|
106
|
+
- CUDA is intentionally not part of the default decode plan.
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
# Shot Detection Python
|
|
2
|
+
|
|
3
|
+
Standalone Python package for shot boundary detection.
|
|
4
|
+
|
|
5
|
+
It takes a video file, runs TransNetV2-style ONNX inference on low-resolution RGB frames, and returns shot ranges with millisecond timestamps.
|
|
6
|
+
|
|
7
|
+
Built by [Seeknetic](https://www.seeknetic.com/). If you want to make video shots searchable and enable more professional video understanding workflows, visit [seeknetic.com](https://www.seeknetic.com/).
|
|
8
|
+
|
|
9
|
+
## What it does
|
|
10
|
+
|
|
11
|
+
- Probes the input video with `ffprobe`
|
|
12
|
+
- Extracts low-resolution frames with `ffmpeg`
|
|
13
|
+
- Runs sliding-window ONNX inference
|
|
14
|
+
- Finds shot boundaries
|
|
15
|
+
- Converts them into `{start_ms, end_ms}` shot segments
|
|
16
|
+
|
|
17
|
+
## Installation
|
|
18
|
+
|
|
19
|
+
```bash
|
|
20
|
+
pip install shot-detection
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
Import from `shot_detection`:
|
|
24
|
+
|
|
25
|
+
```python
|
|
26
|
+
from shot_detection import ShotDetector
|
|
27
|
+
```
|
|
28
|
+
|
|
29
|
+
You also need:
|
|
30
|
+
|
|
31
|
+
- `ffmpeg`
|
|
32
|
+
- `ffprobe`
|
|
33
|
+
|
|
34
|
+
On first run, the package downloads the default model automatically:
|
|
35
|
+
|
|
36
|
+
- URL: `https://download.shotai.io/model/shot-detection/transnetv2_open_fp16.onnx`
|
|
37
|
+
- cache dir:
|
|
38
|
+
- Linux/macOS: `~/.cache/shot-detection/models/`
|
|
39
|
+
- Windows: `%LOCALAPPDATA%\\shot-detection\\models\\`
|
|
40
|
+
|
|
41
|
+
You can override the cache root with `SHOT_DETECTION_CACHE_DIR`.
|
|
42
|
+
|
|
43
|
+
## Usage
|
|
44
|
+
|
|
45
|
+
```python
|
|
46
|
+
from shot_detection import ShotDetector
|
|
47
|
+
|
|
48
|
+
detector = ShotDetector()
|
|
49
|
+
shots = detector.detect("/path/to/video.mp4")
|
|
50
|
+
|
|
51
|
+
for shot in shots:
|
|
52
|
+
print(shot.start_ms, shot.end_ms)
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
## Advanced usage
|
|
56
|
+
|
|
57
|
+
```python
|
|
58
|
+
from shot_detection import detect_shots
|
|
59
|
+
|
|
60
|
+
shots = detect_shots(
|
|
61
|
+
video_path="/path/to/video.mp4",
|
|
62
|
+
threshold=0.5,
|
|
63
|
+
min_shot_duration_ms=500,
|
|
64
|
+
)
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
## Custom model path
|
|
68
|
+
|
|
69
|
+
```python
|
|
70
|
+
from shot_detection import ShotDetector
|
|
71
|
+
|
|
72
|
+
detector = ShotDetector(model_path="/path/to/custom-transnetv2.onnx")
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
## Notes
|
|
76
|
+
|
|
77
|
+
- The package expects the ONNX model input to accept 100-frame windows at `48x27` RGB.
|
|
78
|
+
- `ffmpeg` decode is adaptive: it prefers system-native hardware decoding when available and falls back to software decoding automatically.
|
|
79
|
+
- CUDA is intentionally not part of the default decode plan.
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "shot-detection"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Standalone Python shot boundary detection package for splitting videos into shots"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
license = "Apache-2.0"
|
|
7
|
+
authors = [
|
|
8
|
+
{ name = "Seeknetic" },
|
|
9
|
+
]
|
|
10
|
+
requires-python = ">=3.9"
|
|
11
|
+
dependencies = [
|
|
12
|
+
"numpy>=1.24,<3",
|
|
13
|
+
"onnxruntime>=1.17,<2",
|
|
14
|
+
"typing-extensions>=4.9,<5",
|
|
15
|
+
]
|
|
16
|
+
classifiers = [
|
|
17
|
+
"Typing :: Typed",
|
|
18
|
+
"Intended Audience :: Developers",
|
|
19
|
+
"Programming Language :: Python :: 3",
|
|
20
|
+
"Programming Language :: Python :: 3.9",
|
|
21
|
+
"Programming Language :: Python :: 3.10",
|
|
22
|
+
"Programming Language :: Python :: 3.11",
|
|
23
|
+
"Programming Language :: Python :: 3.12",
|
|
24
|
+
"Operating System :: OS Independent",
|
|
25
|
+
"License :: OSI Approved :: Apache Software License",
|
|
26
|
+
"Topic :: Multimedia :: Video",
|
|
27
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
[project.urls]
|
|
31
|
+
Homepage = "https://github.com/Seeknetic/shot-detection-python"
|
|
32
|
+
Repository = "https://github.com/Seeknetic/shot-detection-python"
|
|
33
|
+
|
|
34
|
+
[project.optional-dependencies]
|
|
35
|
+
dev = [
|
|
36
|
+
"pytest>=8,<9",
|
|
37
|
+
]
|
|
38
|
+
|
|
39
|
+
[build-system]
|
|
40
|
+
requires = ["hatchling>=1.26,<2"]
|
|
41
|
+
build-backend = "hatchling.build"
|
|
42
|
+
|
|
43
|
+
[tool.hatch.build.targets.wheel]
|
|
44
|
+
packages = ["src/shot_detection"]
|
|
45
|
+
|
|
46
|
+
[tool.pytest.ini_options]
|
|
47
|
+
testpaths = ["tests"]
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from .detector import ShotDetector, detect_shots, detect_shots_detailed
|
|
2
|
+
from .errors import ShotDetectionError
|
|
3
|
+
from .model_manager import (
|
|
4
|
+
DEFAULT_MODEL_URL,
|
|
5
|
+
ensure_default_model,
|
|
6
|
+
get_default_cache_dir,
|
|
7
|
+
get_default_model_path,
|
|
8
|
+
)
|
|
9
|
+
from .types import DetectionResult, ShotBoundary, ShotSegment, VideoMetadata
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"ShotDetector",
|
|
13
|
+
"ShotDetectionError",
|
|
14
|
+
"DEFAULT_MODEL_URL",
|
|
15
|
+
"ShotBoundary",
|
|
16
|
+
"ShotSegment",
|
|
17
|
+
"VideoMetadata",
|
|
18
|
+
"DetectionResult",
|
|
19
|
+
"ensure_default_model",
|
|
20
|
+
"get_default_cache_dir",
|
|
21
|
+
"get_default_model_path",
|
|
22
|
+
"detect_shots",
|
|
23
|
+
"detect_shots_detailed",
|
|
24
|
+
]
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from .errors import ShotDetectionError
|
|
9
|
+
from .ffmpeg import MODEL_INPUT_HEIGHT, MODEL_INPUT_WIDTH, extract_low_res_rgb_frames, probe_video
|
|
10
|
+
from .inference import MODEL_INPUT_CHANNELS, WINDOW_SIZE, TransNetV2ONNX
|
|
11
|
+
from .model_manager import ensure_default_model
|
|
12
|
+
from .types import DetectionResult, ShotBoundary, ShotSegment
|
|
13
|
+
|
|
14
|
+
WINDOW_STEP = WINDOW_SIZE - 10
|
|
15
|
+
DEFAULT_THRESHOLD = 0.5
|
|
16
|
+
DEFAULT_MIN_SHOT_DURATION_MS = 500
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ShotDetector:
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
*,
|
|
23
|
+
model_path: str | Path | None = None,
|
|
24
|
+
threshold: float = DEFAULT_THRESHOLD,
|
|
25
|
+
min_shot_duration_ms: int = DEFAULT_MIN_SHOT_DURATION_MS,
|
|
26
|
+
model_cache_dir: str | Path | None = None,
|
|
27
|
+
) -> None:
|
|
28
|
+
resolved_model_path = ensure_default_model(model_path=model_path, cache_dir=model_cache_dir)
|
|
29
|
+
self.model = TransNetV2ONNX(resolved_model_path)
|
|
30
|
+
self.threshold = threshold
|
|
31
|
+
self.min_shot_duration_ms = min_shot_duration_ms
|
|
32
|
+
|
|
33
|
+
def detect(self, video_path: str | Path) -> list[ShotSegment]:
|
|
34
|
+
return self.detect_detailed(video_path).shots
|
|
35
|
+
|
|
36
|
+
def detect_detailed(self, video_path: str | Path) -> DetectionResult:
|
|
37
|
+
metadata = probe_video(video_path)
|
|
38
|
+
raw_frames = extract_low_res_rgb_frames(video_path)
|
|
39
|
+
frame_size = MODEL_INPUT_WIDTH * MODEL_INPUT_HEIGHT * MODEL_INPUT_CHANNELS
|
|
40
|
+
total_frames = len(raw_frames) // frame_size
|
|
41
|
+
|
|
42
|
+
if total_frames <= 0:
|
|
43
|
+
raise ShotDetectionError("No frames were extracted from the input video")
|
|
44
|
+
|
|
45
|
+
frames = np.frombuffer(raw_frames[: total_frames * frame_size], dtype=np.uint8).reshape(
|
|
46
|
+
total_frames,
|
|
47
|
+
MODEL_INPUT_HEIGHT,
|
|
48
|
+
MODEL_INPUT_WIDTH,
|
|
49
|
+
MODEL_INPUT_CHANNELS,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
all_probs: list[float] = []
|
|
53
|
+
|
|
54
|
+
for start in range(0, total_frames, WINDOW_STEP):
|
|
55
|
+
end = min(start + WINDOW_SIZE, total_frames)
|
|
56
|
+
window_frames = end - start
|
|
57
|
+
if window_frames < 10:
|
|
58
|
+
break
|
|
59
|
+
|
|
60
|
+
buffer = np.zeros(
|
|
61
|
+
(1, WINDOW_SIZE, MODEL_INPUT_HEIGHT, MODEL_INPUT_WIDTH, MODEL_INPUT_CHANNELS),
|
|
62
|
+
dtype=np.uint8,
|
|
63
|
+
)
|
|
64
|
+
buffer[0, :window_frames] = frames[start:end]
|
|
65
|
+
|
|
66
|
+
probs = self.model.infer(buffer)
|
|
67
|
+
for offset in range(window_frames):
|
|
68
|
+
frame_index = start + offset
|
|
69
|
+
if frame_index >= len(all_probs):
|
|
70
|
+
all_probs.append(_to_boundary_probability(float(probs[offset])))
|
|
71
|
+
|
|
72
|
+
peak_indices = find_peaks(np.asarray(all_probs, dtype=np.float32), self.threshold)
|
|
73
|
+
boundaries = [
|
|
74
|
+
ShotBoundary(
|
|
75
|
+
frame_index=frame_index,
|
|
76
|
+
timestamp_ms=int(round((frame_index / metadata.fps) * 1000)),
|
|
77
|
+
confidence=float(all_probs[frame_index]),
|
|
78
|
+
)
|
|
79
|
+
for frame_index in peak_indices
|
|
80
|
+
]
|
|
81
|
+
shots = boundaries_to_shots(boundaries, metadata.duration_ms, self.min_shot_duration_ms)
|
|
82
|
+
|
|
83
|
+
return DetectionResult(
|
|
84
|
+
boundaries=boundaries,
|
|
85
|
+
shots=shots,
|
|
86
|
+
total_frames=total_frames,
|
|
87
|
+
fps=metadata.fps,
|
|
88
|
+
duration_ms=metadata.duration_ms,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def detect_shots(
|
|
93
|
+
*,
|
|
94
|
+
video_path: str | Path,
|
|
95
|
+
model_path: str | Path | None = None,
|
|
96
|
+
threshold: float = DEFAULT_THRESHOLD,
|
|
97
|
+
min_shot_duration_ms: int = DEFAULT_MIN_SHOT_DURATION_MS,
|
|
98
|
+
model_cache_dir: str | Path | None = None,
|
|
99
|
+
) -> list[ShotSegment]:
|
|
100
|
+
detector = ShotDetector(
|
|
101
|
+
model_path=model_path,
|
|
102
|
+
threshold=threshold,
|
|
103
|
+
min_shot_duration_ms=min_shot_duration_ms,
|
|
104
|
+
model_cache_dir=model_cache_dir,
|
|
105
|
+
)
|
|
106
|
+
return detector.detect(video_path)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def detect_shots_detailed(
|
|
110
|
+
*,
|
|
111
|
+
video_path: str | Path,
|
|
112
|
+
model_path: str | Path | None = None,
|
|
113
|
+
threshold: float = DEFAULT_THRESHOLD,
|
|
114
|
+
min_shot_duration_ms: int = DEFAULT_MIN_SHOT_DURATION_MS,
|
|
115
|
+
model_cache_dir: str | Path | None = None,
|
|
116
|
+
) -> DetectionResult:
|
|
117
|
+
detector = ShotDetector(
|
|
118
|
+
model_path=model_path,
|
|
119
|
+
threshold=threshold,
|
|
120
|
+
min_shot_duration_ms=min_shot_duration_ms,
|
|
121
|
+
model_cache_dir=model_cache_dir,
|
|
122
|
+
)
|
|
123
|
+
return detector.detect_detailed(video_path)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _to_boundary_probability(value: float) -> float:
|
|
127
|
+
if 0.0 <= value <= 1.0:
|
|
128
|
+
return value
|
|
129
|
+
return 1.0 / (1.0 + math.exp(-value))
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def find_peaks(probs: np.ndarray, threshold: float, min_distance: int = 10) -> list[int]:
|
|
133
|
+
peaks: list[int] = []
|
|
134
|
+
|
|
135
|
+
for index in range(1, len(probs) - 1):
|
|
136
|
+
current = float(probs[index])
|
|
137
|
+
previous = float(probs[index - 1])
|
|
138
|
+
following = float(probs[index + 1])
|
|
139
|
+
|
|
140
|
+
if current > threshold and current > previous and current > following:
|
|
141
|
+
if not peaks or index - peaks[-1] >= min_distance:
|
|
142
|
+
peaks.append(index)
|
|
143
|
+
elif current > float(probs[peaks[-1]]):
|
|
144
|
+
peaks[-1] = index
|
|
145
|
+
|
|
146
|
+
return peaks
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def boundaries_to_shots(
|
|
150
|
+
boundaries: list[ShotBoundary],
|
|
151
|
+
duration_ms: int,
|
|
152
|
+
min_shot_duration_ms: int = DEFAULT_MIN_SHOT_DURATION_MS,
|
|
153
|
+
) -> list[ShotSegment]:
|
|
154
|
+
start_points = [0, *[boundary.timestamp_ms for boundary in boundaries]]
|
|
155
|
+
end_points = [*[boundary.timestamp_ms for boundary in boundaries], duration_ms]
|
|
156
|
+
|
|
157
|
+
shots = [
|
|
158
|
+
ShotSegment(index=index, start_ms=start_points[index], end_ms=end_points[index])
|
|
159
|
+
for index in range(len(start_points))
|
|
160
|
+
]
|
|
161
|
+
return merge_short_shots(shots, min_shot_duration_ms)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def merge_short_shots(shots: list[ShotSegment], min_shot_duration_ms: int) -> list[ShotSegment]:
|
|
165
|
+
if len(shots) <= 1:
|
|
166
|
+
return shots
|
|
167
|
+
|
|
168
|
+
merged = [
|
|
169
|
+
{"index": shot.index, "start_ms": shot.start_ms, "end_ms": shot.end_ms}
|
|
170
|
+
for shot in shots
|
|
171
|
+
]
|
|
172
|
+
|
|
173
|
+
has_short = True
|
|
174
|
+
while has_short:
|
|
175
|
+
has_short = False
|
|
176
|
+
next_shots: list[dict[str, int]] = []
|
|
177
|
+
|
|
178
|
+
for index, shot in enumerate(merged):
|
|
179
|
+
duration = shot["end_ms"] - shot["start_ms"]
|
|
180
|
+
if duration >= min_shot_duration_ms:
|
|
181
|
+
next_shots.append(shot)
|
|
182
|
+
continue
|
|
183
|
+
|
|
184
|
+
has_short = True
|
|
185
|
+
previous = next_shots[-1] if next_shots else None
|
|
186
|
+
following = merged[index + 1] if index + 1 < len(merged) else None
|
|
187
|
+
|
|
188
|
+
if previous is None and following is not None:
|
|
189
|
+
following["start_ms"] = shot["start_ms"]
|
|
190
|
+
elif previous is not None and following is None:
|
|
191
|
+
previous["end_ms"] = shot["end_ms"]
|
|
192
|
+
elif previous is not None and following is not None:
|
|
193
|
+
previous_duration = previous["end_ms"] - previous["start_ms"]
|
|
194
|
+
following_duration = following["end_ms"] - following["start_ms"]
|
|
195
|
+
if previous_duration <= following_duration:
|
|
196
|
+
previous["end_ms"] = shot["end_ms"]
|
|
197
|
+
else:
|
|
198
|
+
following["start_ms"] = shot["start_ms"]
|
|
199
|
+
|
|
200
|
+
merged = next_shots
|
|
201
|
+
if len(merged) <= 1:
|
|
202
|
+
break
|
|
203
|
+
|
|
204
|
+
return [
|
|
205
|
+
ShotSegment(index=index, start_ms=shot["start_ms"], end_ms=shot["end_ms"])
|
|
206
|
+
for index, shot in enumerate(merged)
|
|
207
|
+
]
|
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import shutil
|
|
5
|
+
import subprocess
|
|
6
|
+
import sys
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
from .errors import ShotDetectionError
|
|
11
|
+
from .types import VideoMetadata
|
|
12
|
+
|
|
13
|
+
MODEL_INPUT_WIDTH = 48
|
|
14
|
+
MODEL_INPUT_HEIGHT = 27
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass(frozen=True)
|
|
18
|
+
class FfmpegRuntimeCapabilities:
|
|
19
|
+
hwaccels: set[str]
|
|
20
|
+
decoders: set[str]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass(frozen=True)
|
|
24
|
+
class FfmpegInputPlan:
|
|
25
|
+
label: str
|
|
26
|
+
input_options: list[str]
|
|
27
|
+
uses_hardware_decode: bool
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def resolve_binary(name: str) -> str:
|
|
31
|
+
binary = shutil.which(name)
|
|
32
|
+
if binary is None:
|
|
33
|
+
raise ShotDetectionError(
|
|
34
|
+
f"Required binary '{name}' was not found on PATH. Please install FFmpeg before running shot detection."
|
|
35
|
+
)
|
|
36
|
+
return binary
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def run_command(command: list[str]) -> subprocess.CompletedProcess[bytes]:
|
|
40
|
+
try:
|
|
41
|
+
return subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
42
|
+
except subprocess.CalledProcessError as exc:
|
|
43
|
+
message = exc.stderr.decode("utf-8", errors="replace").strip() or str(exc)
|
|
44
|
+
raise ShotDetectionError(message) from exc
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _parse_ffmpeg_list(output: str) -> set[str]:
|
|
48
|
+
values: set[str] = set()
|
|
49
|
+
for line in output.splitlines():
|
|
50
|
+
trimmed = line.strip()
|
|
51
|
+
if not trimmed:
|
|
52
|
+
continue
|
|
53
|
+
|
|
54
|
+
parts = trimmed.split()
|
|
55
|
+
if len(parts) >= 2 and parts[0].isupper():
|
|
56
|
+
values.add(parts[1].lower())
|
|
57
|
+
elif trimmed.isidentifier():
|
|
58
|
+
values.add(trimmed.lower())
|
|
59
|
+
|
|
60
|
+
return values
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def get_runtime_capabilities(ffmpeg_path: str) -> FfmpegRuntimeCapabilities:
|
|
64
|
+
hwaccels = _parse_ffmpeg_list(run_command([ffmpeg_path, "-hide_banner", "-hwaccels"]).stdout.decode("utf-8"))
|
|
65
|
+
decoders = _parse_ffmpeg_list(run_command([ffmpeg_path, "-hide_banner", "-decoders"]).stdout.decode("utf-8"))
|
|
66
|
+
return FfmpegRuntimeCapabilities(hwaccels=hwaccels, decoders=decoders)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _linux_vaapi_options() -> list[str] | None:
|
|
70
|
+
dri_dir = Path("/dev/dri")
|
|
71
|
+
if not dri_dir.exists():
|
|
72
|
+
return None
|
|
73
|
+
|
|
74
|
+
render_nodes = sorted(dri_dir.glob("renderD*"))
|
|
75
|
+
if not render_nodes:
|
|
76
|
+
return None
|
|
77
|
+
|
|
78
|
+
return [
|
|
79
|
+
"-threads",
|
|
80
|
+
"1",
|
|
81
|
+
"-hwaccel",
|
|
82
|
+
"vaapi",
|
|
83
|
+
"-hwaccel_output_format",
|
|
84
|
+
"vaapi",
|
|
85
|
+
"-vaapi_device",
|
|
86
|
+
str(render_nodes[0]),
|
|
87
|
+
]
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def get_decode_plans(platform_name: str | None = None) -> list[FfmpegInputPlan]:
|
|
91
|
+
ffmpeg_path = resolve_binary("ffmpeg")
|
|
92
|
+
capabilities = get_runtime_capabilities(ffmpeg_path)
|
|
93
|
+
platform_name = platform_name or sys.platform
|
|
94
|
+
|
|
95
|
+
plans: list[FfmpegInputPlan] = []
|
|
96
|
+
|
|
97
|
+
if platform_name == "darwin" and "videotoolbox" in capabilities.hwaccels:
|
|
98
|
+
plans.append(
|
|
99
|
+
FfmpegInputPlan(
|
|
100
|
+
label="videotoolbox-hardware",
|
|
101
|
+
input_options=["-threads", "1", "-hwaccel", "videotoolbox"],
|
|
102
|
+
uses_hardware_decode=True,
|
|
103
|
+
)
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if platform_name == "win32":
|
|
107
|
+
if "d3d12va" in capabilities.hwaccels:
|
|
108
|
+
plans.append(
|
|
109
|
+
FfmpegInputPlan(
|
|
110
|
+
label="d3d12va-hardware",
|
|
111
|
+
input_options=["-threads", "1", "-hwaccel", "d3d12va", "-hwaccel_output_format", "d3d12"],
|
|
112
|
+
uses_hardware_decode=True,
|
|
113
|
+
)
|
|
114
|
+
)
|
|
115
|
+
if "d3d11va" in capabilities.hwaccels:
|
|
116
|
+
plans.append(
|
|
117
|
+
FfmpegInputPlan(
|
|
118
|
+
label="d3d11va-hardware",
|
|
119
|
+
input_options=["-threads", "1", "-hwaccel", "d3d11va", "-hwaccel_output_format", "d3d11"],
|
|
120
|
+
uses_hardware_decode=True,
|
|
121
|
+
)
|
|
122
|
+
)
|
|
123
|
+
if "dxva2" in capabilities.hwaccels:
|
|
124
|
+
plans.append(
|
|
125
|
+
FfmpegInputPlan(
|
|
126
|
+
label="dxva2-hardware",
|
|
127
|
+
input_options=["-threads", "1", "-hwaccel", "dxva2", "-hwaccel_output_format", "dxva2_vld"],
|
|
128
|
+
uses_hardware_decode=True,
|
|
129
|
+
)
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
if platform_name.startswith("linux") and "vaapi" in capabilities.hwaccels:
|
|
133
|
+
vaapi_options = _linux_vaapi_options()
|
|
134
|
+
if vaapi_options is not None:
|
|
135
|
+
plans.append(
|
|
136
|
+
FfmpegInputPlan(
|
|
137
|
+
label="vaapi-hardware",
|
|
138
|
+
input_options=vaapi_options,
|
|
139
|
+
uses_hardware_decode=True,
|
|
140
|
+
)
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
plans.append(
|
|
144
|
+
FfmpegInputPlan(
|
|
145
|
+
label="software-fallback",
|
|
146
|
+
input_options=["-threads", "1"],
|
|
147
|
+
uses_hardware_decode=False,
|
|
148
|
+
)
|
|
149
|
+
)
|
|
150
|
+
return plans
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def _requires_hwdownload(plan: FfmpegInputPlan, platform_name: str | None = None) -> bool:
|
|
154
|
+
platform_name = platform_name or sys.platform
|
|
155
|
+
if platform_name == "darwin" and plan.label == "videotoolbox-hardware":
|
|
156
|
+
return False
|
|
157
|
+
return plan.uses_hardware_decode
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def probe_video(video_path: str | Path) -> VideoMetadata:
|
|
161
|
+
ffprobe_path = resolve_binary("ffprobe")
|
|
162
|
+
result = run_command(
|
|
163
|
+
[
|
|
164
|
+
ffprobe_path,
|
|
165
|
+
"-v",
|
|
166
|
+
"error",
|
|
167
|
+
"-show_streams",
|
|
168
|
+
"-show_format",
|
|
169
|
+
"-print_format",
|
|
170
|
+
"json",
|
|
171
|
+
str(video_path),
|
|
172
|
+
]
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
try:
|
|
176
|
+
payload = json.loads(result.stdout.decode("utf-8"))
|
|
177
|
+
except json.JSONDecodeError as exc:
|
|
178
|
+
raise ShotDetectionError("Failed to parse ffprobe output") from exc
|
|
179
|
+
|
|
180
|
+
streams = payload.get("streams")
|
|
181
|
+
if not isinstance(streams, list):
|
|
182
|
+
raise ShotDetectionError("ffprobe did not return stream metadata")
|
|
183
|
+
|
|
184
|
+
video_stream = next((stream for stream in streams if stream.get("codec_type") == "video"), None)
|
|
185
|
+
if not isinstance(video_stream, dict):
|
|
186
|
+
raise ShotDetectionError("No video stream found in the input file")
|
|
187
|
+
|
|
188
|
+
width = int(video_stream.get("width") or 0)
|
|
189
|
+
height = int(video_stream.get("height") or 0)
|
|
190
|
+
fps = _parse_frame_rate(video_stream.get("avg_frame_rate") or video_stream.get("r_frame_rate"))
|
|
191
|
+
duration_seconds = _parse_duration_seconds(payload.get("format"), video_stream)
|
|
192
|
+
|
|
193
|
+
if width <= 0 or height <= 0 or fps <= 0 or duration_seconds <= 0:
|
|
194
|
+
raise ShotDetectionError("Unable to determine valid video metadata")
|
|
195
|
+
|
|
196
|
+
return VideoMetadata(
|
|
197
|
+
duration_ms=int(round(duration_seconds * 1000)),
|
|
198
|
+
width=width,
|
|
199
|
+
height=height,
|
|
200
|
+
fps=fps,
|
|
201
|
+
codec=video_stream.get("codec_name"),
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _parse_frame_rate(value: object) -> float:
|
|
206
|
+
if not isinstance(value, str) or not value:
|
|
207
|
+
return 0.0
|
|
208
|
+
|
|
209
|
+
if "/" in value:
|
|
210
|
+
numerator_text, denominator_text = value.split("/", 1)
|
|
211
|
+
try:
|
|
212
|
+
numerator = float(numerator_text)
|
|
213
|
+
denominator = float(denominator_text)
|
|
214
|
+
except ValueError:
|
|
215
|
+
return 0.0
|
|
216
|
+
return 0.0 if denominator == 0 else numerator / denominator
|
|
217
|
+
|
|
218
|
+
try:
|
|
219
|
+
return float(value)
|
|
220
|
+
except ValueError:
|
|
221
|
+
return 0.0
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def _parse_duration_seconds(format_info: object, video_stream: dict[str, object]) -> float:
|
|
225
|
+
if isinstance(format_info, dict):
|
|
226
|
+
try:
|
|
227
|
+
duration = float(format_info.get("duration") or 0.0)
|
|
228
|
+
except (TypeError, ValueError):
|
|
229
|
+
duration = 0.0
|
|
230
|
+
if duration > 0:
|
|
231
|
+
return duration
|
|
232
|
+
|
|
233
|
+
try:
|
|
234
|
+
return float(video_stream.get("duration") or 0.0)
|
|
235
|
+
except (TypeError, ValueError):
|
|
236
|
+
return 0.0
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def extract_low_res_rgb_frames(
|
|
240
|
+
video_path: str | Path,
|
|
241
|
+
*,
|
|
242
|
+
width: int = MODEL_INPUT_WIDTH,
|
|
243
|
+
height: int = MODEL_INPUT_HEIGHT,
|
|
244
|
+
) -> bytes:
|
|
245
|
+
ffmpeg_path = resolve_binary("ffmpeg")
|
|
246
|
+
last_error: ShotDetectionError | None = None
|
|
247
|
+
|
|
248
|
+
for plan in get_decode_plans():
|
|
249
|
+
scale_filter = f"scale={width}:{height}:flags=fast_bilinear"
|
|
250
|
+
if _requires_hwdownload(plan):
|
|
251
|
+
filter_chain = f"hwdownload,format=nv12,{scale_filter},format=rgb24"
|
|
252
|
+
else:
|
|
253
|
+
filter_chain = f"{scale_filter},format=rgb24"
|
|
254
|
+
|
|
255
|
+
command = [
|
|
256
|
+
ffmpeg_path,
|
|
257
|
+
"-v",
|
|
258
|
+
"error",
|
|
259
|
+
*plan.input_options,
|
|
260
|
+
"-i",
|
|
261
|
+
str(video_path),
|
|
262
|
+
"-vf",
|
|
263
|
+
filter_chain,
|
|
264
|
+
"-pix_fmt",
|
|
265
|
+
"rgb24",
|
|
266
|
+
"-vsync",
|
|
267
|
+
"0",
|
|
268
|
+
"-f",
|
|
269
|
+
"rawvideo",
|
|
270
|
+
"pipe:1",
|
|
271
|
+
]
|
|
272
|
+
|
|
273
|
+
try:
|
|
274
|
+
return run_command(command).stdout
|
|
275
|
+
except ShotDetectionError as exc:
|
|
276
|
+
last_error = exc
|
|
277
|
+
|
|
278
|
+
if last_error is not None:
|
|
279
|
+
raise last_error
|
|
280
|
+
|
|
281
|
+
raise ShotDetectionError("No FFmpeg decode plan was available")
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import onnxruntime as ort
|
|
7
|
+
|
|
8
|
+
from .errors import ShotDetectionError
|
|
9
|
+
|
|
10
|
+
WINDOW_SIZE = 100
|
|
11
|
+
MODEL_INPUT_WIDTH = 48
|
|
12
|
+
MODEL_INPUT_HEIGHT = 27
|
|
13
|
+
MODEL_INPUT_CHANNELS = 3
|
|
14
|
+
MODEL_INPUT_SIZE = WINDOW_SIZE * MODEL_INPUT_WIDTH * MODEL_INPUT_HEIGHT * MODEL_INPUT_CHANNELS
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TransNetV2ONNX:
|
|
18
|
+
def __init__(self, model_path: str | Path) -> None:
|
|
19
|
+
self.model_path = str(model_path)
|
|
20
|
+
self._session = self._create_session(self.model_path)
|
|
21
|
+
self._input_name = self._session.get_inputs()[0].name
|
|
22
|
+
self._output_name = self._session.get_outputs()[0].name
|
|
23
|
+
self._input_type = self._session.get_inputs()[0].type
|
|
24
|
+
|
|
25
|
+
def _create_session(self, model_path: str) -> ort.InferenceSession:
|
|
26
|
+
path = Path(model_path)
|
|
27
|
+
if not path.is_file():
|
|
28
|
+
raise ShotDetectionError(f"Model file not found: {model_path}")
|
|
29
|
+
|
|
30
|
+
available_providers = ort.get_available_providers()
|
|
31
|
+
preferred_order = [
|
|
32
|
+
"CoreMLExecutionProvider",
|
|
33
|
+
"DmlExecutionProvider",
|
|
34
|
+
"CPUExecutionProvider",
|
|
35
|
+
]
|
|
36
|
+
providers = [provider for provider in preferred_order if provider in available_providers]
|
|
37
|
+
if not providers:
|
|
38
|
+
providers = available_providers
|
|
39
|
+
if not providers:
|
|
40
|
+
raise ShotDetectionError("No ONNX Runtime execution provider is available")
|
|
41
|
+
|
|
42
|
+
try:
|
|
43
|
+
return ort.InferenceSession(model_path, providers=providers)
|
|
44
|
+
except Exception as exc: # pragma: no cover
|
|
45
|
+
raise ShotDetectionError(f"Failed to initialize ONNX Runtime session: {exc}") from exc
|
|
46
|
+
|
|
47
|
+
def infer(self, frames: np.ndarray) -> np.ndarray:
|
|
48
|
+
if frames.shape != (1, WINDOW_SIZE, MODEL_INPUT_HEIGHT, MODEL_INPUT_WIDTH, MODEL_INPUT_CHANNELS):
|
|
49
|
+
raise ShotDetectionError(f"Unexpected frame tensor shape: {frames.shape}")
|
|
50
|
+
|
|
51
|
+
if self._input_type == "tensor(uint8)":
|
|
52
|
+
input_data = frames.astype(np.uint8, copy=False)
|
|
53
|
+
else:
|
|
54
|
+
input_data = frames.astype(np.float32, copy=False)
|
|
55
|
+
|
|
56
|
+
try:
|
|
57
|
+
outputs = self._session.run([self._output_name], {self._input_name: input_data})
|
|
58
|
+
except Exception as exc: # pragma: no cover
|
|
59
|
+
raise ShotDetectionError(f"ONNX inference failed: {exc}") from exc
|
|
60
|
+
|
|
61
|
+
output = np.asarray(outputs[0]).reshape(-1)
|
|
62
|
+
if output.size < WINDOW_SIZE:
|
|
63
|
+
raise ShotDetectionError(f"Model output is too short: {output.size}")
|
|
64
|
+
|
|
65
|
+
return output[:WINDOW_SIZE]
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import tempfile
|
|
5
|
+
import urllib.request
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
from .errors import ShotDetectionError
|
|
9
|
+
|
|
10
|
+
DEFAULT_MODEL_URL = "https://download.shotai.io/model/shot-detection/transnetv2_open_fp16.onnx"
|
|
11
|
+
DEFAULT_MODEL_FILENAME = "transnetv2_open_fp16.onnx"
|
|
12
|
+
DEFAULT_CACHE_DIRNAME = "shot-detection"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_default_cache_dir() -> Path:
|
|
16
|
+
override = os.environ.get("SHOT_DETECTION_CACHE_DIR")
|
|
17
|
+
if override:
|
|
18
|
+
return Path(override).expanduser()
|
|
19
|
+
|
|
20
|
+
if os.name == "nt":
|
|
21
|
+
local_app_data = os.environ.get("LOCALAPPDATA")
|
|
22
|
+
if local_app_data:
|
|
23
|
+
return Path(local_app_data) / DEFAULT_CACHE_DIRNAME
|
|
24
|
+
|
|
25
|
+
xdg_cache_home = os.environ.get("XDG_CACHE_HOME")
|
|
26
|
+
if xdg_cache_home:
|
|
27
|
+
return Path(xdg_cache_home) / DEFAULT_CACHE_DIRNAME
|
|
28
|
+
|
|
29
|
+
return Path.home() / ".cache" / DEFAULT_CACHE_DIRNAME
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_default_model_path(cache_dir: str | Path | None = None) -> Path:
|
|
33
|
+
base_dir = Path(cache_dir).expanduser() if cache_dir is not None else get_default_cache_dir()
|
|
34
|
+
return base_dir / "models" / DEFAULT_MODEL_FILENAME
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def ensure_default_model(
|
|
38
|
+
*,
|
|
39
|
+
model_path: str | Path | None = None,
|
|
40
|
+
cache_dir: str | Path | None = None,
|
|
41
|
+
model_url: str = DEFAULT_MODEL_URL,
|
|
42
|
+
timeout_seconds: int = 60,
|
|
43
|
+
) -> Path:
|
|
44
|
+
resolved_path = Path(model_path).expanduser() if model_path is not None else get_default_model_path(cache_dir)
|
|
45
|
+
if resolved_path.is_file():
|
|
46
|
+
return resolved_path
|
|
47
|
+
|
|
48
|
+
resolved_path.parent.mkdir(parents=True, exist_ok=True)
|
|
49
|
+
temp_file_handle = tempfile.NamedTemporaryFile(
|
|
50
|
+
prefix=resolved_path.stem + ".",
|
|
51
|
+
suffix=".tmp",
|
|
52
|
+
dir=resolved_path.parent,
|
|
53
|
+
delete=False,
|
|
54
|
+
)
|
|
55
|
+
temp_path = Path(temp_file_handle.name)
|
|
56
|
+
temp_file_handle.close()
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
with urllib.request.urlopen(model_url, timeout=timeout_seconds) as response, temp_path.open("wb") as file_obj:
|
|
60
|
+
while True:
|
|
61
|
+
chunk = response.read(1024 * 1024)
|
|
62
|
+
if not chunk:
|
|
63
|
+
break
|
|
64
|
+
file_obj.write(chunk)
|
|
65
|
+
|
|
66
|
+
if temp_path.stat().st_size == 0:
|
|
67
|
+
raise ShotDetectionError(f"Downloaded model is empty: {model_url}")
|
|
68
|
+
|
|
69
|
+
os.replace(temp_path, resolved_path)
|
|
70
|
+
return resolved_path
|
|
71
|
+
except Exception as exc:
|
|
72
|
+
temp_path.unlink(missing_ok=True)
|
|
73
|
+
raise ShotDetectionError(
|
|
74
|
+
f"Failed to download the default shot detection model from {model_url}: {exc}"
|
|
75
|
+
) from exc
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass(frozen=True)
|
|
7
|
+
class VideoMetadata:
|
|
8
|
+
duration_ms: int
|
|
9
|
+
width: int
|
|
10
|
+
height: int
|
|
11
|
+
fps: float
|
|
12
|
+
codec: str | None = None
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(frozen=True)
|
|
16
|
+
class ShotBoundary:
|
|
17
|
+
frame_index: int
|
|
18
|
+
timestamp_ms: int
|
|
19
|
+
confidence: float
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass(frozen=True)
|
|
23
|
+
class ShotSegment:
|
|
24
|
+
index: int
|
|
25
|
+
start_ms: int
|
|
26
|
+
end_ms: int
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass(frozen=True)
|
|
30
|
+
class DetectionResult:
|
|
31
|
+
boundaries: list[ShotBoundary]
|
|
32
|
+
shots: list[ShotSegment]
|
|
33
|
+
total_frames: int
|
|
34
|
+
fps: float
|
|
35
|
+
duration_ms: int
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from shot_detection.detector import boundaries_to_shots, find_peaks, merge_short_shots
|
|
2
|
+
from shot_detection.types import ShotBoundary, ShotSegment
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def test_find_peaks_prefers_stronger_peak_when_too_close() -> None:
|
|
6
|
+
peaks = find_peaks(__import__("numpy").array([0.0, 0.7, 0.0, 0.8, 0.0], dtype="float32"), 0.5, min_distance=10)
|
|
7
|
+
assert peaks == [3]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def test_boundaries_to_shots_returns_ranges_in_ms() -> None:
|
|
11
|
+
boundaries = [
|
|
12
|
+
ShotBoundary(frame_index=10, timestamp_ms=1000, confidence=0.9),
|
|
13
|
+
ShotBoundary(frame_index=20, timestamp_ms=2500, confidence=0.8),
|
|
14
|
+
]
|
|
15
|
+
shots = boundaries_to_shots(boundaries, duration_ms=4000, min_shot_duration_ms=100)
|
|
16
|
+
assert shots == [
|
|
17
|
+
ShotSegment(index=0, start_ms=0, end_ms=1000),
|
|
18
|
+
ShotSegment(index=1, start_ms=1000, end_ms=2500),
|
|
19
|
+
ShotSegment(index=2, start_ms=2500, end_ms=4000),
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def test_merge_short_shots_merges_short_middle_segment() -> None:
|
|
24
|
+
shots = [
|
|
25
|
+
ShotSegment(index=0, start_ms=0, end_ms=1000),
|
|
26
|
+
ShotSegment(index=1, start_ms=1000, end_ms=1200),
|
|
27
|
+
ShotSegment(index=2, start_ms=1200, end_ms=3000),
|
|
28
|
+
]
|
|
29
|
+
merged = merge_short_shots(shots, min_shot_duration_ms=500)
|
|
30
|
+
assert merged == [
|
|
31
|
+
ShotSegment(index=0, start_ms=0, end_ms=1200),
|
|
32
|
+
ShotSegment(index=1, start_ms=1200, end_ms=3000),
|
|
33
|
+
]
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import io
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
from shot_detection.model_manager import ensure_default_model, get_default_model_path
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class _FakeResponse:
|
|
10
|
+
def __init__(self, payload: bytes) -> None:
|
|
11
|
+
self._buffer = io.BytesIO(payload)
|
|
12
|
+
|
|
13
|
+
def read(self, size: int = -1) -> bytes:
|
|
14
|
+
return self._buffer.read(size)
|
|
15
|
+
|
|
16
|
+
def __enter__(self) -> "_FakeResponse":
|
|
17
|
+
return self
|
|
18
|
+
|
|
19
|
+
def __exit__(self, exc_type, exc, tb) -> None:
|
|
20
|
+
return None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def test_get_default_model_path_uses_cache_dir(tmp_path: Path) -> None:
|
|
24
|
+
model_path = get_default_model_path(tmp_path)
|
|
25
|
+
assert model_path == tmp_path / "models" / "transnetv2_open_fp16.onnx"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def test_ensure_default_model_downloads_once(monkeypatch, tmp_path: Path) -> None:
|
|
29
|
+
payload = b"fake-onnx-model"
|
|
30
|
+
calls: list[str] = []
|
|
31
|
+
|
|
32
|
+
def fake_urlopen(url: str, timeout: int = 60) -> _FakeResponse:
|
|
33
|
+
calls.append(url)
|
|
34
|
+
return _FakeResponse(payload)
|
|
35
|
+
|
|
36
|
+
monkeypatch.setattr("shot_detection.model_manager.urllib.request.urlopen", fake_urlopen)
|
|
37
|
+
|
|
38
|
+
model_path = ensure_default_model(cache_dir=tmp_path)
|
|
39
|
+
assert model_path.is_file()
|
|
40
|
+
assert model_path.read_bytes() == payload
|
|
41
|
+
assert len(calls) == 1
|
|
42
|
+
|
|
43
|
+
same_path = ensure_default_model(cache_dir=tmp_path)
|
|
44
|
+
assert same_path == model_path
|
|
45
|
+
assert len(calls) == 1
|