webgpu 0.0.1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
webgpu/colormap.py ADDED
@@ -0,0 +1,327 @@
1
+ from .webgpu_api import (
2
+ TexelCopyBufferLayout,
3
+ TexelCopyTextureInfo,
4
+ Texture,
5
+ TextureFormat,
6
+ TextureUsage,
7
+ )
8
+
9
+ from .labels import Labels
10
+ from .render_object import MultipleRenderObject, RenderObject
11
+ from .uniforms import Binding, UniformBase, ct
12
+ from .utils import SamplerBinding, TextureBinding, format_number, read_shader_file
13
+
14
+
15
+ class ColormapUniforms(UniformBase):
16
+ _binding = Binding.COLORMAP
17
+ _fields_ = [
18
+ ("min", ct.c_float),
19
+ ("max", ct.c_float),
20
+ ("position_x", ct.c_float),
21
+ ("position_y", ct.c_float),
22
+ ("discrete", ct.c_uint32),
23
+ ("n_colors", ct.c_uint32),
24
+ ("width", ct.c_float),
25
+ ("height", ct.c_float),
26
+ ]
27
+
28
+
29
+ class Colorbar(RenderObject):
30
+ texture: Texture
31
+ vertex_entry_point: str = "colormap_vertex"
32
+ fragment_entry_point: str = "colormap_fragment"
33
+ n_vertices: int = 3
34
+
35
+ def __init__(self, minval=0, maxval=1):
36
+ self.texture = None
37
+ self.minval = minval
38
+ self.maxval = maxval
39
+ self.position_x = -0.9
40
+ self.position_y = 0.9
41
+ self.discrete = 0
42
+ self.n_colors = 8
43
+ self.width = 1.0
44
+ self.height = 0.05
45
+ self.uniforms = None
46
+ self.sampler = None
47
+ self.autoupdate = True
48
+
49
+ def update(self, timestamp):
50
+ if timestamp == self._timestamp:
51
+ return
52
+ self._timestamp = timestamp
53
+ if self.uniforms is None:
54
+ self.uniforms = ColormapUniforms(self.device)
55
+ self.uniforms.min = self.minval
56
+ self.uniforms.max = self.maxval
57
+ self.uniforms.position_x = self.position_x
58
+ self.uniforms.position_y = self.position_y
59
+ self.uniforms.discrete = self.discrete
60
+ self.uniforms.n_colors = self.n_colors
61
+ self.uniforms.width = self.width
62
+ self.uniforms.height = self.height
63
+ self.n_instances = 2 * self.n_colors
64
+ self.uniforms.update_buffer()
65
+
66
+ if self.sampler is None:
67
+ self.sampler = self.device.createSampler(
68
+ magFilter="linear",
69
+ minFilter="linear",
70
+ )
71
+
72
+ if self.texture is None:
73
+ self.set_colormap("matlab:jet")
74
+ self.create_render_pipeline()
75
+
76
+ def get_bounding_box(self):
77
+ return None
78
+
79
+ def set_n_colors(self, n_colors):
80
+ self.n_instances = 2 * n_colors
81
+ if self.uniforms is not None:
82
+ self.uniforms.n_colors = n_colors
83
+ self.uniforms.update_buffer()
84
+
85
+ def set_min_max(self, minval, maxval, set_autoupdate=True):
86
+ self.minval = minval
87
+ self.maxval = maxval
88
+ if set_autoupdate:
89
+ self.autoupdate = False
90
+ if self.uniforms is not None:
91
+ self.uniforms.min = minval
92
+ self.uniforms.max = maxval
93
+ self.uniforms.update_buffer()
94
+
95
+ def get_bindings(self):
96
+ return [
97
+ TextureBinding(Binding.COLORMAP_TEXTURE, self.texture),
98
+ SamplerBinding(Binding.COLORMAP_SAMPLER, self.sampler),
99
+ *self.uniforms.get_bindings(),
100
+ ]
101
+
102
+ def get_shader_code(self):
103
+ return read_shader_file("colormap.wgsl", __file__)
104
+
105
+ def set_colormap(self, name: str):
106
+ if self.texture is not None:
107
+ self.texture.destroy()
108
+
109
+ data = _colormaps[name]
110
+ n = len(data)
111
+ v4 = [v + [255] for v in data]
112
+ data = sum(v4, [])
113
+
114
+ self.texture = self.device.createTexture(
115
+ size=[n, 1, 1],
116
+ usage=TextureUsage.TEXTURE_BINDING | TextureUsage.COPY_DST,
117
+ format=TextureFormat.rgba8unorm,
118
+ dimension="1d",
119
+ )
120
+ self.device.queue.writeTexture(
121
+ TexelCopyTextureInfo(self.texture),
122
+ data,
123
+ TexelCopyBufferLayout(bytesPerRow=n * 4),
124
+ [n, 1, 1],
125
+ )
126
+
127
+
128
+ class Colormap(MultipleRenderObject):
129
+ def __init__(self):
130
+ self.colorbar = Colorbar()
131
+ self.labels = Labels([], [], font_size=14, h_align="center", v_align="top")
132
+ self.update_labels()
133
+ super().__init__([self.colorbar, self.labels])
134
+
135
+ @property
136
+ def autoupdate(self):
137
+ return self.colorbar.autoupdate
138
+
139
+ def get_shader_code(self):
140
+ return self.colorbar.get_shader_code()
141
+
142
+ def get_bindings(self):
143
+ return self.colorbar.get_bindings()
144
+
145
+ @autoupdate.setter
146
+ def autoupdate(self, value):
147
+ self.colorbar.autoupdate = value
148
+
149
+ def set_min_max(self, min, max, set_autoupdate=True):
150
+ self.colorbar.set_min_max(min, max, set_autoupdate)
151
+ self.update_labels()
152
+
153
+ def update_labels(self):
154
+ self.labels.labels = [
155
+ format_number(v)
156
+ for v in [
157
+ self.colorbar.minval
158
+ + i / 4 * (self.colorbar.maxval - self.colorbar.minval)
159
+ for i in range(6)
160
+ ]
161
+ ]
162
+ self.labels.positions = [
163
+ (
164
+ self.colorbar.position_x + i * self.colorbar.width / 4,
165
+ self.colorbar.position_y - 0.01,
166
+ 0,
167
+ )
168
+ for i in range(5)
169
+ ]
170
+
171
+ def get_bounding_box(self):
172
+ return None
173
+
174
+
175
+ _colormaps = {
176
+ "viridis": [
177
+ [68, 1, 84],
178
+ [71, 13, 96],
179
+ [72, 24, 106],
180
+ [72, 35, 116],
181
+ [71, 45, 123],
182
+ [69, 55, 129],
183
+ [66, 64, 134],
184
+ [62, 73, 137],
185
+ [59, 82, 139],
186
+ [55, 91, 141],
187
+ [51, 99, 141],
188
+ [47, 107, 142],
189
+ [44, 114, 142],
190
+ [41, 122, 142],
191
+ [38, 130, 142],
192
+ [35, 137, 142],
193
+ [33, 145, 140],
194
+ [31, 152, 139],
195
+ [31, 160, 136],
196
+ [34, 167, 133],
197
+ [40, 174, 128],
198
+ [50, 182, 122],
199
+ [63, 188, 115],
200
+ [78, 195, 107],
201
+ [94, 201, 98],
202
+ [112, 207, 87],
203
+ [132, 212, 75],
204
+ [152, 216, 62],
205
+ [173, 220, 48],
206
+ [194, 223, 35],
207
+ [216, 226, 25],
208
+ [236, 229, 27],
209
+ ],
210
+ "plasma": [
211
+ [13, 8, 135],
212
+ [34, 6, 144],
213
+ [49, 5, 151],
214
+ [63, 4, 156],
215
+ [76, 2, 161],
216
+ [89, 1, 165],
217
+ [102, 0, 167],
218
+ [114, 1, 168],
219
+ [126, 3, 168],
220
+ [138, 9, 165],
221
+ [149, 17, 161],
222
+ [160, 26, 156],
223
+ [170, 35, 149],
224
+ [179, 44, 142],
225
+ [188, 53, 135],
226
+ [196, 62, 127],
227
+ [204, 71, 120],
228
+ [211, 81, 113],
229
+ [218, 90, 106],
230
+ [224, 99, 99],
231
+ [230, 108, 92],
232
+ [235, 118, 85],
233
+ [240, 128, 78],
234
+ [245, 139, 71],
235
+ [248, 149, 64],
236
+ [251, 161, 57],
237
+ [253, 172, 51],
238
+ [254, 184, 44],
239
+ [253, 197, 39],
240
+ [252, 210, 37],
241
+ [248, 223, 37],
242
+ [244, 237, 39],
243
+ ],
244
+ "cet_l20": [
245
+ [48, 48, 48],
246
+ [55, 51, 69],
247
+ [60, 54, 89],
248
+ [64, 57, 108],
249
+ [66, 61, 127],
250
+ [67, 65, 145],
251
+ [67, 69, 162],
252
+ [65, 75, 176],
253
+ [63, 81, 188],
254
+ [59, 88, 197],
255
+ [55, 97, 201],
256
+ [50, 107, 197],
257
+ [41, 119, 183],
258
+ [34, 130, 166],
259
+ [37, 139, 149],
260
+ [49, 147, 133],
261
+ [66, 154, 118],
262
+ [85, 160, 103],
263
+ [108, 165, 87],
264
+ [130, 169, 72],
265
+ [150, 173, 58],
266
+ [170, 176, 43],
267
+ [190, 179, 29],
268
+ [211, 181, 19],
269
+ [230, 183, 19],
270
+ [241, 188, 20],
271
+ [248, 194, 20],
272
+ [252, 202, 20],
273
+ [254, 211, 19],
274
+ [255, 220, 17],
275
+ [254, 230, 15],
276
+ [252, 240, 13],
277
+ ],
278
+ "matlab:jet": [
279
+ [0, 0, 128],
280
+ [0, 0, 164],
281
+ [0, 0, 200],
282
+ [0, 0, 237],
283
+ [0, 1, 255],
284
+ [0, 33, 255],
285
+ [0, 65, 255],
286
+ [0, 96, 255],
287
+ [0, 129, 255],
288
+ [0, 161, 255],
289
+ [0, 193, 255],
290
+ [0, 225, 251],
291
+ [22, 255, 225],
292
+ [48, 255, 199],
293
+ [73, 255, 173],
294
+ [99, 255, 148],
295
+ [125, 255, 122],
296
+ [151, 255, 96],
297
+ [177, 255, 70],
298
+ [202, 255, 44],
299
+ [228, 255, 19],
300
+ [254, 237, 0],
301
+ [255, 208, 0],
302
+ [255, 178, 0],
303
+ [255, 148, 0],
304
+ [255, 119, 0],
305
+ [255, 89, 0],
306
+ [255, 59, 0],
307
+ [255, 30, 0],
308
+ [232, 0, 0],
309
+ [196, 0, 0],
310
+ [159, 0, 0],
311
+ ],
312
+ }
313
+
314
+
315
+ if __name__ == "__main__":
316
+ from cmap import Colormap
317
+
318
+ print("_colormaps = {")
319
+ for name in ["viridis", "plasma", "cet_l20", "matlab:jet"]:
320
+ print(f" '{name}' : [")
321
+ cm = Colormap(name)
322
+ for i in range(32):
323
+ c = cm(i / 32)
324
+ r, g, b = [int(255 * c[i] + 0.5) for i in range(3)]
325
+ print(f" [{r}, {g}, {b}],")
326
+ print(" ],")
327
+ print("}")
webgpu/draw.py ADDED
@@ -0,0 +1,35 @@
1
+ from .canvas import Canvas
2
+ from .lilgui import LilGUI
3
+ from .render_object import BaseRenderObject
4
+ from .scene import Scene
5
+ from .utils import max_bounding_box
6
+
7
+
8
+ def Draw(
9
+ scene: Scene | BaseRenderObject | list[BaseRenderObject],
10
+ canvas: Canvas,
11
+ lilgui=True,
12
+ ) -> Scene:
13
+ import numpy as np
14
+
15
+ if isinstance(scene, BaseRenderObject):
16
+ scene = Scene([scene])
17
+ elif isinstance(scene, list):
18
+ scene = Scene(scene)
19
+ scene.init(canvas)
20
+ if lilgui:
21
+ scene.gui = LilGUI(canvas.canvas.id, scene._id)
22
+
23
+ objects = scene.render_objects
24
+ pmin, pmax = max_bounding_box([o.get_bounding_box() for o in objects])
25
+
26
+ camera = scene.options.camera
27
+ camera.transform._center = 0.5 * (pmin + pmax)
28
+ camera.transform._scale = 2 / np.linalg.norm(pmax - pmin)
29
+
30
+ if not (pmin[2] == 0 and pmax[2] == 0):
31
+ camera.transform.rotate(30, -20)
32
+ camera._update_uniforms()
33
+ scene.render()
34
+
35
+ return scene
webgpu/font.py ADDED
@@ -0,0 +1,164 @@
1
+ import base64
2
+ import json
3
+ import os
4
+ import zlib
5
+
6
+ from .uniforms import Binding, UniformBase, ct
7
+ from .utils import Device, TextureBinding, read_shader_file
8
+ from .webgpu_api import *
9
+
10
+
11
+ def create_font_texture(device: Device, size: int = 15):
12
+ fonts = json.load(open(os.path.join(os.path.dirname(__file__), "fonts.json")))
13
+
14
+ dist = 0
15
+ while str(size) not in fonts:
16
+ # try to find the closest available font size
17
+ dist += 1
18
+ if dist > 20:
19
+ raise ValueError(f"Font size {size} not found")
20
+
21
+ if str(size + dist) in fonts:
22
+ size += dist
23
+ break
24
+
25
+ if str(size - dist) in fonts:
26
+ size -= dist
27
+ break
28
+
29
+ font = fonts[str(size)]
30
+ data = zlib.decompress(base64.b64decode(font["data"]))
31
+ w = font["width"]
32
+ h = font["height"]
33
+
34
+ tex_width = w * (127 - 32)
35
+
36
+ texture = device.createTexture(
37
+ size=[tex_width, h, 1],
38
+ usage=TextureUsage.TEXTURE_BINDING | TextureUsage.COPY_DST,
39
+ format=TextureFormat.r8unorm,
40
+ label="font",
41
+ )
42
+ device.queue.writeTexture(
43
+ TexelCopyTextureInfo(texture),
44
+ data,
45
+ TexelCopyBufferLayout(bytesPerRow=tex_width),
46
+ size=[tex_width, h, 1],
47
+ )
48
+
49
+ return texture
50
+
51
+
52
+ def _get_default_font():
53
+ import os
54
+
55
+ # font = "/usr/share/fonts/TTF/JetBrainsMonoNerdFont-Regular.ttf"
56
+ font = ""
57
+ if not os.path.exists(font):
58
+ from matplotlib import font_manager
59
+
60
+ for f in font_manager.fontManager.ttflist:
61
+ if "mono" in f.name.lower():
62
+ font = f.fname
63
+ if f.fname.lower().endswith("DejaVuSansMono.ttf".lower()):
64
+ break
65
+
66
+ return font
67
+
68
+
69
+ def create_font_data(size: int = 15, font_file: str = ""):
70
+ from PIL import Image, ImageDraw, ImageFont
71
+
72
+ font_file = font_file or _get_default_font()
73
+ text = "".join([chr(i) for i in range(32, 127)]) # printable ascii characters
74
+
75
+ # disable ligatures and other features, because they are merging characters
76
+ # this is not desired when using the rendered image as a texture
77
+ features = [
78
+ "-liga",
79
+ "-kern",
80
+ "-calt",
81
+ "-clig",
82
+ "-ccmp",
83
+ "-locl",
84
+ "-mark",
85
+ "-mkmk",
86
+ "-rlig",
87
+ ]
88
+
89
+ font = ImageFont.truetype(font_file, size)
90
+ x0, y0, x1, y1 = font.getbbox("$", features=features)
91
+
92
+ # the actual height is usually a few pixels less than the font size
93
+ h = round(y1 - y0)
94
+ w = round(x1 - x0)
95
+
96
+ # create an image with the text (greyscale, will be used as alpha channel on the gpu)
97
+ image = Image.new("L", (len(text) * w, h), (0))
98
+ draw = ImageDraw.Draw(image)
99
+ for i, c in enumerate(text):
100
+ draw.text((i * w, -y0), c, font=font, fill=(255), features=features)
101
+
102
+ # image.save(f"out_{size}.png")
103
+ return image.tobytes(), w, h
104
+
105
+
106
+ class FontUniforms(UniformBase):
107
+ _binding = Binding.FONT
108
+ _fields_ = [
109
+ ("width", ct.c_uint32),
110
+ ("height", ct.c_uint32),
111
+ ("width_normalized", ct.c_float),
112
+ ("height_normalized", ct.c_float),
113
+ ]
114
+
115
+
116
+ class Font:
117
+ def __init__(self, canvas, size=15):
118
+ self.canvas = canvas
119
+ self.uniforms = FontUniforms(canvas.device)
120
+ self.set_font_size(size)
121
+
122
+ self.canvas.on_resize(self.update)
123
+
124
+ def get_bindings(self):
125
+ return [
126
+ TextureBinding(Binding.FONT_TEXTURE, self._texture, dim=2),
127
+ *self.uniforms.get_bindings(),
128
+ ]
129
+
130
+ def get_shader_code(self):
131
+ return read_shader_file("font.wgsl", __file__)
132
+
133
+ def set_font_size(self, font_size: int):
134
+ from .font import create_font_texture
135
+
136
+ self._texture = create_font_texture(self.canvas.device, font_size)
137
+ char_width = self._texture.width // (127 - 32)
138
+ char_height = self._texture.height
139
+ self.uniforms.width = char_width
140
+ self.uniforms.height = char_height
141
+ self.update()
142
+
143
+ def update(self):
144
+ self.uniforms.width_normalized = 2.0 * self.uniforms.width / self.canvas.width
145
+ self.uniforms.height_normalized = (
146
+ 2.0 * self.uniforms.height / self.canvas.height
147
+ )
148
+ self.uniforms.update_buffer()
149
+
150
+
151
+ if __name__ == "__main__":
152
+ # create font data and store it as json because we cannot generate this in pyodide
153
+
154
+ fonts = {}
155
+
156
+ for size in list(range(8, 21, 2)) + [25, 30, 40]:
157
+ data, w, h = create_font_data(size)
158
+ fonts[size] = {
159
+ "data": base64.b64encode(zlib.compress(data)).decode("utf-8"),
160
+ "width": w,
161
+ "height": h,
162
+ }
163
+
164
+ json.dump(fonts, open("fonts.json", "w"), indent=2)