tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,23 @@
1
1
  from __future__ import annotations
2
- from typing import cast, Type, TypeVar, Generic, Any
3
- import contextlib, decimal, statistics, time, ctypes, array, os, fcntl
4
- from tinygrad.helpers import PROFILE, from_mv, getenv, to_mv, round_up
2
+ from typing import cast, Callable, Type, TypeVar, Generic, Any
3
+ import contextlib, decimal, statistics, time, ctypes, array, os, struct, traceback, collections
4
+ try: import fcntl # windows misses that
5
+ except ImportError: fcntl = None #type:ignore[assignment]
6
+ from tinygrad.helpers import PROFILE, getenv, to_mv, round_up, ProfileRangeEvent
5
7
  from tinygrad.renderer import Renderer
6
- from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator, ProfileRangeEvent, ProfileDeviceEvent
7
- from tinygrad.ops import sym_infer, sint, Variable, UOp
8
+ from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator, ProfileDeviceEvent, ProfileProgramEvent
9
+ from tinygrad.uop.ops import sym_infer, sint, Variable, UOp
8
10
  from tinygrad.runtime.autogen import libc
9
11
 
10
- class HWInterface:
12
+ class MMIOInterface:
13
+ def __init__(self, addr:int, nbytes:int, fmt='B'): self.mv, self.addr, self.nbytes, self.fmt = to_mv(addr, nbytes).cast(fmt), addr, nbytes, fmt
14
+ def __len__(self): return self.nbytes // struct.calcsize(self.fmt)
15
+ def __getitem__(self, k): return (bytes(self.mv[k]) if self.fmt == 'B' else self.mv[k].tolist()) if isinstance(k, slice) else self.mv[k]
16
+ def __setitem__(self, k, v): self.mv[k] = v
17
+ def view(self, offset:int=0, size:int|None=None, fmt=None) -> MMIOInterface:
18
+ return MMIOInterface(self.addr+offset, size or (self.nbytes - offset), fmt=fmt or self.fmt)
19
+
20
+ class FileIOInterface:
11
21
  """
12
22
  Hardware Abstraction Layer for HCQ devices. The class provides a unified interface for interacting with hardware devices.
13
23
  """
@@ -18,7 +28,10 @@ class HWInterface:
18
28
  def __del__(self):
19
29
  if hasattr(self, 'fd'): os.close(self.fd)
20
30
  def ioctl(self, request, arg): return fcntl.ioctl(self.fd, request, arg)
21
- def mmap(self, start, sz, prot, flags, offset): return libc.mmap(start, sz, prot, flags, self.fd, offset)
31
+ def mmap(self, start, sz, prot, flags, offset):
32
+ x = libc.mmap(start, sz, prot, flags, self.fd, offset)
33
+ if x == 0xffffffffffffffff: raise OSError(f"Failed to mmap {sz} bytes at {hex(start)}: {os.strerror(ctypes.get_errno())}")
34
+ return x
22
35
  def read(self, size=None, binary=False, offset=None):
23
36
  if offset is not None: self.seek(offset)
24
37
  with open(self.fd, "rb" if binary else "r", closefd=False) as file: return file.read(size)
@@ -28,7 +41,10 @@ class HWInterface:
28
41
  def listdir(self): return os.listdir(self.path)
29
42
  def seek(self, offset): os.lseek(self.fd, offset, os.SEEK_SET)
30
43
  @staticmethod
31
- def anon_mmap(start, sz, prot, flags, offset): return libc.mmap(start, sz, prot, flags, -1, offset)
44
+ def anon_mmap(start, sz, prot, flags, offset):
45
+ x = libc.mmap(start, sz, prot, flags, -1, offset)
46
+ if x == 0xffffffffffffffff: raise OSError(f"Failed to mmap {sz} bytes at {hex(start)}: {os.strerror(ctypes.get_errno())}")
47
+ return x
32
48
  @staticmethod
33
49
  def munmap(buf, sz): return libc.munmap(buf, sz)
34
50
  @staticmethod
@@ -36,14 +52,14 @@ class HWInterface:
36
52
  @staticmethod
37
53
  def readlink(path): return os.readlink(path)
38
54
  @staticmethod
39
- def eventfd(initval, flags=None): return HWInterface(fd=os.eventfd(initval, flags)) # type: ignore[attr-defined]
55
+ def eventfd(initval, flags=None): return FileIOInterface(fd=os.eventfd(initval, flags)) # type: ignore[attr-defined]
40
56
 
41
- if MOCKGPU:=getenv("MOCKGPU"): from test.mockgpu.mockgpu import MockHWInterface as HWInterface # noqa: F401 # pylint: disable=unused-import
57
+ if MOCKGPU:=getenv("MOCKGPU"): from test.mockgpu.mockgpu import MockFileIOInterface as FileIOInterface # noqa: F401 # pylint: disable=unused-import
42
58
 
43
59
  # **************** for HCQ Compatible Devices ****************
44
60
 
45
61
  SignalType = TypeVar('SignalType', bound='HCQSignal')
46
- DeviceType = TypeVar('DeviceType', bound='HCQCompiled')
62
+ HCQDeviceType = TypeVar('HCQDeviceType', bound='HCQCompiled')
47
63
  ProgramType = TypeVar('ProgramType', bound='HCQProgram')
48
64
  ArgsStateType = TypeVar('ArgsStateType', bound='HCQArgsState')
49
65
  QueueType = TypeVar('QueueType', bound='HWQueue')
@@ -57,16 +73,16 @@ class BumpAllocator:
57
73
  self.ptr = (res:=round_up(self.ptr, alignment)) + size
58
74
  return res + self.base
59
75
 
60
- class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
76
+ class HWQueue(Generic[SignalType, HCQDeviceType, ProgramType, ArgsStateType]):
61
77
  """
62
78
  A base class for hardware command queues in the HCQ (Hardware Command Queue) API.
63
79
  """
64
80
 
65
81
  def __init__(self):
66
82
  self._q:Any = []
67
- self.binded_device:DeviceType|None = None
83
+ self.binded_device:HCQDeviceType|None = None
68
84
  self.q_sints:list[tuple[int, int]] = []
69
- self.mv_sints:list[tuple[memoryview, int, int, int|None]] = []
85
+ self.mv_sints:list[tuple[MMIOInterface, int, int, int|None]] = []
70
86
  self.syms:list[sint] = []
71
87
  self._prev_resolved_syms:list[int|None] = []
72
88
 
@@ -150,7 +166,7 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
150
166
 
151
167
  # *** submit and bind commands ***
152
168
 
153
- def bind(self, dev:DeviceType):
169
+ def bind(self, dev:HCQDeviceType):
154
170
  """
155
171
  Associates the queue with a specific device for optimized execution.
156
172
 
@@ -165,13 +181,13 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
165
181
  """
166
182
 
167
183
  def bind_args_state(self, args_state:ArgsStateType):
168
- for vals, ptr, fmt in args_state.bind_data: self.bind_sints_to_ptr(*vals, ptr=ptr, fmt=fmt)
184
+ for vals, mem, fmt in args_state.bind_data: self.bind_sints_to_mem(*vals, mem=mem, fmt=fmt)
169
185
 
170
- def bind_sints(self, *vals:sint, struct:ctypes.Structure, start_field:str, fmt, mask:int|None=None):
171
- self.bind_sints_to_ptr(*vals, ptr=ctypes.addressof(struct) + getattr(type(struct), start_field).offset, fmt=fmt, mask=mask)
186
+ def bind_sints(self, *vals:sint, mem:MMIOInterface, struct_t:Type[ctypes.Structure], start_field:str, fmt, mask:int|None=None):
187
+ self.bind_sints_to_mem(*vals, mem=mem, fmt=fmt, mask=mask, offset=getattr(struct_t, start_field).offset)
172
188
 
173
- def bind_sints_to_ptr(self, *vals:sint, ptr:int, fmt, mask:int|None=None):
174
- mv = to_mv(ptr, 8*len(vals)).cast(fmt)
189
+ def bind_sints_to_mem(self, *vals:sint, mem:MMIOInterface, fmt, mask:int|None=None, offset:int=0):
190
+ mv = mem.view(offset=offset, size=len(vals)*8, fmt=fmt)
175
191
  for i, val in enumerate(vals):
176
192
  if isinstance(val, int): mv[i] = val if mask is None else ((mv[i] & ~mask) | val)
177
193
  else: self.mv_sints.append((mv, i, self._new_sym(val), mask))
@@ -189,7 +205,7 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
189
205
 
190
206
  self._prev_resolved_syms = cast(list[int|None], resolved_syms)
191
207
 
192
- def submit(self, dev:DeviceType, var_vals:dict[Variable, int]|None=None):
208
+ def submit(self, dev:HCQDeviceType, var_vals:dict[Variable, int]|None=None):
193
209
  """
194
210
  Submits the command queue to a specific device for execution.
195
211
 
@@ -200,18 +216,21 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
200
216
  if var_vals is not None: self._apply_var_vals(var_vals)
201
217
  self._submit(dev)
202
218
  return self
203
- def _submit(self, dev:DeviceType): raise NotImplementedError("need _submit")
219
+ def _submit(self, dev:HCQDeviceType): raise NotImplementedError("need _submit")
204
220
 
205
- class HCQSignal(Generic[DeviceType]):
206
- def __init__(self, base_addr:sint=0, value:int=0, timeline_for_device:DeviceType|None=None, timestamp_divider=1, value_off=0, timestamp_off=8):
207
- self.base_addr, self.value_addr, self.timestamp_addr = base_addr, base_addr+value_off, base_addr+timestamp_off
221
+ class HCQSignal(Generic[HCQDeviceType]):
222
+ def __init__(self, base_buf:HCQBuffer, value:int=0, owner:HCQDeviceType|None=None, is_timeline:bool=False, timestamp_divider=1000):
223
+ self.base_buf, self.value_addr, self.timestamp_addr, self.owner = base_buf, base_buf.va_addr+0, base_buf.va_addr+8, owner
224
+ self.is_timeline = is_timeline
208
225
  self.timestamp_divider:decimal.Decimal = decimal.Decimal(timestamp_divider)
209
- self.timeline_for_device:DeviceType|None = timeline_for_device
210
226
 
211
- if isinstance(base_addr, int):
212
- self.value_mv, self.timestamp_mv = to_mv(self.value_addr, 8).cast('Q'), to_mv(self.timestamp_addr, 8).cast('Q')
227
+ if isinstance(self.base_buf.va_addr, int):
228
+ self.value_mv, self.timestamp_mv = self.base_buf.cpu_view().view(0, 8, 'Q'), self.base_buf.cpu_view().view(8, 8, 'Q')
213
229
  self.value_mv[0] = value
214
230
 
231
+ def __del__(self):
232
+ if isinstance(self.base_buf.va_addr, int) and self.owner is not None: HCQCompiled.signal_pool[self.owner.peer_group].append(self.base_buf)
233
+
215
234
  @property
216
235
  def value(self) -> int: return self.value_mv[0]
217
236
 
@@ -241,54 +260,57 @@ class HCQSignal(Generic[DeviceType]):
241
260
 
242
261
  Args:
243
262
  value: The value to wait for.
244
- timeout: Maximum time to wait in milliseconds. Defaults to 10s.
263
+ timeout: Maximum time to wait in milliseconds. Defaults to 30s.
245
264
  """
246
265
  start_time = int(time.perf_counter() * 1000)
247
- while self.value < value and (time_spent:=int(time.perf_counter() * 1000) - start_time) < timeout:
266
+ while (not_passed:=(prev_value:=self.value) < value) and (time_spent:=int(time.perf_counter() * 1000) - start_time) < timeout:
248
267
  self._sleep(time_spent)
249
- if self.value < value: raise RuntimeError(f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})")
268
+ if self.value != prev_value: start_time = int(time.perf_counter() * 1000) # progress was made, reset timer
269
+ if not_passed and self.value < value: raise RuntimeError(f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})")
250
270
 
251
271
  @contextlib.contextmanager
252
- def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type:Type[HWQueue]|None=None, queue:HWQueue|None=None):
253
- st, en = (dev.signal_t(), dev.signal_t()) if enabled else (None, None)
272
+ def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type:Callable[[], HWQueue]|None=None, queue:HWQueue|None=None):
273
+ st, en = (dev.new_signal(), dev.new_signal()) if enabled else (None, None)
254
274
 
255
275
  if enabled and queue is not None: queue.timestamp(st)
256
276
  elif enabled:
257
277
  assert queue_type is not None
258
- queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(st).signal(dev.timeline_signal, dev.timeline_value).submit(dev)
259
- dev.timeline_value += 1
278
+ queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(st).signal(dev.timeline_signal, dev.next_timeline()).submit(dev)
260
279
 
261
280
  try: yield (st, en)
262
281
  finally:
263
282
  if enabled and queue is not None: queue.timestamp(en)
264
283
  elif enabled:
265
284
  assert queue_type is not None
266
- queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(en).signal(dev.timeline_signal, dev.timeline_value).submit(dev)
267
- dev.timeline_value += 1
285
+ queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(en).signal(dev.timeline_signal, dev.next_timeline()).submit(dev)
268
286
 
269
287
  if enabled and PROFILE: dev.sig_prof_records.append((cast(HCQSignal, st), cast(HCQSignal, en), desc, queue_type is dev.hw_copy_queue_t))
270
288
 
271
289
  class HCQArgsState(Generic[ProgramType]):
272
- def __init__(self, ptr:int, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint, ...]=()):
273
- self.ptr, self.prg = ptr, prg
274
- self.bind_data:list[tuple[tuple[sint, ...], int, str]] = []
290
+ def __init__(self, buf:HCQBuffer, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint, ...]=()):
291
+ self.buf, self.prg, self.bufs, self.vals = buf, prg, bufs, vals
292
+ self.bind_data:list[tuple[tuple[sint, ...], MMIOInterface, str]] = []
275
293
 
276
- def bind_sints_to_ptr(self, *vals:sint, ptr:int, fmt): self.bind_data.append((vals, ptr, fmt))
294
+ def bind_sints_to_buf(self, *vals:sint, buf:HCQBuffer, fmt, offset=0): self.bind_data.append((vals, buf.cpu_view().view(offset=offset), fmt))
277
295
 
278
296
  class CLikeArgsState(HCQArgsState[ProgramType]):
279
- def __init__(self, ptr:int, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint, ...]=(), prefix:list[int]|None=None):
280
- super().__init__(ptr, prg, bufs, vals=vals)
297
+ def __init__(self, buf:HCQBuffer, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint, ...]=(), prefix:list[int]|None=None):
298
+ super().__init__(buf, prg, bufs, vals=vals)
281
299
 
282
- if prefix is not None: to_mv(self.ptr, len(prefix) * 4).cast('I')[:] = array.array('I', prefix)
300
+ if prefix is not None: self.buf.cpu_view().view(size=len(prefix) * 4, fmt='I')[:] = array.array('I', prefix)
283
301
 
284
- self.bind_sints_to_ptr(*[b.va_addr for b in bufs], ptr=self.ptr + len(prefix or []) * 4, fmt='Q')
285
- self.bind_sints_to_ptr(*vals, ptr=self.ptr + len(prefix or []) * 4 + len(bufs) * 8, fmt='I')
302
+ self.bind_sints_to_buf(*[b.va_addr for b in bufs], buf=self.buf, fmt='Q', offset=len(prefix or []) * 4)
303
+ self.bind_sints_to_buf(*vals, buf=self.buf, fmt='I', offset=len(prefix or []) * 4 + len(bufs) * 8)
286
304
 
287
- class HCQProgram(Generic[DeviceType]):
288
- def __init__(self, args_state_t:Type[HCQArgsState], dev:DeviceType, name:str, kernargs_alloc_size:int):
305
+ class HCQProgram(Generic[HCQDeviceType]):
306
+ def __init__(self, args_state_t:Type[HCQArgsState], dev:HCQDeviceType, name:str, kernargs_alloc_size:int, lib:bytes|None=None, base:int|None=None):
289
307
  self.args_state_t, self.dev, self.name, self.kernargs_alloc_size = args_state_t, dev, name, kernargs_alloc_size
308
+ if PROFILE: Compiled.profile_events += [ProfileProgramEvent(dev.device, name, lib, base)]
309
+
310
+ @staticmethod
311
+ def _fini(dev, buf, spec): dev.allocator.free(buf, buf.size, spec)
290
312
 
291
- def fill_kernargs(self, bufs:tuple[HCQBuffer, ...], vals:tuple[int, ...]=(), kernargs_ptr:int|None=None) -> HCQArgsState:
313
+ def fill_kernargs(self, bufs:tuple[HCQBuffer, ...], vals:tuple[int, ...]=(), kernargs:HCQBuffer|None=None) -> HCQArgsState:
292
314
  """
293
315
  Fills arguments for the kernel, optionally allocating space from the device if `kernargs_ptr` is not provided.
294
316
  Args:
@@ -298,7 +320,9 @@ class HCQProgram(Generic[DeviceType]):
298
320
  Returns:
299
321
  Arguments state with the given buffers and values set for the program.
300
322
  """
301
- return self.args_state_t(kernargs_ptr or self.dev.kernargs_allocator.alloc(self.kernargs_alloc_size), self, bufs, vals=vals)
323
+ argsbuf = kernargs or self.dev.kernargs_buf.offset(offset=self.dev.kernargs_offset_allocator.alloc(self.kernargs_alloc_size),
324
+ size=self.kernargs_alloc_size)
325
+ return self.args_state_t(argsbuf, self, bufs, vals=vals)
302
326
 
303
327
  def __call__(self, *bufs:HCQBuffer, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1),
304
328
  vals:tuple[int, ...]=(), wait:bool=False) -> float|None:
@@ -322,8 +346,7 @@ class HCQProgram(Generic[DeviceType]):
322
346
  with hcq_profile(self.dev, queue=q, desc=self.name, enabled=wait or PROFILE) as (sig_st, sig_en):
323
347
  q.exec(self, kernargs, global_size, local_size)
324
348
 
325
- q.signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
326
- self.dev.timeline_value += 1
349
+ q.signal(self.dev.timeline_signal, self.dev.next_timeline()).submit(self.dev)
327
350
 
328
351
  if wait: self.dev.synchronize()
329
352
  return (float(sig_en.timestamp - sig_st.timestamp) / 1e6) if wait else None
@@ -332,25 +355,41 @@ class HCQCompiled(Compiled, Generic[SignalType]):
332
355
  """
333
356
  A base class for devices compatible with the HCQ (Hardware Command Queue) API.
334
357
  """
335
- devices: list[HCQCompiled] = []
358
+ peer_groups: dict[str, list[HCQCompiled]] = collections.defaultdict(list)
359
+ signal_pages: dict[str, list[HCQBuffer]] = collections.defaultdict(list) # per peer group
360
+ signal_pool: dict[str, list[HCQBuffer]] = collections.defaultdict(list) # per peer group
361
+ cpu_devices: list[HCQCompiled] = []
336
362
 
337
363
  def __init__(self, device:str, allocator:HCQAllocatorBase, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[SignalType],
338
- comp_queue_t:Type[HWQueue], copy_queue_t:Type[HWQueue]|None):
364
+ comp_queue_t:Callable[[], HWQueue], copy_queue_t:Callable[[], HWQueue]|None=None, kernargs_size=(16 << 20), sigalloc_size=0x1000):
339
365
  self.device_id:int = int(device.split(":")[1]) if ":" in device else 0
366
+
367
+ from tinygrad.runtime.graph.hcq import HCQGraph
368
+ super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
369
+
370
+ # TODO: peer logic is determined based on device name.
371
+ self.peer_group = device.split(":")[0]
372
+ HCQCompiled.peer_groups[self.peer_group].append(self)
373
+
374
+ # Map signals if any
375
+ for sig_page in HCQCompiled.signal_pages[self.peer_group]: cast(HCQAllocator, self.allocator).map(sig_page)
376
+
377
+ self.sigalloc_size = sigalloc_size
340
378
  self.signal_t, self.hw_compute_queue_t, self.hw_copy_queue_t = signal_t, comp_queue_t, copy_queue_t
341
379
  self.timeline_value:int = 1
342
- self.timeline_signal:SignalType = self.signal_t(value=0, timeline_for_device=self)
343
- self._shadow_timeline_signal:SignalType = self.signal_t(value=0, timeline_for_device=self)
380
+ self.timeline_signal, self._shadow_timeline_signal = self.new_signal(value=0, is_timeline=True), self.new_signal(value=0, is_timeline=True)
344
381
  self.sig_prof_records:list[tuple[HCQSignal, HCQSignal, str, bool]] = []
345
382
 
346
- from tinygrad.runtime.graph.hcq import HCQGraph
347
- super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
383
+ self.kernargs_buf:HCQBuffer = self.allocator.alloc(kernargs_size, BufferSpec(cpu_access=True))
384
+ self.kernargs_offset_allocator:BumpAllocator = BumpAllocator(self.kernargs_buf.size, wrap=True)
348
385
 
349
- self.kernargs_page:HCQBuffer = self.allocator.alloc(16 << 20, BufferSpec(cpu_access=True))
350
- self.kernargs_allocator:BumpAllocator = BumpAllocator(self.kernargs_page.size, base=cast(int, self.kernargs_page.va_addr), wrap=True)
351
- self.devices.append(self)
386
+ if self._is_cpu(): HCQCompiled.cpu_devices.append(self)
352
387
 
353
388
  def synchronize(self):
389
+ # If we have any work on CPU devices, need to synchronize them. This is just an optimization to release GIL allowing to finish faster.
390
+ if not self._is_cpu():
391
+ for dev in HCQCompiled.cpu_devices: dev.synchronize()
392
+
354
393
  try: self.timeline_signal.wait(self.timeline_value - 1)
355
394
  except RuntimeError as e:
356
395
  if hasattr(self, 'on_device_hang'): self.on_device_hang()
@@ -361,10 +400,22 @@ class HCQCompiled(Compiled, Generic[SignalType]):
361
400
  Compiled.profile_events += [ProfileRangeEvent(self.device, name, st.timestamp, en.timestamp, cp) for st,en,name,cp in self.sig_prof_records]
362
401
  self.sig_prof_records = []
363
402
 
403
+ def next_timeline(self):
404
+ self.timeline_value += 1
405
+ return self.timeline_value - 1
406
+
407
+ def new_signal(self, **kwargs) -> SignalType:
408
+ if not HCQCompiled.signal_pool[pg:=self.peer_group]:
409
+ HCQCompiled.signal_pages[pg].append(alc:=self.allocator.alloc(self.sigalloc_size, BufferSpec(host=True, uncached=True, cpu_access=True)))
410
+ HCQCompiled.signal_pool[pg] += [alc.offset(offset=off, size=16) for off in range(0, alc.size, 16)]
411
+ for dev in HCQCompiled.peer_groups[pg]: cast(HCQAllocator, dev.allocator).map(alc)
412
+ return self.signal_t(base_buf=HCQCompiled.signal_pool[pg].pop(), owner=self, **kwargs)
413
+
364
414
  def _at_profile_finalize(self):
365
- def _sync(d:HCQCompiled, q_t:Type[HWQueue]):
366
- q_t().timestamp(d.timeline_signal).signal(d.timeline_signal, d.timeline_value).submit(d)
367
- d.timeline_value += 1
415
+ self.synchronize() # Expect device to be synchronizes
416
+
417
+ def _sync(d:HCQCompiled, q_t:Callable[[], HWQueue]):
418
+ q_t().timestamp(d.timeline_signal).signal(d.timeline_signal, d.next_timeline()).submit(d)
368
419
  st = time.perf_counter_ns()
369
420
  d.timeline_signal.wait(d.timeline_value - 1) # average of the two
370
421
  et = time.perf_counter_ns()
@@ -386,41 +437,82 @@ class HCQCompiled(Compiled, Generic[SignalType]):
386
437
  except MemoryError: buf, realloced = self.allocator.alloc(oldbuf.size if oldbuf is not None else new_size, options=options), False
387
438
  return buf, realloced
388
439
 
440
+ def _select_iface(self, *ifaces:Type):
441
+ errs:str = ""
442
+ if val:=getenv(f'{type(self).__name__[:-6].upper()}_IFACE', ""): ifaces = tuple(x for x in ifaces if x.__name__.startswith(val.upper()))
443
+ for iface_t in ifaces:
444
+ try: return iface_t(self, self.device_id)
445
+ except Exception: errs += f"\n{iface_t.__name__}: {traceback.format_exc()}"
446
+ raise RuntimeError(f"Cannot find a usable interface for {type(self).__name__[:-6]}:{self.device_id}:\n{errs}")
447
+
448
+ def _is_cpu(self) -> bool: return hasattr(self, 'device') and self.device.split(":")[0] in ("CPU", "LLVM")
449
+
450
+ def finalize(self):
451
+ try: self.synchronize() # Try to finalize device in any case.
452
+ except RuntimeError as e: print(f"{self.device} synchronization failed before finalizing: {e}")
453
+
454
+ # If the device has an interface, call its device_fini method to clean up resources.
455
+ if hasattr(self, 'iface') and hasattr(self.iface, 'device_fini'): self.iface.device_fini()
456
+
389
457
  class HCQBuffer:
390
- def __init__(self, va_addr:sint, size:int, texture_info:Any=None, meta:Any=None, _base:HCQBuffer|None=None):
391
- self.va_addr, self.size, self.texture_info, self.meta, self._base = va_addr, size, texture_info, meta, _base
458
+ def __init__(self, va_addr:sint, size:int, texture_info:Any=None, meta:Any=None, _base:HCQBuffer|None=None, view:MMIOInterface|None=None,
459
+ owner:HCQCompiled|None=None):
460
+ self.va_addr, self.size, self.texture_info, self.meta, self._base, self.view = va_addr, size, texture_info, meta, _base, view
461
+ self._devs, self.owner = ([owner] if owner is not None else []), owner
462
+ self._mappings:dict[HCQCompiled, HCQBuffer] = {} # mapping to the other devices
463
+
464
+ def offset(self, offset:int=0, size:int|None=None) -> HCQBuffer:
465
+ return HCQBuffer(self.va_addr+offset, size or (self.size - offset), owner=self.owner, texture_info=self.texture_info, meta=self.meta,
466
+ _base=self._base or self, view=(self.view.view(offset=offset, size=size) if self.view is not None else None))
467
+
468
+ def cpu_view(self) -> MMIOInterface:
469
+ assert self.view is not None, "buffer has no cpu_view"
470
+ return self.view
392
471
 
393
- class HCQAllocatorBase(LRUAllocator, Generic[DeviceType]):
472
+ @property
473
+ def mappings(self): return self._mappings if self._base is None else self._base._mappings
474
+
475
+ @property
476
+ def mapped_devs(self): return self._devs if self._base is None else self._base._devs
477
+
478
+ class HCQAllocatorBase(LRUAllocator[HCQDeviceType], Generic[HCQDeviceType]):
394
479
  """
395
480
  A base allocator class compatible with the HCQ (Hardware Command Queue) API.
396
481
 
397
482
  This class implements basic copy operations following the HCQ API, utilizing both types of `HWQueue`.
398
483
  """
399
484
 
400
- def __init__(self, dev:DeviceType, batch_size:int=(2 << 20), batch_cnt:int=32):
401
- self.dev:DeviceType = dev
402
- self.b = [self._alloc(batch_size, BufferSpec(host=True)) for _ in range(batch_cnt)]
403
- self.b_timeline, self.b_next = [0] * len(self.b), 0
404
- super().__init__()
485
+ def __init__(self, dev:HCQDeviceType, batch_size:int=(2 << 20), batch_cnt:int=32, copy_bufs=None, max_copyout_size:int|None=None):
486
+ super().__init__(dev)
487
+ self.b = copy_bufs or [self._alloc(batch_size, BufferSpec(host=True)) for _ in range(batch_cnt)]
488
+ self.b_timeline, self.b_next, self.max_copyout_size = [0] * len(self.b), 0, max_copyout_size
405
489
 
406
- def map(self, buf:HCQBuffer): pass
490
+ def map(self, buf:HCQBuffer):
491
+ if self.dev in buf.mapped_devs: return
492
+ if buf.owner is None: raise RuntimeError(f"map failed: buffer {buf.va_addr} has no owner, it's a virtual buffer")
493
+ if not hasattr(self, '_map'): raise NotImplementedError("map failed: no method implemented")
407
494
 
408
- def _offset(self, buf, size:int, offset:int) -> HCQBuffer:
409
- return HCQBuffer(va_addr=buf.va_addr + offset, size=size, texture_info=buf.texture_info, meta=buf.meta, _base=buf._base or buf)
495
+ # Since it's unified memory space, any buffer mapping is valid for all devices after successful map.
496
+ # Devices can save mappings and internal metadata as a new buffer.
497
+ if (mb:=self._map(buf)) is not None: buf.mappings[self.dev] = mb
498
+ buf.mapped_devs.append(self.dev)
410
499
 
411
- class HCQAllocator(HCQAllocatorBase, Generic[DeviceType]):
500
+ def _offset(self, buf, size:int, offset:int) -> HCQBuffer: return buf.offset(offset=offset, size=size)
501
+
502
+ class HCQAllocator(HCQAllocatorBase, Generic[HCQDeviceType]):
412
503
  def _copyin(self, dest:HCQBuffer, src:memoryview):
413
504
  assert self.dev.hw_copy_queue_t is not None
414
- with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"CPU -> {self.dev.device}", enabled=PROFILE):
505
+ with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"TINY -> {self.dev.device}", enabled=PROFILE):
415
506
  for i in range(0, src.nbytes, self.b[0].size):
416
507
  self.b_next = (self.b_next + 1) % len(self.b)
417
508
  self.dev.timeline_signal.wait(self.b_timeline[self.b_next])
418
- ctypes.memmove(self.b[self.b_next].va_addr, from_mv(src[i:]), lsize:=min(self.b[self.b_next].size, src.nbytes-i))
509
+
510
+ lsize = min(self.b[self.b_next].size, src.nbytes - i)
511
+ self.b[self.b_next].cpu_view().view(size=lsize, fmt='B')[:] = src[i:i+lsize]
419
512
  self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
420
513
  .copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \
421
- .signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
422
- self.b_timeline[self.b_next] = self.dev.timeline_value
423
- self.dev.timeline_value += 1
514
+ .signal(self.dev.timeline_signal, self.dev.next_timeline()).submit(self.dev)
515
+ self.b_timeline[self.b_next] = self.dev.timeline_value - 1
424
516
 
425
517
  def copy_from_disk(self, dest:HCQBuffer, src, size):
426
518
  def _get_temp_buf():
@@ -435,25 +527,22 @@ class HCQAllocator(HCQAllocatorBase, Generic[DeviceType]):
435
527
  for (batch_info, dst_off, src_off, copy_size) in src.device.allocator._copyout_sharded(src, size, _get_temp_buf, seg_len=self.b[0].size):
436
528
  self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
437
529
  .copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
438
- .signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
439
- self.b_timeline[batch_info[1]] = self.dev.timeline_value
440
- self.dev.timeline_value += 1
530
+ .signal(self.dev.timeline_signal, self.dev.next_timeline()).submit(self.dev)
531
+ self.b_timeline[batch_info[1]] = self.dev.timeline_value - 1
441
532
 
442
533
  def _copyout(self, dest:memoryview, src:HCQBuffer):
443
534
  self.dev.synchronize()
444
535
 
445
536
  assert self.dev.hw_copy_queue_t is not None
446
- with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"{self.dev.device} -> CPU", enabled=PROFILE):
447
- for i in range(0, dest.nbytes, self.b[0].size):
537
+ with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"{self.dev.device} -> TINY", enabled=PROFILE):
538
+ for i in range(0, dest.nbytes, cp_size:=(self.max_copyout_size or self.b[0].size)):
448
539
  self.dev.hw_copy_queue_t().wait(self.dev.timeline_signal, self.dev.timeline_value - 1) \
449
- .copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \
450
- .signal(self.dev.timeline_signal, self.dev.timeline_value).submit(self.dev)
451
- self.dev.timeline_signal.wait(self.dev.timeline_value)
452
- self.dev.timeline_value += 1
453
-
454
- ctypes.memmove(from_mv(dest[i:]), self.b[0].va_addr, lsize)
540
+ .copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(cp_size, dest.nbytes-i)) \
541
+ .signal(self.dev.timeline_signal, self.dev.next_timeline()).submit(self.dev)
542
+ self.dev.timeline_signal.wait(self.dev.timeline_value - 1)
543
+ dest[i:i+lsize] = self.b[0].cpu_view().view(size=lsize, fmt='B')[:]
455
544
 
456
- def _transfer(self, dest:HCQBuffer, src:HCQBuffer, sz:int, src_dev:DeviceType, dest_dev:DeviceType):
545
+ def _transfer(self, dest:HCQBuffer, src:HCQBuffer, sz:int, src_dev:HCQDeviceType, dest_dev:HCQDeviceType):
457
546
  cast(HCQAllocator, src_dev.allocator).map(dest)
458
547
 
459
548
  assert src_dev.hw_copy_queue_t is not None
@@ -461,11 +550,9 @@ class HCQAllocator(HCQAllocatorBase, Generic[DeviceType]):
461
550
  src_dev.hw_copy_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
462
551
  .wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
463
552
  .copy(dest.va_addr, src.va_addr, sz) \
464
- .signal(src_dev.timeline_signal, src_dev.timeline_value).submit(src_dev)
465
- src_dev.timeline_value += 1
553
+ .signal(src_dev.timeline_signal, src_dev.next_timeline()).submit(src_dev)
466
554
 
467
555
  if src_dev != dest_dev:
468
556
  dest_dev.hw_compute_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
469
557
  .wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
470
- .signal(dest_dev.timeline_signal, dest_dev.timeline_value).submit(dest_dev)
471
- dest_dev.timeline_value += 1
558
+ .signal(dest_dev.timeline_signal, dest_dev.next_timeline()).submit(dest_dev)