tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,482 @@
|
|
1
|
+
# the REMOTE=1 device is a process boundary between the frontend/runtime
|
2
|
+
# normally tinygrad is frontend <-> middleware <-> runtime <-> hardware
|
3
|
+
# with REMOTE tinygrad is frontend <-> middleware <-> RemoteDevice ///HTTP/// remote_server <-> runtime <-> hardware
|
4
|
+
# this client and server can be on the same machine, same network, or just same internet
|
5
|
+
# it should be a secure (example: no use of pickle) boundary. HTTP is used for RPC
|
6
|
+
|
7
|
+
from __future__ import annotations
|
8
|
+
from typing import Callable, Iterator, Any, cast
|
9
|
+
from collections import defaultdict
|
10
|
+
from dataclasses import dataclass, field, replace
|
11
|
+
import multiprocessing, threading, functools, itertools, asyncio, http, http.client, hashlib, time, os, binascii, struct, ast, contextlib, weakref
|
12
|
+
import traceback, builtins
|
13
|
+
from tinygrad.renderer import Renderer, ProgramSpec
|
14
|
+
from tinygrad.dtype import DTYPES_DICT, dtypes
|
15
|
+
from tinygrad.uop.ops import UOp, Ops, Variable, sint
|
16
|
+
from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, LazySeq, Timing
|
17
|
+
from tinygrad.engine.jit import GraphRunner, MultiGraphRunner, ExecItem, graph_class
|
18
|
+
from tinygrad.engine.realize import CompiledRunner, BufferXfer
|
19
|
+
from tinygrad.device import Compiled, Buffer, Allocator, Compiler, Device, BufferSpec
|
20
|
+
from tinygrad.runtime.support.ib import IBCtx, IBConn, SGE
|
21
|
+
|
22
|
+
# ***** API *****
|
23
|
+
|
24
|
+
@dataclass(frozen=True)
|
25
|
+
class SessionKey: host: str; idx: int; nonce: str # noqa: E702
|
26
|
+
|
27
|
+
@dataclass(frozen=True)
|
28
|
+
class RemoteRequest: session: SessionKey|None = field(default=None, kw_only=True)
|
29
|
+
|
30
|
+
@dataclass(frozen=True)
|
31
|
+
class SessionFree(RemoteRequest): pass
|
32
|
+
|
33
|
+
@dataclass(frozen=True)
|
34
|
+
class RemoteProperties:
|
35
|
+
real_device: str
|
36
|
+
renderer: tuple[str, str, tuple[Any, ...]]
|
37
|
+
offset_supported: bool
|
38
|
+
graph_supported: bool
|
39
|
+
graph_supports_multi: bool
|
40
|
+
ib_gid: bytes|None
|
41
|
+
|
42
|
+
@dataclass(frozen=True)
|
43
|
+
class RemoteException:
|
44
|
+
exc: Exception
|
45
|
+
trace: str = ""
|
46
|
+
|
47
|
+
@dataclass(frozen=True)
|
48
|
+
class GetProperties(RemoteRequest): pass
|
49
|
+
|
50
|
+
@dataclass(frozen=True)
|
51
|
+
class Event(RemoteRequest): event_session: SessionKey; event: int # noqa: E702
|
52
|
+
|
53
|
+
@dataclass(frozen=True)
|
54
|
+
class Wait(RemoteRequest): event: int
|
55
|
+
|
56
|
+
@dataclass(frozen=True)
|
57
|
+
class IBConnect(RemoteRequest): host: str; gid: bytes; qp_num: int # noqa: E702
|
58
|
+
|
59
|
+
@dataclass(frozen=True)
|
60
|
+
class BufferAlloc(RemoteRequest): buffer_num: int; size: int; options: BufferSpec # noqa: E702
|
61
|
+
|
62
|
+
@dataclass(frozen=True)
|
63
|
+
class BufferOffset(RemoteRequest): buffer_num: int; size: int; offset: int; sbuffer_num: int # noqa: E702
|
64
|
+
|
65
|
+
@dataclass(frozen=True)
|
66
|
+
class BufferIOVAS(RemoteRequest): buffer_nums: list[tuple[SessionKey, int]] # noqa: E702
|
67
|
+
|
68
|
+
@dataclass(frozen=True)
|
69
|
+
class BufferFree(RemoteRequest): buffer_num: int # noqa: E702
|
70
|
+
|
71
|
+
@dataclass(frozen=True)
|
72
|
+
class CopyIn(RemoteRequest): buffer_num: int; datahash: str # noqa: E702
|
73
|
+
|
74
|
+
@dataclass(frozen=True)
|
75
|
+
class CopyOut(RemoteRequest): buffer_num: int
|
76
|
+
|
77
|
+
@dataclass(frozen=True)
|
78
|
+
class Transfer(RemoteRequest): buffer_num: int; dsession: SessionKey; dbuffer_num: int # noqa: E702
|
79
|
+
|
80
|
+
@dataclass(frozen=True)
|
81
|
+
class BatchTransfer(RemoteRequest):
|
82
|
+
sbuffer_nums: list[tuple[SessionKey, int]]
|
83
|
+
dbuffer_nums: list[tuple[SessionKey, int]]
|
84
|
+
|
85
|
+
@dataclass(frozen=True)
|
86
|
+
class ProgramAlloc(RemoteRequest): name: str; datahash: str # noqa: E702
|
87
|
+
|
88
|
+
@dataclass(frozen=True)
|
89
|
+
class ProgramFree(RemoteRequest): name: str; datahash: str # noqa: E702
|
90
|
+
|
91
|
+
@dataclass(frozen=True)
|
92
|
+
class ProgramExec(RemoteRequest):
|
93
|
+
name: str; datahash: str; bufs: tuple[int, ...]; vals: tuple[int, ...] # noqa: E702
|
94
|
+
global_size: tuple[int, ...]|None; local_size: tuple[int, ...]|None; wait: bool # noqa: E702
|
95
|
+
|
96
|
+
@dataclass(frozen=True)
|
97
|
+
class GraphComputeItem:
|
98
|
+
session: SessionKey
|
99
|
+
name: str
|
100
|
+
datahash: str
|
101
|
+
bufs: tuple[int, ...]
|
102
|
+
vars: tuple[Variable, ...]
|
103
|
+
fixedvars: dict[Variable, int]
|
104
|
+
ins: tuple[int, ...]
|
105
|
+
outs: tuple[int, ...]
|
106
|
+
global_size: tuple[sint, ...]|None
|
107
|
+
local_size: tuple[sint, ...]|None
|
108
|
+
|
109
|
+
@dataclass(frozen=True)
|
110
|
+
class GraphAlloc(RemoteRequest):
|
111
|
+
graph_num: int
|
112
|
+
jit_cache: tuple[GraphComputeItem|Transfer, ...]
|
113
|
+
bufs: tuple[tuple[SessionKey, int], ...]
|
114
|
+
var_vals: dict[Variable, int]
|
115
|
+
|
116
|
+
@dataclass(frozen=True)
|
117
|
+
class GraphFree(RemoteRequest):
|
118
|
+
graph_num: int
|
119
|
+
|
120
|
+
@dataclass(frozen=True)
|
121
|
+
class GraphExec(RemoteRequest):
|
122
|
+
graph_num: int
|
123
|
+
bufs: tuple[tuple[SessionKey, int], ...]
|
124
|
+
var_vals: dict[Variable, int]
|
125
|
+
wait: bool
|
126
|
+
|
127
|
+
# for safe deserialization
|
128
|
+
eval_excs = [v for k,v in builtins.__dict__.items() if isinstance(v, type) and issubclass(v, Exception) and not k.endswith("Warning")]
|
129
|
+
eval_globals = {x.__name__:x for x in [SessionKey, SessionFree, RemoteProperties, GetProperties, Event, Wait, BufferAlloc, BufferOffset, BufferIOVAS,
|
130
|
+
BufferFree, CopyIn, CopyOut, Transfer, BatchTransfer, IBConnect, ProgramAlloc, ProgramFree, ProgramExec,
|
131
|
+
GraphComputeItem, GraphAlloc, GraphFree, GraphExec, BufferSpec, UOp, Ops, dtypes, RemoteException] + eval_excs}
|
132
|
+
attribute_whitelist: dict[Any, set[str]] = {dtypes: {*DTYPES_DICT.keys(), 'imagef', 'imageh'}, Ops: {x.name for x in Ops}}
|
133
|
+
eval_fxns = {ast.Constant: lambda x: x.value, ast.Tuple: lambda x: tuple(map(safe_eval, x.elts)), ast.List: lambda x: list(map(safe_eval, x.elts)),
|
134
|
+
ast.Dict: lambda x: {safe_eval(k):safe_eval(v) for k,v in zip(x.keys, x.values)},
|
135
|
+
ast.Call: lambda x: safe_eval(x.func)(*[safe_eval(arg) for arg in x.args], **{kwarg.arg: safe_eval(kwarg.value) for kwarg in x.keywords}),
|
136
|
+
ast.Name: lambda x: eval_globals[x.id], ast.Attribute: lambda x: safe_getattr(safe_eval(x.value), x.attr)}
|
137
|
+
def safe_getattr(value, attr):
|
138
|
+
assert attr in attribute_whitelist.get(value, set()), f'getattr({value}, {repr(attr)}) is not whitelisted'
|
139
|
+
return getattr(value, attr)
|
140
|
+
def safe_eval(node): return eval_fxns[node.__class__](node)
|
141
|
+
|
142
|
+
class BatchRequest:
|
143
|
+
def __init__(self):
|
144
|
+
self._q: list[RemoteRequest] = []
|
145
|
+
self._h: dict[str, bytes] = {}
|
146
|
+
def h(self, d:bytes|memoryview) -> str:
|
147
|
+
datahash = hashlib.sha256(d).hexdigest() # NOTE: this is very slow, should use blake3 on gpu instead
|
148
|
+
if datahash not in self._h:
|
149
|
+
self._h[datahash] = bytes.fromhex(datahash)+struct.pack("<Q", len(d))+bytes(d)
|
150
|
+
return datahash
|
151
|
+
def q(self, x:RemoteRequest): self._q.append(x)
|
152
|
+
def serialize(self) -> bytes:
|
153
|
+
self.h(repr(self._q).encode())
|
154
|
+
return b''.join(self._h.values())
|
155
|
+
def deserialize(self, dat:bytes) -> BatchRequest:
|
156
|
+
ptr = 0
|
157
|
+
while ptr < len(dat):
|
158
|
+
datahash, datalen = binascii.hexlify(dat[ptr:ptr+0x20]).decode(), struct.unpack("<Q", dat[ptr+0x20:ptr+0x28])[0]
|
159
|
+
self._h[datahash] = dat[ptr+0x28:ptr+0x28+datalen]
|
160
|
+
ptr += 0x28+datalen
|
161
|
+
self._q = safe_eval(ast.parse(self._h[datahash], mode="eval").body)
|
162
|
+
return self
|
163
|
+
|
164
|
+
# ***** backend *****
|
165
|
+
|
166
|
+
@dataclass
|
167
|
+
class RemoteSession:
|
168
|
+
programs: dict[tuple[str, str], Any] = field(default_factory=dict)
|
169
|
+
graphs: dict[int, GraphRunner] = field(default_factory=dict)
|
170
|
+
buffers: dict[int, Buffer] = field(default_factory=dict)
|
171
|
+
events: defaultdict[int, asyncio.Event] = field(default_factory=functools.partial(defaultdict, asyncio.Event))
|
172
|
+
|
173
|
+
class RemoteHandler:
|
174
|
+
def __init__(self, base_device: str):
|
175
|
+
self.base_device = base_device
|
176
|
+
self.sessions: defaultdict[SessionKey, RemoteSession] = defaultdict(RemoteSession)
|
177
|
+
|
178
|
+
try: self.ib_ctx: IBCtx|None = IBCtx(getenv("IB_DEV", 0))
|
179
|
+
except (IndexError, AttributeError): self.ib_ctx = None
|
180
|
+
self.ib_lock = asyncio.Lock()
|
181
|
+
self.ib_conns: dict[str, IBConn|None] = {}
|
182
|
+
self.iova_cache: dict[tuple[SessionKey, int], tuple[int, int, int]] = {}
|
183
|
+
|
184
|
+
async def __call__(self, reader:asyncio.StreamReader, writer:asyncio.StreamWriter):
|
185
|
+
while (req_hdr:=(await reader.readline()).decode().strip()):
|
186
|
+
req_method, req_path, _ = req_hdr.split(' ')
|
187
|
+
req_headers = {}
|
188
|
+
while (hdr:=(await reader.readline()).decode().strip()):
|
189
|
+
key, value = hdr.split(':', 1)
|
190
|
+
req_headers[key.lower()] = value.strip()
|
191
|
+
req_body = await reader.readexactly(int(req_headers.get("content-length", "0")))
|
192
|
+
try: res_status, res_body = await self.handle(req_method, req_path, req_body)
|
193
|
+
except Exception as e:
|
194
|
+
res_status, res_body = http.HTTPStatus.INTERNAL_SERVER_ERROR, repr(RemoteException(e, traceback.format_exc())).encode()
|
195
|
+
print(f"{traceback.format_exc()}", flush=True)
|
196
|
+
writer.write(f"HTTP/1.1 {res_status.value} {res_status.phrase}\r\nContent-Length: {len(res_body)}\r\n\r\n".encode() + res_body)
|
197
|
+
|
198
|
+
async def ib_connect(self, ssession:SessionKey, dsession:SessionKey) -> IBConn|None:
|
199
|
+
if self.ib_ctx is None: return None
|
200
|
+
await self.ib_lock.acquire()
|
201
|
+
conn = RemoteConnection(dsession.host)
|
202
|
+
if dsession.host not in self.ib_conns:
|
203
|
+
props = safe_eval(ast.parse(conn.q(GetProperties(session=dsession), wait=True), mode="eval").body)
|
204
|
+
if props.ib_gid is not None:
|
205
|
+
self.ib_conns[dsession.host] = ib_conn = IBConn(self.ib_ctx)
|
206
|
+
ibxc_ret = conn.q(IBConnect(ssession.host, ib_conn.gid, ib_conn.qp_num, session=dsession), wait=True)
|
207
|
+
ib_conn.connect(*struct.unpack('<16sQ', ibxc_ret))
|
208
|
+
else:
|
209
|
+
self.ib_conns[dsession.host] = None
|
210
|
+
self.ib_lock.release()
|
211
|
+
return self.ib_conns[dsession.host]
|
212
|
+
|
213
|
+
async def get_iovas(self, bufs:list[tuple[SessionKey, int]]) -> list[tuple[int, int, int]]:
|
214
|
+
await self.ib_lock.acquire()
|
215
|
+
if (rbufs:=[buf for buf in bufs if buf not in self.iova_cache]):
|
216
|
+
conn = RemoteConnection(rbufs[0][0].host)
|
217
|
+
resp = await conn.aq(BufferIOVAS(rbufs, session=rbufs[0][0]), wait=True)
|
218
|
+
self.iova_cache.update({rbuf: struct.unpack('<QQQ', resp[i*24:(i+1)*24]) for i,rbuf in enumerate(rbufs)})
|
219
|
+
self.ib_lock.release()
|
220
|
+
return [self.iova_cache[buf] for buf in bufs]
|
221
|
+
|
222
|
+
async def handle(self, method:str, path:str, body:bytes) -> tuple[http.HTTPStatus, bytes]:
|
223
|
+
status, ret = http.HTTPStatus.OK, b""
|
224
|
+
if path == "/batch" and method == "POST":
|
225
|
+
# TODO: streaming deserialize?
|
226
|
+
req = BatchRequest().deserialize(body)
|
227
|
+
# the cmds are always last (currently in datahash)
|
228
|
+
for c in req._q:
|
229
|
+
if DEBUG >= 1: print(c)
|
230
|
+
session, dev = self.sessions[unwrap(c.session)], Device[f"{self.base_device}:{unwrap(c.session).idx}"]
|
231
|
+
match c:
|
232
|
+
case SessionFree(): del self.sessions[unwrap(c.session)]
|
233
|
+
case GetProperties():
|
234
|
+
cls, args = dev.renderer.__reduce__()
|
235
|
+
graph_cls = graph_class(Device[self.base_device])
|
236
|
+
rp = RemoteProperties(
|
237
|
+
real_device=dev.device, renderer=(cls.__module__, cls.__name__, args), offset_supported=hasattr(dev.allocator, '_offset'),
|
238
|
+
graph_supported=graph_cls is not None,
|
239
|
+
graph_supports_multi=graph_cls is not None and issubclass(graph_cls, MultiGraphRunner) and hasattr(dev.allocator, '_transfer'),
|
240
|
+
ib_gid=bytes(self.ib_ctx.gid_attr.raw) if self.ib_ctx is not None else None,
|
241
|
+
)
|
242
|
+
ret = repr(rp).encode()
|
243
|
+
case Event():
|
244
|
+
if c.session == c.event_session:
|
245
|
+
session.events[c.event].set()
|
246
|
+
else:
|
247
|
+
for d in Device._opened_devices: Device[d].synchronize() # wait for device*s* to finish executing previous stuff
|
248
|
+
# TODO: don't wait, just send
|
249
|
+
await RemoteConnection(c.event_session.host).aq(Event(c.event_session, c.event, session=c.event_session), wait=True)
|
250
|
+
case Wait():
|
251
|
+
assert await session.events[c.event].wait()
|
252
|
+
del session.events[c.event] # do not leak memory
|
253
|
+
case IBConnect():
|
254
|
+
self.ib_conns[c.host] = ibc = IBConn(unwrap(self.ib_ctx))
|
255
|
+
ibc.connect(c.gid, c.qp_num)
|
256
|
+
ret = struct.pack('<16sQ', ibc.gid, ibc.qp_num)
|
257
|
+
case BufferAlloc():
|
258
|
+
assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already allocated"
|
259
|
+
session.buffers[c.buffer_num] = Buffer(dev.device, c.size, dtypes.uint8, options=c.options, preallocate=True)
|
260
|
+
case BufferIOVAS():
|
261
|
+
rets = []
|
262
|
+
for buffer_session,buffer_num in c.buffer_nums:
|
263
|
+
iova, mr = unwrap(self.ib_ctx).reg(buf:=self.sessions[buffer_session].buffers[buffer_num])
|
264
|
+
rets.append(struct.pack("<QQQ", iova, mr.contents.rkey, buf.nbytes))
|
265
|
+
ret = b"".join(rets)
|
266
|
+
case BufferOffset():
|
267
|
+
assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already exists"
|
268
|
+
session.buffers[c.buffer_num] = session.buffers[c.sbuffer_num].view(c.size, dtypes.uint8, c.offset).allocate()
|
269
|
+
case BufferFree(): del session.buffers[c.buffer_num]
|
270
|
+
case CopyIn(): session.buffers[c.buffer_num].copyin(memoryview(bytearray(req._h[c.datahash])))
|
271
|
+
case CopyOut(): session.buffers[c.buffer_num].copyout(memoryview(ret:=bytearray(session.buffers[c.buffer_num].nbytes)))
|
272
|
+
case Transfer():
|
273
|
+
if c.dsession.host == unwrap(c.session).host:
|
274
|
+
dsession, ddev = self.sessions[c.dsession], Device[f"{self.base_device}:{unwrap(c.dsession).idx}"]
|
275
|
+
dbuf, sbuf = dsession.buffers[c.dbuffer_num], session.buffers[c.buffer_num]
|
276
|
+
if hasattr(ddev.allocator, '_transfer'):
|
277
|
+
assert dbuf.nbytes == sbuf.nbytes, f"{dbuf.nbytes} != {sbuf.nbytes}"
|
278
|
+
ddev.allocator._transfer(dbuf._buf, sbuf._buf, dbuf.nbytes, dest_dev=ddev, src_dev=dev)
|
279
|
+
else:
|
280
|
+
sbuf.copyout(data:=memoryview(bytearray(sbuf.nbytes)))
|
281
|
+
dbuf.copyin(data)
|
282
|
+
else:
|
283
|
+
conn, ib_conn = RemoteConnection(c.dsession.host), await self.ib_connect(unwrap(c.session), c.dsession)
|
284
|
+
sbuf = session.buffers[c.buffer_num]
|
285
|
+
if ib_conn is not None:
|
286
|
+
src_iova, src_mr = unwrap(self.ib_ctx).reg(sbuf)
|
287
|
+
dst_iova, dst_key, dst_size = (await self.get_iovas([(c.dsession, c.dbuffer_num)]))[0]
|
288
|
+
assert sbuf.nbytes == dst_size, f"{sbuf.nbytes} != {dst_size}"
|
289
|
+
for d in Device._opened_devices: Device[d].synchronize()
|
290
|
+
ib_conn.rdma_write([SGE(dst_iova, dst_key, src_iova, src_mr.contents.lkey, dst_size)])
|
291
|
+
else:
|
292
|
+
sbuf.copyout(data:=memoryview(bytearray(sbuf.nbytes)))
|
293
|
+
await conn.aq(CopyIn(c.dbuffer_num, conn.req.h(data), session=c.dsession), wait=True)
|
294
|
+
case BatchTransfer():
|
295
|
+
conn, ib_conn = RemoteConnection(c.dbuffer_nums[0][0].host), await self.ib_connect(c.sbuffer_nums[0][0], c.dbuffer_nums[0][0])
|
296
|
+
if ib_conn is not None:
|
297
|
+
sbufs = [unwrap(self.ib_ctx).reg(self.sessions[s].buffers[bi]) for s,bi in c.sbuffer_nums]
|
298
|
+
dbufs = await self.get_iovas(c.dbuffer_nums)
|
299
|
+
for d in Device._opened_devices: Device[d].synchronize()
|
300
|
+
ib_conn.rdma_write([SGE(di, dk, si, sm.contents.lkey, ds) for (di,dk,ds),(si,sm) in zip(dbufs, sbufs)])
|
301
|
+
else:
|
302
|
+
for (sbuf_session,sbuf_num),(dbuf_session,dbuf_num) in zip(c.sbuffer_nums, c.dbuffer_nums):
|
303
|
+
sbuf = self.sessions[sbuf_session].buffers[sbuf_num]
|
304
|
+
sbuf.copyout(data:=memoryview(bytearray(sbuf.nbytes)))
|
305
|
+
await conn.aq(CopyIn(dbuf_num, conn.req.h(data), session=dbuf_session), wait=True)
|
306
|
+
case ProgramAlloc():
|
307
|
+
lib = dev.compiler.compile_cached(req._h[c.datahash].decode())
|
308
|
+
session.programs[(c.name, c.datahash)] = dev.runtime(c.name, lib)
|
309
|
+
case ProgramFree(): del session.programs[(c.name, c.datahash)]
|
310
|
+
case ProgramExec():
|
311
|
+
bufs = [session.buffers[x]._buf for x in c.bufs]
|
312
|
+
extra_args = {k:v for k,v in [("global_size", c.global_size), ("local_size", c.local_size)] if v is not None}
|
313
|
+
r = session.programs[(c.name, c.datahash)](*bufs, vals=c.vals, wait=c.wait, **extra_args)
|
314
|
+
if r is not None: ret = str(r).encode()
|
315
|
+
case GraphAlloc():
|
316
|
+
graph_fn: Callable = unwrap(dev.graph)
|
317
|
+
def _parse_ji(gi: GraphComputeItem|Transfer):
|
318
|
+
match gi:
|
319
|
+
case GraphComputeItem():
|
320
|
+
prg = self.sessions[gi.session].programs[(gi.name, gi.datahash)]
|
321
|
+
ps = ProgramSpec(gi.name, '', f"{self.base_device}:{gi.session.idx}", UOp(Ops.NOOP),
|
322
|
+
vars=list(gi.vars), ins=list(gi.ins), outs=list(gi.outs),
|
323
|
+
global_size=list(cast(tuple[int], gi.global_size)) if gi.global_size is not None else None,
|
324
|
+
local_size=list(cast(tuple[int], gi.local_size)) if gi.local_size is not None else None)
|
325
|
+
return ExecItem(CompiledRunner(ps, precompiled=b'', prg=prg), [self.sessions[gi.session].buffers[buf] for buf in gi.bufs],
|
326
|
+
fixedvars=gi.fixedvars)
|
327
|
+
case Transfer():
|
328
|
+
dbuf, sbuf = self.sessions[gi.dsession].buffers[gi.dbuffer_num], self.sessions[unwrap(gi.session)].buffers[gi.buffer_num]
|
329
|
+
assert dbuf.nbytes == sbuf.nbytes, f"{dbuf.nbytes} != {sbuf.nbytes}"
|
330
|
+
return ExecItem(BufferXfer(dbuf.nbytes, dbuf.device, sbuf.device), [dbuf, sbuf])
|
331
|
+
assert c.graph_num not in session.graphs, f"graph {c.graph_num} already allocated"
|
332
|
+
session.graphs[c.graph_num] = graph_fn(list(map(_parse_ji, c.jit_cache)), [self.sessions[s].buffers[i] for s,i in c.bufs], c.var_vals)
|
333
|
+
case GraphFree(): del session.graphs[c.graph_num]
|
334
|
+
case GraphExec():
|
335
|
+
r = session.graphs[c.graph_num]([self.sessions[s].buffers[i] for s,i in c.bufs], c.var_vals, wait=c.wait)
|
336
|
+
if r is not None: ret = str(r).encode()
|
337
|
+
else: status, ret = http.HTTPStatus.NOT_FOUND, b"Not Found"
|
338
|
+
return status, ret
|
339
|
+
|
340
|
+
def remote_server(port:int):
|
341
|
+
device = getenv("REMOTEDEV", next(Device.get_available_devices()) if Device.DEFAULT == "REMOTE" else Device.DEFAULT)
|
342
|
+
async def _inner_async(port:int, device:str):
|
343
|
+
print(f"start remote server on {port} with device {device}")
|
344
|
+
await (await asyncio.start_server(RemoteHandler(device), host='', port=port)).serve_forever()
|
345
|
+
asyncio.run(_inner_async(port, device))
|
346
|
+
|
347
|
+
# ***** frontend *****
|
348
|
+
|
349
|
+
class RemoteAllocator(Allocator['RemoteDevice']):
|
350
|
+
def __init__(self, dev:RemoteDevice):
|
351
|
+
if dev.properties.offset_supported: self._offset = self._dyn_offset
|
352
|
+
super().__init__(dev)
|
353
|
+
# TODO: ideally we shouldn't have to deal with images here
|
354
|
+
def _alloc(self, size:int, options:BufferSpec) -> int:
|
355
|
+
self.dev.q(BufferAlloc(buffer_num:=next(self.dev.buffer_num), size, options))
|
356
|
+
return buffer_num
|
357
|
+
# TODO: options should not be here in any Allocator
|
358
|
+
def _free(self, opaque:int, options):
|
359
|
+
try: self.dev.q(BufferFree(opaque))
|
360
|
+
except (TypeError, AttributeError): pass
|
361
|
+
def _copyin(self, dest:int, src:memoryview): self.dev.q(CopyIn(dest, self.dev.conn.req.h(src)))
|
362
|
+
def _copyout(self, dest:memoryview, src:int):
|
363
|
+
resp = self.dev.q(CopyOut(src), wait=True)
|
364
|
+
assert len(resp) == len(dest), f"buffer length mismatch {len(resp)} != {len(dest)}"
|
365
|
+
dest[:] = resp
|
366
|
+
def _transfer(self, dest, src, sz, src_dev, dest_dev):
|
367
|
+
if dest_dev.conn != src_dev.conn:
|
368
|
+
dest_dev.q(Event(src_dev.session, start_event:=next(src_dev.event_num)))
|
369
|
+
src_dev.q(Wait(start_event))
|
370
|
+
src_dev.q(Transfer(src, dest_dev.session, dest))
|
371
|
+
if dest_dev.conn != src_dev.conn:
|
372
|
+
src_dev.q(Event(dest_dev.session, end_event:=next(dest_dev.event_num)))
|
373
|
+
dest_dev.q(Wait(end_event))
|
374
|
+
if DEBUG >= 2: dest_dev.conn.batch_submit()
|
375
|
+
def _dyn_offset(self, opaque:int, size:int, offset:int) -> int:
|
376
|
+
self.dev.q(BufferOffset(buffer_num:=next(self.dev.buffer_num), size, offset, opaque))
|
377
|
+
return buffer_num
|
378
|
+
|
379
|
+
class RemoteProgram:
|
380
|
+
def __init__(self, dev:RemoteDevice, name:str, lib:bytes):
|
381
|
+
self.dev, self.name = dev, name
|
382
|
+
self.datahash = self.dev.conn.req.h(lib)
|
383
|
+
self.dev.q(ProgramAlloc(self.name, self.datahash))
|
384
|
+
super().__init__()
|
385
|
+
weakref.finalize(self, self._fini, self.dev, self.name, self.datahash)
|
386
|
+
|
387
|
+
@staticmethod
|
388
|
+
def _fini(dev:RemoteDevice, name:str, datahash:str): dev.q(ProgramFree(name, datahash))
|
389
|
+
|
390
|
+
def __call__(self, *bufs, global_size=None, local_size=None, vals:tuple[int, ...]=(), wait=False):
|
391
|
+
ret = self.dev.q(ProgramExec(self.name, self.datahash, bufs, vals, global_size, local_size, wait), wait=wait)
|
392
|
+
if wait: return float(ret)
|
393
|
+
|
394
|
+
@functools.cache
|
395
|
+
class RemoteConnection:
|
396
|
+
q_lock = threading.Lock()
|
397
|
+
all: dict[RemoteConnection, None] = {} # dict instead of set for deterministic ordering
|
398
|
+
|
399
|
+
def __init__(self, host:str):
|
400
|
+
if DEBUG >= 1: print(f"remote with host {host}")
|
401
|
+
while 1:
|
402
|
+
try:
|
403
|
+
self.conn = http.client.HTTPConnection(host, timeout=getenv("REMOTE_TIMEOUT", 300.0))
|
404
|
+
self.conn.connect()
|
405
|
+
break
|
406
|
+
except Exception as e:
|
407
|
+
print(e)
|
408
|
+
time.sleep(0.1)
|
409
|
+
self.req: BatchRequest = BatchRequest()
|
410
|
+
RemoteConnection.all[self] = None
|
411
|
+
|
412
|
+
def q(self, x:RemoteRequest, wait:bool=False):
|
413
|
+
with RemoteConnection.q_lock:
|
414
|
+
self.req.q(x)
|
415
|
+
if wait: return self.batch_submit(take_q=False)
|
416
|
+
|
417
|
+
async def aq(self, x:RemoteRequest, wait:bool=False): return await asyncio.to_thread(self.q, x, wait=wait)
|
418
|
+
|
419
|
+
def batch_submit(self, take_q:bool=True):
|
420
|
+
if take_q: RemoteConnection.q_lock.acquire()
|
421
|
+
conns = RemoteConnection.all.keys()
|
422
|
+
datas = {conn: conn.req.serialize() for conn in conns}
|
423
|
+
reqs, hashes, hash_datas = sum(len(c.req._q) for c in conns), sum(len(c.req._h) for c in conns), sum(len(data) for data in datas.values())
|
424
|
+
with Timing(f"*** send {reqs:-3d} requests {hashes:-3d} hashes with len {hash_datas/1024:.2f} kB in ", enabled=DEBUG>=3):
|
425
|
+
for conn,data in datas.items(): conn.conn.request("POST", "/batch", data)
|
426
|
+
for conn in datas.keys():
|
427
|
+
response = conn.conn.getresponse()
|
428
|
+
resp = response.read()
|
429
|
+
conn.req = BatchRequest() # no matter what response, reset conn
|
430
|
+
if response.status == http.HTTPStatus.INTERNAL_SERVER_ERROR:
|
431
|
+
exc_wrapper = safe_eval(ast.parse(resp.decode(), mode="eval").body)
|
432
|
+
exc_wrapper.exc.add_note(exc_wrapper.trace)
|
433
|
+
raise exc_wrapper.exc
|
434
|
+
assert response.status == http.HTTPStatus.OK, f"POST /batch failed: {resp.decode()}"
|
435
|
+
if conn == self: ret = resp
|
436
|
+
if take_q: RemoteConnection.q_lock.release()
|
437
|
+
return ret
|
438
|
+
|
439
|
+
def parse_hosts(hs:str) -> list[tuple[str, int]]|LazySeq[tuple[str, int]]:
|
440
|
+
hosts = [(unwrap(h), int(c) if c is not None else c) for h,c in ((h.split("*", maxsplit=1)+[None,])[:2] for h in hs.split(","))]
|
441
|
+
if len(hosts) == 1 and hosts[0][1] is None: return LazySeq(lambda idx: (hosts[0][0], idx))
|
442
|
+
return [(h, i) for h,c in hosts for i in range(unwrap(c))]
|
443
|
+
|
444
|
+
class RemoteDevice(Compiled):
|
445
|
+
devices = parse_hosts(getenv("HOST", ""))
|
446
|
+
|
447
|
+
def __init__(self, device:str):
|
448
|
+
host, idx = RemoteDevice.devices[int(device.split(":")[1]) if ":" in device else 0]
|
449
|
+
|
450
|
+
# connection is shared between sessions on the same host
|
451
|
+
self.session: SessionKey = SessionKey(host or RemoteDevice.local_server(), idx, binascii.hexlify(os.urandom(0x10)).decode())
|
452
|
+
self.conn: RemoteConnection = RemoteConnection(self.session.host)
|
453
|
+
|
454
|
+
# state for the session
|
455
|
+
self.buffer_num: Iterator[int] = itertools.count(0)
|
456
|
+
self.graph_num: Iterator[int] = itertools.count(0)
|
457
|
+
self.event_num: Iterator[int] = itertools.count(0)
|
458
|
+
|
459
|
+
self.properties: RemoteProperties = safe_eval(ast.parse(self.q(GetProperties(), wait=True), mode="eval").body)
|
460
|
+
if DEBUG >= 1: print(f"remote has device {self.properties.real_device}")
|
461
|
+
# TODO: how to we have BEAM be cached on the backend? this should just send a specification of the compute. rethink what goes in Renderer
|
462
|
+
renderer = self.properties.renderer
|
463
|
+
if not renderer[0].startswith("tinygrad.") or not renderer[1].endswith("Renderer"): raise RuntimeError(f"bad renderer {renderer}")
|
464
|
+
renderer_class = fromimport(renderer[0], renderer[1]) # TODO: is this secure?
|
465
|
+
if not issubclass(renderer_class, Renderer): raise RuntimeError(f"renderer isn't a Renderer {renderer}")
|
466
|
+
renderer_instance = renderer_class(*renderer[2])
|
467
|
+
renderer_instance.device = device
|
468
|
+
graph = fromimport('tinygrad.runtime.graph.remote', "RemoteGraph") if self.properties.graph_supported else None
|
469
|
+
super().__init__(device, RemoteAllocator(self), renderer_instance, Compiler(), functools.partial(RemoteProgram, self), graph, id(self.conn))
|
470
|
+
|
471
|
+
def finalize(self):
|
472
|
+
with contextlib.suppress(ConnectionError, http.client.HTTPException): self.q(SessionFree(), wait=True)
|
473
|
+
|
474
|
+
def q(self, x:RemoteRequest, wait:bool=False): return self.conn.q(replace(x, session=self.session), wait=wait)
|
475
|
+
|
476
|
+
@functools.cache
|
477
|
+
@staticmethod
|
478
|
+
def local_server():
|
479
|
+
multiprocessing.Process(target=remote_server, args=(6667,), name="MainProcess", daemon=True).start()
|
480
|
+
return "127.0.0.1:6667"
|
481
|
+
|
482
|
+
if __name__ == "__main__": remote_server(getenv("PORT", 6667))
|
tinygrad/runtime/ops_webgpu.py
CHANGED
@@ -1,35 +1,34 @@
|
|
1
1
|
import functools, struct
|
2
|
-
from tinygrad.device import Compiled, Allocator, Compiler
|
2
|
+
from tinygrad.device import Compiled, Allocator, Compiler, BufferSpec
|
3
3
|
from tinygrad.renderer.wgsl import WGSLRenderer
|
4
|
-
from tinygrad.helpers import round_up,
|
4
|
+
from tinygrad.helpers import round_up, suppress_finalizing
|
5
5
|
from tinygrad.runtime.autogen import webgpu
|
6
|
-
from typing import List, Any
|
6
|
+
from typing import List, Any, TypeAlias
|
7
7
|
import ctypes
|
8
8
|
import os
|
9
9
|
|
10
|
+
WGPUDevPtr: TypeAlias = webgpu.WGPUDevice # type: ignore
|
11
|
+
WGPUBufPtr: TypeAlias = webgpu.WGPUBuffer # type: ignore
|
12
|
+
|
10
13
|
backend_types = {v: k for k, v in webgpu.WGPUBackendType__enumvalues.items() }
|
11
14
|
|
12
|
-
|
13
|
-
instance = webgpu.wgpuCreateInstance(webgpu.WGPUInstanceDescriptor(features = webgpu.WGPUInstanceFeatures(timedWaitAnyEnable = True)))
|
14
|
-
except AttributeError:
|
15
|
-
raise RuntimeError("Cannot find dawn library. Install it with: " + ("brew tap wpmed92/dawn && brew install dawn" if OSX else
|
16
|
-
"sudo curl -L https://github.com/wpmed92/pydawn/releases/download/v0.1.6/libwebgpu_dawn.so -o /usr/lib/libwebgpu_dawn.so"))
|
15
|
+
instance = webgpu.wgpuCreateInstance(webgpu.WGPUInstanceDescriptor(features = webgpu.WGPUInstanceFeatures(timedWaitAnyEnable = True)))
|
17
16
|
|
18
|
-
def to_c_string(_str): return ctypes.create_string_buffer(_str.encode('utf-8'))
|
17
|
+
def to_c_string(_str:str) -> ctypes.Array: return ctypes.create_string_buffer(_str.encode('utf-8'))
|
19
18
|
|
20
|
-
def from_wgpu_str(string_view): return ctypes.string_at(string_view.data, string_view.length).decode("utf-8")
|
19
|
+
def from_wgpu_str(string_view:webgpu.struct_WGPUStringView) -> str: return ctypes.string_at(string_view.data, string_view.length).decode("utf-8")
|
21
20
|
|
22
|
-
def to_wgpu_str(_str):
|
21
|
+
def to_wgpu_str(_str:str) -> webgpu.struct_WGPUStringView:
|
23
22
|
return webgpu.WGPUStringView(data=ctypes.cast(ctypes.pointer(to_c_string(_str)), ctypes.POINTER(ctypes.c_char)), length=len(_str))
|
24
23
|
|
25
|
-
def _wait(future):
|
24
|
+
def _wait(future:webgpu.struct_WGPUFuture):
|
26
25
|
assert webgpu.wgpuInstanceWaitAny(instance, 1, webgpu.WGPUFutureWaitInfo(future=future), 2**64-1) == webgpu.WGPUWaitStatus_Success, "Future failed"
|
27
26
|
|
28
|
-
def write_buffer(device, buf, offset, src):
|
27
|
+
def write_buffer(device:WGPUDevPtr, buf:WGPUBufPtr, offset:int, src:memoryview|bytearray|bytes):
|
29
28
|
src = bytearray(src)
|
30
29
|
webgpu.wgpuQueueWriteBuffer(webgpu.wgpuDeviceGetQueue(device), buf, offset, (ctypes.c_uint8 * len(src)).from_buffer(src), len(src))
|
31
30
|
|
32
|
-
def _run(async_fun, cb_info_type, cb_type, status_enum, res_idx, msg_idx, *params):
|
31
|
+
def _run(async_fun, cb_info_type, cb_type, status_enum, res_idx:int|None, msg_idx:int|None, *params):
|
33
32
|
result: List[Any] = []
|
34
33
|
|
35
34
|
def cb(*params):
|
@@ -42,7 +41,7 @@ def _run(async_fun, cb_info_type, cb_type, status_enum, res_idx, msg_idx, *param
|
|
42
41
|
if result[0] != 1: raise RuntimeError(f"[{status_enum[result[0]] if status_enum else 'ERROR'}]{result[msg_idx] if msg_idx else ''}")
|
43
42
|
return result[res_idx] if res_idx else None
|
44
43
|
|
45
|
-
def copy_buffer_to_buffer(dev, src, src_offset, dst, dst_offset, size):
|
44
|
+
def copy_buffer_to_buffer(dev:WGPUDevPtr, src:WGPUBufPtr, src_offset:int, dst:WGPUBufPtr, dst_offset:int, size:int):
|
46
45
|
encoder = webgpu.wgpuDeviceCreateCommandEncoder(dev, webgpu.WGPUCommandEncoderDescriptor())
|
47
46
|
webgpu.wgpuCommandEncoderCopyBufferToBuffer(encoder, src, src_offset, dst, dst_offset, size)
|
48
47
|
cb = webgpu.wgpuCommandEncoderFinish(encoder, webgpu.WGPUCommandBufferDescriptor())
|
@@ -50,7 +49,7 @@ def copy_buffer_to_buffer(dev, src, src_offset, dst, dst_offset, size):
|
|
50
49
|
webgpu.wgpuCommandBufferRelease(cb)
|
51
50
|
webgpu.wgpuCommandEncoderRelease(encoder)
|
52
51
|
|
53
|
-
def read_buffer(dev, buf):
|
52
|
+
def read_buffer(dev:WGPUDevPtr, buf:WGPUBufPtr) -> memoryview:
|
54
53
|
size = webgpu.wgpuBufferGetSize(buf)
|
55
54
|
tmp_buffer = webgpu.wgpuDeviceCreateBuffer(dev, webgpu.WGPUBufferDescriptor(size=size,
|
56
55
|
usage=webgpu.WGPUBufferUsage_CopyDst | webgpu.WGPUBufferUsage_MapRead, mappedAtCreation=False))
|
@@ -63,17 +62,17 @@ def read_buffer(dev, buf):
|
|
63
62
|
webgpu.wgpuBufferDestroy(tmp_buffer)
|
64
63
|
return memoryview(buf_copy).cast("B")
|
65
64
|
|
66
|
-
def pop_error(device):
|
65
|
+
def pop_error(device:WGPUDevPtr) -> str:
|
67
66
|
return _run(webgpu.wgpuDevicePopErrorScopeF, webgpu.WGPUPopErrorScopeCallbackInfo, webgpu.WGPUPopErrorScopeCallback, None, 2, 2, device)
|
68
67
|
|
69
|
-
def create_uniform(wgpu_device, val):
|
68
|
+
def create_uniform(wgpu_device:WGPUDevPtr, val:int|float) -> WGPUBufPtr:
|
70
69
|
buf = webgpu.wgpuDeviceCreateBuffer(wgpu_device,
|
71
70
|
webgpu.WGPUBufferDescriptor(size=4, usage=webgpu.WGPUBufferUsage_Uniform | webgpu.WGPUBufferUsage_CopyDst))
|
72
71
|
write_buffer(wgpu_device, buf, 0, val.to_bytes(4, "little") if isinstance(val, int) else struct.pack('<f', val))
|
73
72
|
return buf
|
74
73
|
|
75
74
|
class WebGPUProgram:
|
76
|
-
def __init__(self, dev, name:str, lib:bytes):
|
75
|
+
def __init__(self, dev:tuple[WGPUDevPtr, bool], name:str, lib:bytes):
|
77
76
|
(self.dev, self.timestamp_supported) = dev
|
78
77
|
|
79
78
|
# Creating shader module
|
@@ -89,14 +88,15 @@ class WebGPUProgram:
|
|
89
88
|
if err := pop_error(self.dev): raise RuntimeError(f"Shader compilation failed: {err}")
|
90
89
|
|
91
90
|
self.name, self.lib, self.prg = name, lib, shader_module
|
92
|
-
def __call__(self, *bufs, global_size=(1,1,1), local_size=(1,1,1),
|
91
|
+
def __call__(self, *bufs:WGPUBufPtr, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1),
|
92
|
+
vals:tuple[int, ...]=(), wait=False) -> float|None:
|
93
93
|
wait = wait and self.timestamp_supported
|
94
94
|
tmp_bufs = [*bufs]
|
95
95
|
buf_patch = False
|
96
96
|
|
97
97
|
# WebGPU does not allow using the same buffer for input and output
|
98
98
|
for i in range(1, len(bufs)):
|
99
|
-
if bufs[i] == bufs[0]:
|
99
|
+
if ctypes.addressof(bufs[i]) == ctypes.addressof(bufs[0]):
|
100
100
|
tmp_bufs[0] = webgpu.wgpuDeviceCreateBuffer(self.dev,
|
101
101
|
webgpu.WGPUBufferDescriptor(size=webgpu.wgpuBufferGetSize(bufs[0]), usage=webgpu.wgpuBufferGetUsage(bufs[0])))
|
102
102
|
buf_patch = True
|
@@ -173,23 +173,23 @@ class WebGPUProgram:
|
|
173
173
|
webgpu.wgpuBufferDestroy(query_buf)
|
174
174
|
webgpu.wgpuQuerySetDestroy(query_set)
|
175
175
|
return time
|
176
|
+
return None
|
176
177
|
|
177
|
-
class WebGpuAllocator(Allocator):
|
178
|
-
def
|
179
|
-
def _alloc(self, size: int, options):
|
178
|
+
class WebGpuAllocator(Allocator['WGPUDevPtr']):
|
179
|
+
def _alloc(self, size:int, options:BufferSpec) -> WGPUBufPtr:
|
180
180
|
# WebGPU buffers have to be 4-byte aligned
|
181
181
|
return webgpu.wgpuDeviceCreateBuffer(self.dev, webgpu.WGPUBufferDescriptor(size=round_up(size, 4),
|
182
182
|
usage=webgpu.WGPUBufferUsage_Storage | webgpu.WGPUBufferUsage_CopyDst | webgpu.WGPUBufferUsage_CopySrc))
|
183
|
-
def _copyin(self, dest, src:
|
183
|
+
def _copyin(self, dest:WGPUBufPtr, src:memoryview):
|
184
184
|
if src.nbytes % 4:
|
185
185
|
padded_src = bytearray(round_up(src.nbytes, 4))
|
186
186
|
padded_src[:src.nbytes] = src
|
187
187
|
write_buffer(self.dev, dest, 0, padded_src if src.nbytes % 4 else src)
|
188
|
-
def _copyout(self, dest:
|
188
|
+
def _copyout(self, dest:memoryview, src:WGPUBufPtr):
|
189
189
|
buffer_data = read_buffer(self.dev, src)
|
190
190
|
dest[:] = buffer_data[:dest.nbytes] if webgpu.wgpuBufferGetSize(src) > dest.nbytes else buffer_data
|
191
|
-
|
192
|
-
|
191
|
+
@suppress_finalizing
|
192
|
+
def _free(self, opaque:WGPUBufPtr, options:BufferSpec): webgpu.wgpuBufferDestroy(opaque)
|
193
193
|
|
194
194
|
class WebGpuDevice(Compiled):
|
195
195
|
def __init__(self, device:str):
|