webgpu 0.0.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
webgpu/utils.py ADDED
@@ -0,0 +1,385 @@
1
+ import base64
2
+ import zlib
3
+ from pathlib import Path
4
+
5
+ from .webgpu_api import *
6
+ from .webgpu_api import toJS as to_js
7
+ from . import platform
8
+
9
+ _device: Device = None
10
+
11
+
12
+ def init_device_sync():
13
+ global _device
14
+ if _device is not None:
15
+ return _device
16
+
17
+ if not platform.js.navigator.gpu:
18
+ platform.js.alert("WebGPU is not supported")
19
+ sys.exit(1)
20
+
21
+ reqAdapter = platform.js.navigator.gpu.requestAdapter
22
+ options = RequestAdapterOptions(
23
+ powerPreference=PowerPreference.low_power,
24
+ ).toJS()
25
+ adapter = reqAdapter(options)
26
+ if not adapter:
27
+ platform.js.alert("WebGPU is not supported")
28
+ sys.exit(1)
29
+ one_gig = 1024**3
30
+ _device = Device(
31
+ adapter.requestDevice(
32
+ [],
33
+ Limits(
34
+ maxBufferSize=one_gig - 16,
35
+ maxStorageBufferBindingSize=one_gig - 16,
36
+ ),
37
+ None,
38
+ "WebGPU device",
39
+ )
40
+ )
41
+ return _device
42
+
43
+
44
+ async def init_device() -> Device:
45
+ global _device
46
+
47
+ if _device is not None:
48
+ return _device
49
+
50
+ adapter = requestAdapter(powerPreference=PowerPreference.low_power)
51
+ try:
52
+ adapter = await adapter
53
+ except:
54
+ pass
55
+
56
+ required_features = []
57
+ if "timestamp-query" in adapter.features:
58
+ print("have timestamp query")
59
+ required_features.append("timestamp-query")
60
+ else:
61
+ print("no timestamp query")
62
+
63
+ one_meg = 1024**2
64
+ one_gig = 1024**3
65
+ _device = adapter.requestDevice(
66
+ label="WebGPU device",
67
+ requiredLimits=Limits(
68
+ maxBufferSize=one_gig - 16,
69
+ maxStorageBufferBindingSize=one_gig - 16,
70
+ ),
71
+ )
72
+ try:
73
+ _device = await _device
74
+ except:
75
+ pass
76
+
77
+ limits = _device.limits
78
+ js.console.log("device limits\n", limits)
79
+ js.console.log("adapter info\n", adapter.info)
80
+
81
+ print(
82
+ f"max storage buffer binding size {limits.maxStorageBufferBindingSize / one_meg:.2f} MB"
83
+ )
84
+ print(f"max buffer size {limits.maxBufferSize / one_meg:.2f} MB")
85
+
86
+ return _device
87
+
88
+
89
+ def get_device() -> Device:
90
+ if _device is None:
91
+ raise RuntimeError("Device not initialized")
92
+ return _device
93
+
94
+
95
+ class Pyodide:
96
+ def __init__(self):
97
+ pass
98
+
99
+ def __setattr__(self, key, value):
100
+ pass
101
+
102
+
103
+ def find_shader_file(file_name, module_file) -> Path:
104
+ for path in [module_file, __file__]:
105
+ file_path = Path(path).parent / "shaders" / file_name
106
+ if file_path.exists():
107
+ return file_path
108
+
109
+ raise FileNotFoundError(f"Shader file {file_name} not found")
110
+
111
+
112
+ def read_shader_file(file_name, module_file) -> str:
113
+ code = find_shader_file(file_name, module_file).read_text()
114
+ if not "#import" in code:
115
+ return code
116
+ lines = code.split("\n")
117
+ code = ""
118
+ for line in lines:
119
+ if line.startswith("#import"):
120
+ imported_file = line.split()[1] + ".wgsl"
121
+ code += f"// start file {imported_file}\n"
122
+ code += read_shader_file(imported_file, module_file) + "\n"
123
+ code += f"// end file {imported_file}\n"
124
+ else:
125
+ code += line + "\n"
126
+ return code
127
+
128
+
129
+ def encode_bytes(data: bytes) -> str:
130
+ if data == b"":
131
+ return ""
132
+ return base64.b64encode(zlib.compress(data)).decode("utf-8")
133
+
134
+
135
+ def decode_bytes(data: str) -> bytes:
136
+ if data == "":
137
+ return b""
138
+ return zlib.decompress(base64.b64decode(data.encode()))
139
+
140
+
141
+ class BaseBinding:
142
+ """Base class for any object that has a binding number (uniform, storage buffer, texture etc.)"""
143
+
144
+ def __init__(
145
+ self,
146
+ nr,
147
+ visibility=ShaderStage.ALL,
148
+ resource=None,
149
+ layout=None,
150
+ binding=None,
151
+ ):
152
+ self.nr = nr
153
+ self.visibility = visibility
154
+ self._layout_data = layout or {}
155
+ self._binding_data = binding or {}
156
+ self._resource = resource or {}
157
+
158
+ @property
159
+ def layout(self):
160
+ return {
161
+ "binding": self.nr,
162
+ "visibility": self.visibility,
163
+ } | self._layout_data
164
+
165
+ @property
166
+ def binding(self):
167
+ return {
168
+ "binding": self.nr,
169
+ "resource": self._resource,
170
+ }
171
+
172
+
173
+ class UniformBinding(BaseBinding):
174
+ def __init__(self, nr, buffer, visibility=ShaderStage.ALL):
175
+ super().__init__(
176
+ nr=nr,
177
+ visibility=visibility,
178
+ layout={"buffer": {"type": "uniform"}},
179
+ resource={"buffer": buffer},
180
+ )
181
+
182
+
183
+ class StorageTextureBinding(BaseBinding):
184
+ def __init__(
185
+ self,
186
+ nr,
187
+ texture,
188
+ visibility=ShaderStage.COMPUTE,
189
+ dim=2,
190
+ access="write-only",
191
+ ):
192
+ super().__init__(
193
+ nr=nr,
194
+ visibility=visibility,
195
+ layout={
196
+ "storageTexture": {
197
+ "access": access,
198
+ "format": texture.format,
199
+ "viewDimension": f"{dim}d",
200
+ }
201
+ },
202
+ resource=texture.createView(),
203
+ )
204
+
205
+
206
+ class TextureBinding(BaseBinding):
207
+ def __init__(
208
+ self,
209
+ nr,
210
+ texture,
211
+ visibility=ShaderStage.FRAGMENT,
212
+ sample_type="float",
213
+ dim=1,
214
+ multisamples=False,
215
+ ):
216
+ super().__init__(
217
+ nr=nr,
218
+ visibility=visibility,
219
+ layout={
220
+ "texture": {
221
+ "sampleType": sample_type,
222
+ "viewDimension": f"{dim}d",
223
+ "multisamples": multisamples,
224
+ }
225
+ },
226
+ resource=texture.createView(),
227
+ )
228
+
229
+
230
+ class SamplerBinding(BaseBinding):
231
+ def __init__(self, nr, sampler, visibility=ShaderStage.FRAGMENT):
232
+ super().__init__(
233
+ nr=nr,
234
+ visibility=visibility,
235
+ layout={"sampler": {"type": "filtering"}},
236
+ resource=sampler,
237
+ )
238
+
239
+
240
+ class BufferBinding(BaseBinding):
241
+ def __init__(self, nr, buffer, read_only=True, visibility=ShaderStage.ALL):
242
+ type_ = "read-only-storage" if read_only else "storage"
243
+ if not read_only:
244
+ visibility = ShaderStage.COMPUTE
245
+ super().__init__(
246
+ nr=nr,
247
+ visibility=visibility,
248
+ layout={"buffer": {"type": type_}},
249
+ resource={"buffer": buffer},
250
+ )
251
+
252
+
253
+ def create_bind_group(device, bindings: list, label=""):
254
+ """creates bind group layout and bind group from a list of BaseBinding objects"""
255
+ layouts = []
256
+ resources = []
257
+ for binding in bindings:
258
+ layouts.append(BindGroupLayoutEntry(**binding.layout))
259
+ resources.append(BindGroupEntry(**binding.binding))
260
+
261
+ layout = device.createBindGroupLayout(entries=layouts, label=label)
262
+ group = device.createBindGroup(
263
+ label=label,
264
+ layout=layout,
265
+ entries=resources,
266
+ )
267
+ return layout, group
268
+
269
+
270
+ class TimeQuery:
271
+ def __init__(self, device):
272
+ self.device = device
273
+ self.query_set = self.device.createQuerySet(
274
+ to_js({"type": "timestamp", "count": 2})
275
+ )
276
+ self.buffer = self.device.createBuffer(
277
+ size=16,
278
+ usage=BufferUsage.COPY_DST | BufferUsage.MAP_READ,
279
+ )
280
+
281
+
282
+ def reload_package(package_name):
283
+ """Reload python package and all submodules (searches in modules for references to other submodules)"""
284
+ import importlib
285
+ import os
286
+ import types
287
+
288
+ package = importlib.import_module(package_name)
289
+ assert hasattr(package, "__package__")
290
+ file_name = package.__file__
291
+ package_dir = os.path.dirname(file_name) + os.sep
292
+ reloaded_modules = {file_name: package}
293
+
294
+ def reload_recursive(module):
295
+ module = importlib.reload(module)
296
+
297
+ for var in vars(module).values():
298
+ if isinstance(var, types.ModuleType):
299
+ file_name = getattr(var, "__file__", None)
300
+ if file_name is not None and file_name.startswith(package_dir):
301
+ if file_name not in reloaded_modules:
302
+ reloaded_modules[file_name] = reload_recursive(var)
303
+
304
+ return module
305
+
306
+ reload_recursive(package)
307
+ return reloaded_modules
308
+
309
+
310
+ def run_compute_shader(
311
+ encoder, code, bindings, n_workgroups, label="compute", entry_point="main"
312
+ ):
313
+ from webgpu.utils import create_bind_group, get_device
314
+
315
+ device = get_device()
316
+
317
+ shader_module = device.createShaderModule(code)
318
+
319
+ layout, bind_group = create_bind_group(device, bindings, label)
320
+ pipeline = device.createComputePipeline(
321
+ device.createPipelineLayout([layout], label),
322
+ ComputeState(
323
+ shader_module,
324
+ entry_point,
325
+ ),
326
+ label,
327
+ )
328
+
329
+ pass_encoder = encoder.beginComputePass()
330
+ pass_encoder.setPipeline(pipeline)
331
+ pass_encoder.setBindGroup(0, bind_group)
332
+ pass_encoder.dispatchWorkgroups(*n_workgroups)
333
+ pass_encoder.end()
334
+
335
+
336
+ def buffer_from_array(array, usage=BufferUsage.STORAGE | BufferUsage.COPY_DST):
337
+ device = get_device()
338
+ buffer = device.createBuffer(array.size * array.itemsize, usage=usage)
339
+ device.queue.writeBuffer(buffer, 0, array.tobytes())
340
+ return buffer
341
+
342
+
343
+ def uniform_from_array(array):
344
+ return buffer_from_array(array, usage=BufferUsage.UNIFORM | BufferUsage.COPY_DST)
345
+
346
+
347
+ class ReadBuffer:
348
+ def __init__(self, buffer, encoder):
349
+ self.buffer = buffer
350
+ self.read_buffer = get_device().createBuffer(
351
+ buffer.size, BufferUsage.MAP_READ | BufferUsage.COPY_DST
352
+ )
353
+ encoder.copyBufferToBuffer(self.buffer, 0, self.read_buffer, 0, buffer.size)
354
+
355
+ def get_array(self, dtype):
356
+ import numpy as np
357
+
358
+ self.read_buffer.handle.mapAsync(MapMode.READ, 0, self.read_buffer.size)
359
+ data = self.read_buffer.handle.getMappedRange(0, self.read_buffer.size)
360
+ res = np.frombuffer(data, dtype=dtype)
361
+ self.read_buffer.unmap()
362
+ return res
363
+
364
+
365
+ def max_bounding_box(boxes):
366
+ import numpy as np
367
+
368
+ boxes = [b for b in boxes if b is not None]
369
+ pmin = np.array(boxes[0][0])
370
+ pmax = np.array(boxes[0][1])
371
+ for b in boxes[1:]:
372
+ pmin = np.minimum(pmin, np.array(b[0]))
373
+ pmax = np.maximum(pmax, np.array(b[1]))
374
+ return (pmin, pmax)
375
+
376
+
377
+ def format_number(n):
378
+ if n == 0:
379
+ return "0"
380
+ abs_n = abs(n)
381
+ # Use scientific notation for numbers smaller than 0.001 or larger than 9999
382
+ if abs_n < 1e-2 or abs_n >= 1e3:
383
+ return f"{n:.2e}"
384
+ else:
385
+ return f"{n:.3g}"
webgpu/vectors.py ADDED
@@ -0,0 +1,103 @@
1
+ import numpy as np
2
+
3
+ from . import BufferBinding, Colormap, RenderObject, read_shader_file
4
+ from .uniforms import UniformBase, ct
5
+ from .utils import buffer_from_array
6
+ from .webgpu_api import PrimitiveTopology
7
+
8
+
9
+ class Binding:
10
+ POINTS = 81
11
+ VECTORS = 82
12
+ OPTIONS = 83
13
+
14
+
15
+ class VectorUniform(UniformBase):
16
+ _binding = Binding.OPTIONS
17
+ _fields_ = [
18
+ ("size", ct.c_float),
19
+ ("scale_veclen", ct.c_uint32),
20
+ ("_padding2", ct.c_float),
21
+ ("_padding3", ct.c_float),
22
+ ]
23
+
24
+
25
+ class BaseVectorRenderObject(RenderObject):
26
+ topology = PrimitiveTopology.triangle_strip
27
+ n_vertices = 10
28
+
29
+ def __init__(self, label="VectorField"):
30
+ super().__init__(label=label)
31
+ self.colormap = Colormap()
32
+
33
+ def update(self, timestamp):
34
+ if timestamp == self._timestamp:
35
+ return
36
+ self._timestamp = timestamp
37
+
38
+ self.colormap.options = self.options
39
+ self.colormap.update(timestamp)
40
+
41
+ def get_bindings(self):
42
+ return [
43
+ *self.options.camera.get_bindings(),
44
+ *self.options.light.get_bindings(),
45
+ BufferBinding(Binding.POINTS, self._buffers["points"]),
46
+ BufferBinding(Binding.VECTORS, self._buffers["vectors"]),
47
+ *self.vec_uniforms.get_bindings(),
48
+ *self.colormap.get_bindings(),
49
+ ]
50
+
51
+ def create_vector_data(self):
52
+ raise NotImplementedError
53
+
54
+ def get_shader_code(self):
55
+ shader_code = read_shader_file("vector.wgsl", __file__)
56
+ shader_code += self.options.camera.get_shader_code()
57
+ shader_code += self.options.light.get_shader_code()
58
+ shader_code += self.colormap.get_shader_code()
59
+ return shader_code
60
+
61
+ def render(self, encoder):
62
+ super().render(encoder)
63
+
64
+
65
+ class VectorRenderer(BaseVectorRenderObject):
66
+ def __init__(self, points, vectors, size=None, scale_with_vector_length=False):
67
+ super().__init__(label="VectorField")
68
+ self.scale_with_vector_length = scale_with_vector_length
69
+ self.points = np.asarray(points, dtype=np.float32).reshape(-1)
70
+ self.vectors = np.asarray(vectors, dtype=np.float32).reshape(-1)
71
+ self.bounding_box = self.points.reshape(-1, 3).min(axis=0), self.points.reshape(
72
+ -1, 3
73
+ ).max(axis=0)
74
+ self.size = size or 1 / 10 * np.linalg.norm(
75
+ self.bounding_box[1] - self.bounding_box[0]
76
+ )
77
+
78
+ def update(self, timestamp):
79
+ if timestamp == self._timestamp:
80
+ return
81
+
82
+ super().update(timestamp)
83
+
84
+ self._buffers = {
85
+ "points": buffer_from_array(self.points),
86
+ "vectors": buffer_from_array(self.vectors),
87
+ }
88
+ self.vec_uniforms = VectorUniform(self.device)
89
+ self.vec_uniforms.size = self.size
90
+ self.vec_uniforms.scale_veclen = self.scale_with_vector_length
91
+ self.vec_uniforms.update_buffer()
92
+ min_vec, max_vec = (
93
+ np.linalg.norm(self.vectors.reshape(-1, 3), axis=1).min(),
94
+ np.linalg.norm(self.vectors.reshape(-1, 3), axis=1).max(),
95
+ )
96
+ if self.colormap.autoupdate:
97
+ self.colormap.set_min_max(min_vec, max_vec, set_autoupdate=False)
98
+ self.colormap.update(timestamp + 1)
99
+ self.n_instances = len(self.points) // 3
100
+ self.create_render_pipeline()
101
+
102
+ def get_bounding_box(self):
103
+ return self.bounding_box