tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +1 -1
- tinygrad/apps/llm.py +206 -0
- tinygrad/codegen/__init__.py +116 -0
- tinygrad/codegen/devectorizer.py +315 -172
- tinygrad/codegen/expander.py +8 -16
- tinygrad/codegen/gpudims.py +89 -0
- tinygrad/codegen/linearize.py +205 -203
- tinygrad/codegen/lowerer.py +92 -139
- tinygrad/codegen/opt/__init__.py +38 -0
- tinygrad/codegen/opt/heuristic.py +125 -0
- tinygrad/codegen/opt/kernel.py +510 -0
- tinygrad/{engine → codegen/opt}/search.py +51 -35
- tinygrad/codegen/opt/swizzler.py +134 -0
- tinygrad/codegen/opt/tc.py +127 -0
- tinygrad/codegen/quantize.py +67 -0
- tinygrad/device.py +122 -132
- tinygrad/dtype.py +152 -35
- tinygrad/engine/jit.py +81 -54
- tinygrad/engine/memory.py +46 -27
- tinygrad/engine/realize.py +82 -41
- tinygrad/engine/schedule.py +70 -445
- tinygrad/frontend/__init__.py +0 -0
- tinygrad/frontend/onnx.py +1253 -0
- tinygrad/frontend/torch.py +5 -0
- tinygrad/gradient.py +19 -27
- tinygrad/helpers.py +95 -47
- tinygrad/nn/__init__.py +7 -8
- tinygrad/nn/optim.py +72 -41
- tinygrad/nn/state.py +37 -23
- tinygrad/renderer/__init__.py +40 -60
- tinygrad/renderer/cstyle.py +143 -128
- tinygrad/renderer/llvmir.py +113 -62
- tinygrad/renderer/ptx.py +50 -32
- tinygrad/renderer/wgsl.py +27 -23
- tinygrad/runtime/autogen/am/am.py +5861 -0
- tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
- tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
- tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
- tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
- tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
- tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
- tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
- tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
- tinygrad/runtime/autogen/comgr.py +35 -9
- tinygrad/runtime/autogen/comgr_3.py +906 -0
- tinygrad/runtime/autogen/cuda.py +2419 -494
- tinygrad/runtime/autogen/hsa.py +57 -16
- tinygrad/runtime/autogen/ib.py +7171 -0
- tinygrad/runtime/autogen/io_uring.py +917 -118
- tinygrad/runtime/autogen/kfd.py +748 -26
- tinygrad/runtime/autogen/libc.py +613 -218
- tinygrad/runtime/autogen/libusb.py +1643 -0
- tinygrad/runtime/autogen/nv/nv.py +8602 -0
- tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
- tinygrad/runtime/autogen/opencl.py +2 -4
- tinygrad/runtime/autogen/sqtt.py +1789 -0
- tinygrad/runtime/autogen/vfio.py +3 -3
- tinygrad/runtime/autogen/webgpu.py +273 -264
- tinygrad/runtime/graph/cuda.py +3 -3
- tinygrad/runtime/graph/hcq.py +68 -29
- tinygrad/runtime/graph/metal.py +29 -13
- tinygrad/runtime/graph/remote.py +114 -0
- tinygrad/runtime/ops_amd.py +537 -320
- tinygrad/runtime/ops_cpu.py +108 -7
- tinygrad/runtime/ops_cuda.py +12 -14
- tinygrad/runtime/ops_disk.py +13 -10
- tinygrad/runtime/ops_dsp.py +47 -40
- tinygrad/runtime/ops_gpu.py +13 -11
- tinygrad/runtime/ops_hip.py +6 -9
- tinygrad/runtime/ops_llvm.py +35 -15
- tinygrad/runtime/ops_metal.py +29 -19
- tinygrad/runtime/ops_npy.py +5 -3
- tinygrad/runtime/ops_null.py +28 -0
- tinygrad/runtime/ops_nv.py +306 -234
- tinygrad/runtime/ops_python.py +62 -52
- tinygrad/runtime/ops_qcom.py +28 -39
- tinygrad/runtime/ops_remote.py +482 -0
- tinygrad/runtime/ops_webgpu.py +28 -28
- tinygrad/runtime/support/am/amdev.py +114 -249
- tinygrad/runtime/support/am/ip.py +211 -172
- tinygrad/runtime/support/amd.py +138 -0
- tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
- tinygrad/runtime/support/compiler_cuda.py +8 -11
- tinygrad/runtime/support/elf.py +2 -1
- tinygrad/runtime/support/hcq.py +184 -97
- tinygrad/runtime/support/ib.py +172 -0
- tinygrad/runtime/support/llvm.py +3 -4
- tinygrad/runtime/support/memory.py +251 -0
- tinygrad/runtime/support/nv/__init__.py +0 -0
- tinygrad/runtime/support/nv/ip.py +581 -0
- tinygrad/runtime/support/nv/nvdev.py +183 -0
- tinygrad/runtime/support/system.py +170 -0
- tinygrad/runtime/support/usb.py +268 -0
- tinygrad/runtime/support/webgpu.py +18 -0
- tinygrad/schedule/__init__.py +0 -0
- tinygrad/schedule/grouper.py +119 -0
- tinygrad/schedule/kernelize.py +368 -0
- tinygrad/schedule/multi.py +231 -0
- tinygrad/shape/shapetracker.py +40 -46
- tinygrad/shape/view.py +88 -52
- tinygrad/tensor.py +968 -542
- tinygrad/uop/__init__.py +117 -0
- tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
- tinygrad/uop/mathtraits.py +169 -0
- tinygrad/uop/ops.py +1021 -0
- tinygrad/uop/spec.py +228 -0
- tinygrad/{codegen → uop}/symbolic.py +239 -216
- tinygrad/uop/upat.py +163 -0
- tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
- tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
- tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
- tinygrad/viz/index.html +203 -403
- tinygrad/viz/js/index.js +718 -0
- tinygrad/viz/js/worker.js +29 -0
- tinygrad/viz/serve.py +224 -102
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
- tinygrad-0.11.0.dist-info/RECORD +141 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
- tinygrad/codegen/kernel.py +0 -693
- tinygrad/engine/multi.py +0 -161
- tinygrad/ops.py +0 -1003
- tinygrad/runtime/ops_cloud.py +0 -220
- tinygrad/runtime/support/allocator.py +0 -94
- tinygrad/spec.py +0 -155
- tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
- tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
- tinygrad/viz/perfetto.html +0 -178
- tinygrad-0.10.2.dist-info/RECORD +0 -99
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/runtime/support/hcq.py
CHANGED
@@ -1,13 +1,23 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
from typing import cast, Type, TypeVar, Generic, Any
|
3
|
-
import contextlib, decimal, statistics, time, ctypes, array, os,
|
4
|
-
|
2
|
+
from typing import cast, Callable, Type, TypeVar, Generic, Any
|
3
|
+
import contextlib, decimal, statistics, time, ctypes, array, os, struct, traceback, collections
|
4
|
+
try: import fcntl # windows misses that
|
5
|
+
except ImportError: fcntl = None #type:ignore[assignment]
|
6
|
+
from tinygrad.helpers import PROFILE, getenv, to_mv, round_up, ProfileRangeEvent
|
5
7
|
from tinygrad.renderer import Renderer
|
6
|
-
from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator,
|
7
|
-
from tinygrad.ops import sym_infer, sint, Variable, UOp
|
8
|
+
from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator, ProfileDeviceEvent, ProfileProgramEvent
|
9
|
+
from tinygrad.uop.ops import sym_infer, sint, Variable, UOp
|
8
10
|
from tinygrad.runtime.autogen import libc
|
9
11
|
|
10
|
-
class
|
12
|
+
class MMIOInterface:
|
13
|
+
def __init__(self, addr:int, nbytes:int, fmt='B'): self.mv, self.addr, self.nbytes, self.fmt = to_mv(addr, nbytes).cast(fmt), addr, nbytes, fmt
|
14
|
+
def __len__(self): return self.nbytes // struct.calcsize(self.fmt)
|
15
|
+
def __getitem__(self, k): return (bytes(self.mv[k]) if self.fmt == 'B' else self.mv[k].tolist()) if isinstance(k, slice) else self.mv[k]
|
16
|
+
def __setitem__(self, k, v): self.mv[k] = v
|
17
|
+
def view(self, offset:int=0, size:int|None=None, fmt=None) -> MMIOInterface:
|
18
|
+
return MMIOInterface(self.addr+offset, size or (self.nbytes - offset), fmt=fmt or self.fmt)
|
19
|
+
|
20
|
+
class FileIOInterface:
|
11
21
|
"""
|
12
22
|
Hardware Abstraction Layer for HCQ devices. The class provides a unified interface for interacting with hardware devices.
|
13
23
|
"""
|
@@ -18,7 +28,10 @@ class HWInterface:
|
|
18
28
|
def __del__(self):
|
19
29
|
if hasattr(self, 'fd'): os.close(self.fd)
|
20
30
|
def ioctl(self, request, arg): return fcntl.ioctl(self.fd, request, arg)
|
21
|
-
def mmap(self, start, sz, prot, flags, offset):
|
31
|
+
def mmap(self, start, sz, prot, flags, offset):
|
32
|
+
x = libc.mmap(start, sz, prot, flags, self.fd, offset)
|
33
|
+
if x == 0xffffffffffffffff: raise OSError(f"Failed to mmap {sz} bytes at {hex(start)}: {os.strerror(ctypes.get_errno())}")
|
34
|
+
return x
|
22
35
|
def read(self, size=None, binary=False, offset=None):
|
23
36
|
if offset is not None: self.seek(offset)
|
24
37
|
with open(self.fd, "rb" if binary else "r", closefd=False) as file: return file.read(size)
|
@@ -28,7 +41,10 @@ class HWInterface:
|
|
28
41
|
def listdir(self): return os.listdir(self.path)
|
29
42
|
def seek(self, offset): os.lseek(self.fd, offset, os.SEEK_SET)
|
30
43
|
@staticmethod
|
31
|
-
def anon_mmap(start, sz, prot, flags, offset):
|
44
|
+
def anon_mmap(start, sz, prot, flags, offset):
|
45
|
+
x = libc.mmap(start, sz, prot, flags, -1, offset)
|
46
|
+
if x == 0xffffffffffffffff: raise OSError(f"Failed to mmap {sz} bytes at {hex(start)}: {os.strerror(ctypes.get_errno())}")
|
47
|
+
return x
|
32
48
|
@staticmethod
|
33
49
|
def munmap(buf, sz): return libc.munmap(buf, sz)
|
34
50
|
@staticmethod
|
@@ -36,14 +52,14 @@ class HWInterface:
|
|
36
52
|
@staticmethod
|
37
53
|
def readlink(path): return os.readlink(path)
|
38
54
|
@staticmethod
|
39
|
-
def eventfd(initval, flags=None): return
|
55
|
+
def eventfd(initval, flags=None): return FileIOInterface(fd=os.eventfd(initval, flags)) # type: ignore[attr-defined]
|
40
56
|
|
41
|
-
if MOCKGPU:=getenv("MOCKGPU"): from test.mockgpu.mockgpu import
|
57
|
+
if MOCKGPU:=getenv("MOCKGPU"): from test.mockgpu.mockgpu import MockFileIOInterface as FileIOInterface # noqa: F401 # pylint: disable=unused-import
|
42
58
|
|
43
59
|
# **************** for HCQ Compatible Devices ****************
|
44
60
|
|
45
61
|
SignalType = TypeVar('SignalType', bound='HCQSignal')
|
46
|
-
|
62
|
+
HCQDeviceType = TypeVar('HCQDeviceType', bound='HCQCompiled')
|
47
63
|
ProgramType = TypeVar('ProgramType', bound='HCQProgram')
|
48
64
|
ArgsStateType = TypeVar('ArgsStateType', bound='HCQArgsState')
|
49
65
|
QueueType = TypeVar('QueueType', bound='HWQueue')
|
@@ -57,16 +73,16 @@ class BumpAllocator:
|
|
57
73
|
self.ptr = (res:=round_up(self.ptr, alignment)) + size
|
58
74
|
return res + self.base
|
59
75
|
|
60
|
-
class HWQueue(Generic[SignalType,
|
76
|
+
class HWQueue(Generic[SignalType, HCQDeviceType, ProgramType, ArgsStateType]):
|
61
77
|
"""
|
62
78
|
A base class for hardware command queues in the HCQ (Hardware Command Queue) API.
|
63
79
|
"""
|
64
80
|
|
65
81
|
def __init__(self):
|
66
82
|
self._q:Any = []
|
67
|
-
self.binded_device:
|
83
|
+
self.binded_device:HCQDeviceType|None = None
|
68
84
|
self.q_sints:list[tuple[int, int]] = []
|
69
|
-
self.mv_sints:list[tuple[
|
85
|
+
self.mv_sints:list[tuple[MMIOInterface, int, int, int|None]] = []
|
70
86
|
self.syms:list[sint] = []
|
71
87
|
self._prev_resolved_syms:list[int|None] = []
|
72
88
|
|
@@ -150,7 +166,7 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
|
|
150
166
|
|
151
167
|
# *** submit and bind commands ***
|
152
168
|
|
153
|
-
def bind(self, dev:
|
169
|
+
def bind(self, dev:HCQDeviceType):
|
154
170
|
"""
|
155
171
|
Associates the queue with a specific device for optimized execution.
|
156
172
|
|
@@ -165,13 +181,13 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
|
|
165
181
|
"""
|
166
182
|
|
167
183
|
def bind_args_state(self, args_state:ArgsStateType):
|
168
|
-
for vals,
|
184
|
+
for vals, mem, fmt in args_state.bind_data: self.bind_sints_to_mem(*vals, mem=mem, fmt=fmt)
|
169
185
|
|
170
|
-
def bind_sints(self, *vals:sint,
|
171
|
-
self.
|
186
|
+
def bind_sints(self, *vals:sint, mem:MMIOInterface, struct_t:Type[ctypes.Structure], start_field:str, fmt, mask:int|None=None):
|
187
|
+
self.bind_sints_to_mem(*vals, mem=mem, fmt=fmt, mask=mask, offset=getattr(struct_t, start_field).offset)
|
172
188
|
|
173
|
-
def
|
174
|
-
mv =
|
189
|
+
def bind_sints_to_mem(self, *vals:sint, mem:MMIOInterface, fmt, mask:int|None=None, offset:int=0):
|
190
|
+
mv = mem.view(offset=offset, size=len(vals)*8, fmt=fmt)
|
175
191
|
for i, val in enumerate(vals):
|
176
192
|
if isinstance(val, int): mv[i] = val if mask is None else ((mv[i] & ~mask) | val)
|
177
193
|
else: self.mv_sints.append((mv, i, self._new_sym(val), mask))
|
@@ -189,7 +205,7 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
|
|
189
205
|
|
190
206
|
self._prev_resolved_syms = cast(list[int|None], resolved_syms)
|
191
207
|
|
192
|
-
def submit(self, dev:
|
208
|
+
def submit(self, dev:HCQDeviceType, var_vals:dict[Variable, int]|None=None):
|
193
209
|
"""
|
194
210
|
Submits the command queue to a specific device for execution.
|
195
211
|
|
@@ -200,18 +216,21 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
|
|
200
216
|
if var_vals is not None: self._apply_var_vals(var_vals)
|
201
217
|
self._submit(dev)
|
202
218
|
return self
|
203
|
-
def _submit(self, dev:
|
219
|
+
def _submit(self, dev:HCQDeviceType): raise NotImplementedError("need _submit")
|
204
220
|
|
205
|
-
class HCQSignal(Generic[
|
206
|
-
def __init__(self,
|
207
|
-
self.
|
221
|
+
class HCQSignal(Generic[HCQDeviceType]):
|
222
|
+
def __init__(self, base_buf:HCQBuffer, value:int=0, owner:HCQDeviceType|None=None, is_timeline:bool=False, timestamp_divider=1000):
|
223
|
+
self.base_buf, self.value_addr, self.timestamp_addr, self.owner = base_buf, base_buf.va_addr+0, base_buf.va_addr+8, owner
|
224
|
+
self.is_timeline = is_timeline
|
208
225
|
self.timestamp_divider:decimal.Decimal = decimal.Decimal(timestamp_divider)
|
209
|
-
self.timeline_for_device:DeviceType|None = timeline_for_device
|
210
226
|
|
211
|
-
if isinstance(
|
212
|
-
self.value_mv, self.timestamp_mv =
|
227
|
+
if isinstance(self.base_buf.va_addr, int):
|
228
|
+
self.value_mv, self.timestamp_mv = self.base_buf.cpu_view().view(0, 8, 'Q'), self.base_buf.cpu_view().view(8, 8, 'Q')
|
213
229
|
self.value_mv[0] = value
|
214
230
|
|
231
|
+
def __del__(self):
|
232
|
+
if isinstance(self.base_buf.va_addr, int) and self.owner is not None: HCQCompiled.signal_pool[self.owner.peer_group].append(self.base_buf)
|
233
|
+
|
215
234
|
@property
|
216
235
|
def value(self) -> int: return self.value_mv[0]
|
217
236
|
|
@@ -241,54 +260,57 @@ class HCQSignal(Generic[DeviceType]):
|
|
241
260
|
|
242
261
|
Args:
|
243
262
|
value: The value to wait for.
|
244
|
-
timeout: Maximum time to wait in milliseconds. Defaults to
|
263
|
+
timeout: Maximum time to wait in milliseconds. Defaults to 30s.
|
245
264
|
"""
|
246
265
|
start_time = int(time.perf_counter() * 1000)
|
247
|
-
while self.value < value and (time_spent:=int(time.perf_counter() * 1000) - start_time) < timeout:
|
266
|
+
while (not_passed:=(prev_value:=self.value) < value) and (time_spent:=int(time.perf_counter() * 1000) - start_time) < timeout:
|
248
267
|
self._sleep(time_spent)
|
249
|
-
|
268
|
+
if self.value != prev_value: start_time = int(time.perf_counter() * 1000) # progress was made, reset timer
|
269
|
+
if not_passed and self.value < value: raise RuntimeError(f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})")
|
250
270
|
|
251
271
|
@contextlib.contextmanager
|
252
|
-
def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type:
|
253
|
-
st, en = (dev.
|
272
|
+
def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type:Callable[[], HWQueue]|None=None, queue:HWQueue|None=None):
|
273
|
+
st, en = (dev.new_signal(), dev.new_signal()) if enabled else (None, None)
|
254
274
|
|
255
275
|
if enabled and queue is not None: queue.timestamp(st)
|
256
276
|
elif enabled:
|
257
277
|
assert queue_type is not None
|
258
|
-
queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(st).signal(dev.timeline_signal, dev.
|
259
|
-
dev.timeline_value += 1
|
278
|
+
queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(st).signal(dev.timeline_signal, dev.next_timeline()).submit(dev)
|
260
279
|
|
261
280
|
try: yield (st, en)
|
262
281
|
finally:
|
263
282
|
if enabled and queue is not None: queue.timestamp(en)
|
264
283
|
elif enabled:
|
265
284
|
assert queue_type is not None
|
266
|
-
queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(en).signal(dev.timeline_signal, dev.
|
267
|
-
dev.timeline_value += 1
|
285
|
+
queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(en).signal(dev.timeline_signal, dev.next_timeline()).submit(dev)
|
268
286
|
|
269
287
|
if enabled and PROFILE: dev.sig_prof_records.append((cast(HCQSignal, st), cast(HCQSignal, en), desc, queue_type is dev.hw_copy_queue_t))
|
270
288
|
|
271
289
|
class HCQArgsState(Generic[ProgramType]):
|
272
|
-
def __init__(self,
|
273
|
-
self.
|
274
|
-
self.bind_data:list[tuple[tuple[sint, ...],
|
290
|
+
def __init__(self, buf:HCQBuffer, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint, ...]=()):
|
291
|
+
self.buf, self.prg, self.bufs, self.vals = buf, prg, bufs, vals
|
292
|
+
self.bind_data:list[tuple[tuple[sint, ...], MMIOInterface, str]] = []
|
275
293
|
|
276
|
-
def
|
294
|
+
def bind_sints_to_buf(self, *vals:sint, buf:HCQBuffer, fmt, offset=0): self.bind_data.append((vals, buf.cpu_view().view(offset=offset), fmt))
|
277
295
|
|
278
296
|
class CLikeArgsState(HCQArgsState[ProgramType]):
|
279
|
-
def __init__(self,
|
280
|
-
super().__init__(
|
297
|
+
def __init__(self, buf:HCQBuffer, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint, ...]=(), prefix:list[int]|None=None):
|
298
|
+
super().__init__(buf, prg, bufs, vals=vals)
|
281
299
|
|
282
|
-
if prefix is not None:
|
300
|
+
if prefix is not None: self.buf.cpu_view().view(size=len(prefix) * 4, fmt='I')[:] = array.array('I', prefix)
|
283
301
|
|
284
|
-
self.
|
285
|
-
self.
|
302
|
+
self.bind_sints_to_buf(*[b.va_addr for b in bufs], buf=self.buf, fmt='Q', offset=len(prefix or []) * 4)
|
303
|
+
self.bind_sints_to_buf(*vals, buf=self.buf, fmt='I', offset=len(prefix or []) * 4 + len(bufs) * 8)
|
286
304
|
|
287
|
-
class HCQProgram(Generic[
|
288
|
-
def __init__(self, args_state_t:Type[HCQArgsState], dev:
|
305
|
+
class HCQProgram(Generic[HCQDeviceType]):
|
306
|
+
def __init__(self, args_state_t:Type[HCQArgsState], dev:HCQDeviceType, name:str, kernargs_alloc_size:int, lib:bytes|None=None, base:int|None=None):
|
289
307
|
self.args_state_t, self.dev, self.name, self.kernargs_alloc_size = args_state_t, dev, name, kernargs_alloc_size
|
308
|
+
if PROFILE: Compiled.profile_events += [ProfileProgramEvent(dev.device, name, lib, base)]
|
309
|
+
|
310
|
+
@staticmethod
|
311
|
+
def _fini(dev, buf, spec): dev.allocator.free(buf, buf.size, spec)
|
290
312
|
|
291
|
-
def fill_kernargs(self, bufs:tuple[HCQBuffer, ...], vals:tuple[int, ...]=(),
|
313
|
+
def fill_kernargs(self, bufs:tuple[HCQBuffer, ...], vals:tuple[int, ...]=(), kernargs:HCQBuffer|None=None) -> HCQArgsState:
|
292
314
|
"""
|
293
315
|
Fills arguments for the kernel, optionally allocating space from the device if `kernargs_ptr` is not provided.
|
294
316
|
Args:
|
@@ -298,7 +320,9 @@ class HCQProgram(Generic[DeviceType]):
|
|
298
320
|
Returns:
|
299
321
|
Arguments state with the given buffers and values set for the program.
|
300
322
|
"""
|
301
|
-
|
323
|
+
argsbuf = kernargs or self.dev.kernargs_buf.offset(offset=self.dev.kernargs_offset_allocator.alloc(self.kernargs_alloc_size),
|
324
|
+
size=self.kernargs_alloc_size)
|
325
|
+
return self.args_state_t(argsbuf, self, bufs, vals=vals)
|
302
326
|
|
303
327
|
def __call__(self, *bufs:HCQBuffer, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1),
|
304
328
|
vals:tuple[int, ...]=(), wait:bool=False) -> float|None:
|
@@ -322,8 +346,7 @@ class HCQProgram(Generic[DeviceType]):
|
|
322
346
|
with hcq_profile(self.dev, queue=q, desc=self.name, enabled=wait or PROFILE) as (sig_st, sig_en):
|
323
347
|
q.exec(self, kernargs, global_size, local_size)
|
324
348
|
|
325
|
-
q.signal(self.dev.timeline_signal, self.dev.
|
326
|
-
self.dev.timeline_value += 1
|
349
|
+
q.signal(self.dev.timeline_signal, self.dev.next_timeline()).submit(self.dev)
|
327
350
|
|
328
351
|
if wait: self.dev.synchronize()
|
329
352
|
return (float(sig_en.timestamp - sig_st.timestamp) / 1e6) if wait else None
|
@@ -332,25 +355,41 @@ class HCQCompiled(Compiled, Generic[SignalType]):
|
|
332
355
|
"""
|
333
356
|
A base class for devices compatible with the HCQ (Hardware Command Queue) API.
|
334
357
|
"""
|
335
|
-
|
358
|
+
peer_groups: dict[str, list[HCQCompiled]] = collections.defaultdict(list)
|
359
|
+
signal_pages: dict[str, list[HCQBuffer]] = collections.defaultdict(list) # per peer group
|
360
|
+
signal_pool: dict[str, list[HCQBuffer]] = collections.defaultdict(list) # per peer group
|
361
|
+
cpu_devices: list[HCQCompiled] = []
|
336
362
|
|
337
363
|
def __init__(self, device:str, allocator:HCQAllocatorBase, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[SignalType],
|
338
|
-
comp_queue_t:
|
364
|
+
comp_queue_t:Callable[[], HWQueue], copy_queue_t:Callable[[], HWQueue]|None=None, kernargs_size=(16 << 20), sigalloc_size=0x1000):
|
339
365
|
self.device_id:int = int(device.split(":")[1]) if ":" in device else 0
|
366
|
+
|
367
|
+
from tinygrad.runtime.graph.hcq import HCQGraph
|
368
|
+
super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
|
369
|
+
|
370
|
+
# TODO: peer logic is determined based on device name.
|
371
|
+
self.peer_group = device.split(":")[0]
|
372
|
+
HCQCompiled.peer_groups[self.peer_group].append(self)
|
373
|
+
|
374
|
+
# Map signals if any
|
375
|
+
for sig_page in HCQCompiled.signal_pages[self.peer_group]: cast(HCQAllocator, self.allocator).map(sig_page)
|
376
|
+
|
377
|
+
self.sigalloc_size = sigalloc_size
|
340
378
|
self.signal_t, self.hw_compute_queue_t, self.hw_copy_queue_t = signal_t, comp_queue_t, copy_queue_t
|
341
379
|
self.timeline_value:int = 1
|
342
|
-
self.timeline_signal
|
343
|
-
self._shadow_timeline_signal:SignalType = self.signal_t(value=0, timeline_for_device=self)
|
380
|
+
self.timeline_signal, self._shadow_timeline_signal = self.new_signal(value=0, is_timeline=True), self.new_signal(value=0, is_timeline=True)
|
344
381
|
self.sig_prof_records:list[tuple[HCQSignal, HCQSignal, str, bool]] = []
|
345
382
|
|
346
|
-
|
347
|
-
|
383
|
+
self.kernargs_buf:HCQBuffer = self.allocator.alloc(kernargs_size, BufferSpec(cpu_access=True))
|
384
|
+
self.kernargs_offset_allocator:BumpAllocator = BumpAllocator(self.kernargs_buf.size, wrap=True)
|
348
385
|
|
349
|
-
self.
|
350
|
-
self.kernargs_allocator:BumpAllocator = BumpAllocator(self.kernargs_page.size, base=cast(int, self.kernargs_page.va_addr), wrap=True)
|
351
|
-
self.devices.append(self)
|
386
|
+
if self._is_cpu(): HCQCompiled.cpu_devices.append(self)
|
352
387
|
|
353
388
|
def synchronize(self):
|
389
|
+
# If we have any work on CPU devices, need to synchronize them. This is just an optimization to release GIL allowing to finish faster.
|
390
|
+
if not self._is_cpu():
|
391
|
+
for dev in HCQCompiled.cpu_devices: dev.synchronize()
|
392
|
+
|
354
393
|
try: self.timeline_signal.wait(self.timeline_value - 1)
|
355
394
|
except RuntimeError as e:
|
356
395
|
if hasattr(self, 'on_device_hang'): self.on_device_hang()
|
@@ -361,10 +400,22 @@ class HCQCompiled(Compiled, Generic[SignalType]):
|
|
361
400
|
Compiled.profile_events += [ProfileRangeEvent(self.device, name, st.timestamp, en.timestamp, cp) for st,en,name,cp in self.sig_prof_records]
|
362
401
|
self.sig_prof_records = []
|
363
402
|
|
403
|
+
def next_timeline(self):
|
404
|
+
self.timeline_value += 1
|
405
|
+
return self.timeline_value - 1
|
406
|
+
|
407
|
+
def new_signal(self, **kwargs) -> SignalType:
|
408
|
+
if not HCQCompiled.signal_pool[pg:=self.peer_group]:
|
409
|
+
HCQCompiled.signal_pages[pg].append(alc:=self.allocator.alloc(self.sigalloc_size, BufferSpec(host=True, uncached=True, cpu_access=True)))
|
410
|
+
HCQCompiled.signal_pool[pg] += [alc.offset(offset=off, size=16) for off in range(0, alc.size, 16)]
|
411
|
+
for dev in HCQCompiled.peer_groups[pg]: cast(HCQAllocator, dev.allocator).map(alc)
|
412
|
+
return self.signal_t(base_buf=HCQCompiled.signal_pool[pg].pop(), owner=self, **kwargs)
|
413
|
+
|
364
414
|
def _at_profile_finalize(self):
|
365
|
-
|
366
|
-
|
367
|
-
|
415
|
+
self.synchronize() # Expect device to be synchronizes
|
416
|
+
|
417
|
+
def _sync(d:HCQCompiled, q_t:Callable[[], HWQueue]):
|
418
|
+
q_t().timestamp(d.timeline_signal).signal(d.timeline_signal, d.next_timeline()).submit(d)
|
368
419
|
st = time.perf_counter_ns()
|
369
420
|
d.timeline_signal.wait(d.timeline_value - 1) # average of the two
|
370
421
|
et = time.perf_counter_ns()
|
@@ -386,41 +437,82 @@ class HCQCompiled(Compiled, Generic[SignalType]):
|
|
386
437
|
except MemoryError: buf, realloced = self.allocator.alloc(oldbuf.size if oldbuf is not None else new_size, options=options), False
|
387
438
|
return buf, realloced
|
388
439
|
|
440
|
+
def _select_iface(self, *ifaces:Type):
|
441
|
+
errs:str = ""
|
442
|
+
if val:=getenv(f'{type(self).__name__[:-6].upper()}_IFACE', ""): ifaces = tuple(x for x in ifaces if x.__name__.startswith(val.upper()))
|
443
|
+
for iface_t in ifaces:
|
444
|
+
try: return iface_t(self, self.device_id)
|
445
|
+
except Exception: errs += f"\n{iface_t.__name__}: {traceback.format_exc()}"
|
446
|
+
raise RuntimeError(f"Cannot find a usable interface for {type(self).__name__[:-6]}:{self.device_id}:\n{errs}")
|
447
|
+
|
448
|
+
def _is_cpu(self) -> bool: return hasattr(self, 'device') and self.device.split(":")[0] in ("CPU", "LLVM")
|
449
|
+
|
450
|
+
def finalize(self):
|
451
|
+
try: self.synchronize() # Try to finalize device in any case.
|
452
|
+
except RuntimeError as e: print(f"{self.device} synchronization failed before finalizing: {e}")
|
453
|
+
|
454
|
+
# If the device has an interface, call its device_fini method to clean up resources.
|
455
|
+
if hasattr(self, 'iface') and hasattr(self.iface, 'device_fini'): self.iface.device_fini()
|
456
|
+
|
389
457
|
class HCQBuffer:
|
390
|
-
def __init__(self, va_addr:sint, size:int, texture_info:Any=None, meta:Any=None, _base:HCQBuffer|None=None
|
391
|
-
|
458
|
+
def __init__(self, va_addr:sint, size:int, texture_info:Any=None, meta:Any=None, _base:HCQBuffer|None=None, view:MMIOInterface|None=None,
|
459
|
+
owner:HCQCompiled|None=None):
|
460
|
+
self.va_addr, self.size, self.texture_info, self.meta, self._base, self.view = va_addr, size, texture_info, meta, _base, view
|
461
|
+
self._devs, self.owner = ([owner] if owner is not None else []), owner
|
462
|
+
self._mappings:dict[HCQCompiled, HCQBuffer] = {} # mapping to the other devices
|
463
|
+
|
464
|
+
def offset(self, offset:int=0, size:int|None=None) -> HCQBuffer:
|
465
|
+
return HCQBuffer(self.va_addr+offset, size or (self.size - offset), owner=self.owner, texture_info=self.texture_info, meta=self.meta,
|
466
|
+
_base=self._base or self, view=(self.view.view(offset=offset, size=size) if self.view is not None else None))
|
467
|
+
|
468
|
+
def cpu_view(self) -> MMIOInterface:
|
469
|
+
assert self.view is not None, "buffer has no cpu_view"
|
470
|
+
return self.view
|
392
471
|
|
393
|
-
|
472
|
+
@property
|
473
|
+
def mappings(self): return self._mappings if self._base is None else self._base._mappings
|
474
|
+
|
475
|
+
@property
|
476
|
+
def mapped_devs(self): return self._devs if self._base is None else self._base._devs
|
477
|
+
|
478
|
+
class HCQAllocatorBase(LRUAllocator[HCQDeviceType], Generic[HCQDeviceType]):
|
394
479
|
"""
|
395
480
|
A base allocator class compatible with the HCQ (Hardware Command Queue) API.
|
396
481
|
|
397
482
|
This class implements basic copy operations following the HCQ API, utilizing both types of `HWQueue`.
|
398
483
|
"""
|
399
484
|
|
400
|
-
def __init__(self, dev:
|
401
|
-
|
402
|
-
self.b = [self._alloc(batch_size, BufferSpec(host=True)) for _ in range(batch_cnt)]
|
403
|
-
self.b_timeline, self.b_next = [0] * len(self.b), 0
|
404
|
-
super().__init__()
|
485
|
+
def __init__(self, dev:HCQDeviceType, batch_size:int=(2 << 20), batch_cnt:int=32, copy_bufs=None, max_copyout_size:int|None=None):
|
486
|
+
super().__init__(dev)
|
487
|
+
self.b = copy_bufs or [self._alloc(batch_size, BufferSpec(host=True)) for _ in range(batch_cnt)]
|
488
|
+
self.b_timeline, self.b_next, self.max_copyout_size = [0] * len(self.b), 0, max_copyout_size
|
405
489
|
|
406
|
-
def map(self, buf:HCQBuffer):
|
490
|
+
def map(self, buf:HCQBuffer):
|
491
|
+
if self.dev in buf.mapped_devs: return
|
492
|
+
if buf.owner is None: raise RuntimeError(f"map failed: buffer {buf.va_addr} has no owner, it's a virtual buffer")
|
493
|
+
if not hasattr(self, '_map'): raise NotImplementedError("map failed: no method implemented")
|
407
494
|
|
408
|
-
|
409
|
-
|
495
|
+
# Since it's unified memory space, any buffer mapping is valid for all devices after successful map.
|
496
|
+
# Devices can save mappings and internal metadata as a new buffer.
|
497
|
+
if (mb:=self._map(buf)) is not None: buf.mappings[self.dev] = mb
|
498
|
+
buf.mapped_devs.append(self.dev)
|
410
499
|
|
411
|
-
|
500
|
+
def _offset(self, buf, size:int, offset:int) -> HCQBuffer: return buf.offset(offset=offset, size=size)
|
501
|
+
|
502
|
+
class HCQAllocator(HCQAllocatorBase, Generic[HCQDeviceType]):
|
412
503
|
def _copyin(self, dest:HCQBuffer, src:memoryview):
|
413
504
|
assert self.dev.hw_copy_queue_t is not None
|
414
|
-
with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"
|
505
|
+
with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"TINY -> {self.dev.device}", enabled=PROFILE):
|
415
506
|
for i in range(0, src.nbytes, self.b[0].size):
|
416
507
|
self.b_next = (self.b_next + 1) % len(self.b)
|
417
508
|
self.dev.timeline_signal.wait(self.b_timeline[self.b_next])
|
418
|
-
|
509
|
+
|
510
|
+
lsize = min(self.b[self.b_next].size, src.nbytes - i)
|
511
|
+
self.b[self.b_next].cpu_view().view(size=lsize, fmt='B')[:] = src[i:i+lsize]
|
419
512
|
self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
|
420
513
|
.copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \
|
421
|
-
.signal(self.dev.timeline_signal, self.dev.
|
422
|
-
self.b_timeline[self.b_next] = self.dev.timeline_value
|
423
|
-
self.dev.timeline_value += 1
|
514
|
+
.signal(self.dev.timeline_signal, self.dev.next_timeline()).submit(self.dev)
|
515
|
+
self.b_timeline[self.b_next] = self.dev.timeline_value - 1
|
424
516
|
|
425
517
|
def copy_from_disk(self, dest:HCQBuffer, src, size):
|
426
518
|
def _get_temp_buf():
|
@@ -435,25 +527,22 @@ class HCQAllocator(HCQAllocatorBase, Generic[DeviceType]):
|
|
435
527
|
for (batch_info, dst_off, src_off, copy_size) in src.device.allocator._copyout_sharded(src, size, _get_temp_buf, seg_len=self.b[0].size):
|
436
528
|
self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
|
437
529
|
.copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
|
438
|
-
.signal(self.dev.timeline_signal, self.dev.
|
439
|
-
self.b_timeline[batch_info[1]] = self.dev.timeline_value
|
440
|
-
self.dev.timeline_value += 1
|
530
|
+
.signal(self.dev.timeline_signal, self.dev.next_timeline()).submit(self.dev)
|
531
|
+
self.b_timeline[batch_info[1]] = self.dev.timeline_value - 1
|
441
532
|
|
442
533
|
def _copyout(self, dest:memoryview, src:HCQBuffer):
|
443
534
|
self.dev.synchronize()
|
444
535
|
|
445
536
|
assert self.dev.hw_copy_queue_t is not None
|
446
|
-
with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"{self.dev.device} ->
|
447
|
-
for i in range(0, dest.nbytes, self.b[0].size):
|
537
|
+
with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"{self.dev.device} -> TINY", enabled=PROFILE):
|
538
|
+
for i in range(0, dest.nbytes, cp_size:=(self.max_copyout_size or self.b[0].size)):
|
448
539
|
self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
|
449
|
-
.copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(
|
450
|
-
.signal(self.dev.timeline_signal, self.dev.
|
451
|
-
self.dev.timeline_signal.wait(self.dev.timeline_value)
|
452
|
-
self.
|
453
|
-
|
454
|
-
ctypes.memmove(from_mv(dest[i:]), self.b[0].va_addr, lsize)
|
540
|
+
.copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(cp_size, dest.nbytes-i)) \
|
541
|
+
.signal(self.dev.timeline_signal, self.dev.next_timeline()).submit(self.dev)
|
542
|
+
self.dev.timeline_signal.wait(self.dev.timeline_value - 1)
|
543
|
+
dest[i:i+lsize] = self.b[0].cpu_view().view(size=lsize, fmt='B')[:]
|
455
544
|
|
456
|
-
def _transfer(self, dest:HCQBuffer, src:HCQBuffer, sz:int, src_dev:
|
545
|
+
def _transfer(self, dest:HCQBuffer, src:HCQBuffer, sz:int, src_dev:HCQDeviceType, dest_dev:HCQDeviceType):
|
457
546
|
cast(HCQAllocator, src_dev.allocator).map(dest)
|
458
547
|
|
459
548
|
assert src_dev.hw_copy_queue_t is not None
|
@@ -461,11 +550,9 @@ class HCQAllocator(HCQAllocatorBase, Generic[DeviceType]):
|
|
461
550
|
src_dev.hw_copy_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
|
462
551
|
.wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
|
463
552
|
.copy(dest.va_addr, src.va_addr, sz) \
|
464
|
-
.signal(src_dev.timeline_signal, src_dev.
|
465
|
-
src_dev.timeline_value += 1
|
553
|
+
.signal(src_dev.timeline_signal, src_dev.next_timeline()).submit(src_dev)
|
466
554
|
|
467
555
|
if src_dev != dest_dev:
|
468
556
|
dest_dev.hw_compute_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
|
469
557
|
.wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
|
470
|
-
.signal(dest_dev.timeline_signal, dest_dev.
|
471
|
-
dest_dev.timeline_value += 1
|
558
|
+
.signal(dest_dev.timeline_signal, dest_dev.next_timeline()).submit(dest_dev)
|