webgpu 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
webgpu/main.py ADDED
@@ -0,0 +1,177 @@
1
+ """Main file for the webgpu example, creates a small 2d mesh and renders it using WebGPU"""
2
+
3
+ import urllib.parse
4
+
5
+ import js
6
+ import ngsolve as ngs
7
+ from netgen.occ import unit_cube
8
+ from netgen.geom2d import unit_square
9
+ from pyodide.ffi import create_proxy
10
+
11
+ from .gpu import WebGPU, init_webgpu
12
+ from .mesh import *
13
+ from .lic import LineIntegralConvolutionRenderObject
14
+
15
+ from .webgpu_api import BackendType, RenderPassEncoder, ColorTargetState
16
+
17
+ # s = ColorTargetState()
18
+ # print("BackendType", BackendType)
19
+ # print("ColorTargetState", s)
20
+
21
+
22
+ def f(encoder: RenderPassEncoder):
23
+ return
24
+
25
+
26
+ gpu: WebGPU = None
27
+ mesh_object: RenderObject = None
28
+ elements_object = None
29
+ point_number_object = None
30
+
31
+ cf = None
32
+ render_function = None
33
+
34
+
35
+ async def main():
36
+ global gpu, mesh_object, cf, render_function
37
+
38
+ gpu = await init_webgpu(js.document.getElementById("canvas"))
39
+ # print("DEVICE", dir(gpu.native_device))
40
+
41
+ point_number_object = None
42
+
43
+ if 1:
44
+ from ngsolve.meshes import MakeStructured3DMesh
45
+
46
+ # create new ngsolve mesh and evaluate arbitrary function on it
47
+ # mesh = ngs.Mesh(unit_cube.GenerateMesh(maxh=0.2))
48
+ # mesh = MakeStructured3DMesh(True, 10, prism=True)
49
+
50
+ import netgen.occ as occ
51
+ from netgen.meshing import IdentificationType
52
+
53
+ idtype = IdentificationType.CLOSESURFACES
54
+ inner = occ.Box((0, 0, 0), (1, 1, 1))
55
+ trafo = occ.gp_Trsf().Scale(inner.center, 1.1)
56
+ outer = trafo(inner)
57
+
58
+ # inner.Identify(outer, "", idtype, trafo)
59
+ shape = occ.Glue([outer - inner, inner])
60
+
61
+ geo = occ.OCCGeometry(shape)
62
+ mesh = geo.GenerateMesh(maxh=0.3)
63
+
64
+ # mesh = unit_square.GenerateMesh(maxh=0.3)
65
+ mesh = ngs.Mesh(mesh)
66
+
67
+ order = 3
68
+ cf = cf or ngs.sin(10 * ngs.x) * ngs.sin(10 * ngs.y)
69
+ # cf = ngs.x
70
+ data = MeshData(mesh, cf, order)
71
+ gpu.u_function.min = -1
72
+ gpu.u_function.max = 1
73
+ else:
74
+ # use compute shader to create a unit_square mesh
75
+ # but has always P1 and 'x' hard-coded as function
76
+ query = urllib.parse.parse_qs(js.location.search[1:])
77
+ N = 10
78
+ N = int(query.get("n", [N])[0])
79
+ data = create_testing_square_mesh(gpu, N)
80
+ gpu.u_function.min = 0
81
+ gpu.u_function.max = 1
82
+
83
+ # lic = LineIntegralConvolutionRenderObject(gpu, 1000, 800)
84
+ # print("LIC", lic)
85
+
86
+ # mesh_object = MeshRenderObject(gpu, data)
87
+ # mesh_object = MeshRenderObjectIndexed(gpu, data) # function values are wrong, due to ngsolve vertex numbering order
88
+ # mesh_object = MeshRenderObjectDeferred(
89
+ # gpu, data
90
+ # ) # function values are wrong, due to ngsolve vertex numbering order
91
+ point_number_object = PointNumbersRenderObject(gpu, data, font_size=16)
92
+ elements_object = Mesh3dElementsRenderObject(gpu, data)
93
+
94
+ t_last = 0
95
+ fps = 0
96
+ frame_counter = 0
97
+ params = pyodide.ffi.to_js({"shrink": 0.3})
98
+
99
+ def render(time):
100
+ # this is the render function, it's called for every frame
101
+ if not isinstance(time, float):
102
+ time = 0
103
+
104
+ nonlocal t_last, fps, frame_counter
105
+ print("params", params.shrink)
106
+ dt = time - t_last
107
+ t_last = time
108
+ frame_counter += 1
109
+ print(f"frame time {dt:.2f} ms")
110
+
111
+ gpu.u_mesh.shrink = params.shrink
112
+
113
+ # copy camera position etc. to GPU
114
+ gpu.update_uniforms()
115
+
116
+ command_encoder = gpu.device.createCommandEncoder()
117
+
118
+ if mesh_object is not None:
119
+ mesh_object.render(command_encoder)
120
+
121
+ if elements_object is not None:
122
+ elements_object.render(command_encoder)
123
+
124
+ if point_number_object is not None:
125
+ point_number_object.render(command_encoder)
126
+
127
+ gpu.device.queue.submit([command_encoder.finish()])
128
+ if frame_counter < 20:
129
+ js.requestAnimationFrame(render_function)
130
+
131
+ render_function = create_proxy(render)
132
+ gpu.input_handler._update_uniforms()
133
+ gpu.input_handler.render_function = render_function
134
+
135
+ render_function.request_id = js.requestAnimationFrame(render_function)
136
+
137
+ try:
138
+ gui = js.gui
139
+
140
+ gui.reset(recursive=True)
141
+
142
+ folder = js.window.folder
143
+ folder.reset()
144
+ folder.add(params, "shrink", 0.1, 1.0)
145
+ gui.onChange(render_function)
146
+ except Exception as e:
147
+ print(e)
148
+
149
+
150
+ def cleanup():
151
+ print("cleanup")
152
+ global gpu, mesh_object
153
+ if "gpu" in globals():
154
+ del gpu
155
+ if "mesh_object" in globals():
156
+ del mesh_object
157
+
158
+
159
+ async def user_function(data):
160
+ code, expr = data
161
+ import base64
162
+ import marshal
163
+ import types
164
+
165
+ code = base64.b64decode(code.encode("utf-8"))
166
+ code = marshal.loads(code)
167
+ func = types.FunctionType(code, globals(), "user_function")
168
+ func(expr)
169
+
170
+
171
+ async def reload(*args, **kwargs):
172
+ print("reload")
173
+ cleanup()
174
+ reload_package("webgpu")
175
+ from webgpu.main import main
176
+
177
+ await main()
webgpu/platform.py ADDED
@@ -0,0 +1,129 @@
1
+ """Platform specific code, currenty there are two possibilities:
2
+
3
+ 1. Running in a browser with Pyodide
4
+ Webgpu calls are done directly in the browser using the Pyodide interface provided by the "js" module.
5
+
6
+ 2. Running in a Python environment with a websocket connection to a browser
7
+ Webgpu calls are transferred via a websocket connection to the browser environment (using the webgpu.link.websocket module)
8
+ """
9
+
10
+ from collections.abc import Mapping
11
+
12
+ is_pyodide = False
13
+ create_proxy = None
14
+ destroy_proxy = None
15
+ js = None
16
+ websocket_server = None
17
+ link = None
18
+
19
+ try:
20
+ import js as pyodide_js
21
+ import pyodide.ffi
22
+ from pyodide.ffi import JsPromise, JsProxy, create_proxy
23
+
24
+ def destroy_proxy(proxy):
25
+ proxy.destroy()
26
+
27
+ is_pyodide = True
28
+
29
+ def _default_converter(value, a, b):
30
+ from .webgpu_api import BaseWebGPUHandle, BaseWebGPUObject
31
+
32
+ if isinstance(value, BaseWebGPUHandle):
33
+ return pyodide.ffi.to_js(value.handle)
34
+ if isinstance(value, BaseWebGPUObject):
35
+ return value.__dict__
36
+
37
+ def _convert(d):
38
+ from .webgpu_api import BaseWebGPUHandle, BaseWebGPUObject
39
+
40
+ if d is None:
41
+ return None
42
+ if isinstance(d, BaseWebGPUHandle):
43
+ return d.handle
44
+ if isinstance(d, BaseWebGPUObject):
45
+ return _convert(d.__dict__) if d.__dict__ else None
46
+ if isinstance(d, Mapping):
47
+ if not d:
48
+ return None
49
+ ret = {}
50
+ for key in d:
51
+ value = _convert(d[key])
52
+ if value is not None:
53
+ ret[key] = value
54
+ return ret
55
+
56
+ if isinstance(d, list):
57
+ return [_convert(value) for value in d]
58
+
59
+ return d
60
+
61
+ def toJS(value):
62
+ value = _convert(value)
63
+ ret = pyodide.ffi.to_js(
64
+ value,
65
+ dict_converter=pyodide_js.Object.fromEntries,
66
+ default_converter=_default_converter,
67
+ create_pyproxies=False,
68
+ )
69
+ return ret
70
+
71
+ except ImportError:
72
+ pass
73
+
74
+ if not is_pyodide:
75
+ from .link.proxy import Proxy as JsProxy
76
+ from .link.websocket import WebsocketLinkServer
77
+
78
+ toJS = lambda x: x
79
+
80
+ class JsPromise:
81
+ pass
82
+
83
+
84
+ if is_pyodide:
85
+ import json
86
+
87
+ from .link.base import LinkBase
88
+
89
+ LinkBase.register_serializer(JsProxy, lambda v: json.loads(pyodide_js.JSON.stringify(v)))
90
+
91
+
92
+ def init(before_wait_for_connection=None):
93
+ global js, create_proxy, destroy_proxy, websocket_server, link
94
+ if is_pyodide or js is not None:
95
+ return
96
+
97
+ websocket_server = WebsocketLinkServer()
98
+ create_proxy = websocket_server.create_proxy
99
+ destroy_proxy = websocket_server.destroy_proxy
100
+ link = websocket_server
101
+
102
+ if before_wait_for_connection:
103
+ before_wait_for_connection(websocket_server)
104
+
105
+ websocket_server.wait_for_connection()
106
+ js = websocket_server.get(None, None)
107
+
108
+ from .link.base import LinkBase
109
+ from .webgpu_api import BaseWebGPUHandle, BaseWebGPUObject
110
+
111
+ LinkBase.register_serializer(BaseWebGPUHandle, lambda v: v.handle)
112
+ LinkBase.register_serializer(BaseWebGPUObject, lambda v: v.__dict__ or None)
113
+
114
+ websocket_server._start_handling_messages.set()
115
+
116
+
117
+ def init_pyodide(link_):
118
+ global link
119
+ link = link_
120
+ print("init pyodide with link")
121
+ global js
122
+ js = link.get(None, None)
123
+ print("JS:", js)
124
+
125
+ from .link.base import LinkBase
126
+ from .webgpu_api import BaseWebGPUHandle, BaseWebGPUObject
127
+
128
+ LinkBase.register_serializer(BaseWebGPUHandle, lambda v: v.handle)
129
+ LinkBase.register_serializer(BaseWebGPUObject, lambda v: v.__dict__ or None)
@@ -0,0 +1,155 @@
1
+ import uuid
2
+ from typing import Callable
3
+
4
+ from .camera import Camera
5
+ from .canvas import Canvas
6
+ from .light import Light
7
+ from .utils import BaseBinding, create_bind_group, get_device, is_pyodide
8
+ from .webgpu_api import (
9
+ CommandEncoder,
10
+ CompareFunction,
11
+ DepthStencilState,
12
+ Device,
13
+ FragmentState,
14
+ PrimitiveState,
15
+ PrimitiveTopology,
16
+ VertexState,
17
+ )
18
+
19
+
20
+ class RenderOptions:
21
+ viewport: tuple[int, int, int, int, float, float]
22
+ canvas: Canvas
23
+
24
+ def __init__(self, canvas):
25
+ self.canvas = canvas
26
+ self.light = Light(self.device)
27
+ self.camera = Camera(canvas)
28
+
29
+ @property
30
+ def device(self) -> Device:
31
+ return self.canvas.device
32
+
33
+ def update_buffers(self):
34
+ self.camera._update_uniforms()
35
+
36
+ def get_bindings(self):
37
+ return [
38
+ *self.light.get_bindings(),
39
+ *self.camera.get_bindings(),
40
+ ]
41
+
42
+ def begin_render_pass(self, command_encoder: CommandEncoder, **kwargs):
43
+ load_op = command_encoder.getLoadOp()
44
+
45
+ render_pass_encoder = command_encoder.beginRenderPass(
46
+ self.canvas.color_attachments(load_op),
47
+ self.canvas.depth_stencil_attachment(load_op),
48
+ **kwargs,
49
+ )
50
+
51
+ render_pass_encoder.setViewport(0, 0, self.canvas.width, self.canvas.height, 0.0, 1.0)
52
+
53
+ return render_pass_encoder
54
+
55
+
56
+ class BaseRenderObject:
57
+ options: RenderOptions
58
+ label: str = ""
59
+ _timestamp: float = -1
60
+ active: bool = True
61
+
62
+ def __init__(self, label=None):
63
+ if label is None:
64
+ self.label = self.__class__.__name__
65
+ else:
66
+ self.label = label
67
+
68
+ def get_bounding_box(self) -> tuple[list[float], list[float]] | None:
69
+ return ([0.0, 0.0, 0.0], [1.0, 1.0, 1.0])
70
+
71
+ def update(self, timestamp):
72
+ if timestamp == self._timestamp:
73
+ return
74
+ self._timestamp = timestamp
75
+ self.create_render_pipeline()
76
+
77
+ @property
78
+ def device(self) -> Device:
79
+ return get_device()
80
+
81
+ @property
82
+ def canvas(self) -> Canvas:
83
+ return self.options.canvas
84
+
85
+ def create_render_pipeline(self) -> None:
86
+ raise NotImplementedError
87
+
88
+ def render(self, encoder: CommandEncoder):
89
+ raise NotImplementedError
90
+
91
+ def get_bindings(self) -> list[BaseBinding]:
92
+ raise NotImplementedError
93
+
94
+ def get_shader_code(self) -> str:
95
+ raise NotImplementedError
96
+
97
+ def add_options_to_gui(self, gui):
98
+ pass
99
+
100
+
101
+ class MultipleRenderObject(BaseRenderObject):
102
+ def __init__(self, render_objects):
103
+ self.render_objects = render_objects
104
+
105
+ def update(self, timestamp):
106
+ for r in self.render_objects:
107
+ r.options = self.options
108
+ r.update(timestamp=timestamp)
109
+
110
+ def redraw(self, timestamp=None):
111
+ for r in self.render_objects:
112
+ r.redraw(timestamp=timestamp)
113
+
114
+ def render(self, encoder):
115
+ for r in self.render_objects:
116
+ r.render(encoder)
117
+
118
+
119
+ class RenderObject(BaseRenderObject):
120
+ """Base class for render objects"""
121
+
122
+ n_vertices: int = 0
123
+ n_instances: int = 1
124
+ topology: PrimitiveTopology = PrimitiveTopology.triangle_list
125
+ depthBias: int = 0
126
+ vertex_entry_point: str = "vertex_main"
127
+ fragment_entry_point: str = "fragment_main"
128
+
129
+ def create_render_pipeline(self) -> None:
130
+ shader_module = self.device.createShaderModule(self.get_shader_code())
131
+ layout, self.group = create_bind_group(self.device, self.get_bindings())
132
+ self.pipeline = self.device.createRenderPipeline(
133
+ self.device.createPipelineLayout([layout]),
134
+ vertex=VertexState(module=shader_module, entryPoint=self.vertex_entry_point),
135
+ fragment=FragmentState(
136
+ module=shader_module,
137
+ entryPoint=self.fragment_entry_point,
138
+ targets=[self.options.canvas.color_target],
139
+ ),
140
+ primitive=PrimitiveState(topology=self.topology),
141
+ depthStencil=DepthStencilState(
142
+ format=self.options.canvas.depth_format,
143
+ depthWriteEnabled=True,
144
+ depthCompare=CompareFunction.less,
145
+ depthBias=self.depthBias,
146
+ ),
147
+ multisample=self.options.canvas.multisample,
148
+ )
149
+
150
+ def render(self, encoder: CommandEncoder) -> None:
151
+ render_pass = self.options.begin_render_pass(encoder)
152
+ render_pass.setPipeline(self.pipeline)
153
+ render_pass.setBindGroup(0, self.group)
154
+ render_pass.draw(self.n_vertices, self.n_instances)
155
+ render_pass.end()
webgpu/scene.py ADDED
@@ -0,0 +1,201 @@
1
+ import math
2
+ import time
3
+ from threading import Timer
4
+
5
+ from . import platform
6
+ from .canvas import Canvas
7
+ from .render_object import BaseRenderObject, RenderOptions
8
+ from .utils import is_pyodide, max_bounding_box
9
+ from .webgpu_api import *
10
+
11
+ _TARGET_FPS = 60
12
+
13
+
14
+ def debounce(render_function):
15
+ # Render only once every 1/_TARGET_FPS seconds
16
+ def debounced(*args, **kwargs):
17
+ if debounced.timer is not None:
18
+ # we already have a render scheduled, so do nothing
19
+ return
20
+
21
+ def f():
22
+ # clear the timer, so we can schedule a new one with the next function call
23
+ t = time.time()
24
+ render_function(*args, **kwargs)
25
+ debounced.timer = None
26
+ debounced.t_last = t
27
+
28
+ t_wait = max(1 / _TARGET_FPS - (time.time() - debounced.t_last), 0)
29
+ debounced.timer = Timer(t_wait, f)
30
+ debounced.timer.start()
31
+
32
+ debounced.timer = None
33
+ debounced.t_last = time.time()
34
+ return debounced
35
+
36
+
37
+ class Scene:
38
+ canvas: Canvas = None
39
+ render_objects: list[BaseRenderObject]
40
+ options: RenderOptions
41
+ gui: object = None
42
+
43
+ def __init__(
44
+ self,
45
+ render_objects: list[BaseRenderObject],
46
+ id: str | None = None,
47
+ canvas: Canvas | None = None,
48
+ ):
49
+ if id is None:
50
+ import uuid
51
+
52
+ id = str(uuid.uuid4())
53
+
54
+ self._id = id
55
+ self.render_objects = render_objects
56
+
57
+ if is_pyodide:
58
+ _scenes_by_id[id] = self
59
+ if canvas is not None:
60
+ self.init(canvas)
61
+
62
+ self.t_last = 0
63
+
64
+ def __repr__(self):
65
+ return ""
66
+
67
+ @property
68
+ def id(self) -> str:
69
+ return self._id
70
+
71
+ @property
72
+ def device(self) -> Device:
73
+ return self.canvas.device
74
+
75
+ def init(self, canvas):
76
+ self.canvas = canvas
77
+ self.options = RenderOptions(self.canvas)
78
+
79
+ timestamp = time.time()
80
+ for obj in self.render_objects:
81
+ obj.options = self.options
82
+ obj.update(timestamp=timestamp)
83
+
84
+ pmin, pmax = max_bounding_box([o.get_bounding_box() for o in self.render_objects])
85
+ camera = self.options.camera
86
+ camera.transform._center = 0.5 * (pmin + pmax)
87
+
88
+ def norm(v):
89
+ return math.sqrt(v[0] ** 2 + v[1] ** 2 + v[2] ** 2)
90
+
91
+ camera.transform._scale = 2 / norm(pmax - pmin)
92
+ if not (pmin[2] == 0 and pmax[2] == 0):
93
+ camera.transform.rotate(270, 0)
94
+ camera.transform.rotate(0, -20)
95
+ camera.transform.rotate(20, 0)
96
+
97
+ self._js_render = platform.create_proxy(self._render_direct)
98
+ camera.register_callbacks(canvas.input_handler, self.render)
99
+ self.options.update_buffers()
100
+ if is_pyodide:
101
+ _scenes_by_id[self.id] = self
102
+
103
+ canvas.on_resize(self.render)
104
+
105
+ def redraw(self):
106
+ import time
107
+
108
+ ts = time.time()
109
+ for obj in self.render_objects:
110
+ obj.redraw(timestamp=ts)
111
+
112
+ self.render()
113
+ return
114
+
115
+ if is_pyodide:
116
+ import js
117
+
118
+ js.requestAnimationFrame(self._js_render)
119
+ else:
120
+ # TODO: check if we are in a jupyter kernel
121
+ from .jupyter import run_code_in_pyodide
122
+
123
+ ts = time.time()
124
+ for obj in self.render_objects:
125
+ obj.redraw(timestamp=ts)
126
+
127
+ run_code_in_pyodide(f"import webgpu.scene; webgpu.scene.redraw_scene('{self.id}')")
128
+
129
+ def _render(self):
130
+ platform.js.requestAnimationFrame(self._js_render)
131
+
132
+ def _render_direct(self, t=0):
133
+ encoder = self.device.createCommandEncoder()
134
+
135
+ for obj in self.render_objects:
136
+ if obj.active:
137
+ obj.render(encoder)
138
+
139
+ encoder.copyTextureToTexture(
140
+ TexelCopyTextureInfo(self.canvas.target_texture),
141
+ TexelCopyTextureInfo(self.canvas.context.getCurrentTexture()),
142
+ [self.canvas.width, self.canvas.height, 1],
143
+ )
144
+ self.device.queue.submit([encoder.finish()])
145
+
146
+ @debounce
147
+ def render(self, t=0):
148
+ # self.canvas.resize()
149
+
150
+ if is_pyodide:
151
+ self._render()
152
+ return
153
+ # print("render")
154
+ # print("canvas", self.canvas.canvas)
155
+ # from . import proxy
156
+ # proxy.js.console.log("canvas", self.canvas.canvas)
157
+ # print("canvas size ", self.canvas.canvas.width, self.canvas.canvas.height)
158
+ # print(
159
+ # "texture size",
160
+ # self.canvas.target_texture.width,
161
+ # self.canvas.target_texture.height,
162
+ # )
163
+ encoder = self.device.createCommandEncoder()
164
+
165
+ for obj in self.render_objects:
166
+ if obj.active:
167
+ obj.render(encoder)
168
+
169
+ self.device.queue.submit([encoder.finish()])
170
+
171
+ if not is_pyodide:
172
+ platform.js.patchedRequestAnimationFrame(
173
+ self.canvas.device.handle,
174
+ self.canvas.context,
175
+ self.canvas.target_texture,
176
+ )
177
+
178
+ def cleanup(self):
179
+ for obj in self.render_objects:
180
+ obj.options = None
181
+
182
+ self.options.camera.unregister_callbacks(self.canvas.input_handler)
183
+ self.options.camera._render_function = None
184
+ self.canvas.input_handler.unregister_callbacks()
185
+ platform.destroy_proxy(self._js_render)
186
+ del self._js_render
187
+ self.canvas._on_resize_callbacks.remove(self.render)
188
+ self.canvas = None
189
+
190
+ if is_pyodide:
191
+ del _scenes_by_id[self.id]
192
+
193
+
194
+ if is_pyodide:
195
+ _scenes_by_id: dict[str, Scene] = {}
196
+
197
+ def get_scene(id: str) -> Scene:
198
+ return _scenes_by_id[id]
199
+
200
+ def redraw_scene(id: str):
201
+ get_scene(id).redraw()
File without changes
@@ -0,0 +1,21 @@
1
+ struct CameraUniforms {
2
+ model_view: mat4x4<f32>,
3
+ model_view_projection: mat4x4<f32>,
4
+ normal_mat: mat4x4<f32>,
5
+ aspect: f32,
6
+
7
+ padding0: u32,
8
+ padding1: u32,
9
+ padding2: u32,
10
+ };
11
+
12
+ @group(0) @binding(0) var<uniform> u_camera : CameraUniforms;
13
+
14
+ fn cameraMapPoint(p: vec3f) -> vec4f {
15
+ return u_camera.model_view_projection * vec4<f32>(p, 1.0);
16
+ }
17
+
18
+ fn cameraMapNormal(n: vec3f) -> vec4f {
19
+ return u_camera.normal_mat * vec4(n, 1.0);
20
+ }
21
+