vidformer 0.9.0__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vidformer/__init__.py +1397 -3
- vidformer/cv2/__init__.py +858 -1
- vidformer/supervision/__init__.py +529 -0
- {vidformer-0.9.0.dist-info → vidformer-0.10.0.dist-info}/METADATA +7 -5
- vidformer-0.10.0.dist-info/RECORD +6 -0
- vidformer/cv2/vf_cv2.py +0 -810
- vidformer/igni/__init__.py +0 -1
- vidformer/igni/vf_igni.py +0 -285
- vidformer/vf.py +0 -1112
- vidformer-0.9.0.dist-info/RECORD +0 -9
- {vidformer-0.9.0.dist-info → vidformer-0.10.0.dist-info}/WHEEL +0 -0
vidformer/__init__.py
CHANGED
@@ -1,5 +1,1399 @@
|
|
1
|
-
"""
|
1
|
+
"""
|
2
|
+
vidformer-py is a Python 🐍 interface for [vidformer](https://github.com/ixlab/vidformer).
|
2
3
|
|
3
|
-
|
4
|
+
**Quick links:**
|
5
|
+
* [📦 PyPI](https://pypi.org/project/vidformer/)
|
6
|
+
* [📘 Documentation - vidformer-py](https://ixlab.github.io/vidformer/vidformer-py/pdoc/)
|
7
|
+
* [📘 Documentation - vidformer.cv2](https://ixlab.github.io/vidformer/vidformer-py/pdoc/vidformer/cv2.html)
|
8
|
+
* [📘 Documentation - vidformer.supervision](https://ixlab.github.io/vidformer/vidformer-py/pdoc/vidformer/supervision.html)
|
9
|
+
* [🧑💻 Source Code](https://github.com/ixlab/vidformer/tree/main/vidformer-py/)
|
10
|
+
"""
|
4
11
|
|
5
|
-
|
12
|
+
__version__ = "0.10.0"
|
13
|
+
|
14
|
+
|
15
|
+
import subprocess
|
16
|
+
from fractions import Fraction
|
17
|
+
import random
|
18
|
+
import time
|
19
|
+
import json
|
20
|
+
import socket
|
21
|
+
import os
|
22
|
+
import multiprocessing
|
23
|
+
import uuid
|
24
|
+
import threading
|
25
|
+
import gzip
|
26
|
+
import base64
|
27
|
+
import re
|
28
|
+
from urllib.parse import urlparse
|
29
|
+
|
30
|
+
import requests
|
31
|
+
import msgpack
|
32
|
+
import numpy as np
|
33
|
+
|
34
|
+
_in_notebook = False
|
35
|
+
try:
|
36
|
+
from IPython import get_ipython
|
37
|
+
|
38
|
+
if "IPKernelApp" in get_ipython().config:
|
39
|
+
_in_notebook = True
|
40
|
+
except Exception:
|
41
|
+
pass
|
42
|
+
|
43
|
+
|
44
|
+
def _wait_for_url(url, max_attempts=150, delay=0.1):
|
45
|
+
for attempt in range(max_attempts):
|
46
|
+
try:
|
47
|
+
response = requests.get(url)
|
48
|
+
if response.status_code == 200:
|
49
|
+
return response.text.strip()
|
50
|
+
else:
|
51
|
+
time.sleep(delay)
|
52
|
+
except requests.exceptions.RequestException:
|
53
|
+
time.sleep(delay)
|
54
|
+
return None
|
55
|
+
|
56
|
+
|
57
|
+
def _play(namespace, hls_video_url, hls_js_url, method="html", status_url=None):
|
58
|
+
# The namespace is so multiple videos in one tab don't conflict
|
59
|
+
|
60
|
+
if method == "html":
|
61
|
+
from IPython.display import HTML
|
62
|
+
|
63
|
+
if not status_url:
|
64
|
+
html_code = f"""
|
65
|
+
<!DOCTYPE html>
|
66
|
+
<html>
|
67
|
+
<head>
|
68
|
+
<title>HLS Video Player</title>
|
69
|
+
<!-- Include hls.js library -->
|
70
|
+
<script src="{hls_js_url}"></script>
|
71
|
+
</head>
|
72
|
+
<body>
|
73
|
+
<video id="video-{namespace}" controls width="640" height="360" autoplay></video>
|
74
|
+
<script>
|
75
|
+
var video = document.getElementById('video-{namespace}');
|
76
|
+
var videoSrc = '{hls_video_url}';
|
77
|
+
|
78
|
+
if (Hls.isSupported()) {{
|
79
|
+
var hls = new Hls();
|
80
|
+
hls.loadSource(videoSrc);
|
81
|
+
hls.attachMedia(video);
|
82
|
+
hls.on(Hls.Events.MANIFEST_PARSED, function() {{
|
83
|
+
video.play();
|
84
|
+
}});
|
85
|
+
}} else if (video.canPlayType('application/vnd.apple.mpegurl')) {{
|
86
|
+
video.src = videoSrc;
|
87
|
+
video.addEventListener('loadedmetadata', function() {{
|
88
|
+
video.play();
|
89
|
+
}});
|
90
|
+
}} else {{
|
91
|
+
console.error('This browser does not appear to support HLS.');
|
92
|
+
}}
|
93
|
+
</script>
|
94
|
+
</body>
|
95
|
+
</html>
|
96
|
+
"""
|
97
|
+
return HTML(data=html_code)
|
98
|
+
else:
|
99
|
+
html_code = f"""
|
100
|
+
<!DOCTYPE html>
|
101
|
+
<html>
|
102
|
+
<head>
|
103
|
+
<title>HLS Video Player</title>
|
104
|
+
<script src="{hls_js_url}"></script>
|
105
|
+
</head>
|
106
|
+
<body>
|
107
|
+
<div id="container"></div>
|
108
|
+
<script>
|
109
|
+
var statusUrl = '{status_url}';
|
110
|
+
var videoSrc = '{hls_video_url}';
|
111
|
+
var videoNamespace = '{namespace}';
|
112
|
+
|
113
|
+
function showWaiting() {{
|
114
|
+
document.getElementById('container').textContent = 'Waiting...';
|
115
|
+
pollStatus();
|
116
|
+
}}
|
117
|
+
|
118
|
+
function pollStatus() {{
|
119
|
+
setTimeout(function() {{
|
120
|
+
fetch(statusUrl)
|
121
|
+
.then(r => r.json())
|
122
|
+
.then(res => {{
|
123
|
+
if (res.ready) {{
|
124
|
+
document.getElementById('container').textContent = '';
|
125
|
+
attachHls();
|
126
|
+
}} else {{
|
127
|
+
pollStatus();
|
128
|
+
}}
|
129
|
+
}})
|
130
|
+
.catch(e => {{
|
131
|
+
console.error(e);
|
132
|
+
pollStatus();
|
133
|
+
}});
|
134
|
+
}}, 250);
|
135
|
+
}}
|
136
|
+
|
137
|
+
function attachHls() {{
|
138
|
+
var container = document.getElementById('container');
|
139
|
+
container.textContent = '';
|
140
|
+
var video = document.createElement('video');
|
141
|
+
video.id = 'video-' + videoNamespace;
|
142
|
+
video.controls = true;
|
143
|
+
video.width = 640;
|
144
|
+
video.height = 360;
|
145
|
+
container.appendChild(video);
|
146
|
+
if (Hls.isSupported()) {{
|
147
|
+
var hls = new Hls();
|
148
|
+
hls.loadSource(videoSrc);
|
149
|
+
hls.attachMedia(video);
|
150
|
+
hls.on(Hls.Events.MANIFEST_PARSED, function() {{
|
151
|
+
video.play();
|
152
|
+
}});
|
153
|
+
}} else if (video.canPlayType('application/vnd.apple.mpegurl')) {{
|
154
|
+
video.src = videoSrc;
|
155
|
+
video.addEventListener('loadedmetadata', function() {{
|
156
|
+
video.play();
|
157
|
+
}});
|
158
|
+
}}
|
159
|
+
}}
|
160
|
+
|
161
|
+
fetch(statusUrl)
|
162
|
+
.then(r => r.json())
|
163
|
+
.then(res => {{
|
164
|
+
if (res.ready) {{
|
165
|
+
attachHls();
|
166
|
+
}} else {{
|
167
|
+
showWaiting();
|
168
|
+
}}
|
169
|
+
}})
|
170
|
+
.catch(e => {{
|
171
|
+
console.error(e);
|
172
|
+
showWaiting();
|
173
|
+
}});
|
174
|
+
</script>
|
175
|
+
</body>
|
176
|
+
</html>
|
177
|
+
"""
|
178
|
+
return HTML(data=html_code)
|
179
|
+
elif method == "link":
|
180
|
+
return hls_video_url
|
181
|
+
else:
|
182
|
+
raise ValueError("Invalid method")
|
183
|
+
|
184
|
+
|
185
|
+
class IgniSource:
|
186
|
+
def __init__(self, id: str, src):
|
187
|
+
self._name = id
|
188
|
+
self._fmt = {
|
189
|
+
"width": src["width"],
|
190
|
+
"height": src["height"],
|
191
|
+
"pix_fmt": src["pix_fmt"],
|
192
|
+
}
|
193
|
+
self._ts = [Fraction(x[0], x[1]) for x in src["ts"]]
|
194
|
+
self.iloc = _SourceILoc(self)
|
195
|
+
|
196
|
+
def id(self) -> str:
|
197
|
+
return self._name
|
198
|
+
|
199
|
+
def fmt(self):
|
200
|
+
return {**self._fmt}
|
201
|
+
|
202
|
+
def ts(self) -> list[Fraction]:
|
203
|
+
return self._ts.copy()
|
204
|
+
|
205
|
+
def __len__(self):
|
206
|
+
return len(self._ts)
|
207
|
+
|
208
|
+
def __getitem__(self, idx):
|
209
|
+
if type(idx) is not Fraction:
|
210
|
+
raise Exception("Source index must be a Fraction")
|
211
|
+
return SourceExpr(self, idx, False)
|
212
|
+
|
213
|
+
def __repr__(self):
|
214
|
+
return f"IgniSource({self._name})"
|
215
|
+
|
216
|
+
|
217
|
+
class IgniSpec:
|
218
|
+
def __init__(self, id: str, src):
|
219
|
+
self._id = id
|
220
|
+
self._fmt = {
|
221
|
+
"width": src["width"],
|
222
|
+
"height": src["height"],
|
223
|
+
"pix_fmt": src["pix_fmt"],
|
224
|
+
}
|
225
|
+
self._vod_endpoint = src["vod_endpoint"]
|
226
|
+
parsed_url = urlparse(self._vod_endpoint)
|
227
|
+
self._hls_js_url = f"{parsed_url.scheme}://{parsed_url.netloc}/hls.js"
|
228
|
+
|
229
|
+
def id(self) -> str:
|
230
|
+
return self._id
|
231
|
+
|
232
|
+
def play(self, *args, **kwargs):
|
233
|
+
url = f"{self._vod_endpoint}playlist.m3u8"
|
234
|
+
status_url = f"{self._vod_endpoint}status"
|
235
|
+
hls_js_url = self._hls_js_url
|
236
|
+
return _play(self._id, url, hls_js_url, *args, **kwargs, status_url=status_url)
|
237
|
+
|
238
|
+
|
239
|
+
class IgniServer:
|
240
|
+
def __init__(self, endpoint: str, api_key: str):
|
241
|
+
if not endpoint.startswith("http://") and not endpoint.startswith("https://"):
|
242
|
+
raise Exception("Endpoint must start with http:// or https://")
|
243
|
+
if endpoint.endswith("/"):
|
244
|
+
raise Exception("Endpoint must not end with /")
|
245
|
+
self._endpoint = endpoint
|
246
|
+
|
247
|
+
self._api_key = api_key
|
248
|
+
response = requests.get(
|
249
|
+
f"{self._endpoint}/auth",
|
250
|
+
headers={"Authorization": f"Bearer {self._api_key}"},
|
251
|
+
)
|
252
|
+
if not response.ok:
|
253
|
+
raise Exception(response.text)
|
254
|
+
response = response.json()
|
255
|
+
assert response["status"] == "ok"
|
256
|
+
|
257
|
+
def get_source(self, id: str) -> IgniSource:
|
258
|
+
assert type(id) is str
|
259
|
+
response = requests.get(
|
260
|
+
f"{self._endpoint}/source/{id}",
|
261
|
+
headers={"Authorization": f"Bearer {self._api_key}"},
|
262
|
+
)
|
263
|
+
if not response.ok:
|
264
|
+
raise Exception(response.text)
|
265
|
+
response = response.json()
|
266
|
+
return IgniSource(response["id"], response)
|
267
|
+
|
268
|
+
def list_sources(self) -> list[str]:
|
269
|
+
response = requests.get(
|
270
|
+
f"{self._endpoint}/source",
|
271
|
+
headers={"Authorization": f"Bearer {self._api_key}"},
|
272
|
+
)
|
273
|
+
if not response.ok:
|
274
|
+
raise Exception(response.text)
|
275
|
+
response = response.json()
|
276
|
+
return response
|
277
|
+
|
278
|
+
def delete_source(self, id: str):
|
279
|
+
assert type(id) is str
|
280
|
+
response = requests.delete(
|
281
|
+
f"{self._endpoint}/source/{id}",
|
282
|
+
headers={"Authorization": f"Bearer {self._api_key}"},
|
283
|
+
)
|
284
|
+
if not response.ok:
|
285
|
+
raise Exception(response.text)
|
286
|
+
response = response.json()
|
287
|
+
assert response["status"] == "ok"
|
288
|
+
|
289
|
+
def search_source(
|
290
|
+
self, name, stream_idx, storage_service, storage_config
|
291
|
+
) -> list[str]:
|
292
|
+
assert type(name) is str
|
293
|
+
assert type(stream_idx) is int
|
294
|
+
assert type(storage_service) is str
|
295
|
+
assert type(storage_config) is dict
|
296
|
+
for k, v in storage_config.items():
|
297
|
+
assert type(k) is str
|
298
|
+
assert type(v) is str
|
299
|
+
req = {
|
300
|
+
"name": name,
|
301
|
+
"stream_idx": stream_idx,
|
302
|
+
"storage_service": storage_service,
|
303
|
+
"storage_config": storage_config,
|
304
|
+
}
|
305
|
+
response = requests.post(
|
306
|
+
f"{self._endpoint}/source/search",
|
307
|
+
json=req,
|
308
|
+
headers={"Authorization": f"Bearer {self._api_key}"},
|
309
|
+
)
|
310
|
+
if not response.ok:
|
311
|
+
raise Exception(response.text)
|
312
|
+
response = response.json()
|
313
|
+
return response
|
314
|
+
|
315
|
+
def create_source(
|
316
|
+
self, name, stream_idx, storage_service, storage_config
|
317
|
+
) -> IgniSource:
|
318
|
+
assert type(name) is str
|
319
|
+
assert type(stream_idx) is int
|
320
|
+
assert type(storage_service) is str
|
321
|
+
assert type(storage_config) is dict
|
322
|
+
for k, v in storage_config.items():
|
323
|
+
assert type(k) is str
|
324
|
+
assert type(v) is str
|
325
|
+
req = {
|
326
|
+
"name": name,
|
327
|
+
"stream_idx": stream_idx,
|
328
|
+
"storage_service": storage_service,
|
329
|
+
"storage_config": storage_config,
|
330
|
+
}
|
331
|
+
response = requests.post(
|
332
|
+
f"{self._endpoint}/source",
|
333
|
+
json=req,
|
334
|
+
headers={"Authorization": f"Bearer {self._api_key}"},
|
335
|
+
)
|
336
|
+
if not response.ok:
|
337
|
+
raise Exception(response.text)
|
338
|
+
response = response.json()
|
339
|
+
assert response["status"] == "ok"
|
340
|
+
id = response["id"]
|
341
|
+
return self.get_source(id)
|
342
|
+
|
343
|
+
def source(self, name, stream_idx, storage_service, storage_config) -> IgniSource:
|
344
|
+
"""Convenience function for accessing sources.
|
345
|
+
|
346
|
+
Tries to find a source with the given name, stream_idx, storage_service, and storage_config.
|
347
|
+
If no source is found, creates a new source with the given parameters.
|
348
|
+
"""
|
349
|
+
|
350
|
+
sources = self.search_source(name, stream_idx, storage_service, storage_config)
|
351
|
+
if len(sources) == 0:
|
352
|
+
return self.create_source(name, stream_idx, storage_service, storage_config)
|
353
|
+
return self.get_source(sources[0])
|
354
|
+
|
355
|
+
def get_spec(self, id: str) -> IgniSpec:
|
356
|
+
assert type(id) is str
|
357
|
+
response = requests.get(
|
358
|
+
f"{self._endpoint}/spec/{id}",
|
359
|
+
headers={"Authorization": f"Bearer {self._api_key}"},
|
360
|
+
)
|
361
|
+
if not response.ok:
|
362
|
+
raise Exception(response.text)
|
363
|
+
response = response.json()
|
364
|
+
return IgniSpec(response["id"], response)
|
365
|
+
|
366
|
+
def list_specs(self) -> list[str]:
|
367
|
+
response = requests.get(
|
368
|
+
f"{self._endpoint}/spec",
|
369
|
+
headers={"Authorization": f"Bearer {self._api_key}"},
|
370
|
+
)
|
371
|
+
if not response.ok:
|
372
|
+
raise Exception(response.text)
|
373
|
+
response = response.json()
|
374
|
+
return response
|
375
|
+
|
376
|
+
def create_spec(
|
377
|
+
self,
|
378
|
+
width,
|
379
|
+
height,
|
380
|
+
pix_fmt,
|
381
|
+
vod_segment_length,
|
382
|
+
frame_rate,
|
383
|
+
ready_hook=None,
|
384
|
+
steer_hook=None,
|
385
|
+
) -> IgniSpec:
|
386
|
+
assert type(width) is int
|
387
|
+
assert type(height) is int
|
388
|
+
assert type(pix_fmt) is str
|
389
|
+
assert type(vod_segment_length) is Fraction
|
390
|
+
assert type(frame_rate) is Fraction
|
391
|
+
assert type(ready_hook) is str or ready_hook is None
|
392
|
+
assert type(steer_hook) is str or steer_hook is None
|
393
|
+
|
394
|
+
req = {
|
395
|
+
"width": width,
|
396
|
+
"height": height,
|
397
|
+
"pix_fmt": pix_fmt,
|
398
|
+
"vod_segment_length": [
|
399
|
+
vod_segment_length.numerator,
|
400
|
+
vod_segment_length.denominator,
|
401
|
+
],
|
402
|
+
"frame_rate": [frame_rate.numerator, frame_rate.denominator],
|
403
|
+
"ready_hook": ready_hook,
|
404
|
+
"steer_hook": steer_hook,
|
405
|
+
}
|
406
|
+
response = requests.post(
|
407
|
+
f"{self._endpoint}/spec",
|
408
|
+
json=req,
|
409
|
+
headers={"Authorization": f"Bearer {self._api_key}"},
|
410
|
+
)
|
411
|
+
if not response.ok:
|
412
|
+
raise Exception(response.text)
|
413
|
+
response = response.json()
|
414
|
+
assert response["status"] == "ok"
|
415
|
+
return self.get_spec(response["id"])
|
416
|
+
|
417
|
+
def delete_spec(self, id: str):
|
418
|
+
assert type(id) is str
|
419
|
+
response = requests.delete(
|
420
|
+
f"{self._endpoint}/spec/{id}",
|
421
|
+
headers={"Authorization": f"Bearer {self._api_key}"},
|
422
|
+
)
|
423
|
+
if not response.ok:
|
424
|
+
raise Exception(response.text)
|
425
|
+
response = response.json()
|
426
|
+
assert response["status"] == "ok"
|
427
|
+
|
428
|
+
def push_spec_part(self, spec_id, pos, frames, terminal):
|
429
|
+
if type(spec_id) is IgniSpec:
|
430
|
+
spec_id = spec_id._id
|
431
|
+
assert type(spec_id) is str
|
432
|
+
assert type(pos) is int
|
433
|
+
assert type(frames) is list
|
434
|
+
assert type(terminal) is bool
|
435
|
+
|
436
|
+
req_frames = []
|
437
|
+
for frame in frames:
|
438
|
+
assert type(frame) is tuple
|
439
|
+
assert len(frame) == 2
|
440
|
+
t = frame[0]
|
441
|
+
f = frame[1]
|
442
|
+
assert type(t) is Fraction
|
443
|
+
assert f is None or type(f) is SourceExpr or type(f) is FilterExpr
|
444
|
+
req_frames.append(
|
445
|
+
[
|
446
|
+
[t.numerator, t.denominator],
|
447
|
+
f._to_json_spec() if f is not None else None,
|
448
|
+
]
|
449
|
+
)
|
450
|
+
|
451
|
+
req = {
|
452
|
+
"pos": pos,
|
453
|
+
"frames": req_frames,
|
454
|
+
"terminal": terminal,
|
455
|
+
}
|
456
|
+
response = requests.post(
|
457
|
+
f"{self._endpoint}/spec/{spec_id}/part",
|
458
|
+
json=req,
|
459
|
+
headers={"Authorization": f"Bearer {self._api_key}"},
|
460
|
+
)
|
461
|
+
if not response.ok:
|
462
|
+
raise Exception(response.text)
|
463
|
+
response = response.json()
|
464
|
+
assert response["status"] == "ok"
|
465
|
+
|
466
|
+
|
467
|
+
class YrdenSpec:
|
468
|
+
"""
|
469
|
+
A video transformation specification.
|
470
|
+
|
471
|
+
See https://ixlab.github.io/vidformer/concepts.html for more information.
|
472
|
+
"""
|
473
|
+
|
474
|
+
def __init__(self, domain: list[Fraction], render, fmt: dict):
|
475
|
+
self._domain = domain
|
476
|
+
self._render = render
|
477
|
+
self._fmt = fmt
|
478
|
+
|
479
|
+
def __repr__(self):
|
480
|
+
if len(self._domain) <= 20:
|
481
|
+
lines = []
|
482
|
+
for i, t in enumerate(self._domain):
|
483
|
+
frame_expr = self._render(t, i)
|
484
|
+
lines.append(
|
485
|
+
f"{t.numerator}/{t.denominator} => {frame_expr}",
|
486
|
+
)
|
487
|
+
return "\n".join(lines)
|
488
|
+
else:
|
489
|
+
lines = []
|
490
|
+
for i, t in enumerate(self._domain[:10]):
|
491
|
+
frame_expr = self._render(t, i)
|
492
|
+
lines.append(
|
493
|
+
f"{t.numerator}/{t.denominator} => {frame_expr}",
|
494
|
+
)
|
495
|
+
lines.append("...")
|
496
|
+
for i, t in enumerate(self._domain[-10:]):
|
497
|
+
frame_expr = self._render(t, i)
|
498
|
+
lines.append(
|
499
|
+
f"{t.numerator}/{t.denominator} => {frame_expr}",
|
500
|
+
)
|
501
|
+
return "\n".join(lines)
|
502
|
+
|
503
|
+
def _sources(self):
|
504
|
+
s = set()
|
505
|
+
for i, t in enumerate(self._domain):
|
506
|
+
frame_expr = self._render(t, i)
|
507
|
+
s = s.union(frame_expr._sources())
|
508
|
+
return s
|
509
|
+
|
510
|
+
def _to_json_spec(self):
|
511
|
+
frames = []
|
512
|
+
s = set()
|
513
|
+
f = {}
|
514
|
+
for i, t in enumerate(self._domain):
|
515
|
+
frame_expr = self._render(t, i)
|
516
|
+
s = s.union(frame_expr._sources())
|
517
|
+
f = {**f, **frame_expr._filters()}
|
518
|
+
frame = [[t.numerator, t.denominator], frame_expr._to_json_spec()]
|
519
|
+
frames.append(frame)
|
520
|
+
return {"frames": frames}, s, f
|
521
|
+
|
522
|
+
def play(self, server, method="html", verbose=False):
|
523
|
+
"""Play the video live in the notebook."""
|
524
|
+
|
525
|
+
spec, sources, filters = self._to_json_spec()
|
526
|
+
spec_json_bytes = json.dumps(spec).encode("utf-8")
|
527
|
+
spec_obj_json_gzip = gzip.compress(spec_json_bytes, compresslevel=1)
|
528
|
+
spec_obj_json_gzip_b64 = base64.b64encode(spec_obj_json_gzip).decode("utf-8")
|
529
|
+
|
530
|
+
sources = [
|
531
|
+
{
|
532
|
+
"name": s._name,
|
533
|
+
"path": s._path,
|
534
|
+
"stream": s._stream,
|
535
|
+
"service": s._service.as_json() if s._service is not None else None,
|
536
|
+
}
|
537
|
+
for s in sources
|
538
|
+
]
|
539
|
+
filters = {
|
540
|
+
k: {
|
541
|
+
"filter": v._func,
|
542
|
+
"args": v._kwargs,
|
543
|
+
}
|
544
|
+
for k, v in filters.items()
|
545
|
+
}
|
546
|
+
arrays = []
|
547
|
+
|
548
|
+
if verbose:
|
549
|
+
print(f"Sending to server. Spec is {len(spec_obj_json_gzip_b64)} bytes")
|
550
|
+
|
551
|
+
resp = server._new(spec_obj_json_gzip_b64, sources, filters, arrays, self._fmt)
|
552
|
+
hls_video_url = resp["stream_url"]
|
553
|
+
hls_player_url = resp["player_url"]
|
554
|
+
namespace = resp["namespace"]
|
555
|
+
hls_js_url = server.hls_js_url()
|
556
|
+
|
557
|
+
if method == "link":
|
558
|
+
return hls_video_url
|
559
|
+
if method == "player":
|
560
|
+
return hls_player_url
|
561
|
+
if method == "iframe":
|
562
|
+
from IPython.display import IFrame
|
563
|
+
|
564
|
+
return IFrame(hls_player_url, width=1280, height=720)
|
565
|
+
if method == "html":
|
566
|
+
from IPython.display import HTML
|
567
|
+
|
568
|
+
# We add a namespace to the video element to avoid conflicts with other videos
|
569
|
+
html_code = f"""
|
570
|
+
<!DOCTYPE html>
|
571
|
+
<html>
|
572
|
+
<head>
|
573
|
+
<title>HLS Video Player</title>
|
574
|
+
<!-- Include hls.js library -->
|
575
|
+
<script src="{hls_js_url}"></script>
|
576
|
+
</head>
|
577
|
+
<body>
|
578
|
+
<!-- Video element -->
|
579
|
+
<video id="video-{namespace}" controls width="640" height="360" autoplay></video>
|
580
|
+
<script>
|
581
|
+
var video = document.getElementById('video-{namespace}');
|
582
|
+
var videoSrc = '{hls_video_url}';
|
583
|
+
var hls = new Hls();
|
584
|
+
hls.loadSource(videoSrc);
|
585
|
+
hls.attachMedia(video);
|
586
|
+
hls.on(Hls.Events.MANIFEST_PARSED, function() {{
|
587
|
+
video.play();
|
588
|
+
}});
|
589
|
+
</script>
|
590
|
+
</body>
|
591
|
+
</html>
|
592
|
+
"""
|
593
|
+
return HTML(data=html_code)
|
594
|
+
else:
|
595
|
+
return hls_player_url
|
596
|
+
|
597
|
+
def load(self, server):
|
598
|
+
spec, sources, filters = self._to_json_spec()
|
599
|
+
spec_json_bytes = json.dumps(spec).encode("utf-8")
|
600
|
+
spec_obj_json_gzip = gzip.compress(spec_json_bytes, compresslevel=1)
|
601
|
+
spec_obj_json_gzip_b64 = base64.b64encode(spec_obj_json_gzip).decode("utf-8")
|
602
|
+
|
603
|
+
sources = [
|
604
|
+
{
|
605
|
+
"name": s._name,
|
606
|
+
"path": s._path,
|
607
|
+
"stream": s._stream,
|
608
|
+
"service": s._service.as_json() if s._service is not None else None,
|
609
|
+
}
|
610
|
+
for s in sources
|
611
|
+
]
|
612
|
+
filters = {
|
613
|
+
k: {
|
614
|
+
"filter": v._func,
|
615
|
+
"args": v._kwargs,
|
616
|
+
}
|
617
|
+
for k, v in filters.items()
|
618
|
+
}
|
619
|
+
arrays = []
|
620
|
+
|
621
|
+
resp = server._new(spec_obj_json_gzip_b64, sources, filters, arrays, self._fmt)
|
622
|
+
namespace = resp["namespace"]
|
623
|
+
return YrdenLoader(server, namespace, self._domain)
|
624
|
+
|
625
|
+
def save(self, server, pth, encoder=None, encoder_opts=None, format=None):
|
626
|
+
"""Save the video to a file."""
|
627
|
+
|
628
|
+
assert encoder is None or type(encoder) is str
|
629
|
+
assert encoder_opts is None or type(encoder_opts) is dict
|
630
|
+
if encoder_opts is not None:
|
631
|
+
for k, v in encoder_opts.items():
|
632
|
+
assert type(k) is str and type(v) is str
|
633
|
+
|
634
|
+
spec, sources, filters = self._to_json_spec()
|
635
|
+
spec_json_bytes = json.dumps(spec).encode("utf-8")
|
636
|
+
spec_obj_json_gzip = gzip.compress(spec_json_bytes, compresslevel=1)
|
637
|
+
spec_obj_json_gzip_b64 = base64.b64encode(spec_obj_json_gzip).decode("utf-8")
|
638
|
+
|
639
|
+
sources = [
|
640
|
+
{
|
641
|
+
"name": s._name,
|
642
|
+
"path": s._path,
|
643
|
+
"stream": s._stream,
|
644
|
+
"service": s._service.as_json() if s._service is not None else None,
|
645
|
+
}
|
646
|
+
for s in sources
|
647
|
+
]
|
648
|
+
filters = {
|
649
|
+
k: {
|
650
|
+
"filter": v._func,
|
651
|
+
"args": v._kwargs,
|
652
|
+
}
|
653
|
+
for k, v in filters.items()
|
654
|
+
}
|
655
|
+
arrays = []
|
656
|
+
|
657
|
+
resp = server._export(
|
658
|
+
pth,
|
659
|
+
spec_obj_json_gzip_b64,
|
660
|
+
sources,
|
661
|
+
filters,
|
662
|
+
arrays,
|
663
|
+
self._fmt,
|
664
|
+
encoder,
|
665
|
+
encoder_opts,
|
666
|
+
format,
|
667
|
+
)
|
668
|
+
|
669
|
+
return resp
|
670
|
+
|
671
|
+
def _vrod_bench(self, server):
|
672
|
+
out = {}
|
673
|
+
pth = "spec.json"
|
674
|
+
start_t = time.time()
|
675
|
+
with open(pth, "w") as outfile:
|
676
|
+
spec, sources, filters = self._to_json_spec()
|
677
|
+
outfile.write(json.dumps(spec))
|
678
|
+
|
679
|
+
sources = [
|
680
|
+
{
|
681
|
+
"name": s._name,
|
682
|
+
"path": s._path,
|
683
|
+
"stream": s._stream,
|
684
|
+
"service": s._service.as_json() if s._service is not None else None,
|
685
|
+
}
|
686
|
+
for s in sources
|
687
|
+
]
|
688
|
+
filters = {
|
689
|
+
k: {
|
690
|
+
"filter": v._func,
|
691
|
+
"args": v._kwargs,
|
692
|
+
}
|
693
|
+
for k, v in filters.items()
|
694
|
+
}
|
695
|
+
arrays = []
|
696
|
+
end_t = time.time()
|
697
|
+
out["vrod_create_spec"] = end_t - start_t
|
698
|
+
|
699
|
+
start = time.time()
|
700
|
+
resp = server._new(pth, sources, filters, arrays, self._fmt)
|
701
|
+
end = time.time()
|
702
|
+
out["vrod_register"] = end - start
|
703
|
+
|
704
|
+
stream_url = resp["stream_url"]
|
705
|
+
first_segment = stream_url.replace("stream.m3u8", "segment-0.ts")
|
706
|
+
|
707
|
+
start = time.time()
|
708
|
+
r = requests.get(first_segment)
|
709
|
+
r.raise_for_status()
|
710
|
+
end = time.time()
|
711
|
+
out["vrod_first_segment"] = end - start
|
712
|
+
return out
|
713
|
+
|
714
|
+
def _dve2_bench(self, server):
|
715
|
+
pth = "spec.json"
|
716
|
+
out = {}
|
717
|
+
start_t = time.time()
|
718
|
+
with open(pth, "w") as outfile:
|
719
|
+
spec, sources, filters = self._to_json_spec()
|
720
|
+
outfile.write(json.dumps(spec))
|
721
|
+
|
722
|
+
sources = [
|
723
|
+
{
|
724
|
+
"name": s._name,
|
725
|
+
"path": s._path,
|
726
|
+
"stream": s._stream,
|
727
|
+
"service": s._service.as_json() if s._service is not None else None,
|
728
|
+
}
|
729
|
+
for s in sources
|
730
|
+
]
|
731
|
+
filters = {
|
732
|
+
k: {
|
733
|
+
"filter": v._func,
|
734
|
+
"args": v._kwargs,
|
735
|
+
}
|
736
|
+
for k, v in filters.items()
|
737
|
+
}
|
738
|
+
arrays = []
|
739
|
+
end_t = time.time()
|
740
|
+
out["dve2_create_spec"] = end_t - start_t
|
741
|
+
|
742
|
+
start = time.time()
|
743
|
+
resp = server._export(pth, sources, filters, arrays, self._fmt, None, None)
|
744
|
+
resp.raise_for_status()
|
745
|
+
end = time.time()
|
746
|
+
out["dve2_exec"] = end - start
|
747
|
+
return out
|
748
|
+
|
749
|
+
|
750
|
+
class YrdenLoader:
|
751
|
+
def __init__(self, server, namespace: str, domain):
|
752
|
+
self._server = server
|
753
|
+
self._namespace = namespace
|
754
|
+
self._domain = domain
|
755
|
+
|
756
|
+
def _chunk(self, start_i, end_i):
|
757
|
+
return self._server._raw(self._namespace, start_i, end_i)
|
758
|
+
|
759
|
+
def __len__(self):
|
760
|
+
return len(self._domain)
|
761
|
+
|
762
|
+
def _find_index_by_rational(self, value):
|
763
|
+
if value not in self._domain:
|
764
|
+
raise ValueError(f"Rational timestamp {value} is not in the domain")
|
765
|
+
return self._domain.index(value)
|
766
|
+
|
767
|
+
def __getitem__(self, index):
|
768
|
+
if isinstance(index, slice):
|
769
|
+
start = index.start if index.start is not None else 0
|
770
|
+
end = index.stop if index.stop is not None else len(self._domain)
|
771
|
+
assert start >= 0 and start < len(self._domain)
|
772
|
+
assert end >= 0 and end <= len(self._domain)
|
773
|
+
assert start <= end
|
774
|
+
num_frames = end - start
|
775
|
+
all_bytes = self._chunk(start, end - 1)
|
776
|
+
all_bytes_len = len(all_bytes)
|
777
|
+
assert all_bytes_len % num_frames == 0
|
778
|
+
return [
|
779
|
+
all_bytes[
|
780
|
+
i
|
781
|
+
* all_bytes_len
|
782
|
+
// num_frames : (i + 1)
|
783
|
+
* all_bytes_len
|
784
|
+
// num_frames
|
785
|
+
]
|
786
|
+
for i in range(num_frames)
|
787
|
+
]
|
788
|
+
elif isinstance(index, int):
|
789
|
+
assert index >= 0 and index < len(self._domain)
|
790
|
+
return self._chunk(index, index)
|
791
|
+
else:
|
792
|
+
raise TypeError(
|
793
|
+
"Invalid argument type for iloc. Use a slice or an integer."
|
794
|
+
)
|
795
|
+
|
796
|
+
|
797
|
+
class YrdenServer:
|
798
|
+
"""
|
799
|
+
A connection to a Yrden server.
|
800
|
+
|
801
|
+
A yrden server is the main API for local use of vidformer.
|
802
|
+
"""
|
803
|
+
|
804
|
+
def __init__(self, domain=None, port=None, bin=None, hls_prefix=None):
|
805
|
+
"""
|
806
|
+
Connect to a Yrden server
|
807
|
+
|
808
|
+
Can either connect to an existing server, if domain and port are provided, or start a new server using the provided binary.
|
809
|
+
If no domain or binary is provided, the `VIDFORMER_BIN` environment variable is used.
|
810
|
+
"""
|
811
|
+
|
812
|
+
self._domain = domain
|
813
|
+
self._port = port
|
814
|
+
self._proc = None
|
815
|
+
if self._port is None:
|
816
|
+
if bin is None:
|
817
|
+
if os.getenv("VIDFORMER_BIN") is not None:
|
818
|
+
bin = os.getenv("VIDFORMER_BIN")
|
819
|
+
else:
|
820
|
+
bin = "vidformer-cli"
|
821
|
+
|
822
|
+
self._domain = "localhost"
|
823
|
+
self._port = random.randint(49152, 65535)
|
824
|
+
cmd = [bin, "yrden", "--port", str(self._port)]
|
825
|
+
if _in_notebook:
|
826
|
+
# We need to print the URL in the notebook
|
827
|
+
# This is a trick to get VS Code to forward the port
|
828
|
+
cmd += ["--print-url"]
|
829
|
+
|
830
|
+
if hls_prefix is not None:
|
831
|
+
if type(hls_prefix) is not str:
|
832
|
+
raise Exception("hls_prefix must be a string")
|
833
|
+
cmd += ["--hls-prefix", hls_prefix]
|
834
|
+
|
835
|
+
self._proc = subprocess.Popen(cmd)
|
836
|
+
|
837
|
+
version = _wait_for_url(f"http://{self._domain}:{self._port}/")
|
838
|
+
if version is None:
|
839
|
+
raise Exception("Failed to connect to server")
|
840
|
+
|
841
|
+
expected_version = f"vidformer-yrden v{__version__}"
|
842
|
+
if version != expected_version:
|
843
|
+
print(
|
844
|
+
f"Warning: Expected version `{expected_version}`, got `{version}`. API may not be compatible!"
|
845
|
+
)
|
846
|
+
|
847
|
+
def _source(self, name: str, path: str, stream: int, service):
|
848
|
+
r = requests.post(
|
849
|
+
f"http://{self._domain}:{self._port}/source",
|
850
|
+
json={
|
851
|
+
"name": name,
|
852
|
+
"path": path,
|
853
|
+
"stream": stream,
|
854
|
+
"service": service.as_json() if service is not None else None,
|
855
|
+
},
|
856
|
+
)
|
857
|
+
if not r.ok:
|
858
|
+
raise Exception(r.text)
|
859
|
+
|
860
|
+
resp = r.json()
|
861
|
+
resp["ts"] = [Fraction(x[0], x[1]) for x in resp["ts"]]
|
862
|
+
return resp
|
863
|
+
|
864
|
+
def _new(self, spec, sources, filters, arrays, fmt):
|
865
|
+
req = {
|
866
|
+
"spec": spec,
|
867
|
+
"sources": sources,
|
868
|
+
"filters": filters,
|
869
|
+
"arrays": arrays,
|
870
|
+
"width": fmt["width"],
|
871
|
+
"height": fmt["height"],
|
872
|
+
"pix_fmt": fmt["pix_fmt"],
|
873
|
+
}
|
874
|
+
|
875
|
+
r = requests.post(f"http://{self._domain}:{self._port}/new", json=req)
|
876
|
+
if not r.ok:
|
877
|
+
raise Exception(r.text)
|
878
|
+
|
879
|
+
return r.json()
|
880
|
+
|
881
|
+
def _export(
|
882
|
+
self, pth, spec, sources, filters, arrays, fmt, encoder, encoder_opts, format
|
883
|
+
):
|
884
|
+
req = {
|
885
|
+
"spec": spec,
|
886
|
+
"sources": sources,
|
887
|
+
"filters": filters,
|
888
|
+
"arrays": arrays,
|
889
|
+
"width": fmt["width"],
|
890
|
+
"height": fmt["height"],
|
891
|
+
"pix_fmt": fmt["pix_fmt"],
|
892
|
+
"output_path": pth,
|
893
|
+
"encoder": encoder,
|
894
|
+
"encoder_opts": encoder_opts,
|
895
|
+
"format": format,
|
896
|
+
}
|
897
|
+
|
898
|
+
r = requests.post(f"http://{self._domain}:{self._port}/export", json=req)
|
899
|
+
if not r.ok:
|
900
|
+
raise Exception(r.text)
|
901
|
+
|
902
|
+
return r.json()
|
903
|
+
|
904
|
+
def _raw(self, namespace, start_i, end_i):
|
905
|
+
r = requests.get(
|
906
|
+
f"http://{self._domain}:{self._port}/{namespace}/raw/{start_i}-{end_i}"
|
907
|
+
)
|
908
|
+
if not r.ok:
|
909
|
+
raise Exception(r.text)
|
910
|
+
return r.content
|
911
|
+
|
912
|
+
def hls_js_url(self):
|
913
|
+
"""Return the link to the yrden-hosted hls.js file"""
|
914
|
+
return f"http://{self._domain}:{self._port}/hls.js"
|
915
|
+
|
916
|
+
def __del__(self):
|
917
|
+
if self._proc is not None:
|
918
|
+
self._proc.kill()
|
919
|
+
|
920
|
+
|
921
|
+
class YrdenSource:
|
922
|
+
"""A video source."""
|
923
|
+
|
924
|
+
def __init__(
|
925
|
+
self, server: YrdenServer, name: str, path: str, stream: int, service=None
|
926
|
+
):
|
927
|
+
if service is None:
|
928
|
+
# check if path is a http URL and, if so, automatically set the service
|
929
|
+
# for example, the following code should work with just vf.Source(server, "tos_720p", "https://f.dominik.win/data/dve2/tos_720p.mp4")
|
930
|
+
# this creates a storage service with endpoint "https://f.dominik.win/" and path "data/dve2/tos_720p.mp4"
|
931
|
+
# don't use the root parameter in this case
|
932
|
+
|
933
|
+
match = re.match(r"(http|https)://([^/]+)(.*)", path)
|
934
|
+
if match is not None:
|
935
|
+
endpoint = f"{match.group(1)}://{match.group(2)}"
|
936
|
+
path = match.group(3)
|
937
|
+
# remove leading slash
|
938
|
+
if path.startswith("/"):
|
939
|
+
path = path[1:]
|
940
|
+
service = YrdenStorageService("http", endpoint=endpoint)
|
941
|
+
|
942
|
+
self._server = server
|
943
|
+
self._name = name
|
944
|
+
self._path = path
|
945
|
+
self._stream = stream
|
946
|
+
self._service = service
|
947
|
+
|
948
|
+
self.iloc = _SourceILoc(self)
|
949
|
+
|
950
|
+
self._src = self._server._source(
|
951
|
+
self._name, self._path, self._stream, self._service
|
952
|
+
)
|
953
|
+
|
954
|
+
def fmt(self):
|
955
|
+
return {
|
956
|
+
"width": self._src["width"],
|
957
|
+
"height": self._src["height"],
|
958
|
+
"pix_fmt": self._src["pix_fmt"],
|
959
|
+
}
|
960
|
+
|
961
|
+
def ts(self):
|
962
|
+
return self._src["ts"]
|
963
|
+
|
964
|
+
def __len__(self):
|
965
|
+
return len(self._src["ts"])
|
966
|
+
|
967
|
+
def __getitem__(self, idx):
|
968
|
+
if type(idx) is not Fraction:
|
969
|
+
raise Exception("Source index must be a Fraction")
|
970
|
+
return SourceExpr(self, idx, False)
|
971
|
+
|
972
|
+
def play(self, *args, **kwargs):
|
973
|
+
"""Play the video live in the notebook."""
|
974
|
+
|
975
|
+
domain = self.ts()
|
976
|
+
|
977
|
+
def render(t, _i):
|
978
|
+
return self[t]
|
979
|
+
|
980
|
+
spec = YrdenSpec(domain, render, self.fmt())
|
981
|
+
return spec.play(*args, **kwargs)
|
982
|
+
|
983
|
+
|
984
|
+
class YrdenStorageService:
|
985
|
+
def __init__(self, service: str, **kwargs):
|
986
|
+
if type(service) is not str:
|
987
|
+
raise Exception("Service name must be a string")
|
988
|
+
self._service = service
|
989
|
+
for k, v in kwargs.items():
|
990
|
+
if type(v) is not str:
|
991
|
+
raise Exception(f"Value of {k} must be a string")
|
992
|
+
self._config = kwargs
|
993
|
+
|
994
|
+
def as_json(self):
|
995
|
+
return {"service": self._service, "config": self._config}
|
996
|
+
|
997
|
+
def __repr__(self):
|
998
|
+
return f"{self._service}(config={self._config})"
|
999
|
+
|
1000
|
+
|
1001
|
+
class SourceExpr:
|
1002
|
+
def __init__(self, source, idx, is_iloc):
|
1003
|
+
self._source = source
|
1004
|
+
self._idx = idx
|
1005
|
+
self._is_iloc = is_iloc
|
1006
|
+
|
1007
|
+
def __repr__(self):
|
1008
|
+
if self._is_iloc:
|
1009
|
+
return f"{self._source._name}.iloc[{self._idx}]"
|
1010
|
+
else:
|
1011
|
+
return f"{self._source._name}[{self._idx}]"
|
1012
|
+
|
1013
|
+
def _to_json_spec(self):
|
1014
|
+
if self._is_iloc:
|
1015
|
+
return {
|
1016
|
+
"Source": {
|
1017
|
+
"video": self._source._name,
|
1018
|
+
"index": {"ILoc": int(self._idx)},
|
1019
|
+
}
|
1020
|
+
}
|
1021
|
+
else:
|
1022
|
+
return {
|
1023
|
+
"Source": {
|
1024
|
+
"video": self._source._name,
|
1025
|
+
"index": {"T": [self._idx.numerator, self._idx.denominator]},
|
1026
|
+
}
|
1027
|
+
}
|
1028
|
+
|
1029
|
+
def _sources(self):
|
1030
|
+
return set([self._source])
|
1031
|
+
|
1032
|
+
def _filters(self):
|
1033
|
+
return {}
|
1034
|
+
|
1035
|
+
|
1036
|
+
class _SourceILoc:
|
1037
|
+
def __init__(self, source):
|
1038
|
+
self._source = source
|
1039
|
+
|
1040
|
+
def __getitem__(self, idx):
|
1041
|
+
if type(idx) is not int:
|
1042
|
+
raise Exception(f"Source iloc index must be an integer, got a {type(idx)}")
|
1043
|
+
return SourceExpr(self._source, idx, True)
|
1044
|
+
|
1045
|
+
|
1046
|
+
def _json_arg(arg, skip_data_anot=False):
|
1047
|
+
if type(arg) is FilterExpr or type(arg) is SourceExpr:
|
1048
|
+
return {"Frame": arg._to_json_spec()}
|
1049
|
+
elif type(arg) is int:
|
1050
|
+
if skip_data_anot:
|
1051
|
+
return {"Int": arg}
|
1052
|
+
return {"Data": {"Int": arg}}
|
1053
|
+
elif type(arg) is str:
|
1054
|
+
if skip_data_anot:
|
1055
|
+
return {"String": arg}
|
1056
|
+
return {"Data": {"String": arg}}
|
1057
|
+
elif type(arg) is bytes:
|
1058
|
+
arg = list(arg)
|
1059
|
+
if skip_data_anot:
|
1060
|
+
return {"Bytes": arg}
|
1061
|
+
return {"Data": {"Bytes": arg}}
|
1062
|
+
elif type(arg) is float:
|
1063
|
+
if skip_data_anot:
|
1064
|
+
return {"Float": arg}
|
1065
|
+
return {"Data": {"Float": arg}}
|
1066
|
+
elif type(arg) is bool:
|
1067
|
+
if skip_data_anot:
|
1068
|
+
return {"Bool": arg}
|
1069
|
+
return {"Data": {"Bool": arg}}
|
1070
|
+
elif type(arg) is tuple or type(arg) is list:
|
1071
|
+
if skip_data_anot:
|
1072
|
+
return {"List": [_json_arg(x, True) for x in list(arg)]}
|
1073
|
+
return {"Data": {"List": [_json_arg(x, True) for x in list(arg)]}}
|
1074
|
+
else:
|
1075
|
+
raise Exception(f"Unknown arg type: {type(arg)}")
|
1076
|
+
|
1077
|
+
|
1078
|
+
class Filter:
|
1079
|
+
"""A video filter."""
|
1080
|
+
|
1081
|
+
def __init__(self, name: str, tl_func=None, **kwargs):
|
1082
|
+
self._name = name
|
1083
|
+
|
1084
|
+
# tl_func is the top level func, which is the true implementation, not just a pretty name
|
1085
|
+
if tl_func is None:
|
1086
|
+
self._func = name
|
1087
|
+
else:
|
1088
|
+
self._func = tl_func
|
1089
|
+
|
1090
|
+
# filter infra args, not invocation args
|
1091
|
+
for k, v in kwargs.items():
|
1092
|
+
if type(v) is not str:
|
1093
|
+
raise Exception(f"Value of {k} must be a string")
|
1094
|
+
self._kwargs = kwargs
|
1095
|
+
|
1096
|
+
def __call__(self, *args, **kwargs):
|
1097
|
+
return FilterExpr(self, args, kwargs)
|
1098
|
+
|
1099
|
+
|
1100
|
+
class FilterExpr:
|
1101
|
+
def __init__(self, filter: Filter, args, kwargs):
|
1102
|
+
self._filter = filter
|
1103
|
+
self._args = args
|
1104
|
+
self._kwargs = kwargs
|
1105
|
+
|
1106
|
+
def __repr__(self):
|
1107
|
+
args = []
|
1108
|
+
for arg in self._args:
|
1109
|
+
val = f'"{arg}"' if type(arg) is str else str(arg)
|
1110
|
+
args.append(str(val))
|
1111
|
+
for k, v in self._kwargs.items():
|
1112
|
+
val = f'"{v}"' if type(v) is str else str(v)
|
1113
|
+
args.append(f"{k}={val}")
|
1114
|
+
return f"{self._filter._name}({', '.join(args)})"
|
1115
|
+
|
1116
|
+
def _to_json_spec(self):
|
1117
|
+
args = []
|
1118
|
+
for arg in self._args:
|
1119
|
+
args.append(_json_arg(arg))
|
1120
|
+
kwargs = {}
|
1121
|
+
for k, v in self._kwargs.items():
|
1122
|
+
kwargs[k] = _json_arg(v)
|
1123
|
+
return {"Filter": {"name": self._filter._name, "args": args, "kwargs": kwargs}}
|
1124
|
+
|
1125
|
+
def _sources(self):
|
1126
|
+
s = set()
|
1127
|
+
for arg in self._args:
|
1128
|
+
if type(arg) is FilterExpr or type(arg) is SourceExpr:
|
1129
|
+
s = s.union(arg._sources())
|
1130
|
+
for arg in self._kwargs.values():
|
1131
|
+
if type(arg) is FilterExpr or type(arg) is SourceExpr:
|
1132
|
+
s = s.union(arg._sources())
|
1133
|
+
return s
|
1134
|
+
|
1135
|
+
def _filters(self):
|
1136
|
+
f = {self._filter._name: self._filter}
|
1137
|
+
for arg in self._args:
|
1138
|
+
if type(arg) is FilterExpr:
|
1139
|
+
f = {**f, **arg._filters()}
|
1140
|
+
for arg in self._kwargs.values():
|
1141
|
+
if type(arg) is FilterExpr:
|
1142
|
+
f = {**f, **arg._filters()}
|
1143
|
+
return f
|
1144
|
+
|
1145
|
+
|
1146
|
+
class UDF:
|
1147
|
+
"""User-defined filter superclass"""
|
1148
|
+
|
1149
|
+
def __init__(self, name: str):
|
1150
|
+
self._name = name
|
1151
|
+
self._socket_path = None
|
1152
|
+
self._p = None
|
1153
|
+
|
1154
|
+
def filter(self, *args, **kwargs):
|
1155
|
+
raise Exception("User must implement the filter method")
|
1156
|
+
|
1157
|
+
def filter_type(self, *args, **kwargs):
|
1158
|
+
raise Exception("User must implement the filter_type method")
|
1159
|
+
|
1160
|
+
def into_filter(self):
|
1161
|
+
assert self._socket_path is None
|
1162
|
+
self._socket_path = f"/tmp/vidformer-{self._name}-{str(uuid.uuid4())}.sock"
|
1163
|
+
self._p = multiprocessing.Process(
|
1164
|
+
target=_run_udf_host, args=(self, self._socket_path)
|
1165
|
+
)
|
1166
|
+
self._p.start()
|
1167
|
+
return Filter(
|
1168
|
+
name=self._name, tl_func="IPC", socket=self._socket_path, func=self._name
|
1169
|
+
)
|
1170
|
+
|
1171
|
+
def _handle_connection(self, connection):
|
1172
|
+
try:
|
1173
|
+
while True:
|
1174
|
+
frame_len = connection.recv(4)
|
1175
|
+
if not frame_len or len(frame_len) != 4:
|
1176
|
+
break
|
1177
|
+
frame_len = int.from_bytes(frame_len, byteorder="big")
|
1178
|
+
data = connection.recv(frame_len)
|
1179
|
+
if not data:
|
1180
|
+
break
|
1181
|
+
|
1182
|
+
while len(data) < frame_len:
|
1183
|
+
new_data = connection.recv(frame_len - len(data))
|
1184
|
+
if not new_data:
|
1185
|
+
raise Exception("Partial data received")
|
1186
|
+
data += new_data
|
1187
|
+
|
1188
|
+
obj = msgpack.unpackb(data, raw=False)
|
1189
|
+
f_op, f_args, f_kwargs = (
|
1190
|
+
obj["op"],
|
1191
|
+
obj["args"],
|
1192
|
+
obj["kwargs"],
|
1193
|
+
)
|
1194
|
+
|
1195
|
+
response = None
|
1196
|
+
if f_op == "filter":
|
1197
|
+
f_args = [self._deser_filter(x) for x in f_args]
|
1198
|
+
f_kwargs = {k: self._deser_filter(v) for k, v in f_kwargs}
|
1199
|
+
response = self.filter(*f_args, **f_kwargs)
|
1200
|
+
if type(response) is not UDFFrame:
|
1201
|
+
raise Exception(
|
1202
|
+
f"filter must return a UDFFrame, got {type(response)}"
|
1203
|
+
)
|
1204
|
+
if response.frame_type().pix_fmt() != "rgb24":
|
1205
|
+
raise Exception(
|
1206
|
+
f"filter must return a frame with pix_fmt 'rgb24', got {response.frame_type().pix_fmt()}"
|
1207
|
+
)
|
1208
|
+
|
1209
|
+
response = response._response_ser()
|
1210
|
+
elif f_op == "filter_type":
|
1211
|
+
f_args = [self._deser_filter_type(x) for x in f_args]
|
1212
|
+
f_kwargs = {k: self._deser_filter_type(v) for k, v in f_kwargs}
|
1213
|
+
response = self.filter_type(*f_args, **f_kwargs)
|
1214
|
+
if type(response) is not UDFFrameType:
|
1215
|
+
raise Exception(
|
1216
|
+
f"filter_type must return a UDFFrameType, got {type(response)}"
|
1217
|
+
)
|
1218
|
+
if response.pix_fmt() != "rgb24":
|
1219
|
+
raise Exception(
|
1220
|
+
f"filter_type must return a frame with pix_fmt 'rgb24', got {response.pix_fmt()}"
|
1221
|
+
)
|
1222
|
+
response = response._response_ser()
|
1223
|
+
else:
|
1224
|
+
raise Exception(f"Unknown operation: {f_op}")
|
1225
|
+
|
1226
|
+
response = msgpack.packb(response, use_bin_type=True)
|
1227
|
+
response_len = len(response).to_bytes(4, byteorder="big")
|
1228
|
+
connection.sendall(response_len)
|
1229
|
+
connection.sendall(response)
|
1230
|
+
finally:
|
1231
|
+
connection.close()
|
1232
|
+
|
1233
|
+
def _deser_filter_type(self, obj):
|
1234
|
+
assert type(obj) is dict
|
1235
|
+
keys = list(obj.keys())
|
1236
|
+
assert len(keys) == 1
|
1237
|
+
type_key = keys[0]
|
1238
|
+
assert type_key in ["FrameType", "String", "Int", "Bool"]
|
1239
|
+
|
1240
|
+
if type_key == "FrameType":
|
1241
|
+
frame = obj[type_key]
|
1242
|
+
assert type(frame) is dict
|
1243
|
+
assert "width" in frame
|
1244
|
+
assert "height" in frame
|
1245
|
+
assert "format" in frame
|
1246
|
+
assert type(frame["width"]) is int
|
1247
|
+
assert type(frame["height"]) is int
|
1248
|
+
assert frame["format"] == 2 # AV_PIX_FMT_RGB24
|
1249
|
+
return UDFFrameType(frame["width"], frame["height"], "rgb24")
|
1250
|
+
elif type_key == "String":
|
1251
|
+
assert type(obj[type_key]) is str
|
1252
|
+
return obj[type_key]
|
1253
|
+
elif type_key == "Int":
|
1254
|
+
assert type(obj[type_key]) is int
|
1255
|
+
return obj[type_key]
|
1256
|
+
elif type_key == "Bool":
|
1257
|
+
assert type(obj[type_key]) is bool
|
1258
|
+
return obj[type_key]
|
1259
|
+
else:
|
1260
|
+
assert False, f"Unknown type: {type_key}"
|
1261
|
+
|
1262
|
+
def _deser_filter(self, obj):
|
1263
|
+
assert type(obj) is dict
|
1264
|
+
keys = list(obj.keys())
|
1265
|
+
assert len(keys) == 1
|
1266
|
+
type_key = keys[0]
|
1267
|
+
assert type_key in ["Frame", "String", "Int", "Bool"]
|
1268
|
+
|
1269
|
+
if type_key == "Frame":
|
1270
|
+
frame = obj[type_key]
|
1271
|
+
assert type(frame) is dict
|
1272
|
+
assert "data" in frame
|
1273
|
+
assert "width" in frame
|
1274
|
+
assert "height" in frame
|
1275
|
+
assert "format" in frame
|
1276
|
+
assert type(frame["width"]) is int
|
1277
|
+
assert type(frame["height"]) is int
|
1278
|
+
assert frame["format"] == "rgb24"
|
1279
|
+
assert type(frame["data"]) is bytes
|
1280
|
+
|
1281
|
+
data = np.frombuffer(frame["data"], dtype=np.uint8)
|
1282
|
+
data = data.reshape(frame["height"], frame["width"], 3)
|
1283
|
+
return UDFFrame(
|
1284
|
+
data, UDFFrameType(frame["width"], frame["height"], "rgb24")
|
1285
|
+
)
|
1286
|
+
elif type_key == "String":
|
1287
|
+
assert type(obj[type_key]) is str
|
1288
|
+
return obj[type_key]
|
1289
|
+
elif type_key == "Int":
|
1290
|
+
assert type(obj[type_key]) is int
|
1291
|
+
return obj[type_key]
|
1292
|
+
elif type_key == "Bool":
|
1293
|
+
assert type(obj[type_key]) is bool
|
1294
|
+
return obj[type_key]
|
1295
|
+
else:
|
1296
|
+
assert False, f"Unknown type: {type_key}"
|
1297
|
+
|
1298
|
+
def _host(self, socket_path: str):
|
1299
|
+
if os.path.exists(socket_path):
|
1300
|
+
os.remove(socket_path)
|
1301
|
+
|
1302
|
+
# start listening on the socket
|
1303
|
+
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
1304
|
+
sock.bind(socket_path)
|
1305
|
+
sock.listen(1)
|
1306
|
+
|
1307
|
+
while True:
|
1308
|
+
# accept incoming connection
|
1309
|
+
connection, client_address = sock.accept()
|
1310
|
+
thread = threading.Thread(
|
1311
|
+
target=self._handle_connection, args=(connection,)
|
1312
|
+
)
|
1313
|
+
thread.start()
|
1314
|
+
|
1315
|
+
def __del__(self):
|
1316
|
+
if self._socket_path is not None:
|
1317
|
+
self._p.terminate()
|
1318
|
+
if os.path.exists(self._socket_path):
|
1319
|
+
# it's possible the process hasn't even created the socket yet
|
1320
|
+
os.remove(self._socket_path)
|
1321
|
+
|
1322
|
+
|
1323
|
+
class UDFFrameType:
|
1324
|
+
"""
|
1325
|
+
Frame type for use in UDFs.
|
1326
|
+
"""
|
1327
|
+
|
1328
|
+
def __init__(self, width: int, height: int, pix_fmt: str):
|
1329
|
+
assert type(width) is int
|
1330
|
+
assert type(height) is int
|
1331
|
+
assert type(pix_fmt) is str
|
1332
|
+
|
1333
|
+
self._width = width
|
1334
|
+
self._height = height
|
1335
|
+
self._pix_fmt = pix_fmt
|
1336
|
+
|
1337
|
+
def width(self):
|
1338
|
+
return self._width
|
1339
|
+
|
1340
|
+
def height(self):
|
1341
|
+
return self._height
|
1342
|
+
|
1343
|
+
def pix_fmt(self):
|
1344
|
+
return self._pix_fmt
|
1345
|
+
|
1346
|
+
def _response_ser(self):
|
1347
|
+
return {
|
1348
|
+
"frame_type": {
|
1349
|
+
"width": self._width,
|
1350
|
+
"height": self._height,
|
1351
|
+
"format": 2, # AV_PIX_FMT_RGB24
|
1352
|
+
}
|
1353
|
+
}
|
1354
|
+
|
1355
|
+
def __repr__(self):
|
1356
|
+
return f"FrameType<{self._width}x{self._height}, {self._pix_fmt}>"
|
1357
|
+
|
1358
|
+
|
1359
|
+
class UDFFrame:
|
1360
|
+
"""A symbolic reference to a frame for use in UDFs."""
|
1361
|
+
|
1362
|
+
def __init__(self, data: np.ndarray, f_type: UDFFrameType):
|
1363
|
+
assert type(data) is np.ndarray
|
1364
|
+
assert type(f_type) is UDFFrameType
|
1365
|
+
|
1366
|
+
# We only support RGB24 for now
|
1367
|
+
assert data.dtype == np.uint8
|
1368
|
+
assert data.shape[2] == 3
|
1369
|
+
|
1370
|
+
# check type matches
|
1371
|
+
assert data.shape[0] == f_type.height()
|
1372
|
+
assert data.shape[1] == f_type.width()
|
1373
|
+
assert f_type.pix_fmt() == "rgb24"
|
1374
|
+
|
1375
|
+
self._data = data
|
1376
|
+
self._f_type = f_type
|
1377
|
+
|
1378
|
+
def data(self):
|
1379
|
+
return self._data
|
1380
|
+
|
1381
|
+
def frame_type(self):
|
1382
|
+
return self._f_type
|
1383
|
+
|
1384
|
+
def _response_ser(self):
|
1385
|
+
return {
|
1386
|
+
"frame": {
|
1387
|
+
"data": self._data.tobytes(),
|
1388
|
+
"width": self._f_type.width(),
|
1389
|
+
"height": self._f_type.height(),
|
1390
|
+
"format": "rgb24",
|
1391
|
+
}
|
1392
|
+
}
|
1393
|
+
|
1394
|
+
def __repr__(self):
|
1395
|
+
return f"Frame<{self._f_type.width()}x{self._f_type.height()}, {self._f_type.pix_fmt()}>"
|
1396
|
+
|
1397
|
+
|
1398
|
+
def _run_udf_host(udf: UDF, socket_path: str):
|
1399
|
+
udf._host(socket_path)
|