tinygrad 0.10.1__py3-none-any.whl → 0.10.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. tinygrad/codegen/devectorizer.py +247 -0
  2. tinygrad/codegen/expander.py +121 -0
  3. tinygrad/codegen/kernel.py +35 -37
  4. tinygrad/codegen/linearize.py +19 -10
  5. tinygrad/codegen/lowerer.py +31 -8
  6. tinygrad/codegen/symbolic.py +476 -0
  7. tinygrad/codegen/transcendental.py +10 -0
  8. tinygrad/device.py +28 -11
  9. tinygrad/dtype.py +12 -3
  10. tinygrad/engine/jit.py +3 -2
  11. tinygrad/engine/multi.py +0 -1
  12. tinygrad/engine/realize.py +7 -4
  13. tinygrad/engine/schedule.py +227 -255
  14. tinygrad/engine/search.py +20 -27
  15. tinygrad/gradient.py +3 -0
  16. tinygrad/helpers.py +7 -4
  17. tinygrad/nn/state.py +2 -2
  18. tinygrad/ops.py +64 -329
  19. tinygrad/renderer/__init__.py +19 -3
  20. tinygrad/renderer/cstyle.py +39 -18
  21. tinygrad/renderer/llvmir.py +55 -18
  22. tinygrad/renderer/ptx.py +6 -2
  23. tinygrad/renderer/wgsl.py +20 -12
  24. tinygrad/runtime/autogen/libc.py +404 -71
  25. tinygrad/runtime/autogen/{libpciaccess.py → pci.py} +25 -715
  26. tinygrad/runtime/autogen/webgpu.py +6985 -0
  27. tinygrad/runtime/graph/metal.py +28 -29
  28. tinygrad/runtime/ops_amd.py +37 -34
  29. tinygrad/runtime/{ops_clang.py → ops_cpu.py} +4 -2
  30. tinygrad/runtime/ops_disk.py +1 -1
  31. tinygrad/runtime/ops_dsp.py +59 -33
  32. tinygrad/runtime/ops_llvm.py +14 -12
  33. tinygrad/runtime/ops_metal.py +78 -62
  34. tinygrad/runtime/ops_nv.py +9 -6
  35. tinygrad/runtime/ops_python.py +5 -5
  36. tinygrad/runtime/ops_webgpu.py +200 -38
  37. tinygrad/runtime/support/am/amdev.py +23 -11
  38. tinygrad/runtime/support/am/ip.py +10 -10
  39. tinygrad/runtime/support/elf.py +2 -0
  40. tinygrad/runtime/support/hcq.py +7 -5
  41. tinygrad/runtime/support/llvm.py +8 -14
  42. tinygrad/shape/shapetracker.py +3 -2
  43. tinygrad/shape/view.py +2 -3
  44. tinygrad/spec.py +21 -20
  45. tinygrad/tensor.py +150 -90
  46. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js +1232 -0
  47. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/cpp.min.js +47 -0
  48. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js +42 -0
  49. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css +9 -0
  50. tinygrad/viz/assets/d3js.org/d3.v5.min.js +2 -0
  51. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +4816 -0
  52. tinygrad/viz/assets/unpkg.com/@highlightjs/cdn-assets@11.10.0/styles/tokyo-night-dark.min.css +8 -0
  53. tinygrad/viz/index.html +544 -0
  54. tinygrad/viz/perfetto.html +178 -0
  55. tinygrad/viz/serve.py +205 -0
  56. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/METADATA +20 -8
  57. tinygrad-0.10.2.dist-info/RECORD +99 -0
  58. tinygrad/codegen/rewriter.py +0 -516
  59. tinygrad-0.10.1.dist-info/RECORD +0 -86
  60. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/LICENSE +0 -0
  61. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/WHEEL +0 -0
  62. {tinygrad-0.10.1.dist-info → tinygrad-0.10.2.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- import os, pathlib, struct, ctypes, tempfile, functools, decimal
1
+ import os, pathlib, struct, ctypes, tempfile, functools, contextlib, decimal, platform
2
2
  from typing import Any, Union, cast
3
3
  from tinygrad.helpers import prod, to_mv, getenv, round_up, cache_dir, T, init_c_struct_t, PROFILE
4
4
  from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator, cpu_profile, ProfileDeviceEvent, ProfileRangeEvent
@@ -9,10 +9,11 @@ class objc_id(ctypes.c_void_p): # This prevents ctypes from converting response
9
9
  def __eq__(self, other): return self.value == other.value
10
10
 
11
11
  class objc_instance(objc_id): # method with name "new", "alloc" should be freed after use
12
- def __del__(self): msg(self, "release")
13
-
14
- @functools.lru_cache(None)
15
- def sel(name: str): return libobjc.sel_registerName(name.encode())
12
+ def __del__(self):
13
+ # CPython doesn't make any guarantees about order in which globals (like `msg` or `libobjc`) are destroyed when the interpreter shuts down
14
+ # https://github.com/tinygrad/tinygrad/pull/8949 triggered the unlucky ordering which lead to a bunch of errors at exit
15
+ # TODO: Why isn't `sys.is_finalizing` working?
16
+ if msg is not None and libobjc is not None: msg("release")(self)
16
17
 
17
18
  class MTLResourceOptions:
18
19
  MTLResourceCPUCacheModeDefaultCache = 0
@@ -26,46 +27,47 @@ REQUEST_TYPE_COMPILE = 13
26
27
 
27
28
  libobjc = ctypes.CDLL("/usr/lib/libobjc.dylib")
28
29
  libmetal = ctypes.CDLL("/System/Library/Frameworks/Metal.framework/Metal")
29
- compiler = ctypes.CDLL("/System/Library/PrivateFrameworks/MTLCompiler.framework/MTLCompiler")
30
30
  # Must be loaded for default Metal Device: https://developer.apple.com/documentation/metal/1433401-mtlcreatesystemdefaultdevice?language=objc
31
31
  ctypes.CDLL("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics")
32
32
  libdispatch = ctypes.CDLL("/usr/lib/libSystem.dylib") # libdispatch is part of libSystem on mac
33
33
  libobjc.objc_getClass.restype = objc_id
34
34
  libobjc.sel_registerName.restype = objc_id
35
35
  libmetal.MTLCreateSystemDefaultDevice.restype = objc_instance
36
- compiler.MTLCodeGenServiceCreate.restype = ctypes.c_void_p
37
36
  libdispatch.dispatch_data_create.restype = objc_instance
38
37
 
39
- # Ignore mypy error reporting incompatible default, because typevar default only works on python 3.12
40
- def msg(ptr: objc_id, selector: str, /, *args: Any, restype: type[T] = objc_id) -> T: # type: ignore [assignment]
38
+ @functools.lru_cache(None)
39
+ def msg(selector: str, restype: type[T] = objc_id): # type: ignore [assignment]
40
+ resname = libobjc.sel_registerName(selector.encode())
41
41
  sender = libobjc["objc_msgSend"] # Using attribute access returns a new reference so setting restype is safe
42
42
  sender.restype = restype
43
- return sender(ptr, sel(selector), *args)
43
+ def _msg(ptr: objc_id, *args: Any) -> T: return sender(ptr, resname, *args)
44
+ return _msg
44
45
 
45
- def to_ns_str(s: str): return msg(libobjc.objc_getClass(b"NSString"), "stringWithUTF8String:", s.encode(), restype=objc_instance)
46
- def from_ns_str(s): return bytes(msg(s, "UTF8String", restype=ctypes.c_char_p)).decode()
46
+ @functools.lru_cache(None)
47
+ def to_ns_str(s: str): return msg("stringWithUTF8String:", objc_instance)(libobjc.objc_getClass(b"NSString"), s.encode())
48
+ def from_ns_str(s): return bytes(msg("UTF8String", ctypes.c_char_p)(s)).decode()
47
49
 
48
50
  def to_struct(*t: int, _type: type = ctypes.c_ulong): return init_c_struct_t(tuple([(f"field{i}", _type) for i in range(len(t))]))(*t)
49
51
 
50
52
  def wait_check(cbuf: Any):
51
- msg(cbuf, "waitUntilCompleted")
52
- error_check(msg(cbuf, "error", restype=objc_instance))
53
+ msg("waitUntilCompleted")(cbuf)
54
+ error_check(msg("error", objc_instance)(cbuf))
53
55
 
54
- def cmdbuf_label(cbuf: objc_id) -> str|None: return from_ns_str(label) if (label:=msg(cbuf, "label", restype=objc_id)).value is not None else None
55
- def cmdbuf_st_time(cbuf: objc_id) -> float: return cast(float, msg(cbuf, "GPUStartTime", restype=ctypes.c_double))
56
- def cmdbuf_en_time(cbuf: objc_id) -> float: return cast(float, msg(cbuf, "GPUEndTime", restype=ctypes.c_double))
56
+ def cmdbuf_label(cbuf: objc_id) -> str|None: return from_ns_str(label) if (label:=msg("label", objc_id)(cbuf)).value is not None else None
57
+ def cmdbuf_st_time(cbuf: objc_id) -> float: return cast(float, msg("GPUStartTime", ctypes.c_double)(cbuf))
58
+ def cmdbuf_en_time(cbuf: objc_id) -> float: return cast(float, msg("GPUEndTime", ctypes.c_double)(cbuf))
57
59
 
58
60
  def error_check(error: objc_instance, error_constructor: type[Exception] = RuntimeError):
59
61
  if error.value is None: return None
60
- raise error_constructor(from_ns_str(msg(error, "localizedDescription", restype=objc_instance)))
62
+ raise error_constructor(from_ns_str(msg("localizedDescription", objc_instance)(error)))
61
63
 
62
64
  class MetalDevice(Compiled):
63
65
  def __init__(self, device:str):
64
66
  self.sysdevice = libmetal.MTLCreateSystemDefaultDevice()
65
- self.mtl_queue = msg(self.sysdevice, "newCommandQueueWithMaxCommandBufferCount:", 1024, restype=objc_instance)
67
+ self.mtl_queue = msg("newCommandQueueWithMaxCommandBufferCount:", objc_instance)(self.sysdevice, 1024)
66
68
  if self.mtl_queue is None: raise RuntimeError("Cannot allocate a new command queue")
67
69
  self.mtl_buffers_in_flight: list[Any] = []
68
- self.timeline_signal = msg(self.sysdevice, "newSharedEvent", restype=objc_instance)
70
+ self.timeline_signal = msg("newSharedEvent", objc_instance)(self.sysdevice)
69
71
  self.timeline_value = 0
70
72
 
71
73
  Compiled.profile_events += [ProfileDeviceEvent(device)]
@@ -83,16 +85,25 @@ class MetalDevice(Compiled):
83
85
  self.mtl_buffers_in_flight.clear()
84
86
 
85
87
  def metal_src_to_library(device:MetalDevice, src:str) -> objc_instance:
86
- options = msg(libobjc.objc_getClass(b"MTLCompileOptions"), "new", restype=objc_instance)
87
- msg(options, "setFastMathEnabled:", getenv("METAL_FAST_MATH"))
88
- library = msg(device.sysdevice, "newLibraryWithSource:options:error:", to_ns_str(src), options,
89
- ctypes.byref(compileError:=objc_instance()), restype=objc_instance)
88
+ options = msg("new", objc_instance)(libobjc.objc_getClass(b"MTLCompileOptions"))
89
+ msg("setFastMathEnabled:")(options, getenv("METAL_FAST_MATH"))
90
+ library = msg("newLibraryWithSource:options:error:", objc_instance)(device.sysdevice, to_ns_str(src),
91
+ options, ctypes.byref(compileError:=objc_instance()))
90
92
  error_check(compileError, CompileError)
91
93
  return library
92
94
 
93
95
  class MetalCompiler(Compiler):
96
+ # Opening METAL after LLVM doesn't fail because ctypes.CDLL opens with RTLD_LOCAL but MTLCompiler opens it's own llvm with RTLD_GLOBAL
97
+ # This means that MTLCompiler's llvm will create it's own instances of global state because RTLD_LOCAL doesn't export symbols, but if RTLD_GLOBAL
98
+ # library is loaded first then RTLD_LOCAL library will just use it's symbols. On linux there is RTLD_DEEPBIND to prevent that, but on macos there
99
+ # doesn't seem to be anything we can do.
100
+ with contextlib.suppress(FileNotFoundError):
101
+ import tinygrad.runtime.autogen.llvm # noqa: F401
102
+ support = ctypes.CDLL("/System/Library/PrivateFrameworks/MTLCompiler.framework/MTLCompiler")
103
+ support.MTLCodeGenServiceCreate.restype = ctypes.c_void_p
104
+
94
105
  def __init__(self):
95
- self.cgs = ctypes.c_void_p(compiler.MTLCodeGenServiceCreate(b"tinygrad"))
106
+ self.cgs = ctypes.c_void_p(MetalCompiler.support.MTLCodeGenServiceCreate(b"tinygrad"))
96
107
  super().__init__("compile_metal_direct")
97
108
  def __reduce__(self): return (MetalCompiler,()) # force pickle to create new instance for each multiprocessing fork
98
109
  def compile(self, src:str) -> bytes:
@@ -106,9 +117,14 @@ class MetalCompiler(Compiler):
106
117
  ret = reply[sum(struct.unpack('<LL', reply[8:16])):]
107
118
  else:
108
119
  ret = CompileError(errorMessage.decode())
120
+
121
+ # no changes for compute in 2.0 - 2.4 specs, use 2.0 as default for old versions.
122
+ macos_major = int(platform.mac_ver()[0].split('.')[0])
123
+ metal_version = "metal3.1" if macos_major >= 14 else "metal3.0" if macos_major >= 13 else "macos-metal2.0"
124
+
109
125
  # llvm will create modules.timestamp in cache path and cache compilation of metal stdlib (250ms => 8ms compilation time)
110
126
  # note that llvm won't necessarily create anything else here as apple has prebuilt versions of many standard libraries
111
- params = f'-fno-fast-math -std=metal3.1 --driver-mode=metal -x metal -fmodules-cache-path="{cache_dir}" -fno-caret-diagnostics'
127
+ params = f'-fno-fast-math -std={metal_version} --driver-mode=metal -x metal -fmodules-cache-path="{cache_dir}" -fno-caret-diagnostics'
112
128
  # source blob has to be padded to multiple of 4 but at least one 'b\x00' should be added, params blob just has to be null terminated
113
129
  src_padded, params_padded = src.encode() + b'\x00'*(round_up(len(src) + 1, 4) - len(src)), params.encode() + b'\x00'
114
130
  request = struct.pack('<QQ', len(src_padded), len(params_padded)) + src_padded + params_padded
@@ -116,7 +132,7 @@ class MetalCompiler(Compiler):
116
132
  # See https://clang.llvm.org/docs/Block-ABI-Apple.html#high-level for struct layout.
117
133
  # Fields other than invoke are unused in this case so we can just use ctypes.byref with negative offset to invoke field, add blockptr as a first
118
134
  # argument and pretend it's a normal callback
119
- compiler.MTLCodeGenServiceBuildRequest(self.cgs, None, REQUEST_TYPE_COMPILE, request, len(request), ctypes.byref(callback, -0x10))
135
+ MetalCompiler.support.MTLCodeGenServiceBuildRequest(self.cgs, None, REQUEST_TYPE_COMPILE, request, len(request), ctypes.byref(callback, -0x10))
120
136
  if isinstance(ret, Exception): raise ret
121
137
  assert ret[:4] == b"MTLB" and ret[-4:] == b"ENDT", f"Invalid Metal library. {ret!r}"
122
138
  return ret
@@ -133,35 +149,36 @@ class MetalProgram:
133
149
  if lib[:4] == b"MTLB":
134
150
  # binary metal library
135
151
  data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
136
- self.library = msg(self.dev.sysdevice, "newLibraryWithData:error:", data, ctypes.byref(error_lib:=objc_instance()), restype=objc_instance)
152
+ self.library = msg("newLibraryWithData:error:", objc_instance)(self.dev.sysdevice, data, ctypes.byref(error_lib:=objc_instance()))
137
153
  error_check(error_lib)
138
154
  else:
139
155
  # metal source. rely on OS caching
140
156
  try: self.library = metal_src_to_library(self.dev, lib.decode())
141
157
  except CompileError as e: raise RuntimeError from e
142
- self.fxn = msg(self.library, "newFunctionWithName:", to_ns_str(name), restype=objc_instance)
143
- descriptor = msg(libobjc.objc_getClass(b"MTLComputePipelineDescriptor"), "new", restype=objc_instance)
144
- msg(descriptor, "setComputeFunction:", self.fxn)
145
- msg(descriptor, "setSupportIndirectCommandBuffers:", True)
146
- self.pipeline_state = msg(self.dev.sysdevice, "newComputePipelineStateWithDescriptor:options:reflection:error:",
147
- descriptor, MTLPipelineOption.MTLPipelineOptionNone, None, ctypes.byref(error_pipeline_creation:=objc_instance()), restype=objc_instance)
158
+ self.fxn = msg("newFunctionWithName:", objc_instance)(self.library, to_ns_str(name))
159
+ descriptor = msg("new", objc_instance)(libobjc.objc_getClass(b"MTLComputePipelineDescriptor"))
160
+ msg("setComputeFunction:")(descriptor, self.fxn)
161
+ msg("setSupportIndirectCommandBuffers:")(descriptor, True)
162
+ self.pipeline_state = msg("newComputePipelineStateWithDescriptor:options:reflection:error:", objc_instance)(self.dev.sysdevice,
163
+ descriptor, MTLPipelineOption.MTLPipelineOptionNone, None, ctypes.byref(error_pipeline_creation:=objc_instance()))
148
164
  error_check(error_pipeline_creation)
165
+ # cache these msg calls
166
+ self.max_total_threads: int = cast(int, msg("maxTotalThreadsPerThreadgroup", ctypes.c_ulong)(self.pipeline_state))
149
167
 
150
168
  def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
151
- max_total_threads = msg(self.pipeline_state, "maxTotalThreadsPerThreadgroup", restype=ctypes.c_ulong)
152
- if prod(local_size) > cast(int, max_total_threads):
153
- exec_width = msg(self.pipeline_state, "threadExecutionWidth", restype=ctypes.c_ulong)
154
- memory_length = msg(self.pipeline_state, "staticThreadgroupMemoryLength", restype=ctypes.c_ulong)
155
- raise RuntimeError(f"local size {local_size} bigger than {max_total_threads} with exec width {exec_width} memory length {memory_length}")
156
- command_buffer = msg(self.dev.mtl_queue, "commandBuffer", restype=objc_instance)
157
- encoder = msg(command_buffer, "computeCommandEncoder", restype=objc_instance)
158
- msg(encoder, "setComputePipelineState:", self.pipeline_state)
159
- for i,a in enumerate(bufs): msg(encoder, "setBuffer:offset:atIndex:", a.buf, a.offset, i)
160
- for i,a in enumerate(vals, start=len(bufs)): msg(encoder, "setBytes:length:atIndex:", bytes(ctypes.c_int(a)), 4, i)
161
- msg(encoder, "dispatchThreadgroups:threadsPerThreadgroup:", to_struct(*global_size), to_struct(*local_size))
162
- msg(encoder, "endEncoding")
163
- msg(command_buffer, "setLabel:", to_ns_str(self.name))
164
- msg(command_buffer, "commit")
169
+ if prod(local_size) > self.max_total_threads:
170
+ exec_width = msg("threadExecutionWidth", ctypes.c_ulong)(self.pipeline_state)
171
+ memory_length = msg("staticThreadgroupMemoryLength", ctypes.c_ulong)(self.pipeline_state)
172
+ raise RuntimeError(f"local size {local_size} bigger than {self.max_total_threads} with exec width {exec_width} memory length {memory_length}")
173
+ command_buffer = msg("commandBuffer", objc_instance)(self.dev.mtl_queue)
174
+ encoder = msg("computeCommandEncoder", objc_instance)(command_buffer)
175
+ msg("setComputePipelineState:")(encoder, self.pipeline_state)
176
+ for i,a in enumerate(bufs): msg("setBuffer:offset:atIndex:")(encoder, a.buf, a.offset, i)
177
+ for i,a in enumerate(vals, start=len(bufs)): msg("setBytes:length:atIndex:")(encoder, bytes(ctypes.c_int(a)), 4, i)
178
+ msg("dispatchThreadgroups:threadsPerThreadgroup:")(encoder, to_struct(*global_size), to_struct(*local_size))
179
+ msg("endEncoding")(encoder)
180
+ msg("setLabel:")(command_buffer, to_ns_str(self.name)) # TODO: is this always needed?
181
+ msg("commit")(command_buffer)
165
182
  self.dev.mtl_buffers_in_flight.append(command_buffer)
166
183
  if wait:
167
184
  wait_check(command_buffer)
@@ -176,33 +193,32 @@ class MetalAllocator(LRUAllocator):
176
193
  super().__init__()
177
194
  def _alloc(self, size:int, options) -> MetalBuffer:
178
195
  # Buffer is explicitly released in _free() rather than garbage collected via reference count
179
- ret = msg(self.dev.sysdevice, "newBufferWithLength:options:", ctypes.c_ulong(size), MTLResourceOptions.MTLResourceStorageModeShared,
180
- restype=objc_id)
196
+ ret = msg("newBufferWithLength:options:", objc_id)(self.dev.sysdevice, ctypes.c_ulong(size), MTLResourceOptions.MTLResourceStorageModeShared)
181
197
  if ret.value is None: raise MemoryError(f"Metal OOM while allocating {size=}")
182
198
  return MetalBuffer(ret, size)
183
- def _free(self, opaque:MetalBuffer, options): msg(opaque.buf, "release")
199
+ def _free(self, opaque:MetalBuffer, options): msg("release")(opaque.buf)
184
200
  def _transfer(self, dest:MetalBuffer, src:MetalBuffer, sz:int, src_dev:MetalDevice, dest_dev:MetalDevice):
185
201
  dest_dev.synchronize()
186
- src_command_buffer = msg(src_dev.mtl_queue, "commandBuffer", restype=objc_instance)
187
- encoder = msg(src_command_buffer, "blitCommandEncoder", restype=objc_instance)
188
- msg(encoder, "copyFromBuffer:sourceOffset:toBuffer:destinationOffset:size:", src.buf, ctypes.c_ulong(src.offset),
202
+ src_command_buffer = msg("commandBuffer", objc_instance)(src_dev.mtl_queue)
203
+ encoder = msg("blitCommandEncoder", objc_instance)(src_command_buffer)
204
+ msg("copyFromBuffer:sourceOffset:toBuffer:destinationOffset:size:")(encoder, src.buf, ctypes.c_ulong(src.offset),
189
205
  dest.buf, ctypes.c_ulong(dest.offset), ctypes.c_ulong(sz))
190
- msg(encoder, "endEncoding")
206
+ msg("endEncoding")(encoder)
191
207
  if src_dev != dest_dev:
192
- msg(src_command_buffer, "encodeSignalEvent:value:", src_dev.timeline_signal, src_dev.timeline_value)
193
- dest_command_buffer = msg(dest_dev.mtl_queue, "commandBuffer", restype=objc_instance)
194
- msg(dest_command_buffer, "encodeWaitForEvent:value:", src_dev.timeline_signal, src_dev.timeline_value)
195
- msg(dest_command_buffer, "commit")
208
+ msg("encodeSignalEvent:value:")(src_command_buffer, src_dev.timeline_signal, src_dev.timeline_value)
209
+ dest_command_buffer = msg("commandBuffer", objc_instance)(dest_dev.mtl_queue)
210
+ msg("encodeWaitForEvent:value:")(dest_command_buffer, src_dev.timeline_signal, src_dev.timeline_value)
211
+ msg("commit")(dest_command_buffer)
196
212
  dest_dev.mtl_buffers_in_flight.append(dest_command_buffer)
197
213
  src_dev.timeline_value += 1
198
- msg(src_command_buffer, "setLabel:", to_ns_str(f"COPY {src_dev.device} -> {dest_dev.device}"))
199
- msg(src_command_buffer, "commit")
214
+ msg("setLabel:")(src_command_buffer, to_ns_str(f"COPY {src_dev.device} -> {dest_dev.device}"))
215
+ msg("commit")(src_command_buffer)
200
216
  src_dev.mtl_buffers_in_flight.append(src_command_buffer)
201
217
  def _cp_mv(self, dst, src, prof_desc):
202
218
  with cpu_profile(prof_desc, self.dev.device, is_copy=True): dst[:] = src
203
219
  def _as_buffer(self, src:MetalBuffer) -> memoryview:
204
220
  self.dev.synchronize()
205
- return to_mv(cast(int, msg(src.buf, "contents", restype=objc_id).value), src.size + src.offset)[src.offset:]
221
+ return to_mv(cast(int, msg("contents", objc_id)(src.buf).value), src.size + src.offset)[src.offset:]
206
222
  def _copyin(self, dest:MetalBuffer, src:memoryview): self._cp_mv(self._as_buffer(dest), src, "CPU -> METAL")
207
223
  def _copyout(self, dest:memoryview, src:MetalBuffer): self._cp_mv(dest, self._as_buffer(src), "METAL -> CPU")
208
224
  def _offset(self, buf:MetalBuffer, size:int, offset:int): return MetalBuffer(buf.buf, size, offset)
@@ -6,8 +6,8 @@ from dataclasses import dataclass
6
6
  from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, CLikeArgsState, HCQProgram, HCQSignal, BumpAllocator
7
7
  from tinygrad.runtime.support.hcq import HWInterface, MOCKGPU
8
8
  from tinygrad.ops import sint
9
- from tinygrad.device import BufferSpec
10
- from tinygrad.helpers import getenv, mv_address, init_c_struct_t, to_mv, round_up, data64, data64_le, DEBUG, prod
9
+ from tinygrad.device import BufferSpec, CPUProgram
10
+ from tinygrad.helpers import getenv, mv_address, init_c_struct_t, to_mv, round_up, data64, data64_le, DEBUG, prod, OSX
11
11
  from tinygrad.renderer.ptx import PTXRenderer
12
12
  from tinygrad.renderer.cstyle import NVRenderer
13
13
  from tinygrad.runtime.support.compiler_cuda import CUDACompiler, PTXCompiler, PTX, NVPTXCompiler, NVCompiler
@@ -122,6 +122,8 @@ class NVCommandQueue(HWQueue[NVSignal, 'NVDevice', 'NVProgram', 'NVArgsState']):
122
122
 
123
123
  gpfifo.ring[gpfifo.put_value % gpfifo.entries_count] = (cmdq_addr//4 << 2) | (len(self._q) << 42) | (1 << 41)
124
124
  gpfifo.controls.GPPut = (gpfifo.put_value + 1) % gpfifo.entries_count
125
+
126
+ if CPUProgram.atomic_lib is not None: CPUProgram.atomic_lib.atomic_thread_fence(__ATOMIC_SEQ_CST:=5)
125
127
  dev.gpu_mmio[0x90 // 4] = gpfifo.token
126
128
  gpfifo.put_value += 1
127
129
 
@@ -141,6 +143,7 @@ class NVComputeQueue(NVCommandQueue):
141
143
 
142
144
  self.bind_sints_to_ptr(*global_size, ptr=qmd_addr + nv_gpu.NVC6C0_QMDV03_00_CTA_RASTER_WIDTH[1] // 8, fmt='I')
143
145
  self.bind_sints_to_ptr(*local_size, ptr=qmd_addr + nv_gpu.NVC6C0_QMDV03_00_CTA_THREAD_DIMENSION0[1] // 8, fmt='H')
146
+ self.bind_sints_to_ptr(*local_size, *global_size, ptr=args_state.ptr, fmt='I')
144
147
  qmd.constant_buffer_addr_upper_0, qmd.constant_buffer_addr_lower_0 = data64(args_state.ptr)
145
148
 
146
149
  if self.active_qmd is None:
@@ -188,7 +191,7 @@ class NVCopyQueue(NVCommandQueue):
188
191
 
189
192
  class NVArgsState(CLikeArgsState):
190
193
  def __init__(self, ptr:int, prg:NVProgram, bufs:tuple[HCQBuffer, ...], vals:tuple[int, ...]=()):
191
- if MOCKGPU: prg.constbuffer_0[0:2] = [len(bufs), len(vals)]
194
+ if MOCKGPU: prg.constbuffer_0[80:82] = [len(bufs), len(vals)]
192
195
  super().__init__(ptr, prg, bufs, vals=vals, prefix=prg.constbuffer_0)
193
196
 
194
197
  class NVProgram(HCQProgram):
@@ -292,8 +295,8 @@ class NVDevice(HCQCompiled[NVSignal]):
292
295
  # TODO: Need a proper allocator for va addresses
293
296
  # 0x1000000000 - 0x2000000000, reserved for system/cpu mappings
294
297
  # VA space is 48bits.
295
- low_uvm_vaddr_allocator: BumpAllocator = BumpAllocator(size=0x1000000000, base=0x1000000000, wrap=False)
296
- uvm_vaddr_allocator: BumpAllocator = BumpAllocator(size=(1 << 48) - 1, base=0x2000000000, wrap=False)
298
+ low_uvm_vaddr_allocator: BumpAllocator = BumpAllocator(size=0x1000000000, base=0x8000000000 if OSX else 0x1000000000, wrap=False)
299
+ uvm_vaddr_allocator: BumpAllocator = BumpAllocator(size=(1 << 48) - 1, base=low_uvm_vaddr_allocator.base + low_uvm_vaddr_allocator.size, wrap=False)
297
300
  host_object_enumerator: int = 0x1000
298
301
 
299
302
  def _new_gpu_fd(self):
@@ -311,7 +314,7 @@ class NVDevice(HCQCompiled[NVSignal]):
311
314
 
312
315
  def _gpu_alloc(self, size:int, host=False, uncached=False, cpu_access=False, contiguous=False, map_flags=0, tag="") -> HCQBuffer:
313
316
  # Uncached memory is "system". Use huge pages only for gpu memory.
314
- page_size = (4 << 10) if uncached or host else ((2 << 20) if size >= (8 << 20) else (4 << 10))
317
+ page_size = (4 << (12 if OSX else 10)) if uncached or host else ((2 << 20) if size >= (8 << 20) else (4 << (12 if OSX else 10)))
315
318
  size = round_up(size, page_size)
316
319
  va_addr = self._alloc_gpu_vaddr(size, alignment=page_size, force_low=cpu_access)
317
320
 
@@ -40,7 +40,7 @@ class PythonProgram:
40
40
  loop_ends: dict[int, int] = {}
41
41
  while i < len(self.uops):
42
42
  uop, dtype, idp, arg = self.uops[i]
43
- void_ops = {Ops.STORE, Ops.ENDRANGE, Ops.BARRIER, Ops.IF, Ops.ENDIF}
43
+ void_ops = {Ops.STORE, Ops.ENDRANGE, Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.NAME}
44
44
  if uop is Ops.DEFINE_ACC: idp = [idp[0]]
45
45
  inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
46
46
  dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
@@ -60,7 +60,7 @@ class PythonProgram:
60
60
  loop_ends[idp[0]] = i
61
61
  i = idp[0]
62
62
  continue
63
- if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF):
63
+ if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF, Ops.NAME):
64
64
  # in the python emulator, the warp is always in sync
65
65
  i += 1
66
66
  continue
@@ -173,7 +173,7 @@ class PythonProgram:
173
173
  # C, D (8 elements on 8 threads)
174
174
  def c_map(lane, elem): return (lane, elem)
175
175
  ul[i] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map)
176
- elif arg[4] == "CLANG":
176
+ elif arg[4] == "CPU":
177
177
  def elem(x, col, row, _): return x[col+row][0] # k is always 0
178
178
  def c_map(_, elem): return (elem%16, elem//16)
179
179
  ul[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
@@ -194,9 +194,9 @@ class PythonRenderer(Renderer):
194
194
  if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tc_sm80
195
195
  if getenv("EMULATE_CUDA_SM75"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tc_sm75
196
196
  if getenv("EMULATE_INTEL"): self.device, self.suffix, self.tensor_cores = "INTEL", "INTEL", IntelRenderer.tensor_cores
197
- if getenv("EMULATE_AMX"): self.device, self.tensor_cores = "CLANG", ClangRenderer.tensor_cores
197
+ if getenv("EMULATE_AMX"): self.device, self.tensor_cores = "CPU", ClangRenderer.tensor_cores
198
198
 
199
- def render(self, name:str, uops:list[UOp]) -> str:
199
+ def render(self, uops:list[UOp]) -> str:
200
200
  lops = [(u.op, u.dtype, [uops.index(v) for v in u.src], u.arg) for u in uops]
201
201
  return base64.b64encode(pickle.dumps(lops)).decode()
202
202
 
@@ -1,63 +1,225 @@
1
1
  import functools, struct
2
2
  from tinygrad.device import Compiled, Allocator, Compiler
3
3
  from tinygrad.renderer.wgsl import WGSLRenderer
4
- from tinygrad.helpers import round_up
5
- import wgpu
4
+ from tinygrad.helpers import round_up, OSX
5
+ from tinygrad.runtime.autogen import webgpu
6
+ from typing import List, Any
7
+ import ctypes
8
+ import os
6
9
 
7
- def create_uniform(wgpu_device, val) -> wgpu.GPUBuffer:
8
- buf = wgpu_device.create_buffer(size=4, usage=wgpu.BufferUsage.UNIFORM | wgpu.BufferUsage.COPY_DST)
9
- wgpu_device.queue.write_buffer(buf, 0, val.to_bytes(4, "little") if isinstance(val, int) else struct.pack('<f', val))
10
+ backend_types = {v: k for k, v in webgpu.WGPUBackendType__enumvalues.items() }
11
+
12
+ try:
13
+ instance = webgpu.wgpuCreateInstance(webgpu.WGPUInstanceDescriptor(features = webgpu.WGPUInstanceFeatures(timedWaitAnyEnable = True)))
14
+ except AttributeError:
15
+ raise RuntimeError("Cannot find dawn library. Install it with: " + ("brew tap wpmed92/dawn && brew install dawn" if OSX else
16
+ "sudo curl -L https://github.com/wpmed92/pydawn/releases/download/v0.1.6/libwebgpu_dawn.so -o /usr/lib/libwebgpu_dawn.so"))
17
+
18
+ def to_c_string(_str): return ctypes.create_string_buffer(_str.encode('utf-8'))
19
+
20
+ def from_wgpu_str(string_view): return ctypes.string_at(string_view.data, string_view.length).decode("utf-8")
21
+
22
+ def to_wgpu_str(_str):
23
+ return webgpu.WGPUStringView(data=ctypes.cast(ctypes.pointer(to_c_string(_str)), ctypes.POINTER(ctypes.c_char)), length=len(_str))
24
+
25
+ def _wait(future):
26
+ assert webgpu.wgpuInstanceWaitAny(instance, 1, webgpu.WGPUFutureWaitInfo(future=future), 2**64-1) == webgpu.WGPUWaitStatus_Success, "Future failed"
27
+
28
+ def write_buffer(device, buf, offset, src):
29
+ src = bytearray(src)
30
+ webgpu.wgpuQueueWriteBuffer(webgpu.wgpuDeviceGetQueue(device), buf, offset, (ctypes.c_uint8 * len(src)).from_buffer(src), len(src))
31
+
32
+ def _run(async_fun, cb_info_type, cb_type, status_enum, res_idx, msg_idx, *params):
33
+ result: List[Any] = []
34
+
35
+ def cb(*params):
36
+ result[:] = params
37
+ if msg_idx: result[msg_idx] = from_wgpu_str(result[msg_idx])
38
+
39
+ cb_info = cb_info_type(nextInChain=None, mode=webgpu.WGPUCallbackMode_WaitAnyOnly, callback=cb_type(cb))
40
+ _wait(async_fun(*params, cb_info))
41
+
42
+ if result[0] != 1: raise RuntimeError(f"[{status_enum[result[0]] if status_enum else 'ERROR'}]{result[msg_idx] if msg_idx else ''}")
43
+ return result[res_idx] if res_idx else None
44
+
45
+ def copy_buffer_to_buffer(dev, src, src_offset, dst, dst_offset, size):
46
+ encoder = webgpu.wgpuDeviceCreateCommandEncoder(dev, webgpu.WGPUCommandEncoderDescriptor())
47
+ webgpu.wgpuCommandEncoderCopyBufferToBuffer(encoder, src, src_offset, dst, dst_offset, size)
48
+ cb = webgpu.wgpuCommandEncoderFinish(encoder, webgpu.WGPUCommandBufferDescriptor())
49
+ webgpu.wgpuQueueSubmit(webgpu.wgpuDeviceGetQueue(dev), 1, (webgpu.WGPUCommandBuffer*1)(cb))
50
+ webgpu.wgpuCommandBufferRelease(cb)
51
+ webgpu.wgpuCommandEncoderRelease(encoder)
52
+
53
+ def read_buffer(dev, buf):
54
+ size = webgpu.wgpuBufferGetSize(buf)
55
+ tmp_buffer = webgpu.wgpuDeviceCreateBuffer(dev, webgpu.WGPUBufferDescriptor(size=size,
56
+ usage=webgpu.WGPUBufferUsage_CopyDst | webgpu.WGPUBufferUsage_MapRead, mappedAtCreation=False))
57
+ copy_buffer_to_buffer(dev, buf, 0, tmp_buffer, 0, size)
58
+ _run(webgpu.wgpuBufferMapAsync2, webgpu.WGPUBufferMapCallbackInfo2, webgpu.WGPUBufferMapCallback2, webgpu.WGPUBufferMapAsyncStatus__enumvalues,
59
+ None, 0, tmp_buffer, webgpu.WGPUMapMode_Read, 0, size)
60
+ void_ptr = ctypes.cast(webgpu.wgpuBufferGetConstMappedRange(tmp_buffer, 0, size), ctypes.c_void_p)
61
+ buf_copy = bytearray((ctypes.c_uint8 * size).from_address(void_ptr.value))
62
+ webgpu.wgpuBufferUnmap(tmp_buffer)
63
+ webgpu.wgpuBufferDestroy(tmp_buffer)
64
+ return memoryview(buf_copy).cast("B")
65
+
66
+ def pop_error(device):
67
+ return _run(webgpu.wgpuDevicePopErrorScopeF, webgpu.WGPUPopErrorScopeCallbackInfo, webgpu.WGPUPopErrorScopeCallback, None, 2, 2, device)
68
+
69
+ def create_uniform(wgpu_device, val):
70
+ buf = webgpu.wgpuDeviceCreateBuffer(wgpu_device,
71
+ webgpu.WGPUBufferDescriptor(size=4, usage=webgpu.WGPUBufferUsage_Uniform | webgpu.WGPUBufferUsage_CopyDst))
72
+ write_buffer(wgpu_device, buf, 0, val.to_bytes(4, "little") if isinstance(val, int) else struct.pack('<f', val))
10
73
  return buf
11
74
 
12
75
  class WebGPUProgram:
13
76
  def __init__(self, dev, name:str, lib:bytes):
14
77
  (self.dev, self.timestamp_supported) = dev
15
- self.name, self.lib, self.prg = name, lib, self.dev.create_shader_module(code=lib.decode()) # NOTE: this is the compiler
78
+
79
+ # Creating shader module
80
+ shader = webgpu.WGPUShaderModuleWGSLDescriptor(code=to_wgpu_str(lib.decode()),
81
+ chain=webgpu.WGPUChainedStruct(sType=webgpu.WGPUSType_ShaderSourceWGSL))
82
+ module = webgpu.WGPUShaderModuleDescriptor()
83
+ module.nextInChain = ctypes.cast(ctypes.pointer(shader), ctypes.POINTER(webgpu.struct_WGPUChainedStruct))
84
+
85
+ # Check compiler error
86
+ webgpu.wgpuDevicePushErrorScope(self.dev, webgpu.WGPUErrorFilter_Validation)
87
+ shader_module = webgpu.wgpuDeviceCreateShaderModule(self.dev, module)
88
+
89
+ if err := pop_error(self.dev): raise RuntimeError(f"Shader compilation failed: {err}")
90
+
91
+ self.name, self.lib, self.prg = name, lib, shader_module
16
92
  def __call__(self, *bufs, global_size=(1,1,1), local_size=(1,1,1), vals=(), wait=False):
17
93
  wait = wait and self.timestamp_supported
18
- binding_layouts = [{"binding": 0, "visibility": wgpu.ShaderStage.COMPUTE, "buffer": {"type": wgpu.BufferBindingType.uniform }}]
19
- binding_layouts += [{"binding": i+1, "visibility": wgpu.ShaderStage.COMPUTE,
20
- "buffer": {"type": wgpu.BufferBindingType.uniform if i >= len(bufs) else wgpu.BufferBindingType.storage }} for i in range(len(bufs)+len(vals))] # noqa: E501
21
- bindings = [{"binding": 0, "resource": {"buffer": create_uniform(self.dev, float('inf')), "offset": 0, "size": 4}}]
22
- bindings += [{"binding": i+1, "resource": {"buffer": create_uniform(self.dev, x) if i >= len(bufs) else x, "offset": 0,
23
- "size": 4 if i >= len(bufs) else x.size}} for i,x in enumerate(bufs+vals)] # noqa: E501
24
- bind_group_layout = self.dev.create_bind_group_layout(entries=binding_layouts)
25
- pipeline_layout = self.dev.create_pipeline_layout(bind_group_layouts=[bind_group_layout])
26
- bind_group = self.dev.create_bind_group(layout=bind_group_layout, entries=bindings)
27
- compute_pipeline = self.dev.create_compute_pipeline(layout=pipeline_layout,compute={"module": self.prg, "entry_point": self.name},)
28
- command_encoder = self.dev.create_command_encoder()
94
+ tmp_bufs = [*bufs]
95
+ buf_patch = False
96
+
97
+ # WebGPU does not allow using the same buffer for input and output
98
+ for i in range(1, len(bufs)):
99
+ if bufs[i] == bufs[0]:
100
+ tmp_bufs[0] = webgpu.wgpuDeviceCreateBuffer(self.dev,
101
+ webgpu.WGPUBufferDescriptor(size=webgpu.wgpuBufferGetSize(bufs[0]), usage=webgpu.wgpuBufferGetUsage(bufs[0])))
102
+ buf_patch = True
103
+
104
+ # Creating bind group layout
105
+ binding_layouts = [webgpu.WGPUBindGroupLayoutEntry(binding=0, visibility= webgpu.WGPUShaderStage_Compute,
106
+ buffer=webgpu.WGPUBufferBindingLayout(type=webgpu.WGPUBufferBindingType_Uniform))]
107
+ binding_layouts += [webgpu.WGPUBindGroupLayoutEntry(binding=i+1, visibility=webgpu.WGPUShaderStage_Compute,
108
+ buffer=webgpu.WGPUBufferBindingLayout(type=webgpu.WGPUBufferBindingType_Uniform if i >= len(tmp_bufs)
109
+ else webgpu.WGPUBufferBindingType_Storage)) for i in range(len(tmp_bufs)+len(vals))]
110
+
111
+ bl_arr_type = webgpu.WGPUBindGroupLayoutEntry * len(binding_layouts)
112
+ webgpu.wgpuDevicePushErrorScope(self.dev, webgpu.WGPUErrorFilter_Validation)
113
+ bind_group_layouts = [webgpu.wgpuDeviceCreateBindGroupLayout(self.dev, webgpu.WGPUBindGroupLayoutDescriptor(
114
+ entryCount=len(binding_layouts), entries=ctypes.cast(bl_arr_type(*binding_layouts), ctypes.POINTER(webgpu.WGPUBindGroupLayoutEntry))))]
115
+
116
+ if bg_layout_err := pop_error(self.dev): raise RuntimeError(f"Error creating bind group layout: {bg_layout_err}")
117
+
118
+ # Creating pipeline layout
119
+ pipeline_layout_desc = webgpu.WGPUPipelineLayoutDescriptor(bindGroupLayoutCount=len(bind_group_layouts),
120
+ bindGroupLayouts = (webgpu.WGPUBindGroupLayout * len(bind_group_layouts))(*bind_group_layouts))
121
+
122
+ webgpu.wgpuDevicePushErrorScope(self.dev, webgpu.WGPUErrorFilter_Validation)
123
+ pipeline_layout = webgpu.wgpuDeviceCreatePipelineLayout(self.dev, pipeline_layout_desc)
124
+
125
+ if pipe_err := pop_error(self.dev): raise RuntimeError(f"Error creating pipeline layout: {pipe_err}")
126
+
127
+ # Creating bind group
128
+ bindings = [webgpu.WGPUBindGroupEntry(binding=0, buffer=create_uniform(self.dev, float('inf')), offset=0, size=4)]
129
+ bindings += [webgpu.WGPUBindGroupEntry(binding=i+1, buffer=create_uniform(self.dev, x) if i >= len(tmp_bufs) else x, offset=0,
130
+ size=4 if i >= len(tmp_bufs) else webgpu.wgpuBufferGetSize(x)) for i,x in enumerate(tuple(tmp_bufs)+vals)]
131
+
132
+ bg_arr_type = webgpu.WGPUBindGroupEntry * len(bindings)
133
+ bind_group_desc = webgpu.WGPUBindGroupDescriptor(layout=bind_group_layouts[0], entryCount=len(bindings), entries=bg_arr_type(*bindings))
134
+ webgpu.wgpuDevicePushErrorScope(self.dev, webgpu.WGPUErrorFilter_Validation)
135
+ bind_group = webgpu.wgpuDeviceCreateBindGroup(self.dev, bind_group_desc)
136
+
137
+ if bind_err := pop_error(self.dev): raise RuntimeError(f"Error creating bind group: {bind_err}")
138
+
139
+ # Creating compute pipeline
140
+ compute_desc = webgpu.WGPUComputePipelineDescriptor(layout=pipeline_layout,
141
+ compute=webgpu.WGPUComputeState(module=self.prg, entryPoint=to_wgpu_str(self.name)))
142
+ pipeline_result = _run(webgpu.wgpuDeviceCreateComputePipelineAsync2, webgpu.WGPUCreateComputePipelineAsyncCallbackInfo2,
143
+ webgpu.WGPUCreateComputePipelineAsyncCallback2, webgpu.WGPUCreatePipelineAsyncStatus__enumvalues, 1, None, self.dev, compute_desc)
144
+
145
+ command_encoder = webgpu.wgpuDeviceCreateCommandEncoder(self.dev, webgpu.WGPUCommandEncoderDescriptor())
146
+ comp_pass_desc = webgpu.WGPUComputePassDescriptor(nextInChain=None)
147
+
29
148
  if wait:
30
- query_set = self.dev.create_query_set(type=wgpu.QueryType.timestamp, count=2)
31
- query_buf = self.dev.create_buffer(size=16, usage=wgpu.BufferUsage.QUERY_RESOLVE | wgpu.BufferUsage.COPY_SRC)
32
- timestamp_writes = {"query_set": query_set, "beginning_of_pass_write_index": 0, "end_of_pass_write_index": 1}
33
- compute_pass = command_encoder.begin_compute_pass(timestamp_writes=timestamp_writes if wait else None) # pylint: disable=E0606
34
- compute_pass.set_pipeline(compute_pipeline)
35
- compute_pass.set_bind_group(0, bind_group, [], 0, 999999) # last 2 not used
36
- compute_pass.dispatch_workgroups(*global_size) # x y z
37
- compute_pass.end()
149
+ query_set = webgpu.wgpuDeviceCreateQuerySet(self.dev, webgpu.WGPUQuerySetDescriptor(type=webgpu.WGPUQueryType_Timestamp, count=2))
150
+ query_buf = webgpu.wgpuDeviceCreateBuffer(self.dev,
151
+ webgpu.WGPUBufferDescriptor(size=16, usage=webgpu.WGPUBufferUsage_QueryResolve | webgpu.WGPUBufferUsage_CopySrc))
152
+ comp_pass_desc.timestampWrites = ctypes.pointer(webgpu.WGPUComputePassTimestampWrites(
153
+ querySet=query_set, beginningOfPassWriteIndex=0, endOfPassWriteIndex=1))
154
+
155
+ # Begin compute pass
156
+ compute_pass = webgpu.wgpuCommandEncoderBeginComputePass(command_encoder, comp_pass_desc)
157
+ webgpu.wgpuComputePassEncoderSetPipeline(compute_pass, pipeline_result)
158
+ webgpu.wgpuComputePassEncoderSetBindGroup(compute_pass, 0, bind_group, 0, None)
159
+ webgpu.wgpuComputePassEncoderDispatchWorkgroups(compute_pass, *global_size)
160
+ webgpu.wgpuComputePassEncoderEnd(compute_pass)
161
+
162
+ if wait: webgpu.wgpuCommandEncoderResolveQuerySet(command_encoder, query_set, 0, 2, query_buf, 0)
163
+
164
+ cmd_buf = webgpu.wgpuCommandEncoderFinish(command_encoder, webgpu.WGPUCommandBufferDescriptor())
165
+ webgpu.wgpuQueueSubmit(webgpu.wgpuDeviceGetQueue(self.dev), 1, (webgpu.WGPUCommandBuffer*1)(cmd_buf))
166
+
167
+ if buf_patch:
168
+ copy_buffer_to_buffer(self.dev, tmp_bufs[0], 0, bufs[0], 0, webgpu.wgpuBufferGetSize(bufs[0]))
169
+ webgpu.wgpuBufferDestroy(tmp_bufs[0])
170
+
38
171
  if wait:
39
- command_encoder.resolve_query_set(query_set=query_set, first_query=0, query_count=2, destination=query_buf, destination_offset=0)
40
- self.dev.queue.submit([command_encoder.finish()])
41
- return ((timestamps:=self.dev.queue.read_buffer(query_buf).cast("Q").tolist())[1] - timestamps[0]) / 1e9 if wait else None
172
+ time = ((timestamps:=read_buffer(self.dev, query_buf).cast("Q").tolist())[1] - timestamps[0]) / 1e9
173
+ webgpu.wgpuBufferDestroy(query_buf)
174
+ webgpu.wgpuQuerySetDestroy(query_set)
175
+ return time
42
176
 
43
- # WebGPU buffers have to be 4-byte aligned
44
177
  class WebGpuAllocator(Allocator):
45
178
  def __init__(self, dev): self.dev = dev
46
179
  def _alloc(self, size: int, options):
47
- return self.dev.create_buffer(size=round_up(size, 4), usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST | wgpu.BufferUsage.COPY_SRC)
180
+ # WebGPU buffers have to be 4-byte aligned
181
+ return webgpu.wgpuDeviceCreateBuffer(self.dev, webgpu.WGPUBufferDescriptor(size=round_up(size, 4),
182
+ usage=webgpu.WGPUBufferUsage_Storage | webgpu.WGPUBufferUsage_CopyDst | webgpu.WGPUBufferUsage_CopySrc))
48
183
  def _copyin(self, dest, src: memoryview):
49
184
  if src.nbytes % 4:
50
185
  padded_src = bytearray(round_up(src.nbytes, 4))
51
186
  padded_src[:src.nbytes] = src
52
- self.dev.queue.write_buffer(dest, 0, padded_src if src.nbytes % 4 else src)
187
+ write_buffer(self.dev, dest, 0, padded_src if src.nbytes % 4 else src)
53
188
  def _copyout(self, dest: memoryview, src):
54
- buffer_data = self.dev.queue.read_buffer(src, 0)
55
- dest[:] = buffer_data[:dest.nbytes] if src._nbytes > dest.nbytes else buffer_data
189
+ buffer_data = read_buffer(self.dev, src)
190
+ dest[:] = buffer_data[:dest.nbytes] if webgpu.wgpuBufferGetSize(src) > dest.nbytes else buffer_data
191
+ def _free(self, opaque, options):
192
+ webgpu.wgpuBufferDestroy(opaque)
56
193
 
57
194
  class WebGpuDevice(Compiled):
58
195
  def __init__(self, device:str):
59
- adapter = wgpu.gpu.request_adapter_sync(power_preference="high-performance")
60
- timestamp_supported = wgpu.FeatureName.timestamp_query in adapter.features
61
- wgpu_device = adapter.request_device_sync(required_features=[wgpu.FeatureName.timestamp_query] if timestamp_supported else [])
62
- super().__init__(device, WebGpuAllocator(wgpu_device), WGSLRenderer(), Compiler(),
63
- functools.partial(WebGPUProgram, (wgpu_device, timestamp_supported)))
196
+ # Requesting an adapter
197
+ adapter_res = _run(webgpu.wgpuInstanceRequestAdapterF, webgpu.WGPURequestAdapterCallbackInfo, webgpu.WGPURequestAdapterCallback,
198
+ webgpu.WGPURequestAdapterStatus__enumvalues, 1, 2, instance,
199
+
200
+ webgpu.WGPURequestAdapterOptions(powerPreference=webgpu.WGPUPowerPreference_HighPerformance,
201
+ backendType=backend_types.get(os.getenv("WEBGPU_BACKEND", ""), 0)))
202
+
203
+ # Get supported features
204
+ supported_features = webgpu.WGPUSupportedFeatures()
205
+ webgpu.wgpuAdapterGetFeatures(adapter_res, supported_features)
206
+ supported = [supported_features.features[i] for i in range(supported_features.featureCount)]
207
+ features = [feat for feat in [webgpu.WGPUFeatureName_TimestampQuery, webgpu.WGPUFeatureName_ShaderF16] if feat in supported]
208
+ dev_desc = webgpu.WGPUDeviceDescriptor(requiredFeatureCount=len(features),requiredFeatures=(webgpu.WGPUFeatureName * len(features))(*features))
209
+
210
+ # Limits
211
+ supported_limits = webgpu.WGPUSupportedLimits()
212
+ webgpu.wgpuAdapterGetLimits(adapter_res, ctypes.cast(ctypes.pointer(supported_limits),ctypes.POINTER(webgpu.struct_WGPUSupportedLimits)))
213
+ limits = webgpu.WGPURequiredLimits(limits=supported_limits.limits)
214
+ dev_desc.requiredLimits = ctypes.cast(ctypes.pointer(limits),ctypes.POINTER(webgpu.struct_WGPURequiredLimits))
215
+
216
+ # Requesting a device
217
+ device_res = _run(webgpu.wgpuAdapterRequestDeviceF, webgpu.WGPURequestDeviceCallbackInfo, webgpu.WGPURequestDeviceCallback,
218
+ webgpu.WGPURequestDeviceStatus__enumvalues, 1, 2, adapter_res, dev_desc)
219
+
220
+ super().__init__(device, WebGpuAllocator(device_res), WGSLRenderer(), Compiler(),
221
+ functools.partial(WebGPUProgram, (device_res, webgpu.WGPUFeatureName_TimestampQuery in supported)))
222
+
223
+ def synchronize(self):
224
+ _run(webgpu.wgpuQueueOnSubmittedWorkDone2, webgpu.WGPUQueueWorkDoneCallbackInfo2, webgpu.WGPUQueueWorkDoneCallback2,
225
+ webgpu.WGPUQueueWorkDoneStatus__enumvalues, None, None, webgpu.wgpuDeviceGetQueue(self.runtime.args[0][0]))