webgpu 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
webgpu/colormap.py ADDED
@@ -0,0 +1,325 @@
1
+ from .labels import Labels
2
+ from .render_object import MultipleRenderObject, RenderObject
3
+ from .uniforms import Binding, UniformBase, ct
4
+ from .utils import SamplerBinding, TextureBinding, format_number, read_shader_file
5
+ from .webgpu_api import (
6
+ TexelCopyBufferLayout,
7
+ TexelCopyTextureInfo,
8
+ Texture,
9
+ TextureFormat,
10
+ TextureUsage,
11
+ )
12
+
13
+
14
+ class ColormapUniforms(UniformBase):
15
+ _binding = Binding.COLORMAP
16
+ _fields_ = [
17
+ ("min", ct.c_float),
18
+ ("max", ct.c_float),
19
+ ("position_x", ct.c_float),
20
+ ("position_y", ct.c_float),
21
+ ("discrete", ct.c_uint32),
22
+ ("n_colors", ct.c_uint32),
23
+ ("width", ct.c_float),
24
+ ("height", ct.c_float),
25
+ ]
26
+
27
+
28
+ class Colorbar(RenderObject):
29
+ texture: Texture
30
+ vertex_entry_point: str = "colormap_vertex"
31
+ fragment_entry_point: str = "colormap_fragment"
32
+ n_vertices: int = 3
33
+
34
+ def __init__(self, minval=0, maxval=1):
35
+ self.texture = None
36
+ self.minval = minval
37
+ self.maxval = maxval
38
+ self.position_x = -0.9
39
+ self.position_y = 0.9
40
+ self.discrete = 0
41
+ self.n_colors = 8
42
+ self.width = 1.0
43
+ self.height = 0.05
44
+ self.uniforms = None
45
+ self.sampler = None
46
+ self.autoupdate = True
47
+
48
+ def update(self, timestamp):
49
+ if timestamp == self._timestamp:
50
+ return
51
+ self._timestamp = timestamp
52
+ if self.uniforms is None:
53
+ self.uniforms = ColormapUniforms(self.device)
54
+ self.uniforms.min = self.minval
55
+ self.uniforms.max = self.maxval
56
+ self.uniforms.position_x = self.position_x
57
+ self.uniforms.position_y = self.position_y
58
+ self.uniforms.discrete = self.discrete
59
+ self.uniforms.n_colors = self.n_colors
60
+ self.uniforms.width = self.width
61
+ self.uniforms.height = self.height
62
+ self.n_instances = 2 * self.n_colors
63
+ self.uniforms.update_buffer()
64
+
65
+ if self.sampler is None:
66
+ self.sampler = self.device.createSampler(
67
+ magFilter="linear",
68
+ minFilter="linear",
69
+ )
70
+
71
+ if self.texture is None:
72
+ self.set_colormap("matlab:jet")
73
+ self.create_render_pipeline()
74
+
75
+ def get_bounding_box(self):
76
+ return None
77
+
78
+ def set_n_colors(self, n_colors):
79
+ self.n_instances = 2 * n_colors
80
+ if self.uniforms is not None:
81
+ self.uniforms.n_colors = n_colors
82
+ self.uniforms.update_buffer()
83
+
84
+ def set_min_max(self, minval, maxval, set_autoupdate=True):
85
+ self.minval = minval
86
+ self.maxval = maxval
87
+ if set_autoupdate:
88
+ self.autoupdate = False
89
+ if self.uniforms is not None:
90
+ self.uniforms.min = minval
91
+ self.uniforms.max = maxval
92
+ self.uniforms.update_buffer()
93
+
94
+ def get_bindings(self):
95
+ return [
96
+ TextureBinding(Binding.COLORMAP_TEXTURE, self.texture),
97
+ SamplerBinding(Binding.COLORMAP_SAMPLER, self.sampler),
98
+ *self.uniforms.get_bindings(),
99
+ ]
100
+
101
+ def get_shader_code(self):
102
+ return read_shader_file("colormap.wgsl", __file__)
103
+
104
+ def set_colormap(self, name: str):
105
+ if self.texture is not None:
106
+ self.texture.destroy()
107
+
108
+ data = _colormaps[name]
109
+ n = len(data)
110
+ v4 = [v + [255] for v in data]
111
+ data = sum(v4, [])
112
+
113
+ self.texture = self.device.createTexture(
114
+ size=[n, 1, 1],
115
+ usage=TextureUsage.TEXTURE_BINDING | TextureUsage.COPY_DST,
116
+ format=TextureFormat.rgba8unorm,
117
+ dimension="1d",
118
+ )
119
+ self.device.queue.writeTexture(
120
+ TexelCopyTextureInfo(self.texture),
121
+ data,
122
+ TexelCopyBufferLayout(bytesPerRow=n * 4),
123
+ [n, 1, 1],
124
+ )
125
+
126
+
127
+ class Colormap(MultipleRenderObject):
128
+ def __init__(self):
129
+ self.colorbar = Colorbar()
130
+ self.labels = Labels([], [], font_size=14, h_align="center", v_align="top")
131
+ self.update_labels()
132
+ super().__init__([self.colorbar, self.labels])
133
+
134
+ @property
135
+ def autoupdate(self):
136
+ return self.colorbar.autoupdate
137
+
138
+ def get_shader_code(self):
139
+ return self.colorbar.get_shader_code()
140
+
141
+ def get_bindings(self):
142
+ return self.colorbar.get_bindings()
143
+
144
+ @autoupdate.setter
145
+ def autoupdate(self, value):
146
+ self.colorbar.autoupdate = value
147
+
148
+ def set_min_max(self, min, max, set_autoupdate=True):
149
+ self.colorbar.set_min_max(min, max, set_autoupdate)
150
+ self.update_labels()
151
+
152
+ def update_labels(self):
153
+ self.labels.labels = [
154
+ format_number(v)
155
+ for v in [
156
+ self.colorbar.minval + i / 4 * (self.colorbar.maxval - self.colorbar.minval)
157
+ for i in range(6)
158
+ ]
159
+ ]
160
+ self.labels.positions = [
161
+ (
162
+ self.colorbar.position_x + i * self.colorbar.width / 4,
163
+ self.colorbar.position_y - 0.01,
164
+ 0,
165
+ )
166
+ for i in range(5)
167
+ ]
168
+
169
+ def get_bounding_box(self):
170
+ return None
171
+
172
+
173
+ _colormaps = {
174
+ "viridis": [
175
+ [68, 1, 84],
176
+ [71, 13, 96],
177
+ [72, 24, 106],
178
+ [72, 35, 116],
179
+ [71, 45, 123],
180
+ [69, 55, 129],
181
+ [66, 64, 134],
182
+ [62, 73, 137],
183
+ [59, 82, 139],
184
+ [55, 91, 141],
185
+ [51, 99, 141],
186
+ [47, 107, 142],
187
+ [44, 114, 142],
188
+ [41, 122, 142],
189
+ [38, 130, 142],
190
+ [35, 137, 142],
191
+ [33, 145, 140],
192
+ [31, 152, 139],
193
+ [31, 160, 136],
194
+ [34, 167, 133],
195
+ [40, 174, 128],
196
+ [50, 182, 122],
197
+ [63, 188, 115],
198
+ [78, 195, 107],
199
+ [94, 201, 98],
200
+ [112, 207, 87],
201
+ [132, 212, 75],
202
+ [152, 216, 62],
203
+ [173, 220, 48],
204
+ [194, 223, 35],
205
+ [216, 226, 25],
206
+ [236, 229, 27],
207
+ ],
208
+ "plasma": [
209
+ [13, 8, 135],
210
+ [34, 6, 144],
211
+ [49, 5, 151],
212
+ [63, 4, 156],
213
+ [76, 2, 161],
214
+ [89, 1, 165],
215
+ [102, 0, 167],
216
+ [114, 1, 168],
217
+ [126, 3, 168],
218
+ [138, 9, 165],
219
+ [149, 17, 161],
220
+ [160, 26, 156],
221
+ [170, 35, 149],
222
+ [179, 44, 142],
223
+ [188, 53, 135],
224
+ [196, 62, 127],
225
+ [204, 71, 120],
226
+ [211, 81, 113],
227
+ [218, 90, 106],
228
+ [224, 99, 99],
229
+ [230, 108, 92],
230
+ [235, 118, 85],
231
+ [240, 128, 78],
232
+ [245, 139, 71],
233
+ [248, 149, 64],
234
+ [251, 161, 57],
235
+ [253, 172, 51],
236
+ [254, 184, 44],
237
+ [253, 197, 39],
238
+ [252, 210, 37],
239
+ [248, 223, 37],
240
+ [244, 237, 39],
241
+ ],
242
+ "cet_l20": [
243
+ [48, 48, 48],
244
+ [55, 51, 69],
245
+ [60, 54, 89],
246
+ [64, 57, 108],
247
+ [66, 61, 127],
248
+ [67, 65, 145],
249
+ [67, 69, 162],
250
+ [65, 75, 176],
251
+ [63, 81, 188],
252
+ [59, 88, 197],
253
+ [55, 97, 201],
254
+ [50, 107, 197],
255
+ [41, 119, 183],
256
+ [34, 130, 166],
257
+ [37, 139, 149],
258
+ [49, 147, 133],
259
+ [66, 154, 118],
260
+ [85, 160, 103],
261
+ [108, 165, 87],
262
+ [130, 169, 72],
263
+ [150, 173, 58],
264
+ [170, 176, 43],
265
+ [190, 179, 29],
266
+ [211, 181, 19],
267
+ [230, 183, 19],
268
+ [241, 188, 20],
269
+ [248, 194, 20],
270
+ [252, 202, 20],
271
+ [254, 211, 19],
272
+ [255, 220, 17],
273
+ [254, 230, 15],
274
+ [252, 240, 13],
275
+ ],
276
+ "matlab:jet": [
277
+ [0, 0, 128],
278
+ [0, 0, 164],
279
+ [0, 0, 200],
280
+ [0, 0, 237],
281
+ [0, 1, 255],
282
+ [0, 33, 255],
283
+ [0, 65, 255],
284
+ [0, 96, 255],
285
+ [0, 129, 255],
286
+ [0, 161, 255],
287
+ [0, 193, 255],
288
+ [0, 225, 251],
289
+ [22, 255, 225],
290
+ [48, 255, 199],
291
+ [73, 255, 173],
292
+ [99, 255, 148],
293
+ [125, 255, 122],
294
+ [151, 255, 96],
295
+ [177, 255, 70],
296
+ [202, 255, 44],
297
+ [228, 255, 19],
298
+ [254, 237, 0],
299
+ [255, 208, 0],
300
+ [255, 178, 0],
301
+ [255, 148, 0],
302
+ [255, 119, 0],
303
+ [255, 89, 0],
304
+ [255, 59, 0],
305
+ [255, 30, 0],
306
+ [232, 0, 0],
307
+ [196, 0, 0],
308
+ [159, 0, 0],
309
+ ],
310
+ }
311
+
312
+
313
+ if __name__ == "__main__":
314
+ from cmap import Colormap
315
+
316
+ print("_colormaps = {")
317
+ for name in ["viridis", "plasma", "cet_l20", "matlab:jet"]:
318
+ print(f" '{name}' : [")
319
+ cm = Colormap(name)
320
+ for i in range(32):
321
+ c = cm(i / 32)
322
+ r, g, b = [int(255 * c[i] + 0.5) for i in range(3)]
323
+ print(f" [{r}, {g}, {b}],")
324
+ print(" ],")
325
+ print("}")
webgpu/draw.py ADDED
@@ -0,0 +1,35 @@
1
+ from .canvas import Canvas
2
+ from .lilgui import LilGUI
3
+ from .render_object import BaseRenderObject
4
+ from .scene import Scene
5
+ from .utils import max_bounding_box
6
+
7
+
8
+ def Draw(
9
+ scene: Scene | BaseRenderObject | list[BaseRenderObject],
10
+ canvas: Canvas,
11
+ lilgui=True,
12
+ ) -> Scene:
13
+ import numpy as np
14
+
15
+ if isinstance(scene, BaseRenderObject):
16
+ scene = Scene([scene])
17
+ elif isinstance(scene, list):
18
+ scene = Scene(scene)
19
+ scene.init(canvas)
20
+ if lilgui:
21
+ scene.gui = LilGUI(canvas.canvas.id, scene._id)
22
+
23
+ objects = scene.render_objects
24
+ pmin, pmax = max_bounding_box([o.get_bounding_box() for o in objects])
25
+
26
+ camera = scene.options.camera
27
+ camera.transform._center = 0.5 * (pmin + pmax)
28
+ camera.transform._scale = 2 / np.linalg.norm(pmax - pmin)
29
+
30
+ if not (pmin[2] == 0 and pmax[2] == 0):
31
+ camera.transform.rotate(30, -20)
32
+ camera._update_uniforms()
33
+ scene.render()
34
+
35
+ return scene
webgpu/font.py ADDED
@@ -0,0 +1,162 @@
1
+ import base64
2
+ import json
3
+ import os
4
+ import zlib
5
+
6
+ from .uniforms import Binding, UniformBase, ct
7
+ from .utils import Device, TextureBinding, read_shader_file
8
+ from .webgpu_api import *
9
+
10
+
11
+ def create_font_texture(device: Device, size: int = 15):
12
+ fonts = json.load(open(os.path.join(os.path.dirname(__file__), "fonts.json")))
13
+
14
+ dist = 0
15
+ while str(size) not in fonts:
16
+ # try to find the closest available font size
17
+ dist += 1
18
+ if dist > 20:
19
+ raise ValueError(f"Font size {size} not found")
20
+
21
+ if str(size + dist) in fonts:
22
+ size += dist
23
+ break
24
+
25
+ if str(size - dist) in fonts:
26
+ size -= dist
27
+ break
28
+
29
+ font = fonts[str(size)]
30
+ data = zlib.decompress(base64.b64decode(font["data"]))
31
+ w = font["width"]
32
+ h = font["height"]
33
+
34
+ tex_width = w * (127 - 32)
35
+
36
+ texture = device.createTexture(
37
+ size=[tex_width, h, 1],
38
+ usage=TextureUsage.TEXTURE_BINDING | TextureUsage.COPY_DST,
39
+ format=TextureFormat.r8unorm,
40
+ label="font",
41
+ )
42
+ device.queue.writeTexture(
43
+ TexelCopyTextureInfo(texture),
44
+ data,
45
+ TexelCopyBufferLayout(bytesPerRow=tex_width),
46
+ size=[tex_width, h, 1],
47
+ )
48
+
49
+ return texture
50
+
51
+
52
+ def _get_default_font():
53
+ import os
54
+
55
+ # font = "/usr/share/fonts/TTF/JetBrainsMonoNerdFont-Regular.ttf"
56
+ font = ""
57
+ if not os.path.exists(font):
58
+ from matplotlib import font_manager
59
+
60
+ for f in font_manager.fontManager.ttflist:
61
+ if "mono" in f.name.lower():
62
+ font = f.fname
63
+ if f.fname.lower().endswith("DejaVuSansMono.ttf".lower()):
64
+ break
65
+
66
+ return font
67
+
68
+
69
+ def create_font_data(size: int = 15, font_file: str = ""):
70
+ from PIL import Image, ImageDraw, ImageFont
71
+
72
+ font_file = font_file or _get_default_font()
73
+ text = "".join([chr(i) for i in range(32, 127)]) # printable ascii characters
74
+
75
+ # disable ligatures and other features, because they are merging characters
76
+ # this is not desired when using the rendered image as a texture
77
+ features = [
78
+ "-liga",
79
+ "-kern",
80
+ "-calt",
81
+ "-clig",
82
+ "-ccmp",
83
+ "-locl",
84
+ "-mark",
85
+ "-mkmk",
86
+ "-rlig",
87
+ ]
88
+
89
+ font = ImageFont.truetype(font_file, size)
90
+ x0, y0, x1, y1 = font.getbbox("$", features=features)
91
+
92
+ # the actual height is usually a few pixels less than the font size
93
+ h = round(y1 - y0)
94
+ w = round(x1 - x0)
95
+
96
+ # create an image with the text (greyscale, will be used as alpha channel on the gpu)
97
+ image = Image.new("L", (len(text) * w, h), (0))
98
+ draw = ImageDraw.Draw(image)
99
+ for i, c in enumerate(text):
100
+ draw.text((i * w, -y0), c, font=font, fill=(255), features=features)
101
+
102
+ # image.save(f"out_{size}.png")
103
+ return image.tobytes(), w, h
104
+
105
+
106
+ class FontUniforms(UniformBase):
107
+ _binding = Binding.FONT
108
+ _fields_ = [
109
+ ("width", ct.c_uint32),
110
+ ("height", ct.c_uint32),
111
+ ("width_normalized", ct.c_float),
112
+ ("height_normalized", ct.c_float),
113
+ ]
114
+
115
+
116
+ class Font:
117
+ def __init__(self, canvas, size=15):
118
+ self.canvas = canvas
119
+ self.uniforms = FontUniforms(canvas.device)
120
+ self.set_font_size(size)
121
+
122
+ self.canvas.on_resize(self.update)
123
+
124
+ def get_bindings(self):
125
+ return [
126
+ TextureBinding(Binding.FONT_TEXTURE, self._texture, dim=2),
127
+ *self.uniforms.get_bindings(),
128
+ ]
129
+
130
+ def get_shader_code(self):
131
+ return read_shader_file("font.wgsl", __file__)
132
+
133
+ def set_font_size(self, font_size: int):
134
+ from .font import create_font_texture
135
+
136
+ self._texture = create_font_texture(self.canvas.device, font_size)
137
+ char_width = self._texture.width // (127 - 32)
138
+ char_height = self._texture.height
139
+ self.uniforms.width = char_width
140
+ self.uniforms.height = char_height
141
+ self.update()
142
+
143
+ def update(self):
144
+ self.uniforms.width_normalized = 2.0 * self.uniforms.width / self.canvas.width
145
+ self.uniforms.height_normalized = 2.0 * self.uniforms.height / self.canvas.height
146
+ self.uniforms.update_buffer()
147
+
148
+
149
+ if __name__ == "__main__":
150
+ # create font data and store it as json because we cannot generate this in pyodide
151
+
152
+ fonts = {}
153
+
154
+ for size in list(range(8, 21, 2)) + [25, 30, 40]:
155
+ data, w, h = create_font_data(size)
156
+ fonts[size] = {
157
+ "data": base64.b64encode(zlib.compress(data)).decode("utf-8"),
158
+ "width": w,
159
+ "height": h,
160
+ }
161
+
162
+ json.dump(fonts, open("fonts.json", "w"), indent=2)