tinygrad 0.10.2__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. tinygrad/__init__.py +1 -1
  2. tinygrad/apps/llm.py +206 -0
  3. tinygrad/codegen/__init__.py +116 -0
  4. tinygrad/codegen/devectorizer.py +315 -172
  5. tinygrad/codegen/expander.py +8 -16
  6. tinygrad/codegen/gpudims.py +89 -0
  7. tinygrad/codegen/linearize.py +205 -203
  8. tinygrad/codegen/lowerer.py +92 -139
  9. tinygrad/codegen/opt/__init__.py +38 -0
  10. tinygrad/codegen/opt/heuristic.py +125 -0
  11. tinygrad/codegen/opt/kernel.py +510 -0
  12. tinygrad/{engine → codegen/opt}/search.py +51 -35
  13. tinygrad/codegen/opt/swizzler.py +134 -0
  14. tinygrad/codegen/opt/tc.py +127 -0
  15. tinygrad/codegen/quantize.py +67 -0
  16. tinygrad/device.py +122 -132
  17. tinygrad/dtype.py +152 -35
  18. tinygrad/engine/jit.py +81 -54
  19. tinygrad/engine/memory.py +46 -27
  20. tinygrad/engine/realize.py +82 -41
  21. tinygrad/engine/schedule.py +70 -445
  22. tinygrad/frontend/__init__.py +0 -0
  23. tinygrad/frontend/onnx.py +1253 -0
  24. tinygrad/frontend/torch.py +5 -0
  25. tinygrad/gradient.py +19 -27
  26. tinygrad/helpers.py +95 -47
  27. tinygrad/nn/__init__.py +7 -8
  28. tinygrad/nn/optim.py +72 -41
  29. tinygrad/nn/state.py +37 -23
  30. tinygrad/renderer/__init__.py +40 -60
  31. tinygrad/renderer/cstyle.py +143 -128
  32. tinygrad/renderer/llvmir.py +113 -62
  33. tinygrad/renderer/ptx.py +50 -32
  34. tinygrad/renderer/wgsl.py +27 -23
  35. tinygrad/runtime/autogen/am/am.py +5861 -0
  36. tinygrad/runtime/autogen/am/pm4_nv.py +962 -0
  37. tinygrad/runtime/autogen/am/pm4_soc15.py +931 -0
  38. tinygrad/runtime/autogen/am/sdma_4_0_0.py +5209 -0
  39. tinygrad/runtime/autogen/am/sdma_4_4_2.py +5209 -0
  40. tinygrad/runtime/autogen/am/sdma_5_0_0.py +7103 -0
  41. tinygrad/runtime/autogen/am/sdma_6_0_0.py +8085 -0
  42. tinygrad/runtime/autogen/am/smu_v13_0_0.py +3068 -0
  43. tinygrad/runtime/autogen/am/smu_v14_0_2.py +3605 -0
  44. tinygrad/runtime/autogen/amd_gpu.py +1433 -67197
  45. tinygrad/runtime/autogen/comgr.py +35 -9
  46. tinygrad/runtime/autogen/comgr_3.py +906 -0
  47. tinygrad/runtime/autogen/cuda.py +2419 -494
  48. tinygrad/runtime/autogen/hsa.py +57 -16
  49. tinygrad/runtime/autogen/ib.py +7171 -0
  50. tinygrad/runtime/autogen/io_uring.py +917 -118
  51. tinygrad/runtime/autogen/kfd.py +748 -26
  52. tinygrad/runtime/autogen/libc.py +613 -218
  53. tinygrad/runtime/autogen/libusb.py +1643 -0
  54. tinygrad/runtime/autogen/nv/nv.py +8602 -0
  55. tinygrad/runtime/autogen/nv_gpu.py +7218 -2072
  56. tinygrad/runtime/autogen/opencl.py +2 -4
  57. tinygrad/runtime/autogen/sqtt.py +1789 -0
  58. tinygrad/runtime/autogen/vfio.py +3 -3
  59. tinygrad/runtime/autogen/webgpu.py +273 -264
  60. tinygrad/runtime/graph/cuda.py +3 -3
  61. tinygrad/runtime/graph/hcq.py +68 -29
  62. tinygrad/runtime/graph/metal.py +29 -13
  63. tinygrad/runtime/graph/remote.py +114 -0
  64. tinygrad/runtime/ops_amd.py +537 -320
  65. tinygrad/runtime/ops_cpu.py +108 -7
  66. tinygrad/runtime/ops_cuda.py +12 -14
  67. tinygrad/runtime/ops_disk.py +13 -10
  68. tinygrad/runtime/ops_dsp.py +47 -40
  69. tinygrad/runtime/ops_gpu.py +13 -11
  70. tinygrad/runtime/ops_hip.py +6 -9
  71. tinygrad/runtime/ops_llvm.py +35 -15
  72. tinygrad/runtime/ops_metal.py +29 -19
  73. tinygrad/runtime/ops_npy.py +5 -3
  74. tinygrad/runtime/ops_null.py +28 -0
  75. tinygrad/runtime/ops_nv.py +306 -234
  76. tinygrad/runtime/ops_python.py +62 -52
  77. tinygrad/runtime/ops_qcom.py +28 -39
  78. tinygrad/runtime/ops_remote.py +482 -0
  79. tinygrad/runtime/ops_webgpu.py +28 -28
  80. tinygrad/runtime/support/am/amdev.py +114 -249
  81. tinygrad/runtime/support/am/ip.py +211 -172
  82. tinygrad/runtime/support/amd.py +138 -0
  83. tinygrad/runtime/support/{compiler_hip.py → compiler_amd.py} +40 -8
  84. tinygrad/runtime/support/compiler_cuda.py +8 -11
  85. tinygrad/runtime/support/elf.py +2 -1
  86. tinygrad/runtime/support/hcq.py +184 -97
  87. tinygrad/runtime/support/ib.py +172 -0
  88. tinygrad/runtime/support/llvm.py +3 -4
  89. tinygrad/runtime/support/memory.py +251 -0
  90. tinygrad/runtime/support/nv/__init__.py +0 -0
  91. tinygrad/runtime/support/nv/ip.py +581 -0
  92. tinygrad/runtime/support/nv/nvdev.py +183 -0
  93. tinygrad/runtime/support/system.py +170 -0
  94. tinygrad/runtime/support/usb.py +268 -0
  95. tinygrad/runtime/support/webgpu.py +18 -0
  96. tinygrad/schedule/__init__.py +0 -0
  97. tinygrad/schedule/grouper.py +119 -0
  98. tinygrad/schedule/kernelize.py +368 -0
  99. tinygrad/schedule/multi.py +231 -0
  100. tinygrad/shape/shapetracker.py +40 -46
  101. tinygrad/shape/view.py +88 -52
  102. tinygrad/tensor.py +968 -542
  103. tinygrad/uop/__init__.py +117 -0
  104. tinygrad/{codegen/transcendental.py → uop/decompositions.py} +125 -38
  105. tinygrad/uop/mathtraits.py +169 -0
  106. tinygrad/uop/ops.py +1021 -0
  107. tinygrad/uop/spec.py +228 -0
  108. tinygrad/{codegen → uop}/symbolic.py +239 -216
  109. tinygrad/uop/upat.py +163 -0
  110. tinygrad/viz/assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/x86asm.min.js +19 -0
  111. tinygrad/viz/assets/d3js.org/d3.v7.min.js +2 -0
  112. tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js +801 -0
  113. tinygrad/viz/index.html +203 -403
  114. tinygrad/viz/js/index.js +718 -0
  115. tinygrad/viz/js/worker.js +29 -0
  116. tinygrad/viz/serve.py +224 -102
  117. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/METADATA +24 -16
  118. tinygrad-0.11.0.dist-info/RECORD +141 -0
  119. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/WHEEL +1 -1
  120. tinygrad/codegen/kernel.py +0 -693
  121. tinygrad/engine/multi.py +0 -161
  122. tinygrad/ops.py +0 -1003
  123. tinygrad/runtime/ops_cloud.py +0 -220
  124. tinygrad/runtime/support/allocator.py +0 -94
  125. tinygrad/spec.py +0 -155
  126. tinygrad/viz/assets/d3js.org/d3.v5.min.js +0 -2
  127. tinygrad/viz/assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js +0 -4816
  128. tinygrad/viz/perfetto.html +0 -178
  129. tinygrad-0.10.2.dist-info/RECORD +0 -99
  130. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info/licenses}/LICENSE +0 -0
  131. {tinygrad-0.10.2.dist-info → tinygrad-0.11.0.dist-info}/top_level.txt +0 -0
tinygrad/device.py CHANGED
@@ -1,11 +1,11 @@
1
1
  from __future__ import annotations
2
2
  from dataclasses import dataclass, replace
3
3
  from collections import defaultdict
4
- from typing import Optional, Any, Iterator, Generator
5
- import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, ctypes.util, platform, contextlib, sys, re, atexit, pickle, decimal, time
6
- from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp, mv_address, \
7
- cpu_time_execution, colored, Context, round_up
8
- from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes
4
+ from typing import Any, Generic, TypeVar, Iterator
5
+ import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal, time
6
+ from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored, \
7
+ Context, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup
8
+ from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype
9
9
  from tinygrad.renderer import Renderer
10
10
 
11
11
  # **************** Device ****************
@@ -15,18 +15,18 @@ class _Device:
15
15
  def __init__(self) -> None:
16
16
  self._devices = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
17
17
  self._opened_devices:set[str] = set()
18
- @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
18
+ @functools.cache # this class is a singleton, pylint: disable=method-cache-max-size-none
19
19
  def _canonicalize(self, device:str) -> str: return re.sub(r":0$", "", (d:=device.split(":", 1)[0].upper()) + device[len(d):])
20
20
  # NOTE: you can't cache canonicalize in case Device.DEFAULT changes
21
- def canonicalize(self, device:Optional[str]) -> str: return self._canonicalize(device) if device is not None else Device.DEFAULT
21
+ def canonicalize(self, device:str|None) -> str: return self._canonicalize(device if device is not None else Device.DEFAULT)
22
22
  def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
23
- @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
23
+ @functools.cache # this class is a singleton, pylint: disable=method-cache-max-size-none
24
24
  def __get_canonicalized_item(self, ix:str) -> Compiled:
25
- cpn = multiprocessing.current_process().name
26
- assert (cpn == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY", "PYTHON"], f"can only open device {ix} from parent, not {cpn}"
27
- x = ix.split(":")[0].upper()
28
- ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) \
29
- if (cname.lower() == x.lower() + "device")][0](ix)
25
+ assert ALLOW_DEVICE_USAGE or ix.split(":")[0] in ["DISK", "NPY", "PYTHON"], f"usage of device {ix} disallowed"
26
+ base = (__package__ or __name__).split('.')[0] # tinygrad
27
+ x = ix.split(":")[0].lower()
28
+ ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{base}.runtime.ops_{x}')) \
29
+ if (cname.lower() == x + "device")][0](ix)
30
30
  if DEBUG >= 1: print(f"opened device {ix} from pid:{os.getpid()}")
31
31
  self._opened_devices.add(ix)
32
32
  return ret
@@ -37,7 +37,10 @@ class _Device:
37
37
  with contextlib.suppress(Exception): yield self[device].device
38
38
  @functools.cached_property
39
39
  def DEFAULT(self) -> str:
40
- if (from_env:=next((d for d in self._devices if d not in ["DISK", "NPY"] and getenv(d) == 1), None)): return from_env
40
+ dev = [dev] if (dev:=getenv("DEV", "").upper()) else []
41
+ from_env = dedup(dev + [d for d in self._devices if d not in ["DISK", "NPY"] and getenv(d) == 1])
42
+ assert len(from_env) < 2, f"multiple devices set in env: {from_env}"
43
+ if len(from_env) == 1: return from_env[0]
41
44
  try:
42
45
  device = next(self.get_available_devices())
43
46
  os.environ[device] = "1" # we set this in environment for spawned children
@@ -48,14 +51,12 @@ atexit.register(lambda: [Device[dn].finalize() for dn in Device._opened_devices]
48
51
 
49
52
  # **************** Profile ****************
50
53
 
51
- class ProfileEvent: pass
52
-
53
54
  @dataclass(frozen=True)
54
55
  class ProfileDeviceEvent(ProfileEvent):
55
56
  device:str; comp_tdiff:decimal.Decimal=decimal.Decimal(0); copy_tdiff:decimal.Decimal=decimal.Decimal(0) # noqa: E702
56
57
 
57
58
  @dataclass(frozen=True)
58
- class ProfileRangeEvent(ProfileEvent): device:str; name:str; st:decimal.Decimal; en:decimal.Decimal; is_copy:bool # noqa: E702
59
+ class ProfileProgramEvent(ProfileEvent): device:str; name:str; lib:bytes|None; base:int|None # noqa: E702
59
60
 
60
61
  @dataclass(frozen=True)
61
62
  class ProfileGraphEntry: device:str; name:str; st_id:int; en_id:int; is_copy:bool # noqa: E702
@@ -63,39 +64,42 @@ class ProfileGraphEntry: device:str; name:str; st_id:int; en_id:int; is_copy:boo
63
64
  @dataclass(frozen=True)
64
65
  class ProfileGraphEvent(ProfileEvent): ents:list[ProfileGraphEntry]; deps:list[list[int]]; sigs:list[decimal.Decimal] # noqa: E702
65
66
 
66
- @dataclass
67
- class ProfileResult: st:Optional[int]=None; en:Optional[int]=None # noqa: E702
68
-
69
- @contextlib.contextmanager
70
- def cpu_profile(name, device="CPU", is_copy=False, display=True) -> Generator[ProfileResult, None, None]:
71
- yield (res:=ProfileResult(st:=time.perf_counter_ns()))
72
- res.en = en = time.perf_counter_ns()
73
- if PROFILE and display:
74
- Compiled.profile_events += [ProfileRangeEvent(device, name, decimal.Decimal(st) / 1000, decimal.Decimal(en) / 1000, is_copy=is_copy)]
75
-
76
67
  # **************** Buffer + Allocators ****************
77
68
 
78
-
79
69
  @dataclass(frozen=True, eq=True)
80
70
  class BufferSpec:
81
71
  # TODO: move device, size, dtype here?
82
- image: Optional[ImageDType] = None
72
+ image: ImageDType|None = None
83
73
  uncached: bool = False
84
74
  cpu_access: bool = False
85
75
  host: bool = False
86
76
  nolru: bool = False
87
- external_ptr: Optional[int] = None
77
+ external_ptr: int|None = None
78
+
79
+ class MultiBuffer:
80
+ def __init__(self, device:tuple[str, ...], size:int, dtype:DType):
81
+ self.bufs = [Buffer(d, size, dtype) for d in device]
82
+ @property
83
+ def size(self): return self.bufs[0].size
84
+ @property
85
+ def dtype(self): return self.bufs[0].dtype
86
+ def ref(self, cnt):
87
+ for b in self.bufs: b.ref(cnt)
88
+ return self
89
+ def is_allocated(self): return all(x.is_allocated() for x in self.bufs)
90
+ def __repr__(self): return f"<multibuf real:{self.is_allocated()} device:{tuple(x.device for x in self.bufs)} size:{self.size} dtype:{self.dtype}>"
88
91
 
89
92
  class Buffer:
90
- def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferSpec]=None, initial_value:Optional[bytes]=None,
91
- lb_refcount=0, base:Optional[Buffer]=None, offset:int=0, preallocate=False):
93
+ profile_events:list[ProfileEvent] = []
94
+ def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:BufferSpec|None=None, initial_value:bytes|None=None,
95
+ uop_refcount=0, base:Buffer|None=None, offset:int=0, preallocate=False):
92
96
  if isinstance(dtype, ImageDType): options = BufferSpec(image=dtype) # TODO: image hack shouldn't be here. where should it be?
93
97
  else: assert isinstance(dtype, DType) and not isinstance(dtype, PtrDType)
94
- self.device, self.size, self.dtype, self.options, self.offset = device, size, dtype, options, offset
98
+ self.device, self.size, self.dtype, self.options, self.offset, self.allocated_views = device, size, dtype, options, offset, 0
95
99
  if base is None:
96
100
  assert offset == 0, "base buffers can't have offset"
97
101
  self._base = None
98
- self._lb_refcount = lb_refcount
102
+ self._uop_refcount = uop_refcount
99
103
  if opaque is not None: self.allocate(opaque)
100
104
  if initial_value is not None:
101
105
  self.allocate()
@@ -108,60 +112,86 @@ class Buffer:
108
112
  @property
109
113
  def base(self) -> Buffer: return self._base if self._base is not None else self
110
114
  @property
111
- def lb_refcount(self): return self.base._lb_refcount
112
- def ref(self, cnt): self.base._lb_refcount += cnt
113
- def is_allocated(self) -> bool: return hasattr(self, '_buf')
114
- def ensure_allocated(self) -> Buffer: return self.allocate() if not self.is_allocated() else self
115
+ def uop_refcount(self): return self.base._uop_refcount
116
+ def ref(self, cnt):
117
+ self.base._uop_refcount += cnt
118
+ return self
119
+ # check if the underlying buffer is allocated and the current buffer/view is initialized
120
+ def is_initialized(self) -> bool: return self.is_allocated() and hasattr(self, '_buf')
121
+ # check if the underlying buffer is allocated, possibly from the base object
122
+ def is_allocated(self) -> bool: return self.base.is_allocated() if self._base is not None else hasattr(self, '_buf')
123
+ def ensure_allocated(self) -> Buffer: return self.allocate() if not self.is_initialized() else self
115
124
  def allocate(self, opaque=None, external_ptr=None) -> Buffer:
116
- assert not self.is_allocated(), "can't allocate already allocated buffer"
125
+ assert not self.is_initialized(), "can't allocate already allocated buffer"
126
+ if DEBUG >= 7: print(f"buffer: allocate {self.nbytes} bytes on {self.device}")
127
+ if MAX_BUFFER_SIZE > 0 and self.size > MAX_BUFFER_SIZE: raise RuntimeError(f"buffer of size {self.size/1e6:.2f}M is too large")
117
128
  self.allocator:Allocator = Device[self.device].allocator
118
129
  if external_ptr is not None:
119
130
  self.options = replace(self.options, external_ptr=external_ptr) if self.options else BufferSpec(external_ptr=external_ptr)
120
131
  if self._base is not None:
121
132
  self._base.ensure_allocated()
133
+ self._base.allocated_views += 1
122
134
  assert hasattr(self.allocator, "_offset"), "offset function required for view"
123
135
  self._buf: Any = self.allocator._offset(self.base._buf, self.nbytes, self.offset)
124
136
  else:
125
137
  self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, self.options)
126
138
  if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes
139
+ if PROFILE:
140
+ self._prof_num = num = len(Buffer.profile_events)
141
+ ts = decimal.Decimal(time.perf_counter_ns())/1000
142
+ Buffer.profile_events.append(ProfilePointEvent(self.device, "alloc", ts, num, {"dtype":str(self.dtype),"sz":self.size,"nbytes":self.nbytes}))
127
143
  return self
128
144
  def deallocate(self):
129
- assert self.is_allocated(), "buffer must be allocated to deallocate"
145
+ assert hasattr(self, '_buf'), "buffer must be allocated to deallocate"
146
+ if DEBUG is not None and DEBUG >= 7: print(f"buffer: deallocate {self.nbytes} bytes on {self.device}")
130
147
  if self._base is None and (self.options is None or self.options.external_ptr is None):
131
- if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
148
+ if GlobalCounters is not None and not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.nbytes
149
+ if PROFILE: Buffer.profile_events.append(ProfilePointEvent(self.device, "free", decimal.Decimal(time.perf_counter_ns())/1000, self._prof_num))
132
150
  self.allocator.free(self._buf, self.nbytes, self.options)
151
+ elif self._base is not None: self._base.allocated_views -= 1
133
152
  del self._buf
134
153
  def __reduce__(self):
135
154
  buf = None
136
155
  if self._base is not None:
137
156
  return self.__class__, (self.device, self.size, self.dtype, None, None, None, 0, self.base, self.offset, self.is_allocated())
138
- if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options, None, self.lb_refcount)
157
+ if self.device == "NPY": return self.__class__, (self.device, self.size, self.dtype, self._buf, self.options, None, self.uop_refcount)
139
158
  if self.is_allocated():
140
159
  buf = bytearray(self.nbytes)
141
160
  self.copyout(memoryview(buf))
142
- return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf, self.lb_refcount)
161
+ return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf, self.uop_refcount)
143
162
  @property
144
163
  def nbytes(self): return self.size*self.dtype.itemsize
145
- def __del__(self): (not self.is_allocated()) or self.deallocate()
164
+ def __del__(self): (not hasattr(self, '_buf')) or self.deallocate()
146
165
  def __repr__(self):
147
166
  return f"<buf real:{self.is_allocated()} device:{self.device} size:{self.size} dtype:{self.dtype}" + \
148
- (f" offset:{self.offset}" if hasattr(self, "base") else "") + (f" {self.options=}" if self.options is not None else "") + ">"
167
+ (f" offset:{self.offset}" if self._base is not None else "") + (f" {self.options=}" if self.options is not None else "") + ">"
168
+ def as_dmaref(self) -> DMARef:
169
+ assert hasattr(self.allocator, "_as_dmaref"), f"Device {self.device} doesn't support DMA"
170
+ return self.allocator._as_dmaref(self._buf)
149
171
  def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
150
172
  # zero copy with as_buffer (disabled by default due to use after free)
151
173
  if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, '_as_buffer') and (self.options is None or self.options.image is None):
152
174
  return self.allocator._as_buffer(self._buf)
153
175
  assert not force_zero_copy, "force zero copy was passed, but copy is required"
154
176
  return self.copyout(memoryview(bytearray(self.nbytes)))
177
+ def as_typed_buffer(self, shape=None, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
178
+ assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
179
+ assert self.dtype.base.fmt != "e" or sys.version_info >= (3, 12)
180
+ return self.as_buffer(allow_zero_copy, force_zero_copy).cast(self.dtype.base.fmt, shape if shape is not None else (self.size,))
181
+ def numpy(self) -> 'np.ndarray': # type: ignore [name-defined] # noqa: F821
182
+ import numpy as np
183
+ assert _to_np_dtype(self.dtype.base) is not None, f"no np dtype for {self.dtype.base}"
184
+ return np.frombuffer(self.as_buffer(), dtype=_to_np_dtype(self.dtype.base))
155
185
  def copyin(self, mv:memoryview):
156
186
  mv = flat_mv(mv)
157
187
  assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
158
- assert self.is_allocated(), "can't copyin to unallocated buffer"
188
+ assert self.is_initialized(), "can't copyin to unallocated buffer"
159
189
  self.allocator._copyin(self._buf, mv)
160
190
  return self
161
191
  def copyout(self, mv:memoryview) -> memoryview:
162
192
  mv = flat_mv(mv)
163
193
  assert len(mv) == self.nbytes, f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
164
- assert self.is_allocated(), "can't copyout unallocated buffer"
194
+ assert self.is_initialized(), "can't copyout unallocated buffer"
165
195
  self.allocator._copyout(mv, self._buf)
166
196
  return mv
167
197
  def view(self, size:int, dtype:DType, offset:int) -> Buffer:
@@ -169,13 +199,33 @@ class Buffer:
169
199
  if self._base is not None: return Buffer(self.device, size, dtype, base=self._base, offset=self.offset+offset)
170
200
  return Buffer(self.device, size, dtype, base=self, offset=offset)
171
201
 
202
+ @dataclass(frozen=True)
203
+ class DMACPURef:
204
+ addr: int
205
+ size: int
206
+
207
+ @dataclass(frozen=True)
208
+ class DMAFdRef:
209
+ fd: int
210
+ offset: int
211
+ size: int
212
+
213
+ DMARef = DMACPURef|DMAFdRef
214
+
215
+ DeviceType = TypeVar('DeviceType', bound='Compiled')
216
+
172
217
  # TODO: size, dest, src are the same type. can we enforce this?
173
- class Allocator:
218
+ class Allocator(Generic[DeviceType]):
219
+ def __init__(self, dev:DeviceType):
220
+ self.dev: DeviceType = dev
221
+ self.default_buffer_spec: BufferSpec = BufferSpec()
222
+ self.supports_copy_from_disk: bool = True
174
223
  # overridden in LRUAllocator
175
- def alloc(self, size:int, options:Optional[BufferSpec]=None):
224
+ def alloc(self, size:int, options:BufferSpec|None=None):
176
225
  assert size > 0, f"alloc size must be positive, getting {size}"
177
- return self._alloc(size, options if options is not None else BufferSpec())
178
- def free(self, opaque, size:int, options:Optional[BufferSpec]=None): self._free(opaque, options if options is not None else BufferSpec())
226
+ return self._alloc(size, options if options is not None else self.default_buffer_spec)
227
+ def free(self, opaque, size:int, options:BufferSpec|None=None):
228
+ self._free(opaque, options if options is not None else self.default_buffer_spec)
179
229
 
180
230
  # implemented by the runtime
181
231
  def _alloc(self, size:int, options:BufferSpec): raise NotImplementedError("need alloc")
@@ -186,13 +236,15 @@ class Allocator:
186
236
  # def _offset(self, buf, size:int, offset:int):
187
237
  # def _transfer(self, dest, src, sz:int, src_dev, dest_dev):
188
238
 
189
- class LRUAllocator(Allocator):
239
+ class LRUAllocator(Allocator, Generic[DeviceType]):
190
240
  """
191
241
  The LRU Allocator is responsible for caching buffers.
192
242
  It ensures that buffers are not freed until it is absolutely necessary, optimizing performance.
193
243
  """
194
- def __init__(self): self.cache: dict[tuple[int, Optional[BufferSpec]], Any] = defaultdict(list)
195
- def alloc(self, size:int, options:Optional[BufferSpec]=None):
244
+ def __init__(self, dev:DeviceType):
245
+ self.cache: dict[tuple[int, BufferSpec|None], Any] = defaultdict(list)
246
+ super().__init__(dev)
247
+ def alloc(self, size:int, options:BufferSpec|None=None):
196
248
  if len(c := self.cache[(size, options)]): return c.pop()
197
249
  try: return super().alloc(size, options)
198
250
  except (RuntimeError, MemoryError):
@@ -202,84 +254,16 @@ class LRUAllocator(Allocator):
202
254
  for (sz,options),opaques in self.cache.items():
203
255
  for opaque in opaques: super().free(opaque, sz, options)
204
256
  opaques.clear()
205
- def free(self, opaque:Any, size:int, options:Optional[BufferSpec]=None):
257
+ def free(self, opaque:Any, size:int, options:BufferSpec|None=None):
206
258
  if LRU and (options is None or not options.nolru): self.cache[(size, options)].append(opaque)
207
259
  else: super().free(opaque, size, options)
208
260
 
209
- class _MallocAllocator(LRUAllocator):
210
- def _alloc(self, size:int, options:BufferSpec):
211
- # must be aligned to 0x20 for 256-bit ymm registers
212
- # TODO: investigate if this is the cause of nondeterminism in speed
213
- alignment = 0x1000 if size >= 0x1000 else 0x20
214
- return (ctypes.c_uint8 * size).from_address(options.external_ptr) if options.external_ptr else self._alloc_aligned(size, alignment)
215
- def _alloc_aligned(self, size:int, alignment:int):
216
- buffer = (ctypes.c_uint8 * (size + alignment))()
217
- offset = round_up(ctypes.addressof(buffer), alignment) - ctypes.addressof(buffer)
218
- return (ctypes.c_uint8 * size).from_buffer(buffer, offset)
219
- def _as_buffer(self, src) -> memoryview: return flat_mv(memoryview(src))
220
- def _copyin(self, dest, src:memoryview): ctypes.memmove(dest, from_mv(src), len(src))
221
- def _copyout(self, dest:memoryview, src): ctypes.memmove(from_mv(dest), src, len(dest))
222
- def _offset(self, buf, size:int, offset:int): return from_mv(self._as_buffer(buf)[offset:offset+size])
223
-
224
- MallocAllocator = _MallocAllocator()
225
-
226
- # NOTE: MAP_JIT is added to mmap module in python 3.13
227
- MAP_JIT = 0x0800
228
-
229
- # CPUProgram is a jit/shellcode program that can be just mmapped and jumped to
230
- class CPUProgram:
231
- rt_lib = ctypes.CDLL(ctypes.util.find_library('System' if OSX else 'kernel32') if OSX or sys.platform == "win32" else 'libgcc_s.so.1')
232
- atomic_lib = ctypes.CDLL(ctypes.util.find_library('atomic')) if sys.platform == "linux" else None
233
-
234
- def __init__(self, name:str, lib:bytes):
235
- if sys.platform == "win32":
236
- PAGE_EXECUTE_READWRITE = 0x40
237
- MEM_COMMIT = 0x1000
238
- MEM_RESERVE = 0x2000
239
- ctypes.windll.kernel32.VirtualAlloc.restype = ctypes.c_void_p
240
- self.mem = ctypes.windll.kernel32.VirtualAlloc(ctypes.c_void_p(0), ctypes.c_size_t(len(lib)), MEM_COMMIT | MEM_RESERVE, PAGE_EXECUTE_READWRITE)
241
- ctypes.memmove(self.mem, lib, len(lib))
242
- ctypes.windll.kernel32.GetCurrentProcess.restype = ctypes.c_void_p
243
- proc = ctypes.windll.kernel32.GetCurrentProcess()
244
- ctypes.windll.kernel32.FlushInstructionCache(ctypes.c_void_p(proc), ctypes.c_void_p(self.mem), ctypes.c_size_t(len(lib)))
245
- self.fxn = ctypes.CFUNCTYPE(None)(self.mem)
246
- else:
247
- from mmap import mmap, PROT_READ, PROT_WRITE, PROT_EXEC, MAP_ANON, MAP_PRIVATE
248
- # On apple silicon with SPRR enabled (it always is in macos) RWX pages are unrepresentable: https://blog.svenpeter.dev/posts/m1_sprr_gxf/
249
- # MAP_JIT allows us to easily flip pages from RW- to R-X and vice versa. It is a noop on intel cpus. (man pthread_jit_write_protect_np)
250
- self.mem = mmap(-1, len(lib), MAP_ANON | MAP_PRIVATE | (MAP_JIT if OSX else 0), PROT_READ | PROT_WRITE | PROT_EXEC)
251
-
252
- if OSX: CPUProgram.rt_lib.pthread_jit_write_protect_np(False)
253
- self.mem.write(lib)
254
- if OSX: CPUProgram.rt_lib.pthread_jit_write_protect_np(True)
255
-
256
- # __clear_cache isn't a normal libc function, but a compiler support routine found in libgcc_s for gcc and compiler-rt for clang.
257
- # libgcc_s comes as shared library but compiler-rt is only a bunch of static library archives which we can't directly load, but fortunately
258
- # it somehow found its way into libSystem on macos (likely because it used __builtin_clear_cache) and libgcc_s is ~always present on linux
259
- # Using ["name"] instead of .name because otherwise name is getting mangled: https://docs.python.org/3.12/reference/expressions.html#index-5
260
- CPUProgram.rt_lib["__clear_cache"](ctypes.c_void_p(mv_address(self.mem)), ctypes.c_void_p(mv_address(self.mem) + len(lib)))
261
-
262
- self.fxn = ctypes.CFUNCTYPE(None)(mv_address(self.mem))
263
-
264
- def __call__(self, *bufs, vals=(), wait=False):
265
- args = list(bufs) + list(vals)
266
- # NOTE: replace this by --target={host's triple}-elf in clang args once we only support macos sequoia and later.
267
- # Apple relaxes abi requirement for stack arguments to always be at least 8 byte aligned on arm64
268
- # https://developer.apple.com/documentation/xcode/writing-arm64-code-for-apple-platforms
269
- # This hack is required because clang/llvm bug doesn't allow us to just use {host's triple}+'-elf' (relocation failures)
270
- # The bug was fixed in https://github.com/llvm/llvm-project/commit/454cc36630296262cdb6360b60f90a64a97f7f1a but was only backported to xcode 16+
271
- if platform.machine() == "arm64" and OSX: args = args[:8] + [ctypes.c_int64(a) if isinstance(a, int) else a for a in args[8:]]
272
- return cpu_time_execution(lambda: self.fxn(*args), enable=wait)
273
-
274
- def __del__(self):
275
- if sys.platform == 'win32': ctypes.windll.kernel32.VirtualFree(ctypes.c_void_p(self.mem), ctypes.c_size_t(0), 0x8000) #0x8000 - MEM_RELEASE
276
-
277
261
  # **************** for Compiled Devices ****************
278
262
 
279
263
  class CompileError(Exception): pass
280
264
 
281
265
  class Compiler:
282
- def __init__(self, cachekey:Optional[str]=None): self.cachekey = None if getenv("DISABLE_COMPILER_CACHE") else cachekey
266
+ def __init__(self, cachekey:str|None=None): self.cachekey = None if DISABLE_COMPILER_CACHE else cachekey
283
267
  def compile(self, src:str) -> bytes: return src.encode() # NOTE: empty compiler is the default
284
268
  def compile_cached(self, src:str) -> bytes:
285
269
  if self.cachekey is None or (lib := diskcache_get(self.cachekey, src)) is None:
@@ -292,9 +276,9 @@ class Compiler:
292
276
  class Compiled:
293
277
  profile_events:list[ProfileEvent] = [ProfileDeviceEvent("CPU")] # NOTE: CPU is the default device.
294
278
 
295
- def __init__(self, device:str, allocator:Allocator, renderer:Optional[Renderer], compiler:Optional[Compiler], runtime, graph=None):
279
+ def __init__(self, device:str, allocator:Allocator, renderer:Renderer|None, compiler:Compiler|None, runtime, graph=None, group_id=None):
296
280
  self.device, self.allocator, self.compiler, self.runtime, self.graph = device, allocator, compiler or Compiler(), runtime, graph
297
- self.renderer = renderer or Renderer()
281
+ self.renderer, self.group_id = renderer or Renderer(), group_id
298
282
  def synchronize(self):
299
283
  """
300
284
  Synchronize all pending operations on the device.
@@ -314,11 +298,16 @@ class Compiled:
314
298
  # override this in your device implementation
315
299
 
316
300
  # TODO: move this to each Device
317
- def is_dtype_supported(dtype:DType, device:Optional[str]=None) -> bool:
301
+ def is_dtype_supported(dtype:DType, device:str|None=None) -> bool:
318
302
  if device is None: device = Device.DEFAULT
319
303
  if dtype == dtypes.bfloat16:
320
- # NOTE: this requires bf16 buffer support
321
- return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
304
+ if device == "METAL": return not CI
305
+ if device in {"CUDA", "NV"}: return not CI and not getenv("PTX")
306
+ if device in {"CPU", "LLVM"}: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"}
307
+ return device == "AMD"
308
+ if dtype in dtypes.fp8s:
309
+ # not supported yet - in progress
310
+ return False
322
311
  if device == "WEBGPU": return dtype in [dtypes.bool, dtypes.char, dtypes.uchar, dtypes.short,
323
312
  dtypes.ushort, dtypes.float, dtypes.int32, dtypes.uint32, dtypes.half]
324
313
  # for CI GPU and OSX, cl_khr_fp16 isn't supported
@@ -340,10 +329,11 @@ if PROFILE:
340
329
  for dev in devs: dev.synchronize()
341
330
  for dev in devs: dev._at_profile_finalize()
342
331
 
343
- with open(fn:=temp("profile.pkl", append_user=True), "wb") as f: pickle.dump(Compiled.profile_events, f)
332
+ with open(fn:=temp("profile.pkl", append_user=True), "wb") as f: pickle.dump(cpu_events+Compiled.profile_events+Buffer.profile_events, f)
344
333
 
345
- from tinygrad.ops import launch_viz
346
- launch_viz("PROFILE", fn)
334
+ if not getenv("SQTT", 0):
335
+ from tinygrad.uop.ops import launch_viz
336
+ launch_viz(PROFILE, fn)
347
337
 
348
338
  if __name__ == "__main__":
349
339
  for device in ALL_DEVICES: