tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,482 @@
1
+ # the REMOTE=1 device is a process boundary between the frontend/runtime
2
+ # normally tinygrad is frontend <-> middleware <-> runtime <-> hardware
3
+ # with REMOTE tinygrad is frontend <-> middleware <-> RemoteDevice ///HTTP/// remote_server <-> runtime <-> hardware
4
+ # this client and server can be on the same machine, same network, or just same internet
5
+ # it should be a secure (example: no use of pickle) boundary. HTTP is used for RPC
6
+
7
+ from __future__ import annotations
8
+ from typing import Callable, Iterator, Any, cast
9
+ from collections import defaultdict
10
+ from dataclasses import dataclass, field, replace
11
+ import multiprocessing, threading, functools, itertools, asyncio, http, http.client, hashlib, time, os, binascii, struct, ast, contextlib, weakref
12
+ import traceback, builtins
13
+ from tinygrad.renderer import Renderer, ProgramSpec
14
+ from tinygrad.dtype import DTYPES_DICT, dtypes
15
+ from tinygrad.uop.ops import UOp, Ops, Variable, sint
16
+ from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, LazySeq, Timing
17
+ from tinygrad.engine.jit import GraphRunner, MultiGraphRunner, ExecItem, graph_class
18
+ from tinygrad.engine.realize import CompiledRunner, BufferXfer
19
+ from tinygrad.device import Compiled, Buffer, Allocator, Compiler, Device, BufferSpec
20
+ from tinygrad.runtime.support.ib import IBCtx, IBConn, SGE
21
+
22
+ # ***** API *****
23
+
24
+ @dataclass(frozen=True)
25
+ class SessionKey: host: str; idx: int; nonce: str # noqa: E702
26
+
27
+ @dataclass(frozen=True)
28
+ class RemoteRequest: session: SessionKey|None = field(default=None, kw_only=True)
29
+
30
+ @dataclass(frozen=True)
31
+ class SessionFree(RemoteRequest): pass
32
+
33
+ @dataclass(frozen=True)
34
+ class RemoteProperties:
35
+ real_device: str
36
+ renderer: tuple[str, str, tuple[Any, ...]]
37
+ offset_supported: bool
38
+ graph_supported: bool
39
+ graph_supports_multi: bool
40
+ ib_gid: bytes|None
41
+
42
+ @dataclass(frozen=True)
43
+ class RemoteException:
44
+ exc: Exception
45
+ trace: str = ""
46
+
47
+ @dataclass(frozen=True)
48
+ class GetProperties(RemoteRequest): pass
49
+
50
+ @dataclass(frozen=True)
51
+ class Event(RemoteRequest): event_session: SessionKey; event: int # noqa: E702
52
+
53
+ @dataclass(frozen=True)
54
+ class Wait(RemoteRequest): event: int
55
+
56
+ @dataclass(frozen=True)
57
+ class IBConnect(RemoteRequest): host: str; gid: bytes; qp_num: int # noqa: E702
58
+
59
+ @dataclass(frozen=True)
60
+ class BufferAlloc(RemoteRequest): buffer_num: int; size: int; options: BufferSpec # noqa: E702
61
+
62
+ @dataclass(frozen=True)
63
+ class BufferOffset(RemoteRequest): buffer_num: int; size: int; offset: int; sbuffer_num: int # noqa: E702
64
+
65
+ @dataclass(frozen=True)
66
+ class BufferIOVAS(RemoteRequest): buffer_nums: list[tuple[SessionKey, int]] # noqa: E702
67
+
68
+ @dataclass(frozen=True)
69
+ class BufferFree(RemoteRequest): buffer_num: int # noqa: E702
70
+
71
+ @dataclass(frozen=True)
72
+ class CopyIn(RemoteRequest): buffer_num: int; datahash: str # noqa: E702
73
+
74
+ @dataclass(frozen=True)
75
+ class CopyOut(RemoteRequest): buffer_num: int
76
+
77
+ @dataclass(frozen=True)
78
+ class Transfer(RemoteRequest): buffer_num: int; dsession: SessionKey; dbuffer_num: int # noqa: E702
79
+
80
+ @dataclass(frozen=True)
81
+ class BatchTransfer(RemoteRequest):
82
+ sbuffer_nums: list[tuple[SessionKey, int]]
83
+ dbuffer_nums: list[tuple[SessionKey, int]]
84
+
85
+ @dataclass(frozen=True)
86
+ class ProgramAlloc(RemoteRequest): name: str; datahash: str # noqa: E702
87
+
88
+ @dataclass(frozen=True)
89
+ class ProgramFree(RemoteRequest): name: str; datahash: str # noqa: E702
90
+
91
+ @dataclass(frozen=True)
92
+ class ProgramExec(RemoteRequest):
93
+ name: str; datahash: str; bufs: tuple[int, ...]; vals: tuple[int, ...] # noqa: E702
94
+ global_size: tuple[int, ...]|None; local_size: tuple[int, ...]|None; wait: bool # noqa: E702
95
+
96
+ @dataclass(frozen=True)
97
+ class GraphComputeItem:
98
+ session: SessionKey
99
+ name: str
100
+ datahash: str
101
+ bufs: tuple[int, ...]
102
+ vars: tuple[Variable, ...]
103
+ fixedvars: dict[Variable, int]
104
+ ins: tuple[int, ...]
105
+ outs: tuple[int, ...]
106
+ global_size: tuple[sint, ...]|None
107
+ local_size: tuple[sint, ...]|None
108
+
109
+ @dataclass(frozen=True)
110
+ class GraphAlloc(RemoteRequest):
111
+ graph_num: int
112
+ jit_cache: tuple[GraphComputeItem|Transfer, ...]
113
+ bufs: tuple[tuple[SessionKey, int], ...]
114
+ var_vals: dict[Variable, int]
115
+
116
+ @dataclass(frozen=True)
117
+ class GraphFree(RemoteRequest):
118
+ graph_num: int
119
+
120
+ @dataclass(frozen=True)
121
+ class GraphExec(RemoteRequest):
122
+ graph_num: int
123
+ bufs: tuple[tuple[SessionKey, int], ...]
124
+ var_vals: dict[Variable, int]
125
+ wait: bool
126
+
127
+ # for safe deserialization
128
+ eval_excs = [v for k,v in builtins.__dict__.items() if isinstance(v, type) and issubclass(v, Exception) and not k.endswith("Warning")]
129
+ eval_globals = {x.__name__:x for x in [SessionKey, SessionFree, RemoteProperties, GetProperties, Event, Wait, BufferAlloc, BufferOffset, BufferIOVAS,
130
+ BufferFree, CopyIn, CopyOut, Transfer, BatchTransfer, IBConnect, ProgramAlloc, ProgramFree, ProgramExec,
131
+ GraphComputeItem, GraphAlloc, GraphFree, GraphExec, BufferSpec, UOp, Ops, dtypes, RemoteException] + eval_excs}
132
+ attribute_whitelist: dict[Any, set[str]] = {dtypes: {*DTYPES_DICT.keys(), 'imagef', 'imageh'}, Ops: {x.name for x in Ops}}
133
+ eval_fxns = {ast.Constant: lambda x: x.value, ast.Tuple: lambda x: tuple(map(safe_eval, x.elts)), ast.List: lambda x: list(map(safe_eval, x.elts)),
134
+ ast.Dict: lambda x: {safe_eval(k):safe_eval(v) for k,v in zip(x.keys, x.values)},
135
+ ast.Call: lambda x: safe_eval(x.func)(*[safe_eval(arg) for arg in x.args], **{kwarg.arg: safe_eval(kwarg.value) for kwarg in x.keywords}),
136
+ ast.Name: lambda x: eval_globals[x.id], ast.Attribute: lambda x: safe_getattr(safe_eval(x.value), x.attr)}
137
+ def safe_getattr(value, attr):
138
+ assert attr in attribute_whitelist.get(value, set()), f'getattr({value}, {repr(attr)}) is not whitelisted'
139
+ return getattr(value, attr)
140
+ def safe_eval(node): return eval_fxns[node.__class__](node)
141
+
142
+ class BatchRequest:
143
+ def __init__(self):
144
+ self._q: list[RemoteRequest] = []
145
+ self._h: dict[str, bytes] = {}
146
+ def h(self, d:bytes|memoryview) -> str:
147
+ datahash = hashlib.sha256(d).hexdigest() # NOTE: this is very slow, should use blake3 on gpu instead
148
+ if datahash not in self._h:
149
+ self._h[datahash] = bytes.fromhex(datahash)+struct.pack("<Q", len(d))+bytes(d)
150
+ return datahash
151
+ def q(self, x:RemoteRequest): self._q.append(x)
152
+ def serialize(self) -> bytes:
153
+ self.h(repr(self._q).encode())
154
+ return b''.join(self._h.values())
155
+ def deserialize(self, dat:bytes) -> BatchRequest:
156
+ ptr = 0
157
+ while ptr < len(dat):
158
+ datahash, datalen = binascii.hexlify(dat[ptr:ptr+0x20]).decode(), struct.unpack("<Q", dat[ptr+0x20:ptr+0x28])[0]
159
+ self._h[datahash] = dat[ptr+0x28:ptr+0x28+datalen]
160
+ ptr += 0x28+datalen
161
+ self._q = safe_eval(ast.parse(self._h[datahash], mode="eval").body)
162
+ return self
163
+
164
+ # ***** backend *****
165
+
166
+ @dataclass
167
+ class RemoteSession:
168
+ programs: dict[tuple[str, str], Any] = field(default_factory=dict)
169
+ graphs: dict[int, GraphRunner] = field(default_factory=dict)
170
+ buffers: dict[int, Buffer] = field(default_factory=dict)
171
+ events: defaultdict[int, asyncio.Event] = field(default_factory=functools.partial(defaultdict, asyncio.Event))
172
+
173
+ class RemoteHandler:
174
+ def __init__(self, base_device: str):
175
+ self.base_device = base_device
176
+ self.sessions: defaultdict[SessionKey, RemoteSession] = defaultdict(RemoteSession)
177
+
178
+ try: self.ib_ctx: IBCtx|None = IBCtx(getenv("IB_DEV", 0))
179
+ except (IndexError, AttributeError): self.ib_ctx = None
180
+ self.ib_lock = asyncio.Lock()
181
+ self.ib_conns: dict[str, IBConn|None] = {}
182
+ self.iova_cache: dict[tuple[SessionKey, int], tuple[int, int, int]] = {}
183
+
184
+ async def __call__(self, reader:asyncio.StreamReader, writer:asyncio.StreamWriter):
185
+ while (req_hdr:=(await reader.readline()).decode().strip()):
186
+ req_method, req_path, _ = req_hdr.split(' ')
187
+ req_headers = {}
188
+ while (hdr:=(await reader.readline()).decode().strip()):
189
+ key, value = hdr.split(':', 1)
190
+ req_headers[key.lower()] = value.strip()
191
+ req_body = await reader.readexactly(int(req_headers.get("content-length", "0")))
192
+ try: res_status, res_body = await self.handle(req_method, req_path, req_body)
193
+ except Exception as e:
194
+ res_status, res_body = http.HTTPStatus.INTERNAL_SERVER_ERROR, repr(RemoteException(e, traceback.format_exc())).encode()
195
+ print(f"{traceback.format_exc()}", flush=True)
196
+ writer.write(f"HTTP/1.1 {res_status.value} {res_status.phrase}\r\nContent-Length: {len(res_body)}\r\n\r\n".encode() + res_body)
197
+
198
+ async def ib_connect(self, ssession:SessionKey, dsession:SessionKey) -> IBConn|None:
199
+ if self.ib_ctx is None: return None
200
+ await self.ib_lock.acquire()
201
+ conn = RemoteConnection(dsession.host)
202
+ if dsession.host not in self.ib_conns:
203
+ props = safe_eval(ast.parse(conn.q(GetProperties(session=dsession), wait=True), mode="eval").body)
204
+ if props.ib_gid is not None:
205
+ self.ib_conns[dsession.host] = ib_conn = IBConn(self.ib_ctx)
206
+ ibxc_ret = conn.q(IBConnect(ssession.host, ib_conn.gid, ib_conn.qp_num, session=dsession), wait=True)
207
+ ib_conn.connect(*struct.unpack('<16sQ', ibxc_ret))
208
+ else:
209
+ self.ib_conns[dsession.host] = None
210
+ self.ib_lock.release()
211
+ return self.ib_conns[dsession.host]
212
+
213
+ async def get_iovas(self, bufs:list[tuple[SessionKey, int]]) -> list[tuple[int, int, int]]:
214
+ await self.ib_lock.acquire()
215
+ if (rbufs:=[buf for buf in bufs if buf not in self.iova_cache]):
216
+ conn = RemoteConnection(rbufs[0][0].host)
217
+ resp = await conn.aq(BufferIOVAS(rbufs, session=rbufs[0][0]), wait=True)
218
+ self.iova_cache.update({rbuf: struct.unpack('<QQQ', resp[i*24:(i+1)*24]) for i,rbuf in enumerate(rbufs)})
219
+ self.ib_lock.release()
220
+ return [self.iova_cache[buf] for buf in bufs]
221
+
222
+ async def handle(self, method:str, path:str, body:bytes) -> tuple[http.HTTPStatus, bytes]:
223
+ status, ret = http.HTTPStatus.OK, b""
224
+ if path == "/batch" and method == "POST":
225
+ # TODO: streaming deserialize?
226
+ req = BatchRequest().deserialize(body)
227
+ # the cmds are always last (currently in datahash)
228
+ for c in req._q:
229
+ if DEBUG >= 1: print(c)
230
+ session, dev = self.sessions[unwrap(c.session)], Device[f"{self.base_device}:{unwrap(c.session).idx}"]
231
+ match c:
232
+ case SessionFree(): del self.sessions[unwrap(c.session)]
233
+ case GetProperties():
234
+ cls, args = dev.renderer.__reduce__()
235
+ graph_cls = graph_class(Device[self.base_device])
236
+ rp = RemoteProperties(
237
+ real_device=dev.device, renderer=(cls.__module__, cls.__name__, args), offset_supported=hasattr(dev.allocator, '_offset'),
238
+ graph_supported=graph_cls is not None,
239
+ graph_supports_multi=graph_cls is not None and issubclass(graph_cls, MultiGraphRunner) and hasattr(dev.allocator, '_transfer'),
240
+ ib_gid=bytes(self.ib_ctx.gid_attr.raw) if self.ib_ctx is not None else None,
241
+ )
242
+ ret = repr(rp).encode()
243
+ case Event():
244
+ if c.session == c.event_session:
245
+ session.events[c.event].set()
246
+ else:
247
+ for d in Device._opened_devices: Device[d].synchronize() # wait for device*s* to finish executing previous stuff
248
+ # TODO: don't wait, just send
249
+ await RemoteConnection(c.event_session.host).aq(Event(c.event_session, c.event, session=c.event_session), wait=True)
250
+ case Wait():
251
+ assert await session.events[c.event].wait()
252
+ del session.events[c.event] # do not leak memory
253
+ case IBConnect():
254
+ self.ib_conns[c.host] = ibc = IBConn(unwrap(self.ib_ctx))
255
+ ibc.connect(c.gid, c.qp_num)
256
+ ret = struct.pack('<16sQ', ibc.gid, ibc.qp_num)
257
+ case BufferAlloc():
258
+ assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already allocated"
259
+ session.buffers[c.buffer_num] = Buffer(dev.device, c.size, dtypes.uint8, options=c.options, preallocate=True)
260
+ case BufferIOVAS():
261
+ rets = []
262
+ for buffer_session,buffer_num in c.buffer_nums:
263
+ iova, mr = unwrap(self.ib_ctx).reg(buf:=self.sessions[buffer_session].buffers[buffer_num])
264
+ rets.append(struct.pack("<QQQ", iova, mr.contents.rkey, buf.nbytes))
265
+ ret = b"".join(rets)
266
+ case BufferOffset():
267
+ assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already exists"
268
+ session.buffers[c.buffer_num] = session.buffers[c.sbuffer_num].view(c.size, dtypes.uint8, c.offset).allocate()
269
+ case BufferFree(): del session.buffers[c.buffer_num]
270
+ case CopyIn(): session.buffers[c.buffer_num].copyin(memoryview(bytearray(req._h[c.datahash])))
271
+ case CopyOut(): session.buffers[c.buffer_num].copyout(memoryview(ret:=bytearray(session.buffers[c.buffer_num].nbytes)))
272
+ case Transfer():
273
+ if c.dsession.host == unwrap(c.session).host:
274
+ dsession, ddev = self.sessions[c.dsession], Device[f"{self.base_device}:{unwrap(c.dsession).idx}"]
275
+ dbuf, sbuf = dsession.buffers[c.dbuffer_num], session.buffers[c.buffer_num]
276
+ if hasattr(ddev.allocator, '_transfer'):
277
+ assert dbuf.nbytes == sbuf.nbytes, f"{dbuf.nbytes} != {sbuf.nbytes}"
278
+ ddev.allocator._transfer(dbuf._buf, sbuf._buf, dbuf.nbytes, dest_dev=ddev, src_dev=dev)
279
+ else:
280
+ sbuf.copyout(data:=memoryview(bytearray(sbuf.nbytes)))
281
+ dbuf.copyin(data)
282
+ else:
283
+ conn, ib_conn = RemoteConnection(c.dsession.host), await self.ib_connect(unwrap(c.session), c.dsession)
284
+ sbuf = session.buffers[c.buffer_num]
285
+ if ib_conn is not None:
286
+ src_iova, src_mr = unwrap(self.ib_ctx).reg(sbuf)
287
+ dst_iova, dst_key, dst_size = (await self.get_iovas([(c.dsession, c.dbuffer_num)]))[0]
288
+ assert sbuf.nbytes == dst_size, f"{sbuf.nbytes} != {dst_size}"
289
+ for d in Device._opened_devices: Device[d].synchronize()
290
+ ib_conn.rdma_write([SGE(dst_iova, dst_key, src_iova, src_mr.contents.lkey, dst_size)])
291
+ else:
292
+ sbuf.copyout(data:=memoryview(bytearray(sbuf.nbytes)))
293
+ await conn.aq(CopyIn(c.dbuffer_num, conn.req.h(data), session=c.dsession), wait=True)
294
+ case BatchTransfer():
295
+ conn, ib_conn = RemoteConnection(c.dbuffer_nums[0][0].host), await self.ib_connect(c.sbuffer_nums[0][0], c.dbuffer_nums[0][0])
296
+ if ib_conn is not None:
297
+ sbufs = [unwrap(self.ib_ctx).reg(self.sessions[s].buffers[bi]) for s,bi in c.sbuffer_nums]
298
+ dbufs = await self.get_iovas(c.dbuffer_nums)
299
+ for d in Device._opened_devices: Device[d].synchronize()
300
+ ib_conn.rdma_write([SGE(di, dk, si, sm.contents.lkey, ds) for (di,dk,ds),(si,sm) in zip(dbufs, sbufs)])
301
+ else:
302
+ for (sbuf_session,sbuf_num),(dbuf_session,dbuf_num) in zip(c.sbuffer_nums, c.dbuffer_nums):
303
+ sbuf = self.sessions[sbuf_session].buffers[sbuf_num]
304
+ sbuf.copyout(data:=memoryview(bytearray(sbuf.nbytes)))
305
+ await conn.aq(CopyIn(dbuf_num, conn.req.h(data), session=dbuf_session), wait=True)
306
+ case ProgramAlloc():
307
+ lib = dev.compiler.compile_cached(req._h[c.datahash].decode())
308
+ session.programs[(c.name, c.datahash)] = dev.runtime(c.name, lib)
309
+ case ProgramFree(): del session.programs[(c.name, c.datahash)]
310
+ case ProgramExec():
311
+ bufs = [session.buffers[x]._buf for x in c.bufs]
312
+ extra_args = {k:v for k,v in [("global_size", c.global_size), ("local_size", c.local_size)] if v is not None}
313
+ r = session.programs[(c.name, c.datahash)](*bufs, vals=c.vals, wait=c.wait, **extra_args)
314
+ if r is not None: ret = str(r).encode()
315
+ case GraphAlloc():
316
+ graph_fn: Callable = unwrap(dev.graph)
317
+ def _parse_ji(gi: GraphComputeItem|Transfer):
318
+ match gi:
319
+ case GraphComputeItem():
320
+ prg = self.sessions[gi.session].programs[(gi.name, gi.datahash)]
321
+ ps = ProgramSpec(gi.name, '', f"{self.base_device}:{gi.session.idx}", UOp(Ops.NOOP),
322
+ vars=list(gi.vars), ins=list(gi.ins), outs=list(gi.outs),
323
+ global_size=list(cast(tuple[int], gi.global_size)) if gi.global_size is not None else None,
324
+ local_size=list(cast(tuple[int], gi.local_size)) if gi.local_size is not None else None)
325
+ return ExecItem(CompiledRunner(ps, precompiled=b'', prg=prg), [self.sessions[gi.session].buffers[buf] for buf in gi.bufs],
326
+ fixedvars=gi.fixedvars)
327
+ case Transfer():
328
+ dbuf, sbuf = self.sessions[gi.dsession].buffers[gi.dbuffer_num], self.sessions[unwrap(gi.session)].buffers[gi.buffer_num]
329
+ assert dbuf.nbytes == sbuf.nbytes, f"{dbuf.nbytes} != {sbuf.nbytes}"
330
+ return ExecItem(BufferXfer(dbuf.nbytes, dbuf.device, sbuf.device), [dbuf, sbuf])
331
+ assert c.graph_num not in session.graphs, f"graph {c.graph_num} already allocated"
332
+ session.graphs[c.graph_num] = graph_fn(list(map(_parse_ji, c.jit_cache)), [self.sessions[s].buffers[i] for s,i in c.bufs], c.var_vals)
333
+ case GraphFree(): del session.graphs[c.graph_num]
334
+ case GraphExec():
335
+ r = session.graphs[c.graph_num]([self.sessions[s].buffers[i] for s,i in c.bufs], c.var_vals, wait=c.wait)
336
+ if r is not None: ret = str(r).encode()
337
+ else: status, ret = http.HTTPStatus.NOT_FOUND, b"Not Found"
338
+ return status, ret
339
+
340
+ def remote_server(port:int):
341
+ device = getenv("REMOTEDEV", next(Device.get_available_devices()) if Device.DEFAULT == "REMOTE" else Device.DEFAULT)
342
+ async def _inner_async(port:int, device:str):
343
+ print(f"start remote server on {port} with device {device}")
344
+ await (await asyncio.start_server(RemoteHandler(device), host='', port=port)).serve_forever()
345
+ asyncio.run(_inner_async(port, device))
346
+
347
+ # ***** frontend *****
348
+
349
+ class RemoteAllocator(Allocator['RemoteDevice']):
350
+ def __init__(self, dev:RemoteDevice):
351
+ if dev.properties.offset_supported: self._offset = self._dyn_offset
352
+ super().__init__(dev)
353
+ # TODO: ideally we shouldn't have to deal with images here
354
+ def _alloc(self, size:int, options:BufferSpec) -> int:
355
+ self.dev.q(BufferAlloc(buffer_num:=next(self.dev.buffer_num), size, options))
356
+ return buffer_num
357
+ # TODO: options should not be here in any Allocator
358
+ def _free(self, opaque:int, options):
359
+ try: self.dev.q(BufferFree(opaque))
360
+ except (TypeError, AttributeError): pass
361
+ def _copyin(self, dest:int, src:memoryview): self.dev.q(CopyIn(dest, self.dev.conn.req.h(src)))
362
+ def _copyout(self, dest:memoryview, src:int):
363
+ resp = self.dev.q(CopyOut(src), wait=True)
364
+ assert len(resp) == len(dest), f"buffer length mismatch {len(resp)} != {len(dest)}"
365
+ dest[:] = resp
366
+ def _transfer(self, dest, src, sz, src_dev, dest_dev):
367
+ if dest_dev.conn != src_dev.conn:
368
+ dest_dev.q(Event(src_dev.session, start_event:=next(src_dev.event_num)))
369
+ src_dev.q(Wait(start_event))
370
+ src_dev.q(Transfer(src, dest_dev.session, dest))
371
+ if dest_dev.conn != src_dev.conn:
372
+ src_dev.q(Event(dest_dev.session, end_event:=next(dest_dev.event_num)))
373
+ dest_dev.q(Wait(end_event))
374
+ if DEBUG >= 2: dest_dev.conn.batch_submit()
375
+ def _dyn_offset(self, opaque:int, size:int, offset:int) -> int:
376
+ self.dev.q(BufferOffset(buffer_num:=next(self.dev.buffer_num), size, offset, opaque))
377
+ return buffer_num
378
+
379
+ class RemoteProgram:
380
+ def __init__(self, dev:RemoteDevice, name:str, lib:bytes):
381
+ self.dev, self.name = dev, name
382
+ self.datahash = self.dev.conn.req.h(lib)
383
+ self.dev.q(ProgramAlloc(self.name, self.datahash))
384
+ super().__init__()
385
+ weakref.finalize(self, self._fini, self.dev, self.name, self.datahash)
386
+
387
+ @staticmethod
388
+ def _fini(dev:RemoteDevice, name:str, datahash:str): dev.q(ProgramFree(name, datahash))
389
+
390
+ def __call__(self, *bufs, global_size=None, local_size=None, vals:tuple[int, ...]=(), wait=False):
391
+ ret = self.dev.q(ProgramExec(self.name, self.datahash, bufs, vals, global_size, local_size, wait), wait=wait)
392
+ if wait: return float(ret)
393
+
394
+ @functools.cache
395
+ class RemoteConnection:
396
+ q_lock = threading.Lock()
397
+ all: dict[RemoteConnection, None] = {} # dict instead of set for deterministic ordering
398
+
399
+ def __init__(self, host:str):
400
+ if DEBUG >= 1: print(f"remote with host {host}")
401
+ while 1:
402
+ try:
403
+ self.conn = http.client.HTTPConnection(host, timeout=getenv("REMOTE_TIMEOUT", 300.0))
404
+ self.conn.connect()
405
+ break
406
+ except Exception as e:
407
+ print(e)
408
+ time.sleep(0.1)
409
+ self.req: BatchRequest = BatchRequest()
410
+ RemoteConnection.all[self] = None
411
+
412
+ def q(self, x:RemoteRequest, wait:bool=False):
413
+ with RemoteConnection.q_lock:
414
+ self.req.q(x)
415
+ if wait: return self.batch_submit(take_q=False)
416
+
417
+ async def aq(self, x:RemoteRequest, wait:bool=False): return await asyncio.to_thread(self.q, x, wait=wait)
418
+
419
+ def batch_submit(self, take_q:bool=True):
420
+ if take_q: RemoteConnection.q_lock.acquire()
421
+ conns = RemoteConnection.all.keys()
422
+ datas = {conn: conn.req.serialize() for conn in conns}
423
+ reqs, hashes, hash_datas = sum(len(c.req._q) for c in conns), sum(len(c.req._h) for c in conns), sum(len(data) for data in datas.values())
424
+ with Timing(f"*** send {reqs:-3d} requests {hashes:-3d} hashes with len {hash_datas/1024:.2f} kB in ", enabled=DEBUG>=3):
425
+ for conn,data in datas.items(): conn.conn.request("POST", "/batch", data)
426
+ for conn in datas.keys():
427
+ response = conn.conn.getresponse()
428
+ resp = response.read()
429
+ conn.req = BatchRequest() # no matter what response, reset conn
430
+ if response.status == http.HTTPStatus.INTERNAL_SERVER_ERROR:
431
+ exc_wrapper = safe_eval(ast.parse(resp.decode(), mode="eval").body)
432
+ exc_wrapper.exc.add_note(exc_wrapper.trace)
433
+ raise exc_wrapper.exc
434
+ assert response.status == http.HTTPStatus.OK, f"POST /batch failed: {resp.decode()}"
435
+ if conn == self: ret = resp
436
+ if take_q: RemoteConnection.q_lock.release()
437
+ return ret
438
+
439
+ def parse_hosts(hs:str) -> list[tuple[str, int]]|LazySeq[tuple[str, int]]:
440
+ hosts = [(unwrap(h), int(c) if c is not None else c) for h,c in ((h.split("*", maxsplit=1)+[None,])[:2] for h in hs.split(","))]
441
+ if len(hosts) == 1 and hosts[0][1] is None: return LazySeq(lambda idx: (hosts[0][0], idx))
442
+ return [(h, i) for h,c in hosts for i in range(unwrap(c))]
443
+
444
+ class RemoteDevice(Compiled):
445
+ devices = parse_hosts(getenv("HOST", ""))
446
+
447
+ def __init__(self, device:str):
448
+ host, idx = RemoteDevice.devices[int(device.split(":")[1]) if ":" in device else 0]
449
+
450
+ # connection is shared between sessions on the same host
451
+ self.session: SessionKey = SessionKey(host or RemoteDevice.local_server(), idx, binascii.hexlify(os.urandom(0x10)).decode())
452
+ self.conn: RemoteConnection = RemoteConnection(self.session.host)
453
+
454
+ # state for the session
455
+ self.buffer_num: Iterator[int] = itertools.count(0)
456
+ self.graph_num: Iterator[int] = itertools.count(0)
457
+ self.event_num: Iterator[int] = itertools.count(0)
458
+
459
+ self.properties: RemoteProperties = safe_eval(ast.parse(self.q(GetProperties(), wait=True), mode="eval").body)
460
+ if DEBUG >= 1: print(f"remote has device {self.properties.real_device}")
461
+ # TODO: how to we have BEAM be cached on the backend? this should just send a specification of the compute. rethink what goes in Renderer
462
+ renderer = self.properties.renderer
463
+ if not renderer[0].startswith("tinygrad.") or not renderer[1].endswith("Renderer"): raise RuntimeError(f"bad renderer {renderer}")
464
+ renderer_class = fromimport(renderer[0], renderer[1]) # TODO: is this secure?
465
+ if not issubclass(renderer_class, Renderer): raise RuntimeError(f"renderer isn't a Renderer {renderer}")
466
+ renderer_instance = renderer_class(*renderer[2])
467
+ renderer_instance.device = device
468
+ graph = fromimport('tinygrad.runtime.graph.remote', "RemoteGraph") if self.properties.graph_supported else None
469
+ super().__init__(device, RemoteAllocator(self), renderer_instance, Compiler(), functools.partial(RemoteProgram, self), graph, id(self.conn))
470
+
471
+ def finalize(self):
472
+ with contextlib.suppress(ConnectionError, http.client.HTTPException): self.q(SessionFree(), wait=True)
473
+
474
+ def q(self, x:RemoteRequest, wait:bool=False): return self.conn.q(replace(x, session=self.session), wait=wait)
475
+
476
+ @functools.cache
477
+ @staticmethod
478
+ def local_server():
479
+ multiprocessing.Process(target=remote_server, args=(6667,), name="MainProcess", daemon=True).start()
480
+ return "127.0.0.1:6667"
481
+
482
+ if __name__ == "__main__": remote_server(getenv("PORT", 6667))
@@ -1,35 +1,34 @@
1
1
  import functools, struct
2
- from tinygrad.device import Compiled, Allocator, Compiler
2
+ from tinygrad.device import Compiled, Allocator, Compiler, BufferSpec
3
3
  from tinygrad.renderer.wgsl import WGSLRenderer
4
- from tinygrad.helpers import round_up, OSX
4
+ from tinygrad.helpers import round_up, suppress_finalizing
5
5
  from tinygrad.runtime.autogen import webgpu
6
- from typing import List, Any
6
+ from typing import List, Any, TypeAlias
7
7
  import ctypes
8
8
  import os
9
9
 
10
+ WGPUDevPtr: TypeAlias = webgpu.WGPUDevice # type: ignore
11
+ WGPUBufPtr: TypeAlias = webgpu.WGPUBuffer # type: ignore
12
+
10
13
  backend_types = {v: k for k, v in webgpu.WGPUBackendType__enumvalues.items() }
11
14
 
12
- try:
13
- instance = webgpu.wgpuCreateInstance(webgpu.WGPUInstanceDescriptor(features = webgpu.WGPUInstanceFeatures(timedWaitAnyEnable = True)))
14
- except AttributeError:
15
- raise RuntimeError("Cannot find dawn library. Install it with: " + ("brew tap wpmed92/dawn && brew install dawn" if OSX else
16
- "sudo curl -L https://github.com/wpmed92/pydawn/releases/download/v0.1.6/libwebgpu_dawn.so -o /usr/lib/libwebgpu_dawn.so"))
15
+ instance = webgpu.wgpuCreateInstance(webgpu.WGPUInstanceDescriptor(features = webgpu.WGPUInstanceFeatures(timedWaitAnyEnable = True)))
17
16
 
18
- def to_c_string(_str): return ctypes.create_string_buffer(_str.encode('utf-8'))
17
+ def to_c_string(_str:str) -> ctypes.Array: return ctypes.create_string_buffer(_str.encode('utf-8'))
19
18
 
20
- def from_wgpu_str(string_view): return ctypes.string_at(string_view.data, string_view.length).decode("utf-8")
19
+ def from_wgpu_str(string_view:webgpu.struct_WGPUStringView) -> str: return ctypes.string_at(string_view.data, string_view.length).decode("utf-8")
21
20
 
22
- def to_wgpu_str(_str):
21
+ def to_wgpu_str(_str:str) -> webgpu.struct_WGPUStringView:
23
22
  return webgpu.WGPUStringView(data=ctypes.cast(ctypes.pointer(to_c_string(_str)), ctypes.POINTER(ctypes.c_char)), length=len(_str))
24
23
 
25
- def _wait(future):
24
+ def _wait(future:webgpu.struct_WGPUFuture):
26
25
  assert webgpu.wgpuInstanceWaitAny(instance, 1, webgpu.WGPUFutureWaitInfo(future=future), 2**64-1) == webgpu.WGPUWaitStatus_Success, "Future failed"
27
26
 
28
- def write_buffer(device, buf, offset, src):
27
+ def write_buffer(device:WGPUDevPtr, buf:WGPUBufPtr, offset:int, src:memoryview|bytearray|bytes):
29
28
  src = bytearray(src)
30
29
  webgpu.wgpuQueueWriteBuffer(webgpu.wgpuDeviceGetQueue(device), buf, offset, (ctypes.c_uint8 * len(src)).from_buffer(src), len(src))
31
30
 
32
- def _run(async_fun, cb_info_type, cb_type, status_enum, res_idx, msg_idx, *params):
31
+ def _run(async_fun, cb_info_type, cb_type, status_enum, res_idx:int|None, msg_idx:int|None, *params):
33
32
  result: List[Any] = []
34
33
 
35
34
  def cb(*params):
@@ -42,7 +41,7 @@ def _run(async_fun, cb_info_type, cb_type, status_enum, res_idx, msg_idx, *param
42
41
  if result[0] != 1: raise RuntimeError(f"[{status_enum[result[0]] if status_enum else 'ERROR'}]{result[msg_idx] if msg_idx else ''}")
43
42
  return result[res_idx] if res_idx else None
44
43
 
45
- def copy_buffer_to_buffer(dev, src, src_offset, dst, dst_offset, size):
44
+ def copy_buffer_to_buffer(dev:WGPUDevPtr, src:WGPUBufPtr, src_offset:int, dst:WGPUBufPtr, dst_offset:int, size:int):
46
45
  encoder = webgpu.wgpuDeviceCreateCommandEncoder(dev, webgpu.WGPUCommandEncoderDescriptor())
47
46
  webgpu.wgpuCommandEncoderCopyBufferToBuffer(encoder, src, src_offset, dst, dst_offset, size)
48
47
  cb = webgpu.wgpuCommandEncoderFinish(encoder, webgpu.WGPUCommandBufferDescriptor())
@@ -50,7 +49,7 @@ def copy_buffer_to_buffer(dev, src, src_offset, dst, dst_offset, size):
50
49
  webgpu.wgpuCommandBufferRelease(cb)
51
50
  webgpu.wgpuCommandEncoderRelease(encoder)
52
51
 
53
- def read_buffer(dev, buf):
52
+ def read_buffer(dev:WGPUDevPtr, buf:WGPUBufPtr) -> memoryview:
54
53
  size = webgpu.wgpuBufferGetSize(buf)
55
54
  tmp_buffer = webgpu.wgpuDeviceCreateBuffer(dev, webgpu.WGPUBufferDescriptor(size=size,
56
55
  usage=webgpu.WGPUBufferUsage_CopyDst | webgpu.WGPUBufferUsage_MapRead, mappedAtCreation=False))
@@ -63,17 +62,17 @@ def read_buffer(dev, buf):
63
62
  webgpu.wgpuBufferDestroy(tmp_buffer)
64
63
  return memoryview(buf_copy).cast("B")
65
64
 
66
- def pop_error(device):
65
+ def pop_error(device:WGPUDevPtr) -> str:
67
66
  return _run(webgpu.wgpuDevicePopErrorScopeF, webgpu.WGPUPopErrorScopeCallbackInfo, webgpu.WGPUPopErrorScopeCallback, None, 2, 2, device)
68
67
 
69
- def create_uniform(wgpu_device, val):
68
+ def create_uniform(wgpu_device:WGPUDevPtr, val:int|float) -> WGPUBufPtr:
70
69
  buf = webgpu.wgpuDeviceCreateBuffer(wgpu_device,
71
70
  webgpu.WGPUBufferDescriptor(size=4, usage=webgpu.WGPUBufferUsage_Uniform | webgpu.WGPUBufferUsage_CopyDst))
72
71
  write_buffer(wgpu_device, buf, 0, val.to_bytes(4, "little") if isinstance(val, int) else struct.pack('<f', val))
73
72
  return buf
74
73
 
75
74
  class WebGPUProgram:
76
- def __init__(self, dev, name:str, lib:bytes):
75
+ def __init__(self, dev:tuple[WGPUDevPtr, bool], name:str, lib:bytes):
77
76
  (self.dev, self.timestamp_supported) = dev
78
77
 
79
78
  # Creating shader module
@@ -89,14 +88,15 @@ class WebGPUProgram:
89
88
  if err := pop_error(self.dev): raise RuntimeError(f"Shader compilation failed: {err}")
90
89
 
91
90
  self.name, self.lib, self.prg = name, lib, shader_module
92
- def __call__(self, *bufs, global_size=(1,1,1), local_size=(1,1,1), vals=(), wait=False):
91
+ def __call__(self, *bufs:WGPUBufPtr, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1),
92
+ vals:tuple[int, ...]=(), wait=False) -> float|None:
93
93
  wait = wait and self.timestamp_supported
94
94
  tmp_bufs = [*bufs]
95
95
  buf_patch = False
96
96
 
97
97
  # WebGPU does not allow using the same buffer for input and output
98
98
  for i in range(1, len(bufs)):
99
- if bufs[i] == bufs[0]:
99
+ if ctypes.addressof(bufs[i]) == ctypes.addressof(bufs[0]):
100
100
  tmp_bufs[0] = webgpu.wgpuDeviceCreateBuffer(self.dev,
101
101
  webgpu.WGPUBufferDescriptor(size=webgpu.wgpuBufferGetSize(bufs[0]), usage=webgpu.wgpuBufferGetUsage(bufs[0])))
102
102
  buf_patch = True
@@ -173,23 +173,23 @@ class WebGPUProgram:
173
173
  webgpu.wgpuBufferDestroy(query_buf)
174
174
  webgpu.wgpuQuerySetDestroy(query_set)
175
175
  return time
176
+ return None
176
177
 
177
- class WebGpuAllocator(Allocator):
178
- def __init__(self, dev): self.dev = dev
179
- def _alloc(self, size: int, options):
178
+ class WebGpuAllocator(Allocator['WGPUDevPtr']):
179
+ def _alloc(self, size:int, options:BufferSpec) -> WGPUBufPtr:
180
180
  # WebGPU buffers have to be 4-byte aligned
181
181
  return webgpu.wgpuDeviceCreateBuffer(self.dev, webgpu.WGPUBufferDescriptor(size=round_up(size, 4),
182
182
  usage=webgpu.WGPUBufferUsage_Storage | webgpu.WGPUBufferUsage_CopyDst | webgpu.WGPUBufferUsage_CopySrc))
183
- def _copyin(self, dest, src: memoryview):
183
+ def _copyin(self, dest:WGPUBufPtr, src:memoryview):
184
184
  if src.nbytes % 4:
185
185
  padded_src = bytearray(round_up(src.nbytes, 4))
186
186
  padded_src[:src.nbytes] = src
187
187
  write_buffer(self.dev, dest, 0, padded_src if src.nbytes % 4 else src)
188
- def _copyout(self, dest: memoryview, src):
188
+ def _copyout(self, dest:memoryview, src:WGPUBufPtr):
189
189
  buffer_data = read_buffer(self.dev, src)
190
190
  dest[:] = buffer_data[:dest.nbytes] if webgpu.wgpuBufferGetSize(src) > dest.nbytes else buffer_data
191
- def _free(self, opaque, options):
192
- webgpu.wgpuBufferDestroy(opaque)
191
+ @suppress_finalizing
192
+ def _free(self, opaque:WGPUBufPtr, options:BufferSpec): webgpu.wgpuBufferDestroy(opaque)
193
193
 
194
194
  class WebGpuDevice(Compiled):
195
195
  def __init__(self, device:str):