vidformer 0.8.0__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vidformer/vf.py DELETED
@@ -1,984 +0,0 @@
1
- """
2
- vidformer-py is a Python 🐍 interface for [vidformer](https://github.com/ixlab/vidformer).
3
-
4
- **Quick links:**
5
- * [📦 PyPI](https://pypi.org/project/vidformer/)
6
- * [📘 Documentation - vidformer-py](https://ixlab.github.io/vidformer/vidformer-py/)
7
- * [📘 Documentation - vidformer.cv2](https://ixlab.github.io/vidformer/vidformer-py-cv2/)
8
- * [🧑‍💻 Source Code](https://github.com/ixlab/vidformer/tree/main/vidformer-py/)
9
- """
10
-
11
- import subprocess
12
- from fractions import Fraction
13
- import random
14
- import time
15
- import json
16
- import socket
17
- import os
18
- import sys
19
- import multiprocessing
20
- import uuid
21
- import threading
22
- import gzip
23
- import base64
24
- import re
25
-
26
- import requests
27
- import msgpack
28
- import numpy as np
29
-
30
- from . import __version__
31
-
32
- _in_notebook = False
33
- try:
34
- from IPython import get_ipython
35
-
36
- if "IPKernelApp" in get_ipython().config:
37
- _in_notebook = True
38
- except:
39
- pass
40
-
41
-
42
- def _check_hls_link_exists(url, max_attempts=150, delay=0.1):
43
- for attempt in range(max_attempts):
44
- try:
45
- response = requests.get(url)
46
- if response.status_code == 200:
47
- return response.text.strip()
48
- else:
49
- time.sleep(delay)
50
- except requests.exceptions.RequestException as e:
51
- time.sleep(delay)
52
- return None
53
-
54
-
55
- class Spec:
56
- """
57
- A video transformation specification.
58
-
59
- See https://ixlab.github.io/vidformer/concepts.html for more information.
60
- """
61
-
62
- def __init__(self, domain: list[Fraction], render, fmt: dict):
63
- self._domain = domain
64
- self._render = render
65
- self._fmt = fmt
66
-
67
- def __repr__(self):
68
- if len(self._domain) <= 20:
69
- lines = []
70
- for i, t in enumerate(self._domain):
71
- frame_expr = self._render(t, i)
72
- lines.append(
73
- f"{t.numerator}/{t.denominator} => {frame_expr}",
74
- )
75
- return "\n".join(lines)
76
- else:
77
- lines = []
78
- for i, t in enumerate(self._domain[:10]):
79
- frame_expr = self._render(t, i)
80
- lines.append(
81
- f"{t.numerator}/{t.denominator} => {frame_expr}",
82
- )
83
- lines.append("...")
84
- for i, t in enumerate(self._domain[-10:]):
85
- frame_expr = self._render(t, i)
86
- lines.append(
87
- f"{t.numerator}/{t.denominator} => {frame_expr}",
88
- )
89
- return "\n".join(lines)
90
-
91
- def _sources(self):
92
- s = set()
93
- for i, t in enumerate(self._domain):
94
- frame_expr = self._render(t, i)
95
- s = s.union(frame_expr._sources())
96
- return s
97
-
98
- def _to_json_spec(self):
99
- frames = []
100
- s = set()
101
- f = {}
102
- for i, t in enumerate(self._domain):
103
- frame_expr = self._render(t, i)
104
- s = s.union(frame_expr._sources())
105
- f = {**f, **frame_expr._filters()}
106
- frame = [[t.numerator, t.denominator], frame_expr._to_json_spec()]
107
- frames.append(frame)
108
- return {"frames": frames}, s, f
109
-
110
- def play(self, server, method="html", verbose=False):
111
- """Play the video live in the notebook."""
112
-
113
- spec, sources, filters = self._to_json_spec()
114
- spec_json_bytes = json.dumps(spec).encode("utf-8")
115
- spec_obj_json_gzip = gzip.compress(spec_json_bytes, compresslevel=1)
116
- spec_obj_json_gzip_b64 = base64.b64encode(spec_obj_json_gzip).decode("utf-8")
117
-
118
- sources = [
119
- {
120
- "name": s._name,
121
- "path": s._path,
122
- "stream": s._stream,
123
- "service": s._service.as_json() if s._service is not None else None,
124
- }
125
- for s in sources
126
- ]
127
- filters = {
128
- k: {
129
- "filter": v._func,
130
- "args": v._kwargs,
131
- }
132
- for k, v in filters.items()
133
- }
134
- arrays = []
135
-
136
- if verbose:
137
- print(f"Sending to server. Spec is {len(spec_obj_json_gzip_b64)} bytes")
138
-
139
- resp = server._new(spec_obj_json_gzip_b64, sources, filters, arrays, self._fmt)
140
- hls_video_url = resp["stream_url"]
141
- hls_player_url = resp["player_url"]
142
- namespace = resp["namespace"]
143
- hls_js_url = server.hls_js_url()
144
-
145
- if method == "link":
146
- return hls_video_url
147
- if method == "player":
148
- return hls_player_url
149
- if method == "iframe":
150
- from IPython.display import IFrame
151
-
152
- return IFrame(hls_player_url, width=1280, height=720)
153
- if method == "html":
154
- from IPython.display import HTML
155
-
156
- # We add a namespace to the video element to avoid conflicts with other videos
157
- html_code = f"""
158
- <!DOCTYPE html>
159
- <html>
160
- <head>
161
- <title>HLS Video Player</title>
162
- <!-- Include hls.js library -->
163
- <script src="{hls_js_url}"></script>
164
- </head>
165
- <body>
166
- <!-- Video element -->
167
- <video id="video-{namespace}" controls width="640" height="360" autoplay></video>
168
- <script>
169
- var video = document.getElementById('video-{namespace}');
170
- var videoSrc = '{hls_video_url}';
171
- var hls = new Hls();
172
- hls.loadSource(videoSrc);
173
- hls.attachMedia(video);
174
- hls.on(Hls.Events.MANIFEST_PARSED, function() {{
175
- video.play();
176
- }});
177
- </script>
178
- </body>
179
- </html>
180
- """
181
- return HTML(data=html_code)
182
- else:
183
- return hls_player_url
184
-
185
- def load(self, server):
186
- spec, sources, filters = self._to_json_spec()
187
- spec_json_bytes = json.dumps(spec).encode("utf-8")
188
- spec_obj_json_gzip = gzip.compress(spec_json_bytes, compresslevel=1)
189
- spec_obj_json_gzip_b64 = base64.b64encode(spec_obj_json_gzip).decode("utf-8")
190
-
191
- sources = [
192
- {
193
- "name": s._name,
194
- "path": s._path,
195
- "stream": s._stream,
196
- "service": s._service.as_json() if s._service is not None else None,
197
- }
198
- for s in sources
199
- ]
200
- filters = {
201
- k: {
202
- "filter": v._func,
203
- "args": v._kwargs,
204
- }
205
- for k, v in filters.items()
206
- }
207
- arrays = []
208
-
209
- resp = server._new(spec_obj_json_gzip_b64, sources, filters, arrays, self._fmt)
210
- namespace = resp["namespace"]
211
- return Loader(server, namespace, self._domain)
212
-
213
- def save(self, server, pth, encoder=None, encoder_opts=None, format=None):
214
- """Save the video to a file."""
215
-
216
- assert encoder is None or type(encoder) == str
217
- assert encoder_opts is None or type(encoder_opts) == dict
218
- if encoder_opts is not None:
219
- for k, v in encoder_opts.items():
220
- assert type(k) == str and type(v) == str
221
-
222
- spec, sources, filters = self._to_json_spec()
223
- spec_json_bytes = json.dumps(spec).encode("utf-8")
224
- spec_obj_json_gzip = gzip.compress(spec_json_bytes, compresslevel=1)
225
- spec_obj_json_gzip_b64 = base64.b64encode(spec_obj_json_gzip).decode("utf-8")
226
-
227
- sources = [
228
- {
229
- "name": s._name,
230
- "path": s._path,
231
- "stream": s._stream,
232
- "service": s._service.as_json() if s._service is not None else None,
233
- }
234
- for s in sources
235
- ]
236
- filters = {
237
- k: {
238
- "filter": v._func,
239
- "args": v._kwargs,
240
- }
241
- for k, v in filters.items()
242
- }
243
- arrays = []
244
-
245
- resp = server._export(
246
- pth,
247
- spec_obj_json_gzip_b64,
248
- sources,
249
- filters,
250
- arrays,
251
- self._fmt,
252
- encoder,
253
- encoder_opts,
254
- format,
255
- )
256
-
257
- return resp
258
-
259
- def _vrod_bench(self, server):
260
- out = {}
261
- pth = "spec.json"
262
- start_t = time.time()
263
- with open(pth, "w") as outfile:
264
- spec, sources, filters = self._to_json_spec()
265
- outfile.write(json.dumps(spec))
266
-
267
- sources = [
268
- {
269
- "name": s._name,
270
- "path": s._path,
271
- "stream": s._stream,
272
- "service": s._service.as_json() if s._service is not None else None,
273
- }
274
- for s in sources
275
- ]
276
- filters = {
277
- k: {
278
- "filter": v._func,
279
- "args": v._kwargs,
280
- }
281
- for k, v in filters.items()
282
- }
283
- arrays = []
284
- end_t = time.time()
285
- out["vrod_create_spec"] = end_t - start_t
286
-
287
- start = time.time()
288
- resp = server._new(pth, sources, filters, arrays, self._fmt)
289
- end = time.time()
290
- out["vrod_register"] = end - start
291
-
292
- stream_url = resp["stream_url"]
293
- first_segment = stream_url.replace("stream.m3u8", "segment-0.ts")
294
-
295
- start = time.time()
296
- r = requests.get(first_segment)
297
- r.raise_for_status()
298
- end = time.time()
299
- out["vrod_first_segment"] = end - start
300
- return out
301
-
302
- def _dve2_bench(self, server):
303
- pth = "spec.json"
304
- out = {}
305
- start_t = time.time()
306
- with open(pth, "w") as outfile:
307
- spec, sources, filters = self._to_json_spec()
308
- outfile.write(json.dumps(spec))
309
-
310
- sources = [
311
- {
312
- "name": s._name,
313
- "path": s._path,
314
- "stream": s._stream,
315
- "service": s._service.as_json() if s._service is not None else None,
316
- }
317
- for s in sources
318
- ]
319
- filters = {
320
- k: {
321
- "filter": v._func,
322
- "args": v._kwargs,
323
- }
324
- for k, v in filters.items()
325
- }
326
- arrays = []
327
- end_t = time.time()
328
- out["dve2_create_spec"] = end_t - start_t
329
-
330
- start = time.time()
331
- resp = server._export(pth, sources, filters, arrays, self._fmt, None, None)
332
- end = time.time()
333
- out["dve2_exec"] = end - start
334
- return out
335
-
336
-
337
- class Loader:
338
- def __init__(self, server, namespace: str, domain):
339
- self._server = server
340
- self._namespace = namespace
341
- self._domain = domain
342
-
343
- def _chunk(self, start_i, end_i):
344
- return self._server._raw(self._namespace, start_i, end_i)
345
-
346
- def __len__(self):
347
- return len(self._domain)
348
-
349
- def _find_index_by_rational(self, value):
350
- if value not in self._domain:
351
- raise ValueError(f"Rational timestamp {value} is not in the domain")
352
- return self._domain.index(value)
353
-
354
- def __getitem__(self, index):
355
- if isinstance(index, slice):
356
- start = index.start if index.start is not None else 0
357
- end = index.stop if index.stop is not None else len(self._domain)
358
- assert start >= 0 and start < len(self._domain)
359
- assert end >= 0 and end <= len(self._domain)
360
- assert start <= end
361
- num_frames = end - start
362
- all_bytes = self._chunk(start, end - 1)
363
- all_bytes_len = len(all_bytes)
364
- assert all_bytes_len % num_frames == 0
365
- return [
366
- all_bytes[
367
- i
368
- * all_bytes_len
369
- // num_frames : (i + 1)
370
- * all_bytes_len
371
- // num_frames
372
- ]
373
- for i in range(num_frames)
374
- ]
375
- elif isinstance(index, int):
376
- assert index >= 0 and index < len(self._domain)
377
- return self._chunk(index, index)
378
- else:
379
- raise TypeError(
380
- "Invalid argument type for iloc. Use a slice or an integer."
381
- )
382
-
383
-
384
- class YrdenServer:
385
- """
386
- A connection to a Yrden server.
387
-
388
- A yrden server is the main API for local use of vidformer.
389
- """
390
-
391
- def __init__(self, domain=None, port=None, bin=None, hls_prefix=None):
392
- """
393
- Connect to a Yrden server
394
-
395
- Can either connect to an existing server, if domain and port are provided, or start a new server using the provided binary.
396
- If no domain or binary is provided, the `VIDFORMER_BIN` environment variable is used.
397
- """
398
-
399
- self._domain = domain
400
- self._port = port
401
- self._proc = None
402
- if self._port is None:
403
- if bin is None:
404
- if os.getenv("VIDFORMER_BIN") is not None:
405
- bin = os.getenv("VIDFORMER_BIN")
406
- else:
407
- bin = "vidformer-cli"
408
-
409
- self._domain = "localhost"
410
- self._port = random.randint(49152, 65535)
411
- cmd = [bin, "yrden", "--port", str(self._port)]
412
- if _in_notebook:
413
- # We need to print the URL in the notebook
414
- # This is a trick to get VS Code to forward the port
415
- cmd += ["--print-url"]
416
-
417
- if hls_prefix is not None:
418
- if not type(hls_prefix) == str:
419
- raise Exception("hls_prefix must be a string")
420
- cmd += ["--hls-prefix", hls_prefix]
421
-
422
- self._proc = subprocess.Popen(cmd)
423
-
424
- version = _check_hls_link_exists(f"http://{self._domain}:{self._port}/")
425
- if version is None:
426
- raise Exception("Failed to connect to server")
427
-
428
- expected_version = f"vidformer-yrden v{__version__}"
429
- if version != expected_version:
430
- print(
431
- f"Warning: Expected version `{expected_version}`, got `{version}`. API may not be compatible!"
432
- )
433
-
434
- def _source(self, name: str, path: str, stream: int, service):
435
- r = requests.post(
436
- f"http://{self._domain}:{self._port}/source",
437
- json={
438
- "name": name,
439
- "path": path,
440
- "stream": stream,
441
- "service": service.as_json() if service is not None else None,
442
- },
443
- )
444
- if not r.ok:
445
- raise Exception(r.text)
446
-
447
- resp = r.json()
448
- resp["ts"] = [Fraction(x[0], x[1]) for x in resp["ts"]]
449
- return resp
450
-
451
- def _new(self, spec, sources, filters, arrays, fmt):
452
- req = {
453
- "spec": spec,
454
- "sources": sources,
455
- "filters": filters,
456
- "arrays": arrays,
457
- "width": fmt["width"],
458
- "height": fmt["height"],
459
- "pix_fmt": fmt["pix_fmt"],
460
- }
461
-
462
- r = requests.post(f"http://{self._domain}:{self._port}/new", json=req)
463
- if not r.ok:
464
- raise Exception(r.text)
465
-
466
- return r.json()
467
-
468
- def _export(
469
- self, pth, spec, sources, filters, arrays, fmt, encoder, encoder_opts, format
470
- ):
471
- req = {
472
- "spec": spec,
473
- "sources": sources,
474
- "filters": filters,
475
- "arrays": arrays,
476
- "width": fmt["width"],
477
- "height": fmt["height"],
478
- "pix_fmt": fmt["pix_fmt"],
479
- "output_path": pth,
480
- "encoder": encoder,
481
- "encoder_opts": encoder_opts,
482
- "format": format,
483
- }
484
-
485
- r = requests.post(f"http://{self._domain}:{self._port}/export", json=req)
486
- if not r.ok:
487
- raise Exception(r.text)
488
-
489
- return r.json()
490
-
491
- def _raw(self, namespace, start_i, end_i):
492
- r = requests.get(
493
- f"http://{self._domain}:{self._port}/{namespace}/raw/{start_i}-{end_i}"
494
- )
495
- if not r.ok:
496
- raise Exception(r.text)
497
- return r.content
498
-
499
- def hls_js_url(self):
500
- """Return the link to the yrden-hosted hls.js file"""
501
- return f"http://{self._domain}:{self._port}/hls.js"
502
-
503
- def __del__(self):
504
- if self._proc is not None:
505
- self._proc.kill()
506
-
507
-
508
- class SourceExpr:
509
- def __init__(self, source, idx, is_iloc):
510
- self._source = source
511
- self._idx = idx
512
- self._is_iloc = is_iloc
513
-
514
- def __repr__(self):
515
- if self._is_iloc:
516
- return f"{self._source._name}.iloc[{self._idx}]"
517
- else:
518
- return f"{self._source._name}[{self._idx}]"
519
-
520
- def _to_json_spec(self):
521
- if self._is_iloc:
522
- return {
523
- "Source": {
524
- "video": self._source._name,
525
- "index": {"ILoc": int(self._idx)},
526
- }
527
- }
528
- else:
529
- return {
530
- "Source": {
531
- "video": self._source._name,
532
- "index": {"T": [self._idx.numerator, self._idx.denominator]},
533
- }
534
- }
535
-
536
- def _sources(self):
537
- return set([self._source])
538
-
539
- def _filters(self):
540
- return {}
541
-
542
-
543
- class SourceILoc:
544
- def __init__(self, source):
545
- self._source = source
546
-
547
- def __getitem__(self, idx):
548
- if type(idx) != int:
549
- raise Exception(f"Source iloc index must be an integer, got a {type(idx)}")
550
- return SourceExpr(self._source, idx, True)
551
-
552
-
553
- class Source:
554
- """A video source."""
555
-
556
- def __init__(
557
- self, server: YrdenServer, name: str, path: str, stream: int, service=None
558
- ):
559
- if service is None:
560
- # check if path is a http URL and, if so, automatically set the service
561
- # for example, the following code should work with just vf.Source(server, "tos_720p", "https://f.dominik.win/data/dve2/tos_720p.mp4")
562
- # this creates a storage service with endpoint "https://f.dominik.win/" and path "data/dve2/tos_720p.mp4"
563
- # don't use the root parameter in this case
564
-
565
- match = re.match(r"(http|https)://([^/]+)(.*)", path)
566
- if match is not None:
567
- endpoint = f"{match.group(1)}://{match.group(2)}"
568
- path = match.group(3)
569
- # remove leading slash
570
- if path.startswith("/"):
571
- path = path[1:]
572
- service = StorageService("http", endpoint=endpoint)
573
-
574
- self._server = server
575
- self._name = name
576
- self._path = path
577
- self._stream = stream
578
- self._service = service
579
-
580
- self.iloc = SourceILoc(self)
581
-
582
- self._src = self._server._source(
583
- self._name, self._path, self._stream, self._service
584
- )
585
-
586
- def fmt(self):
587
- return {
588
- "width": self._src["width"],
589
- "height": self._src["height"],
590
- "pix_fmt": self._src["pix_fmt"],
591
- }
592
-
593
- def ts(self):
594
- return self._src["ts"]
595
-
596
- def __len__(self):
597
- return len(self._src["ts"])
598
-
599
- def __getitem__(self, idx):
600
- if type(idx) != Fraction:
601
- raise Exception("Source index must be a Fraction")
602
- return SourceExpr(self, idx, False)
603
-
604
- def play(self, *args, **kwargs):
605
- """Play the video live in the notebook."""
606
-
607
- domain = self.ts()
608
- render = lambda t, i: self[t]
609
- spec = Spec(domain, render, self.fmt())
610
- return spec.play(*args, **kwargs)
611
-
612
-
613
- class StorageService:
614
- def __init__(self, service: str, **kwargs):
615
- if type(service) != str:
616
- raise Exception("Service name must be a string")
617
- self._service = service
618
- for k, v in kwargs.items():
619
- if type(v) != str:
620
- raise Exception(f"Value of {k} must be a string")
621
- self._config = kwargs
622
-
623
- def as_json(self):
624
- return {"service": self._service, "config": self._config}
625
-
626
- def __repr__(self):
627
- return f"{self._service}(config={self._config})"
628
-
629
-
630
- def _json_arg(arg, skip_data_anot=False):
631
- if type(arg) == FilterExpr or type(arg) == SourceExpr:
632
- return {"Frame": arg._to_json_spec()}
633
- elif type(arg) == int:
634
- if skip_data_anot:
635
- return {"Int": arg}
636
- return {"Data": {"Int": arg}}
637
- elif type(arg) == str:
638
- if skip_data_anot:
639
- return {"String": arg}
640
- return {"Data": {"String": arg}}
641
- elif type(arg) == bytes:
642
- arg = list(arg)
643
- if skip_data_anot:
644
- return {"Bytes": arg}
645
- return {"Data": {"Bytes": arg}}
646
- elif type(arg) == float:
647
- if skip_data_anot:
648
- return {"Float": arg}
649
- return {"Data": {"Float": arg}}
650
- elif type(arg) == bool:
651
- if skip_data_anot:
652
- return {"Bool": arg}
653
- return {"Data": {"Bool": arg}}
654
- elif type(arg) == tuple or type(arg) == list:
655
- if skip_data_anot:
656
- return {"List": [_json_arg(x, True) for x in list(arg)]}
657
- return {"Data": {"List": [_json_arg(x, True) for x in list(arg)]}}
658
- else:
659
- raise Exception(f"Unknown arg type: {type(arg)}")
660
-
661
-
662
- class Filter:
663
- """A video filter."""
664
-
665
- def __init__(self, name: str, tl_func=None, **kwargs):
666
- self._name = name
667
-
668
- # tl_func is the top level func, which is the true implementation, not just a pretty name
669
- if tl_func is None:
670
- self._func = name
671
- else:
672
- self._func = tl_func
673
-
674
- # filter infra args, not invocation args
675
- for k, v in kwargs.items():
676
- if type(v) != str:
677
- raise Exception(f"Value of {k} must be a string")
678
- self._kwargs = kwargs
679
-
680
- def __call__(self, *args, **kwargs):
681
- return FilterExpr(self, args, kwargs)
682
-
683
-
684
- class FilterExpr:
685
- def __init__(self, filter: Filter, args, kwargs):
686
- self._filter = filter
687
- self._args = args
688
- self._kwargs = kwargs
689
-
690
- def __repr__(self):
691
- args = []
692
- for arg in self._args:
693
- val = f'"{arg}"' if type(arg) == str else str(arg)
694
- args.append(str(val))
695
- for k, v in self._kwargs.items():
696
- val = f'"{v}"' if type(v) == str else str(v)
697
- args.append(f"{k}={val}")
698
- return f"{self._filter._name}({', '.join(args)})"
699
-
700
- def _to_json_spec(self):
701
- args = []
702
- for arg in self._args:
703
- args.append(_json_arg(arg))
704
- kwargs = {}
705
- for k, v in self._kwargs.items():
706
- kwargs[k] = _json_arg(v)
707
- return {"Filter": {"name": self._filter._name, "args": args, "kwargs": kwargs}}
708
-
709
- def _sources(self):
710
- s = set()
711
- for arg in self._args:
712
- if type(arg) == FilterExpr or type(arg) == SourceExpr:
713
- s = s.union(arg._sources())
714
- for arg in self._kwargs.values():
715
- if type(arg) == FilterExpr or type(arg) == SourceExpr:
716
- s = s.union(arg._sources())
717
- return s
718
-
719
- def _filters(self):
720
- f = {self._filter._name: self._filter}
721
- for arg in self._args:
722
- if type(arg) == FilterExpr:
723
- f = {**f, **arg._filters()}
724
- for arg in self._kwargs.values():
725
- if type(arg) == FilterExpr:
726
- f = {**f, **arg._filters()}
727
- return f
728
-
729
-
730
- class UDF:
731
- """User-defined filter superclass"""
732
-
733
- def __init__(self, name: str):
734
- self._name = name
735
- self._socket_path = None
736
- self._p = None
737
-
738
- def filter(self, *args, **kwargs):
739
- raise Exception("User must implement the filter method")
740
-
741
- def filter_type(self, *args, **kwargs):
742
- raise Exception("User must implement the filter_type method")
743
-
744
- def into_filter(self):
745
- assert self._socket_path is None
746
- self._socket_path = f"/tmp/vidformer-{self._name}-{str(uuid.uuid4())}.sock"
747
- self._p = multiprocessing.Process(
748
- target=_run_udf_host, args=(self, self._socket_path)
749
- )
750
- self._p.start()
751
- return Filter(
752
- name=self._name, tl_func="IPC", socket=self._socket_path, func=self._name
753
- )
754
-
755
- def _handle_connection(self, connection):
756
- try:
757
- while True:
758
- frame_len = connection.recv(4)
759
- if not frame_len or len(frame_len) != 4:
760
- break
761
- frame_len = int.from_bytes(frame_len, byteorder="big")
762
- data = connection.recv(frame_len)
763
- if not data:
764
- break
765
-
766
- while len(data) < frame_len:
767
- new_data = connection.recv(frame_len - len(data))
768
- if not new_data:
769
- raise Exception("Partial data received")
770
- data += new_data
771
-
772
- obj = msgpack.unpackb(data, raw=False)
773
- f_func, f_op, f_args, f_kwargs = (
774
- obj["func"],
775
- obj["op"],
776
- obj["args"],
777
- obj["kwargs"],
778
- )
779
-
780
- response = None
781
- if f_op == "filter":
782
- f_args = [self._deser_filter(x) for x in f_args]
783
- f_kwargs = {k: self._deser_filter(v) for k, v in f_kwargs}
784
- response = self.filter(*f_args, **f_kwargs)
785
- if type(response) != UDFFrame:
786
- raise Exception(
787
- f"filter must return a UDFFrame, got {type(response)}"
788
- )
789
- if response.frame_type().pix_fmt() != "rgb24":
790
- raise Exception(
791
- f"filter must return a frame with pix_fmt 'rgb24', got {response.frame_type().pix_fmt()}"
792
- )
793
-
794
- response = response._response_ser()
795
- elif f_op == "filter_type":
796
- f_args = [self._deser_filter_type(x) for x in f_args]
797
- f_kwargs = {k: self._deser_filter_type(v) for k, v in f_kwargs}
798
- response = self.filter_type(*f_args, **f_kwargs)
799
- if type(response) != UDFFrameType:
800
- raise Exception(
801
- f"filter_type must return a UDFFrameType, got {type(response)}"
802
- )
803
- if response.pix_fmt() != "rgb24":
804
- raise Exception(
805
- f"filter_type must return a frame with pix_fmt 'rgb24', got {response.pix_fmt()}"
806
- )
807
- response = response._response_ser()
808
- else:
809
- raise Exception(f"Unknown operation: {f_op}")
810
-
811
- response = msgpack.packb(response, use_bin_type=True)
812
- response_len = len(response).to_bytes(4, byteorder="big")
813
- connection.sendall(response_len)
814
- connection.sendall(response)
815
- finally:
816
- connection.close()
817
-
818
- def _deser_filter_type(self, obj):
819
- assert type(obj) == dict
820
- keys = list(obj.keys())
821
- assert len(keys) == 1
822
- type_key = keys[0]
823
- assert type_key in ["FrameType", "String", "Int", "Bool"]
824
-
825
- if type_key == "FrameType":
826
- frame = obj[type_key]
827
- assert type(frame) == dict
828
- assert "width" in frame
829
- assert "height" in frame
830
- assert "format" in frame
831
- assert type(frame["width"]) == int
832
- assert type(frame["height"]) == int
833
- assert frame["format"] == 2 # AV_PIX_FMT_RGB24
834
- return UDFFrameType(frame["width"], frame["height"], "rgb24")
835
- elif type_key == "String":
836
- assert type(obj[type_key]) == str
837
- return obj[type_key]
838
- elif type_key == "Int":
839
- assert type(obj[type_key]) == int
840
- return obj[type_key]
841
- elif type_key == "Bool":
842
- assert type(obj[type_key]) == bool
843
- return obj[type_key]
844
- else:
845
- assert False, f"Unknown type: {type_key}"
846
-
847
- def _deser_filter(self, obj):
848
- assert type(obj) == dict
849
- keys = list(obj.keys())
850
- assert len(keys) == 1
851
- type_key = keys[0]
852
- assert type_key in ["Frame", "String", "Int", "Bool"]
853
-
854
- if type_key == "Frame":
855
- frame = obj[type_key]
856
- assert type(frame) == dict
857
- assert "data" in frame
858
- assert "width" in frame
859
- assert "height" in frame
860
- assert "format" in frame
861
- assert type(frame["width"]) == int
862
- assert type(frame["height"]) == int
863
- assert frame["format"] == "rgb24"
864
- assert type(frame["data"]) == bytes
865
-
866
- data = np.frombuffer(frame["data"], dtype=np.uint8)
867
- data = data.reshape(frame["height"], frame["width"], 3)
868
- return UDFFrame(
869
- data, UDFFrameType(frame["width"], frame["height"], "rgb24")
870
- )
871
- elif type_key == "String":
872
- assert type(obj[type_key]) == str
873
- return obj[type_key]
874
- elif type_key == "Int":
875
- assert type(obj[type_key]) == int
876
- return obj[type_key]
877
- elif type_key == "Bool":
878
- assert type(obj[type_key]) == bool
879
- return obj[type_key]
880
- else:
881
- assert False, f"Unknown type: {type_key}"
882
-
883
- def _host(self, socket_path: str):
884
- if os.path.exists(socket_path):
885
- os.remove(socket_path)
886
-
887
- # start listening on the socket
888
- sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
889
- sock.bind(socket_path)
890
- sock.listen(1)
891
-
892
- while True:
893
- # accept incoming connection
894
- connection, client_address = sock.accept()
895
- thread = threading.Thread(
896
- target=self._handle_connection, args=(connection,)
897
- )
898
- thread.start()
899
-
900
- def __del__(self):
901
- if self._socket_path is not None:
902
- self._p.terminate()
903
- if os.path.exists(self._socket_path):
904
- # it's possible the process hasn't even created the socket yet
905
- os.remove(self._socket_path)
906
-
907
-
908
- class UDFFrameType:
909
- """
910
- Frame type for use in UDFs.
911
- """
912
-
913
- def __init__(self, width: int, height: int, pix_fmt: str):
914
- assert type(width) == int
915
- assert type(height) == int
916
- assert type(pix_fmt) == str
917
-
918
- self._width = width
919
- self._height = height
920
- self._pix_fmt = pix_fmt
921
-
922
- def width(self):
923
- return self._width
924
-
925
- def height(self):
926
- return self._height
927
-
928
- def pix_fmt(self):
929
- return self._pix_fmt
930
-
931
- def _response_ser(self):
932
- return {
933
- "frame_type": {
934
- "width": self._width,
935
- "height": self._height,
936
- "format": 2, # AV_PIX_FMT_RGB24
937
- }
938
- }
939
-
940
- def __repr__(self):
941
- return f"FrameType<{self._width}x{self._height}, {self._pix_fmt}>"
942
-
943
-
944
- class UDFFrame:
945
- """A symbolic reference to a frame for use in UDFs."""
946
-
947
- def __init__(self, data: np.ndarray, f_type: UDFFrameType):
948
- assert type(data) == np.ndarray
949
- assert type(f_type) == UDFFrameType
950
-
951
- # We only support RGB24 for now
952
- assert data.dtype == np.uint8
953
- assert data.shape[2] == 3
954
-
955
- # check type matches
956
- assert data.shape[0] == f_type.height()
957
- assert data.shape[1] == f_type.width()
958
- assert f_type.pix_fmt() == "rgb24"
959
-
960
- self._data = data
961
- self._f_type = f_type
962
-
963
- def data(self):
964
- return self._data
965
-
966
- def frame_type(self):
967
- return self._f_type
968
-
969
- def _response_ser(self):
970
- return {
971
- "frame": {
972
- "data": self._data.tobytes(),
973
- "width": self._f_type.width(),
974
- "height": self._f_type.height(),
975
- "format": "rgb24",
976
- }
977
- }
978
-
979
- def __repr__(self):
980
- return f"Frame<{self._f_type.width()}x{self._f_type.height()}, {self._f_type.pix_fmt()}>"
981
-
982
-
983
- def _run_udf_host(udf: UDF, socket_path: str):
984
- udf._host(socket_path)