vidformer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,15 @@
1
+ Metadata-Version: 2.1
2
+ Name: vidformer
3
+ Version: 0.1.0
4
+ Summary: A Python library for creating and viewing videos with vidformer.
5
+ Author-email: Dominik Winecki <dominikwinecki@gmail.com>
6
+ Requires-Python: >=3.8
7
+ Description-Content-Type: text/markdown
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: Operating System :: OS Independent
10
+ Requires-Dist: requests
11
+ Requires-Dist: msgpack
12
+ Requires-Dist: numpy
13
+
14
+ # vidformer-py
15
+
@@ -0,0 +1,4 @@
1
+ vidformer.py,sha256=bLGZ4zQVkU6etFj0jWuJM8h0sJqcZXwkqA0VfdtCyk0,23257
2
+ vidformer-0.1.0.dist-info/WHEEL,sha256=EZbGkh7Ie4PoZfRQ8I0ZuP9VklN_TvcZ6DSE5Uar4z4,81
3
+ vidformer-0.1.0.dist-info/METADATA,sha256=2YNtgTcee8N9LM8lzMvNFLndbcNy88Fqk6mwFYhz8pQ,427
4
+ vidformer-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: flit 3.9.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
vidformer.py ADDED
@@ -0,0 +1,759 @@
1
+ """A Python library for creating and viewing videos with vidformer."""
2
+
3
+ __version__ = "0.1.0"
4
+
5
+ import subprocess
6
+ from fractions import Fraction
7
+ import random
8
+ import time
9
+ import json
10
+ import socket
11
+ import os
12
+ import sys
13
+ import multiprocessing
14
+ import uuid
15
+ import threading
16
+
17
+ import requests
18
+ import msgpack
19
+ import numpy as np
20
+
21
+ _in_notebook = False
22
+ try:
23
+ from IPython import get_ipython
24
+
25
+ if "IPKernelApp" in get_ipython().config:
26
+ _in_notebook = True
27
+ except:
28
+ pass
29
+
30
+
31
+ def _check_hls_link_exists(url, max_attempts=150, delay=0.1):
32
+ for attempt in range(max_attempts):
33
+ try:
34
+ response = requests.get(url)
35
+ if response.status_code == 200:
36
+ return True
37
+ else:
38
+ time.sleep(delay)
39
+ except requests.exceptions.RequestException as e:
40
+ time.sleep(delay)
41
+ return False
42
+
43
+
44
+ class Spec:
45
+ def __init__(self, domain: list[Fraction], render, fmt: dict):
46
+ self._domain = domain
47
+ self._render = render
48
+ self._fmt = fmt
49
+
50
+ def __repr__(self):
51
+ lines = []
52
+ for i, t in enumerate(self._domain):
53
+ frame_expr = self._render(t, i)
54
+ lines.append(
55
+ f"{t.numerator}/{t.denominator} => {frame_expr}",
56
+ )
57
+ return "\n".join(lines)
58
+
59
+ def _sources(self):
60
+ s = set()
61
+ for i, t in enumerate(self._domain):
62
+ frame_expr = self._render(t, i)
63
+ s = s.union(frame_expr._sources())
64
+ return s
65
+
66
+ def _to_json_spec(self):
67
+ frames = []
68
+ s = set()
69
+ f = {}
70
+ for i, t in enumerate(self._domain):
71
+ frame_expr = self._render(t, i)
72
+ s = s.union(frame_expr._sources())
73
+ f = {**f, **frame_expr._filters()}
74
+ frame = [[t.numerator, t.denominator], frame_expr._to_json_spec()]
75
+ frames.append(frame)
76
+ return {"frames": frames}, s, f
77
+
78
+ def play(self, server, keep_spec=False):
79
+ """Play the video live in the notebook."""
80
+
81
+ from IPython.display import HTML
82
+
83
+ spec_pth = f"spec-{str(uuid.uuid4())}.json"
84
+ with open(spec_pth, "w") as outfile:
85
+ spec, sources, filters = self._to_json_spec()
86
+ outfile.write(json.dumps(spec))
87
+
88
+ sources = [
89
+ {
90
+ "name": s._name,
91
+ "path": s._path,
92
+ "stream": s._stream,
93
+ "service": s._service.as_json() if s._service is not None else None,
94
+ }
95
+ for s in sources
96
+ ]
97
+ filters = {
98
+ k: {
99
+ "filter": v._func,
100
+ "args": v._kwargs,
101
+ }
102
+ for k, v in filters.items()
103
+ }
104
+ arrays = []
105
+
106
+ print("Sending to server")
107
+ resp = server._new(spec_pth, sources, filters, arrays, self._fmt)
108
+ hls_video_url = resp["stream_url"]
109
+ namespace = resp["namespace"]
110
+
111
+ if not keep_spec:
112
+ os.remove(spec_pth)
113
+
114
+ hls_js_url = server.hls_js_url()
115
+
116
+ # We add a namespace to the video element to avoid conflicts with other videos
117
+ html_code = f"""
118
+ <!DOCTYPE html>
119
+ <html>
120
+ <head>
121
+ <title>HLS Video Player</title>
122
+ <!-- Include hls.js library -->
123
+ <script src="{hls_js_url}"></script>
124
+ </head>
125
+ <body>
126
+ <!-- Video element -->
127
+ <video id="video-{namespace}" controls width="640" height="360" autoplay></video>
128
+ <script>
129
+ var video = document.getElementById('video-{namespace}');
130
+ var videoSrc = '{hls_video_url}';
131
+ var hls = new Hls();
132
+ hls.loadSource(videoSrc);
133
+ hls.attachMedia(video);
134
+ hls.on(Hls.Events.MANIFEST_PARSED, function() {{
135
+ video.play();
136
+ }});
137
+ </script>
138
+ </body>
139
+ </html>
140
+ """
141
+ return HTML(data=html_code)
142
+
143
+ def save(self, server, pth, keep_spec=False):
144
+ """Save the video to a file."""
145
+
146
+ spec_pth = f"spec-{str(uuid.uuid4())}.json"
147
+ with open(spec_pth, "w") as outfile:
148
+ spec, sources, filters = self._to_json_spec()
149
+ outfile.write(json.dumps(spec))
150
+
151
+ sources = [
152
+ {
153
+ "name": s._name,
154
+ "path": s._path,
155
+ "stream": s._stream,
156
+ "service": s._service.as_json() if s._service is not None else None,
157
+ }
158
+ for s in sources
159
+ ]
160
+ filters = {
161
+ k: {
162
+ "filter": v._func,
163
+ "args": v._kwargs,
164
+ }
165
+ for k, v in filters.items()
166
+ }
167
+ arrays = []
168
+
169
+ resp = server._export(pth, spec_pth, sources, filters, arrays, self._fmt)
170
+
171
+ if not keep_spec:
172
+ os.remove(spec_pth)
173
+
174
+ return resp
175
+
176
+ def _vrod_bench(self, server):
177
+ out = {}
178
+ pth = "spec.json"
179
+ start_t = time.time()
180
+ with open(pth, "w") as outfile:
181
+ spec, sources, filters = self._to_json_spec()
182
+ outfile.write(json.dumps(spec))
183
+
184
+ sources = [
185
+ {
186
+ "name": s._name,
187
+ "path": s._path,
188
+ "stream": s._stream,
189
+ "service": s._service.as_json() if s._service is not None else None,
190
+ }
191
+ for s in sources
192
+ ]
193
+ filters = {
194
+ k: {
195
+ "filter": v._func,
196
+ "args": v._kwargs,
197
+ }
198
+ for k, v in filters.items()
199
+ }
200
+ arrays = []
201
+ end_t = time.time()
202
+ out["vrod_create_spec"] = end_t - start_t
203
+
204
+ start = time.time()
205
+ resp = server._new(pth, sources, filters, arrays, self._fmt)
206
+ end = time.time()
207
+ out["vrod_register"] = end - start
208
+
209
+ stream_url = resp["stream_url"]
210
+ first_segment = stream_url.replace("stream.m3u8", "segment-0.ts")
211
+
212
+ start = time.time()
213
+ r = requests.get(first_segment)
214
+ r.raise_for_status()
215
+ end = time.time()
216
+ out["vrod_first_segment"] = end - start
217
+ return out
218
+
219
+ def _dve2_bench(self, server):
220
+ pth = "spec.json"
221
+ out = {}
222
+ start_t = time.time()
223
+ with open(pth, "w") as outfile:
224
+ spec, sources, filters = self._to_json_spec()
225
+ outfile.write(json.dumps(spec))
226
+
227
+ sources = [
228
+ {
229
+ "name": s._name,
230
+ "path": s._path,
231
+ "stream": s._stream,
232
+ "service": s._service.as_json() if s._service is not None else None,
233
+ }
234
+ for s in sources
235
+ ]
236
+ filters = {
237
+ k: {
238
+ "filter": v._func,
239
+ "args": v._kwargs,
240
+ }
241
+ for k, v in filters.items()
242
+ }
243
+ arrays = []
244
+ end_t = time.time()
245
+ out["dve2_create_spec"] = end_t - start_t
246
+
247
+ start = time.time()
248
+ resp = server._export(pth, sources, filters, arrays, self._fmt)
249
+ end = time.time()
250
+ out["dve2_exec"] = end - start
251
+ return out
252
+
253
+
254
+ class YrdenServer:
255
+ """A connection to a Yrden server"""
256
+
257
+ def __init__(self, domain=None, port=None, bin="vidformer-cli"):
258
+ """Connect to a Yrden server
259
+
260
+ Can either connect to an existing server, if domain and port are provided, or start a new server using the provided binary
261
+ """
262
+
263
+ self._domain = domain
264
+ self._port = port
265
+ self._proc = None
266
+ if self._port is None:
267
+ assert bin is not None
268
+ self._domain = "localhost"
269
+ self._port = random.randint(49152, 65535)
270
+ cmd = [bin, "yrden", "--port", str(self._port)]
271
+ if _in_notebook:
272
+ # We need to print the URL in the notebook
273
+ # This is also a trick to get VS Code to forward the port
274
+ cmd += ["--print-url"]
275
+ self._proc = subprocess.Popen(cmd)
276
+
277
+ assert _check_hls_link_exists(f"http://{self._domain}:{self._port}/")
278
+
279
+ def _source(self, name: str, path: str, stream: int, service):
280
+ r = requests.post(
281
+ f"http://{self._domain}:{self._port}/source",
282
+ json={
283
+ "name": name,
284
+ "path": path,
285
+ "stream": stream,
286
+ "service": service.as_json() if service is not None else None,
287
+ },
288
+ )
289
+ if not r.ok:
290
+ raise Exception(r.text)
291
+
292
+ resp = r.json()
293
+ resp["ts"] = [Fraction(x[0], x[1]) for x in resp["ts"]]
294
+ return resp
295
+
296
+ def _new(self, spec, sources, filters, arrays, fmt):
297
+ req = {
298
+ "spec": spec,
299
+ "sources": sources,
300
+ "filters": filters,
301
+ "arrays": arrays,
302
+ "width": fmt["width"],
303
+ "height": fmt["height"],
304
+ "pix_fmt": fmt["pix_fmt"],
305
+ }
306
+
307
+ r = requests.post(f"http://{self._domain}:{self._port}/new", json=req)
308
+ if not r.ok:
309
+ raise Exception(r.text)
310
+
311
+ return r.json()
312
+
313
+ def _export(self, pth, spec_pth, sources, filters, arrays, fmt):
314
+ req = {
315
+ "spec": spec_pth,
316
+ "sources": sources,
317
+ "filters": filters,
318
+ "arrays": arrays,
319
+ "width": fmt["width"],
320
+ "height": fmt["height"],
321
+ "pix_fmt": fmt["pix_fmt"],
322
+ "output_path": pth,
323
+ }
324
+
325
+ r = requests.post(f"http://{self._domain}:{self._port}/export", json=req)
326
+ if not r.ok:
327
+ raise Exception(r.text)
328
+
329
+ return r.json()
330
+
331
+ def hls_js_url(self):
332
+ """Return the link to the yrden-hosted hls.js file"""
333
+ return f"http://{self._domain}:{self._port}/hls.js"
334
+
335
+ def __del__(self):
336
+ if self._proc is not None:
337
+ self._proc.kill()
338
+
339
+
340
+ class SourceExpr:
341
+ def __init__(self, source, idx, is_iloc):
342
+ self._source = source
343
+ self._idx = idx
344
+ self._is_iloc = is_iloc
345
+
346
+ def __repr__(self):
347
+ if self._is_iloc:
348
+ return f"{self._source.name}.iloc[{self._idx}]"
349
+ else:
350
+ return f"{self._source.name}[{self._idx}]"
351
+
352
+ def _to_json_spec(self):
353
+ if self._is_iloc:
354
+ return {
355
+ "Source": {
356
+ "video": self._source._name,
357
+ "index": {"ILoc": int(self._idx)},
358
+ }
359
+ }
360
+ else:
361
+ return {
362
+ "Source": {
363
+ "video": self._source._name,
364
+ "index": {"T": [self._idx.numerator, self._idx.denominator]},
365
+ }
366
+ }
367
+
368
+ def _sources(self):
369
+ return set([self._source])
370
+
371
+ def _filters(self):
372
+ return {}
373
+
374
+
375
+ class SourceILoc:
376
+ def __init__(self, source):
377
+ self._source = source
378
+
379
+ def __getitem__(self, idx):
380
+ if type(idx) != int:
381
+ raise Exception("Source iloc index must be an integer")
382
+ return SourceExpr(self._source, idx, True)
383
+
384
+
385
+ class Source:
386
+ def __init__(
387
+ self, server: YrdenServer, name: str, path: str, stream: int, service=None
388
+ ):
389
+ self._server = server
390
+ self._name = name
391
+ self._path = path
392
+ self._stream = stream
393
+ self._service = service
394
+
395
+ self.iloc = SourceILoc(self)
396
+
397
+ self._src = self._server._source(
398
+ self._name, self._path, self._stream, self._service
399
+ )
400
+
401
+ def fmt(self):
402
+ return {
403
+ "width": self._src["width"],
404
+ "height": self._src["height"],
405
+ "pix_fmt": self._src["pix_fmt"],
406
+ }
407
+
408
+ def ts(self):
409
+ return self._src["ts"]
410
+
411
+ def __getitem__(self, idx):
412
+ if type(idx) != Fraction:
413
+ raise Exception("Source index must be a Fraction")
414
+ return SourceExpr(self, idx, False)
415
+
416
+
417
+ class StorageService:
418
+ def __init__(self, service: str, **kwargs):
419
+ if type(service) != str:
420
+ raise Exception("Service name must be a string")
421
+ self._service = service
422
+ for k, v in kwargs.items():
423
+ if type(v) != str:
424
+ raise Exception(f"Value of {k} must be a string")
425
+ self._config = kwargs
426
+
427
+ def as_json(self):
428
+ return {"service": self._service, "config": self._config}
429
+
430
+ def __repr__(self):
431
+ return f"{self._service}(config={self._config})"
432
+
433
+
434
+ def _json_arg(arg):
435
+ if type(arg) == FilterExpr or type(arg) == SourceExpr:
436
+ return {"Frame": arg._to_json_spec()}
437
+ elif type(arg) == int:
438
+ return {"Data": {"Int": arg}}
439
+ elif type(arg) == str:
440
+ return {"Data": {"String": arg}}
441
+ elif type(arg) == bool:
442
+ return {"Data": {"Bool": arg}}
443
+ else:
444
+ assert False
445
+
446
+
447
+ class Filter:
448
+ def __init__(self, name: str, tl_func=None, **kwargs):
449
+ self._name = name
450
+
451
+ # tl_func is the top level func, which is the true implementation, not just a pretty name
452
+ if tl_func is None:
453
+ self._func = name
454
+ else:
455
+ self._func = tl_func
456
+
457
+ # filter infra args, not invocation args
458
+ for k, v in kwargs.items():
459
+ if type(v) != str:
460
+ raise Exception(f"Value of {k} must be a string")
461
+ self._kwargs = kwargs
462
+
463
+ def __call__(self, *args, **kwargs):
464
+ return FilterExpr(self, args, kwargs)
465
+
466
+
467
+ class FilterExpr:
468
+ def __init__(self, filter: Filter, args, kwargs):
469
+ self._filter = filter
470
+ self._args = args
471
+ self._kwargs = kwargs
472
+
473
+ def __repr__(self):
474
+ args = []
475
+ for arg in self._args:
476
+ val = f'"{arg}"' if type(arg) == str else str(arg)
477
+ args.append(str(val))
478
+ for k, v in self._kwargs.items():
479
+ val = f'"{v}"' if type(v) == str else str(v)
480
+ args.append(f"{k}={val}")
481
+ return f"{self._filter._name}({', '.join(args)})"
482
+
483
+ def _to_json_spec(self):
484
+ args = []
485
+ for arg in self._args:
486
+ args.append(_json_arg(arg))
487
+ kwargs = {}
488
+ for k, v in self._kwargs.items():
489
+ kwargs[k] = _json_arg(v)
490
+ return {"Filter": {"name": self._filter._name, "args": args, "kwargs": kwargs}}
491
+
492
+ def _sources(self):
493
+ s = set()
494
+ for arg in self._args:
495
+ if type(arg) == FilterExpr or type(arg) == SourceExpr:
496
+ s = s.union(arg._sources())
497
+ for arg in self._kwargs.values():
498
+ if type(arg) == FilterExpr or type(arg) == SourceExpr:
499
+ s = s.union(arg._sources())
500
+ return s
501
+
502
+ def _filters(self):
503
+ f = {self._filter._name: self._filter}
504
+ for arg in self._args:
505
+ if type(arg) == FilterExpr:
506
+ f = {**f, **arg._filters()}
507
+ for arg in self._kwargs.values():
508
+ if type(arg) == FilterExpr:
509
+ f = {**f, **arg._filters()}
510
+ return f
511
+
512
+
513
+ class UDF:
514
+ def __init__(self, name: str):
515
+ self._name = name
516
+ self._socket_path = None
517
+ self._p = None
518
+
519
+ def filter(self, *args, **kwargs):
520
+ raise Exception("User must implement the filter method")
521
+
522
+ def filter_type(self, *args, **kwargs):
523
+ raise Exception("User must implement the filter_type method")
524
+
525
+ def into_filter(self):
526
+ assert self._socket_path is None
527
+ self._socket_path = f"/tmp/vidformer-{self._name}-{str(uuid.uuid4())}.sock"
528
+ self._p = multiprocessing.Process(
529
+ target=_run_udf_host, args=(self, self._socket_path)
530
+ )
531
+ self._p.start()
532
+ return Filter(
533
+ name=self._name, tl_func="IPC", socket=self._socket_path, func=self._name
534
+ )
535
+
536
+ def _handle_connection(self, connection):
537
+ try:
538
+ while True:
539
+ frame_len = connection.recv(4)
540
+ if not frame_len or len(frame_len) != 4:
541
+ break
542
+ frame_len = int.from_bytes(frame_len, byteorder="big")
543
+ data = connection.recv(frame_len)
544
+ if not data:
545
+ break
546
+
547
+ while len(data) < frame_len:
548
+ new_data = connection.recv(frame_len - len(data))
549
+ if not new_data:
550
+ raise Exception("Partial data received")
551
+ data += new_data
552
+
553
+ obj = msgpack.unpackb(data, raw=False)
554
+ f_func, f_op, f_args, f_kwargs = (
555
+ obj["func"],
556
+ obj["op"],
557
+ obj["args"],
558
+ obj["kwargs"],
559
+ )
560
+
561
+ response = None
562
+ if f_op == "filter":
563
+ f_args = [self._deser_filter(x) for x in f_args]
564
+ f_kwargs = {k: self._deser_filter(v) for k, v in f_kwargs}
565
+ response = self.filter(*f_args, **f_kwargs)
566
+ if type(response) != UDFFrame:
567
+ raise Exception(
568
+ f"filter must return a UDFFrame, got {type(response)}"
569
+ )
570
+ if response.frame_type().pix_fmt() != "rgb24":
571
+ raise Exception(
572
+ f"filter must return a frame with pix_fmt 'rgb24', got {response.frame_type().pix_fmt()}"
573
+ )
574
+
575
+ response = response._response_ser()
576
+ elif f_op == "filter_type":
577
+ f_args = [self._deser_filter_type(x) for x in f_args]
578
+ f_kwargs = {k: self._deser_filter_type(v) for k, v in f_kwargs}
579
+ response = self.filter_type(*f_args, **f_kwargs)
580
+ if type(response) != UDFFrameType:
581
+ raise Exception(
582
+ f"filter_type must return a UDFFrameType, got {type(response)}"
583
+ )
584
+ if response.pix_fmt() != "rgb24":
585
+ raise Exception(
586
+ f"filter_type must return a frame with pix_fmt 'rgb24', got {response.pix_fmt()}"
587
+ )
588
+ response = response._response_ser()
589
+ else:
590
+ raise Exception(f"Unknown operation: {f_op}")
591
+
592
+ response = msgpack.packb(response, use_bin_type=True)
593
+ response_len = len(response).to_bytes(4, byteorder="big")
594
+ connection.sendall(response_len)
595
+ connection.sendall(response)
596
+ finally:
597
+ connection.close()
598
+
599
+ def _deser_filter_type(self, obj):
600
+ assert type(obj) == dict
601
+ keys = list(obj.keys())
602
+ assert len(keys) == 1
603
+ type_key = keys[0]
604
+ assert type_key in ["Frame", "String", "Int", "Bool"]
605
+
606
+ if type_key == "Frame":
607
+ frame = obj[type_key]
608
+ assert type(frame) == dict
609
+ assert "width" in frame
610
+ assert "height" in frame
611
+ assert "format" in frame
612
+ assert type(frame["width"]) == int
613
+ assert type(frame["height"]) == int
614
+ assert frame["format"] == 2 # AV_PIX_FMT_RGB24
615
+ return UDFFrameType(frame["width"], frame["height"], "rgb24")
616
+ elif type_key == "String":
617
+ assert type(obj[type_key]) == str
618
+ return obj[type_key]
619
+ elif type_key == "Int":
620
+ assert type(obj[type_key]) == int
621
+ return obj[type_key]
622
+ elif type_key == "Bool":
623
+ assert type(obj[type_key]) == bool
624
+ return obj[type_key]
625
+ else:
626
+ assert False
627
+
628
+ def _deser_filter(self, obj):
629
+ assert type(obj) == dict
630
+ keys = list(obj.keys())
631
+ assert len(keys) == 1
632
+ type_key = keys[0]
633
+ assert type_key in ["Frame", "String", "Int", "Bool"]
634
+
635
+ if type_key == "Frame":
636
+ frame = obj[type_key]
637
+ assert type(frame) == dict
638
+ assert "data" in frame
639
+ assert "width" in frame
640
+ assert "height" in frame
641
+ assert "format" in frame
642
+ assert type(frame["width"]) == int
643
+ assert type(frame["height"]) == int
644
+ assert frame["format"] == "rgb24"
645
+ assert type(frame["data"]) == bytes
646
+
647
+ data = np.frombuffer(frame["data"], dtype=np.uint8)
648
+ data = data.reshape(frame["height"], frame["width"], 3)
649
+ return UDFFrame(
650
+ data, UDFFrameType(frame["width"], frame["height"], "rgb24")
651
+ )
652
+ elif type_key == "String":
653
+ assert type(obj[type_key]) == str
654
+ return obj[type_key]
655
+ elif type_key == "Int":
656
+ assert type(obj[type_key]) == int
657
+ return obj[type_key]
658
+ elif type_key == "Bool":
659
+ assert type(obj[type_key]) == bool
660
+ return obj[type_key]
661
+ else:
662
+ assert False
663
+
664
+ def _host(self, socket_path: str):
665
+ if os.path.exists(socket_path):
666
+ os.remove(socket_path)
667
+
668
+ # start listening on the socket
669
+ sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
670
+ sock.bind(socket_path)
671
+ sock.listen(1)
672
+
673
+ while True:
674
+ # accept incoming connection
675
+ connection, client_address = sock.accept()
676
+ thread = threading.Thread(
677
+ target=self._handle_connection, args=(connection,)
678
+ )
679
+ thread.start()
680
+
681
+ def __del__(self):
682
+ if self._socket_path is not None:
683
+ self._p.terminate()
684
+ if os.path.exists(self._socket_path):
685
+ # it's possible the process hasn't even created the socket yet
686
+ os.remove(self._socket_path)
687
+
688
+
689
+ class UDFFrameType:
690
+ def __init__(self, width: int, height: int, pix_fmt: str):
691
+ assert type(width) == int
692
+ assert type(height) == int
693
+ assert type(pix_fmt) == str
694
+
695
+ self._width = width
696
+ self._height = height
697
+ self._pix_fmt = pix_fmt
698
+
699
+ def width(self):
700
+ return self._width
701
+
702
+ def height(self):
703
+ return self._height
704
+
705
+ def pix_fmt(self):
706
+ return self._pix_fmt
707
+
708
+ def _response_ser(self):
709
+ return {
710
+ "frame_type": {
711
+ "width": self._width,
712
+ "height": self._height,
713
+ "format": 2, # AV_PIX_FMT_RGB24
714
+ }
715
+ }
716
+
717
+ def __repr__(self):
718
+ return f"FrameType<{self._width}x{self._height}, {self._pix_fmt}>"
719
+
720
+
721
+ class UDFFrame:
722
+ def __init__(self, data: np.ndarray, f_type: UDFFrameType):
723
+ assert type(data) == np.ndarray
724
+ assert type(f_type) == UDFFrameType
725
+
726
+ # We only support RGB24 for now
727
+ assert data.dtype == np.uint8
728
+ assert data.shape[2] == 3
729
+
730
+ # check type matches
731
+ assert data.shape[0] == f_type.height()
732
+ assert data.shape[1] == f_type.width()
733
+ assert f_type.pix_fmt() == "rgb24"
734
+
735
+ self._data = data
736
+ self._f_type = f_type
737
+
738
+ def data(self):
739
+ return self._data
740
+
741
+ def frame_type(self):
742
+ return self._f_type
743
+
744
+ def _response_ser(self):
745
+ return {
746
+ "frame": {
747
+ "data": self._data.tobytes(),
748
+ "width": self._f_type.width(),
749
+ "height": self._f_type.height(),
750
+ "format": "rgb24",
751
+ }
752
+ }
753
+
754
+ def __repr__(self):
755
+ return f"Frame<{self._f_type.width()}x{self._f_type.height()}, {self._f_type.pix_fmt()}>"
756
+
757
+
758
+ def _run_udf_host(udf: UDF, socket_path: str):
759
+ udf._host(socket_path)