webgpu 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
webgpu/utils.py ADDED
@@ -0,0 +1,379 @@
1
+ import base64
2
+ import zlib
3
+ from pathlib import Path
4
+
5
+ from . import platform
6
+ from .webgpu_api import *
7
+ from .webgpu_api import toJS as to_js
8
+
9
+ _device: Device = None
10
+
11
+
12
+ def init_device_sync():
13
+ global _device
14
+ if _device is not None:
15
+ return _device
16
+
17
+ if not platform.js.navigator.gpu:
18
+ platform.js.alert("WebGPU is not supported")
19
+ sys.exit(1)
20
+
21
+ reqAdapter = platform.js.navigator.gpu.requestAdapter
22
+ options = RequestAdapterOptions(
23
+ powerPreference=PowerPreference.low_power,
24
+ ).toJS()
25
+ adapter = reqAdapter(options)
26
+ if not adapter:
27
+ platform.js.alert("WebGPU is not supported")
28
+ sys.exit(1)
29
+ one_gig = 1024**3
30
+ _device = Device(
31
+ adapter.requestDevice(
32
+ [],
33
+ Limits(
34
+ maxBufferSize=one_gig - 16,
35
+ maxStorageBufferBindingSize=one_gig - 16,
36
+ ),
37
+ None,
38
+ "WebGPU device",
39
+ )
40
+ )
41
+ return _device
42
+
43
+
44
+ async def init_device() -> Device:
45
+ global _device
46
+
47
+ if _device is not None:
48
+ return _device
49
+
50
+ adapter = requestAdapter(powerPreference=PowerPreference.low_power)
51
+ try:
52
+ adapter = await adapter
53
+ except:
54
+ pass
55
+
56
+ required_features = []
57
+ if "timestamp-query" in adapter.features:
58
+ print("have timestamp query")
59
+ required_features.append("timestamp-query")
60
+ else:
61
+ print("no timestamp query")
62
+
63
+ one_meg = 1024**2
64
+ one_gig = 1024**3
65
+ _device = adapter.requestDevice(
66
+ label="WebGPU device",
67
+ requiredLimits=Limits(
68
+ maxBufferSize=one_gig - 16,
69
+ maxStorageBufferBindingSize=one_gig - 16,
70
+ ),
71
+ )
72
+ try:
73
+ _device = await _device
74
+ except:
75
+ pass
76
+
77
+ limits = _device.limits
78
+ js.console.log("device limits\n", limits)
79
+ js.console.log("adapter info\n", adapter.info)
80
+
81
+ print(f"max storage buffer binding size {limits.maxStorageBufferBindingSize / one_meg:.2f} MB")
82
+ print(f"max buffer size {limits.maxBufferSize / one_meg:.2f} MB")
83
+
84
+ return _device
85
+
86
+
87
+ def get_device() -> Device:
88
+ if _device is None:
89
+ raise RuntimeError("Device not initialized")
90
+ return _device
91
+
92
+
93
+ class Pyodide:
94
+ def __init__(self):
95
+ pass
96
+
97
+ def __setattr__(self, key, value):
98
+ pass
99
+
100
+
101
+ def find_shader_file(file_name, module_file) -> Path:
102
+ for path in [module_file, __file__]:
103
+ file_path = Path(path).parent / "shaders" / file_name
104
+ if file_path.exists():
105
+ return file_path
106
+
107
+ raise FileNotFoundError(f"Shader file {file_name} not found")
108
+
109
+
110
+ def read_shader_file(file_name, module_file) -> str:
111
+ code = find_shader_file(file_name, module_file).read_text()
112
+ if not "#import" in code:
113
+ return code
114
+ lines = code.split("\n")
115
+ code = ""
116
+ for line in lines:
117
+ if line.startswith("#import"):
118
+ imported_file = line.split()[1] + ".wgsl"
119
+ code += f"// start file {imported_file}\n"
120
+ code += read_shader_file(imported_file, module_file) + "\n"
121
+ code += f"// end file {imported_file}\n"
122
+ else:
123
+ code += line + "\n"
124
+ return code
125
+
126
+
127
+ def encode_bytes(data: bytes) -> str:
128
+ if data == b"":
129
+ return ""
130
+ return base64.b64encode(zlib.compress(data)).decode("utf-8")
131
+
132
+
133
+ def decode_bytes(data: str) -> bytes:
134
+ if data == "":
135
+ return b""
136
+ return zlib.decompress(base64.b64decode(data.encode()))
137
+
138
+
139
+ class BaseBinding:
140
+ """Base class for any object that has a binding number (uniform, storage buffer, texture etc.)"""
141
+
142
+ def __init__(
143
+ self,
144
+ nr,
145
+ visibility=ShaderStage.ALL,
146
+ resource=None,
147
+ layout=None,
148
+ binding=None,
149
+ ):
150
+ self.nr = nr
151
+ self.visibility = visibility
152
+ self._layout_data = layout or {}
153
+ self._binding_data = binding or {}
154
+ self._resource = resource or {}
155
+
156
+ @property
157
+ def layout(self):
158
+ return {
159
+ "binding": self.nr,
160
+ "visibility": self.visibility,
161
+ } | self._layout_data
162
+
163
+ @property
164
+ def binding(self):
165
+ return {
166
+ "binding": self.nr,
167
+ "resource": self._resource,
168
+ }
169
+
170
+
171
+ class UniformBinding(BaseBinding):
172
+ def __init__(self, nr, buffer, visibility=ShaderStage.ALL):
173
+ super().__init__(
174
+ nr=nr,
175
+ visibility=visibility,
176
+ layout={"buffer": {"type": "uniform"}},
177
+ resource={"buffer": buffer},
178
+ )
179
+
180
+
181
+ class StorageTextureBinding(BaseBinding):
182
+ def __init__(
183
+ self,
184
+ nr,
185
+ texture,
186
+ visibility=ShaderStage.COMPUTE,
187
+ dim=2,
188
+ access="write-only",
189
+ ):
190
+ super().__init__(
191
+ nr=nr,
192
+ visibility=visibility,
193
+ layout={
194
+ "storageTexture": {
195
+ "access": access,
196
+ "format": texture.format,
197
+ "viewDimension": f"{dim}d",
198
+ }
199
+ },
200
+ resource=texture.createView(),
201
+ )
202
+
203
+
204
+ class TextureBinding(BaseBinding):
205
+ def __init__(
206
+ self,
207
+ nr,
208
+ texture,
209
+ visibility=ShaderStage.FRAGMENT,
210
+ sample_type="float",
211
+ dim=1,
212
+ multisamples=False,
213
+ ):
214
+ super().__init__(
215
+ nr=nr,
216
+ visibility=visibility,
217
+ layout={
218
+ "texture": {
219
+ "sampleType": sample_type,
220
+ "viewDimension": f"{dim}d",
221
+ "multisamples": multisamples,
222
+ }
223
+ },
224
+ resource=texture.createView(),
225
+ )
226
+
227
+
228
+ class SamplerBinding(BaseBinding):
229
+ def __init__(self, nr, sampler, visibility=ShaderStage.FRAGMENT):
230
+ super().__init__(
231
+ nr=nr,
232
+ visibility=visibility,
233
+ layout={"sampler": {"type": "filtering"}},
234
+ resource=sampler,
235
+ )
236
+
237
+
238
+ class BufferBinding(BaseBinding):
239
+ def __init__(self, nr, buffer, read_only=True, visibility=ShaderStage.ALL):
240
+ type_ = "read-only-storage" if read_only else "storage"
241
+ if not read_only:
242
+ visibility = ShaderStage.COMPUTE
243
+ super().__init__(
244
+ nr=nr,
245
+ visibility=visibility,
246
+ layout={"buffer": {"type": type_}},
247
+ resource={"buffer": buffer},
248
+ )
249
+
250
+
251
+ def create_bind_group(device, bindings: list, label=""):
252
+ """creates bind group layout and bind group from a list of BaseBinding objects"""
253
+ layouts = []
254
+ resources = []
255
+ for binding in bindings:
256
+ layouts.append(BindGroupLayoutEntry(**binding.layout))
257
+ resources.append(BindGroupEntry(**binding.binding))
258
+
259
+ layout = device.createBindGroupLayout(entries=layouts, label=label)
260
+ group = device.createBindGroup(
261
+ label=label,
262
+ layout=layout,
263
+ entries=resources,
264
+ )
265
+ return layout, group
266
+
267
+
268
+ class TimeQuery:
269
+ def __init__(self, device):
270
+ self.device = device
271
+ self.query_set = self.device.createQuerySet(to_js({"type": "timestamp", "count": 2}))
272
+ self.buffer = self.device.createBuffer(
273
+ size=16,
274
+ usage=BufferUsage.COPY_DST | BufferUsage.MAP_READ,
275
+ )
276
+
277
+
278
+ def reload_package(package_name):
279
+ """Reload python package and all submodules (searches in modules for references to other submodules)"""
280
+ import importlib
281
+ import os
282
+ import types
283
+
284
+ package = importlib.import_module(package_name)
285
+ assert hasattr(package, "__package__")
286
+ file_name = package.__file__
287
+ package_dir = os.path.dirname(file_name) + os.sep
288
+ reloaded_modules = {file_name: package}
289
+
290
+ def reload_recursive(module):
291
+ module = importlib.reload(module)
292
+
293
+ for var in vars(module).values():
294
+ if isinstance(var, types.ModuleType):
295
+ file_name = getattr(var, "__file__", None)
296
+ if file_name is not None and file_name.startswith(package_dir):
297
+ if file_name not in reloaded_modules:
298
+ reloaded_modules[file_name] = reload_recursive(var)
299
+
300
+ return module
301
+
302
+ reload_recursive(package)
303
+ return reloaded_modules
304
+
305
+
306
+ def run_compute_shader(encoder, code, bindings, n_workgroups, label="compute", entry_point="main"):
307
+ from webgpu.utils import create_bind_group, get_device
308
+
309
+ device = get_device()
310
+
311
+ shader_module = device.createShaderModule(code)
312
+
313
+ layout, bind_group = create_bind_group(device, bindings, label)
314
+ pipeline = device.createComputePipeline(
315
+ device.createPipelineLayout([layout], label),
316
+ ComputeState(
317
+ shader_module,
318
+ entry_point,
319
+ ),
320
+ label,
321
+ )
322
+
323
+ pass_encoder = encoder.beginComputePass()
324
+ pass_encoder.setPipeline(pipeline)
325
+ pass_encoder.setBindGroup(0, bind_group)
326
+ pass_encoder.dispatchWorkgroups(*n_workgroups)
327
+ pass_encoder.end()
328
+
329
+
330
+ def buffer_from_array(array, usage=BufferUsage.STORAGE | BufferUsage.COPY_DST):
331
+ device = get_device()
332
+ buffer = device.createBuffer(array.size * array.itemsize, usage=usage)
333
+ device.queue.writeBuffer(buffer, 0, array.tobytes())
334
+ return buffer
335
+
336
+
337
+ def uniform_from_array(array):
338
+ return buffer_from_array(array, usage=BufferUsage.UNIFORM | BufferUsage.COPY_DST)
339
+
340
+
341
+ class ReadBuffer:
342
+ def __init__(self, buffer, encoder):
343
+ self.buffer = buffer
344
+ self.read_buffer = get_device().createBuffer(
345
+ buffer.size, BufferUsage.MAP_READ | BufferUsage.COPY_DST
346
+ )
347
+ encoder.copyBufferToBuffer(self.buffer, 0, self.read_buffer, 0, buffer.size)
348
+
349
+ def get_array(self, dtype):
350
+ import numpy as np
351
+
352
+ self.read_buffer.handle.mapAsync(MapMode.READ, 0, self.read_buffer.size)
353
+ data = self.read_buffer.handle.getMappedRange(0, self.read_buffer.size)
354
+ res = np.frombuffer(data, dtype=dtype)
355
+ self.read_buffer.unmap()
356
+ return res
357
+
358
+
359
+ def max_bounding_box(boxes):
360
+ import numpy as np
361
+
362
+ boxes = [b for b in boxes if b is not None]
363
+ pmin = np.array(boxes[0][0])
364
+ pmax = np.array(boxes[0][1])
365
+ for b in boxes[1:]:
366
+ pmin = np.minimum(pmin, np.array(b[0]))
367
+ pmax = np.maximum(pmax, np.array(b[1]))
368
+ return (pmin, pmax)
369
+
370
+
371
+ def format_number(n):
372
+ if n == 0:
373
+ return "0"
374
+ abs_n = abs(n)
375
+ # Use scientific notation for numbers smaller than 0.001 or larger than 9999
376
+ if abs_n < 1e-2 or abs_n >= 1e3:
377
+ return f"{n:.2e}"
378
+ else:
379
+ return f"{n:.3g}"
webgpu/vectors.py ADDED
@@ -0,0 +1,101 @@
1
+ import numpy as np
2
+
3
+ from . import BufferBinding, Colormap, RenderObject, read_shader_file
4
+ from .uniforms import UniformBase, ct
5
+ from .utils import buffer_from_array
6
+ from .webgpu_api import PrimitiveTopology
7
+
8
+
9
+ class Binding:
10
+ POINTS = 81
11
+ VECTORS = 82
12
+ OPTIONS = 83
13
+
14
+
15
+ class VectorUniform(UniformBase):
16
+ _binding = Binding.OPTIONS
17
+ _fields_ = [
18
+ ("size", ct.c_float),
19
+ ("scale_veclen", ct.c_uint32),
20
+ ("_padding2", ct.c_float),
21
+ ("_padding3", ct.c_float),
22
+ ]
23
+
24
+
25
+ class BaseVectorRenderObject(RenderObject):
26
+ topology = PrimitiveTopology.triangle_strip
27
+ n_vertices = 10
28
+
29
+ def __init__(self, label="VectorField"):
30
+ super().__init__(label=label)
31
+ self.colormap = Colormap()
32
+
33
+ def update(self, timestamp):
34
+ if timestamp == self._timestamp:
35
+ return
36
+ self._timestamp = timestamp
37
+
38
+ self.colormap.options = self.options
39
+ self.colormap.update(timestamp)
40
+
41
+ def get_bindings(self):
42
+ return [
43
+ *self.options.camera.get_bindings(),
44
+ *self.options.light.get_bindings(),
45
+ BufferBinding(Binding.POINTS, self._buffers["points"]),
46
+ BufferBinding(Binding.VECTORS, self._buffers["vectors"]),
47
+ *self.vec_uniforms.get_bindings(),
48
+ *self.colormap.get_bindings(),
49
+ ]
50
+
51
+ def create_vector_data(self):
52
+ raise NotImplementedError
53
+
54
+ def get_shader_code(self):
55
+ shader_code = read_shader_file("vector.wgsl", __file__)
56
+ shader_code += self.options.camera.get_shader_code()
57
+ shader_code += self.options.light.get_shader_code()
58
+ shader_code += self.colormap.get_shader_code()
59
+ return shader_code
60
+
61
+ def render(self, encoder):
62
+ super().render(encoder)
63
+
64
+
65
+ class VectorRenderer(BaseVectorRenderObject):
66
+ def __init__(self, points, vectors, size=None, scale_with_vector_length=False):
67
+ super().__init__(label="VectorField")
68
+ self.scale_with_vector_length = scale_with_vector_length
69
+ self.points = np.asarray(points, dtype=np.float32).reshape(-1)
70
+ self.vectors = np.asarray(vectors, dtype=np.float32).reshape(-1)
71
+ self.bounding_box = self.points.reshape(-1, 3).min(axis=0), self.points.reshape(-1, 3).max(
72
+ axis=0
73
+ )
74
+ self.size = size or 1 / 10 * np.linalg.norm(self.bounding_box[1] - self.bounding_box[0])
75
+
76
+ def update(self, timestamp):
77
+ if timestamp == self._timestamp:
78
+ return
79
+
80
+ super().update(timestamp)
81
+
82
+ self._buffers = {
83
+ "points": buffer_from_array(self.points),
84
+ "vectors": buffer_from_array(self.vectors),
85
+ }
86
+ self.vec_uniforms = VectorUniform(self.device)
87
+ self.vec_uniforms.size = self.size
88
+ self.vec_uniforms.scale_veclen = self.scale_with_vector_length
89
+ self.vec_uniforms.update_buffer()
90
+ min_vec, max_vec = (
91
+ np.linalg.norm(self.vectors.reshape(-1, 3), axis=1).min(),
92
+ np.linalg.norm(self.vectors.reshape(-1, 3), axis=1).max(),
93
+ )
94
+ if self.colormap.autoupdate:
95
+ self.colormap.set_min_max(min_vec, max_vec, set_autoupdate=False)
96
+ self.colormap.update(timestamp + 1)
97
+ self.n_instances = len(self.points) // 3
98
+ self.create_render_pipeline()
99
+
100
+ def get_bounding_box(self):
101
+ return self.bounding_box