webgpu 0.0.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
webgpu/main.py ADDED
@@ -0,0 +1,177 @@
1
+ """Main file for the webgpu example, creates a small 2d mesh and renders it using WebGPU"""
2
+
3
+ import urllib.parse
4
+
5
+ import js
6
+ import ngsolve as ngs
7
+ from netgen.occ import unit_cube
8
+ from netgen.geom2d import unit_square
9
+ from pyodide.ffi import create_proxy
10
+
11
+ from .gpu import WebGPU, init_webgpu
12
+ from .mesh import *
13
+ from .lic import LineIntegralConvolutionRenderObject
14
+
15
+ from .webgpu_api import BackendType, RenderPassEncoder, ColorTargetState
16
+
17
+ # s = ColorTargetState()
18
+ # print("BackendType", BackendType)
19
+ # print("ColorTargetState", s)
20
+
21
+
22
+ def f(encoder: RenderPassEncoder):
23
+ return
24
+
25
+
26
+ gpu: WebGPU = None
27
+ mesh_object: RenderObject = None
28
+ elements_object = None
29
+ point_number_object = None
30
+
31
+ cf = None
32
+ render_function = None
33
+
34
+
35
+ async def main():
36
+ global gpu, mesh_object, cf, render_function
37
+
38
+ gpu = await init_webgpu(js.document.getElementById("canvas"))
39
+ # print("DEVICE", dir(gpu.native_device))
40
+
41
+ point_number_object = None
42
+
43
+ if 1:
44
+ from ngsolve.meshes import MakeStructured3DMesh
45
+
46
+ # create new ngsolve mesh and evaluate arbitrary function on it
47
+ # mesh = ngs.Mesh(unit_cube.GenerateMesh(maxh=0.2))
48
+ # mesh = MakeStructured3DMesh(True, 10, prism=True)
49
+
50
+ import netgen.occ as occ
51
+ from netgen.meshing import IdentificationType
52
+
53
+ idtype = IdentificationType.CLOSESURFACES
54
+ inner = occ.Box((0, 0, 0), (1, 1, 1))
55
+ trafo = occ.gp_Trsf().Scale(inner.center, 1.1)
56
+ outer = trafo(inner)
57
+
58
+ # inner.Identify(outer, "", idtype, trafo)
59
+ shape = occ.Glue([outer - inner, inner])
60
+
61
+ geo = occ.OCCGeometry(shape)
62
+ mesh = geo.GenerateMesh(maxh=0.3)
63
+
64
+ # mesh = unit_square.GenerateMesh(maxh=0.3)
65
+ mesh = ngs.Mesh(mesh)
66
+
67
+ order = 3
68
+ cf = cf or ngs.sin(10 * ngs.x) * ngs.sin(10 * ngs.y)
69
+ # cf = ngs.x
70
+ data = MeshData(mesh, cf, order)
71
+ gpu.u_function.min = -1
72
+ gpu.u_function.max = 1
73
+ else:
74
+ # use compute shader to create a unit_square mesh
75
+ # but has always P1 and 'x' hard-coded as function
76
+ query = urllib.parse.parse_qs(js.location.search[1:])
77
+ N = 10
78
+ N = int(query.get("n", [N])[0])
79
+ data = create_testing_square_mesh(gpu, N)
80
+ gpu.u_function.min = 0
81
+ gpu.u_function.max = 1
82
+
83
+ # lic = LineIntegralConvolutionRenderObject(gpu, 1000, 800)
84
+ # print("LIC", lic)
85
+
86
+ # mesh_object = MeshRenderObject(gpu, data)
87
+ # mesh_object = MeshRenderObjectIndexed(gpu, data) # function values are wrong, due to ngsolve vertex numbering order
88
+ # mesh_object = MeshRenderObjectDeferred(
89
+ # gpu, data
90
+ # ) # function values are wrong, due to ngsolve vertex numbering order
91
+ point_number_object = PointNumbersRenderObject(gpu, data, font_size=16)
92
+ elements_object = Mesh3dElementsRenderObject(gpu, data)
93
+
94
+ t_last = 0
95
+ fps = 0
96
+ frame_counter = 0
97
+ params = pyodide.ffi.to_js({"shrink": 0.3})
98
+
99
+ def render(time):
100
+ # this is the render function, it's called for every frame
101
+ if not isinstance(time, float):
102
+ time = 0
103
+
104
+ nonlocal t_last, fps, frame_counter
105
+ print("params", params.shrink)
106
+ dt = time - t_last
107
+ t_last = time
108
+ frame_counter += 1
109
+ print(f"frame time {dt:.2f} ms")
110
+
111
+ gpu.u_mesh.shrink = params.shrink
112
+
113
+ # copy camera position etc. to GPU
114
+ gpu.update_uniforms()
115
+
116
+ command_encoder = gpu.device.createCommandEncoder()
117
+
118
+ if mesh_object is not None:
119
+ mesh_object.render(command_encoder)
120
+
121
+ if elements_object is not None:
122
+ elements_object.render(command_encoder)
123
+
124
+ if point_number_object is not None:
125
+ point_number_object.render(command_encoder)
126
+
127
+ gpu.device.queue.submit([command_encoder.finish()])
128
+ if frame_counter < 20:
129
+ js.requestAnimationFrame(render_function)
130
+
131
+ render_function = create_proxy(render)
132
+ gpu.input_handler._update_uniforms()
133
+ gpu.input_handler.render_function = render_function
134
+
135
+ render_function.request_id = js.requestAnimationFrame(render_function)
136
+
137
+ try:
138
+ gui = js.gui
139
+
140
+ gui.reset(recursive=True)
141
+
142
+ folder = js.window.folder
143
+ folder.reset()
144
+ folder.add(params, "shrink", 0.1, 1.0)
145
+ gui.onChange(render_function)
146
+ except Exception as e:
147
+ print(e)
148
+
149
+
150
+ def cleanup():
151
+ print("cleanup")
152
+ global gpu, mesh_object
153
+ if "gpu" in globals():
154
+ del gpu
155
+ if "mesh_object" in globals():
156
+ del mesh_object
157
+
158
+
159
+ async def user_function(data):
160
+ code, expr = data
161
+ import base64
162
+ import marshal
163
+ import types
164
+
165
+ code = base64.b64decode(code.encode("utf-8"))
166
+ code = marshal.loads(code)
167
+ func = types.FunctionType(code, globals(), "user_function")
168
+ func(expr)
169
+
170
+
171
+ async def reload(*args, **kwargs):
172
+ print("reload")
173
+ cleanup()
174
+ reload_package("webgpu")
175
+ from webgpu.main import main
176
+
177
+ await main()
webgpu/platform.py ADDED
@@ -0,0 +1,125 @@
1
+ """Platform specific code, currenty there are two possibilities:
2
+
3
+ 1. Running in a browser with Pyodide
4
+ Webgpu calls are done directly in the browser using the Pyodide interface provided by the "js" module.
5
+
6
+ 2. Running in a Python environment with a websocket connection to a browser
7
+ Webgpu calls are transferred via a websocket connection to the browser environment (using the webgpu.link.websocket module)
8
+ """
9
+
10
+ from collections.abc import Mapping
11
+
12
+ is_pyodide = False
13
+ create_proxy = None
14
+ destroy_proxy = None
15
+ js = None
16
+ websocket_server = None
17
+ link = None
18
+
19
+ try:
20
+ import js as pyodide_js
21
+ import pyodide.ffi
22
+ from pyodide.ffi import create_proxy, JsPromise, JsProxy
23
+
24
+ def destroy_proxy(proxy):
25
+ proxy.destroy()
26
+
27
+ is_pyodide = True
28
+
29
+ def _default_converter(value, a, b):
30
+ from .webgpu_api import BaseWebGPUHandle, BaseWebGPUObject
31
+
32
+ if isinstance(value, BaseWebGPUHandle):
33
+ return pyodide.ffi.to_js(value.handle)
34
+ if isinstance(value, BaseWebGPUObject):
35
+ return value.__dict__
36
+
37
+ def _convert(d):
38
+ from .webgpu_api import BaseWebGPUHandle, BaseWebGPUObject
39
+
40
+ if d is None:
41
+ return None
42
+ if isinstance(d, BaseWebGPUHandle):
43
+ return d.handle
44
+ if isinstance(d, BaseWebGPUObject):
45
+ return _convert(d.__dict__) if d.__dict__ else None
46
+ if isinstance(d, Mapping):
47
+ if not d:
48
+ return None
49
+ ret = {}
50
+ for key in d:
51
+ value = _convert(d[key])
52
+ if value is not None:
53
+ ret[key] = value
54
+ return ret
55
+
56
+ if isinstance(d, list):
57
+ return [_convert(value) for value in d]
58
+
59
+ return d
60
+
61
+ def toJS(value):
62
+ value = _convert(value)
63
+ ret = pyodide.ffi.to_js(
64
+ value,
65
+ dict_converter=pyodide_js.Object.fromEntries,
66
+ default_converter=_default_converter,
67
+ create_pyproxies=False,
68
+ )
69
+ return ret
70
+
71
+ except ImportError:
72
+ pass
73
+
74
+ if not is_pyodide:
75
+ from .link.proxy import Proxy as JsProxy
76
+
77
+ from .link.websocket import WebsocketLinkServer
78
+
79
+ toJS = lambda x: x
80
+ websocket_server = WebsocketLinkServer()
81
+ create_proxy = websocket_server.create_proxy
82
+ destroy_proxy = websocket_server.destroy_proxy
83
+ link = websocket_server
84
+
85
+ class JsPromise:
86
+ pass
87
+
88
+
89
+ if is_pyodide:
90
+ from .link.base import LinkBase
91
+ import json
92
+
93
+ LinkBase.register_serializer(JsProxy, lambda v: json.loads(pyodide_js.JSON.stringify(v)))
94
+
95
+
96
+ def init():
97
+ global js
98
+ if is_pyodide or js is not None:
99
+ return
100
+
101
+ websocket_server.wait_for_connection()
102
+ js = websocket_server.get(None, None)
103
+
104
+ from .link.base import LinkBase
105
+ from .webgpu_api import BaseWebGPUHandle, BaseWebGPUObject
106
+
107
+ LinkBase.register_serializer(BaseWebGPUHandle, lambda v: v.handle)
108
+ LinkBase.register_serializer(BaseWebGPUObject, lambda v: v.__dict__ or None)
109
+
110
+ websocket_server._start_handling_messages.set()
111
+
112
+ def init_pyodide(link_):
113
+ global link
114
+ link = link_
115
+ print("init pyodide with link")
116
+ global js
117
+ js = link.get(None, None)
118
+ print("JS:", js)
119
+
120
+ from .link.base import LinkBase
121
+ from .webgpu_api import BaseWebGPUHandle, BaseWebGPUObject
122
+
123
+ LinkBase.register_serializer(BaseWebGPUHandle, lambda v: v.handle)
124
+ LinkBase.register_serializer(BaseWebGPUObject, lambda v: v.__dict__ or None)
125
+
@@ -0,0 +1,159 @@
1
+ import uuid
2
+ from typing import Callable
3
+
4
+ from .camera import Camera
5
+ from .canvas import Canvas
6
+ from .light import Light
7
+ from .utils import BaseBinding, is_pyodide, create_bind_group, get_device
8
+ from .webgpu_api import (
9
+ CommandEncoder,
10
+ CompareFunction,
11
+ DepthStencilState,
12
+ Device,
13
+ FragmentState,
14
+ PrimitiveState,
15
+ PrimitiveTopology,
16
+ VertexState,
17
+ )
18
+
19
+
20
+ class RenderOptions:
21
+ viewport: tuple[int, int, int, int, float, float]
22
+ canvas: Canvas
23
+
24
+ def __init__(self, canvas):
25
+ self.canvas = canvas
26
+ self.light = Light(self.device)
27
+ self.camera = Camera(canvas)
28
+
29
+ @property
30
+ def device(self) -> Device:
31
+ return self.canvas.device
32
+
33
+ def update_buffers(self):
34
+ self.camera._update_uniforms()
35
+
36
+ def get_bindings(self):
37
+ return [
38
+ *self.light.get_bindings(),
39
+ *self.camera.get_bindings(),
40
+ ]
41
+
42
+ def begin_render_pass(self, command_encoder: CommandEncoder, **kwargs):
43
+ load_op = command_encoder.getLoadOp()
44
+
45
+ render_pass_encoder = command_encoder.beginRenderPass(
46
+ self.canvas.color_attachments(load_op),
47
+ self.canvas.depth_stencil_attachment(load_op),
48
+ **kwargs,
49
+ )
50
+
51
+ render_pass_encoder.setViewport(
52
+ 0, 0, self.canvas.width, self.canvas.height, 0.0, 1.0
53
+ )
54
+
55
+ return render_pass_encoder
56
+
57
+
58
+ class BaseRenderObject:
59
+ options: RenderOptions
60
+ label: str = ""
61
+ _timestamp: float = -1
62
+ active: bool = True
63
+
64
+ def __init__(self, label=None):
65
+ if label is None:
66
+ self.label = self.__class__.__name__
67
+ else:
68
+ self.label = label
69
+
70
+ def get_bounding_box(self) -> tuple[list[float], list[float]] | None:
71
+ return ([0.0, 0.0, 0.0], [1.0, 1.0, 1.0])
72
+
73
+ def update(self, timestamp):
74
+ if timestamp == self._timestamp:
75
+ return
76
+ self._timestamp = timestamp
77
+ self.create_render_pipeline()
78
+
79
+ @property
80
+ def device(self) -> Device:
81
+ return get_device()
82
+
83
+ @property
84
+ def canvas(self) -> Canvas:
85
+ return self.options.canvas
86
+
87
+ def create_render_pipeline(self) -> None:
88
+ raise NotImplementedError
89
+
90
+ def render(self, encoder: CommandEncoder):
91
+ raise NotImplementedError
92
+
93
+ def get_bindings(self) -> list[BaseBinding]:
94
+ raise NotImplementedError
95
+
96
+ def get_shader_code(self) -> str:
97
+ raise NotImplementedError
98
+
99
+ def add_options_to_gui(self, gui):
100
+ pass
101
+
102
+
103
+ class MultipleRenderObject(BaseRenderObject):
104
+ def __init__(self, render_objects):
105
+ self.render_objects = render_objects
106
+
107
+ def update(self, timestamp):
108
+ for r in self.render_objects:
109
+ r.options = self.options
110
+ r.update(timestamp=timestamp)
111
+
112
+ def redraw(self, timestamp=None):
113
+ for r in self.render_objects:
114
+ r.redraw(timestamp=timestamp)
115
+
116
+ def render(self, encoder):
117
+ for r in self.render_objects:
118
+ r.render(encoder)
119
+
120
+
121
+ class RenderObject(BaseRenderObject):
122
+ """Base class for render objects"""
123
+
124
+ n_vertices: int = 0
125
+ n_instances: int = 1
126
+ topology: PrimitiveTopology = PrimitiveTopology.triangle_list
127
+ depthBias: int = 0
128
+ vertex_entry_point: str = "vertex_main"
129
+ fragment_entry_point: str = "fragment_main"
130
+
131
+ def create_render_pipeline(self) -> None:
132
+ shader_module = self.device.createShaderModule(self.get_shader_code())
133
+ layout, self.group = create_bind_group(self.device, self.get_bindings())
134
+ self.pipeline = self.device.createRenderPipeline(
135
+ self.device.createPipelineLayout([layout]),
136
+ vertex=VertexState(
137
+ module=shader_module, entryPoint=self.vertex_entry_point
138
+ ),
139
+ fragment=FragmentState(
140
+ module=shader_module,
141
+ entryPoint=self.fragment_entry_point,
142
+ targets=[self.options.canvas.color_target],
143
+ ),
144
+ primitive=PrimitiveState(topology=self.topology),
145
+ depthStencil=DepthStencilState(
146
+ format=self.options.canvas.depth_format,
147
+ depthWriteEnabled=True,
148
+ depthCompare=CompareFunction.less,
149
+ depthBias=self.depthBias,
150
+ ),
151
+ multisample=self.options.canvas.multisample,
152
+ )
153
+
154
+ def render(self, encoder: CommandEncoder) -> None:
155
+ render_pass = self.options.begin_render_pass(encoder)
156
+ render_pass.setPipeline(self.pipeline)
157
+ render_pass.setBindGroup(0, self.group)
158
+ render_pass.draw(self.n_vertices, self.n_instances)
159
+ render_pass.end()
webgpu/scene.py ADDED
@@ -0,0 +1,206 @@
1
+ from threading import Timer
2
+ import time
3
+
4
+ from .canvas import Canvas
5
+ from .render_object import BaseRenderObject, RenderOptions
6
+ from .utils import is_pyodide, max_bounding_box
7
+ from .webgpu_api import *
8
+ from .platform import create_proxy, destroy_proxy
9
+ from . import platform
10
+ import math
11
+
12
+ _TARGET_FPS = 60
13
+
14
+
15
+ def debounce(render_function):
16
+ # Render only once every 1/_TARGET_FPS seconds
17
+ def debounced(*args, **kwargs):
18
+ if debounced.timer is not None:
19
+ # we already have a render scheduled, so do nothing
20
+ return
21
+
22
+ def f():
23
+ # clear the timer, so we can schedule a new one with the next function call
24
+ t = time.time()
25
+ render_function(*args, **kwargs)
26
+ debounced.timer = None
27
+ debounced.t_last = t
28
+
29
+ t_wait = max(1 / _TARGET_FPS - (time.time() - debounced.t_last), 0)
30
+ debounced.timer = Timer(t_wait, f)
31
+ debounced.timer.start()
32
+
33
+ debounced.timer = None
34
+ debounced.t_last = time.time()
35
+ return debounced
36
+
37
+
38
+ class Scene:
39
+ canvas: Canvas = None
40
+ render_objects: list[BaseRenderObject]
41
+ options: RenderOptions
42
+ gui: object = None
43
+
44
+ def __init__(
45
+ self,
46
+ render_objects: list[BaseRenderObject],
47
+ id: str | None = None,
48
+ canvas: Canvas | None = None,
49
+ ):
50
+ if id is None:
51
+ import uuid
52
+
53
+ id = str(uuid.uuid4())
54
+
55
+ self._id = id
56
+ self.render_objects = render_objects
57
+
58
+ if is_pyodide:
59
+ _scenes_by_id[id] = self
60
+ if canvas is not None:
61
+ self.init(canvas)
62
+
63
+ self.t_last = 0
64
+
65
+ def __repr__(self):
66
+ return ""
67
+
68
+ @property
69
+ def id(self) -> str:
70
+ return self._id
71
+
72
+ @property
73
+ def device(self) -> Device:
74
+ return self.canvas.device
75
+
76
+ def init(self, canvas):
77
+ self.canvas = canvas
78
+ self.options = RenderOptions(self.canvas)
79
+
80
+ timestamp = time.time()
81
+ for obj in self.render_objects:
82
+ obj.options = self.options
83
+ obj.update(timestamp=timestamp)
84
+
85
+ pmin, pmax = max_bounding_box(
86
+ [o.get_bounding_box() for o in self.render_objects]
87
+ )
88
+ camera = self.options.camera
89
+ camera.transform._center = 0.5 * (pmin + pmax)
90
+
91
+ def norm(v):
92
+ return math.sqrt(v[0] ** 2 + v[1] ** 2 + v[2] ** 2)
93
+
94
+ camera.transform._scale = 2 / norm(pmax - pmin)
95
+ if not (pmin[2] == 0 and pmax[2] == 0):
96
+ camera.transform.rotate(270, 0)
97
+ camera.transform.rotate(0, -20)
98
+ camera.transform.rotate(20, 0)
99
+
100
+ self._js_render = create_proxy(self._render_direct)
101
+ camera.register_callbacks(canvas.input_handler, self.render)
102
+ self.options.update_buffers()
103
+ if is_pyodide:
104
+ _scenes_by_id[self.id] = self
105
+
106
+ canvas.on_resize(self.render)
107
+
108
+ def redraw(self):
109
+ import time
110
+
111
+ ts = time.time()
112
+ for obj in self.render_objects:
113
+ obj.redraw(timestamp=ts)
114
+
115
+ self.render()
116
+ return
117
+
118
+ if is_pyodide:
119
+ import js
120
+
121
+ js.requestAnimationFrame(self._js_render)
122
+ else:
123
+ # TODO: check if we are in a jupyter kernel
124
+ from .jupyter import run_code_in_pyodide
125
+
126
+ ts = time.time()
127
+ for obj in self.render_objects:
128
+ obj.redraw(timestamp=ts)
129
+
130
+ run_code_in_pyodide(
131
+ f"import webgpu.scene; webgpu.scene.redraw_scene('{self.id}')"
132
+ )
133
+
134
+ def _render(self):
135
+ platform.js.requestAnimationFrame(self._js_render)
136
+
137
+ def _render_direct(self, t=0):
138
+ encoder = self.device.createCommandEncoder()
139
+
140
+ for obj in self.render_objects:
141
+ if obj.active:
142
+ obj.render(encoder)
143
+
144
+ encoder.copyTextureToTexture(
145
+ TexelCopyTextureInfo(self.canvas.target_texture),
146
+ TexelCopyTextureInfo(self.canvas.context.getCurrentTexture()),
147
+ [self.canvas.width, self.canvas.height, 1],
148
+ )
149
+ self.device.queue.submit([encoder.finish()])
150
+
151
+ @debounce
152
+ def render(self, t=0):
153
+ # self.canvas.resize()
154
+
155
+ if is_pyodide:
156
+ self._render()
157
+ return
158
+ # print("render")
159
+ # print("canvas", self.canvas.canvas)
160
+ # from . import proxy
161
+ # proxy.js.console.log("canvas", self.canvas.canvas)
162
+ # print("canvas size ", self.canvas.canvas.width, self.canvas.canvas.height)
163
+ # print(
164
+ # "texture size",
165
+ # self.canvas.target_texture.width,
166
+ # self.canvas.target_texture.height,
167
+ # )
168
+ encoder = self.device.createCommandEncoder()
169
+
170
+ for obj in self.render_objects:
171
+ if obj.active:
172
+ obj.render(encoder)
173
+
174
+ self.device.queue.submit([encoder.finish()])
175
+
176
+ if not is_pyodide:
177
+ platform.js.patchedRequestAnimationFrame(
178
+ self.canvas.device.handle,
179
+ self.canvas.context,
180
+ self.canvas.target_texture,
181
+ )
182
+
183
+ def cleanup(self):
184
+ for obj in self.render_objects:
185
+ obj.options = None
186
+
187
+ self.options.camera.unregister_callbacks(self.canvas.input_handler)
188
+ self.options.camera._render_function = None
189
+ self.canvas.input_handler.unregister_callbacks()
190
+ destroy_proxy(self._js_render)
191
+ del self._js_render
192
+ self.canvas._on_resize_callbacks.remove(self.render)
193
+ self.canvas = None
194
+
195
+ if is_pyodide:
196
+ del _scenes_by_id[self.id]
197
+
198
+
199
+ if is_pyodide:
200
+ _scenes_by_id: dict[str, Scene] = {}
201
+
202
+ def get_scene(id: str) -> Scene:
203
+ return _scenes_by_id[id]
204
+
205
+ def redraw_scene(id: str):
206
+ get_scene(id).redraw()
File without changes
@@ -0,0 +1,21 @@
1
+ struct CameraUniforms {
2
+ model_view: mat4x4<f32>,
3
+ model_view_projection: mat4x4<f32>,
4
+ normal_mat: mat4x4<f32>,
5
+ aspect: f32,
6
+
7
+ padding0: u32,
8
+ padding1: u32,
9
+ padding2: u32,
10
+ };
11
+
12
+ @group(0) @binding(0) var<uniform> u_camera : CameraUniforms;
13
+
14
+ fn cameraMapPoint(p: vec3f) -> vec4f {
15
+ return u_camera.model_view_projection * vec4<f32>(p, 1.0);
16
+ }
17
+
18
+ fn cameraMapNormal(n: vec3f) -> vec4f {
19
+ return u_camera.normal_mat * vec4(n, 1.0);
20
+ }
21
+