webgpu 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
webgpu/lilgui.py ADDED
@@ -0,0 +1,73 @@
1
+ from typing import Callable
2
+
3
+ from . import platform
4
+
5
+
6
+ class Folder:
7
+ def __init__(self, label: str | None, container, scene):
8
+ self.label = label
9
+ self.container = container
10
+ self.scene = scene
11
+ self.gui = None
12
+
13
+ def folder(self, label: str, closed=False):
14
+ folder = Folder(label, self.container, self.scene)
15
+ folder.gui = self.gui.addFolder(label)
16
+ if closed:
17
+ folder.gui.close()
18
+ return folder
19
+
20
+ def add(self, label: str, value, func: Callable, *args):
21
+ def f(*args):
22
+ func(*args)
23
+ self.scene.render()
24
+
25
+ return self.gui.add({label: value}, label, *args).onChange(platform.create_proxy(f))
26
+
27
+ def checkbox(
28
+ self,
29
+ label: str,
30
+ value: bool,
31
+ func: Callable[[bool], None],
32
+ ):
33
+ return self.add(label, value, func)
34
+
35
+ def value(
36
+ self,
37
+ label: str,
38
+ value: object,
39
+ func: Callable[[object], None],
40
+ ):
41
+ return self.add(label, value, func)
42
+
43
+ def dropdown(
44
+ self,
45
+ values: dict[str, object],
46
+ func: Callable[[object], None],
47
+ value: str | None = None,
48
+ label="Dropdown",
49
+ ):
50
+ if value is None:
51
+ value = list(values.keys())[0]
52
+
53
+ return self.add(label, value, func, values)
54
+
55
+ def slider(
56
+ self,
57
+ value: float,
58
+ func: Callable[[float], None],
59
+ min=0.0,
60
+ max=1.0,
61
+ step=None,
62
+ label="Slider",
63
+ ):
64
+ if step is None:
65
+ step = (max - min) / 100
66
+
67
+ return self.add(label, value, func, min, max, step)
68
+
69
+
70
+ class LilGUI(Folder):
71
+ def __init__(self, container, scene):
72
+ super().__init__(None, container, scene)
73
+ self.gui = platform.js.createLilGUI({"container": container})
@@ -0,0 +1,3 @@
1
+ from pathlib import Path
2
+
3
+ js_code = (Path(__file__).parent / "link.js").read_text().replace("export ", "")
webgpu/link/base.py ADDED
@@ -0,0 +1,431 @@
1
+ import asyncio
2
+ import base64
3
+ import itertools
4
+ import json
5
+ import threading
6
+ from collections.abc import Mapping
7
+ from typing import Callable
8
+
9
+
10
+ class AttrDict(dict):
11
+ def __init__(self, *args, **kwargs):
12
+ super(AttrDict, self).__init__(*args, **kwargs)
13
+ self.__dict__ = self
14
+
15
+
16
+ class LinkBase:
17
+ _request_id: itertools.count
18
+ _requests: dict
19
+ _objects: dict
20
+
21
+ _serializers: dict[type, Callable] = {}
22
+
23
+ def _send_data(self, data):
24
+ raise NotImplementedError
25
+
26
+ @staticmethod
27
+ def register_serializer(type_, serializer):
28
+ LinkBase._serializers[type_] = serializer
29
+
30
+ def __init__(self):
31
+ self._request_id = itertools.count()
32
+ self._requests = {}
33
+ self._objects = {}
34
+
35
+ def _call_data(self, id, prop, args, ignore_result=False):
36
+ return {
37
+ "request_id": next(self._request_id) if not ignore_result else None,
38
+ "prop": prop,
39
+ "type": "call",
40
+ "id": id,
41
+ "args": self._dump_data(args),
42
+ }
43
+
44
+ def call_new(self, id=None, prop=None, args=[], ignore_result=False):
45
+ return self._send_data(self._call_data(id, prop, args, ignore_result) | {"type": "new"})
46
+
47
+ def call_method(self, id=None, prop=None, args=[], ignore_result=False):
48
+ return self._send_data(self._call_data(id, prop, args, ignore_result))
49
+
50
+ def call_method_ignore_return(self, id=None, prop=None, args=[]):
51
+ return self.call(id, prop, args, ignore_result=True)
52
+
53
+ def call(self, id, args=[], parent_id=None, ignore_result=False):
54
+ return self._send_data(
55
+ self._call_data(id, None, args, ignore_result) | {"parent_id": parent_id}
56
+ )
57
+
58
+ def set_item(self, id, key, value):
59
+ return self._send_data(
60
+ {
61
+ "type": "set",
62
+ "id": id,
63
+ "key": key,
64
+ "value": self._dump_data(value),
65
+ }
66
+ )
67
+
68
+ def set(self, id, prop, value):
69
+ return self._send_data(
70
+ {
71
+ "type": "set",
72
+ "id": id,
73
+ "prop": prop,
74
+ "value": self._dump_data(value),
75
+ }
76
+ )
77
+
78
+ def get_keys(self, id):
79
+ return self._send_data(
80
+ {
81
+ "request_id": next(self._request_id),
82
+ "type": "get_keys",
83
+ "id": id,
84
+ }
85
+ )
86
+
87
+ def get_item(self, id, key):
88
+ return self._send_data(
89
+ {
90
+ "request_id": next(self._request_id),
91
+ "type": "get",
92
+ "id": id,
93
+ "key": key,
94
+ }
95
+ )
96
+
97
+ def get(self, id, prop: str | None = None):
98
+ return self._send_data(
99
+ {
100
+ "request_id": next(self._request_id),
101
+ "type": "get",
102
+ "id": id,
103
+ "prop": prop,
104
+ }
105
+ )
106
+
107
+ def create_handle(self, obj):
108
+ id_ = id(obj)
109
+ self._objects[id_] = obj
110
+ return {"__is_crosslink_type__": True, "type": "proxy", "id": id_}
111
+
112
+ def _dump_data(self, data):
113
+ from .proxy import Proxy
114
+
115
+ type_ = type(data)
116
+ for ser_type in self._serializers:
117
+ if issubclass(type_, ser_type):
118
+ data = self._serializers[ser_type](data)
119
+ break
120
+
121
+ if isinstance(data, (int, float, str, bool, type(None))):
122
+ return data
123
+
124
+ if isinstance(data, (bytes, memoryview)):
125
+ return {
126
+ "__is_crosslink_type__": True,
127
+ "type": "bytes",
128
+ "value": base64.b64encode(data).decode(),
129
+ }
130
+
131
+ if isinstance(data, dict):
132
+ return {k: self._dump_data(data[k]) for k in list(data.keys())}
133
+
134
+ if isinstance(data, Mapping):
135
+ return {k: self._dump_data(data[k]) for k in list(data.keys())}
136
+
137
+ if isinstance(data, (list, tuple)):
138
+ return [self._dump_data(v) for v in data]
139
+
140
+ if isinstance(data, Proxy):
141
+ return {
142
+ "__is_crosslink_type__": True,
143
+ "type": "object",
144
+ "id": data._id,
145
+ "parent_id": data._parent_id,
146
+ }
147
+
148
+ # complex type - store it in objects and only send its id
149
+ # print("complex type", data)
150
+ id_ = id(data)
151
+ self._objects[id_] = data
152
+ return {"__is_crosslink_type__": True, "type": "proxy", "id": id_}
153
+
154
+ def _load_data(self, data):
155
+ """Parse the result of a message from the remote environment"""
156
+ from .proxy import Proxy
157
+
158
+ # print("load data", data, type(data))
159
+
160
+ if isinstance(data, list):
161
+ return [self._load_data(v) for v in data]
162
+
163
+ if not isinstance(data, dict):
164
+ return data
165
+
166
+ if not data.get("__is_crosslink_type__", False):
167
+ return AttrDict({k: self._load_data(v) for k, v in data.items()})
168
+
169
+ if data["type"] == "object":
170
+ return self._objects[data["id"]]
171
+
172
+ if data["type"] == "proxy":
173
+
174
+ return Proxy(self, data.get("parent_id", None), data.get("id", None))
175
+
176
+ if data["type"] == "bytes":
177
+ return base64.b64decode(data["value"])
178
+
179
+ raise Exception(f"Unknown result type: {data}")
180
+
181
+ def expose(self, name: str, obj):
182
+ self._objects[str(name)] = obj
183
+
184
+ def create_proxy(self, func, ignore_return_value=False):
185
+ raise NotImplementedError
186
+
187
+ def destroy_proxy(self, proxy):
188
+ del self._objects[proxy["id"]]
189
+
190
+ def _send_response(self, request_id, data):
191
+ if type(data) is bytes:
192
+ data = request_id.to_bytes(4, "big") + data
193
+ else:
194
+ data = {
195
+ "request_id": request_id,
196
+ "type": "response",
197
+ "value": self._dump_data(data),
198
+ }
199
+
200
+ self._send_data(data)
201
+
202
+ def _get_obj(self, data):
203
+ obj = self._objects
204
+ id_ = data.get("id", None)
205
+ prop = data.get("prop", None)
206
+ key = data.get("key", None)
207
+
208
+ if id_ is not None:
209
+ obj = obj[data["id"]]
210
+ if prop is not None:
211
+ obj = obj.__getattribute__(prop)
212
+ if key is not None:
213
+ obj = obj[data["key"]]
214
+ return obj
215
+
216
+ async def _on_message_async(self, message: str):
217
+ data = json.loads(message)
218
+ try:
219
+ msg_type = data.get("type", None)
220
+ request_id = data.get("request_id", None)
221
+
222
+ response = None
223
+
224
+ match msg_type:
225
+ case "response":
226
+ event = self._requests[request_id]
227
+ self._requests[request_id] = self._load_data(data.get("value", None))
228
+ event.set()
229
+ return
230
+
231
+ case "call":
232
+ func = self._get_obj(data)
233
+ args = self._load_data(data["args"])
234
+ response = func(*args)
235
+ try:
236
+ response = await response
237
+ except TypeError:
238
+ pass
239
+ except Exception as e:
240
+ print("error in call", type(e), str(e))
241
+
242
+ case "get":
243
+ response = self._get_obj(data)
244
+
245
+ case "get_keys":
246
+ response = []
247
+
248
+ case "set":
249
+ prop = data.pop("prop", None)
250
+ key = data.pop("key", None)
251
+ obj = self._get_obj(data)
252
+ if prop is not None:
253
+ obj.__setattr__(prop, data["value"])
254
+ elif key is not None:
255
+ obj[key] = self._load_data(data["value"])
256
+
257
+ case _:
258
+ print("unknown message type", msg_type)
259
+
260
+ if request_id is not None:
261
+ self._send_response(request_id, response)
262
+ except Exception as e:
263
+ from webapp_client.utils import print_exception
264
+
265
+ print("error in on_message", data, type(e), str(e))
266
+ if "id" in data and data["id"] in self._objects:
267
+ print("object", data["id"], self._objects[data["id"]])
268
+ print_exception(e)
269
+
270
+ def _on_message(self, message: str):
271
+ data = json.loads(message)
272
+ try:
273
+ msg_type = data.get("type", None)
274
+ request_id = data.get("request_id", None)
275
+
276
+ response = None
277
+
278
+ match msg_type:
279
+ case "response":
280
+ event = self._requests[request_id]
281
+ self._requests[request_id] = self._load_data(data.get("value", None))
282
+ event.set()
283
+ return
284
+
285
+ case "call":
286
+ func = self._get_obj(data)
287
+ args = self._load_data(data["args"])
288
+ # print("call", func, args)
289
+ response = func(*args)
290
+
291
+ case "get":
292
+ response = self._get_obj(data)
293
+
294
+ case "get_keys":
295
+ response = []
296
+
297
+ case "set":
298
+ prop = data.pop("prop", None)
299
+ key = data.pop("key", None)
300
+ obj = self._get_obj(data)
301
+ if prop is not None:
302
+ obj.__setattr__(prop, data["value"])
303
+ elif key is not None:
304
+ obj[key] = self._load_data(data["value"])
305
+
306
+ case _:
307
+ print("unknown message type", msg_type)
308
+
309
+ if request_id is not None:
310
+ self._send_response(request_id, response)
311
+ except Exception as e:
312
+ from webapp_client.utils import print_exception
313
+
314
+ print("error in on_message", data, type(e), str(e))
315
+ print_exception(e)
316
+
317
+
318
+ class PyodideLink(LinkBase):
319
+ def __init__(self, send_message, size_buffer, result_buffer):
320
+ super().__init__()
321
+ self._send_message = send_message
322
+ self._size_buffer = size_buffer
323
+ self._result_buffer = result_buffer
324
+
325
+ def create_proxy(self, func, ignore_return_value=False):
326
+ id_ = id(func)
327
+ self._objects[id_] = func
328
+ return {
329
+ "__is_crosslink_type__": True,
330
+ "type": "proxy",
331
+ "id": id_,
332
+ "ignore_return_value": ignore_return_value,
333
+ }
334
+
335
+ def _send_data(self, data):
336
+ if type(data) is bytes:
337
+ self._send_message(data)
338
+ else:
339
+ if (
340
+ data.get("request_id", None) is not None
341
+ and data["type"] != "response"
342
+ and not data.get("ignore_return_value", False)
343
+ ):
344
+ import js
345
+
346
+ js.Atomics.store(self._size_buffer, 0, 0)
347
+ self._send_message(json.dumps(data))
348
+ js.Atomics.wait(self._size_buffer, 0, 0, 10000)
349
+ n = self._size_buffer[0]
350
+ res = bytes(self._result_buffer.slice(0, n))
351
+ s = res.decode("utf-8")
352
+ data = json.loads(s)
353
+ return self._load_data(data.get("value", None))
354
+ else:
355
+ self._send_message(json.dumps(data))
356
+
357
+
358
+ class LinkBaseAsync(LinkBase):
359
+ _send_loop: asyncio.AbstractEventLoop
360
+ _callback_loop: asyncio.AbstractEventLoop
361
+ _callback_queue: asyncio.Queue
362
+ _callback_thread: threading.Thread
363
+
364
+ def __init__(self):
365
+ super().__init__()
366
+ self._send_loop = asyncio.new_event_loop()
367
+ self._callback_loop = asyncio.new_event_loop()
368
+ self._callback_queue = asyncio.Queue()
369
+
370
+ self._callback_thread = threading.Thread(target=self._start_callback_thread, daemon=True)
371
+ self._callback_thread.start()
372
+
373
+ def wait_for_connection(self):
374
+ raise NotImplementedError
375
+
376
+ def create_proxy(self, func, ignore_return_value=False):
377
+ def wrapper(*args):
378
+ asyncio.run_coroutine_threadsafe(
379
+ self._callback_queue.put((func, args)), self._callback_loop
380
+ )
381
+
382
+ id_ = id(wrapper)
383
+ self._objects[id_] = wrapper
384
+ return {
385
+ "__is_crosslink_type__": True,
386
+ "type": "proxy",
387
+ "id": id_,
388
+ "ignore_return_value": ignore_return_value,
389
+ }
390
+
391
+ def _send_data(self, data):
392
+ """Send data to the remote environment,
393
+ if request_id is set, (blocking-)wait for the response and return it"""
394
+ # print("send data", data)
395
+
396
+ request_id = data.get("request_id", None)
397
+ type = data.get("type", None)
398
+ # print("send response", data)
399
+ message = json.dumps(data)
400
+ event = None
401
+ if type != "response" and request_id is not None:
402
+ event = threading.Event()
403
+ self._requests[request_id] = event
404
+
405
+ asyncio.run_coroutine_threadsafe(self._send_async(message), self._send_loop)
406
+ if event:
407
+ event.wait()
408
+ return self._requests.pop(request_id)
409
+
410
+ async def _send_async(self, message):
411
+ raise NotImplementedError
412
+
413
+ def _start_callback_thread(self):
414
+ async def handle_callbacks():
415
+ while True:
416
+ try:
417
+ func, args = await self._callback_queue.get()
418
+ func(*args)
419
+ except asyncio.QueueEmpty:
420
+ pass
421
+ except Exception as e:
422
+ print("error in callback", type(e), str(e))
423
+ # await asyncio.sleep(0.01)
424
+
425
+ try:
426
+ self._callback_loop = asyncio.new_event_loop()
427
+ asyncio.set_event_loop(self._callback_loop)
428
+ self._callback_loop.create_task(handle_callbacks())
429
+ self._callback_loop.run_forever()
430
+ except Exception as e:
431
+ print("exception in _start_callback_thread", e)