webgpu 0.0.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
webgpu/lilgui.py ADDED
@@ -0,0 +1,75 @@
1
+ from typing import Callable
2
+
3
+ from . import platform
4
+
5
+
6
+ class Folder:
7
+ def __init__(self, label: str | None, container, scene):
8
+ self.label = label
9
+ self.container = container
10
+ self.scene = scene
11
+ self.gui = None
12
+
13
+ def folder(self, label: str, closed=False):
14
+ folder = Folder(label, self.container, self.scene)
15
+ folder.gui = self.gui.addFolder(label)
16
+ if closed:
17
+ folder.gui.close()
18
+ return folder
19
+
20
+ def add(self, label: str, value, func: Callable, *args):
21
+ def f(*args):
22
+ func(*args)
23
+ self.scene.render()
24
+
25
+ return self.gui.add({label: value}, label, *args).onChange(
26
+ platform.create_proxy(f)
27
+ )
28
+
29
+ def checkbox(
30
+ self,
31
+ label: str,
32
+ value: bool,
33
+ func: Callable[[bool], None],
34
+ ):
35
+ return self.add(label, value, func)
36
+
37
+ def value(
38
+ self,
39
+ label: str,
40
+ value: object,
41
+ func: Callable[[object], None],
42
+ ):
43
+ return self.add(label, value, func)
44
+
45
+ def dropdown(
46
+ self,
47
+ values: dict[str, object],
48
+ func: Callable[[object], None],
49
+ value: str | None = None,
50
+ label="Dropdown",
51
+ ):
52
+ if value is None:
53
+ value = list(values.keys())[0]
54
+
55
+ return self.add(label, value, func, values)
56
+
57
+ def slider(
58
+ self,
59
+ value: float,
60
+ func: Callable[[float], None],
61
+ min=0.0,
62
+ max=1.0,
63
+ step=None,
64
+ label="Slider",
65
+ ):
66
+ if step is None:
67
+ step = (max - min) / 100
68
+
69
+ return self.add(label, value, func, min, max, step)
70
+
71
+
72
+ class LilGUI(Folder):
73
+ def __init__(self, container, scene):
74
+ super().__init__(None, container, scene)
75
+ self.gui = platform.js.createLilGUI({"container": container})
@@ -0,0 +1,3 @@
1
+ from pathlib import Path
2
+
3
+ js_code = (Path(__file__).parent / "link.js").read_text().replace("export ", "")
webgpu/link/base.py ADDED
@@ -0,0 +1,439 @@
1
+ import asyncio
2
+ import base64
3
+ import itertools
4
+ import json
5
+ import threading
6
+ from collections.abc import Mapping
7
+ from typing import Callable
8
+
9
+
10
+ class AttrDict(dict):
11
+ def __init__(self, *args, **kwargs):
12
+ super(AttrDict, self).__init__(*args, **kwargs)
13
+ self.__dict__ = self
14
+
15
+
16
+ class LinkBase:
17
+ _request_id: itertools.count
18
+ _requests: dict
19
+ _objects: dict
20
+
21
+ _serializers: dict[type, Callable] = {}
22
+
23
+ def _send_data(self, data):
24
+ raise NotImplementedError
25
+
26
+ @staticmethod
27
+ def register_serializer(type_, serializer):
28
+ LinkBase._serializers[type_] = serializer
29
+
30
+ def __init__(self):
31
+ self._request_id = itertools.count()
32
+ self._requests = {}
33
+ self._objects = {}
34
+
35
+ def _call_data(self, id, prop, args, ignore_result=False):
36
+ return {
37
+ "request_id": next(self._request_id) if not ignore_result else None,
38
+ "prop": prop,
39
+ "type": "call",
40
+ "id": id,
41
+ "args": self._dump_data(args),
42
+ }
43
+
44
+ def call_new(self, id=None, prop=None, args=[], ignore_result=False):
45
+ return self._send_data(
46
+ self._call_data(id, prop, args, ignore_result) | {"type": "new"}
47
+ )
48
+
49
+ def call_method(self, id=None, prop=None, args=[], ignore_result=False):
50
+ return self._send_data(self._call_data(id, prop, args, ignore_result))
51
+
52
+ def call_method_ignore_return(self, id=None, prop=None, args=[]):
53
+ return self.call(id, prop, args, ignore_result=True)
54
+
55
+ def call(self, id, args=[], parent_id=None, ignore_result=False):
56
+ return self._send_data(
57
+ self._call_data(id, None, args, ignore_result) | {"parent_id": parent_id}
58
+ )
59
+
60
+ def set_item(self, id, key, value):
61
+ return self._send_data(
62
+ {
63
+ "type": "set",
64
+ "id": id,
65
+ "key": key,
66
+ "value": self._dump_data(value),
67
+ }
68
+ )
69
+
70
+ def set(self, id, prop, value):
71
+ return self._send_data(
72
+ {
73
+ "type": "set",
74
+ "id": id,
75
+ "prop": prop,
76
+ "value": self._dump_data(value),
77
+ }
78
+ )
79
+
80
+ def get_keys(self, id):
81
+ return self._send_data(
82
+ {
83
+ "request_id": next(self._request_id),
84
+ "type": "get_keys",
85
+ "id": id,
86
+ }
87
+ )
88
+
89
+ def get_item(self, id, key):
90
+ return self._send_data(
91
+ {
92
+ "request_id": next(self._request_id),
93
+ "type": "get",
94
+ "id": id,
95
+ "key": key,
96
+ }
97
+ )
98
+
99
+ def get(self, id, prop: str | None = None):
100
+ return self._send_data(
101
+ {
102
+ "request_id": next(self._request_id),
103
+ "type": "get",
104
+ "id": id,
105
+ "prop": prop,
106
+ }
107
+ )
108
+
109
+ def create_handle(self, obj):
110
+ id_ = id(obj)
111
+ self._objects[id_] = obj
112
+ return {"__is_crosslink_type__": True, "type": "proxy", "id": id_}
113
+
114
+ def _dump_data(self, data):
115
+ from .proxy import Proxy
116
+
117
+ type_ = type(data)
118
+ for ser_type in self._serializers:
119
+ if issubclass(type_, ser_type):
120
+ data = self._serializers[ser_type](data)
121
+ break
122
+
123
+ if isinstance(data, (int, float, str, bool, type(None))):
124
+ return data
125
+
126
+ if isinstance(data, (bytes, memoryview)):
127
+ return {
128
+ "__is_crosslink_type__": True,
129
+ "type": "bytes",
130
+ "value": base64.b64encode(data).decode(),
131
+ }
132
+
133
+ if isinstance(data, dict):
134
+ return {k: self._dump_data(data[k]) for k in list(data.keys())}
135
+
136
+ if isinstance(data, Mapping):
137
+ return {k: self._dump_data(data[k]) for k in list(data.keys())}
138
+
139
+ if isinstance(data, (list, tuple)):
140
+ return [self._dump_data(v) for v in data]
141
+
142
+ if isinstance(data, Proxy):
143
+ return {
144
+ "__is_crosslink_type__": True,
145
+ "type": "object",
146
+ "id": data._id,
147
+ "parent_id": data._parent_id,
148
+ }
149
+
150
+ # complex type - store it in objects and only send its id
151
+ # print("complex type", data)
152
+ id_ = id(data)
153
+ self._objects[id_] = data
154
+ return {"__is_crosslink_type__": True, "type": "proxy", "id": id_}
155
+
156
+ def _load_data(self, data):
157
+ """Parse the result of a message from the remote environment"""
158
+ from .proxy import Proxy
159
+
160
+ # print("load data", data, type(data))
161
+
162
+ if isinstance(data, list):
163
+ return [self._load_data(v) for v in data]
164
+
165
+ if not isinstance(data, dict):
166
+ return data
167
+
168
+ if not data.get("__is_crosslink_type__", False):
169
+ return AttrDict({k: self._load_data(v) for k, v in data.items()})
170
+
171
+ if data["type"] == "object":
172
+ return self._objects[data["id"]]
173
+
174
+ if data["type"] == "proxy":
175
+
176
+ return Proxy(self, data.get("parent_id", None), data.get("id", None))
177
+
178
+ if data["type"] == "bytes":
179
+ return base64.b64decode(data["value"])
180
+
181
+ raise Exception(f"Unknown result type: {data}")
182
+
183
+ def expose(self, name: str, obj):
184
+ self._objects[str(name)] = obj
185
+
186
+ def create_proxy(self, func, ignore_return_value=False):
187
+ raise NotImplementedError
188
+
189
+ def destroy_proxy(self, proxy):
190
+ del self._objects[proxy["id"]]
191
+
192
+ def _send_response(self, request_id, data):
193
+ if type(data) is bytes:
194
+ data = request_id.to_bytes(4, "big") + data
195
+ else:
196
+ data = {
197
+ "request_id": request_id,
198
+ "type": "response",
199
+ "value": self._dump_data(data),
200
+ }
201
+
202
+ self._send_data(data)
203
+
204
+ def _get_obj(self, data):
205
+ obj = self._objects
206
+ id_ = data.get("id", None)
207
+ prop = data.get("prop", None)
208
+ key = data.get("key", None)
209
+
210
+ if id_ is not None:
211
+ obj = obj[data["id"]]
212
+ if prop is not None:
213
+ obj = obj.__getattribute__(prop)
214
+ if key is not None:
215
+ obj = obj[data["key"]]
216
+ return obj
217
+
218
+ async def _on_message_async(self, message: str):
219
+ data = json.loads(message)
220
+ try:
221
+ msg_type = data.get("type", None)
222
+ request_id = data.get("request_id", None)
223
+
224
+ response = None
225
+
226
+ match msg_type:
227
+ case "response":
228
+ event = self._requests[request_id]
229
+ self._requests[request_id] = self._load_data(
230
+ data.get("value", None)
231
+ )
232
+ event.set()
233
+ return
234
+
235
+ case "call":
236
+ func = self._get_obj(data)
237
+ args = self._load_data(data["args"])
238
+ response = func(*args)
239
+ try:
240
+ response = await response
241
+ except TypeError:
242
+ pass
243
+ except Exception as e:
244
+ print("error in call", type(e), str(e))
245
+
246
+ case "get":
247
+ response = self._get_obj(data)
248
+
249
+ case "get_keys":
250
+ response = []
251
+
252
+ case "set":
253
+ prop = data.pop("prop", None)
254
+ key = data.pop("key", None)
255
+ obj = self._get_obj(data)
256
+ if prop is not None:
257
+ obj.__setattr__(prop, data["value"])
258
+ elif key is not None:
259
+ obj[key] = self._load_data(data["value"])
260
+
261
+ case _:
262
+ print("unknown message type", msg_type)
263
+
264
+ if request_id is not None:
265
+ self._send_response(request_id, response)
266
+ except Exception as e:
267
+ from webapp_client.utils import print_exception
268
+
269
+ print("error in on_message", data, type(e), str(e))
270
+ if "id" in data:
271
+ print("id", data["id"], self._objects[data["id"]])
272
+ print_exception(e)
273
+
274
+ def _on_message(self, message: str):
275
+ data = json.loads(message)
276
+ try:
277
+ msg_type = data.get("type", None)
278
+ request_id = data.get("request_id", None)
279
+
280
+ response = None
281
+
282
+ match msg_type:
283
+ case "response":
284
+ event = self._requests[request_id]
285
+ self._requests[request_id] = self._load_data(
286
+ data.get("value", None)
287
+ )
288
+ event.set()
289
+ return
290
+
291
+ case "call":
292
+ func = self._get_obj(data)
293
+ args = self._load_data(data["args"])
294
+ # print("call", func, args)
295
+ response = func(*args)
296
+
297
+ case "get":
298
+ response = self._get_obj(data)
299
+
300
+ case "get_keys":
301
+ response = []
302
+
303
+ case "set":
304
+ prop = data.pop("prop", None)
305
+ key = data.pop("key", None)
306
+ obj = self._get_obj(data)
307
+ if prop is not None:
308
+ obj.__setattr__(prop, data["value"])
309
+ elif key is not None:
310
+ obj[key] = self._load_data(data["value"])
311
+
312
+ case _:
313
+ print("unknown message type", msg_type)
314
+
315
+ if request_id is not None:
316
+ self._send_response(request_id, response)
317
+ except Exception as e:
318
+ from webapp_client.utils import print_exception
319
+
320
+ print("error in on_message", data, type(e), str(e))
321
+ print_exception(e)
322
+
323
+
324
+ class PyodideLink(LinkBase):
325
+ def __init__(self, send_message, size_buffer, result_buffer):
326
+ super().__init__()
327
+ self._send_message = send_message
328
+ self._size_buffer = size_buffer
329
+ self._result_buffer = result_buffer
330
+
331
+ def create_proxy(self, func, ignore_return_value=False):
332
+ id_ = id(func)
333
+ self._objects[id_] = func
334
+ return {
335
+ "__is_crosslink_type__": True,
336
+ "type": "proxy",
337
+ "id": id_,
338
+ "ignore_return_value": ignore_return_value,
339
+ }
340
+
341
+ def _send_data(self, data):
342
+ if type(data) is bytes:
343
+ self._send_message(data)
344
+ else:
345
+ if (
346
+ data.get("request_id", None) is not None
347
+ and data["type"] != "response"
348
+ and not data.get("ignore_return_value", False)
349
+ ):
350
+ import js
351
+
352
+ js.Atomics.store(self._size_buffer, 0, 0)
353
+ self._send_message(json.dumps(data))
354
+ js.Atomics.wait(self._size_buffer, 0, 0, 10000)
355
+ n = self._size_buffer[0]
356
+ res = bytes(self._result_buffer.slice(0, n))
357
+ s = res.decode("utf-8")
358
+ data = json.loads(s)
359
+ return self._load_data(data.get("value", None))
360
+ else:
361
+ self._send_message(json.dumps(data))
362
+
363
+
364
+ class LinkBaseAsync(LinkBase):
365
+ _send_loop: asyncio.AbstractEventLoop
366
+ _callback_loop: asyncio.AbstractEventLoop
367
+ _callback_queue: asyncio.Queue
368
+ _callback_thread: threading.Thread
369
+
370
+ def __init__(self):
371
+ super().__init__()
372
+ self._send_loop = asyncio.new_event_loop()
373
+ self._callback_loop = asyncio.new_event_loop()
374
+ self._callback_queue = asyncio.Queue()
375
+
376
+ self._callback_thread = threading.Thread(
377
+ target=self._start_callback_thread, daemon=True
378
+ )
379
+ self._callback_thread.start()
380
+
381
+ def wait_for_connection(self):
382
+ raise NotImplementedError
383
+
384
+ def create_proxy(self, func, ignore_return_value=False):
385
+ def wrapper(*args):
386
+ asyncio.run_coroutine_threadsafe(
387
+ self._callback_queue.put((func, args)), self._callback_loop
388
+ )
389
+
390
+ id_ = id(wrapper)
391
+ self._objects[id_] = wrapper
392
+ return {
393
+ "__is_crosslink_type__": True,
394
+ "type": "proxy",
395
+ "id": id_,
396
+ "ignore_return_value": ignore_return_value,
397
+ }
398
+
399
+ def _send_data(self, data):
400
+ """Send data to the remote environment,
401
+ if request_id is set, (blocking-)wait for the response and return it"""
402
+ # print("send data", data)
403
+
404
+ request_id = data.get("request_id", None)
405
+ type = data.get("type", None)
406
+ # print("send response", data)
407
+ message = json.dumps(data)
408
+ event = None
409
+ if type != "response" and request_id is not None:
410
+ event = threading.Event()
411
+ self._requests[request_id] = event
412
+
413
+ asyncio.run_coroutine_threadsafe(self._send_async(message), self._send_loop)
414
+ if event:
415
+ event.wait()
416
+ return self._requests.pop(request_id)
417
+
418
+ async def _send_async(self, message):
419
+ raise NotImplementedError
420
+
421
+ def _start_callback_thread(self):
422
+ async def handle_callbacks():
423
+ while True:
424
+ try:
425
+ func, args = await self._callback_queue.get()
426
+ func(*args)
427
+ except asyncio.QueueEmpty:
428
+ pass
429
+ except Exception as e:
430
+ print("error in callback", type(e), str(e))
431
+ # await asyncio.sleep(0.01)
432
+
433
+ try:
434
+ self._callback_loop = asyncio.new_event_loop()
435
+ asyncio.set_event_loop(self._callback_loop)
436
+ self._callback_loop.create_task(handle_callbacks())
437
+ self._callback_loop.run_forever()
438
+ except Exception as e:
439
+ print("exception in _start_callback_thread", e)