vidformer 0.9.0__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vidformer/vf.py DELETED
@@ -1,1112 +0,0 @@
1
- """
2
- vidformer-py is a Python 🐍 interface for [vidformer](https://github.com/ixlab/vidformer).
3
-
4
- **Quick links:**
5
- * [📦 PyPI](https://pypi.org/project/vidformer/)
6
- * [📘 Documentation - vidformer-py](https://ixlab.github.io/vidformer/vidformer-py/)
7
- * [📘 Documentation - vidformer.cv2](https://ixlab.github.io/vidformer/vidformer-py-cv2/)
8
- * [🧑‍💻 Source Code](https://github.com/ixlab/vidformer/tree/main/vidformer-py/)
9
- """
10
-
11
- import subprocess
12
- from fractions import Fraction
13
- import random
14
- import time
15
- import json
16
- import socket
17
- import os
18
- import sys
19
- import multiprocessing
20
- import uuid
21
- import threading
22
- import gzip
23
- import base64
24
- import re
25
-
26
- import requests
27
- import msgpack
28
- import numpy as np
29
-
30
- from . import __version__
31
-
32
- _in_notebook = False
33
- try:
34
- from IPython import get_ipython
35
-
36
- if "IPKernelApp" in get_ipython().config:
37
- _in_notebook = True
38
- except:
39
- pass
40
-
41
-
42
- def _check_hls_link_exists(url, max_attempts=150, delay=0.1):
43
- for attempt in range(max_attempts):
44
- try:
45
- response = requests.get(url)
46
- if response.status_code == 200:
47
- return response.text.strip()
48
- else:
49
- time.sleep(delay)
50
- except requests.exceptions.RequestException as e:
51
- time.sleep(delay)
52
- return None
53
-
54
-
55
- def _play(namespace, hls_video_url, hls_js_url, method="html", status_url=None):
56
- # The namespace is so multiple videos in one tab don't conflict
57
-
58
- if method == "html":
59
- from IPython.display import HTML
60
-
61
- if not status_url:
62
- html_code = f"""
63
- <!DOCTYPE html>
64
- <html>
65
- <head>
66
- <title>HLS Video Player</title>
67
- <!-- Include hls.js library -->
68
- <script src="{hls_js_url}"></script>
69
- </head>
70
- <body>
71
- <video id="video-{namespace}" controls width="640" height="360" autoplay></video>
72
- <script>
73
- var video = document.getElementById('video-{namespace}');
74
- var videoSrc = '{hls_video_url}';
75
-
76
- if (Hls.isSupported()) {{
77
- var hls = new Hls();
78
- hls.loadSource(videoSrc);
79
- hls.attachMedia(video);
80
- hls.on(Hls.Events.MANIFEST_PARSED, function() {{
81
- video.play();
82
- }});
83
- }} else if (video.canPlayType('application/vnd.apple.mpegurl')) {{
84
- video.src = videoSrc;
85
- video.addEventListener('loadedmetadata', function() {{
86
- video.play();
87
- }});
88
- }} else {{
89
- console.error('This browser does not appear to support HLS.');
90
- }}
91
- </script>
92
- </body>
93
- </html>
94
- """
95
- return HTML(data=html_code)
96
- else:
97
- html_code = f"""
98
- <!DOCTYPE html>
99
- <html>
100
- <head>
101
- <title>HLS Video Player</title>
102
- <script src="{hls_js_url}"></script>
103
- </head>
104
- <body>
105
- <div id="container"></div>
106
- <script>
107
- var statusUrl = '{status_url}';
108
- var videoSrc = '{hls_video_url}';
109
- var videoNamespace = '{namespace}';
110
-
111
- function showWaiting() {{
112
- document.getElementById('container').textContent = 'Waiting...';
113
- pollStatus();
114
- }}
115
-
116
- function pollStatus() {{
117
- setTimeout(function() {{
118
- fetch(statusUrl)
119
- .then(r => r.json())
120
- .then(res => {{
121
- if (res.ready) {{
122
- document.getElementById('container').textContent = '';
123
- attachHls();
124
- }} else {{
125
- pollStatus();
126
- }}
127
- }})
128
- .catch(e => {{
129
- console.error(e);
130
- pollStatus();
131
- }});
132
- }}, 250);
133
- }}
134
-
135
- function attachHls() {{
136
- var container = document.getElementById('container');
137
- container.textContent = '';
138
- var video = document.createElement('video');
139
- video.id = 'video-' + videoNamespace;
140
- video.controls = true;
141
- video.width = 640;
142
- video.height = 360;
143
- container.appendChild(video);
144
- if (Hls.isSupported()) {{
145
- var hls = new Hls();
146
- hls.loadSource(videoSrc);
147
- hls.attachMedia(video);
148
- hls.on(Hls.Events.MANIFEST_PARSED, function() {{
149
- video.play();
150
- }});
151
- }} else if (video.canPlayType('application/vnd.apple.mpegurl')) {{
152
- video.src = videoSrc;
153
- video.addEventListener('loadedmetadata', function() {{
154
- video.play();
155
- }});
156
- }}
157
- }}
158
-
159
- fetch(statusUrl)
160
- .then(r => r.json())
161
- .then(res => {{
162
- if (res.ready) {{
163
- attachHls();
164
- }} else {{
165
- showWaiting();
166
- }}
167
- }})
168
- .catch(e => {{
169
- console.error(e);
170
- showWaiting();
171
- }});
172
- </script>
173
- </body>
174
- </html>
175
- """
176
- return HTML(data=html_code)
177
- elif method == "link":
178
- return hls_video_url
179
- else:
180
- raise ValueError("Invalid method")
181
-
182
-
183
- class Spec:
184
- """
185
- A video transformation specification.
186
-
187
- See https://ixlab.github.io/vidformer/concepts.html for more information.
188
- """
189
-
190
- def __init__(self, domain: list[Fraction], render, fmt: dict):
191
- self._domain = domain
192
- self._render = render
193
- self._fmt = fmt
194
-
195
- def __repr__(self):
196
- if len(self._domain) <= 20:
197
- lines = []
198
- for i, t in enumerate(self._domain):
199
- frame_expr = self._render(t, i)
200
- lines.append(
201
- f"{t.numerator}/{t.denominator} => {frame_expr}",
202
- )
203
- return "\n".join(lines)
204
- else:
205
- lines = []
206
- for i, t in enumerate(self._domain[:10]):
207
- frame_expr = self._render(t, i)
208
- lines.append(
209
- f"{t.numerator}/{t.denominator} => {frame_expr}",
210
- )
211
- lines.append("...")
212
- for i, t in enumerate(self._domain[-10:]):
213
- frame_expr = self._render(t, i)
214
- lines.append(
215
- f"{t.numerator}/{t.denominator} => {frame_expr}",
216
- )
217
- return "\n".join(lines)
218
-
219
- def _sources(self):
220
- s = set()
221
- for i, t in enumerate(self._domain):
222
- frame_expr = self._render(t, i)
223
- s = s.union(frame_expr._sources())
224
- return s
225
-
226
- def _to_json_spec(self):
227
- frames = []
228
- s = set()
229
- f = {}
230
- for i, t in enumerate(self._domain):
231
- frame_expr = self._render(t, i)
232
- s = s.union(frame_expr._sources())
233
- f = {**f, **frame_expr._filters()}
234
- frame = [[t.numerator, t.denominator], frame_expr._to_json_spec()]
235
- frames.append(frame)
236
- return {"frames": frames}, s, f
237
-
238
- def play(self, server, method="html", verbose=False):
239
- """Play the video live in the notebook."""
240
-
241
- spec, sources, filters = self._to_json_spec()
242
- spec_json_bytes = json.dumps(spec).encode("utf-8")
243
- spec_obj_json_gzip = gzip.compress(spec_json_bytes, compresslevel=1)
244
- spec_obj_json_gzip_b64 = base64.b64encode(spec_obj_json_gzip).decode("utf-8")
245
-
246
- sources = [
247
- {
248
- "name": s._name,
249
- "path": s._path,
250
- "stream": s._stream,
251
- "service": s._service.as_json() if s._service is not None else None,
252
- }
253
- for s in sources
254
- ]
255
- filters = {
256
- k: {
257
- "filter": v._func,
258
- "args": v._kwargs,
259
- }
260
- for k, v in filters.items()
261
- }
262
- arrays = []
263
-
264
- if verbose:
265
- print(f"Sending to server. Spec is {len(spec_obj_json_gzip_b64)} bytes")
266
-
267
- resp = server._new(spec_obj_json_gzip_b64, sources, filters, arrays, self._fmt)
268
- hls_video_url = resp["stream_url"]
269
- hls_player_url = resp["player_url"]
270
- namespace = resp["namespace"]
271
- hls_js_url = server.hls_js_url()
272
-
273
- if method == "link":
274
- return hls_video_url
275
- if method == "player":
276
- return hls_player_url
277
- if method == "iframe":
278
- from IPython.display import IFrame
279
-
280
- return IFrame(hls_player_url, width=1280, height=720)
281
- if method == "html":
282
- from IPython.display import HTML
283
-
284
- # We add a namespace to the video element to avoid conflicts with other videos
285
- html_code = f"""
286
- <!DOCTYPE html>
287
- <html>
288
- <head>
289
- <title>HLS Video Player</title>
290
- <!-- Include hls.js library -->
291
- <script src="{hls_js_url}"></script>
292
- </head>
293
- <body>
294
- <!-- Video element -->
295
- <video id="video-{namespace}" controls width="640" height="360" autoplay></video>
296
- <script>
297
- var video = document.getElementById('video-{namespace}');
298
- var videoSrc = '{hls_video_url}';
299
- var hls = new Hls();
300
- hls.loadSource(videoSrc);
301
- hls.attachMedia(video);
302
- hls.on(Hls.Events.MANIFEST_PARSED, function() {{
303
- video.play();
304
- }});
305
- </script>
306
- </body>
307
- </html>
308
- """
309
- return HTML(data=html_code)
310
- else:
311
- return hls_player_url
312
-
313
- def load(self, server):
314
- spec, sources, filters = self._to_json_spec()
315
- spec_json_bytes = json.dumps(spec).encode("utf-8")
316
- spec_obj_json_gzip = gzip.compress(spec_json_bytes, compresslevel=1)
317
- spec_obj_json_gzip_b64 = base64.b64encode(spec_obj_json_gzip).decode("utf-8")
318
-
319
- sources = [
320
- {
321
- "name": s._name,
322
- "path": s._path,
323
- "stream": s._stream,
324
- "service": s._service.as_json() if s._service is not None else None,
325
- }
326
- for s in sources
327
- ]
328
- filters = {
329
- k: {
330
- "filter": v._func,
331
- "args": v._kwargs,
332
- }
333
- for k, v in filters.items()
334
- }
335
- arrays = []
336
-
337
- resp = server._new(spec_obj_json_gzip_b64, sources, filters, arrays, self._fmt)
338
- namespace = resp["namespace"]
339
- return Loader(server, namespace, self._domain)
340
-
341
- def save(self, server, pth, encoder=None, encoder_opts=None, format=None):
342
- """Save the video to a file."""
343
-
344
- assert encoder is None or type(encoder) == str
345
- assert encoder_opts is None or type(encoder_opts) == dict
346
- if encoder_opts is not None:
347
- for k, v in encoder_opts.items():
348
- assert type(k) == str and type(v) == str
349
-
350
- spec, sources, filters = self._to_json_spec()
351
- spec_json_bytes = json.dumps(spec).encode("utf-8")
352
- spec_obj_json_gzip = gzip.compress(spec_json_bytes, compresslevel=1)
353
- spec_obj_json_gzip_b64 = base64.b64encode(spec_obj_json_gzip).decode("utf-8")
354
-
355
- sources = [
356
- {
357
- "name": s._name,
358
- "path": s._path,
359
- "stream": s._stream,
360
- "service": s._service.as_json() if s._service is not None else None,
361
- }
362
- for s in sources
363
- ]
364
- filters = {
365
- k: {
366
- "filter": v._func,
367
- "args": v._kwargs,
368
- }
369
- for k, v in filters.items()
370
- }
371
- arrays = []
372
-
373
- resp = server._export(
374
- pth,
375
- spec_obj_json_gzip_b64,
376
- sources,
377
- filters,
378
- arrays,
379
- self._fmt,
380
- encoder,
381
- encoder_opts,
382
- format,
383
- )
384
-
385
- return resp
386
-
387
- def _vrod_bench(self, server):
388
- out = {}
389
- pth = "spec.json"
390
- start_t = time.time()
391
- with open(pth, "w") as outfile:
392
- spec, sources, filters = self._to_json_spec()
393
- outfile.write(json.dumps(spec))
394
-
395
- sources = [
396
- {
397
- "name": s._name,
398
- "path": s._path,
399
- "stream": s._stream,
400
- "service": s._service.as_json() if s._service is not None else None,
401
- }
402
- for s in sources
403
- ]
404
- filters = {
405
- k: {
406
- "filter": v._func,
407
- "args": v._kwargs,
408
- }
409
- for k, v in filters.items()
410
- }
411
- arrays = []
412
- end_t = time.time()
413
- out["vrod_create_spec"] = end_t - start_t
414
-
415
- start = time.time()
416
- resp = server._new(pth, sources, filters, arrays, self._fmt)
417
- end = time.time()
418
- out["vrod_register"] = end - start
419
-
420
- stream_url = resp["stream_url"]
421
- first_segment = stream_url.replace("stream.m3u8", "segment-0.ts")
422
-
423
- start = time.time()
424
- r = requests.get(first_segment)
425
- r.raise_for_status()
426
- end = time.time()
427
- out["vrod_first_segment"] = end - start
428
- return out
429
-
430
- def _dve2_bench(self, server):
431
- pth = "spec.json"
432
- out = {}
433
- start_t = time.time()
434
- with open(pth, "w") as outfile:
435
- spec, sources, filters = self._to_json_spec()
436
- outfile.write(json.dumps(spec))
437
-
438
- sources = [
439
- {
440
- "name": s._name,
441
- "path": s._path,
442
- "stream": s._stream,
443
- "service": s._service.as_json() if s._service is not None else None,
444
- }
445
- for s in sources
446
- ]
447
- filters = {
448
- k: {
449
- "filter": v._func,
450
- "args": v._kwargs,
451
- }
452
- for k, v in filters.items()
453
- }
454
- arrays = []
455
- end_t = time.time()
456
- out["dve2_create_spec"] = end_t - start_t
457
-
458
- start = time.time()
459
- resp = server._export(pth, sources, filters, arrays, self._fmt, None, None)
460
- end = time.time()
461
- out["dve2_exec"] = end - start
462
- return out
463
-
464
-
465
- class Loader:
466
- def __init__(self, server, namespace: str, domain):
467
- self._server = server
468
- self._namespace = namespace
469
- self._domain = domain
470
-
471
- def _chunk(self, start_i, end_i):
472
- return self._server._raw(self._namespace, start_i, end_i)
473
-
474
- def __len__(self):
475
- return len(self._domain)
476
-
477
- def _find_index_by_rational(self, value):
478
- if value not in self._domain:
479
- raise ValueError(f"Rational timestamp {value} is not in the domain")
480
- return self._domain.index(value)
481
-
482
- def __getitem__(self, index):
483
- if isinstance(index, slice):
484
- start = index.start if index.start is not None else 0
485
- end = index.stop if index.stop is not None else len(self._domain)
486
- assert start >= 0 and start < len(self._domain)
487
- assert end >= 0 and end <= len(self._domain)
488
- assert start <= end
489
- num_frames = end - start
490
- all_bytes = self._chunk(start, end - 1)
491
- all_bytes_len = len(all_bytes)
492
- assert all_bytes_len % num_frames == 0
493
- return [
494
- all_bytes[
495
- i
496
- * all_bytes_len
497
- // num_frames : (i + 1)
498
- * all_bytes_len
499
- // num_frames
500
- ]
501
- for i in range(num_frames)
502
- ]
503
- elif isinstance(index, int):
504
- assert index >= 0 and index < len(self._domain)
505
- return self._chunk(index, index)
506
- else:
507
- raise TypeError(
508
- "Invalid argument type for iloc. Use a slice or an integer."
509
- )
510
-
511
-
512
- class YrdenServer:
513
- """
514
- A connection to a Yrden server.
515
-
516
- A yrden server is the main API for local use of vidformer.
517
- """
518
-
519
- def __init__(self, domain=None, port=None, bin=None, hls_prefix=None):
520
- """
521
- Connect to a Yrden server
522
-
523
- Can either connect to an existing server, if domain and port are provided, or start a new server using the provided binary.
524
- If no domain or binary is provided, the `VIDFORMER_BIN` environment variable is used.
525
- """
526
-
527
- self._domain = domain
528
- self._port = port
529
- self._proc = None
530
- if self._port is None:
531
- if bin is None:
532
- if os.getenv("VIDFORMER_BIN") is not None:
533
- bin = os.getenv("VIDFORMER_BIN")
534
- else:
535
- bin = "vidformer-cli"
536
-
537
- self._domain = "localhost"
538
- self._port = random.randint(49152, 65535)
539
- cmd = [bin, "yrden", "--port", str(self._port)]
540
- if _in_notebook:
541
- # We need to print the URL in the notebook
542
- # This is a trick to get VS Code to forward the port
543
- cmd += ["--print-url"]
544
-
545
- if hls_prefix is not None:
546
- if not type(hls_prefix) == str:
547
- raise Exception("hls_prefix must be a string")
548
- cmd += ["--hls-prefix", hls_prefix]
549
-
550
- self._proc = subprocess.Popen(cmd)
551
-
552
- version = _check_hls_link_exists(f"http://{self._domain}:{self._port}/")
553
- if version is None:
554
- raise Exception("Failed to connect to server")
555
-
556
- expected_version = f"vidformer-yrden v{__version__}"
557
- if version != expected_version:
558
- print(
559
- f"Warning: Expected version `{expected_version}`, got `{version}`. API may not be compatible!"
560
- )
561
-
562
- def _source(self, name: str, path: str, stream: int, service):
563
- r = requests.post(
564
- f"http://{self._domain}:{self._port}/source",
565
- json={
566
- "name": name,
567
- "path": path,
568
- "stream": stream,
569
- "service": service.as_json() if service is not None else None,
570
- },
571
- )
572
- if not r.ok:
573
- raise Exception(r.text)
574
-
575
- resp = r.json()
576
- resp["ts"] = [Fraction(x[0], x[1]) for x in resp["ts"]]
577
- return resp
578
-
579
- def _new(self, spec, sources, filters, arrays, fmt):
580
- req = {
581
- "spec": spec,
582
- "sources": sources,
583
- "filters": filters,
584
- "arrays": arrays,
585
- "width": fmt["width"],
586
- "height": fmt["height"],
587
- "pix_fmt": fmt["pix_fmt"],
588
- }
589
-
590
- r = requests.post(f"http://{self._domain}:{self._port}/new", json=req)
591
- if not r.ok:
592
- raise Exception(r.text)
593
-
594
- return r.json()
595
-
596
- def _export(
597
- self, pth, spec, sources, filters, arrays, fmt, encoder, encoder_opts, format
598
- ):
599
- req = {
600
- "spec": spec,
601
- "sources": sources,
602
- "filters": filters,
603
- "arrays": arrays,
604
- "width": fmt["width"],
605
- "height": fmt["height"],
606
- "pix_fmt": fmt["pix_fmt"],
607
- "output_path": pth,
608
- "encoder": encoder,
609
- "encoder_opts": encoder_opts,
610
- "format": format,
611
- }
612
-
613
- r = requests.post(f"http://{self._domain}:{self._port}/export", json=req)
614
- if not r.ok:
615
- raise Exception(r.text)
616
-
617
- return r.json()
618
-
619
- def _raw(self, namespace, start_i, end_i):
620
- r = requests.get(
621
- f"http://{self._domain}:{self._port}/{namespace}/raw/{start_i}-{end_i}"
622
- )
623
- if not r.ok:
624
- raise Exception(r.text)
625
- return r.content
626
-
627
- def hls_js_url(self):
628
- """Return the link to the yrden-hosted hls.js file"""
629
- return f"http://{self._domain}:{self._port}/hls.js"
630
-
631
- def __del__(self):
632
- if self._proc is not None:
633
- self._proc.kill()
634
-
635
-
636
- class SourceExpr:
637
- def __init__(self, source, idx, is_iloc):
638
- self._source = source
639
- self._idx = idx
640
- self._is_iloc = is_iloc
641
-
642
- def __repr__(self):
643
- if self._is_iloc:
644
- return f"{self._source._name}.iloc[{self._idx}]"
645
- else:
646
- return f"{self._source._name}[{self._idx}]"
647
-
648
- def _to_json_spec(self):
649
- if self._is_iloc:
650
- return {
651
- "Source": {
652
- "video": self._source._name,
653
- "index": {"ILoc": int(self._idx)},
654
- }
655
- }
656
- else:
657
- return {
658
- "Source": {
659
- "video": self._source._name,
660
- "index": {"T": [self._idx.numerator, self._idx.denominator]},
661
- }
662
- }
663
-
664
- def _sources(self):
665
- return set([self._source])
666
-
667
- def _filters(self):
668
- return {}
669
-
670
-
671
- class SourceILoc:
672
- def __init__(self, source):
673
- self._source = source
674
-
675
- def __getitem__(self, idx):
676
- if type(idx) != int:
677
- raise Exception(f"Source iloc index must be an integer, got a {type(idx)}")
678
- return SourceExpr(self._source, idx, True)
679
-
680
-
681
- class Source:
682
- """A video source."""
683
-
684
- def __init__(
685
- self, server: YrdenServer, name: str, path: str, stream: int, service=None
686
- ):
687
- if service is None:
688
- # check if path is a http URL and, if so, automatically set the service
689
- # for example, the following code should work with just vf.Source(server, "tos_720p", "https://f.dominik.win/data/dve2/tos_720p.mp4")
690
- # this creates a storage service with endpoint "https://f.dominik.win/" and path "data/dve2/tos_720p.mp4"
691
- # don't use the root parameter in this case
692
-
693
- match = re.match(r"(http|https)://([^/]+)(.*)", path)
694
- if match is not None:
695
- endpoint = f"{match.group(1)}://{match.group(2)}"
696
- path = match.group(3)
697
- # remove leading slash
698
- if path.startswith("/"):
699
- path = path[1:]
700
- service = StorageService("http", endpoint=endpoint)
701
-
702
- self._server = server
703
- self._name = name
704
- self._path = path
705
- self._stream = stream
706
- self._service = service
707
-
708
- self.iloc = SourceILoc(self)
709
-
710
- self._src = self._server._source(
711
- self._name, self._path, self._stream, self._service
712
- )
713
-
714
- def fmt(self):
715
- return {
716
- "width": self._src["width"],
717
- "height": self._src["height"],
718
- "pix_fmt": self._src["pix_fmt"],
719
- }
720
-
721
- def ts(self):
722
- return self._src["ts"]
723
-
724
- def __len__(self):
725
- return len(self._src["ts"])
726
-
727
- def __getitem__(self, idx):
728
- if type(idx) != Fraction:
729
- raise Exception("Source index must be a Fraction")
730
- return SourceExpr(self, idx, False)
731
-
732
- def play(self, *args, **kwargs):
733
- """Play the video live in the notebook."""
734
-
735
- domain = self.ts()
736
- render = lambda t, i: self[t]
737
- spec = Spec(domain, render, self.fmt())
738
- return spec.play(*args, **kwargs)
739
-
740
-
741
- class StorageService:
742
- def __init__(self, service: str, **kwargs):
743
- if type(service) != str:
744
- raise Exception("Service name must be a string")
745
- self._service = service
746
- for k, v in kwargs.items():
747
- if type(v) != str:
748
- raise Exception(f"Value of {k} must be a string")
749
- self._config = kwargs
750
-
751
- def as_json(self):
752
- return {"service": self._service, "config": self._config}
753
-
754
- def __repr__(self):
755
- return f"{self._service}(config={self._config})"
756
-
757
-
758
- def _json_arg(arg, skip_data_anot=False):
759
- if type(arg) == FilterExpr or type(arg) == SourceExpr:
760
- return {"Frame": arg._to_json_spec()}
761
- elif type(arg) == int:
762
- if skip_data_anot:
763
- return {"Int": arg}
764
- return {"Data": {"Int": arg}}
765
- elif type(arg) == str:
766
- if skip_data_anot:
767
- return {"String": arg}
768
- return {"Data": {"String": arg}}
769
- elif type(arg) == bytes:
770
- arg = list(arg)
771
- if skip_data_anot:
772
- return {"Bytes": arg}
773
- return {"Data": {"Bytes": arg}}
774
- elif type(arg) == float:
775
- if skip_data_anot:
776
- return {"Float": arg}
777
- return {"Data": {"Float": arg}}
778
- elif type(arg) == bool:
779
- if skip_data_anot:
780
- return {"Bool": arg}
781
- return {"Data": {"Bool": arg}}
782
- elif type(arg) == tuple or type(arg) == list:
783
- if skip_data_anot:
784
- return {"List": [_json_arg(x, True) for x in list(arg)]}
785
- return {"Data": {"List": [_json_arg(x, True) for x in list(arg)]}}
786
- else:
787
- raise Exception(f"Unknown arg type: {type(arg)}")
788
-
789
-
790
- class Filter:
791
- """A video filter."""
792
-
793
- def __init__(self, name: str, tl_func=None, **kwargs):
794
- self._name = name
795
-
796
- # tl_func is the top level func, which is the true implementation, not just a pretty name
797
- if tl_func is None:
798
- self._func = name
799
- else:
800
- self._func = tl_func
801
-
802
- # filter infra args, not invocation args
803
- for k, v in kwargs.items():
804
- if type(v) != str:
805
- raise Exception(f"Value of {k} must be a string")
806
- self._kwargs = kwargs
807
-
808
- def __call__(self, *args, **kwargs):
809
- return FilterExpr(self, args, kwargs)
810
-
811
-
812
- class FilterExpr:
813
- def __init__(self, filter: Filter, args, kwargs):
814
- self._filter = filter
815
- self._args = args
816
- self._kwargs = kwargs
817
-
818
- def __repr__(self):
819
- args = []
820
- for arg in self._args:
821
- val = f'"{arg}"' if type(arg) == str else str(arg)
822
- args.append(str(val))
823
- for k, v in self._kwargs.items():
824
- val = f'"{v}"' if type(v) == str else str(v)
825
- args.append(f"{k}={val}")
826
- return f"{self._filter._name}({', '.join(args)})"
827
-
828
- def _to_json_spec(self):
829
- args = []
830
- for arg in self._args:
831
- args.append(_json_arg(arg))
832
- kwargs = {}
833
- for k, v in self._kwargs.items():
834
- kwargs[k] = _json_arg(v)
835
- return {"Filter": {"name": self._filter._name, "args": args, "kwargs": kwargs}}
836
-
837
- def _sources(self):
838
- s = set()
839
- for arg in self._args:
840
- if type(arg) == FilterExpr or type(arg) == SourceExpr:
841
- s = s.union(arg._sources())
842
- for arg in self._kwargs.values():
843
- if type(arg) == FilterExpr or type(arg) == SourceExpr:
844
- s = s.union(arg._sources())
845
- return s
846
-
847
- def _filters(self):
848
- f = {self._filter._name: self._filter}
849
- for arg in self._args:
850
- if type(arg) == FilterExpr:
851
- f = {**f, **arg._filters()}
852
- for arg in self._kwargs.values():
853
- if type(arg) == FilterExpr:
854
- f = {**f, **arg._filters()}
855
- return f
856
-
857
-
858
- class UDF:
859
- """User-defined filter superclass"""
860
-
861
- def __init__(self, name: str):
862
- self._name = name
863
- self._socket_path = None
864
- self._p = None
865
-
866
- def filter(self, *args, **kwargs):
867
- raise Exception("User must implement the filter method")
868
-
869
- def filter_type(self, *args, **kwargs):
870
- raise Exception("User must implement the filter_type method")
871
-
872
- def into_filter(self):
873
- assert self._socket_path is None
874
- self._socket_path = f"/tmp/vidformer-{self._name}-{str(uuid.uuid4())}.sock"
875
- self._p = multiprocessing.Process(
876
- target=_run_udf_host, args=(self, self._socket_path)
877
- )
878
- self._p.start()
879
- return Filter(
880
- name=self._name, tl_func="IPC", socket=self._socket_path, func=self._name
881
- )
882
-
883
- def _handle_connection(self, connection):
884
- try:
885
- while True:
886
- frame_len = connection.recv(4)
887
- if not frame_len or len(frame_len) != 4:
888
- break
889
- frame_len = int.from_bytes(frame_len, byteorder="big")
890
- data = connection.recv(frame_len)
891
- if not data:
892
- break
893
-
894
- while len(data) < frame_len:
895
- new_data = connection.recv(frame_len - len(data))
896
- if not new_data:
897
- raise Exception("Partial data received")
898
- data += new_data
899
-
900
- obj = msgpack.unpackb(data, raw=False)
901
- f_func, f_op, f_args, f_kwargs = (
902
- obj["func"],
903
- obj["op"],
904
- obj["args"],
905
- obj["kwargs"],
906
- )
907
-
908
- response = None
909
- if f_op == "filter":
910
- f_args = [self._deser_filter(x) for x in f_args]
911
- f_kwargs = {k: self._deser_filter(v) for k, v in f_kwargs}
912
- response = self.filter(*f_args, **f_kwargs)
913
- if type(response) != UDFFrame:
914
- raise Exception(
915
- f"filter must return a UDFFrame, got {type(response)}"
916
- )
917
- if response.frame_type().pix_fmt() != "rgb24":
918
- raise Exception(
919
- f"filter must return a frame with pix_fmt 'rgb24', got {response.frame_type().pix_fmt()}"
920
- )
921
-
922
- response = response._response_ser()
923
- elif f_op == "filter_type":
924
- f_args = [self._deser_filter_type(x) for x in f_args]
925
- f_kwargs = {k: self._deser_filter_type(v) for k, v in f_kwargs}
926
- response = self.filter_type(*f_args, **f_kwargs)
927
- if type(response) != UDFFrameType:
928
- raise Exception(
929
- f"filter_type must return a UDFFrameType, got {type(response)}"
930
- )
931
- if response.pix_fmt() != "rgb24":
932
- raise Exception(
933
- f"filter_type must return a frame with pix_fmt 'rgb24', got {response.pix_fmt()}"
934
- )
935
- response = response._response_ser()
936
- else:
937
- raise Exception(f"Unknown operation: {f_op}")
938
-
939
- response = msgpack.packb(response, use_bin_type=True)
940
- response_len = len(response).to_bytes(4, byteorder="big")
941
- connection.sendall(response_len)
942
- connection.sendall(response)
943
- finally:
944
- connection.close()
945
-
946
- def _deser_filter_type(self, obj):
947
- assert type(obj) == dict
948
- keys = list(obj.keys())
949
- assert len(keys) == 1
950
- type_key = keys[0]
951
- assert type_key in ["FrameType", "String", "Int", "Bool"]
952
-
953
- if type_key == "FrameType":
954
- frame = obj[type_key]
955
- assert type(frame) == dict
956
- assert "width" in frame
957
- assert "height" in frame
958
- assert "format" in frame
959
- assert type(frame["width"]) == int
960
- assert type(frame["height"]) == int
961
- assert frame["format"] == 2 # AV_PIX_FMT_RGB24
962
- return UDFFrameType(frame["width"], frame["height"], "rgb24")
963
- elif type_key == "String":
964
- assert type(obj[type_key]) == str
965
- return obj[type_key]
966
- elif type_key == "Int":
967
- assert type(obj[type_key]) == int
968
- return obj[type_key]
969
- elif type_key == "Bool":
970
- assert type(obj[type_key]) == bool
971
- return obj[type_key]
972
- else:
973
- assert False, f"Unknown type: {type_key}"
974
-
975
- def _deser_filter(self, obj):
976
- assert type(obj) == dict
977
- keys = list(obj.keys())
978
- assert len(keys) == 1
979
- type_key = keys[0]
980
- assert type_key in ["Frame", "String", "Int", "Bool"]
981
-
982
- if type_key == "Frame":
983
- frame = obj[type_key]
984
- assert type(frame) == dict
985
- assert "data" in frame
986
- assert "width" in frame
987
- assert "height" in frame
988
- assert "format" in frame
989
- assert type(frame["width"]) == int
990
- assert type(frame["height"]) == int
991
- assert frame["format"] == "rgb24"
992
- assert type(frame["data"]) == bytes
993
-
994
- data = np.frombuffer(frame["data"], dtype=np.uint8)
995
- data = data.reshape(frame["height"], frame["width"], 3)
996
- return UDFFrame(
997
- data, UDFFrameType(frame["width"], frame["height"], "rgb24")
998
- )
999
- elif type_key == "String":
1000
- assert type(obj[type_key]) == str
1001
- return obj[type_key]
1002
- elif type_key == "Int":
1003
- assert type(obj[type_key]) == int
1004
- return obj[type_key]
1005
- elif type_key == "Bool":
1006
- assert type(obj[type_key]) == bool
1007
- return obj[type_key]
1008
- else:
1009
- assert False, f"Unknown type: {type_key}"
1010
-
1011
- def _host(self, socket_path: str):
1012
- if os.path.exists(socket_path):
1013
- os.remove(socket_path)
1014
-
1015
- # start listening on the socket
1016
- sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1017
- sock.bind(socket_path)
1018
- sock.listen(1)
1019
-
1020
- while True:
1021
- # accept incoming connection
1022
- connection, client_address = sock.accept()
1023
- thread = threading.Thread(
1024
- target=self._handle_connection, args=(connection,)
1025
- )
1026
- thread.start()
1027
-
1028
- def __del__(self):
1029
- if self._socket_path is not None:
1030
- self._p.terminate()
1031
- if os.path.exists(self._socket_path):
1032
- # it's possible the process hasn't even created the socket yet
1033
- os.remove(self._socket_path)
1034
-
1035
-
1036
- class UDFFrameType:
1037
- """
1038
- Frame type for use in UDFs.
1039
- """
1040
-
1041
- def __init__(self, width: int, height: int, pix_fmt: str):
1042
- assert type(width) == int
1043
- assert type(height) == int
1044
- assert type(pix_fmt) == str
1045
-
1046
- self._width = width
1047
- self._height = height
1048
- self._pix_fmt = pix_fmt
1049
-
1050
- def width(self):
1051
- return self._width
1052
-
1053
- def height(self):
1054
- return self._height
1055
-
1056
- def pix_fmt(self):
1057
- return self._pix_fmt
1058
-
1059
- def _response_ser(self):
1060
- return {
1061
- "frame_type": {
1062
- "width": self._width,
1063
- "height": self._height,
1064
- "format": 2, # AV_PIX_FMT_RGB24
1065
- }
1066
- }
1067
-
1068
- def __repr__(self):
1069
- return f"FrameType<{self._width}x{self._height}, {self._pix_fmt}>"
1070
-
1071
-
1072
- class UDFFrame:
1073
- """A symbolic reference to a frame for use in UDFs."""
1074
-
1075
- def __init__(self, data: np.ndarray, f_type: UDFFrameType):
1076
- assert type(data) == np.ndarray
1077
- assert type(f_type) == UDFFrameType
1078
-
1079
- # We only support RGB24 for now
1080
- assert data.dtype == np.uint8
1081
- assert data.shape[2] == 3
1082
-
1083
- # check type matches
1084
- assert data.shape[0] == f_type.height()
1085
- assert data.shape[1] == f_type.width()
1086
- assert f_type.pix_fmt() == "rgb24"
1087
-
1088
- self._data = data
1089
- self._f_type = f_type
1090
-
1091
- def data(self):
1092
- return self._data
1093
-
1094
- def frame_type(self):
1095
- return self._f_type
1096
-
1097
- def _response_ser(self):
1098
- return {
1099
- "frame": {
1100
- "data": self._data.tobytes(),
1101
- "width": self._f_type.width(),
1102
- "height": self._f_type.height(),
1103
- "format": "rgb24",
1104
- }
1105
- }
1106
-
1107
- def __repr__(self):
1108
- return f"Frame<{self._f_type.width()}x{self._f_type.height()}, {self._f_type.pix_fmt()}>"
1109
-
1110
-
1111
- def _run_udf_host(udf: UDF, socket_path: str):
1112
- udf._host(socket_path)